Spaces:
Sleeping
Sleeping
| """ | |
| Retriever indexer module for DocChat. | |
| Provides utilities for building different types of retrievers: | |
| - Vector-based retriever (ChromaDB + embeddings) | |
| - Hybrid retriever (BM25 + Vector with ensemble) | |
| """ | |
| import logging | |
| import sys | |
| from typing import List, Any | |
| import time | |
| import hashlib | |
| import os | |
| import json | |
| import threading | |
| from langchain_core.documents import Document | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain_core.callbacks import CallbackManagerForRetrieverRun | |
| from langchain_chroma import Chroma | |
| from langchain_community.retrievers import BM25Retriever | |
| from langchain_google_genai import GoogleGenerativeAIEmbeddings | |
| from langchain_core.vectorstores import VectorStoreRetriever | |
| from configuration.parameters import parameters | |
| logger = logging.getLogger(__name__) | |
| # Thread lock for manifest file access | |
| _manifest_lock = threading.Lock() | |
| def doc_id(doc) -> str: | |
| """Generate a unique ID for a document based on source, page, chunk_id, and content hash.""" | |
| src = doc.metadata.get("source", "") | |
| page = doc.metadata.get("page", "") | |
| chunk = doc.metadata.get("chunk_id", "") | |
| # Include content hash to ensure uniqueness even if chunk_id is missing | |
| content = hashlib.sha256(doc.page_content.encode("utf-8")).hexdigest()[:16] | |
| base = f"{src}::{page}::{chunk}::{content}" | |
| return hashlib.sha256(base.encode("utf-8")).hexdigest() | |
| def content_hash(doc) -> str: | |
| return hashlib.sha256(doc.page_content.encode("utf-8")).hexdigest() | |
| def load_manifest(path): | |
| """Thread-safe manifest loading.""" | |
| if os.path.exists(path): | |
| try: | |
| with open(path, "r") as f: | |
| return json.load(f) | |
| except (json.JSONDecodeError, IOError) as e: | |
| logger.warning(f"Failed to load manifest, starting fresh: {e}") | |
| return {} | |
| return {} | |
| def save_manifest(path, manifest): | |
| """Thread-safe manifest saving with atomic write.""" | |
| temp_path = path + ".tmp" | |
| try: | |
| with open(temp_path, "w") as f: | |
| json.dump(manifest, f) | |
| os.replace(temp_path, path) # Atomic rename | |
| except Exception as e: | |
| logger.error(f"Failed to save manifest: {e}") | |
| if os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| class EnsembleRetriever(BaseRetriever): | |
| """ | |
| Custom Ensemble Retriever combining multiple retrievers with weighted RRF. | |
| Attributes: | |
| retrievers: List of retriever instances | |
| weights: List of weights (should sum to 1.0) | |
| c: RRF constant (default: 60) | |
| k: Max documents to return (default: 10) | |
| """ | |
| retrievers: List[Any] | |
| weights: List[float] | |
| c: int = 60 | |
| k: int = 10 | |
| class Config: | |
| arbitrary_types_allowed = True | |
| def _get_relevant_documents( | |
| self, | |
| query: str, | |
| *, | |
| run_manager: CallbackManagerForRetrieverRun = None | |
| ) -> List[Document]: | |
| """Retrieve and combine documents using weighted RRF, deduplicating charts by doc_id and aggregating page numbers.""" | |
| logger.debug(f"[ENSEMBLE] Query: {query[:80]}...") | |
| all_docs_with_scores = {} | |
| retriever_names = ["BM25", "Vector"] | |
| for idx, (retriever, weight) in enumerate(zip(self.retrievers, self.weights)): | |
| retriever_name = retriever_names[idx] if idx < len(retriever_names) else f"Retriever_{idx}" | |
| try: | |
| docs = retriever.invoke(query) | |
| logger.debug(f"[ENSEMBLE] {retriever_name}: {len(docs)} docs (weight: {weight})") | |
| for rank, doc in enumerate(docs): | |
| # Deduplicate by doc_id only | |
| doc_key = doc_id(doc) | |
| rrf_score = weight / (rank + 1 + self.c) | |
| if doc_key in all_docs_with_scores: | |
| existing_doc, existing_score = all_docs_with_scores[doc_key] | |
| # Aggregate page numbers | |
| existing_pages = set() | |
| if isinstance(existing_doc.metadata.get('page'), list): | |
| existing_pages.update(existing_doc.metadata['page']) | |
| else: | |
| existing_pages.add(existing_doc.metadata.get('page')) | |
| existing_pages.add(doc.metadata.get('page')) | |
| # Update metadata to include all pages | |
| existing_doc.metadata['page'] = sorted(p for p in existing_pages if p is not None) | |
| all_docs_with_scores[doc_key] = (existing_doc, existing_score + rrf_score) | |
| else: | |
| all_docs_with_scores[doc_key] = (doc, rrf_score) | |
| except Exception as e: | |
| logger.warning(f"[ENSEMBLE] {retriever_name} failed: {e}") | |
| continue | |
| sorted_docs = sorted(all_docs_with_scores.values(), key=lambda x: x[1], reverse=True) | |
| result = [doc for doc, score in sorted_docs[:self.k]] | |
| logger.debug(f"[ENSEMBLE] Returning {len(result)} documents") | |
| return result | |
| class RetrieverBuilder: | |
| """Builder class for creating document retrievers with caching.""" | |
| def __init__(self): | |
| """Initialize with embeddings model.""" | |
| self.embeddings = GoogleGenerativeAIEmbeddings( | |
| model="models/text-embedding-004", | |
| google_api_key=parameters.GOOGLE_API_KEY, | |
| batch_size=100, # Increased from 32 to 100 for 3× faster embedding (Google supports up to 100) | |
| ) | |
| self._retriever_cache = {} # {docset_hash: retriever} | |
| self._bm25_cache = {} # {docset_hash: bm25_retriever} - NEW: Cache BM25 retrievers | |
| self._vector_store_cache = {} # {chroma_dir: vector_store} - NEW: Reuse ChromaDB connections | |
| logger.debug("RetrieverBuilder initialized with caching enabled") | |
| def _hash_docs(self, docs): | |
| # Create a hash of all document contents and metadata | |
| m = hashlib.sha256() | |
| for doc in docs: | |
| m.update(doc.page_content.encode('utf-8')) | |
| for k, v in sorted(doc.metadata.items()): | |
| m.update(str(k).encode('utf-8')) | |
| m.update(str(v).encode('utf-8')) | |
| return m.hexdigest() | |
| def build_hybrid_retriever(self, docs, session_id: str = None) -> EnsembleRetriever: | |
| """ | |
| Build hybrid retriever using BM25 and vector search. | |
| Args: | |
| docs: List of documents to index | |
| session_id: Optional session ID for user isolation (recommended for multi-user) | |
| Returns: | |
| EnsembleRetriever combining BM25 and vector search | |
| """ | |
| logger.info(f"Building hybrid retriever with {len(docs)} documents...") | |
| if not docs: | |
| raise ValueError("No documents provided") | |
| # Generate cache key from document content hashes | |
| cache_key = self._hash_docs(docs) | |
| # Check retriever cache first (10-200× speedup for repeat queries) | |
| if cache_key in self._retriever_cache: | |
| logger.info(f"✅ Using cached retriever for docset {cache_key[:8]}... (CACHE HIT)") | |
| return self._retriever_cache[cache_key] | |
| logger.debug(f"Cache miss for docset {cache_key[:8]}..., building new retriever") | |
| # Use session-specific directory if provided (for multi-user isolation) | |
| if session_id: | |
| chroma_dir = os.path.join(parameters.CHROMA_DB_PATH, f"session_{session_id}") | |
| else: | |
| chroma_dir = parameters.CHROMA_DB_PATH | |
| manifest_path = os.path.join(chroma_dir, "indexed_manifest.json") | |
| os.makedirs(chroma_dir, exist_ok=True) | |
| # Thread-safe manifest access | |
| with _manifest_lock: | |
| manifest = load_manifest(manifest_path) | |
| t_vector_start = time.time() | |
| # Check vector store cache (reuse ChromaDB connections) | |
| if chroma_dir in self._vector_store_cache: | |
| logger.debug(f"Reusing cached vector store connection for {chroma_dir}") | |
| vector_store = self._vector_store_cache[chroma_dir] | |
| else: | |
| vector_store = Chroma( | |
| embedding_function=self.embeddings, | |
| persist_directory=chroma_dir, | |
| ) | |
| self._vector_store_cache[chroma_dir] = vector_store | |
| logger.debug(f"Created new vector store connection for {chroma_dir}") | |
| to_add = [] | |
| ids_to_add = [] | |
| to_delete_ids = [] | |
| current_ids = set() | |
| for d in docs: | |
| _id = doc_id(d) | |
| _hash = content_hash(d) | |
| current_ids.add(_id) | |
| if _id not in manifest: | |
| to_add.append(d) | |
| ids_to_add.append(_id) | |
| manifest[_id] = _hash | |
| elif manifest[_id] != _hash: | |
| to_delete_ids.append(_id) | |
| to_add.append(d) | |
| ids_to_add.append(_id) | |
| manifest[_id] = _hash | |
| if to_add: | |
| # Safety net: de-dupe before add_documents | |
| seen = set() | |
| uniq_docs, uniq_ids = [], [] | |
| for doc, _id in zip(to_add, ids_to_add): | |
| if _id in seen: | |
| continue | |
| seen.add(_id) | |
| uniq_docs.append(doc) | |
| uniq_ids.append(_id) | |
| # Log duplicate count for debugging | |
| dupe_count = len(to_add) - len(uniq_docs) | |
| if dupe_count > 0: | |
| logger.debug(f"Filtered {dupe_count} duplicate documents before indexing") | |
| # Batch add documents for better performance | |
| logger.info(f"[PROFILE] Adding {len(uniq_docs)} new documents to vector store...") | |
| t_add_start = time.time() | |
| # Add in batches for progress tracking and memory efficiency | |
| batch_size = 100 | |
| for i in range(0, len(uniq_docs), batch_size): | |
| batch_docs = uniq_docs[i:i+batch_size] | |
| batch_ids = uniq_ids[i:i+batch_size] | |
| vector_store.add_documents(batch_docs, ids=batch_ids) | |
| if len(uniq_docs) > batch_size: | |
| logger.debug(f"[PROFILE] Indexed batch {i//batch_size + 1}/{(len(uniq_docs)-1)//batch_size + 1}") | |
| t_add_end = time.time() | |
| logger.info(f"[PROFILE] Vector store add_documents: {t_add_end - t_add_start:.2f}s") | |
| t_vector_end = time.time() | |
| logger.info(f"[PROFILE] Total vector store setup: {t_vector_end - t_vector_start:.2f}s") | |
| # Thread-safe manifest save | |
| with _manifest_lock: | |
| save_manifest(manifest_path, manifest) | |
| # Create BM25 retriever | |
| t_bm25_start = time.time() | |
| # Check BM25 cache (avoid rebuilding for same documents) | |
| if cache_key in self._bm25_cache: | |
| logger.debug(f"Reusing cached BM25 retriever for docset {cache_key[:8]}...") | |
| bm25_retriever = self._bm25_cache[cache_key] | |
| else: | |
| texts = [doc.page_content for doc in docs] | |
| metadatas = [doc.metadata for doc in docs] | |
| bm25_retriever = BM25Retriever.from_texts(texts=texts, metadatas=metadatas) | |
| bm25_retriever.k = parameters.BM25_SEARCH_K | |
| self._bm25_cache[cache_key] = bm25_retriever | |
| logger.debug(f"Created new BM25 retriever for docset {cache_key[:8]}...") | |
| t_bm25_end = time.time() | |
| logger.info(f"[PROFILE] BM25 retriever creation: {t_bm25_end - t_bm25_start:.2f}s") | |
| logger.debug(f"BM25 indexed {len(docs)} texts, k={bm25_retriever.k}") | |
| t_vec_retr_start = time.time() | |
| vector_retriever = vector_store.as_retriever( | |
| search_type="mmr", | |
| search_kwargs={ | |
| "k": parameters.VECTOR_SEARCH_K_CHROMA, | |
| "fetch_k": parameters.VECTOR_FETCH_K, | |
| "lambda_mult": 0.7, | |
| }, | |
| ) | |
| t_vec_retr_end = time.time() | |
| logger.info(f"[PROFILE] Vector retriever creation: {t_vec_retr_end - t_vec_retr_start:.2f}s") | |
| logger.debug("Vector retriever created") | |
| t_ensemble_start = time.time() | |
| hybrid_retriever = EnsembleRetriever( | |
| retrievers=[bm25_retriever, vector_retriever], | |
| weights=parameters.HYBRID_RETRIEVER_WEIGHTS, | |
| k=parameters.VECTOR_SEARCH_K, | |
| ) | |
| t_ensemble_end = time.time() | |
| logger.info(f"[PROFILE] Ensemble retriever creation: {t_ensemble_end - t_ensemble_start:.2f}s") | |
| logger.info(f"Hybrid retriever created (k={parameters.VECTOR_SEARCH_K})") | |
| logger.info(f"[PROFILE] Total hybrid retriever build: {t_ensemble_end - t_vector_start:.2f}s") | |
| # Cache the complete retriever for future use | |
| self._retriever_cache[cache_key] = hybrid_retriever | |
| logger.debug(f"Cached retriever for docset {cache_key[:8]}... (future requests will be instant)") | |
| return hybrid_retriever | |