bert-fastapi-hf / main.py
ganeshkonapalli's picture
Upload 6 files
3eecc60 verified
raw
history blame contribute delete
988 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import pickle
from transformers import BertTokenizer, BertForSequenceClassification
app = FastAPI()
class InputText(BaseModel):
text: str
# Load model components
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model.load_state_dict(torch.load("app/bert_model.pth", map_location=torch.device("cpu")))
model.eval()
with open("app/tokenizer.pkl", "rb") as f:
tokenizer = pickle.load(f)
with open("app/label_encoder.pkl", "rb") as f:
label_encoder = pickle.load(f)
@app.post("/predict")
def predict(input_data: InputText):
encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding="max_length", max_length=32)
with torch.no_grad():
outputs = model(**encoding)
prediction = torch.argmax(outputs.logits, dim=1).item()
label = label_encoder.inverse_transform([prediction])[0]
return {"prediction": label}