File size: 4,599 Bytes
5fa8558
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# tests/test_services_predict.py
import pytest


def test_run_predict_manual_without_engine(monkeypatch):
    """
    Cas simple : engine=None => pas d'audit, on renvoie proba/pred/payload enrichi.
    """
    from app.services import predict as predict_service

    # Fake predict_manual (ML)
    def fake_predict_manual(payload, model, threshold):
        return 0.8, 1, {"x": 1, "enrich": True}

    monkeypatch.setattr(predict_service, "predict_manual", fake_predict_manual)

    proba, pred, payload_enrichi = predict_service.run_predict_manual(
        payload={"x": 1},
        model=object(),
        threshold=0.3,
        engine=None,
    )

    assert proba == 0.8
    assert pred == 1
    assert payload_enrichi["enrich"] is True


def test_run_predict_manual_with_engine_calls_audit(monkeypatch):
    """
    Cas engine présent : log_audit doit être appelé.
    """
    from app.services import predict as predict_service

    # Fake predict_manual
    def fake_predict_manual(payload, model, threshold):
        return 0.2, 0, {"foo": "bar"}

    monkeypatch.setattr(predict_service, "predict_manual", fake_predict_manual)

    # Spy log_audit
    calls = {"count": 0, "args": None}

    def fake_log_audit(conn, payload, proba, prediction, threshold):
        calls["count"] += 1
        calls["args"] = (conn, payload, proba, prediction, threshold)
        return 123

    monkeypatch.setattr(predict_service, "log_audit", fake_log_audit)

    # Dummy engine.begin() context manager
    class DummyEngine:
        def begin(self):
            return self

        def __enter__(self):
            return "dummy-conn"

        def __exit__(self, exc_type, exc, tb):
            return False

    proba, pred, payload_enrichi = predict_service.run_predict_manual(
        payload={"hello": "world"},
        model=object(),
        threshold=0.292,
        engine=DummyEngine(),
    )

    assert proba == 0.2
    assert pred == 0
    assert payload_enrichi == {"foo": "bar"}
    assert calls["count"] == 1
    assert calls["args"][0] == "dummy-conn"
    assert calls["args"][1] == {"foo": "bar"}
    assert calls["args"][2] == 0.2
    assert calls["args"][3] == 0
    assert calls["args"][4] == 0.292


def test_run_predict_by_id_not_found_raises_keyerror(monkeypatch):
    """
    Cas id absent : get_employee_features_by_id renvoie None => KeyError attendu.
    """
    from app.services import predict as predict_service

    def fake_get_employee_features_by_id(engine, id_employee):
        return None

    monkeypatch.setattr(predict_service, "get_employee_features_by_id", fake_get_employee_features_by_id)

    with pytest.raises(KeyError):
        predict_service.run_predict_by_id(
            id_employee=999,
            model=object(),
            threshold=0.5,
            engine=object(),
        )


def test_run_predict_by_id_with_engine_calls_audit_and_adds_id(monkeypatch):
    """
    Cas nominal : on récupère un employé, on prédit, on log en audit,
    et on ajoute id_employee dans payload_enrichi avant log.
    """
    from app.services import predict as predict_service

    # Fake features fetch
    def fake_get_employee_features_by_id(engine, id_employee):
        return {"id_employee": id_employee, "age": 40}

    monkeypatch.setattr(predict_service, "get_employee_features_by_id", fake_get_employee_features_by_id)

    # Fake predict_from_employee_features
    def fake_predict_from_employee_features(employee, model, threshold):
        # payload enrichi sans id -> le service doit l'ajouter
        return 0.55, 1, {"age": employee["age"]}

    monkeypatch.setattr(predict_service, "predict_from_employee_features", fake_predict_from_employee_features)

    # Spy log_audit
    calls = {"count": 0, "payload": None}

    def fake_log_audit(conn, payload, proba, prediction, threshold):
        calls["count"] += 1
        calls["payload"] = payload
        return 456

    monkeypatch.setattr(predict_service, "log_audit", fake_log_audit)

    class DummyEngine:
        def begin(self):
            return self

        def __enter__(self):
            return "dummy-conn"

        def __exit__(self, exc_type, exc, tb):
            return False

    proba, pred, payload_enrichi = predict_service.run_predict_by_id(
        id_employee=7,
        model=object(),
        threshold=0.292,
        engine=DummyEngine(),
    )

    assert proba == 0.55
    assert pred == 1
    assert payload_enrichi["age"] == 40
    assert payload_enrichi["id_employee"] == 7

    assert calls["count"] == 1
    assert calls["payload"]["id_employee"] == 7