Soma / app /db /session.py
Komalpreet Kaur
feat: implement memory consolidation service with sleep cycle, add frontend visualization components, and integrate backend database orchestration.
2e9dd8e unverified
import json
import sqlite3
from contextlib import contextmanager
from app.core.config import settings
# ── Backend detection ─────────────────────────────────────────────
# If DATABASE_URL is set (Supabase/Postgres), try to use psycopg2.
# Otherwise fall back to local SQLite (development).
USE_POSTGRES = bool(settings.DATABASE_URL)
DB_PATH = settings.SQLITE_DB_PATH # Always defined for fallback
if USE_POSTGRES:
import psycopg2
import psycopg2.extras
from psycopg2 import IntegrityError as DBIntegrityError
else:
from sqlite3 import IntegrityError as DBIntegrityError
def _test_postgres_connection():
"""Test if Postgres is reachable. If not, fall back to SQLite."""
if not USE_POSTGRES:
return True
try:
conn = psycopg2.connect(settings.DATABASE_URL, connect_timeout=3)
conn.close()
return True
except Exception:
return False
# Runtime backend selection (can change after connection test)
_db_backend = "postgres" if USE_POSTGRES else "sqlite"
@contextmanager
def get_conn():
global _db_backend
if _db_backend == "postgres":
conn = psycopg2.connect(settings.DATABASE_URL)
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
else:
conn = sqlite3.connect(DB_PATH)
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _cursor(conn):
"""Return a dict-style cursor for Postgres, default for SQLite."""
if _db_backend == "postgres":
return conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
return conn.cursor()
# ── Schema init ───────────────────────────────────────────────────
def init_session_db():
global _db_backend
# Test Postgres connection if configured
if _db_backend == "postgres" and not _test_postgres_connection():
print("[!] Postgres connection failed. Falling back to SQLite.")
_db_backend = "sqlite"
with get_conn() as conn:
cur = conn.cursor()
if _db_backend == "postgres":
cur.execute('''
CREATE TABLE IF NOT EXISTS users (
id SERIAL PRIMARY KEY,
username TEXT UNIQUE NOT NULL,
hashed_password TEXT NOT NULL,
created_at TIMESTAMP DEFAULT NOW()
)
''')
cur.execute('''
CREATE TABLE IF NOT EXISTS messages (
id SERIAL PRIMARY KEY,
session_id TEXT,
role TEXT,
content TEXT,
timestamp TIMESTAMP DEFAULT NOW()
)
''')
cur.execute('''
CREATE TABLE IF NOT EXISTS neural_sparks (
id SERIAL PRIMARY KEY,
content TEXT,
entities TEXT,
user_id TEXT DEFAULT 'default_user',
timestamp TIMESTAMP DEFAULT NOW()
)
''')
else:
cur.execute('''
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT UNIQUE NOT NULL,
hashed_password TEXT NOT NULL,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
cur.execute('''
CREATE TABLE IF NOT EXISTS messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT,
role TEXT,
content TEXT,
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
cur.execute('''
CREATE TABLE IF NOT EXISTS neural_sparks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
content TEXT,
entities TEXT,
user_id TEXT DEFAULT 'default_user',
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
)
''')
# SQLite migration: add user_id if missing
try:
cur.execute("ALTER TABLE neural_sparks ADD COLUMN user_id TEXT DEFAULT 'default_user'")
except sqlite3.OperationalError:
pass
# ── User helpers ──────────────────────────────────────────────────
def create_user(username: str, hashed_password: str) -> bool:
try:
with get_conn() as conn:
cur = conn.cursor()
if _db_backend == "postgres":
cur.execute(
'INSERT INTO users (username, hashed_password) VALUES (%s, %s)',
(username, hashed_password)
)
else:
cur.execute(
'INSERT INTO users (username, hashed_password) VALUES (?, ?)',
(username, hashed_password)
)
return True
except DBIntegrityError:
return False
def get_user(username: str):
with get_conn() as conn:
cur = _cursor(conn)
if _db_backend == "postgres":
cur.execute(
'SELECT username, hashed_password FROM users WHERE LOWER(username) = LOWER(%s)',
(username,)
)
row = cur.fetchone()
return (row['username'], row['hashed_password']) if row else None
else:
cur.execute(
'SELECT username, hashed_password FROM users WHERE username = ? COLLATE NOCASE',
(username,)
)
return cur.fetchone()
# ── Message helpers ───────────────────────────────────────────────
def add_message(session_id: str, role: str, content: str):
with get_conn() as conn:
cur = conn.cursor()
if _db_backend == "postgres":
cur.execute(
'INSERT INTO messages (session_id, role, content) VALUES (%s, %s, %s)',
(session_id, role, content)
)
else:
cur.execute(
'INSERT INTO messages (session_id, role, content) VALUES (?, ?, ?)',
(session_id, role, content)
)
def get_recent_messages(session_id: str, exchanges: int = 5):
with get_conn() as conn:
cur = _cursor(conn)
if _db_backend == "postgres":
cur.execute(
'SELECT role, content FROM messages WHERE session_id = %s ORDER BY timestamp DESC LIMIT %s',
(session_id, exchanges * 2)
)
else:
cur.execute(
'SELECT role, content FROM messages WHERE session_id = ? ORDER BY timestamp DESC LIMIT ?',
(session_id, exchanges * 2)
)
rows = cur.fetchall()
if _db_backend == "postgres":
rows = [{"role": r['role'], "content": r['content']} for r in reversed(rows)]
else:
rows = [{"role": r[0], "content": r[1]} for r in reversed(rows)]
return rows
def get_all_session_ids():
with get_conn() as conn:
cur = conn.cursor()
cur.execute('SELECT DISTINCT session_id FROM messages')
return [row[0] if _db_backend != "postgres" else row['session_id'] for row in cur.fetchall()]
def get_message_count(session_id: str):
with get_conn() as conn:
cur = conn.cursor()
if _db_backend == "postgres":
cur.execute(
'SELECT COUNT(*) FROM messages WHERE session_id = %s',
(session_id,)
)
else:
cur.execute(
'SELECT COUNT(*) FROM messages WHERE session_id = ?',
(session_id,)
)
row = cur.fetchone()
return row[0] if _db_backend != "postgres" else list(row.values())[0]
def prune_old_messages(session_id: str, keep_recent: int = 10):
with get_conn() as conn:
cur = conn.cursor()
if _db_backend == "postgres":
cur.execute('''
DELETE FROM messages WHERE session_id = %s AND id NOT IN (
SELECT id FROM messages WHERE session_id = %s
ORDER BY timestamp DESC LIMIT %s
)
''', (session_id, session_id, keep_recent))
else:
cur.execute('''
DELETE FROM messages WHERE session_id = ? AND id NOT IN (
SELECT id FROM messages WHERE session_id = ?
ORDER BY timestamp DESC LIMIT ?
)
''', (session_id, session_id, keep_recent))
return get_message_count(session_id)
# ── Spark helpers ─────────────────────────────────────────────────
def add_spark(content: str, entities: list, user_id: str = "default_user"):
with get_conn() as conn:
cur = conn.cursor()
if _db_backend == "postgres":
cur.execute(
'INSERT INTO neural_sparks (content, entities, user_id) VALUES (%s, %s, %s)',
(content, json.dumps(entities), user_id)
)
else:
cur.execute(
'INSERT INTO neural_sparks (content, entities, user_id) VALUES (?, ?, ?)',
(content, json.dumps(entities), user_id)
)
def clear_user_messages(user_id: str):
"""Delete all messages and sparks for a user (fresh session)."""
with get_conn() as conn:
cur = conn.cursor()
if _db_backend == "postgres":
cur.execute('DELETE FROM messages WHERE session_id = %s', (user_id,))
cur.execute('DELETE FROM neural_sparks WHERE user_id = %s', (user_id,))
else:
cur.execute('DELETE FROM messages WHERE session_id = ?', (user_id,))
cur.execute('DELETE FROM neural_sparks WHERE user_id = ?', (user_id,))
def get_recent_sparks(user_id: str = "default_user", limit: int = 5):
with get_conn() as conn:
cur = _cursor(conn)
if _db_backend == "postgres":
cur.execute(
'SELECT content, entities, timestamp FROM neural_sparks WHERE user_id = %s ORDER BY timestamp DESC LIMIT %s',
(user_id, limit)
)
else:
cur.execute(
'SELECT content, entities, timestamp FROM neural_sparks WHERE user_id = ? ORDER BY timestamp DESC LIMIT ?',
(user_id, limit)
)
rows = cur.fetchall()
if _db_backend == "postgres":
return [{"content": r['content'], "entities": json.loads(r['entities']), "timestamp": str(r['timestamp'])} for r in rows]
return [{"content": r[0], "entities": json.loads(r[1]), "timestamp": r[2]} for r in rows]