mjpsm's picture
Upload 3 files
950cef0 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
app = FastAPI(
title="Check-in Detail Classifier API",
description="Classifies check-ins as DETAILED or NOT_DETAILED",
version="1.0"
)
# Load model once (efficient)
MODEL_NAME = "mjpsm/checkin-detail-classifier"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval()
# Request schema
class Request(BaseModel):
text: str
# Root route
@app.get("/")
def root():
return {
"message": "Welcome to the Check-in Detail Classifier API"
}
# Classification logic
def classify(text: str):
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True
)
# Remove token_type_ids (DistilBERT fix)
inputs.pop("token_type_ids", None)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=1)
pred = torch.argmax(probs).item()
confidence = probs[0][pred].item()
label = model.config.id2label[pred]
return label, confidence
# Predict endpoint
@app.post("/predict")
def predict(req: Request):
label, confidence = classify(req.text)
return {
"input": req.text,
"prediction": label,
"confidence": round(confidence, 4)
}