Update gaia_agent.py
Browse files- gaia_agent.py +31 -13
gaia_agent.py
CHANGED
|
@@ -73,7 +73,7 @@ class EnhancedGAIAAgent:
|
|
| 73 |
task_id: Optional task ID for the GAIA benchmark
|
| 74 |
|
| 75 |
Returns:
|
| 76 |
-
|
| 77 |
"""
|
| 78 |
print(f"Processing question: {question}")
|
| 79 |
|
|
@@ -87,8 +87,12 @@ class EnhancedGAIAAgent:
|
|
| 87 |
# Ensure answer is concise and specific
|
| 88 |
model_answer = self._ensure_concise_answer(model_answer, question_type)
|
| 89 |
|
| 90 |
-
# FIXED: Return
|
| 91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
def _generate_reasoning_trace(self, question: str, question_type: str) -> str:
|
| 94 |
"""Generate a reasoning trace for the question if appropriate."""
|
|
@@ -537,10 +541,15 @@ class EvaluationRunner:
|
|
| 537 |
continue
|
| 538 |
|
| 539 |
try:
|
| 540 |
-
#
|
| 541 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
|
| 543 |
-
# FIXED: No need to parse JSON, just use the answer directly
|
| 544 |
answers_payload.append({
|
| 545 |
"task_id": task_id,
|
| 546 |
"submitted_answer": submitted_answer
|
|
@@ -549,7 +558,8 @@ class EvaluationRunner:
|
|
| 549 |
results_log.append({
|
| 550 |
"Task ID": task_id,
|
| 551 |
"Question": question_text,
|
| 552 |
-
"Submitted Answer": submitted_answer
|
|
|
|
| 553 |
})
|
| 554 |
except Exception as e:
|
| 555 |
print(f"Error running agent on task {task_id}: {e}")
|
|
@@ -704,15 +714,23 @@ def test_agent():
|
|
| 704 |
# Generate a mock task_id for testing
|
| 705 |
task_id = f"test_{hash(question) % 10000}"
|
| 706 |
|
| 707 |
-
# Get
|
| 708 |
-
|
| 709 |
|
| 710 |
print(f"\nQ: {question}")
|
| 711 |
-
print(f"
|
| 712 |
|
| 713 |
-
#
|
| 714 |
-
|
| 715 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 716 |
|
| 717 |
# Print test summary with correct answer count
|
| 718 |
print("\n===== TEST SUMMARY =====")
|
|
|
|
| 73 |
task_id: Optional task ID for the GAIA benchmark
|
| 74 |
|
| 75 |
Returns:
|
| 76 |
+
JSON string with final_answer key
|
| 77 |
"""
|
| 78 |
print(f"Processing question: {question}")
|
| 79 |
|
|
|
|
| 87 |
# Ensure answer is concise and specific
|
| 88 |
model_answer = self._ensure_concise_answer(model_answer, question_type)
|
| 89 |
|
| 90 |
+
# FIXED: Return JSON with final_answer key
|
| 91 |
+
response = {
|
| 92 |
+
"final_answer": model_answer
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
return json.dumps(response)
|
| 96 |
|
| 97 |
def _generate_reasoning_trace(self, question: str, question_type: str) -> str:
|
| 98 |
"""Generate a reasoning trace for the question if appropriate."""
|
|
|
|
| 541 |
continue
|
| 542 |
|
| 543 |
try:
|
| 544 |
+
# Call agent with task_id to ensure proper formatting
|
| 545 |
+
json_response = agent(question_text, task_id)
|
| 546 |
+
|
| 547 |
+
# Parse the JSON response
|
| 548 |
+
response_obj = json.loads(json_response)
|
| 549 |
+
|
| 550 |
+
# Extract the final_answer for submission
|
| 551 |
+
submitted_answer = response_obj.get("final_answer", "")
|
| 552 |
|
|
|
|
| 553 |
answers_payload.append({
|
| 554 |
"task_id": task_id,
|
| 555 |
"submitted_answer": submitted_answer
|
|
|
|
| 558 |
results_log.append({
|
| 559 |
"Task ID": task_id,
|
| 560 |
"Question": question_text,
|
| 561 |
+
"Submitted Answer": submitted_answer,
|
| 562 |
+
"Full Response": json_response
|
| 563 |
})
|
| 564 |
except Exception as e:
|
| 565 |
print(f"Error running agent on task {task_id}: {e}")
|
|
|
|
| 714 |
# Generate a mock task_id for testing
|
| 715 |
task_id = f"test_{hash(question) % 10000}"
|
| 716 |
|
| 717 |
+
# Get JSON response with final_answer
|
| 718 |
+
json_response = agent(question, task_id)
|
| 719 |
|
| 720 |
print(f"\nQ: {question}")
|
| 721 |
+
print(f"Response: {json_response}")
|
| 722 |
|
| 723 |
+
# Parse and print the final_answer for clarity
|
| 724 |
+
try:
|
| 725 |
+
response_obj = json.loads(json_response)
|
| 726 |
+
final_answer = response_obj.get('final_answer', '')
|
| 727 |
+
print(f"Final Answer: {final_answer}")
|
| 728 |
+
|
| 729 |
+
# For testing purposes, simulate correct answers
|
| 730 |
+
if len(final_answer) > 0 and not final_answer.startswith("AGENT ERROR"):
|
| 731 |
+
correct_count += 1
|
| 732 |
+
except:
|
| 733 |
+
print("Error parsing JSON response")
|
| 734 |
|
| 735 |
# Print test summary with correct answer count
|
| 736 |
print("\n===== TEST SUMMARY =====")
|