PYAE1994's picture
feat: GOD MODE+ v4.0 - tools/task_dag.py
5f8b502 verified
"""
Task DAG Engine β€” Devin-style Task Graph
Plans, tracks, and executes tasks as a Directed Acyclic Graph
"""
import asyncio
import json
import time
import uuid
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Set
import structlog
log = structlog.get_logger()
class StepStatus(str, Enum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
SKIPPED = "skipped"
RETRYING = "retrying"
class TaskNode:
"""Single node in the task DAG."""
def __init__(
self,
node_id: str,
name: str,
description: str = "",
tool: str = "none",
depends_on: Optional[List[str]] = None,
retries: int = 2,
timeout: int = 120,
metadata: Optional[Dict] = None,
):
self.id = node_id
self.name = name
self.description = description
self.tool = tool
self.depends_on: List[str] = depends_on or []
self.retries = retries
self.timeout = timeout
self.metadata = metadata or {}
self.status = StepStatus.PENDING
self.result: Optional[str] = None
self.error: Optional[str] = None
self.attempt = 0
self.started_at: Optional[float] = None
self.completed_at: Optional[float] = None
def to_dict(self) -> Dict:
return {
"id": self.id,
"name": self.name,
"description": self.description,
"tool": self.tool,
"depends_on": self.depends_on,
"status": self.status.value,
"result": (self.result or "")[:300],
"error": self.error,
"attempt": self.attempt,
"started_at": self.started_at,
"completed_at": self.completed_at,
"duration": round(self.completed_at - self.started_at, 2) if self.started_at and self.completed_at else None,
}
def is_ready(self, completed_ids: Set[str]) -> bool:
"""Check if all dependencies are met."""
return all(dep in completed_ids for dep in self.depends_on)
class TaskDAG:
"""
Directed Acyclic Graph of tasks.
Supports: parallel execution, dependency resolution, retry, rollback.
"""
def __init__(self, dag_id: str, goal: str):
self.id = dag_id
self.goal = goal
self.nodes: Dict[str, TaskNode] = {}
self.created_at = time.time()
self.started_at: Optional[float] = None
self.completed_at: Optional[float] = None
self.status = "pending"
self.result: Optional[str] = None
def add_node(self, node: TaskNode) -> "TaskDAG":
self.nodes[node.id] = node
return self
def get_ready_nodes(self) -> List[TaskNode]:
"""Get nodes whose dependencies are all completed."""
completed = {nid for nid, n in self.nodes.items() if n.status == StepStatus.COMPLETED}
return [
n for n in self.nodes.values()
if n.status == StepStatus.PENDING and n.is_ready(completed)
]
def get_progress(self) -> Dict:
total = len(self.nodes)
completed = sum(1 for n in self.nodes.values() if n.status == StepStatus.COMPLETED)
failed = sum(1 for n in self.nodes.values() if n.status == StepStatus.FAILED)
running = sum(1 for n in self.nodes.values() if n.status == StepStatus.RUNNING)
pending = sum(1 for n in self.nodes.values() if n.status == StepStatus.PENDING)
return {
"total": total,
"completed": completed,
"failed": failed,
"running": running,
"pending": pending,
"percent": round((completed / total * 100) if total > 0 else 0, 1),
}
def is_complete(self) -> bool:
return all(
n.status in (StepStatus.COMPLETED, StepStatus.FAILED, StepStatus.SKIPPED)
for n in self.nodes.values()
)
def has_failed(self) -> bool:
return any(n.status == StepStatus.FAILED for n in self.nodes.values())
def to_dict(self) -> Dict:
progress = self.get_progress()
return {
"id": self.id,
"goal": self.goal,
"status": self.status,
"progress": progress,
"nodes": [n.to_dict() for n in self.nodes.values()],
"created_at": self.created_at,
"started_at": self.started_at,
"completed_at": self.completed_at,
"duration": round(self.completed_at - self.started_at, 2) if self.started_at and self.completed_at else None,
}
class DAGEngine:
"""
Executes TaskDAGs with:
- Parallel execution of independent nodes
- Dependency-aware scheduling
- Per-node retry logic
- Real-time WebSocket streaming
- Rollback support
"""
def __init__(self, ws_manager=None):
self.ws = ws_manager
self._active_dags: Dict[str, TaskDAG] = {}
# ─── Build DAG from Plan ───────────────────────────────────────────────────
def build_from_steps(self, steps: List[Dict], goal: str = "") -> TaskDAG:
"""Convert flat step list into DAG with sequential dependencies."""
dag_id = f"dag_{uuid.uuid4().hex[:8]}"
dag = TaskDAG(dag_id, goal)
prev_id = None
for i, step in enumerate(steps):
node_id = step.get("id") or f"step_{i+1}"
deps = step.get("depends_on") or ([prev_id] if prev_id else [])
node = TaskNode(
node_id=node_id,
name=step.get("name", f"Step {i+1}"),
description=step.get("description", ""),
tool=step.get("tool", "none"),
depends_on=deps,
retries=step.get("retries", 2),
timeout=step.get("timeout", 120),
metadata=step.get("metadata", {}),
)
dag.add_node(node)
prev_id = node_id
return dag
def build_saas_dag(self, project_name: str) -> TaskDAG:
"""Pre-built DAG for full SaaS project scaffolding."""
dag_id = f"saas_{uuid.uuid4().hex[:8]}"
dag = TaskDAG(dag_id, f"Build SaaS: {project_name}")
nodes = [
TaskNode("plan", "Planning", "Analyze requirements and create architecture plan", "none", []),
TaskNode("scaffold", "Scaffold Project", "Create project structure and base files", "shell", ["plan"]),
TaskNode("backend", "Build Backend", "Generate API, routes, models", "code", ["scaffold"]),
TaskNode("frontend", "Build Frontend", "Generate UI components and pages", "code", ["scaffold"]),
TaskNode("db", "Setup Database", "Create DB schema, migrations", "shell", ["backend"]),
TaskNode("auth", "Add Auth", "Implement authentication system", "code", ["backend", "db"]),
TaskNode("tests", "Write Tests", "Generate unit and integration tests", "code", ["backend", "frontend"]),
TaskNode("lint", "Lint & Format", "Run linters and formatters", "shell", ["backend", "frontend"]),
TaskNode("git_init", "Init Git Repo", "Initialize git and make first commit", "github", ["scaffold"]),
TaskNode("deploy", "Deploy", "Deploy to Vercel/Cloudflare", "shell", ["tests", "lint"]),
TaskNode("verify", "Verify Deployment", "Check deployment URL and health", "none", ["deploy"]),
]
for n in nodes:
dag.add_node(n)
return dag
# ─── Execute DAG ───────────────────────────────────────────────────────────
async def execute(
self,
dag: TaskDAG,
executor: Callable,
session_id: str = "",
task_id: str = "",
max_parallel: int = 3,
) -> Dict:
"""
Execute a DAG with dependency-aware parallel scheduling.
executor: async fn(node, context) -> str
"""
self._active_dags[dag.id] = dag
dag.status = "running"
dag.started_at = time.time()
results: Dict[str, str] = {}
await self._emit(task_id, session_id, "dag_started", {
"dag_id": dag.id,
"goal": dag.goal,
"total_nodes": len(dag.nodes),
"nodes": [n.to_dict() for n in dag.nodes.values()],
})
semaphore = asyncio.Semaphore(max_parallel)
while not dag.is_complete():
ready = dag.get_ready_nodes()
if not ready:
# All ready nodes are running β€” wait
await asyncio.sleep(0.5)
continue
# Launch all ready nodes in parallel (up to semaphore limit)
tasks = []
for node in ready:
node.status = StepStatus.RUNNING
node.started_at = time.time()
await self._emit(task_id, session_id, "dag_node_started", {
"node_id": node.id,
"name": node.name,
"tool": node.tool,
"dag_id": dag.id,
"progress": dag.get_progress(),
})
t = asyncio.create_task(
self._execute_node(node, dag, results, executor, semaphore, session_id, task_id)
)
tasks.append(t)
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
# Check progress
await self._emit(task_id, session_id, "dag_progress", {
"dag_id": dag.id,
"progress": dag.get_progress(),
"nodes": [n.to_dict() for n in dag.nodes.values()],
})
dag.completed_at = time.time()
dag.status = "completed" if not dag.has_failed() else "partial_failure"
# Compile final result
completed_results = {
nid: n.result for nid, n in dag.nodes.items()
if n.status == StepStatus.COMPLETED and n.result
}
await self._emit(task_id, session_id, "dag_completed", {
"dag_id": dag.id,
"status": dag.status,
"progress": dag.get_progress(),
"duration": round(dag.completed_at - dag.started_at, 2),
"nodes": [n.to_dict() for n in dag.nodes.values()],
})
return {
"success": not dag.has_failed(),
"dag_id": dag.id,
"status": dag.status,
"progress": dag.get_progress(),
"results": completed_results,
"nodes": [n.to_dict() for n in dag.nodes.values()],
}
async def _execute_node(
self,
node: TaskNode,
dag: TaskDAG,
results: Dict,
executor: Callable,
semaphore: asyncio.Semaphore,
session_id: str,
task_id: str,
):
async with semaphore:
context = {
"goal": dag.goal,
"previous_results": {k: v for k, v in results.items()},
"node_metadata": node.metadata,
}
for attempt in range(1, node.retries + 2):
node.attempt = attempt
try:
result = await asyncio.wait_for(
executor(node, context),
timeout=node.timeout,
)
node.result = str(result)
node.status = StepStatus.COMPLETED
node.completed_at = time.time()
results[node.id] = node.result
await self._emit(task_id, session_id, "dag_node_completed", {
"node_id": node.id,
"name": node.name,
"dag_id": dag.id,
"result": node.result[:200],
"duration": round(node.completed_at - node.started_at, 2),
"attempt": attempt,
"progress": dag.get_progress(),
})
return
except asyncio.TimeoutError:
node.error = f"Timeout after {node.timeout}s"
log.warning("Node timeout", node=node.name, attempt=attempt)
except Exception as e:
node.error = str(e)
log.warning("Node error", node=node.name, attempt=attempt, error=str(e))
if attempt <= node.retries:
node.status = StepStatus.RETRYING
await self._emit(task_id, session_id, "dag_node_retry", {
"node_id": node.id,
"name": node.name,
"attempt": attempt,
"max_retries": node.retries,
"error": node.error,
})
await asyncio.sleep(2 ** (attempt - 1))
node.status = StepStatus.FAILED
node.completed_at = time.time()
await self._emit(task_id, session_id, "dag_node_failed", {
"node_id": node.id,
"name": node.name,
"dag_id": dag.id,
"error": node.error,
"attempts": node.attempt,
})
# ─── Get Active DAG ───────────────────────────────────────────────────────
def get_dag(self, dag_id: str) -> Optional[TaskDAG]:
return self._active_dags.get(dag_id)
def get_all_dags(self) -> List[Dict]:
return [dag.to_dict() for dag in self._active_dags.values()]
# ─── Emit ─────────────────────────────────────────────────────────────────
async def _emit(self, task_id: str, session_id: str, event: str, data: Dict):
if not self.ws:
return
try:
if task_id:
await self.ws.emit(task_id, event, data, session_id=session_id)
if session_id:
await self.ws.emit_chat(session_id, event, data)
except Exception:
pass