import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.pool import StaticPool from app.main import app from app.core.database import Base, get_db from app.core.config import settings # Setup in-memory SQLite database for testing SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:" engine = create_engine( SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}, poolclass=StaticPool, ) TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) def override_get_db(): try: db = TestingSessionLocal() yield db finally: db.close() app.dependency_overrides[get_db] = override_get_db client = TestClient(app) @pytest.fixture(scope="module") def test_db(): # Create tables Base.metadata.create_all(bind=engine) yield # Drop tables Base.metadata.drop_all(bind=engine) def test_health_check(): response = client.get("/health") assert response.status_code == 200 assert response.json() == {"status": "healthy"} def test_predict_unauthorized(): response = client.post("/predict", json={}) assert response.status_code == 422 def test_predict_invalid_key(): response = client.post("/predict", headers={"X-API-KEY": "wrong_key"}, json={}) assert response.status_code == 401 def test_predict_success(test_db): # Valid input data based on updated schema payload = { "age": 30, "genre": "M", "revenu_mensuel": 5000, "statut_marital": "Célibataire", "departement": "R&D", "poste": "Ingénieur", "nombre_experiences_precedentes": 2, "nombre_heures_travailless": 40, "annee_experience_totale": 5, "annees_dans_l_entreprise": 2, "annees_dans_le_poste_actuel": 1, "satisfaction_employee_environnement": 3, "note_evaluation_precedente": 3, "niveau_hierarchique_poste": 2, "satisfaction_employee_nature_travail": 3, "satisfaction_employee_equipe": 4, "satisfaction_employee_equilibre_pro_perso": 3, "note_evaluation_actuelle": 3, "heure_supplementaires": "Non", "augementation_salaire_precedente": "10-15%", "nombre_participation_pee": 0, "nb_formations_suivies": 1, "nombre_employee_sous_responsabilite": 0, "distance_domicile_travail": 10, "niveau_education": 3, "domaine_etude": "Sciences", "ayant_enfants": "Non", "frequence_deplacement": "Rare", "annees_depuis_la_derniere_promotion": 1, "annes_sous_responsable_actuel": 1 } response = client.post( "/predict", headers={"X-API-KEY": settings.API_KEY}, json=payload ) if response.status_code != 200: print(f"Response error: {response.json()}") assert response.status_code == 200 data = response.json() assert "prediction" in data assert "probability" in data # Verify DB insertion db = TestingSessionLocal() log = db.query(Base.metadata.tables["prediction_logs"]).first() assert log is not None assert log.age == 30 assert log.genre == "M" db.close()