File size: 13,183 Bytes
50fcf88
 
 
 
 
 
 
 
 
 
 
 
 
 
b5d7f6d
50fcf88
 
 
 
 
 
 
 
 
 
 
 
b5d7f6d
 
 
50fcf88
 
b5d7f6d
50fcf88
 
 
b5d7f6d
 
 
50fcf88
 
 
 
 
 
 
 
b5d7f6d
50fcf88
b5d7f6d
 
 
 
 
 
50fcf88
 
 
 
b5d7f6d
 
 
 
 
 
 
 
 
 
50fcf88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a7fd26
50fcf88
 
 
 
 
 
 
 
 
2a7fd26
 
50fcf88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c63c58
50fcf88
 
8c63c58
 
 
50fcf88
 
 
 
 
 
 
 
 
 
 
b5d7f6d
50fcf88
 
 
 
 
b5d7f6d
50fcf88
 
 
 
 
 
 
b5d7f6d
8c63c58
 
 
 
 
 
 
 
 
 
b5d7f6d
 
 
 
 
 
50fcf88
 
b5d7f6d
 
 
 
 
 
 
8c63c58
 
 
 
 
 
 
 
 
 
 
 
50fcf88
 
 
 
b5d7f6d
50fcf88
 
 
 
 
 
 
 
 
 
 
 
 
b5d7f6d
50fcf88
 
 
 
 
 
 
 
 
 
b5d7f6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50fcf88
 
8c63c58
 
 
 
 
 
 
 
 
 
 
 
 
50fcf88
 
8c63c58
b5d7f6d
50fcf88
 
 
 
8c63c58
50fcf88
 
 
 
 
 
 
b5d7f6d
50fcf88
 
 
8c63c58
50fcf88
 
 
 
 
b5d7f6d
 
8c63c58
 
 
 
50fcf88
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""
Retriever indexer module for DocChat.

