File size: 5,920 Bytes
d868fac 3025bb3 d868fac 3025bb3 d868fac 3025bb3 d868fac 3025bb3 d868fac 691f45a 3025bb3 691f45a 3025bb3 691f45a d868fac 3025bb3 d868fac 3025bb3 691f45a 3025bb3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from dataflow.graph import DataGraph
from dataflow.nodes_base import NodeInstance
from dataflow.ui.vueflow_canvas import VueFlowCanvas
from .text_to_image import TextToImageNode
@dataclass(slots=True)
class GraphRuntime:
graph: DataGraph
canvas: VueFlowCanvas
def _get_execution_chain(self, root: NodeInstance) -> list[NodeInstance]:
"""Return all upstream nodes that belong to the chain of root.
Order is upstream first, root last. This is only for UI purposes
(spinners, progress, syncing), not to force execution.
"""
result: list[NodeInstance] = []
visited: set[str] = set()
connections = getattr(self.graph, "connections", []) or []
def visit(node: NodeInstance) -> None:
node_id = getattr(node, "node_id", None)
if node_id is None or node_id in visited:
return
visited.add(node_id)
for conn in connections:
try:
if conn.end_node is node:
visit(conn.start_node)
except AttributeError:
continue
result.append(node)
visit(root)
return result
async def execute_node(self, node: NodeInstance | str) -> None:
"""Execute a node and keep the canvas in sync with its state.
Rules:
- Clicked node is always reset and re run.
- Upstream nodes are never reset here and may reuse cached data.
- Which nodes actually execute is decided by DataGraph and node logic.
- As soon as a node finishes, it is synced to the UI.
"""
import asyncio
if isinstance(node, str):
node_id = node
node_obj = self.graph.nodes.get(node_id)
if node_obj is None:
return
else:
node_obj = node
node_id = node.node_id
execution_chain = self._get_execution_chain(node_obj)
# show spinner on all nodes in the chain
for n in execution_chain:
nid = getattr(n, "node_id", None)
if nid:
self.canvas.set_node_processing(nid, True)
# progress polling for nodes that expose progress_value
stop_progress = False
progress_tasks: list[asyncio.Task] = []
async def progress_updater(n: NodeInstance, nid: str) -> None:
last_value: Any = None
while not stop_progress:
await asyncio.sleep(0.1)
if not hasattr(n, "progress_value"):
continue
current = getattr(n, "progress_value", None)
message = getattr(n, "progress_message", None)
if current is None:
continue
if current != last_value:
self.canvas.update_node_progress(nid, current, message)
last_value = current
for n in execution_chain:
nid = getattr(n, "node_id", None)
if nid and hasattr(n, "progress_value"):
progress_tasks.append(asyncio.create_task(progress_updater(n, nid)))
# callback from DataGraph after each node is executed
async def on_node_executed(executed_node: NodeInstance) -> None:
# Only nodes that actually ran will call this.
await self._sync_node_to_ui(executed_node)
# save previous callback so we can restore it
previous_cb = getattr(self.graph, "_on_node_executed", None)
try:
print(f"Runtime: Executing {node_id}...")
# clicked node is always reset, upstream nodes are not
if hasattr(node_obj, "reset_node"):
node_obj.reset_node()
# register our per node callback
self.graph.set_on_node_executed(on_node_executed)
# let DataGraph drive which nodes actually execute
await self.graph.execute(node_obj)
# one more sync for all nodes in the chain, in case some did not run
for n in execution_chain:
await self._sync_node_to_ui(n)
# nice "complete" flash for the clicked node if it has progress
if hasattr(node_obj, "progress_value"):
self.canvas.update_node_progress(node_id, 1.0, "Complete")
await asyncio.sleep(0.3)
except Exception as e:
print(f"Runtime execution failed: {e}")
import traceback
traceback.print_exc()
finally:
# restore previous graph callback
self.graph.set_on_node_executed(previous_cb)
# stop progress updaters
stop_progress = True
for t in progress_tasks:
t.cancel()
try:
await t
except asyncio.CancelledError:
pass
# hide spinner on all nodes in the chain
for n in execution_chain:
nid = getattr(n, "node_id", None)
if nid:
self.canvas.set_node_processing(nid, False)
# reset progress on the clicked node
if hasattr(node_obj, "progress_value"):
self.canvas.update_node_progress(node_id, 0.0, None)
async def _sync_node_to_ui(self, node: NodeInstance) -> None:
"""Push relevant node state back to the Vue nodes."""
if isinstance(node, TextToImageNode):
image_src = "" if node.image_src is None else str(node.image_src)
values: dict[str, Any] = {
"image": image_src,
"error": node.error or None,
}
self.canvas.update_node_values(node.node_id, values)
# add other node types here as needed
|