File size: 14,643 Bytes
ebaf2ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
"""
MLE Memory Module: Sparse Address Table
========================================
Distributed memory indexed by 4096-bit binary vectors.
Semantic proximity is encoded via Hamming distance.

Features:
- Bit-packed storage (512 bytes/vector) with cache-aligned layout
- LSH index for sub-linear approximate nearest neighbor search
- Multi-resolution indexing (coarse + fine search)
- Metadata/payload attachment per entry
"""

import numpy as np
from collections import defaultdict
from typing import List, Tuple, Optional, Dict, Any
import logging

from ..utils.simd_ops import (
    N_BITS, N_WORDS, N_BYTES,
    random_binary_vector, random_binary_vectors,
    hamming_distance, hamming_batch, hamming_topk,
    xor_vectors, popcount, majority_vote, hamming_similarity
)

logger = logging.getLogger(__name__)


class HammingLSH:
    """Locality-Sensitive Hashing for Hamming space.

    Uses random bit sampling as the LSH family:
        h_i(v) = v[bit_index_i]
        P(h(a) == h(b)) = 1 - hamming(a,b)/n

    Multiple hash tables with K-bit signatures for amplification.
    """

    def __init__(
        self,
        n_bits: int = N_BITS,
        n_tables: int = 32,
        n_projections: int = 8,
        seed: int = 42
    ):
        self.n_bits = n_bits
        self.n_tables = n_tables
        self.n_projections = n_projections

        rng = np.random.RandomState(seed)
        # Random bit indices for each table: which bits to sample
        self.bit_indices = [
            rng.choice(n_bits, n_projections, replace=False)
            for _ in range(n_tables)
        ]

        # Hash tables: table_idx -> {hash_key -> list of vector indices}
        self.tables: List[Dict[bytes, List[int]]] = [
            defaultdict(list) for _ in range(n_tables)
        ]
        self.n_indexed = 0

    def _compute_hash(self, bits_unpacked: np.ndarray, table_idx: int) -> bytes:
        """Extract hash signature from unpacked bit array."""
        sig = bits_unpacked[self.bit_indices[table_idx]]
        return np.packbits(sig).tobytes()

    def _unpack_vector(self, packed: np.ndarray) -> np.ndarray:
        """Unpack uint64 vector to bit array."""
        return np.unpackbits(packed.view(np.uint8))

    def add(self, packed_vector: np.ndarray, idx: int):
        """Add a single vector to all hash tables."""
        bits = self._unpack_vector(packed_vector)
        for t in range(self.n_tables):
            h = self._compute_hash(bits, t)
            self.tables[t][h].append(idx)
        self.n_indexed += 1

    def add_batch(self, packed_vectors: np.ndarray, start_idx: int = 0):
        """Add multiple vectors to all hash tables."""
        for i in range(len(packed_vectors)):
            self.add(packed_vectors[i], start_idx + i)

    def query_candidates(self, packed_query: np.ndarray, max_candidates: int = 2000) -> np.ndarray:
        """Find candidate indices via LSH (before exact reranking).
        Returns deduplicated candidate indices.
        """
        bits = self._unpack_vector(packed_query)
        candidates = set()
        for t in range(self.n_tables):
            h = self._compute_hash(bits, t)
            bucket = self.tables[t].get(h, [])
            candidates.update(bucket)
            if len(candidates) >= max_candidates:
                break
        return np.array(list(candidates)[:max_candidates], dtype=np.int64)

    def query_multi_probe(self, packed_query: np.ndarray, n_probes: int = 3,
                          max_candidates: int = 2000) -> np.ndarray:
        """Multi-probe LSH: also check neighboring buckets by flipping bits.
        Increases recall at cost of more bucket lookups.
        For short signatures (n_projections <= 12), we can flip multiple
        bits combinatorially.
        """
        bits = self._unpack_vector(packed_query)
        candidates = set()

        for t in range(self.n_tables):
            # Original bucket
            h = self._compute_hash(bits, t)
            candidates.update(self.tables[t].get(h, []))

            # Probe neighboring buckets: flip each single projection bit
            probe_bits = bits.copy()
            n_probe_bits = min(n_probes, self.n_projections)
            for probe in range(n_probe_bits):
                bit_pos = self.bit_indices[t][probe]
                probe_bits[bit_pos] ^= 1
                h2 = self._compute_hash(probe_bits, t)
                candidates.update(self.tables[t].get(h2, []))
                probe_bits[bit_pos] ^= 1  # restore

            # Also probe 2-bit flips for the first few bits
            if n_probes >= 2 and self.n_projections >= 2:
                for i in range(min(n_probes, self.n_projections)):
                    for j in range(i + 1, min(n_probes, self.n_projections)):
                        probe_bits = bits.copy()
                        probe_bits[self.bit_indices[t][i]] ^= 1
                        probe_bits[self.bit_indices[t][j]] ^= 1
                        h3 = self._compute_hash(probe_bits, t)
                        candidates.update(self.tables[t].get(h3, []))

            if len(candidates) >= max_candidates:
                break

        return np.array(list(candidates)[:max_candidates], dtype=np.int64)


