| import logging |
| from typing import List, Optional, Tuple |
|
|
| import numpy as np |
| import faiss |
|
|
| from backend.app.config import SearchConfig |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class DualFAISSIndex: |
| """ |
| Two parallel FAISS indices (image + text) fused via Reciprocal Rank Fusion. |
| """ |
|
|
| def __init__(self, dim: int, config: SearchConfig): |
| self.dim = dim |
| self.config = config |
| self.image_index = None |
| self.text_index = None |
|
|
| def _create_index(self, n_vectors: int) -> faiss.Index: |
| if n_vectors < 5000: |
| logger.info(f"Using IndexFlatIP (exact, n={n_vectors:,})") |
| return faiss.IndexFlatIP(self.dim) |
|
|
| n_clusters = min(self.config.n_clusters, max(16, n_vectors // 40)) |
| logger.info(f"Using IndexIVFFlat (n={n_vectors:,}, clusters={n_clusters})") |
| quantizer = faiss.IndexFlatIP(self.dim) |
| index = faiss.IndexIVFFlat( |
| quantizer, self.dim, n_clusters, faiss.METRIC_INNER_PRODUCT |
| ) |
| return index |
|
|
| def build(self, image_embeddings: np.ndarray, text_embeddings: np.ndarray): |
| image_embeddings = image_embeddings.astype(np.float32) |
| text_embeddings = text_embeddings.astype(np.float32) |
|
|
| assert image_embeddings.shape == text_embeddings.shape, ( |
| f"Shape mismatch: images {image_embeddings.shape} vs text {text_embeddings.shape}" |
| ) |
|
|
| n = image_embeddings.shape[0] |
|
|
| logger.info("Building image FAISS index...") |
| self.image_index = self._create_index(n) |
| if hasattr(self.image_index, 'train'): |
| try: |
| self.image_index.train(image_embeddings) |
| except Exception: |
| self.image_index = faiss.IndexFlatIP(self.dim) |
| self.image_index.add(image_embeddings) |
|
|
| logger.info("Building text FAISS index...") |
| self.text_index = self._create_index(n) |
| if hasattr(self.text_index, 'train'): |
| try: |
| self.text_index.train(text_embeddings) |
| except Exception: |
| self.text_index = faiss.IndexFlatIP(self.dim) |
| self.text_index.add(text_embeddings) |
|
|
| logger.info( |
| f"Dual index built: {self.image_index.ntotal:,} image, " |
| f"{self.text_index.ntotal:,} text vectors" |
| ) |
|
|
| def search_image_index(self, query: np.ndarray, top_k: int): |
| q = query.astype(np.float32).reshape(1, -1) |
| if hasattr(self.image_index, 'nprobe'): |
| self.image_index.nprobe = self.config.n_probe |
| return self.image_index.search(q, top_k) |
|
|
| def search_text_index(self, query: np.ndarray, top_k: int): |
| q = query.astype(np.float32).reshape(1, -1) |
| if hasattr(self.text_index, 'nprobe'): |
| self.text_index.nprobe = self.config.n_probe |
| return self.text_index.search(q, top_k) |
|
|
| def search_fused( |
| self, query: np.ndarray, top_k: int, |
| image_weight: Optional[float] = None, |
| text_weight: Optional[float] = None, |
| ) -> Tuple[List[int], List[float]]: |
| iw = image_weight or self.config.image_index_weight |
| tw = text_weight or self.config.text_index_weight |
| rrf_k = self.config.rrf_k |
|
|
| broad_k = min(top_k * 3, self.image_index.ntotal) |
|
|
| _, img_ids = self.search_image_index(query, broad_k) |
| _, txt_ids = self.search_text_index(query, broad_k) |
| img_ids = img_ids[0] |
| txt_ids = txt_ids[0] |
|
|
| img_rank = {int(idx): rank + 1 for rank, idx in enumerate(img_ids) if idx >= 0} |
| txt_rank = {int(idx): rank + 1 for rank, idx in enumerate(txt_ids) if idx >= 0} |
|
|
| all_candidates = set(img_rank.keys()) | set(txt_rank.keys()) |
| scores = {} |
| for idx in all_candidates: |
| score = 0.0 |
| if idx in img_rank: |
| score += iw / (rrf_k + img_rank[idx]) |
| if idx in txt_rank: |
| score += tw / (rrf_k + txt_rank[idx]) |
| scores[idx] = score |
|
|
| ranked = sorted(scores.items(), key=lambda x: -x[1])[:top_k] |
| return [r[0] for r in ranked], [r[1] for r in ranked] |
|
|
| def save(self, image_path: str, text_path: str): |
| faiss.write_index(self.image_index, image_path) |
| faiss.write_index(self.text_index, text_path) |
| logger.info(f"Saved dual index to {image_path} and {text_path}") |
|
|
| def load(self, image_path: str, text_path: str): |
| self.image_index = faiss.read_index(image_path) |
| self.text_index = faiss.read_index(text_path) |
| logger.info( |
| f"Loaded dual index: {self.image_index.ntotal:,} image, " |
| f"{self.text_index.ntotal:,} text vectors" |
| ) |
|
|
|
|
| __all__ = ["DualFAISSIndex"] |
|
|