Spaces:
Sleeping
Sleeping
citation filtering
Browse files- .gitignore +2 -1
- utils/generator.py +78 -11
.gitignore
CHANGED
|
@@ -1 +1,2 @@
|
|
| 1 |
-
.DS_Store
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
| 2 |
+
.env
|
utils/generator.py
CHANGED
|
@@ -2,6 +2,7 @@ import logging
|
|
| 2 |
import asyncio
|
| 3 |
import json
|
| 4 |
import ast
|
|
|
|
| 5 |
from typing import List, Dict, Any, Union, Generator, AsyncGenerator
|
| 6 |
from dotenv import load_dotenv
|
| 7 |
|
|
@@ -73,6 +74,51 @@ def get_chat_model():
|
|
| 73 |
# Initialize provider-agnostic chat model
|
| 74 |
chat_model = get_chat_model()
|
| 75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
# ---------------------------------------------------------------------
|
| 77 |
# Context processing - may need further refinement (i.e. to manage other data sources)
|
| 78 |
# ---------------------------------------------------------------------
|
|
@@ -189,14 +235,21 @@ def build_messages(question: str, context: str) -> list:
|
|
| 189 |
Returns:
|
| 190 |
List of LangChain message objects
|
| 191 |
"""
|
| 192 |
-
system_content =
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
|
| 202 |
|
|
@@ -253,9 +306,15 @@ async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui
|
|
| 253 |
# Return ChatUI format
|
| 254 |
result = {"answer": answer}
|
| 255 |
if processed_results:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
# Extract sources for ChatUI
|
| 257 |
sources = []
|
| 258 |
-
for result_item in
|
| 259 |
filename = result_item.get('filename', 'Unknown')
|
| 260 |
page = result_item.get('page', 'Unknown')
|
| 261 |
year = result_item.get('year', 'Unknown')
|
|
@@ -349,8 +408,10 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
|
|
| 349 |
try:
|
| 350 |
messages = build_messages(query, formatted_context)
|
| 351 |
|
| 352 |
-
# Stream the text response
|
|
|
|
| 353 |
async for chunk in _call_llm_streaming(messages):
|
|
|
|
| 354 |
if chatui_format:
|
| 355 |
yield {"event": "data", "data": chunk}
|
| 356 |
else:
|
|
@@ -358,8 +419,14 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
|
|
| 358 |
|
| 359 |
# Send sources at the end if available and in ChatUI format
|
| 360 |
if chatui_format and processed_results:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 361 |
sources = []
|
| 362 |
-
for result in
|
| 363 |
filename = result.get('filename', 'Unknown')
|
| 364 |
page = result.get('page', 'Unknown')
|
| 365 |
year = result.get('year', 'Unknown')
|
|
|
|
| 2 |
import asyncio
|
| 3 |
import json
|
| 4 |
import ast
|
| 5 |
+
import re
|
| 6 |
from typing import List, Dict, Any, Union, Generator, AsyncGenerator
|
| 7 |
from dotenv import load_dotenv
|
| 8 |
|
|
|
|
| 74 |
# Initialize provider-agnostic chat model
|
| 75 |
chat_model = get_chat_model()
|
| 76 |
|
| 77 |
+
# ---------------------------------------------------------------------
|
| 78 |
+
# Citation parsing and source filtering
|
| 79 |
+
# ---------------------------------------------------------------------
|
| 80 |
+
def parse_citations_from_response(response: str) -> List[int]:
|
| 81 |
+
"""
|
| 82 |
+
Parse citation numbers from the generated response.
|
| 83 |
+
|
| 84 |
+
Args:
|
| 85 |
+
response: The generated response text
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
List of unique citation numbers found in the response
|
| 89 |
+
"""
|
| 90 |
+
# Find all citation patterns like [1], [2], [1][2], etc.
|
| 91 |
+
citation_pattern = r'\[(\d+)\]'
|
| 92 |
+
matches = re.findall(citation_pattern, response)
|
| 93 |
+
|
| 94 |
+
# Convert to integers and return unique values
|
| 95 |
+
citation_numbers = [int(match) for match in matches]
|
| 96 |
+
return sorted(list(set(citation_numbers)))
|
| 97 |
+
|
| 98 |
+
def filter_sources_by_citations(processed_results: List[Dict[str, Any]], cited_numbers: List[int]) -> List[Dict[str, Any]]:
|
| 99 |
+
"""
|
| 100 |
+
Filter sources to only include those that were cited in the response.
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
processed_results: All processed retrieval results
|
| 104 |
+
cited_numbers: List of citation numbers found in the response
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
List of sources that were actually cited
|
| 108 |
+
"""
|
| 109 |
+
if not cited_numbers:
|
| 110 |
+
return []
|
| 111 |
+
|
| 112 |
+
# Filter sources based on citation numbers (1-indexed)
|
| 113 |
+
cited_sources = []
|
| 114 |
+
for citation_num in cited_numbers:
|
| 115 |
+
# Convert to 0-indexed for list access
|
| 116 |
+
source_index = citation_num - 1
|
| 117 |
+
if 0 <= source_index < len(processed_results):
|
| 118 |
+
cited_sources.append(processed_results[source_index])
|
| 119 |
+
|
| 120 |
+
return cited_sources
|
| 121 |
+
|
| 122 |
# ---------------------------------------------------------------------
|
| 123 |
# Context processing - may need further refinement (i.e. to manage other data sources)
|
| 124 |
# ---------------------------------------------------------------------
|
|
|
|
| 235 |
Returns:
|
| 236 |
List of LangChain message objects
|
| 237 |
"""
|
| 238 |
+
system_content = """
|
| 239 |
+
You are AuditQ&A, an AI Assistant created by Auditors and Data Scientist. \
|
| 240 |
+
You are given a question and extracted passages of the consolidated/departmental/thematic focus audit reports.\
|
| 241 |
+
Provide a clear and structured answer based on the passages/context provided and the guidelines.
|
| 242 |
+
Guidelines:
|
| 243 |
+
- If the passages have useful facts or numbers, use them in your answer.
|
| 244 |
+
- Do not just summarize each passage one by one. Group your summaries to highlight the key parts in the explanation.
|
| 245 |
+
- If it makes sense, use bullet points and lists to make your answers easier to understand.
|
| 246 |
+
- You do not need to use every passage. Only use the ones that help answer the question.
|
| 247 |
+
- Answer the USER question using only the CONTEXT provided.
|
| 248 |
+
- When referencing information from the context, use inline citations in square brackets like [1], [2], etc. to reference the document numbers shown in the context.
|
| 249 |
+
- Use multiple citations when information comes from multiple documents, like [1][2].
|
| 250 |
+
- Do not use the sentence 'Doc x says ...' to say where information came from, but rather just include the citation at the end of the sentence.
|
| 251 |
+
- If the context is insufficient, say "I don't have sufficient information to answer the question. Please try rephrasing your query."
|
| 252 |
+
"""
|
| 253 |
|
| 254 |
user_content = f"### CONTEXT\n{context}\n\n### USER QUESTION\n{question}"
|
| 255 |
|
|
|
|
| 306 |
# Return ChatUI format
|
| 307 |
result = {"answer": answer}
|
| 308 |
if processed_results:
|
| 309 |
+
# Parse citations from the response
|
| 310 |
+
cited_numbers = parse_citations_from_response(answer)
|
| 311 |
+
|
| 312 |
+
# Filter sources to only include cited ones
|
| 313 |
+
cited_sources = filter_sources_by_citations(processed_results, cited_numbers)
|
| 314 |
+
|
| 315 |
# Extract sources for ChatUI
|
| 316 |
sources = []
|
| 317 |
+
for result_item in cited_sources: # Only cited sources
|
| 318 |
filename = result_item.get('filename', 'Unknown')
|
| 319 |
page = result_item.get('page', 'Unknown')
|
| 320 |
year = result_item.get('year', 'Unknown')
|
|
|
|
| 408 |
try:
|
| 409 |
messages = build_messages(query, formatted_context)
|
| 410 |
|
| 411 |
+
# Stream the text response and accumulate it for citation parsing
|
| 412 |
+
accumulated_response = ""
|
| 413 |
async for chunk in _call_llm_streaming(messages):
|
| 414 |
+
accumulated_response += chunk
|
| 415 |
if chatui_format:
|
| 416 |
yield {"event": "data", "data": chunk}
|
| 417 |
else:
|
|
|
|
| 419 |
|
| 420 |
# Send sources at the end if available and in ChatUI format
|
| 421 |
if chatui_format and processed_results:
|
| 422 |
+
# Parse citations from the complete response
|
| 423 |
+
cited_numbers = parse_citations_from_response(accumulated_response)
|
| 424 |
+
|
| 425 |
+
# Filter sources to only include cited ones
|
| 426 |
+
cited_sources = filter_sources_by_citations(processed_results, cited_numbers)
|
| 427 |
+
|
| 428 |
sources = []
|
| 429 |
+
for result in cited_sources: # Only cited sources
|
| 430 |
filename = result.get('filename', 'Unknown')
|
| 431 |
page = result.get('page', 'Unknown')
|
| 432 |
year = result.get('year', 'Unknown')
|