Spaces:
Sleeping
Sleeping
| import logging | |
| from typing import Optional, Dict, Any, List | |
| import threading | |
| import re | |
| import numpy as np | |
| import chromadb | |
| from rank_bm25 import BM25Okapi | |
| logger = logging.getLogger(__name__) | |
| class VectorStoreManager: | |
| _instance = None | |
| _lock = threading.Lock() | |
| _initialized = False | |
| def __new__(cls): | |
| with cls._lock: | |
| if cls._instance is None: | |
| cls._instance = super().__new__(cls) | |
| return cls._instance | |
| def __init__(self): | |
| with self._lock: | |
| if not self._initialized: | |
| self._initialize() | |
| VectorStoreManager._initialized = True | |
| def _initialize(self): | |
| """Initialize vector store with single collection + BM25 index""" | |
| try: | |
| logger.info("Initializing vector store components...") | |
| self.client = None | |
| self.collection = None | |
| db_path = "output/chromadb" # Match your pipeline path | |
| self.client = chromadb.PersistentClient(path=db_path) | |
| logger.info(f"ChromaDB client initialized at path: {db_path}") | |
| available_collections = [col.name for col in self.client.list_collections()] | |
| logger.info(f"Available collections: {available_collections}") | |
| try: | |
| self.collection = self.client.get_collection("rag_documents") | |
| collection_count = self.collection.count() | |
| logger.info( | |
| f"Collection 'rag_documents' loaded with {collection_count} documents" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Collection 'rag_documents' not found: {str(e)}") | |
| raise ValueError( | |
| "Required collection 'rag_documents' not found. " | |
| f"Available: {available_collections}" | |
| ) | |
| # ---- Build BM25 index from all stored docs ---- | |
| logger.info("Building BM25 index from Chroma collection...") | |
| data = self.collection.get(include=["documents", "metadatas"]) | |
| self.all_ids: List[str] = data["ids"] | |
| self.all_docs: List[str] = data["documents"] | |
| self.all_metas: List[Dict[str, Any]] = data["metadatas"] | |
| self.tokenized_corpus = [self._tokenize(d) for d in self.all_docs] | |
| self.bm25 = BM25Okapi(self.tokenized_corpus) | |
| logger.info(f"BM25 index ready with {len(self.all_docs)} chunks") | |
| logger.info("Vector store initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Failed to initialize vector store: {str(e)}") | |
| VectorStoreManager._initialized = False | |
| raise | |
| # ----------------- Helpers ----------------- | |
| def _tokenize(self, text: str) -> List[str]: | |
| return re.findall(r"\w+", (text or "").lower()) | |
| def _matches_filters( | |
| self, | |
| meta: Dict[str, Any], | |
| doc_text: str, | |
| where_filters: Optional[Dict[str, Any]], | |
| where_document: Optional[Dict[str, Any]], | |
| ) -> bool: | |
| if where_filters: | |
| for k, v in where_filters.items(): | |
| if meta.get(k) != v: | |
| return False | |
| if where_document: | |
| # you only use {"$contains": "..."} | |
| contains = where_document.get("$contains") | |
| if contains and contains.lower() not in (doc_text or "").lower(): | |
| return False | |
| return True | |
| def _rrf_fuse( | |
| self, | |
| dense_ranked: List[Dict[str, Any]], | |
| sparse_ranked: List[Dict[str, Any]], | |
| k: int = 60, | |
| w_dense: float = 0.6, | |
| w_sparse: float = 0.4, | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Reciprocal Rank Fusion | |
| score = w_dense/(k+rank_dense) + w_sparse/(k+rank_sparse) | |
| """ | |
| scores: Dict[str, Dict[str, Any]] = {} | |
| for rank, item in enumerate(dense_ranked): | |
| doc_id = item["id"] | |
| scores.setdefault(doc_id, {"score": 0.0, "item": item}) | |
| scores[doc_id]["score"] += w_dense / (k + rank + 1) | |
| for rank, item in enumerate(sparse_ranked): | |
| doc_id = item["id"] | |
| scores.setdefault(doc_id, {"score": 0.0, "item": item}) | |
| scores[doc_id]["score"] += w_sparse / (k + rank + 1) | |
| fused = sorted(scores.values(), key=lambda x: x["score"], reverse=True) | |
| return [x["item"] for x in fused] | |
| # ----------------- Main retrieval ----------------- | |
| def retrieve_documents( | |
| self, | |
| question: str, | |
| n_results: int = 5, | |
| where_filters: Optional[Dict[str, Any]] = None, | |
| where_document: Optional[Dict[str, Any]] = None, | |
| enable_bm25: bool = False, | |
| bm25_k: Optional[int] = None, | |
| alpha: float = 0.6, # dense weight in hybrid fusion | |
| ) -> List[Dict[str, Any]]: | |
| """ | |
| Retrieve documents using: | |
| - semantic-only (Chroma) | |
| - or hybrid semantic + BM25 (RRF fusion) | |
| Returns a list of dicts: | |
| {id, text, metadata, distance, bm25_score(optional)} | |
| """ | |
| if not self._initialized or self.collection is None: | |
| raise RuntimeError("VectorStoreManager not properly initialized") | |
| logger.info(f"Retrieving documents for query: {question[:50]}...") | |
| dense_k = n_results | |
| bm25_k = bm25_k or n_results | |
| # ----- Dense retrieval (semantic via Chroma) ----- | |
| try: | |
| dense_res = self.collection.query( | |
| query_texts=[question], | |
| n_results=dense_k, | |
| include=["documents", "metadatas", "distances"], | |
| where=where_filters if where_filters else None, | |
| where_document=where_document if where_document else None, | |
| ) | |
| except Exception as e: | |
| logger.error(f"Dense retrieval failed: {str(e)}") | |
| raise | |
| dense_ranked: List[Dict[str, Any]] = [] | |
| if dense_res and dense_res.get("documents") and dense_res["documents"][0]: | |
| for i in range(len(dense_res["documents"][0])): | |
| meta = dense_res["metadatas"][0][i] | |
| dense_ranked.append({ | |
| "id": dense_res["ids"][0][i], | |
| "text": dense_res["documents"][0][i], | |
| "metadata": meta, | |
| "distance": float(dense_res["distances"][0][i]), | |
| "source": meta.get("source", "Unknown"), | |
| }) | |
| if not enable_bm25: | |
| logger.info(f"Semantic-only retrieved {len(dense_ranked)} docs") | |
| return dense_ranked | |
| # ----- Sparse retrieval (BM25) ----- | |
| q_tokens = self._tokenize(question) | |
| scores = self.bm25.get_scores(q_tokens) | |
| # Apply same filters to BM25 corpus | |
| valid_indices = [] | |
| for idx, (doc, meta) in enumerate(zip(self.all_docs, self.all_metas)): | |
| if self._matches_filters(meta, doc, where_filters, where_document): | |
| valid_indices.append(idx) | |
| # take top bm25_k from valid indices | |
| valid_scores = [(idx, scores[idx]) for idx in valid_indices] | |
| valid_scores.sort(key=lambda x: x[1], reverse=True) | |
| top_sparse = valid_scores[:bm25_k] | |
| sparse_ranked: List[Dict[str, Any]] = [] | |
| for idx, s in top_sparse: | |
| meta = self.all_metas[idx] | |
| sparse_ranked.append({ | |
| "id": self.all_ids[idx], | |
| "text": self.all_docs[idx], | |
| "metadata": meta, | |
| "bm25_score": float(s), | |
| "distance": None, # may be absent if not in dense top-k | |
| "source": meta.get("source", "Unknown"), | |
| }) | |
| # ----- Fuse dense + sparse ----- | |
| fused = self._rrf_fuse( | |
| dense_ranked, | |
| sparse_ranked, | |
| w_dense=alpha, | |
| w_sparse=1.0 - alpha, | |
| ) | |
| logger.info( | |
| f"Hybrid retrieved dense={len(dense_ranked)} sparse={len(sparse_ranked)} " | |
| f"fused={len(fused)}" | |
| ) | |
| return fused | |
| def get_vector_store() -> VectorStoreManager: | |
| """FastAPI dependency for injecting VectorStoreManager""" | |
| return VectorStoreManager() | |