Spaces:
Running
Running
| from abc import ABC, abstractmethod | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Generic, TypeVar | |
| from nicegui import ui | |
| from velai.nodes.actions.node_action_decorator import as_action_name, node_action | |
| from velai.nodes.actions.node_action_models import AsyncNodeActionResultIterator, NodeActionArguments, NodeActionResult | |
| from velai.nodes.base_node import BaseNode, BaseNodeData | |
| from velai.nodes.base_node_renderable import BaseNodeRenderable | |
| class DownloadData: | |
| src: str | Path | bytes | |
| filename: str | None = None | |
| mime_type: str = "" | |
| class GenerativeNodeData(BaseNodeData): | |
| pass | |
| T_RESULT = TypeVar("T_RESULT") | |
| T_GENERATIVE_DATA = TypeVar("T_GENERATIVE_DATA", bound=GenerativeNodeData) | |
| class GenerativeNode(BaseNode[T_GENERATIVE_DATA], Generic[T_GENERATIVE_DATA], ABC): | |
| data_cls = GenerativeNodeData | |
| async def get_download_data(self) -> DownloadData | None: | |
| raise NotImplementedError() | |
| def _on_progress(self, value: float, message: str | None): | |
| self.data.progress_value = value | |
| self.data.progress_message = message | |
| T_GENERATIVE_NODE = TypeVar("T_GENERATIVE_NODE", bound=GenerativeNode) | |
| class GenerativeNodeRenderable(BaseNodeRenderable[T_GENERATIVE_NODE], Generic[T_GENERATIVE_NODE], ABC): | |
| async def _on_reset_node_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator: | |
| # Confirm with the user before resetting | |
| with ui.dialog() as dialog, ui.card(): | |
| ui.label("Are you sure to reset the node?") | |
| with ui.row(): | |
| ui.button("Yes", on_click=lambda: dialog.submit("Yes")) | |
| ui.button("No", on_click=lambda: dialog.submit("No")) | |
| result = await dialog | |
| if result != "Yes": | |
| return | |
| # Reset the outputs of a node | |
| args.node.reset_outputs() | |
| ui.notify("Node has been reset!", type="positive") | |
| yield NodeActionResult.update_node() | |
| async def _on_download_data_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator: | |
| if isinstance(args.node, GenerativeNode): | |
| data = await args.node.get_download_data() | |
| if data is None: | |
| ui.notify("Nothing to download!", type="warning") | |
| return | |
| ui.notify("Download started...", type="positive") | |
| if isinstance(data, str): | |
| ui.navigate.to(data, new_tab=True) | |
| else: | |
| ui.download(data.src, data.filename, data.mime_type) | |
| yield NodeActionResult.discard() | |
| def get_header_buttons(self, node: T_GENERATIVE_NODE) -> list[dict[str, Any]]: | |
| buttons = super().get_header_buttons(node) | |
| buttons += [ | |
| { | |
| "name": "run", | |
| "icon": "rocket_launch", | |
| "tooltip": "Generate", | |
| "action": "execute_node", | |
| "disableWhileProcessing": True, | |
| }, | |
| { | |
| "name": "reset", | |
| "icon": "delete_forever", | |
| "tooltip": "Reset", | |
| "action": as_action_name(self._on_reset_node_action), | |
| "requiresContent": self._requires_content_fields(), | |
| "disableWhileProcessing": True, | |
| }, | |
| { | |
| "name": "download", | |
| "icon": "download", | |
| "tooltip": "Download", | |
| "action": as_action_name(self._on_download_data_action), | |
| "requiresContent": self._requires_content_fields(), | |
| "disableWhileProcessing": True, | |
| }, | |
| ] | |
| return buttons | |
| def _requires_content_fields(self) -> list[str]: | |
| return [] | |