import time from typing import List, Dict, Any, Optional, Tuple from dataclasses import dataclass from .index import VectorStore @dataclass class RetrievalResult: """Result from retrieval pipeline""" content: str sources: List[Dict[str, Any]] latency: float metadata: Dict[str, Any] class BaseRAG: """Standard RAG pipeline without hierarchical filtering""" def __init__(self, vector_store: VectorStore, collection_name: str = "documents"): self.vector_store = vector_store self.collection_name = collection_name def retrieve(self, query: str, k: int = 5) -> RetrievalResult: """Retrieve documents using standard vector similarity""" start_time = time.time() results = self.vector_store.search( collection_name=self.collection_name, query=query, k=k ) latency = time.time() - start_time return RetrievalResult( content=self._format_results(results), sources=results, latency=latency, metadata={"strategy": "base_rag", "k": k} ) def _format_results(self, results: List[Dict[str, Any]]) -> str: """Format retrieval results into text""" formatted = [] for i, result in enumerate(results, 1): formatted.append(f"[{i}] {result['content'][:200]}... (Score: {result['score']:.3f})") return "\n\n".join(formatted) class HierarchicalRAG: """Hierarchical RAG pipeline with metadata filtering""" def __init__(self, vector_store: VectorStore, collection_name: str = "documents"): self.vector_store = vector_store self.collection_name = collection_name def retrieve(self, query: str, k: int = 5, level1: Optional[str] = None, level2: Optional[str] = None, level3: Optional[str] = None, doc_type: Optional[str] = None) -> RetrievalResult: """Retrieve documents with hierarchical filtering""" start_time = time.time() # Build metadata filters filters = self._build_filters(level1, level2, level3, doc_type) results = self.vector_store.search( collection_name=self.collection_name, query=query, filters=filters, k=k ) latency = time.time() - start_time return RetrievalResult( content=self._format_results(results), sources=results, latency=latency, metadata={ "strategy": "hier_rag", "k": k, "filters": filters } ) def _build_filters(self, level1: Optional[str], level2: Optional[str], level3: Optional[str], doc_type: Optional[str]) -> Dict[str, Any]: """Build metadata filters for hierarchical search""" filters = {} if level1: filters["level1"] = level1 if level2: filters["level2"] = level2 if level3: filters["level3"] = level3 if doc_type: filters["doc_type"] = doc_type return filters if filters else None def _format_results(self, results: List[Dict[str, Any]]) -> str: """Format retrieval results into text""" formatted = [] for i, result in enumerate(results, 1): metadata = result['metadata'] formatted.append( f"[{i}] {result['content'][:200]}...\n" f" Domain: {metadata.get('level1', 'N/A')} > " f"{metadata.get('level2', 'N/A')} > " f"{metadata.get('level3', 'N/A')}\n" f" Score: {result['score']:.3f}" ) return "\n\n".join(formatted) class RAGManager: """Manager for both RAG pipelines""" def __init__(self, persist_directory: str = "/data/chroma"): self.vector_store = VectorStore(persist_directory) self.base_rag = BaseRAG(self.vector_store) self.hier_rag = HierarchicalRAG(self.vector_store) def compare_retrieval(self, query: str, k: int = 5, level1: Optional[str] = None, level2: Optional[str] = None, level3: Optional[str] = None, doc_type: Optional[str] = None) -> Tuple[RetrievalResult, RetrievalResult]: """Compare Base-RAG vs Hier-RAG""" base_result = self.base_rag.retrieve(query, k) hier_result = self.hier_rag.retrieve(query, k, level1, level2, level3, doc_type) return base_result, hier_result