Isateles commited on
Commit
a7b80a9
·
1 Parent(s): 394d24e

Update GAIA agent-fixed extract answer

Browse files
Files changed (1) hide show
  1. app.py +62 -131
app.py CHANGED
@@ -71,49 +71,44 @@ def setup_llm():
71
 
72
 
73
  def extract_final_answer(response_text: str) -> str:
74
- """Extract answer aligned with GAIA scoring rules"""
 
 
 
75
 
76
  # Look for FINAL ANSWER pattern
77
  match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", response_text, re.IGNORECASE | re.DOTALL)
78
 
79
  if not match:
80
- # Fallback: look for answer at the end of response
81
- lines = response_text.strip().split('\n')
82
- if lines:
83
- # Check if last line looks like an answer
84
- last_line = lines[-1].strip()
85
- if len(last_line) < 100 and not last_line.startswith(('I', 'The', 'To', 'Based')):
86
- answer = last_line
87
- else:
88
- logger.warning("No FINAL ANSWER found")
89
- return ""
90
- else:
91
- return ""
92
- else:
93
- answer = match.group(1).strip()
94
 
95
- # Remove any trailing punctuation that's not part of the answer
96
- answer = answer.rstrip('.')
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  # Clean for GAIA scoring
99
 
100
- # 1. Handle numbers with more precision
101
  if re.match(r'^[\d\s.,\-+e]+$', answer):
102
- # Remove all formatting
103
  cleaned = answer.replace(',', '').replace(' ', '')
104
  try:
105
- # Try to parse as float
106
  num = float(cleaned)
107
- # Return integer if whole number, otherwise keep precision
108
- if num.is_integer():
109
- return str(int(num))
110
- else:
111
- # Keep original precision, don't round
112
- return str(num)
113
  except:
114
  pass
115
 
116
- # 2. Handle percentages (remove % sign)
117
  if answer.endswith('%'):
118
  answer = answer[:-1].strip()
119
  try:
@@ -122,43 +117,30 @@ def extract_final_answer(response_text: str) -> str:
122
  except:
123
  pass
124
 
125
- # 3. Lists: clean and standardize
126
- if ',' in answer or ' and ' in answer.lower():
127
- # Split on commas and 'and'
128
- parts = re.split(r',|\s+and\s+', answer)
129
- cleaned_parts = []
130
-
131
- for part in parts:
132
- part = part.strip()
133
- if not part:
134
- continue
135
-
136
- # Try to parse as number
137
- try:
138
- num = float(part.replace('$', '').replace('%', '').replace(',', ''))
139
- cleaned_parts.append(str(int(num)) if num.is_integer() else str(num))
140
- except:
141
- # Remove articles from strings
142
- words = part.split()
143
- if words and words[0].lower() in ['the', 'a', 'an']:
144
- cleaned_parts.append(' '.join(words[1:]))
145
- else:
146
- cleaned_parts.append(part)
147
-
148
- return ', '.join(cleaned_parts)
149
-
150
- # 4. Yes/No answers
151
  if answer.lower() in ['yes', 'no']:
152
  return answer.lower()
153
 
154
- # 5. Single words/phrases: remove articles
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  words = answer.split()
156
  if words and words[0].lower() in ['the', 'a', 'an']:
157
  return ' '.join(words[1:])
158
 
159
  return answer
160
 
161
-
162
  class GAIAAgent:
163
  """GAIA RAG Agent using LlamaIndex AgentWorkflow"""
164
 
@@ -197,119 +179,68 @@ class GAIAAgent:
197
 
198
  import warnings
199
  warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*Event loop is closed.*")
200
-
201
-
202
  try:
203
- # Create new event loop for async operations
204
  loop = asyncio.new_event_loop()
205
  asyncio.set_event_loop(loop)
206
 
207
  try:
208
  async def run_agent():
209
- # Track what happened during execution
210
- tool_calls = []
211
- response_chunks = []
212
-
213
  try:
214
- # Start the agent workflow
215
  handler = self.agent.run(user_msg=question)
216
 
