import asyncio import glob import importlib.util import inspect import json import logging import os import re from contextlib import asynccontextmanager from typing import Literal import torch import torch.nn.functional as F from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from huggingface_hub import snapshot_download from pydantic import BaseModel, Field from transformers import AutoModelForSequenceClassification, AutoTokenizer logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s | %(message)s", ) logger = logging.getLogger("vibecheck") NO_CACHE_HEADERS = { "Cache-Control": "no-store, no-cache, must-revalidate, max-age=0", "Pragma": "no-cache", "Expires": "0", } HF_TOKEN = os.environ.get("HF_TOKEN") QUICK_VIBE_MODEL = os.environ.get( "QUICK_VIBE_MODEL", "itsLu/mentalbert-v6-flat", ) DEEP_DIVE_MODEL = os.environ.get( "DEEP_DIVE_MODEL", "itsLu/mentalbert-v6-hierarchical", ) # Operational flags. All off / restrictive by default; flip via Space secrets. DEBUG_ENDPOINTS = os.environ.get("DEBUG_ENDPOINTS", "false").strip().lower() == "true" MAX_INPUT_CHARS = int(os.environ.get("MAX_INPUT_CHARS", "4096")) # CORS allowlist. Comma-separated origins via env var; sensible defaults # cover local dev + the current Vercel deployment. Add new origins by # setting the ALLOWED_ORIGINS secret on the Space — no code change needed. _DEFAULT_ALLOWED_ORIGINS = ( "http://localhost:3000," "http://localhost:3001," "https://vibecheck-eosin.vercel.app" ) ALLOWED_ORIGINS = [ o.strip() for o in os.environ.get("ALLOWED_ORIGINS", _DEFAULT_ALLOWED_ORIGINS).split(",") if o.strip() ] LABEL_MAP: dict[str, str] = { "Anxiety": "anxiety", "Bipolar": "bipolar", "Depression": "depression", "Directed Aggression": "unhinged", "Normal": "normal", "Personality Disorder": "personality_disorder", "Stress": "stress", "Suicidal": "suicidal", } # Hardcoded fallback for the flat 8-class head — alphabetical, matches # sklearn.preprocessing.LabelEncoder applied to the v5/v6 class set. # Used only when the model's own config + repo files don't supply labels. FALLBACK_QUICK_VIBE_LABELS = [ "Anxiety", "Bipolar", "Depression", "Directed Aggression", "Normal", "Personality Disorder", "Stress", "Suicidal", ] # Deep Dive stage label order (matches API_DOCUMENTATION.md §"stage_probs semantics") STAGE2_LABELS = ["Anxiety", "Bipolar", "Depression", "Personality Disorder", "Stress"] # Explicit-threat pre-filter. Matches (1st-person modal) + (violent verb) + # (2nd/3rd-person target). Skipping the model for cases that should never # require it. Carefully NOT matching: # - "I want to kill myself" (target "myself" not in target list) # - "I want to kiss my friend" (verb "kiss" not in verb list) # - "killing me softly with this song" (no modal verb prefix) EXPLICIT_THREAT_PATTERN = re.compile( r"\b(want\s+to|wanna|gonna|going\s+to|gotta|will|need\s+to|finna|about\s+to|tryna)\s+" r"(kill|murder|hurt|harm|beat|stab|shoot|attack|strangle|choke|smash|bash|destroy|punch)\s+" r"(my|that|the|him|her|them|those|these|you|this)\b", flags=re.IGNORECASE, ) # Sensitive-mode override on the Deep Dive cascade: if Stage 0's # P(Directed Aggression) clears this bar but the handler's own # threshold didn't fire, the proxy promotes the label to DA. # Balanced mode unchanged — proxy trusts whatever the handler returns. SENSITIVE_T0_OVERRIDE = 0.25 def _resolve_quick_vibe_labels(repo_dir: str, model_config) -> tuple[list[str], str]: """Resolve the index→class-name list for the flat head. Order of preference: 1. model.config.id2label (if not the LABEL_N placeholders) 2. label_encoder.joblib in the model repo (sklearn LabelEncoder) 3. inference_config.json with a "classes" array 4. hardcoded 8-class alphabetical fallback Returns (labels, source) so /diag can surface which path won. """ # 1. Trust the model's own config if it has real labels cfg = getattr(model_config, "id2label", None) or {} if cfg: ordered = [cfg[k] for k in sorted(cfg.keys(), key=lambda x: int(x))] if ordered and not all(str(v).startswith("LABEL_") for v in ordered): return ordered, "model.config.id2label" # 2. sklearn LabelEncoder pickled into the repo le_path = os.path.join(repo_dir, "label_encoder.joblib") if os.path.isfile(le_path): try: import joblib le = joblib.load(le_path) classes = [str(c) for c in list(le.classes_)] if classes: return classes, "label_encoder.joblib" except Exception as e: logger.warning("[labels] failed to load label_encoder.joblib: %s", e) # 3. inference_config.json with a "classes" array ic_path = os.path.join(repo_dir, "inference_config.json") if os.path.isfile(ic_path): try: with open(ic_path) as f: ic = json.load(f) classes = ic.get("classes") if isinstance(classes, list) and classes: return [str(c) for c in classes], "inference_config.json" except Exception as e: logger.warning("[labels] failed to load inference_config.json: %s", e) # 4. Hardcoded fallback return list(FALLBACK_QUICK_VIBE_LABELS), "hardcoded_fallback" DEVICE = torch.device("cpu") QUICK_VIBE_MAX_LEN = 256 # Lazy-loaded singletons + per-model locks _state: dict[str, object] = {} _load_locks: dict[str, asyncio.Lock] = {} def _lock(name: str) -> asyncio.Lock: if name not in _load_locks: _load_locks[name] = asyncio.Lock() return _load_locks[name] def _load_quick_vibe_sync() -> dict: logger.info("[load] Quick Vibe begin: %s (pid=%d)", QUICK_VIBE_MODEL, os.getpid()) # snapshot_download so we can inspect label_encoder.joblib / # inference_config.json alongside the weights repo_dir = snapshot_download(QUICK_VIBE_MODEL, token=HF_TOKEN) tokenizer = AutoTokenizer.from_pretrained(repo_dir) model = AutoModelForSequenceClassification.from_pretrained(repo_dir) model.to(DEVICE).eval() labels, labels_source = _resolve_quick_vibe_labels(repo_dir, model.config) n_params = sum(p.numel() for p in model.parameters()) logger.info( "[load] Quick Vibe ready: %s | params=%s | repo_dir=%s | labels_source=%s | labels=%s", QUICK_VIBE_MODEL, f"{n_params:,}", repo_dir, labels_source, labels, ) if len(labels) != getattr(model.config, "num_labels", len(labels)): logger.warning( "[labels] decoder list length (%d) != model.num_labels (%d) — predictions will be wrong", len(labels), model.config.num_labels, ) return { "tokenizer": tokenizer, "model": model, "params": n_params, "labels": labels, "labels_source": labels_source, "repo_dir": repo_dir, } def _load_deep_dive_sync() -> object: logger.info("[load] Deep Dive begin: %s (pid=%d)", DEEP_DIVE_MODEL, os.getpid()) # Pull the entire repo (weights + handler.py + config) into local cache repo_dir = snapshot_download(DEEP_DIVE_MODEL, token=HF_TOKEN) handler_path = os.path.join(repo_dir, "handler.py") if not os.path.exists(handler_path): raise RuntimeError( f"handler.py not found in {DEEP_DIVE_MODEL} (looked at {handler_path})" ) spec = importlib.util.spec_from_file_location("dd_handler", handler_path) module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) if not hasattr(module, "EndpointHandler"): raise RuntimeError("handler.py does not define EndpointHandler") handler = module.EndpointHandler(repo_dir) logger.info( "[load] Deep Dive ready: %s | repo_dir=%s | handler=%s", DEEP_DIVE_MODEL, repo_dir, type(handler).__name__, ) return handler def _list_hf_cache_models() -> list[str]: cache_root = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface")) hub_dir = os.path.join(cache_root, "hub") if not os.path.isdir(hub_dir): return [] return sorted( os.path.basename(p) for p in glob.glob(os.path.join(hub_dir, "models--*")) ) async def _get_state(name: str) -> object: if name in _state: return _state[name] async with _lock(name): if name in _state: return _state[name] loop = asyncio.get_event_loop() if name == "mentalbert": _state[name] = await loop.run_in_executor(None, _load_quick_vibe_sync) elif name == "longformer": _state[name] = await loop.run_in_executor(None, _load_deep_dive_sync) else: raise HTTPException(status_code=400, detail=f"Unknown model: {name}") return _state[name] def _run_quick_vibe(text: str, state: dict) -> tuple[str, float]: tokenizer = state["tokenizer"] model = state["model"] labels = state["labels"] inputs = tokenizer( text, return_tensors="pt", truncation=True, padding="max_length", max_length=QUICK_VIBE_MAX_LEN, ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): logits = model(**inputs).logits probs = F.softmax(logits, dim=-1)[0] idx = int(torch.argmax(probs).item()) raw_label = labels[idx] if 0 <= idx < len(labels) else f"LABEL_{idx}" return raw_label, float(probs[idx].item()) def _debug_quick_vibe(text: str, state: dict) -> dict: tokenizer = state["tokenizer"] model = state["model"] labels = state["labels"] inputs = tokenizer( text, return_tensors="pt", truncation=True, padding="max_length", max_length=QUICK_VIBE_MAX_LEN, ) inputs = {k: v.to(DEVICE) for k, v in inputs.items()} with torch.no_grad(): logits_t = model(**inputs).logits probs_t = F.softmax(logits_t, dim=-1)[0] logits = [round(x, 6) for x in logits_t[0].tolist()] probs = [round(x, 6) for x in probs_t.tolist()] idx = int(torch.argmax(probs_t).item()) raw_label = labels[idx] if 0 <= idx < len(labels) else f"LABEL_{idx}" return { "model": "mentalbert", "logits": logits, "probs": probs, "argmax_idx": idx, "raw_label": raw_label, "classification": LABEL_MAP.get(raw_label, "normal"), "confidence": round(float(probs_t[idx].item()), 4), "labels_in_order": labels, "labels_source": state.get("labels_source"), } def _derive_confidence(label: str, exit_stage: str, stage_probs: dict) -> float: probs = stage_probs.get(exit_stage) or [] if exit_stage == "stage0": return float(probs[1]) if label == "Directed Aggression" else float(probs[0]) if exit_stage == "stage1a": return float(probs[1]) if label == "Suicidal" else float(probs[0]) if exit_stage == "stage1b": return float(probs[0]) if label == "Normal" else float(probs[1]) if exit_stage == "stage2": try: idx = STAGE2_LABELS.index(label) return float(probs[idx]) except (ValueError, IndexError): return 0.0 if exit_stage == "stage3": return float(probs[0]) if label == "Depression" else float(probs[1]) return 0.0 def _run_deep_dive(text: str, handler, sensitive_mode: bool) -> dict: payload = {"inputs": text, "mode": "safety" if sensitive_mode else "balanced"} result = handler(payload) if not isinstance(result, dict) or "label" not in result: raise HTTPException( status_code=502, detail=f"Unexpected Deep Dive handler output: {type(result).__name__}", ) return result @asynccontextmanager async def lifespan(_app: FastAPI): yield _state.clear() app = FastAPI(title="VibeCheck API", version="6.0.0", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, allow_credentials=False, allow_methods=["POST", "GET"], allow_headers=["Content-Type"], ) class ClassifyRequest(BaseModel): text: str model: Literal["mentalbert", "longformer"] = "mentalbert" sensitive_mode: bool = Field(default=False) class ClassifyResponse(BaseModel): classification: str confidence: float @app.get("/") def health(): return JSONResponse( content={ "status": "ok", "pid": os.getpid(), "quick_vibe_model": QUICK_VIBE_MODEL, "deep_dive_model": DEEP_DIVE_MODEL, "loaded": list(_state.keys()), }, headers=NO_CACHE_HEADERS, ) if DEBUG_ENDPOINTS: @app.get("/diag") def diag(): """Diagnostic snapshot: which models are actually loaded in this worker, where they came from, and what's sitting in the HF cache on disk. Only registered when DEBUG_ENDPOINTS=true.""" info: dict = { "status": "ok", "pid": os.getpid(), "quick_vibe_model_id": QUICK_VIBE_MODEL, "deep_dive_model_id": DEEP_DIVE_MODEL, "loaded": list(_state.keys()), "hf_cache_models": _list_hf_cache_models(), "hf_token_configured": bool(HF_TOKEN), } if "mentalbert" in _state: s = _state["mentalbert"] model = s["model"] info["quick_vibe"] = { "params": s.get("params") or sum(p.numel() for p in model.parameters()), "name_or_path": getattr(model.config, "_name_or_path", None), "id2label_from_config": dict(model.config.id2label), "labels_in_use": s.get("labels"), "labels_source": s.get("labels_source"), "num_labels": getattr(model.config, "num_labels", None), "repo_dir": s.get("repo_dir"), } if "longformer" in _state: handler = _state["longformer"] try: handler_file = inspect.getfile(type(handler)) except Exception: handler_file = None info["deep_dive"] = { "handler_class": type(handler).__name__, "handler_file": handler_file, } return JSONResponse(content=info, headers=NO_CACHE_HEADERS) if DEBUG_ENDPOINTS: @app.post("/debug_classify") async def debug_classify(req: ClassifyRequest): """Stage-by-stage diagnostic. Reports whether the explicit-threat pre-filter fired before the model ran. For Quick Vibe: returns full logits/probs/argmax/decoded label. For Deep Dive: raw handler output plus whether the Sensitive-mode t0 override changed the label. Only registered when DEBUG_ENDPOINTS=true.""" text = req.text.strip() if not text: raise HTTPException(status_code=422, detail="text must not be empty") if len(text) > MAX_INPUT_CHARS: raise HTTPException( status_code=413, detail=f"Input too long: {len(text)} > {MAX_INPUT_CHARS} chars", ) threat_match = EXPLICIT_THREAT_PATTERN.search(text) if threat_match: return JSONResponse( content={ "model": req.model, "prefilter": "explicit_threat", "matched_span": threat_match.group(0), "raw_label": "Directed Aggression", "classification": LABEL_MAP["Directed Aggression"], "confidence": 0.99, }, headers=NO_CACHE_HEADERS, ) state = await _get_state(req.model) loop = asyncio.get_event_loop() if req.model == "mentalbert": result = await loop.run_in_executor(None, _debug_quick_vibe, text, state) return JSONResponse(content=result, headers=NO_CACHE_HEADERS) handler_result = await loop.run_in_executor( None, _run_deep_dive, text, state, req.sensitive_mode ) handler_label = handler_result["label"] raw_label = handler_label stage0 = handler_result.get("stage_probs", {}).get("stage0") or [] override_applied = False if ( req.sensitive_mode and handler_label != "Directed Aggression" and len(stage0) >= 2 and float(stage0[1]) >= SENSITIVE_T0_OVERRIDE ): raw_label = "Directed Aggression" override_applied = True confidence = float(stage0[1]) else: confidence = _derive_confidence( handler_label, handler_result.get("exit_stage", ""), handler_result.get("stage_probs", {}), ) return JSONResponse( content={ "model": "longformer", "sensitive_mode": req.sensitive_mode, "handler_label": handler_label, "raw_label": raw_label, "classification": LABEL_MAP.get(raw_label, "normal"), "confidence": round(confidence, 4), "exit_stage": handler_result.get("exit_stage"), "mode": handler_result.get("mode"), "stage_probs": handler_result.get("stage_probs", {}), "stage1a_raw": handler_result.get("stage1a_raw"), "stage3_raw": handler_result.get("stage3_raw"), "sensitive_t0_override": { "threshold": SENSITIVE_T0_OVERRIDE, "applied": override_applied, "stage0_p_da": float(stage0[1]) if len(stage0) >= 2 else None, }, }, headers=NO_CACHE_HEADERS, ) @app.post("/classify", response_model=ClassifyResponse) async def classify(req: ClassifyRequest): text = req.text.strip() if not text: raise HTTPException(status_code=422, detail="text must not be empty") if len(text) > MAX_INPUT_CHARS: raise HTTPException( status_code=413, detail=f"Input too long: {len(text)} > {MAX_INPUT_CHARS} chars", ) # Pre-filter: explicit threats short-circuit model inference. # Applies to both mentalbert and longformer. if EXPLICIT_THREAT_PATTERN.search(text): return ClassifyResponse( classification=LABEL_MAP["Directed Aggression"], confidence=0.99, ) state = await _get_state(req.model) loop = asyncio.get_event_loop() if req.model == "mentalbert": raw_label, confidence = await loop.run_in_executor( None, _run_quick_vibe, text, state ) else: result = await loop.run_in_executor( None, _run_deep_dive, text, state, req.sensitive_mode ) raw_label = result["label"] stage0 = result.get("stage_probs", {}).get("stage0") or [] # Sensitive-mode stage 0 override: lower threshold from handler's # default (~0.40) to 0.25 so the cascade fires DA on weaker cues. if ( req.sensitive_mode and raw_label != "Directed Aggression" and len(stage0) >= 2 and float(stage0[1]) >= SENSITIVE_T0_OVERRIDE ): raw_label = "Directed Aggression" confidence = float(stage0[1]) else: confidence = _derive_confidence( raw_label, result.get("exit_stage", ""), result.get("stage_probs", {}), ) return ClassifyResponse( classification=LABEL_MAP.get(raw_label, "normal"), confidence=round(confidence, 4), )