File size: 1,442 Bytes
fba3401
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# app.py(现在)
from fastapi import FastAPI
import torch
from transformers import BertTokenizer
from model.sentiment_model import SentimentAnalysisModel
from schemas.sentiment import SentimentRequest, SentimentResponse
from services.inference import predict_sentiment
from schemas.sentiment import BatchSentimentRequest
from services.inference import batch_predict
from fastapi.concurrency import run_in_threadpool


app = FastAPI()


@app.on_event("startup")
def startup_event():
    global tokenizer, model, device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    model = SentimentAnalysisModel("bert-base-uncased")
    model.load_state_dict(torch.load("bert_imdb_sentiment.pth", map_location=device))
    model.to(device)
    model.eval()


@app.post("/predict", response_model=SentimentResponse)
async def predict_api(req: SentimentRequest):
    label, conf = await run_in_threadpool(
        predict_sentiment, req.text, tokenizer, model, device
    )
    return SentimentResponse(label=label, confidence=conf)



@app.post("/predict_batch")
async def predict_batch_api(req: BatchSentimentRequest):
    results = await run_in_threadpool(
        batch_predict, req.texts, tokenizer, model, device
    )
    return results



@app.get("/health")
def health():
    return {"status": "ok"}