Simple_RAG / rag_pipeline.py
Zeggai Abdellah
update the system
c23c6b4
import json
import re
import glob
import os
from langchain_google_genai import GoogleGenerativeAI
from langchain_core.documents import Document
from langdetect import detect
from dotenv import load_dotenv
# Load environment variables from .env file
load_dotenv()
def generate_rag_response(query, retrieved_documents, model="gemini-2.0-flash-exp"):
"""
Perform Retrieval-Augmented Generation (RAG) using Google's Gemini.
Args:
query (str): The user's query.
retrieved_documents (list of str): The documents retrieved from the retriever.
model (str): The Gemini model to use.
Returns:
str: The generated response text.
"""
information = "\n\n".join(retrieved_documents)
prompt = f"""You are a helpful and knowledgeable AI-powered vaccine assistant designed to support doctors in clinical decision-making.
You provide evidence-based guidance using only information from official vaccine medical documents.
Answer the doctor's question accurately and concisely using only the provided information.
IMPORTANT REQUIREMENTS:
### Language Settings
1. DETECT THE LANGUAGE OF THE DOCTOR'S QUERY.
2. YOU MUST RESPOND ONLY IN ONE OF THESE THREE LANGUAGES:
- English (en): If the doctor's query is in English OR in any language not listed below
- Arabic (ar): ONLY if the doctor's query is in Arabic
- French (fr): ONLY if the doctor's query is in French
3. DO NOT switch languages mid-response. Use ONLY ONE language throughout your entire answer.
### Citation and Sourcing
1. For each fact in your response, include an inline citation in the format [Source ID] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90].
2. Do NOT use 'Source ID:' in the citation format; use only the source ID in square brackets.
3. If a fact is supported by multiple sources, use the following format:
- Use adjacent citations: [e795ebd28318886c0b1a5395ac30ad90][21a932b2340bb16707763f57f0ad2]
4. Use ONLY the provided information and never include facts from your general knowledge.
### Content Formatting
1. When rendering tables:
- Convert HTML tables into clean Markdown format
- Preserve all original headers and data rows exactly
- Include the citation in the table caption, e.g., "Table: Vaccination Schedule [Source ID]"
2. For lists, maintain the original bullet points/numbering and include citations.
3. Present information concisely but ensure clinical accuracy is never compromised.
### Professional Tone
1. Maintain a professional, clinical tone appropriate for physician communication.
2. Prioritize clarity and precision in medical terminology.
### Response Handling
1. If the question cannot be answered with the provided documents:
- English: "I don't have sufficient information in the provided documents to answer this question completely. Please consult additional official vaccine resources or a specialist for guidance on this topic."
- Arabic: "ليس لدي معلومات كافية في الوثائق المقدمة للإجابة على هذا السؤال بشكل كامل. يرجى استشارة مصادر لقاح رسمية إضافية أو متخصص للحصول على إرشادات حول هذا الموضوع."
- French: "Je n'ai pas suffisamment d'informations dans les documents fournis pour répondre complètement à cette question. Veuillez consulter des ressources officielles sur les vaccins ou un spécialiste pour obtenir des conseils sur ce sujet."
2. If the question is clearly unrelated to vaccines or medicine:
- English: "I'm specialized in providing vaccine information for healthcare professionals. Could you please ask a question related to vaccines or immunization? I'd be happy to help with that."
- Arabic: "أنا متخصص في تقديم معلومات اللقاحات للمهنيين الصحيين. هل يمكنك طرح سؤال يتعلق باللقاحات أو التطعيم؟ سأكون سعيدًا بمساعدتك في ذلك."
- French: "Je suis spécialisé dans la fourniture d'informations sur les vaccins pour les professionnels de la santé. Pourriez-vous poser une question liée aux vaccins ou à l'immunisation ? Je serais heureux de vous aider avec ça."
3. For simple greetings:
- Respond with a simple formal greeting in the same language as the query.
Question: {query}
Information: {information}
"""
# Initialize the LLM - using GoogleGenerativeAI instead of ChatGoogleGenerativeAI
llm = GoogleGenerativeAI(
model=model,
google_api_key=os.getenv("GOOGLE_API_KEY")
)
# Generate response using langchain
response = llm.invoke(prompt)
return response
def extract_source_ids(response_text):
"""
Extract source IDs from the response, handling different citation formats:
- Standard format: [Source ID]
- Multiple sources in one citation: [Source ID1][Source ID2]
- Multiple sources in one bracket: [Source ID1, Source ID2]
Args:
response_text (str): The generated response text with inline citations.
Returns:
list of str: List of unique source IDs found in the response text.
"""
import re
# First, extract all source IDs from inline citations with adjacent brackets [ID1][ID2]
# Replace them with single brackets with comma separation to standardize format
consolidated_text = re.sub(r'\][\s]*\[', '][', response_text)
consolidated_text = re.sub(r'\]\[', ', ', consolidated_text)
# Now extract all source IDs from any format (single ID or comma-separated IDs)
inline_citations = re.findall(r'\[([^\[\]]+)\]', consolidated_text)
if not inline_citations:
print("Warning: No source IDs found in the response text.")
return []
# Process each citation which might contain multiple comma-separated IDs
all_ids = []
for citation in inline_citations:
# Split by comma and strip whitespace
ids = [id_str.strip() for id_str in citation.split(',')]
all_ids.extend(ids)
# Get unique source IDs
source_ids = list(set(all_ids))
# Filter out any non-UUID-like IDs (if needed)
# This is now optional as we're handling various source ID formats
# uuid_pattern = r'^[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}$'
# source_ids = [source_id for source_id in source_ids if re.match(uuid_pattern, source_id, re.IGNORECASE)]
if not source_ids:
print("Warning: No valid source IDs found after filtering.")
return []
return source_ids
def format_response_with_sequential_citations(response_text, unique_ids, clean_all_citations=False):
"""
Format the response text by either:
- Replacing source IDs with sequential numbers (default)
- Completely removing all citations (if clean_all_citations=True)
Handles multiple citation formats:
- Standard format: [Source ID]
- Multiple sources in one citation: [Source ID1][Source ID2]
- Multiple sources in one bracket: [Source ID1, Source ID2]
Args:
response_text (str): The generated response text with inline citations.
unique_ids (list): List of unique source IDs found in the response.
clean_all_citations (bool): If True, removes all citations completely.
If False, formats them as numbers.
Returns:
str: The formatted response text.
"""
import re
if not unique_ids:
return response_text
formatted_response = response_text
# Create a mapping from source ID to sequential number
id_to_number = {source_id: str(i+1) for i, source_id in enumerate(unique_ids)}
if clean_all_citations:
# Remove all citations completely
formatted_response = re.sub(r'\[[^\[\]]+?\]', '', formatted_response)
# Clean up any resulting double spaces
formatted_response = re.sub(r'\s+', ' ', formatted_response)
else:
# First, standardize adjacent citations [ID1][ID2] to [ID1, ID2]
formatted_response = re.sub(r'\][\s]*\[', '][', formatted_response)
formatted_response = re.sub(r'\]\[', ', ', formatted_response)
# Now handle citations with multiple IDs
def replace_citation(match):
content = match.group(1)
# Check if there are multiple IDs separated by commas
if ',' in content:
ids = [id_str.strip() for id_str in content.split(',')]
numbers = []
for id_str in ids:
if id_str in id_to_number:
numbers.append(id_to_number[id_str])
if numbers:
return f"[{', '.join(numbers)}]"
# Single ID case
elif content in id_to_number:
return f"[{id_to_number[content]}]"
return match.group(0)
# Replace citations with their sequential numbers
formatted_response = re.sub(r'\[([^\[\]]+)\]', replace_citation, formatted_response)
return formatted_response.strip()
def retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_directory="./data/"):
"""
Retrieve relevant documents and prepare them for the RAG generation.
Args:
query (str): The user's query.
expanding_retriever: The retriever object (e.g., returned by prepare_environment_and_retriever).
chunks_directory (str): Path to the directory containing JSON files.
Returns:
tuple: (source_texts_for_rag, retrieved_elements_full)
"""
# Get documents - query expansion happens automatically
retrieved_docs = expanding_retriever.get_relevant_documents(query)
retrieved_chunk_ids = [doc.metadata["element_id"] for doc in retrieved_docs]
# Get unique filenames from retrieved documents
needed_filenames = set(doc.metadata["source"] for doc in retrieved_docs)
# Convert PDF filenames to JSON filenames (e.g., "file.pdf" -> "file.json")
needed_json_files = []
for filename in needed_filenames:
# Remove extension and add .json
base_name = os.path.splitext(filename)[0]
json_filename = f"{base_name}.json"
json_path = os.path.join(chunks_directory, json_filename)
if os.path.exists(json_path):
needed_json_files.append(json_path)
else:
print(f"Warning: JSON file not found: {json_path}")
# Load only the needed JSON files
all_chunks_data = []
for json_file in needed_json_files:
print(f"Loading: {os.path.basename(json_file)}")
with open(json_file, "r", encoding="utf-8") as f:
chunks_data = json.load(f)
all_chunks_data.extend(chunks_data)
source_retrieved_texts = []
retrieved_elements_full = []
for chu in all_chunks_data:
if chu["element_id"] in retrieved_chunk_ids:
if chu.get("type") == "TableElement":
text = (
f"[Source ID: {chu['element_id']}]\n"
f"CONTENT:\n{chu['text']}\n"
f"HTML:\n{chu['table_text_as_html']}\n\n"
)
source_retrieved_texts.append(text)
retrieved_elements_full.append(chu)
else:
for element in chu.get("elements", []):
text = (
f"[Source ID: {element['element_id']}]\n"
f"CONTENT:\n{element['text']}\n\n"
)
source_retrieved_texts.append(text)
retrieved_elements_full.append(element)
return source_retrieved_texts, retrieved_elements_full
def full_rag_pipeline(query, expanding_retriever, chunks_directory="./data/", model="gemini-2.0-flash-exp", clean_all_citations=False):
"""
Full RAG pipeline from query to RAG response + extracted sources.
Args:
query (str): The user's query.
expanding_retriever: The retriever object.
chunks_directory (str): Path to the directory containing JSON files.
model (str): Gemini model.
clean_all_citations (bool): Whether to remove all citations from response.
Returns:
dict: {
"response": str,
"cited_elements_json": str,
"answer_language": str
}
"""
source_texts, retrieved_elements = retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_directory)
# Step 1: RAG
response_text = generate_rag_response(query, source_texts, model=model)
# Step 2: Extract cited sources
unique_ids = extract_source_ids(response_text)
# Step 2.1: Format the response text with sequential citations
response_text = format_response_with_sequential_citations(response_text, unique_ids, clean_all_citations=clean_all_citations)
# Step 3: Get only the cited elements
cited_elements = [element for element in retrieved_elements if element["element_id"] in unique_ids]
cited_elements_json = json.dumps(cited_elements, ensure_ascii=False, indent=2)
# Improved language detection
try:
# Detect the language of the first 5 words of the response
first_line = " ".join(response_text.split()[:5])
first_line = re.sub(r'\[.*?\]', '', first_line) # Remove citations
answer_language = detect(first_line)
if answer_language not in ['en', 'ar', 'fr']:
# Fall back to query language if detection fails
answer_language = detect(query)
except:
answer_language = detect(query) if detect(query) in ['en', 'ar', 'fr'] else 'en'
return {
"response": response_text,
"cited_elements_json": cited_elements_json,
"answer_language": answer_language
}