File size: 14,515 Bytes
b69b364
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
import os
import argparse
import logging
from typing import Dict, List, Optional, Generator, Tuple
from openai import OpenAI
from config import Config, InfoSource
from search_engine import SearchEngine
import voyageai

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

class RAGSystem:
    """Main RAG system class"""
    
    def __init__(self, shared_data=None):
        self.config = Config()
        
        # Initialize clients
        gemini_api_key = os.getenv("GEMINI_API_KEY")
        if gemini_api_key:
            self.gemini_client = OpenAI(
                api_key=gemini_api_key, 
                base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
            )
        else:
            self.gemini_client = None

        openai_api_key = os.getenv("OPENAI_API_KEY")
        if openai_api_key:
            self.openai_client = OpenAI(api_key=openai_api_key)
        else:
            self.openai_client = None
            
        self.voyage_client = voyageai.Client(api_key=os.getenv("VOYAGE_API_KEY"))
        self.search_engine = SearchEngine(self.voyage_client)
        
    def _validate_inputs(self, query_text: str, similarity_k: int, info_source: str):
        """Validate input parameters"""
        if not query_text or not query_text.strip():
            raise ValueError("Query text cannot be empty")
        
        if similarity_k <= 0:
            raise ValueError("similarity_k must be a positive integer")
        
        try:
            InfoSource(info_source.lower())
        except ValueError:
            valid_sources = [s.value for s in InfoSource]
            raise ValueError(f"Invalid info_source '{info_source}'. Must be one of: {valid_sources}")

    def _clean_section_id(self, section_id: str) -> str:
        """Clean section ID for display - NHS format: condition__section__part"""
        if not section_id or section_id == 'Unknown section':
            return section_id
        
        # Handle NHS format: "adhd-adults__Overview__Part_1"
        if '__' in section_id:
            parts = section_id.split('__')
            if len(parts) >= 2:
                # Get condition and section, ignore part number
                condition = parts[0].replace('-', ' ').replace('_', ' ').title()
                section = parts[1].replace('_', ' ').title()
                return f"{condition} - {section}"
        
        # Fallback: just clean up underscores and dashes
        clean_section = section_id.replace('_', ' ').replace('-', ' ').title()
        return clean_section
    
    def _get_context_text(self, results: List[Dict]) -> str:
        """Generate context text from search results"""
        context_text_sections = []

        for doc in results:
            section_id = doc['metadata'].get('original_id', 'Unknown section')
            url = doc['metadata'].get('url', '')
            document_text = doc['metadata'].get('document', '')
            
            # Clean up section_id for display
            clean_section_id = self._clean_section_id(section_id)
            
            # Create formatted section without showing URL explicitly
            # The URL will be available in the document_text if it was part of the original content
            formatted_section = (
                f"Source Information: [Section: {clean_section_id}]\n"
                f"Context: {document_text}"
                f"{f' Available at: {url}' if url else ''}"  # Include URL for LLM to use
            )
            context_text_sections.append(formatted_section)
        
        return "\n\n---\n\n".join(context_text_sections)
        
    def _create_system_prompt(self, context_text: str, context_description: str, 
                            not_found_message: str, query_text: str) -> List[Dict]:
        """Create system prompt for LLM"""
        return [
            {
                "role": "system",
                "content": (
                    f"You are a medical AI assistant tasked with answering clinical questions strictly based on the provided {context_description} context. Follow the requirements below to ensure accurate, consistent, and professional responses.\n\n"
                    "# Response Rules\n\n"
                    "1. **Context Restriction**:\n"
                    "   - Only use information given in the provided NHS health information context.\n"
                    "   - Do not generate or speculate with information not explicitly found in the given context.\n\n"
                    "2. **Answer Format**:\n"
                    "   - Provide a clear and concise response based solely on the context.\n"
                    "   - When including a list, use standard markdown bullet points (`*` or `-`).\n"
                    "   - If a list follows introductory text, insert a line break before the first bullet point.\n"
                    "   - Each bullet point must be on its own line.\n\n"
                    "3. **Preserve Tables**:\n"
                    "   - If relevant markdown tables appear in the context, reproduce them in your answer.\n"
                    "   - Maintain the original structure, formatting, and content of any included tables.\n\n"
                    "4. **Links and URLs**:\n"
                    "   - Include any URLs or web links from the context directly in your response when relevant.\n"
                    "   - Integrate links naturally within sentences, using markdown syntax for clickable text links.\n"
                    "   - DO NOT generate or invent any URLs not explicitly present in the context.\n\n"
                    "5. **Markdown Link Formatting**:\n"
                    "   - In responses, only the descriptive text in brackets should be visible and clickable (e.g., `[NHS ADHD information](https://www.nhs.uk/conditions/attention-deficit-hyperactivity-disorder-adhd/)`).\n"
                    "   - Readers should never see raw URLs in the text.\n"
                    "   - Use descriptive link text like 'NHS ADHD information' or 'NHS depression guide' rather than generic terms.\n\n"
                    "6. **If No Relevant Information**:\n"
                    "   - If the context contains no relevant information, state clearly:\n"
                    f"      *\"{not_found_message}\"*\n\n"
                    "# Output Format\n\n"
                    "- All responses should be in plain text, using markdown formatting for lists and links as required.\n"
                    "- Do not use code blocks.\n"
                    "- Answers should be concise, accurate, and formatted according to the rules above.\n\n"
                    "# Examples\n\n"
                    "**Example 1: Integration of markdown link in context**\n"
                    "Question: \"What are the symptoms of ADHD?\"\n"
                    "Context snippet: ...see the NHS information on ADHD symptoms...\n"
                    "Output:\n"
                    "According to the [NHS ADHD information](https://www.nhs.uk/conditions/attention-deficit-hyperactivity-disorder-adhd/), symptoms include...\n\n"
                    "**Example 2: Multiple condition references**\n"
                    "According to NHS guidance:\n"
                    "* Initial symptoms may include difficulty concentrating.\n"
                    "* For detailed information, see the [NHS ADHD guide](https://www.nhs.uk/conditions/adhd/).\n\n"
                    "**Example 3: No relevant context**\n"
                    f"{not_found_message}\n\n"
                    "# Notes\n\n"
                    "- Never output information beyond what is provided in the supplied context.\n"
                    "- Always use markdown for lists and links.\n"
                    "- Make sure all markdown tables from context are preserved in your answer if relevant.\n"
                    "- Present links only as clickable text, not as bare URLs.\n"
                    "- Use descriptive link text that indicates the specific NHS condition or topic.\n\n"
                    "**REMINDER:**\n"
                    "Strictly adhere to all formatting and content rules above for every response."
                ),
            },
            {
                "role": "assistant", 
                "content": (
                    f"Here is the context from {context_description} that you should use to answer the following question:\n\n{context_text}\n\n"
                ),
            },
            {
                "role": "user",
                "content": query_text,
            },
        ]
    


    def get_sources_from_results(self, results: List[Dict], info_source: str) -> List[Dict]:
        """Extract formatted sources from search results"""
        sources = []
        for doc in results:
            metadata = doc.get('metadata', {})
            section_id = metadata.get('original_id', 'Unknown section')
            source = metadata.get('source', 'Unknown')
            url = metadata.get('url', '')
            
            # Clean section ID for display
            clean_section_id = self._clean_section_id(section_id)
            
            source_info = {
                'metadata': {
                    'source': source,
                    'original_id': section_id,
                    'url': url,
                    'clean_section': clean_section_id
                }
            }
            sources.append(source_info)
        return sources
    
    def query_rag_stream(self, query_text: str, llm_model: str, similarity_k: int = 25, info_source: str = "NHS", 
                        filename_filter: Optional[str] = None) -> Generator[Tuple[str, List[Dict]], None, None]:
        """Query RAG system with streaming response"""
        try:
            self._validate_inputs(query_text, similarity_k, info_source)
            source_config = self.config.get_source_config(info_source)
            
            # Use the correct namespace from your test
            namespace = "nhs_guidelines_voyage_3_large"
            
            # Get similar documents using only similarity search
            results = self.search_engine.similarity_search(
                query_text, 
                namespace=namespace,
                top_k=similarity_k
            )
            
            if not results:
                yield "I couldn't find any relevant information to answer your question.", []
                return
            
            # Generate context and system prompt
            context_text = self._get_context_text(results)
            system_messages = self._create_system_prompt(
                context_text, 
                source_config.context_description,
                source_config.not_found_message, 
                query_text
            )
            
            # Get sources for response
            sources_data = self.get_sources_from_results(results, info_source)
            
            # Stream LLM response
            yield from self._stream_llm_response(system_messages, query_text, llm_model, sources_data)
            
        except Exception as e:
            logger.error(f"Error in query_rag_stream: {e}")
            yield f"An error occurred while processing your query: {str(e)}", []
    
    def _stream_llm_response(self, system_messages: List[Dict], query_text: str, 
                           llm_model: str, sources_data: List[Dict]) -> Generator[Tuple[str, List[Dict]], None, None]:
        """Stream LLM response"""
        try:
            if "gemini" in llm_model.lower() and self.gemini_client:
                stream = self.gemini_client.chat.completions.create(
                    model=llm_model,
                    messages=system_messages,
                    temperature=0,
                    stream=True
                )
                
                for chunk in stream:
                    if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content:
                        content = chunk.choices[0].delta.content
                        yield content, sources_data
                    
            else:
                error_msg = f"Unsupported LLM model or client not available: {llm_model}"
                logger.error(error_msg)
                yield error_msg, []
                return

        except Exception as e:
            logger.error(f"Error in LLM completion: {e}")
            yield f"Error generating response: {str(e)}", []



