Zeggai Abdellah commited on
Commit
5a74e30
·
1 Parent(s): f5c821c

update the handle of the complaxe query

Browse files
Files changed (2) hide show
  1. prepare_env.py +294 -273
  2. rag_pipeline.py +273 -294
prepare_env.py CHANGED
@@ -1,16 +1,22 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Enhanced RAG Pipeline for vaccine assistant - Fixed version with max iterations control
4
- Handles agent creation and question answering with sequential citation numbering
5
  """
6
 
 
7
  import json
8
  import re
9
- from llama_index.core import PromptTemplate
10
- from llama_index.core.agent import ReActAgent
11
- from llama_index.llms.google_genai import GoogleGenAI
12
- from langdetect import detect
13
- import os
 
 
 
 
 
14
 
15
 
16
  def extract_source_ids(response_text):
@@ -47,13 +53,8 @@ def extract_source_ids(response_text):
47
  ids = [id_str.strip() for id_str in citation.split(',')]
48
  all_ids.extend(ids)
49
 
50
- # Get unique source IDs while preserving order
51
- seen = set()
52
- source_ids = []
53
- for id_str in all_ids:
54
- if id_str not in seen:
55
- seen.add(id_str)
56
- source_ids.append(id_str)
57
 
58
  if not source_ids:
59
  print("Warning: No valid source IDs found after filtering.")
@@ -62,301 +63,321 @@ def extract_source_ids(response_text):
62
  return source_ids
63
 
64
 
65
- def convert_citations_to_sequential(response_text, source_id_to_number_map):
66
- """
67
- Convert source IDs in response text to sequential numbers.
 
 
 
68
 
69
- Args:
70
- response_text (str): The response text with source ID citations
71
- source_id_to_number_map (dict): Mapping from source IDs to sequential numbers
72
-
73
- Returns:
74
- str: Response text with sequential number citations
75
- """
76
- def replace_citation(match):
77
- citation_content = match.group(1)
78
- # Handle multiple IDs in one citation (comma-separated)
79
- ids = [id_str.strip() for id_str in citation_content.split(',')]
80
-
81
- # Convert each ID to its sequential number
82
- numbers = []
83
- for id_str in ids:
84
- if id_str in source_id_to_number_map:
85
- numbers.append(str(source_id_to_number_map[id_str]))
86
-
87
- # Return the formatted citation with sequential numbers
88
- if len(numbers) == 1:
89
- return f"[{numbers[0]}]"
90
- elif len(numbers) > 1:
91
- return f"[{','.join(numbers)}]"
92
- else:
93
- return match.group(0) # Return original if no mapping found
94
 
95
- # Replace all citations in the text
96
- sequential_response = re.sub(r'\[([^\[\]]+)\]', replace_citation, response_text)
97
- return sequential_response
98
 
99
 
100
- def create_safe_custom_prompt(tools, llm):
101
- """Create a safe version that won't have formatting conflicts"""
 
 
 
102
 
103
- custom_instructions = """
104
- ## MEDICAL ASSISTANT ROLE
105
- You are a helpful and knowledgeable AI-powered vaccine assistant designed to support doctors in clinical decision-making.
106
- You provide evidence-based guidance using only information from official vaccine medical documents.
107
- Answer the doctor's question accurately and concisely using only the provided information.
108
-
109
- ## CRITICAL RULES FOR EFFICIENCY
110
-
111
- ### Tool Usage Strategy
112
- 1. **MAXIMUM 3 TOOL CALLS**: You must provide a complete answer within 3 tool calls maximum.
113
- 2. **Smart Tool Selection**: Choose the most relevant tool first based on the question topic.
114
- 3. **Comparative Questions**: For questions comparing documents/protocols:
115
- - First tool call: Get information from primary source (e.g., Algerian guide)
116
- - Second tool call: Get information from secondary source (e.g., WHO document)
117
- - Third tool call: Only if absolutely necessary for missing details
118
- 4. **Stop Early**: If you have sufficient information after 1-2 tool calls, provide your answer immediately.
119
-
120
- ### Citation and Sourcing
121
- 1. For each fact in your response, include an inline citation in the format [Source] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90].
122
- 2. Do NOT use 'Source:' in the citation format; use only the Source in square brackets.
123
- 3. If a fact is supported by multiple sources, use adjacent citations: [source1][source2]
124
- 4. Use ONLY the provided information and never include facts from your general knowledge.
125
-
126
- ### Content Formatting
127
- 1. When rendering tables:
128
- - Convert HTML tables into clean Markdown format
129
- - Preserve all original headers and data rows exactly
130
- - Include the citation in the table caption, e.g., 'Table: Vaccination Schedule [Source]'
131
- 2. For lists, maintain the original bullet points/numbering and include citations.
132
- 3. Present information concisely but ensure clinical accuracy is never compromised.
133
-
134
- ### Answer Completeness Guidelines
135
- - If you find relevant information from 1-2 sources, synthesize and provide a complete answer
136
- - Don't keep searching for more sources unless critical information is missing
137
- - For comparative questions, clearly structure your answer with sections for each source
138
- - If information is not available in the documents, clearly state this limitation
139
-
140
- ---
141
-
142
- """
143
-
144
- # Get the exact original template first
145
- temp_agent = ReActAgent.from_tools(tools, llm=llm, verbose=False)
146
- original_prompts = temp_agent.get_prompts()
147
- original_template = original_prompts["agent_worker:system_prompt"].template
148
-
149
- # Add instructions at the very beginning
150
- safe_template = f"{custom_instructions}{original_template}"
151
-
152
- # Create new prompt with same metadata as original
153
- original_prompt = original_prompts["agent_worker:system_prompt"]
154
 
155
- try:
156
- new_prompt = PromptTemplate(
157
- template=safe_template,
158
- template_vars=original_prompt.template_vars,
159
- metadata=original_prompt.metadata if hasattr(original_prompt, 'metadata') else None
160
- )
161
- return new_prompt
162
- except:
163
- # Even safer fallback
164
- return PromptTemplate(template=safe_template)
165
 
 
 
166
 
167
- def create_agent(tools, llm):
168
- """Create the ReAct agent with custom prompt and controlled max iterations"""
169
-
170
- # Create agent with controlled max iterations (reduced from default 10 to 5)
171
- agent = ReActAgent.from_tools(
172
- tools,
173
- llm=llm,
174
- verbose=True,
175
- max_iterations=5, # Reduced max iterations
176
  )
 
177
 
178
- # Create and apply safe custom prompt
179
- try:
180
- safe_custom_prompt = create_safe_custom_prompt(tools, llm)
181
- agent.update_prompts({"agent_worker:system_prompt": safe_custom_prompt})
182
- print("✅ Successfully updated with safe custom prompt and max_iterations=5")
183
- except Exception as e:
184
- print(f"❌ Safe prompt update failed: {e}")
185
- print("⚠️ Using original agent without modifications")
186
-
187
- return agent
188
 
189
-
190
- def initialize_rag_pipeline(tools):
191
- """Initialize the RAG pipeline with tools"""
 
 
 
 
192
 
