""" Routing optimisé via Hamming distance + bit-slicing SIMD Le routing est le cœur de la récupération rapide dans la mémoire distribuée. On utilise : - Bit-slicing : découpe les vecteurs 4096 bits en tranches de 64 bits - Hamming distance via popcount par tranche (SIMD-friendly) - Index inversé pour sauts rapides """ import numpy as np from numba import njit, prange, uint64 from typing import List, Tuple, Dict, Optional import logging logger = logging.getLogger(__name__) VECTOR_SIZE = 4096 SLICE_BITS = 64 NUM_SLICES = VECTOR_SIZE // SLICE_BITS @njit(cache=True) def pack_bits_to_uint64(bits: np.ndarray) -> np.ndarray: """ Pack un vecteur 4096 bits (uint8) en 64 uint64. Chaque uint64 contient 64 bits du vecteur original. """ result = np.zeros(NUM_SLICES, dtype=np.uint64) for i in range(NUM_SLICES): val = np.uint64(0) for j in range(SLICE_BITS): if bits[i * SLICE_BITS + j]: val |= np.uint64(1) << np.uint64(j) result[i] = val return result @njit(cache=True) def unpack_uint64_to_bits(packed: np.ndarray) -> np.ndarray: """Unpack 64 uint64 en un vecteur 4096 bits.""" result = np.zeros(VECTOR_SIZE, dtype=np.uint8) for i in range(NUM_SLICES): val = packed[i] for j in range(SLICE_BITS): result[i * SLICE_BITS + j] = np.uint8((val >> np.uint64(j)) & np.uint64(1)) return result @njit(parallel=True, cache=True) def hamming_uint64_batch(query_packed: np.ndarray, table_packed: np.ndarray) -> np.ndarray: """ Distance de Hamming entre un vecteur packé et N vecteurs packés. Utilise XOR + popcount via bit-twiddling (très rapide). Args: query_packed: (NUM_SLICES,) uint64 table_packed: (N, NUM_SLICES) uint64 Returns: distances: (N,) int32 """ N = table_packed.shape[0] distances = np.empty(N, dtype=np.int32) for i in prange(N): dist = np.int32(0) for j in range(NUM_SLICES): xor_val = query_packed[j] ^ table_packed[i, j] # Popcount pour uint64 x = xor_val x = x - ((x >> np.uint64(1)) & np.uint64(0x5555555555555555)) x = (x & np.uint64(0x3333333333333333)) + ((x >> np.uint64(2)) & np.uint64(0x3333333333333333)) x = (x + (x >> np.uint64(4))) & np.uint64(0x0F0F0F0F0F0F0F0F) x = x + (x >> np.uint64(8)) x = x + (x >> np.uint64(16)) x = x + (x >> np.uint64(32)) dist += np.int32(x & np.uint64(0x7F)) distances[i] = dist return distances class BitSliceIndex: """ Index inversé par tranches de bits. Pour chaque tranche (64 bits), maintient une table de hachage des tranches observées vers les vecteurs qui les contiennent. """ def __init__(self): # slice_idx -> {slice_hash -> [vector_indices]} self.slice_maps: List[Dict[int, List[int]]] = [ {} for _ in range(NUM_SLICES) ] def add_vector(self, vector_idx: int, packed: np.ndarray): """Ajoute un vecteur à l'index.""" for slice_idx in range(NUM_SLICES): slice_val = int(packed[slice_idx]) if slice_val not in self.slice_maps[slice_idx]: self.slice_maps[slice_idx][slice_val] = [] self.slice_maps[slice_idx][slice_val].append(vector_idx) def remove_vector(self, vector_idx: int, packed: np.ndarray): """Retire un vecteur de l'index.""" for slice_idx in range(NUM_SLICES): slice_val = int(packed[slice_idx]) if slice_val in self.slice_maps[slice_idx]: lst = self.slice_maps[slice_idx][slice_val] if vector_idx in lst: lst.remove(vector_idx) def query_candidates( self, query_packed: np.ndarray, top_slices: int = 8, max_candidates: int = 100 ) -> List[Tuple[int, int]]: """ Retourne les candidats qui partagent le plus de tranches avec la requête. Returns: List of (vector_idx, match_count) """ counts: Dict[int, int] = {} # Prend les top_slices tranches les plus discriminantes # (celles qui ont le moins d'entrées dans l'index) slice_counts = [] for slice_idx in range(NUM_SLICES): slice_val = int(query_packed[slice_idx]) n_entries = len(self.slice_maps[slice_idx].get(slice_val, [])) slice_counts.append((slice_idx, n_entries)) # Trie : préfère les tranches rares (plus discriminantes) slice_counts.sort(key=lambda x: x[1]) for slice_idx, _ in slice_counts[:top_slices]: slice_val = int(query_packed[slice_idx]) for vec_idx in self.slice_maps[slice_idx].get(slice_val, []): counts[vec_idx] = counts.get(vec_idx, 0) + 1 # Trie par nombre de matches candidates = sorted(counts.items(), key=lambda x: -x[1]) return candidates[:max_candidates] class HammingRouter: """ Routeur optimisé basé sur Hamming + bit-slicing. Responsabilités : - Convertir les vecteurs en format packé uint64 - Maintenir l'index inversé - Router les requêtes vers les voisins les plus proches - Apprendre les routes fréquentes """ def __init__( self, use_index: bool = True, index_top_slices: int = 8, cache_size: int = 1000, learn_routes: bool = True, ): self.use_index = use_index self.index_top_slices = index_top_slices self.learn_routes = learn_routes # Stockage packé self.packed_vectors: Dict[int, np.ndarray] = {} # vector_idx -> packed self.index = BitSliceIndex() # Cache de routes : pattern_hash -> [(target_idx, frequency)] self.route_cache: Dict[int, List[Tuple[int, float]]] = {} self.cache_size = cache_size self.cache_hits = 0 self.cache_misses = 0 # Stats self.query_count = 0 self.avg_query_time = 0.0 def add_vector(self, vector_idx: int, vector: np.ndarray): """Ajoute un vecteur au routeur.""" packed = pack_bits_to_uint64(vector) self.packed_vectors[vector_idx] = packed if self.use_index: self.index.add_vector(vector_idx, packed) def remove_vector(self, vector_idx: int, vector: np.ndarray): """Retire un vecteur du routeur.""" if vector_idx in self.packed_vectors: packed = self.packed_vectors[vector_idx] if self.use_index: self.index.remove_vector(vector_idx, packed) del self.packed_vectors[vector_idx] def update_vector(self, vector_idx: int, new_vector: np.ndarray): """Met à jour un vecteur existant.""" old_packed = self.packed_vectors.get(vector_idx) new_packed = pack_bits_to_uint64(new_vector) self.packed_vectors[vector_idx] = new_packed if self.use_index and old_packed is not None: self.index.remove_vector(vector_idx, old_packed) self.index.add_vector(vector_idx, new_packed) def _pattern_hash(self, packed: np.ndarray) -> int: """Hachage rapide d'un vecteur packé pour le cache.""" # XOR fold des tranches avec mix simple (en Python int pour éviter overflow) h = 0xcbf29ce484222325 # FNV offset basis for i in range(NUM_SLICES): h ^= int(packed[i]) h = (h * 0x100000001b3) & 0xFFFFFFFFFFFFFFFF return h def route( self, query: np.ndarray, candidate_indices: Optional[List[int]] = None, k: int = 5, use_cache: bool = True ) -> List[Tuple[int, float]]: """ Route une requête vers les k voisins les plus proches. Args: query: vecteur (4096,) uint8 candidate_indices: indices candidats (si None, utilise tous) k: nombre de résultats Returns: [(vector_idx, distance)] trié par distance """ import time t0 = time.time() query_packed = pack_bits_to_uint64(query) # Essaie le cache if use_cache and self.learn_routes: ph = self._pattern_hash(query_packed) if ph in self.route_cache: self.cache_hits += 1 # Filtre les entrées qui existent encore results = [ (idx, 0.0) for idx, freq in self.route_cache[ph] if idx in self.packed_vectors ] if len(results) >= k: t1 = time.time() self._update_timing(t1 - t0) return results[:k] else: self.cache_misses += 1 # Détermine les candidats if candidate_indices is None: if self.use_index and len(self.packed_vectors) > 100: # Utilise l'index pour réduire les candidats candidates = self.index.query_candidates( query_packed, top_slices=self.index_top_slices, max_candidates=200 ) candidate_indices = [idx for idx, _ in candidates] else: candidate_indices = list(self.packed_vectors.keys()) if len(candidate_indices) == 0: return [] # Calcule les distances de Hamming candidates_packed = np.array([ self.packed_vectors[idx] for idx in candidate_indices ], dtype=np.uint64) distances = hamming_uint64_batch(query_packed, candidates_packed) # Trie sorted_idx = np.argsort(distances)[:k] results = [ (candidate_indices[si], float(distances[si])) for si in sorted_idx ] # Met à jour le cache if self.learn_routes and use_cache: ph = self._pattern_hash(query_packed) if ph not in self.route_cache: if len(self.route_cache) >= self.cache_size: # Éviction aléatoire keys = list(self.route_cache.keys()) to_evict = keys[np.random.randint(0, len(keys))] del self.route_cache[to_evict] # Stocke la route apprise self.route_cache[ph] = [ (idx, 1.0 / (1.0 + dist)) for idx, dist in results ] t1 = time.time() self._update_timing(t1 - t0) return results def _update_timing(self, elapsed: float): """Met à jour les statistiques de temps.""" self.query_count += 1 self.avg_query_time = ( self.avg_query_time * (self.query_count - 1) + elapsed ) / self.query_count def get_stats(self) -> Dict: total = self.cache_hits + self.cache_misses hit_rate = self.cache_hits / total if total > 0 else 0.0 return { 'query_count': self.query_count, 'avg_query_time_ms': self.avg_query_time * 1000, 'cache_hit_rate': hit_rate, 'n_vectors': len(self.packed_vectors), }