chmielvu's picture
feat: add production refinements (Phase 1-3)
4454066 verified
"""
Workflow executor with DAG orchestration and parallel execution
"""
from typing import Dict, Any, List, Set, Optional
from .schema import WorkflowDefinition, WorkflowTask
from .persistence import WorkflowStore
import networkx as nx
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
import logging
import time
import json
logger = logging.getLogger(__name__)
class WorkflowExecutor:
"""
Executes workflows as DAGs with parallel task execution.
Features:
- Dependency resolution via topological sort
- Parallel execution with configurable concurrency
- Error handling and retry logic
- Execution trace for debugging
"""
def __init__(
self,
tools_registry: Dict[str, Any],
max_parallel: int = 3,
timeout: int = 600,
memory=None,
store_path: str = "./workflow_cache"
):
"""
Initialize workflow executor.
Args:
tools_registry: Map of tool_name -> tool_instance
max_parallel: Maximum parallel tasks
timeout: Default workflow timeout
memory: Optional agent memory for context
store_path: Directory for workflow persistence
"""
self.tools_registry = tools_registry
self.max_parallel = max_parallel
self.timeout = timeout
self.memory = memory
self.store = WorkflowStore(store_path=store_path)
def execute(self, workflow: WorkflowDefinition) -> Dict[str, Any]:
"""
Execute workflow and return result.
Args:
workflow: WorkflowDefinition to execute
Returns:
Dict with success status, result, and execution trace
"""
start_time = time.time()
trace = []
try:
# Build DAG
graph = self._build_dag(workflow)
logger.info(f"Built DAG with {len(graph.nodes)} nodes")
# Topological sort for execution order
try:
execution_order = list(nx.topological_sort(graph))
except nx.NetworkXError as e:
return {
"success": False,
"error": f"Invalid DAG: {e}",
"trace": trace
}
# Execute tasks
results = {}
task_map = {task.id: task for task in workflow.tasks}
# Process tasks in dependency order
pending_tasks = set(execution_order)
completed_tasks = set()
while pending_tasks:
# Find tasks ready to execute (all dependencies complete)
ready_tasks = [
tid for tid in pending_tasks
if all(dep in completed_tasks for dep in task_map[tid].depends_on)
]
if not ready_tasks:
break # Deadlock or error
# Execute ready tasks in parallel (up to max_parallel)
batch_size = min(len(ready_tasks), workflow.max_parallel)
batch = ready_tasks[:batch_size]
logger.info(f"Executing batch: {batch}")
with ThreadPoolExecutor(max_workers=batch_size) as executor:
futures = {
executor.submit(
self._execute_task,
task_map[tid],
results,
trace
): tid
for tid in batch
}
# Wait for completion
for future in futures:
tid = futures[future]
try:
task_timeout = task_map[tid].timeout_seconds
result = future.result(timeout=task_timeout)
results[tid] = result
completed_tasks.add(tid)
pending_tasks.remove(tid)
except FutureTimeoutError:
error_msg = f"Task {tid} timed out"
logger.error(error_msg)
trace.append({
"task_id": tid,
"status": "timeout",
"error": error_msg
})
# Mark as failed but continue with other tasks
results[tid] = {"error": error_msg}
completed_tasks.add(tid)
pending_tasks.remove(tid)
except Exception as e:
error_msg = f"Task {tid} failed: {e}"
logger.error(error_msg)
trace.append({
"task_id": tid,
"status": "error",
"error": str(e)
})
results[tid] = {"error": str(e)}
completed_tasks.add(tid)
pending_tasks.remove(tid)
# Check workflow timeout
if time.time() - start_time > workflow.timeout_seconds:
return {
"success": False,
"error": "Workflow timeout exceeded",
"trace": trace,
"partial_results": results
}
# Get final result
final_result = results.get(workflow.final_task)
if final_result is None:
return {
"success": False,
"error": f"Final task {workflow.final_task} did not execute",
"trace": trace,
"results": results
}
execution_time = time.time() - start_time
result = {
"success": True,
"result": final_result,
"execution_time": execution_time,
"trace": trace,
"all_results": results
}
# Save successful workflow execution
workflow_id = f"{workflow.name}_{int(time.time())}"
self.store.save_workflow(workflow_id, workflow, result)
return result
except Exception as e:
logger.error(f"Workflow execution failed: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"trace": trace
}
def _build_dag(self, workflow: WorkflowDefinition) -> nx.DiGraph:
"""Build NetworkX directed graph from workflow."""
graph = nx.DiGraph()
# Add nodes
for task in workflow.tasks:
graph.add_node(task.id)
# Add edges (dependencies)
for task in workflow.tasks:
for dep in task.depends_on:
graph.add_edge(dep, task.id) # Edge from dependency to task
return graph
def _execute_task(
self,
task: WorkflowTask,
results: Dict[str, Any],
trace: List[Dict[str, Any]]
) -> Any:
"""
Execute single task with retry logic.
Args:
task: Task to execute
results: Shared results dict (for accessing dependency outputs)
trace: Shared trace list
Returns:
Task result
"""
logger.info(f"Executing task: {task.id} (tool: {task.tool})")
trace.append({
"task_id": task.id,
"tool": task.tool,
"status": "started",
"timestamp": time.time()
})
# Get tool
tool = self.tools_registry.get(task.tool)
if not tool:
error_msg = f"Tool not found: {task.tool}"
logger.error(error_msg)
trace.append({
"task_id": task.id,
"status": "error",
"error": error_msg
})
raise ValueError(error_msg)
# Resolve arguments (may reference previous task results)
args = self._resolve_args(task.args, results)
# Execute with retry
last_error = None
for attempt in range(task.max_retries + 1):
try:
result = tool.forward(**args)
# Parse result if it's JSON string
if isinstance(result, str):
try:
result = json.loads(result)
except json.JSONDecodeError:
pass # Keep as string
trace.append({
"task_id": task.id,
"status": "completed",
"attempt": attempt + 1,
"timestamp": time.time()
})
logger.info(f"Task {task.id} completed successfully")
return result
except Exception as e:
last_error = e
logger.warning(
f"Task {task.id} attempt {attempt + 1}/{task.max_retries + 1} failed: {e}"
)
if attempt < task.max_retries and task.retry_on_failure:
time.sleep(1 * (2 ** attempt)) # Exponential backoff
continue
else:
trace.append({
"task_id": task.id,
"status": "failed",
"error": str(e),
"attempts": attempt + 1
})
raise
# Should not reach here, but for safety
if last_error:
raise last_error
else:
raise RuntimeError(f"Task {task.id} failed without exception")
def _resolve_args(self, args: Dict[str, Any], results: Dict[str, Any]) -> Dict[str, Any]:
"""
Resolve arguments that reference previous task results.
Supports syntax: "${task_id}" or "${task_id.field}"
Args:
args: Raw arguments
results: Previous task results
Returns:
Resolved arguments
"""
resolved = {}
for key, value in args.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# Reference to previous task result
ref = value[2:-1] # Remove ${ and }
parts = ref.split(".")
# Get task result
task_id = parts[0]
if task_id not in results:
raise ValueError(f"Referenced task {task_id} not yet executed")
result = results[task_id]
# Navigate nested fields
for part in parts[1:]:
if isinstance(result, dict):
result = result.get(part)
else:
raise ValueError(f"Cannot access field {part} on {type(result)}")
resolved[key] = result
else:
resolved[key] = value
return resolved