Zeggai Abdellah commited on
Commit
1817834
·
1 Parent(s): 8355f0c

add number fo the citation

Browse files
Files changed (1) hide show
  1. rag_pipeline.py +99 -42
rag_pipeline.py CHANGED
@@ -1,7 +1,7 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- RAG Pipeline for vaccine assistant
4
- Handles agent creation and question answering
5
  """
6
 
7
  import json
@@ -46,13 +46,13 @@ def extract_source_ids(response_text):
46
  ids = [id_str.strip() for id_str in citation.split(',')]
47
  all_ids.extend(ids)
48
 
49
- # Get unique source IDs
50
- source_ids = list(set(all_ids))
51
-
52
- # Filter out any non-UUID-like IDs (if needed)
53
- # This is now optional as we're handling various source ID formats
54
- # uuid_pattern = r'^[0-9a-f]{8}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{4}-?[0-9a-f]{12}$'
55
- # source_ids = [source_id for source_id in source_ids if re.match(uuid_pattern, source_id, re.IGNORECASE)]
56
 
57
  if not source_ids:
58
  print("Warning: No valid source IDs found after filtering.")
@@ -61,6 +61,41 @@ def extract_source_ids(response_text):
61
  return source_ids
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def create_custom_prompt():
65
  """Create custom prompt with medical assistant instructions"""
66
 
@@ -240,9 +275,9 @@ def process_question(agent, question: str) -> str:
240
  print(f"Error processing question: {e}")
241
  return f"Error processing your question: {str(e)}"
242
 
243
- def process_question_with_citations(agent, question: str, chunks_directory="./data/") -> dict:
244
  """
245
- Process a question through the RAG pipeline and extract cited elements.
246
 
247
  Args:
248
  agent: The initialized RAG agent
