Spaces:
Sleeping
Sleeping
| import json | |
| import re | |
| import glob | |
| import os | |
| from langchain_google_genai import GoogleGenerativeAI | |
| from langchain_core.documents import Document | |
| from langdetect import detect | |
| from dotenv import load_dotenv | |
| # Load environment variables from .env file | |
| load_dotenv() | |
| def generate_rag_response(query, retrieved_documents, model="gemini-2.0-flash-exp"): | |
| """ | |
| Perform Retrieval-Augmented Generation (RAG) using Google's Gemini. | |
| Args: | |
| query (str): The user's query. | |
| retrieved_documents (list of str): The documents retrieved from the retriever. | |
| model (str): The Gemini model to use. | |
| Returns: | |
| str: The generated response text. | |
| """ | |
| information = "\n\n".join(retrieved_documents) | |
| prompt = f"""You are a helpful and knowledgeable AI-powered vaccine assistant designed to support doctors in clinical decision-making. | |
| You provide evidence-based guidance using only information from official vaccine medical documents. | |
| Answer the doctor's question accurately and concisely using only the provided information. | |
| IMPORTANT REQUIREMENTS: | |
| ### Language Settings | |
| 1. DETECT THE LANGUAGE OF THE DOCTOR'S QUERY. | |
| 2. YOU MUST RESPOND ONLY IN ONE OF THESE THREE LANGUAGES: | |
| - English (en): If the doctor's query is in English OR in any language not listed below | |
| - Arabic (ar): ONLY if the doctor's query is in Arabic | |
| - French (fr): ONLY if the doctor's query is in French | |
| 3. DO NOT switch languages mid-response. Use ONLY ONE language throughout your entire answer. | |
| ### Citation and Sourcing | |
| 1. For each fact in your response, include an inline citation in the format [Source ID] immediately following the information, e.g., [e795ebd28318886c0b1a5395ac30ad90]. | |
| 2. Do NOT use 'Source ID:' in the citation format; use only the source ID in square brackets. | |
| 3. If a fact is supported by multiple sources, use the following format: | |
| - Use adjacent citations: [e795ebd28318886c0b1a5395ac30ad90][21a932b2340bb16707763f57f0ad2] | |
| 4. Use ONLY the provided information and never include facts from your general knowledge. | |
| ### Content Formatting | |
| 1. When rendering tables: | |
| - Convert HTML tables into clean Markdown format | |
| - Preserve all original headers and data rows exactly | |
| - Include the citation in the table caption, e.g., "Table: Vaccination Schedule [Source ID]" | |
| 2. For lists, maintain the original bullet points/numbering and include citations. | |
| 3. Present information concisely but ensure clinical accuracy is never compromised. | |
| ### Professional Tone | |
| 1. Maintain a professional, clinical tone appropriate for physician communication. | |
| 2. Prioritize clarity and precision in medical terminology. | |
| ### Response Handling | |
| 1. If the question cannot be answered with the provided documents: | |
| - English: "I don't have sufficient information in the provided documents to answer this question completely. Please consult additional official vaccine resources or a specialist for guidance on this topic." | |
| - Arabic: "ليس لدي معلومات كافية في الوثائق المقدمة للإجابة على هذا السؤال بشكل كامل. يرجى استشارة مصادر لقاح رسمية إضافية أو متخصص للحصول على إرشادات حول هذا الموضوع." | |
| - French: "Je n'ai pas suffisamment d'informations dans les documents fournis pour répondre complètement à cette question. Veuillez consulter des ressources officielles sur les vaccins ou un spécialiste pour obtenir des conseils sur ce sujet." | |
| 2. If the question is clearly unrelated to vaccines or medicine: | |
| - English: "I'm specialized in providing vaccine information for healthcare professionals. Could you please ask a question related to vaccines or immunization? I'd be happy to help with that." | |
| - Arabic: "أنا متخصص في تقديم معلومات اللقاحات للمهنيين الصحيين. هل يمكنك طرح سؤال يتعلق باللقاحات أو التطعيم؟ سأكون سعيدًا بمساعدتك في ذلك." | |
| - French: "Je suis spécialisé dans la fourniture d'informations sur les vaccins pour les professionnels de la santé. Pourriez-vous poser une question liée aux vaccins ou à l'immunisation ? Je serais heureux de vous aider avec ça." | |
| 3. For simple greetings: | |
| - Respond with a simple formal greeting in the same language as the query. | |
| Question: {query} | |
| Information: {information} | |
| """ | |
| # Initialize the LLM - using GoogleGenerativeAI instead of ChatGoogleGenerativeAI | |
| llm = GoogleGenerativeAI( | |
| model=model, | |
| google_api_key=os.getenv("GOOGLE_API_KEY") | |
| ) | |
| # Generate response using langchain | |
| response = llm.invoke(prompt) | |
| return response | |
| def extract_source_ids(response_text): | |
| """ | |
| Extract source IDs from the response, handling different citation formats: | |
| - Standard format: [Source ID] | |
| - Multiple sources in one citation: [Source ID1][Source ID2] | |
| - Multiple sources in one bracket: [Source ID1, Source ID2] | |
| Args: | |
| response_text (str): The generated response text with inline citations. | |
| Returns: | |
| list of str: List of unique source IDs found in the response text. | |
| """ | |
| import re | |
| # First, extract all source IDs from inline citations with adjacent brackets [ID1][ID2] | |
| # Replace them with single brackets with comma separation to standardize format | |
| consolidated_text = re.sub(r'\][\s]*\[', '][', response_text) | |
| consolidated_text = re.sub(r'\]\[', ', ', consolidated_text) | |
| # Now extract all source IDs from any format (single ID or comma-separated IDs) | |
| inline_citations = re.findall(r'\[([^\[\]]+)\]', consolidated_text) | |
| if not inline_citations: | |
| print("Warning: No source IDs found in the response text.") | |
| return [] | |
| # Process each citation which might contain multiple comma-separated IDs | |
| all_ids = [] | |
| for citation in inline_citations: | |
| # Split by comma and strip whitespace | |
| ids = [id_str.strip() for id_str in citation.split(',')] | |
| all_ids.extend(ids) | |
| # Get unique source IDs | |
| source_ids = list(set(all_ids)) | |
| # Filter out any non-UUID-like IDs (if needed) | |
| # This is now optional as we're handling various source ID formats | |
| # uuid_pattern = r'^[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}$' | |
| # source_ids = [source_id for source_id in source_ids if re.match(uuid_pattern, source_id, re.IGNORECASE)] | |
| if not source_ids: | |
| print("Warning: No valid source IDs found after filtering.") | |
| return [] | |
| return source_ids | |
| def format_response_with_sequential_citations(response_text, unique_ids, clean_all_citations=False): | |
| """ | |
| Format the response text by either: | |
| - Replacing source IDs with sequential numbers (default) | |
| - Completely removing all citations (if clean_all_citations=True) | |
| Handles multiple citation formats: | |
| - Standard format: [Source ID] | |
| - Multiple sources in one citation: [Source ID1][Source ID2] | |
| - Multiple sources in one bracket: [Source ID1, Source ID2] | |
| Args: | |
| response_text (str): The generated response text with inline citations. | |
| unique_ids (list): List of unique source IDs found in the response. | |
| clean_all_citations (bool): If True, removes all citations completely. | |
| If False, formats them as numbers. | |
| Returns: | |
| str: The formatted response text. | |
| """ | |
| import re | |
| if not unique_ids: | |
| return response_text | |
| formatted_response = response_text | |
| # Create a mapping from source ID to sequential number | |
| id_to_number = {source_id: str(i+1) for i, source_id in enumerate(unique_ids)} | |
| if clean_all_citations: | |
| # Remove all citations completely | |
| formatted_response = re.sub(r'\[[^\[\]]+?\]', '', formatted_response) | |
| # Clean up any resulting double spaces | |
| formatted_response = re.sub(r'\s+', ' ', formatted_response) | |
| else: | |
| # First, standardize adjacent citations [ID1][ID2] to [ID1, ID2] | |
| formatted_response = re.sub(r'\][\s]*\[', '][', formatted_response) | |
| formatted_response = re.sub(r'\]\[', ', ', formatted_response) | |
| # Now handle citations with multiple IDs | |
| def replace_citation(match): | |
| content = match.group(1) | |
| # Check if there are multiple IDs separated by commas | |
| if ',' in content: | |
| ids = [id_str.strip() for id_str in content.split(',')] | |
| numbers = [] | |
| for id_str in ids: | |
| if id_str in id_to_number: | |
| numbers.append(id_to_number[id_str]) | |
| if numbers: | |
| return f"[{', '.join(numbers)}]" | |
| # Single ID case | |
| elif content in id_to_number: | |
| return f"[{id_to_number[content]}]" | |
| return match.group(0) | |
| # Replace citations with their sequential numbers | |
| formatted_response = re.sub(r'\[([^\[\]]+)\]', replace_citation, formatted_response) | |
| return formatted_response.strip() | |
| def retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_directory="./data/"): | |
| """ | |
| Retrieve relevant documents and prepare them for the RAG generation. | |
| Args: | |
| query (str): The user's query. | |
| expanding_retriever: The retriever object (e.g., returned by prepare_environment_and_retriever). | |
| chunks_directory (str): Path to the directory containing JSON files. | |
| Returns: | |
| tuple: (source_texts_for_rag, retrieved_elements_full) | |
| """ | |
| # Get documents - query expansion happens automatically | |
| retrieved_docs = expanding_retriever.get_relevant_documents(query) | |
| retrieved_chunk_ids = [doc.metadata["element_id"] for doc in retrieved_docs] | |
| # Get unique filenames from retrieved documents | |
| needed_filenames = set(doc.metadata["source"] for doc in retrieved_docs) | |
| # Convert PDF filenames to JSON filenames (e.g., "file.pdf" -> "file.json") | |
| needed_json_files = [] | |
| for filename in needed_filenames: | |
| # Remove extension and add .json | |
| base_name = os.path.splitext(filename)[0] | |
| json_filename = f"{base_name}.json" | |
| json_path = os.path.join(chunks_directory, json_filename) | |
| if os.path.exists(json_path): | |
| needed_json_files.append(json_path) | |
| else: | |
| print(f"Warning: JSON file not found: {json_path}") | |
| # Load only the needed JSON files | |
| all_chunks_data = [] | |
| for json_file in needed_json_files: | |
| print(f"Loading: {os.path.basename(json_file)}") | |
| with open(json_file, "r", encoding="utf-8") as f: | |
| chunks_data = json.load(f) | |
| all_chunks_data.extend(chunks_data) | |
| source_retrieved_texts = [] | |
| retrieved_elements_full = [] | |
| for chu in all_chunks_data: | |
| if chu["element_id"] in retrieved_chunk_ids: | |
| if chu.get("type") == "TableElement": | |
| text = ( | |
| f"[Source ID: {chu['element_id']}]\n" | |
| f"CONTENT:\n{chu['text']}\n" | |
| f"HTML:\n{chu['table_text_as_html']}\n\n" | |
| ) | |
| source_retrieved_texts.append(text) | |
| retrieved_elements_full.append(chu) | |
| else: | |
| for element in chu.get("elements", []): | |
| text = ( | |
| f"[Source ID: {element['element_id']}]\n" | |
| f"CONTENT:\n{element['text']}\n\n" | |
| ) | |
| source_retrieved_texts.append(text) | |
| retrieved_elements_full.append(element) | |
| return source_retrieved_texts, retrieved_elements_full | |
| def full_rag_pipeline(query, expanding_retriever, chunks_directory="./data/", model="gemini-2.0-flash-exp", clean_all_citations=False): | |
| """ | |
| Full RAG pipeline from query to RAG response + extracted sources. | |
| Args: | |
| query (str): The user's query. | |
| expanding_retriever: The retriever object. | |
| chunks_directory (str): Path to the directory containing JSON files. | |
| model (str): Gemini model. | |
| clean_all_citations (bool): Whether to remove all citations from response. | |
| Returns: | |
| dict: { | |
| "response": str, | |
| "cited_elements_json": str, | |
| "answer_language": str | |
| } | |
| """ | |
| source_texts, retrieved_elements = retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_directory) | |
| # Step 1: RAG | |
| response_text = generate_rag_response(query, source_texts, model=model) | |
| # Step 2: Extract cited sources | |
| unique_ids = extract_source_ids(response_text) | |
| # Step 2.1: Format the response text with sequential citations | |
| response_text = format_response_with_sequential_citations(response_text, unique_ids, clean_all_citations=clean_all_citations) | |
| # Step 3: Get only the cited elements | |
| cited_elements = [element for element in retrieved_elements if element["element_id"] in unique_ids] | |
| cited_elements_json = json.dumps(cited_elements, ensure_ascii=False, indent=2) | |
| # Improved language detection | |
| try: | |
| # Detect the language of the first 5 words of the response | |
| first_line = " ".join(response_text.split()[:5]) | |
| first_line = re.sub(r'\[.*?\]', '', first_line) # Remove citations | |
| answer_language = detect(first_line) | |
| if answer_language not in ['en', 'ar', 'fr']: | |
| # Fall back to query language if detection fails | |
| answer_language = detect(query) | |
| except: | |
| answer_language = detect(query) if detect(query) in ['en', 'ar', 'fr'] else 'en' | |
| return { | |
| "response": response_text, | |
| "cited_elements_json": cited_elements_json, | |
| "answer_language": answer_language | |
| } |