SavlonBhai commited on
Commit
13a03b8
·
verified ·
1 Parent(s): 35283b2

Create Model

Browse files
Files changed (1) hide show
  1. Model +53 -0
Model ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel, Field
7
+ from transformers import Pipeline, pipeline
8
+
9
+ APP_TITLE = "Sentiment Analysis API"
10
+ MODEL_NAME = os.getenv("MODEL_NAME", "distilbert-base-uncased-finetuned-sst-2-english")
11
+
12
+ app = FastAPI(title=APP_TITLE)
13
+
14
+ class PredictRequest(BaseModel):
15
+ inputs: List[str] = Field(..., min_items=1, description="List of input texts")
16
+
17
+ class Prediction(BaseModel):
18
+ label: str
19
+ score: float
20
+
21
+ class PredictResponse(BaseModel):
22
+ predictions: List[Prediction]
23
+
24
+ sentiment_pipe: Pipeline | None = None
25
+
26
+ @app.on_event("startup")
27
+ def load_model() -> None:
28
+ global sentiment_pipe
29
+ device = 0 if torch.cuda.is_available() else -1
30
+ sentiment_pipe = pipeline(
31
+ task="sentiment-analysis",
32
+ model=MODEL_NAME,
33
+ device=device
34
+ )
35
+
36
+ @app.get("/health")
37
+ def health() -> dict:
38
+ return {"status": "ok"}
39
+
40
+ @app.post("/predict", response_model=PredictResponse)
41
+ def predict(req: PredictRequest) -> PredictResponse:
42
+ if sentiment_pipe is None:
43
+ raise HTTPException(status_code=503, detail="Model not loaded")
44
+ try:
45
+ outputs = sentiment_pipe(req.inputs, truncation=True)
46
+ preds = [Prediction(label=o["label"], score=float(o["score"])) for o in outputs]
47
+ return PredictResponse(predictions=preds)
48
+ except Exception as e:
49
+ raise HTTPException(status_code=400, detail=str(e))
50
+
51
+ if __name__ == "__main__":
52
+ import uvicorn
53
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)