File size: 2,762 Bytes
6ca2339
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""SQLite-based chat history persistence."""

import sqlite3
import json
import time
import os
import threading

# Database path — use /data on HF Spaces (persistent), fallback to local
DB_DIR = os.getenv("DATA_DIR", os.path.dirname(os.path.dirname(__file__)))
DB_PATH = os.path.join(DB_DIR, "chat_history.db")

# Thread-local storage for SQLite connections (SQLite isn't thread-safe)
_local = threading.local()


def _get_conn() -> sqlite3.Connection:
    """Get a thread-local SQLite connection."""
    if not hasattr(_local, "conn") or _local.conn is None:
        _local.conn = sqlite3.connect(DB_PATH, check_same_thread=False)
        _local.conn.row_factory = sqlite3.Row
        _local.conn.execute("PRAGMA journal_mode=WAL")  # Better concurrent reads
    return _local.conn


def init_db():
    """Create tables if they don't exist."""
    conn = _get_conn()
    conn.execute("""
        CREATE TABLE IF NOT EXISTS chat_messages (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            session_id TEXT NOT NULL,
            role TEXT NOT NULL,
            content TEXT NOT NULL,
            timestamp REAL NOT NULL,
            created_at TEXT DEFAULT (datetime('now'))
        )
    """)
    conn.execute("""
        CREATE INDEX IF NOT EXISTS idx_session_id 
        ON chat_messages(session_id)
    """)
    conn.commit()


def save_message(session_id: str, role: str, content: str):
    """Save a single chat message to the database."""
    conn = _get_conn()
    conn.execute(
        "INSERT INTO chat_messages (session_id, role, content, timestamp) VALUES (?, ?, ?, ?)",
        (session_id, role, content, time.time()),
    )
    conn.commit()


def get_chat_history(session_id: str) -> list[dict]:
    """Retrieve full chat history for a session."""
    conn = _get_conn()
    rows = conn.execute(
        "SELECT role, content, created_at FROM chat_messages WHERE session_id = ? ORDER BY id ASC",
        (session_id,),
    ).fetchall()
    return [{"role": r["role"], "content": r["content"], "timestamp": r["created_at"]} for r in rows]


def get_all_sessions() -> list[dict]:
    """Get a summary of all chat sessions."""
    conn = _get_conn()
    rows = conn.execute("""
        SELECT 
            session_id, 
            COUNT(*) as message_count,
            MIN(created_at) as started_at,
            MAX(created_at) as last_active
        FROM chat_messages 
        GROUP BY session_id 
        ORDER BY MAX(id) DESC
    """).fetchall()
    return [
        {
            "session_id": r["session_id"],
            "message_count": r["message_count"],
            "started_at": r["started_at"],
            "last_active": r["last_active"],
        }
        for r in rows
    ]


# Initialize database on import
init_db()