PYAE1994 commited on
Commit
5f8b502
Β·
verified Β·
1 Parent(s): e2cfcd7

feat: GOD MODE+ v4.0 - tools/task_dag.py

Browse files
Files changed (1) hide show
  1. tools/task_dag.py +378 -0
tools/task_dag.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task DAG Engine β€” Devin-style Task Graph
3
+ Plans, tracks, and executes tasks as a Directed Acyclic Graph
4
+ """
5
+
6
+ import asyncio
7
+ import json
8
+ import time
9
+ import uuid
10
+ from enum import Enum
11
+ from typing import Any, Callable, Dict, List, Optional, Set
12
+
13
+ import structlog
14
+
15
+ log = structlog.get_logger()
16
+
17
+
18
+ class StepStatus(str, Enum):
19
+ PENDING = "pending"
20
+ RUNNING = "running"
21
+ COMPLETED = "completed"
22
+ FAILED = "failed"
23
+ SKIPPED = "skipped"
24
+ RETRYING = "retrying"
25
+
26
+
27
+ class TaskNode:
28
+ """Single node in the task DAG."""
29
+
30
+ def __init__(
31
+ self,
32
+ node_id: str,
33
+ name: str,
34
+ description: str = "",
35
+ tool: str = "none",
36
+ depends_on: Optional[List[str]] = None,
37
+ retries: int = 2,
38
+ timeout: int = 120,
39
+ metadata: Optional[Dict] = None,
40
+ ):
41
+ self.id = node_id
42
+ self.name = name
43
+ self.description = description
44
+ self.tool = tool
45
+ self.depends_on: List[str] = depends_on or []
46
+ self.retries = retries
47
+ self.timeout = timeout
48
+ self.metadata = metadata or {}
49
+ self.status = StepStatus.PENDING
50
+ self.result: Optional[str] = None
51
+ self.error: Optional[str] = None
52
+ self.attempt = 0
53
+ self.started_at: Optional[float] = None
54
+ self.completed_at: Optional[float] = None
55
+
56
+ def to_dict(self) -> Dict:
57
+ return {
58
+ "id": self.id,
59
+ "name": self.name,
60
+ "description": self.description,
61
+ "tool": self.tool,
62
+ "depends_on": self.depends_on,
63
+ "status": self.status.value,
64
+ "result": (self.result or "")[:300],
65
+ "error": self.error,
66
+ "attempt": self.attempt,
67
+ "started_at": self.started_at,
68
+ "completed_at": self.completed_at,
69
+ "duration": round(self.completed_at - self.started_at, 2) if self.started_at and self.completed_at else None,
70
+ }
71
+
72
+ def is_ready(self, completed_ids: Set[str]) -> bool:
73
+ """Check if all dependencies are met."""
74
+ return all(dep in completed_ids for dep in self.depends_on)
75
+
76
+
77
+ class TaskDAG:
78
+ """
79
+ Directed Acyclic Graph of tasks.
80
+ Supports: parallel execution, dependency resolution, retry, rollback.
81
+ """
82
+
83
+ def __init__(self, dag_id: str, goal: str):
84
+ self.id = dag_id
85
+ self.goal = goal
86
+ self.nodes: Dict[str, TaskNode] = {}
87
+ self.created_at = time.time()
88
+ self.started_at: Optional[float] = None
89
+ self.completed_at: Optional[float] = None
90
+ self.status = "pending"
91
+ self.result: Optional[str] = None
92
+
93
+ def add_node(self, node: TaskNode) -> "TaskDAG":
94
+ self.nodes[node.id] = node
95
+ return self
96
+
97
+ def get_ready_nodes(self) -> List[TaskNode]:
98
+ """Get nodes whose dependencies are all completed."""
99
+ completed = {nid for nid, n in self.nodes.items() if n.status == StepStatus.COMPLETED}
100
+ return [
101
+ n for n in self.nodes.values()
102
+ if n.status == StepStatus.PENDING and n.is_ready(completed)
103
+ ]
104
+
105
+ def get_progress(self) -> Dict:
106
+ total = len(self.nodes)
107
+ completed = sum(1 for n in self.nodes.values() if n.status == StepStatus.COMPLETED)
108
+ failed = sum(1 for n in self.nodes.values() if n.status == StepStatus.FAILED)
109
+ running = sum(1 for n in self.nodes.values() if n.status == StepStatus.RUNNING)
110
+ pending = sum(1 for n in self.nodes.values() if n.status == StepStatus.PENDING)
111
+ return {
112
+ "total": total,
113
+ "completed": completed,
114
+ "failed": failed,
115
+ "running": running,
116
+ "pending": pending,
117
+ "percent": round((completed / total * 100) if total > 0 else 0, 1),
118
+ }
119
+
120
+ def is_complete(self) -> bool:
121
+ return all(
122
+ n.status in (StepStatus.COMPLETED, StepStatus.FAILED, StepStatus.SKIPPED)
123
+ for n in self.nodes.values()
124
+ )
125
+
126
+ def has_failed(self) -> bool:
127
+ return any(n.status == StepStatus.FAILED for n in self.nodes.values())
128
+
129
+ def to_dict(self) -> Dict:
130
+ progress = self.get_progress()
131
+ return {
132
+ "id": self.id,
133
+ "goal": self.goal,
134
+ "status": self.status,
135
+ "progress": progress,
136
+ "nodes": [n.to_dict() for n in self.nodes.values()],
137
+ "created_at": self.created_at,
138
+ "started_at": self.started_at,
139
+ "completed_at": self.completed_at,
140
+ "duration": round(self.completed_at - self.started_at, 2) if self.started_at and self.completed_at else None,
141
+ }
142
+
143
+
144
+ class DAGEngine:
145
+ """
146
+ Executes TaskDAGs with:
147
+ - Parallel execution of independent nodes
148
+ - Dependency-aware scheduling
149
+ - Per-node retry logic
150
+ - Real-time WebSocket streaming
151
+ - Rollback support
152
+ """
153
+
154
+ def __init__(self, ws_manager=None):
155
+ self.ws = ws_manager
156
+ self._active_dags: Dict[str, TaskDAG] = {}
157
+
158
+ # ─── Build DAG from Plan ───────────────────────────────────────────────────
159
+
160
+ def build_from_steps(self, steps: List[Dict], goal: str = "") -> TaskDAG:
161
+ """Convert flat step list into DAG with sequential dependencies."""
162
+ dag_id = f"dag_{uuid.uuid4().hex[:8]}"
163
+ dag = TaskDAG(dag_id, goal)
164
+ prev_id = None
165
+ for i, step in enumerate(steps):
166
+ node_id = step.get("id") or f"step_{i+1}"
167
+ deps = step.get("depends_on") or ([prev_id] if prev_id else [])
168
+ node = TaskNode(
169
+ node_id=node_id,
170
+ name=step.get("name", f"Step {i+1}"),
171
+ description=step.get("description", ""),
172
+ tool=step.get("tool", "none"),
173
+ depends_on=deps,
174
+ retries=step.get("retries", 2),
175
+ timeout=step.get("timeout", 120),
176
+ metadata=step.get("metadata", {}),
177
+ )
178
+ dag.add_node(node)
179
+ prev_id = node_id
180
+ return dag
181
+
182
+ def build_saas_dag(self, project_name: str) -> TaskDAG:
183
+ """Pre-built DAG for full SaaS project scaffolding."""
184
+ dag_id = f"saas_{uuid.uuid4().hex[:8]}"
185
+ dag = TaskDAG(dag_id, f"Build SaaS: {project_name}")
186
+ nodes = [
187
+ TaskNode("plan", "Planning", "Analyze requirements and create architecture plan", "none", []),
188
+ TaskNode("scaffold", "Scaffold Project", "Create project structure and base files", "shell", ["plan"]),
189
+ TaskNode("backend", "Build Backend", "Generate API, routes, models", "code", ["scaffold"]),
190
+ TaskNode("frontend", "Build Frontend", "Generate UI components and pages", "code", ["scaffold"]),
191
+ TaskNode("db", "Setup Database", "Create DB schema, migrations", "shell", ["backend"]),
192
+ TaskNode("auth", "Add Auth", "Implement authentication system", "code", ["backend", "db"]),
193
+ TaskNode("tests", "Write Tests", "Generate unit and integration tests", "code", ["backend", "frontend"]),
194
+ TaskNode("lint", "Lint & Format", "Run linters and formatters", "shell", ["backend", "frontend"]),
195
+ TaskNode("git_init", "Init Git Repo", "Initialize git and make first commit", "github", ["scaffold"]),
196
+ TaskNode("deploy", "Deploy", "Deploy to Vercel/Cloudflare", "shell", ["tests", "lint"]),
197
+ TaskNode("verify", "Verify Deployment", "Check deployment URL and health", "none", ["deploy"]),
198
+ ]
199
+ for n in nodes:
200
+ dag.add_node(n)
201
+ return dag
202
+
203
+ # ─── Execute DAG ───────────────────────────────────────────────────────────
204
+
205
+ async def execute(
206
+ self,
207
+ dag: TaskDAG,
208
+ executor: Callable,
209
+ session_id: str = "",
210
+ task_id: str = "",
211
+ max_parallel: int = 3,
212
+ ) -> Dict:
213
+ """
214
+ Execute a DAG with dependency-aware parallel scheduling.
215
+ executor: async fn(node, context) -> str
216
+ """
217
+ self._active_dags[dag.id] = dag
218
+ dag.status = "running"
219
+ dag.started_at = time.time()
220
+ results: Dict[str, str] = {}
221
+
222
+ await self._emit(task_id, session_id, "dag_started", {
223
+ "dag_id": dag.id,
224
+ "goal": dag.goal,
225
+ "total_nodes": len(dag.nodes),
226
+ "nodes": [n.to_dict() for n in dag.nodes.values()],
227
+ })
228
+
229
+ semaphore = asyncio.Semaphore(max_parallel)
230
+
231
+ while not dag.is_complete():
232
+ ready = dag.get_ready_nodes()
233
+ if not ready:
234
+ # All ready nodes are running β€” wait
235
+ await asyncio.sleep(0.5)
236
+ continue
237
+
238
+ # Launch all ready nodes in parallel (up to semaphore limit)
239
+ tasks = []
240
+ for node in ready:
241
+ node.status = StepStatus.RUNNING
242
+ node.started_at = time.time()
243
+ await self._emit(task_id, session_id, "dag_node_started", {
244
+ "node_id": node.id,
245
+ "name": node.name,
246
+ "tool": node.tool,
247
+ "dag_id": dag.id,
248
+ "progress": dag.get_progress(),
249
+ })
250
+ t = asyncio.create_task(
251
+ self._execute_node(node, dag, results, executor, semaphore, session_id, task_id)
252
+ )
253
+ tasks.append(t)
254
+
255
+ if tasks:
256
+ await asyncio.gather(*tasks, return_exceptions=True)
257
+
258
+ # Check progress
259
+ await self._emit(task_id, session_id, "dag_progress", {
260
+ "dag_id": dag.id,
261
+ "progress": dag.get_progress(),
262
+ "nodes": [n.to_dict() for n in dag.nodes.values()],
263
+ })
264
+
265
+ dag.completed_at = time.time()
266
+ dag.status = "completed" if not dag.has_failed() else "partial_failure"
267
+
268
+ # Compile final result
269
+ completed_results = {
270
+ nid: n.result for nid, n in dag.nodes.items()
271
+ if n.status == StepStatus.COMPLETED and n.result
272
+ }
273
+
274
+ await self._emit(task_id, session_id, "dag_completed", {
275
+ "dag_id": dag.id,
276
+ "status": dag.status,
277
+ "progress": dag.get_progress(),
278
+ "duration": round(dag.completed_at - dag.started_at, 2),
279
+ "nodes": [n.to_dict() for n in dag.nodes.values()],
280
+ })
281
+
282
+ return {
283
+ "success": not dag.has_failed(),
284
+ "dag_id": dag.id,
285
+ "status": dag.status,
286
+ "progress": dag.get_progress(),
287
+ "results": completed_results,
288
+ "nodes": [n.to_dict() for n in dag.nodes.values()],
289
+ }
290
+
291
+ async def _execute_node(
292
+ self,
293
+ node: TaskNode,
294
+ dag: TaskDAG,
295
+ results: Dict,
296
+ executor: Callable,
297
+ semaphore: asyncio.Semaphore,
298
+ session_id: str,
299
+ task_id: str,
300
+ ):
301
+ async with semaphore:
302
+ context = {
303
+ "goal": dag.goal,
304
+ "previous_results": {k: v for k, v in results.items()},
305
+ "node_metadata": node.metadata,
306
+ }
307
+
308
+ for attempt in range(1, node.retries + 2):
309
+ node.attempt = attempt
310
+ try:
311
+ result = await asyncio.wait_for(
312
+ executor(node, context),
313
+ timeout=node.timeout,
314
+ )
315
+ node.result = str(result)
316
+ node.status = StepStatus.COMPLETED
317
+ node.completed_at = time.time()
318
+ results[node.id] = node.result
319
+
320
+ await self._emit(task_id, session_id, "dag_node_completed", {
321
+ "node_id": node.id,
322
+ "name": node.name,
323
+ "dag_id": dag.id,
324
+ "result": node.result[:200],
325
+ "duration": round(node.completed_at - node.started_at, 2),
326
+ "attempt": attempt,
327
+ "progress": dag.get_progress(),
328
+ })
329
+ return
330
+
331
+ except asyncio.TimeoutError:
332
+ node.error = f"Timeout after {node.timeout}s"
333
+ log.warning("Node timeout", node=node.name, attempt=attempt)
334
+ except Exception as e:
335
+ node.error = str(e)
336
+ log.warning("Node error", node=node.name, attempt=attempt, error=str(e))
337
+
338
+ if attempt <= node.retries:
339
+ node.status = StepStatus.RETRYING
340
+ await self._emit(task_id, session_id, "dag_node_retry", {
341
+ "node_id": node.id,
342
+ "name": node.name,
343
+ "attempt": attempt,
344
+ "max_retries": node.retries,
345
+ "error": node.error,
346
+ })
347
+ await asyncio.sleep(2 ** (attempt - 1))
348
+
349
+ node.status = StepStatus.FAILED
350
+ node.completed_at = time.time()
351
+ await self._emit(task_id, session_id, "dag_node_failed", {
352
+ "node_id": node.id,
353
+ "name": node.name,
354
+ "dag_id": dag.id,
355
+ "error": node.error,
356
+ "attempts": node.attempt,
357
+ })
358
+
359
+ # ─── Get Active DAG ───────────────────────────────────────────────────────
360
+
361
+ def get_dag(self, dag_id: str) -> Optional[TaskDAG]:
362
+ return self._active_dags.get(dag_id)
363
+
364
+ def get_all_dags(self) -> List[Dict]:
365
+ return [dag.to_dict() for dag in self._active_dags.values()]
366
+
367
+ # ─── Emit ─────────────────────────────────────────────────────────────────
368
+
369
+ async def _emit(self, task_id: str, session_id: str, event: str, data: Dict):
370
+ if not self.ws:
371
+ return
372
+ try:
373
+ if task_id:
374
+ await self.ws.emit(task_id, event, data, session_id=session_id)
375
+ if session_id:
376
+ await self.ws.emit_chat(session_id, event, data)
377
+ except Exception:
378
+ pass