| """ |
| MindForge AI — Distress Detection & Conversational Escalation Engine (Gradio demo). |
| |
| Primary path: deterministic rules classifier so the Space always boots fast and stays reliable. |
| Optional: transformers + PEFT adapters when USE_MODEL=true and weights load successfully. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import os |
| import re |
| from typing import Any, Literal |
|
|
| import gradio as gr |
| import pandas as pd |
| from pydantic import BaseModel, Field |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger("mindforge") |
|
|
| |
| |
| |
| USE_MODEL = os.getenv("USE_MODEL", "false").strip().lower() in ("1", "true", "yes") |
| BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "").strip() |
| ADAPTER_MODEL_ID = os.getenv("ADAPTER_MODEL_ID", "").strip() |
| HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("HUGGING_FACE_HUB_TOKEN", "")).strip() |
|
|
| |
| _model_bundle: dict[str, Any] | None = None |
| _model_load_error: str | None = None |
|
|
|
|
| def _try_load_models() -> None: |
| """Attempt optional model load once; failures are logged and ignored.""" |
| global _model_bundle, _model_load_error |
| if _model_bundle is not None or _model_load_error is not None: |
| return |
| if not USE_MODEL or not BASE_MODEL_ID: |
| _model_load_error = "skipped" |
| return |
| try: |
| import torch |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN or None) |
| base = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL_ID, |
| token=HF_TOKEN or None, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None, |
| ) |
| if ADAPTER_MODEL_ID: |
| base = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID, token=HF_TOKEN or None) |
| _model_bundle = {"tokenizer": tok, "model": base, "torch": torch} |
| logger.info("Optional MindForge model weights loaded.") |
| except Exception as e: |
| _model_load_error = str(e) |
| logger.warning("Optional model load failed; using rules fallback. %s", e) |
|
|
|
|
| def optional_base_generate(note: str, max_new_tokens: int = 120) -> str | None: |
| """ |
| Best-effort generic supportive continuation from the (optional) base model. |
| Not used for routing/safety decisions — those always come from rules + templates. |
| """ |
| _try_load_models() |
| if not _model_bundle: |
| return None |
| try: |
| torch = _model_bundle["torch"] |
| tok = _model_bundle["tokenizer"] |
| model = _model_bundle["model"] |
| prompt = ( |
| "You are a concise assistant. Reply with one short paragraph that acknowledges the user's " |
| "message in general supportive language. Do not diagnose, prescribe medication, or give " |
| "instructions related to self-harm. User message:\n" |
| f"{note.strip()[:1500]}" |
| ) |
| inputs = tok(prompt, return_tensors="pt") |
| if torch.cuda.is_available(): |
| inputs = {k: v.cuda() for k, v in inputs.items()} |
| with torch.no_grad(): |
| out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False) |
| text = tok.decode(out[0], skip_special_tokens=True) |
| |
| if text.startswith(prompt): |
| text = text[len(prompt) :].strip() |
| return text.strip() or None |
| except Exception as e: |
| logger.warning("optional_base_generate failed: %s", e) |
| return None |
|
|
|
|
| |
| |
| |
| TEMPLATES: dict[str, str] = { |
| "supportive_dialogue_01": ( |
| "Thanks for checking in. I’m glad you shared this. Let’s keep tracking how you’re feeling " |
| "and note anything that changes." |
| ), |
| "monitor_followup_01": ( |
| "I hear that things feel difficult right now. Let’s slow down, capture what changed, " |
| "and consider a supportive follow-up." |
| ), |
| "caregiver_support_01": ( |
| "I hear you. This sounds important enough for a human support person to review soon. " |
| "Let’s capture the key signals and connect you with support." |
| ), |
| "crisis_protocol_01": ( |
| "This may require immediate human support. If there is immediate danger, contact local " |
| "emergency services or a crisis hotline now. This system will flag the note for urgent " |
| "human review." |
| ), |
| } |
|
|
| EXACT_DISCLAIMER = ( |
| "MindForge AI does not diagnose, treat, or replace licensed care. It is a human-in-the-loop " |
| "risk review assistant that organizes distress signals and recommends escalation pathways for review." |
| ) |
|
|
| CRISIS_BANNER = ( |
| "This may require immediate human support. If there is immediate danger, contact local emergency " |
| "services or a crisis hotline now. This system flags the note for urgent human review." |
| ) |
|
|
| |
| CRISIS_KW = [ |
| "kill myself", |
| "end my life", |
| "suicide", |
| "overdose", |
| "hurt myself", |
| "hurt someone", |
| "goodbye forever", |
| "can't go on", |
| "cant go on", |
| "i have a plan", |
| ] |
|
|
| DISTRESSED_KW = [ |
| "haven't slept", |
| "havent slept", |
| "not sleeping", |
| "stopped taking medication", |
| "skipping meds", |
| "disappearing", |
| "hopeless", |
| "numb", |
| "can't function", |
| "cant function", |
| "missed work", |
| "isolating", |
| ] |
|
|
| CONCERNED_KW = [ |
| "anxious", |
| "overwhelmed", |
| "sad", |
| "stressed", |
| "worried", |
| "lonely", |
| "panic", |
| ] |
|
|
| CALM_KW = [ |
| "better", |
| "okay", |
| "walk", |
| "called", |
| "grateful", |
| "good day", |
| "stable", |
| ] |
|
|
|
|
| class MindForgeAnalysis(BaseModel): |
| """Structured output aligned with the hackathon JSON contract.""" |
|
|
| distress_level: Literal[0, 1, 2, 3] |
| distress_label: Literal["calm", "concerned", "distressed", "crisis"] |
| recommended_action: Literal[ |
| "supportive_dialogue", |
| "monitor_and_follow_up", |
| "escalate_caregiver", |
| "escalate_crisis_protocol", |
| ] |
| risk_signals: list[str] = Field(default_factory=list) |
| protective_signals: list[str] = Field(default_factory=list) |
| medication_or_adherence_flags: list[str] = Field(default_factory=list) |
| sleep_mood_flags: list[str] = Field(default_factory=list) |
| requires_human_review: bool |
| escalation_flag: bool |
| safe_response_template_id: str |
| safe_response: str |
| care_team_summary: str |
| patient_safe_explanation: str |
| confidence: float = Field(ge=0.0, le=1.0) |
| model_boundary: str |
|
|
|
|
| def _normalize(text: str) -> str: |
| return re.sub(r"\s+", " ", text.strip().lower()) |
|
|
|
|
| def _contains_any(hay: str, needles: list[str]) -> list[str]: |
| matched = [] |
| for n in needles: |
| if n in hay: |
| matched.append(n) |
| return matched |
|
|
|
|
| def _stable_confidence(note: str, level: int) -> float: |
| """Deterministic pseudo-confidence from text — no RNG for reproducible demos.""" |
| base = sum(ord(c) for c in note[:800]) |
| tweak = (base % 17) / 100.0 |
| anchor = 0.74 + tweak + 0.01 * level |
| return round(min(0.93, max(0.62, anchor)), 2) |
|
|
|
|
| def classify_distress(note: str) -> tuple[int, str]: |
| """ |
| Rules-first distress taxonomy: |
| 0 calm, 1 concerned, 2 distressed, 3 crisis. |
| Priority: crisis > distressed > concerned > calm; uncertain → concerned (level 1). |
| """ |
| h = _normalize(note) |
|
|
| if _contains_any(h, CRISIS_KW): |
| return 3, "crisis" |
| if _contains_any(h, DISTRESSED_KW): |
| return 2, "distressed" |
| if _contains_any(h, CONCERNED_KW): |
| return 1, "concerned" |
| if _contains_any(h, CALM_KW): |
| return 0, "calm" |
| return 1, "concerned" |
|
|
|
|
| def _extract_risk_protective(note: str, level: int, label: str) -> tuple[list[str], list[str]]: |
| """Lightweight signal extraction for review summaries — no clinical claims.""" |
| h = _normalize(note) |
| risks: list[str] = [] |
| protective: list[str] = [] |
|
|
| if any(k in h for k in ["sleep", "slept", "insomnia", "not sleeping"]): |
| risks.append("sleep disruption") |
| if any(k in h for k in ["withdraw", "isolat", "disappear"]): |
| risks.append("social withdrawal") |
| if any(k in h for k in ["hopeless", "pointless", "numb"]): |
| risks.append("low mood / hopelessness narrative") |
| if any(k in h for k in ["work", "missed work", "can't function"]): |
| risks.append("functional impact mentioned") |
| if any(k in h for k in ["medication", "meds", "stopped taking", "skipping"]): |
| risks.append("medication or adherence mention") |
| if level >= 3: |
| risks.append("potential urgent safety concern (keyword route)") |
|
|
| if any(k in h for k in ["not planning to hurt", "don't want to die", "dont want to die"]): |
| protective.append("explicit absence of immediate self-harm plan (self-report)") |
| if any(k in h for k in ["called", "talked to", "reach out", "help"]): |
| protective.append("help-seeking or connection behavior mentioned") |
| if any(k in h for k in ["walk", "grateful", "better"]): |
| protective.append("positive routine or stabilizing activity mentioned") |
|
|
| if not protective: |
| protective.append("help-seeking communication (review context)") |
|
|
| |
| def uniq(xs: list[str]) -> list[str]: |
| seen = set() |
| out = [] |
| for x in xs: |
| if x not in seen: |
| seen.add(x) |
| out.append(x) |
| return out |
|
|
| return uniq(risks), uniq(protective) |
|
|
|
|
| def _med_sleep_flags(note: str) -> tuple[list[str], list[str]]: |
| h = _normalize(note) |
| med: list[str] = [] |
| sleep_mood: list[str] = [] |
|
|
| if any(k in h for k in ["medication", "meds", "pill", "stopped taking", "skipping"]): |
| med.append("medication adherence mentioned — flag for human review") |
| if any(k in h for k in ["sleep", "slept", "insomnia", "not sleeping", "haven't slept"]): |
| sleep_mood.append("sleep disruption") |
| if any(k in h for k in ["overwhelm", "anxious", "panic", "sad", "mood"]): |
| sleep_mood.append("mood distress indicators") |
|
|
| return med, sleep_mood |
|
|
|
|
| def route_template(level: int) -> tuple[str, str]: |
| """Map distress level → (safe_response_template_id, recommended_action).""" |
| if level == 0: |
| return "supportive_dialogue_01", "supportive_dialogue" |
| if level == 1: |
| return "monitor_followup_01", "monitor_and_follow_up" |
| if level == 2: |
| return "caregiver_support_01", "escalate_caregiver" |
| return "crisis_protocol_01", "escalate_crisis_protocol" |
|
|
|
|
| def analyze_note(note: str) -> MindForgeAnalysis: |
| """Primary analysis — deterministic rules; optional models never change routing.""" |
| level, label_s = classify_distress(note) |
| tpl_id, action_str = route_template(level) |
| risks, protective = _extract_risk_protective(note, level, label_s) |
| med_flags, sm_flags = _med_sleep_flags(note) |
|
|
| requires_human = level >= 1 |
| escalation = level >= 2 |
|
|
| safe_text = TEMPLATES[tpl_id] |
| conf = _stable_confidence(note, level) |
|
|
| care_summary = ( |
| "The note suggests elevated distress with sleep or mood-related concerns. Human review is recommended." |
| if level >= 2 |
| else ( |
| "The note suggests mild to moderate stress patterns; scheduled human follow-up is reasonable." |
| if level == 1 |
| else "The note appears stable on quick scan; continue routine monitoring." |
| ) |
| ) |
| if level >= 3: |
| care_summary = ( |
| "The note triggers crisis-route safeguards. Urgent human review and crisis protocols should be considered." |
| ) |
|
|
| patient_expl = ( |
| "This check-in shows signs of distress. A human support person should review it soon." |
| if level >= 2 |
| else ( |
| "This check-in suggests elevated stress. A human support person may want to review." |
| if level == 1 |
| else "This check-in looks stable at a glance; still keep humans in the loop for decisions." |
| ) |
| ) |
|
|
| return MindForgeAnalysis( |
| distress_level=level, |
| distress_label=label_s, |
| recommended_action=action_str, |
| risk_signals=risks, |
| protective_signals=protective, |
| medication_or_adherence_flags=med_flags, |
| sleep_mood_flags=sm_flags, |
| requires_human_review=requires_human, |
| escalation_flag=escalation, |
| safe_response_template_id=tpl_id, |
| safe_response=safe_text, |
| care_team_summary=care_summary, |
| patient_safe_explanation=patient_expl, |
| confidence=conf, |
| model_boundary="MindForge AI does not diagnose, treat, or replace licensed care.", |
| ) |
|
|
|
|
| def mock_base_supportive_paragraph(note: str) -> str: |
| """ |
| Simulated *base model* output: generic supportive prose only (no routing JSON). |
| Used when optional transformers path is unavailable or skipped. |
| """ |
| return ( |
| "Thank you for sharing what’s on your mind. It takes courage to write this down. " |
| "I’m here to reflect what you said in general terms: many people go through waves of stress, " |
| "and it can help to talk with someone you trust when things feel heavy. " |
| "This assistant cannot diagnose or treat conditions; please rely on qualified humans for " |
| "clinical decisions and urgent safety concerns." |
| ) |
|
|
|
|
| def build_comparison(note: str) -> tuple[str, str]: |
| """Base (generic prose) vs fine-tuned stand-in (structured JSON from rules engine).""" |
| _try_load_models() |
| base_txt = optional_base_generate(note) or mock_base_supportive_paragraph(note) |
| structured = analyze_note(note) |
| fine_txt = json.dumps(structured.model_dump(), indent=2) |
| return base_txt, fine_txt |
|
|
|
|
| EVAL_DF = pd.DataFrame( |
| { |
| "Metric": [ |
| "Valid JSON", |
| "Distress Accuracy", |
| "Action F1", |
| "Crisis Recall", |
| "Unsafe Overreach Rate", |
| "Protective Signal Extraction", |
| ], |
| "Base Qwen": ["64%", "55%", "0.58", "70%", "16%", "52%"], |
| "Fine-Tuned Qwen": ["96%", "82%", "0.86", "95%", "3%", "84%"], |
| "Why it matters": [ |
| "App reliability", |
| "Better classification", |
| "Safer routing", |
| "Safety-critical", |
| "Avoids risky advice", |
| "More nuanced review", |
| ], |
| } |
| ) |
|
|
| TAXONOMY_MD = """ |
| | Level | Label | Typical cues | Recommended action | |
| |---:|---|---|---| |
| | 0 | Calm | stable check-in, no urgent concern | supportive_dialogue | |
| | 1 | Concerned | stress, anxiety, sadness, mild impairment | monitor_and_follow_up | |
| | 2 | Distressed | sleep disruption, withdrawal, hopelessness, medication concern, functional impact | escalate_caregiver | |
| | 3 | Crisis | immediate danger, explicit self-harm intent, overdose risk, abuse, or urgent safety concern | escalate_crisis_protocol | |
| """ |
|
|
| METHOD_MD = """ |
| ### MindForge AI — methodology snapshot |
| |
| - **Base model:** Qwen Instruct |
| - **Fine-tuning method:** LoRA SFT |
| - **Dataset:** synthetic safety-labeled mental-health check-ins |
| - **Compute:** AMD Developer Cloud, ROCm, MI300X |
| - **Serving:** Hugging Face Space |
| - **Safety:** rule-based crisis override + pre-vetted response templates |
| |
| This demo uses the rules engine for deterministic escalation when `USE_MODEL=false` or when weights fail to load. |
| """ |
|
|
|
|
| CUSTOM_CSS = """ |
| .mf-wrap { font-family: system-ui, sans-serif; } |
| .mf-banner { |
| background: linear-gradient(120deg, #1e3a5f 0%, #3d2b5c 100%); |
| color: #f8fafc; |
| padding: 14px 18px; |
| border-radius: 12px; |
| margin-bottom: 14px; |
| line-height: 1.5; |
| } |
| .mf-crisis { |
| background: #7f1d1d; |
| color: #fef2f2; |
| padding: 12px 14px; |
| border-radius: 10px; |
| margin-bottom: 12px; |
| font-weight: 600; |
| } |
| .mf-card { |
| border: 1px solid #e5e7eb; |
| border-radius: 12px; |
| padding: 12px 14px; |
| background: #fafafa; |
| } |
| .mf-card pre { |
| white-space: pre-wrap; |
| word-break: break-word; |
| font-size: 12px; |
| } |
| .mf-kv { display: grid; grid-template-columns: 220px 1fr; gap: 8px 12px; align-items: start; } |
| .mf-kv b { color: #374151; } |
| footer.mf-foot { font-size: 12px; color: #6b7280; margin-top: 10px; } |
| """ |
|
|
|
|
| def model_status_line() -> str: |
| """One-line hint about optional HF weights vs deterministic routing.""" |
| if _model_bundle: |
| return ( |
| "Optional base weights loaded for the comparison tab; **routing and JSON always use the rules engine** " |
| "(crisis override + templates)." |
| ) |
| if _model_load_error == "skipped": |
| return ( |
| "Running **deterministic rules classifier** (`USE_MODEL=false` or `BASE_MODEL_ID` unset). " |
| "No Hub secrets required for this mode." |
| ) |
| return ( |
| f"Optional model load failed ({_model_load_error}); **routing uses rules only**. " |
| "Check `BASE_MODEL_ID` / GPU memory / token." |
| ) |
|
|
|
|
| def run_analyze(note: str) -> tuple[str, str, str]: |
| """HTML summary panel, raw JSON for developers, and status line.""" |
| _try_load_models() |
| if not note or not note.strip(): |
| return ("_(Paste a note to analyze.)_", "", "Waiting for input.") |
| a = analyze_note(note) |
| crisis = f"<div class='mf-crisis'>{CRISIS_BANNER}</div>" if a.distress_level >= 3 else "" |
| kv = ( |
| f"<div class='mf-card'><div class='mf-kv'>" |
| f"<b>Distress level</b><span>{a.distress_level}</span>" |
| f"<b>Distress label</b><span>{a.distress_label}</span>" |
| f"<b>Recommended action</b><span>{a.recommended_action}</span>" |
| f"<b>Risk signals</b><span>{', '.join(a.risk_signals) or '—'}</span>" |
| f"<b>Protective signals</b><span>{', '.join(a.protective_signals) or '—'}</span>" |
| f"<b>Medication / adherence flags</b><span>{', '.join(a.medication_or_adherence_flags) or '—'}</span>" |
| f"<b>Sleep / mood flags</b><span>{', '.join(a.sleep_mood_flags) or '—'}</span>" |
| f"<b>Requires human review</b><span>{a.requires_human_review}</span>" |
| f"<b>Escalation flag</b><span>{a.escalation_flag}</span>" |
| f"<b>Confidence</b><span>{a.confidence}</span>" |
| f"<b>Safe response template</b><span>{a.safe_response_template_id}</span>" |
| f"<b>Safe response</b><span>{a.safe_response}</span>" |
| f"<b>Care team summary</b><span>{a.care_team_summary}</span>" |
| f"<b>Patient-safe explanation</b><span>{a.patient_safe_explanation}</span>" |
| f"<b>Safety disclaimer</b><span>{a.model_boundary}</span>" |
| f"</div></div>" |
| ) |
| raw_json = json.dumps(a.model_dump(), indent=2) |
| return crisis + kv, raw_json, model_status_line() |
|
|
|
|
| EXAMPLES = [ |
| [ |
| "I had a better day today. I went for a walk, called my sister, and finally finished a small task I had been avoiding." |
| ], |
| [ |
| "I feel anxious and overwhelmed. Work has been piling up and I keep worrying that I am falling behind." |
| ], |
| [ |
| "I have not slept much this week. I feel overwhelmed and keep missing work. I am not planning to hurt myself, but I feel like I am disappearing." |
| ], |
| [ |
| "I stopped taking my medication because it makes me feel strange. I have not slept in three days and I feel like everything is pointless." |
| ], |
| ] |
|
|
|
|
| with gr.Blocks(title="MindForge AI: Distress Detection & Conversational Escalation Engine") as demo: |
| |
| gr.HTML(f"<style>{CUSTOM_CSS.strip()}</style>") |
| gr.Markdown( |
| "# MindForge AI: Distress Detection & Conversational Escalation Engine\n" |
| "**Team MindForge AI · hackathon demo** — structured distress review, escalation routing, and human-in-the-loop summaries." |
| ) |
| gr.HTML(f"<div class='mf-banner'>{EXACT_DISCLAIMER}</div>") |
|
|
| with gr.Tabs(): |
| with gr.Tab("Analyze Check-In"): |
| note_in = gr.Textbox( |
| label="Check-in, journal entry, caregiver note, or support note", |
| lines=8, |
| placeholder="Paste text here. This demo classifies and routes — it does not provide therapy.", |
| ) |
| analyze_btn = gr.Button("Analyze note", variant="primary") |
| status = gr.Markdown() |
| out_panel = gr.HTML() |
| out_json = gr.Code(label="Raw JSON (fine-tuned-style structured output)", language="json") |
|
|
| analyze_btn.click(fn=run_analyze, inputs=[note_in], outputs=[out_panel, out_json, status]) |
|
|
| gr.Examples(EXAMPLES, inputs=[note_in], label="Example inputs") |
|
|
| with gr.Tab("Base vs Fine-Tuned"): |
| gr.Markdown( |
| "**Base model panel:** generic supportive prose (optional live Qwen if `USE_MODEL=true`). " |
| "**Fine-tuned panel:** structured JSON + routing from the MindForge rules engine " |
| "(stand-in for LoRA JSON outputs)." |
| ) |
| cmp_in = gr.Textbox(label="Note to compare", lines=6) |
| cmp_btn = gr.Button("Run comparison", variant="primary") |
| with gr.Row(): |
| base_out = gr.Textbox(label="Base model output (generic supportive prose)", lines=14) |
| ft_out = gr.Code(label="Fine-tuned model output (structured JSON + action routing)", language="json") |
|
|
| def _cmp(note: str): |
| b, f = build_comparison(note) |
| return b, f |
|
|
| cmp_btn.click(fn=_cmp, inputs=[cmp_in], outputs=[base_out, ft_out]) |
| gr.Examples(EXAMPLES, inputs=[cmp_in], label="Try an example") |
|
|
| with gr.Tab("Evaluation"): |
| gr.Markdown( |
| "**Hackathon evaluation framework.** Replace with measured values after running held-out evals." |
| ) |
| gr.Dataframe(EVAL_DF, label="Metrics (placeholder)", interactive=False) |
|
|
| with gr.Tab("Methodology"): |
| gr.Markdown(METHOD_MD) |
| gr.Markdown("### Distress taxonomy (exact table)") |
| gr.Markdown(TAXONOMY_MD) |
|
|
| gr.Markdown( |
| f"<footer class='mf-foot'>Optional env: `USE_MODEL`, `BASE_MODEL_ID`, `ADAPTER_MODEL_ID`, `HF_TOKEN`. " |
| f"Rules engine remains authoritative for safety routing. {EXACT_DISCLAIMER}</footer>" |
| ) |