Zeggai Abdellah commited on
Commit
c23c6b4
·
1 Parent(s): 7f51074

update the system

Browse files
chunks.json → data/Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json RENAMED
The diff for this file is too large to render. See raw diff
 
data/Immunization in Practice_WHO_eng_2015.json ADDED
The diff for this file is too large to render. See raw diff
 
prepare_env.py CHANGED
@@ -1,10 +1,16 @@
1
  import json
2
  import os
 
 
3
  from dotenv import load_dotenv
 
4
 
5
  # Load environment variables from .env file
6
  load_dotenv()
 
7
  from langchain_core.documents import Document
 
 
8
  from langchain_community.vectorstores import Chroma
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
  from langchain_community.retrievers import BM25Retriever
@@ -12,78 +18,278 @@ from langchain.retrievers import EnsembleRetriever
12
  from langchain.retrievers.multi_query import MultiQueryRetriever
13
  from langchain_google_genai import GoogleGenerativeAI
14
 
15
- def prepare_environment_and_retriever(
16
- chunks_path="./chunks.json",
17
- model_name="intfloat/multilingual-e5-base",
18
- collection_name="Guide_2023_e5_multilingual",
19
- persist_directory="chroma_db_multilingual",
20
- k_vector=6,
21
- k_sparse=2,
22
- weights=[0.5, 0.5],
23
- llm_model_name="gemini-2.0-flash"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ):
25
- # Load the chunks.json
26
- with open(chunks_path, "r", encoding="utf-8") as f:
27
- chunks_data = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
 
 
 
 
 
 
 
 
29
  documents = []
 
30
 
31
- for element in chunks_data:
32
- text = element["text"]
33
- metadata = {
34
- "source": element["filename"],
35
- "filetype": element["filetype"],
36
- "element_id": element["element_id"]
37
- }
38
 
39
- if element.get("type") == "TableElement":
40
- metadata["table_text_as_html"] = element["table_text_as_html"]
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- doc = Document(page_content=text, metadata=metadata)
43
- documents.append(doc)
 
 
 
 
44
 
45
- # Create the embedding function
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  embedding_function = HuggingFaceEmbeddings(
47
- model_name=model_name
48
  )
49
 
