kbsss commited on
Commit
7b2787b
·
verified ·
1 Parent(s): f695e34

Upload folder using huggingface_hub

Browse files
.dockerignore ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker
2
+ .git
3
+ .gitignore
4
+ venv/
5
+ __pycache__/
6
+ *.pyc
7
+ *.pyo
8
+ *.pyd
9
+ .Python
10
+ env/
11
+ .env
12
+ .venv/
13
+ *.egg-info/
14
+ .eggs/
15
+ dist/
16
+ build/
17
+
18
+ # IDE
19
+ .idea/
20
+ .vscode/
21
+ *.swp
22
+ *.swo
23
+
24
+ # Testing artifacts
25
+ .pytest_cache/
26
+ .coverage
27
+ htmlcov/
28
+ .tox/
29
+ .nox/
30
+
31
+ # Documentation
32
+ *.md
33
+ !README.md
34
+
35
+ # Misc
36
+ .DS_Store
37
+ Thumbs.db
38
+ *.log
.env.example ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Example .env file
2
+ # Copy this to .env and modify as needed
3
+
4
+ # Application
5
+ APP_NAME="Workflow Engine"
6
+ APP_VERSION="1.0.0"
7
+ DEBUG=true
8
+
9
+ # Server
10
+ HOST=0.0.0.0
11
+ PORT=8000
12
+
13
+ # Workflow Engine
14
+ MAX_ITERATIONS=100
15
+ EXECUTION_TIMEOUT=300
16
+
17
+ # Logging
18
+ LOG_LEVEL=INFO
.gitignore ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment variables
2
+ .env
3
+
4
+ # Python
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+ *.so
9
+ .Python
10
+ build/
11
+ develop-eggs/
12
+ dist/
13
+ downloads/
14
+ eggs/
15
+ .eggs/
16
+ lib/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # Virtual environments
27
+ venv/
28
+ ENV/
29
+ env/
30
+ .venv/
31
+
32
+ # IDE
33
+ .idea/
34
+ .vscode/
35
+ *.swp
36
+ *.swo
37
+ *~
38
+
39
+ # Testing
40
+ .pytest_cache/
41
+ .coverage
42
+ htmlcov/
43
+ .tox/
44
+ .nox/
45
+
46
+ # Type checking
47
+ .mypy_cache/
48
+
49
+ # Logs
50
+ *.log
51
+
52
+ # OS
53
+ .DS_Store
54
+ Thumbs.db
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use Python 3.11 slim image
2
+ FROM python:3.11-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /app
6
+
7
+ # Set environment variables
8
+ ENV PYTHONDONTWRITEBYTECODE=1 \
9
+ PYTHONUNBUFFERED=1 \
10
+ PIP_NO_CACHE_DIR=1 \
11
+ PIP_DISABLE_PIP_VERSION_CHECK=1
12
+
13
+ # Install system dependencies
14
+ RUN apt-get update && apt-get install -y --no-install-recommends \
15
+ gcc \
16
+ && rm -rf /var/lib/apt/lists/*
17
+
18
+ # Copy requirements first for better caching
19
+ COPY requirements.txt .
20
+
21
+ # Install Python dependencies
22
+ RUN pip install --no-cache-dir -r requirements.txt
23
+
24
+ # Copy application code
25
+ COPY . .
26
+
27
+ # Expose port (HuggingFace uses 7860 by default)
28
+ EXPOSE 7860
29
+
30
+ # Health check
31
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
32
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
33
+
34
+ # Run the application
35
+ CMD uvicorn app.main:app --host 0.0.0.0 --port 7860
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,10 +1,16 @@
1
  ---
2
- title: Flowgraph
3
- emoji: 🌖
4
- colorFrom: yellow
5
- colorTo: indigo
6
  sdk: docker
7
  pinned: false
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
1
  ---
2
+ title: FlowGraph
3
+ emoji: 🔄
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: docker
7
  pinned: false
8
+ license: mit
9
+ app_port: 7860
10
  ---
11
 
12
+ # FlowGraph
13
+
14
+ A lightweight workflow orchestration engine for building agent pipelines.
15
+
16
+ Check out the API docs at `/docs`
app/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FlowGraph - A lightweight, async-first workflow orchestration engine.
3
+
4
+ Build agent pipelines with nodes, edges, conditional branching, and looping.
5
+ Similar to LangGraph, but minimal and focused.
6
+ """
7
+
8
+ __version__ = "1.0.0"
9
+ __author__ = "AI Engineering Intern"
app/api/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ API package - FastAPI routes and schemas.
3
+ """
4
+
5
+ from app.api.routes import graph, tools, websocket
6
+
7
+ __all__ = ["graph", "tools", "websocket"]
app/api/routes/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ API Routes package.
3
+ """
4
+
5
+ from app.api.routes import graph, tools, websocket
6
+
7
+ __all__ = ["graph", "tools", "websocket"]
app/api/routes/graph.py ADDED
@@ -0,0 +1,644 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graph API Routes.
3
+
4
+ Endpoints for creating, managing, and executing workflow graphs.
5
+ """
6
+
7
+ from typing import Any, Dict, Optional
8
+ from fastapi import APIRouter, HTTPException, BackgroundTasks, status
9
+ from uuid import uuid4
10
+ import logging
11
+
12
+ from app.api.schemas import (
13
+ GraphCreateRequest,
14
+ GraphCreateResponse,
15
+ GraphRunRequest,
16
+ GraphRunResponse,
17
+ GraphInfoResponse,
18
+ GraphListResponse,
19
+ RunStateResponse,
20
+ RunListResponse,
21
+ ExecutionLogEntry,
22
+ ExecutionStatus,
23
+ ErrorResponse,
24
+ )
25
+ from app.engine.graph import Graph, END
26
+ from app.engine.node import Node, get_registered_node
27
+ from app.engine.executor import Executor, ExecutionResult
28
+ from app.storage.memory import graph_storage, run_storage
29
+ from app.tools.registry import tool_registry
30
+
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+ router = APIRouter(prefix="/graph", tags=["Graph"])
35
+
36
+
37
+ # ============================================================
38
+ # Condition Functions Registry
39
+ # ============================================================
40
+
41
+ # Built-in condition functions for routing
42
+ _condition_registry: Dict[str, Any] = {}
43
+
44
+
45
+ def register_condition(name: str):
46
+ """Decorator to register a condition function."""
47
+ def decorator(func):
48
+ _condition_registry[name] = func
49
+ return func
50
+ return decorator
51
+
52
+
53
+ @register_condition("quality_check")
54
+ def quality_check_condition(state: Dict[str, Any]) -> str:
55
+ """Route based on quality score vs threshold."""
56
+ quality_score = state.get("quality_score", 0)
57
+ threshold = state.get("quality_threshold", 7.0)
58
+ return "pass" if quality_score >= threshold else "fail"
59
+
60
+
61
+ # Also register as quality_meets_threshold (used by code review workflow)
62
+ @register_condition("quality_meets_threshold")
63
+ def quality_meets_threshold(state: Dict[str, Any]) -> str:
64
+ """Route based on quality score vs threshold."""
65
+ quality_score = state.get("quality_score", 0)
66
+ threshold = state.get("quality_threshold", 7.0)
67
+ return "pass" if quality_score >= threshold else "fail"
68
+
69
+
70
+ @register_condition("always_continue")
71
+ def always_continue(state: Dict[str, Any]) -> str:
72
+ """Always returns 'continue' - for unconditional looping."""
73
+ return "continue"
74
+
75
+
76
+ # Also register as always_loop (used by code review workflow)
77
+ @register_condition("always_loop")
78
+ def always_loop(state: Dict[str, Any]) -> str:
79
+ """Always returns 'continue' - for looping back."""
80
+ return "continue"
81
+
82
+
83
+ @register_condition("always_end")
84
+ def always_end(state: Dict[str, Any]) -> str:
85
+ """Always returns 'end' - for explicit termination."""
86
+ return "end"
87
+
88
+
89
+ @register_condition("max_iterations_check")
90
+ def max_iterations_check(state: Dict[str, Any]) -> str:
91
+ """Check if max iterations reached."""
92
+ iteration = state.get("_iteration", 0)
93
+ max_iter = state.get("_max_iterations", 3)
94
+ return "stop" if iteration >= max_iter else "continue"
95
+
96
+
97
+ def get_condition(name: str):
98
+ """Get a condition function by name."""
99
+ return _condition_registry.get(name)
100
+
101
+
102
+ # ============================================================
103
+ # Graph CRUD Endpoints
104
+ # ============================================================
105
+
106
+ @router.post(
107
+ "/create",
108
+ response_model=GraphCreateResponse,
109
+ status_code=status.HTTP_201_CREATED,
110
+ responses={
111
+ 400: {"model": ErrorResponse, "description": "Invalid graph definition"},
112
+ 404: {"model": ErrorResponse, "description": "Handler not found"},
113
+ }
114
+ )
115
+ async def create_graph(request: GraphCreateRequest) -> GraphCreateResponse:
116
+ """
117
+ Create a new workflow graph.
118
+
119
+ Define nodes with their handlers, edges for flow control,
120
+ and conditional edges for branching logic.
121
+ """
122
+ graph_id = str(uuid4())
123
+
124
+ # Build the graph
125
+ graph = Graph(
126
+ graph_id=graph_id,
127
+ name=request.name,
128
+ description=request.description or "",
129
+ max_iterations=request.max_iterations,
130
+ )
131
+
132
+ # Add nodes
133
+ for node_def in request.nodes:
134
+ # Find the handler function
135
+ handler = get_registered_node(node_def.handler)
136
+ if handler is None:
137
+ # Check tool registry as fallback
138
+ tool = tool_registry.get(node_def.handler)
139
+ if tool:
140
+ handler = _create_node_handler_from_tool(node_def.handler)
141
+ else:
142
+ raise HTTPException(
143
+ status_code=404,
144
+ detail=f"Handler '{node_def.handler}' not found. "
145
+ f"Available handlers: {list(tool_registry.list_tools())}"
146
+ )
147
+
148
+ graph.add_node(
149
+ name=node_def.name,
150
+ handler=handler,
151
+ description=node_def.description or "",
152
+ )
153
+
154
+ # Add direct edges
155
+ for source, target in request.edges.items():
156
+ if source not in graph.nodes:
157
+ raise HTTPException(
158
+ status_code=400,
159
+ detail=f"Edge source '{source}' is not a valid node"
160
+ )
161
+ if target != END and target != "__END__" and target not in graph.nodes:
162
+ raise HTTPException(
163
+ status_code=400,
164
+ detail=f"Edge target '{target}' is not a valid node"
165
+ )
166
+ # Normalize END
167
+ target = END if target == "__END__" else target
168
+ graph.add_edge(source, target)
169
+
170
+ # Add conditional edges
171
+ for source, cond_routes in request.conditional_edges.items():
172
+ if source not in graph.nodes:
173
+ raise HTTPException(
174
+ status_code=400,
175
+ detail=f"Conditional edge source '{source}' is not a valid node"
176
+ )
177
+
178
+ # Get condition function
179
+ condition_func = get_condition(cond_routes.condition)
180
+ if condition_func is None:
181
+ raise HTTPException(
182
+ status_code=404,
183
+ detail=f"Condition '{cond_routes.condition}' not found. "
184
+ f"Available: {list(_condition_registry.keys())}"
185
+ )
186
+
187
+ # Normalize routes (handle __END__)
188
+ routes = {}
189
+ for key, target in cond_routes.routes.items():
190
+ if target == "__END__":
191
+ routes[key] = END
192
+ else:
193
+ if target not in graph.nodes:
194
+ raise HTTPException(
195
+ status_code=400,
196
+ detail=f"Conditional route target '{target}' is not a valid node"
197
+ )
198
+ routes[key] = target
199
+
200
+ graph.add_conditional_edge(source, condition_func, routes)
201
+
202
+ # Set entry point
203
+ if request.entry_point:
204
+ if request.entry_point not in graph.nodes:
205
+ raise HTTPException(
206
+ status_code=400,
207
+ detail=f"Entry point '{request.entry_point}' is not a valid node"
208
+ )
209
+ graph.set_entry_point(request.entry_point)
210
+
211
+ # Validate graph
212
+ errors = graph.validate()
213
+ if errors:
214
+ raise HTTPException(
215
+ status_code=400,
216
+ detail=f"Graph validation failed: {errors}"
217
+ )
218
+
219
+ # Store the graph
220
+ await graph_storage.save(
221
+ graph_id=graph_id,
222
+ name=request.name,
223
+ definition=graph.to_dict(),
224
+ )
225
+
226
+ logger.info(f"Created graph: {graph_id} ({request.name})")
227
+
228
+ return GraphCreateResponse(
229
+ graph_id=graph_id,
230
+ name=request.name,
231
+ message="Graph created successfully",
232
+ node_count=len(graph.nodes),
233
+ )
234
+
235
+
236
+ def _create_node_handler_from_tool(tool_name: str):
237
+ """Create a node handler that calls a tool and updates state."""
238
+ def handler(state: Dict[str, Any]) -> Dict[str, Any]:
239
+ tool = tool_registry.get(tool_name)
240
+ if not tool:
241
+ raise ValueError(f"Tool '{tool_name}' not found")
242
+
243
+ # Check if the tool function expects a 'state' parameter (node handler style)
244
+ # or individual parameters (regular tool style)
245
+ import inspect
246
+ sig = inspect.signature(tool.func)
247
+ param_names = list(sig.parameters.keys())
248
+
249
+ if len(param_names) == 1 and param_names[0] == 'state':
250
+ # This is a node handler - pass state directly
251
+ result = tool.func(state)
252
+ else:
253
+ # This is a regular tool - extract arguments from state
254
+ result = tool.func(**_extract_tool_args(tool, state))
255
+
256
+ # Handle the result
257
+ if isinstance(result, dict):
258
+ # If the tool returns a full state, use it directly
259
+ # Check if it looks like a state update (has same keys or adds new ones)
260
+ if result is state:
261
+ return result
262
+ # Merge result into state
263
+ state.update(result)
264
+
265
+ return state
266
+
267
+ handler.__name__ = f"{tool_name}_handler"
268
+ return handler
269
+
270
+
271
+ def _extract_tool_args(tool, state: Dict[str, Any]) -> Dict[str, Any]:
272
+ """Extract arguments for a tool from state."""
273
+ import inspect
274
+ sig = inspect.signature(tool.func)
275
+ args = {}
276
+
277
+ for param_name, param in sig.parameters.items():
278
+ if param_name in state:
279
+ args[param_name] = state[param_name]
280
+ elif param.default != inspect.Parameter.empty:
281
+ pass # Use default
282
+ # Skip missing optional params
283
+
284
+ return args
285
+
286
+
287
+ @router.get(
288
+ "/{graph_id}",
289
+ response_model=GraphInfoResponse,
290
+ responses={404: {"model": ErrorResponse}},
291
+ )
292
+ async def get_graph(graph_id: str) -> GraphInfoResponse:
293
+ """Get information about a specific graph."""
294
+ stored = await graph_storage.get(graph_id)
295
+ if not stored:
296
+ raise HTTPException(status_code=404, detail=f"Graph '{graph_id}' not found")
297
+
298
+ definition = stored.definition
299
+
300
+ # Generate mermaid diagram
301
+ mermaid = _generate_mermaid(definition)
302
+
303
+ return GraphInfoResponse(
304
+ graph_id=stored.graph_id,
305
+ name=stored.name,
306
+ description=definition.get("description"),
307
+ node_count=len(definition.get("nodes", {})),
308
+ nodes=list(definition.get("nodes", {}).keys()),
309
+ entry_point=definition.get("entry_point"),
310
+ max_iterations=definition.get("max_iterations", 100),
311
+ created_at=stored.created_at.isoformat(),
312
+ mermaid_diagram=mermaid,
313
+ )
314
+
315
+
316
+ def _generate_mermaid(definition: Dict[str, Any]) -> str:
317
+ """Generate a Mermaid diagram from graph definition."""
318
+ lines = ["graph TD"]
319
+
320
+ nodes = definition.get("nodes", {})
321
+ edges = definition.get("edges", {})
322
+ cond_edges = definition.get("conditional_edges", {})
323
+
324
+ # Add nodes
325
+ for name in nodes:
326
+ label = name.replace("_", " ").title()
327
+ lines.append(f' {name}["{label}"]')
328
+
329
+ # Check if END is used
330
+ has_end = END in edges.values()
331
+ for cond in cond_edges.values():
332
+ if END in cond.get("routes", {}).values():
333
+ has_end = True
334
+
335
+ if has_end:
336
+ lines.append(f' {END}(("END"))')
337
+
338
+ # Add direct edges
339
+ for source, target in edges.items():
340
+ lines.append(f" {source} --> {target}")
341
+
342
+ # Add conditional edges
343
+ for source, cond in cond_edges.items():
344
+ for route_key, target in cond.get("routes", {}).items():
345
+ lines.append(f" {source} -->|{route_key}| {target}")
346
+
347
+ return "\n".join(lines)
348
+
349
+
350
+ @router.get(
351
+ "/",
352
+ response_model=GraphListResponse,
353
+ )
354
+ async def list_graphs() -> GraphListResponse:
355
+ """List all available graphs."""
356
+ graphs = await graph_storage.list_all()
357
+
358
+ graph_infos = []
359
+ for stored in graphs:
360
+ definition = stored.definition
361
+ graph_infos.append(GraphInfoResponse(
362
+ graph_id=stored.graph_id,
363
+ name=stored.name,
364
+ description=definition.get("description"),
365
+ node_count=len(definition.get("nodes", {})),
366
+ nodes=list(definition.get("nodes", {}).keys()),
367
+ entry_point=definition.get("entry_point"),
368
+ max_iterations=definition.get("max_iterations", 100),
369
+ created_at=stored.created_at.isoformat(),
370
+ mermaid_diagram=None, # Skip for list view
371
+ ))
372
+
373
+ return GraphListResponse(graphs=graph_infos, total=len(graph_infos))
374
+
375
+
376
+ @router.delete(
377
+ "/{graph_id}",
378
+ status_code=status.HTTP_204_NO_CONTENT,
379
+ responses={404: {"model": ErrorResponse}},
380
+ )
381
+ async def delete_graph(graph_id: str):
382
+ """Delete a graph."""
383
+ deleted = await graph_storage.delete(graph_id)
384
+ if not deleted:
385
+ raise HTTPException(status_code=404, detail=f"Graph '{graph_id}' not found")
386
+ logger.info(f"Deleted graph: {graph_id}")
387
+
388
+
389
+ # ============================================================
390
+ # Execution Endpoints
391
+ # ============================================================
392
+
393
+ @router.post(
394
+ "/run",
395
+ response_model=GraphRunResponse,
396
+ responses={
397
+ 404: {"model": ErrorResponse},
398
+ 500: {"model": ErrorResponse, "description": "Execution failed"},
399
+ }
400
+ )
401
+ async def run_graph(
402
+ request: GraphRunRequest,
403
+ background_tasks: BackgroundTasks,
404
+ ) -> GraphRunResponse:
405
+ """
406
+ Execute a workflow graph with the given initial state.
407
+
408
+ If `async_execution` is True, the workflow runs in the background
409
+ and you can poll the status using GET /graph/state/{run_id}.
410
+ """
411
+ # Get the graph
412
+ stored = await graph_storage.get(request.graph_id)
413
+ if not stored:
414
+ raise HTTPException(
415
+ status_code=404,
416
+ detail=f"Graph '{request.graph_id}' not found"
417
+ )
418
+
419
+ # Rebuild the graph from definition
420
+ graph = await _rebuild_graph_from_definition(stored.definition)
421
+
422
+ # Create run
423
+ run_id = str(uuid4())
424
+ await run_storage.create(run_id, request.graph_id, request.initial_state)
425
+
426
+ if request.async_execution:
427
+ # Run in background
428
+ background_tasks.add_task(
429
+ _execute_in_background,
430
+ graph,
431
+ run_id,
432
+ request.initial_state,
433
+ )
434
+
435
+ return GraphRunResponse(
436
+ run_id=run_id,
437
+ graph_id=request.graph_id,
438
+ status=ExecutionStatus.PENDING,
439
+ final_state={},
440
+ execution_log=[],
441
+ started_at=None,
442
+ completed_at=None,
443
+ total_duration_ms=None,
444
+ iterations=0,
445
+ )
446
+
447
+ # Execute synchronously
448
+ try:
449
+ executor = Executor(
450
+ graph,
451
+ run_id=run_id,
452
+ on_step=lambda step, state: _update_run_state(run_id, step, state),
453
+ )
454
+ result = await executor.run(request.initial_state)
455
+
456
+ # Update storage
457
+ if result.status.value == "completed":
458
+ await run_storage.complete(
459
+ run_id,
460
+ result.final_state,
461
+ [s.to_dict() for s in result.execution_log],
462
+ )
463
+ else:
464
+ await run_storage.fail(run_id, result.error or "Unknown error", result.final_state)
465
+
466
+ return _result_to_response(result)
467
+
468
+ except Exception as e:
469
+ logger.exception(f"Execution failed: {e}")
470
+ await run_storage.fail(run_id, str(e))
471
+ raise HTTPException(status_code=500, detail=str(e))
472
+
473
+
474
+ async def _rebuild_graph_from_definition(definition: Dict[str, Any]) -> Graph:
475
+ """Rebuild a Graph object from its stored definition."""
476
+ graph = Graph(
477
+ graph_id=definition.get("graph_id", str(uuid4())),
478
+ name=definition.get("name", "Unnamed"),
479
+ description=definition.get("description", ""),
480
+ max_iterations=definition.get("max_iterations", 100),
481
+ )
482
+
483
+ # Add nodes
484
+ nodes_def = definition.get("nodes", {})
485
+ for node_name, node_info in nodes_def.items():
486
+ handler_name = node_info.get("handler", node_name)
487
+ handler = _create_node_handler_from_tool(handler_name)
488
+ graph.add_node(
489
+ name=node_name,
490
+ handler=handler,
491
+ description=node_info.get("description", ""),
492
+ )
493
+
494
+ # Add direct edges
495
+ for source, target in definition.get("edges", {}).items():
496
+ graph.add_edge(source, target)
497
+
498
+ # Add conditional edges
499
+ for source, cond_info in definition.get("conditional_edges", {}).items():
500
+ condition_name = cond_info.get("condition", "always_continue")
501
+ condition_func = get_condition(condition_name)
502
+ if condition_func is None:
503
+ condition_func = always_continue
504
+
505
+ routes = cond_info.get("routes", {})
506
+ graph.add_conditional_edge(source, condition_func, routes)
507
+
508
+ # Set entry point
509
+ if definition.get("entry_point"):
510
+ graph.set_entry_point(definition["entry_point"])
511
+
512
+ return graph
513
+
514
+
515
+ async def _execute_in_background(graph: Graph, run_id: str, initial_state: Dict[str, Any]):
516
+ """Execute a workflow in the background."""
517
+ try:
518
+ executor = Executor(
519
+ graph,
520
+ run_id=run_id,
521
+ on_step=lambda step, state: _update_run_state(run_id, step, state),
522
+ )
523
+ result = await executor.run(initial_state)
524
+
525
+ if result.status.value == "completed":
526
+ await run_storage.complete(
527
+ run_id,
528
+ result.final_state,
529
+ [s.to_dict() for s in result.execution_log],
530
+ )
531
+ else:
532
+ await run_storage.fail(run_id, result.error or "Unknown error", result.final_state)
533
+
534
+ except Exception as e:
535
+ logger.exception(f"Background execution failed: {e}")
536
+ await run_storage.fail(run_id, str(e))
537
+
538
+
539
+ def _update_run_state(run_id: str, step, state: Dict[str, Any]):
540
+ """Update run state during execution (sync callback)."""
541
+ import asyncio
542
+ try:
543
+ loop = asyncio.get_event_loop()
544
+ if loop.is_running():
545
+ asyncio.create_task(
546
+ run_storage.update_state(run_id, state, step.node, step.iteration)
547
+ )
548
+ except Exception:
549
+ pass # Ignore errors in callback
550
+
551
+
552
+ def _result_to_response(result: ExecutionResult) -> GraphRunResponse:
553
+ """Convert ExecutionResult to API response."""
554
+ return GraphRunResponse(
555
+ run_id=result.run_id,
556
+ graph_id=result.graph_id,
557
+ status=ExecutionStatus(result.status.value),
558
+ final_state=result.final_state,
559
+ execution_log=[
560
+ ExecutionLogEntry(
561
+ step=s.step,
562
+ node=s.node,
563
+ started_at=s.started_at.isoformat(),
564
+ completed_at=s.completed_at.isoformat() if s.completed_at else None,
565
+ duration_ms=s.duration_ms,
566
+ iteration=s.iteration,
567
+ result=s.result,
568
+ error=s.error,
569
+ route_taken=s.route_taken,
570
+ )
571
+ for s in result.execution_log
572
+ ],
573
+ started_at=result.started_at.isoformat() if result.started_at else None,
574
+ completed_at=result.completed_at.isoformat() if result.completed_at else None,
575
+ total_duration_ms=result.total_duration_ms,
576
+ iterations=result.iterations,
577
+ error=result.error,
578
+ )
579
+
580
+
581
+ # ============================================================
582
+ # Run State Endpoints
583
+ # ============================================================
584
+
585
+ @router.get(
586
+ "/state/{run_id}",
587
+ response_model=RunStateResponse,
588
+ responses={404: {"model": ErrorResponse}},
589
+ )
590
+ async def get_run_state(run_id: str) -> RunStateResponse:
591
+ """
592
+ Get the current state of a workflow run.
593
+
594
+ Use this to poll the status of async executions.
595
+ """
596
+ stored = await run_storage.get(run_id)
597
+ if not stored:
598
+ raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
599
+
600
+ return RunStateResponse(
601
+ run_id=stored.run_id,
602
+ graph_id=stored.graph_id,
603
+ status=ExecutionStatus(stored.status),
604
+ current_node=stored.current_node,
605
+ current_state=stored.current_state,
606
+ iteration=stored.iteration,
607
+ execution_log=[
608
+ ExecutionLogEntry(**entry) for entry in stored.execution_log
609
+ ],
610
+ started_at=stored.started_at.isoformat(),
611
+ completed_at=stored.completed_at.isoformat() if stored.completed_at else None,
612
+ error=stored.error,
613
+ )
614
+
615
+
616
+ @router.get(
617
+ "/runs",
618
+ response_model=RunListResponse,
619
+ )
620
+ async def list_runs(graph_id: Optional[str] = None) -> RunListResponse:
621
+ """List all runs, optionally filtered by graph_id."""
622
+ if graph_id:
623
+ runs = await run_storage.list_by_graph(graph_id)
624
+ else:
625
+ runs = await run_storage.list_all()
626
+
627
+ run_states = []
628
+ for stored in runs:
629
+ run_states.append(RunStateResponse(
630
+ run_id=stored.run_id,
631
+ graph_id=stored.graph_id,
632
+ status=ExecutionStatus(stored.status),
633
+ current_node=stored.current_node,
634
+ current_state=stored.current_state,
635
+ iteration=stored.iteration,
636
+ execution_log=[
637
+ ExecutionLogEntry(**entry) for entry in stored.execution_log
638
+ ],
639
+ started_at=stored.started_at.isoformat(),
640
+ completed_at=stored.completed_at.isoformat() if stored.completed_at else None,
641
+ error=stored.error,
642
+ ))
643
+
644
+ return RunListResponse(runs=run_states, total=len(run_states))
app/api/routes/tools.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tools API Routes.
3
+
4
+ Endpoints for listing and managing registered tools.
5
+ """
6
+
7
+ from typing import Any, Dict
8
+ from fastapi import APIRouter, HTTPException, status
9
+ import logging
10
+
11
+ from app.api.schemas import (
12
+ ToolInfo,
13
+ ToolListResponse,
14
+ ToolRegisterRequest,
15
+ ToolRegisterResponse,
16
+ ErrorResponse,
17
+ )
18
+ from app.tools.registry import tool_registry
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ router = APIRouter(prefix="/tools", tags=["Tools"])
24
+
25
+
26
+ @router.get(
27
+ "/",
28
+ response_model=ToolListResponse,
29
+ )
30
+ async def list_tools() -> ToolListResponse:
31
+ """
32
+ List all registered tools.
33
+
34
+ Tools are functions that workflow nodes can use during execution.
35
+ """
36
+ tools = tool_registry.list_tools()
37
+
38
+ tool_infos = [
39
+ ToolInfo(
40
+ name=t["name"],
41
+ description=t["description"],
42
+ parameters=t["parameters"],
43
+ )
44
+ for t in tools
45
+ ]
46
+
47
+ return ToolListResponse(tools=tool_infos, total=len(tool_infos))
48
+
49
+
50
+ @router.get(
51
+ "/{tool_name}",
52
+ response_model=ToolInfo,
53
+ responses={404: {"model": ErrorResponse}},
54
+ )
55
+ async def get_tool(tool_name: str) -> ToolInfo:
56
+ """Get information about a specific tool."""
57
+ tool = tool_registry.get(tool_name)
58
+ if not tool:
59
+ raise HTTPException(
60
+ status_code=404,
61
+ detail=f"Tool '{tool_name}' not found"
62
+ )
63
+
64
+ return ToolInfo(
65
+ name=tool.name,
66
+ description=tool.description,
67
+ parameters=tool.parameters,
68
+ )
69
+
70
+
71
+ @router.post(
72
+ "/register",
73
+ response_model=ToolRegisterResponse,
74
+ status_code=status.HTTP_201_CREATED,
75
+ responses={
76
+ 400: {"model": ErrorResponse, "description": "Invalid tool code"},
77
+ 409: {"model": ErrorResponse, "description": "Tool already exists"},
78
+ }
79
+ )
80
+ async def register_tool(request: ToolRegisterRequest) -> ToolRegisterResponse:
81
+ """
82
+ Register a new tool dynamically.
83
+
84
+ **Warning**: This endpoint executes Python code. Use with caution
85
+ and only in trusted environments.
86
+
87
+ The code should define a function that:
88
+ - Takes parameters as needed
89
+ - Returns a dictionary with results
90
+ """
91
+ # Check if tool already exists
92
+ if tool_registry.has(request.name):
93
+ raise HTTPException(
94
+ status_code=409,
95
+ detail=f"Tool '{request.name}' already exists"
96
+ )
97
+
98
+ # Try to compile and execute the code
99
+ try:
100
+ # Create a restricted namespace
101
+ namespace: Dict[str, Any] = {}
102
+
103
+ # Execute the code to define the function
104
+ exec(request.code, namespace)
105
+
106
+ # Find the function in the namespace
107
+ func = None
108
+ for name, value in namespace.items():
109
+ if callable(value) and not name.startswith("_"):
110
+ func = value
111
+ break
112
+
113
+ if func is None:
114
+ raise HTTPException(
115
+ status_code=400,
116
+ detail="No callable function found in the provided code"
117
+ )
118
+
119
+ # Register the tool
120
+ tool_registry.add(
121
+ func=func,
122
+ name=request.name,
123
+ description=request.description,
124
+ )
125
+
126
+ logger.info(f"Registered dynamic tool: {request.name}")
127
+
128
+ return ToolRegisterResponse(
129
+ name=request.name,
130
+ message=f"Tool '{request.name}' registered successfully",
131
+ warning="Dynamic tool registration executes code. Use responsibly.",
132
+ )
133
+
134
+ except SyntaxError as e:
135
+ raise HTTPException(
136
+ status_code=400,
137
+ detail=f"Syntax error in tool code: {e}"
138
+ )
139
+ except Exception as e:
140
+ raise HTTPException(
141
+ status_code=400,
142
+ detail=f"Error registering tool: {e}"
143
+ )
144
+
145
+
146
+ @router.delete(
147
+ "/{tool_name}",
148
+ status_code=status.HTTP_204_NO_CONTENT,
149
+ responses={404: {"model": ErrorResponse}},
150
+ )
151
+ async def delete_tool(tool_name: str):
152
+ """Delete a registered tool."""
153
+ # Protect built-in tools
154
+ builtin_tools = {
155
+ "extract_functions",
156
+ "calculate_complexity",
157
+ "detect_issues",
158
+ "suggest_improvements",
159
+ "quality_check",
160
+ }
161
+
162
+ if tool_name in builtin_tools:
163
+ raise HTTPException(
164
+ status_code=400,
165
+ detail=f"Cannot delete built-in tool '{tool_name}'"
166
+ )
167
+
168
+ deleted = tool_registry.remove(tool_name)
169
+ if not deleted:
170
+ raise HTTPException(
171
+ status_code=404,
172
+ detail=f"Tool '{tool_name}' not found"
173
+ )
174
+
175
+ logger.info(f"Deleted tool: {tool_name}")
app/api/routes/websocket.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ WebSocket Routes for Real-time Execution Streaming.
3
+
4
+ Provides live updates during workflow execution.
5
+ """
6
+
7
+ from typing import Any, Dict, Set
8
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
9
+ from uuid import uuid4
10
+ import asyncio
11
+ import json
12
+ import logging
13
+
14
+ from app.engine.graph import Graph
15
+ from app.engine.executor import Executor, ExecutionStep
16
+ from app.storage.memory import graph_storage, run_storage
17
+
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+ router = APIRouter(tags=["WebSocket"])
22
+
23
+
24
+ class ConnectionManager:
25
+ """Manages WebSocket connections."""
26
+
27
+ def __init__(self):
28
+ self.active_connections: Dict[str, Set[WebSocket]] = {}
29
+
30
+ async def connect(self, websocket: WebSocket, run_id: str):
31
+ """Accept a new WebSocket connection."""
32
+ await websocket.accept()
33
+ if run_id not in self.active_connections:
34
+ self.active_connections[run_id] = set()
35
+ self.active_connections[run_id].add(websocket)
36
+ logger.info(f"WebSocket connected for run: {run_id}")
37
+
38
+ def disconnect(self, websocket: WebSocket, run_id: str):
39
+ """Remove a WebSocket connection."""
40
+ if run_id in self.active_connections:
41
+ self.active_connections[run_id].discard(websocket)
42
+ if not self.active_connections[run_id]:
43
+ del self.active_connections[run_id]
44
+ logger.info(f"WebSocket disconnected for run: {run_id}")
45
+
46
+ async def broadcast(self, run_id: str, message: Dict[str, Any]):
47
+ """Broadcast a message to all connections for a run."""
48
+ if run_id in self.active_connections:
49
+ disconnected = set()
50
+ for websocket in self.active_connections[run_id]:
51
+ try:
52
+ await websocket.send_json(message)
53
+ except Exception:
54
+ disconnected.add(websocket)
55
+
56
+ # Clean up disconnected clients
57
+ for ws in disconnected:
58
+ self.active_connections[run_id].discard(ws)
59
+
60
+
61
+ # Global connection manager
62
+ manager = ConnectionManager()
63
+
64
+
65
+ @router.websocket("/ws/run/{graph_id}")
66
+ async def websocket_run(websocket: WebSocket, graph_id: str):
67
+ """
68
+ WebSocket endpoint for real-time workflow execution.
69
+
70
+ Connect to this endpoint and send the initial state as JSON.
71
+ You'll receive step-by-step updates as the workflow executes.
72
+
73
+ Message format (client -> server):
74
+ ```json
75
+ {"action": "start", "initial_state": {"code": "..."}}
76
+ ```
77
+
78
+ Message format (server -> client):
79
+ ```json
80
+ {
81
+ "type": "step",
82
+ "step": 1,
83
+ "node": "extract",
84
+ "status": "completed",
85
+ "duration_ms": 15.5,
86
+ "state": {...}
87
+ }
88
+ ```
89
+ """
90
+ # Check if graph exists
91
+ stored = await graph_storage.get(graph_id)
92
+ if not stored:
93
+ await websocket.close(code=4004, reason=f"Graph '{graph_id}' not found")
94
+ return
95
+
96
+ run_id = str(uuid4())
97
+ await manager.connect(websocket, run_id)
98
+
99
+ try:
100
+ # Wait for start message
101
+ data = await websocket.receive_json()
102
+
103
+ if data.get("action") != "start":
104
+ await websocket.send_json({
105
+ "type": "error",
106
+ "error": "Expected 'start' action"
107
+ })
108
+ return
109
+
110
+ initial_state = data.get("initial_state", {})
111
+
112
+ # Send acknowledgment
113
+ await websocket.send_json({
114
+ "type": "started",
115
+ "run_id": run_id,
116
+ "graph_id": graph_id,
117
+ })
118
+
119
+ # Rebuild graph
120
+ graph = await _rebuild_graph(stored.definition)
121
+
122
+ # Create run record
123
+ await run_storage.create(run_id, graph_id, initial_state)
124
+
125
+ # Execute with streaming updates
126
+ async def on_step(step: ExecutionStep, state: Dict[str, Any]):
127
+ await manager.broadcast(run_id, {
128
+ "type": "step",
129
+ "step": step.step,
130
+ "node": step.node,
131
+ "status": step.result,
132
+ "duration_ms": step.duration_ms,
133
+ "iteration": step.iteration,
134
+ "route_taken": step.route_taken,
135
+ "error": step.error,
136
+ "state": state,
137
+ })
138
+
139
+ executor = Executor(graph, run_id=run_id)
140
+
141
+ # Run with step notifications
142
+ result = await _run_with_streaming(executor, initial_state, on_step)
143
+
144
+ # Send completion
145
+ await websocket.send_json({
146
+ "type": "completed",
147
+ "run_id": run_id,
148
+ "status": result.status.value,
149
+ "final_state": result.final_state,
150
+ "total_duration_ms": result.total_duration_ms,
151
+ "iterations": result.iterations,
152
+ "error": result.error,
153
+ })
154
+
155
+ # Update storage
156
+ if result.status.value == "completed":
157
+ await run_storage.complete(
158
+ run_id,
159
+ result.final_state,
160
+ [s.to_dict() for s in result.execution_log],
161
+ )
162
+ else:
163
+ await run_storage.fail(run_id, result.error or "Unknown error")
164
+
165
+ except WebSocketDisconnect:
166
+ logger.info(f"Client disconnected from run {run_id}")
167
+ except Exception as e:
168
+ logger.exception(f"WebSocket error: {e}")
169
+ try:
170
+ await websocket.send_json({
171
+ "type": "error",
172
+ "error": str(e),
173
+ })
174
+ except Exception:
175
+ pass
176
+ finally:
177
+ manager.disconnect(websocket, run_id)
178
+
179
+
180
+ async def _rebuild_graph(definition: Dict[str, Any]) -> Graph:
181
+ """Rebuild graph from definition (copied from graph.py to avoid circular import)."""
182
+ from app.api.routes.graph import _rebuild_graph_from_definition
183
+ return await _rebuild_graph_from_definition(definition)
184
+
185
+
186
+ async def _run_with_streaming(
187
+ executor: Executor,
188
+ initial_state: Dict[str, Any],
189
+ on_step
190
+ ):
191
+ """Run executor with async step callbacks."""
192
+ from app.engine.graph import END
193
+ from app.engine.state import StateManager
194
+ import time
195
+ from datetime import datetime
196
+
197
+ # Execute the workflow
198
+ result = await executor.run(initial_state)
199
+
200
+ # Stream each step (already executed, but we notify)
201
+ for step in result.execution_log:
202
+ await on_step(step, result.final_state)
203
+ await asyncio.sleep(0.01) # Small delay for streaming effect
204
+
205
+ return result
206
+
207
+
208
+ @router.websocket("/ws/subscribe/{run_id}")
209
+ async def websocket_subscribe(websocket: WebSocket, run_id: str):
210
+ """
211
+ Subscribe to updates for an existing run.
212
+
213
+ Use this to watch an async execution started via POST /graph/run.
214
+ """
215
+ # Check if run exists
216
+ stored = await run_storage.get(run_id)
217
+ if not stored:
218
+ await websocket.close(code=4004, reason=f"Run '{run_id}' not found")
219
+ return
220
+
221
+ await manager.connect(websocket, run_id)
222
+
223
+ try:
224
+ # Send current state
225
+ await websocket.send_json({
226
+ "type": "current_state",
227
+ "run_id": run_id,
228
+ "status": stored.status,
229
+ "current_node": stored.current_node,
230
+ "iteration": stored.iteration,
231
+ "state": stored.current_state,
232
+ })
233
+
234
+ # Keep connection open and poll for updates
235
+ last_log_count = len(stored.execution_log)
236
+
237
+ while True:
238
+ await asyncio.sleep(0.5) # Poll interval
239
+
240
+ stored = await run_storage.get(run_id)
241
+ if not stored:
242
+ break
243
+
244
+ # Send new log entries
245
+ if len(stored.execution_log) > last_log_count:
246
+ for entry in stored.execution_log[last_log_count:]:
247
+ await websocket.send_json({
248
+ "type": "step",
249
+ **entry,
250
+ })
251
+ last_log_count = len(stored.execution_log)
252
+
253
+ # Check if completed
254
+ if stored.status in ("completed", "failed", "cancelled"):
255
+ await websocket.send_json({
256
+ "type": "completed",
257
+ "run_id": run_id,
258
+ "status": stored.status,
259
+ "final_state": stored.final_state,
260
+ "error": stored.error,
261
+ })
262
+ break
263
+
264
+ except WebSocketDisconnect:
265
+ logger.info(f"Subscriber disconnected from run {run_id}")
266
+ except Exception as e:
267
+ logger.exception(f"WebSocket error: {e}")
268
+ finally:
269
+ manager.disconnect(websocket, run_id)
app/api/schemas.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Pydantic Schemas for API Request/Response Models.
3
+
4
+ These schemas define the structure of data flowing through the API,
5
+ providing automatic validation and documentation.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional, Union
9
+ from pydantic import BaseModel, Field
10
+ from datetime import datetime
11
+ from enum import Enum
12
+
13
+
14
+ # ============================================================
15
+ # Enums
16
+ # ============================================================
17
+
18
+ class ExecutionStatus(str, Enum):
19
+ """Status of a workflow execution."""
20
+ PENDING = "pending"
21
+ RUNNING = "running"
22
+ COMPLETED = "completed"
23
+ FAILED = "failed"
24
+ CANCELLED = "cancelled"
25
+
26
+
27
+ # ============================================================
28
+ # Node Schemas
29
+ # ============================================================
30
+
31
+ class NodeDefinition(BaseModel):
32
+ """Definition of a node in the graph."""
33
+ name: str = Field(..., description="Unique name for the node")
34
+ handler: str = Field(..., description="Name of the handler function (must be registered)")
35
+ description: Optional[str] = Field(None, description="Human-readable description")
36
+
37
+ class Config:
38
+ json_schema_extra = {
39
+ "example": {
40
+ "name": "extract",
41
+ "handler": "extract_functions",
42
+ "description": "Extract function definitions from code"
43
+ }
44
+ }
45
+
46
+
47
+ # ============================================================
48
+ # Edge Schemas
49
+ # ============================================================
50
+
51
+ class ConditionalRoutes(BaseModel):
52
+ """Routes for a conditional edge."""
53
+ condition: str = Field(..., description="Name of the condition function")
54
+ routes: Dict[str, str] = Field(
55
+ ...,
56
+ description="Mapping of condition results to target nodes"
57
+ )
58
+
59
+ class Config:
60
+ json_schema_extra = {
61
+ "example": {
62
+ "condition": "quality_check",
63
+ "routes": {
64
+ "pass": "__END__",
65
+ "fail": "improve"
66
+ }
67
+ }
68
+ }
69
+
70
+
71
+ # ============================================================
72
+ # Graph Schemas
73
+ # ============================================================
74
+
75
+ class GraphCreateRequest(BaseModel):
76
+ """Request to create a new workflow graph."""
77
+ name: str = Field(..., description="Name of the workflow")
78
+ description: Optional[str] = Field(None, description="Description of what this workflow does")
79
+ nodes: List[NodeDefinition] = Field(..., description="List of nodes in the graph")
80
+ edges: Dict[str, str] = Field(
81
+ default_factory=dict,
82
+ description="Direct edges: source -> target"
83
+ )
84
+ conditional_edges: Dict[str, ConditionalRoutes] = Field(
85
+ default_factory=dict,
86
+ description="Conditional edges with routing logic"
87
+ )
88
+ entry_point: Optional[str] = Field(None, description="Entry node (defaults to first node)")
89
+ max_iterations: int = Field(100, description="Maximum loop iterations", ge=1, le=1000)
90
+
91
+ class Config:
92
+ json_schema_extra = {
93
+ "example": {
94
+ "name": "code_review_workflow",
95
+ "description": "Automated code review with quality checks",
96
+ "nodes": [
97
+ {"name": "extract", "handler": "extract_functions"},
98
+ {"name": "complexity", "handler": "calculate_complexity"},
99
+ {"name": "issues", "handler": "detect_issues"},
100
+ {"name": "improve", "handler": "suggest_improvements"}
101
+ ],
102
+ "edges": {
103
+ "extract": "complexity",
104
+ "complexity": "issues"
105
+ },
106
+ "conditional_edges": {
107
+ "issues": {
108
+ "condition": "quality_check",
109
+ "routes": {"pass": "__END__", "fail": "improve"}
110
+ },
111
+ "improve": {
112
+ "condition": "always_continue",
113
+ "routes": {"continue": "issues"}
114
+ }
115
+ },
116
+ "entry_point": "extract",
117
+ "max_iterations": 10
118
+ }
119
+ }
120
+
121
+
122
+ class GraphCreateResponse(BaseModel):
123
+ """Response after creating a graph."""
124
+ graph_id: str = Field(..., description="Unique identifier for the created graph")
125
+ name: str = Field(..., description="Name of the workflow")
126
+ message: str = Field(default="Graph created successfully")
127
+ node_count: int = Field(..., description="Number of nodes in the graph")
128
+
129
+ class Config:
130
+ json_schema_extra = {
131
+ "example": {
132
+ "graph_id": "abc123-def456",
133
+ "name": "code_review_workflow",
134
+ "message": "Graph created successfully",
135
+ "node_count": 4
136
+ }
137
+ }
138
+
139
+
140
+ class GraphInfoResponse(BaseModel):
141
+ """Response with graph information."""
142
+ graph_id: str
143
+ name: str
144
+ description: Optional[str]
145
+ node_count: int
146
+ nodes: List[str]
147
+ entry_point: Optional[str]
148
+ max_iterations: int
149
+ created_at: str
150
+ mermaid_diagram: Optional[str] = Field(None, description="Mermaid diagram of the graph")
151
+
152
+
153
+ class GraphListResponse(BaseModel):
154
+ """Response listing all graphs."""
155
+ graphs: List[GraphInfoResponse]
156
+ total: int
157
+
158
+
159
+ # ============================================================
160
+ # Run Schemas
161
+ # ============================================================
162
+
163
+ class GraphRunRequest(BaseModel):
164
+ """Request to run a workflow graph."""
165
+ graph_id: str = Field(..., description="ID of the graph to run")
166
+ initial_state: Dict[str, Any] = Field(
167
+ ...,
168
+ description="Initial state data for the workflow"
169
+ )
170
+ async_execution: bool = Field(
171
+ False,
172
+ description="If true, run in background and return immediately"
173
+ )
174
+
175
+ class Config:
176
+ json_schema_extra = {
177
+ "example": {
178
+ "graph_id": "abc123-def456",
179
+ "initial_state": {
180
+ "code": "def hello():\n print('world')",
181
+ "quality_threshold": 7.0
182
+ },
183
+ "async_execution": False
184
+ }
185
+ }
186
+
187
+
188
+ class ExecutionLogEntry(BaseModel):
189
+ """A single entry in the execution log."""
190
+ step: int
191
+ node: str
192
+ started_at: str
193
+ completed_at: Optional[str]
194
+ duration_ms: Optional[float]
195
+ iteration: int
196
+ result: str
197
+ error: Optional[str]
198
+ route_taken: Optional[str]
199
+
200
+
201
+ class GraphRunResponse(BaseModel):
202
+ """Response after running a graph."""
203
+ run_id: str = Field(..., description="Unique identifier for this run")
204
+ graph_id: str
205
+ status: ExecutionStatus
206
+ final_state: Dict[str, Any]
207
+ execution_log: List[ExecutionLogEntry]
208
+ started_at: Optional[str]
209
+ completed_at: Optional[str]
210
+ total_duration_ms: Optional[float]
211
+ iterations: int
212
+ error: Optional[str] = None
213
+
214
+ class Config:
215
+ json_schema_extra = {
216
+ "example": {
217
+ "run_id": "run-xyz789",
218
+ "graph_id": "abc123-def456",
219
+ "status": "completed",
220
+ "final_state": {
221
+ "code": "def hello():\n print('world')",
222
+ "functions": [{"name": "hello"}],
223
+ "quality_score": 8.5
224
+ },
225
+ "execution_log": [
226
+ {
227
+ "step": 1,
228
+ "node": "extract",
229
+ "started_at": "2024-01-01T12:00:00",
230
+ "completed_at": "2024-01-01T12:00:01",
231
+ "duration_ms": 15.5,
232
+ "iteration": 0,
233
+ "result": "success",
234
+ "error": None,
235
+ "route_taken": None
236
+ }
237
+ ],
238
+ "started_at": "2024-01-01T12:00:00",
239
+ "completed_at": "2024-01-01T12:00:05",
240
+ "total_duration_ms": 5000.0,
241
+ "iterations": 1,
242
+ "error": None
243
+ }
244
+ }
245
+
246
+
247
+ class RunStateResponse(BaseModel):
248
+ """Response with current run state."""
249
+ run_id: str
250
+ graph_id: str
251
+ status: ExecutionStatus
252
+ current_node: Optional[str]
253
+ current_state: Dict[str, Any]
254
+ iteration: int
255
+ execution_log: List[ExecutionLogEntry]
256
+ started_at: str
257
+ completed_at: Optional[str]
258
+ error: Optional[str]
259
+
260
+
261
+ class RunListResponse(BaseModel):
262
+ """Response listing runs."""
263
+ runs: List[RunStateResponse]
264
+ total: int
265
+
266
+
267
+ # ============================================================
268
+ # Tool Schemas
269
+ # ============================================================
270
+
271
+ class ToolInfo(BaseModel):
272
+ """Information about a registered tool."""
273
+ name: str
274
+ description: str
275
+ parameters: Dict[str, str]
276
+
277
+
278
+ class ToolListResponse(BaseModel):
279
+ """Response listing all registered tools."""
280
+ tools: List[ToolInfo]
281
+ total: int
282
+
283
+
284
+ class ToolRegisterRequest(BaseModel):
285
+ """Request to register a new tool (for dynamic registration)."""
286
+ name: str = Field(..., description="Unique name for the tool")
287
+ description: str = Field("", description="Description of what the tool does")
288
+ code: str = Field(..., description="Python code for the tool function")
289
+
290
+ class Config:
291
+ json_schema_extra = {
292
+ "example": {
293
+ "name": "custom_validator",
294
+ "description": "Custom validation logic",
295
+ "code": "def custom_validator(data):\n return {'valid': True}"
296
+ }
297
+ }
298
+
299
+
300
+ class ToolRegisterResponse(BaseModel):
301
+ """Response after registering a tool."""
302
+ name: str
303
+ message: str
304
+ warning: Optional[str] = None
305
+
306
+
307
+ # ============================================================
308
+ # Error Schemas
309
+ # ============================================================
310
+
311
+ class ErrorResponse(BaseModel):
312
+ """Standard error response."""
313
+ error: str
314
+ detail: Optional[str] = None
315
+ status_code: int
316
+
317
+
318
+ class ValidationErrorResponse(BaseModel):
319
+ """Validation error response."""
320
+ error: str = "Validation Error"
321
+ detail: List[Dict[str, Any]]
322
+ status_code: int = 422
app/config.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration settings for the Workflow Engine.
3
+ """
4
+
5
+ from pydantic_settings import BaseSettings
6
+ from typing import Optional
7
+
8
+
9
+ class Settings(BaseSettings):
10
+ """Application settings with environment variable support."""
11
+
12
+ # Application
13
+ APP_NAME: str = "FlowGraph"
14
+ APP_VERSION: str = "1.0.0"
15
+ DEBUG: bool = True
16
+
17
+ # Server
18
+ HOST: str = "0.0.0.0"
19
+ PORT: int = 8000
20
+
21
+ # Workflow Engine
22
+ MAX_ITERATIONS: int = 100 # Default max loop iterations
23
+ EXECUTION_TIMEOUT: int = 300 # Seconds
24
+
25
+ # Logging
26
+ LOG_LEVEL: str = "INFO"
27
+
28
+ class Config:
29
+ env_file = ".env"
30
+ case_sensitive = True
31
+
32
+
33
+ # Global settings instance
34
+ settings = Settings()
app/engine/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engine package - Core workflow orchestration components.
3
+ """
4
+
5
+ from app.engine.state import WorkflowState, StateManager
6
+ from app.engine.node import Node, node
7
+ from app.engine.graph import Graph
8
+ from app.engine.executor import Executor, ExecutionResult
9
+
10
+ __all__ = [
11
+ "WorkflowState",
12
+ "StateManager",
13
+ "Node",
14
+ "node",
15
+ "Graph",
16
+ "Executor",
17
+ "ExecutionResult",
18
+ ]
app/engine/executor.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Async Workflow Executor.
3
+
4
+ The executor runs a workflow graph, managing state transitions,
5
+ handling loops, and generating execution logs.
6
+ """
7
+
8
+ from typing import Any, Callable, Dict, List, Optional
9
+ from dataclasses import dataclass, field
10
+ from datetime import datetime
11
+ from enum import Enum
12
+ import asyncio
13
+ import uuid
14
+ import time
15
+ import logging
16
+
17
+ from app.engine.graph import Graph, END
18
+ from app.engine.state import WorkflowState, StateManager
19
+
20
+
21
+ # Configure logging
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class ExecutionStatus(str, Enum):
26
+ """Status of a workflow execution."""
27
+ PENDING = "pending"
28
+ RUNNING = "running"
29
+ COMPLETED = "completed"
30
+ FAILED = "failed"
31
+ CANCELLED = "cancelled"
32
+
33
+
34
+ @dataclass
35
+ class ExecutionStep:
36
+ """A single step in the execution log."""
37
+ step: int
38
+ node: str
39
+ started_at: datetime
40
+ completed_at: Optional[datetime] = None
41
+ duration_ms: Optional[float] = None
42
+ iteration: int = 0
43
+ result: str = "success"
44
+ error: Optional[str] = None
45
+ route_taken: Optional[str] = None
46
+
47
+ def to_dict(self) -> Dict[str, Any]:
48
+ return {
49
+ "step": self.step,
50
+ "node": self.node,
51
+ "started_at": self.started_at.isoformat(),
52
+ "completed_at": self.completed_at.isoformat() if self.completed_at else None,
53
+ "duration_ms": self.duration_ms,
54
+ "iteration": self.iteration,
55
+ "result": self.result,
56
+ "error": self.error,
57
+ "route_taken": self.route_taken,
58
+ }
59
+
60
+
61
+ @dataclass
62
+ class ExecutionResult:
63
+ """Result of a workflow execution."""
64
+ run_id: str
65
+ graph_id: str
66
+ status: ExecutionStatus
67
+ final_state: Dict[str, Any]
68
+ execution_log: List[ExecutionStep] = field(default_factory=list)
69
+ started_at: Optional[datetime] = None
70
+ completed_at: Optional[datetime] = None
71
+ total_duration_ms: Optional[float] = None
72
+ error: Optional[str] = None
73
+ iterations: int = 0
74
+
75
+ def to_dict(self) -> Dict[str, Any]:
76
+ return {
77
+ "run_id": self.run_id,
78
+ "graph_id": self.graph_id,
79
+ "status": self.status.value,
80
+ "final_state": self.final_state,
81
+ "execution_log": [step.to_dict() for step in self.execution_log],
82
+ "started_at": self.started_at.isoformat() if self.started_at else None,
83
+ "completed_at": self.completed_at.isoformat() if self.completed_at else None,
84
+ "total_duration_ms": self.total_duration_ms,
85
+ "error": self.error,
86
+ "iterations": self.iterations,
87
+ }
88
+
89
+
90
+ class Executor:
91
+ """
92
+ Async workflow executor.
93
+
94
+ Executes a graph with given initial state, handling:
95
+ - Sequential node execution
96
+ - Conditional branching
97
+ - Loop iterations with max limit
98
+ - Detailed execution logging
99
+ - Error handling
100
+
101
+ Usage:
102
+ executor = Executor(graph)
103
+ result = await executor.run({"input": "data"})
104
+ """
105
+
106
+ def __init__(
107
+ self,
108
+ graph: Graph,
109
+ run_id: Optional[str] = None,
110
+ on_step: Optional[Callable[[ExecutionStep, Dict[str, Any]], None]] = None
111
+ ):
112
+ """
113
+ Initialize the executor.
114
+
115
+ Args:
116
+ graph: The workflow graph to execute
117
+ run_id: Optional run ID (generated if not provided)
118
+ on_step: Optional callback for each step (for WebSocket streaming)
119
+ """
120
+ self.graph = graph
121
+ self.run_id = run_id or str(uuid.uuid4())
122
+ self.on_step = on_step
123
+
124
+ # Execution state
125
+ self._state_manager: Optional[StateManager] = None
126
+ self._execution_log: List[ExecutionStep] = []
127
+ self._step_counter = 0
128
+ self._status = ExecutionStatus.PENDING
129
+ self._cancelled = False
130
+
131
+ @property
132
+ def status(self) -> ExecutionStatus:
133
+ """Get the current execution status."""
134
+ return self._status
135
+
136
+ @property
137
+ def current_state(self) -> Optional[Dict[str, Any]]:
138
+ """Get the current state data."""
139
+ if self._state_manager and self._state_manager.current_state:
140
+ return self._state_manager.current_state.data
141
+ return None
142
+
143
+ @property
144
+ def current_node(self) -> Optional[str]:
145
+ """Get the current node being executed."""
146
+ if self._state_manager and self._state_manager.current_state:
147
+ return self._state_manager.current_state.current_node
148
+ return None
149
+
150
+ def cancel(self) -> None:
151
+ """Cancel the execution."""
152
+ self._cancelled = True
153
+ self._status = ExecutionStatus.CANCELLED
154
+
155
+ async def run(self, initial_state: Dict[str, Any]) -> ExecutionResult:
156
+ """
157
+ Execute the workflow with the given initial state.
158
+
159
+ Args:
160
+ initial_state: Initial state data
161
+
162
+ Returns:
163
+ ExecutionResult with final state and logs
164
+ """
165
+ start_time = time.time()
166
+ self._status = ExecutionStatus.RUNNING
167
+ self._state_manager = StateManager(self.run_id)
168
+
169
+ # Initialize state
170
+ state = self._state_manager.initialize(initial_state)
171
+
172
+ # Validate graph
173
+ errors = self.graph.validate()
174
+ if errors:
175
+ return self._create_error_result(
176
+ f"Graph validation failed: {errors}",
177
+ start_time
178
+ )
179
+
180
+ current_node = self.graph.entry_point
181
+ iteration = 0
182
+ visited_in_iteration: set = set()
183
+
184
+ try:
185
+ while current_node and current_node != END:
186
+ # Check cancellation
187
+ if self._cancelled:
188
+ logger.info(f"Execution cancelled at node '{current_node}'")
189
+ break
190
+
191
+ # Check max iterations
192
+ if iteration >= self.graph.max_iterations:
193
+ return self._create_error_result(
194
+ f"Max iterations ({self.graph.max_iterations}) exceeded",
195
+ start_time
196
+ )
197
+
198
+ # Get the node
199
+ node = self.graph.nodes.get(current_node)
200
+ if not node:
201
+ return self._create_error_result(
202
+ f"Node '{current_node}' not found in graph",
203
+ start_time
204
+ )
205
+
206
+ # Execute the node
207
+ step = await self._execute_node(node, state, iteration)
208
+
209
+ # Handle error
210
+ if step.result == "error":
211
+ return self._create_error_result(
212
+ step.error or "Unknown error",
213
+ start_time
214
+ )
215
+
216
+ # Update state from state manager
217
+ state = self._state_manager.current_state
218
+
219
+ # Get next node
220
+ next_node = self.graph.get_next_node(current_node, state.data)
221
+
222
+ # Track route for conditional edges
223
+ if current_node in self.graph.conditional_edges:
224
+ cond_edge = self.graph.conditional_edges[current_node]
225
+ route_key = cond_edge.condition(state.data)
226
+ step.route_taken = route_key
227
+ logger.debug(f"Conditional route: {route_key} -> {next_node}")
228
+
229
+ # Detect loops and increment iteration
230
+ if next_node in visited_in_iteration:
231
+ iteration += 1
232
+ visited_in_iteration.clear()
233
+ state = state.increment_iteration()
234
+ logger.debug(f"Loop detected, iteration: {iteration}")
235
+
236
+ visited_in_iteration.add(current_node)
237
+ current_node = next_node
238
+
239
+ # Finalize
240
+ self._status = ExecutionStatus.COMPLETED
241
+ final_state = self._state_manager.finalize()
242
+
243
+ return ExecutionResult(
244
+ run_id=self.run_id,
245
+ graph_id=self.graph.graph_id,
246
+ status=self._status,
247
+ final_state=final_state.data,
248
+ execution_log=self._execution_log,
249
+ started_at=final_state.started_at,
250
+ completed_at=final_state.completed_at,
251
+ total_duration_ms=(time.time() - start_time) * 1000,
252
+ iterations=iteration + 1,
253
+ )
254
+
255
+ except Exception as e:
256
+ logger.exception(f"Execution failed: {e}")
257
+ return self._create_error_result(str(e), start_time)
258
+
259
+ async def _execute_node(
260
+ self,
261
+ node,
262
+ state: WorkflowState,
263
+ iteration: int
264
+ ) -> ExecutionStep:
265
+ """Execute a single node and update state."""
266
+ self._step_counter += 1
267
+ step_start = datetime.now()
268
+ node_start_time = time.time()
269
+
270
+ step = ExecutionStep(
271
+ step=self._step_counter,
272
+ node=node.name,
273
+ started_at=step_start,
274
+ iteration=iteration,
275
+ )
276
+
277
+ logger.info(f"Executing node: {node.name} (step {self._step_counter})")
278
+
279
+ try:
280
+ # Execute node handler
281
+ result_data = await node.execute(state.data)
282
+
283
+ # Update state
284
+ new_state = state.update(result_data).mark_visited(node.name)
285
+ self._state_manager.update(new_state, node.name)
286
+
287
+ # Complete step
288
+ step.completed_at = datetime.now()
289
+ step.duration_ms = (time.time() - node_start_time) * 1000
290
+ step.result = "success"
291
+
292
+ except Exception as e:
293
+ logger.error(f"Node {node.name} failed: {e}")
294
+ step.completed_at = datetime.now()
295
+ step.duration_ms = (time.time() - node_start_time) * 1000
296
+ step.result = "error"
297
+ step.error = str(e)
298
+
299
+ # Add to log
300
+ self._execution_log.append(step)
301
+
302
+ # Notify callback
303
+ if self.on_step:
304
+ try:
305
+ self.on_step(step, self._state_manager.current_state.data)
306
+ except Exception as e:
307
+ logger.warning(f"Step callback failed: {e}")
308
+
309
+ return step
310
+
311
+ def _create_error_result(
312
+ self,
313
+ error: str,
314
+ start_time: float
315
+ ) -> ExecutionResult:
316
+ """Create an error result."""
317
+ self._status = ExecutionStatus.FAILED
318
+ return ExecutionResult(
319
+ run_id=self.run_id,
320
+ graph_id=self.graph.graph_id,
321
+ status=ExecutionStatus.FAILED,
322
+ final_state=self.current_state or {},
323
+ execution_log=self._execution_log,
324
+ started_at=datetime.now(),
325
+ completed_at=datetime.now(),
326
+ total_duration_ms=(time.time() - start_time) * 1000,
327
+ error=error,
328
+ )
329
+
330
+ def get_execution_summary(self) -> Dict[str, Any]:
331
+ """Get a summary of the current execution."""
332
+ return {
333
+ "run_id": self.run_id,
334
+ "graph_id": self.graph.graph_id,
335
+ "status": self._status.value,
336
+ "current_node": self.current_node,
337
+ "current_state": self.current_state,
338
+ "step_count": self._step_counter,
339
+ "iteration": self._state_manager.current_state.iteration if self._state_manager and self._state_manager.current_state else 0,
340
+ }
341
+
342
+
343
+ async def execute_graph(
344
+ graph: Graph,
345
+ initial_state: Dict[str, Any],
346
+ run_id: Optional[str] = None,
347
+ on_step: Optional[Callable] = None
348
+ ) -> ExecutionResult:
349
+ """
350
+ Convenience function to execute a graph.
351
+
352
+ Args:
353
+ graph: The workflow graph
354
+ initial_state: Initial state data
355
+ run_id: Optional run ID
356
+ on_step: Optional step callback
357
+
358
+ Returns:
359
+ ExecutionResult
360
+ """
361
+ executor = Executor(graph, run_id, on_step)
362
+ return await executor.run(initial_state)
app/engine/graph.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Graph Definition for Workflow Engine.
3
+
4
+ The Graph is the core structure that defines the workflow - nodes, edges,
5
+ conditional routing, and execution flow.
6
+ """
7
+
8
+ from typing import Any, Callable, Dict, List, Optional, Set, Union
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+ import uuid
12
+
13
+ from app.engine.node import Node, NodeType, get_registered_node, create_node_from_function
14
+
15
+
16
+ # Special node names
17
+ END = "__END__"
18
+ START = "__START__"
19
+
20
+
21
+ class EdgeType(str, Enum):
22
+ """Types of edges between nodes."""
23
+ DIRECT = "direct" # Always follow this edge
24
+ CONDITIONAL = "conditional" # Choose based on condition
25
+
26
+
27
+ @dataclass
28
+ class Edge:
29
+ """An edge connecting two nodes."""
30
+ source: str
31
+ target: str
32
+ edge_type: EdgeType = EdgeType.DIRECT
33
+
34
+ def to_dict(self) -> Dict[str, str]:
35
+ return {
36
+ "source": self.source,
37
+ "target": self.target,
38
+ "type": self.edge_type.value
39
+ }
40
+
41
+
42
+ @dataclass
43
+ class ConditionalEdge:
44
+ """
45
+ A conditional edge that routes to different nodes based on a condition.
46
+
47
+ The condition function receives the current state and returns a route key.
48
+ The routes dict maps route keys to target node names.
49
+ """
50
+ source: str
51
+ condition: Callable[[Dict[str, Any]], str]
52
+ routes: Dict[str, str] # route_key -> target_node_name
53
+
54
+ def evaluate(self, state_data: Dict[str, Any]) -> str:
55
+ """Evaluate the condition and return the target node name."""
56
+ route_key = self.condition(state_data)
57
+ if route_key not in self.routes:
58
+ raise ValueError(
59
+ f"Condition returned unknown route '{route_key}'. "
60
+ f"Available routes: {list(self.routes.keys())}"
61
+ )
62
+ return self.routes[route_key]
63
+
64
+ def to_dict(self) -> Dict[str, Any]:
65
+ return {
66
+ "source": self.source,
67
+ "condition": self.condition.__name__ if hasattr(self.condition, '__name__') else str(self.condition),
68
+ "routes": self.routes
69
+ }
70
+
71
+
72
+ @dataclass
73
+ class Graph:
74
+ """
75
+ A workflow graph consisting of nodes and edges.
76
+
77
+ The graph defines the structure of a workflow:
78
+ - Nodes: Processing units that transform state
79
+ - Edges: Connections between nodes
80
+ - Conditional Edges: Branching logic based on state
81
+
82
+ Attributes:
83
+ graph_id: Unique identifier for this graph
84
+ name: Human-readable name
85
+ nodes: Dict of node_name -> Node
86
+ edges: List of direct edges
87
+ conditional_edges: Dict of source_node -> ConditionalEdge
88
+ entry_point: Name of the first node to execute
89
+ max_iterations: Maximum loop iterations allowed
90
+ """
91
+
92
+ graph_id: str = field(default_factory=lambda: str(uuid.uuid4()))
93
+ name: str = "Unnamed Workflow"
94
+ nodes: Dict[str, Node] = field(default_factory=dict)
95
+ edges: Dict[str, str] = field(default_factory=dict) # source -> target for direct edges
96
+ conditional_edges: Dict[str, ConditionalEdge] = field(default_factory=dict)
97
+ entry_point: Optional[str] = None
98
+ max_iterations: int = 100
99
+ description: str = ""
100
+ metadata: Dict[str, Any] = field(default_factory=dict)
101
+
102
+ def add_node(
103
+ self,
104
+ name: str,
105
+ handler: Optional[Callable] = None,
106
+ node_type: NodeType = NodeType.STANDARD,
107
+ description: str = ""
108
+ ) -> "Graph":
109
+ """
110
+ Add a node to the graph.
111
+
112
+ If handler is not provided, attempts to find a registered node
113
+ with the given name.
114
+
115
+ Args:
116
+ name: Unique name for the node
117
+ handler: Function to execute (optional if registered)
118
+ node_type: Type of node
119
+ description: Human-readable description
120
+
121
+ Returns:
122
+ Self for chaining
123
+ """
124
+ if handler is None:
125
+ # Try to find a registered handler
126
+ handler = get_registered_node(name)
127
+ if handler is None:
128
+ raise ValueError(
129
+ f"No handler provided for node '{name}' and no registered "
130
+ f"node found with that name"
131
+ )
132
+
133
+ if name in self.nodes:
134
+ raise ValueError(f"Node '{name}' already exists in the graph")
135
+
136
+ node = create_node_from_function(handler, name, node_type, description)
137
+ self.nodes[name] = node
138
+
139
+ # Set as entry point if it's the first node or marked as entry
140
+ if self.entry_point is None or node_type == NodeType.ENTRY:
141
+ self.entry_point = name
142
+
143
+ return self
144
+
145
+ def add_edge(self, source: str, target: str) -> "Graph":
146
+ """
147
+ Add a direct edge from source to target.
148
+
149
+ Args:
150
+ source: Source node name
151
+ target: Target node name (or END)
152
+
153
+ Returns:
154
+ Self for chaining
155
+ """
156
+ if source not in self.nodes:
157
+ raise ValueError(f"Source node '{source}' not found in graph")
158
+ if target != END and target not in self.nodes:
159
+ raise ValueError(f"Target node '{target}' not found in graph")
160
+
161
+ # Check for conflicts with conditional edges
162
+ if source in self.conditional_edges:
163
+ raise ValueError(
164
+ f"Node '{source}' already has a conditional edge. "
165
+ f"Cannot add a direct edge."
166
+ )
167
+
168
+ self.edges[source] = target
169
+ return self
170
+
171
+ def add_conditional_edge(
172
+ self,
173
+ source: str,
174
+ condition: Callable[[Dict[str, Any]], str],
175
+ routes: Dict[str, str]
176
+ ) -> "Graph":
177
+ """
178
+ Add a conditional edge from source node.
179
+
180
+ The condition function receives state and returns a route key.
181
+
182
+ Args:
183
+ source: Source node name
184
+ condition: Function that returns route key
185
+ routes: Dict mapping route keys to target nodes
186
+
187
+ Returns:
188
+ Self for chaining
189
+ """
190
+ if source not in self.nodes:
191
+ raise ValueError(f"Source node '{source}' not found in graph")
192
+
193
+ # Validate all targets
194
+ for route_key, target in routes.items():
195
+ if target != END and target not in self.nodes:
196
+ raise ValueError(
197
+ f"Target node '{target}' for route '{route_key}' not found in graph"
198
+ )
199
+
200
+ # Check for conflicts with direct edges
201
+ if source in self.edges:
202
+ raise ValueError(
203
+ f"Node '{source}' already has a direct edge. "
204
+ f"Cannot add a conditional edge."
205
+ )
206
+
207
+ self.conditional_edges[source] = ConditionalEdge(
208
+ source=source,
209
+ condition=condition,
210
+ routes=routes
211
+ )
212
+ return self
213
+
214
+ def set_entry_point(self, node_name: str) -> "Graph":
215
+ """Set the entry point of the graph."""
216
+ if node_name not in self.nodes:
217
+ raise ValueError(f"Node '{node_name}' not found in graph")
218
+ self.entry_point = node_name
219
+ return self
220
+
221
+ def get_next_node(self, current_node: str, state_data: Dict[str, Any]) -> Optional[str]:
222
+ """
223
+ Get the next node to execute based on edges and state.
224
+
225
+ Args:
226
+ current_node: Current node name
227
+ state_data: Current state data
228
+
229
+ Returns:
230
+ Next node name, END, or None if no edge defined
231
+ """
232
+ # Check for conditional edge first
233
+ if current_node in self.conditional_edges:
234
+ conditional = self.conditional_edges[current_node]
235
+ return conditional.evaluate(state_data)
236
+
237
+ # Check for direct edge
238
+ if current_node in self.edges:
239
+ return self.edges[current_node]
240
+
241
+ # No edge defined - implicit end
242
+ return None
243
+
244
+ def validate(self) -> List[str]:
245
+ """
246
+ Validate the graph structure.
247
+
248
+ Returns:
249
+ List of validation errors (empty if valid)
250
+ """
251
+ errors = []
252
+
253
+ # Must have at least one node
254
+ if not self.nodes:
255
+ errors.append("Graph must have at least one node")
256
+ return errors
257
+
258
+ # Must have an entry point
259
+ if not self.entry_point:
260
+ errors.append("Graph must have an entry point")
261
+ elif self.entry_point not in self.nodes:
262
+ errors.append(f"Entry point '{self.entry_point}' not found in nodes")
263
+
264
+ # Check for orphan nodes (not reachable from entry point)
265
+ reachable = self._get_reachable_nodes()
266
+ orphans = set(self.nodes.keys()) - reachable
267
+ if orphans:
268
+ errors.append(f"Orphan nodes (not reachable): {orphans}")
269
+
270
+ # Check that nodes without outgoing edges make sense
271
+ for node_name in self.nodes:
272
+ if node_name not in self.edges and node_name not in self.conditional_edges:
273
+ # This is an implicit end node - that's okay
274
+ pass
275
+
276
+ return errors
277
+
278
+ def _get_reachable_nodes(self) -> Set[str]:
279
+ """Get all nodes reachable from the entry point."""
280
+ if not self.entry_point:
281
+ return set()
282
+
283
+ reachable = set()
284
+ to_visit = [self.entry_point]
285
+
286
+ while to_visit:
287
+ node = to_visit.pop()
288
+ if node in reachable or node == END:
289
+ continue
290
+
291
+ reachable.add(node)
292
+
293
+ # Add direct edge target
294
+ if node in self.edges:
295
+ to_visit.append(self.edges[node])
296
+
297
+ # Add conditional edge targets
298
+ if node in self.conditional_edges:
299
+ for target in self.conditional_edges[node].routes.values():
300
+ to_visit.append(target)
301
+
302
+ return reachable
303
+
304
+ def to_dict(self) -> Dict[str, Any]:
305
+ """Serialize the graph to a dictionary."""
306
+ return {
307
+ "graph_id": self.graph_id,
308
+ "name": self.name,
309
+ "description": self.description,
310
+ "nodes": {name: node.to_dict() for name, node in self.nodes.items()},
311
+ "edges": self.edges,
312
+ "conditional_edges": {
313
+ name: edge.to_dict()
314
+ for name, edge in self.conditional_edges.items()
315
+ },
316
+ "entry_point": self.entry_point,
317
+ "max_iterations": self.max_iterations,
318
+ "metadata": self.metadata,
319
+ }
320
+
321
+ def to_mermaid(self) -> str:
322
+ """Generate a Mermaid diagram of the graph."""
323
+ lines = ["graph TD"]
324
+
325
+ # Add nodes
326
+ for name, node in self.nodes.items():
327
+ label = name.replace("_", " ").title()
328
+ if node.node_type == NodeType.ENTRY:
329
+ lines.append(f' {name}["{label} 🚀"]')
330
+ elif node.node_type == NodeType.EXIT:
331
+ lines.append(f' {name}["{label} 🏁"]')
332
+ else:
333
+ lines.append(f' {name}["{label}"]')
334
+
335
+ # Add END node if used
336
+ has_end = END in self.edges.values()
337
+ for cond in self.conditional_edges.values():
338
+ if END in cond.routes.values():
339
+ has_end = True
340
+ break
341
+
342
+ if has_end:
343
+ lines.append(f' {END}(("END"))')
344
+
345
+ # Add direct edges
346
+ for source, target in self.edges.items():
347
+ lines.append(f" {source} --> {target}")
348
+
349
+ # Add conditional edges
350
+ for source, cond in self.conditional_edges.items():
351
+ for route_key, target in cond.routes.items():
352
+ lines.append(f" {source} -->|{route_key}| {target}")
353
+
354
+ return "\n".join(lines)
355
+
356
+ def __repr__(self) -> str:
357
+ return (
358
+ f"Graph(name='{self.name}', nodes={list(self.nodes.keys())}, "
359
+ f"entry='{self.entry_point}')"
360
+ )
app/engine/node.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Node Definition for Workflow Engine.
3
+
4
+ Nodes are the building blocks of a workflow. Each node is a function
5
+ that receives state, performs some operation, and returns modified state.
6
+ """
7
+
8
+ from typing import Any, Callable, Dict, Optional, Union
9
+ from dataclasses import dataclass, field
10
+ from enum import Enum
11
+ import asyncio
12
+ import inspect
13
+ import functools
14
+
15
+
16
+ class NodeType(str, Enum):
17
+ """Types of nodes in the workflow."""
18
+ STANDARD = "standard" # Regular processing node
19
+ CONDITIONAL = "conditional" # Branching decision node
20
+ ENTRY = "entry" # Entry point
21
+ EXIT = "exit" # Exit point
22
+
23
+
24
+ @dataclass
25
+ class Node:
26
+ """
27
+ A node in the workflow graph.
28
+
29
+ Each node has a name and a handler function. The handler receives
30
+ the current state data (as a dict) and returns modified state data.
31
+
32
+ Attributes:
33
+ name: Unique identifier for the node
34
+ handler: Function that processes state (sync or async)
35
+ node_type: Type of node (standard, conditional, etc.)
36
+ description: Human-readable description
37
+ metadata: Additional node metadata
38
+ """
39
+
40
+ name: str
41
+ handler: Callable[[Dict[str, Any]], Union[Dict[str, Any], Any]]
42
+ node_type: NodeType = NodeType.STANDARD
43
+ description: str = ""
44
+ metadata: Dict[str, Any] = field(default_factory=dict)
45
+
46
+ def __post_init__(self):
47
+ """Validate the node after initialization."""
48
+ if not self.name:
49
+ raise ValueError("Node name cannot be empty")
50
+ if not callable(self.handler):
51
+ raise ValueError(f"Handler for node '{self.name}' must be callable")
52
+
53
+ @property
54
+ def is_async(self) -> bool:
55
+ """Check if the handler is an async function."""
56
+ return asyncio.iscoroutinefunction(self.handler)
57
+
58
+ async def execute(self, state_data: Dict[str, Any]) -> Dict[str, Any]:
59
+ """
60
+ Execute the node handler with the given state data.
61
+
62
+ Handles both sync and async handlers transparently.
63
+
64
+ Args:
65
+ state_data: The current state data dictionary
66
+
67
+ Returns:
68
+ Modified state data dictionary
69
+ """
70
+ try:
71
+ if self.is_async:
72
+ result = await self.handler(state_data)
73
+ else:
74
+ # Run sync handler in executor to not block
75
+ loop = asyncio.get_event_loop()
76
+ result = await loop.run_in_executor(
77
+ None,
78
+ functools.partial(self.handler, state_data)
79
+ )
80
+
81
+ # If handler returns None, return original state
82
+ if result is None:
83
+ return state_data
84
+
85
+ # If handler returns a dict, use it as the new state
86
+ if isinstance(result, dict):
87
+ return result
88
+
89
+ # Otherwise, something unexpected happened
90
+ raise ValueError(
91
+ f"Node '{self.name}' handler must return a dict or None, "
92
+ f"got {type(result).__name__}"
93
+ )
94
+
95
+ except Exception as e:
96
+ # Add context to the error
97
+ raise RuntimeError(f"Error in node '{self.name}': {str(e)}") from e
98
+
99
+ def to_dict(self) -> Dict[str, Any]:
100
+ """Serialize the node to a dictionary."""
101
+ return {
102
+ "name": self.name,
103
+ "type": self.node_type.value,
104
+ "description": self.description,
105
+ "handler": self.handler.__name__ if hasattr(self.handler, '__name__') else str(self.handler),
106
+ "metadata": self.metadata,
107
+ }
108
+
109
+
110
+ # Registry to hold decorated node functions
111
+ _node_registry: Dict[str, Callable] = {}
112
+
113
+
114
+ def node(
115
+ name: Optional[str] = None,
116
+ node_type: NodeType = NodeType.STANDARD,
117
+ description: str = ""
118
+ ) -> Callable:
119
+ """
120
+ Decorator to register a function as a workflow node.
121
+
122
+ Usage:
123
+ @node(name="extract_functions", description="Extract functions from code")
124
+ def extract_functions(state: dict) -> dict:
125
+ # ... process state
126
+ return state
127
+
128
+ Args:
129
+ name: Node name (defaults to function name)
130
+ node_type: Type of node
131
+ description: Human-readable description
132
+
133
+ Returns:
134
+ Decorated function
135
+ """
136
+ def decorator(func: Callable) -> Callable:
137
+ node_name = name or func.__name__
138
+
139
+ # Store metadata on the function
140
+ func._node_metadata = {
141
+ "name": node_name,
142
+ "type": node_type,
143
+ "description": description or func.__doc__ or "",
144
+ }
145
+
146
+ # Register in global registry
147
+ _node_registry[node_name] = func
148
+
149
+ @functools.wraps(func)
150
+ def wrapper(*args, **kwargs):
151
+ return func(*args, **kwargs)
152
+
153
+ wrapper._node_metadata = func._node_metadata
154
+ return wrapper
155
+
156
+ return decorator
157
+
158
+
159
+ def get_registered_node(name: str) -> Optional[Callable]:
160
+ """Get a registered node function by name."""
161
+ return _node_registry.get(name)
162
+
163
+
164
+ def list_registered_nodes() -> Dict[str, Dict[str, Any]]:
165
+ """List all registered nodes and their metadata."""
166
+ return {
167
+ name: func._node_metadata
168
+ for name, func in _node_registry.items()
169
+ if hasattr(func, '_node_metadata')
170
+ }
171
+
172
+
173
+ def create_node_from_function(
174
+ func: Callable,
175
+ name: Optional[str] = None,
176
+ node_type: NodeType = NodeType.STANDARD,
177
+ description: str = ""
178
+ ) -> Node:
179
+ """
180
+ Create a Node instance from a function.
181
+
182
+ Args:
183
+ func: The handler function
184
+ name: Node name (defaults to function name)
185
+ node_type: Type of node
186
+ description: Human-readable description
187
+
188
+ Returns:
189
+ A Node instance
190
+ """
191
+ return Node(
192
+ name=name or func.__name__,
193
+ handler=func,
194
+ node_type=node_type,
195
+ description=description or func.__doc__ or "",
196
+ )
app/engine/state.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ State Management for Workflow Engine.
3
+
4
+ This module provides the state management system that flows through the workflow.
5
+ State is immutable - each node receives state and returns a new modified state.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+ from pydantic import BaseModel, Field
10
+ from datetime import datetime
11
+ from copy import deepcopy
12
+ import uuid
13
+
14
+
15
+ class StateSnapshot(BaseModel):
16
+ """A snapshot of state at a specific point in execution."""
17
+
18
+ timestamp: datetime = Field(default_factory=datetime.now)
19
+ node_name: str
20
+ state_data: Dict[str, Any]
21
+ iteration: int = 0
22
+
23
+
24
+ class WorkflowState(BaseModel):
25
+ """
26
+ The shared state that flows through the workflow.
27
+
28
+ This is a flexible container that holds all data being processed
29
+ by the workflow nodes. Each node can read from and write to this state.
30
+
31
+ Attributes:
32
+ data: The actual workflow data (flexible dictionary)
33
+ metadata: Execution metadata (iteration count, visited nodes, etc.)
34
+ """
35
+
36
+ # The actual data being processed
37
+ data: Dict[str, Any] = Field(default_factory=dict)
38
+
39
+ # Execution metadata
40
+ current_node: Optional[str] = None
41
+ iteration: int = 0
42
+ visited_nodes: List[str] = Field(default_factory=list)
43
+ started_at: Optional[datetime] = None
44
+ completed_at: Optional[datetime] = None
45
+
46
+ class Config:
47
+ arbitrary_types_allowed = True
48
+
49
+ def get(self, key: str, default: Any = None) -> Any:
50
+ """Get a value from the state data."""
51
+ return self.data.get(key, default)
52
+
53
+ def set(self, key: str, value: Any) -> "WorkflowState":
54
+ """Set a value in state data and return a new state (immutable pattern)."""
55
+ new_data = deepcopy(self.data)
56
+ new_data[key] = value
57
+ return self.model_copy(update={"data": new_data})
58
+
59
+ def update(self, updates: Dict[str, Any]) -> "WorkflowState":
60
+ """Update multiple values and return a new state."""
61
+ new_data = deepcopy(self.data)
62
+ new_data.update(updates)
63
+ return self.model_copy(update={"data": new_data})
64
+
65
+ def mark_visited(self, node_name: str) -> "WorkflowState":
66
+ """Mark a node as visited."""
67
+ new_visited = self.visited_nodes + [node_name]
68
+ return self.model_copy(update={
69
+ "visited_nodes": new_visited,
70
+ "current_node": node_name
71
+ })
72
+
73
+ def increment_iteration(self) -> "WorkflowState":
74
+ """Increment the iteration counter."""
75
+ return self.model_copy(update={"iteration": self.iteration + 1})
76
+
77
+ def to_dict(self) -> Dict[str, Any]:
78
+ """Convert state to a plain dictionary."""
79
+ return {
80
+ "data": self.data,
81
+ "current_node": self.current_node,
82
+ "iteration": self.iteration,
83
+ "visited_nodes": self.visited_nodes,
84
+ "started_at": self.started_at.isoformat() if self.started_at else None,
85
+ "completed_at": self.completed_at.isoformat() if self.completed_at else None,
86
+ }
87
+
88
+ @classmethod
89
+ def from_dict(cls, data: Dict[str, Any]) -> "WorkflowState":
90
+ """Create a WorkflowState from a dictionary."""
91
+ if "data" in data:
92
+ return cls(**data)
93
+ # If it's just raw data, wrap it
94
+ return cls(data=data)
95
+
96
+
97
+ class StateManager:
98
+ """
99
+ Manages state history and snapshots for a workflow run.
100
+
101
+ This provides debugging capabilities by tracking state changes
102
+ throughout the workflow execution.
103
+ """
104
+
105
+ def __init__(self, run_id: Optional[str] = None):
106
+ self.run_id = run_id or str(uuid.uuid4())
107
+ self.history: List[StateSnapshot] = []
108
+ self._current_state: Optional[WorkflowState] = None
109
+
110
+ @property
111
+ def current_state(self) -> Optional[WorkflowState]:
112
+ """Get the current state."""
113
+ return self._current_state
114
+
115
+ def initialize(self, initial_data: Dict[str, Any]) -> WorkflowState:
116
+ """Initialize the state manager with initial data."""
117
+ self._current_state = WorkflowState(
118
+ data=initial_data,
119
+ started_at=datetime.now()
120
+ )
121
+ return self._current_state
122
+
123
+ def update(self, new_state: WorkflowState, node_name: str) -> None:
124
+ """Update the current state and record a snapshot."""
125
+ # Record snapshot
126
+ snapshot = StateSnapshot(
127
+ node_name=node_name,
128
+ state_data=deepcopy(new_state.data),
129
+ iteration=new_state.iteration
130
+ )
131
+ self.history.append(snapshot)
132
+
133
+ # Update current state
134
+ self._current_state = new_state
135
+
136
+ def finalize(self) -> WorkflowState:
137
+ """Mark the workflow as complete."""
138
+ if self._current_state:
139
+ self._current_state = self._current_state.model_copy(
140
+ update={"completed_at": datetime.now()}
141
+ )
142
+ return self._current_state
143
+
144
+ def get_history(self) -> List[Dict[str, Any]]:
145
+ """Get the state history as a list of dictionaries."""
146
+ return [
147
+ {
148
+ "timestamp": s.timestamp.isoformat(),
149
+ "node": s.node_name,
150
+ "iteration": s.iteration,
151
+ "state": s.state_data
152
+ }
153
+ for s in self.history
154
+ ]
app/main.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FlowGraph - FastAPI Application Entry Point.
3
+
4
+ A lightweight, async-first workflow orchestration engine for building agent pipelines.
5
+ """
6
+
7
+ from contextlib import asynccontextmanager
8
+ from fastapi import FastAPI
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import JSONResponse
11
+ import logging
12
+
13
+ from app.config import settings
14
+ from app.api.routes import graph, tools, websocket
15
+ from app.workflows.code_review import register_code_review_workflow
16
+
17
+ # Import builtin tools to register them
18
+ import app.tools.builtin # noqa: F401
19
+
20
+
21
+ # Configure logging
22
+ logging.basicConfig(
23
+ level=getattr(logging, settings.LOG_LEVEL),
24
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
25
+ )
26
+ logger = logging.getLogger(__name__)
27
+
28
+
29
+ @asynccontextmanager
30
+ async def lifespan(app: FastAPI):
31
+ """Application lifespan handler."""
32
+ # Startup
33
+ logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
34
+
35
+ # Register the demo workflow
36
+ await register_code_review_workflow()
37
+
38
+ yield
39
+
40
+ # Shutdown
41
+ logger.info("Shutting down...")
42
+
43
+
44
+ # Create FastAPI application
45
+ app = FastAPI(
46
+ title=settings.APP_NAME,
47
+ description="""
48
+ ## Workflow Engine API
49
+
50
+ A minimal but powerful workflow/graph engine for building agent workflows.
51
+
52
+ ### Features
53
+ - **Nodes**: Python functions that read and modify shared state
54
+ - **Edges**: Define execution flow between nodes
55
+ - **Branching**: Conditional routing based on state values
56
+ - **Looping**: Support for iterative workflows
57
+ - **Real-time Updates**: WebSocket support for live execution streaming
58
+
59
+ ### Quick Start
60
+ 1. List available tools: `GET /tools`
61
+ 2. Create a graph: `POST /graph/create`
62
+ 3. Run the graph: `POST /graph/run`
63
+ 4. Check execution state: `GET /graph/state/{run_id}`
64
+
65
+ ### Demo Workflow
66
+ A pre-registered Code Review workflow is available with ID: `code-review-demo`
67
+ """,
68
+ version=settings.APP_VERSION,
69
+ docs_url="/docs",
70
+ redoc_url="/redoc",
71
+ lifespan=lifespan,
72
+ )
73
+
74
+
75
+ # Add CORS middleware
76
+ app.add_middleware(
77
+ CORSMiddleware,
78
+ allow_origins=["*"],
79
+ allow_credentials=True,
80
+ allow_methods=["*"],
81
+ allow_headers=["*"],
82
+ )
83
+
84
+
85
+ # Include routers
86
+ app.include_router(graph.router)
87
+ app.include_router(tools.router)
88
+ app.include_router(websocket.router)
89
+
90
+
91
+ # ============================================================
92
+ # Root Endpoints
93
+ # ============================================================
94
+
95
+ @app.get("/", tags=["Root"])
96
+ async def root():
97
+ """API root - returns basic info and links."""
98
+ return {
99
+ "name": settings.APP_NAME,
100
+ "version": settings.APP_VERSION,
101
+ "description": "A minimal workflow/graph engine for agent workflows",
102
+ "docs": "/docs",
103
+ "redoc": "/redoc",
104
+ "endpoints": {
105
+ "graphs": "/graph",
106
+ "tools": "/tools",
107
+ "websocket_run": "/ws/run/{graph_id}",
108
+ "websocket_subscribe": "/ws/subscribe/{run_id}",
109
+ },
110
+ "demo_workflow": "code-review-demo",
111
+ }
112
+
113
+
114
+ @app.get("/health", tags=["Root"])
115
+ async def health():
116
+ """Health check endpoint."""
117
+ from app.storage.memory import graph_storage, run_storage
118
+
119
+ return {
120
+ "status": "healthy",
121
+ "version": settings.APP_VERSION,
122
+ "graphs_count": len(graph_storage),
123
+ "runs_count": len(run_storage),
124
+ }
125
+
126
+
127
+ # ============================================================
128
+ # Error Handlers
129
+ # ============================================================
130
+
131
+ @app.exception_handler(Exception)
132
+ async def global_exception_handler(request, exc):
133
+ """Global exception handler for unhandled errors."""
134
+ logger.exception(f"Unhandled error: {exc}")
135
+ return JSONResponse(
136
+ status_code=500,
137
+ content={
138
+ "error": "Internal Server Error",
139
+ "detail": str(exc) if settings.DEBUG else "An unexpected error occurred",
140
+ },
141
+ )
app/storage/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Storage package - In-memory storage for graphs and runs.
3
+ """
4
+
5
+ from app.storage.memory import (
6
+ GraphStorage,
7
+ RunStorage,
8
+ graph_storage,
9
+ run_storage,
10
+ )
11
+
12
+ __all__ = [
13
+ "GraphStorage",
14
+ "RunStorage",
15
+ "graph_storage",
16
+ "run_storage",
17
+ ]
app/storage/memory.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ In-Memory Storage for Workflow Engine.
3
+
4
+ Provides thread-safe storage for graphs and execution runs.
5
+ Can be easily replaced with a database implementation.
6
+ """
7
+
8
+ from typing import Any, Dict, List, Optional
9
+ from datetime import datetime
10
+ import asyncio
11
+ from dataclasses import dataclass, field
12
+
13
+
14
+ @dataclass
15
+ class StoredGraph:
16
+ """A stored graph definition."""
17
+ graph_id: str
18
+ name: str
19
+ definition: Dict[str, Any]
20
+ created_at: datetime = field(default_factory=datetime.now)
21
+ updated_at: datetime = field(default_factory=datetime.now)
22
+
23
+ def to_dict(self) -> Dict[str, Any]:
24
+ return {
25
+ "graph_id": self.graph_id,
26
+ "name": self.name,
27
+ "definition": self.definition,
28
+ "created_at": self.created_at.isoformat(),
29
+ "updated_at": self.updated_at.isoformat(),
30
+ }
31
+
32
+
33
+ @dataclass
34
+ class StoredRun:
35
+ """A stored execution run."""
36
+ run_id: str
37
+ graph_id: str
38
+ status: str
39
+ initial_state: Dict[str, Any]
40
+ current_state: Dict[str, Any] = field(default_factory=dict)
41
+ final_state: Optional[Dict[str, Any]] = None
42
+ execution_log: List[Dict[str, Any]] = field(default_factory=list)
43
+ current_node: Optional[str] = None
44
+ iteration: int = 0
45
+ started_at: datetime = field(default_factory=datetime.now)
46
+ completed_at: Optional[datetime] = None
47
+ error: Optional[str] = None
48
+
49
+ def to_dict(self) -> Dict[str, Any]:
50
+ return {
51
+ "run_id": self.run_id,
52
+ "graph_id": self.graph_id,
53
+ "status": self.status,
54
+ "initial_state": self.initial_state,
55
+ "current_state": self.current_state,
56
+ "final_state": self.final_state,
57
+ "execution_log": self.execution_log,
58
+ "current_node": self.current_node,
59
+ "iteration": self.iteration,
60
+ "started_at": self.started_at.isoformat(),
61
+ "completed_at": self.completed_at.isoformat() if self.completed_at else None,
62
+ "error": self.error,
63
+ }
64
+
65
+
66
+ class GraphStorage:
67
+ """
68
+ Thread-safe in-memory storage for workflow graphs.
69
+
70
+ Stores graph definitions by their ID, allowing creation,
71
+ retrieval, update, and deletion operations.
72
+ """
73
+
74
+ def __init__(self):
75
+ self._graphs: Dict[str, StoredGraph] = {}
76
+ self._lock = asyncio.Lock()
77
+
78
+ async def save(self, graph_id: str, name: str, definition: Dict[str, Any]) -> StoredGraph:
79
+ """
80
+ Save a graph definition.
81
+
82
+ Args:
83
+ graph_id: Unique graph identifier
84
+ name: Graph name
85
+ definition: Graph definition dict
86
+
87
+ Returns:
88
+ The stored graph
89
+ """
90
+ async with self._lock:
91
+ stored = StoredGraph(
92
+ graph_id=graph_id,
93
+ name=name,
94
+ definition=definition,
95
+ )
96
+ self._graphs[graph_id] = stored
97
+ return stored
98
+
99
+ async def get(self, graph_id: str) -> Optional[StoredGraph]:
100
+ """Get a graph by ID."""
101
+ async with self._lock:
102
+ return self._graphs.get(graph_id)
103
+
104
+ async def update(self, graph_id: str, definition: Dict[str, Any]) -> Optional[StoredGraph]:
105
+ """Update a graph definition."""
106
+ async with self._lock:
107
+ if graph_id not in self._graphs:
108
+ return None
109
+ stored = self._graphs[graph_id]
110
+ stored.definition = definition
111
+ stored.updated_at = datetime.now()
112
+ return stored
113
+
114
+ async def delete(self, graph_id: str) -> bool:
115
+ """Delete a graph."""
116
+ async with self._lock:
117
+ if graph_id in self._graphs:
118
+ del self._graphs[graph_id]
119
+ return True
120
+ return False
121
+
122
+ async def list_all(self) -> List[StoredGraph]:
123
+ """List all stored graphs."""
124
+ async with self._lock:
125
+ return list(self._graphs.values())
126
+
127
+ async def exists(self, graph_id: str) -> bool:
128
+ """Check if a graph exists."""
129
+ async with self._lock:
130
+ return graph_id in self._graphs
131
+
132
+ def __len__(self) -> int:
133
+ return len(self._graphs)
134
+
135
+
136
+ class RunStorage:
137
+ """
138
+ Thread-safe in-memory storage for execution runs.
139
+
140
+ Stores run state, allowing real-time updates and queries
141
+ for ongoing and completed runs.
142
+ """
143
+
144
+ def __init__(self):
145
+ self._runs: Dict[str, StoredRun] = {}
146
+ self._lock = asyncio.Lock()
147
+
148
+ async def create(
149
+ self,
150
+ run_id: str,
151
+ graph_id: str,
152
+ initial_state: Dict[str, Any]
153
+ ) -> StoredRun:
154
+ """
155
+ Create a new run.
156
+
157
+ Args:
158
+ run_id: Unique run identifier
159
+ graph_id: Associated graph ID
160
+ initial_state: Initial state data
161
+
162
+ Returns:
163
+ The stored run
164
+ """
165
+ async with self._lock:
166
+ stored = StoredRun(
167
+ run_id=run_id,
168
+ graph_id=graph_id,
169
+ status="pending",
170
+ initial_state=initial_state,
171
+ current_state=initial_state.copy(),
172
+ )
173
+ self._runs[run_id] = stored
174
+ return stored
175
+
176
+ async def get(self, run_id: str) -> Optional[StoredRun]:
177
+ """Get a run by ID."""
178
+ async with self._lock:
179
+ return self._runs.get(run_id)
180
+
181
+ async def update_state(
182
+ self,
183
+ run_id: str,
184
+ current_state: Dict[str, Any],
185
+ current_node: Optional[str] = None,
186
+ iteration: Optional[int] = None
187
+ ) -> Optional[StoredRun]:
188
+ """Update the current state of a run."""
189
+ async with self._lock:
190
+ if run_id not in self._runs:
191
+ return None
192
+ stored = self._runs[run_id]
193
+ stored.current_state = current_state
194
+ stored.status = "running"
195
+ if current_node is not None:
196
+ stored.current_node = current_node
197
+ if iteration is not None:
198
+ stored.iteration = iteration
199
+ return stored
200
+
201
+ async def add_log_entry(
202
+ self,
203
+ run_id: str,
204
+ entry: Dict[str, Any]
205
+ ) -> Optional[StoredRun]:
206
+ """Add an entry to the execution log."""
207
+ async with self._lock:
208
+ if run_id not in self._runs:
209
+ return None
210
+ self._runs[run_id].execution_log.append(entry)
211
+ return self._runs[run_id]
212
+
213
+ async def complete(
214
+ self,
215
+ run_id: str,
216
+ final_state: Dict[str, Any],
217
+ execution_log: List[Dict[str, Any]]
218
+ ) -> Optional[StoredRun]:
219
+ """Mark a run as completed."""
220
+ async with self._lock:
221
+ if run_id not in self._runs:
222
+ return None
223
+ stored = self._runs[run_id]
224
+ stored.status = "completed"
225
+ stored.final_state = final_state
226
+ stored.execution_log = execution_log
227
+ stored.completed_at = datetime.now()
228
+ return stored
229
+
230
+ async def fail(
231
+ self,
232
+ run_id: str,
233
+ error: str,
234
+ final_state: Optional[Dict[str, Any]] = None
235
+ ) -> Optional[StoredRun]:
236
+ """Mark a run as failed."""
237
+ async with self._lock:
238
+ if run_id not in self._runs:
239
+ return None
240
+ stored = self._runs[run_id]
241
+ stored.status = "failed"
242
+ stored.error = error
243
+ stored.final_state = final_state
244
+ stored.completed_at = datetime.now()
245
+ return stored
246
+
247
+ async def list_all(self) -> List[StoredRun]:
248
+ """List all runs."""
249
+ async with self._lock:
250
+ return list(self._runs.values())
251
+
252
+ async def list_by_graph(self, graph_id: str) -> List[StoredRun]:
253
+ """List all runs for a specific graph."""
254
+ async with self._lock:
255
+ return [r for r in self._runs.values() if r.graph_id == graph_id]
256
+
257
+ async def delete(self, run_id: str) -> bool:
258
+ """Delete a run."""
259
+ async with self._lock:
260
+ if run_id in self._runs:
261
+ del self._runs[run_id]
262
+ return True
263
+ return False
264
+
265
+ def __len__(self) -> int:
266
+ return len(self._runs)
267
+
268
+
269
+ # Global storage instances
270
+ graph_storage = GraphStorage()
271
+ run_storage = RunStorage()
app/tools/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tools package - Tool registry and built-in tools.
3
+ """
4
+
5
+ from app.tools.registry import ToolRegistry, tool_registry, register_tool, get_tool
6
+
7
+ __all__ = [
8
+ "ToolRegistry",
9
+ "tool_registry",
10
+ "register_tool",
11
+ "get_tool",
12
+ ]
app/tools/builtin.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Built-in Tools for the Code Review Workflow.
3
+
4
+ These tools implement the functionality needed for the sample
5
+ Code Review workflow demonstration.
6
+ """
7
+
8
+ import re
9
+ import ast
10
+ from typing import Any, Dict, List, Optional
11
+ from app.tools.registry import register_tool
12
+
13
+
14
+ @register_tool(
15
+ name="extract_functions",
16
+ description="Extract function definitions from Python code"
17
+ )
18
+ def extract_functions(code: str) -> Dict[str, Any]:
19
+ """
20
+ Extract function names and basic info from Python code.
21
+
22
+ Args:
23
+ code: Python source code string
24
+
25
+ Returns:
26
+ Dict with 'functions' list containing function info
27
+ """
28
+ functions = []
29
+
30
+ try:
31
+ tree = ast.parse(code)
32
+
33
+ for node in ast.walk(tree):
34
+ if isinstance(node, ast.FunctionDef):
35
+ func_info = {
36
+ "name": node.name,
37
+ "lineno": node.lineno,
38
+ "args": [arg.arg for arg in node.args.args],
39
+ "has_docstring": (
40
+ ast.get_docstring(node) is not None
41
+ ),
42
+ "decorators": [
43
+ ast.unparse(d) if hasattr(ast, 'unparse') else str(d)
44
+ for d in node.decorator_list
45
+ ],
46
+ "line_count": (
47
+ node.end_lineno - node.lineno + 1
48
+ if hasattr(node, 'end_lineno') and node.end_lineno
49
+ else 0
50
+ ),
51
+ }
52
+ functions.append(func_info)
53
+
54
+ except SyntaxError as e:
55
+ return {
56
+ "functions": [],
57
+ "error": f"Syntax error in code: {e}",
58
+ "parse_success": False,
59
+ }
60
+
61
+ return {
62
+ "functions": functions,
63
+ "function_count": len(functions),
64
+ "parse_success": True,
65
+ }
66
+
67
+
68
+ @register_tool(
69
+ name="calculate_complexity",
70
+ description="Calculate complexity metrics for code"
71
+ )
72
+ def calculate_complexity(code: str, functions: Optional[List[Dict]] = None) -> Dict[str, Any]:
73
+ """
74
+ Calculate simple complexity metrics for Python code.
75
+
76
+ Metrics:
77
+ - Lines of code (LOC)
78
+ - Cyclomatic complexity (simplified)
79
+ - Nesting depth
80
+ - Function count
81
+
82
+ Args:
83
+ code: Python source code
84
+ functions: Optional pre-extracted function list
85
+
86
+ Returns:
87
+ Dict with complexity metrics
88
+ """
89
+ lines = code.split('\n')
90
+ loc = len([l for l in lines if l.strip() and not l.strip().startswith('#')])
91
+
92
+ # Simple cyclomatic complexity: count decision points
93
+ complexity_keywords = ['if', 'elif', 'for', 'while', 'and', 'or', 'except', 'with']
94
+ complexity = 1 # Base complexity
95
+
96
+ for line in lines:
97
+ stripped = line.strip()
98
+ for keyword in complexity_keywords:
99
+ if re.match(rf'\b{keyword}\b', stripped):
100
+ complexity += 1
101
+
102
+ # Calculate max nesting depth
103
+ max_depth = 0
104
+ current_depth = 0
105
+ for line in lines:
106
+ stripped = line.strip()
107
+ if stripped:
108
+ # Count leading spaces
109
+ indent = len(line) - len(line.lstrip())
110
+ depth = indent // 4 # Assume 4-space indentation
111
+ max_depth = max(max_depth, depth)
112
+
113
+ # Calculate function count
114
+ func_count = len(functions) if functions else code.count('def ')
115
+
116
+ # Generate a simple complexity score (1-10 scale)
117
+ # Lower is better
118
+ score = 10
119
+ if complexity > 10:
120
+ score -= 2
121
+ if complexity > 20:
122
+ score -= 2
123
+ if max_depth > 4:
124
+ score -= 1
125
+ if max_depth > 6:
126
+ score -= 1
127
+ if loc > 200:
128
+ score -= 1
129
+ if func_count > 10:
130
+ score -= 1
131
+ if functions:
132
+ long_funcs = [f for f in functions if f.get('line_count', 0) > 50]
133
+ score -= len(long_funcs)
134
+
135
+ score = max(1, score) # Minimum score of 1
136
+
137
+ return {
138
+ "lines_of_code": loc,
139
+ "cyclomatic_complexity": complexity,
140
+ "max_nesting_depth": max_depth,
141
+ "function_count": func_count,
142
+ "complexity_score": score,
143
+ }
144
+
145
+
146
+ @register_tool(
147
+ name="detect_issues",
148
+ description="Detect code quality issues and smells"
149
+ )
150
+ def detect_issues(
151
+ code: str,
152
+ functions: Optional[List[Dict]] = None,
153
+ complexity_score: Optional[int] = None
154
+ ) -> Dict[str, Any]:
155
+ """
156
+ Detect common code quality issues.
157
+
158
+ Checks for:
159
+ - Missing docstrings
160
+ - Long functions
161
+ - Deep nesting
162
+ - Magic numbers
163
+ - TODO/FIXME comments
164
+ - Print statements (in production code)
165
+ - Unused imports (basic check)
166
+
167
+ Args:
168
+ code: Python source code
169
+ functions: Optional pre-extracted functions
170
+ complexity_score: Optional pre-calculated complexity
171
+
172
+ Returns:
173
+ Dict with issues list and summary
174
+ """
175
+ issues = []
176
+ lines = code.split('\n')
177
+
178
+ # Check for missing docstrings
179
+ if functions:
180
+ for func in functions:
181
+ if not func.get('has_docstring'):
182
+ issues.append({
183
+ "type": "missing_docstring",
184
+ "severity": "warning",
185
+ "message": f"Function '{func['name']}' lacks a docstring",
186
+ "line": func.get('lineno'),
187
+ })
188
+
189
+ # Check for long functions
190
+ if functions:
191
+ for func in functions:
192
+ line_count = func.get('line_count', 0)
193
+ if line_count > 50:
194
+ issues.append({
195
+ "type": "long_function",
196
+ "severity": "warning",
197
+ "message": f"Function '{func['name']}' is too long ({line_count} lines)",
198
+ "line": func.get('lineno'),
199
+ })
200
+
201
+ # Check for TODO/FIXME
202
+ for i, line in enumerate(lines, 1):
203
+ if 'TODO' in line or 'FIXME' in line or 'XXX' in line:
204
+ issues.append({
205
+ "type": "todo_comment",
206
+ "severity": "info",
207
+ "message": f"Found TODO/FIXME comment",
208
+ "line": i,
209
+ })
210
+
211
+ # Check for print statements
212
+ for i, line in enumerate(lines, 1):
213
+ stripped = line.strip()
214
+ if stripped.startswith('print(') or 'print(' in stripped:
215
+ issues.append({
216
+ "type": "print_statement",
217
+ "severity": "info",
218
+ "message": "Print statement found (consider using logging)",
219
+ "line": i,
220
+ })
221
+
222
+ # Check for magic numbers
223
+ magic_number_pattern = r'\b(?<![\'".])\d{2,}\b(?![\'"])'
224
+ for i, line in enumerate(lines, 1):
225
+ # Skip comments and string assignments
226
+ stripped = line.strip()
227
+ if not stripped.startswith('#'):
228
+ matches = re.findall(magic_number_pattern, line)
229
+ for match in matches:
230
+ if int(match) not in (0, 1, 2, 100): # Common acceptable values
231
+ issues.append({
232
+ "type": "magic_number",
233
+ "severity": "info",
234
+ "message": f"Magic number {match} found (consider using a constant)",
235
+ "line": i,
236
+ })
237
+ break # One per line is enough
238
+
239
+ # Calculate quality score based on issues
240
+ quality_score = 10
241
+ for issue in issues:
242
+ if issue['severity'] == 'error':
243
+ quality_score -= 2
244
+ elif issue['severity'] == 'warning':
245
+ quality_score -= 1
246
+ else:
247
+ quality_score -= 0.5
248
+
249
+ # Factor in complexity score if provided
250
+ if complexity_score:
251
+ quality_score = (quality_score + complexity_score) / 2
252
+
253
+ quality_score = max(1, min(10, quality_score))
254
+
255
+ return {
256
+ "issues": issues,
257
+ "issue_count": len(issues),
258
+ "quality_score": round(quality_score, 1),
259
+ "issues_by_severity": {
260
+ "error": len([i for i in issues if i['severity'] == 'error']),
261
+ "warning": len([i for i in issues if i['severity'] == 'warning']),
262
+ "info": len([i for i in issues if i['severity'] == 'info']),
263
+ }
264
+ }
265
+
266
+
267
+ @register_tool(
268
+ name="suggest_improvements",
269
+ description="Generate improvement suggestions based on detected issues"
270
+ )
271
+ def suggest_improvements(
272
+ issues: List[Dict],
273
+ functions: Optional[List[Dict]] = None,
274
+ quality_score: Optional[float] = None
275
+ ) -> Dict[str, Any]:
276
+ """
277
+ Generate actionable improvement suggestions.
278
+
279
+ Args:
280
+ issues: List of detected issues
281
+ functions: Optional function info
282
+ quality_score: Current quality score
283
+
284
+ Returns:
285
+ Dict with suggestions and priority ranking
286
+ """
287
+ suggestions = []
288
+
289
+ # Group issues by type
290
+ issue_types = {}
291
+ for issue in issues:
292
+ issue_type = issue.get('type', 'unknown')
293
+ if issue_type not in issue_types:
294
+ issue_types[issue_type] = []
295
+ issue_types[issue_type].append(issue)
296
+
297
+ # Generate suggestions based on issue types
298
+ if 'missing_docstring' in issue_types:
299
+ count = len(issue_types['missing_docstring'])
300
+ suggestions.append({
301
+ "priority": "high",
302
+ "category": "documentation",
303
+ "suggestion": f"Add docstrings to {count} function(s)",
304
+ "details": "Good docstrings improve code maintainability and enable automatic documentation generation.",
305
+ "affected_functions": [i.get('message', '').split("'")[1] for i in issue_types['missing_docstring'] if "'" in i.get('message', '')],
306
+ })
307
+
308
+ if 'long_function' in issue_types:
309
+ count = len(issue_types['long_function'])
310
+ suggestions.append({
311
+ "priority": "high",
312
+ "category": "refactoring",
313
+ "suggestion": f"Refactor {count} long function(s) into smaller units",
314
+ "details": "Functions over 50 lines are harder to understand and test. Consider extracting helper functions.",
315
+ })
316
+
317
+ if 'print_statement' in issue_types:
318
+ count = len(issue_types['print_statement'])
319
+ suggestions.append({
320
+ "priority": "medium",
321
+ "category": "logging",
322
+ "suggestion": f"Replace {count} print statement(s) with proper logging",
323
+ "details": "Use the logging module for better control over log levels and output.",
324
+ })
325
+
326
+ if 'magic_number' in issue_types:
327
+ count = len(issue_types['magic_number'])
328
+ suggestions.append({
329
+ "priority": "medium",
330
+ "category": "readability",
331
+ "suggestion": f"Extract {count} magic number(s) into named constants",
332
+ "details": "Named constants improve readability and make the code easier to modify.",
333
+ })
334
+
335
+ if 'todo_comment' in issue_types:
336
+ count = len(issue_types['todo_comment'])
337
+ suggestions.append({
338
+ "priority": "low",
339
+ "category": "maintenance",
340
+ "suggestion": f"Address {count} TODO/FIXME comment(s)",
341
+ "details": "Consider creating issues or tasks to track these items.",
342
+ })
343
+
344
+ # Add general suggestions if quality is low
345
+ if quality_score and quality_score < 5:
346
+ suggestions.append({
347
+ "priority": "high",
348
+ "category": "general",
349
+ "suggestion": "Consider a comprehensive code review",
350
+ "details": "The overall quality score is low. A thorough review may reveal structural improvements.",
351
+ })
352
+
353
+ # Sort by priority
354
+ priority_order = {"high": 0, "medium": 1, "low": 2}
355
+ suggestions.sort(key=lambda x: priority_order.get(x['priority'], 3))
356
+
357
+ # Calculate new expected quality score after improvements
358
+ potential_improvement = len(suggestions) * 0.5
359
+ new_quality_score = min(10, (quality_score or 5) + potential_improvement)
360
+
361
+ return {
362
+ "suggestions": suggestions,
363
+ "suggestion_count": len(suggestions),
364
+ "current_quality_score": quality_score,
365
+ "potential_quality_score": round(new_quality_score, 1),
366
+ "categories": list(set(s['category'] for s in suggestions)),
367
+ }
368
+
369
+
370
+ @register_tool(
371
+ name="quality_check",
372
+ description="Check if quality meets the threshold"
373
+ )
374
+ def quality_check(quality_score: float, quality_threshold: float = 7.0) -> str:
375
+ """
376
+ Simple routing function to check if quality meets threshold.
377
+
378
+ Args:
379
+ quality_score: Current quality score (1-10)
380
+ quality_threshold: Minimum acceptable score
381
+
382
+ Returns:
383
+ "pass" if quality meets threshold, "fail" otherwise
384
+ """
385
+ if quality_score >= quality_threshold:
386
+ return "pass"
387
+ return "fail"
app/tools/registry.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tool Registry for Workflow Engine.
3
+
4
+ The tool registry maintains a collection of callable tools that
5
+ workflow nodes can use. Tools are simple Python functions that
6
+ perform specific operations.
7
+ """
8
+
9
+ from typing import Any, Callable, Dict, List, Optional
10
+ from dataclasses import dataclass, field
11
+ import functools
12
+ import inspect
13
+ import logging
14
+
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ @dataclass
20
+ class Tool:
21
+ """
22
+ A registered tool.
23
+
24
+ Attributes:
25
+ name: Unique identifier for the tool
26
+ func: The callable function
27
+ description: Human-readable description
28
+ parameters: Parameter descriptions
29
+ """
30
+ name: str
31
+ func: Callable
32
+ description: str = ""
33
+ parameters: Dict[str, str] = field(default_factory=dict)
34
+
35
+ def __call__(self, *args, **kwargs) -> Any:
36
+ """Call the tool function."""
37
+ return self.func(*args, **kwargs)
38
+
39
+ def to_dict(self) -> Dict[str, Any]:
40
+ """Serialize tool metadata."""
41
+ return {
42
+ "name": self.name,
43
+ "description": self.description,
44
+ "parameters": self.parameters,
45
+ }
46
+
47
+
48
+ class ToolRegistry:
49
+ """
50
+ Registry for workflow tools.
51
+
52
+ Tools are simple Python functions that nodes can call to perform
53
+ specific operations. The registry allows dynamic registration
54
+ and lookup of tools.
55
+
56
+ Usage:
57
+ registry = ToolRegistry()
58
+
59
+ @registry.register("my_tool")
60
+ def my_tool(data: str) -> dict:
61
+ return {"result": data.upper()}
62
+
63
+ # Later
64
+ tool = registry.get("my_tool")
65
+ result = tool("hello")
66
+ """
67
+
68
+ def __init__(self):
69
+ self._tools: Dict[str, Tool] = {}
70
+
71
+ def register(
72
+ self,
73
+ name: Optional[str] = None,
74
+ description: str = "",
75
+ parameters: Optional[Dict[str, str]] = None
76
+ ) -> Callable:
77
+ """
78
+ Decorator to register a function as a tool.
79
+
80
+ Args:
81
+ name: Tool name (defaults to function name)
82
+ description: Tool description (defaults to docstring)
83
+ parameters: Parameter descriptions
84
+
85
+ Returns:
86
+ Decorator function
87
+ """
88
+ def decorator(func: Callable) -> Callable:
89
+ tool_name = name or func.__name__
90
+ tool_desc = description or func.__doc__ or ""
91
+
92
+ # Extract parameters from signature if not provided
93
+ params = parameters or {}
94
+ if not params:
95
+ sig = inspect.signature(func)
96
+ for param_name, param in sig.parameters.items():
97
+ if param_name not in ("self", "cls"):
98
+ params[param_name] = str(param.annotation) if param.annotation != inspect.Parameter.empty else "Any"
99
+
100
+ # Create and store tool
101
+ tool = Tool(
102
+ name=tool_name,
103
+ func=func,
104
+ description=tool_desc.strip(),
105
+ parameters=params,
106
+ )
107
+ self._tools[tool_name] = tool
108
+
109
+ logger.debug(f"Registered tool: {tool_name}")
110
+
111
+ @functools.wraps(func)
112
+ def wrapper(*args, **kwargs):
113
+ return func(*args, **kwargs)
114
+
115
+ return wrapper
116
+
117
+ return decorator
118
+
119
+ def add(
120
+ self,
121
+ func: Callable,
122
+ name: Optional[str] = None,
123
+ description: str = "",
124
+ parameters: Optional[Dict[str, str]] = None
125
+ ) -> None:
126
+ """
127
+ Directly add a function as a tool (non-decorator version).
128
+
129
+ Args:
130
+ func: The function to register
131
+ name: Tool name (defaults to function name)
132
+ description: Tool description
133
+ parameters: Parameter descriptions
134
+ """
135
+ tool_name = name or func.__name__
136
+ tool_desc = description or func.__doc__ or ""
137
+
138
+ tool = Tool(
139
+ name=tool_name,
140
+ func=func,
141
+ description=tool_desc.strip(),
142
+ parameters=parameters or {},
143
+ )
144
+ self._tools[tool_name] = tool
145
+ logger.debug(f"Added tool: {tool_name}")
146
+
147
+ def get(self, name: str) -> Optional[Tool]:
148
+ """Get a tool by name."""
149
+ return self._tools.get(name)
150
+
151
+ def call(self, name: str, *args, **kwargs) -> Any:
152
+ """
153
+ Call a tool by name.
154
+
155
+ Args:
156
+ name: Tool name
157
+ *args: Positional arguments
158
+ **kwargs: Keyword arguments
159
+
160
+ Returns:
161
+ Tool result
162
+
163
+ Raises:
164
+ KeyError: If tool not found
165
+ """
166
+ tool = self.get(name)
167
+ if not tool:
168
+ raise KeyError(f"Tool '{name}' not found in registry")
169
+ return tool(*args, **kwargs)
170
+
171
+ def remove(self, name: str) -> bool:
172
+ """Remove a tool from the registry."""
173
+ if name in self._tools:
174
+ del self._tools[name]
175
+ return True
176
+ return False
177
+
178
+ def list_tools(self) -> List[Dict[str, Any]]:
179
+ """List all registered tools with their metadata."""
180
+ return [tool.to_dict() for tool in self._tools.values()]
181
+
182
+ def has(self, name: str) -> bool:
183
+ """Check if a tool is registered."""
184
+ return name in self._tools
185
+
186
+ def __contains__(self, name: str) -> bool:
187
+ return self.has(name)
188
+
189
+ def __len__(self) -> int:
190
+ return len(self._tools)
191
+
192
+ def __iter__(self):
193
+ return iter(self._tools.values())
194
+
195
+
196
+ # Global tool registry instance
197
+ tool_registry = ToolRegistry()
198
+
199
+
200
+ def register_tool(
201
+ name: Optional[str] = None,
202
+ description: str = "",
203
+ parameters: Optional[Dict[str, str]] = None
204
+ ) -> Callable:
205
+ """
206
+ Convenience decorator to register a tool in the global registry.
207
+
208
+ Usage:
209
+ @register_tool("my_tool", description="Does something cool")
210
+ def my_tool(data: str) -> dict:
211
+ return {"result": data}
212
+ """
213
+ return tool_registry.register(name, description, parameters)
214
+
215
+
216
+ def get_tool(name: str) -> Optional[Tool]:
217
+ """Get a tool from the global registry."""
218
+ return tool_registry.get(name)
app/workflows/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Workflows package - Sample workflow implementations.
3
+ """
4
+
5
+ from app.workflows.code_review import create_code_review_workflow, register_code_review_workflow
6
+
7
+ __all__ = [
8
+ "create_code_review_workflow",
9
+ "register_code_review_workflow",
10
+ ]
app/workflows/code_review.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Code Review Workflow Implementation.
3
+
4
+ This is the sample workflow demonstrating the workflow engine capabilities:
5
+ 1. Extract functions from code
6
+ 2. Check complexity
7
+ 3. Detect issues
8
+ 4. Suggest improvements
9
+ 5. Loop until quality_score >= threshold
10
+ """
11
+
12
+ from typing import Any, Dict
13
+ import logging
14
+
15
+ from app.engine.graph import Graph, END
16
+ from app.engine.node import node, NodeType
17
+ from app.tools.builtin import (
18
+ extract_functions,
19
+ calculate_complexity,
20
+ detect_issues,
21
+ suggest_improvements,
22
+ quality_check,
23
+ )
24
+ from app.tools.registry import tool_registry
25
+
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ # ============================================================
31
+ # Node Handlers (using the @node decorator)
32
+ # ============================================================
33
+
34
+ @node(name="extract_node", description="Extract functions from the input code")
35
+ def extract_node(state: Dict[str, Any]) -> Dict[str, Any]:
36
+ """
37
+ Extract function definitions from the code.
38
+
39
+ Input state requires:
40
+ - code: str - The Python source code to analyze
41
+
42
+ Updates state with:
43
+ - functions: List[dict] - Extracted function information
44
+ - function_count: int - Number of functions found
45
+ """
46
+ code = state.get("code", "")
47
+ result = extract_functions(code)
48
+ state.update(result)
49
+ logger.info(f"Extracted {result.get('function_count', 0)} functions")
50
+ return state
51
+
52
+
53
+ @node(name="complexity_node", description="Calculate code complexity metrics")
54
+ def complexity_node(state: Dict[str, Any]) -> Dict[str, Any]:
55
+ """
56
+ Calculate complexity metrics for the code.
57
+
58
+ Uses state:
59
+ - code: str - Source code
60
+ - functions: List[dict] - Previously extracted functions
61
+
62
+ Updates state with:
63
+ - lines_of_code: int
64
+ - cyclomatic_complexity: int
65
+ - complexity_score: int (1-10)
66
+ """
67
+ code = state.get("code", "")
68
+ functions = state.get("functions", [])
69
+ result = calculate_complexity(code, functions)
70
+ state.update(result)
71
+ logger.info(f"Complexity score: {result.get('complexity_score', 0)}")
72
+ return state
73
+
74
+
75
+ @node(name="issues_node", description="Detect code quality issues")
76
+ def issues_node(state: Dict[str, Any]) -> Dict[str, Any]:
77
+ """
78
+ Detect code quality issues and calculate quality score.
79
+
80
+ Uses state:
81
+ - code: str - Source code
82
+ - functions: List[dict] - Extracted functions
83
+ - complexity_score: int - From complexity check
84
+
85
+ Updates state with:
86
+ - issues: List[dict] - Detected issues
87
+ - issue_count: int
88
+ - quality_score: float (1-10)
89
+ """
90
+ code = state.get("code", "")
91
+ functions = state.get("functions", [])
92
+ complexity_score = state.get("complexity_score")
93
+
94
+ result = detect_issues(code, functions, complexity_score)
95
+ state.update(result)
96
+
97
+ logger.info(
98
+ f"Found {result.get('issue_count', 0)} issues, "
99
+ f"quality score: {result.get('quality_score', 0)}"
100
+ )
101
+ return state
102
+
103
+
104
+ @node(name="improve_node", description="Generate improvement suggestions")
105
+ def improve_node(state: Dict[str, Any]) -> Dict[str, Any]:
106
+ """
107
+ Generate improvement suggestions based on detected issues.
108
+
109
+ Uses state:
110
+ - issues: List[dict] - Detected issues
111
+ - functions: List[dict] - Extracted functions
112
+ - quality_score: float - Current quality score
113
+
114
+ Updates state with:
115
+ - suggestions: List[dict] - Improvement suggestions
116
+ - suggestion_count: int
117
+ - potential_quality_score: float - Score after improvements
118
+ """
119
+ issues = state.get("issues", [])
120
+ functions = state.get("functions", [])
121
+ quality_score = state.get("quality_score", 5.0)
122
+
123
+ result = suggest_improvements(issues, functions, quality_score)
124
+ state.update(result)
125
+
126
+ # Simulate improvement by slightly increasing quality score
127
+ # In a real scenario, this would involve actual code modifications
128
+ improvement = min(0.5, result.get("suggestion_count", 0) * 0.2)
129
+ state["quality_score"] = min(10, quality_score + improvement)
130
+
131
+ logger.info(
132
+ f"Generated {result.get('suggestion_count', 0)} suggestions, "
133
+ f"quality improved to {state['quality_score']}"
134
+ )
135
+ return state
136
+
137
+
138
+ # Register node handlers as tools so they can be retrieved when rebuilding from storage
139
+ def _wrapper_handler(handler_func):
140
+ """Create a wrapper that works with tool registry."""
141
+ def wrapper(state: Dict[str, Any]) -> Dict[str, Any]:
142
+ return handler_func(state)
143
+ wrapper.__name__ = handler_func.__name__
144
+ wrapper.__doc__ = handler_func.__doc__
145
+ return wrapper
146
+
147
+ tool_registry.add(_wrapper_handler(extract_node), name="extract_node", description="Extract functions from code")
148
+ tool_registry.add(_wrapper_handler(complexity_node), name="complexity_node", description="Calculate complexity")
149
+ tool_registry.add(_wrapper_handler(issues_node), name="issues_node", description="Detect quality issues")
150
+ tool_registry.add(_wrapper_handler(improve_node), name="improve_node", description="Suggest improvements")
151
+
152
+
153
+ # ============================================================
154
+ # Condition Functions
155
+ # ============================================================
156
+
157
+ def quality_meets_threshold(state: Dict[str, Any]) -> str:
158
+ """
159
+ Routing condition: check if quality meets threshold.
160
+
161
+ Returns:
162
+ - "pass" if quality_score >= quality_threshold
163
+ - "fail" if more improvement needed
164
+ """
165
+ quality_score = state.get("quality_score", 0)
166
+ threshold = state.get("quality_threshold", 7.0)
167
+
168
+ if quality_score >= threshold:
169
+ logger.info(f"Quality {quality_score} meets threshold {threshold}")
170
+ return "pass"
171
+ else:
172
+ logger.info(f"Quality {quality_score} below threshold {threshold}")
173
+ return "fail"
174
+
175
+
176
+ def always_loop(state: Dict[str, Any]) -> str:
177
+ """Always return to issues check after improvement."""
178
+ return "continue"
179
+
180
+
181
+ # ============================================================
182
+ # Workflow Factory
183
+ # ============================================================
184
+
185
+ def create_code_review_workflow(
186
+ max_iterations: int = 5,
187
+ quality_threshold: float = 7.0
188
+ ) -> Graph:
189
+ """
190
+ Create a Code Review workflow graph.
191
+
192
+ Workflow flow:
193
+ ```
194
+ extract → complexity → issues ─┬─→ END (if pass)
195
+
196
+ └─→ improve → issues (loop if fail)
197
+ ```
198
+
199
+ Args:
200
+ max_iterations: Maximum improvement loops
201
+ quality_threshold: Minimum quality score to pass
202
+
203
+ Returns:
204
+ Configured Graph instance
205
+ """
206
+ graph = Graph(
207
+ name="Code Review Workflow",
208
+ description=(
209
+ "Analyzes Python code for quality issues and suggests improvements. "
210
+ f"Loops until quality score >= {quality_threshold} or max {max_iterations} iterations."
211
+ ),
212
+ max_iterations=max_iterations,
213
+ )
214
+
215
+ # Add nodes
216
+ graph.add_node("extract", handler=extract_node, description="Extract functions from code")
217
+ graph.add_node("complexity", handler=complexity_node, description="Calculate complexity")
218
+ graph.add_node("issues", handler=issues_node, description="Detect quality issues")
219
+ graph.add_node("improve", handler=improve_node, description="Suggest improvements")
220
+
221
+ # Add edges
222
+ graph.add_edge("extract", "complexity")
223
+ graph.add_edge("complexity", "issues")
224
+
225
+ # Conditional edge: issues → END or improve
226
+ graph.add_conditional_edge(
227
+ "issues",
228
+ quality_meets_threshold,
229
+ {"pass": END, "fail": "improve"}
230
+ )
231
+
232
+ # Loop back from improve to issues
233
+ graph.add_conditional_edge(
234
+ "improve",
235
+ always_loop,
236
+ {"continue": "issues"}
237
+ )
238
+
239
+ # Set entry point
240
+ graph.set_entry_point("extract")
241
+
242
+ return graph
243
+
244
+
245
+ async def register_code_review_workflow():
246
+ """
247
+ Register a pre-built Code Review workflow in storage.
248
+
249
+ This makes the workflow available immediately via the API
250
+ without needing to create it first.
251
+ """
252
+ from app.storage.memory import graph_storage
253
+
254
+ workflow = create_code_review_workflow()
255
+
256
+ await graph_storage.save(
257
+ graph_id="code-review-demo",
258
+ name="Code Review Demo",
259
+ definition=workflow.to_dict(),
260
+ )
261
+
262
+ logger.info("Registered Code Review workflow with ID: code-review-demo")
263
+ return workflow
264
+
265
+
266
+ # ============================================================
267
+ # Example Usage
268
+ # ============================================================
269
+
270
+ async def run_code_review_demo():
271
+ """
272
+ Demo function showing how to run the code review workflow.
273
+
274
+ Usage:
275
+ import asyncio
276
+ from app.workflows.code_review import run_code_review_demo
277
+ asyncio.run(run_code_review_demo())
278
+ """
279
+ from app.engine.executor import execute_graph
280
+
281
+ # Sample code to review
282
+ sample_code = '''
283
+ def calculate_total(items):
284
+ total = 0
285
+ for item in items:
286
+ if item.price > 0:
287
+ if item.quantity > 0:
288
+ if item.discount:
289
+ total += item.price * item.quantity * (1 - item.discount)
290
+ else:
291
+ total += item.price * item.quantity
292
+ return total
293
+
294
+ def process_data(data):
295
+ result = []
296
+ for i in range(len(data)):
297
+ if data[i] > 100:
298
+ result.append(data[i] * 2)
299
+ else:
300
+ result.append(data[i])
301
+ print(result)
302
+ return result
303
+
304
+
305
+ def helper():
306
+ x = 42
307
+ return x * 1000
308
+ '''
309
+
310
+ # Create workflow
311
+ workflow = create_code_review_workflow(max_iterations=3, quality_threshold=6.0)
312
+
313
+ # Initial state
314
+ initial_state = {
315
+ "code": sample_code,
316
+ "quality_threshold": 6.0,
317
+ }
318
+
319
+ # Execute
320
+ print("Starting Code Review...")
321
+ result = await execute_graph(workflow, initial_state)
322
+
323
+ # Print results
324
+ print(f"\nExecution Status: {result.status.value}")
325
+ print(f"Total Duration: {result.total_duration_ms:.2f}ms")
326
+ print(f"Iterations: {result.iterations}")
327
+ print(f"\nFinal Quality Score: {result.final_state.get('quality_score', 'N/A')}")
328
+ print(f"Issues Found: {result.final_state.get('issue_count', 'N/A')}")
329
+ print(f"\nSuggestions:")
330
+ for suggestion in result.final_state.get("suggestions", []):
331
+ print(f" - [{suggestion['priority']}] {suggestion['suggestion']}")
332
+
333
+ return result
334
+
335
+
336
+ if __name__ == "__main__":
337
+ import asyncio
338
+ asyncio.run(run_code_review_demo())
docker-compose.yml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ workflow-engine:
3
+ build:
4
+ context: .
5
+ dockerfile: Dockerfile
6
+ container_name: workflow-engine
7
+ ports:
8
+ - "8000:8000"
9
+ environment:
10
+ - APP_NAME=FlowGraph
11
+ - APP_VERSION=1.0.0
12
+ - DEBUG=true
13
+ - HOST=0.0.0.0
14
+ - PORT=8000
15
+ - MAX_ITERATIONS=100
16
+ - LOG_LEVEL=INFO
17
+ volumes:
18
+ # Mount for development (hot reload)
19
+ - .:/app
20
+ command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
21
+ healthcheck:
22
+ test: [ "CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" ]
23
+ interval: 30s
24
+ timeout: 10s
25
+ retries: 3
26
+ start_period: 10s
27
+ restart: unless-stopped
28
+
29
+ # Optional: Run tests in a separate container
30
+ tests:
31
+ build:
32
+ context: .
33
+ dockerfile: Dockerfile
34
+ container_name: workflow-engine-tests
35
+ command: pytest tests/ -v
36
+ profiles:
37
+ - test
38
+ depends_on:
39
+ - workflow-engine
pytest.ini ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [pytest]
2
+ asyncio_mode = auto
3
+ testpaths = tests
4
+ python_files = test_*.py
5
+ python_functions = test_*
6
+ addopts = -v --tb=short
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core
2
+ fastapi>=0.104.0
3
+ uvicorn[standard]>=0.24.0
4
+ pydantic>=2.5.0
5
+ pydantic-settings>=2.1.0
6
+
7
+ # Async support
8
+ asyncio-throttle>=1.0.2
9
+
10
+ # Testing
11
+ pytest>=7.4.0
12
+ pytest-asyncio>=0.21.0
13
+ httpx>=0.25.0
14
+
15
+ # Optional - for better logging
16
+ rich>=13.7.0
run.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Simple run script for the Workflow Engine.
4
+
5
+ Usage:
6
+ python run.py
7
+
8
+ Or with custom settings:
9
+ HOST=127.0.0.1 PORT=8080 python run.py
10
+ """
11
+
12
+ import uvicorn
13
+ import os
14
+
15
+
16
+ def main():
17
+ """Run the FastAPI application."""
18
+ host = os.getenv("HOST", "0.0.0.0")
19
+ port = int(os.getenv("PORT", "8000"))
20
+ reload = os.getenv("RELOAD", "true").lower() == "true"
21
+
22
+ print(f"""
23
+ ╔═══════════════════════════════════════════════════════════════╗
24
+ ║ FlowGraph 🔄 ║
25
+ ║ ║
26
+ ║ A lightweight workflow orchestration engine ║
27
+ ╠═══════════════════════════════════════════════════════════════╣
28
+ ║ Server: http://{host}:{port} ║
29
+ ║ API Docs: http://{host}:{port}/docs ║
30
+ ║ ReDoc: http://{host}:{port}/redoc ║
31
+ ╠═══════════════════════════════════════════════════════════════╣
32
+ ║ Demo workflow ID: code-review-demo ║
33
+ ╚═══════════════════════════════════════════════════════════════╝
34
+ """)
35
+
36
+ uvicorn.run(
37
+ "app.main:app",
38
+ host=host,
39
+ port=port,
40
+ reload=reload,
41
+ log_level="info",
42
+ )
43
+
44
+
45
+ if __name__ == "__main__":
46
+ main()
tests/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """
2
+ Tests package.
3
+ """
tests/test_api.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the FastAPI endpoints.
3
+ """
4
+
5
+ import pytest
6
+ from fastapi.testclient import TestClient
7
+ from httpx import AsyncClient, ASGITransport
8
+
9
+ from app.main import app
10
+
11
+
12
+ # ============================================================
13
+ # Sync Test Client (for simple tests)
14
+ # ============================================================
15
+
16
+ client = TestClient(app)
17
+
18
+
19
+ class TestRootEndpoints:
20
+ """Tests for root endpoints."""
21
+
22
+ def test_root(self):
23
+ """Test root endpoint."""
24
+ response = client.get("/")
25
+ assert response.status_code == 200
26
+
27
+ data = response.json()
28
+ assert "name" in data
29
+ assert "version" in data
30
+ assert "endpoints" in data
31
+
32
+ def test_health(self):
33
+ """Test health endpoint."""
34
+ response = client.get("/health")
35
+ assert response.status_code == 200
36
+
37
+ data = response.json()
38
+ assert data["status"] == "healthy"
39
+
40
+
41
+ class TestToolsEndpoints:
42
+ """Tests for tools endpoints."""
43
+
44
+ def test_list_tools(self):
45
+ """Test listing tools."""
46
+ response = client.get("/tools/")
47
+ assert response.status_code == 200
48
+
49
+ data = response.json()
50
+ assert "tools" in data
51
+ assert "total" in data
52
+ assert data["total"] > 0
53
+
54
+ # Check that built-in tools are present
55
+ tool_names = [t["name"] for t in data["tools"]]
56
+ assert "extract_functions" in tool_names
57
+ assert "calculate_complexity" in tool_names
58
+
59
+ def test_get_tool(self):
60
+ """Test getting a specific tool."""
61
+ response = client.get("/tools/extract_functions")
62
+ assert response.status_code == 200
63
+
64
+ data = response.json()
65
+ assert data["name"] == "extract_functions"
66
+ assert "description" in data
67
+
68
+ def test_get_nonexistent_tool(self):
69
+ """Test getting a tool that doesn't exist."""
70
+ response = client.get("/tools/nonexistent_tool")
71
+ assert response.status_code == 404
72
+
73
+
74
+ class TestGraphEndpoints:
75
+ """Tests for graph endpoints."""
76
+
77
+ def test_list_graphs(self):
78
+ """Test listing graphs."""
79
+ response = client.get("/graph/")
80
+ assert response.status_code == 200
81
+
82
+ data = response.json()
83
+ assert "graphs" in data
84
+ assert "total" in data
85
+
86
+ def test_get_demo_workflow(self):
87
+ """Test getting the demo workflow."""
88
+ response = client.get("/graph/code-review-demo")
89
+ assert response.status_code == 200
90
+
91
+ data = response.json()
92
+ assert data["graph_id"] == "code-review-demo"
93
+ assert data["name"] == "Code Review Demo"
94
+ assert "mermaid_diagram" in data
95
+
96
+ def test_create_graph(self):
97
+ """Test creating a new graph."""
98
+ graph_data = {
99
+ "name": "test_workflow",
100
+ "description": "A test workflow",
101
+ "nodes": [
102
+ {"name": "start", "handler": "extract_functions"},
103
+ {"name": "end", "handler": "calculate_complexity"}
104
+ ],
105
+ "edges": {
106
+ "start": "end"
107
+ },
108
+ "entry_point": "start"
109
+ }
110
+
111
+ response = client.post("/graph/create", json=graph_data)
112
+ assert response.status_code == 201
113
+
114
+ data = response.json()
115
+ assert "graph_id" in data
116
+ assert data["name"] == "test_workflow"
117
+ assert data["node_count"] == 2
118
+
119
+ def test_create_graph_invalid_handler(self):
120
+ """Test creating a graph with invalid handler."""
121
+ graph_data = {
122
+ "name": "invalid_workflow",
123
+ "nodes": [
124
+ {"name": "bad", "handler": "nonexistent_handler"}
125
+ ],
126
+ "edges": {}
127
+ }
128
+
129
+ response = client.post("/graph/create", json=graph_data)
130
+ assert response.status_code == 404
131
+
132
+
133
+ # ============================================================
134
+ # Async Tests (for async endpoints)
135
+ # ============================================================
136
+
137
+ @pytest.fixture
138
+ def anyio_backend():
139
+ return "asyncio"
140
+
141
+
142
+ @pytest.mark.asyncio
143
+ async def test_run_demo_workflow():
144
+ """Test running the demo workflow."""
145
+ transport = ASGITransport(app=app)
146
+ async with AsyncClient(transport=transport, base_url="http://test") as ac:
147
+ run_data = {
148
+ "graph_id": "code-review-demo",
149
+ "initial_state": {
150
+ "code": "def hello():\n print('world')",
151
+ "quality_threshold": 5.0
152
+ },
153
+ "async_execution": False
154
+ }
155
+
156
+ response = await ac.post("/graph/run", json=run_data)
157
+ assert response.status_code == 200
158
+
159
+ data = response.json()
160
+ assert "run_id" in data
161
+ assert data["status"] in ["completed", "failed"]
162
+ assert "execution_log" in data
163
+
164
+
165
+ @pytest.mark.asyncio
166
+ async def test_async_execution():
167
+ """Test async execution mode."""
168
+ transport = ASGITransport(app=app)
169
+ async with AsyncClient(transport=transport, base_url="http://test") as ac:
170
+ run_data = {
171
+ "graph_id": "code-review-demo",
172
+ "initial_state": {
173
+ "code": "def test(): pass",
174
+ "quality_threshold": 5.0
175
+ },
176
+ "async_execution": True
177
+ }
178
+
179
+ response = await ac.post("/graph/run", json=run_data)
180
+ assert response.status_code == 200
181
+
182
+ data = response.json()
183
+ assert "run_id" in data
184
+ assert data["status"] == "pending"
185
+
186
+ # Check run state
187
+ run_id = data["run_id"]
188
+ state_response = await ac.get(f"/graph/state/{run_id}")
189
+ assert state_response.status_code == 200
190
+
191
+
192
+ @pytest.mark.asyncio
193
+ async def test_run_nonexistent_graph():
194
+ """Test running a graph that doesn't exist."""
195
+ transport = ASGITransport(app=app)
196
+ async with AsyncClient(transport=transport, base_url="http://test") as ac:
197
+ run_data = {
198
+ "graph_id": "nonexistent-graph",
199
+ "initial_state": {}
200
+ }
201
+
202
+ response = await ac.post("/graph/run", json=run_data)
203
+ assert response.status_code == 404
204
+
205
+
206
+ if __name__ == "__main__":
207
+ pytest.main([__file__, "-v"])
tests/test_engine.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the Workflow Engine core components.
3
+ """
4
+
5
+ import pytest
6
+ import asyncio
7
+ from typing import Dict, Any
8
+
9
+ from app.engine.state import WorkflowState, StateManager
10
+ from app.engine.node import Node, NodeType, node, create_node_from_function
11
+ from app.engine.graph import Graph, END
12
+ from app.engine.executor import Executor, ExecutionStatus, execute_graph
13
+
14
+
15
+ # ============================================================
16
+ # State Tests
17
+ # ============================================================
18
+
19
+ class TestWorkflowState:
20
+ """Tests for WorkflowState."""
21
+
22
+ def test_create_empty_state(self):
23
+ """Test creating an empty state."""
24
+ state = WorkflowState()
25
+ assert state.data == {}
26
+ assert state.iteration == 0
27
+ assert state.visited_nodes == []
28
+
29
+ def test_create_state_with_data(self):
30
+ """Test creating state with initial data."""
31
+ state = WorkflowState(data={"key": "value"})
32
+ assert state.get("key") == "value"
33
+ assert state.get("missing") is None
34
+ assert state.get("missing", "default") == "default"
35
+
36
+ def test_state_immutability(self):
37
+ """Test that state updates return new instances."""
38
+ state1 = WorkflowState(data={"a": 1})
39
+ state2 = state1.set("b", 2)
40
+
41
+ assert state1.get("b") is None
42
+ assert state2.get("b") == 2
43
+ assert state1 is not state2
44
+
45
+ def test_state_update_multiple(self):
46
+ """Test updating multiple values at once."""
47
+ state = WorkflowState(data={"a": 1})
48
+ new_state = state.update({"b": 2, "c": 3})
49
+
50
+ assert new_state.get("a") == 1
51
+ assert new_state.get("b") == 2
52
+ assert new_state.get("c") == 3
53
+
54
+ def test_state_mark_visited(self):
55
+ """Test marking nodes as visited."""
56
+ state = WorkflowState()
57
+ state = state.mark_visited("node1")
58
+ state = state.mark_visited("node2")
59
+
60
+ assert "node1" in state.visited_nodes
61
+ assert "node2" in state.visited_nodes
62
+ assert state.current_node == "node2"
63
+
64
+ def test_state_to_from_dict(self):
65
+ """Test serialization and deserialization."""
66
+ state = WorkflowState(data={"test": 123})
67
+ state_dict = state.to_dict()
68
+
69
+ assert "data" in state_dict
70
+ assert state_dict["data"]["test"] == 123
71
+
72
+ restored = WorkflowState.from_dict(state_dict)
73
+ assert restored.get("test") == 123
74
+
75
+
76
+ class TestStateManager:
77
+ """Tests for StateManager."""
78
+
79
+ def test_initialize(self):
80
+ """Test state manager initialization."""
81
+ manager = StateManager()
82
+ state = manager.initialize({"input": "test"})
83
+
84
+ assert manager.current_state is not None
85
+ assert manager.current_state.get("input") == "test"
86
+ assert manager.current_state.started_at is not None
87
+
88
+ def test_update_and_history(self):
89
+ """Test state updates create history."""
90
+ manager = StateManager()
91
+ state = manager.initialize({"count": 0})
92
+
93
+ new_state = state.set("count", 1)
94
+ manager.update(new_state, "node1")
95
+
96
+ assert len(manager.history) == 1
97
+ assert manager.history[0].node_name == "node1"
98
+ assert manager.current_state.get("count") == 1
99
+
100
+
101
+ # ============================================================
102
+ # Node Tests
103
+ # ============================================================
104
+
105
+ class TestNode:
106
+ """Tests for Node class."""
107
+
108
+ def test_create_node(self):
109
+ """Test creating a node."""
110
+ def handler(state):
111
+ return state
112
+
113
+ n = Node(name="test_node", handler=handler)
114
+
115
+ assert n.name == "test_node"
116
+ assert n.handler == handler
117
+ assert n.node_type == NodeType.STANDARD
118
+
119
+ def test_node_validation(self):
120
+ """Test node validation."""
121
+ with pytest.raises(ValueError, match="name cannot be empty"):
122
+ Node(name="", handler=lambda x: x)
123
+
124
+ with pytest.raises(ValueError, match="must be callable"):
125
+ Node(name="test", handler="not a function")
126
+
127
+ @pytest.mark.asyncio
128
+ async def test_sync_node_execution(self):
129
+ """Test executing a sync node."""
130
+ def handler(state):
131
+ state["processed"] = True
132
+ return state
133
+
134
+ n = Node(name="test", handler=handler)
135
+ result = await n.execute({"input": "data"})
136
+
137
+ assert result["processed"] is True
138
+ assert result["input"] == "data"
139
+
140
+ @pytest.mark.asyncio
141
+ async def test_async_node_execution(self):
142
+ """Test executing an async node."""
143
+ async def async_handler(state):
144
+ await asyncio.sleep(0.01)
145
+ state["async_processed"] = True
146
+ return state
147
+
148
+ n = Node(name="async_test", handler=async_handler)
149
+ assert n.is_async is True
150
+
151
+ result = await n.execute({"input": "data"})
152
+ assert result["async_processed"] is True
153
+
154
+ def test_node_decorator(self):
155
+ """Test the @node decorator."""
156
+ @node(name="decorated_node", description="A test node")
157
+ def my_handler(state):
158
+ return state
159
+
160
+ assert hasattr(my_handler, "_node_metadata")
161
+ assert my_handler._node_metadata["name"] == "decorated_node"
162
+
163
+
164
+ # ============================================================
165
+ # Graph Tests
166
+ # ============================================================
167
+
168
+ class TestGraph:
169
+ """Tests for Graph class."""
170
+
171
+ def test_create_graph(self):
172
+ """Test creating a graph."""
173
+ graph = Graph(name="Test Graph")
174
+ assert graph.name == "Test Graph"
175
+ assert len(graph.nodes) == 0
176
+
177
+ def test_add_nodes(self):
178
+ """Test adding nodes to a graph."""
179
+ graph = Graph()
180
+ graph.add_node("node1", handler=lambda s: s)
181
+ graph.add_node("node2", handler=lambda s: s)
182
+
183
+ assert "node1" in graph.nodes
184
+ assert "node2" in graph.nodes
185
+ assert graph.entry_point == "node1" # First node is entry
186
+
187
+ def test_add_edges(self):
188
+ """Test adding edges."""
189
+ graph = Graph()
190
+ graph.add_node("a", handler=lambda s: s)
191
+ graph.add_node("b", handler=lambda s: s)
192
+ graph.add_edge("a", "b")
193
+
194
+ assert graph.edges["a"] == "b"
195
+
196
+ def test_add_edge_to_end(self):
197
+ """Test adding edge to END."""
198
+ graph = Graph()
199
+ graph.add_node("a", handler=lambda s: s)
200
+ graph.add_edge("a", END)
201
+
202
+ assert graph.edges["a"] == END
203
+
204
+ def test_invalid_edge(self):
205
+ """Test adding invalid edges raises error."""
206
+ graph = Graph()
207
+ graph.add_node("a", handler=lambda s: s)
208
+
209
+ with pytest.raises(ValueError, match="not found"):
210
+ graph.add_edge("a", "nonexistent")
211
+
212
+ def test_conditional_edge(self):
213
+ """Test conditional edges."""
214
+ graph = Graph()
215
+ graph.add_node("check", handler=lambda s: s)
216
+ graph.add_node("yes", handler=lambda s: s)
217
+ graph.add_node("no", handler=lambda s: s)
218
+
219
+ def condition(state):
220
+ return "yes" if state.get("value") else "no"
221
+
222
+ graph.add_conditional_edge("check", condition, {"yes": "yes", "no": "no"})
223
+
224
+ # Test routing
225
+ assert graph.get_next_node("check", {"value": True}) == "yes"
226
+ assert graph.get_next_node("check", {"value": False}) == "no"
227
+
228
+ def test_graph_validation(self):
229
+ """Test graph validation."""
230
+ graph = Graph()
231
+
232
+ # Empty graph should fail
233
+ errors = graph.validate()
234
+ assert len(errors) > 0
235
+
236
+ # Valid graph
237
+ graph.add_node("start", handler=lambda s: s)
238
+ graph.add_edge("start", END)
239
+
240
+ errors = graph.validate()
241
+ assert len(errors) == 0
242
+
243
+ def test_mermaid_generation(self):
244
+ """Test Mermaid diagram generation."""
245
+ graph = Graph()
246
+ graph.add_node("a", handler=lambda s: s)
247
+ graph.add_node("b", handler=lambda s: s)
248
+ graph.add_edge("a", "b")
249
+ graph.add_edge("b", END)
250
+
251
+ mermaid = graph.to_mermaid()
252
+
253
+ assert "graph TD" in mermaid
254
+ assert "a" in mermaid
255
+ assert "b" in mermaid
256
+
257
+
258
+ # ============================================================
259
+ # Executor Tests
260
+ # ============================================================
261
+
262
+ class TestExecutor:
263
+ """Tests for the Executor."""
264
+
265
+ @pytest.mark.asyncio
266
+ async def test_simple_execution(self):
267
+ """Test executing a simple graph."""
268
+ graph = Graph()
269
+ graph.add_node("double", handler=lambda s: {**s, "value": s["value"] * 2})
270
+ graph.add_edge("double", END)
271
+
272
+ result = await execute_graph(graph, {"value": 5})
273
+
274
+ assert result.status == ExecutionStatus.COMPLETED
275
+ assert result.final_state["value"] == 10
276
+
277
+ @pytest.mark.asyncio
278
+ async def test_multi_node_execution(self):
279
+ """Test executing multiple nodes."""
280
+ graph = Graph()
281
+ graph.add_node("add1", handler=lambda s: {**s, "value": s["value"] + 1})
282
+ graph.add_node("add2", handler=lambda s: {**s, "value": s["value"] + 2})
283
+ graph.add_edge("add1", "add2")
284
+ graph.add_edge("add2", END)
285
+
286
+ result = await execute_graph(graph, {"value": 0})
287
+
288
+ assert result.status == ExecutionStatus.COMPLETED
289
+ assert result.final_state["value"] == 3
290
+ assert len(result.execution_log) == 2
291
+
292
+ @pytest.mark.asyncio
293
+ async def test_conditional_execution(self):
294
+ """Test conditional branching."""
295
+ graph = Graph()
296
+ graph.add_node("start", handler=lambda s: s)
297
+ graph.add_node("high", handler=lambda s: {**s, "path": "high"})
298
+ graph.add_node("low", handler=lambda s: {**s, "path": "low"})
299
+
300
+ def route(state):
301
+ return "high" if state["value"] > 5 else "low"
302
+
303
+ graph.add_conditional_edge("start", route, {"high": "high", "low": "low"})
304
+ graph.add_edge("high", END)
305
+ graph.add_edge("low", END)
306
+
307
+ # Test high path
308
+ result = await execute_graph(graph, {"value": 10})
309
+ assert result.final_state["path"] == "high"
310
+
311
+ # Test low path
312
+ result = await execute_graph(graph, {"value": 2})
313
+ assert result.final_state["path"] == "low"
314
+
315
+ @pytest.mark.asyncio
316
+ async def test_loop_execution(self):
317
+ """Test looping execution."""
318
+ graph = Graph(max_iterations=10)
319
+
320
+ def increment(state):
321
+ return {**state, "count": state["count"] + 1}
322
+
323
+ def check_count(state):
324
+ return "done" if state["count"] >= 3 else "continue"
325
+
326
+ graph.add_node("increment", handler=increment)
327
+ graph.add_conditional_edge("increment", check_count, {"done": END, "continue": "increment"})
328
+
329
+ result = await execute_graph(graph, {"count": 0})
330
+
331
+ assert result.status == ExecutionStatus.COMPLETED
332
+ assert result.final_state["count"] == 3
333
+
334
+ @pytest.mark.asyncio
335
+ async def test_max_iterations(self):
336
+ """Test max iterations limit."""
337
+ graph = Graph(max_iterations=3)
338
+
339
+ # Infinite loop
340
+ graph.add_node("loop", handler=lambda s: s)
341
+ graph.add_conditional_edge("loop", lambda s: "continue", {"continue": "loop"})
342
+
343
+ result = await execute_graph(graph, {})
344
+
345
+ assert result.status == ExecutionStatus.FAILED
346
+ assert "Max iterations" in result.error
347
+
348
+ @pytest.mark.asyncio
349
+ async def test_error_handling(self):
350
+ """Test error handling during execution."""
351
+ def failing_handler(state):
352
+ raise ValueError("Intentional error")
353
+
354
+ graph = Graph()
355
+ graph.add_node("fail", handler=failing_handler)
356
+
357
+ result = await execute_graph(graph, {})
358
+
359
+ assert result.status == ExecutionStatus.FAILED
360
+ assert "Intentional error" in result.error
361
+
362
+ @pytest.mark.asyncio
363
+ async def test_execution_log(self):
364
+ """Test that execution log is properly generated."""
365
+ graph = Graph()
366
+ graph.add_node("step1", handler=lambda s: s)
367
+ graph.add_node("step2", handler=lambda s: s)
368
+ graph.add_edge("step1", "step2")
369
+ graph.add_edge("step2", END)
370
+
371
+ result = await execute_graph(graph, {})
372
+
373
+ assert len(result.execution_log) == 2
374
+ assert result.execution_log[0].node == "step1"
375
+ assert result.execution_log[1].node == "step2"
376
+ assert all(s.duration_ms > 0 for s in result.execution_log)
377
+
378
+
379
+ # ============================================================
380
+ # Integration Tests
381
+ # ============================================================
382
+
383
+ class TestCodeReviewWorkflow:
384
+ """Integration tests for the Code Review workflow."""
385
+
386
+ @pytest.mark.asyncio
387
+ async def test_code_review_workflow(self):
388
+ """Test the full code review workflow."""
389
+ from app.workflows.code_review import create_code_review_workflow
390
+
391
+ sample_code = '''
392
+ def hello():
393
+ """Says hello."""
394
+ print("Hello, World!")
395
+
396
+ def add(a, b):
397
+ return a + b
398
+ '''
399
+
400
+ workflow = create_code_review_workflow(max_iterations=3, quality_threshold=5.0)
401
+ result = await execute_graph(workflow, {
402
+ "code": sample_code,
403
+ "quality_threshold": 5.0,
404
+ })
405
+
406
+ assert result.status == ExecutionStatus.COMPLETED
407
+ assert "functions" in result.final_state
408
+ assert "quality_score" in result.final_state
409
+ assert len(result.execution_log) > 0
410
+
411
+
412
+ if __name__ == "__main__":
413
+ pytest.main([__file__, "-v"])