marintosti12 commited on
Commit
4ba8e3d
·
1 Parent(s): 9ad00b1

feat(test) : add some fonctionals tets

Browse files
alembic/env.py CHANGED
@@ -7,7 +7,8 @@ from sqlalchemy.ext.asyncio import create_async_engine, AsyncEngine
7
  from sqlalchemy import pool
8
 
9
  from src.config.db import Base
10
-
 
11
  # Alembic Config
12
  config = context.config
13
  if config.config_file_name is not None:
 
7
  from sqlalchemy import pool
8
 
9
  from src.config.db import Base
10
+ from dotenv import load_dotenv, find_dotenv
11
+ load_dotenv(find_dotenv())
12
  # Alembic Config
13
  config = context.config
14
  if config.config_file_name is not None:
alembic/versions/24251a13df00_ml_outputs.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ml outputs
2
+
3
+ Revision ID: 24251a13df00
4
+ Revises: ecd589af543e
5
+ Create Date: 2025-09-15 16:34:46.842373
6
+
7
+ """
8
+ from typing import Sequence, Union
9
+
10
+ from alembic import op
11
+ import sqlalchemy as sa
12
+ from sqlalchemy.dialects import postgresql
13
+
14
+
15
+ # revision identifiers, used by Alembic.
16
+ revision: str = '24251a13df00'
17
+ down_revision: Union[str, Sequence[str], None] = 'ecd589af543e'
18
+ branch_labels: Union[str, Sequence[str], None] = None
19
+ depends_on: Union[str, Sequence[str], None] = None
20
+
21
+
22
+ def upgrade() -> None:
23
+ op.create_table(
24
+ "ml_outputs",
25
+ sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True,
26
+ server_default=sa.text("gen_random_uuid()")),
27
+ sa.Column("created_at", sa.DateTime(timezone=True), nullable=False,
28
+ server_default=sa.text("TIMEZONE('utc', now())")),
29
+ sa.Column("input_id", postgresql.UUID(as_uuid=True), nullable=False),
30
+ sa.Column("prediction", sa.String(length=100), nullable=False),
31
+ sa.Column("prob", sa.Float(), nullable=True),
32
+ sa.Column("error", sa.String(length=500), nullable=True),
33
+ sa.ForeignKeyConstraint(["input_id"], ["ml_inputs.id"], ondelete="CASCADE"),
34
+ )
35
+
36
+
37
+ def downgrade() -> None:
38
+ op.drop_table("ml_outputs")
39
+
src/config/db.py CHANGED
@@ -13,3 +13,10 @@ class Base(DeclarativeBase):
13
 
14
  engine = create_engine(settings.DATABASE_URL, echo=True, future=True)
15
  SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
 
 
 
 
 
 
 
 
13
 
14
  engine = create_engine(settings.DATABASE_URL, echo=True, future=True)
15
  SessionLocal = sessionmaker(bind=engine, autocommit=False, autoflush=False)
16
+
17
+ def get_db():
18
+ db = SessionLocal()
19
+ try:
20
+ yield db
21
+ finally:
22
+ db.close()
src/controllers/home_controller.py CHANGED
@@ -1,27 +1,27 @@
1
- from fastapi import APIRouter, HTTPException
2
- from config.db import SessionLocal
3
  from models.ml import MLModel
 
4
 
5
  router = APIRouter()
6
 
7
  @router.get("/", tags=["models"])
8
- def list_ml_models():
9
  try:
10
- with SessionLocal() as s:
11
- rows = (
12
- s.query(MLModel)
13
- .order_by(MLModel.created_at.desc())
14
- .all()
15
- )
16
- return [
17
- {
18
- "id": str(r.id),
19
- "name": r.name,
20
- "description": r.description,
21
- "created_at": r.created_at.isoformat() if r.created_at else None,
22
- "is_active": r.is_active,
23
- }
24
- for r in rows
25
- ]
26
  except Exception as e:
27
  raise HTTPException(status_code=500, detail=str(e))
 
1
+ from fastapi import APIRouter, Depends, HTTPException
2
+ from config.db import get_db
3
  from models.ml import MLModel
4
+ from sqlalchemy.orm import Session
5
 
6
  router = APIRouter()
7
 
8
  @router.get("/", tags=["models"])
9
+ def list_ml_models(db: Session = Depends(get_db)):
10
  try:
11
+ rows = (
12
+ db.query(MLModel)
13
+ .order_by(MLModel.created_at.desc())
14
+ .all()
15
+ )
16
+ return [
17
+ {
18
+ "id": str(r.id),
19
+ "name": r.name,
20
+ "description": r.description,
21
+ "created_at": r.created_at.isoformat() if r.created_at else None,
22
+ "is_active": r.is_active,
23
+ }
24
+ for r in rows
25
+ ]
 
26
  except Exception as e:
27
  raise HTTPException(status_code=500, detail=str(e))
src/controllers/predict_controller.py CHANGED
@@ -1,11 +1,11 @@
1
- # src/controllers/predict_controller.py
2
- from fastapi import APIRouter, HTTPException
3
 
4
- from config.db import SessionLocal
5
  from models.ml import MLModel
6
 
7
  # Schemas
8
  from models.ml_inputs import MLInput
 
9
 
10
  import pandas as pd
11
  from model_loader import load_model
@@ -13,32 +13,31 @@ from features import compute_features
13
  from schemas.PredictItemResult import PredictItemResult
14
  from schemas.PredictResponse import PredictResponse
15
  from schemas.PredictRequest import PredictRequest
 
16
 
17
  router = APIRouter(prefix="/predict", tags=["inference"])
18
 
19
- # (optionnel) mapping lisible des classes
20
  LABELS = {
21
  "0": "reste_dans_l_entreprise",
22
  "1": "parti_de_l_entreprise",
23
  }
24
 
25
- # --------- Route ----------
26
  @router.post("/", response_model=PredictResponse)
27
- def batch_predict(payload: PredictRequest):
28
- with SessionLocal() as s:
29
- row = (
30
- s.query(MLModel)
31
- .filter(MLModel.name == payload.model_name)
32
- .first()
33
- )
34
-
35
- objs = [MLInput(**x.model_dump()) for x in payload.inputs]
36
- s.add_all(objs)
37
- s.commit()
38
 
 
 
 
 
39
 
40
- if not row or getattr(row, "is_active", True) is False:
41
- raise HTTPException(status_code=404, detail="Modèle introuvable ou inactif")
42
 
43
  try:
44
  m = load_model(payload.model_name)
@@ -53,16 +52,33 @@ def batch_predict(payload: PredictRequest):
53
 
54
  probas = m.predict_proba(X)
55
  classes = getattr(m, "classes_", None)
56
- for p in probas:
 
57
  i = int(p.argmax())
58
  key = str(classes[i]) if classes is not None else str(i)
59
  label = LABELS.get(key, key)
60
- results.append(PredictItemResult(label=label, proba=float(p[i])))
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  except Exception as e:
 
63
  raise HTTPException(status_code=400, detail=f"Erreur pendant la prédiction: {e}")
64
 
65
  return PredictResponse(
66
  model_name=payload.model_name,
67
  results=results,
68
  )
 
 
 
1
+ from fastapi import APIRouter, Depends, HTTPException
 
2
 
3
+ from config.db import SessionLocal, get_db
4
  from models.ml import MLModel
5
 
6
  # Schemas
7
  from models.ml_inputs import MLInput
8
+ from models.ml_output import MLOutput
9
 
10
  import pandas as pd
11
  from model_loader import load_model
 
13
  from schemas.PredictItemResult import PredictItemResult
14
  from schemas.PredictResponse import PredictResponse
15
  from schemas.PredictRequest import PredictRequest
16
+ from sqlalchemy.orm import Session
17
 
18
  router = APIRouter(prefix="/predict", tags=["inference"])
19
 
 
20
  LABELS = {
21
  "0": "reste_dans_l_entreprise",
22
  "1": "parti_de_l_entreprise",
23
  }
24
 
 
25
  @router.post("/", response_model=PredictResponse)
26
+ @router.post("/", response_model=PredictResponse)
27
+ def batch_predict(payload: PredictRequest, db: Session = Depends(get_db)):
28
+ row = (
29
+ db.query(MLModel)
30
+ .filter(MLModel.name == payload.model_name)
31
+ .first()
32
+ )
 
 
 
 
33
 
34
+ # --- stocker les inputs
35
+ objs = [MLInput(**x.model_dump()) for x in payload.inputs]
36
+ db.add_all(objs)
37
+ db.commit()
38
 
39
+ if not row or getattr(row, "is_active", True) is False:
40
+ raise HTTPException(status_code=404, detail="Modèle introuvable ou inactif")
41
 
42
  try:
43
  m = load_model(payload.model_name)
 
52
 
53
  probas = m.predict_proba(X)
54
  classes = getattr(m, "classes_", None)
55
+
56
+ for idx, p in enumerate(probas):
57
  i = int(p.argmax())
58
  key = str(classes[i]) if classes is not None else str(i)
59
  label = LABELS.get(key, key)
60
+
61
+ pred = PredictItemResult(label=label, proba=float(p[i]))
62
+ results.append(pred)
63
+
64
+ print(objs[idx].id)
65
+ db.add(
66
+ MLOutput(
67
+ input_id=objs[idx].id,
68
+ prediction=label,
69
+ prob=float(p[i]),
70
+ )
71
+ )
72
+
73
+ db.commit()
74
 
75
  except Exception as e:
76
+ db.rollback()
77
  raise HTTPException(status_code=400, detail=f"Erreur pendant la prédiction: {e}")
78
 
79
  return PredictResponse(
80
  model_name=payload.model_name,
81
  results=results,
82
  )
83
+
84
+
src/main.py CHANGED
@@ -1,19 +1,10 @@
1
 
2
  from fastapi import FastAPI
3
 
4
- from config.db import SessionLocal
5
-
6
 
7
  from controllers.home_controller import router as ml_home_router
8
  from controllers.predict_controller import router as predict_router
9
 
10
- def get_db():
11
- db = SessionLocal()
12
- try:
13
- yield db
14
- finally:
15
- db.close()
16
-
17
 
18
  app = FastAPI()
19
 
 
1
 
2
  from fastapi import FastAPI
3
 
 
 
4
 
5
  from controllers.home_controller import router as ml_home_router
6
  from controllers.predict_controller import router as predict_router
7
 
 
 
 
 
 
 
 
8
 
9
  app = FastAPI()
10
 
src/models/ml_output.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/models/ml_output.py
2
+ from __future__ import annotations
3
+ import uuid
4
+ from datetime import datetime
5
+ from typing import Optional
6
+
7
+ from sqlalchemy import String, Float, DateTime, ForeignKey
8
+ from sqlalchemy.orm import Mapped, mapped_column
9
+ from sqlalchemy.sql import func
10
+ from sqlalchemy.dialects.postgresql import UUID
11
+
12
+ from .base import Base
13
+
14
+
15
+ class MLOutput(Base):
16
+ __tablename__ = "ml_outputs"
17
+
18
+ id: Mapped[uuid.UUID] = mapped_column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
19
+
20
+
21
+ created_at: Mapped[datetime] = mapped_column(
22
+ DateTime(timezone=True), server_default=func.now()
23
+ )
24
+
25
+ input_id: Mapped[str] = mapped_column(
26
+ UUID(as_uuid=True), ForeignKey("ml_inputs.id", ondelete="CASCADE"), nullable=False
27
+ )
28
+
29
+ prediction: Mapped[str] = mapped_column(String(255), nullable=False)
30
+ prob: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
31
+ error: Mapped[Optional[str]] = mapped_column(String(500), nullable=True)
src/seeds/ml_models_seed.py CHANGED
@@ -1,17 +1,15 @@
1
- # src/seeds/ml_models_seed.py
2
  import os
3
  from datetime import datetime, timezone
4
  from sqlalchemy import create_engine, text
5
  from sqlalchemy.orm import Session
6
 
7
- # (optionnel) charge .env automatiquement
8
  try:
9
  from dotenv import load_dotenv
10
  load_dotenv()
11
  except Exception:
12
  pass
13
 
14
- DATABASE_URL = os.environ["DATABASE_URL"] # ex: postgresql+psycopg2://...
15
  engine = create_engine(DATABASE_URL, future=True)
16
 
17
  UPSERT = text("""
 
 
