Spaces:
Running
Running
File size: 1,200 Bytes
c08c05a 1001133 c08c05a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
app = FastAPI(
title="Roadblock Detection API",
description="Detects whether a check-in contains a roadblock",
version="1.0"
)
MODEL_NAME = "mjpsm/roadblock-classifier-v1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
class Request(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Roadblock API is live"}
@app.post("/predict")
def predict(req: Request):
inputs = tokenizer(
req.text,
return_tensors="pt",
truncation=True,
padding=True
)
# 🔥 FIX HERE
inputs.pop("token_type_ids", None)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
label = model.config.id2label[pred]
return {
"input": req.text,
"prediction": label,
"confidence": round(confidence, 3)
} |