ParshvPatel's picture
feat: HuggingFace Spaces deployment
d992912
import logging
from typing import List, Optional, Tuple
import numpy as np
import faiss
from backend.app.config import SearchConfig
logger = logging.getLogger(__name__)
class DualFAISSIndex:
"""
Two parallel FAISS indices (image + text) fused via Reciprocal Rank Fusion.
"""
def __init__(self, dim: int, config: SearchConfig):
self.dim = dim
self.config = config
self.image_index = None
self.text_index = None
def _create_index(self, n_vectors: int) -> faiss.Index:
if n_vectors < 5000:
logger.info(f"Using IndexFlatIP (exact, n={n_vectors:,})")
return faiss.IndexFlatIP(self.dim)
n_clusters = min(self.config.n_clusters, max(16, n_vectors // 40))
logger.info(f"Using IndexIVFFlat (n={n_vectors:,}, clusters={n_clusters})")
quantizer = faiss.IndexFlatIP(self.dim)
index = faiss.IndexIVFFlat(
quantizer, self.dim, n_clusters, faiss.METRIC_INNER_PRODUCT
)
return index
def build(self, image_embeddings: np.ndarray, text_embeddings: np.ndarray):
image_embeddings = image_embeddings.astype(np.float32)
text_embeddings = text_embeddings.astype(np.float32)
assert image_embeddings.shape == text_embeddings.shape, (
f"Shape mismatch: images {image_embeddings.shape} vs text {text_embeddings.shape}"
)
n = image_embeddings.shape[0]
logger.info("Building image FAISS index...")
self.image_index = self._create_index(n)
if hasattr(self.image_index, 'train'):
try:
self.image_index.train(image_embeddings)
except Exception:
self.image_index = faiss.IndexFlatIP(self.dim)
self.image_index.add(image_embeddings)
logger.info("Building text FAISS index...")
self.text_index = self._create_index(n)
if hasattr(self.text_index, 'train'):
try:
self.text_index.train(text_embeddings)
except Exception:
self.text_index = faiss.IndexFlatIP(self.dim)
self.text_index.add(text_embeddings)
logger.info(
f"Dual index built: {self.image_index.ntotal:,} image, "
f"{self.text_index.ntotal:,} text vectors"
)
def search_image_index(self, query: np.ndarray, top_k: int):
q = query.astype(np.float32).reshape(1, -1)
if hasattr(self.image_index, 'nprobe'):
self.image_index.nprobe = self.config.n_probe
return self.image_index.search(q, top_k)
def search_text_index(self, query: np.ndarray, top_k: int):
q = query.astype(np.float32).reshape(1, -1)
if hasattr(self.text_index, 'nprobe'):
self.text_index.nprobe = self.config.n_probe
return self.text_index.search(q, top_k)
def search_fused(
self, query: np.ndarray, top_k: int,
image_weight: Optional[float] = None,
text_weight: Optional[float] = None,
) -> Tuple[List[int], List[float]]:
iw = image_weight or self.config.image_index_weight
tw = text_weight or self.config.text_index_weight
rrf_k = self.config.rrf_k
broad_k = min(top_k * 3, self.image_index.ntotal)
_, img_ids = self.search_image_index(query, broad_k)
_, txt_ids = self.search_text_index(query, broad_k)
img_ids = img_ids[0]
txt_ids = txt_ids[0]
img_rank = {int(idx): rank + 1 for rank, idx in enumerate(img_ids) if idx >= 0}
txt_rank = {int(idx): rank + 1 for rank, idx in enumerate(txt_ids) if idx >= 0}
all_candidates = set(img_rank.keys()) | set(txt_rank.keys())
scores = {}
for idx in all_candidates:
score = 0.0
if idx in img_rank:
score += iw / (rrf_k + img_rank[idx])
if idx in txt_rank:
score += tw / (rrf_k + txt_rank[idx])
scores[idx] = score
ranked = sorted(scores.items(), key=lambda x: -x[1])[:top_k]
return [r[0] for r in ranked], [r[1] for r in ranked]
def save(self, image_path: str, text_path: str):
faiss.write_index(self.image_index, image_path)
faiss.write_index(self.text_index, text_path)
logger.info(f"Saved dual index to {image_path} and {text_path}")
def load(self, image_path: str, text_path: str):
self.image_index = faiss.read_index(image_path)
self.text_index = faiss.read_index(text_path)
logger.info(
f"Loaded dual index: {self.image_index.ntotal:,} image, "
f"{self.text_index.ntotal:,} text vectors"
)
__all__ = ["DualFAISSIndex"]