from __future__ import annotations from dataclasses import dataclass from typing import Any from dataflow.graph import DataGraph from dataflow.nodes_base import NodeInstance from dataflow.ui.vueflow_canvas import VueFlowCanvas from .text_to_image import TextToImageNode @dataclass(slots=True) class GraphRuntime: graph: DataGraph canvas: VueFlowCanvas def _get_execution_chain(self, root: NodeInstance) -> list[NodeInstance]: """Return all upstream nodes that belong to the chain of root. Order is upstream first, root last. This is only for UI purposes (spinners, progress, syncing), not to force execution. """ result: list[NodeInstance] = [] visited: set[str] = set() connections = getattr(self.graph, "connections", []) or [] def visit(node: NodeInstance) -> None: node_id = getattr(node, "node_id", None) if node_id is None or node_id in visited: return visited.add(node_id) for conn in connections: try: if conn.end_node is node: visit(conn.start_node) except AttributeError: continue result.append(node) visit(root) return result async def execute_node(self, node: NodeInstance | str) -> None: """Execute a node and keep the canvas in sync with its state. Rules: - Clicked node is always reset and re run. - Upstream nodes are never reset here and may reuse cached data. - Which nodes actually execute is decided by DataGraph and node logic. - As soon as a node finishes, it is synced to the UI. """ import asyncio if isinstance(node, str): node_id = node node_obj = self.graph.nodes.get(node_id) if node_obj is None: return else: node_obj = node node_id = node.node_id execution_chain = self._get_execution_chain(node_obj) # show spinner on all nodes in the chain for n in execution_chain: nid = getattr(n, "node_id", None) if nid: self.canvas.set_node_processing(nid, True) # progress polling for nodes that expose progress_value stop_progress = False progress_tasks: list[asyncio.Task] = [] async def progress_updater(n: NodeInstance, nid: str) -> None: last_value: Any = None while not stop_progress: await asyncio.sleep(0.1) if not hasattr(n, "progress_value"): continue current = getattr(n, "progress_value", None) message = getattr(n, "progress_message", None) if current is None: continue if current != last_value: self.canvas.update_node_progress(nid, current, message) last_value = current for n in execution_chain: nid = getattr(n, "node_id", None) if nid and hasattr(n, "progress_value"): progress_tasks.append(asyncio.create_task(progress_updater(n, nid))) # callback from DataGraph after each node is executed async def on_node_executed(executed_node: NodeInstance) -> None: # Only nodes that actually ran will call this. await self._sync_node_to_ui(executed_node) # save previous callback so we can restore it previous_cb = getattr(self.graph, "_on_node_executed", None) try: print(f"Runtime: Executing {node_id}...") # clicked node is always reset, upstream nodes are not if hasattr(node_obj, "reset_node"): node_obj.reset_node() # register our per node callback self.graph.set_on_node_executed(on_node_executed) # let DataGraph drive which nodes actually execute await self.graph.execute(node_obj) # one more sync for all nodes in the chain, in case some did not run for n in execution_chain: await self._sync_node_to_ui(n) # nice "complete" flash for the clicked node if it has progress if hasattr(node_obj, "progress_value"): self.canvas.update_node_progress(node_id, 1.0, "Complete") await asyncio.sleep(0.3) except Exception as e: print(f"Runtime execution failed: {e}") import traceback traceback.print_exc() finally: # restore previous graph callback self.graph.set_on_node_executed(previous_cb) # stop progress updaters stop_progress = True for t in progress_tasks: t.cancel() try: await t except asyncio.CancelledError: pass # hide spinner on all nodes in the chain for n in execution_chain: nid = getattr(n, "node_id", None) if nid: self.canvas.set_node_processing(nid, False) # reset progress on the clicked node if hasattr(node_obj, "progress_value"): self.canvas.update_node_progress(node_id, 0.0, None) async def _sync_node_to_ui(self, node: NodeInstance) -> None: """Push relevant node state back to the Vue nodes.""" if isinstance(node, TextToImageNode): image_src = "" if node.image_src is None else str(node.image_src) values: dict[str, Any] = { "image": image_src, "error": node.error or None, } self.canvas.update_node_values(node.node_id, values) # add other node types here as needed