File size: 540 Bytes
3b2ce9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from fastapi import FastAPI
import torch

app = FastAPI()

tokenizer = AutoTokenizer.from_pretrained("./model")
model = AutoModelForSequenceClassification.from_pretrained("./model")

@app.post("/predict")
def predict(text: str):
    enc = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        out = model(**enc).logits
    pred = torch.argmax(out, dim=1).item()
    label = "productive" if pred == 1 else "unproductive"
    return {"label": label}