import os from typing import List import torch from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from transformers import Pipeline, pipeline APP_TITLE = "Sentiment Analysis API" MODEL_NAME = os.getenv("MODEL_NAME", "distilbert-base-uncased-finetuned-sst-2-english") app = FastAPI(title=APP_TITLE) class PredictRequest(BaseModel): inputs: List[str] = Field(..., min_items=1, description="List of input texts") class Prediction(BaseModel): label: str score: float class PredictResponse(BaseModel): predictions: List[Prediction] sentiment_pipe: Pipeline | None = None @app.on_event("startup") def load_model() -> None: global sentiment_pipe device = 0 if torch.cuda.is_available() else -1 sentiment_pipe = pipeline( task="sentiment-analysis", model=MODEL_NAME, device=device ) @app.get("/health") def health() -> dict: return {"status": "ok"} @app.post("/predict", response_model=PredictResponse) def predict(req: PredictRequest) -> PredictResponse: if sentiment_pipe is None: raise HTTPException(status_code=503, detail="Model not loaded") try: outputs = sentiment_pipe(req.inputs, truncation=True) preds = [Prediction(label=o["label"], score=float(o["score"])) for o in outputs] return PredictResponse(predictions=preds) except Exception as e: raise HTTPException(status_code=400, detail=str(e)) if __name__ == "__main__": import uvicorn uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)