File size: 13,854 Bytes
7b2787b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
"""
Tests for the Workflow Engine core components.
"""

import pytest
import asyncio
from typing import Dict, Any

from app.engine.state import WorkflowState, StateManager
from app.engine.node import Node, NodeType, node, create_node_from_function
from app.engine.graph import Graph, END
from app.engine.executor import Executor, ExecutionStatus, execute_graph


# ============================================================
# State Tests
# ============================================================

class TestWorkflowState:
    """Tests for WorkflowState."""
    
    def test_create_empty_state(self):
        """Test creating an empty state."""
        state = WorkflowState()
        assert state.data == {}
        assert state.iteration == 0
        assert state.visited_nodes == []
    
    def test_create_state_with_data(self):
        """Test creating state with initial data."""
        state = WorkflowState(data={"key": "value"})
        assert state.get("key") == "value"
        assert state.get("missing") is None
        assert state.get("missing", "default") == "default"
    
    def test_state_immutability(self):
        """Test that state updates return new instances."""
        state1 = WorkflowState(data={"a": 1})
        state2 = state1.set("b", 2)
        
        assert state1.get("b") is None
        assert state2.get("b") == 2
        assert state1 is not state2
    
    def test_state_update_multiple(self):
        """Test updating multiple values at once."""
        state = WorkflowState(data={"a": 1})
        new_state = state.update({"b": 2, "c": 3})
        
        assert new_state.get("a") == 1
        assert new_state.get("b") == 2
        assert new_state.get("c") == 3
    
    def test_state_mark_visited(self):
        """Test marking nodes as visited."""
        state = WorkflowState()
        state = state.mark_visited("node1")
        state = state.mark_visited("node2")
        
        assert "node1" in state.visited_nodes
        assert "node2" in state.visited_nodes
        assert state.current_node == "node2"
    
    def test_state_to_from_dict(self):
        """Test serialization and deserialization."""
        state = WorkflowState(data={"test": 123})
        state_dict = state.to_dict()
        
        assert "data" in state_dict
        assert state_dict["data"]["test"] == 123
        
        restored = WorkflowState.from_dict(state_dict)
        assert restored.get("test") == 123


class TestStateManager:
    """Tests for StateManager."""
    
    def test_initialize(self):
        """Test state manager initialization."""
        manager = StateManager()
        state = manager.initialize({"input": "test"})
        
        assert manager.current_state is not None
        assert manager.current_state.get("input") == "test"
        assert manager.current_state.started_at is not None
    
    def test_update_and_history(self):
        """Test state updates create history."""
        manager = StateManager()
        state = manager.initialize({"count": 0})
        
        new_state = state.set("count", 1)
        manager.update(new_state, "node1")
        
        assert len(manager.history) == 1
        assert manager.history[0].node_name == "node1"
        assert manager.current_state.get("count") == 1


# ============================================================
# Node Tests
# ============================================================

class TestNode:
    """Tests for Node class."""
    
    def test_create_node(self):
        """Test creating a node."""
        def handler(state):
            return state
        
        n = Node(name="test_node", handler=handler)
        
        assert n.name == "test_node"
        assert n.handler == handler
        assert n.node_type == NodeType.STANDARD
    
    def test_node_validation(self):
        """Test node validation."""
        with pytest.raises(ValueError, match="name cannot be empty"):
            Node(name="", handler=lambda x: x)
        
        with pytest.raises(ValueError, match="must be callable"):
            Node(name="test", handler="not a function")
    
    @pytest.mark.asyncio
    async def test_sync_node_execution(self):
        """Test executing a sync node."""
        def handler(state):
            state["processed"] = True
            return state
        
        n = Node(name="test", handler=handler)
        result = await n.execute({"input": "data"})
        
        assert result["processed"] is True
        assert result["input"] == "data"
    
    @pytest.mark.asyncio
    async def test_async_node_execution(self):
        """Test executing an async node."""
        async def async_handler(state):
            await asyncio.sleep(0.01)
            state["async_processed"] = True
            return state
        
        n = Node(name="async_test", handler=async_handler)
        assert n.is_async is True
        
        result = await n.execute({"input": "data"})
        assert result["async_processed"] is True
    
    def test_node_decorator(self):
        """Test the @node decorator."""
        @node(name="decorated_node", description="A test node")
        def my_handler(state):
            return state
        
        assert hasattr(my_handler, "_node_metadata")
        assert my_handler._node_metadata["name"] == "decorated_node"


