File size: 14,184 Bytes
7f51074
 
c23c6b4
 
7f51074
 
 
 
 
 
 
 
c23c6b4
7f51074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c23c6b4
7f51074
 
 
 
 
 
c23c6b4
7f51074
 
 
 
 
 
 
 
c23c6b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f51074
 
 
 
c23c6b4
7f51074
 
 
c23c6b4
7f51074
 
 
 
c23c6b4
7f51074
 
 
 
 
 
 
 
 
 
 
c23c6b4
7f51074
 
 
 
 
 
c23c6b4
7f51074
c23c6b4
7f51074
 
 
 
 
 
 
 
c23c6b4
7f51074
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
import json
import re
import glob
import os
from langchain_google_genai import GoogleGenerativeAI
from langchain_core.documents import Document
from langdetect import detect
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

def generate_rag_response(query, retrieved_documents, model="gemini-2.0-flash-exp"):
    """
    Perform Retrieval-Augmented Generation (RAG) using Google's Gemini.
    Args:
        query (str): The user's query.
        retrieved_documents (list of str): The documents retrieved from the retriever.
        model (str): The Gemini model to use.
    Returns:
        str: The generated response text.
    """
    information = "\n\n".join(retrieved_documents)
    
    prompt = f"""You are a helpful and knowledgeable AI-powered vaccine assistant designed to support doctors in clinical decision-making.
        You provide evidence-based guidance using only information from official vaccine medical documents.
        Answer the doctor's question accurately and concisely using only the provided information.

        IMPORTANT REQUIREMENTS:

        ### Language Settings
        1. DETECT THE LANGUAGE OF THE DOCTOR'S QUERY.
        2. YOU MUST RESPOND ONLY IN ONE OF THESE THREE LANGUAGES:
           - English (en): If the doctor's query is in English OR in any language not listed below
           - Arabic (ar): ONLY if the doctor's query is in Arabic
           - French (fr): ONLY if the doctor's query is in French
        3. DO NOT switch languages mid-response. Use ONLY ONE language throughout your entire answer.

        ### Citation and Sourcing
        1. For each fact in your response, include an inline citation in the format [Source ID] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90].
        2. Do NOT use 'Source ID:' in the citation format; use only the source ID in square brackets.
        3. If a fact is supported by multiple sources, use the following format:
           - Use adjacent citations: [e795ebd28318886c0b1a5395ac30ad90][21a932b2340bb16707763f57f0ad2]
        4. Use ONLY the provided information and never include facts from your general knowledge.

        ### Content Formatting
        1. When rendering tables:
           - Convert HTML tables into clean Markdown format
           - Preserve all original headers and data rows exactly
           - Include the citation in the table caption, e.g., "Table: Vaccination Schedule [Source ID]"
        2. For lists, maintain the original bullet points/numbering and include citations.
        3. Present information concisely but ensure clinical accuracy is never compromised.

        ### Professional Tone
        1. Maintain a professional, clinical tone appropriate for physician communication.
        2. Prioritize clarity and precision in medical terminology.

        ### Response Handling
        1. If the question cannot be answered with the provided documents:
           - English: "I don't have sufficient information in the provided documents to answer this question completely. Please consult additional official vaccine resources or a specialist for guidance on this topic."
           - Arabic: "ليس لدي معلومات كافية في الوثائق المقدمة للإجابة على هذا السؤال بشكل كامل. يرجى استشارة مصادر لقاح رسمية إضافية أو متخصص للحصول على إرشادات حول هذا الموضوع."
           - French: "Je n'ai pas suffisamment d'informations dans les documents fournis pour répondre complètement à cette question. Veuillez consulter des ressources officielles sur les vaccins ou un spécialiste pour obtenir des conseils sur ce sujet."
        2. If the question is clearly unrelated to vaccines or medicine:
           - English: "I'm specialized in providing vaccine information for healthcare professionals. Could you please ask a question related to vaccines or immunization? I'd be happy to help with that."
           - Arabic: "أنا متخصص في تقديم معلومات اللقاحات للمهنيين الصحيين. هل يمكنك طرح سؤال يتعلق باللقاحات أو التطعيم؟ سأكون سعيدًا بمساعدتك في ذلك."
           - French: "Je suis spécialisé dans la fourniture d'informations sur les vaccins pour les professionnels de la santé. Pourriez-vous poser une question liée aux vaccins ou à l'immunisation ? Je serais heureux de vous aider avec ça."
        3. For simple greetings:
           - Respond with a simple formal greeting in the same language as the query.

        Question: {query}
        Information: {information}
        """
    
    # Initialize the LLM - using GoogleGenerativeAI instead of ChatGoogleGenerativeAI
    llm = GoogleGenerativeAI(
        model=model,
        google_api_key=os.getenv("GOOGLE_API_KEY")
    )
    
    # Generate response using langchain
    response = llm.invoke(prompt)
    return response


