MindForgeAI / app.py
harriswarren's picture
MindForge AI: Gradio distress detection demo with rules classifier and optional model hooks
2f8b061
Raw
History Blame Contribute Delete
22.8 kB
"""
MindForge AI — Distress Detection & Conversational Escalation Engine (Gradio demo).
Primary path: deterministic rules classifier so the Space always boots fast and stays reliable.
Optional: transformers + PEFT adapters when USE_MODEL=true and weights load successfully.
"""
from __future__ import annotations
import json
import logging
import os
import re
from typing import Any, Literal
import gradio as gr
import pandas as pd
from pydantic import BaseModel, Field
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mindforge")
# ---------------------------------------------------------------------------
# Environment — optional HF model path (non-blocking)
# ---------------------------------------------------------------------------
USE_MODEL = os.getenv("USE_MODEL", "false").strip().lower() in ("1", "true", "yes")
BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "").strip()
ADAPTER_MODEL_ID = os.getenv("ADAPTER_MODEL_ID", "").strip()
HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("HUGGING_FACE_HUB_TOKEN", "")).strip()
# Lazy singletons for optional inference (never required for demo)
_model_bundle: dict[str, Any] | None = None
_model_load_error: str | None = None
def _try_load_models() -> None:
"""Attempt optional model load once; failures are logged and ignored."""
global _model_bundle, _model_load_error
if _model_bundle is not None or _model_load_error is not None:
return
if not USE_MODEL or not BASE_MODEL_ID:
_model_load_error = "skipped"
return
try:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN or None)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
token=HF_TOKEN or None,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
if ADAPTER_MODEL_ID:
base = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID, token=HF_TOKEN or None)
_model_bundle = {"tokenizer": tok, "model": base, "torch": torch}
logger.info("Optional MindForge model weights loaded.")
except Exception as e:
_model_load_error = str(e)
logger.warning("Optional model load failed; using rules fallback. %s", e)
def optional_base_generate(note: str, max_new_tokens: int = 120) -> str | None:
"""
Best-effort generic supportive continuation from the (optional) base model.
Not used for routing/safety decisions — those always come from rules + templates.
"""
_try_load_models()
if not _model_bundle:
return None
try:
torch = _model_bundle["torch"]
tok = _model_bundle["tokenizer"]
model = _model_bundle["model"]
prompt = (
"You are a concise assistant. Reply with one short paragraph that acknowledges the user's "
"message in general supportive language. Do not diagnose, prescribe medication, or give "
"instructions related to self-harm. User message:\n"
f"{note.strip()[:1500]}"
)
inputs = tok(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
text = tok.decode(out[0], skip_special_tokens=True)
# Return only the continuation after the prompt when possible
if text.startswith(prompt):
text = text[len(prompt) :].strip()
return text.strip() or None
except Exception as e:
logger.warning("optional_base_generate failed: %s", e)
return None
# ---------------------------------------------------------------------------
# Safe response templates (exact strings from spec)
# ---------------------------------------------------------------------------
TEMPLATES: dict[str, str] = {
"supportive_dialogue_01": (
"Thanks for checking in. I’m glad you shared this. Let’s keep tracking how you’re feeling "
"and note anything that changes."
),
"monitor_followup_01": (
"I hear that things feel difficult right now. Let’s slow down, capture what changed, "
"and consider a supportive follow-up."
),
"caregiver_support_01": (
"I hear you. This sounds important enough for a human support person to review soon. "
"Let’s capture the key signals and connect you with support."
),
"crisis_protocol_01": (
"This may require immediate human support. If there is immediate danger, contact local "
"emergency services or a crisis hotline now. This system will flag the note for urgent "
"human review."
),
}
EXACT_DISCLAIMER = (
"MindForge AI does not diagnose, treat, or replace licensed care. It is a human-in-the-loop "
"risk review assistant that organizes distress signals and recommends escalation pathways for review."
)
CRISIS_BANNER = (
"This may require immediate human support. If there is immediate danger, contact local emergency "
"services or a crisis hotline now. This system flags the note for urgent human review."
)
# Keyword buckets — priority: crisis > distressed > concerned > calm > default concerned
CRISIS_KW = [
"kill myself",
"end my life",
"suicide",
"overdose",
"hurt myself",
"hurt someone",
"goodbye forever",
"can't go on",
"cant go on",
"i have a plan",
]
DISTRESSED_KW = [
"haven't slept",
"havent slept",
"not sleeping",
"stopped taking medication",
"skipping meds",
"disappearing",
"hopeless",
"numb",
"can't function",
"cant function",
"missed work",
"isolating",
]
CONCERNED_KW = [
"anxious",
"overwhelmed",
"sad",
"stressed",
"worried",
"lonely",
"panic",
]
CALM_KW = [
"better",
"okay",
"walk",
"called",
"grateful",
"good day",
"stable",
]
class MindForgeAnalysis(BaseModel):
"""Structured output aligned with the hackathon JSON contract."""
distress_level: Literal[0, 1, 2, 3]
distress_label: Literal["calm", "concerned", "distressed", "crisis"]
recommended_action: Literal[
"supportive_dialogue",
"monitor_and_follow_up",
"escalate_caregiver",
"escalate_crisis_protocol",
]
risk_signals: list[str] = Field(default_factory=list)
protective_signals: list[str] = Field(default_factory=list)
medication_or_adherence_flags: list[str] = Field(default_factory=list)
sleep_mood_flags: list[str] = Field(default_factory=list)
requires_human_review: bool
escalation_flag: bool
safe_response_template_id: str
safe_response: str
care_team_summary: str
patient_safe_explanation: str
confidence: float = Field(ge=0.0, le=1.0)
model_boundary: str
def _normalize(text: str) -> str:
return re.sub(r"\s+", " ", text.strip().lower())
def _contains_any(hay: str, needles: list[str]) -> list[str]:
matched = []
for n in needles:
if n in hay:
matched.append(n)
return matched
def _stable_confidence(note: str, level: int) -> float:
"""Deterministic pseudo-confidence from text — no RNG for reproducible demos."""
base = sum(ord(c) for c in note[:800])
tweak = (base % 17) / 100.0
anchor = 0.74 + tweak + 0.01 * level
return round(min(0.93, max(0.62, anchor)), 2)
def classify_distress(note: str) -> tuple[int, str]:
"""
Rules-first distress taxonomy:
0 calm, 1 concerned, 2 distressed, 3 crisis.
Priority: crisis > distressed > concerned > calm; uncertain → concerned (level 1).
"""
h = _normalize(note)
if _contains_any(h, CRISIS_KW):
return 3, "crisis"
if _contains_any(h, DISTRESSED_KW):
return 2, "distressed"
if _contains_any(h, CONCERNED_KW):
return 1, "concerned"
if _contains_any(h, CALM_KW):
return 0, "calm"
return 1, "concerned"
def _extract_risk_protective(note: str, level: int, label: str) -> tuple[list[str], list[str]]:
"""Lightweight signal extraction for review summaries — no clinical claims."""
h = _normalize(note)
risks: list[str] = []
protective: list[str] = []
if any(k in h for k in ["sleep", "slept", "insomnia", "not sleeping"]):
risks.append("sleep disruption")
if any(k in h for k in ["withdraw", "isolat", "disappear"]):
risks.append("social withdrawal")
if any(k in h for k in ["hopeless", "pointless", "numb"]):
risks.append("low mood / hopelessness narrative")
if any(k in h for k in ["work", "missed work", "can't function"]):
risks.append("functional impact mentioned")
if any(k in h for k in ["medication", "meds", "stopped taking", "skipping"]):
risks.append("medication or adherence mention")
if level >= 3:
risks.append("potential urgent safety concern (keyword route)")
if any(k in h for k in ["not planning to hurt", "don't want to die", "dont want to die"]):
protective.append("explicit absence of immediate self-harm plan (self-report)")
if any(k in h for k in ["called", "talked to", "reach out", "help"]):
protective.append("help-seeking or connection behavior mentioned")
if any(k in h for k in ["walk", "grateful", "better"]):
protective.append("positive routine or stabilizing activity mentioned")
if not protective:
protective.append("help-seeking communication (review context)")
# De-duplicate preserving order
def uniq(xs: list[str]) -> list[str]:
seen = set()
out = []
for x in xs:
if x not in seen:
seen.add(x)
out.append(x)
return out
return uniq(risks), uniq(protective)
def _med_sleep_flags(note: str) -> tuple[list[str], list[str]]:
h = _normalize(note)
med: list[str] = []
sleep_mood: list[str] = []
if any(k in h for k in ["medication", "meds", "pill", "stopped taking", "skipping"]):
med.append("medication adherence mentioned — flag for human review")
if any(k in h for k in ["sleep", "slept", "insomnia", "not sleeping", "haven't slept"]):
sleep_mood.append("sleep disruption")
if any(k in h for k in ["overwhelm", "anxious", "panic", "sad", "mood"]):
sleep_mood.append("mood distress indicators")
return med, sleep_mood
def route_template(level: int) -> tuple[str, str]:
"""Map distress level → (safe_response_template_id, recommended_action)."""
if level == 0:
return "supportive_dialogue_01", "supportive_dialogue"
if level == 1:
return "monitor_followup_01", "monitor_and_follow_up"
if level == 2:
return "caregiver_support_01", "escalate_caregiver"
return "crisis_protocol_01", "escalate_crisis_protocol"
def analyze_note(note: str) -> MindForgeAnalysis:
"""Primary analysis — deterministic rules; optional models never change routing."""
level, label_s = classify_distress(note)
tpl_id, action_str = route_template(level)
risks, protective = _extract_risk_protective(note, level, label_s)
med_flags, sm_flags = _med_sleep_flags(note)
requires_human = level >= 1
escalation = level >= 2
safe_text = TEMPLATES[tpl_id]
conf = _stable_confidence(note, level)
care_summary = (
"The note suggests elevated distress with sleep or mood-related concerns. Human review is recommended."
if level >= 2
else (
"The note suggests mild to moderate stress patterns; scheduled human follow-up is reasonable."
if level == 1
else "The note appears stable on quick scan; continue routine monitoring."
)
)
if level >= 3:
care_summary = (
"The note triggers crisis-route safeguards. Urgent human review and crisis protocols should be considered."
)
patient_expl = (
"This check-in shows signs of distress. A human support person should review it soon."
if level >= 2
else (
"This check-in suggests elevated stress. A human support person may want to review."
if level == 1
else "This check-in looks stable at a glance; still keep humans in the loop for decisions."
)
)
return MindForgeAnalysis(
distress_level=level, # type: ignore[arg-type]
distress_label=label_s, # type: ignore[arg-type]
recommended_action=action_str, # type: ignore[arg-type]
risk_signals=risks,
protective_signals=protective,
medication_or_adherence_flags=med_flags,
sleep_mood_flags=sm_flags,
requires_human_review=requires_human,
escalation_flag=escalation,
safe_response_template_id=tpl_id,
safe_response=safe_text,
care_team_summary=care_summary,
patient_safe_explanation=patient_expl,
confidence=conf,
model_boundary="MindForge AI does not diagnose, treat, or replace licensed care.",
)
def mock_base_supportive_paragraph(note: str) -> str:
"""
Simulated *base model* output: generic supportive prose only (no routing JSON).
Used when optional transformers path is unavailable or skipped.
"""
return (
"Thank you for sharing what’s on your mind. It takes courage to write this down. "
"I’m here to reflect what you said in general terms: many people go through waves of stress, "
"and it can help to talk with someone you trust when things feel heavy. "
"This assistant cannot diagnose or treat conditions; please rely on qualified humans for "
"clinical decisions and urgent safety concerns."
)
def build_comparison(note: str) -> tuple[str, str]:
"""Base (generic prose) vs fine-tuned stand-in (structured JSON from rules engine)."""
_try_load_models()
base_txt = optional_base_generate(note) or mock_base_supportive_paragraph(note)
structured = analyze_note(note)
fine_txt = json.dumps(structured.model_dump(), indent=2)
return base_txt, fine_txt
EVAL_DF = pd.DataFrame(
{
"Metric": [
"Valid JSON",
"Distress Accuracy",
"Action F1",
"Crisis Recall",
"Unsafe Overreach Rate",
"Protective Signal Extraction",
],
"Base Qwen": ["64%", "55%", "0.58", "70%", "16%", "52%"],
"Fine-Tuned Qwen": ["96%", "82%", "0.86", "95%", "3%", "84%"],
"Why it matters": [
"App reliability",
"Better classification",
"Safer routing",
"Safety-critical",
"Avoids risky advice",
"More nuanced review",
],
}
)
TAXONOMY_MD = """
| Level | Label | Typical cues | Recommended action |
|---:|---|---|---|
| 0 | Calm | stable check-in, no urgent concern | supportive_dialogue |
| 1 | Concerned | stress, anxiety, sadness, mild impairment | monitor_and_follow_up |
| 2 | Distressed | sleep disruption, withdrawal, hopelessness, medication concern, functional impact | escalate_caregiver |
| 3 | Crisis | immediate danger, explicit self-harm intent, overdose risk, abuse, or urgent safety concern | escalate_crisis_protocol |
"""
METHOD_MD = """
### MindForge AI — methodology snapshot
- **Base model:** Qwen Instruct
- **Fine-tuning method:** LoRA SFT
- **Dataset:** synthetic safety-labeled mental-health check-ins
- **Compute:** AMD Developer Cloud, ROCm, MI300X
- **Serving:** Hugging Face Space
- **Safety:** rule-based crisis override + pre-vetted response templates
This demo uses the rules engine for deterministic escalation when `USE_MODEL=false` or when weights fail to load.
"""
CUSTOM_CSS = """
.mf-wrap { font-family: system-ui, sans-serif; }
.mf-banner {
background: linear-gradient(120deg, #1e3a5f 0%, #3d2b5c 100%);
color: #f8fafc;
padding: 14px 18px;
border-radius: 12px;
margin-bottom: 14px;
line-height: 1.5;
}
.mf-crisis {
background: #7f1d1d;
color: #fef2f2;
padding: 12px 14px;
border-radius: 10px;
margin-bottom: 12px;
font-weight: 600;
}
.mf-card {
border: 1px solid #e5e7eb;
border-radius: 12px;
padding: 12px 14px;
background: #fafafa;
}
.mf-card pre {
white-space: pre-wrap;
word-break: break-word;
font-size: 12px;
}
.mf-kv { display: grid; grid-template-columns: 220px 1fr; gap: 8px 12px; align-items: start; }
.mf-kv b { color: #374151; }
footer.mf-foot { font-size: 12px; color: #6b7280; margin-top: 10px; }
"""
def model_status_line() -> str:
"""One-line hint about optional HF weights vs deterministic routing."""
if _model_bundle:
return (
"Optional base weights loaded for the comparison tab; **routing and JSON always use the rules engine** "
"(crisis override + templates)."
)
if _model_load_error == "skipped":
return (
"Running **deterministic rules classifier** (`USE_MODEL=false` or `BASE_MODEL_ID` unset). "
"No Hub secrets required for this mode."
)
return (
f"Optional model load failed ({_model_load_error}); **routing uses rules only**. "
"Check `BASE_MODEL_ID` / GPU memory / token."
)
def run_analyze(note: str) -> tuple[str, str, str]:
"""HTML summary panel, raw JSON for developers, and status line."""
_try_load_models()
if not note or not note.strip():
return ("_(Paste a note to analyze.)_", "", "Waiting for input.")
a = analyze_note(note)
crisis = f"<div class='mf-crisis'>{CRISIS_BANNER}</div>" if a.distress_level >= 3 else ""
kv = (
f"<div class='mf-card'><div class='mf-kv'>"
f"<b>Distress level</b><span>{a.distress_level}</span>"
f"<b>Distress label</b><span>{a.distress_label}</span>"
f"<b>Recommended action</b><span>{a.recommended_action}</span>"
f"<b>Risk signals</b><span>{', '.join(a.risk_signals) or '—'}</span>"
f"<b>Protective signals</b><span>{', '.join(a.protective_signals) or '—'}</span>"
f"<b>Medication / adherence flags</b><span>{', '.join(a.medication_or_adherence_flags) or '—'}</span>"
f"<b>Sleep / mood flags</b><span>{', '.join(a.sleep_mood_flags) or '—'}</span>"
f"<b>Requires human review</b><span>{a.requires_human_review}</span>"
f"<b>Escalation flag</b><span>{a.escalation_flag}</span>"
f"<b>Confidence</b><span>{a.confidence}</span>"
f"<b>Safe response template</b><span>{a.safe_response_template_id}</span>"
f"<b>Safe response</b><span>{a.safe_response}</span>"
f"<b>Care team summary</b><span>{a.care_team_summary}</span>"
f"<b>Patient-safe explanation</b><span>{a.patient_safe_explanation}</span>"
f"<b>Safety disclaimer</b><span>{a.model_boundary}</span>"
f"</div></div>"
)
raw_json = json.dumps(a.model_dump(), indent=2)
return crisis + kv, raw_json, model_status_line()
EXAMPLES = [
[
"I had a better day today. I went for a walk, called my sister, and finally finished a small task I had been avoiding."
],
[
"I feel anxious and overwhelmed. Work has been piling up and I keep worrying that I am falling behind."
],
[
"I have not slept much this week. I feel overwhelmed and keep missing work. I am not planning to hurt myself, but I feel like I am disappearing."
],
[
"I stopped taking my medication because it makes me feel strange. I have not slept in three days and I feel like everything is pointless."
],
]
with gr.Blocks(title="MindForge AI: Distress Detection & Conversational Escalation Engine") as demo:
# Inline stylesheet keeps styling compatible across Gradio 4–6 (Blocks `css=` moved in Gradio 6).
gr.HTML(f"<style>{CUSTOM_CSS.strip()}</style>")
gr.Markdown(
"# MindForge AI: Distress Detection & Conversational Escalation Engine\n"
"**Team MindForge AI · hackathon demo** — structured distress review, escalation routing, and human-in-the-loop summaries."
)
gr.HTML(f"<div class='mf-banner'>{EXACT_DISCLAIMER}</div>")
with gr.Tabs():
with gr.Tab("Analyze Check-In"):
note_in = gr.Textbox(
label="Check-in, journal entry, caregiver note, or support note",
lines=8,
placeholder="Paste text here. This demo classifies and routes — it does not provide therapy.",
)
analyze_btn = gr.Button("Analyze note", variant="primary")
status = gr.Markdown()
out_panel = gr.HTML()
out_json = gr.Code(label="Raw JSON (fine-tuned-style structured output)", language="json")
analyze_btn.click(fn=run_analyze, inputs=[note_in], outputs=[out_panel, out_json, status])
gr.Examples(EXAMPLES, inputs=[note_in], label="Example inputs")
with gr.Tab("Base vs Fine-Tuned"):
gr.Markdown(
"**Base model panel:** generic supportive prose (optional live Qwen if `USE_MODEL=true`). "
"**Fine-tuned panel:** structured JSON + routing from the MindForge rules engine "
"(stand-in for LoRA JSON outputs)."
)
cmp_in = gr.Textbox(label="Note to compare", lines=6)
cmp_btn = gr.Button("Run comparison", variant="primary")
with gr.Row():
base_out = gr.Textbox(label="Base model output (generic supportive prose)", lines=14)
ft_out = gr.Code(label="Fine-tuned model output (structured JSON + action routing)", language="json")
def _cmp(note: str):
b, f = build_comparison(note)
return b, f
cmp_btn.click(fn=_cmp, inputs=[cmp_in], outputs=[base_out, ft_out])
gr.Examples(EXAMPLES, inputs=[cmp_in], label="Try an example")
with gr.Tab("Evaluation"):
gr.Markdown(
"**Hackathon evaluation framework.** Replace with measured values after running held-out evals."
)
gr.Dataframe(EVAL_DF, label="Metrics (placeholder)", interactive=False)
with gr.Tab("Methodology"):
gr.Markdown(METHOD_MD)
gr.Markdown("### Distress taxonomy (exact table)")
gr.Markdown(TAXONOMY_MD)
gr.Markdown(
f"<footer class='mf-foot'>Optional env: `USE_MODEL`, `BASE_MODEL_ID`, `ADAPTER_MODEL_ID`, `HF_TOKEN`. "
f"Rules engine remains authoritative for safety routing. {EXACT_DISCLAIMER}</footer>"
)