""" Task DAG Engine — Devin-style Task Graph Plans, tracks, and executes tasks as a Directed Acyclic Graph """ import asyncio import json import time import uuid from enum import Enum from typing import Any, Callable, Dict, List, Optional, Set import structlog log = structlog.get_logger() class StepStatus(str, Enum): PENDING = "pending" RUNNING = "running" COMPLETED = "completed" FAILED = "failed" SKIPPED = "skipped" RETRYING = "retrying" class TaskNode: """Single node in the task DAG.""" def __init__( self, node_id: str, name: str, description: str = "", tool: str = "none", depends_on: Optional[List[str]] = None, retries: int = 2, timeout: int = 120, metadata: Optional[Dict] = None, ): self.id = node_id self.name = name self.description = description self.tool = tool self.depends_on: List[str] = depends_on or [] self.retries = retries self.timeout = timeout self.metadata = metadata or {} self.status = StepStatus.PENDING self.result: Optional[str] = None self.error: Optional[str] = None self.attempt = 0 self.started_at: Optional[float] = None self.completed_at: Optional[float] = None def to_dict(self) -> Dict: return { "id": self.id, "name": self.name, "description": self.description, "tool": self.tool, "depends_on": self.depends_on, "status": self.status.value, "result": (self.result or "")[:300], "error": self.error, "attempt": self.attempt, "started_at": self.started_at, "completed_at": self.completed_at, "duration": round(self.completed_at - self.started_at, 2) if self.started_at and self.completed_at else None, } def is_ready(self, completed_ids: Set[str]) -> bool: """Check if all dependencies are met.""" return all(dep in completed_ids for dep in self.depends_on) class TaskDAG: """ Directed Acyclic Graph of tasks. Supports: parallel execution, dependency resolution, retry, rollback. """ def __init__(self, dag_id: str, goal: str): self.id = dag_id self.goal = goal self.nodes: Dict[str, TaskNode] = {} self.created_at = time.time() self.started_at: Optional[float] = None self.completed_at: Optional[float] = None self.status = "pending" self.result: Optional[str] = None def add_node(self, node: TaskNode) -> "TaskDAG": self.nodes[node.id] = node return self def get_ready_nodes(self) -> List[TaskNode]: """Get nodes whose dependencies are all completed.""" completed = {nid for nid, n in self.nodes.items() if n.status == StepStatus.COMPLETED} return [ n for n in self.nodes.values() if n.status == StepStatus.PENDING and n.is_ready(completed) ] def get_progress(self) -> Dict: total = len(self.nodes) completed = sum(1 for n in self.nodes.values() if n.status == StepStatus.COMPLETED) failed = sum(1 for n in self.nodes.values() if n.status == StepStatus.FAILED) running = sum(1 for n in self.nodes.values() if n.status == StepStatus.RUNNING) pending = sum(1 for n in self.nodes.values() if n.status == StepStatus.PENDING) return { "total": total, "completed": completed, "failed": failed, "running": running, "pending": pending, "percent": round((completed / total * 100) if total > 0 else 0, 1), } def is_complete(self) -> bool: return all( n.status in (StepStatus.COMPLETED, StepStatus.FAILED, StepStatus.SKIPPED) for n in self.nodes.values() ) def has_failed(self) -> bool: return any(n.status == StepStatus.FAILED for n in self.nodes.values()) def to_dict(self) -> Dict: progress = self.get_progress() return { "id": self.id, "goal": self.goal, "status": self.status, "progress": progress, "nodes": [n.to_dict() for n in self.nodes.values()], "created_at": self.created_at, "started_at": self.started_at, "completed_at": self.completed_at, "duration": round(self.completed_at - self.started_at, 2) if self.started_at and self.completed_at else None, } class DAGEngine: """ Executes TaskDAGs with: - Parallel execution of independent nodes - Dependency-aware scheduling - Per-node retry logic - Real-time WebSocket streaming - Rollback support """ def __init__(self, ws_manager=None): self.ws = ws_manager self._active_dags: Dict[str, TaskDAG] = {} # ─── Build DAG from Plan ─────────────────────────────────────────────────── def build_from_steps(self, steps: List[Dict], goal: str = "") -> TaskDAG: """Convert flat step list into DAG with sequential dependencies.""" dag_id = f"dag_{uuid.uuid4().hex[:8]}" dag = TaskDAG(dag_id, goal) prev_id = None for i, step in enumerate(steps): node_id = step.get("id") or f"step_{i+1}" deps = step.get("depends_on") or ([prev_id] if prev_id else []) node = TaskNode( node_id=node_id, name=step.get("name", f"Step {i+1}"), description=step.get("description", ""), tool=step.get("tool", "none"), depends_on=deps, retries=step.get("retries", 2), timeout=step.get("timeout", 120), metadata=step.get("metadata", {}), ) dag.add_node(node) prev_id = node_id return dag def build_saas_dag(self, project_name: str) -> TaskDAG: """Pre-built DAG for full SaaS project scaffolding.""" dag_id = f"saas_{uuid.uuid4().hex[:8]}" dag = TaskDAG(dag_id, f"Build SaaS: {project_name}") nodes = [ TaskNode("plan", "Planning", "Analyze requirements and create architecture plan", "none", []), TaskNode("scaffold", "Scaffold Project", "Create project structure and base files", "shell", ["plan"]), TaskNode("backend", "Build Backend", "Generate API, routes, models", "code", ["scaffold"]), TaskNode("frontend", "Build Frontend", "Generate UI components and pages", "code", ["scaffold"]), TaskNode("db", "Setup Database", "Create DB schema, migrations", "shell", ["backend"]), TaskNode("auth", "Add Auth", "Implement authentication system", "code", ["backend", "db"]), TaskNode("tests", "Write Tests", "Generate unit and integration tests", "code", ["backend", "frontend"]), TaskNode("lint", "Lint & Format", "Run linters and formatters", "shell", ["backend", "frontend"]), TaskNode("git_init", "Init Git Repo", "Initialize git and make first commit", "github", ["scaffold"]), TaskNode("deploy", "Deploy", "Deploy to Vercel/Cloudflare", "shell", ["tests", "lint"]), TaskNode("verify", "Verify Deployment", "Check deployment URL and health", "none", ["deploy"]), ] for n in nodes: dag.add_node(n) return dag # ─── Execute DAG ─────────────────────────────────────────────────────────── async def execute( self, dag: TaskDAG, executor: Callable, session_id: str = "", task_id: str = "", max_parallel: int = 3, ) -> Dict: """ Execute a DAG with dependency-aware parallel scheduling. executor: async fn(node, context) -> str """ self._active_dags[dag.id] = dag dag.status = "running" dag.started_at = time.time() results: Dict[str, str] = {} await self._emit(task_id, session_id, "dag_started", { "dag_id": dag.id, "goal": dag.goal, "total_nodes": len(dag.nodes), "nodes": [n.to_dict() for n in dag.nodes.values()], }) semaphore = asyncio.Semaphore(max_parallel) while not dag.is_complete(): ready = dag.get_ready_nodes() if not ready: # All ready nodes are running — wait await asyncio.sleep(0.5) continue # Launch all ready nodes in parallel (up to semaphore limit) tasks = [] for node in ready: node.status = StepStatus.RUNNING node.started_at = time.time() await self._emit(task_id, session_id, "dag_node_started", { "node_id": node.id, "name": node.name, "tool": node.tool, "dag_id": dag.id, "progress": dag.get_progress(), }) t = asyncio.create_task( self._execute_node(node, dag, results, executor, semaphore, session_id, task_id) ) tasks.append(t) if tasks: await asyncio.gather(*tasks, return_exceptions=True) # Check progress await self._emit(task_id, session_id, "dag_progress", { "dag_id": dag.id, "progress": dag.get_progress(), "nodes": [n.to_dict() for n in dag.nodes.values()], }) dag.completed_at = time.time() dag.status = "completed" if not dag.has_failed() else "partial_failure" # Compile final result completed_results = { nid: n.result for nid, n in dag.nodes.items() if n.status == StepStatus.COMPLETED and n.result } await self._emit(task_id, session_id, "dag_completed", { "dag_id": dag.id, "status": dag.status, "progress": dag.get_progress(), "duration": round(dag.completed_at - dag.started_at, 2), "nodes": [n.to_dict() for n in dag.nodes.values()], }) return { "success": not dag.has_failed(), "dag_id": dag.id, "status": dag.status, "progress": dag.get_progress(), "results": completed_results, "nodes": [n.to_dict() for n in dag.nodes.values()], } async def _execute_node( self, node: TaskNode, dag: TaskDAG, results: Dict, executor: Callable, semaphore: asyncio.Semaphore, session_id: str, task_id: str, ): async with semaphore: context = { "goal": dag.goal, "previous_results": {k: v for k, v in results.items()}, "node_metadata": node.metadata, } for attempt in range(1, node.retries + 2): node.attempt = attempt try: result = await asyncio.wait_for( executor(node, context), timeout=node.timeout, ) node.result = str(result) node.status = StepStatus.COMPLETED node.completed_at = time.time() results[node.id] = node.result await self._emit(task_id, session_id, "dag_node_completed", { "node_id": node.id, "name": node.name, "dag_id": dag.id, "result": node.result[:200], "duration": round(node.completed_at - node.started_at, 2), "attempt": attempt, "progress": dag.get_progress(), }) return except asyncio.TimeoutError: node.error = f"Timeout after {node.timeout}s" log.warning("Node timeout", node=node.name, attempt=attempt) except Exception as e: node.error = str(e) log.warning("Node error", node=node.name, attempt=attempt, error=str(e)) if attempt <= node.retries: node.status = StepStatus.RETRYING await self._emit(task_id, session_id, "dag_node_retry", { "node_id": node.id, "name": node.name, "attempt": attempt, "max_retries": node.retries, "error": node.error, }) await asyncio.sleep(2 ** (attempt - 1)) node.status = StepStatus.FAILED node.completed_at = time.time() await self._emit(task_id, session_id, "dag_node_failed", { "node_id": node.id, "name": node.name, "dag_id": dag.id, "error": node.error, "attempts": node.attempt, }) # ─── Get Active DAG ─────────────────────────────────────────────────────── def get_dag(self, dag_id: str) -> Optional[TaskDAG]: return self._active_dags.get(dag_id) def get_all_dags(self) -> List[Dict]: return [dag.to_dict() for dag in self._active_dags.values()] # ─── Emit ───────────────────────────────────────────────────────────────── async def _emit(self, task_id: str, session_id: str, event: str, data: Dict): if not self.ws: return try: if task_id: await self.ws.emit(task_id, event, data, session_id=session_id) if session_id: await self.ws.emit_chat(session_id, event, data) except Exception: pass