|
|
|
|
|
""" |
|
|
Quick test script for specific GAIA questions. |
|
|
Use this to verify fixes without running full evaluation. |
|
|
|
|
|
Usage: |
|
|
uv run python test/test_quick_fixes.py |
|
|
""" |
|
|
|
|
|
import os |
|
|
import sys |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
|
|
|
from src.agent.graph import GAIAAgent |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TEST_QUESTIONS = [ |
|
|
{ |
|
|
"task_id": "2d83110e-a098-4ebb-9987-066c06fa42d0", |
|
|
"name": "Reverse sentence (calculator threading fix)", |
|
|
"question": ".rewsna eht sa \"tfel\" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI", |
|
|
"expected": "Right", |
|
|
}, |
|
|
{ |
|
|
"task_id": "6f37996b-2ac7-44b0-8e68-6d28256631b4", |
|
|
"name": "Table commutativity (LLM issue - table in question)", |
|
|
"question": '''Given this table defining * on the set S = {a, b, c, d, e} |
|
|
|
|
|
|*|a|b|c|d|e| |
|
|
|---|---|---|---|---| |
|
|
|a|a|b|c|b|d| |
|
|
|b|b|c|a|e|c| |
|
|
|c|c|a|b|b|a| |
|
|
|d|b|e|b|e|d| |
|
|
|e|d|b|a|d|c| |
|
|
|
|
|
provide the subset of S involved in any possible counter-examples that prove * is not commutative. Provide your answer as a comma separated list of the elements in the set in alphabetical order.''', |
|
|
"expected": "b, e", |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_question(agent: GAIAAgent, test_case: dict) -> dict: |
|
|
"""Test a single question and return result.""" |
|
|
task_id = test_case["task_id"] |
|
|
question = test_case["question"] |
|
|
expected = test_case.get("expected", "N/A") |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Testing: {test_case['name']}") |
|
|
print(f"Task ID: {task_id}") |
|
|
print(f"Expected: {expected}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
try: |
|
|
answer = agent(question, file_path=None) |
|
|
|
|
|
|
|
|
is_correct = answer.strip().lower() == expected.lower().strip() |
|
|
|
|
|
result = { |
|
|
"task_id": task_id, |
|
|
"name": test_case["name"], |
|
|
"question": question[:100] + "..." if len(question) > 100 else question, |
|
|
"expected": expected, |
|
|
"answer": answer, |
|
|
"correct": is_correct, |
|
|
"status": "success", |
|
|
} |
|
|
|
|
|
|
|
|
if not answer: |
|
|
result["system_error"] = "yes" |
|
|
elif answer.lower().startswith("error:") or "no evidence collected" in answer.lower(): |
|
|
result["system_error"] = "yes" |
|
|
result["error_log"] = answer |
|
|
else: |
|
|
result["system_error"] = "no" |
|
|
|
|
|
except Exception as e: |
|
|
result = { |
|
|
"task_id": task_id, |
|
|
"name": test_case["name"], |
|
|
"question": question[:100] + "..." if len(question) > 100 else question, |
|
|
"expected": expected, |
|
|
"answer": f"ERROR: {str(e)}", |
|
|
"correct": False, |
|
|
"status": "error", |
|
|
"system_error": "yes", |
|
|
"error_log": str(e), |
|
|
} |
|
|
|
|
|
|
|
|
status_icon = "✅" if result["correct"] else "❌" if result["system_error"] == "no" else "⚠️" |
|
|
print(f"\n{status_icon} Result: {result['answer'][:100]}") |
|
|
if result["system_error"] == "yes": |
|
|
print(f" System Error: Yes") |
|
|
if result.get("error_log"): |
|
|
print(f" Error: {result['error_log'][:100]}") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run quick tests on specific questions.""" |
|
|
print("\n" + "="*60) |
|
|
print("GAIA Quick Test - Verify Fixes") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
llm_provider = os.getenv("LLM_PROVIDER", "gemini") |
|
|
print(f"\nLLM Provider: {llm_provider}") |
|
|
|
|
|
|
|
|
print("\nInitializing agent...") |
|
|
try: |
|
|
agent = GAIAAgent() |
|
|
print("✅ Agent initialized") |
|
|
except Exception as e: |
|
|
print(f"❌ Agent initialization failed: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
results = [] |
|
|
for test_case in TEST_QUESTIONS: |
|
|
result = test_question(agent, test_case) |
|
|
results.append(result) |
|
|
|
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("SUMMARY") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
success_count = sum(1 for r in results if r["correct"]) |
|
|
error_count = sum(1 for r in results if r["system_error"] == "yes") |
|
|
ai_fail_count = sum(1 for r in results if r["system_error"] == "no" and not r["correct"]) |
|
|
|
|
|
print(f"\nTotal: {len(results)}") |
|
|
print(f"✅ Correct: {success_count}") |
|
|
print(f"⚠️ System Errors: {error_count}") |
|
|
print(f"❌ AI Wrong: {ai_fail_count}") |
|
|
|
|
|
|
|
|
print(f"\nDetailed Results:") |
|
|
for r in results: |
|
|
status = "✅" if r["correct"] else "⚠️" if r["system_error"] == "yes" else "❌" |
|
|
print(f" {status} {r['name']}: {r['answer'][:50]}{'...' if len(r['answer']) > 50 else ''}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|