| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from typing import Any |
| |
|
| | from dataflow.graph import DataGraph |
| | from dataflow.nodes_base import NodeInstance |
| | from dataflow.ui.vueflow_canvas import VueFlowCanvas |
| | from .text_to_image import TextToImageNode |
| |
|
| |
|
| | @dataclass(slots=True) |
| | class GraphRuntime: |
| | graph: DataGraph |
| | canvas: VueFlowCanvas |
| |
|
| | def _get_execution_chain(self, root: NodeInstance) -> list[NodeInstance]: |
| | """Return all upstream nodes that belong to the chain of root. |
| | |
| | Order is upstream first, root last. This is only for UI purposes |
| | (spinners, progress, syncing), not to force execution. |
| | """ |
| | result: list[NodeInstance] = [] |
| | visited: set[str] = set() |
| | connections = getattr(self.graph, "connections", []) or [] |
| |
|
| | def visit(node: NodeInstance) -> None: |
| | node_id = getattr(node, "node_id", None) |
| | if node_id is None or node_id in visited: |
| | return |
| | visited.add(node_id) |
| |
|
| | for conn in connections: |
| | try: |
| | if conn.end_node is node: |
| | visit(conn.start_node) |
| | except AttributeError: |
| | continue |
| |
|
| | result.append(node) |
| |
|
| | visit(root) |
| | return result |
| |
|
| | async def execute_node(self, node: NodeInstance | str) -> None: |
| | """Execute a node and keep the canvas in sync with its state. |
| | |
| | Rules: |
| | - Clicked node is always reset and re run. |
| | - Upstream nodes are never reset here and may reuse cached data. |
| | - Which nodes actually execute is decided by DataGraph and node logic. |
| | - As soon as a node finishes, it is synced to the UI. |
| | """ |
| | import asyncio |
| |
|
| | if isinstance(node, str): |
| | node_id = node |
| | node_obj = self.graph.nodes.get(node_id) |
| | if node_obj is None: |
| | return |
| | else: |
| | node_obj = node |
| | node_id = node.node_id |
| |
|
| | execution_chain = self._get_execution_chain(node_obj) |
| |
|
| | |
| | for n in execution_chain: |
| | nid = getattr(n, "node_id", None) |
| | if nid: |
| | self.canvas.set_node_processing(nid, True) |
| |
|
| | |
| | stop_progress = False |
| | progress_tasks: list[asyncio.Task] = [] |
| |
|
| | async def progress_updater(n: NodeInstance, nid: str) -> None: |
| | last_value: Any = None |
| | while not stop_progress: |
| | await asyncio.sleep(0.1) |
| | if not hasattr(n, "progress_value"): |
| | continue |
| | current = getattr(n, "progress_value", None) |
| | message = getattr(n, "progress_message", None) |
| | if current is None: |
| | continue |
| | if current != last_value: |
| | self.canvas.update_node_progress(nid, current, message) |
| | last_value = current |
| |
|
| | for n in execution_chain: |
| | nid = getattr(n, "node_id", None) |
| | if nid and hasattr(n, "progress_value"): |
| | progress_tasks.append(asyncio.create_task(progress_updater(n, nid))) |
| |
|
| | |
| | async def on_node_executed(executed_node: NodeInstance) -> None: |
| | |
| | await self._sync_node_to_ui(executed_node) |
| |
|
| | |
| | previous_cb = getattr(self.graph, "_on_node_executed", None) |
| |
|
| | try: |
| | print(f"Runtime: Executing {node_id}...") |
| |
|
| | |
| | if hasattr(node_obj, "reset_node"): |
| | node_obj.reset_node() |
| |
|
| | |
| | self.graph.set_on_node_executed(on_node_executed) |
| |
|
| | |
| | await self.graph.execute(node_obj) |
| |
|
| | |
| | for n in execution_chain: |
| | await self._sync_node_to_ui(n) |
| |
|
| | |
| | if hasattr(node_obj, "progress_value"): |
| | self.canvas.update_node_progress(node_id, 1.0, "Complete") |
| | await asyncio.sleep(0.3) |
| |
|
| | except Exception as e: |
| | print(f"Runtime execution failed: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | finally: |
| | |
| | self.graph.set_on_node_executed(previous_cb) |
| |
|
| | |
| | stop_progress = True |
| | for t in progress_tasks: |
| | t.cancel() |
| | try: |
| | await t |
| | except asyncio.CancelledError: |
| | pass |
| |
|
| | |
| | for n in execution_chain: |
| | nid = getattr(n, "node_id", None) |
| | if nid: |
| | self.canvas.set_node_processing(nid, False) |
| |
|
| | |
| | if hasattr(node_obj, "progress_value"): |
| | self.canvas.update_node_progress(node_id, 0.0, None) |
| |
|
| | async def _sync_node_to_ui(self, node: NodeInstance) -> None: |
| | """Push relevant node state back to the Vue nodes.""" |
| | if isinstance(node, TextToImageNode): |
| | image_src = "" if node.image_src is None else str(node.image_src) |
| | values: dict[str, Any] = { |
| | "image": image_src, |
| | "error": node.error or None, |
| | } |
| | self.canvas.update_node_values(node.node_id, values) |
| | |
| |
|