Spaces:
Sleeping
Sleeping
| import time | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from dataclasses import dataclass | |
| from .index import VectorStore | |
| class RetrievalResult: | |
| """Result from retrieval pipeline""" | |
| content: str | |
| sources: List[Dict[str, Any]] | |
| latency: float | |
| metadata: Dict[str, Any] | |
| class BaseRAG: | |
| """Standard RAG pipeline without hierarchical filtering""" | |
| def __init__(self, vector_store: VectorStore, collection_name: str = "documents"): | |
| self.vector_store = vector_store | |
| self.collection_name = collection_name | |
| def retrieve(self, query: str, k: int = 5) -> RetrievalResult: | |
| """Retrieve documents using standard vector similarity""" | |
| start_time = time.time() | |
| results = self.vector_store.search( | |
| collection_name=self.collection_name, | |
| query=query, | |
| k=k | |
| ) | |
| latency = time.time() - start_time | |
| return RetrievalResult( | |
| content=self._format_results(results), | |
| sources=results, | |
| latency=latency, | |
| metadata={"strategy": "base_rag", "k": k} | |
| ) | |
| def _format_results(self, results: List[Dict[str, Any]]) -> str: | |
| """Format retrieval results into text""" | |
| formatted = [] | |
| for i, result in enumerate(results, 1): | |
| formatted.append(f"[{i}] {result['content'][:200]}... (Score: {result['score']:.3f})") | |
| return "\n\n".join(formatted) | |
| class HierarchicalRAG: | |
| """Hierarchical RAG pipeline with metadata filtering""" | |
| def __init__(self, vector_store: VectorStore, collection_name: str = "documents"): | |
| self.vector_store = vector_store | |
| self.collection_name = collection_name | |
| def retrieve(self, query: str, k: int = 5, | |
| level1: Optional[str] = None, | |
| level2: Optional[str] = None, | |
| level3: Optional[str] = None, | |
| doc_type: Optional[str] = None) -> RetrievalResult: | |
| """Retrieve documents with hierarchical filtering""" | |
| start_time = time.time() | |
| # Build metadata filters | |
| filters = self._build_filters(level1, level2, level3, doc_type) | |
| results = self.vector_store.search( | |
| collection_name=self.collection_name, | |
| query=query, | |
| filters=filters, | |
| k=k | |
| ) | |
| latency = time.time() - start_time | |
| return RetrievalResult( | |
| content=self._format_results(results), | |
| sources=results, | |
| latency=latency, | |
| metadata={ | |
| "strategy": "hier_rag", | |
| "k": k, | |
| "filters": filters | |
| } | |
| ) | |
| def _build_filters(self, level1: Optional[str], level2: Optional[str], | |
| level3: Optional[str], doc_type: Optional[str]) -> Dict[str, Any]: | |
| """Build metadata filters for hierarchical search""" | |
| filters = {} | |
| if level1: | |
| filters["level1"] = level1 | |
| if level2: | |
| filters["level2"] = level2 | |
| if level3: | |
| filters["level3"] = level3 | |
| if doc_type: | |
| filters["doc_type"] = doc_type | |
| return filters if filters else None | |
| def _format_results(self, results: List[Dict[str, Any]]) -> str: | |
| """Format retrieval results into text""" | |
| formatted = [] | |
| for i, result in enumerate(results, 1): | |
| metadata = result['metadata'] | |
| formatted.append( | |
| f"[{i}] {result['content'][:200]}...\n" | |
| f" Domain: {metadata.get('level1', 'N/A')} > " | |
| f"{metadata.get('level2', 'N/A')} > " | |
| f"{metadata.get('level3', 'N/A')}\n" | |
| f" Score: {result['score']:.3f}" | |
| ) | |
| return "\n\n".join(formatted) | |
| class RAGManager: | |
| """Manager for both RAG pipelines""" | |
| def __init__(self, persist_directory: str = "/data/chroma"): | |
| self.vector_store = VectorStore(persist_directory) | |
| self.base_rag = BaseRAG(self.vector_store) | |
| self.hier_rag = HierarchicalRAG(self.vector_store) | |
| def compare_retrieval(self, query: str, k: int = 5, | |
| level1: Optional[str] = None, | |
| level2: Optional[str] = None, | |
| level3: Optional[str] = None, | |
| doc_type: Optional[str] = None) -> Tuple[RetrievalResult, RetrievalResult]: | |
| """Compare Base-RAG vs Hier-RAG""" | |
| base_result = self.base_rag.retrieve(query, k) | |
| hier_result = self.hier_rag.retrieve(query, k, level1, level2, level3, doc_type) | |
| return base_result, hier_result |