File size: 16,369 Bytes
d38c1d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
"""
Memory-mapped CSR (Compressed Sparse Row) storage for co-occurrence graphs.

Replaces Python dict-of-dicts with three numpy.memmap arrays:
  indptr.bin   — row pointers (V+1 int32)
  indices.bin  — column indices (nnz int32), sorted within each row
  data.bin     — edge weights (nnz float32)

43M edges: ~346MB on disk, ~50MB RSS (OS pages in only accessed rows).
Python dicts for the same: ~4GB. 11x reduction.

Write-Ahead Log (WAL) buffers real-time edge updates (teach/RLHF)
without rebuilding the CSR. Reads merge CSR + WAL transparently.
"""

import os
import math
import threading
import numpy as np

try:
    from scipy.sparse import csr_matrix as scipy_csr, coo_matrix as scipy_coo
    from scipy.sparse.linalg import norm as sparse_norm_scipy
    HAS_SCIPY = True
except ImportError:
    HAS_SCIPY = False


class MMapCSR:
    """Memory-mapped CSR matrix backed by three flat files.

    Zero-copy row access via numpy views into mmap'd arrays.
    OS page cache manages hot/cold rows automatically.
    """

    def __init__(self, path: str, readonly: bool = True):
        """Open an existing CSR from disk.

        Args:
            path: directory containing indptr.bin, indices.bin, data.bin
            readonly: if True, open read-only (no writes to mmap)
        """
        self._path = path
        mode = 'r' if readonly else 'r+'

        self._indptr = np.memmap(
            os.path.join(path, 'indptr.bin'),
            dtype=np.int32, mode=mode,
        )
        self._indices = np.memmap(
            os.path.join(path, 'indices.bin'),
            dtype=np.int32, mode=mode,
        )
        self._data = np.memmap(
            os.path.join(path, 'data.bin'),
            dtype=np.float32, mode=mode,
        )

        self.num_rows = len(self._indptr) - 1
        self.nnz = len(self._indices)

        # Pre-compute per-row norms for fast cosine
        self._row_norms = None  # lazy

        # scipy view (lazy)
        self._scipy_mat = None

    @property
    def scipy_matrix(self):
        """scipy.sparse.csr_matrix view over the same mmap'd data. No copy."""
        if self._scipy_mat is None:
            if not HAS_SCIPY:
                raise ImportError("scipy required for spmv. pip install scipy")
            self._scipy_mat = scipy_csr(
                (self._data, self._indices, self._indptr),
                shape=(self.num_rows, self.num_rows),
            )
        return self._scipy_mat

    def get_row(self, row_idx: int):
        """Get (col_indices, weights) for a row. Zero-copy numpy views.

        Returns:
            (np.ndarray[int32], np.ndarray[float32]) — column indices and weights
        """
        if row_idx < 0 or row_idx >= self.num_rows:
            return np.array([], dtype=np.int32), np.array([], dtype=np.float32)
        start = int(self._indptr[row_idx])
        end = int(self._indptr[row_idx + 1])
        if start == end:
            return np.array([], dtype=np.int32), np.array([], dtype=np.float32)
        return self._indices[start:end], self._data[start:end]

    def get_row_dict(self, row_idx: int) -> dict:
        """Get row as {col_idx: weight} dict. Compatibility shim for convergence."""
        cols, vals = self.get_row(row_idx)
        if len(cols) == 0:
            return {}
        return dict(zip(cols.astype(int), vals.astype(float)))

    def row_dot(self, row_idx: int, query: dict) -> float:
        """Dot product between a CSR row and a sparse query dict."""
        cols, vals = self.get_row(row_idx)
        if len(cols) == 0:
            return 0.0
        total = 0.0
        for i in range(len(cols)):
            c = int(cols[i])
            if c in query:
                total += float(vals[i]) * query[c]
        return total

    def row_norm(self, row_idx: int) -> float:
        """L2 norm of a row."""
        _, vals = self.get_row(row_idx)
        if len(vals) == 0:
            return 0.0
        return float(np.sqrt(np.dot(vals, vals)))

    def row_cosine(self, row_idx: int, query: dict, q_norm: float = 0.0) -> float:
        """Cosine similarity between a CSR row and a sparse query dict."""
        dot = self.row_dot(row_idx, query)
        if dot == 0.0:
            return 0.0
        rn = self.row_norm(row_idx)
        if rn == 0.0:
            return 0.0
        if q_norm == 0.0:
            q_norm = math.sqrt(sum(v * v for v in query.values()))
        if q_norm == 0.0:
            return 0.0
        return dot / (rn * q_norm)

    def spmv(self, query_vec: np.ndarray) -> np.ndarray:
        """Sparse matrix-vector multiply: CSR @ query_vec.

        Uses scipy BLAS if available, else manual row iteration.
        """
        if HAS_SCIPY:
            return np.asarray(self.scipy_matrix @ query_vec).ravel()
        # Fallback: manual
        result = np.zeros(self.num_rows, dtype=np.float32)
        for i in range(self.num_rows):
            start = int(self._indptr[i])
            end = int(self._indptr[i + 1])
            if start == end:
                continue
            cols = self._indices[start:end]
            vals = self._data[start:end]
            result[i] = np.dot(vals, query_vec[cols])
        return result

    @property
    def row_norms_vec(self):
        """Precomputed L2 norms for all rows. Lazy, cached."""
        if self._row_norms is None:
            if HAS_SCIPY:
                mat = self.scipy_matrix
                # Per-row L2 norm
                self._row_norms = np.sqrt(
                    np.asarray(mat.multiply(mat).sum(axis=1)).ravel()
                )
            else:
                norms = np.zeros(self.num_rows, dtype=np.float32)
                for i in range(self.num_rows):
                    _, vals = self.get_row(i)
                    if len(vals) > 0:
                        norms[i] = float(np.sqrt(np.dot(vals, vals)))
                self._row_norms = norms
        return self._row_norms

    def stats(self) -> dict:
        """Return size stats."""
        return {
            'num_rows': self.num_rows,
            'nnz': self.nnz,
            'avg_degree': self.nnz / max(self.num_rows, 1),
            'disk_mb': (
                os.path.getsize(os.path.join(self._path, 'indptr.bin')) +
                os.path.getsize(os.path.join(self._path, 'indices.bin')) +
                os.path.getsize(os.path.join(self._path, 'data.bin'))
            ) / (1024 * 1024),
        }


