vibecheck-api / app.py
itsLu's picture
chore(api): remove /classify/debug diagnostic endpoint
0ae9c14
import asyncio
import importlib.util
import os
from contextlib import asynccontextmanager
from typing import Literal
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import snapshot_download
from pydantic import BaseModel, Field
from transformers import AutoModelForSequenceClassification, AutoTokenizer
HF_TOKEN = os.environ.get("HF_TOKEN")
QUICK_VIBE_MODEL = os.environ.get(
"QUICK_VIBE_MODEL",
"itsLu/mentalbert-v5-flat-8class",
)
DEEP_DIVE_MODEL = os.environ.get(
"DEEP_DIVE_MODEL",
"itsLu/mentalbert-v5-hierarchical-v3",
)
LABEL_MAP: dict[str, str] = {
"Anxiety": "anxiety",
"Bipolar": "bipolar",
"Depression": "depression",
"Directed Aggression": "unhinged",
"Normal": "normal",
"Personality Disorder": "personality_disorder",
"Stress": "stress",
"Suicidal": "suicidal",
}
# Deep Dive stage label order (matches API_DOCUMENTATION.md §"stage_probs semantics")
STAGE2_LABELS = ["Anxiety", "Bipolar", "Depression", "Personality Disorder", "Stress"]
DEVICE = torch.device("cpu")
QUICK_VIBE_MAX_LEN = 256
# Lazy-loaded singletons + per-model locks
_state: dict[str, object] = {}
_load_locks: dict[str, asyncio.Lock] = {}
def _lock(name: str) -> asyncio.Lock:
if name not in _load_locks:
_load_locks[name] = asyncio.Lock()
return _load_locks[name]
def _load_quick_vibe_sync() -> dict:
tokenizer = AutoTokenizer.from_pretrained(QUICK_VIBE_MODEL, token=HF_TOKEN)
model = AutoModelForSequenceClassification.from_pretrained(
QUICK_VIBE_MODEL, token=HF_TOKEN
)
model.to(DEVICE).eval()
return {"tokenizer": tokenizer, "model": model}
def _load_deep_dive_sync() -> object:
# Pull the entire repo (weights + handler.py + config) into local cache
repo_dir = snapshot_download(DEEP_DIVE_MODEL, token=HF_TOKEN)
handler_path = os.path.join(repo_dir, "handler.py")
if not os.path.exists(handler_path):
raise RuntimeError(
f"handler.py not found in {DEEP_DIVE_MODEL} (looked at {handler_path})"
)
spec = importlib.util.spec_from_file_location("dd_handler", handler_path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
if not hasattr(module, "EndpointHandler"):
raise RuntimeError("handler.py does not define EndpointHandler")
return module.EndpointHandler(repo_dir)
async def _get_state(name: str) -> object:
if name in _state:
return _state[name]
async with _lock(name):
if name in _state:
return _state[name]
loop = asyncio.get_event_loop()
if name == "mentalbert":
_state[name] = await loop.run_in_executor(None, _load_quick_vibe_sync)
elif name == "longformer":
_state[name] = await loop.run_in_executor(None, _load_deep_dive_sync)
else:
raise HTTPException(status_code=400, detail=f"Unknown model: {name}")
return _state[name]
def _run_quick_vibe(text: str, state: dict) -> tuple[str, float]:
tokenizer = state["tokenizer"]
model = state["model"]
inputs = tokenizer(
text,
return_tensors="pt",
truncation=True,
padding="max_length",
max_length=QUICK_VIBE_MAX_LEN,
)
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
logits = model(**inputs).logits
probs = F.softmax(logits, dim=-1)[0]
idx = int(torch.argmax(probs).item())
raw_label = model.config.id2label.get(idx, "Normal")
return raw_label, float(probs[idx].item())
def _derive_confidence(label: str, exit_stage: str, stage_probs: dict) -> float:
probs = stage_probs.get(exit_stage) or []
if exit_stage == "stage0":
return float(probs[1]) if label == "Directed Aggression" else float(probs[0])
if exit_stage == "stage1a":
return float(probs[1]) if label == "Suicidal" else float(probs[0])
if exit_stage == "stage1b":
return float(probs[0]) if label == "Normal" else float(probs[1])
if exit_stage == "stage2":
try:
idx = STAGE2_LABELS.index(label)
return float(probs[idx])
except (ValueError, IndexError):
return 0.0
if exit_stage == "stage3":
return float(probs[0]) if label == "Depression" else float(probs[1])
return 0.0
def _run_deep_dive(text: str, handler, sensitive_mode: bool) -> dict:
payload = {"inputs": text, "mode": "safety" if sensitive_mode else "balanced"}
result = handler(payload)
if not isinstance(result, dict) or "label" not in result:
raise HTTPException(
status_code=502,
detail=f"Unexpected Deep Dive handler output: {type(result).__name__}",
)
return result
@asynccontextmanager
async def lifespan(_app: FastAPI):
yield
_state.clear()
app = FastAPI(title="VibeCheck API", version="5.1.0", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["POST", "GET"],
allow_headers=["Content-Type"],
)
class ClassifyRequest(BaseModel):
text: str
model: Literal["mentalbert", "longformer"] = "mentalbert"
sensitive_mode: bool = Field(default=False)
class ClassifyResponse(BaseModel):
classification: str
confidence: float
@app.get("/")
def health():
return {
"status": "ok",
"quick_vibe_model": QUICK_VIBE_MODEL,
"deep_dive_model": DEEP_DIVE_MODEL,
"loaded": list(_state.keys()),
}
@app.post("/classify", response_model=ClassifyResponse)
async def classify(req: ClassifyRequest):
text = req.text.strip()
if not text:
raise HTTPException(status_code=422, detail="text must not be empty")
state = await _get_state(req.model)
loop = asyncio.get_event_loop()
if req.model == "mentalbert":
raw_label, confidence = await loop.run_in_executor(
None, _run_quick_vibe, text, state
)
else:
result = await loop.run_in_executor(
None, _run_deep_dive, text, state, req.sensitive_mode
)
raw_label = result["label"]
confidence = _derive_confidence(
raw_label,
result.get("exit_stage", ""),
result.get("stage_probs", {}),
)
return ClassifyResponse(
classification=LABEL_MAP.get(raw_label, "normal"),
confidence=round(confidence, 4),
)