reframe / session.py
Venkatesh Rajagopal
REFRAME: live CBT studio — fine-tuned Gemma 12B on Modal + Cohere voice (ZeroGPU)
4ae4ae8
Raw
History Blame Contribute Delete
3.27 kB
"""Session state management for REFRAME — all persisted via gr.BrowserState."""
from __future__ import annotations
import json
import uuid
from dataclasses import asdict, dataclass, field
from datetime import datetime
@dataclass
class ThoughtCard:
id: str = field(default_factory=lambda: str(uuid.uuid4())[:8])
date: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
situation: str = ""
automatic_thought: str = ""
distortions: list[str] = field(default_factory=list)
evidence_for: list[str] = field(default_factory=list)
evidence_against: list[str] = field(default_factory=list)
balanced_thought: str = ""
emotion_before: tuple[str, int] = ("", 0)
emotion_after: tuple[str, int] = ("", 0)
@dataclass
class Experiment:
description: str = ""
set_date: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d"))
status: str = "pending" # pending | completed | skipped
outcome: str = ""
insight: str = ""
@dataclass
class SessionState:
session_count: int = 0
cards: list[dict] = field(default_factory=list)
experiments: list[dict] = field(default_factory=list)
distortion_counts: dict[str, int] = field(default_factory=dict)
mood_log: list[tuple[str, int]] = field(default_factory=list)
last_summary: str = ""
days_since_last: int = 0
mood_trend: str = ""
conversation_summaries: list[str] = field(default_factory=list)
def serialize_session(state: SessionState) -> str:
"""Serialize session to JSON string for BrowserState."""
return json.dumps(asdict(state))
def deserialize_session(data: str | None) -> SessionState:
"""Deserialize session from BrowserState JSON string."""
if not data:
return SessionState()
try:
d = json.loads(data) if isinstance(data, str) else data
return SessionState(**d)
except (json.JSONDecodeError, TypeError):
return SessionState()
def save_card(state: SessionState, card: ThoughtCard) -> SessionState:
"""Save a completed thought card to session."""
state.cards.append(asdict(card))
# Update distortion counts
for d in card.distortions:
state.distortion_counts[d] = state.distortion_counts.get(d, 0) + 1
# Cap at max
from config import MAX_CARDS
if len(state.cards) > MAX_CARDS:
state.cards = state.cards[-MAX_CARDS:]
return state
def save_experiment(state: SessionState, experiment: Experiment) -> SessionState:
"""Save a behavioral experiment to session."""
state.experiments.append(asdict(experiment))
return state
def start_new_session(state: SessionState) -> SessionState:
"""Increment session count and calculate days since last."""
state.session_count += 1
return state
def get_session_context(state: SessionState) -> dict:
"""Get session data formatted for prompt injection."""
return {
"session_count": state.session_count,
"cards": state.cards,
"card_count": len(state.cards),
"experiments": state.experiments,
"distortion_counts": state.distortion_counts,
"last_summary": state.last_summary,
"days_since_last": state.days_since_last,
"mood_trend": state.mood_trend,
}