Spaces:
Sleeping
Sleeping
| import pytest | |
| from fastapi.testclient import TestClient | |
| from sqlalchemy import create_engine | |
| from sqlalchemy.orm import sessionmaker | |
| from sqlalchemy.pool import StaticPool | |
| from app.main import app | |
| from app.core.database import Base, get_db | |
| from app.core.config import settings | |
| # Setup in-memory SQLite database for testing | |
| SQLALCHEMY_DATABASE_URL = "sqlite:///:memory:" | |
| engine = create_engine( | |
| SQLALCHEMY_DATABASE_URL, | |
| connect_args={"check_same_thread": False}, | |
| poolclass=StaticPool, | |
| ) | |
| TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | |
| def override_get_db(): | |
| try: | |
| db = TestingSessionLocal() | |
| yield db | |
| finally: | |
| db.close() | |
| app.dependency_overrides[get_db] = override_get_db | |
| client = TestClient(app) | |
| def test_db(): | |
| # Create tables | |
| Base.metadata.create_all(bind=engine) | |
| yield | |
| # Drop tables | |
| Base.metadata.drop_all(bind=engine) | |
| def test_health_check(): | |
| response = client.get("/health") | |
| assert response.status_code == 200 | |
| assert response.json() == {"status": "healthy"} | |
| def test_predict_unauthorized(): | |
| response = client.post("/predict", json={}) | |
| assert response.status_code == 422 | |
| def test_predict_invalid_key(): | |
| response = client.post("/predict", headers={"X-API-KEY": "wrong_key"}, json={}) | |
| assert response.status_code == 401 | |
| def test_predict_success(test_db): | |
| # Valid input data based on updated schema | |
| payload = { | |
| "age": 30, | |
| "genre": "M", | |
| "revenu_mensuel": 5000, | |
| "statut_marital": "Célibataire", | |
| "departement": "R&D", | |
| "poste": "Ingénieur", | |
| "nombre_experiences_precedentes": 2, | |
| "nombre_heures_travailless": 40, | |
| "annee_experience_totale": 5, | |
| "annees_dans_l_entreprise": 2, | |
| "annees_dans_le_poste_actuel": 1, | |
| "satisfaction_employee_environnement": 3, | |
| "note_evaluation_precedente": 3, | |
| "niveau_hierarchique_poste": 2, | |
| "satisfaction_employee_nature_travail": 3, | |
| "satisfaction_employee_equipe": 4, | |
| "satisfaction_employee_equilibre_pro_perso": 3, | |
| "note_evaluation_actuelle": 3, | |
| "heure_supplementaires": "Non", | |
| "augementation_salaire_precedente": "10-15%", | |
| "nombre_participation_pee": 0, | |
| "nb_formations_suivies": 1, | |
| "nombre_employee_sous_responsabilite": 0, | |
| "distance_domicile_travail": 10, | |
| "niveau_education": 3, | |
| "domaine_etude": "Sciences", | |
| "ayant_enfants": "Non", | |
| "frequence_deplacement": "Rare", | |
| "annees_depuis_la_derniere_promotion": 1, | |
| "annes_sous_responsable_actuel": 1 | |
| } | |
| response = client.post( | |
| "/predict", | |
| headers={"X-API-KEY": settings.API_KEY}, | |
| json=payload | |
| ) | |
| if response.status_code != 200: | |
| print(f"Response error: {response.json()}") | |
| assert response.status_code == 200 | |
| data = response.json() | |
| assert "prediction" in data | |
| assert "probability" in data | |
| # Verify DB insertion | |
| db = TestingSessionLocal() | |
| log = db.query(Base.metadata.tables["prediction_logs"]).first() | |
| assert log is not None | |
| assert log.age == 30 | |
| assert log.genre == "M" | |
| db.close() | |