Spaces:
Sleeping
Sleeping
| """ | |
| Cogni-Engine v1 — TiDB Persistence Layer | |
| Handles all database communication: schema, CRUD, buffering, sync. | |
| The "long-term memory" — survives reboots, crashes, rebuilds. | |
| """ | |
| import time | |
| import json | |
| import threading | |
| import traceback | |
| from typing import List, Dict, Optional, Any, Tuple | |
| from contextlib import contextmanager | |
| import pymysql | |
| from pymysql.cursors import DictCursor | |
| import config | |
| import utils | |
| # ═══════════════════════════════════════════════════════════ | |
| # CONNECTION MANAGER | |
| # ═══════════════════════════════════════════════════════════ | |
| class ConnectionPool: | |
| """ | |
| Simple connection pool for TiDB. | |
| Manages multiple reusable connections with auto-reconnect. | |
| """ | |
| def __init__(self): | |
| self._connections: List[pymysql.Connection] = [] | |
| self._lock = threading.Lock() | |
| self._available: List[pymysql.Connection] = [] | |
| self._max_size = config.TIDB_POOL_SIZE | |
| self._initialized = False | |
| def _create_connection(self) -> pymysql.Connection: | |
| """Create a new TiDB connection.""" | |
| connect_kwargs = { | |
| "host": config.TIDB_HOST, | |
| "port": config.TIDB_PORT, | |
| "user": config.TIDB_USER, | |
| "password": config.TIDB_PASSWORD, | |
| "database": config.TIDB_DATABASE, | |
| "connect_timeout": config.TIDB_CONNECT_TIMEOUT, | |
| "read_timeout": config.TIDB_READ_TIMEOUT, | |
| "write_timeout": config.TIDB_WRITE_TIMEOUT, | |
| "charset": "utf8mb4", | |
| "cursorclass": DictCursor, | |
| "autocommit": True | |
| } | |
| if config.TIDB_SSL: | |
| connect_kwargs["ssl"] = {"ssl_mode": "VERIFY_IDENTITY"} | |
| conn = pymysql.connect(**connect_kwargs) | |
| return conn | |
| def _test_connection(self, conn: pymysql.Connection) -> bool: | |
| """Test if connection is still alive.""" | |
| try: | |
| conn.ping(reconnect=False) | |
| return True | |
| except Exception: | |
| return False | |
| def acquire(self) -> pymysql.Connection: | |
| """Get a connection from the pool.""" | |
| with self._lock: | |
| # Try to reuse existing connection | |
| while self._available: | |
| conn = self._available.pop() | |
| if self._test_connection(conn): | |
| return conn | |
| else: | |
| # Dead connection, discard | |
| try: | |
| conn.close() | |
| except Exception: | |
| pass | |
| # Create new if under limit | |
| if len(self._connections) < self._max_size: | |
| conn = self._create_connection() | |
| self._connections.append(conn) | |
| return conn | |
| # All connections busy, create temporary one | |
| return self._create_connection() | |
| def release(self, conn: pymysql.Connection): | |
| """Return a connection to the pool.""" | |
| with self._lock: | |
| if self._test_connection(conn): | |
| self._available.append(conn) | |
| else: | |
| try: | |
| conn.close() | |
| except Exception: | |
| pass | |
| # Remove from tracked connections | |
| if conn in self._connections: | |
| self._connections.remove(conn) | |
| def close_all(self): | |
| """Close all connections.""" | |
| with self._lock: | |
| for conn in self._connections: | |
| try: | |
| conn.close() | |
| except Exception: | |
| pass | |
| self._connections.clear() | |
| self._available.clear() | |
| def connection(self): | |
| """Context manager for auto acquire/release.""" | |
| conn = None | |
| try: | |
| conn = self.acquire() | |
| yield conn | |
| finally: | |
| if conn: | |
| self.release(conn) | |
| # ═══════════════════════════════════════════════════════════ | |
| # WRITE BUFFER | |
| # ═══════════════════════════════════════════════════════════ | |
| class WriteBuffer: | |
| """ | |
| Buffers write operations and flushes in batches. | |
| Prevents excessive DB writes during rapid thinking cycles. | |
| """ | |
| def __init__(self): | |
| self._lock = threading.Lock() | |
| self._node_upserts: Dict[str, dict] = {} # id → node_data | |
| self._edge_upserts: Dict[str, dict] = {} # id → edge_data | |
| self._chain_upserts: Dict[str, dict] = {} # id → chain_data | |
| self._node_deletes: set = set() # ids to delete | |
| self._edge_deletes: set = set() # ids to delete | |
| self._state_update: Optional[dict] = None # thinking state | |
| self._checksum_updates: Dict[str, str] = {} # filename → checksum | |
| self._operation_count = 0 | |
| self._last_flush_time = time.time() | |
| def buffer_node(self, node_data: dict): | |
| """Buffer a node upsert.""" | |
| with self._lock: | |
| self._node_upserts[node_data["id"]] = node_data | |
| self._operation_count += 1 | |
| def buffer_edge(self, edge_data: dict): | |
| """Buffer an edge upsert.""" | |
| with self._lock: | |
| self._edge_upserts[edge_data["id"]] = edge_data | |
| self._operation_count += 1 | |
| def buffer_chain(self, chain_data: dict): | |
| """Buffer a chain upsert.""" | |
| with self._lock: | |
| self._chain_upserts[chain_data["id"]] = chain_data | |
| self._operation_count += 1 | |
| def buffer_node_delete(self, node_id: str): | |
| """Buffer a node deletion.""" | |
| with self._lock: | |
| self._node_deletes.add(node_id) | |
| self._node_upserts.pop(node_id, None) | |
| self._operation_count += 1 | |
| def buffer_edge_delete(self, edge_id: str): | |
| """Buffer an edge deletion.""" | |
| with self._lock: | |
| self._edge_deletes.add(edge_id) | |
| self._edge_upserts.pop(edge_id, None) | |
| self._operation_count += 1 | |
| def buffer_state(self, state: dict): | |
| """Buffer thinking state update.""" | |
| with self._lock: | |
| self._state_update = state | |
| def buffer_checksum(self, filename: str, checksum: str): | |
| """Buffer file checksum update.""" | |
| with self._lock: | |
| self._checksum_updates[filename] = checksum | |
| def pending_count(self) -> int: | |
| """Number of pending operations.""" | |
| return self._operation_count | |
| def seconds_since_flush(self) -> float: | |
| """Seconds elapsed since last flush.""" | |
| return time.time() - self._last_flush_time | |
| def should_flush(self) -> bool: | |
| """Check if buffer should be flushed based on config thresholds.""" | |
| if self._operation_count == 0 and self._state_update is None: | |
| return False | |
| if self._operation_count >= config.SYNC_INTERVAL_CYCLES: | |
| return True | |
| if self.seconds_since_flush >= config.SYNC_INTERVAL_SECONDS: | |
| return True | |
| return False | |
| def drain(self) -> dict: | |
| """ | |
| Extract all buffered operations and reset buffer. | |
| Returns dict with all pending operations. | |
| """ | |
| with self._lock: | |
| data = { | |
| "node_upserts": dict(self._node_upserts), | |
| "edge_upserts": dict(self._edge_upserts), | |
| "chain_upserts": dict(self._chain_upserts), | |
| "node_deletes": set(self._node_deletes), | |
| "edge_deletes": set(self._edge_deletes), | |
| "state_update": self._state_update, | |
| "checksum_updates": dict(self._checksum_updates) | |
| } | |
| self._node_upserts.clear() | |
| self._edge_upserts.clear() | |
| self._chain_upserts.clear() | |
| self._node_deletes.clear() | |
| self._edge_deletes.clear() | |
| self._state_update = None | |
| self._checksum_updates.clear() | |
| self._operation_count = 0 | |
| self._last_flush_time = time.time() | |
| return data | |
| # ═══════════════════════════════════════════════════════════ | |
| # MAIN MEMORY CLASS | |
| # ═══════════════════════════════════════════════════════════ | |
| class Memory: | |
| """ | |
| TiDB persistence layer. | |
| Handles schema creation, CRUD, buffered writes, and full state load/save. | |
| """ | |
| def __init__(self): | |
| self.pool = ConnectionPool() | |
| self.buffer = WriteBuffer() | |
| self._connected = False | |
| self._schema_ready = False | |
| # ─────────────────────────────────────────────────── | |
| # INITIALIZATION | |
| # ─────────────────────────────────────────────────── | |
| def initialize(self) -> bool: | |
| """ | |
| Initialize database: test connection and create schema. | |
| Returns True if successful. | |
| """ | |
| if not config.TIDB_HOST: | |
| print("[MEMORY] No TiDB host configured. Running without persistence.") | |
| return False | |
| for attempt in range(config.TIDB_RETRY_ATTEMPTS): | |
| try: | |
| with self.pool.connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute("SELECT 1") | |
| result = cur.fetchone() | |
| if result: | |
| self._connected = True | |
| print(f"[MEMORY] Connected to TiDB at {config.TIDB_HOST}") | |
| self._create_schema() | |
| return True | |
| except Exception as e: | |
| print(f"[MEMORY] Connection attempt {attempt + 1} failed: {e}") | |
| if attempt < config.TIDB_RETRY_ATTEMPTS - 1: | |
| time.sleep(config.TIDB_RETRY_DELAY) | |
| print("[MEMORY] Failed to connect to TiDB after all retries.") | |
| return False | |
| def _create_schema(self): | |
| """Create all required tables if they don't exist.""" | |
| schema_sql = [ | |
| """ | |
| CREATE TABLE IF NOT EXISTS nodes ( | |
| id VARCHAR(32) PRIMARY KEY, | |
| type VARCHAR(32) NOT NULL, | |
| content TEXT NOT NULL, | |
| vector JSON, | |
| weight FLOAT DEFAULT 1.0, | |
| connections INT DEFAULT 0, | |
| source VARCHAR(16) DEFAULT 'data', | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, | |
| INDEX idx_type (type), | |
| INDEX idx_weight (weight), | |
| INDEX idx_source (source) | |
| ) | |
| """, | |
| """ | |
| CREATE TABLE IF NOT EXISTS edges ( | |
| id VARCHAR(32) PRIMARY KEY, | |
| from_node VARCHAR(32) NOT NULL, | |
| to_node VARCHAR(32) NOT NULL, | |
| relation VARCHAR(32) NOT NULL, | |
| weight FLOAT DEFAULT 1.0, | |
| confidence FLOAT DEFAULT 1.0, | |
| source VARCHAR(16) DEFAULT 'data', | |
| used_count INT DEFAULT 0, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, | |
| INDEX idx_from (from_node), | |
| INDEX idx_to (to_node), | |
| INDEX idx_relation (relation), | |
| INDEX idx_weight (weight), | |
| INDEX idx_confidence (confidence) | |
| ) | |
| """, | |
| """ | |
| CREATE TABLE IF NOT EXISTS reasoning_chains ( | |
| id VARCHAR(32) PRIMARY KEY, | |
| path JSON NOT NULL, | |
| conclusion TEXT, | |
| confidence FLOAT DEFAULT 0.5, | |
| used_count INT DEFAULT 0, | |
| created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| INDEX idx_confidence (confidence), | |
| INDEX idx_used (used_count) | |
| ) | |
| """, | |
| """ | |
| CREATE TABLE IF NOT EXISTS thinking_state ( | |
| id INT PRIMARY KEY DEFAULT 1, | |
| current_cycle BIGINT DEFAULT 0, | |
| total_cycles BIGINT DEFAULT 0, | |
| cursor_position VARCHAR(64) DEFAULT '', | |
| phase VARCHAR(32) DEFAULT 'init', | |
| metrics JSON, | |
| started_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP | |
| ) | |
| """, | |
| """ | |
| CREATE TABLE IF NOT EXISTS file_checksums ( | |
| filename VARCHAR(255) PRIMARY KEY, | |
| checksum VARCHAR(64) NOT NULL, | |
| processed_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
| line_count INT DEFAULT 0 | |
| ) | |
| """, | |
| """ | |
| CREATE TABLE IF NOT EXISTS config_store ( | |
| k VARCHAR(64) PRIMARY KEY, | |
| v TEXT NOT NULL, | |
| updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP | |
| ) | |
| """ | |
| ] | |
| try: | |
| with self.pool.connection() as conn: | |
| with conn.cursor() as cur: | |
| for sql in schema_sql: | |
| cur.execute(sql) | |
| # Initialize thinking_state if empty | |
| cur.execute("SELECT COUNT(*) as cnt FROM thinking_state") | |
| row = cur.fetchone() | |
| if row["cnt"] == 0: | |
| cur.execute(""" | |
| INSERT INTO thinking_state | |
| (id, current_cycle, total_cycles, cursor_position, phase, metrics) | |
| VALUES (1, 0, 0, '', 'init', '{}') | |
| """) | |
| self._schema_ready = True | |
| print("[MEMORY] Schema ready.") | |
| except Exception as e: | |
| print(f"[MEMORY] Schema creation failed: {e}") | |
| traceback.print_exc() | |
| def is_connected(self) -> bool: | |
| return self._connected and self._schema_ready | |
| # ─────────────────────────────────────────────────── | |
| # EXECUTE HELPERS | |
| # ─────────────────────────────────────────────────── | |
| def _execute(self, sql: str, params: tuple = None, fetch: str = "none") -> Any: | |
| """ | |
| Execute SQL with auto-retry. | |
| fetch: "none", "one", "all" | |
| """ | |
| if not self._connected: | |
| return None | |
| for attempt in range(config.TIDB_RETRY_ATTEMPTS): | |
| try: | |
| with self.pool.connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(sql, params) | |
| if fetch == "one": | |
| return cur.fetchone() | |
| elif fetch == "all": | |
| return cur.fetchall() | |
| return True | |
| except pymysql.err.OperationalError as e: | |
| if attempt < config.TIDB_RETRY_ATTEMPTS - 1: | |
| time.sleep(config.TIDB_RETRY_DELAY) | |
| else: | |
| print(f"[MEMORY] Execute failed after retries: {e}") | |
| return None | |
| except Exception as e: | |
| print(f"[MEMORY] Execute error: {e}") | |
| return None | |
| def _execute_many(self, sql: str, params_list: List[tuple]) -> bool: | |
| """Execute SQL for multiple parameter sets (batch insert/update).""" | |
| if not self._connected or not params_list: | |
| return False | |
| for attempt in range(config.TIDB_RETRY_ATTEMPTS): | |
| try: | |
| with self.pool.connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.executemany(sql, params_list) | |
| return True | |
| except pymysql.err.OperationalError as e: | |
| if attempt < config.TIDB_RETRY_ATTEMPTS - 1: | |
| time.sleep(config.TIDB_RETRY_DELAY) | |
| else: | |
| print(f"[MEMORY] ExecuteMany failed: {e}") | |
| return False | |
| except Exception as e: | |
| print(f"[MEMORY] ExecuteMany error: {e}") | |
| return False | |
| # ─────────────────────────────────────────────────── | |
| # NODE OPERATIONS | |
| # ─────────────────────────────────────────────────── | |
| def save_node(self, node: dict): | |
| """Buffer a node for batch writing.""" | |
| self.buffer.buffer_node(node) | |
| def save_nodes_immediate(self, nodes: List[dict]) -> bool: | |
| """Write nodes directly to DB (bypass buffer). Used for bulk import.""" | |
| if not nodes: | |
| return True | |
| sql = """ | |
| INSERT INTO nodes (id, type, content, vector, weight, connections, source) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| content = VALUES(content), | |
| vector = VALUES(vector), | |
| weight = VALUES(weight), | |
| connections = VALUES(connections), | |
| source = VALUES(source), | |
| updated_at = CURRENT_TIMESTAMP | |
| """ | |
| params = [ | |
| ( | |
| n["id"], | |
| n.get("type", "fact"), | |
| n.get("content", ""), | |
| json.dumps(n.get("vector", [])), | |
| n.get("weight", 1.0), | |
| n.get("connections", 0), | |
| n.get("source", "data") | |
| ) | |
| for n in nodes | |
| ] | |
| return self._execute_many(sql, params) | |
| def load_all_nodes(self) -> List[dict]: | |
| """Load all nodes from DB. Used at startup.""" | |
| rows = self._execute( | |
| "SELECT id, type, content, vector, weight, connections, source, " | |
| "created_at, updated_at FROM nodes", | |
| fetch="all" | |
| ) | |
| if not rows: | |
| return [] | |
| nodes = [] | |
| for row in rows: | |
| vector_data = row.get("vector") | |
| if isinstance(vector_data, str): | |
| vector_data = json.loads(vector_data) | |
| nodes.append({ | |
| "id": row["id"], | |
| "type": row["type"], | |
| "content": row["content"], | |
| "vector": vector_data if vector_data else [], | |
| "weight": float(row["weight"]), | |
| "connections": int(row["connections"]), | |
| "source": row["source"], | |
| "created_at": str(row["created_at"]) if row.get("created_at") else "", | |
| "updated_at": str(row["updated_at"]) if row.get("updated_at") else "" | |
| }) | |
| return nodes | |
| def delete_node(self, node_id: str): | |
| """Buffer a node deletion.""" | |
| self.buffer.buffer_node_delete(node_id) | |
| def get_node_count(self) -> int: | |
| """Get total node count from DB.""" | |
| row = self._execute("SELECT COUNT(*) as cnt FROM nodes", fetch="one") | |
| return row["cnt"] if row else 0 | |
| # ─────────────────────────────────────────────────── | |
| # EDGE OPERATIONS | |
| # ─────────────────────────────────────────────────── | |
| def save_edge(self, edge: dict): | |
| """Buffer an edge for batch writing.""" | |
| self.buffer.buffer_edge(edge) | |
| def save_edges_immediate(self, edges: List[dict]) -> bool: | |
| """Write edges directly to DB (bypass buffer).""" | |
| if not edges: | |
| return True | |
| sql = """ | |
| INSERT INTO edges (id, from_node, to_node, relation, weight, confidence, source, used_count) | |
| VALUES (%s, %s, %s, %s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| weight = VALUES(weight), | |
| confidence = VALUES(confidence), | |
| used_count = VALUES(used_count), | |
| updated_at = CURRENT_TIMESTAMP | |
| """ | |
| params = [ | |
| ( | |
| e["id"], | |
| e["from_node"], | |
| e["to_node"], | |
| e.get("relation", "related_to"), | |
| e.get("weight", 1.0), | |
| e.get("confidence", 1.0), | |
| e.get("source", "data"), | |
| e.get("used_count", 0) | |
| ) | |
| for e in edges | |
| ] | |
| return self._execute_many(sql, params) | |
| def load_all_edges(self) -> List[dict]: | |
| """Load all edges from DB. Used at startup.""" | |
| rows = self._execute( | |
| "SELECT id, from_node, to_node, relation, weight, confidence, " | |
| "source, used_count, created_at FROM edges", | |
| fetch="all" | |
| ) | |
| if not rows: | |
| return [] | |
| return [ | |
| { | |
| "id": row["id"], | |
| "from_node": row["from_node"], | |
| "to_node": row["to_node"], | |
| "relation": row["relation"], | |
| "weight": float(row["weight"]), | |
| "confidence": float(row["confidence"]), | |
| "source": row["source"], | |
| "used_count": int(row["used_count"]), | |
| "created_at": str(row["created_at"]) if row.get("created_at") else "" | |
| } | |
| for row in rows | |
| ] | |
| def delete_edge(self, edge_id: str): | |
| """Buffer an edge deletion.""" | |
| self.buffer.buffer_edge_delete(edge_id) | |
| def get_edge_count(self) -> int: | |
| """Get total edge count from DB.""" | |
| row = self._execute("SELECT COUNT(*) as cnt FROM edges", fetch="one") | |
| return row["cnt"] if row else 0 | |
| # ─────────────────────────────────────────────────── | |
| # CHAIN OPERATIONS | |
| # ─────────────────────────────────────────────────── | |
| def save_chain(self, chain: dict): | |
| """Buffer a reasoning chain for batch writing.""" | |
| self.buffer.buffer_chain(chain) | |
| def load_top_chains(self, limit: int = 10000) -> List[dict]: | |
| """Load top-scored reasoning chains from DB.""" | |
| rows = self._execute( | |
| "SELECT id, path, conclusion, confidence, used_count, created_at " | |
| "FROM reasoning_chains ORDER BY confidence DESC, used_count DESC LIMIT %s", | |
| (limit,), | |
| fetch="all" | |
| ) | |
| if not rows: | |
| return [] | |
| chains = [] | |
| for row in rows: | |
| path_data = row.get("path") | |
| if isinstance(path_data, str): | |
| path_data = json.loads(path_data) | |
| chains.append({ | |
| "id": row["id"], | |
| "path": path_data if path_data else [], | |
| "conclusion": row.get("conclusion", ""), | |
| "confidence": float(row["confidence"]), | |
| "used_count": int(row["used_count"]), | |
| "created_at": str(row["created_at"]) if row.get("created_at") else "" | |
| }) | |
| return chains | |
| def get_chain_count(self) -> int: | |
| """Get total chain count.""" | |
| row = self._execute("SELECT COUNT(*) as cnt FROM reasoning_chains", fetch="one") | |
| return row["cnt"] if row else 0 | |
| def prune_weak_chains(self, min_confidence: float = 0.2, max_age_days: int = 30) -> int: | |
| """Delete chains with low confidence and old age. Returns count deleted.""" | |
| result = self._execute( | |
| "DELETE FROM reasoning_chains WHERE confidence < %s " | |
| "AND used_count = 0 AND created_at < DATE_SUB(NOW(), INTERVAL %s DAY)", | |
| (min_confidence, max_age_days) | |
| ) | |
| return 0 # executemany doesn't return rowcount easily | |
| # ─────────────────────────────────────────────────── | |
| # THINKING STATE | |
| # ─────────────────────────────────────────────────── | |
| def save_thinking_state(self, state: dict): | |
| """Buffer thinking state update.""" | |
| self.buffer.buffer_state(state) | |
| def load_thinking_state(self) -> dict: | |
| """Load thinking state from DB.""" | |
| row = self._execute( | |
| "SELECT current_cycle, total_cycles, cursor_position, phase, " | |
| "metrics, started_at, updated_at FROM thinking_state WHERE id = 1", | |
| fetch="one" | |
| ) | |
| if not row: | |
| return { | |
| "current_cycle": 0, | |
| "total_cycles": 0, | |
| "cursor_position": "", | |
| "phase": "init", | |
| "metrics": {}, | |
| "started_at": utils.timestamp_now(), | |
| "updated_at": utils.timestamp_now() | |
| } | |
| metrics = row.get("metrics") | |
| if isinstance(metrics, str): | |
| try: | |
| metrics = json.loads(metrics) | |
| except (json.JSONDecodeError, TypeError): | |
| metrics = {} | |
| return { | |
| "current_cycle": int(row.get("current_cycle", 0)), | |
| "total_cycles": int(row.get("total_cycles", 0)), | |
| "cursor_position": row.get("cursor_position", ""), | |
| "phase": row.get("phase", "init"), | |
| "metrics": metrics if metrics else {}, | |
| "started_at": str(row["started_at"]) if row.get("started_at") else "", | |
| "updated_at": str(row["updated_at"]) if row.get("updated_at") else "" | |
| } | |
| # ─────────────────────────────────────────────────── | |
| # FILE CHECKSUMS | |
| # ─────────────────────────────────────────────────── | |
| def save_file_checksum(self, filename: str, checksum: str, line_count: int = 0): | |
| """Buffer file checksum update.""" | |
| self.buffer.buffer_checksum(filename, checksum) | |
| def load_file_checksums(self) -> Dict[str, str]: | |
| """Load all file checksums. Returns {filename: checksum}.""" | |
| rows = self._execute( | |
| "SELECT filename, checksum FROM file_checksums", | |
| fetch="all" | |
| ) | |
| if not rows: | |
| return {} | |
| return {row["filename"]: row["checksum"] for row in rows} | |
| # ─────────────────────────────────────────────────── | |
| # CONFIG STORE | |
| # ─────────────────────────────────────────────────── | |
| def save_config(self, key: str, value: str): | |
| """Save a config key-value pair.""" | |
| self._execute( | |
| "INSERT INTO config_store (k, v) VALUES (%s, %s) " | |
| "ON DUPLICATE KEY UPDATE v = VALUES(v), updated_at = CURRENT_TIMESTAMP", | |
| (key, value) | |
| ) | |
| def load_config(self, key: str, default: str = "") -> str: | |
| """Load a config value.""" | |
| row = self._execute( | |
| "SELECT v FROM config_store WHERE k = %s", | |
| (key,), | |
| fetch="one" | |
| ) | |
| return row["v"] if row else default | |
| # ─────────────────────────────────────────────────── | |
| # FLUSH (Buffer → DB) | |
| # ─────────────────────────────────────────────────── | |
| def flush(self) -> dict: | |
| """ | |
| Flush all buffered operations to TiDB. | |
| Returns summary of what was flushed. | |
| """ | |
| if not self._connected: | |
| return {"status": "not_connected", "flushed": 0} | |
| data = self.buffer.drain() | |
| summary = { | |
| "nodes_upserted": 0, | |
| "edges_upserted": 0, | |
| "chains_upserted": 0, | |
| "nodes_deleted": 0, | |
| "edges_deleted": 0, | |
| "state_updated": False, | |
| "checksums_updated": 0 | |
| } | |
| try: | |
| # ── Upsert nodes ── | |
| if data["node_upserts"]: | |
| nodes = list(data["node_upserts"].values()) | |
| if self.save_nodes_immediate(nodes): | |
| summary["nodes_upserted"] = len(nodes) | |
| # ── Upsert edges ── | |
| if data["edge_upserts"]: | |
| edges = list(data["edge_upserts"].values()) | |
| if self.save_edges_immediate(edges): | |
| summary["edges_upserted"] = len(edges) | |
| # ── Upsert chains ── | |
| if data["chain_upserts"]: | |
| chains = list(data["chain_upserts"].values()) | |
| sql = """ | |
| INSERT INTO reasoning_chains (id, path, conclusion, confidence, used_count) | |
| VALUES (%s, %s, %s, %s, %s) | |
| ON DUPLICATE KEY UPDATE | |
| confidence = VALUES(confidence), | |
| used_count = VALUES(used_count) | |
| """ | |
| params = [ | |
| ( | |
| c["id"], | |
| json.dumps(c.get("path", [])), | |
| c.get("conclusion", ""), | |
| c.get("confidence", 0.5), | |
| c.get("used_count", 0) | |
| ) | |
| for c in chains | |
| ] | |
| if self._execute_many(sql, params): | |
| summary["chains_upserted"] = len(chains) | |
| # ── Delete nodes ── | |
| if data["node_deletes"]: | |
| for node_id in data["node_deletes"]: | |
| self._execute("DELETE FROM nodes WHERE id = %s", (node_id,)) | |
| # Also delete connected edges | |
| self._execute( | |
| "DELETE FROM edges WHERE from_node = %s OR to_node = %s", | |
| (node_id, node_id) | |
| ) | |
| summary["nodes_deleted"] = len(data["node_deletes"]) | |
| # ── Delete edges ── | |
| if data["edge_deletes"]: | |
| for edge_id in data["edge_deletes"]: | |
| self._execute("DELETE FROM edges WHERE id = %s", (edge_id,)) | |
| summary["edges_deleted"] = len(data["edge_deletes"]) | |
| # ── Update thinking state ── | |
| if data["state_update"]: | |
| state = data["state_update"] | |
| self._execute( | |
| """ | |
| UPDATE thinking_state SET | |
| current_cycle = %s, | |
| total_cycles = %s, | |
| cursor_position = %s, | |
| phase = %s, | |
| metrics = %s, | |
| updated_at = CURRENT_TIMESTAMP | |
| WHERE id = 1 | |
| """, | |
| ( | |
| state.get("current_cycle", 0), | |
| state.get("total_cycles", 0), | |
| state.get("cursor_position", ""), | |
| state.get("phase", ""), | |
| json.dumps(state.get("metrics", {})) | |
| ) | |
| ) | |
| summary["state_updated"] = True | |
| # ── Update file checksums ── | |
| if data["checksum_updates"]: | |
| sql = """ | |
| INSERT INTO file_checksums (filename, checksum, processed_at) | |
| VALUES (%s, %s, CURRENT_TIMESTAMP) | |
| ON DUPLICATE KEY UPDATE | |
| checksum = VALUES(checksum), | |
| processed_at = CURRENT_TIMESTAMP | |
| """ | |
| params = [ | |
| (fname, chk) | |
| for fname, chk in data["checksum_updates"].items() | |
| ] | |
| if self._execute_many(sql, params): | |
| summary["checksums_updated"] = len(params) | |
| except Exception as e: | |
| print(f"[MEMORY] Flush error: {e}") | |
| traceback.print_exc() | |
| summary["error"] = str(e) | |
| total = ( | |
| summary["nodes_upserted"] + summary["edges_upserted"] + | |
| summary["chains_upserted"] + summary["nodes_deleted"] + | |
| summary["edges_deleted"] | |
| ) | |
| if total > 0: | |
| print(f"[MEMORY] Flushed: {summary}") | |
| return summary | |
| def flush_if_needed(self) -> Optional[dict]: | |
| """Flush only if buffer thresholds are met.""" | |
| if self.buffer.should_flush(): | |
| return self.flush() | |
| return None | |
| # ─────────────────────────────────────────────────── | |
| # FULL STATE LOAD (Startup) | |
| # ─────────────────────────────────────────────────── | |
| def load_full_state(self) -> dict: | |
| """ | |
| Load complete brain state from DB. | |
| Called once at startup. | |
| Returns dict with all components. | |
| """ | |
| if not self._connected: | |
| return { | |
| "nodes": [], | |
| "edges": [], | |
| "chains": [], | |
| "thinking_state": { | |
| "current_cycle": 0, | |
| "total_cycles": 0, | |
| "cursor_position": "", | |
| "phase": "init", | |
| "metrics": {} | |
| }, | |
| "file_checksums": {}, | |
| "loaded": False | |
| } | |
| print("[MEMORY] Loading full state from TiDB...") | |
| start = time.time() | |
| nodes = self.load_all_nodes() | |
| print(f"[MEMORY] Loaded {len(nodes)} nodes") | |
| edges = self.load_all_edges() | |
| print(f"[MEMORY] Loaded {len(edges)} edges") | |
| chains = self.load_top_chains(limit=10000) | |
| print(f"[MEMORY] Loaded {len(chains)} chains") | |
| state = self.load_thinking_state() | |
| print(f"[MEMORY] Loaded thinking state (cycle {state['total_cycles']})") | |
| checksums = self.load_file_checksums() | |
| print(f"[MEMORY] Loaded {len(checksums)} file checksums") | |
| elapsed = time.time() - start | |
| print(f"[MEMORY] Full state loaded in {elapsed:.1f}s") | |
| return { | |
| "nodes": nodes, | |
| "edges": edges, | |
| "chains": chains, | |
| "thinking_state": state, | |
| "file_checksums": checksums, | |
| "loaded": True | |
| } | |
| # ─────────────────────────────────────────────────── | |
| # MAINTENANCE | |
| # ─────────────────────────────────────────────────── | |
| def prune_weak_edges(self, threshold: float = None) -> int: | |
| """Delete edges below weight threshold directly from DB.""" | |
| if threshold is None: | |
| threshold = config.PRUNE_WEIGHT_THRESHOLD | |
| self._execute( | |
| "DELETE FROM edges WHERE weight < %s AND source = 'inferred'", | |
| (threshold,) | |
| ) | |
| return 0 | |
| def prune_orphan_nodes(self) -> int: | |
| """Delete nodes with no edges and low weight.""" | |
| self._execute( | |
| """ | |
| DELETE FROM nodes WHERE connections = 0 | |
| AND weight < %s AND source = 'inferred' | |
| """, | |
| (config.WEIGHT_MIN * 2,) | |
| ) | |
| return 0 | |
| def get_db_stats(self) -> dict: | |
| """Get database-level statistics.""" | |
| if not self._connected: | |
| return {"connected": False} | |
| node_count = self.get_node_count() | |
| edge_count = self.get_edge_count() | |
| chain_count = self.get_chain_count() | |
| # Count by source | |
| inferred_nodes = self._execute( | |
| "SELECT COUNT(*) as cnt FROM nodes WHERE source = 'inferred'", | |
| fetch="one" | |
| ) | |
| inferred_edges = self._execute( | |
| "SELECT COUNT(*) as cnt FROM edges WHERE source = 'inferred'", | |
| fetch="one" | |
| ) | |
| return { | |
| "connected": True, | |
| "total_nodes": node_count, | |
| "total_edges": edge_count, | |
| "total_chains": chain_count, | |
| "inferred_nodes": inferred_nodes["cnt"] if inferred_nodes else 0, | |
| "inferred_edges": inferred_edges["cnt"] if inferred_edges else 0, | |
| "buffer_pending": self.buffer.pending_count | |
| } | |
| # ─────────────────────────────────────────────────── | |
| # CLEANUP | |
| # ─────────────────────────────────────────────────── | |
| def shutdown(self): | |
| """Graceful shutdown: flush buffer and close connections.""" | |
| print("[MEMORY] Shutting down...") | |
| # Final flush | |
| if self._connected: | |
| try: | |
| self.flush() | |
| print("[MEMORY] Final flush completed.") | |
| except Exception as e: | |
| print(f"[MEMORY] Final flush error: {e}") | |
| # Close pool | |
| self.pool.close_all() | |
| print("[MEMORY] Connections closed.") |