Saathi / modules /cognitive_journal.py
Samarth Gupta
Fixed everything
a6406c6
"""Module 4 - Thought Diary (née Cognitive Journal): CBT journal plus deterministic clinical signals.
The original journal path stays intact: user entry -> crisis check -> structured
LLM JSON -> distortion cards. Around it we add non-LLM clinical surfaces:
PHQ-9, GAD-7, daily check-ins, and an auditable stepped-care dashboard.
"""
from __future__ import annotations
from collections import Counter
from datetime import datetime
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple
import plotly.graph_objects as go
import streamlit as st
from pydantic import BaseModel, Field
from backend.claude_client import chat_structured
from backend.i18n import claude_language_name, t
from backend.resources import load_screeners
from backend.safeguards import check_crisis, render_crisis_banner
MODULE_NAME = "cognitive_journal"
ENTRIES_KEY = "cognitive_journal_entries"
LAST_ANALYSIS_KEY = "cognitive_journal_last"
SECTION_KEY = "cognitive_journal_section"
JOURNAL_WIDGET_VERSION_KEY = "cognitive_journal_widget_version"
JOURNAL_NOTICE_KEY = "cognitive_journal_notice"
PHQ9_SCORE_KEY = "cognitive_journal_phq9_score"
PHQ9_ANSWERS_KEY = "cognitive_journal_phq9_answers"
PHQ9_TAKEN_AT_KEY = "cognitive_journal_phq9_taken_at"
PHQ9_HISTORY_KEY = "cognitive_journal_phq9_history"
PHQ9_ITEM9_KEY = "cognitive_journal_phq9_item9_positive"
PHQ9_WIDGET_VERSION_KEY = "cognitive_journal_phq9_widget_version"
GAD7_SCORE_KEY = "cognitive_journal_gad7_score"
GAD7_ANSWERS_KEY = "cognitive_journal_gad7_answers"
GAD7_TAKEN_AT_KEY = "cognitive_journal_gad7_taken_at"
GAD7_HISTORY_KEY = "cognitive_journal_gad7_history"
GAD7_WIDGET_VERSION_KEY = "cognitive_journal_gad7_widget_version"
CHECKINS_KEY = "cognitive_journal_checkins"
LEGACY_CLEANUP_KEY = "cognitive_journal_legacy_cleanup_done"
ScreenerKind = Literal["phq9", "gad7"]
SECTION_CHOICES = [
("journal", "journal_section_journal_label"),
("phq9", "journal_section_phq9_label"),
("gad7", "journal_section_gad7_label"),
("checkin", "journal_section_checkin_label"),
("dashboard", "journal_section_dashboard_label"),
]
DistortionType = Literal[
"catastrophizing",
"mind_reading",
"all_or_nothing",
"fortune_telling",
"personalization",
"mental_filter",
"emotional_reasoning",
"should_statements",
]
MoodType = Literal[
"anxious",
"sad",
"frustrated",
"hopeful",
"neutral",
"overwhelmed",
]
class Distortion(BaseModel):
type: DistortionType
phrase: str
explanation: str
reframe: str
evidence_question: str
class JournalAnalysis(BaseModel):
overall_mood: MoodType
distortions: List[Distortion] = Field(default_factory=list)
summary: str
needs_professional_signal: bool = False
ts: str = ""
entry_text: str = ""
DISTORTION_LABELS = {
"catastrophizing": "Catastrophizing",
"mind_reading": "Mind-reading",
"all_or_nothing": "All-or-nothing thinking",
"fortune_telling": "Fortune-telling",
"personalization": "Personalization",
"mental_filter": "Mental filter",
"emotional_reasoning": "Emotional reasoning",
"should_statements": "Should-statements",
}
_LEGACY_SEED_PHRASES = {
"my whole placement season is ruined",
"everyone in my batch thinks I'm useless",
"I should have studied harder in first year",
"my parents are disappointed because of me",
"I'm going to freeze up in the viva and fail",
"if I don't get an A I've wasted the whole semester",
"I only remember the parts I got wrong in the last test",
}
_LEGACY_SEED_CHECKINS = [
{"mood": 5, "sleep_h": 6.5, "stress": 6},
{"mood": 6, "sleep_h": 7.0, "stress": 5},
]
def _is_legacy_seed_entry(entry: Any) -> bool:
if isinstance(entry, Mapping):
distortions = entry.get("distortions", []) or []
else:
distortions = getattr(entry, "distortions", []) or []
phrases = set()
for distortion in distortions:
if isinstance(distortion, Mapping):
phrase = distortion.get("phrase", "")
else:
phrase = getattr(distortion, "phrase", "")
if phrase:
phrases.add(str(phrase))
return bool(phrases) and phrases.issubset(_LEGACY_SEED_PHRASES)
def _drop_legacy_seed_entries(entries: Sequence[Any]) -> List[Any]:
return [entry for entry in entries if not _is_legacy_seed_entry(entry)]
def _same_checkin_values(item: Mapping[str, Any], expected: Mapping[str, Any]) -> bool:
try:
return (
int(item.get("mood")) == int(expected["mood"])
and float(item.get("sleep_h")) == float(expected["sleep_h"])
and int(item.get("stress")) == int(expected["stress"])
)
except (TypeError, ValueError):
return False
def _drop_legacy_seed_checkins(checkins: Sequence[Mapping[str, Any]]) -> List[Mapping[str, Any]]:
checkins_list = list(checkins)
if len(checkins_list) < len(_LEGACY_SEED_CHECKINS):
return checkins_list
for item, expected in zip(checkins_list, _LEGACY_SEED_CHECKINS):
if not _same_checkin_values(item, expected):
return checkins_list
return checkins_list[len(_LEGACY_SEED_CHECKINS) :]
def _init_state() -> None:
needs_legacy_cleanup = not bool(st.session_state.get(LEGACY_CLEANUP_KEY))
if ENTRIES_KEY not in st.session_state:
st.session_state[ENTRIES_KEY] = []
elif needs_legacy_cleanup:
st.session_state[ENTRIES_KEY] = _drop_legacy_seed_entries(
list(st.session_state.get(ENTRIES_KEY, []) or [])
)
if LAST_ANALYSIS_KEY not in st.session_state:
st.session_state[LAST_ANALYSIS_KEY] = None
if SECTION_KEY not in st.session_state:
st.session_state[SECTION_KEY] = "journal"
if JOURNAL_WIDGET_VERSION_KEY not in st.session_state:
st.session_state[JOURNAL_WIDGET_VERSION_KEY] = 0
if JOURNAL_NOTICE_KEY not in st.session_state:
st.session_state[JOURNAL_NOTICE_KEY] = None
if PHQ9_SCORE_KEY not in st.session_state:
st.session_state[PHQ9_SCORE_KEY] = None
if PHQ9_ANSWERS_KEY not in st.session_state:
st.session_state[PHQ9_ANSWERS_KEY] = [0] * 9
if PHQ9_TAKEN_AT_KEY not in st.session_state:
st.session_state[PHQ9_TAKEN_AT_KEY] = ""
if PHQ9_HISTORY_KEY not in st.session_state:
st.session_state[PHQ9_HISTORY_KEY] = []
if PHQ9_ITEM9_KEY not in st.session_state:
st.session_state[PHQ9_ITEM9_KEY] = False
if PHQ9_WIDGET_VERSION_KEY not in st.session_state:
st.session_state[PHQ9_WIDGET_VERSION_KEY] = 0
if GAD7_SCORE_KEY not in st.session_state:
st.session_state[GAD7_SCORE_KEY] = None
if GAD7_ANSWERS_KEY not in st.session_state:
st.session_state[GAD7_ANSWERS_KEY] = [0] * 7
if GAD7_TAKEN_AT_KEY not in st.session_state:
st.session_state[GAD7_TAKEN_AT_KEY] = ""
if GAD7_HISTORY_KEY not in st.session_state:
st.session_state[GAD7_HISTORY_KEY] = []
if GAD7_WIDGET_VERSION_KEY not in st.session_state:
st.session_state[GAD7_WIDGET_VERSION_KEY] = 0
if CHECKINS_KEY not in st.session_state:
st.session_state[CHECKINS_KEY] = []
elif needs_legacy_cleanup:
st.session_state[CHECKINS_KEY] = _drop_legacy_seed_checkins(
list(st.session_state.get(CHECKINS_KEY, []) or [])
)
st.session_state[LEGACY_CLEANUP_KEY] = True
def _render_section_router(lang: str) -> str:
options = [key for key, _ in SECTION_CHOICES]
# Normalise the session-state value before the widget reads it, so we never
# pass both `index=` and an existing `key=` (Streamlit forbids that combo).
current = st.session_state.get(SECTION_KEY, "journal")
if current not in options:
st.session_state[SECTION_KEY] = "journal"
choice = st.radio(
"journal-section",
options=options,
format_func=lambda key: t(dict(SECTION_CHOICES)[key], lang),
horizontal=True,
label_visibility="collapsed",
key=SECTION_KEY,
)
return str(choice)
def _render_journal_explainer(lang: str) -> None:
with st.expander(t("journal_about_heading", lang), expanded=False):
st.markdown(t("journal_about_body", lang))
def _render_analysis(analysis: Any, lang: str) -> None:
st.markdown(f"##### {t('journal_summary_heading', lang)}")
st.info(_entry_summary(analysis))
if _entry_needs_pro(analysis):
st.warning(t("journal_needs_pro", lang))
distortions = _entry_distortions(analysis)
if not distortions:
st.success(t("journal_no_distortions", lang))
return
st.markdown(f"##### {t('journal_distortions_heading', lang)}")
for d in distortions:
distortion_type = _distortion_type(d)
with st.container(border=True):
st.markdown(f"> *{_distortion_phrase(d)}*")
st.markdown(
f"**{DISTORTION_LABELS.get(distortion_type, distortion_type)}** - "
f"{_distortion_explanation(d)}"
)
st.markdown(f"**{t('journal_reframe_heading', lang)}:** {_distortion_reframe(d)}")
st.caption(
f"**{t('journal_question_heading', lang)}:** "
f"{_distortion_evidence_question(d)}"
)
def _entry_distortions(entry: Any) -> Sequence[Any]:
if isinstance(entry, JournalAnalysis):
return entry.distortions
if isinstance(entry, Mapping):
return entry.get("distortions", []) or []
return getattr(entry, "distortions", []) or []
def _entry_needs_pro(entry: Any) -> bool:
if isinstance(entry, JournalAnalysis):
return bool(entry.needs_professional_signal)
if isinstance(entry, Mapping):
return bool(entry.get("needs_professional_signal"))
return bool(getattr(entry, "needs_professional_signal", False))
def _entry_timestamp(entry: Any) -> str:
if isinstance(entry, JournalAnalysis):
return str(entry.ts or "")
if isinstance(entry, Mapping):
return str(entry.get("ts", ""))
return str(getattr(entry, "ts", ""))
def _entry_summary(entry: Any) -> str:
if isinstance(entry, JournalAnalysis):
return str(entry.summary or "")
if isinstance(entry, Mapping):
return str(entry.get("summary", ""))
return str(getattr(entry, "summary", ""))
def _entry_mood(entry: Any) -> str:
if isinstance(entry, JournalAnalysis):
return str(entry.overall_mood or "")
if isinstance(entry, Mapping):
return str(entry.get("overall_mood", ""))
return str(getattr(entry, "overall_mood", ""))
def _entry_text(entry: Any) -> str:
if isinstance(entry, JournalAnalysis):
return str(entry.entry_text or "")
if isinstance(entry, Mapping):
return str(entry.get("entry_text", ""))
return str(getattr(entry, "entry_text", ""))
def _distortion_type(distortion: Any) -> str:
if isinstance(distortion, Distortion):
return distortion.type
if isinstance(distortion, Mapping):
return str(distortion.get("type", ""))
return str(getattr(distortion, "type", ""))
def _distortion_phrase(distortion: Any) -> str:
if isinstance(distortion, Distortion):
return str(distortion.phrase or "")
if isinstance(distortion, Mapping):
return str(distortion.get("phrase", ""))
return str(getattr(distortion, "phrase", ""))
def _distortion_explanation(distortion: Any) -> str:
if isinstance(distortion, Distortion):
return str(distortion.explanation or "")
if isinstance(distortion, Mapping):
return str(distortion.get("explanation", ""))
return str(getattr(distortion, "explanation", ""))
def _distortion_reframe(distortion: Any) -> str:
if isinstance(distortion, Distortion):
return str(distortion.reframe or "")
if isinstance(distortion, Mapping):
return str(distortion.get("reframe", ""))
return str(getattr(distortion, "reframe", ""))
def _distortion_evidence_question(distortion: Any) -> str:
if isinstance(distortion, Distortion):
return str(distortion.evidence_question or "")
if isinstance(distortion, Mapping):
return str(distortion.get("evidence_question", ""))
return str(getattr(distortion, "evidence_question", ""))
def _analysis_to_journal_record(analysis: Any, entry_text: str) -> Dict[str, Any]:
if isinstance(analysis, Mapping):
record = dict(analysis)
elif hasattr(analysis, "model_dump"):
record = dict(analysis.model_dump(mode="python"))
elif hasattr(analysis, "dict"):
record = dict(analysis.dict())
else:
record = {
"overall_mood": getattr(analysis, "overall_mood", "neutral"),
"distortions": getattr(analysis, "distortions", []) or [],
"summary": getattr(analysis, "summary", ""),
"needs_professional_signal": bool(
getattr(analysis, "needs_professional_signal", False)
),
}
record["ts"] = datetime.now().isoformat(timespec="seconds")
record["entry_text"] = entry_text.strip()
record.setdefault("distortions", [])
record.setdefault("summary", "")
record.setdefault("overall_mood", "neutral")
record.setdefault("needs_professional_signal", False)
return record
def _fallback_journal_record(entry_text: str, lang: str, error: Exception) -> Dict[str, Any]:
record = _analysis_to_journal_record(
{
"overall_mood": "neutral",
"distortions": [],
"summary": t("journal_analysis_failed_summary", lang),
"needs_professional_signal": False,
},
entry_text,
)
record["analysis_error"] = str(error)[:240]
return record
def _crisis_journal_record(entry_text: str, lang: str) -> Dict[str, Any]:
"""Record used when crisis language is detected.
We intentionally do NOT call the LLM on crisis text — the banner is the
priority. We still append an entry so the user's history is preserved and
the dashboard / chart keeps moving forward.
We also tag a default distortion so the distortion chart still reflects
crisis entries (otherwise they'd show as "no distortions" which is
misleading — the user's crisis text almost certainly contains cognitive
distortions, we just can't safely send it to the LLM).
"""
# Extract a short snippet for the phrase field (first ~120 chars of entry)
snippet = entry_text.strip()
if len(snippet) > 120:
snippet = snippet[:117].rstrip() + "..."
record = _analysis_to_journal_record(
{
"overall_mood": "overwhelmed",
"distortions": [
{
"type": "catastrophizing",
"phrase": snippet,
"explanation": t("journal_crisis_distortion_explanation", lang),
"reframe": t("journal_crisis_distortion_reframe", lang),
"evidence_question": t("journal_crisis_distortion_question", lang),
},
],
"summary": t("journal_crisis_summary", lang),
"needs_professional_signal": True,
},
entry_text,
)
record["crisis_flagged"] = True
return record
def _append_journal_record(record: Mapping[str, Any]) -> List[Any]:
entries = list(st.session_state.get(ENTRIES_KEY, []) or [])
entries.append(dict(record))
st.session_state[ENTRIES_KEY] = entries
st.session_state[LAST_ANALYSIS_KEY] = dict(record)
return entries
def _short_text(text: str, limit: int = 96) -> str:
cleaned = " ".join(str(text or "").split())
if not cleaned:
return "-"
if len(cleaned) <= limit:
return cleaned
return f"{cleaned[: limit - 3].rstrip()}..."
def _journal_history_rows(lang: str) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
entries = list(st.session_state.get(ENTRIES_KEY, []) or [])
for fallback_order, entry in enumerate(entries, start=1):
ts = _entry_timestamp(entry)
distortions = _entry_distortions(entry)
rows.append(
{
"entry": fallback_order,
"logged_at": _timestamp_label(ts),
"mood": _entry_mood(entry) or "-",
"distortions": len(distortions),
"support": t("journal_support_yes", lang)
if _entry_needs_pro(entry)
else t("journal_support_no", lang),
"note": _short_text(_entry_text(entry)),
"observation": _short_text(_entry_summary(entry), limit=140),
"_sort_ts": _parse_timestamp(ts),
"_fallback_order": fallback_order,
}
)
rows.sort(
key=lambda row: (
row.get("_sort_ts") or datetime.max,
int(row.get("_fallback_order") or 0),
)
)
for idx, row in enumerate(rows, start=1):
row["entry"] = idx
visible_keys = (
"entry",
"logged_at",
"mood",
"distortions",
"support",
"note",
"observation",
)
return [{key: row.get(key) for key in visible_keys} for row in rows]
def _render_journal_history(lang: str) -> None:
st.markdown(f"##### {t('dashboard_journal_history_heading', lang)}")
rows = _journal_history_rows(lang)
if not rows:
st.caption(t("dashboard_journal_history_empty", lang))
return
st.dataframe(
rows,
use_container_width=True,
hide_index=True,
column_config={
"entry": st.column_config.NumberColumn(t("journal_col_entry", lang), width="small"),
"logged_at": st.column_config.TextColumn(t("journal_col_logged", lang), width="small"),
"mood": st.column_config.TextColumn(t("journal_col_mood", lang), width="small"),
"distortions": st.column_config.NumberColumn(t("journal_col_distortions", lang), width="small"),
"support": st.column_config.TextColumn(t("journal_col_support", lang), width="small"),
"note": st.column_config.TextColumn(t("journal_col_note", lang), width="medium"),
"observation": st.column_config.TextColumn(t("journal_col_observation", lang), width="large"),
},
)
def _distortion_chart_rows(entries: Sequence[Any]) -> List[Dict[str, Any]]:
counter: Counter[str] = Counter()
for entry in entries:
for d in _entry_distortions(entry):
key = _distortion_type(d)
if key:
counter[key] += 1
return [
{"distortion": DISTORTION_LABELS.get(k, k), "count": v}
for k, v in counter.most_common()
]
def _render_chart(lang: str) -> None:
entries = list(st.session_state.get(ENTRIES_KEY, []) or [])
rows = _distortion_chart_rows(entries)
if not rows:
if entries:
# Entries exist, but Saathi didn't tag any CBT distortions in them.
# This is common when the text describes feelings ("I feel lonely")
# rather than distorted thoughts ("my life is ruined forever").
st.info(
t("dashboard_distortion_no_distortions", lang).format(
count=len(entries)
)
)
st.caption(t("dashboard_distortion_try_examples", lang))
with st.expander(t("dashboard_distortion_examples_heading", lang), expanded=False):
st.markdown(t("dashboard_distortion_examples_body", lang))
else:
st.caption(t("dashboard_distortion_empty_hint", lang))
return
max_count = max(int(row["count"]) for row in rows)
fig = go.Figure()
fig.add_trace(
go.Bar(
x=[row["distortion"] for row in rows],
y=[int(row["count"]) for row in rows],
text=[str(row["count"]) for row in rows],
textposition="outside",
marker=dict(color="#2563EB", line=dict(color="#1D4ED8", width=1)),
hovertemplate=f"<b>%{{x}}</b><br>{t('journal_chart_yaxis', lang)} %{{y}}<extra></extra>",
)
)
fig.update_layout(
title=t("journal_chart_title", lang),
xaxis_title=t("journal_chart_xaxis", lang),
yaxis_title=t("journal_chart_yaxis", lang),
showlegend=False,
height=320,
margin=dict(l=10, r=10, t=50, b=10),
bargap=0.45,
plot_bgcolor="rgba(0,0,0,0)",
paper_bgcolor="rgba(0,0,0,0)",
)
fig.update_xaxes(type="category", tickangle=0)
fig.update_yaxes(
range=[0, max(2, max_count + 1)],
tickmode="linear",
dtick=1,
rangemode="tozero",
)
fig.update_traces(cliponaxis=False)
st.plotly_chart(fig, use_container_width=True)
_render_distortion_summary_line(entries, lang)
def _render_screener_summary_line(kind: ScreenerKind, lang: str) -> None:
summary = _summarize_screener_history(kind)
if not summary:
return
config = _screener_config(kind)
name = str(config.get("name", kind.upper()))
if summary["direction"] == "single":
st.info(
t("screener_summary_single_attempt", lang).format(
name=name,
latest=summary["latest"],
max=summary["max"],
band=t(f"band_{summary['latest_band']}", lang),
)
)
return
delta = int(summary["delta"])
delta_signed = f"{delta:+d}"
direction_label = t(f"screener_direction_{summary['direction']}", lang)
st.info(
t("screener_summary_line", lang).format(
name=name,
attempts=summary["attempts"],
latest=summary["latest"],
max=summary["max"],
band=t(f"band_{summary['latest_band']}", lang),
delta_signed=delta_signed,
direction=direction_label,
)
)
def _render_checkin_summary_line(lang: str) -> None:
checkins = list(st.session_state.get(CHECKINS_KEY, []) or [])
summary = _summarize_checkins(checkins)
if not summary:
return
if summary["count"] < 2:
st.caption(t("checkin_summary_single", lang))
return
def _fmt(value: Optional[float]) -> str:
if value is None:
return "-"
return f"{value:.1f}"
st.info(
t("checkin_summary_line", lang).format(
count=summary["count"],
mood_avg=_fmt(summary["mood_avg"]),
sleep_avg=_fmt(summary["sleep_avg"]),
stress_avg=_fmt(summary["stress_avg"]),
mood_trend=t(f"checkin_trend_{summary['mood_trend']}", lang),
stress_trend=t(f"checkin_trend_{summary['stress_trend']}", lang),
)
)
def _render_distortion_summary_line(entries: Sequence[Any], lang: str) -> None:
summary = _summarize_distortions(entries)
if not summary:
return
top_label = DISTORTION_LABELS.get(summary["top_type"], summary["top_type"])
if summary.get("second_type"):
second_label = DISTORTION_LABELS.get(
summary["second_type"], summary["second_type"]
)
line = t("distortion_summary_line", lang).format(
top=top_label,
top_count=summary["top_count"],
total_entries=summary["total_entries"],
second=second_label,
second_count=summary["second_count"],
)
else:
line = t("distortion_summary_single_line", lang).format(
top=top_label,
top_count=summary["top_count"],
total_entries=summary["total_entries"],
)
tip = t(summary["tip_key"], lang)
st.info(f"{line}\n\n**{t('distortion_summary_tip_prefix', lang)}** {tip}")
def _render_journal(lang: str) -> None:
notice = st.session_state.get(JOURNAL_NOTICE_KEY)
if isinstance(notice, Mapping) and notice.get("message"):
level = notice.get("level")
if level == "warning":
st.warning(str(notice["message"]))
else:
st.success(str(notice["message"]))
st.session_state[JOURNAL_NOTICE_KEY] = None
input_key = f"cognitive_journal_input_v{st.session_state[JOURNAL_WIDGET_VERSION_KEY]}"
entry = st.text_area(
"journal_entry",
placeholder=t("journal_input_placeholder", lang),
label_visibility="collapsed",
height=140,
key=input_key,
)
if st.button(t("journal_send_button", lang), type="primary", key="cognitive_journal_send_button"):
entry_text = entry.strip()
if entry_text:
if check_crisis(entry_text):
# Crisis detected: show banner immediately, and STILL log the
# entry (with the pro-signal flag set) so the user's record is
# preserved and the dashboard keeps working.
render_crisis_banner(lang)
record = _crisis_journal_record(entry_text, lang)
entries = _append_journal_record(record)
st.session_state[JOURNAL_NOTICE_KEY] = {
"level": "warning",
"message": t("journal_crisis_logged", lang).format(count=len(entries)),
}
st.session_state[JOURNAL_WIDGET_VERSION_KEY] += 1
return
with st.spinner("..."):
try:
analysis = chat_structured(
module=MODULE_NAME,
user_text=entry_text,
language_name=claude_language_name(lang),
schema=JournalAnalysis,
)
record = _analysis_to_journal_record(analysis, entry_text)
entries = _append_journal_record(record)
st.session_state[JOURNAL_NOTICE_KEY] = {
"level": "success",
"message": t("journal_logged", lang).format(count=len(entries)),
}
except Exception as e:
record = _fallback_journal_record(entry_text, lang, e)
entries = _append_journal_record(record)
st.session_state[JOURNAL_NOTICE_KEY] = {
"level": "warning",
"message": t("journal_saved_without_analysis", lang).format(
count=len(entries)
),
}
st.session_state[JOURNAL_WIDGET_VERSION_KEY] += 1
st.rerun()
if st.session_state[LAST_ANALYSIS_KEY] is not None:
_render_analysis(st.session_state[LAST_ANALYSIS_KEY], lang)
_render_recent_journal_entries(lang)
def _render_recent_journal_entries(lang: str) -> None:
"""Compact 'previous entries' list shown directly under the Journal tab."""
entries = list(st.session_state.get(ENTRIES_KEY, []) or [])
if len(entries) <= 1:
return
with st.expander(t("journal_recent_heading", lang), expanded=False):
rows = _journal_history_rows(lang)
if not rows:
st.caption(t("journal_recent_empty", lang))
return
recent_rows = list(reversed(rows))[:5]
st.dataframe(
recent_rows,
use_container_width=True,
hide_index=True,
column_config={
"entry": st.column_config.NumberColumn(t("journal_col_entry", lang), width="small"),
"logged_at": st.column_config.TextColumn(t("journal_col_logged", lang), width="small"),
"mood": st.column_config.TextColumn(t("journal_col_mood", lang), width="small"),
"distortions": st.column_config.NumberColumn(t("journal_col_distortions", lang), width="small"),
"support": st.column_config.TextColumn(t("journal_col_support", lang), width="small"),
"note": st.column_config.TextColumn(t("journal_col_note", lang), width="medium"),
"observation": st.column_config.TextColumn(t("journal_col_observation", lang), width="large"),
},
)
def _score_screener(answers: List[int]) -> int:
return sum(int(a) for a in answers)
def _band_for(score: int, bands: List[dict]) -> str:
for band in bands:
if score <= int(band["max"]):
return str(band["key"])
return str(bands[-1]["key"]) if bands else "none"
# --- Deterministic summary helpers (no Streamlit, unit-testable) -----------
def _direction_from_delta(delta: int) -> str:
"""Lower score = improving for both PHQ-9 and GAD-7."""
if delta <= -2:
return "improving"
if delta >= 2:
return "worsening"
return "stable"
def _summarize_screener_history(kind: ScreenerKind) -> Optional[Dict[str, Any]]:
"""Return a short structured summary of a screener's history.
Returns None when history is empty. When exactly one attempt exists,
`delta` and `direction` are None / "single".
"""
history = list(st.session_state.get(_history_key(kind), []) or [])
if not history:
return None
scores = [_coerce_int(item.get("score")) for item in history]
scores = [s for s in scores if s is not None]
if not scores:
return None
config = _screener_config(kind)
bands = list(config.get("bands") or [])
latest_score = int(scores[-1])
first_score = int(scores[0])
latest_band = _band_for(latest_score, bands)
max_score = _max_score(config)
if len(scores) == 1:
return {
"attempts": 1,
"latest": latest_score,
"latest_band": latest_band,
"max": max_score,
"first": latest_score,
"delta": 0,
"direction": "single",
}
delta = latest_score - first_score
return {
"attempts": len(scores),
"latest": latest_score,
"latest_band": latest_band,
"max": max_score,
"first": first_score,
"delta": delta,
"direction": _direction_from_delta(delta),
}
def _trend_label(prev_avg: Optional[float], last_avg: Optional[float]) -> str:
if prev_avg is None or last_avg is None:
return "flat"
diff = last_avg - prev_avg
if diff >= 0.75:
return "up"
if diff <= -0.75:
return "down"
return "flat"
def _mean(values: Sequence[float]) -> Optional[float]:
nums = [float(v) for v in values if v is not None]
return sum(nums) / len(nums) if nums else None
def _summarize_checkins(
checkins: Sequence[Mapping[str, Any]],
) -> Optional[Dict[str, Any]]:
sorted_items = list(_sorted_checkins(checkins))
if not sorted_items:
return None
moods = [float(item.get("mood")) for item in sorted_items if item.get("mood") is not None]
sleeps = [
float(item.get("sleep_h"))
for item in sorted_items
if item.get("sleep_h") is not None
]
stresses = [
float(item.get("stress"))
for item in sorted_items
if item.get("stress") is not None
]
last_moods = moods[-3:]
prev_moods = moods[-6:-3] if len(moods) >= 6 else moods[:-3]
last_stress = stresses[-3:]
prev_stress = stresses[-6:-3] if len(stresses) >= 6 else stresses[:-3]
return {
"count": len(sorted_items),
"mood_avg": _mean(moods),
"sleep_avg": _mean(sleeps),
"stress_avg": _mean(stresses),
"mood_trend": _trend_label(_mean(prev_moods), _mean(last_moods)),
"stress_trend": _trend_label(_mean(prev_stress), _mean(last_stress)),
}
def _summarize_distortions(entries: Sequence[Any]) -> Optional[Dict[str, Any]]:
counter: Counter[str] = Counter()
total_distortions = 0
for entry in entries:
for d in _entry_distortions(entry):
key = _distortion_type(d)
if key:
counter[key] += 1
total_distortions += 1
if total_distortions == 0:
return None
top_entries = counter.most_common(2)
top_type, top_count = top_entries[0]
second_type, second_count = (top_entries[1] if len(top_entries) > 1 else (None, 0))
return {
"top_type": top_type,
"top_count": int(top_count),
"second_type": second_type,
"second_count": int(second_count),
"total_entries": len(entries),
"total_distortions": total_distortions,
"tip_key": f"distortion_tip_{top_type}",
}
def _screener_config(kind: ScreenerKind) -> Dict[str, Any]:
data = load_screeners()
return dict(data.get(kind) or {})
def _max_score(config: Mapping[str, Any]) -> int:
bands = config.get("bands") or []
return int(bands[-1]["max"]) if bands else 0
def _response_labels(config: Mapping[str, Any], lang: str) -> List[str]:
labels = config.get("response_labels") or {}
selected = labels.get(lang) or labels.get("en") or []
return list(selected)
def _localized_item(item: Mapping[str, Any], lang: str) -> str:
if item.get(lang):
return str(item[lang])
text = str(item.get("en", ""))
if lang != "en":
return f"{text} (translation pending)"
return text
def _answers_key(kind: ScreenerKind) -> str:
return PHQ9_ANSWERS_KEY if kind == "phq9" else GAD7_ANSWERS_KEY
def _score_key(kind: ScreenerKind) -> str:
return PHQ9_SCORE_KEY if kind == "phq9" else GAD7_SCORE_KEY
def _taken_key(kind: ScreenerKind) -> str:
return PHQ9_TAKEN_AT_KEY if kind == "phq9" else GAD7_TAKEN_AT_KEY
def _history_key(kind: ScreenerKind) -> str:
return PHQ9_HISTORY_KEY if kind == "phq9" else GAD7_HISTORY_KEY
def _version_key(kind: ScreenerKind) -> str:
return PHQ9_WIDGET_VERSION_KEY if kind == "phq9" else GAD7_WIDGET_VERSION_KEY
def _next_screener_log_order() -> int:
return (
len(st.session_state.get(PHQ9_HISTORY_KEY, []) or [])
+ len(st.session_state.get(GAD7_HISTORY_KEY, []) or [])
+ 1
)
def _ensure_answers(kind: ScreenerKind, n_items: int) -> List[int]:
key = _answers_key(kind)
answers = list(st.session_state.get(key) or [])
if len(answers) != n_items:
answers = [0] * n_items
st.session_state[key] = answers
return answers
def _record_screener(kind: ScreenerKind, config: Mapping[str, Any], answers: List[int]) -> None:
score = _score_screener(answers)
band = _band_for(score, list(config.get("bands") or []))
ts = datetime.now().isoformat(timespec="seconds")
log_order = _next_screener_log_order()
q9_positive = bool(kind == "phq9" and len(answers) >= 9 and answers[8] > 0)
st.session_state[_score_key(kind)] = score
st.session_state[_answers_key(kind)] = list(answers)
st.session_state[_taken_key(kind)] = ts
if kind == "phq9":
st.session_state[PHQ9_ITEM9_KEY] = q9_positive
history = list(st.session_state.get(_history_key(kind), []))
history.append(
{
"ts": ts,
"log_order": log_order,
"score": score,
"band": band,
"answers": list(answers),
"q9_positive": q9_positive,
}
)
st.session_state[_history_key(kind)] = history
def _parse_timestamp(ts: str) -> Optional[datetime]:
if not ts:
return None
try:
return datetime.fromisoformat(ts)
except ValueError:
return None
def _timestamp_label(ts: str) -> str:
if not ts:
return "-"
dt = _parse_timestamp(ts)
if dt is None:
return ts
if dt.second:
return dt.strftime("%d %b, %H:%M:%S")
return dt.strftime("%d %b, %H:%M")
def _render_screener_result(kind: ScreenerKind, config: Mapping[str, Any], score: int, lang: str) -> None:
band = _band_for(score, list(config.get("bands") or []))
max_score = _max_score(config)
band_label = t(f"band_{band}", lang)
interp = t(f"band_{band}_interpretation", lang)
with st.container(border=True):
cols = st.columns(3)
cols[0].metric(t("screener_score_label", lang), f"{score}/{max_score}")
cols[1].metric(t("screener_band_label", lang), band_label)
cols[2].metric(t("screener_last_taken", lang), _timestamp_label(st.session_state[_taken_key(kind)]))
st.write(interp)
st.caption(t("screener_disclaimer", lang))
def _render_single_screener_timeline(kind: ScreenerKind, lang: str) -> None:
rows = _screener_history_rows(kind, lang)
if not rows:
return
rows.sort(
key=lambda row: (
row.get("_sort_ts") or datetime.max,
int(row.get("attempt") or 0),
)
)
plot_rows = [row for row in rows if row.get("score") is not None]
if not plot_rows:
return
config = _screener_config(kind)
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=[int(row["attempt"]) for row in plot_rows],
y=[int(row["score"]) for row in plot_rows],
mode="lines+markers+text",
name=config.get("name", kind.upper()),
text=[str(row["score"]) for row in plot_rows],
textposition="top center",
customdata=[[row.get("taken_at", "-")] for row in plot_rows],
hovertemplate=(
"<b>%{fullData.name}</b><br>"
"Attempt %{x}<br>"
"Score %{y}<br>"
"Taken %{customdata[0]}<extra></extra>"
),
marker=dict(size=14),
line=dict(width=3),
)
)
max_attempt = max(int(row["attempt"]) for row in plot_rows)
fig.update_layout(
title=f"{config.get('name', kind.upper())} {t('dashboard_screener_timeline_heading', lang)}",
xaxis_title="Attempt",
yaxis_title="Score",
height=260,
margin=dict(l=10, r=10, t=50, b=10),
showlegend=False,
)
fig.update_xaxes(
type="linear",
range=[0.5, max_attempt + 0.5],
tickmode="linear",
dtick=1,
)
fig.update_yaxes(range=[0, _max_score(config)], rangemode="tozero")
st.plotly_chart(fig, use_container_width=True)
def _render_screener(kind: ScreenerKind, lang: str) -> None:
config = _screener_config(kind)
if not config:
st.error("Screener data could not be loaded.")
return
items = list(config.get("items") or [])
answers = _ensure_answers(kind, len(items))
labels = _response_labels(config, lang)
if len(labels) != 4:
labels = ["Not at all", "Several days", "More than half the days", "Nearly every day"]
st.subheader(t(f"{kind}_header", lang))
st.caption(t(f"{kind}_sub", lang))
st.caption(f"**{t('screener_timeframe_prefix', lang)}:** {config.get('timeframe', '')}")
if lang != "en":
st.info(t("screener_translation_pending", lang))
st.caption(t("screener_source_caption", lang))
version = st.session_state[_version_key(kind)]
for idx, item in enumerate(items):
answer = st.radio(
_localized_item(item, lang),
options=[0, 1, 2, 3],
index=int(answers[idx]),
format_func=lambda value, labels=labels: f"{value} - {labels[value]}",
horizontal=True,
key=f"cognitive_journal_{kind}_v{version}_item_{idx}",
)
answers[idx] = int(answer)
st.session_state[_answers_key(kind)] = answers
button_cols = st.columns([1, 1, 4])
with button_cols[0]:
score_clicked = st.button(
t(f"{kind}_cta", lang),
type="primary",
key=f"cognitive_journal_{kind}_score_button",
disabled=st.session_state[_score_key(kind)] is not None,
)
with button_cols[1]:
retake_clicked = st.button(
t(f"{kind}_retake", lang),
key=f"cognitive_journal_{kind}_retake",
disabled=st.session_state[_score_key(kind)] is None,
)
if score_clicked:
_record_screener(kind, config, answers)
if retake_clicked:
st.session_state[_answers_key(kind)] = [0] * len(items)
st.session_state[_score_key(kind)] = None
st.session_state[_taken_key(kind)] = ""
if kind == "phq9":
st.session_state[PHQ9_ITEM9_KEY] = False
st.session_state[_version_key(kind)] += 1
st.rerun()
score = st.session_state[_score_key(kind)]
if score is not None:
_render_screener_result(kind, config, int(score), lang)
_render_single_screener_timeline(kind, lang)
_render_screener_summary_line(kind, lang)
def _render_phq9(lang: str) -> None:
_render_screener("phq9", lang)
def _render_gad7(lang: str) -> None:
_render_screener("gad7", lang)
def _append_checkin(mood: int, sleep_h: float, stress: int) -> None:
checkins = list(st.session_state.get(CHECKINS_KEY, []))
checkins.append(
{
"ts": datetime.now().isoformat(timespec="seconds"),
"mood": int(mood),
"sleep_h": float(sleep_h),
"stress": int(stress),
}
)
st.session_state[CHECKINS_KEY] = checkins
def _render_daily_checkin(lang: str) -> None:
st.subheader(t("checkin_header", lang))
st.caption(t("checkin_sub", lang))
mood = st.slider(t("checkin_mood_label", lang), 1, 10, 6, key="cognitive_journal_checkin_mood")
sleep_h = st.slider(
t("checkin_sleep_label", lang),
0.0,
12.0,
7.0,
0.5,
key="cognitive_journal_checkin_sleep",
)
stress = st.slider(t("checkin_stress_label", lang), 1, 10, 5, key="cognitive_journal_checkin_stress")
if st.button(t("checkin_submit", lang), type="primary", key="cognitive_journal_checkin_submit"):
_append_checkin(mood, sleep_h, stress)
st.success(t("checkin_logged", lang))
st.markdown(f"##### {t('checkin_history_heading', lang)}")
checkins = st.session_state.get(CHECKINS_KEY, [])
if not checkins:
st.caption(t("checkin_empty_history", lang))
return
for item in reversed(_sorted_checkins(checkins)[-5:]):
st.write(
f"{_timestamp_label(str(item.get('ts', '')))} · "
f"mood {item.get('mood')}/10 · sleep {item.get('sleep_h')}h · stress {item.get('stress')}/10"
)
_render_checkin_trend(lang)
def _coerce_int(value: Any) -> Optional[int]:
if value is None or value == "":
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def _recent_mood_avg(checkins: Sequence[Mapping[str, Any]]) -> Optional[float]:
if not checkins:
return None
recent = checkins[-3:]
moods = []
for item in recent:
try:
moods.append(float(item.get("mood")))
except (TypeError, ValueError):
continue
if not moods:
return None
return sum(moods) / len(moods)
def _aggregate_session_signals() -> dict:
entries = list(st.session_state.get(ENTRIES_KEY, []))
recent_entries = entries[-3:]
distortion_total = sum(len(_entry_distortions(entry)) for entry in recent_entries)
distortion_density = distortion_total / len(recent_entries) if recent_entries else 0.0
pro_signal_count = sum(1 for entry in recent_entries if _entry_needs_pro(entry))
checkins = _sorted_checkins(st.session_state.get(CHECKINS_KEY, []) or [])
chat_history = st.session_state.get("saathi_chat_history", []) or []
chat_turns = sum(1 for msg in chat_history if isinstance(msg, Mapping) and msg.get("role") == "user")
student_events = 1 if st.session_state.get("student_response") or st.session_state.get("student_situation") else 0
legal_lookups = 1 if st.session_state.get("legal_aid_response") or st.session_state.get("legal_aid_sections") else 0
soothe_poems = 1 if st.session_state.get("soothe_poem") else 0
return {
"phq9_score": _coerce_int(st.session_state.get(PHQ9_SCORE_KEY)),
"gad7_score": _coerce_int(st.session_state.get(GAD7_SCORE_KEY)),
"phq9_item9_positive": bool(st.session_state.get(PHQ9_ITEM9_KEY)),
"recent_mood_avg": _recent_mood_avg(checkins),
"distortion_density": distortion_density,
"pro_signal_count": pro_signal_count,
"crisis_history": False,
"chat_turns": chat_turns,
"student_events": student_events,
"legal_lookups": legal_lookups,
"soothe_poems": soothe_poems,
}
def _score_at_least(score: Optional[int], threshold: int) -> bool:
return score is not None and score >= threshold
def _compute_stepped_care_recommendation(signals: dict) -> Tuple[str, str]:
phq9 = _coerce_int(signals.get("phq9_score"))
gad7 = _coerce_int(signals.get("gad7_score"))
recent_mood_avg = signals.get("recent_mood_avg")
distortion_density = float(signals.get("distortion_density") or 0.0)
pro_signal_count = int(signals.get("pro_signal_count") or 0)
low_mood_with_distortions = (
recent_mood_avg is not None
and float(recent_mood_avg) <= 3
and distortion_density >= 2
)
if (
_score_at_least(phq9, 20)
or _score_at_least(gad7, 15)
or pro_signal_count >= 2
or bool(signals.get("phq9_item9_positive"))
or bool(signals.get("crisis_history"))
):
return "urgent_professional", "tier_reason_urgent_professional"
if _score_at_least(phq9, 15) or _score_at_least(gad7, 10) or low_mood_with_distortions:
return "guided_support", "tier_reason_guided_support"
if _score_at_least(phq9, 10) or _score_at_least(gad7, 5) or distortion_density >= 1.5:
return "self_help", "tier_reason_self_help"
return "self_care", "tier_reason_self_care"
def _render_stepped_care_card(lang: str, tier: str, rationale_key: str) -> None:
title = t(f"tier_{tier}", lang)
body = t(f"tier_{tier}_body", lang)
rationale = t(rationale_key, lang)
message = f"**{title}**\n\n{body}\n\n_{rationale}_"
if tier == "urgent_professional":
st.error(message)
elif tier == "guided_support":
st.warning(message)
elif tier == "self_help":
st.info(message)
else:
st.success(message)
st.caption(t("tier_disclaimer", lang))
def _render_screener_metric(kind: ScreenerKind, lang: str) -> None:
config = _screener_config(kind)
score = st.session_state.get(_score_key(kind))
if score is None:
st.info(t(f"{kind}_not_taken_nudge", lang))
return
score_int = int(score)
band = _band_for(score_int, list(config.get("bands") or []))
with st.container(border=True):
st.metric(config.get("name", kind.upper()), f"{score_int}/{_max_score(config)}")
st.caption(f"{t(f'band_{band}', lang)} · {t('screener_last_taken', lang)}: {_timestamp_label(st.session_state[_taken_key(kind)])}")
def _render_screener_snapshot(lang: str) -> None:
st.markdown(f"##### {t('dashboard_screener_history_heading', lang)}")
cols = st.columns(2)
with cols[0]:
_render_screener_metric("phq9", lang)
with cols[1]:
_render_screener_metric("gad7", lang)
def _screener_history_rows(kind: ScreenerKind, lang: str) -> List[Dict[str, Any]]:
config = _screener_config(kind)
rows = []
for idx, item in enumerate(st.session_state.get(_history_key(kind), []) or [], start=1):
score = _coerce_int(item.get("score"))
band = str(item.get("band") or "none")
ts = str(item.get("ts", ""))
log_order = _coerce_int(item.get("log_order"))
rows.append(
{
"log": log_order,
"attempt": idx,
"screener": config.get("name", kind.upper()),
"score": score,
"band": t(f"band_{band}", lang),
"taken_at": _timestamp_label(ts),
"_raw_ts": ts,
"_sort_ts": _parse_timestamp(ts),
"_fallback_order": idx,
}
)
return rows
def _combined_screener_history_rows(lang: str) -> List[Dict[str, Any]]:
rows = []
for global_idx, row in enumerate(
_screener_history_rows("phq9", lang) + _screener_history_rows("gad7", lang),
start=1,
):
row = dict(row)
row["_fallback_order"] = row.get("log") or global_idx
rows.append(row)
rows.sort(
key=lambda row: (
row.get("_sort_ts") or datetime.max,
int(row.get("_fallback_order") or 0),
)
)
for idx, row in enumerate(rows, start=1):
row["log"] = idx
return rows
def _visible_screener_history_rows(rows: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]:
visible_keys = ("log", "screener", "attempt", "score", "band", "taken_at")
return [{key: row.get(key) for key in visible_keys} for row in rows]
def _render_screener_timeline(lang: str) -> None:
st.markdown(f"##### {t('dashboard_screener_timeline_heading', lang)}")
rows = _combined_screener_history_rows(lang)
if not rows:
st.caption(t("dashboard_screener_timeline_empty", lang))
return
plot_rows = [row for row in rows if row.get("score") is not None]
if plot_rows:
fig = go.Figure()
screener_names = []
for row in plot_rows:
name = str(row.get("screener") or "")
if name and name not in screener_names:
screener_names.append(name)
for screener_name in screener_names:
screener_rows = [row for row in plot_rows if row.get("screener") == screener_name]
fig.add_trace(
go.Scatter(
x=[int(row["log"]) for row in screener_rows],
y=[int(row["score"]) for row in screener_rows],
mode="lines+markers+text",
name=screener_name,
text=[str(row["score"]) for row in screener_rows],
textposition="top center",
customdata=[
[row.get("taken_at", "-"), row.get("attempt", "-")]
for row in screener_rows
],
hovertemplate=(
"<b>%{fullData.name}</b><br>"
"Log %{x}<br>"
"Score %{y}<br>"
"Attempt %{customdata[1]}<br>"
"Taken %{customdata[0]}<extra></extra>"
),
marker=dict(size=14),
line=dict(width=3),
)
)
max_log = max(int(row["log"]) for row in plot_rows)
fig.update_layout(
xaxis_title="Log order",
yaxis_title="Score",
height=280,
margin=dict(l=10, r=10, t=50, b=10),
title=t("dashboard_screener_timeline_heading", lang),
)
fig.update_xaxes(
type="linear",
range=[0.5, max_log + 0.5],
tickmode="linear",
dtick=1,
)
fig.update_yaxes(range=[0, 27], rangemode="tozero")
st.plotly_chart(fig, use_container_width=True)
st.dataframe(_visible_screener_history_rows(rows), use_container_width=True, hide_index=True)
def _score_snapshot(score: Optional[int], max_score: int) -> str:
if score is None:
return "-"
return f"{score}/{max_score}"
def _render_session_snapshot(lang: str) -> None:
signals = _aggregate_session_signals()
tier, _ = _compute_stepped_care_recommendation(signals)
phq9_config = _screener_config("phq9")
gad7_config = _screener_config("gad7")
checkins = st.session_state.get(CHECKINS_KEY, [])
with st.container(border=True):
st.caption(f"**{t('journal_snapshot_heading', lang)}**")
cols = st.columns(4)
cols[0].metric(
"PHQ-9",
_score_snapshot(signals.get("phq9_score"), _max_score(phq9_config)),
)
cols[1].metric(
"GAD-7",
_score_snapshot(signals.get("gad7_score"), _max_score(gad7_config)),
)
cols[2].metric(t("journal_snapshot_checkins", lang), len(checkins))
cols[3].metric(t("journal_snapshot_care_tier", lang), t(f"tier_{tier}", lang))
st.caption(t("journal_snapshot_hint", lang))
def _render_checkin_trend(lang: str) -> None:
st.markdown(f"##### {t('dashboard_checkin_heading', lang)}")
checkins = list(st.session_state.get(CHECKINS_KEY, []) or [])
rows = _checkin_trend_rows(checkins)
if not rows:
st.caption(t("checkin_trend_caption_empty", lang))
return
fig = go.Figure()
metric_names = []
for row in rows:
name = str(row.get("metric") or "")
if name and name not in metric_names:
metric_names.append(name)
for metric_name in metric_names:
metric_rows = [row for row in rows if row.get("metric") == metric_name]
fig.add_trace(
go.Scatter(
x=[int(row["check_in"]) for row in metric_rows],
y=[float(row["value"]) for row in metric_rows],
mode="lines+markers+text",
name=metric_name,
text=[
str(int(row["value"]))
if float(row["value"]).is_integer()
else str(row["value"])
for row in metric_rows
],
textposition="top center",
customdata=[[row.get("logged_at", "-")] for row in metric_rows],
hovertemplate=(
"<b>%{fullData.name}</b><br>"
"Check-in %{x}<br>"
"Value %{y}<br>"
"Logged %{customdata[0]}<extra></extra>"
),
marker=dict(size=14),
line=dict(width=3),
)
)
max_checkin = max(int(row["check_in"]) for row in rows)
fig.update_layout(
xaxis_title="Check-in order",
yaxis_title="Value",
height=300,
margin=dict(l=10, r=10, t=50, b=10),
title=t("checkin_trend_title", lang),
)
fig.update_xaxes(
type="linear",
range=[0.5, max_checkin + 0.5],
tickmode="linear",
dtick=1,
)
fig.update_yaxes(range=[0, 12], rangemode="tozero")
st.plotly_chart(fig, use_container_width=True)
_render_checkin_summary_line(lang)
def _sorted_checkins(checkins: Sequence[Mapping[str, Any]]) -> List[Mapping[str, Any]]:
indexed_rows = list(enumerate(checkins))
indexed_rows.sort(
key=lambda indexed_item: (
_parse_timestamp(str(indexed_item[1].get("ts", ""))) or datetime.max,
indexed_item[0],
)
)
return [item for _, item in indexed_rows]
def _checkin_trend_rows(checkins: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
for idx, item in enumerate(_sorted_checkins(checkins), start=1):
for metric_key, label in (
("mood", "Mood"),
("sleep_h", "Sleep hours"),
("stress", "Stress"),
):
value = item.get(metric_key)
if value is None:
continue
try:
value = float(value)
except (TypeError, ValueError):
continue
rows.append(
{
"check_in": idx,
"metric": label,
"value": value,
"logged_at": _timestamp_label(str(item.get("ts", ""))),
}
)
return rows
def _checkin_history_rows(lang: str) -> List[Dict[str, Any]]:
rows = []
for idx, item in enumerate(_sorted_checkins(st.session_state.get(CHECKINS_KEY, []) or []), start=1):
rows.append(
{
"check_in": idx,
"logged_at": _timestamp_label(str(item.get("ts", ""))),
"mood": item.get("mood"),
"sleep_h": item.get("sleep_h"),
"stress": item.get("stress"),
}
)
return rows
def _render_checkin_history(lang: str) -> None:
rows = _checkin_history_rows(lang)
if not rows:
return
st.markdown(f"##### {t('dashboard_checkin_history_heading', lang)}")
st.dataframe(rows, use_container_width=True, hide_index=True)
def _render_crossmodule_footer(lang: str) -> None:
signals = _aggregate_session_signals()
rows = []
if signals["chat_turns"]:
rows.append(f"{signals['chat_turns']} chat turn(s) in {t('tab_chat', lang)}")
if signals["student_events"]:
rows.append(f"{signals['student_events']} student support event(s)")
if signals["legal_lookups"]:
rows.append(f"{signals['legal_lookups']} legal lookup(s)")
if signals["soothe_poems"]:
rows.append(f"{signals['soothe_poems']} poem(s) generated")
with st.expander(t("crossmodule_footer_heading", lang)):
if not rows:
st.caption(t("crossmodule_footer_none", lang))
for row in rows:
st.caption(row)
def _render_insights_dashboard(lang: str) -> None:
st.subheader(t("dashboard_header", lang))
st.caption(t("dashboard_sub", lang))
signals = _aggregate_session_signals()
tier, rationale_key = _compute_stepped_care_recommendation(signals)
_render_stepped_care_card(lang, tier, rationale_key)
st.caption(t("screener_disclaimer", lang))
_render_screener_snapshot(lang)
_render_screener_timeline(lang)
st.divider()
_render_checkin_trend(lang)
_render_checkin_history(lang)
st.divider()
_render_journal_history(lang)
st.divider()
st.markdown(f"##### {t('dashboard_distortion_heading', lang)}")
_render_chart(lang)
st.divider()
_render_crossmodule_footer(lang)
def render(lang: str) -> None:
_init_state()
st.header(t("journal_header", lang))
st.caption(t("journal_sub", lang))
_render_journal_explainer(lang)
section = _render_section_router(lang)
if section != "dashboard":
_render_session_snapshot(lang)
if section == "journal":
_render_journal(lang)
elif section == "phq9":
_render_phq9(lang)
elif section == "gad7":
_render_gad7(lang)
elif section == "checkin":
_render_daily_checkin(lang)
elif section == "dashboard":
_render_insights_dashboard(lang)
# ---------------------------------------------------------------------------
# Cross-module memory — read-only digest exported to Saathi Chat
# ---------------------------------------------------------------------------
#
# Design contract:
# - This function is the ONLY way other modules may read Thought Diary
# state. It returns a short English digest string, never raw entry text.
# - No PII leaves the module: only counts, top category names, scores, trends.
# - Returns "" (empty string) when the user has no journal data yet, so the
# calling module can insert the string into a prompt template unconditionally
# without breaking the layout when it's the user's first turn.
# - Deterministic: no LLM call, no randomness. Safe to invoke every chat turn.
# - English on purpose — this is a system-prompt block for the LLM, not for
# the user. The LLM is already instructed to respond in the user's language.
def get_cognitive_journal_context() -> str:
"""Return a short English digest of the user's Thought Diary activity.
Designed to be injected into Saathi Chat's system prompt so the chat
module can reference the user's recent journal / screener / check-in
activity without holding its own copy of that state. Empty string if
no data is available yet.
"""
entries = list(st.session_state.get(ENTRIES_KEY, []) or [])
checkins = list(st.session_state.get(CHECKINS_KEY, []) or [])
phq9_history = list(st.session_state.get(PHQ9_HISTORY_KEY, []) or [])
gad7_history = list(st.session_state.get(GAD7_HISTORY_KEY, []) or [])
if not (entries or checkins or phq9_history or gad7_history):
return ""
lines: List[str] = []
# --- Journal entries + distortion pattern ---
if entries:
distortion_summary = _summarize_distortions(entries)
crisis_flags = sum(1 for e in entries if _entry_needs_pro(e))
mood_counter: Counter[str] = Counter()
for e in entries:
mood = _entry_mood(e)
if mood:
mood_counter[mood] += 1
top_mood = mood_counter.most_common(1)
top_mood_label = top_mood[0][0] if top_mood else None
journal_line = f"- Journal entries this session: {len(entries)}"
if top_mood_label:
journal_line += f" (most common mood tag: {top_mood_label})"
lines.append(journal_line)
# --- Include actual entry text so Chat can reference specifics ---
recent = entries[-5:] # last 5 entries max
lines.append("- Recent journal entries (newest last):")
for idx, entry in enumerate(recent, 1):
text = _entry_text(entry)
if len(text) > 200:
text = text[:197].rstrip() + "..."
mood = _entry_mood(entry) or "unknown"
distortions_list = _entry_distortions(entry)
distortion_types = [_distortion_type(d) for d in distortions_list if _distortion_type(d)]
dtypes_str = ", ".join(distortion_types) if distortion_types else "none"
summary = _entry_summary(entry)
if len(summary) > 150:
summary = summary[:147].rstrip() + "..."
lines.append(
f' {idx}. "{text}" '
f"[mood: {mood}, distortions: {dtypes_str}] "
f"Observation: {summary}"
)
if distortion_summary:
top_type = distortion_summary["top_type"]
top_count = distortion_summary["top_count"]
top_label = DISTORTION_LABELS.get(top_type, top_type)
total = distortion_summary["total_distortions"]
dist_line = (
f"- Top cognitive distortion: {top_label} "
f"({top_count}x of {total} total distortions tagged)"
)
if distortion_summary.get("second_type"):
second_label = DISTORTION_LABELS.get(
distortion_summary["second_type"],
distortion_summary["second_type"],
)
dist_line += f"; second most common: {second_label}"
lines.append(dist_line)
else:
lines.append(
"- No CBT distortions were tagged in those entries "
"(the entries mostly describe feelings, not distorted thoughts)."
)
if crisis_flags:
lines.append(
f"- {crisis_flags} entry/entries were flagged as needing "
"professional support (crisis language or LLM-flagged severity)."
)
# --- PHQ-9 history ---
phq9_summary = _summarize_screener_history("phq9")
if phq9_summary:
if phq9_summary["direction"] == "single":
lines.append(
f"- PHQ-9 (depression): 1 attempt, score "
f"{phq9_summary['latest']}/{phq9_summary['max']} "
f"(band: {phq9_summary['latest_band']})."
)
else:
lines.append(
f"- PHQ-9 (depression): {phq9_summary['attempts']} attempts, "
f"latest {phq9_summary['latest']}/{phq9_summary['max']} "
f"(band: {phq9_summary['latest_band']}), "
f"change from first attempt: {int(phq9_summary['delta']):+d} "
f"({phq9_summary['direction']})."
)
# --- GAD-7 history ---
gad7_summary = _summarize_screener_history("gad7")
if gad7_summary:
if gad7_summary["direction"] == "single":
lines.append(
f"- GAD-7 (anxiety): 1 attempt, score "
f"{gad7_summary['latest']}/{gad7_summary['max']} "
f"(band: {gad7_summary['latest_band']})."
)
else:
lines.append(
f"- GAD-7 (anxiety): {gad7_summary['attempts']} attempts, "
f"latest {gad7_summary['latest']}/{gad7_summary['max']} "
f"(band: {gad7_summary['latest_band']}), "
f"change from first attempt: {int(gad7_summary['delta']):+d} "
f"({gad7_summary['direction']})."
)
# --- Daily check-ins ---
checkin_summary = _summarize_checkins(checkins)
if checkin_summary and checkin_summary["count"] >= 1:
def _f(value: Optional[float]) -> str:
return "-" if value is None else f"{value:.1f}"
if checkin_summary["count"] >= 2:
lines.append(
f"- Daily check-ins: {checkin_summary['count']} logged. "
f"Mood avg {_f(checkin_summary['mood_avg'])}/10 "
f"(trend: {checkin_summary['mood_trend']}), "
f"sleep avg {_f(checkin_summary['sleep_avg'])} h, "
f"stress avg {_f(checkin_summary['stress_avg'])}/10 "
f"(trend: {checkin_summary['stress_trend']})."
)
else:
lines.append(
f"- Daily check-ins: 1 logged so far "
f"(mood {_f(checkin_summary['mood_avg'])}/10, "
f"sleep {_f(checkin_summary['sleep_avg'])} h, "
f"stress {_f(checkin_summary['stress_avg'])}/10). "
f"Not enough data for a trend yet."
)
# --- Stepped-care tier (deterministic rule-based) ---
try:
signals = _aggregate_session_signals()
tier, _reason_key = _compute_stepped_care_recommendation(signals)
tier_human = {
"self_care": "self-care",
"self_help": "self-help tools",
"guided_support": "guided support",
"urgent_professional": "urgent professional help",
}.get(tier, tier)
lines.append(f"- Current stepped-care recommendation: {tier_human}.")
except Exception:
# Stepped-care is nice-to-have; never fail the digest on a rule-engine bug.
pass
if not lines:
return ""
header = (
"# What Saathi knows from this user's Thought Diary (this session only)\n"
"# You may reference this naturally when it is relevant, but do NOT read\n"
"# it out like a database dump. Use it to personalise your reply.\n"
"# Never claim you can see journal entries the user did not mention —\n"
"# only cite the patterns below."
)
return header + "\n" + "\n".join(lines)