# ============================================================
# Graph Tests
# ============================================================

class TestGraph:
    """Tests for Graph class."""
    
    def test_create_graph(self):
        """Test creating a graph."""
        graph = Graph(name="Test Graph")
        assert graph.name == "Test Graph"
        assert len(graph.nodes) == 0
    
    def test_add_nodes(self):
        """Test adding nodes to a graph."""
        graph = Graph()
        graph.add_node("node1", handler=lambda s: s)
        graph.add_node("node2", handler=lambda s: s)
        
        assert "node1" in graph.nodes
        assert "node2" in graph.nodes
        assert graph.entry_point == "node1"  # First node is entry
    
    def test_add_edges(self):
        """Test adding edges."""
        graph = Graph()
        graph.add_node("a", handler=lambda s: s)
        graph.add_node("b", handler=lambda s: s)
        graph.add_edge("a", "b")
        
        assert graph.edges["a"] == "b"
    
    def test_add_edge_to_end(self):
        """Test adding edge to END."""
        graph = Graph()
        graph.add_node("a", handler=lambda s: s)
        graph.add_edge("a", END)
        
        assert graph.edges["a"] == END
    
    def test_invalid_edge(self):
        """Test adding invalid edges raises error."""
        graph = Graph()
        graph.add_node("a", handler=lambda s: s)
        
        with pytest.raises(ValueError, match="not found"):
            graph.add_edge("a", "nonexistent")
    
    def test_conditional_edge(self):
        """Test conditional edges."""
        graph = Graph()
        graph.add_node("check", handler=lambda s: s)
        graph.add_node("yes", handler=lambda s: s)
        graph.add_node("no", handler=lambda s: s)
        
        def condition(state):
            return "yes" if state.get("value") else "no"
        
        graph.add_conditional_edge("check", condition, {"yes": "yes", "no": "no"})
        
        # Test routing
        assert graph.get_next_node("check", {"value": True}) == "yes"
        assert graph.get_next_node("check", {"value": False}) == "no"
    
    def test_graph_validation(self):
        """Test graph validation."""
        graph = Graph()
        
        # Empty graph should fail
        errors = graph.validate()
        assert len(errors) > 0
        
        # Valid graph
        graph.add_node("start", handler=lambda s: s)
        graph.add_edge("start", END)
        
        errors = graph.validate()
        assert len(errors) == 0
    
    def test_mermaid_generation(self):
        """Test Mermaid diagram generation."""
        graph = Graph()
        graph.add_node("a", handler=lambda s: s)
        graph.add_node("b", handler=lambda s: s)
        graph.add_edge("a", "b")
        graph.add_edge("b", END)
        
        mermaid = graph.to_mermaid()
        
        assert "graph TD" in mermaid
        assert "a" in mermaid
        assert "b" in mermaid


# ============================================================
# Executor Tests
# ============================================================

