| """ |
| Document retrieval with semantic search. |
| |
| This module provides retrieval functionality using ChromaDB vector search |
| with support for filtering and relevance scoring. |
| """ |
|
|
| from typing import List, Optional |
| from dataclasses import dataclass, field |
| import json |
| import numpy as np |
| from src.config.settings import get_settings |
| from src.embedding.embedder import Embedder |
| from src.embedding.vector_store import VectorStore |
| from src.utils.logging import get_logger, log_retrieval |
| import time |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| @dataclass |
| class RetrievedChunk: |
| """A chunk retrieved from the vector store with relevance info.""" |
|
|
| chunk_id: str |
| text: str |
| filename: str |
| document_id: str |
| score: float |
| token_count: int |
| chunk_index: int |
| page_numbers: List[int] |
| metadata: dict = field(default_factory=dict) |
|
|
| def to_dict(self) -> dict: |
| """Convert to dictionary.""" |
| return { |
| "chunk_id": self.chunk_id, |
| "text": self.text, |
| "filename": self.filename, |
| "document_id": self.document_id, |
| "score": self.score, |
| "token_count": self.token_count, |
| "chunk_index": self.chunk_index, |
| "page_numbers": self.page_numbers, |
| "metadata": self.metadata, |
| } |
|
|
| @property |
| def source_type(self) -> str: |
| """Get the source type (local, web, arxiv, etc.).""" |
| return self.metadata.get("source_type", "local") |
|
|
| @property |
| def url(self) -> Optional[str]: |
| """Get URL for web/scientific sources.""" |
| return self.metadata.get("url") |
|
|
|
|
| class Retriever: |
| """Retrieve relevant document chunks for a query.""" |
|
|
| def __init__(self): |
| """Initialize retriever with embedder and vector store.""" |
| settings = get_settings() |
| self.embedder = Embedder() |
| self.vector_store = VectorStore() |
| self.top_k = settings.top_k_retrieval |
| self.score_threshold = settings.retrieval_score_threshold |
|
|
| def retrieve( |
| self, |
| query: str, |
| top_k: Optional[int] = None, |
| filter_filename: Optional[str] = None, |
| filter_filenames: Optional[List[str]] = None, |
| ) -> List[RetrievedChunk]: |
| """ |
| Retrieve relevant chunks for a query. |
| |
| Args: |
| query: User query text |
| top_k: Number of results to return (default from settings) |
| filter_filename: Optional single filename to filter results (deprecated, use filter_filenames) |
| filter_filenames: Optional list of filenames to filter results |
| |
| Returns: |
| List[RetrievedChunk]: Retrieved chunks sorted by relevance |
| """ |
| start_time = time.time() |
| k = top_k or self.top_k |
|
|
| logger.debug(f"Retrieving chunks for query: {query[:100]}...") |
|
|
| |
| filenames_filter = filter_filenames |
| if filter_filename and not filenames_filter: |
| filenames_filter = [filter_filename] |
|
|
| |
| query_embedding = self.embedder.encode_single(query, is_query=True) |
|
|
| |
| results = self.vector_store.query( |
| query_embedding, |
| top_k=k, |
| filter_filenames=filenames_filter, |
| ) |
|
|
| |
| chunks = [] |
|
|
| if results and results.get('ids') and len(results['ids']) > 0: |
| ids = results['ids'][0] |
| documents = results['documents'][0] |
| metadatas = results['metadatas'][0] |
| distances = results['distances'][0] |
|
|
| for i, (chunk_id, text, metadata, distance) in enumerate( |
| zip(ids, documents, metadatas, distances) |
| ): |
| |
| |
| score = 1.0 / (1.0 + distance) |
|
|
| |
| if score < self.score_threshold: |
| continue |
|
|
| |
| page_numbers_raw = metadata.get('page_numbers', '[]') |
| try: |
| page_numbers = json.loads(page_numbers_raw) if isinstance(page_numbers_raw, str) else page_numbers_raw |
| except (json.JSONDecodeError, TypeError): |
| page_numbers = [] |
|
|
| chunk = RetrievedChunk( |
| chunk_id=chunk_id, |
| text=text, |
| filename=metadata.get('filename', 'unknown'), |
| document_id=metadata.get('document_id', ''), |
| score=score, |
| token_count=metadata.get('token_count', 0), |
| chunk_index=metadata.get('chunk_index', 0), |
| page_numbers=page_numbers, |
| ) |
| chunks.append(chunk) |
|
|
| |
| duration_ms = (time.time() - start_time) * 1000 |
| log_retrieval(logger, query, len(chunks), duration_ms) |
|
|
| return chunks |
|
|
| def retrieve_with_diversity( |
| self, |
| query: str, |
| top_k: Optional[int] = None, |
| diversity_threshold: float = 0.8, |
| filter_filenames: Optional[List[str]] = None, |
| ) -> List[RetrievedChunk]: |
| """ |
| Retrieve chunks with diversity filtering to avoid redundant results. |
| |
| Uses maximal marginal relevance (MMR) to balance relevance and diversity. |
| |
| Args: |
| query: User query text |
| top_k: Number of diverse results to return |
| diversity_threshold: Similarity threshold for diversity filtering |
| filter_filenames: Optional list of filenames to filter results |
| |
| Returns: |
| List[RetrievedChunk]: Diverse retrieved chunks |
| """ |
| k = top_k or self.top_k |
|
|
| |
| candidates = self.retrieve(query, top_k=k * 3, filter_filenames=filter_filenames) |
|
|
| if not candidates: |
| return [] |
|
|
| |
| diverse_chunks = [candidates[0]] |
|
|
| for candidate in candidates[1:]: |
| if len(diverse_chunks) >= k: |
| break |
|
|
| |
| is_diverse = True |
| for selected in diverse_chunks: |
| |
| overlap = self._text_overlap(candidate.text, selected.text) |
| if overlap > diversity_threshold: |
| is_diverse = False |
| break |
|
|
| if is_diverse: |
| diverse_chunks.append(candidate) |
|
|
| logger.debug(f"Diversity filtering: {len(candidates)} -> {len(diverse_chunks)} chunks") |
| return diverse_chunks |
|
|
| def _text_overlap(self, text1: str, text2: str) -> float: |
| """ |
| Calculate simple text overlap ratio. |
| |
| Args: |
| text1: First text |
| text2: Second text |
| |
| Returns: |
| float: Overlap ratio (0-1) |
| """ |
| words1 = set(text1.lower().split()) |
| words2 = set(text2.lower().split()) |
|
|
| if not words1 or not words2: |
| return 0.0 |
|
|
| intersection = len(words1 & words2) |
| union = len(words1 | words2) |
|
|
| return intersection / union if union > 0 else 0.0 |
|
|
| def get_stats(self) -> dict: |
| """Get retriever statistics.""" |
| return { |
| "vector_store": self.vector_store.get_collection_stats(), |
| "embedding_model": self.embedder.model_name, |
| "top_k": self.top_k, |
| "score_threshold": self.score_threshold, |
| } |
|
|