Spaces:
Sleeping
Sleeping
Zeggai Abdellah
commited on
Commit
·
f5c821c
1
Parent(s):
938709d
fix handleing a complicated queations
Browse files- prepare_env.py +288 -405
- rag_pipeline.py +298 -308
prepare_env.py
CHANGED
|
@@ -1,22 +1,16 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
import os
|
| 8 |
import json
|
| 9 |
import re
|
| 10 |
-
import
|
| 11 |
-
from
|
| 12 |
-
from
|
| 13 |
-
from
|
| 14 |
-
|
| 15 |
-
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
| 16 |
-
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 17 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 18 |
-
from llama_index.core.tools import FunctionTool
|
| 19 |
-
from llama_index.core.schema import TextNode
|
| 20 |
|
| 21 |
|
| 22 |
def extract_source_ids(response_text):
|
|
@@ -53,13 +47,13 @@ def extract_source_ids(response_text):
|
|
| 53 |
ids = [id_str.strip() for id_str in citation.split(',')]
|
| 54 |
all_ids.extend(ids)
|
| 55 |
|
| 56 |
-
# Get unique source IDs
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
|
| 64 |
if not source_ids:
|
| 65 |
print("Warning: No valid source IDs found after filtering.")
|
|
@@ -68,412 +62,301 @@ def extract_source_ids(response_text):
|
|
| 68 |
return source_ids
|
| 69 |
|
| 70 |
|
| 71 |
-
def
|
| 72 |
-
"""
|
| 73 |
-
|
| 74 |
-
embedding_function = HuggingFaceEmbeddings(
|
| 75 |
-
model_name="intfloat/multilingual-e5-base"
|
| 76 |
-
)
|
| 77 |
-
|
| 78 |
-
# Initialize LLM
|
| 79 |
-
genai_api_key = os.getenv('GOOGLE_API_KEY')
|
| 80 |
-
llm = ChatGoogleGenerativeAI(
|
| 81 |
-
model="gemini-2.0-flash",
|
| 82 |
-
google_api_key=genai_api_key
|
| 83 |
-
)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
metadata = {
|
| 97 |
-
"language": "fra",
|
| 98 |
-
"source": element["filename"],
|
| 99 |
-
"filetype": element["filetype"],
|
| 100 |
-
"element_id": element["element_id"]
|
| 101 |
-
}
|
| 102 |
-
|
| 103 |
-
if "TableElement" == element["type"]:
|
| 104 |
-
metadata["table_text_as_html"] = element["table_text_as_html"]
|
| 105 |
|
| 106 |
-
doc = Document(page_content=text, metadata=metadata)
|
| 107 |
-
documents.append(doc)
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
documents=documents,
|
| 112 |
-
embedding=embedding_function,
|
| 113 |
-
collection_name=collection_name,
|
| 114 |
-
persist_directory="chroma_db_multilingual"
|
| 115 |
-
)
|
| 116 |
-
return vectorstore, documents
|
| 117 |
-
|
| 118 |
-
def create_retriever(vectorstore, docs, llm):
|
| 119 |
-
"""Create ensemble retriever with vector and BM25 search"""
|
| 120 |
-
# Vector retriever
|
| 121 |
-
vector_retriever = vectorstore.as_retriever(
|
| 122 |
-
search_type="similarity",
|
| 123 |
-
search_kwargs={"k": 6}
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
# BM25 retriever
|
| 127 |
-
bm25_retriever = BM25Retriever.from_documents(docs)
|
| 128 |
-
bm25_retriever.k = 2
|
| 129 |
-
|
| 130 |
-
# Ensemble retriever
|
| 131 |
-
ensemble_retriever = EnsembleRetriever(
|
| 132 |
-
retrievers=[vector_retriever, bm25_retriever],
|
| 133 |
-
weights=[0.5, 0.5]
|
| 134 |
-
)
|
| 135 |
-
|
| 136 |
-
# Multi-query expanding retriever
|
| 137 |
-
expanding_retriever = MultiQueryRetriever.from_llm(
|
| 138 |
-
retriever=ensemble_retriever,
|
| 139 |
-
llm=llm
|
| 140 |
-
)
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
-
|
| 172 |
-
return "No relevant documents found for the query."
|
| 173 |
-
|
| 174 |
-
chunk_ids = [node.metadata['element_id'] for node in retrieved_docs]
|
| 175 |
-
with open(section_path_chunks, "r", encoding="utf-8") as f:
|
| 176 |
-
chunks_data = json.load(f)
|
| 177 |
-
|
| 178 |
-
chunks_unique = [node for node in chunks_data if node.get('element_id', 'Unknown') in chunk_ids]
|
| 179 |
-
combined_text = []
|
| 180 |
-
|
| 181 |
-
for chu in chunks_unique:
|
| 182 |
-
if "TableElement" == chu["type"]:
|
| 183 |
-
text = f"[Source: {chu['element_id']}]\n CONTENT: \n{chu['text']}\n HTML: \n {chu['table_text_as_html']} \n\n"
|
| 184 |
-
combined_text.append(text)
|
| 185 |
-
else:
|
| 186 |
-
for element in chu["elements"]:
|
| 187 |
-
text = f"[Source: {element['element_id']}]\n CONTENT: \n{element['text']} \n\n"
|
| 188 |
-
combined_text.append(text)
|
| 189 |
-
|
| 190 |
-
result = "\n---\n".join(combined_text)
|
| 191 |
-
print(f"Retrieved {len(nodes_from_retrieved_docs)} documents for query: {query[:50]}...")
|
| 192 |
-
return result
|
| 193 |
-
except Exception as e:
|
| 194 |
-
print(f"Error in section tool: {e}")
|
| 195 |
-
return f"Error retrieving documents: {str(e)}"
|
| 196 |
-
|
| 197 |
-
def create_section_tools(embedding_function, llm):
|
| 198 |
-
"""Create all section-specific retrieval tools"""
|
| 199 |
-
|
| 200 |
-
# Define section paths
|
| 201 |
-
section_paths = {
|
| 202 |
-
'one': 'section_one_chunks.json',
|
| 203 |
-
'two': 'section_two_chunks.json',
|
| 204 |
-
'three': 'section_three_chunks.json',
|
| 205 |
-
'four': 'section_four_chunks.json',
|
| 206 |
-
'five': 'section_five_chunks.json',
|
| 207 |
-
'six': 'section_six_chunks.json',
|
| 208 |
-
'seven': 'section_seven_chunks.json',
|
| 209 |
-
'eight': 'section_eight_chunks.json',
|
| 210 |
-
'nine': 'section_nine_chunks.json',
|
| 211 |
-
'ten': 'section_ten_chunks.json'
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
# Create retrievers for each section
|
| 215 |
-
section_retrievers = {}
|
| 216 |
-
for section, path in section_paths.items():
|
| 217 |
-
if os.path.exists(path):
|
| 218 |
-
vstore, docs = create_vectorstore_from_json(f'./data/{path}', f"Guide_2023_{section}", embedding_function)
|
| 219 |
-
section_retrievers[section] = create_retriever(vstore, docs, llm)
|
| 220 |
|
| 221 |
-
#
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
guide_retriever = create_retriever(guide_vstore, guide_docs, llm)
|
| 226 |
-
else:
|
| 227 |
-
guide_retriever = None
|
| 228 |
-
# General-purpose tool (entire Algerian guide)
|
| 229 |
-
def guide_retrieval_tool(query: str) -> str:
|
| 230 |
-
"""
|
| 231 |
-
General-purpose retrieval tool for the entire Algerian National Vaccination Guide (2023).
|
| 232 |
-
|
| 233 |
-
Use this tool when a query spans multiple sections or cannot be routed confidently to a specific tool.
|
| 234 |
-
This is the fallback and all-encompassing tool to retrieve any vaccination-related information
|
| 235 |
-
from the national guide.
|
| 236 |
-
|
| 237 |
-
Secondary source: The WHO Immunization Guide can be queried separately via `immunization_tool`.
|
| 238 |
-
|
| 239 |
-
Args:
|
| 240 |
-
query (str): A general or complex question related to vaccination policy, schedules, or practice.
|
| 241 |
-
|
| 242 |
-
Returns:
|
| 243 |
-
str: Synthesized response based on the full Algerian guide.
|
| 244 |
-
"""
|
| 245 |
-
if not guide_retriever:
|
| 246 |
-
return "Guide retriever not available"
|
| 247 |
-
return section_tool_wrapper(guide_retriever, guide_path, query)
|
| 248 |
-
|
| 249 |
-
# Primary + Secondary Document Paths
|
| 250 |
-
immunization_path = './data/Immunization_in_Practice_WHO_eng_2015.json'
|
| 251 |
-
|
| 252 |
-
# WHO Immunization in Practice Tool
|
| 253 |
-
if os.path.exists(immunization_path):
|
| 254 |
-
immunization_vstore, immunization_docs = create_vectorstore_from_json(
|
| 255 |
-
immunization_path,
|
| 256 |
-
"Immunization_in_Practice_WHO_eng_2015",
|
| 257 |
-
embedding_function
|
| 258 |
-
)
|
| 259 |
-
immunization_retriever = create_retriever(immunization_vstore, immunization_docs, llm)
|
| 260 |
-
else:
|
| 261 |
-
immunization_retriever = None
|
| 262 |
-
|
| 263 |
-
def immunization_tool(query: str) -> str:
|
| 264 |
-
"""
|
| 265 |
-
WHO Immunization in Practice 2015 retrieval tool.
|
| 266 |
-
|
| 267 |
-
Use this tool to provide global best practices and operational guidance on immunization,
|
| 268 |
-
especially when context or clarification is needed beyond the Algerian national guide.
|
| 269 |
-
This can serve as a secondary source for training, logistics, and procedural reference.
|
| 270 |
-
|
| 271 |
-
Args:
|
| 272 |
-
query (str): A question related to immunization practice in general.
|
| 273 |
-
|
| 274 |
-
Returns:
|
| 275 |
-
str: Retrieved guidance from the WHO Immunization in Practice manual (2015).
|
| 276 |
-
"""
|
| 277 |
-
if not immunization_retriever:
|
| 278 |
-
return "Immunization in Practice retriever not available"
|
| 279 |
-
return section_tool_wrapper(immunization_retriever, immunization_path, query)
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
# Section-Specific Tools (Primary: Algerian National Vaccination Guide)
|
| 284 |
-
|
| 285 |
-
def section_one_tool(query: str) -> str:
|
| 286 |
-
"""
|
| 287 |
-
Section 1: Programme Élargi de Vaccination (PEV)
|
| 288 |
-
|
| 289 |
-
Use this tool to retrieve information about the Algerian immunization program:
|
| 290 |
-
its objectives, historical background, strengths and weaknesses, and justification
|
| 291 |
-
for calendar updates.
|
| 292 |
-
|
| 293 |
-
Primary source: Algerian National Vaccination Guide, Section 1.
|
| 294 |
-
Secondary source for operational benchmarks: WHO Immunization in Practice (optional).
|
| 295 |
-
|
| 296 |
-
Args:
|
| 297 |
-
query (str): A question about Algeria’s national immunization strategy.
|
| 298 |
-
|
| 299 |
-
Returns:
|
| 300 |
-
str: Relevant content from Section 1 of the guide.
|
| 301 |
-
"""
|
| 302 |
-
return section_tool_wrapper(section_retrievers['one'], section_paths['one'], query)
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
def section_two_tool(query: str) -> str:
|
| 306 |
-
"""
|
| 307 |
-
Section 2: Maladies Ciblées par la Vaccination
|
| 308 |
-
|
| 309 |
-
Use this tool for questions about the diseases targeted by the national vaccination calendar:
|
| 310 |
-
symptoms, transmission, complications, and prevention strategies.
|
| 311 |
-
|
| 312 |
-
Primary source: Algerian National Guide, Section 2.
|
| 313 |
-
Secondary source: WHO guide may support contextual insights.
|
| 314 |
-
|
| 315 |
-
Args:
|
| 316 |
-
query (str): A question about a vaccine-preventable disease (e.g. polio, rougeole).
|
| 317 |
-
|
| 318 |
-
Returns:
|
| 319 |
-
str: Disease-specific guidance from Section 2.
|
| 320 |
-
"""
|
| 321 |
-
return section_tool_wrapper(section_retrievers['two'], section_paths['two'], query)
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
def section_three_tool(query: str) -> str:
|
| 325 |
-
"""
|
| 326 |
-
Section 3: Vaccins du Calendrier
|
| 327 |
-
|
| 328 |
-
Use this tool to retrieve technical and procedural information about the vaccines used in the calendar:
|
| 329 |
-
names, contents, administration method, and dosing details.
|
| 330 |
-
|
| 331 |
-
Args:
|
| 332 |
-
query (str): A question about a specific vaccine's type or method of use.
|
| 333 |
-
|
| 334 |
-
Returns:
|
| 335 |
-
str: Vaccine information from Section 3.
|
| 336 |
-
"""
|
| 337 |
-
return section_tool_wrapper(section_retrievers['three'], section_paths['three'], query)
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
def section_four_tool(query: str) -> str:
|
| 341 |
-
"""
|
| 342 |
-
Section 4: Rattrapage Vaccinal
|
| 343 |
-
|
| 344 |
-
Use this tool to determine catch-up strategies for children who missed or delayed one or more doses.
|
| 345 |
-
It provides age-adjusted rescheduling rules and justifications.
|
| 346 |
-
|
| 347 |
-
Args:
|
| 348 |
-
query (str): A question about how to manage missed vaccinations.
|
| 349 |
-
|
| 350 |
-
Returns:
|
| 351 |
-
str: Catch-up guidelines from Section 4.
|
| 352 |
-
"""
|
| 353 |
-
return section_tool_wrapper(section_retrievers['four'], section_paths['four'], query)
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
def section_five_tool(query: str) -> str:
|
| 357 |
-
"""
|
| 358 |
-
Section 5: Vaccination des Populations Particulières
|
| 359 |
|
| 360 |
-
|
| 361 |
-
|
| 362 |
|
| 363 |
-
|
| 364 |
-
|
| 365 |
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
Use this tool for logistics, storage conditions, temperature monitoring,
|
| 377 |
-
and emergency procedures in case of cold chain failure.
|
| 378 |
-
|
| 379 |
-
Args:
|
| 380 |
-
query (str): A question about how vaccines should be stored and transported.
|
| 381 |
-
|
| 382 |
-
Returns:
|
| 383 |
-
str: Operational cold chain standards from Section 6.
|
| 384 |
-
"""
|
| 385 |
-
return section_tool_wrapper(section_retrievers['six'], section_paths['six'], query)
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
def section_seven_tool(query: str) -> str:
|
| 389 |
-
"""
|
| 390 |
-
Section 7: Sécurité des Injections
|
| 391 |
-
|
| 392 |
-
Use this tool to ensure injection safety: handling equipment, preventing needle-stick injuries,
|
| 393 |
-
and disposing of biomedical waste.
|
| 394 |
-
|
| 395 |
-
Args:
|
| 396 |
-
query (str): A question about safe injection practices.
|
| 397 |
-
|
| 398 |
-
Returns:
|
| 399 |
-
str: Procedures and guidelines from Section 7.
|
| 400 |
-
"""
|
| 401 |
-
return section_tool_wrapper(section_retrievers['seven'], section_paths['seven'], query)
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
def section_eight_tool(query: str) -> str:
|
| 405 |
-
"""
|
| 406 |
-
Section 8: Tenue d'une Séance de Vaccination & Vaccinovigilance
|
| 407 |
-
|
| 408 |
-
Use this tool to plan and monitor vaccination sessions, including material preparation,
|
| 409 |
-
injection recording, and handling of adverse events post-immunization (AEFI).
|
| 410 |
-
|
| 411 |
-
Args:
|
| 412 |
-
query (str): A question about session operations or vaccine side effect monitoring.
|
| 413 |
-
|
| 414 |
-
Returns:
|
| 415 |
-
str: Guidelines from Section 8.
|
| 416 |
-
"""
|
| 417 |
-
return section_tool_wrapper(section_retrievers['eight'], section_paths['eight'], query)
|
| 418 |
|
| 419 |
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
-
|
| 425 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 426 |
|
| 427 |
-
|
| 428 |
-
query (str): A question about planning and organizing vaccination sessions.
|
| 429 |
|
| 430 |
-
Returns:
|
| 431 |
-
str: Recommendations from Section 9.
|
| 432 |
-
"""
|
| 433 |
-
return section_tool_wrapper(section_retrievers['nine'], section_paths['nine'], query)
|
| 434 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
|
| 436 |
-
def section_ten_tool(query: str) -> str:
|
| 437 |
-
"""
|
| 438 |
-
Section 10: Mobilisation Sociale
|
| 439 |
|
| 440 |
-
|
| 441 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 442 |
|
| 443 |
-
Args:
|
| 444 |
-
query (str): A question about public communication and trust-building around vaccines.
|
| 445 |
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
return section_tool_wrapper(section_retrievers['ten'], section_paths['ten'], query)
|
| 450 |
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
-
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
]
|
| 467 |
|
| 468 |
-
return
|
| 469 |
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
print("Creating section tools...")
|
| 476 |
-
tools = create_section_tools(embedding_function, llm)
|
| 477 |
|
| 478 |
-
|
| 479 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
+
Enhanced RAG Pipeline for vaccine assistant - Fixed version with max iterations control
|
| 4 |
+
Handles agent creation and question answering with sequential citation numbering
|
| 5 |
"""
|
| 6 |
|
|
|
|
| 7 |
import json
|
| 8 |
import re
|
| 9 |
+
from llama_index.core import PromptTemplate
|
| 10 |
+
from llama_index.core.agent import ReActAgent
|
| 11 |
+
from llama_index.llms.google_genai import GoogleGenAI
|
| 12 |
+
from langdetect import detect
|
| 13 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def extract_source_ids(response_text):
|
|
|
|
| 47 |
ids = [id_str.strip() for id_str in citation.split(',')]
|
| 48 |
all_ids.extend(ids)
|
| 49 |
|
| 50 |
+
# Get unique source IDs while preserving order
|
| 51 |
+
seen = set()
|
| 52 |
+
source_ids = []
|
| 53 |
+
for id_str in all_ids:
|
| 54 |
+
if id_str not in seen:
|
| 55 |
+
seen.add(id_str)
|
| 56 |
+
source_ids.append(id_str)
|
| 57 |
|
| 58 |
if not source_ids:
|
| 59 |
print("Warning: No valid source IDs found after filtering.")
|
|
|
|
| 62 |
return source_ids
|
| 63 |
|
| 64 |
|
| 65 |
+
def convert_citations_to_sequential(response_text, source_id_to_number_map):
|
| 66 |
+
"""
|
| 67 |
+
Convert source IDs in response text to sequential numbers.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
+
Args:
|
| 70 |
+
response_text (str): The response text with source ID citations
|
| 71 |
+
source_id_to_number_map (dict): Mapping from source IDs to sequential numbers
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
str: Response text with sequential number citations
|
| 75 |
+
"""
|
| 76 |
+
def replace_citation(match):
|
| 77 |
+
citation_content = match.group(1)
|
| 78 |
+
# Handle multiple IDs in one citation (comma-separated)
|
| 79 |
+
ids = [id_str.strip() for id_str in citation_content.split(',')]
|
| 80 |
+
|
| 81 |
+
# Convert each ID to its sequential number
|
| 82 |
+
numbers = []
|
| 83 |
+
for id_str in ids:
|
| 84 |
+
if id_str in source_id_to_number_map:
|
| 85 |
+
numbers.append(str(source_id_to_number_map[id_str]))
|
| 86 |
+
|
| 87 |
+
# Return the formatted citation with sequential numbers
|
| 88 |
+
if len(numbers) == 1:
|
| 89 |
+
return f"[{numbers[0]}]"
|
| 90 |
+
elif len(numbers) > 1:
|
| 91 |
+
return f"[{','.join(numbers)}]"
|
| 92 |
+
else:
|
| 93 |
+
return match.group(0) # Return original if no mapping found
|
| 94 |
|
| 95 |
+
# Replace all citations in the text
|
| 96 |
+
sequential_response = re.sub(r'\[([^\[\]]+)\]', replace_citation, response_text)
|
| 97 |
+
return sequential_response
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
def create_safe_custom_prompt(tools, llm):
|
| 101 |
+
"""Create a safe version that won't have formatting conflicts"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
+
custom_instructions = """
|
| 104 |
+
## MEDICAL ASSISTANT ROLE
|
| 105 |
+
You are a helpful and knowledgeable AI-powered vaccine assistant designed to support doctors in clinical decision-making.
|
| 106 |
+
You provide evidence-based guidance using only information from official vaccine medical documents.
|
| 107 |
+
Answer the doctor's question accurately and concisely using only the provided information.
|
| 108 |
+
|
| 109 |
+
## CRITICAL RULES FOR EFFICIENCY
|
| 110 |
+
|
| 111 |
+
### Tool Usage Strategy
|
| 112 |
+
1. **MAXIMUM 3 TOOL CALLS**: You must provide a complete answer within 3 tool calls maximum.
|
| 113 |
+
2. **Smart Tool Selection**: Choose the most relevant tool first based on the question topic.
|
| 114 |
+
3. **Comparative Questions**: For questions comparing documents/protocols:
|
| 115 |
+
- First tool call: Get information from primary source (e.g., Algerian guide)
|
| 116 |
+
- Second tool call: Get information from secondary source (e.g., WHO document)
|
| 117 |
+
- Third tool call: Only if absolutely necessary for missing details
|
| 118 |
+
4. **Stop Early**: If you have sufficient information after 1-2 tool calls, provide your answer immediately.
|
| 119 |
+
|
| 120 |
+
### Citation and Sourcing
|
| 121 |
+
1. For each fact in your response, include an inline citation in the format [Source] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90].
|
| 122 |
+
2. Do NOT use 'Source:' in the citation format; use only the Source in square brackets.
|
| 123 |
+
3. If a fact is supported by multiple sources, use adjacent citations: [source1][source2]
|
| 124 |
+
4. Use ONLY the provided information and never include facts from your general knowledge.
|
| 125 |
+
|
| 126 |
+
### Content Formatting
|
| 127 |
+
1. When rendering tables:
|
| 128 |
+
- Convert HTML tables into clean Markdown format
|
| 129 |
+
- Preserve all original headers and data rows exactly
|
| 130 |
+
- Include the citation in the table caption, e.g., 'Table: Vaccination Schedule [Source]'
|
| 131 |
+
2. For lists, maintain the original bullet points/numbering and include citations.
|
| 132 |
+
3. Present information concisely but ensure clinical accuracy is never compromised.
|
| 133 |
+
|
| 134 |
+
### Answer Completeness Guidelines
|
| 135 |
+
- If you find relevant information from 1-2 sources, synthesize and provide a complete answer
|
| 136 |
+
- Don't keep searching for more sources unless critical information is missing
|
| 137 |
+
- For comparative questions, clearly structure your answer with sections for each source
|
| 138 |
+
- If information is not available in the documents, clearly state this limitation
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
|
| 142 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
+
# Get the exact original template first
|
| 145 |
+
temp_agent = ReActAgent.from_tools(tools, llm=llm, verbose=False)
|
| 146 |
+
original_prompts = temp_agent.get_prompts()
|
| 147 |
+
original_template = original_prompts["agent_worker:system_prompt"].template
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
# Add instructions at the very beginning
|
| 150 |
+
safe_template = f"{custom_instructions}{original_template}"
|
| 151 |
|
| 152 |
+
# Create new prompt with same metadata as original
|
| 153 |
+
original_prompt = original_prompts["agent_worker:system_prompt"]
|
| 154 |
|
| 155 |
+
try:
|
| 156 |
+
new_prompt = PromptTemplate(
|
| 157 |
+
template=safe_template,
|
| 158 |
+
template_vars=original_prompt.template_vars,
|
| 159 |
+
metadata=original_prompt.metadata if hasattr(original_prompt, 'metadata') else None
|
| 160 |
+
)
|
| 161 |
+
return new_prompt
|
| 162 |
+
except:
|
| 163 |
+
# Even safer fallback
|
| 164 |
+
return PromptTemplate(template=safe_template)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
|
| 167 |
+
def create_agent(tools, llm):
|
| 168 |
+
"""Create the ReAct agent with custom prompt and controlled max iterations"""
|
| 169 |
+
|
| 170 |
+
# Create agent with controlled max iterations (reduced from default 10 to 5)
|
| 171 |
+
agent = ReActAgent.from_tools(
|
| 172 |
+
tools,
|
| 173 |
+
llm=llm,
|
| 174 |
+
verbose=True,
|
| 175 |
+
max_iterations=5, # Reduced max iterations
|
| 176 |
+
)
|
| 177 |
|
| 178 |
+
# Create and apply safe custom prompt
|
| 179 |
+
try:
|
| 180 |
+
safe_custom_prompt = create_safe_custom_prompt(tools, llm)
|
| 181 |
+
agent.update_prompts({"agent_worker:system_prompt": safe_custom_prompt})
|
| 182 |
+
print("✅ Successfully updated with safe custom prompt and max_iterations=5")
|
| 183 |
+
except Exception as e:
|
| 184 |
+
print(f"❌ Safe prompt update failed: {e}")
|
| 185 |
+
print("⚠️ Using original agent without modifications")
|
| 186 |
|
| 187 |
+
return agent
|
|
|
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
|
| 190 |
+
def initialize_rag_pipeline(tools):
|
| 191 |
+
"""Initialize the RAG pipeline with tools"""
|
| 192 |
+
|
| 193 |
+
# Initialize LlamaIndex LLM with specific parameters to improve efficiency
|
| 194 |
+
llama_index_llm = GoogleGenAI(
|
| 195 |
+
model="models/gemini-2.0-flash",
|
| 196 |
+
api_key=os.getenv('GOOGLE_API_KEY'),
|
| 197 |
+
temperature=0.1, # Lower temperature for more focused responses
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Create agent
|
| 201 |
+
agent = create_agent(tools, llama_index_llm)
|
| 202 |
+
|
| 203 |
+
return agent
|
| 204 |
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
+
def process_question(agent, question: str) -> str:
|
| 207 |
+
"""Process a question through the RAG pipeline with timeout handling"""
|
| 208 |
+
try:
|
| 209 |
+
# Add timeout/retry logic
|
| 210 |
+
response = agent.chat(question)
|
| 211 |
+
return response.response
|
| 212 |
+
except Exception as e:
|
| 213 |
+
error_msg = str(e)
|
| 214 |
+
print(f"Error processing question: {error_msg}")
|
| 215 |
+
|
| 216 |
+
# Handle specific "max iterations" error
|
| 217 |
+
if "max iterations" in error_msg.lower() or "reached max" in error_msg.lower():
|
| 218 |
+
return ("I apologize, but I was unable to find a complete answer within the allowed search attempts. "
|
| 219 |
+
"This might be because the specific comparison you're asking about requires information "
|
| 220 |
+
"that spans multiple sections of the documents. Could you please rephrase your question "
|
| 221 |
+
"to be more specific about which aspect of the difference you're most interested in?")
|
| 222 |
+
|
| 223 |
+
return f"Error processing your question: {error_msg}"
|
| 224 |
|
|
|
|
|
|
|
| 225 |
|
| 226 |
+
def aswer_language_detection(response_text: str) -> str:
|
| 227 |
+
"""
|
| 228 |
+
Detect the language of the response text.
|
|
|
|
| 229 |
|
| 230 |
+
Args:
|
| 231 |
+
response_text (str): The response text to analyze.
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
str: Detected language code (e.g., 'en', 'fr', etc.)
|
| 235 |
+
"""
|
| 236 |
+
try:
|
| 237 |
+
# Detect the language of the first 5 words of the response
|
| 238 |
+
first_line = " ".join(response_text.split()[:5])
|
| 239 |
+
first_line = re.sub(r'\[.*?\]', '', first_line) # Remove citations
|
| 240 |
+
answer_language = detect(first_line)
|
| 241 |
+
if answer_language not in ['en', 'ar', 'fr']:
|
| 242 |
+
answer_language = 'en'
|
| 243 |
+
except:
|
| 244 |
+
answer_language = 'en'
|
|
|
|
| 245 |
|
| 246 |
+
return answer_language
|
| 247 |
|
| 248 |
+
|
| 249 |
+
def process_question_with_sequential_citations(agent, question: str, chunks_directory="./data/") -> dict:
|
| 250 |
+
"""
|
| 251 |
+
Process a question through the RAG pipeline and return response with sequential citation numbers.
|
| 252 |
+
Enhanced with better error handling for max iterations.
|
|
|
|
|
|
|
| 253 |
|
| 254 |
+
Args:
|
| 255 |
+
agent: The initialized RAG agent
|
| 256 |
+
question (str): The user's question
|
| 257 |
+
chunks_directory (str): Path to the directory containing JSON files
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
dict: {
|
| 261 |
+
"response": str, # Response with sequential citation numbers [1], [2], etc.
|
| 262 |
+
"cited_elements_json": str, # JSON array of cited elements in order
|
| 263 |
+
"unique_ids": list, # Original source IDs in order
|
| 264 |
+
"citation_mapping": dict # Mapping from source ID to citation number
|
| 265 |
+
}
|
| 266 |
+
"""
|
| 267 |
+
try:
|
| 268 |
+
# Get the response from the agent with improved error handling
|
| 269 |
+
response = agent.chat(question)
|
| 270 |
+
response_text = response.response
|
| 271 |
+
|
| 272 |
+
# Check if the response indicates max iterations was reached
|
| 273 |
+
if "max iterations" in response_text.lower() or len(response_text.strip()) == 0:
|
| 274 |
+
# Provide a more helpful fallback response
|
| 275 |
+
response_text = ("I apologize, but I encountered difficulties processing your comparative question "
|
| 276 |
+
"within the allowed search attempts. For questions comparing different protocols "
|
| 277 |
+
"or documents, please try asking about each aspect separately. For example, "
|
| 278 |
+
"first ask about the Algerian definition of Diphtheria, then ask about the WHO definition.")
|
| 279 |
+
|
| 280 |
+
# Extract source IDs from the response (preserving order)
|
| 281 |
+
unique_ids = extract_source_ids(response_text)
|
| 282 |
+
|
| 283 |
+
# Create mapping from source ID to sequential number
|
| 284 |
+
source_id_to_number = {source_id: i + 1 for i, source_id in enumerate(unique_ids)}
|
| 285 |
+
|
| 286 |
+
# Convert citations to sequential numbers
|
| 287 |
+
sequential_response = convert_citations_to_sequential(response_text, source_id_to_number)
|
| 288 |
+
|
| 289 |
+
# Load all chunks data to find cited elements
|
| 290 |
+
all_chunks_data = []
|
| 291 |
+
min_chunks_files = ["Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json",
|
| 292 |
+
"Immunization_in_Practice_WHO_eng_2015.json"]
|
| 293 |
+
|
| 294 |
+
for json_file in min_chunks_files:
|
| 295 |
+
json_path = os.path.join(chunks_directory, json_file)
|
| 296 |
+
try:
|
| 297 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 298 |
+
chunks_data = json.load(f)
|
| 299 |
+
all_chunks_data.extend(chunks_data)
|
| 300 |
+
except Exception as e:
|
| 301 |
+
print(f"Warning: Could not load {json_file}: {e}")
|
| 302 |
+
|
| 303 |
+
# Get cited elements in the same order as the sequential citations
|
| 304 |
+
cited_elements_ordered = []
|
| 305 |
+
for source_id in unique_ids: # This preserves the order
|
| 306 |
+
for element in all_chunks_data:
|
| 307 |
+
if element.get("type") == 'TableElement':
|
| 308 |
+
if element.get("element_id") == source_id:
|
| 309 |
+
cited_elements_ordered.append(element)
|
| 310 |
+
break
|
| 311 |
+
else:
|
| 312 |
+
if "elements" in element:
|
| 313 |
+
for nested_element in element["elements"]:
|
| 314 |
+
if nested_element.get("element_id") == source_id:
|
| 315 |
+
cited_elements_ordered.append(nested_element)
|
| 316 |
+
break
|
| 317 |
+
else:
|
| 318 |
+
continue
|
| 319 |
+
break
|
| 320 |
+
|
| 321 |
+
# Convert to JSON
|
| 322 |
+
cited_elements_json = json.dumps(cited_elements_ordered, ensure_ascii=False, indent=2)
|
| 323 |
+
answer_language = aswer_language_detection(response_text)
|
| 324 |
+
|
| 325 |
+
return {
|
| 326 |
+
"response": sequential_response,
|
| 327 |
+
"cited_elements_json": cited_elements_json,
|
| 328 |
+
"unique_ids": unique_ids,
|
| 329 |
+
"citation_mapping": source_id_to_number,
|
| 330 |
+
"answer_language": answer_language
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
except Exception as e:
|
| 334 |
+
error_msg = str(e)
|
| 335 |
+
print(f"Error processing question: {error_msg}")
|
| 336 |
+
|
| 337 |
+
# Create appropriate fallback response based on error type
|
| 338 |
+
if "max iterations" in error_msg.lower() or "reached max" in error_msg.lower():
|
| 339 |
+
fallback_response = ("I apologize, but I was unable to complete the comparison within the allowed search attempts. "
|
| 340 |
+
"For complex comparative questions like yours about the differences between Algerian and WHO "
|
| 341 |
+
"definitions of Diphtheria, please try asking about each source separately: \n\n"
|
| 342 |
+
"1. First ask: 'What is the definition of Diphtheria in the Algerian vaccination guide?'\n"
|
| 343 |
+
"2. Then ask: 'What is the definition of Diphtheria in the WHO document?'\n\n"
|
| 344 |
+
"This will help me provide you with more focused and complete information.")
|
| 345 |
+
else:
|
| 346 |
+
fallback_response = f"I encountered an error while processing your question: {error_msg}"
|
| 347 |
+
|
| 348 |
+
return {
|
| 349 |
+
"response": fallback_response,
|
| 350 |
+
"cited_elements_json": "[]",
|
| 351 |
+
"unique_ids": [],
|
| 352 |
+
"citation_mapping": {},
|
| 353 |
+
"answer_language": "en"
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def process_question_with_citations(agent, question: str, chunks_directory="./data/") -> dict:
|
| 358 |
+
"""
|
| 359 |
+
Legacy function - maintained for backward compatibility.
|
| 360 |
+
Now calls the new sequential citation function.
|
| 361 |
+
"""
|
| 362 |
+
return process_question_with_sequential_citations(agent, question, chunks_directory)
|
rag_pipeline.py
CHANGED
|
@@ -1,16 +1,22 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
|
|
|
| 7 |
import json
|
| 8 |
import re
|
| 9 |
-
|
| 10 |
-
from
|
| 11 |
-
from
|
| 12 |
-
from
|
| 13 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
|
| 16 |
def extract_source_ids(response_text):
|
|
@@ -47,13 +53,8 @@ def extract_source_ids(response_text):
|
|
| 47 |
ids = [id_str.strip() for id_str in citation.split(',')]
|
| 48 |
all_ids.extend(ids)
|
| 49 |
|
| 50 |
-
# Get unique source IDs
|
| 51 |
-
|
| 52 |
-
source_ids = []
|
| 53 |
-
for id_str in all_ids:
|
| 54 |
-
if id_str not in seen:
|
| 55 |
-
seen.add(id_str)
|
| 56 |
-
source_ids.append(id_str)
|
| 57 |
|
| 58 |
if not source_ids:
|
| 59 |
print("Warning: No valid source IDs found after filtering.")
|
|
@@ -62,332 +63,321 @@ def extract_source_ids(response_text):
|
|
| 62 |
return source_ids
|
| 63 |
|
| 64 |
|
| 65 |
-
def
|
| 66 |
-
"""
|
| 67 |
-
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
def replace_citation(match):
|
| 77 |
-
citation_content = match.group(1)
|
| 78 |
-
# Handle multiple IDs in one citation (comma-separated)
|
| 79 |
-
ids = [id_str.strip() for id_str in citation_content.split(',')]
|
| 80 |
-
|
| 81 |
-
# Convert each ID to its sequential number
|
| 82 |
-
numbers = []
|
| 83 |
-
for id_str in ids:
|
| 84 |
-
if id_str in source_id_to_number_map:
|
| 85 |
-
numbers.append(str(source_id_to_number_map[id_str]))
|
| 86 |
-
|
| 87 |
-
# Return the formatted citation with sequential numbers
|
| 88 |
-
if len(numbers) == 1:
|
| 89 |
-
return f"[{numbers[0]}]"
|
| 90 |
-
elif len(numbers) > 1:
|
| 91 |
-
return f"[{','.join(numbers)}]"
|
| 92 |
-
else:
|
| 93 |
-
return match.group(0) # Return original if no mapping found
|
| 94 |
|
| 95 |
-
|
| 96 |
-
sequential_response = re.sub(r'\[([^\[\]]+)\]', replace_citation, response_text)
|
| 97 |
-
return sequential_response
|
| 98 |
|
| 99 |
|
| 100 |
-
def
|
| 101 |
-
"""Create
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
1. For each fact in your response, include an inline citation in the format [Source] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90].
|
| 113 |
-
2. Do NOT use 'Source:' in the citation format; use only the Source in square brackets.
|
| 114 |
-
3. If a fact is supported by multiple sources, use the following format:
|
| 115 |
-
- Use adjacent citations: [e795ebd28318886c0b1a5395ac30ad90][21a932b2340bb16707763f57f0ad2]
|
| 116 |
-
4. Use ONLY the provided information and never include facts from your general knowledge.
|
| 117 |
-
|
| 118 |
-
### Content Formatting
|
| 119 |
-
1. When rendering tables:
|
| 120 |
-
- Convert HTML tables into clean Markdown format
|
| 121 |
-
- Preserve all original headers and data rows exactly
|
| 122 |
-
- Include the citation in the table caption, e.g., 'Table: Vaccination Schedule [Source]'
|
| 123 |
-
2. For lists, maintain the original bullet points/numbering and include citations.
|
| 124 |
-
3. Present information concisely but ensure clinical accuracy is never compromised.
|
| 125 |
-
|
| 126 |
-
## Tools
|
| 127 |
-
|
| 128 |
-
You have access to a wide variety of tools. You are responsible for using the tools in any sequence you deem appropriate to complete the task at hand.
|
| 129 |
-
This may require breaking the task into subtasks and using different tools to complete each subtask.
|
| 130 |
-
|
| 131 |
-
You have access to the following tools:
|
| 132 |
-
{tool_desc}
|
| 133 |
-
|
| 134 |
-
## Output Format
|
| 135 |
-
|
| 136 |
-
Please answer in the same language as the question and use the following format:
|
| 137 |
-
|
| 138 |
-
```
|
| 139 |
-
Thought: The current language of the user is: (user's language). I need to use a tool to help me answer the question.
|
| 140 |
-
Action: tool name (one of {tool_names}) if using a tool.
|
| 141 |
-
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
|
| 142 |
-
```
|
| 143 |
-
|
| 144 |
-
Please ALWAYS start with a Thought.
|
| 145 |
-
|
| 146 |
-
NEVER surround your response with markdown code markers. You may use code markers within your response if you need to.
|
| 147 |
-
|
| 148 |
-
Please use a valid JSON format for the Action Input. Do NOT do this {{"input": "hello world", "num_beams": 5}}.
|
| 149 |
-
|
| 150 |
-
If this format is used, the tool will respond in the following format:
|
| 151 |
-
|
| 152 |
-
```
|
| 153 |
-
Observation: tool response
|
| 154 |
-
```
|
| 155 |
|
| 156 |
-
|
|
|
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
Answer: [your answer here with proper citations (In the same language as the user's question)]
|
| 161 |
-
```
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
-
## Current Conversation
|
| 169 |
|
| 170 |
-
|
| 171 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
template_vars=["tool_desc", "tool_names"]
|
| 177 |
-
)
|
| 178 |
-
return custom_prompt
|
| 179 |
-
except:
|
| 180 |
-
# Fallback to simple template
|
| 181 |
-
return PromptTemplate(template=custom_instructions)
|
| 182 |
-
|
| 183 |
-
def create_safe_custom_prompt(tools, llm):
|
| 184 |
-
"""Create a safe version that won't have formatting conflicts"""
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
## IMPORTANT REQUIREMENTS
|
| 193 |
-
|
| 194 |
-
### Citation and Sourcing
|
| 195 |
-
1. For each fact in your response, include an inline citation in the format [Source] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90].
|
| 196 |
-
2. Do NOT use 'Source:' in the citation format; use only the Source in square brackets.
|
| 197 |
-
3. If a fact is supported by multiple sources, use the following format:
|
| 198 |
-
- Use adjacent citations: [e795ebd28318886c0b1a5395ac30ad90][21a932b2340bb16707763f57f0ad2]
|
| 199 |
-
4. Use ONLY the provided information and never include facts from your general knowledge.
|
| 200 |
-
|
| 201 |
-
### Content Formatting
|
| 202 |
-
1. When rendering tables:
|
| 203 |
-
- Convert HTML tables into clean Markdown format
|
| 204 |
-
- Preserve all original headers and data rows exactly
|
| 205 |
-
- Include the citation in the table caption, e.g., 'Table: Vaccination Schedule [Source]'
|
| 206 |
-
2. For lists, maintain the original bullet points/numbering and include citations.
|
| 207 |
-
3. Present information concisely but ensure clinical accuracy is never compromised.
|
| 208 |
-
|
| 209 |
-
---
|
| 210 |
-
|
| 211 |
-
"""
|
| 212 |
|
| 213 |
-
#
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
try:
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
template_vars=original_prompt.template_vars,
|
| 228 |
-
metadata=original_prompt.metadata if hasattr(original_prompt, 'metadata') else None
|
| 229 |
-
)
|
| 230 |
-
return new_prompt
|
| 231 |
-
except:
|
| 232 |
-
# Even safer fallback
|
| 233 |
-
return PromptTemplate(template=safe_template)
|
| 234 |
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
# Create agent
|
| 239 |
-
agent = ReActAgent.from_tools(
|
| 240 |
-
tools,
|
| 241 |
-
llm=llm,
|
| 242 |
-
verbose=True,
|
| 243 |
-
)
|
| 244 |
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
except Exception as e:
|
| 251 |
-
print(f"
|
| 252 |
-
|
| 253 |
|
| 254 |
-
return agent
|
| 255 |
|
| 256 |
-
def
|
| 257 |
-
"""
|
| 258 |
|
| 259 |
-
#
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
-
# Create
|
| 266 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
-
|
| 271 |
-
"""Process a question through the RAG pipeline"""
|
| 272 |
-
try:
|
| 273 |
-
response = agent.chat(question)
|
| 274 |
-
return response.response
|
| 275 |
-
except Exception as e:
|
| 276 |
-
print(f"Error processing question: {e}")
|
| 277 |
-
return f"Error processing your question: {str(e)}"
|
| 278 |
|
| 279 |
-
def
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
Args:
|
| 284 |
-
response_text (str): The response text to analyze.
|
| 285 |
|
| 286 |
-
|
| 287 |
-
str: Detected language code (e.g., 'en', 'fr', etc.)
|
| 288 |
-
"""
|
| 289 |
-
|
| 290 |
-
try:
|
| 291 |
-
# Detect the language of the first 5 words of the response
|
| 292 |
-
first_line = " ".join(response_text.split()[:5])
|
| 293 |
-
first_line = re.sub(r'\[.*?\]', '', first_line) # Remove citations
|
| 294 |
-
answer_language = detect(first_line)
|
| 295 |
-
if answer_language not in ['en', 'ar', 'fr']:
|
| 296 |
-
answer_language ='en'
|
| 297 |
-
except:
|
| 298 |
-
answer_language ='en'
|
| 299 |
-
|
| 300 |
-
finally:
|
| 301 |
-
return answer_language
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
def process_question_with_sequential_citations(agent, question: str, chunks_directory="./data/") -> dict:
|
| 305 |
-
"""
|
| 306 |
-
Process a question through the RAG pipeline and return response with sequential citation numbers.
|
| 307 |
-
|
| 308 |
-
Args:
|
| 309 |
-
agent: The initialized RAG agent
|
| 310 |
-
question (str): The user's question
|
| 311 |
-
chunks_directory (str): Path to the directory containing JSON files
|
| 312 |
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
}
|
| 320 |
-
"""
|
| 321 |
-
try:
|
| 322 |
-
# Get the response from the agent
|
| 323 |
-
response = agent.chat(question)
|
| 324 |
-
response_text = response.response
|
| 325 |
|
| 326 |
-
|
| 327 |
-
unique_ids = extract_source_ids(response_text)
|
| 328 |
|
| 329 |
-
|
| 330 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 331 |
|
| 332 |
-
|
| 333 |
-
sequential_response = convert_citations_to_sequential(response_text, source_id_to_number)
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
|
|
|
| 339 |
|
| 340 |
-
|
| 341 |
-
json_path = os.path.join(chunks_directory, json_file)
|
| 342 |
-
try:
|
| 343 |
-
with open(json_path, "r", encoding="utf-8") as f:
|
| 344 |
-
chunks_data = json.load(f)
|
| 345 |
-
all_chunks_data.extend(chunks_data)
|
| 346 |
-
except Exception as e:
|
| 347 |
-
print(f"Warning: Could not load {json_file}: {e}")
|
| 348 |
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
continue
|
| 365 |
-
break
|
| 366 |
|
| 367 |
-
|
| 368 |
-
cited_elements_json = json.dumps(cited_elements_ordered, ensure_ascii=False, indent=2)
|
| 369 |
-
aswer_language= aswer_language_detection(response_text)
|
| 370 |
-
return {
|
| 371 |
-
"response": sequential_response,
|
| 372 |
-
"cited_elements_json": cited_elements_json,
|
| 373 |
-
"unique_ids": unique_ids,
|
| 374 |
-
"citation_mapping": source_id_to_number,
|
| 375 |
-
"answer_language":aswer_language
|
| 376 |
-
}
|
| 377 |
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
+
Environment preparation script for vaccine assistant - Improved version
|
| 4 |
+
Creates vector stores and retrieval tools with better descriptions for efficient agent routing
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
import os
|
| 8 |
import json
|
| 9 |
import re
|
| 10 |
+
import nest_asyncio
|
| 11 |
+
from typing import List
|
| 12 |
+
from langchain_community.vectorstores import Chroma
|
| 13 |
+
from langchain_core.documents import Document
|
| 14 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 15 |
+
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
| 16 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 17 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 18 |
+
from llama_index.core.tools import FunctionTool
|
| 19 |
+
from llama_index.core.schema import TextNode
|
| 20 |
|
| 21 |
|
| 22 |
def extract_source_ids(response_text):
|
|
|
|
| 53 |
ids = [id_str.strip() for id_str in citation.split(',')]
|
| 54 |
all_ids.extend(ids)
|
| 55 |
|
| 56 |
+
# Get unique source IDs
|
| 57 |
+
source_ids = list(set(all_ids))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
if not source_ids:
|
| 60 |
print("Warning: No valid source IDs found after filtering.")
|
|
|
|
| 63 |
return source_ids
|
| 64 |
|
| 65 |
|
| 66 |
+
def setup_models():
|
| 67 |
+
"""Initialize embedding model and LLM"""
|
| 68 |
+
# Initialize embedding model
|
| 69 |
+
embedding_function = HuggingFaceEmbeddings(
|
| 70 |
+
model_name="intfloat/multilingual-e5-base"
|
| 71 |
+
)
|
| 72 |
|
| 73 |
+
# Initialize LLM with better parameters for focused responses
|
| 74 |
+
genai_api_key = os.getenv('GOOGLE_API_KEY')
|
| 75 |
+
llm = ChatGoogleGenerativeAI(
|
| 76 |
+
model="gemini-2.0-flash",
|
| 77 |
+
google_api_key=genai_api_key,
|
| 78 |
+
temperature=0.1 # Lower temperature for more focused responses
|
| 79 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
+
return embedding_function, llm
|
|
|
|
|
|
|
| 82 |
|
| 83 |
|
| 84 |
+
def create_vectorstore_from_json(json_path: str, collection_name: str, embedding_function):
|
| 85 |
+
"""Create vector store from JSON chunks"""
|
| 86 |
+
# Load the chunks.json
|
| 87 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 88 |
+
chunks_data = json.load(f)
|
| 89 |
|
| 90 |
+
documents = []
|
| 91 |
+
for element in chunks_data:
|
| 92 |
+
text = element["text"]
|
| 93 |
+
metadata = {
|
| 94 |
+
"language": "fra",
|
| 95 |
+
"source": element["filename"],
|
| 96 |
+
"filetype": element["filetype"],
|
| 97 |
+
"element_id": element["element_id"]
|
| 98 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
if "TableElement" == element["type"]:
|
| 101 |
+
metadata["table_text_as_html"] = element["table_text_as_html"]
|
| 102 |
|
| 103 |
+
doc = Document(page_content=text, metadata=metadata)
|
| 104 |
+
documents.append(doc)
|
|
|
|
|
|
|
| 105 |
|
| 106 |
+
# Create vector store
|
| 107 |
+
vectorstore = Chroma.from_documents(
|
| 108 |
+
documents=documents,
|
| 109 |
+
embedding=embedding_function,
|
| 110 |
+
collection_name=collection_name,
|
| 111 |
+
persist_directory="chroma_db_multilingual"
|
| 112 |
+
)
|
| 113 |
+
return vectorstore, documents
|
| 114 |
|
|
|
|
| 115 |
|
| 116 |
+
def create_retriever(vectorstore, docs, llm):
|
| 117 |
+
"""Create ensemble retriever with vector and BM25 search"""
|
| 118 |
+
# Vector retriever
|
| 119 |
+
vector_retriever = vectorstore.as_retriever(
|
| 120 |
+
search_type="similarity",
|
| 121 |
+
search_kwargs={"k": 4} # Reduced from 6 to 4 for efficiency
|
| 122 |
+
)
|
| 123 |
|
| 124 |
+
# BM25 retriever
|
| 125 |
+
bm25_retriever = BM25Retriever.from_documents(docs)
|
| 126 |
+
bm25_retriever.k = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
|
| 128 |
+
# Ensemble retriever
|
| 129 |
+
ensemble_retriever = EnsembleRetriever(
|
| 130 |
+
retrievers=[vector_retriever, bm25_retriever],
|
| 131 |
+
weights=[0.5, 0.5]
|
| 132 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
# Multi-query expanding retriever (with reduced complexity for efficiency)
|
| 135 |
+
expanding_retriever = MultiQueryRetriever.from_llm(
|
| 136 |
+
retriever=ensemble_retriever,
|
| 137 |
+
llm=llm
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return expanding_retriever
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def convert_chromadb_to_llamaindex_nodes(chromadb_documents: List) -> List[TextNode]:
|
| 144 |
+
"""Convert ChromaDB Document objects to LlamaIndex TextNode objects"""
|
| 145 |
+
nodes = []
|
| 146 |
+
for i, doc in enumerate(chromadb_documents):
|
| 147 |
+
try:
|
| 148 |
+
text = doc.page_content
|
| 149 |
+
metadata = doc.metadata.copy()
|
| 150 |
+
element_id = metadata.get("element_id", f"doc_{i}")
|
| 151 |
+
source = metadata.get("source", "unknown")
|
| 152 |
+
node_id = f"{source}_{element_id}"
|
| 153 |
+
|
| 154 |
+
node = TextNode(
|
| 155 |
+
text=text,
|
| 156 |
+
metadata=metadata,
|
| 157 |
+
id_=node_id
|
| 158 |
+
)
|
| 159 |
+
nodes.append(node)
|
| 160 |
+
except Exception as e:
|
| 161 |
+
continue
|
| 162 |
+
return nodes
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def section_tool_wrapper(retriever, section_path_chunks, query):
|
| 166 |
+
"""Generic section tool wrapper with improved efficiency"""
|
| 167 |
try:
|
| 168 |
+
retrieved_docs = retriever.get_relevant_documents(query)
|
| 169 |
+
nodes_from_retrieved_docs = convert_chromadb_to_llamaindex_nodes(retrieved_docs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
if not nodes_from_retrieved_docs:
|
| 172 |
+
return "No relevant documents found for the query."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
|
| 174 |
+
chunk_ids = [node.metadata['element_id'] for node in retrieved_docs]
|
| 175 |
+
with open(section_path_chunks, "r", encoding="utf-8") as f:
|
| 176 |
+
chunks_data = json.load(f)
|
| 177 |
+
|
| 178 |
+
chunks_unique = [node for node in chunks_data if node.get('element_id', 'Unknown') in chunk_ids]
|
| 179 |
+
combined_text = []
|
| 180 |
+
|
| 181 |
+
# Limit the number of chunks to avoid overwhelming the context
|
| 182 |
+
max_chunks = 8 # Reasonable limit
|
| 183 |
+
for chu in chunks_unique[:max_chunks]:
|
| 184 |
+
if "TableElement" == chu["type"]:
|
| 185 |
+
text = f"[{chu['element_id']}]\n CONTENT: \n{chu['text']}\n HTML: \n {chu['table_text_as_html']} \n\n"
|
| 186 |
+
combined_text.append(text)
|
| 187 |
+
else:
|
| 188 |
+
for element in chu["elements"]:
|
| 189 |
+
text = f"[{element['element_id']}]\n CONTENT: \n{element['text']} \n\n"
|
| 190 |
+
combined_text.append(text)
|
| 191 |
+
|
| 192 |
+
result = "\n---\n".join(combined_text)
|
| 193 |
+
print(f"Retrieved {len(nodes_from_retrieved_docs)} documents for query: {query[:50]}...")
|
| 194 |
+
return result
|
| 195 |
except Exception as e:
|
| 196 |
+
print(f"Error in section tool: {e}")
|
| 197 |
+
return f"Error retrieving documents: {str(e)}"
|
| 198 |
|
|
|
|
| 199 |
|
| 200 |
+
def create_section_tools(embedding_function, llm):
|
| 201 |
+
"""Create all section-specific retrieval tools with improved descriptions"""
|
| 202 |
|
| 203 |
+
# Define section paths
|
| 204 |
+
section_paths = {
|
| 205 |
+
'one': 'section_one_chunks.json',
|
| 206 |
+
'two': 'section_two_chunks.json',
|
| 207 |
+
'three': 'section_three_chunks.json',
|
| 208 |
+
'four': 'section_four_chunks.json',
|
| 209 |
+
'five': 'section_five_chunks.json',
|
| 210 |
+
'six': 'section_six_chunks.json',
|
| 211 |
+
'seven': 'section_seven_chunks.json',
|
| 212 |
+
'eight': 'section_eight_chunks.json',
|
| 213 |
+
'nine': 'section_nine_chunks.json',
|
| 214 |
+
'ten': 'section_ten_chunks.json'
|
| 215 |
+
}
|
| 216 |
|
| 217 |
+
# Create retrievers for each section
|
| 218 |
+
section_retrievers = {}
|
| 219 |
+
for section, path in section_paths.items():
|
| 220 |
+
if os.path.exists(f'./data/{path}'):
|
| 221 |
+
vstore, docs = create_vectorstore_from_json(f'./data/{path}', f"Guide_2023_{section}", embedding_function)
|
| 222 |
+
section_retrievers[section] = create_retriever(vstore, docs, llm)
|
| 223 |
|
| 224 |
+
# Create main guide retriever
|
| 225 |
+
guide_path = './data/Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json'
|
| 226 |
+
if os.path.exists(guide_path):
|
| 227 |
+
guide_vstore, guide_docs = create_vectorstore_from_json(guide_path, "Guide_2023_multilingual", embedding_function)
|
| 228 |
+
guide_retriever = create_retriever(guide_vstore, guide_docs, llm)
|
| 229 |
+
else:
|
| 230 |
+
guide_retriever = None
|
| 231 |
+
|
| 232 |
+
# Primary + Secondary Document Paths
|
| 233 |
+
immunization_path = './data/Immunization_in_Practice_WHO_eng_2015.json'
|
| 234 |
+
|
| 235 |
+
# WHO Immunization in Practice Tool
|
| 236 |
+
if os.path.exists(immunization_path):
|
| 237 |
+
immunization_vstore, immunization_docs = create_vectorstore_from_json(
|
| 238 |
+
immunization_path,
|
| 239 |
+
"Immunization_in_Practice_WHO_eng_2015",
|
| 240 |
+
embedding_function
|
| 241 |
+
)
|
| 242 |
+
immunization_retriever = create_retriever(immunization_vstore, immunization_docs, llm)
|
| 243 |
+
else:
|
| 244 |
+
immunization_retriever = None
|
| 245 |
|
| 246 |
+
# Tool Functions with Improved Efficiency Focus
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
+
def guide_retrieval_tool(query: str) -> str:
|
| 249 |
+
"""
|
| 250 |
+
**PRIMARY TOOL - USE FIRST FOR MOST QUESTIONS**
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
+
Comprehensive search across the entire Algerian National Vaccination Guide (2023).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
+
**When to use this tool:**
|
| 255 |
+
- General vaccination questions
|
| 256 |
+
- Disease definitions and descriptions
|
| 257 |
+
- Vaccine schedules and protocols
|
| 258 |
+
- Comparative questions needing Algerian perspective
|
| 259 |
+
- Any question about Algeria's vaccination program
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 260 |
|
| 261 |
+
**Keywords that indicate this tool:** Algeria, Algerian, national, calendrier, vaccination, PEV, diseases (diphteria, polio, measles, etc.)
|
|
|
|
| 262 |
|
| 263 |
+
Args:
|
| 264 |
+
query (str): Any vaccination-related question about Algeria's national program
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
str: Comprehensive information from the Algerian guide with citations
|
| 268 |
+
"""
|
| 269 |
+
if not guide_retriever:
|
| 270 |
+
return "Guide retriever not available"
|
| 271 |
+
return section_tool_wrapper(guide_retriever, guide_path, query)
|
| 272 |
+
|
| 273 |
+
def immunization_tool(query: str) -> str:
|
| 274 |
+
"""
|
| 275 |
+
**SECONDARY TOOL - USE FOR WHO/INTERNATIONAL PERSPECTIVE**
|
| 276 |
|
| 277 |
+
WHO Immunization in Practice 2015 - Global best practices and international standards.
|
|
|
|
| 278 |
|
| 279 |
+
**When to use this tool:**
|
| 280 |
+
- Questions specifically asking about WHO recommendations
|
| 281 |
+
- International/global immunization practices
|
| 282 |
+
- Comparative questions needing WHO perspective
|
| 283 |
+
- Technical immunization procedures and best practices
|
| 284 |
|
| 285 |
+
**Keywords that indicate this tool:** WHO, international, global, best practices, standards
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
+
Args:
|
| 288 |
+
query (str): Question about international immunization practices or WHO recommendations
|
| 289 |
+
|
| 290 |
+
Returns:
|
| 291 |
+
str: WHO guidance and international best practices with citations
|
| 292 |
+
"""
|
| 293 |
+
if not immunization_retriever:
|
| 294 |
+
return "Immunization in Practice retriever not available"
|
| 295 |
+
return section_tool_wrapper(immunization_retriever, immunization_path, query)
|
| 296 |
+
|
| 297 |
+
# Section-Specific Tools (USE ONLY IF QUESTION IS VERY SPECIFIC TO THE SECTION)
|
| 298 |
+
|
| 299 |
+
def section_two_tool(query: str) -> str:
|
| 300 |
+
"""
|
| 301 |
+
**DISEASE-SPECIFIC TOOL**
|
|
|
|
|
|
|
| 302 |
|
| 303 |
+
Section 2: Vaccine-preventable diseases - definitions, symptoms, transmission, complications.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
|
| 305 |
+
**Use ONLY for specific disease definition questions like:**
|
| 306 |
+
- "What is diphtheria?"
|
| 307 |
+
- "Define measles according to Algerian protocol"
|
| 308 |
+
- "Symptoms of polio"
|
| 309 |
+
|
| 310 |
+
**Keywords:** definition, symptoms, transmission, complications, disease characteristics
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
query (str): Specific question about disease definitions or characteristics
|
| 314 |
+
|
| 315 |
+
Returns:
|
| 316 |
+
str: Disease-specific medical information with citations
|
| 317 |
+
"""
|
| 318 |
+
if 'two' not in section_retrievers:
|
| 319 |
+
return "Section 2 retriever not available"
|
| 320 |
+
return section_tool_wrapper(section_retrievers['two'], f'./data/{section_paths["two"]}', query)
|
| 321 |
+
|
| 322 |
+
def section_three_tool(query: str) -> str:
|
| 323 |
+
"""
|
| 324 |
+
**VACCINE-SPECIFIC TOOL**
|
| 325 |
+
|
| 326 |
+
Section 3: Vaccine details - types, composition, administration methods.
|
| 327 |
+
|
| 328 |
+
**Use ONLY for specific vaccine technical questions like:**
|
| 329 |
+
- "What type of vaccine is used for diphtheria?"
|
| 330 |
+
- "How is the MMR vaccine administered?"
|
| 331 |
+
- "Vaccine composition and dosage"
|
| 332 |
+
|
| 333 |
+
**Keywords:** vaccine type, composition, administration, dosage, technical details
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
query (str): Technical question about specific vaccines
|
| 337 |
+
|
| 338 |
+
Returns:
|
| 339 |
+
str: Technical vaccine information with citations
|
| 340 |
+
"""
|
| 341 |
+
if 'three' not in section_retrievers:
|
| 342 |
+
return "Section 3 retriever not available"
|
| 343 |
+
return section_tool_wrapper(section_retrievers['three'], f'./data/{section_paths["three"]}', query)
|
| 344 |
+
|
| 345 |
+
# Create FunctionTool objects with focused selection
|
| 346 |
+
tools = [
|
| 347 |
+
# Primary tools - most commonly used
|
| 348 |
+
FunctionTool.from_defaults(
|
| 349 |
+
name="algerian_guide_search",
|
| 350 |
+
fn=guide_retrieval_tool,
|
| 351 |
+
description="PRIMARY TOOL: Search the complete Algerian National Vaccination Guide for any vaccination-related question"
|
| 352 |
+
),
|
| 353 |
+
FunctionTool.from_defaults(
|
| 354 |
+
name="who_immunization_search",
|
| 355 |
+
fn=immunization_tool,
|
| 356 |
+
description="SECONDARY TOOL: Search WHO Immunization in Practice for international standards and WHO recommendations"
|
| 357 |
+
),
|
| 358 |
+
# Specialized tools - use only when very specific
|
| 359 |
+
FunctionTool.from_defaults(
|
| 360 |
+
name="disease_definitions_search",
|
| 361 |
+
fn=section_two_tool,
|
| 362 |
+
description="SPECIALIZED: Search for specific disease definitions, symptoms, and characteristics"
|
| 363 |
+
),
|
| 364 |
+
FunctionTool.from_defaults(
|
| 365 |
+
name="vaccine_technical_search",
|
| 366 |
+
fn=section_three_tool,
|
| 367 |
+
description="SPECIALIZED: Search for technical vaccine details, composition, and administration methods"
|
| 368 |
+
),
|
| 369 |
+
]
|
| 370 |
+
|
| 371 |
+
return tools
|
| 372 |
|
| 373 |
+
|
| 374 |
+
def prepare_environment():
|
| 375 |
+
"""Main function to prepare the environment and return tools"""
|
| 376 |
+
print("Setting up models...")
|
| 377 |
+
embedding_function, llm = setup_models()
|
| 378 |
+
|
| 379 |
+
print("Creating section tools...")
|
| 380 |
+
tools = create_section_tools(embedding_function, llm)
|
| 381 |
+
|
| 382 |
+
print("Environment prepared successfully!")
|
| 383 |
+
return tools, llm
|