def extract_source_ids(response_text):
    """
    Extract source IDs from the response, handling different citation formats:
    - Standard format: [Source ID]
    - Multiple sources in one citation: [Source ID1][Source ID2]
    - Multiple sources in one bracket: [Source ID1, Source ID2]

    Args:
        response_text (str): The generated response text with inline citations.

    Returns:
        list of str: List of unique source IDs found in the response text.
    """
    import re

    # First, extract all source IDs from inline citations with adjacent brackets [ID1][ID2]
    # Replace them with single brackets with comma separation to standardize format
    consolidated_text = re.sub(r'\][\s]*\[', '][', response_text)
    consolidated_text = re.sub(r'\]\[', ', ', consolidated_text)
    
    # Now extract all source IDs from any format (single ID or comma-separated IDs)
    inline_citations = re.findall(r'\[([^\[\]]+)\]', consolidated_text)
    
    if not inline_citations:
        print("Warning: No source IDs found in the response text.")
        return []

    # Process each citation which might contain multiple comma-separated IDs
    all_ids = []
    for citation in inline_citations:
        # Split by comma and strip whitespace
        ids = [id_str.strip() for id_str in citation.split(',')]
        all_ids.extend(ids)
    
    # Get unique source IDs
    source_ids = list(set(all_ids))
    
    # Filter out any non-UUID-like IDs (if needed)
    # This is now optional as we're handling various source ID formats
    # uuid_pattern = r'^[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}$'
    # source_ids = [source_id for source_id in source_ids if re.match(uuid_pattern, source_id, re.IGNORECASE)]
    
    if not source_ids:
        print("Warning: No valid source IDs found after filtering.")
        return []

    return source_ids


def format_response_with_sequential_citations(response_text, unique_ids, clean_all_citations=False):
    """
    Format the response text by either:
    - Replacing source IDs with sequential numbers (default)
    - Completely removing all citations (if clean_all_citations=True)
    
    Handles multiple citation formats:
    - Standard format: [Source ID]
    - Multiple sources in one citation: [Source ID1][Source ID2]
    - Multiple sources in one bracket: [Source ID1, Source ID2]
    
    Args:
        response_text (str): The generated response text with inline citations.
        unique_ids (list): List of unique source IDs found in the response.
        clean_all_citations (bool): If True, removes all citations completely.
                                   If False, formats them as numbers.
        
    Returns:
        str: The formatted response text.
    """
    import re
    
    if not unique_ids:
        return response_text

    formatted_response = response_text
    
    # Create a mapping from source ID to sequential number
    id_to_number = {source_id: str(i+1) for i, source_id in enumerate(unique_ids)}
    
    if clean_all_citations:
        # Remove all citations completely
        formatted_response = re.sub(r'\[[^\[\]]+?\]', '', formatted_response)
        # Clean up any resulting double spaces
        formatted_response = re.sub(r'\s+', ' ', formatted_response)
    else:
        # First, standardize adjacent citations [ID1][ID2] to [ID1, ID2]
        formatted_response = re.sub(r'\][\s]*\[', '][', formatted_response)
        formatted_response = re.sub(r'\]\[', ', ', formatted_response)
        
        # Now handle citations with multiple IDs
        def replace_citation(match):
            content = match.group(1)
            # Check if there are multiple IDs separated by commas
            if ',' in content:
                ids = [id_str.strip() for id_str in content.split(',')]
                numbers = []
                for id_str in ids:
                    if id_str in id_to_number:
                        numbers.append(id_to_number[id_str])
                if numbers:
                    return f"[{', '.join(numbers)}]"
            # Single ID case
            elif content in id_to_number:
                return f"[{id_to_number[content]}]"
            return match.group(0)
        
        # Replace citations with their sequential numbers
        formatted_response = re.sub(r'\[([^\[\]]+)\]', replace_citation, formatted_response)
    
    return formatted_response.strip()

def retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_directory="./data/"):
    """
    Retrieve relevant documents and prepare them for the RAG generation.

    Args:
        query (str): The user's query.
        expanding_retriever: The retriever object (e.g., returned by prepare_environment_and_retriever).
        chunks_directory (str): Path to the directory containing JSON files.

    Returns:
        tuple: (source_texts_for_rag, retrieved_elements_full)
    """
    # Get documents - query expansion happens automatically
    retrieved_docs = expanding_retriever.get_relevant_documents(query)

    retrieved_chunk_ids = [doc.metadata["element_id"] for doc in retrieved_docs]
    
    # Get unique filenames from retrieved documents
    needed_filenames = set(doc.metadata["source"] for doc in retrieved_docs)
    
    # Convert PDF filenames to JSON filenames (e.g., "file.pdf" -> "file.json")
    needed_json_files = []
    for filename in needed_filenames:
        # Remove extension and add .json
        base_name = os.path.splitext(filename)[0]
        json_filename = f"{base_name}.json"
        json_path = os.path.join(chunks_directory, json_filename)
        if os.path.exists(json_path):
            needed_json_files.append(json_path)
        else:
            print(f"Warning: JSON file not found: {json_path}")

    # Load only the needed JSON files
    all_chunks_data = []
    for json_file in needed_json_files:
        print(f"Loading: {os.path.basename(json_file)}")
        with open(json_file, "r", encoding="utf-8") as f:
            chunks_data = json.load(f)
            all_chunks_data.extend(chunks_data)

    source_retrieved_texts = []
    retrieved_elements_full = []

    for chu in all_chunks_data:
        if chu["element_id"] in retrieved_chunk_ids:
            if chu.get("type") == "TableElement":
                text = (
                    f"[Source ID: {chu['element_id']}]\n"
                    f"CONTENT:\n{chu['text']}\n"
                    f"HTML:\n{chu['table_text_as_html']}\n\n"
                )
                source_retrieved_texts.append(text)
                retrieved_elements_full.append(chu)
            else:
                for element in chu.get("elements", []):
                    text = (
                        f"[Source ID: {element['element_id']}]\n"
                        f"CONTENT:\n{element['text']}\n\n"
                    )
                    source_retrieved_texts.append(text)
                    retrieved_elements_full.append(element)

    return source_retrieved_texts, retrieved_elements_full

def full_rag_pipeline(query, expanding_retriever, chunks_directory="./data/", model="gemini-2.0-flash-exp", clean_all_citations=False):
    """
    Full RAG pipeline from query to RAG response + extracted sources.

    Args:
        query (str): The user's query.
        expanding_retriever: The retriever object.
        chunks_directory (str): Path to the directory containing JSON files.
        model (str): Gemini model.
        clean_all_citations (bool): Whether to remove all citations from response.

    Returns:
        dict: {
            "response": str,
            "cited_elements_json": str,
            "answer_language": str
        }
    """
    source_texts, retrieved_elements = retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_directory)

    # Step 1: RAG
    response_text = generate_rag_response(query, source_texts, model=model)

    # Step 2: Extract cited sources
    unique_ids = extract_source_ids(response_text)
    
    # Step 2.1: Format the response text with sequential citations
    response_text = format_response_with_sequential_citations(response_text, unique_ids, clean_all_citations=clean_all_citations)

    # Step 3: Get only the cited elements
    cited_elements = [element for element in retrieved_elements if element["element_id"] in unique_ids]

    cited_elements_json = json.dumps(cited_elements, ensure_ascii=False, indent=2)
    
    # Improved language detection
    try:
        # Detect the language of the first 5 words of the response
        first_line = " ".join(response_text.split()[:5])
        first_line = re.sub(r'\[.*?\]', '', first_line)  # Remove citations
        answer_language = detect(first_line)
        if answer_language not in ['en', 'ar', 'fr']:
            # Fall back to query language if detection fails
            answer_language = detect(query)
    except:
        answer_language = detect(query) if detect(query) in ['en', 'ar', 'fr'] else 'en'

    return {
        "response": response_text,
        "cited_elements_json": cited_elements_json,
        "answer_language": answer_language
    }