Spaces:
Running
Running
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from typing import Any, Generic, Iterable, TypeVar | |
| from .connection import Connection | |
| from .enums import ConnectMultiplicity, DataPortState, PortConnectionState | |
| from .nodes import NodeInstance | |
| T_NODE = TypeVar("T_NODE", bound=NodeInstance) | |
| class DataGraph(Generic[T_NODE]): | |
| nodes: dict[str, T_NODE] = field(default_factory=dict) | |
| connections: list[Connection] = field(default_factory=list) | |
| def add_node(self, node: T_NODE) -> None: | |
| self.nodes[node.node_id] = node | |
| def _refresh_port_connection_state(self, port: Any) -> None: | |
| """Recompute connection_state for a single port based on self.connections.""" | |
| is_connected = any(c.start_port is port or c.end_port is port for c in self.connections) | |
| port.connection_state = PortConnectionState.CONNECTED if is_connected else PortConnectionState.DISCONNECTED | |
| def add_connection(self, c: Connection, mark_dirty: bool = True) -> None: | |
| if c.end_port.schema.dtype.id != c.start_port.schema.dtype.id: | |
| raise ValueError("datatype mismatch") | |
| self.connections.append(c) | |
| self._refresh_port_connection_state(c.start_port) | |
| self._refresh_port_connection_state(c.end_port) | |
| if mark_dirty: | |
| c.end_node.mark_dirty() | |
| c.end_port.state = DataPortState.DIRTY | |
| def remove_connection(self, c: Connection) -> None: | |
| """Remove a connection and update port states accordingly.""" | |
| try: | |
| self.connections.remove(c) | |
| except ValueError: | |
| return | |
| start_port = c.start_port | |
| end_port = c.end_port | |
| end_node = c.end_node | |
| # recompute connection_state for both ports | |
| self._refresh_port_connection_state(start_port) | |
| self._refresh_port_connection_state(end_port) | |
| # if the input port has no more incoming connections, clear its value | |
| if end_port.schema.multiplicity == ConnectMultiplicity.SINGLE: | |
| has_incoming = any(other.end_port is end_port for other in self.connections) | |
| if not has_incoming: | |
| end_port.value = None | |
| end_port.state = DataPortState.CLEAN | |
| else: | |
| # for multi inputs, rebuild the list from all remaining feeds | |
| feeds = [other for other in self.connections if other.end_port is end_port] | |
| values = [] | |
| for other in feeds: | |
| if other.start_port.value is not None: | |
| values.append(other.start_port.value) | |
| end_port.value = values | |
| end_port.state = DataPortState.CLEAN | |
| # the destination node's inputs changed, mark it dirty | |
| end_node.mark_dirty() | |
| def remove_node(self, node: NodeInstance) -> None: | |
| """Remove a node and its connections, updating neighbor port states.""" | |
| # remove all connections touching this node | |
| for c in list(self.connections): | |
| if c.start_node is node or c.end_node is node: | |
| # for connections where this node is the destination, | |
| # remove_connection will mark this node dirty, which is harmless | |
| self.remove_connection(c) | |
| # finally drop the node from the graph | |
| self.nodes.pop(node.node_id, None) | |
| def upstream_of(self, node: NodeInstance) -> Iterable[Connection]: | |
| for c in self.connections: | |
| if c.end_node is node: | |
| yield c | |
| def downstream_of(self, node: NodeInstance) -> Iterable[Connection]: | |
| for c in self.connections: | |
| if c.start_node is node: | |
| yield c | |
| def upstream_nodes(self, node: NodeInstance) -> Iterable[NodeInstance]: | |
| for c in self.upstream_of(node): | |
| yield c.start_node | |
| def downstream_nodes(self, node: NodeInstance) -> Iterable[NodeInstance]: | |
| for c in self.downstream_of(node): | |
| yield c.end_node | |