File size: 3,971 Bytes
fff8e78
 
 
2f4273f
 
 
fff8e78
2f4273f
 
e2d1705
 
fff8e78
 
 
 
 
 
 
 
 
2f4273f
 
 
 
 
 
 
fff8e78
2f4273f
 
 
 
 
fff8e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f4273f
fff8e78
 
 
 
 
 
 
 
 
 
 
 
 
2f4273f
fff8e78
 
 
 
 
 
 
 
 
 
 
 
 
2f4273f
fff8e78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2f4273f
 
 
fff8e78
 
 
 
 
 
 
 
 
2f4273f
 
 
fff8e78
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""Session metadata store — SQLite so uvicorn --reload and multi-worker restarts don't kill sessions."""
import sqlite3
import json
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Optional

import os as _os
_DB_PATH = Path(_os.getenv("DATA_DIR", ".")) / "unmask_sessions.db"

# Bulk LangGraph fields are checkpointed by SqliteSaver — we only store lightweight metadata here.
_SLIM_KEYS = frozenset({
    "phase", "mastery_scores", "weak_topics", "diagnostic_complete",
    "consecutive_correct", "consecutive_incorrect", "current_topic",
    "study_focus", "learning_mode", "turn_count",
})

_SESSION_TTL_SEC = 7200  # 2 hours


@dataclass
class Session:
    session_id: str
    session_start: float = field(default_factory=time.time)
    state: dict[str, Any] = field(default_factory=dict)
    diag_order: list = field(default_factory=list)
    diag_total: int = 0
    diag_q_index: int = 0
    warmup_done: bool = False
    study_focus: str = ""
    learning_mode: str = "text"
    last_diagram_concept: Optional[str] = None


def _conn() -> sqlite3.Connection:
    c = sqlite3.connect(str(_DB_PATH), check_same_thread=False)
    c.execute("""CREATE TABLE IF NOT EXISTS sessions (
        session_id TEXT PRIMARY KEY,
        session_start REAL,
        data TEXT
    )""")
    c.commit()
    return c


_db = _conn()


def _row_to_session(row) -> "Session":
    data = json.loads(row[2])
    s = Session(session_id=row[0], session_start=row[1])
    s.state         = data.get("state", {})
    s.diag_order    = data.get("diag_order", [])
    s.diag_total    = data.get("diag_total", 0)
    s.diag_q_index  = data.get("diag_q_index", 0)
    s.warmup_done   = data.get("warmup_done", False)
    s.study_focus   = data.get("study_focus", "")
    s.learning_mode = data.get("learning_mode", "text")
    s.last_diagram_concept = data.get("last_diagram_concept")
    return s


def _session_to_data(sess: "Session") -> str:
    # Only persist slim state keys — LangGraph checkpointer owns the full state
    slim_state = {k: v for k, v in (sess.state or {}).items() if k in _SLIM_KEYS}
    return json.dumps({
        "state": slim_state,
        "diag_order": sess.diag_order,
        "diag_total": sess.diag_total,
        "diag_q_index": sess.diag_q_index,
        "warmup_done": sess.warmup_done,
        "study_focus": sess.study_focus,
        "learning_mode": sess.learning_mode,
        "last_diagram_concept": sess.last_diagram_concept,
    })


def _purge_stale() -> None:
    cutoff = time.time() - _SESSION_TTL_SEC
    _db.execute("DELETE FROM sessions WHERE session_start < ?", (cutoff,))
    _db.commit()


# In-memory cache so hot path (get_session) doesn't hit SQLite on every SSE token
_cache: dict[str, "Session"] = {}


def create_session() -> "Session":
    _purge_stale()
    sess = Session(session_id=str(uuid.uuid4()))
    _cache[sess.session_id] = sess
    _db.execute(
        "INSERT INTO sessions (session_id, session_start, data) VALUES (?, ?, ?)",
        (sess.session_id, sess.session_start, _session_to_data(sess)),
    )
    _db.commit()
    return sess


def get_session(session_id: str) -> Optional["Session"]:
    if session_id in _cache:
        return _cache[session_id]
    row = _db.execute(
        "SELECT session_id, session_start, data FROM sessions WHERE session_id = ?",
        (session_id,),
    ).fetchone()
    if not row:
        return None
    sess = _row_to_session(row)
    _cache[session_id] = sess
    return sess


def save_session(session_id: str) -> None:
    sess = _cache.get(session_id)
    if not sess:
        return
    _db.execute(
        "UPDATE sessions SET data = ? WHERE session_id = ?",
        (_session_to_data(sess), session_id),
    )
    _db.commit()


def delete_session(session_id: str) -> None:
    _cache.pop(session_id, None)
    _db.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
    _db.commit()