50
- # Create and persist the vector store
51
- vectorstore = Chroma.from_documents(
52
- documents=documents,
53
- embedding=embedding_function,
54
- collection_name=collection_name,
55
- persist_directory=persist_directory
56
- )
57
- # vectorstore.persist()
58
- print("✅ Stored with multilingual embeddings.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Build retrievers
61
- retriever_multilingual = vectorstore.as_retriever(
 
 
 
62
  search_type="similarity",
63
  search_kwargs={"k": k_vector}
64
  )
65
 
 
66
  bm25_retriever = BM25Retriever.from_documents(documents)
67
  bm25_retriever.k = k_sparse
68
 
69
  # Ensemble retriever (combining vector + sparse search)
70
  ensemble_retriever = EnsembleRetriever(
71
- retrievers=[retriever_multilingual, bm25_retriever],
72
- weights=weights
73
  )
 
74
 
75
  # Language model for multi-query expansion
76
- # Using GoogleGenerativeAI instead of ChatGoogleGenerativeAI
77
- llm = GoogleGenerativeAI(
78
- model=llm_model_name,
79
- google_api_key=os.getenv("GOOGLE_API_KEY")
80
- )
81
-
82
- expanding_retriever = MultiQueryRetriever.from_llm(
83
- retriever=ensemble_retriever,
84
- llm=llm
85
- )
86
-
87
- print("✅ Retrieval system ready (vector + sparse + ensemble + multi-query).")
 
 
 
 
88
 
89
- return expanding_retriever # Return the final retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import os
3
+ import glob
4
+ from typing import List, Optional
5
  from dotenv import load_dotenv
6
+ import logging
7
 
8
  # Load environment variables from .env file
9
  load_dotenv()
10
+
11
  from langchain_core.documents import Document
12
+ from langchain_core.output_parsers import BaseOutputParser
13
+ from langchain_core.prompts import PromptTemplate
14
  from langchain_community.vectorstores import Chroma
15
  from langchain_huggingface import HuggingFaceEmbeddings
16
  from langchain_community.retrievers import BM25Retriever
 
18
  from langchain.retrievers.multi_query import MultiQueryRetriever
19
  from langchain_google_genai import GoogleGenerativeAI
20
 
21
+ # Set up logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ class LineListOutputParser(BaseOutputParser[List[str]]):
26
+ """Custom output parser for a list of lines with better error handling."""
27
+
28
+ def parse(self, text: str) -> List[str]:
29
+ """Parse the LLM output into a list of queries."""
30
+ try:
31
+ lines = text.strip().split("\n")
32
+ # Remove empty lines and clean up
33
+ cleaned_lines = []
34
+ for line in lines:
35
+ cleaned = line.strip()
36
+ if cleaned and not cleaned.startswith("#") and len(cleaned) > 5:
37
+ # Remove numbering if present (e.g., "1. ", "- ", etc.)
38
+ if cleaned[0].isdigit() and ". " in cleaned:
39
+ cleaned = cleaned.split(". ", 1)[1]
40
+ elif cleaned.startswith("- "):
41
+ cleaned = cleaned[2:]
42
+ cleaned_lines.append(cleaned)
43
+
44
+ # Ensure we have at least one query
45
+ if not cleaned_lines:
46
+ cleaned_lines = [text.strip()]
47
+
48
+ return cleaned_lines
49
+ except Exception as e:
50
+ logger.warning(f"Error parsing output: {e}. Returning original text.")
51
+ return [text.strip()] if text.strip() else [""]
52
+
53
+ def create_custom_multi_query_retriever(
54
+ base_retriever,
55
+ llm,
56
+ num_queries: int = 5,
57
+ include_original: bool = True
58
  ):
59
+ """Create a custom MultiQueryRetriever with improved prompt."""
60
+
61
+ # Custom prompt template for better query generation
62
+ # query_prompt = PromptTemplate(
63
+ # input_variables=["question"],
64
+ # template="""You are an AI assistant specialized in generating diverse search queries.
65
+ # Your task is to generate {num_queries} different versions of the given user question to retrieve relevant documents from a knowledge base.
66
+
67
+ # Guidelines:
68
+ # - Create variations that capture different aspects and perspectives of the question
69
+ # - Use synonyms and alternative phrasings
70
+ # - Consider different levels of specificity (broader and narrower)
71
+ # - Focus on the core intent while varying the expression
72
+ # - Each query should be a complete, well-formed question or statement
73
+
74
+ # Original question: {question}
75
+
76
+ # Generate {num_queries} alternative queries (one per line):""".replace("{num_queries}", str(num_queries))
77
+ # )
78
+
79
+ # Create the MultiQueryRetriever with custom components
80
+ multi_query_retriever = MultiQueryRetriever.from_llm(
81
+ retriever=base_retriever,
82
+ llm=llm,
83
+ include_original=include_original
84
+ )
85
+
86
+ # # Override the output parser
87
+ # multi_query_retriever.output_parser = LineListOutputParser()
88
+
89
+ return multi_query_retriever
90
+
91
+ def validate_environment():
92
+ """Validate that required environment variables are set."""
93
+ required_vars = ["GOOGLE_API_KEY"]
94
+ missing_vars = [var for var in required_vars if not os.getenv(var)]
95
+
96
+ if missing_vars:
97
+ raise ValueError(f"Missing required environment variables: {missing_vars}")
98
+
99
+ logger.info("✅ Environment variables validated.")
100
 
101
+ def load_documents_from_json(chunks_directory: str) -> List[Document]:
102
+ """Load documents from JSON files with better error handling."""
103
+ json_files = glob.glob(os.path.join(chunks_directory, "*.json"))
104
+
105
+ if not json_files:
106
+ raise ValueError(f"No JSON files found in directory: {chunks_directory}")
107
+
108
+ logger.info(f"Found {len(json_files)} JSON files: {[os.path.basename(f) for f in json_files]}")
109
+
110
  documents = []
111
+ total_processed = 0
112
 
113
+ for json_file in json_files:
114
+ try:
115
+ logger.info(f"Processing: {os.path.basename(json_file)}")
116
+
117
+ with open(json_file, "r", encoding="utf-8") as f:
118
+ chunks_data = json.load(f)
 
119
 
120
+ file_doc_count = 0
121
+ for element in chunks_data:
122
+ try:
123
+ text = element.get("text", "").strip()
124
+ if not text: # Skip empty text
125
+ continue
126
+
127
+ metadata = {
128
+ "source": element.get("filename", "unknown"),
129
+ "filetype": element.get("filetype", "unknown"),
130
+ "element_id": element.get("element_id", "unknown"),
131
+ "json_source": os.path.basename(json_file)
132
+ }
133
 
134
+ # Add table-specific metadata if present
135
+ if element.get("type") == "TableElement" and element.get("table_text_as_html"):
136
+ metadata["table_text_as_html"] = element["table_text_as_html"]
137
+ # metadata["element_type"] = "table"
138
+ else:
139
+ metadata["element_type"] = element.get("type", "text")
140
 
141
+ doc = Document(page_content=text, metadata=metadata)
142
+ documents.append(doc)
143
+ file_doc_count += 1
144
+
145
+ except Exception as e:
146
+ logger.warning(f"Error processing element in {json_file}: {e}")
147
+ continue
148
+
149
+ logger.info(f" → Loaded {file_doc_count} documents from {os.path.basename(json_file)}")
150
+ total_processed += file_doc_count
151
+
152
+ except Exception as e:
153
+ logger.error(f"Error processing file {json_file}: {e}")
154
+ continue
155
+
156
+ if not documents:
157
+ raise ValueError("No valid documents were loaded from any JSON files.")
158
+
159
+ logger.info(f"✅ Total loaded: {len(documents)} documents from {len(json_files)} JSON files.")
160
+ return documents
161
+
162
+ def prepare_environment_and_retriever(
163
+ chunks_directory: str = "./data/",
164
+ model_name: str = "intfloat/multilingual-e5-base",
165
+ collection_name: str = "Guide_2023_e5_multilingual",
166
+ persist_directory: str = "chroma_db_multilingual",
167
+ k_vector: int = 6,
168
+ k_sparse: int = 2,
169
+ ensemble_weights: List[float] = [0.5, 0.5],
170
+ llm_model_name: str = "gemini-2.0-flash-exp",
171
+ num_query_variations: int = 5,
172
+ include_original_query: bool = True,
173
+ temperature: float = 0.1
174
+ ):
175
+ """
176
+ Prepare the complete retrieval environment with MultiQueryRetriever.
177
+
178
+ Args:
179
+ chunks_directory: Directory containing JSON files with document chunks
180
+ model_name: HuggingFace embedding model name
181
+ collection_name: Chroma collection name
182
+ persist_directory: Directory to persist Chroma database
183
+ k_vector: Number of documents to retrieve from vector search
184
+ k_sparse: Number of documents to retrieve from BM25 search
185
+ ensemble_weights: Weights for ensemble retriever [vector, sparse]
186
+ llm_model_name: Google Gemini model name for query expansion
187
+ num_query_variations: Number of query variations to generate
188
+ include_original_query: Whether to include original query in search
189
+ temperature: LLM temperature for query generation
190
+
191
+ Returns:
192
+ MultiQueryRetriever: Configured retriever ready for use
193
+ """
194
+
195
+ # Validate environment
196
+ validate_environment()
197
+
198
+ # Load documents
199
+ documents = load_documents_from_json(chunks_directory)
200
+
201
+ # Create embedding function
202
+ logger.info(f"Creating embeddings with model: {model_name}")
203
  embedding_function = HuggingFaceEmbeddings(
204
+ model_name=model_name,
205
  )
206
 
207
+ # Create or load vector store
208
+ logger.info("Creating/loading vector store...")
209
+ try:
210
+ # Try to load existing vectorstore first
211
+ if os.path.exists(persist_directory):
212
+ vectorstore = Chroma(
213
+ collection_name=collection_name,
214
+ embedding_function=embedding_function,
215
+ persist_directory=persist_directory
216
+ )
217
+ logger.info("✅ Loaded existing vector store.")
218
+ else:
219
+ # Create new vectorstore
220
+ vectorstore = Chroma.from_documents(
221
+ documents=documents,
222
+ embedding=embedding_function,
223
+ collection_name=collection_name,
224
+ persist_directory=persist_directory
225
+ )
226
+ logger.info("✅ Created new vector store with multilingual embeddings.")
227
+ except Exception as e:
228
+ logger.warning(f"Error with persistent storage: {e}. Creating in-memory store.")
229
+ vectorstore = Chroma.from_documents(
230
+ documents=documents,
231
+ embedding=embedding_function,
232
+ collection_name=collection_name
233
+ )
234
 
235
+ # Create base retrievers
236
+ logger.info("Setting up retrievers...")
237
+
238
+ # Vector retriever
239
+ vector_retriever = vectorstore.as_retriever(
240
  search_type="similarity",
241
  search_kwargs={"k": k_vector}
242
  )
243
 
244
+ # BM25 (sparse) retriever
245
  bm25_retriever = BM25Retriever.from_documents(documents)
246
  bm25_retriever.k = k_sparse
247
 
248
  # Ensemble retriever (combining vector + sparse search)
249
  ensemble_retriever = EnsembleRetriever(
250
+ retrievers=[vector_retriever, bm25_retriever],
251
+ weights=ensemble_weights
252
  )
253
+ logger.info(f"✅ Ensemble retriever created with weights: {ensemble_weights}")
254
 
255
  # Language model for multi-query expansion
256
+ logger.info(f"Initializing LLM: {llm_model_name}")
257
+ try:
258
+ llm = GoogleGenerativeAI(
259
+ model=llm_model_name,
260
+ google_api_key=os.getenv("GOOGLE_API_KEY"),
261
+ temperature=temperature,
262
+ max_output_tokens=1000 # Reasonable limit for query generation
263
+ )
264
+
265
+ # Test the LLM with a simple call
266
+ test_response = llm.invoke("Generate a simple test query about artificial intelligence.")
267
+ logger.info("✅ LLM connection verified.")
268
+
269
+ except Exception as e:
270
+ logger.error(f"Error initializing LLM: {e}")
271
+ raise
272
 
273
+ # Create MultiQueryRetriever with custom configuration
274
+ logger.info("Creating MultiQueryRetriever...")
275
+ try:
276
+ multi_query_retriever = create_custom_multi_query_retriever(
277
+ base_retriever=ensemble_retriever,
278
+ llm=llm,
279
+ num_queries=num_query_variations,
280
+ include_original=include_original_query
281
+ )
282
+
283
+ logger.info(f"✅ MultiQueryRetriever ready:")
284
+ logger.info(f" - Vector search: top-{k_vector}")
285
+ logger.info(f" - Sparse search: top-{k_sparse}")
286
+ logger.info(f" - Ensemble weights: {ensemble_weights}")
287
+ logger.info(f" - Query variations: {num_query_variations}")
288
+ logger.info(f" - Include original: {include_original_query}")
289
+
290
+ return multi_query_retriever
291
+
292
+ except Exception as e:
293
+ logger.error(f"Error creating MultiQueryRetriever: {e}")
294
+ logger.info("Falling back to ensemble retriever without query expansion.")
295
+ return ensemble_retriever
rag_pipeline.py CHANGED
@@ -1,15 +1,16 @@
1
  import json
2
  import re
 
 
3
  from langchain_google_genai import GoogleGenerativeAI
4
  from langchain_core.documents import Document
5
  from langdetect import detect
6
- import os
7
  from dotenv import load_dotenv
8
 
9
  # Load environment variables from .env file
10
  load_dotenv()
11
 
12
- def generate_rag_response(query, retrieved_documents, model="gemini-2.0-flash"):
13
  """
14
  Perform Retrieval-Augmented Generation (RAG) using Google's Gemini.
15
  Args:
@@ -192,14 +193,14 @@ def format_response_with_sequential_citations(response_text, unique_ids, clean_a
192
 
193
  return formatted_response.strip()
194
 
195
- def retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_path="./chunks.json"):
196
  """
197
  Retrieve relevant documents and prepare them for the RAG generation.
198
 
199
  Args:
200
  query (str): The user's query.
201
  expanding_retriever: The retriever object (e.g., returned by prepare_environment_and_retriever).
202
- chunks_path (str): Path to the chunks.json file.
203
 
204
  Returns:
205
  tuple: (source_texts_for_rag, retrieved_elements_full)
@@ -208,23 +209,43 @@ def retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_pat
208
  retrieved_docs = expanding_retriever.get_relevant_documents(query)
209
 
210
  retrieved_chunk_ids = [doc.metadata["element_id"] for doc in retrieved_docs]
211
-
212
- # Load all chunks
213
- with open(chunks_path, "r", encoding="utf-8") as f:
214
- chunks_data = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  source_retrieved_texts = []
217
  retrieved_elements_full = []
218
 
219
- for chu in chunks_data:
220
  if chu["element_id"] in retrieved_chunk_ids:
221
  if chu.get("type") == "TableElement":
222
  text = (
223
- f"[Source ID: {chu['elements']['element_id']}]\n"
224
  f"CONTENT:\n{chu['text']}\n"
225
  f"HTML:\n{chu['table_text_as_html']}\n\n"
226
  )
227
  source_retrieved_texts.append(text)
 
228
  else:
229
  for element in chu.get("elements", []):
230
  text = (
@@ -236,15 +257,16 @@ def retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_pat
236
 
237
  return source_retrieved_texts, retrieved_elements_full
238
 
239
- def full_rag_pipeline(query, expanding_retriever, chunks_path="./chunks.json", model="gemini-2.0-flash", clean_all_citations=False):
240
  """
241
  Full RAG pipeline from query to RAG response + extracted sources.
242
 
243
  Args:
244
  query (str): The user's query.
245
  expanding_retriever: The retriever object.
246
- chunks_path (str): Path to the chunks.json.
247
  model (str): Gemini model.
 
248
 
249
  Returns:
250
  dict: {
@@ -253,12 +275,11 @@ def full_rag_pipeline(query, expanding_retriever, chunks_path="./chunks.json", m
253
  "answer_language": str
254
  }
255
  """
256
- source_texts, retrieved_elements = retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_path)
257
 
258
  # Step 1: RAG
259
  response_text = generate_rag_response(query, source_texts, model=model)
260
 
261
-
262
  # Step 2: Extract cited sources
263
  unique_ids = extract_source_ids(response_text)
264
 
 
1
  import json
2
  import re
3
+ import glob
4
+ import os
5
  from langchain_google_genai import GoogleGenerativeAI
6
  from langchain_core.documents import Document
7
  from langdetect import detect
 
8
  from dotenv import load_dotenv
9
 
10
  # Load environment variables from .env file
11
  load_dotenv()
12
 
13
+ def generate_rag_response(query, retrieved_documents, model="gemini-2.0-flash-exp"):
14
  """
15
  Perform Retrieval-Augmented Generation (RAG) using Google's Gemini.
16
  Args:
 
193
 
194
  return formatted_response.strip()
195
 
196
+ def retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_directory="./data/"):
197
  """
198
  Retrieve relevant documents and prepare them for the RAG generation.
199
 
200
  Args:
201
  query (str): The user's query.
202
  expanding_retriever: The retriever object (e.g., returned by prepare_environment_and_retriever).
203
+ chunks_directory (str): Path to the directory containing JSON files.
204
 
205
  Returns:
206
  tuple: (source_texts_for_rag, retrieved_elements_full)
 
209
  retrieved_docs = expanding_retriever.get_relevant_documents(query)
210
 
211
  retrieved_chunk_ids = [doc.metadata["element_id"] for doc in retrieved_docs]
212
+
213
+ # Get unique filenames from retrieved documents
214
+ needed_filenames = set(doc.metadata["source"] for doc in retrieved_docs)
215
+
216
+ # Convert PDF filenames to JSON filenames (e.g., "file.pdf" -> "file.json")
217
+ needed_json_files = []
218
+ for filename in needed_filenames:
219
+ # Remove extension and add .json
220
+ base_name = os.path.splitext(filename)[0]
221
+ json_filename = f"{base_name}.json"
222
+ json_path = os.path.join(chunks_directory, json_filename)
223
+ if os.path.exists(json_path):
224
+ needed_json_files.append(json_path)
225
+ else:
226
+ print(f"Warning: JSON file not found: {json_path}")
227
+
228
+ # Load only the needed JSON files
229
+ all_chunks_data = []
230
+ for json_file in needed_json_files:
231
+ print(f"Loading: {os.path.basename(json_file)}")
232
+ with open(json_file, "r", encoding="utf-8") as f:
233
+ chunks_data = json.load(f)
234
+ all_chunks_data.extend(chunks_data)
235
 
236
  source_retrieved_texts = []
237
  retrieved_elements_full = []
238
 
239
+ for chu in all_chunks_data:
240
  if chu["element_id"] in retrieved_chunk_ids:
241
  if chu.get("type") == "TableElement":
242
  text = (
243
+ f"[Source ID: {chu['element_id']}]\n"
244
  f"CONTENT:\n{chu['text']}\n"
245
  f"HTML:\n{chu['table_text_as_html']}\n\n"
246
  )
247
  source_retrieved_texts.append(text)
248
+ retrieved_elements_full.append(chu)
249
  else:
250
  for element in chu.get("elements", []):
251
  text = (
 
257
 
258
  return source_retrieved_texts, retrieved_elements_full
259
 
260
+ def full_rag_pipeline(query, expanding_retriever, chunks_directory="./data/", model="gemini-2.0-flash-exp", clean_all_citations=False):
261
  """
262
  Full RAG pipeline from query to RAG response + extracted sources.
263
 
264
  Args:
265
  query (str): The user's query.
266
  expanding_retriever: The retriever object.
267
+ chunks_directory (str): Path to the directory containing JSON files.
268
  model (str): Gemini model.
269
+ clean_all_citations (bool): Whether to remove all citations from response.
270
 
271
  Returns:
272
  dict: {
 
275
  "answer_language": str
276
  }
277
  """
278
+ source_texts, retrieved_elements = retrieve_documents_and_prepare_inputs(query, expanding_retriever, chunks_directory)
279
 
280
  # Step 1: RAG
281
  response_text = generate_rag_response(query, source_texts, model=model)
282
 
 
283
  # Step 2: Extract cited sources
284
  unique_ids = extract_source_ids(response_text)
285