1
  import os
2
  from datetime import datetime, timezone
3
  from sqlalchemy import create_engine, text
4
  from sqlalchemy.orm import Session
5
 
 
6
  try:
7
  from dotenv import load_dotenv
8
  load_dotenv()
9
  except Exception:
10
  pass
11
 
12
+ DATABASE_URL = os.environ["DATABASE_URL"]
13
  engine = create_engine(DATABASE_URL, future=True)
14
 
15
  UPSERT = text("""
tests/test_home.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+ from sqlalchemy import create_engine
3
+ from sqlalchemy.orm import sessionmaker
4
+
5
+ from main import app
6
+ from config.db import get_db
7
+
8
+ from config.db import Base
9
+ from models.ml import MLModel
10
+
11
+ import uuid
12
+ from datetime import datetime, timezone
13
+
14
+
15
+ def test_list_models_simple(tmp_path):
16
+ db_path = tmp_path / "testing.db"
17
+ engine = create_engine(
18
+ f"sqlite:///{db_path}",
19
+ connect_args={"check_same_thread": False},
20
+ future=True,
21
+ )
22
+ SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
23
+
24
+ MLModel.metadata.create_all(engine)
25
+
26
+ session = SQLSession()
27
+
28
+ def get_db_override():
29
+ return session
30
+
31
+ app.dependency_overrides[get_db] = get_db_override
32
+
33
+ client = TestClient(app, raise_server_exceptions=False)
34
+
35
+ created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc)
36
+ session.add_all(
37
+ [
38
+ MLModel(
39
+ id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"),
40
+ name="baseline",
41
+ description="Baseline model",
42
+ created_at=created,
43
+ is_active=True,
44
+ ),
45
+ MLModel(
46
+ id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000002"),
47
+ name="best_model",
48
+ description="XGB v1",
49
+ created_at=created,
50
+ is_active=True,
51
+ ),
52
+ MLModel(
53
+ id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000003"),
54
+ name="logistic_regression",
55
+ description="Logistic Regression",
56
+ created_at=created,
57
+ is_active=True,
58
+ ),
59
+ ]
60
+ )
61
+ session.commit()
62
+
63
+ resp = client.get("/")
64
+
65
+
66
+ app.dependency_overrides.clear()
67
+ session.close()
68
+
69
+ assert resp.status_code == 200
70
+ data = resp.json()
71
+ names = {row["name"] for row in data}
72
+ assert names == {"baseline", "best_model", 'logistic_regression'}
73
+
74
+
75
+ def test_list_models_returns_500_when_db_fails():
76
+ class BrokenSession:
77
+ def query(self, *a, **kw):
78
+ raise RuntimeError("DB is down")
79
+
80
+ def get_db_override():
81
+ yield BrokenSession()
82
+
83
+ app.dependency_overrides[get_db] = get_db_override
84
+ client = TestClient(app, raise_server_exceptions=False)
85
+
86
+ resp = client.get("/")
87
+
88
+ app.dependency_overrides.clear()
89
+
90
+ assert resp.status_code == 500
91
+ body = resp.json()
92
+ assert "DB is down" in body["detail"]
93
+
tests/test_main.py DELETED
@@ -1,8 +0,0 @@
1
- from fastapi.testclient import TestClient
2
-
3
- from main import app
4
-
5
- client = TestClient(app)
6
-
7
- def test_root_ok():
8
- assert True
 
 
 
 
 
 
 
 
 
tests/test_predict.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi.testclient import TestClient
2
+ from sqlalchemy import create_engine
3
+ from sqlalchemy.orm import sessionmaker
4
+
5
+ from main import app
6
+ from config.db import get_db
7
+
8
+ from config.db import Base
9
+ from models.ml import MLModel
10
+ from models.ml_inputs import MLInput
11
+ from models.ml_output import MLOutput
12
+
13
+ import uuid
14
+ from datetime import datetime, timezone
15
+
16
+
17
+ def test_simple_predict(tmp_path):
18
+ db_path = tmp_path / "testing.db"
19
+ engine = create_engine(
20
+ f"sqlite:///{db_path}",
21
+ connect_args={"check_same_thread": False},
22
+ future=True,
23
+ )
24
+ SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
25
+
26
+ MLModel.metadata.create_all(engine)
27
+ MLInput.metadata.create_all(engine)
28
+ MLOutput.metadata.create_all(engine)
29
+
30
+ session = SQLSession()
31
+
32
+ def get_db_override():
33
+ return session
34
+
35
+
36
+ app.dependency_overrides[get_db] = get_db_override
37
+
38
+ client = TestClient(app, raise_server_exceptions=False)
39
+
40
+ created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc)
41
+ session.add_all(
42
+ [
43
+ MLModel(
44
+ id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"),
45
+ name="baseline",
46
+ description="Baseline model",
47
+ created_at=created,
48
+ is_active=True,
49
+ ),
50
+ MLModel(
51
+ id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000002"),
52
+ name="best_model",
53
+ description="XGB v1",
54
+ created_at=created,
55
+ is_active=True,
56
+ ),
57
+ MLModel(
58
+ id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000003"),
59
+ name="logistic_regression",
60
+ description="Logistic Regression",
61
+ created_at=created,
62
+ is_active=True,
63
+ ),
64
+ ]
65
+ )
66
+ session.commit()
67
+
68
+
69
+ payload = {
70
+ "model_name": "best_model",
71
+ "inputs": [{
72
+ "id_employee": 123,
73
+ "age": 35,
74
+ "genre": "Homme",
75
+ "revenu_mensuel": 4200,
76
+ "statut_marital": "Célibataire",
77
+ "departement": "Ventes",
78
+ "poste": "Commercial",
79
+ "nombre_experiences_precedentes": 2,
80
+ "nombre_heures_travailless": 40,
81
+ "annee_experience_totale": 5,
82
+ "annees_dans_l_entreprise": 2,
83
+ "annees_dans_le_poste_actuel": 1,
84
+ "nombre_participation_pee": 1,
85
+ "nb_formations_suivies": 3,
86
+ "nombre_employee_sous_responsabilite": 0,
87
+ "code_sondage": 7,
88
+ "distance_domicile_travail": 12,
89
+ "niveau_education": 3,
90
+ "domaine_etude": "Marketing",
91
+ "ayant_enfants": "Non",
92
+ "frequence_deplacement": "Rarement",
93
+ "annees_depuis_la_derniere_promotion": 0,
94
+ "annes_sous_responsable_actuel": 1,
95
+ "satisfaction_employee_environnement": 3,
96
+ "note_evaluation_precedente": 4,
97
+ "niveau_hierarchique_poste": 2,
98
+ "satisfaction_employee_nature_travail": 3,
99
+ "satisfaction_employee_equipe": 4,
100
+ "satisfaction_employee_equilibre_pro_perso": 3,
101
+ "eval_number": "E2",
102
+ "note_evaluation_actuelle": 4,
103
+ "heure_supplementaires": "Non",
104
+ "augementation_salaire_precedente": 11
105
+ }]
106
+ }
107
+
108
+
109
+ resp = client.post("/predict", json=payload)
110
+
111
+ print("STATUS:", resp.status_code)
112
+ print("BODY:", resp.text)
113
+
114
+ app.dependency_overrides.clear()
115
+ session.close()
116
+
117
+ assert resp.status_code == 200
118
+ data = resp.json()
119
+ assert data["model_name"] == "best_model"
120
+ assert isinstance(data["results"], list)
121
+ assert len(data["results"]) == 1
122
+
123
+ result = data["results"][0]
124
+ assert result["label"] == "reste_dans_l_entreprise"
125
+ assert isinstance(result["proba"], float)
126
+ assert 0 <= result["proba"] <= 1
127
+
128
+
129
+ def test_not_found_model(tmp_path):
130
+ db_path = tmp_path / "testing.db"
131
+ engine = create_engine(
132
+ f"sqlite:///{db_path}",
133
+ connect_args={"check_same_thread": False},
134
+ future=True,
135
+ )
136
+ SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
137
+
138
+ MLModel.metadata.create_all(engine)
139
+ MLInput.metadata.create_all(engine)
140
+ MLOutput.metadata.create_all(engine)
141
+
142
+ session = SQLSession()
143
+
144
+ def get_db_override():
145
+ return session
146
+
147
+
148
+ app.dependency_overrides[get_db] = get_db_override
149
+
150
+ client = TestClient(app, raise_server_exceptions=False)
151
+
152
+ created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc)
153
+ session.add_all(
154
+ [
155
+ MLModel(
156
+ id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"),
157
+ name="baseline",
158
+ description="Baseline model",
159
+ created_at=created,
160
+ is_active=True,
161
+ ),
162
+ ]
163
+ )
164
+ session.commit()
165
+
166
+
167
+ payload = {
168
+ "model_name": "best_model",
169
+ "inputs": [{
170
+ "id_employee": 123,
171
+ "age": 35,
172
+ "genre": "Homme",
173
+ "revenu_mensuel": 4200,
174
+ "statut_marital": "Célibataire",
175
+ "departement": "Ventes",
176
+ "poste": "Commercial",
177
+ "nombre_experiences_precedentes": 2,
178
+ "nombre_heures_travailless": 40,
179
+ "annee_experience_totale": 5,
180
+ "annees_dans_l_entreprise": 2,
181
+ "annees_dans_le_poste_actuel": 1,
182
+ "nombre_participation_pee": 1,
183
+ "nb_formations_suivies": 3,
184
+ "nombre_employee_sous_responsabilite": 0,
185
+ "code_sondage": 7,
186
+ "distance_domicile_travail": 12,
187
+ "niveau_education": 3,
188
+ "domaine_etude": "Marketing",
189
+ "ayant_enfants": "Non",
190
+ "frequence_deplacement": "Rarement",
191
+ "annees_depuis_la_derniere_promotion": 0,
192
+ "annes_sous_responsable_actuel": 1,
193
+ "satisfaction_employee_environnement": 3,
194
+ "note_evaluation_precedente": 4,
195
+ "niveau_hierarchique_poste": 2,
196
+ "satisfaction_employee_nature_travail": 3,
197
+ "satisfaction_employee_equipe": 4,
198
+ "satisfaction_employee_equilibre_pro_perso": 3,
199
+ "eval_number": "E2",
200
+ "note_evaluation_actuelle": 4,
201
+ "heure_supplementaires": "Non",
202
+ "augementation_salaire_precedente": 11
203
+ }]
204
+ }
205
+
206
+
207
+ resp = client.post("/predict", json=payload)
208
+
209
+
210
+ app.dependency_overrides.clear()
211
+ session.close()
212
+
213
+ assert resp.status_code == 404
214
+ data = resp.json()
215
+ assert data["detail"] == "Modèle introuvable ou inactif"
216
+
217
+
218
+ def test_inactif_model(tmp_path):
219
+ db_path = tmp_path / "testing.db"
220
+ engine = create_engine(
221
+ f"sqlite:///{db_path}",
222
+ connect_args={"check_same_thread": False},
223
+ future=True,
224
+ )
225
+ SQLSession = sessionmaker(bind=engine, autoflush=False, autocommit=False, future=True)
226
+
227
+ MLModel.metadata.create_all(engine)
228
+ MLInput.metadata.create_all(engine)
229
+ MLOutput.metadata.create_all(engine)
230
+
231
+ session = SQLSession()
232
+
233
+ def get_db_override():
234
+ return session
235
+
236
+
237
+ app.dependency_overrides[get_db] = get_db_override
238
+
239
+ client = TestClient(app, raise_server_exceptions=False)
240
+
241
+ created = datetime(2025, 9, 15, 10, 11, 3, 950802, tzinfo=timezone.utc)
242
+ session.add_all(
243
+ [
244
+ MLModel(
245
+ id=uuid.UUID("5b1c7b3a-0000-4000-8000-000000000001"),
246
+ name="baseline",
247
+ description="Baseline model",
248
+ created_at=created,
249
+ is_active=False,
250
+ ),
251
+ ]
252
+ )
253
+ session.commit()
254
+
255
+
256
+ payload = {
257
+ "model_name": "baseline",
258
+ "inputs": [{
259
+ "id_employee": 123,
260
+ "age": 35,
261
+ "genre": "Homme",
262
+ "revenu_mensuel": 4200,
263
+ "statut_marital": "Célibataire",
264
+ "departement": "Ventes",
265
+ "poste": "Commercial",
266
+ "nombre_experiences_precedentes": 2,
267
+ "nombre_heures_travailless": 40,
268
+ "annee_experience_totale": 5,
269
+ "annees_dans_l_entreprise": 2,
270
+ "annees_dans_le_poste_actuel": 1,
271
+ "nombre_participation_pee": 1,
272
+ "nb_formations_suivies": 3,
273
+ "nombre_employee_sous_responsabilite": 0,
274
+ "code_sondage": 7,
275
+ "distance_domicile_travail": 12,
276
+ "niveau_education": 3,
277
+ "domaine_etude": "Marketing",
278
+ "ayant_enfants": "Non",
279
+ "frequence_deplacement": "Rarement",
280
+ "annees_depuis_la_derniere_promotion": 0,
281
+ "annes_sous_responsable_actuel": 1,
282
+ "satisfaction_employee_environnement": 3,
283
+ "note_evaluation_precedente": 4,
284
+ "niveau_hierarchique_poste": 2,
285
+ "satisfaction_employee_nature_travail": 3,
286
+ "satisfaction_employee_equipe": 4,
287
+ "satisfaction_employee_equilibre_pro_perso": 3,
288
+ "eval_number": "E2",
289
+ "note_evaluation_actuelle": 4,
290
+ "heure_supplementaires": "Non",
291
+ "augementation_salaire_precedente": 11
292
+ }]
293
+ }
294
+
295
+
296
+ resp = client.post("/predict", json=payload)
297
+
298
+ print("STATUS:", resp.status_code)
299
+ print("BODY:", resp.text)
300
+
301
+ app.dependency_overrides.clear()
302
+ session.close()
303
+
304
+ assert resp.status_code == 404
305
+ data = resp.json()
306
+ assert data["detail"] == "Modèle introuvable ou inactif"
307
+
308
+
309
+ def test_list_models_returns_500_when_db_fails():
310
+ class BrokenSession:
311
+ def query(self, *a, **kw):
312
+ raise RuntimeError("DB is down")
313
+
314
+ def get_db_override():
315
+ yield BrokenSession()
316
+
317
+ app.dependency_overrides[get_db] = get_db_override
318
+ client = TestClient(app, raise_server_exceptions=False)
319
+
320
+ payload = {
321
+ "model_name": "baseline",
322
+ "inputs": [{
323
+ "id_employee": 123,
324
+ "age": 35,
325
+ "genre": "Homme",
326
+ "revenu_mensuel": 4200,
327
+ "statut_marital": "Célibataire",
328
+ "departement": "Ventes",
329
+ "poste": "Commercial",
330
+ "nombre_experiences_precedentes": 2,
331
+ "nombre_heures_travailless": 40,
332
+ "annee_experience_totale": 5,
333
+ "annees_dans_l_entreprise": 2,
334
+ "annees_dans_le_poste_actuel": 1,
335
+ "nombre_participation_pee": 1,
336
+ "nb_formations_suivies": 3,
337
+ "nombre_employee_sous_responsabilite": 0,
338
+ "code_sondage": 7,
339
+ "distance_domicile_travail": 12,
340
+ "niveau_education": 3,
341
+ "domaine_etude": "Marketing",
342
+ "ayant_enfants": "Non",
343
+ "frequence_deplacement": "Rarement",
344
+ "annees_depuis_la_derniere_promotion": 0,
345
+ "annes_sous_responsable_actuel": 1,
346
+ "satisfaction_employee_environnement": 3,
347
+ "note_evaluation_precedente": 4,
348
+ "niveau_hierarchique_poste": 2,
349
+ "satisfaction_employee_nature_travail": 3,
350
+ "satisfaction_employee_equipe": 4,
351
+ "satisfaction_employee_equilibre_pro_perso": 3,
352
+ "eval_number": "E2",
353
+ "note_evaluation_actuelle": 4,
354
+ "heure_supplementaires": "Non",
355
+ "augementation_salaire_precedente": 11
356
+ }]
357
+ }
358
+
359
+
360
+ resp = client.post("/predict", json=payload)
361
+
362
+ app.dependency_overrides.clear()
363
+
364
+ assert resp.status_code == 500
365
+