Spaces:
Running
Running
Komalpreet Kaur
feat: implement memory consolidation service with sleep cycle, add frontend visualization components, and integrate backend database orchestration.
2e9dd8e unverified | import json | |
| import sqlite3 | |
| from contextlib import contextmanager | |
| from app.core.config import settings | |
| # ββ Backend detection βββββββββββββββββββββββββββββββββββββββββββββ | |
| # If DATABASE_URL is set (Supabase/Postgres), try to use psycopg2. | |
| # Otherwise fall back to local SQLite (development). | |
| USE_POSTGRES = bool(settings.DATABASE_URL) | |
| DB_PATH = settings.SQLITE_DB_PATH # Always defined for fallback | |
| if USE_POSTGRES: | |
| import psycopg2 | |
| import psycopg2.extras | |
| from psycopg2 import IntegrityError as DBIntegrityError | |
| else: | |
| from sqlite3 import IntegrityError as DBIntegrityError | |
| def _test_postgres_connection(): | |
| """Test if Postgres is reachable. If not, fall back to SQLite.""" | |
| if not USE_POSTGRES: | |
| return True | |
| try: | |
| conn = psycopg2.connect(settings.DATABASE_URL, connect_timeout=3) | |
| conn.close() | |
| return True | |
| except Exception: | |
| return False | |
| # Runtime backend selection (can change after connection test) | |
| _db_backend = "postgres" if USE_POSTGRES else "sqlite" | |
| def get_conn(): | |
| global _db_backend | |
| if _db_backend == "postgres": | |
| conn = psycopg2.connect(settings.DATABASE_URL) | |
| try: | |
| yield conn | |
| conn.commit() | |
| except Exception: | |
| conn.rollback() | |
| raise | |
| finally: | |
| conn.close() | |
| else: | |
| conn = sqlite3.connect(DB_PATH) | |
| try: | |
| yield conn | |
| conn.commit() | |
| except Exception: | |
| conn.rollback() | |
| raise | |
| finally: | |
| conn.close() | |
| def _cursor(conn): | |
| """Return a dict-style cursor for Postgres, default for SQLite.""" | |
| if _db_backend == "postgres": | |
| return conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) | |
| return conn.cursor() | |
| # ββ Schema init βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def init_session_db(): | |
| global _db_backend | |
| # Test Postgres connection if configured | |
| if _db_backend == "postgres" and not _test_postgres_connection(): | |
| print("[!] Postgres connection failed. Falling back to SQLite.") | |
| _db_backend = "sqlite" | |
| with get_conn() as conn: | |
| cur = conn.cursor() | |
| if _db_backend == "postgres": | |
| cur.execute(''' | |
| CREATE TABLE IF NOT EXISTS users ( | |
| id SERIAL PRIMARY KEY, | |
| username TEXT UNIQUE NOT NULL, | |
| hashed_password TEXT NOT NULL, | |
| created_at TIMESTAMP DEFAULT NOW() | |
| ) | |
| ''') | |
| cur.execute(''' | |
| CREATE TABLE IF NOT EXISTS messages ( | |
| id SERIAL PRIMARY KEY, | |
| session_id TEXT, | |
| role TEXT, | |
| content TEXT, | |
| timestamp TIMESTAMP DEFAULT NOW() | |
| ) | |
| ''') | |
| cur.execute(''' | |
| CREATE TABLE IF NOT EXISTS neural_sparks ( | |
| id SERIAL PRIMARY KEY, | |
| content TEXT, | |
| entities TEXT, | |
| user_id TEXT DEFAULT 'default_user', | |
| timestamp TIMESTAMP DEFAULT NOW() | |
| ) | |
| ''') | |
| else: | |
| cur.execute(''' | |
| CREATE TABLE IF NOT EXISTS users ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| username TEXT UNIQUE NOT NULL, | |
| hashed_password TEXT NOT NULL, | |
| created_at DATETIME DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| cur.execute(''' | |
| CREATE TABLE IF NOT EXISTS messages ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id TEXT, | |
| role TEXT, | |
| content TEXT, | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| cur.execute(''' | |
| CREATE TABLE IF NOT EXISTS neural_sparks ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| content TEXT, | |
| entities TEXT, | |
| user_id TEXT DEFAULT 'default_user', | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP | |
| ) | |
| ''') | |
| # SQLite migration: add user_id if missing | |
| try: | |
| cur.execute("ALTER TABLE neural_sparks ADD COLUMN user_id TEXT DEFAULT 'default_user'") | |
| except sqlite3.OperationalError: | |
| pass | |
| # ββ User helpers ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def create_user(username: str, hashed_password: str) -> bool: | |
| try: | |
| with get_conn() as conn: | |
| cur = conn.cursor() | |
| if _db_backend == "postgres": | |
| cur.execute( | |
| 'INSERT INTO users (username, hashed_password) VALUES (%s, %s)', | |
| (username, hashed_password) | |
| ) | |
| else: | |
| cur.execute( | |
| 'INSERT INTO users (username, hashed_password) VALUES (?, ?)', | |
| (username, hashed_password) | |
| ) | |
| return True | |
| except DBIntegrityError: | |
| return False | |
| def get_user(username: str): | |
| with get_conn() as conn: | |
| cur = _cursor(conn) | |
| if _db_backend == "postgres": | |
| cur.execute( | |
| 'SELECT username, hashed_password FROM users WHERE LOWER(username) = LOWER(%s)', | |
| (username,) | |
| ) | |
| row = cur.fetchone() | |
| return (row['username'], row['hashed_password']) if row else None | |
| else: | |
| cur.execute( | |
| 'SELECT username, hashed_password FROM users WHERE username = ? COLLATE NOCASE', | |
| (username,) | |
| ) | |
| return cur.fetchone() | |
| # ββ Message helpers βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def add_message(session_id: str, role: str, content: str): | |
| with get_conn() as conn: | |
| cur = conn.cursor() | |
| if _db_backend == "postgres": | |
| cur.execute( | |
| 'INSERT INTO messages (session_id, role, content) VALUES (%s, %s, %s)', | |
| (session_id, role, content) | |
| ) | |
| else: | |
| cur.execute( | |
| 'INSERT INTO messages (session_id, role, content) VALUES (?, ?, ?)', | |
| (session_id, role, content) | |
| ) | |
| def get_recent_messages(session_id: str, exchanges: int = 5): | |
| with get_conn() as conn: | |
| cur = _cursor(conn) | |
| if _db_backend == "postgres": | |
| cur.execute( | |
| 'SELECT role, content FROM messages WHERE session_id = %s ORDER BY timestamp DESC LIMIT %s', | |
| (session_id, exchanges * 2) | |
| ) | |
| else: | |
| cur.execute( | |
| 'SELECT role, content FROM messages WHERE session_id = ? ORDER BY timestamp DESC LIMIT ?', | |
| (session_id, exchanges * 2) | |
| ) | |
| rows = cur.fetchall() | |
| if _db_backend == "postgres": | |
| rows = [{"role": r['role'], "content": r['content']} for r in reversed(rows)] | |
| else: | |
| rows = [{"role": r[0], "content": r[1]} for r in reversed(rows)] | |
| return rows | |
| def get_all_session_ids(): | |
| with get_conn() as conn: | |
| cur = conn.cursor() | |
| cur.execute('SELECT DISTINCT session_id FROM messages') | |
| return [row[0] if _db_backend != "postgres" else row['session_id'] for row in cur.fetchall()] | |
| def get_message_count(session_id: str): | |
| with get_conn() as conn: | |
| cur = conn.cursor() | |
| if _db_backend == "postgres": | |
| cur.execute( | |
| 'SELECT COUNT(*) FROM messages WHERE session_id = %s', | |
| (session_id,) | |
| ) | |
| else: | |
| cur.execute( | |
| 'SELECT COUNT(*) FROM messages WHERE session_id = ?', | |
| (session_id,) | |
| ) | |
| row = cur.fetchone() | |
| return row[0] if _db_backend != "postgres" else list(row.values())[0] | |
| def prune_old_messages(session_id: str, keep_recent: int = 10): | |
| with get_conn() as conn: | |
| cur = conn.cursor() | |
| if _db_backend == "postgres": | |
| cur.execute(''' | |
| DELETE FROM messages WHERE session_id = %s AND id NOT IN ( | |
| SELECT id FROM messages WHERE session_id = %s | |
| ORDER BY timestamp DESC LIMIT %s | |
| ) | |
| ''', (session_id, session_id, keep_recent)) | |
| else: | |
| cur.execute(''' | |
| DELETE FROM messages WHERE session_id = ? AND id NOT IN ( | |
| SELECT id FROM messages WHERE session_id = ? | |
| ORDER BY timestamp DESC LIMIT ? | |
| ) | |
| ''', (session_id, session_id, keep_recent)) | |
| return get_message_count(session_id) | |
| # ββ Spark helpers βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def add_spark(content: str, entities: list, user_id: str = "default_user"): | |
| with get_conn() as conn: | |
| cur = conn.cursor() | |
| if _db_backend == "postgres": | |
| cur.execute( | |
| 'INSERT INTO neural_sparks (content, entities, user_id) VALUES (%s, %s, %s)', | |
| (content, json.dumps(entities), user_id) | |
| ) | |
| else: | |
| cur.execute( | |
| 'INSERT INTO neural_sparks (content, entities, user_id) VALUES (?, ?, ?)', | |
| (content, json.dumps(entities), user_id) | |
| ) | |
| def clear_user_messages(user_id: str): | |
| """Delete all messages and sparks for a user (fresh session).""" | |
| with get_conn() as conn: | |
| cur = conn.cursor() | |
| if _db_backend == "postgres": | |
| cur.execute('DELETE FROM messages WHERE session_id = %s', (user_id,)) | |
| cur.execute('DELETE FROM neural_sparks WHERE user_id = %s', (user_id,)) | |
| else: | |
| cur.execute('DELETE FROM messages WHERE session_id = ?', (user_id,)) | |
| cur.execute('DELETE FROM neural_sparks WHERE user_id = ?', (user_id,)) | |
| def get_recent_sparks(user_id: str = "default_user", limit: int = 5): | |
| with get_conn() as conn: | |
| cur = _cursor(conn) | |
| if _db_backend == "postgres": | |
| cur.execute( | |
| 'SELECT content, entities, timestamp FROM neural_sparks WHERE user_id = %s ORDER BY timestamp DESC LIMIT %s', | |
| (user_id, limit) | |
| ) | |
| else: | |
| cur.execute( | |
| 'SELECT content, entities, timestamp FROM neural_sparks WHERE user_id = ? ORDER BY timestamp DESC LIMIT ?', | |
| (user_id, limit) | |
| ) | |
| rows = cur.fetchall() | |
| if _db_backend == "postgres": | |
| return [{"content": r['content'], "entities": json.loads(r['entities']), "timestamp": str(r['timestamp'])} for r in rows] | |
| return [{"content": r[0], "entities": json.loads(r[1]), "timestamp": r[2]} for r in rows] | |