import logging from typing import Optional, Dict, Any, List import threading import re import numpy as np import chromadb from rank_bm25 import BM25Okapi logger = logging.getLogger(__name__) class VectorStoreManager: _instance = None _lock = threading.Lock() _initialized = False def __new__(cls): with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self): with self._lock: if not self._initialized: self._initialize() VectorStoreManager._initialized = True def _initialize(self): """Initialize vector store with single collection + BM25 index""" try: logger.info("Initializing vector store components...") self.client = None self.collection = None db_path = "output/chromadb" # Match your pipeline path self.client = chromadb.PersistentClient(path=db_path) logger.info(f"ChromaDB client initialized at path: {db_path}") available_collections = [col.name for col in self.client.list_collections()] logger.info(f"Available collections: {available_collections}") try: self.collection = self.client.get_collection("rag_documents") collection_count = self.collection.count() logger.info( f"Collection 'rag_documents' loaded with {collection_count} documents" ) except Exception as e: logger.error(f"Collection 'rag_documents' not found: {str(e)}") raise ValueError( "Required collection 'rag_documents' not found. " f"Available: {available_collections}" ) # ---- Build BM25 index from all stored docs ---- logger.info("Building BM25 index from Chroma collection...") data = self.collection.get(include=["documents", "metadatas"]) self.all_ids: List[str] = data["ids"] self.all_docs: List[str] = data["documents"] self.all_metas: List[Dict[str, Any]] = data["metadatas"] self.tokenized_corpus = [self._tokenize(d) for d in self.all_docs] self.bm25 = BM25Okapi(self.tokenized_corpus) logger.info(f"BM25 index ready with {len(self.all_docs)} chunks") logger.info("Vector store initialized successfully") except Exception as e: logger.error(f"Failed to initialize vector store: {str(e)}") VectorStoreManager._initialized = False raise # ----------------- Helpers ----------------- def _tokenize(self, text: str) -> List[str]: return re.findall(r"\w+", (text or "").lower()) def _matches_filters( self, meta: Dict[str, Any], doc_text: str, where_filters: Optional[Dict[str, Any]], where_document: Optional[Dict[str, Any]], ) -> bool: if where_filters: for k, v in where_filters.items(): if meta.get(k) != v: return False if where_document: # you only use {"$contains": "..."} contains = where_document.get("$contains") if contains and contains.lower() not in (doc_text or "").lower(): return False return True def _rrf_fuse( self, dense_ranked: List[Dict[str, Any]], sparse_ranked: List[Dict[str, Any]], k: int = 60, w_dense: float = 0.6, w_sparse: float = 0.4, ) -> List[Dict[str, Any]]: """ Reciprocal Rank Fusion score = w_dense/(k+rank_dense) + w_sparse/(k+rank_sparse) """ scores: Dict[str, Dict[str, Any]] = {} for rank, item in enumerate(dense_ranked): doc_id = item["id"] scores.setdefault(doc_id, {"score": 0.0, "item": item}) scores[doc_id]["score"] += w_dense / (k + rank + 1) for rank, item in enumerate(sparse_ranked): doc_id = item["id"] scores.setdefault(doc_id, {"score": 0.0, "item": item}) scores[doc_id]["score"] += w_sparse / (k + rank + 1) fused = sorted(scores.values(), key=lambda x: x["score"], reverse=True) return [x["item"] for x in fused] # ----------------- Main retrieval ----------------- def retrieve_documents( self, question: str, n_results: int = 5, where_filters: Optional[Dict[str, Any]] = None, where_document: Optional[Dict[str, Any]] = None, enable_bm25: bool = False, bm25_k: Optional[int] = None, alpha: float = 0.6, # dense weight in hybrid fusion ) -> List[Dict[str, Any]]: """ Retrieve documents using: - semantic-only (Chroma) - or hybrid semantic + BM25 (RRF fusion) Returns a list of dicts: {id, text, metadata, distance, bm25_score(optional)} """ if not self._initialized or self.collection is None: raise RuntimeError("VectorStoreManager not properly initialized") logger.info(f"Retrieving documents for query: {question[:50]}...") dense_k = n_results bm25_k = bm25_k or n_results # ----- Dense retrieval (semantic via Chroma) ----- try: dense_res = self.collection.query( query_texts=[question], n_results=dense_k, include=["documents", "metadatas", "distances"], where=where_filters if where_filters else None, where_document=where_document if where_document else None, ) except Exception as e: logger.error(f"Dense retrieval failed: {str(e)}") raise dense_ranked: List[Dict[str, Any]] = [] if dense_res and dense_res.get("documents") and dense_res["documents"][0]: for i in range(len(dense_res["documents"][0])): meta = dense_res["metadatas"][0][i] dense_ranked.append({ "id": dense_res["ids"][0][i], "text": dense_res["documents"][0][i], "metadata": meta, "distance": float(dense_res["distances"][0][i]), "source": meta.get("source", "Unknown"), }) if not enable_bm25: logger.info(f"Semantic-only retrieved {len(dense_ranked)} docs") return dense_ranked # ----- Sparse retrieval (BM25) ----- q_tokens = self._tokenize(question) scores = self.bm25.get_scores(q_tokens) # Apply same filters to BM25 corpus valid_indices = [] for idx, (doc, meta) in enumerate(zip(self.all_docs, self.all_metas)): if self._matches_filters(meta, doc, where_filters, where_document): valid_indices.append(idx) # take top bm25_k from valid indices valid_scores = [(idx, scores[idx]) for idx in valid_indices] valid_scores.sort(key=lambda x: x[1], reverse=True) top_sparse = valid_scores[:bm25_k] sparse_ranked: List[Dict[str, Any]] = [] for idx, s in top_sparse: meta = self.all_metas[idx] sparse_ranked.append({ "id": self.all_ids[idx], "text": self.all_docs[idx], "metadata": meta, "bm25_score": float(s), "distance": None, # may be absent if not in dense top-k "source": meta.get("source", "Unknown"), }) # ----- Fuse dense + sparse ----- fused = self._rrf_fuse( dense_ranked, sparse_ranked, w_dense=alpha, w_sparse=1.0 - alpha, ) logger.info( f"Hybrid retrieved dense={len(dense_ranked)} sparse={len(sparse_ranked)} " f"fused={len(fused)}" ) return fused def get_vector_store() -> VectorStoreManager: """FastAPI dependency for injecting VectorStoreManager""" return VectorStoreManager()