Spaces:
Sleeping
Sleeping
| from fastapi.testclient import TestClient | |
| from sqlalchemy import create_engine | |
| from sqlalchemy.orm import sessionmaker | |
| from main import app | |
| from config.db import get_db | |
| from models.ml import MLModel | |
| from models.ml_inputs import MLInput | |
| from models.ml_output import MLOutput | |
| import uuid | |
| from datetime import datetime, timezone | |
| def test_simple_predict(tmp_path): | |
| db_path = tmp_path / "testing.db" | |
| engine = create_engine( | |
| f"sqlite:///{db_path}", | |
| connect_args={"check_same_thread": False}, | |
| future=True, | |
| ) | |
| SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True) | |
| MLModel.metadata.create_all(engine) | |
| MLInput.metadata.create_all(engine) | |
| MLOutput.metadata.create_all(engine) | |
| session = SQLSession() | |
| def get_db_override(): | |
| return session | |
| app.dependency_overrides[get_db] = get_db_override | |
| client = TestClient(app, raise_server_exceptions=False) | |
| created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc) | |
| session.add_all( | |
| [ | |
| MLModel( | |
| id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"), | |
| name="baseline", | |
| description="Baseline model", | |
| created_at=created, | |
| is_active=True, | |
| ), | |
| MLModel( | |
| id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000002"), | |
| name="best_model", | |
| description="XGB v1", | |
| created_at=created, | |
| is_active=True, | |
| ), | |
| MLModel( | |
| id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000003"), | |
| name="logistic_regression", | |
| description="Logistic Regression", | |
| created_at=created, | |
| is_active=True, | |
| ), | |
| ] | |
| ) | |
| session.commit() | |
| payload = { | |
| "model_name": "best_model", | |
| "inputs": [{ | |
| "id_employee": 123, | |
| "age": 35, | |
| "genre": "Homme", | |
| "revenu_mensuel": 4200, | |
| "statut_marital": "Célibataire", | |
| "departement": "Ventes", | |
| "poste": "Commercial", | |
| "nombre_experiences_precedentes": 2, | |
| "nombre_heures_travailless": 40, | |
| "annee_experience_totale": 5, | |
| "annees_dans_l_entreprise": 2, | |
| "annees_dans_le_poste_actuel": 1, | |
| "nombre_participation_pee": 1, | |
| "nb_formations_suivies": 3, | |
| "nombre_employee_sous_responsabilite": 0, | |
| "code_sondage": 7, | |
| "distance_domicile_travail": 12, | |
| "niveau_education": 3, | |
| "domaine_etude": "Marketing", | |
| "ayant_enfants": "Non", | |
| "frequence_deplacement": "Rarement", | |
| "annees_depuis_la_derniere_promotion": 0, | |
| "annes_sous_responsable_actuel": 1, | |
| "satisfaction_employee_environnement": 3, | |
| "note_evaluation_precedente": 4, | |
| "niveau_hierarchique_poste": 2, | |
| "satisfaction_employee_nature_travail": 3, | |
| "satisfaction_employee_equipe": 4, | |
| "satisfaction_employee_equilibre_pro_perso": 3, | |
| "eval_number": "E2", | |
| "note_evaluation_actuelle": 4, | |
| "heure_supplementaires": "Non", | |
| "augementation_salaire_precedente": 11 | |
| }] | |
| } | |
| resp = client.post("/predict", json=payload) | |
| print("STATUS:", resp.status_code) | |
| print("BODY:", resp.text) | |
| app.dependency_overrides.clear() | |
| session.close() | |
| assert resp.status_code == 200 | |
| data = resp.json() | |
| assert data["model_name"] == "best_model" | |
| assert isinstance(data["results"], list) | |
| assert len(data["results"]) == 1 | |
| result = data["results"][0] | |
| assert result["label"] == "reste_dans_l_entreprise" | |
| assert isinstance(result["proba"], float) | |
| assert 0 <= result["proba"] <= 1 | |
| def test_not_found_model(tmp_path): | |
| db_path = tmp_path / "testing.db" | |
| engine = create_engine( | |
| f"sqlite:///{db_path}", | |
| connect_args={"check_same_thread": False}, | |
| future=True, | |
| ) | |
| SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True) | |
| MLModel.metadata.create_all(engine) | |
| MLInput.metadata.create_all(engine) | |
| MLOutput.metadata.create_all(engine) | |
| session = SQLSession() | |
| def get_db_override(): | |
| return session | |
| app.dependency_overrides[get_db] = get_db_override | |
| client = TestClient(app, raise_server_exceptions=False) | |
| created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc) | |
| session.add_all( | |
| [ | |
| MLModel( | |
| id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"), | |
| name="baseline", | |
| description="Baseline model", | |
| created_at=created, | |
| is_active=True, | |
| ), | |
| ] | |
| ) | |
| session.commit() | |
| payload = { | |
| "model_name": "best_model", | |
| "inputs": [{ | |
| "id_employee": 123, | |
| "age": 35, | |
| "genre": "Homme", | |
| "revenu_mensuel": 4200, | |
| "statut_marital": "Célibataire", | |
| "departement": "Ventes", | |
| "poste": "Commercial", | |
| "nombre_experiences_precedentes": 2, | |
| "nombre_heures_travailless": 40, | |
| "annee_experience_totale": 5, | |
| "annees_dans_l_entreprise": 2, | |
| "annees_dans_le_poste_actuel": 1, | |
| "nombre_participation_pee": 1, | |
| "nb_formations_suivies": 3, | |
| "nombre_employee_sous_responsabilite": 0, | |
| "code_sondage": 7, | |
| "distance_domicile_travail": 12, | |
| "niveau_education": 3, | |
| "domaine_etude": "Marketing", | |
| "ayant_enfants": "Non", | |
| "frequence_deplacement": "Rarement", | |
| "annees_depuis_la_derniere_promotion": 0, | |
| "annes_sous_responsable_actuel": 1, | |
| "satisfaction_employee_environnement": 3, | |
| "note_evaluation_precedente": 4, | |
| "niveau_hierarchique_poste": 2, | |
| "satisfaction_employee_nature_travail": 3, | |
| "satisfaction_employee_equipe": 4, | |
| "satisfaction_employee_equilibre_pro_perso": 3, | |
| "eval_number": "E2", | |
| "note_evaluation_actuelle": 4, | |
| "heure_supplementaires": "Non", | |
| "augementation_salaire_precedente": 11 | |
| }] | |
| } | |
| resp = client.post("/predict", json=payload) | |
| app.dependency_overrides.clear() | |
| session.close() | |
| assert resp.status_code == 404 | |
| data = resp.json() | |
| assert data["detail"] == "Modèle introuvable ou inactif" | |
| def test_inactif_model(tmp_path): | |
| db_path = tmp_path / "testing.db" | |
| engine = create_engine( | |
| f"sqlite:///{db_path}", | |
| connect_args={"check_same_thread": False}, | |
| future=True, | |
| ) | |
| SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True) | |
| MLModel.metadata.create_all(engine) | |
| MLInput.metadata.create_all(engine) | |
| MLOutput.metadata.create_all(engine) | |
| session = SQLSession() | |
| def get_db_override(): | |
| return session | |
| app.dependency_overrides[get_db] = get_db_override | |
| client = TestClient(app, raise_server_exceptions=False) | |
| created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc) | |
| session.add_all( | |
| [ | |
| MLModel( | |
| id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"), | |
| name="baseline", | |
| description="Baseline model", | |
| created_at=created, | |
| is_active=False, | |
| ), | |
| ] | |
| ) | |
| session.commit() | |
| payload = { | |
| "model_name": "baseline", | |
| "inputs": [{ | |
| "id_employee": 123, | |
| "age": 35, | |
| "genre": "Homme", | |
| "revenu_mensuel": 4200, | |
| "statut_marital": "Célibataire", | |
| "departement": "Ventes", | |
| "poste": "Commercial", | |
| "nombre_experiences_precedentes": 2, | |
| "nombre_heures_travailless": 40, | |
| "annee_experience_totale": 5, | |
| "annees_dans_l_entreprise": 2, | |
| "annees_dans_le_poste_actuel": 1, | |
| "nombre_participation_pee": 1, | |
| "nb_formations_suivies": 3, | |
| "nombre_employee_sous_responsabilite": 0, | |
| "code_sondage": 7, | |
| "distance_domicile_travail": 12, | |
| "niveau_education": 3, | |
| "domaine_etude": "Marketing", | |
| "ayant_enfants": "Non", | |
| "frequence_deplacement": "Rarement", | |
| "annees_depuis_la_derniere_promotion": 0, | |
| "annes_sous_responsable_actuel": 1, | |
| "satisfaction_employee_environnement": 3, | |
| "note_evaluation_precedente": 4, | |
| "niveau_hierarchique_poste": 2, | |
| "satisfaction_employee_nature_travail": 3, | |
| "satisfaction_employee_equipe": 4, | |
| "satisfaction_employee_equilibre_pro_perso": 3, | |
| "eval_number": "E2", | |
| "note_evaluation_actuelle": 4, | |
| "heure_supplementaires": "Non", | |
| "augementation_salaire_precedente": 11 | |
| }] | |
| } | |
| resp = client.post("/predict", json=payload) | |
| print("STATUS:", resp.status_code) | |
| print("BODY:", resp.text) | |
| app.dependency_overrides.clear() | |
| session.close() | |
| assert resp.status_code == 404 | |
| data = resp.json() | |
| assert data["detail"] == "Modèle introuvable ou inactif" | |
| def test_list_models_returns_500_when_db_fails(): | |
| class BrokenSession: | |
| def query(self, *a, **kw): | |
| raise RuntimeError("DB is down") | |
| def get_db_override(): | |
| yield BrokenSession() | |
| app.dependency_overrides[get_db] = get_db_override | |
| client = TestClient(app, raise_server_exceptions=False) | |
| payload = { | |
| "model_name": "baseline", | |
| "inputs": [{ | |
| "id_employee": 123, | |
| "age": 35, | |
| "genre": "Homme", | |
| "revenu_mensuel": 4200, | |
| "statut_marital": "Célibataire", | |
| "departement": "Ventes", | |
| "poste": "Commercial", | |
| "nombre_experiences_precedentes": 2, | |
| "nombre_heures_travailless": 40, | |
| "annee_experience_totale": 5, | |
| "annees_dans_l_entreprise": 2, | |
| "annees_dans_le_poste_actuel": 1, | |
| "nombre_participation_pee": 1, | |
| "nb_formations_suivies": 3, | |
| "nombre_employee_sous_responsabilite": 0, | |
| "code_sondage": 7, | |
| "distance_domicile_travail": 12, | |
| "niveau_education": 3, | |
| "domaine_etude": "Marketing", | |
| "ayant_enfants": "Non", | |
| "frequence_deplacement": "Rarement", | |
| "annees_depuis_la_derniere_promotion": 0, | |
| "annes_sous_responsable_actuel": 1, | |
| "satisfaction_employee_environnement": 3, | |
| "note_evaluation_precedente": 4, | |
| "niveau_hierarchique_poste": 2, | |
| "satisfaction_employee_nature_travail": 3, | |
| "satisfaction_employee_equipe": 4, | |
| "satisfaction_employee_equilibre_pro_perso": 3, | |
| "eval_number": "E2", | |
| "note_evaluation_actuelle": 4, | |
| "heure_supplementaires": "Non", | |
| "augementation_salaire_precedente": 11 | |
| }] | |
| } | |
| resp = client.post("/predict", json=payload) | |
| app.dependency_overrides.clear() | |
| assert resp.status_code == 500 | |