Spaces:
Build error
Build error
| # app.py | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import joblib, numpy as np, os, sys, types, re | |
| import pandas as pd | |
| from typing import Dict, List, Any | |
| import preprocess_utils as pu | |
| APP_VERSION = "1.0.0" | |
| PIPELINE_PATH = "artifacts/model.joblib" | |
| # ---------- MONKEY-PATCH pour pickles entraînés dans un notebook ---------- | |
| def _ensure_main_module(): | |
| main_mod = sys.modules.get("__main__") | |
| if main_mod is None or not isinstance(main_mod, types.ModuleType): | |
| main_mod = types.ModuleType("__main__") | |
| sys.modules["__main__"] = main_mod | |
| return main_mod | |
| def _expose_helpers_in_main(): | |
| main_mod = _ensure_main_module() | |
| helpers = { | |
| "to_float": pu.to_float, | |
| "ravel_1d": pu.ravel_1d, | |
| "to_1d": pu.to_1d, | |
| "combine_text_cols": pu.combine_text_cols, | |
| "clean_target_series": pu.clean_target_series, | |
| } | |
| for k, v in helpers.items(): | |
| setattr(main_mod, k, v) | |
| def _try_load_with_monkeypatch(path, max_retries=5): | |
| _expose_helpers_in_main() | |
| last_err = None | |
| for _ in range(max_retries): | |
| try: | |
| return joblib.load(path) | |
| except AttributeError as e: | |
| last_err = e | |
| m = re.search(r"Can't get attribute '([^']+)'", str(e)) | |
| if m: | |
| missing = m.group(1) | |
| main_mod = _ensure_main_module() | |
| if not hasattr(main_mod, missing): | |
| # stub identité pour continuer si une petite fonction custom manque | |
| setattr(main_mod, missing, lambda *a, **k: a[0] if a else None) | |
| else: | |
| break | |
| raise last_err if last_err else RuntimeError("Unable to load model.") | |
| # -------------------------------------------------------------------------- | |
| # ---------- Introspection des colonnes attendues par le preprocess ---------- | |
| REQUIRED: Dict[str, List[str]] = {"num": [], "cat": [], "text": []} | |
| def _introspect_required_columns(pipeline_obj) -> Dict[str, List[str]]: | |
| req = {"num": [], "cat": [], "text": []} | |
| try: | |
| prep = pipeline_obj.named_steps.get("prep") | |
| transformers = getattr(prep, "transformers_", None) or getattr(prep, "transformers", []) | |
| for name, trans, cols in transformers: | |
| if name == "num" and isinstance(cols, list): | |
| req["num"] = list(cols) | |
| elif name == "cat" and isinstance(cols, list): | |
| req["cat"] = list(cols) | |
| elif isinstance(cols, list): | |
| # tout le reste (txt_...) -> colonnes texte individuelles | |
| req["text"].extend(list(cols)) | |
| except Exception: | |
| pass | |
| # unicité | |
| req["text"] = list(dict.fromkeys(req["text"])) | |
| return req | |
| def _make_inference_df(prof: str, course: str) -> pd.DataFrame | None: | |
| data: Dict[str, Any] = {} | |
| # colonnes texte : on remplit intelligemment | |
| for col in REQUIRED.get("text", []): | |
| c = col.lower() | |
| if "global" in c: | |
| data[col] = f"{prof} | {course}" | |
| elif "title" in c: | |
| data[col] = course | |
| elif "desc" in c or "description" in c: | |
| data[col] = course | |
| else: | |
| data[col] = (f"{prof} {course}").strip() | |
| # colonnes num | |
| for col in REQUIRED.get("num", []): | |
| data[col] = 0 | |
| # colonnes cat | |
| for col in REQUIRED.get("cat", []): | |
| data[col] = "missing" | |
| return pd.DataFrame([data]) if data else None | |
| # -------------------------------------------------------------------------- | |
| app = FastAPI(title="Satisfaction Grades API", version=APP_VERSION) | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_credentials=True, | |
| allow_methods=["*"], allow_headers=["*"]) | |
| class Payload(BaseModel): | |
| professor: str | |
| course: str | |
| class Prediction(BaseModel): | |
| grade: float | |
| def load_model(): | |
| global pipeline, REQUIRED | |
| if not os.path.exists(PIPELINE_PATH): | |
| raise RuntimeError(f"Model file not found at {PIPELINE_PATH}") | |
| pipeline = _try_load_with_monkeypatch(PIPELINE_PATH) | |
| REQUIRED = _introspect_required_columns(pipeline) | |
| print(f"✅ Model loaded from {PIPELINE_PATH}") | |
| print("Expected columns:", REQUIRED) | |
| def health(): | |
| return {"status": "ok", "version": APP_VERSION} | |
| def predict(payload: Payload): | |
| prof = (payload.professor or "").strip() | |
| course = (payload.course or "").strip() | |
| if not prof and not course: | |
| raise HTTPException(status_code=422, detail="Provide at least 'professor' or 'course'.") | |
| # 1) essayer avec DataFrame conforme au preprocess | |
| try: | |
| df = _make_inference_df(prof, course) | |
| if df is not None: | |
| y_pred = pipeline.predict(df) | |
| val = float(np.ravel(y_pred)[0]) | |
| return {"grade": max(1.0, min(5.0, val))} | |
| except Exception as e_df: | |
| print(f"[WARN] DF inference failed → fallback string. Reason: {e_df}") | |
| # 2) fallback: pipeline qui attend une simple chaîne (cas TF-IDF unique) | |
| try: | |
| text = f"{prof} | {course}" | |
| y_pred = pipeline.predict([text]) | |
| val = float(np.ravel(y_pred)[0]) | |
| return {"grade": max(1.0, min(5.0, val))} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Inference error: {e}") | |