Update gaia_agent.py
Browse files- gaia_agent.py +98 -12
gaia_agent.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
"""
|
| 2 |
-
Enhanced GAIA Agent with
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
|
@@ -15,7 +15,7 @@ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, pipeline
|
|
| 15 |
class EnhancedGAIAAgent:
|
| 16 |
"""
|
| 17 |
An enhanced agent designed to pass the GAIA evaluation by combining rule-based precision
|
| 18 |
-
with LLM-powered flexibility
|
| 19 |
"""
|
| 20 |
|
| 21 |
def __init__(self, model_name="google/flan-t5-large", device=None):
|
|
@@ -64,21 +64,85 @@ class EnhancedGAIAAgent:
|
|
| 64 |
self.tokenizer = None
|
| 65 |
self.model = None
|
| 66 |
|
| 67 |
-
def __call__(self, question: str) -> str:
|
| 68 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
print(f"Processing question: {question}")
|
| 70 |
|
| 71 |
# Determine question type
|
| 72 |
question_type = self._classify_question(question)
|
| 73 |
print(f"Classified as: {question_type}")
|
| 74 |
|
| 75 |
-
#
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
# Ensure answer is concise and specific
|
| 79 |
-
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
def _classify_question(self, question: str) -> str:
|
| 84 |
"""Determine the type of question for specialized handling."""
|
|
@@ -503,15 +567,25 @@ class EvaluationRunner:
|
|
| 503 |
continue
|
| 504 |
|
| 505 |
try:
|
| 506 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
answers_payload.append({
|
| 508 |
"task_id": task_id,
|
| 509 |
"submitted_answer": submitted_answer
|
| 510 |
})
|
|
|
|
| 511 |
results_log.append({
|
| 512 |
"Task ID": task_id,
|
| 513 |
"Question": question_text,
|
| 514 |
-
"Submitted Answer": submitted_answer
|
|
|
|
| 515 |
})
|
| 516 |
except Exception as e:
|
| 517 |
print(f"Error running agent on task {task_id}: {e}")
|
|
@@ -598,9 +672,21 @@ def test_agent():
|
|
| 598 |
|
| 599 |
print("\n=== AGENT TEST RESULTS ===")
|
| 600 |
for question in test_questions:
|
| 601 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
print(f"\nQ: {question}")
|
| 603 |
-
print(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 604 |
|
| 605 |
return "Test completed successfully"
|
| 606 |
|
|
|
|
| 1 |
"""
|
| 2 |
+
Enhanced GAIA Agent with Strict Output Formatting for Hugging Face Course
|
| 3 |
"""
|
| 4 |
|
| 5 |
import os
|
|
|
|
| 15 |
class EnhancedGAIAAgent:
|
| 16 |
"""
|
| 17 |
An enhanced agent designed to pass the GAIA evaluation by combining rule-based precision
|
| 18 |
+
with LLM-powered flexibility and strict output formatting.
|
| 19 |
"""
|
| 20 |
|
| 21 |
def __init__(self, model_name="google/flan-t5-large", device=None):
|
|
|
|
| 64 |
self.tokenizer = None
|
| 65 |
self.model = None
|
| 66 |
|
| 67 |
+
def __call__(self, question: str, task_id: str = None) -> str:
|
| 68 |
+
"""
|
| 69 |
+
Process a question and return a formatted answer according to GAIA benchmark requirements.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
question: The question to answer
|
| 73 |
+
task_id: Optional task ID for the GAIA benchmark
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
JSON string with the required GAIA format
|
| 77 |
+
"""
|
| 78 |
print(f"Processing question: {question}")
|
| 79 |
|
| 80 |
# Determine question type
|
| 81 |
question_type = self._classify_question(question)
|
| 82 |
print(f"Classified as: {question_type}")
|
| 83 |
|
| 84 |
+
# Generate reasoning trace if appropriate
|
| 85 |
+
reasoning_trace = self._generate_reasoning_trace(question, question_type)
|
| 86 |
+
|
| 87 |
+
# Use the appropriate handler to get the answer
|
| 88 |
+
model_answer = self.handlers[question_type](question)
|
| 89 |
|
| 90 |
# Ensure answer is concise and specific
|
| 91 |
+
model_answer = self._ensure_concise_answer(model_answer, question_type)
|
| 92 |
|
| 93 |
+
# Format the response according to GAIA requirements
|
| 94 |
+
response = {
|
| 95 |
+
"task_id": task_id if task_id else "unknown_task",
|
| 96 |
+
"model_answer": model_answer,
|
| 97 |
+
"reasoning_trace": reasoning_trace
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
# Return the formatted JSON response
|
| 101 |
+
return json.dumps(response, ensure_ascii=False)
|
| 102 |
+
|
| 103 |
+
def _generate_reasoning_trace(self, question: str, question_type: str) -> str:
|
| 104 |
+
"""Generate a reasoning trace for the question if appropriate."""
|
| 105 |
+
# For calculation and reasoning questions, provide a trace
|
| 106 |
+
if question_type == 'calculation':
|
| 107 |
+
# Extract numbers and operation from the question
|
| 108 |
+
numbers = re.findall(r'\d+', question)
|
| 109 |
+
|
| 110 |
+
if len(numbers) >= 2:
|
| 111 |
+
if re.search(r'(sum|add|plus|\+)', question.lower()):
|
| 112 |
+
return f"To find the sum, I add the numbers: {' + '.join(numbers)} = {sum(int(num) for num in numbers)}"
|
| 113 |
+
elif re.search(r'(difference|subtract|minus|\-)', question.lower()) and len(numbers) >= 2:
|
| 114 |
+
return f"To find the difference, I subtract: {numbers[0]} - {numbers[1]} = {int(numbers[0]) - int(numbers[1])}"
|
| 115 |
+
elif re.search(r'(product|multiply|times|\*)', question.lower()) and len(numbers) >= 2:
|
| 116 |
+
return f"To find the product, I multiply: {numbers[0]} × {numbers[1]} = {int(numbers[0]) * int(numbers[1])}"
|
| 117 |
+
elif re.search(r'(divide|division|\/)', question.lower()) and len(numbers) >= 2:
|
| 118 |
+
if int(numbers[1]) != 0:
|
| 119 |
+
return f"To find the quotient, I divide: {numbers[0]} ÷ {numbers[1]} = {int(numbers[0]) / int(numbers[1])}"
|
| 120 |
+
|
| 121 |
+
# If we can't generate a specific trace, use a generic one
|
| 122 |
+
return "I need to identify the numbers and operations in the question, then perform the calculation step by step."
|
| 123 |
+
|
| 124 |
+
elif question_type in ['factual', 'general'] and self.llm_available:
|
| 125 |
+
# For factual and general questions, use LLM to generate a trace
|
| 126 |
+
try:
|
| 127 |
+
prompt = f"Explain your reasoning for answering this question: {question}"
|
| 128 |
+
inputs = self.tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True).to(self.device)
|
| 129 |
+
outputs = self.model.generate(
|
| 130 |
+
inputs["input_ids"],
|
| 131 |
+
max_length=150,
|
| 132 |
+
min_length=20,
|
| 133 |
+
temperature=0.3,
|
| 134 |
+
top_p=0.95,
|
| 135 |
+
do_sample=True,
|
| 136 |
+
num_return_sequences=1
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
trace = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 140 |
+
return trace[:200] # Limit trace length
|
| 141 |
+
except:
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
# For other question types or if LLM fails, provide a minimal trace
|
| 145 |
+
return ""
|
| 146 |
|
| 147 |
def _classify_question(self, question: str) -> str:
|
| 148 |
"""Determine the type of question for specialized handling."""
|
|
|
|
| 567 |
continue
|
| 568 |
|
| 569 |
try:
|
| 570 |
+
# Call agent with task_id to ensure proper formatting
|
| 571 |
+
json_response = agent(question_text, task_id)
|
| 572 |
+
|
| 573 |
+
# Parse the JSON response
|
| 574 |
+
response_obj = json.loads(json_response)
|
| 575 |
+
|
| 576 |
+
# Extract the model_answer for submission
|
| 577 |
+
submitted_answer = response_obj.get("model_answer", "")
|
| 578 |
+
|
| 579 |
answers_payload.append({
|
| 580 |
"task_id": task_id,
|
| 581 |
"submitted_answer": submitted_answer
|
| 582 |
})
|
| 583 |
+
|
| 584 |
results_log.append({
|
| 585 |
"Task ID": task_id,
|
| 586 |
"Question": question_text,
|
| 587 |
+
"Submitted Answer": submitted_answer,
|
| 588 |
+
"Full Response": json_response
|
| 589 |
})
|
| 590 |
except Exception as e:
|
| 591 |
print(f"Error running agent on task {task_id}: {e}")
|
|
|
|
| 672 |
|
| 673 |
print("\n=== AGENT TEST RESULTS ===")
|
| 674 |
for question in test_questions:
|
| 675 |
+
# Generate a mock task_id for testing
|
| 676 |
+
task_id = f"test_{hash(question) % 10000}"
|
| 677 |
+
|
| 678 |
+
# Get formatted JSON response
|
| 679 |
+
json_response = agent(question, task_id)
|
| 680 |
+
|
| 681 |
print(f"\nQ: {question}")
|
| 682 |
+
print(f"Response: {json_response}")
|
| 683 |
+
|
| 684 |
+
# Parse and print the model_answer for clarity
|
| 685 |
+
try:
|
| 686 |
+
response_obj = json.loads(json_response)
|
| 687 |
+
print(f"Model Answer: {response_obj.get('model_answer', '')}")
|
| 688 |
+
except:
|
| 689 |
+
print("Error parsing JSON response")
|
| 690 |
|
| 691 |
return "Test completed successfully"
|
| 692 |
|