"""Session state management for REFRAME — all persisted via gr.BrowserState.""" from __future__ import annotations import json import uuid from dataclasses import asdict, dataclass, field from datetime import datetime @dataclass class ThoughtCard: id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) date: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) situation: str = "" automatic_thought: str = "" distortions: list[str] = field(default_factory=list) evidence_for: list[str] = field(default_factory=list) evidence_against: list[str] = field(default_factory=list) balanced_thought: str = "" emotion_before: tuple[str, int] = ("", 0) emotion_after: tuple[str, int] = ("", 0) @dataclass class Experiment: description: str = "" set_date: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) status: str = "pending" # pending | completed | skipped outcome: str = "" insight: str = "" @dataclass class SessionState: session_count: int = 0 cards: list[dict] = field(default_factory=list) experiments: list[dict] = field(default_factory=list) distortion_counts: dict[str, int] = field(default_factory=dict) mood_log: list[tuple[str, int]] = field(default_factory=list) last_summary: str = "" days_since_last: int = 0 mood_trend: str = "" conversation_summaries: list[str] = field(default_factory=list) def serialize_session(state: SessionState) -> str: """Serialize session to JSON string for BrowserState.""" return json.dumps(asdict(state)) def deserialize_session(data: str | None) -> SessionState: """Deserialize session from BrowserState JSON string.""" if not data: return SessionState() try: d = json.loads(data) if isinstance(data, str) else data return SessionState(**d) except (json.JSONDecodeError, TypeError): return SessionState() def save_card(state: SessionState, card: ThoughtCard) -> SessionState: """Save a completed thought card to session.""" state.cards.append(asdict(card)) # Update distortion counts for d in card.distortions: state.distortion_counts[d] = state.distortion_counts.get(d, 0) + 1 # Cap at max from config import MAX_CARDS if len(state.cards) > MAX_CARDS: state.cards = state.cards[-MAX_CARDS:] return state def save_experiment(state: SessionState, experiment: Experiment) -> SessionState: """Save a behavioral experiment to session.""" state.experiments.append(asdict(experiment)) return state def start_new_session(state: SessionState) -> SessionState: """Increment session count and calculate days since last.""" state.session_count += 1 return state def get_session_context(state: SessionState) -> dict: """Get session data formatted for prompt injection.""" return { "session_count": state.session_count, "cards": state.cards, "card_count": len(state.cards), "experiments": state.experiments, "distortion_counts": state.distortion_counts, "last_summary": state.last_summary, "days_since_last": state.days_since_last, "mood_trend": state.mood_trend, }