soft.engineer
init project
e71fabd
import time
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from .index import VectorStore
@dataclass
class RetrievalResult:
"""Result from retrieval pipeline"""
content: str
sources: List[Dict[str, Any]]
latency: float
metadata: Dict[str, Any]
class BaseRAG:
"""Standard RAG pipeline without hierarchical filtering"""
def __init__(self, vector_store: VectorStore, collection_name: str = "documents"):
self.vector_store = vector_store
self.collection_name = collection_name
def retrieve(self, query: str, k: int = 5) -> RetrievalResult:
"""Retrieve documents using standard vector similarity"""
start_time = time.time()
results = self.vector_store.search(
collection_name=self.collection_name,
query=query,
k=k
)
latency = time.time() - start_time
return RetrievalResult(
content=self._format_results(results),
sources=results,
latency=latency,
metadata={"strategy": "base_rag", "k": k}
)
def _format_results(self, results: List[Dict[str, Any]]) -> str:
"""Format retrieval results into text"""
formatted = []
for i, result in enumerate(results, 1):
formatted.append(f"[{i}] {result['content'][:200]}... (Score: {result['score']:.3f})")
return "\n\n".join(formatted)
class HierarchicalRAG:
"""Hierarchical RAG pipeline with metadata filtering"""
def __init__(self, vector_store: VectorStore, collection_name: str = "documents"):
self.vector_store = vector_store
self.collection_name = collection_name
def retrieve(self, query: str, k: int = 5,
level1: Optional[str] = None,
level2: Optional[str] = None,
level3: Optional[str] = None,
doc_type: Optional[str] = None) -> RetrievalResult:
"""Retrieve documents with hierarchical filtering"""
start_time = time.time()
# Build metadata filters
filters = self._build_filters(level1, level2, level3, doc_type)
results = self.vector_store.search(
collection_name=self.collection_name,
query=query,
filters=filters,
k=k
)
latency = time.time() - start_time
return RetrievalResult(
content=self._format_results(results),
sources=results,
latency=latency,
metadata={
"strategy": "hier_rag",
"k": k,
"filters": filters
}
)
def _build_filters(self, level1: Optional[str], level2: Optional[str],
level3: Optional[str], doc_type: Optional[str]) -> Dict[str, Any]:
"""Build metadata filters for hierarchical search"""
filters = {}
if level1:
filters["level1"] = level1
if level2:
filters["level2"] = level2
if level3:
filters["level3"] = level3
if doc_type:
filters["doc_type"] = doc_type
return filters if filters else None
def _format_results(self, results: List[Dict[str, Any]]) -> str:
"""Format retrieval results into text"""
formatted = []
for i, result in enumerate(results, 1):
metadata = result['metadata']
formatted.append(
f"[{i}] {result['content'][:200]}...\n"
f" Domain: {metadata.get('level1', 'N/A')} > "
f"{metadata.get('level2', 'N/A')} > "
f"{metadata.get('level3', 'N/A')}\n"
f" Score: {result['score']:.3f}"
)
return "\n\n".join(formatted)
class RAGManager:
"""Manager for both RAG pipelines"""
def __init__(self, persist_directory: str = "/data/chroma"):
self.vector_store = VectorStore(persist_directory)
self.base_rag = BaseRAG(self.vector_store)
self.hier_rag = HierarchicalRAG(self.vector_store)
def compare_retrieval(self, query: str, k: int = 5,
level1: Optional[str] = None,
level2: Optional[str] = None,
level3: Optional[str] = None,
doc_type: Optional[str] = None) -> Tuple[RetrievalResult, RetrievalResult]:
"""Compare Base-RAG vs Hier-RAG"""
base_result = self.base_rag.retrieve(query, k)
hier_result = self.hier_rag.retrieve(query, k, level1, level2, level3, doc_type)
return base_result, hier_result