Upload folder via script
Browse files- .github/workflows/publish-hf.yml +30 -0
- Dockerfile +3 -1
- Makefile +2 -0
- velai/async_utils.py +66 -1
- velai/nodes/base_execution_node.py +6 -6
- velai/nodes/base_node_renderable.py +4 -5
- velai/nodes/image_data.py +3 -3
- velai/nodes/image_to_mesh.py +9 -6
- velai/nodes/text_data.py +48 -3
- velai/nodes/text_to_image.py +4 -0
- velai/services/image/ImageGenerator.py +3 -0
- velai/services/mesh/DummyMeshGenerator.py +6 -4
- velai/services/mesh/MeshGenerator.py +6 -3
- velai/services/text/FalAITextGenerator.py +8 -0
- velai/services/text/TextGenerator.py +4 -1
- velai/session.py +6 -1
- velai/ui/vueflow_canvas.vue +1 -1
.github/workflows/publish-hf.yml
CHANGED
|
@@ -1,4 +1,34 @@
|
|
| 1 |
name: Publish to HuggingFace
|
|
|
|
| 2 |
on:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
jobs:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
name: Publish to HuggingFace
|
| 2 |
+
|
| 3 |
on:
|
| 4 |
+
workflow_dispatch:
|
| 5 |
+
push:
|
| 6 |
+
tags:
|
| 7 |
+
- "*"
|
| 8 |
|
| 9 |
jobs:
|
| 10 |
+
publish:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
|
| 13 |
+
steps:
|
| 14 |
+
- name: Checkout
|
| 15 |
+
uses: actions/checkout@v4
|
| 16 |
+
|
| 17 |
+
- name: Set up Python
|
| 18 |
+
uses: actions/setup-python@v5
|
| 19 |
+
with:
|
| 20 |
+
python-version: "3.12"
|
| 21 |
+
|
| 22 |
+
- name: Install uv
|
| 23 |
+
uses: astral-sh/setup-uv@v5
|
| 24 |
+
|
| 25 |
+
- name: Sync (release deps)
|
| 26 |
+
run: uv sync --frozen --group release
|
| 27 |
+
|
| 28 |
+
- name: Publish
|
| 29 |
+
env:
|
| 30 |
+
HF_REPO_ID: ${{ vars.HF_REPO_ID }}
|
| 31 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 32 |
+
run: |
|
| 33 |
+
set -euo pipefail
|
| 34 |
+
make publish HF_REPO_ID="${HF_REPO_ID}"
|
Dockerfile
CHANGED
|
@@ -26,7 +26,9 @@ RUN uv sync --frozen --no-dev
|
|
| 26 |
# Now copy the rest of the project
|
| 27 |
COPY . .
|
| 28 |
|
| 29 |
-
ENV
|
|
|
|
|
|
|
| 30 |
EXPOSE 7860
|
| 31 |
|
| 32 |
# Default command to start your app
|
|
|
|
| 26 |
# Now copy the rest of the project
|
| 27 |
COPY . .
|
| 28 |
|
| 29 |
+
ENV VELAI_HOST="0.0.0.0"
|
| 30 |
+
ENV VELAI_PORT=7860
|
| 31 |
+
|
| 32 |
EXPOSE 7860
|
| 33 |
|
| 34 |
# Default command to start your app
|
Makefile
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
.PHONY: fmt fmt-check lint autoformat build publish
|
| 2 |
|
|
|
|
|
|
|
| 3 |
# Code Style
|
| 4 |
fmt:
|
| 5 |
uv run --group lint ruff format .
|
|
|
|
| 1 |
.PHONY: fmt fmt-check lint autoformat build publish
|
| 2 |
|
| 3 |
+
default: build
|
| 4 |
+
|
| 5 |
# Code Style
|
| 6 |
fmt:
|
| 7 |
uv run --group lint ruff format .
|
velai/async_utils.py
CHANGED
|
@@ -1,8 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from dataclasses import dataclass, field
|
| 2 |
-
from typing import Awaitable, Callable,
|
| 3 |
|
| 4 |
from nicegui import ui
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
@dataclass
|
| 8 |
class AsyncDirtyTimer:
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import functools
|
| 3 |
+
import inspect
|
| 4 |
+
from concurrent.futures import Executor
|
| 5 |
from dataclasses import dataclass, field
|
| 6 |
+
from typing import Any, Awaitable, Callable, Concatenate, Generic, Optional, ParamSpec, TypeVar
|
| 7 |
|
| 8 |
from nicegui import ui
|
| 9 |
|
| 10 |
+
P = ParamSpec("P")
|
| 11 |
+
R = TypeVar("R")
|
| 12 |
+
T = TypeVar("T")
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def async_wrap(fn: Callable[P, R], *, executor: Optional[Executor] = None) -> Callable[P, Awaitable[R]]:
|
| 16 |
+
@functools.wraps(fn)
|
| 17 |
+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
| 18 |
+
if executor is None:
|
| 19 |
+
return await asyncio.to_thread(fn, *args, **kwargs)
|
| 20 |
+
loop = asyncio.get_running_loop()
|
| 21 |
+
call = functools.partial(fn, *args, **kwargs)
|
| 22 |
+
return await loop.run_in_executor(executor, call)
|
| 23 |
+
|
| 24 |
+
wrapper.__isabstractmethod__ = False
|
| 25 |
+
return wrapper
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class AsyncMethod(Generic[T, P, R]):
|
| 29 |
+
__isabstractmethod__ = False
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self, sync_method: Callable[Concatenate[T, P], R], *, executor: Optional[Executor] = None, cache: bool = True
|
| 33 |
+
):
|
| 34 |
+
self._sync_name = sync_method.__name__
|
| 35 |
+
self._executor = executor
|
| 36 |
+
self._cache = cache
|
| 37 |
+
self._attr_name: Optional[str] = None
|
| 38 |
+
|
| 39 |
+
def __set_name__(self, owner, name: str) -> None:
|
| 40 |
+
self._attr_name = name
|
| 41 |
+
|
| 42 |
+
def __get__(self, obj: Optional[T], objtype=None) -> Callable[P, Awaitable[R]]:
|
| 43 |
+
if obj is None:
|
| 44 |
+
return self
|
| 45 |
+
|
| 46 |
+
if self._cache and self._attr_name is not None:
|
| 47 |
+
cached = obj.__dict__.get(self._attr_name)
|
| 48 |
+
if cached is not None:
|
| 49 |
+
return cached
|
| 50 |
+
|
| 51 |
+
sync = getattr(obj, self._sync_name)
|
| 52 |
+
|
| 53 |
+
async_fn = async_wrap(sync, executor=self._executor)
|
| 54 |
+
async_fn.__signature__ = inspect.signature(sync)
|
| 55 |
+
async_fn.__isabstractmethod__ = False
|
| 56 |
+
|
| 57 |
+
if self._cache and self._attr_name is not None:
|
| 58 |
+
obj.__dict__[self._attr_name] = async_fn
|
| 59 |
+
|
| 60 |
+
return async_fn
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def async_method_wrapper(
|
| 64 |
+
sync_method: Callable[Concatenate[T, P], R],
|
| 65 |
+
*,
|
| 66 |
+
executor: Optional[Executor] = None,
|
| 67 |
+
cache: bool = True,
|
| 68 |
+
) -> AsyncMethod[T, P, R]:
|
| 69 |
+
return AsyncMethod(sync_method, executor=executor, cache=cache)
|
| 70 |
+
|
| 71 |
|
| 72 |
@dataclass
|
| 73 |
class AsyncDirtyTimer:
|
velai/nodes/base_execution_node.py
CHANGED
|
@@ -40,7 +40,7 @@ class ExecutionNode(BaseNode[T_EXECUTION_DATA], Generic[T_EXECUTION_DATA], ABC):
|
|
| 40 |
async def on_queue_for_processing(self):
|
| 41 |
self.data.error_message = ""
|
| 42 |
self.data.progress_value = None
|
| 43 |
-
self.data.progress_message =
|
| 44 |
|
| 45 |
async def process(self) -> None:
|
| 46 |
# early exit if already processed
|
|
@@ -86,7 +86,7 @@ class ExecutionNodeRenderable(BaseNodeRenderable[T_EXECUTION_NODE], Generic[T_EX
|
|
| 86 |
self.action_registry.register("reset_node", self._on_reset_node_action)
|
| 87 |
self.action_registry.register("download_data", self._on_download_data_action)
|
| 88 |
|
| 89 |
-
async def _on_reset_node_action(self, args: NodeActionArguments) -> NodeActionResult:
|
| 90 |
# Confirm with the user before resetting
|
| 91 |
with ui.dialog() as dialog, ui.card():
|
| 92 |
ui.label("Are you sure to reset the node?")
|
|
@@ -95,7 +95,7 @@ class ExecutionNodeRenderable(BaseNodeRenderable[T_EXECUTION_NODE], Generic[T_EX
|
|
| 95 |
ui.button("No", on_click=lambda: dialog.submit("No"))
|
| 96 |
result = await dialog
|
| 97 |
if result == "No":
|
| 98 |
-
return
|
| 99 |
|
| 100 |
# Reset the outputs of a node
|
| 101 |
args.node.reset_outputs()
|
|
@@ -103,20 +103,20 @@ class ExecutionNodeRenderable(BaseNodeRenderable[T_EXECUTION_NODE], Generic[T_EX
|
|
| 103 |
ui.notify("Node has been reset!", type="positive")
|
| 104 |
return NodeActionResult.update()
|
| 105 |
|
| 106 |
-
async def _on_download_data_action(self, args: NodeActionArguments) -> NodeActionResult:
|
| 107 |
if isinstance(args.node, ExecutionNode):
|
| 108 |
data = await args.node.get_download_data()
|
| 109 |
|
| 110 |
if data is None:
|
| 111 |
ui.notify("Nothing to download!", type="warning")
|
| 112 |
-
return
|
| 113 |
|
| 114 |
ui.notify("Download started...", type="positive")
|
| 115 |
if isinstance(data, str):
|
| 116 |
ui.navigate.to(data, new_tab=True)
|
| 117 |
else:
|
| 118 |
ui.download(data.src, data.filename, data.mime_type)
|
| 119 |
-
return
|
| 120 |
|
| 121 |
def get_header_buttons(self, node: T_EXECUTION_NODE) -> list[dict[str, Any]]:
|
| 122 |
buttons = super().get_header_buttons(node)
|
|
|
|
| 40 |
async def on_queue_for_processing(self):
|
| 41 |
self.data.error_message = ""
|
| 42 |
self.data.progress_value = None
|
| 43 |
+
self.data.progress_message = None
|
| 44 |
|
| 45 |
async def process(self) -> None:
|
| 46 |
# early exit if already processed
|
|
|
|
| 86 |
self.action_registry.register("reset_node", self._on_reset_node_action)
|
| 87 |
self.action_registry.register("download_data", self._on_download_data_action)
|
| 88 |
|
| 89 |
+
async def _on_reset_node_action(self, args: NodeActionArguments) -> NodeActionResult | None:
|
| 90 |
# Confirm with the user before resetting
|
| 91 |
with ui.dialog() as dialog, ui.card():
|
| 92 |
ui.label("Are you sure to reset the node?")
|
|
|
|
| 95 |
ui.button("No", on_click=lambda: dialog.submit("No"))
|
| 96 |
result = await dialog
|
| 97 |
if result == "No":
|
| 98 |
+
return
|
| 99 |
|
| 100 |
# Reset the outputs of a node
|
| 101 |
args.node.reset_outputs()
|
|
|
|
| 103 |
ui.notify("Node has been reset!", type="positive")
|
| 104 |
return NodeActionResult.update()
|
| 105 |
|
| 106 |
+
async def _on_download_data_action(self, args: NodeActionArguments) -> NodeActionResult | None:
|
| 107 |
if isinstance(args.node, ExecutionNode):
|
| 108 |
data = await args.node.get_download_data()
|
| 109 |
|
| 110 |
if data is None:
|
| 111 |
ui.notify("Nothing to download!", type="warning")
|
| 112 |
+
return
|
| 113 |
|
| 114 |
ui.notify("Download started...", type="positive")
|
| 115 |
if isinstance(data, str):
|
| 116 |
ui.navigate.to(data, new_tab=True)
|
| 117 |
else:
|
| 118 |
ui.download(data.src, data.filename, data.mime_type)
|
| 119 |
+
return
|
| 120 |
|
| 121 |
def get_header_buttons(self, node: T_EXECUTION_NODE) -> list[dict[str, Any]]:
|
| 122 |
buttons = super().get_header_buttons(node)
|
velai/nodes/base_node_renderable.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
from abc import ABC
|
| 2 |
-
from typing import
|
| 3 |
|
| 4 |
-
from velai.nodes.actions.node_action_registry import
|
| 5 |
from velai.nodes.actions.node_actions_models import NodeActionArguments, NodeActionResult
|
| 6 |
from velai.nodes.base_node import BaseNode, logger
|
| 7 |
-
from velai.nodes.vue_nodes import
|
| 8 |
|
| 9 |
T_BASE_NODE = TypeVar("T_BASE_NODE", bound=BaseNode)
|
| 10 |
|
|
@@ -14,12 +14,11 @@ class BaseNodeRenderable(VueNodeRenderable[T_BASE_NODE], Generic[T_BASE_NODE], A
|
|
| 14 |
super().__init__()
|
| 15 |
self.action_registry = NodeActionRegistry()
|
| 16 |
|
| 17 |
-
async def run_custom_action(self, args: NodeActionArguments) -> NodeActionResult:
|
| 18 |
try:
|
| 19 |
return await self.action_registry.run_action(args)
|
| 20 |
except NodeActionNotFoundError:
|
| 21 |
logger.exception()
|
| 22 |
-
return NodeActionResult()
|
| 23 |
|
| 24 |
# todo: create actual renderable fields that correspond to the vue components
|
| 25 |
def get_header_buttons(self, node: T_BASE_NODE) -> list[dict[str, Any]]:
|
|
|
|
| 1 |
from abc import ABC
|
| 2 |
+
from typing import Any, Generic, TypeVar
|
| 3 |
|
| 4 |
+
from velai.nodes.actions.node_action_registry import NodeActionNotFoundError, NodeActionRegistry
|
| 5 |
from velai.nodes.actions.node_actions_models import NodeActionArguments, NodeActionResult
|
| 6 |
from velai.nodes.base_node import BaseNode, logger
|
| 7 |
+
from velai.nodes.vue_nodes import VueNodeData, VueNodeRenderable
|
| 8 |
|
| 9 |
T_BASE_NODE = TypeVar("T_BASE_NODE", bound=BaseNode)
|
| 10 |
|
|
|
|
| 14 |
super().__init__()
|
| 15 |
self.action_registry = NodeActionRegistry()
|
| 16 |
|
| 17 |
+
async def run_custom_action(self, args: NodeActionArguments) -> NodeActionResult | None:
|
| 18 |
try:
|
| 19 |
return await self.action_registry.run_action(args)
|
| 20 |
except NodeActionNotFoundError:
|
| 21 |
logger.exception()
|
|
|
|
| 22 |
|
| 23 |
# todo: create actual renderable fields that correspond to the vue components
|
| 24 |
def get_header_buttons(self, node: T_BASE_NODE) -> list[dict[str, Any]]:
|
velai/nodes/image_data.py
CHANGED
|
@@ -2,18 +2,18 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
from typing import Any
|
| 4 |
|
|
|
|
| 5 |
from velai.dataflow.enums import NodeKind, PortDirection
|
| 6 |
from velai.dataflow.nodes import NodeType
|
| 7 |
from velai.dataflow.ports import PortSchema
|
| 8 |
-
from velai.data_types import ImageType
|
| 9 |
from velai.nodes.base_node import BaseNode, BaseNodeData
|
| 10 |
-
from velai.nodes.base_node_renderable import
|
| 11 |
|
| 12 |
ImageDataNodeType = NodeType(
|
| 13 |
kind=NodeKind.IMAGE_DATA,
|
| 14 |
display_name="Image",
|
| 15 |
inputs=[],
|
| 16 |
-
outputs=[PortSchema(name="image", dtype=ImageType, direction=PortDirection.OUTPUT)],
|
| 17 |
)
|
| 18 |
|
| 19 |
|
|
|
|
| 2 |
|
| 3 |
from typing import Any
|
| 4 |
|
| 5 |
+
from velai.data_types import ImageType
|
| 6 |
from velai.dataflow.enums import NodeKind, PortDirection
|
| 7 |
from velai.dataflow.nodes import NodeType
|
| 8 |
from velai.dataflow.ports import PortSchema
|
|
|
|
| 9 |
from velai.nodes.base_node import BaseNode, BaseNodeData
|
| 10 |
+
from velai.nodes.base_node_renderable import T_BASE_NODE, BaseNodeRenderable
|
| 11 |
|
| 12 |
ImageDataNodeType = NodeType(
|
| 13 |
kind=NodeKind.IMAGE_DATA,
|
| 14 |
display_name="Image",
|
| 15 |
inputs=[],
|
| 16 |
+
outputs=[PortSchema(name="image", dtype=ImageType, direction=PortDirection.OUTPUT, tooltip="Image")],
|
| 17 |
)
|
| 18 |
|
| 19 |
|
velai/nodes/image_to_mesh.py
CHANGED
|
@@ -6,21 +6,22 @@ from functools import partial
|
|
| 6 |
|
| 7 |
from PIL.Image import Image
|
| 8 |
|
| 9 |
-
from velai.
|
|
|
|
| 10 |
from velai.dataflow.nodes import NodeType
|
| 11 |
from velai.dataflow.ports import PortSchema
|
| 12 |
-
from velai.services.mesh import
|
| 13 |
from velai.services.registry import get_registry
|
| 14 |
from velai.services.services import TaskType
|
| 15 |
-
|
|
|
|
| 16 |
from .base_execution_node import (
|
| 17 |
-
ExecutionNode,
|
| 18 |
-
ExecutionNodeRenderable,
|
| 19 |
T_EXECUTION_NODE,
|
| 20 |
DownloadData,
|
|
|
|
| 21 |
ExecutionNodeData,
|
|
|
|
| 22 |
)
|
| 23 |
-
from ..storage import storage_endpoint
|
| 24 |
|
| 25 |
DEFAULT_MESH_MODEL_SERVICE_ID = os.getenv("DEFAULT_MESH_MODEL_SERVICE_ID", "fal_ai_mesh_sam3d")
|
| 26 |
|
|
@@ -34,6 +35,7 @@ ImageToMeshNodeType = NodeType(
|
|
| 34 |
direction=PortDirection.INPUT,
|
| 35 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 36 |
capacity=1,
|
|
|
|
| 37 |
),
|
| 38 |
PortSchema(
|
| 39 |
name="image",
|
|
@@ -41,6 +43,7 @@ ImageToMeshNodeType = NodeType(
|
|
| 41 |
direction=PortDirection.INPUT,
|
| 42 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 43 |
capacity=1,
|
|
|
|
| 44 |
),
|
| 45 |
],
|
| 46 |
outputs=[
|
|
|
|
| 6 |
|
| 7 |
from PIL.Image import Image
|
| 8 |
|
| 9 |
+
from velai.data_types import ImageType, MeshType, TextType
|
| 10 |
+
from velai.dataflow.enums import ConnectMultiplicity, NodeKind, PortDirection
|
| 11 |
from velai.dataflow.nodes import NodeType
|
| 12 |
from velai.dataflow.ports import PortSchema
|
| 13 |
+
from velai.services.mesh import MeshGenerationResult, MeshGenerator
|
| 14 |
from velai.services.registry import get_registry
|
| 15 |
from velai.services.services import TaskType
|
| 16 |
+
|
| 17 |
+
from ..storage import storage_endpoint
|
| 18 |
from .base_execution_node import (
|
|
|
|
|
|
|
| 19 |
T_EXECUTION_NODE,
|
| 20 |
DownloadData,
|
| 21 |
+
ExecutionNode,
|
| 22 |
ExecutionNodeData,
|
| 23 |
+
ExecutionNodeRenderable,
|
| 24 |
)
|
|
|
|
| 25 |
|
| 26 |
DEFAULT_MESH_MODEL_SERVICE_ID = os.getenv("DEFAULT_MESH_MODEL_SERVICE_ID", "fal_ai_mesh_sam3d")
|
| 27 |
|
|
|
|
| 35 |
direction=PortDirection.INPUT,
|
| 36 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 37 |
capacity=1,
|
| 38 |
+
tooltip="Prompt / Mask",
|
| 39 |
),
|
| 40 |
PortSchema(
|
| 41 |
name="image",
|
|
|
|
| 43 |
direction=PortDirection.INPUT,
|
| 44 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 45 |
capacity=1,
|
| 46 |
+
tooltip="Image",
|
| 47 |
),
|
| 48 |
],
|
| 49 |
outputs=[
|
velai/nodes/text_data.py
CHANGED
|
@@ -1,22 +1,30 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import logging
|
|
|
|
|
|
|
| 4 |
from typing import Any
|
| 5 |
|
|
|
|
| 6 |
from velai.dataflow.enums import NodeKind, PortDirection
|
| 7 |
from velai.dataflow.nodes import NodeType
|
| 8 |
from velai.dataflow.ports import PortSchema
|
| 9 |
-
from velai.
|
| 10 |
from velai.nodes.base_node import BaseNode, BaseNodeData
|
| 11 |
-
from velai.nodes.base_node_renderable import
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
logger = logging.getLogger(__name__)
|
| 14 |
|
|
|
|
|
|
|
| 15 |
TextDataNodeType = NodeType(
|
| 16 |
kind=NodeKind.TEXT_DATA,
|
| 17 |
display_name="Text",
|
| 18 |
inputs=[],
|
| 19 |
-
outputs=[PortSchema(name="text", dtype=TextType, direction=PortDirection.OUTPUT)],
|
| 20 |
)
|
| 21 |
|
| 22 |
|
|
@@ -30,6 +38,10 @@ class TextDataNode(BaseNode[BaseNodeData]):
|
|
| 30 |
|
| 31 |
|
| 32 |
class TextDataNodeRenderable(BaseNodeRenderable[TextDataNode]):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
def get_fields(self, node: T_BASE_NODE) -> list[dict[str, Any]]:
|
| 34 |
return [
|
| 35 |
*super().get_fields(node),
|
|
@@ -40,3 +52,36 @@ class TextDataNodeRenderable(BaseNodeRenderable[TextDataNode]):
|
|
| 40 |
"placeholder": "Enter text...",
|
| 41 |
},
|
| 42 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import logging
|
| 4 |
+
import os
|
| 5 |
+
import typing
|
| 6 |
from typing import Any
|
| 7 |
|
| 8 |
+
from velai.data_types import TextType
|
| 9 |
from velai.dataflow.enums import NodeKind, PortDirection
|
| 10 |
from velai.dataflow.nodes import NodeType
|
| 11 |
from velai.dataflow.ports import PortSchema
|
| 12 |
+
from velai.nodes.actions.node_actions_models import NodeActionArguments, NodeActionResult
|
| 13 |
from velai.nodes.base_node import BaseNode, BaseNodeData
|
| 14 |
+
from velai.nodes.base_node_renderable import T_BASE_NODE, BaseNodeRenderable
|
| 15 |
+
from velai.services.registry import get_registry
|
| 16 |
+
from velai.services.services import TaskType
|
| 17 |
+
from velai.services.text.TextGenerator import TextGenerator
|
| 18 |
|
| 19 |
logger = logging.getLogger(__name__)
|
| 20 |
|
| 21 |
+
DEFAULT_TEXT_MODEL_SERVICE_ID = os.getenv("DEFAULT_TEXT_MODEL_SERVICE_ID", "fal_ai_gemini_flash_text")
|
| 22 |
+
|
| 23 |
TextDataNodeType = NodeType(
|
| 24 |
kind=NodeKind.TEXT_DATA,
|
| 25 |
display_name="Text",
|
| 26 |
inputs=[],
|
| 27 |
+
outputs=[PortSchema(name="text", dtype=TextType, direction=PortDirection.OUTPUT, tooltip="Text")],
|
| 28 |
)
|
| 29 |
|
| 30 |
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
class TextDataNodeRenderable(BaseNodeRenderable[TextDataNode]):
|
| 41 |
+
def __init__(self):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.action_registry.register("enhance_text", self._on_enhance_text_action)
|
| 44 |
+
|
| 45 |
def get_fields(self, node: T_BASE_NODE) -> list[dict[str, Any]]:
|
| 46 |
return [
|
| 47 |
*super().get_fields(node),
|
|
|
|
| 52 |
"placeholder": "Enter text...",
|
| 53 |
},
|
| 54 |
]
|
| 55 |
+
|
| 56 |
+
def get_header_buttons(self, node: T_BASE_NODE) -> list[dict[str, Any]]:
|
| 57 |
+
buttons = super().get_header_buttons(node)
|
| 58 |
+
|
| 59 |
+
buttons.append(
|
| 60 |
+
{
|
| 61 |
+
"name": "enhance",
|
| 62 |
+
"icon": "spellcheck",
|
| 63 |
+
"tooltip": "Enhance",
|
| 64 |
+
"action": "enhance_text",
|
| 65 |
+
"requireContent": False,
|
| 66 |
+
"disableWhileProcessing": False,
|
| 67 |
+
}
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
return buttons
|
| 71 |
+
|
| 72 |
+
async def _on_enhance_text_action(self, args: NodeActionArguments) -> NodeActionResult:
|
| 73 |
+
text_output = args.node.outputs["text"]
|
| 74 |
+
|
| 75 |
+
# run model
|
| 76 |
+
registry = get_registry()
|
| 77 |
+
text_service = typing.cast(TextGenerator, registry.create(TaskType.TEXT, DEFAULT_TEXT_MODEL_SERVICE_ID))
|
| 78 |
+
|
| 79 |
+
# todo: improve prompt
|
| 80 |
+
result = await text_service.generate_async(
|
| 81 |
+
f"Please enhance this image prompt for genai image generation "
|
| 82 |
+
f"and only return the resulting prompt (no enclosing formatting): "
|
| 83 |
+
f"```{text_output.value}```"
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
text_output.value = result.text
|
| 87 |
+
return NodeActionResult.update()
|
velai/nodes/text_to_image.py
CHANGED
|
@@ -43,6 +43,7 @@ TextToImageNodeType = NodeType(
|
|
| 43 |
direction=PortDirection.INPUT,
|
| 44 |
multiplicity=ConnectMultiplicity.MULTIPLE,
|
| 45 |
capacity=3,
|
|
|
|
| 46 |
),
|
| 47 |
PortSchema(
|
| 48 |
name="image1",
|
|
@@ -50,6 +51,7 @@ TextToImageNodeType = NodeType(
|
|
| 50 |
direction=PortDirection.INPUT,
|
| 51 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 52 |
capacity=1,
|
|
|
|
| 53 |
),
|
| 54 |
PortSchema(
|
| 55 |
name="image2",
|
|
@@ -57,6 +59,7 @@ TextToImageNodeType = NodeType(
|
|
| 57 |
direction=PortDirection.INPUT,
|
| 58 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 59 |
capacity=1,
|
|
|
|
| 60 |
),
|
| 61 |
PortSchema(
|
| 62 |
name="image3",
|
|
@@ -64,6 +67,7 @@ TextToImageNodeType = NodeType(
|
|
| 64 |
direction=PortDirection.INPUT,
|
| 65 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 66 |
capacity=1,
|
|
|
|
| 67 |
),
|
| 68 |
],
|
| 69 |
outputs=[PortSchema(name="image", dtype=ImageType, direction=PortDirection.OUTPUT)],
|
|
|
|
| 43 |
direction=PortDirection.INPUT,
|
| 44 |
multiplicity=ConnectMultiplicity.MULTIPLE,
|
| 45 |
capacity=3,
|
| 46 |
+
tooltip="Prompt",
|
| 47 |
),
|
| 48 |
PortSchema(
|
| 49 |
name="image1",
|
|
|
|
| 51 |
direction=PortDirection.INPUT,
|
| 52 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 53 |
capacity=1,
|
| 54 |
+
tooltip="Image 1",
|
| 55 |
),
|
| 56 |
PortSchema(
|
| 57 |
name="image2",
|
|
|
|
| 59 |
direction=PortDirection.INPUT,
|
| 60 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 61 |
capacity=1,
|
| 62 |
+
tooltip="Image 2",
|
| 63 |
),
|
| 64 |
PortSchema(
|
| 65 |
name="image3",
|
|
|
|
| 67 |
direction=PortDirection.INPUT,
|
| 68 |
multiplicity=ConnectMultiplicity.SINGLE,
|
| 69 |
capacity=1,
|
| 70 |
+
tooltip="Image 3",
|
| 71 |
),
|
| 72 |
],
|
| 73 |
outputs=[PortSchema(name="image", dtype=ImageType, direction=PortDirection.OUTPUT)],
|
velai/services/image/ImageGenerator.py
CHANGED
|
@@ -3,6 +3,7 @@ from typing import Any, Sequence
|
|
| 3 |
|
| 4 |
from PIL import Image
|
| 5 |
|
|
|
|
| 6 |
from velai.services.image.ImageGenerationResult import ImageGenerationResult
|
| 7 |
from velai.services.progress import ProgressCallback
|
| 8 |
from velai.services.services import GenerationService, TaskType
|
|
@@ -26,3 +27,5 @@ class ImageGenerator(GenerationService, ABC):
|
|
| 26 |
"""Generate images from a prompt and optional images."""
|
| 27 |
|
| 28 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from PIL import Image
|
| 5 |
|
| 6 |
+
from velai.async_utils import async_method_wrapper
|
| 7 |
from velai.services.image.ImageGenerationResult import ImageGenerationResult
|
| 8 |
from velai.services.progress import ProgressCallback
|
| 9 |
from velai.services.services import GenerationService, TaskType
|
|
|
|
| 27 |
"""Generate images from a prompt and optional images."""
|
| 28 |
|
| 29 |
raise NotImplementedError
|
| 30 |
+
|
| 31 |
+
generate_async = async_method_wrapper(generate)
|
velai/services/mesh/DummyMeshGenerator.py
CHANGED
|
@@ -2,14 +2,14 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
-
from typing import
|
| 6 |
|
| 7 |
from PIL import Image
|
| 8 |
|
|
|
|
|
|
|
| 9 |
from .MeshGenerationResult import MeshGenerationResult
|
| 10 |
from .MeshGenerator import MeshGenerator
|
| 11 |
-
from ..progress import ProgressCallback
|
| 12 |
-
from ..registry import register_service
|
| 13 |
|
| 14 |
KHRONOS_GROUP_URL = "https://raw.githubusercontent.com/KhronosGroup/glTF-Sample-Models/master"
|
| 15 |
|
|
@@ -53,7 +53,9 @@ class DummyMeshGenerator(MeshGenerator):
|
|
| 53 |
jitter = random.randint(-64, 64)
|
| 54 |
mesh_url += f"?{jitter}"
|
| 55 |
|
| 56 |
-
|
|
|
|
|
|
|
| 57 |
|
| 58 |
return MeshGenerationResult(
|
| 59 |
provider="dummy",
|
|
|
|
| 2 |
|
| 3 |
import random
|
| 4 |
import time
|
| 5 |
+
from typing import Any, Sequence
|
| 6 |
|
| 7 |
from PIL import Image
|
| 8 |
|
| 9 |
+
from ..progress import ProgressCallback, call_progress
|
| 10 |
+
from ..registry import register_service
|
| 11 |
from .MeshGenerationResult import MeshGenerationResult
|
| 12 |
from .MeshGenerator import MeshGenerator
|
|
|
|
|
|
|
| 13 |
|
| 14 |
KHRONOS_GROUP_URL = "https://raw.githubusercontent.com/KhronosGroup/glTF-Sample-Models/master"
|
| 15 |
|
|
|
|
| 53 |
jitter = random.randint(-64, 64)
|
| 54 |
mesh_url += f"?{jitter}"
|
| 55 |
|
| 56 |
+
for i in range(1, 9):
|
| 57 |
+
call_progress(progress, i * 0.1, "Generating mesh")
|
| 58 |
+
time.sleep(0.1)
|
| 59 |
|
| 60 |
return MeshGenerationResult(
|
| 61 |
provider="dummy",
|
velai/services/mesh/MeshGenerator.py
CHANGED
|
@@ -1,13 +1,14 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
-
from typing import
|
| 5 |
|
| 6 |
from PIL import Image
|
| 7 |
|
| 8 |
-
from .
|
| 9 |
from ..progress import ProgressCallback
|
| 10 |
-
from ..services import
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
class MeshGenerator(GenerationService, ABC):
|
|
@@ -24,3 +25,5 @@ class MeshGenerator(GenerationService, ABC):
|
|
| 24 |
**kwargs: Any,
|
| 25 |
) -> MeshGenerationResult:
|
| 26 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Any, Sequence
|
| 5 |
|
| 6 |
from PIL import Image
|
| 7 |
|
| 8 |
+
from ...async_utils import async_method_wrapper
|
| 9 |
from ..progress import ProgressCallback
|
| 10 |
+
from ..services import GenerationService, TaskType
|
| 11 |
+
from .MeshGenerationResult import MeshGenerationResult
|
| 12 |
|
| 13 |
|
| 14 |
class MeshGenerator(GenerationService, ABC):
|
|
|
|
| 25 |
**kwargs: Any,
|
| 26 |
) -> MeshGenerationResult:
|
| 27 |
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
generate_async = async_method_wrapper(generate)
|
velai/services/text/FalAITextGenerator.py
CHANGED
|
@@ -122,3 +122,11 @@ class FalAITextGenerator(TextGenerator):
|
|
| 122 |
text=text_output,
|
| 123 |
raw_response=response,
|
| 124 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
text=text_output,
|
| 123 |
raw_response=response,
|
| 124 |
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@register_service
|
| 128 |
+
class FalAIGeminiFlashTextGenerator(FalAITextGenerator):
|
| 129 |
+
service_id = "fal_ai_gemini_flash_text"
|
| 130 |
+
|
| 131 |
+
def __init__(self, model_name: str | None = None):
|
| 132 |
+
super().__init__(model_name=model_name, extra_arguments={"model": "google/gemini-2.5-flash"})
|
velai/services/text/TextGenerator.py
CHANGED
|
@@ -3,8 +3,9 @@ from typing import Sequence
|
|
| 3 |
|
| 4 |
from PIL import Image
|
| 5 |
|
|
|
|
| 6 |
from velai.services.progress import ProgressCallback
|
| 7 |
-
from velai.services.services import
|
| 8 |
from velai.services.text.TextGenerationResult import TextGenerationResult
|
| 9 |
|
| 10 |
|
|
@@ -23,3 +24,5 @@ class TextGenerator(GenerationService, ABC):
|
|
| 23 |
) -> TextGenerationResult:
|
| 24 |
"""Generate text from a prompt and optional images."""
|
| 25 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from PIL import Image
|
| 5 |
|
| 6 |
+
from velai.async_utils import async_method_wrapper
|
| 7 |
from velai.services.progress import ProgressCallback
|
| 8 |
+
from velai.services.services import GenerationService, TaskType
|
| 9 |
from velai.services.text.TextGenerationResult import TextGenerationResult
|
| 10 |
|
| 11 |
|
|
|
|
| 24 |
) -> TextGenerationResult:
|
| 25 |
"""Generate text from a prompt and optional images."""
|
| 26 |
raise NotImplementedError
|
| 27 |
+
|
| 28 |
+
generate_async = async_method_wrapper(generate)
|
velai/session.py
CHANGED
|
@@ -23,7 +23,7 @@ from velai.ui.vueflow_canvas import VueFlowCanvas
|
|
| 23 |
from . import app_context
|
| 24 |
from .async_utils import AsyncDirtyTimer
|
| 25 |
from .controller import GraphController
|
| 26 |
-
from .nodes.actions.node_actions_models import NodeActionArguments
|
| 27 |
from .nodes.base_node import BaseNode
|
| 28 |
from .nodes.base_node_renderable import BaseNodeRenderable
|
| 29 |
from .nodes.image_to_mesh import ImageToMeshNode, ImageToMeshNodeRenderable, ImageToMeshNodeType
|
|
@@ -480,14 +480,19 @@ class GraphSession:
|
|
| 480 |
if not isinstance(node_renderable, BaseNodeRenderable):
|
| 481 |
return
|
| 482 |
|
|
|
|
| 483 |
result = await node_renderable.run_custom_action(NodeActionArguments(action, node))
|
| 484 |
|
|
|
|
|
|
|
|
|
|
| 485 |
if result.requires_ui_sync:
|
| 486 |
await self.runtime.sync_node_to_ui(node)
|
| 487 |
|
| 488 |
if result.requires_data_save and self.autosaver is not None:
|
| 489 |
self.autosaver.mark_dirty()
|
| 490 |
|
|
|
|
| 491 |
return
|
| 492 |
|
| 493 |
elif event_type == "execute_node":
|
|
|
|
| 23 |
from . import app_context
|
| 24 |
from .async_utils import AsyncDirtyTimer
|
| 25 |
from .controller import GraphController
|
| 26 |
+
from .nodes.actions.node_actions_models import NodeActionArguments, NodeActionResult
|
| 27 |
from .nodes.base_node import BaseNode
|
| 28 |
from .nodes.base_node_renderable import BaseNodeRenderable
|
| 29 |
from .nodes.image_to_mesh import ImageToMeshNode, ImageToMeshNodeRenderable, ImageToMeshNodeType
|
|
|
|
| 480 |
if not isinstance(node_renderable, BaseNodeRenderable):
|
| 481 |
return
|
| 482 |
|
| 483 |
+
await canvas.set_node_processing(node_id, True)
|
| 484 |
result = await node_renderable.run_custom_action(NodeActionArguments(action, node))
|
| 485 |
|
| 486 |
+
if result is None:
|
| 487 |
+
result = NodeActionResult.discard()
|
| 488 |
+
|
| 489 |
if result.requires_ui_sync:
|
| 490 |
await self.runtime.sync_node_to_ui(node)
|
| 491 |
|
| 492 |
if result.requires_data_save and self.autosaver is not None:
|
| 493 |
self.autosaver.mark_dirty()
|
| 494 |
|
| 495 |
+
await canvas.set_node_processing(node_id, False)
|
| 496 |
return
|
| 497 |
|
| 498 |
elif event_type == "execute_node":
|
velai/ui/vueflow_canvas.vue
CHANGED
|
@@ -1025,7 +1025,7 @@ export default {
|
|
| 1025 |
:indeterminate="typeof data.values.progress_value !== 'number'"
|
| 1026 |
size="28px"
|
| 1027 |
rounded
|
| 1028 |
-
show-value
|
| 1029 |
/>
|
| 1030 |
<div class="vf-inner-loading-text">
|
| 1031 |
<span v-if="data.values.progress_message">
|
|
|
|
| 1025 |
:indeterminate="typeof data.values.progress_value !== 'number'"
|
| 1026 |
size="28px"
|
| 1027 |
rounded
|
| 1028 |
+
:show-value="typeof data.values.progress_value === 'number'"
|
| 1029 |
/>
|
| 1030 |
<div class="vf-inner-loading-text">
|
| 1031 |
<span v-if="data.values.progress_message">
|