"""SQLite-based chat history persistence.""" import sqlite3 import json import time import os import threading # Database path — use /data on HF Spaces (persistent), fallback to local DB_DIR = os.getenv("DATA_DIR", os.path.dirname(os.path.dirname(__file__))) DB_PATH = os.path.join(DB_DIR, "chat_history.db") # Thread-local storage for SQLite connections (SQLite isn't thread-safe) _local = threading.local() def _get_conn() -> sqlite3.Connection: """Get a thread-local SQLite connection.""" if not hasattr(_local, "conn") or _local.conn is None: _local.conn = sqlite3.connect(DB_PATH, check_same_thread=False) _local.conn.row_factory = sqlite3.Row _local.conn.execute("PRAGMA journal_mode=WAL") # Better concurrent reads return _local.conn def init_db(): """Create tables if they don't exist.""" conn = _get_conn() conn.execute(""" CREATE TABLE IF NOT EXISTS chat_messages ( id INTEGER PRIMARY KEY AUTOINCREMENT, session_id TEXT NOT NULL, role TEXT NOT NULL, content TEXT NOT NULL, timestamp REAL NOT NULL, created_at TEXT DEFAULT (datetime('now')) ) """) conn.execute(""" CREATE INDEX IF NOT EXISTS idx_session_id ON chat_messages(session_id) """) conn.commit() def save_message(session_id: str, role: str, content: str): """Save a single chat message to the database.""" conn = _get_conn() conn.execute( "INSERT INTO chat_messages (session_id, role, content, timestamp) VALUES (?, ?, ?, ?)", (session_id, role, content, time.time()), ) conn.commit() def get_chat_history(session_id: str) -> list[dict]: """Retrieve full chat history for a session.""" conn = _get_conn() rows = conn.execute( "SELECT role, content, created_at FROM chat_messages WHERE session_id = ? ORDER BY id ASC", (session_id,), ).fetchall() return [{"role": r["role"], "content": r["content"], "timestamp": r["created_at"]} for r in rows] def get_all_sessions() -> list[dict]: """Get a summary of all chat sessions.""" conn = _get_conn() rows = conn.execute(""" SELECT session_id, COUNT(*) as message_count, MIN(created_at) as started_at, MAX(created_at) as last_active FROM chat_messages GROUP BY session_id ORDER BY MAX(id) DESC """).fetchall() return [ { "session_id": r["session_id"], "message_count": r["message_count"], "started_at": r["started_at"], "last_active": r["last_active"], } for r in rows ] # Initialize database on import init_db()