velai / dataflow /graph.py
cansik's picture
Upload folder via script
3025bb3 verified
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Iterable, Callable, Awaitable
from .connection import Connection
from .enums import DataPortState, ConnectMultiplicity
from .nodes_base import NodeInstance
NodeExecutedCallback = Callable[[NodeInstance], Awaitable[None] | None]
@dataclass(slots=True)
class DataGraph:
nodes: dict[str, NodeInstance] = field(default_factory=dict)
connections: list[Connection] = field(default_factory=list)
_on_node_executed: NodeExecutedCallback | None = field(default=None, repr=False)
def add_node(self, node: NodeInstance) -> None:
self.nodes[node.node_id] = node
def add_connection(self, c: Connection) -> None:
# basic type safety
if c.end_port.schema.dtype.id != c.start_port.schema.dtype.id:
raise ValueError("datatype mismatch")
# multiplicity and capacity checks are enforced in the Vue side too
self.connections.append(c)
c.end_node.mark_dirty()
def upstream_of(self, node: NodeInstance) -> Iterable[Connection]:
for c in self.connections:
if c.end_node is node:
yield c
def downstream_of(self, node: NodeInstance) -> Iterable[Connection]:
for c in self.connections:
if c.start_node is node:
yield c
def set_on_node_executed(self, cb: NodeExecutedCallback | None) -> None:
"""Register a callback that is invoked after each node is executed.
The callback can be sync or async. Pass None to disable.
"""
self._on_node_executed = cb
async def _run_node(self, node: NodeInstance) -> None:
"""Internal helper that executes a single node and fires the callback."""
await node.process()
cb = self._on_node_executed
if cb is None:
return
result = cb(node)
# allow async callbacks
if hasattr(result, "__await__"):
await result # type: ignore[func-returns-value]
async def execute(self, node: NodeInstance | None = None) -> None:
if node is None:
for n in list(self.nodes.values()):
await self.execute(n)
return
# Identify incoming connections
incoming_conns = list(self.upstream_of(node))
# Recursively execute upstream if dirty
for c in incoming_conns:
if c.start_port.state == DataPortState.DIRTY:
await self.execute(c.start_node)
# Pull values from upstream to inputs
for inp in node.all_inputs():
# Find all connections feeding this port
feeds = [c for c in incoming_conns if c.end_port is inp]
if not feeds:
# No connections to this port - clear its value
inp.value = None if inp.schema.multiplicity == ConnectMultiplicity.SINGLE else []
inp.state = DataPortState.CLEAN
continue
if inp.schema.multiplicity == ConnectMultiplicity.MULTIPLE:
# Collect all values
values = []
for c in feeds:
if c.start_port.value is not None:
values.append(c.start_port.value)
inp.value = values
else:
# Single connection (take the last one if multiple defined by mistake)
if feeds:
inp.value = feeds[-1].start_port.value
print(
f"[DEBUG Graph] Transferring value to {node.node_id}.{inp.name}: type={type(inp.value).__name__ if inp.value is not None else 'None'}")
inp.state = DataPortState.CLEAN
old_outputs = {p.name: p.value for p in node.all_outputs()}
# Process
await self._run_node(node)
# after process
for outp in node.all_outputs():
outp.state = DataPortState.CLEAN
# mark downstream dirty only if needed
for c in self.downstream_of(node):
out_name = c.start_port.name
before = old_outputs.get(out_name)
after = c.start_port.value
if before != after:
c.end_port.state = DataPortState.DIRTY