File size: 3,811 Bytes
0f8b3a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Generic, TypeVar

from nicegui import ui

from velai.nodes.actions.node_action_decorator import as_action_name, node_action
from velai.nodes.actions.node_action_models import AsyncNodeActionResultIterator, NodeActionArguments, NodeActionResult
from velai.nodes.base_node import BaseNode, BaseNodeData
from velai.nodes.base_node_renderable import BaseNodeRenderable


@dataclass
class DownloadData:
    src: str | Path | bytes
    filename: str | None = None
    mime_type: str = ""


@dataclass
class GenerativeNodeData(BaseNodeData):
    pass


T_RESULT = TypeVar("T_RESULT")
T_GENERATIVE_DATA = TypeVar("T_GENERATIVE_DATA", bound=GenerativeNodeData)


class GenerativeNode(BaseNode[T_GENERATIVE_DATA], Generic[T_GENERATIVE_DATA], ABC):
    data_cls = GenerativeNodeData

    @abstractmethod
    async def get_download_data(self) -> DownloadData | None:
        raise NotImplementedError()

    def _on_progress(self, value: float, message: str | None):
        self.data.progress_value = value
        self.data.progress_message = message


T_GENERATIVE_NODE = TypeVar("T_GENERATIVE_NODE", bound=GenerativeNode)


class GenerativeNodeRenderable(BaseNodeRenderable[T_GENERATIVE_NODE], Generic[T_GENERATIVE_NODE], ABC):
    @node_action
    async def _on_reset_node_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator:
        # Confirm with the user before resetting
        with ui.dialog() as dialog, ui.card():
            ui.label("Are you sure to reset the node?")
            with ui.row():
                ui.button("Yes", on_click=lambda: dialog.submit("Yes"))
                ui.button("No", on_click=lambda: dialog.submit("No"))
        result = await dialog
        if result != "Yes":
            return

        # Reset the outputs of a node
        args.node.reset_outputs()

        ui.notify("Node has been reset!", type="positive")
        yield NodeActionResult.update_node()

    @node_action
    async def _on_download_data_action(self, args: NodeActionArguments) -> AsyncNodeActionResultIterator:
        if isinstance(args.node, GenerativeNode):
            data = await args.node.get_download_data()

            if data is None:
                ui.notify("Nothing to download!", type="warning")
                return

            ui.notify("Download started...", type="positive")
            if isinstance(data, str):
                ui.navigate.to(data, new_tab=True)
            else:
                ui.download(data.src, data.filename, data.mime_type)
        yield NodeActionResult.discard()

    def get_header_buttons(self, node: T_GENERATIVE_NODE) -> list[dict[str, Any]]:
        buttons = super().get_header_buttons(node)

        buttons += [
            {
                "name": "run",
                "icon": "rocket_launch",
                "tooltip": "Generate",
                "action": "execute_node",
                "disableWhileProcessing": True,
            },
            {
                "name": "reset",
                "icon": "delete_forever",
                "tooltip": "Reset",
                "action": as_action_name(self._on_reset_node_action),
                "requiresContent": self._requires_content_fields(),
                "disableWhileProcessing": True,
            },
            {
                "name": "download",
                "icon": "download",
                "tooltip": "Download",
                "action": as_action_name(self._on_download_data_action),
                "requiresContent": self._requires_content_fields(),
                "disableWhileProcessing": True,
            },
        ]

        return buttons

    def _requires_content_fields(self) -> list[str]:
        return []