| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass, field |
| | from typing import Iterable, Callable, Awaitable |
| |
|
| | from .connection import Connection |
| | from .enums import DataPortState, ConnectMultiplicity |
| | from .nodes_base import NodeInstance |
| |
|
| | NodeExecutedCallback = Callable[[NodeInstance], Awaitable[None] | None] |
| |
|
| |
|
| | @dataclass(slots=True) |
| | class DataGraph: |
| | nodes: dict[str, NodeInstance] = field(default_factory=dict) |
| | connections: list[Connection] = field(default_factory=list) |
| |
|
| | _on_node_executed: NodeExecutedCallback | None = field(default=None, repr=False) |
| |
|
| | def add_node(self, node: NodeInstance) -> None: |
| | self.nodes[node.node_id] = node |
| |
|
| | def add_connection(self, c: Connection) -> None: |
| | |
| | if c.end_port.schema.dtype.id != c.start_port.schema.dtype.id: |
| | raise ValueError("datatype mismatch") |
| | |
| | self.connections.append(c) |
| | c.end_node.mark_dirty() |
| |
|
| | def upstream_of(self, node: NodeInstance) -> Iterable[Connection]: |
| | for c in self.connections: |
| | if c.end_node is node: |
| | yield c |
| |
|
| | def downstream_of(self, node: NodeInstance) -> Iterable[Connection]: |
| | for c in self.connections: |
| | if c.start_node is node: |
| | yield c |
| |
|
| | def set_on_node_executed(self, cb: NodeExecutedCallback | None) -> None: |
| | """Register a callback that is invoked after each node is executed. |
| | |
| | The callback can be sync or async. Pass None to disable. |
| | """ |
| | self._on_node_executed = cb |
| |
|
| | async def _run_node(self, node: NodeInstance) -> None: |
| | """Internal helper that executes a single node and fires the callback.""" |
| | await node.process() |
| |
|
| | cb = self._on_node_executed |
| | if cb is None: |
| | return |
| |
|
| | result = cb(node) |
| | |
| | if hasattr(result, "__await__"): |
| | await result |
| |
|
| | async def execute(self, node: NodeInstance | None = None) -> None: |
| | if node is None: |
| | for n in list(self.nodes.values()): |
| | await self.execute(n) |
| | return |
| |
|
| | |
| | incoming_conns = list(self.upstream_of(node)) |
| |
|
| | |
| | for c in incoming_conns: |
| | if c.start_port.state == DataPortState.DIRTY: |
| | await self.execute(c.start_node) |
| |
|
| | |
| | for inp in node.all_inputs(): |
| | |
| | feeds = [c for c in incoming_conns if c.end_port is inp] |
| |
|
| | if not feeds: |
| | |
| | inp.value = None if inp.schema.multiplicity == ConnectMultiplicity.SINGLE else [] |
| | inp.state = DataPortState.CLEAN |
| | continue |
| |
|
| | if inp.schema.multiplicity == ConnectMultiplicity.MULTIPLE: |
| | |
| | values = [] |
| | for c in feeds: |
| | if c.start_port.value is not None: |
| | values.append(c.start_port.value) |
| | inp.value = values |
| | else: |
| | |
| | if feeds: |
| | inp.value = feeds[-1].start_port.value |
| | print( |
| | f"[DEBUG Graph] Transferring value to {node.node_id}.{inp.name}: type={type(inp.value).__name__ if inp.value is not None else 'None'}") |
| |
|
| | inp.state = DataPortState.CLEAN |
| |
|
| | old_outputs = {p.name: p.value for p in node.all_outputs()} |
| |
|
| | |
| | await self._run_node(node) |
| |
|
| | |
| | for outp in node.all_outputs(): |
| | outp.state = DataPortState.CLEAN |
| |
|
| | |
| | for c in self.downstream_of(node): |
| | out_name = c.start_port.name |
| | before = old_outputs.get(out_name) |
| | after = c.start_port.value |
| | if before != after: |
| | c.end_port.state = DataPortState.DIRTY |
| |
|