193
- # Initialize LlamaIndex LLM with specific parameters to improve efficiency
194
- llama_index_llm = GoogleGenAI(
195
- model="models/gemini-2.0-flash",
196
- api_key=os.getenv('GOOGLE_API_KEY'),
197
- temperature=0.1, # Lower temperature for more focused responses
 
 
 
198
  )
199
 
200
- # Create agent
201
- agent = create_agent(tools, llama_index_llm)
 
 
 
202
 
203
- return agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
 
 
205
 
206
- def process_question(agent, question: str) -> str:
207
- """Process a question through the RAG pipeline with timeout handling"""
208
- try:
209
- # Add timeout/retry logic
210
- response = agent.chat(question)
211
- return response.response
212
- except Exception as e:
213
- error_msg = str(e)
214
- print(f"Error processing question: {error_msg}")
215
-
216
- # Handle specific "max iterations" error
217
- if "max iterations" in error_msg.lower() or "reached max" in error_msg.lower():
218
- return ("I apologize, but I was unable to find a complete answer within the allowed search attempts. "
219
- "This might be because the specific comparison you're asking about requires information "
220
- "that spans multiple sections of the documents. Could you please rephrase your question "
221
- "to be more specific about which aspect of the difference you're most interested in?")
222
 
223
- return f"Error processing your question: {error_msg}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
 
226
- def aswer_language_detection(response_text: str) -> str:
227
- """
228
- Detect the language of the response text.
229
 
230
- Args:
231
- response_text (str): The response text to analyze.
232
-
233
- Returns:
234
- str: Detected language code (e.g., 'en', 'fr', etc.)
235
- """
236
- try:
237
- # Detect the language of the first 5 words of the response
238
- first_line = " ".join(response_text.split()[:5])
239
- first_line = re.sub(r'\[.*?\]', '', first_line) # Remove citations
240
- answer_language = detect(first_line)
241
- if answer_language not in ['en', 'ar', 'fr']:
242
- answer_language = 'en'
243
- except:
244
- answer_language = 'en'
 
 
 
 
 
245
 
246
- return answer_language
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
 
248
 
249
- def process_question_with_sequential_citations(agent, question: str, chunks_directory="./data/") -> dict:
250
- """
251
- Process a question through the RAG pipeline and return response with sequential citation numbers.
252
- Enhanced with better error handling for max iterations.
253
-
254
- Args:
255
- agent: The initialized RAG agent
256
- question (str): The user's question
257
- chunks_directory (str): Path to the directory containing JSON files
258
 
259
- Returns:
260
- dict: {
261
- "response": str, # Response with sequential citation numbers [1], [2], etc.
262
- "cited_elements_json": str, # JSON array of cited elements in order
263
- "unique_ids": list, # Original source IDs in order
264
- "citation_mapping": dict # Mapping from source ID to citation number
265
- }
266
- """
267
- try:
268
- # Get the response from the agent with improved error handling
269
- response = agent.chat(question)
270
- response_text = response.response
271
 
272
- # Check if the response indicates max iterations was reached
273
- if "max iterations" in response_text.lower() or len(response_text.strip()) == 0:
274
- # Provide a more helpful fallback response
275
- response_text = ("I apologize, but I encountered difficulties processing your comparative question "
276
- "within the allowed search attempts. For questions comparing different protocols "
277
- "or documents, please try asking about each aspect separately. For example, "
278
- "first ask about the Algerian definition of Diphtheria, then ask about the WHO definition.")
279
 
280
- # Extract source IDs from the response (preserving order)
281
- unique_ids = extract_source_ids(response_text)
282
 
283
- # Create mapping from source ID to sequential number
284
- source_id_to_number = {source_id: i + 1 for i, source_id in enumerate(unique_ids)}
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- # Convert citations to sequential numbers
287
- sequential_response = convert_citations_to_sequential(response_text, source_id_to_number)
288
 
289
- # Load all chunks data to find cited elements
290
- all_chunks_data = []
291
- min_chunks_files = ["Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json",
292
- "Immunization_in_Practice_WHO_eng_2015.json"]
 
293
 
294
- for json_file in min_chunks_files:
295
- json_path = os.path.join(chunks_directory, json_file)
296
- try:
297
- with open(json_path, "r", encoding="utf-8") as f:
298
- chunks_data = json.load(f)
299
- all_chunks_data.extend(chunks_data)
300
- except Exception as e:
301
- print(f"Warning: Could not load {json_file}: {e}")
302
 
303
- # Get cited elements in the same order as the sequential citations
304
- cited_elements_ordered = []
305
- for source_id in unique_ids: # This preserves the order
306
- for element in all_chunks_data:
307
- if element.get("type") == 'TableElement':
308
- if element.get("element_id") == source_id:
309
- cited_elements_ordered.append(element)
310
- break
311
- else:
312
- if "elements" in element:
313
- for nested_element in element["elements"]:
314
- if nested_element.get("element_id") == source_id:
315
- cited_elements_ordered.append(nested_element)
316
- break
317
- else:
318
- continue
319
- break
320
 
321
- # Convert to JSON
322
- cited_elements_json = json.dumps(cited_elements_ordered, ensure_ascii=False, indent=2)
323
- answer_language = aswer_language_detection(response_text)
324
 
325
- return {
326
- "response": sequential_response,
327
- "cited_elements_json": cited_elements_json,
328
- "unique_ids": unique_ids,
329
- "citation_mapping": source_id_to_number,
330
- "answer_language": answer_language
331
- }
332
 
333
- except Exception as e:
334
- error_msg = str(e)
335
- print(f"Error processing question: {error_msg}")
336
 
337
- # Create appropriate fallback response based on error type
338
- if "max iterations" in error_msg.lower() or "reached max" in error_msg.lower():
339
- fallback_response = ("I apologize, but I was unable to complete the comparison within the allowed search attempts. "
340
- "For complex comparative questions like yours about the differences between Algerian and WHO "
341
- "definitions of Diphtheria, please try asking about each source separately: \n\n"
342
- "1. First ask: 'What is the definition of Diphtheria in the Algerian vaccination guide?'\n"
343
- "2. Then ask: 'What is the definition of Diphtheria in the WHO document?'\n\n"
344
- "This will help me provide you with more focused and complete information.")
345
- else:
346
- fallback_response = f"I encountered an error while processing your question: {error_msg}"
 
 
 
347
 
