rag_agent / test_scripts.py
Cheh Kit Hong
fixing gradio
aa018e3
"""
Test script for RAG Agent logic.
Tests the agent workflow, nodes, state management, and retrieval.
"""
import sys
from pathlib import Path
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))
from langchain_core.messages import HumanMessage, AIMessage
from agent.state import AgentState
from core.rag_agent import RAGAgent
def print_separator(title: str):
"""Print a visual separator."""
print("\n" + "="*70)
print(f" {title}")
print("="*70 + "\n")
def test_agent_initialization():
"""Test RAGAgent can be initialized properly."""
print_separator("TEST 1: Agent Initialization")
try:
agent = RAGAgent()
print("βœ“ RAGAgent initialized successfully")
print(f" - Thread ID: {agent.thread_id}")
print(f" - LLM Model: {agent.llm.model_name if hasattr(agent.llm, 'model_name') else 'initialized'}")
print(f" - Graph: {type(agent.agent_graph).__name__}")
return agent
except Exception as e:
print(f"βœ— Failed to initialize RAGAgent: {e}")
import traceback
traceback.print_exc()
return None
def test_simple_query(agent: RAGAgent):
"""Test a simple query execution."""
print_separator("TEST 2: Simple Query")
if agent is None:
print("βœ— Skipping - agent not initialized")
return False
try:
query = "What is DeepAnalyze?"
print(f"Query: '{query}'")
initial_state = {
"messages": [HumanMessage(content=query)],
}
result = agent.agent_graph.invoke(
initial_state,
config=agent.get_config()
)
messages = result.get("messages", [])
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
if ai_messages:
print(f"βœ“ Query executed successfully")
print(f" Total messages: {len(messages)}")
print(f" Response length: {len(ai_messages[-1].content)} chars")
print(f"\n Response preview:")
print(f" {ai_messages[-1].content[:300]}...")
return True
else:
print(f"βœ— No AI response generated")
return False
except Exception as e:
print(f"βœ— Query execution failed: {e}")
import traceback
traceback.print_exc()
return False
def test_rag_query(agent: RAGAgent):
"""Test a query that should use RAG (local documents)."""
print_separator("TEST 3: RAG Query")
if agent is None:
print("βœ— Skipping - agent not initialized")
return False
try:
query = "Explain the architecture of SAM 3"
print(f"Query: '{query}' (should use local documents)")
initial_state = {
"messages": [HumanMessage(content=query)],
}
result = agent.agent_graph.invoke(
initial_state,
config=agent.get_config()
)
messages = result.get("messages", [])
rag_method = result.get("rag_method", "UNKNOWN")
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
print(f" Routing decision: {rag_method}")
if ai_messages:
print(f"βœ“ RAG query executed")
print(f" Response preview:")
print(f" {ai_messages[-1].content[:300]}...")
return True
else:
print(f"βœ— No response generated")
return False
except Exception as e:
print(f"βœ— RAG query failed: {e}")
import traceback
traceback.print_exc()
return False
def test_web_search_query(agent: RAGAgent):
"""Test a query that should use web search."""
print_separator("TEST 4: Web Search Query")
if agent is None:
print("βœ— Skipping - agent not initialized")
return False
try:
query = "What's the latest news about AI in 2025?"
print(f"Query: '{query}' (should use web search)")
initial_state = {
"messages": [HumanMessage(content=query)],
}
result = agent.agent_graph.invoke(
initial_state,
config=agent.get_config()
)
messages = result.get("messages", [])
rag_method = result.get("rag_method", "UNKNOWN")
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
print(f" Routing decision: {rag_method}")
if ai_messages:
print(f"βœ“ Web search query executed")
print(f" Response preview:")
print(f" {ai_messages[-1].content[:300]}...")
return True
else:
print(f"βœ— No response generated")
return False
except Exception as e:
print(f"βœ— Web search query failed: {e}")
import traceback
traceback.print_exc()
return False
def test_general_query(agent: RAGAgent):
"""Test a general query that doesn't need RAG or web search."""
print_separator("TEST 5: General Query")
if agent is None:
print("βœ— Skipping - agent not initialized")
return False
try:
query = "What is 15 multiplied by 7?"
print(f"Query: '{query}' (should use general LLM)")
initial_state = {
"messages": [HumanMessage(content=query)],
}
result = agent.agent_graph.invoke(
initial_state,
config=agent.get_config()
)
messages = result.get("messages", [])
rag_method = result.get("rag_method", "UNKNOWN")
ai_messages = [m for m in messages if isinstance(m, AIMessage)]
print(f" Routing decision: {rag_method}")
if ai_messages:
print(f"βœ“ General query executed")
print(f" Response: {ai_messages[-1].content}")
return True
else:
print(f"βœ— No response generated")
return False
except Exception as e:
print(f"βœ— General query failed: {e}")
import traceback
traceback.print_exc()
return False
def test_conversation_memory(agent: RAGAgent):
"""Test multi-turn conversation with memory."""
print_separator("TEST 6: Conversation Memory")
if agent is None:
print("βœ— Skipping - agent not initialized")
return False
try:
# Reset thread for clean test
agent.reset_thread()
# First turn
print("Turn 1: 'What is DeepAnalyze?'")
state1 = {
"messages": [HumanMessage(content="What is DeepAnalyze?")],
}
result1 = agent.agent_graph.invoke(state1, config=agent.get_config())
ai_msg_1 = [m for m in result1["messages"] if isinstance(m, AIMessage)]
if not ai_msg_1:
print("βœ— No response in turn 1")
return False
print(f"βœ“ Turn 1 response: {ai_msg_1[-1].content[:100]}...")
# Second turn - follow-up question
print("\nTurn 2: 'What are its main features?' (requires context)")
state2 = {
"messages": [HumanMessage(content="What are its main features?")],
}
result2 = agent.agent_graph.invoke(state2, config=agent.get_config())
ai_msg_2 = [m for m in result2["messages"] if isinstance(m, AIMessage)]
if not ai_msg_2:
print("βœ— No response in turn 2")
return False
print(f"βœ“ Turn 2 response: {ai_msg_2[-1].content[:100]}...")
# Check if response makes sense in context
response = ai_msg_2[-1].content.lower()
if "deepanalyze" in response or "feature" in response or "agent" in response:
print("βœ“ Conversation memory working - response uses context")
return True
else:
print("⚠ Response may not be using conversation context properly")
return True # Still pass, as it generated a response
except Exception as e:
print(f"βœ— Conversation memory test failed: {e}")
import traceback
traceback.print_exc()
return False
def test_thread_reset(agent: RAGAgent):
"""Test thread reset functionality."""
print_separator("TEST 7: Thread Reset")
if agent is None:
print("βœ— Skipping - agent not initialized")
return False
try:
old_thread_id = agent.thread_id
print(f"Old thread ID: {old_thread_id}")
agent.reset_thread()
new_thread_id = agent.thread_id
print(f"New thread ID: {new_thread_id}")
if old_thread_id != new_thread_id:
print("βœ“ Thread reset successfully")
return True
else:
print("βœ— Thread ID unchanged after reset")
return False
except Exception as e:
print(f"βœ— Thread reset failed: {e}")
import traceback
traceback.print_exc()
return False
def run_all_tests():
"""Run all tests and provide summary."""
print("\n" + "β–ˆ"*70)
print(" RAG AGENT TEST SUITE")
print("β–ˆ"*70)
# Initialize agent once
agent = test_agent_initialization()
if agent is None:
print("\nβœ— Cannot proceed - agent initialization failed")
return False
tests = [
("Simple Query", lambda: test_simple_query(agent)),
("RAG Query", lambda: test_rag_query(agent)),
("Web Search Query", lambda: test_web_search_query(agent)),
("General Query", lambda: test_general_query(agent)),
("Conversation Memory", lambda: test_conversation_memory(agent)),
("Thread Reset", lambda: test_thread_reset(agent)),
]
results = {}
for name, test_func in tests:
try:
results[name] = test_func()
except Exception as e:
print(f"\nβœ— Test '{name}' crashed: {e}")
import traceback
traceback.print_exc()
results[name] = False
# Print summary
print_separator("TEST SUMMARY")
passed = sum(results.values())
total = len(results)
for name, passed_test in results.items():
status = "βœ“ PASS" if passed_test else "βœ— FAIL"
print(f"{status}: {name}")
print(f"\n{'='*70}")
print(f" TOTAL: {passed}/{total} tests passed ({passed/total*100:.1f}%)")
print(f"{'='*70}\n")
return passed == total
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)