Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import torch | |
| import pickle | |
| import torch.nn as nn | |
| from transformers import AutoTokenizer, AutoModel | |
| class ModelClass(nn.Module): | |
| def __init__(self, num_labels_per_task): | |
| super().__init__() | |
| self.encoder = AutoModel.from_pretrained("roberta-base") | |
| hidden_size = self.encoder.config.hidden_size | |
| self.classifiers = nn.ModuleList([ | |
| nn.Linear(hidden_size, num_labels) for num_labels in num_labels_per_task | |
| ]) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.pooler_output | |
| return [clf(pooled_output) for clf in self.classifiers] | |
| with open("roberta_model.pkl", "rb") as f: | |
| bundle = pickle.load(f) | |
| tokenizer = bundle["tokenizer"] | |
| label_encoders = bundle["label_encoders"] | |
| model_state_dict = bundle["model_state_dict"] | |
| label_columns = list(label_encoders.keys()) | |
| num_labels_per_task = [len(le.classes_) for le in label_encoders.values()] | |
| model = ModelClass(num_labels_per_task) | |
| model.load_state_dict(model_state_dict) | |
| model.eval() | |
| app = FastAPI() | |
| class Request(BaseModel): | |
| text: str | |
| def predict(req: Request): | |
| try: | |
| inputs = tokenizer(req.text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| logits = model(**inputs) | |
| preds = [torch.argmax(logit, dim=1).item() for logit in logits] | |
| decoded = {col: label_encoders[col].inverse_transform([pred])[0] for col, pred in zip(label_columns, preds)} | |
| return {"predictions": decoded} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |