kratadata's picture
Upload folder via script
0f8b3a0 verified
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any, Generic, Iterable, TypeVar
from .connection import Connection
from .enums import ConnectMultiplicity, DataPortState, PortConnectionState
from .nodes import NodeInstance
T_NODE = TypeVar("T_NODE", bound=NodeInstance)
@dataclass
class DataGraph(Generic[T_NODE]):
nodes: dict[str, T_NODE] = field(default_factory=dict)
connections: list[Connection] = field(default_factory=list)
def add_node(self, node: T_NODE) -> None:
self.nodes[node.node_id] = node
def _refresh_port_connection_state(self, port: Any) -> None:
"""Recompute connection_state for a single port based on self.connections."""
is_connected = any(c.start_port is port or c.end_port is port for c in self.connections)
port.connection_state = PortConnectionState.CONNECTED if is_connected else PortConnectionState.DISCONNECTED
def add_connection(self, c: Connection, mark_dirty: bool = True) -> None:
if c.end_port.schema.dtype.id != c.start_port.schema.dtype.id:
raise ValueError("datatype mismatch")
self.connections.append(c)
self._refresh_port_connection_state(c.start_port)
self._refresh_port_connection_state(c.end_port)
if mark_dirty:
c.end_node.mark_dirty()
c.end_port.state = DataPortState.DIRTY
def remove_connection(self, c: Connection) -> None:
"""Remove a connection and update port states accordingly."""
try:
self.connections.remove(c)
except ValueError:
return
start_port = c.start_port
end_port = c.end_port
end_node = c.end_node
# recompute connection_state for both ports
self._refresh_port_connection_state(start_port)
self._refresh_port_connection_state(end_port)
# if the input port has no more incoming connections, clear its value
if end_port.schema.multiplicity == ConnectMultiplicity.SINGLE:
has_incoming = any(other.end_port is end_port for other in self.connections)
if not has_incoming:
end_port.value = None
end_port.state = DataPortState.CLEAN
else:
# for multi inputs, rebuild the list from all remaining feeds
feeds = [other for other in self.connections if other.end_port is end_port]
values = []
for other in feeds:
if other.start_port.value is not None:
values.append(other.start_port.value)
end_port.value = values
end_port.state = DataPortState.CLEAN
# the destination node's inputs changed, mark it dirty
end_node.mark_dirty()
def remove_node(self, node: NodeInstance) -> None:
"""Remove a node and its connections, updating neighbor port states."""
# remove all connections touching this node
for c in list(self.connections):
if c.start_node is node or c.end_node is node:
# for connections where this node is the destination,
# remove_connection will mark this node dirty, which is harmless
self.remove_connection(c)
# finally drop the node from the graph
self.nodes.pop(node.node_id, None)
def upstream_of(self, node: NodeInstance) -> Iterable[Connection]:
for c in self.connections:
if c.end_node is node:
yield c
def downstream_of(self, node: NodeInstance) -> Iterable[Connection]:
for c in self.connections:
if c.start_node is node:
yield c
def upstream_nodes(self, node: NodeInstance) -> Iterable[NodeInstance]:
for c in self.upstream_of(node):
yield c.start_node
def downstream_nodes(self, node: NodeInstance) -> Iterable[NodeInstance]:
for c in self.downstream_of(node):
yield c.end_node