subbu / app.py
subbu123456's picture
Upload 4 files
0b761ae verified
raw
history blame contribute delete
714 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer
app = FastAPI()
# Load tokenizer and model
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
model = torch.load("roberta_model.pkl", map_location=torch.device("cpu"))
model.eval()
class RequestBody(BaseModel):
text: str
@app.post("/predict")
def predict(data: RequestBody):
inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits if hasattr(outputs, "logits") else outputs
preds = torch.sigmoid(logits).cpu().numpy().tolist()
return {"predictions": preds}