File size: 3,747 Bytes
b59fc2c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""
SQLite-backed session store with TTL for RAG sessions.
Replaces the in-memory dict to persist sessions across server restarts
and prevent unbounded memory growth.
"""
import sqlite3
import json
import time
from pathlib import Path
from typing import Optional, Dict, Any

SESSION_TTL_SECONDS = 24 * 60 * 60  # 24 hours


class SessionStore:
    def __init__(self, db_path: Optional[str] = None):
        if db_path is None:
            db_path = str(Path(__file__).parent.parent / "sessions.db")
        self.db_path = db_path
        self._init_db()
        print(f"[SessionStore] SQLite store at: {self.db_path}")

    def _get_conn(self) -> sqlite3.Connection:
        conn = sqlite3.connect(self.db_path, check_same_thread=False)
        conn.execute("PRAGMA journal_mode=WAL")  # Better concurrent read performance
        conn.row_factory = sqlite3.Row
        return conn

    def _init_db(self):
        with self._get_conn() as conn:
            conn.execute(
                """
                CREATE TABLE IF NOT EXISTS sessions (
                    session_id TEXT PRIMARY KEY,
                    data       TEXT NOT NULL,
                    created_at REAL NOT NULL,
                    updated_at REAL NOT NULL
                )
                """
            )
            conn.commit()

    def get(self, session_id: str) -> Optional[Dict[str, Any]]:
        """Return session data or None if not found / expired."""
        with self._get_conn() as conn:
            row = conn.execute(
                "SELECT data, updated_at FROM sessions WHERE session_id = ?",
                (session_id,),
            ).fetchone()

        if row is None:
            return None

        if time.time() - row["updated_at"] > SESSION_TTL_SECONDS:
            self.delete(session_id)
            return None

        return json.loads(row["data"])

    def set(self, session_id: str, data: Dict[str, Any]):
        """Persist or update a session (all values must be JSON-serialisable)."""
        now = time.time()
        data_json = json.dumps(data, ensure_ascii=False)
        with self._get_conn() as conn:
            conn.execute(
                """
                INSERT INTO sessions (session_id, data, created_at, updated_at)
                VALUES (?, ?, ?, ?)
                ON CONFLICT(session_id) DO UPDATE
                    SET data = excluded.data, updated_at = excluded.updated_at
                """,
                (session_id, data_json, now, now),
            )
            conn.commit()

    def delete(self, session_id: str):
        with self._get_conn() as conn:
            conn.execute(
                "DELETE FROM sessions WHERE session_id = ?", (session_id,)
            )
            conn.commit()

    def cleanup_expired(self) -> int:
        """Delete all sessions older than SESSION_TTL_SECONDS. Returns count deleted."""
        cutoff = time.time() - SESSION_TTL_SECONDS
        with self._get_conn() as conn:
            cur = conn.execute(
                "DELETE FROM sessions WHERE updated_at < ?", (cutoff,)
            )
            conn.commit()
            deleted = cur.rowcount
        if deleted:
            print(f"[SessionStore] Cleaned up {deleted} expired session(s)")
        return deleted

    def cleanup_expired(self) -> int:
        """Remove sessions older than TTL. Returns number of rows deleted."""
        cutoff = time.time() - SESSION_TTL_SECONDS
        with self._get_conn() as conn:
            deleted = conn.execute(
                "DELETE FROM sessions WHERE updated_at < ?", (cutoff,)
            ).rowcount
            conn.commit()
        if deleted:
            print(f"[SessionStore] Cleaned up {deleted} expired session(s)")
        return deleted