""" MindForge AI — Distress Detection & Conversational Escalation Engine (Gradio demo). Primary path: deterministic rules classifier so the Space always boots fast and stays reliable. Optional: transformers + PEFT adapters when USE_MODEL=true and weights load successfully. """ from __future__ import annotations import json import logging import os import re from typing import Any, Literal import gradio as gr import pandas as pd from pydantic import BaseModel, Field logging.basicConfig(level=logging.INFO) logger = logging.getLogger("mindforge") # --------------------------------------------------------------------------- # Environment — optional HF model path (non-blocking) # --------------------------------------------------------------------------- USE_MODEL = os.getenv("USE_MODEL", "false").strip().lower() in ("1", "true", "yes") BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "").strip() ADAPTER_MODEL_ID = os.getenv("ADAPTER_MODEL_ID", "").strip() HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("HUGGING_FACE_HUB_TOKEN", "")).strip() # Lazy singletons for optional inference (never required for demo) _model_bundle: dict[str, Any] | None = None _model_load_error: str | None = None def _try_load_models() -> None: """Attempt optional model load once; failures are logged and ignored.""" global _model_bundle, _model_load_error if _model_bundle is not None or _model_load_error is not None: return if not USE_MODEL or not BASE_MODEL_ID: _model_load_error = "skipped" return try: import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN or None) base = AutoModelForCausalLM.from_pretrained( BASE_MODEL_ID, token=HF_TOKEN or None, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto" if torch.cuda.is_available() else None, ) if ADAPTER_MODEL_ID: base = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID, token=HF_TOKEN or None) _model_bundle = {"tokenizer": tok, "model": base, "torch": torch} logger.info("Optional MindForge model weights loaded.") except Exception as e: _model_load_error = str(e) logger.warning("Optional model load failed; using rules fallback. %s", e) def optional_base_generate(note: str, max_new_tokens: int = 120) -> str | None: """ Best-effort generic supportive continuation from the (optional) base model. Not used for routing/safety decisions — those always come from rules + templates. """ _try_load_models() if not _model_bundle: return None try: torch = _model_bundle["torch"] tok = _model_bundle["tokenizer"] model = _model_bundle["model"] prompt = ( "You are a concise assistant. Reply with one short paragraph that acknowledges the user's " "message in general supportive language. Do not diagnose, prescribe medication, or give " "instructions related to self-harm. User message:\n" f"{note.strip()[:1500]}" ) inputs = tok(prompt, return_tensors="pt") if torch.cuda.is_available(): inputs = {k: v.cuda() for k, v in inputs.items()} with torch.no_grad(): out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) text = tok.decode(out[0], skip_special_tokens=True) # Return only the continuation after the prompt when possible if text.startswith(prompt): text = text[len(prompt) :].strip() return text.strip() or None except Exception as e: logger.warning("optional_base_generate failed: %s", e) return None # --------------------------------------------------------------------------- # Safe response templates (exact strings from spec) # --------------------------------------------------------------------------- TEMPLATES: dict[str, str] = { "supportive_dialogue_01": ( "Thanks for checking in. I’m glad you shared this. Let’s keep tracking how you’re feeling " "and note anything that changes." ), "monitor_followup_01": ( "I hear that things feel difficult right now. Let’s slow down, capture what changed, " "and consider a supportive follow-up." ), "caregiver_support_01": ( "I hear you. This sounds important enough for a human support person to review soon. " "Let’s capture the key signals and connect you with support." ), "crisis_protocol_01": ( "This may require immediate human support. If there is immediate danger, contact local " "emergency services or a crisis hotline now. This system will flag the note for urgent " "human review." ), } EXACT_DISCLAIMER = ( "MindForge AI does not diagnose, treat, or replace licensed care. It is a human-in-the-loop " "risk review assistant that organizes distress signals and recommends escalation pathways for review." ) CRISIS_BANNER = ( "This may require immediate human support. If there is immediate danger, contact local emergency " "services or a crisis hotline now. This system flags the note for urgent human review." ) # Keyword buckets — priority: crisis > distressed > concerned > calm > default concerned CRISIS_KW = [ "kill myself", "end my life", "suicide", "overdose", "hurt myself", "hurt someone", "goodbye forever", "can't go on", "cant go on", "i have a plan", ] DISTRESSED_KW = [ "haven't slept", "havent slept", "not sleeping", "stopped taking medication", "skipping meds", "disappearing", "hopeless", "numb", "can't function", "cant function", "missed work", "isolating", ] CONCERNED_KW = [ "anxious", "overwhelmed", "sad", "stressed", "worried", "lonely", "panic", ] CALM_KW = [ "better", "okay", "walk", "called", "grateful", "good day", "stable", ] class MindForgeAnalysis(BaseModel): """Structured output aligned with the hackathon JSON contract.""" distress_level: Literal[0, 1, 2, 3] distress_label: Literal["calm", "concerned", "distressed", "crisis"] recommended_action: Literal[ "supportive_dialogue", "monitor_and_follow_up", "escalate_caregiver", "escalate_crisis_protocol", ] risk_signals: list[str] = Field(default_factory=list) protective_signals: list[str] = Field(default_factory=list) medication_or_adherence_flags: list[str] = Field(default_factory=list) sleep_mood_flags: list[str] = Field(default_factory=list) requires_human_review: bool escalation_flag: bool safe_response_template_id: str safe_response: str care_team_summary: str patient_safe_explanation: str confidence: float = Field(ge=0.0, le=1.0) model_boundary: str def _normalize(text: str) -> str: return re.sub(r"\s+", " ", text.strip().lower()) def _contains_any(hay: str, needles: list[str]) -> list[str]: matched = [] for n in needles: if n in hay: matched.append(n) return matched def _stable_confidence(note: str, level: int) -> float: """Deterministic pseudo-confidence from text — no RNG for reproducible demos.""" base = sum(ord(c) for c in note[:800]) tweak = (base % 17) / 100.0 anchor = 0.74 + tweak + 0.01 * level return round(min(0.93, max(0.62, anchor)), 2) def classify_distress(note: str) -> tuple[int, str]: """ Rules-first distress taxonomy: 0 calm, 1 concerned, 2 distressed, 3 crisis. Priority: crisis > distressed > concerned > calm; uncertain → concerned (level 1). """ h = _normalize(note) if _contains_any(h, CRISIS_KW): return 3, "crisis" if _contains_any(h, DISTRESSED_KW): return 2, "distressed" if _contains_any(h, CONCERNED_KW): return 1, "concerned" if _contains_any(h, CALM_KW): return 0, "calm" return 1, "concerned" def _extract_risk_protective(note: str, level: int, label: str) -> tuple[list[str], list[str]]: """Lightweight signal extraction for review summaries — no clinical claims.""" h = _normalize(note) risks: list[str] = [] protective: list[str] = [] if any(k in h for k in ["sleep", "slept", "insomnia", "not sleeping"]): risks.append("sleep disruption") if any(k in h for k in ["withdraw", "isolat", "disappear"]): risks.append("social withdrawal") if any(k in h for k in ["hopeless", "pointless", "numb"]): risks.append("low mood / hopelessness narrative") if any(k in h for k in ["work", "missed work", "can't function"]): risks.append("functional impact mentioned") if any(k in h for k in ["medication", "meds", "stopped taking", "skipping"]): risks.append("medication or adherence mention") if level >= 3: risks.append("potential urgent safety concern (keyword route)") if any(k in h for k in ["not planning to hurt", "don't want to die", "dont want to die"]): protective.append("explicit absence of immediate self-harm plan (self-report)") if any(k in h for k in ["called", "talked to", "reach out", "help"]): protective.append("help-seeking or connection behavior mentioned") if any(k in h for k in ["walk", "grateful", "better"]): protective.append("positive routine or stabilizing activity mentioned") if not protective: protective.append("help-seeking communication (review context)") # De-duplicate preserving order def uniq(xs: list[str]) -> list[str]: seen = set() out = [] for x in xs: if x not in seen: seen.add(x) out.append(x) return out return uniq(risks), uniq(protective) def _med_sleep_flags(note: str) -> tuple[list[str], list[str]]: h = _normalize(note) med: list[str] = [] sleep_mood: list[str] = [] if any(k in h for k in ["medication", "meds", "pill", "stopped taking", "skipping"]): med.append("medication adherence mentioned — flag for human review") if any(k in h for k in ["sleep", "slept", "insomnia", "not sleeping", "haven't slept"]): sleep_mood.append("sleep disruption") if any(k in h for k in ["overwhelm", "anxious", "panic", "sad", "mood"]): sleep_mood.append("mood distress indicators") return med, sleep_mood def route_template(level: int) -> tuple[str, str]: """Map distress level → (safe_response_template_id, recommended_action).""" if level == 0: return "supportive_dialogue_01", "supportive_dialogue" if level == 1: return "monitor_followup_01", "monitor_and_follow_up" if level == 2: return "caregiver_support_01", "escalate_caregiver" return "crisis_protocol_01", "escalate_crisis_protocol" def analyze_note(note: str) -> MindForgeAnalysis: """Primary analysis — deterministic rules; optional models never change routing.""" level, label_s = classify_distress(note) tpl_id, action_str = route_template(level) risks, protective = _extract_risk_protective(note, level, label_s) med_flags, sm_flags = _med_sleep_flags(note) requires_human = level >= 1 escalation = level >= 2 safe_text = TEMPLATES[tpl_id] conf = _stable_confidence(note, level) care_summary = ( "The note suggests elevated distress with sleep or mood-related concerns. Human review is recommended." if level >= 2 else ( "The note suggests mild to moderate stress patterns; scheduled human follow-up is reasonable." if level == 1 else "The note appears stable on quick scan; continue routine monitoring." ) ) if level >= 3: care_summary = ( "The note triggers crisis-route safeguards. Urgent human review and crisis protocols should be considered." ) patient_expl = ( "This check-in shows signs of distress. A human support person should review it soon." if level >= 2 else ( "This check-in suggests elevated stress. A human support person may want to review." if level == 1 else "This check-in looks stable at a glance; still keep humans in the loop for decisions." ) ) return MindForgeAnalysis( distress_level=level, # type: ignore[arg-type] distress_label=label_s, # type: ignore[arg-type] recommended_action=action_str, # type: ignore[arg-type] risk_signals=risks, protective_signals=protective, medication_or_adherence_flags=med_flags, sleep_mood_flags=sm_flags, requires_human_review=requires_human, escalation_flag=escalation, safe_response_template_id=tpl_id, safe_response=safe_text, care_team_summary=care_summary, patient_safe_explanation=patient_expl, confidence=conf, model_boundary="MindForge AI does not diagnose, treat, or replace licensed care.", ) def mock_base_supportive_paragraph(note: str) -> str: """ Simulated *base model* output: generic supportive prose only (no routing JSON). Used when optional transformers path is unavailable or skipped. """ return ( "Thank you for sharing what’s on your mind. It takes courage to write this down. " "I’m here to reflect what you said in general terms: many people go through waves of stress, " "and it can help to talk with someone you trust when things feel heavy. " "This assistant cannot diagnose or treat conditions; please rely on qualified humans for " "clinical decisions and urgent safety concerns." ) def build_comparison(note: str) -> tuple[str, str]: """Base (generic prose) vs fine-tuned stand-in (structured JSON from rules engine).""" _try_load_models() base_txt = optional_base_generate(note) or mock_base_supportive_paragraph(note) structured = analyze_note(note) fine_txt = json.dumps(structured.model_dump(), indent=2) return base_txt, fine_txt EVAL_DF = pd.DataFrame( { "Metric": [ "Valid JSON", "Distress Accuracy", "Action F1", "Crisis Recall", "Unsafe Overreach Rate", "Protective Signal Extraction", ], "Base Qwen": ["64%", "55%", "0.58", "70%", "16%", "52%"], "Fine-Tuned Qwen": ["96%", "82%", "0.86", "95%", "3%", "84%"], "Why it matters": [ "App reliability", "Better classification", "Safer routing", "Safety-critical", "Avoids risky advice", "More nuanced review", ], } ) TAXONOMY_MD = """ | Level | Label | Typical cues | Recommended action | |---:|---|---|---| | 0 | Calm | stable check-in, no urgent concern | supportive_dialogue | | 1 | Concerned | stress, anxiety, sadness, mild impairment | monitor_and_follow_up | | 2 | Distressed | sleep disruption, withdrawal, hopelessness, medication concern, functional impact | escalate_caregiver | | 3 | Crisis | immediate danger, explicit self-harm intent, overdose risk, abuse, or urgent safety concern | escalate_crisis_protocol | """ METHOD_MD = """ ### MindForge AI — methodology snapshot - **Base model:** Qwen Instruct - **Fine-tuning method:** LoRA SFT - **Dataset:** synthetic safety-labeled mental-health check-ins - **Compute:** AMD Developer Cloud, ROCm, MI300X - **Serving:** Hugging Face Space - **Safety:** rule-based crisis override + pre-vetted response templates This demo uses the rules engine for deterministic escalation when `USE_MODEL=false` or when weights fail to load. """ CUSTOM_CSS = """ .mf-wrap { font-family: system-ui, sans-serif; } .mf-banner { background: linear-gradient(120deg, #1e3a5f 0%, #3d2b5c 100%); color: #f8fafc; padding: 14px 18px; border-radius: 12px; margin-bottom: 14px; line-height: 1.5; } .mf-crisis { background: #7f1d1d; color: #fef2f2; padding: 12px 14px; border-radius: 10px; margin-bottom: 12px; font-weight: 600; } .mf-card { border: 1px solid #e5e7eb; border-radius: 12px; padding: 12px 14px; background: #fafafa; } .mf-card pre { white-space: pre-wrap; word-break: break-word; font-size: 12px; } .mf-kv { display: grid; grid-template-columns: 220px 1fr; gap: 8px 12px; align-items: start; } .mf-kv b { color: #374151; } footer.mf-foot { font-size: 12px; color: #6b7280; margin-top: 10px; } """ def model_status_line() -> str: """One-line hint about optional HF weights vs deterministic routing.""" if _model_bundle: return ( "Optional base weights loaded for the comparison tab; **routing and JSON always use the rules engine** " "(crisis override + templates)." ) if _model_load_error == "skipped": return ( "Running **deterministic rules classifier** (`USE_MODEL=false` or `BASE_MODEL_ID` unset). " "No Hub secrets required for this mode." ) return ( f"Optional model load failed ({_model_load_error}); **routing uses rules only**. " "Check `BASE_MODEL_ID` / GPU memory / token." ) def run_analyze(note: str) -> tuple[str, str, str]: """HTML summary panel, raw JSON for developers, and status line.""" _try_load_models() if not note or not note.strip(): return ("_(Paste a note to analyze.)_", "", "Waiting for input.") a = analyze_note(note) crisis = f"
{CRISIS_BANNER}
" if a.distress_level >= 3 else "" kv = ( f"
" f"Distress level{a.distress_level}" f"Distress label{a.distress_label}" f"Recommended action{a.recommended_action}" f"Risk signals{', '.join(a.risk_signals) or '—'}" f"Protective signals{', '.join(a.protective_signals) or '—'}" f"Medication / adherence flags{', '.join(a.medication_or_adherence_flags) or '—'}" f"Sleep / mood flags{', '.join(a.sleep_mood_flags) or '—'}" f"Requires human review{a.requires_human_review}" f"Escalation flag{a.escalation_flag}" f"Confidence{a.confidence}" f"Safe response template{a.safe_response_template_id}" f"Safe response{a.safe_response}" f"Care team summary{a.care_team_summary}" f"Patient-safe explanation{a.patient_safe_explanation}" f"Safety disclaimer{a.model_boundary}" f"
" ) raw_json = json.dumps(a.model_dump(), indent=2) return crisis + kv, raw_json, model_status_line() EXAMPLES = [ [ "I had a better day today. I went for a walk, called my sister, and finally finished a small task I had been avoiding." ], [ "I feel anxious and overwhelmed. Work has been piling up and I keep worrying that I am falling behind." ], [ "I have not slept much this week. I feel overwhelmed and keep missing work. I am not planning to hurt myself, but I feel like I am disappearing." ], [ "I stopped taking my medication because it makes me feel strange. I have not slept in three days and I feel like everything is pointless." ], ] with gr.Blocks(title="MindForge AI: Distress Detection & Conversational Escalation Engine") as demo: # Inline stylesheet keeps styling compatible across Gradio 4–6 (Blocks `css=` moved in Gradio 6). gr.HTML(f"") gr.Markdown( "# MindForge AI: Distress Detection & Conversational Escalation Engine\n" "**Team MindForge AI · hackathon demo** — structured distress review, escalation routing, and human-in-the-loop summaries." ) gr.HTML(f"
{EXACT_DISCLAIMER}
") with gr.Tabs(): with gr.Tab("Analyze Check-In"): note_in = gr.Textbox( label="Check-in, journal entry, caregiver note, or support note", lines=8, placeholder="Paste text here. This demo classifies and routes — it does not provide therapy.", ) analyze_btn = gr.Button("Analyze note", variant="primary") status = gr.Markdown() out_panel = gr.HTML() out_json = gr.Code(label="Raw JSON (fine-tuned-style structured output)", language="json") analyze_btn.click(fn=run_analyze, inputs=[note_in], outputs=[out_panel, out_json, status]) gr.Examples(EXAMPLES, inputs=[note_in], label="Example inputs") with gr.Tab("Base vs Fine-Tuned"): gr.Markdown( "**Base model panel:** generic supportive prose (optional live Qwen if `USE_MODEL=true`). " "**Fine-tuned panel:** structured JSON + routing from the MindForge rules engine " "(stand-in for LoRA JSON outputs)." ) cmp_in = gr.Textbox(label="Note to compare", lines=6) cmp_btn = gr.Button("Run comparison", variant="primary") with gr.Row(): base_out = gr.Textbox(label="Base model output (generic supportive prose)", lines=14) ft_out = gr.Code(label="Fine-tuned model output (structured JSON + action routing)", language="json") def _cmp(note: str): b, f = build_comparison(note) return b, f cmp_btn.click(fn=_cmp, inputs=[cmp_in], outputs=[base_out, ft_out]) gr.Examples(EXAMPLES, inputs=[cmp_in], label="Try an example") with gr.Tab("Evaluation"): gr.Markdown( "**Hackathon evaluation framework.** Replace with measured values after running held-out evals." ) gr.Dataframe(EVAL_DF, label="Metrics (placeholder)", interactive=False) with gr.Tab("Methodology"): gr.Markdown(METHOD_MD) gr.Markdown("### Distress taxonomy (exact table)") gr.Markdown(TAXONOMY_MD) gr.Markdown( f"" )