File size: 3,981 Bytes
0f8b3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Generic, Iterable, TypeVar

from .connection import Connection
from .enums import ConnectMultiplicity, DataPortState, PortConnectionState
from .nodes import NodeInstance

T_NODE = TypeVar("T_NODE", bound=NodeInstance)


@dataclass
class DataGraph(Generic[T_NODE]):
    nodes: dict[str, T_NODE] = field(default_factory=dict)
    connections: list[Connection] = field(default_factory=list)

    def add_node(self, node: T_NODE) -> None:
        self.nodes[node.node_id] = node

    def _refresh_port_connection_state(self, port: Any) -> None:
        """Recompute connection_state for a single port based on self.connections."""
        is_connected = any(c.start_port is port or c.end_port is port for c in self.connections)
        port.connection_state = PortConnectionState.CONNECTED if is_connected else PortConnectionState.DISCONNECTED

    def add_connection(self, c: Connection, mark_dirty: bool = True) -> None:
        if c.end_port.schema.dtype.id != c.start_port.schema.dtype.id:
            raise ValueError("datatype mismatch")

        self.connections.append(c)

        self._refresh_port_connection_state(c.start_port)
        self._refresh_port_connection_state(c.end_port)

        if mark_dirty:
            c.end_node.mark_dirty()
            c.end_port.state = DataPortState.DIRTY

    def remove_connection(self, c: Connection) -> None:
        """Remove a connection and update port states accordingly."""
        try:
            self.connections.remove(c)
        except ValueError:
            return

        start_port = c.start_port
        end_port = c.end_port
        end_node = c.end_node

        # recompute connection_state for both ports
        self._refresh_port_connection_state(start_port)
        self._refresh_port_connection_state(end_port)

        # if the input port has no more incoming connections, clear its value
        if end_port.schema.multiplicity == ConnectMultiplicity.SINGLE:
            has_incoming = any(other.end_port is end_port for other in self.connections)
            if not has_incoming:
                end_port.value = None
                end_port.state = DataPortState.CLEAN
        else:
            # for multi inputs, rebuild the list from all remaining feeds
            feeds = [other for other in self.connections if other.end_port is end_port]
            values = []
            for other in feeds:
                if other.start_port.value is not None:
                    values.append(other.start_port.value)
            end_port.value = values
            end_port.state = DataPortState.CLEAN

        # the destination node's inputs changed, mark it dirty
        end_node.mark_dirty()

    def remove_node(self, node: NodeInstance) -> None:
        """Remove a node and its connections, updating neighbor port states."""
        # remove all connections touching this node
        for c in list(self.connections):
            if c.start_node is node or c.end_node is node:
                # for connections where this node is the destination,
                # remove_connection will mark this node dirty, which is harmless
                self.remove_connection(c)

        # finally drop the node from the graph
        self.nodes.pop(node.node_id, None)

    def upstream_of(self, node: NodeInstance) -> Iterable[Connection]:
        for c in self.connections:
            if c.end_node is node:
                yield c

    def downstream_of(self, node: NodeInstance) -> Iterable[Connection]:
        for c in self.connections:
            if c.start_node is node:
                yield c

    def upstream_nodes(self, node: NodeInstance) -> Iterable[NodeInstance]:
        for c in self.upstream_of(node):
            yield c.start_node

    def downstream_nodes(self, node: NodeInstance) -> Iterable[NodeInstance]:
        for c in self.downstream_of(node):
            yield c.end_node