File size: 21,274 Bytes
c7256ee
 
 
 
 
 
 
c27a4e3
 
c7256ee
 
 
 
c27a4e3
 
 
 
 
 
c7256ee
 
c27a4e3
 
 
 
c7256ee
 
 
c27a4e3
 
 
 
 
 
 
 
c7256ee
c27a4e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7256ee
c27a4e3
 
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
c7256ee
 
 
 
 
 
c27a4e3
 
 
 
 
 
c7256ee
 
 
c27a4e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7256ee
c27a4e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
c7256ee
c27a4e3
 
 
c7256ee
 
 
 
 
c27a4e3
 
c7256ee
 
c27a4e3
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c64aaec
c7256ee
 
 
 
 
 
c27a4e3
c7256ee
c64aaec
c7256ee
 
 
c27a4e3
 
 
c7256ee
 
 
c27a4e3
c7256ee
 
 
 
 
c64aaec
c27a4e3
c64aaec
c7256ee
c64aaec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7256ee
 
 
 
c27a4e3
 
 
c7256ee
 
 
 
 
 
 
 
 
 
 
8f37cc7
 
 
c27a4e3
 
 
 
c7256ee
 
 
 
 
 
c27a4e3
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
import numpy as np
import time
import re
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
from typing import Optional, List

#

# changed mmr to return final k, as a param, prev was hardcoded to 3
# --@Qamare

# Try to import FlashRank for CPU optimization, fallback to sentence-transformers
# try:
    # from flashrank import Ranker, RerankRequest
    # FLASHRANK_AVAILABLE = True
# except ImportError:
    # from sentence_transformers import CrossEncoder
    # FLASHRANK_AVAILABLE = False

class HybridRetriever:
    def __init__(self, embed_model, rerank_model_name='jinaai/jina-reranker-v1-tiny-en', verbose: bool = True):
        import sys
        import os
        print(f"[DEBUG-HybridRetriever] Starting init", flush=True)
        self.embed_model = embed_model
        self.verbose = verbose
        self.rerank_model_name = self._normalize_rerank_model_name(rerank_model_name)
        print(f"[DEBUG-HybridRetriever] Rerank model name: {self.rerank_model_name}", flush=True)

        self.vo_client = None
        self.ce_reranker = None
        self.reranker_backend = "cross-encoder"

        voyage_api_key = os.getenv("VOYAGE_API_KEY")
        if voyage_api_key:
            try:
                import voyageai
                self.vo_client = voyageai.Client(api_key=voyage_api_key)
                self.reranker_backend = "voyageai"
                # Voyage uses model IDs like rerank-2.5; keep a safe default.
                if not self.rerank_model_name.startswith("rerank-"):
                    self.rerank_model_name = "rerank-2.5"
                print(f"[DEBUG-HybridRetriever] Voyage AI client initialized", flush=True)
            except Exception as exc:
                print(f"[DEBUG-HybridRetriever] Voyage unavailable ({exc}); falling back to cross-encoder", flush=True)

        if self.vo_client is None:
            from sentence_transformers import CrossEncoder
            ce_model_name = self.rerank_model_name
            if not ce_model_name.startswith("cross-encoder/"):
                ce_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
            self.ce_reranker = CrossEncoder(ce_model_name)
            self.rerank_model_name = ce_model_name
            self.reranker_backend = "cross-encoder"
            print(f"[DEBUG-HybridRetriever] Cross-encoder reranker initialized: {ce_model_name}", flush=True)
        
        sys.stdout.flush()
        print(f"[DEBUG-HybridRetriever] Init complete", flush=True)

    def _normalize_rerank_model_name(self, model_name: str) -> str:
        normalized = (model_name or "").strip()
        if not normalized:
            return "cross-encoder/ms-marco-MiniLM-L-6-v2"
        if "/" in normalized:
            return normalized
        return f"cross-encoder/{normalized}"

    def _tokenize(self, text: str) -> List[str]:
        """Tokenize text using regex to strip punctuation."""
        return re.findall(r'\w+', text.lower())

