Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import torch | |
| import os | |
| from transformers import RobertaTokenizer | |
| class InputData(BaseModel): | |
| text: str | |
| app = FastAPI() | |
| # Load model | |
| model_path = os.path.join(os.path.dirname(__file__), "roberta_model.pkl") | |
| model = torch.load(model_path, map_location=torch.device("cpu")) | |
| model.eval() | |
| # Load tokenizer | |
| tokenizer = RobertaTokenizer.from_pretrained("roberta-base") | |
| async def root(): | |
| return {"message": "RoBERTa FastAPI Space is running!"} | |
| async def predict(data: InputData): | |
| inputs = tokenizer(data.text, return_tensors="pt", truncation=True, padding=True) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| logits = outputs.logits | |
| prediction = torch.argmax(logits, dim=1).item() | |
| return {"prediction": prediction} |