MLE-Morpho-Logic-Engine / mle /memory /sparse_address_table.py
Harry00's picture
feat: complete MLE engine implementation
ebaf2ce verified
"""
MLE Memory Module: Sparse Address Table
========================================
Distributed memory indexed by 4096-bit binary vectors.
Semantic proximity is encoded via Hamming distance.
Features:
- Bit-packed storage (512 bytes/vector) with cache-aligned layout
- LSH index for sub-linear approximate nearest neighbor search
- Multi-resolution indexing (coarse + fine search)
- Metadata/payload attachment per entry
"""
import numpy as np
from collections import defaultdict
from typing import List, Tuple, Optional, Dict, Any
import logging
from ..utils.simd_ops import (
N_BITS, N_WORDS, N_BYTES,
random_binary_vector, random_binary_vectors,
hamming_distance, hamming_batch, hamming_topk,
xor_vectors, popcount, majority_vote, hamming_similarity
)
logger = logging.getLogger(__name__)
class HammingLSH:
"""Locality-Sensitive Hashing for Hamming space.
Uses random bit sampling as the LSH family:
h_i(v) = v[bit_index_i]
P(h(a) == h(b)) = 1 - hamming(a,b)/n
Multiple hash tables with K-bit signatures for amplification.
"""
def __init__(
self,
n_bits: int = N_BITS,
n_tables: int = 32,
n_projections: int = 8,
seed: int = 42
):
self.n_bits = n_bits
self.n_tables = n_tables
self.n_projections = n_projections
rng = np.random.RandomState(seed)
# Random bit indices for each table: which bits to sample
self.bit_indices = [
rng.choice(n_bits, n_projections, replace=False)
for _ in range(n_tables)
]
# Hash tables: table_idx -> {hash_key -> list of vector indices}
self.tables: List[Dict[bytes, List[int]]] = [
defaultdict(list) for _ in range(n_tables)
]
self.n_indexed = 0
def _compute_hash(self, bits_unpacked: np.ndarray, table_idx: int) -> bytes:
"""Extract hash signature from unpacked bit array."""
sig = bits_unpacked[self.bit_indices[table_idx]]
return np.packbits(sig).tobytes()
def _unpack_vector(self, packed: np.ndarray) -> np.ndarray:
"""Unpack uint64 vector to bit array."""
return np.unpackbits(packed.view(np.uint8))
def add(self, packed_vector: np.ndarray, idx: int):
"""Add a single vector to all hash tables."""
bits = self._unpack_vector(packed_vector)
for t in range(self.n_tables):
h = self._compute_hash(bits, t)
self.tables[t][h].append(idx)
self.n_indexed += 1
def add_batch(self, packed_vectors: np.ndarray, start_idx: int = 0):
"""Add multiple vectors to all hash tables."""
for i in range(len(packed_vectors)):
self.add(packed_vectors[i], start_idx + i)
def query_candidates(self, packed_query: np.ndarray, max_candidates: int = 2000) -> np.ndarray:
"""Find candidate indices via LSH (before exact reranking).
Returns deduplicated candidate indices.
"""
bits = self._unpack_vector(packed_query)
candidates = set()
for t in range(self.n_tables):
h = self._compute_hash(bits, t)
bucket = self.tables[t].get(h, [])
candidates.update(bucket)
if len(candidates) >= max_candidates:
break
return np.array(list(candidates)[:max_candidates], dtype=np.int64)
def query_multi_probe(self, packed_query: np.ndarray, n_probes: int = 3,
max_candidates: int = 2000) -> np.ndarray:
"""Multi-probe LSH: also check neighboring buckets by flipping bits.
Increases recall at cost of more bucket lookups.
For short signatures (n_projections <= 12), we can flip multiple
bits combinatorially.
"""
bits = self._unpack_vector(packed_query)
candidates = set()
for t in range(self.n_tables):
# Original bucket
h = self._compute_hash(bits, t)
candidates.update(self.tables[t].get(h, []))
# Probe neighboring buckets: flip each single projection bit
probe_bits = bits.copy()
n_probe_bits = min(n_probes, self.n_projections)
for probe in range(n_probe_bits):
bit_pos = self.bit_indices[t][probe]
probe_bits[bit_pos] ^= 1
h2 = self._compute_hash(probe_bits, t)
candidates.update(self.tables[t].get(h2, []))
probe_bits[bit_pos] ^= 1 # restore
# Also probe 2-bit flips for the first few bits
if n_probes >= 2 and self.n_projections >= 2:
for i in range(min(n_probes, self.n_projections)):
for j in range(i + 1, min(n_probes, self.n_projections)):
probe_bits = bits.copy()
probe_bits[self.bit_indices[t][i]] ^= 1
probe_bits[self.bit_indices[t][j]] ^= 1
h3 = self._compute_hash(probe_bits, t)
candidates.update(self.tables[t].get(h3, []))
if len(candidates) >= max_candidates:
break
return np.array(list(candidates)[:max_candidates], dtype=np.int64)
class MemoryEntry:
"""A single entry in the Sparse Address Table."""
__slots__ = ['address', 'content', 'metadata', 'activation', 'timestamp']
def __init__(self, address: np.ndarray, content: np.ndarray,
metadata: Optional[Dict[str, Any]] = None):
self.address = address # (N_WORDS,) uint64 - the index key
self.content = content # (N_WORDS,) uint64 - stored data
self.metadata = metadata or {} # arbitrary metadata
self.activation = 0.0 # current activation level
self.timestamp = 0 # last access time
class SparseAddressTable:
"""
Distributed memory indexed by 4096-bit binary vectors.
Architecture:
- Primary storage: contiguous (N, N_WORDS) uint64 matrix for SIMD batch ops
- LSH index: multi-table bit-sampling for sub-linear ANN search
- Content storage: separate matrix (decoupled address/content)
- Activation tracking: for energy-based dynamics
Memory layout is Structure of Arrays (SoA) for cache locality
during batch Hamming distance computation.
"""
def __init__(
self,
capacity: int = 100_000,
lsh_tables: int = 32,
lsh_projections: int = 8,
lsh_seed: int = 42
):
self.capacity = capacity
self.size = 0
# SoA layout: addresses and contents as contiguous matrices
self._addresses = np.zeros((capacity, N_WORDS), dtype=np.uint64)
self._contents = np.zeros((capacity, N_WORDS), dtype=np.uint64)
# Metadata and activation stored separately
self._metadata: List[Dict[str, Any]] = [None] * capacity
self._activations = np.zeros(capacity, dtype=np.float64)
self._timestamps = np.zeros(capacity, dtype=np.int64)
# LSH index — use short signatures (8-bit) with many tables (32)
# for high recall on 4096-bit vectors
self.lsh = HammingLSH(
n_bits=N_BITS,
n_tables=lsh_tables,
n_projections=lsh_projections,
seed=lsh_seed
)
# Global step counter for timestamps
self._step = 0
# Symbol table: name -> index mapping for named concepts
self._symbol_table: Dict[str, int] = {}
@property
def addresses(self) -> np.ndarray:
"""Active address vectors. Shape: (size, N_WORDS)."""
return self._addresses[:self.size]
@property
def contents(self) -> np.ndarray:
"""Active content vectors. Shape: (size, N_WORDS)."""
return self._contents[:self.size]
@property
def activations(self) -> np.ndarray:
"""Active activation levels. Shape: (size,)."""
return self._activations[:self.size]
def store(self, address: np.ndarray, content: np.ndarray,
metadata: Optional[Dict[str, Any]] = None,
name: Optional[str] = None) -> int:
"""Store a new entry. Returns the entry index."""
if self.size >= self.capacity:
self._grow()
idx = self.size
self._addresses[idx] = address
self._contents[idx] = content
self._metadata[idx] = metadata or {}
self._timestamps[idx] = self._step
self._step += 1
# Index in LSH
self.lsh.add(address, idx)
if name:
self._symbol_table[name] = idx
self.size += 1
return idx
def store_concept(self, name: str, content: Optional[np.ndarray] = None,
metadata: Optional[Dict[str, Any]] = None) -> int:
"""Store a named concept with auto-generated address."""
address = random_binary_vector()
if content is None:
content = random_binary_vector()
meta = metadata or {}
meta['name'] = name
return self.store(address, content, metadata=meta, name=name)
def get_by_name(self, name: str) -> Optional[Tuple[np.ndarray, np.ndarray, Dict]]:
"""Retrieve entry by symbolic name."""
idx = self._symbol_table.get(name)
if idx is None:
return None
return (self._addresses[idx].copy(),
self._contents[idx].copy(),
self._metadata[idx])
def get_address_by_name(self, name: str) -> Optional[np.ndarray]:
"""Get the address vector for a named concept."""
idx = self._symbol_table.get(name)
if idx is None:
return None
return self._addresses[idx].copy()
def get_content_by_name(self, name: str) -> Optional[np.ndarray]:
"""Get the content vector for a named concept."""
idx = self._symbol_table.get(name)
if idx is None:
return None
return self._contents[idx].copy()
def query_nearest(self, query: np.ndarray, k: int = 10,
use_lsh: bool = True) -> List[Tuple[int, int]]:
"""Find k nearest entries by Hamming distance to query address.
Args:
query: (N_WORDS,) uint64 query vector
k: number of results
use_lsh: if True, use LSH pre-filter; if False, exact scan
Returns:
List of (index, distance) tuples, sorted by distance ascending.
"""
if self.size == 0:
return []
if use_lsh and self.size > 1000:
# LSH pre-filter → exact rerank
candidates = self.lsh.query_multi_probe(query, max_candidates=max(k * 10, 2000))
if len(candidates) == 0:
# Fallback to exact
candidates = np.arange(self.size, dtype=np.int64)
candidate_vecs = np.ascontiguousarray(self._addresses[candidates])
dists = hamming_batch(query, candidate_vecs)
if k < len(candidates):
top_local = np.argpartition(dists, k)[:k]
else:
top_local = np.arange(len(candidates))
order = np.argsort(dists[top_local])
sorted_local = top_local[order]
return [(int(candidates[i]), int(dists[i])) for i in sorted_local]
else:
# Exact search
indices, distances = hamming_topk(query, self.addresses, k=k)
return [(int(idx), int(dist)) for idx, dist in zip(indices, distances)]
def query_radius(self, query: np.ndarray, radius: int) -> List[Tuple[int, int]]:
"""Find all entries within Hamming radius of query."""
if self.size == 0:
return []
dists = hamming_batch(query, self.addresses)
mask = dists <= radius
indices = np.where(mask)[0]
return [(int(i), int(dists[i])) for i in indices]
def activate(self, indices: np.ndarray, strengths: np.ndarray):
"""Set activation levels for specified entries."""
self._activations[indices] = strengths
def decay_activations(self, factor: float = 0.95):
"""Exponential decay of all activations."""
self._activations[:self.size] *= factor
def get_active(self, threshold: float = 0.1) -> np.ndarray:
"""Get indices of entries with activation above threshold."""
return np.where(self._activations[:self.size] > threshold)[0]
def read_activated(self, threshold: float = 0.1) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Read contents of activated entries.
Returns: (indices, content_vectors, activation_strengths)
"""
active_idx = self.get_active(threshold)
if len(active_idx) == 0:
return (np.array([], dtype=np.int64),
np.zeros((0, N_WORDS), dtype=np.uint64),
np.array([], dtype=np.float64))
return (active_idx,
self._contents[active_idx],
self._activations[active_idx])
def _grow(self, factor: float = 1.5):
"""Grow internal storage when capacity is exceeded."""
new_cap = int(self.capacity * factor)
logger.info(f"Growing SparseAddressTable from {self.capacity} to {new_cap}")
new_addr = np.zeros((new_cap, N_WORDS), dtype=np.uint64)
new_cont = np.zeros((new_cap, N_WORDS), dtype=np.uint64)
new_act = np.zeros(new_cap, dtype=np.float64)
new_ts = np.zeros(new_cap, dtype=np.int64)
new_addr[:self.size] = self._addresses[:self.size]
new_cont[:self.size] = self._contents[:self.size]
new_act[:self.size] = self._activations[:self.size]
new_ts[:self.size] = self._timestamps[:self.size]
self._addresses = new_addr
self._contents = new_cont
self._activations = new_act
self._timestamps = new_ts
self._metadata.extend([None] * (new_cap - self.capacity))
self.capacity = new_cap
def stats(self) -> Dict[str, Any]:
"""Return memory statistics."""
mem_bytes = self.size * N_BYTES * 2 # addresses + contents
return {
'size': self.size,
'capacity': self.capacity,
'memory_mb': mem_bytes / (1024 * 1024),
'lsh_tables': self.lsh.n_tables,
'lsh_projections': self.lsh.n_projections,
'active_entries': int((self._activations[:self.size] > 0.1).sum()),
'named_symbols': len(self._symbol_table),
}
def __repr__(self):
return (f"SparseAddressTable(size={self.size}, capacity={self.capacity}, "
f"symbols={len(self._symbol_table)})")