Upload folder using huggingface_hub
Browse files- .dockerignore +38 -0
- .env.example +18 -0
- .gitignore +54 -0
- Dockerfile +35 -0
- LICENSE +21 -0
- README.md +11 -5
- app/__init__.py +9 -0
- app/api/__init__.py +7 -0
- app/api/routes/__init__.py +7 -0
- app/api/routes/graph.py +644 -0
- app/api/routes/tools.py +175 -0
- app/api/routes/websocket.py +269 -0
- app/api/schemas.py +322 -0
- app/config.py +34 -0
- app/engine/__init__.py +18 -0
- app/engine/executor.py +362 -0
- app/engine/graph.py +360 -0
- app/engine/node.py +196 -0
- app/engine/state.py +154 -0
- app/main.py +141 -0
- app/storage/__init__.py +17 -0
- app/storage/memory.py +271 -0
- app/tools/__init__.py +12 -0
- app/tools/builtin.py +387 -0
- app/tools/registry.py +218 -0
- app/workflows/__init__.py +10 -0
- app/workflows/code_review.py +338 -0
- docker-compose.yml +39 -0
- pytest.ini +6 -0
- requirements.txt +16 -0
- run.py +46 -0
- tests/__init__.py +3 -0
- tests/test_api.py +207 -0
- tests/test_engine.py +413 -0
.dockerignore
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docker
|
| 2 |
+
.git
|
| 3 |
+
.gitignore
|
| 4 |
+
venv/
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.pyc
|
| 7 |
+
*.pyo
|
| 8 |
+
*.pyd
|
| 9 |
+
.Python
|
| 10 |
+
env/
|
| 11 |
+
.env
|
| 12 |
+
.venv/
|
| 13 |
+
*.egg-info/
|
| 14 |
+
.eggs/
|
| 15 |
+
dist/
|
| 16 |
+
build/
|
| 17 |
+
|
| 18 |
+
# IDE
|
| 19 |
+
.idea/
|
| 20 |
+
.vscode/
|
| 21 |
+
*.swp
|
| 22 |
+
*.swo
|
| 23 |
+
|
| 24 |
+
# Testing artifacts
|
| 25 |
+
.pytest_cache/
|
| 26 |
+
.coverage
|
| 27 |
+
htmlcov/
|
| 28 |
+
.tox/
|
| 29 |
+
.nox/
|
| 30 |
+
|
| 31 |
+
# Documentation
|
| 32 |
+
*.md
|
| 33 |
+
!README.md
|
| 34 |
+
|
| 35 |
+
# Misc
|
| 36 |
+
.DS_Store
|
| 37 |
+
Thumbs.db
|
| 38 |
+
*.log
|
.env.example
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Example .env file
|
| 2 |
+
# Copy this to .env and modify as needed
|
| 3 |
+
|
| 4 |
+
# Application
|
| 5 |
+
APP_NAME="Workflow Engine"
|
| 6 |
+
APP_VERSION="1.0.0"
|
| 7 |
+
DEBUG=true
|
| 8 |
+
|
| 9 |
+
# Server
|
| 10 |
+
HOST=0.0.0.0
|
| 11 |
+
PORT=8000
|
| 12 |
+
|
| 13 |
+
# Workflow Engine
|
| 14 |
+
MAX_ITERATIONS=100
|
| 15 |
+
EXECUTION_TIMEOUT=300
|
| 16 |
+
|
| 17 |
+
# Logging
|
| 18 |
+
LOG_LEVEL=INFO
|
.gitignore
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment variables
|
| 2 |
+
.env
|
| 3 |
+
|
| 4 |
+
# Python
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
*.so
|
| 9 |
+
.Python
|
| 10 |
+
build/
|
| 11 |
+
develop-eggs/
|
| 12 |
+
dist/
|
| 13 |
+
downloads/
|
| 14 |
+
eggs/
|
| 15 |
+
.eggs/
|
| 16 |
+
lib/
|
| 17 |
+
lib64/
|
| 18 |
+
parts/
|
| 19 |
+
sdist/
|
| 20 |
+
var/
|
| 21 |
+
wheels/
|
| 22 |
+
*.egg-info/
|
| 23 |
+
.installed.cfg
|
| 24 |
+
*.egg
|
| 25 |
+
|
| 26 |
+
# Virtual environments
|
| 27 |
+
venv/
|
| 28 |
+
ENV/
|
| 29 |
+
env/
|
| 30 |
+
.venv/
|
| 31 |
+
|
| 32 |
+
# IDE
|
| 33 |
+
.idea/
|
| 34 |
+
.vscode/
|
| 35 |
+
*.swp
|
| 36 |
+
*.swo
|
| 37 |
+
*~
|
| 38 |
+
|
| 39 |
+
# Testing
|
| 40 |
+
.pytest_cache/
|
| 41 |
+
.coverage
|
| 42 |
+
htmlcov/
|
| 43 |
+
.tox/
|
| 44 |
+
.nox/
|
| 45 |
+
|
| 46 |
+
# Type checking
|
| 47 |
+
.mypy_cache/
|
| 48 |
+
|
| 49 |
+
# Logs
|
| 50 |
+
*.log
|
| 51 |
+
|
| 52 |
+
# OS
|
| 53 |
+
.DS_Store
|
| 54 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.11 slim image
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /app
|
| 6 |
+
|
| 7 |
+
# Set environment variables
|
| 8 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 9 |
+
PYTHONUNBUFFERED=1 \
|
| 10 |
+
PIP_NO_CACHE_DIR=1 \
|
| 11 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 12 |
+
|
| 13 |
+
# Install system dependencies
|
| 14 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 15 |
+
gcc \
|
| 16 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 17 |
+
|
| 18 |
+
# Copy requirements first for better caching
|
| 19 |
+
COPY requirements.txt .
|
| 20 |
+
|
| 21 |
+
# Install Python dependencies
|
| 22 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 23 |
+
|
| 24 |
+
# Copy application code
|
| 25 |
+
COPY . .
|
| 26 |
+
|
| 27 |
+
# Expose port (HuggingFace uses 7860 by default)
|
| 28 |
+
EXPOSE 7860
|
| 29 |
+
|
| 30 |
+
# Health check
|
| 31 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 32 |
+
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:7860/health')" || exit 1
|
| 33 |
+
|
| 34 |
+
# Run the application
|
| 35 |
+
CMD uvicorn app.main:app --host 0.0.0.0 --port 7860
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,10 +1,16 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: FlowGraph
|
| 3 |
+
emoji: 🔄
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# FlowGraph
|
| 13 |
+
|
| 14 |
+
A lightweight workflow orchestration engine for building agent pipelines.
|
| 15 |
+
|
| 16 |
+
Check out the API docs at `/docs`
|
app/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FlowGraph - A lightweight, async-first workflow orchestration engine.
|
| 3 |
+
|
| 4 |
+
Build agent pipelines with nodes, edges, conditional branching, and looping.
|
| 5 |
+
Similar to LangGraph, but minimal and focused.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
__version__ = "1.0.0"
|
| 9 |
+
__author__ = "AI Engineering Intern"
|
app/api/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API package - FastAPI routes and schemas.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from app.api.routes import graph, tools, websocket
|
| 6 |
+
|
| 7 |
+
__all__ = ["graph", "tools", "websocket"]
|
app/api/routes/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
API Routes package.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from app.api.routes import graph, tools, websocket
|
| 6 |
+
|
| 7 |
+
__all__ = ["graph", "tools", "websocket"]
|
app/api/routes/graph.py
ADDED
|
@@ -0,0 +1,644 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graph API Routes.
|
| 3 |
+
|
| 4 |
+
Endpoints for creating, managing, and executing workflow graphs.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Any, Dict, Optional
|
| 8 |
+
from fastapi import APIRouter, HTTPException, BackgroundTasks, status
|
| 9 |
+
from uuid import uuid4
|
| 10 |
+
import logging
|
| 11 |
+
|
| 12 |
+
from app.api.schemas import (
|
| 13 |
+
GraphCreateRequest,
|
| 14 |
+
GraphCreateResponse,
|
| 15 |
+
GraphRunRequest,
|
| 16 |
+
GraphRunResponse,
|
| 17 |
+
GraphInfoResponse,
|
| 18 |
+
GraphListResponse,
|
| 19 |
+
RunStateResponse,
|
| 20 |
+
RunListResponse,
|
| 21 |
+
ExecutionLogEntry,
|
| 22 |
+
ExecutionStatus,
|
| 23 |
+
ErrorResponse,
|
| 24 |
+
)
|
| 25 |
+
from app.engine.graph import Graph, END
|
| 26 |
+
from app.engine.node import Node, get_registered_node
|
| 27 |
+
from app.engine.executor import Executor, ExecutionResult
|
| 28 |
+
from app.storage.memory import graph_storage, run_storage
|
| 29 |
+
from app.tools.registry import tool_registry
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
router = APIRouter(prefix="/graph", tags=["Graph"])
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
# ============================================================
|
| 38 |
+
# Condition Functions Registry
|
| 39 |
+
# ============================================================
|
| 40 |
+
|
| 41 |
+
# Built-in condition functions for routing
|
| 42 |
+
_condition_registry: Dict[str, Any] = {}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def register_condition(name: str):
|
| 46 |
+
"""Decorator to register a condition function."""
|
| 47 |
+
def decorator(func):
|
| 48 |
+
_condition_registry[name] = func
|
| 49 |
+
return func
|
| 50 |
+
return decorator
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@register_condition("quality_check")
|
| 54 |
+
def quality_check_condition(state: Dict[str, Any]) -> str:
|
| 55 |
+
"""Route based on quality score vs threshold."""
|
| 56 |
+
quality_score = state.get("quality_score", 0)
|
| 57 |
+
threshold = state.get("quality_threshold", 7.0)
|
| 58 |
+
return "pass" if quality_score >= threshold else "fail"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Also register as quality_meets_threshold (used by code review workflow)
|
| 62 |
+
@register_condition("quality_meets_threshold")
|
| 63 |
+
def quality_meets_threshold(state: Dict[str, Any]) -> str:
|
| 64 |
+
"""Route based on quality score vs threshold."""
|
| 65 |
+
quality_score = state.get("quality_score", 0)
|
| 66 |
+
threshold = state.get("quality_threshold", 7.0)
|
| 67 |
+
return "pass" if quality_score >= threshold else "fail"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@register_condition("always_continue")
|
| 71 |
+
def always_continue(state: Dict[str, Any]) -> str:
|
| 72 |
+
"""Always returns 'continue' - for unconditional looping."""
|
| 73 |
+
return "continue"
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Also register as always_loop (used by code review workflow)
|
| 77 |
+
@register_condition("always_loop")
|
| 78 |
+
def always_loop(state: Dict[str, Any]) -> str:
|
| 79 |
+
"""Always returns 'continue' - for looping back."""
|
| 80 |
+
return "continue"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@register_condition("always_end")
|
| 84 |
+
def always_end(state: Dict[str, Any]) -> str:
|
| 85 |
+
"""Always returns 'end' - for explicit termination."""
|
| 86 |
+
return "end"
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@register_condition("max_iterations_check")
|
| 90 |
+
def max_iterations_check(state: Dict[str, Any]) -> str:
|
| 91 |
+
"""Check if max iterations reached."""
|
| 92 |
+
iteration = state.get("_iteration", 0)
|
| 93 |
+
max_iter = state.get("_max_iterations", 3)
|
| 94 |
+
return "stop" if iteration >= max_iter else "continue"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_condition(name: str):
|
| 98 |
+
"""Get a condition function by name."""
|
| 99 |
+
return _condition_registry.get(name)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ============================================================
|
| 103 |
+
# Graph CRUD Endpoints
|
| 104 |
+
# ============================================================
|
| 105 |
+
|
| 106 |
+
@router.post(
|
| 107 |
+
"/create",
|
| 108 |
+
response_model=GraphCreateResponse,
|
| 109 |
+
status_code=status.HTTP_201_CREATED,
|
| 110 |
+
responses={
|
| 111 |
+
400: {"model": ErrorResponse, "description": "Invalid graph definition"},
|
| 112 |
+
404: {"model": ErrorResponse, "description": "Handler not found"},
|
| 113 |
+
}
|
| 114 |
+
)
|
| 115 |
+
async def create_graph(request: GraphCreateRequest) -> GraphCreateResponse:
|
| 116 |
+
"""
|
| 117 |
+
Create a new workflow graph.
|
| 118 |
+
|
| 119 |
+
Define nodes with their handlers, edges for flow control,
|
| 120 |
+
and conditional edges for branching logic.
|
| 121 |
+
"""
|
| 122 |
+
graph_id = str(uuid4())
|
| 123 |
+
|
| 124 |
+
# Build the graph
|
| 125 |
+
graph = Graph(
|
| 126 |
+
graph_id=graph_id,
|
| 127 |
+
name=request.name,
|
| 128 |
+
description=request.description or "",
|
| 129 |
+
max_iterations=request.max_iterations,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Add nodes
|
| 133 |
+
for node_def in request.nodes:
|
| 134 |
+
# Find the handler function
|
| 135 |
+
handler = get_registered_node(node_def.handler)
|
| 136 |
+
if handler is None:
|
| 137 |
+
# Check tool registry as fallback
|
| 138 |
+
tool = tool_registry.get(node_def.handler)
|
| 139 |
+
if tool:
|
| 140 |
+
handler = _create_node_handler_from_tool(node_def.handler)
|
| 141 |
+
else:
|
| 142 |
+
raise HTTPException(
|
| 143 |
+
status_code=404,
|
| 144 |
+
detail=f"Handler '{node_def.handler}' not found. "
|
| 145 |
+
f"Available handlers: {list(tool_registry.list_tools())}"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
graph.add_node(
|
| 149 |
+
name=node_def.name,
|
| 150 |
+
handler=handler,
|
| 151 |
+
description=node_def.description or "",
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Add direct edges
|
| 155 |
+
for source, target in request.edges.items():
|
| 156 |
+
if source not in graph.nodes:
|
| 157 |
+
raise HTTPException(
|
| 158 |
+
status_code=400,
|
| 159 |
+
detail=f"Edge source '{source}' is not a valid node"
|
| 160 |
+
)
|
| 161 |
+
if target != END and target != "__END__" and target not in graph.nodes:
|
| 162 |
+
raise HTTPException(
|
| 163 |
+
status_code=400,
|
| 164 |
+
detail=f"Edge target '{target}' is not a valid node"
|
| 165 |
+
)
|
| 166 |
+
# Normalize END
|
| 167 |
+
target = END if target == "__END__" else target
|
| 168 |
+
graph.add_edge(source, target)
|
| 169 |
+
|
| 170 |
+
# Add conditional edges
|
| 171 |
+
for source, cond_routes in request.conditional_edges.items():
|
| 172 |
+
if source not in graph.nodes:
|
| 173 |
+
raise HTTPException(
|
| 174 |
+
status_code=400,
|
| 175 |
+
detail=f"Conditional edge source '{source}' is not a valid node"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Get condition function
|
| 179 |
+
condition_func = get_condition(cond_routes.condition)
|
| 180 |
+
if condition_func is None:
|
| 181 |
+
raise HTTPException(
|
| 182 |
+
status_code=404,
|
| 183 |
+
detail=f"Condition '{cond_routes.condition}' not found. "
|
| 184 |
+
f"Available: {list(_condition_registry.keys())}"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Normalize routes (handle __END__)
|
| 188 |
+
routes = {}
|
| 189 |
+
for key, target in cond_routes.routes.items():
|
| 190 |
+
if target == "__END__":
|
| 191 |
+
routes[key] = END
|
| 192 |
+
else:
|
| 193 |
+
if target not in graph.nodes:
|
| 194 |
+
raise HTTPException(
|
| 195 |
+
status_code=400,
|
| 196 |
+
detail=f"Conditional route target '{target}' is not a valid node"
|
| 197 |
+
)
|
| 198 |
+
routes[key] = target
|
| 199 |
+
|
| 200 |
+
graph.add_conditional_edge(source, condition_func, routes)
|
| 201 |
+
|
| 202 |
+
# Set entry point
|
| 203 |
+
if request.entry_point:
|
| 204 |
+
if request.entry_point not in graph.nodes:
|
| 205 |
+
raise HTTPException(
|
| 206 |
+
status_code=400,
|
| 207 |
+
detail=f"Entry point '{request.entry_point}' is not a valid node"
|
| 208 |
+
)
|
| 209 |
+
graph.set_entry_point(request.entry_point)
|
| 210 |
+
|
| 211 |
+
# Validate graph
|
| 212 |
+
errors = graph.validate()
|
| 213 |
+
if errors:
|
| 214 |
+
raise HTTPException(
|
| 215 |
+
status_code=400,
|
| 216 |
+
detail=f"Graph validation failed: {errors}"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Store the graph
|
| 220 |
+
await graph_storage.save(
|
| 221 |
+
graph_id=graph_id,
|
| 222 |
+
name=request.name,
|
| 223 |
+
definition=graph.to_dict(),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
logger.info(f"Created graph: {graph_id} ({request.name})")
|
| 227 |
+
|
| 228 |
+
return GraphCreateResponse(
|
| 229 |
+
graph_id=graph_id,
|
| 230 |
+
name=request.name,
|
| 231 |
+
message="Graph created successfully",
|
| 232 |
+
node_count=len(graph.nodes),
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def _create_node_handler_from_tool(tool_name: str):
|
| 237 |
+
"""Create a node handler that calls a tool and updates state."""
|
| 238 |
+
def handler(state: Dict[str, Any]) -> Dict[str, Any]:
|
| 239 |
+
tool = tool_registry.get(tool_name)
|
| 240 |
+
if not tool:
|
| 241 |
+
raise ValueError(f"Tool '{tool_name}' not found")
|
| 242 |
+
|
| 243 |
+
# Check if the tool function expects a 'state' parameter (node handler style)
|
| 244 |
+
# or individual parameters (regular tool style)
|
| 245 |
+
import inspect
|
| 246 |
+
sig = inspect.signature(tool.func)
|
| 247 |
+
param_names = list(sig.parameters.keys())
|
| 248 |
+
|
| 249 |
+
if len(param_names) == 1 and param_names[0] == 'state':
|
| 250 |
+
# This is a node handler - pass state directly
|
| 251 |
+
result = tool.func(state)
|
| 252 |
+
else:
|
| 253 |
+
# This is a regular tool - extract arguments from state
|
| 254 |
+
result = tool.func(**_extract_tool_args(tool, state))
|
| 255 |
+
|
| 256 |
+
# Handle the result
|
| 257 |
+
if isinstance(result, dict):
|
| 258 |
+
# If the tool returns a full state, use it directly
|
| 259 |
+
# Check if it looks like a state update (has same keys or adds new ones)
|
| 260 |
+
if result is state:
|
| 261 |
+
return result
|
| 262 |
+
# Merge result into state
|
| 263 |
+
state.update(result)
|
| 264 |
+
|
| 265 |
+
return state
|
| 266 |
+
|
| 267 |
+
handler.__name__ = f"{tool_name}_handler"
|
| 268 |
+
return handler
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _extract_tool_args(tool, state: Dict[str, Any]) -> Dict[str, Any]:
|
| 272 |
+
"""Extract arguments for a tool from state."""
|
| 273 |
+
import inspect
|
| 274 |
+
sig = inspect.signature(tool.func)
|
| 275 |
+
args = {}
|
| 276 |
+
|
| 277 |
+
for param_name, param in sig.parameters.items():
|
| 278 |
+
if param_name in state:
|
| 279 |
+
args[param_name] = state[param_name]
|
| 280 |
+
elif param.default != inspect.Parameter.empty:
|
| 281 |
+
pass # Use default
|
| 282 |
+
# Skip missing optional params
|
| 283 |
+
|
| 284 |
+
return args
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
@router.get(
|
| 288 |
+
"/{graph_id}",
|
| 289 |
+
response_model=GraphInfoResponse,
|
| 290 |
+
responses={404: {"model": ErrorResponse}},
|
| 291 |
+
)
|
| 292 |
+
async def get_graph(graph_id: str) -> GraphInfoResponse:
|
| 293 |
+
"""Get information about a specific graph."""
|
| 294 |
+
stored = await graph_storage.get(graph_id)
|
| 295 |
+
if not stored:
|
| 296 |
+
raise HTTPException(status_code=404, detail=f"Graph '{graph_id}' not found")
|
| 297 |
+
|
| 298 |
+
definition = stored.definition
|
| 299 |
+
|
| 300 |
+
# Generate mermaid diagram
|
| 301 |
+
mermaid = _generate_mermaid(definition)
|
| 302 |
+
|
| 303 |
+
return GraphInfoResponse(
|
| 304 |
+
graph_id=stored.graph_id,
|
| 305 |
+
name=stored.name,
|
| 306 |
+
description=definition.get("description"),
|
| 307 |
+
node_count=len(definition.get("nodes", {})),
|
| 308 |
+
nodes=list(definition.get("nodes", {}).keys()),
|
| 309 |
+
entry_point=definition.get("entry_point"),
|
| 310 |
+
max_iterations=definition.get("max_iterations", 100),
|
| 311 |
+
created_at=stored.created_at.isoformat(),
|
| 312 |
+
mermaid_diagram=mermaid,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def _generate_mermaid(definition: Dict[str, Any]) -> str:
|
| 317 |
+
"""Generate a Mermaid diagram from graph definition."""
|
| 318 |
+
lines = ["graph TD"]
|
| 319 |
+
|
| 320 |
+
nodes = definition.get("nodes", {})
|
| 321 |
+
edges = definition.get("edges", {})
|
| 322 |
+
cond_edges = definition.get("conditional_edges", {})
|
| 323 |
+
|
| 324 |
+
# Add nodes
|
| 325 |
+
for name in nodes:
|
| 326 |
+
label = name.replace("_", " ").title()
|
| 327 |
+
lines.append(f' {name}["{label}"]')
|
| 328 |
+
|
| 329 |
+
# Check if END is used
|
| 330 |
+
has_end = END in edges.values()
|
| 331 |
+
for cond in cond_edges.values():
|
| 332 |
+
if END in cond.get("routes", {}).values():
|
| 333 |
+
has_end = True
|
| 334 |
+
|
| 335 |
+
if has_end:
|
| 336 |
+
lines.append(f' {END}(("END"))')
|
| 337 |
+
|
| 338 |
+
# Add direct edges
|
| 339 |
+
for source, target in edges.items():
|
| 340 |
+
lines.append(f" {source} --> {target}")
|
| 341 |
+
|
| 342 |
+
# Add conditional edges
|
| 343 |
+
for source, cond in cond_edges.items():
|
| 344 |
+
for route_key, target in cond.get("routes", {}).items():
|
| 345 |
+
lines.append(f" {source} -->|{route_key}| {target}")
|
| 346 |
+
|
| 347 |
+
return "\n".join(lines)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
@router.get(
|
| 351 |
+
"/",
|
| 352 |
+
response_model=GraphListResponse,
|
| 353 |
+
)
|
| 354 |
+
async def list_graphs() -> GraphListResponse:
|
| 355 |
+
"""List all available graphs."""
|
| 356 |
+
graphs = await graph_storage.list_all()
|
| 357 |
+
|
| 358 |
+
graph_infos = []
|
| 359 |
+
for stored in graphs:
|
| 360 |
+
definition = stored.definition
|
| 361 |
+
graph_infos.append(GraphInfoResponse(
|
| 362 |
+
graph_id=stored.graph_id,
|
| 363 |
+
name=stored.name,
|
| 364 |
+
description=definition.get("description"),
|
| 365 |
+
node_count=len(definition.get("nodes", {})),
|
| 366 |
+
nodes=list(definition.get("nodes", {}).keys()),
|
| 367 |
+
entry_point=definition.get("entry_point"),
|
| 368 |
+
max_iterations=definition.get("max_iterations", 100),
|
| 369 |
+
created_at=stored.created_at.isoformat(),
|
| 370 |
+
mermaid_diagram=None, # Skip for list view
|
| 371 |
+
))
|
| 372 |
+
|
| 373 |
+
return GraphListResponse(graphs=graph_infos, total=len(graph_infos))
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
@router.delete(
|
| 377 |
+
"/{graph_id}",
|
| 378 |
+
status_code=status.HTTP_204_NO_CONTENT,
|
| 379 |
+
responses={404: {"model": ErrorResponse}},
|
| 380 |
+
)
|
| 381 |
+
async def delete_graph(graph_id: str):
|
| 382 |
+
"""Delete a graph."""
|
| 383 |
+
deleted = await graph_storage.delete(graph_id)
|
| 384 |
+
if not deleted:
|
| 385 |
+
raise HTTPException(status_code=404, detail=f"Graph '{graph_id}' not found")
|
| 386 |
+
logger.info(f"Deleted graph: {graph_id}")
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# ============================================================
|
| 390 |
+
# Execution Endpoints
|
| 391 |
+
# ============================================================
|
| 392 |
+
|
| 393 |
+
@router.post(
|
| 394 |
+
"/run",
|
| 395 |
+
response_model=GraphRunResponse,
|
| 396 |
+
responses={
|
| 397 |
+
404: {"model": ErrorResponse},
|
| 398 |
+
500: {"model": ErrorResponse, "description": "Execution failed"},
|
| 399 |
+
}
|
| 400 |
+
)
|
| 401 |
+
async def run_graph(
|
| 402 |
+
request: GraphRunRequest,
|
| 403 |
+
background_tasks: BackgroundTasks,
|
| 404 |
+
) -> GraphRunResponse:
|
| 405 |
+
"""
|
| 406 |
+
Execute a workflow graph with the given initial state.
|
| 407 |
+
|
| 408 |
+
If `async_execution` is True, the workflow runs in the background
|
| 409 |
+
and you can poll the status using GET /graph/state/{run_id}.
|
| 410 |
+
"""
|
| 411 |
+
# Get the graph
|
| 412 |
+
stored = await graph_storage.get(request.graph_id)
|
| 413 |
+
if not stored:
|
| 414 |
+
raise HTTPException(
|
| 415 |
+
status_code=404,
|
| 416 |
+
detail=f"Graph '{request.graph_id}' not found"
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Rebuild the graph from definition
|
| 420 |
+
graph = await _rebuild_graph_from_definition(stored.definition)
|
| 421 |
+
|
| 422 |
+
# Create run
|
| 423 |
+
run_id = str(uuid4())
|
| 424 |
+
await run_storage.create(run_id, request.graph_id, request.initial_state)
|
| 425 |
+
|
| 426 |
+
if request.async_execution:
|
| 427 |
+
# Run in background
|
| 428 |
+
background_tasks.add_task(
|
| 429 |
+
_execute_in_background,
|
| 430 |
+
graph,
|
| 431 |
+
run_id,
|
| 432 |
+
request.initial_state,
|
| 433 |
+
)
|
| 434 |
+
|
| 435 |
+
return GraphRunResponse(
|
| 436 |
+
run_id=run_id,
|
| 437 |
+
graph_id=request.graph_id,
|
| 438 |
+
status=ExecutionStatus.PENDING,
|
| 439 |
+
final_state={},
|
| 440 |
+
execution_log=[],
|
| 441 |
+
started_at=None,
|
| 442 |
+
completed_at=None,
|
| 443 |
+
total_duration_ms=None,
|
| 444 |
+
iterations=0,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
# Execute synchronously
|
| 448 |
+
try:
|
| 449 |
+
executor = Executor(
|
| 450 |
+
graph,
|
| 451 |
+
run_id=run_id,
|
| 452 |
+
on_step=lambda step, state: _update_run_state(run_id, step, state),
|
| 453 |
+
)
|
| 454 |
+
result = await executor.run(request.initial_state)
|
| 455 |
+
|
| 456 |
+
# Update storage
|
| 457 |
+
if result.status.value == "completed":
|
| 458 |
+
await run_storage.complete(
|
| 459 |
+
run_id,
|
| 460 |
+
result.final_state,
|
| 461 |
+
[s.to_dict() for s in result.execution_log],
|
| 462 |
+
)
|
| 463 |
+
else:
|
| 464 |
+
await run_storage.fail(run_id, result.error or "Unknown error", result.final_state)
|
| 465 |
+
|
| 466 |
+
return _result_to_response(result)
|
| 467 |
+
|
| 468 |
+
except Exception as e:
|
| 469 |
+
logger.exception(f"Execution failed: {e}")
|
| 470 |
+
await run_storage.fail(run_id, str(e))
|
| 471 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
async def _rebuild_graph_from_definition(definition: Dict[str, Any]) -> Graph:
|
| 475 |
+
"""Rebuild a Graph object from its stored definition."""
|
| 476 |
+
graph = Graph(
|
| 477 |
+
graph_id=definition.get("graph_id", str(uuid4())),
|
| 478 |
+
name=definition.get("name", "Unnamed"),
|
| 479 |
+
description=definition.get("description", ""),
|
| 480 |
+
max_iterations=definition.get("max_iterations", 100),
|
| 481 |
+
)
|
| 482 |
+
|
| 483 |
+
# Add nodes
|
| 484 |
+
nodes_def = definition.get("nodes", {})
|
| 485 |
+
for node_name, node_info in nodes_def.items():
|
| 486 |
+
handler_name = node_info.get("handler", node_name)
|
| 487 |
+
handler = _create_node_handler_from_tool(handler_name)
|
| 488 |
+
graph.add_node(
|
| 489 |
+
name=node_name,
|
| 490 |
+
handler=handler,
|
| 491 |
+
description=node_info.get("description", ""),
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
# Add direct edges
|
| 495 |
+
for source, target in definition.get("edges", {}).items():
|
| 496 |
+
graph.add_edge(source, target)
|
| 497 |
+
|
| 498 |
+
# Add conditional edges
|
| 499 |
+
for source, cond_info in definition.get("conditional_edges", {}).items():
|
| 500 |
+
condition_name = cond_info.get("condition", "always_continue")
|
| 501 |
+
condition_func = get_condition(condition_name)
|
| 502 |
+
if condition_func is None:
|
| 503 |
+
condition_func = always_continue
|
| 504 |
+
|
| 505 |
+
routes = cond_info.get("routes", {})
|
| 506 |
+
graph.add_conditional_edge(source, condition_func, routes)
|
| 507 |
+
|
| 508 |
+
# Set entry point
|
| 509 |
+
if definition.get("entry_point"):
|
| 510 |
+
graph.set_entry_point(definition["entry_point"])
|
| 511 |
+
|
| 512 |
+
return graph
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
async def _execute_in_background(graph: Graph, run_id: str, initial_state: Dict[str, Any]):
|
| 516 |
+
"""Execute a workflow in the background."""
|
| 517 |
+
try:
|
| 518 |
+
executor = Executor(
|
| 519 |
+
graph,
|
| 520 |
+
run_id=run_id,
|
| 521 |
+
on_step=lambda step, state: _update_run_state(run_id, step, state),
|
| 522 |
+
)
|
| 523 |
+
result = await executor.run(initial_state)
|
| 524 |
+
|
| 525 |
+
if result.status.value == "completed":
|
| 526 |
+
await run_storage.complete(
|
| 527 |
+
run_id,
|
| 528 |
+
result.final_state,
|
| 529 |
+
[s.to_dict() for s in result.execution_log],
|
| 530 |
+
)
|
| 531 |
+
else:
|
| 532 |
+
await run_storage.fail(run_id, result.error or "Unknown error", result.final_state)
|
| 533 |
+
|
| 534 |
+
except Exception as e:
|
| 535 |
+
logger.exception(f"Background execution failed: {e}")
|
| 536 |
+
await run_storage.fail(run_id, str(e))
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
def _update_run_state(run_id: str, step, state: Dict[str, Any]):
|
| 540 |
+
"""Update run state during execution (sync callback)."""
|
| 541 |
+
import asyncio
|
| 542 |
+
try:
|
| 543 |
+
loop = asyncio.get_event_loop()
|
| 544 |
+
if loop.is_running():
|
| 545 |
+
asyncio.create_task(
|
| 546 |
+
run_storage.update_state(run_id, state, step.node, step.iteration)
|
| 547 |
+
)
|
| 548 |
+
except Exception:
|
| 549 |
+
pass # Ignore errors in callback
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
def _result_to_response(result: ExecutionResult) -> GraphRunResponse:
|
| 553 |
+
"""Convert ExecutionResult to API response."""
|
| 554 |
+
return GraphRunResponse(
|
| 555 |
+
run_id=result.run_id,
|
| 556 |
+
graph_id=result.graph_id,
|
| 557 |
+
status=ExecutionStatus(result.status.value),
|
| 558 |
+
final_state=result.final_state,
|
| 559 |
+
execution_log=[
|
| 560 |
+
ExecutionLogEntry(
|
| 561 |
+
step=s.step,
|
| 562 |
+
node=s.node,
|
| 563 |
+
started_at=s.started_at.isoformat(),
|
| 564 |
+
completed_at=s.completed_at.isoformat() if s.completed_at else None,
|
| 565 |
+
duration_ms=s.duration_ms,
|
| 566 |
+
iteration=s.iteration,
|
| 567 |
+
result=s.result,
|
| 568 |
+
error=s.error,
|
| 569 |
+
route_taken=s.route_taken,
|
| 570 |
+
)
|
| 571 |
+
for s in result.execution_log
|
| 572 |
+
],
|
| 573 |
+
started_at=result.started_at.isoformat() if result.started_at else None,
|
| 574 |
+
completed_at=result.completed_at.isoformat() if result.completed_at else None,
|
| 575 |
+
total_duration_ms=result.total_duration_ms,
|
| 576 |
+
iterations=result.iterations,
|
| 577 |
+
error=result.error,
|
| 578 |
+
)
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# ============================================================
|
| 582 |
+
# Run State Endpoints
|
| 583 |
+
# ============================================================
|
| 584 |
+
|
| 585 |
+
@router.get(
|
| 586 |
+
"/state/{run_id}",
|
| 587 |
+
response_model=RunStateResponse,
|
| 588 |
+
responses={404: {"model": ErrorResponse}},
|
| 589 |
+
)
|
| 590 |
+
async def get_run_state(run_id: str) -> RunStateResponse:
|
| 591 |
+
"""
|
| 592 |
+
Get the current state of a workflow run.
|
| 593 |
+
|
| 594 |
+
Use this to poll the status of async executions.
|
| 595 |
+
"""
|
| 596 |
+
stored = await run_storage.get(run_id)
|
| 597 |
+
if not stored:
|
| 598 |
+
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
| 599 |
+
|
| 600 |
+
return RunStateResponse(
|
| 601 |
+
run_id=stored.run_id,
|
| 602 |
+
graph_id=stored.graph_id,
|
| 603 |
+
status=ExecutionStatus(stored.status),
|
| 604 |
+
current_node=stored.current_node,
|
| 605 |
+
current_state=stored.current_state,
|
| 606 |
+
iteration=stored.iteration,
|
| 607 |
+
execution_log=[
|
| 608 |
+
ExecutionLogEntry(**entry) for entry in stored.execution_log
|
| 609 |
+
],
|
| 610 |
+
started_at=stored.started_at.isoformat(),
|
| 611 |
+
completed_at=stored.completed_at.isoformat() if stored.completed_at else None,
|
| 612 |
+
error=stored.error,
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
@router.get(
|
| 617 |
+
"/runs",
|
| 618 |
+
response_model=RunListResponse,
|
| 619 |
+
)
|
| 620 |
+
async def list_runs(graph_id: Optional[str] = None) -> RunListResponse:
|
| 621 |
+
"""List all runs, optionally filtered by graph_id."""
|
| 622 |
+
if graph_id:
|
| 623 |
+
runs = await run_storage.list_by_graph(graph_id)
|
| 624 |
+
else:
|
| 625 |
+
runs = await run_storage.list_all()
|
| 626 |
+
|
| 627 |
+
run_states = []
|
| 628 |
+
for stored in runs:
|
| 629 |
+
run_states.append(RunStateResponse(
|
| 630 |
+
run_id=stored.run_id,
|
| 631 |
+
graph_id=stored.graph_id,
|
| 632 |
+
status=ExecutionStatus(stored.status),
|
| 633 |
+
current_node=stored.current_node,
|
| 634 |
+
current_state=stored.current_state,
|
| 635 |
+
iteration=stored.iteration,
|
| 636 |
+
execution_log=[
|
| 637 |
+
ExecutionLogEntry(**entry) for entry in stored.execution_log
|
| 638 |
+
],
|
| 639 |
+
started_at=stored.started_at.isoformat(),
|
| 640 |
+
completed_at=stored.completed_at.isoformat() if stored.completed_at else None,
|
| 641 |
+
error=stored.error,
|
| 642 |
+
))
|
| 643 |
+
|
| 644 |
+
return RunListResponse(runs=run_states, total=len(run_states))
|
app/api/routes/tools.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tools API Routes.
|
| 3 |
+
|
| 4 |
+
Endpoints for listing and managing registered tools.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Any, Dict
|
| 8 |
+
from fastapi import APIRouter, HTTPException, status
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
from app.api.schemas import (
|
| 12 |
+
ToolInfo,
|
| 13 |
+
ToolListResponse,
|
| 14 |
+
ToolRegisterRequest,
|
| 15 |
+
ToolRegisterResponse,
|
| 16 |
+
ErrorResponse,
|
| 17 |
+
)
|
| 18 |
+
from app.tools.registry import tool_registry
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
router = APIRouter(prefix="/tools", tags=["Tools"])
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@router.get(
|
| 27 |
+
"/",
|
| 28 |
+
response_model=ToolListResponse,
|
| 29 |
+
)
|
| 30 |
+
async def list_tools() -> ToolListResponse:
|
| 31 |
+
"""
|
| 32 |
+
List all registered tools.
|
| 33 |
+
|
| 34 |
+
Tools are functions that workflow nodes can use during execution.
|
| 35 |
+
"""
|
| 36 |
+
tools = tool_registry.list_tools()
|
| 37 |
+
|
| 38 |
+
tool_infos = [
|
| 39 |
+
ToolInfo(
|
| 40 |
+
name=t["name"],
|
| 41 |
+
description=t["description"],
|
| 42 |
+
parameters=t["parameters"],
|
| 43 |
+
)
|
| 44 |
+
for t in tools
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
return ToolListResponse(tools=tool_infos, total=len(tool_infos))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@router.get(
|
| 51 |
+
"/{tool_name}",
|
| 52 |
+
response_model=ToolInfo,
|
| 53 |
+
responses={404: {"model": ErrorResponse}},
|
| 54 |
+
)
|
| 55 |
+
async def get_tool(tool_name: str) -> ToolInfo:
|
| 56 |
+
"""Get information about a specific tool."""
|
| 57 |
+
tool = tool_registry.get(tool_name)
|
| 58 |
+
if not tool:
|
| 59 |
+
raise HTTPException(
|
| 60 |
+
status_code=404,
|
| 61 |
+
detail=f"Tool '{tool_name}' not found"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return ToolInfo(
|
| 65 |
+
name=tool.name,
|
| 66 |
+
description=tool.description,
|
| 67 |
+
parameters=tool.parameters,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@router.post(
|
| 72 |
+
"/register",
|
| 73 |
+
response_model=ToolRegisterResponse,
|
| 74 |
+
status_code=status.HTTP_201_CREATED,
|
| 75 |
+
responses={
|
| 76 |
+
400: {"model": ErrorResponse, "description": "Invalid tool code"},
|
| 77 |
+
409: {"model": ErrorResponse, "description": "Tool already exists"},
|
| 78 |
+
}
|
| 79 |
+
)
|
| 80 |
+
async def register_tool(request: ToolRegisterRequest) -> ToolRegisterResponse:
|
| 81 |
+
"""
|
| 82 |
+
Register a new tool dynamically.
|
| 83 |
+
|
| 84 |
+
**Warning**: This endpoint executes Python code. Use with caution
|
| 85 |
+
and only in trusted environments.
|
| 86 |
+
|
| 87 |
+
The code should define a function that:
|
| 88 |
+
- Takes parameters as needed
|
| 89 |
+
- Returns a dictionary with results
|
| 90 |
+
"""
|
| 91 |
+
# Check if tool already exists
|
| 92 |
+
if tool_registry.has(request.name):
|
| 93 |
+
raise HTTPException(
|
| 94 |
+
status_code=409,
|
| 95 |
+
detail=f"Tool '{request.name}' already exists"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Try to compile and execute the code
|
| 99 |
+
try:
|
| 100 |
+
# Create a restricted namespace
|
| 101 |
+
namespace: Dict[str, Any] = {}
|
| 102 |
+
|
| 103 |
+
# Execute the code to define the function
|
| 104 |
+
exec(request.code, namespace)
|
| 105 |
+
|
| 106 |
+
# Find the function in the namespace
|
| 107 |
+
func = None
|
| 108 |
+
for name, value in namespace.items():
|
| 109 |
+
if callable(value) and not name.startswith("_"):
|
| 110 |
+
func = value
|
| 111 |
+
break
|
| 112 |
+
|
| 113 |
+
if func is None:
|
| 114 |
+
raise HTTPException(
|
| 115 |
+
status_code=400,
|
| 116 |
+
detail="No callable function found in the provided code"
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# Register the tool
|
| 120 |
+
tool_registry.add(
|
| 121 |
+
func=func,
|
| 122 |
+
name=request.name,
|
| 123 |
+
description=request.description,
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
logger.info(f"Registered dynamic tool: {request.name}")
|
| 127 |
+
|
| 128 |
+
return ToolRegisterResponse(
|
| 129 |
+
name=request.name,
|
| 130 |
+
message=f"Tool '{request.name}' registered successfully",
|
| 131 |
+
warning="Dynamic tool registration executes code. Use responsibly.",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
except SyntaxError as e:
|
| 135 |
+
raise HTTPException(
|
| 136 |
+
status_code=400,
|
| 137 |
+
detail=f"Syntax error in tool code: {e}"
|
| 138 |
+
)
|
| 139 |
+
except Exception as e:
|
| 140 |
+
raise HTTPException(
|
| 141 |
+
status_code=400,
|
| 142 |
+
detail=f"Error registering tool: {e}"
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@router.delete(
|
| 147 |
+
"/{tool_name}",
|
| 148 |
+
status_code=status.HTTP_204_NO_CONTENT,
|
| 149 |
+
responses={404: {"model": ErrorResponse}},
|
| 150 |
+
)
|
| 151 |
+
async def delete_tool(tool_name: str):
|
| 152 |
+
"""Delete a registered tool."""
|
| 153 |
+
# Protect built-in tools
|
| 154 |
+
builtin_tools = {
|
| 155 |
+
"extract_functions",
|
| 156 |
+
"calculate_complexity",
|
| 157 |
+
"detect_issues",
|
| 158 |
+
"suggest_improvements",
|
| 159 |
+
"quality_check",
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if tool_name in builtin_tools:
|
| 163 |
+
raise HTTPException(
|
| 164 |
+
status_code=400,
|
| 165 |
+
detail=f"Cannot delete built-in tool '{tool_name}'"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
deleted = tool_registry.remove(tool_name)
|
| 169 |
+
if not deleted:
|
| 170 |
+
raise HTTPException(
|
| 171 |
+
status_code=404,
|
| 172 |
+
detail=f"Tool '{tool_name}' not found"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
logger.info(f"Deleted tool: {tool_name}")
|
app/api/routes/websocket.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
WebSocket Routes for Real-time Execution Streaming.
|
| 3 |
+
|
| 4 |
+
Provides live updates during workflow execution.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Any, Dict, Set
|
| 8 |
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
| 9 |
+
from uuid import uuid4
|
| 10 |
+
import asyncio
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
from app.engine.graph import Graph
|
| 15 |
+
from app.engine.executor import Executor, ExecutionStep
|
| 16 |
+
from app.storage.memory import graph_storage, run_storage
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger(__name__)
|
| 20 |
+
|
| 21 |
+
router = APIRouter(tags=["WebSocket"])
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ConnectionManager:
|
| 25 |
+
"""Manages WebSocket connections."""
|
| 26 |
+
|
| 27 |
+
def __init__(self):
|
| 28 |
+
self.active_connections: Dict[str, Set[WebSocket]] = {}
|
| 29 |
+
|
| 30 |
+
async def connect(self, websocket: WebSocket, run_id: str):
|
| 31 |
+
"""Accept a new WebSocket connection."""
|
| 32 |
+
await websocket.accept()
|
| 33 |
+
if run_id not in self.active_connections:
|
| 34 |
+
self.active_connections[run_id] = set()
|
| 35 |
+
self.active_connections[run_id].add(websocket)
|
| 36 |
+
logger.info(f"WebSocket connected for run: {run_id}")
|
| 37 |
+
|
| 38 |
+
def disconnect(self, websocket: WebSocket, run_id: str):
|
| 39 |
+
"""Remove a WebSocket connection."""
|
| 40 |
+
if run_id in self.active_connections:
|
| 41 |
+
self.active_connections[run_id].discard(websocket)
|
| 42 |
+
if not self.active_connections[run_id]:
|
| 43 |
+
del self.active_connections[run_id]
|
| 44 |
+
logger.info(f"WebSocket disconnected for run: {run_id}")
|
| 45 |
+
|
| 46 |
+
async def broadcast(self, run_id: str, message: Dict[str, Any]):
|
| 47 |
+
"""Broadcast a message to all connections for a run."""
|
| 48 |
+
if run_id in self.active_connections:
|
| 49 |
+
disconnected = set()
|
| 50 |
+
for websocket in self.active_connections[run_id]:
|
| 51 |
+
try:
|
| 52 |
+
await websocket.send_json(message)
|
| 53 |
+
except Exception:
|
| 54 |
+
disconnected.add(websocket)
|
| 55 |
+
|
| 56 |
+
# Clean up disconnected clients
|
| 57 |
+
for ws in disconnected:
|
| 58 |
+
self.active_connections[run_id].discard(ws)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Global connection manager
|
| 62 |
+
manager = ConnectionManager()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@router.websocket("/ws/run/{graph_id}")
|
| 66 |
+
async def websocket_run(websocket: WebSocket, graph_id: str):
|
| 67 |
+
"""
|
| 68 |
+
WebSocket endpoint for real-time workflow execution.
|
| 69 |
+
|
| 70 |
+
Connect to this endpoint and send the initial state as JSON.
|
| 71 |
+
You'll receive step-by-step updates as the workflow executes.
|
| 72 |
+
|
| 73 |
+
Message format (client -> server):
|
| 74 |
+
```json
|
| 75 |
+
{"action": "start", "initial_state": {"code": "..."}}
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
Message format (server -> client):
|
| 79 |
+
```json
|
| 80 |
+
{
|
| 81 |
+
"type": "step",
|
| 82 |
+
"step": 1,
|
| 83 |
+
"node": "extract",
|
| 84 |
+
"status": "completed",
|
| 85 |
+
"duration_ms": 15.5,
|
| 86 |
+
"state": {...}
|
| 87 |
+
}
|
| 88 |
+
```
|
| 89 |
+
"""
|
| 90 |
+
# Check if graph exists
|
| 91 |
+
stored = await graph_storage.get(graph_id)
|
| 92 |
+
if not stored:
|
| 93 |
+
await websocket.close(code=4004, reason=f"Graph '{graph_id}' not found")
|
| 94 |
+
return
|
| 95 |
+
|
| 96 |
+
run_id = str(uuid4())
|
| 97 |
+
await manager.connect(websocket, run_id)
|
| 98 |
+
|
| 99 |
+
try:
|
| 100 |
+
# Wait for start message
|
| 101 |
+
data = await websocket.receive_json()
|
| 102 |
+
|
| 103 |
+
if data.get("action") != "start":
|
| 104 |
+
await websocket.send_json({
|
| 105 |
+
"type": "error",
|
| 106 |
+
"error": "Expected 'start' action"
|
| 107 |
+
})
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
initial_state = data.get("initial_state", {})
|
| 111 |
+
|
| 112 |
+
# Send acknowledgment
|
| 113 |
+
await websocket.send_json({
|
| 114 |
+
"type": "started",
|
| 115 |
+
"run_id": run_id,
|
| 116 |
+
"graph_id": graph_id,
|
| 117 |
+
})
|
| 118 |
+
|
| 119 |
+
# Rebuild graph
|
| 120 |
+
graph = await _rebuild_graph(stored.definition)
|
| 121 |
+
|
| 122 |
+
# Create run record
|
| 123 |
+
await run_storage.create(run_id, graph_id, initial_state)
|
| 124 |
+
|
| 125 |
+
# Execute with streaming updates
|
| 126 |
+
async def on_step(step: ExecutionStep, state: Dict[str, Any]):
|
| 127 |
+
await manager.broadcast(run_id, {
|
| 128 |
+
"type": "step",
|
| 129 |
+
"step": step.step,
|
| 130 |
+
"node": step.node,
|
| 131 |
+
"status": step.result,
|
| 132 |
+
"duration_ms": step.duration_ms,
|
| 133 |
+
"iteration": step.iteration,
|
| 134 |
+
"route_taken": step.route_taken,
|
| 135 |
+
"error": step.error,
|
| 136 |
+
"state": state,
|
| 137 |
+
})
|
| 138 |
+
|
| 139 |
+
executor = Executor(graph, run_id=run_id)
|
| 140 |
+
|
| 141 |
+
# Run with step notifications
|
| 142 |
+
result = await _run_with_streaming(executor, initial_state, on_step)
|
| 143 |
+
|
| 144 |
+
# Send completion
|
| 145 |
+
await websocket.send_json({
|
| 146 |
+
"type": "completed",
|
| 147 |
+
"run_id": run_id,
|
| 148 |
+
"status": result.status.value,
|
| 149 |
+
"final_state": result.final_state,
|
| 150 |
+
"total_duration_ms": result.total_duration_ms,
|
| 151 |
+
"iterations": result.iterations,
|
| 152 |
+
"error": result.error,
|
| 153 |
+
})
|
| 154 |
+
|
| 155 |
+
# Update storage
|
| 156 |
+
if result.status.value == "completed":
|
| 157 |
+
await run_storage.complete(
|
| 158 |
+
run_id,
|
| 159 |
+
result.final_state,
|
| 160 |
+
[s.to_dict() for s in result.execution_log],
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
await run_storage.fail(run_id, result.error or "Unknown error")
|
| 164 |
+
|
| 165 |
+
except WebSocketDisconnect:
|
| 166 |
+
logger.info(f"Client disconnected from run {run_id}")
|
| 167 |
+
except Exception as e:
|
| 168 |
+
logger.exception(f"WebSocket error: {e}")
|
| 169 |
+
try:
|
| 170 |
+
await websocket.send_json({
|
| 171 |
+
"type": "error",
|
| 172 |
+
"error": str(e),
|
| 173 |
+
})
|
| 174 |
+
except Exception:
|
| 175 |
+
pass
|
| 176 |
+
finally:
|
| 177 |
+
manager.disconnect(websocket, run_id)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
async def _rebuild_graph(definition: Dict[str, Any]) -> Graph:
|
| 181 |
+
"""Rebuild graph from definition (copied from graph.py to avoid circular import)."""
|
| 182 |
+
from app.api.routes.graph import _rebuild_graph_from_definition
|
| 183 |
+
return await _rebuild_graph_from_definition(definition)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
async def _run_with_streaming(
|
| 187 |
+
executor: Executor,
|
| 188 |
+
initial_state: Dict[str, Any],
|
| 189 |
+
on_step
|
| 190 |
+
):
|
| 191 |
+
"""Run executor with async step callbacks."""
|
| 192 |
+
from app.engine.graph import END
|
| 193 |
+
from app.engine.state import StateManager
|
| 194 |
+
import time
|
| 195 |
+
from datetime import datetime
|
| 196 |
+
|
| 197 |
+
# Execute the workflow
|
| 198 |
+
result = await executor.run(initial_state)
|
| 199 |
+
|
| 200 |
+
# Stream each step (already executed, but we notify)
|
| 201 |
+
for step in result.execution_log:
|
| 202 |
+
await on_step(step, result.final_state)
|
| 203 |
+
await asyncio.sleep(0.01) # Small delay for streaming effect
|
| 204 |
+
|
| 205 |
+
return result
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@router.websocket("/ws/subscribe/{run_id}")
|
| 209 |
+
async def websocket_subscribe(websocket: WebSocket, run_id: str):
|
| 210 |
+
"""
|
| 211 |
+
Subscribe to updates for an existing run.
|
| 212 |
+
|
| 213 |
+
Use this to watch an async execution started via POST /graph/run.
|
| 214 |
+
"""
|
| 215 |
+
# Check if run exists
|
| 216 |
+
stored = await run_storage.get(run_id)
|
| 217 |
+
if not stored:
|
| 218 |
+
await websocket.close(code=4004, reason=f"Run '{run_id}' not found")
|
| 219 |
+
return
|
| 220 |
+
|
| 221 |
+
await manager.connect(websocket, run_id)
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
# Send current state
|
| 225 |
+
await websocket.send_json({
|
| 226 |
+
"type": "current_state",
|
| 227 |
+
"run_id": run_id,
|
| 228 |
+
"status": stored.status,
|
| 229 |
+
"current_node": stored.current_node,
|
| 230 |
+
"iteration": stored.iteration,
|
| 231 |
+
"state": stored.current_state,
|
| 232 |
+
})
|
| 233 |
+
|
| 234 |
+
# Keep connection open and poll for updates
|
| 235 |
+
last_log_count = len(stored.execution_log)
|
| 236 |
+
|
| 237 |
+
while True:
|
| 238 |
+
await asyncio.sleep(0.5) # Poll interval
|
| 239 |
+
|
| 240 |
+
stored = await run_storage.get(run_id)
|
| 241 |
+
if not stored:
|
| 242 |
+
break
|
| 243 |
+
|
| 244 |
+
# Send new log entries
|
| 245 |
+
if len(stored.execution_log) > last_log_count:
|
| 246 |
+
for entry in stored.execution_log[last_log_count:]:
|
| 247 |
+
await websocket.send_json({
|
| 248 |
+
"type": "step",
|
| 249 |
+
**entry,
|
| 250 |
+
})
|
| 251 |
+
last_log_count = len(stored.execution_log)
|
| 252 |
+
|
| 253 |
+
# Check if completed
|
| 254 |
+
if stored.status in ("completed", "failed", "cancelled"):
|
| 255 |
+
await websocket.send_json({
|
| 256 |
+
"type": "completed",
|
| 257 |
+
"run_id": run_id,
|
| 258 |
+
"status": stored.status,
|
| 259 |
+
"final_state": stored.final_state,
|
| 260 |
+
"error": stored.error,
|
| 261 |
+
})
|
| 262 |
+
break
|
| 263 |
+
|
| 264 |
+
except WebSocketDisconnect:
|
| 265 |
+
logger.info(f"Subscriber disconnected from run {run_id}")
|
| 266 |
+
except Exception as e:
|
| 267 |
+
logger.exception(f"WebSocket error: {e}")
|
| 268 |
+
finally:
|
| 269 |
+
manager.disconnect(websocket, run_id)
|
app/api/schemas.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Pydantic Schemas for API Request/Response Models.
|
| 3 |
+
|
| 4 |
+
These schemas define the structure of data flowing through the API,
|
| 5 |
+
providing automatic validation and documentation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict, List, Optional, Union
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from enum import Enum
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ============================================================
|
| 15 |
+
# Enums
|
| 16 |
+
# ============================================================
|
| 17 |
+
|
| 18 |
+
class ExecutionStatus(str, Enum):
|
| 19 |
+
"""Status of a workflow execution."""
|
| 20 |
+
PENDING = "pending"
|
| 21 |
+
RUNNING = "running"
|
| 22 |
+
COMPLETED = "completed"
|
| 23 |
+
FAILED = "failed"
|
| 24 |
+
CANCELLED = "cancelled"
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ============================================================
|
| 28 |
+
# Node Schemas
|
| 29 |
+
# ============================================================
|
| 30 |
+
|
| 31 |
+
class NodeDefinition(BaseModel):
|
| 32 |
+
"""Definition of a node in the graph."""
|
| 33 |
+
name: str = Field(..., description="Unique name for the node")
|
| 34 |
+
handler: str = Field(..., description="Name of the handler function (must be registered)")
|
| 35 |
+
description: Optional[str] = Field(None, description="Human-readable description")
|
| 36 |
+
|
| 37 |
+
class Config:
|
| 38 |
+
json_schema_extra = {
|
| 39 |
+
"example": {
|
| 40 |
+
"name": "extract",
|
| 41 |
+
"handler": "extract_functions",
|
| 42 |
+
"description": "Extract function definitions from code"
|
| 43 |
+
}
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ============================================================
|
| 48 |
+
# Edge Schemas
|
| 49 |
+
# ============================================================
|
| 50 |
+
|
| 51 |
+
class ConditionalRoutes(BaseModel):
|
| 52 |
+
"""Routes for a conditional edge."""
|
| 53 |
+
condition: str = Field(..., description="Name of the condition function")
|
| 54 |
+
routes: Dict[str, str] = Field(
|
| 55 |
+
...,
|
| 56 |
+
description="Mapping of condition results to target nodes"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
class Config:
|
| 60 |
+
json_schema_extra = {
|
| 61 |
+
"example": {
|
| 62 |
+
"condition": "quality_check",
|
| 63 |
+
"routes": {
|
| 64 |
+
"pass": "__END__",
|
| 65 |
+
"fail": "improve"
|
| 66 |
+
}
|
| 67 |
+
}
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ============================================================
|
| 72 |
+
# Graph Schemas
|
| 73 |
+
# ============================================================
|
| 74 |
+
|
| 75 |
+
class GraphCreateRequest(BaseModel):
|
| 76 |
+
"""Request to create a new workflow graph."""
|
| 77 |
+
name: str = Field(..., description="Name of the workflow")
|
| 78 |
+
description: Optional[str] = Field(None, description="Description of what this workflow does")
|
| 79 |
+
nodes: List[NodeDefinition] = Field(..., description="List of nodes in the graph")
|
| 80 |
+
edges: Dict[str, str] = Field(
|
| 81 |
+
default_factory=dict,
|
| 82 |
+
description="Direct edges: source -> target"
|
| 83 |
+
)
|
| 84 |
+
conditional_edges: Dict[str, ConditionalRoutes] = Field(
|
| 85 |
+
default_factory=dict,
|
| 86 |
+
description="Conditional edges with routing logic"
|
| 87 |
+
)
|
| 88 |
+
entry_point: Optional[str] = Field(None, description="Entry node (defaults to first node)")
|
| 89 |
+
max_iterations: int = Field(100, description="Maximum loop iterations", ge=1, le=1000)
|
| 90 |
+
|
| 91 |
+
class Config:
|
| 92 |
+
json_schema_extra = {
|
| 93 |
+
"example": {
|
| 94 |
+
"name": "code_review_workflow",
|
| 95 |
+
"description": "Automated code review with quality checks",
|
| 96 |
+
"nodes": [
|
| 97 |
+
{"name": "extract", "handler": "extract_functions"},
|
| 98 |
+
{"name": "complexity", "handler": "calculate_complexity"},
|
| 99 |
+
{"name": "issues", "handler": "detect_issues"},
|
| 100 |
+
{"name": "improve", "handler": "suggest_improvements"}
|
| 101 |
+
],
|
| 102 |
+
"edges": {
|
| 103 |
+
"extract": "complexity",
|
| 104 |
+
"complexity": "issues"
|
| 105 |
+
},
|
| 106 |
+
"conditional_edges": {
|
| 107 |
+
"issues": {
|
| 108 |
+
"condition": "quality_check",
|
| 109 |
+
"routes": {"pass": "__END__", "fail": "improve"}
|
| 110 |
+
},
|
| 111 |
+
"improve": {
|
| 112 |
+
"condition": "always_continue",
|
| 113 |
+
"routes": {"continue": "issues"}
|
| 114 |
+
}
|
| 115 |
+
},
|
| 116 |
+
"entry_point": "extract",
|
| 117 |
+
"max_iterations": 10
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class GraphCreateResponse(BaseModel):
|
| 123 |
+
"""Response after creating a graph."""
|
| 124 |
+
graph_id: str = Field(..., description="Unique identifier for the created graph")
|
| 125 |
+
name: str = Field(..., description="Name of the workflow")
|
| 126 |
+
message: str = Field(default="Graph created successfully")
|
| 127 |
+
node_count: int = Field(..., description="Number of nodes in the graph")
|
| 128 |
+
|
| 129 |
+
class Config:
|
| 130 |
+
json_schema_extra = {
|
| 131 |
+
"example": {
|
| 132 |
+
"graph_id": "abc123-def456",
|
| 133 |
+
"name": "code_review_workflow",
|
| 134 |
+
"message": "Graph created successfully",
|
| 135 |
+
"node_count": 4
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class GraphInfoResponse(BaseModel):
|
| 141 |
+
"""Response with graph information."""
|
| 142 |
+
graph_id: str
|
| 143 |
+
name: str
|
| 144 |
+
description: Optional[str]
|
| 145 |
+
node_count: int
|
| 146 |
+
nodes: List[str]
|
| 147 |
+
entry_point: Optional[str]
|
| 148 |
+
max_iterations: int
|
| 149 |
+
created_at: str
|
| 150 |
+
mermaid_diagram: Optional[str] = Field(None, description="Mermaid diagram of the graph")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class GraphListResponse(BaseModel):
|
| 154 |
+
"""Response listing all graphs."""
|
| 155 |
+
graphs: List[GraphInfoResponse]
|
| 156 |
+
total: int
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ============================================================
|
| 160 |
+
# Run Schemas
|
| 161 |
+
# ============================================================
|
| 162 |
+
|
| 163 |
+
class GraphRunRequest(BaseModel):
|
| 164 |
+
"""Request to run a workflow graph."""
|
| 165 |
+
graph_id: str = Field(..., description="ID of the graph to run")
|
| 166 |
+
initial_state: Dict[str, Any] = Field(
|
| 167 |
+
...,
|
| 168 |
+
description="Initial state data for the workflow"
|
| 169 |
+
)
|
| 170 |
+
async_execution: bool = Field(
|
| 171 |
+
False,
|
| 172 |
+
description="If true, run in background and return immediately"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
class Config:
|
| 176 |
+
json_schema_extra = {
|
| 177 |
+
"example": {
|
| 178 |
+
"graph_id": "abc123-def456",
|
| 179 |
+
"initial_state": {
|
| 180 |
+
"code": "def hello():\n print('world')",
|
| 181 |
+
"quality_threshold": 7.0
|
| 182 |
+
},
|
| 183 |
+
"async_execution": False
|
| 184 |
+
}
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class ExecutionLogEntry(BaseModel):
|
| 189 |
+
"""A single entry in the execution log."""
|
| 190 |
+
step: int
|
| 191 |
+
node: str
|
| 192 |
+
started_at: str
|
| 193 |
+
completed_at: Optional[str]
|
| 194 |
+
duration_ms: Optional[float]
|
| 195 |
+
iteration: int
|
| 196 |
+
result: str
|
| 197 |
+
error: Optional[str]
|
| 198 |
+
route_taken: Optional[str]
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class GraphRunResponse(BaseModel):
|
| 202 |
+
"""Response after running a graph."""
|
| 203 |
+
run_id: str = Field(..., description="Unique identifier for this run")
|
| 204 |
+
graph_id: str
|
| 205 |
+
status: ExecutionStatus
|
| 206 |
+
final_state: Dict[str, Any]
|
| 207 |
+
execution_log: List[ExecutionLogEntry]
|
| 208 |
+
started_at: Optional[str]
|
| 209 |
+
completed_at: Optional[str]
|
| 210 |
+
total_duration_ms: Optional[float]
|
| 211 |
+
iterations: int
|
| 212 |
+
error: Optional[str] = None
|
| 213 |
+
|
| 214 |
+
class Config:
|
| 215 |
+
json_schema_extra = {
|
| 216 |
+
"example": {
|
| 217 |
+
"run_id": "run-xyz789",
|
| 218 |
+
"graph_id": "abc123-def456",
|
| 219 |
+
"status": "completed",
|
| 220 |
+
"final_state": {
|
| 221 |
+
"code": "def hello():\n print('world')",
|
| 222 |
+
"functions": [{"name": "hello"}],
|
| 223 |
+
"quality_score": 8.5
|
| 224 |
+
},
|
| 225 |
+
"execution_log": [
|
| 226 |
+
{
|
| 227 |
+
"step": 1,
|
| 228 |
+
"node": "extract",
|
| 229 |
+
"started_at": "2024-01-01T12:00:00",
|
| 230 |
+
"completed_at": "2024-01-01T12:00:01",
|
| 231 |
+
"duration_ms": 15.5,
|
| 232 |
+
"iteration": 0,
|
| 233 |
+
"result": "success",
|
| 234 |
+
"error": None,
|
| 235 |
+
"route_taken": None
|
| 236 |
+
}
|
| 237 |
+
],
|
| 238 |
+
"started_at": "2024-01-01T12:00:00",
|
| 239 |
+
"completed_at": "2024-01-01T12:00:05",
|
| 240 |
+
"total_duration_ms": 5000.0,
|
| 241 |
+
"iterations": 1,
|
| 242 |
+
"error": None
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class RunStateResponse(BaseModel):
|
| 248 |
+
"""Response with current run state."""
|
| 249 |
+
run_id: str
|
| 250 |
+
graph_id: str
|
| 251 |
+
status: ExecutionStatus
|
| 252 |
+
current_node: Optional[str]
|
| 253 |
+
current_state: Dict[str, Any]
|
| 254 |
+
iteration: int
|
| 255 |
+
execution_log: List[ExecutionLogEntry]
|
| 256 |
+
started_at: str
|
| 257 |
+
completed_at: Optional[str]
|
| 258 |
+
error: Optional[str]
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class RunListResponse(BaseModel):
|
| 262 |
+
"""Response listing runs."""
|
| 263 |
+
runs: List[RunStateResponse]
|
| 264 |
+
total: int
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ============================================================
|
| 268 |
+
# Tool Schemas
|
| 269 |
+
# ============================================================
|
| 270 |
+
|
| 271 |
+
class ToolInfo(BaseModel):
|
| 272 |
+
"""Information about a registered tool."""
|
| 273 |
+
name: str
|
| 274 |
+
description: str
|
| 275 |
+
parameters: Dict[str, str]
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
class ToolListResponse(BaseModel):
|
| 279 |
+
"""Response listing all registered tools."""
|
| 280 |
+
tools: List[ToolInfo]
|
| 281 |
+
total: int
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class ToolRegisterRequest(BaseModel):
|
| 285 |
+
"""Request to register a new tool (for dynamic registration)."""
|
| 286 |
+
name: str = Field(..., description="Unique name for the tool")
|
| 287 |
+
description: str = Field("", description="Description of what the tool does")
|
| 288 |
+
code: str = Field(..., description="Python code for the tool function")
|
| 289 |
+
|
| 290 |
+
class Config:
|
| 291 |
+
json_schema_extra = {
|
| 292 |
+
"example": {
|
| 293 |
+
"name": "custom_validator",
|
| 294 |
+
"description": "Custom validation logic",
|
| 295 |
+
"code": "def custom_validator(data):\n return {'valid': True}"
|
| 296 |
+
}
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
class ToolRegisterResponse(BaseModel):
|
| 301 |
+
"""Response after registering a tool."""
|
| 302 |
+
name: str
|
| 303 |
+
message: str
|
| 304 |
+
warning: Optional[str] = None
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# ============================================================
|
| 308 |
+
# Error Schemas
|
| 309 |
+
# ============================================================
|
| 310 |
+
|
| 311 |
+
class ErrorResponse(BaseModel):
|
| 312 |
+
"""Standard error response."""
|
| 313 |
+
error: str
|
| 314 |
+
detail: Optional[str] = None
|
| 315 |
+
status_code: int
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class ValidationErrorResponse(BaseModel):
|
| 319 |
+
"""Validation error response."""
|
| 320 |
+
error: str = "Validation Error"
|
| 321 |
+
detail: List[Dict[str, Any]]
|
| 322 |
+
status_code: int = 422
|
app/config.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration settings for the Workflow Engine.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from pydantic_settings import BaseSettings
|
| 6 |
+
from typing import Optional
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Settings(BaseSettings):
|
| 10 |
+
"""Application settings with environment variable support."""
|
| 11 |
+
|
| 12 |
+
# Application
|
| 13 |
+
APP_NAME: str = "FlowGraph"
|
| 14 |
+
APP_VERSION: str = "1.0.0"
|
| 15 |
+
DEBUG: bool = True
|
| 16 |
+
|
| 17 |
+
# Server
|
| 18 |
+
HOST: str = "0.0.0.0"
|
| 19 |
+
PORT: int = 8000
|
| 20 |
+
|
| 21 |
+
# Workflow Engine
|
| 22 |
+
MAX_ITERATIONS: int = 100 # Default max loop iterations
|
| 23 |
+
EXECUTION_TIMEOUT: int = 300 # Seconds
|
| 24 |
+
|
| 25 |
+
# Logging
|
| 26 |
+
LOG_LEVEL: str = "INFO"
|
| 27 |
+
|
| 28 |
+
class Config:
|
| 29 |
+
env_file = ".env"
|
| 30 |
+
case_sensitive = True
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Global settings instance
|
| 34 |
+
settings = Settings()
|
app/engine/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Engine package - Core workflow orchestration components.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from app.engine.state import WorkflowState, StateManager
|
| 6 |
+
from app.engine.node import Node, node
|
| 7 |
+
from app.engine.graph import Graph
|
| 8 |
+
from app.engine.executor import Executor, ExecutionResult
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"WorkflowState",
|
| 12 |
+
"StateManager",
|
| 13 |
+
"Node",
|
| 14 |
+
"node",
|
| 15 |
+
"Graph",
|
| 16 |
+
"Executor",
|
| 17 |
+
"ExecutionResult",
|
| 18 |
+
]
|
app/engine/executor.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Async Workflow Executor.
|
| 3 |
+
|
| 4 |
+
The executor runs a workflow graph, managing state transitions,
|
| 5 |
+
handling loops, and generating execution logs.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from enum import Enum
|
| 12 |
+
import asyncio
|
| 13 |
+
import uuid
|
| 14 |
+
import time
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
from app.engine.graph import Graph, END
|
| 18 |
+
from app.engine.state import WorkflowState, StateManager
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Configure logging
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class ExecutionStatus(str, Enum):
|
| 26 |
+
"""Status of a workflow execution."""
|
| 27 |
+
PENDING = "pending"
|
| 28 |
+
RUNNING = "running"
|
| 29 |
+
COMPLETED = "completed"
|
| 30 |
+
FAILED = "failed"
|
| 31 |
+
CANCELLED = "cancelled"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class ExecutionStep:
|
| 36 |
+
"""A single step in the execution log."""
|
| 37 |
+
step: int
|
| 38 |
+
node: str
|
| 39 |
+
started_at: datetime
|
| 40 |
+
completed_at: Optional[datetime] = None
|
| 41 |
+
duration_ms: Optional[float] = None
|
| 42 |
+
iteration: int = 0
|
| 43 |
+
result: str = "success"
|
| 44 |
+
error: Optional[str] = None
|
| 45 |
+
route_taken: Optional[str] = None
|
| 46 |
+
|
| 47 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 48 |
+
return {
|
| 49 |
+
"step": self.step,
|
| 50 |
+
"node": self.node,
|
| 51 |
+
"started_at": self.started_at.isoformat(),
|
| 52 |
+
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
| 53 |
+
"duration_ms": self.duration_ms,
|
| 54 |
+
"iteration": self.iteration,
|
| 55 |
+
"result": self.result,
|
| 56 |
+
"error": self.error,
|
| 57 |
+
"route_taken": self.route_taken,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@dataclass
|
| 62 |
+
class ExecutionResult:
|
| 63 |
+
"""Result of a workflow execution."""
|
| 64 |
+
run_id: str
|
| 65 |
+
graph_id: str
|
| 66 |
+
status: ExecutionStatus
|
| 67 |
+
final_state: Dict[str, Any]
|
| 68 |
+
execution_log: List[ExecutionStep] = field(default_factory=list)
|
| 69 |
+
started_at: Optional[datetime] = None
|
| 70 |
+
completed_at: Optional[datetime] = None
|
| 71 |
+
total_duration_ms: Optional[float] = None
|
| 72 |
+
error: Optional[str] = None
|
| 73 |
+
iterations: int = 0
|
| 74 |
+
|
| 75 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 76 |
+
return {
|
| 77 |
+
"run_id": self.run_id,
|
| 78 |
+
"graph_id": self.graph_id,
|
| 79 |
+
"status": self.status.value,
|
| 80 |
+
"final_state": self.final_state,
|
| 81 |
+
"execution_log": [step.to_dict() for step in self.execution_log],
|
| 82 |
+
"started_at": self.started_at.isoformat() if self.started_at else None,
|
| 83 |
+
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
| 84 |
+
"total_duration_ms": self.total_duration_ms,
|
| 85 |
+
"error": self.error,
|
| 86 |
+
"iterations": self.iterations,
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class Executor:
|
| 91 |
+
"""
|
| 92 |
+
Async workflow executor.
|
| 93 |
+
|
| 94 |
+
Executes a graph with given initial state, handling:
|
| 95 |
+
- Sequential node execution
|
| 96 |
+
- Conditional branching
|
| 97 |
+
- Loop iterations with max limit
|
| 98 |
+
- Detailed execution logging
|
| 99 |
+
- Error handling
|
| 100 |
+
|
| 101 |
+
Usage:
|
| 102 |
+
executor = Executor(graph)
|
| 103 |
+
result = await executor.run({"input": "data"})
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
graph: Graph,
|
| 109 |
+
run_id: Optional[str] = None,
|
| 110 |
+
on_step: Optional[Callable[[ExecutionStep, Dict[str, Any]], None]] = None
|
| 111 |
+
):
|
| 112 |
+
"""
|
| 113 |
+
Initialize the executor.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
graph: The workflow graph to execute
|
| 117 |
+
run_id: Optional run ID (generated if not provided)
|
| 118 |
+
on_step: Optional callback for each step (for WebSocket streaming)
|
| 119 |
+
"""
|
| 120 |
+
self.graph = graph
|
| 121 |
+
self.run_id = run_id or str(uuid.uuid4())
|
| 122 |
+
self.on_step = on_step
|
| 123 |
+
|
| 124 |
+
# Execution state
|
| 125 |
+
self._state_manager: Optional[StateManager] = None
|
| 126 |
+
self._execution_log: List[ExecutionStep] = []
|
| 127 |
+
self._step_counter = 0
|
| 128 |
+
self._status = ExecutionStatus.PENDING
|
| 129 |
+
self._cancelled = False
|
| 130 |
+
|
| 131 |
+
@property
|
| 132 |
+
def status(self) -> ExecutionStatus:
|
| 133 |
+
"""Get the current execution status."""
|
| 134 |
+
return self._status
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def current_state(self) -> Optional[Dict[str, Any]]:
|
| 138 |
+
"""Get the current state data."""
|
| 139 |
+
if self._state_manager and self._state_manager.current_state:
|
| 140 |
+
return self._state_manager.current_state.data
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def current_node(self) -> Optional[str]:
|
| 145 |
+
"""Get the current node being executed."""
|
| 146 |
+
if self._state_manager and self._state_manager.current_state:
|
| 147 |
+
return self._state_manager.current_state.current_node
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
def cancel(self) -> None:
|
| 151 |
+
"""Cancel the execution."""
|
| 152 |
+
self._cancelled = True
|
| 153 |
+
self._status = ExecutionStatus.CANCELLED
|
| 154 |
+
|
| 155 |
+
async def run(self, initial_state: Dict[str, Any]) -> ExecutionResult:
|
| 156 |
+
"""
|
| 157 |
+
Execute the workflow with the given initial state.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
initial_state: Initial state data
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
ExecutionResult with final state and logs
|
| 164 |
+
"""
|
| 165 |
+
start_time = time.time()
|
| 166 |
+
self._status = ExecutionStatus.RUNNING
|
| 167 |
+
self._state_manager = StateManager(self.run_id)
|
| 168 |
+
|
| 169 |
+
# Initialize state
|
| 170 |
+
state = self._state_manager.initialize(initial_state)
|
| 171 |
+
|
| 172 |
+
# Validate graph
|
| 173 |
+
errors = self.graph.validate()
|
| 174 |
+
if errors:
|
| 175 |
+
return self._create_error_result(
|
| 176 |
+
f"Graph validation failed: {errors}",
|
| 177 |
+
start_time
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
current_node = self.graph.entry_point
|
| 181 |
+
iteration = 0
|
| 182 |
+
visited_in_iteration: set = set()
|
| 183 |
+
|
| 184 |
+
try:
|
| 185 |
+
while current_node and current_node != END:
|
| 186 |
+
# Check cancellation
|
| 187 |
+
if self._cancelled:
|
| 188 |
+
logger.info(f"Execution cancelled at node '{current_node}'")
|
| 189 |
+
break
|
| 190 |
+
|
| 191 |
+
# Check max iterations
|
| 192 |
+
if iteration >= self.graph.max_iterations:
|
| 193 |
+
return self._create_error_result(
|
| 194 |
+
f"Max iterations ({self.graph.max_iterations}) exceeded",
|
| 195 |
+
start_time
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Get the node
|
| 199 |
+
node = self.graph.nodes.get(current_node)
|
| 200 |
+
if not node:
|
| 201 |
+
return self._create_error_result(
|
| 202 |
+
f"Node '{current_node}' not found in graph",
|
| 203 |
+
start_time
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Execute the node
|
| 207 |
+
step = await self._execute_node(node, state, iteration)
|
| 208 |
+
|
| 209 |
+
# Handle error
|
| 210 |
+
if step.result == "error":
|
| 211 |
+
return self._create_error_result(
|
| 212 |
+
step.error or "Unknown error",
|
| 213 |
+
start_time
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
# Update state from state manager
|
| 217 |
+
state = self._state_manager.current_state
|
| 218 |
+
|
| 219 |
+
# Get next node
|
| 220 |
+
next_node = self.graph.get_next_node(current_node, state.data)
|
| 221 |
+
|
| 222 |
+
# Track route for conditional edges
|
| 223 |
+
if current_node in self.graph.conditional_edges:
|
| 224 |
+
cond_edge = self.graph.conditional_edges[current_node]
|
| 225 |
+
route_key = cond_edge.condition(state.data)
|
| 226 |
+
step.route_taken = route_key
|
| 227 |
+
logger.debug(f"Conditional route: {route_key} -> {next_node}")
|
| 228 |
+
|
| 229 |
+
# Detect loops and increment iteration
|
| 230 |
+
if next_node in visited_in_iteration:
|
| 231 |
+
iteration += 1
|
| 232 |
+
visited_in_iteration.clear()
|
| 233 |
+
state = state.increment_iteration()
|
| 234 |
+
logger.debug(f"Loop detected, iteration: {iteration}")
|
| 235 |
+
|
| 236 |
+
visited_in_iteration.add(current_node)
|
| 237 |
+
current_node = next_node
|
| 238 |
+
|
| 239 |
+
# Finalize
|
| 240 |
+
self._status = ExecutionStatus.COMPLETED
|
| 241 |
+
final_state = self._state_manager.finalize()
|
| 242 |
+
|
| 243 |
+
return ExecutionResult(
|
| 244 |
+
run_id=self.run_id,
|
| 245 |
+
graph_id=self.graph.graph_id,
|
| 246 |
+
status=self._status,
|
| 247 |
+
final_state=final_state.data,
|
| 248 |
+
execution_log=self._execution_log,
|
| 249 |
+
started_at=final_state.started_at,
|
| 250 |
+
completed_at=final_state.completed_at,
|
| 251 |
+
total_duration_ms=(time.time() - start_time) * 1000,
|
| 252 |
+
iterations=iteration + 1,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
except Exception as e:
|
| 256 |
+
logger.exception(f"Execution failed: {e}")
|
| 257 |
+
return self._create_error_result(str(e), start_time)
|
| 258 |
+
|
| 259 |
+
async def _execute_node(
|
| 260 |
+
self,
|
| 261 |
+
node,
|
| 262 |
+
state: WorkflowState,
|
| 263 |
+
iteration: int
|
| 264 |
+
) -> ExecutionStep:
|
| 265 |
+
"""Execute a single node and update state."""
|
| 266 |
+
self._step_counter += 1
|
| 267 |
+
step_start = datetime.now()
|
| 268 |
+
node_start_time = time.time()
|
| 269 |
+
|
| 270 |
+
step = ExecutionStep(
|
| 271 |
+
step=self._step_counter,
|
| 272 |
+
node=node.name,
|
| 273 |
+
started_at=step_start,
|
| 274 |
+
iteration=iteration,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
logger.info(f"Executing node: {node.name} (step {self._step_counter})")
|
| 278 |
+
|
| 279 |
+
try:
|
| 280 |
+
# Execute node handler
|
| 281 |
+
result_data = await node.execute(state.data)
|
| 282 |
+
|
| 283 |
+
# Update state
|
| 284 |
+
new_state = state.update(result_data).mark_visited(node.name)
|
| 285 |
+
self._state_manager.update(new_state, node.name)
|
| 286 |
+
|
| 287 |
+
# Complete step
|
| 288 |
+
step.completed_at = datetime.now()
|
| 289 |
+
step.duration_ms = (time.time() - node_start_time) * 1000
|
| 290 |
+
step.result = "success"
|
| 291 |
+
|
| 292 |
+
except Exception as e:
|
| 293 |
+
logger.error(f"Node {node.name} failed: {e}")
|
| 294 |
+
step.completed_at = datetime.now()
|
| 295 |
+
step.duration_ms = (time.time() - node_start_time) * 1000
|
| 296 |
+
step.result = "error"
|
| 297 |
+
step.error = str(e)
|
| 298 |
+
|
| 299 |
+
# Add to log
|
| 300 |
+
self._execution_log.append(step)
|
| 301 |
+
|
| 302 |
+
# Notify callback
|
| 303 |
+
if self.on_step:
|
| 304 |
+
try:
|
| 305 |
+
self.on_step(step, self._state_manager.current_state.data)
|
| 306 |
+
except Exception as e:
|
| 307 |
+
logger.warning(f"Step callback failed: {e}")
|
| 308 |
+
|
| 309 |
+
return step
|
| 310 |
+
|
| 311 |
+
def _create_error_result(
|
| 312 |
+
self,
|
| 313 |
+
error: str,
|
| 314 |
+
start_time: float
|
| 315 |
+
) -> ExecutionResult:
|
| 316 |
+
"""Create an error result."""
|
| 317 |
+
self._status = ExecutionStatus.FAILED
|
| 318 |
+
return ExecutionResult(
|
| 319 |
+
run_id=self.run_id,
|
| 320 |
+
graph_id=self.graph.graph_id,
|
| 321 |
+
status=ExecutionStatus.FAILED,
|
| 322 |
+
final_state=self.current_state or {},
|
| 323 |
+
execution_log=self._execution_log,
|
| 324 |
+
started_at=datetime.now(),
|
| 325 |
+
completed_at=datetime.now(),
|
| 326 |
+
total_duration_ms=(time.time() - start_time) * 1000,
|
| 327 |
+
error=error,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
def get_execution_summary(self) -> Dict[str, Any]:
|
| 331 |
+
"""Get a summary of the current execution."""
|
| 332 |
+
return {
|
| 333 |
+
"run_id": self.run_id,
|
| 334 |
+
"graph_id": self.graph.graph_id,
|
| 335 |
+
"status": self._status.value,
|
| 336 |
+
"current_node": self.current_node,
|
| 337 |
+
"current_state": self.current_state,
|
| 338 |
+
"step_count": self._step_counter,
|
| 339 |
+
"iteration": self._state_manager.current_state.iteration if self._state_manager and self._state_manager.current_state else 0,
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
async def execute_graph(
|
| 344 |
+
graph: Graph,
|
| 345 |
+
initial_state: Dict[str, Any],
|
| 346 |
+
run_id: Optional[str] = None,
|
| 347 |
+
on_step: Optional[Callable] = None
|
| 348 |
+
) -> ExecutionResult:
|
| 349 |
+
"""
|
| 350 |
+
Convenience function to execute a graph.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
graph: The workflow graph
|
| 354 |
+
initial_state: Initial state data
|
| 355 |
+
run_id: Optional run ID
|
| 356 |
+
on_step: Optional step callback
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
ExecutionResult
|
| 360 |
+
"""
|
| 361 |
+
executor = Executor(graph, run_id, on_step)
|
| 362 |
+
return await executor.run(initial_state)
|
app/engine/graph.py
ADDED
|
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Graph Definition for Workflow Engine.
|
| 3 |
+
|
| 4 |
+
The Graph is the core structure that defines the workflow - nodes, edges,
|
| 5 |
+
conditional routing, and execution flow.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Set, Union
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from enum import Enum
|
| 11 |
+
import uuid
|
| 12 |
+
|
| 13 |
+
from app.engine.node import Node, NodeType, get_registered_node, create_node_from_function
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# Special node names
|
| 17 |
+
END = "__END__"
|
| 18 |
+
START = "__START__"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class EdgeType(str, Enum):
|
| 22 |
+
"""Types of edges between nodes."""
|
| 23 |
+
DIRECT = "direct" # Always follow this edge
|
| 24 |
+
CONDITIONAL = "conditional" # Choose based on condition
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class Edge:
|
| 29 |
+
"""An edge connecting two nodes."""
|
| 30 |
+
source: str
|
| 31 |
+
target: str
|
| 32 |
+
edge_type: EdgeType = EdgeType.DIRECT
|
| 33 |
+
|
| 34 |
+
def to_dict(self) -> Dict[str, str]:
|
| 35 |
+
return {
|
| 36 |
+
"source": self.source,
|
| 37 |
+
"target": self.target,
|
| 38 |
+
"type": self.edge_type.value
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class ConditionalEdge:
|
| 44 |
+
"""
|
| 45 |
+
A conditional edge that routes to different nodes based on a condition.
|
| 46 |
+
|
| 47 |
+
The condition function receives the current state and returns a route key.
|
| 48 |
+
The routes dict maps route keys to target node names.
|
| 49 |
+
"""
|
| 50 |
+
source: str
|
| 51 |
+
condition: Callable[[Dict[str, Any]], str]
|
| 52 |
+
routes: Dict[str, str] # route_key -> target_node_name
|
| 53 |
+
|
| 54 |
+
def evaluate(self, state_data: Dict[str, Any]) -> str:
|
| 55 |
+
"""Evaluate the condition and return the target node name."""
|
| 56 |
+
route_key = self.condition(state_data)
|
| 57 |
+
if route_key not in self.routes:
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"Condition returned unknown route '{route_key}'. "
|
| 60 |
+
f"Available routes: {list(self.routes.keys())}"
|
| 61 |
+
)
|
| 62 |
+
return self.routes[route_key]
|
| 63 |
+
|
| 64 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 65 |
+
return {
|
| 66 |
+
"source": self.source,
|
| 67 |
+
"condition": self.condition.__name__ if hasattr(self.condition, '__name__') else str(self.condition),
|
| 68 |
+
"routes": self.routes
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class Graph:
|
| 74 |
+
"""
|
| 75 |
+
A workflow graph consisting of nodes and edges.
|
| 76 |
+
|
| 77 |
+
The graph defines the structure of a workflow:
|
| 78 |
+
- Nodes: Processing units that transform state
|
| 79 |
+
- Edges: Connections between nodes
|
| 80 |
+
- Conditional Edges: Branching logic based on state
|
| 81 |
+
|
| 82 |
+
Attributes:
|
| 83 |
+
graph_id: Unique identifier for this graph
|
| 84 |
+
name: Human-readable name
|
| 85 |
+
nodes: Dict of node_name -> Node
|
| 86 |
+
edges: List of direct edges
|
| 87 |
+
conditional_edges: Dict of source_node -> ConditionalEdge
|
| 88 |
+
entry_point: Name of the first node to execute
|
| 89 |
+
max_iterations: Maximum loop iterations allowed
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
graph_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 93 |
+
name: str = "Unnamed Workflow"
|
| 94 |
+
nodes: Dict[str, Node] = field(default_factory=dict)
|
| 95 |
+
edges: Dict[str, str] = field(default_factory=dict) # source -> target for direct edges
|
| 96 |
+
conditional_edges: Dict[str, ConditionalEdge] = field(default_factory=dict)
|
| 97 |
+
entry_point: Optional[str] = None
|
| 98 |
+
max_iterations: int = 100
|
| 99 |
+
description: str = ""
|
| 100 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 101 |
+
|
| 102 |
+
def add_node(
|
| 103 |
+
self,
|
| 104 |
+
name: str,
|
| 105 |
+
handler: Optional[Callable] = None,
|
| 106 |
+
node_type: NodeType = NodeType.STANDARD,
|
| 107 |
+
description: str = ""
|
| 108 |
+
) -> "Graph":
|
| 109 |
+
"""
|
| 110 |
+
Add a node to the graph.
|
| 111 |
+
|
| 112 |
+
If handler is not provided, attempts to find a registered node
|
| 113 |
+
with the given name.
|
| 114 |
+
|
| 115 |
+
Args:
|
| 116 |
+
name: Unique name for the node
|
| 117 |
+
handler: Function to execute (optional if registered)
|
| 118 |
+
node_type: Type of node
|
| 119 |
+
description: Human-readable description
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
Self for chaining
|
| 123 |
+
"""
|
| 124 |
+
if handler is None:
|
| 125 |
+
# Try to find a registered handler
|
| 126 |
+
handler = get_registered_node(name)
|
| 127 |
+
if handler is None:
|
| 128 |
+
raise ValueError(
|
| 129 |
+
f"No handler provided for node '{name}' and no registered "
|
| 130 |
+
f"node found with that name"
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
if name in self.nodes:
|
| 134 |
+
raise ValueError(f"Node '{name}' already exists in the graph")
|
| 135 |
+
|
| 136 |
+
node = create_node_from_function(handler, name, node_type, description)
|
| 137 |
+
self.nodes[name] = node
|
| 138 |
+
|
| 139 |
+
# Set as entry point if it's the first node or marked as entry
|
| 140 |
+
if self.entry_point is None or node_type == NodeType.ENTRY:
|
| 141 |
+
self.entry_point = name
|
| 142 |
+
|
| 143 |
+
return self
|
| 144 |
+
|
| 145 |
+
def add_edge(self, source: str, target: str) -> "Graph":
|
| 146 |
+
"""
|
| 147 |
+
Add a direct edge from source to target.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
source: Source node name
|
| 151 |
+
target: Target node name (or END)
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
Self for chaining
|
| 155 |
+
"""
|
| 156 |
+
if source not in self.nodes:
|
| 157 |
+
raise ValueError(f"Source node '{source}' not found in graph")
|
| 158 |
+
if target != END and target not in self.nodes:
|
| 159 |
+
raise ValueError(f"Target node '{target}' not found in graph")
|
| 160 |
+
|
| 161 |
+
# Check for conflicts with conditional edges
|
| 162 |
+
if source in self.conditional_edges:
|
| 163 |
+
raise ValueError(
|
| 164 |
+
f"Node '{source}' already has a conditional edge. "
|
| 165 |
+
f"Cannot add a direct edge."
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
self.edges[source] = target
|
| 169 |
+
return self
|
| 170 |
+
|
| 171 |
+
def add_conditional_edge(
|
| 172 |
+
self,
|
| 173 |
+
source: str,
|
| 174 |
+
condition: Callable[[Dict[str, Any]], str],
|
| 175 |
+
routes: Dict[str, str]
|
| 176 |
+
) -> "Graph":
|
| 177 |
+
"""
|
| 178 |
+
Add a conditional edge from source node.
|
| 179 |
+
|
| 180 |
+
The condition function receives state and returns a route key.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
source: Source node name
|
| 184 |
+
condition: Function that returns route key
|
| 185 |
+
routes: Dict mapping route keys to target nodes
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Self for chaining
|
| 189 |
+
"""
|
| 190 |
+
if source not in self.nodes:
|
| 191 |
+
raise ValueError(f"Source node '{source}' not found in graph")
|
| 192 |
+
|
| 193 |
+
# Validate all targets
|
| 194 |
+
for route_key, target in routes.items():
|
| 195 |
+
if target != END and target not in self.nodes:
|
| 196 |
+
raise ValueError(
|
| 197 |
+
f"Target node '{target}' for route '{route_key}' not found in graph"
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
# Check for conflicts with direct edges
|
| 201 |
+
if source in self.edges:
|
| 202 |
+
raise ValueError(
|
| 203 |
+
f"Node '{source}' already has a direct edge. "
|
| 204 |
+
f"Cannot add a conditional edge."
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
self.conditional_edges[source] = ConditionalEdge(
|
| 208 |
+
source=source,
|
| 209 |
+
condition=condition,
|
| 210 |
+
routes=routes
|
| 211 |
+
)
|
| 212 |
+
return self
|
| 213 |
+
|
| 214 |
+
def set_entry_point(self, node_name: str) -> "Graph":
|
| 215 |
+
"""Set the entry point of the graph."""
|
| 216 |
+
if node_name not in self.nodes:
|
| 217 |
+
raise ValueError(f"Node '{node_name}' not found in graph")
|
| 218 |
+
self.entry_point = node_name
|
| 219 |
+
return self
|
| 220 |
+
|
| 221 |
+
def get_next_node(self, current_node: str, state_data: Dict[str, Any]) -> Optional[str]:
|
| 222 |
+
"""
|
| 223 |
+
Get the next node to execute based on edges and state.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
current_node: Current node name
|
| 227 |
+
state_data: Current state data
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
Next node name, END, or None if no edge defined
|
| 231 |
+
"""
|
| 232 |
+
# Check for conditional edge first
|
| 233 |
+
if current_node in self.conditional_edges:
|
| 234 |
+
conditional = self.conditional_edges[current_node]
|
| 235 |
+
return conditional.evaluate(state_data)
|
| 236 |
+
|
| 237 |
+
# Check for direct edge
|
| 238 |
+
if current_node in self.edges:
|
| 239 |
+
return self.edges[current_node]
|
| 240 |
+
|
| 241 |
+
# No edge defined - implicit end
|
| 242 |
+
return None
|
| 243 |
+
|
| 244 |
+
def validate(self) -> List[str]:
|
| 245 |
+
"""
|
| 246 |
+
Validate the graph structure.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
List of validation errors (empty if valid)
|
| 250 |
+
"""
|
| 251 |
+
errors = []
|
| 252 |
+
|
| 253 |
+
# Must have at least one node
|
| 254 |
+
if not self.nodes:
|
| 255 |
+
errors.append("Graph must have at least one node")
|
| 256 |
+
return errors
|
| 257 |
+
|
| 258 |
+
# Must have an entry point
|
| 259 |
+
if not self.entry_point:
|
| 260 |
+
errors.append("Graph must have an entry point")
|
| 261 |
+
elif self.entry_point not in self.nodes:
|
| 262 |
+
errors.append(f"Entry point '{self.entry_point}' not found in nodes")
|
| 263 |
+
|
| 264 |
+
# Check for orphan nodes (not reachable from entry point)
|
| 265 |
+
reachable = self._get_reachable_nodes()
|
| 266 |
+
orphans = set(self.nodes.keys()) - reachable
|
| 267 |
+
if orphans:
|
| 268 |
+
errors.append(f"Orphan nodes (not reachable): {orphans}")
|
| 269 |
+
|
| 270 |
+
# Check that nodes without outgoing edges make sense
|
| 271 |
+
for node_name in self.nodes:
|
| 272 |
+
if node_name not in self.edges and node_name not in self.conditional_edges:
|
| 273 |
+
# This is an implicit end node - that's okay
|
| 274 |
+
pass
|
| 275 |
+
|
| 276 |
+
return errors
|
| 277 |
+
|
| 278 |
+
def _get_reachable_nodes(self) -> Set[str]:
|
| 279 |
+
"""Get all nodes reachable from the entry point."""
|
| 280 |
+
if not self.entry_point:
|
| 281 |
+
return set()
|
| 282 |
+
|
| 283 |
+
reachable = set()
|
| 284 |
+
to_visit = [self.entry_point]
|
| 285 |
+
|
| 286 |
+
while to_visit:
|
| 287 |
+
node = to_visit.pop()
|
| 288 |
+
if node in reachable or node == END:
|
| 289 |
+
continue
|
| 290 |
+
|
| 291 |
+
reachable.add(node)
|
| 292 |
+
|
| 293 |
+
# Add direct edge target
|
| 294 |
+
if node in self.edges:
|
| 295 |
+
to_visit.append(self.edges[node])
|
| 296 |
+
|
| 297 |
+
# Add conditional edge targets
|
| 298 |
+
if node in self.conditional_edges:
|
| 299 |
+
for target in self.conditional_edges[node].routes.values():
|
| 300 |
+
to_visit.append(target)
|
| 301 |
+
|
| 302 |
+
return reachable
|
| 303 |
+
|
| 304 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 305 |
+
"""Serialize the graph to a dictionary."""
|
| 306 |
+
return {
|
| 307 |
+
"graph_id": self.graph_id,
|
| 308 |
+
"name": self.name,
|
| 309 |
+
"description": self.description,
|
| 310 |
+
"nodes": {name: node.to_dict() for name, node in self.nodes.items()},
|
| 311 |
+
"edges": self.edges,
|
| 312 |
+
"conditional_edges": {
|
| 313 |
+
name: edge.to_dict()
|
| 314 |
+
for name, edge in self.conditional_edges.items()
|
| 315 |
+
},
|
| 316 |
+
"entry_point": self.entry_point,
|
| 317 |
+
"max_iterations": self.max_iterations,
|
| 318 |
+
"metadata": self.metadata,
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
def to_mermaid(self) -> str:
|
| 322 |
+
"""Generate a Mermaid diagram of the graph."""
|
| 323 |
+
lines = ["graph TD"]
|
| 324 |
+
|
| 325 |
+
# Add nodes
|
| 326 |
+
for name, node in self.nodes.items():
|
| 327 |
+
label = name.replace("_", " ").title()
|
| 328 |
+
if node.node_type == NodeType.ENTRY:
|
| 329 |
+
lines.append(f' {name}["{label} 🚀"]')
|
| 330 |
+
elif node.node_type == NodeType.EXIT:
|
| 331 |
+
lines.append(f' {name}["{label} 🏁"]')
|
| 332 |
+
else:
|
| 333 |
+
lines.append(f' {name}["{label}"]')
|
| 334 |
+
|
| 335 |
+
# Add END node if used
|
| 336 |
+
has_end = END in self.edges.values()
|
| 337 |
+
for cond in self.conditional_edges.values():
|
| 338 |
+
if END in cond.routes.values():
|
| 339 |
+
has_end = True
|
| 340 |
+
break
|
| 341 |
+
|
| 342 |
+
if has_end:
|
| 343 |
+
lines.append(f' {END}(("END"))')
|
| 344 |
+
|
| 345 |
+
# Add direct edges
|
| 346 |
+
for source, target in self.edges.items():
|
| 347 |
+
lines.append(f" {source} --> {target}")
|
| 348 |
+
|
| 349 |
+
# Add conditional edges
|
| 350 |
+
for source, cond in self.conditional_edges.items():
|
| 351 |
+
for route_key, target in cond.routes.items():
|
| 352 |
+
lines.append(f" {source} -->|{route_key}| {target}")
|
| 353 |
+
|
| 354 |
+
return "\n".join(lines)
|
| 355 |
+
|
| 356 |
+
def __repr__(self) -> str:
|
| 357 |
+
return (
|
| 358 |
+
f"Graph(name='{self.name}', nodes={list(self.nodes.keys())}, "
|
| 359 |
+
f"entry='{self.entry_point}')"
|
| 360 |
+
)
|
app/engine/node.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Node Definition for Workflow Engine.
|
| 3 |
+
|
| 4 |
+
Nodes are the building blocks of a workflow. Each node is a function
|
| 5 |
+
that receives state, performs some operation, and returns modified state.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Callable, Dict, Optional, Union
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from enum import Enum
|
| 11 |
+
import asyncio
|
| 12 |
+
import inspect
|
| 13 |
+
import functools
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class NodeType(str, Enum):
|
| 17 |
+
"""Types of nodes in the workflow."""
|
| 18 |
+
STANDARD = "standard" # Regular processing node
|
| 19 |
+
CONDITIONAL = "conditional" # Branching decision node
|
| 20 |
+
ENTRY = "entry" # Entry point
|
| 21 |
+
EXIT = "exit" # Exit point
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Node:
|
| 26 |
+
"""
|
| 27 |
+
A node in the workflow graph.
|
| 28 |
+
|
| 29 |
+
Each node has a name and a handler function. The handler receives
|
| 30 |
+
the current state data (as a dict) and returns modified state data.
|
| 31 |
+
|
| 32 |
+
Attributes:
|
| 33 |
+
name: Unique identifier for the node
|
| 34 |
+
handler: Function that processes state (sync or async)
|
| 35 |
+
node_type: Type of node (standard, conditional, etc.)
|
| 36 |
+
description: Human-readable description
|
| 37 |
+
metadata: Additional node metadata
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
name: str
|
| 41 |
+
handler: Callable[[Dict[str, Any]], Union[Dict[str, Any], Any]]
|
| 42 |
+
node_type: NodeType = NodeType.STANDARD
|
| 43 |
+
description: str = ""
|
| 44 |
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
| 45 |
+
|
| 46 |
+
def __post_init__(self):
|
| 47 |
+
"""Validate the node after initialization."""
|
| 48 |
+
if not self.name:
|
| 49 |
+
raise ValueError("Node name cannot be empty")
|
| 50 |
+
if not callable(self.handler):
|
| 51 |
+
raise ValueError(f"Handler for node '{self.name}' must be callable")
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def is_async(self) -> bool:
|
| 55 |
+
"""Check if the handler is an async function."""
|
| 56 |
+
return asyncio.iscoroutinefunction(self.handler)
|
| 57 |
+
|
| 58 |
+
async def execute(self, state_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 59 |
+
"""
|
| 60 |
+
Execute the node handler with the given state data.
|
| 61 |
+
|
| 62 |
+
Handles both sync and async handlers transparently.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
state_data: The current state data dictionary
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Modified state data dictionary
|
| 69 |
+
"""
|
| 70 |
+
try:
|
| 71 |
+
if self.is_async:
|
| 72 |
+
result = await self.handler(state_data)
|
| 73 |
+
else:
|
| 74 |
+
# Run sync handler in executor to not block
|
| 75 |
+
loop = asyncio.get_event_loop()
|
| 76 |
+
result = await loop.run_in_executor(
|
| 77 |
+
None,
|
| 78 |
+
functools.partial(self.handler, state_data)
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# If handler returns None, return original state
|
| 82 |
+
if result is None:
|
| 83 |
+
return state_data
|
| 84 |
+
|
| 85 |
+
# If handler returns a dict, use it as the new state
|
| 86 |
+
if isinstance(result, dict):
|
| 87 |
+
return result
|
| 88 |
+
|
| 89 |
+
# Otherwise, something unexpected happened
|
| 90 |
+
raise ValueError(
|
| 91 |
+
f"Node '{self.name}' handler must return a dict or None, "
|
| 92 |
+
f"got {type(result).__name__}"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
# Add context to the error
|
| 97 |
+
raise RuntimeError(f"Error in node '{self.name}': {str(e)}") from e
|
| 98 |
+
|
| 99 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 100 |
+
"""Serialize the node to a dictionary."""
|
| 101 |
+
return {
|
| 102 |
+
"name": self.name,
|
| 103 |
+
"type": self.node_type.value,
|
| 104 |
+
"description": self.description,
|
| 105 |
+
"handler": self.handler.__name__ if hasattr(self.handler, '__name__') else str(self.handler),
|
| 106 |
+
"metadata": self.metadata,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# Registry to hold decorated node functions
|
| 111 |
+
_node_registry: Dict[str, Callable] = {}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def node(
|
| 115 |
+
name: Optional[str] = None,
|
| 116 |
+
node_type: NodeType = NodeType.STANDARD,
|
| 117 |
+
description: str = ""
|
| 118 |
+
) -> Callable:
|
| 119 |
+
"""
|
| 120 |
+
Decorator to register a function as a workflow node.
|
| 121 |
+
|
| 122 |
+
Usage:
|
| 123 |
+
@node(name="extract_functions", description="Extract functions from code")
|
| 124 |
+
def extract_functions(state: dict) -> dict:
|
| 125 |
+
# ... process state
|
| 126 |
+
return state
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
name: Node name (defaults to function name)
|
| 130 |
+
node_type: Type of node
|
| 131 |
+
description: Human-readable description
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Decorated function
|
| 135 |
+
"""
|
| 136 |
+
def decorator(func: Callable) -> Callable:
|
| 137 |
+
node_name = name or func.__name__
|
| 138 |
+
|
| 139 |
+
# Store metadata on the function
|
| 140 |
+
func._node_metadata = {
|
| 141 |
+
"name": node_name,
|
| 142 |
+
"type": node_type,
|
| 143 |
+
"description": description or func.__doc__ or "",
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
# Register in global registry
|
| 147 |
+
_node_registry[node_name] = func
|
| 148 |
+
|
| 149 |
+
@functools.wraps(func)
|
| 150 |
+
def wrapper(*args, **kwargs):
|
| 151 |
+
return func(*args, **kwargs)
|
| 152 |
+
|
| 153 |
+
wrapper._node_metadata = func._node_metadata
|
| 154 |
+
return wrapper
|
| 155 |
+
|
| 156 |
+
return decorator
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def get_registered_node(name: str) -> Optional[Callable]:
|
| 160 |
+
"""Get a registered node function by name."""
|
| 161 |
+
return _node_registry.get(name)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def list_registered_nodes() -> Dict[str, Dict[str, Any]]:
|
| 165 |
+
"""List all registered nodes and their metadata."""
|
| 166 |
+
return {
|
| 167 |
+
name: func._node_metadata
|
| 168 |
+
for name, func in _node_registry.items()
|
| 169 |
+
if hasattr(func, '_node_metadata')
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def create_node_from_function(
|
| 174 |
+
func: Callable,
|
| 175 |
+
name: Optional[str] = None,
|
| 176 |
+
node_type: NodeType = NodeType.STANDARD,
|
| 177 |
+
description: str = ""
|
| 178 |
+
) -> Node:
|
| 179 |
+
"""
|
| 180 |
+
Create a Node instance from a function.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
func: The handler function
|
| 184 |
+
name: Node name (defaults to function name)
|
| 185 |
+
node_type: Type of node
|
| 186 |
+
description: Human-readable description
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
A Node instance
|
| 190 |
+
"""
|
| 191 |
+
return Node(
|
| 192 |
+
name=name or func.__name__,
|
| 193 |
+
handler=func,
|
| 194 |
+
node_type=node_type,
|
| 195 |
+
description=description or func.__doc__ or "",
|
| 196 |
+
)
|
app/engine/state.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
State Management for Workflow Engine.
|
| 3 |
+
|
| 4 |
+
This module provides the state management system that flows through the workflow.
|
| 5 |
+
State is immutable - each node receives state and returns a new modified state.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
from pydantic import BaseModel, Field
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
from copy import deepcopy
|
| 12 |
+
import uuid
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class StateSnapshot(BaseModel):
|
| 16 |
+
"""A snapshot of state at a specific point in execution."""
|
| 17 |
+
|
| 18 |
+
timestamp: datetime = Field(default_factory=datetime.now)
|
| 19 |
+
node_name: str
|
| 20 |
+
state_data: Dict[str, Any]
|
| 21 |
+
iteration: int = 0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class WorkflowState(BaseModel):
|
| 25 |
+
"""
|
| 26 |
+
The shared state that flows through the workflow.
|
| 27 |
+
|
| 28 |
+
This is a flexible container that holds all data being processed
|
| 29 |
+
by the workflow nodes. Each node can read from and write to this state.
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
data: The actual workflow data (flexible dictionary)
|
| 33 |
+
metadata: Execution metadata (iteration count, visited nodes, etc.)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
# The actual data being processed
|
| 37 |
+
data: Dict[str, Any] = Field(default_factory=dict)
|
| 38 |
+
|
| 39 |
+
# Execution metadata
|
| 40 |
+
current_node: Optional[str] = None
|
| 41 |
+
iteration: int = 0
|
| 42 |
+
visited_nodes: List[str] = Field(default_factory=list)
|
| 43 |
+
started_at: Optional[datetime] = None
|
| 44 |
+
completed_at: Optional[datetime] = None
|
| 45 |
+
|
| 46 |
+
class Config:
|
| 47 |
+
arbitrary_types_allowed = True
|
| 48 |
+
|
| 49 |
+
def get(self, key: str, default: Any = None) -> Any:
|
| 50 |
+
"""Get a value from the state data."""
|
| 51 |
+
return self.data.get(key, default)
|
| 52 |
+
|
| 53 |
+
def set(self, key: str, value: Any) -> "WorkflowState":
|
| 54 |
+
"""Set a value in state data and return a new state (immutable pattern)."""
|
| 55 |
+
new_data = deepcopy(self.data)
|
| 56 |
+
new_data[key] = value
|
| 57 |
+
return self.model_copy(update={"data": new_data})
|
| 58 |
+
|
| 59 |
+
def update(self, updates: Dict[str, Any]) -> "WorkflowState":
|
| 60 |
+
"""Update multiple values and return a new state."""
|
| 61 |
+
new_data = deepcopy(self.data)
|
| 62 |
+
new_data.update(updates)
|
| 63 |
+
return self.model_copy(update={"data": new_data})
|
| 64 |
+
|
| 65 |
+
def mark_visited(self, node_name: str) -> "WorkflowState":
|
| 66 |
+
"""Mark a node as visited."""
|
| 67 |
+
new_visited = self.visited_nodes + [node_name]
|
| 68 |
+
return self.model_copy(update={
|
| 69 |
+
"visited_nodes": new_visited,
|
| 70 |
+
"current_node": node_name
|
| 71 |
+
})
|
| 72 |
+
|
| 73 |
+
def increment_iteration(self) -> "WorkflowState":
|
| 74 |
+
"""Increment the iteration counter."""
|
| 75 |
+
return self.model_copy(update={"iteration": self.iteration + 1})
|
| 76 |
+
|
| 77 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 78 |
+
"""Convert state to a plain dictionary."""
|
| 79 |
+
return {
|
| 80 |
+
"data": self.data,
|
| 81 |
+
"current_node": self.current_node,
|
| 82 |
+
"iteration": self.iteration,
|
| 83 |
+
"visited_nodes": self.visited_nodes,
|
| 84 |
+
"started_at": self.started_at.isoformat() if self.started_at else None,
|
| 85 |
+
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
@classmethod
|
| 89 |
+
def from_dict(cls, data: Dict[str, Any]) -> "WorkflowState":
|
| 90 |
+
"""Create a WorkflowState from a dictionary."""
|
| 91 |
+
if "data" in data:
|
| 92 |
+
return cls(**data)
|
| 93 |
+
# If it's just raw data, wrap it
|
| 94 |
+
return cls(data=data)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
class StateManager:
|
| 98 |
+
"""
|
| 99 |
+
Manages state history and snapshots for a workflow run.
|
| 100 |
+
|
| 101 |
+
This provides debugging capabilities by tracking state changes
|
| 102 |
+
throughout the workflow execution.
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, run_id: Optional[str] = None):
|
| 106 |
+
self.run_id = run_id or str(uuid.uuid4())
|
| 107 |
+
self.history: List[StateSnapshot] = []
|
| 108 |
+
self._current_state: Optional[WorkflowState] = None
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def current_state(self) -> Optional[WorkflowState]:
|
| 112 |
+
"""Get the current state."""
|
| 113 |
+
return self._current_state
|
| 114 |
+
|
| 115 |
+
def initialize(self, initial_data: Dict[str, Any]) -> WorkflowState:
|
| 116 |
+
"""Initialize the state manager with initial data."""
|
| 117 |
+
self._current_state = WorkflowState(
|
| 118 |
+
data=initial_data,
|
| 119 |
+
started_at=datetime.now()
|
| 120 |
+
)
|
| 121 |
+
return self._current_state
|
| 122 |
+
|
| 123 |
+
def update(self, new_state: WorkflowState, node_name: str) -> None:
|
| 124 |
+
"""Update the current state and record a snapshot."""
|
| 125 |
+
# Record snapshot
|
| 126 |
+
snapshot = StateSnapshot(
|
| 127 |
+
node_name=node_name,
|
| 128 |
+
state_data=deepcopy(new_state.data),
|
| 129 |
+
iteration=new_state.iteration
|
| 130 |
+
)
|
| 131 |
+
self.history.append(snapshot)
|
| 132 |
+
|
| 133 |
+
# Update current state
|
| 134 |
+
self._current_state = new_state
|
| 135 |
+
|
| 136 |
+
def finalize(self) -> WorkflowState:
|
| 137 |
+
"""Mark the workflow as complete."""
|
| 138 |
+
if self._current_state:
|
| 139 |
+
self._current_state = self._current_state.model_copy(
|
| 140 |
+
update={"completed_at": datetime.now()}
|
| 141 |
+
)
|
| 142 |
+
return self._current_state
|
| 143 |
+
|
| 144 |
+
def get_history(self) -> List[Dict[str, Any]]:
|
| 145 |
+
"""Get the state history as a list of dictionaries."""
|
| 146 |
+
return [
|
| 147 |
+
{
|
| 148 |
+
"timestamp": s.timestamp.isoformat(),
|
| 149 |
+
"node": s.node_name,
|
| 150 |
+
"iteration": s.iteration,
|
| 151 |
+
"state": s.state_data
|
| 152 |
+
}
|
| 153 |
+
for s in self.history
|
| 154 |
+
]
|
app/main.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FlowGraph - FastAPI Application Entry Point.
|
| 3 |
+
|
| 4 |
+
A lightweight, async-first workflow orchestration engine for building agent pipelines.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from contextlib import asynccontextmanager
|
| 8 |
+
from fastapi import FastAPI
|
| 9 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 10 |
+
from fastapi.responses import JSONResponse
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from app.config import settings
|
| 14 |
+
from app.api.routes import graph, tools, websocket
|
| 15 |
+
from app.workflows.code_review import register_code_review_workflow
|
| 16 |
+
|
| 17 |
+
# Import builtin tools to register them
|
| 18 |
+
import app.tools.builtin # noqa: F401
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Configure logging
|
| 22 |
+
logging.basicConfig(
|
| 23 |
+
level=getattr(logging, settings.LOG_LEVEL),
|
| 24 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
| 25 |
+
)
|
| 26 |
+
logger = logging.getLogger(__name__)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@asynccontextmanager
|
| 30 |
+
async def lifespan(app: FastAPI):
|
| 31 |
+
"""Application lifespan handler."""
|
| 32 |
+
# Startup
|
| 33 |
+
logger.info(f"Starting {settings.APP_NAME} v{settings.APP_VERSION}")
|
| 34 |
+
|
| 35 |
+
# Register the demo workflow
|
| 36 |
+
await register_code_review_workflow()
|
| 37 |
+
|
| 38 |
+
yield
|
| 39 |
+
|
| 40 |
+
# Shutdown
|
| 41 |
+
logger.info("Shutting down...")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Create FastAPI application
|
| 45 |
+
app = FastAPI(
|
| 46 |
+
title=settings.APP_NAME,
|
| 47 |
+
description="""
|
| 48 |
+
## Workflow Engine API
|
| 49 |
+
|
| 50 |
+
A minimal but powerful workflow/graph engine for building agent workflows.
|
| 51 |
+
|
| 52 |
+
### Features
|
| 53 |
+
- **Nodes**: Python functions that read and modify shared state
|
| 54 |
+
- **Edges**: Define execution flow between nodes
|
| 55 |
+
- **Branching**: Conditional routing based on state values
|
| 56 |
+
- **Looping**: Support for iterative workflows
|
| 57 |
+
- **Real-time Updates**: WebSocket support for live execution streaming
|
| 58 |
+
|
| 59 |
+
### Quick Start
|
| 60 |
+
1. List available tools: `GET /tools`
|
| 61 |
+
2. Create a graph: `POST /graph/create`
|
| 62 |
+
3. Run the graph: `POST /graph/run`
|
| 63 |
+
4. Check execution state: `GET /graph/state/{run_id}`
|
| 64 |
+
|
| 65 |
+
### Demo Workflow
|
| 66 |
+
A pre-registered Code Review workflow is available with ID: `code-review-demo`
|
| 67 |
+
""",
|
| 68 |
+
version=settings.APP_VERSION,
|
| 69 |
+
docs_url="/docs",
|
| 70 |
+
redoc_url="/redoc",
|
| 71 |
+
lifespan=lifespan,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Add CORS middleware
|
| 76 |
+
app.add_middleware(
|
| 77 |
+
CORSMiddleware,
|
| 78 |
+
allow_origins=["*"],
|
| 79 |
+
allow_credentials=True,
|
| 80 |
+
allow_methods=["*"],
|
| 81 |
+
allow_headers=["*"],
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# Include routers
|
| 86 |
+
app.include_router(graph.router)
|
| 87 |
+
app.include_router(tools.router)
|
| 88 |
+
app.include_router(websocket.router)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ============================================================
|
| 92 |
+
# Root Endpoints
|
| 93 |
+
# ============================================================
|
| 94 |
+
|
| 95 |
+
@app.get("/", tags=["Root"])
|
| 96 |
+
async def root():
|
| 97 |
+
"""API root - returns basic info and links."""
|
| 98 |
+
return {
|
| 99 |
+
"name": settings.APP_NAME,
|
| 100 |
+
"version": settings.APP_VERSION,
|
| 101 |
+
"description": "A minimal workflow/graph engine for agent workflows",
|
| 102 |
+
"docs": "/docs",
|
| 103 |
+
"redoc": "/redoc",
|
| 104 |
+
"endpoints": {
|
| 105 |
+
"graphs": "/graph",
|
| 106 |
+
"tools": "/tools",
|
| 107 |
+
"websocket_run": "/ws/run/{graph_id}",
|
| 108 |
+
"websocket_subscribe": "/ws/subscribe/{run_id}",
|
| 109 |
+
},
|
| 110 |
+
"demo_workflow": "code-review-demo",
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@app.get("/health", tags=["Root"])
|
| 115 |
+
async def health():
|
| 116 |
+
"""Health check endpoint."""
|
| 117 |
+
from app.storage.memory import graph_storage, run_storage
|
| 118 |
+
|
| 119 |
+
return {
|
| 120 |
+
"status": "healthy",
|
| 121 |
+
"version": settings.APP_VERSION,
|
| 122 |
+
"graphs_count": len(graph_storage),
|
| 123 |
+
"runs_count": len(run_storage),
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
# ============================================================
|
| 128 |
+
# Error Handlers
|
| 129 |
+
# ============================================================
|
| 130 |
+
|
| 131 |
+
@app.exception_handler(Exception)
|
| 132 |
+
async def global_exception_handler(request, exc):
|
| 133 |
+
"""Global exception handler for unhandled errors."""
|
| 134 |
+
logger.exception(f"Unhandled error: {exc}")
|
| 135 |
+
return JSONResponse(
|
| 136 |
+
status_code=500,
|
| 137 |
+
content={
|
| 138 |
+
"error": "Internal Server Error",
|
| 139 |
+
"detail": str(exc) if settings.DEBUG else "An unexpected error occurred",
|
| 140 |
+
},
|
| 141 |
+
)
|
app/storage/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Storage package - In-memory storage for graphs and runs.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from app.storage.memory import (
|
| 6 |
+
GraphStorage,
|
| 7 |
+
RunStorage,
|
| 8 |
+
graph_storage,
|
| 9 |
+
run_storage,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
__all__ = [
|
| 13 |
+
"GraphStorage",
|
| 14 |
+
"RunStorage",
|
| 15 |
+
"graph_storage",
|
| 16 |
+
"run_storage",
|
| 17 |
+
]
|
app/storage/memory.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
In-Memory Storage for Workflow Engine.
|
| 3 |
+
|
| 4 |
+
Provides thread-safe storage for graphs and execution runs.
|
| 5 |
+
Can be easily replaced with a database implementation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Any, Dict, List, Optional
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
import asyncio
|
| 11 |
+
from dataclasses import dataclass, field
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class StoredGraph:
|
| 16 |
+
"""A stored graph definition."""
|
| 17 |
+
graph_id: str
|
| 18 |
+
name: str
|
| 19 |
+
definition: Dict[str, Any]
|
| 20 |
+
created_at: datetime = field(default_factory=datetime.now)
|
| 21 |
+
updated_at: datetime = field(default_factory=datetime.now)
|
| 22 |
+
|
| 23 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 24 |
+
return {
|
| 25 |
+
"graph_id": self.graph_id,
|
| 26 |
+
"name": self.name,
|
| 27 |
+
"definition": self.definition,
|
| 28 |
+
"created_at": self.created_at.isoformat(),
|
| 29 |
+
"updated_at": self.updated_at.isoformat(),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@dataclass
|
| 34 |
+
class StoredRun:
|
| 35 |
+
"""A stored execution run."""
|
| 36 |
+
run_id: str
|
| 37 |
+
graph_id: str
|
| 38 |
+
status: str
|
| 39 |
+
initial_state: Dict[str, Any]
|
| 40 |
+
current_state: Dict[str, Any] = field(default_factory=dict)
|
| 41 |
+
final_state: Optional[Dict[str, Any]] = None
|
| 42 |
+
execution_log: List[Dict[str, Any]] = field(default_factory=list)
|
| 43 |
+
current_node: Optional[str] = None
|
| 44 |
+
iteration: int = 0
|
| 45 |
+
started_at: datetime = field(default_factory=datetime.now)
|
| 46 |
+
completed_at: Optional[datetime] = None
|
| 47 |
+
error: Optional[str] = None
|
| 48 |
+
|
| 49 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 50 |
+
return {
|
| 51 |
+
"run_id": self.run_id,
|
| 52 |
+
"graph_id": self.graph_id,
|
| 53 |
+
"status": self.status,
|
| 54 |
+
"initial_state": self.initial_state,
|
| 55 |
+
"current_state": self.current_state,
|
| 56 |
+
"final_state": self.final_state,
|
| 57 |
+
"execution_log": self.execution_log,
|
| 58 |
+
"current_node": self.current_node,
|
| 59 |
+
"iteration": self.iteration,
|
| 60 |
+
"started_at": self.started_at.isoformat(),
|
| 61 |
+
"completed_at": self.completed_at.isoformat() if self.completed_at else None,
|
| 62 |
+
"error": self.error,
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class GraphStorage:
|
| 67 |
+
"""
|
| 68 |
+
Thread-safe in-memory storage for workflow graphs.
|
| 69 |
+
|
| 70 |
+
Stores graph definitions by their ID, allowing creation,
|
| 71 |
+
retrieval, update, and deletion operations.
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def __init__(self):
|
| 75 |
+
self._graphs: Dict[str, StoredGraph] = {}
|
| 76 |
+
self._lock = asyncio.Lock()
|
| 77 |
+
|
| 78 |
+
async def save(self, graph_id: str, name: str, definition: Dict[str, Any]) -> StoredGraph:
|
| 79 |
+
"""
|
| 80 |
+
Save a graph definition.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
graph_id: Unique graph identifier
|
| 84 |
+
name: Graph name
|
| 85 |
+
definition: Graph definition dict
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
The stored graph
|
| 89 |
+
"""
|
| 90 |
+
async with self._lock:
|
| 91 |
+
stored = StoredGraph(
|
| 92 |
+
graph_id=graph_id,
|
| 93 |
+
name=name,
|
| 94 |
+
definition=definition,
|
| 95 |
+
)
|
| 96 |
+
self._graphs[graph_id] = stored
|
| 97 |
+
return stored
|
| 98 |
+
|
| 99 |
+
async def get(self, graph_id: str) -> Optional[StoredGraph]:
|
| 100 |
+
"""Get a graph by ID."""
|
| 101 |
+
async with self._lock:
|
| 102 |
+
return self._graphs.get(graph_id)
|
| 103 |
+
|
| 104 |
+
async def update(self, graph_id: str, definition: Dict[str, Any]) -> Optional[StoredGraph]:
|
| 105 |
+
"""Update a graph definition."""
|
| 106 |
+
async with self._lock:
|
| 107 |
+
if graph_id not in self._graphs:
|
| 108 |
+
return None
|
| 109 |
+
stored = self._graphs[graph_id]
|
| 110 |
+
stored.definition = definition
|
| 111 |
+
stored.updated_at = datetime.now()
|
| 112 |
+
return stored
|
| 113 |
+
|
| 114 |
+
async def delete(self, graph_id: str) -> bool:
|
| 115 |
+
"""Delete a graph."""
|
| 116 |
+
async with self._lock:
|
| 117 |
+
if graph_id in self._graphs:
|
| 118 |
+
del self._graphs[graph_id]
|
| 119 |
+
return True
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
async def list_all(self) -> List[StoredGraph]:
|
| 123 |
+
"""List all stored graphs."""
|
| 124 |
+
async with self._lock:
|
| 125 |
+
return list(self._graphs.values())
|
| 126 |
+
|
| 127 |
+
async def exists(self, graph_id: str) -> bool:
|
| 128 |
+
"""Check if a graph exists."""
|
| 129 |
+
async with self._lock:
|
| 130 |
+
return graph_id in self._graphs
|
| 131 |
+
|
| 132 |
+
def __len__(self) -> int:
|
| 133 |
+
return len(self._graphs)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class RunStorage:
|
| 137 |
+
"""
|
| 138 |
+
Thread-safe in-memory storage for execution runs.
|
| 139 |
+
|
| 140 |
+
Stores run state, allowing real-time updates and queries
|
| 141 |
+
for ongoing and completed runs.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(self):
|
| 145 |
+
self._runs: Dict[str, StoredRun] = {}
|
| 146 |
+
self._lock = asyncio.Lock()
|
| 147 |
+
|
| 148 |
+
async def create(
|
| 149 |
+
self,
|
| 150 |
+
run_id: str,
|
| 151 |
+
graph_id: str,
|
| 152 |
+
initial_state: Dict[str, Any]
|
| 153 |
+
) -> StoredRun:
|
| 154 |
+
"""
|
| 155 |
+
Create a new run.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
run_id: Unique run identifier
|
| 159 |
+
graph_id: Associated graph ID
|
| 160 |
+
initial_state: Initial state data
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
The stored run
|
| 164 |
+
"""
|
| 165 |
+
async with self._lock:
|
| 166 |
+
stored = StoredRun(
|
| 167 |
+
run_id=run_id,
|
| 168 |
+
graph_id=graph_id,
|
| 169 |
+
status="pending",
|
| 170 |
+
initial_state=initial_state,
|
| 171 |
+
current_state=initial_state.copy(),
|
| 172 |
+
)
|
| 173 |
+
self._runs[run_id] = stored
|
| 174 |
+
return stored
|
| 175 |
+
|
| 176 |
+
async def get(self, run_id: str) -> Optional[StoredRun]:
|
| 177 |
+
"""Get a run by ID."""
|
| 178 |
+
async with self._lock:
|
| 179 |
+
return self._runs.get(run_id)
|
| 180 |
+
|
| 181 |
+
async def update_state(
|
| 182 |
+
self,
|
| 183 |
+
run_id: str,
|
| 184 |
+
current_state: Dict[str, Any],
|
| 185 |
+
current_node: Optional[str] = None,
|
| 186 |
+
iteration: Optional[int] = None
|
| 187 |
+
) -> Optional[StoredRun]:
|
| 188 |
+
"""Update the current state of a run."""
|
| 189 |
+
async with self._lock:
|
| 190 |
+
if run_id not in self._runs:
|
| 191 |
+
return None
|
| 192 |
+
stored = self._runs[run_id]
|
| 193 |
+
stored.current_state = current_state
|
| 194 |
+
stored.status = "running"
|
| 195 |
+
if current_node is not None:
|
| 196 |
+
stored.current_node = current_node
|
| 197 |
+
if iteration is not None:
|
| 198 |
+
stored.iteration = iteration
|
| 199 |
+
return stored
|
| 200 |
+
|
| 201 |
+
async def add_log_entry(
|
| 202 |
+
self,
|
| 203 |
+
run_id: str,
|
| 204 |
+
entry: Dict[str, Any]
|
| 205 |
+
) -> Optional[StoredRun]:
|
| 206 |
+
"""Add an entry to the execution log."""
|
| 207 |
+
async with self._lock:
|
| 208 |
+
if run_id not in self._runs:
|
| 209 |
+
return None
|
| 210 |
+
self._runs[run_id].execution_log.append(entry)
|
| 211 |
+
return self._runs[run_id]
|
| 212 |
+
|
| 213 |
+
async def complete(
|
| 214 |
+
self,
|
| 215 |
+
run_id: str,
|
| 216 |
+
final_state: Dict[str, Any],
|
| 217 |
+
execution_log: List[Dict[str, Any]]
|
| 218 |
+
) -> Optional[StoredRun]:
|
| 219 |
+
"""Mark a run as completed."""
|
| 220 |
+
async with self._lock:
|
| 221 |
+
if run_id not in self._runs:
|
| 222 |
+
return None
|
| 223 |
+
stored = self._runs[run_id]
|
| 224 |
+
stored.status = "completed"
|
| 225 |
+
stored.final_state = final_state
|
| 226 |
+
stored.execution_log = execution_log
|
| 227 |
+
stored.completed_at = datetime.now()
|
| 228 |
+
return stored
|
| 229 |
+
|
| 230 |
+
async def fail(
|
| 231 |
+
self,
|
| 232 |
+
run_id: str,
|
| 233 |
+
error: str,
|
| 234 |
+
final_state: Optional[Dict[str, Any]] = None
|
| 235 |
+
) -> Optional[StoredRun]:
|
| 236 |
+
"""Mark a run as failed."""
|
| 237 |
+
async with self._lock:
|
| 238 |
+
if run_id not in self._runs:
|
| 239 |
+
return None
|
| 240 |
+
stored = self._runs[run_id]
|
| 241 |
+
stored.status = "failed"
|
| 242 |
+
stored.error = error
|
| 243 |
+
stored.final_state = final_state
|
| 244 |
+
stored.completed_at = datetime.now()
|
| 245 |
+
return stored
|
| 246 |
+
|
| 247 |
+
async def list_all(self) -> List[StoredRun]:
|
| 248 |
+
"""List all runs."""
|
| 249 |
+
async with self._lock:
|
| 250 |
+
return list(self._runs.values())
|
| 251 |
+
|
| 252 |
+
async def list_by_graph(self, graph_id: str) -> List[StoredRun]:
|
| 253 |
+
"""List all runs for a specific graph."""
|
| 254 |
+
async with self._lock:
|
| 255 |
+
return [r for r in self._runs.values() if r.graph_id == graph_id]
|
| 256 |
+
|
| 257 |
+
async def delete(self, run_id: str) -> bool:
|
| 258 |
+
"""Delete a run."""
|
| 259 |
+
async with self._lock:
|
| 260 |
+
if run_id in self._runs:
|
| 261 |
+
del self._runs[run_id]
|
| 262 |
+
return True
|
| 263 |
+
return False
|
| 264 |
+
|
| 265 |
+
def __len__(self) -> int:
|
| 266 |
+
return len(self._runs)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# Global storage instances
|
| 270 |
+
graph_storage = GraphStorage()
|
| 271 |
+
run_storage = RunStorage()
|
app/tools/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tools package - Tool registry and built-in tools.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from app.tools.registry import ToolRegistry, tool_registry, register_tool, get_tool
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"ToolRegistry",
|
| 9 |
+
"tool_registry",
|
| 10 |
+
"register_tool",
|
| 11 |
+
"get_tool",
|
| 12 |
+
]
|
app/tools/builtin.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Built-in Tools for the Code Review Workflow.
|
| 3 |
+
|
| 4 |
+
These tools implement the functionality needed for the sample
|
| 5 |
+
Code Review workflow demonstration.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import re
|
| 9 |
+
import ast
|
| 10 |
+
from typing import Any, Dict, List, Optional
|
| 11 |
+
from app.tools.registry import register_tool
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@register_tool(
|
| 15 |
+
name="extract_functions",
|
| 16 |
+
description="Extract function definitions from Python code"
|
| 17 |
+
)
|
| 18 |
+
def extract_functions(code: str) -> Dict[str, Any]:
|
| 19 |
+
"""
|
| 20 |
+
Extract function names and basic info from Python code.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
code: Python source code string
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
Dict with 'functions' list containing function info
|
| 27 |
+
"""
|
| 28 |
+
functions = []
|
| 29 |
+
|
| 30 |
+
try:
|
| 31 |
+
tree = ast.parse(code)
|
| 32 |
+
|
| 33 |
+
for node in ast.walk(tree):
|
| 34 |
+
if isinstance(node, ast.FunctionDef):
|
| 35 |
+
func_info = {
|
| 36 |
+
"name": node.name,
|
| 37 |
+
"lineno": node.lineno,
|
| 38 |
+
"args": [arg.arg for arg in node.args.args],
|
| 39 |
+
"has_docstring": (
|
| 40 |
+
ast.get_docstring(node) is not None
|
| 41 |
+
),
|
| 42 |
+
"decorators": [
|
| 43 |
+
ast.unparse(d) if hasattr(ast, 'unparse') else str(d)
|
| 44 |
+
for d in node.decorator_list
|
| 45 |
+
],
|
| 46 |
+
"line_count": (
|
| 47 |
+
node.end_lineno - node.lineno + 1
|
| 48 |
+
if hasattr(node, 'end_lineno') and node.end_lineno
|
| 49 |
+
else 0
|
| 50 |
+
),
|
| 51 |
+
}
|
| 52 |
+
functions.append(func_info)
|
| 53 |
+
|
| 54 |
+
except SyntaxError as e:
|
| 55 |
+
return {
|
| 56 |
+
"functions": [],
|
| 57 |
+
"error": f"Syntax error in code: {e}",
|
| 58 |
+
"parse_success": False,
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
return {
|
| 62 |
+
"functions": functions,
|
| 63 |
+
"function_count": len(functions),
|
| 64 |
+
"parse_success": True,
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@register_tool(
|
| 69 |
+
name="calculate_complexity",
|
| 70 |
+
description="Calculate complexity metrics for code"
|
| 71 |
+
)
|
| 72 |
+
def calculate_complexity(code: str, functions: Optional[List[Dict]] = None) -> Dict[str, Any]:
|
| 73 |
+
"""
|
| 74 |
+
Calculate simple complexity metrics for Python code.
|
| 75 |
+
|
| 76 |
+
Metrics:
|
| 77 |
+
- Lines of code (LOC)
|
| 78 |
+
- Cyclomatic complexity (simplified)
|
| 79 |
+
- Nesting depth
|
| 80 |
+
- Function count
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
code: Python source code
|
| 84 |
+
functions: Optional pre-extracted function list
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Dict with complexity metrics
|
| 88 |
+
"""
|
| 89 |
+
lines = code.split('\n')
|
| 90 |
+
loc = len([l for l in lines if l.strip() and not l.strip().startswith('#')])
|
| 91 |
+
|
| 92 |
+
# Simple cyclomatic complexity: count decision points
|
| 93 |
+
complexity_keywords = ['if', 'elif', 'for', 'while', 'and', 'or', 'except', 'with']
|
| 94 |
+
complexity = 1 # Base complexity
|
| 95 |
+
|
| 96 |
+
for line in lines:
|
| 97 |
+
stripped = line.strip()
|
| 98 |
+
for keyword in complexity_keywords:
|
| 99 |
+
if re.match(rf'\b{keyword}\b', stripped):
|
| 100 |
+
complexity += 1
|
| 101 |
+
|
| 102 |
+
# Calculate max nesting depth
|
| 103 |
+
max_depth = 0
|
| 104 |
+
current_depth = 0
|
| 105 |
+
for line in lines:
|
| 106 |
+
stripped = line.strip()
|
| 107 |
+
if stripped:
|
| 108 |
+
# Count leading spaces
|
| 109 |
+
indent = len(line) - len(line.lstrip())
|
| 110 |
+
depth = indent // 4 # Assume 4-space indentation
|
| 111 |
+
max_depth = max(max_depth, depth)
|
| 112 |
+
|
| 113 |
+
# Calculate function count
|
| 114 |
+
func_count = len(functions) if functions else code.count('def ')
|
| 115 |
+
|
| 116 |
+
# Generate a simple complexity score (1-10 scale)
|
| 117 |
+
# Lower is better
|
| 118 |
+
score = 10
|
| 119 |
+
if complexity > 10:
|
| 120 |
+
score -= 2
|
| 121 |
+
if complexity > 20:
|
| 122 |
+
score -= 2
|
| 123 |
+
if max_depth > 4:
|
| 124 |
+
score -= 1
|
| 125 |
+
if max_depth > 6:
|
| 126 |
+
score -= 1
|
| 127 |
+
if loc > 200:
|
| 128 |
+
score -= 1
|
| 129 |
+
if func_count > 10:
|
| 130 |
+
score -= 1
|
| 131 |
+
if functions:
|
| 132 |
+
long_funcs = [f for f in functions if f.get('line_count', 0) > 50]
|
| 133 |
+
score -= len(long_funcs)
|
| 134 |
+
|
| 135 |
+
score = max(1, score) # Minimum score of 1
|
| 136 |
+
|
| 137 |
+
return {
|
| 138 |
+
"lines_of_code": loc,
|
| 139 |
+
"cyclomatic_complexity": complexity,
|
| 140 |
+
"max_nesting_depth": max_depth,
|
| 141 |
+
"function_count": func_count,
|
| 142 |
+
"complexity_score": score,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@register_tool(
|
| 147 |
+
name="detect_issues",
|
| 148 |
+
description="Detect code quality issues and smells"
|
| 149 |
+
)
|
| 150 |
+
def detect_issues(
|
| 151 |
+
code: str,
|
| 152 |
+
functions: Optional[List[Dict]] = None,
|
| 153 |
+
complexity_score: Optional[int] = None
|
| 154 |
+
) -> Dict[str, Any]:
|
| 155 |
+
"""
|
| 156 |
+
Detect common code quality issues.
|
| 157 |
+
|
| 158 |
+
Checks for:
|
| 159 |
+
- Missing docstrings
|
| 160 |
+
- Long functions
|
| 161 |
+
- Deep nesting
|
| 162 |
+
- Magic numbers
|
| 163 |
+
- TODO/FIXME comments
|
| 164 |
+
- Print statements (in production code)
|
| 165 |
+
- Unused imports (basic check)
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
code: Python source code
|
| 169 |
+
functions: Optional pre-extracted functions
|
| 170 |
+
complexity_score: Optional pre-calculated complexity
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Dict with issues list and summary
|
| 174 |
+
"""
|
| 175 |
+
issues = []
|
| 176 |
+
lines = code.split('\n')
|
| 177 |
+
|
| 178 |
+
# Check for missing docstrings
|
| 179 |
+
if functions:
|
| 180 |
+
for func in functions:
|
| 181 |
+
if not func.get('has_docstring'):
|
| 182 |
+
issues.append({
|
| 183 |
+
"type": "missing_docstring",
|
| 184 |
+
"severity": "warning",
|
| 185 |
+
"message": f"Function '{func['name']}' lacks a docstring",
|
| 186 |
+
"line": func.get('lineno'),
|
| 187 |
+
})
|
| 188 |
+
|
| 189 |
+
# Check for long functions
|
| 190 |
+
if functions:
|
| 191 |
+
for func in functions:
|
| 192 |
+
line_count = func.get('line_count', 0)
|
| 193 |
+
if line_count > 50:
|
| 194 |
+
issues.append({
|
| 195 |
+
"type": "long_function",
|
| 196 |
+
"severity": "warning",
|
| 197 |
+
"message": f"Function '{func['name']}' is too long ({line_count} lines)",
|
| 198 |
+
"line": func.get('lineno'),
|
| 199 |
+
})
|
| 200 |
+
|
| 201 |
+
# Check for TODO/FIXME
|
| 202 |
+
for i, line in enumerate(lines, 1):
|
| 203 |
+
if 'TODO' in line or 'FIXME' in line or 'XXX' in line:
|
| 204 |
+
issues.append({
|
| 205 |
+
"type": "todo_comment",
|
| 206 |
+
"severity": "info",
|
| 207 |
+
"message": f"Found TODO/FIXME comment",
|
| 208 |
+
"line": i,
|
| 209 |
+
})
|
| 210 |
+
|
| 211 |
+
# Check for print statements
|
| 212 |
+
for i, line in enumerate(lines, 1):
|
| 213 |
+
stripped = line.strip()
|
| 214 |
+
if stripped.startswith('print(') or 'print(' in stripped:
|
| 215 |
+
issues.append({
|
| 216 |
+
"type": "print_statement",
|
| 217 |
+
"severity": "info",
|
| 218 |
+
"message": "Print statement found (consider using logging)",
|
| 219 |
+
"line": i,
|
| 220 |
+
})
|
| 221 |
+
|
| 222 |
+
# Check for magic numbers
|
| 223 |
+
magic_number_pattern = r'\b(?<![\'".])\d{2,}\b(?![\'"])'
|
| 224 |
+
for i, line in enumerate(lines, 1):
|
| 225 |
+
# Skip comments and string assignments
|
| 226 |
+
stripped = line.strip()
|
| 227 |
+
if not stripped.startswith('#'):
|
| 228 |
+
matches = re.findall(magic_number_pattern, line)
|
| 229 |
+
for match in matches:
|
| 230 |
+
if int(match) not in (0, 1, 2, 100): # Common acceptable values
|
| 231 |
+
issues.append({
|
| 232 |
+
"type": "magic_number",
|
| 233 |
+
"severity": "info",
|
| 234 |
+
"message": f"Magic number {match} found (consider using a constant)",
|
| 235 |
+
"line": i,
|
| 236 |
+
})
|
| 237 |
+
break # One per line is enough
|
| 238 |
+
|
| 239 |
+
# Calculate quality score based on issues
|
| 240 |
+
quality_score = 10
|
| 241 |
+
for issue in issues:
|
| 242 |
+
if issue['severity'] == 'error':
|
| 243 |
+
quality_score -= 2
|
| 244 |
+
elif issue['severity'] == 'warning':
|
| 245 |
+
quality_score -= 1
|
| 246 |
+
else:
|
| 247 |
+
quality_score -= 0.5
|
| 248 |
+
|
| 249 |
+
# Factor in complexity score if provided
|
| 250 |
+
if complexity_score:
|
| 251 |
+
quality_score = (quality_score + complexity_score) / 2
|
| 252 |
+
|
| 253 |
+
quality_score = max(1, min(10, quality_score))
|
| 254 |
+
|
| 255 |
+
return {
|
| 256 |
+
"issues": issues,
|
| 257 |
+
"issue_count": len(issues),
|
| 258 |
+
"quality_score": round(quality_score, 1),
|
| 259 |
+
"issues_by_severity": {
|
| 260 |
+
"error": len([i for i in issues if i['severity'] == 'error']),
|
| 261 |
+
"warning": len([i for i in issues if i['severity'] == 'warning']),
|
| 262 |
+
"info": len([i for i in issues if i['severity'] == 'info']),
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
@register_tool(
|
| 268 |
+
name="suggest_improvements",
|
| 269 |
+
description="Generate improvement suggestions based on detected issues"
|
| 270 |
+
)
|
| 271 |
+
def suggest_improvements(
|
| 272 |
+
issues: List[Dict],
|
| 273 |
+
functions: Optional[List[Dict]] = None,
|
| 274 |
+
quality_score: Optional[float] = None
|
| 275 |
+
) -> Dict[str, Any]:
|
| 276 |
+
"""
|
| 277 |
+
Generate actionable improvement suggestions.
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
issues: List of detected issues
|
| 281 |
+
functions: Optional function info
|
| 282 |
+
quality_score: Current quality score
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
Dict with suggestions and priority ranking
|
| 286 |
+
"""
|
| 287 |
+
suggestions = []
|
| 288 |
+
|
| 289 |
+
# Group issues by type
|
| 290 |
+
issue_types = {}
|
| 291 |
+
for issue in issues:
|
| 292 |
+
issue_type = issue.get('type', 'unknown')
|
| 293 |
+
if issue_type not in issue_types:
|
| 294 |
+
issue_types[issue_type] = []
|
| 295 |
+
issue_types[issue_type].append(issue)
|
| 296 |
+
|
| 297 |
+
# Generate suggestions based on issue types
|
| 298 |
+
if 'missing_docstring' in issue_types:
|
| 299 |
+
count = len(issue_types['missing_docstring'])
|
| 300 |
+
suggestions.append({
|
| 301 |
+
"priority": "high",
|
| 302 |
+
"category": "documentation",
|
| 303 |
+
"suggestion": f"Add docstrings to {count} function(s)",
|
| 304 |
+
"details": "Good docstrings improve code maintainability and enable automatic documentation generation.",
|
| 305 |
+
"affected_functions": [i.get('message', '').split("'")[1] for i in issue_types['missing_docstring'] if "'" in i.get('message', '')],
|
| 306 |
+
})
|
| 307 |
+
|
| 308 |
+
if 'long_function' in issue_types:
|
| 309 |
+
count = len(issue_types['long_function'])
|
| 310 |
+
suggestions.append({
|
| 311 |
+
"priority": "high",
|
| 312 |
+
"category": "refactoring",
|
| 313 |
+
"suggestion": f"Refactor {count} long function(s) into smaller units",
|
| 314 |
+
"details": "Functions over 50 lines are harder to understand and test. Consider extracting helper functions.",
|
| 315 |
+
})
|
| 316 |
+
|
| 317 |
+
if 'print_statement' in issue_types:
|
| 318 |
+
count = len(issue_types['print_statement'])
|
| 319 |
+
suggestions.append({
|
| 320 |
+
"priority": "medium",
|
| 321 |
+
"category": "logging",
|
| 322 |
+
"suggestion": f"Replace {count} print statement(s) with proper logging",
|
| 323 |
+
"details": "Use the logging module for better control over log levels and output.",
|
| 324 |
+
})
|
| 325 |
+
|
| 326 |
+
if 'magic_number' in issue_types:
|
| 327 |
+
count = len(issue_types['magic_number'])
|
| 328 |
+
suggestions.append({
|
| 329 |
+
"priority": "medium",
|
| 330 |
+
"category": "readability",
|
| 331 |
+
"suggestion": f"Extract {count} magic number(s) into named constants",
|
| 332 |
+
"details": "Named constants improve readability and make the code easier to modify.",
|
| 333 |
+
})
|
| 334 |
+
|
| 335 |
+
if 'todo_comment' in issue_types:
|
| 336 |
+
count = len(issue_types['todo_comment'])
|
| 337 |
+
suggestions.append({
|
| 338 |
+
"priority": "low",
|
| 339 |
+
"category": "maintenance",
|
| 340 |
+
"suggestion": f"Address {count} TODO/FIXME comment(s)",
|
| 341 |
+
"details": "Consider creating issues or tasks to track these items.",
|
| 342 |
+
})
|
| 343 |
+
|
| 344 |
+
# Add general suggestions if quality is low
|
| 345 |
+
if quality_score and quality_score < 5:
|
| 346 |
+
suggestions.append({
|
| 347 |
+
"priority": "high",
|
| 348 |
+
"category": "general",
|
| 349 |
+
"suggestion": "Consider a comprehensive code review",
|
| 350 |
+
"details": "The overall quality score is low. A thorough review may reveal structural improvements.",
|
| 351 |
+
})
|
| 352 |
+
|
| 353 |
+
# Sort by priority
|
| 354 |
+
priority_order = {"high": 0, "medium": 1, "low": 2}
|
| 355 |
+
suggestions.sort(key=lambda x: priority_order.get(x['priority'], 3))
|
| 356 |
+
|
| 357 |
+
# Calculate new expected quality score after improvements
|
| 358 |
+
potential_improvement = len(suggestions) * 0.5
|
| 359 |
+
new_quality_score = min(10, (quality_score or 5) + potential_improvement)
|
| 360 |
+
|
| 361 |
+
return {
|
| 362 |
+
"suggestions": suggestions,
|
| 363 |
+
"suggestion_count": len(suggestions),
|
| 364 |
+
"current_quality_score": quality_score,
|
| 365 |
+
"potential_quality_score": round(new_quality_score, 1),
|
| 366 |
+
"categories": list(set(s['category'] for s in suggestions)),
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
@register_tool(
|
| 371 |
+
name="quality_check",
|
| 372 |
+
description="Check if quality meets the threshold"
|
| 373 |
+
)
|
| 374 |
+
def quality_check(quality_score: float, quality_threshold: float = 7.0) -> str:
|
| 375 |
+
"""
|
| 376 |
+
Simple routing function to check if quality meets threshold.
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
quality_score: Current quality score (1-10)
|
| 380 |
+
quality_threshold: Minimum acceptable score
|
| 381 |
+
|
| 382 |
+
Returns:
|
| 383 |
+
"pass" if quality meets threshold, "fail" otherwise
|
| 384 |
+
"""
|
| 385 |
+
if quality_score >= quality_threshold:
|
| 386 |
+
return "pass"
|
| 387 |
+
return "fail"
|
app/tools/registry.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tool Registry for Workflow Engine.
|
| 3 |
+
|
| 4 |
+
The tool registry maintains a collection of callable tools that
|
| 5 |
+
workflow nodes can use. Tools are simple Python functions that
|
| 6 |
+
perform specific operations.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
import functools
|
| 12 |
+
import inspect
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class Tool:
|
| 21 |
+
"""
|
| 22 |
+
A registered tool.
|
| 23 |
+
|
| 24 |
+
Attributes:
|
| 25 |
+
name: Unique identifier for the tool
|
| 26 |
+
func: The callable function
|
| 27 |
+
description: Human-readable description
|
| 28 |
+
parameters: Parameter descriptions
|
| 29 |
+
"""
|
| 30 |
+
name: str
|
| 31 |
+
func: Callable
|
| 32 |
+
description: str = ""
|
| 33 |
+
parameters: Dict[str, str] = field(default_factory=dict)
|
| 34 |
+
|
| 35 |
+
def __call__(self, *args, **kwargs) -> Any:
|
| 36 |
+
"""Call the tool function."""
|
| 37 |
+
return self.func(*args, **kwargs)
|
| 38 |
+
|
| 39 |
+
def to_dict(self) -> Dict[str, Any]:
|
| 40 |
+
"""Serialize tool metadata."""
|
| 41 |
+
return {
|
| 42 |
+
"name": self.name,
|
| 43 |
+
"description": self.description,
|
| 44 |
+
"parameters": self.parameters,
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ToolRegistry:
|
| 49 |
+
"""
|
| 50 |
+
Registry for workflow tools.
|
| 51 |
+
|
| 52 |
+
Tools are simple Python functions that nodes can call to perform
|
| 53 |
+
specific operations. The registry allows dynamic registration
|
| 54 |
+
and lookup of tools.
|
| 55 |
+
|
| 56 |
+
Usage:
|
| 57 |
+
registry = ToolRegistry()
|
| 58 |
+
|
| 59 |
+
@registry.register("my_tool")
|
| 60 |
+
def my_tool(data: str) -> dict:
|
| 61 |
+
return {"result": data.upper()}
|
| 62 |
+
|
| 63 |
+
# Later
|
| 64 |
+
tool = registry.get("my_tool")
|
| 65 |
+
result = tool("hello")
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self):
|
| 69 |
+
self._tools: Dict[str, Tool] = {}
|
| 70 |
+
|
| 71 |
+
def register(
|
| 72 |
+
self,
|
| 73 |
+
name: Optional[str] = None,
|
| 74 |
+
description: str = "",
|
| 75 |
+
parameters: Optional[Dict[str, str]] = None
|
| 76 |
+
) -> Callable:
|
| 77 |
+
"""
|
| 78 |
+
Decorator to register a function as a tool.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
name: Tool name (defaults to function name)
|
| 82 |
+
description: Tool description (defaults to docstring)
|
| 83 |
+
parameters: Parameter descriptions
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
Decorator function
|
| 87 |
+
"""
|
| 88 |
+
def decorator(func: Callable) -> Callable:
|
| 89 |
+
tool_name = name or func.__name__
|
| 90 |
+
tool_desc = description or func.__doc__ or ""
|
| 91 |
+
|
| 92 |
+
# Extract parameters from signature if not provided
|
| 93 |
+
params = parameters or {}
|
| 94 |
+
if not params:
|
| 95 |
+
sig = inspect.signature(func)
|
| 96 |
+
for param_name, param in sig.parameters.items():
|
| 97 |
+
if param_name not in ("self", "cls"):
|
| 98 |
+
params[param_name] = str(param.annotation) if param.annotation != inspect.Parameter.empty else "Any"
|
| 99 |
+
|
| 100 |
+
# Create and store tool
|
| 101 |
+
tool = Tool(
|
| 102 |
+
name=tool_name,
|
| 103 |
+
func=func,
|
| 104 |
+
description=tool_desc.strip(),
|
| 105 |
+
parameters=params,
|
| 106 |
+
)
|
| 107 |
+
self._tools[tool_name] = tool
|
| 108 |
+
|
| 109 |
+
logger.debug(f"Registered tool: {tool_name}")
|
| 110 |
+
|
| 111 |
+
@functools.wraps(func)
|
| 112 |
+
def wrapper(*args, **kwargs):
|
| 113 |
+
return func(*args, **kwargs)
|
| 114 |
+
|
| 115 |
+
return wrapper
|
| 116 |
+
|
| 117 |
+
return decorator
|
| 118 |
+
|
| 119 |
+
def add(
|
| 120 |
+
self,
|
| 121 |
+
func: Callable,
|
| 122 |
+
name: Optional[str] = None,
|
| 123 |
+
description: str = "",
|
| 124 |
+
parameters: Optional[Dict[str, str]] = None
|
| 125 |
+
) -> None:
|
| 126 |
+
"""
|
| 127 |
+
Directly add a function as a tool (non-decorator version).
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
func: The function to register
|
| 131 |
+
name: Tool name (defaults to function name)
|
| 132 |
+
description: Tool description
|
| 133 |
+
parameters: Parameter descriptions
|
| 134 |
+
"""
|
| 135 |
+
tool_name = name or func.__name__
|
| 136 |
+
tool_desc = description or func.__doc__ or ""
|
| 137 |
+
|
| 138 |
+
tool = Tool(
|
| 139 |
+
name=tool_name,
|
| 140 |
+
func=func,
|
| 141 |
+
description=tool_desc.strip(),
|
| 142 |
+
parameters=parameters or {},
|
| 143 |
+
)
|
| 144 |
+
self._tools[tool_name] = tool
|
| 145 |
+
logger.debug(f"Added tool: {tool_name}")
|
| 146 |
+
|
| 147 |
+
def get(self, name: str) -> Optional[Tool]:
|
| 148 |
+
"""Get a tool by name."""
|
| 149 |
+
return self._tools.get(name)
|
| 150 |
+
|
| 151 |
+
def call(self, name: str, *args, **kwargs) -> Any:
|
| 152 |
+
"""
|
| 153 |
+
Call a tool by name.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
name: Tool name
|
| 157 |
+
*args: Positional arguments
|
| 158 |
+
**kwargs: Keyword arguments
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Tool result
|
| 162 |
+
|
| 163 |
+
Raises:
|
| 164 |
+
KeyError: If tool not found
|
| 165 |
+
"""
|
| 166 |
+
tool = self.get(name)
|
| 167 |
+
if not tool:
|
| 168 |
+
raise KeyError(f"Tool '{name}' not found in registry")
|
| 169 |
+
return tool(*args, **kwargs)
|
| 170 |
+
|
| 171 |
+
def remove(self, name: str) -> bool:
|
| 172 |
+
"""Remove a tool from the registry."""
|
| 173 |
+
if name in self._tools:
|
| 174 |
+
del self._tools[name]
|
| 175 |
+
return True
|
| 176 |
+
return False
|
| 177 |
+
|
| 178 |
+
def list_tools(self) -> List[Dict[str, Any]]:
|
| 179 |
+
"""List all registered tools with their metadata."""
|
| 180 |
+
return [tool.to_dict() for tool in self._tools.values()]
|
| 181 |
+
|
| 182 |
+
def has(self, name: str) -> bool:
|
| 183 |
+
"""Check if a tool is registered."""
|
| 184 |
+
return name in self._tools
|
| 185 |
+
|
| 186 |
+
def __contains__(self, name: str) -> bool:
|
| 187 |
+
return self.has(name)
|
| 188 |
+
|
| 189 |
+
def __len__(self) -> int:
|
| 190 |
+
return len(self._tools)
|
| 191 |
+
|
| 192 |
+
def __iter__(self):
|
| 193 |
+
return iter(self._tools.values())
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
# Global tool registry instance
|
| 197 |
+
tool_registry = ToolRegistry()
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def register_tool(
|
| 201 |
+
name: Optional[str] = None,
|
| 202 |
+
description: str = "",
|
| 203 |
+
parameters: Optional[Dict[str, str]] = None
|
| 204 |
+
) -> Callable:
|
| 205 |
+
"""
|
| 206 |
+
Convenience decorator to register a tool in the global registry.
|
| 207 |
+
|
| 208 |
+
Usage:
|
| 209 |
+
@register_tool("my_tool", description="Does something cool")
|
| 210 |
+
def my_tool(data: str) -> dict:
|
| 211 |
+
return {"result": data}
|
| 212 |
+
"""
|
| 213 |
+
return tool_registry.register(name, description, parameters)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def get_tool(name: str) -> Optional[Tool]:
|
| 217 |
+
"""Get a tool from the global registry."""
|
| 218 |
+
return tool_registry.get(name)
|
app/workflows/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Workflows package - Sample workflow implementations.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from app.workflows.code_review import create_code_review_workflow, register_code_review_workflow
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"create_code_review_workflow",
|
| 9 |
+
"register_code_review_workflow",
|
| 10 |
+
]
|
app/workflows/code_review.py
ADDED
|
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Code Review Workflow Implementation.
|
| 3 |
+
|
| 4 |
+
This is the sample workflow demonstrating the workflow engine capabilities:
|
| 5 |
+
1. Extract functions from code
|
| 6 |
+
2. Check complexity
|
| 7 |
+
3. Detect issues
|
| 8 |
+
4. Suggest improvements
|
| 9 |
+
5. Loop until quality_score >= threshold
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from typing import Any, Dict
|
| 13 |
+
import logging
|
| 14 |
+
|
| 15 |
+
from app.engine.graph import Graph, END
|
| 16 |
+
from app.engine.node import node, NodeType
|
| 17 |
+
from app.tools.builtin import (
|
| 18 |
+
extract_functions,
|
| 19 |
+
calculate_complexity,
|
| 20 |
+
detect_issues,
|
| 21 |
+
suggest_improvements,
|
| 22 |
+
quality_check,
|
| 23 |
+
)
|
| 24 |
+
from app.tools.registry import tool_registry
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ============================================================
|
| 31 |
+
# Node Handlers (using the @node decorator)
|
| 32 |
+
# ============================================================
|
| 33 |
+
|
| 34 |
+
@node(name="extract_node", description="Extract functions from the input code")
|
| 35 |
+
def extract_node(state: Dict[str, Any]) -> Dict[str, Any]:
|
| 36 |
+
"""
|
| 37 |
+
Extract function definitions from the code.
|
| 38 |
+
|
| 39 |
+
Input state requires:
|
| 40 |
+
- code: str - The Python source code to analyze
|
| 41 |
+
|
| 42 |
+
Updates state with:
|
| 43 |
+
- functions: List[dict] - Extracted function information
|
| 44 |
+
- function_count: int - Number of functions found
|
| 45 |
+
"""
|
| 46 |
+
code = state.get("code", "")
|
| 47 |
+
result = extract_functions(code)
|
| 48 |
+
state.update(result)
|
| 49 |
+
logger.info(f"Extracted {result.get('function_count', 0)} functions")
|
| 50 |
+
return state
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@node(name="complexity_node", description="Calculate code complexity metrics")
|
| 54 |
+
def complexity_node(state: Dict[str, Any]) -> Dict[str, Any]:
|
| 55 |
+
"""
|
| 56 |
+
Calculate complexity metrics for the code.
|
| 57 |
+
|
| 58 |
+
Uses state:
|
| 59 |
+
- code: str - Source code
|
| 60 |
+
- functions: List[dict] - Previously extracted functions
|
| 61 |
+
|
| 62 |
+
Updates state with:
|
| 63 |
+
- lines_of_code: int
|
| 64 |
+
- cyclomatic_complexity: int
|
| 65 |
+
- complexity_score: int (1-10)
|
| 66 |
+
"""
|
| 67 |
+
code = state.get("code", "")
|
| 68 |
+
functions = state.get("functions", [])
|
| 69 |
+
result = calculate_complexity(code, functions)
|
| 70 |
+
state.update(result)
|
| 71 |
+
logger.info(f"Complexity score: {result.get('complexity_score', 0)}")
|
| 72 |
+
return state
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@node(name="issues_node", description="Detect code quality issues")
|
| 76 |
+
def issues_node(state: Dict[str, Any]) -> Dict[str, Any]:
|
| 77 |
+
"""
|
| 78 |
+
Detect code quality issues and calculate quality score.
|
| 79 |
+
|
| 80 |
+
Uses state:
|
| 81 |
+
- code: str - Source code
|
| 82 |
+
- functions: List[dict] - Extracted functions
|
| 83 |
+
- complexity_score: int - From complexity check
|
| 84 |
+
|
| 85 |
+
Updates state with:
|
| 86 |
+
- issues: List[dict] - Detected issues
|
| 87 |
+
- issue_count: int
|
| 88 |
+
- quality_score: float (1-10)
|
| 89 |
+
"""
|
| 90 |
+
code = state.get("code", "")
|
| 91 |
+
functions = state.get("functions", [])
|
| 92 |
+
complexity_score = state.get("complexity_score")
|
| 93 |
+
|
| 94 |
+
result = detect_issues(code, functions, complexity_score)
|
| 95 |
+
state.update(result)
|
| 96 |
+
|
| 97 |
+
logger.info(
|
| 98 |
+
f"Found {result.get('issue_count', 0)} issues, "
|
| 99 |
+
f"quality score: {result.get('quality_score', 0)}"
|
| 100 |
+
)
|
| 101 |
+
return state
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@node(name="improve_node", description="Generate improvement suggestions")
|
| 105 |
+
def improve_node(state: Dict[str, Any]) -> Dict[str, Any]:
|
| 106 |
+
"""
|
| 107 |
+
Generate improvement suggestions based on detected issues.
|
| 108 |
+
|
| 109 |
+
Uses state:
|
| 110 |
+
- issues: List[dict] - Detected issues
|
| 111 |
+
- functions: List[dict] - Extracted functions
|
| 112 |
+
- quality_score: float - Current quality score
|
| 113 |
+
|
| 114 |
+
Updates state with:
|
| 115 |
+
- suggestions: List[dict] - Improvement suggestions
|
| 116 |
+
- suggestion_count: int
|
| 117 |
+
- potential_quality_score: float - Score after improvements
|
| 118 |
+
"""
|
| 119 |
+
issues = state.get("issues", [])
|
| 120 |
+
functions = state.get("functions", [])
|
| 121 |
+
quality_score = state.get("quality_score", 5.0)
|
| 122 |
+
|
| 123 |
+
result = suggest_improvements(issues, functions, quality_score)
|
| 124 |
+
state.update(result)
|
| 125 |
+
|
| 126 |
+
# Simulate improvement by slightly increasing quality score
|
| 127 |
+
# In a real scenario, this would involve actual code modifications
|
| 128 |
+
improvement = min(0.5, result.get("suggestion_count", 0) * 0.2)
|
| 129 |
+
state["quality_score"] = min(10, quality_score + improvement)
|
| 130 |
+
|
| 131 |
+
logger.info(
|
| 132 |
+
f"Generated {result.get('suggestion_count', 0)} suggestions, "
|
| 133 |
+
f"quality improved to {state['quality_score']}"
|
| 134 |
+
)
|
| 135 |
+
return state
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# Register node handlers as tools so they can be retrieved when rebuilding from storage
|
| 139 |
+
def _wrapper_handler(handler_func):
|
| 140 |
+
"""Create a wrapper that works with tool registry."""
|
| 141 |
+
def wrapper(state: Dict[str, Any]) -> Dict[str, Any]:
|
| 142 |
+
return handler_func(state)
|
| 143 |
+
wrapper.__name__ = handler_func.__name__
|
| 144 |
+
wrapper.__doc__ = handler_func.__doc__
|
| 145 |
+
return wrapper
|
| 146 |
+
|
| 147 |
+
tool_registry.add(_wrapper_handler(extract_node), name="extract_node", description="Extract functions from code")
|
| 148 |
+
tool_registry.add(_wrapper_handler(complexity_node), name="complexity_node", description="Calculate complexity")
|
| 149 |
+
tool_registry.add(_wrapper_handler(issues_node), name="issues_node", description="Detect quality issues")
|
| 150 |
+
tool_registry.add(_wrapper_handler(improve_node), name="improve_node", description="Suggest improvements")
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ============================================================
|
| 154 |
+
# Condition Functions
|
| 155 |
+
# ============================================================
|
| 156 |
+
|
| 157 |
+
def quality_meets_threshold(state: Dict[str, Any]) -> str:
|
| 158 |
+
"""
|
| 159 |
+
Routing condition: check if quality meets threshold.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
- "pass" if quality_score >= quality_threshold
|
| 163 |
+
- "fail" if more improvement needed
|
| 164 |
+
"""
|
| 165 |
+
quality_score = state.get("quality_score", 0)
|
| 166 |
+
threshold = state.get("quality_threshold", 7.0)
|
| 167 |
+
|
| 168 |
+
if quality_score >= threshold:
|
| 169 |
+
logger.info(f"Quality {quality_score} meets threshold {threshold}")
|
| 170 |
+
return "pass"
|
| 171 |
+
else:
|
| 172 |
+
logger.info(f"Quality {quality_score} below threshold {threshold}")
|
| 173 |
+
return "fail"
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def always_loop(state: Dict[str, Any]) -> str:
|
| 177 |
+
"""Always return to issues check after improvement."""
|
| 178 |
+
return "continue"
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# ============================================================
|
| 182 |
+
# Workflow Factory
|
| 183 |
+
# ============================================================
|
| 184 |
+
|
| 185 |
+
def create_code_review_workflow(
|
| 186 |
+
max_iterations: int = 5,
|
| 187 |
+
quality_threshold: float = 7.0
|
| 188 |
+
) -> Graph:
|
| 189 |
+
"""
|
| 190 |
+
Create a Code Review workflow graph.
|
| 191 |
+
|
| 192 |
+
Workflow flow:
|
| 193 |
+
```
|
| 194 |
+
extract → complexity → issues ─┬─→ END (if pass)
|
| 195 |
+
│
|
| 196 |
+
└─→ improve → issues (loop if fail)
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
max_iterations: Maximum improvement loops
|
| 201 |
+
quality_threshold: Minimum quality score to pass
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Configured Graph instance
|
| 205 |
+
"""
|
| 206 |
+
graph = Graph(
|
| 207 |
+
name="Code Review Workflow",
|
| 208 |
+
description=(
|
| 209 |
+
"Analyzes Python code for quality issues and suggests improvements. "
|
| 210 |
+
f"Loops until quality score >= {quality_threshold} or max {max_iterations} iterations."
|
| 211 |
+
),
|
| 212 |
+
max_iterations=max_iterations,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Add nodes
|
| 216 |
+
graph.add_node("extract", handler=extract_node, description="Extract functions from code")
|
| 217 |
+
graph.add_node("complexity", handler=complexity_node, description="Calculate complexity")
|
| 218 |
+
graph.add_node("issues", handler=issues_node, description="Detect quality issues")
|
| 219 |
+
graph.add_node("improve", handler=improve_node, description="Suggest improvements")
|
| 220 |
+
|
| 221 |
+
# Add edges
|
| 222 |
+
graph.add_edge("extract", "complexity")
|
| 223 |
+
graph.add_edge("complexity", "issues")
|
| 224 |
+
|
| 225 |
+
# Conditional edge: issues → END or improve
|
| 226 |
+
graph.add_conditional_edge(
|
| 227 |
+
"issues",
|
| 228 |
+
quality_meets_threshold,
|
| 229 |
+
{"pass": END, "fail": "improve"}
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Loop back from improve to issues
|
| 233 |
+
graph.add_conditional_edge(
|
| 234 |
+
"improve",
|
| 235 |
+
always_loop,
|
| 236 |
+
{"continue": "issues"}
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# Set entry point
|
| 240 |
+
graph.set_entry_point("extract")
|
| 241 |
+
|
| 242 |
+
return graph
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
async def register_code_review_workflow():
|
| 246 |
+
"""
|
| 247 |
+
Register a pre-built Code Review workflow in storage.
|
| 248 |
+
|
| 249 |
+
This makes the workflow available immediately via the API
|
| 250 |
+
without needing to create it first.
|
| 251 |
+
"""
|
| 252 |
+
from app.storage.memory import graph_storage
|
| 253 |
+
|
| 254 |
+
workflow = create_code_review_workflow()
|
| 255 |
+
|
| 256 |
+
await graph_storage.save(
|
| 257 |
+
graph_id="code-review-demo",
|
| 258 |
+
name="Code Review Demo",
|
| 259 |
+
definition=workflow.to_dict(),
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
logger.info("Registered Code Review workflow with ID: code-review-demo")
|
| 263 |
+
return workflow
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# ============================================================
|
| 267 |
+
# Example Usage
|
| 268 |
+
# ============================================================
|
| 269 |
+
|
| 270 |
+
async def run_code_review_demo():
|
| 271 |
+
"""
|
| 272 |
+
Demo function showing how to run the code review workflow.
|
| 273 |
+
|
| 274 |
+
Usage:
|
| 275 |
+
import asyncio
|
| 276 |
+
from app.workflows.code_review import run_code_review_demo
|
| 277 |
+
asyncio.run(run_code_review_demo())
|
| 278 |
+
"""
|
| 279 |
+
from app.engine.executor import execute_graph
|
| 280 |
+
|
| 281 |
+
# Sample code to review
|
| 282 |
+
sample_code = '''
|
| 283 |
+
def calculate_total(items):
|
| 284 |
+
total = 0
|
| 285 |
+
for item in items:
|
| 286 |
+
if item.price > 0:
|
| 287 |
+
if item.quantity > 0:
|
| 288 |
+
if item.discount:
|
| 289 |
+
total += item.price * item.quantity * (1 - item.discount)
|
| 290 |
+
else:
|
| 291 |
+
total += item.price * item.quantity
|
| 292 |
+
return total
|
| 293 |
+
|
| 294 |
+
def process_data(data):
|
| 295 |
+
result = []
|
| 296 |
+
for i in range(len(data)):
|
| 297 |
+
if data[i] > 100:
|
| 298 |
+
result.append(data[i] * 2)
|
| 299 |
+
else:
|
| 300 |
+
result.append(data[i])
|
| 301 |
+
print(result)
|
| 302 |
+
return result
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def helper():
|
| 306 |
+
x = 42
|
| 307 |
+
return x * 1000
|
| 308 |
+
'''
|
| 309 |
+
|
| 310 |
+
# Create workflow
|
| 311 |
+
workflow = create_code_review_workflow(max_iterations=3, quality_threshold=6.0)
|
| 312 |
+
|
| 313 |
+
# Initial state
|
| 314 |
+
initial_state = {
|
| 315 |
+
"code": sample_code,
|
| 316 |
+
"quality_threshold": 6.0,
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
# Execute
|
| 320 |
+
print("Starting Code Review...")
|
| 321 |
+
result = await execute_graph(workflow, initial_state)
|
| 322 |
+
|
| 323 |
+
# Print results
|
| 324 |
+
print(f"\nExecution Status: {result.status.value}")
|
| 325 |
+
print(f"Total Duration: {result.total_duration_ms:.2f}ms")
|
| 326 |
+
print(f"Iterations: {result.iterations}")
|
| 327 |
+
print(f"\nFinal Quality Score: {result.final_state.get('quality_score', 'N/A')}")
|
| 328 |
+
print(f"Issues Found: {result.final_state.get('issue_count', 'N/A')}")
|
| 329 |
+
print(f"\nSuggestions:")
|
| 330 |
+
for suggestion in result.final_state.get("suggestions", []):
|
| 331 |
+
print(f" - [{suggestion['priority']}] {suggestion['suggestion']}")
|
| 332 |
+
|
| 333 |
+
return result
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
if __name__ == "__main__":
|
| 337 |
+
import asyncio
|
| 338 |
+
asyncio.run(run_code_review_demo())
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
workflow-engine:
|
| 3 |
+
build:
|
| 4 |
+
context: .
|
| 5 |
+
dockerfile: Dockerfile
|
| 6 |
+
container_name: workflow-engine
|
| 7 |
+
ports:
|
| 8 |
+
- "8000:8000"
|
| 9 |
+
environment:
|
| 10 |
+
- APP_NAME=FlowGraph
|
| 11 |
+
- APP_VERSION=1.0.0
|
| 12 |
+
- DEBUG=true
|
| 13 |
+
- HOST=0.0.0.0
|
| 14 |
+
- PORT=8000
|
| 15 |
+
- MAX_ITERATIONS=100
|
| 16 |
+
- LOG_LEVEL=INFO
|
| 17 |
+
volumes:
|
| 18 |
+
# Mount for development (hot reload)
|
| 19 |
+
- .:/app
|
| 20 |
+
command: uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload
|
| 21 |
+
healthcheck:
|
| 22 |
+
test: [ "CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" ]
|
| 23 |
+
interval: 30s
|
| 24 |
+
timeout: 10s
|
| 25 |
+
retries: 3
|
| 26 |
+
start_period: 10s
|
| 27 |
+
restart: unless-stopped
|
| 28 |
+
|
| 29 |
+
# Optional: Run tests in a separate container
|
| 30 |
+
tests:
|
| 31 |
+
build:
|
| 32 |
+
context: .
|
| 33 |
+
dockerfile: Dockerfile
|
| 34 |
+
container_name: workflow-engine-tests
|
| 35 |
+
command: pytest tests/ -v
|
| 36 |
+
profiles:
|
| 37 |
+
- test
|
| 38 |
+
depends_on:
|
| 39 |
+
- workflow-engine
|
pytest.ini
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
asyncio_mode = auto
|
| 3 |
+
testpaths = tests
|
| 4 |
+
python_files = test_*.py
|
| 5 |
+
python_functions = test_*
|
| 6 |
+
addopts = -v --tb=short
|
requirements.txt
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core
|
| 2 |
+
fastapi>=0.104.0
|
| 3 |
+
uvicorn[standard]>=0.24.0
|
| 4 |
+
pydantic>=2.5.0
|
| 5 |
+
pydantic-settings>=2.1.0
|
| 6 |
+
|
| 7 |
+
# Async support
|
| 8 |
+
asyncio-throttle>=1.0.2
|
| 9 |
+
|
| 10 |
+
# Testing
|
| 11 |
+
pytest>=7.4.0
|
| 12 |
+
pytest-asyncio>=0.21.0
|
| 13 |
+
httpx>=0.25.0
|
| 14 |
+
|
| 15 |
+
# Optional - for better logging
|
| 16 |
+
rich>=13.7.0
|
run.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Simple run script for the Workflow Engine.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python run.py
|
| 7 |
+
|
| 8 |
+
Or with custom settings:
|
| 9 |
+
HOST=127.0.0.1 PORT=8080 python run.py
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import uvicorn
|
| 13 |
+
import os
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
"""Run the FastAPI application."""
|
| 18 |
+
host = os.getenv("HOST", "0.0.0.0")
|
| 19 |
+
port = int(os.getenv("PORT", "8000"))
|
| 20 |
+
reload = os.getenv("RELOAD", "true").lower() == "true"
|
| 21 |
+
|
| 22 |
+
print(f"""
|
| 23 |
+
╔═══════════════════════════════════════════════════════════════╗
|
| 24 |
+
║ FlowGraph 🔄 ║
|
| 25 |
+
║ ║
|
| 26 |
+
║ A lightweight workflow orchestration engine ║
|
| 27 |
+
╠═══════════════════════════════════════════════════════════════╣
|
| 28 |
+
║ Server: http://{host}:{port} ║
|
| 29 |
+
║ API Docs: http://{host}:{port}/docs ║
|
| 30 |
+
║ ReDoc: http://{host}:{port}/redoc ║
|
| 31 |
+
╠═══════════════════════════════════════════════════════════════╣
|
| 32 |
+
║ Demo workflow ID: code-review-demo ║
|
| 33 |
+
╚═══════════════════════════════════════════════════════════════╝
|
| 34 |
+
""")
|
| 35 |
+
|
| 36 |
+
uvicorn.run(
|
| 37 |
+
"app.main:app",
|
| 38 |
+
host=host,
|
| 39 |
+
port=port,
|
| 40 |
+
reload=reload,
|
| 41 |
+
log_level="info",
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
main()
|
tests/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests package.
|
| 3 |
+
"""
|
tests/test_api.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the FastAPI endpoints.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
from fastapi.testclient import TestClient
|
| 7 |
+
from httpx import AsyncClient, ASGITransport
|
| 8 |
+
|
| 9 |
+
from app.main import app
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# ============================================================
|
| 13 |
+
# Sync Test Client (for simple tests)
|
| 14 |
+
# ============================================================
|
| 15 |
+
|
| 16 |
+
client = TestClient(app)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TestRootEndpoints:
|
| 20 |
+
"""Tests for root endpoints."""
|
| 21 |
+
|
| 22 |
+
def test_root(self):
|
| 23 |
+
"""Test root endpoint."""
|
| 24 |
+
response = client.get("/")
|
| 25 |
+
assert response.status_code == 200
|
| 26 |
+
|
| 27 |
+
data = response.json()
|
| 28 |
+
assert "name" in data
|
| 29 |
+
assert "version" in data
|
| 30 |
+
assert "endpoints" in data
|
| 31 |
+
|
| 32 |
+
def test_health(self):
|
| 33 |
+
"""Test health endpoint."""
|
| 34 |
+
response = client.get("/health")
|
| 35 |
+
assert response.status_code == 200
|
| 36 |
+
|
| 37 |
+
data = response.json()
|
| 38 |
+
assert data["status"] == "healthy"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestToolsEndpoints:
|
| 42 |
+
"""Tests for tools endpoints."""
|
| 43 |
+
|
| 44 |
+
def test_list_tools(self):
|
| 45 |
+
"""Test listing tools."""
|
| 46 |
+
response = client.get("/tools/")
|
| 47 |
+
assert response.status_code == 200
|
| 48 |
+
|
| 49 |
+
data = response.json()
|
| 50 |
+
assert "tools" in data
|
| 51 |
+
assert "total" in data
|
| 52 |
+
assert data["total"] > 0
|
| 53 |
+
|
| 54 |
+
# Check that built-in tools are present
|
| 55 |
+
tool_names = [t["name"] for t in data["tools"]]
|
| 56 |
+
assert "extract_functions" in tool_names
|
| 57 |
+
assert "calculate_complexity" in tool_names
|
| 58 |
+
|
| 59 |
+
def test_get_tool(self):
|
| 60 |
+
"""Test getting a specific tool."""
|
| 61 |
+
response = client.get("/tools/extract_functions")
|
| 62 |
+
assert response.status_code == 200
|
| 63 |
+
|
| 64 |
+
data = response.json()
|
| 65 |
+
assert data["name"] == "extract_functions"
|
| 66 |
+
assert "description" in data
|
| 67 |
+
|
| 68 |
+
def test_get_nonexistent_tool(self):
|
| 69 |
+
"""Test getting a tool that doesn't exist."""
|
| 70 |
+
response = client.get("/tools/nonexistent_tool")
|
| 71 |
+
assert response.status_code == 404
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class TestGraphEndpoints:
|
| 75 |
+
"""Tests for graph endpoints."""
|
| 76 |
+
|
| 77 |
+
def test_list_graphs(self):
|
| 78 |
+
"""Test listing graphs."""
|
| 79 |
+
response = client.get("/graph/")
|
| 80 |
+
assert response.status_code == 200
|
| 81 |
+
|
| 82 |
+
data = response.json()
|
| 83 |
+
assert "graphs" in data
|
| 84 |
+
assert "total" in data
|
| 85 |
+
|
| 86 |
+
def test_get_demo_workflow(self):
|
| 87 |
+
"""Test getting the demo workflow."""
|
| 88 |
+
response = client.get("/graph/code-review-demo")
|
| 89 |
+
assert response.status_code == 200
|
| 90 |
+
|
| 91 |
+
data = response.json()
|
| 92 |
+
assert data["graph_id"] == "code-review-demo"
|
| 93 |
+
assert data["name"] == "Code Review Demo"
|
| 94 |
+
assert "mermaid_diagram" in data
|
| 95 |
+
|
| 96 |
+
def test_create_graph(self):
|
| 97 |
+
"""Test creating a new graph."""
|
| 98 |
+
graph_data = {
|
| 99 |
+
"name": "test_workflow",
|
| 100 |
+
"description": "A test workflow",
|
| 101 |
+
"nodes": [
|
| 102 |
+
{"name": "start", "handler": "extract_functions"},
|
| 103 |
+
{"name": "end", "handler": "calculate_complexity"}
|
| 104 |
+
],
|
| 105 |
+
"edges": {
|
| 106 |
+
"start": "end"
|
| 107 |
+
},
|
| 108 |
+
"entry_point": "start"
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
response = client.post("/graph/create", json=graph_data)
|
| 112 |
+
assert response.status_code == 201
|
| 113 |
+
|
| 114 |
+
data = response.json()
|
| 115 |
+
assert "graph_id" in data
|
| 116 |
+
assert data["name"] == "test_workflow"
|
| 117 |
+
assert data["node_count"] == 2
|
| 118 |
+
|
| 119 |
+
def test_create_graph_invalid_handler(self):
|
| 120 |
+
"""Test creating a graph with invalid handler."""
|
| 121 |
+
graph_data = {
|
| 122 |
+
"name": "invalid_workflow",
|
| 123 |
+
"nodes": [
|
| 124 |
+
{"name": "bad", "handler": "nonexistent_handler"}
|
| 125 |
+
],
|
| 126 |
+
"edges": {}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
response = client.post("/graph/create", json=graph_data)
|
| 130 |
+
assert response.status_code == 404
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# ============================================================
|
| 134 |
+
# Async Tests (for async endpoints)
|
| 135 |
+
# ============================================================
|
| 136 |
+
|
| 137 |
+
@pytest.fixture
|
| 138 |
+
def anyio_backend():
|
| 139 |
+
return "asyncio"
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@pytest.mark.asyncio
|
| 143 |
+
async def test_run_demo_workflow():
|
| 144 |
+
"""Test running the demo workflow."""
|
| 145 |
+
transport = ASGITransport(app=app)
|
| 146 |
+
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
| 147 |
+
run_data = {
|
| 148 |
+
"graph_id": "code-review-demo",
|
| 149 |
+
"initial_state": {
|
| 150 |
+
"code": "def hello():\n print('world')",
|
| 151 |
+
"quality_threshold": 5.0
|
| 152 |
+
},
|
| 153 |
+
"async_execution": False
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
response = await ac.post("/graph/run", json=run_data)
|
| 157 |
+
assert response.status_code == 200
|
| 158 |
+
|
| 159 |
+
data = response.json()
|
| 160 |
+
assert "run_id" in data
|
| 161 |
+
assert data["status"] in ["completed", "failed"]
|
| 162 |
+
assert "execution_log" in data
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@pytest.mark.asyncio
|
| 166 |
+
async def test_async_execution():
|
| 167 |
+
"""Test async execution mode."""
|
| 168 |
+
transport = ASGITransport(app=app)
|
| 169 |
+
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
| 170 |
+
run_data = {
|
| 171 |
+
"graph_id": "code-review-demo",
|
| 172 |
+
"initial_state": {
|
| 173 |
+
"code": "def test(): pass",
|
| 174 |
+
"quality_threshold": 5.0
|
| 175 |
+
},
|
| 176 |
+
"async_execution": True
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
response = await ac.post("/graph/run", json=run_data)
|
| 180 |
+
assert response.status_code == 200
|
| 181 |
+
|
| 182 |
+
data = response.json()
|
| 183 |
+
assert "run_id" in data
|
| 184 |
+
assert data["status"] == "pending"
|
| 185 |
+
|
| 186 |
+
# Check run state
|
| 187 |
+
run_id = data["run_id"]
|
| 188 |
+
state_response = await ac.get(f"/graph/state/{run_id}")
|
| 189 |
+
assert state_response.status_code == 200
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@pytest.mark.asyncio
|
| 193 |
+
async def test_run_nonexistent_graph():
|
| 194 |
+
"""Test running a graph that doesn't exist."""
|
| 195 |
+
transport = ASGITransport(app=app)
|
| 196 |
+
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
| 197 |
+
run_data = {
|
| 198 |
+
"graph_id": "nonexistent-graph",
|
| 199 |
+
"initial_state": {}
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
response = await ac.post("/graph/run", json=run_data)
|
| 203 |
+
assert response.status_code == 404
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
if __name__ == "__main__":
|
| 207 |
+
pytest.main([__file__, "-v"])
|
tests/test_engine.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Tests for the Workflow Engine core components.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
import asyncio
|
| 7 |
+
from typing import Dict, Any
|
| 8 |
+
|
| 9 |
+
from app.engine.state import WorkflowState, StateManager
|
| 10 |
+
from app.engine.node import Node, NodeType, node, create_node_from_function
|
| 11 |
+
from app.engine.graph import Graph, END
|
| 12 |
+
from app.engine.executor import Executor, ExecutionStatus, execute_graph
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ============================================================
|
| 16 |
+
# State Tests
|
| 17 |
+
# ============================================================
|
| 18 |
+
|
| 19 |
+
class TestWorkflowState:
|
| 20 |
+
"""Tests for WorkflowState."""
|
| 21 |
+
|
| 22 |
+
def test_create_empty_state(self):
|
| 23 |
+
"""Test creating an empty state."""
|
| 24 |
+
state = WorkflowState()
|
| 25 |
+
assert state.data == {}
|
| 26 |
+
assert state.iteration == 0
|
| 27 |
+
assert state.visited_nodes == []
|
| 28 |
+
|
| 29 |
+
def test_create_state_with_data(self):
|
| 30 |
+
"""Test creating state with initial data."""
|
| 31 |
+
state = WorkflowState(data={"key": "value"})
|
| 32 |
+
assert state.get("key") == "value"
|
| 33 |
+
assert state.get("missing") is None
|
| 34 |
+
assert state.get("missing", "default") == "default"
|
| 35 |
+
|
| 36 |
+
def test_state_immutability(self):
|
| 37 |
+
"""Test that state updates return new instances."""
|
| 38 |
+
state1 = WorkflowState(data={"a": 1})
|
| 39 |
+
state2 = state1.set("b", 2)
|
| 40 |
+
|
| 41 |
+
assert state1.get("b") is None
|
| 42 |
+
assert state2.get("b") == 2
|
| 43 |
+
assert state1 is not state2
|
| 44 |
+
|
| 45 |
+
def test_state_update_multiple(self):
|
| 46 |
+
"""Test updating multiple values at once."""
|
| 47 |
+
state = WorkflowState(data={"a": 1})
|
| 48 |
+
new_state = state.update({"b": 2, "c": 3})
|
| 49 |
+
|
| 50 |
+
assert new_state.get("a") == 1
|
| 51 |
+
assert new_state.get("b") == 2
|
| 52 |
+
assert new_state.get("c") == 3
|
| 53 |
+
|
| 54 |
+
def test_state_mark_visited(self):
|
| 55 |
+
"""Test marking nodes as visited."""
|
| 56 |
+
state = WorkflowState()
|
| 57 |
+
state = state.mark_visited("node1")
|
| 58 |
+
state = state.mark_visited("node2")
|
| 59 |
+
|
| 60 |
+
assert "node1" in state.visited_nodes
|
| 61 |
+
assert "node2" in state.visited_nodes
|
| 62 |
+
assert state.current_node == "node2"
|
| 63 |
+
|
| 64 |
+
def test_state_to_from_dict(self):
|
| 65 |
+
"""Test serialization and deserialization."""
|
| 66 |
+
state = WorkflowState(data={"test": 123})
|
| 67 |
+
state_dict = state.to_dict()
|
| 68 |
+
|
| 69 |
+
assert "data" in state_dict
|
| 70 |
+
assert state_dict["data"]["test"] == 123
|
| 71 |
+
|
| 72 |
+
restored = WorkflowState.from_dict(state_dict)
|
| 73 |
+
assert restored.get("test") == 123
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TestStateManager:
|
| 77 |
+
"""Tests for StateManager."""
|
| 78 |
+
|
| 79 |
+
def test_initialize(self):
|
| 80 |
+
"""Test state manager initialization."""
|
| 81 |
+
manager = StateManager()
|
| 82 |
+
state = manager.initialize({"input": "test"})
|
| 83 |
+
|
| 84 |
+
assert manager.current_state is not None
|
| 85 |
+
assert manager.current_state.get("input") == "test"
|
| 86 |
+
assert manager.current_state.started_at is not None
|
| 87 |
+
|
| 88 |
+
def test_update_and_history(self):
|
| 89 |
+
"""Test state updates create history."""
|
| 90 |
+
manager = StateManager()
|
| 91 |
+
state = manager.initialize({"count": 0})
|
| 92 |
+
|
| 93 |
+
new_state = state.set("count", 1)
|
| 94 |
+
manager.update(new_state, "node1")
|
| 95 |
+
|
| 96 |
+
assert len(manager.history) == 1
|
| 97 |
+
assert manager.history[0].node_name == "node1"
|
| 98 |
+
assert manager.current_state.get("count") == 1
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ============================================================
|
| 102 |
+
# Node Tests
|
| 103 |
+
# ============================================================
|
| 104 |
+
|
| 105 |
+
class TestNode:
|
| 106 |
+
"""Tests for Node class."""
|
| 107 |
+
|
| 108 |
+
def test_create_node(self):
|
| 109 |
+
"""Test creating a node."""
|
| 110 |
+
def handler(state):
|
| 111 |
+
return state
|
| 112 |
+
|
| 113 |
+
n = Node(name="test_node", handler=handler)
|
| 114 |
+
|
| 115 |
+
assert n.name == "test_node"
|
| 116 |
+
assert n.handler == handler
|
| 117 |
+
assert n.node_type == NodeType.STANDARD
|
| 118 |
+
|
| 119 |
+
def test_node_validation(self):
|
| 120 |
+
"""Test node validation."""
|
| 121 |
+
with pytest.raises(ValueError, match="name cannot be empty"):
|
| 122 |
+
Node(name="", handler=lambda x: x)
|
| 123 |
+
|
| 124 |
+
with pytest.raises(ValueError, match="must be callable"):
|
| 125 |
+
Node(name="test", handler="not a function")
|
| 126 |
+
|
| 127 |
+
@pytest.mark.asyncio
|
| 128 |
+
async def test_sync_node_execution(self):
|
| 129 |
+
"""Test executing a sync node."""
|
| 130 |
+
def handler(state):
|
| 131 |
+
state["processed"] = True
|
| 132 |
+
return state
|
| 133 |
+
|
| 134 |
+
n = Node(name="test", handler=handler)
|
| 135 |
+
result = await n.execute({"input": "data"})
|
| 136 |
+
|
| 137 |
+
assert result["processed"] is True
|
| 138 |
+
assert result["input"] == "data"
|
| 139 |
+
|
| 140 |
+
@pytest.mark.asyncio
|
| 141 |
+
async def test_async_node_execution(self):
|
| 142 |
+
"""Test executing an async node."""
|
| 143 |
+
async def async_handler(state):
|
| 144 |
+
await asyncio.sleep(0.01)
|
| 145 |
+
state["async_processed"] = True
|
| 146 |
+
return state
|
| 147 |
+
|
| 148 |
+
n = Node(name="async_test", handler=async_handler)
|
| 149 |
+
assert n.is_async is True
|
| 150 |
+
|
| 151 |
+
result = await n.execute({"input": "data"})
|
| 152 |
+
assert result["async_processed"] is True
|
| 153 |
+
|
| 154 |
+
def test_node_decorator(self):
|
| 155 |
+
"""Test the @node decorator."""
|
| 156 |
+
@node(name="decorated_node", description="A test node")
|
| 157 |
+
def my_handler(state):
|
| 158 |
+
return state
|
| 159 |
+
|
| 160 |
+
assert hasattr(my_handler, "_node_metadata")
|
| 161 |
+
assert my_handler._node_metadata["name"] == "decorated_node"
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ============================================================
|
| 165 |
+
# Graph Tests
|
| 166 |
+
# ============================================================
|
| 167 |
+
|
| 168 |
+
class TestGraph:
|
| 169 |
+
"""Tests for Graph class."""
|
| 170 |
+
|
| 171 |
+
def test_create_graph(self):
|
| 172 |
+
"""Test creating a graph."""
|
| 173 |
+
graph = Graph(name="Test Graph")
|
| 174 |
+
assert graph.name == "Test Graph"
|
| 175 |
+
assert len(graph.nodes) == 0
|
| 176 |
+
|
| 177 |
+
def test_add_nodes(self):
|
| 178 |
+
"""Test adding nodes to a graph."""
|
| 179 |
+
graph = Graph()
|
| 180 |
+
graph.add_node("node1", handler=lambda s: s)
|
| 181 |
+
graph.add_node("node2", handler=lambda s: s)
|
| 182 |
+
|
| 183 |
+
assert "node1" in graph.nodes
|
| 184 |
+
assert "node2" in graph.nodes
|
| 185 |
+
assert graph.entry_point == "node1" # First node is entry
|
| 186 |
+
|
| 187 |
+
def test_add_edges(self):
|
| 188 |
+
"""Test adding edges."""
|
| 189 |
+
graph = Graph()
|
| 190 |
+
graph.add_node("a", handler=lambda s: s)
|
| 191 |
+
graph.add_node("b", handler=lambda s: s)
|
| 192 |
+
graph.add_edge("a", "b")
|
| 193 |
+
|
| 194 |
+
assert graph.edges["a"] == "b"
|
| 195 |
+
|
| 196 |
+
def test_add_edge_to_end(self):
|
| 197 |
+
"""Test adding edge to END."""
|
| 198 |
+
graph = Graph()
|
| 199 |
+
graph.add_node("a", handler=lambda s: s)
|
| 200 |
+
graph.add_edge("a", END)
|
| 201 |
+
|
| 202 |
+
assert graph.edges["a"] == END
|
| 203 |
+
|
| 204 |
+
def test_invalid_edge(self):
|
| 205 |
+
"""Test adding invalid edges raises error."""
|
| 206 |
+
graph = Graph()
|
| 207 |
+
graph.add_node("a", handler=lambda s: s)
|
| 208 |
+
|
| 209 |
+
with pytest.raises(ValueError, match="not found"):
|
| 210 |
+
graph.add_edge("a", "nonexistent")
|
| 211 |
+
|
| 212 |
+
def test_conditional_edge(self):
|
| 213 |
+
"""Test conditional edges."""
|
| 214 |
+
graph = Graph()
|
| 215 |
+
graph.add_node("check", handler=lambda s: s)
|
| 216 |
+
graph.add_node("yes", handler=lambda s: s)
|
| 217 |
+
graph.add_node("no", handler=lambda s: s)
|
| 218 |
+
|
| 219 |
+
def condition(state):
|
| 220 |
+
return "yes" if state.get("value") else "no"
|
| 221 |
+
|
| 222 |
+
graph.add_conditional_edge("check", condition, {"yes": "yes", "no": "no"})
|
| 223 |
+
|
| 224 |
+
# Test routing
|
| 225 |
+
assert graph.get_next_node("check", {"value": True}) == "yes"
|
| 226 |
+
assert graph.get_next_node("check", {"value": False}) == "no"
|
| 227 |
+
|
| 228 |
+
def test_graph_validation(self):
|
| 229 |
+
"""Test graph validation."""
|
| 230 |
+
graph = Graph()
|
| 231 |
+
|
| 232 |
+
# Empty graph should fail
|
| 233 |
+
errors = graph.validate()
|
| 234 |
+
assert len(errors) > 0
|
| 235 |
+
|
| 236 |
+
# Valid graph
|
| 237 |
+
graph.add_node("start", handler=lambda s: s)
|
| 238 |
+
graph.add_edge("start", END)
|
| 239 |
+
|
| 240 |
+
errors = graph.validate()
|
| 241 |
+
assert len(errors) == 0
|
| 242 |
+
|
| 243 |
+
def test_mermaid_generation(self):
|
| 244 |
+
"""Test Mermaid diagram generation."""
|
| 245 |
+
graph = Graph()
|
| 246 |
+
graph.add_node("a", handler=lambda s: s)
|
| 247 |
+
graph.add_node("b", handler=lambda s: s)
|
| 248 |
+
graph.add_edge("a", "b")
|
| 249 |
+
graph.add_edge("b", END)
|
| 250 |
+
|
| 251 |
+
mermaid = graph.to_mermaid()
|
| 252 |
+
|
| 253 |
+
assert "graph TD" in mermaid
|
| 254 |
+
assert "a" in mermaid
|
| 255 |
+
assert "b" in mermaid
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# ============================================================
|
| 259 |
+
# Executor Tests
|
| 260 |
+
# ============================================================
|
| 261 |
+
|
| 262 |
+
class TestExecutor:
|
| 263 |
+
"""Tests for the Executor."""
|
| 264 |
+
|
| 265 |
+
@pytest.mark.asyncio
|
| 266 |
+
async def test_simple_execution(self):
|
| 267 |
+
"""Test executing a simple graph."""
|
| 268 |
+
graph = Graph()
|
| 269 |
+
graph.add_node("double", handler=lambda s: {**s, "value": s["value"] * 2})
|
| 270 |
+
graph.add_edge("double", END)
|
| 271 |
+
|
| 272 |
+
result = await execute_graph(graph, {"value": 5})
|
| 273 |
+
|
| 274 |
+
assert result.status == ExecutionStatus.COMPLETED
|
| 275 |
+
assert result.final_state["value"] == 10
|
| 276 |
+
|
| 277 |
+
@pytest.mark.asyncio
|
| 278 |
+
async def test_multi_node_execution(self):
|
| 279 |
+
"""Test executing multiple nodes."""
|
| 280 |
+
graph = Graph()
|
| 281 |
+
graph.add_node("add1", handler=lambda s: {**s, "value": s["value"] + 1})
|
| 282 |
+
graph.add_node("add2", handler=lambda s: {**s, "value": s["value"] + 2})
|
| 283 |
+
graph.add_edge("add1", "add2")
|
| 284 |
+
graph.add_edge("add2", END)
|
| 285 |
+
|
| 286 |
+
result = await execute_graph(graph, {"value": 0})
|
| 287 |
+
|
| 288 |
+
assert result.status == ExecutionStatus.COMPLETED
|
| 289 |
+
assert result.final_state["value"] == 3
|
| 290 |
+
assert len(result.execution_log) == 2
|
| 291 |
+
|
| 292 |
+
@pytest.mark.asyncio
|
| 293 |
+
async def test_conditional_execution(self):
|
| 294 |
+
"""Test conditional branching."""
|
| 295 |
+
graph = Graph()
|
| 296 |
+
graph.add_node("start", handler=lambda s: s)
|
| 297 |
+
graph.add_node("high", handler=lambda s: {**s, "path": "high"})
|
| 298 |
+
graph.add_node("low", handler=lambda s: {**s, "path": "low"})
|
| 299 |
+
|
| 300 |
+
def route(state):
|
| 301 |
+
return "high" if state["value"] > 5 else "low"
|
| 302 |
+
|
| 303 |
+
graph.add_conditional_edge("start", route, {"high": "high", "low": "low"})
|
| 304 |
+
graph.add_edge("high", END)
|
| 305 |
+
graph.add_edge("low", END)
|
| 306 |
+
|
| 307 |
+
# Test high path
|
| 308 |
+
result = await execute_graph(graph, {"value": 10})
|
| 309 |
+
assert result.final_state["path"] == "high"
|
| 310 |
+
|
| 311 |
+
# Test low path
|
| 312 |
+
result = await execute_graph(graph, {"value": 2})
|
| 313 |
+
assert result.final_state["path"] == "low"
|
| 314 |
+
|
| 315 |
+
@pytest.mark.asyncio
|
| 316 |
+
async def test_loop_execution(self):
|
| 317 |
+
"""Test looping execution."""
|
| 318 |
+
graph = Graph(max_iterations=10)
|
| 319 |
+
|
| 320 |
+
def increment(state):
|
| 321 |
+
return {**state, "count": state["count"] + 1}
|
| 322 |
+
|
| 323 |
+
def check_count(state):
|
| 324 |
+
return "done" if state["count"] >= 3 else "continue"
|
| 325 |
+
|
| 326 |
+
graph.add_node("increment", handler=increment)
|
| 327 |
+
graph.add_conditional_edge("increment", check_count, {"done": END, "continue": "increment"})
|
| 328 |
+
|
| 329 |
+
result = await execute_graph(graph, {"count": 0})
|
| 330 |
+
|
| 331 |
+
assert result.status == ExecutionStatus.COMPLETED
|
| 332 |
+
assert result.final_state["count"] == 3
|
| 333 |
+
|
| 334 |
+
@pytest.mark.asyncio
|
| 335 |
+
async def test_max_iterations(self):
|
| 336 |
+
"""Test max iterations limit."""
|
| 337 |
+
graph = Graph(max_iterations=3)
|
| 338 |
+
|
| 339 |
+
# Infinite loop
|
| 340 |
+
graph.add_node("loop", handler=lambda s: s)
|
| 341 |
+
graph.add_conditional_edge("loop", lambda s: "continue", {"continue": "loop"})
|
| 342 |
+
|
| 343 |
+
result = await execute_graph(graph, {})
|
| 344 |
+
|
| 345 |
+
assert result.status == ExecutionStatus.FAILED
|
| 346 |
+
assert "Max iterations" in result.error
|
| 347 |
+
|
| 348 |
+
@pytest.mark.asyncio
|
| 349 |
+
async def test_error_handling(self):
|
| 350 |
+
"""Test error handling during execution."""
|
| 351 |
+
def failing_handler(state):
|
| 352 |
+
raise ValueError("Intentional error")
|
| 353 |
+
|
| 354 |
+
graph = Graph()
|
| 355 |
+
graph.add_node("fail", handler=failing_handler)
|
| 356 |
+
|
| 357 |
+
result = await execute_graph(graph, {})
|
| 358 |
+
|
| 359 |
+
assert result.status == ExecutionStatus.FAILED
|
| 360 |
+
assert "Intentional error" in result.error
|
| 361 |
+
|
| 362 |
+
@pytest.mark.asyncio
|
| 363 |
+
async def test_execution_log(self):
|
| 364 |
+
"""Test that execution log is properly generated."""
|
| 365 |
+
graph = Graph()
|
| 366 |
+
graph.add_node("step1", handler=lambda s: s)
|
| 367 |
+
graph.add_node("step2", handler=lambda s: s)
|
| 368 |
+
graph.add_edge("step1", "step2")
|
| 369 |
+
graph.add_edge("step2", END)
|
| 370 |
+
|
| 371 |
+
result = await execute_graph(graph, {})
|
| 372 |
+
|
| 373 |
+
assert len(result.execution_log) == 2
|
| 374 |
+
assert result.execution_log[0].node == "step1"
|
| 375 |
+
assert result.execution_log[1].node == "step2"
|
| 376 |
+
assert all(s.duration_ms > 0 for s in result.execution_log)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# ============================================================
|
| 380 |
+
# Integration Tests
|
| 381 |
+
# ============================================================
|
| 382 |
+
|
| 383 |
+
class TestCodeReviewWorkflow:
|
| 384 |
+
"""Integration tests for the Code Review workflow."""
|
| 385 |
+
|
| 386 |
+
@pytest.mark.asyncio
|
| 387 |
+
async def test_code_review_workflow(self):
|
| 388 |
+
"""Test the full code review workflow."""
|
| 389 |
+
from app.workflows.code_review import create_code_review_workflow
|
| 390 |
+
|
| 391 |
+
sample_code = '''
|
| 392 |
+
def hello():
|
| 393 |
+
"""Says hello."""
|
| 394 |
+
print("Hello, World!")
|
| 395 |
+
|
| 396 |
+
def add(a, b):
|
| 397 |
+
return a + b
|
| 398 |
+
'''
|
| 399 |
+
|
| 400 |
+
workflow = create_code_review_workflow(max_iterations=3, quality_threshold=5.0)
|
| 401 |
+
result = await execute_graph(workflow, {
|
| 402 |
+
"code": sample_code,
|
| 403 |
+
"quality_threshold": 5.0,
|
| 404 |
+
})
|
| 405 |
+
|
| 406 |
+
assert result.status == ExecutionStatus.COMPLETED
|
| 407 |
+
assert "functions" in result.final_state
|
| 408 |
+
assert "quality_score" in result.final_state
|
| 409 |
+
assert len(result.execution_log) > 0
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
if __name__ == "__main__":
|
| 413 |
+
pytest.main([__file__, "-v"])
|