File size: 988 Bytes
3eecc60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32

from fastapi import FastAPI
from pydantic import BaseModel
import torch
import pickle
from transformers import BertTokenizer, BertForSequenceClassification

app = FastAPI()

class InputText(BaseModel):
    text: str

# Load model components
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model.load_state_dict(torch.load("app/bert_model.pth", map_location=torch.device("cpu")))
model.eval()

with open("app/tokenizer.pkl", "rb") as f:
    tokenizer = pickle.load(f)

with open("app/label_encoder.pkl", "rb") as f:
    label_encoder = pickle.load(f)

@app.post("/predict")
def predict(input_data: InputText):
    encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding="max_length", max_length=32)
    with torch.no_grad():
        outputs = model(**encoding)
        prediction = torch.argmax(outputs.logits, dim=1).item()
    label = label_encoder.inverse_transform([prediction])[0]
    return {"prediction": label}