from __future__ import annotations from dataclasses import dataclass, field from typing import Any, Generic, Iterable, TypeVar from .connection import Connection from .enums import ConnectMultiplicity, DataPortState, PortConnectionState from .nodes import NodeInstance T_NODE = TypeVar("T_NODE", bound=NodeInstance) @dataclass class DataGraph(Generic[T_NODE]): nodes: dict[str, T_NODE] = field(default_factory=dict) connections: list[Connection] = field(default_factory=list) def add_node(self, node: T_NODE) -> None: self.nodes[node.node_id] = node def _refresh_port_connection_state(self, port: Any) -> None: """Recompute connection_state for a single port based on self.connections.""" is_connected = any(c.start_port is port or c.end_port is port for c in self.connections) port.connection_state = PortConnectionState.CONNECTED if is_connected else PortConnectionState.DISCONNECTED def add_connection(self, c: Connection, mark_dirty: bool = True) -> None: if c.end_port.schema.dtype.id != c.start_port.schema.dtype.id: raise ValueError("datatype mismatch") self.connections.append(c) self._refresh_port_connection_state(c.start_port) self._refresh_port_connection_state(c.end_port) if mark_dirty: c.end_node.mark_dirty() c.end_port.state = DataPortState.DIRTY def remove_connection(self, c: Connection) -> None: """Remove a connection and update port states accordingly.""" try: self.connections.remove(c) except ValueError: return start_port = c.start_port end_port = c.end_port end_node = c.end_node # recompute connection_state for both ports self._refresh_port_connection_state(start_port) self._refresh_port_connection_state(end_port) # if the input port has no more incoming connections, clear its value if end_port.schema.multiplicity == ConnectMultiplicity.SINGLE: has_incoming = any(other.end_port is end_port for other in self.connections) if not has_incoming: end_port.value = None end_port.state = DataPortState.CLEAN else: # for multi inputs, rebuild the list from all remaining feeds feeds = [other for other in self.connections if other.end_port is end_port] values = [] for other in feeds: if other.start_port.value is not None: values.append(other.start_port.value) end_port.value = values end_port.state = DataPortState.CLEAN # the destination node's inputs changed, mark it dirty end_node.mark_dirty() def remove_node(self, node: NodeInstance) -> None: """Remove a node and its connections, updating neighbor port states.""" # remove all connections touching this node for c in list(self.connections): if c.start_node is node or c.end_node is node: # for connections where this node is the destination, # remove_connection will mark this node dirty, which is harmless self.remove_connection(c) # finally drop the node from the graph self.nodes.pop(node.node_id, None) def upstream_of(self, node: NodeInstance) -> Iterable[Connection]: for c in self.connections: if c.end_node is node: yield c def downstream_of(self, node: NodeInstance) -> Iterable[Connection]: for c in self.connections: if c.start_node is node: yield c def upstream_nodes(self, node: NodeInstance) -> Iterable[NodeInstance]: for c in self.upstream_of(node): yield c.start_node def downstream_nodes(self, node: NodeInstance) -> Iterable[NodeInstance]: for c in self.downstream_of(node): yield c.end_node