Update GAIA agent-fixed extract answer
Browse files
app.py
CHANGED
|
@@ -71,49 +71,44 @@ def setup_llm():
|
|
| 71 |
|
| 72 |
|
| 73 |
def extract_final_answer(response_text: str) -> str:
|
| 74 |
-
"""Extract answer aligned with GAIA scoring rules"""
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
# Look for FINAL ANSWER pattern
|
| 77 |
match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", response_text, re.IGNORECASE | re.DOTALL)
|
| 78 |
|
| 79 |
if not match:
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
if lines:
|
| 83 |
-
# Check if last line looks like an answer
|
| 84 |
-
last_line = lines[-1].strip()
|
| 85 |
-
if len(last_line) < 100 and not last_line.startswith(('I', 'The', 'To', 'Based')):
|
| 86 |
-
answer = last_line
|
| 87 |
-
else:
|
| 88 |
-
logger.warning("No FINAL ANSWER found")
|
| 89 |
-
return ""
|
| 90 |
-
else:
|
| 91 |
-
return ""
|
| 92 |
-
else:
|
| 93 |
-
answer = match.group(1).strip()
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
# Clean for GAIA scoring
|
| 99 |
|
| 100 |
-
# 1. Handle numbers
|
| 101 |
if re.match(r'^[\d\s.,\-+e]+$', answer):
|
| 102 |
-
# Remove all formatting
|
| 103 |
cleaned = answer.replace(',', '').replace(' ', '')
|
| 104 |
try:
|
| 105 |
-
# Try to parse as float
|
| 106 |
num = float(cleaned)
|
| 107 |
-
|
| 108 |
-
if num.is_integer():
|
| 109 |
-
return str(int(num))
|
| 110 |
-
else:
|
| 111 |
-
# Keep original precision, don't round
|
| 112 |
-
return str(num)
|
| 113 |
except:
|
| 114 |
pass
|
| 115 |
|
| 116 |
-
# 2. Handle percentages
|
| 117 |
if answer.endswith('%'):
|
| 118 |
answer = answer[:-1].strip()
|
| 119 |
try:
|
|
@@ -122,43 +117,30 @@ def extract_final_answer(response_text: str) -> str:
|
|
| 122 |
except:
|
| 123 |
pass
|
| 124 |
|
| 125 |
-
# 3.
|
| 126 |
-
if ',' in answer or ' and ' in answer.lower():
|
| 127 |
-
# Split on commas and 'and'
|
| 128 |
-
parts = re.split(r',|\s+and\s+', answer)
|
| 129 |
-
cleaned_parts = []
|
| 130 |
-
|
| 131 |
-
for part in parts:
|
| 132 |
-
part = part.strip()
|
| 133 |
-
if not part:
|
| 134 |
-
continue
|
| 135 |
-
|
| 136 |
-
# Try to parse as number
|
| 137 |
-
try:
|
| 138 |
-
num = float(part.replace('$', '').replace('%', '').replace(',', ''))
|
| 139 |
-
cleaned_parts.append(str(int(num)) if num.is_integer() else str(num))
|
| 140 |
-
except:
|
| 141 |
-
# Remove articles from strings
|
| 142 |
-
words = part.split()
|
| 143 |
-
if words and words[0].lower() in ['the', 'a', 'an']:
|
| 144 |
-
cleaned_parts.append(' '.join(words[1:]))
|
| 145 |
-
else:
|
| 146 |
-
cleaned_parts.append(part)
|
| 147 |
-
|
| 148 |
-
return ', '.join(cleaned_parts)
|
| 149 |
-
|
| 150 |
-
# 4. Yes/No answers
|
| 151 |
if answer.lower() in ['yes', 'no']:
|
| 152 |
return answer.lower()
|
| 153 |
|
| 154 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
words = answer.split()
|
| 156 |
if words and words[0].lower() in ['the', 'a', 'an']:
|
| 157 |
return ' '.join(words[1:])
|
| 158 |
|
| 159 |
return answer
|
| 160 |
|
| 161 |
-
|
| 162 |
class GAIAAgent:
|
| 163 |
"""GAIA RAG Agent using LlamaIndex AgentWorkflow"""
|
| 164 |
|
|
@@ -197,119 +179,68 @@ class GAIAAgent:
|
|
| 197 |
|
| 198 |
import warnings
|
| 199 |
warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*Event loop is closed.*")
|
| 200 |
-
|
| 201 |
-
|
| 202 |
try:
|
| 203 |
-
# Create new event loop for async operations
|
| 204 |
loop = asyncio.new_event_loop()
|
| 205 |
asyncio.set_event_loop(loop)
|
| 206 |
|
| 207 |
try:
|
| 208 |
async def run_agent():
|
| 209 |
-
# Track what happened during execution
|
| 210 |
-
tool_calls = []
|
| 211 |
-
response_chunks = []
|
| 212 |
-
|
| 213 |
try:
|
| 214 |
-
# Start the agent workflow
|
| 215 |
handler = self.agent.run(user_msg=question)
|
| 216 |
|
| 217 |
-
#
|
| 218 |
-
# We need to collect BOTH tool usage AND response content
|
| 219 |
-
from llama_index.core.agent.workflow import ToolCallResult
|
| 220 |
-
|
| 221 |
-
# Stream events and collect information
|
| 222 |
-
async for event in handler.stream_events():
|
| 223 |
-
# Log tool usage
|
| 224 |
-
if isinstance(event, ToolCallResult):
|
| 225 |
-
tool_info = f"{event.tool_name}: {str(event.result)[:100]}..."
|
| 226 |
-
tool_calls.append(tool_info)
|
| 227 |
-
logger.info(f"Tool used: {tool_info}")
|
| 228 |
-
|
| 229 |
-
# Also collect any text responses
|
| 230 |
-
# Different event types might have content in different attributes
|
| 231 |
-
if hasattr(event, 'delta'):
|
| 232 |
-
response_chunks.append(str(event.delta))
|
| 233 |
-
elif hasattr(event, 'content'):
|
| 234 |
-
response_chunks.append(str(event.content))
|
| 235 |
-
elif hasattr(event, 'response'):
|
| 236 |
-
response_chunks.append(str(event.response))
|
| 237 |
-
|
| 238 |
-
# Get the final result after streaming
|
| 239 |
result = await handler
|
| 240 |
|
| 241 |
-
# Extract
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
else:
|
| 248 |
response_text = str(result)
|
| 249 |
|
| 250 |
-
#
|
| 251 |
-
|
| 252 |
-
logger.info(f"Tools used in this query: {', '.join(set(tool_calls))}")
|
| 253 |
-
|
| 254 |
-
# CRITICAL: Check if we got a meaningful response
|
| 255 |
-
# This prevents infinite loops
|
| 256 |
-
if not response_text or len(response_text.strip()) < 10:
|
| 257 |
-
logger.warning("Got empty or too short response from agent")
|
| 258 |
-
# Return a fallback response
|
| 259 |
-
return "FINAL ANSWER: Unable to determine answer"
|
| 260 |
|
| 261 |
return response_text
|
| 262 |
|
| 263 |
-
except asyncio.TimeoutError:
|
| 264 |
-
# Prevent infinite waiting
|
| 265 |
-
logger.error("Agent timeout - preventing infinite loop")
|
| 266 |
-
return "FINAL ANSWER: Request timeout"
|
| 267 |
-
|
| 268 |
except Exception as e:
|
| 269 |
logger.error(f"Agent execution error: {e}")
|
| 270 |
-
|
| 271 |
-
|
|
|
|
| 272 |
|
| 273 |
-
# Run with timeout to prevent infinite loops
|
| 274 |
response_text = loop.run_until_complete(
|
| 275 |
-
asyncio.wait_for(run_agent(), timeout=
|
| 276 |
)
|
| 277 |
|
| 278 |
# Extract clean answer
|
| 279 |
clean_answer = extract_final_answer(response_text)
|
| 280 |
|
| 281 |
-
# VALIDATION: Ensure we have a valid answer
|
| 282 |
-
if not clean_answer:
|
| 283 |
-
logger.warning("No answer extracted, using fallback")
|
| 284 |
-
# Try to extract any number or short phrase from response
|
| 285 |
-
# This prevents returning empty string to GAIA
|
| 286 |
-
numbers = re.findall(r'\b\d+\.?\d*\b', response_text)
|
| 287 |
-
if numbers:
|
| 288 |
-
clean_answer = numbers[-1] # Use last number found
|
| 289 |
-
else:
|
| 290 |
-
# Look for any short phrase that could be an answer
|
| 291 |
-
sentences = response_text.split('.')
|
| 292 |
-
for sent in reversed(sentences):
|
| 293 |
-
sent = sent.strip()
|
| 294 |
-
if 0 < len(sent) < 50 and not sent.startswith(('I', 'The', 'To')):
|
| 295 |
-
clean_answer = sent
|
| 296 |
-
break
|
| 297 |
-
|
| 298 |
logger.info(f"Full response preview: {response_text[:200]}...")
|
| 299 |
logger.info(f"Extracted answer: '{clean_answer}'")
|
| 300 |
|
| 301 |
return clean_answer
|
| 302 |
|
| 303 |
finally:
|
| 304 |
-
# Always close the loop
|
| 305 |
loop.close()
|
| 306 |
|
| 307 |
except Exception as e:
|
| 308 |
logger.error(f"Error processing question: {e}")
|
| 309 |
-
|
| 310 |
-
return "0" # Safe fallback for math questions
|
| 311 |
|
| 312 |
-
|
| 313 |
def run_and_submit_all(profile: gr.OAuthProfile | None):
|
| 314 |
"""Run GAIA evaluation following course template structure"""
|
| 315 |
|
|
|
|
| 71 |
|
| 72 |
|
| 73 |
def extract_final_answer(response_text: str) -> str:
|
| 74 |
+
"""Extract answer aligned with GAIA scoring rules - FIXED VERSION"""
|
| 75 |
+
|
| 76 |
+
# First, remove any "assistant:" prefix that might have been added
|
| 77 |
+
response_text = re.sub(r'^assistant:\s*', '', response_text, flags=re.IGNORECASE)
|
| 78 |
|
| 79 |
# Look for FINAL ANSWER pattern
|
| 80 |
match = re.search(r"FINAL ANSWER:\s*(.+?)(?:\n|$)", response_text, re.IGNORECASE | re.DOTALL)
|
| 81 |
|
| 82 |
if not match:
|
| 83 |
+
logger.warning("No FINAL ANSWER found in response")
|
| 84 |
+
return ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
|
| 86 |
+
answer = match.group(1).strip()
|
| 87 |
+
|
| 88 |
+
# CRITICAL: Stop processing if we hit "assistant:" or any reasoning text
|
| 89 |
+
if 'assistant:' in answer:
|
| 90 |
+
answer = answer.split('assistant:')[0].strip()
|
| 91 |
+
|
| 92 |
+
# Remove any trailing explanatory text (usually starts with lowercase after answer)
|
| 93 |
+
sentences = answer.split('.')
|
| 94 |
+
if len(sentences) > 1:
|
| 95 |
+
# Check if second sentence starts with lowercase (indicates explanation)
|
| 96 |
+
first_sentence = sentences[0].strip()
|
| 97 |
+
if first_sentence and (not sentences[1].strip() or sentences[1].strip()[0].islower()):
|
| 98 |
+
answer = first_sentence
|
| 99 |
|
| 100 |
# Clean for GAIA scoring
|
| 101 |
|
| 102 |
+
# 1. Handle pure numbers
|
| 103 |
if re.match(r'^[\d\s.,\-+e]+$', answer):
|
|
|
|
| 104 |
cleaned = answer.replace(',', '').replace(' ', '')
|
| 105 |
try:
|
|
|
|
| 106 |
num = float(cleaned)
|
| 107 |
+
return str(int(num)) if num.is_integer() else str(num)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
except:
|
| 109 |
pass
|
| 110 |
|
| 111 |
+
# 2. Handle percentages
|
| 112 |
if answer.endswith('%'):
|
| 113 |
answer = answer[:-1].strip()
|
| 114 |
try:
|
|
|
|
| 117 |
except:
|
| 118 |
pass
|
| 119 |
|
| 120 |
+
# 3. Handle yes/no
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
if answer.lower() in ['yes', 'no']:
|
| 122 |
return answer.lower()
|
| 123 |
|
| 124 |
+
# 4. Handle lists
|
| 125 |
+
if ',' in answer:
|
| 126 |
+
items = [item.strip() for item in answer.split(',')]
|
| 127 |
+
cleaned_items = []
|
| 128 |
+
for item in items:
|
| 129 |
+
# Remove articles
|
| 130 |
+
words = item.split()
|
| 131 |
+
if words and words[0].lower() in ['the', 'a', 'an']:
|
| 132 |
+
cleaned_items.append(' '.join(words[1:]))
|
| 133 |
+
else:
|
| 134 |
+
cleaned_items.append(item)
|
| 135 |
+
return ', '.join(cleaned_items)
|
| 136 |
+
|
| 137 |
+
# 5. Single answer - remove articles
|
| 138 |
words = answer.split()
|
| 139 |
if words and words[0].lower() in ['the', 'a', 'an']:
|
| 140 |
return ' '.join(words[1:])
|
| 141 |
|
| 142 |
return answer
|
| 143 |
|
|
|
|
| 144 |
class GAIAAgent:
|
| 145 |
"""GAIA RAG Agent using LlamaIndex AgentWorkflow"""
|
| 146 |
|
|
|
|
| 179 |
|
| 180 |
import warnings
|
| 181 |
warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*Event loop is closed.*")
|
| 182 |
+
|
|
|
|
| 183 |
try:
|
|
|
|
| 184 |
loop = asyncio.new_event_loop()
|
| 185 |
asyncio.set_event_loop(loop)
|
| 186 |
|
| 187 |
try:
|
| 188 |
async def run_agent():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
try:
|
|
|
|
| 190 |
handler = self.agent.run(user_msg=question)
|
| 191 |
|
| 192 |
+
# Wait for the result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
result = await handler
|
| 194 |
|
| 195 |
+
# Extract response text more carefully
|
| 196 |
+
response_text = ""
|
| 197 |
+
|
| 198 |
+
# Try different ways to get the response
|
| 199 |
+
if hasattr(result, 'response'):
|
| 200 |
+
if hasattr(result.response, 'message'):
|
| 201 |
+
if hasattr(result.response.message, 'content'):
|
| 202 |
+
response_text = result.response.message.content
|
| 203 |
+
else:
|
| 204 |
+
response_text = str(result.response.message)
|
| 205 |
+
else:
|
| 206 |
+
response_text = str(result.response)
|
| 207 |
+
elif hasattr(result, 'content'):
|
| 208 |
+
response_text = result.content
|
| 209 |
+
elif hasattr(result, 'output'):
|
| 210 |
+
response_text = result.output
|
| 211 |
else:
|
| 212 |
response_text = str(result)
|
| 213 |
|
| 214 |
+
# Clean up any streaming artifacts
|
| 215 |
+
response_text = re.sub(r'assistant:\s*', '', response_text, flags=re.IGNORECASE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
|
| 217 |
return response_text
|
| 218 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 219 |
except Exception as e:
|
| 220 |
logger.error(f"Agent execution error: {e}")
|
| 221 |
+
import traceback
|
| 222 |
+
logger.error(traceback.format_exc())
|
| 223 |
+
return "FINAL ANSWER: "
|
| 224 |
|
|
|
|
| 225 |
response_text = loop.run_until_complete(
|
| 226 |
+
asyncio.wait_for(run_agent(), timeout=60)
|
| 227 |
)
|
| 228 |
|
| 229 |
# Extract clean answer
|
| 230 |
clean_answer = extract_final_answer(response_text)
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
logger.info(f"Full response preview: {response_text[:200]}...")
|
| 233 |
logger.info(f"Extracted answer: '{clean_answer}'")
|
| 234 |
|
| 235 |
return clean_answer
|
| 236 |
|
| 237 |
finally:
|
|
|
|
| 238 |
loop.close()
|
| 239 |
|
| 240 |
except Exception as e:
|
| 241 |
logger.error(f"Error processing question: {e}")
|
| 242 |
+
return ""
|
|
|
|
| 243 |
|
|
|
|
| 244 |
def run_and_submit_all(profile: gr.OAuthProfile | None):
|
| 245 |
"""Run GAIA evaluation following course template structure"""
|
| 246 |
|