Spaces:
Sleeping
Sleeping
| """ | |
| Workflow executor with DAG orchestration and parallel execution | |
| """ | |
| from typing import Dict, Any, List, Set, Optional | |
| from .schema import WorkflowDefinition, WorkflowTask | |
| from .persistence import WorkflowStore | |
| import networkx as nx | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError | |
| import logging | |
| import time | |
| import json | |
| logger = logging.getLogger(__name__) | |
| class WorkflowExecutor: | |
| """ | |
| Executes workflows as DAGs with parallel task execution. | |
| Features: | |
| - Dependency resolution via topological sort | |
| - Parallel execution with configurable concurrency | |
| - Error handling and retry logic | |
| - Execution trace for debugging | |
| """ | |
| def __init__( | |
| self, | |
| tools_registry: Dict[str, Any], | |
| max_parallel: int = 3, | |
| timeout: int = 600, | |
| memory=None, | |
| store_path: str = "./workflow_cache" | |
| ): | |
| """ | |
| Initialize workflow executor. | |
| Args: | |
| tools_registry: Map of tool_name -> tool_instance | |
| max_parallel: Maximum parallel tasks | |
| timeout: Default workflow timeout | |
| memory: Optional agent memory for context | |
| store_path: Directory for workflow persistence | |
| """ | |
| self.tools_registry = tools_registry | |
| self.max_parallel = max_parallel | |
| self.timeout = timeout | |
| self.memory = memory | |
| self.store = WorkflowStore(store_path=store_path) | |
| def execute(self, workflow: WorkflowDefinition) -> Dict[str, Any]: | |
| """ | |
| Execute workflow and return result. | |
| Args: | |
| workflow: WorkflowDefinition to execute | |
| Returns: | |
| Dict with success status, result, and execution trace | |
| """ | |
| start_time = time.time() | |
| trace = [] | |
| try: | |
| # Build DAG | |
| graph = self._build_dag(workflow) | |
| logger.info(f"Built DAG with {len(graph.nodes)} nodes") | |
| # Topological sort for execution order | |
| try: | |
| execution_order = list(nx.topological_sort(graph)) | |
| except nx.NetworkXError as e: | |
| return { | |
| "success": False, | |
| "error": f"Invalid DAG: {e}", | |
| "trace": trace | |
| } | |
| # Execute tasks | |
| results = {} | |
| task_map = {task.id: task for task in workflow.tasks} | |
| # Process tasks in dependency order | |
| pending_tasks = set(execution_order) | |
| completed_tasks = set() | |
| while pending_tasks: | |
| # Find tasks ready to execute (all dependencies complete) | |
| ready_tasks = [ | |
| tid for tid in pending_tasks | |
| if all(dep in completed_tasks for dep in task_map[tid].depends_on) | |
| ] | |
| if not ready_tasks: | |
| break # Deadlock or error | |
| # Execute ready tasks in parallel (up to max_parallel) | |
| batch_size = min(len(ready_tasks), workflow.max_parallel) | |
| batch = ready_tasks[:batch_size] | |
| logger.info(f"Executing batch: {batch}") | |
| with ThreadPoolExecutor(max_workers=batch_size) as executor: | |
| futures = { | |
| executor.submit( | |
| self._execute_task, | |
| task_map[tid], | |
| results, | |
| trace | |
| ): tid | |
| for tid in batch | |
| } | |
| # Wait for completion | |
| for future in futures: | |
| tid = futures[future] | |
| try: | |
| task_timeout = task_map[tid].timeout_seconds | |
| result = future.result(timeout=task_timeout) | |
| results[tid] = result | |
| completed_tasks.add(tid) | |
| pending_tasks.remove(tid) | |
| except FutureTimeoutError: | |
| error_msg = f"Task {tid} timed out" | |
| logger.error(error_msg) | |
| trace.append({ | |
| "task_id": tid, | |
| "status": "timeout", | |
| "error": error_msg | |
| }) | |
| # Mark as failed but continue with other tasks | |
| results[tid] = {"error": error_msg} | |
| completed_tasks.add(tid) | |
| pending_tasks.remove(tid) | |
| except Exception as e: | |
| error_msg = f"Task {tid} failed: {e}" | |
| logger.error(error_msg) | |
| trace.append({ | |
| "task_id": tid, | |
| "status": "error", | |
| "error": str(e) | |
| }) | |
| results[tid] = {"error": str(e)} | |
| completed_tasks.add(tid) | |
| pending_tasks.remove(tid) | |
| # Check workflow timeout | |
| if time.time() - start_time > workflow.timeout_seconds: | |
| return { | |
| "success": False, | |
| "error": "Workflow timeout exceeded", | |
| "trace": trace, | |
| "partial_results": results | |
| } | |
| # Get final result | |
| final_result = results.get(workflow.final_task) | |
| if final_result is None: | |
| return { | |
| "success": False, | |
| "error": f"Final task {workflow.final_task} did not execute", | |
| "trace": trace, | |
| "results": results | |
| } | |
| execution_time = time.time() - start_time | |
| result = { | |
| "success": True, | |
| "result": final_result, | |
| "execution_time": execution_time, | |
| "trace": trace, | |
| "all_results": results | |
| } | |
| # Save successful workflow execution | |
| workflow_id = f"{workflow.name}_{int(time.time())}" | |
| self.store.save_workflow(workflow_id, workflow, result) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Workflow execution failed: {e}", exc_info=True) | |
| return { | |
| "success": False, | |
| "error": str(e), | |
| "trace": trace | |
| } | |
| def _build_dag(self, workflow: WorkflowDefinition) -> nx.DiGraph: | |
| """Build NetworkX directed graph from workflow.""" | |
| graph = nx.DiGraph() | |
| # Add nodes | |
| for task in workflow.tasks: | |
| graph.add_node(task.id) | |
| # Add edges (dependencies) | |
| for task in workflow.tasks: | |
| for dep in task.depends_on: | |
| graph.add_edge(dep, task.id) # Edge from dependency to task | |
| return graph | |
| def _execute_task( | |
| self, | |
| task: WorkflowTask, | |
| results: Dict[str, Any], | |
| trace: List[Dict[str, Any]] | |
| ) -> Any: | |
| """ | |
| Execute single task with retry logic. | |
| Args: | |
| task: Task to execute | |
| results: Shared results dict (for accessing dependency outputs) | |
| trace: Shared trace list | |
| Returns: | |
| Task result | |
| """ | |
| logger.info(f"Executing task: {task.id} (tool: {task.tool})") | |
| trace.append({ | |
| "task_id": task.id, | |
| "tool": task.tool, | |
| "status": "started", | |
| "timestamp": time.time() | |
| }) | |
| # Get tool | |
| tool = self.tools_registry.get(task.tool) | |
| if not tool: | |
| error_msg = f"Tool not found: {task.tool}" | |
| logger.error(error_msg) | |
| trace.append({ | |
| "task_id": task.id, | |
| "status": "error", | |
| "error": error_msg | |
| }) | |
| raise ValueError(error_msg) | |
| # Resolve arguments (may reference previous task results) | |
| args = self._resolve_args(task.args, results) | |
| # Execute with retry | |
| last_error = None | |
| for attempt in range(task.max_retries + 1): | |
| try: | |
| result = tool.forward(**args) | |
| # Parse result if it's JSON string | |
| if isinstance(result, str): | |
| try: | |
| result = json.loads(result) | |
| except json.JSONDecodeError: | |
| pass # Keep as string | |
| trace.append({ | |
| "task_id": task.id, | |
| "status": "completed", | |
| "attempt": attempt + 1, | |
| "timestamp": time.time() | |
| }) | |
| logger.info(f"Task {task.id} completed successfully") | |
| return result | |
| except Exception as e: | |
| last_error = e | |
| logger.warning( | |
| f"Task {task.id} attempt {attempt + 1}/{task.max_retries + 1} failed: {e}" | |
| ) | |
| if attempt < task.max_retries and task.retry_on_failure: | |
| time.sleep(1 * (2 ** attempt)) # Exponential backoff | |
| continue | |
| else: | |
| trace.append({ | |
| "task_id": task.id, | |
| "status": "failed", | |
| "error": str(e), | |
| "attempts": attempt + 1 | |
| }) | |
| raise | |
| # Should not reach here, but for safety | |
| if last_error: | |
| raise last_error | |
| else: | |
| raise RuntimeError(f"Task {task.id} failed without exception") | |
| def _resolve_args(self, args: Dict[str, Any], results: Dict[str, Any]) -> Dict[str, Any]: | |
| """ | |
| Resolve arguments that reference previous task results. | |
| Supports syntax: "${task_id}" or "${task_id.field}" | |
| Args: | |
| args: Raw arguments | |
| results: Previous task results | |
| Returns: | |
| Resolved arguments | |
| """ | |
| resolved = {} | |
| for key, value in args.items(): | |
| if isinstance(value, str) and value.startswith("${") and value.endswith("}"): | |
| # Reference to previous task result | |
| ref = value[2:-1] # Remove ${ and } | |
| parts = ref.split(".") | |
| # Get task result | |
| task_id = parts[0] | |
| if task_id not in results: | |
| raise ValueError(f"Referenced task {task_id} not yet executed") | |
| result = results[task_id] | |
| # Navigate nested fields | |
| for part in parts[1:]: | |
| if isinstance(result, dict): | |
| result = result.get(part) | |
| else: | |
| raise ValueError(f"Cannot access field {part} on {type(result)}") | |
| resolved[key] = result | |
| else: | |
| resolved[key] = value | |
| return resolved | |