Spaces:
Sleeping
Sleeping
| """Session metadata store — SQLite so uvicorn --reload and multi-worker restarts don't kill sessions.""" | |
| import sqlite3 | |
| import json | |
| import time | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| import os as _os | |
| _DB_PATH = Path(_os.getenv("DATA_DIR", ".")) / "unmask_sessions.db" | |
| # Bulk LangGraph fields are checkpointed by SqliteSaver — we only store lightweight metadata here. | |
| _SLIM_KEYS = frozenset({ | |
| "phase", "mastery_scores", "weak_topics", "diagnostic_complete", | |
| "consecutive_correct", "consecutive_incorrect", "current_topic", | |
| "study_focus", "learning_mode", "turn_count", | |
| }) | |
| _SESSION_TTL_SEC = 7200 # 2 hours | |
| class Session: | |
| session_id: str | |
| session_start: float = field(default_factory=time.time) | |
| state: dict[str, Any] = field(default_factory=dict) | |
| diag_order: list = field(default_factory=list) | |
| diag_total: int = 0 | |
| diag_q_index: int = 0 | |
| warmup_done: bool = False | |
| study_focus: str = "" | |
| learning_mode: str = "text" | |
| last_diagram_concept: Optional[str] = None | |
| def _conn() -> sqlite3.Connection: | |
| c = sqlite3.connect(str(_DB_PATH), check_same_thread=False) | |
| c.execute("""CREATE TABLE IF NOT EXISTS sessions ( | |
| session_id TEXT PRIMARY KEY, | |
| session_start REAL, | |
| data TEXT | |
| )""") | |
| c.commit() | |
| return c | |
| _db = _conn() | |
| def _row_to_session(row) -> "Session": | |
| data = json.loads(row[2]) | |
| s = Session(session_id=row[0], session_start=row[1]) | |
| s.state = data.get("state", {}) | |
| s.diag_order = data.get("diag_order", []) | |
| s.diag_total = data.get("diag_total", 0) | |
| s.diag_q_index = data.get("diag_q_index", 0) | |
| s.warmup_done = data.get("warmup_done", False) | |
| s.study_focus = data.get("study_focus", "") | |
| s.learning_mode = data.get("learning_mode", "text") | |
| s.last_diagram_concept = data.get("last_diagram_concept") | |
| return s | |
| def _session_to_data(sess: "Session") -> str: | |
| # Only persist slim state keys — LangGraph checkpointer owns the full state | |
| slim_state = {k: v for k, v in (sess.state or {}).items() if k in _SLIM_KEYS} | |
| return json.dumps({ | |
| "state": slim_state, | |
| "diag_order": sess.diag_order, | |
| "diag_total": sess.diag_total, | |
| "diag_q_index": sess.diag_q_index, | |
| "warmup_done": sess.warmup_done, | |
| "study_focus": sess.study_focus, | |
| "learning_mode": sess.learning_mode, | |
| "last_diagram_concept": sess.last_diagram_concept, | |
| }) | |
| def _purge_stale() -> None: | |
| cutoff = time.time() - _SESSION_TTL_SEC | |
| _db.execute("DELETE FROM sessions WHERE session_start < ?", (cutoff,)) | |
| _db.commit() | |
| # In-memory cache so hot path (get_session) doesn't hit SQLite on every SSE token | |
| _cache: dict[str, "Session"] = {} | |
| def create_session() -> "Session": | |
| _purge_stale() | |
| sess = Session(session_id=str(uuid.uuid4())) | |
| _cache[sess.session_id] = sess | |
| _db.execute( | |
| "INSERT INTO sessions (session_id, session_start, data) VALUES (?, ?, ?)", | |
| (sess.session_id, sess.session_start, _session_to_data(sess)), | |
| ) | |
| _db.commit() | |
| return sess | |
| def get_session(session_id: str) -> Optional["Session"]: | |
| if session_id in _cache: | |
| return _cache[session_id] | |
| row = _db.execute( | |
| "SELECT session_id, session_start, data FROM sessions WHERE session_id = ?", | |
| (session_id,), | |
| ).fetchone() | |
| if not row: | |
| return None | |
| sess = _row_to_session(row) | |
| _cache[session_id] = sess | |
| return sess | |
| def save_session(session_id: str) -> None: | |
| sess = _cache.get(session_id) | |
| if not sess: | |
| return | |
| _db.execute( | |
| "UPDATE sessions SET data = ? WHERE session_id = ?", | |
| (_session_to_data(sess), session_id), | |
| ) | |
| _db.commit() | |
| def delete_session(session_id: str) -> None: | |
| _cache.pop(session_id, None) | |
| _db.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,)) | |
| _db.commit() | |