mtyrrell commited on
Commit
d049b68
·
1 Parent(s): f852f01

citation filtering

Browse files
Files changed (2) hide show
  1. .gitignore +2 -1
  2. utils/generator.py +78 -11
.gitignore CHANGED
@@ -1 +1,2 @@
1
- .DS_Store
 
 
1
+ .DS_Store
2
+ .env
utils/generator.py CHANGED
@@ -2,6 +2,7 @@ import logging
2
  import asyncio
3
  import json
4
  import ast
 
5
  from typing import List, Dict, Any, Union, Generator, AsyncGenerator
6
  from dotenv import load_dotenv
7
 
@@ -73,6 +74,51 @@ def get_chat_model():
73
  # Initialize provider-agnostic chat model
74
  chat_model = get_chat_model()
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # ---------------------------------------------------------------------
77
  # Context processing - may need further refinement (i.e. to manage other data sources)
78
  # ---------------------------------------------------------------------
@@ -189,14 +235,21 @@ def build_messages(question: str, context: str) -> list:
189
  Returns:
190
  List of LangChain message objects
191
  """
192
- system_content = (
193
- "You are an expert assistant. Answer the USER question using only the "
194
- "CONTEXT provided. When referencing information from the context, use inline "
195
- "citations in square brackets like [1], [2], etc. to reference the document "
196
- "numbers shown in the context. Use multiple citations when information comes "
197
- "from multiple documents, like [1][2]. If the context is insufficient, say "
198
- "'I don't know.'"
199
- )
 
 
 
 
 
 
 
200
 
201
  user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
202
 
@@ -253,9 +306,15 @@ async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui
253
  # Return ChatUI format
254
  result = {"answer": answer}
255
  if processed_results:
 
 
 
 
 
 
256
  # Extract sources for ChatUI
257
  sources = []
258
- for result_item in processed_results:
259
  filename = result_item.get('filename', 'Unknown')
260
  page = result_item.get('page', 'Unknown')
261
  year = result_item.get('year', 'Unknown')
@@ -349,8 +408,10 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
349
  try:
350
  messages = build_messages(query, formatted_context)
351
 
352
- # Stream the text response
 
353
  async for chunk in _call_llm_streaming(messages):
 
354
  if chatui_format:
355
  yield {"event": "data", "data": chunk}
356
  else:
@@ -358,8 +419,14 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
358
 
359
  # Send sources at the end if available and in ChatUI format
360
  if chatui_format and processed_results:
 
 
 
 
 
 
361
  sources = []
362
- for result in processed_results:
363
  filename = result.get('filename', 'Unknown')
364
  page = result.get('page', 'Unknown')
365
  year = result.get('year', 'Unknown')
 
2
  import asyncio
3
  import json
4
  import ast
5
+ import re
6
  from typing import List, Dict, Any, Union, Generator, AsyncGenerator
7
  from dotenv import load_dotenv
8
 
 
74
  # Initialize provider-agnostic chat model
75
  chat_model = get_chat_model()
76
 
77
+ # ---------------------------------------------------------------------
78
+ # Citation parsing and source filtering
79
+ # ---------------------------------------------------------------------
80
+ def parse_citations_from_response(response: str) -> List[int]:
81
+ """
82
+ Parse citation numbers from the generated response.
83
+
84
+ Args:
85
+ response: The generated response text
86
+
87
+ Returns:
88
+ List of unique citation numbers found in the response
89
+ """
90
+ # Find all citation patterns like [1], [2], [1][2], etc.
91
+ citation_pattern = r'\[(\d+)\]'
92
+ matches = re.findall(citation_pattern, response)
93
+
94
+ # Convert to integers and return unique values
95
+ citation_numbers = [int(match) for match in matches]
96
+ return sorted(list(set(citation_numbers)))
97
+
98
+ def filter_sources_by_citations(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
99
+ """
100
+ Filter sources to only include those that were cited in the response.
101
+
102
+ Args:
103
+ processed_results: All processed retrieval results
104
+ cited_numbers: List of citation numbers found in the response
105
+
106
+ Returns:
107
+ List of sources that were actually cited
108
+ """
109
+ if not cited_numbers:
110
+ return []
111
+
112
+ # Filter sources based on citation numbers (1-indexed)
113
+ cited_sources = []
114
+ for citation_num in cited_numbers:
115
+ # Convert to 0-indexed for list access
116
+ source_index = citation_num - 1
117
+ if 0 <= source_index < len(processed_results):
118
+ cited_sources.append(processed_results[source_index])
119
+
120
+ return cited_sources
121
+
122
  # ---------------------------------------------------------------------
123
  # Context processing - may need further refinement (i.e. to manage other data sources)
124
  # ---------------------------------------------------------------------
 
235
  Returns:
236
  List of LangChain message objects
237
  """
238
+ system_content = """
239
+ You are AuditQ&A, an AI Assistant created by Auditors and Data Scientist. \
240
+ You are given a question and extracted passages of the consolidated/departmental/thematic focus audit reports.\
241
+ Provide a clear and structured answer based on the passages/context provided and the guidelines.
242
+ Guidelines:
243
+ - If the passages have useful facts or numbers, use them in your answer.
244
+ - Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
245
+ - If it makes sense, use bullet points and lists to make your answers easier to understand.
246
+ - You do not need to use every passage. Only use the ones that help answer the question.
247
+ - Answer the USER question using only the CONTEXT provided.
248
+ - When referencing information from the context, use inline citations in square brackets like [1], [2], etc. to reference the document numbers shown in the context.
249
+ - Use multiple citations when information comes from multiple documents, like [1][2].
250
+ - Do not use the sentence 'Doc x says ...' to say where information came from, but rather just include the citation at the end of the sentence.
251
+ - If the context is insufficient, say "I don't have sufficient information to answer the question. Please try rephrasing your query."
252
+ """
253
 
254
  user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
255
 
 
306
  # Return ChatUI format
307
  result = {"answer": answer}
308
  if processed_results:
309
+ # Parse citations from the response
310
+ cited_numbers = parse_citations_from_response(answer)
311
+
312
+ # Filter sources to only include cited ones
313
+ cited_sources = filter_sources_by_citations(processed_results, cited_numbers)
314
+
315
  # Extract sources for ChatUI
316
  sources = []
317
+ for result_item in cited_sources: # Only cited sources
318
  filename = result_item.get('filename', 'Unknown')
319
  page = result_item.get('page', 'Unknown')
320
  year = result_item.get('year', 'Unknown')
 
408
  try:
409
  messages = build_messages(query, formatted_context)
410
 
411
+ # Stream the text response and accumulate it for citation parsing
412
+ accumulated_response = ""
413
  async for chunk in _call_llm_streaming(messages):
414
+ accumulated_response += chunk
415
  if chatui_format:
416
  yield {"event": "data", "data": chunk}
417
  else:
 
419
 
420
  # Send sources at the end if available and in ChatUI format
421
  if chatui_format and processed_results:
422
+ # Parse citations from the complete response
423
+ cited_numbers = parse_citations_from_response(accumulated_response)
424
+
425
+ # Filter sources to only include cited ones
426
+ cited_sources = filter_sources_by_citations(processed_results, cited_numbers)
427
+
428
  sources = []
429
+ for result in cited_sources: # Only cited sources
430
  filename = result.get('filename', 'Unknown')
431
  page = result.get('page', 'Unknown')
432
  year = result.get('year', 'Unknown')