348
- return {
349
- "response": fallback_response,
350
- "cited_elements_json": "[]",
351
- "unique_ids": [],
352
- "citation_mapping": {},
353
- "answer_language": "en"
354
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
- def process_question_with_citations(agent, question: str, chunks_directory="./data/") -> dict:
358
- """
359
- Legacy function - maintained for backward compatibility.
360
- Now calls the new sequential citation function.
361
- """
362
- return process_question_with_sequential_citations(agent, question, chunks_directory)
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Environment preparation script for vaccine assistant - Improved version
4
+ Creates vector stores and retrieval tools with better descriptions for efficient agent routing
5
  """
6
 
7
+ import os
8
  import json
9
  import re
10
+ import nest_asyncio
11
+ from typing import List
12
+ from langchain_community.vectorstores import Chroma
13
+ from langchain_core.documents import Document
14
+ from langchain.embeddings import HuggingFaceEmbeddings
15
+ from langchain.retrievers import BM25Retriever, EnsembleRetriever
16
+ from langchain.retrievers.multi_query import MultiQueryRetriever
17
+ from langchain_google_genai import ChatGoogleGenerativeAI
18
+ from llama_index.core.tools import FunctionTool
19
+ from llama_index.core.schema import TextNode
20
 
21
 
22
  def extract_source_ids(response_text):
 
53
  ids = [id_str.strip() for id_str in citation.split(',')]
54
  all_ids.extend(ids)
55
 
56
+ # Get unique source IDs
57
+ source_ids = list(set(all_ids))
 
 
 
 
 
58
 
59
  if not source_ids:
60
  print("Warning: No valid source IDs found after filtering.")
 
63
  return source_ids
64
 
65
 
66
+ def setup_models():
67
+ """Initialize embedding model and LLM"""
68
+ # Initialize embedding model
69
+ embedding_function = HuggingFaceEmbeddings(
70
+ model_name="intfloat/multilingual-e5-base"
71
+ )
72
 
73
+ # Initialize LLM with better parameters for focused responses
74
+ genai_api_key = os.getenv('GOOGLE_API_KEY')
75
+ llm = ChatGoogleGenerativeAI(
76
+ model="gemini-2.0-flash",
77
+ google_api_key=genai_api_key,
78
+ temperature=0.1 # Lower temperature for more focused responses
79
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ return embedding_function, llm
 
 
82
 
83
 
84
+ def create_vectorstore_from_json(json_path: str, collection_name: str, embedding_function):
85
+ """Create vector store from JSON chunks"""
86
+ # Load the chunks.json
87
+ with open(json_path, "r", encoding="utf-8") as f:
88
+ chunks_data = json.load(f)
89
 
90
+ documents = []
91
+ for element in chunks_data:
92
+ text = element["text"]
93
+ metadata = {
94
+ "language": "fra",
95
+ "source": element["filename"],
96
+ "filetype": element["filetype"],
97
+ "element_id": element["element_id"]
98
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ if "TableElement" == element["type"]:
101
+ metadata["table_text_as_html"] = element["table_text_as_html"]
 
 
 
 
 
 
 
 
102
 
103
+ doc = Document(page_content=text, metadata=metadata)
104
+ documents.append(doc)
105
 
106
+ # Create vector store
107
+ vectorstore = Chroma.from_documents(
108
+ documents=documents,
109
+ embedding=embedding_function,
110
+ collection_name=collection_name,
111
+ persist_directory="chroma_db_multilingual"
 
 
 
112
  )
113
+ return vectorstore, documents
114
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ def create_retriever(vectorstore, docs, llm):
117
+ """Create ensemble retriever with vector and BM25 search"""
118
+ # Vector retriever
119
+ vector_retriever = vectorstore.as_retriever(
120
+ search_type="similarity",
121
+ search_kwargs={"k": 4} # Reduced from 6 to 4 for efficiency
122
+ )
123
 
124
+ # BM25 retriever
125
+ bm25_retriever = BM25Retriever.from_documents(docs)
126
+ bm25_retriever.k = 2
127
+
128
+ # Ensemble retriever
129
+ ensemble_retriever = EnsembleRetriever(
130
+ retrievers=[vector_retriever, bm25_retriever],
131
+ weights=[0.5, 0.5]
132
  )
133
 
134
+ # Multi-query expanding retriever (with reduced complexity for efficiency)
135
+ expanding_retriever = MultiQueryRetriever.from_llm(
136
+ retriever=ensemble_retriever,
137
+ llm=llm
138
+ )
139
 
140
+ return expanding_retriever
141
+
142
+
143
+ def convert_chromadb_to_llamaindex_nodes(chromadb_documents: List) -> List[TextNode]:
144
+ """Convert ChromaDB Document objects to LlamaIndex TextNode objects"""
145
+ nodes = []
146
+ for i, doc in enumerate(chromadb_documents):
147
+ try:
148
+ text = doc.page_content
149
+ metadata = doc.metadata.copy()
150
+ element_id = metadata.get("element_id", f"doc_{i}")
151
+ source = metadata.get("source", "unknown")
152
+ node_id = f"{source}_{element_id}"
153
+
154
+ node = TextNode(
155
+ text=text,
156
+ metadata=metadata,
157
+ id_=node_id
158
+ )
159
+ nodes.append(node)
160
+ except Exception as e:
161
+ continue
162
+ return nodes
163
+
164
+
165
+ def section_tool_wrapper(retriever, section_path_chunks, query):
166
+ """Generic section tool wrapper with improved efficiency"""
167
+ try:
168
+ retrieved_docs = retriever.get_relevant_documents(query)
169
+ nodes_from_retrieved_docs = convert_chromadb_to_llamaindex_nodes(retrieved_docs)
170
 
171
+ if not nodes_from_retrieved_docs:
172
+ return "No relevant documents found for the query."
173
 
174
+ chunk_ids = [node.metadata['element_id'] for node in retrieved_docs]
175
+ with open(section_path_chunks, "r", encoding="utf-8") as f:
176
+ chunks_data = json.load(f)
177
+
178
+ chunks_unique = [node for node in chunks_data if node.get('element_id', 'Unknown') in chunk_ids]
179
+ combined_text = []
 
 
 
 
 
 
 
 
 
 
180
 
181
+ # Limit the number of chunks to avoid overwhelming the context
182
+ max_chunks = 8 # Reasonable limit
183
+ for chu in chunks_unique[:max_chunks]:
184
+ if "TableElement" == chu["type"]:
185
+ text = f"[{chu['element_id']}]\n CONTENT: \n{chu['text']}\n HTML: \n {chu['table_text_as_html']} \n\n"
186
+ combined_text.append(text)
187
+ else:
188
+ for element in chu["elements"]:
189
+ text = f"[{element['element_id']}]\n CONTENT: \n{element['text']} \n\n"
190
+ combined_text.append(text)
191
+
192
+ result = "\n---\n".join(combined_text)
193
+ print(f"Retrieved {len(nodes_from_retrieved_docs)} documents for query: {query[:50]}...")
194
+ return result
195
+ except Exception as e:
196
+ print(f"Error in section tool: {e}")
197
+ return f"Error retrieving documents: {str(e)}"
198
 
199
 
200
+ def create_section_tools(embedding_function, llm):
201
+ """Create all section-specific retrieval tools with improved descriptions"""
 
202
 
203
+ # Define section paths
204
+ section_paths = {
205
+ 'one': 'section_one_chunks.json',
206
+ 'two': 'section_two_chunks.json',
207
+ 'three': 'section_three_chunks.json',
208
+ 'four': 'section_four_chunks.json',
209
+ 'five': 'section_five_chunks.json',
210
+ 'six': 'section_six_chunks.json',
211
+ 'seven': 'section_seven_chunks.json',
212
+ 'eight': 'section_eight_chunks.json',
213
+ 'nine': 'section_nine_chunks.json',
214
+ 'ten': 'section_ten_chunks.json'
215
+ }
216
+
217
+ # Create retrievers for each section
218
+ section_retrievers = {}
219
+ for section, path in section_paths.items():
220
+ if os.path.exists(f'./data/{path}'):
221
+ vstore, docs = create_vectorstore_from_json(f'./data/{path}', f"Guide_2023_{section}", embedding_function)
222
+ section_retrievers[section] = create_retriever(vstore, docs, llm)
223
 
224
+ # Create main guide retriever
225
+ guide_path = './data/Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json'
226
+ if os.path.exists(guide_path):
227
+ guide_vstore, guide_docs = create_vectorstore_from_json(guide_path, "Guide_2023_multilingual", embedding_function)
228
+ guide_retriever = create_retriever(guide_vstore, guide_docs, llm)
229
+ else:
230
+ guide_retriever = None
231
+
232
+ # Primary + Secondary Document Paths
233
+ immunization_path = './data/Immunization_in_Practice_WHO_eng_2015.json'
234
+
235
+ # WHO Immunization in Practice Tool
236
+ if os.path.exists(immunization_path):
237
+ immunization_vstore, immunization_docs = create_vectorstore_from_json(
238
+ immunization_path,
239
+ "Immunization_in_Practice_WHO_eng_2015",
240
+ embedding_function
241
+ )
242
+ immunization_retriever = create_retriever(immunization_vstore, immunization_docs, llm)
243
+ else:
244
+ immunization_retriever = None
245
 
246
+ # Tool Functions with Improved Efficiency Focus
247
 
248
+ def guide_retrieval_tool(query: str) -> str:
249
+ """
250
+ **PRIMARY TOOL - USE FIRST FOR MOST QUESTIONS**
 
 
 
 
 
 
251
 
252
+ Comprehensive search across the entire Algerian National Vaccination Guide (2023).
 
 
 
 
 
 
 
 
 
 
 
253
 
254
+ **When to use this tool:**
255
+ - General vaccination questions
256
+ - Disease definitions and descriptions
257
+ - Vaccine schedules and protocols
258
+ - Comparative questions needing Algerian perspective
259
+ - Any question about Algeria's vaccination program
 
260
 
261
+ **Keywords that indicate this tool:** Algeria, Algerian, national, calendrier, vaccination, PEV, diseases (diphteria, polio, measles, etc.)
 
262
 
263
+ Args:
264
+ query (str): Any vaccination-related question about Algeria's national program
265
+
266
+ Returns:
267
+ str: Comprehensive information from the Algerian guide with citations
268
+ """
269
+ if not guide_retriever:
270
+ return "Guide retriever not available"
271
+ return section_tool_wrapper(guide_retriever, guide_path, query)
272
+
273
+ def immunization_tool(query: str) -> str:
274
+ """
275
+ **SECONDARY TOOL - USE FOR WHO/INTERNATIONAL PERSPECTIVE**
276
 
277
+ WHO Immunization in Practice 2015 - Global best practices and international standards.
 
278
 
279
+ **When to use this tool:**
280
+ - Questions specifically asking about WHO recommendations
281
+ - International/global immunization practices
282
+ - Comparative questions needing WHO perspective
283
+ - Technical immunization procedures and best practices
284
 
285
+ **Keywords that indicate this tool:** WHO, international, global, best practices, standards
 
 
 
 
 
 
 
286
 
287
+ Args:
288
+ query (str): Question about international immunization practices or WHO recommendations
289
+
290
+ Returns:
291
+ str: WHO guidance and international best practices with citations
292
+ """
293
+ if not immunization_retriever:
294
+ return "Immunization in Practice retriever not available"
295
+ return section_tool_wrapper(immunization_retriever, immunization_path, query)
296
+
297
+ # Section-Specific Tools (USE ONLY IF QUESTION IS VERY SPECIFIC TO THE SECTION)
298
+
299
+ def section_two_tool(query: str) -> str:
300
+ """
301
+ **DISEASE-SPECIFIC TOOL**
 
 
302
 
303
+ Section 2: Vaccine-preventable diseases - definitions, symptoms, transmission, complications.
 
 
304
 
305
+ **Use ONLY for specific disease definition questions like:**
306
+ - "What is diphtheria?"
307
+ - "Define measles according to Algerian protocol"
308
+ - "Symptoms of polio"
 
 
 
309
 
310
+ **Keywords:** definition, symptoms, transmission, complications, disease characteristics
 
 
311
 
312
+ Args:
313
+ query (str): Specific question about disease definitions or characteristics
314
+
315
+ Returns:
316
+ str: Disease-specific medical information with citations
317
+ """
318
+ if 'two' not in section_retrievers:
319
+ return "Section 2 retriever not available"
320
+ return section_tool_wrapper(section_retrievers['two'], f'./data/{section_paths["two"]}', query)
321
+
322
+ def section_three_tool(query: str) -> str:
323
+ """
324
+ **VACCINE-SPECIFIC TOOL**
325
 
326
+ Section 3: Vaccine details - types, composition, administration methods.
327
+
328
+ **Use ONLY for specific vaccine technical questions like:**
329
+ - "What type of vaccine is used for diphtheria?"
330
+ - "How is the MMR vaccine administered?"
331
+ - "Vaccine composition and dosage"
332
+
333
+ **Keywords:** vaccine type, composition, administration, dosage, technical details
334
+
335
+ Args:
336
+ query (str): Technical question about specific vaccines
337
+
338
+ Returns:
339
+ str: Technical vaccine information with citations
340
+ """
341
+ if 'three' not in section_retrievers:
342
+ return "Section 3 retriever not available"
343
+ return section_tool_wrapper(section_retrievers['three'], f'./data/{section_paths["three"]}', query)
344
+
345
+ # Create FunctionTool objects with focused selection
346
+ tools = [
347
+ # Primary tools - most commonly used
348
+ FunctionTool.from_defaults(
349
+ name="algerian_guide_search",
350
+ fn=guide_retrieval_tool,
351
+ description="PRIMARY TOOL: Search the complete Algerian National Vaccination Guide for any vaccination-related question"
352
+ ),
353
+ FunctionTool.from_defaults(
354
+ name="who_immunization_search",
355
+ fn=immunization_tool,
356
+ description="SECONDARY TOOL: Search WHO Immunization in Practice for international standards and WHO recommendations"
357
+ ),
358
+ # Specialized tools - use only when very specific
359
+ FunctionTool.from_defaults(
360
+ name="disease_definitions_search",
361
+ fn=section_two_tool,
362
+ description="SPECIALIZED: Search for specific disease definitions, symptoms, and characteristics"
363
+ ),
364
+ FunctionTool.from_defaults(
365
+ name="vaccine_technical_search",
366
+ fn=section_three_tool,
367
+ description="SPECIALIZED: Search for technical vaccine details, composition, and administration methods"
368
+ ),
369
+ ]
370
+
371
+ return tools
372
 
373
 
374
+ def prepare_environment():
375
+ """Main function to prepare the environment and return tools"""
376
+ print("Setting up models...")
377
+ embedding_function, llm = setup_models()
378
+
379
+ print("Creating section tools...")
380
+ tools = create_section_tools(embedding_function, llm)
381
+
382
+ print("Environment prepared successfully!")
383
+ return tools, llm
rag_pipeline.py CHANGED
@@ -1,22 +1,16 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Environment preparation script for vaccine assistant - Improved version
4
- Creates vector stores and retrieval tools with better descriptions for efficient agent routing
5
  """
6
 
7
- import os
8
  import json
9
  import re
10
- import nest_asyncio
11
- from typing import List
12
- from langchain_community.vectorstores import Chroma
13
- from langchain_core.documents import Document
14
- from langchain.embeddings import HuggingFaceEmbeddings
15
- from langchain.retrievers import BM25Retriever, EnsembleRetriever
16
- from langchain.retrievers.multi_query import MultiQueryRetriever
17
- from langchain_google_genai import ChatGoogleGenerativeAI
18
- from llama_index.core.tools import FunctionTool
19
- from llama_index.core.schema import TextNode
20
 
21
 
22
  def extract_source_ids(response_text):
@@ -53,8 +47,13 @@ def extract_source_ids(response_text):
53
  ids = [id_str.strip() for id_str in citation.split(',')]
54
  all_ids.extend(ids)
55
 
56
- # Get unique source IDs
57
- source_ids = list(set(all_ids))
 
 
 
 
 
58
 
59
  if not source_ids:
60
  print("Warning: No valid source IDs found after filtering.")
@@ -63,321 +62,301 @@ def extract_source_ids(response_text):
63
  return source_ids
64
 
65
 
66
- def setup_models():
67
- """Initialize embedding model and LLM"""
68
- # Initialize embedding model
69
- embedding_function = HuggingFaceEmbeddings(
70
- model_name="intfloat/multilingual-e5-base"
71
- )
72
 
73
- # Initialize LLM with better parameters for focused responses
74
- genai_api_key = os.getenv('GOOGLE_API_KEY')
75
- llm = ChatGoogleGenerativeAI(
76
- model="gemini-2.0-flash",
77
- google_api_key=genai_api_key,
78
- temperature=0.1 # Lower temperature for more focused responses
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
- return embedding_function, llm
 
 
82
 
83
 
84
- def create_vectorstore_from_json(json_path: str, collection_name: str, embedding_function):
85
- """Create vector store from JSON chunks"""
86
- # Load the chunks.json
87
- with open(json_path, "r", encoding="utf-8") as f:
88
- chunks_data = json.load(f)
89
 
90
- documents = []
91
- for element in chunks_data:
92
- text = element["text"]
93
- metadata = {
94
- "language": "fra",
95
- "source": element["filename"],
96
- "filetype": element["filetype"],
97
- "element_id": element["element_id"]
98
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- if "TableElement" == element["type"]:
101
- metadata["table_text_as_html"] = element["table_text_as_html"]
102
-
103
- doc = Document(page_content=text, metadata=metadata)
104
- documents.append(doc)
 
105
 
106
- # Create vector store
107
- vectorstore = Chroma.from_documents(
108
- documents=documents,
109
- embedding=embedding_function,
110
- collection_name=collection_name,
111
- persist_directory="chroma_db_multilingual"
112
- )
113
- return vectorstore, documents
114
 
 
 
115
 
116
- def create_retriever(vectorstore, docs, llm):
117
- """Create ensemble retriever with vector and BM25 search"""
118
- # Vector retriever
119
- vector_retriever = vectorstore.as_retriever(
120
- search_type="similarity",
121
- search_kwargs={"k": 4} # Reduced from 6 to 4 for efficiency
122
- )
123
-
124
- # BM25 retriever
125
- bm25_retriever = BM25Retriever.from_documents(docs)
126
- bm25_retriever.k = 2
127
-
128
- # Ensemble retriever
129
- ensemble_retriever = EnsembleRetriever(
130
- retrievers=[vector_retriever, bm25_retriever],
131
- weights=[0.5, 0.5]
132
- )
133
-
134
- # Multi-query expanding retriever (with reduced complexity for efficiency)
135
- expanding_retriever = MultiQueryRetriever.from_llm(
136
- retriever=ensemble_retriever,
137
- llm=llm
138
- )
139
-
140
- return expanding_retriever
141
-
142
-
143
- def convert_chromadb_to_llamaindex_nodes(chromadb_documents: List) -> List[TextNode]:
144
- """Convert ChromaDB Document objects to LlamaIndex TextNode objects"""
145
- nodes = []
146
- for i, doc in enumerate(chromadb_documents):
147
- try:
148
- text = doc.page_content
149
- metadata = doc.metadata.copy()
150
- element_id = metadata.get("element_id", f"doc_{i}")
151
- source = metadata.get("source", "unknown")
152
- node_id = f"{source}_{element_id}"
153
-
154
- node = TextNode(
155
- text=text,
156
- metadata=metadata,
157
- id_=node_id
158
- )
159
- nodes.append(node)
160
- except Exception as e:
161
- continue
162
- return nodes
163
-
164
-
165
- def section_tool_wrapper(retriever, section_path_chunks, query):
166
- """Generic section tool wrapper with improved efficiency"""
167
  try:
168
- retrieved_docs = retriever.get_relevant_documents(query)
169
- nodes_from_retrieved_docs = convert_chromadb_to_llamaindex_nodes(retrieved_docs)
 
 
 
 
 
 
 
170
 
171
- if not nodes_from_retrieved_docs:
172
- return "No relevant documents found for the query."
173
 
174
- chunk_ids = [node.metadata['element_id'] for node in retrieved_docs]
175
- with open(section_path_chunks, "r", encoding="utf-8") as f:
176
- chunks_data = json.load(f)
 
 
 
 
 
 
 
177
 
178
- chunks_unique = [node for node in chunks_data if node.get('element_id', 'Unknown') in chunk_ids]
179
- combined_text = []
180
-
181
- # Limit the number of chunks to avoid overwhelming the context
182
- max_chunks = 8 # Reasonable limit
183
- for chu in chunks_unique[:max_chunks]:
184
- if "TableElement" == chu["type"]:
185
- text = f"[{chu['element_id']}]\n CONTENT: \n{chu['text']}\n HTML: \n {chu['table_text_as_html']} \n\n"
186
- combined_text.append(text)
187
- else:
188
- for element in chu["elements"]:
189
- text = f"[{element['element_id']}]\n CONTENT: \n{element['text']} \n\n"
190
- combined_text.append(text)
191
-
192
- result = "\n---\n".join(combined_text)
193
- print(f"Retrieved {len(nodes_from_retrieved_docs)} documents for query: {query[:50]}...")
194
- return result
195
  except Exception as e:
196
- print(f"Error in section tool: {e}")
197
- return f"Error retrieving documents: {str(e)}"
 
 
198
 
199
 
200
- def create_section_tools(embedding_function, llm):
201
- """Create all section-specific retrieval tools with improved descriptions"""
202
 
203
- # Define section paths
204
- section_paths = {
205
- 'one': 'section_one_chunks.json',
206
- 'two': 'section_two_chunks.json',
207
- 'three': 'section_three_chunks.json',
208
- 'four': 'section_four_chunks.json',
209
- 'five': 'section_five_chunks.json',
210
- 'six': 'section_six_chunks.json',
211
- 'seven': 'section_seven_chunks.json',
212
- 'eight': 'section_eight_chunks.json',
213
- 'nine': 'section_nine_chunks.json',
214
- 'ten': 'section_ten_chunks.json'
215
- }
216
 
217
- # Create retrievers for each section
218
- section_retrievers = {}
219
- for section, path in section_paths.items():
220
- if os.path.exists(f'./data/{path}'):
221
- vstore, docs = create_vectorstore_from_json(f'./data/{path}', f"Guide_2023_{section}", embedding_function)
222
- section_retrievers[section] = create_retriever(vstore, docs, llm)
223
 
224
- # Create main guide retriever
225
- guide_path = './data/Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json'
226
- if os.path.exists(guide_path):
227
- guide_vstore, guide_docs = create_vectorstore_from_json(guide_path, "Guide_2023_multilingual", embedding_function)
228
- guide_retriever = create_retriever(guide_vstore, guide_docs, llm)
229
- else:
230
- guide_retriever = None
231
-
232
- # Primary + Secondary Document Paths
233
- immunization_path = './data/Immunization_in_Practice_WHO_eng_2015.json'
234
-
235
- # WHO Immunization in Practice Tool
236
- if os.path.exists(immunization_path):
237
- immunization_vstore, immunization_docs = create_vectorstore_from_json(
238
- immunization_path,
239
- "Immunization_in_Practice_WHO_eng_2015",
240
- embedding_function
241
- )
242
- immunization_retriever = create_retriever(immunization_vstore, immunization_docs, llm)
243
- else:
244
- immunization_retriever = None
245
 
246
- # Tool Functions with Improved Efficiency Focus
247
 
248
- def guide_retrieval_tool(query: str) -> str:
249
- """
250
- **PRIMARY TOOL - USE FIRST FOR MOST QUESTIONS**
 
 
 
 
 
 
251
 
252
- Comprehensive search across the entire Algerian National Vaccination Guide (2023).
 
 
 
 
 
253
 
254
- **When to use this tool:**
255
- - General vaccination questions
256
- - Disease definitions and descriptions
257
- - Vaccine schedules and protocols
258
- - Comparative questions needing Algerian perspective
259
- - Any question about Algeria's vaccination program
 
 
 
260
 
261
- **Keywords that indicate this tool:** Algeria, Algerian, national, calendrier, vaccination, PEV, diseases (diphteria, polio, measles, etc.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
- Args:
264
- query (str): Any vaccination-related question about Algeria's national program
265
-
266
- Returns:
267
- str: Comprehensive information from the Algerian guide with citations
268
- """
269
- if not guide_retriever:
270
- return "Guide retriever not available"
271
- return section_tool_wrapper(guide_retriever, guide_path, query)
272
-
273
- def immunization_tool(query: str) -> str:
274
- """
275
- **SECONDARY TOOL - USE FOR WHO/INTERNATIONAL PERSPECTIVE**
276
 
277
- WHO Immunization in Practice 2015 - Global best practices and international standards.
 
 
 
 
 
 
278
 
279
- **When to use this tool:**
280
- - Questions specifically asking about WHO recommendations
281
- - International/global immunization practices
282
- - Comparative questions needing WHO perspective
283
- - Technical immunization procedures and best practices
284
 
285
- **Keywords that indicate this tool:** WHO, international, global, best practices, standards
 
286
 
287
- Args:
288
- query (str): Question about international immunization practices or WHO recommendations
289
-
290
- Returns:
291
- str: WHO guidance and international best practices with citations
292
- """
293
- if not immunization_retriever:
294
- return "Immunization in Practice retriever not available"
295
- return section_tool_wrapper(immunization_retriever, immunization_path, query)
296
-
297
- # Section-Specific Tools (USE ONLY IF QUESTION IS VERY SPECIFIC TO THE SECTION)
298
-
299
- def section_two_tool(query: str) -> str:
300
- """
301
- **DISEASE-SPECIFIC TOOL**
302
 
303
- Section 2: Vaccine-preventable diseases - definitions, symptoms, transmission, complications.
 
 
 
304
 
305
- **Use ONLY for specific disease definition questions like:**
306
- - "What is diphtheria?"
307
- - "Define measles according to Algerian protocol"
308
- - "Symptoms of polio"
 
 
 
 
309
 
310
- **Keywords:** definition, symptoms, transmission, complications, disease characteristics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
- Args:
313
- query (str): Specific question about disease definitions or characteristics
314
-
315
- Returns:
316
- str: Disease-specific medical information with citations
317
- """
318
- if 'two' not in section_retrievers:
319
- return "Section 2 retriever not available"
320
- return section_tool_wrapper(section_retrievers['two'], f'./data/{section_paths["two"]}', query)
321
-
322
- def section_three_tool(query: str) -> str:
323
- """
324
- **VACCINE-SPECIFIC TOOL**
325
 
326
- Section 3: Vaccine details - types, composition, administration methods.
 
 
 
 
 
 
327
 
328
- **Use ONLY for specific vaccine technical questions like:**
329
- - "What type of vaccine is used for diphtheria?"
330
- - "How is the MMR vaccine administered?"
331
- - "Vaccine composition and dosage"
332
 
333
- **Keywords:** vaccine type, composition, administration, dosage, technical details
 
 
 
 
 
 
 
 
 
334
 
335
- Args:
336
- query (str): Technical question about specific vaccines
337
-
338
- Returns:
339
- str: Technical vaccine information with citations
340
- """
341
- if 'three' not in section_retrievers:
342
- return "Section 3 retriever not available"
343
- return section_tool_wrapper(section_retrievers['three'], f'./data/{section_paths["three"]}', query)
344
-
345
- # Create FunctionTool objects with focused selection
346
- tools = [
347
- # Primary tools - most commonly used
348
- FunctionTool.from_defaults(
349
- name="algerian_guide_search",
350
- fn=guide_retrieval_tool,
351
- description="PRIMARY TOOL: Search the complete Algerian National Vaccination Guide for any vaccination-related question"
352
- ),
353
- FunctionTool.from_defaults(
354
- name="who_immunization_search",
355
- fn=immunization_tool,
356
- description="SECONDARY TOOL: Search WHO Immunization in Practice for international standards and WHO recommendations"
357
- ),
358
- # Specialized tools - use only when very specific
359
- FunctionTool.from_defaults(
360
- name="disease_definitions_search",
361
- fn=section_two_tool,
362
- description="SPECIALIZED: Search for specific disease definitions, symptoms, and characteristics"
363
- ),
364
- FunctionTool.from_defaults(
365
- name="vaccine_technical_search",
366
- fn=section_three_tool,
367
- description="SPECIALIZED: Search for technical vaccine details, composition, and administration methods"
368
- ),
369
- ]
370
-
371
- return tools
372
 
373
 
374
- def prepare_environment():
375
- """Main function to prepare the environment and return tools"""
376
- print("Setting up models...")
377
- embedding_function, llm = setup_models()
378
-
379
- print("Creating section tools...")
380
- tools = create_section_tools(embedding_function, llm)
381
-
382
- print("Environment prepared successfully!")
383
- return tools, llm
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Enhanced RAG Pipeline for vaccine assistant - Fixed version with max iterations control
4
+ Handles agent creation and question answering with sequential citation numbering
5
  """
6
 
 
7
  import json
8
  import re
9
+ from llama_index.core import PromptTemplate
10
+ from llama_index.core.agent import ReActAgent
11
+ from llama_index.llms.google_genai import GoogleGenAI
12
+ from langdetect import detect
13
+ import os
 
 
 
 
 
14
 
15
 
16
  def extract_source_ids(response_text):
 
47
  ids = [id_str.strip() for id_str in citation.split(',')]
48
  all_ids.extend(ids)
49
 
50
+ # Get unique source IDs while preserving order
51
+ seen = set()
52
+ source_ids = []
53
+ for id_str in all_ids:
54
+ if id_str not in seen:
55
+ seen.add(id_str)
56
+ source_ids.append(id_str)
57
 
58
  if not source_ids:
59
  print("Warning: No valid source IDs found after filtering.")
 
62
  return source_ids
63
 
64
 
65
+ def convert_citations_to_sequential(response_text, source_id_to_number_map):
66
+ """
67
+ Convert source IDs in response text to sequential numbers.
 
 
 
68
 
69
+ Args:
70
+ response_text (str): The response text with source ID citations
71
+ source_id_to_number_map (dict): Mapping from source IDs to sequential numbers
72
+
73
+ Returns:
74
+ str: Response text with sequential number citations
75
+ """
76
+ def replace_citation(match):
77
+ citation_content = match.group(1)
78
+ # Handle multiple IDs in one citation (comma-separated)
79
+ ids = [id_str.strip() for id_str in citation_content.split(',')]
80
+
81
+ # Convert each ID to its sequential number
82
+ numbers = []
83
+ for id_str in ids:
84
+ if id_str in source_id_to_number_map:
85
+ numbers.append(str(source_id_to_number_map[id_str]))
86
+
87
+ # Return the formatted citation with sequential numbers
88
+ if len(numbers) == 1:
89
+ return f"[{numbers[0]}]"
90
+ elif len(numbers) > 1:
91
+ return f"[{','.join(numbers)}]"
92
+ else:
93
+ return match.group(0) # Return original if no mapping found
94
 
95
+ # Replace all citations in the text
96
+ sequential_response = re.sub(r'\[([^\[\]]+)\]', replace_citation, response_text)
97
+ return sequential_response
98
 
99
 
100
+ def create_safe_custom_prompt(tools, llm):
101
+ """Create a safe version that won't have formatting conflicts"""
 
 
 
102
 
103
+ custom_instructions = """
104
+ ## MEDICAL ASSISTANT ROLE
105
+ You are a helpful and knowledgeable AI-powered vaccine assistant designed to support doctors in clinical decision-making.
106
+ You provide evidence-based guidance using only information from official vaccine medical documents.
107
+ Answer the doctor's question accurately and concisely using only the provided information.
108
+
109
+ ## CRITICAL RULES FOR EFFICIENCY
110
+
111
+ ### Tool Usage Strategy
112
+ 1. **MAXIMUM 3 TOOL CALLS**: You must provide a complete answer within 3 tool calls maximum.
113
+ 2. **Smart Tool Selection**: Choose the most relevant tool first based on the question topic.
114
+ 3. **Comparative Questions**: For questions comparing documents/protocols:
115
+ - First tool call: Get information from primary source (e.g., Algerian guide)
116
+ - Second tool call: Get information from secondary source (e.g., WHO document)
117
+ - Third tool call: Only if absolutely necessary for missing details
118
+ 4. **Stop Early**: If you have sufficient information after 1-2 tool calls, provide your answer immediately.
119
+
120
+ ### Citation and Sourcing
121
+ 1. For each fact in your response, include an inline citation in the format [Source] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90].
122
+ 2. Do NOT use 'Source:' in the citation format; use only the Source in square brackets.
123
+ 3. If a fact is supported by multiple sources, use adjacent citations: [source1][source2]
124
+ 4. Use ONLY the provided information and never include facts from your general knowledge.
125
+
126
+ ### Content Formatting
127
+ 1. When rendering tables:
128
+ - Convert HTML tables into clean Markdown format
129
+ - Preserve all original headers and data rows exactly
130
+ - Include the citation in the table caption, e.g., 'Table: Vaccination Schedule [Source]'
131
+ 2. For lists, maintain the original bullet points/numbering and include citations.
132
+ 3. Present information concisely but ensure clinical accuracy is never compromised.
133
+
134
+ ### Answer Completeness Guidelines
135
+ - If you find relevant information from 1-2 sources, synthesize and provide a complete answer
136
+ - Don't keep searching for more sources unless critical information is missing
137
+ - For comparative questions, clearly structure your answer with sections for each source
138
+ - If information is not available in the documents, clearly state this limitation
139
+
140
+ ---
141
 
142
+ """
143
+
144
+ # Get the exact original template first
145
+ temp_agent = ReActAgent.from_tools(tools, llm=llm, verbose=False)
146
+ original_prompts = temp_agent.get_prompts()
147
+ original_template = original_prompts["agent_worker:system_prompt"].template
148
 
149
+ # Add instructions at the very beginning
150
+ safe_template = f"{custom_instructions}{original_template}"
 
 
 
 
 
 
151
 
152
+ # Create new prompt with same metadata as original
153
+ original_prompt = original_prompts["agent_worker:system_prompt"]
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  try:
156
+ new_prompt = PromptTemplate(
157
+ template=safe_template,
158
+ template_vars=original_prompt.template_vars,
159
+ metadata=original_prompt.metadata if hasattr(original_prompt, 'metadata') else None
160
+ )
161
+ return new_prompt
162
+ except:
163
+ # Even safer fallback
164
+ return PromptTemplate(template=safe_template)
165
 
 
 
166
 
167
+ def create_agent(tools, llm):
168
+ """Create the ReAct agent with custom prompt and controlled max iterations"""
169
+
170
+ # Create agent with controlled max iterations (reduced from default 10 to 5)
171
+ agent = ReActAgent.from_tools(
172
+ tools,
173
+ llm=llm,
174
+ verbose=True,
175
+ max_iterations=5, # Reduced max iterations
176
+ )
177
 
178
+ # Create and apply safe custom prompt
179
+ try:
180
+ safe_custom_prompt = create_safe_custom_prompt(tools, llm)
181
+ agent.update_prompts({"agent_worker:system_prompt": safe_custom_prompt})
182
+ print("✅ Successfully updated with safe custom prompt and max_iterations=5")
 
 
 
 
 
 
 
 
 
 
 
 
183
  except Exception as e:
184
+ print(f" Safe prompt update failed: {e}")
185
+ print("⚠️ Using original agent without modifications")
186
+
187
+ return agent
188
 
189
 
190
+ def initialize_rag_pipeline(tools):
191
+ """Initialize the RAG pipeline with tools"""
192
 
193
+ # Initialize LlamaIndex LLM with specific parameters to improve efficiency
194
+ llama_index_llm = GoogleGenAI(
195
+ model="models/gemini-2.0-flash",
196
+ api_key=os.getenv('GOOGLE_API_KEY'),
197
+ temperature=0.1, # Lower temperature for more focused responses
198
+ )
 
 
 
 
 
 
 
199
 
200
+ # Create agent
201
+ agent = create_agent(tools, llama_index_llm)
 
 
 
 
202
 
203
+ return agent
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
 
205
 
206
+ def process_question(agent, question: str) -> str:
207
+ """Process a question through the RAG pipeline with timeout handling"""
208
+ try:
209
+ # Add timeout/retry logic
210
+ response = agent.chat(question)
211
+ return response.response
212
+ except Exception as e:
213
+ error_msg = str(e)
214
+ print(f"Error processing question: {error_msg}")
215
 
216
+ # Handle specific "max iterations" error
217
+ if "max iterations" in error_msg.lower() or "reached max" in error_msg.lower():
218
+ return ("I apologize, but I was unable to find a complete answer within the allowed search attempts. "
219
+ "This might be because the specific comparison you're asking about requires information "
220
+ "that spans multiple sections of the documents. Could you please rephrase your question "
221
+ "to be more specific about which aspect of the difference you're most interested in?")
222
 
223
+ return f"Error processing your question: {error_msg}"
224
+
225
+
226
+ def aswer_language_detection(response_text: str) -> str:
227
+ """
228
+ Detect the language of the response text.
229
+
230
+ Args:
231
+ response_text (str): The response text to analyze.
232
 
233
+ Returns:
234
+ str: Detected language code (e.g., 'en', 'fr', etc.)
235
+ """
236
+ try:
237
+ # Detect the language of the first 5 words of the response
238
+ first_line = " ".join(response_text.split()[:5])
239
+ first_line = re.sub(r'\[.*?\]', '', first_line) # Remove citations
240
+ answer_language = detect(first_line)
241
+ if answer_language not in ['en', 'ar', 'fr']:
242
+ answer_language = 'en'
243
+ except:
244
+ answer_language = 'en'
245
+
246
+ return answer_language
247
+
248
+
249
+ def process_question_with_sequential_citations(agent, question: str, chunks_directory="./data/") -> dict:
250
+ """
251
+ Process a question through the RAG pipeline and return response with sequential citation numbers.
252
+ Enhanced with better error handling for max iterations.
253
+
254
+ Args:
255
+ agent: The initialized RAG agent
256
+ question (str): The user's question
257
+ chunks_directory (str): Path to the directory containing JSON files
258
 
259
+ Returns:
260
+ dict: {
261
+ "response": str, # Response with sequential citation numbers [1], [2], etc.
262
+ "cited_elements_json": str, # JSON array of cited elements in order
263
+ "unique_ids": list, # Original source IDs in order
264
+ "citation_mapping": dict # Mapping from source ID to citation number
265
+ }
266
+ """
267
+ try:
268
+ # Get the response from the agent with improved error handling
269
+ response = agent.chat(question)
270
+ response_text = response.response
 
271
 
272
+ # Check if the response indicates max iterations was reached
273
+ if "max iterations" in response_text.lower() or len(response_text.strip()) == 0:
274
+ # Provide a more helpful fallback response
275
+ response_text = ("I apologize, but I encountered difficulties processing your comparative question "
276
+ "within the allowed search attempts. For questions comparing different protocols "
277
+ "or documents, please try asking about each aspect separately. For example, "
278
+ "first ask about the Algerian definition of Diphtheria, then ask about the WHO definition.")
279
 
280
+ # Extract source IDs from the response (preserving order)
281
+ unique_ids = extract_source_ids(response_text)
 
 
 
282
 
283
+ # Create mapping from source ID to sequential number
284
+ source_id_to_number = {source_id: i + 1 for i, source_id in enumerate(unique_ids)}
285
 
286
+ # Convert citations to sequential numbers
287
+ sequential_response = convert_citations_to_sequential(response_text, source_id_to_number)
 
 
 
 
 
 
 
 
 
 
 
 
 
288
 
289
+ # Load all chunks data to find cited elements
290
+ all_chunks_data = []
291
+ min_chunks_files = ["Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json",
292
+ "Immunization_in_Practice_WHO_eng_2015.json"]
293
 
294
+ for json_file in min_chunks_files:
295
+ json_path = os.path.join(chunks_directory, json_file)
296
+ try:
297
+ with open(json_path, "r", encoding="utf-8") as f:
298
+ chunks_data = json.load(f)
299
+ all_chunks_data.extend(chunks_data)
300
+ except Exception as e:
301
+ print(f"Warning: Could not load {json_file}: {e}")
302
 
303
+ # Get cited elements in the same order as the sequential citations
304
+ cited_elements_ordered = []
305
+ for source_id in unique_ids: # This preserves the order
306
+ for element in all_chunks_data:
307
+ if element.get("type") == 'TableElement':
308
+ if element.get("element_id") == source_id:
309
+ cited_elements_ordered.append(element)
310
+ break
311
+ else:
312
+ if "elements" in element:
313
+ for nested_element in element["elements"]:
314
+ if nested_element.get("element_id") == source_id:
315
+ cited_elements_ordered.append(nested_element)
316
+ break
317
+ else:
318
+ continue
319
+ break
320
 
321
+ # Convert to JSON
322
+ cited_elements_json = json.dumps(cited_elements_ordered, ensure_ascii=False, indent=2)
323
+ answer_language = aswer_language_detection(response_text)
 
 
 
 
 
 
 
 
 
 
324
 
325
+ return {
326
+ "response": sequential_response,
327
+ "cited_elements_json": cited_elements_json,
328
+ "unique_ids": unique_ids,
329
+ "citation_mapping": source_id_to_number,
330
+ "answer_language": answer_language
331
+ }
332
 
333
+ except Exception as e:
334
+ error_msg = str(e)
335
+ print(f"Error processing question: {error_msg}")
 
336
 
337
+ # Create appropriate fallback response based on error type
338
+ if "max iterations" in error_msg.lower() or "reached max" in error_msg.lower():
339
+ fallback_response = ("I apologize, but I was unable to complete the comparison within the allowed search attempts. "
340
+ "For complex comparative questions like yours about the differences between Algerian and WHO "
341
+ "definitions of Diphtheria, please try asking about each source separately: \n\n"
342
+ "1. First ask: 'What is the definition of Diphtheria in the Algerian vaccination guide?'\n"
343
+ "2. Then ask: 'What is the definition of Diphtheria in the WHO document?'\n\n"
344
+ "This will help me provide you with more focused and complete information.")
345
+ else:
346
+ fallback_response = f"I encountered an error while processing your question: {error_msg}"
347
 
348
+ return {
349
+ "response": fallback_response,
350
+ "cited_elements_json": "[]",
351
+ "unique_ids": [],
352
+ "citation_mapping": {},
353
+ "answer_language": "en"
354
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355
 
356
 
357
+ def process_question_with_citations(agent, question: str, chunks_directory="./data/") -> dict:
358
+ """
359
+ Legacy function - maintained for backward compatibility.
360
+ Now calls the new sequential citation function.
361
+ """
362
+ return process_question_with_sequential_citations(agent, question, chunks_directory)