velai-workshop / velai /nodes /base_generative_node.py
kratadata's picture
Upload folder via script
0f8b3a0 verified
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Generic, TypeVar
from nicegui import ui
from velai.nodes.actions.node_action_decorator import as_action_name, node_action
from velai.nodes.actions.node_action_models import AsyncNodeActionResultIterator, NodeActionArguments, NodeActionResult
from velai.nodes.base_node import BaseNode, BaseNodeData
from velai.nodes.base_node_renderable import BaseNodeRenderable
@dataclass
class DownloadData:
src: str | Path | bytes
filename: str | None = None
mime_type: str = ""
@dataclass
class GenerativeNodeData(BaseNodeData):
pass
T_RESULT = TypeVar("T_RESULT")
T_GENERATIVE_DATA = TypeVar("T_GENERATIVE_DATA", bound=GenerativeNodeData)
class GenerativeNode(BaseNode[T_GENERATIVE_DATA], Generic[T_GENERATIVE_DATA], ABC):
data_cls = GenerativeNodeData
@abstractmethod
async def get_download_data(self) -> DownloadData | None:
raise NotImplementedError()
def _on_progress(self, value: float, message: str | None):
self.data.progress_value = value
self.data.progress_message = message
T_GENERATIVE_NODE = TypeVar("T_GENERATIVE_NODE", bound=GenerativeNode)
class GenerativeNodeRenderable(BaseNodeRenderable[T_GENERATIVE_NODE], Generic[T_GENERATIVE_NODE], ABC):
@node_action
async def _on_reset_node_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator:
# Confirm with the user before resetting
with ui.dialog() as dialog, ui.card():
ui.label("Are you sure to reset the node?")
with ui.row():
ui.button("Yes", on_click=lambda: dialog.submit("Yes"))
ui.button("No", on_click=lambda: dialog.submit("No"))
result = await dialog
if result != "Yes":
return
# Reset the outputs of a node
args.node.reset_outputs()
ui.notify("Node has been reset!", type="positive")
yield NodeActionResult.update_node()
@node_action
async def _on_download_data_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator:
if isinstance(args.node, GenerativeNode):
data = await args.node.get_download_data()
if data is None:
ui.notify("Nothing to download!", type="warning")
return
ui.notify("Download started...", type="positive")
if isinstance(data, str):
ui.navigate.to(data, new_tab=True)
else:
ui.download(data.src, data.filename, data.mime_type)
yield NodeActionResult.discard()
def get_header_buttons(self, node: T_GENERATIVE_NODE) -> list[dict[str, Any]]:
buttons = super().get_header_buttons(node)
buttons += [
{
"name": "run",
"icon": "rocket_launch",
"tooltip": "Generate",
"action": "execute_node",
"disableWhileProcessing": True,
},
{
"name": "reset",
"icon": "delete_forever",
"tooltip": "Reset",
"action": as_action_name(self._on_reset_node_action),
"requiresContent": self._requires_content_fields(),
"disableWhileProcessing": True,
},
{
"name": "download",
"icon": "download",
"tooltip": "Download",
"action": as_action_name(self._on_download_data_action),
"requiresContent": self._requires_content_fields(),
"disableWhileProcessing": True,
},
]
return buttons
def _requires_content_fields(self) -> list[str]:
return []