NLP-RAG / retriever /retriever.py
Qar-Raz's picture
Sync backend Docker context from GitHub main
c64aaec verified
import numpy as np
import time
import re
from rank_bm25 import BM25Okapi
from sklearn.metrics.pairwise import cosine_similarity
from typing import Optional, List
#
# changed mmr to return final k, as a param, prev was hardcoded to 3
# --@Qamare
# Try to import FlashRank for CPU optimization, fallback to sentence-transformers
# try:
# from flashrank import Ranker, RerankRequest
# FLASHRANK_AVAILABLE = True
# except ImportError:
# from sentence_transformers import CrossEncoder
# FLASHRANK_AVAILABLE = False
class HybridRetriever:
def __init__(self, embed_model, rerank_model_name='jinaai/jina-reranker-v1-tiny-en', verbose: bool = True):
import sys
import os
print(f"[DEBUG-HybridRetriever] Starting init", flush=True)
self.embed_model = embed_model
self.verbose = verbose
self.rerank_model_name = self._normalize_rerank_model_name(rerank_model_name)
print(f"[DEBUG-HybridRetriever] Rerank model name: {self.rerank_model_name}", flush=True)
self.vo_client = None
self.ce_reranker = None
self.reranker_backend = "cross-encoder"
voyage_api_key = os.getenv("VOYAGE_API_KEY")
if voyage_api_key:
try:
import voyageai
self.vo_client = voyageai.Client(api_key=voyage_api_key)
self.reranker_backend = "voyageai"
# Voyage uses model IDs like rerank-2.5; keep a safe default.
if not self.rerank_model_name.startswith("rerank-"):
self.rerank_model_name = "rerank-2.5"
print(f"[DEBUG-HybridRetriever] Voyage AI client initialized", flush=True)
except Exception as exc:
print(f"[DEBUG-HybridRetriever] Voyage unavailable ({exc}); falling back to cross-encoder", flush=True)
if self.vo_client is None:
from sentence_transformers import CrossEncoder
ce_model_name = self.rerank_model_name
if not ce_model_name.startswith("cross-encoder/"):
ce_model_name = "cross-encoder/ms-marco-MiniLM-L-6-v2"
self.ce_reranker = CrossEncoder(ce_model_name)
self.rerank_model_name = ce_model_name
self.reranker_backend = "cross-encoder"
print(f"[DEBUG-HybridRetriever] Cross-encoder reranker initialized: {ce_model_name}", flush=True)
sys.stdout.flush()
print(f"[DEBUG-HybridRetriever] Init complete", flush=True)
def _normalize_rerank_model_name(self, model_name: str) -> str:
normalized = (model_name or "").strip()
if not normalized:
return "cross-encoder/ms-marco-MiniLM-L-6-v2"
if "/" in normalized:
return normalized
return f"cross-encoder/{normalized}"
def _tokenize(self, text: str) -> List[str]:
"""Tokenize text using regex to strip punctuation."""
return re.findall(r'\w+', text.lower())
# added these two helper methods for chunking based on chunk_technique metadata, and normalization of chunking_technique param
def _build_chunking_index_map(self) -> dict[str, List[int]]:
mapping: dict[str, List[int]] = {}
for idx, chunk in enumerate(self.final_chunks):
metadata = chunk.get('metadata', {})
technique = (metadata.get('chunking_technique') or '').strip().lower()
if not technique:
continue
mapping.setdefault(technique, []).append(idx)
return mapping
def _normalize_chunking_technique(self, chunking_technique: Optional[str]) -> Optional[str]:
if not chunking_technique:
return None
normalized = str(chunking_technique).strip().lower()
if not normalized or normalized in {"all", "any", "*", "none"}:
return None
return normalized
# ------------------------------------------------------------------
# Retrieval
# ------------------------------------------------------------------
def _semantic_search(self, query, index, top_k, technique_name: Optional[str] = None) -> tuple[np.ndarray, List[str]]:
query_vector = self.embed_model.encode(query)
query_kwargs = {
"vector": query_vector.tolist(),
"top_k": top_k,
"include_metadata": True,
}
if technique_name:
query_kwargs["filter"] = {"chunking_technique": {"$eq": technique_name}}
res = index.query(
**query_kwargs
)
chunks = [match['metadata']['text'] for match in res['matches']]
return query_vector, chunks
def _bm25_search(self, query, index, top_k=50, technique_name: Optional[str] = None) -> List[str]:
try:
import os
from pinecone import Pinecone
from pinecone_text.sparse import BM25Encoder
encoder = BM25Encoder().default()
pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
sparse_index = pc.Index("cbt-book-sparse")
sparse_vector = encoder.encode_queries(query)
query_kwargs = {
"sparse_vector": sparse_vector,
"top_k": top_k,
"include_metadata": True,
}
if technique_name:
query_kwargs["filter"] = {"chunking_technique": {"$eq": technique_name}}
res = sparse_index.query(**query_kwargs)
return [match["metadata"]["text"] for match in res["matches"]]
except Exception as e:
print(f"Error in BM25 search against Pinecone: {e}")
return []
"""Fetch chunks from Pinecone and perform BM25 ranking locally."""
# Fetch more candidates than needed for BM25 to rank against
# Use a reasonable multiplier to get enough candidates without over-fetching
fetch_limit = min(top_k * 4,25) # e.g., 4*4=16, capped at 50
res = index.query(
vector=[0.0] * 512, # Dummy vector (BM25 doesn't use embeddings)
top_k=fetch_limit,
include_metadata=True,
filter={"chunking_technique": {"$eq": technique_name}}
)
# Extract chunks
chunks = [match['metadata']['text'] for match in res['matches']]
if not chunks:
return []
# Build BM25 index on these chunks
tokenized_corpus = [self._tokenize(chunk) for chunk in chunks]
bm25 = BM25Okapi(tokenized_corpus)
# Score query against chunks
tokenized_query = self._tokenize(query)
scores = bm25.get_scores(tokenized_query)
top_indices = np.argsort(scores)[::-1][:top_k]
return [chunks[i] for i in top_indices]
# ------------------------------------------------------------------
# Fusion
# ------------------------------------------------------------------
def _rrf_score(self, semantic_results, bm25_results, k=60) -> List[str]:
scores = {}
for rank, chunk in enumerate(semantic_results):
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
for rank, chunk in enumerate(bm25_results):
scores[chunk] = scores.get(chunk, 0) + 1 / (k + rank + 1)
return [chunk for chunk, _ in sorted(scores.items(), key=lambda x: x[1], reverse=True)]
# ------------------------------------------------------------------
# Reranking
# ------------------------------------------------------------------
def _cross_encoder_rerank(self, query, chunks, final_k) -> tuple[List[str], List[float]]:
if not chunks:
return [], []
if self.vo_client is not None:
reranking = self.vo_client.rerank(query, chunks, model=self.rerank_model_name, top_k=final_k)
ranked_chunks = [result.document for result in reranking.results]
ranked_scores = [result.relevance_score for result in reranking.results]
return ranked_chunks, ranked_scores
pairs = [[query, chunk] for chunk in chunks]
scores = self.ce_reranker.predict(pairs)
ranked_indices = np.argsort(scores)[::-1][:final_k]
ranked_chunks = [chunks[i] for i in ranked_indices]
ranked_scores = [float(scores[i]) for i in ranked_indices]
return ranked_chunks, ranked_scores
# ------------------------------------------------------------------
# MMR (applied after reranking as a diversity filter)
# ------------------------------------------------------------------
def _maximal_marginal_relevance(self, query_vector, chunks, lambda_param=0.5, top_k=10) -> List[str]:
"""
Maximum Marginal Relevance (MMR) for diversity filtering.
DIVISION BY ZERO DEBUGGING:
- This method can cause division by zero in cosine_similarity if vectors are zero
- We've added multiple safeguards to prevent this
"""
print(f" [MMR DEBUG] Starting MMR with {len(chunks)} chunks, top_k={top_k}")
if not chunks:
print(f" [MMR DEBUG] No chunks, returning empty list")
return []
# STEP 1: Encode chunks to get embeddings
print(f" [MMR DEBUG] Encoding {len(chunks)} chunks...")
try:
chunk_embeddings = np.array([self.embed_model.encode(c) for c in chunks])
print(f" [MMR DEBUG] Chunk embeddings shape: {chunk_embeddings.shape}")
except Exception as e:
print(f" [MMR DEBUG] ERROR encoding chunks: {e}")
return chunks[:top_k]
# STEP 2: Reshape query vector
query_embedding = query_vector.reshape(1, -1)
print(f" [MMR DEBUG] Query embedding shape: {query_embedding.shape}")
# STEP 3: Check for zero vectors (POTENTIAL DIVISION BY ZERO SOURCE)
print(f" [MMR DEBUG] Checking for zero vectors...")
query_norm = np.linalg.norm(query_embedding)
chunk_norms = np.linalg.norm(chunk_embeddings, axis=1)
print(f" [MMR DEBUG] Query norm: {query_norm}")
print(f" [MMR DEBUG] Chunk norms min: {chunk_norms.min()}, max: {chunk_norms.max()}")
# Check for zero or near-zero vectors
if query_norm < 1e-10 or np.any(chunk_norms < 1e-10):
print(f" [MMR DEBUG] WARNING: Zero or near-zero vectors detected!")
print(f" [MMR DEBUG] Query norm < 1e-10: {query_norm < 1e-10}")
print(f" [MMR DEBUG] Any chunk norm < 1e-10: {np.any(chunk_norms < 1e-10)}")
print(f" [MMR DEBUG] Falling back to simple selection without MMR")
return chunks[:top_k]
# STEP 4: Compute relevance scores (POTENTIAL DIVISION BY ZERO SOURCE)
print(f" [MMR DEBUG] Computing relevance scores with cosine_similarity...")
try:
relevance_scores = cosine_similarity(query_embedding, chunk_embeddings)[0]
print(f" [MMR DEBUG] Relevance scores computed successfully")
print(f" [MMR DEBUG] Relevance scores shape: {relevance_scores.shape}")
print(f" [MMR DEBUG] Relevance scores min: {relevance_scores.min()}, max: {relevance_scores.max()}")
except Exception as e:
print(f" [MMR DEBUG] ERROR computing relevance scores: {e}")
print(f" [MMR DEBUG] Falling back to simple selection")
return chunks[:top_k]
# STEP 5: Initialize selection
selected, unselected = [], list(range(len(chunks)))
first = int(np.argmax(relevance_scores))
selected.append(first)
unselected.remove(first)
print(f" [MMR DEBUG] Selected first chunk: index {first}")
# STEP 6: Iteratively select chunks using MMR
print(f" [MMR DEBUG] Starting MMR iteration...")
iteration = 0
while len(selected) < min(top_k, len(chunks)):
iteration += 1
print(f" [MMR DEBUG] Iteration {iteration}: selected={len(selected)}, unselected={len(unselected)}")
# Calculate MMR scores
mmr_scores = []
for i in unselected:
# Compute max similarity to already selected items
max_sim = -1
for s in selected:
try:
# POTENTIAL DIVISION BY ZERO SOURCE: cosine_similarity
sim = cosine_similarity(
chunk_embeddings[i].reshape(1, -1),
chunk_embeddings[s].reshape(1, -1)
)[0][0]
max_sim = max(max_sim, sim)
except Exception as e:
print(f" [MMR DEBUG] ERROR computing similarity between chunk {i} and {s}: {e}")
# If similarity computation fails, use 0
max_sim = max(max_sim, 0)
mmr_score = lambda_param * relevance_scores[i] - (1 - lambda_param) * max_sim
mmr_scores.append((i, mmr_score))
# Select chunk with highest MMR score
if mmr_scores:
best, best_score = max(mmr_scores, key=lambda x: x[1])
selected.append(best)
unselected.remove(best)
print(f" [MMR DEBUG] Selected chunk {best} with MMR score {best_score:.4f}")
else:
print(f" [MMR DEBUG] No MMR scores computed, breaking")
break
print(f" [MMR DEBUG] MMR complete. Selected {len(selected)} chunks")
return [chunks[i] for i in selected]
# ------------------------------------------------------------------
# Main search
# ------------------------------------------------------------------
def search(self, query, index, top_k=50, final_k=5, mode="hybrid",
rerank_strategy="cross-encoder", use_mmr=False, lambda_param=0.5,
technique_name: Optional[str] = None,
chunking_technique: Optional[str] = None,
verbose: Optional[bool] = None, test: bool = False) -> tuple[List[str], float]:
"""
:param mode: "semantic", "bm25", or "hybrid"
:param rerank_strategy: "cross-encoder", "rrf", or "none"
:param use_mmr: Whether to apply MMR diversity filter after reranking
:param lambda_param: MMR trade-off between relevance (1.0) and diversity (0.0)
:param technique_name: Chunking technique to filter by (default: "markdown")
:returns: Tuple of (ranked_chunks, avg_chunk_score)
"""
should_print = verbose if verbose is not None else self.verbose
requested_technique = self._normalize_chunking_technique(chunking_technique or technique_name)
total_start = time.perf_counter()
semantic_time = 0.0
bm25_time = 0.0
rerank_time = 0.0
mmr_time = 0.0
if should_print:
self._print_search_header(query, mode, rerank_strategy, top_k, final_k)
if requested_technique:
print(f"Chunking Filter: {requested_technique}")
# 1. Retrieve candidates
query_vector = None
semantic_chunks, bm25_chunks = [], []
if mode in ["semantic", "hybrid"]:
semantic_start = time.perf_counter()
query_vector, semantic_chunks = self._semantic_search(query, index, top_k, requested_technique)
semantic_time = time.perf_counter() - semantic_start
print(f"[DEBUG-FLOW] retrieved {len(semantic_chunks)} chunks from semantic search", flush=True)
if should_print:
self._print_candidates("Semantic Search", semantic_chunks)
print(f"Semantic time: {semantic_time:.3f}s")
if mode in ["bm25", "hybrid"]:
bm25_start = time.perf_counter()
bm25_chunks = self._bm25_search(query, index, top_k, requested_technique)
bm25_time = time.perf_counter() - bm25_start
print(f"[DEBUG-FLOW] retrieved {len(bm25_chunks)} chunks from BM25 search", flush=True)
if should_print:
self._print_candidates("BM25 Search", bm25_chunks)
print(f"BM25 time: {bm25_time:.3f}s")
print("All BM25 results:")
for i, chunk in enumerate(bm25_chunks):
print(f" [{i}] {chunk[:200]}..." if len(chunk) > 200 else f" [{i}] {chunk}")
# 2. Fuse / rerank
rerank_start = time.perf_counter()
chunk_scores = []
if rerank_strategy == "rrf":
candidates = self._rrf_score(semantic_chunks, bm25_chunks)[:final_k]
label = "RRF"
elif rerank_strategy == "cross-encoder":
combined = list(dict.fromkeys(semantic_chunks + bm25_chunks))
print(f"[DEBUG-FLOW] {len(combined)} unique chunks went into cross-encoder", flush=True)
candidates, chunk_scores = self._cross_encoder_rerank(query, combined, final_k)
print(f"[DEBUG-FLOW] {len(candidates)} chunks got out of cross-encoder", flush=True)
label = "Cross-Encoder"
elif rerank_strategy == "voyage":
import voyageai
voyage_client = voyageai.Client()
combined = list(dict.fromkeys(semantic_chunks + bm25_chunks))
print(f"[DEBUG-FLOW] {len(combined)} unique chunks went into voyage reranker", flush=True)
if not combined:
candidates, chunk_scores = [], []
else:
try:
reranking = voyage_client.rerank(query=query, documents=combined, model=self.rerank_model_name, top_k=final_k)
candidates = [r.document for r in reranking.results]
chunk_scores = [r.relevance_score for r in reranking.results]
print(f"[DEBUG-FLOW] {len(candidates)} chunks got out of voyage reranker", flush=True)
except Exception as e:
print(f"Error calling Voyage API: {e}")
candidates = combined[:final_k]
chunk_scores = []
label = "Voyage"
else: # "none"
candidates = list(dict.fromkeys(semantic_chunks + bm25_chunks))[:final_k]
label = "No Reranking"
rerank_time = time.perf_counter() - rerank_start
# Compute average chunk score
avg_chunk_score = float(np.mean(chunk_scores)) if chunk_scores else 0.0
# 3. MMR diversity filter (applied after reranking)
if use_mmr and candidates:
mmr_start = time.perf_counter()
if query_vector is None:
query_vector = self.embed_model.encode(query)
candidates = self._maximal_marginal_relevance(query_vector, candidates,
lambda_param=lambda_param, top_k=final_k)
label += " + MMR"
mmr_time = time.perf_counter() - mmr_start
# Safety cap: always honor requested final_k regardless of retrieval strategy.
candidates = candidates[:final_k]
if test and rerank_strategy != "cross-encoder" and candidates:
_, test_scores = self._cross_encoder_rerank(query, candidates, len(candidates))
avg_chunk_score = float(np.mean(test_scores)) if test_scores else 0.0
total_time = time.perf_counter() - total_start
if should_print:
self._print_final_results(candidates, label)
self._print_timing_summary(semantic_time, bm25_time, rerank_time, mmr_time, total_time)
return candidates, avg_chunk_score
# ------------------------------------------------------------------
# Printing
# ------------------------------------------------------------------
def _print_search_header(self, query, mode, rerank_strategy, top_k, final_k):
print("\n" + "="*80)
print(f" SEARCH QUERY: {query}")
print(f"Mode: {mode.upper()} | Rerank: {rerank_strategy.upper()}")
print(f"Top-K: {top_k} | Final-K: {final_k}")
print("-" * 80)
def _print_candidates(self, label, chunks, preview_n=3):
print(f"{label}: Retrieved {len(chunks)} candidates")
for i, chunk in enumerate(chunks[:preview_n]):
preview = chunk[:100] + "..." if len(chunk) > 100 else chunk
print(f" [{i}] {preview}")
def _print_final_results(self, results, strategy_label):
print(f"\n Final {len(results)} Results ({strategy_label}):")
for i, chunk in enumerate(results):
preview = chunk[:150] + "..." if len(chunk) > 150 else chunk
print(f" [{i+1}] {preview}")
print("="*80)
def _print_timing_summary(self, semantic_time, bm25_time, rerank_time, mmr_time, total_time):
print(" Retrieval Timing:")
print(f" Semantic: {semantic_time:.3f}s")
print(f" BM25: {bm25_time:.3f}s")
print(f" Rerank/Fusion: {rerank_time:.3f}s")
print(f" MMR: {mmr_time:.3f}s")
print(f" Total Retrieval: {total_time:.3f}s")