|
|
""" |
|
|
Retriever Agent |
|
|
|
|
|
Implements hybrid retrieval combining dense and sparse methods. |
|
|
Follows FAANG best practices for production RAG systems. |
|
|
|
|
|
Key Features: |
|
|
- Dense retrieval (embedding-based semantic search) |
|
|
- Sparse retrieval (BM25/TF-IDF keyword matching) |
|
|
- Reciprocal Rank Fusion (RRF) for combining results |
|
|
- Query expansion using planner output |
|
|
- Adaptive retrieval based on query intent |
|
|
""" |
|
|
|
|
|
from typing import List, Optional, Dict, Any, Tuple |
|
|
from pydantic import BaseModel, Field |
|
|
from loguru import logger |
|
|
from dataclasses import dataclass |
|
|
from collections import defaultdict |
|
|
import re |
|
|
import math |
|
|
|
|
|
from ..store import VectorStore, VectorSearchResult, get_vector_store, VectorStoreConfig |
|
|
from ..embeddings import EmbeddingAdapter, get_embedding_adapter, EmbeddingConfig |
|
|
from .query_planner import QueryPlan, SubQuery, QueryIntent |
|
|
|
|
|
|
|
|
class HybridSearchConfig(BaseModel): |
|
|
"""Configuration for hybrid retrieval.""" |
|
|
|
|
|
dense_weight: float = Field(default=0.7, ge=0.0, le=1.0) |
|
|
dense_top_k: int = Field(default=20, ge=1) |
|
|
|
|
|
|
|
|
sparse_weight: float = Field(default=0.3, ge=0.0, le=1.0) |
|
|
sparse_top_k: int = Field(default=20, ge=1) |
|
|
|
|
|
|
|
|
rrf_k: int = Field(default=60, description="RRF constant (typically 60)") |
|
|
final_top_k: int = Field(default=10, ge=1) |
|
|
|
|
|
|
|
|
use_query_expansion: bool = Field(default=True) |
|
|
max_expanded_queries: int = Field(default=3, ge=1) |
|
|
|
|
|
|
|
|
adapt_to_intent: bool = Field(default=True) |
|
|
|
|
|
|
|
|
class RetrievalResult(BaseModel): |
|
|
"""Result from hybrid retrieval.""" |
|
|
chunk_id: str |
|
|
document_id: str |
|
|
text: str |
|
|
score: float |
|
|
dense_score: Optional[float] = None |
|
|
sparse_score: Optional[float] = None |
|
|
dense_rank: Optional[int] = None |
|
|
sparse_rank: Optional[int] = None |
|
|
|
|
|
|
|
|
page: Optional[int] = None |
|
|
chunk_type: Optional[str] = None |
|
|
source_path: Optional[str] = None |
|
|
metadata: Dict[str, Any] = Field(default_factory=dict) |
|
|
|
|
|
|
|
|
bbox: Optional[Dict[str, float]] = None |
|
|
|
|
|
|
|
|
class RetrieverAgent: |
|
|
""" |
|
|
Hybrid retrieval agent combining dense and sparse search. |
|
|
|
|
|
Capabilities: |
|
|
1. Dense retrieval via embedding similarity |
|
|
2. Sparse retrieval via BM25-style keyword matching |
|
|
3. Reciprocal Rank Fusion for result combination |
|
|
4. Query expansion from planner |
|
|
5. Intent-aware retrieval adaptation |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Optional[HybridSearchConfig] = None, |
|
|
vector_store: Optional[VectorStore] = None, |
|
|
embedding_adapter: Optional[EmbeddingAdapter] = None, |
|
|
): |
|
|
""" |
|
|
Initialize Retriever Agent. |
|
|
|
|
|
Args: |
|
|
config: Hybrid search configuration |
|
|
vector_store: Vector store for dense retrieval |
|
|
embedding_adapter: Embedding adapter for query encoding |
|
|
""" |
|
|
self.config = config or HybridSearchConfig() |
|
|
self._store = vector_store |
|
|
self._embedder = embedding_adapter |
|
|
|
|
|
|
|
|
self._k1 = 1.5 |
|
|
self._b = 0.75 |
|
|
|
|
|
|
|
|
self._doc_stats: Optional[Dict[str, Any]] = None |
|
|
|
|
|
logger.info("RetrieverAgent initialized with hybrid search") |
|
|
|
|
|
@property |
|
|
def store(self) -> VectorStore: |
|
|
"""Get vector store (lazy initialization).""" |
|
|
if self._store is None: |
|
|
self._store = get_vector_store() |
|
|
return self._store |
|
|
|
|
|
@property |
|
|
def embedder(self) -> EmbeddingAdapter: |
|
|
"""Get embedding adapter (lazy initialization).""" |
|
|
if self._embedder is None: |
|
|
self._embedder = get_embedding_adapter() |
|
|
return self._embedder |
|
|
|
|
|
def retrieve( |
|
|
self, |
|
|
query: str, |
|
|
plan: Optional[QueryPlan] = None, |
|
|
top_k: Optional[int] = None, |
|
|
filters: Optional[Dict[str, Any]] = None, |
|
|
) -> List[RetrievalResult]: |
|
|
""" |
|
|
Perform hybrid retrieval for a query. |
|
|
|
|
|
Args: |
|
|
query: Search query |
|
|
plan: Optional query plan for expansion and intent |
|
|
top_k: Number of results (overrides config) |
|
|
filters: Metadata filters |
|
|
|
|
|
Returns: |
|
|
List of retrieval results ranked by RRF score |
|
|
""" |
|
|
top_k = top_k or self.config.final_top_k |
|
|
|
|
|
|
|
|
queries = self._get_queries(query, plan) |
|
|
|
|
|
|
|
|
dense_weight, sparse_weight = self._adapt_weights(plan) |
|
|
|
|
|
|
|
|
dense_results = self._dense_retrieve(queries, filters) |
|
|
|
|
|
|
|
|
sparse_results = self._sparse_retrieve(queries, filters) |
|
|
|
|
|
|
|
|
combined = self._reciprocal_rank_fusion( |
|
|
dense_results, |
|
|
sparse_results, |
|
|
dense_weight, |
|
|
sparse_weight, |
|
|
) |
|
|
|
|
|
|
|
|
results = sorted(combined.values(), key=lambda x: x.score, reverse=True) |
|
|
return results[:top_k] |
|
|
|
|
|
def retrieve_for_subqueries( |
|
|
self, |
|
|
sub_queries: List[SubQuery], |
|
|
filters: Optional[Dict[str, Any]] = None, |
|
|
) -> Dict[str, List[RetrievalResult]]: |
|
|
""" |
|
|
Retrieve for multiple sub-queries, respecting dependencies. |
|
|
|
|
|
Args: |
|
|
sub_queries: List of sub-queries from planner |
|
|
filters: Optional metadata filters |
|
|
|
|
|
Returns: |
|
|
Dict mapping sub-query ID to retrieval results |
|
|
""" |
|
|
results = {} |
|
|
|
|
|
|
|
|
sorted_queries = self._topological_sort(sub_queries) |
|
|
|
|
|
for sq in sorted_queries: |
|
|
|
|
|
sq_results = self.retrieve( |
|
|
sq.query, |
|
|
top_k=self.config.final_top_k, |
|
|
filters=filters, |
|
|
) |
|
|
results[sq.id] = sq_results |
|
|
|
|
|
return results |
|
|
|
|
|
def _get_queries( |
|
|
self, |
|
|
query: str, |
|
|
plan: Optional[QueryPlan], |
|
|
) -> List[str]: |
|
|
"""Get list of queries to run (original + expanded).""" |
|
|
queries = [query] |
|
|
|
|
|
if plan and self.config.use_query_expansion: |
|
|
|
|
|
for term in plan.expanded_terms[:self.config.max_expanded_queries]: |
|
|
|
|
|
expanded = f"{query} {term}" |
|
|
queries.append(expanded) |
|
|
|
|
|
return queries |
|
|
|
|
|
def _adapt_weights( |
|
|
self, |
|
|
plan: Optional[QueryPlan], |
|
|
) -> Tuple[float, float]: |
|
|
"""Adapt dense/sparse weights based on query intent.""" |
|
|
if not plan or not self.config.adapt_to_intent: |
|
|
return self.config.dense_weight, self.config.sparse_weight |
|
|
|
|
|
intent = plan.intent |
|
|
|
|
|
|
|
|
if intent == QueryIntent.FACTOID: |
|
|
return 0.6, 0.4 |
|
|
|
|
|
|
|
|
if intent == QueryIntent.DEFINITION: |
|
|
return 0.8, 0.2 |
|
|
|
|
|
|
|
|
if intent == QueryIntent.COMPARISON: |
|
|
return 0.5, 0.5 |
|
|
|
|
|
|
|
|
if intent == QueryIntent.AGGREGATION: |
|
|
return 0.75, 0.25 |
|
|
|
|
|
|
|
|
if intent == QueryIntent.LIST: |
|
|
return 0.5, 0.5 |
|
|
|
|
|
return self.config.dense_weight, self.config.sparse_weight |
|
|
|
|
|
def _dense_retrieve( |
|
|
self, |
|
|
queries: List[str], |
|
|
filters: Optional[Dict[str, Any]], |
|
|
) -> Dict[str, Tuple[int, float]]: |
|
|
""" |
|
|
Perform dense (embedding) retrieval. |
|
|
|
|
|
Returns: |
|
|
Dict mapping chunk_id to (rank, score) |
|
|
""" |
|
|
all_results: Dict[str, List[Tuple[int, float, VectorSearchResult]]] = defaultdict(list) |
|
|
|
|
|
for query in queries: |
|
|
|
|
|
query_embedding = self.embedder.embed_text(query) |
|
|
|
|
|
|
|
|
results = self.store.search( |
|
|
query_embedding=query_embedding, |
|
|
top_k=self.config.dense_top_k, |
|
|
filters=filters, |
|
|
) |
|
|
|
|
|
|
|
|
for rank, result in enumerate(results, 1): |
|
|
all_results[result.chunk_id].append((rank, result.similarity, result)) |
|
|
|
|
|
|
|
|
aggregated = {} |
|
|
for chunk_id, scores in all_results.items(): |
|
|
best_rank = min(s[0] for s in scores) |
|
|
best_score = max(s[1] for s in scores) |
|
|
aggregated[chunk_id] = (best_rank, best_score, scores[0][2]) |
|
|
|
|
|
return aggregated |
|
|
|
|
|
def _sparse_retrieve( |
|
|
self, |
|
|
queries: List[str], |
|
|
filters: Optional[Dict[str, Any]], |
|
|
) -> Dict[str, Tuple[int, float]]: |
|
|
""" |
|
|
Perform sparse (BM25-style) retrieval. |
|
|
|
|
|
Returns: |
|
|
Dict mapping chunk_id to (rank, score) |
|
|
""" |
|
|
|
|
|
|
|
|
try: |
|
|
all_chunks = self._get_all_chunks(filters) |
|
|
except Exception as e: |
|
|
logger.warning(f"Sparse retrieval failed: {e}") |
|
|
return {} |
|
|
|
|
|
if not all_chunks: |
|
|
return {} |
|
|
|
|
|
|
|
|
if self._doc_stats is None: |
|
|
self._compute_doc_stats(all_chunks) |
|
|
|
|
|
|
|
|
all_scores: Dict[str, List[float]] = defaultdict(list) |
|
|
|
|
|
for query in queries: |
|
|
query_terms = self._tokenize(query) |
|
|
for chunk_id, text in all_chunks.items(): |
|
|
score = self._bm25_score(query_terms, text) |
|
|
all_scores[chunk_id].append(score) |
|
|
|
|
|
|
|
|
aggregated = {} |
|
|
for chunk_id, scores in all_scores.items(): |
|
|
best_score = max(scores) |
|
|
aggregated[chunk_id] = best_score |
|
|
|
|
|
|
|
|
ranked = sorted(aggregated.items(), key=lambda x: x[1], reverse=True) |
|
|
result = {} |
|
|
for rank, (chunk_id, score) in enumerate(ranked[:self.config.sparse_top_k], 1): |
|
|
result[chunk_id] = (rank, score, None) |
|
|
|
|
|
return result |
|
|
|
|
|
def _get_all_chunks( |
|
|
self, |
|
|
filters: Optional[Dict[str, Any]], |
|
|
) -> Dict[str, str]: |
|
|
"""Get all chunks for sparse retrieval.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query_embedding = self.embedder.embed_text("document content information") |
|
|
results = self.store.search( |
|
|
query_embedding=query_embedding, |
|
|
top_k=1000, |
|
|
filters=filters, |
|
|
) |
|
|
|
|
|
chunks = {} |
|
|
for result in results: |
|
|
chunks[result.chunk_id] = result.text |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _compute_doc_stats(self, chunks: Dict[str, str]): |
|
|
"""Compute document statistics for BM25.""" |
|
|
doc_lengths = [] |
|
|
df = defaultdict(int) |
|
|
|
|
|
for text in chunks.values(): |
|
|
terms = self._tokenize(text) |
|
|
doc_lengths.append(len(terms)) |
|
|
for term in set(terms): |
|
|
df[term] += 1 |
|
|
|
|
|
self._doc_stats = { |
|
|
"avg_dl": sum(doc_lengths) / len(doc_lengths) if doc_lengths else 1, |
|
|
"n_docs": len(chunks), |
|
|
"df": dict(df), |
|
|
} |
|
|
|
|
|
def _tokenize(self, text: str) -> List[str]: |
|
|
"""Simple tokenization.""" |
|
|
text = text.lower() |
|
|
text = re.sub(r'[^\w\s]', ' ', text) |
|
|
return text.split() |
|
|
|
|
|
def _bm25_score(self, query_terms: List[str], doc_text: str) -> float: |
|
|
"""Compute BM25 score.""" |
|
|
if not self._doc_stats: |
|
|
return 0.0 |
|
|
|
|
|
doc_terms = self._tokenize(doc_text) |
|
|
dl = len(doc_terms) |
|
|
avg_dl = self._doc_stats["avg_dl"] |
|
|
n_docs = self._doc_stats["n_docs"] |
|
|
df = self._doc_stats["df"] |
|
|
|
|
|
|
|
|
tf = defaultdict(int) |
|
|
for term in doc_terms: |
|
|
tf[term] += 1 |
|
|
|
|
|
score = 0.0 |
|
|
for term in query_terms: |
|
|
if term not in tf: |
|
|
continue |
|
|
|
|
|
|
|
|
doc_freq = df.get(term, 0) |
|
|
idf = math.log((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1) |
|
|
|
|
|
|
|
|
term_freq = tf[term] |
|
|
tf_component = (term_freq * (self._k1 + 1)) / ( |
|
|
term_freq + self._k1 * (1 - self._b + self._b * dl / avg_dl) |
|
|
) |
|
|
|
|
|
score += idf * tf_component |
|
|
|
|
|
return score |
|
|
|
|
|
def _reciprocal_rank_fusion( |
|
|
self, |
|
|
dense_results: Dict[str, Tuple[int, float, Any]], |
|
|
sparse_results: Dict[str, Tuple[int, float, Any]], |
|
|
dense_weight: float, |
|
|
sparse_weight: float, |
|
|
) -> Dict[str, RetrievalResult]: |
|
|
""" |
|
|
Combine dense and sparse results using RRF. |
|
|
|
|
|
RRF score = sum(1 / (k + rank)) for each ranking |
|
|
""" |
|
|
k = self.config.rrf_k |
|
|
combined = {} |
|
|
|
|
|
|
|
|
all_chunk_ids = set(dense_results.keys()) | set(sparse_results.keys()) |
|
|
|
|
|
for chunk_id in all_chunk_ids: |
|
|
dense_rank = dense_results.get(chunk_id, (1000, 0, None))[0] |
|
|
dense_score = dense_results.get(chunk_id, (1000, 0, None))[1] |
|
|
sparse_rank = sparse_results.get(chunk_id, (1000, 0, None))[0] |
|
|
sparse_score = sparse_results.get(chunk_id, (1000, 0, None))[1] |
|
|
|
|
|
|
|
|
rrf_dense = dense_weight / (k + dense_rank) if chunk_id in dense_results else 0 |
|
|
rrf_sparse = sparse_weight / (k + sparse_rank) if chunk_id in sparse_results else 0 |
|
|
rrf_score = rrf_dense + rrf_sparse |
|
|
|
|
|
|
|
|
metadata = {} |
|
|
page = None |
|
|
chunk_type = None |
|
|
source_path = None |
|
|
text = "" |
|
|
document_id = "" |
|
|
bbox = None |
|
|
|
|
|
if chunk_id in dense_results: |
|
|
result_obj = dense_results[chunk_id][2] |
|
|
if result_obj: |
|
|
text = result_obj.text |
|
|
document_id = result_obj.document_id |
|
|
page = result_obj.page |
|
|
chunk_type = result_obj.chunk_type |
|
|
metadata = result_obj.metadata |
|
|
source_path = metadata.get("source_path", "") |
|
|
bbox = result_obj.bbox |
|
|
|
|
|
combined[chunk_id] = RetrievalResult( |
|
|
chunk_id=chunk_id, |
|
|
document_id=document_id, |
|
|
text=text, |
|
|
score=rrf_score, |
|
|
dense_score=dense_score if chunk_id in dense_results else None, |
|
|
sparse_score=sparse_score if chunk_id in sparse_results else None, |
|
|
dense_rank=dense_rank if chunk_id in dense_results else None, |
|
|
sparse_rank=sparse_rank if chunk_id in sparse_results else None, |
|
|
page=page, |
|
|
chunk_type=chunk_type, |
|
|
source_path=source_path, |
|
|
metadata=metadata, |
|
|
bbox=bbox, |
|
|
) |
|
|
|
|
|
return combined |
|
|
|
|
|
def _topological_sort(self, sub_queries: List[SubQuery]) -> List[SubQuery]: |
|
|
"""Sort sub-queries by dependencies.""" |
|
|
|
|
|
sorted_queries = [] |
|
|
remaining = list(sub_queries) |
|
|
completed = set() |
|
|
|
|
|
while remaining: |
|
|
for sq in remaining[:]: |
|
|
if all(dep in completed for dep in sq.depends_on): |
|
|
sorted_queries.append(sq) |
|
|
completed.add(sq.id) |
|
|
remaining.remove(sq) |
|
|
break |
|
|
else: |
|
|
|
|
|
sorted_queries.extend(remaining) |
|
|
break |
|
|
|
|
|
return sorted_queries |
|
|
|