File size: 6,268 Bytes
fcf8749
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
"""
Tests for LangGraph allocation workflow.
Verifies workflow equivalence, decision logging, and performance.
"""

import pytest
import time
from datetime import date
from uuid import uuid4

from app.schemas.allocation_state import AllocationState
from app.services.langgraph_nodes import (
    ml_effort_node,
    route_planner_node,
    fairness_check_node,
    should_reoptimize,
)
from app.services.langgraph_workflow import (
    create_allocation_graph,
    get_workflow_visualization,
)


class TestAllocationState:
    """Tests for AllocationState schema."""
    
    def test_allocation_state_defaults(self):
        """AllocationState should have sensible defaults."""
        state = AllocationState()
        
        assert state.request == {}
        assert state.config_used is None
        assert state.decision_logs == []
        assert state.effort_matrix is None
        assert state.explanations == {}
    
    def test_allocation_state_serialization(self):
        """AllocationState should serialize to dict."""
        state = AllocationState(
            request={"test": "data"},
            driver_models=[{"id": "d1", "name": "Driver 1"}],
        )
        
        data = state.model_dump(mode="json")
        
        assert data["request"] == {"test": "data"}
        assert data["driver_models"] == [{"id": "d1", "name": "Driver 1"}]
    
    def test_allocation_state_deserialization(self):
        """AllocationState should deserialize from dict."""
        data = {
            "request": {"drivers": []},
            "config_used": {"gini_threshold": 0.35},
            "decision_logs": [{"agent_name": "TEST"}],
        }
        
        state = AllocationState.model_validate(data)
        
        assert state.request == {"drivers": []}
        assert state.config_used["gini_threshold"] == 0.35
        assert len(state.decision_logs) == 1


class TestLangGraphNodes:
    """Tests for individual LangGraph nodes."""
    
    def test_should_reoptimize_returns_reoptimize(self):
        """should_reoptimize should return 'reoptimize' when fairness check says REOPTIMIZE."""
        state = AllocationState(
            fairness_check_1={"status": "REOPTIMIZE"},
            route_proposal_2=None,
        )
        
        result = should_reoptimize(state)
        
        assert result == "reoptimize"
    
    def test_should_reoptimize_returns_continue(self):
        """should_reoptimize should return 'continue' when fairness check says ACCEPT."""
        state = AllocationState(
            fairness_check_1={"status": "ACCEPT"},
        )
        
        result = should_reoptimize(state)
        
        assert result == "continue"
    
    def test_should_reoptimize_skips_when_proposal2_exists(self):
        """should_reoptimize should return 'continue' if proposal 2 already exists."""
        state = AllocationState(
            fairness_check_1={"status": "REOPTIMIZE"},
            route_proposal_2={"allocation": []},  # Already have proposal 2
        )
        
        result = should_reoptimize(state)
        
        assert result == "continue"


class TestWorkflowGraph:
    """Tests for the LangGraph workflow."""
    
    def test_create_allocation_graph(self):
        """create_allocation_graph should return a compiled graph."""
        graph = create_allocation_graph()
        
        assert graph is not None
        # Graph should have nodes
        assert hasattr(graph, 'invoke') or hasattr(graph, 'ainvoke')
    
    def test_workflow_visualization(self):
        """get_workflow_visualization should return a Mermaid diagram."""
        diagram = get_workflow_visualization()
        
        assert "```mermaid" in diagram
        assert "ml_effort" in diagram
        assert "fairness_check_1" in diagram
        assert "explainability" in diagram
    
    def test_graph_with_gemini_disabled(self):
        """Graph should compile without Gemini node."""
        import os
        os.environ.pop("GOOGLE_API_KEY", None)
        
        graph = create_allocation_graph(enable_gemini=False)
        
        assert graph is not None


class TestDecisionLogging:
    """Tests for decision log generation."""
    
    def test_ml_effort_node_creates_log(self):
        """ml_effort_node should append to decision_logs."""
        # This test would require mock drivers/routes
        # Placeholder for full integration test
        pass
    
    def test_decision_log_format(self):
        """Decision logs should have required fields."""
        log_entry = {
            "timestamp": "2026-02-04T10:00:00",
            "agent_name": "ML_EFFORT",
            "step_type": "MATRIX_GENERATION",
            "input_snapshot": {"num_drivers": 5},
            "output_snapshot": {"matrix_size": 25},
        }
        
        assert "timestamp" in log_entry
        assert "agent_name" in log_entry
        assert "step_type" in log_entry
        assert "input_snapshot" in log_entry
        assert "output_snapshot" in log_entry


class TestWorkflowPerformance:
    """Performance tests for the workflow."""
    
    @pytest.mark.slow
    def test_state_serialization_performance(self):
        """State serialization should be fast."""
        state = AllocationState(
            request={"packages": [{"id": f"pkg_{i}"} for i in range(100)]},
            decision_logs=[{"step": i} for i in range(50)],
        )
        
        start = time.time()
        for _ in range(100):
            state.model_dump(mode="json")
        elapsed = time.time() - start
        
        # Should serialize 100 times in under 1 second
        assert elapsed < 1.0, f"Serialization too slow: {elapsed:.2f}s"


# Integration test placeholder
class TestWorkflowEquivalence:
    """Tests to verify LangGraph produces same results as original."""
    
    @pytest.mark.skip(reason="Requires full DB setup - run manually")
    async def test_workflow_produces_identical_results(self):
        """LangGraph workflow should produce identical results to original endpoint."""
        # This test compares original /allocate response with /allocate/langgraph
        # Requires:
        # 1. Same request to both endpoints
        # 2. Compare final allocations (ignoring timestamps/UUIDs)
        pass