marintosti12
fix(folder) : rename test folder
67092ff
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from main import app
from config.db import get_db
from models.ml import MLModel
from models.ml_inputs import MLInput
from models.ml_output import MLOutput
import uuid
from datetime import datetime, timezone
def test_simple_predict(tmp_path):
db_path = tmp_path / "testing.db"
engine = create_engine(
f"sqlite:///{db_path}",
connect_args={"check_same_thread": False},
future=True,
)
SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
MLModel.metadata.create_all(engine)
MLInput.metadata.create_all(engine)
MLOutput.metadata.create_all(engine)
session = SQLSession()
def get_db_override():
return session
app.dependency_overrides[get_db] = get_db_override
client = TestClient(app, raise_server_exceptions=False)
created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc)
session.add_all(
[
MLModel(
id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"),
name="baseline",
description="Baseline model",
created_at=created,
is_active=True,
),
MLModel(
id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000002"),
name="best_model",
description="XGB v1",
created_at=created,
is_active=True,
),
MLModel(
id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000003"),
name="logistic_regression",
description="Logistic Regression",
created_at=created,
is_active=True,
),
]
)
session.commit()
payload = {
"model_name": "best_model",
"inputs": [{
"id_employee": 123,
"age": 35,
"genre": "Homme",
"revenu_mensuel": 4200,
"statut_marital": "Célibataire",
"departement": "Ventes",
"poste": "Commercial",
"nombre_experiences_precedentes": 2,
"nombre_heures_travailless": 40,
"annee_experience_totale": 5,
"annees_dans_l_entreprise": 2,
"annees_dans_le_poste_actuel": 1,
"nombre_participation_pee": 1,
"nb_formations_suivies": 3,
"nombre_employee_sous_responsabilite": 0,
"code_sondage": 7,
"distance_domicile_travail": 12,
"niveau_education": 3,
"domaine_etude": "Marketing",
"ayant_enfants": "Non",
"frequence_deplacement": "Rarement",
"annees_depuis_la_derniere_promotion": 0,
"annes_sous_responsable_actuel": 1,
"satisfaction_employee_environnement": 3,
"note_evaluation_precedente": 4,
"niveau_hierarchique_poste": 2,
"satisfaction_employee_nature_travail": 3,
"satisfaction_employee_equipe": 4,
"satisfaction_employee_equilibre_pro_perso": 3,
"eval_number": "E2",
"note_evaluation_actuelle": 4,
"heure_supplementaires": "Non",
"augementation_salaire_precedente": 11
}]
}
resp = client.post("/predict", json=payload)
print("STATUS:", resp.status_code)
print("BODY:", resp.text)
app.dependency_overrides.clear()
session.close()
assert resp.status_code == 200
data = resp.json()
assert data["model_name"] == "best_model"
assert isinstance(data["results"], list)
assert len(data["results"]) == 1
result = data["results"][0]
assert result["label"] == "reste_dans_l_entreprise"
assert isinstance(result["proba"], float)
assert 0 <= result["proba"] <= 1
def test_not_found_model(tmp_path):
db_path = tmp_path / "testing.db"
engine = create_engine(
f"sqlite:///{db_path}",
connect_args={"check_same_thread": False},
future=True,
)
SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
MLModel.metadata.create_all(engine)
MLInput.metadata.create_all(engine)
MLOutput.metadata.create_all(engine)
session = SQLSession()
def get_db_override():
return session
app.dependency_overrides[get_db] = get_db_override
client = TestClient(app, raise_server_exceptions=False)
created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc)
session.add_all(
[
MLModel(
id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"),
name="baseline",
description="Baseline model",
created_at=created,
is_active=True,
),
]
)
session.commit()
payload = {
"model_name": "best_model",
"inputs": [{
"id_employee": 123,
"age": 35,
"genre": "Homme",
"revenu_mensuel": 4200,
"statut_marital": "Célibataire",
"departement": "Ventes",
"poste": "Commercial",
"nombre_experiences_precedentes": 2,
"nombre_heures_travailless": 40,
"annee_experience_totale": 5,
"annees_dans_l_entreprise": 2,
"annees_dans_le_poste_actuel": 1,
"nombre_participation_pee": 1,
"nb_formations_suivies": 3,
"nombre_employee_sous_responsabilite": 0,
"code_sondage": 7,
"distance_domicile_travail": 12,
"niveau_education": 3,
"domaine_etude": "Marketing",
"ayant_enfants": "Non",
"frequence_deplacement": "Rarement",
"annees_depuis_la_derniere_promotion": 0,
"annes_sous_responsable_actuel": 1,
"satisfaction_employee_environnement": 3,
"note_evaluation_precedente": 4,
"niveau_hierarchique_poste": 2,
"satisfaction_employee_nature_travail": 3,
"satisfaction_employee_equipe": 4,
"satisfaction_employee_equilibre_pro_perso": 3,
"eval_number": "E2",
"note_evaluation_actuelle": 4,
"heure_supplementaires": "Non",
"augementation_salaire_precedente": 11
}]
}
resp = client.post("/predict", json=payload)
app.dependency_overrides.clear()
session.close()
assert resp.status_code == 404
data = resp.json()
assert data["detail"] == "Modèle introuvable ou inactif"
def test_inactif_model(tmp_path):
db_path = tmp_path / "testing.db"
engine = create_engine(
f"sqlite:///{db_path}",
connect_args={"check_same_thread": False},
future=True,
)
SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
MLModel.metadata.create_all(engine)
MLInput.metadata.create_all(engine)
MLOutput.metadata.create_all(engine)
session = SQLSession()
def get_db_override():
return session
app.dependency_overrides[get_db] = get_db_override
client = TestClient(app, raise_server_exceptions=False)
created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc)
session.add_all(
[
MLModel(
id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"),
name="baseline",
description="Baseline model",
created_at=created,
is_active=False,
),
]
)
session.commit()
payload = {
"model_name": "baseline",
"inputs": [{
"id_employee": 123,
"age": 35,
"genre": "Homme",
"revenu_mensuel": 4200,
"statut_marital": "Célibataire",
"departement": "Ventes",
"poste": "Commercial",
"nombre_experiences_precedentes": 2,
"nombre_heures_travailless": 40,
"annee_experience_totale": 5,
"annees_dans_l_entreprise": 2,
"annees_dans_le_poste_actuel": 1,
"nombre_participation_pee": 1,
"nb_formations_suivies": 3,
"nombre_employee_sous_responsabilite": 0,
"code_sondage": 7,
"distance_domicile_travail": 12,
"niveau_education": 3,
"domaine_etude": "Marketing",
"ayant_enfants": "Non",
"frequence_deplacement": "Rarement",
"annees_depuis_la_derniere_promotion": 0,
"annes_sous_responsable_actuel": 1,
"satisfaction_employee_environnement": 3,
"note_evaluation_precedente": 4,
"niveau_hierarchique_poste": 2,
"satisfaction_employee_nature_travail": 3,
"satisfaction_employee_equipe": 4,
"satisfaction_employee_equilibre_pro_perso": 3,
"eval_number": "E2",
"note_evaluation_actuelle": 4,
"heure_supplementaires": "Non",
"augementation_salaire_precedente": 11
}]
}
resp = client.post("/predict", json=payload)
print("STATUS:", resp.status_code)
print("BODY:", resp.text)
app.dependency_overrides.clear()
session.close()
assert resp.status_code == 404
data = resp.json()
assert data["detail"] == "Modèle introuvable ou inactif"
def test_list_models_returns_500_when_db_fails():
class BrokenSession:
def query(self, *a, **kw):
raise RuntimeError("DB is down")
def get_db_override():
yield BrokenSession()
app.dependency_overrides[get_db] = get_db_override
client = TestClient(app, raise_server_exceptions=False)
payload = {
"model_name": "baseline",
"inputs": [{
"id_employee": 123,
"age": 35,
"genre": "Homme",
"revenu_mensuel": 4200,
"statut_marital": "Célibataire",
"departement": "Ventes",
"poste": "Commercial",
"nombre_experiences_precedentes": 2,
"nombre_heures_travailless": 40,
"annee_experience_totale": 5,
"annees_dans_l_entreprise": 2,
"annees_dans_le_poste_actuel": 1,
"nombre_participation_pee": 1,
"nb_formations_suivies": 3,
"nombre_employee_sous_responsabilite": 0,
"code_sondage": 7,
"distance_domicile_travail": 12,
"niveau_education": 3,
"domaine_etude": "Marketing",
"ayant_enfants": "Non",
"frequence_deplacement": "Rarement",
"annees_depuis_la_derniere_promotion": 0,
"annes_sous_responsable_actuel": 1,
"satisfaction_employee_environnement": 3,
"note_evaluation_precedente": 4,
"niveau_hierarchique_poste": 2,
"satisfaction_employee_nature_travail": 3,
"satisfaction_employee_equipe": 4,
"satisfaction_employee_equilibre_pro_perso": 3,
"eval_number": "E2",
"note_evaluation_actuelle": 4,
"heure_supplementaires": "Non",
"augementation_salaire_precedente": 11
}]
}
resp = client.post("/predict", json=payload)
app.dependency_overrides.clear()
assert resp.status_code == 500