217
- # IMPORTANT: Process events WITHOUT consuming them
218
- # We need to collect BOTH tool usage AND response content
219
- from llama_index.core.agent.workflow import ToolCallResult
220
-
221
- # Stream events and collect information
222
- async for event in handler.stream_events():
223
- # Log tool usage
224
- if isinstance(event, ToolCallResult):
225
- tool_info = f"{event.tool_name}: {str(event.result)[:100]}..."
226
- tool_calls.append(tool_info)
227
- logger.info(f"Tool used: {tool_info}")
228
-
229
- # Also collect any text responses
230
- # Different event types might have content in different attributes
231
- if hasattr(event, 'delta'):
232
- response_chunks.append(str(event.delta))
233
- elif hasattr(event, 'content'):
234
- response_chunks.append(str(event.content))
235
- elif hasattr(event, 'response'):
236
- response_chunks.append(str(event.response))
237
-
238
- # Get the final result after streaming
239
  result = await handler
240
 
241
- # Extract the final response text
242
- # Priority: accumulated chunks > result.response > str(result)
243
- if response_chunks:
244
- response_text = ''.join(response_chunks)
245
- elif hasattr(result, 'response'):
246
- response_text = str(result.response)
 
 
 
 
 
 
 
 
 
 
247
  else:
248
  response_text = str(result)
249
 
250
- # Log what tools were used for debugging
251
- if tool_calls:
252
- logger.info(f"Tools used in this query: {', '.join(set(tool_calls))}")
253
-
254
- # CRITICAL: Check if we got a meaningful response
255
- # This prevents infinite loops
256
- if not response_text or len(response_text.strip()) < 10:
257
- logger.warning("Got empty or too short response from agent")
258
- # Return a fallback response
259
- return "FINAL ANSWER: Unable to determine answer"
260
 
261
  return response_text
262
 
263
- except asyncio.TimeoutError:
264
- # Prevent infinite waiting
265
- logger.error("Agent timeout - preventing infinite loop")
266
- return "FINAL ANSWER: Request timeout"
267
-
268
  except Exception as e:
269
  logger.error(f"Agent execution error: {e}")
270
- # Return structured error response
271
- return f"FINAL ANSWER: Error occurred"
 
272
 
273
- # Run with timeout to prevent infinite loops
274
  response_text = loop.run_until_complete(
275
- asyncio.wait_for(run_agent(), timeout=120) # 2 minute timeout
276
  )
277
 
278
  # Extract clean answer
279
  clean_answer = extract_final_answer(response_text)
280
 
281
- # VALIDATION: Ensure we have a valid answer
282
- if not clean_answer:
283
- logger.warning("No answer extracted, using fallback")
284
- # Try to extract any number or short phrase from response
285
- # This prevents returning empty string to GAIA
286
- numbers = re.findall(r'\b\d+\.?\d*\b', response_text)
287
- if numbers:
288
- clean_answer = numbers[-1] # Use last number found
289
- else:
290
- # Look for any short phrase that could be an answer
291
- sentences = response_text.split('.')
292
- for sent in reversed(sentences):
293
- sent = sent.strip()
294
- if 0 < len(sent) < 50 and not sent.startswith(('I', 'The', 'To')):
295
- clean_answer = sent
296
- break
297
-
298
  logger.info(f"Full response preview: {response_text[:200]}...")
299
  logger.info(f"Extracted answer: '{clean_answer}'")
300
 
301
  return clean_answer
302
 
303
  finally:
304
- # Always close the loop
305
  loop.close()
306
 
307
  except Exception as e:
308
  logger.error(f"Error processing question: {e}")
309
- # Never return empty string to GAIA - always return something
310
- return "0" # Safe fallback for math questions
311
 
312
-
313
  def run_and_submit_all(profile: gr.OAuthProfile | None):
314
  """Run GAIA evaluation following course template structure"""
315
 
 
71
 
72
 
73
  def extract_final_answer(response_text: str) -> str:
74
+ """Extract answer aligned with GAIA scoring rules - FIXED VERSION"""
75
+
76
+ # First, remove any "assistant:" prefix that might have been added
77
+ response_text = re.sub(r'^assistant:\s*', '', response_text, flags=re.IGNORECASE)
78
 
79
  # Look for FINAL ANSWER pattern
80
  match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", response_text, re.IGNORECASE | re.DOTALL)
81
 
82
  if not match:
