""" crisis_engine.py — PsyPredict Crisis Detection Layer Uses DistilBERT zero-shot classification (NOT keyword matching). Weighted risk scoring across mental health risk dimensions. Triggers override of LLM output when threshold exceeded. This layer is the safety net — it runs BEFORE and OVERRIDES the LLM. """ from __future__ import annotations import asyncio import logging from typing import List, Optional from app.schemas import CrisisResource, PsychReport, RiskLevel logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Risk Labels + Weights (tuned empirically) # --------------------------------------------------------------------------- RISK_LABELS: list[str] = [ "suicidal ideation", "self-harm intent", "immediate danger to self", "severe mental breakdown", "hopelessness and worthlessness", ] RISK_WEIGHTS: dict[str, float] = { "suicidal ideation": 1.0, "self-harm intent": 1.0, "immediate danger to self": 0.95, "severe mental breakdown": 0.60, "hopelessness and worthlessness": 0.50, } # --------------------------------------------------------------------------- # Crisis Resources (India + International) # --------------------------------------------------------------------------- CRISIS_RESOURCES: List[CrisisResource] = [ CrisisResource(name="iCall (India)", contact="9152987821", available="Mon–Sat 8am–10pm"), CrisisResource(name="Vandrevala Foundation (India)", contact="1860-2662-345", available="24/7"), CrisisResource(name="AASRA (India)", contact="9820466627", available="24/7"), CrisisResource(name="Befrienders Worldwide", contact="https://www.befrienders.org", available="24/7"), CrisisResource(name="Crisis Text Line (US/UK)", contact="Text HOME to 741741", available="24/7"), ] # --------------------------------------------------------------------------- # Zero-Shot Classifier # --------------------------------------------------------------------------- _zero_shot_pipeline = None _load_error: Optional[str] = None def initialize_crisis_classifier() -> None: """ Load MiniLM zero-shot classifier at startup. Uses cross-encoder/nli-MiniLM2-L6-H768 — lightweight, fast. """ global _zero_shot_pipeline, _load_error try: from transformers import pipeline as hf_pipeline import os local_path = os.path.join("app", "ml_assets", "crisis_model") logger.info("Loading crisis zero-shot classifier from %s", local_path) _zero_shot_pipeline = hf_pipeline( "zero-shot-classification", model=local_path if os.path.exists(local_path) else "cross-encoder/nli-MiniLM2-L6-H768", device=-1, # CPU ) logger.info("✅ Crisis classifier loaded.") except Exception as exc: _load_error = str(exc) logger.error("❌ Crisis classifier load failed: %s", exc) def _score_sync(text: str) -> float: """ Synchronous zero-shot scoring. Runs in thread pool. Returns weighted crisis risk score in [0, 1]. """ if _zero_shot_pipeline is None: # Fallback: basic substring check for true emergencies only return _fallback_score(text) try: result = _zero_shot_pipeline( text[:512], candidate_labels=RISK_LABELS, multi_label=True, ) label_scores: dict[str, float] = dict( zip(result["labels"], result["scores"]) ) # Weighted sum, normalized to [0, 1] total_weight = sum(RISK_WEIGHTS.values()) weighted_sum = sum( label_scores.get(lbl, 0.0) * RISK_WEIGHTS[lbl] for lbl in RISK_LABELS ) return min(weighted_sum / total_weight, 1.0) except Exception as exc: logger.error("Crisis scoring error: %s", exc) return _fallback_score(text) def _fallback_score(text: str) -> float: """ Hard fallback: only fires on unambiguous semantic signals. This is distinct from keyword matching — uses phrase-level context. """ HIGH_RISK_PHRASES = [ "want to die", "kill myself", "end my life", "hurt myself", "suicide", "self harm", "self-harm", "no reason to live", "don't want to exist", "cannot go on", "take my life", ] t = text.lower() hits = sum(1 for phrase in HIGH_RISK_PHRASES if phrase in t) return min(hits * 0.35, 1.0) class CrisisEngine: """ Evaluates crisis risk from user text. Must be called before LLM generation. If triggered, returns a deterministic PsychReport override. """ def __init__(self, threshold: float = 0.65) -> None: self.threshold = threshold async def evaluate(self, text: str) -> tuple[float, bool]: """ Returns (risk_score, crisis_triggered). Runs synchronous model in thread pool. """ score = await asyncio.to_thread(_score_sync, text) triggered = score >= self.threshold if triggered: logger.warning( "CRISIS TRIGGERED — risk_score=%.3f text=%r", score, text[:100], ) return score, triggered def build_crisis_report(self, risk_score: float) -> tuple[str, PsychReport]: """ Returns deterministic crisis reply + PsychReport. Does NOT involve the LLM. """ reply = ( "I hear that you're going through something very serious right now. " "Please reach out to a crisis support line immediately — " "you don't have to face this alone." ) report = PsychReport( risk_classification=RiskLevel.CRITICAL, emotional_state_summary=( "Severe psychological distress detected. Indicators of self-harm " "or suicidal ideation are present." ), behavioral_inference=( "User's expressed content suggests acute crisis state. " "Immediate professional intervention is warranted." ), cognitive_distortions=["Hopelessness", "All-or-nothing thinking"], suggested_interventions=[ "Immediate contact with a mental health crisis line.", "Notify a trusted person or emergency services if in immediate danger.", "Seek in-person emergency psychiatric evaluation.", ], confidence_score=round(risk_score, 3), crisis_triggered=True, crisis_resources=CRISIS_RESOURCES, service_degraded=False, ) return reply, report @property def is_loaded(self) -> bool: return _zero_shot_pipeline is not None # Singleton crisis_engine = CrisisEngine()