from __future__ import annotations from dataclasses import dataclass, field from typing import Iterable, Callable, Awaitable from .connection import Connection from .enums import DataPortState, ConnectMultiplicity from .nodes_base import NodeInstance NodeExecutedCallback = Callable[[NodeInstance], Awaitable[None] | None] @dataclass(slots=True) class DataGraph: nodes: dict[str, NodeInstance] = field(default_factory=dict) connections: list[Connection] = field(default_factory=list) _on_node_executed: NodeExecutedCallback | None = field(default=None, repr=False) def add_node(self, node: NodeInstance) -> None: self.nodes[node.node_id] = node def add_connection(self, c: Connection) -> None: # basic type safety if c.end_port.schema.dtype.id != c.start_port.schema.dtype.id: raise ValueError("datatype mismatch") # multiplicity and capacity checks are enforced in the Vue side too self.connections.append(c) c.end_node.mark_dirty() def upstream_of(self, node: NodeInstance) -> Iterable[Connection]: for c in self.connections: if c.end_node is node: yield c def downstream_of(self, node: NodeInstance) -> Iterable[Connection]: for c in self.connections: if c.start_node is node: yield c def set_on_node_executed(self, cb: NodeExecutedCallback | None) -> None: """Register a callback that is invoked after each node is executed. The callback can be sync or async. Pass None to disable. """ self._on_node_executed = cb async def _run_node(self, node: NodeInstance) -> None: """Internal helper that executes a single node and fires the callback.""" await node.process() cb = self._on_node_executed if cb is None: return result = cb(node) # allow async callbacks if hasattr(result, "__await__"): await result # type: ignore[func-returns-value] async def execute(self, node: NodeInstance | None = None) -> None: if node is None: for n in list(self.nodes.values()): await self.execute(n) return # Identify incoming connections incoming_conns = list(self.upstream_of(node)) # Recursively execute upstream if dirty for c in incoming_conns: if c.start_port.state == DataPortState.DIRTY: await self.execute(c.start_node) # Pull values from upstream to inputs for inp in node.all_inputs(): # Find all connections feeding this port feeds = [c for c in incoming_conns if c.end_port is inp] if not feeds: # No connections to this port - clear its value inp.value = None if inp.schema.multiplicity == ConnectMultiplicity.SINGLE else [] inp.state = DataPortState.CLEAN continue if inp.schema.multiplicity == ConnectMultiplicity.MULTIPLE: # Collect all values values = [] for c in feeds: if c.start_port.value is not None: values.append(c.start_port.value) inp.value = values else: # Single connection (take the last one if multiple defined by mistake) if feeds: inp.value = feeds[-1].start_port.value print( f"[DEBUG Graph] Transferring value to {node.node_id}.{inp.name}: type={type(inp.value).__name__ if inp.value is not None else 'None'}") inp.state = DataPortState.CLEAN old_outputs = {p.name: p.value for p in node.all_outputs()} # Process await self._run_node(node) # after process for outp in node.all_outputs(): outp.state = DataPortState.CLEAN # mark downstream dirty only if needed for c in self.downstream_of(node): out_name = c.start_port.name before = old_outputs.get(out_name) after = c.start_port.value if before != after: c.end_port.state = DataPortState.DIRTY