ai_teacher_api / app.py
amadsall's picture
Upload 3 files
2908f4b verified
# app.py
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import joblib, numpy as np, os, sys, types, re
import pandas as pd
from typing import Dict, List, Any
import preprocess_utils as pu
APP_VERSION = "1.0.0"
PIPELINE_PATH = "artifacts/model.joblib"
# ---------- MONKEY-PATCH pour pickles entraînés dans un notebook ----------
def _ensure_main_module():
main_mod = sys.modules.get("__main__")
if main_mod is None or not isinstance(main_mod, types.ModuleType):
main_mod = types.ModuleType("__main__")
sys.modules["__main__"] = main_mod
return main_mod
def _expose_helpers_in_main():
main_mod = _ensure_main_module()
helpers = {
"to_float": pu.to_float,
"ravel_1d": pu.ravel_1d,
"to_1d": pu.to_1d,
"combine_text_cols": pu.combine_text_cols,
"clean_target_series": pu.clean_target_series,
}
for k, v in helpers.items():
setattr(main_mod, k, v)
def _try_load_with_monkeypatch(path, max_retries=5):
_expose_helpers_in_main()
last_err = None
for _ in range(max_retries):
try:
return joblib.load(path)
except AttributeError as e:
last_err = e
m = re.search(r"Can't get attribute '([^']+)'", str(e))
if m:
missing = m.group(1)
main_mod = _ensure_main_module()
if not hasattr(main_mod, missing):
# stub identité pour continuer si une petite fonction custom manque
setattr(main_mod, missing, lambda *a, **k: a[0] if a else None)
else:
break
raise last_err if last_err else RuntimeError("Unable to load model.")
# --------------------------------------------------------------------------
# ---------- Introspection des colonnes attendues par le preprocess ----------
REQUIRED: Dict[str, List[str]] = {"num": [], "cat": [], "text": []}
def _introspect_required_columns(pipeline_obj) -> Dict[str, List[str]]:
req = {"num": [], "cat": [], "text": []}
try:
prep = pipeline_obj.named_steps.get("prep")
transformers = getattr(prep, "transformers_", None) or getattr(prep, "transformers", [])
for name, trans, cols in transformers:
if name == "num" and isinstance(cols, list):
req["num"] = list(cols)
elif name == "cat" and isinstance(cols, list):
req["cat"] = list(cols)
elif isinstance(cols, list):
# tout le reste (txt_...) -> colonnes texte individuelles
req["text"].extend(list(cols))
except Exception:
pass
# unicité
req["text"] = list(dict.fromkeys(req["text"]))
return req
def _make_inference_df(prof: str, course: str) -> pd.DataFrame | None:
data: Dict[str, Any] = {}
# colonnes texte : on remplit intelligemment
for col in REQUIRED.get("text", []):
c = col.lower()
if "global" in c:
data[col] = f"{prof} | {course}"
elif "title" in c:
data[col] = course
elif "desc" in c or "description" in c:
data[col] = course
else:
data[col] = (f"{prof} {course}").strip()
# colonnes num
for col in REQUIRED.get("num", []):
data[col] = 0
# colonnes cat
for col in REQUIRED.get("cat", []):
data[col] = "missing"
return pd.DataFrame([data]) if data else None
# --------------------------------------------------------------------------
app = FastAPI(title="Satisfaction Grades API", version=APP_VERSION)
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"])
class Payload(BaseModel):
professor: str
course: str
class Prediction(BaseModel):
grade: float
@app.on_event("startup")
def load_model():
global pipeline, REQUIRED
if not os.path.exists(PIPELINE_PATH):
raise RuntimeError(f"Model file not found at {PIPELINE_PATH}")
pipeline = _try_load_with_monkeypatch(PIPELINE_PATH)
REQUIRED = _introspect_required_columns(pipeline)
print(f"✅ Model loaded from {PIPELINE_PATH}")
print("Expected columns:", REQUIRED)
@app.get("/health")
def health():
return {"status": "ok", "version": APP_VERSION}
@app.post("/api/predict", response_model=Prediction)
def predict(payload: Payload):
prof = (payload.professor or "").strip()
course = (payload.course or "").strip()
if not prof and not course:
raise HTTPException(status_code=422, detail="Provide at least 'professor' or 'course'.")
# 1) essayer avec DataFrame conforme au preprocess
try:
df = _make_inference_df(prof, course)
if df is not None:
y_pred = pipeline.predict(df)
val = float(np.ravel(y_pred)[0])
return {"grade": max(1.0, min(5.0, val))}
except Exception as e_df:
print(f"[WARN] DF inference failed → fallback string. Reason: {e_df}")
# 2) fallback: pipeline qui attend une simple chaîne (cas TF-IDF unique)
try:
text = f"{prof} | {course}"
y_pred = pipeline.predict([text])
val = float(np.ravel(y_pred)[0])
return {"grade": max(1.0, min(5.0, val))}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Inference error: {e}")