AgenticRAG / prepare_env.py
Zeggai Abdellah
update the k retrever
e84f370
# -*- coding: utf-8 -*-
"""
Environment preparation script for vaccine assistant
Creates vector stores and retrieval tools
"""
import os
import json
import re
import nest_asyncio
from typing import List
from langchain_community.vectorstores import Chroma
from langchain_core.documents import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers import BM25Retriever, EnsembleRetriever
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain_google_genai import ChatGoogleGenerativeAI
from llama_index.core.tools import FunctionTool
from llama_index.core.schema import TextNode
from langchain.prompts import PromptTemplate
import logging
logging.basicConfig()
logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)
def extract_source_ids(response_text):
"""
Extract source IDs from the response, handling different citation formats:
- Standard format: [Source ID]
- Multiple sources in one citation: [Source ID1][Source ID2]
- Multiple sources in one bracket: [Source ID1, Source ID2]
Args:
response_text (str): The generated response text with inline citations.
Returns:
list of str: List of unique source IDs found in the response text.
"""
import re
# First, extract all source IDs from inline citations with adjacent brackets [ID1][ID2]
# Replace them with single brackets with comma separation to standardize format
consolidated_text = re.sub(r'\][\s]*\[', '][', response_text)
consolidated_text = re.sub(r'\]\[', ', ', consolidated_text)
# Now extract all source IDs from any format (single ID or comma-separated IDs)
inline_citations = re.findall(r'\[([^\[\]]+)\]', consolidated_text)
if not inline_citations:
print("Warning: No source IDs found in the response text.")
return []
# Process each citation which might contain multiple comma-separated IDs
all_ids = []
for citation in inline_citations:
# Split by comma and strip whitespace
ids = [id_str.strip() for id_str in citation.split(',')]
all_ids.extend(ids)
# Get unique source IDs
source_ids = list(set(all_ids))
# Filter out any non-UUID-like IDs (if needed)
# This is now optional as we're handling various source ID formats
# uuid_pattern = r'^[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}$'
# source_ids = [source_id for source_id in source_ids if re.match(uuid_pattern, source_id, re.IGNORECASE)]
if not source_ids:
print("Warning: No valid source IDs found after filtering.")
return []
return source_ids
def setup_models():
"""Initialize embedding model and LLM"""
print("πŸ”§ Setting up embedding model and LLM...")
# Initialize embedding model
embedding_function = HuggingFaceEmbeddings(
model_name="intfloat/multilingual-e5-base"
)
print("βœ… Embedding model initialized: intfloat/multilingual-e5-base")
# Initialize LLM
genai_api_key = os.getenv('GOOGLE_API_KEY')
llm = ChatGoogleGenerativeAI(
model="gemini-2.0-flash",
google_api_key=genai_api_key
)
print("βœ… LLM initialized: gemini-2.0-flash")
return embedding_function, llm
def create_vectorstore_from_json(json_path: str, collection_name: str, embedding_function):
"""Create vector store from JSON chunks"""
print(f"πŸ“š Creating vector store from: {json_path}")
# Load the chunks.json
with open(json_path, "r", encoding="utf-8") as f:
chunks_data = json.load(f)
print(f"πŸ“Š Loaded {len(chunks_data)} chunks from JSON")
documents = []
for element in chunks_data:
text = element["text"]
metadata = {
"language": "fra",
"source": element["filename"],
"filetype": element["filetype"],
"element_id": element["element_id"]
}
if "TableElement" == element["type"]:
metadata["table_text_as_html"] = element["table_text_as_html"]
doc = Document(page_content=text, metadata=metadata)
documents.append(doc)
# Create vector store
vectorstore = Chroma.from_documents(
documents=documents,
embedding=embedding_function,
collection_name=collection_name,
persist_directory="chroma_db_multilingual"
)
print(f"βœ… Vector store created with collection: {collection_name}")
return vectorstore, documents
def create_retriever(vectorstore, docs, llm, bm25_k=3,vector_k=6):
"""Create ensemble retriever with vector and BM25 search
Args:
vectorstore: The vector store for similarity search
docs: Documents for BM25 retriever
llm: Language model for multi-query generation
bm25_k: Number of documents to retrieve with BM25
vector_k: Number of documents to retrieve with vector search
Returns:
Configured retriever (MultiQueryRetriever or EnsembleRetriever)
"""
print("πŸ” Creating ensemble retriever...")
# PromptTemplate for Vaccine Assistant MultiQuery Retriever
VACCINE_MULTIQUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an AI assistant specialized in vaccine-related medical information retrieval.
Your task is to generate multiple search queries based on the original question to find relevant information from official vaccine medical documents.
IMPORTANT GUIDELINES:
- Keep all vaccine-specific terminology and medical terms intact
- Maintain the clinical and medical context
- Focus on evidence-based vaccine information
- Preserve any specific vaccine names, diseases, or medical conditions mentioned
- Generate queries that would help retrieve information about vaccine schedules, dosing, contraindications, adverse events, and disease prevention
Original question: {question}
Generate 4 different search queries that rephrase the original question while maintaining vaccine terminology and medical accuracy. Each query should approach the topic from a slightly different angle to maximize retrieval from vaccine medical documents.
Provide only the alternative questions, one per line."""
)
# Vector retriever
vector_retriever = vectorstore.as_retriever(
search_type="similarity",
search_kwargs={"k": vector_k}
)
print(f"βœ… Vector retriever created (k={vector_k})")
# BM25 retriever
bm25_retriever = BM25Retriever.from_documents(docs)
bm25_retriever.k = bm25_k
print(f"βœ… BM25 retriever created (k={bm25_k})")
# Ensemble retriever
ensemble_retriever = EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=[0.5, 0.5]
)
print("βœ… Ensemble retriever created (weights: 0.5, 0.5)")
# Multi-query expanding retriever (only for filtered mode)
expanding_retriever = MultiQueryRetriever.from_llm(
retriever=ensemble_retriever,
llm=llm,
prompt=VACCINE_MULTIQUERY_PROMPT,
)
print("βœ… Multi-query expanding retriever created")
return expanding_retriever
def convert_chromadb_to_llamaindex_nodes(chromadb_documents: List) -> List[TextNode]:
"""Convert ChromaDB Document objects to LlamaIndex TextNode objects"""
nodes = []
for i, doc in enumerate(chromadb_documents):
try:
text = doc.page_content
metadata = doc.metadata.copy()
element_id = metadata.get("element_id", f"doc_{i}")
source = metadata.get("source", "unknown")
node_id = f"{source}_{element_id}"
node = TextNode(
text=text,
metadata=metadata,
id_=node_id
)
nodes.append(node)
except Exception as e:
continue
return nodes
def section_tool_wrapper(retriever, section_path_chunks, query):
"""Generic section tool wrapper"""
print(f"πŸ” TOOL CALL: Searching for query: '{query[:100]}...' in {section_path_chunks}")
try:
retrieved_docs = retriever.get_relevant_documents(query)
print(f"πŸ“„ Retrieved {len(retrieved_docs)} documents")
nodes_from_retrieved_docs = convert_chromadb_to_llamaindex_nodes(retrieved_docs)
if not nodes_from_retrieved_docs:
print("❌ No relevant documents found for the query")
return "No relevant documents found for the query."
chunk_ids = [node.metadata['element_id'] for node in retrieved_docs]
print(f"πŸ†” Found chunk IDs: {chunk_ids}")
with open(section_path_chunks, "r", encoding="utf-8") as f:
chunks_data = json.load(f)
chunks_unique = [node for node in chunks_data if node.get('element_id', 'Unknown') in chunk_ids]
print(f"βœ… Matched {len(chunks_unique)} unique chunks")
combined_text = []
for chu in chunks_unique:
if "TableElement" == chu["type"]:
text = f"[Source: {chu['elements']['element_id']}]\n CONTENT: \n{chu['text']}\n HTML: \n {chu['table_text_as_html']} \n\n"
combined_text.append(text)
else:
for element in chu["elements"]:
text = f"[Source: {element['element_id']}]\n CONTENT: \n{element['text']} \n\n"
combined_text.append(text)
result = "\n---\n".join(combined_text)
print(f"βœ… TOOL RESPONSE: Generated response with {len(combined_text)} text sections")
return result
except Exception as e:
print(f"❌ TOOL ERROR: {e}")
return f"Error retrieving documents: {str(e)}"
def create_section_tools(embedding_function, llm):
"""
Create all section-specific retrieval tools with improved descriptions for accurate routing.
"""
print("πŸ› οΈ Creating section-specific retrieval tools with enhanced descriptions...")
# Define section paths - Fixed path structure
section_paths = {
'one': './data/section_one_chunks.json',
'two': './data/section_two_chunks.json',
'three': './data/section_three_chunks.json',
'four': './data/section_four_chunks.json',
'five': './data/section_five_chunks.json',
'six': './data/section_six_chunks.json',
'seven': './data/section_seven_chunks.json',
'eight': './data/section_eight_chunks.json',
'nine': './data/section_nine_chunks.json',
'ten': './data/section_ten_chunks.json'
}
# Create retrievers for each section
section_retrievers = {}
for section, path in section_paths.items():
try:
if os.path.exists(path):
print(f"πŸ“ Creating retriever for section {section} from {path}")
vstore, docs = create_vectorstore_from_json(path, f"Guide_2023_{section}", embedding_function)
section_retrievers[section] = create_retriever(vstore, docs, llm,)
print(f"βœ… Successfully created retriever for section {section}")
else:
print(f"⚠️ Warning: File not found for section {section}: {path}")
section_retrievers[section] = None
except Exception as e:
print(f"❌ Error creating retriever for section {section}: {e}")
section_retrievers[section] = None
# Create main guide retriever
guide_path = './data/Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json'
guide_retriever = None
try:
if os.path.exists(guide_path):
print("πŸ“š Creating main guide retriever...")
guide_vstore, guide_docs = create_vectorstore_from_json(guide_path, "Guide_2023_multilingual", embedding_function)
guide_retriever = create_retriever(guide_vstore, guide_docs, llm)
print("βœ… Successfully created main guide retriever")
else:
print(f"⚠️ Warning: Main guide file not found: {guide_path}")
except Exception as e:
print(f"❌ Error creating main guide retriever: {e}")
# WHO Immunization in Practice Tool
immunization_path = './data/Immunization in Practice_WHO_eng_2015.json'
immunization_retriever = None
try:
if os.path.exists(immunization_path):
print("🌍 Creating immunization retriever...")
immunization_vstore, immunization_docs = create_vectorstore_from_json(
immunization_path,
"Immunization_in_Practice_WHO_eng_2015",
embedding_function
)
immunization_retriever = create_retriever(immunization_vstore, immunization_docs, llm)
print("βœ… Successfully created immunization retriever")
else:
print(f"⚠️ Warning: Immunization file not found: {immunization_path}")
except Exception as e:
print(f"❌ Error creating immunization retriever: {e}")
# --- Tool Definitions with Improved Descriptions ---
def general_guide_tool(query: str) -> str:
"""
A general-purpose tool for the Algerian National Vaccination Guide.
**Use this tool as a fallback** if no other specific tool seems appropriate, or for very broad, multi-topic questions
(e.g., 'Summarize the Algerian vaccination policy and its safety measures').
**Always prefer a more specific tool if the query matches its description** (e.g., use 'cold_chain_tool' for temperature questions).
Args:
query (str): A broad or ambiguous question about the Algerian National Vaccination Guide.
Returns:
str: Content retrieved from the entire guide.
"""
print(f"πŸ₯ GENERAL GUIDE TOOL CALLED (FALLBACK): {query[:50]}...")
if not guide_retriever:
return "Guide retriever not available - main guide file may be missing"
return section_tool_wrapper(guide_retriever, guide_path, query)
def who_immunization_tool(query: str) -> str:
"""
Provides information from the WHO's 'Immunization in Practice' guide. Use this for questions about
**global immunization standards**, international best practices, or for comparing Algerian policy to
general WHO recommendations on topics like cold chain, safety, and disease control.
Args:
query (str): A question seeking global or general immunization practices.
Returns:
str: Content from the WHO Immunization in Practice guide.
"""
print(f"🌍 WHO TOOL CALLED: {query[:50]}...")
if not immunization_retriever:
return "Immunization in Practice retriever not available - WHO guide file may be missing"
return section_tool_wrapper(immunization_retriever, immunization_path, query)
def program_overview_tool(query: str) -> str:
"""
(Section 1) The primary tool for questions about the **history, objectives, and structure** of Algeria's
national immunization program (PEV - Programme Γ‰largi de Vaccination). Use this for topics like
the program's rationale, key achievements, and the reasons for updates to the vaccination calendar.
Args:
query (str): A question about the foundation or evolution of the PEV.
Returns:
str: Response from Section 1.
"""
print(f"πŸ“‹ PROGRAM OVERVIEW (S1) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('one'):
return "Section 1 retriever not available"
return section_tool_wrapper(section_retrievers['one'], section_paths['one'], query)
def disease_info_tool(query: str) -> str:
"""
(Section 2) The definitive tool for information on **specific vaccine-preventable diseases**.
Use this to find details on **symptoms, transmission methods, complications**, and prevention
strategies for diseases like Diphtheria, Measles, Polio, Tetanus, etc.
Args:
query (str): A question about a disease covered by the national vaccination program.
Returns:
str: Disease-specific content from Section 2.
"""
print(f"🦠 DISEASE INFO (S2) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('two'):
return "Section 2 retriever not available"
return section_tool_wrapper(section_retrievers['two'], section_paths['two'], query)
def vaccine_properties_tool(query: str) -> str:
"""
(Section 3) The specific tool for questions about the **vaccines themselves**: their types (e.g., BCG, ROR,
DTCaVPI), composition, whether they are live or inactivated, and the correct **method of administration**
(e.g., intradermal, intramuscular, oral).
Args:
query (str): A question about a vaccine's formulation or how it is administered.
Returns:
str: Vaccine-specific info from Section 3.
"""
print(f"πŸ’‰ VACCINE PROPERTIES (S3) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('three'):
return "Section 3 retriever not available"
return section_tool_wrapper(section_retrievers['three'], section_paths['three'], query)
def catch_up_vaccination_tool(query: str) -> str:
"""
(Section 4) Specialized tool for **missed or delayed vaccinations (rattrapage vaccinal)**.
Use this for questions about creating a **catch-up schedule** for a child who is behind
on their shots, based on their age and vaccination history.
Args:
query (str): A question about catch-up vaccination due to a delay or missed dose.
Returns:
str: Catch-up schedule guidance from Section 4.
"""
print(f"πŸ”„ CATCH-UP (S4) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('four'):
return "Section 4 retriever not available"
return section_tool_wrapper(section_retrievers['four'], section_paths['four'], query)
def special_populations_tool(query: str) -> str:
"""
(Section 5) The designated tool for vaccination guidelines concerning **special populations**.
Use for questions about vaccinating preterm infants, allergic children, or patients with
immunosuppression, chronic illnesses (cardiac, pulmonary), or other specific health conditions.
Args:
query (str): A question about tailored vaccination for a vulnerable or special group.
Returns:
str: Custom recommendations from Section 5.
"""
print(f"πŸ‘₯ SPECIAL POPULATIONS (S5) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('five'):
return "Section 5 retriever not available"
return section_tool_wrapper(section_retrievers['five'], section_paths['five'], query)
def cold_chain_tool(query: str) -> str:
"""
(Section 6) The definitive tool for all questions about the **cold chain**, including vaccine **storage
temperatures**, transport protocols, refrigerators, temperature monitoring (like PCV pastilles),
and procedures for handling cold chain failures or power outages.
Args:
query (str): A logistics-related question about vaccine temperature management.
Returns:
str: Cold chain instructions from Section 6.
"""
print(f"❄️ COLD CHAIN (S6) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('six'):
return "Section 6 retriever not available"
return section_tool_wrapper(section_retrievers['six'], section_paths['six'], query)
def injection_safety_tool(query: str) -> str:
"""
(Section 7) The primary tool for questions related to the **safe administration of injections**.
Use for topics like sterile equipment, proper injection techniques, preventing needlestick injuries,
and safe disposal of medical waste (DASRI).
Args:
query (str): A question about how to perform vaccine injections safely.
Returns:
str: Best practices from Section 7.
"""
print(f"πŸ›‘οΈ INJECTION SAFETY (S7) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('seven'):
return "Section 7 retriever not available"
return section_tool_wrapper(section_retrievers['seven'], section_paths['seven'], query)
def session_management_tool(query: str) -> str:
"""
(Section 8) Use this tool for questions about the **operational conduct of a vaccination session**
and **vaccinovigilance**. This includes preparing the session, material setup, registering vaccination
acts, and monitoring/reporting adverse events post-vaccination (MPVI).
Args:
query (str): A question about running a vaccination session or post-vaccine monitoring.
Returns:
str: Workflow and safety monitoring details from Section 8.
"""
print(f"πŸ“Š SESSION MGMT (S8) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('eight'):
return "Section 8 retriever not available"
return section_tool_wrapper(section_retrievers['eight'], section_paths['eight'], query)
def planning_and_logistics_tool(query: str) -> str:
"""
(Section 9) This tool is for **planning vaccination sessions and managing logistics**. Use it for
questions about creating operational maps, estimating vaccine and supply needs, managing stock,
and reducing vaccine wastage.
Args:
query (str): A question about organizing vaccination services or managing stock.
Returns:
str: Planning and stock guidance from Section 9.
"""
print(f"πŸ“… PLANNING & LOGISTICS (S9) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('nine'):
return "Section 9 retriever not available"
return section_tool_wrapper(section_retrievers['nine'], section_paths['nine'], query)
def communication_tool(query: str) -> str:
"""
(Section 10) The specific tool for **social mobilization and communication**. Use this for
questions about communication strategies, addressing **vaccine hesitancy**, managing rumors,
and community outreach to promote vaccination.
Args:
query (str): A question about public engagement or communication for vaccination.
Returns:
str: Public mobilization strategies from Section 10.
"""
print(f"πŸ“’ COMMUNICATION (S10) TOOL CALLED: {query[:50]}...")
if not section_retrievers.get('ten'):
return "Section 10 retriever not available"
return section_tool_wrapper(section_retrievers['ten'], section_paths['ten'], query)
# Create FunctionTool objects with new, clearer names
tools = [
FunctionTool.from_defaults(name="general_guide_tool", fn=general_guide_tool),
FunctionTool.from_defaults(name="who_immunization_tool", fn=who_immunization_tool),
# Section-specific tools
FunctionTool.from_defaults(name="program_overview_tool", fn=program_overview_tool),
FunctionTool.from_defaults(name="disease_info_tool", fn=disease_info_tool),
FunctionTool.from_defaults(name="vaccine_properties_tool", fn=vaccine_properties_tool),
FunctionTool.from_defaults(name="catch_up_vaccination_tool", fn=catch_up_vaccination_tool),
FunctionTool.from_defaults(name="special_populations_tool", fn=special_populations_tool),
FunctionTool.from_defaults(name="cold_chain_tool", fn=cold_chain_tool),
FunctionTool.from_defaults(name="injection_safety_tool", fn=injection_safety_tool),
FunctionTool.from_defaults(name="session_management_tool", fn=session_management_tool),
FunctionTool.from_defaults(name="planning_and_logistics_tool", fn=planning_and_logistics_tool),
FunctionTool.from_defaults(name="communication_tool", fn=communication_tool),
]
print(f"βœ… Created {len(tools)} tools with improved routing descriptions")
return tools
def prepare_environment():
"""Main function to prepare the environment and return tools"""
print("πŸš€ Starting environment preparation...")
print("πŸ”§ Setting up models...")
embedding_function, llm = setup_models()
print("πŸ› οΈ Creating section tools...")
tools = create_section_tools(embedding_function, llm)
print("βœ… Environment prepared successfully!")
print(f"πŸ“‹ Created {len(tools)} tools")
return tools, llm