Spaces:
Configuration error
Configuration error
File size: 988 Bytes
3eecc60 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import pickle
from transformers import BertTokenizer, BertForSequenceClassification
app = FastAPI()
class InputText(BaseModel):
text: str
# Load model components
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model.load_state_dict(torch.load("app/bert_model.pth", map_location=torch.device("cpu")))
model.eval()
with open("app/tokenizer.pkl", "rb") as f:
tokenizer = pickle.load(f)
with open("app/label_encoder.pkl", "rb") as f:
label_encoder = pickle.load(f)
@app.post("/predict")
def predict(input_data: InputText):
encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding="max_length", max_length=32)
with torch.no_grad():
outputs = model(**encoding)
prediction = torch.argmax(outputs.logits, dim=1).item()
label = label_encoder.inverse_transform([prediction])[0]
return {"prediction": label}
|