DataSage12's picture
Update app/main.py
f3b41f5 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import pipeline
from fastapi.staticfiles import StaticFiles
import logging
from fastapi.responses import FileResponse
# app/main.py
import os, pathlib
# 👉 garantit un cache écrivable dans le conteneur
CACHE_DIR = os.getenv("HF_HOME", "/app/.cache")
pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s [%(name)s] %(message)s"
)
logger = logging.getLogger("sentiment-api")
app = FastAPI(
title="Sentiment Analysis API",
description="FastAPI + Hugging Face Transformers. Endpoints: /predict (principal), /analyze (compat).",
version="0.2.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
# Option 1 (souvent OK) :
analyzer = pipeline(
"sentiment-analysis",
model="distilbert-base-uncased-finetuned-sst-2-english",
model_kwargs={"cache_dir": CACHE_DIR}
)
class PredictIn(BaseModel):
text: str
class PredictOut(BaseModel):
text: str
sentiment: str
confidence: float
@app.get("/", include_in_schema=False)
def root():
return FileResponse(os.path.join("static", "index.html"))
def _normalize_label(label: str) -> str:
l = (label or "").lower()
if l.startswith("neg"):
return "negative"
if l.startswith("neu"):
return "neutral"
if l.startswith("pos"):
return "positive"
return "positive" if "1" in l or "5" in l or "star" in l else "negative"
@app.post("/predict", response_model=PredictOut)
def predict(payload: PredictIn):
text = (payload.text or "").strip()
if not text:
logger.warning("Texte vide reçu")
raise HTTPException(status_code=400, detail="Text is empty")
try:
res = analyzer(text)[0]
label = _normalize_label(res.get("label", ""))
score = float(res.get("score", 0.0))
out = {"text": text, "sentiment": label, "confidence": round(score, 4)}
logger.info("Prediction: %s", out)
return out
except Exception as e:
logger.exception("Erreur d'inférence: %s", e)
raise HTTPException(status_code=500, detail="Prediction failed")
@app.post("/analyze", response_model=PredictOut)
def analyze_compat(payload: PredictIn):
return predict(payload)