Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import RobertaTokenizer | |
| from config import LABEL_COLUMNS, MAX_LEN, DEVICE, MODEL_SAVE_DIR | |
| from dataset_utils import load_label_encoders | |
| from models.roberta_model import RobertaMultiOutputModel | |
| app = FastAPI() | |
| # Load model, tokenizer, and label encoders | |
| tokenizer = RobertaTokenizer.from_pretrained("roberta-base") | |
| label_encoders = load_label_encoders() | |
| num_classes_per_label = [len(le.classes_) for le in label_encoders.values()] | |
| model = RobertaMultiOutputModel(num_classes_per_label) | |
| model.load_state_dict(torch.load(MODEL_SAVE_DIR + "ROBERTA_model.pth", map_location=DEVICE)) | |
| model.to(DEVICE) | |
| model.eval() | |
| class InputText(BaseModel): | |
| text: str | |
| def predict(input_data: InputText): | |
| inputs = tokenizer(input_data.text, truncation=True, padding="max_length", max_length=MAX_LEN, return_tensors="pt") | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| predictions = [torch.argmax(output, dim=1).item() for output in outputs] | |
| decoded_preds = {col: label_encoders[col].inverse_transform([pred])[0] for col, pred in zip(LABEL_COLUMNS, predictions)} | |
| return {"predictions": decoded_preds} | |