File size: 1,442 Bytes
fba3401 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | # app.py(现在)
from fastapi import FastAPI
import torch
from transformers import BertTokenizer
from model.sentiment_model import SentimentAnalysisModel
from schemas.sentiment import SentimentRequest, SentimentResponse
from services.inference import predict_sentiment
from schemas.sentiment import BatchSentimentRequest
from services.inference import batch_predict
from fastapi.concurrency import run_in_threadpool
app = FastAPI()
@app.on_event("startup")
def startup_event():
global tokenizer, model, device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = SentimentAnalysisModel("bert-base-uncased")
model.load_state_dict(torch.load("bert_imdb_sentiment.pth", map_location=device))
model.to(device)
model.eval()
@app.post("/predict", response_model=SentimentResponse)
async def predict_api(req: SentimentRequest):
label, conf = await run_in_threadpool(
predict_sentiment, req.text, tokenizer, model, device
)
return SentimentResponse(label=label, confidence=conf)
@app.post("/predict_batch")
async def predict_batch_api(req: BatchSentimentRequest):
results = await run_in_threadpool(
batch_predict, req.texts, tokenizer, model, device
)
return results
@app.get("/health")
def health():
return {"status": "ok"}
|