NHS-CHAT / src /query_rag.py
matthewlewis06's picture
Updated frontend and cleaned imports
57d180b
import os
import argparse
import logging
from typing import Dict, List, Optional, Generator, Tuple
from openai import OpenAI
from config import Config, InfoSource
from search_engine import SearchEngine
import voyageai
# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class RAGSystem:
"""Main RAG system class"""
def __init__(self, shared_data=None):
self.config = Config()
# Initialize clients
gemini_api_key = os.getenv("GEMINI_API_KEY")
if gemini_api_key:
self.gemini_client = OpenAI(
api_key=gemini_api_key,
base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
)
else:
self.gemini_client = None
openai_api_key = os.getenv("OPENAI_API_KEY")
if openai_api_key:
self.openai_client = OpenAI(api_key=openai_api_key)
else:
self.openai_client = None
self.voyage_client = voyageai.Client(api_key=os.getenv("VOYAGE_API_KEY"))
self.search_engine = SearchEngine(self.voyage_client)
def _validate_inputs(self, query_text: str, similarity_k: int, info_source: str):
"""Validate input parameters"""
if not query_text or not query_text.strip():
raise ValueError("Query text cannot be empty")
if similarity_k <= 0:
raise ValueError("similarity_k must be a positive integer")
try:
InfoSource(info_source.lower())
except ValueError:
valid_sources = [s.value for s in InfoSource]
raise ValueError(f"Invalid info_source '{info_source}'. Must be one of: {valid_sources}")
def _clean_section_id(self, section_id: str) -> str:
"""Clean section ID for display - NHS format: condition__section__part"""
if not section_id or section_id == 'Unknown section':
return section_id
# Handle NHS format: "adhd-adults__Overview__Part_1"
if '__' in section_id:
parts = section_id.split('__')
if len(parts) >= 2:
# Get condition and section, ignore part number
condition = parts[0].replace('-', ' ').replace('_', ' ').title()
section = parts[1].replace('_', ' ').title()
return f"{condition} - {section}"
# Fallback: just clean up underscores and dashes
clean_section = section_id.replace('_', ' ').replace('-', ' ').title()
return clean_section
def _get_context_text(self, results: List[Dict]) -> str:
"""Generate context text from search results"""
context_text_sections = []
for doc in results:
section_id = doc['metadata'].get('original_id', 'Unknown section')
url = doc['metadata'].get('url', '')
document_text = doc['metadata'].get('document', '')
# Clean up section_id for display
clean_section_id = self._clean_section_id(section_id)
# Create formatted section without showing URL explicitly
# The URL will be available in the document_text if it was part of the original content
formatted_section = (
f"Source Information: [Section: {clean_section_id}]\n"
f"Context: {document_text}"
f"{f' Available at: {url}' if url else ''}" # Include URL for LLM to use
)
context_text_sections.append(formatted_section)
return "\n\n---\n\n".join(context_text_sections)
def _create_system_prompt(self, context_text: str, context_description: str,
not_found_message: str, query_text: str) -> List[Dict]:
"""Create system prompt for LLM"""
return [
{
"role": "system",
"content": (
f"You are a medical AI assistant tasked with answering clinical questions strictly based on the provided {context_description} context. Follow the requirements below to ensure accurate, consistent, and professional responses.\n\n"
"# Response Rules\n\n"
"1. **Context Restriction**:\n"
" - Only use information given in the provided NHS health information context.\n"
" - Do not generate or speculate with information not explicitly found in the given context.\n\n"
"2. **Answer Format**:\n"
" - Provide a clear and concise response based solely on the context.\n"
" - When including a list, use standard markdown bullet points (`*` or `-`).\n"
" - If a list follows introductory text, insert a line break before the first bullet point.\n"
" - Each bullet point must be on its own line.\n\n"
"3. **Preserve Tables**:\n"
" - If relevant markdown tables appear in the context, reproduce them in your answer.\n"
" - Maintain the original structure, formatting, and content of any included tables.\n\n"
"4. **Links and URLs**:\n"
" - Include any URLs or web links from the context directly in your response when relevant.\n"
" - Integrate links naturally within sentences, using markdown syntax for clickable text links.\n"
" - DO NOT generate or invent any URLs not explicitly present in the context.\n\n"
"5. **Markdown Link Formatting**:\n"
" - In responses, only the descriptive text in brackets should be visible and clickable (e.g., `[NHS ADHD information](https://www.nhs.uk/conditions/attention-deficit-hyperactivity-disorder-adhd/)`).\n"
" - Readers should never see raw URLs in the text.\n"
" - Use descriptive link text like 'NHS ADHD information' or 'NHS depression guide' rather than generic terms.\n\n"
"6. **If No Relevant Information**:\n"
" - If the context contains no relevant information, state clearly:\n"
f" *\"{not_found_message}\"*\n\n"
"# Output Format\n\n"
"- All responses should be in plain text, using markdown formatting for lists and links as required.\n"
"- Do not use code blocks.\n"
"- Answers should be concise, accurate, and formatted according to the rules above.\n\n"
"# Examples\n\n"
"**Example 1: Integration of markdown link in context**\n"
"Question: \"What are the symptoms of ADHD?\"\n"
"Context snippet: ...see the NHS information on ADHD symptoms...\n"
"Output:\n"
"According to the [NHS ADHD information](https://www.nhs.uk/conditions/attention-deficit-hyperactivity-disorder-adhd/), symptoms include...\n\n"
"**Example 2: Multiple condition references**\n"
"According to NHS guidance:\n"
"* Initial symptoms may include difficulty concentrating.\n"
"* For detailed information, see the [NHS ADHD guide](https://www.nhs.uk/conditions/adhd/).\n\n"
"**Example 3: No relevant context**\n"
f"{not_found_message}\n\n"
"# Notes\n\n"
"- Never output information beyond what is provided in the supplied context.\n"
"- Always use markdown for lists and links.\n"
"- Make sure all markdown tables from context are preserved in your answer if relevant.\n"
"- Present links only as clickable text, not as bare URLs.\n"
"- Use descriptive link text that indicates the specific NHS condition or topic.\n\n"
"**REMINDER:**\n"
"Strictly adhere to all formatting and content rules above for every response."
),
},
{
"role": "assistant",
"content": (
f"Here is the context from {context_description} that you should use to answer the following question:\n\n{context_text}\n\n"
),
},
{
"role": "user",
"content": query_text,
},
]
def get_sources_from_results(self, results: List[Dict], info_source: str) -> List[Dict]:
"""Extract formatted sources from search results"""
sources = []
for doc in results:
metadata = doc.get('metadata', {})
section_id = metadata.get('original_id', 'Unknown section')
source = metadata.get('source', 'Unknown')
url = metadata.get('url', '')
# Clean section ID for display
clean_section_id = self._clean_section_id(section_id)
source_info = {
'metadata': {
'source': source,
'original_id': section_id,
'url': url,
'clean_section': clean_section_id
}
}
sources.append(source_info)
return sources
def query_rag_stream(self, query_text: str, llm_model: str, similarity_k: int = 25, info_source: str = "NHS",
filename_filter: Optional[str] = None) -> Generator[Tuple[str, List[Dict]], None, None]:
"""Query RAG system with streaming response"""
try:
self._validate_inputs(query_text, similarity_k, info_source)
source_config = self.config.get_source_config(info_source)
# Use the correct namespace from your test
namespace = "nhs_guidelines_voyage_3_large"
# Get similar documents using only similarity search
results = self.search_engine.similarity_search(
query_text,
namespace=namespace,
top_k=similarity_k
)
if not results:
yield "I couldn't find any relevant information to answer your question.", []
return
# Generate context and system prompt
context_text = self._get_context_text(results)
system_messages = self._create_system_prompt(
context_text,
source_config.context_description,
source_config.not_found_message,
query_text
)
# Get sources for response
sources_data = self.get_sources_from_results(results, info_source)
# Stream LLM response
yield from self._stream_llm_response(system_messages, query_text, llm_model, sources_data)
except Exception as e:
logger.error(f"Error in query_rag_stream: {e}")
yield f"An error occurred while processing your query: {str(e)}", []
def _stream_llm_response(self, system_messages: List[Dict], query_text: str,
llm_model: str, sources_data: List[Dict]) -> Generator[Tuple[str, List[Dict]], None, None]:
"""Stream LLM response"""
try:
if "gemini" in llm_model.lower() and self.gemini_client:
stream = self.gemini_client.chat.completions.create(
model=llm_model,
messages=system_messages,
temperature=0,
stream=True
)
for chunk in stream:
if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
content = chunk.choices[0].delta.content
yield content, sources_data
else:
error_msg = f"Unsupported LLM model or client not available: {llm_model}"
logger.error(error_msg)
yield error_msg, []
return
except Exception as e:
logger.error(f"Error in LLM completion: {e}")
yield f"Error generating response: {str(e)}", []
def main():
"""Main function for CLI usage"""
parser = argparse.ArgumentParser(description="RAG System Query Interface")
parser.add_argument("--query_text", type=str, default="What are the symptoms of ADHD in adults?",
help="The query text.")
parser.add_argument("--llm_model", type=str, default="gemini-2.0-flash",
help="The LLM model to use.")
parser.add_argument("--similarity_k", type=int, default=5,
help="Number of results to retrieve in similarity search.")
parser.add_argument("--info_source", type=str, default="NHS",
choices=["nhs", "NHS"],
help="Information source to query.")
args = parser.parse_args()
try:
print("Initializing RAG system...")
rag_system = RAGSystem()
print(f"\n=== Query: {args.query_text} ===")
print(f"Source: {args.info_source}")
print(f"LLM Model: {args.llm_model}")
print("\n=== LLM Response ===\n")
response_text, sources_data = "", []
for chunk, sources in rag_system.query_rag_stream(
query_text=args.query_text,
llm_model=args.llm_model,
similarity_k=args.similarity_k,
info_source=args.info_source
):
print(chunk, end="", flush=True)
response_text += chunk
sources_data = sources
print("\n\n=== Sources Data ===\n")
for i, source in enumerate(sources_data, 1):
metadata = source.get('metadata', {})
print(f"Source {i}:")
print(f" Clean Section: {metadata.get('clean_section', 'Unknown')}")
print(f" URL: {metadata.get('url', 'No URL')}")
print()
except Exception as e:
logger.error(f"Error in main: {e}")
print(f"Error: {e}")
if __name__ == "__main__":
main()