from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path from typing import Any, Generic, TypeVar from nicegui import ui from velai.nodes.actions.node_action_decorator import as_action_name, node_action from velai.nodes.actions.node_action_models import AsyncNodeActionResultIterator, NodeActionArguments, NodeActionResult from velai.nodes.base_node import BaseNode, BaseNodeData from velai.nodes.base_node_renderable import BaseNodeRenderable @dataclass class DownloadData: src: str | Path | bytes filename: str | None = None mime_type: str = "" @dataclass class GenerativeNodeData(BaseNodeData): pass T_RESULT = TypeVar("T_RESULT") T_GENERATIVE_DATA = TypeVar("T_GENERATIVE_DATA", bound=GenerativeNodeData) class GenerativeNode(BaseNode[T_GENERATIVE_DATA], Generic[T_GENERATIVE_DATA], ABC): data_cls = GenerativeNodeData @abstractmethod async def get_download_data(self) -> DownloadData | None: raise NotImplementedError() def _on_progress(self, value: float, message: str | None): self.data.progress_value = value self.data.progress_message = message T_GENERATIVE_NODE = TypeVar("T_GENERATIVE_NODE", bound=GenerativeNode) class GenerativeNodeRenderable(BaseNodeRenderable[T_GENERATIVE_NODE], Generic[T_GENERATIVE_NODE], ABC): @node_action async def _on_reset_node_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator: # Confirm with the user before resetting with ui.dialog() as dialog, ui.card(): ui.label("Are you sure to reset the node?") with ui.row(): ui.button("Yes", on_click=lambda: dialog.submit("Yes")) ui.button("No", on_click=lambda: dialog.submit("No")) result = await dialog if result != "Yes": return # Reset the outputs of a node args.node.reset_outputs() ui.notify("Node has been reset!", type="positive") yield NodeActionResult.update_node() @node_action async def _on_download_data_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator: if isinstance(args.node, GenerativeNode): data = await args.node.get_download_data() if data is None: ui.notify("Nothing to download!", type="warning") return ui.notify("Download started...", type="positive") if isinstance(data, str): ui.navigate.to(data, new_tab=True) else: ui.download(data.src, data.filename, data.mime_type) yield NodeActionResult.discard() def get_header_buttons(self, node: T_GENERATIVE_NODE) -> list[dict[str, Any]]: buttons = super().get_header_buttons(node) buttons += [ { "name": "run", "icon": "rocket_launch", "tooltip": "Generate", "action": "execute_node", "disableWhileProcessing": True, }, { "name": "reset", "icon": "delete_forever", "tooltip": "Reset", "action": as_action_name(self._on_reset_node_action), "requiresContent": self._requires_content_fields(), "disableWhileProcessing": True, }, { "name": "download", "icon": "download", "tooltip": "Download", "action": as_action_name(self._on_download_data_action), "requiresContent": self._requires_content_fields(), "disableWhileProcessing": True, }, ] return buttons def _requires_content_fields(self) -> list[str]: return []