File size: 22,773 Bytes
2f8b061 9f64956 2f8b061 9f64956 2f8b061 9f64956 2f8b061 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 | """
MindForge AI — Distress Detection & Conversational Escalation Engine (Gradio demo).
Primary path: deterministic rules classifier so the Space always boots fast and stays reliable.
Optional: transformers + PEFT adapters when USE_MODEL=true and weights load successfully.
"""
from __future__ import annotations
import json
import logging
import os
import re
from typing import Any, Literal
import gradio as gr
import pandas as pd
from pydantic import BaseModel, Field
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("mindforge")
# ---------------------------------------------------------------------------
# Environment — optional HF model path (non-blocking)
# ---------------------------------------------------------------------------
USE_MODEL = os.getenv("USE_MODEL", "false").strip().lower() in ("1", "true", "yes")
BASE_MODEL_ID = os.getenv("BASE_MODEL_ID", "").strip()
ADAPTER_MODEL_ID = os.getenv("ADAPTER_MODEL_ID", "").strip()
HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("HUGGING_FACE_HUB_TOKEN", "")).strip()
# Lazy singletons for optional inference (never required for demo)
_model_bundle: dict[str, Any] | None = None
_model_load_error: str | None = None
def _try_load_models() -> None:
"""Attempt optional model load once; failures are logged and ignored."""
global _model_bundle, _model_load_error
if _model_bundle is not None or _model_load_error is not None:
return
if not USE_MODEL or not BASE_MODEL_ID:
_model_load_error = "skipped"
return
try:
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(BASE_MODEL_ID, token=HF_TOKEN or None)
base = AutoModelForCausalLM.from_pretrained(
BASE_MODEL_ID,
token=HF_TOKEN or None,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None,
)
if ADAPTER_MODEL_ID:
base = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID, token=HF_TOKEN or None)
_model_bundle = {"tokenizer": tok, "model": base, "torch": torch}
logger.info("Optional MindForge model weights loaded.")
except Exception as e:
_model_load_error = str(e)
logger.warning("Optional model load failed; using rules fallback. %s", e)
def optional_base_generate(note: str, max_new_tokens: int = 120) -> str | None:
"""
Best-effort generic supportive continuation from the (optional) base model.
Not used for routing/safety decisions — those always come from rules + templates.
"""
_try_load_models()
if not _model_bundle:
return None
try:
torch = _model_bundle["torch"]
tok = _model_bundle["tokenizer"]
model = _model_bundle["model"]
prompt = (
"You are a concise assistant. Reply with one short paragraph that acknowledges the user's "
"message in general supportive language. Do not diagnose, prescribe medication, or give "
"instructions related to self-harm. User message:\n"
f"{note.strip()[:1500]}"
)
inputs = tok(prompt, return_tensors="pt")
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
text = tok.decode(out[0], skip_special_tokens=True)
# Return only the continuation after the prompt when possible
if text.startswith(prompt):
text = text[len(prompt) :].strip()
return text.strip() or None
except Exception as e:
logger.warning("optional_base_generate failed: %s", e)
return None
# ---------------------------------------------------------------------------
# Safe response templates (exact strings from spec)
# ---------------------------------------------------------------------------
TEMPLATES: dict[str, str] = {
"supportive_dialogue_01": (
"Thanks for checking in. I’m glad you shared this. Let’s keep tracking how you’re feeling "
"and note anything that changes."
),
"monitor_followup_01": (
"I hear that things feel difficult right now. Let’s slow down, capture what changed, "
"and consider a supportive follow-up."
),
"caregiver_support_01": (
"I hear you. This sounds important enough for a human support person to review soon. "
"Let’s capture the key signals and connect you with support."
),
"crisis_protocol_01": (
"This may require immediate human support. If there is immediate danger, contact local "
"emergency services or a crisis hotline now. This system will flag the note for urgent "
"human review."
),
}
EXACT_DISCLAIMER = (
"MindForge AI does not diagnose, treat, or replace licensed care. It is a human-in-the-loop "
"risk review assistant that organizes distress signals and recommends escalation pathways for review."
)
CRISIS_BANNER = (
"This may require immediate human support. If there is immediate danger, contact local emergency "
"services or a crisis hotline now. This system flags the note for urgent human review."
)
# Keyword buckets — priority: crisis > distressed > concerned > calm > default concerned
CRISIS_KW = [
"kill myself",
"end my life",
"suicide",
"overdose",
"hurt myself",
"hurt someone",
"goodbye forever",
"can't go on",
"cant go on",
"i have a plan",
]
DISTRESSED_KW = [
"haven't slept",
"havent slept",
"not sleeping",
"stopped taking medication",
"skipping meds",
"disappearing",
"hopeless",
"numb",
"can't function",
"cant function",
"missed work",
"isolating",
]
CONCERNED_KW = [
"anxious",
"overwhelmed",
"sad",
"stressed",
"worried",
"lonely",
"panic",
]
CALM_KW = [
"better",
"okay",
"walk",
"called",
"grateful",
"good day",
"stable",
]
class MindForgeAnalysis(BaseModel):
"""Structured output aligned with the hackathon JSON contract."""
distress_level: Literal[0, 1, 2, 3]
distress_label: Literal["calm", "concerned", "distressed", "crisis"]
recommended_action: Literal[
"supportive_dialogue",
"monitor_and_follow_up",
"escalate_caregiver",
"escalate_crisis_protocol",
]
risk_signals: list[str] = Field(default_factory=list)
protective_signals: list[str] = Field(default_factory=list)
medication_or_adherence_flags: list[str] = Field(default_factory=list)
sleep_mood_flags: list[str] = Field(default_factory=list)
requires_human_review: bool
escalation_flag: bool
safe_response_template_id: str
safe_response: str
care_team_summary: str
patient_safe_explanation: str
confidence: float = Field(ge=0.0, le=1.0)
model_boundary: str
def _normalize(text: str) -> str:
return re.sub(r"\s+", " ", text.strip().lower())
def _contains_any(hay: str, needles: list[str]) -> list[str]:
matched = []
for n in needles:
if n in hay:
matched.append(n)
return matched
def _stable_confidence(note: str, level: int) -> float:
"""Deterministic pseudo-confidence from text — no RNG for reproducible demos."""
base = sum(ord(c) for c in note[:800])
tweak = (base % 17) / 100.0
anchor = 0.74 + tweak + 0.01 * level
return round(min(0.93, max(0.62, anchor)), 2)
def classify_distress(note: str) -> tuple[int, str]:
"""
Rules-first distress taxonomy:
0 calm, 1 concerned, 2 distressed, 3 crisis.
Priority: crisis > distressed > concerned > calm; uncertain → concerned (level 1).
"""
h = _normalize(note)
if _contains_any(h, CRISIS_KW):
return 3, "crisis"
if _contains_any(h, DISTRESSED_KW):
return 2, "distressed"
if _contains_any(h, CONCERNED_KW):
return 1, "concerned"
if _contains_any(h, CALM_KW):
return 0, "calm"
return 1, "concerned"
def _extract_risk_protective(note: str, level: int, label: str) -> tuple[list[str], list[str]]:
"""Lightweight signal extraction for review summaries — no clinical claims."""
h = _normalize(note)
risks: list[str] = []
protective: list[str] = []
if any(k in h for k in ["sleep", "slept", "insomnia", "not sleeping"]):
risks.append("sleep disruption")
if any(k in h for k in ["withdraw", "isolat", "disappear"]):
risks.append("social withdrawal")
if any(k in h for k in ["hopeless", "pointless", "numb"]):
risks.append("low mood / hopelessness narrative")
if any(k in h for k in ["work", "missed work", "can't function"]):
risks.append("functional impact mentioned")
if any(k in h for k in ["medication", "meds", "stopped taking", "skipping"]):
risks.append("medication or adherence mention")
if level >= 3:
risks.append("potential urgent safety concern (keyword route)")
if any(k in h for k in ["not planning to hurt", "don't want to die", "dont want to die"]):
protective.append("explicit absence of immediate self-harm plan (self-report)")
if any(k in h for k in ["called", "talked to", "reach out", "help"]):
protective.append("help-seeking or connection behavior mentioned")
if any(k in h for k in ["walk", "grateful", "better"]):
protective.append("positive routine or stabilizing activity mentioned")
if not protective:
protective.append("help-seeking communication (review context)")
# De-duplicate preserving order
def uniq(xs: list[str]) -> list[str]:
seen = set()
out = []
for x in xs:
if x not in seen:
seen.add(x)
out.append(x)
return out
return uniq(risks), uniq(protective)
def _med_sleep_flags(note: str) -> tuple[list[str], list[str]]:
h = _normalize(note)
med: list[str] = []
sleep_mood: list[str] = []
if any(k in h for k in ["medication", "meds", "pill", "stopped taking", "skipping"]):
med.append("medication adherence mentioned — flag for human review")
if any(k in h for k in ["sleep", "slept", "insomnia", "not sleeping", "haven't slept"]):
sleep_mood.append("sleep disruption")
if any(k in h for k in ["overwhelm", "anxious", "panic", "sad", "mood"]):
sleep_mood.append("mood distress indicators")
return med, sleep_mood
def route_template(level: int) -> tuple[str, str]:
"""Map distress level → (safe_response_template_id, recommended_action)."""
if level == 0:
return "supportive_dialogue_01", "supportive_dialogue"
if level == 1:
return "monitor_followup_01", "monitor_and_follow_up"
if level == 2:
return "caregiver_support_01", "escalate_caregiver"
return "crisis_protocol_01", "escalate_crisis_protocol"
def analyze_note(note: str) -> MindForgeAnalysis:
"""Primary analysis — deterministic rules; optional models never change routing."""
level, label_s = classify_distress(note)
tpl_id, action_str = route_template(level)
risks, protective = _extract_risk_protective(note, level, label_s)
med_flags, sm_flags = _med_sleep_flags(note)
requires_human = level >= 1
escalation = level >= 2
safe_text = TEMPLATES[tpl_id]
conf = _stable_confidence(note, level)
care_summary = (
"The note suggests elevated distress with sleep or mood-related concerns. Human review is recommended."
if level >= 2
else (
"The note suggests mild to moderate stress patterns; scheduled human follow-up is reasonable."
if level == 1
else "The note appears stable on quick scan; continue routine monitoring."
)
)
if level >= 3:
care_summary = (
"The note triggers crisis-route safeguards. Urgent human review and crisis protocols should be considered."
)
patient_expl = (
"This check-in shows signs of distress. A human support person should review it soon."
if level >= 2
else (
"This check-in suggests elevated stress. A human support person may want to review."
if level == 1
else "This check-in looks stable at a glance; still keep humans in the loop for decisions."
)
)
return MindForgeAnalysis(
distress_level=level, # type: ignore[arg-type]
distress_label=label_s, # type: ignore[arg-type]
recommended_action=action_str, # type: ignore[arg-type]
risk_signals=risks,
protective_signals=protective,
medication_or_adherence_flags=med_flags,
sleep_mood_flags=sm_flags,
requires_human_review=requires_human,
escalation_flag=escalation,
safe_response_template_id=tpl_id,
safe_response=safe_text,
care_team_summary=care_summary,
patient_safe_explanation=patient_expl,
confidence=conf,
model_boundary="MindForge AI does not diagnose, treat, or replace licensed care.",
)
def mock_base_supportive_paragraph(note: str) -> str:
"""
Simulated *base model* output: generic supportive prose only (no routing JSON).
Used when optional transformers path is unavailable or skipped.
"""
return (
"Thank you for sharing what’s on your mind. It takes courage to write this down. "
"I’m here to reflect what you said in general terms: many people go through waves of stress, "
"and it can help to talk with someone you trust when things feel heavy. "
"This assistant cannot diagnose or treat conditions; please rely on qualified humans for "
"clinical decisions and urgent safety concerns."
)
def build_comparison(note: str) -> tuple[str, str]:
"""Base (generic prose) vs fine-tuned stand-in (structured JSON from rules engine)."""
_try_load_models()
base_txt = optional_base_generate(note) or mock_base_supportive_paragraph(note)
structured = analyze_note(note)
fine_txt = json.dumps(structured.model_dump(), indent=2)
return base_txt, fine_txt
EVAL_DF = pd.DataFrame(
{
"Metric": [
"Valid JSON",
"Distress Accuracy",
"Action F1",
"Crisis Recall",
"Unsafe Overreach Rate",
"Protective Signal Extraction",
],
"Base Qwen": ["64%", "55%", "0.58", "70%", "16%", "52%"],
"Fine-Tuned Qwen": ["96%", "82%", "0.86", "95%", "3%", "84%"],
"Why it matters": [
"App reliability",
"Better classification",
"Safer routing",
"Safety-critical",
"Avoids risky advice",
"More nuanced review",
],
}
)
TAXONOMY_MD = """
| Level | Label | Typical cues | Recommended action |
|---:|---|---|---|
| 0 | Calm | stable check-in, no urgent concern | supportive_dialogue |
| 1 | Concerned | stress, anxiety, sadness, mild impairment | monitor_and_follow_up |
| 2 | Distressed | sleep disruption, withdrawal, hopelessness, medication concern, functional impact | escalate_caregiver |
| 3 | Crisis | immediate danger, explicit self-harm intent, overdose risk, abuse, or urgent safety concern | escalate_crisis_protocol |
"""
METHOD_MD = """
### MindForge AI — methodology snapshot
- **Base model:** Qwen Instruct
- **Fine-tuning method:** LoRA SFT
- **Dataset:** synthetic safety-labeled mental-health check-ins
- **Compute:** AMD Developer Cloud, ROCm, MI300X
- **Serving:** Hugging Face Space
- **Safety:** rule-based crisis override + pre-vetted response templates
This demo uses the rules engine for deterministic escalation when `USE_MODEL=false` or when weights fail to load.
"""
CUSTOM_CSS = """
.mf-wrap { font-family: system-ui, sans-serif; }
.mf-banner {
background: linear-gradient(120deg, #1e3a5f 0%, #3d2b5c 100%);
color: #f8fafc;
padding: 14px 18px;
border-radius: 12px;
margin-bottom: 14px;
line-height: 1.5;
}
.mf-crisis {
background: #7f1d1d;
color: #fef2f2;
padding: 12px 14px;
border-radius: 10px;
margin-bottom: 12px;
font-weight: 600;
}
.mf-card {
border: 1px solid #e5e7eb;
border-radius: 12px;
padding: 12px 14px;
background: #fafafa;
}
.mf-card pre {
white-space: pre-wrap;
word-break: break-word;
font-size: 12px;
}
.mf-kv { display: grid; grid-template-columns: 220px 1fr; gap: 8px 12px; align-items: start; }
.mf-kv b { color: #374151; }
footer.mf-foot { font-size: 12px; color: #6b7280; margin-top: 10px; }
"""
def model_status_line() -> str:
"""One-line hint about optional HF weights vs deterministic routing."""
if _model_bundle:
return (
"Optional base weights loaded for the comparison tab; **routing and JSON always use the rules engine** "
"(crisis override + templates)."
)
if _model_load_error == "skipped":
return (
"Running **deterministic rules classifier** (`USE_MODEL=false` or `BASE_MODEL_ID` unset). "
"No Hub secrets required for this mode."
)
return (
f"Optional model load failed ({_model_load_error}); **routing uses rules only**. "
"Check `BASE_MODEL_ID` / GPU memory / token."
)
def run_analyze(note: str) -> tuple[str, str, str]:
"""HTML summary panel, raw JSON for developers, and status line."""
_try_load_models()
if not note or not note.strip():
return ("_(Paste a note to analyze.)_", "", "Waiting for input.")
a = analyze_note(note)
crisis = f"<div class='mf-crisis'>{CRISIS_BANNER}</div>" if a.distress_level >= 3 else ""
kv = (
f"<div class='mf-card'><div class='mf-kv'>"
f"<b>Distress level</b><span>{a.distress_level}</span>"
f"<b>Distress label</b><span>{a.distress_label}</span>"
f"<b>Recommended action</b><span>{a.recommended_action}</span>"
f"<b>Risk signals</b><span>{', '.join(a.risk_signals) or '—'}</span>"
f"<b>Protective signals</b><span>{', '.join(a.protective_signals) or '—'}</span>"
f"<b>Medication / adherence flags</b><span>{', '.join(a.medication_or_adherence_flags) or '—'}</span>"
f"<b>Sleep / mood flags</b><span>{', '.join(a.sleep_mood_flags) or '—'}</span>"
f"<b>Requires human review</b><span>{a.requires_human_review}</span>"
f"<b>Escalation flag</b><span>{a.escalation_flag}</span>"
f"<b>Confidence</b><span>{a.confidence}</span>"
f"<b>Safe response template</b><span>{a.safe_response_template_id}</span>"
f"<b>Safe response</b><span>{a.safe_response}</span>"
f"<b>Care team summary</b><span>{a.care_team_summary}</span>"
f"<b>Patient-safe explanation</b><span>{a.patient_safe_explanation}</span>"
f"<b>Safety disclaimer</b><span>{a.model_boundary}</span>"
f"</div></div>"
)
raw_json = json.dumps(a.model_dump(), indent=2)
return crisis + kv, raw_json, model_status_line()
EXAMPLES = [
[
"I had a better day today. I went for a walk, called my sister, and finally finished a small task I had been avoiding."
],
[
"I feel anxious and overwhelmed. Work has been piling up and I keep worrying that I am falling behind."
],
[
"I have not slept much this week. I feel overwhelmed and keep missing work. I am not planning to hurt myself, but I feel like I am disappearing."
],
[
"I stopped taking my medication because it makes me feel strange. I have not slept in three days and I feel like everything is pointless."
],
]
with gr.Blocks(title="MindForge AI: Distress Detection & Conversational Escalation Engine") as demo:
# Inline stylesheet keeps styling compatible across Gradio 4–6 (Blocks `css=` moved in Gradio 6).
gr.HTML(f"<style>{CUSTOM_CSS.strip()}</style>")
gr.Markdown(
"# MindForge AI: Distress Detection & Conversational Escalation Engine\n"
"**Team MindForge AI · hackathon demo** — structured distress review, escalation routing, and human-in-the-loop summaries."
)
gr.HTML(f"<div class='mf-banner'>{EXACT_DISCLAIMER}</div>")
with gr.Tabs():
with gr.Tab("Analyze Check-In"):
note_in = gr.Textbox(
label="Check-in, journal entry, caregiver note, or support note",
lines=8,
placeholder="Paste text here. This demo classifies and routes — it does not provide therapy.",
)
analyze_btn = gr.Button("Analyze note", variant="primary")
status = gr.Markdown()
out_panel = gr.HTML()
out_json = gr.Code(label="Raw JSON (fine-tuned-style structured output)", language="json")
analyze_btn.click(fn=run_analyze, inputs=[note_in], outputs=[out_panel, out_json, status])
gr.Examples(EXAMPLES, inputs=[note_in], label="Example inputs")
with gr.Tab("Base vs Fine-Tuned"):
gr.Markdown(
"**Base model panel:** generic supportive prose (optional live Qwen if `USE_MODEL=true`). "
"**Fine-tuned panel:** structured JSON + routing from the MindForge rules engine "
"(stand-in for LoRA JSON outputs)."
)
cmp_in = gr.Textbox(label="Note to compare", lines=6)
cmp_btn = gr.Button("Run comparison", variant="primary")
with gr.Row():
base_out = gr.Textbox(label="Base model output (generic supportive prose)", lines=14)
ft_out = gr.Code(label="Fine-tuned model output (structured JSON + action routing)", language="json")
def _cmp(note: str):
b, f = build_comparison(note)
return b, f
cmp_btn.click(fn=_cmp, inputs=[cmp_in], outputs=[base_out, ft_out])
gr.Examples(EXAMPLES, inputs=[cmp_in], label="Try an example")
with gr.Tab("Evaluation"):
gr.Markdown(
"**Hackathon evaluation framework.** Replace with measured values after running held-out evals."
)
gr.Dataframe(EVAL_DF, label="Metrics (placeholder)", interactive=False)
with gr.Tab("Methodology"):
gr.Markdown(METHOD_MD)
gr.Markdown("### Distress taxonomy (exact table)")
gr.Markdown(TAXONOMY_MD)
gr.Markdown(
f"<footer class='mf-foot'>Optional env: `USE_MODEL`, `BASE_MODEL_ID`, `ADAPTER_MODEL_ID`, `HF_TOKEN`. "
f"Rules engine remains authoritative for safety routing. {EXACT_DISCLAIMER}</footer>"
) |