|
|
""" |
|
|
Simple test script for the GAIA agent |
|
|
""" |
|
|
import os |
|
|
from dotenv import load_dotenv |
|
|
from langchain_core.messages import HumanMessage |
|
|
from agent import build_graph |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
print("Checking API keys...") |
|
|
groq_key = os.getenv("GROQ_API_KEY") |
|
|
tavily_key = os.getenv("TAVILY_API_KEY") |
|
|
|
|
|
if not groq_key: |
|
|
print("β GROQ_API_KEY not found in environment") |
|
|
else: |
|
|
print(f"β
GROQ_API_KEY found: {groq_key[:10]}...") |
|
|
|
|
|
if not tavily_key: |
|
|
print("β TAVILY_API_KEY not found in environment") |
|
|
else: |
|
|
print(f"β
TAVILY_API_KEY found: {tavily_key[:10]}...") |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Building agent...") |
|
|
print("="*60) |
|
|
|
|
|
try: |
|
|
agent = build_graph() |
|
|
print("β
Agent built successfully!") |
|
|
except Exception as e: |
|
|
print(f"β Error building agent: {e}") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
test_questions = [ |
|
|
{ |
|
|
"question": "What is 25 * 4?", |
|
|
"expected_type": "number", |
|
|
"description": "Simple calculation test" |
|
|
}, |
|
|
{ |
|
|
"question": "Who was the first president of the United States? Answer with just the name.", |
|
|
"expected_type": "text", |
|
|
"description": "Simple knowledge test" |
|
|
} |
|
|
] |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Running tests...") |
|
|
print("="*60) |
|
|
|
|
|
for i, test in enumerate(test_questions, 1): |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Test {i}: {test['description']}") |
|
|
print(f"Question: {test['question']}") |
|
|
print('='*60) |
|
|
|
|
|
try: |
|
|
config = {"configurable": {"thread_id": f"test_{i}"}} |
|
|
result = agent.invoke( |
|
|
{"messages": [HumanMessage(content=test['question'])]}, |
|
|
config=config |
|
|
) |
|
|
answer = result['messages'][-1].content |
|
|
|
|
|
|
|
|
if "Final Answer:" in answer: |
|
|
answer = answer.split("Final Answer:")[-1].strip() |
|
|
|
|
|
print(f"β
Answer: {answer}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("Tests completed!") |
|
|
print("="*60) |
|
|
|
|
|
|