mvp-router-api / app.py
ethnmcl's picture
Update app.py
e8c4436 verified
import os
from typing import Dict, Any, Optional, List
import torch
import torch.nn.functional as F
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/mvp-router-roberta")
MAX_LENGTH = int(os.getenv("MAX_LENGTH", "128"))
HF_TOKEN = os.getenv("HF_TOKEN", None)
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device)
model.eval()
id2label = getattr(model.config, "id2label", {}) or {}
def _label(i: int) -> str:
return id2label.get(i) or id2label.get(str(i)) or str(i)
app = FastAPI(title="MVP Router API", version="1.0.0")
class LabelRequest(BaseModel):
text: str
return_probs: bool = True
top_k: int = 5
class LabelResponse(BaseModel):
model_id: str
label: str
score: float
probs: Optional[Dict[str, float]] = None
top_k: Optional[List[Dict[str, Any]]] = None
@app.get("/health")
def health():
return {"status": "ok", "model_id": MODEL_ID, "device": device}
@app.post("/label", response_model=LabelResponse)
def label(req: LabelRequest):
text = req.text.strip()
if not text:
return LabelResponse(model_id=MODEL_ID, label="INVALID_INPUT", score=0.0)
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding=True,
max_length=MAX_LENGTH
).to(device)
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=-1).squeeze(0)
probs_np = probs.detach().cpu().numpy()
pred_id = int(probs_np.argmax())
pred_label = _label(pred_id)
pred_score = float(probs_np[pred_id])
k = max(1, min(int(req.top_k), probs_np.shape[0]))
top_idx = probs_np.argsort()[::-1][:k]
top_k = [{"label": _label(int(i)), "score": float(probs_np[int(i)])} for i in top_idx]
probs_dict = None
if req.return_probs:
probs_dict = {_label(i): float(probs_np[i]) for i in range(probs_np.shape[0])}
return LabelResponse(
model_id=MODEL_ID,
label=pred_label,
score=pred_score,
probs=probs_dict,
top_k=top_k
)