zeta / src /retrieval /retriever.py
rodrigo-moonray
Deploy zeta-only embeddings (NV-Embed-v2 + E5-small)
9b457ed
"""
Document retrieval with semantic search.
This module provides retrieval functionality using ChromaDB vector search
with support for filtering and relevance scoring.
"""
from typing import List, Optional
from dataclasses import dataclass, field
import json
import numpy as np
from src.config.settings import get_settings
from src.embedding.embedder import Embedder
from src.embedding.vector_store import VectorStore
from src.utils.logging import get_logger, log_retrieval
import time
logger = get_logger(__name__)
@dataclass
class RetrievedChunk:
"""A chunk retrieved from the vector store with relevance info."""
chunk_id: str
text: str
filename: str
document_id: str
score: float
token_count: int
chunk_index: int
page_numbers: List[int]
metadata: dict = field(default_factory=dict)
def to_dict(self) -> dict:
"""Convert to dictionary."""
return {
"chunk_id": self.chunk_id,
"text": self.text,
"filename": self.filename,
"document_id": self.document_id,
"score": self.score,
"token_count": self.token_count,
"chunk_index": self.chunk_index,
"page_numbers": self.page_numbers,
"metadata": self.metadata,
}
@property
def source_type(self) -> str:
"""Get the source type (local, web, arxiv, etc.)."""
return self.metadata.get("source_type", "local")
@property
def url(self) -> Optional[str]:
"""Get URL for web/scientific sources."""
return self.metadata.get("url")
class Retriever:
"""Retrieve relevant document chunks for a query."""
def __init__(self):
"""Initialize retriever with embedder and vector store."""
settings = get_settings()
self.embedder = Embedder()
self.vector_store = VectorStore()
self.top_k = settings.top_k_retrieval
self.score_threshold = settings.retrieval_score_threshold
def retrieve(
self,
query: str,
top_k: Optional[int] = None,
filter_filename: Optional[str] = None,
filter_filenames: Optional[List[str]] = None,
) -> List[RetrievedChunk]:
"""
Retrieve relevant chunks for a query.
Args:
query: User query text
top_k: Number of results to return (default from settings)
filter_filename: Optional single filename to filter results (deprecated, use filter_filenames)
filter_filenames: Optional list of filenames to filter results
Returns:
List[RetrievedChunk]: Retrieved chunks sorted by relevance
"""
start_time = time.time()
k = top_k or self.top_k
logger.debug(f"Retrieving chunks for query: {query[:100]}...")
# Handle both single filename and list of filenames
filenames_filter = filter_filenames
if filter_filename and not filenames_filter:
filenames_filter = [filter_filename]
# Generate query embedding (is_query=True for models that use query instructions)
query_embedding = self.embedder.encode_single(query, is_query=True)
# Query vector store with filtering
results = self.vector_store.query(
query_embedding,
top_k=k,
filter_filenames=filenames_filter,
)
# Process results
chunks = []
if results and results.get('ids') and len(results['ids']) > 0:
ids = results['ids'][0]
documents = results['documents'][0]
metadatas = results['metadatas'][0]
distances = results['distances'][0]
for i, (chunk_id, text, metadata, distance) in enumerate(
zip(ids, documents, metadatas, distances)
):
# Convert distance to similarity score (ChromaDB uses L2 distance)
# Lower distance = higher similarity
score = 1.0 / (1.0 + distance)
# Apply score threshold
if score < self.score_threshold:
continue
# Parse page_numbers from JSON string
page_numbers_raw = metadata.get('page_numbers', '[]')
try:
page_numbers = json.loads(page_numbers_raw) if isinstance(page_numbers_raw, str) else page_numbers_raw
except (json.JSONDecodeError, TypeError):
page_numbers = []
chunk = RetrievedChunk(
chunk_id=chunk_id,
text=text,
filename=metadata.get('filename', 'unknown'),
document_id=metadata.get('document_id', ''),
score=score,
token_count=metadata.get('token_count', 0),
chunk_index=metadata.get('chunk_index', 0),
page_numbers=page_numbers,
)
chunks.append(chunk)
# Log retrieval metrics
duration_ms = (time.time() - start_time) * 1000
log_retrieval(logger, query, len(chunks), duration_ms)
return chunks
def retrieve_with_diversity(
self,
query: str,
top_k: Optional[int] = None,
diversity_threshold: float = 0.8,
filter_filenames: Optional[List[str]] = None,
) -> List[RetrievedChunk]:
"""
Retrieve chunks with diversity filtering to avoid redundant results.
Uses maximal marginal relevance (MMR) to balance relevance and diversity.
Args:
query: User query text
top_k: Number of diverse results to return
diversity_threshold: Similarity threshold for diversity filtering
filter_filenames: Optional list of filenames to filter results
Returns:
List[RetrievedChunk]: Diverse retrieved chunks
"""
k = top_k or self.top_k
# Retrieve more candidates for diversity filtering
candidates = self.retrieve(query, top_k=k * 3, filter_filenames=filter_filenames)
if not candidates:
return []
# Apply simple diversity filtering based on text similarity
diverse_chunks = [candidates[0]] # Always include top result
for candidate in candidates[1:]:
if len(diverse_chunks) >= k:
break
# Check if candidate is sufficiently different from selected chunks
is_diverse = True
for selected in diverse_chunks:
# Simple text overlap check
overlap = self._text_overlap(candidate.text, selected.text)
if overlap > diversity_threshold:
is_diverse = False
break
if is_diverse:
diverse_chunks.append(candidate)
logger.debug(f"Diversity filtering: {len(candidates)} -> {len(diverse_chunks)} chunks")
return diverse_chunks
def _text_overlap(self, text1: str, text2: str) -> float:
"""
Calculate simple text overlap ratio.
Args:
text1: First text
text2: Second text
Returns:
float: Overlap ratio (0-1)
"""
words1 = set(text1.lower().split())
words2 = set(text2.lower().split())
if not words1 or not words2:
return 0.0
intersection = len(words1 & words2)
union = len(words1 | words2)
return intersection / union if union > 0 else 0.0
def get_stats(self) -> dict:
"""Get retriever statistics."""
return {
"vector_store": self.vector_store.get_collection_stats(),
"embedding_model": self.embedder.model_name,
"top_k": self.top_k,
"score_threshold": self.score_threshold,
}