class MemoryEntry:
    """A single entry in the Sparse Address Table."""
    __slots__ = ['address', 'content', 'metadata', 'activation', 'timestamp']

    def __init__(self, address: np.ndarray, content: np.ndarray,
                 metadata: Optional[Dict[str, Any]] = None):
        self.address = address          # (N_WORDS,) uint64 - the index key
        self.content = content          # (N_WORDS,) uint64 - stored data
        self.metadata = metadata or {}  # arbitrary metadata
        self.activation = 0.0           # current activation level
        self.timestamp = 0              # last access time


class SparseAddressTable:
    """
    Distributed memory indexed by 4096-bit binary vectors.

    Architecture:
    - Primary storage: contiguous (N, N_WORDS) uint64 matrix for SIMD batch ops
    - LSH index: multi-table bit-sampling for sub-linear ANN search
    - Content storage: separate matrix (decoupled address/content)
    - Activation tracking: for energy-based dynamics

    Memory layout is Structure of Arrays (SoA) for cache locality
    during batch Hamming distance computation.
    """

    def __init__(
        self,
        capacity: int = 100_000,
        lsh_tables: int = 32,
        lsh_projections: int = 8,
        lsh_seed: int = 42
    ):
        self.capacity = capacity
        self.size = 0

        # SoA layout: addresses and contents as contiguous matrices
        self._addresses = np.zeros((capacity, N_WORDS), dtype=np.uint64)
        self._contents = np.zeros((capacity, N_WORDS), dtype=np.uint64)

        # Metadata and activation stored separately
        self._metadata: List[Dict[str, Any]] = [None] * capacity
        self._activations = np.zeros(capacity, dtype=np.float64)
        self._timestamps = np.zeros(capacity, dtype=np.int64)

        # LSH index — use short signatures (8-bit) with many tables (32)
        # for high recall on 4096-bit vectors
        self.lsh = HammingLSH(
            n_bits=N_BITS,
            n_tables=lsh_tables,
            n_projections=lsh_projections,
            seed=lsh_seed
        )

        # Global step counter for timestamps
        self._step = 0

        # Symbol table: name -> index mapping for named concepts
        self._symbol_table: Dict[str, int] = {}

    @property
    def addresses(self) -> np.ndarray:
        """Active address vectors. Shape: (size, N_WORDS)."""
        return self._addresses[:self.size]

    @property
    def contents(self) -> np.ndarray:
        """Active content vectors. Shape: (size, N_WORDS)."""
        return self._contents[:self.size]

    @property
    def activations(self) -> np.ndarray:
        """Active activation levels. Shape: (size,)."""
        return self._activations[:self.size]

    def store(self, address: np.ndarray, content: np.ndarray,
              metadata: Optional[Dict[str, Any]] = None,
              name: Optional[str] = None) -> int:
        """Store a new entry. Returns the entry index."""
        if self.size >= self.capacity:
            self._grow()

        idx = self.size
        self._addresses[idx] = address
        self._contents[idx] = content
        self._metadata[idx] = metadata or {}
        self._timestamps[idx] = self._step
        self._step += 1

        # Index in LSH
        self.lsh.add(address, idx)

        if name:
            self._symbol_table[name] = idx

        self.size += 1
        return idx

    def store_concept(self, name: str, content: Optional[np.ndarray] = None,
                      metadata: Optional[Dict[str, Any]] = None) -> int:
        """Store a named concept with auto-generated address."""
        address = random_binary_vector()
        if content is None:
            content = random_binary_vector()
        meta = metadata or {}
        meta['name'] = name
        return self.store(address, content, metadata=meta, name=name)

    def get_by_name(self, name: str) -> Optional[Tuple[np.ndarray, np.ndarray, Dict]]:
        """Retrieve entry by symbolic name."""
        idx = self._symbol_table.get(name)
        if idx is None:
            return None
        return (self._addresses[idx].copy(),
                self._contents[idx].copy(),
                self._metadata[idx])

    def get_address_by_name(self, name: str) -> Optional[np.ndarray]:
        """Get the address vector for a named concept."""
        idx = self._symbol_table.get(name)
        if idx is None:
            return None
        return self._addresses[idx].copy()

    def get_content_by_name(self, name: str) -> Optional[np.ndarray]:
        """Get the content vector for a named concept."""
        idx = self._symbol_table.get(name)
        if idx is None:
            return None
        return self._contents[idx].copy()

    def query_nearest(self, query: np.ndarray, k: int = 10,
                      use_lsh: bool = True) -> List[Tuple[int, int]]:
        """Find k nearest entries by Hamming distance to query address.

        Args:
            query: (N_WORDS,) uint64 query vector
            k: number of results
            use_lsh: if True, use LSH pre-filter; if False, exact scan

        Returns:
            List of (index, distance) tuples, sorted by distance ascending.
        """
        if self.size == 0:
            return []

        if use_lsh and self.size > 1000:
            # LSH pre-filter → exact rerank
            candidates = self.lsh.query_multi_probe(query, max_candidates=max(k * 10, 2000))
            if len(candidates) == 0:
                # Fallback to exact
                candidates = np.arange(self.size, dtype=np.int64)
            candidate_vecs = np.ascontiguousarray(self._addresses[candidates])
            dists = hamming_batch(query, candidate_vecs)
            if k < len(candidates):
                top_local = np.argpartition(dists, k)[:k]
            else:
                top_local = np.arange(len(candidates))
            order = np.argsort(dists[top_local])
            sorted_local = top_local[order]
            return [(int(candidates[i]), int(dists[i])) for i in sorted_local]
        else:
            # Exact search
            indices, distances = hamming_topk(query, self.addresses, k=k)
            return [(int(idx), int(dist)) for idx, dist in zip(indices, distances)]

    def query_radius(self, query: np.ndarray, radius: int) -> List[Tuple[int, int]]:
        """Find all entries within Hamming radius of query."""
        if self.size == 0:
            return []
        dists = hamming_batch(query, self.addresses)
        mask = dists <= radius
        indices = np.where(mask)[0]
        return [(int(i), int(dists[i])) for i in indices]

    def activate(self, indices: np.ndarray, strengths: np.ndarray):
        """Set activation levels for specified entries."""
        self._activations[indices] = strengths

    def decay_activations(self, factor: float = 0.95):
        """Exponential decay of all activations."""
        self._activations[:self.size] *= factor

    def get_active(self, threshold: float = 0.1) -> np.ndarray:
        """Get indices of entries with activation above threshold."""
        return np.where(self._activations[:self.size] > threshold)[0]

    def read_activated(self, threshold: float = 0.1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Read contents of activated entries.
        Returns: (indices, content_vectors, activation_strengths)
        """
        active_idx = self.get_active(threshold)
        if len(active_idx) == 0:
            return (np.array([], dtype=np.int64),
                    np.zeros((0, N_WORDS), dtype=np.uint64),
                    np.array([], dtype=np.float64))
        return (active_idx,
                self._contents[active_idx],
                self._activations[active_idx])

    def _grow(self, factor: float = 1.5):
        """Grow internal storage when capacity is exceeded."""
        new_cap = int(self.capacity * factor)
        logger.info(f"Growing SparseAddressTable from {self.capacity} to {new_cap}")

        new_addr = np.zeros((new_cap, N_WORDS), dtype=np.uint64)
        new_cont = np.zeros((new_cap, N_WORDS), dtype=np.uint64)
        new_act = np.zeros(new_cap, dtype=np.float64)
        new_ts = np.zeros(new_cap, dtype=np.int64)

        new_addr[:self.size] = self._addresses[:self.size]
        new_cont[:self.size] = self._contents[:self.size]
        new_act[:self.size] = self._activations[:self.size]
        new_ts[:self.size] = self._timestamps[:self.size]

        self._addresses = new_addr
        self._contents = new_cont
        self._activations = new_act
        self._timestamps = new_ts
        self._metadata.extend([None] * (new_cap - self.capacity))
        self.capacity = new_cap

    def stats(self) -> Dict[str, Any]:
        """Return memory statistics."""
        mem_bytes = self.size * N_BYTES * 2  # addresses + contents
        return {
            'size': self.size,
            'capacity': self.capacity,
            'memory_mb': mem_bytes / (1024 * 1024),
            'lsh_tables': self.lsh.n_tables,
            'lsh_projections': self.lsh.n_projections,
            'active_entries': int((self._activations[:self.size] > 0.1).sum()),
            'named_symbols': len(self._symbol_table),
        }

    def __repr__(self):
        return (f"SparseAddressTable(size={self.size}, capacity={self.capacity}, "
                f"symbols={len(self._symbol_table)})")