@@ -251,9 +286,10 @@ def process_question_with_citations(agent, question: str, chunks_directory="./da
251
 
252
  Returns:
253
  dict: {
254
- "response": str,
255
- "cited_elements_json": str,
256
- "unique_ids": list
 
257
  }
258
  """
259
  try:
@@ -261,48 +297,69 @@ def process_question_with_citations(agent, question: str, chunks_directory="./da
261
  response = agent.chat(question)
262
  response_text = response.response
263
 
264
- # Extract source IDs from the response
265
  unique_ids = extract_source_ids(response_text)
266
 
 
 
 
 
 
 
267
  # Load all chunks data to find cited elements
268
  all_chunks_data = []
269
- # the ids is only in the two main files, so we can load them all at once
270
- min_chunks_files = ["Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json", "Immunization in Practice_WHO_eng_2015.json"]
 
271
  for json_file in min_chunks_files:
272
- if json_file.endswith('.json'):
273
- json_path = os.path.join(chunks_directory, json_file)
274
- try:
275
- with open(json_path, "r", encoding="utf-8") as f:
276
- chunks_data = json.load(f)
277
- all_chunks_data.extend(chunks_data)
278
- except Exception as e:
279
- print(f"Warning: Could not load {json_file}: {e}")
280
 
281
- # Get only the cited elements
282
- cited_elements = []
283
- for element in all_chunks_data:
284
- if element.get("type") =='TableElement':
285
- if element.get("element_id") in unique_ids:
286
- cited_elements.append(element['elements'])
287
- else :
288
- if "elements" in element:
289
- for nested_element in element["elements"]:
290
- if nested_element.get("element_id") in unique_ids:
291
- cited_elements.append(nested_element)
 
 
 
 
 
 
292
 
293
  # Convert to JSON
294
- cited_elements_json = json.dumps(cited_elements, ensure_ascii=False, indent=2)
295
 
296
  return {
297
- "response": response_text,
298
  "cited_elements_json": cited_elements_json,
299
- "unique_ids": unique_ids
 
300
  }
301
 
302
  except Exception as e:
303
  print(f"Error processing question: {e}")
304
  return {
305
- "response": response_text,
306
  "cited_elements_json": "[]",
307
- "unique_ids": []
308
- }
 
 
 
 
 
 
 
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ Enhanced RAG Pipeline for vaccine assistant
4
+ Handles agent creation and question answering with sequential citation numbering
5
  """
6
 
7
  import json
 
46
  ids = [id_str.strip() for id_str in citation.split(',')]
47
  all_ids.extend(ids)
48
 
49
+ # Get unique source IDs while preserving order
50
+ seen = set()
51
+ source_ids = []
52
+ for id_str in all_ids:
53
+ if id_str not in seen:
54
+ seen.add(id_str)
55
+ source_ids.append(id_str)
56
 
57
  if not source_ids:
58
  print("Warning: No valid source IDs found after filtering.")
 
61
  return source_ids
62
 
63
 
64
+ def convert_citations_to_sequential(response_text, source_id_to_number_map):
65
+ """
66
+ Convert source IDs in response text to sequential numbers.
67
+
68
+ Args:
69
+ response_text (str): The response text with source ID citations
70
+ source_id_to_number_map (dict): Mapping from source IDs to sequential numbers
71
+
72
+ Returns:
73
+ str: Response text with sequential number citations
74
+ """
75
+ def replace_citation(match):
76
+ citation_content = match.group(1)
77
+ # Handle multiple IDs in one citation (comma-separated)
78
+ ids = [id_str.strip() for id_str in citation_content.split(',')]
79
+
80
+ # Convert each ID to its sequential number
81
+ numbers = []
82
+ for id_str in ids:
83
+ if id_str in source_id_to_number_map:
84
+ numbers.append(str(source_id_to_number_map[id_str]))
85
+
86
+ # Return the formatted citation with sequential numbers
87
+ if len(numbers) == 1:
88
+ return f"[{numbers[0]}]"
89
+ elif len(numbers) > 1:
90
+ return f"[{','.join(numbers)}]"
91
+ else:
92
+ return match.group(0) # Return original if no mapping found
93
+
94
+ # Replace all citations in the text
95
+ sequential_response = re.sub(r'\[([^\[\]]+)\]', replace_citation, response_text)
96
+ return sequential_response
97
+
98
+
99
  def create_custom_prompt():
100
  """Create custom prompt with medical assistant instructions"""
101
 
 
275
  print(f"Error processing question: {e}")
276
  return f"Error processing your question: {str(e)}"
277
 
278
+ def process_question_with_sequential_citations(agent, question: str, chunks_directory="./data/") -> dict:
279
  """
280
+ Process a question through the RAG pipeline and return response with sequential citation numbers.
281
 
282
  Args:
283
  agent: The initialized RAG agent
 
286
 
287
  Returns:
288
  dict: {
289
+ "response": str, # Response with sequential citation numbers [1], [2], etc.
290
+ "cited_elements_json": str, # JSON array of cited elements in order
291
+ "unique_ids": list, # Original source IDs in order
292
+ "citation_mapping": dict # Mapping from source ID to citation number
293
  }
294
  """
295
  try:
 
297
  response = agent.chat(question)
298
  response_text = response.response
299
 
300
+ # Extract source IDs from the response (preserving order)
301
  unique_ids = extract_source_ids(response_text)
302
 
303
+ # Create mapping from source ID to sequential number
304
+ source_id_to_number = {source_id: i + 1 for i, source_id in enumerate(unique_ids)}
305
+
306
+ # Convert citations to sequential numbers
307
+ sequential_response = convert_citations_to_sequential(response_text, source_id_to_number)
308
+
309
  # Load all chunks data to find cited elements
310
  all_chunks_data = []
311
+ min_chunks_files = ["Guide-pratique-de-mise-en-oeuvre-du-calendrier-national-de-vaccination-2023.json",
312
+ "Immunization_in_Practice_WHO_eng_2015.json"]
313
+
314
  for json_file in min_chunks_files:
315
+ json_path = os.path.join(chunks_directory, json_file)
316
+ try:
317
+ with open(json_path, "r", encoding="utf-8") as f:
318
+ chunks_data = json.load(f)
319
+ all_chunks_data.extend(chunks_data)
320
+ except Exception as e:
321
+ print(f"Warning: Could not load {json_file}: {e}")
 
322
 
323
+ # Get cited elements in the same order as the sequential citations
324
+ cited_elements_ordered = []
325
+ for source_id in unique_ids: # This preserves the order
326
+ for element in all_chunks_data:
327
+ if element.get("type") == 'TableElement':
328
+ if element.get("element_id") == source_id:
329
+ cited_elements_ordered.append(element)
330
+ break
331
+ else:
332
+ if "elements" in element:
333
+ for nested_element in element["elements"]:
334
+ if nested_element.get("element_id") == source_id:
335
+ cited_elements_ordered.append(nested_element)
336
+ break
337
+ else:
338
+ continue
339
+ break
340
 
341
  # Convert to JSON
342
+ cited_elements_json = json.dumps(cited_elements_ordered, ensure_ascii=False, indent=2)
343
 
344
  return {
345
+ "response": sequential_response,
346
  "cited_elements_json": cited_elements_json,
347
+ "unique_ids": unique_ids,
348
+ "citation_mapping": source_id_to_number
349
  }
350
 
351
  except Exception as e:
352
  print(f"Error processing question: {e}")
353
  return {
354
+ "response": response_text if 'response_text' in locals() else "Error occurred",
355
  "cited_elements_json": "[]",
356
+ "unique_ids": [],
357
+ "citation_mapping": {}
358
+ }
359
+
360
+ def process_question_with_citations(agent, question: str, chunks_directory="./data/") -> dict:
361
+ """
362
+ Legacy function - maintained for backward compatibility.
363
+ Now calls the new sequential citation function.
364
+ """
365
+ return process_question_with_sequential_citations(agent, question, chunks_directory)