technova-ml-api / tests /test_api.py
github-actions
deploy: snapshot
5fa8558
import pytest
# Payload valide pour /predict
PAYLOAD_OK = {
"age": 41,
"genre": "homme",
"revenu_mensuel": 3993,
"statut_marital": "célibataire",
"departement": "commercial",
"poste": "cadre commercial",
"nombre_experiences_precedentes": 2,
"annees_dans_l_entreprise": 5,
"satisfaction_employee_environnement": 4,
"satisfaction_employee_nature_travail": 1,
"satisfaction_employee_equipe": 1,
"satisfaction_employee_equilibre_pro_perso": 1,
"heure_supplementaires": True,
"augmentation_salaire_precedente": 11,
"nombre_participation_pee": 0,
"nb_formations_suivies": 0,
"distance_domicile_travail": 1,
"niveau_education": 2,
"domaine_etude": "infra & cloud",
"frequence_deplacement": "occasionnel",
"annees_sous_responsable_actuel": 0,
"annees_dans_le_poste_actuel": 0,
"note_evaluation_actuelle": 0,
"note_evaluation_precedente": 0,
"annees_depuis_la_derniere_promotion": 0
}
# -------------------------------------------------------------------
# Utilitaire : injecter un état minimal dans l'app (pas de HF / pas de DB)
# -------------------------------------------------------------------
def _inject_dummy_state():
from app.main import app
class DummyModel:
def predict_proba(self, X):
return [[0.2, 0.8]] # proba classe 1
app.state.model = DummyModel()
app.state.threshold = 0.292
app.state.engine = None
# =========================
# /predict (POST)
# =========================
def test_post_predict_unauthorized_without_api_key(client):
r = client.post("/predict", json=PAYLOAD_OK)
assert r.status_code == 401
def test_post_predict_unauthorized_with_wrong_api_key(client):
r = client.post(
"/predict",
json=PAYLOAD_OK,
headers={"X-API-Key": "WRONG"},
)
assert r.status_code == 401
def test_post_predict_ok_with_api_key(client, auth_headers):
_inject_dummy_state()
r = client.post("/predict", json=PAYLOAD_OK, headers=auth_headers)
assert r.status_code == 200, r.text
body = r.json()
assert body["threshold"] == 0.292
assert body["prediction"] in (0, 1)
assert "proba" in body
# =========================
# /predict/{id} (GET)
# =========================
def test_get_predict_by_id_unauthorized_without_api_key(client):
r = client.get("/predict/7")
assert r.status_code == 401
def test_get_predict_by_id_unauthorized_with_wrong_api_key(client):
r = client.get("/predict/7", headers={"X-API-Key": "WRONG"})
assert r.status_code == 401
def test_get_predict_by_id_ok_with_api_key(client, auth_headers, monkeypatch):
_inject_dummy_state()
import app.main as main_module
def fake_run_predict_by_id(*, id_employee, model, threshold, engine):
return 0.55, 1, {"id_employee": id_employee}
monkeypatch.setattr(main_module, "run_predict_by_id", fake_run_predict_by_id)
r = client.get("/predict/7", headers=auth_headers)
assert r.status_code == 200, r.text
body = r.json()
assert body["threshold"] == 0.292
assert body["prediction"] in (0, 1)
assert "proba" in body