| """ |
| MLE Memory Module: Sparse Address Table |
| ======================================== |
| Distributed memory indexed by 4096-bit binary vectors. |
| Semantic proximity is encoded via Hamming distance. |
| |
| Features: |
| - Bit-packed storage (512 bytes/vector) with cache-aligned layout |
| - LSH index for sub-linear approximate nearest neighbor search |
| - Multi-resolution indexing (coarse + fine search) |
| - Metadata/payload attachment per entry |
| """ |
|
|
| import numpy as np |
| from collections import defaultdict |
| from typing import List, Tuple, Optional, Dict, Any |
| import logging |
|
|
| from ..utils.simd_ops import ( |
| N_BITS, N_WORDS, N_BYTES, |
| random_binary_vector, random_binary_vectors, |
| hamming_distance, hamming_batch, hamming_topk, |
| xor_vectors, popcount, majority_vote, hamming_similarity |
| ) |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class HammingLSH: |
| """Locality-Sensitive Hashing for Hamming space. |
| |
| Uses random bit sampling as the LSH family: |
| h_i(v) = v[bit_index_i] |
| P(h(a) == h(b)) = 1 - hamming(a,b)/n |
| |
| Multiple hash tables with K-bit signatures for amplification. |
| """ |
|
|
| def __init__( |
| self, |
| n_bits: int = N_BITS, |
| n_tables: int = 32, |
| n_projections: int = 8, |
| seed: int = 42 |
| ): |
| self.n_bits = n_bits |
| self.n_tables = n_tables |
| self.n_projections = n_projections |
|
|
| rng = np.random.RandomState(seed) |
| |
| self.bit_indices = [ |
| rng.choice(n_bits, n_projections, replace=False) |
| for _ in range(n_tables) |
| ] |
|
|
| |
| self.tables: List[Dict[bytes, List[int]]] = [ |
| defaultdict(list) for _ in range(n_tables) |
| ] |
| self.n_indexed = 0 |
|
|
| def _compute_hash(self, bits_unpacked: np.ndarray, table_idx: int) -> bytes: |
| """Extract hash signature from unpacked bit array.""" |
| sig = bits_unpacked[self.bit_indices[table_idx]] |
| return np.packbits(sig).tobytes() |
|
|
| def _unpack_vector(self, packed: np.ndarray) -> np.ndarray: |
| """Unpack uint64 vector to bit array.""" |
| return np.unpackbits(packed.view(np.uint8)) |
|
|
| def add(self, packed_vector: np.ndarray, idx: int): |
| """Add a single vector to all hash tables.""" |
| bits = self._unpack_vector(packed_vector) |
| for t in range(self.n_tables): |
| h = self._compute_hash(bits, t) |
| self.tables[t][h].append(idx) |
| self.n_indexed += 1 |
|
|
| def add_batch(self, packed_vectors: np.ndarray, start_idx: int = 0): |
| """Add multiple vectors to all hash tables.""" |
| for i in range(len(packed_vectors)): |
| self.add(packed_vectors[i], start_idx + i) |
|
|
| def query_candidates(self, packed_query: np.ndarray, max_candidates: int = 2000) -> np.ndarray: |
| """Find candidate indices via LSH (before exact reranking). |
| Returns deduplicated candidate indices. |
| """ |
| bits = self._unpack_vector(packed_query) |
| candidates = set() |
| for t in range(self.n_tables): |
| h = self._compute_hash(bits, t) |
| bucket = self.tables[t].get(h, []) |
| candidates.update(bucket) |
| if len(candidates) >= max_candidates: |
| break |
| return np.array(list(candidates)[:max_candidates], dtype=np.int64) |
|
|
| def query_multi_probe(self, packed_query: np.ndarray, n_probes: int = 3, |
| max_candidates: int = 2000) -> np.ndarray: |
| """Multi-probe LSH: also check neighboring buckets by flipping bits. |
| Increases recall at cost of more bucket lookups. |
| For short signatures (n_projections <= 12), we can flip multiple |
| bits combinatorially. |
| """ |
| bits = self._unpack_vector(packed_query) |
| candidates = set() |
|
|
| for t in range(self.n_tables): |
| |
| h = self._compute_hash(bits, t) |
| candidates.update(self.tables[t].get(h, [])) |
|
|
| |
| probe_bits = bits.copy() |
| n_probe_bits = min(n_probes, self.n_projections) |
| for probe in range(n_probe_bits): |
| bit_pos = self.bit_indices[t][probe] |
| probe_bits[bit_pos] ^= 1 |
| h2 = self._compute_hash(probe_bits, t) |
| candidates.update(self.tables[t].get(h2, [])) |
| probe_bits[bit_pos] ^= 1 |
|
|
| |
| if n_probes >= 2 and self.n_projections >= 2: |
| for i in range(min(n_probes, self.n_projections)): |
| for j in range(i + 1, min(n_probes, self.n_projections)): |
| probe_bits = bits.copy() |
| probe_bits[self.bit_indices[t][i]] ^= 1 |
| probe_bits[self.bit_indices[t][j]] ^= 1 |
| h3 = self._compute_hash(probe_bits, t) |
| candidates.update(self.tables[t].get(h3, [])) |
|
|
| if len(candidates) >= max_candidates: |
| break |
|
|
| return np.array(list(candidates)[:max_candidates], dtype=np.int64) |
|
|
|
|
| class MemoryEntry: |
| """A single entry in the Sparse Address Table.""" |
| __slots__ = ['address', 'content', 'metadata', 'activation', 'timestamp'] |
|
|
| def __init__(self, address: np.ndarray, content: np.ndarray, |
| metadata: Optional[Dict[str, Any]] = None): |
| self.address = address |
| self.content = content |
| self.metadata = metadata or {} |
| self.activation = 0.0 |
| self.timestamp = 0 |
|
|
|
|
| class SparseAddressTable: |
| """ |
| Distributed memory indexed by 4096-bit binary vectors. |
| |
| Architecture: |
| - Primary storage: contiguous (N, N_WORDS) uint64 matrix for SIMD batch ops |
| - LSH index: multi-table bit-sampling for sub-linear ANN search |
| - Content storage: separate matrix (decoupled address/content) |
| - Activation tracking: for energy-based dynamics |
| |
| Memory layout is Structure of Arrays (SoA) for cache locality |
| during batch Hamming distance computation. |
| """ |
|
|
| def __init__( |
| self, |
| capacity: int = 100_000, |
| lsh_tables: int = 32, |
| lsh_projections: int = 8, |
| lsh_seed: int = 42 |
| ): |
| self.capacity = capacity |
| self.size = 0 |
|
|
| |
| self._addresses = np.zeros((capacity, N_WORDS), dtype=np.uint64) |
| self._contents = np.zeros((capacity, N_WORDS), dtype=np.uint64) |
|
|
| |
| self._metadata: List[Dict[str, Any]] = [None] * capacity |
| self._activations = np.zeros(capacity, dtype=np.float64) |
| self._timestamps = np.zeros(capacity, dtype=np.int64) |
|
|
| |
| |
| self.lsh = HammingLSH( |
| n_bits=N_BITS, |
| n_tables=lsh_tables, |
| n_projections=lsh_projections, |
| seed=lsh_seed |
| ) |
|
|
| |
| self._step = 0 |
|
|
| |
| self._symbol_table: Dict[str, int] = {} |
|
|
| @property |
| def addresses(self) -> np.ndarray: |
| """Active address vectors. Shape: (size, N_WORDS).""" |
| return self._addresses[:self.size] |
|
|
| @property |
| def contents(self) -> np.ndarray: |
| """Active content vectors. Shape: (size, N_WORDS).""" |
| return self._contents[:self.size] |
|
|
| @property |
| def activations(self) -> np.ndarray: |
| """Active activation levels. Shape: (size,).""" |
| return self._activations[:self.size] |
|
|
| def store(self, address: np.ndarray, content: np.ndarray, |
| metadata: Optional[Dict[str, Any]] = None, |
| name: Optional[str] = None) -> int: |
| """Store a new entry. Returns the entry index.""" |
| if self.size >= self.capacity: |
| self._grow() |
|
|
| idx = self.size |
| self._addresses[idx] = address |
| self._contents[idx] = content |
| self._metadata[idx] = metadata or {} |
| self._timestamps[idx] = self._step |
| self._step += 1 |
|
|
| |
| self.lsh.add(address, idx) |
|
|
| if name: |
| self._symbol_table[name] = idx |
|
|
| self.size += 1 |
| return idx |
|
|
| def store_concept(self, name: str, content: Optional[np.ndarray] = None, |
| metadata: Optional[Dict[str, Any]] = None) -> int: |
| """Store a named concept with auto-generated address.""" |
| address = random_binary_vector() |
| if content is None: |
| content = random_binary_vector() |
| meta = metadata or {} |
| meta['name'] = name |
| return self.store(address, content, metadata=meta, name=name) |
|
|
| def get_by_name(self, name: str) -> Optional[Tuple[np.ndarray, np.ndarray, Dict]]: |
| """Retrieve entry by symbolic name.""" |
| idx = self._symbol_table.get(name) |
| if idx is None: |
| return None |
| return (self._addresses[idx].copy(), |
| self._contents[idx].copy(), |
| self._metadata[idx]) |
|
|
| def get_address_by_name(self, name: str) -> Optional[np.ndarray]: |
| """Get the address vector for a named concept.""" |
| idx = self._symbol_table.get(name) |
| if idx is None: |
| return None |
| return self._addresses[idx].copy() |
|
|
| def get_content_by_name(self, name: str) -> Optional[np.ndarray]: |
| """Get the content vector for a named concept.""" |
| idx = self._symbol_table.get(name) |
| if idx is None: |
| return None |
| return self._contents[idx].copy() |
|
|
| def query_nearest(self, query: np.ndarray, k: int = 10, |
| use_lsh: bool = True) -> List[Tuple[int, int]]: |
| """Find k nearest entries by Hamming distance to query address. |
| |
| Args: |
| query: (N_WORDS,) uint64 query vector |
| k: number of results |
| use_lsh: if True, use LSH pre-filter; if False, exact scan |
| |
| Returns: |
| List of (index, distance) tuples, sorted by distance ascending. |
| """ |
| if self.size == 0: |
| return [] |
|
|
| if use_lsh and self.size > 1000: |
| |
| candidates = self.lsh.query_multi_probe(query, max_candidates=max(k * 10, 2000)) |
| if len(candidates) == 0: |
| |
| candidates = np.arange(self.size, dtype=np.int64) |
| candidate_vecs = np.ascontiguousarray(self._addresses[candidates]) |
| dists = hamming_batch(query, candidate_vecs) |
| if k < len(candidates): |
| top_local = np.argpartition(dists, k)[:k] |
| else: |
| top_local = np.arange(len(candidates)) |
| order = np.argsort(dists[top_local]) |
| sorted_local = top_local[order] |
| return [(int(candidates[i]), int(dists[i])) for i in sorted_local] |
| else: |
| |
| indices, distances = hamming_topk(query, self.addresses, k=k) |
| return [(int(idx), int(dist)) for idx, dist in zip(indices, distances)] |
|
|
| def query_radius(self, query: np.ndarray, radius: int) -> List[Tuple[int, int]]: |
| """Find all entries within Hamming radius of query.""" |
| if self.size == 0: |
| return [] |
| dists = hamming_batch(query, self.addresses) |
| mask = dists <= radius |
| indices = np.where(mask)[0] |
| return [(int(i), int(dists[i])) for i in indices] |
|
|
| def activate(self, indices: np.ndarray, strengths: np.ndarray): |
| """Set activation levels for specified entries.""" |
| self._activations[indices] = strengths |
|
|
| def decay_activations(self, factor: float = 0.95): |
| """Exponential decay of all activations.""" |
| self._activations[:self.size] *= factor |
|
|
| def get_active(self, threshold: float = 0.1) -> np.ndarray: |
| """Get indices of entries with activation above threshold.""" |
| return np.where(self._activations[:self.size] > threshold)[0] |
|
|
| def read_activated(self, threshold: float = 0.1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: |
| """Read contents of activated entries. |
| Returns: (indices, content_vectors, activation_strengths) |
| """ |
| active_idx = self.get_active(threshold) |
| if len(active_idx) == 0: |
| return (np.array([], dtype=np.int64), |
| np.zeros((0, N_WORDS), dtype=np.uint64), |
| np.array([], dtype=np.float64)) |
| return (active_idx, |
| self._contents[active_idx], |
| self._activations[active_idx]) |
|
|
| def _grow(self, factor: float = 1.5): |
| """Grow internal storage when capacity is exceeded.""" |
| new_cap = int(self.capacity * factor) |
| logger.info(f"Growing SparseAddressTable from {self.capacity} to {new_cap}") |
|
|
| new_addr = np.zeros((new_cap, N_WORDS), dtype=np.uint64) |
| new_cont = np.zeros((new_cap, N_WORDS), dtype=np.uint64) |
| new_act = np.zeros(new_cap, dtype=np.float64) |
| new_ts = np.zeros(new_cap, dtype=np.int64) |
|
|
| new_addr[:self.size] = self._addresses[:self.size] |
| new_cont[:self.size] = self._contents[:self.size] |
| new_act[:self.size] = self._activations[:self.size] |
| new_ts[:self.size] = self._timestamps[:self.size] |
|
|
| self._addresses = new_addr |
| self._contents = new_cont |
| self._activations = new_act |
| self._timestamps = new_ts |
| self._metadata.extend([None] * (new_cap - self.capacity)) |
| self.capacity = new_cap |
|
|
| def stats(self) -> Dict[str, Any]: |
| """Return memory statistics.""" |
| mem_bytes = self.size * N_BYTES * 2 |
| return { |
| 'size': self.size, |
| 'capacity': self.capacity, |
| 'memory_mb': mem_bytes / (1024 * 1024), |
| 'lsh_tables': self.lsh.n_tables, |
| 'lsh_projections': self.lsh.n_projections, |
| 'active_entries': int((self._activations[:self.size] > 0.1).sum()), |
| 'named_symbols': len(self._symbol_table), |
| } |
|
|
| def __repr__(self): |
| return (f"SparseAddressTable(size={self.size}, capacity={self.capacity}, " |
| f"symbols={len(self._symbol_table)})") |
|
|