File size: 3,670 Bytes
58a1fee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import unittest
from unittest.mock import MagicMock, patch
import os

# Set dummy key if not present to avoid init errors
if "GEMINI_API_KEY" not in os.environ:
    os.environ["GEMINI_API_KEY"] = "dummy_key"
if "TAVILY_API_KEY" not in os.environ:
    os.environ["TAVILY_API_KEY"] = "dummy_key"

from agentic_rag_v2_graph import build_agentic_rag_v2_graph

class TestRagGraph(unittest.TestCase):
    @patch('agentic_rag_v2_graph.genai.GenerativeModel')
    @patch('agentic_rag_v2_graph.TavilyClient')
    def test_web_search_flow(self, mock_tavily, mock_genai):
        print("\n\n=== πŸ§ͺ STARTING DRY RUN GRAPH TEST ===")
        print("Goal: Verify 'research_web' -> 'verifier' -> 'responder' flow without API calls.\n")

        # === πŸ”§ DEMO CONFIGURATION (EDIT HERE) ===
        # Change these values to simulate different questions!
        DEMO_QUERY = "Who is the father of the computer?"
        EXPECTED_WEB_CONTENT = "Charles Babbage is considered by many as the father of the computer."
        VERIFIER_NOTE = "βœ… VERIFIED: Search results confirm Charles Babbage invented the Analytical Engine."
        FINAL_ANSWER = "Charles Babbage is the father of the computer."
        # =========================================

        # --- Setup Mocks ---
        mock_model = MagicMock()
        mock_genai.return_value = mock_model
        
        # Helper to create dummy response object
        def create_response(text):
            r = MagicMock()
            r.text = text
            return r

        # Sequence of LLM outputs (Order matters!):
        # 1. Supervisor: "research_web"
        # 2. Verifier: "The info looks consistent."
        # 3. Supervisor: "responder"
        # 4. Responder: "The Answer."
        
        mock_model.generate_content.side_effect = [
            create_response("research_web"),
            create_response(VERIFIER_NOTE),
            create_response("responder"),
            create_response(FINAL_ANSWER)
        ]
        
        # Mock Web Search Tool
        mock_tavily_instance = MagicMock()
        mock_tavily.return_value = mock_tavily_instance
        mock_tavily_instance.get_search_context.return_value = EXPECTED_WEB_CONTENT

        # --- Build Graph ---
        print("πŸ› οΈ  Building Graph...")
        try:
            graph = build_agentic_rag_v2_graph()
            print("βœ… Graph built successfully.")
        except Exception as e:
            self.fail(f"❌ Graph build failed: {e}")
        
        # --- Run Graph ---
        initial_state = {
            "messages": [],
            "query": DEMO_QUERY,
            "final_answer": "",
            "next_node": "",
            "current_tool": "",
            "tool_outputs": [],
            "verification_notes": "",
            "retries": 0
        }
        
        print("\nπŸƒ Invoking Graph (Mocked LLM)...")
        result = graph.invoke(initial_state, config={"configurable": {"thread_id": "test_dry_run"}})
        
        # --- Assertions ---
        print("\n\n=== πŸ“Š TEST RESULT ANALYSIS ===")
        print(f"Final Answer: {result['final_answer']}")
        print(f"Verification Notes: {result['verification_notes']}")
        
        self.assertIn("VERIFIED", result['verification_notes'], "❌ verifier_node did not populate verification_notes!")
        self.assertIn(FINAL_ANSWER, result['final_answer'], "❌ Responder did not fail gracefully.")
        
        print("\nβœ… SUCCESS: The Graph followed the correct path: Supervisor -> Web -> Verifier -> Supervisor -> Responder")
        print("βœ… SUCCESS: 'verifier_node' executed and produced notes.")

if __name__ == "__main__":
    unittest.main()