class TestExecutor:
    """Tests for the Executor."""
    
    @pytest.mark.asyncio
    async def test_simple_execution(self):
        """Test executing a simple graph."""
        graph = Graph()
        graph.add_node("double", handler=lambda s: {**s, "value": s["value"] * 2})
        graph.add_edge("double", END)
        
        result = await execute_graph(graph, {"value": 5})
        
        assert result.status == ExecutionStatus.COMPLETED
        assert result.final_state["value"] == 10
    
    @pytest.mark.asyncio
    async def test_multi_node_execution(self):
        """Test executing multiple nodes."""
        graph = Graph()
        graph.add_node("add1", handler=lambda s: {**s, "value": s["value"] + 1})
        graph.add_node("add2", handler=lambda s: {**s, "value": s["value"] + 2})
        graph.add_edge("add1", "add2")
        graph.add_edge("add2", END)
        
        result = await execute_graph(graph, {"value": 0})
        
        assert result.status == ExecutionStatus.COMPLETED
        assert result.final_state["value"] == 3
        assert len(result.execution_log) == 2
    
    @pytest.mark.asyncio
    async def test_conditional_execution(self):
        """Test conditional branching."""
        graph = Graph()
        graph.add_node("start", handler=lambda s: s)
        graph.add_node("high", handler=lambda s: {**s, "path": "high"})
        graph.add_node("low", handler=lambda s: {**s, "path": "low"})
        
        def route(state):
            return "high" if state["value"] > 5 else "low"
        
        graph.add_conditional_edge("start", route, {"high": "high", "low": "low"})
        graph.add_edge("high", END)
        graph.add_edge("low", END)
        
        # Test high path
        result = await execute_graph(graph, {"value": 10})
        assert result.final_state["path"] == "high"
        
        # Test low path
        result = await execute_graph(graph, {"value": 2})
        assert result.final_state["path"] == "low"
    
    @pytest.mark.asyncio
    async def test_loop_execution(self):
        """Test looping execution."""
        graph = Graph(max_iterations=10)
        
        def increment(state):
            return {**state, "count": state["count"] + 1}
        
        def check_count(state):
            return "done" if state["count"] >= 3 else "continue"
        
        graph.add_node("increment", handler=increment)
        graph.add_conditional_edge("increment", check_count, {"done": END, "continue": "increment"})
        
        result = await execute_graph(graph, {"count": 0})
        
        assert result.status == ExecutionStatus.COMPLETED
        assert result.final_state["count"] == 3
    
    @pytest.mark.asyncio
    async def test_max_iterations(self):
        """Test max iterations limit."""
        graph = Graph(max_iterations=3)
        
        # Infinite loop
        graph.add_node("loop", handler=lambda s: s)
        graph.add_conditional_edge("loop", lambda s: "continue", {"continue": "loop"})
        
        result = await execute_graph(graph, {})
        
        assert result.status == ExecutionStatus.FAILED
        assert "Max iterations" in result.error
    
    @pytest.mark.asyncio
    async def test_error_handling(self):
        """Test error handling during execution."""
        def failing_handler(state):
            raise ValueError("Intentional error")
        
        graph = Graph()
        graph.add_node("fail", handler=failing_handler)
        
        result = await execute_graph(graph, {})
        
        assert result.status == ExecutionStatus.FAILED
        assert "Intentional error" in result.error
    
    @pytest.mark.asyncio
    async def test_execution_log(self):
        """Test that execution log is properly generated."""
        graph = Graph()
        graph.add_node("step1", handler=lambda s: s)
        graph.add_node("step2", handler=lambda s: s)
        graph.add_edge("step1", "step2")
        graph.add_edge("step2", END)
        
        result = await execute_graph(graph, {})
        
        assert len(result.execution_log) == 2
        assert result.execution_log[0].node == "step1"
        assert result.execution_log[1].node == "step2"
        assert all(s.duration_ms > 0 for s in result.execution_log)


# ============================================================
# Integration Tests
# ============================================================

class TestCodeReviewWorkflow:
    """Integration tests for the Code Review workflow."""
    
    @pytest.mark.asyncio
    async def test_code_review_workflow(self):
        """Test the full code review workflow."""
        from app.workflows.code_review import create_code_review_workflow
        
        sample_code = '''
def hello():
    """Says hello."""
    print("Hello, World!")

def add(a, b):
    return a + b
'''
        
        workflow = create_code_review_workflow(max_iterations=3, quality_threshold=5.0)
        result = await execute_graph(workflow, {
            "code": sample_code,
            "quality_threshold": 5.0,
        })
        
        assert result.status == ExecutionStatus.COMPLETED
        assert "functions" in result.final_state
        assert "quality_score" in result.final_state
        assert len(result.execution_log) > 0


if __name__ == "__main__":
    pytest.main([__file__, "-v"])