""" Workflow executor with DAG orchestration and parallel execution """ from typing import Dict, Any, List, Set, Optional from .schema import WorkflowDefinition, WorkflowTask from .persistence import WorkflowStore import networkx as nx from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError import logging import time import json logger = logging.getLogger(__name__) class WorkflowExecutor: """ Executes workflows as DAGs with parallel task execution. Features: - Dependency resolution via topological sort - Parallel execution with configurable concurrency - Error handling and retry logic - Execution trace for debugging """ def __init__( self, tools_registry: Dict[str, Any], max_parallel: int = 3, timeout: int = 600, memory=None, store_path: str = "./workflow_cache" ): """ Initialize workflow executor. Args: tools_registry: Map of tool_name -> tool_instance max_parallel: Maximum parallel tasks timeout: Default workflow timeout memory: Optional agent memory for context store_path: Directory for workflow persistence """ self.tools_registry = tools_registry self.max_parallel = max_parallel self.timeout = timeout self.memory = memory self.store = WorkflowStore(store_path=store_path) def execute(self, workflow: WorkflowDefinition) -> Dict[str, Any]: """ Execute workflow and return result. Args: workflow: WorkflowDefinition to execute Returns: Dict with success status, result, and execution trace """ start_time = time.time() trace = [] try: # Build DAG graph = self._build_dag(workflow) logger.info(f"Built DAG with {len(graph.nodes)} nodes") # Topological sort for execution order try: execution_order = list(nx.topological_sort(graph)) except nx.NetworkXError as e: return { "success": False, "error": f"Invalid DAG: {e}", "trace": trace } # Execute tasks results = {} task_map = {task.id: task for task in workflow.tasks} # Process tasks in dependency order pending_tasks = set(execution_order) completed_tasks = set() while pending_tasks: # Find tasks ready to execute (all dependencies complete) ready_tasks = [ tid for tid in pending_tasks if all(dep in completed_tasks for dep in task_map[tid].depends_on) ] if not ready_tasks: break # Deadlock or error # Execute ready tasks in parallel (up to max_parallel) batch_size = min(len(ready_tasks), workflow.max_parallel) batch = ready_tasks[:batch_size] logger.info(f"Executing batch: {batch}") with ThreadPoolExecutor(max_workers=batch_size) as executor: futures = { executor.submit( self._execute_task, task_map[tid], results, trace ): tid for tid in batch } # Wait for completion for future in futures: tid = futures[future] try: task_timeout = task_map[tid].timeout_seconds result = future.result(timeout=task_timeout) results[tid] = result completed_tasks.add(tid) pending_tasks.remove(tid) except FutureTimeoutError: error_msg = f"Task {tid} timed out" logger.error(error_msg) trace.append({ "task_id": tid, "status": "timeout", "error": error_msg }) # Mark as failed but continue with other tasks results[tid] = {"error": error_msg} completed_tasks.add(tid) pending_tasks.remove(tid) except Exception as e: error_msg = f"Task {tid} failed: {e}" logger.error(error_msg) trace.append({ "task_id": tid, "status": "error", "error": str(e) }) results[tid] = {"error": str(e)} completed_tasks.add(tid) pending_tasks.remove(tid) # Check workflow timeout if time.time() - start_time > workflow.timeout_seconds: return { "success": False, "error": "Workflow timeout exceeded", "trace": trace, "partial_results": results } # Get final result final_result = results.get(workflow.final_task) if final_result is None: return { "success": False, "error": f"Final task {workflow.final_task} did not execute", "trace": trace, "results": results } execution_time = time.time() - start_time result = { "success": True, "result": final_result, "execution_time": execution_time, "trace": trace, "all_results": results } # Save successful workflow execution workflow_id = f"{workflow.name}_{int(time.time())}" self.store.save_workflow(workflow_id, workflow, result) return result except Exception as e: logger.error(f"Workflow execution failed: {e}", exc_info=True) return { "success": False, "error": str(e), "trace": trace } def _build_dag(self, workflow: WorkflowDefinition) -> nx.DiGraph: """Build NetworkX directed graph from workflow.""" graph = nx.DiGraph() # Add nodes for task in workflow.tasks: graph.add_node(task.id) # Add edges (dependencies) for task in workflow.tasks: for dep in task.depends_on: graph.add_edge(dep, task.id) # Edge from dependency to task return graph def _execute_task( self, task: WorkflowTask, results: Dict[str, Any], trace: List[Dict[str, Any]] ) -> Any: """ Execute single task with retry logic. Args: task: Task to execute results: Shared results dict (for accessing dependency outputs) trace: Shared trace list Returns: Task result """ logger.info(f"Executing task: {task.id} (tool: {task.tool})") trace.append({ "task_id": task.id, "tool": task.tool, "status": "started", "timestamp": time.time() }) # Get tool tool = self.tools_registry.get(task.tool) if not tool: error_msg = f"Tool not found: {task.tool}" logger.error(error_msg) trace.append({ "task_id": task.id, "status": "error", "error": error_msg }) raise ValueError(error_msg) # Resolve arguments (may reference previous task results) args = self._resolve_args(task.args, results) # Execute with retry last_error = None for attempt in range(task.max_retries + 1): try: result = tool.forward(**args) # Parse result if it's JSON string if isinstance(result, str): try: result = json.loads(result) except json.JSONDecodeError: pass # Keep as string trace.append({ "task_id": task.id, "status": "completed", "attempt": attempt + 1, "timestamp": time.time() }) logger.info(f"Task {task.id} completed successfully") return result except Exception as e: last_error = e logger.warning( f"Task {task.id} attempt {attempt + 1}/{task.max_retries + 1} failed: {e}" ) if attempt < task.max_retries and task.retry_on_failure: time.sleep(1 * (2 ** attempt)) # Exponential backoff continue else: trace.append({ "task_id": task.id, "status": "failed", "error": str(e), "attempts": attempt + 1 }) raise # Should not reach here, but for safety if last_error: raise last_error else: raise RuntimeError(f"Task {task.id} failed without exception") def _resolve_args(self, args: Dict[str, Any], results: Dict[str, Any]) -> Dict[str, Any]: """ Resolve arguments that reference previous task results. Supports syntax: "${task_id}" or "${task_id.field}" Args: args: Raw arguments results: Previous task results Returns: Resolved arguments """ resolved = {} for key, value in args.items(): if isinstance(value, str) and value.startswith("${") and value.endswith("}"): # Reference to previous task result ref = value[2:-1] # Remove ${ and } parts = ref.split(".") # Get task result task_id = parts[0] if task_id not in results: raise ValueError(f"Referenced task {task_id} not yet executed") result = results[task_id] # Navigate nested fields for part in parts[1:]: if isinstance(result, dict): result = result.get(part) else: raise ValueError(f"Cannot access field {part} on {type(result)}") resolved[key] = result else: resolved[key] = value return resolved