"""Query engine orchestrating the full RAG pipeline.""" import logging from pathlib import Path from typing import Generator, List, Optional from pydantic import BaseModel, Field from rich.console import Console from src.rag.retriever import HybridRetriever, RetrievalResult from src.rag.reranker import CrossEncoderReranker from src.llm.llm_client import LLMClient # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Medical disclaimer (default) MEDICAL_DISCLAIMER = ( "**Medical Disclaimer:** This information is sourced from EyeWiki, a resource of the " "American Academy of Ophthalmology (AAO). It is not a substitute for professional " "medical advice, diagnosis, or treatment. AI systems can make errors. Always consult " "with a qualified ophthalmologist or eye care professional for medical concerns and " "verify any critical information with authoritative sources." ) # Default system prompt DEFAULT_SYSTEM_PROMPT = """You are an expert ophthalmology assistant with comprehensive knowledge of eye diseases, treatments, and procedures. Your role is to provide accurate, evidence-based information from the EyeWiki medical knowledge base. Guidelines: - Base your answers strictly on the provided context - Cite sources using [Source: Title] format when referencing information - If the context doesn't contain enough information, say so explicitly - Use clear, precise medical terminology while remaining accessible - Structure your responses logically with appropriate sections - For treatment information, emphasize the importance of professional consultation - Always maintain professional medical standards""" class SourceInfo(BaseModel): """ Information about a source document. Attributes: title: Document title url: Source URL section: Section within document relevance_score: Relevance score (cross-encoder scores, unbounded) """ title: str = Field(..., description="Document title") url: str = Field(..., description="Source URL") section: str = Field(default="", description="Section within document") relevance_score: float = Field(..., description="Relevance score (cross-encoder, unbounded)") class QueryResponse(BaseModel): """ Response from query engine. Attributes: answer: Generated answer text sources: List of source documents used confidence: Confidence score based on retrieval disclaimer: Medical disclaimer text query: Original query """ answer: str = Field(..., description="Generated answer") sources: List[SourceInfo] = Field(default_factory=list, description="Source documents") confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score") disclaimer: str = Field(default=MEDICAL_DISCLAIMER, description="Medical disclaimer") query: str = Field(..., description="Original query") class EyeWikiQueryEngine: """ Query engine orchestrating the full RAG pipeline. Pipeline: 1. Query � Retriever (hybrid search) 2. Results � Reranker (cross-encoder) 3. Top results � Context assembly 4. Context + Query � LLM generation 5. Response + Sources + Disclaimer Features: - Two-stage retrieval (fast + precise) - Context assembly with token limits - Source diversity prioritization - Medical disclaimer inclusion - Streaming and non-streaming modes """ def __init__( self, retriever: HybridRetriever, reranker: CrossEncoderReranker, llm_client: LLMClient, system_prompt_path: Optional[Path] = None, query_prompt_path: Optional[Path] = None, disclaimer_path: Optional[Path] = None, max_context_tokens: int = 4000, retrieval_k: int = 20, rerank_k: int = 5, ): """ Initialize query engine. Args: retriever: HybridRetriever instance reranker: CrossEncoderReranker instance llm_client: LLMClient instance (OllamaClient or OpenAIClient) system_prompt_path: Path to custom system prompt file query_prompt_path: Path to custom query prompt template disclaimer_path: Path to custom medical disclaimer file max_context_tokens: Maximum tokens for context retrieval_k: Number of documents to retrieve initially rerank_k: Number of documents after reranking """ self.retriever = retriever self.reranker = reranker self.llm_client = llm_client self.max_context_tokens = max_context_tokens self.retrieval_k = retrieval_k self.rerank_k = rerank_k self.console = Console() # Load system prompt if system_prompt_path and system_prompt_path.exists(): with open(system_prompt_path, "r") as f: self.system_prompt = f.read() logger.info(f"Loaded system prompt from {system_prompt_path}") else: self.system_prompt = DEFAULT_SYSTEM_PROMPT logger.info("Using default system prompt") # Load query prompt template if query_prompt_path and query_prompt_path.exists(): with open(query_prompt_path, "r") as f: self.query_prompt_template = f.read() logger.info(f"Loaded query prompt from {query_prompt_path}") else: self.query_prompt_template = None logger.info("Using inline query prompt formatting") # Load medical disclaimer if disclaimer_path and disclaimer_path.exists(): with open(disclaimer_path, "r") as f: self.medical_disclaimer = f.read().strip() logger.info(f"Loaded medical disclaimer from {disclaimer_path}") else: self.medical_disclaimer = MEDICAL_DISCLAIMER logger.info("Using default medical disclaimer") def _estimate_tokens(self, text: str) -> int: """ Estimate token count for text. Uses simple heuristic: ~4 characters per token. Args: text: Input text Returns: Estimated token count """ return len(text) // 4 def _prioritize_diverse_sources( self, results: List[RetrievalResult] ) -> List[RetrievalResult]: """ Prioritize results from diverse sources. Ensures we don't just get multiple chunks from the same article. Args: results: Sorted list of retrieval results Returns: Reordered list prioritizing diversity """ seen_documents = set() diverse_results = [] remaining_results = [] # First pass: one chunk per document for result in results: doc_title = result.document_title if doc_title not in seen_documents: diverse_results.append(result) seen_documents.add(doc_title) else: remaining_results.append(result) # Second pass: add remaining high-scoring chunks diverse_results.extend(remaining_results) return diverse_results def _assemble_context(self, results: List[RetrievalResult]) -> str: """ Assemble context from retrieval results. Features: - Formats with section headers - Limits to max_context_tokens - Prioritizes diverse sources - Includes source citations Args: results: List of retrieval results Returns: Formatted context string """ if not results: return "" # Prioritize diversity diverse_results = self._prioritize_diverse_sources(results) context_parts = [] total_tokens = 0 for i, result in enumerate(diverse_results, 1): # Format context chunk chunk_text = f"[Source {i}: {result.document_title}" if result.section: chunk_text += f" - {result.section}" chunk_text += f"]\n{result.content}\n" # Check token limit chunk_tokens = self._estimate_tokens(chunk_text) if total_tokens + chunk_tokens > self.max_context_tokens: logger.info( f"Reached context token limit ({self.max_context_tokens}), " f"using {i-1} of {len(diverse_results)} chunks" ) break context_parts.append(chunk_text) total_tokens += chunk_tokens context = "\n".join(context_parts) logger.info( f"Assembled context: {len(context_parts)} chunks, " f"~{total_tokens} tokens" ) return context def _extract_sources(self, results: List[RetrievalResult]) -> List[SourceInfo]: """ Extract source information from results. Args: results: List of retrieval results Returns: List of SourceInfo objects """ sources = [] seen_titles = set() for result in results: # Deduplicate by title if result.document_title not in seen_titles: source = SourceInfo( title=result.document_title, url=result.source_url, section=result.section, relevance_score=result.score, ) sources.append(source) seen_titles.add(result.document_title) return sources def _calculate_confidence(self, results: List[RetrievalResult]) -> float: """ Calculate confidence score based on retrieval scores. Uses average of top reranked scores. Args: results: List of retrieval results Returns: Confidence score (0-1) """ if not results: return 0.0 # Use average of top scores top_scores = [r.score for r in results[:self.rerank_k]] if not top_scores: return 0.0 avg_score = sum(top_scores) / len(top_scores) # Normalize to 0-1 range (assuming scores are roughly 0-1) confidence = min(max(avg_score, 0.0), 1.0) return confidence def _format_prompt(self, query: str, context: str) -> str: """ Format the prompt for LLM. Uses query_prompt_template if loaded, otherwise uses default format. Args: query: User query context: Assembled context Returns: Formatted prompt """ if self.query_prompt_template: # Use template with placeholders prompt = self.query_prompt_template.format( context=context, question=query ) else: # Default inline formatting prompt = f"""Context from EyeWiki medical knowledge base: {context} --- Question: {query} Please provide a comprehensive answer based on the context above. Structure your response clearly and cite sources where appropriate.""" return prompt def query( self, question: str, include_sources: bool = True, filters: Optional[dict] = None, ) -> QueryResponse: """ Query the engine and get response. Pipeline: 1. Retrieve documents (retrieval_k) 2. Rerank with cross-encoder (rerank_k) 3. Assemble context with token limits 4. Generate answer with LLM 5. Return response with sources and disclaimer Args: question: User question include_sources: Include source information in response filters: Optional metadata filters for retrieval Returns: QueryResponse object """ logger.info(f"Processing query: '{question}'") # Step 1: Retrieve documents logger.info(f"Retrieving top {self.retrieval_k} candidates...") retrieval_results = self.retriever.retrieve( query=question, top_k=self.retrieval_k, filters=filters, ) if not retrieval_results: logger.warning("No results found for query") return QueryResponse( answer="I couldn't find relevant information to answer this question in the EyeWiki knowledge base.", sources=[], confidence=0.0, query=question, ) # Step 2: Rerank for precision logger.info(f"Reranking to top {self.rerank_k}...") reranked_results = self.reranker.rerank( query=question, documents=retrieval_results, top_k=self.rerank_k, ) # Step 3: Assemble context context = self._assemble_context(reranked_results) # Step 4: Generate answer logger.info("Generating answer with LLM...") prompt = self._format_prompt(question, context) try: answer = self.llm_client.generate( prompt=prompt, system_prompt=self.system_prompt, temperature=0.1, # Low temperature for factual responses ) except Exception as e: logger.error(f"Error generating answer: {e}") answer = ( "I encountered an error while generating the answer. " "Please try again or rephrase your question." ) # Step 5: Extract sources sources = self._extract_sources(reranked_results) if include_sources else [] # Step 6: Calculate confidence confidence = self._calculate_confidence(reranked_results) # Create response response = QueryResponse( answer=answer, sources=sources, confidence=confidence, query=question, ) logger.info( f"Query complete: {len(sources)} sources, " f"confidence: {confidence:.2f}" ) return response def stream_query( self, question: str, filters: Optional[dict] = None, ) -> Generator[str, None, None]: """ Query with streaming response. Yields answer chunks in real-time. Args: question: User question filters: Optional metadata filters Yields: Answer chunks as they are generated """ logger.info(f"Processing streaming query: '{question}'") # Retrieval and reranking (same as query()) retrieval_results = self.retriever.retrieve( query=question, top_k=self.retrieval_k, filters=filters, ) if not retrieval_results: yield "I couldn't find relevant information to answer this question." return reranked_results = self.reranker.rerank( query=question, documents=retrieval_results, top_k=self.rerank_k, ) # Assemble context context = self._assemble_context(reranked_results) # Generate prompt prompt = self._format_prompt(question, context) # Stream generation try: for chunk in self.llm_client.stream_generate( prompt=prompt, system_prompt=self.system_prompt, temperature=0.1, ): yield chunk except Exception as e: logger.error(f"Error in streaming generation: {e}") yield "\n\n[Error: Failed to generate response]" def batch_query( self, questions: List[str], include_sources: bool = True, ) -> List[QueryResponse]: """ Process multiple queries. Args: questions: List of questions include_sources: Include sources in responses Returns: List of QueryResponse objects """ responses = [] for question in questions: response = self.query(question, include_sources=include_sources) responses.append(response) return responses def get_pipeline_info(self) -> dict: """ Get information about the pipeline configuration. Returns: Dictionary with pipeline settings """ return { "retrieval_k": self.retrieval_k, "rerank_k": self.rerank_k, "max_context_tokens": self.max_context_tokens, "retriever_config": { "dense_weight": self.retriever.dense_weight, "sparse_weight": self.retriever.sparse_weight, "term_expansion": self.retriever.enable_term_expansion, }, "reranker_info": self.reranker.get_model_info(), "llm_model": self.llm_client.llm_model, }