| """ |
| Routing optimisé via Hamming distance + bit-slicing SIMD |
| |
| Le routing est le cœur de la récupération rapide dans la mémoire distribuée. |
| On utilise : |
| - Bit-slicing : découpe les vecteurs 4096 bits en tranches de 64 bits |
| - Hamming distance via popcount par tranche (SIMD-friendly) |
| - Index inversé pour sauts rapides |
| """ |
|
|
| import numpy as np |
| from numba import njit, prange, uint64 |
| from typing import List, Tuple, Dict, Optional |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| VECTOR_SIZE = 4096 |
| SLICE_BITS = 64 |
| NUM_SLICES = VECTOR_SIZE // SLICE_BITS |
|
|
|
|
| @njit(cache=True) |
| def pack_bits_to_uint64(bits: np.ndarray) -> np.ndarray: |
| """ |
| Pack un vecteur 4096 bits (uint8) en 64 uint64. |
| Chaque uint64 contient 64 bits du vecteur original. |
| """ |
| result = np.zeros(NUM_SLICES, dtype=np.uint64) |
| for i in range(NUM_SLICES): |
| val = np.uint64(0) |
| for j in range(SLICE_BITS): |
| if bits[i * SLICE_BITS + j]: |
| val |= np.uint64(1) << np.uint64(j) |
| result[i] = val |
| return result |
|
|
|
|
| @njit(cache=True) |
| def unpack_uint64_to_bits(packed: np.ndarray) -> np.ndarray: |
| """Unpack 64 uint64 en un vecteur 4096 bits.""" |
| result = np.zeros(VECTOR_SIZE, dtype=np.uint8) |
| for i in range(NUM_SLICES): |
| val = packed[i] |
| for j in range(SLICE_BITS): |
| result[i * SLICE_BITS + j] = np.uint8((val >> np.uint64(j)) & np.uint64(1)) |
| return result |
|
|
|
|
| @njit(parallel=True, cache=True) |
| def hamming_uint64_batch(query_packed: np.ndarray, table_packed: np.ndarray) -> np.ndarray: |
| """ |
| Distance de Hamming entre un vecteur packé et N vecteurs packés. |
| Utilise XOR + popcount via bit-twiddling (très rapide). |
| |
| Args: |
| query_packed: (NUM_SLICES,) uint64 |
| table_packed: (N, NUM_SLICES) uint64 |
| |
| Returns: |
| distances: (N,) int32 |
| """ |
| N = table_packed.shape[0] |
| distances = np.empty(N, dtype=np.int32) |
| |
| for i in prange(N): |
| dist = np.int32(0) |
| for j in range(NUM_SLICES): |
| xor_val = query_packed[j] ^ table_packed[i, j] |
| |
| x = xor_val |
| x = x - ((x >> np.uint64(1)) & np.uint64(0x5555555555555555)) |
| x = (x & np.uint64(0x3333333333333333)) + ((x >> np.uint64(2)) & np.uint64(0x3333333333333333)) |
| x = (x + (x >> np.uint64(4))) & np.uint64(0x0F0F0F0F0F0F0F0F) |
| x = x + (x >> np.uint64(8)) |
| x = x + (x >> np.uint64(16)) |
| x = x + (x >> np.uint64(32)) |
| dist += np.int32(x & np.uint64(0x7F)) |
| distances[i] = dist |
| |
| return distances |
|
|
|
|
| class BitSliceIndex: |
| """ |
| Index inversé par tranches de bits. |
| Pour chaque tranche (64 bits), maintient une table de hachage des |
| tranches observées vers les vecteurs qui les contiennent. |
| """ |
| |
| def __init__(self): |
| |
| self.slice_maps: List[Dict[int, List[int]]] = [ |
| {} for _ in range(NUM_SLICES) |
| ] |
| |
| def add_vector(self, vector_idx: int, packed: np.ndarray): |
| """Ajoute un vecteur à l'index.""" |
| for slice_idx in range(NUM_SLICES): |
| slice_val = int(packed[slice_idx]) |
| if slice_val not in self.slice_maps[slice_idx]: |
| self.slice_maps[slice_idx][slice_val] = [] |
| self.slice_maps[slice_idx][slice_val].append(vector_idx) |
| |
| def remove_vector(self, vector_idx: int, packed: np.ndarray): |
| """Retire un vecteur de l'index.""" |
| for slice_idx in range(NUM_SLICES): |
| slice_val = int(packed[slice_idx]) |
| if slice_val in self.slice_maps[slice_idx]: |
| lst = self.slice_maps[slice_idx][slice_val] |
| if vector_idx in lst: |
| lst.remove(vector_idx) |
| |
| def query_candidates( |
| self, |
| query_packed: np.ndarray, |
| top_slices: int = 8, |
| max_candidates: int = 100 |
| ) -> List[Tuple[int, int]]: |
| """ |
| Retourne les candidats qui partagent le plus de tranches avec la requête. |
| |
| Returns: |
| List of (vector_idx, match_count) |
| """ |
| counts: Dict[int, int] = {} |
| |
| |
| |
| slice_counts = [] |
| for slice_idx in range(NUM_SLICES): |
| slice_val = int(query_packed[slice_idx]) |
| n_entries = len(self.slice_maps[slice_idx].get(slice_val, [])) |
| slice_counts.append((slice_idx, n_entries)) |
| |
| |
| slice_counts.sort(key=lambda x: x[1]) |
| |
| for slice_idx, _ in slice_counts[:top_slices]: |
| slice_val = int(query_packed[slice_idx]) |
| for vec_idx in self.slice_maps[slice_idx].get(slice_val, []): |
| counts[vec_idx] = counts.get(vec_idx, 0) + 1 |
| |
| |
| candidates = sorted(counts.items(), key=lambda x: -x[1]) |
| return candidates[:max_candidates] |
|
|
|
|
| class HammingRouter: |
| """ |
| Routeur optimisé basé sur Hamming + bit-slicing. |
| |
| Responsabilités : |
| - Convertir les vecteurs en format packé uint64 |
| - Maintenir l'index inversé |
| - Router les requêtes vers les voisins les plus proches |
| - Apprendre les routes fréquentes |
| """ |
| |
| def __init__( |
| self, |
| use_index: bool = True, |
| index_top_slices: int = 8, |
| cache_size: int = 1000, |
| learn_routes: bool = True, |
| ): |
| self.use_index = use_index |
| self.index_top_slices = index_top_slices |
| self.learn_routes = learn_routes |
| |
| |
| self.packed_vectors: Dict[int, np.ndarray] = {} |
| self.index = BitSliceIndex() |
| |
| |
| self.route_cache: Dict[int, List[Tuple[int, float]]] = {} |
| self.cache_size = cache_size |
| self.cache_hits = 0 |
| self.cache_misses = 0 |
| |
| |
| self.query_count = 0 |
| self.avg_query_time = 0.0 |
| |
| def add_vector(self, vector_idx: int, vector: np.ndarray): |
| """Ajoute un vecteur au routeur.""" |
| packed = pack_bits_to_uint64(vector) |
| self.packed_vectors[vector_idx] = packed |
| if self.use_index: |
| self.index.add_vector(vector_idx, packed) |
| |
| def remove_vector(self, vector_idx: int, vector: np.ndarray): |
| """Retire un vecteur du routeur.""" |
| if vector_idx in self.packed_vectors: |
| packed = self.packed_vectors[vector_idx] |
| if self.use_index: |
| self.index.remove_vector(vector_idx, packed) |
| del self.packed_vectors[vector_idx] |
| |
| def update_vector(self, vector_idx: int, new_vector: np.ndarray): |
| """Met à jour un vecteur existant.""" |
| old_packed = self.packed_vectors.get(vector_idx) |
| new_packed = pack_bits_to_uint64(new_vector) |
| self.packed_vectors[vector_idx] = new_packed |
| |
| if self.use_index and old_packed is not None: |
| self.index.remove_vector(vector_idx, old_packed) |
| self.index.add_vector(vector_idx, new_packed) |
| |
| def _pattern_hash(self, packed: np.ndarray) -> int: |
| """Hachage rapide d'un vecteur packé pour le cache.""" |
| |
| h = 0xcbf29ce484222325 |
| for i in range(NUM_SLICES): |
| h ^= int(packed[i]) |
| h = (h * 0x100000001b3) & 0xFFFFFFFFFFFFFFFF |
| return h |
| |
| def route( |
| self, |
| query: np.ndarray, |
| candidate_indices: Optional[List[int]] = None, |
| k: int = 5, |
| use_cache: bool = True |
| ) -> List[Tuple[int, float]]: |
| """ |
| Route une requête vers les k voisins les plus proches. |
| |
| Args: |
| query: vecteur (4096,) uint8 |
| candidate_indices: indices candidats (si None, utilise tous) |
| k: nombre de résultats |
| |
| Returns: |
| [(vector_idx, distance)] trié par distance |
| """ |
| import time |
| t0 = time.time() |
| |
| query_packed = pack_bits_to_uint64(query) |
| |
| |
| if use_cache and self.learn_routes: |
| ph = self._pattern_hash(query_packed) |
| if ph in self.route_cache: |
| self.cache_hits += 1 |
| |
| results = [ |
| (idx, 0.0) for idx, freq in self.route_cache[ph] |
| if idx in self.packed_vectors |
| ] |
| if len(results) >= k: |
| t1 = time.time() |
| self._update_timing(t1 - t0) |
| return results[:k] |
| else: |
| self.cache_misses += 1 |
| |
| |
| if candidate_indices is None: |
| if self.use_index and len(self.packed_vectors) > 100: |
| |
| candidates = self.index.query_candidates( |
| query_packed, |
| top_slices=self.index_top_slices, |
| max_candidates=200 |
| ) |
| candidate_indices = [idx for idx, _ in candidates] |
| else: |
| candidate_indices = list(self.packed_vectors.keys()) |
| |
| if len(candidate_indices) == 0: |
| return [] |
| |
| |
| candidates_packed = np.array([ |
| self.packed_vectors[idx] for idx in candidate_indices |
| ], dtype=np.uint64) |
| |
| distances = hamming_uint64_batch(query_packed, candidates_packed) |
| |
| |
| sorted_idx = np.argsort(distances)[:k] |
| results = [ |
| (candidate_indices[si], float(distances[si])) |
| for si in sorted_idx |
| ] |
| |
| |
| if self.learn_routes and use_cache: |
| ph = self._pattern_hash(query_packed) |
| if ph not in self.route_cache: |
| if len(self.route_cache) >= self.cache_size: |
| |
| keys = list(self.route_cache.keys()) |
| to_evict = keys[np.random.randint(0, len(keys))] |
| del self.route_cache[to_evict] |
| |
| |
| self.route_cache[ph] = [ |
| (idx, 1.0 / (1.0 + dist)) for idx, dist in results |
| ] |
| |
| t1 = time.time() |
| self._update_timing(t1 - t0) |
| |
| return results |
| |
| def _update_timing(self, elapsed: float): |
| """Met à jour les statistiques de temps.""" |
| self.query_count += 1 |
| self.avg_query_time = ( |
| self.avg_query_time * (self.query_count - 1) + elapsed |
| ) / self.query_count |
| |
| def get_stats(self) -> Dict: |
| total = self.cache_hits + self.cache_misses |
| hit_rate = self.cache_hits / total if total > 0 else 0.0 |
| return { |
| 'query_count': self.query_count, |
| 'avg_query_time_ms': self.avg_query_time * 1000, |
| 'cache_hit_rate': hit_rate, |
| 'n_vectors': len(self.packed_vectors), |
| } |