Spaces:
Running
Running
| from __future__ import annotations | |
| import asyncio | |
| import logging | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, Set | |
| from velai.dataflow.enums import ConnectMultiplicity, DataPortState, NodeState, PortConnectionState | |
| from velai.dataflow.graph import DataGraph | |
| from velai.dataflow.nodes import NodeInstance | |
| from velai.nodes.base_node import BaseNode, BaseNodeData | |
| from velai.ui.node_progress import ProgressStateUpdater | |
| from velai.ui.vueflow_canvas import VueFlowCanvas | |
| logger = logging.getLogger(__name__) | |
| class ExecutionGraph: | |
| root_node: NodeInstance | |
| nodes: Dict[str, NodeInstance] = field(default_factory=dict) | |
| # node_id -> set of upstream node_ids that must finish before this node starts | |
| dependencies: Dict[str, Set[str]] = field(default_factory=dict) | |
| nodes_to_execute: Set[str] = field(default_factory=set) | |
| class _ProgressManager: | |
| def __init__(self, canvas: VueFlowCanvas, updater: ProgressStateUpdater) -> None: | |
| self._canvas = canvas | |
| self._updater = updater | |
| self._stop_events: dict[str, asyncio.Event] = {} | |
| self._tasks: dict[str, asyncio.Task[None]] = {} | |
| def start(self, node: BaseNode[BaseNodeData]) -> None: | |
| node_id = node.node_id | |
| if node_id in self._tasks: | |
| return | |
| stop_event = asyncio.Event() | |
| self._stop_events[node_id] = stop_event | |
| self._tasks[node_id] = asyncio.create_task(self._poll(node, stop_event)) | |
| async def stop(self, node_id: str) -> None: | |
| stop_event = self._stop_events.pop(node_id, None) | |
| if stop_event is not None: | |
| stop_event.set() | |
| task = self._tasks.pop(node_id, None) | |
| if task is None: | |
| return | |
| task.cancel() | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| async def stop_all(self) -> None: | |
| for event in self._stop_events.values(): | |
| event.set() | |
| self._stop_events.clear() | |
| tasks = list(self._tasks.values()) | |
| self._tasks.clear() | |
| for task in tasks: | |
| task.cancel() | |
| for task in tasks: | |
| try: | |
| await task | |
| except asyncio.CancelledError: | |
| pass | |
| async def _poll(self, node: BaseNode[BaseNodeData], stop_event: asyncio.Event) -> None: | |
| last_value: Any = object() | |
| last_message: Any = object() | |
| while not stop_event.is_set(): | |
| await asyncio.sleep(0.1) | |
| current_value = node.data.progress_value | |
| current_message = node.data.progress_message | |
| if current_value != last_value or current_message != last_message: | |
| await self._updater.update(node) | |
| last_value = current_value | |
| last_message = current_message | |
| class GraphRuntime: | |
| graph: DataGraph | |
| canvas: VueFlowCanvas | |
| _node_tasks: Dict[str, asyncio.Task[None]] = field(default_factory=dict, init=False) | |
| _progress: _ProgressManager = field(init=False) | |
| _progress_updater: ProgressStateUpdater = field(init=False) | |
| def __post_init__(self) -> None: | |
| self._progress_updater = ProgressStateUpdater(self.canvas) | |
| self._progress = _ProgressManager(self.canvas, self._progress_updater) | |
| def _resolve_node(self, node: NodeInstance | str) -> NodeInstance | None: | |
| if isinstance(node, str): | |
| return self.graph.nodes.get(node) | |
| return node | |
| def _create_execution_graph(self, root: NodeInstance) -> ExecutionGraph: | |
| """Create an execution graph for the given root node. | |
| Also checks for circular dependencies. | |
| """ | |
| exe_graph = ExecutionGraph(root_node=root) | |
| visited: Dict[str, int] = {} # node_id -> status (0=visiting, 1=visited) | |
| def visit(node: NodeInstance): | |
| node_id = node.node_id | |
| if node_id in visited: | |
| if visited[node_id] == 0: | |
| raise ValueError(f"Circular dependency detected at node {node_id}") | |
| return | |
| visited[node_id] = 0 | |
| exe_graph.nodes[node_id] = node | |
| exe_graph.dependencies[node_id] = set() | |
| for conn in self.graph.upstream_of(node): | |
| upstream_node = conn.start_node | |
| exe_graph.dependencies[node_id].add(upstream_node.node_id) | |
| visit(upstream_node) | |
| visited[node_id] = 1 | |
| visit(root) | |
| return exe_graph | |
| def _compute_execution_plan(self, exe_graph: ExecutionGraph) -> None: | |
| # 1. Identify initially dirty nodes relevant to this execution | |
| dirty_nodes = set() | |
| for node in exe_graph.nodes.values(): | |
| is_relevant_dirty = False | |
| for output in node.all_outputs(): | |
| if output.state == DataPortState.DIRTY: | |
| # Check if this output is connected to any node in exe_graph | |
| for conn in self.graph.downstream_of(node): | |
| if conn.start_port is output and conn.end_node.node_id in exe_graph.nodes: | |
| is_relevant_dirty = True | |
| break | |
| if is_relevant_dirty: | |
| break | |
| if is_relevant_dirty: | |
| dirty_nodes.add(node.node_id) | |
| # 2. Propagate downstream | |
| # Build local downstream map | |
| downstream = {nid: set() for nid in exe_graph.nodes} | |
| for nid, deps in exe_graph.dependencies.items(): | |
| for dep in deps: | |
| if dep in downstream: | |
| downstream[dep].add(nid) | |
| queue = list(dirty_nodes) | |
| exe_graph.nodes_to_execute = set(dirty_nodes) | |
| while queue: | |
| nid = queue.pop(0) | |
| for child_id in downstream.get(nid, []): | |
| if child_id not in exe_graph.nodes_to_execute: | |
| exe_graph.nodes_to_execute.add(child_id) | |
| queue.append(child_id) | |
| async def execute_node(self, node: NodeInstance | str) -> None: | |
| """Execute a node and its dependencies in parallel where possible.""" | |
| node_obj = self._resolve_node(node) | |
| if node_obj is None: | |
| return | |
| logger.info("Runtime: Preparing execution for %s...", node_obj.node_id) | |
| # If already running or queued, wait for it to finish first | |
| waited = False | |
| while node_obj.node_id in self._node_tasks or node_obj.state in (NodeState.QUEUED, NodeState.PROCESSING): | |
| if node_obj.node_id in self._node_tasks: | |
| await self._node_tasks[node_obj.node_id] | |
| else: | |
| await asyncio.sleep(0.1) | |
| waited = True | |
| try: | |
| # 1. Create Execution Graph | |
| exe_graph = self._create_execution_graph(node_obj) | |
| # 2. Compute Execution Plan | |
| self._compute_execution_plan(exe_graph) | |
| # 3. Check if we need to run | |
| if not exe_graph.nodes_to_execute: | |
| if waited: | |
| # If we waited and it's clean, assume previous run satisfied request | |
| return | |
| else: | |
| # Force run if not waited (user clicked execute on clean node) | |
| exe_graph.nodes_to_execute.add(node_obj.node_id) | |
| # 4. Reset outputs for nodes that will run (specifically the root if forced) | |
| if node_obj.node_id in exe_graph.nodes_to_execute: | |
| node_obj.reset_outputs() | |
| # 5. Mark all nodes that will be executed as QUEUED | |
| for node_id in exe_graph.nodes_to_execute: | |
| exe_node = exe_graph.nodes[node_id] | |
| if exe_node.state not in (NodeState.PROCESSING, NodeState.QUEUED): | |
| exe_node.state = NodeState.QUEUED | |
| if isinstance(exe_node, BaseNode): | |
| await exe_node.on_queue_for_execution() | |
| await self._progress_updater.update(exe_node) | |
| # 6. Run Execution Graph | |
| await self._run_execution_graph(exe_graph) | |
| except Exception as e: | |
| logger.exception("Runtime execution failed: %s", e) | |
| # Cleanup any nodes that are still in QUEUED state due to this failure | |
| for exe_node in exe_graph.nodes.values(): | |
| if exe_node.state == NodeState.QUEUED and exe_node.node_id not in self._node_tasks: | |
| exe_node.state = NodeState.IDLE | |
| await self._progress_updater.update(exe_node) | |
| async def _run_execution_graph(self, exe_graph: ExecutionGraph) -> None: | |
| """Execute nodes in the execution graph topologically and in parallel.""" | |
| # Use a local cache of tasks to avoid re-executing the same node in this run | |
| # but also coordinate with the global _node_tasks. | |
| local_tasks: Dict[str, asyncio.Task[None]] = {} | |
| async def run_node_recursive(node_id: str): | |
| if node_id in local_tasks: | |
| await local_tasks[node_id] | |
| return | |
| node = exe_graph.nodes[node_id] | |
| # 1. Ensure all dependencies are executed first | |
| # We run them in parallel | |
| dep_tasks = [asyncio.create_task(run_node_recursive(dep_id)) for dep_id in exe_graph.dependencies[node_id]] | |
| if dep_tasks: | |
| await asyncio.gather(*dep_tasks) | |
| # 2. Execute this node (coordinated globally) | |
| # We wrap this in a local task so others in this run can wait for it | |
| local_tasks[node_id] = asyncio.create_task(self._ensure_node_executed(node, exe_graph)) | |
| await local_tasks[node_id] | |
| await run_node_recursive(exe_graph.root_node.node_id) | |
| async def _ensure_node_executed(self, node: NodeInstance, exe_graph: ExecutionGraph) -> None: | |
| node_id = node.node_id | |
| # If already being executed globally, just wait for it | |
| if node_id in self._node_tasks: | |
| logger.debug("Node %s is already executing, waiting...", node_id) | |
| await self._node_tasks[node_id] | |
| return | |
| # Check if node actually needs execution according to plan | |
| if node_id not in exe_graph.nodes_to_execute: | |
| logger.debug("Node %s not in execution plan, skipping.", node_id) | |
| return | |
| # Double check: if node became clean (executed by another task) and is not the forced root, skip it | |
| if node_id != exe_graph.root_node.node_id and not node.has_dirty_outputs(): | |
| logger.debug("Node %s became clean, skipping.", node_id) | |
| return | |
| # Create a new global task for this node | |
| task = asyncio.create_task(self._execute_single_node(node)) | |
| self._node_tasks[node_id] = task | |
| try: | |
| await task | |
| finally: | |
| self._node_tasks.pop(node_id, None) | |
| async def _execute_single_node(self, node: NodeInstance) -> None: | |
| """The actual execution of a single node.""" | |
| node_id = node.node_id | |
| logger.info("Runtime: Executing node %s", node_id) | |
| # 1. Update status to QUEUED | |
| node.state = NodeState.QUEUED | |
| if isinstance(node, BaseNode): | |
| await node.on_queue_for_execution() | |
| await self._progress_updater.update(node) | |
| # 2. Pull values from upstream | |
| self._pull_upstream_values(node) | |
| # 3. Update status to PROCESSING | |
| node.state = NodeState.PROCESSING | |
| if isinstance(node, BaseNode): | |
| self._progress.start(node) | |
| await self._progress_updater.update(node) | |
| try: | |
| # 4. Run process | |
| await node.process() | |
| # 5. Push values to downstream (mark them dirty) | |
| self._push_downstream_values(node) | |
| except Exception as e: | |
| logger.exception("Node %s failed: %s", node_id, e) | |
| # In case of failure, we might want to set error message if it's a BaseNode | |
| if isinstance(node, BaseNode): | |
| node.data.error_message = str(e) | |
| raise | |
| finally: | |
| # 6. Update status to IDLE | |
| node.state = NodeState.IDLE | |
| # 7. UI Cleanup | |
| await self._progress.stop(node_id) | |
| await self._progress_updater.update(node) | |
| await self.sync_node_to_ui(node) | |
| def _pull_upstream_values(self, node: NodeInstance) -> None: | |
| incoming_conns = list(self.graph.upstream_of(node)) | |
| for inp in node.all_inputs(): | |
| # Find all connections feeding this port | |
| feeds = [c for c in incoming_conns if c.end_port is inp] | |
| if not feeds: | |
| # No connections to this port | |
| inp.value = None if inp.schema.multiplicity == ConnectMultiplicity.SINGLE else [] | |
| elif inp.schema.multiplicity == ConnectMultiplicity.MULTIPLE: | |
| values = [] | |
| for c in feeds: | |
| if c.start_port.value is not None: | |
| values.append(c.start_port.value) | |
| inp.value = values | |
| else: | |
| # Single connection | |
| inp.value = feeds[-1].start_port.value | |
| inp.state = DataPortState.CLEAN | |
| inp.connection_state = PortConnectionState.CONNECTED if feeds else PortConnectionState.DISCONNECTED | |
| def _push_downstream_values(self, node: NodeInstance) -> None: | |
| for outp in node.all_outputs(): | |
| outp.state = DataPortState.CLEAN | |
| for c in self.graph.downstream_of(node): | |
| # For now, we always mark downstream nodes as dirty when an upstream node finishes. | |
| # A more optimized version would check if the output value actually changed. | |
| c.end_port.state = DataPortState.DIRTY | |
| c.end_node.mark_dirty() | |
| async def sync_node_to_ui(self, node: NodeInstance) -> None: | |
| """Push relevant node state back to the Vue nodes.""" | |
| if not isinstance(node, BaseNode): | |
| return | |
| values = node.get_state() | |
| if values: | |
| await self.canvas.update_node_values(node.node_id, values) | |