Spaces:
Sleeping
Sleeping
Venkatesh Rajagopal
REFRAME: live CBT studio — fine-tuned Gemma 12B on Modal + Cohere voice (ZeroGPU)
4ae4ae8 | """Session state management for REFRAME — all persisted via gr.BrowserState.""" | |
| from __future__ import annotations | |
| import json | |
| import uuid | |
| from dataclasses import asdict, dataclass, field | |
| from datetime import datetime | |
| class ThoughtCard: | |
| id: str = field(default_factory=lambda: str(uuid.uuid4())[:8]) | |
| date: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) | |
| situation: str = "" | |
| automatic_thought: str = "" | |
| distortions: list[str] = field(default_factory=list) | |
| evidence_for: list[str] = field(default_factory=list) | |
| evidence_against: list[str] = field(default_factory=list) | |
| balanced_thought: str = "" | |
| emotion_before: tuple[str, int] = ("", 0) | |
| emotion_after: tuple[str, int] = ("", 0) | |
| class Experiment: | |
| description: str = "" | |
| set_date: str = field(default_factory=lambda: datetime.now().strftime("%Y-%m-%d")) | |
| status: str = "pending" # pending | completed | skipped | |
| outcome: str = "" | |
| insight: str = "" | |
| class SessionState: | |
| session_count: int = 0 | |
| cards: list[dict] = field(default_factory=list) | |
| experiments: list[dict] = field(default_factory=list) | |
| distortion_counts: dict[str, int] = field(default_factory=dict) | |
| mood_log: list[tuple[str, int]] = field(default_factory=list) | |
| last_summary: str = "" | |
| days_since_last: int = 0 | |
| mood_trend: str = "" | |
| conversation_summaries: list[str] = field(default_factory=list) | |
| def serialize_session(state: SessionState) -> str: | |
| """Serialize session to JSON string for BrowserState.""" | |
| return json.dumps(asdict(state)) | |
| def deserialize_session(data: str | None) -> SessionState: | |
| """Deserialize session from BrowserState JSON string.""" | |
| if not data: | |
| return SessionState() | |
| try: | |
| d = json.loads(data) if isinstance(data, str) else data | |
| return SessionState(**d) | |
| except (json.JSONDecodeError, TypeError): | |
| return SessionState() | |
| def save_card(state: SessionState, card: ThoughtCard) -> SessionState: | |
| """Save a completed thought card to session.""" | |
| state.cards.append(asdict(card)) | |
| # Update distortion counts | |
| for d in card.distortions: | |
| state.distortion_counts[d] = state.distortion_counts.get(d, 0) + 1 | |
| # Cap at max | |
| from config import MAX_CARDS | |
| if len(state.cards) > MAX_CARDS: | |
| state.cards = state.cards[-MAX_CARDS:] | |
| return state | |
| def save_experiment(state: SessionState, experiment: Experiment) -> SessionState: | |
| """Save a behavioral experiment to session.""" | |
| state.experiments.append(asdict(experiment)) | |
| return state | |
| def start_new_session(state: SessionState) -> SessionState: | |
| """Increment session count and calculate days since last.""" | |
| state.session_count += 1 | |
| return state | |
| def get_session_context(state: SessionState) -> dict: | |
| """Get session data formatted for prompt injection.""" | |
| return { | |
| "session_count": state.session_count, | |
| "cards": state.cards, | |
| "card_count": len(state.cards), | |
| "experiments": state.experiments, | |
| "distortion_counts": state.distortion_counts, | |
| "last_summary": state.last_summary, | |
| "days_since_last": state.days_since_last, | |
| "mood_trend": state.mood_trend, | |
| } | |