velai-workshop / velai /nodes /base_node.py
kratadata's picture
Upload folder via script
0f8b3a0 verified
from __future__ import annotations
import logging
from abc import ABC
from dataclasses import dataclass
from typing import Any, Callable, ClassVar, Counter, Dict, Generic, TypeVar
from PIL import Image
from velai import app_context
from velai.dataflow.enums import DataPortState
from velai.dataflow.nodes import NodeInstance, NodeType
from velai.dataflow.ports import PortState
from velai.serialization.JsonSerializable import DataclassJsonSerializable
from velai.serialization.JsonTypeSerializer import DefaultSerializer
from velai.services.generator_service import GenerationResult
logger = logging.getLogger(__name__)
class NameConflictError(ValueError):
pass
@dataclass(slots=True)
class BaseNodeData(DataclassJsonSerializable):
error_message: str | None = None
progress_value: float | None = None
progress_message: str | None = None
custom_title: str | None = None
T_DATA = TypeVar("T_DATA", bound=BaseNodeData)
class BaseNode(NodeInstance, Generic[T_DATA], ABC):
data_cls: ClassVar[type[BaseNodeData]] = BaseNodeData
def __init__(
self,
node_id: str,
node_type: NodeType,
data: T_DATA | None = None,
auto_process: bool = False,
x: float = 0.0,
y: float = 0.0,
width: float = 250,
height: float = 200,
inputs: dict[str, PortState] | None = None,
outputs: dict[str, PortState] | None = None,
on_process: Callable[[NodeInstance], None] | None = None,
) -> None:
super().__init__(
node_id=node_id,
node_type=node_type,
auto_process=auto_process,
x=x,
y=y,
width=width,
height=height,
inputs=inputs,
outputs=outputs,
on_process=on_process,
)
if data is None:
data = self.data_cls()
self.data = data
# check all field names of inputs, outputs, data and see if there is a conflict
self._check_for_name_conflict()
def _check_for_name_conflict(self):
variable_names = [
*[e.name for e in self.all_inputs()],
*[e.name for e in self.all_outputs()],
*self.data.to_dict().keys(),
]
counts = Counter(variable_names)
conflicts = [name for name, count in counts.items() if count > 1]
if conflicts:
raise NameConflictError(f"Duplicate variable names detected: {', '.join(conflicts)}")
def get_display_title(self) -> str:
custom = self.data.custom_title
if custom and str(custom).strip():
return str(custom).strip()
return self.node_type.display_name
def get_state(self) -> Dict[str, Any]:
# capture values of output ports
outputs_dict = {}
for name, port in (self.outputs or {}).items():
outputs_dict[name] = DefaultSerializer.serialize(port.value, source_type=port.schema.dtype.py_type)
# add internal node state
data_dict = self.data.to_dict()
# todo: inputs, outputs and data dict share the names is problematic
# idea: use different attribute-prefixes or objects
state: Dict[str, Any] = {**outputs_dict, **data_dict}
return state
def set_state(self, state: dict[str, Any]) -> None:
if not state:
return
logger.debug(f"set_state {self.node_id} ({self.node_type.kind})")
self.data.update_from_dict(state)
self._set_port_values(self.outputs, state)
self._set_port_values(self.inputs, state)
def duplicate(self, new_id: str) -> "BaseNode":
"""Create a copy of this node with a new identifier.
The default implementation instantiates a new node of the same
concrete class using the same ``node_type``. It then copies the
serialisable state using ``get_state`` and ``set_state`` and also
duplicates all input port values (deep copying PIL images where
possible). Subclasses can override this to copy additional
attributes.
"""
cls = type(self)
# instantiate a new node; note that __init__ from dataclass will
# initialise ports and call __post_init__ on NodeInstance
new_node: BaseNode = cls(new_id, self.node_type) # type: ignore[call-arg]
# copy over persisted state
state = self.get_state()
new_node.set_state(state)
# copy input port values
for name, port in (self.inputs or {}).items():
if name not in new_node.inputs:
continue
if port.value is None:
continue
new_val = port.value
# deep copy PIL Image if needed
try:
if isinstance(port.value, Image.Image):
new_val = port.value.copy()
except Exception:
pass
new_node.inputs[name].value = new_val
return new_node
async def process(self) -> None:
# early exit if already processed
if not self.has_dirty_outputs():
return
try:
# run actual execution of the node
await self.on_node_execution()
except Exception as e:
logger.exception("Node execution failed.")
self.reset_node()
self.data.error_message = f"{str(e)}"
raise
def reset_node(self) -> None:
# reset internal state
self.data = self.data_cls()
self.reset_outputs()
@staticmethod
def _set_port_values(ports: dict[str, PortState], data_dict: dict[str, Any]):
for field_name, value in data_dict.items():
if field_name not in ports:
continue
port = ports.get(field_name)
port.value = DefaultSerializer.de_serialize(value, target_type=port.schema.dtype.py_type)
port.state = DataPortState.CLEAN
async def on_queue_for_execution(self):
self.data.error_message = ""
self.data.progress_value = None
self.data.progress_message = None
async def on_node_execution(self):
pass
async def on_generation_result(self, result: GenerationResult):
ctx = await app_context.current_app_context()
info = ctx.user_info
info.generation.cost += result.cost
info.save(ctx.user_storage)