Spaces:
Sleeping
Sleeping
| """ | |
| core/retriever.py - Document Retrieval with Optional Reranking | |
| ============================================================== | |
| Retrieves relevant chunks from vector store with optional reranking. | |
| - Vector similarity search via VectorStore (ChromaDB) | |
| - Optional reranking via Azure Cohere rerank API (v2) | |
| """ | |
| from typing import List, Dict, Optional | |
| from dataclasses import dataclass | |
| import time | |
| import os | |
| import httpx # For Azure Cohere rerank API | |
| # LLM clients (kept for future, not used directly here) | |
| from openai import OpenAI # For completeness, not used in rerank now | |
| # Our components | |
| try: | |
| from .vector_store import VectorStore | |
| except ImportError: | |
| from vector_store import VectorStore | |
| # Config | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class RetrievalResult: | |
| """Container for retrieval result with metadata""" | |
| chunks: List[str] | |
| scores: List[float] | |
| metadata: List[Dict] | |
| retrieval_time_ms: float | |
| reranking_time_ms: float | |
| total_chunks_retrieved: int | |
| reranked: bool | |
| rerank_cost_usd: float = 0.0 | |
| class Retriever: | |
| """Document retriever with optional reranking""" | |
| # Cohere rerank pricing: ~1 USD per 1M tokens (0.001 per 1K) [web:56] | |
| RERANK_COST_PER_1K_TOKENS = 0.001 | |
| def __init__( | |
| self, | |
| vector_store: VectorStore, | |
| top_k: int = 5, | |
| use_reranking: bool = False, | |
| rerank_top_n: int = 3, | |
| ): | |
| """ | |
| Initialize retriever | |
| Args: | |
| vector_store: VectorStore instance | |
| top_k: Number of chunks to retrieve initially | |
| use_reranking: Whether to rerank results (uses Azure Cohere) | |
| rerank_top_n: Number of top chunks to return after reranking | |
| """ | |
| self.vector_store = vector_store | |
| self.top_k = top_k | |
| self.use_reranking = use_reranking | |
| self.rerank_top_n = rerank_top_n | |
| # Initialize Azure Cohere rerank config | |
| if use_reranking: | |
| self.azure_rerank_endpoint = os.getenv("AZURE_COHERE_RERANK_ENDPOINT") | |
| self.azure_rerank_key = os.getenv("AZURE_COHERE_RERANK_KEY") | |
| if not self.azure_rerank_key: | |
| raise RuntimeError( | |
| "Reranking requested but Azure Cohere rerank key is not configured. " | |
| "Set AZURE_COHERE_RERANK_KEY in your .env." | |
| ) | |
| print(f"✅ Cohere reranking enabled via Azure AI Foundry (top-{self.rerank_top_n})") | |
| def retrieve( | |
| self, | |
| query: str, | |
| filter: Optional[Dict] = None, | |
| ) -> RetrievalResult: | |
| """ | |
| Retrieve relevant chunks for query | |
| Args: | |
| query: User question | |
| filter: Optional metadata filter | |
| Returns: | |
| RetrievalResult with chunks and metadata | |
| """ | |
| # Step 1: Vector similarity search | |
| start_time = time.time() | |
| results = self.vector_store.similarity_search_with_score( | |
| query=query, | |
| k=self.top_k if not self.use_reranking else self.top_k * 3, # more candidates for rerank | |
| filter=filter, | |
| ) | |
| retrieval_time_ms = (time.time() - start_time) * 1000 | |
| # Extract chunks, scores, metadata | |
| chunks = [r["text"] for r in results] | |
| scores = [r["score"] for r in results] | |
| metadata = [r["metadata"] for r in results] | |
| reranking_time_ms = 0.0 | |
| reranked = False | |
| rerank_cost = 0.0 | |
| # Step 2: Optional reranking via Azure Cohere | |
| if self.use_reranking and len(chunks) > 0: | |
| start_rerank = time.time() | |
| # Estimate rerank cost (query + all candidate chunks) | |
| total_text = query + " ".join(chunks) | |
| estimated_tokens = len(total_text.split()) * 1.3 # rough word→token factor | |
| rerank_cost = (estimated_tokens / 1000) * self.RERANK_COST_PER_1K_TOKENS | |
| # Call Azure Cohere rerank v2 API | |
| try: | |
| reranked_chunks, reranked_scores, reranked_metadata = self._rerank_with_azure_cohere( | |
| query=query, | |
| chunks=chunks, | |
| metadatas=metadata, | |
| ) | |
| # Truncate to rerank_top_n | |
| chunks = reranked_chunks[: self.rerank_top_n] | |
| scores = reranked_scores[: self.rerank_top_n] | |
| metadata = reranked_metadata[: self.rerank_top_n] | |
| reranked = True | |
| except Exception as e: | |
| # Hard fail: no fallback to direct Cohere API | |
| raise RuntimeError( | |
| f"Azure Cohere reranking failed or is not configured correctly: {e}" | |
| ) | |
| reranking_time_ms = (time.time() - start_rerank) * 1000 | |
| return RetrievalResult( | |
| chunks=chunks, | |
| scores=scores, | |
| metadata=metadata, | |
| retrieval_time_ms=retrieval_time_ms, | |
| reranking_time_ms=reranking_time_ms, | |
| total_chunks_retrieved=len(chunks), | |
| reranked=reranked, | |
| rerank_cost_usd=rerank_cost, | |
| ) | |
| def _rerank_with_azure_cohere( | |
| self, | |
| query: str, | |
| chunks: List[str], | |
| metadatas: List[Dict], | |
| ): | |
| """ | |
| Rerank chunks using Azure Cohere rerank v2 API. | |
| Model: cohere-rerank-v4.0-fast | |
| """ | |
| payload = { | |
| "model": "cohere-rerank-v4.0-fast", | |
| "query": query, | |
| "documents": chunks, | |
| "top_n": self.rerank_top_n, | |
| "return_documents": True, | |
| } | |
| with httpx.Client() as client: | |
| response = client.post( | |
| self.azure_rerank_endpoint, | |
| headers={ | |
| "Authorization": f"Bearer {self.azure_rerank_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| json=payload, | |
| timeout=10.0, | |
| ) | |
| response.raise_for_status() | |
| data = response.json() | |
| reranked_chunks: List[str] = [] | |
| reranked_scores: List[float] = [] | |
| reranked_metadatas: List[Dict] = [] | |
| for result in data.get("results", []): | |
| idx = result.get("index") | |
| reranked_chunks.append(chunks[idx]) | |
| reranked_scores.append(result.get("relevance_score", 0.0)) | |
| reranked_metadatas.append(metadatas[idx]) | |
| return reranked_chunks, reranked_scores, reranked_metadatas | |
| def get_context_string(self, result: RetrievalResult) -> str: | |
| """Format retrieved chunks as context string for LLM""" | |
| return "\n\n".join( | |
| [f"[{i+1}] {chunk}" for i, chunk in enumerate(result.chunks)] | |
| ) | |
| # ============================================================================ | |
| # USAGE EXAMPLE | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| print("🔍 Retriever Test") | |
| print("=" * 80) | |
| # Initialize embedder | |
| print("\n1️⃣ Initializing embedder...") | |
| from embedder import Embedder | |
| embedder = Embedder( | |
| provider="sentence-transformers", | |
| model_name="all-MiniLM-L6-v2", | |
| ) | |
| print(f" ✅ Embedder ready: {embedder.model_name}") | |
| # Initialize vector store | |
| print("\n2️⃣ Initializing vector store...") | |
| from vector_store import VectorStore | |
| vector_store = VectorStore( | |
| collection_name="retriever_test", | |
| embedder=embedder, | |
| ) | |
| # Add sample documents | |
| print("\n3️⃣ Adding sample documents...") | |
| from chunker import Chunker | |
| sample_docs = [ | |
| "The RAG Pipeline Optimizer uses wiki_dpr and Natural Questions datasets for evaluation.", | |
| "Six different pipeline configurations are tested simultaneously with varying chunk sizes.", | |
| "Azure OpenAI GPT-5, Cohere, DeepSeek, Claude 4, and Groq Llama models are compared for cost and quality.", | |
| "Embeddings can be generated using local sentence-transformers (free) or Azure OpenAI (paid).", | |
| "ChromaDB provides local vector storage with persistence and fast similarity search.", | |
| "Pipeline A uses GPT-5 with 256 token chunks for speed.", | |
| "Pipeline B uses GPT-5 with 512 token chunks and Cohere reranking for maximum accuracy.", | |
| "Pipeline C uses Cohere Command with 512 token chunks for balanced performance.", | |
| "Pipeline D uses Claude 4 Sonnet with 1024 token chunks for complex reasoning.", | |
| "Pipeline E uses DeepSeek with 2048 token chunks for cost optimization.", | |
| "Pipeline F uses Groq Llama with 600 token chunks for experimental fast inference.", | |
| ] | |
| chunker = Chunker(chunk_size=100, chunk_overlap=0) | |
| chunks = [] | |
| for i, doc in enumerate(sample_docs): | |
| doc_chunks = chunker.chunk(doc, strategy="recursive") | |
| for chunk in doc_chunks: | |
| chunk.metadata = {"doc_id": i, "source": f"doc_{i}"} | |
| chunks.extend(doc_chunks) | |
| vector_store.add_chunks(chunks, metadata={"test": "retriever_demo"}) | |
| print(f" ✅ Added {len(chunks)} chunks from {len(sample_docs)} documents") | |
| # Test queries | |
| test_queries = [ | |
| "What models are used in the RAG optimizer?", | |
| "Which pipeline is cheapest?", | |
| "How does Pipeline B differ from Pipeline A?", | |
| ] | |
| print("\n" + "=" * 80) | |
| # Test retrieval WITHOUT reranking | |
| for idx, query in enumerate(test_queries, 1): | |
| print(f"\n📋 Query {idx}: '{query}'") | |
| print("-" * 80) | |
| print("\n4️⃣ Retrieval WITHOUT reranking:") | |
| retriever = Retriever( | |
| vector_store=vector_store, | |
| top_k=3, | |
| use_reranking=False, | |
| ) | |
| result = retriever.retrieve(query) | |
| print(f" ✅ Retrieved: {result.total_chunks_retrieved} chunks") | |
| print(f" ⏱️ Time: {result.retrieval_time_ms:.0f}ms") | |
| print(f" 🔄 Reranked: {result.reranked}") | |
| print(f"\n Top chunks:") | |
| for i, (chunk, score) in enumerate(zip(result.chunks, result.scores), 1): | |
| print(f" {i}. (score: {score:.4f}) {chunk[:100]}...") | |
| # Test retrieval WITH reranking | |
| print("\n5️⃣ Retrieval WITH reranking:") | |
| retriever_rerank = Retriever( | |
| vector_store=vector_store, | |
| top_k=5, | |
| use_reranking=True, | |
| rerank_top_n=3, | |
| ) | |
| result_rerank = retriever_rerank.retrieve(query) | |
| print(f" ✅ Retrieved: {result_rerank.total_chunks_retrieved} chunks (from {5} initial)") | |
| print(f" ⏱️ Retrieval time: {result_rerank.retrieval_time_ms:.0f}ms") | |
| print(f" ⏱️ Reranking time: {result_rerank.reranking_time_ms:.0f}ms") | |
| print(f" 💰 Rerank cost: ${result_rerank.rerank_cost_usd:.6f}") | |
| print(f" 🔄 Reranked: {result_rerank.reranked}") | |
| print(f"\n Top reranked chunks:") | |
| for i, (chunk, score) in enumerate(zip(result_rerank.chunks, result_rerank.scores), 1): | |
| print(f" {i}. (score: {score:.4f}) {chunk[:100]}...") | |
| # Test context string formatting | |
| print("\n" + "=" * 80) | |
| print("\n6️⃣ Testing context string formatting:") | |
| retriever = Retriever(vector_store=vector_store, top_k=3) | |
| result = retriever.retrieve(test_queries[0]) | |
| context_str = retriever.get_context_string(result) | |
| print(f"\n{context_str}\n") | |
| # Cleanup | |
| print("\n7️⃣ Cleaning up test collection...") | |
| vector_store.delete_collection() | |
| print("\n" + "=" * 80) | |
| print("✅ Retriever test complete!") | |
| print("\n💡 Key features:") | |
| print(" - Vector similarity search: ~50-100ms") | |
| print(" - Azure Cohere reranking for Pipeline B (accuracy-focused)") | |
| print(" - Skip reranking for Pipelines A,C,D,E,F (speed/cost-focused)") | |
| print("\n🚀 Next: Build pipeline orchestrator to combine all components!") | |