| """ |
| Memory-mapped CSR (Compressed Sparse Row) storage for co-occurrence graphs. |
| |
| Replaces Python dict-of-dicts with three numpy.memmap arrays: |
| indptr.bin — row pointers (V+1 int32) |
| indices.bin — column indices (nnz int32), sorted within each row |
| data.bin — edge weights (nnz float32) |
| |
| 43M edges: ~346MB on disk, ~50MB RSS (OS pages in only accessed rows). |
| Python dicts for the same: ~4GB. 11x reduction. |
| |
| Write-Ahead Log (WAL) buffers real-time edge updates (teach/RLHF) |
| without rebuilding the CSR. Reads merge CSR + WAL transparently. |
| """ |
|
|
| import os |
| import math |
| import threading |
| import numpy as np |
|
|
| try: |
| from scipy.sparse import csr_matrix as scipy_csr, coo_matrix as scipy_coo |
| from scipy.sparse.linalg import norm as sparse_norm_scipy |
| HAS_SCIPY = True |
| except ImportError: |
| HAS_SCIPY = False |
|
|
|
|
| class MMapCSR: |
| """Memory-mapped CSR matrix backed by three flat files. |
| |
| Zero-copy row access via numpy views into mmap'd arrays. |
| OS page cache manages hot/cold rows automatically. |
| """ |
|
|
| def __init__(self, path: str, readonly: bool = True): |
| """Open an existing CSR from disk. |
| |
| Args: |
| path: directory containing indptr.bin, indices.bin, data.bin |
| readonly: if True, open read-only (no writes to mmap) |
| """ |
| self._path = path |
| mode = 'r' if readonly else 'r+' |
|
|
| self._indptr = np.memmap( |
| os.path.join(path, 'indptr.bin'), |
| dtype=np.int32, mode=mode, |
| ) |
| self._indices = np.memmap( |
| os.path.join(path, 'indices.bin'), |
| dtype=np.int32, mode=mode, |
| ) |
| self._data = np.memmap( |
| os.path.join(path, 'data.bin'), |
| dtype=np.float32, mode=mode, |
| ) |
|
|
| self.num_rows = len(self._indptr) - 1 |
| self.nnz = len(self._indices) |
|
|
| |
| self._row_norms = None |
|
|
| |
| self._scipy_mat = None |
|
|
| @property |
| def scipy_matrix(self): |
| """scipy.sparse.csr_matrix view over the same mmap'd data. No copy.""" |
| if self._scipy_mat is None: |
| if not HAS_SCIPY: |
| raise ImportError("scipy required for spmv. pip install scipy") |
| self._scipy_mat = scipy_csr( |
| (self._data, self._indices, self._indptr), |
| shape=(self.num_rows, self.num_rows), |
| ) |
| return self._scipy_mat |
|
|
| def get_row(self, row_idx: int): |
| """Get (col_indices, weights) for a row. Zero-copy numpy views. |
| |
| Returns: |
| (np.ndarray[int32], np.ndarray[float32]) — column indices and weights |
| """ |
| if row_idx < 0 or row_idx >= self.num_rows: |
| return np.array([], dtype=np.int32), np.array([], dtype=np.float32) |
| start = int(self._indptr[row_idx]) |
| end = int(self._indptr[row_idx + 1]) |
| if start == end: |
| return np.array([], dtype=np.int32), np.array([], dtype=np.float32) |
| return self._indices[start:end], self._data[start:end] |
|
|
| def get_row_dict(self, row_idx: int) -> dict: |
| """Get row as {col_idx: weight} dict. Compatibility shim for convergence.""" |
| cols, vals = self.get_row(row_idx) |
| if len(cols) == 0: |
| return {} |
| return dict(zip(cols.astype(int), vals.astype(float))) |
|
|
| def row_dot(self, row_idx: int, query: dict) -> float: |
| """Dot product between a CSR row and a sparse query dict.""" |
| cols, vals = self.get_row(row_idx) |
| if len(cols) == 0: |
| return 0.0 |
| total = 0.0 |
| for i in range(len(cols)): |
| c = int(cols[i]) |
| if c in query: |
| total += float(vals[i]) * query[c] |
| return total |
|
|
| def row_norm(self, row_idx: int) -> float: |
| """L2 norm of a row.""" |
| _, vals = self.get_row(row_idx) |
| if len(vals) == 0: |
| return 0.0 |
| return float(np.sqrt(np.dot(vals, vals))) |
|
|
| def row_cosine(self, row_idx: int, query: dict, q_norm: float = 0.0) -> float: |
| """Cosine similarity between a CSR row and a sparse query dict.""" |
| dot = self.row_dot(row_idx, query) |
| if dot == 0.0: |
| return 0.0 |
| rn = self.row_norm(row_idx) |
| if rn == 0.0: |
| return 0.0 |
| if q_norm == 0.0: |
| q_norm = math.sqrt(sum(v * v for v in query.values())) |
| if q_norm == 0.0: |
| return 0.0 |
| return dot / (rn * q_norm) |
|
|
| def spmv(self, query_vec: np.ndarray) -> np.ndarray: |
| """Sparse matrix-vector multiply: CSR @ query_vec. |
| |
| Uses scipy BLAS if available, else manual row iteration. |
| """ |
| if HAS_SCIPY: |
| return np.asarray(self.scipy_matrix @ query_vec).ravel() |
| |
| result = np.zeros(self.num_rows, dtype=np.float32) |
| for i in range(self.num_rows): |
| start = int(self._indptr[i]) |
| end = int(self._indptr[i + 1]) |
| if start == end: |
| continue |
| cols = self._indices[start:end] |
| vals = self._data[start:end] |
| result[i] = np.dot(vals, query_vec[cols]) |
| return result |
|
|
| @property |
| def row_norms_vec(self): |
| """Precomputed L2 norms for all rows. Lazy, cached.""" |
| if self._row_norms is None: |
| if HAS_SCIPY: |
| mat = self.scipy_matrix |
| |
| self._row_norms = np.sqrt( |
| np.asarray(mat.multiply(mat).sum(axis=1)).ravel() |
| ) |
| else: |
| norms = np.zeros(self.num_rows, dtype=np.float32) |
| for i in range(self.num_rows): |
| _, vals = self.get_row(i) |
| if len(vals) > 0: |
| norms[i] = float(np.sqrt(np.dot(vals, vals))) |
| self._row_norms = norms |
| return self._row_norms |
|
|
| def stats(self) -> dict: |
| """Return size stats.""" |
| return { |
| 'num_rows': self.num_rows, |
| 'nnz': self.nnz, |
| 'avg_degree': self.nnz / max(self.num_rows, 1), |
| 'disk_mb': ( |
| os.path.getsize(os.path.join(self._path, 'indptr.bin')) + |
| os.path.getsize(os.path.join(self._path, 'indices.bin')) + |
| os.path.getsize(os.path.join(self._path, 'data.bin')) |
| ) / (1024 * 1024), |
| } |
|
|
|
|
| class CSRWriteAheadLog: |
| """Buffers real-time edge updates without rebuilding CSR. |
| |
| teach() writes to WAL + LMDB. ask() reads CSR + WAL overlay merged. |
| When WAL exceeds max_entries, signal for background CSR rebuild. |
| """ |
|
|
| def __init__(self, max_entries: int = 100_000): |
| self._log = {} |
| self._count = 0 |
| self._max = max_entries |
| self._lock = threading.Lock() |
|
|
| @property |
| def needs_flush(self) -> bool: |
| return self._count >= self._max |
|
|
| @property |
| def entry_count(self) -> int: |
| return self._count |
|
|
| def add_edge(self, word_a: int, word_b: int, weight: float): |
| """Buffer a co-occurrence edge update.""" |
| with self._lock: |
| if word_a not in self._log: |
| self._log[word_a] = {} |
| if word_b not in self._log: |
| self._log[word_b] = {} |
| old_a = self._log[word_a].get(word_b, 0) |
| old_b = self._log[word_b].get(word_a, 0) |
| self._log[word_a][word_b] = old_a + weight |
| self._log[word_b][word_a] = old_b + weight |
| self._count += 2 |
|
|
| def get_overlay(self, word_idx: int) -> dict: |
| """Get pending WAL updates for one word. Returns {} if none.""" |
| with self._lock: |
| return dict(self._log.get(word_idx, {})) |
|
|
| def merge_read(self, csr: MMapCSR, word_idx: int) -> dict: |
| """Read CSR row + WAL overlay, merged into one dict. |
| |
| This is the primary read path for convergence. |
| """ |
| |
| if word_idx < csr.num_rows: |
| base = csr.get_row_dict(word_idx) |
| else: |
| |
| base = {} |
|
|
| |
| overlay = self.get_overlay(word_idx) |
| if not overlay: |
| |
| if not base: |
| return {word_idx: 1.0} |
| return base |
|
|
| |
| merged = dict(base) |
| for k, v in overlay.items(): |
| merged[k] = merged.get(k, 0) + v |
| if not merged: |
| merged[word_idx] = 1.0 |
| return merged |
|
|
| def clear(self): |
| """Reset WAL after successful CSR rebuild.""" |
| with self._lock: |
| self._log.clear() |
| self._count = 0 |
| self._scipy_cache = None |
|
|
| def snapshot(self) -> dict: |
| """Get a frozen copy of all WAL entries (for CSR rebuild).""" |
| with self._lock: |
| return {k: dict(v) for k, v in self._log.items()} |
|
|
| |
|
|
| _EDGE_FMT = '<iif' |
| _EDGE_SIZE = 12 |
|
|
| def persist_to_lmdb(self, env, wal_db): |
| """Write all WAL edges to LMDB wal_edges sub-db. Atomic transaction.""" |
| import struct |
| snapshot = self.snapshot() |
| if not snapshot: |
| return 0 |
| |
| seen = set() |
| edges = [] |
| for word_a, neighbors in snapshot.items(): |
| for word_b, weight in neighbors.items(): |
| key = (min(word_a, word_b), max(word_a, word_b)) |
| if key not in seen: |
| seen.add(key) |
| |
| clamped = max(-1e30, min(1e30, weight)) |
| edges.append((word_a, word_b, clamped)) |
| with env.begin(write=True) as txn: |
| |
| txn.drop(wal_db, delete=False) |
| |
| for i, (a, b, w) in enumerate(edges): |
| txn.put( |
| struct.pack('<i', i), |
| struct.pack(self._EDGE_FMT, a, b, w), |
| db=wal_db, |
| ) |
| txn.put(b'__count__', struct.pack('<i', len(edges)), db=wal_db) |
| return len(edges) |
|
|
| def load_from_lmdb(self, env, wal_db): |
| """Load persisted WAL edges from LMDB into memory. Called at startup.""" |
| import struct |
| loaded = 0 |
| with env.begin(db=wal_db) as txn: |
| count_raw = txn.get(b'__count__', db=wal_db) |
| if not count_raw: |
| return 0 |
| count = struct.unpack('<i', count_raw)[0] |
| for i in range(count): |
| val = txn.get(struct.pack('<i', i), db=wal_db) |
| if val and len(val) == self._EDGE_SIZE: |
| a, b, w = struct.unpack(self._EDGE_FMT, val) |
| |
| if a not in self._log: |
| self._log[a] = {} |
| if b not in self._log: |
| self._log[b] = {} |
| self._log[a][b] = self._log[a].get(b, 0) + w |
| self._log[b][a] = self._log[b].get(a, 0) + w |
| self._count += 2 |
| loaded += 1 |
| return loaded |
|
|
| |
|
|
| _flush_thread = None |
| _flush_stop = None |
| _flush_env = None |
| _flush_db = None |
| _last_persisted_count = 0 |
|
|
| def start_background_flush(self, env, wal_db, interval_sec: float = 5.0): |
| """Start background thread that persists WAL to LMDB every interval_sec.""" |
| import time as _time |
| self._flush_env = env |
| self._flush_db = wal_db |
| self._flush_stop = threading.Event() |
| self._last_persisted_count = self._count |
|
|
| def _flush_loop(): |
| while not self._flush_stop.is_set(): |
| self._flush_stop.wait(interval_sec) |
| if self._flush_stop.is_set(): |
| break |
| if self._count != self._last_persisted_count: |
| self.persist_to_lmdb(env, wal_db) |
| self._last_persisted_count = self._count |
|
|
| self._flush_thread = threading.Thread(target=_flush_loop, daemon=True) |
| self._flush_thread.start() |
|
|
| def stop_background_flush(self): |
| """Stop background flush and do a final persist.""" |
| if self._flush_stop: |
| self._flush_stop.set() |
| if self._flush_thread: |
| self._flush_thread.join(timeout=5) |
| |
| if self._flush_env and self._flush_db and self._count > self._last_persisted_count: |
| self.persist_to_lmdb(self._flush_env, self._flush_db) |
|
|
| _scipy_cache = None |
| _scipy_cache_count = -1 |
|
|
| def effective_scipy_matrix(self, csr: 'MMapCSR'): |
| """Get CSR + WAL merged as a single scipy sparse matrix. Cached. |
| |
| Rebuilds only when WAL has changed since last call. |
| """ |
| if not HAS_SCIPY: |
| raise ImportError("scipy required") |
| if self._scipy_cache is not None and self._scipy_cache_count == self._count: |
| return self._scipy_cache |
| base = csr.scipy_matrix |
| if self._count == 0: |
| self._scipy_cache = base |
| self._scipy_cache_count = 0 |
| return base |
| |
| rows, cols, vals = [], [], [] |
| with self._lock: |
| for word_idx, neighbors in self._log.items(): |
| for neighbor_idx, weight in neighbors.items(): |
| rows.append(word_idx) |
| cols.append(neighbor_idx) |
| vals.append(weight) |
| count_now = self._count |
| if not rows: |
| self._scipy_cache = base |
| self._scipy_cache_count = count_now |
| return base |
| |
| max_idx = max(max(rows), max(cols), base.shape[0] - 1) |
| shape = (max_idx + 1, max_idx + 1) |
| if shape != base.shape: |
| base.resize(shape) |
| wal_sparse = scipy_coo( |
| (np.array(vals, dtype=np.float32), |
| (np.array(rows, dtype=np.int32), np.array(cols, dtype=np.int32))), |
| shape=shape, |
| ).tocsr() |
| self._scipy_cache = base + wal_sparse |
| self._scipy_cache_count = count_now |
| return self._scipy_cache |
|
|
|
|
| def build_csr_from_dicts(cooc: dict, num_words: int, output_path: str): |
| """Build CSR files from a dict-of-dicts co-occurrence graph. |
| |
| Args: |
| cooc: {word_idx: {neighbor_idx: weight, ...}, ...} |
| num_words: total vocabulary size (V) |
| output_path: directory to write indptr.bin, indices.bin, data.bin |
| """ |
| os.makedirs(output_path, exist_ok=True) |
|
|
| |
| row_counts = np.zeros(num_words, dtype=np.int32) |
| for row_idx in range(num_words): |
| if row_idx in cooc: |
| row_counts[row_idx] = len(cooc[row_idx]) |
|
|
| |
| indptr = np.zeros(num_words + 1, dtype=np.int32) |
| np.cumsum(row_counts, out=indptr[1:]) |
| nnz = int(indptr[-1]) |
|
|
| |
| indices = np.zeros(nnz, dtype=np.int32) |
| data = np.zeros(nnz, dtype=np.float32) |
|
|
| for row_idx in range(num_words): |
| if row_idx not in cooc or not cooc[row_idx]: |
| continue |
| start = int(indptr[row_idx]) |
| edges = sorted(cooc[row_idx].items(), key=lambda x: x[0]) |
| for j, (col, weight) in enumerate(edges): |
| indices[start + j] = col |
| data[start + j] = weight |
|
|
| |
| indptr_path = os.path.join(output_path, 'indptr.bin') |
| indices_path = os.path.join(output_path, 'indices.bin') |
| data_path = os.path.join(output_path, 'data.bin') |
|
|
| indptr.tofile(indptr_path) |
| indices.tofile(indices_path) |
| data.tofile(data_path) |
|
|
| print(f"CSR built: {num_words:,} rows, {nnz:,} edges, " |
| f"{(os.path.getsize(indptr_path) + os.path.getsize(indices_path) + os.path.getsize(data_path)) / 1024 / 1024:.1f} MB") |
|
|
| return indptr_path, indices_path, data_path |
|
|