Spaces:
Sleeping
Sleeping
| """SQLite-based chat history persistence.""" | |
| import sqlite3 | |
| import json | |
| import time | |
| import os | |
| import threading | |
| # Database path — use /data on HF Spaces (persistent), fallback to local | |
| DB_DIR = os.getenv("DATA_DIR", os.path.dirname(os.path.dirname(__file__))) | |
| DB_PATH = os.path.join(DB_DIR, "chat_history.db") | |
| # Thread-local storage for SQLite connections (SQLite isn't thread-safe) | |
| _local = threading.local() | |
| def _get_conn() -> sqlite3.Connection: | |
| """Get a thread-local SQLite connection.""" | |
| if not hasattr(_local, "conn") or _local.conn is None: | |
| _local.conn = sqlite3.connect(DB_PATH, check_same_thread=False) | |
| _local.conn.row_factory = sqlite3.Row | |
| _local.conn.execute("PRAGMA journal_mode=WAL") # Better concurrent reads | |
| return _local.conn | |
| def init_db(): | |
| """Create tables if they don't exist.""" | |
| conn = _get_conn() | |
| conn.execute(""" | |
| CREATE TABLE IF NOT EXISTS chat_messages ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id TEXT NOT NULL, | |
| role TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| timestamp REAL NOT NULL, | |
| created_at TEXT DEFAULT (datetime('now')) | |
| ) | |
| """) | |
| conn.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_session_id | |
| ON chat_messages(session_id) | |
| """) | |
| conn.commit() | |
| def save_message(session_id: str, role: str, content: str): | |
| """Save a single chat message to the database.""" | |
| conn = _get_conn() | |
| conn.execute( | |
| "INSERT INTO chat_messages (session_id, role, content, timestamp) VALUES (?, ?, ?, ?)", | |
| (session_id, role, content, time.time()), | |
| ) | |
| conn.commit() | |
| def get_chat_history(session_id: str) -> list[dict]: | |
| """Retrieve full chat history for a session.""" | |
| conn = _get_conn() | |
| rows = conn.execute( | |
| "SELECT role, content, created_at FROM chat_messages WHERE session_id = ? ORDER BY id ASC", | |
| (session_id,), | |
| ).fetchall() | |
| return [{"role": r["role"], "content": r["content"], "timestamp": r["created_at"]} for r in rows] | |
| def get_all_sessions() -> list[dict]: | |
| """Get a summary of all chat sessions.""" | |
| conn = _get_conn() | |
| rows = conn.execute(""" | |
| SELECT | |
| session_id, | |
| COUNT(*) as message_count, | |
| MIN(created_at) as started_at, | |
| MAX(created_at) as last_active | |
| FROM chat_messages | |
| GROUP BY session_id | |
| ORDER BY MAX(id) DESC | |
| """).fetchall() | |
| return [ | |
| { | |
| "session_id": r["session_id"], | |
| "message_count": r["message_count"], | |
| "started_at": r["started_at"], | |
| "last_active": r["last_active"], | |
| } | |
| for r in rows | |
| ] | |
| # Initialize database on import | |
| init_db() | |