from fastapi import FastAPI from pydantic import BaseModel import torch from transformers import RobertaTokenizer from config import LABEL_COLUMNS, MAX_LEN, DEVICE, MODEL_SAVE_DIR from dataset_utils import load_label_encoders from models.roberta_model import RobertaMultiOutputModel app = FastAPI() # Load model, tokenizer, and label encoders tokenizer = RobertaTokenizer.from_pretrained("roberta-base") label_encoders = load_label_encoders() num_classes_per_label = [len(le.classes_) for le in label_encoders.values()] model = RobertaMultiOutputModel(num_classes_per_label) model.load_state_dict(torch.load(MODEL_SAVE_DIR + "ROBERTA_model.pth", map_location=DEVICE)) model.to(DEVICE) model.eval() class InputText(BaseModel): text: str @app.post("/predict") def predict(input_data: InputText): inputs = tokenizer(input_data.text, truncation=True, padding="max_length", max_length=MAX_LEN, return_tensors="pt") inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) predictions = [torch.argmax(output, dim=1).item() for output in outputs] decoded_preds = {col: label_encoders[col].inverse_transform([pred])[0] for col, pred in zip(LABEL_COLUMNS, predictions)} return {"predictions": decoded_preds}