Provides utilities for building different types of retrievers:
- Vector-based retriever (ChromaDB + embeddings)
- Hybrid retriever (BM25 + Vector with ensemble)
"""
import logging
import sys
from typing import List, Any
import time
import hashlib
import os
import json
import threading

from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_chroma import Chroma
from langchain_community.retrievers import BM25Retriever
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_core.vectorstores import VectorStoreRetriever
from configuration.parameters import parameters

logger = logging.getLogger(__name__)

# Thread lock for manifest file access
_manifest_lock = threading.Lock()


def doc_id(doc) -> str:
    """Generate a unique ID for a document based on source, page, chunk_id, and content hash."""
    src = doc.metadata.get("source", "")
    page = doc.metadata.get("page", "")
    chunk = doc.metadata.get("chunk_id", "")
    # Include content hash to ensure uniqueness even if chunk_id is missing
    content = hashlib.sha256(doc.page_content.encode("utf-8")).hexdigest()[:16]
    base = f"{src}::{page}::{chunk}::{content}"
    return hashlib.sha256(base.encode("utf-8")).hexdigest()


def content_hash(doc) -> str:
    return hashlib.sha256(doc.page_content.encode("utf-8")).hexdigest()


def load_manifest(path):
    """Thread-safe manifest loading."""
    if os.path.exists(path):
        try:
            with open(path, "r") as f:
                return json.load(f)
        except (json.JSONDecodeError, IOError) as e:
            logger.warning(f"Failed to load manifest, starting fresh: {e}")
            return {}
    return {}


def save_manifest(path, manifest):
    """Thread-safe manifest saving with atomic write."""
    temp_path = path + ".tmp"
    try:
        with open(temp_path, "w") as f:
            json.dump(manifest, f)
        os.replace(temp_path, path)  # Atomic rename
    except Exception as e:
        logger.error(f"Failed to save manifest: {e}")
        if os.path.exists(temp_path):
            os.remove(temp_path)


class EnsembleRetriever(BaseRetriever):
    """
    Custom Ensemble Retriever combining multiple retrievers with weighted RRF.
    
    Attributes:
        retrievers: List of retriever instances
        weights: List of weights (should sum to 1.0)
        c: RRF constant (default: 60)
        k: Max documents to return (default: 10)
    """
    
    retrievers: List[Any]
    weights: List[float]
    c: int = 60
    k: int = 10
    
    class Config:
        arbitrary_types_allowed = True
    
    def _get_relevant_documents(
        self, 
        query: str, 
        *, 
        run_manager: CallbackManagerForRetrieverRun = None
    ) -> List[Document]:
        """Retrieve and combine documents using weighted RRF, deduplicating charts by doc_id and aggregating page numbers."""
        logger.debug(f"[ENSEMBLE] Query: {query[:80]}...")
        all_docs_with_scores = {}
        retriever_names = ["BM25", "Vector"]
        for idx, (retriever, weight) in enumerate(zip(self.retrievers, self.weights)):
            retriever_name = retriever_names[idx] if idx < len(retriever_names) else f"Retriever_{idx}"
            try:
                docs = retriever.invoke(query)
                logger.debug(f"[ENSEMBLE] {retriever_name}: {len(docs)} docs (weight: {weight})")
                for rank, doc in enumerate(docs):
                    # Deduplicate by doc_id only
                    doc_key = doc_id(doc)
                    rrf_score = weight / (rank + 1 + self.c)
                    if doc_key in all_docs_with_scores:
                        existing_doc, existing_score = all_docs_with_scores[doc_key]
                        # Aggregate page numbers
                        existing_pages = set()
                        if isinstance(existing_doc.metadata.get('page'), list):
                            existing_pages.update(existing_doc.metadata['page'])
                        else:
                            existing_pages.add(existing_doc.metadata.get('page'))
                        existing_pages.add(doc.metadata.get('page'))
                        # Update metadata to include all pages
                        existing_doc.metadata['page'] = sorted(p for p in existing_pages if p is not None)
                        all_docs_with_scores[doc_key] = (existing_doc, existing_score + rrf_score)
                    else:
                        all_docs_with_scores[doc_key] = (doc, rrf_score)
            except Exception as e:
                logger.warning(f"[ENSEMBLE] {retriever_name} failed: {e}")
                continue
        sorted_docs = sorted(all_docs_with_scores.values(), key=lambda x: x[1], reverse=True)
        result = [doc for doc, score in sorted_docs[:self.k]]
        logger.debug(f"[ENSEMBLE] Returning {len(result)} documents")
        return result


class RetrieverBuilder:
    """Builder class for creating document retrievers with caching."""
    
    def __init__(self):
        """Initialize with embeddings model."""
        self.embeddings = GoogleGenerativeAIEmbeddings(
            model="models/text-embedding-004",
            google_api_key=parameters.GOOGLE_API_KEY,
            batch_size=100,  # Increased from 32 to 100 for 3× faster embedding (Google supports up to 100)
        )
        self._retriever_cache = {}  # {docset_hash: retriever}
        self._bm25_cache = {}  # {docset_hash: bm25_retriever} - NEW: Cache BM25 retrievers
        self._vector_store_cache = {}  # {chroma_dir: vector_store} - NEW: Reuse ChromaDB connections
        logger.debug("RetrieverBuilder initialized with caching enabled")

    def _hash_docs(self, docs):
        # Create a hash of all document contents and metadata
        m = hashlib.sha256()
        for doc in docs:
            m.update(doc.page_content.encode('utf-8'))
            for k, v in sorted(doc.metadata.items()):
                m.update(str(k).encode('utf-8'))
                m.update(str(v).encode('utf-8'))
        return m.hexdigest()

    def build_hybrid_retriever(self, docs, session_id: str = None) -> EnsembleRetriever:
        """
        Build hybrid retriever using BM25 and vector search.
        
        Args:
            docs: List of documents to index
            session_id: Optional session ID for user isolation (recommended for multi-user)
            
        Returns:
            EnsembleRetriever combining BM25 and vector search
        """
        logger.info(f"Building hybrid retriever with {len(docs)} documents...")
        if not docs:
            raise ValueError("No documents provided")
        
        # Generate cache key from document content hashes
        cache_key = self._hash_docs(docs)
        
        # Check retriever cache first (10-200× speedup for repeat queries)
        if cache_key in self._retriever_cache:
            logger.info(f"✅ Using cached retriever for docset {cache_key[:8]}... (CACHE HIT)")
            return self._retriever_cache[cache_key]
        
        logger.debug(f"Cache miss for docset {cache_key[:8]}..., building new retriever")
        
        # Use session-specific directory if provided (for multi-user isolation)
        if session_id:
            chroma_dir = os.path.join(parameters.CHROMA_DB_PATH, f"session_{session_id}")
        else:
            chroma_dir = parameters.CHROMA_DB_PATH
            
        manifest_path = os.path.join(chroma_dir, "indexed_manifest.json")
        os.makedirs(chroma_dir, exist_ok=True)
        
        # Thread-safe manifest access
        with _manifest_lock:
            manifest = load_manifest(manifest_path)
        
        t_vector_start = time.time()
        
        # Check vector store cache (reuse ChromaDB connections)
        if chroma_dir in self._vector_store_cache:
            logger.debug(f"Reusing cached vector store connection for {chroma_dir}")
            vector_store = self._vector_store_cache[chroma_dir]
        else:
            vector_store = Chroma(
                embedding_function=self.embeddings,
                persist_directory=chroma_dir,
            )
            self._vector_store_cache[chroma_dir] = vector_store
            logger.debug(f"Created new vector store connection for {chroma_dir}")

        to_add = []
        ids_to_add = []
        to_delete_ids = []
        current_ids = set()
        
        for d in docs:
            _id = doc_id(d)
            _hash = content_hash(d)
            current_ids.add(_id)
            if _id not in manifest:
                to_add.append(d)
                ids_to_add.append(_id)
                manifest[_id] = _hash
            elif manifest[_id] != _hash:
                to_delete_ids.append(_id)
                to_add.append(d)
                ids_to_add.append(_id)
                manifest[_id] = _hash             
        
        if to_add:
            # Safety net: de-dupe before add_documents
            seen = set()
            uniq_docs, uniq_ids = [], []
            for doc, _id in zip(to_add, ids_to_add):
                if _id in seen:
                    continue
                seen.add(_id)
                uniq_docs.append(doc)
                uniq_ids.append(_id)
            
            # Log duplicate count for debugging
            dupe_count = len(to_add) - len(uniq_docs)
            if dupe_count > 0:
                logger.debug(f"Filtered {dupe_count} duplicate documents before indexing")
            
            # Batch add documents for better performance
            logger.info(f"[PROFILE] Adding {len(uniq_docs)} new documents to vector store...")
            t_add_start = time.time()
            
            # Add in batches for progress tracking and memory efficiency
            batch_size = 100
            for i in range(0, len(uniq_docs), batch_size):
                batch_docs = uniq_docs[i:i+batch_size]
                batch_ids = uniq_ids[i:i+batch_size]
                vector_store.add_documents(batch_docs, ids=batch_ids)
                if len(uniq_docs) > batch_size:
                    logger.debug(f"[PROFILE] Indexed batch {i//batch_size + 1}/{(len(uniq_docs)-1)//batch_size + 1}")
            
            t_add_end = time.time()
            logger.info(f"[PROFILE] Vector store add_documents: {t_add_end - t_add_start:.2f}s")
        
        t_vector_end = time.time()
        logger.info(f"[PROFILE] Total vector store setup: {t_vector_end - t_vector_start:.2f}s")
        
        # Thread-safe manifest save
        with _manifest_lock:
            save_manifest(manifest_path, manifest)
        
        # Create BM25 retriever
        t_bm25_start = time.time()
        
        # Check BM25 cache (avoid rebuilding for same documents)
        if cache_key in self._bm25_cache:
            logger.debug(f"Reusing cached BM25 retriever for docset {cache_key[:8]}...")
            bm25_retriever = self._bm25_cache[cache_key]
        else:
            texts = [doc.page_content for doc in docs]
            metadatas = [doc.metadata for doc in docs]
            bm25_retriever = BM25Retriever.from_texts(texts=texts, metadatas=metadatas)
            bm25_retriever.k = parameters.BM25_SEARCH_K
            self._bm25_cache[cache_key] = bm25_retriever
            logger.debug(f"Created new BM25 retriever for docset {cache_key[:8]}...")
        
        t_bm25_end = time.time()
        logger.info(f"[PROFILE] BM25 retriever creation: {t_bm25_end - t_bm25_start:.2f}s")
        logger.debug(f"BM25 indexed {len(docs)} texts, k={bm25_retriever.k}")
        
        t_vec_retr_start = time.time()
        vector_retriever = vector_store.as_retriever(
            search_type="mmr",
            search_kwargs={
                "k": parameters.VECTOR_SEARCH_K_CHROMA,
                "fetch_k": parameters.VECTOR_FETCH_K,
                "lambda_mult": 0.7,
            },
        )
        t_vec_retr_end = time.time()
        logger.info(f"[PROFILE] Vector retriever creation: {t_vec_retr_end - t_vec_retr_start:.2f}s")
        logger.debug("Vector retriever created")
        
        t_ensemble_start = time.time()
        hybrid_retriever = EnsembleRetriever(
            retrievers=[bm25_retriever, vector_retriever],
            weights=parameters.HYBRID_RETRIEVER_WEIGHTS,                                                                
            k=parameters.VECTOR_SEARCH_K,
        )
        t_ensemble_end = time.time()
        logger.info(f"[PROFILE] Ensemble retriever creation: {t_ensemble_end - t_ensemble_start:.2f}s")
        logger.info(f"Hybrid retriever created (k={parameters.VECTOR_SEARCH_K})")
        logger.info(f"[PROFILE] Total hybrid retriever build: {t_ensemble_end - t_vector_start:.2f}s")
        
        # Cache the complete retriever for future use
        self._retriever_cache[cache_key] = hybrid_retriever
        logger.debug(f"Cached retriever for docset {cache_key[:8]}... (future requests will be instant)")
        
        return hybrid_retriever