# added these two helper methods for chunking based on chunk_technique metadata, and normalization of chunking_technique param
    def _build_chunking_index_map(self) -> dict[str, List[int]]:
        mapping: dict[str, List[int]] = {}
        for idx, chunk in enumerate(self.final_chunks):
            metadata = chunk.get('metadata', {})
            technique = (metadata.get('chunking_technique') or '').strip().lower()
            if not technique:
                continue
            mapping.setdefault(technique, []).append(idx)
        return mapping

    def _normalize_chunking_technique(self, chunking_technique: Optional[str]) -> Optional[str]:
        if not chunking_technique:
            return None
        normalized = str(chunking_technique).strip().lower()
        if not normalized or normalized in {"all", "any", "*", "none"}:
            return None
        return normalized

    # ------------------------------------------------------------------
    # Retrieval
    # ------------------------------------------------------------------

    def _semantic_search(self, query, index, top_k, technique_name: Optional[str] = None) -> tuple[np.ndarray, List[str]]:
        query_vector = self.embed_model.encode(query)
        query_kwargs = {
            "vector": query_vector.tolist(),
            "top_k": top_k,
            "include_metadata": True,
        }
        if technique_name:
            query_kwargs["filter"] = {"chunking_technique": {"$eq": technique_name}}

        res = index.query(
            **query_kwargs
        )
        chunks = [match['metadata']['text'] for match in res['matches']]
        return query_vector, chunks

    def _bm25_search(self, query, index, top_k=50, technique_name: Optional[str] = None) -> List[str]:
        try:
            import os
            from pinecone import Pinecone
            from pinecone_text.sparse import BM25Encoder
            encoder = BM25Encoder().default()
            pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
            sparse_index = pc.Index("cbt-book-sparse")
            sparse_vector = encoder.encode_queries(query)
            query_kwargs = {
                "sparse_vector": sparse_vector,
                "top_k": top_k,
                "include_metadata": True,
            }
            if technique_name:
                query_kwargs["filter"] = {"chunking_technique": {"$eq": technique_name}}

            res = sparse_index.query(**query_kwargs)
            return [match["metadata"]["text"] for match in res["matches"]]
        except Exception as e:
            print(f"Error in BM25 search against Pinecone: {e}")
            return []
        
        """Fetch chunks from Pinecone and perform BM25 ranking locally."""
        # Fetch more candidates than needed for BM25 to rank against
        # Use a reasonable multiplier to get enough candidates without over-fetching
        fetch_limit = min(top_k * 4,25)  # e.g., 4*4=16, capped at 50
        res = index.query(
            vector=[0.0] * 512,  # Dummy vector (BM25 doesn't use embeddings)
            top_k=fetch_limit,
            include_metadata=True,
            filter={"chunking_technique": {"$eq": technique_name}}
        )
        
        # Extract chunks
        chunks = [match['metadata']['text'] for match in res['matches']]
        if not chunks:
            return []
        
        # Build BM25 index on these chunks
        tokenized_corpus = [self._tokenize(chunk) for chunk in chunks]
        bm25 = BM25Okapi(tokenized_corpus)
        
        # Score query against chunks
        tokenized_query = self._tokenize(query)
        scores = bm25.get_scores(tokenized_query)
        top_indices = np.argsort(scores)[::-1][:top_k]
        return [chunks[i] for i in top_indices]

    # ------------------------------------------------------------------
    # Fusion
    # ------------------------------------------------------------------

    def _rrf_score(self, semantic_results, bm25_results, k=60) -> List[str]:
        scores = {}
        for rank, chunk in enumerate(semantic_results):
            scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
        for rank, chunk in enumerate(bm25_results):
            scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
        return [chunk for chunk, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]

    # ------------------------------------------------------------------
    # Reranking
    # ------------------------------------------------------------------

    def _cross_encoder_rerank(self, query, chunks, final_k) -> tuple[List[str], List[float]]:
        if not chunks:
            return [], []

        if self.vo_client is not None:
            reranking = self.vo_client.rerank(query, chunks, model=self.rerank_model_name, top_k=final_k)
            ranked_chunks = [result.document for result in reranking.results]
            ranked_scores = [result.relevance_score for result in reranking.results]
            return ranked_chunks, ranked_scores

        pairs = [[query, chunk] for chunk in chunks]
        scores = self.ce_reranker.predict(pairs)
        ranked_indices = np.argsort(scores)[::-1][:final_k]
        ranked_chunks = [chunks[i] for i in ranked_indices]
        ranked_scores = [float(scores[i]) for i in ranked_indices]
        return ranked_chunks, ranked_scores

    # ------------------------------------------------------------------
    # MMR (applied after reranking as a diversity filter)
    # ------------------------------------------------------------------

    def _maximal_marginal_relevance(self, query_vector, chunks, lambda_param=0.5, top_k=10) -> List[str]:
        """
        Maximum Marginal Relevance (MMR) for diversity filtering.
        
        DIVISION BY ZERO DEBUGGING:
        - This method can cause division by zero in cosine_similarity if vectors are zero
        - We've added multiple safeguards to prevent this
        """
        print(f"    [MMR DEBUG] Starting MMR with {len(chunks)} chunks, top_k={top_k}")
        
        if not chunks:
            print(f"    [MMR DEBUG] No chunks, returning empty list")
            return []
        
        # STEP 1: Encode chunks to get embeddings
        print(f"    [MMR DEBUG] Encoding {len(chunks)} chunks...")
        try:
            chunk_embeddings = np.array([self.embed_model.encode(c) for c in chunks])
            print(f"    [MMR DEBUG] Chunk embeddings shape: {chunk_embeddings.shape}")
        except Exception as e:
            print(f"    [MMR DEBUG] ERROR encoding chunks: {e}")
            return chunks[:top_k]
        
        # STEP 2: Reshape query vector
        query_embedding = query_vector.reshape(1, -1)
        print(f"    [MMR DEBUG] Query embedding shape: {query_embedding.shape}")
        
        # STEP 3: Check for zero vectors (POTENTIAL DIVISION BY ZERO SOURCE)
        print(f"    [MMR DEBUG] Checking for zero vectors...")
        query_norm = np.linalg.norm(query_embedding)
        chunk_norms = np.linalg.norm(chunk_embeddings, axis=1)
        
        print(f"    [MMR DEBUG] Query norm: {query_norm}")
        print(f"    [MMR DEBUG] Chunk norms min: {chunk_norms.min()}, max: {chunk_norms.max()}")
        
        # Check for zero or near-zero vectors
        if query_norm < 1e-10 or np.any(chunk_norms < 1e-10):
            print(f"    [MMR DEBUG] WARNING: Zero or near-zero vectors detected!")
            print(f"    [MMR DEBUG] Query norm < 1e-10: {query_norm < 1e-10}")
            print(f"    [MMR DEBUG] Any chunk norm < 1e-10: {np.any(chunk_norms < 1e-10)}")
            print(f"    [MMR DEBUG] Falling back to simple selection without MMR")
            return chunks[:top_k]
        
        # STEP 4: Compute relevance scores (POTENTIAL DIVISION BY ZERO SOURCE)
        print(f"    [MMR DEBUG] Computing relevance scores with cosine_similarity...")
        try:
            relevance_scores = cosine_similarity(query_embedding, chunk_embeddings)[0]
            print(f"    [MMR DEBUG] Relevance scores computed successfully")
            print(f"    [MMR DEBUG] Relevance scores shape: {relevance_scores.shape}")
            print(f"    [MMR DEBUG] Relevance scores min: {relevance_scores.min()}, max: {relevance_scores.max()}")
        except Exception as e:
            print(f"    [MMR DEBUG] ERROR computing relevance scores: {e}")
            print(f"    [MMR DEBUG] Falling back to simple selection")
            return chunks[:top_k]
        
        # STEP 5: Initialize selection
        selected, unselected = [], list(range(len(chunks)))
        
        first = int(np.argmax(relevance_scores))
        selected.append(first)
        unselected.remove(first)
        print(f"    [MMR DEBUG] Selected first chunk: index {first}")
        
        # STEP 6: Iteratively select chunks using MMR
        print(f"    [MMR DEBUG] Starting MMR iteration...")
        iteration = 0
        while len(selected) < min(top_k, len(chunks)):
            iteration += 1
            print(f"    [MMR DEBUG] Iteration {iteration}: selected={len(selected)}, unselected={len(unselected)}")
            
            # Calculate MMR scores
            mmr_scores = []
            for i in unselected:
                # Compute max similarity to already selected items
                max_sim = -1
                for s in selected:
                    try:
                        # POTENTIAL DIVISION BY ZERO SOURCE: cosine_similarity
                        sim = cosine_similarity(
                            chunk_embeddings[i].reshape(1, -1),
                            chunk_embeddings[s].reshape(1, -1)
                        )[0][0]
                        max_sim = max(max_sim, sim)
                    except Exception as e:
                        print(f"    [MMR DEBUG] ERROR computing similarity between chunk {i} and {s}: {e}")
                        # If similarity computation fails, use 0
                        max_sim = max(max_sim, 0)
                
                mmr_score = lambda_param * relevance_scores[i] - (1 - lambda_param) * max_sim
                mmr_scores.append((i, mmr_score))
            
            # Select chunk with highest MMR score
            if mmr_scores:
                best, best_score = max(mmr_scores, key=lambda x: x[1])
                selected.append(best)
                unselected.remove(best)
                print(f"    [MMR DEBUG] Selected chunk {best} with MMR score {best_score:.4f}")
            else:
                print(f"    [MMR DEBUG] No MMR scores computed, breaking")
                break
        
        print(f"    [MMR DEBUG] MMR complete. Selected {len(selected)} chunks")
        return [chunks[i] for i in selected]

    # ------------------------------------------------------------------
    # Main search
    # ------------------------------------------------------------------

    def search(self, query, index, top_k=50, final_k=5, mode="hybrid",
               rerank_strategy="cross-encoder", use_mmr=False, lambda_param=0.5,
               technique_name: Optional[str] = None,
               chunking_technique: Optional[str] = None,
               verbose: Optional[bool] = None, test: bool = False) -> tuple[List[str], float]:
        """
        :param mode:             "semantic", "bm25", or "hybrid"
        :param rerank_strategy:  "cross-encoder", "rrf", or "none"
        :param use_mmr:          Whether to apply MMR diversity filter after reranking
        :param lambda_param:     MMR trade-off between relevance (1.0) and diversity (0.0)
        :param technique_name:   Chunking technique to filter by (default: "markdown")
        :returns:                Tuple of (ranked_chunks, avg_chunk_score)
        """
        should_print = verbose if verbose is not None else self.verbose
        requested_technique = self._normalize_chunking_technique(chunking_technique or technique_name)
        total_start = time.perf_counter()
        semantic_time = 0.0
        bm25_time = 0.0
        rerank_time = 0.0
        mmr_time = 0.0
        
        if should_print:
            self._print_search_header(query, mode, rerank_strategy, top_k, final_k)
            if requested_technique:
                print(f"Chunking Filter: {requested_technique}")

        # 1. Retrieve candidates
        query_vector = None
        semantic_chunks, bm25_chunks = [], []

        if mode in ["semantic", "hybrid"]:
            semantic_start = time.perf_counter()
            query_vector, semantic_chunks = self._semantic_search(query, index, top_k, requested_technique)
            semantic_time = time.perf_counter() - semantic_start
            print(f"[DEBUG-FLOW] retrieved {len(semantic_chunks)} chunks from semantic search", flush=True)
            if should_print:
                self._print_candidates("Semantic Search", semantic_chunks)
                print(f"Semantic time: {semantic_time:.3f}s")

        if mode in ["bm25", "hybrid"]:
            bm25_start = time.perf_counter()
            bm25_chunks = self._bm25_search(query, index, top_k, requested_technique)
            bm25_time = time.perf_counter() - bm25_start
            print(f"[DEBUG-FLOW] retrieved {len(bm25_chunks)} chunks from BM25 search", flush=True)
            if should_print:
                self._print_candidates("BM25 Search", bm25_chunks)
                print(f"BM25 time: {bm25_time:.3f}s")
                print("All BM25 results:")
                for i, chunk in enumerate(bm25_chunks):
                    print(f"  [{i}] {chunk[:200]}..." if len(chunk) > 200 else f"  [{i}] {chunk}")

        # 2. Fuse / rerank
        rerank_start = time.perf_counter()
        chunk_scores = []
        if rerank_strategy == "rrf":
            candidates = self._rrf_score(semantic_chunks, bm25_chunks)[:final_k]
            label = "RRF"
        elif rerank_strategy == "cross-encoder":
            combined = list(dict.fromkeys(semantic_chunks + bm25_chunks))
            print(f"[DEBUG-FLOW] {len(combined)} unique chunks went into cross-encoder", flush=True)
            candidates, chunk_scores = self._cross_encoder_rerank(query, combined, final_k)
            print(f"[DEBUG-FLOW] {len(candidates)} chunks got out of cross-encoder", flush=True)
            label = "Cross-Encoder"
        elif rerank_strategy == "voyage":
            import voyageai
            voyage_client = voyageai.Client()
            combined = list(dict.fromkeys(semantic_chunks + bm25_chunks))
            print(f"[DEBUG-FLOW] {len(combined)} unique chunks went into voyage reranker", flush=True)
            if not combined:
                candidates, chunk_scores = [], []
            else:
                try:
                    reranking = voyage_client.rerank(query=query, documents=combined, model=self.rerank_model_name, top_k=final_k)
                    candidates = [r.document for r in reranking.results]
                    chunk_scores = [r.relevance_score for r in reranking.results]
                    print(f"[DEBUG-FLOW] {len(candidates)} chunks got out of voyage reranker", flush=True)
                except Exception as e:
                    print(f"Error calling Voyage API: {e}")
                    candidates = combined[:final_k]
                    chunk_scores = []
            label = "Voyage"
        else:  # "none"
            candidates = list(dict.fromkeys(semantic_chunks + bm25_chunks))[:final_k]
            label = "No Reranking"
        rerank_time = time.perf_counter() - rerank_start
        
        # Compute average chunk score
        avg_chunk_score = float(np.mean(chunk_scores)) if chunk_scores else 0.0

        # 3. MMR diversity filter (applied after reranking)
        if use_mmr and candidates:
            mmr_start = time.perf_counter()
            if query_vector is None:
                query_vector = self.embed_model.encode(query)
            candidates = self._maximal_marginal_relevance(query_vector, candidates,
                                                          lambda_param=lambda_param, top_k=final_k)
            label += " + MMR"
            mmr_time = time.perf_counter() - mmr_start

        # Safety cap: always honor requested final_k regardless of retrieval strategy.
        candidates = candidates[:final_k]

        if test and rerank_strategy != "cross-encoder" and candidates:
            _, test_scores = self._cross_encoder_rerank(query, candidates, len(candidates))
            avg_chunk_score = float(np.mean(test_scores)) if test_scores else 0.0

        total_time = time.perf_counter() - total_start

        if should_print:
            self._print_final_results(candidates, label)
            self._print_timing_summary(semantic_time, bm25_time, rerank_time, mmr_time, total_time)

        return candidates, avg_chunk_score

    # ------------------------------------------------------------------
    # Printing
    # ------------------------------------------------------------------

    def _print_search_header(self, query, mode, rerank_strategy, top_k, final_k):
        print("\n" + "="*80)
        print(f" SEARCH QUERY: {query}")
        print(f"Mode: {mode.upper()} | Rerank: {rerank_strategy.upper()}")
        print(f"Top-K: {top_k} | Final-K: {final_k}")
        print("-" * 80)

    def _print_candidates(self, label, chunks, preview_n=3):
        print(f"{label}: Retrieved {len(chunks)} candidates")
        for i, chunk in enumerate(chunks[:preview_n]):
            preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
            print(f"   [{i}] {preview}")

    def _print_final_results(self, results, strategy_label):
        print(f"\n Final {len(results)} Results ({strategy_label}):")
        for i, chunk in enumerate(results):
            preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
            print(f"   [{i+1}] {preview}")
        print("="*80)

    def _print_timing_summary(self, semantic_time, bm25_time, rerank_time, mmr_time, total_time):
        print(" Retrieval Timing:")
        print(f"   Semantic: {semantic_time:.3f}s")
        print(f"   BM25: {bm25_time:.3f}s")
        print(f"   Rerank/Fusion: {rerank_time:.3f}s")
        print(f"   MMR: {mmr_time:.3f}s")
        print(f"   Total Retrieval: {total_time:.3f}s")