| | """ |
| | Agentic RAG Orchestrator |
| | |
| | Coordinates the multi-agent RAG pipeline with self-correction loop. |
| | Follows FAANG best practices for production RAG systems. |
| | |
| | Pipeline: |
| | Query -> Plan -> Retrieve -> Rerank -> Synthesize -> Validate -> (Revise?) -> Response |
| | |
| | Key Features: |
| | - LangGraph-style state machine |
| | - Self-correction loop (up to N attempts) |
| | - Streaming support |
| | - Comprehensive logging and metrics |
| | - Graceful degradation |
| | """ |
| |
|
| | from typing import List, Optional, Dict, Any, Generator, Tuple |
| | from pydantic import BaseModel, Field |
| | from loguru import logger |
| | from dataclasses import dataclass, field |
| | from enum import Enum |
| | import time |
| |
|
| | from ..store import VectorStore, get_vector_store, VectorStoreConfig |
| | from ..embeddings import EmbeddingAdapter, get_embedding_adapter, EmbeddingConfig |
| |
|
| | from .query_planner import QueryPlannerAgent, QueryPlan, SubQuery |
| | from .retriever import RetrieverAgent, RetrievalResult, HybridSearchConfig |
| | from .reranker import RerankerAgent, RankedResult, RerankerConfig |
| | from .synthesizer import SynthesizerAgent, SynthesisResult, Citation, SynthesizerConfig |
| | from .critic import CriticAgent, CriticResult, ValidationIssue, CriticConfig |
| |
|
| |
|
| | class PipelineStage(str, Enum): |
| | """Stages in the RAG pipeline.""" |
| | PLANNING = "planning" |
| | RETRIEVAL = "retrieval" |
| | RERANKING = "reranking" |
| | SYNTHESIS = "synthesis" |
| | VALIDATION = "validation" |
| | REVISION = "revision" |
| | COMPLETE = "complete" |
| |
|
| |
|
| | class RAGConfig(BaseModel): |
| | """Configuration for the agentic RAG system.""" |
| | |
| | model: str = Field(default="llama3.2:3b") |
| | base_url: str = Field(default="http://localhost:11434") |
| |
|
| | |
| | max_revision_attempts: int = Field(default=2, ge=0, le=5) |
| | enable_query_planning: bool = Field(default=True) |
| | enable_reranking: bool = Field(default=True) |
| | enable_validation: bool = Field(default=True) |
| |
|
| | |
| | retrieval_top_k: int = Field(default=10, ge=1) |
| | final_top_k: int = Field(default=5, ge=1) |
| |
|
| | |
| | min_confidence: float = Field(default=0.5, ge=0.0, le=1.0) |
| |
|
| | |
| | verbose: bool = Field(default=False) |
| |
|
| |
|
| | @dataclass |
| | class RAGState: |
| | """State maintained through the pipeline.""" |
| | query: str |
| | stage: PipelineStage = PipelineStage.PLANNING |
| |
|
| | |
| | query_plan: Optional[QueryPlan] = None |
| | retrieved_chunks: List[RetrievalResult] = field(default_factory=list) |
| | ranked_chunks: List[RankedResult] = field(default_factory=list) |
| | synthesis_result: Optional[SynthesisResult] = None |
| | critic_result: Optional[CriticResult] = None |
| |
|
| | |
| | revision_attempt: int = 0 |
| | revision_history: List[SynthesisResult] = field(default_factory=list) |
| |
|
| | |
| | start_time: float = field(default_factory=time.time) |
| | stage_times: Dict[str, float] = field(default_factory=dict) |
| |
|
| | |
| | errors: List[str] = field(default_factory=list) |
| |
|
| |
|
| | class RAGResponse(BaseModel): |
| | """Final response from the RAG system.""" |
| | answer: str |
| | citations: List[Citation] |
| | confidence: float |
| |
|
| | |
| | query: str |
| | num_sources: int |
| | validated: bool |
| | revision_attempts: int |
| |
|
| | |
| | query_plan: Optional[Dict[str, Any]] = None |
| | validation_details: Optional[Dict[str, Any]] = None |
| | latency_ms: float = 0.0 |
| |
|
| |
|
| | class AgenticRAG: |
| | """ |
| | Production-grade Multi-Agent RAG System. |
| | |
| | Orchestrates: |
| | - QueryPlannerAgent: Query decomposition and planning |
| | - RetrieverAgent: Hybrid retrieval |
| | - RerankerAgent: Cross-encoder reranking |
| | - SynthesizerAgent: Answer generation |
| | - CriticAgent: Validation and hallucination detection |
| | |
| | Features: |
| | - Self-correction loop |
| | - Graceful degradation |
| | - Comprehensive metrics |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | config: Optional[RAGConfig] = None, |
| | vector_store: Optional[VectorStore] = None, |
| | embedding_adapter: Optional[EmbeddingAdapter] = None, |
| | ): |
| | """ |
| | Initialize the Agentic RAG system. |
| | |
| | Args: |
| | config: RAG configuration |
| | vector_store: Vector store for retrieval |
| | embedding_adapter: Embedding adapter |
| | """ |
| | self.config = config or RAGConfig() |
| |
|
| | |
| | self._store = vector_store |
| | self._embedder = embedding_adapter |
| |
|
| | |
| | self._init_agents() |
| |
|
| | logger.info( |
| | f"AgenticRAG initialized (model={self.config.model}, " |
| | f"revision_attempts={self.config.max_revision_attempts})" |
| | ) |
| |
|
| | def _init_agents(self): |
| | """Initialize all agents with shared configuration.""" |
| | |
| | self.planner = QueryPlannerAgent( |
| | model=self.config.model, |
| | base_url=self.config.base_url, |
| | use_llm=self.config.enable_query_planning, |
| | ) |
| |
|
| | |
| | retriever_config = HybridSearchConfig( |
| | dense_top_k=self.config.retrieval_top_k, |
| | sparse_top_k=self.config.retrieval_top_k, |
| | final_top_k=self.config.retrieval_top_k, |
| | ) |
| | self.retriever = RetrieverAgent( |
| | config=retriever_config, |
| | vector_store=self._store, |
| | embedding_adapter=self._embedder, |
| | ) |
| |
|
| | |
| | reranker_config = RerankerConfig( |
| | model=self.config.model, |
| | base_url=self.config.base_url, |
| | top_k=self.config.final_top_k, |
| | use_llm_rerank=self.config.enable_reranking, |
| | min_relevance_score=0.1, |
| | ) |
| | self.reranker = RerankerAgent(config=reranker_config) |
| |
|
| | |
| | synth_config = SynthesizerConfig( |
| | model=self.config.model, |
| | base_url=self.config.base_url, |
| | confidence_threshold=self.config.min_confidence, |
| | ) |
| | self.synthesizer = SynthesizerAgent(config=synth_config) |
| |
|
| | |
| | critic_config = CriticConfig( |
| | model=self.config.model, |
| | base_url=self.config.base_url, |
| | ) |
| | self.critic = CriticAgent(config=critic_config) |
| |
|
| | @property |
| | def store(self) -> VectorStore: |
| | """Get vector store (lazy initialization).""" |
| | if self._store is None: |
| | self._store = get_vector_store() |
| | return self._store |
| |
|
| | @property |
| | def embedder(self) -> EmbeddingAdapter: |
| | """Get embedding adapter (lazy initialization).""" |
| | if self._embedder is None: |
| | self._embedder = get_embedding_adapter() |
| | return self._embedder |
| |
|
| | def query( |
| | self, |
| | question: str, |
| | filters: Optional[Dict[str, Any]] = None, |
| | ) -> RAGResponse: |
| | """ |
| | Process a query through the full RAG pipeline. |
| | |
| | Args: |
| | question: User's question |
| | filters: Optional metadata filters for retrieval |
| | |
| | Returns: |
| | RAGResponse with answer and metadata |
| | """ |
| | |
| | state = RAGState(query=question) |
| |
|
| | try: |
| | |
| | state = self._plan(state) |
| |
|
| | |
| | state = self._retrieve(state, filters) |
| |
|
| | |
| | state = self._rerank(state) |
| |
|
| | |
| | state = self._synthesize(state) |
| |
|
| | |
| | if self.config.enable_validation: |
| | state = self._validate_and_revise(state) |
| |
|
| | |
| | return self._build_response(state) |
| |
|
| | except Exception as e: |
| | logger.error(f"RAG pipeline error: {e}") |
| | state.errors.append(str(e)) |
| | return self._build_error_response(state, str(e)) |
| |
|
| | def query_stream( |
| | self, |
| | question: str, |
| | filters: Optional[Dict[str, Any]] = None, |
| | ) -> Generator[Tuple[PipelineStage, Any], None, None]: |
| | """ |
| | Process query with streaming updates. |
| | |
| | Yields: |
| | Tuple of (stage, stage_result) |
| | """ |
| | state = RAGState(query=question) |
| |
|
| | try: |
| | |
| | state = self._plan(state) |
| | yield PipelineStage.PLANNING, state.query_plan |
| |
|
| | |
| | state = self._retrieve(state, filters) |
| | yield PipelineStage.RETRIEVAL, len(state.retrieved_chunks) |
| |
|
| | |
| | state = self._rerank(state) |
| | yield PipelineStage.RERANKING, len(state.ranked_chunks) |
| |
|
| | |
| | state = self._synthesize(state) |
| | yield PipelineStage.SYNTHESIS, state.synthesis_result |
| |
|
| | |
| | if self.config.enable_validation: |
| | state = self._validate_and_revise(state) |
| | yield PipelineStage.VALIDATION, state.critic_result |
| |
|
| | |
| | response = self._build_response(state) |
| | yield PipelineStage.COMPLETE, response |
| |
|
| | except Exception as e: |
| | logger.error(f"Streaming error: {e}") |
| | yield PipelineStage.COMPLETE, self._build_error_response(state, str(e)) |
| |
|
| | def _plan(self, state: RAGState) -> RAGState: |
| | """Execute query planning stage.""" |
| | start = time.time() |
| | state.stage = PipelineStage.PLANNING |
| |
|
| | if self.config.verbose: |
| | logger.info(f"Planning query: {state.query}") |
| |
|
| | state.query_plan = self.planner.plan(state.query) |
| |
|
| | state.stage_times["planning"] = time.time() - start |
| |
|
| | if self.config.verbose: |
| | logger.info( |
| | f"Query plan: intent={state.query_plan.intent}, " |
| | f"sub_queries={len(state.query_plan.sub_queries)}" |
| | ) |
| |
|
| | return state |
| |
|
| | def _retrieve( |
| | self, |
| | state: RAGState, |
| | filters: Optional[Dict[str, Any]], |
| | ) -> RAGState: |
| | """Execute retrieval stage.""" |
| | start = time.time() |
| | state.stage = PipelineStage.RETRIEVAL |
| |
|
| | if self.config.verbose: |
| | logger.info("Retrieving relevant chunks...") |
| |
|
| | |
| | state.retrieved_chunks = self.retriever.retrieve( |
| | query=state.query, |
| | plan=state.query_plan, |
| | top_k=self.config.retrieval_top_k, |
| | filters=filters, |
| | ) |
| |
|
| | state.stage_times["retrieval"] = time.time() - start |
| |
|
| | if self.config.verbose: |
| | logger.info(f"Retrieved {len(state.retrieved_chunks)} chunks") |
| |
|
| | return state |
| |
|
| | def _rerank(self, state: RAGState) -> RAGState: |
| | """Execute reranking stage.""" |
| | start = time.time() |
| | state.stage = PipelineStage.RERANKING |
| |
|
| | if not state.retrieved_chunks: |
| | state.ranked_chunks = [] |
| | return state |
| |
|
| | if self.config.verbose: |
| | logger.info("Reranking results...") |
| |
|
| | state.ranked_chunks = self.reranker.rerank( |
| | query=state.query, |
| | results=state.retrieved_chunks, |
| | top_k=self.config.final_top_k, |
| | ) |
| |
|
| | state.stage_times["reranking"] = time.time() - start |
| |
|
| | if self.config.verbose: |
| | logger.info(f"Reranked to {len(state.ranked_chunks)} chunks") |
| |
|
| | return state |
| |
|
| | def _synthesize(self, state: RAGState) -> RAGState: |
| | """Execute synthesis stage.""" |
| | start = time.time() |
| | state.stage = PipelineStage.SYNTHESIS |
| |
|
| | if self.config.verbose: |
| | logger.info("Synthesizing answer...") |
| |
|
| | state.synthesis_result = self.synthesizer.synthesize( |
| | query=state.query, |
| | results=state.ranked_chunks, |
| | plan=state.query_plan, |
| | ) |
| |
|
| | state.stage_times["synthesis"] = time.time() - start |
| |
|
| | if self.config.verbose: |
| | logger.info( |
| | f"Synthesized answer (confidence={state.synthesis_result.confidence:.2f})" |
| | ) |
| |
|
| | return state |
| |
|
| | def _validate_and_revise(self, state: RAGState) -> RAGState: |
| | """Execute validation and optional revision loop.""" |
| | start = time.time() |
| |
|
| | while state.revision_attempt <= self.config.max_revision_attempts: |
| | state.stage = PipelineStage.VALIDATION |
| |
|
| | if self.config.verbose: |
| | logger.info(f"Validating (attempt {state.revision_attempt + 1})...") |
| |
|
| | |
| | state.critic_result = self.critic.validate( |
| | synthesis_result=state.synthesis_result, |
| | sources=state.ranked_chunks, |
| | ) |
| |
|
| | if state.critic_result.is_valid: |
| | if self.config.verbose: |
| | logger.info("Validation passed!") |
| | break |
| |
|
| | |
| | if state.revision_attempt >= self.config.max_revision_attempts: |
| | if self.config.verbose: |
| | logger.warning("Max revision attempts reached") |
| | break |
| |
|
| | |
| | state.stage = PipelineStage.REVISION |
| | state.revision_attempt += 1 |
| | state.revision_history.append(state.synthesis_result) |
| |
|
| | if self.config.verbose: |
| | logger.info(f"Revising answer (attempt {state.revision_attempt})...") |
| |
|
| | |
| | state.synthesis_result = self._revise_synthesis(state) |
| |
|
| | state.stage_times["validation"] = time.time() - start |
| | return state |
| |
|
| | def _revise_synthesis(self, state: RAGState) -> SynthesisResult: |
| | """Revise synthesis based on critic feedback.""" |
| | |
| | |
| | |
| | return self.synthesizer.synthesize( |
| | query=state.query, |
| | results=state.ranked_chunks, |
| | plan=state.query_plan, |
| | ) |
| |
|
| | def _build_response(self, state: RAGState) -> RAGResponse: |
| | """Build final response from state.""" |
| | total_time = (time.time() - state.start_time) * 1000 |
| |
|
| | synthesis = state.synthesis_result |
| | if synthesis is None: |
| | return self._build_error_response(state, "No synthesis result") |
| |
|
| | |
| | query_plan_dict = None |
| | if state.query_plan: |
| | query_plan_dict = { |
| | "intent": state.query_plan.intent.value, |
| | "sub_queries": len(state.query_plan.sub_queries), |
| | "expanded_terms": state.query_plan.expanded_terms[:5], |
| | } |
| |
|
| | |
| | validation_dict = None |
| | if state.critic_result: |
| | validation_dict = { |
| | "is_valid": state.critic_result.is_valid, |
| | "confidence": state.critic_result.confidence, |
| | "hallucination_score": state.critic_result.hallucination_score, |
| | "citation_accuracy": state.critic_result.citation_accuracy, |
| | "issues": len(state.critic_result.issues), |
| | } |
| |
|
| | return RAGResponse( |
| | answer=synthesis.answer, |
| | citations=synthesis.citations, |
| | confidence=synthesis.confidence, |
| | query=state.query, |
| | num_sources=synthesis.num_sources_used, |
| | validated=state.critic_result.is_valid if state.critic_result else False, |
| | revision_attempts=state.revision_attempt, |
| | query_plan=query_plan_dict, |
| | validation_details=validation_dict, |
| | latency_ms=total_time, |
| | ) |
| |
|
| | def _build_error_response( |
| | self, |
| | state: RAGState, |
| | error: str, |
| | ) -> RAGResponse: |
| | """Build error response.""" |
| | return RAGResponse( |
| | answer=f"I encountered an error processing your query: {error}", |
| | citations=[], |
| | confidence=0.0, |
| | query=state.query, |
| | num_sources=0, |
| | validated=False, |
| | revision_attempts=state.revision_attempt, |
| | latency_ms=(time.time() - state.start_time) * 1000, |
| | ) |
| |
|
| | def index_text( |
| | self, |
| | text: str, |
| | document_id: str, |
| | metadata: Optional[Dict[str, Any]] = None, |
| | ) -> int: |
| | """ |
| | Index text content into the vector store. |
| | |
| | Args: |
| | text: Text content to index |
| | document_id: Unique document identifier |
| | metadata: Optional metadata |
| | |
| | Returns: |
| | Number of chunks indexed |
| | """ |
| | |
| | chunk_size = 500 |
| | overlap = 50 |
| | chunks = [] |
| | embeddings = [] |
| |
|
| | for i in range(0, len(text), chunk_size - overlap): |
| | chunk_text = text[i:i + chunk_size] |
| | if len(chunk_text.strip()) < 50: |
| | continue |
| |
|
| | chunk_id = f"{document_id}_chunk_{len(chunks)}" |
| | chunks.append({ |
| | "chunk_id": chunk_id, |
| | "document_id": document_id, |
| | "text": chunk_text, |
| | "page": 0, |
| | "chunk_type": "text", |
| | "source_path": metadata.get("filename", "") if metadata else "", |
| | }) |
| |
|
| | |
| | embedding = self.embedder.embed_text(chunk_text) |
| | embeddings.append(embedding) |
| |
|
| | if not chunks: |
| | return 0 |
| |
|
| | |
| | self.store.add_chunks(chunks, embeddings) |
| |
|
| | logger.info(f"Indexed {len(chunks)} chunks for document {document_id}") |
| | return len(chunks) |
| |
|
| | def get_stats(self) -> Dict[str, Any]: |
| | """Get system statistics.""" |
| | return { |
| | "total_chunks": self.store.count(), |
| | "model": self.config.model, |
| | "embedding_model": self.embedder.model_name, |
| | "embedding_dimension": self.embedder.embedding_dimension, |
| | } |
| |
|