from fastapi import FastAPI from pydantic import BaseModel import torch from transformers import RobertaTokenizer from app.roberta_model import RobertaMultiOutputModel from app.dataset_utils import load_label_encoders from app.config import MAX_LEN, LABEL_COLUMNS, MODEL_SAVE_DIR, LABEL_ENCODERS_PATH, TOKENIZER_PATH app = FastAPI() class InputText(BaseModel): sanction_context: str label_encoders = load_label_encoders(LABEL_ENCODERS_PATH) num_classes_per_label = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] model = RobertaMultiOutputModel(num_classes_per_label) model.load_state_dict(torch.load(f"{MODEL_SAVE_DIR}/ROBERTA_model.pth", map_location="cpu")) model.eval() tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER_PATH) @app.post("/predict") def predict(input_text: InputText): inputs = tokenizer( input_text.sanction_context, padding='max_length', truncation=True, max_length=MAX_LEN, return_tensors="pt" ) with torch.no_grad(): outputs = model(**inputs) predicted = [torch.argmax(logit, dim=1).item() for logit in outputs] decoded = { label: label_encoders[label].inverse_transform([pred])[0] for label, pred in zip(LABEL_COLUMNS, predicted) } return {"predictions": decoded}