Spaces:
Running
Running
| import asyncio | |
| import importlib.util | |
| import os | |
| from contextlib import asynccontextmanager | |
| from typing import Literal | |
| import torch | |
| import torch.nn.functional as F | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import snapshot_download | |
| from pydantic import BaseModel, Field | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| QUICK_VIBE_MODEL = os.environ.get( | |
| "QUICK_VIBE_MODEL", | |
| "itsLu/mentalbert-v5-flat-8class", | |
| ) | |
| DEEP_DIVE_MODEL = os.environ.get( | |
| "DEEP_DIVE_MODEL", | |
| "itsLu/mentalbert-v5-hierarchical-v3", | |
| ) | |
| LABEL_MAP: dict[str, str] = { | |
| "Anxiety": "anxiety", | |
| "Bipolar": "bipolar", | |
| "Depression": "depression", | |
| "Directed Aggression": "unhinged", | |
| "Normal": "normal", | |
| "Personality Disorder": "personality_disorder", | |
| "Stress": "stress", | |
| "Suicidal": "suicidal", | |
| } | |
| # Deep Dive stage label order (matches API_DOCUMENTATION.md §"stage_probs semantics") | |
| STAGE2_LABELS = ["Anxiety", "Bipolar", "Depression", "Personality Disorder", "Stress"] | |
| DEVICE = torch.device("cpu") | |
| QUICK_VIBE_MAX_LEN = 256 | |
| # Lazy-loaded singletons + per-model locks | |
| _state: dict[str, object] = {} | |
| _load_locks: dict[str, asyncio.Lock] = {} | |
| def _lock(name: str) -> asyncio.Lock: | |
| if name not in _load_locks: | |
| _load_locks[name] = asyncio.Lock() | |
| return _load_locks[name] | |
| def _load_quick_vibe_sync() -> dict: | |
| tokenizer = AutoTokenizer.from_pretrained(QUICK_VIBE_MODEL, token=HF_TOKEN) | |
| model = AutoModelForSequenceClassification.from_pretrained( | |
| QUICK_VIBE_MODEL, token=HF_TOKEN | |
| ) | |
| model.to(DEVICE).eval() | |
| return {"tokenizer": tokenizer, "model": model} | |
| def _load_deep_dive_sync() -> object: | |
| # Pull the entire repo (weights + handler.py + config) into local cache | |
| repo_dir = snapshot_download(DEEP_DIVE_MODEL, token=HF_TOKEN) | |
| handler_path = os.path.join(repo_dir, "handler.py") | |
| if not os.path.exists(handler_path): | |
| raise RuntimeError( | |
| f"handler.py not found in {DEEP_DIVE_MODEL} (looked at {handler_path})" | |
| ) | |
| spec = importlib.util.spec_from_file_location("dd_handler", handler_path) | |
| module = importlib.util.module_from_spec(spec) | |
| assert spec.loader is not None | |
| spec.loader.exec_module(module) | |
| if not hasattr(module, "EndpointHandler"): | |
| raise RuntimeError("handler.py does not define EndpointHandler") | |
| return module.EndpointHandler(repo_dir) | |
| async def _get_state(name: str) -> object: | |
| if name in _state: | |
| return _state[name] | |
| async with _lock(name): | |
| if name in _state: | |
| return _state[name] | |
| loop = asyncio.get_event_loop() | |
| if name == "mentalbert": | |
| _state[name] = await loop.run_in_executor(None, _load_quick_vibe_sync) | |
| elif name == "longformer": | |
| _state[name] = await loop.run_in_executor(None, _load_deep_dive_sync) | |
| else: | |
| raise HTTPException(status_code=400, detail=f"Unknown model: {name}") | |
| return _state[name] | |
| def _run_quick_vibe(text: str, state: dict) -> tuple[str, float]: | |
| tokenizer = state["tokenizer"] | |
| model = state["model"] | |
| inputs = tokenizer( | |
| text, | |
| return_tensors="pt", | |
| truncation=True, | |
| padding="max_length", | |
| max_length=QUICK_VIBE_MAX_LEN, | |
| ) | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probs = F.softmax(logits, dim=-1)[0] | |
| idx = int(torch.argmax(probs).item()) | |
| raw_label = model.config.id2label.get(idx, "Normal") | |
| return raw_label, float(probs[idx].item()) | |
| def _derive_confidence(label: str, exit_stage: str, stage_probs: dict) -> float: | |
| probs = stage_probs.get(exit_stage) or [] | |
| if exit_stage == "stage0": | |
| return float(probs[1]) if label == "Directed Aggression" else float(probs[0]) | |
| if exit_stage == "stage1a": | |
| return float(probs[1]) if label == "Suicidal" else float(probs[0]) | |
| if exit_stage == "stage1b": | |
| return float(probs[0]) if label == "Normal" else float(probs[1]) | |
| if exit_stage == "stage2": | |
| try: | |
| idx = STAGE2_LABELS.index(label) | |
| return float(probs[idx]) | |
| except (ValueError, IndexError): | |
| return 0.0 | |
| if exit_stage == "stage3": | |
| return float(probs[0]) if label == "Depression" else float(probs[1]) | |
| return 0.0 | |
| def _run_deep_dive(text: str, handler, sensitive_mode: bool) -> dict: | |
| payload = {"inputs": text, "mode": "safety" if sensitive_mode else "balanced"} | |
| result = handler(payload) | |
| if not isinstance(result, dict) or "label" not in result: | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"Unexpected Deep Dive handler output: {type(result).__name__}", | |
| ) | |
| return result | |
| async def lifespan(_app: FastAPI): | |
| yield | |
| _state.clear() | |
| app = FastAPI(title="VibeCheck API", version="5.1.0", lifespan=lifespan) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["POST", "GET"], | |
| allow_headers=["Content-Type"], | |
| ) | |
| class ClassifyRequest(BaseModel): | |
| text: str | |
| model: Literal["mentalbert", "longformer"] = "mentalbert" | |
| sensitive_mode: bool = Field(default=False) | |
| class ClassifyResponse(BaseModel): | |
| classification: str | |
| confidence: float | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "quick_vibe_model": QUICK_VIBE_MODEL, | |
| "deep_dive_model": DEEP_DIVE_MODEL, | |
| "loaded": list(_state.keys()), | |
| } | |
| async def classify(req: ClassifyRequest): | |
| text = req.text.strip() | |
| if not text: | |
| raise HTTPException(status_code=422, detail="text must not be empty") | |
| state = await _get_state(req.model) | |
| loop = asyncio.get_event_loop() | |
| if req.model == "mentalbert": | |
| raw_label, confidence = await loop.run_in_executor( | |
| None, _run_quick_vibe, text, state | |
| ) | |
| else: | |
| result = await loop.run_in_executor( | |
| None, _run_deep_dive, text, state, req.sensitive_mode | |
| ) | |
| raw_label = result["label"] | |
| confidence = _derive_confidence( | |
| raw_label, | |
| result.get("exit_stage", ""), | |
| result.get("stage_probs", {}), | |
| ) | |
| return ClassifyResponse( | |
| classification=LABEL_MAP.get(raw_label, "normal"), | |
| confidence=round(confidence, 4), | |
| ) | |