roberta / app.py
ganeshkonapalli's picture
Upload app.py
d7db76e verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer
from config import DEVICE, MAX_LEN, LABEL_COLUMNS, ROBERTA_MODEL_NAME, MODEL_SAVE_DIR, LABEL_ENCODERS_PATH
from models.roberta_model import RobertaMultiOutputModel
from dataset_utils import load_label_encoders
import numpy as np
app = FastAPI()
# Load label encoders
label_encoders = load_label_encoders()
# Load tokenizer and model
tokenizer = RobertaTokenizer.from_pretrained(ROBERTA_MODEL_NAME)
num_labels = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS]
model = RobertaMultiOutputModel(num_labels)
model.load_state_dict(torch.load(MODEL_SAVE_DIR + "ROBERTA_model.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()
class RequestText(BaseModel):
text: str
@app.post("/predict")
def predict_labels(request: RequestText):
try:
inputs = tokenizer(
request.text,
padding='max_length',
truncation=True,
max_length=MAX_LEN,
return_tensors="pt"
)
input_ids = inputs['input_ids'].to(DEVICE)
attention_mask = inputs['attention_mask'].to(DEVICE)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
predictions = [torch.argmax(logits, dim=1).cpu().numpy()[0] for logits in outputs]
decoded_predictions = {
label: label_encoders[label].inverse_transform([pred])[0]
for label, pred in zip(LABEL_COLUMNS, predictions)
}
return decoded_predictions
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))