from fastapi import FastAPI from pydantic import BaseModel import torch import pickle from transformers import BertTokenizer, BertForSequenceClassification app = FastAPI() class InputText(BaseModel): text: str # Load model components model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) model.load_state_dict(torch.load("app/bert_model.pth", map_location=torch.device("cpu"))) model.eval() with open("app/tokenizer.pkl", "rb") as f: tokenizer = pickle.load(f) with open("app/label_encoder.pkl", "rb") as f: label_encoder = pickle.load(f) @app.post("/predict") def predict(input_data: InputText): encoding = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding="max_length", max_length=32) with torch.no_grad(): outputs = model(**encoding) prediction = torch.argmax(outputs.logits, dim=1).item() label = label_encoder.inverse_transform([prediction])[0] return {"prediction": label}