""" Retriever Agent Implements hybrid retrieval combining dense and sparse methods. Follows FAANG best practices for production RAG systems. Key Features: - Dense retrieval (embedding-based semantic search) - Sparse retrieval (BM25/TF-IDF keyword matching) - Reciprocal Rank Fusion (RRF) for combining results - Query expansion using planner output - Adaptive retrieval based on query intent """ from typing import List, Optional, Dict, Any, Tuple from pydantic import BaseModel, Field from loguru import logger from dataclasses import dataclass from collections import defaultdict import re import math from ..store import VectorStore, VectorSearchResult, get_vector_store, VectorStoreConfig from ..embeddings import EmbeddingAdapter, get_embedding_adapter, EmbeddingConfig from .query_planner import QueryPlan, SubQuery, QueryIntent class HybridSearchConfig(BaseModel): """Configuration for hybrid retrieval.""" # Dense retrieval settings dense_weight: float = Field(default=0.7, ge=0.0, le=1.0) dense_top_k: int = Field(default=20, ge=1) # Sparse retrieval settings sparse_weight: float = Field(default=0.3, ge=0.0, le=1.0) sparse_top_k: int = Field(default=20, ge=1) # Fusion settings rrf_k: int = Field(default=60, description="RRF constant (typically 60)") final_top_k: int = Field(default=10, ge=1) # Query expansion use_query_expansion: bool = Field(default=True) max_expanded_queries: int = Field(default=3, ge=1) # Intent-based adaptation adapt_to_intent: bool = Field(default=True) class RetrievalResult(BaseModel): """Result from hybrid retrieval.""" chunk_id: str document_id: str text: str score: float # Combined RRF score dense_score: Optional[float] = None sparse_score: Optional[float] = None dense_rank: Optional[int] = None sparse_rank: Optional[int] = None # Metadata page: Optional[int] = None chunk_type: Optional[str] = None source_path: Optional[str] = None metadata: Dict[str, Any] = Field(default_factory=dict) # For evidence grounding bbox: Optional[Dict[str, float]] = None class RetrieverAgent: """ Hybrid retrieval agent combining dense and sparse search. Capabilities: 1. Dense retrieval via embedding similarity 2. Sparse retrieval via BM25-style keyword matching 3. Reciprocal Rank Fusion for result combination 4. Query expansion from planner 5. Intent-aware retrieval adaptation """ def __init__( self, config: Optional[HybridSearchConfig] = None, vector_store: Optional[VectorStore] = None, embedding_adapter: Optional[EmbeddingAdapter] = None, ): """ Initialize Retriever Agent. Args: config: Hybrid search configuration vector_store: Vector store for dense retrieval embedding_adapter: Embedding adapter for query encoding """ self.config = config or HybridSearchConfig() self._store = vector_store self._embedder = embedding_adapter # BM25 parameters self._k1 = 1.5 self._b = 0.75 # Document statistics for BM25 (computed lazily) self._doc_stats: Optional[Dict[str, Any]] = None logger.info("RetrieverAgent initialized with hybrid search") @property def store(self) -> VectorStore: """Get vector store (lazy initialization).""" if self._store is None: self._store = get_vector_store() return self._store @property def embedder(self) -> EmbeddingAdapter: """Get embedding adapter (lazy initialization).""" if self._embedder is None: self._embedder = get_embedding_adapter() return self._embedder def retrieve( self, query: str, plan: Optional[QueryPlan] = None, top_k: Optional[int] = None, filters: Optional[Dict[str, Any]] = None, ) -> List[RetrievalResult]: """ Perform hybrid retrieval for a query. Args: query: Search query plan: Optional query plan for expansion and intent top_k: Number of results (overrides config) filters: Metadata filters Returns: List of retrieval results ranked by RRF score """ top_k = top_k or self.config.final_top_k # Get queries to run (original + expanded) queries = self._get_queries(query, plan) # Adapt retrieval based on intent dense_weight, sparse_weight = self._adapt_weights(plan) # Run dense retrieval dense_results = self._dense_retrieve(queries, filters) # Run sparse retrieval sparse_results = self._sparse_retrieve(queries, filters) # Combine with RRF combined = self._reciprocal_rank_fusion( dense_results, sparse_results, dense_weight, sparse_weight, ) # Return top-k results = sorted(combined.values(), key=lambda x: x.score, reverse=True) return results[:top_k] def retrieve_for_subqueries( self, sub_queries: List[SubQuery], filters: Optional[Dict[str, Any]] = None, ) -> Dict[str, List[RetrievalResult]]: """ Retrieve for multiple sub-queries, respecting dependencies. Args: sub_queries: List of sub-queries from planner filters: Optional metadata filters Returns: Dict mapping sub-query ID to retrieval results """ results = {} # Sort by priority and dependencies sorted_queries = self._topological_sort(sub_queries) for sq in sorted_queries: # Retrieve for this sub-query sq_results = self.retrieve( sq.query, top_k=self.config.final_top_k, filters=filters, ) results[sq.id] = sq_results return results def _get_queries( self, query: str, plan: Optional[QueryPlan], ) -> List[str]: """Get list of queries to run (original + expanded).""" queries = [query] if plan and self.config.use_query_expansion: # Add expanded terms as additional queries for term in plan.expanded_terms[:self.config.max_expanded_queries]: # Combine original query with expanded term expanded = f"{query} {term}" queries.append(expanded) return queries def _adapt_weights( self, plan: Optional[QueryPlan], ) -> Tuple[float, float]: """Adapt dense/sparse weights based on query intent.""" if not plan or not self.config.adapt_to_intent: return self.config.dense_weight, self.config.sparse_weight intent = plan.intent # Factoid queries benefit from keyword matching if intent == QueryIntent.FACTOID: return 0.6, 0.4 # Definition queries benefit from semantic search if intent == QueryIntent.DEFINITION: return 0.8, 0.2 # Comparison needs both if intent == QueryIntent.COMPARISON: return 0.5, 0.5 # Aggregation needs broad semantic coverage if intent == QueryIntent.AGGREGATION: return 0.75, 0.25 # List queries benefit from keyword precision if intent == QueryIntent.LIST: return 0.5, 0.5 return self.config.dense_weight, self.config.sparse_weight def _dense_retrieve( self, queries: List[str], filters: Optional[Dict[str, Any]], ) -> Dict[str, Tuple[int, float]]: """ Perform dense (embedding) retrieval. Returns: Dict mapping chunk_id to (rank, score) """ all_results: Dict[str, List[Tuple[int, float, VectorSearchResult]]] = defaultdict(list) for query in queries: # Embed query query_embedding = self.embedder.embed_text(query) # Search results = self.store.search( query_embedding=query_embedding, top_k=self.config.dense_top_k, filters=filters, ) # Record results with rank for rank, result in enumerate(results, 1): all_results[result.chunk_id].append((rank, result.similarity, result)) # Aggregate scores across queries (take best rank/score) aggregated = {} for chunk_id, scores in all_results.items(): best_rank = min(s[0] for s in scores) best_score = max(s[1] for s in scores) aggregated[chunk_id] = (best_rank, best_score, scores[0][2]) return aggregated def _sparse_retrieve( self, queries: List[str], filters: Optional[Dict[str, Any]], ) -> Dict[str, Tuple[int, float]]: """ Perform sparse (BM25-style) retrieval. Returns: Dict mapping chunk_id to (rank, score) """ # Get all chunks from vector store for sparse search # In production, this would use an inverted index try: all_chunks = self._get_all_chunks(filters) except Exception as e: logger.warning(f"Sparse retrieval failed: {e}") return {} if not all_chunks: return {} # Compute document statistics if needed if self._doc_stats is None: self._compute_doc_stats(all_chunks) # Score all chunks for each query all_scores: Dict[str, List[float]] = defaultdict(list) for query in queries: query_terms = self._tokenize(query) for chunk_id, text in all_chunks.items(): score = self._bm25_score(query_terms, text) all_scores[chunk_id].append(score) # Aggregate scores (take max) aggregated = {} for chunk_id, scores in all_scores.items(): best_score = max(scores) aggregated[chunk_id] = best_score # Rank by score ranked = sorted(aggregated.items(), key=lambda x: x[1], reverse=True) result = {} for rank, (chunk_id, score) in enumerate(ranked[:self.config.sparse_top_k], 1): result[chunk_id] = (rank, score, None) return result def _get_all_chunks( self, filters: Optional[Dict[str, Any]], ) -> Dict[str, str]: """Get all chunks for sparse retrieval.""" # This is a simplified implementation # In production, use an inverted index # Get chunk IDs from dense search with generic query query_embedding = self.embedder.embed_text("document content information") results = self.store.search( query_embedding=query_embedding, top_k=1000, # Get as many as possible filters=filters, ) chunks = {} for result in results: chunks[result.chunk_id] = result.text return chunks def _compute_doc_stats(self, chunks: Dict[str, str]): """Compute document statistics for BM25.""" doc_lengths = [] df = defaultdict(int) # Document frequency for text in chunks.values(): terms = self._tokenize(text) doc_lengths.append(len(terms)) for term in set(terms): df[term] += 1 self._doc_stats = { "avg_dl": sum(doc_lengths) / len(doc_lengths) if doc_lengths else 1, "n_docs": len(chunks), "df": dict(df), } def _tokenize(self, text: str) -> List[str]: """Simple tokenization.""" text = text.lower() text = re.sub(r'[^\w\s]', ' ', text) return text.split() def _bm25_score(self, query_terms: List[str], doc_text: str) -> float: """Compute BM25 score.""" if not self._doc_stats: return 0.0 doc_terms = self._tokenize(doc_text) dl = len(doc_terms) avg_dl = self._doc_stats["avg_dl"] n_docs = self._doc_stats["n_docs"] df = self._doc_stats["df"] # Count term frequencies in document tf = defaultdict(int) for term in doc_terms: tf[term] += 1 score = 0.0 for term in query_terms: if term not in tf: continue # IDF doc_freq = df.get(term, 0) idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1) # TF with saturation term_freq = tf[term] tf_component = (term_freq * (self._k1 + 1)) / ( term_freq + self._k1 * (1 - self._b + self._b * dl / avg_dl) ) score += idf * tf_component return score def _reciprocal_rank_fusion( self, dense_results: Dict[str, Tuple[int, float, Any]], sparse_results: Dict[str, Tuple[int, float, Any]], dense_weight: float, sparse_weight: float, ) -> Dict[str, RetrievalResult]: """ Combine dense and sparse results using RRF. RRF score = sum(1 / (k + rank)) for each ranking """ k = self.config.rrf_k combined = {} # Get all unique chunk IDs all_chunk_ids = set(dense_results.keys()) | set(sparse_results.keys()) for chunk_id in all_chunk_ids: dense_rank = dense_results.get(chunk_id, (1000, 0, None))[0] dense_score = dense_results.get(chunk_id, (1000, 0, None))[1] sparse_rank = sparse_results.get(chunk_id, (1000, 0, None))[0] sparse_score = sparse_results.get(chunk_id, (1000, 0, None))[1] # RRF formula rrf_dense = dense_weight / (k + dense_rank) if chunk_id in dense_results else 0 rrf_sparse = sparse_weight / (k + sparse_rank) if chunk_id in sparse_results else 0 rrf_score = rrf_dense + rrf_sparse # Get metadata from dense results if available metadata = {} page = None chunk_type = None source_path = None text = "" document_id = "" bbox = None if chunk_id in dense_results: result_obj = dense_results[chunk_id][2] if result_obj: text = result_obj.text document_id = result_obj.document_id page = result_obj.page chunk_type = result_obj.chunk_type metadata = result_obj.metadata source_path = metadata.get("source_path", "") bbox = result_obj.bbox combined[chunk_id] = RetrievalResult( chunk_id=chunk_id, document_id=document_id, text=text, score=rrf_score, dense_score=dense_score if chunk_id in dense_results else None, sparse_score=sparse_score if chunk_id in sparse_results else None, dense_rank=dense_rank if chunk_id in dense_results else None, sparse_rank=sparse_rank if chunk_id in sparse_results else None, page=page, chunk_type=chunk_type, source_path=source_path, metadata=metadata, bbox=bbox, ) return combined def _topological_sort(self, sub_queries: List[SubQuery]) -> List[SubQuery]: """Sort sub-queries by dependencies.""" # Simple topological sort sorted_queries = [] remaining = list(sub_queries) completed = set() while remaining: for sq in remaining[:]: if all(dep in completed for dep in sq.depends_on): sorted_queries.append(sq) completed.add(sq.id) remaining.remove(sq) break else: # Cycle detected or invalid dependencies, just append rest sorted_queries.extend(remaining) break return sorted_queries