def main():
    """Main function for CLI usage"""
    parser = argparse.ArgumentParser(description="RAG System Query Interface")
    parser.add_argument("--query_text", type=str, default="What are the symptoms of ADHD in adults?", 
                       help="The query text.")
    parser.add_argument("--llm_model", type=str, default="gemini-2.0-flash", 
                       help="The LLM model to use.")
    parser.add_argument("--similarity_k", type=int, default=5, 
                       help="Number of results to retrieve in similarity search.")
    parser.add_argument("--info_source", type=str, default="NHS", 
                       choices=["nhs", "NHS"],
                       help="Information source to query.")
    
    args = parser.parse_args()
    
    try:
        print("Initializing RAG system...")
        rag_system = RAGSystem()
        
        print(f"\n=== Query: {args.query_text} ===")
        print(f"Source: {args.info_source}")
        print(f"LLM Model: {args.llm_model}")
        print("\n=== LLM Response ===\n")
        
        response_text, sources_data = "", []
        
        for chunk, sources in rag_system.query_rag_stream(
            query_text=args.query_text,
            llm_model=args.llm_model,
            similarity_k=args.similarity_k,
            info_source=args.info_source
        ):
            print(chunk, end="", flush=True)
            response_text += chunk
            sources_data = sources
        
        print("\n\n=== Sources Data ===\n")
        for i, source in enumerate(sources_data, 1):
            metadata = source.get('metadata', {})
            print(f"Source {i}:")
            print(f"  Clean Section: {metadata.get('clean_section', 'Unknown')}")
            print(f"  URL: {metadata.get('url', 'No URL')}")
            print()
        
    except Exception as e:
        logger.error(f"Error in main: {e}")
        print(f"Error: {e}")


if __name__ == "__main__":
    main()