from __future__ import annotations import logging from abc import ABC from dataclasses import dataclass from typing import Any, Callable, ClassVar, Counter, Dict, Generic, TypeVar from PIL import Image from velai import app_context from velai.dataflow.enums import DataPortState from velai.dataflow.nodes import NodeInstance, NodeType from velai.dataflow.ports import PortState from velai.serialization.JsonSerializable import DataclassJsonSerializable from velai.serialization.JsonTypeSerializer import DefaultSerializer from velai.services.generator_service import GenerationResult logger = logging.getLogger(__name__) class NameConflictError(ValueError): pass @dataclass(slots=True) class BaseNodeData(DataclassJsonSerializable): error_message: str | None = None progress_value: float | None = None progress_message: str | None = None custom_title: str | None = None T_DATA = TypeVar("T_DATA", bound=BaseNodeData) class BaseNode(NodeInstance, Generic[T_DATA], ABC): data_cls: ClassVar[type[BaseNodeData]] = BaseNodeData def __init__( self, node_id: str, node_type: NodeType, data: T_DATA | None = None, auto_process: bool = False, x: float = 0.0, y: float = 0.0, width: float = 250, height: float = 200, inputs: dict[str, PortState] | None = None, outputs: dict[str, PortState] | None = None, on_process: Callable[[NodeInstance], None] | None = None, ) -> None: super().__init__( node_id=node_id, node_type=node_type, auto_process=auto_process, x=x, y=y, width=width, height=height, inputs=inputs, outputs=outputs, on_process=on_process, ) if data is None: data = self.data_cls() self.data = data # check all field names of inputs, outputs, data and see if there is a conflict self._check_for_name_conflict() def _check_for_name_conflict(self): variable_names = [ *[e.name for e in self.all_inputs()], *[e.name for e in self.all_outputs()], *self.data.to_dict().keys(), ] counts = Counter(variable_names) conflicts = [name for name, count in counts.items() if count > 1] if conflicts: raise NameConflictError(f"Duplicate variable names detected: {', '.join(conflicts)}") def get_display_title(self) -> str: custom = self.data.custom_title if custom and str(custom).strip(): return str(custom).strip() return self.node_type.display_name def get_state(self) -> Dict[str, Any]: # capture values of output ports outputs_dict = {} for name, port in (self.outputs or {}).items(): outputs_dict[name] = DefaultSerializer.serialize(port.value, source_type=port.schema.dtype.py_type) # add internal node state data_dict = self.data.to_dict() # todo: inputs, outputs and data dict share the names is problematic # idea: use different attribute-prefixes or objects state: Dict[str, Any] = {**outputs_dict, **data_dict} return state def set_state(self, state: dict[str, Any]) -> None: if not state: return logger.debug(f"set_state {self.node_id} ({self.node_type.kind})") self.data.update_from_dict(state) self._set_port_values(self.outputs, state) self._set_port_values(self.inputs, state) def duplicate(self, new_id: str) -> "BaseNode": """Create a copy of this node with a new identifier. The default implementation instantiates a new node of the same concrete class using the same ``node_type``. It then copies the serialisable state using ``get_state`` and ``set_state`` and also duplicates all input port values (deep copying PIL images where possible). Subclasses can override this to copy additional attributes. """ cls = type(self) # instantiate a new node; note that __init__ from dataclass will # initialise ports and call __post_init__ on NodeInstance new_node: BaseNode = cls(new_id, self.node_type) # type: ignore[call-arg] # copy over persisted state state = self.get_state() new_node.set_state(state) # copy input port values for name, port in (self.inputs or {}).items(): if name not in new_node.inputs: continue if port.value is None: continue new_val = port.value # deep copy PIL Image if needed try: if isinstance(port.value, Image.Image): new_val = port.value.copy() except Exception: pass new_node.inputs[name].value = new_val return new_node async def process(self) -> None: # early exit if already processed if not self.has_dirty_outputs(): return try: # run actual execution of the node await self.on_node_execution() except Exception as e: logger.exception("Node execution failed.") self.reset_node() self.data.error_message = f"{str(e)}" raise def reset_node(self) -> None: # reset internal state self.data = self.data_cls() self.reset_outputs() @staticmethod def _set_port_values(ports: dict[str, PortState], data_dict: dict[str, Any]): for field_name, value in data_dict.items(): if field_name not in ports: continue port = ports.get(field_name) port.value = DefaultSerializer.de_serialize(value, target_type=port.schema.dtype.py_type) port.state = DataPortState.CLEAN async def on_queue_for_execution(self): self.data.error_message = "" self.data.progress_value = None self.data.progress_message = None async def on_node_execution(self): pass async def on_generation_result(self, result: GenerationResult): ctx = await app_context.current_app_context() info = ctx.user_info info.generation.cost += result.cost info.save(ctx.user_storage)