import os import argparse import logging from typing import Dict, List, Optional, Generator, Tuple from openai import OpenAI from config import Config, InfoSource from search_engine import SearchEngine import voyageai # Setup logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) class RAGSystem: """Main RAG system class""" def __init__(self, shared_data=None): self.config = Config() # Initialize clients gemini_api_key = os.getenv("GEMINI_API_KEY") if gemini_api_key: self.gemini_client = OpenAI( api_key=gemini_api_key, base_url="https://generativelanguage.googleapis.com/v1beta/openai/" ) else: self.gemini_client = None openai_api_key = os.getenv("OPENAI_API_KEY") if openai_api_key: self.openai_client = OpenAI(api_key=openai_api_key) else: self.openai_client = None self.voyage_client = voyageai.Client(api_key=os.getenv("VOYAGE_API_KEY")) self.search_engine = SearchEngine(self.voyage_client) def _validate_inputs(self, query_text: str, similarity_k: int, info_source: str): """Validate input parameters""" if not query_text or not query_text.strip(): raise ValueError("Query text cannot be empty") if similarity_k <= 0: raise ValueError("similarity_k must be a positive integer") try: InfoSource(info_source.lower()) except ValueError: valid_sources = [s.value for s in InfoSource] raise ValueError(f"Invalid info_source '{info_source}'. Must be one of: {valid_sources}") def _clean_section_id(self, section_id: str) -> str: """Clean section ID for display - NHS format: condition__section__part""" if not section_id or section_id == 'Unknown section': return section_id # Handle NHS format: "adhd-adults__Overview__Part_1" if '__' in section_id: parts = section_id.split('__') if len(parts) >= 2: # Get condition and section, ignore part number condition = parts[0].replace('-', ' ').replace('_', ' ').title() section = parts[1].replace('_', ' ').title() return f"{condition} - {section}" # Fallback: just clean up underscores and dashes clean_section = section_id.replace('_', ' ').replace('-', ' ').title() return clean_section def _get_context_text(self, results: List[Dict]) -> str: """Generate context text from search results""" context_text_sections = [] for doc in results: section_id = doc['metadata'].get('original_id', 'Unknown section') url = doc['metadata'].get('url', '') document_text = doc['metadata'].get('document', '') # Clean up section_id for display clean_section_id = self._clean_section_id(section_id) # Create formatted section without showing URL explicitly # The URL will be available in the document_text if it was part of the original content formatted_section = ( f"Source Information: [Section: {clean_section_id}]\n" f"Context: {document_text}" f"{f' Available at: {url}' if url else ''}" # Include URL for LLM to use ) context_text_sections.append(formatted_section) return "\n\n---\n\n".join(context_text_sections) def _create_system_prompt(self, context_text: str, context_description: str, not_found_message: str, query_text: str) -> List[Dict]: """Create system prompt for LLM""" return [ { "role": "system", "content": ( f"You are a medical AI assistant tasked with answering clinical questions strictly based on the provided {context_description} context. Follow the requirements below to ensure accurate, consistent, and professional responses.\n\n" "# Response Rules\n\n" "1. **Context Restriction**:\n" " - Only use information given in the provided NHS health information context.\n" " - Do not generate or speculate with information not explicitly found in the given context.\n\n" "2. **Answer Format**:\n" " - Provide a clear and concise response based solely on the context.\n" " - When including a list, use standard markdown bullet points (`*` or `-`).\n" " - If a list follows introductory text, insert a line break before the first bullet point.\n" " - Each bullet point must be on its own line.\n\n" "3. **Preserve Tables**:\n" " - If relevant markdown tables appear in the context, reproduce them in your answer.\n" " - Maintain the original structure, formatting, and content of any included tables.\n\n" "4. **Links and URLs**:\n" " - Include any URLs or web links from the context directly in your response when relevant.\n" " - Integrate links naturally within sentences, using markdown syntax for clickable text links.\n" " - DO NOT generate or invent any URLs not explicitly present in the context.\n\n" "5. **Markdown Link Formatting**:\n" " - In responses, only the descriptive text in brackets should be visible and clickable (e.g., `[NHS ADHD information](https://www.nhs.uk/conditions/attention-deficit-hyperactivity-disorder-adhd/)`).\n" " - Readers should never see raw URLs in the text.\n" " - Use descriptive link text like 'NHS ADHD information' or 'NHS depression guide' rather than generic terms.\n\n" "6. **If No Relevant Information**:\n" " - If the context contains no relevant information, state clearly:\n" f" *\"{not_found_message}\"*\n\n" "# Output Format\n\n" "- All responses should be in plain text, using markdown formatting for lists and links as required.\n" "- Do not use code blocks.\n" "- Answers should be concise, accurate, and formatted according to the rules above.\n\n" "# Examples\n\n" "**Example 1: Integration of markdown link in context**\n" "Question: \"What are the symptoms of ADHD?\"\n" "Context snippet: ...see the NHS information on ADHD symptoms...\n" "Output:\n" "According to the [NHS ADHD information](https://www.nhs.uk/conditions/attention-deficit-hyperactivity-disorder-adhd/), symptoms include...\n\n" "**Example 2: Multiple condition references**\n" "According to NHS guidance:\n" "* Initial symptoms may include difficulty concentrating.\n" "* For detailed information, see the [NHS ADHD guide](https://www.nhs.uk/conditions/adhd/).\n\n" "**Example 3: No relevant context**\n" f"{not_found_message}\n\n" "# Notes\n\n" "- Never output information beyond what is provided in the supplied context.\n" "- Always use markdown for lists and links.\n" "- Make sure all markdown tables from context are preserved in your answer if relevant.\n" "- Present links only as clickable text, not as bare URLs.\n" "- Use descriptive link text that indicates the specific NHS condition or topic.\n\n" "**REMINDER:**\n" "Strictly adhere to all formatting and content rules above for every response." ), }, { "role": "assistant", "content": ( f"Here is the context from {context_description} that you should use to answer the following question:\n\n{context_text}\n\n" ), }, { "role": "user", "content": query_text, }, ] def get_sources_from_results(self, results: List[Dict], info_source: str) -> List[Dict]: """Extract formatted sources from search results""" sources = [] for doc in results: metadata = doc.get('metadata', {}) section_id = metadata.get('original_id', 'Unknown section') source = metadata.get('source', 'Unknown') url = metadata.get('url', '') # Clean section ID for display clean_section_id = self._clean_section_id(section_id) source_info = { 'metadata': { 'source': source, 'original_id': section_id, 'url': url, 'clean_section': clean_section_id } } sources.append(source_info) return sources def query_rag_stream(self, query_text: str, llm_model: str, similarity_k: int = 25, info_source: str = "NHS", filename_filter: Optional[str] = None) -> Generator[Tuple[str, List[Dict]], None, None]: """Query RAG system with streaming response""" try: self._validate_inputs(query_text, similarity_k, info_source) source_config = self.config.get_source_config(info_source) # Use the correct namespace from your test namespace = "nhs_guidelines_voyage_3_large" # Get similar documents using only similarity search results = self.search_engine.similarity_search( query_text, namespace=namespace, top_k=similarity_k ) if not results: yield "I couldn't find any relevant information to answer your question.", [] return # Generate context and system prompt context_text = self._get_context_text(results) system_messages = self._create_system_prompt( context_text, source_config.context_description, source_config.not_found_message, query_text ) # Get sources for response sources_data = self.get_sources_from_results(results, info_source) # Stream LLM response yield from self._stream_llm_response(system_messages, query_text, llm_model, sources_data) except Exception as e: logger.error(f"Error in query_rag_stream: {e}") yield f"An error occurred while processing your query: {str(e)}", [] def _stream_llm_response(self, system_messages: List[Dict], query_text: str, llm_model: str, sources_data: List[Dict]) -> Generator[Tuple[str, List[Dict]], None, None]: """Stream LLM response""" try: if "gemini" in llm_model.lower() and self.gemini_client: stream = self.gemini_client.chat.completions.create( model=llm_model, messages=system_messages, temperature=0, stream=True ) for chunk in stream: if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: content = chunk.choices[0].delta.content yield content, sources_data else: error_msg = f"Unsupported LLM model or client not available: {llm_model}" logger.error(error_msg) yield error_msg, [] return except Exception as e: logger.error(f"Error in LLM completion: {e}") yield f"Error generating response: {str(e)}", [] def main(): """Main function for CLI usage""" parser = argparse.ArgumentParser(description="RAG System Query Interface") parser.add_argument("--query_text", type=str, default="What are the symptoms of ADHD in adults?", help="The query text.") parser.add_argument("--llm_model", type=str, default="gemini-2.0-flash", help="The LLM model to use.") parser.add_argument("--similarity_k", type=int, default=5, help="Number of results to retrieve in similarity search.") parser.add_argument("--info_source", type=str, default="NHS", choices=["nhs", "NHS"], help="Information source to query.") args = parser.parse_args() try: print("Initializing RAG system...") rag_system = RAGSystem() print(f"\n=== Query: {args.query_text} ===") print(f"Source: {args.info_source}") print(f"LLM Model: {args.llm_model}") print("\n=== LLM Response ===\n") response_text, sources_data = "", [] for chunk, sources in rag_system.query_rag_stream( query_text=args.query_text, llm_model=args.llm_model, similarity_k=args.similarity_k, info_source=args.info_source ): print(chunk, end="", flush=True) response_text += chunk sources_data = sources print("\n\n=== Sources Data ===\n") for i, source in enumerate(sources_data, 1): metadata = source.get('metadata', {}) print(f"Source {i}:") print(f" Clean Section: {metadata.get('clean_section', 'Unknown')}") print(f" URL: {metadata.get('url', 'No URL')}") print() except Exception as e: logger.error(f"Error in main: {e}") print(f"Error: {e}") if __name__ == "__main__": main()