from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import pipeline from fastapi.staticfiles import StaticFiles import logging from fastapi.responses import FileResponse # app/main.py import os, pathlib # 👉 garantit un cache écrivable dans le conteneur CACHE_DIR = os.getenv("HF_HOME", "/app/.cache") pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s [%(name)s] %(message)s" ) logger = logging.getLogger("sentiment-api") app = FastAPI( title="Sentiment Analysis API", description="FastAPI + Hugging Face Transformers. Endpoints: /predict (principal), /analyze (compat).", version="0.2.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/static", StaticFiles(directory="static"), name="static") # Option 1 (souvent OK) : analyzer = pipeline( "sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", model_kwargs={"cache_dir": CACHE_DIR} ) class PredictIn(BaseModel): text: str class PredictOut(BaseModel): text: str sentiment: str confidence: float @app.get("/", include_in_schema=False) def root(): return FileResponse(os.path.join("static", "index.html")) def _normalize_label(label: str) -> str: l = (label or "").lower() if l.startswith("neg"): return "negative" if l.startswith("neu"): return "neutral" if l.startswith("pos"): return "positive" return "positive" if "1" in l or "5" in l or "star" in l else "negative" @app.post("/predict", response_model=PredictOut) def predict(payload: PredictIn): text = (payload.text or "").strip() if not text: logger.warning("Texte vide reçu") raise HTTPException(status_code=400, detail="Text is empty") try: res = analyzer(text)[0] label = _normalize_label(res.get("label", "")) score = float(res.get("score", 0.0)) out = {"text": text, "sentiment": label, "confidence": round(score, 4)} logger.info("Prediction: %s", out) return out except Exception as e: logger.exception("Erreur d'inférence: %s", e) raise HTTPException(status_code=500, detail="Prediction failed") @app.post("/analyze", response_model=PredictOut) def analyze_compat(payload: PredictIn): return predict(payload)