Spaces:
Running
Running
File size: 3,811 Bytes
0f8b3a0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Generic, TypeVar
from nicegui import ui
from velai.nodes.actions.node_action_decorator import as_action_name, node_action
from velai.nodes.actions.node_action_models import AsyncNodeActionResultIterator, NodeActionArguments, NodeActionResult
from velai.nodes.base_node import BaseNode, BaseNodeData
from velai.nodes.base_node_renderable import BaseNodeRenderable
@dataclass
class DownloadData:
src: str | Path | bytes
filename: str | None = None
mime_type: str = ""
@dataclass
class GenerativeNodeData(BaseNodeData):
pass
T_RESULT = TypeVar("T_RESULT")
T_GENERATIVE_DATA = TypeVar("T_GENERATIVE_DATA", bound=GenerativeNodeData)
class GenerativeNode(BaseNode[T_GENERATIVE_DATA], Generic[T_GENERATIVE_DATA], ABC):
data_cls = GenerativeNodeData
@abstractmethod
async def get_download_data(self) -> DownloadData | None:
raise NotImplementedError()
def _on_progress(self, value: float, message: str | None):
self.data.progress_value = value
self.data.progress_message = message
T_GENERATIVE_NODE = TypeVar("T_GENERATIVE_NODE", bound=GenerativeNode)
class GenerativeNodeRenderable(BaseNodeRenderable[T_GENERATIVE_NODE], Generic[T_GENERATIVE_NODE], ABC):
@node_action
async def _on_reset_node_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator:
# Confirm with the user before resetting
with ui.dialog() as dialog, ui.card():
ui.label("Are you sure to reset the node?")
with ui.row():
ui.button("Yes", on_click=lambda: dialog.submit("Yes"))
ui.button("No", on_click=lambda: dialog.submit("No"))
result = await dialog
if result != "Yes":
return
# Reset the outputs of a node
args.node.reset_outputs()
ui.notify("Node has been reset!", type="positive")
yield NodeActionResult.update_node()
@node_action
async def _on_download_data_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator:
if isinstance(args.node, GenerativeNode):
data = await args.node.get_download_data()
if data is None:
ui.notify("Nothing to download!", type="warning")
return
ui.notify("Download started...", type="positive")
if isinstance(data, str):
ui.navigate.to(data, new_tab=True)
else:
ui.download(data.src, data.filename, data.mime_type)
yield NodeActionResult.discard()
def get_header_buttons(self, node: T_GENERATIVE_NODE) -> list[dict[str, Any]]:
buttons = super().get_header_buttons(node)
buttons += [
{
"name": "run",
"icon": "rocket_launch",
"tooltip": "Generate",
"action": "execute_node",
"disableWhileProcessing": True,
},
{
"name": "reset",
"icon": "delete_forever",
"tooltip": "Reset",
"action": as_action_name(self._on_reset_node_action),
"requiresContent": self._requires_content_fields(),
"disableWhileProcessing": True,
},
{
"name": "download",
"icon": "download",
"tooltip": "Download",
"action": as_action_name(self._on_download_data_action),
"requiresContent": self._requires_content_fields(),
"disableWhileProcessing": True,
},
]
return buttons
def _requires_content_fields(self) -> list[str]:
return []
|