robertathing / app.py
ganeshkonapalli's picture
Upload 4 files
4790396 verified
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer
from config import LABEL_COLUMNS, MAX_LEN, DEVICE, MODEL_SAVE_DIR
from dataset_utils import load_label_encoders
from models.roberta_model import RobertaMultiOutputModel
app = FastAPI()
# Load model, tokenizer, and label encoders
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
label_encoders = load_label_encoders()
num_classes_per_label = [len(le.classes_) for le in label_encoders.values()]
model = RobertaMultiOutputModel(num_classes_per_label)
model.load_state_dict(torch.load(MODEL_SAVE_DIR + "ROBERTA_model.pth", map_location=DEVICE))
model.to(DEVICE)
model.eval()
class InputText(BaseModel):
text: str
@app.post("/predict")
def predict(input_data: InputText):
inputs = tokenizer(input_data.text, truncation=True, padding="max_length", max_length=MAX_LEN, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
predictions = [torch.argmax(output, dim=1).item() for output in outputs]
decoded_preds = {col: label_encoders[col].inverse_transform([pred])[0] for col, pred in zip(LABEL_COLUMNS, predictions)}
return {"predictions": decoded_preds}