velai / nodes /runtime.py
cansik's picture
Upload folder via script
3025bb3 verified
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from dataflow.graph import DataGraph
from dataflow.nodes_base import NodeInstance
from dataflow.ui.vueflow_canvas import VueFlowCanvas
from .text_to_image import TextToImageNode
@dataclass(slots=True)
class GraphRuntime:
graph: DataGraph
canvas: VueFlowCanvas
def _get_execution_chain(self, root: NodeInstance) -> list[NodeInstance]:
"""Return all upstream nodes that belong to the chain of root.
Order is upstream first, root last. This is only for UI purposes
(spinners, progress, syncing), not to force execution.
"""
result: list[NodeInstance] = []
visited: set[str] = set()
connections = getattr(self.graph, "connections", []) or []
def visit(node: NodeInstance) -> None:
node_id = getattr(node, "node_id", None)
if node_id is None or node_id in visited:
return
visited.add(node_id)
for conn in connections:
try:
if conn.end_node is node:
visit(conn.start_node)
except AttributeError:
continue
result.append(node)
visit(root)
return result
async def execute_node(self, node: NodeInstance | str) -> None:
"""Execute a node and keep the canvas in sync with its state.
Rules:
- Clicked node is always reset and re run.
- Upstream nodes are never reset here and may reuse cached data.
- Which nodes actually execute is decided by DataGraph and node logic.
- As soon as a node finishes, it is synced to the UI.
"""
import asyncio
if isinstance(node, str):
node_id = node
node_obj = self.graph.nodes.get(node_id)
if node_obj is None:
return
else:
node_obj = node
node_id = node.node_id
execution_chain = self._get_execution_chain(node_obj)
# show spinner on all nodes in the chain
for n in execution_chain:
nid = getattr(n, "node_id", None)
if nid:
self.canvas.set_node_processing(nid, True)
# progress polling for nodes that expose progress_value
stop_progress = False
progress_tasks: list[asyncio.Task] = []
async def progress_updater(n: NodeInstance, nid: str) -> None:
last_value: Any = None
while not stop_progress:
await asyncio.sleep(0.1)
if not hasattr(n, "progress_value"):
continue
current = getattr(n, "progress_value", None)
message = getattr(n, "progress_message", None)
if current is None:
continue
if current != last_value:
self.canvas.update_node_progress(nid, current, message)
last_value = current
for n in execution_chain:
nid = getattr(n, "node_id", None)
if nid and hasattr(n, "progress_value"):
progress_tasks.append(asyncio.create_task(progress_updater(n, nid)))
# callback from DataGraph after each node is executed
async def on_node_executed(executed_node: NodeInstance) -> None:
# Only nodes that actually ran will call this.
await self._sync_node_to_ui(executed_node)
# save previous callback so we can restore it
previous_cb = getattr(self.graph, "_on_node_executed", None)
try:
print(f"Runtime: Executing {node_id}...")
# clicked node is always reset, upstream nodes are not
if hasattr(node_obj, "reset_node"):
node_obj.reset_node()
# register our per node callback
self.graph.set_on_node_executed(on_node_executed)
# let DataGraph drive which nodes actually execute
await self.graph.execute(node_obj)
# one more sync for all nodes in the chain, in case some did not run
for n in execution_chain:
await self._sync_node_to_ui(n)
# nice "complete" flash for the clicked node if it has progress
if hasattr(node_obj, "progress_value"):
self.canvas.update_node_progress(node_id, 1.0, "Complete")
await asyncio.sleep(0.3)
except Exception as e:
print(f"Runtime execution failed: {e}")
import traceback
traceback.print_exc()
finally:
# restore previous graph callback
self.graph.set_on_node_executed(previous_cb)
# stop progress updaters
stop_progress = True
for t in progress_tasks:
t.cancel()
try:
await t
except asyncio.CancelledError:
pass
# hide spinner on all nodes in the chain
for n in execution_chain:
nid = getattr(n, "node_id", None)
if nid:
self.canvas.set_node_processing(nid, False)
# reset progress on the clicked node
if hasattr(node_obj, "progress_value"):
self.canvas.update_node_progress(node_id, 0.0, None)
async def _sync_node_to_ui(self, node: NodeInstance) -> None:
"""Push relevant node state back to the Vue nodes."""
if isinstance(node, TextToImageNode):
image_src = "" if node.image_src is None else str(node.image_src)
values: dict[str, Any] = {
"image": image_src,
"error": node.error or None,
}
self.canvas.update_node_values(node.node_id, values)
# add other node types here as needed