morpho-logic-engine / mle /routing.py
Harry00's picture
Upload mle/routing.py
b4b14c2 verified
"""
Routing optimisé via Hamming distance + bit-slicing SIMD
Le routing est le cœur de la récupération rapide dans la mémoire distribuée.
On utilise :
- Bit-slicing : découpe les vecteurs 4096 bits en tranches de 64 bits
- Hamming distance via popcount par tranche (SIMD-friendly)
- Index inversé pour sauts rapides
"""
import numpy as np
from numba import njit, prange, uint64
from typing import List, Tuple, Dict, Optional
import logging
logger = logging.getLogger(__name__)
VECTOR_SIZE = 4096
SLICE_BITS = 64
NUM_SLICES = VECTOR_SIZE // SLICE_BITS
@njit(cache=True)
def pack_bits_to_uint64(bits: np.ndarray) -> np.ndarray:
"""
Pack un vecteur 4096 bits (uint8) en 64 uint64.
Chaque uint64 contient 64 bits du vecteur original.
"""
result = np.zeros(NUM_SLICES, dtype=np.uint64)
for i in range(NUM_SLICES):
val = np.uint64(0)
for j in range(SLICE_BITS):
if bits[i * SLICE_BITS + j]:
val |= np.uint64(1) << np.uint64(j)
result[i] = val
return result
@njit(cache=True)
def unpack_uint64_to_bits(packed: np.ndarray) -> np.ndarray:
"""Unpack 64 uint64 en un vecteur 4096 bits."""
result = np.zeros(VECTOR_SIZE, dtype=np.uint8)
for i in range(NUM_SLICES):
val = packed[i]
for j in range(SLICE_BITS):
result[i * SLICE_BITS + j] = np.uint8((val >> np.uint64(j)) & np.uint64(1))
return result
@njit(parallel=True, cache=True)
def hamming_uint64_batch(query_packed: np.ndarray, table_packed: np.ndarray) -> np.ndarray:
"""
Distance de Hamming entre un vecteur packé et N vecteurs packés.
Utilise XOR + popcount via bit-twiddling (très rapide).
Args:
query_packed: (NUM_SLICES,) uint64
table_packed: (N, NUM_SLICES) uint64
Returns:
distances: (N,) int32
"""
N = table_packed.shape[0]
distances = np.empty(N, dtype=np.int32)
for i in prange(N):
dist = np.int32(0)
for j in range(NUM_SLICES):
xor_val = query_packed[j] ^ table_packed[i, j]
# Popcount pour uint64
x = xor_val
x = x - ((x >> np.uint64(1)) & np.uint64(0x5555555555555555))
x = (x & np.uint64(0x3333333333333333)) + ((x >> np.uint64(2)) & np.uint64(0x3333333333333333))
x = (x + (x >> np.uint64(4))) & np.uint64(0x0F0F0F0F0F0F0F0F)
x = x + (x >> np.uint64(8))
x = x + (x >> np.uint64(16))
x = x + (x >> np.uint64(32))
dist += np.int32(x & np.uint64(0x7F))
distances[i] = dist
return distances
class BitSliceIndex:
"""
Index inversé par tranches de bits.
Pour chaque tranche (64 bits), maintient une table de hachage des
tranches observées vers les vecteurs qui les contiennent.
"""
def __init__(self):
# slice_idx -> {slice_hash -> [vector_indices]}
self.slice_maps: List[Dict[int, List[int]]] = [
{} for _ in range(NUM_SLICES)
]
def add_vector(self, vector_idx: int, packed: np.ndarray):
"""Ajoute un vecteur à l'index."""
for slice_idx in range(NUM_SLICES):
slice_val = int(packed[slice_idx])
if slice_val not in self.slice_maps[slice_idx]:
self.slice_maps[slice_idx][slice_val] = []
self.slice_maps[slice_idx][slice_val].append(vector_idx)
def remove_vector(self, vector_idx: int, packed: np.ndarray):
"""Retire un vecteur de l'index."""
for slice_idx in range(NUM_SLICES):
slice_val = int(packed[slice_idx])
if slice_val in self.slice_maps[slice_idx]:
lst = self.slice_maps[slice_idx][slice_val]
if vector_idx in lst:
lst.remove(vector_idx)
def query_candidates(
self,
query_packed: np.ndarray,
top_slices: int = 8,
max_candidates: int = 100
) -> List[Tuple[int, int]]:
"""
Retourne les candidats qui partagent le plus de tranches avec la requête.
Returns:
List of (vector_idx, match_count)
"""
counts: Dict[int, int] = {}
# Prend les top_slices tranches les plus discriminantes
# (celles qui ont le moins d'entrées dans l'index)
slice_counts = []
for slice_idx in range(NUM_SLICES):
slice_val = int(query_packed[slice_idx])
n_entries = len(self.slice_maps[slice_idx].get(slice_val, []))
slice_counts.append((slice_idx, n_entries))
# Trie : préfère les tranches rares (plus discriminantes)
slice_counts.sort(key=lambda x: x[1])
for slice_idx, _ in slice_counts[:top_slices]:
slice_val = int(query_packed[slice_idx])
for vec_idx in self.slice_maps[slice_idx].get(slice_val, []):
counts[vec_idx] = counts.get(vec_idx, 0) + 1
# Trie par nombre de matches
candidates = sorted(counts.items(), key=lambda x: -x[1])
return candidates[:max_candidates]
class HammingRouter:
"""
Routeur optimisé basé sur Hamming + bit-slicing.
Responsabilités :
- Convertir les vecteurs en format packé uint64
- Maintenir l'index inversé
- Router les requêtes vers les voisins les plus proches
- Apprendre les routes fréquentes
"""
def __init__(
self,
use_index: bool = True,
index_top_slices: int = 8,
cache_size: int = 1000,
learn_routes: bool = True,
):
self.use_index = use_index
self.index_top_slices = index_top_slices
self.learn_routes = learn_routes
# Stockage packé
self.packed_vectors: Dict[int, np.ndarray] = {} # vector_idx -> packed
self.index = BitSliceIndex()
# Cache de routes : pattern_hash -> [(target_idx, frequency)]
self.route_cache: Dict[int, List[Tuple[int, float]]] = {}
self.cache_size = cache_size
self.cache_hits = 0
self.cache_misses = 0
# Stats
self.query_count = 0
self.avg_query_time = 0.0
def add_vector(self, vector_idx: int, vector: np.ndarray):
"""Ajoute un vecteur au routeur."""
packed = pack_bits_to_uint64(vector)
self.packed_vectors[vector_idx] = packed
if self.use_index:
self.index.add_vector(vector_idx, packed)
def remove_vector(self, vector_idx: int, vector: np.ndarray):
"""Retire un vecteur du routeur."""
if vector_idx in self.packed_vectors:
packed = self.packed_vectors[vector_idx]
if self.use_index:
self.index.remove_vector(vector_idx, packed)
del self.packed_vectors[vector_idx]
def update_vector(self, vector_idx: int, new_vector: np.ndarray):
"""Met à jour un vecteur existant."""
old_packed = self.packed_vectors.get(vector_idx)
new_packed = pack_bits_to_uint64(new_vector)
self.packed_vectors[vector_idx] = new_packed
if self.use_index and old_packed is not None:
self.index.remove_vector(vector_idx, old_packed)
self.index.add_vector(vector_idx, new_packed)
def _pattern_hash(self, packed: np.ndarray) -> int:
"""Hachage rapide d'un vecteur packé pour le cache."""
# XOR fold des tranches avec mix simple (en Python int pour éviter overflow)
h = 0xcbf29ce484222325 # FNV offset basis
for i in range(NUM_SLICES):
h ^= int(packed[i])
h = (h * 0x100000001b3) & 0xFFFFFFFFFFFFFFFF
return h
def route(
self,
query: np.ndarray,
candidate_indices: Optional[List[int]] = None,
k: int = 5,
use_cache: bool = True
) -> List[Tuple[int, float]]:
"""
Route une requête vers les k voisins les plus proches.
Args:
query: vecteur (4096,) uint8
candidate_indices: indices candidats (si None, utilise tous)
k: nombre de résultats
Returns:
[(vector_idx, distance)] trié par distance
"""
import time
t0 = time.time()
query_packed = pack_bits_to_uint64(query)
# Essaie le cache
if use_cache and self.learn_routes:
ph = self._pattern_hash(query_packed)
if ph in self.route_cache:
self.cache_hits += 1
# Filtre les entrées qui existent encore
results = [
(idx, 0.0) for idx, freq in self.route_cache[ph]
if idx in self.packed_vectors
]
if len(results) >= k:
t1 = time.time()
self._update_timing(t1 - t0)
return results[:k]
else:
self.cache_misses += 1
# Détermine les candidats
if candidate_indices is None:
if self.use_index and len(self.packed_vectors) > 100:
# Utilise l'index pour réduire les candidats
candidates = self.index.query_candidates(
query_packed,
top_slices=self.index_top_slices,
max_candidates=200
)
candidate_indices = [idx for idx, _ in candidates]
else:
candidate_indices = list(self.packed_vectors.keys())
if len(candidate_indices) == 0:
return []
# Calcule les distances de Hamming
candidates_packed = np.array([
self.packed_vectors[idx] for idx in candidate_indices
], dtype=np.uint64)
distances = hamming_uint64_batch(query_packed, candidates_packed)
# Trie
sorted_idx = np.argsort(distances)[:k]
results = [
(candidate_indices[si], float(distances[si]))
for si in sorted_idx
]
# Met à jour le cache
if self.learn_routes and use_cache:
ph = self._pattern_hash(query_packed)
if ph not in self.route_cache:
if len(self.route_cache) >= self.cache_size:
# Éviction aléatoire
keys = list(self.route_cache.keys())
to_evict = keys[np.random.randint(0, len(keys))]
del self.route_cache[to_evict]
# Stocke la route apprise
self.route_cache[ph] = [
(idx, 1.0 / (1.0 + dist)) for idx, dist in results
]
t1 = time.time()
self._update_timing(t1 - t0)
return results
def _update_timing(self, elapsed: float):
"""Met à jour les statistiques de temps."""
self.query_count += 1
self.avg_query_time = (
self.avg_query_time * (self.query_count - 1) + elapsed
) / self.query_count
def get_stats(self) -> Dict:
total = self.cache_hits + self.cache_misses
hit_rate = self.cache_hits / total if total > 0 else 0.0
return {
'query_count': self.query_count,
'avg_query_time_ms': self.avg_query_time * 1000,
'cache_hit_rate': hit_rate,
'n_vectors': len(self.packed_vectors),
}