Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| from abc import ABC | |
| from dataclasses import dataclass | |
| from typing import Any, Callable, ClassVar, Counter, Dict, Generic, TypeVar | |
| from PIL import Image | |
| from velai import app_context | |
| from velai.dataflow.enums import DataPortState | |
| from velai.dataflow.nodes import NodeInstance, NodeType | |
| from velai.dataflow.ports import PortState | |
| from velai.serialization.JsonSerializable import DataclassJsonSerializable | |
| from velai.serialization.JsonTypeSerializer import DefaultSerializer | |
| from velai.services.generator_service import GenerationResult | |
| logger = logging.getLogger(__name__) | |
| class NameConflictError(ValueError): | |
| pass | |
| class BaseNodeData(DataclassJsonSerializable): | |
| error_message: str | None = None | |
| progress_value: float | None = None | |
| progress_message: str | None = None | |
| custom_title: str | None = None | |
| T_DATA = TypeVar("T_DATA", bound=BaseNodeData) | |
| class BaseNode(NodeInstance, Generic[T_DATA], ABC): | |
| data_cls: ClassVar[type[BaseNodeData]] = BaseNodeData | |
| def __init__( | |
| self, | |
| node_id: str, | |
| node_type: NodeType, | |
| data: T_DATA | None = None, | |
| auto_process: bool = False, | |
| x: float = 0.0, | |
| y: float = 0.0, | |
| width: float = 250, | |
| height: float = 200, | |
| inputs: dict[str, PortState] | None = None, | |
| outputs: dict[str, PortState] | None = None, | |
| on_process: Callable[[NodeInstance], None] | None = None, | |
| ) -> None: | |
| super().__init__( | |
| node_id=node_id, | |
| node_type=node_type, | |
| auto_process=auto_process, | |
| x=x, | |
| y=y, | |
| width=width, | |
| height=height, | |
| inputs=inputs, | |
| outputs=outputs, | |
| on_process=on_process, | |
| ) | |
| if data is None: | |
| data = self.data_cls() | |
| self.data = data | |
| # check all field names of inputs, outputs, data and see if there is a conflict | |
| self._check_for_name_conflict() | |
| def _check_for_name_conflict(self): | |
| variable_names = [ | |
| *[e.name for e in self.all_inputs()], | |
| *[e.name for e in self.all_outputs()], | |
| *self.data.to_dict().keys(), | |
| ] | |
| counts = Counter(variable_names) | |
| conflicts = [name for name, count in counts.items() if count > 1] | |
| if conflicts: | |
| raise NameConflictError(f"Duplicate variable names detected: {', '.join(conflicts)}") | |
| def get_display_title(self) -> str: | |
| custom = self.data.custom_title | |
| if custom and str(custom).strip(): | |
| return str(custom).strip() | |
| return self.node_type.display_name | |
| def get_state(self) -> Dict[str, Any]: | |
| # capture values of output ports | |
| outputs_dict = {} | |
| for name, port in (self.outputs or {}).items(): | |
| outputs_dict[name] = DefaultSerializer.serialize(port.value, source_type=port.schema.dtype.py_type) | |
| # add internal node state | |
| data_dict = self.data.to_dict() | |
| # todo: inputs, outputs and data dict share the names is problematic | |
| # idea: use different attribute-prefixes or objects | |
| state: Dict[str, Any] = {**outputs_dict, **data_dict} | |
| return state | |
| def set_state(self, state: dict[str, Any]) -> None: | |
| if not state: | |
| return | |
| logger.debug(f"set_state {self.node_id} ({self.node_type.kind})") | |
| self.data.update_from_dict(state) | |
| self._set_port_values(self.outputs, state) | |
| self._set_port_values(self.inputs, state) | |
| def duplicate(self, new_id: str) -> "BaseNode": | |
| """Create a copy of this node with a new identifier. | |
| The default implementation instantiates a new node of the same | |
| concrete class using the same ``node_type``. It then copies the | |
| serialisable state using ``get_state`` and ``set_state`` and also | |
| duplicates all input port values (deep copying PIL images where | |
| possible). Subclasses can override this to copy additional | |
| attributes. | |
| """ | |
| cls = type(self) | |
| # instantiate a new node; note that __init__ from dataclass will | |
| # initialise ports and call __post_init__ on NodeInstance | |
| new_node: BaseNode = cls(new_id, self.node_type) # type: ignore[call-arg] | |
| # copy over persisted state | |
| state = self.get_state() | |
| new_node.set_state(state) | |
| # copy input port values | |
| for name, port in (self.inputs or {}).items(): | |
| if name not in new_node.inputs: | |
| continue | |
| if port.value is None: | |
| continue | |
| new_val = port.value | |
| # deep copy PIL Image if needed | |
| try: | |
| if isinstance(port.value, Image.Image): | |
| new_val = port.value.copy() | |
| except Exception: | |
| pass | |
| new_node.inputs[name].value = new_val | |
| return new_node | |
| async def process(self) -> None: | |
| # early exit if already processed | |
| if not self.has_dirty_outputs(): | |
| return | |
| try: | |
| # run actual execution of the node | |
| await self.on_node_execution() | |
| except Exception as e: | |
| logger.exception("Node execution failed.") | |
| self.reset_node() | |
| self.data.error_message = f"{str(e)}" | |
| raise | |
| def reset_node(self) -> None: | |
| # reset internal state | |
| self.data = self.data_cls() | |
| self.reset_outputs() | |
| def _set_port_values(ports: dict[str, PortState], data_dict: dict[str, Any]): | |
| for field_name, value in data_dict.items(): | |
| if field_name not in ports: | |
| continue | |
| port = ports.get(field_name) | |
| port.value = DefaultSerializer.de_serialize(value, target_type=port.schema.dtype.py_type) | |
| port.state = DataPortState.CLEAN | |
| async def on_queue_for_execution(self): | |
| self.data.error_message = "" | |
| self.data.progress_value = None | |
| self.data.progress_message = None | |
| async def on_node_execution(self): | |
| pass | |
| async def on_generation_result(self, result: GenerationResult): | |
| ctx = await app_context.current_app_context() | |
| info = ctx.user_info | |
| info.generation.cost += result.cost | |
| info.save(ctx.user_storage) | |