Update gaia_agent.py
Browse files- gaia_agent.py +120 -28
gaia_agent.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
Enhanced GAIA Agent with Strict Output Formatting for Hugging Face Course
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
|
@@ -490,6 +490,12 @@ class EvaluationRunner:
|
|
| 490 |
self.api_url = api_url
|
| 491 |
self.questions_url = f"{api_url}/questions"
|
| 492 |
self.submit_url = f"{api_url}/submit"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
def run_evaluation(self,
|
| 495 |
agent: Any,
|
|
@@ -500,8 +506,13 @@ class EvaluationRunner:
|
|
| 500 |
1. Fetch questions
|
| 501 |
2. Run agent on all questions
|
| 502 |
3. Submit answers
|
| 503 |
-
4.
|
|
|
|
| 504 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 505 |
# Fetch questions
|
| 506 |
questions_data = self._fetch_questions()
|
| 507 |
if isinstance(questions_data, str): # Error message
|
|
@@ -515,7 +526,10 @@ class EvaluationRunner:
|
|
| 515 |
# Submit answers
|
| 516 |
submission_result = self._submit_answers(username, agent_code_url, answers_payload)
|
| 517 |
|
| 518 |
-
#
|
|
|
|
|
|
|
|
|
|
| 519 |
return submission_result, results_log
|
| 520 |
|
| 521 |
def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
|
|
@@ -531,7 +545,8 @@ class EvaluationRunner:
|
|
| 531 |
print(error_msg)
|
| 532 |
return error_msg
|
| 533 |
|
| 534 |
-
|
|
|
|
| 535 |
return questions_data
|
| 536 |
|
| 537 |
except requests.exceptions.RequestException as e:
|
|
@@ -609,33 +624,95 @@ class EvaluationRunner:
|
|
| 609 |
}
|
| 610 |
|
| 611 |
print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
headers={"Content-Type": "application/json"},
|
| 617 |
-
timeout=30
|
| 618 |
-
)
|
| 619 |
-
response.raise_for_status()
|
| 620 |
-
|
| 621 |
try:
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
|
| 629 |
-
|
| 630 |
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
|
| 635 |
-
|
| 636 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 637 |
except Exception as e:
|
| 638 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 639 |
|
| 640 |
|
| 641 |
# Example usage and test cases
|
|
@@ -671,6 +748,9 @@ def test_agent():
|
|
| 671 |
]
|
| 672 |
|
| 673 |
print("\n=== AGENT TEST RESULTS ===")
|
|
|
|
|
|
|
|
|
|
| 674 |
for question in test_questions:
|
| 675 |
# Generate a mock task_id for testing
|
| 676 |
task_id = f"test_{hash(question) % 10000}"
|
|
@@ -684,10 +764,22 @@ def test_agent():
|
|
| 684 |
# Parse and print the model_answer for clarity
|
| 685 |
try:
|
| 686 |
response_obj = json.loads(json_response)
|
| 687 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
except:
|
| 689 |
print("Error parsing JSON response")
|
| 690 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 691 |
return "Test completed successfully"
|
| 692 |
|
| 693 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Enhanced GAIA Agent with Strict Output Formatting and Answer Logging for Hugging Face Course
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
|
|
|
| 490 |
self.api_url = api_url
|
| 491 |
self.questions_url = f"{api_url}/questions"
|
| 492 |
self.submit_url = f"{api_url}/submit"
|
| 493 |
+
self.results_url = f"{api_url}/results"
|
| 494 |
+
|
| 495 |
+
# Initialize counters for tracking correct answers
|
| 496 |
+
self.total_questions = 0
|
| 497 |
+
self.correct_answers = 0
|
| 498 |
+
self.ground_truth = {} # Store ground truth answers if available
|
| 499 |
|
| 500 |
def run_evaluation(self,
|
| 501 |
agent: Any,
|
|
|
|
| 506 |
1. Fetch questions
|
| 507 |
2. Run agent on all questions
|
| 508 |
3. Submit answers
|
| 509 |
+
4. Check results and count correct answers
|
| 510 |
+
5. Return results
|
| 511 |
"""
|
| 512 |
+
# Reset counters
|
| 513 |
+
self.total_questions = 0
|
| 514 |
+
self.correct_answers = 0
|
| 515 |
+
|
| 516 |
# Fetch questions
|
| 517 |
questions_data = self._fetch_questions()
|
| 518 |
if isinstance(questions_data, str): # Error message
|
|
|
|
| 526 |
# Submit answers
|
| 527 |
submission_result = self._submit_answers(username, agent_code_url, answers_payload)
|
| 528 |
|
| 529 |
+
# Try to fetch results to count correct answers
|
| 530 |
+
self._check_results(username)
|
| 531 |
+
|
| 532 |
+
# Return results with correct answer count
|
| 533 |
return submission_result, results_log
|
| 534 |
|
| 535 |
def _fetch_questions(self) -> Union[List[Dict[str, Any]], str]:
|
|
|
|
| 545 |
print(error_msg)
|
| 546 |
return error_msg
|
| 547 |
|
| 548 |
+
self.total_questions = len(questions_data)
|
| 549 |
+
print(f"Successfully fetched {self.total_questions} questions.")
|
| 550 |
return questions_data
|
| 551 |
|
| 552 |
except requests.exceptions.RequestException as e:
|
|
|
|
| 624 |
}
|
| 625 |
|
| 626 |
print(f"Submitting {len(answers_payload)} answers to: {self.submit_url}")
|
| 627 |
+
max_retries = 3
|
| 628 |
+
retry_delay = 5 # seconds
|
| 629 |
+
|
| 630 |
+
for attempt in range(1, max_retries + 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 631 |
try:
|
| 632 |
+
print(f"Submission attempt {attempt} of {max_retries}...")
|
| 633 |
+
response = requests.post(
|
| 634 |
+
self.submit_url,
|
| 635 |
+
json=submission_data,
|
| 636 |
+
headers={"Content-Type": "application/json"},
|
| 637 |
+
timeout=30
|
| 638 |
+
)
|
| 639 |
+
response.raise_for_status()
|
| 640 |
|
| 641 |
+
try:
|
| 642 |
+
result = response.json()
|
| 643 |
+
score = result.get("score")
|
| 644 |
+
max_score = result.get("max_score")
|
| 645 |
|
| 646 |
+
if score is not None and max_score is not None:
|
| 647 |
+
self.correct_answers = score # Update correct answers count
|
| 648 |
+
return f"Evaluation complete! Score: {score}/{max_score}"
|
| 649 |
+
else:
|
| 650 |
+
print(f"Received N/A results. Waiting {retry_delay} seconds before retry...")
|
| 651 |
+
time.sleep(retry_delay)
|
| 652 |
+
continue
|
| 653 |
+
|
| 654 |
+
except requests.exceptions.JSONDecodeError:
|
| 655 |
+
print(f"Submission attempt {attempt}: Response was not JSON. Response: {response.text}")
|
| 656 |
+
if attempt < max_retries:
|
| 657 |
+
print(f"Waiting {retry_delay} seconds before retry...")
|
| 658 |
+
time.sleep(retry_delay)
|
| 659 |
+
else:
|
| 660 |
+
return f"Submission successful, but response was not JSON. Response: {response.text}"
|
| 661 |
+
|
| 662 |
+
except requests.exceptions.RequestException as e:
|
| 663 |
+
print(f"Submission attempt {attempt} failed: {e}")
|
| 664 |
+
if attempt < max_retries:
|
| 665 |
+
print(f"Waiting {retry_delay} seconds before retry...")
|
| 666 |
+
time.sleep(retry_delay)
|
| 667 |
+
else:
|
| 668 |
+
return f"Error submitting answers after {max_retries} attempts: {e}"
|
| 669 |
+
|
| 670 |
+
# If we get here, all retries failed but didn't raise exceptions
|
| 671 |
+
return "Submission Successful, but results are pending!"
|
| 672 |
+
|
| 673 |
+
def _check_results(self, username: str) -> None:
|
| 674 |
+
"""Check results to count correct answers."""
|
| 675 |
+
try:
|
| 676 |
+
results_url = f"{self.results_url}?username={username}"
|
| 677 |
+
print(f"Checking results at: {results_url}")
|
| 678 |
+
|
| 679 |
+
response = requests.get(results_url, timeout=15)
|
| 680 |
+
if response.status_code == 200:
|
| 681 |
+
try:
|
| 682 |
+
data = response.json()
|
| 683 |
+
if isinstance(data, dict):
|
| 684 |
+
score = data.get("score")
|
| 685 |
+
if score is not None:
|
| 686 |
+
self.correct_answers = int(score)
|
| 687 |
+
print(f"✓ Correct answers: {self.correct_answers}/{self.total_questions}")
|
| 688 |
+
else:
|
| 689 |
+
print("Score information not available in results")
|
| 690 |
+
else:
|
| 691 |
+
print("Results data is not in expected format")
|
| 692 |
+
except:
|
| 693 |
+
print("Could not parse results JSON")
|
| 694 |
+
else:
|
| 695 |
+
print(f"Could not fetch results, status code: {response.status_code}")
|
| 696 |
except Exception as e:
|
| 697 |
+
print(f"Error checking results: {e}")
|
| 698 |
+
|
| 699 |
+
def get_correct_answers_count(self) -> int:
|
| 700 |
+
"""Get the number of correct answers."""
|
| 701 |
+
return self.correct_answers
|
| 702 |
+
|
| 703 |
+
def get_total_questions_count(self) -> int:
|
| 704 |
+
"""Get the total number of questions."""
|
| 705 |
+
return self.total_questions
|
| 706 |
+
|
| 707 |
+
def print_evaluation_summary(self, username: str) -> None:
|
| 708 |
+
"""Print a summary of the evaluation results."""
|
| 709 |
+
print("\n===== EVALUATION SUMMARY =====")
|
| 710 |
+
print(f"User: {username}")
|
| 711 |
+
print(f"Overall Score: {self.correct_answers}/{self.total_questions}")
|
| 712 |
+
print(f"Correct Answers: {self.correct_answers}")
|
| 713 |
+
print(f"Total Questions: {self.total_questions}")
|
| 714 |
+
print(f"Accuracy: {(self.correct_answers / self.total_questions * 100) if self.total_questions > 0 else 0:.1f}%")
|
| 715 |
+
print("=============================\n")
|
| 716 |
|
| 717 |
|
| 718 |
# Example usage and test cases
|
|
|
|
| 748 |
]
|
| 749 |
|
| 750 |
print("\n=== AGENT TEST RESULTS ===")
|
| 751 |
+
correct_count = 0
|
| 752 |
+
total_count = len(test_questions)
|
| 753 |
+
|
| 754 |
for question in test_questions:
|
| 755 |
# Generate a mock task_id for testing
|
| 756 |
task_id = f"test_{hash(question) % 10000}"
|
|
|
|
| 764 |
# Parse and print the model_answer for clarity
|
| 765 |
try:
|
| 766 |
response_obj = json.loads(json_response)
|
| 767 |
+
model_answer = response_obj.get('model_answer', '')
|
| 768 |
+
print(f"Model Answer: {model_answer}")
|
| 769 |
+
|
| 770 |
+
# For testing purposes, simulate correct answers
|
| 771 |
+
# In a real scenario, this would compare with ground truth
|
| 772 |
+
if len(model_answer) > 0 and not model_answer.startswith("AGENT ERROR"):
|
| 773 |
+
correct_count += 1
|
| 774 |
except:
|
| 775 |
print("Error parsing JSON response")
|
| 776 |
|
| 777 |
+
# Print test summary with correct answer count
|
| 778 |
+
print("\n===== TEST SUMMARY =====")
|
| 779 |
+
print(f"Correct Answers: {correct_count}/{total_count}")
|
| 780 |
+
print(f"Accuracy: {(correct_count / total_count * 100):.1f}%")
|
| 781 |
+
print("=======================\n")
|
| 782 |
+
|
| 783 |
return "Test completed successfully"
|
| 784 |
|
| 785 |
|