subbu-roberta / app.py
subbu123456's picture
Upload 4 files
18177db verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
import pickle
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
class ModelClass(nn.Module):
def __init__(self, num_labels_per_task):
super().__init__()
self.encoder = AutoModel.from_pretrained("roberta-base")
hidden_size = self.encoder.config.hidden_size
self.classifiers = nn.ModuleList([
nn.Linear(hidden_size, num_labels) for num_labels in num_labels_per_task
])
def forward(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
pooled_output = outputs.pooler_output
return [clf(pooled_output) for clf in self.classifiers]
with open("roberta_model.pkl", "rb") as f:
bundle = pickle.load(f)
tokenizer = bundle["tokenizer"]
label_encoders = bundle["label_encoders"]
model_state_dict = bundle["model_state_dict"]
label_columns = list(label_encoders.keys())
num_labels_per_task = [len(le.classes_) for le in label_encoders.values()]
model = ModelClass(num_labels_per_task)
model.load_state_dict(model_state_dict)
model.eval()
app = FastAPI()
class Request(BaseModel):
text: str
@app.post("/predict")
def predict(req: Request):
try:
inputs = tokenizer(req.text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
logits = model(**inputs)
preds = [torch.argmax(logit, dim=1).item() for logit in logits]
decoded = {col: label_encoders[col].inverse_transform([pred])[0] for col, pred in zip(label_columns, preds)}
return {"predictions": decoded}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))