File size: 2,155 Bytes
223e45d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
"""
Simple test script for the GAIA agent
"""
import os
from dotenv import load_dotenv
from langchain_core.messages import HumanMessage
from agent import build_graph
# Load environment variables
load_dotenv()
# Verify API keys are set
print("Checking API keys...")
groq_key = os.getenv("GROQ_API_KEY")
tavily_key = os.getenv("TAVILY_API_KEY")
if not groq_key:
print("❌ GROQ_API_KEY not found in environment")
else:
print(f"✅ GROQ_API_KEY found: {groq_key[:10]}...")
if not tavily_key:
print("❌ TAVILY_API_KEY not found in environment")
else:
print(f"✅ TAVILY_API_KEY found: {tavily_key[:10]}...")
print("\n" + "="*60)
print("Building agent...")
print("="*60)
try:
agent = build_graph()
print("✅ Agent built successfully!")
except Exception as e:
print(f"❌ Error building agent: {e}")
exit(1)
# Test questions (simple ones to verify functionality)
test_questions = [
{
"question": "What is 25 * 4?",
"expected_type": "number",
"description": "Simple calculation test"
},
{
"question": "Who was the first president of the United States? Answer with just the name.",
"expected_type": "text",
"description": "Simple knowledge test"
}
]
print("\n" + "="*60)
print("Running tests...")
print("="*60)
for i, test in enumerate(test_questions, 1):
print(f"\n{'='*60}")
print(f"Test {i}: {test['description']}")
print(f"Question: {test['question']}")
print('='*60)
try:
config = {"configurable": {"thread_id": f"test_{i}"}}
result = agent.invoke(
{"messages": [HumanMessage(content=test['question'])]},
config=config
)
answer = result['messages'][-1].content
# Extract final answer if it has "Final Answer:" prefix
if "Final Answer:" in answer:
answer = answer.split("Final Answer:")[-1].strip()
print(f"✅ Answer: {answer}")
except Exception as e:
print(f"❌ Error: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*60)
print("Tests completed!")
print("="*60)
|