mtyrrell commited on
Commit
f2a3674
·
1 Parent(s): 4a1d809

added validation guardrails

Browse files
Files changed (1) hide show
  1. utils/generator.py +41 -1
utils/generator.py CHANGED
@@ -72,6 +72,33 @@ def _extract_sources(processed_results: List[Dict[str, Any]], cited_numbers: Lis
72
 
73
  return cited_sources
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def _process_context(context: Union[str, List[Dict[str, Any]]]) -> tuple[str, List[Dict[str, Any]]]:
76
  """Process context and return formatted context string and processed results"""
77
  processed_results = []
@@ -138,6 +165,9 @@ CITATION FORMAT (CRITICAL):
138
  - CORRECT: "Revenue increased by 15% [3]."
139
  - INCORRECT: "(Document 3)", "(Doc 3)", "Document 3 states", "according to document 3"
140
  - NEVER use phrases like "Doc x says" or "(Document x)" - ONLY use [x] format.
 
 
 
141
 
142
  - If the context is insufficient, say "I don't have sufficient information to answer the question. Please try rephrasing your query."
143
  """
@@ -200,6 +230,11 @@ async def generate(query: str, context: Union[str, List[Dict[str, Any]]], chatui
200
  messages = _build_messages(query, formatted_context)
201
  answer = await _call_llm(messages)
202
 
 
 
 
 
 
203
  if chatui_format:
204
  result = {"answer": answer}
205
  if processed_results:
@@ -238,9 +273,14 @@ async def generate_streaming(query: str, context: Union[str, List[Dict[str, Any]
238
  else:
239
  yield chunk
240
 
 
 
 
 
 
241
  # Send sources at the end if available and in ChatUI format
242
  if chatui_format and processed_results:
243
- cited_numbers = _parse_citations(accumulated_response)
244
  cited_sources = _extract_sources(processed_results, cited_numbers)
245
  sources = _create_sources_list(cited_sources)
246
  yield {"event": "sources", "data": {"sources": sources}}
 
72
 
73
  return cited_sources
74
 
75
+ def normalize_citations(response: str) -> str:
76
+ """Convert non-compliant citation formats to [x] format"""
77
+ # Convert (Document X) to [X]
78
+ response = re.sub(r'\(Document\s+(\d+)\)', r'[\1]', response, flags=re.IGNORECASE)
79
+ # Convert (Doc X) to [X]
80
+ response = re.sub(r'\(Doc\s+(\d+)\)', r'[\1]', response, flags=re.IGNORECASE)
81
+ # Convert "Document X says" to [X]
82
+ response = re.sub(r'Document\s+(\d+)\s+(?:says|states|mentions)', r'[\1]', response, flags=re.IGNORECASE)
83
+ return response
84
+
85
+ def clean_response(response: str) -> str:
86
+ """Remove unwanted reference sections"""
87
+ # Split by common reference section headers
88
+ patterns = [
89
+ r'\n\s*References?\s*:',
90
+ r'\n\s*Sources?\s*:',
91
+ r'\n\s*Bibliography\s*:',
92
+ r'\n\s*Citations?\s*:',
93
+ ]
94
+
95
+ for pattern in patterns:
96
+ if re.search(pattern, response, re.IGNORECASE):
97
+ response = re.split(pattern, response, flags=re.IGNORECASE)[0]
98
+ break
99
+
100
+ return response.strip()
101
+
102
  def _process_context(context: Union[str, List[Dict[str, Any]]]) -> tuple[str, List[Dict[str, Any]]]:
103
  """Process context and return formatted context string and processed results"""
104
  processed_results = []
 
165
  - CORRECT: "Revenue increased by 15% [3]."
166
  - INCORRECT: "(Document 3)", "(Doc 3)", "Document 3 states", "according to document 3"
167
  - NEVER use phrases like "Doc x says" or "(Document x)" - ONLY use [x] format.
168
+ - DO NOT add a "References" section at the end of your response.
169
+ - DO NOT list out the full document names, page numbers, or years at the end.
170
+ - Your response should END after your answer - no bibliography, no references list, no sources section.
171
 
172
  - If the context is insufficient, say "I don't have sufficient information to answer the question. Please try rephrasing your query."
173
  """
 
230
  messages = _build_messages(query, formatted_context)
231
  answer = await _call_llm(messages)
232
 
233
+ # Normalize citations to ensure proper format
234
+ answer = normalize_citations(answer)
235
+ # Clean response to remove unwanted reference sections
236
+ answer = clean_response(answer)
237
+
238
  if chatui_format:
239
  result = {"answer": answer}
240
  if processed_results:
 
273
  else:
274
  yield chunk
275
 
276
+ # Normalize citations in the complete response
277
+ normalized_response = normalize_citations(accumulated_response)
278
+ # Clean response to remove unwanted reference sections
279
+ cleaned_response = clean_response(normalized_response)
280
+
281
  # Send sources at the end if available and in ChatUI format
282
  if chatui_format and processed_results:
283
+ cited_numbers = _parse_citations(cleaned_response)
284
  cited_sources = _extract_sources(processed_results, cited_numbers)
285
  sources = _create_sources_list(cited_sources)
286
  yield {"event": "sources", "data": {"sources": sources}}