| | """Query engine orchestrating the full RAG pipeline.""" |
| |
|
| | import logging |
| | from pathlib import Path |
| | from typing import Generator, List, Optional |
| |
|
| | from pydantic import BaseModel, Field |
| | from rich.console import Console |
| |
|
| | from src.rag.retriever import HybridRetriever, RetrievalResult |
| | from src.rag.reranker import CrossEncoderReranker |
| | from src.llm.llm_client import LLMClient |
| |
|
| |
|
| | |
| | logging.basicConfig(level=logging.INFO) |
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | |
| | MEDICAL_DISCLAIMER = ( |
| | "**Medical Disclaimer:** This information is sourced from EyeWiki, a resource of the " |
| | "American Academy of Ophthalmology (AAO). It is not a substitute for professional " |
| | "medical advice, diagnosis, or treatment. AI systems can make errors. Always consult " |
| | "with a qualified ophthalmologist or eye care professional for medical concerns and " |
| | "verify any critical information with authoritative sources." |
| | ) |
| |
|
| | |
| | DEFAULT_SYSTEM_PROMPT = """You are an expert ophthalmology assistant with comprehensive knowledge of eye diseases, treatments, and procedures. |
| | |
| | Your role is to provide accurate, evidence-based information from the EyeWiki medical knowledge base. |
| | |
| | Guidelines: |
| | - Base your answers strictly on the provided context |
| | - Cite sources using [Source: Title] format when referencing information |
| | - If the context doesn't contain enough information, say so explicitly |
| | - Use clear, precise medical terminology while remaining accessible |
| | - Structure your responses logically with appropriate sections |
| | - For treatment information, emphasize the importance of professional consultation |
| | - Always maintain professional medical standards""" |
| |
|
| |
|
| | class SourceInfo(BaseModel): |
| | """ |
| | Information about a source document. |
| | |
| | Attributes: |
| | title: Document title |
| | url: Source URL |
| | section: Section within document |
| | relevance_score: Relevance score (cross-encoder scores, unbounded) |
| | """ |
| |
|
| | title: str = Field(..., description="Document title") |
| | url: str = Field(..., description="Source URL") |
| | section: str = Field(default="", description="Section within document") |
| | relevance_score: float = Field(..., description="Relevance score (cross-encoder, unbounded)") |
| |
|
| |
|
| | class QueryResponse(BaseModel): |
| | """ |
| | Response from query engine. |
| | |
| | Attributes: |
| | answer: Generated answer text |
| | sources: List of source documents used |
| | confidence: Confidence score based on retrieval |
| | disclaimer: Medical disclaimer text |
| | query: Original query |
| | """ |
| |
|
| | answer: str = Field(..., description="Generated answer") |
| | sources: List[SourceInfo] = Field(default_factory=list, description="Source documents") |
| | confidence: float = Field(..., ge=0.0, le=1.0, description="Confidence score") |
| | disclaimer: str = Field(default=MEDICAL_DISCLAIMER, description="Medical disclaimer") |
| | query: str = Field(..., description="Original query") |
| |
|
| |
|
| | class EyeWikiQueryEngine: |
| | """ |
| | Query engine orchestrating the full RAG pipeline. |
| | |
| | Pipeline: |
| | 1. Query � Retriever (hybrid search) |
| | 2. Results � Reranker (cross-encoder) |
| | 3. Top results � Context assembly |
| | 4. Context + Query � LLM generation |
| | 5. Response + Sources + Disclaimer |
| | |
| | Features: |
| | - Two-stage retrieval (fast + precise) |
| | - Context assembly with token limits |
| | - Source diversity prioritization |
| | - Medical disclaimer inclusion |
| | - Streaming and non-streaming modes |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | retriever: HybridRetriever, |
| | reranker: CrossEncoderReranker, |
| | llm_client: LLMClient, |
| | system_prompt_path: Optional[Path] = None, |
| | query_prompt_path: Optional[Path] = None, |
| | disclaimer_path: Optional[Path] = None, |
| | max_context_tokens: int = 4000, |
| | retrieval_k: int = 20, |
| | rerank_k: int = 5, |
| | ): |
| | """ |
| | Initialize query engine. |
| | |
| | Args: |
| | retriever: HybridRetriever instance |
| | reranker: CrossEncoderReranker instance |
| | llm_client: LLMClient instance (OllamaClient or OpenAIClient) |
| | system_prompt_path: Path to custom system prompt file |
| | query_prompt_path: Path to custom query prompt template |
| | disclaimer_path: Path to custom medical disclaimer file |
| | max_context_tokens: Maximum tokens for context |
| | retrieval_k: Number of documents to retrieve initially |
| | rerank_k: Number of documents after reranking |
| | """ |
| | self.retriever = retriever |
| | self.reranker = reranker |
| | self.llm_client = llm_client |
| | self.max_context_tokens = max_context_tokens |
| | self.retrieval_k = retrieval_k |
| | self.rerank_k = rerank_k |
| |
|
| | self.console = Console() |
| |
|
| | |
| | if system_prompt_path and system_prompt_path.exists(): |
| | with open(system_prompt_path, "r") as f: |
| | self.system_prompt = f.read() |
| | logger.info(f"Loaded system prompt from {system_prompt_path}") |
| | else: |
| | self.system_prompt = DEFAULT_SYSTEM_PROMPT |
| | logger.info("Using default system prompt") |
| |
|
| | |
| | if query_prompt_path and query_prompt_path.exists(): |
| | with open(query_prompt_path, "r") as f: |
| | self.query_prompt_template = f.read() |
| | logger.info(f"Loaded query prompt from {query_prompt_path}") |
| | else: |
| | self.query_prompt_template = None |
| | logger.info("Using inline query prompt formatting") |
| |
|
| | |
| | if disclaimer_path and disclaimer_path.exists(): |
| | with open(disclaimer_path, "r") as f: |
| | self.medical_disclaimer = f.read().strip() |
| | logger.info(f"Loaded medical disclaimer from {disclaimer_path}") |
| | else: |
| | self.medical_disclaimer = MEDICAL_DISCLAIMER |
| | logger.info("Using default medical disclaimer") |
| |
|
| | def _estimate_tokens(self, text: str) -> int: |
| | """ |
| | Estimate token count for text. |
| | |
| | Uses simple heuristic: ~4 characters per token. |
| | |
| | Args: |
| | text: Input text |
| | |
| | Returns: |
| | Estimated token count |
| | """ |
| | return len(text) // 4 |
| |
|
| | def _prioritize_diverse_sources( |
| | self, results: List[RetrievalResult] |
| | ) -> List[RetrievalResult]: |
| | """ |
| | Prioritize results from diverse sources. |
| | |
| | Ensures we don't just get multiple chunks from the same article. |
| | |
| | Args: |
| | results: Sorted list of retrieval results |
| | |
| | Returns: |
| | Reordered list prioritizing diversity |
| | """ |
| | seen_documents = set() |
| | diverse_results = [] |
| | remaining_results = [] |
| |
|
| | |
| | for result in results: |
| | doc_title = result.document_title |
| | if doc_title not in seen_documents: |
| | diverse_results.append(result) |
| | seen_documents.add(doc_title) |
| | else: |
| | remaining_results.append(result) |
| |
|
| | |
| | diverse_results.extend(remaining_results) |
| |
|
| | return diverse_results |
| |
|
| | def _assemble_context(self, results: List[RetrievalResult]) -> str: |
| | """ |
| | Assemble context from retrieval results. |
| | |
| | Features: |
| | - Formats with section headers |
| | - Limits to max_context_tokens |
| | - Prioritizes diverse sources |
| | - Includes source citations |
| | |
| | Args: |
| | results: List of retrieval results |
| | |
| | Returns: |
| | Formatted context string |
| | """ |
| | if not results: |
| | return "" |
| |
|
| | |
| | diverse_results = self._prioritize_diverse_sources(results) |
| |
|
| | context_parts = [] |
| | total_tokens = 0 |
| |
|
| | for i, result in enumerate(diverse_results, 1): |
| | |
| | chunk_text = f"[Source {i}: {result.document_title}" |
| | if result.section: |
| | chunk_text += f" - {result.section}" |
| | chunk_text += f"]\n{result.content}\n" |
| |
|
| | |
| | chunk_tokens = self._estimate_tokens(chunk_text) |
| |
|
| | if total_tokens + chunk_tokens > self.max_context_tokens: |
| | logger.info( |
| | f"Reached context token limit ({self.max_context_tokens}), " |
| | f"using {i-1} of {len(diverse_results)} chunks" |
| | ) |
| | break |
| |
|
| | context_parts.append(chunk_text) |
| | total_tokens += chunk_tokens |
| |
|
| | context = "\n".join(context_parts) |
| |
|
| | logger.info( |
| | f"Assembled context: {len(context_parts)} chunks, " |
| | f"~{total_tokens} tokens" |
| | ) |
| |
|
| | return context |
| |
|
| | def _extract_sources(self, results: List[RetrievalResult]) -> List[SourceInfo]: |
| | """ |
| | Extract source information from results. |
| | |
| | Args: |
| | results: List of retrieval results |
| | |
| | Returns: |
| | List of SourceInfo objects |
| | """ |
| | sources = [] |
| | seen_titles = set() |
| |
|
| | for result in results: |
| | |
| | if result.document_title not in seen_titles: |
| | source = SourceInfo( |
| | title=result.document_title, |
| | url=result.source_url, |
| | section=result.section, |
| | relevance_score=result.score, |
| | ) |
| | sources.append(source) |
| | seen_titles.add(result.document_title) |
| |
|
| | return sources |
| |
|
| | def _calculate_confidence(self, results: List[RetrievalResult]) -> float: |
| | """ |
| | Calculate confidence score based on retrieval scores. |
| | |
| | Uses average of top reranked scores. |
| | |
| | Args: |
| | results: List of retrieval results |
| | |
| | Returns: |
| | Confidence score (0-1) |
| | """ |
| | if not results: |
| | return 0.0 |
| |
|
| | |
| | top_scores = [r.score for r in results[:self.rerank_k]] |
| |
|
| | if not top_scores: |
| | return 0.0 |
| |
|
| | avg_score = sum(top_scores) / len(top_scores) |
| |
|
| | |
| | confidence = min(max(avg_score, 0.0), 1.0) |
| |
|
| | return confidence |
| |
|
| | def _format_prompt(self, query: str, context: str) -> str: |
| | """ |
| | Format the prompt for LLM. |
| | |
| | Uses query_prompt_template if loaded, otherwise uses default format. |
| | |
| | Args: |
| | query: User query |
| | context: Assembled context |
| | |
| | Returns: |
| | Formatted prompt |
| | """ |
| | if self.query_prompt_template: |
| | |
| | prompt = self.query_prompt_template.format( |
| | context=context, |
| | question=query |
| | ) |
| | else: |
| | |
| | prompt = f"""Context from EyeWiki medical knowledge base: |
| | |
| | {context} |
| | |
| | --- |
| | |
| | Question: {query} |
| | |
| | Please provide a comprehensive answer based on the context above. Structure your response clearly and cite sources where appropriate.""" |
| |
|
| | return prompt |
| |
|
| | def query( |
| | self, |
| | question: str, |
| | include_sources: bool = True, |
| | filters: Optional[dict] = None, |
| | ) -> QueryResponse: |
| | """ |
| | Query the engine and get response. |
| | |
| | Pipeline: |
| | 1. Retrieve documents (retrieval_k) |
| | 2. Rerank with cross-encoder (rerank_k) |
| | 3. Assemble context with token limits |
| | 4. Generate answer with LLM |
| | 5. Return response with sources and disclaimer |
| | |
| | Args: |
| | question: User question |
| | include_sources: Include source information in response |
| | filters: Optional metadata filters for retrieval |
| | |
| | Returns: |
| | QueryResponse object |
| | """ |
| | logger.info(f"Processing query: '{question}'") |
| |
|
| | |
| | logger.info(f"Retrieving top {self.retrieval_k} candidates...") |
| | retrieval_results = self.retriever.retrieve( |
| | query=question, |
| | top_k=self.retrieval_k, |
| | filters=filters, |
| | ) |
| |
|
| | if not retrieval_results: |
| | logger.warning("No results found for query") |
| | return QueryResponse( |
| | answer="I couldn't find relevant information to answer this question in the EyeWiki knowledge base.", |
| | sources=[], |
| | confidence=0.0, |
| | query=question, |
| | ) |
| |
|
| | |
| | logger.info(f"Reranking to top {self.rerank_k}...") |
| | reranked_results = self.reranker.rerank( |
| | query=question, |
| | documents=retrieval_results, |
| | top_k=self.rerank_k, |
| | ) |
| |
|
| | |
| | context = self._assemble_context(reranked_results) |
| |
|
| | |
| | logger.info("Generating answer with LLM...") |
| | prompt = self._format_prompt(question, context) |
| |
|
| | try: |
| | answer = self.llm_client.generate( |
| | prompt=prompt, |
| | system_prompt=self.system_prompt, |
| | temperature=0.1, |
| | ) |
| | except Exception as e: |
| | logger.error(f"Error generating answer: {e}") |
| | answer = ( |
| | "I encountered an error while generating the answer. " |
| | "Please try again or rephrase your question." |
| | ) |
| |
|
| | |
| | sources = self._extract_sources(reranked_results) if include_sources else [] |
| |
|
| | |
| | confidence = self._calculate_confidence(reranked_results) |
| |
|
| | |
| | response = QueryResponse( |
| | answer=answer, |
| | sources=sources, |
| | confidence=confidence, |
| | query=question, |
| | ) |
| |
|
| | logger.info( |
| | f"Query complete: {len(sources)} sources, " |
| | f"confidence: {confidence:.2f}" |
| | ) |
| |
|
| | return response |
| |
|
| | def stream_query( |
| | self, |
| | question: str, |
| | filters: Optional[dict] = None, |
| | ) -> Generator[str, None, None]: |
| | """ |
| | Query with streaming response. |
| | |
| | Yields answer chunks in real-time. |
| | |
| | Args: |
| | question: User question |
| | filters: Optional metadata filters |
| | |
| | Yields: |
| | Answer chunks as they are generated |
| | """ |
| | logger.info(f"Processing streaming query: '{question}'") |
| |
|
| | |
| | retrieval_results = self.retriever.retrieve( |
| | query=question, |
| | top_k=self.retrieval_k, |
| | filters=filters, |
| | ) |
| |
|
| | if not retrieval_results: |
| | yield "I couldn't find relevant information to answer this question." |
| | return |
| |
|
| | reranked_results = self.reranker.rerank( |
| | query=question, |
| | documents=retrieval_results, |
| | top_k=self.rerank_k, |
| | ) |
| |
|
| | |
| | context = self._assemble_context(reranked_results) |
| |
|
| | |
| | prompt = self._format_prompt(question, context) |
| |
|
| | |
| | try: |
| | for chunk in self.llm_client.stream_generate( |
| | prompt=prompt, |
| | system_prompt=self.system_prompt, |
| | temperature=0.1, |
| | ): |
| | yield chunk |
| |
|
| | except Exception as e: |
| | logger.error(f"Error in streaming generation: {e}") |
| | yield "\n\n[Error: Failed to generate response]" |
| |
|
| | def batch_query( |
| | self, |
| | questions: List[str], |
| | include_sources: bool = True, |
| | ) -> List[QueryResponse]: |
| | """ |
| | Process multiple queries. |
| | |
| | Args: |
| | questions: List of questions |
| | include_sources: Include sources in responses |
| | |
| | Returns: |
| | List of QueryResponse objects |
| | """ |
| | responses = [] |
| |
|
| | for question in questions: |
| | response = self.query(question, include_sources=include_sources) |
| | responses.append(response) |
| |
|
| | return responses |
| |
|
| | def get_pipeline_info(self) -> dict: |
| | """ |
| | Get information about the pipeline configuration. |
| | |
| | Returns: |
| | Dictionary with pipeline settings |
| | """ |
| | return { |
| | "retrieval_k": self.retrieval_k, |
| | "rerank_k": self.rerank_k, |
| | "max_context_tokens": self.max_context_tokens, |
| | "retriever_config": { |
| | "dense_weight": self.retriever.dense_weight, |
| | "sparse_weight": self.retriever.sparse_weight, |
| | "term_expansion": self.retriever.enable_term_expansion, |
| | }, |
| | "reranker_info": self.reranker.get_model_info(), |
| | "llm_model": self.llm_client.llm_model, |
| | } |
| |
|