Spaces:
Running
Running
| import torch | |
| import numpy as np | |
| import math | |
| import os | |
| import gc | |
| import json | |
| import subprocess | |
| from typing import Union, List, Optional, Tuple, Sequence | |
| from dataclasses import dataclass | |
| from .codebook import ScalarQuantizer | |
| from .rotation import get_orthogonal_matrix, rotate_forward, rotate_backward | |
| from .tq_bridge import tq_native | |
| class ProdQuantized: | |
| sq_codes: np.ndarray | |
| qjl_signs: np.ndarray | |
| norms: np.ndarray | |
| centroids: np.ndarray | |
| dim: int | |
| sq_bits: int | |
| total_bits: int | |
| qjl_scale: float | |
| rot_op: np.ndarray | |
| res_norms: np.ndarray | |
| class IVFData: | |
| coarse_centroids: torch.Tensor | |
| pq_data: ProdQuantized | |
| vector_ids: np.ndarray | |
| list_offsets: np.ndarray | |
| n_list: int | |
| n_probe: int | |
| class TQEngine: | |
| def __init__(self, dim: int = 768, bits: int = 4, device: str = None, use_ivf: bool = False, ivf_nlist: int = 1024, ivf_nprobe: int = 32): | |
| if bits not in [2, 4]: | |
| raise ValueError(f"TurboQuant currently only supports 2-bit (1+1) and 4-bit (3+1) configurations. Received: {bits}") | |
| self.dim = dim | |
| self.bits = bits | |
| self.sq_bits = bits - 1 | |
| self.device = device if device else ("cuda" if torch.cuda.is_available() else "cpu") | |
| self.use_ivf = use_ivf | |
| self.ivf_nlist = ivf_nlist | |
| self.ivf_nprobe = int(ivf_nprobe) | |
| self.ivf_target_candidates = 20000 | |
| self.ivf_nprobe_min = 1 | |
| self.ivf_nprobe_max = None | |
| self.deleted_ids = set() | |
| self.dynamic_shards = {} | |
| self.current_ivf_data = None | |
| self.raw_vectors = None | |
| self.sq_quantizer = ScalarQuantizer(dim=dim, bits=self.sq_bits, device=self.device) | |
| self.rot_op_t = get_orthogonal_matrix(dim, device=self.device) | |
| self.rot_op_np = self.rot_op_t.cpu().numpy().astype(np.float32) | |
| self.qjl_scale = 1.0 | |
| self.hnsw_navigator = None | |
| self.use_hnsw = False | |
| def _auto_nprobe(self) -> int: | |
| ivf = self.current_ivf_data | |
| if ivf is None or not isinstance(ivf, IVFData): | |
| return max(1, self.ivf_nprobe) | |
| offsets = np.asarray(ivf.list_offsets, dtype=np.int64) | |
| if offsets.ndim != 1 or offsets.size < 2: | |
| return max(1, self.ivf_nprobe) | |
| counts = offsets[1:] - offsets[:-1] | |
| if counts.size == 0: | |
| return max(1, self.ivf_nprobe) | |
| med = float(np.median(counts)) | |
| if med <= 0: | |
| med = float(np.mean(counts)) if float(np.mean(counts)) > 0 else 1.0 | |
| target = max(1, int(self.ivf_target_candidates)) | |
| nprobe = int(np.ceil(target / med)) | |
| nlist = int(ivf.n_list) if hasattr(ivf, "n_list") else int(counts.size) | |
| nprobe = max(self.ivf_nprobe_min, nprobe) | |
| if self.ivf_nprobe_max is not None: | |
| nprobe = min(int(self.ivf_nprobe_max), nprobe) | |
| nprobe = min(nlist, nprobe) | |
| return max(1, nprobe) | |
| def _train_kmeans(self, x_sample: torch.Tensor, n_list: int, iters: int = 10): | |
| N, D = x_sample.shape | |
| indices = torch.randperm(N)[:n_list] | |
| centroids = x_sample[indices].clone() | |
| for i in range(iters): | |
| assignments = torch.zeros(N, dtype=torch.long, device=self.device) | |
| chunk_size = 8192 | |
| for s in range(0, N, chunk_size): | |
| e = min(s + chunk_size, N) | |
| batch = x_sample[s:e] | |
| scores = torch.mm(batch, centroids.t()) | |
| assignments[s:e] = scores.argmax(dim=1) | |
| del scores, batch | |
| new_centroids = torch.zeros_like(centroids) | |
| counts = torch.zeros(n_list, 1, device=self.device) | |
| ones = torch.ones(N, 1, device=self.device) | |
| new_centroids.index_add_(0, assignments, x_sample) | |
| counts.index_add_(0, assignments, ones) | |
| counts = torch.clamp(counts, min=1.0) | |
| centroids = new_centroids / counts | |
| centroids = centroids / (torch.norm(centroids, dim=1, keepdim=True) + 1e-12) | |
| gc.collect() | |
| return centroids | |
| def bind_raw_data(self, npy_path: str): | |
| if os.path.exists(npy_path): | |
| self.raw_vectors = np.load(npy_path, mmap_mode='r') | |
| else: | |
| print(f"⚠️ Warning: Raw data file not found: {npy_path}") | |
| def index(self, x: Union[torch.Tensor, np.ndarray], online_clustering: bool = False, | |
| save_path: str = None, n_train: int = None, train_iters: int = 20, | |
| init_centroids: np.ndarray = None): | |
| if isinstance(x, torch.Tensor): | |
| x_np = x.detach().cpu().numpy().astype(np.float32, copy=False) | |
| else: | |
| x_np = np.asarray(x, dtype=np.float32) | |
| N, D = x_np.shape | |
| self.dim = D | |
| if self.use_ivf: | |
| if init_centroids is not None: | |
| print(f" Using pre-trained coarse centroids ({init_centroids.shape[0]} clusters)...") | |
| centroids_np = np.asarray(init_centroids, dtype=np.float32) | |
| n_sample = n_train if n_train is not None else 180000 | |
| else: | |
| if n_train is None: | |
| n_train = min(N, 256 * self.ivf_nlist) | |
| n_sample = min(N, n_train) | |
| indices = np.random.choice(N, n_sample, replace=False) | |
| x_sample_np = x_np[indices].copy() | |
| print(f" Training {self.ivf_nlist} coarse centroids on {n_sample:,} samples (iters={train_iters})...") | |
| centroids_np = tq_native.tq_kmeans_train(x_sample_np, self.ivf_nlist, iters=train_iters) | |
| del x_sample_np | |
| coarse_centroids = torch.from_numpy(centroids_np).to(self.device) | |
| print(f" Assigning {N:,} vectors to clusters...") | |
| assignments = tq_native.tq_assign_clusters(x_np, centroids_np) | |
| with torch.no_grad(): | |
| n_scale = n_sample | |
| scale_idx = np.random.choice(N, n_scale, replace=False) | |
| x_scale = torch.from_numpy(np.asarray(x_np[scale_idx], dtype=np.float32)).to(self.device) | |
| c_idx = torch.from_numpy(assignments[scale_idx]).to(self.device).long() | |
| c = coarse_centroids.index_select(0, c_idx) | |
| residual = x_scale - c | |
| residual_rot = rotate_forward(residual, self.rot_op_t) | |
| print(f" Calibrating Lloyd-Max Codebook for {self.sq_bits}-bit SQ...") | |
| self.sq_quantizer.fit(residual_rot, iterations=30) | |
| self.qjl_scale = float(torch.mean(torch.abs(residual_rot)).item()) | |
| del x_scale, c_idx, c, residual, residual_rot | |
| index_dir = save_path if save_path else "tq_index_temp" | |
| os.makedirs(index_dir, exist_ok=True) | |
| counts = np.bincount(assignments, minlength=self.ivf_nlist) | |
| offsets = np.zeros(self.ivf_nlist + 1, dtype=np.int64) | |
| offsets[1:] = np.cumsum(counts) | |
| sq_packed_dim = self.dim | |
| if self.sq_bits == 1: sq_packed_dim = self.dim // 8 | |
| elif self.sq_bits == 3: sq_packed_dim = self.dim // 2 | |
| qjl_packed_dim = self.dim // 8 | |
| from numpy.lib.format import open_memmap | |
| f_sq = open_memmap(os.path.join(index_dir, "sq_codes.npy"), mode='w+', dtype=np.uint8, shape=(N, sq_packed_dim)) | |
| f_signs = open_memmap(os.path.join(index_dir, "qjl_signs.npy"), mode='w+', dtype=np.uint8, shape=(N, qjl_packed_dim)) | |
| f_norms = open_memmap(os.path.join(index_dir, "norms.npy"), mode='w+', dtype=np.float32, shape=(N,)) | |
| f_res_norms = open_memmap(os.path.join(index_dir, "res_norms.npy"), mode='w+', dtype=np.float32, shape=(N,)) | |
| f_ids = open_memmap(os.path.join(index_dir, "vector_ids.npy"), mode='w+', dtype=np.int64, shape=(N,)) | |
| dummy_pq = self._quantize_flat(torch.zeros((1, self.dim), device=self.device)) | |
| print(f" Streaming Indexing: Rotating & Quantizing {N:,} vectors in batches...") | |
| cluster_counters = np.zeros(self.ivf_nlist, dtype=np.int64) | |
| batch_size = 100000 | |
| with torch.no_grad(): | |
| for s in range(0, N, batch_size): | |
| e = min(s + batch_size, N) | |
| batch_x = torch.from_numpy(x_np[s:e]).to(self.device) | |
| batch_c = torch.from_numpy(centroids_np[assignments[s:e]]).to(self.device) | |
| batch_res = batch_x - batch_c | |
| batch_rot_res = rotate_forward(batch_res, self.rot_op_t) | |
| batch_q_res = self._quantize_flat(batch_x, online_clustering=False, x_rot=batch_rot_res) | |
| target_positions = np.zeros(e - s, dtype=np.int64) | |
| batch_ass = assignments[s:e] | |
| for i, c_id in enumerate(batch_ass): | |
| target_positions[i] = offsets[c_id] + cluster_counters[c_id] | |
| cluster_counters[c_id] += 1 | |
| f_sq[target_positions] = batch_q_res.sq_codes | |
| f_signs[target_positions] = batch_q_res.qjl_signs | |
| f_norms[target_positions] = batch_q_res.norms | |
| f_res_norms[target_positions] = batch_q_res.res_norms | |
| f_ids[target_positions] = np.arange(s, e, dtype=np.int64) | |
| if s % 1000000 == 0: | |
| f_sq.flush(); f_signs.flush(); f_norms.flush(); f_res_norms.flush(); f_ids.flush() | |
| gc.collect() | |
| del batch_x, batch_c, batch_res, batch_rot_res, batch_q_res | |
| import faiss | |
| self.hnsw_navigator = faiss.IndexHNSWFlat(self.dim, 32, faiss.METRIC_INNER_PRODUCT) | |
| self.hnsw_navigator.add(centroids_np) | |
| print(f" Finalizing index files...") | |
| f_sq.flush(); f_signs.flush(); f_norms.flush(); f_res_norms.flush(); f_ids.flush() | |
| # Lưu các file metadata quan trọng | |
| np.save(os.path.join(index_dir, "list_offsets.npy"), offsets.astype(np.int32)) | |
| np.save(os.path.join(index_dir, "coarse_centroids.npy"), centroids_np) | |
| np.save(os.path.join(index_dir, "sq_centroids.npy"), dummy_pq.centroids) | |
| np.save(os.path.join(index_dir, "rot_op.npy"), self.rot_op_np) | |
| if hasattr(self, "hnsw_navigator") and self.hnsw_navigator is not None: | |
| faiss.write_index(self.hnsw_navigator, os.path.join(index_dir, "centroids.hnsw")) | |
| import json | |
| with open(os.path.join(index_dir, "metadata.json"), "w", encoding='utf-8') as f: | |
| json.dump({ | |
| "dim": int(self.dim), "bits": int(self.bits), "qjl_scale": float(self.qjl_scale), | |
| "n_list": int(self.ivf_nlist), "n_probe": int(self.ivf_nprobe), "deleted_ids": [] | |
| }, f, indent=2) | |
| flat_pq = ProdQuantized( | |
| sq_codes=f_sq, qjl_signs=f_signs, norms=f_norms, | |
| centroids=dummy_pq.centroids, dim=self.dim, sq_bits=self.sq_bits, total_bits=self.bits, | |
| qjl_scale=self.qjl_scale, rot_op=self.rot_op_np, res_norms=f_res_norms | |
| ) | |
| self.current_ivf_data = IVFData( | |
| coarse_centroids=coarse_centroids, pq_data=flat_pq, vector_ids=f_ids, | |
| list_offsets=offsets.astype(np.int32), n_list=self.ivf_nlist, n_probe=self.ivf_nprobe | |
| ) | |
| return self.current_ivf_data | |
| else: | |
| return self._quantize_flat(torch.from_numpy(np.array(x, dtype=np.float32)).to(self.device), online_clustering) | |
| def add(self, vector: torch.Tensor, vector_id: int): | |
| if self.current_ivf_data is None: | |
| raise ValueError("Cần gọi index() hoặc load_index() trước khi add().") | |
| if vector.device.type != self.device: | |
| vector = vector.to(self.device) | |
| if vector.dim() == 1: | |
| vector = vector.unsqueeze(0) | |
| scores = torch.mm(vector, self.current_ivf_data.coarse_centroids.t()) | |
| c_idx = scores.argmax(dim=1).item() | |
| centroid = self.current_ivf_data.coarse_centroids[c_idx].unsqueeze(0) | |
| pq_single = self._quantize_flat(vector, online_clustering=False, centroid=centroid) | |
| if c_idx not in self.dynamic_shards: | |
| self.dynamic_shards[c_idx] = [] | |
| self.dynamic_shards[c_idx].append((vector_id, pq_single)) | |
| if vector_id in self.deleted_ids: | |
| self.deleted_ids.remove(vector_id) | |
| def merge_dynamic_shards(self): | |
| ivf = self.current_ivf_data | |
| if not isinstance(ivf, IVFData) or not self.dynamic_shards: | |
| return | |
| new_total_size = len(ivf.vector_ids) + sum(len(v) for v in self.dynamic_shards.values()) | |
| updated_sq_codes = torch.zeros((new_total_size, ivf.pq_data.sq_codes.shape[1]), dtype=torch.uint8) | |
| updated_qjl_signs = torch.zeros((new_total_size, ivf.pq_data.qjl_signs.shape[1]), dtype=torch.int8) | |
| updated_norms = torch.zeros(new_total_size) | |
| updated_res_norms = torch.zeros(new_total_size) | |
| updated_vector_ids = np.zeros(new_total_size, dtype=np.int64) | |
| updated_offsets = [0] | |
| curr_pos = 0 | |
| for c_idx in range(ivf.n_list): | |
| old_start = ivf.list_offsets[c_idx] | |
| old_end = ivf.list_offsets[c_idx+1] | |
| old_size = old_end - old_start | |
| if old_size > 0: | |
| updated_sq_codes[curr_pos:curr_pos+old_size] = torch.from_numpy(ivf.pq_data.sq_codes[old_start:old_end]) | |
| updated_qjl_signs[curr_pos:curr_pos+old_size] = torch.from_numpy(ivf.pq_data.qjl_signs[old_start:old_end]) | |
| updated_norms[curr_pos:curr_pos+old_size] = torch.from_numpy(ivf.pq_data.norms[old_start:old_end]) | |
| updated_res_norms[curr_pos:curr_pos+old_size] = torch.from_numpy(ivf.pq_data.res_norms[old_start:old_end]) | |
| updated_vector_ids[curr_pos:curr_pos+old_size] = ivf.vector_ids[old_start:old_end] | |
| curr_pos += old_size | |
| if c_idx in self.dynamic_shards: | |
| for vid, dpq in self.dynamic_shards[c_idx]: | |
| updated_sq_codes[curr_pos] = torch.from_numpy(dpq.sq_codes) if isinstance(dpq.sq_codes, np.ndarray) else dpq.sq_codes | |
| updated_qjl_signs[curr_pos] = torch.from_numpy(dpq.qjl_signs) if isinstance(dpq.qjl_signs, np.ndarray) else dpq.qjl_signs | |
| updated_norms[curr_pos] = float(dpq.norms[0]) if isinstance(dpq.norms, np.ndarray) else float(dpq.norms) | |
| updated_res_norms[curr_pos] = float(dpq.res_norms[0]) if isinstance(dpq.res_norms, np.ndarray) else float(dpq.res_norms) | |
| updated_vector_ids[curr_pos] = vid | |
| curr_pos += 1 | |
| updated_offsets.append(curr_pos) | |
| ivf.pq_data.sq_codes = updated_sq_codes | |
| ivf.pq_data.qjl_signs = updated_qjl_signs | |
| ivf.pq_data.norms = updated_norms | |
| ivf.pq_data.res_norms = updated_res_norms | |
| ivf.vector_ids = updated_vector_ids | |
| ivf.list_offsets = torch.tensor(updated_offsets, dtype=torch.long) | |
| self.dynamic_shards.clear() | |
| def delete(self, vector_id: int): | |
| self.deleted_ids.add(vector_id) | |
| def _quantize_flat(self, x: torch.Tensor, online_clustering: bool = False, x_rot: torch.Tensor = None, centroid: torch.Tensor = None) -> ProdQuantized: | |
| if x.device.type != self.device: | |
| x = x.to(self.device) | |
| if centroid is not None: | |
| if centroid.device.type != self.device: | |
| centroid = centroid.to(self.device) | |
| x_target = x - centroid | |
| else: | |
| x_target = x | |
| if x_rot is None: | |
| x_rot = rotate_forward(x_target, self.rot_op_t) | |
| norms = torch.norm(x, dim=-1) | |
| res_norms = torch.norm(x_target, dim=-1) | |
| if online_clustering: | |
| self.sq_quantizer.fit(x_rot) | |
| x_rot_np = np.ascontiguousarray(x_rot.detach().cpu().numpy(), dtype=np.float32) | |
| sq_centroids_np = np.ascontiguousarray(self.sq_quantizer.centroids.detach().cpu().numpy(), dtype=np.float32) | |
| try: | |
| sq_codes_packed, qjl_signs_packed, res_norms_np = tq_native.tq_quantize_rotated(x_rot_np, sq_centroids_np, int(self.sq_bits)) | |
| sq_codes_np = np.asarray(sq_codes_packed, dtype=np.uint8) | |
| qjl_signs = np.asarray(qjl_signs_packed, dtype=np.uint8) | |
| res_norms_np = np.asarray(res_norms_np, dtype=np.float32) | |
| except Exception: | |
| sq_q = self.sq_quantizer.quantize(x_rot) | |
| x_hat_1 = self.sq_quantizer.reconstruct(sq_q.indices) | |
| residual = x_rot - x_hat_1 | |
| res_norms_np = torch.norm(residual, dim=-1).detach().cpu().numpy().astype(np.float32) | |
| signs = (residual > 0).to(torch.uint8).cpu().numpy() | |
| qjl_signs = np.packbits(signs, axis=-1, bitorder='little').astype(np.uint8) | |
| sq_codes_np = sq_q.indices.cpu().numpy().astype(np.uint8) | |
| return ProdQuantized( | |
| sq_codes=sq_codes_np, qjl_signs=qjl_signs.astype(np.uint8), | |
| norms=norms.cpu().numpy().astype(np.float32), | |
| centroids=self.sq_quantizer.centroids.cpu().numpy().astype(np.float32), | |
| dim=self.dim, sq_bits=self.sq_bits, total_bits=self.bits, | |
| qjl_scale=self.qjl_scale, rot_op=self.rot_op_np, res_norms=res_norms_np | |
| ) | |
| def search_batch(self, queries: torch.Tensor, top_k: int = 100, n_probe: int = None, | |
| allowed_ids: Optional[List[int]] = None, | |
| raw_corpus: Optional[np.ndarray] = None, | |
| rerank_factor: Optional[int] = None) -> List[Tuple[torch.Tensor, torch.Tensor]]: | |
| ivf = self.current_ivf_data | |
| if ivf is None: | |
| raise ValueError("No data indexed. Call index() first.") | |
| if queries.device.type != self.device: | |
| queries = queries.to(self.device) | |
| if queries.dim() == 1: | |
| queries = queries.unsqueeze(0) | |
| nprobe = int(n_probe) if n_probe is not None else (self._auto_nprobe() if self.ivf_nprobe <= 0 else int(self.ivf_nprobe)) | |
| num_queries = queries.shape[0] | |
| queries_np = np.ascontiguousarray(queries.detach().cpu().numpy(), dtype=np.float32) | |
| pq = ivf.pq_data | |
| # Fast Path: If HNSW is not active, run 100% in Rust using tq_unified_search | |
| if not self.use_hnsw or self.hnsw_navigator is None: | |
| allowed_arr = np.array(allowed_ids, dtype=np.int64) if allowed_ids is not None else None | |
| raw_arr = raw_corpus if raw_corpus is not None else (self.raw_vectors if hasattr(self, "raw_vectors") else None) | |
| scores, indices = tq_native.tq_unified_search( | |
| queries_np, | |
| self.rot_op_np, | |
| ivf.coarse_centroids.cpu().numpy(), | |
| ivf.list_offsets, | |
| ivf.vector_ids, | |
| pq.sq_codes, | |
| pq.centroids, | |
| pq.norms, | |
| pq.qjl_signs, | |
| pq.res_norms, | |
| float(pq.qjl_scale), | |
| int(self.dim), | |
| int(self.sq_bits), | |
| int(nprobe), | |
| int(top_k), | |
| allowed_arr, | |
| raw_arr, | |
| int(rerank_factor) if rerank_factor is not None else None | |
| ) | |
| results = [] | |
| for i in range(num_queries): | |
| final_ids = torch.from_numpy(indices[i]).to(self.device) | |
| final_scores = torch.from_numpy(scores[i]).to(self.device) | |
| results.append((final_ids, final_scores)) | |
| return results | |
| # Slow/Fallback Path: Standard HNSW scan (uses tq_ivf_scan_with_clusters) | |
| if self.use_hnsw and self.hnsw_navigator is not None: | |
| cluster_scores_np, cluster_ids_np = self.hnsw_navigator.search(queries_np, nprobe) | |
| cluster_ids = torch.from_numpy(cluster_ids_np).to(self.device).long() | |
| cluster_scores = torch.from_numpy(cluster_scores_np).to(self.device).float() | |
| else: | |
| scores_c = torch.mm(queries, ivf.coarse_centroids.t()) | |
| cluster_scores, cluster_ids = torch.topk(scores_c, nprobe, dim=1) | |
| q_rot = rotate_forward(queries, self.rot_op_t) | |
| q_rot_np = np.ascontiguousarray(q_rot.cpu().numpy(), dtype=np.float32) | |
| scores, indices = tq_native.tq_ivf_scan_with_clusters( | |
| queries_np, pq.sq_codes, | |
| pq.centroids, | |
| pq.norms, | |
| pq.qjl_signs, | |
| pq.res_norms, | |
| q_rot_np, ivf.list_offsets, | |
| np.ascontiguousarray(cluster_ids.cpu().numpy(), dtype=np.int32), | |
| np.ascontiguousarray(cluster_scores.cpu().numpy(), dtype=np.float32), | |
| float(pq.qjl_scale), int(self.dim), int(self.sq_bits), | |
| int(top_k if allowed_ids is None else top_k * 10) | |
| ) | |
| results = [] | |
| global_ids = ivf.vector_ids | |
| allowed_set = set(allowed_ids) if allowed_ids is not None else None | |
| for i in range(num_queries): | |
| valid_mask = indices[i] != -1 | |
| q_indices = indices[i][valid_mask] | |
| q_scores = scores[i][valid_mask] | |
| q_global_ids = global_ids[q_indices] | |
| # Apply rerank if requested in slow path | |
| if raw_corpus is not None and rerank_factor is not None: | |
| candidates_raw = torch.from_numpy(raw_corpus[q_global_ids]).to(self.device) | |
| exact_scores = torch.mm(queries[i:i+1], candidates_raw.t()).view(-1) | |
| _, final_idx = torch.topk(exact_scores, min(top_k, len(exact_scores))) | |
| q_global_ids = q_global_ids[final_idx.cpu().numpy()] | |
| q_scores = exact_scores[final_idx].cpu().numpy() | |
| if allowed_set is not None: | |
| mask = np.isin(q_global_ids, list(allowed_set)) | |
| q_global_ids = q_global_ids[mask] | |
| q_scores = q_scores[mask] | |
| q_global_ids = q_global_ids[:top_k] | |
| q_scores = q_scores[:top_k] | |
| final_ids = torch.from_numpy(q_global_ids.copy()).to(self.device) | |
| final_scores = torch.from_numpy(q_scores.copy()).to(self.device) | |
| results.append((final_ids, final_scores)) | |
| return results | |
| def search(self, query: torch.Tensor, top_k: int = 100, n_probe: int = None, allowed_ids: Optional[List[int]] = None) -> tuple[torch.Tensor, torch.Tensor]: | |
| ivf = self.current_ivf_data | |
| if ivf is None: | |
| raise ValueError("No data indexed. Call index() first.") | |
| if query.device.type != self.device: | |
| query = query.to(self.device) | |
| if isinstance(ivf, IVFData): | |
| if query.dim() == 1: | |
| query = query.unsqueeze(0) | |
| results = self.search_batch(query, top_k=top_k, n_probe=n_probe, allowed_ids=allowed_ids) | |
| return results[0] | |
| else: | |
| return self._native_cosine_search_flat(query, ivf, top_k) | |
| def _native_cosine_search_flat(self, query: torch.Tensor, pq: ProdQuantized, top_k: int = 100) -> tuple[torch.Tensor, torch.Tensor]: | |
| if query.device.type != self.device: | |
| query = query.to(self.device) | |
| if query.dim() == 1: | |
| query = query.unsqueeze(0) | |
| q_rot = rotate_forward(query, self.rot_op_t).squeeze(0) | |
| q_np = q_rot.cpu().numpy().astype(np.float32) | |
| query_1d = np.array(q_np, dtype=np.float32, order='C') | |
| total_vectors = pq.sq_codes.shape[0] | |
| ram_gb = 4.0 | |
| h = 10**(len(str(total_vectors)) - 1) | |
| raw_batch_size = int((0.3 * total_vectors * (h * 100) / total_vectors) / (ram_gb * (ram_gb / 0.4))) + 1 | |
| compression_ratio = 32 // self.bits | |
| tq_batch_size = raw_batch_size * compression_ratio | |
| all_scores = [] | |
| for start_idx in range(0, total_vectors, tq_batch_size): | |
| end_idx = min(start_idx + tq_batch_size, total_vectors) | |
| sq_batch = np.ascontiguousarray(pq.sq_codes[start_idx:end_idx], dtype=np.uint8) | |
| qjl_batch = np.ascontiguousarray(pq.qjl_signs[start_idx:end_idx], dtype=np.uint8) | |
| norms_batch = np.ascontiguousarray(pq.norms[start_idx:end_idx], dtype=np.float32) | |
| res_norms_batch = np.ascontiguousarray(pq.res_norms[start_idx:end_idx], dtype=np.float32) | |
| centroids_1d = np.ascontiguousarray(pq.centroids, dtype=np.float32) | |
| batch_scores = tq_native.tq_scan(query_1d, sq_batch, centroids_1d, norms_batch, qjl_batch, res_norms_batch, query_1d, float(pq.qjl_scale), int(self.dim), int(self.sq_bits)) | |
| all_scores.append(batch_scores) | |
| final_scores = np.concatenate(all_scores) | |
| scores_t = torch.from_numpy(final_scores).view(-1) | |
| top_scores, top_indices = torch.topk(scores_t, min(top_k, len(scores_t))) | |
| return top_indices, top_scores | |
| def save_index(self, path: str): | |
| from pathlib import Path | |
| import faiss | |
| save_path = Path(path).resolve() | |
| if self.current_ivf_data is None: | |
| raise ValueError("Không có dữ liệu để lưu. Hãy gọi index() trước.") | |
| gc.collect() | |
| if not save_path.exists(): | |
| save_path.mkdir(parents=True, exist_ok=True) | |
| ivf = self.current_ivf_data | |
| pq = ivf.pq_data | |
| def to_np(obj): | |
| if isinstance(obj, torch.Tensor): | |
| return obj.detach().cpu().numpy() | |
| return obj | |
| files_to_save = {} | |
| if not isinstance(pq.sq_codes, np.memmap): | |
| files_to_save["sq_codes.npy"] = to_np(pq.sq_codes) | |
| files_to_save["qjl_signs.npy"] = to_np(pq.qjl_signs) | |
| files_to_save["norms.npy"] = to_np(pq.norms) | |
| files_to_save["res_norms.npy"] = to_np(pq.res_norms) | |
| files_to_save["vector_ids.npy"] = to_np(ivf.vector_ids) | |
| # Các file này phải luôn được lưu | |
| files_to_save["list_offsets.npy"] = to_np(ivf.list_offsets) | |
| files_to_save["coarse_centroids.npy"] = to_np(ivf.coarse_centroids) | |
| files_to_save["rot_op.npy"] = to_np(pq.rot_op) | |
| files_to_save["sq_centroids.npy"] = to_np(pq.centroids) | |
| for filename, data in files_to_save.items(): | |
| np.save(str(save_path / filename), data) | |
| if self.hnsw_navigator is not None: | |
| faiss.write_index(self.hnsw_navigator, str(save_path / "centroids.hnsw")) | |
| meta = { | |
| "dim": int(self.dim), "bits": int(self.bits), "qjl_scale": float(self.qjl_scale), | |
| "n_list": int(ivf.n_list), "n_probe": int(ivf.n_probe), "deleted_ids": list(self.deleted_ids) | |
| } | |
| with open(str(save_path / "metadata.json"), "w", encoding='utf-8') as f: | |
| json.dump(meta, f, indent=2) | |
| def load_index(self, path: str, use_mmap: bool = True): | |
| from pathlib import Path | |
| import platform | |
| import faiss | |
| load_path = Path(path).resolve() | |
| if not load_path.exists(): | |
| raise FileNotFoundError(f"Thư mục index không tồn tại: {load_path}") | |
| with open(str(load_path / "metadata.json"), "r", encoding='utf-8') as f: | |
| meta = json.load(f) | |
| self.dim = meta["dim"] | |
| self.bits = meta["bits"] | |
| self.sq_bits = self.bits - 1 | |
| self.qjl_scale = meta["qjl_scale"] | |
| self.ivf_nlist = meta["n_list"] | |
| self.ivf_nprobe = meta["n_probe"] | |
| self.deleted_ids = set(meta["deleted_ids"]) | |
| mmap_val = 'r' | |
| if platform.system() == "Windows" and not use_mmap: | |
| mmap_val = None | |
| elif platform.system() != "Windows": | |
| mmap_val = 'r' | |
| coarse_centroids = torch.from_numpy(np.load(os.path.join(path, "coarse_centroids.npy"))).to(self.device) | |
| sq_codes = np.load(os.path.join(path, "sq_codes.npy"), mmap_mode=mmap_val) | |
| qjl_signs = np.load(os.path.join(path, "qjl_signs.npy"), mmap_mode=mmap_val) | |
| norms = np.load(os.path.join(path, "norms.npy"), mmap_mode=mmap_val) | |
| res_norms = np.load(os.path.join(path, "res_norms.npy"), mmap_mode=mmap_val) | |
| vector_ids = np.load(os.path.join(path, "vector_ids.npy"), mmap_mode=mmap_val) | |
| list_offsets = np.load(os.path.join(path, "list_offsets.npy"), mmap_mode=mmap_val) | |
| rot_op = np.load(os.path.join(path, "rot_op.npy"), mmap_mode=mmap_val) | |
| sq_centroids = np.load(os.path.join(path, "sq_centroids.npy"), mmap_mode=mmap_val) | |
| self.rot_op_np = rot_op | |
| self.rot_op_t = torch.from_numpy(rot_op).to(self.device) | |
| pq_data = ProdQuantized( | |
| sq_codes=sq_codes, qjl_signs=qjl_signs, norms=norms, centroids=sq_centroids, | |
| dim=self.dim, sq_bits=self.sq_bits, total_bits=self.bits, | |
| qjl_scale=self.qjl_scale, rot_op=rot_op, res_norms=res_norms | |
| ) | |
| self.current_ivf_data = IVFData( | |
| coarse_centroids=coarse_centroids, pq_data=pq_data, vector_ids=vector_ids, | |
| list_offsets=list_offsets, n_list=self.ivf_nlist, n_probe=self.ivf_nprobe | |
| ) | |
| # Load HNSW Navigator nếu có | |
| hnsw_path = os.path.join(path, "centroids.hnsw") | |
| if os.path.exists(hnsw_path): | |
| import faiss | |
| idx = faiss.read_index(str(hnsw_path)) | |
| # Kiểm tra metric, nếu là L2 (0) thì phải tạo lại vì TQ dùng IP (1) | |
| if idx.metric_type != faiss.METRIC_INNER_PRODUCT: | |
| print(" Warning: HNSW index uses L2 metric. Rebuilding with Inner Product...") | |
| self.hnsw_navigator = faiss.IndexHNSWFlat(self.dim, 32, faiss.METRIC_INNER_PRODUCT) | |
| self.hnsw_navigator.add(coarse_centroids.cpu().numpy()) | |
| faiss.write_index(self.hnsw_navigator, str(hnsw_path)) | |
| else: | |
| self.hnsw_navigator = idx | |
| elif self.use_hnsw: | |
| print(" Warning: HNSW index missing. Building now...") | |
| import faiss | |
| self.hnsw_navigator = faiss.IndexHNSWFlat(self.dim, 32, faiss.METRIC_INNER_PRODUCT) | |
| self.hnsw_navigator.add(coarse_centroids.cpu().numpy()) | |
| faiss.write_index(self.hnsw_navigator, str(hnsw_path)) | |
| print(f"Loaded index from: {path} (mmap={mmap_val})") | |