Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from typing import List, Dict | |
| from agent import GAIAAgent | |
| def normalize_answer(answer: str) -> str: | |
| """Normalize answer for comparison.""" | |
| if not answer: | |
| return "" | |
| # Remove common prefixes/suffixes | |
| answer = answer.strip() | |
| # Remove quotes if they wrap the entire answer | |
| if (answer.startswith('"') and answer.endswith('"')) or (answer.startswith("'") and answer.endswith("'")): | |
| answer = answer[1:-1] | |
| # Convert to lowercase for comparison | |
| return answer.lower().strip() | |
| def extract_final_answer(response: str) -> str: | |
| """Extract the final answer from the model response.""" | |
| if "FINAL ANSWER:" in response: | |
| answer = response.split("FINAL ANSWER:")[1].strip() | |
| # Clean up the answer - remove any trailing explanation | |
| answer = answer.split('\n')[0].strip() | |
| return answer | |
| # If no FINAL ANSWER format, try to extract from end of response | |
| lines = response.strip().split('\n') | |
| return lines[-1].strip() | |
| def load_gaia_dataset(dataset_path: str) -> List[Dict]: | |
| """Load GAIA dataset from JSON/JSONL file.""" | |
| tasks = [] | |
| if not os.path.exists(dataset_path): | |
| print(f"Dataset file not found: {dataset_path}") | |
| return tasks | |
| try: | |
| with open(dataset_path, "r", encoding="utf-8") as f: | |
| if dataset_path.endswith('.jsonl'): | |
| # JSONL format - one JSON object per line | |
| for line_num, line in enumerate(f, 1): | |
| line = line.strip() | |
| if line: | |
| try: | |
| task = json.loads(line) | |
| tasks.append(task) | |
| except json.JSONDecodeError as e: | |
| print(f"Error parsing line {line_num}: {e}") | |
| else: | |
| # Regular JSON format | |
| data = json.load(f) | |
| if isinstance(data, list): | |
| tasks = data | |
| elif isinstance(data, dict) and 'tasks' in data: | |
| tasks = data['tasks'] | |
| else: | |
| print("Unexpected JSON format") | |
| except Exception as e: | |
| print(f"Error loading dataset: {e}") | |
| print(f"Loaded {len(tasks)} tasks from {dataset_path}") | |
| return tasks | |
| def create_sample_dataset() -> List[Dict]: | |
| """Create a sample dataset for testing if no GAIA dataset is available.""" | |
| sample_tasks = [ | |
| { | |
| "task_id": "sample_1", | |
| "question": "What is 15 + 27?", | |
| "answer": "42", | |
| "level": 1, | |
| "file_name": None | |
| }, | |
| { | |
| "task_id": "sample_2", | |
| "question": "What is the capital of France?", | |
| "answer": "Paris", | |
| "level": 1, | |
| "file_name": None | |
| }, | |
| { | |
| "task_id": "sample_3", | |
| "question": "How many days are in a leap year?", | |
| "answer": "366", | |
| "level": 1, | |
| "file_name": None | |
| }, | |
| { | |
| "task_id": "sample_4", | |
| "question": "What is 2 * 6 * 7?", | |
| "answer": "84", | |
| "level": 1, | |
| "file_name": None | |
| }, | |
| { | |
| "task_id": "sample_5", | |
| "question": "What year did World War II end?", | |
| "answer": "1945", | |
| "level": 1, | |
| "file_name": None | |
| } | |
| ] | |
| print("Using sample dataset for testing") | |
| return sample_tasks | |
| def evaluate_agent(dataset_path: str = None, max_tasks: int = None) -> float: | |
| """Evaluate the GAIA agent on the dataset.""" | |
| # Load dataset | |
| if dataset_path and os.path.exists(dataset_path): | |
| tasks = load_gaia_dataset(dataset_path) | |
| else: | |
| print("No dataset file found, using sample tasks for testing") | |
| tasks = create_sample_dataset() | |
| if not tasks: | |
| print("No tasks to evaluate") | |
| return 0.0 | |
| # Limit number of tasks if specified | |
| if max_tasks: | |
| tasks = tasks[:max_tasks] | |
| print(f"Evaluating on first {len(tasks)} tasks") | |
| # Initialize agent | |
| print("Initializing GAIA agent...") | |
| agent = GAIAAgent() | |
| # Test API connection first | |
| print("Testing API connection...") | |
| test_response = agent.test_grok() | |
| if "error" in test_response.lower(): | |
| print(f"API test failed: {test_response}") | |
| return 0.0 | |
| else: | |
| print("API connection successful!") | |
| # Process tasks | |
| correct = 0 | |
| total = len(tasks) | |
| submission_entries = [] | |
| print(f"\nStarting evaluation on {total} tasks...") | |
| print("=" * 50) | |
| for i, task in enumerate(tasks, 1): | |
| task_id = task.get("task_id", f"task_{i}") | |
| question = task.get("question", "") | |
| expected_answer = task.get("answer", "") | |
| print(f"\nTask {i}/{total}: {task_id}") | |
| print(f"Question: {question[:100]}{'...' if len(question) > 100 else ''}") | |
| try: | |
| # Process task with agent | |
| response = agent.process_task(task) | |
| predicted_answer = extract_final_answer(response) | |
| print(f"Expected: {expected_answer}") | |
| print(f"Predicted: {predicted_answer}") | |
| # Compare answers (normalized) | |
| is_correct = normalize_answer(predicted_answer) == normalize_answer(expected_answer) | |
| if is_correct: | |
| correct += 1 | |
| print("β CORRECT") | |
| else: | |
| print("β INCORRECT") | |
| # Store submission entry | |
| submission_entries.append({ | |
| "task_id": task_id, | |
| "model_answer": predicted_answer, | |
| "reasoning_trace": response | |
| }) | |
| except Exception as e: | |
| print(f"Error processing task {task_id}: {e}") | |
| submission_entries.append({ | |
| "task_id": task_id, | |
| "model_answer": "ERROR", | |
| "reasoning_trace": f"Error: {str(e)}" | |
| }) | |
| # Progress update | |
| current_score = (correct / i) * 100 | |
| print(f"Current score: {correct}/{i} = {current_score:.1f}%") | |
| print("-" * 30) | |
| # Final score | |
| final_score = (correct / total) * 100 | |
| # Save submission file | |
| try: | |
| with open("submission.jsonl", "w", encoding="utf-8") as f: | |
| for entry in submission_entries: | |
| f.write(json.dumps(entry) + "\n") | |
| print(f"\nSubmission saved to submission.jsonl") | |
| except Exception as e: | |
| print(f"Error saving submission: {e}") | |
| # Print final results | |
| print("=" * 50) | |
| print("FINAL RESULTS") | |
| print("=" * 50) | |
| print(f"Total tasks: {total}") | |
| print(f"Correct answers: {correct}") | |
| print(f"Final score: {final_score:.2f}%") | |
| if final_score >= 30: | |
| print("π CONGRATULATIONS! Score β₯30% - Certificate achieved!") | |
| else: | |
| print(f"π Score below 30%. Need {30 - final_score:.2f}% more for certificate.") | |
| return final_score | |
| def main(): | |
| """Main evaluation function.""" | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Evaluate GAIA agent") | |
| parser.add_argument("--dataset", type=str, default="gaia_test.json", | |
| help="Path to GAIA dataset file") | |
| parser.add_argument("--max-tasks", type=int, default=None, | |
| help="Maximum number of tasks to evaluate") | |
| args = parser.parse_args() | |
| score = evaluate_agent(args.dataset, args.max_tasks) | |
| print(f"\nFinal evaluation score: {score:.2f}%") | |
| if score >= 30: | |
| print("Certificate requirements met! π") | |
| else: | |
| print("Keep working to reach 30% for the certificate! πͺ") | |
| if __name__ == "__main__": | |
| main() |