File size: 6,405 Bytes
0f8b3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
from __future__ import annotations

import logging
from abc import ABC
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Counter, Dict, Generic, TypeVar

from PIL import Image

from velai import app_context
from velai.dataflow.enums import DataPortState
from velai.dataflow.nodes import NodeInstance, NodeType
from velai.dataflow.ports import PortState
from velai.serialization.JsonSerializable import DataclassJsonSerializable
from velai.serialization.JsonTypeSerializer import DefaultSerializer
from velai.services.generator_service import GenerationResult

logger = logging.getLogger(__name__)


class NameConflictError(ValueError):
    pass


@dataclass(slots=True)
class BaseNodeData(DataclassJsonSerializable):
    error_message: str | None = None

    progress_value: float | None = None
    progress_message: str | None = None
    custom_title: str | None = None


T_DATA = TypeVar("T_DATA", bound=BaseNodeData)


class BaseNode(NodeInstance, Generic[T_DATA], ABC):
    data_cls: ClassVar[type[BaseNodeData]] = BaseNodeData

    def __init__(
        self,
        node_id: str,
        node_type: NodeType,
        data: T_DATA | None = None,
        auto_process: bool = False,
        x: float = 0.0,
        y: float = 0.0,
        width: float = 250,
        height: float = 200,
        inputs: dict[str, PortState] | None = None,
        outputs: dict[str, PortState] | None = None,
        on_process: Callable[[NodeInstance], None] | None = None,
    ) -> None:
        super().__init__(
            node_id=node_id,
            node_type=node_type,
            auto_process=auto_process,
            x=x,
            y=y,
            width=width,
            height=height,
            inputs=inputs,
            outputs=outputs,
            on_process=on_process,
        )

        if data is None:
            data = self.data_cls()
        self.data = data

        # check all field names of inputs, outputs, data and see if there is a conflict
        self._check_for_name_conflict()

    def _check_for_name_conflict(self):
        variable_names = [
            *[e.name for e in self.all_inputs()],
            *[e.name for e in self.all_outputs()],
            *self.data.to_dict().keys(),
        ]

        counts = Counter(variable_names)
        conflicts = [name for name, count in counts.items() if count > 1]

        if conflicts:
            raise NameConflictError(f"Duplicate variable names detected: {', '.join(conflicts)}")

    def get_display_title(self) -> str:
        custom = self.data.custom_title
        if custom and str(custom).strip():
            return str(custom).strip()
        return self.node_type.display_name

    def get_state(self) -> Dict[str, Any]:
        # capture values of output ports
        outputs_dict = {}
        for name, port in (self.outputs or {}).items():
            outputs_dict[name] = DefaultSerializer.serialize(port.value, source_type=port.schema.dtype.py_type)

        # add internal node state
        data_dict = self.data.to_dict()

        # todo: inputs, outputs and data dict share the names is problematic
        #   idea: use different attribute-prefixes or objects
        state: Dict[str, Any] = {**outputs_dict, **data_dict}

        return state

    def set_state(self, state: dict[str, Any]) -> None:
        if not state:
            return

        logger.debug(f"set_state {self.node_id} ({self.node_type.kind})")

        self.data.update_from_dict(state)

        self._set_port_values(self.outputs, state)
        self._set_port_values(self.inputs, state)

    def duplicate(self, new_id: str) -> "BaseNode":
        """Create a copy of this node with a new identifier.

        The default implementation instantiates a new node of the same
        concrete class using the same ``node_type``. It then copies the
        serialisable state using ``get_state`` and ``set_state`` and also
        duplicates all input port values (deep copying PIL images where
        possible). Subclasses can override this to copy additional
        attributes.
        """
        cls = type(self)
        # instantiate a new node; note that __init__ from dataclass will
        # initialise ports and call __post_init__ on NodeInstance
        new_node: BaseNode = cls(new_id, self.node_type)  # type: ignore[call-arg]
        # copy over persisted state
        state = self.get_state()
        new_node.set_state(state)
        # copy input port values
        for name, port in (self.inputs or {}).items():
            if name not in new_node.inputs:
                continue
            if port.value is None:
                continue
            new_val = port.value
            # deep copy PIL Image if needed
            try:
                if isinstance(port.value, Image.Image):
                    new_val = port.value.copy()
            except Exception:
                pass
            new_node.inputs[name].value = new_val
        return new_node

    async def process(self) -> None:
        # early exit if already processed
        if not self.has_dirty_outputs():
            return

        try:
            # run actual execution of the node
            await self.on_node_execution()
        except Exception as e:
            logger.exception("Node execution failed.")

            self.reset_node()
            self.data.error_message = f"{str(e)}"
            raise

    def reset_node(self) -> None:
        # reset internal state
        self.data = self.data_cls()
        self.reset_outputs()

    @staticmethod
    def _set_port_values(ports: dict[str, PortState], data_dict: dict[str, Any]):
        for field_name, value in data_dict.items():
            if field_name not in ports:
                continue

            port = ports.get(field_name)
            port.value = DefaultSerializer.de_serialize(value, target_type=port.schema.dtype.py_type)
            port.state = DataPortState.CLEAN

    async def on_queue_for_execution(self):
        self.data.error_message = ""
        self.data.progress_value = None
        self.data.progress_message = None

    async def on_node_execution(self):
        pass

    async def on_generation_result(self, result: GenerationResult):
        ctx = await app_context.current_app_context()
        info = ctx.user_info

        info.generation.cost += result.cost
        info.save(ctx.user_storage)