Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Environment preparation script for vaccine assistant | |
| Creates vector stores and retrieval tools | |
| """ | |
| import os | |
| import json | |
| import re | |
| import nest_asyncio | |
| from typing import List | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_core.documents import Document | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| from langchain.retrievers import BM25Retriever, EnsembleRetriever | |
| from langchain.retrievers.multi_query import MultiQueryRetriever | |
| from langchain_google_genai import ChatGoogleGenerativeAI | |
| from llama_index.core.tools import FunctionTool | |
| from llama_index.core.schema import TextNode | |
| from langchain.prompts import PromptTemplate | |
| import logging | |
| logging.basicConfig() | |
| logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO) | |
| def extract_source_ids(response_text): | |
| """ | |
| Extract source IDs from the response, handling different citation formats: | |
| - Standard format: [Source ID] | |
| - Multiple sources in one citation: [Source ID1][Source ID2] | |
| - Multiple sources in one bracket: [Source ID1, Source ID2] | |
| Args: | |
| response_text (str): The generated response text with inline citations. | |
| Returns: | |
| list of str: List of unique source IDs found in the response text. | |
| """ | |
| import re | |
| # First, extract all source IDs from inline citations with adjacent brackets [ID1][ID2] | |
| # Replace them with single brackets with comma separation to standardize format | |
| consolidated_text = re.sub(r'\][\s]*\[', '][', response_text) | |
| consolidated_text = re.sub(r'\]\[', ', ', consolidated_text) | |
| # Now extract all source IDs from any format (single ID or comma-separated IDs) | |
| inline_citations = re.findall(r'\[([^\[\]]+)\]', consolidated_text) | |
| if not inline_citations: | |
| print("Warning: No source IDs found in the response text.") | |
| return [] | |
| # Process each citation which might contain multiple comma-separated IDs | |
| all_ids = [] | |
| for citation in inline_citations: | |
| # Split by comma and strip whitespace | |
| ids = [id_str.strip() for id_str in citation.split(',')] | |
| all_ids.extend(ids) | |
| # Get unique source IDs | |
| source_ids = list(set(all_ids)) | |
| # Filter out any non-UUID-like IDs (if needed) | |
| # This is now optional as we're handling various source ID formats | |
| # uuid_pattern = r'^[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}$' | |
| # source_ids = [source_id for source_id in source_ids if re.match(uuid_pattern, source_id, re.IGNORECASE)] | |
| if not source_ids: | |
| print("Warning: No valid source IDs found after filtering.") | |
| return [] | |
| return source_ids | |
| def setup_models(): | |
| """Initialize embedding model and LLM""" | |
| print("π§ Setting up embedding model and LLM...") | |
| # Initialize embedding model | |
| embedding_function = HuggingFaceEmbeddings( | |
| model_name="intfloat/multilingual-e5-base" | |
| ) | |
| print("β Embedding model initialized: intfloat/multilingual-e5-base") | |
| # Initialize LLM | |
| genai_api_key = os.getenv('GOOGLE_API_KEY') | |
| llm = ChatGoogleGenerativeAI( | |
| model="gemini-2.0-flash", | |
| google_api_key=genai_api_key | |
| ) | |
| print("β LLM initialized: gemini-2.0-flash") | |
| return embedding_function, llm | |
| def create_vectorstore_from_json(json_path: str, collection_name: str, embedding_function): | |
| """Create vector store from JSON chunks""" | |
| print(f"π Creating vector store from: {json_path}") | |
| # Load the chunks.json | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| chunks_data = json.load(f) | |
| print(f"π Loaded {len(chunks_data)} chunks from JSON") | |
| documents = [] | |
| for element in chunks_data: | |
| text = element["text"] | |
| metadata = { | |
| "language": "fra", | |
| "source": element["filename"], | |
| "filetype": element["filetype"], | |
| "element_id": element["element_id"] | |
| } | |
| if "TableElement" == element["type"]: | |
| metadata["table_text_as_html"] = element["table_text_as_html"] | |
| doc = Document(page_content=text, metadata=metadata) | |
| documents.append(doc) | |
| # Create vector store | |
| vectorstore = Chroma.from_documents( | |
| documents=documents, | |
| embedding=embedding_function, | |
| collection_name=collection_name, | |
| persist_directory="chroma_db_multilingual" | |
| ) | |
| print(f"β Vector store created with collection: {collection_name}") | |
| return vectorstore, documents | |
| def create_retriever(vectorstore, docs, llm, bm25_k=3,vector_k=6): | |
| """Create ensemble retriever with vector and BM25 search | |
| Args: | |
| vectorstore: The vector store for similarity search | |
| docs: Documents for BM25 retriever | |
| llm: Language model for multi-query generation | |
| bm25_k: Number of documents to retrieve with BM25 | |
| vector_k: Number of documents to retrieve with vector search | |
| Returns: | |
| Configured retriever (MultiQueryRetriever or EnsembleRetriever) | |
| """ | |
| print("π Creating ensemble retriever...") | |
| # PromptTemplate for Vaccine Assistant MultiQuery Retriever | |
| VACCINE_MULTIQUERY_PROMPT = PromptTemplate( | |
| input_variables=["question"], | |
| template="""You are an AI assistant specialized in vaccine-related medical information retrieval. | |
| Your task is to generate multiple search queries based on the original question to find relevant information from official vaccine medical documents. | |
| IMPORTANT GUIDELINES: | |
| - Keep all vaccine-specific terminology and medical terms intact | |
| - Maintain the clinical and medical context | |
| - Focus on evidence-based vaccine information | |
| - Preserve any specific vaccine names, diseases, or medical conditions mentioned | |
| - Generate queries that would help retrieve information about vaccine schedules, dosing, contraindications, adverse events, and disease prevention | |
| Original question: {question} | |
| Generate 4 different search queries that rephrase the original question while maintaining vaccine terminology and medical accuracy. Each query should approach the topic from a slightly different angle to maximize retrieval from vaccine medical documents. | |
| Provide only the alternative questions, one per line.""" | |
| ) | |
| # Vector retriever | |
| vector_retriever = vectorstore.as_retriever( | |
| search_type="similarity", | |
| search_kwargs={"k": vector_k} | |
| ) | |
| print(f"β Vector retriever created (k={vector_k})") | |
| # BM25 retriever | |
| bm25_retriever = BM25Retriever.from_documents(docs) | |
| bm25_retriever.k = bm25_k | |
| print(f"β BM25 retriever created (k={bm25_k})") | |
| # Ensemble retriever | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[vector_retriever, bm25_retriever], | |
| weights=[0.5, 0.5] | |
| ) | |
| print("β Ensemble retriever created (weights: 0.5, 0.5)") | |
| # Multi-query expanding retriever (only for filtered mode) | |
| expanding_retriever = MultiQueryRetriever.from_llm( | |
| retriever=ensemble_retriever, | |
| llm=llm, | |
| prompt=VACCINE_MULTIQUERY_PROMPT, | |
| ) | |
| print("β Multi-query expanding retriever created") | |
| return expanding_retriever | |
| def convert_chromadb_to_llamaindex_nodes(chromadb_documents: List) -> List[TextNode]: | |
| """Convert ChromaDB Document objects to LlamaIndex TextNode objects""" | |
| nodes = [] | |
| for i, doc in enumerate(chromadb_documents): | |
| try: | |
| text = doc.page_content | |
| metadata = doc.metadata.copy() | |
| element_id = metadata.get("element_id", f"doc_{i}") | |
| source = metadata.get("source", "unknown") | |
| node_id = f"{source}_{element_id}" | |
| node = TextNode( | |
| text=text, | |
| metadata=metadata, | |
| id_=node_id | |
| ) | |
| nodes.append(node) | |
| except Exception as e: | |
| continue | |
| return nodes | |
| def section_tool_wrapper(retriever, section_path_chunks, query): | |
| """Generic section tool wrapper""" | |
| print(f"π TOOL CALL: Searching for query: '{query[:100]}...' in {section_path_chunks}") | |
| try: | |
| retrieved_docs = retriever.get_relevant_documents(query) | |
| print(f"π Retrieved {len(retrieved_docs)} documents") | |
| nodes_from_retrieved_docs = convert_chromadb_to_llamaindex_nodes(retrieved_docs) | |
| if not nodes_from_retrieved_docs: | |
| print("β No relevant documents found for the query") | |
| return "No relevant documents found for the query." | |
| chunk_ids = [node.metadata['element_id'] for node in retrieved_docs] | |
| print(f"π Found chunk IDs: {chunk_ids}") | |
| with open(section_path_chunks, "r", encoding="utf-8") as f: | |
| chunks_data = json.load(f) | |
| chunks_unique = [node for node in chunks_data if node.get('element_id', 'Unknown') in chunk_ids] | |
| print(f"β Matched {len(chunks_unique)} unique chunks") | |
| combined_text = [] | |
| for chu in chunks_unique: | |
| if "TableElement" == chu["type"]: | |
| text = f"[Source: {chu['elements']['element_id']}]\n CONTENT: \n{chu['text']}\n HTML: \n {chu['table_text_as_html']} \n\n" | |
| combined_text.append(text) | |
| else: | |
| for element in chu["elements"]: | |
| text = f"[Source: {element['element_id']}]\n CONTENT: \n{element['text']} \n\n" | |
| combined_text.append(text) | |
| result = "\n---\n".join(combined_text) | |
| print(f"β TOOL RESPONSE: Generated response with {len(combined_text)} text sections") | |
| return result | |
| except Exception as e: | |
| print(f"β TOOL ERROR: {e}") | |
| return f"Error retrieving documents: {str(e)}" | |
| def create_section_tools(embedding_function, llm): | |
| """ | |
| Create all section-specific retrieval tools with improved descriptions for accurate routing. | |
| """ | |
| print("π οΈ Creating section-specific retrieval tools with enhanced descriptions...") | |
| # Define section paths - Fixed path structure | |
| section_paths = { | |
| 'one': './data/section_one_chunks.json', | |
| 'two': './data/section_two_chunks.json', | |
| 'three': './data/section_three_chunks.json', | |
| 'four': './data/section_four_chunks.json', | |
| 'five': './data/section_five_chunks.json', | |
| 'six': './data/section_six_chunks.json', | |
| 'seven': './data/section_seven_chunks.json', | |
| 'eight': './data/section_eight_chunks.json', | |
| 'nine': './data/section_nine_chunks.json', | |
| 'ten': './data/section_ten_chunks.json' | |
| } | |
| # Create retrievers for each section | |
| section_retrievers = {} | |
| for section, path in section_paths.items(): | |
| try: | |
| if os.path.exists(path): | |
| print(f"π Creating retriever for section {section} from {path}") | |
| vstore, docs = create_vectorstore_from_json(path, f"Guide_2023_{section}", embedding_function) | |
| section_retrievers[section] = create_retriever(vstore, docs, llm,) | |
| print(f"β Successfully created retriever for section {section}") | |
| else: | |
| print(f"β οΈ Warning: File not found for section {section}: {path}") | |
| section_retrievers[section] = None | |
| except Exception as e: | |
| print(f"β Error creating retriever for section {section}: {e}") | |
| section_retrievers[section] = None | |
| # Create main guide retriever | |
| guide_path = './data/Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json' | |
| guide_retriever = None | |
| try: | |
| if os.path.exists(guide_path): | |
| print("π Creating main guide retriever...") | |
| guide_vstore, guide_docs = create_vectorstore_from_json(guide_path, "Guide_2023_multilingual", embedding_function) | |
| guide_retriever = create_retriever(guide_vstore, guide_docs, llm) | |
| print("β Successfully created main guide retriever") | |
| else: | |
| print(f"β οΈ Warning: Main guide file not found: {guide_path}") | |
| except Exception as e: | |
| print(f"β Error creating main guide retriever: {e}") | |
| # WHO Immunization in Practice Tool | |
| immunization_path = './data/Immunization in Practice_WHO_eng_2015.json' | |
| immunization_retriever = None | |
| try: | |
| if os.path.exists(immunization_path): | |
| print("π Creating immunization retriever...") | |
| immunization_vstore, immunization_docs = create_vectorstore_from_json( | |
| immunization_path, | |
| "Immunization_in_Practice_WHO_eng_2015", | |
| embedding_function | |
| ) | |
| immunization_retriever = create_retriever(immunization_vstore, immunization_docs, llm) | |
| print("β Successfully created immunization retriever") | |
| else: | |
| print(f"β οΈ Warning: Immunization file not found: {immunization_path}") | |
| except Exception as e: | |
| print(f"β Error creating immunization retriever: {e}") | |
| # --- Tool Definitions with Improved Descriptions --- | |
| def general_guide_tool(query: str) -> str: | |
| """ | |
| A general-purpose tool for the Algerian National Vaccination Guide. | |
| **Use this tool as a fallback** if no other specific tool seems appropriate, or for very broad, multi-topic questions | |
| (e.g., 'Summarize the Algerian vaccination policy and its safety measures'). | |
| **Always prefer a more specific tool if the query matches its description** (e.g., use 'cold_chain_tool' for temperature questions). | |
| Args: | |
| query (str): A broad or ambiguous question about the Algerian National Vaccination Guide. | |
| Returns: | |
| str: Content retrieved from the entire guide. | |
| """ | |
| print(f"π₯ GENERAL GUIDE TOOL CALLED (FALLBACK): {query[:50]}...") | |
| if not guide_retriever: | |
| return "Guide retriever not available - main guide file may be missing" | |
| return section_tool_wrapper(guide_retriever, guide_path, query) | |
| def who_immunization_tool(query: str) -> str: | |
| """ | |
| Provides information from the WHO's 'Immunization in Practice' guide. Use this for questions about | |
| **global immunization standards**, international best practices, or for comparing Algerian policy to | |
| general WHO recommendations on topics like cold chain, safety, and disease control. | |
| Args: | |
| query (str): A question seeking global or general immunization practices. | |
| Returns: | |
| str: Content from the WHO Immunization in Practice guide. | |
| """ | |
| print(f"π WHO TOOL CALLED: {query[:50]}...") | |
| if not immunization_retriever: | |
| return "Immunization in Practice retriever not available - WHO guide file may be missing" | |
| return section_tool_wrapper(immunization_retriever, immunization_path, query) | |
| def program_overview_tool(query: str) -> str: | |
| """ | |
| (Section 1) The primary tool for questions about the **history, objectives, and structure** of Algeria's | |
| national immunization program (PEV - Programme Γlargi de Vaccination). Use this for topics like | |
| the program's rationale, key achievements, and the reasons for updates to the vaccination calendar. | |
| Args: | |
| query (str): A question about the foundation or evolution of the PEV. | |
| Returns: | |
| str: Response from Section 1. | |
| """ | |
| print(f"π PROGRAM OVERVIEW (S1) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('one'): | |
| return "Section 1 retriever not available" | |
| return section_tool_wrapper(section_retrievers['one'], section_paths['one'], query) | |
| def disease_info_tool(query: str) -> str: | |
| """ | |
| (Section 2) The definitive tool for information on **specific vaccine-preventable diseases**. | |
| Use this to find details on **symptoms, transmission methods, complications**, and prevention | |
| strategies for diseases like Diphtheria, Measles, Polio, Tetanus, etc. | |
| Args: | |
| query (str): A question about a disease covered by the national vaccination program. | |
| Returns: | |
| str: Disease-specific content from Section 2. | |
| """ | |
| print(f"π¦ DISEASE INFO (S2) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('two'): | |
| return "Section 2 retriever not available" | |
| return section_tool_wrapper(section_retrievers['two'], section_paths['two'], query) | |
| def vaccine_properties_tool(query: str) -> str: | |
| """ | |
| (Section 3) The specific tool for questions about the **vaccines themselves**: their types (e.g., BCG, ROR, | |
| DTCaVPI), composition, whether they are live or inactivated, and the correct **method of administration** | |
| (e.g., intradermal, intramuscular, oral). | |
| Args: | |
| query (str): A question about a vaccine's formulation or how it is administered. | |
| Returns: | |
| str: Vaccine-specific info from Section 3. | |
| """ | |
| print(f"π VACCINE PROPERTIES (S3) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('three'): | |
| return "Section 3 retriever not available" | |
| return section_tool_wrapper(section_retrievers['three'], section_paths['three'], query) | |
| def catch_up_vaccination_tool(query: str) -> str: | |
| """ | |
| (Section 4) Specialized tool for **missed or delayed vaccinations (rattrapage vaccinal)**. | |
| Use this for questions about creating a **catch-up schedule** for a child who is behind | |
| on their shots, based on their age and vaccination history. | |
| Args: | |
| query (str): A question about catch-up vaccination due to a delay or missed dose. | |
| Returns: | |
| str: Catch-up schedule guidance from Section 4. | |
| """ | |
| print(f"π CATCH-UP (S4) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('four'): | |
| return "Section 4 retriever not available" | |
| return section_tool_wrapper(section_retrievers['four'], section_paths['four'], query) | |
| def special_populations_tool(query: str) -> str: | |
| """ | |
| (Section 5) The designated tool for vaccination guidelines concerning **special populations**. | |
| Use for questions about vaccinating preterm infants, allergic children, or patients with | |
| immunosuppression, chronic illnesses (cardiac, pulmonary), or other specific health conditions. | |
| Args: | |
| query (str): A question about tailored vaccination for a vulnerable or special group. | |
| Returns: | |
| str: Custom recommendations from Section 5. | |
| """ | |
| print(f"π₯ SPECIAL POPULATIONS (S5) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('five'): | |
| return "Section 5 retriever not available" | |
| return section_tool_wrapper(section_retrievers['five'], section_paths['five'], query) | |
| def cold_chain_tool(query: str) -> str: | |
| """ | |
| (Section 6) The definitive tool for all questions about the **cold chain**, including vaccine **storage | |
| temperatures**, transport protocols, refrigerators, temperature monitoring (like PCV pastilles), | |
| and procedures for handling cold chain failures or power outages. | |
| Args: | |
| query (str): A logistics-related question about vaccine temperature management. | |
| Returns: | |
| str: Cold chain instructions from Section 6. | |
| """ | |
| print(f"βοΈ COLD CHAIN (S6) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('six'): | |
| return "Section 6 retriever not available" | |
| return section_tool_wrapper(section_retrievers['six'], section_paths['six'], query) | |
| def injection_safety_tool(query: str) -> str: | |
| """ | |
| (Section 7) The primary tool for questions related to the **safe administration of injections**. | |
| Use for topics like sterile equipment, proper injection techniques, preventing needlestick injuries, | |
| and safe disposal of medical waste (DASRI). | |
| Args: | |
| query (str): A question about how to perform vaccine injections safely. | |
| Returns: | |
| str: Best practices from Section 7. | |
| """ | |
| print(f"π‘οΈ INJECTION SAFETY (S7) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('seven'): | |
| return "Section 7 retriever not available" | |
| return section_tool_wrapper(section_retrievers['seven'], section_paths['seven'], query) | |
| def session_management_tool(query: str) -> str: | |
| """ | |
| (Section 8) Use this tool for questions about the **operational conduct of a vaccination session** | |
| and **vaccinovigilance**. This includes preparing the session, material setup, registering vaccination | |
| acts, and monitoring/reporting adverse events post-vaccination (MPVI). | |
| Args: | |
| query (str): A question about running a vaccination session or post-vaccine monitoring. | |
| Returns: | |
| str: Workflow and safety monitoring details from Section 8. | |
| """ | |
| print(f"π SESSION MGMT (S8) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('eight'): | |
| return "Section 8 retriever not available" | |
| return section_tool_wrapper(section_retrievers['eight'], section_paths['eight'], query) | |
| def planning_and_logistics_tool(query: str) -> str: | |
| """ | |
| (Section 9) This tool is for **planning vaccination sessions and managing logistics**. Use it for | |
| questions about creating operational maps, estimating vaccine and supply needs, managing stock, | |
| and reducing vaccine wastage. | |
| Args: | |
| query (str): A question about organizing vaccination services or managing stock. | |
| Returns: | |
| str: Planning and stock guidance from Section 9. | |
| """ | |
| print(f"π PLANNING & LOGISTICS (S9) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('nine'): | |
| return "Section 9 retriever not available" | |
| return section_tool_wrapper(section_retrievers['nine'], section_paths['nine'], query) | |
| def communication_tool(query: str) -> str: | |
| """ | |
| (Section 10) The specific tool for **social mobilization and communication**. Use this for | |
| questions about communication strategies, addressing **vaccine hesitancy**, managing rumors, | |
| and community outreach to promote vaccination. | |
| Args: | |
| query (str): A question about public engagement or communication for vaccination. | |
| Returns: | |
| str: Public mobilization strategies from Section 10. | |
| """ | |
| print(f"π’ COMMUNICATION (S10) TOOL CALLED: {query[:50]}...") | |
| if not section_retrievers.get('ten'): | |
| return "Section 10 retriever not available" | |
| return section_tool_wrapper(section_retrievers['ten'], section_paths['ten'], query) | |
| # Create FunctionTool objects with new, clearer names | |
| tools = [ | |
| FunctionTool.from_defaults(name="general_guide_tool", fn=general_guide_tool), | |
| FunctionTool.from_defaults(name="who_immunization_tool", fn=who_immunization_tool), | |
| # Section-specific tools | |
| FunctionTool.from_defaults(name="program_overview_tool", fn=program_overview_tool), | |
| FunctionTool.from_defaults(name="disease_info_tool", fn=disease_info_tool), | |
| FunctionTool.from_defaults(name="vaccine_properties_tool", fn=vaccine_properties_tool), | |
| FunctionTool.from_defaults(name="catch_up_vaccination_tool", fn=catch_up_vaccination_tool), | |
| FunctionTool.from_defaults(name="special_populations_tool", fn=special_populations_tool), | |
| FunctionTool.from_defaults(name="cold_chain_tool", fn=cold_chain_tool), | |
| FunctionTool.from_defaults(name="injection_safety_tool", fn=injection_safety_tool), | |
| FunctionTool.from_defaults(name="session_management_tool", fn=session_management_tool), | |
| FunctionTool.from_defaults(name="planning_and_logistics_tool", fn=planning_and_logistics_tool), | |
| FunctionTool.from_defaults(name="communication_tool", fn=communication_tool), | |
| ] | |
| print(f"β Created {len(tools)} tools with improved routing descriptions") | |
| return tools | |
| def prepare_environment(): | |
| """Main function to prepare the environment and return tools""" | |
| print("π Starting environment preparation...") | |
| print("π§ Setting up models...") | |
| embedding_function, llm = setup_models() | |
| print("π οΈ Creating section tools...") | |
| tools = create_section_tools(embedding_function, llm) | |
| print("β Environment prepared successfully!") | |
| print(f"π Created {len(tools)} tools") | |
| return tools, llm |