File size: 3,114 Bytes
5fa8558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108

import pytest

# Payload valide pour /predict
PAYLOAD_OK = {
  "age": 41,
  "genre": "homme",
  "revenu_mensuel": 3993,
  "statut_marital": "célibataire",
  "departement": "commercial",
  "poste": "cadre commercial",
  "nombre_experiences_precedentes": 2,
  "annees_dans_l_entreprise": 5,
  "satisfaction_employee_environnement": 4,
  "satisfaction_employee_nature_travail": 1,
  "satisfaction_employee_equipe": 1,
  "satisfaction_employee_equilibre_pro_perso": 1,
  "heure_supplementaires": True,
  "augmentation_salaire_precedente": 11,
  "nombre_participation_pee": 0,
  "nb_formations_suivies": 0,
  "distance_domicile_travail": 1,
  "niveau_education": 2,
  "domaine_etude": "infra & cloud",
  "frequence_deplacement": "occasionnel",
  "annees_sous_responsable_actuel": 0,
  "annees_dans_le_poste_actuel": 0,
  "note_evaluation_actuelle": 0,
  "note_evaluation_precedente": 0,
  "annees_depuis_la_derniere_promotion": 0
}


# -------------------------------------------------------------------
# Utilitaire : injecter un état minimal dans l'app (pas de HF / pas de DB)
# -------------------------------------------------------------------
def _inject_dummy_state():
    from app.main import app

    class DummyModel:
        def predict_proba(self, X):
            return [[0.2, 0.8]]  # proba classe 1

    app.state.model = DummyModel()
    app.state.threshold = 0.292
    app.state.engine = None


# =========================
# /predict (POST)
# =========================
def test_post_predict_unauthorized_without_api_key(client):
    r = client.post("/predict", json=PAYLOAD_OK)
    assert r.status_code == 401


def test_post_predict_unauthorized_with_wrong_api_key(client):
    r = client.post(
        "/predict",
        json=PAYLOAD_OK,
        headers={"X-API-Key": "WRONG"},
    )
    assert r.status_code == 401


def test_post_predict_ok_with_api_key(client, auth_headers):
    _inject_dummy_state()

    r = client.post("/predict", json=PAYLOAD_OK, headers=auth_headers)
    assert r.status_code == 200, r.text

    body = r.json()
    assert body["threshold"] == 0.292
    assert body["prediction"] in (0, 1)
    assert "proba" in body


# =========================
# /predict/{id} (GET)
# =========================
def test_get_predict_by_id_unauthorized_without_api_key(client):
    r = client.get("/predict/7")
    assert r.status_code == 401


def test_get_predict_by_id_unauthorized_with_wrong_api_key(client):
    r = client.get("/predict/7", headers={"X-API-Key": "WRONG"})
    assert r.status_code == 401


def test_get_predict_by_id_ok_with_api_key(client, auth_headers, monkeypatch):
    _inject_dummy_state()


    import app.main as main_module

    def fake_run_predict_by_id(*, id_employee, model, threshold, engine):
        return 0.55, 1, {"id_employee": id_employee}

    monkeypatch.setattr(main_module, "run_predict_by_id", fake_run_predict_by_id)

    r = client.get("/predict/7", headers=auth_headers)
    assert r.status_code == 200, r.text

    body = r.json()
    assert body["threshold"] == 0.292
    assert body["prediction"] in (0, 1)
    assert "proba" in body