REESEPALLATH's picture
Initial deploy
7cc1e9a
from __future__ import annotations
"""
Enhanced Database Layer β€” Session-aware chat with LSTM state persistence.
Production improvements over original:
- WAL mode for concurrent read/write
- Database indices for query performance
- Session cleanup/expiry utility
- Structured logging instead of silent failures
"""
import logging
import os
import sqlite3
import uuid
from datetime import datetime, timedelta
logger = logging.getLogger(__name__)
DB = os.getenv("DB_PATH", "chat_memory.db")
def _connect():
conn = sqlite3.connect(DB)
# Enable WAL mode for better concurrent read/write performance
conn.execute("PRAGMA journal_mode=WAL")
conn.execute("PRAGMA busy_timeout=5000")
return conn
def init_db():
conn = _connect()
c = conn.cursor()
# Original chat history table (preserved for backward compat)
c.execute("""
CREATE TABLE IF NOT EXISTS chat_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
role TEXT,
content TEXT,
session_id TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
""")
# Session management table
c.execute("""
CREATE TABLE IF NOT EXISTS chat_sessions (
session_id TEXT PRIMARY KEY,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
last_active DATETIME DEFAULT CURRENT_TIMESTAMP,
lstm_state BLOB,
context_summary TEXT DEFAULT ''
)
""")
# Add session_id column to chat_history if missing (migration)
try:
c.execute("SELECT session_id FROM chat_history LIMIT 1")
except sqlite3.OperationalError:
c.execute("ALTER TABLE chat_history ADD COLUMN session_id TEXT")
# ── Indices for query performance ───────────────────────────
c.execute("""
CREATE INDEX IF NOT EXISTS idx_chat_history_session_id
ON chat_history(session_id)
""")
c.execute("""
CREATE INDEX IF NOT EXISTS idx_chat_history_timestamp
ON chat_history(timestamp)
""")
c.execute("""
CREATE INDEX IF NOT EXISTS idx_chat_history_session_id_id
ON chat_history(session_id, id)
""")
c.execute("""
CREATE INDEX IF NOT EXISTS idx_chat_sessions_last_active
ON chat_sessions(last_active)
""")
conn.commit()
conn.close()
logger.info("Database initialised (WAL mode, indices created)")
# ── Session Management ──────────────────────────────────────────
def create_session() -> str:
"""Create a new chat session. Returns session_id."""
session_id = str(uuid.uuid4())
conn = _connect()
conn.execute(
"INSERT INTO chat_sessions (session_id) VALUES (?)",
(session_id,)
)
conn.commit()
conn.close()
logger.debug("Session created: %s", session_id)
return session_id
def get_session(session_id: str) -> dict | None:
"""Get session details."""
conn = _connect()
c = conn.cursor()
c.execute(
"SELECT session_id, created_at, last_active, context_summary "
"FROM chat_sessions WHERE session_id = ?",
(session_id,)
)
row = c.fetchone()
conn.close()
if not row:
return None
return {
"session_id": row[0],
"created_at": row[1],
"last_active": row[2],
"context_summary": row[3] or "",
}
def touch_session(session_id: str):
"""Update last_active timestamp."""
conn = _connect()
conn.execute(
"UPDATE chat_sessions SET last_active = ? WHERE session_id = ?",
(datetime.now().isoformat(), session_id)
)
conn.commit()
conn.close()
def save_lstm_state(session_id: str, state_bytes: bytes):
"""Save serialised LSTM hidden state for a session."""
conn = _connect()
conn.execute(
"UPDATE chat_sessions SET lstm_state = ?, last_active = ? WHERE session_id = ?",
(state_bytes, datetime.now().isoformat(), session_id)
)
conn.commit()
conn.close()
def load_lstm_state(session_id: str) -> bytes | None:
"""Load serialised LSTM hidden state for a session."""
conn = _connect()
c = conn.cursor()
c.execute("SELECT lstm_state FROM chat_sessions WHERE session_id = ?",
(session_id,))
row = c.fetchone()
conn.close()
return row[0] if row and row[0] else None
def update_context_summary(session_id: str, summary: str):
"""Update the context summary for a session."""
conn = _connect()
conn.execute(
"UPDATE chat_sessions SET context_summary = ? WHERE session_id = ?",
(summary, session_id)
)
conn.commit()
conn.close()
# ── Message Management ──────────────────────────────────────────
def add_message(role: str, content: str, session_id: str = None):
"""Add a message to chat history."""
conn = _connect()
conn.execute(
"INSERT INTO chat_history (role, content, session_id) VALUES (?, ?, ?)",
(role, content, session_id)
)
conn.commit()
conn.close()
def get_recent_messages(limit: int = 6, session_id: str = None):
"""Get recent messages, optionally filtered by session."""
conn = _connect()
c = conn.cursor()
if session_id:
c.execute(
"SELECT role, content FROM chat_history "
"WHERE session_id = ? ORDER BY id DESC LIMIT ?",
(session_id, limit)
)
else:
c.execute(
"SELECT role, content FROM chat_history ORDER BY id DESC LIMIT ?",
(limit,)
)
rows = c.fetchall()
conn.close()
return rows[::-1]
def get_session_messages(session_id: str, limit: int = 50) -> list:
"""Get all messages for a session (for LSTM processing)."""
conn = _connect()
c = conn.cursor()
c.execute(
"SELECT role, content, timestamp FROM chat_history "
"WHERE session_id = ? ORDER BY id ASC LIMIT ?",
(session_id, limit)
)
rows = c.fetchall()
conn.close()
return [{"role": r, "content": c, "timestamp": t} for r, c, t in rows]
def clear_history(session_id: str = None):
"""Clear chat history. If session_id given, clear only that session."""
conn = _connect()
if session_id:
conn.execute("DELETE FROM chat_history WHERE session_id = ?", (session_id,))
conn.execute("DELETE FROM chat_sessions WHERE session_id = ?", (session_id,))
else:
conn.execute("DELETE FROM chat_history")
conn.execute("DELETE FROM chat_sessions")
conn.commit()
conn.close()
# ── Session Cleanup ─────────────────────────────────────────────
def cleanup_old_sessions(max_age_days: int = 30) -> int:
"""
Delete sessions and their messages older than max_age_days.
Returns the number of sessions deleted.
"""
cutoff = (datetime.now() - timedelta(days=max_age_days)).isoformat()
conn = _connect()
c = conn.cursor()
# Find expired sessions
c.execute(
"SELECT session_id FROM chat_sessions WHERE last_active < ?",
(cutoff,)
)
expired = [row[0] for row in c.fetchall()]
if expired:
placeholders = ",".join("?" * len(expired))
conn.execute(
f"DELETE FROM chat_history WHERE session_id IN ({placeholders})",
expired,
)
conn.execute(
f"DELETE FROM chat_sessions WHERE session_id IN ({placeholders})",
expired,
)
conn.commit()
logger.info("Cleaned up %d expired sessions (older than %d days)", len(expired), max_age_days)
conn.close()
return len(expired)