| """ |
| Task DAG Engine β Devin-style Task Graph |
| Plans, tracks, and executes tasks as a Directed Acyclic Graph |
| """ |
|
|
| import asyncio |
| import json |
| import time |
| import uuid |
| from enum import Enum |
| from typing import Any, Callable, Dict, List, Optional, Set |
|
|
| import structlog |
|
|
| log = structlog.get_logger() |
|
|
|
|
| class StepStatus(str, Enum): |
| PENDING = "pending" |
| RUNNING = "running" |
| COMPLETED = "completed" |
| FAILED = "failed" |
| SKIPPED = "skipped" |
| RETRYING = "retrying" |
|
|
|
|
| class TaskNode: |
| """Single node in the task DAG.""" |
|
|
| def __init__( |
| self, |
| node_id: str, |
| name: str, |
| description: str = "", |
| tool: str = "none", |
| depends_on: Optional[List[str]] = None, |
| retries: int = 2, |
| timeout: int = 120, |
| metadata: Optional[Dict] = None, |
| ): |
| self.id = node_id |
| self.name = name |
| self.description = description |
| self.tool = tool |
| self.depends_on: List[str] = depends_on or [] |
| self.retries = retries |
| self.timeout = timeout |
| self.metadata = metadata or {} |
| self.status = StepStatus.PENDING |
| self.result: Optional[str] = None |
| self.error: Optional[str] = None |
| self.attempt = 0 |
| self.started_at: Optional[float] = None |
| self.completed_at: Optional[float] = None |
|
|
| def to_dict(self) -> Dict: |
| return { |
| "id": self.id, |
| "name": self.name, |
| "description": self.description, |
| "tool": self.tool, |
| "depends_on": self.depends_on, |
| "status": self.status.value, |
| "result": (self.result or "")[:300], |
| "error": self.error, |
| "attempt": self.attempt, |
| "started_at": self.started_at, |
| "completed_at": self.completed_at, |
| "duration": round(self.completed_at - self.started_at, 2) if self.started_at and self.completed_at else None, |
| } |
|
|
| def is_ready(self, completed_ids: Set[str]) -> bool: |
| """Check if all dependencies are met.""" |
| return all(dep in completed_ids for dep in self.depends_on) |
|
|
|
|
| class TaskDAG: |
| """ |
| Directed Acyclic Graph of tasks. |
| Supports: parallel execution, dependency resolution, retry, rollback. |
| """ |
|
|
| def __init__(self, dag_id: str, goal: str): |
| self.id = dag_id |
| self.goal = goal |
| self.nodes: Dict[str, TaskNode] = {} |
| self.created_at = time.time() |
| self.started_at: Optional[float] = None |
| self.completed_at: Optional[float] = None |
| self.status = "pending" |
| self.result: Optional[str] = None |
|
|
| def add_node(self, node: TaskNode) -> "TaskDAG": |
| self.nodes[node.id] = node |
| return self |
|
|
| def get_ready_nodes(self) -> List[TaskNode]: |
| """Get nodes whose dependencies are all completed.""" |
| completed = {nid for nid, n in self.nodes.items() if n.status == StepStatus.COMPLETED} |
| return [ |
| n for n in self.nodes.values() |
| if n.status == StepStatus.PENDING and n.is_ready(completed) |
| ] |
|
|
| def get_progress(self) -> Dict: |
| total = len(self.nodes) |
| completed = sum(1 for n in self.nodes.values() if n.status == StepStatus.COMPLETED) |
| failed = sum(1 for n in self.nodes.values() if n.status == StepStatus.FAILED) |
| running = sum(1 for n in self.nodes.values() if n.status == StepStatus.RUNNING) |
| pending = sum(1 for n in self.nodes.values() if n.status == StepStatus.PENDING) |
| return { |
| "total": total, |
| "completed": completed, |
| "failed": failed, |
| "running": running, |
| "pending": pending, |
| "percent": round((completed / total * 100) if total > 0 else 0, 1), |
| } |
|
|
| def is_complete(self) -> bool: |
| return all( |
| n.status in (StepStatus.COMPLETED, StepStatus.FAILED, StepStatus.SKIPPED) |
| for n in self.nodes.values() |
| ) |
|
|
| def has_failed(self) -> bool: |
| return any(n.status == StepStatus.FAILED for n in self.nodes.values()) |
|
|
| def to_dict(self) -> Dict: |
| progress = self.get_progress() |
| return { |
| "id": self.id, |
| "goal": self.goal, |
| "status": self.status, |
| "progress": progress, |
| "nodes": [n.to_dict() for n in self.nodes.values()], |
| "created_at": self.created_at, |
| "started_at": self.started_at, |
| "completed_at": self.completed_at, |
| "duration": round(self.completed_at - self.started_at, 2) if self.started_at and self.completed_at else None, |
| } |
|
|
|
|
| class DAGEngine: |
| """ |
| Executes TaskDAGs with: |
| - Parallel execution of independent nodes |
| - Dependency-aware scheduling |
| - Per-node retry logic |
| - Real-time WebSocket streaming |
| - Rollback support |
| """ |
|
|
| def __init__(self, ws_manager=None): |
| self.ws = ws_manager |
| self._active_dags: Dict[str, TaskDAG] = {} |
|
|
| |
|
|
| def build_from_steps(self, steps: List[Dict], goal: str = "") -> TaskDAG: |
| """Convert flat step list into DAG with sequential dependencies.""" |
| dag_id = f"dag_{uuid.uuid4().hex[:8]}" |
| dag = TaskDAG(dag_id, goal) |
| prev_id = None |
| for i, step in enumerate(steps): |
| node_id = step.get("id") or f"step_{i+1}" |
| deps = step.get("depends_on") or ([prev_id] if prev_id else []) |
| node = TaskNode( |
| node_id=node_id, |
| name=step.get("name", f"Step {i+1}"), |
| description=step.get("description", ""), |
| tool=step.get("tool", "none"), |
| depends_on=deps, |
| retries=step.get("retries", 2), |
| timeout=step.get("timeout", 120), |
| metadata=step.get("metadata", {}), |
| ) |
| dag.add_node(node) |
| prev_id = node_id |
| return dag |
|
|
| def build_saas_dag(self, project_name: str) -> TaskDAG: |
| """Pre-built DAG for full SaaS project scaffolding.""" |
| dag_id = f"saas_{uuid.uuid4().hex[:8]}" |
| dag = TaskDAG(dag_id, f"Build SaaS: {project_name}") |
| nodes = [ |
| TaskNode("plan", "Planning", "Analyze requirements and create architecture plan", "none", []), |
| TaskNode("scaffold", "Scaffold Project", "Create project structure and base files", "shell", ["plan"]), |
| TaskNode("backend", "Build Backend", "Generate API, routes, models", "code", ["scaffold"]), |
| TaskNode("frontend", "Build Frontend", "Generate UI components and pages", "code", ["scaffold"]), |
| TaskNode("db", "Setup Database", "Create DB schema, migrations", "shell", ["backend"]), |
| TaskNode("auth", "Add Auth", "Implement authentication system", "code", ["backend", "db"]), |
| TaskNode("tests", "Write Tests", "Generate unit and integration tests", "code", ["backend", "frontend"]), |
| TaskNode("lint", "Lint & Format", "Run linters and formatters", "shell", ["backend", "frontend"]), |
| TaskNode("git_init", "Init Git Repo", "Initialize git and make first commit", "github", ["scaffold"]), |
| TaskNode("deploy", "Deploy", "Deploy to Vercel/Cloudflare", "shell", ["tests", "lint"]), |
| TaskNode("verify", "Verify Deployment", "Check deployment URL and health", "none", ["deploy"]), |
| ] |
| for n in nodes: |
| dag.add_node(n) |
| return dag |
|
|
| |
|
|
| async def execute( |
| self, |
| dag: TaskDAG, |
| executor: Callable, |
| session_id: str = "", |
| task_id: str = "", |
| max_parallel: int = 3, |
| ) -> Dict: |
| """ |
| Execute a DAG with dependency-aware parallel scheduling. |
| executor: async fn(node, context) -> str |
| """ |
| self._active_dags[dag.id] = dag |
| dag.status = "running" |
| dag.started_at = time.time() |
| results: Dict[str, str] = {} |
|
|
| await self._emit(task_id, session_id, "dag_started", { |
| "dag_id": dag.id, |
| "goal": dag.goal, |
| "total_nodes": len(dag.nodes), |
| "nodes": [n.to_dict() for n in dag.nodes.values()], |
| }) |
|
|
| semaphore = asyncio.Semaphore(max_parallel) |
|
|
| while not dag.is_complete(): |
| ready = dag.get_ready_nodes() |
| if not ready: |
| |
| await asyncio.sleep(0.5) |
| continue |
|
|
| |
| tasks = [] |
| for node in ready: |
| node.status = StepStatus.RUNNING |
| node.started_at = time.time() |
| await self._emit(task_id, session_id, "dag_node_started", { |
| "node_id": node.id, |
| "name": node.name, |
| "tool": node.tool, |
| "dag_id": dag.id, |
| "progress": dag.get_progress(), |
| }) |
| t = asyncio.create_task( |
| self._execute_node(node, dag, results, executor, semaphore, session_id, task_id) |
| ) |
| tasks.append(t) |
|
|
| if tasks: |
| await asyncio.gather(*tasks, return_exceptions=True) |
|
|
| |
| await self._emit(task_id, session_id, "dag_progress", { |
| "dag_id": dag.id, |
| "progress": dag.get_progress(), |
| "nodes": [n.to_dict() for n in dag.nodes.values()], |
| }) |
|
|
| dag.completed_at = time.time() |
| dag.status = "completed" if not dag.has_failed() else "partial_failure" |
|
|
| |
| completed_results = { |
| nid: n.result for nid, n in dag.nodes.items() |
| if n.status == StepStatus.COMPLETED and n.result |
| } |
|
|
| await self._emit(task_id, session_id, "dag_completed", { |
| "dag_id": dag.id, |
| "status": dag.status, |
| "progress": dag.get_progress(), |
| "duration": round(dag.completed_at - dag.started_at, 2), |
| "nodes": [n.to_dict() for n in dag.nodes.values()], |
| }) |
|
|
| return { |
| "success": not dag.has_failed(), |
| "dag_id": dag.id, |
| "status": dag.status, |
| "progress": dag.get_progress(), |
| "results": completed_results, |
| "nodes": [n.to_dict() for n in dag.nodes.values()], |
| } |
|
|
| async def _execute_node( |
| self, |
| node: TaskNode, |
| dag: TaskDAG, |
| results: Dict, |
| executor: Callable, |
| semaphore: asyncio.Semaphore, |
| session_id: str, |
| task_id: str, |
| ): |
| async with semaphore: |
| context = { |
| "goal": dag.goal, |
| "previous_results": {k: v for k, v in results.items()}, |
| "node_metadata": node.metadata, |
| } |
|
|
| for attempt in range(1, node.retries + 2): |
| node.attempt = attempt |
| try: |
| result = await asyncio.wait_for( |
| executor(node, context), |
| timeout=node.timeout, |
| ) |
| node.result = str(result) |
| node.status = StepStatus.COMPLETED |
| node.completed_at = time.time() |
| results[node.id] = node.result |
|
|
| await self._emit(task_id, session_id, "dag_node_completed", { |
| "node_id": node.id, |
| "name": node.name, |
| "dag_id": dag.id, |
| "result": node.result[:200], |
| "duration": round(node.completed_at - node.started_at, 2), |
| "attempt": attempt, |
| "progress": dag.get_progress(), |
| }) |
| return |
|
|
| except asyncio.TimeoutError: |
| node.error = f"Timeout after {node.timeout}s" |
| log.warning("Node timeout", node=node.name, attempt=attempt) |
| except Exception as e: |
| node.error = str(e) |
| log.warning("Node error", node=node.name, attempt=attempt, error=str(e)) |
|
|
| if attempt <= node.retries: |
| node.status = StepStatus.RETRYING |
| await self._emit(task_id, session_id, "dag_node_retry", { |
| "node_id": node.id, |
| "name": node.name, |
| "attempt": attempt, |
| "max_retries": node.retries, |
| "error": node.error, |
| }) |
| await asyncio.sleep(2 ** (attempt - 1)) |
|
|
| node.status = StepStatus.FAILED |
| node.completed_at = time.time() |
| await self._emit(task_id, session_id, "dag_node_failed", { |
| "node_id": node.id, |
| "name": node.name, |
| "dag_id": dag.id, |
| "error": node.error, |
| "attempts": node.attempt, |
| }) |
|
|
| |
|
|
| def get_dag(self, dag_id: str) -> Optional[TaskDAG]: |
| return self._active_dags.get(dag_id) |
|
|
| def get_all_dags(self) -> List[Dict]: |
| return [dag.to_dict() for dag in self._active_dags.values()] |
|
|
| |
|
|
| async def _emit(self, task_id: str, session_id: str, event: str, data: Dict): |
| if not self.ws: |
| return |
| try: |
| if task_id: |
| await self.ws.emit(task_id, event, data, session_id=session_id) |
| if session_id: |
| await self.ws.emit_chat(session_id, event, data) |
| except Exception: |
| pass |
|
|