Spaces:
Running
Running
| from __future__ import annotations | |
| """ | |
| Enhanced Database Layer β Session-aware chat with LSTM state persistence. | |
| Production improvements over original: | |
| - WAL mode for concurrent read/write | |
| - Database indices for query performance | |
| - Session cleanup/expiry utility | |
| - Structured logging instead of silent failures | |
| """ | |
| import logging | |
| import os | |
| import sqlite3 | |
| import uuid | |
| from datetime import datetime, timedelta | |
| logger = logging.getLogger(__name__) | |
| DB = os.getenv("DB_PATH", "chat_memory.db") | |
| def _connect(): | |
| conn = sqlite3.connect(DB) | |
| # Enable WAL mode for better concurrent read/write performance | |
| conn.execute("PRAGMA journal_mode=WAL") | |
| conn.execute("PRAGMA busy_timeout=5000") | |
| return conn | |
| def init_db(): | |
| conn = _connect() | |
| c = conn.cursor() | |
| # Original chat history table (preserved for backward compat) | |
| c.execute(""" | |
| CREATE TABLE IF NOT EXISTS chat_history ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| role TEXT, | |
| content TEXT, | |
| session_id TEXT, | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| """) | |
| # Session management table | |
| c.execute(""" | |
| CREATE TABLE IF NOT EXISTS chat_sessions ( | |
| session_id TEXT PRIMARY KEY, | |
| created_at DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| last_active DATETIME DEFAULT CURRENT_TIMESTAMP, | |
| lstm_state BLOB, | |
| context_summary TEXT DEFAULT '' | |
| ) | |
| """) | |
| # Add session_id column to chat_history if missing (migration) | |
| try: | |
| c.execute("SELECT session_id FROM chat_history LIMIT 1") | |
| except sqlite3.OperationalError: | |
| c.execute("ALTER TABLE chat_history ADD COLUMN session_id TEXT") | |
| # ββ Indices for query performance βββββββββββββββββββββββββββ | |
| c.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_chat_history_session_id | |
| ON chat_history(session_id) | |
| """) | |
| c.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_chat_history_timestamp | |
| ON chat_history(timestamp) | |
| """) | |
| c.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_chat_history_session_id_id | |
| ON chat_history(session_id, id) | |
| """) | |
| c.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_chat_sessions_last_active | |
| ON chat_sessions(last_active) | |
| """) | |
| conn.commit() | |
| conn.close() | |
| logger.info("Database initialised (WAL mode, indices created)") | |
| # ββ Session Management ββββββββββββββββββββββββββββββββββββββββββ | |
| def create_session() -> str: | |
| """Create a new chat session. Returns session_id.""" | |
| session_id = str(uuid.uuid4()) | |
| conn = _connect() | |
| conn.execute( | |
| "INSERT INTO chat_sessions (session_id) VALUES (?)", | |
| (session_id,) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| logger.debug("Session created: %s", session_id) | |
| return session_id | |
| def get_session(session_id: str) -> dict | None: | |
| """Get session details.""" | |
| conn = _connect() | |
| c = conn.cursor() | |
| c.execute( | |
| "SELECT session_id, created_at, last_active, context_summary " | |
| "FROM chat_sessions WHERE session_id = ?", | |
| (session_id,) | |
| ) | |
| row = c.fetchone() | |
| conn.close() | |
| if not row: | |
| return None | |
| return { | |
| "session_id": row[0], | |
| "created_at": row[1], | |
| "last_active": row[2], | |
| "context_summary": row[3] or "", | |
| } | |
| def touch_session(session_id: str): | |
| """Update last_active timestamp.""" | |
| conn = _connect() | |
| conn.execute( | |
| "UPDATE chat_sessions SET last_active = ? WHERE session_id = ?", | |
| (datetime.now().isoformat(), session_id) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def save_lstm_state(session_id: str, state_bytes: bytes): | |
| """Save serialised LSTM hidden state for a session.""" | |
| conn = _connect() | |
| conn.execute( | |
| "UPDATE chat_sessions SET lstm_state = ?, last_active = ? WHERE session_id = ?", | |
| (state_bytes, datetime.now().isoformat(), session_id) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def load_lstm_state(session_id: str) -> bytes | None: | |
| """Load serialised LSTM hidden state for a session.""" | |
| conn = _connect() | |
| c = conn.cursor() | |
| c.execute("SELECT lstm_state FROM chat_sessions WHERE session_id = ?", | |
| (session_id,)) | |
| row = c.fetchone() | |
| conn.close() | |
| return row[0] if row and row[0] else None | |
| def update_context_summary(session_id: str, summary: str): | |
| """Update the context summary for a session.""" | |
| conn = _connect() | |
| conn.execute( | |
| "UPDATE chat_sessions SET context_summary = ? WHERE session_id = ?", | |
| (summary, session_id) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| # ββ Message Management ββββββββββββββββββββββββββββββββββββββββββ | |
| def add_message(role: str, content: str, session_id: str = None): | |
| """Add a message to chat history.""" | |
| conn = _connect() | |
| conn.execute( | |
| "INSERT INTO chat_history (role, content, session_id) VALUES (?, ?, ?)", | |
| (role, content, session_id) | |
| ) | |
| conn.commit() | |
| conn.close() | |
| def get_recent_messages(limit: int = 6, session_id: str = None): | |
| """Get recent messages, optionally filtered by session.""" | |
| conn = _connect() | |
| c = conn.cursor() | |
| if session_id: | |
| c.execute( | |
| "SELECT role, content FROM chat_history " | |
| "WHERE session_id = ? ORDER BY id DESC LIMIT ?", | |
| (session_id, limit) | |
| ) | |
| else: | |
| c.execute( | |
| "SELECT role, content FROM chat_history ORDER BY id DESC LIMIT ?", | |
| (limit,) | |
| ) | |
| rows = c.fetchall() | |
| conn.close() | |
| return rows[::-1] | |
| def get_session_messages(session_id: str, limit: int = 50) -> list: | |
| """Get all messages for a session (for LSTM processing).""" | |
| conn = _connect() | |
| c = conn.cursor() | |
| c.execute( | |
| "SELECT role, content, timestamp FROM chat_history " | |
| "WHERE session_id = ? ORDER BY id ASC LIMIT ?", | |
| (session_id, limit) | |
| ) | |
| rows = c.fetchall() | |
| conn.close() | |
| return [{"role": r, "content": c, "timestamp": t} for r, c, t in rows] | |
| def clear_history(session_id: str = None): | |
| """Clear chat history. If session_id given, clear only that session.""" | |
| conn = _connect() | |
| if session_id: | |
| conn.execute("DELETE FROM chat_history WHERE session_id = ?", (session_id,)) | |
| conn.execute("DELETE FROM chat_sessions WHERE session_id = ?", (session_id,)) | |
| else: | |
| conn.execute("DELETE FROM chat_history") | |
| conn.execute("DELETE FROM chat_sessions") | |
| conn.commit() | |
| conn.close() | |
| # ββ Session Cleanup βββββββββββββββββββββββββββββββββββββββββββββ | |
| def cleanup_old_sessions(max_age_days: int = 30) -> int: | |
| """ | |
| Delete sessions and their messages older than max_age_days. | |
| Returns the number of sessions deleted. | |
| """ | |
| cutoff = (datetime.now() - timedelta(days=max_age_days)).isoformat() | |
| conn = _connect() | |
| c = conn.cursor() | |
| # Find expired sessions | |
| c.execute( | |
| "SELECT session_id FROM chat_sessions WHERE last_active < ?", | |
| (cutoff,) | |
| ) | |
| expired = [row[0] for row in c.fetchall()] | |
| if expired: | |
| placeholders = ",".join("?" * len(expired)) | |
| conn.execute( | |
| f"DELETE FROM chat_history WHERE session_id IN ({placeholders})", | |
| expired, | |
| ) | |
| conn.execute( | |
| f"DELETE FROM chat_sessions WHERE session_id IN ({placeholders})", | |
| expired, | |
| ) | |
| conn.commit() | |
| logger.info("Cleaned up %d expired sessions (older than %d days)", len(expired), max_age_days) | |
| conn.close() | |
| return len(expired) | |