Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import RobertaTokenizer | |
| from config import DEVICE, MAX_LEN, LABEL_COLUMNS, ROBERTA_MODEL_NAME, MODEL_SAVE_DIR, LABEL_ENCODERS_PATH | |
| from models.roberta_model import RobertaMultiOutputModel | |
| from dataset_utils import load_label_encoders | |
| import numpy as np | |
| app = FastAPI() | |
| # Load label encoders | |
| label_encoders = load_label_encoders() | |
| # Load tokenizer and model | |
| tokenizer = RobertaTokenizer.from_pretrained(ROBERTA_MODEL_NAME) | |
| num_labels = [len(label_encoders[col].classes_) for col in LABEL_COLUMNS] | |
| model = RobertaMultiOutputModel(num_labels) | |
| model.load_state_dict(torch.load(MODEL_SAVE_DIR + "ROBERTA_model.pth", map_location=DEVICE)) | |
| model.to(DEVICE) | |
| model.eval() | |
| class RequestText(BaseModel): | |
| text: str | |
| def predict_labels(request: RequestText): | |
| try: | |
| inputs = tokenizer( | |
| request.text, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=MAX_LEN, | |
| return_tensors="pt" | |
| ) | |
| input_ids = inputs['input_ids'].to(DEVICE) | |
| attention_mask = inputs['attention_mask'].to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| predictions = [torch.argmax(logits, dim=1).cpu().numpy()[0] for logits in outputs] | |
| decoded_predictions = { | |
| label: label_encoders[label].inverse_transform([pred])[0] | |
| for label, pred in zip(LABEL_COLUMNS, predictions) | |
| } | |
| return decoded_predictions | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |