Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import RobertaTokenizer | |
| from app.roberta_model import RobertaMultiOutputModel | |
| from app.dataset_utils import load_label_encoders | |
| from app.config import MAX_LEN, LABEL_COLUMNS, MODEL_SAVE_DIR, LABEL_ENCODERS_PATH, TOKENIZER_PATH | |
| app = FastAPI() | |
| class InputText(BaseModel): | |
| sanction_context: str | |
| label_encoders = load_label_encoders(LABEL_ENCODERS_PATH) | |
| num_classes_per_label = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] | |
| model = RobertaMultiOutputModel(num_classes_per_label) | |
| model.load_state_dict(torch.load(f"{MODEL_SAVE_DIR}/ROBERTA_model.pth", map_location="cpu")) | |
| model.eval() | |
| tokenizer = RobertaTokenizer.from_pretrained(TOKENIZER_PATH) | |
| def predict(input_text: InputText): | |
| inputs = tokenizer( | |
| input_text.sanction_context, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=MAX_LEN, | |
| return_tensors="pt" | |
| ) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predicted = [torch.argmax(logit, dim=1).item() for logit in outputs] | |
| decoded = { | |
| label: label_encoders[label].inverse_transform([pred])[0] | |
| for label, pred in zip(LABEL_COLUMNS, predicted) | |
| } | |
| return {"predictions": decoded} | |