from fastapi import FastAPI, HTTPException from pydantic import BaseModel import torch from transformers import RobertaTokenizer from config import DEVICE, MAX_LEN, LABEL_COLUMNS, ROBERTA_MODEL_NAME, MODEL_SAVE_DIR, LABEL_ENCODERS_PATH from models.roberta_model import RobertaMultiOutputModel from dataset_utils import load_label_encoders import numpy as np app = FastAPI() # Load label encoders label_encoders = load_label_encoders() # Load tokenizer and model tokenizer = RobertaTokenizer.from_pretrained(ROBERTA_MODEL_NAME) num_labels = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] model = RobertaMultiOutputModel(num_labels) model.load_state_dict(torch.load(MODEL_SAVE_DIR + "ROBERTA_model.pth", map_location=DEVICE)) model.to(DEVICE) model.eval() class RequestText(BaseModel): text: str @app.post("/predict") def predict_labels(request: RequestText): try: inputs = tokenizer( request.text, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors="pt" ) input_ids = inputs['input_ids'].to(DEVICE) attention_mask = inputs['attention_mask'].to(DEVICE) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) predictions = [torch.argmax(logits, dim=1).cpu().numpy()[0] for logits in outputs] decoded_predictions = { label: label_encoders[label].inverse_transform([pred])[0] for label, pred in zip(LABEL_COLUMNS, predictions) } return decoded_predictions except Exception as e: raise HTTPException(status_code=500, detail=str(e))