Spaces:
Sleeping
Sleeping
Zeggai Abdellah
commited on
Commit
·
c181ce0
1
Parent(s):
ee6a617
update the code with new simpler version
Browse files- main.py +95 -98
- models.py +0 -14
- prepare_env.py +254 -0
- rag_pipeline.py +189 -0
- rag_system.py +0 -345
main.py
CHANGED
|
@@ -1,138 +1,135 @@
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
-
|
| 4 |
-
|
| 5 |
"""
|
| 6 |
|
| 7 |
-
import
|
| 8 |
-
from fastapi import FastAPI, HTTPException, Query
|
| 9 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
from
|
| 13 |
-
from
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
|
| 17 |
-
rag_system = AgenticRAGSystem(config)
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
|
|
|
|
| 21 |
|
| 22 |
-
# FastAPI app
|
| 23 |
app = FastAPI(
|
| 24 |
-
title="
|
| 25 |
-
description="
|
| 26 |
version="1.0.0"
|
| 27 |
)
|
| 28 |
|
| 29 |
# Add CORS middleware
|
| 30 |
app.add_middleware(
|
| 31 |
CORSMiddleware,
|
| 32 |
-
allow_origins=["*"], # Configure
|
| 33 |
allow_credentials=True,
|
| 34 |
allow_methods=["*"],
|
| 35 |
allow_headers=["*"],
|
| 36 |
)
|
| 37 |
|
|
|
|
|
|
|
|
|
|
| 38 |
@app.on_event("startup")
|
| 39 |
async def startup_event():
|
| 40 |
-
"""Initialize the RAG
|
| 41 |
-
global
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
@app.get("/"
|
| 58 |
async def root():
|
| 59 |
-
"""Root endpoint"""
|
| 60 |
-
return HealthResponse(
|
| 61 |
-
status="online",
|
| 62 |
-
message="Agentic RAG Vaccination Assistant API is running"
|
| 63 |
-
)
|
| 64 |
-
|
| 65 |
-
@app.get("/health", response_model=HealthResponse)
|
| 66 |
-
async def health_check():
|
| 67 |
"""Health check endpoint"""
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
else:
|
| 74 |
-
return HealthResponse(
|
| 75 |
-
status="initializing",
|
| 76 |
-
message="System is still initializing. Please wait."
|
| 77 |
-
)
|
| 78 |
|
| 79 |
-
@app.get("/ask"
|
| 80 |
-
async def ask_question(
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
|
| 91 |
if not question.strip():
|
| 92 |
-
raise HTTPException(
|
| 93 |
-
status_code=400,
|
| 94 |
-
detail="Question cannot be empty"
|
| 95 |
-
)
|
| 96 |
|
| 97 |
try:
|
| 98 |
-
|
| 99 |
-
import concurrent.futures
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
answer = await loop.run_in_executor(
|
| 104 |
-
executor,
|
| 105 |
-
rag_system.ask_question,
|
| 106 |
-
question,
|
| 107 |
-
with_citations
|
| 108 |
-
)
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
| 115 |
|
| 116 |
except Exception as e:
|
| 117 |
-
|
| 118 |
-
raise HTTPException(
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
-
@app.post("/ask", response_model=QuestionResponse)
|
| 124 |
-
async def ask_question_post(request: QuestionRequest):
|
| 125 |
-
"""Ask a question to the vaccination assistant (POST version)"""
|
| 126 |
-
return await ask_question(request.question, request.with_citations)
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
| 129 |
import uvicorn
|
| 130 |
-
|
| 131 |
-
print("Starting Agentic RAG API server...")
|
| 132 |
-
uvicorn.run(
|
| 133 |
-
"main:app",
|
| 134 |
-
host="0.0.0.0",
|
| 135 |
-
port=8000,
|
| 136 |
-
reload=True,
|
| 137 |
-
log_level="info"
|
| 138 |
-
)
|
|
|
|
| 1 |
# -*- coding: utf-8 -*-
|
| 2 |
"""
|
| 3 |
+
FastAPI server for vaccine assistant
|
| 4 |
+
Main entry point for the application
|
| 5 |
"""
|
| 6 |
|
| 7 |
+
from fastapi import FastAPI, Query, HTTPException
|
|
|
|
| 8 |
from fastapi.middleware.cors import CORSMiddleware
|
| 9 |
+
import os
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
import logging
|
| 12 |
|
| 13 |
+
# Import our modules
|
| 14 |
+
from prepare_env import prepare_environment
|
| 15 |
+
from rag_pipeline import initialize_rag_pipeline, process_question
|
| 16 |
|
| 17 |
+
# Load environment variables
|
| 18 |
+
load_dotenv()
|
|
|
|
| 19 |
|
| 20 |
+
# Setup logging
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
+
# Initialize FastAPI app
|
| 25 |
app = FastAPI(
|
| 26 |
+
title="Vaccine Assistant API",
|
| 27 |
+
description="AI-powered vaccine assistant for medical professionals",
|
| 28 |
version="1.0.0"
|
| 29 |
)
|
| 30 |
|
| 31 |
# Add CORS middleware
|
| 32 |
app.add_middleware(
|
| 33 |
CORSMiddleware,
|
| 34 |
+
allow_origins=["*"], # Configure appropriately for production
|
| 35 |
allow_credentials=True,
|
| 36 |
allow_methods=["*"],
|
| 37 |
allow_headers=["*"],
|
| 38 |
)
|
| 39 |
|
| 40 |
+
# Global variables for the agent (initialized on startup)
|
| 41 |
+
agent = None
|
| 42 |
+
|
| 43 |
@app.on_event("startup")
|
| 44 |
async def startup_event():
|
| 45 |
+
"""Initialize the RAG pipeline on startup"""
|
| 46 |
+
global agent
|
| 47 |
+
try:
|
| 48 |
+
logger.info("Starting up vaccine assistant...")
|
| 49 |
+
|
| 50 |
+
# Check for required environment variables
|
| 51 |
+
if not os.getenv('GOOGLE_API_KEY'):
|
| 52 |
+
logger.warning("GOOGLE_API_KEY not found in environment variables")
|
| 53 |
+
|
| 54 |
+
# Prepare environment and tools
|
| 55 |
+
logger.info("Preparing environment and tools...")
|
| 56 |
+
tools, llm = prepare_environment()
|
| 57 |
|
| 58 |
+
# Initialize RAG pipeline
|
| 59 |
+
logger.info("Initializing RAG pipeline...")
|
| 60 |
+
agent = initialize_rag_pipeline(tools)
|
| 61 |
+
|
| 62 |
+
logger.info("✅ Vaccine assistant startup completed successfully!")
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
logger.error(f"❌ Error during startup: {e}")
|
| 66 |
+
raise e
|
| 67 |
|
| 68 |
+
@app.get("/")
|
| 69 |
async def root():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
"""Health check endpoint"""
|
| 71 |
+
return {
|
| 72 |
+
"message": "Vaccine Assistant API is running",
|
| 73 |
+
"status": "healthy",
|
| 74 |
+
"version": "1.0.0"
|
| 75 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
+
@app.get("/ask")
|
| 78 |
+
async def ask_question(question: str = Query(..., description="The medical question to ask")):
|
| 79 |
+
"""
|
| 80 |
+
Main endpoint for asking questions to the vaccine assistant
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
question: The medical question related to vaccines
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
JSON response with question and answer
|
| 87 |
+
"""
|
| 88 |
+
global agent
|
| 89 |
+
|
| 90 |
+
if agent is None:
|
| 91 |
+
raise HTTPException(status_code=503, detail="Agent not initialized. Please try again later.")
|
| 92 |
|
| 93 |
if not question.strip():
|
| 94 |
+
raise HTTPException(status_code=400, detail="Question cannot be empty")
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
try:
|
| 97 |
+
logger.info(f"Processing question: {question[:100]}...")
|
|
|
|
| 98 |
|
| 99 |
+
# Process the question through RAG pipeline
|
| 100 |
+
answer = process_question(agent, question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
logger.info("Question processed successfully")
|
| 103 |
+
|
| 104 |
+
return {
|
| 105 |
+
"question": question,
|
| 106 |
+
"answer": answer,
|
| 107 |
+
"status": "success"
|
| 108 |
+
}
|
| 109 |
|
| 110 |
except Exception as e:
|
| 111 |
+
logger.error(f"Error processing question: {e}")
|
| 112 |
+
raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
|
| 113 |
+
|
| 114 |
+
@app.get("/health")
|
| 115 |
+
async def health_check():
|
| 116 |
+
"""Detailed health check endpoint"""
|
| 117 |
+
global agent
|
| 118 |
+
|
| 119 |
+
health_status = {
|
| 120 |
+
"status": "healthy" if agent is not None else "unhealthy",
|
| 121 |
+
"agent_initialized": agent is not None,
|
| 122 |
+
"google_api_key_configured": bool(os.getenv('GOOGLE_API_KEY')),
|
| 123 |
+
"version": "1.0.0"
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
if agent is None:
|
| 127 |
+
health_status["status"] = "unhealthy"
|
| 128 |
+
health_status["message"] = "Agent not initialized"
|
| 129 |
+
|
| 130 |
+
return health_status
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
|
| 133 |
if __name__ == "__main__":
|
| 134 |
import uvicorn
|
| 135 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models.py
DELETED
|
@@ -1,14 +0,0 @@
|
|
| 1 |
-
from pydantic import BaseModel
|
| 2 |
-
|
| 3 |
-
class QuestionRequest(BaseModel):
|
| 4 |
-
question: str
|
| 5 |
-
with_citations: bool = False
|
| 6 |
-
|
| 7 |
-
class QuestionResponse(BaseModel):
|
| 8 |
-
question: str
|
| 9 |
-
answer: str
|
| 10 |
-
status: str = "success"
|
| 11 |
-
|
| 12 |
-
class HealthResponse(BaseModel):
|
| 13 |
-
status: str
|
| 14 |
-
message: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prepare_env.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
Environment preparation script for vaccine assistant
|
| 4 |
+
Creates vector stores and retrieval tools
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
import json
|
| 9 |
+
import nest_asyncio
|
| 10 |
+
from typing import List
|
| 11 |
+
from langchain_community.vectorstores import Chroma
|
| 12 |
+
from langchain_core.documents import Document
|
| 13 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
| 14 |
+
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
| 15 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 16 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 17 |
+
from llama_index.core.tools import FunctionTool
|
| 18 |
+
from llama_index.core.schema import TextNode
|
| 19 |
+
|
| 20 |
+
# Apply nest_asyncio for compatibility
|
| 21 |
+
nest_asyncio.apply()
|
| 22 |
+
|
| 23 |
+
def setup_models():
|
| 24 |
+
"""Initialize embedding model and LLM"""
|
| 25 |
+
# Initialize embedding model
|
| 26 |
+
embedding_function = HuggingFaceEmbeddings(
|
| 27 |
+
model_name="intfloat/multilingual-e5-base"
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# Initialize LLM
|
| 31 |
+
genai_api_key = os.getenv('GOOGLE_API_KEY', 'AIzaSyBho3W4W9fR7wHUJbX18JKH-12wDSD7pWg')
|
| 32 |
+
llm = ChatGoogleGenerativeAI(
|
| 33 |
+
model="gemini-2.0-flash",
|
| 34 |
+
google_api_key=genai_api_key
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
return embedding_function, llm
|
| 38 |
+
|
| 39 |
+
def create_vectorstore_from_json(json_path: str, collection_name: str, embedding_function):
|
| 40 |
+
"""Create vector store from JSON chunks"""
|
| 41 |
+
# Load the chunks.json
|
| 42 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 43 |
+
chunks_data = json.load(f)
|
| 44 |
+
|
| 45 |
+
documents = []
|
| 46 |
+
for element in chunks_data:
|
| 47 |
+
text = element["text"]
|
| 48 |
+
metadata = {
|
| 49 |
+
"language": "fra",
|
| 50 |
+
"source": element["filename"],
|
| 51 |
+
"filetype": element["filetype"],
|
| 52 |
+
"element_id": element["element_id"]
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
if "TableElement" == element["type"]:
|
| 56 |
+
metadata["table_text_as_html"] = element["table_text_as_html"]
|
| 57 |
+
|
| 58 |
+
doc = Document(page_content=text, metadata=metadata)
|
| 59 |
+
documents.append(doc)
|
| 60 |
+
|
| 61 |
+
# Create vector store
|
| 62 |
+
vectorstore = Chroma.from_documents(
|
| 63 |
+
documents=documents,
|
| 64 |
+
embedding=embedding_function,
|
| 65 |
+
collection_name=collection_name,
|
| 66 |
+
persist_directory="chroma_db_multilingual"
|
| 67 |
+
)
|
| 68 |
+
return vectorstore, documents
|
| 69 |
+
|
| 70 |
+
def create_retriever(vectorstore, docs, llm):
|
| 71 |
+
"""Create ensemble retriever with vector and BM25 search"""
|
| 72 |
+
# Vector retriever
|
| 73 |
+
vector_retriever = vectorstore.as_retriever(
|
| 74 |
+
search_type="similarity",
|
| 75 |
+
search_kwargs={"k": 6}
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
# BM25 retriever
|
| 79 |
+
bm25_retriever = BM25Retriever.from_documents(docs)
|
| 80 |
+
bm25_retriever.k = 2
|
| 81 |
+
|
| 82 |
+
# Ensemble retriever
|
| 83 |
+
ensemble_retriever = EnsembleRetriever(
|
| 84 |
+
retrievers=[vector_retriever, bm25_retriever],
|
| 85 |
+
weights=[0.5, 0.5]
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# Multi-query expanding retriever
|
| 89 |
+
expanding_retriever = MultiQueryRetriever.from_llm(
|
| 90 |
+
retriever=ensemble_retriever,
|
| 91 |
+
llm=llm
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return expanding_retriever
|
| 95 |
+
|
| 96 |
+
def convert_chromadb_to_llamaindex_nodes(chromadb_documents: List) -> List[TextNode]:
|
| 97 |
+
"""Convert ChromaDB Document objects to LlamaIndex TextNode objects"""
|
| 98 |
+
nodes = []
|
| 99 |
+
for i, doc in enumerate(chromadb_documents):
|
| 100 |
+
try:
|
| 101 |
+
text = doc.page_content
|
| 102 |
+
metadata = doc.metadata.copy()
|
| 103 |
+
element_id = metadata.get("element_id", f"doc_{i}")
|
| 104 |
+
source = metadata.get("source", "unknown")
|
| 105 |
+
node_id = f"{source}_{element_id}"
|
| 106 |
+
|
| 107 |
+
node = TextNode(
|
| 108 |
+
text=text,
|
| 109 |
+
metadata=metadata,
|
| 110 |
+
id_=node_id
|
| 111 |
+
)
|
| 112 |
+
nodes.append(node)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
continue
|
| 115 |
+
return nodes
|
| 116 |
+
|
| 117 |
+
def section_tool_wrapper(retriever, section_path_chunks, query):
|
| 118 |
+
"""Generic section tool wrapper"""
|
| 119 |
+
try:
|
| 120 |
+
retrieved_docs = retriever.get_relevant_documents(query)
|
| 121 |
+
nodes_from_retrieved_docs = convert_chromadb_to_llamaindex_nodes(retrieved_docs)
|
| 122 |
+
|
| 123 |
+
if not nodes_from_retrieved_docs:
|
| 124 |
+
return "No relevant documents found for the query."
|
| 125 |
+
|
| 126 |
+
chunk_ids = [node.metadata['element_id'] for node in retrieved_docs]
|
| 127 |
+
with open(section_path_chunks, "r", encoding="utf-8") as f:
|
| 128 |
+
chunks_data = json.load(f)
|
| 129 |
+
|
| 130 |
+
chunks_unique = [node for node in chunks_data if node.get('element_id', 'Unknown') in chunk_ids]
|
| 131 |
+
combined_text = []
|
| 132 |
+
|
| 133 |
+
for chu in chunks_unique:
|
| 134 |
+
if "TableElement" == chu["type"]:
|
| 135 |
+
text = f"[Source: {chu['elements']['element_id']}]\n CONTENT: \n{chu['text']}\n HTML: \n {chu['table_text_as_html']} \n\n"
|
| 136 |
+
combined_text.append(text)
|
| 137 |
+
else:
|
| 138 |
+
for element in chu["elements"]:
|
| 139 |
+
text = f"[Source: {element['element_id']}]\n CONTENT: \n{element['text']} \n\n"
|
| 140 |
+
combined_text.append(text)
|
| 141 |
+
|
| 142 |
+
result = "\n---\n".join(combined_text)
|
| 143 |
+
print(f"Retrieved {len(nodes_from_retrieved_docs)} documents for query: {query[:50]}...")
|
| 144 |
+
return result
|
| 145 |
+
except Exception as e:
|
| 146 |
+
print(f"Error in section tool: {e}")
|
| 147 |
+
return f"Error retrieving documents: {str(e)}"
|
| 148 |
+
|
| 149 |
+
def create_section_tools(embedding_function, llm):
|
| 150 |
+
"""Create all section-specific retrieval tools"""
|
| 151 |
+
|
| 152 |
+
# Define section paths
|
| 153 |
+
section_paths = {
|
| 154 |
+
# 'one': 'section_one_chunks.json',
|
| 155 |
+
# 'two': 'section_two_chunks.json',
|
| 156 |
+
# 'three': 'section_three_chunks.json',
|
| 157 |
+
# 'four': 'section_four_chunks.json',
|
| 158 |
+
# 'five': 'section_five_chunks.json',
|
| 159 |
+
# 'six': 'section_six_chunks.json',
|
| 160 |
+
# 'seven': 'section_seven_chunks.json',
|
| 161 |
+
# 'eight': 'section_eight_chunks.json',
|
| 162 |
+
# 'nine': 'section_nine_chunks.json',
|
| 163 |
+
'ten': './data/section_ten_chunks.json'
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
# Create retrievers for each section
|
| 167 |
+
section_retrievers = {}
|
| 168 |
+
for section, path in section_paths.items():
|
| 169 |
+
if os.path.exists(path):
|
| 170 |
+
vstore, docs = create_vectorstore_from_json(path, f"Guide_2023_{section}", embedding_function)
|
| 171 |
+
section_retrievers[section] = create_retriever(vstore, docs, llm)
|
| 172 |
+
|
| 173 |
+
# # Create main guide retriever
|
| 174 |
+
# guide_path = 'Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json'
|
| 175 |
+
# if os.path.exists(guide_path):
|
| 176 |
+
# guide_vstore, guide_docs = create_vectorstore_from_json(guide_path, "Guide_2023_multilingual", embedding_function)
|
| 177 |
+
# guide_retriever = create_retriever(guide_vstore, guide_docs, llm)
|
| 178 |
+
# else:
|
| 179 |
+
# guide_retriever = None
|
| 180 |
+
|
| 181 |
+
# # Define tool functions
|
| 182 |
+
# def guide_retrieval_tool(query: str) -> str:
|
| 183 |
+
# """General-purpose retrieval tool for the entire Algerian National Vaccination Guide"""
|
| 184 |
+
# if not guide_retriever:
|
| 185 |
+
# return "Guide retriever not available"
|
| 186 |
+
# return section_tool_wrapper(guide_retriever, guide_path, query)
|
| 187 |
+
|
| 188 |
+
# def section_one_tool(query: str) -> str:
|
| 189 |
+
# """Section 1: Programme Élargi de Vaccination"""
|
| 190 |
+
# return section_tool_wrapper(section_retrievers['one'], section_paths['one'], query)
|
| 191 |
+
|
| 192 |
+
# def section_two_tool(query: str) -> str:
|
| 193 |
+
# """Section 2: Maladies Ciblées"""
|
| 194 |
+
# return section_tool_wrapper(section_retrievers['two'], section_paths['two'], query)
|
| 195 |
+
|
| 196 |
+
# def section_three_tool(query: str) -> str:
|
| 197 |
+
# """Section 3: Vaccins du Calendrier"""
|
| 198 |
+
# return section_tool_wrapper(section_retrievers['three'], section_paths['three'], query)
|
| 199 |
+
|
| 200 |
+
# def section_four_tool(query: str) -> str:
|
| 201 |
+
# """Section 4: Rattrapage Vaccinal"""
|
| 202 |
+
# return section_tool_wrapper(section_retrievers['four'], section_paths['four'], query)
|
| 203 |
+
|
| 204 |
+
# def section_five_tool(query: str) -> str:
|
| 205 |
+
# """Section 5: Populations Particulières"""
|
| 206 |
+
# return section_tool_wrapper(section_retrievers['five'], section_paths['five'], query)
|
| 207 |
+
|
| 208 |
+
# def section_six_tool(query: str) -> str:
|
| 209 |
+
# """Section 6: Chaîne du Froid"""
|
| 210 |
+
# return section_tool_wrapper(section_retrievers['six'], section_paths['six'], query)
|
| 211 |
+
|
| 212 |
+
# def section_seven_tool(query: str) -> str:
|
| 213 |
+
# """Section 7: Sécurité des Injections"""
|
| 214 |
+
# return section_tool_wrapper(section_retrievers['seven'], section_paths['seven'], query)
|
| 215 |
+
|
| 216 |
+
# def section_eight_tool(query: str) -> str:
|
| 217 |
+
# """Section 8: Séance de Vaccination & Vaccinovigilance"""
|
| 218 |
+
# return section_tool_wrapper(section_retrievers['eight'], section_paths['eight'], query)
|
| 219 |
+
|
| 220 |
+
# def section_nine_tool(query: str) -> str:
|
| 221 |
+
# """Section 9: Planification des Séances de Vaccination"""
|
| 222 |
+
# return section_tool_wrapper(section_retrievers['nine'], section_paths['nine'], query)
|
| 223 |
+
|
| 224 |
+
def section_ten_tool(query: str) -> str:
|
| 225 |
+
"""Section 10: Mobilisation Sociale"""
|
| 226 |
+
return section_tool_wrapper(section_retrievers['ten'], section_paths['ten'], query)
|
| 227 |
+
|
| 228 |
+
# Create FunctionTool objects
|
| 229 |
+
tools = [
|
| 230 |
+
# FunctionTool.from_defaults(name="Guide_vector_tool", fn=guide_retrieval_tool),
|
| 231 |
+
# FunctionTool.from_defaults(name="section_one_vector_query_tool", fn=section_one_tool),
|
| 232 |
+
# FunctionTool.from_defaults(name="section_two_vector_query_tool", fn=section_two_tool),
|
| 233 |
+
# FunctionTool.from_defaults(name="section_three_vector_query_tool", fn=section_three_tool),
|
| 234 |
+
# FunctionTool.from_defaults(name="section_four_vector_query_tool", fn=section_four_tool),
|
| 235 |
+
# FunctionTool.from_defaults(name="section_five_vector_query_tool", fn=section_five_tool),
|
| 236 |
+
# FunctionTool.from_defaults(name="section_six_vector_query_tool", fn=section_six_tool),
|
| 237 |
+
# FunctionTool.from_defaults(name="section_seven_vector_query_tool", fn=section_seven_tool),
|
| 238 |
+
# FunctionTool.from_defaults(name="section_eight_vector_query_tool", fn=section_eight_tool),
|
| 239 |
+
# FunctionTool.from_defaults(name="section_nine_vector_query_tool", fn=section_nine_tool),
|
| 240 |
+
FunctionTool.from_defaults(name="section_ten_vector_query_tool", fn=section_ten_tool),
|
| 241 |
+
]
|
| 242 |
+
|
| 243 |
+
return tools
|
| 244 |
+
|
| 245 |
+
def prepare_environment():
|
| 246 |
+
"""Main function to prepare the environment and return tools"""
|
| 247 |
+
print("Setting up models...")
|
| 248 |
+
embedding_function, llm = setup_models()
|
| 249 |
+
|
| 250 |
+
print("Creating section tools...")
|
| 251 |
+
tools = create_section_tools(embedding_function, llm)
|
| 252 |
+
|
| 253 |
+
print("Environment prepared successfully!")
|
| 254 |
+
return tools, llm
|
rag_pipeline.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
RAG Pipeline for vaccine assistant
|
| 4 |
+
Handles agent creation and question answering
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from llama_index.core import PromptTemplate
|
| 8 |
+
from llama_index.core.agent import ReActAgent
|
| 9 |
+
from llama_index.llms.google_genai import GoogleGenAI
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
def create_custom_prompt():
|
| 13 |
+
"""Create custom prompt with medical assistant instructions"""
|
| 14 |
+
|
| 15 |
+
custom_instructions = """
|
| 16 |
+
## MEDICAL ASSISTANT ROLE
|
| 17 |
+
You are a helpful and knowledgeable AI-powered vaccine assistant designed to support doctors in clinical decision-making.
|
| 18 |
+
You provide evidence-based guidance using only information from official vaccine medical documents.
|
| 19 |
+
Answer the doctor's question accurately and concisely using only the provided information.
|
| 20 |
+
|
| 21 |
+
## IMPORTANT REQUIREMENTS
|
| 22 |
+
|
| 23 |
+
### Citation and Sourcing
|
| 24 |
+
1. For each fact in your response, include an inline citation in the format [Source] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90].
|
| 25 |
+
2. Do NOT use 'Source:' in the citation format; use only the Source in square brackets.
|
| 26 |
+
3. If a fact is supported by multiple sources, use the following format:
|
| 27 |
+
- Use adjacent citations: [e795ebd28318886c0b1a5395ac30ad90][21a932b2340bb16707763f57f0ad2]
|
| 28 |
+
4. Use ONLY the provided information and never include facts from your general knowledge.
|
| 29 |
+
|
| 30 |
+
### Content Formatting
|
| 31 |
+
1. When rendering tables:
|
| 32 |
+
- Convert HTML tables into clean Markdown format
|
| 33 |
+
- Preserve all original headers and data rows exactly
|
| 34 |
+
- Include the citation in the table caption, e.g., 'Table: Vaccination Schedule [Source]'
|
| 35 |
+
2. For lists, maintain the original bullet points/numbering and include citations.
|
| 36 |
+
3. Present information concisely but ensure clinical accuracy is never compromised.
|
| 37 |
+
|
| 38 |
+
## Tools
|
| 39 |
+
|
| 40 |
+
You have access to a wide variety of tools. You are responsible for using the tools in any sequence you deem appropriate to complete the task at hand.
|
| 41 |
+
This may require breaking the task into subtasks and using different tools to complete each subtask.
|
| 42 |
+
|
| 43 |
+
You have access to the following tools:
|
| 44 |
+
{tool_desc}
|
| 45 |
+
|
| 46 |
+
## Output Format
|
| 47 |
+
|
| 48 |
+
Please answer in the same language as the question and use the following format:
|
| 49 |
+
|
| 50 |
+
```
|
| 51 |
+
Thought: The current language of the user is: (user's language). I need to use a tool to help me answer the question.
|
| 52 |
+
Action: tool name (one of {tool_names}) if using a tool.
|
| 53 |
+
Action Input: the input to the tool, in a JSON format representing the kwargs (e.g. {{"input": "hello world", "num_beams": 5}})
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
Please ALWAYS start with a Thought.
|
| 57 |
+
|
| 58 |
+
NEVER surround your response with markdown code markers. You may use code markers within your response if you need to.
|
| 59 |
+
|
| 60 |
+
Please use a valid JSON format for the Action Input. Do NOT do this {{"input": "hello world", "num_beams": 5}}.
|
| 61 |
+
|
| 62 |
+
If this format is used, the tool will respond in the following format:
|
| 63 |
+
|
| 64 |
+
```
|
| 65 |
+
Observation: tool response
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
You should keep repeating the above format till you have enough information to answer the question without using any more tools. At that point, you MUST respond in one of the following two formats:
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
Thought: I can answer without using any more tools. I'll use the user's language to answer. Remember to include proper citations
|
| 72 |
+
Answer: [your answer here with proper citations (In the same language as the user's question)]
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
```
|
| 76 |
+
Thought: I cannot answer the question with the provided tools.
|
| 77 |
+
Answer: [your answer here (In the same language as the user's question)]
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
## Current Conversation
|
| 81 |
+
|
| 82 |
+
Below is the current conversation consisting of interleaving human and assistant messages.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
custom_prompt = PromptTemplate(
|
| 87 |
+
template=custom_instructions,
|
| 88 |
+
template_vars=["tool_desc", "tool_names"]
|
| 89 |
+
)
|
| 90 |
+
return custom_prompt
|
| 91 |
+
except:
|
| 92 |
+
# Fallback to simple template
|
| 93 |
+
return PromptTemplate(template=custom_instructions)
|
| 94 |
+
|
| 95 |
+
def create_safe_custom_prompt(tools, llm):
|
| 96 |
+
"""Create a safe version that won't have formatting conflicts"""
|
| 97 |
+
|
| 98 |
+
custom_instructions = """
|
| 99 |
+
## MEDICAL ASSISTANT ROLE
|
| 100 |
+
You are a helpful and knowledgeable AI-powered vaccine assistant designed to support doctors in clinical decision-making.
|
| 101 |
+
You provide evidence-based guidance using only information from official vaccine medical documents.
|
| 102 |
+
Answer the doctor's question accurately and concisely using only the provided information.
|
| 103 |
+
|
| 104 |
+
## IMPORTANT REQUIREMENTS
|
| 105 |
+
|
| 106 |
+
### Citation and Sourcing
|
| 107 |
+
1. For each fact in your response, include an inline citation in the format [Source] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90].
|
| 108 |
+
2. Do NOT use 'Source:' in the citation format; use only the Source in square brackets.
|
| 109 |
+
3. If a fact is supported by multiple sources, use the following format:
|
| 110 |
+
- Use adjacent citations: [e795ebd28318886c0b1a5395ac30ad90][21a932b2340bb16707763f57f0ad2]
|
| 111 |
+
4. Use ONLY the provided information and never include facts from your general knowledge.
|
| 112 |
+
|
| 113 |
+
### Content Formatting
|
| 114 |
+
1. When rendering tables:
|
| 115 |
+
- Convert HTML tables into clean Markdown format
|
| 116 |
+
- Preserve all original headers and data rows exactly
|
| 117 |
+
- Include the citation in the table caption, e.g., 'Table: Vaccination Schedule [Source]'
|
| 118 |
+
2. For lists, maintain the original bullet points/numbering and include citations.
|
| 119 |
+
3. Present information concisely but ensure clinical accuracy is never compromised.
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
"""
|
| 124 |
+
|
| 125 |
+
# Get the exact original template first
|
| 126 |
+
temp_agent = ReActAgent.from_tools(tools, llm=llm, verbose=False)
|
| 127 |
+
original_prompts = temp_agent.get_prompts()
|
| 128 |
+
original_template = original_prompts["agent_worker:system_prompt"].template
|
| 129 |
+
|
| 130 |
+
# Add instructions at the very beginning
|
| 131 |
+
safe_template = f"{custom_instructions}{original_template}"
|
| 132 |
+
|
| 133 |
+
# Create new prompt with same metadata as original
|
| 134 |
+
original_prompt = original_prompts["agent_worker:system_prompt"]
|
| 135 |
+
|
| 136 |
+
try:
|
| 137 |
+
new_prompt = PromptTemplate(
|
| 138 |
+
template=safe_template,
|
| 139 |
+
template_vars=original_prompt.template_vars,
|
| 140 |
+
metadata=original_prompt.metadata if hasattr(original_prompt, 'metadata') else None
|
| 141 |
+
)
|
| 142 |
+
return new_prompt
|
| 143 |
+
except:
|
| 144 |
+
# Even safer fallback
|
| 145 |
+
return PromptTemplate(template=safe_template)
|
| 146 |
+
|
| 147 |
+
def create_agent(tools, llm):
|
| 148 |
+
"""Create the ReAct agent with custom prompt"""
|
| 149 |
+
|
| 150 |
+
# Create agent
|
| 151 |
+
agent = ReActAgent.from_tools(
|
| 152 |
+
tools,
|
| 153 |
+
llm=llm,
|
| 154 |
+
verbose=True,
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Create and apply safe custom prompt
|
| 158 |
+
try:
|
| 159 |
+
safe_custom_prompt = create_safe_custom_prompt(tools, llm)
|
| 160 |
+
agent.update_prompts({"agent_worker:system_prompt": safe_custom_prompt})
|
| 161 |
+
print("✅ Successfully updated with safe custom prompt")
|
| 162 |
+
except Exception as e:
|
| 163 |
+
print(f"❌ Safe prompt update failed: {e}")
|
| 164 |
+
print("⚠️ Using original agent without modifications")
|
| 165 |
+
|
| 166 |
+
return agent
|
| 167 |
+
|
| 168 |
+
def initialize_rag_pipeline(tools):
|
| 169 |
+
"""Initialize the RAG pipeline with tools"""
|
| 170 |
+
|
| 171 |
+
# Initialize LlamaIndex LLM
|
| 172 |
+
llama_index_llm = GoogleGenAI(
|
| 173 |
+
model="models/gemini-2.0-flash",
|
| 174 |
+
api_key=os.getenv('GOOGLE_API_KEY', 'AIzaSyDsbC8H6e08TKDwa5WPE3SiBA39e20K4co'),
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Create agent
|
| 178 |
+
agent = create_agent(tools, llama_index_llm)
|
| 179 |
+
|
| 180 |
+
return agent
|
| 181 |
+
|
| 182 |
+
def process_question(agent, question: str) -> str:
|
| 183 |
+
"""Process a question through the RAG pipeline"""
|
| 184 |
+
try:
|
| 185 |
+
response = agent.chat(question)
|
| 186 |
+
return response.response
|
| 187 |
+
except Exception as e:
|
| 188 |
+
print(f"Error processing question: {e}")
|
| 189 |
+
return f"Error processing your question: {str(e)}"
|
rag_system.py
DELETED
|
@@ -1,345 +0,0 @@
|
|
| 1 |
-
import os
|
| 2 |
-
import json
|
| 3 |
-
from typing import List, Dict, Any, Optional
|
| 4 |
-
|
| 5 |
-
# LlamaIndex imports
|
| 6 |
-
from llama_index.core import Settings, PromptTemplate
|
| 7 |
-
from llama_index.llms.google_genai import GoogleGenAI
|
| 8 |
-
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 9 |
-
from llama_index.core.agent import ReActAgent
|
| 10 |
-
from llama_index.core.tools import FunctionTool
|
| 11 |
-
from llama_index.core.schema import TextNode
|
| 12 |
-
|
| 13 |
-
# LangChain imports
|
| 14 |
-
from langchain_community.vectorstores import Chroma
|
| 15 |
-
from langchain_core.documents import Document
|
| 16 |
-
from langchain.embeddings import HuggingFaceEmbeddings
|
| 17 |
-
from langchain.retrievers import BM25Retriever, EnsembleRetriever
|
| 18 |
-
from langchain.retrievers.multi_query import MultiQueryRetriever
|
| 19 |
-
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 20 |
-
|
| 21 |
-
from config import Config
|
| 22 |
-
|
| 23 |
-
class AgenticRAGSystem:
|
| 24 |
-
def __init__(self, config: Config):
|
| 25 |
-
self.config = config
|
| 26 |
-
self.setup_llm_and_embeddings()
|
| 27 |
-
self.guide_retriever = None
|
| 28 |
-
self.section_retrievers = {}
|
| 29 |
-
self.agent = None
|
| 30 |
-
|
| 31 |
-
def setup_llm_and_embeddings(self):
|
| 32 |
-
"""Initialize LLM and embedding models"""
|
| 33 |
-
# LlamaIndex settings
|
| 34 |
-
Settings.llm = GoogleGenAI(
|
| 35 |
-
model=self.config.LLM_MODEL,
|
| 36 |
-
api_key=self.config.GOOGLE_API_KEY_1,
|
| 37 |
-
)
|
| 38 |
-
Settings.embed_model = HuggingFaceEmbedding(
|
| 39 |
-
model_name=self.config.EMBEDDING_MODEL
|
| 40 |
-
)
|
| 41 |
-
|
| 42 |
-
# LangChain components
|
| 43 |
-
self.embedding_function = HuggingFaceEmbeddings(
|
| 44 |
-
model_name=self.config.EMBEDDING_MODEL
|
| 45 |
-
)
|
| 46 |
-
self.llm = ChatGoogleGenerativeAI(
|
| 47 |
-
model="gemini-2.0-flash",
|
| 48 |
-
google_api_key=self.config.GOOGLE_API_KEY_2
|
| 49 |
-
)
|
| 50 |
-
|
| 51 |
-
def create_vectorstore_from_json(self, json_path: str, collection_name: str):
|
| 52 |
-
"""Create vector store from JSON chunks"""
|
| 53 |
-
if not os.path.exists(json_path):
|
| 54 |
-
raise FileNotFoundError(f"JSON file not found: {json_path}")
|
| 55 |
-
|
| 56 |
-
with open(json_path, "r", encoding="utf-8") as f:
|
| 57 |
-
chunks_data = json.load(f)
|
| 58 |
-
|
| 59 |
-
documents = []
|
| 60 |
-
for element in chunks_data:
|
| 61 |
-
text = element.get("text", "").strip()
|
| 62 |
-
if not text:
|
| 63 |
-
continue
|
| 64 |
-
|
| 65 |
-
metadata = {
|
| 66 |
-
"language": "fra",
|
| 67 |
-
"source": element.get("filename", "unknown"),
|
| 68 |
-
"filetype": element.get("filetype", "unknown"),
|
| 69 |
-
"element_id": element.get("element_id", "unknown")
|
| 70 |
-
}
|
| 71 |
-
|
| 72 |
-
if element.get("type") == "TableElement":
|
| 73 |
-
metadata["table_text_as_html"] = element.get("table_text_as_html", "")
|
| 74 |
-
|
| 75 |
-
doc = Document(page_content=text, metadata=metadata)
|
| 76 |
-
documents.append(doc)
|
| 77 |
-
|
| 78 |
-
vectorstore = Chroma.from_documents(
|
| 79 |
-
documents=documents,
|
| 80 |
-
embedding=self.embedding_function,
|
| 81 |
-
collection_name=collection_name,
|
| 82 |
-
persist_directory=self.config.CHROMA_DB_PATH
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
return vectorstore, documents
|
| 86 |
-
|
| 87 |
-
def create_retriever(self, vectorstore, docs):
|
| 88 |
-
"""Create ensemble retriever with semantic and BM25 search"""
|
| 89 |
-
retriever_multilingual = vectorstore.as_retriever(
|
| 90 |
-
search_type="similarity",
|
| 91 |
-
search_kwargs={"k": 6}
|
| 92 |
-
)
|
| 93 |
-
|
| 94 |
-
bm25_retriever = BM25Retriever.from_documents(docs)
|
| 95 |
-
bm25_retriever.k = 2
|
| 96 |
-
|
| 97 |
-
ensemble_retriever = EnsembleRetriever(
|
| 98 |
-
retrievers=[retriever_multilingual, bm25_retriever],
|
| 99 |
-
weights=[0.5, 0.5]
|
| 100 |
-
)
|
| 101 |
-
|
| 102 |
-
expanding_retriever = MultiQueryRetriever.from_llm(
|
| 103 |
-
retriever=ensemble_retriever,
|
| 104 |
-
llm=self.llm
|
| 105 |
-
)
|
| 106 |
-
|
| 107 |
-
return expanding_retriever
|
| 108 |
-
|
| 109 |
-
def convert_chromadb_to_llamaindex_nodes(self, chromadb_documents: List) -> List[TextNode]:
|
| 110 |
-
"""Convert ChromaDB documents to LlamaIndex TextNode objects"""
|
| 111 |
-
nodes = []
|
| 112 |
-
for i, doc in enumerate(chromadb_documents):
|
| 113 |
-
try:
|
| 114 |
-
text = doc.page_content
|
| 115 |
-
metadata = doc.metadata.copy()
|
| 116 |
-
element_id = metadata.get("element_id", f"doc_{i}")
|
| 117 |
-
source = metadata.get("source", "unknown")
|
| 118 |
-
node_id = f"{source}_{element_id}"
|
| 119 |
-
|
| 120 |
-
node = TextNode(
|
| 121 |
-
text=text,
|
| 122 |
-
metadata=metadata,
|
| 123 |
-
id_=node_id
|
| 124 |
-
)
|
| 125 |
-
nodes.append(node)
|
| 126 |
-
except Exception as e:
|
| 127 |
-
print(f"Error converting document {i}: {e}")
|
| 128 |
-
continue
|
| 129 |
-
|
| 130 |
-
return nodes
|
| 131 |
-
|
| 132 |
-
def section_tool_wrapper(self, retriever, section_path_chunks, query):
|
| 133 |
-
"""Generic wrapper for section-specific tools"""
|
| 134 |
-
try:
|
| 135 |
-
retrieved_docs = retriever.get_relevant_documents(query)
|
| 136 |
-
|
| 137 |
-
if not retrieved_docs:
|
| 138 |
-
return "No relevant documents found for the query."
|
| 139 |
-
|
| 140 |
-
chunk_ids = [doc.metadata.get('element_id') for doc in retrieved_docs]
|
| 141 |
-
|
| 142 |
-
if not os.path.exists(section_path_chunks):
|
| 143 |
-
return f"Section data file not found: {section_path_chunks}"
|
| 144 |
-
|
| 145 |
-
with open(section_path_chunks, "r", encoding="utf-8") as f:
|
| 146 |
-
chunks_data = json.load(f)
|
| 147 |
-
|
| 148 |
-
chunks_unique = [node for node in chunks_data if node.get('element_id', 'Unknown') in chunk_ids]
|
| 149 |
-
combined_text = []
|
| 150 |
-
|
| 151 |
-
for chu in chunks_unique:
|
| 152 |
-
if chu.get("type") == "TableElement":
|
| 153 |
-
text = f"[Source: {chu.get('element_id', 'Unknown')}]\nCONTENT:\n{chu.get('text', '')}\nHTML:\n{chu.get('table_text_as_html', '')}\n\n"
|
| 154 |
-
combined_text.append(text)
|
| 155 |
-
else:
|
| 156 |
-
elements = chu.get("elements", [chu]) # Handle both formats
|
| 157 |
-
for element in elements:
|
| 158 |
-
text = f"[Source: {element.get('element_id', 'Unknown')}]\nCONTENT:\n{element.get('text', '')}\n\n"
|
| 159 |
-
combined_text.append(text)
|
| 160 |
-
|
| 161 |
-
result = "\n---\n".join(combined_text)
|
| 162 |
-
print(f"Retrieved {len(retrieved_docs)} documents for query: {query[:50]}...")
|
| 163 |
-
return result
|
| 164 |
-
|
| 165 |
-
except Exception as e:
|
| 166 |
-
print(f"Error in section tool wrapper: {e}")
|
| 167 |
-
return f"Error retrieving documents: {str(e)}"
|
| 168 |
-
|
| 169 |
-
def initialize_system(self):
|
| 170 |
-
"""Initialize all retrievers and create the agent"""
|
| 171 |
-
try:
|
| 172 |
-
# File paths - make these configurable
|
| 173 |
-
json_files = {
|
| 174 |
-
# "guide": "Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json",
|
| 175 |
-
# "section_1": "section_one_chunks.json",
|
| 176 |
-
# "section_2": "section_two_chunks.json",
|
| 177 |
-
# "section_3": "section_three_chunks.json",
|
| 178 |
-
# "section_4": "section_four_chunks.json",
|
| 179 |
-
# "section_5": "section_five_chunks.json",
|
| 180 |
-
# "section_6": "section_six_chunks.json",
|
| 181 |
-
# "section_7": "section_seven_chunks.json",
|
| 182 |
-
# "section_8": "section_eight_chunks.json",
|
| 183 |
-
# "section_9": "section_nine_chunks.json",
|
| 184 |
-
"section_10": "section_ten_chunks.json",
|
| 185 |
-
}
|
| 186 |
-
|
| 187 |
-
# Check if files exist
|
| 188 |
-
for name, filepath in json_files.items():
|
| 189 |
-
full_path = os.path.join(self.config.BASE_PATH, filepath)
|
| 190 |
-
if not os.path.exists(full_path):
|
| 191 |
-
print(f"Warning: {name} file not found at {full_path}")
|
| 192 |
-
|
| 193 |
-
# Initialize main guide retriever
|
| 194 |
-
guide_path = os.path.join(self.config.BASE_PATH, json_files["guide"])
|
| 195 |
-
if os.path.exists(guide_path):
|
| 196 |
-
guide_vstore, guide_doc = self.create_vectorstore_from_json(guide_path, "Guide_2023_multilingual")
|
| 197 |
-
self.guide_retriever = self.create_retriever(guide_vstore, guide_doc)
|
| 198 |
-
|
| 199 |
-
# Initialize section retrievers
|
| 200 |
-
for i in range(1, 11):
|
| 201 |
-
section_key = f"section_{i}"
|
| 202 |
-
section_path = os.path.join(self.config.BASE_PATH, json_files[section_key])
|
| 203 |
-
if os.path.exists(section_path):
|
| 204 |
-
vstore, doc = self.create_vectorstore_from_json(section_path, f"Section_{i}_multilingual")
|
| 205 |
-
self.section_retrievers[section_key] = {
|
| 206 |
-
'retriever': self.create_retriever(vstore, doc),
|
| 207 |
-
'path': section_path
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
# Create tools
|
| 211 |
-
tools = self.create_tools()
|
| 212 |
-
|
| 213 |
-
# Create agent
|
| 214 |
-
self.agent = ReActAgent.from_tools(
|
| 215 |
-
tools,
|
| 216 |
-
llm=Settings.llm,
|
| 217 |
-
verbose=True
|
| 218 |
-
)
|
| 219 |
-
|
| 220 |
-
# Apply custom prompt
|
| 221 |
-
self.apply_custom_prompt()
|
| 222 |
-
|
| 223 |
-
print("✅ Agentic RAG system initialized successfully")
|
| 224 |
-
return True
|
| 225 |
-
|
| 226 |
-
except Exception as e:
|
| 227 |
-
print(f"❌ Failed to initialize system: {e}")
|
| 228 |
-
return False
|
| 229 |
-
|
| 230 |
-
def create_tools(self):
|
| 231 |
-
"""Create all the function tools"""
|
| 232 |
-
tools = []
|
| 233 |
-
|
| 234 |
-
# # Main guide tool
|
| 235 |
-
# if self.guide_retriever:
|
| 236 |
-
# def guide_tool(query: str) -> str:
|
| 237 |
-
# """General-purpose retrieval tool for the Algerian National Vaccination Guide (2023)"""
|
| 238 |
-
# return self.section_tool_wrapper(
|
| 239 |
-
# self.guide_retriever,
|
| 240 |
-
# os.path.join(self.config.BASE_PATH, "Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json"),
|
| 241 |
-
# query
|
| 242 |
-
# )
|
| 243 |
-
|
| 244 |
-
# tools.append(FunctionTool.from_defaults(name="Guide_vector_tool", fn=guide_tool))
|
| 245 |
-
|
| 246 |
-
# Section tools
|
| 247 |
-
section_descriptions = {
|
| 248 |
-
# "section_1": "Programme Élargi de Vaccination - General national immunization program in Algeria",
|
| 249 |
-
# "section_2": "Maladies Ciblées - Diseases targeted by the national vaccination calendar",
|
| 250 |
-
# "section_3": "Vaccins du Calendrier - Vaccines themselves: types, administration methods, compositions",
|
| 251 |
-
# "section_4": "Rattrapage Vaccinal - Catch-up vaccination procedures and schedules",
|
| 252 |
-
# "section_5": "Populations Particulières - Vaccination of special populations (premature, immunosuppressed, etc.)",
|
| 253 |
-
# "section_6": "Chaîne du Froid - Vaccine cold chain logistics and storage",
|
| 254 |
-
# "section_7": "Sécurité des Injections - Safe injection practices",
|
| 255 |
-
# "section_8": "Séance de Vaccination & Vaccinovigilance - Vaccination sessions and adverse event monitoring",
|
| 256 |
-
# "section_9": "Planification des Séances - Planning of vaccination sessions",
|
| 257 |
-
"section_10": "Mobilisation Sociale - Community mobilization and vaccine hesitancy"
|
| 258 |
-
}
|
| 259 |
-
|
| 260 |
-
for section_key, description in section_descriptions.items():
|
| 261 |
-
if section_key in self.section_retrievers:
|
| 262 |
-
def create_section_tool(section_data, desc):
|
| 263 |
-
def section_tool(query: str) -> str:
|
| 264 |
-
return self.section_tool_wrapper(
|
| 265 |
-
section_data['retriever'],
|
| 266 |
-
section_data['path'],
|
| 267 |
-
query
|
| 268 |
-
)
|
| 269 |
-
section_tool.__doc__ = f"Handles queries about {desc}"
|
| 270 |
-
return section_tool
|
| 271 |
-
|
| 272 |
-
section_tool_func = create_section_tool(self.section_retrievers[section_key], description)
|
| 273 |
-
tools.append(FunctionTool.from_defaults(
|
| 274 |
-
name=f"{section_key}_vector_query_tool",
|
| 275 |
-
fn=section_tool_func
|
| 276 |
-
))
|
| 277 |
-
|
| 278 |
-
return tools
|
| 279 |
-
|
| 280 |
-
def apply_custom_prompt(self):
|
| 281 |
-
"""Apply custom instructions to the agent"""
|
| 282 |
-
custom_instructions = """
|
| 283 |
-
## MEDICAL ASSISTANT ROLE
|
| 284 |
-
You are a helpful and knowledgeable AI-powered vaccine assistant designed to support doctors in clinical decision-making.
|
| 285 |
-
You provide evidence-based guidance using only information from official vaccine medical documents.
|
| 286 |
-
Answer the doctor's question accurately and concisely using only the provided information.
|
| 287 |
-
|
| 288 |
-
## IMPORTANT REQUIREMENTS
|
| 289 |
-
|
| 290 |
-
### Citation and Sourcing
|
| 291 |
-
1. For each fact in your response, include an inline citation in the format [Source] immediately following the information.
|
| 292 |
-
2. Do NOT use 'Source:' in the citation format; use only the Source in square brackets.
|
| 293 |
-
3. If a fact is supported by multiple sources, use adjacent citations.
|
| 294 |
-
4. Use ONLY the provided information and never include facts from your general knowledge.
|
| 295 |
-
|
| 296 |
-
### Content Formatting
|
| 297 |
-
1. When rendering tables: Convert HTML tables into clean Markdown format
|
| 298 |
-
2. For lists, maintain the original bullet points/numbering and include citations.
|
| 299 |
-
3. Present information concisely but ensure clinical accuracy is never compromised.
|
| 300 |
-
"""
|
| 301 |
-
|
| 302 |
-
try:
|
| 303 |
-
# Create safe custom prompt
|
| 304 |
-
temp_agent = ReActAgent.from_tools([], llm=Settings.llm, verbose=False)
|
| 305 |
-
original_prompts = temp_agent.get_prompts()
|
| 306 |
-
original_template = original_prompts["agent_worker:system_prompt"].template
|
| 307 |
-
|
| 308 |
-
safe_template = f"""{custom_instructions}
|
| 309 |
-
|
| 310 |
-
---
|
| 311 |
-
|
| 312 |
-
{original_template}"""
|
| 313 |
-
|
| 314 |
-
original_prompt = original_prompts["agent_worker:system_prompt"]
|
| 315 |
-
new_prompt = PromptTemplate(
|
| 316 |
-
template=safe_template,
|
| 317 |
-
template_vars=original_prompt.template_vars,
|
| 318 |
-
metadata=getattr(original_prompt, 'metadata', None)
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
self.agent.update_prompts({"agent_worker:system_prompt": new_prompt})
|
| 322 |
-
print("✅ Successfully updated with custom prompt")
|
| 323 |
-
|
| 324 |
-
except Exception as e:
|
| 325 |
-
print(f"❌ Custom prompt update failed: {e}")
|
| 326 |
-
|
| 327 |
-
def ask_question(self, question: str, with_citations: bool = False) -> str:
|
| 328 |
-
"""Process a question using the agentic RAG system"""
|
| 329 |
-
if not self.agent:
|
| 330 |
-
raise ValueError("Agent not initialized. Call initialize_system() first.")
|
| 331 |
-
|
| 332 |
-
try:
|
| 333 |
-
response = self.agent.chat(question)
|
| 334 |
-
answer = response.response
|
| 335 |
-
|
| 336 |
-
if not with_citations:
|
| 337 |
-
# Simple regex to remove citations if not wanted
|
| 338 |
-
import re
|
| 339 |
-
answer = re.sub(r'\[[\w\d-]+\]', '', answer)
|
| 340 |
-
|
| 341 |
-
return answer
|
| 342 |
-
|
| 343 |
-
except Exception as e:
|
| 344 |
-
print(f"Error processing question: {e}")
|
| 345 |
-
return f"Sorry, I encountered an error while processing your question: {str(e)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|