import os from typing import Dict, Any, Optional, List import torch import torch.nn.functional as F from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification MODEL_ID = os.getenv("MODEL_ID", "ethnmcl/mvp-router-roberta") MAX_LENGTH = int(os.getenv("MAX_LENGTH", "128")) HF_TOKEN = os.getenv("HF_TOKEN", None) device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, token=HF_TOKEN) model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device) model.eval() id2label = getattr(model.config, "id2label", {}) or {} def _label(i: int) -> str: return id2label.get(i) or id2label.get(str(i)) or str(i) app = FastAPI(title="MVP Router API", version="1.0.0") class LabelRequest(BaseModel): text: str return_probs: bool = True top_k: int = 5 class LabelResponse(BaseModel): model_id: str label: str score: float probs: Optional[Dict[str, float]] = None top_k: Optional[List[Dict[str, Any]]] = None @app.get("/health") def health(): return {"status": "ok", "model_id": MODEL_ID, "device": device} @app.post("/label", response_model=LabelResponse) def label(req: LabelRequest): text = req.text.strip() if not text: return LabelResponse(model_id=MODEL_ID, label="INVALID_INPUT", score=0.0) inputs = tokenizer( text, return_tensors="pt", truncation=True, padding=True, max_length=MAX_LENGTH ).to(device) with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=-1).squeeze(0) probs_np = probs.detach().cpu().numpy() pred_id = int(probs_np.argmax()) pred_label = _label(pred_id) pred_score = float(probs_np[pred_id]) k = max(1, min(int(req.top_k), probs_np.shape[0])) top_idx = probs_np.argsort()[::-1][:k] top_k = [{"label": _label(int(i)), "score": float(probs_np[int(i)])} for i in top_idx] probs_dict = None if req.return_probs: probs_dict = {_label(i): float(probs_np[i]) for i in range(probs_np.shape[0])} return LabelResponse( model_id=MODEL_ID, label=pred_label, score=pred_score, probs=probs_dict, top_k=top_k )