""" MLE Memory Module: Sparse Address Table ======================================== Distributed memory indexed by 4096-bit binary vectors. Semantic proximity is encoded via Hamming distance. Features: - Bit-packed storage (512 bytes/vector) with cache-aligned layout - LSH index for sub-linear approximate nearest neighbor search - Multi-resolution indexing (coarse + fine search) - Metadata/payload attachment per entry """ import numpy as np from collections import defaultdict from typing import List, Tuple, Optional, Dict, Any import logging from ..utils.simd_ops import ( N_BITS, N_WORDS, N_BYTES, random_binary_vector, random_binary_vectors, hamming_distance, hamming_batch, hamming_topk, xor_vectors, popcount, majority_vote, hamming_similarity ) logger = logging.getLogger(__name__) class HammingLSH: """Locality-Sensitive Hashing for Hamming space. Uses random bit sampling as the LSH family: h_i(v) = v[bit_index_i] P(h(a) == h(b)) = 1 - hamming(a,b)/n Multiple hash tables with K-bit signatures for amplification. """ def __init__( self, n_bits: int = N_BITS, n_tables: int = 32, n_projections: int = 8, seed: int = 42 ): self.n_bits = n_bits self.n_tables = n_tables self.n_projections = n_projections rng = np.random.RandomState(seed) # Random bit indices for each table: which bits to sample self.bit_indices = [ rng.choice(n_bits, n_projections, replace=False) for _ in range(n_tables) ] # Hash tables: table_idx -> {hash_key -> list of vector indices} self.tables: List[Dict[bytes, List[int]]] = [ defaultdict(list) for _ in range(n_tables) ] self.n_indexed = 0 def _compute_hash(self, bits_unpacked: np.ndarray, table_idx: int) -> bytes: """Extract hash signature from unpacked bit array.""" sig = bits_unpacked[self.bit_indices[table_idx]] return np.packbits(sig).tobytes() def _unpack_vector(self, packed: np.ndarray) -> np.ndarray: """Unpack uint64 vector to bit array.""" return np.unpackbits(packed.view(np.uint8)) def add(self, packed_vector: np.ndarray, idx: int): """Add a single vector to all hash tables.""" bits = self._unpack_vector(packed_vector) for t in range(self.n_tables): h = self._compute_hash(bits, t) self.tables[t][h].append(idx) self.n_indexed += 1 def add_batch(self, packed_vectors: np.ndarray, start_idx: int = 0): """Add multiple vectors to all hash tables.""" for i in range(len(packed_vectors)): self.add(packed_vectors[i], start_idx + i) def query_candidates(self, packed_query: np.ndarray, max_candidates: int = 2000) -> np.ndarray: """Find candidate indices via LSH (before exact reranking). Returns deduplicated candidate indices. """ bits = self._unpack_vector(packed_query) candidates = set() for t in range(self.n_tables): h = self._compute_hash(bits, t) bucket = self.tables[t].get(h, []) candidates.update(bucket) if len(candidates) >= max_candidates: break return np.array(list(candidates)[:max_candidates], dtype=np.int64) def query_multi_probe(self, packed_query: np.ndarray, n_probes: int = 3, max_candidates: int = 2000) -> np.ndarray: """Multi-probe LSH: also check neighboring buckets by flipping bits. Increases recall at cost of more bucket lookups. For short signatures (n_projections <= 12), we can flip multiple bits combinatorially. """ bits = self._unpack_vector(packed_query) candidates = set() for t in range(self.n_tables): # Original bucket h = self._compute_hash(bits, t) candidates.update(self.tables[t].get(h, [])) # Probe neighboring buckets: flip each single projection bit probe_bits = bits.copy() n_probe_bits = min(n_probes, self.n_projections) for probe in range(n_probe_bits): bit_pos = self.bit_indices[t][probe] probe_bits[bit_pos] ^= 1 h2 = self._compute_hash(probe_bits, t) candidates.update(self.tables[t].get(h2, [])) probe_bits[bit_pos] ^= 1 # restore # Also probe 2-bit flips for the first few bits if n_probes >= 2 and self.n_projections >= 2: for i in range(min(n_probes, self.n_projections)): for j in range(i + 1, min(n_probes, self.n_projections)): probe_bits = bits.copy() probe_bits[self.bit_indices[t][i]] ^= 1 probe_bits[self.bit_indices[t][j]] ^= 1 h3 = self._compute_hash(probe_bits, t) candidates.update(self.tables[t].get(h3, [])) if len(candidates) >= max_candidates: break return np.array(list(candidates)[:max_candidates], dtype=np.int64) class MemoryEntry: """A single entry in the Sparse Address Table.""" __slots__ = ['address', 'content', 'metadata', 'activation', 'timestamp'] def __init__(self, address: np.ndarray, content: np.ndarray, metadata: Optional[Dict[str, Any]] = None): self.address = address # (N_WORDS,) uint64 - the index key self.content = content # (N_WORDS,) uint64 - stored data self.metadata = metadata or {} # arbitrary metadata self.activation = 0.0 # current activation level self.timestamp = 0 # last access time class SparseAddressTable: """ Distributed memory indexed by 4096-bit binary vectors. Architecture: - Primary storage: contiguous (N, N_WORDS) uint64 matrix for SIMD batch ops - LSH index: multi-table bit-sampling for sub-linear ANN search - Content storage: separate matrix (decoupled address/content) - Activation tracking: for energy-based dynamics Memory layout is Structure of Arrays (SoA) for cache locality during batch Hamming distance computation. """ def __init__( self, capacity: int = 100_000, lsh_tables: int = 32, lsh_projections: int = 8, lsh_seed: int = 42 ): self.capacity = capacity self.size = 0 # SoA layout: addresses and contents as contiguous matrices self._addresses = np.zeros((capacity, N_WORDS), dtype=np.uint64) self._contents = np.zeros((capacity, N_WORDS), dtype=np.uint64) # Metadata and activation stored separately self._metadata: List[Dict[str, Any]] = [None] * capacity self._activations = np.zeros(capacity, dtype=np.float64) self._timestamps = np.zeros(capacity, dtype=np.int64) # LSH index — use short signatures (8-bit) with many tables (32) # for high recall on 4096-bit vectors self.lsh = HammingLSH( n_bits=N_BITS, n_tables=lsh_tables, n_projections=lsh_projections, seed=lsh_seed ) # Global step counter for timestamps self._step = 0 # Symbol table: name -> index mapping for named concepts self._symbol_table: Dict[str, int] = {} @property def addresses(self) -> np.ndarray: """Active address vectors. Shape: (size, N_WORDS).""" return self._addresses[:self.size] @property def contents(self) -> np.ndarray: """Active content vectors. Shape: (size, N_WORDS).""" return self._contents[:self.size] @property def activations(self) -> np.ndarray: """Active activation levels. Shape: (size,).""" return self._activations[:self.size] def store(self, address: np.ndarray, content: np.ndarray, metadata: Optional[Dict[str, Any]] = None, name: Optional[str] = None) -> int: """Store a new entry. Returns the entry index.""" if self.size >= self.capacity: self._grow() idx = self.size self._addresses[idx] = address self._contents[idx] = content self._metadata[idx] = metadata or {} self._timestamps[idx] = self._step self._step += 1 # Index in LSH self.lsh.add(address, idx) if name: self._symbol_table[name] = idx self.size += 1 return idx def store_concept(self, name: str, content: Optional[np.ndarray] = None, metadata: Optional[Dict[str, Any]] = None) -> int: """Store a named concept with auto-generated address.""" address = random_binary_vector() if content is None: content = random_binary_vector() meta = metadata or {} meta['name'] = name return self.store(address, content, metadata=meta, name=name) def get_by_name(self, name: str) -> Optional[Tuple[np.ndarray, np.ndarray, Dict]]: """Retrieve entry by symbolic name.""" idx = self._symbol_table.get(name) if idx is None: return None return (self._addresses[idx].copy(), self._contents[idx].copy(), self._metadata[idx]) def get_address_by_name(self, name: str) -> Optional[np.ndarray]: """Get the address vector for a named concept.""" idx = self._symbol_table.get(name) if idx is None: return None return self._addresses[idx].copy() def get_content_by_name(self, name: str) -> Optional[np.ndarray]: """Get the content vector for a named concept.""" idx = self._symbol_table.get(name) if idx is None: return None return self._contents[idx].copy() def query_nearest(self, query: np.ndarray, k: int = 10, use_lsh: bool = True) -> List[Tuple[int, int]]: """Find k nearest entries by Hamming distance to query address. Args: query: (N_WORDS,) uint64 query vector k: number of results use_lsh: if True, use LSH pre-filter; if False, exact scan Returns: List of (index, distance) tuples, sorted by distance ascending. """ if self.size == 0: return [] if use_lsh and self.size > 1000: # LSH pre-filter → exact rerank candidates = self.lsh.query_multi_probe(query, max_candidates=max(k * 10, 2000)) if len(candidates) == 0: # Fallback to exact candidates = np.arange(self.size, dtype=np.int64) candidate_vecs = np.ascontiguousarray(self._addresses[candidates]) dists = hamming_batch(query, candidate_vecs) if k < len(candidates): top_local = np.argpartition(dists, k)[:k] else: top_local = np.arange(len(candidates)) order = np.argsort(dists[top_local]) sorted_local = top_local[order] return [(int(candidates[i]), int(dists[i])) for i in sorted_local] else: # Exact search indices, distances = hamming_topk(query, self.addresses, k=k) return [(int(idx), int(dist)) for idx, dist in zip(indices, distances)] def query_radius(self, query: np.ndarray, radius: int) -> List[Tuple[int, int]]: """Find all entries within Hamming radius of query.""" if self.size == 0: return [] dists = hamming_batch(query, self.addresses) mask = dists <= radius indices = np.where(mask)[0] return [(int(i), int(dists[i])) for i in indices] def activate(self, indices: np.ndarray, strengths: np.ndarray): """Set activation levels for specified entries.""" self._activations[indices] = strengths def decay_activations(self, factor: float = 0.95): """Exponential decay of all activations.""" self._activations[:self.size] *= factor def get_active(self, threshold: float = 0.1) -> np.ndarray: """Get indices of entries with activation above threshold.""" return np.where(self._activations[:self.size] > threshold)[0] def read_activated(self, threshold: float = 0.1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Read contents of activated entries. Returns: (indices, content_vectors, activation_strengths) """ active_idx = self.get_active(threshold) if len(active_idx) == 0: return (np.array([], dtype=np.int64), np.zeros((0, N_WORDS), dtype=np.uint64), np.array([], dtype=np.float64)) return (active_idx, self._contents[active_idx], self._activations[active_idx]) def _grow(self, factor: float = 1.5): """Grow internal storage when capacity is exceeded.""" new_cap = int(self.capacity * factor) logger.info(f"Growing SparseAddressTable from {self.capacity} to {new_cap}") new_addr = np.zeros((new_cap, N_WORDS), dtype=np.uint64) new_cont = np.zeros((new_cap, N_WORDS), dtype=np.uint64) new_act = np.zeros(new_cap, dtype=np.float64) new_ts = np.zeros(new_cap, dtype=np.int64) new_addr[:self.size] = self._addresses[:self.size] new_cont[:self.size] = self._contents[:self.size] new_act[:self.size] = self._activations[:self.size] new_ts[:self.size] = self._timestamps[:self.size] self._addresses = new_addr self._contents = new_cont self._activations = new_act self._timestamps = new_ts self._metadata.extend([None] * (new_cap - self.capacity)) self.capacity = new_cap def stats(self) -> Dict[str, Any]: """Return memory statistics.""" mem_bytes = self.size * N_BYTES * 2 # addresses + contents return { 'size': self.size, 'capacity': self.capacity, 'memory_mb': mem_bytes / (1024 * 1024), 'lsh_tables': self.lsh.n_tables, 'lsh_projections': self.lsh.n_projections, 'active_entries': int((self._activations[:self.size] > 0.1).sum()), 'named_symbols': len(self._symbol_table), } def __repr__(self): return (f"SparseAddressTable(size={self.size}, capacity={self.capacity}, " f"symbols={len(self._symbol_table)})")