mjpsm's picture
Update app.py
1001133 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
app = FastAPI(
title="Roadblock Detection API",
description="Detects whether a check-in contains a roadblock",
version="1.0"
)
MODEL_NAME = "mjpsm/roadblock-classifier-v1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
class Request(BaseModel):
text: str
@app.get("/")
def root():
return {"message": "Roadblock API is live"}
@app.post("/predict")
def predict(req: Request):
inputs = tokenizer(
req.text,
return_tensors="pt",
truncation=True,
padding=True
)
# 🔥 FIX HERE
inputs.pop("token_type_ids", None)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
label = model.config.id2label[pred]
return {
"input": req.text,
"prediction": label,
"confidence": round(confidence, 3)
}