class CSRWriteAheadLog:
    """Buffers real-time edge updates without rebuilding CSR.

    teach() writes to WAL + LMDB. ask() reads CSR + WAL overlay merged.
    When WAL exceeds max_entries, signal for background CSR rebuild.
    """

    def __init__(self, max_entries: int = 100_000):
        self._log = {}       # word_idx → {neighbor_idx: weight}
        self._count = 0
        self._max = max_entries
        self._lock = threading.Lock()

    @property
    def needs_flush(self) -> bool:
        return self._count >= self._max

    @property
    def entry_count(self) -> int:
        return self._count

    def add_edge(self, word_a: int, word_b: int, weight: float):
        """Buffer a co-occurrence edge update."""
        with self._lock:
            if word_a not in self._log:
                self._log[word_a] = {}
            if word_b not in self._log:
                self._log[word_b] = {}
            old_a = self._log[word_a].get(word_b, 0)
            old_b = self._log[word_b].get(word_a, 0)
            self._log[word_a][word_b] = old_a + weight
            self._log[word_b][word_a] = old_b + weight
            self._count += 2

    def get_overlay(self, word_idx: int) -> dict:
        """Get pending WAL updates for one word. Returns {} if none."""
        with self._lock:
            return dict(self._log.get(word_idx, {}))

    def merge_read(self, csr: MMapCSR, word_idx: int) -> dict:
        """Read CSR row + WAL overlay, merged into one dict.

        This is the primary read path for convergence.
        """
        # Base from CSR (zero-copy read, converted to dict)
        if word_idx < csr.num_rows:
            base = csr.get_row_dict(word_idx)
        else:
            # New word beyond CSR — lives entirely in WAL
            base = {}

        # Overlay from WAL
        overlay = self.get_overlay(word_idx)
        if not overlay:
            # Self-connection if empty
            if not base:
                return {word_idx: 1.0}
            return base

        # Merge: WAL adds to base
        merged = dict(base)
        for k, v in overlay.items():
            merged[k] = merged.get(k, 0) + v
        if not merged:
            merged[word_idx] = 1.0
        return merged

    def clear(self):
        """Reset WAL after successful CSR rebuild."""
        with self._lock:
            self._log.clear()
            self._count = 0
            self._scipy_cache = None

    def snapshot(self) -> dict:
        """Get a frozen copy of all WAL entries (for CSR rebuild)."""
        with self._lock:
            return {k: dict(v) for k, v in self._log.items()}

    # --- Persistence (LMDB-backed, crash-safe) ---

    _EDGE_FMT = '<iif'  # word_a, word_b, weight — 12 bytes per edge
    _EDGE_SIZE = 12

    def persist_to_lmdb(self, env, wal_db):
        """Write all WAL edges to LMDB wal_edges sub-db. Atomic transaction."""
        import struct
        snapshot = self.snapshot()
        if not snapshot:
            return 0
        # Flatten to unique directed edges (a→b, not both a→b and b→a)
        seen = set()
        edges = []
        for word_a, neighbors in snapshot.items():
            for word_b, weight in neighbors.items():
                key = (min(word_a, word_b), max(word_a, word_b))
                if key not in seen:
                    seen.add(key)
                    # Clamp to float32 range to prevent overflow
                    clamped = max(-1e30, min(1e30, weight))
                    edges.append((word_a, word_b, clamped))
        with env.begin(write=True) as txn:
            # Clear old WAL entries
            txn.drop(wal_db, delete=False)
            # Write all edges as sequential keys
            for i, (a, b, w) in enumerate(edges):
                txn.put(
                    struct.pack('<i', i),
                    struct.pack(self._EDGE_FMT, a, b, w),
                    db=wal_db,
                )
            txn.put(b'__count__', struct.pack('<i', len(edges)), db=wal_db)
        return len(edges)

    def load_from_lmdb(self, env, wal_db):
        """Load persisted WAL edges from LMDB into memory. Called at startup."""
        import struct
        loaded = 0
        with env.begin(db=wal_db) as txn:
            count_raw = txn.get(b'__count__', db=wal_db)
            if not count_raw:
                return 0
            count = struct.unpack('<i', count_raw)[0]
            for i in range(count):
                val = txn.get(struct.pack('<i', i), db=wal_db)
                if val and len(val) == self._EDGE_SIZE:
                    a, b, w = struct.unpack(self._EDGE_FMT, val)
                    # Add directly to log without double-counting
                    if a not in self._log:
                        self._log[a] = {}
                    if b not in self._log:
                        self._log[b] = {}
                    self._log[a][b] = self._log[a].get(b, 0) + w
                    self._log[b][a] = self._log[b].get(a, 0) + w
                    self._count += 2
                    loaded += 1
        return loaded

    # --- Background flush thread ---

    _flush_thread = None
    _flush_stop = None
    _flush_env = None
    _flush_db = None
    _last_persisted_count = 0

    def start_background_flush(self, env, wal_db, interval_sec: float = 5.0):
        """Start background thread that persists WAL to LMDB every interval_sec."""
        import time as _time
        self._flush_env = env
        self._flush_db = wal_db
        self._flush_stop = threading.Event()
        self._last_persisted_count = self._count

        def _flush_loop():
            while not self._flush_stop.is_set():
                self._flush_stop.wait(interval_sec)
                if self._flush_stop.is_set():
                    break
                if self._count != self._last_persisted_count:
                    self.persist_to_lmdb(env, wal_db)
                    self._last_persisted_count = self._count

        self._flush_thread = threading.Thread(target=_flush_loop, daemon=True)
        self._flush_thread.start()

    def stop_background_flush(self):
        """Stop background flush and do a final persist."""
        if self._flush_stop:
            self._flush_stop.set()
        if self._flush_thread:
            self._flush_thread.join(timeout=5)
        # Final flush
        if self._flush_env and self._flush_db and self._count > self._last_persisted_count:
            self.persist_to_lmdb(self._flush_env, self._flush_db)

    _scipy_cache = None
    _scipy_cache_count = -1

    def effective_scipy_matrix(self, csr: 'MMapCSR'):
        """Get CSR + WAL merged as a single scipy sparse matrix. Cached.

        Rebuilds only when WAL has changed since last call.
        """
        if not HAS_SCIPY:
            raise ImportError("scipy required")
        if self._scipy_cache is not None and self._scipy_cache_count == self._count:
            return self._scipy_cache
        base = csr.scipy_matrix
        if self._count == 0:
            self._scipy_cache = base
            self._scipy_cache_count = 0
            return base
        # Build COO from WAL entries
        rows, cols, vals = [], [], []
        with self._lock:
            for word_idx, neighbors in self._log.items():
                for neighbor_idx, weight in neighbors.items():
                    rows.append(word_idx)
                    cols.append(neighbor_idx)
                    vals.append(weight)
            count_now = self._count
        if not rows:
            self._scipy_cache = base
            self._scipy_cache_count = count_now
            return base
        # Expand shape if WAL has new words beyond CSR
        max_idx = max(max(rows), max(cols), base.shape[0] - 1)
        shape = (max_idx + 1, max_idx + 1)
        if shape != base.shape:
            base.resize(shape)  # in-place, returns None
        wal_sparse = scipy_coo(
            (np.array(vals, dtype=np.float32),
             (np.array(rows, dtype=np.int32), np.array(cols, dtype=np.int32))),
            shape=shape,
        ).tocsr()
        self._scipy_cache = base + wal_sparse
        self._scipy_cache_count = count_now
        return self._scipy_cache


