File size: 6,268 Bytes
fcf8749 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | """
Tests for LangGraph allocation workflow.
Verifies workflow equivalence, decision logging, and performance.
"""
import pytest
import time
from datetime import date
from uuid import uuid4
from app.schemas.allocation_state import AllocationState
from app.services.langgraph_nodes import (
ml_effort_node,
route_planner_node,
fairness_check_node,
should_reoptimize,
)
from app.services.langgraph_workflow import (
create_allocation_graph,
get_workflow_visualization,
)
class TestAllocationState:
"""Tests for AllocationState schema."""
def test_allocation_state_defaults(self):
"""AllocationState should have sensible defaults."""
state = AllocationState()
assert state.request == {}
assert state.config_used is None
assert state.decision_logs == []
assert state.effort_matrix is None
assert state.explanations == {}
def test_allocation_state_serialization(self):
"""AllocationState should serialize to dict."""
state = AllocationState(
request={"test": "data"},
driver_models=[{"id": "d1", "name": "Driver 1"}],
)
data = state.model_dump(mode="json")
assert data["request"] == {"test": "data"}
assert data["driver_models"] == [{"id": "d1", "name": "Driver 1"}]
def test_allocation_state_deserialization(self):
"""AllocationState should deserialize from dict."""
data = {
"request": {"drivers": []},
"config_used": {"gini_threshold": 0.35},
"decision_logs": [{"agent_name": "TEST"}],
}
state = AllocationState.model_validate(data)
assert state.request == {"drivers": []}
assert state.config_used["gini_threshold"] == 0.35
assert len(state.decision_logs) == 1
class TestLangGraphNodes:
"""Tests for individual LangGraph nodes."""
def test_should_reoptimize_returns_reoptimize(self):
"""should_reoptimize should return 'reoptimize' when fairness check says REOPTIMIZE."""
state = AllocationState(
fairness_check_1={"status": "REOPTIMIZE"},
route_proposal_2=None,
)
result = should_reoptimize(state)
assert result == "reoptimize"
def test_should_reoptimize_returns_continue(self):
"""should_reoptimize should return 'continue' when fairness check says ACCEPT."""
state = AllocationState(
fairness_check_1={"status": "ACCEPT"},
)
result = should_reoptimize(state)
assert result == "continue"
def test_should_reoptimize_skips_when_proposal2_exists(self):
"""should_reoptimize should return 'continue' if proposal 2 already exists."""
state = AllocationState(
fairness_check_1={"status": "REOPTIMIZE"},
route_proposal_2={"allocation": []}, # Already have proposal 2
)
result = should_reoptimize(state)
assert result == "continue"
class TestWorkflowGraph:
"""Tests for the LangGraph workflow."""
def test_create_allocation_graph(self):
"""create_allocation_graph should return a compiled graph."""
graph = create_allocation_graph()
assert graph is not None
# Graph should have nodes
assert hasattr(graph, 'invoke') or hasattr(graph, 'ainvoke')
def test_workflow_visualization(self):
"""get_workflow_visualization should return a Mermaid diagram."""
diagram = get_workflow_visualization()
assert "```mermaid" in diagram
assert "ml_effort" in diagram
assert "fairness_check_1" in diagram
assert "explainability" in diagram
def test_graph_with_gemini_disabled(self):
"""Graph should compile without Gemini node."""
import os
os.environ.pop("GOOGLE_API_KEY", None)
graph = create_allocation_graph(enable_gemini=False)
assert graph is not None
class TestDecisionLogging:
"""Tests for decision log generation."""
def test_ml_effort_node_creates_log(self):
"""ml_effort_node should append to decision_logs."""
# This test would require mock drivers/routes
# Placeholder for full integration test
pass
def test_decision_log_format(self):
"""Decision logs should have required fields."""
log_entry = {
"timestamp": "2026-02-04T10:00:00",
"agent_name": "ML_EFFORT",
"step_type": "MATRIX_GENERATION",
"input_snapshot": {"num_drivers": 5},
"output_snapshot": {"matrix_size": 25},
}
assert "timestamp" in log_entry
assert "agent_name" in log_entry
assert "step_type" in log_entry
assert "input_snapshot" in log_entry
assert "output_snapshot" in log_entry
class TestWorkflowPerformance:
"""Performance tests for the workflow."""
@pytest.mark.slow
def test_state_serialization_performance(self):
"""State serialization should be fast."""
state = AllocationState(
request={"packages": [{"id": f"pkg_{i}"} for i in range(100)]},
decision_logs=[{"step": i} for i in range(50)],
)
start = time.time()
for _ in range(100):
state.model_dump(mode="json")
elapsed = time.time() - start
# Should serialize 100 times in under 1 second
assert elapsed < 1.0, f"Serialization too slow: {elapsed:.2f}s"
# Integration test placeholder
class TestWorkflowEquivalence:
"""Tests to verify LangGraph produces same results as original."""
@pytest.mark.skip(reason="Requires full DB setup - run manually")
async def test_workflow_produces_identical_results(self):
"""LangGraph workflow should produce identical results to original endpoint."""
# This test compares original /allocate response with /allocate/langgraph
# Requires:
# 1. Same request to both endpoints
# 2. Compare final allocations (ignoring timestamps/UUIDs)
pass
|