""" Test script for RAG Agent logic. Tests the agent workflow, nodes, state management, and retrieval. """ import sys from pathlib import Path # Add project root to path sys.path.insert(0, str(Path(__file__).parent)) from langchain_core.messages import HumanMessage, AIMessage from agent.state import AgentState from core.rag_agent import RAGAgent def print_separator(title: str): """Print a visual separator.""" print("\n" + "="*70) print(f" {title}") print("="*70 + "\n") def test_agent_initialization(): """Test RAGAgent can be initialized properly.""" print_separator("TEST 1: Agent Initialization") try: agent = RAGAgent() print("✓ RAGAgent initialized successfully") print(f" - Thread ID: {agent.thread_id}") print(f" - LLM Model: {agent.llm.model_name if hasattr(agent.llm, 'model_name') else 'initialized'}") print(f" - Graph: {type(agent.agent_graph).__name__}") return agent except Exception as e: print(f"✗ Failed to initialize RAGAgent: {e}") import traceback traceback.print_exc() return None def test_simple_query(agent: RAGAgent): """Test a simple query execution.""" print_separator("TEST 2: Simple Query") if agent is None: print("✗ Skipping - agent not initialized") return False try: query = "What is DeepAnalyze?" print(f"Query: '{query}'") initial_state = { "messages": [HumanMessage(content=query)], } result = agent.agent_graph.invoke( initial_state, config=agent.get_config() ) messages = result.get("messages", []) ai_messages = [m for m in messages if isinstance(m, AIMessage)] if ai_messages: print(f"✓ Query executed successfully") print(f" Total messages: {len(messages)}") print(f" Response length: {len(ai_messages[-1].content)} chars") print(f"\n Response preview:") print(f" {ai_messages[-1].content[:300]}...") return True else: print(f"✗ No AI response generated") return False except Exception as e: print(f"✗ Query execution failed: {e}") import traceback traceback.print_exc() return False def test_rag_query(agent: RAGAgent): """Test a query that should use RAG (local documents).""" print_separator("TEST 3: RAG Query") if agent is None: print("✗ Skipping - agent not initialized") return False try: query = "Explain the architecture of SAM 3" print(f"Query: '{query}' (should use local documents)") initial_state = { "messages": [HumanMessage(content=query)], } result = agent.agent_graph.invoke( initial_state, config=agent.get_config() ) messages = result.get("messages", []) rag_method = result.get("rag_method", "UNKNOWN") ai_messages = [m for m in messages if isinstance(m, AIMessage)] print(f" Routing decision: {rag_method}") if ai_messages: print(f"✓ RAG query executed") print(f" Response preview:") print(f" {ai_messages[-1].content[:300]}...") return True else: print(f"✗ No response generated") return False except Exception as e: print(f"✗ RAG query failed: {e}") import traceback traceback.print_exc() return False def test_web_search_query(agent: RAGAgent): """Test a query that should use web search.""" print_separator("TEST 4: Web Search Query") if agent is None: print("✗ Skipping - agent not initialized") return False try: query = "What's the latest news about AI in 2025?" print(f"Query: '{query}' (should use web search)") initial_state = { "messages": [HumanMessage(content=query)], } result = agent.agent_graph.invoke( initial_state, config=agent.get_config() ) messages = result.get("messages", []) rag_method = result.get("rag_method", "UNKNOWN") ai_messages = [m for m in messages if isinstance(m, AIMessage)] print(f" Routing decision: {rag_method}") if ai_messages: print(f"✓ Web search query executed") print(f" Response preview:") print(f" {ai_messages[-1].content[:300]}...") return True else: print(f"✗ No response generated") return False except Exception as e: print(f"✗ Web search query failed: {e}") import traceback traceback.print_exc() return False def test_general_query(agent: RAGAgent): """Test a general query that doesn't need RAG or web search.""" print_separator("TEST 5: General Query") if agent is None: print("✗ Skipping - agent not initialized") return False try: query = "What is 15 multiplied by 7?" print(f"Query: '{query}' (should use general LLM)") initial_state = { "messages": [HumanMessage(content=query)], } result = agent.agent_graph.invoke( initial_state, config=agent.get_config() ) messages = result.get("messages", []) rag_method = result.get("rag_method", "UNKNOWN") ai_messages = [m for m in messages if isinstance(m, AIMessage)] print(f" Routing decision: {rag_method}") if ai_messages: print(f"✓ General query executed") print(f" Response: {ai_messages[-1].content}") return True else: print(f"✗ No response generated") return False except Exception as e: print(f"✗ General query failed: {e}") import traceback traceback.print_exc() return False def test_conversation_memory(agent: RAGAgent): """Test multi-turn conversation with memory.""" print_separator("TEST 6: Conversation Memory") if agent is None: print("✗ Skipping - agent not initialized") return False try: # Reset thread for clean test agent.reset_thread() # First turn print("Turn 1: 'What is DeepAnalyze?'") state1 = { "messages": [HumanMessage(content="What is DeepAnalyze?")], } result1 = agent.agent_graph.invoke(state1, config=agent.get_config()) ai_msg_1 = [m for m in result1["messages"] if isinstance(m, AIMessage)] if not ai_msg_1: print("✗ No response in turn 1") return False print(f"✓ Turn 1 response: {ai_msg_1[-1].content[:100]}...") # Second turn - follow-up question print("\nTurn 2: 'What are its main features?' (requires context)") state2 = { "messages": [HumanMessage(content="What are its main features?")], } result2 = agent.agent_graph.invoke(state2, config=agent.get_config()) ai_msg_2 = [m for m in result2["messages"] if isinstance(m, AIMessage)] if not ai_msg_2: print("✗ No response in turn 2") return False print(f"✓ Turn 2 response: {ai_msg_2[-1].content[:100]}...") # Check if response makes sense in context response = ai_msg_2[-1].content.lower() if "deepanalyze" in response or "feature" in response or "agent" in response: print("✓ Conversation memory working - response uses context") return True else: print("⚠ Response may not be using conversation context properly") return True # Still pass, as it generated a response except Exception as e: print(f"✗ Conversation memory test failed: {e}") import traceback traceback.print_exc() return False def test_thread_reset(agent: RAGAgent): """Test thread reset functionality.""" print_separator("TEST 7: Thread Reset") if agent is None: print("✗ Skipping - agent not initialized") return False try: old_thread_id = agent.thread_id print(f"Old thread ID: {old_thread_id}") agent.reset_thread() new_thread_id = agent.thread_id print(f"New thread ID: {new_thread_id}") if old_thread_id != new_thread_id: print("✓ Thread reset successfully") return True else: print("✗ Thread ID unchanged after reset") return False except Exception as e: print(f"✗ Thread reset failed: {e}") import traceback traceback.print_exc() return False def run_all_tests(): """Run all tests and provide summary.""" print("\n" + "█"*70) print(" RAG AGENT TEST SUITE") print("█"*70) # Initialize agent once agent = test_agent_initialization() if agent is None: print("\n✗ Cannot proceed - agent initialization failed") return False tests = [ ("Simple Query", lambda: test_simple_query(agent)), ("RAG Query", lambda: test_rag_query(agent)), ("Web Search Query", lambda: test_web_search_query(agent)), ("General Query", lambda: test_general_query(agent)), ("Conversation Memory", lambda: test_conversation_memory(agent)), ("Thread Reset", lambda: test_thread_reset(agent)), ] results = {} for name, test_func in tests: try: results[name] = test_func() except Exception as e: print(f"\n✗ Test '{name}' crashed: {e}") import traceback traceback.print_exc() results[name] = False # Print summary print_separator("TEST SUMMARY") passed = sum(results.values()) total = len(results) for name, passed_test in results.items(): status = "✓ PASS" if passed_test else "✗ FAIL" print(f"{status}: {name}") print(f"\n{'='*70}") print(f" TOTAL: {passed}/{total} tests passed ({passed/total*100:.1f}%)") print(f"{'='*70}\n") return passed == total if __name__ == "__main__": success = run_all_tests() sys.exit(0 if success else 1)