File size: 2,545 Bytes
736b100
 
 
 
 
 
f3b41f5
736b100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa7de1a
 
 
 
 
 
736b100
 
 
 
 
 
 
 
 
f3b41f5
 
 
 
736b100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import pipeline
from fastapi.staticfiles import StaticFiles
import logging
from fastapi.responses import FileResponse



# app/main.py
import os, pathlib

# 👉 garantit un cache écrivable dans le conteneur
CACHE_DIR = os.getenv("HF_HOME", "/app/.cache")
pathlib.Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)


logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s %(levelname)s [%(name)s] %(message)s"
)
logger = logging.getLogger("sentiment-api")

app = FastAPI(
    title="Sentiment Analysis API",
    description="FastAPI + Hugging Face Transformers. Endpoints: /predict (principal), /analyze (compat).",
    version="0.2.0"
)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"], allow_credentials=True,
    allow_methods=["*"], allow_headers=["*"],
)

app.mount("/static", StaticFiles(directory="static"), name="static")

# Option 1 (souvent OK) :
analyzer = pipeline(
    "sentiment-analysis",
    model="distilbert-base-uncased-finetuned-sst-2-english",
    model_kwargs={"cache_dir": CACHE_DIR}
)

class PredictIn(BaseModel):
    text: str

class PredictOut(BaseModel):
    text: str
    sentiment: str
    confidence: float


@app.get("/", include_in_schema=False)
def root():
    return FileResponse(os.path.join("static", "index.html"))

def _normalize_label(label: str) -> str:
    l = (label or "").lower()
    if l.startswith("neg"):
        return "negative"
    if l.startswith("neu"):
        return "neutral"
    if l.startswith("pos"):
        return "positive"
    return "positive" if "1" in l or "5" in l or "star" in l else "negative"

@app.post("/predict", response_model=PredictOut)
def predict(payload: PredictIn):
    text = (payload.text or "").strip()
    if not text:
        logger.warning("Texte vide reçu")
        raise HTTPException(status_code=400, detail="Text is empty")

    try:
        res = analyzer(text)[0]
        label = _normalize_label(res.get("label", ""))
        score = float(res.get("score", 0.0))
        out = {"text": text, "sentiment": label, "confidence": round(score, 4)}
        logger.info("Prediction: %s", out)
        return out
    except Exception as e:
        logger.exception("Erreur d'inférence: %s", e)
        raise HTTPException(status_code=500, detail="Prediction failed")

@app.post("/analyze", response_model=PredictOut)
def analyze_compat(payload: PredictIn):
    return predict(payload)