Spaces:
Sleeping
Sleeping
File size: 11,280 Bytes
4454066 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 |
"""
Workflow executor with DAG orchestration and parallel execution
"""
from typing import Dict, Any, List, Set, Optional
from .schema import WorkflowDefinition, WorkflowTask
from .persistence import WorkflowStore
import networkx as nx
from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError
import logging
import time
import json
logger = logging.getLogger(__name__)
class WorkflowExecutor:
"""
Executes workflows as DAGs with parallel task execution.
Features:
- Dependency resolution via topological sort
- Parallel execution with configurable concurrency
- Error handling and retry logic
- Execution trace for debugging
"""
def __init__(
self,
tools_registry: Dict[str, Any],
max_parallel: int = 3,
timeout: int = 600,
memory=None,
store_path: str = "./workflow_cache"
):
"""
Initialize workflow executor.
Args:
tools_registry: Map of tool_name -> tool_instance
max_parallel: Maximum parallel tasks
timeout: Default workflow timeout
memory: Optional agent memory for context
store_path: Directory for workflow persistence
"""
self.tools_registry = tools_registry
self.max_parallel = max_parallel
self.timeout = timeout
self.memory = memory
self.store = WorkflowStore(store_path=store_path)
def execute(self, workflow: WorkflowDefinition) -> Dict[str, Any]:
"""
Execute workflow and return result.
Args:
workflow: WorkflowDefinition to execute
Returns:
Dict with success status, result, and execution trace
"""
start_time = time.time()
trace = []
try:
# Build DAG
graph = self._build_dag(workflow)
logger.info(f"Built DAG with {len(graph.nodes)} nodes")
# Topological sort for execution order
try:
execution_order = list(nx.topological_sort(graph))
except nx.NetworkXError as e:
return {
"success": False,
"error": f"Invalid DAG: {e}",
"trace": trace
}
# Execute tasks
results = {}
task_map = {task.id: task for task in workflow.tasks}
# Process tasks in dependency order
pending_tasks = set(execution_order)
completed_tasks = set()
while pending_tasks:
# Find tasks ready to execute (all dependencies complete)
ready_tasks = [
tid for tid in pending_tasks
if all(dep in completed_tasks for dep in task_map[tid].depends_on)
]
if not ready_tasks:
break # Deadlock or error
# Execute ready tasks in parallel (up to max_parallel)
batch_size = min(len(ready_tasks), workflow.max_parallel)
batch = ready_tasks[:batch_size]
logger.info(f"Executing batch: {batch}")
with ThreadPoolExecutor(max_workers=batch_size) as executor:
futures = {
executor.submit(
self._execute_task,
task_map[tid],
results,
trace
): tid
for tid in batch
}
# Wait for completion
for future in futures:
tid = futures[future]
try:
task_timeout = task_map[tid].timeout_seconds
result = future.result(timeout=task_timeout)
results[tid] = result
completed_tasks.add(tid)
pending_tasks.remove(tid)
except FutureTimeoutError:
error_msg = f"Task {tid} timed out"
logger.error(error_msg)
trace.append({
"task_id": tid,
"status": "timeout",
"error": error_msg
})
# Mark as failed but continue with other tasks
results[tid] = {"error": error_msg}
completed_tasks.add(tid)
pending_tasks.remove(tid)
except Exception as e:
error_msg = f"Task {tid} failed: {e}"
logger.error(error_msg)
trace.append({
"task_id": tid,
"status": "error",
"error": str(e)
})
results[tid] = {"error": str(e)}
completed_tasks.add(tid)
pending_tasks.remove(tid)
# Check workflow timeout
if time.time() - start_time > workflow.timeout_seconds:
return {
"success": False,
"error": "Workflow timeout exceeded",
"trace": trace,
"partial_results": results
}
# Get final result
final_result = results.get(workflow.final_task)
if final_result is None:
return {
"success": False,
"error": f"Final task {workflow.final_task} did not execute",
"trace": trace,
"results": results
}
execution_time = time.time() - start_time
result = {
"success": True,
"result": final_result,
"execution_time": execution_time,
"trace": trace,
"all_results": results
}
# Save successful workflow execution
workflow_id = f"{workflow.name}_{int(time.time())}"
self.store.save_workflow(workflow_id, workflow, result)
return result
except Exception as e:
logger.error(f"Workflow execution failed: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"trace": trace
}
def _build_dag(self, workflow: WorkflowDefinition) -> nx.DiGraph:
"""Build NetworkX directed graph from workflow."""
graph = nx.DiGraph()
# Add nodes
for task in workflow.tasks:
graph.add_node(task.id)
# Add edges (dependencies)
for task in workflow.tasks:
for dep in task.depends_on:
graph.add_edge(dep, task.id) # Edge from dependency to task
return graph
def _execute_task(
self,
task: WorkflowTask,
results: Dict[str, Any],
trace: List[Dict[str, Any]]
) -> Any:
"""
Execute single task with retry logic.
Args:
task: Task to execute
results: Shared results dict (for accessing dependency outputs)
trace: Shared trace list
Returns:
Task result
"""
logger.info(f"Executing task: {task.id} (tool: {task.tool})")
trace.append({
"task_id": task.id,
"tool": task.tool,
"status": "started",
"timestamp": time.time()
})
# Get tool
tool = self.tools_registry.get(task.tool)
if not tool:
error_msg = f"Tool not found: {task.tool}"
logger.error(error_msg)
trace.append({
"task_id": task.id,
"status": "error",
"error": error_msg
})
raise ValueError(error_msg)
# Resolve arguments (may reference previous task results)
args = self._resolve_args(task.args, results)
# Execute with retry
last_error = None
for attempt in range(task.max_retries + 1):
try:
result = tool.forward(**args)
# Parse result if it's JSON string
if isinstance(result, str):
try:
result = json.loads(result)
except json.JSONDecodeError:
pass # Keep as string
trace.append({
"task_id": task.id,
"status": "completed",
"attempt": attempt + 1,
"timestamp": time.time()
})
logger.info(f"Task {task.id} completed successfully")
return result
except Exception as e:
last_error = e
logger.warning(
f"Task {task.id} attempt {attempt + 1}/{task.max_retries + 1} failed: {e}"
)
if attempt < task.max_retries and task.retry_on_failure:
time.sleep(1 * (2 ** attempt)) # Exponential backoff
continue
else:
trace.append({
"task_id": task.id,
"status": "failed",
"error": str(e),
"attempts": attempt + 1
})
raise
# Should not reach here, but for safety
if last_error:
raise last_error
else:
raise RuntimeError(f"Task {task.id} failed without exception")
def _resolve_args(self, args: Dict[str, Any], results: Dict[str, Any]) -> Dict[str, Any]:
"""
Resolve arguments that reference previous task results.
Supports syntax: "${task_id}" or "${task_id.field}"
Args:
args: Raw arguments
results: Previous task results
Returns:
Resolved arguments
"""
resolved = {}
for key, value in args.items():
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
# Reference to previous task result
ref = value[2:-1] # Remove ${ and }
parts = ref.split(".")
# Get task result
task_id = parts[0]
if task_id not in results:
raise ValueError(f"Referenced task {task_id} not yet executed")
result = results[task_id]
# Navigate nested fields
for part in parts[1:]:
if isinstance(result, dict):
result = result.get(part)
else:
raise ValueError(f"Cannot access field {part} on {type(result)}")
resolved[key] = result
else:
resolved[key] = value
return resolved
|