def build_csr_from_dicts(cooc: dict, num_words: int, output_path: str):
    """Build CSR files from a dict-of-dicts co-occurrence graph.

    Args:
        cooc: {word_idx: {neighbor_idx: weight, ...}, ...}
        num_words: total vocabulary size (V)
        output_path: directory to write indptr.bin, indices.bin, data.bin
    """
    os.makedirs(output_path, exist_ok=True)

    # Pass 1: count edges per row
    row_counts = np.zeros(num_words, dtype=np.int32)
    for row_idx in range(num_words):
        if row_idx in cooc:
            row_counts[row_idx] = len(cooc[row_idx])

    # Build indptr
    indptr = np.zeros(num_words + 1, dtype=np.int32)
    np.cumsum(row_counts, out=indptr[1:])
    nnz = int(indptr[-1])

    # Pass 2: fill indices and data, sorted within rows
    indices = np.zeros(nnz, dtype=np.int32)
    data = np.zeros(nnz, dtype=np.float32)

    for row_idx in range(num_words):
        if row_idx not in cooc or not cooc[row_idx]:
            continue
        start = int(indptr[row_idx])
        edges = sorted(cooc[row_idx].items(), key=lambda x: x[0])
        for j, (col, weight) in enumerate(edges):
            indices[start + j] = col
            data[start + j] = weight

    # Write files
    indptr_path = os.path.join(output_path, 'indptr.bin')
    indices_path = os.path.join(output_path, 'indices.bin')
    data_path = os.path.join(output_path, 'data.bin')

    indptr.tofile(indptr_path)
    indices.tofile(indices_path)
    data.tofile(data_path)

    print(f"CSR built: {num_words:,} rows, {nnz:,} edges, "
          f"{(os.path.getsize(indptr_path) + os.path.getsize(indices_path) + os.path.getsize(data_path)) / 1024 / 1024:.1f} MB")

    return indptr_path, indices_path, data_path