83
+ logger.warning("No FINAL ANSWER found in response")
84
+ return ""
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ answer = match.group(1).strip()
87
+
88
+ # CRITICAL: Stop processing if we hit "assistant:" or any reasoning text
89
+ if 'assistant:' in answer:
90
+ answer = answer.split('assistant:')[0].strip()
91
+
92
+ # Remove any trailing explanatory text (usually starts with lowercase after answer)
93
+ sentences = answer.split('.')
94
+ if len(sentences) > 1:
95
+ # Check if second sentence starts with lowercase (indicates explanation)
96
+ first_sentence = sentences[0].strip()
97
+ if first_sentence and (not sentences[1].strip() or sentences[1].strip()[0].islower()):
98
+ answer = first_sentence
99
 
100
  # Clean for GAIA scoring
101
 
102
+ # 1. Handle pure numbers
103
  if re.match(r'^[\d\s.,\-+e]+$', answer):
 
104
  cleaned = answer.replace(',', '').replace(' ', '')
105
  try:
 
106
  num = float(cleaned)
107
+ return str(int(num)) if num.is_integer() else str(num)
 
 
 
 
 
108
  except:
109
  pass
110
 
111
+ # 2. Handle percentages
112
  if answer.endswith('%'):
113
  answer = answer[:-1].strip()
114
  try:
 
117
  except:
118
  pass
119
 
120
+ # 3. Handle yes/no
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if answer.lower() in ['yes', 'no']:
122
  return answer.lower()
123
 
124
+ # 4. Handle lists
125
+ if ',' in answer:
126
+ items = [item.strip() for item in answer.split(',')]
127
+ cleaned_items = []
128
+ for item in items:
129
+ # Remove articles
130
+ words = item.split()
131
+ if words and words[0].lower() in ['the', 'a', 'an']:
132
+ cleaned_items.append(' '.join(words[1:]))
133
+ else:
134
+ cleaned_items.append(item)
135
+ return ', '.join(cleaned_items)
136
+
137
+ # 5. Single answer - remove articles
138
  words = answer.split()
139
  if words and words[0].lower() in ['the', 'a', 'an']:
140
  return ' '.join(words[1:])
141
 
142
  return answer
143
 
 
144
  class GAIAAgent:
145
  """GAIA RAG Agent using LlamaIndex AgentWorkflow"""
146
 
 
179
 
180
  import warnings
181
  warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*Event loop is closed.*")
182
+
 
183
  try:
 
184
  loop = asyncio.new_event_loop()
185
  asyncio.set_event_loop(loop)
186
 
187
  try:
188
  async def run_agent():
 
 
 
 
189
  try:
 
190
  handler = self.agent.run(user_msg=question)
191
 
192
+ # Wait for the result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  result = await handler
194
 
195
+ # Extract response text more carefully
196
+ response_text = ""
197
+
198
+ # Try different ways to get the response
199
+ if hasattr(result, 'response'):
200
+ if hasattr(result.response, 'message'):
201
+ if hasattr(result.response.message, 'content'):
202
+ response_text = result.response.message.content
203
+ else:
204
+ response_text = str(result.response.message)
205
+ else:
206
+ response_text = str(result.response)
207
+ elif hasattr(result, 'content'):
208
+ response_text = result.content
209
+ elif hasattr(result, 'output'):
210
+ response_text = result.output
211
  else:
212
  response_text = str(result)
213
 
214
+ # Clean up any streaming artifacts
215
+ response_text = re.sub(r'assistant:\s*', '', response_text, flags=re.IGNORECASE)
 
 
 
 
 
 
 
 
216
 
217
  return response_text
218
 
 
 
 
 
 
219
  except Exception as e:
220
  logger.error(f"Agent execution error: {e}")
221
+ import traceback
222
+ logger.error(traceback.format_exc())
223
+ return "FINAL ANSWER: "
224
 
 
225
  response_text = loop.run_until_complete(
226
+ asyncio.wait_for(run_agent(), timeout=60)
227
  )
228
 
229
  # Extract clean answer
230
  clean_answer = extract_final_answer(response_text)
231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  logger.info(f"Full response preview: {response_text[:200]}...")
233
  logger.info(f"Extracted answer: '{clean_answer}'")
234
 
235
  return clean_answer
236
 
237
  finally:
 
238
  loop.close()
239
 
240
  except Exception as e:
241
  logger.error(f"Error processing question: {e}")
242
+ return ""
 
243
 
 
244
  def run_and_submit_all(profile: gr.OAuthProfile | None):
245
  """Run GAIA evaluation following course template structure"""
246