"""Module 4 - Thought Diary (née Cognitive Journal): CBT journal plus deterministic clinical signals. The original journal path stays intact: user entry -> crisis check -> structured LLM JSON -> distortion cards. Around it we add non-LLM clinical surfaces: PHQ-9, GAD-7, daily check-ins, and an auditable stepped-care dashboard. """ from __future__ import annotations from collections import Counter from datetime import datetime from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple import plotly.graph_objects as go import streamlit as st from pydantic import BaseModel, Field from backend.claude_client import chat_structured from backend.i18n import claude_language_name, t from backend.resources import load_screeners from backend.safeguards import check_crisis, render_crisis_banner MODULE_NAME = "cognitive_journal" ENTRIES_KEY = "cognitive_journal_entries" LAST_ANALYSIS_KEY = "cognitive_journal_last" SECTION_KEY = "cognitive_journal_section" JOURNAL_WIDGET_VERSION_KEY = "cognitive_journal_widget_version" JOURNAL_NOTICE_KEY = "cognitive_journal_notice" PHQ9_SCORE_KEY = "cognitive_journal_phq9_score" PHQ9_ANSWERS_KEY = "cognitive_journal_phq9_answers" PHQ9_TAKEN_AT_KEY = "cognitive_journal_phq9_taken_at" PHQ9_HISTORY_KEY = "cognitive_journal_phq9_history" PHQ9_ITEM9_KEY = "cognitive_journal_phq9_item9_positive" PHQ9_WIDGET_VERSION_KEY = "cognitive_journal_phq9_widget_version" GAD7_SCORE_KEY = "cognitive_journal_gad7_score" GAD7_ANSWERS_KEY = "cognitive_journal_gad7_answers" GAD7_TAKEN_AT_KEY = "cognitive_journal_gad7_taken_at" GAD7_HISTORY_KEY = "cognitive_journal_gad7_history" GAD7_WIDGET_VERSION_KEY = "cognitive_journal_gad7_widget_version" CHECKINS_KEY = "cognitive_journal_checkins" LEGACY_CLEANUP_KEY = "cognitive_journal_legacy_cleanup_done" ScreenerKind = Literal["phq9", "gad7"] SECTION_CHOICES = [ ("journal", "journal_section_journal_label"), ("phq9", "journal_section_phq9_label"), ("gad7", "journal_section_gad7_label"), ("checkin", "journal_section_checkin_label"), ("dashboard", "journal_section_dashboard_label"), ] DistortionType = Literal[ "catastrophizing", "mind_reading", "all_or_nothing", "fortune_telling", "personalization", "mental_filter", "emotional_reasoning", "should_statements", ] MoodType = Literal[ "anxious", "sad", "frustrated", "hopeful", "neutral", "overwhelmed", ] class Distortion(BaseModel): type: DistortionType phrase: str explanation: str reframe: str evidence_question: str class JournalAnalysis(BaseModel): overall_mood: MoodType distortions: List[Distortion] = Field(default_factory=list) summary: str needs_professional_signal: bool = False ts: str = "" entry_text: str = "" DISTORTION_LABELS = { "catastrophizing": "Catastrophizing", "mind_reading": "Mind-reading", "all_or_nothing": "All-or-nothing thinking", "fortune_telling": "Fortune-telling", "personalization": "Personalization", "mental_filter": "Mental filter", "emotional_reasoning": "Emotional reasoning", "should_statements": "Should-statements", } _LEGACY_SEED_PHRASES = { "my whole placement season is ruined", "everyone in my batch thinks I'm useless", "I should have studied harder in first year", "my parents are disappointed because of me", "I'm going to freeze up in the viva and fail", "if I don't get an A I've wasted the whole semester", "I only remember the parts I got wrong in the last test", } _LEGACY_SEED_CHECKINS = [ {"mood": 5, "sleep_h": 6.5, "stress": 6}, {"mood": 6, "sleep_h": 7.0, "stress": 5}, ] def _is_legacy_seed_entry(entry: Any) -> bool: if isinstance(entry, Mapping): distortions = entry.get("distortions", []) or [] else: distortions = getattr(entry, "distortions", []) or [] phrases = set() for distortion in distortions: if isinstance(distortion, Mapping): phrase = distortion.get("phrase", "") else: phrase = getattr(distortion, "phrase", "") if phrase: phrases.add(str(phrase)) return bool(phrases) and phrases.issubset(_LEGACY_SEED_PHRASES) def _drop_legacy_seed_entries(entries: Sequence[Any]) -> List[Any]: return [entry for entry in entries if not _is_legacy_seed_entry(entry)] def _same_checkin_values(item: Mapping[str, Any], expected: Mapping[str, Any]) -> bool: try: return ( int(item.get("mood")) == int(expected["mood"]) and float(item.get("sleep_h")) == float(expected["sleep_h"]) and int(item.get("stress")) == int(expected["stress"]) ) except (TypeError, ValueError): return False def _drop_legacy_seed_checkins(checkins: Sequence[Mapping[str, Any]]) -> List[Mapping[str, Any]]: checkins_list = list(checkins) if len(checkins_list) < len(_LEGACY_SEED_CHECKINS): return checkins_list for item, expected in zip(checkins_list, _LEGACY_SEED_CHECKINS): if not _same_checkin_values(item, expected): return checkins_list return checkins_list[len(_LEGACY_SEED_CHECKINS) :] def _init_state() -> None: needs_legacy_cleanup = not bool(st.session_state.get(LEGACY_CLEANUP_KEY)) if ENTRIES_KEY not in st.session_state: st.session_state[ENTRIES_KEY] = [] elif needs_legacy_cleanup: st.session_state[ENTRIES_KEY] = _drop_legacy_seed_entries( list(st.session_state.get(ENTRIES_KEY, []) or []) ) if LAST_ANALYSIS_KEY not in st.session_state: st.session_state[LAST_ANALYSIS_KEY] = None if SECTION_KEY not in st.session_state: st.session_state[SECTION_KEY] = "journal" if JOURNAL_WIDGET_VERSION_KEY not in st.session_state: st.session_state[JOURNAL_WIDGET_VERSION_KEY] = 0 if JOURNAL_NOTICE_KEY not in st.session_state: st.session_state[JOURNAL_NOTICE_KEY] = None if PHQ9_SCORE_KEY not in st.session_state: st.session_state[PHQ9_SCORE_KEY] = None if PHQ9_ANSWERS_KEY not in st.session_state: st.session_state[PHQ9_ANSWERS_KEY] = [0] * 9 if PHQ9_TAKEN_AT_KEY not in st.session_state: st.session_state[PHQ9_TAKEN_AT_KEY] = "" if PHQ9_HISTORY_KEY not in st.session_state: st.session_state[PHQ9_HISTORY_KEY] = [] if PHQ9_ITEM9_KEY not in st.session_state: st.session_state[PHQ9_ITEM9_KEY] = False if PHQ9_WIDGET_VERSION_KEY not in st.session_state: st.session_state[PHQ9_WIDGET_VERSION_KEY] = 0 if GAD7_SCORE_KEY not in st.session_state: st.session_state[GAD7_SCORE_KEY] = None if GAD7_ANSWERS_KEY not in st.session_state: st.session_state[GAD7_ANSWERS_KEY] = [0] * 7 if GAD7_TAKEN_AT_KEY not in st.session_state: st.session_state[GAD7_TAKEN_AT_KEY] = "" if GAD7_HISTORY_KEY not in st.session_state: st.session_state[GAD7_HISTORY_KEY] = [] if GAD7_WIDGET_VERSION_KEY not in st.session_state: st.session_state[GAD7_WIDGET_VERSION_KEY] = 0 if CHECKINS_KEY not in st.session_state: st.session_state[CHECKINS_KEY] = [] elif needs_legacy_cleanup: st.session_state[CHECKINS_KEY] = _drop_legacy_seed_checkins( list(st.session_state.get(CHECKINS_KEY, []) or []) ) st.session_state[LEGACY_CLEANUP_KEY] = True def _render_section_router(lang: str) -> str: options = [key for key, _ in SECTION_CHOICES] # Normalise the session-state value before the widget reads it, so we never # pass both `index=` and an existing `key=` (Streamlit forbids that combo). current = st.session_state.get(SECTION_KEY, "journal") if current not in options: st.session_state[SECTION_KEY] = "journal" choice = st.radio( "journal-section", options=options, format_func=lambda key: t(dict(SECTION_CHOICES)[key], lang), horizontal=True, label_visibility="collapsed", key=SECTION_KEY, ) return str(choice) def _render_journal_explainer(lang: str) -> None: with st.expander(t("journal_about_heading", lang), expanded=False): st.markdown(t("journal_about_body", lang)) def _render_analysis(analysis: Any, lang: str) -> None: st.markdown(f"##### {t('journal_summary_heading', lang)}") st.info(_entry_summary(analysis)) if _entry_needs_pro(analysis): st.warning(t("journal_needs_pro", lang)) distortions = _entry_distortions(analysis) if not distortions: st.success(t("journal_no_distortions", lang)) return st.markdown(f"##### {t('journal_distortions_heading', lang)}") for d in distortions: distortion_type = _distortion_type(d) with st.container(border=True): st.markdown(f"> *{_distortion_phrase(d)}*") st.markdown( f"**{DISTORTION_LABELS.get(distortion_type, distortion_type)}** - " f"{_distortion_explanation(d)}" ) st.markdown(f"**{t('journal_reframe_heading', lang)}:** {_distortion_reframe(d)}") st.caption( f"**{t('journal_question_heading', lang)}:** " f"{_distortion_evidence_question(d)}" ) def _entry_distortions(entry: Any) -> Sequence[Any]: if isinstance(entry, JournalAnalysis): return entry.distortions if isinstance(entry, Mapping): return entry.get("distortions", []) or [] return getattr(entry, "distortions", []) or [] def _entry_needs_pro(entry: Any) -> bool: if isinstance(entry, JournalAnalysis): return bool(entry.needs_professional_signal) if isinstance(entry, Mapping): return bool(entry.get("needs_professional_signal")) return bool(getattr(entry, "needs_professional_signal", False)) def _entry_timestamp(entry: Any) -> str: if isinstance(entry, JournalAnalysis): return str(entry.ts or "") if isinstance(entry, Mapping): return str(entry.get("ts", "")) return str(getattr(entry, "ts", "")) def _entry_summary(entry: Any) -> str: if isinstance(entry, JournalAnalysis): return str(entry.summary or "") if isinstance(entry, Mapping): return str(entry.get("summary", "")) return str(getattr(entry, "summary", "")) def _entry_mood(entry: Any) -> str: if isinstance(entry, JournalAnalysis): return str(entry.overall_mood or "") if isinstance(entry, Mapping): return str(entry.get("overall_mood", "")) return str(getattr(entry, "overall_mood", "")) def _entry_text(entry: Any) -> str: if isinstance(entry, JournalAnalysis): return str(entry.entry_text or "") if isinstance(entry, Mapping): return str(entry.get("entry_text", "")) return str(getattr(entry, "entry_text", "")) def _distortion_type(distortion: Any) -> str: if isinstance(distortion, Distortion): return distortion.type if isinstance(distortion, Mapping): return str(distortion.get("type", "")) return str(getattr(distortion, "type", "")) def _distortion_phrase(distortion: Any) -> str: if isinstance(distortion, Distortion): return str(distortion.phrase or "") if isinstance(distortion, Mapping): return str(distortion.get("phrase", "")) return str(getattr(distortion, "phrase", "")) def _distortion_explanation(distortion: Any) -> str: if isinstance(distortion, Distortion): return str(distortion.explanation or "") if isinstance(distortion, Mapping): return str(distortion.get("explanation", "")) return str(getattr(distortion, "explanation", "")) def _distortion_reframe(distortion: Any) -> str: if isinstance(distortion, Distortion): return str(distortion.reframe or "") if isinstance(distortion, Mapping): return str(distortion.get("reframe", "")) return str(getattr(distortion, "reframe", "")) def _distortion_evidence_question(distortion: Any) -> str: if isinstance(distortion, Distortion): return str(distortion.evidence_question or "") if isinstance(distortion, Mapping): return str(distortion.get("evidence_question", "")) return str(getattr(distortion, "evidence_question", "")) def _analysis_to_journal_record(analysis: Any, entry_text: str) -> Dict[str, Any]: if isinstance(analysis, Mapping): record = dict(analysis) elif hasattr(analysis, "model_dump"): record = dict(analysis.model_dump(mode="python")) elif hasattr(analysis, "dict"): record = dict(analysis.dict()) else: record = { "overall_mood": getattr(analysis, "overall_mood", "neutral"), "distortions": getattr(analysis, "distortions", []) or [], "summary": getattr(analysis, "summary", ""), "needs_professional_signal": bool( getattr(analysis, "needs_professional_signal", False) ), } record["ts"] = datetime.now().isoformat(timespec="seconds") record["entry_text"] = entry_text.strip() record.setdefault("distortions", []) record.setdefault("summary", "") record.setdefault("overall_mood", "neutral") record.setdefault("needs_professional_signal", False) return record def _fallback_journal_record(entry_text: str, lang: str, error: Exception) -> Dict[str, Any]: record = _analysis_to_journal_record( { "overall_mood": "neutral", "distortions": [], "summary": t("journal_analysis_failed_summary", lang), "needs_professional_signal": False, }, entry_text, ) record["analysis_error"] = str(error)[:240] return record def _crisis_journal_record(entry_text: str, lang: str) -> Dict[str, Any]: """Record used when crisis language is detected. We intentionally do NOT call the LLM on crisis text — the banner is the priority. We still append an entry so the user's history is preserved and the dashboard / chart keeps moving forward. We also tag a default distortion so the distortion chart still reflects crisis entries (otherwise they'd show as "no distortions" which is misleading — the user's crisis text almost certainly contains cognitive distortions, we just can't safely send it to the LLM). """ # Extract a short snippet for the phrase field (first ~120 chars of entry) snippet = entry_text.strip() if len(snippet) > 120: snippet = snippet[:117].rstrip() + "..." record = _analysis_to_journal_record( { "overall_mood": "overwhelmed", "distortions": [ { "type": "catastrophizing", "phrase": snippet, "explanation": t("journal_crisis_distortion_explanation", lang), "reframe": t("journal_crisis_distortion_reframe", lang), "evidence_question": t("journal_crisis_distortion_question", lang), }, ], "summary": t("journal_crisis_summary", lang), "needs_professional_signal": True, }, entry_text, ) record["crisis_flagged"] = True return record def _append_journal_record(record: Mapping[str, Any]) -> List[Any]: entries = list(st.session_state.get(ENTRIES_KEY, []) or []) entries.append(dict(record)) st.session_state[ENTRIES_KEY] = entries st.session_state[LAST_ANALYSIS_KEY] = dict(record) return entries def _short_text(text: str, limit: int = 96) -> str: cleaned = " ".join(str(text or "").split()) if not cleaned: return "-" if len(cleaned) <= limit: return cleaned return f"{cleaned[: limit - 3].rstrip()}..." def _journal_history_rows(lang: str) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] entries = list(st.session_state.get(ENTRIES_KEY, []) or []) for fallback_order, entry in enumerate(entries, start=1): ts = _entry_timestamp(entry) distortions = _entry_distortions(entry) rows.append( { "entry": fallback_order, "logged_at": _timestamp_label(ts), "mood": _entry_mood(entry) or "-", "distortions": len(distortions), "support": t("journal_support_yes", lang) if _entry_needs_pro(entry) else t("journal_support_no", lang), "note": _short_text(_entry_text(entry)), "observation": _short_text(_entry_summary(entry), limit=140), "_sort_ts": _parse_timestamp(ts), "_fallback_order": fallback_order, } ) rows.sort( key=lambda row: ( row.get("_sort_ts") or datetime.max, int(row.get("_fallback_order") or 0), ) ) for idx, row in enumerate(rows, start=1): row["entry"] = idx visible_keys = ( "entry", "logged_at", "mood", "distortions", "support", "note", "observation", ) return [{key: row.get(key) for key in visible_keys} for row in rows] def _render_journal_history(lang: str) -> None: st.markdown(f"##### {t('dashboard_journal_history_heading', lang)}") rows = _journal_history_rows(lang) if not rows: st.caption(t("dashboard_journal_history_empty", lang)) return st.dataframe( rows, use_container_width=True, hide_index=True, column_config={ "entry": st.column_config.NumberColumn(t("journal_col_entry", lang), width="small"), "logged_at": st.column_config.TextColumn(t("journal_col_logged", lang), width="small"), "mood": st.column_config.TextColumn(t("journal_col_mood", lang), width="small"), "distortions": st.column_config.NumberColumn(t("journal_col_distortions", lang), width="small"), "support": st.column_config.TextColumn(t("journal_col_support", lang), width="small"), "note": st.column_config.TextColumn(t("journal_col_note", lang), width="medium"), "observation": st.column_config.TextColumn(t("journal_col_observation", lang), width="large"), }, ) def _distortion_chart_rows(entries: Sequence[Any]) -> List[Dict[str, Any]]: counter: Counter[str] = Counter() for entry in entries: for d in _entry_distortions(entry): key = _distortion_type(d) if key: counter[key] += 1 return [ {"distortion": DISTORTION_LABELS.get(k, k), "count": v} for k, v in counter.most_common() ] def _render_chart(lang: str) -> None: entries = list(st.session_state.get(ENTRIES_KEY, []) or []) rows = _distortion_chart_rows(entries) if not rows: if entries: # Entries exist, but Saathi didn't tag any CBT distortions in them. # This is common when the text describes feelings ("I feel lonely") # rather than distorted thoughts ("my life is ruined forever"). st.info( t("dashboard_distortion_no_distortions", lang).format( count=len(entries) ) ) st.caption(t("dashboard_distortion_try_examples", lang)) with st.expander(t("dashboard_distortion_examples_heading", lang), expanded=False): st.markdown(t("dashboard_distortion_examples_body", lang)) else: st.caption(t("dashboard_distortion_empty_hint", lang)) return max_count = max(int(row["count"]) for row in rows) fig = go.Figure() fig.add_trace( go.Bar( x=[row["distortion"] for row in rows], y=[int(row["count"]) for row in rows], text=[str(row["count"]) for row in rows], textposition="outside", marker=dict(color="#2563EB", line=dict(color="#1D4ED8", width=1)), hovertemplate=f"%{{x}}
{t('journal_chart_yaxis', lang)} %{{y}}", ) ) fig.update_layout( title=t("journal_chart_title", lang), xaxis_title=t("journal_chart_xaxis", lang), yaxis_title=t("journal_chart_yaxis", lang), showlegend=False, height=320, margin=dict(l=10, r=10, t=50, b=10), bargap=0.45, plot_bgcolor="rgba(0,0,0,0)", paper_bgcolor="rgba(0,0,0,0)", ) fig.update_xaxes(type="category", tickangle=0) fig.update_yaxes( range=[0, max(2, max_count + 1)], tickmode="linear", dtick=1, rangemode="tozero", ) fig.update_traces(cliponaxis=False) st.plotly_chart(fig, use_container_width=True) _render_distortion_summary_line(entries, lang) def _render_screener_summary_line(kind: ScreenerKind, lang: str) -> None: summary = _summarize_screener_history(kind) if not summary: return config = _screener_config(kind) name = str(config.get("name", kind.upper())) if summary["direction"] == "single": st.info( t("screener_summary_single_attempt", lang).format( name=name, latest=summary["latest"], max=summary["max"], band=t(f"band_{summary['latest_band']}", lang), ) ) return delta = int(summary["delta"]) delta_signed = f"{delta:+d}" direction_label = t(f"screener_direction_{summary['direction']}", lang) st.info( t("screener_summary_line", lang).format( name=name, attempts=summary["attempts"], latest=summary["latest"], max=summary["max"], band=t(f"band_{summary['latest_band']}", lang), delta_signed=delta_signed, direction=direction_label, ) ) def _render_checkin_summary_line(lang: str) -> None: checkins = list(st.session_state.get(CHECKINS_KEY, []) or []) summary = _summarize_checkins(checkins) if not summary: return if summary["count"] < 2: st.caption(t("checkin_summary_single", lang)) return def _fmt(value: Optional[float]) -> str: if value is None: return "-" return f"{value:.1f}" st.info( t("checkin_summary_line", lang).format( count=summary["count"], mood_avg=_fmt(summary["mood_avg"]), sleep_avg=_fmt(summary["sleep_avg"]), stress_avg=_fmt(summary["stress_avg"]), mood_trend=t(f"checkin_trend_{summary['mood_trend']}", lang), stress_trend=t(f"checkin_trend_{summary['stress_trend']}", lang), ) ) def _render_distortion_summary_line(entries: Sequence[Any], lang: str) -> None: summary = _summarize_distortions(entries) if not summary: return top_label = DISTORTION_LABELS.get(summary["top_type"], summary["top_type"]) if summary.get("second_type"): second_label = DISTORTION_LABELS.get( summary["second_type"], summary["second_type"] ) line = t("distortion_summary_line", lang).format( top=top_label, top_count=summary["top_count"], total_entries=summary["total_entries"], second=second_label, second_count=summary["second_count"], ) else: line = t("distortion_summary_single_line", lang).format( top=top_label, top_count=summary["top_count"], total_entries=summary["total_entries"], ) tip = t(summary["tip_key"], lang) st.info(f"{line}\n\n**{t('distortion_summary_tip_prefix', lang)}** {tip}") def _render_journal(lang: str) -> None: notice = st.session_state.get(JOURNAL_NOTICE_KEY) if isinstance(notice, Mapping) and notice.get("message"): level = notice.get("level") if level == "warning": st.warning(str(notice["message"])) else: st.success(str(notice["message"])) st.session_state[JOURNAL_NOTICE_KEY] = None input_key = f"cognitive_journal_input_v{st.session_state[JOURNAL_WIDGET_VERSION_KEY]}" entry = st.text_area( "journal_entry", placeholder=t("journal_input_placeholder", lang), label_visibility="collapsed", height=140, key=input_key, ) if st.button(t("journal_send_button", lang), type="primary", key="cognitive_journal_send_button"): entry_text = entry.strip() if entry_text: if check_crisis(entry_text): # Crisis detected: show banner immediately, and STILL log the # entry (with the pro-signal flag set) so the user's record is # preserved and the dashboard keeps working. render_crisis_banner(lang) record = _crisis_journal_record(entry_text, lang) entries = _append_journal_record(record) st.session_state[JOURNAL_NOTICE_KEY] = { "level": "warning", "message": t("journal_crisis_logged", lang).format(count=len(entries)), } st.session_state[JOURNAL_WIDGET_VERSION_KEY] += 1 return with st.spinner("..."): try: analysis = chat_structured( module=MODULE_NAME, user_text=entry_text, language_name=claude_language_name(lang), schema=JournalAnalysis, ) record = _analysis_to_journal_record(analysis, entry_text) entries = _append_journal_record(record) st.session_state[JOURNAL_NOTICE_KEY] = { "level": "success", "message": t("journal_logged", lang).format(count=len(entries)), } except Exception as e: record = _fallback_journal_record(entry_text, lang, e) entries = _append_journal_record(record) st.session_state[JOURNAL_NOTICE_KEY] = { "level": "warning", "message": t("journal_saved_without_analysis", lang).format( count=len(entries) ), } st.session_state[JOURNAL_WIDGET_VERSION_KEY] += 1 st.rerun() if st.session_state[LAST_ANALYSIS_KEY] is not None: _render_analysis(st.session_state[LAST_ANALYSIS_KEY], lang) _render_recent_journal_entries(lang) def _render_recent_journal_entries(lang: str) -> None: """Compact 'previous entries' list shown directly under the Journal tab.""" entries = list(st.session_state.get(ENTRIES_KEY, []) or []) if len(entries) <= 1: return with st.expander(t("journal_recent_heading", lang), expanded=False): rows = _journal_history_rows(lang) if not rows: st.caption(t("journal_recent_empty", lang)) return recent_rows = list(reversed(rows))[:5] st.dataframe( recent_rows, use_container_width=True, hide_index=True, column_config={ "entry": st.column_config.NumberColumn(t("journal_col_entry", lang), width="small"), "logged_at": st.column_config.TextColumn(t("journal_col_logged", lang), width="small"), "mood": st.column_config.TextColumn(t("journal_col_mood", lang), width="small"), "distortions": st.column_config.NumberColumn(t("journal_col_distortions", lang), width="small"), "support": st.column_config.TextColumn(t("journal_col_support", lang), width="small"), "note": st.column_config.TextColumn(t("journal_col_note", lang), width="medium"), "observation": st.column_config.TextColumn(t("journal_col_observation", lang), width="large"), }, ) def _score_screener(answers: List[int]) -> int: return sum(int(a) for a in answers) def _band_for(score: int, bands: List[dict]) -> str: for band in bands: if score <= int(band["max"]): return str(band["key"]) return str(bands[-1]["key"]) if bands else "none" # --- Deterministic summary helpers (no Streamlit, unit-testable) ----------- def _direction_from_delta(delta: int) -> str: """Lower score = improving for both PHQ-9 and GAD-7.""" if delta <= -2: return "improving" if delta >= 2: return "worsening" return "stable" def _summarize_screener_history(kind: ScreenerKind) -> Optional[Dict[str, Any]]: """Return a short structured summary of a screener's history. Returns None when history is empty. When exactly one attempt exists, `delta` and `direction` are None / "single". """ history = list(st.session_state.get(_history_key(kind), []) or []) if not history: return None scores = [_coerce_int(item.get("score")) for item in history] scores = [s for s in scores if s is not None] if not scores: return None config = _screener_config(kind) bands = list(config.get("bands") or []) latest_score = int(scores[-1]) first_score = int(scores[0]) latest_band = _band_for(latest_score, bands) max_score = _max_score(config) if len(scores) == 1: return { "attempts": 1, "latest": latest_score, "latest_band": latest_band, "max": max_score, "first": latest_score, "delta": 0, "direction": "single", } delta = latest_score - first_score return { "attempts": len(scores), "latest": latest_score, "latest_band": latest_band, "max": max_score, "first": first_score, "delta": delta, "direction": _direction_from_delta(delta), } def _trend_label(prev_avg: Optional[float], last_avg: Optional[float]) -> str: if prev_avg is None or last_avg is None: return "flat" diff = last_avg - prev_avg if diff >= 0.75: return "up" if diff <= -0.75: return "down" return "flat" def _mean(values: Sequence[float]) -> Optional[float]: nums = [float(v) for v in values if v is not None] return sum(nums) / len(nums) if nums else None def _summarize_checkins( checkins: Sequence[Mapping[str, Any]], ) -> Optional[Dict[str, Any]]: sorted_items = list(_sorted_checkins(checkins)) if not sorted_items: return None moods = [float(item.get("mood")) for item in sorted_items if item.get("mood") is not None] sleeps = [ float(item.get("sleep_h")) for item in sorted_items if item.get("sleep_h") is not None ] stresses = [ float(item.get("stress")) for item in sorted_items if item.get("stress") is not None ] last_moods = moods[-3:] prev_moods = moods[-6:-3] if len(moods) >= 6 else moods[:-3] last_stress = stresses[-3:] prev_stress = stresses[-6:-3] if len(stresses) >= 6 else stresses[:-3] return { "count": len(sorted_items), "mood_avg": _mean(moods), "sleep_avg": _mean(sleeps), "stress_avg": _mean(stresses), "mood_trend": _trend_label(_mean(prev_moods), _mean(last_moods)), "stress_trend": _trend_label(_mean(prev_stress), _mean(last_stress)), } def _summarize_distortions(entries: Sequence[Any]) -> Optional[Dict[str, Any]]: counter: Counter[str] = Counter() total_distortions = 0 for entry in entries: for d in _entry_distortions(entry): key = _distortion_type(d) if key: counter[key] += 1 total_distortions += 1 if total_distortions == 0: return None top_entries = counter.most_common(2) top_type, top_count = top_entries[0] second_type, second_count = (top_entries[1] if len(top_entries) > 1 else (None, 0)) return { "top_type": top_type, "top_count": int(top_count), "second_type": second_type, "second_count": int(second_count), "total_entries": len(entries), "total_distortions": total_distortions, "tip_key": f"distortion_tip_{top_type}", } def _screener_config(kind: ScreenerKind) -> Dict[str, Any]: data = load_screeners() return dict(data.get(kind) or {}) def _max_score(config: Mapping[str, Any]) -> int: bands = config.get("bands") or [] return int(bands[-1]["max"]) if bands else 0 def _response_labels(config: Mapping[str, Any], lang: str) -> List[str]: labels = config.get("response_labels") or {} selected = labels.get(lang) or labels.get("en") or [] return list(selected) def _localized_item(item: Mapping[str, Any], lang: str) -> str: if item.get(lang): return str(item[lang]) text = str(item.get("en", "")) if lang != "en": return f"{text} (translation pending)" return text def _answers_key(kind: ScreenerKind) -> str: return PHQ9_ANSWERS_KEY if kind == "phq9" else GAD7_ANSWERS_KEY def _score_key(kind: ScreenerKind) -> str: return PHQ9_SCORE_KEY if kind == "phq9" else GAD7_SCORE_KEY def _taken_key(kind: ScreenerKind) -> str: return PHQ9_TAKEN_AT_KEY if kind == "phq9" else GAD7_TAKEN_AT_KEY def _history_key(kind: ScreenerKind) -> str: return PHQ9_HISTORY_KEY if kind == "phq9" else GAD7_HISTORY_KEY def _version_key(kind: ScreenerKind) -> str: return PHQ9_WIDGET_VERSION_KEY if kind == "phq9" else GAD7_WIDGET_VERSION_KEY def _next_screener_log_order() -> int: return ( len(st.session_state.get(PHQ9_HISTORY_KEY, []) or []) + len(st.session_state.get(GAD7_HISTORY_KEY, []) or []) + 1 ) def _ensure_answers(kind: ScreenerKind, n_items: int) -> List[int]: key = _answers_key(kind) answers = list(st.session_state.get(key) or []) if len(answers) != n_items: answers = [0] * n_items st.session_state[key] = answers return answers def _record_screener(kind: ScreenerKind, config: Mapping[str, Any], answers: List[int]) -> None: score = _score_screener(answers) band = _band_for(score, list(config.get("bands") or [])) ts = datetime.now().isoformat(timespec="seconds") log_order = _next_screener_log_order() q9_positive = bool(kind == "phq9" and len(answers) >= 9 and answers[8] > 0) st.session_state[_score_key(kind)] = score st.session_state[_answers_key(kind)] = list(answers) st.session_state[_taken_key(kind)] = ts if kind == "phq9": st.session_state[PHQ9_ITEM9_KEY] = q9_positive history = list(st.session_state.get(_history_key(kind), [])) history.append( { "ts": ts, "log_order": log_order, "score": score, "band": band, "answers": list(answers), "q9_positive": q9_positive, } ) st.session_state[_history_key(kind)] = history def _parse_timestamp(ts: str) -> Optional[datetime]: if not ts: return None try: return datetime.fromisoformat(ts) except ValueError: return None def _timestamp_label(ts: str) -> str: if not ts: return "-" dt = _parse_timestamp(ts) if dt is None: return ts if dt.second: return dt.strftime("%d %b, %H:%M:%S") return dt.strftime("%d %b, %H:%M") def _render_screener_result(kind: ScreenerKind, config: Mapping[str, Any], score: int, lang: str) -> None: band = _band_for(score, list(config.get("bands") or [])) max_score = _max_score(config) band_label = t(f"band_{band}", lang) interp = t(f"band_{band}_interpretation", lang) with st.container(border=True): cols = st.columns(3) cols[0].metric(t("screener_score_label", lang), f"{score}/{max_score}") cols[1].metric(t("screener_band_label", lang), band_label) cols[2].metric(t("screener_last_taken", lang), _timestamp_label(st.session_state[_taken_key(kind)])) st.write(interp) st.caption(t("screener_disclaimer", lang)) def _render_single_screener_timeline(kind: ScreenerKind, lang: str) -> None: rows = _screener_history_rows(kind, lang) if not rows: return rows.sort( key=lambda row: ( row.get("_sort_ts") or datetime.max, int(row.get("attempt") or 0), ) ) plot_rows = [row for row in rows if row.get("score") is not None] if not plot_rows: return config = _screener_config(kind) fig = go.Figure() fig.add_trace( go.Scatter( x=[int(row["attempt"]) for row in plot_rows], y=[int(row["score"]) for row in plot_rows], mode="lines+markers+text", name=config.get("name", kind.upper()), text=[str(row["score"]) for row in plot_rows], textposition="top center", customdata=[[row.get("taken_at", "-")] for row in plot_rows], hovertemplate=( "%{fullData.name}
" "Attempt %{x}
" "Score %{y}
" "Taken %{customdata[0]}" ), marker=dict(size=14), line=dict(width=3), ) ) max_attempt = max(int(row["attempt"]) for row in plot_rows) fig.update_layout( title=f"{config.get('name', kind.upper())} {t('dashboard_screener_timeline_heading', lang)}", xaxis_title="Attempt", yaxis_title="Score", height=260, margin=dict(l=10, r=10, t=50, b=10), showlegend=False, ) fig.update_xaxes( type="linear", range=[0.5, max_attempt + 0.5], tickmode="linear", dtick=1, ) fig.update_yaxes(range=[0, _max_score(config)], rangemode="tozero") st.plotly_chart(fig, use_container_width=True) def _render_screener(kind: ScreenerKind, lang: str) -> None: config = _screener_config(kind) if not config: st.error("Screener data could not be loaded.") return items = list(config.get("items") or []) answers = _ensure_answers(kind, len(items)) labels = _response_labels(config, lang) if len(labels) != 4: labels = ["Not at all", "Several days", "More than half the days", "Nearly every day"] st.subheader(t(f"{kind}_header", lang)) st.caption(t(f"{kind}_sub", lang)) st.caption(f"**{t('screener_timeframe_prefix', lang)}:** {config.get('timeframe', '')}") if lang != "en": st.info(t("screener_translation_pending", lang)) st.caption(t("screener_source_caption", lang)) version = st.session_state[_version_key(kind)] for idx, item in enumerate(items): answer = st.radio( _localized_item(item, lang), options=[0, 1, 2, 3], index=int(answers[idx]), format_func=lambda value, labels=labels: f"{value} - {labels[value]}", horizontal=True, key=f"cognitive_journal_{kind}_v{version}_item_{idx}", ) answers[idx] = int(answer) st.session_state[_answers_key(kind)] = answers button_cols = st.columns([1, 1, 4]) with button_cols[0]: score_clicked = st.button( t(f"{kind}_cta", lang), type="primary", key=f"cognitive_journal_{kind}_score_button", disabled=st.session_state[_score_key(kind)] is not None, ) with button_cols[1]: retake_clicked = st.button( t(f"{kind}_retake", lang), key=f"cognitive_journal_{kind}_retake", disabled=st.session_state[_score_key(kind)] is None, ) if score_clicked: _record_screener(kind, config, answers) if retake_clicked: st.session_state[_answers_key(kind)] = [0] * len(items) st.session_state[_score_key(kind)] = None st.session_state[_taken_key(kind)] = "" if kind == "phq9": st.session_state[PHQ9_ITEM9_KEY] = False st.session_state[_version_key(kind)] += 1 st.rerun() score = st.session_state[_score_key(kind)] if score is not None: _render_screener_result(kind, config, int(score), lang) _render_single_screener_timeline(kind, lang) _render_screener_summary_line(kind, lang) def _render_phq9(lang: str) -> None: _render_screener("phq9", lang) def _render_gad7(lang: str) -> None: _render_screener("gad7", lang) def _append_checkin(mood: int, sleep_h: float, stress: int) -> None: checkins = list(st.session_state.get(CHECKINS_KEY, [])) checkins.append( { "ts": datetime.now().isoformat(timespec="seconds"), "mood": int(mood), "sleep_h": float(sleep_h), "stress": int(stress), } ) st.session_state[CHECKINS_KEY] = checkins def _render_daily_checkin(lang: str) -> None: st.subheader(t("checkin_header", lang)) st.caption(t("checkin_sub", lang)) mood = st.slider(t("checkin_mood_label", lang), 1, 10, 6, key="cognitive_journal_checkin_mood") sleep_h = st.slider( t("checkin_sleep_label", lang), 0.0, 12.0, 7.0, 0.5, key="cognitive_journal_checkin_sleep", ) stress = st.slider(t("checkin_stress_label", lang), 1, 10, 5, key="cognitive_journal_checkin_stress") if st.button(t("checkin_submit", lang), type="primary", key="cognitive_journal_checkin_submit"): _append_checkin(mood, sleep_h, stress) st.success(t("checkin_logged", lang)) st.markdown(f"##### {t('checkin_history_heading', lang)}") checkins = st.session_state.get(CHECKINS_KEY, []) if not checkins: st.caption(t("checkin_empty_history", lang)) return for item in reversed(_sorted_checkins(checkins)[-5:]): st.write( f"{_timestamp_label(str(item.get('ts', '')))} · " f"mood {item.get('mood')}/10 · sleep {item.get('sleep_h')}h · stress {item.get('stress')}/10" ) _render_checkin_trend(lang) def _coerce_int(value: Any) -> Optional[int]: if value is None or value == "": return None try: return int(value) except (TypeError, ValueError): return None def _recent_mood_avg(checkins: Sequence[Mapping[str, Any]]) -> Optional[float]: if not checkins: return None recent = checkins[-3:] moods = [] for item in recent: try: moods.append(float(item.get("mood"))) except (TypeError, ValueError): continue if not moods: return None return sum(moods) / len(moods) def _aggregate_session_signals() -> dict: entries = list(st.session_state.get(ENTRIES_KEY, [])) recent_entries = entries[-3:] distortion_total = sum(len(_entry_distortions(entry)) for entry in recent_entries) distortion_density = distortion_total / len(recent_entries) if recent_entries else 0.0 pro_signal_count = sum(1 for entry in recent_entries if _entry_needs_pro(entry)) checkins = _sorted_checkins(st.session_state.get(CHECKINS_KEY, []) or []) chat_history = st.session_state.get("saathi_chat_history", []) or [] chat_turns = sum(1 for msg in chat_history if isinstance(msg, Mapping) and msg.get("role") == "user") student_events = 1 if st.session_state.get("student_response") or st.session_state.get("student_situation") else 0 legal_lookups = 1 if st.session_state.get("legal_aid_response") or st.session_state.get("legal_aid_sections") else 0 soothe_poems = 1 if st.session_state.get("soothe_poem") else 0 return { "phq9_score": _coerce_int(st.session_state.get(PHQ9_SCORE_KEY)), "gad7_score": _coerce_int(st.session_state.get(GAD7_SCORE_KEY)), "phq9_item9_positive": bool(st.session_state.get(PHQ9_ITEM9_KEY)), "recent_mood_avg": _recent_mood_avg(checkins), "distortion_density": distortion_density, "pro_signal_count": pro_signal_count, "crisis_history": False, "chat_turns": chat_turns, "student_events": student_events, "legal_lookups": legal_lookups, "soothe_poems": soothe_poems, } def _score_at_least(score: Optional[int], threshold: int) -> bool: return score is not None and score >= threshold def _compute_stepped_care_recommendation(signals: dict) -> Tuple[str, str]: phq9 = _coerce_int(signals.get("phq9_score")) gad7 = _coerce_int(signals.get("gad7_score")) recent_mood_avg = signals.get("recent_mood_avg") distortion_density = float(signals.get("distortion_density") or 0.0) pro_signal_count = int(signals.get("pro_signal_count") or 0) low_mood_with_distortions = ( recent_mood_avg is not None and float(recent_mood_avg) <= 3 and distortion_density >= 2 ) if ( _score_at_least(phq9, 20) or _score_at_least(gad7, 15) or pro_signal_count >= 2 or bool(signals.get("phq9_item9_positive")) or bool(signals.get("crisis_history")) ): return "urgent_professional", "tier_reason_urgent_professional" if _score_at_least(phq9, 15) or _score_at_least(gad7, 10) or low_mood_with_distortions: return "guided_support", "tier_reason_guided_support" if _score_at_least(phq9, 10) or _score_at_least(gad7, 5) or distortion_density >= 1.5: return "self_help", "tier_reason_self_help" return "self_care", "tier_reason_self_care" def _render_stepped_care_card(lang: str, tier: str, rationale_key: str) -> None: title = t(f"tier_{tier}", lang) body = t(f"tier_{tier}_body", lang) rationale = t(rationale_key, lang) message = f"**{title}**\n\n{body}\n\n_{rationale}_" if tier == "urgent_professional": st.error(message) elif tier == "guided_support": st.warning(message) elif tier == "self_help": st.info(message) else: st.success(message) st.caption(t("tier_disclaimer", lang)) def _render_screener_metric(kind: ScreenerKind, lang: str) -> None: config = _screener_config(kind) score = st.session_state.get(_score_key(kind)) if score is None: st.info(t(f"{kind}_not_taken_nudge", lang)) return score_int = int(score) band = _band_for(score_int, list(config.get("bands") or [])) with st.container(border=True): st.metric(config.get("name", kind.upper()), f"{score_int}/{_max_score(config)}") st.caption(f"{t(f'band_{band}', lang)} · {t('screener_last_taken', lang)}: {_timestamp_label(st.session_state[_taken_key(kind)])}") def _render_screener_snapshot(lang: str) -> None: st.markdown(f"##### {t('dashboard_screener_history_heading', lang)}") cols = st.columns(2) with cols[0]: _render_screener_metric("phq9", lang) with cols[1]: _render_screener_metric("gad7", lang) def _screener_history_rows(kind: ScreenerKind, lang: str) -> List[Dict[str, Any]]: config = _screener_config(kind) rows = [] for idx, item in enumerate(st.session_state.get(_history_key(kind), []) or [], start=1): score = _coerce_int(item.get("score")) band = str(item.get("band") or "none") ts = str(item.get("ts", "")) log_order = _coerce_int(item.get("log_order")) rows.append( { "log": log_order, "attempt": idx, "screener": config.get("name", kind.upper()), "score": score, "band": t(f"band_{band}", lang), "taken_at": _timestamp_label(ts), "_raw_ts": ts, "_sort_ts": _parse_timestamp(ts), "_fallback_order": idx, } ) return rows def _combined_screener_history_rows(lang: str) -> List[Dict[str, Any]]: rows = [] for global_idx, row in enumerate( _screener_history_rows("phq9", lang) + _screener_history_rows("gad7", lang), start=1, ): row = dict(row) row["_fallback_order"] = row.get("log") or global_idx rows.append(row) rows.sort( key=lambda row: ( row.get("_sort_ts") or datetime.max, int(row.get("_fallback_order") or 0), ) ) for idx, row in enumerate(rows, start=1): row["log"] = idx return rows def _visible_screener_history_rows(rows: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]: visible_keys = ("log", "screener", "attempt", "score", "band", "taken_at") return [{key: row.get(key) for key in visible_keys} for row in rows] def _render_screener_timeline(lang: str) -> None: st.markdown(f"##### {t('dashboard_screener_timeline_heading', lang)}") rows = _combined_screener_history_rows(lang) if not rows: st.caption(t("dashboard_screener_timeline_empty", lang)) return plot_rows = [row for row in rows if row.get("score") is not None] if plot_rows: fig = go.Figure() screener_names = [] for row in plot_rows: name = str(row.get("screener") or "") if name and name not in screener_names: screener_names.append(name) for screener_name in screener_names: screener_rows = [row for row in plot_rows if row.get("screener") == screener_name] fig.add_trace( go.Scatter( x=[int(row["log"]) for row in screener_rows], y=[int(row["score"]) for row in screener_rows], mode="lines+markers+text", name=screener_name, text=[str(row["score"]) for row in screener_rows], textposition="top center", customdata=[ [row.get("taken_at", "-"), row.get("attempt", "-")] for row in screener_rows ], hovertemplate=( "%{fullData.name}
" "Log %{x}
" "Score %{y}
" "Attempt %{customdata[1]}
" "Taken %{customdata[0]}" ), marker=dict(size=14), line=dict(width=3), ) ) max_log = max(int(row["log"]) for row in plot_rows) fig.update_layout( xaxis_title="Log order", yaxis_title="Score", height=280, margin=dict(l=10, r=10, t=50, b=10), title=t("dashboard_screener_timeline_heading", lang), ) fig.update_xaxes( type="linear", range=[0.5, max_log + 0.5], tickmode="linear", dtick=1, ) fig.update_yaxes(range=[0, 27], rangemode="tozero") st.plotly_chart(fig, use_container_width=True) st.dataframe(_visible_screener_history_rows(rows), use_container_width=True, hide_index=True) def _score_snapshot(score: Optional[int], max_score: int) -> str: if score is None: return "-" return f"{score}/{max_score}" def _render_session_snapshot(lang: str) -> None: signals = _aggregate_session_signals() tier, _ = _compute_stepped_care_recommendation(signals) phq9_config = _screener_config("phq9") gad7_config = _screener_config("gad7") checkins = st.session_state.get(CHECKINS_KEY, []) with st.container(border=True): st.caption(f"**{t('journal_snapshot_heading', lang)}**") cols = st.columns(4) cols[0].metric( "PHQ-9", _score_snapshot(signals.get("phq9_score"), _max_score(phq9_config)), ) cols[1].metric( "GAD-7", _score_snapshot(signals.get("gad7_score"), _max_score(gad7_config)), ) cols[2].metric(t("journal_snapshot_checkins", lang), len(checkins)) cols[3].metric(t("journal_snapshot_care_tier", lang), t(f"tier_{tier}", lang)) st.caption(t("journal_snapshot_hint", lang)) def _render_checkin_trend(lang: str) -> None: st.markdown(f"##### {t('dashboard_checkin_heading', lang)}") checkins = list(st.session_state.get(CHECKINS_KEY, []) or []) rows = _checkin_trend_rows(checkins) if not rows: st.caption(t("checkin_trend_caption_empty", lang)) return fig = go.Figure() metric_names = [] for row in rows: name = str(row.get("metric") or "") if name and name not in metric_names: metric_names.append(name) for metric_name in metric_names: metric_rows = [row for row in rows if row.get("metric") == metric_name] fig.add_trace( go.Scatter( x=[int(row["check_in"]) for row in metric_rows], y=[float(row["value"]) for row in metric_rows], mode="lines+markers+text", name=metric_name, text=[ str(int(row["value"])) if float(row["value"]).is_integer() else str(row["value"]) for row in metric_rows ], textposition="top center", customdata=[[row.get("logged_at", "-")] for row in metric_rows], hovertemplate=( "%{fullData.name}
" "Check-in %{x}
" "Value %{y}
" "Logged %{customdata[0]}" ), marker=dict(size=14), line=dict(width=3), ) ) max_checkin = max(int(row["check_in"]) for row in rows) fig.update_layout( xaxis_title="Check-in order", yaxis_title="Value", height=300, margin=dict(l=10, r=10, t=50, b=10), title=t("checkin_trend_title", lang), ) fig.update_xaxes( type="linear", range=[0.5, max_checkin + 0.5], tickmode="linear", dtick=1, ) fig.update_yaxes(range=[0, 12], rangemode="tozero") st.plotly_chart(fig, use_container_width=True) _render_checkin_summary_line(lang) def _sorted_checkins(checkins: Sequence[Mapping[str, Any]]) -> List[Mapping[str, Any]]: indexed_rows = list(enumerate(checkins)) indexed_rows.sort( key=lambda indexed_item: ( _parse_timestamp(str(indexed_item[1].get("ts", ""))) or datetime.max, indexed_item[0], ) ) return [item for _, item in indexed_rows] def _checkin_trend_rows(checkins: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]: rows: List[Dict[str, Any]] = [] for idx, item in enumerate(_sorted_checkins(checkins), start=1): for metric_key, label in ( ("mood", "Mood"), ("sleep_h", "Sleep hours"), ("stress", "Stress"), ): value = item.get(metric_key) if value is None: continue try: value = float(value) except (TypeError, ValueError): continue rows.append( { "check_in": idx, "metric": label, "value": value, "logged_at": _timestamp_label(str(item.get("ts", ""))), } ) return rows def _checkin_history_rows(lang: str) -> List[Dict[str, Any]]: rows = [] for idx, item in enumerate(_sorted_checkins(st.session_state.get(CHECKINS_KEY, []) or []), start=1): rows.append( { "check_in": idx, "logged_at": _timestamp_label(str(item.get("ts", ""))), "mood": item.get("mood"), "sleep_h": item.get("sleep_h"), "stress": item.get("stress"), } ) return rows def _render_checkin_history(lang: str) -> None: rows = _checkin_history_rows(lang) if not rows: return st.markdown(f"##### {t('dashboard_checkin_history_heading', lang)}") st.dataframe(rows, use_container_width=True, hide_index=True) def _render_crossmodule_footer(lang: str) -> None: signals = _aggregate_session_signals() rows = [] if signals["chat_turns"]: rows.append(f"{signals['chat_turns']} chat turn(s) in {t('tab_chat', lang)}") if signals["student_events"]: rows.append(f"{signals['student_events']} student support event(s)") if signals["legal_lookups"]: rows.append(f"{signals['legal_lookups']} legal lookup(s)") if signals["soothe_poems"]: rows.append(f"{signals['soothe_poems']} poem(s) generated") with st.expander(t("crossmodule_footer_heading", lang)): if not rows: st.caption(t("crossmodule_footer_none", lang)) for row in rows: st.caption(row) def _render_insights_dashboard(lang: str) -> None: st.subheader(t("dashboard_header", lang)) st.caption(t("dashboard_sub", lang)) signals = _aggregate_session_signals() tier, rationale_key = _compute_stepped_care_recommendation(signals) _render_stepped_care_card(lang, tier, rationale_key) st.caption(t("screener_disclaimer", lang)) _render_screener_snapshot(lang) _render_screener_timeline(lang) st.divider() _render_checkin_trend(lang) _render_checkin_history(lang) st.divider() _render_journal_history(lang) st.divider() st.markdown(f"##### {t('dashboard_distortion_heading', lang)}") _render_chart(lang) st.divider() _render_crossmodule_footer(lang) def render(lang: str) -> None: _init_state() st.header(t("journal_header", lang)) st.caption(t("journal_sub", lang)) _render_journal_explainer(lang) section = _render_section_router(lang) if section != "dashboard": _render_session_snapshot(lang) if section == "journal": _render_journal(lang) elif section == "phq9": _render_phq9(lang) elif section == "gad7": _render_gad7(lang) elif section == "checkin": _render_daily_checkin(lang) elif section == "dashboard": _render_insights_dashboard(lang) # --------------------------------------------------------------------------- # Cross-module memory — read-only digest exported to Saathi Chat # --------------------------------------------------------------------------- # # Design contract: # - This function is the ONLY way other modules may read Thought Diary # state. It returns a short English digest string, never raw entry text. # - No PII leaves the module: only counts, top category names, scores, trends. # - Returns "" (empty string) when the user has no journal data yet, so the # calling module can insert the string into a prompt template unconditionally # without breaking the layout when it's the user's first turn. # - Deterministic: no LLM call, no randomness. Safe to invoke every chat turn. # - English on purpose — this is a system-prompt block for the LLM, not for # the user. The LLM is already instructed to respond in the user's language. def get_cognitive_journal_context() -> str: """Return a short English digest of the user's Thought Diary activity. Designed to be injected into Saathi Chat's system prompt so the chat module can reference the user's recent journal / screener / check-in activity without holding its own copy of that state. Empty string if no data is available yet. """ entries = list(st.session_state.get(ENTRIES_KEY, []) or []) checkins = list(st.session_state.get(CHECKINS_KEY, []) or []) phq9_history = list(st.session_state.get(PHQ9_HISTORY_KEY, []) or []) gad7_history = list(st.session_state.get(GAD7_HISTORY_KEY, []) or []) if not (entries or checkins or phq9_history or gad7_history): return "" lines: List[str] = [] # --- Journal entries + distortion pattern --- if entries: distortion_summary = _summarize_distortions(entries) crisis_flags = sum(1 for e in entries if _entry_needs_pro(e)) mood_counter: Counter[str] = Counter() for e in entries: mood = _entry_mood(e) if mood: mood_counter[mood] += 1 top_mood = mood_counter.most_common(1) top_mood_label = top_mood[0][0] if top_mood else None journal_line = f"- Journal entries this session: {len(entries)}" if top_mood_label: journal_line += f" (most common mood tag: {top_mood_label})" lines.append(journal_line) # --- Include actual entry text so Chat can reference specifics --- recent = entries[-5:] # last 5 entries max lines.append("- Recent journal entries (newest last):") for idx, entry in enumerate(recent, 1): text = _entry_text(entry) if len(text) > 200: text = text[:197].rstrip() + "..." mood = _entry_mood(entry) or "unknown" distortions_list = _entry_distortions(entry) distortion_types = [_distortion_type(d) for d in distortions_list if _distortion_type(d)] dtypes_str = ", ".join(distortion_types) if distortion_types else "none" summary = _entry_summary(entry) if len(summary) > 150: summary = summary[:147].rstrip() + "..." lines.append( f' {idx}. "{text}" ' f"[mood: {mood}, distortions: {dtypes_str}] " f"Observation: {summary}" ) if distortion_summary: top_type = distortion_summary["top_type"] top_count = distortion_summary["top_count"] top_label = DISTORTION_LABELS.get(top_type, top_type) total = distortion_summary["total_distortions"] dist_line = ( f"- Top cognitive distortion: {top_label} " f"({top_count}x of {total} total distortions tagged)" ) if distortion_summary.get("second_type"): second_label = DISTORTION_LABELS.get( distortion_summary["second_type"], distortion_summary["second_type"], ) dist_line += f"; second most common: {second_label}" lines.append(dist_line) else: lines.append( "- No CBT distortions were tagged in those entries " "(the entries mostly describe feelings, not distorted thoughts)." ) if crisis_flags: lines.append( f"- {crisis_flags} entry/entries were flagged as needing " "professional support (crisis language or LLM-flagged severity)." ) # --- PHQ-9 history --- phq9_summary = _summarize_screener_history("phq9") if phq9_summary: if phq9_summary["direction"] == "single": lines.append( f"- PHQ-9 (depression): 1 attempt, score " f"{phq9_summary['latest']}/{phq9_summary['max']} " f"(band: {phq9_summary['latest_band']})." ) else: lines.append( f"- PHQ-9 (depression): {phq9_summary['attempts']} attempts, " f"latest {phq9_summary['latest']}/{phq9_summary['max']} " f"(band: {phq9_summary['latest_band']}), " f"change from first attempt: {int(phq9_summary['delta']):+d} " f"({phq9_summary['direction']})." ) # --- GAD-7 history --- gad7_summary = _summarize_screener_history("gad7") if gad7_summary: if gad7_summary["direction"] == "single": lines.append( f"- GAD-7 (anxiety): 1 attempt, score " f"{gad7_summary['latest']}/{gad7_summary['max']} " f"(band: {gad7_summary['latest_band']})." ) else: lines.append( f"- GAD-7 (anxiety): {gad7_summary['attempts']} attempts, " f"latest {gad7_summary['latest']}/{gad7_summary['max']} " f"(band: {gad7_summary['latest_band']}), " f"change from first attempt: {int(gad7_summary['delta']):+d} " f"({gad7_summary['direction']})." ) # --- Daily check-ins --- checkin_summary = _summarize_checkins(checkins) if checkin_summary and checkin_summary["count"] >= 1: def _f(value: Optional[float]) -> str: return "-" if value is None else f"{value:.1f}" if checkin_summary["count"] >= 2: lines.append( f"- Daily check-ins: {checkin_summary['count']} logged. " f"Mood avg {_f(checkin_summary['mood_avg'])}/10 " f"(trend: {checkin_summary['mood_trend']}), " f"sleep avg {_f(checkin_summary['sleep_avg'])} h, " f"stress avg {_f(checkin_summary['stress_avg'])}/10 " f"(trend: {checkin_summary['stress_trend']})." ) else: lines.append( f"- Daily check-ins: 1 logged so far " f"(mood {_f(checkin_summary['mood_avg'])}/10, " f"sleep {_f(checkin_summary['sleep_avg'])} h, " f"stress {_f(checkin_summary['stress_avg'])}/10). " f"Not enough data for a trend yet." ) # --- Stepped-care tier (deterministic rule-based) --- try: signals = _aggregate_session_signals() tier, _reason_key = _compute_stepped_care_recommendation(signals) tier_human = { "self_care": "self-care", "self_help": "self-help tools", "guided_support": "guided support", "urgent_professional": "urgent professional help", }.get(tier, tier) lines.append(f"- Current stepped-care recommendation: {tier_human}.") except Exception: # Stepped-care is nice-to-have; never fail the digest on a rule-engine bug. pass if not lines: return "" header = ( "# What Saathi knows from this user's Thought Diary (this session only)\n" "# You may reference this naturally when it is relevant, but do NOT read\n" "# it out like a database dump. Use it to personalise your reply.\n" "# Never claim you can see journal entries the user did not mention —\n" "# only cite the patterns below." ) return header + "\n" + "\n".join(lines)