Spaces:
Sleeping
Sleeping
| """Module 4 - Thought Diary (née Cognitive Journal): CBT journal plus deterministic clinical signals. | |
| The original journal path stays intact: user entry -> crisis check -> structured | |
| LLM JSON -> distortion cards. Around it we add non-LLM clinical surfaces: | |
| PHQ-9, GAD-7, daily check-ins, and an auditable stepped-care dashboard. | |
| """ | |
| from __future__ import annotations | |
| from collections import Counter | |
| from datetime import datetime | |
| from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple | |
| import plotly.graph_objects as go | |
| import streamlit as st | |
| from pydantic import BaseModel, Field | |
| from backend.claude_client import chat_structured | |
| from backend.i18n import claude_language_name, t | |
| from backend.resources import load_screeners | |
| from backend.safeguards import check_crisis, render_crisis_banner | |
| MODULE_NAME = "cognitive_journal" | |
| ENTRIES_KEY = "cognitive_journal_entries" | |
| LAST_ANALYSIS_KEY = "cognitive_journal_last" | |
| SECTION_KEY = "cognitive_journal_section" | |
| JOURNAL_WIDGET_VERSION_KEY = "cognitive_journal_widget_version" | |
| JOURNAL_NOTICE_KEY = "cognitive_journal_notice" | |
| PHQ9_SCORE_KEY = "cognitive_journal_phq9_score" | |
| PHQ9_ANSWERS_KEY = "cognitive_journal_phq9_answers" | |
| PHQ9_TAKEN_AT_KEY = "cognitive_journal_phq9_taken_at" | |
| PHQ9_HISTORY_KEY = "cognitive_journal_phq9_history" | |
| PHQ9_ITEM9_KEY = "cognitive_journal_phq9_item9_positive" | |
| PHQ9_WIDGET_VERSION_KEY = "cognitive_journal_phq9_widget_version" | |
| GAD7_SCORE_KEY = "cognitive_journal_gad7_score" | |
| GAD7_ANSWERS_KEY = "cognitive_journal_gad7_answers" | |
| GAD7_TAKEN_AT_KEY = "cognitive_journal_gad7_taken_at" | |
| GAD7_HISTORY_KEY = "cognitive_journal_gad7_history" | |
| GAD7_WIDGET_VERSION_KEY = "cognitive_journal_gad7_widget_version" | |
| CHECKINS_KEY = "cognitive_journal_checkins" | |
| LEGACY_CLEANUP_KEY = "cognitive_journal_legacy_cleanup_done" | |
| ScreenerKind = Literal["phq9", "gad7"] | |
| SECTION_CHOICES = [ | |
| ("journal", "journal_section_journal_label"), | |
| ("phq9", "journal_section_phq9_label"), | |
| ("gad7", "journal_section_gad7_label"), | |
| ("checkin", "journal_section_checkin_label"), | |
| ("dashboard", "journal_section_dashboard_label"), | |
| ] | |
| DistortionType = Literal[ | |
| "catastrophizing", | |
| "mind_reading", | |
| "all_or_nothing", | |
| "fortune_telling", | |
| "personalization", | |
| "mental_filter", | |
| "emotional_reasoning", | |
| "should_statements", | |
| ] | |
| MoodType = Literal[ | |
| "anxious", | |
| "sad", | |
| "frustrated", | |
| "hopeful", | |
| "neutral", | |
| "overwhelmed", | |
| ] | |
| class Distortion(BaseModel): | |
| type: DistortionType | |
| phrase: str | |
| explanation: str | |
| reframe: str | |
| evidence_question: str | |
| class JournalAnalysis(BaseModel): | |
| overall_mood: MoodType | |
| distortions: List[Distortion] = Field(default_factory=list) | |
| summary: str | |
| needs_professional_signal: bool = False | |
| ts: str = "" | |
| entry_text: str = "" | |
| DISTORTION_LABELS = { | |
| "catastrophizing": "Catastrophizing", | |
| "mind_reading": "Mind-reading", | |
| "all_or_nothing": "All-or-nothing thinking", | |
| "fortune_telling": "Fortune-telling", | |
| "personalization": "Personalization", | |
| "mental_filter": "Mental filter", | |
| "emotional_reasoning": "Emotional reasoning", | |
| "should_statements": "Should-statements", | |
| } | |
| _LEGACY_SEED_PHRASES = { | |
| "my whole placement season is ruined", | |
| "everyone in my batch thinks I'm useless", | |
| "I should have studied harder in first year", | |
| "my parents are disappointed because of me", | |
| "I'm going to freeze up in the viva and fail", | |
| "if I don't get an A I've wasted the whole semester", | |
| "I only remember the parts I got wrong in the last test", | |
| } | |
| _LEGACY_SEED_CHECKINS = [ | |
| {"mood": 5, "sleep_h": 6.5, "stress": 6}, | |
| {"mood": 6, "sleep_h": 7.0, "stress": 5}, | |
| ] | |
| def _is_legacy_seed_entry(entry: Any) -> bool: | |
| if isinstance(entry, Mapping): | |
| distortions = entry.get("distortions", []) or [] | |
| else: | |
| distortions = getattr(entry, "distortions", []) or [] | |
| phrases = set() | |
| for distortion in distortions: | |
| if isinstance(distortion, Mapping): | |
| phrase = distortion.get("phrase", "") | |
| else: | |
| phrase = getattr(distortion, "phrase", "") | |
| if phrase: | |
| phrases.add(str(phrase)) | |
| return bool(phrases) and phrases.issubset(_LEGACY_SEED_PHRASES) | |
| def _drop_legacy_seed_entries(entries: Sequence[Any]) -> List[Any]: | |
| return [entry for entry in entries if not _is_legacy_seed_entry(entry)] | |
| def _same_checkin_values(item: Mapping[str, Any], expected: Mapping[str, Any]) -> bool: | |
| try: | |
| return ( | |
| int(item.get("mood")) == int(expected["mood"]) | |
| and float(item.get("sleep_h")) == float(expected["sleep_h"]) | |
| and int(item.get("stress")) == int(expected["stress"]) | |
| ) | |
| except (TypeError, ValueError): | |
| return False | |
| def _drop_legacy_seed_checkins(checkins: Sequence[Mapping[str, Any]]) -> List[Mapping[str, Any]]: | |
| checkins_list = list(checkins) | |
| if len(checkins_list) < len(_LEGACY_SEED_CHECKINS): | |
| return checkins_list | |
| for item, expected in zip(checkins_list, _LEGACY_SEED_CHECKINS): | |
| if not _same_checkin_values(item, expected): | |
| return checkins_list | |
| return checkins_list[len(_LEGACY_SEED_CHECKINS) :] | |
| def _init_state() -> None: | |
| needs_legacy_cleanup = not bool(st.session_state.get(LEGACY_CLEANUP_KEY)) | |
| if ENTRIES_KEY not in st.session_state: | |
| st.session_state[ENTRIES_KEY] = [] | |
| elif needs_legacy_cleanup: | |
| st.session_state[ENTRIES_KEY] = _drop_legacy_seed_entries( | |
| list(st.session_state.get(ENTRIES_KEY, []) or []) | |
| ) | |
| if LAST_ANALYSIS_KEY not in st.session_state: | |
| st.session_state[LAST_ANALYSIS_KEY] = None | |
| if SECTION_KEY not in st.session_state: | |
| st.session_state[SECTION_KEY] = "journal" | |
| if JOURNAL_WIDGET_VERSION_KEY not in st.session_state: | |
| st.session_state[JOURNAL_WIDGET_VERSION_KEY] = 0 | |
| if JOURNAL_NOTICE_KEY not in st.session_state: | |
| st.session_state[JOURNAL_NOTICE_KEY] = None | |
| if PHQ9_SCORE_KEY not in st.session_state: | |
| st.session_state[PHQ9_SCORE_KEY] = None | |
| if PHQ9_ANSWERS_KEY not in st.session_state: | |
| st.session_state[PHQ9_ANSWERS_KEY] = [0] * 9 | |
| if PHQ9_TAKEN_AT_KEY not in st.session_state: | |
| st.session_state[PHQ9_TAKEN_AT_KEY] = "" | |
| if PHQ9_HISTORY_KEY not in st.session_state: | |
| st.session_state[PHQ9_HISTORY_KEY] = [] | |
| if PHQ9_ITEM9_KEY not in st.session_state: | |
| st.session_state[PHQ9_ITEM9_KEY] = False | |
| if PHQ9_WIDGET_VERSION_KEY not in st.session_state: | |
| st.session_state[PHQ9_WIDGET_VERSION_KEY] = 0 | |
| if GAD7_SCORE_KEY not in st.session_state: | |
| st.session_state[GAD7_SCORE_KEY] = None | |
| if GAD7_ANSWERS_KEY not in st.session_state: | |
| st.session_state[GAD7_ANSWERS_KEY] = [0] * 7 | |
| if GAD7_TAKEN_AT_KEY not in st.session_state: | |
| st.session_state[GAD7_TAKEN_AT_KEY] = "" | |
| if GAD7_HISTORY_KEY not in st.session_state: | |
| st.session_state[GAD7_HISTORY_KEY] = [] | |
| if GAD7_WIDGET_VERSION_KEY not in st.session_state: | |
| st.session_state[GAD7_WIDGET_VERSION_KEY] = 0 | |
| if CHECKINS_KEY not in st.session_state: | |
| st.session_state[CHECKINS_KEY] = [] | |
| elif needs_legacy_cleanup: | |
| st.session_state[CHECKINS_KEY] = _drop_legacy_seed_checkins( | |
| list(st.session_state.get(CHECKINS_KEY, []) or []) | |
| ) | |
| st.session_state[LEGACY_CLEANUP_KEY] = True | |
| def _render_section_router(lang: str) -> str: | |
| options = [key for key, _ in SECTION_CHOICES] | |
| # Normalise the session-state value before the widget reads it, so we never | |
| # pass both `index=` and an existing `key=` (Streamlit forbids that combo). | |
| current = st.session_state.get(SECTION_KEY, "journal") | |
| if current not in options: | |
| st.session_state[SECTION_KEY] = "journal" | |
| choice = st.radio( | |
| "journal-section", | |
| options=options, | |
| format_func=lambda key: t(dict(SECTION_CHOICES)[key], lang), | |
| horizontal=True, | |
| label_visibility="collapsed", | |
| key=SECTION_KEY, | |
| ) | |
| return str(choice) | |
| def _render_journal_explainer(lang: str) -> None: | |
| with st.expander(t("journal_about_heading", lang), expanded=False): | |
| st.markdown(t("journal_about_body", lang)) | |
| def _render_analysis(analysis: Any, lang: str) -> None: | |
| st.markdown(f"##### {t('journal_summary_heading', lang)}") | |
| st.info(_entry_summary(analysis)) | |
| if _entry_needs_pro(analysis): | |
| st.warning(t("journal_needs_pro", lang)) | |
| distortions = _entry_distortions(analysis) | |
| if not distortions: | |
| st.success(t("journal_no_distortions", lang)) | |
| return | |
| st.markdown(f"##### {t('journal_distortions_heading', lang)}") | |
| for d in distortions: | |
| distortion_type = _distortion_type(d) | |
| with st.container(border=True): | |
| st.markdown(f"> *{_distortion_phrase(d)}*") | |
| st.markdown( | |
| f"**{DISTORTION_LABELS.get(distortion_type, distortion_type)}** - " | |
| f"{_distortion_explanation(d)}" | |
| ) | |
| st.markdown(f"**{t('journal_reframe_heading', lang)}:** {_distortion_reframe(d)}") | |
| st.caption( | |
| f"**{t('journal_question_heading', lang)}:** " | |
| f"{_distortion_evidence_question(d)}" | |
| ) | |
| def _entry_distortions(entry: Any) -> Sequence[Any]: | |
| if isinstance(entry, JournalAnalysis): | |
| return entry.distortions | |
| if isinstance(entry, Mapping): | |
| return entry.get("distortions", []) or [] | |
| return getattr(entry, "distortions", []) or [] | |
| def _entry_needs_pro(entry: Any) -> bool: | |
| if isinstance(entry, JournalAnalysis): | |
| return bool(entry.needs_professional_signal) | |
| if isinstance(entry, Mapping): | |
| return bool(entry.get("needs_professional_signal")) | |
| return bool(getattr(entry, "needs_professional_signal", False)) | |
| def _entry_timestamp(entry: Any) -> str: | |
| if isinstance(entry, JournalAnalysis): | |
| return str(entry.ts or "") | |
| if isinstance(entry, Mapping): | |
| return str(entry.get("ts", "")) | |
| return str(getattr(entry, "ts", "")) | |
| def _entry_summary(entry: Any) -> str: | |
| if isinstance(entry, JournalAnalysis): | |
| return str(entry.summary or "") | |
| if isinstance(entry, Mapping): | |
| return str(entry.get("summary", "")) | |
| return str(getattr(entry, "summary", "")) | |
| def _entry_mood(entry: Any) -> str: | |
| if isinstance(entry, JournalAnalysis): | |
| return str(entry.overall_mood or "") | |
| if isinstance(entry, Mapping): | |
| return str(entry.get("overall_mood", "")) | |
| return str(getattr(entry, "overall_mood", "")) | |
| def _entry_text(entry: Any) -> str: | |
| if isinstance(entry, JournalAnalysis): | |
| return str(entry.entry_text or "") | |
| if isinstance(entry, Mapping): | |
| return str(entry.get("entry_text", "")) | |
| return str(getattr(entry, "entry_text", "")) | |
| def _distortion_type(distortion: Any) -> str: | |
| if isinstance(distortion, Distortion): | |
| return distortion.type | |
| if isinstance(distortion, Mapping): | |
| return str(distortion.get("type", "")) | |
| return str(getattr(distortion, "type", "")) | |
| def _distortion_phrase(distortion: Any) -> str: | |
| if isinstance(distortion, Distortion): | |
| return str(distortion.phrase or "") | |
| if isinstance(distortion, Mapping): | |
| return str(distortion.get("phrase", "")) | |
| return str(getattr(distortion, "phrase", "")) | |
| def _distortion_explanation(distortion: Any) -> str: | |
| if isinstance(distortion, Distortion): | |
| return str(distortion.explanation or "") | |
| if isinstance(distortion, Mapping): | |
| return str(distortion.get("explanation", "")) | |
| return str(getattr(distortion, "explanation", "")) | |
| def _distortion_reframe(distortion: Any) -> str: | |
| if isinstance(distortion, Distortion): | |
| return str(distortion.reframe or "") | |
| if isinstance(distortion, Mapping): | |
| return str(distortion.get("reframe", "")) | |
| return str(getattr(distortion, "reframe", "")) | |
| def _distortion_evidence_question(distortion: Any) -> str: | |
| if isinstance(distortion, Distortion): | |
| return str(distortion.evidence_question or "") | |
| if isinstance(distortion, Mapping): | |
| return str(distortion.get("evidence_question", "")) | |
| return str(getattr(distortion, "evidence_question", "")) | |
| def _analysis_to_journal_record(analysis: Any, entry_text: str) -> Dict[str, Any]: | |
| if isinstance(analysis, Mapping): | |
| record = dict(analysis) | |
| elif hasattr(analysis, "model_dump"): | |
| record = dict(analysis.model_dump(mode="python")) | |
| elif hasattr(analysis, "dict"): | |
| record = dict(analysis.dict()) | |
| else: | |
| record = { | |
| "overall_mood": getattr(analysis, "overall_mood", "neutral"), | |
| "distortions": getattr(analysis, "distortions", []) or [], | |
| "summary": getattr(analysis, "summary", ""), | |
| "needs_professional_signal": bool( | |
| getattr(analysis, "needs_professional_signal", False) | |
| ), | |
| } | |
| record["ts"] = datetime.now().isoformat(timespec="seconds") | |
| record["entry_text"] = entry_text.strip() | |
| record.setdefault("distortions", []) | |
| record.setdefault("summary", "") | |
| record.setdefault("overall_mood", "neutral") | |
| record.setdefault("needs_professional_signal", False) | |
| return record | |
| def _fallback_journal_record(entry_text: str, lang: str, error: Exception) -> Dict[str, Any]: | |
| record = _analysis_to_journal_record( | |
| { | |
| "overall_mood": "neutral", | |
| "distortions": [], | |
| "summary": t("journal_analysis_failed_summary", lang), | |
| "needs_professional_signal": False, | |
| }, | |
| entry_text, | |
| ) | |
| record["analysis_error"] = str(error)[:240] | |
| return record | |
| def _crisis_journal_record(entry_text: str, lang: str) -> Dict[str, Any]: | |
| """Record used when crisis language is detected. | |
| We intentionally do NOT call the LLM on crisis text — the banner is the | |
| priority. We still append an entry so the user's history is preserved and | |
| the dashboard / chart keeps moving forward. | |
| We also tag a default distortion so the distortion chart still reflects | |
| crisis entries (otherwise they'd show as "no distortions" which is | |
| misleading — the user's crisis text almost certainly contains cognitive | |
| distortions, we just can't safely send it to the LLM). | |
| """ | |
| # Extract a short snippet for the phrase field (first ~120 chars of entry) | |
| snippet = entry_text.strip() | |
| if len(snippet) > 120: | |
| snippet = snippet[:117].rstrip() + "..." | |
| record = _analysis_to_journal_record( | |
| { | |
| "overall_mood": "overwhelmed", | |
| "distortions": [ | |
| { | |
| "type": "catastrophizing", | |
| "phrase": snippet, | |
| "explanation": t("journal_crisis_distortion_explanation", lang), | |
| "reframe": t("journal_crisis_distortion_reframe", lang), | |
| "evidence_question": t("journal_crisis_distortion_question", lang), | |
| }, | |
| ], | |
| "summary": t("journal_crisis_summary", lang), | |
| "needs_professional_signal": True, | |
| }, | |
| entry_text, | |
| ) | |
| record["crisis_flagged"] = True | |
| return record | |
| def _append_journal_record(record: Mapping[str, Any]) -> List[Any]: | |
| entries = list(st.session_state.get(ENTRIES_KEY, []) or []) | |
| entries.append(dict(record)) | |
| st.session_state[ENTRIES_KEY] = entries | |
| st.session_state[LAST_ANALYSIS_KEY] = dict(record) | |
| return entries | |
| def _short_text(text: str, limit: int = 96) -> str: | |
| cleaned = " ".join(str(text or "").split()) | |
| if not cleaned: | |
| return "-" | |
| if len(cleaned) <= limit: | |
| return cleaned | |
| return f"{cleaned[: limit - 3].rstrip()}..." | |
| def _journal_history_rows(lang: str) -> List[Dict[str, Any]]: | |
| rows: List[Dict[str, Any]] = [] | |
| entries = list(st.session_state.get(ENTRIES_KEY, []) or []) | |
| for fallback_order, entry in enumerate(entries, start=1): | |
| ts = _entry_timestamp(entry) | |
| distortions = _entry_distortions(entry) | |
| rows.append( | |
| { | |
| "entry": fallback_order, | |
| "logged_at": _timestamp_label(ts), | |
| "mood": _entry_mood(entry) or "-", | |
| "distortions": len(distortions), | |
| "support": t("journal_support_yes", lang) | |
| if _entry_needs_pro(entry) | |
| else t("journal_support_no", lang), | |
| "note": _short_text(_entry_text(entry)), | |
| "observation": _short_text(_entry_summary(entry), limit=140), | |
| "_sort_ts": _parse_timestamp(ts), | |
| "_fallback_order": fallback_order, | |
| } | |
| ) | |
| rows.sort( | |
| key=lambda row: ( | |
| row.get("_sort_ts") or datetime.max, | |
| int(row.get("_fallback_order") or 0), | |
| ) | |
| ) | |
| for idx, row in enumerate(rows, start=1): | |
| row["entry"] = idx | |
| visible_keys = ( | |
| "entry", | |
| "logged_at", | |
| "mood", | |
| "distortions", | |
| "support", | |
| "note", | |
| "observation", | |
| ) | |
| return [{key: row.get(key) for key in visible_keys} for row in rows] | |
| def _render_journal_history(lang: str) -> None: | |
| st.markdown(f"##### {t('dashboard_journal_history_heading', lang)}") | |
| rows = _journal_history_rows(lang) | |
| if not rows: | |
| st.caption(t("dashboard_journal_history_empty", lang)) | |
| return | |
| st.dataframe( | |
| rows, | |
| use_container_width=True, | |
| hide_index=True, | |
| column_config={ | |
| "entry": st.column_config.NumberColumn(t("journal_col_entry", lang), width="small"), | |
| "logged_at": st.column_config.TextColumn(t("journal_col_logged", lang), width="small"), | |
| "mood": st.column_config.TextColumn(t("journal_col_mood", lang), width="small"), | |
| "distortions": st.column_config.NumberColumn(t("journal_col_distortions", lang), width="small"), | |
| "support": st.column_config.TextColumn(t("journal_col_support", lang), width="small"), | |
| "note": st.column_config.TextColumn(t("journal_col_note", lang), width="medium"), | |
| "observation": st.column_config.TextColumn(t("journal_col_observation", lang), width="large"), | |
| }, | |
| ) | |
| def _distortion_chart_rows(entries: Sequence[Any]) -> List[Dict[str, Any]]: | |
| counter: Counter[str] = Counter() | |
| for entry in entries: | |
| for d in _entry_distortions(entry): | |
| key = _distortion_type(d) | |
| if key: | |
| counter[key] += 1 | |
| return [ | |
| {"distortion": DISTORTION_LABELS.get(k, k), "count": v} | |
| for k, v in counter.most_common() | |
| ] | |
| def _render_chart(lang: str) -> None: | |
| entries = list(st.session_state.get(ENTRIES_KEY, []) or []) | |
| rows = _distortion_chart_rows(entries) | |
| if not rows: | |
| if entries: | |
| # Entries exist, but Saathi didn't tag any CBT distortions in them. | |
| # This is common when the text describes feelings ("I feel lonely") | |
| # rather than distorted thoughts ("my life is ruined forever"). | |
| st.info( | |
| t("dashboard_distortion_no_distortions", lang).format( | |
| count=len(entries) | |
| ) | |
| ) | |
| st.caption(t("dashboard_distortion_try_examples", lang)) | |
| with st.expander(t("dashboard_distortion_examples_heading", lang), expanded=False): | |
| st.markdown(t("dashboard_distortion_examples_body", lang)) | |
| else: | |
| st.caption(t("dashboard_distortion_empty_hint", lang)) | |
| return | |
| max_count = max(int(row["count"]) for row in rows) | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Bar( | |
| x=[row["distortion"] for row in rows], | |
| y=[int(row["count"]) for row in rows], | |
| text=[str(row["count"]) for row in rows], | |
| textposition="outside", | |
| marker=dict(color="#2563EB", line=dict(color="#1D4ED8", width=1)), | |
| hovertemplate=f"<b>%{{x}}</b><br>{t('journal_chart_yaxis', lang)} %{{y}}<extra></extra>", | |
| ) | |
| ) | |
| fig.update_layout( | |
| title=t("journal_chart_title", lang), | |
| xaxis_title=t("journal_chart_xaxis", lang), | |
| yaxis_title=t("journal_chart_yaxis", lang), | |
| showlegend=False, | |
| height=320, | |
| margin=dict(l=10, r=10, t=50, b=10), | |
| bargap=0.45, | |
| plot_bgcolor="rgba(0,0,0,0)", | |
| paper_bgcolor="rgba(0,0,0,0)", | |
| ) | |
| fig.update_xaxes(type="category", tickangle=0) | |
| fig.update_yaxes( | |
| range=[0, max(2, max_count + 1)], | |
| tickmode="linear", | |
| dtick=1, | |
| rangemode="tozero", | |
| ) | |
| fig.update_traces(cliponaxis=False) | |
| st.plotly_chart(fig, use_container_width=True) | |
| _render_distortion_summary_line(entries, lang) | |
| def _render_screener_summary_line(kind: ScreenerKind, lang: str) -> None: | |
| summary = _summarize_screener_history(kind) | |
| if not summary: | |
| return | |
| config = _screener_config(kind) | |
| name = str(config.get("name", kind.upper())) | |
| if summary["direction"] == "single": | |
| st.info( | |
| t("screener_summary_single_attempt", lang).format( | |
| name=name, | |
| latest=summary["latest"], | |
| max=summary["max"], | |
| band=t(f"band_{summary['latest_band']}", lang), | |
| ) | |
| ) | |
| return | |
| delta = int(summary["delta"]) | |
| delta_signed = f"{delta:+d}" | |
| direction_label = t(f"screener_direction_{summary['direction']}", lang) | |
| st.info( | |
| t("screener_summary_line", lang).format( | |
| name=name, | |
| attempts=summary["attempts"], | |
| latest=summary["latest"], | |
| max=summary["max"], | |
| band=t(f"band_{summary['latest_band']}", lang), | |
| delta_signed=delta_signed, | |
| direction=direction_label, | |
| ) | |
| ) | |
| def _render_checkin_summary_line(lang: str) -> None: | |
| checkins = list(st.session_state.get(CHECKINS_KEY, []) or []) | |
| summary = _summarize_checkins(checkins) | |
| if not summary: | |
| return | |
| if summary["count"] < 2: | |
| st.caption(t("checkin_summary_single", lang)) | |
| return | |
| def _fmt(value: Optional[float]) -> str: | |
| if value is None: | |
| return "-" | |
| return f"{value:.1f}" | |
| st.info( | |
| t("checkin_summary_line", lang).format( | |
| count=summary["count"], | |
| mood_avg=_fmt(summary["mood_avg"]), | |
| sleep_avg=_fmt(summary["sleep_avg"]), | |
| stress_avg=_fmt(summary["stress_avg"]), | |
| mood_trend=t(f"checkin_trend_{summary['mood_trend']}", lang), | |
| stress_trend=t(f"checkin_trend_{summary['stress_trend']}", lang), | |
| ) | |
| ) | |
| def _render_distortion_summary_line(entries: Sequence[Any], lang: str) -> None: | |
| summary = _summarize_distortions(entries) | |
| if not summary: | |
| return | |
| top_label = DISTORTION_LABELS.get(summary["top_type"], summary["top_type"]) | |
| if summary.get("second_type"): | |
| second_label = DISTORTION_LABELS.get( | |
| summary["second_type"], summary["second_type"] | |
| ) | |
| line = t("distortion_summary_line", lang).format( | |
| top=top_label, | |
| top_count=summary["top_count"], | |
| total_entries=summary["total_entries"], | |
| second=second_label, | |
| second_count=summary["second_count"], | |
| ) | |
| else: | |
| line = t("distortion_summary_single_line", lang).format( | |
| top=top_label, | |
| top_count=summary["top_count"], | |
| total_entries=summary["total_entries"], | |
| ) | |
| tip = t(summary["tip_key"], lang) | |
| st.info(f"{line}\n\n**{t('distortion_summary_tip_prefix', lang)}** {tip}") | |
| def _render_journal(lang: str) -> None: | |
| notice = st.session_state.get(JOURNAL_NOTICE_KEY) | |
| if isinstance(notice, Mapping) and notice.get("message"): | |
| level = notice.get("level") | |
| if level == "warning": | |
| st.warning(str(notice["message"])) | |
| else: | |
| st.success(str(notice["message"])) | |
| st.session_state[JOURNAL_NOTICE_KEY] = None | |
| input_key = f"cognitive_journal_input_v{st.session_state[JOURNAL_WIDGET_VERSION_KEY]}" | |
| entry = st.text_area( | |
| "journal_entry", | |
| placeholder=t("journal_input_placeholder", lang), | |
| label_visibility="collapsed", | |
| height=140, | |
| key=input_key, | |
| ) | |
| if st.button(t("journal_send_button", lang), type="primary", key="cognitive_journal_send_button"): | |
| entry_text = entry.strip() | |
| if entry_text: | |
| if check_crisis(entry_text): | |
| # Crisis detected: show banner immediately, and STILL log the | |
| # entry (with the pro-signal flag set) so the user's record is | |
| # preserved and the dashboard keeps working. | |
| render_crisis_banner(lang) | |
| record = _crisis_journal_record(entry_text, lang) | |
| entries = _append_journal_record(record) | |
| st.session_state[JOURNAL_NOTICE_KEY] = { | |
| "level": "warning", | |
| "message": t("journal_crisis_logged", lang).format(count=len(entries)), | |
| } | |
| st.session_state[JOURNAL_WIDGET_VERSION_KEY] += 1 | |
| return | |
| with st.spinner("..."): | |
| try: | |
| analysis = chat_structured( | |
| module=MODULE_NAME, | |
| user_text=entry_text, | |
| language_name=claude_language_name(lang), | |
| schema=JournalAnalysis, | |
| ) | |
| record = _analysis_to_journal_record(analysis, entry_text) | |
| entries = _append_journal_record(record) | |
| st.session_state[JOURNAL_NOTICE_KEY] = { | |
| "level": "success", | |
| "message": t("journal_logged", lang).format(count=len(entries)), | |
| } | |
| except Exception as e: | |
| record = _fallback_journal_record(entry_text, lang, e) | |
| entries = _append_journal_record(record) | |
| st.session_state[JOURNAL_NOTICE_KEY] = { | |
| "level": "warning", | |
| "message": t("journal_saved_without_analysis", lang).format( | |
| count=len(entries) | |
| ), | |
| } | |
| st.session_state[JOURNAL_WIDGET_VERSION_KEY] += 1 | |
| st.rerun() | |
| if st.session_state[LAST_ANALYSIS_KEY] is not None: | |
| _render_analysis(st.session_state[LAST_ANALYSIS_KEY], lang) | |
| _render_recent_journal_entries(lang) | |
| def _render_recent_journal_entries(lang: str) -> None: | |
| """Compact 'previous entries' list shown directly under the Journal tab.""" | |
| entries = list(st.session_state.get(ENTRIES_KEY, []) or []) | |
| if len(entries) <= 1: | |
| return | |
| with st.expander(t("journal_recent_heading", lang), expanded=False): | |
| rows = _journal_history_rows(lang) | |
| if not rows: | |
| st.caption(t("journal_recent_empty", lang)) | |
| return | |
| recent_rows = list(reversed(rows))[:5] | |
| st.dataframe( | |
| recent_rows, | |
| use_container_width=True, | |
| hide_index=True, | |
| column_config={ | |
| "entry": st.column_config.NumberColumn(t("journal_col_entry", lang), width="small"), | |
| "logged_at": st.column_config.TextColumn(t("journal_col_logged", lang), width="small"), | |
| "mood": st.column_config.TextColumn(t("journal_col_mood", lang), width="small"), | |
| "distortions": st.column_config.NumberColumn(t("journal_col_distortions", lang), width="small"), | |
| "support": st.column_config.TextColumn(t("journal_col_support", lang), width="small"), | |
| "note": st.column_config.TextColumn(t("journal_col_note", lang), width="medium"), | |
| "observation": st.column_config.TextColumn(t("journal_col_observation", lang), width="large"), | |
| }, | |
| ) | |
| def _score_screener(answers: List[int]) -> int: | |
| return sum(int(a) for a in answers) | |
| def _band_for(score: int, bands: List[dict]) -> str: | |
| for band in bands: | |
| if score <= int(band["max"]): | |
| return str(band["key"]) | |
| return str(bands[-1]["key"]) if bands else "none" | |
| # --- Deterministic summary helpers (no Streamlit, unit-testable) ----------- | |
| def _direction_from_delta(delta: int) -> str: | |
| """Lower score = improving for both PHQ-9 and GAD-7.""" | |
| if delta <= -2: | |
| return "improving" | |
| if delta >= 2: | |
| return "worsening" | |
| return "stable" | |
| def _summarize_screener_history(kind: ScreenerKind) -> Optional[Dict[str, Any]]: | |
| """Return a short structured summary of a screener's history. | |
| Returns None when history is empty. When exactly one attempt exists, | |
| `delta` and `direction` are None / "single". | |
| """ | |
| history = list(st.session_state.get(_history_key(kind), []) or []) | |
| if not history: | |
| return None | |
| scores = [_coerce_int(item.get("score")) for item in history] | |
| scores = [s for s in scores if s is not None] | |
| if not scores: | |
| return None | |
| config = _screener_config(kind) | |
| bands = list(config.get("bands") or []) | |
| latest_score = int(scores[-1]) | |
| first_score = int(scores[0]) | |
| latest_band = _band_for(latest_score, bands) | |
| max_score = _max_score(config) | |
| if len(scores) == 1: | |
| return { | |
| "attempts": 1, | |
| "latest": latest_score, | |
| "latest_band": latest_band, | |
| "max": max_score, | |
| "first": latest_score, | |
| "delta": 0, | |
| "direction": "single", | |
| } | |
| delta = latest_score - first_score | |
| return { | |
| "attempts": len(scores), | |
| "latest": latest_score, | |
| "latest_band": latest_band, | |
| "max": max_score, | |
| "first": first_score, | |
| "delta": delta, | |
| "direction": _direction_from_delta(delta), | |
| } | |
| def _trend_label(prev_avg: Optional[float], last_avg: Optional[float]) -> str: | |
| if prev_avg is None or last_avg is None: | |
| return "flat" | |
| diff = last_avg - prev_avg | |
| if diff >= 0.75: | |
| return "up" | |
| if diff <= -0.75: | |
| return "down" | |
| return "flat" | |
| def _mean(values: Sequence[float]) -> Optional[float]: | |
| nums = [float(v) for v in values if v is not None] | |
| return sum(nums) / len(nums) if nums else None | |
| def _summarize_checkins( | |
| checkins: Sequence[Mapping[str, Any]], | |
| ) -> Optional[Dict[str, Any]]: | |
| sorted_items = list(_sorted_checkins(checkins)) | |
| if not sorted_items: | |
| return None | |
| moods = [float(item.get("mood")) for item in sorted_items if item.get("mood") is not None] | |
| sleeps = [ | |
| float(item.get("sleep_h")) | |
| for item in sorted_items | |
| if item.get("sleep_h") is not None | |
| ] | |
| stresses = [ | |
| float(item.get("stress")) | |
| for item in sorted_items | |
| if item.get("stress") is not None | |
| ] | |
| last_moods = moods[-3:] | |
| prev_moods = moods[-6:-3] if len(moods) >= 6 else moods[:-3] | |
| last_stress = stresses[-3:] | |
| prev_stress = stresses[-6:-3] if len(stresses) >= 6 else stresses[:-3] | |
| return { | |
| "count": len(sorted_items), | |
| "mood_avg": _mean(moods), | |
| "sleep_avg": _mean(sleeps), | |
| "stress_avg": _mean(stresses), | |
| "mood_trend": _trend_label(_mean(prev_moods), _mean(last_moods)), | |
| "stress_trend": _trend_label(_mean(prev_stress), _mean(last_stress)), | |
| } | |
| def _summarize_distortions(entries: Sequence[Any]) -> Optional[Dict[str, Any]]: | |
| counter: Counter[str] = Counter() | |
| total_distortions = 0 | |
| for entry in entries: | |
| for d in _entry_distortions(entry): | |
| key = _distortion_type(d) | |
| if key: | |
| counter[key] += 1 | |
| total_distortions += 1 | |
| if total_distortions == 0: | |
| return None | |
| top_entries = counter.most_common(2) | |
| top_type, top_count = top_entries[0] | |
| second_type, second_count = (top_entries[1] if len(top_entries) > 1 else (None, 0)) | |
| return { | |
| "top_type": top_type, | |
| "top_count": int(top_count), | |
| "second_type": second_type, | |
| "second_count": int(second_count), | |
| "total_entries": len(entries), | |
| "total_distortions": total_distortions, | |
| "tip_key": f"distortion_tip_{top_type}", | |
| } | |
| def _screener_config(kind: ScreenerKind) -> Dict[str, Any]: | |
| data = load_screeners() | |
| return dict(data.get(kind) or {}) | |
| def _max_score(config: Mapping[str, Any]) -> int: | |
| bands = config.get("bands") or [] | |
| return int(bands[-1]["max"]) if bands else 0 | |
| def _response_labels(config: Mapping[str, Any], lang: str) -> List[str]: | |
| labels = config.get("response_labels") or {} | |
| selected = labels.get(lang) or labels.get("en") or [] | |
| return list(selected) | |
| def _localized_item(item: Mapping[str, Any], lang: str) -> str: | |
| if item.get(lang): | |
| return str(item[lang]) | |
| text = str(item.get("en", "")) | |
| if lang != "en": | |
| return f"{text} (translation pending)" | |
| return text | |
| def _answers_key(kind: ScreenerKind) -> str: | |
| return PHQ9_ANSWERS_KEY if kind == "phq9" else GAD7_ANSWERS_KEY | |
| def _score_key(kind: ScreenerKind) -> str: | |
| return PHQ9_SCORE_KEY if kind == "phq9" else GAD7_SCORE_KEY | |
| def _taken_key(kind: ScreenerKind) -> str: | |
| return PHQ9_TAKEN_AT_KEY if kind == "phq9" else GAD7_TAKEN_AT_KEY | |
| def _history_key(kind: ScreenerKind) -> str: | |
| return PHQ9_HISTORY_KEY if kind == "phq9" else GAD7_HISTORY_KEY | |
| def _version_key(kind: ScreenerKind) -> str: | |
| return PHQ9_WIDGET_VERSION_KEY if kind == "phq9" else GAD7_WIDGET_VERSION_KEY | |
| def _next_screener_log_order() -> int: | |
| return ( | |
| len(st.session_state.get(PHQ9_HISTORY_KEY, []) or []) | |
| + len(st.session_state.get(GAD7_HISTORY_KEY, []) or []) | |
| + 1 | |
| ) | |
| def _ensure_answers(kind: ScreenerKind, n_items: int) -> List[int]: | |
| key = _answers_key(kind) | |
| answers = list(st.session_state.get(key) or []) | |
| if len(answers) != n_items: | |
| answers = [0] * n_items | |
| st.session_state[key] = answers | |
| return answers | |
| def _record_screener(kind: ScreenerKind, config: Mapping[str, Any], answers: List[int]) -> None: | |
| score = _score_screener(answers) | |
| band = _band_for(score, list(config.get("bands") or [])) | |
| ts = datetime.now().isoformat(timespec="seconds") | |
| log_order = _next_screener_log_order() | |
| q9_positive = bool(kind == "phq9" and len(answers) >= 9 and answers[8] > 0) | |
| st.session_state[_score_key(kind)] = score | |
| st.session_state[_answers_key(kind)] = list(answers) | |
| st.session_state[_taken_key(kind)] = ts | |
| if kind == "phq9": | |
| st.session_state[PHQ9_ITEM9_KEY] = q9_positive | |
| history = list(st.session_state.get(_history_key(kind), [])) | |
| history.append( | |
| { | |
| "ts": ts, | |
| "log_order": log_order, | |
| "score": score, | |
| "band": band, | |
| "answers": list(answers), | |
| "q9_positive": q9_positive, | |
| } | |
| ) | |
| st.session_state[_history_key(kind)] = history | |
| def _parse_timestamp(ts: str) -> Optional[datetime]: | |
| if not ts: | |
| return None | |
| try: | |
| return datetime.fromisoformat(ts) | |
| except ValueError: | |
| return None | |
| def _timestamp_label(ts: str) -> str: | |
| if not ts: | |
| return "-" | |
| dt = _parse_timestamp(ts) | |
| if dt is None: | |
| return ts | |
| if dt.second: | |
| return dt.strftime("%d %b, %H:%M:%S") | |
| return dt.strftime("%d %b, %H:%M") | |
| def _render_screener_result(kind: ScreenerKind, config: Mapping[str, Any], score: int, lang: str) -> None: | |
| band = _band_for(score, list(config.get("bands") or [])) | |
| max_score = _max_score(config) | |
| band_label = t(f"band_{band}", lang) | |
| interp = t(f"band_{band}_interpretation", lang) | |
| with st.container(border=True): | |
| cols = st.columns(3) | |
| cols[0].metric(t("screener_score_label", lang), f"{score}/{max_score}") | |
| cols[1].metric(t("screener_band_label", lang), band_label) | |
| cols[2].metric(t("screener_last_taken", lang), _timestamp_label(st.session_state[_taken_key(kind)])) | |
| st.write(interp) | |
| st.caption(t("screener_disclaimer", lang)) | |
| def _render_single_screener_timeline(kind: ScreenerKind, lang: str) -> None: | |
| rows = _screener_history_rows(kind, lang) | |
| if not rows: | |
| return | |
| rows.sort( | |
| key=lambda row: ( | |
| row.get("_sort_ts") or datetime.max, | |
| int(row.get("attempt") or 0), | |
| ) | |
| ) | |
| plot_rows = [row for row in rows if row.get("score") is not None] | |
| if not plot_rows: | |
| return | |
| config = _screener_config(kind) | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[int(row["attempt"]) for row in plot_rows], | |
| y=[int(row["score"]) for row in plot_rows], | |
| mode="lines+markers+text", | |
| name=config.get("name", kind.upper()), | |
| text=[str(row["score"]) for row in plot_rows], | |
| textposition="top center", | |
| customdata=[[row.get("taken_at", "-")] for row in plot_rows], | |
| hovertemplate=( | |
| "<b>%{fullData.name}</b><br>" | |
| "Attempt %{x}<br>" | |
| "Score %{y}<br>" | |
| "Taken %{customdata[0]}<extra></extra>" | |
| ), | |
| marker=dict(size=14), | |
| line=dict(width=3), | |
| ) | |
| ) | |
| max_attempt = max(int(row["attempt"]) for row in plot_rows) | |
| fig.update_layout( | |
| title=f"{config.get('name', kind.upper())} {t('dashboard_screener_timeline_heading', lang)}", | |
| xaxis_title="Attempt", | |
| yaxis_title="Score", | |
| height=260, | |
| margin=dict(l=10, r=10, t=50, b=10), | |
| showlegend=False, | |
| ) | |
| fig.update_xaxes( | |
| type="linear", | |
| range=[0.5, max_attempt + 0.5], | |
| tickmode="linear", | |
| dtick=1, | |
| ) | |
| fig.update_yaxes(range=[0, _max_score(config)], rangemode="tozero") | |
| st.plotly_chart(fig, use_container_width=True) | |
| def _render_screener(kind: ScreenerKind, lang: str) -> None: | |
| config = _screener_config(kind) | |
| if not config: | |
| st.error("Screener data could not be loaded.") | |
| return | |
| items = list(config.get("items") or []) | |
| answers = _ensure_answers(kind, len(items)) | |
| labels = _response_labels(config, lang) | |
| if len(labels) != 4: | |
| labels = ["Not at all", "Several days", "More than half the days", "Nearly every day"] | |
| st.subheader(t(f"{kind}_header", lang)) | |
| st.caption(t(f"{kind}_sub", lang)) | |
| st.caption(f"**{t('screener_timeframe_prefix', lang)}:** {config.get('timeframe', '')}") | |
| if lang != "en": | |
| st.info(t("screener_translation_pending", lang)) | |
| st.caption(t("screener_source_caption", lang)) | |
| version = st.session_state[_version_key(kind)] | |
| for idx, item in enumerate(items): | |
| answer = st.radio( | |
| _localized_item(item, lang), | |
| options=[0, 1, 2, 3], | |
| index=int(answers[idx]), | |
| format_func=lambda value, labels=labels: f"{value} - {labels[value]}", | |
| horizontal=True, | |
| key=f"cognitive_journal_{kind}_v{version}_item_{idx}", | |
| ) | |
| answers[idx] = int(answer) | |
| st.session_state[_answers_key(kind)] = answers | |
| button_cols = st.columns([1, 1, 4]) | |
| with button_cols[0]: | |
| score_clicked = st.button( | |
| t(f"{kind}_cta", lang), | |
| type="primary", | |
| key=f"cognitive_journal_{kind}_score_button", | |
| disabled=st.session_state[_score_key(kind)] is not None, | |
| ) | |
| with button_cols[1]: | |
| retake_clicked = st.button( | |
| t(f"{kind}_retake", lang), | |
| key=f"cognitive_journal_{kind}_retake", | |
| disabled=st.session_state[_score_key(kind)] is None, | |
| ) | |
| if score_clicked: | |
| _record_screener(kind, config, answers) | |
| if retake_clicked: | |
| st.session_state[_answers_key(kind)] = [0] * len(items) | |
| st.session_state[_score_key(kind)] = None | |
| st.session_state[_taken_key(kind)] = "" | |
| if kind == "phq9": | |
| st.session_state[PHQ9_ITEM9_KEY] = False | |
| st.session_state[_version_key(kind)] += 1 | |
| st.rerun() | |
| score = st.session_state[_score_key(kind)] | |
| if score is not None: | |
| _render_screener_result(kind, config, int(score), lang) | |
| _render_single_screener_timeline(kind, lang) | |
| _render_screener_summary_line(kind, lang) | |
| def _render_phq9(lang: str) -> None: | |
| _render_screener("phq9", lang) | |
| def _render_gad7(lang: str) -> None: | |
| _render_screener("gad7", lang) | |
| def _append_checkin(mood: int, sleep_h: float, stress: int) -> None: | |
| checkins = list(st.session_state.get(CHECKINS_KEY, [])) | |
| checkins.append( | |
| { | |
| "ts": datetime.now().isoformat(timespec="seconds"), | |
| "mood": int(mood), | |
| "sleep_h": float(sleep_h), | |
| "stress": int(stress), | |
| } | |
| ) | |
| st.session_state[CHECKINS_KEY] = checkins | |
| def _render_daily_checkin(lang: str) -> None: | |
| st.subheader(t("checkin_header", lang)) | |
| st.caption(t("checkin_sub", lang)) | |
| mood = st.slider(t("checkin_mood_label", lang), 1, 10, 6, key="cognitive_journal_checkin_mood") | |
| sleep_h = st.slider( | |
| t("checkin_sleep_label", lang), | |
| 0.0, | |
| 12.0, | |
| 7.0, | |
| 0.5, | |
| key="cognitive_journal_checkin_sleep", | |
| ) | |
| stress = st.slider(t("checkin_stress_label", lang), 1, 10, 5, key="cognitive_journal_checkin_stress") | |
| if st.button(t("checkin_submit", lang), type="primary", key="cognitive_journal_checkin_submit"): | |
| _append_checkin(mood, sleep_h, stress) | |
| st.success(t("checkin_logged", lang)) | |
| st.markdown(f"##### {t('checkin_history_heading', lang)}") | |
| checkins = st.session_state.get(CHECKINS_KEY, []) | |
| if not checkins: | |
| st.caption(t("checkin_empty_history", lang)) | |
| return | |
| for item in reversed(_sorted_checkins(checkins)[-5:]): | |
| st.write( | |
| f"{_timestamp_label(str(item.get('ts', '')))} · " | |
| f"mood {item.get('mood')}/10 · sleep {item.get('sleep_h')}h · stress {item.get('stress')}/10" | |
| ) | |
| _render_checkin_trend(lang) | |
| def _coerce_int(value: Any) -> Optional[int]: | |
| if value is None or value == "": | |
| return None | |
| try: | |
| return int(value) | |
| except (TypeError, ValueError): | |
| return None | |
| def _recent_mood_avg(checkins: Sequence[Mapping[str, Any]]) -> Optional[float]: | |
| if not checkins: | |
| return None | |
| recent = checkins[-3:] | |
| moods = [] | |
| for item in recent: | |
| try: | |
| moods.append(float(item.get("mood"))) | |
| except (TypeError, ValueError): | |
| continue | |
| if not moods: | |
| return None | |
| return sum(moods) / len(moods) | |
| def _aggregate_session_signals() -> dict: | |
| entries = list(st.session_state.get(ENTRIES_KEY, [])) | |
| recent_entries = entries[-3:] | |
| distortion_total = sum(len(_entry_distortions(entry)) for entry in recent_entries) | |
| distortion_density = distortion_total / len(recent_entries) if recent_entries else 0.0 | |
| pro_signal_count = sum(1 for entry in recent_entries if _entry_needs_pro(entry)) | |
| checkins = _sorted_checkins(st.session_state.get(CHECKINS_KEY, []) or []) | |
| chat_history = st.session_state.get("saathi_chat_history", []) or [] | |
| chat_turns = sum(1 for msg in chat_history if isinstance(msg, Mapping) and msg.get("role") == "user") | |
| student_events = 1 if st.session_state.get("student_response") or st.session_state.get("student_situation") else 0 | |
| legal_lookups = 1 if st.session_state.get("legal_aid_response") or st.session_state.get("legal_aid_sections") else 0 | |
| soothe_poems = 1 if st.session_state.get("soothe_poem") else 0 | |
| return { | |
| "phq9_score": _coerce_int(st.session_state.get(PHQ9_SCORE_KEY)), | |
| "gad7_score": _coerce_int(st.session_state.get(GAD7_SCORE_KEY)), | |
| "phq9_item9_positive": bool(st.session_state.get(PHQ9_ITEM9_KEY)), | |
| "recent_mood_avg": _recent_mood_avg(checkins), | |
| "distortion_density": distortion_density, | |
| "pro_signal_count": pro_signal_count, | |
| "crisis_history": False, | |
| "chat_turns": chat_turns, | |
| "student_events": student_events, | |
| "legal_lookups": legal_lookups, | |
| "soothe_poems": soothe_poems, | |
| } | |
| def _score_at_least(score: Optional[int], threshold: int) -> bool: | |
| return score is not None and score >= threshold | |
| def _compute_stepped_care_recommendation(signals: dict) -> Tuple[str, str]: | |
| phq9 = _coerce_int(signals.get("phq9_score")) | |
| gad7 = _coerce_int(signals.get("gad7_score")) | |
| recent_mood_avg = signals.get("recent_mood_avg") | |
| distortion_density = float(signals.get("distortion_density") or 0.0) | |
| pro_signal_count = int(signals.get("pro_signal_count") or 0) | |
| low_mood_with_distortions = ( | |
| recent_mood_avg is not None | |
| and float(recent_mood_avg) <= 3 | |
| and distortion_density >= 2 | |
| ) | |
| if ( | |
| _score_at_least(phq9, 20) | |
| or _score_at_least(gad7, 15) | |
| or pro_signal_count >= 2 | |
| or bool(signals.get("phq9_item9_positive")) | |
| or bool(signals.get("crisis_history")) | |
| ): | |
| return "urgent_professional", "tier_reason_urgent_professional" | |
| if _score_at_least(phq9, 15) or _score_at_least(gad7, 10) or low_mood_with_distortions: | |
| return "guided_support", "tier_reason_guided_support" | |
| if _score_at_least(phq9, 10) or _score_at_least(gad7, 5) or distortion_density >= 1.5: | |
| return "self_help", "tier_reason_self_help" | |
| return "self_care", "tier_reason_self_care" | |
| def _render_stepped_care_card(lang: str, tier: str, rationale_key: str) -> None: | |
| title = t(f"tier_{tier}", lang) | |
| body = t(f"tier_{tier}_body", lang) | |
| rationale = t(rationale_key, lang) | |
| message = f"**{title}**\n\n{body}\n\n_{rationale}_" | |
| if tier == "urgent_professional": | |
| st.error(message) | |
| elif tier == "guided_support": | |
| st.warning(message) | |
| elif tier == "self_help": | |
| st.info(message) | |
| else: | |
| st.success(message) | |
| st.caption(t("tier_disclaimer", lang)) | |
| def _render_screener_metric(kind: ScreenerKind, lang: str) -> None: | |
| config = _screener_config(kind) | |
| score = st.session_state.get(_score_key(kind)) | |
| if score is None: | |
| st.info(t(f"{kind}_not_taken_nudge", lang)) | |
| return | |
| score_int = int(score) | |
| band = _band_for(score_int, list(config.get("bands") or [])) | |
| with st.container(border=True): | |
| st.metric(config.get("name", kind.upper()), f"{score_int}/{_max_score(config)}") | |
| st.caption(f"{t(f'band_{band}', lang)} · {t('screener_last_taken', lang)}: {_timestamp_label(st.session_state[_taken_key(kind)])}") | |
| def _render_screener_snapshot(lang: str) -> None: | |
| st.markdown(f"##### {t('dashboard_screener_history_heading', lang)}") | |
| cols = st.columns(2) | |
| with cols[0]: | |
| _render_screener_metric("phq9", lang) | |
| with cols[1]: | |
| _render_screener_metric("gad7", lang) | |
| def _screener_history_rows(kind: ScreenerKind, lang: str) -> List[Dict[str, Any]]: | |
| config = _screener_config(kind) | |
| rows = [] | |
| for idx, item in enumerate(st.session_state.get(_history_key(kind), []) or [], start=1): | |
| score = _coerce_int(item.get("score")) | |
| band = str(item.get("band") or "none") | |
| ts = str(item.get("ts", "")) | |
| log_order = _coerce_int(item.get("log_order")) | |
| rows.append( | |
| { | |
| "log": log_order, | |
| "attempt": idx, | |
| "screener": config.get("name", kind.upper()), | |
| "score": score, | |
| "band": t(f"band_{band}", lang), | |
| "taken_at": _timestamp_label(ts), | |
| "_raw_ts": ts, | |
| "_sort_ts": _parse_timestamp(ts), | |
| "_fallback_order": idx, | |
| } | |
| ) | |
| return rows | |
| def _combined_screener_history_rows(lang: str) -> List[Dict[str, Any]]: | |
| rows = [] | |
| for global_idx, row in enumerate( | |
| _screener_history_rows("phq9", lang) + _screener_history_rows("gad7", lang), | |
| start=1, | |
| ): | |
| row = dict(row) | |
| row["_fallback_order"] = row.get("log") or global_idx | |
| rows.append(row) | |
| rows.sort( | |
| key=lambda row: ( | |
| row.get("_sort_ts") or datetime.max, | |
| int(row.get("_fallback_order") or 0), | |
| ) | |
| ) | |
| for idx, row in enumerate(rows, start=1): | |
| row["log"] = idx | |
| return rows | |
| def _visible_screener_history_rows(rows: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]: | |
| visible_keys = ("log", "screener", "attempt", "score", "band", "taken_at") | |
| return [{key: row.get(key) for key in visible_keys} for row in rows] | |
| def _render_screener_timeline(lang: str) -> None: | |
| st.markdown(f"##### {t('dashboard_screener_timeline_heading', lang)}") | |
| rows = _combined_screener_history_rows(lang) | |
| if not rows: | |
| st.caption(t("dashboard_screener_timeline_empty", lang)) | |
| return | |
| plot_rows = [row for row in rows if row.get("score") is not None] | |
| if plot_rows: | |
| fig = go.Figure() | |
| screener_names = [] | |
| for row in plot_rows: | |
| name = str(row.get("screener") or "") | |
| if name and name not in screener_names: | |
| screener_names.append(name) | |
| for screener_name in screener_names: | |
| screener_rows = [row for row in plot_rows if row.get("screener") == screener_name] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[int(row["log"]) for row in screener_rows], | |
| y=[int(row["score"]) for row in screener_rows], | |
| mode="lines+markers+text", | |
| name=screener_name, | |
| text=[str(row["score"]) for row in screener_rows], | |
| textposition="top center", | |
| customdata=[ | |
| [row.get("taken_at", "-"), row.get("attempt", "-")] | |
| for row in screener_rows | |
| ], | |
| hovertemplate=( | |
| "<b>%{fullData.name}</b><br>" | |
| "Log %{x}<br>" | |
| "Score %{y}<br>" | |
| "Attempt %{customdata[1]}<br>" | |
| "Taken %{customdata[0]}<extra></extra>" | |
| ), | |
| marker=dict(size=14), | |
| line=dict(width=3), | |
| ) | |
| ) | |
| max_log = max(int(row["log"]) for row in plot_rows) | |
| fig.update_layout( | |
| xaxis_title="Log order", | |
| yaxis_title="Score", | |
| height=280, | |
| margin=dict(l=10, r=10, t=50, b=10), | |
| title=t("dashboard_screener_timeline_heading", lang), | |
| ) | |
| fig.update_xaxes( | |
| type="linear", | |
| range=[0.5, max_log + 0.5], | |
| tickmode="linear", | |
| dtick=1, | |
| ) | |
| fig.update_yaxes(range=[0, 27], rangemode="tozero") | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.dataframe(_visible_screener_history_rows(rows), use_container_width=True, hide_index=True) | |
| def _score_snapshot(score: Optional[int], max_score: int) -> str: | |
| if score is None: | |
| return "-" | |
| return f"{score}/{max_score}" | |
| def _render_session_snapshot(lang: str) -> None: | |
| signals = _aggregate_session_signals() | |
| tier, _ = _compute_stepped_care_recommendation(signals) | |
| phq9_config = _screener_config("phq9") | |
| gad7_config = _screener_config("gad7") | |
| checkins = st.session_state.get(CHECKINS_KEY, []) | |
| with st.container(border=True): | |
| st.caption(f"**{t('journal_snapshot_heading', lang)}**") | |
| cols = st.columns(4) | |
| cols[0].metric( | |
| "PHQ-9", | |
| _score_snapshot(signals.get("phq9_score"), _max_score(phq9_config)), | |
| ) | |
| cols[1].metric( | |
| "GAD-7", | |
| _score_snapshot(signals.get("gad7_score"), _max_score(gad7_config)), | |
| ) | |
| cols[2].metric(t("journal_snapshot_checkins", lang), len(checkins)) | |
| cols[3].metric(t("journal_snapshot_care_tier", lang), t(f"tier_{tier}", lang)) | |
| st.caption(t("journal_snapshot_hint", lang)) | |
| def _render_checkin_trend(lang: str) -> None: | |
| st.markdown(f"##### {t('dashboard_checkin_heading', lang)}") | |
| checkins = list(st.session_state.get(CHECKINS_KEY, []) or []) | |
| rows = _checkin_trend_rows(checkins) | |
| if not rows: | |
| st.caption(t("checkin_trend_caption_empty", lang)) | |
| return | |
| fig = go.Figure() | |
| metric_names = [] | |
| for row in rows: | |
| name = str(row.get("metric") or "") | |
| if name and name not in metric_names: | |
| metric_names.append(name) | |
| for metric_name in metric_names: | |
| metric_rows = [row for row in rows if row.get("metric") == metric_name] | |
| fig.add_trace( | |
| go.Scatter( | |
| x=[int(row["check_in"]) for row in metric_rows], | |
| y=[float(row["value"]) for row in metric_rows], | |
| mode="lines+markers+text", | |
| name=metric_name, | |
| text=[ | |
| str(int(row["value"])) | |
| if float(row["value"]).is_integer() | |
| else str(row["value"]) | |
| for row in metric_rows | |
| ], | |
| textposition="top center", | |
| customdata=[[row.get("logged_at", "-")] for row in metric_rows], | |
| hovertemplate=( | |
| "<b>%{fullData.name}</b><br>" | |
| "Check-in %{x}<br>" | |
| "Value %{y}<br>" | |
| "Logged %{customdata[0]}<extra></extra>" | |
| ), | |
| marker=dict(size=14), | |
| line=dict(width=3), | |
| ) | |
| ) | |
| max_checkin = max(int(row["check_in"]) for row in rows) | |
| fig.update_layout( | |
| xaxis_title="Check-in order", | |
| yaxis_title="Value", | |
| height=300, | |
| margin=dict(l=10, r=10, t=50, b=10), | |
| title=t("checkin_trend_title", lang), | |
| ) | |
| fig.update_xaxes( | |
| type="linear", | |
| range=[0.5, max_checkin + 0.5], | |
| tickmode="linear", | |
| dtick=1, | |
| ) | |
| fig.update_yaxes(range=[0, 12], rangemode="tozero") | |
| st.plotly_chart(fig, use_container_width=True) | |
| _render_checkin_summary_line(lang) | |
| def _sorted_checkins(checkins: Sequence[Mapping[str, Any]]) -> List[Mapping[str, Any]]: | |
| indexed_rows = list(enumerate(checkins)) | |
| indexed_rows.sort( | |
| key=lambda indexed_item: ( | |
| _parse_timestamp(str(indexed_item[1].get("ts", ""))) or datetime.max, | |
| indexed_item[0], | |
| ) | |
| ) | |
| return [item for _, item in indexed_rows] | |
| def _checkin_trend_rows(checkins: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]: | |
| rows: List[Dict[str, Any]] = [] | |
| for idx, item in enumerate(_sorted_checkins(checkins), start=1): | |
| for metric_key, label in ( | |
| ("mood", "Mood"), | |
| ("sleep_h", "Sleep hours"), | |
| ("stress", "Stress"), | |
| ): | |
| value = item.get(metric_key) | |
| if value is None: | |
| continue | |
| try: | |
| value = float(value) | |
| except (TypeError, ValueError): | |
| continue | |
| rows.append( | |
| { | |
| "check_in": idx, | |
| "metric": label, | |
| "value": value, | |
| "logged_at": _timestamp_label(str(item.get("ts", ""))), | |
| } | |
| ) | |
| return rows | |
| def _checkin_history_rows(lang: str) -> List[Dict[str, Any]]: | |
| rows = [] | |
| for idx, item in enumerate(_sorted_checkins(st.session_state.get(CHECKINS_KEY, []) or []), start=1): | |
| rows.append( | |
| { | |
| "check_in": idx, | |
| "logged_at": _timestamp_label(str(item.get("ts", ""))), | |
| "mood": item.get("mood"), | |
| "sleep_h": item.get("sleep_h"), | |
| "stress": item.get("stress"), | |
| } | |
| ) | |
| return rows | |
| def _render_checkin_history(lang: str) -> None: | |
| rows = _checkin_history_rows(lang) | |
| if not rows: | |
| return | |
| st.markdown(f"##### {t('dashboard_checkin_history_heading', lang)}") | |
| st.dataframe(rows, use_container_width=True, hide_index=True) | |
| def _render_crossmodule_footer(lang: str) -> None: | |
| signals = _aggregate_session_signals() | |
| rows = [] | |
| if signals["chat_turns"]: | |
| rows.append(f"{signals['chat_turns']} chat turn(s) in {t('tab_chat', lang)}") | |
| if signals["student_events"]: | |
| rows.append(f"{signals['student_events']} student support event(s)") | |
| if signals["legal_lookups"]: | |
| rows.append(f"{signals['legal_lookups']} legal lookup(s)") | |
| if signals["soothe_poems"]: | |
| rows.append(f"{signals['soothe_poems']} poem(s) generated") | |
| with st.expander(t("crossmodule_footer_heading", lang)): | |
| if not rows: | |
| st.caption(t("crossmodule_footer_none", lang)) | |
| for row in rows: | |
| st.caption(row) | |
| def _render_insights_dashboard(lang: str) -> None: | |
| st.subheader(t("dashboard_header", lang)) | |
| st.caption(t("dashboard_sub", lang)) | |
| signals = _aggregate_session_signals() | |
| tier, rationale_key = _compute_stepped_care_recommendation(signals) | |
| _render_stepped_care_card(lang, tier, rationale_key) | |
| st.caption(t("screener_disclaimer", lang)) | |
| _render_screener_snapshot(lang) | |
| _render_screener_timeline(lang) | |
| st.divider() | |
| _render_checkin_trend(lang) | |
| _render_checkin_history(lang) | |
| st.divider() | |
| _render_journal_history(lang) | |
| st.divider() | |
| st.markdown(f"##### {t('dashboard_distortion_heading', lang)}") | |
| _render_chart(lang) | |
| st.divider() | |
| _render_crossmodule_footer(lang) | |
| def render(lang: str) -> None: | |
| _init_state() | |
| st.header(t("journal_header", lang)) | |
| st.caption(t("journal_sub", lang)) | |
| _render_journal_explainer(lang) | |
| section = _render_section_router(lang) | |
| if section != "dashboard": | |
| _render_session_snapshot(lang) | |
| if section == "journal": | |
| _render_journal(lang) | |
| elif section == "phq9": | |
| _render_phq9(lang) | |
| elif section == "gad7": | |
| _render_gad7(lang) | |
| elif section == "checkin": | |
| _render_daily_checkin(lang) | |
| elif section == "dashboard": | |
| _render_insights_dashboard(lang) | |
| # --------------------------------------------------------------------------- | |
| # Cross-module memory — read-only digest exported to Saathi Chat | |
| # --------------------------------------------------------------------------- | |
| # | |
| # Design contract: | |
| # - This function is the ONLY way other modules may read Thought Diary | |
| # state. It returns a short English digest string, never raw entry text. | |
| # - No PII leaves the module: only counts, top category names, scores, trends. | |
| # - Returns "" (empty string) when the user has no journal data yet, so the | |
| # calling module can insert the string into a prompt template unconditionally | |
| # without breaking the layout when it's the user's first turn. | |
| # - Deterministic: no LLM call, no randomness. Safe to invoke every chat turn. | |
| # - English on purpose — this is a system-prompt block for the LLM, not for | |
| # the user. The LLM is already instructed to respond in the user's language. | |
| def get_cognitive_journal_context() -> str: | |
| """Return a short English digest of the user's Thought Diary activity. | |
| Designed to be injected into Saathi Chat's system prompt so the chat | |
| module can reference the user's recent journal / screener / check-in | |
| activity without holding its own copy of that state. Empty string if | |
| no data is available yet. | |
| """ | |
| entries = list(st.session_state.get(ENTRIES_KEY, []) or []) | |
| checkins = list(st.session_state.get(CHECKINS_KEY, []) or []) | |
| phq9_history = list(st.session_state.get(PHQ9_HISTORY_KEY, []) or []) | |
| gad7_history = list(st.session_state.get(GAD7_HISTORY_KEY, []) or []) | |
| if not (entries or checkins or phq9_history or gad7_history): | |
| return "" | |
| lines: List[str] = [] | |
| # --- Journal entries + distortion pattern --- | |
| if entries: | |
| distortion_summary = _summarize_distortions(entries) | |
| crisis_flags = sum(1 for e in entries if _entry_needs_pro(e)) | |
| mood_counter: Counter[str] = Counter() | |
| for e in entries: | |
| mood = _entry_mood(e) | |
| if mood: | |
| mood_counter[mood] += 1 | |
| top_mood = mood_counter.most_common(1) | |
| top_mood_label = top_mood[0][0] if top_mood else None | |
| journal_line = f"- Journal entries this session: {len(entries)}" | |
| if top_mood_label: | |
| journal_line += f" (most common mood tag: {top_mood_label})" | |
| lines.append(journal_line) | |
| # --- Include actual entry text so Chat can reference specifics --- | |
| recent = entries[-5:] # last 5 entries max | |
| lines.append("- Recent journal entries (newest last):") | |
| for idx, entry in enumerate(recent, 1): | |
| text = _entry_text(entry) | |
| if len(text) > 200: | |
| text = text[:197].rstrip() + "..." | |
| mood = _entry_mood(entry) or "unknown" | |
| distortions_list = _entry_distortions(entry) | |
| distortion_types = [_distortion_type(d) for d in distortions_list if _distortion_type(d)] | |
| dtypes_str = ", ".join(distortion_types) if distortion_types else "none" | |
| summary = _entry_summary(entry) | |
| if len(summary) > 150: | |
| summary = summary[:147].rstrip() + "..." | |
| lines.append( | |
| f' {idx}. "{text}" ' | |
| f"[mood: {mood}, distortions: {dtypes_str}] " | |
| f"Observation: {summary}" | |
| ) | |
| if distortion_summary: | |
| top_type = distortion_summary["top_type"] | |
| top_count = distortion_summary["top_count"] | |
| top_label = DISTORTION_LABELS.get(top_type, top_type) | |
| total = distortion_summary["total_distortions"] | |
| dist_line = ( | |
| f"- Top cognitive distortion: {top_label} " | |
| f"({top_count}x of {total} total distortions tagged)" | |
| ) | |
| if distortion_summary.get("second_type"): | |
| second_label = DISTORTION_LABELS.get( | |
| distortion_summary["second_type"], | |
| distortion_summary["second_type"], | |
| ) | |
| dist_line += f"; second most common: {second_label}" | |
| lines.append(dist_line) | |
| else: | |
| lines.append( | |
| "- No CBT distortions were tagged in those entries " | |
| "(the entries mostly describe feelings, not distorted thoughts)." | |
| ) | |
| if crisis_flags: | |
| lines.append( | |
| f"- {crisis_flags} entry/entries were flagged as needing " | |
| "professional support (crisis language or LLM-flagged severity)." | |
| ) | |
| # --- PHQ-9 history --- | |
| phq9_summary = _summarize_screener_history("phq9") | |
| if phq9_summary: | |
| if phq9_summary["direction"] == "single": | |
| lines.append( | |
| f"- PHQ-9 (depression): 1 attempt, score " | |
| f"{phq9_summary['latest']}/{phq9_summary['max']} " | |
| f"(band: {phq9_summary['latest_band']})." | |
| ) | |
| else: | |
| lines.append( | |
| f"- PHQ-9 (depression): {phq9_summary['attempts']} attempts, " | |
| f"latest {phq9_summary['latest']}/{phq9_summary['max']} " | |
| f"(band: {phq9_summary['latest_band']}), " | |
| f"change from first attempt: {int(phq9_summary['delta']):+d} " | |
| f"({phq9_summary['direction']})." | |
| ) | |
| # --- GAD-7 history --- | |
| gad7_summary = _summarize_screener_history("gad7") | |
| if gad7_summary: | |
| if gad7_summary["direction"] == "single": | |
| lines.append( | |
| f"- GAD-7 (anxiety): 1 attempt, score " | |
| f"{gad7_summary['latest']}/{gad7_summary['max']} " | |
| f"(band: {gad7_summary['latest_band']})." | |
| ) | |
| else: | |
| lines.append( | |
| f"- GAD-7 (anxiety): {gad7_summary['attempts']} attempts, " | |
| f"latest {gad7_summary['latest']}/{gad7_summary['max']} " | |
| f"(band: {gad7_summary['latest_band']}), " | |
| f"change from first attempt: {int(gad7_summary['delta']):+d} " | |
| f"({gad7_summary['direction']})." | |
| ) | |
| # --- Daily check-ins --- | |
| checkin_summary = _summarize_checkins(checkins) | |
| if checkin_summary and checkin_summary["count"] >= 1: | |
| def _f(value: Optional[float]) -> str: | |
| return "-" if value is None else f"{value:.1f}" | |
| if checkin_summary["count"] >= 2: | |
| lines.append( | |
| f"- Daily check-ins: {checkin_summary['count']} logged. " | |
| f"Mood avg {_f(checkin_summary['mood_avg'])}/10 " | |
| f"(trend: {checkin_summary['mood_trend']}), " | |
| f"sleep avg {_f(checkin_summary['sleep_avg'])} h, " | |
| f"stress avg {_f(checkin_summary['stress_avg'])}/10 " | |
| f"(trend: {checkin_summary['stress_trend']})." | |
| ) | |
| else: | |
| lines.append( | |
| f"- Daily check-ins: 1 logged so far " | |
| f"(mood {_f(checkin_summary['mood_avg'])}/10, " | |
| f"sleep {_f(checkin_summary['sleep_avg'])} h, " | |
| f"stress {_f(checkin_summary['stress_avg'])}/10). " | |
| f"Not enough data for a trend yet." | |
| ) | |
| # --- Stepped-care tier (deterministic rule-based) --- | |
| try: | |
| signals = _aggregate_session_signals() | |
| tier, _reason_key = _compute_stepped_care_recommendation(signals) | |
| tier_human = { | |
| "self_care": "self-care", | |
| "self_help": "self-help tools", | |
| "guided_support": "guided support", | |
| "urgent_professional": "urgent professional help", | |
| }.get(tier, tier) | |
| lines.append(f"- Current stepped-care recommendation: {tier_human}.") | |
| except Exception: | |
| # Stepped-care is nice-to-have; never fail the digest on a rule-engine bug. | |
| pass | |
| if not lines: | |
| return "" | |
| header = ( | |
| "# What Saathi knows from this user's Thought Diary (this session only)\n" | |
| "# You may reference this naturally when it is relevant, but do NOT read\n" | |
| "# it out like a database dump. Use it to personalise your reply.\n" | |
| "# Never claim you can see journal entries the user did not mention —\n" | |
| "# only cite the patterns below." | |
| ) | |
| return header + "\n" + "\n".join(lines) | |