Spaces:
Sleeping
Sleeping
Commit ·
e248bdb
1
Parent(s): f44fecc
Preparación final para despliegue con Daggr visual
Browse files- README.md +10 -7
- app_daggr.py +102 -34
- daggr/CHANGELOG.md +65 -0
- daggr/__init__.py +53 -0
- daggr/_client_cache.py +200 -0
- daggr/_utils.py +19 -0
- daggr/assets/hf-logo-pirate.png +3 -0
- daggr/assets/logo_dark.png +3 -0
- daggr/assets/logo_dark_small.png +3 -0
- daggr/assets/logo_light.png +3 -0
- daggr/cli.py +689 -0
- daggr/edge.py +60 -0
- daggr/executor.py +846 -0
- daggr/graph.py +767 -0
- daggr/local_space.py +503 -0
- daggr/node.py +772 -0
- daggr/ops.py +57 -0
- daggr/package.json +6 -0
- daggr/port.py +158 -0
- daggr/py.typed +0 -0
- daggr/server.py +1946 -0
- daggr/session.py +114 -0
- daggr/state.py +457 -0
- pyproject.toml +3 -3
- requirements.txt +2 -4
- test_daggr_init.py +28 -0
- uv.lock +0 -0
README.md
CHANGED
|
@@ -1,14 +1,17 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
emoji: 🎨
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
|
| 9 |
-
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
---
|
| 13 |
|
| 14 |
# Generador de Retratos Acuarela con IA
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Acuarela Portrait Daggr
|
| 3 |
emoji: 🎨
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 4.19.2
|
| 8 |
+
app_file: app_daggr.py
|
|
|
|
| 9 |
pinned: false
|
| 10 |
+
tags:
|
| 11 |
+
- daggr
|
| 12 |
+
- modal
|
| 13 |
+
- sdxl
|
| 14 |
+
- watercolor
|
| 15 |
---
|
| 16 |
|
| 17 |
# Generador de Retratos Acuarela con IA
|
app_daggr.py
CHANGED
|
@@ -16,18 +16,24 @@ $ python app_daggr.py
|
|
| 16 |
|
| 17 |
import os
|
| 18 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
import modal
|
| 20 |
import gradio as gr
|
| 21 |
from PIL import Image
|
| 22 |
-
import uuid
|
| 23 |
-
from dotenv import load_dotenv
|
| 24 |
-
|
| 25 |
-
# Import Daggr components
|
| 26 |
from daggr import FnNode, GradioNode, InferenceNode, Graph
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
# --- Modal Setup ---
|
| 31 |
try:
|
| 32 |
ImageCaptioner = modal.Cls.from_name("acuarela-portrait", "ImageCaptioner")
|
| 33 |
ImageGenerator = modal.Cls.from_name("acuarela-portrait", "ImageGenerator")
|
|
@@ -43,23 +49,71 @@ except Exception as e:
|
|
| 43 |
|
| 44 |
# --- Function Nodes ---
|
| 45 |
|
| 46 |
-
def
|
| 47 |
-
"""
|
| 48 |
if image is None:
|
| 49 |
return None
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
-
def generate_caption(
|
| 56 |
"""
|
| 57 |
Step 1: Generate image caption using Modal
|
| 58 |
"""
|
| 59 |
-
if not MODAL_AVAILABLE or
|
| 60 |
-
raise ValueError("Modal not available or image is None")
|
| 61 |
|
| 62 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
captioner = ImageCaptioner()
|
| 64 |
caption = captioner.caption.remote(img_bytes)
|
| 65 |
if not caption:
|
|
@@ -82,14 +136,23 @@ def create_artistic_prompt(caption: str) -> str:
|
|
| 82 |
return prompt
|
| 83 |
|
| 84 |
|
| 85 |
-
def generate_watercolor_image(
|
| 86 |
"""
|
| 87 |
-
Step 3: Generate watercolor image using Modal SDXL
|
| 88 |
"""
|
| 89 |
-
if not MODAL_AVAILABLE or
|
| 90 |
-
raise ValueError("Modal not available or image is None")
|
| 91 |
|
| 92 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
negative_prompt = (
|
| 94 |
"photorealistic, 3d render, photograph, complex background, "
|
| 95 |
"white background, dark background, messy sketch, blurry, "
|
|
@@ -99,8 +162,14 @@ def generate_watercolor_image(img_bytes: bytes, prompt: str) -> Image.Image:
|
|
| 99 |
|
| 100 |
generator = ImageGenerator()
|
| 101 |
result_bytes = generator.generate.remote(img_bytes, prompt, negative_prompt, strength=0.65)
|
| 102 |
-
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
except Exception as e:
|
| 105 |
raise ValueError(f"Image generation failed: {str(e)}")
|
| 106 |
|
|
@@ -113,15 +182,15 @@ def create_workflow():
|
|
| 113 |
"""
|
| 114 |
|
| 115 |
# Node 1: Image Input (implicit - handled by Gradio interface)
|
| 116 |
-
# Node 2: Convert image to bytes
|
| 117 |
converter = FnNode(
|
| 118 |
-
fn=
|
| 119 |
name="Image Converter",
|
| 120 |
inputs={
|
| 121 |
-
"image": gr.Image(label="Upload your photo"
|
| 122 |
},
|
| 123 |
outputs={
|
| 124 |
-
"
|
| 125 |
},
|
| 126 |
)
|
| 127 |
|
|
@@ -130,10 +199,10 @@ def create_workflow():
|
|
| 130 |
fn=generate_caption,
|
| 131 |
name="Image Analysis (BLIP Caption)",
|
| 132 |
inputs={
|
| 133 |
-
"
|
| 134 |
},
|
| 135 |
outputs={
|
| 136 |
-
"
|
| 137 |
},
|
| 138 |
)
|
| 139 |
|
|
@@ -142,10 +211,10 @@ def create_workflow():
|
|
| 142 |
fn=create_artistic_prompt,
|
| 143 |
name="Artistic Prompt Engineering",
|
| 144 |
inputs={
|
| 145 |
-
"caption": captioner.
|
| 146 |
},
|
| 147 |
outputs={
|
| 148 |
-
"
|
| 149 |
},
|
| 150 |
)
|
| 151 |
|
|
@@ -154,11 +223,11 @@ def create_workflow():
|
|
| 154 |
fn=generate_watercolor_image,
|
| 155 |
name="Watercolor Generation (SDXL)",
|
| 156 |
inputs={
|
| 157 |
-
"
|
| 158 |
-
"prompt": prompt_engineer.
|
| 159 |
},
|
| 160 |
outputs={
|
| 161 |
-
"
|
| 162 |
},
|
| 163 |
)
|
| 164 |
|
|
@@ -166,8 +235,7 @@ def create_workflow():
|
|
| 166 |
graph = Graph(
|
| 167 |
name="🎨 Acuarela Portrait Generator - Daggr Workflow",
|
| 168 |
nodes=[converter, captioner, prompt_engineer, image_generator],
|
| 169 |
-
|
| 170 |
-
"Each step is visualized and can be rerun independently!",
|
| 171 |
)
|
| 172 |
|
| 173 |
return graph
|
|
|
|
| 16 |
|
| 17 |
import os
|
| 18 |
import io
|
| 19 |
+
import base64
|
| 20 |
+
import tempfile
|
| 21 |
+
import uuid
|
| 22 |
+
import shutil
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
# Fix for Windows Long Paths in Gradio (must set before importing gradio)
|
| 26 |
+
if os.name == 'nt':
|
| 27 |
+
custom_temp = os.path.join(tempfile.gettempdir(), "gr")
|
| 28 |
+
os.makedirs(custom_temp, exist_ok=True)
|
| 29 |
+
os.environ["GRADIO_TEMP_DIR"] = custom_temp
|
| 30 |
+
|
| 31 |
import modal
|
| 32 |
import gradio as gr
|
| 33 |
from PIL import Image
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
from daggr import FnNode, GradioNode, InferenceNode, Graph
|
| 35 |
|
| 36 |
+
# --- Inference Setup ---
|
|
|
|
|
|
|
| 37 |
try:
|
| 38 |
ImageCaptioner = modal.Cls.from_name("acuarela-portrait", "ImageCaptioner")
|
| 39 |
ImageGenerator = modal.Cls.from_name("acuarela-portrait", "ImageGenerator")
|
|
|
|
| 49 |
|
| 50 |
# --- Function Nodes ---
|
| 51 |
|
| 52 |
+
def convert_image_to_path(image) -> str:
|
| 53 |
+
"""Save image to a temporary file and return the path for Daggr nodes"""
|
| 54 |
if image is None:
|
| 55 |
return None
|
| 56 |
+
|
| 57 |
+
try:
|
| 58 |
+
# Create a unique temp file in our short gr folder
|
| 59 |
+
temp_dir = os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
|
| 60 |
+
temp_path = os.path.join(temp_dir, f"input_{uuid.uuid4()}.png")
|
| 61 |
+
|
| 62 |
+
# Handle Base64 Data URI from Gradio/Daggr
|
| 63 |
+
if isinstance(image, str) and image.startswith('data:image'):
|
| 64 |
+
header, encoded = image.split(",", 1)
|
| 65 |
+
data = base64.b64decode(encoded)
|
| 66 |
+
with open(temp_path, "wb") as f:
|
| 67 |
+
f.write(data)
|
| 68 |
+
return temp_path
|
| 69 |
+
|
| 70 |
+
# Handle file paths (already a path, just verify and return)
|
| 71 |
+
if isinstance(image, (str, Path)):
|
| 72 |
+
path_str = str(image)
|
| 73 |
+
# Windows Long Path fix
|
| 74 |
+
if os.name == 'nt' and not path_str.startswith('\\\\?\\'):
|
| 75 |
+
abs_path = os.path.abspath(path_str)
|
| 76 |
+
path_str = '\\\\?\\' + abs_path if len(abs_path) > 250 else abs_path
|
| 77 |
+
|
| 78 |
+
# If it's already a file, we can just return it, but better to copy to temp
|
| 79 |
+
# to avoid permissions/lifetime issues with Gradio's deep temp folders
|
| 80 |
+
shutil.copy2(path_str, temp_path)
|
| 81 |
+
return temp_path
|
| 82 |
+
|
| 83 |
+
# Handle PIL objects
|
| 84 |
+
if hasattr(image, 'save'):
|
| 85 |
+
image.save(temp_path, format='PNG')
|
| 86 |
+
return temp_path
|
| 87 |
+
|
| 88 |
+
# If it's bytes
|
| 89 |
+
if isinstance(image, bytes):
|
| 90 |
+
with open(temp_path, "wb") as f:
|
| 91 |
+
f.write(image)
|
| 92 |
+
return temp_path
|
| 93 |
+
|
| 94 |
+
return str(image)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
print(f"❌ Error in convert_image_to_path: {e}")
|
| 97 |
+
raise e
|
| 98 |
|
| 99 |
|
| 100 |
+
def generate_caption(img_path: str) -> str:
|
| 101 |
"""
|
| 102 |
Step 1: Generate image caption using Modal
|
| 103 |
"""
|
| 104 |
+
if not MODAL_AVAILABLE or img_path is None:
|
| 105 |
+
raise ValueError("Modal not available or image path is None")
|
| 106 |
|
| 107 |
try:
|
| 108 |
+
# Read bytes from the provided path
|
| 109 |
+
path_str = str(img_path)
|
| 110 |
+
if os.name == 'nt' and not path_str.startswith('\\\\?\\'):
|
| 111 |
+
abs_path = os.path.abspath(path_str)
|
| 112 |
+
path_str = '\\\\?\\' + abs_path if len(abs_path) > 250 else abs_path
|
| 113 |
+
|
| 114 |
+
with open(path_str, "rb") as f:
|
| 115 |
+
img_bytes = f.read()
|
| 116 |
+
|
| 117 |
captioner = ImageCaptioner()
|
| 118 |
caption = captioner.caption.remote(img_bytes)
|
| 119 |
if not caption:
|
|
|
|
| 136 |
return prompt
|
| 137 |
|
| 138 |
|
| 139 |
+
def generate_watercolor_image(img_path: str, prompt: str) -> str:
|
| 140 |
"""
|
| 141 |
+
Step 3: Generate watercolor image using Modal SDXL and return the file path
|
| 142 |
"""
|
| 143 |
+
if not MODAL_AVAILABLE or img_path is None:
|
| 144 |
+
raise ValueError("Modal not available or image path is None")
|
| 145 |
|
| 146 |
try:
|
| 147 |
+
# Read bytes from the provided path
|
| 148 |
+
path_str = str(img_path)
|
| 149 |
+
if os.name == 'nt' and not path_str.startswith('\\\\?\\'):
|
| 150 |
+
abs_path = os.path.abspath(path_str)
|
| 151 |
+
path_str = '\\\\?\\' + abs_path if len(abs_path) > 250 else abs_path
|
| 152 |
+
|
| 153 |
+
with open(path_str, "rb") as f:
|
| 154 |
+
img_bytes = f.read()
|
| 155 |
+
|
| 156 |
negative_prompt = (
|
| 157 |
"photorealistic, 3d render, photograph, complex background, "
|
| 158 |
"white background, dark background, messy sketch, blurry, "
|
|
|
|
| 162 |
|
| 163 |
generator = ImageGenerator()
|
| 164 |
result_bytes = generator.generate.remote(img_bytes, prompt, negative_prompt, strength=0.65)
|
| 165 |
+
|
| 166 |
+
# Save result to a temp file
|
| 167 |
+
temp_dir = os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
|
| 168 |
+
output_path = os.path.join(temp_dir, f"output_{uuid.uuid4()}.png")
|
| 169 |
+
with open(output_path, "wb") as f:
|
| 170 |
+
f.write(result_bytes)
|
| 171 |
+
|
| 172 |
+
return output_path
|
| 173 |
except Exception as e:
|
| 174 |
raise ValueError(f"Image generation failed: {str(e)}")
|
| 175 |
|
|
|
|
| 182 |
"""
|
| 183 |
|
| 184 |
# Node 1: Image Input (implicit - handled by Gradio interface)
|
| 185 |
+
# Node 2: Convert image to path (using string to avoid bytes JSON error)
|
| 186 |
converter = FnNode(
|
| 187 |
+
fn=convert_image_to_path,
|
| 188 |
name="Image Converter",
|
| 189 |
inputs={
|
| 190 |
+
"image": gr.Image(label="Upload your photo"),
|
| 191 |
},
|
| 192 |
outputs={
|
| 193 |
+
"output": gr.Textbox(visible=False),
|
| 194 |
},
|
| 195 |
)
|
| 196 |
|
|
|
|
| 199 |
fn=generate_caption,
|
| 200 |
name="Image Analysis (BLIP Caption)",
|
| 201 |
inputs={
|
| 202 |
+
"img_path": converter.output,
|
| 203 |
},
|
| 204 |
outputs={
|
| 205 |
+
"output": gr.Textbox(label="Generated Caption"),
|
| 206 |
},
|
| 207 |
)
|
| 208 |
|
|
|
|
| 211 |
fn=create_artistic_prompt,
|
| 212 |
name="Artistic Prompt Engineering",
|
| 213 |
inputs={
|
| 214 |
+
"caption": captioner.output,
|
| 215 |
},
|
| 216 |
outputs={
|
| 217 |
+
"output": gr.Textbox(label="Artistic Prompt", lines=3),
|
| 218 |
},
|
| 219 |
)
|
| 220 |
|
|
|
|
| 223 |
fn=generate_watercolor_image,
|
| 224 |
name="Watercolor Generation (SDXL)",
|
| 225 |
inputs={
|
| 226 |
+
"img_path": converter.output,
|
| 227 |
+
"prompt": prompt_engineer.output,
|
| 228 |
},
|
| 229 |
outputs={
|
| 230 |
+
"output": gr.Image(label="Watercolor Portrait"),
|
| 231 |
},
|
| 232 |
)
|
| 233 |
|
|
|
|
| 235 |
graph = Graph(
|
| 236 |
name="🎨 Acuarela Portrait Generator - Daggr Workflow",
|
| 237 |
nodes=[converter, captioner, prompt_engineer, image_generator],
|
| 238 |
+
persist_key=False, # Disable persistence to avoid bytes -> str serialization issues
|
|
|
|
| 239 |
)
|
| 240 |
|
| 241 |
return graph
|
daggr/CHANGELOG.md
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# daggr
|
| 2 |
+
|
| 3 |
+
## 0.7.0
|
| 4 |
+
|
| 5 |
+
### Features
|
| 6 |
+
|
| 7 |
+
- [#69](https://github.com/gradio-app/daggr/pull/69) [`297d104`](https://github.com/gradio-app/daggr/commit/297d104d0e3fbd6b59dfcd1f69c9a478de81bc3d) - Add stop button to cancel running nodes. Thanks @abidlabs!
|
| 8 |
+
- [#64](https://github.com/gradio-app/daggr/pull/64) [`a9e53c6`](https://github.com/gradio-app/daggr/commit/a9e53c64db6b3beede0b19b3876a3e50ab572233) - Fix ChoiceNode: public .name property and respect explicit names. Thanks @abidlabs!
|
| 9 |
+
- [#62](https://github.com/gradio-app/daggr/pull/62) [`695411c`](https://github.com/gradio-app/daggr/commit/695411ce94bc8fedd3320ac941ab41233bf8f887) - Standardize file handling: all files are path strings. Thanks @abidlabs!
|
| 10 |
+
- [#66](https://github.com/gradio-app/daggr/pull/66) [`1ed16f8`](https://github.com/gradio-app/daggr/commit/1ed16f806d535413c3718fc47d81d79d93d73ee0) - Fix gr.JSON rendering (use @render snippet syntax). Thanks @abidlabs!
|
| 11 |
+
- [#63](https://github.com/gradio-app/daggr/pull/63) [`00b05ac`](https://github.com/gradio-app/daggr/commit/00b05ac526642c91560cbebb258626a4511082c2) - Fix file downloads from private HF Spaces. Thanks @abidlabs!
|
| 12 |
+
- [#70](https://github.com/gradio-app/daggr/pull/70) [`33ccb74`](https://github.com/gradio-app/daggr/commit/33ccb7470a482ee1b09fddf1d51de81b1e2c4a40) - Fix gr.Image not rendering with initial value or None input. Thanks @abidlabs!
|
| 13 |
+
- [#68](https://github.com/gradio-app/daggr/pull/68) [`4b76dca`](https://github.com/gradio-app/daggr/commit/4b76dca815e2802b70ac03bb95fb03f21d81a8fa) - Add dependency hash tracking for upstream Spaces and models. Thanks @abidlabs!
|
| 14 |
+
|
| 15 |
+
## 0.6.0
|
| 16 |
+
|
| 17 |
+
### Features
|
| 18 |
+
|
| 19 |
+
- [#54](https://github.com/gradio-app/daggr/pull/54) [`c1abb26`](https://github.com/gradio-app/daggr/commit/c1abb260b254af6ca2060292232049ea89f0f944) - Fix cache. Thanks @abidlabs!
|
| 20 |
+
- [#56](https://github.com/gradio-app/daggr/pull/56) [`6e3dfc0`](https://github.com/gradio-app/daggr/commit/6e3dfc0a585b673adb77bb11ab1dcfd80d01da5a) - Add paste from clipboard button to Image component. Thanks @abidlabs!
|
| 21 |
+
- [#57](https://github.com/gradio-app/daggr/pull/57) [`76855ba`](https://github.com/gradio-app/daggr/commit/76855ba967e3f3132e8ec0590ae037d3151af310) - Fix dropdown options being clipped inside node. Thanks @abidlabs!
|
| 22 |
+
- [#58](https://github.com/gradio-app/daggr/pull/58) [`eb52b72`](https://github.com/gradio-app/daggr/commit/eb52b725b17d277e85f6eac1cc9d07f8068b011b) - Add theme support to daggr. Thanks @abidlabs!
|
| 23 |
+
- [#59](https://github.com/gradio-app/daggr/pull/59) [`78189a4`](https://github.com/gradio-app/daggr/commit/78189a4163b4041c814e52110b65754dc4dbf863) - Add run mode dropdown to control node execution scope. Thanks @abidlabs!
|
| 24 |
+
- [#39](https://github.com/gradio-app/daggr/pull/39) [`e8792ad`](https://github.com/gradio-app/daggr/commit/e8792ad1b5818ff8d13660b0b156f329bbc1c33a) - feat: add --state-db-path CLI arg and DAGGR_DB_PATH env var support. Thanks @leith-bartrich!
|
| 25 |
+
|
| 26 |
+
## 0.5.4
|
| 27 |
+
|
| 28 |
+
### Features
|
| 29 |
+
|
| 30 |
+
- [#27](https://github.com/gradio-app/daggr/pull/27) [`3952b2c`](https://github.com/gradio-app/daggr/commit/3952b2ccf30e7d18994f23049c2a2e84b323cfd6) - changes. Thanks @abidlabs!
|
| 31 |
+
|
| 32 |
+
## 0.5.3
|
| 33 |
+
|
| 34 |
+
### Features
|
| 35 |
+
|
| 36 |
+
- [#19](https://github.com/gradio-app/daggr/pull/19) [`cd956fe`](https://github.com/gradio-app/daggr/commit/cd956fe29945bdfd31bbe76fcb80d3f9c97cc301) - Add daggr tag to deployed Spaces. Thanks @gary149!
|
| 37 |
+
|
| 38 |
+
## 0.5.2
|
| 39 |
+
|
| 40 |
+
### Features
|
| 41 |
+
|
| 42 |
+
- [#14](https://github.com/gradio-app/daggr/pull/14) [`3fa412d`](https://github.com/gradio-app/daggr/commit/3fa412d678988608d49d46d99d193a05469892d2) - Fixes. Thanks @abidlabs!
|
| 43 |
+
|
| 44 |
+
## 0.5.1
|
| 45 |
+
|
| 46 |
+
### Features
|
| 47 |
+
|
| 48 |
+
- [#11](https://github.com/gradio-app/daggr/pull/11) [`ce1d5f4`](https://github.com/gradio-app/daggr/commit/ce1d5f4deaac60d95d9a021b0aa057bc2941b018) - Fixes. Thanks @abidlabs!
|
| 49 |
+
- [#13](https://github.com/gradio-app/daggr/pull/13) [`3246921`](https://github.com/gradio-app/daggr/commit/32469213dad5fd29a7ac85938dffbd976e2c6643) - fixes. Thanks @abidlabs!
|
| 50 |
+
|
| 51 |
+
## 0.5.0
|
| 52 |
+
|
| 53 |
+
### Features
|
| 54 |
+
|
| 55 |
+
- [#8](https://github.com/gradio-app/daggr/pull/8) [`e480065`](https://github.com/gradio-app/daggr/commit/e480065dd058dbf19053a80956dbfc90cf3e3caf) - Improving security around executor and various bug fixes. Thanks @abidlabs!
|
| 56 |
+
|
| 57 |
+
## 0.4.0
|
| 58 |
+
|
| 59 |
+
### Features
|
| 60 |
+
|
| 61 |
+
- [#1](https://github.com/gradio-app/daggr/pull/1) [`23538c8`](https://github.com/gradio-app/daggr/commit/23538c884fb3f2d84bbe4bf14f475dc85fa17c79) - Refactor files, add Dialogue component, and implement fully working podcast example. Thanks @abidlabs!
|
| 62 |
+
|
| 63 |
+
## 0.1.0
|
| 64 |
+
|
| 65 |
+
Initial release
|
daggr/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""daggr - Build visual, node-based AI pipelines with Gradio Spaces.
|
| 2 |
+
|
| 3 |
+
daggr lets you create DAG (directed acyclic graph) pipelines that connect
|
| 4 |
+
Gradio Spaces, Hugging Face models, and Python functions into interactive
|
| 5 |
+
applications.
|
| 6 |
+
|
| 7 |
+
Example:
|
| 8 |
+
>>> from daggr import Graph, GradioNode, FnNode
|
| 9 |
+
>>> import gradio as gr
|
| 10 |
+
>>>
|
| 11 |
+
>>> tts = GradioNode(
|
| 12 |
+
... "mrfakename/MeloTTS",
|
| 13 |
+
... inputs={"text": gr.Textbox()},
|
| 14 |
+
... outputs={"audio": gr.Audio()},
|
| 15 |
+
... )
|
| 16 |
+
>>> graph = Graph("TTS Demo", nodes=[tts])
|
| 17 |
+
>>> graph.launch()
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import json
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
__version__ = json.loads((Path(__file__).parent / "package.json").read_text())[
|
| 24 |
+
"version"
|
| 25 |
+
]
|
| 26 |
+
|
| 27 |
+
from daggr.edge import Edge
|
| 28 |
+
from daggr.graph import Graph
|
| 29 |
+
from daggr.node import (
|
| 30 |
+
ChoiceNode,
|
| 31 |
+
FnNode,
|
| 32 |
+
GradioNode,
|
| 33 |
+
InferenceNode,
|
| 34 |
+
InteractionNode,
|
| 35 |
+
Node,
|
| 36 |
+
)
|
| 37 |
+
from daggr.port import ItemList, Port
|
| 38 |
+
from daggr.server import DaggrServer
|
| 39 |
+
|
| 40 |
+
__all__ = [
|
| 41 |
+
"__version__",
|
| 42 |
+
"ChoiceNode",
|
| 43 |
+
"Edge",
|
| 44 |
+
"Graph",
|
| 45 |
+
"Node",
|
| 46 |
+
"FnNode",
|
| 47 |
+
"GradioNode",
|
| 48 |
+
"InferenceNode",
|
| 49 |
+
"InteractionNode",
|
| 50 |
+
"ItemList",
|
| 51 |
+
"Port",
|
| 52 |
+
"DaggrServer",
|
| 53 |
+
]
|
daggr/_client_cache.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
from daggr.state import get_daggr_cache_dir
|
| 10 |
+
|
| 11 |
+
_client_cache: dict[str, Any] = {}
|
| 12 |
+
_api_memory_cache: dict[str, dict] = {}
|
| 13 |
+
_validated_set: set[str] = set()
|
| 14 |
+
_model_task_cache: dict[str, str] = {}
|
| 15 |
+
_dependency_hash_cache: dict[str, str] = {}
|
| 16 |
+
_dependency_hash_loaded: bool = False
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _is_hot_reload() -> bool:
|
| 20 |
+
return os.environ.get("DAGGR_HOT_RELOAD") == "1"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def _get_cache_path(src: str) -> Path:
|
| 24 |
+
src_hash = hashlib.md5(src.encode()).hexdigest()[:16]
|
| 25 |
+
return get_daggr_cache_dir() / f"{src_hash}.json"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _get_validated_file() -> Path:
|
| 29 |
+
return get_daggr_cache_dir() / "_validated.json"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _load_validated_set() -> None:
|
| 33 |
+
global _validated_set
|
| 34 |
+
if _validated_set:
|
| 35 |
+
return
|
| 36 |
+
if not _is_hot_reload():
|
| 37 |
+
return
|
| 38 |
+
validated_file = _get_validated_file()
|
| 39 |
+
if validated_file.exists():
|
| 40 |
+
try:
|
| 41 |
+
_validated_set = set(json.loads(validated_file.read_text()))
|
| 42 |
+
except (json.JSONDecodeError, OSError):
|
| 43 |
+
_validated_set = set()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _save_validated_set() -> None:
|
| 47 |
+
if not _is_hot_reload():
|
| 48 |
+
return
|
| 49 |
+
try:
|
| 50 |
+
get_daggr_cache_dir().mkdir(parents=True, exist_ok=True)
|
| 51 |
+
_get_validated_file().write_text(json.dumps(list(_validated_set)))
|
| 52 |
+
except OSError:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def is_validated(cache_key: tuple) -> bool:
|
| 57 |
+
if not _is_hot_reload():
|
| 58 |
+
return False
|
| 59 |
+
_load_validated_set()
|
| 60 |
+
return str(cache_key) in _validated_set
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def mark_validated(cache_key: tuple) -> None:
|
| 64 |
+
if not _is_hot_reload():
|
| 65 |
+
return
|
| 66 |
+
_load_validated_set()
|
| 67 |
+
_validated_set.add(str(cache_key))
|
| 68 |
+
_save_validated_set()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def get_api_info(src: str) -> dict | None:
|
| 72 |
+
if src in _api_memory_cache:
|
| 73 |
+
return _api_memory_cache[src]
|
| 74 |
+
|
| 75 |
+
if not _is_hot_reload():
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
cache_path = _get_cache_path(src)
|
| 79 |
+
if cache_path.exists():
|
| 80 |
+
try:
|
| 81 |
+
data = json.loads(cache_path.read_text())
|
| 82 |
+
_api_memory_cache[src] = data
|
| 83 |
+
return data
|
| 84 |
+
except (json.JSONDecodeError, OSError):
|
| 85 |
+
pass
|
| 86 |
+
return None
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def set_api_info(src: str, info: dict) -> None:
|
| 90 |
+
_api_memory_cache[src] = info
|
| 91 |
+
if not _is_hot_reload():
|
| 92 |
+
return
|
| 93 |
+
try:
|
| 94 |
+
get_daggr_cache_dir().mkdir(parents=True, exist_ok=True)
|
| 95 |
+
cache_path = _get_cache_path(src)
|
| 96 |
+
cache_path.write_text(json.dumps(info))
|
| 97 |
+
except OSError:
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_client(src: str):
|
| 102 |
+
return _client_cache.get(src)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def set_client(src: str, client) -> None:
|
| 106 |
+
_client_cache[src] = client
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _get_model_task_cache_path() -> Path:
|
| 110 |
+
return get_daggr_cache_dir() / "_model_tasks.json"
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _load_model_task_cache() -> None:
|
| 114 |
+
global _model_task_cache
|
| 115 |
+
if _model_task_cache:
|
| 116 |
+
return
|
| 117 |
+
if not _is_hot_reload():
|
| 118 |
+
return
|
| 119 |
+
cache_path = _get_model_task_cache_path()
|
| 120 |
+
if cache_path.exists():
|
| 121 |
+
try:
|
| 122 |
+
_model_task_cache = json.loads(cache_path.read_text())
|
| 123 |
+
except (json.JSONDecodeError, OSError):
|
| 124 |
+
_model_task_cache = {}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _save_model_task_cache() -> None:
|
| 128 |
+
if not _is_hot_reload():
|
| 129 |
+
return
|
| 130 |
+
try:
|
| 131 |
+
get_daggr_cache_dir().mkdir(parents=True, exist_ok=True)
|
| 132 |
+
_get_model_task_cache_path().write_text(json.dumps(_model_task_cache))
|
| 133 |
+
except OSError:
|
| 134 |
+
pass
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_model_task(model: str) -> tuple[bool, str | None]:
|
| 138 |
+
"""Get cached task for a model.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
(found_in_cache, task) where:
|
| 142 |
+
- found_in_cache is True if we have cached info for this model
|
| 143 |
+
- task is the pipeline_tag (can be None if model has no task, or "__NOT_FOUND__" if model doesn't exist)
|
| 144 |
+
"""
|
| 145 |
+
if model in _model_task_cache:
|
| 146 |
+
return True, _model_task_cache[model]
|
| 147 |
+
|
| 148 |
+
if not _is_hot_reload():
|
| 149 |
+
return False, None
|
| 150 |
+
|
| 151 |
+
_load_model_task_cache()
|
| 152 |
+
if model in _model_task_cache:
|
| 153 |
+
return True, _model_task_cache[model]
|
| 154 |
+
return False, None
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def set_model_task(model: str, task: str | None) -> None:
|
| 158 |
+
_model_task_cache[model] = task
|
| 159 |
+
_save_model_task_cache()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def set_model_not_found(model: str) -> None:
|
| 163 |
+
_model_task_cache[model] = "__NOT_FOUND__"
|
| 164 |
+
_save_model_task_cache()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def _get_dependency_hash_path() -> Path:
|
| 168 |
+
return get_daggr_cache_dir() / "_dependency_hashes.json"
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _load_dependency_hash_cache() -> None:
|
| 172 |
+
global _dependency_hash_cache, _dependency_hash_loaded
|
| 173 |
+
if _dependency_hash_loaded:
|
| 174 |
+
return
|
| 175 |
+
cache_path = _get_dependency_hash_path()
|
| 176 |
+
if cache_path.exists():
|
| 177 |
+
try:
|
| 178 |
+
_dependency_hash_cache = json.loads(cache_path.read_text())
|
| 179 |
+
except (json.JSONDecodeError, OSError):
|
| 180 |
+
_dependency_hash_cache = {}
|
| 181 |
+
_dependency_hash_loaded = True
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def _save_dependency_hash_cache() -> None:
|
| 185 |
+
try:
|
| 186 |
+
get_daggr_cache_dir().mkdir(parents=True, exist_ok=True)
|
| 187 |
+
_get_dependency_hash_path().write_text(json.dumps(_dependency_hash_cache))
|
| 188 |
+
except OSError:
|
| 189 |
+
pass
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def get_dependency_hash(src: str) -> str | None:
|
| 193 |
+
_load_dependency_hash_cache()
|
| 194 |
+
return _dependency_hash_cache.get(src)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def set_dependency_hash(src: str, sha: str) -> None:
|
| 198 |
+
_load_dependency_hash_cache()
|
| 199 |
+
_dependency_hash_cache[src] = sha
|
| 200 |
+
_save_dependency_hash_cache()
|
daggr/_utils.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Internal utilities for daggr."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import difflib
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def suggest_similar(invalid: str, valid_options: set[str]) -> str | None:
|
| 9 |
+
"""Find a similar string from valid_options using fuzzy matching.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
invalid: The invalid string to find matches for.
|
| 13 |
+
valid_options: Set of valid options to search through.
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
The closest matching string if found with >= 60% similarity, else None.
|
| 17 |
+
"""
|
| 18 |
+
matches = difflib.get_close_matches(invalid, valid_options, n=1, cutoff=0.6)
|
| 19 |
+
return matches[0] if matches else None
|
daggr/assets/hf-logo-pirate.png
ADDED
|
Git LFS Details
|
daggr/assets/logo_dark.png
ADDED
|
Git LFS Details
|
daggr/assets/logo_dark_small.png
ADDED
|
Git LFS Details
|
daggr/assets/logo_light.png
ADDED
|
Git LFS Details
|
daggr/cli.py
ADDED
|
@@ -0,0 +1,689 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import ast
|
| 5 |
+
import importlib.util
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import shutil
|
| 9 |
+
import socket
|
| 10 |
+
import sqlite3
|
| 11 |
+
import sys
|
| 12 |
+
import tempfile
|
| 13 |
+
import threading
|
| 14 |
+
import time
|
| 15 |
+
import webbrowser
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
INITIAL_PORT_VALUE = int(os.getenv("DAGGR_SERVER_PORT", "7860"))
|
| 19 |
+
TRY_NUM_PORTS = int(os.getenv("DAGGR_NUM_PORTS", "100"))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _find_available_port(host: str, start_port: int) -> int:
|
| 23 |
+
"""Find an available port starting from start_port."""
|
| 24 |
+
for port in range(start_port, start_port + TRY_NUM_PORTS):
|
| 25 |
+
try:
|
| 26 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 27 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 28 |
+
s.bind((host if host != "0.0.0.0" else "127.0.0.1", port))
|
| 29 |
+
s.close()
|
| 30 |
+
return port
|
| 31 |
+
except OSError:
|
| 32 |
+
continue
|
| 33 |
+
raise OSError(
|
| 34 |
+
f"Cannot find empty port in range: {start_port}-{start_port + TRY_NUM_PORTS - 1}. "
|
| 35 |
+
f"You can specify a different port by setting the DAGGR_SERVER_PORT environment variable "
|
| 36 |
+
f"or passing the --port parameter."
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def find_python_imports(file_path: Path) -> list[Path]:
|
| 41 |
+
"""Find local Python files imported by the given file."""
|
| 42 |
+
imports = []
|
| 43 |
+
try:
|
| 44 |
+
with open(file_path) as f:
|
| 45 |
+
content = f.read()
|
| 46 |
+
|
| 47 |
+
tree = ast.parse(content)
|
| 48 |
+
|
| 49 |
+
file_dir = file_path.parent
|
| 50 |
+
|
| 51 |
+
for node in ast.walk(tree):
|
| 52 |
+
if isinstance(node, ast.Import):
|
| 53 |
+
for alias in node.names:
|
| 54 |
+
module_path = file_dir / f"{alias.name.replace('.', '/')}.py"
|
| 55 |
+
if module_path.exists():
|
| 56 |
+
imports.append(module_path)
|
| 57 |
+
elif isinstance(node, ast.ImportFrom):
|
| 58 |
+
if node.module:
|
| 59 |
+
module_path = file_dir / f"{node.module.replace('.', '/')}.py"
|
| 60 |
+
if module_path.exists():
|
| 61 |
+
imports.append(module_path)
|
| 62 |
+
package_init = (
|
| 63 |
+
file_dir / node.module.replace(".", "/") / "__init__.py"
|
| 64 |
+
)
|
| 65 |
+
if package_init.exists():
|
| 66 |
+
imports.append(package_init.parent)
|
| 67 |
+
except Exception:
|
| 68 |
+
pass
|
| 69 |
+
return imports
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def main():
|
| 73 |
+
if len(sys.argv) > 1 and sys.argv[1] == "deploy":
|
| 74 |
+
_deploy_main()
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
parser = argparse.ArgumentParser(
|
| 78 |
+
prog="daggr",
|
| 79 |
+
description="Run a daggr app with hot reload",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"script",
|
| 83 |
+
help="Path to the Python script containing the daggr Graph",
|
| 84 |
+
)
|
| 85 |
+
parser.add_argument(
|
| 86 |
+
"--host",
|
| 87 |
+
default="127.0.0.1",
|
| 88 |
+
help="Host to bind to (default: 127.0.0.1)",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument(
|
| 91 |
+
"--port",
|
| 92 |
+
type=int,
|
| 93 |
+
default=7860,
|
| 94 |
+
help="Port to bind to (default: 7860)",
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--no-reload",
|
| 98 |
+
action="store_true",
|
| 99 |
+
help="Disable auto-reload",
|
| 100 |
+
)
|
| 101 |
+
parser.add_argument(
|
| 102 |
+
"--watch-daggr",
|
| 103 |
+
action="store_true",
|
| 104 |
+
default=True,
|
| 105 |
+
help="Watch daggr source for changes (default: True, useful for development)",
|
| 106 |
+
)
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--no-watch-daggr",
|
| 109 |
+
action="store_true",
|
| 110 |
+
help="Don't watch daggr source for changes",
|
| 111 |
+
)
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--delete-sheets",
|
| 114 |
+
action="store_true",
|
| 115 |
+
help="Delete all cached data (sheets, results, downloaded files) for this project and exit",
|
| 116 |
+
)
|
| 117 |
+
parser.add_argument(
|
| 118 |
+
"--force",
|
| 119 |
+
"-f",
|
| 120 |
+
action="store_true",
|
| 121 |
+
help="Skip confirmation prompts (use with --delete-sheets)",
|
| 122 |
+
)
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--state-db-path",
|
| 125 |
+
help="Optional path to SQLite state database. Overrides DAGGR_DB_PATH env var. Defaults to HuggingFace cache.",
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
args = parser.parse_args()
|
| 129 |
+
|
| 130 |
+
script_path = Path(args.script).resolve()
|
| 131 |
+
if not script_path.exists():
|
| 132 |
+
print(f"Error: Script not found: {script_path}")
|
| 133 |
+
sys.exit(1)
|
| 134 |
+
|
| 135 |
+
if not script_path.suffix == ".py":
|
| 136 |
+
print(f"Error: Script must be a Python file: {script_path}")
|
| 137 |
+
sys.exit(1)
|
| 138 |
+
|
| 139 |
+
if args.delete_sheets:
|
| 140 |
+
_delete_sheets(script_path, force=args.force)
|
| 141 |
+
sys.exit(0)
|
| 142 |
+
|
| 143 |
+
watch_daggr = args.watch_daggr and not args.no_watch_daggr
|
| 144 |
+
|
| 145 |
+
os.environ["DAGGR_SCRIPT_PATH"] = str(script_path)
|
| 146 |
+
os.environ["DAGGR_HOST"] = args.host
|
| 147 |
+
os.environ["DAGGR_PORT"] = str(args.port)
|
| 148 |
+
if args.state_db_path:
|
| 149 |
+
os.environ["DAGGR_DB_PATH"] = str(Path(args.state_db_path).resolve())
|
| 150 |
+
|
| 151 |
+
if args.no_reload:
|
| 152 |
+
_run_script(script_path, args.host, args.port)
|
| 153 |
+
else:
|
| 154 |
+
os.environ["DAGGR_HOT_RELOAD"] = "1"
|
| 155 |
+
_run_with_reload(script_path, args.host, args.port, watch_daggr)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _deploy_main():
|
| 159 |
+
"""Entry point for the deploy subcommand."""
|
| 160 |
+
parser = argparse.ArgumentParser(
|
| 161 |
+
prog="daggr deploy",
|
| 162 |
+
description="Deploy a daggr app to Hugging Face Spaces",
|
| 163 |
+
)
|
| 164 |
+
parser.add_argument(
|
| 165 |
+
"script",
|
| 166 |
+
help="Path to the Python script containing the daggr Graph",
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--name",
|
| 170 |
+
"-n",
|
| 171 |
+
help="Space name (default: derived from Graph name)",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--title",
|
| 175 |
+
"-t",
|
| 176 |
+
help="Display title for the Space (default: Graph name)",
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--org",
|
| 180 |
+
"-o",
|
| 181 |
+
help="Organization or username to deploy under (default: your HF account)",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--private",
|
| 185 |
+
"-p",
|
| 186 |
+
action="store_true",
|
| 187 |
+
help="Make the Space private",
|
| 188 |
+
)
|
| 189 |
+
parser.add_argument(
|
| 190 |
+
"--hardware",
|
| 191 |
+
default="cpu-basic",
|
| 192 |
+
help="Hardware tier (default: cpu-basic). Options: cpu-basic, cpu-upgrade, t4-small, t4-medium, a10g-small, etc.",
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--secret",
|
| 196 |
+
"-s",
|
| 197 |
+
action="append",
|
| 198 |
+
dest="secrets",
|
| 199 |
+
metavar="KEY=VALUE",
|
| 200 |
+
help="Add a secret (can be repeated). Example: --secret HF_TOKEN=xxx",
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--requirements",
|
| 204 |
+
"-r",
|
| 205 |
+
help="Path to requirements.txt (default: auto-detect or generate)",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument(
|
| 208 |
+
"--dry-run",
|
| 209 |
+
action="store_true",
|
| 210 |
+
help="Preview what would be deployed without actually deploying",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
args = parser.parse_args(sys.argv[2:])
|
| 214 |
+
|
| 215 |
+
script_path = Path(args.script).resolve()
|
| 216 |
+
if not script_path.exists():
|
| 217 |
+
print(f"Error: Script not found: {script_path}")
|
| 218 |
+
sys.exit(1)
|
| 219 |
+
|
| 220 |
+
if not script_path.suffix == ".py":
|
| 221 |
+
print(f"Error: Script must be a Python file: {script_path}")
|
| 222 |
+
sys.exit(1)
|
| 223 |
+
|
| 224 |
+
secrets = {}
|
| 225 |
+
if args.secrets:
|
| 226 |
+
for secret in args.secrets:
|
| 227 |
+
if "=" not in secret:
|
| 228 |
+
print(f"Error: Invalid secret format '{secret}'. Use KEY=VALUE")
|
| 229 |
+
sys.exit(1)
|
| 230 |
+
key, value = secret.split("=", 1)
|
| 231 |
+
secrets[key] = value
|
| 232 |
+
|
| 233 |
+
_deploy(
|
| 234 |
+
script_path=script_path,
|
| 235 |
+
name=args.name,
|
| 236 |
+
title=args.title,
|
| 237 |
+
org=args.org,
|
| 238 |
+
private=args.private,
|
| 239 |
+
hardware=args.hardware,
|
| 240 |
+
secrets=secrets,
|
| 241 |
+
requirements_path=args.requirements,
|
| 242 |
+
dry_run=args.dry_run,
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def _extract_graph(script_path: Path):
|
| 247 |
+
"""Extract the Graph object from a script without running it."""
|
| 248 |
+
from daggr.graph import Graph
|
| 249 |
+
|
| 250 |
+
sys.path.insert(0, str(script_path.parent))
|
| 251 |
+
|
| 252 |
+
original_launch = Graph.launch
|
| 253 |
+
captured_graph = None
|
| 254 |
+
|
| 255 |
+
def capture_launch(self, **kwargs):
|
| 256 |
+
nonlocal captured_graph
|
| 257 |
+
captured_graph = self
|
| 258 |
+
|
| 259 |
+
Graph.launch = capture_launch
|
| 260 |
+
|
| 261 |
+
try:
|
| 262 |
+
spec = importlib.util.spec_from_file_location("__daggr_deploy__", script_path)
|
| 263 |
+
if spec is None or spec.loader is None:
|
| 264 |
+
print(f"Error: Could not load script: {script_path}")
|
| 265 |
+
sys.exit(1)
|
| 266 |
+
|
| 267 |
+
module = importlib.util.module_from_spec(spec)
|
| 268 |
+
sys.modules["__daggr_deploy__"] = module
|
| 269 |
+
spec.loader.exec_module(module)
|
| 270 |
+
finally:
|
| 271 |
+
Graph.launch = original_launch
|
| 272 |
+
|
| 273 |
+
if captured_graph is None:
|
| 274 |
+
for name in dir(module):
|
| 275 |
+
obj = getattr(module, name)
|
| 276 |
+
if isinstance(obj, Graph):
|
| 277 |
+
captured_graph = obj
|
| 278 |
+
break
|
| 279 |
+
|
| 280 |
+
if captured_graph is None:
|
| 281 |
+
print(f"Error: No Graph found in {script_path}")
|
| 282 |
+
sys.exit(1)
|
| 283 |
+
|
| 284 |
+
return captured_graph
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _sanitize_space_name(name: str) -> str:
|
| 288 |
+
"""Convert a Graph name to a valid HF Space name."""
|
| 289 |
+
sanitized = re.sub(r"[^a-zA-Z0-9\s-]", "", name)
|
| 290 |
+
sanitized = re.sub(r"[\s_]+", "-", sanitized)
|
| 291 |
+
sanitized = sanitized.lower().strip("-")
|
| 292 |
+
return sanitized or "daggr-app"
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def _deploy(
|
| 296 |
+
script_path: Path,
|
| 297 |
+
name: str | None,
|
| 298 |
+
title: str | None,
|
| 299 |
+
org: str | None,
|
| 300 |
+
private: bool,
|
| 301 |
+
hardware: str,
|
| 302 |
+
secrets: dict[str, str],
|
| 303 |
+
requirements_path: str | None,
|
| 304 |
+
dry_run: bool,
|
| 305 |
+
):
|
| 306 |
+
"""Deploy a daggr app to Hugging Face Spaces."""
|
| 307 |
+
import huggingface_hub
|
| 308 |
+
from huggingface_hub import HfApi
|
| 309 |
+
|
| 310 |
+
import daggr
|
| 311 |
+
|
| 312 |
+
print("\n Extracting Graph from script...")
|
| 313 |
+
graph = _extract_graph(script_path)
|
| 314 |
+
|
| 315 |
+
space_name = name or _sanitize_space_name(graph.name)
|
| 316 |
+
space_title = title or graph.name
|
| 317 |
+
|
| 318 |
+
print(f" Graph name: {graph.name}")
|
| 319 |
+
print(f" Space name: {space_name}")
|
| 320 |
+
print(f" Space title: {space_title}")
|
| 321 |
+
|
| 322 |
+
hf_api = HfApi()
|
| 323 |
+
whoami = None
|
| 324 |
+
login_needed = False
|
| 325 |
+
|
| 326 |
+
try:
|
| 327 |
+
whoami = hf_api.whoami()
|
| 328 |
+
if whoami["auth"]["accessToken"]["role"] != "write":
|
| 329 |
+
login_needed = True
|
| 330 |
+
except Exception:
|
| 331 |
+
login_needed = True
|
| 332 |
+
|
| 333 |
+
if login_needed:
|
| 334 |
+
print("\n Need 'write' access token to create a Spaces repo.")
|
| 335 |
+
huggingface_hub.login(add_to_git_credential=False)
|
| 336 |
+
whoami = hf_api.whoami()
|
| 337 |
+
|
| 338 |
+
username = whoami["name"]
|
| 339 |
+
namespace = org or username
|
| 340 |
+
repo_id = f"{namespace}/{space_name}"
|
| 341 |
+
|
| 342 |
+
print(f"\n Target: https://huggingface.co/spaces/{repo_id}")
|
| 343 |
+
print(f" Hardware: {hardware}")
|
| 344 |
+
print(f" Private: {private}")
|
| 345 |
+
if secrets:
|
| 346 |
+
print(f" Secrets: {list(secrets.keys())}")
|
| 347 |
+
|
| 348 |
+
local_imports = find_python_imports(script_path)
|
| 349 |
+
print("\n Files to upload:")
|
| 350 |
+
print(f" • app.py (from {script_path.name})")
|
| 351 |
+
print(" • requirements.txt")
|
| 352 |
+
print(" • README.md")
|
| 353 |
+
for imp in local_imports:
|
| 354 |
+
if imp.is_file():
|
| 355 |
+
print(f" • {imp.name}")
|
| 356 |
+
else:
|
| 357 |
+
print(f" • {imp.name}/ (package)")
|
| 358 |
+
|
| 359 |
+
if dry_run:
|
| 360 |
+
print("\n [Dry run] No changes made.")
|
| 361 |
+
return
|
| 362 |
+
|
| 363 |
+
with tempfile.TemporaryDirectory() as tmpdir:
|
| 364 |
+
tmpdir = Path(tmpdir)
|
| 365 |
+
|
| 366 |
+
shutil.copy(script_path, tmpdir / "app.py")
|
| 367 |
+
|
| 368 |
+
for imp in local_imports:
|
| 369 |
+
if imp.is_file():
|
| 370 |
+
shutil.copy(imp, tmpdir / imp.name)
|
| 371 |
+
else:
|
| 372 |
+
shutil.copytree(imp, tmpdir / imp.name)
|
| 373 |
+
|
| 374 |
+
if requirements_path:
|
| 375 |
+
req_path = Path(requirements_path)
|
| 376 |
+
if not req_path.exists():
|
| 377 |
+
print(f"Error: Requirements file not found: {req_path}")
|
| 378 |
+
sys.exit(1)
|
| 379 |
+
shutil.copy(req_path, tmpdir / "requirements.txt")
|
| 380 |
+
|
| 381 |
+
with open(tmpdir / "requirements.txt", "r") as f:
|
| 382 |
+
req_content = f.read()
|
| 383 |
+
if "daggr" not in req_content:
|
| 384 |
+
with open(tmpdir / "requirements.txt", "a") as f:
|
| 385 |
+
f.write(f"\ndaggr>={daggr.__version__}\n")
|
| 386 |
+
else:
|
| 387 |
+
script_dir = script_path.parent
|
| 388 |
+
existing_req = script_dir / "requirements.txt"
|
| 389 |
+
if existing_req.exists():
|
| 390 |
+
shutil.copy(existing_req, tmpdir / "requirements.txt")
|
| 391 |
+
with open(tmpdir / "requirements.txt", "r") as f:
|
| 392 |
+
req_content = f.read()
|
| 393 |
+
if "daggr" not in req_content:
|
| 394 |
+
with open(tmpdir / "requirements.txt", "a") as f:
|
| 395 |
+
f.write(f"\ndaggr>={daggr.__version__}\n")
|
| 396 |
+
else:
|
| 397 |
+
with open(tmpdir / "requirements.txt", "w") as f:
|
| 398 |
+
f.write(f"daggr>={daggr.__version__}\n")
|
| 399 |
+
|
| 400 |
+
readme_content = f"""---
|
| 401 |
+
title: {space_title}
|
| 402 |
+
emoji: 🔀
|
| 403 |
+
colorFrom: blue
|
| 404 |
+
colorTo: purple
|
| 405 |
+
sdk: gradio
|
| 406 |
+
sdk_version: "{_get_gradio_version()}"
|
| 407 |
+
app_file: app.py
|
| 408 |
+
pinned: false
|
| 409 |
+
tags:
|
| 410 |
+
- daggr
|
| 411 |
+
---
|
| 412 |
+
|
| 413 |
+
# {space_title}
|
| 414 |
+
|
| 415 |
+
This Space was deployed using [daggr](https://github.com/gradio-app/daggr).
|
| 416 |
+
"""
|
| 417 |
+
with open(tmpdir / "README.md", "w") as f:
|
| 418 |
+
f.write(readme_content)
|
| 419 |
+
|
| 420 |
+
print("\n Creating Space repository...")
|
| 421 |
+
try:
|
| 422 |
+
hf_api.create_repo(
|
| 423 |
+
repo_id=repo_id,
|
| 424 |
+
repo_type="space",
|
| 425 |
+
space_sdk="gradio",
|
| 426 |
+
space_hardware=hardware,
|
| 427 |
+
private=private,
|
| 428 |
+
exist_ok=True,
|
| 429 |
+
)
|
| 430 |
+
except Exception as e:
|
| 431 |
+
print(f"Error creating repository: {e}")
|
| 432 |
+
sys.exit(1)
|
| 433 |
+
|
| 434 |
+
print(" Uploading files...")
|
| 435 |
+
try:
|
| 436 |
+
hf_api.upload_folder(
|
| 437 |
+
repo_id=repo_id,
|
| 438 |
+
repo_type="space",
|
| 439 |
+
folder_path=str(tmpdir),
|
| 440 |
+
)
|
| 441 |
+
except Exception as e:
|
| 442 |
+
print(f"Error uploading files: {e}")
|
| 443 |
+
sys.exit(1)
|
| 444 |
+
|
| 445 |
+
if secrets:
|
| 446 |
+
print(" Adding secrets...")
|
| 447 |
+
for secret_name, secret_value in secrets.items():
|
| 448 |
+
try:
|
| 449 |
+
hf_api.add_space_secret(repo_id, secret_name, secret_value)
|
| 450 |
+
except Exception as e:
|
| 451 |
+
print(f" Warning: Could not add secret '{secret_name}': {e}")
|
| 452 |
+
|
| 453 |
+
print(f"\n ✓ Deployed to https://huggingface.co/spaces/{repo_id}")
|
| 454 |
+
print(" The Space may take a few minutes to build and start.\n")
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def _get_gradio_version() -> str:
|
| 458 |
+
"""Get the installed Gradio version."""
|
| 459 |
+
try:
|
| 460 |
+
import gradio
|
| 461 |
+
|
| 462 |
+
return gradio.__version__
|
| 463 |
+
except ImportError:
|
| 464 |
+
return "5.0.0"
|
| 465 |
+
|
| 466 |
+
|
| 467 |
+
def _delete_sheets(script_path: Path, force: bool = False):
|
| 468 |
+
"""Delete all cached data for the project defined in the script."""
|
| 469 |
+
from daggr.graph import Graph
|
| 470 |
+
from daggr.state import get_daggr_cache_dir
|
| 471 |
+
|
| 472 |
+
sys.path.insert(0, str(script_path.parent))
|
| 473 |
+
|
| 474 |
+
original_launch = Graph.launch
|
| 475 |
+
captured_graph = None
|
| 476 |
+
|
| 477 |
+
def capture_launch(self, **kwargs):
|
| 478 |
+
nonlocal captured_graph
|
| 479 |
+
captured_graph = self
|
| 480 |
+
|
| 481 |
+
Graph.launch = capture_launch
|
| 482 |
+
|
| 483 |
+
try:
|
| 484 |
+
spec = importlib.util.spec_from_file_location("__daggr_reset__", script_path)
|
| 485 |
+
if spec is None or spec.loader is None:
|
| 486 |
+
print(f"Error: Could not load script: {script_path}")
|
| 487 |
+
sys.exit(1)
|
| 488 |
+
|
| 489 |
+
module = importlib.util.module_from_spec(spec)
|
| 490 |
+
sys.modules["__daggr_reset__"] = module
|
| 491 |
+
spec.loader.exec_module(module)
|
| 492 |
+
finally:
|
| 493 |
+
Graph.launch = original_launch
|
| 494 |
+
|
| 495 |
+
if captured_graph is None:
|
| 496 |
+
for name in dir(module):
|
| 497 |
+
obj = getattr(module, name)
|
| 498 |
+
if isinstance(obj, Graph):
|
| 499 |
+
captured_graph = obj
|
| 500 |
+
break
|
| 501 |
+
|
| 502 |
+
if captured_graph is None:
|
| 503 |
+
print(f"Error: No Graph found in {script_path}")
|
| 504 |
+
sys.exit(1)
|
| 505 |
+
|
| 506 |
+
persist_key = captured_graph.persist_key
|
| 507 |
+
if not persist_key:
|
| 508 |
+
print("Error: Graph has no persist_key (persistence is disabled)")
|
| 509 |
+
sys.exit(1)
|
| 510 |
+
|
| 511 |
+
cache_dir = get_daggr_cache_dir()
|
| 512 |
+
db_path = cache_dir / "sessions.db"
|
| 513 |
+
|
| 514 |
+
if not db_path.exists():
|
| 515 |
+
print(f"No cache found for project '{persist_key}'")
|
| 516 |
+
return
|
| 517 |
+
|
| 518 |
+
conn = sqlite3.connect(str(db_path))
|
| 519 |
+
cursor = conn.cursor()
|
| 520 |
+
|
| 521 |
+
cursor.execute(
|
| 522 |
+
"SELECT sheet_id FROM sheets WHERE graph_name = ?",
|
| 523 |
+
(persist_key,),
|
| 524 |
+
)
|
| 525 |
+
sheet_ids = [row[0] for row in cursor.fetchall()]
|
| 526 |
+
|
| 527 |
+
if not sheet_ids:
|
| 528 |
+
print(f"No cached data found for project '{persist_key}'")
|
| 529 |
+
conn.close()
|
| 530 |
+
return
|
| 531 |
+
|
| 532 |
+
print(f"\nProject: {persist_key}")
|
| 533 |
+
print(f"This will delete {len(sheet_ids)} sheet(s) and all associated data.")
|
| 534 |
+
print(f"Cache location: {cache_dir}\n")
|
| 535 |
+
|
| 536 |
+
if not force:
|
| 537 |
+
try:
|
| 538 |
+
response = (
|
| 539 |
+
input("Are you sure you want to continue? [y/N] ").strip().lower()
|
| 540 |
+
)
|
| 541 |
+
except (EOFError, KeyboardInterrupt):
|
| 542 |
+
print("\nAborted.")
|
| 543 |
+
conn.close()
|
| 544 |
+
return
|
| 545 |
+
|
| 546 |
+
if response not in ("y", "yes"):
|
| 547 |
+
print("Aborted.")
|
| 548 |
+
conn.close()
|
| 549 |
+
return
|
| 550 |
+
|
| 551 |
+
for sheet_id in sheet_ids:
|
| 552 |
+
cursor.execute("DELETE FROM node_inputs WHERE sheet_id = ?", (sheet_id,))
|
| 553 |
+
cursor.execute("DELETE FROM node_results WHERE sheet_id = ?", (sheet_id,))
|
| 554 |
+
cursor.execute("DELETE FROM sheets WHERE sheet_id = ?", (sheet_id,))
|
| 555 |
+
|
| 556 |
+
conn.commit()
|
| 557 |
+
conn.close()
|
| 558 |
+
|
| 559 |
+
print(f"\n✓ Deleted {len(sheet_ids)} sheet(s) for project '{persist_key}'")
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def _run_script(script_path: Path, host: str, port: int):
|
| 563 |
+
"""Run the script directly without reload."""
|
| 564 |
+
spec = importlib.util.spec_from_file_location("__daggr_main__", script_path)
|
| 565 |
+
if spec is None or spec.loader is None:
|
| 566 |
+
print(f"Error: Could not load script: {script_path}")
|
| 567 |
+
sys.exit(1)
|
| 568 |
+
|
| 569 |
+
sys.path.insert(0, str(script_path.parent))
|
| 570 |
+
|
| 571 |
+
module = importlib.util.module_from_spec(spec)
|
| 572 |
+
sys.modules["__daggr_main__"] = module
|
| 573 |
+
spec.loader.exec_module(module)
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
def _run_with_reload(script_path: Path, host: str, port: int, watch_daggr: bool):
|
| 577 |
+
"""Run the script with uvicorn hot reload."""
|
| 578 |
+
import uvicorn
|
| 579 |
+
|
| 580 |
+
actual_port = _find_available_port(host, port)
|
| 581 |
+
if actual_port != port:
|
| 582 |
+
print(f"\n Port {port} is in use, using {actual_port} instead.")
|
| 583 |
+
|
| 584 |
+
reload_dirs = [str(script_path.parent)]
|
| 585 |
+
|
| 586 |
+
local_imports = find_python_imports(script_path)
|
| 587 |
+
for imp in local_imports:
|
| 588 |
+
imp_dir = str(imp if imp.is_dir() else imp.parent)
|
| 589 |
+
if imp_dir not in reload_dirs:
|
| 590 |
+
reload_dirs.append(imp_dir)
|
| 591 |
+
|
| 592 |
+
if watch_daggr:
|
| 593 |
+
daggr_dir = Path(__file__).parent
|
| 594 |
+
daggr_src = str(daggr_dir)
|
| 595 |
+
if daggr_src not in reload_dirs:
|
| 596 |
+
reload_dirs.append(daggr_src)
|
| 597 |
+
|
| 598 |
+
reload_includes = ["*.py"]
|
| 599 |
+
|
| 600 |
+
print("\n daggr dev server starting...")
|
| 601 |
+
print(" Watching for changes in:")
|
| 602 |
+
for d in reload_dirs:
|
| 603 |
+
print(f" • {d}")
|
| 604 |
+
print()
|
| 605 |
+
|
| 606 |
+
os.environ["DAGGR_PORT"] = str(actual_port)
|
| 607 |
+
|
| 608 |
+
def open_browser():
|
| 609 |
+
time.sleep(1.0)
|
| 610 |
+
webbrowser.open_new_tab(f"http://{host}:{actual_port}")
|
| 611 |
+
|
| 612 |
+
threading.Thread(target=open_browser, daemon=True).start()
|
| 613 |
+
|
| 614 |
+
uvicorn.run(
|
| 615 |
+
"daggr.cli:_create_app",
|
| 616 |
+
factory=True,
|
| 617 |
+
host=host,
|
| 618 |
+
port=actual_port,
|
| 619 |
+
reload=True,
|
| 620 |
+
reload_dirs=reload_dirs,
|
| 621 |
+
reload_includes=reload_includes,
|
| 622 |
+
log_level="warning",
|
| 623 |
+
)
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
def _create_app():
|
| 627 |
+
"""Factory function for uvicorn to create the FastAPI app."""
|
| 628 |
+
from daggr.graph import Graph
|
| 629 |
+
from daggr.server import DaggrServer
|
| 630 |
+
|
| 631 |
+
script_path = Path(os.environ["DAGGR_SCRIPT_PATH"])
|
| 632 |
+
|
| 633 |
+
if str(script_path.parent) not in sys.path:
|
| 634 |
+
sys.path.insert(0, str(script_path.parent))
|
| 635 |
+
|
| 636 |
+
modules_to_remove = [m for m in sys.modules if m.startswith("__daggr_user_script_")]
|
| 637 |
+
for m in modules_to_remove:
|
| 638 |
+
del sys.modules[m]
|
| 639 |
+
|
| 640 |
+
module_name = f"__daggr_user_script_{id(script_path)}__"
|
| 641 |
+
|
| 642 |
+
spec = importlib.util.spec_from_file_location(module_name, script_path)
|
| 643 |
+
if spec is None or spec.loader is None:
|
| 644 |
+
raise RuntimeError(f"Could not load script: {script_path}")
|
| 645 |
+
|
| 646 |
+
original_launch = Graph.launch
|
| 647 |
+
captured_graph = None
|
| 648 |
+
launch_kwargs = {}
|
| 649 |
+
|
| 650 |
+
def capture_launch(self, **kwargs):
|
| 651 |
+
nonlocal captured_graph, launch_kwargs
|
| 652 |
+
captured_graph = self
|
| 653 |
+
launch_kwargs = kwargs
|
| 654 |
+
|
| 655 |
+
Graph.launch = capture_launch
|
| 656 |
+
|
| 657 |
+
try:
|
| 658 |
+
module = importlib.util.module_from_spec(spec)
|
| 659 |
+
sys.modules[module_name] = module
|
| 660 |
+
spec.loader.exec_module(module)
|
| 661 |
+
finally:
|
| 662 |
+
Graph.launch = original_launch
|
| 663 |
+
|
| 664 |
+
if captured_graph is None:
|
| 665 |
+
for name in dir(module):
|
| 666 |
+
obj = getattr(module, name)
|
| 667 |
+
if isinstance(obj, Graph):
|
| 668 |
+
captured_graph = obj
|
| 669 |
+
break
|
| 670 |
+
|
| 671 |
+
if captured_graph is None:
|
| 672 |
+
raise RuntimeError(
|
| 673 |
+
f"No Graph found in {script_path}. "
|
| 674 |
+
"Make sure your script defines a Graph and calls graph.launch() "
|
| 675 |
+
"or has a Graph instance at module level."
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
captured_graph._validate_edges()
|
| 679 |
+
server = DaggrServer(captured_graph)
|
| 680 |
+
|
| 681 |
+
base_url = f"http://{os.environ['DAGGR_HOST']}:{os.environ['DAGGR_PORT']}"
|
| 682 |
+
print(f"\n UI running at: {base_url}")
|
| 683 |
+
print(f" API server at: {base_url}/api\n")
|
| 684 |
+
|
| 685 |
+
return server.app
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
if __name__ == "__main__":
|
| 689 |
+
main()
|
daggr/edge.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Edge module for connecting ports between nodes."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
from daggr.port import GatheredPort, ScatteredPort
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from daggr.port import PortLike
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Edge:
|
| 14 |
+
"""Represents a connection between two ports in a graph.
|
| 15 |
+
|
| 16 |
+
Edges connect an output port of one node to an input port of another,
|
| 17 |
+
defining how data flows through the graph.
|
| 18 |
+
|
| 19 |
+
Attributes:
|
| 20 |
+
source_node: The node providing the output.
|
| 21 |
+
source_port: Name of the output port.
|
| 22 |
+
target_node: The node receiving the input.
|
| 23 |
+
target_port: Name of the input port.
|
| 24 |
+
is_scattered: True if this edge scatters a list to multiple executions.
|
| 25 |
+
is_gathered: True if this edge gathers results back into a list.
|
| 26 |
+
item_key: For scattered edges, the key to extract from each item.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, source: PortLike, target: PortLike):
|
| 30 |
+
self.is_scattered = isinstance(source, ScatteredPort)
|
| 31 |
+
self.is_gathered = isinstance(source, GatheredPort)
|
| 32 |
+
self.item_key: str | None = None
|
| 33 |
+
|
| 34 |
+
if self.is_scattered:
|
| 35 |
+
self.item_key = source.item_key
|
| 36 |
+
|
| 37 |
+
self.source_node = source.node
|
| 38 |
+
self.source_port = source.name
|
| 39 |
+
self.target_node = target.node
|
| 40 |
+
self.target_port = target.name
|
| 41 |
+
|
| 42 |
+
def __repr__(self):
|
| 43 |
+
prefix = ""
|
| 44 |
+
if self.is_scattered:
|
| 45 |
+
key_info = f"['{self.item_key}']" if self.item_key else ""
|
| 46 |
+
prefix = f"scatter{key_info}:"
|
| 47 |
+
elif self.is_gathered:
|
| 48 |
+
prefix = "gather:"
|
| 49 |
+
return (
|
| 50 |
+
f"Edge({prefix}{self.source_node._name}.{self.source_port} -> "
|
| 51 |
+
f"{self.target_node._name}.{self.target_port})"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def as_tuple(self) -> tuple[str, str, str, str]:
|
| 55 |
+
return (
|
| 56 |
+
self.source_node._name,
|
| 57 |
+
self.source_port,
|
| 58 |
+
self.target_node._name,
|
| 59 |
+
self.target_port,
|
| 60 |
+
)
|
daggr/executor.py
ADDED
|
@@ -0,0 +1,846 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Executor for daggr graphs.
|
| 2 |
+
|
| 3 |
+
This module provides the AsyncExecutor for running graph nodes with proper
|
| 4 |
+
concurrency control and session isolation.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import asyncio
|
| 10 |
+
import base64
|
| 11 |
+
import hashlib
|
| 12 |
+
import uuid
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import TYPE_CHECKING, Any
|
| 15 |
+
from urllib.parse import urlparse
|
| 16 |
+
|
| 17 |
+
from gradio_client.utils import is_file_obj_with_meta, traverse
|
| 18 |
+
|
| 19 |
+
from daggr.node import (
|
| 20 |
+
ChoiceNode,
|
| 21 |
+
FnNode,
|
| 22 |
+
GradioNode,
|
| 23 |
+
InferenceNode,
|
| 24 |
+
InteractionNode,
|
| 25 |
+
)
|
| 26 |
+
from daggr.session import ExecutionSession
|
| 27 |
+
from daggr.state import get_daggr_files_dir
|
| 28 |
+
|
| 29 |
+
if TYPE_CHECKING:
|
| 30 |
+
from daggr.graph import Graph
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FileValue(str):
|
| 34 |
+
"""A string subclass that marks a value as a file URL/path from Gradio output."""
|
| 35 |
+
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _download_file(url: str, hf_token: str | None = None) -> str:
|
| 40 |
+
import httpx
|
| 41 |
+
|
| 42 |
+
parsed = urlparse(url)
|
| 43 |
+
ext = Path(parsed.path).suffix or ".bin"
|
| 44 |
+
url_hash = hashlib.md5(url.encode()).hexdigest()[:16]
|
| 45 |
+
filename = f"{url_hash}{ext}"
|
| 46 |
+
|
| 47 |
+
files_dir = get_daggr_files_dir()
|
| 48 |
+
local_path = files_dir / filename
|
| 49 |
+
|
| 50 |
+
if not local_path.exists():
|
| 51 |
+
headers = {}
|
| 52 |
+
if hf_token:
|
| 53 |
+
headers["Authorization"] = f"Bearer {hf_token}"
|
| 54 |
+
with httpx.Client(follow_redirects=True) as client:
|
| 55 |
+
response = client.get(url, headers=headers)
|
| 56 |
+
response.raise_for_status()
|
| 57 |
+
local_path.write_bytes(response.content)
|
| 58 |
+
|
| 59 |
+
return str(local_path)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _postprocess_inference_result(task: str | None, result: Any) -> Any:
|
| 63 |
+
"""Unwrap HF Inference Client result objects to get the actual data."""
|
| 64 |
+
if result is None:
|
| 65 |
+
return None
|
| 66 |
+
|
| 67 |
+
if task == "automatic-speech-recognition":
|
| 68 |
+
return getattr(result, "text", result)
|
| 69 |
+
elif task == "translation":
|
| 70 |
+
return getattr(result, "translation_text", result)
|
| 71 |
+
elif task == "summarization":
|
| 72 |
+
return getattr(result, "summary_text", result)
|
| 73 |
+
elif task in (
|
| 74 |
+
"audio-classification",
|
| 75 |
+
"image-classification",
|
| 76 |
+
"text-classification",
|
| 77 |
+
):
|
| 78 |
+
if isinstance(result, list) and result:
|
| 79 |
+
return {item.label: item.score for item in result if hasattr(item, "label")}
|
| 80 |
+
return result
|
| 81 |
+
elif task == "image-to-text":
|
| 82 |
+
return getattr(result, "generated_text", result)
|
| 83 |
+
elif task == "question-answering":
|
| 84 |
+
if hasattr(result, "answer"):
|
| 85 |
+
return result.answer
|
| 86 |
+
return result
|
| 87 |
+
elif task in ("text-to-speech", "text-to-audio"):
|
| 88 |
+
if isinstance(result, bytes):
|
| 89 |
+
file_path = get_daggr_files_dir() / f"{uuid.uuid4()}.wav"
|
| 90 |
+
file_path.write_bytes(result)
|
| 91 |
+
return str(file_path)
|
| 92 |
+
return result
|
| 93 |
+
elif task in ("text-to-image", "image-to-image"):
|
| 94 |
+
if isinstance(result, dict):
|
| 95 |
+
if "images" in result:
|
| 96 |
+
result = result["images"][0] if result["images"] else result
|
| 97 |
+
elif "image" in result:
|
| 98 |
+
result = result["image"]
|
| 99 |
+
if hasattr(result, "save"):
|
| 100 |
+
file_path = get_daggr_files_dir() / f"{uuid.uuid4()}.png"
|
| 101 |
+
result.save(file_path)
|
| 102 |
+
return str(file_path)
|
| 103 |
+
return result
|
| 104 |
+
|
| 105 |
+
return result
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def _call_inference_task(client: Any, task: str | None, inputs: dict[str, Any]) -> Any:
|
| 109 |
+
primary_input = None
|
| 110 |
+
if task in (
|
| 111 |
+
"image-to-image",
|
| 112 |
+
"image-classification",
|
| 113 |
+
"image-to-text",
|
| 114 |
+
"object-detection",
|
| 115 |
+
"image-segmentation",
|
| 116 |
+
"visual-question-answering",
|
| 117 |
+
"document-question-answering",
|
| 118 |
+
):
|
| 119 |
+
primary_input = inputs.get("image")
|
| 120 |
+
elif task in (
|
| 121 |
+
"automatic-speech-recognition",
|
| 122 |
+
"audio-classification",
|
| 123 |
+
"audio-to-audio",
|
| 124 |
+
):
|
| 125 |
+
primary_input = inputs.get("audio")
|
| 126 |
+
|
| 127 |
+
if primary_input is None:
|
| 128 |
+
primary_input = next(iter(inputs.values()), None) if inputs else None
|
| 129 |
+
|
| 130 |
+
if primary_input is None:
|
| 131 |
+
return None
|
| 132 |
+
|
| 133 |
+
task_method_map = {
|
| 134 |
+
"text-generation": "text_generation",
|
| 135 |
+
"text2text-generation": "text_generation",
|
| 136 |
+
"text-to-image": "text_to_image",
|
| 137 |
+
"image-to-image": "image_to_image",
|
| 138 |
+
"image-to-text": "image_to_text",
|
| 139 |
+
"image-to-video": "image_to_video",
|
| 140 |
+
"text-to-video": "text_to_video",
|
| 141 |
+
"text-to-speech": "text_to_speech",
|
| 142 |
+
"text-to-audio": "text_to_audio",
|
| 143 |
+
"automatic-speech-recognition": "automatic_speech_recognition",
|
| 144 |
+
"audio-to-audio": "audio_to_audio",
|
| 145 |
+
"audio-classification": "audio_classification",
|
| 146 |
+
"image-classification": "image_classification",
|
| 147 |
+
"object-detection": "object_detection",
|
| 148 |
+
"image-segmentation": "image_segmentation",
|
| 149 |
+
"translation": "translation",
|
| 150 |
+
"summarization": "summarization",
|
| 151 |
+
"feature-extraction": "feature_extraction",
|
| 152 |
+
"fill-mask": "fill_mask",
|
| 153 |
+
"question-answering": "question_answering",
|
| 154 |
+
"table-question-answering": "table_question_answering",
|
| 155 |
+
"sentence-similarity": "sentence_similarity",
|
| 156 |
+
"zero-shot-classification": "zero_shot_classification",
|
| 157 |
+
"zero-shot-image-classification": "zero_shot_image_classification",
|
| 158 |
+
"document-question-answering": "document_question_answering",
|
| 159 |
+
"visual-question-answering": "visual_question_answering",
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
method_name = (
|
| 163 |
+
task_method_map.get(task, "text_generation") if task else "text_generation"
|
| 164 |
+
)
|
| 165 |
+
method = getattr(client, method_name, None)
|
| 166 |
+
|
| 167 |
+
file_input_tasks = {
|
| 168 |
+
"image-to-image",
|
| 169 |
+
"image-classification",
|
| 170 |
+
"image-to-text",
|
| 171 |
+
"object-detection",
|
| 172 |
+
"image-segmentation",
|
| 173 |
+
"visual-question-answering",
|
| 174 |
+
"document-question-answering",
|
| 175 |
+
"automatic-speech-recognition",
|
| 176 |
+
"audio-classification",
|
| 177 |
+
"audio-to-audio",
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
if task in file_input_tasks and isinstance(primary_input, str):
|
| 181 |
+
primary_input = _read_file_as_bytes(primary_input)
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
if method is None:
|
| 185 |
+
result = client.text_generation(primary_input)
|
| 186 |
+
elif task in ("image-to-image",):
|
| 187 |
+
prompt = inputs.get("prompt", "")
|
| 188 |
+
result = method(primary_input, prompt=prompt)
|
| 189 |
+
elif task in ("visual-question-answering", "document-question-answering"):
|
| 190 |
+
question = inputs.get("question", inputs.get("prompt", ""))
|
| 191 |
+
result = method(primary_input, question=question)
|
| 192 |
+
else:
|
| 193 |
+
result = method(primary_input)
|
| 194 |
+
except KeyError as e:
|
| 195 |
+
raise RuntimeError(
|
| 196 |
+
f"Provider returned unexpected response format for task '{task}'. "
|
| 197 |
+
f"Missing key: {e}. This model may require a specific provider "
|
| 198 |
+
f"(e.g., 'model_name:fal-ai' or 'model_name:replicate')."
|
| 199 |
+
) from e
|
| 200 |
+
|
| 201 |
+
return _postprocess_inference_result(task, result)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _read_file_as_bytes(file_path: str) -> bytes:
|
| 205 |
+
"""Read a file path or data URL as bytes."""
|
| 206 |
+
if file_path.startswith("data:"):
|
| 207 |
+
try:
|
| 208 |
+
_, encoded = file_path.split(",", 1)
|
| 209 |
+
return base64.b64decode(encoded)
|
| 210 |
+
except Exception:
|
| 211 |
+
pass
|
| 212 |
+
|
| 213 |
+
path = Path(file_path)
|
| 214 |
+
if path.exists():
|
| 215 |
+
return path.read_bytes()
|
| 216 |
+
|
| 217 |
+
return file_path
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class AsyncExecutor:
|
| 221 |
+
"""Async executor for graph nodes.
|
| 222 |
+
|
| 223 |
+
This executor is stateless - all state is held in the ExecutionSession.
|
| 224 |
+
It handles concurrency control:
|
| 225 |
+
- GradioNode/InferenceNode: run concurrently (external API calls)
|
| 226 |
+
- FnNode: sequential by default, configurable via concurrent/concurrency_group
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
def __init__(self, graph: Graph):
|
| 230 |
+
self.graph = graph
|
| 231 |
+
|
| 232 |
+
def _get_client_for_gradio_node(
|
| 233 |
+
self, session: ExecutionSession, gradio_node, cache_key: str
|
| 234 |
+
):
|
| 235 |
+
from daggr import _client_cache
|
| 236 |
+
|
| 237 |
+
token_cache_key = f"{cache_key}__token_{hash(session.hf_token or '')}"
|
| 238 |
+
if token_cache_key in session.clients:
|
| 239 |
+
return session.clients[token_cache_key]
|
| 240 |
+
|
| 241 |
+
if gradio_node._run_locally:
|
| 242 |
+
from daggr.local_space import get_local_client
|
| 243 |
+
|
| 244 |
+
client = get_local_client(gradio_node)
|
| 245 |
+
if client is not None:
|
| 246 |
+
session.clients[token_cache_key] = client
|
| 247 |
+
return client
|
| 248 |
+
|
| 249 |
+
if session.hf_token:
|
| 250 |
+
from gradio_client import Client
|
| 251 |
+
|
| 252 |
+
client = Client(
|
| 253 |
+
gradio_node._src,
|
| 254 |
+
download_files=False,
|
| 255 |
+
verbose=False,
|
| 256 |
+
token=session.hf_token,
|
| 257 |
+
)
|
| 258 |
+
else:
|
| 259 |
+
client = _client_cache.get_client(gradio_node._src)
|
| 260 |
+
if client is None:
|
| 261 |
+
from gradio_client import Client
|
| 262 |
+
|
| 263 |
+
client = Client(
|
| 264 |
+
gradio_node._src,
|
| 265 |
+
download_files=False,
|
| 266 |
+
verbose=False,
|
| 267 |
+
)
|
| 268 |
+
_client_cache.set_client(gradio_node._src, client)
|
| 269 |
+
|
| 270 |
+
session.clients[token_cache_key] = client
|
| 271 |
+
return client
|
| 272 |
+
|
| 273 |
+
def _get_client(self, session: ExecutionSession, node_name: str):
|
| 274 |
+
node = self.graph.nodes[node_name]
|
| 275 |
+
|
| 276 |
+
if isinstance(node, ChoiceNode):
|
| 277 |
+
variant_idx = session.selected_variants.get(node_name, 0)
|
| 278 |
+
variant = node._variants[variant_idx]
|
| 279 |
+
if isinstance(variant, GradioNode):
|
| 280 |
+
cache_key = f"{node_name}__variant_{variant_idx}"
|
| 281 |
+
return self._get_client_for_gradio_node(session, variant, cache_key)
|
| 282 |
+
return None
|
| 283 |
+
|
| 284 |
+
if not isinstance(node, GradioNode):
|
| 285 |
+
return None
|
| 286 |
+
|
| 287 |
+
return self._get_client_for_gradio_node(session, node, node_name)
|
| 288 |
+
|
| 289 |
+
def _get_scattered_input_edges(self, node_name: str) -> list:
|
| 290 |
+
scattered = []
|
| 291 |
+
for edge in self.graph._edges:
|
| 292 |
+
if edge.target_node._name == node_name and edge.is_scattered:
|
| 293 |
+
scattered.append(edge)
|
| 294 |
+
return scattered
|
| 295 |
+
|
| 296 |
+
def _get_gathered_input_edges(self, node_name: str) -> list:
|
| 297 |
+
gathered = []
|
| 298 |
+
for edge in self.graph._edges:
|
| 299 |
+
if edge.target_node._name == node_name and edge.is_gathered:
|
| 300 |
+
gathered.append(edge)
|
| 301 |
+
return gathered
|
| 302 |
+
|
| 303 |
+
def _prepare_inputs(
|
| 304 |
+
self, session: ExecutionSession, node_name: str, skip_scattered: bool = False
|
| 305 |
+
) -> dict[str, Any]:
|
| 306 |
+
inputs = {}
|
| 307 |
+
|
| 308 |
+
for edge in self.graph._edges:
|
| 309 |
+
if edge.target_node._name == node_name:
|
| 310 |
+
if skip_scattered and edge.is_scattered:
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
source_name = edge.source_node._name
|
| 314 |
+
source_output = edge.source_port
|
| 315 |
+
target_input = edge.target_port
|
| 316 |
+
|
| 317 |
+
if source_name in session.results:
|
| 318 |
+
source_result = session.results[source_name]
|
| 319 |
+
|
| 320 |
+
if (
|
| 321 |
+
edge.is_gathered
|
| 322 |
+
and isinstance(source_result, dict)
|
| 323 |
+
and "_scattered_results" in source_result
|
| 324 |
+
):
|
| 325 |
+
scattered_results = source_result["_scattered_results"]
|
| 326 |
+
extracted = []
|
| 327 |
+
for item_result in scattered_results:
|
| 328 |
+
if (
|
| 329 |
+
isinstance(item_result, dict)
|
| 330 |
+
and source_output in item_result
|
| 331 |
+
):
|
| 332 |
+
extracted.append(item_result[source_output])
|
| 333 |
+
else:
|
| 334 |
+
extracted.append(item_result)
|
| 335 |
+
inputs[target_input] = extracted
|
| 336 |
+
elif (
|
| 337 |
+
isinstance(source_result, dict)
|
| 338 |
+
and source_output in source_result
|
| 339 |
+
):
|
| 340 |
+
inputs[target_input] = source_result[source_output]
|
| 341 |
+
elif isinstance(source_result, (list, tuple)):
|
| 342 |
+
try:
|
| 343 |
+
output_idx = int(
|
| 344 |
+
source_output.replace("output_", "").replace(
|
| 345 |
+
"output", "0"
|
| 346 |
+
)
|
| 347 |
+
)
|
| 348 |
+
if 0 <= output_idx < len(source_result):
|
| 349 |
+
inputs[target_input] = source_result[output_idx]
|
| 350 |
+
except (ValueError, TypeError):
|
| 351 |
+
if len(source_result) > 0:
|
| 352 |
+
inputs[target_input] = source_result[0]
|
| 353 |
+
else:
|
| 354 |
+
inputs[target_input] = source_result
|
| 355 |
+
|
| 356 |
+
return inputs
|
| 357 |
+
|
| 358 |
+
def _execute_single_node_sync(
|
| 359 |
+
self, session: ExecutionSession, node_name: str, inputs: dict[str, Any]
|
| 360 |
+
) -> Any:
|
| 361 |
+
"""Synchronous node execution (called from thread pool for FnNode)."""
|
| 362 |
+
node = self.graph.nodes[node_name]
|
| 363 |
+
|
| 364 |
+
if isinstance(node, ChoiceNode):
|
| 365 |
+
variant_idx = session.selected_variants.get(node_name, 0)
|
| 366 |
+
variant = node._variants[variant_idx]
|
| 367 |
+
return self._execute_variant_node_sync(session, node_name, variant, inputs)
|
| 368 |
+
|
| 369 |
+
all_inputs = {}
|
| 370 |
+
for port_name, value in node._fixed_inputs.items():
|
| 371 |
+
all_inputs[port_name] = value() if callable(value) else value
|
| 372 |
+
for port_name, component in node._input_components.items():
|
| 373 |
+
if hasattr(component, "value"):
|
| 374 |
+
val = component.value
|
| 375 |
+
if is_file_obj_with_meta(val):
|
| 376 |
+
val = val["path"]
|
| 377 |
+
all_inputs[port_name] = val
|
| 378 |
+
all_inputs.update(inputs)
|
| 379 |
+
|
| 380 |
+
if isinstance(node, GradioNode):
|
| 381 |
+
client = self._get_client(session, node_name)
|
| 382 |
+
if client:
|
| 383 |
+
api_name = node._api_name or "/predict"
|
| 384 |
+
if not api_name.startswith("/"):
|
| 385 |
+
api_name = "/" + api_name
|
| 386 |
+
call_inputs = {
|
| 387 |
+
k: self._wrap_file_input(v)
|
| 388 |
+
for k, v in all_inputs.items()
|
| 389 |
+
if k in node._input_ports
|
| 390 |
+
}
|
| 391 |
+
if node._preprocess:
|
| 392 |
+
call_inputs = node._preprocess(call_inputs)
|
| 393 |
+
raw_result = client.predict(api_name=api_name, **call_inputs)
|
| 394 |
+
if node._postprocess:
|
| 395 |
+
raw_result = self._apply_postprocess(node._postprocess, raw_result)
|
| 396 |
+
result = self._map_gradio_result(
|
| 397 |
+
node, raw_result, hf_token=session.hf_token
|
| 398 |
+
)
|
| 399 |
+
else:
|
| 400 |
+
result = None
|
| 401 |
+
|
| 402 |
+
elif isinstance(node, FnNode):
|
| 403 |
+
fn_kwargs = {}
|
| 404 |
+
for port_name in node._input_ports:
|
| 405 |
+
if port_name in all_inputs:
|
| 406 |
+
fn_kwargs[port_name] = all_inputs[port_name]
|
| 407 |
+
if node._preprocess:
|
| 408 |
+
fn_kwargs = node._preprocess(fn_kwargs)
|
| 409 |
+
raw_result = node._fn(**fn_kwargs)
|
| 410 |
+
if node._postprocess:
|
| 411 |
+
raw_result = self._apply_postprocess(node._postprocess, raw_result)
|
| 412 |
+
result = self._map_fn_result(node, raw_result)
|
| 413 |
+
|
| 414 |
+
elif isinstance(node, InferenceNode):
|
| 415 |
+
from huggingface_hub import InferenceClient
|
| 416 |
+
|
| 417 |
+
if not node._task_fetched:
|
| 418 |
+
node._fetch_model_info()
|
| 419 |
+
client = InferenceClient(
|
| 420 |
+
model=node._model_name_for_hub,
|
| 421 |
+
provider=node._provider,
|
| 422 |
+
token=session.hf_token,
|
| 423 |
+
)
|
| 424 |
+
inference_inputs = {
|
| 425 |
+
k: v for k, v in all_inputs.items() if k in node._input_ports
|
| 426 |
+
}
|
| 427 |
+
if node._preprocess:
|
| 428 |
+
inference_inputs = node._preprocess(inference_inputs)
|
| 429 |
+
raw_result = _call_inference_task(client, node._task, inference_inputs)
|
| 430 |
+
if node._postprocess:
|
| 431 |
+
raw_result = self._apply_postprocess(node._postprocess, raw_result)
|
| 432 |
+
result = self._map_inference_result(node, raw_result)
|
| 433 |
+
|
| 434 |
+
elif isinstance(node, InteractionNode):
|
| 435 |
+
result = all_inputs.get(
|
| 436 |
+
"input",
|
| 437 |
+
all_inputs.get(node._input_ports[0]) if node._input_ports else None,
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
else:
|
| 441 |
+
result = None
|
| 442 |
+
|
| 443 |
+
return result
|
| 444 |
+
|
| 445 |
+
def _execute_variant_node_sync(
|
| 446 |
+
self,
|
| 447 |
+
session: ExecutionSession,
|
| 448 |
+
node_name: str,
|
| 449 |
+
variant,
|
| 450 |
+
inputs: dict[str, Any],
|
| 451 |
+
) -> Any:
|
| 452 |
+
all_inputs = {}
|
| 453 |
+
for port_name, value in variant._fixed_inputs.items():
|
| 454 |
+
all_inputs[port_name] = value() if callable(value) else value
|
| 455 |
+
for port_name, component in variant._input_components.items():
|
| 456 |
+
if hasattr(component, "value"):
|
| 457 |
+
val = component.value
|
| 458 |
+
if is_file_obj_with_meta(val):
|
| 459 |
+
val = val["path"]
|
| 460 |
+
all_inputs[port_name] = val
|
| 461 |
+
all_inputs.update(inputs)
|
| 462 |
+
|
| 463 |
+
if isinstance(variant, GradioNode):
|
| 464 |
+
client = self._get_client(session, node_name)
|
| 465 |
+
if client:
|
| 466 |
+
api_name = variant._api_name or "/predict"
|
| 467 |
+
if not api_name.startswith("/"):
|
| 468 |
+
api_name = "/" + api_name
|
| 469 |
+
call_inputs = {
|
| 470 |
+
k: self._wrap_file_input(v)
|
| 471 |
+
for k, v in all_inputs.items()
|
| 472 |
+
if k in variant._input_ports
|
| 473 |
+
}
|
| 474 |
+
if variant._preprocess:
|
| 475 |
+
call_inputs = variant._preprocess(call_inputs)
|
| 476 |
+
raw_result = client.predict(api_name=api_name, **call_inputs)
|
| 477 |
+
if variant._postprocess:
|
| 478 |
+
raw_result = self._apply_postprocess(
|
| 479 |
+
variant._postprocess, raw_result
|
| 480 |
+
)
|
| 481 |
+
result = self._map_gradio_result(
|
| 482 |
+
variant, raw_result, hf_token=session.hf_token
|
| 483 |
+
)
|
| 484 |
+
else:
|
| 485 |
+
result = None
|
| 486 |
+
|
| 487 |
+
elif isinstance(variant, FnNode):
|
| 488 |
+
fn_kwargs = {}
|
| 489 |
+
for port_name in variant._input_ports:
|
| 490 |
+
if port_name in all_inputs:
|
| 491 |
+
fn_kwargs[port_name] = all_inputs[port_name]
|
| 492 |
+
if variant._preprocess:
|
| 493 |
+
fn_kwargs = variant._preprocess(fn_kwargs)
|
| 494 |
+
raw_result = variant._fn(**fn_kwargs)
|
| 495 |
+
if variant._postprocess:
|
| 496 |
+
raw_result = self._apply_postprocess(variant._postprocess, raw_result)
|
| 497 |
+
result = self._map_fn_result(variant, raw_result)
|
| 498 |
+
|
| 499 |
+
elif isinstance(variant, InferenceNode):
|
| 500 |
+
from huggingface_hub import InferenceClient
|
| 501 |
+
|
| 502 |
+
if not variant._task_fetched:
|
| 503 |
+
variant._fetch_model_info()
|
| 504 |
+
client = InferenceClient(
|
| 505 |
+
model=variant._model_name_for_hub,
|
| 506 |
+
provider=variant._provider,
|
| 507 |
+
token=session.hf_token,
|
| 508 |
+
)
|
| 509 |
+
inference_inputs = {
|
| 510 |
+
k: v for k, v in all_inputs.items() if k in variant._input_ports
|
| 511 |
+
}
|
| 512 |
+
if variant._preprocess:
|
| 513 |
+
inference_inputs = variant._preprocess(inference_inputs)
|
| 514 |
+
raw_result = _call_inference_task(client, variant._task, inference_inputs)
|
| 515 |
+
if variant._postprocess:
|
| 516 |
+
raw_result = self._apply_postprocess(variant._postprocess, raw_result)
|
| 517 |
+
result = self._map_inference_result(variant, raw_result)
|
| 518 |
+
|
| 519 |
+
else:
|
| 520 |
+
result = None
|
| 521 |
+
|
| 522 |
+
return result
|
| 523 |
+
|
| 524 |
+
async def execute_node(
|
| 525 |
+
self,
|
| 526 |
+
session: ExecutionSession,
|
| 527 |
+
node_name: str,
|
| 528 |
+
user_inputs: dict[str, Any] | None = None,
|
| 529 |
+
) -> Any:
|
| 530 |
+
"""Execute a single node with proper concurrency control."""
|
| 531 |
+
node = self.graph.nodes[node_name]
|
| 532 |
+
scattered_edges = self._get_scattered_input_edges(node_name)
|
| 533 |
+
|
| 534 |
+
if scattered_edges:
|
| 535 |
+
result = await self._execute_scattered_node(
|
| 536 |
+
session, node_name, scattered_edges, user_inputs
|
| 537 |
+
)
|
| 538 |
+
else:
|
| 539 |
+
inputs = self._prepare_inputs(session, node_name)
|
| 540 |
+
if user_inputs:
|
| 541 |
+
if isinstance(user_inputs, dict):
|
| 542 |
+
inputs.update(user_inputs)
|
| 543 |
+
else:
|
| 544 |
+
if node._input_ports:
|
| 545 |
+
inputs[node._input_ports[0]] = user_inputs
|
| 546 |
+
else:
|
| 547 |
+
inputs["input"] = user_inputs
|
| 548 |
+
|
| 549 |
+
try:
|
| 550 |
+
if isinstance(node, (GradioNode, InferenceNode)):
|
| 551 |
+
result = await asyncio.to_thread(
|
| 552 |
+
self._execute_single_node_sync, session, node_name, inputs
|
| 553 |
+
)
|
| 554 |
+
elif isinstance(node, FnNode):
|
| 555 |
+
semaphore = await session.concurrency.get_semaphore(
|
| 556 |
+
node._concurrent,
|
| 557 |
+
node._concurrency_group,
|
| 558 |
+
node._max_concurrent,
|
| 559 |
+
)
|
| 560 |
+
if semaphore:
|
| 561 |
+
async with semaphore:
|
| 562 |
+
result = await asyncio.to_thread(
|
| 563 |
+
self._execute_single_node_sync,
|
| 564 |
+
session,
|
| 565 |
+
node_name,
|
| 566 |
+
inputs,
|
| 567 |
+
)
|
| 568 |
+
else:
|
| 569 |
+
result = await asyncio.to_thread(
|
| 570 |
+
self._execute_single_node_sync, session, node_name, inputs
|
| 571 |
+
)
|
| 572 |
+
else:
|
| 573 |
+
result = await asyncio.to_thread(
|
| 574 |
+
self._execute_single_node_sync, session, node_name, inputs
|
| 575 |
+
)
|
| 576 |
+
except Exception as e:
|
| 577 |
+
raise RuntimeError(f"Error executing node '{node_name}': {e}")
|
| 578 |
+
|
| 579 |
+
session.results[node_name] = result
|
| 580 |
+
return result
|
| 581 |
+
|
| 582 |
+
async def _execute_scattered_node(
|
| 583 |
+
self,
|
| 584 |
+
session: ExecutionSession,
|
| 585 |
+
node_name: str,
|
| 586 |
+
scattered_edges: list,
|
| 587 |
+
user_inputs: dict[str, Any] | None = None,
|
| 588 |
+
) -> dict[str, list[Any]]:
|
| 589 |
+
first_edge = scattered_edges[0]
|
| 590 |
+
source_name = first_edge.source_node._name
|
| 591 |
+
source_port = first_edge.source_port
|
| 592 |
+
|
| 593 |
+
source_result = session.results.get(source_name)
|
| 594 |
+
if source_result is None:
|
| 595 |
+
items = []
|
| 596 |
+
elif isinstance(source_result, dict) and source_port in source_result:
|
| 597 |
+
items = source_result[source_port]
|
| 598 |
+
else:
|
| 599 |
+
items = source_result
|
| 600 |
+
|
| 601 |
+
if not isinstance(items, list):
|
| 602 |
+
items = [items]
|
| 603 |
+
|
| 604 |
+
context_inputs = self._prepare_inputs(session, node_name, skip_scattered=True)
|
| 605 |
+
if user_inputs:
|
| 606 |
+
context_inputs.update(user_inputs)
|
| 607 |
+
|
| 608 |
+
node = self.graph.nodes[node_name]
|
| 609 |
+
|
| 610 |
+
async def execute_item(item, idx):
|
| 611 |
+
item_inputs = dict(context_inputs)
|
| 612 |
+
for edge in scattered_edges:
|
| 613 |
+
target_port = edge.target_port
|
| 614 |
+
item_key = edge.item_key
|
| 615 |
+
if item_key and isinstance(item, dict):
|
| 616 |
+
item_inputs[target_port] = item.get(item_key)
|
| 617 |
+
else:
|
| 618 |
+
item_inputs[target_port] = item
|
| 619 |
+
|
| 620 |
+
try:
|
| 621 |
+
if isinstance(node, (GradioNode, InferenceNode)):
|
| 622 |
+
return await asyncio.to_thread(
|
| 623 |
+
self._execute_single_node_sync, session, node_name, item_inputs
|
| 624 |
+
)
|
| 625 |
+
elif isinstance(node, FnNode):
|
| 626 |
+
semaphore = await session.concurrency.get_semaphore(
|
| 627 |
+
node._concurrent,
|
| 628 |
+
node._concurrency_group,
|
| 629 |
+
node._max_concurrent,
|
| 630 |
+
)
|
| 631 |
+
if semaphore:
|
| 632 |
+
async with semaphore:
|
| 633 |
+
return await asyncio.to_thread(
|
| 634 |
+
self._execute_single_node_sync,
|
| 635 |
+
session,
|
| 636 |
+
node_name,
|
| 637 |
+
item_inputs,
|
| 638 |
+
)
|
| 639 |
+
else:
|
| 640 |
+
return await asyncio.to_thread(
|
| 641 |
+
self._execute_single_node_sync,
|
| 642 |
+
session,
|
| 643 |
+
node_name,
|
| 644 |
+
item_inputs,
|
| 645 |
+
)
|
| 646 |
+
else:
|
| 647 |
+
return await asyncio.to_thread(
|
| 648 |
+
self._execute_single_node_sync, session, node_name, item_inputs
|
| 649 |
+
)
|
| 650 |
+
except Exception as e:
|
| 651 |
+
return {"error": str(e)}
|
| 652 |
+
|
| 653 |
+
if isinstance(node, (GradioNode, InferenceNode)):
|
| 654 |
+
tasks = [execute_item(item, i) for i, item in enumerate(items)]
|
| 655 |
+
results = await asyncio.gather(*tasks)
|
| 656 |
+
else:
|
| 657 |
+
results = []
|
| 658 |
+
for i, item in enumerate(items):
|
| 659 |
+
result = await execute_item(item, i)
|
| 660 |
+
results.append(result)
|
| 661 |
+
|
| 662 |
+
session.scattered_results[node_name] = list(results)
|
| 663 |
+
return {"_scattered_results": list(results), "_items": items}
|
| 664 |
+
|
| 665 |
+
def _wrap_file_input(self, value: Any) -> Any:
|
| 666 |
+
from gradio_client import handle_file
|
| 667 |
+
|
| 668 |
+
if isinstance(value, FileValue):
|
| 669 |
+
return handle_file(str(value))
|
| 670 |
+
|
| 671 |
+
if isinstance(value, str):
|
| 672 |
+
if value.startswith("data:"):
|
| 673 |
+
file_path = self._save_data_url_to_file(value)
|
| 674 |
+
if file_path:
|
| 675 |
+
return handle_file(file_path)
|
| 676 |
+
elif Path(value).exists():
|
| 677 |
+
return handle_file(value)
|
| 678 |
+
|
| 679 |
+
return value
|
| 680 |
+
|
| 681 |
+
def _save_data_url_to_file(self, data_url: str) -> str | None:
|
| 682 |
+
"""Convert a base64 data URL to a file and return the path."""
|
| 683 |
+
if not data_url.startswith("data:"):
|
| 684 |
+
return None
|
| 685 |
+
|
| 686 |
+
try:
|
| 687 |
+
header, encoded = data_url.split(",", 1)
|
| 688 |
+
media_type = header.split(":")[1].split(";")[0]
|
| 689 |
+
ext_map = {
|
| 690 |
+
"image/png": ".png",
|
| 691 |
+
"image/jpeg": ".jpg",
|
| 692 |
+
"image/jpg": ".jpg",
|
| 693 |
+
"image/gif": ".gif",
|
| 694 |
+
"image/webp": ".webp",
|
| 695 |
+
"audio/wav": ".wav",
|
| 696 |
+
"audio/mpeg": ".mp3",
|
| 697 |
+
"audio/mp3": ".mp3",
|
| 698 |
+
"audio/ogg": ".ogg",
|
| 699 |
+
"audio/webm": ".webm",
|
| 700 |
+
"video/mp4": ".mp4",
|
| 701 |
+
"video/webm": ".webm",
|
| 702 |
+
}
|
| 703 |
+
ext = ext_map.get(media_type, ".bin")
|
| 704 |
+
data = base64.b64decode(encoded)
|
| 705 |
+
file_path = get_daggr_files_dir() / f"{uuid.uuid4()}{ext}"
|
| 706 |
+
file_path.write_bytes(data)
|
| 707 |
+
return str(file_path)
|
| 708 |
+
except Exception:
|
| 709 |
+
return None
|
| 710 |
+
|
| 711 |
+
def _apply_postprocess(self, postprocess, raw_result: Any) -> Any:
|
| 712 |
+
if isinstance(raw_result, (list, tuple)):
|
| 713 |
+
return postprocess(*raw_result)
|
| 714 |
+
return postprocess(raw_result)
|
| 715 |
+
|
| 716 |
+
def _extract_file_urls(self, data: Any, hf_token: str | None = None) -> Any:
|
| 717 |
+
def download_and_wrap(file_obj: dict) -> FileValue:
|
| 718 |
+
url = file_obj.get("url")
|
| 719 |
+
if url:
|
| 720 |
+
local_path = _download_file(url, hf_token=hf_token)
|
| 721 |
+
return FileValue(local_path)
|
| 722 |
+
path = file_obj.get("path", "")
|
| 723 |
+
return FileValue(path)
|
| 724 |
+
|
| 725 |
+
return traverse(data, download_and_wrap, is_file_obj_with_meta)
|
| 726 |
+
|
| 727 |
+
def _map_gradio_result(
|
| 728 |
+
self, node, raw_result: Any, hf_token: str | None = None
|
| 729 |
+
) -> dict[str, Any]:
|
| 730 |
+
if raw_result is None:
|
| 731 |
+
return {}
|
| 732 |
+
|
| 733 |
+
raw_result = self._extract_file_urls(raw_result, hf_token=hf_token)
|
| 734 |
+
|
| 735 |
+
output_ports = node._output_ports
|
| 736 |
+
if not output_ports:
|
| 737 |
+
return {"output": raw_result}
|
| 738 |
+
|
| 739 |
+
if isinstance(raw_result, (list, tuple)):
|
| 740 |
+
result = {}
|
| 741 |
+
for i, port_name in enumerate(output_ports):
|
| 742 |
+
if i < len(raw_result):
|
| 743 |
+
result[port_name] = raw_result[i]
|
| 744 |
+
else:
|
| 745 |
+
result[port_name] = None
|
| 746 |
+
return result
|
| 747 |
+
elif len(output_ports) == 1:
|
| 748 |
+
return {output_ports[0]: raw_result}
|
| 749 |
+
else:
|
| 750 |
+
return {output_ports[0]: raw_result}
|
| 751 |
+
|
| 752 |
+
def _map_fn_result(self, node, raw_result: Any) -> dict[str, Any]:
|
| 753 |
+
if raw_result is None:
|
| 754 |
+
return {}
|
| 755 |
+
|
| 756 |
+
output_ports = node._output_ports
|
| 757 |
+
if not output_ports:
|
| 758 |
+
return {"output": raw_result}
|
| 759 |
+
|
| 760 |
+
if isinstance(raw_result, tuple):
|
| 761 |
+
result = {}
|
| 762 |
+
for i, port_name in enumerate(output_ports):
|
| 763 |
+
if i < len(raw_result):
|
| 764 |
+
result[port_name] = raw_result[i]
|
| 765 |
+
else:
|
| 766 |
+
result[port_name] = None
|
| 767 |
+
return result
|
| 768 |
+
else:
|
| 769 |
+
return {output_ports[0]: raw_result}
|
| 770 |
+
|
| 771 |
+
def _map_inference_result(self, node, raw_result: Any) -> dict[str, Any]:
|
| 772 |
+
"""Map inference API result to output ports."""
|
| 773 |
+
if raw_result is None:
|
| 774 |
+
return {}
|
| 775 |
+
|
| 776 |
+
output_ports = node._output_ports
|
| 777 |
+
if not output_ports:
|
| 778 |
+
return {"output": raw_result}
|
| 779 |
+
|
| 780 |
+
return {output_ports[0]: raw_result}
|
| 781 |
+
|
| 782 |
+
async def execute_all(
|
| 783 |
+
self, session: ExecutionSession, entry_inputs: dict[str, dict[str, Any]]
|
| 784 |
+
) -> dict[str, Any]:
|
| 785 |
+
execution_order = self.graph.get_execution_order()
|
| 786 |
+
session.results = {}
|
| 787 |
+
|
| 788 |
+
for node_name in execution_order:
|
| 789 |
+
user_input = entry_inputs.get(node_name, {})
|
| 790 |
+
await self.execute_node(session, node_name, user_input)
|
| 791 |
+
|
| 792 |
+
return session.results
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
class SequentialExecutor:
|
| 796 |
+
"""Legacy synchronous executor for backwards compatibility.
|
| 797 |
+
|
| 798 |
+
This wraps the AsyncExecutor for use in synchronous contexts like node.test().
|
| 799 |
+
For production use, prefer AsyncExecutor with proper session management.
|
| 800 |
+
"""
|
| 801 |
+
|
| 802 |
+
def __init__(self, graph: Graph, hf_token: str | None = None):
|
| 803 |
+
self.graph = graph
|
| 804 |
+
self._async_executor = AsyncExecutor(graph)
|
| 805 |
+
self._session = ExecutionSession(graph, hf_token)
|
| 806 |
+
|
| 807 |
+
@property
|
| 808 |
+
def results(self) -> dict[str, Any]:
|
| 809 |
+
return self._session.results
|
| 810 |
+
|
| 811 |
+
@results.setter
|
| 812 |
+
def results(self, value: dict[str, Any]):
|
| 813 |
+
self._session.results = value
|
| 814 |
+
|
| 815 |
+
@property
|
| 816 |
+
def selected_variants(self) -> dict[str, int]:
|
| 817 |
+
return self._session.selected_variants
|
| 818 |
+
|
| 819 |
+
@selected_variants.setter
|
| 820 |
+
def selected_variants(self, value: dict[str, int]):
|
| 821 |
+
self._session.selected_variants = value
|
| 822 |
+
|
| 823 |
+
def set_hf_token(self, token: str | None):
|
| 824 |
+
self._session.set_hf_token(token)
|
| 825 |
+
|
| 826 |
+
def execute_node(
|
| 827 |
+
self, node_name: str, user_inputs: dict[str, Any] | None = None
|
| 828 |
+
) -> Any:
|
| 829 |
+
"""Synchronous wrapper around async execute_node."""
|
| 830 |
+
loop = asyncio.new_event_loop()
|
| 831 |
+
try:
|
| 832 |
+
return loop.run_until_complete(
|
| 833 |
+
self._async_executor.execute_node(self._session, node_name, user_inputs)
|
| 834 |
+
)
|
| 835 |
+
finally:
|
| 836 |
+
loop.close()
|
| 837 |
+
|
| 838 |
+
def execute_all(self, entry_inputs: dict[str, dict[str, Any]]) -> dict[str, Any]:
|
| 839 |
+
"""Synchronous wrapper around async execute_all."""
|
| 840 |
+
loop = asyncio.new_event_loop()
|
| 841 |
+
try:
|
| 842 |
+
return loop.run_until_complete(
|
| 843 |
+
self._async_executor.execute_all(self._session, entry_inputs)
|
| 844 |
+
)
|
| 845 |
+
finally:
|
| 846 |
+
loop.close()
|
daggr/graph.py
ADDED
|
@@ -0,0 +1,767 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Graph module for daggr.
|
| 2 |
+
|
| 3 |
+
A Graph represents a directed acyclic graph (DAG) of nodes that can be
|
| 4 |
+
executed to process data through a pipeline.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import itertools
|
| 10 |
+
import os
|
| 11 |
+
import re
|
| 12 |
+
import sys
|
| 13 |
+
import threading
|
| 14 |
+
from collections.abc import Sequence
|
| 15 |
+
from typing import TYPE_CHECKING, Any
|
| 16 |
+
|
| 17 |
+
import networkx as nx
|
| 18 |
+
|
| 19 |
+
from daggr._utils import suggest_similar
|
| 20 |
+
from daggr.edge import Edge
|
| 21 |
+
from daggr.local_space import prepare_local_node
|
| 22 |
+
from daggr.node import ChoiceNode, GradioNode, InferenceNode, Node
|
| 23 |
+
from daggr.port import Port
|
| 24 |
+
|
| 25 |
+
if TYPE_CHECKING:
|
| 26 |
+
from gradio.themes import ThemeClass as Theme
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _parse_space_id(src: str) -> str | None:
|
| 30 |
+
if src.startswith("http://") or src.startswith("https://"):
|
| 31 |
+
match = re.match(r"https?://huggingface\.co/spaces/([^/]+/[^/?#]+)", src)
|
| 32 |
+
if match:
|
| 33 |
+
return match.group(1)
|
| 34 |
+
return None
|
| 35 |
+
if "/" in src:
|
| 36 |
+
return src
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _get_dependency_id(node) -> tuple[str | None, str]:
|
| 41 |
+
if isinstance(node, GradioNode):
|
| 42 |
+
space_id = _parse_space_id(node._src)
|
| 43 |
+
return space_id, "space"
|
| 44 |
+
elif isinstance(node, InferenceNode):
|
| 45 |
+
return node._model_name_for_hub, "model"
|
| 46 |
+
return None, ""
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _fetch_current_sha(dep_id: str, dep_type: str) -> str | None:
|
| 50 |
+
try:
|
| 51 |
+
if dep_type == "space":
|
| 52 |
+
from huggingface_hub import space_info
|
| 53 |
+
|
| 54 |
+
info = space_info(dep_id)
|
| 55 |
+
return info.sha
|
| 56 |
+
elif dep_type == "model":
|
| 57 |
+
from huggingface_hub import model_info
|
| 58 |
+
|
| 59 |
+
info = model_info(dep_id)
|
| 60 |
+
return info.sha
|
| 61 |
+
except Exception:
|
| 62 |
+
return None
|
| 63 |
+
return None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _duplicate_space_at_revision(
|
| 67 |
+
space_id: str, revision: str, username: str
|
| 68 |
+
) -> str | None:
|
| 69 |
+
try:
|
| 70 |
+
from huggingface_hub import (
|
| 71 |
+
create_repo,
|
| 72 |
+
snapshot_download,
|
| 73 |
+
upload_folder,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
space_name = space_id.split("/")[-1]
|
| 77 |
+
new_repo_id = f"{username}/{space_name}"
|
| 78 |
+
|
| 79 |
+
local_dir = snapshot_download(
|
| 80 |
+
repo_id=space_id,
|
| 81 |
+
repo_type="space",
|
| 82 |
+
revision=revision,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
create_repo(
|
| 86 |
+
repo_id=new_repo_id,
|
| 87 |
+
repo_type="space",
|
| 88 |
+
space_sdk="gradio",
|
| 89 |
+
exist_ok=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
upload_folder(
|
| 93 |
+
repo_id=new_repo_id,
|
| 94 |
+
repo_type="space",
|
| 95 |
+
folder_path=local_dir,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return new_repo_id
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f" [daggr] Failed to duplicate Space: {e}")
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _prompt_dependency_changes(changed: list[dict]) -> None:
|
| 105 |
+
from daggr import _client_cache
|
| 106 |
+
|
| 107 |
+
is_tty = hasattr(sys.stdin, "isatty") and sys.stdin.isatty()
|
| 108 |
+
|
| 109 |
+
print("\n ⚠️ Upstream dependency changes detected:\n")
|
| 110 |
+
for item in changed:
|
| 111 |
+
print(
|
| 112 |
+
f" • {item['type']} '{item['id']}' (node: {item['node']._name})\n"
|
| 113 |
+
f" cached: {item['cached_sha'][:12]}\n"
|
| 114 |
+
f" current: {item['current_sha'][:12]}"
|
| 115 |
+
)
|
| 116 |
+
print()
|
| 117 |
+
|
| 118 |
+
if not is_tty:
|
| 119 |
+
for item in changed:
|
| 120 |
+
_client_cache.set_dependency_hash(item["id"], item["current_sha"])
|
| 121 |
+
print(
|
| 122 |
+
" [daggr] Non-interactive mode: auto-updated all hashes.\n"
|
| 123 |
+
" Set DAGGR_DEPENDENCY_CHECK=skip to suppress this warning.\n"
|
| 124 |
+
)
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
for item in changed:
|
| 128 |
+
is_space = item["type"] == "space"
|
| 129 |
+
if is_space:
|
| 130 |
+
print(
|
| 131 |
+
f" How would you like to handle '{item['id']}'?\n"
|
| 132 |
+
f" [1] Duplicate the original version under your namespace (safer)\n"
|
| 133 |
+
f" [2] Update to the latest version"
|
| 134 |
+
)
|
| 135 |
+
else:
|
| 136 |
+
print(
|
| 137 |
+
f" How would you like to handle '{item['id']}'?\n"
|
| 138 |
+
f" [1] Update to the latest version"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
try:
|
| 142 |
+
choice = input(" Choice [1]: ").strip() or "1"
|
| 143 |
+
except (EOFError, KeyboardInterrupt):
|
| 144 |
+
choice = "1"
|
| 145 |
+
|
| 146 |
+
if is_space and choice == "1":
|
| 147 |
+
username = _get_hf_username()
|
| 148 |
+
if username is None:
|
| 149 |
+
print(
|
| 150 |
+
" [daggr] Not logged in to Hugging Face. "
|
| 151 |
+
"Updating hash instead.\n"
|
| 152 |
+
" Run `huggingface-cli login` to enable Space duplication."
|
| 153 |
+
)
|
| 154 |
+
_client_cache.set_dependency_hash(item["id"], item["current_sha"])
|
| 155 |
+
else:
|
| 156 |
+
print(
|
| 157 |
+
f" [daggr] Duplicating '{item['id']}' at revision "
|
| 158 |
+
f"{item['cached_sha'][:12]} under {username}/..."
|
| 159 |
+
)
|
| 160 |
+
new_id = _duplicate_space_at_revision(
|
| 161 |
+
item["id"], item["cached_sha"], username
|
| 162 |
+
)
|
| 163 |
+
if new_id:
|
| 164 |
+
item["node"]._src = new_id
|
| 165 |
+
_client_cache.set_dependency_hash(new_id, item["cached_sha"])
|
| 166 |
+
print(
|
| 167 |
+
f" [daggr] Duplicated → '{new_id}'. "
|
| 168 |
+
f"Node now points to duplicated Space."
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
print(
|
| 172 |
+
" [daggr] Duplication failed (revision may have been "
|
| 173 |
+
"squashed). Updating hash instead."
|
| 174 |
+
)
|
| 175 |
+
_client_cache.set_dependency_hash(item["id"], item["current_sha"])
|
| 176 |
+
else:
|
| 177 |
+
_client_cache.set_dependency_hash(item["id"], item["current_sha"])
|
| 178 |
+
print(f" [daggr] Updated hash for '{item['id']}'.")
|
| 179 |
+
|
| 180 |
+
print()
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _get_hf_username() -> str | None:
|
| 184 |
+
try:
|
| 185 |
+
from huggingface_hub import get_token, whoami
|
| 186 |
+
|
| 187 |
+
token = get_token()
|
| 188 |
+
if not token:
|
| 189 |
+
return None
|
| 190 |
+
info = whoami(cache=True)
|
| 191 |
+
return info.get("name")
|
| 192 |
+
except Exception:
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
class _Spinner:
|
| 197 |
+
_CHARS = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
|
| 198 |
+
|
| 199 |
+
def __init__(self, message: str):
|
| 200 |
+
self._message = message
|
| 201 |
+
self._is_tty = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
|
| 202 |
+
if self._is_tty:
|
| 203 |
+
self._stop = threading.Event()
|
| 204 |
+
self._thread = threading.Thread(target=self._spin, daemon=True)
|
| 205 |
+
self._thread.start()
|
| 206 |
+
|
| 207 |
+
def _spin(self):
|
| 208 |
+
frames = itertools.cycle(self._CHARS)
|
| 209 |
+
while not self._stop.is_set():
|
| 210 |
+
sys.stdout.write(f"\r {next(frames)} {self._message}")
|
| 211 |
+
sys.stdout.flush()
|
| 212 |
+
self._stop.wait(0.08)
|
| 213 |
+
|
| 214 |
+
def _finish(self, symbol: str, suffix: str = ""):
|
| 215 |
+
line = f" {symbol} {self._message}"
|
| 216 |
+
if suffix:
|
| 217 |
+
line += f" — {suffix}"
|
| 218 |
+
if self._is_tty:
|
| 219 |
+
self._stop.set()
|
| 220 |
+
self._thread.join()
|
| 221 |
+
sys.stdout.write(f"\r{line}\033[K\n")
|
| 222 |
+
else:
|
| 223 |
+
sys.stdout.write(f"{line}\n")
|
| 224 |
+
sys.stdout.flush()
|
| 225 |
+
|
| 226 |
+
def succeed(self, suffix: str = ""):
|
| 227 |
+
self._finish("✓", suffix)
|
| 228 |
+
|
| 229 |
+
def warn(self, suffix: str = ""):
|
| 230 |
+
self._finish("⚠", suffix)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _get_node_display_label(node) -> str:
|
| 234 |
+
if isinstance(node, GradioNode):
|
| 235 |
+
label = node._src
|
| 236 |
+
if node._api_name:
|
| 237 |
+
label += f" ({node._api_name})"
|
| 238 |
+
return label
|
| 239 |
+
elif isinstance(node, InferenceNode):
|
| 240 |
+
return node._model_name_for_hub
|
| 241 |
+
return node._name
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class Graph:
|
| 245 |
+
"""A directed acyclic graph (DAG) of nodes for data processing.
|
| 246 |
+
|
| 247 |
+
A Graph connects nodes together to form a pipeline. Data flows from entry
|
| 248 |
+
nodes (nodes with no inputs) through the graph to output nodes.
|
| 249 |
+
|
| 250 |
+
Example:
|
| 251 |
+
>>> from daggr import Graph, FnNode
|
| 252 |
+
>>> def step1(x): return {"out": x * 2}
|
| 253 |
+
>>> def step2(y): return {"out": y + 1}
|
| 254 |
+
>>> n1 = FnNode(step1)
|
| 255 |
+
>>> n2 = FnNode(step2, inputs={"y": n1.out})
|
| 256 |
+
>>> graph = Graph("My Pipeline", nodes=[n2])
|
| 257 |
+
>>> graph.launch()
|
| 258 |
+
"""
|
| 259 |
+
|
| 260 |
+
def __init__(
|
| 261 |
+
self,
|
| 262 |
+
name: str,
|
| 263 |
+
nodes: Sequence[Node] | None = None,
|
| 264 |
+
persist_key: str | bool | None = None,
|
| 265 |
+
):
|
| 266 |
+
"""Create a new Graph.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
name: Display name for this graph shown in the UI.
|
| 270 |
+
nodes: Optional list of nodes to add to the graph.
|
| 271 |
+
persist_key: Unique key used to store this graph's data in the database.
|
| 272 |
+
If not provided, derived from name by converting to lowercase
|
| 273 |
+
and replacing spaces/special chars with underscores.
|
| 274 |
+
Set to False to disable persistence entirely.
|
| 275 |
+
Use a custom string to ensure persistence works correctly
|
| 276 |
+
if you change the display name later.
|
| 277 |
+
"""
|
| 278 |
+
if not name or not isinstance(name, str):
|
| 279 |
+
raise ValueError(
|
| 280 |
+
"Graph requires a 'name' parameter. "
|
| 281 |
+
"Example: Graph(name='My Podcast Generator', nodes=[...])"
|
| 282 |
+
)
|
| 283 |
+
self.name = name
|
| 284 |
+
if persist_key is False:
|
| 285 |
+
self.persist_key = None
|
| 286 |
+
elif persist_key:
|
| 287 |
+
self.persist_key = persist_key
|
| 288 |
+
else:
|
| 289 |
+
self.persist_key = re.sub(r"[^a-z0-9]+", "_", name.lower()).strip("_")
|
| 290 |
+
self.nodes: dict[str, Node] = {}
|
| 291 |
+
self._nx_graph = nx.DiGraph()
|
| 292 |
+
self._edges: list[Edge] = []
|
| 293 |
+
|
| 294 |
+
if nodes:
|
| 295 |
+
for node in nodes:
|
| 296 |
+
self.add(node)
|
| 297 |
+
|
| 298 |
+
def add(self, node: Node) -> Graph:
|
| 299 |
+
"""Add a node to the graph.
|
| 300 |
+
|
| 301 |
+
Also adds any upstream nodes connected via the node's port connections.
|
| 302 |
+
|
| 303 |
+
Args:
|
| 304 |
+
node: The node to add.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
self, for method chaining.
|
| 308 |
+
"""
|
| 309 |
+
self._add_node(node)
|
| 310 |
+
self._create_edges_from_port_connections(node)
|
| 311 |
+
return self
|
| 312 |
+
|
| 313 |
+
def edge(self, source: Port, target: Port) -> Graph:
|
| 314 |
+
"""Create an edge connecting two ports.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
source: The source port (output of a node).
|
| 318 |
+
target: The target port (input of a node).
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
self, for method chaining.
|
| 322 |
+
|
| 323 |
+
Raises:
|
| 324 |
+
ValueError: If the edge would create a cycle.
|
| 325 |
+
"""
|
| 326 |
+
edge = Edge(source, target)
|
| 327 |
+
self._add_edge(edge)
|
| 328 |
+
return self
|
| 329 |
+
|
| 330 |
+
def _add_node(self, node: Node) -> None:
|
| 331 |
+
if node._name in self.nodes:
|
| 332 |
+
if self.nodes[node._name] is not node:
|
| 333 |
+
raise ValueError(f"Node with name '{node._name}' already exists")
|
| 334 |
+
return
|
| 335 |
+
self.nodes[node._name] = node
|
| 336 |
+
self._nx_graph.add_node(node._name)
|
| 337 |
+
|
| 338 |
+
def _create_edges_from_port_connections(self, node: Node) -> None:
|
| 339 |
+
for target_port_name, source_port in node._port_connections.items():
|
| 340 |
+
source_node = source_port.node
|
| 341 |
+
source_port_name = source_port.name
|
| 342 |
+
|
| 343 |
+
if source_port_name not in source_node._output_ports:
|
| 344 |
+
available = set(source_node._output_ports)
|
| 345 |
+
suggestion = suggest_similar(source_port_name, available)
|
| 346 |
+
available_str = ", ".join(available) or "(none)"
|
| 347 |
+
msg = (
|
| 348 |
+
f"Output port '{source_port_name}' not found on node "
|
| 349 |
+
f"'{source_node._name}'. Available outputs: {available_str}"
|
| 350 |
+
)
|
| 351 |
+
if suggestion:
|
| 352 |
+
msg += f" Did you mean '{suggestion}'?"
|
| 353 |
+
raise ValueError(msg)
|
| 354 |
+
|
| 355 |
+
is_new_node = source_node._name not in self.nodes
|
| 356 |
+
self._add_node(source_node)
|
| 357 |
+
if is_new_node:
|
| 358 |
+
self._create_edges_from_port_connections(source_node)
|
| 359 |
+
target_port = Port(node, target_port_name)
|
| 360 |
+
edge = Edge(source_port, target_port)
|
| 361 |
+
self._add_edge(edge)
|
| 362 |
+
|
| 363 |
+
def _add_edge(self, edge: Edge) -> None:
|
| 364 |
+
self._add_node(edge.source_node)
|
| 365 |
+
self._add_node(edge.target_node)
|
| 366 |
+
|
| 367 |
+
self._edges.append(edge)
|
| 368 |
+
self._nx_graph.add_edge(edge.source_node._name, edge.target_node._name)
|
| 369 |
+
|
| 370 |
+
if not nx.is_directed_acyclic_graph(self._nx_graph):
|
| 371 |
+
self._nx_graph.remove_edge(edge.source_node._name, edge.target_node._name)
|
| 372 |
+
self._edges.pop()
|
| 373 |
+
raise ValueError("Connection would create a cycle in the DAG")
|
| 374 |
+
|
| 375 |
+
def get_entry_nodes(self) -> list[Node]:
|
| 376 |
+
"""Get all nodes with no incoming edges (entry points of the graph)."""
|
| 377 |
+
entry_nodes = []
|
| 378 |
+
for node_name in self.nodes:
|
| 379 |
+
if self._nx_graph.in_degree(node_name) == 0:
|
| 380 |
+
entry_nodes.append(self.nodes[node_name])
|
| 381 |
+
return entry_nodes
|
| 382 |
+
|
| 383 |
+
def get_execution_order(self) -> list[str]:
|
| 384 |
+
"""Get the topologically sorted order of node names for execution."""
|
| 385 |
+
return list(nx.topological_sort(self._nx_graph))
|
| 386 |
+
|
| 387 |
+
def get_connections(self) -> list[tuple]:
|
| 388 |
+
"""Get all edges as tuples of (source_node, source_port, target_node, target_port)."""
|
| 389 |
+
return [edge.as_tuple() for edge in self._edges]
|
| 390 |
+
|
| 391 |
+
def _validate_edges(self) -> None:
|
| 392 |
+
errors = []
|
| 393 |
+
for edge in self._edges:
|
| 394 |
+
source_node = edge.source_node
|
| 395 |
+
target_node = edge.target_node
|
| 396 |
+
source_port = edge.source_port
|
| 397 |
+
target_port = edge.target_port
|
| 398 |
+
|
| 399 |
+
if source_port not in source_node._output_ports:
|
| 400 |
+
available = set(source_node._output_ports)
|
| 401 |
+
available_str = ", ".join(available) or "(none)"
|
| 402 |
+
suggestion = suggest_similar(source_port, available)
|
| 403 |
+
msg = (
|
| 404 |
+
f"Output port '{source_port}' not found on node "
|
| 405 |
+
f"'{source_node._name}'. Available outputs: {available_str}"
|
| 406 |
+
)
|
| 407 |
+
if suggestion:
|
| 408 |
+
msg += f" Did you mean '{suggestion}'?"
|
| 409 |
+
errors.append(msg)
|
| 410 |
+
|
| 411 |
+
if target_port not in target_node._input_ports:
|
| 412 |
+
available = set(target_node._input_ports)
|
| 413 |
+
available_str = ", ".join(available) or "(none)"
|
| 414 |
+
suggestion = suggest_similar(target_port, available)
|
| 415 |
+
msg = (
|
| 416 |
+
f"Input port '{target_port}' not found on node "
|
| 417 |
+
f"'{target_node._name}'. Available inputs: {available_str}"
|
| 418 |
+
)
|
| 419 |
+
if suggestion:
|
| 420 |
+
msg += f" Did you mean '{suggestion}'?"
|
| 421 |
+
errors.append(msg)
|
| 422 |
+
|
| 423 |
+
if errors:
|
| 424 |
+
raise ValueError("Invalid port connections:\n - " + "\n - ".join(errors))
|
| 425 |
+
|
| 426 |
+
def launch(
|
| 427 |
+
self,
|
| 428 |
+
host: str | None = None,
|
| 429 |
+
port: int | None = None,
|
| 430 |
+
share: bool | None = None,
|
| 431 |
+
open_browser: bool = True,
|
| 432 |
+
theme: Theme | str | None = None,
|
| 433 |
+
api_server: bool = True,
|
| 434 |
+
**kwargs,
|
| 435 |
+
):
|
| 436 |
+
"""Launch the graph as an interactive web application.
|
| 437 |
+
|
| 438 |
+
Starts a web server that displays the graph and allows users to
|
| 439 |
+
execute nodes and view results.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
host: Host to bind to. Defaults to GRADIO_SERVER_NAME env var,
|
| 443 |
+
or "127.0.0.1" if not set. Set to "0.0.0.0" to make
|
| 444 |
+
accessible on a network or when deploying to Hugging Face Spaces.
|
| 445 |
+
port: Port to bind to. Defaults to GRADIO_SERVER_PORT env var,
|
| 446 |
+
or 7860 if not set.
|
| 447 |
+
share: If True, create a public share link. Defaults to True in
|
| 448 |
+
Colab/Kaggle environments, False otherwise.
|
| 449 |
+
open_browser: If True, automatically open the app in the default
|
| 450 |
+
web browser. Defaults to True.
|
| 451 |
+
theme: A Gradio theme to use for styling. Can be a Gradio `Theme` instance,
|
| 452 |
+
a string name like "default", "soft", "monochrome", "glass",
|
| 453 |
+
or a Hub theme like "gradio/seafoam". Defaults to the Gradio
|
| 454 |
+
default theme.
|
| 455 |
+
api_server: If True, expose the programmatic API endpoints
|
| 456 |
+
(/api/call, /api/schema). Defaults to True.
|
| 457 |
+
**kwargs: Additional arguments passed to uvicorn.
|
| 458 |
+
"""
|
| 459 |
+
from daggr.server import DaggrServer
|
| 460 |
+
|
| 461 |
+
if host is None:
|
| 462 |
+
host = os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1")
|
| 463 |
+
if port is None:
|
| 464 |
+
port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
| 465 |
+
|
| 466 |
+
self._startup_display()
|
| 467 |
+
server = DaggrServer(self, theme=theme, api_server=api_server)
|
| 468 |
+
server.run(
|
| 469 |
+
host=host, port=port, share=share, open_browser=open_browser, **kwargs
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
def _prepare_local_nodes(self) -> None:
|
| 473 |
+
for node in self.nodes.values():
|
| 474 |
+
if isinstance(node, ChoiceNode):
|
| 475 |
+
for variant in node._variants:
|
| 476 |
+
if isinstance(variant, GradioNode) and variant._run_locally:
|
| 477 |
+
prepare_local_node(variant)
|
| 478 |
+
elif isinstance(node, GradioNode) and node._run_locally:
|
| 479 |
+
prepare_local_node(node)
|
| 480 |
+
|
| 481 |
+
def _check_dependency_hashes(self) -> None:
|
| 482 |
+
mode = os.environ.get("DAGGR_DEPENDENCY_CHECK", "").lower()
|
| 483 |
+
if mode == "skip":
|
| 484 |
+
return
|
| 485 |
+
|
| 486 |
+
from daggr import _client_cache
|
| 487 |
+
|
| 488 |
+
nodes_to_check: list[GradioNode | InferenceNode] = []
|
| 489 |
+
for node in self.nodes.values():
|
| 490 |
+
if isinstance(node, ChoiceNode):
|
| 491 |
+
for variant in node._variants:
|
| 492 |
+
if isinstance(variant, (GradioNode, InferenceNode)):
|
| 493 |
+
nodes_to_check.append(variant)
|
| 494 |
+
elif isinstance(node, (GradioNode, InferenceNode)):
|
| 495 |
+
nodes_to_check.append(node)
|
| 496 |
+
|
| 497 |
+
if not nodes_to_check:
|
| 498 |
+
return
|
| 499 |
+
|
| 500 |
+
changed: list[dict[str, Any]] = []
|
| 501 |
+
for node in nodes_to_check:
|
| 502 |
+
dep_id, dep_type = _get_dependency_id(node)
|
| 503 |
+
if dep_id is None:
|
| 504 |
+
continue
|
| 505 |
+
|
| 506 |
+
current_sha = _fetch_current_sha(dep_id, dep_type)
|
| 507 |
+
if current_sha is None:
|
| 508 |
+
continue
|
| 509 |
+
|
| 510 |
+
cached_sha = _client_cache.get_dependency_hash(dep_id)
|
| 511 |
+
if cached_sha is None:
|
| 512 |
+
_client_cache.set_dependency_hash(dep_id, current_sha)
|
| 513 |
+
elif cached_sha != current_sha:
|
| 514 |
+
changed.append(
|
| 515 |
+
{
|
| 516 |
+
"type": dep_type,
|
| 517 |
+
"id": dep_id,
|
| 518 |
+
"node": node,
|
| 519 |
+
"cached_sha": cached_sha,
|
| 520 |
+
"current_sha": current_sha,
|
| 521 |
+
}
|
| 522 |
+
)
|
| 523 |
+
|
| 524 |
+
if not changed:
|
| 525 |
+
return
|
| 526 |
+
|
| 527 |
+
if mode == "update":
|
| 528 |
+
for item in changed:
|
| 529 |
+
_client_cache.set_dependency_hash(item["id"], item["current_sha"])
|
| 530 |
+
print(
|
| 531 |
+
f" [daggr] Auto-updated hash for {item['type']} "
|
| 532 |
+
f"'{item['id']}' → {item['current_sha'][:12]}"
|
| 533 |
+
)
|
| 534 |
+
return
|
| 535 |
+
|
| 536 |
+
if mode == "error":
|
| 537 |
+
descs = [
|
| 538 |
+
f" • {item['type']} '{item['id']}': "
|
| 539 |
+
f"{item['cached_sha'][:12]} → {item['current_sha'][:12]}"
|
| 540 |
+
for item in changed
|
| 541 |
+
]
|
| 542 |
+
raise RuntimeError(
|
| 543 |
+
"Upstream dependencies have changed:\n"
|
| 544 |
+
+ "\n".join(descs)
|
| 545 |
+
+ "\nSet DAGGR_DEPENDENCY_CHECK=update to accept changes."
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
_prompt_dependency_changes(changed)
|
| 549 |
+
|
| 550 |
+
def _startup_display(self) -> None:
|
| 551 |
+
mode = os.environ.get("DAGGR_DEPENDENCY_CHECK", "").lower()
|
| 552 |
+
skip_hashes = mode == "skip"
|
| 553 |
+
|
| 554 |
+
node_count = len(self.nodes)
|
| 555 |
+
noun = "node" if node_count == 1 else "nodes"
|
| 556 |
+
print(f"\n Launching Daggr ({self.name}) with {node_count} {noun}:\n")
|
| 557 |
+
|
| 558 |
+
from daggr import _client_cache
|
| 559 |
+
|
| 560 |
+
changed: list[dict[str, Any]] = []
|
| 561 |
+
|
| 562 |
+
def _check_hash(node):
|
| 563 |
+
dep_id, dep_type = _get_dependency_id(node)
|
| 564 |
+
if dep_id is None:
|
| 565 |
+
return None
|
| 566 |
+
|
| 567 |
+
current_sha = _fetch_current_sha(dep_id, dep_type)
|
| 568 |
+
if current_sha is None:
|
| 569 |
+
return None
|
| 570 |
+
|
| 571 |
+
cached_sha = _client_cache.get_dependency_hash(dep_id)
|
| 572 |
+
if cached_sha is None:
|
| 573 |
+
_client_cache.set_dependency_hash(dep_id, current_sha)
|
| 574 |
+
return ("recorded", f"hash {current_sha[:7]} recorded")
|
| 575 |
+
elif cached_sha == current_sha:
|
| 576 |
+
return ("matches", f"hash {current_sha[:7]} matches")
|
| 577 |
+
else:
|
| 578 |
+
changed.append(
|
| 579 |
+
{
|
| 580 |
+
"type": dep_type,
|
| 581 |
+
"id": dep_id,
|
| 582 |
+
"node": node,
|
| 583 |
+
"cached_sha": cached_sha,
|
| 584 |
+
"current_sha": current_sha,
|
| 585 |
+
}
|
| 586 |
+
)
|
| 587 |
+
return ("changed", "hash changed")
|
| 588 |
+
|
| 589 |
+
for node in self.nodes.values():
|
| 590 |
+
if isinstance(node, ChoiceNode):
|
| 591 |
+
spinner = _Spinner(node._name)
|
| 592 |
+
for variant in node._variants:
|
| 593 |
+
if isinstance(variant, GradioNode) and variant._run_locally:
|
| 594 |
+
prepare_local_node(variant)
|
| 595 |
+
results = []
|
| 596 |
+
if not skip_hashes:
|
| 597 |
+
for variant in node._variants:
|
| 598 |
+
if isinstance(variant, (GradioNode, InferenceNode)):
|
| 599 |
+
result = _check_hash(variant)
|
| 600 |
+
if result:
|
| 601 |
+
results.append(result)
|
| 602 |
+
if any(r[0] == "changed" for r in results):
|
| 603 |
+
spinner.warn("hash changed")
|
| 604 |
+
elif results:
|
| 605 |
+
spinner.succeed(results[-1][1])
|
| 606 |
+
else:
|
| 607 |
+
spinner.succeed()
|
| 608 |
+
continue
|
| 609 |
+
|
| 610 |
+
if isinstance(node, GradioNode) and node._run_locally:
|
| 611 |
+
prepare_local_node(node)
|
| 612 |
+
|
| 613 |
+
label = _get_node_display_label(node)
|
| 614 |
+
|
| 615 |
+
if isinstance(node, (GradioNode, InferenceNode)) and not skip_hashes:
|
| 616 |
+
spinner = _Spinner(label)
|
| 617 |
+
result = _check_hash(node)
|
| 618 |
+
if result and result[0] == "changed":
|
| 619 |
+
spinner.warn(result[1])
|
| 620 |
+
elif result:
|
| 621 |
+
spinner.succeed(result[1])
|
| 622 |
+
else:
|
| 623 |
+
spinner.succeed()
|
| 624 |
+
else:
|
| 625 |
+
sys.stdout.write(f" ✓ {label}\n")
|
| 626 |
+
sys.stdout.flush()
|
| 627 |
+
|
| 628 |
+
print()
|
| 629 |
+
|
| 630 |
+
if not changed:
|
| 631 |
+
return
|
| 632 |
+
|
| 633 |
+
if mode == "update":
|
| 634 |
+
for item in changed:
|
| 635 |
+
_client_cache.set_dependency_hash(item["id"], item["current_sha"])
|
| 636 |
+
print(
|
| 637 |
+
f" [daggr] Auto-updated hash for {item['type']} "
|
| 638 |
+
f"'{item['id']}' → {item['current_sha'][:12]}"
|
| 639 |
+
)
|
| 640 |
+
return
|
| 641 |
+
|
| 642 |
+
if mode == "error":
|
| 643 |
+
descs = [
|
| 644 |
+
f" • {item['type']} '{item['id']}': "
|
| 645 |
+
f"{item['cached_sha'][:12]} → {item['current_sha'][:12]}"
|
| 646 |
+
for item in changed
|
| 647 |
+
]
|
| 648 |
+
raise RuntimeError(
|
| 649 |
+
"Upstream dependencies have changed:\n"
|
| 650 |
+
+ "\n".join(descs)
|
| 651 |
+
+ "\nSet DAGGR_DEPENDENCY_CHECK=update to accept changes."
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
_prompt_dependency_changes(changed)
|
| 655 |
+
|
| 656 |
+
def get_subgraphs(self) -> list[set[str]]:
|
| 657 |
+
"""Get all weakly connected components of the graph.
|
| 658 |
+
|
| 659 |
+
Returns a list of sets, where each set contains the node names
|
| 660 |
+
belonging to a connected subgraph. If the graph is fully connected,
|
| 661 |
+
returns a single set with all node names.
|
| 662 |
+
"""
|
| 663 |
+
return [set(c) for c in nx.weakly_connected_components(self._nx_graph)]
|
| 664 |
+
|
| 665 |
+
def get_output_nodes(self) -> list[str]:
|
| 666 |
+
"""Get all nodes with no outgoing edges (output/leaf nodes)."""
|
| 667 |
+
return [
|
| 668 |
+
node_name
|
| 669 |
+
for node_name in self.nodes
|
| 670 |
+
if self._nx_graph.out_degree(node_name) == 0
|
| 671 |
+
]
|
| 672 |
+
|
| 673 |
+
def get_api_schema(self) -> dict:
|
| 674 |
+
"""Get the API schema describing inputs and outputs for each subgraph.
|
| 675 |
+
|
| 676 |
+
Returns a dict with:
|
| 677 |
+
- subgraphs: list of subgraph info, each containing:
|
| 678 |
+
- id: subgraph identifier (e.g., "main" or "subgraph_0")
|
| 679 |
+
- inputs: list of {node, port, type, component} for each input
|
| 680 |
+
- outputs: list of {node, port, type, component} for each output
|
| 681 |
+
"""
|
| 682 |
+
subgraphs = self.get_subgraphs()
|
| 683 |
+
output_nodes = set(self.get_output_nodes())
|
| 684 |
+
result = {"subgraphs": []}
|
| 685 |
+
|
| 686 |
+
for idx, subgraph_nodes in enumerate(subgraphs):
|
| 687 |
+
subgraph_id = "main" if len(subgraphs) == 1 else f"subgraph_{idx}"
|
| 688 |
+
|
| 689 |
+
inputs = []
|
| 690 |
+
outputs = []
|
| 691 |
+
|
| 692 |
+
for node_name in subgraph_nodes:
|
| 693 |
+
node = self.nodes[node_name]
|
| 694 |
+
|
| 695 |
+
if isinstance(node, ChoiceNode):
|
| 696 |
+
continue
|
| 697 |
+
|
| 698 |
+
if node._input_components:
|
| 699 |
+
for port_name, comp in node._input_components.items():
|
| 700 |
+
comp_type = self._get_component_type(comp)
|
| 701 |
+
inputs.append(
|
| 702 |
+
{
|
| 703 |
+
"node": node_name,
|
| 704 |
+
"port": port_name,
|
| 705 |
+
"type": comp_type,
|
| 706 |
+
"id": f"{node_name}__{port_name}".replace(
|
| 707 |
+
" ", "_"
|
| 708 |
+
).replace("-", "_"),
|
| 709 |
+
}
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
if node_name in output_nodes and node._output_components:
|
| 713 |
+
for port_name, comp in node._output_components.items():
|
| 714 |
+
if comp is None:
|
| 715 |
+
continue
|
| 716 |
+
comp_type = self._get_component_type(comp)
|
| 717 |
+
outputs.append(
|
| 718 |
+
{
|
| 719 |
+
"node": node_name,
|
| 720 |
+
"port": port_name,
|
| 721 |
+
"type": comp_type,
|
| 722 |
+
}
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
result["subgraphs"].append(
|
| 726 |
+
{
|
| 727 |
+
"id": subgraph_id,
|
| 728 |
+
"inputs": inputs,
|
| 729 |
+
"outputs": outputs,
|
| 730 |
+
}
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
return result
|
| 734 |
+
|
| 735 |
+
def _get_component_type(self, component) -> str:
|
| 736 |
+
"""Get the type string for a Gradio component."""
|
| 737 |
+
class_name = component.__class__.__name__
|
| 738 |
+
type_map = {
|
| 739 |
+
"Audio": "audio",
|
| 740 |
+
"Textbox": "textbox",
|
| 741 |
+
"TextArea": "textarea",
|
| 742 |
+
"JSON": "json",
|
| 743 |
+
"Chatbot": "json",
|
| 744 |
+
"Image": "image",
|
| 745 |
+
"Number": "number",
|
| 746 |
+
"Markdown": "markdown",
|
| 747 |
+
"Text": "text",
|
| 748 |
+
"Dropdown": "dropdown",
|
| 749 |
+
"Video": "video",
|
| 750 |
+
"File": "file",
|
| 751 |
+
"Model3D": "model3d",
|
| 752 |
+
"Gallery": "gallery",
|
| 753 |
+
"Slider": "slider",
|
| 754 |
+
"Radio": "radio",
|
| 755 |
+
"Checkbox": "checkbox",
|
| 756 |
+
"CheckboxGroup": "checkboxgroup",
|
| 757 |
+
"ColorPicker": "colorpicker",
|
| 758 |
+
"Label": "label",
|
| 759 |
+
"HighlightedText": "highlightedtext",
|
| 760 |
+
"Code": "code",
|
| 761 |
+
"HTML": "html",
|
| 762 |
+
"Dataframe": "dataframe",
|
| 763 |
+
}
|
| 764 |
+
return type_map.get(class_name, "text")
|
| 765 |
+
|
| 766 |
+
def __repr__(self):
|
| 767 |
+
return f"Graph(name={self.name}, nodes={len(self.nodes)}, edges={len(self._edges)})"
|
daggr/local_space.py
ADDED
|
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import atexit
|
| 4 |
+
import hashlib
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import select
|
| 9 |
+
import shutil
|
| 10 |
+
import socket
|
| 11 |
+
import subprocess
|
| 12 |
+
import sys
|
| 13 |
+
import time
|
| 14 |
+
import urllib.error
|
| 15 |
+
import urllib.request
|
| 16 |
+
from datetime import datetime
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import TYPE_CHECKING, Any
|
| 19 |
+
|
| 20 |
+
if TYPE_CHECKING:
|
| 21 |
+
from daggr.node import GradioNode
|
| 22 |
+
|
| 23 |
+
from daggr.state import get_daggr_cache_dir
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _get_spaces_cache_dir() -> Path:
|
| 27 |
+
return get_daggr_cache_dir() / "spaces"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _get_logs_dir() -> Path:
|
| 31 |
+
return get_daggr_cache_dir() / "logs"
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
_running_processes: dict[str, subprocess.Popen] = {}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _get_space_dir(space_id: str) -> Path:
|
| 38 |
+
spaces_dir = _get_spaces_cache_dir()
|
| 39 |
+
parts = space_id.split("/")
|
| 40 |
+
if len(parts) == 2:
|
| 41 |
+
owner, name = parts
|
| 42 |
+
return spaces_dir / owner / name
|
| 43 |
+
return spaces_dir / space_id.replace("/", "_")
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _get_metadata_path(space_dir: Path) -> Path:
|
| 47 |
+
return space_dir / ".daggr_metadata.json"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _hash_file(file_path: Path) -> str:
|
| 51 |
+
if not file_path.exists():
|
| 52 |
+
return ""
|
| 53 |
+
return hashlib.sha256(file_path.read_bytes()).hexdigest()[:16]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _find_free_port(start: int = 7861, end: int = 7960) -> int:
|
| 57 |
+
for port in range(start, end):
|
| 58 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 59 |
+
try:
|
| 60 |
+
s.bind(("127.0.0.1", port))
|
| 61 |
+
return port
|
| 62 |
+
except OSError:
|
| 63 |
+
continue
|
| 64 |
+
raise RuntimeError(f"No free ports available in range {start}-{end}")
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def _is_space_id(src: str) -> bool:
|
| 68 |
+
if src.startswith("http://") or src.startswith("https://"):
|
| 69 |
+
return False
|
| 70 |
+
return "/" in src and not src.startswith("/")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LocalSpaceManager:
|
| 74 |
+
def __init__(self, node: GradioNode):
|
| 75 |
+
self.node = node
|
| 76 |
+
self.space_id = node._src
|
| 77 |
+
self.space_dir = _get_space_dir(self.space_id)
|
| 78 |
+
self.repo_dir = self.space_dir / "repo"
|
| 79 |
+
self.venv_dir = self.space_dir / ".venv"
|
| 80 |
+
self.metadata_path = _get_metadata_path(self.space_dir)
|
| 81 |
+
self.process: subprocess.Popen | None = None
|
| 82 |
+
self.local_url: str | None = None
|
| 83 |
+
|
| 84 |
+
def ensure_ready(self) -> str:
|
| 85 |
+
if not _is_space_id(self.space_id):
|
| 86 |
+
raise ValueError(
|
| 87 |
+
f"Cannot run locally: '{self.space_id}' is not a valid Space ID. "
|
| 88 |
+
"Local mode only works with Hugging Face Spaces (format: 'owner/space-name')."
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
self._ensure_cloned()
|
| 93 |
+
self._ensure_venv()
|
| 94 |
+
url = self._launch_app()
|
| 95 |
+
return url
|
| 96 |
+
except Exception as e:
|
| 97 |
+
self._log_error(e)
|
| 98 |
+
raise
|
| 99 |
+
|
| 100 |
+
def _ensure_cloned(self) -> None:
|
| 101 |
+
metadata = self._load_metadata()
|
| 102 |
+
|
| 103 |
+
if self.repo_dir.exists() and metadata:
|
| 104 |
+
should_update = os.environ.get("DAGGR_UPDATE_SPACES") == "1"
|
| 105 |
+
if not should_update:
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
self.space_dir.mkdir(parents=True, exist_ok=True)
|
| 109 |
+
|
| 110 |
+
from huggingface_hub import snapshot_download
|
| 111 |
+
|
| 112 |
+
print(f" Cloning Space '{self.space_id}'...")
|
| 113 |
+
|
| 114 |
+
if self.repo_dir.exists():
|
| 115 |
+
shutil.rmtree(self.repo_dir)
|
| 116 |
+
|
| 117 |
+
snapshot_download(
|
| 118 |
+
repo_id=self.space_id,
|
| 119 |
+
repo_type="space",
|
| 120 |
+
local_dir=self.repo_dir,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
requirements_path = self.repo_dir / "requirements.txt"
|
| 124 |
+
metadata = {
|
| 125 |
+
"cloned_at": datetime.now().isoformat(),
|
| 126 |
+
"space_id": self.space_id,
|
| 127 |
+
"requirements_hash": _hash_file(requirements_path),
|
| 128 |
+
"python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
|
| 129 |
+
}
|
| 130 |
+
self._save_metadata(metadata)
|
| 131 |
+
print(f" Cloned to {self.repo_dir}")
|
| 132 |
+
|
| 133 |
+
def _get_sdk_version(self) -> str | None:
|
| 134 |
+
readme_path = self.repo_dir / "README.md"
|
| 135 |
+
if not readme_path.exists():
|
| 136 |
+
return None
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
content = readme_path.read_text()
|
| 140 |
+
if not content.startswith("---"):
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
parts = content.split("---", 2)
|
| 144 |
+
if len(parts) < 3:
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
match = re.search(r"sdk_version:\s*['\"]?([^\s'\"]+)", parts[1])
|
| 148 |
+
if match:
|
| 149 |
+
return match.group(1)
|
| 150 |
+
except Exception:
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
return None
|
| 154 |
+
|
| 155 |
+
def _ensure_venv(self) -> None:
|
| 156 |
+
requirements_path = self.repo_dir / "requirements.txt"
|
| 157 |
+
current_hash = _hash_file(requirements_path)
|
| 158 |
+
metadata = self._load_metadata()
|
| 159 |
+
|
| 160 |
+
venv_python = self.venv_dir / "bin" / "python"
|
| 161 |
+
if sys.platform == "win32":
|
| 162 |
+
venv_python = self.venv_dir / "Scripts" / "python.exe"
|
| 163 |
+
|
| 164 |
+
needs_reinstall = False
|
| 165 |
+
if not self.venv_dir.exists() or not venv_python.exists():
|
| 166 |
+
needs_reinstall = True
|
| 167 |
+
elif metadata and metadata.get("requirements_hash") != current_hash:
|
| 168 |
+
needs_reinstall = True
|
| 169 |
+
|
| 170 |
+
if not needs_reinstall:
|
| 171 |
+
return
|
| 172 |
+
|
| 173 |
+
print(f" Setting up virtual environment for '{self.space_id}'...")
|
| 174 |
+
|
| 175 |
+
if self.venv_dir.exists():
|
| 176 |
+
shutil.rmtree(self.venv_dir)
|
| 177 |
+
|
| 178 |
+
subprocess.run(
|
| 179 |
+
[sys.executable, "-m", "venv", str(self.venv_dir)],
|
| 180 |
+
check=True,
|
| 181 |
+
capture_output=True,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
pip_path = self.venv_dir / "bin" / "pip"
|
| 185 |
+
if sys.platform == "win32":
|
| 186 |
+
pip_path = self.venv_dir / "Scripts" / "pip.exe"
|
| 187 |
+
|
| 188 |
+
subprocess.run(
|
| 189 |
+
[str(pip_path), "install", "--upgrade", "pip"],
|
| 190 |
+
check=True,
|
| 191 |
+
capture_output=True,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
sdk_version = self._get_sdk_version()
|
| 195 |
+
if sdk_version:
|
| 196 |
+
gradio_pkg = f"gradio=={sdk_version}"
|
| 197 |
+
print(f" Installing {gradio_pkg}...")
|
| 198 |
+
else:
|
| 199 |
+
gradio_pkg = "gradio"
|
| 200 |
+
print(" Installing gradio (latest)...")
|
| 201 |
+
|
| 202 |
+
result = subprocess.run(
|
| 203 |
+
[str(pip_path), "install", gradio_pkg],
|
| 204 |
+
capture_output=True,
|
| 205 |
+
text=True,
|
| 206 |
+
)
|
| 207 |
+
if result.returncode != 0:
|
| 208 |
+
error_msg = result.stderr or result.stdout
|
| 209 |
+
self._log_to_file("pip_install_gradio", error_msg)
|
| 210 |
+
print(f" Warning: Failed to install {gradio_pkg}")
|
| 211 |
+
|
| 212 |
+
if requirements_path.exists():
|
| 213 |
+
print(f" Installing dependencies from {requirements_path}...")
|
| 214 |
+
print(" (this may take a few minutes)")
|
| 215 |
+
|
| 216 |
+
process = subprocess.Popen(
|
| 217 |
+
[str(pip_path), "install", "-r", str(requirements_path)],
|
| 218 |
+
stdout=subprocess.PIPE,
|
| 219 |
+
stderr=subprocess.STDOUT,
|
| 220 |
+
text=True,
|
| 221 |
+
bufsize=1,
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
output_lines = []
|
| 225 |
+
for line in iter(process.stdout.readline, ""):
|
| 226 |
+
output_lines.append(line)
|
| 227 |
+
line_stripped = line.strip()
|
| 228 |
+
if line_stripped.startswith("Collecting "):
|
| 229 |
+
pkg = line_stripped.replace("Collecting ", "").split()[0]
|
| 230 |
+
print(f" Installing {pkg}...")
|
| 231 |
+
elif (
|
| 232 |
+
line_stripped.startswith("ERROR:")
|
| 233 |
+
or "error" in line_stripped.lower()
|
| 234 |
+
):
|
| 235 |
+
print(f" {line_stripped}")
|
| 236 |
+
|
| 237 |
+
process.wait()
|
| 238 |
+
|
| 239 |
+
if process.returncode != 0:
|
| 240 |
+
error_msg = "".join(output_lines)
|
| 241 |
+
self._log_to_file("pip_install", error_msg)
|
| 242 |
+
print("\n ❌ Dependency installation failed!")
|
| 243 |
+
print(f" Full log: {self._get_log_path('pip_install')}")
|
| 244 |
+
raise RuntimeError(
|
| 245 |
+
f"Failed to install dependencies for '{self.space_id}'.\n"
|
| 246 |
+
f"See logs at: {self._get_log_path('pip_install')}\n"
|
| 247 |
+
f"You can try installing manually:\n"
|
| 248 |
+
f" {pip_path} install -r {requirements_path}"
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
if metadata:
|
| 252 |
+
metadata["requirements_hash"] = current_hash
|
| 253 |
+
self._save_metadata(metadata)
|
| 254 |
+
|
| 255 |
+
print(" Virtual environment ready")
|
| 256 |
+
|
| 257 |
+
def _launch_app(self) -> str:
|
| 258 |
+
global _running_processes
|
| 259 |
+
|
| 260 |
+
if self.space_id in _running_processes:
|
| 261 |
+
proc = _running_processes[self.space_id]
|
| 262 |
+
if proc.poll() is None:
|
| 263 |
+
metadata = self._load_metadata()
|
| 264 |
+
if metadata and metadata.get("local_url"):
|
| 265 |
+
return metadata["local_url"]
|
| 266 |
+
|
| 267 |
+
app_file = self._find_app_file()
|
| 268 |
+
if not app_file:
|
| 269 |
+
raise RuntimeError(
|
| 270 |
+
f"No app.py or main.py found in '{self.space_id}'. "
|
| 271 |
+
"Cannot determine how to launch this Space."
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
port = _find_free_port()
|
| 275 |
+
local_url = f"http://127.0.0.1:{port}"
|
| 276 |
+
|
| 277 |
+
venv_python = self.venv_dir / "bin" / "python"
|
| 278 |
+
if sys.platform == "win32":
|
| 279 |
+
venv_python = self.venv_dir / "Scripts" / "python.exe"
|
| 280 |
+
|
| 281 |
+
timeout = int(os.environ.get("DAGGR_LOCAL_TIMEOUT", "120"))
|
| 282 |
+
|
| 283 |
+
env = os.environ.copy()
|
| 284 |
+
env["GRADIO_SERVER_PORT"] = str(port)
|
| 285 |
+
env["GRADIO_SERVER_NAME"] = "127.0.0.1"
|
| 286 |
+
env["PYTHONUNBUFFERED"] = "1"
|
| 287 |
+
|
| 288 |
+
print(f" Launching '{self.space_id}' on port {port}...")
|
| 289 |
+
print(f" Waiting for app to start (timeout: {timeout}s)...")
|
| 290 |
+
|
| 291 |
+
log_file = self._get_log_path("launch")
|
| 292 |
+
log_file.parent.mkdir(parents=True, exist_ok=True)
|
| 293 |
+
|
| 294 |
+
self.process = subprocess.Popen(
|
| 295 |
+
[str(venv_python), str(app_file)],
|
| 296 |
+
cwd=str(self.repo_dir),
|
| 297 |
+
env=env,
|
| 298 |
+
stdout=subprocess.PIPE,
|
| 299 |
+
stderr=subprocess.STDOUT,
|
| 300 |
+
text=True,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
_running_processes[self.space_id] = self.process
|
| 304 |
+
|
| 305 |
+
ready, error_output = self._wait_for_ready(local_url, timeout, verbose=True)
|
| 306 |
+
if not ready:
|
| 307 |
+
self._log_to_file("launch", error_output)
|
| 308 |
+
if self.process.poll() is None:
|
| 309 |
+
self.process.terminate()
|
| 310 |
+
|
| 311 |
+
print("\n ❌ Space failed to start!")
|
| 312 |
+
if error_output:
|
| 313 |
+
error_lines = error_output.strip().split("\n")
|
| 314 |
+
relevant_lines = [ln for ln in error_lines if ln.strip()][-10:]
|
| 315 |
+
if relevant_lines:
|
| 316 |
+
print(" Last output:")
|
| 317 |
+
for line in relevant_lines:
|
| 318 |
+
print(f" {line}")
|
| 319 |
+
|
| 320 |
+
print(f" Full log: {log_file}")
|
| 321 |
+
raise RuntimeError(
|
| 322 |
+
f"Space '{self.space_id}' failed to start.\n"
|
| 323 |
+
f"See logs at: {log_file}\n"
|
| 324 |
+
"Suggestions:\n"
|
| 325 |
+
" 1. Some Spaces require GPU hardware\n"
|
| 326 |
+
" 2. Check the Space's README for requirements\n"
|
| 327 |
+
" 3. Set DAGGR_LOCAL_VERBOSE=1 to see all output"
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
metadata = self._load_metadata() or {}
|
| 331 |
+
metadata["local_url"] = local_url
|
| 332 |
+
metadata["last_successful_launch"] = datetime.now().isoformat()
|
| 333 |
+
self._save_metadata(metadata)
|
| 334 |
+
|
| 335 |
+
print(f" Space running at {local_url}")
|
| 336 |
+
return local_url
|
| 337 |
+
|
| 338 |
+
def _find_app_file(self) -> Path | None:
|
| 339 |
+
for name in ["app.py", "main.py", "demo.py"]:
|
| 340 |
+
path = self.repo_dir / name
|
| 341 |
+
if path.exists():
|
| 342 |
+
return path
|
| 343 |
+
return None
|
| 344 |
+
|
| 345 |
+
def _wait_for_ready(
|
| 346 |
+
self, url: str, timeout: int, verbose: bool = False
|
| 347 |
+
) -> tuple[bool, str]:
|
| 348 |
+
output_lines: list[str] = []
|
| 349 |
+
start = time.time()
|
| 350 |
+
last_status_time = start
|
| 351 |
+
saw_error = False
|
| 352 |
+
|
| 353 |
+
while time.time() - start < timeout:
|
| 354 |
+
if self.process and self.process.stdout:
|
| 355 |
+
while True:
|
| 356 |
+
if sys.platform == "win32":
|
| 357 |
+
line = self.process.stdout.readline()
|
| 358 |
+
if not line:
|
| 359 |
+
break
|
| 360 |
+
else:
|
| 361 |
+
ready, _, _ = select.select([self.process.stdout], [], [], 0)
|
| 362 |
+
if not ready:
|
| 363 |
+
break
|
| 364 |
+
line = self.process.stdout.readline()
|
| 365 |
+
|
| 366 |
+
if line:
|
| 367 |
+
output_lines.append(line)
|
| 368 |
+
line_lower = line.lower()
|
| 369 |
+
if (
|
| 370 |
+
"traceback" in line_lower
|
| 371 |
+
or "modulenotfounderror" in line_lower
|
| 372 |
+
):
|
| 373 |
+
saw_error = True
|
| 374 |
+
if verbose:
|
| 375 |
+
print(f" [app] {line.rstrip()}")
|
| 376 |
+
|
| 377 |
+
exit_code = self.process.poll() if self.process else None
|
| 378 |
+
if exit_code is not None:
|
| 379 |
+
if self.process and self.process.stdout:
|
| 380 |
+
remaining = self.process.stdout.read()
|
| 381 |
+
if remaining:
|
| 382 |
+
output_lines.append(remaining)
|
| 383 |
+
if verbose:
|
| 384 |
+
for rem_line in remaining.strip().split("\n"):
|
| 385 |
+
if rem_line.strip():
|
| 386 |
+
print(f" [app] {rem_line}")
|
| 387 |
+
print(f" App process exited with code {exit_code}")
|
| 388 |
+
return False, "".join(output_lines)
|
| 389 |
+
|
| 390 |
+
if saw_error:
|
| 391 |
+
time.sleep(0.5)
|
| 392 |
+
if self.process and self.process.poll() is not None:
|
| 393 |
+
if self.process.stdout:
|
| 394 |
+
remaining = self.process.stdout.read()
|
| 395 |
+
if remaining:
|
| 396 |
+
output_lines.append(remaining)
|
| 397 |
+
print(" App crashed during startup")
|
| 398 |
+
return False, "".join(output_lines)
|
| 399 |
+
|
| 400 |
+
elapsed = time.time() - start
|
| 401 |
+
if elapsed - (last_status_time - start) >= 10:
|
| 402 |
+
print(f" Still waiting... ({int(elapsed)}s elapsed)")
|
| 403 |
+
last_status_time = time.time()
|
| 404 |
+
|
| 405 |
+
try:
|
| 406 |
+
with urllib.request.urlopen(url, timeout=2) as response:
|
| 407 |
+
if response.status == 200:
|
| 408 |
+
return True, "".join(output_lines)
|
| 409 |
+
except (urllib.error.URLError, OSError):
|
| 410 |
+
pass
|
| 411 |
+
|
| 412 |
+
time.sleep(0.3)
|
| 413 |
+
|
| 414 |
+
return False, "".join(output_lines)
|
| 415 |
+
|
| 416 |
+
def _load_metadata(self) -> dict[str, Any] | None:
|
| 417 |
+
if not self.metadata_path.exists():
|
| 418 |
+
return None
|
| 419 |
+
try:
|
| 420 |
+
return json.loads(self.metadata_path.read_text())
|
| 421 |
+
except (json.JSONDecodeError, OSError):
|
| 422 |
+
return None
|
| 423 |
+
|
| 424 |
+
def _save_metadata(self, metadata: dict[str, Any]) -> None:
|
| 425 |
+
self.metadata_path.parent.mkdir(parents=True, exist_ok=True)
|
| 426 |
+
self.metadata_path.write_text(json.dumps(metadata, indent=2))
|
| 427 |
+
|
| 428 |
+
def _get_log_path(self, log_type: str) -> Path:
|
| 429 |
+
logs_dir = _get_logs_dir()
|
| 430 |
+
logs_dir.mkdir(parents=True, exist_ok=True)
|
| 431 |
+
safe_name = self.space_id.replace("/", "_")
|
| 432 |
+
timestamp = datetime.now().strftime("%Y-%m-%d")
|
| 433 |
+
return logs_dir / f"{safe_name}_{log_type}_{timestamp}.log"
|
| 434 |
+
|
| 435 |
+
def _log_to_file(self, log_type: str, content: str) -> None:
|
| 436 |
+
log_path = self._get_log_path(log_type)
|
| 437 |
+
log_path.parent.mkdir(parents=True, exist_ok=True)
|
| 438 |
+
with open(log_path, "w") as f:
|
| 439 |
+
f.write(f"Timestamp: {datetime.now().isoformat()}\n")
|
| 440 |
+
f.write(f"Space: {self.space_id}\n")
|
| 441 |
+
f.write(f"Type: {log_type}\n")
|
| 442 |
+
f.write("=" * 50 + "\n")
|
| 443 |
+
f.write(content)
|
| 444 |
+
|
| 445 |
+
def _log_error(self, error: Exception) -> None:
|
| 446 |
+
self._log_to_file("error", str(error))
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def prepare_local_node(node: GradioNode) -> None:
|
| 450 |
+
if node._local_failed or node._local_url:
|
| 451 |
+
return
|
| 452 |
+
|
| 453 |
+
if not _is_space_id(node._src):
|
| 454 |
+
return
|
| 455 |
+
|
| 456 |
+
no_fallback = os.environ.get("DAGGR_LOCAL_NO_FALLBACK") == "1"
|
| 457 |
+
|
| 458 |
+
try:
|
| 459 |
+
manager = LocalSpaceManager(node)
|
| 460 |
+
url = manager.ensure_ready()
|
| 461 |
+
node._local_url = url
|
| 462 |
+
except Exception as e:
|
| 463 |
+
node._local_failed = True
|
| 464 |
+
safe_name = node._src.replace("/", "_")
|
| 465 |
+
|
| 466 |
+
print(f"\n ⚠️ Local setup failed for '{node._src}'")
|
| 467 |
+
print(f" Reason: {e}")
|
| 468 |
+
print(f" Logs: {_get_logs_dir()}/{safe_name}_*.log")
|
| 469 |
+
|
| 470 |
+
if no_fallback:
|
| 471 |
+
raise RuntimeError(
|
| 472 |
+
f"Local execution failed for '{node._src}' and fallback is disabled. "
|
| 473 |
+
f"Error: {e}"
|
| 474 |
+
) from e
|
| 475 |
+
|
| 476 |
+
print(" Will fall back to remote API at execution time.\n")
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
def get_local_client(node: GradioNode) -> Any:
|
| 480 |
+
if node._local_failed:
|
| 481 |
+
return None
|
| 482 |
+
|
| 483 |
+
if node._local_url:
|
| 484 |
+
from gradio_client import Client
|
| 485 |
+
|
| 486 |
+
return Client(node._local_url, download_files=False, verbose=False)
|
| 487 |
+
|
| 488 |
+
return None
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def cleanup_local_processes() -> None:
|
| 492 |
+
global _running_processes
|
| 493 |
+
for space_id, proc in list(_running_processes.items()):
|
| 494 |
+
if proc.poll() is None:
|
| 495 |
+
proc.terminate()
|
| 496 |
+
try:
|
| 497 |
+
proc.wait(timeout=5)
|
| 498 |
+
except subprocess.TimeoutExpired:
|
| 499 |
+
proc.kill()
|
| 500 |
+
_running_processes.clear()
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
atexit.register(cleanup_local_processes)
|
daggr/node.py
ADDED
|
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Node types for daggr graphs.
|
| 2 |
+
|
| 3 |
+
This module defines the various node types that can be used in a daggr graph:
|
| 4 |
+
- Node: Abstract base class for all nodes
|
| 5 |
+
- GradioNode: Wraps a Gradio Space or endpoint
|
| 6 |
+
- InferenceNode: Wraps a Hugging Face Inference API model
|
| 7 |
+
- FnNode: Wraps a Python function
|
| 8 |
+
- InteractionNode: Represents user interaction points
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import inspect
|
| 14 |
+
import warnings
|
| 15 |
+
from abc import ABC
|
| 16 |
+
from collections.abc import Callable
|
| 17 |
+
from typing import Any
|
| 18 |
+
|
| 19 |
+
from daggr._utils import suggest_similar
|
| 20 |
+
from daggr.port import ItemList, Port, PortNamespace, is_port
|
| 21 |
+
|
| 22 |
+
_FILE_TYPE_COMPONENTS = {
|
| 23 |
+
"Image",
|
| 24 |
+
"Audio",
|
| 25 |
+
"Video",
|
| 26 |
+
"File",
|
| 27 |
+
"Gallery",
|
| 28 |
+
"ImageEditor",
|
| 29 |
+
"ImageSlider",
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _warn_if_type_set(component: Any, port_name: str) -> None:
|
| 34 |
+
constructor_args = getattr(component, "_constructor_args", None)
|
| 35 |
+
if not constructor_args:
|
| 36 |
+
return
|
| 37 |
+
comp_type = constructor_args[0].get("type")
|
| 38 |
+
if comp_type is None:
|
| 39 |
+
return
|
| 40 |
+
class_name = type(component).__name__
|
| 41 |
+
if class_name not in _FILE_TYPE_COMPONENTS:
|
| 42 |
+
return
|
| 43 |
+
if comp_type != "filepath":
|
| 44 |
+
warnings.warn(
|
| 45 |
+
f"Gradio component {class_name}(type={comp_type!r}) on port '{port_name}': "
|
| 46 |
+
f"daggr ignores the `type` parameter. All file data is passed as file path "
|
| 47 |
+
f"strings regardless of this setting.",
|
| 48 |
+
stacklevel=4,
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _is_gradio_component(obj: Any) -> bool:
|
| 53 |
+
if obj is None:
|
| 54 |
+
return False
|
| 55 |
+
class_name = obj.__class__.__name__
|
| 56 |
+
module = getattr(obj.__class__, "__module__", "")
|
| 57 |
+
return "gradio" in module or class_name in (
|
| 58 |
+
"Textbox",
|
| 59 |
+
"TextArea",
|
| 60 |
+
"Audio",
|
| 61 |
+
"Image",
|
| 62 |
+
"JSON",
|
| 63 |
+
"Markdown",
|
| 64 |
+
"Number",
|
| 65 |
+
"Checkbox",
|
| 66 |
+
"Dropdown",
|
| 67 |
+
"Radio",
|
| 68 |
+
"Slider",
|
| 69 |
+
"File",
|
| 70 |
+
"Video",
|
| 71 |
+
"Gallery",
|
| 72 |
+
"Chatbot",
|
| 73 |
+
"Text",
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class Node(ABC):
|
| 78 |
+
"""Abstract base class for all nodes in a daggr graph.
|
| 79 |
+
|
| 80 |
+
Nodes represent processing steps in a DAG. Each node has named input and
|
| 81 |
+
output ports that can be connected to form a data processing pipeline.
|
| 82 |
+
|
| 83 |
+
Ports can be accessed as attributes: `node.port_name` returns a Port object.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
name: Optional display name for the node. If not provided, a name will
|
| 87 |
+
be auto-generated based on the node type.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
_id_counter = 0
|
| 91 |
+
|
| 92 |
+
def __init__(self, name: str | None = None):
|
| 93 |
+
self._id = Node._id_counter
|
| 94 |
+
Node._id_counter += 1
|
| 95 |
+
self._name = name or ""
|
| 96 |
+
self._name_explicitly_set = bool(name)
|
| 97 |
+
self._input_ports: list[str] = []
|
| 98 |
+
self._output_ports: list[str] = []
|
| 99 |
+
self._input_components: dict[str, Any] = {}
|
| 100 |
+
self._output_components: dict[str, Any] = {}
|
| 101 |
+
self._item_list_schemas: dict[str, dict[str, Any]] = {}
|
| 102 |
+
self._fixed_inputs: dict[str, Any] = {}
|
| 103 |
+
self._port_connections: dict[str, Any] = {}
|
| 104 |
+
|
| 105 |
+
@property
|
| 106 |
+
def name(self) -> str:
|
| 107 |
+
return self._name
|
| 108 |
+
|
| 109 |
+
@name.setter
|
| 110 |
+
def name(self, value: str) -> None:
|
| 111 |
+
self._name = value
|
| 112 |
+
self._name_explicitly_set = True
|
| 113 |
+
|
| 114 |
+
def __getattr__(self, name: str) -> Port:
|
| 115 |
+
if name.startswith("_"):
|
| 116 |
+
raise AttributeError(name)
|
| 117 |
+
return Port(self, name)
|
| 118 |
+
|
| 119 |
+
def __dir__(self) -> list[str]:
|
| 120 |
+
base = ["_name", "_inputs", "_outputs", "_input_ports", "_output_ports"]
|
| 121 |
+
return base + self._input_ports + self._output_ports
|
| 122 |
+
|
| 123 |
+
def __or__(self, other: Node) -> ChoiceNode:
|
| 124 |
+
"""Combine two nodes as alternatives using the | operator.
|
| 125 |
+
|
| 126 |
+
Returns a ChoiceNode that lets users pick which variant to run.
|
| 127 |
+
|
| 128 |
+
Example:
|
| 129 |
+
>>> tts = GradioNode("space1/tts", ...) | GradioNode("space2/tts", ...)
|
| 130 |
+
>>> # tts.audio works regardless of which variant is selected
|
| 131 |
+
"""
|
| 132 |
+
if isinstance(other, ChoiceNode):
|
| 133 |
+
return ChoiceNode([self] + other._variants, name=self._name)
|
| 134 |
+
return ChoiceNode([self, other], name=self._name)
|
| 135 |
+
|
| 136 |
+
@property
|
| 137 |
+
def _inputs(self) -> PortNamespace:
|
| 138 |
+
return PortNamespace(self, self._input_ports)
|
| 139 |
+
|
| 140 |
+
@property
|
| 141 |
+
def _outputs(self) -> PortNamespace:
|
| 142 |
+
return PortNamespace(self, self._output_ports)
|
| 143 |
+
|
| 144 |
+
def _default_output_port(self) -> Port:
|
| 145 |
+
if self._output_ports:
|
| 146 |
+
return Port(self, self._output_ports[0])
|
| 147 |
+
return Port(self, "output")
|
| 148 |
+
|
| 149 |
+
def _default_input_port(self) -> Port:
|
| 150 |
+
if self._input_ports:
|
| 151 |
+
return Port(self, self._input_ports[0])
|
| 152 |
+
return Port(self, "input")
|
| 153 |
+
|
| 154 |
+
def _validate_ports(self):
|
| 155 |
+
all_ports = set(self._input_ports + self._output_ports)
|
| 156 |
+
underscore_ports = [p for p in all_ports if p.startswith("_")]
|
| 157 |
+
if underscore_ports:
|
| 158 |
+
warnings.warn(
|
| 159 |
+
f"Port names {underscore_ports} start with underscore. "
|
| 160 |
+
f"Use node._inputs.{underscore_ports[0]} or node._outputs.{underscore_ports[0]} to access."
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
def _process_inputs(self, inputs: dict[str, Any]) -> None:
|
| 164 |
+
for port_name, value in inputs.items():
|
| 165 |
+
self._input_ports.append(port_name)
|
| 166 |
+
if is_port(value):
|
| 167 |
+
self._port_connections[port_name] = value
|
| 168 |
+
elif _is_gradio_component(value):
|
| 169 |
+
_warn_if_type_set(value, port_name)
|
| 170 |
+
self._input_components[port_name] = value
|
| 171 |
+
else:
|
| 172 |
+
self._fixed_inputs[port_name] = value
|
| 173 |
+
|
| 174 |
+
def _process_outputs(self, outputs: dict[str, Any]) -> None:
|
| 175 |
+
for port_name, component in outputs.items():
|
| 176 |
+
self._output_ports.append(port_name)
|
| 177 |
+
if component is not None and _is_gradio_component(component):
|
| 178 |
+
_warn_if_type_set(component, port_name)
|
| 179 |
+
self._output_components[port_name] = component
|
| 180 |
+
|
| 181 |
+
def test(self, **inputs) -> dict[str, Any]:
|
| 182 |
+
"""Test-run this node in isolation and return the raw result.
|
| 183 |
+
|
| 184 |
+
If no inputs are provided, auto-generates example values using:
|
| 185 |
+
- Gradio component's .example_value() method
|
| 186 |
+
- Port's associated output component's .example_value()
|
| 187 |
+
- Callable inputs are called
|
| 188 |
+
- Fixed values are used directly
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
**inputs: Override inputs for the test run.
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Dict mapping output port names to their values.
|
| 195 |
+
|
| 196 |
+
Example:
|
| 197 |
+
>>> tts = GradioNode("mrfakename/MeloTTS", api_name="/synthesize", ...)
|
| 198 |
+
>>> result = tts.test(text="Hello world", speaker="EN-US")
|
| 199 |
+
>>> # Returns: {"audio": "/path/to/audio.wav"}
|
| 200 |
+
>>>
|
| 201 |
+
>>> # Or with auto-generated example values:
|
| 202 |
+
>>> result = tts.test()
|
| 203 |
+
"""
|
| 204 |
+
from daggr import Graph
|
| 205 |
+
from daggr.executor import SequentialExecutor
|
| 206 |
+
|
| 207 |
+
if not inputs:
|
| 208 |
+
inputs = self._generate_example_inputs()
|
| 209 |
+
|
| 210 |
+
graph = Graph("_test", nodes=[self], persist_key=False)
|
| 211 |
+
executor = SequentialExecutor(graph)
|
| 212 |
+
return executor.execute_node(self._name, inputs)
|
| 213 |
+
|
| 214 |
+
def _generate_example_inputs(self) -> dict[str, Any]:
|
| 215 |
+
"""Generate example values for all input ports."""
|
| 216 |
+
examples = {}
|
| 217 |
+
|
| 218 |
+
# From input components (Gradio components)
|
| 219 |
+
for port_name, comp in self._input_components.items():
|
| 220 |
+
if hasattr(comp, "example_value"):
|
| 221 |
+
examples[port_name] = comp.example_value()
|
| 222 |
+
|
| 223 |
+
# From fixed inputs (constants, callables, or port connections)
|
| 224 |
+
for port_name, source in self._fixed_inputs.items():
|
| 225 |
+
if callable(source):
|
| 226 |
+
examples[port_name] = source()
|
| 227 |
+
else:
|
| 228 |
+
examples[port_name] = source
|
| 229 |
+
|
| 230 |
+
# From port connections (use the connected port's output component)
|
| 231 |
+
for port_name, port in self._port_connections.items():
|
| 232 |
+
if is_port(port):
|
| 233 |
+
comp = port._node._output_components.get(port._port_name)
|
| 234 |
+
if comp and hasattr(comp, "example_value"):
|
| 235 |
+
examples[port_name] = comp.example_value()
|
| 236 |
+
|
| 237 |
+
return examples
|
| 238 |
+
|
| 239 |
+
def __repr__(self):
|
| 240 |
+
return f"{self.__class__.__name__}(name={self._name})"
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class ChoiceNode(Node):
|
| 244 |
+
"""A node that wraps multiple alternative nodes.
|
| 245 |
+
|
| 246 |
+
ChoiceNode allows users to select which variant to run from a set of
|
| 247 |
+
alternatives. Created using the | operator between nodes.
|
| 248 |
+
|
| 249 |
+
The output ports are the union of all variants' output ports, so downstream
|
| 250 |
+
nodes can connect to any output that exists in at least one variant.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
variants: List of Node objects that serve as alternatives.
|
| 254 |
+
name: Optional display name. Defaults to the first variant's name.
|
| 255 |
+
|
| 256 |
+
Example:
|
| 257 |
+
>>> tts = GradioNode("space1/tts", ...) | GradioNode("space2/tts", ...)
|
| 258 |
+
>>> # tts is a ChoiceNode with two variants
|
| 259 |
+
>>> # tts.audio works regardless of which variant is selected
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
def __init__(
|
| 263 |
+
self,
|
| 264 |
+
variants: list[Node],
|
| 265 |
+
name: str | None = None,
|
| 266 |
+
):
|
| 267 |
+
if not variants:
|
| 268 |
+
raise ValueError("ChoiceNode requires at least one variant")
|
| 269 |
+
|
| 270 |
+
super().__init__(name)
|
| 271 |
+
self._variants = variants
|
| 272 |
+
self._selected_variant = 0
|
| 273 |
+
|
| 274 |
+
if not self._name:
|
| 275 |
+
self._name = variants[0]._name
|
| 276 |
+
|
| 277 |
+
self._output_ports = self._compute_union_output_ports()
|
| 278 |
+
self._output_components = self._compute_union_output_components()
|
| 279 |
+
|
| 280 |
+
for variant in variants:
|
| 281 |
+
for port_name, port in variant._port_connections.items():
|
| 282 |
+
if port_name not in self._port_connections:
|
| 283 |
+
self._port_connections[port_name] = port
|
| 284 |
+
|
| 285 |
+
def _compute_union_output_ports(self) -> list[str]:
|
| 286 |
+
seen = set()
|
| 287 |
+
ports = []
|
| 288 |
+
for variant in self._variants:
|
| 289 |
+
for port in variant._output_ports:
|
| 290 |
+
if port not in seen:
|
| 291 |
+
seen.add(port)
|
| 292 |
+
ports.append(port)
|
| 293 |
+
return ports
|
| 294 |
+
|
| 295 |
+
def _compute_union_output_components(self) -> dict[str, Any]:
|
| 296 |
+
components = {}
|
| 297 |
+
for variant in self._variants:
|
| 298 |
+
for port_name, comp in variant._output_components.items():
|
| 299 |
+
if port_name not in components:
|
| 300 |
+
components[port_name] = comp
|
| 301 |
+
return components
|
| 302 |
+
|
| 303 |
+
def __or__(self, other: Node) -> ChoiceNode:
|
| 304 |
+
if isinstance(other, ChoiceNode):
|
| 305 |
+
return ChoiceNode(self._variants + other._variants, name=self._name)
|
| 306 |
+
return ChoiceNode(self._variants + [other], name=self._name)
|
| 307 |
+
|
| 308 |
+
def __repr__(self):
|
| 309 |
+
variant_names = [v._name for v in self._variants]
|
| 310 |
+
return f"ChoiceNode(name={self._name}, variants={variant_names})"
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class GradioNode(Node):
|
| 314 |
+
"""A node that wraps a Gradio Space or endpoint.
|
| 315 |
+
|
| 316 |
+
GradioNode connects to a Hugging Face Space or any Gradio app and exposes
|
| 317 |
+
its API as a node in the graph.
|
| 318 |
+
|
| 319 |
+
Args:
|
| 320 |
+
space_or_url: Hugging Face Space ID (e.g., "username/space-name") or
|
| 321 |
+
a full URL to a Gradio app.
|
| 322 |
+
api_name: The API endpoint to call (e.g., "/predict"). Defaults to "/predict".
|
| 323 |
+
name: Optional display name for the node.
|
| 324 |
+
inputs: Dict mapping input port names to Gradio components, Port connections,
|
| 325 |
+
or fixed values.
|
| 326 |
+
outputs: Dict mapping output port names to Gradio components for display.
|
| 327 |
+
validate: Whether to validate the Space exists and has the specified endpoint.
|
| 328 |
+
run_locally: If True, clone and run the Space locally instead of using the
|
| 329 |
+
remote API.
|
| 330 |
+
|
| 331 |
+
Example:
|
| 332 |
+
>>> tts = GradioNode(
|
| 333 |
+
... "mrfakename/MeloTTS",
|
| 334 |
+
... api_name="/synthesize",
|
| 335 |
+
... inputs={"text": gr.Textbox(), "speaker": "EN-US"},
|
| 336 |
+
... outputs={"audio": gr.Audio()},
|
| 337 |
+
... )
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
_name_counters: dict[str, int] = {}
|
| 341 |
+
|
| 342 |
+
def __init__(
|
| 343 |
+
self,
|
| 344 |
+
space_or_url: str,
|
| 345 |
+
api_name: str | None = None,
|
| 346 |
+
name: str | None = None,
|
| 347 |
+
inputs: dict[str, Any] | None = None,
|
| 348 |
+
outputs: dict[str, Any] | None = None,
|
| 349 |
+
validate: bool = True,
|
| 350 |
+
run_locally: bool = False,
|
| 351 |
+
preprocess: Callable[[dict], dict] | None = None,
|
| 352 |
+
postprocess: Callable[..., Any] | None = None,
|
| 353 |
+
):
|
| 354 |
+
super().__init__(name)
|
| 355 |
+
self._src = space_or_url
|
| 356 |
+
self._api_name = api_name
|
| 357 |
+
self._run_locally = run_locally
|
| 358 |
+
self._local_url: str | None = None
|
| 359 |
+
self._local_failed = False
|
| 360 |
+
self._preprocess = preprocess
|
| 361 |
+
self._postprocess = postprocess
|
| 362 |
+
|
| 363 |
+
if validate:
|
| 364 |
+
self._validate_space_format()
|
| 365 |
+
|
| 366 |
+
if not self._name:
|
| 367 |
+
base_name = self._src.split("/")[-1]
|
| 368 |
+
if base_name not in GradioNode._name_counters:
|
| 369 |
+
GradioNode._name_counters[base_name] = 0
|
| 370 |
+
self._name = base_name
|
| 371 |
+
else:
|
| 372 |
+
GradioNode._name_counters[base_name] += 1
|
| 373 |
+
self._name = f"{base_name}_{GradioNode._name_counters[base_name]}"
|
| 374 |
+
|
| 375 |
+
self._process_inputs(inputs or {})
|
| 376 |
+
self._process_outputs(outputs or {})
|
| 377 |
+
self._validate_ports()
|
| 378 |
+
|
| 379 |
+
if validate and not run_locally:
|
| 380 |
+
self._validate_gradio_api(inputs or {}, outputs or {})
|
| 381 |
+
|
| 382 |
+
def _validate_space_format(self) -> None:
|
| 383 |
+
src = self._src
|
| 384 |
+
if not ("/" in src or src.startswith("http://") or src.startswith("https://")):
|
| 385 |
+
raise ValueError(
|
| 386 |
+
f"Invalid space_or_url '{src}'. Expected format: 'username/space-name' "
|
| 387 |
+
f"or a full URL like 'https://...'"
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
def _get_api_info(self) -> dict:
|
| 391 |
+
from daggr import _client_cache
|
| 392 |
+
|
| 393 |
+
cached = _client_cache.get_api_info(self._src)
|
| 394 |
+
if cached is not None:
|
| 395 |
+
return cached
|
| 396 |
+
|
| 397 |
+
from gradio_client import Client
|
| 398 |
+
|
| 399 |
+
client = _client_cache.get_client(self._src)
|
| 400 |
+
if client is None:
|
| 401 |
+
client = Client(self._src, download_files=False, verbose=False)
|
| 402 |
+
_client_cache.set_client(self._src, client)
|
| 403 |
+
|
| 404 |
+
api_info = client.view_api(return_format="dict", print_info=False)
|
| 405 |
+
_client_cache.set_api_info(self._src, api_info)
|
| 406 |
+
return api_info
|
| 407 |
+
|
| 408 |
+
def _validate_gradio_api(
|
| 409 |
+
self, inputs: dict[str, Any], outputs: dict[str, Any]
|
| 410 |
+
) -> None:
|
| 411 |
+
from daggr import _client_cache
|
| 412 |
+
|
| 413 |
+
api_name = self._api_name or "/predict"
|
| 414 |
+
if not api_name.startswith("/"):
|
| 415 |
+
api_name = "/" + api_name
|
| 416 |
+
|
| 417 |
+
cache_key = (
|
| 418 |
+
self._src,
|
| 419 |
+
api_name,
|
| 420 |
+
tuple(sorted(inputs.keys())),
|
| 421 |
+
tuple(sorted(outputs.keys())) if outputs else (),
|
| 422 |
+
)
|
| 423 |
+
if _client_cache.is_validated(cache_key):
|
| 424 |
+
return
|
| 425 |
+
|
| 426 |
+
api_info = self._get_api_info()
|
| 427 |
+
|
| 428 |
+
named_endpoints = api_info.get("named_endpoints", {})
|
| 429 |
+
unnamed_endpoints = api_info.get("unnamed_endpoints", {})
|
| 430 |
+
|
| 431 |
+
endpoint_info = None
|
| 432 |
+
if api_name in named_endpoints:
|
| 433 |
+
endpoint_info = named_endpoints[api_name]
|
| 434 |
+
else:
|
| 435 |
+
try:
|
| 436 |
+
fn_index = int(api_name.lstrip("/"))
|
| 437 |
+
if fn_index in unnamed_endpoints or str(fn_index) in unnamed_endpoints:
|
| 438 |
+
endpoint_info = unnamed_endpoints.get(
|
| 439 |
+
fn_index, unnamed_endpoints.get(str(fn_index))
|
| 440 |
+
)
|
| 441 |
+
except ValueError:
|
| 442 |
+
pass
|
| 443 |
+
|
| 444 |
+
if endpoint_info is None:
|
| 445 |
+
available = list(named_endpoints.keys())
|
| 446 |
+
if unnamed_endpoints:
|
| 447 |
+
available.extend([f"/{k}" for k in unnamed_endpoints.keys()])
|
| 448 |
+
suggested = suggest_similar(api_name, set(available))
|
| 449 |
+
msg = (
|
| 450 |
+
f"API endpoint '{api_name}' not found in '{self._src}'. "
|
| 451 |
+
f"Available endpoints: {available}"
|
| 452 |
+
)
|
| 453 |
+
if suggested:
|
| 454 |
+
msg += f" Did you mean '{suggested}'?"
|
| 455 |
+
raise ValueError(msg)
|
| 456 |
+
|
| 457 |
+
params_info = endpoint_info.get("parameters", [])
|
| 458 |
+
valid_params = {p.get("parameter_name", p["label"]) for p in params_info}
|
| 459 |
+
input_params = set(inputs.keys())
|
| 460 |
+
invalid_params = input_params - valid_params
|
| 461 |
+
|
| 462 |
+
if invalid_params:
|
| 463 |
+
suggestions = {}
|
| 464 |
+
for inv in invalid_params:
|
| 465 |
+
suggestion = suggest_similar(inv, valid_params)
|
| 466 |
+
if suggestion:
|
| 467 |
+
suggestions[inv] = suggestion
|
| 468 |
+
msg = (
|
| 469 |
+
f"Invalid parameter(s) {invalid_params} for endpoint '{api_name}' "
|
| 470 |
+
f"in '{self._src}'."
|
| 471 |
+
)
|
| 472 |
+
if suggestions:
|
| 473 |
+
suggestion_str = ", ".join(
|
| 474 |
+
f"'{k}' -> '{v}'" for k, v in suggestions.items()
|
| 475 |
+
)
|
| 476 |
+
msg += f" Did you mean: {suggestion_str}?"
|
| 477 |
+
msg += f" Valid parameters: {valid_params}"
|
| 478 |
+
raise ValueError(msg)
|
| 479 |
+
|
| 480 |
+
required_params = {
|
| 481 |
+
p.get("parameter_name", p["label"])
|
| 482 |
+
for p in params_info
|
| 483 |
+
if not p.get("parameter_has_default", False)
|
| 484 |
+
}
|
| 485 |
+
provided_params = set(inputs.keys())
|
| 486 |
+
missing_required = required_params - provided_params
|
| 487 |
+
|
| 488 |
+
if missing_required:
|
| 489 |
+
raise ValueError(
|
| 490 |
+
f"Missing required parameter(s) {missing_required} for endpoint "
|
| 491 |
+
f"'{api_name}' in '{self._src}'. These parameters have no default values."
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
api_returns = endpoint_info.get("returns", [])
|
| 495 |
+
if outputs and api_returns and not self._postprocess:
|
| 496 |
+
num_returns = len(api_returns)
|
| 497 |
+
num_outputs = len(outputs)
|
| 498 |
+
if num_outputs > num_returns:
|
| 499 |
+
warnings.warn(
|
| 500 |
+
f"GradioNode '{self._name}' defines {num_outputs} outputs but "
|
| 501 |
+
f"endpoint '{api_name}' only returns {num_returns} value(s). "
|
| 502 |
+
f"Extra outputs will be None."
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
_client_cache.mark_validated(cache_key)
|
| 506 |
+
|
| 507 |
+
|
| 508 |
+
class InferenceNode(Node):
|
| 509 |
+
"""A node that wraps a Hugging Face Inference API model.
|
| 510 |
+
|
| 511 |
+
InferenceNode uses the Hugging Face Inference API to run models without
|
| 512 |
+
needing to download them locally. The task type (text-generation, text-to-image,
|
| 513 |
+
etc.) is automatically determined from the model's pipeline_tag on the Hub.
|
| 514 |
+
|
| 515 |
+
Args:
|
| 516 |
+
model: The Hugging Face model ID (e.g., "meta-llama/Llama-2-7b-chat-hf").
|
| 517 |
+
name: Optional display name for the node.
|
| 518 |
+
inputs: Dict mapping input port names to values or components.
|
| 519 |
+
outputs: Dict mapping output port names to components.
|
| 520 |
+
validate: Whether to validate the model exists on the Hub.
|
| 521 |
+
preprocess: Optional function that receives the input dict and returns a
|
| 522 |
+
modified dict before the inference call.
|
| 523 |
+
postprocess: Optional function that receives the raw inference result and
|
| 524 |
+
returns a transformed value before it is mapped to output ports.
|
| 525 |
+
|
| 526 |
+
Example:
|
| 527 |
+
>>> llm = InferenceNode("meta-llama/Llama-2-7b-chat-hf")
|
| 528 |
+
"""
|
| 529 |
+
|
| 530 |
+
def __init__(
|
| 531 |
+
self,
|
| 532 |
+
model: str,
|
| 533 |
+
name: str | None = None,
|
| 534 |
+
inputs: dict[str, Any] | None = None,
|
| 535 |
+
outputs: dict[str, Any] | None = None,
|
| 536 |
+
validate: bool = True,
|
| 537 |
+
preprocess: Callable[[dict], dict] | None = None,
|
| 538 |
+
postprocess: Callable[..., Any] | None = None,
|
| 539 |
+
):
|
| 540 |
+
super().__init__(name)
|
| 541 |
+
self._model = model
|
| 542 |
+
self._task: str | None = None
|
| 543 |
+
self._task_fetched: bool = False
|
| 544 |
+
self._preprocess = preprocess
|
| 545 |
+
self._postprocess = postprocess
|
| 546 |
+
|
| 547 |
+
if not self._name:
|
| 548 |
+
# Strip provider tag (e.g., ":replicate") for display name
|
| 549 |
+
self._name = self._model_name_for_hub.split("/")[-1]
|
| 550 |
+
|
| 551 |
+
if inputs:
|
| 552 |
+
self._process_inputs(inputs)
|
| 553 |
+
else:
|
| 554 |
+
self._input_ports = ["input"]
|
| 555 |
+
|
| 556 |
+
if outputs:
|
| 557 |
+
self._process_outputs(outputs)
|
| 558 |
+
else:
|
| 559 |
+
self._output_ports = ["output"]
|
| 560 |
+
|
| 561 |
+
self._validate_ports()
|
| 562 |
+
|
| 563 |
+
if validate:
|
| 564 |
+
self._fetch_model_info()
|
| 565 |
+
|
| 566 |
+
@property
|
| 567 |
+
def _model_name_for_hub(self) -> str:
|
| 568 |
+
"""Return the model name without provider tags (e.g., ':replicate')."""
|
| 569 |
+
# HF Inference Client allows tags like "model:provider" for routing
|
| 570 |
+
# Strip these for Hub API calls and display
|
| 571 |
+
return self._model.split(":")[0]
|
| 572 |
+
|
| 573 |
+
@property
|
| 574 |
+
def _provider(self) -> str | None:
|
| 575 |
+
"""Return the provider tag if specified (e.g., 'replicate' from 'model:replicate')."""
|
| 576 |
+
parts = self._model.split(":")
|
| 577 |
+
return parts[1] if len(parts) > 1 else None
|
| 578 |
+
|
| 579 |
+
def _fetch_model_info(self) -> None:
|
| 580 |
+
if self._task_fetched:
|
| 581 |
+
return
|
| 582 |
+
|
| 583 |
+
from daggr import _client_cache
|
| 584 |
+
|
| 585 |
+
# Use model name without provider tag for Hub lookups
|
| 586 |
+
hub_model = self._model_name_for_hub
|
| 587 |
+
|
| 588 |
+
found_in_cache, cached = _client_cache.get_model_task(hub_model)
|
| 589 |
+
if found_in_cache:
|
| 590 |
+
if cached == "__NOT_FOUND__":
|
| 591 |
+
raise ValueError(f"Model '{hub_model}' not found on Hugging Face Hub.")
|
| 592 |
+
self._task = cached
|
| 593 |
+
self._task_fetched = True
|
| 594 |
+
return
|
| 595 |
+
|
| 596 |
+
from huggingface_hub import model_info
|
| 597 |
+
from huggingface_hub.utils import RepositoryNotFoundError
|
| 598 |
+
|
| 599 |
+
try:
|
| 600 |
+
info = model_info(hub_model)
|
| 601 |
+
self._task = info.pipeline_tag
|
| 602 |
+
_client_cache.set_model_task(hub_model, self._task)
|
| 603 |
+
self._task_fetched = True
|
| 604 |
+
except RepositoryNotFoundError:
|
| 605 |
+
_client_cache.set_model_not_found(hub_model)
|
| 606 |
+
raise ValueError(
|
| 607 |
+
f"Model '{hub_model}' not found on Hugging Face Hub. "
|
| 608 |
+
f"Please check the model name is correct (format: 'username/model-name')."
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
|
| 612 |
+
class FnNode(Node):
|
| 613 |
+
"""A node that wraps a Python function.
|
| 614 |
+
|
| 615 |
+
FnNode allows you to use any Python function as a node in the graph.
|
| 616 |
+
Input ports are automatically discovered from the function signature.
|
| 617 |
+
|
| 618 |
+
Return values are mapped to output ports in order, just like GradioNode:
|
| 619 |
+
- Single value: maps to the first output port
|
| 620 |
+
- Tuple: each element maps to the corresponding output port in order
|
| 621 |
+
|
| 622 |
+
Concurrency:
|
| 623 |
+
By default, FnNodes execute sequentially (one at a time per session)
|
| 624 |
+
to prevent resource contention. Use the concurrency parameters to
|
| 625 |
+
allow parallel execution:
|
| 626 |
+
|
| 627 |
+
- concurrent=True: Allow this node to run in parallel with others
|
| 628 |
+
- concurrency_group: Group nodes that share a resource (e.g., GPU)
|
| 629 |
+
- max_concurrent: Max parallel executions within a group (default: 1)
|
| 630 |
+
|
| 631 |
+
Note: GradioNode and InferenceNode always run concurrently since they
|
| 632 |
+
are external API calls. Prefer these over FnNode when possible.
|
| 633 |
+
|
| 634 |
+
Args:
|
| 635 |
+
fn: The Python function to wrap.
|
| 636 |
+
name: Optional display name. Defaults to the function name.
|
| 637 |
+
inputs: Optional dict to explicitly define input ports and their
|
| 638 |
+
connections or UI components.
|
| 639 |
+
outputs: Optional dict mapping output port names to UI components
|
| 640 |
+
or ItemList schemas.
|
| 641 |
+
concurrent: If True, allow parallel execution. Default: False.
|
| 642 |
+
concurrency_group: Name of a group sharing a concurrency limit.
|
| 643 |
+
max_concurrent: Max parallel executions in the group. Default: 1.
|
| 644 |
+
|
| 645 |
+
Example:
|
| 646 |
+
>>> def process_text(text: str) -> tuple[str, int]:
|
| 647 |
+
... return text.upper(), len(text)
|
| 648 |
+
>>> node = FnNode(
|
| 649 |
+
... process_text,
|
| 650 |
+
... outputs={"uppercase": gr.Textbox(), "length": gr.Number()}
|
| 651 |
+
... )
|
| 652 |
+
|
| 653 |
+
>>> # Allow parallel execution
|
| 654 |
+
>>> node = FnNode(my_func, concurrent=True)
|
| 655 |
+
|
| 656 |
+
>>> # Share GPU with other nodes (max 2 concurrent)
|
| 657 |
+
>>> node = FnNode(gpu_func, concurrency_group="gpu", max_concurrent=2)
|
| 658 |
+
"""
|
| 659 |
+
|
| 660 |
+
def __init__(
|
| 661 |
+
self,
|
| 662 |
+
fn: Callable,
|
| 663 |
+
name: str | None = None,
|
| 664 |
+
inputs: dict[str, Any] | None = None,
|
| 665 |
+
outputs: dict[str, Any] | None = None,
|
| 666 |
+
preprocess: Callable[[dict], dict] | None = None,
|
| 667 |
+
postprocess: Callable[..., Any] | None = None,
|
| 668 |
+
concurrent: bool = False,
|
| 669 |
+
concurrency_group: str | None = None,
|
| 670 |
+
max_concurrent: int = 1,
|
| 671 |
+
):
|
| 672 |
+
super().__init__(name)
|
| 673 |
+
self._fn = fn
|
| 674 |
+
self._preprocess = preprocess
|
| 675 |
+
self._postprocess = postprocess
|
| 676 |
+
self._concurrent = concurrent
|
| 677 |
+
self._concurrency_group = concurrency_group
|
| 678 |
+
self._max_concurrent = max_concurrent
|
| 679 |
+
|
| 680 |
+
if not self._name:
|
| 681 |
+
self._name = self._fn.__name__
|
| 682 |
+
|
| 683 |
+
if inputs:
|
| 684 |
+
self._validate_fn_inputs(inputs)
|
| 685 |
+
self._process_inputs(inputs)
|
| 686 |
+
else:
|
| 687 |
+
self._discover_signature()
|
| 688 |
+
|
| 689 |
+
if outputs:
|
| 690 |
+
self._process_outputs(outputs)
|
| 691 |
+
else:
|
| 692 |
+
self._output_ports = ["output"]
|
| 693 |
+
|
| 694 |
+
self._validate_ports()
|
| 695 |
+
|
| 696 |
+
def _discover_signature(self):
|
| 697 |
+
sig = inspect.signature(self._fn)
|
| 698 |
+
self._input_ports = list(sig.parameters.keys())
|
| 699 |
+
|
| 700 |
+
def _validate_fn_inputs(self, inputs: dict[str, Any]) -> None:
|
| 701 |
+
sig = inspect.signature(self._fn)
|
| 702 |
+
valid_params = set(sig.parameters.keys())
|
| 703 |
+
provided_params = set(inputs.keys())
|
| 704 |
+
invalid_params = provided_params - valid_params
|
| 705 |
+
|
| 706 |
+
if invalid_params:
|
| 707 |
+
suggestions = {}
|
| 708 |
+
for inv in invalid_params:
|
| 709 |
+
suggestion = suggest_similar(inv, valid_params)
|
| 710 |
+
if suggestion:
|
| 711 |
+
suggestions[inv] = suggestion
|
| 712 |
+
|
| 713 |
+
msg = (
|
| 714 |
+
f"Invalid input(s) {invalid_params} for function '{self._fn.__name__}'."
|
| 715 |
+
)
|
| 716 |
+
if suggestions:
|
| 717 |
+
suggestion_str = ", ".join(
|
| 718 |
+
f"'{k}' -> '{v}'" for k, v in suggestions.items()
|
| 719 |
+
)
|
| 720 |
+
msg += f" Did you mean: {suggestion_str}?"
|
| 721 |
+
msg += f" Valid parameters: {valid_params}"
|
| 722 |
+
raise ValueError(msg)
|
| 723 |
+
|
| 724 |
+
def _process_outputs(self, outputs: dict[str, Any]) -> None:
|
| 725 |
+
for port_name, component in outputs.items():
|
| 726 |
+
self._output_ports.append(port_name)
|
| 727 |
+
if component is None:
|
| 728 |
+
continue
|
| 729 |
+
if isinstance(component, ItemList):
|
| 730 |
+
self._item_list_schemas[port_name] = component.schema
|
| 731 |
+
elif _is_gradio_component(component):
|
| 732 |
+
self._output_components[port_name] = component
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
class InteractionNode(Node):
|
| 736 |
+
"""A node representing a user interaction point in the graph.
|
| 737 |
+
|
| 738 |
+
InteractionNodes pause execution and wait for user input before continuing.
|
| 739 |
+
They are used for approval steps, selections, or other human-in-the-loop
|
| 740 |
+
interactions.
|
| 741 |
+
|
| 742 |
+
Args:
|
| 743 |
+
name: Optional display name for the node.
|
| 744 |
+
interaction_type: Type of interaction (e.g., "generic", "approve", "choose_one").
|
| 745 |
+
inputs: Dict mapping input port names to components or connections.
|
| 746 |
+
outputs: Dict mapping output port names to components.
|
| 747 |
+
"""
|
| 748 |
+
|
| 749 |
+
def __init__(
|
| 750 |
+
self,
|
| 751 |
+
name: str | None = None,
|
| 752 |
+
interaction_type: str = "generic",
|
| 753 |
+
inputs: dict[str, Any] | None = None,
|
| 754 |
+
outputs: dict[str, Any] | None = None,
|
| 755 |
+
):
|
| 756 |
+
super().__init__(name)
|
| 757 |
+
self._interaction_type = interaction_type
|
| 758 |
+
|
| 759 |
+
if inputs:
|
| 760 |
+
self._process_inputs(inputs)
|
| 761 |
+
else:
|
| 762 |
+
self._input_ports = ["input"]
|
| 763 |
+
|
| 764 |
+
if outputs:
|
| 765 |
+
self._process_outputs(outputs)
|
| 766 |
+
else:
|
| 767 |
+
self._output_ports = ["output"]
|
| 768 |
+
|
| 769 |
+
if not self._name:
|
| 770 |
+
self._name = f"interaction_{self._id}"
|
| 771 |
+
|
| 772 |
+
self._validate_ports()
|
daggr/ops.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from daggr.node import InteractionNode
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ChooseOne(InteractionNode):
|
| 7 |
+
_instance_counter = 0
|
| 8 |
+
|
| 9 |
+
def __init__(self, name: str | None = None):
|
| 10 |
+
ChooseOne._instance_counter += 1
|
| 11 |
+
super().__init__(
|
| 12 |
+
name=name or f"choose_one_{ChooseOne._instance_counter}",
|
| 13 |
+
interaction_type="choose_one",
|
| 14 |
+
)
|
| 15 |
+
self._input_ports = ["options"]
|
| 16 |
+
self._output_ports = ["selected"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Approve(InteractionNode):
|
| 20 |
+
_instance_counter = 0
|
| 21 |
+
|
| 22 |
+
def __init__(self, name: str | None = None):
|
| 23 |
+
Approve._instance_counter += 1
|
| 24 |
+
super().__init__(
|
| 25 |
+
name=name or f"approve_{Approve._instance_counter}",
|
| 26 |
+
interaction_type="approve",
|
| 27 |
+
)
|
| 28 |
+
self._input_ports = ["input"]
|
| 29 |
+
self._output_ports = ["output"]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TextInput(InteractionNode):
|
| 33 |
+
_instance_counter = 0
|
| 34 |
+
|
| 35 |
+
def __init__(self, name: str | None = None, label: str = "Input"):
|
| 36 |
+
TextInput._instance_counter += 1
|
| 37 |
+
super().__init__(
|
| 38 |
+
name=name or f"text_input_{TextInput._instance_counter}",
|
| 39 |
+
interaction_type="text_input",
|
| 40 |
+
)
|
| 41 |
+
self._label = label
|
| 42 |
+
self._input_ports = []
|
| 43 |
+
self._output_ports = ["text"]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ImageInput(InteractionNode):
|
| 47 |
+
_instance_counter = 0
|
| 48 |
+
|
| 49 |
+
def __init__(self, name: str | None = None, label: str = "Image"):
|
| 50 |
+
ImageInput._instance_counter += 1
|
| 51 |
+
super().__init__(
|
| 52 |
+
name=name or f"image_input_{ImageInput._instance_counter}",
|
| 53 |
+
interaction_type="image_input",
|
| 54 |
+
)
|
| 55 |
+
self._label = label
|
| 56 |
+
self._input_ports = []
|
| 57 |
+
self._output_ports = ["image"]
|
daggr/package.json
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "daggr",
|
| 3 |
+
"version": "0.7.0",
|
| 4 |
+
"description": "",
|
| 5 |
+
"python": "true"
|
| 6 |
+
}
|
daggr/port.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Port module for node input/output definitions.
|
| 2 |
+
|
| 3 |
+
Ports are named connection points on nodes. Output ports can be connected
|
| 4 |
+
to input ports to form edges in the graph.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import TYPE_CHECKING, Any
|
| 10 |
+
|
| 11 |
+
if TYPE_CHECKING:
|
| 12 |
+
from daggr.node import Node
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class Port:
|
| 16 |
+
"""A named connection point on a node.
|
| 17 |
+
|
| 18 |
+
Ports represent inputs or outputs of a node. Access them as attributes
|
| 19 |
+
on a node: `node.port_name`.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
node: The node this port belongs to.
|
| 23 |
+
name: The name of the port.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, node: Node, name: str):
|
| 27 |
+
self.node = node
|
| 28 |
+
self.name = name
|
| 29 |
+
|
| 30 |
+
def __repr__(self):
|
| 31 |
+
return f"Port({self.node._name}.{self.name})"
|
| 32 |
+
|
| 33 |
+
def _as_source(self) -> tuple[Node, str]:
|
| 34 |
+
return (self.node, self.name)
|
| 35 |
+
|
| 36 |
+
def _as_target(self) -> tuple[Node, str]:
|
| 37 |
+
return (self.node, self.name)
|
| 38 |
+
|
| 39 |
+
def __getattr__(self, attr: str) -> ScatteredPort:
|
| 40 |
+
if attr.startswith("_"):
|
| 41 |
+
raise AttributeError(attr)
|
| 42 |
+
if (
|
| 43 |
+
hasattr(self.node, "_item_list_schemas")
|
| 44 |
+
and self.name in self.node._item_list_schemas
|
| 45 |
+
):
|
| 46 |
+
schema = self.node._item_list_schemas[self.name]
|
| 47 |
+
if attr in schema:
|
| 48 |
+
return ScatteredPort(self, attr)
|
| 49 |
+
raise AttributeError(f"Port '{self.name}' has no attribute '{attr}'")
|
| 50 |
+
|
| 51 |
+
@property
|
| 52 |
+
def each(self) -> ScatteredPort:
|
| 53 |
+
"""Scatter this port's output - run the downstream node once per item in the list."""
|
| 54 |
+
return ScatteredPort(self)
|
| 55 |
+
|
| 56 |
+
def all(self) -> GatheredPort:
|
| 57 |
+
"""Gather outputs from a scattered node back into a list."""
|
| 58 |
+
return GatheredPort(self)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class ScatteredPort:
|
| 62 |
+
"""A port that scatters its list output to run downstream nodes per-item.
|
| 63 |
+
|
| 64 |
+
Created by accessing `.each` on a port. When connected to a downstream
|
| 65 |
+
node, that node will be executed once for each item in the list.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self, port: Port, item_key: str | None = None):
|
| 69 |
+
self.port = port
|
| 70 |
+
self.item_key = item_key
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def node(self):
|
| 74 |
+
return self.port.node
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def name(self):
|
| 78 |
+
return self.port.name
|
| 79 |
+
|
| 80 |
+
def __getitem__(self, key: str) -> ScatteredPort:
|
| 81 |
+
"""Access a specific field from each scattered item (e.g., dialogue.json.each["text"])."""
|
| 82 |
+
return ScatteredPort(self.port, key)
|
| 83 |
+
|
| 84 |
+
def __repr__(self):
|
| 85 |
+
if self.item_key:
|
| 86 |
+
return f"ScatteredPort({self.port}['{self.item_key}'])"
|
| 87 |
+
return f"ScatteredPort({self.port})"
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
class GatheredPort:
|
| 91 |
+
"""A port that gathers scattered results back into a list.
|
| 92 |
+
|
| 93 |
+
Created by calling `.all()` on a port. Collects results from all
|
| 94 |
+
scattered executions back into a single list.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
def __init__(self, port: Port):
|
| 98 |
+
self.port = port
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def node(self):
|
| 102 |
+
return self.port.node
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def name(self):
|
| 106 |
+
return self.port.name
|
| 107 |
+
|
| 108 |
+
def __repr__(self):
|
| 109 |
+
return f"GatheredPort({self.port})"
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
PortLike = Port | ScatteredPort | GatheredPort
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def is_port(obj: Any) -> bool:
|
| 116 |
+
"""Check if an object is a Port, ScatteredPort, or GatheredPort."""
|
| 117 |
+
return isinstance(obj, (Port, ScatteredPort, GatheredPort))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class PortNamespace:
|
| 121 |
+
"""A namespace for accessing ports that start with underscores.
|
| 122 |
+
|
| 123 |
+
Used via `node._inputs` or `node._outputs` to access ports whose names
|
| 124 |
+
start with underscores (which can't be accessed directly as attributes).
|
| 125 |
+
"""
|
| 126 |
+
|
| 127 |
+
def __init__(self, node: Node, port_names: list[str]):
|
| 128 |
+
self._node = node
|
| 129 |
+
self._names = set(port_names)
|
| 130 |
+
|
| 131 |
+
def __getattr__(self, name: str) -> Port:
|
| 132 |
+
if name.startswith("_"):
|
| 133 |
+
raise AttributeError(name)
|
| 134 |
+
return Port(self._node, name)
|
| 135 |
+
|
| 136 |
+
def __dir__(self) -> list[str]:
|
| 137 |
+
return list(self._names)
|
| 138 |
+
|
| 139 |
+
def __repr__(self):
|
| 140 |
+
return f"PortNamespace({list(self._names)})"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class ItemList:
|
| 144 |
+
"""Define an editable list output with per-item schema.
|
| 145 |
+
|
| 146 |
+
Example:
|
| 147 |
+
outputs={
|
| 148 |
+
"items": ItemList(
|
| 149 |
+
speaker=gr.Dropdown(choices=["Host", "Guest"]),
|
| 150 |
+
text=gr.Textbox(lines=2),
|
| 151 |
+
),
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
The function should return a list of dicts matching the schema keys.
|
| 155 |
+
"""
|
| 156 |
+
|
| 157 |
+
def __init__(self, **schema):
|
| 158 |
+
self.schema = schema
|
daggr/py.typed
ADDED
|
File without changes
|
daggr/server.py
ADDED
|
@@ -0,0 +1,1946 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import base64
|
| 5 |
+
import json
|
| 6 |
+
import mimetypes
|
| 7 |
+
import os
|
| 8 |
+
import secrets
|
| 9 |
+
import socket
|
| 10 |
+
import tempfile
|
| 11 |
+
import threading
|
| 12 |
+
import time
|
| 13 |
+
import traceback
|
| 14 |
+
import uuid
|
| 15 |
+
import webbrowser
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from typing import TYPE_CHECKING, Any
|
| 18 |
+
|
| 19 |
+
import uvicorn
|
| 20 |
+
from fastapi import FastAPI, Header, Request, WebSocket, WebSocketDisconnect
|
| 21 |
+
from fastapi.responses import (
|
| 22 |
+
FileResponse,
|
| 23 |
+
HTMLResponse,
|
| 24 |
+
JSONResponse,
|
| 25 |
+
PlainTextResponse,
|
| 26 |
+
Response,
|
| 27 |
+
)
|
| 28 |
+
from gradio_client.utils import is_file_obj_with_meta
|
| 29 |
+
|
| 30 |
+
from daggr.executor import AsyncExecutor, FileValue
|
| 31 |
+
from daggr.node import (
|
| 32 |
+
_FILE_TYPE_COMPONENTS,
|
| 33 |
+
ChoiceNode,
|
| 34 |
+
GradioNode,
|
| 35 |
+
InferenceNode,
|
| 36 |
+
InteractionNode,
|
| 37 |
+
)
|
| 38 |
+
from daggr.session import ExecutionSession
|
| 39 |
+
from daggr.state import SessionState, get_daggr_cache_dir
|
| 40 |
+
|
| 41 |
+
_FILE_COMP_TYPES = {c.lower() for c in _FILE_TYPE_COMPONENTS}
|
| 42 |
+
|
| 43 |
+
if TYPE_CHECKING:
|
| 44 |
+
from gradio.themes import Base as Theme
|
| 45 |
+
|
| 46 |
+
from daggr.graph import Graph
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
INITIAL_PORT_VALUE = int(os.getenv("DAGGR_SERVER_PORT", "7860"))
|
| 50 |
+
TRY_NUM_PORTS = int(os.getenv("DAGGR_NUM_PORTS", "100"))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _find_available_port(host: str, start_port: int) -> int:
|
| 54 |
+
"""Find an available port starting from start_port."""
|
| 55 |
+
for port in range(start_port, start_port + TRY_NUM_PORTS):
|
| 56 |
+
try:
|
| 57 |
+
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
| 58 |
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
| 59 |
+
s.bind((host if host != "0.0.0.0" else "127.0.0.1", port))
|
| 60 |
+
s.close()
|
| 61 |
+
return port
|
| 62 |
+
except OSError:
|
| 63 |
+
continue
|
| 64 |
+
raise OSError(
|
| 65 |
+
f"Cannot find empty port in range: {start_port}-{start_port + TRY_NUM_PORTS - 1}. "
|
| 66 |
+
f"You can specify a different port by setting the DAGGR_SERVER_PORT environment variable "
|
| 67 |
+
f"or passing the port parameter to launch()."
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _get_theme(theme: "Theme | str | None") -> "Theme":
|
| 72 |
+
"""Get a Gradio theme instance from a theme specification.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
theme: Can be a Theme instance, a string name like "default", "soft",
|
| 76 |
+
"monochrome", "glass", or a Hub theme like "gradio/seafoam".
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
A Theme instance.
|
| 80 |
+
"""
|
| 81 |
+
from gradio.themes import Default
|
| 82 |
+
|
| 83 |
+
if theme is None:
|
| 84 |
+
return Default()
|
| 85 |
+
|
| 86 |
+
if isinstance(theme, str):
|
| 87 |
+
from gradio.themes import Base, Default, Glass, Monochrome, Soft
|
| 88 |
+
|
| 89 |
+
theme_mapping = {
|
| 90 |
+
"default": Default,
|
| 91 |
+
"soft": Soft,
|
| 92 |
+
"monochrome": Monochrome,
|
| 93 |
+
"glass": Glass,
|
| 94 |
+
"base": Base,
|
| 95 |
+
}
|
| 96 |
+
theme_lower = theme.lower()
|
| 97 |
+
if theme_lower in theme_mapping:
|
| 98 |
+
return theme_mapping[theme_lower]()
|
| 99 |
+
# Try loading from Hub
|
| 100 |
+
try:
|
| 101 |
+
return Base.from_hub(theme)
|
| 102 |
+
except Exception:
|
| 103 |
+
return Default()
|
| 104 |
+
|
| 105 |
+
return theme
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class DaggrServer:
|
| 109 |
+
def __init__(
|
| 110 |
+
self,
|
| 111 |
+
graph: Graph,
|
| 112 |
+
theme: "Theme | str | None" = None,
|
| 113 |
+
api_server: bool = True,
|
| 114 |
+
):
|
| 115 |
+
self.graph = graph
|
| 116 |
+
self.api_server = api_server
|
| 117 |
+
self.executor = AsyncExecutor(graph)
|
| 118 |
+
self.state = SessionState(db_path=os.environ.get("DAGGR_DB_PATH"))
|
| 119 |
+
self.app = FastAPI(title=graph.name)
|
| 120 |
+
self.connections: dict[str, WebSocket] = {}
|
| 121 |
+
self.theme = _get_theme(theme)
|
| 122 |
+
self.theme_css = self.theme._get_theme_css()
|
| 123 |
+
self._setup_routes()
|
| 124 |
+
|
| 125 |
+
def _extract_token_from_header(self, authorization: str | None) -> str | None:
|
| 126 |
+
if authorization and authorization.startswith("Bearer "):
|
| 127 |
+
return authorization[7:]
|
| 128 |
+
return None
|
| 129 |
+
|
| 130 |
+
def _validate_hf_token(self, token: str) -> dict | None:
|
| 131 |
+
try:
|
| 132 |
+
from huggingface_hub import whoami
|
| 133 |
+
|
| 134 |
+
info = whoami(token=token, cache=True)
|
| 135 |
+
return {
|
| 136 |
+
"username": info.get("name"),
|
| 137 |
+
"fullname": info.get("fullname"),
|
| 138 |
+
"avatar_url": info.get("avatarUrl"),
|
| 139 |
+
}
|
| 140 |
+
except Exception:
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
def _setup_routes(self):
|
| 144 |
+
frontend_dir = Path(__file__).parent / "frontend" / "dist"
|
| 145 |
+
if not frontend_dir.exists():
|
| 146 |
+
raise RuntimeError(
|
| 147 |
+
f"Frontend not found at {frontend_dir}. "
|
| 148 |
+
"If developing, run 'npm run build' in daggr/frontend/"
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
@self.app.get("/theme.css", response_class=PlainTextResponse)
|
| 152 |
+
async def get_theme_css():
|
| 153 |
+
return PlainTextResponse(self.theme_css, media_type="text/css")
|
| 154 |
+
|
| 155 |
+
@self.app.get("/api/graph")
|
| 156 |
+
async def get_graph():
|
| 157 |
+
return self._build_graph_data()
|
| 158 |
+
|
| 159 |
+
@self.app.get("/api/hf_user")
|
| 160 |
+
async def get_hf_user():
|
| 161 |
+
return self._get_hf_user_info()
|
| 162 |
+
|
| 163 |
+
@self.app.get("/api/user_info")
|
| 164 |
+
async def get_user_info(authorization: str | None = Header(default=None)):
|
| 165 |
+
browser_token = self._extract_token_from_header(authorization)
|
| 166 |
+
if browser_token:
|
| 167 |
+
hf_user = self._validate_hf_token(browser_token)
|
| 168 |
+
else:
|
| 169 |
+
hf_user = self._get_hf_user_info()
|
| 170 |
+
user_id = self.state.get_effective_user_id(hf_user)
|
| 171 |
+
is_on_spaces = os.environ.get("SPACE_ID") is not None
|
| 172 |
+
persistence_enabled = self.graph.persist_key is not None
|
| 173 |
+
return {
|
| 174 |
+
"hf_user": hf_user,
|
| 175 |
+
"user_id": user_id,
|
| 176 |
+
"is_on_spaces": is_on_spaces,
|
| 177 |
+
"can_persist": user_id is not None and persistence_enabled,
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
@self.app.post("/api/auth/login")
|
| 181 |
+
async def auth_login(request: Request):
|
| 182 |
+
try:
|
| 183 |
+
body = await request.json()
|
| 184 |
+
token = body.get("token")
|
| 185 |
+
if not token:
|
| 186 |
+
return JSONResponse({"error": "Token is required"}, status_code=400)
|
| 187 |
+
hf_user = self._validate_hf_token(token)
|
| 188 |
+
if not hf_user:
|
| 189 |
+
return JSONResponse({"error": "Invalid token"}, status_code=401)
|
| 190 |
+
return {"hf_user": hf_user, "success": True}
|
| 191 |
+
except Exception as e:
|
| 192 |
+
return JSONResponse({"error": str(e)}, status_code=500)
|
| 193 |
+
|
| 194 |
+
@self.app.post("/api/auth/logout")
|
| 195 |
+
async def auth_logout():
|
| 196 |
+
return {"success": True}
|
| 197 |
+
|
| 198 |
+
@self.app.get("/api/sheets")
|
| 199 |
+
async def list_sheets(authorization: str | None = Header(default=None)):
|
| 200 |
+
if not self.graph.persist_key:
|
| 201 |
+
return {"sheets": [], "user_id": None}
|
| 202 |
+
browser_token = self._extract_token_from_header(authorization)
|
| 203 |
+
if browser_token:
|
| 204 |
+
hf_user = self._validate_hf_token(browser_token)
|
| 205 |
+
else:
|
| 206 |
+
hf_user = self._get_hf_user_info()
|
| 207 |
+
user_id = self.state.get_effective_user_id(hf_user)
|
| 208 |
+
if not user_id:
|
| 209 |
+
return JSONResponse(
|
| 210 |
+
{"error": "Login required to access sheets on Spaces"},
|
| 211 |
+
status_code=401,
|
| 212 |
+
)
|
| 213 |
+
sheets = self.state.list_sheets(user_id, self.graph.persist_key)
|
| 214 |
+
return {"sheets": sheets, "user_id": user_id}
|
| 215 |
+
|
| 216 |
+
@self.app.post("/api/sheets")
|
| 217 |
+
async def create_sheet(
|
| 218 |
+
request: Request, authorization: str | None = Header(default=None)
|
| 219 |
+
):
|
| 220 |
+
if not self.graph.persist_key:
|
| 221 |
+
return JSONResponse(
|
| 222 |
+
{"error": "Persistence is disabled for this graph"},
|
| 223 |
+
status_code=400,
|
| 224 |
+
)
|
| 225 |
+
browser_token = self._extract_token_from_header(authorization)
|
| 226 |
+
if browser_token:
|
| 227 |
+
hf_user = self._validate_hf_token(browser_token)
|
| 228 |
+
else:
|
| 229 |
+
hf_user = self._get_hf_user_info()
|
| 230 |
+
user_id = self.state.get_effective_user_id(hf_user)
|
| 231 |
+
if not user_id:
|
| 232 |
+
return JSONResponse(
|
| 233 |
+
{"error": "Login required to create sheets on Spaces"},
|
| 234 |
+
status_code=401,
|
| 235 |
+
)
|
| 236 |
+
body = await request.json()
|
| 237 |
+
name = body.get("name")
|
| 238 |
+
sheet_id = self.state.create_sheet(user_id, self.graph.persist_key, name)
|
| 239 |
+
sheet = self.state.get_sheet(sheet_id)
|
| 240 |
+
return {"sheet": sheet}
|
| 241 |
+
|
| 242 |
+
@self.app.patch("/api/sheets/{sheet_id}")
|
| 243 |
+
async def rename_sheet(
|
| 244 |
+
sheet_id: str,
|
| 245 |
+
request: Request,
|
| 246 |
+
authorization: str | None = Header(default=None),
|
| 247 |
+
):
|
| 248 |
+
browser_token = self._extract_token_from_header(authorization)
|
| 249 |
+
if browser_token:
|
| 250 |
+
hf_user = self._validate_hf_token(browser_token)
|
| 251 |
+
else:
|
| 252 |
+
hf_user = self._get_hf_user_info()
|
| 253 |
+
user_id = self.state.get_effective_user_id(hf_user)
|
| 254 |
+
if not user_id:
|
| 255 |
+
return JSONResponse({"error": "Login required"}, status_code=401)
|
| 256 |
+
sheet = self.state.get_sheet(sheet_id)
|
| 257 |
+
if not sheet:
|
| 258 |
+
return JSONResponse({"error": "Sheet not found"}, status_code=404)
|
| 259 |
+
if sheet["user_id"] != user_id:
|
| 260 |
+
return JSONResponse({"error": "Access denied"}, status_code=403)
|
| 261 |
+
body = await request.json()
|
| 262 |
+
new_name = body.get("name")
|
| 263 |
+
if not new_name:
|
| 264 |
+
return JSONResponse({"error": "Name required"}, status_code=400)
|
| 265 |
+
self.state.rename_sheet(sheet_id, new_name)
|
| 266 |
+
return {"success": True, "sheet": self.state.get_sheet(sheet_id)}
|
| 267 |
+
|
| 268 |
+
@self.app.delete("/api/sheets/{sheet_id}")
|
| 269 |
+
async def delete_sheet(
|
| 270 |
+
sheet_id: str, authorization: str | None = Header(default=None)
|
| 271 |
+
):
|
| 272 |
+
browser_token = self._extract_token_from_header(authorization)
|
| 273 |
+
if browser_token:
|
| 274 |
+
hf_user = self._validate_hf_token(browser_token)
|
| 275 |
+
else:
|
| 276 |
+
hf_user = self._get_hf_user_info()
|
| 277 |
+
user_id = self.state.get_effective_user_id(hf_user)
|
| 278 |
+
if not user_id:
|
| 279 |
+
return JSONResponse({"error": "Login required"}, status_code=401)
|
| 280 |
+
sheet = self.state.get_sheet(sheet_id)
|
| 281 |
+
if not sheet:
|
| 282 |
+
return JSONResponse({"error": "Sheet not found"}, status_code=404)
|
| 283 |
+
if sheet["user_id"] != user_id:
|
| 284 |
+
return JSONResponse({"error": "Access denied"}, status_code=403)
|
| 285 |
+
self.state.delete_sheet(sheet_id)
|
| 286 |
+
return {"success": True}
|
| 287 |
+
|
| 288 |
+
@self.app.get("/api/sheets/{sheet_id}/state")
|
| 289 |
+
async def get_sheet_state(
|
| 290 |
+
sheet_id: str, authorization: str | None = Header(default=None)
|
| 291 |
+
):
|
| 292 |
+
browser_token = self._extract_token_from_header(authorization)
|
| 293 |
+
if browser_token:
|
| 294 |
+
hf_user = self._validate_hf_token(browser_token)
|
| 295 |
+
else:
|
| 296 |
+
hf_user = self._get_hf_user_info()
|
| 297 |
+
user_id = self.state.get_effective_user_id(hf_user)
|
| 298 |
+
if not user_id:
|
| 299 |
+
return JSONResponse({"error": "Login required"}, status_code=401)
|
| 300 |
+
sheet = self.state.get_sheet(sheet_id)
|
| 301 |
+
if not sheet:
|
| 302 |
+
return JSONResponse({"error": "Sheet not found"}, status_code=404)
|
| 303 |
+
if sheet["user_id"] != user_id:
|
| 304 |
+
return JSONResponse({"error": "Access denied"}, status_code=403)
|
| 305 |
+
state = self.state.get_sheet_state(sheet_id)
|
| 306 |
+
return {"sheet": sheet, "state": state}
|
| 307 |
+
|
| 308 |
+
@self.app.post("/api/run/{node_name}")
|
| 309 |
+
async def run_to_node(node_name: str, data: dict):
|
| 310 |
+
session = ExecutionSession(self.graph)
|
| 311 |
+
session_id = data.get("session_id")
|
| 312 |
+
input_values = data.get("inputs", {})
|
| 313 |
+
selected_results = data.get("selected_results", {})
|
| 314 |
+
return await self._execute_to_node(
|
| 315 |
+
session, node_name, session_id, input_values, selected_results
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
if self.api_server:
|
| 319 |
+
|
| 320 |
+
@self.app.get("/api/schema")
|
| 321 |
+
async def get_api_schema():
|
| 322 |
+
return self.graph.get_api_schema()
|
| 323 |
+
|
| 324 |
+
@self.app.post("/api/call")
|
| 325 |
+
async def call_workflow(request: Request):
|
| 326 |
+
return await self._execute_workflow_api(request, subgraph_id=None)
|
| 327 |
+
|
| 328 |
+
@self.app.post("/api/call/{subgraph_id}")
|
| 329 |
+
async def call_subgraph(subgraph_id: str, request: Request):
|
| 330 |
+
return await self._execute_workflow_api(
|
| 331 |
+
request, subgraph_id=subgraph_id
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
@self.app.websocket("/ws/{session_id}")
|
| 335 |
+
async def websocket_endpoint(websocket: WebSocket, session_id: str):
|
| 336 |
+
await websocket.accept()
|
| 337 |
+
self.connections[session_id] = websocket
|
| 338 |
+
|
| 339 |
+
hf_user = self._get_hf_user_info()
|
| 340 |
+
user_id = self.state.get_effective_user_id(hf_user)
|
| 341 |
+
current_sheet_id: str | None = None
|
| 342 |
+
|
| 343 |
+
session = ExecutionSession(self.graph)
|
| 344 |
+
running_tasks: dict[str, asyncio.Task] = {}
|
| 345 |
+
|
| 346 |
+
async def run_node_execution(
|
| 347 |
+
node_name: str,
|
| 348 |
+
sheet_id: str | None,
|
| 349 |
+
input_values: dict,
|
| 350 |
+
item_list_values: dict,
|
| 351 |
+
selected_results: dict,
|
| 352 |
+
run_id: str,
|
| 353 |
+
user_id: str | None,
|
| 354 |
+
run_ancestors: bool = True,
|
| 355 |
+
):
|
| 356 |
+
try:
|
| 357 |
+
async for result in self._execute_to_node_streaming(
|
| 358 |
+
session,
|
| 359 |
+
node_name,
|
| 360 |
+
sheet_id,
|
| 361 |
+
input_values,
|
| 362 |
+
item_list_values,
|
| 363 |
+
selected_results,
|
| 364 |
+
run_id,
|
| 365 |
+
user_id,
|
| 366 |
+
run_ancestors,
|
| 367 |
+
):
|
| 368 |
+
await websocket.send_json(result)
|
| 369 |
+
except asyncio.CancelledError:
|
| 370 |
+
pass
|
| 371 |
+
except Exception as e:
|
| 372 |
+
await websocket.send_json(
|
| 373 |
+
{
|
| 374 |
+
"type": "error",
|
| 375 |
+
"run_id": run_id,
|
| 376 |
+
"error": str(e),
|
| 377 |
+
"node": node_name,
|
| 378 |
+
}
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
while True:
|
| 383 |
+
data = await websocket.receive_json()
|
| 384 |
+
action = data.get("action")
|
| 385 |
+
|
| 386 |
+
if "hf_token" in data:
|
| 387 |
+
browser_hf_token = data.get("hf_token")
|
| 388 |
+
old_user_id = user_id
|
| 389 |
+
if browser_hf_token:
|
| 390 |
+
hf_user = self._validate_hf_token(browser_hf_token)
|
| 391 |
+
user_id = self.state.get_effective_user_id(hf_user)
|
| 392 |
+
session.set_hf_token(browser_hf_token)
|
| 393 |
+
else:
|
| 394 |
+
hf_user = self._get_hf_user_info()
|
| 395 |
+
user_id = self.state.get_effective_user_id(hf_user)
|
| 396 |
+
session.set_hf_token(None)
|
| 397 |
+
if old_user_id != user_id:
|
| 398 |
+
session.clear_results()
|
| 399 |
+
current_sheet_id = None
|
| 400 |
+
|
| 401 |
+
if action == "run":
|
| 402 |
+
node_name = data.get("node_name")
|
| 403 |
+
input_values = data.get("inputs", {})
|
| 404 |
+
item_list_values = data.get("item_list_values", {})
|
| 405 |
+
selected_results = data.get("selected_results", {})
|
| 406 |
+
run_id = data.get("run_id")
|
| 407 |
+
sheet_id = data.get("sheet_id") or current_sheet_id
|
| 408 |
+
run_ancestors = data.get("run_ancestors", True)
|
| 409 |
+
|
| 410 |
+
task = asyncio.create_task(
|
| 411 |
+
run_node_execution(
|
| 412 |
+
node_name,
|
| 413 |
+
sheet_id,
|
| 414 |
+
input_values,
|
| 415 |
+
item_list_values,
|
| 416 |
+
selected_results,
|
| 417 |
+
run_id,
|
| 418 |
+
user_id,
|
| 419 |
+
run_ancestors,
|
| 420 |
+
)
|
| 421 |
+
)
|
| 422 |
+
running_tasks[run_id] = task
|
| 423 |
+
task.add_done_callback(
|
| 424 |
+
lambda t, rid=run_id: running_tasks.pop(rid, None)
|
| 425 |
+
)
|
| 426 |
+
|
| 427 |
+
elif action == "cancel":
|
| 428 |
+
cancel_run_id = data.get("run_id")
|
| 429 |
+
cancel_node = data.get("node_name")
|
| 430 |
+
task = running_tasks.get(cancel_run_id)
|
| 431 |
+
if task:
|
| 432 |
+
task.cancel()
|
| 433 |
+
await websocket.send_json(
|
| 434 |
+
{
|
| 435 |
+
"type": "cancelled",
|
| 436 |
+
"run_id": cancel_run_id,
|
| 437 |
+
"node": cancel_node,
|
| 438 |
+
}
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
elif action == "get_graph":
|
| 442 |
+
try:
|
| 443 |
+
sheet_id = data.get("sheet_id")
|
| 444 |
+
|
| 445 |
+
persisted_inputs = {}
|
| 446 |
+
persisted_results: dict[str, list[Any]] = {}
|
| 447 |
+
persisted_transform = None
|
| 448 |
+
|
| 449 |
+
if user_id and sheet_id:
|
| 450 |
+
sheet = self.state.get_sheet(sheet_id)
|
| 451 |
+
if sheet and sheet["user_id"] == user_id:
|
| 452 |
+
current_sheet_id = sheet_id
|
| 453 |
+
state = self.state.get_sheet_state(sheet_id)
|
| 454 |
+
persisted_inputs = state.get("inputs", {})
|
| 455 |
+
persisted_results = state.get("results", {})
|
| 456 |
+
persisted_transform = sheet.get("transform")
|
| 457 |
+
|
| 458 |
+
node_results = {}
|
| 459 |
+
for node_name, results_list in persisted_results.items():
|
| 460 |
+
if results_list:
|
| 461 |
+
last_entry = results_list[-1]
|
| 462 |
+
if (
|
| 463 |
+
isinstance(last_entry, dict)
|
| 464 |
+
and "result" in last_entry
|
| 465 |
+
):
|
| 466 |
+
node_results[node_name] = last_entry["result"]
|
| 467 |
+
else:
|
| 468 |
+
node_results[node_name] = last_entry
|
| 469 |
+
|
| 470 |
+
graph_data = self._build_graph_data(
|
| 471 |
+
node_results=node_results,
|
| 472 |
+
input_values=persisted_inputs,
|
| 473 |
+
)
|
| 474 |
+
graph_data["session_id"] = session_id
|
| 475 |
+
graph_data["sheet_id"] = current_sheet_id
|
| 476 |
+
graph_data["user_id"] = user_id
|
| 477 |
+
graph_data["persisted_results"] = (
|
| 478 |
+
self._transform_persisted_results(persisted_results)
|
| 479 |
+
)
|
| 480 |
+
graph_data["transform"] = persisted_transform
|
| 481 |
+
|
| 482 |
+
await websocket.send_json(
|
| 483 |
+
{"type": "graph", "data": graph_data}
|
| 484 |
+
)
|
| 485 |
+
except Exception as e:
|
| 486 |
+
print(f"[ERROR] get_graph failed: {e}")
|
| 487 |
+
traceback.print_exc()
|
| 488 |
+
await websocket.send_json(
|
| 489 |
+
{"type": "error", "error": str(e)}
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
elif action == "save_input":
|
| 493 |
+
if user_id and current_sheet_id:
|
| 494 |
+
node_id = data.get("node_id")
|
| 495 |
+
port_name = data.get("port_name")
|
| 496 |
+
value = data.get("value")
|
| 497 |
+
if node_id and port_name is not None:
|
| 498 |
+
self.state.save_input(
|
| 499 |
+
current_sheet_id, node_id, port_name, value
|
| 500 |
+
)
|
| 501 |
+
await websocket.send_json(
|
| 502 |
+
{"type": "input_saved", "node_id": node_id}
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
elif action == "save_transform":
|
| 506 |
+
if user_id and current_sheet_id:
|
| 507 |
+
x = data.get("x", 0)
|
| 508 |
+
y = data.get("y", 0)
|
| 509 |
+
scale = data.get("scale", 1)
|
| 510 |
+
self.state.save_transform(current_sheet_id, x, y, scale)
|
| 511 |
+
|
| 512 |
+
elif action == "set_sheet":
|
| 513 |
+
sheet_id = data.get("sheet_id")
|
| 514 |
+
if user_id and sheet_id:
|
| 515 |
+
sheet = self.state.get_sheet(sheet_id)
|
| 516 |
+
if sheet and sheet["user_id"] == user_id:
|
| 517 |
+
current_sheet_id = sheet_id
|
| 518 |
+
session.clear_results()
|
| 519 |
+
await websocket.send_json(
|
| 520 |
+
{"type": "sheet_set", "sheet_id": sheet_id}
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
elif action == "save_variant_selection":
|
| 524 |
+
node_id = data.get("node_id")
|
| 525 |
+
variant_index = data.get("variant_index", 0)
|
| 526 |
+
if user_id and current_sheet_id and node_id is not None:
|
| 527 |
+
self.state.save_input(
|
| 528 |
+
current_sheet_id,
|
| 529 |
+
node_id,
|
| 530 |
+
"_selected_variant",
|
| 531 |
+
variant_index,
|
| 532 |
+
)
|
| 533 |
+
await websocket.send_json(
|
| 534 |
+
{
|
| 535 |
+
"type": "variant_selection_saved",
|
| 536 |
+
"node_id": node_id,
|
| 537 |
+
"variant_index": variant_index,
|
| 538 |
+
}
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
elif action == "clear_sheet":
|
| 542 |
+
if user_id and current_sheet_id:
|
| 543 |
+
self.state.clear_sheet_data(current_sheet_id)
|
| 544 |
+
await websocket.send_json({"type": "sheet_cleared"})
|
| 545 |
+
|
| 546 |
+
except WebSocketDisconnect:
|
| 547 |
+
for task in running_tasks.values():
|
| 548 |
+
task.cancel()
|
| 549 |
+
if session_id in self.connections:
|
| 550 |
+
del self.connections[session_id]
|
| 551 |
+
except Exception as e:
|
| 552 |
+
for task in running_tasks.values():
|
| 553 |
+
task.cancel()
|
| 554 |
+
print(f"[ERROR] WebSocket error: {e}")
|
| 555 |
+
traceback.print_exc()
|
| 556 |
+
|
| 557 |
+
@self.app.get("/")
|
| 558 |
+
async def serve_index():
|
| 559 |
+
index_path = frontend_dir / "index.html"
|
| 560 |
+
if index_path.exists():
|
| 561 |
+
return FileResponse(index_path)
|
| 562 |
+
return HTMLResponse(self._get_dev_html())
|
| 563 |
+
|
| 564 |
+
@self.app.get("/assets/{path:path}")
|
| 565 |
+
async def serve_assets(path: str):
|
| 566 |
+
file_path = frontend_dir / "assets" / path
|
| 567 |
+
if file_path.exists():
|
| 568 |
+
content_type, _ = mimetypes.guess_type(str(file_path))
|
| 569 |
+
return FileResponse(file_path, media_type=content_type)
|
| 570 |
+
return Response(status_code=404)
|
| 571 |
+
|
| 572 |
+
@self.app.get("/daggr-assets/{path:path}")
|
| 573 |
+
async def serve_daggr_assets(path: str):
|
| 574 |
+
assets_dir = Path(__file__).parent / "assets"
|
| 575 |
+
file_path = assets_dir / path
|
| 576 |
+
if file_path.exists():
|
| 577 |
+
content_type, _ = mimetypes.guess_type(str(file_path))
|
| 578 |
+
return FileResponse(file_path, media_type=content_type)
|
| 579 |
+
return Response(status_code=404)
|
| 580 |
+
|
| 581 |
+
@self.app.get("/file/{path:path}")
|
| 582 |
+
async def serve_local_file(path: str):
|
| 583 |
+
if len(path) >= 2 and path[1] == ":":
|
| 584 |
+
file_path = Path(path)
|
| 585 |
+
else:
|
| 586 |
+
file_path = Path("/") / path
|
| 587 |
+
temp_dir = Path(tempfile.gettempdir()).resolve()
|
| 588 |
+
daggr_cache = get_daggr_cache_dir().resolve()
|
| 589 |
+
|
| 590 |
+
try:
|
| 591 |
+
resolved = file_path.resolve()
|
| 592 |
+
is_allowed = str(resolved).startswith(str(temp_dir)) or str(
|
| 593 |
+
resolved
|
| 594 |
+
).startswith(str(daggr_cache))
|
| 595 |
+
if not is_allowed:
|
| 596 |
+
return Response(status_code=403)
|
| 597 |
+
except (ValueError, OSError):
|
| 598 |
+
return Response(status_code=403)
|
| 599 |
+
if resolved.exists() and resolved.is_file():
|
| 600 |
+
content_type, _ = mimetypes.guess_type(str(resolved))
|
| 601 |
+
return FileResponse(
|
| 602 |
+
resolved, media_type=content_type or "application/octet-stream"
|
| 603 |
+
)
|
| 604 |
+
return Response(status_code=404)
|
| 605 |
+
|
| 606 |
+
@self.app.get("/{path:path}")
|
| 607 |
+
async def serve_static(path: str):
|
| 608 |
+
if path.startswith("api/") or path.startswith("ws/"):
|
| 609 |
+
return Response(status_code=404)
|
| 610 |
+
file_path = frontend_dir / path
|
| 611 |
+
if file_path.exists() and file_path.is_file():
|
| 612 |
+
return FileResponse(file_path)
|
| 613 |
+
index_path = frontend_dir / "index.html"
|
| 614 |
+
if index_path.exists():
|
| 615 |
+
return FileResponse(index_path)
|
| 616 |
+
return HTMLResponse(self._get_dev_html())
|
| 617 |
+
|
| 618 |
+
def _get_dev_html(self) -> str:
|
| 619 |
+
return f"""<!DOCTYPE html>
|
| 620 |
+
<html lang="en">
|
| 621 |
+
<head>
|
| 622 |
+
<meta charset="UTF-8">
|
| 623 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 624 |
+
<title>{self.graph.name}</title>
|
| 625 |
+
<link rel="stylesheet" href="/theme.css">
|
| 626 |
+
<style>
|
| 627 |
+
* {{ margin: 0; box-sizing: border-box; }}
|
| 628 |
+
body {{
|
| 629 |
+
background: var(--body-background-fill, #000);
|
| 630 |
+
min-height: 100vh;
|
| 631 |
+
font-family: 'Space Grotesk', -apple-system, BlinkMacSystemFont, sans-serif;
|
| 632 |
+
overflow: hidden;
|
| 633 |
+
color: var(--body-text-color, #fff);
|
| 634 |
+
}}
|
| 635 |
+
</style>
|
| 636 |
+
<script type="module" src="http://localhost:5173/src/main.ts"></script>
|
| 637 |
+
</head>
|
| 638 |
+
<body class="dark">
|
| 639 |
+
<div id="app"></div>
|
| 640 |
+
</body>
|
| 641 |
+
</html>"""
|
| 642 |
+
|
| 643 |
+
def _get_node_url(self, node) -> str | None:
|
| 644 |
+
if isinstance(node, GradioNode):
|
| 645 |
+
src = node._src
|
| 646 |
+
if src.startswith("http://") or src.startswith("https://"):
|
| 647 |
+
return src
|
| 648 |
+
elif "/" in src:
|
| 649 |
+
return f"https://huggingface.co/spaces/{src}"
|
| 650 |
+
elif isinstance(node, InferenceNode):
|
| 651 |
+
return f"https://huggingface.co/{node._model_name_for_hub}"
|
| 652 |
+
return None
|
| 653 |
+
|
| 654 |
+
def _get_node_type(self, node, node_name: str) -> str:
|
| 655 |
+
type_map = {
|
| 656 |
+
"FnNode": "FN",
|
| 657 |
+
"TextInput": "INPUT",
|
| 658 |
+
"ImageInput": "IMAGE",
|
| 659 |
+
"ChooseOne": "SELECT",
|
| 660 |
+
"Approve": "APPROVE",
|
| 661 |
+
"GradioNode": "GRADIO",
|
| 662 |
+
"InferenceNode": "MODEL",
|
| 663 |
+
"InteractionNode": "ACTION",
|
| 664 |
+
"ChoiceNode": "CHOICE",
|
| 665 |
+
}
|
| 666 |
+
if isinstance(node, ChoiceNode):
|
| 667 |
+
return "CHOICE"
|
| 668 |
+
class_name = node.__class__.__name__
|
| 669 |
+
return type_map.get(class_name, class_name.upper())
|
| 670 |
+
|
| 671 |
+
def _has_scattered_input(self, node_name: str) -> bool:
|
| 672 |
+
for edge in self.graph._edges:
|
| 673 |
+
if edge.target_node._name == node_name and edge.is_scattered:
|
| 674 |
+
return True
|
| 675 |
+
return False
|
| 676 |
+
|
| 677 |
+
def _get_scattered_edge(self, node_name: str):
|
| 678 |
+
for edge in self.graph._edges:
|
| 679 |
+
if edge.target_node._name == node_name and edge.is_scattered:
|
| 680 |
+
return edge
|
| 681 |
+
return None
|
| 682 |
+
|
| 683 |
+
def _is_output_node(self, node_name: str) -> bool:
|
| 684 |
+
return self.graph._nx_graph.out_degree(node_name) == 0
|
| 685 |
+
|
| 686 |
+
def _is_running_locally(self, node) -> bool:
|
| 687 |
+
if not isinstance(node, GradioNode):
|
| 688 |
+
return False
|
| 689 |
+
return bool(node._run_locally and node._local_url and not node._local_failed)
|
| 690 |
+
|
| 691 |
+
def _build_variant_data(self, variant, input_values: dict) -> dict[str, Any]:
|
| 692 |
+
variant_name = variant._name
|
| 693 |
+
if isinstance(variant, GradioNode) and not variant._name_explicitly_set:
|
| 694 |
+
variant_name = f"{variant._src}"
|
| 695 |
+
if variant._api_name:
|
| 696 |
+
variant_name += f" ({variant._api_name})"
|
| 697 |
+
|
| 698 |
+
input_components = []
|
| 699 |
+
for port_name, comp in variant._input_components.items():
|
| 700 |
+
comp_data = self._serialize_component(comp, port_name)
|
| 701 |
+
input_components.append(comp_data)
|
| 702 |
+
|
| 703 |
+
output_components = []
|
| 704 |
+
for port_name, comp in variant._output_components.items():
|
| 705 |
+
if comp is None:
|
| 706 |
+
continue
|
| 707 |
+
visible = getattr(comp, "visible", True)
|
| 708 |
+
if visible is False:
|
| 709 |
+
continue
|
| 710 |
+
comp_data = self._serialize_component(comp, port_name)
|
| 711 |
+
output_components.append(comp_data)
|
| 712 |
+
|
| 713 |
+
return {
|
| 714 |
+
"name": variant_name,
|
| 715 |
+
"input_components": input_components,
|
| 716 |
+
"output_components": output_components,
|
| 717 |
+
}
|
| 718 |
+
|
| 719 |
+
def _get_component_type(self, component) -> str:
|
| 720 |
+
class_name = component.__class__.__name__
|
| 721 |
+
type_map = {
|
| 722 |
+
"Audio": "audio",
|
| 723 |
+
"Textbox": "textbox",
|
| 724 |
+
"TextArea": "textarea",
|
| 725 |
+
"JSON": "json",
|
| 726 |
+
"Chatbot": "json",
|
| 727 |
+
"Image": "image",
|
| 728 |
+
"Number": "number",
|
| 729 |
+
"Markdown": "markdown",
|
| 730 |
+
"Text": "text",
|
| 731 |
+
"Dropdown": "dropdown",
|
| 732 |
+
"Video": "video",
|
| 733 |
+
"File": "file",
|
| 734 |
+
"Model3D": "model3d",
|
| 735 |
+
"Gallery": "gallery",
|
| 736 |
+
"Slider": "slider",
|
| 737 |
+
"Radio": "radio",
|
| 738 |
+
"Checkbox": "checkbox",
|
| 739 |
+
"CheckboxGroup": "checkboxgroup",
|
| 740 |
+
"ColorPicker": "colorpicker",
|
| 741 |
+
"Label": "label",
|
| 742 |
+
"HighlightedText": "highlightedtext",
|
| 743 |
+
"Code": "code",
|
| 744 |
+
"HTML": "html",
|
| 745 |
+
"Dataframe": "dataframe",
|
| 746 |
+
}
|
| 747 |
+
return type_map.get(class_name, "text")
|
| 748 |
+
|
| 749 |
+
def _serialize_component(self, comp, port_name: str) -> dict[str, Any]:
|
| 750 |
+
comp_type = self._get_component_type(comp)
|
| 751 |
+
comp_class = comp.__class__.__name__
|
| 752 |
+
|
| 753 |
+
props = {
|
| 754 |
+
"label": getattr(comp, "label", "") or port_name,
|
| 755 |
+
"show_label": bool(getattr(comp, "label", "")),
|
| 756 |
+
"interactive": getattr(comp, "interactive", True),
|
| 757 |
+
"visible": getattr(comp, "visible", True),
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
if hasattr(comp, "placeholder"):
|
| 761 |
+
props["placeholder"] = comp.placeholder
|
| 762 |
+
if hasattr(comp, "lines"):
|
| 763 |
+
props["lines"] = comp.lines
|
| 764 |
+
if hasattr(comp, "max_lines"):
|
| 765 |
+
props["max_lines"] = comp.max_lines
|
| 766 |
+
if hasattr(comp, "type"):
|
| 767 |
+
props["type"] = comp.type
|
| 768 |
+
if hasattr(comp, "choices") and comp.choices:
|
| 769 |
+
choices = []
|
| 770 |
+
for c in comp.choices:
|
| 771 |
+
if isinstance(c, (tuple, list)) and len(c) >= 2:
|
| 772 |
+
choices.append([c[0], c[1]])
|
| 773 |
+
else:
|
| 774 |
+
choices.append([str(c), c])
|
| 775 |
+
props["choices"] = choices
|
| 776 |
+
if hasattr(comp, "minimum"):
|
| 777 |
+
props["minimum"] = comp.minimum
|
| 778 |
+
if hasattr(comp, "maximum"):
|
| 779 |
+
props["maximum"] = comp.maximum
|
| 780 |
+
if hasattr(comp, "step"):
|
| 781 |
+
props["step"] = comp.step
|
| 782 |
+
|
| 783 |
+
value = getattr(comp, "value", None)
|
| 784 |
+
if is_file_obj_with_meta(value):
|
| 785 |
+
value = self._file_to_url(value["path"])
|
| 786 |
+
|
| 787 |
+
return {
|
| 788 |
+
"component": comp_class.lower(),
|
| 789 |
+
"type": comp_type,
|
| 790 |
+
"port_name": port_name,
|
| 791 |
+
"props": props,
|
| 792 |
+
"value": value,
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
def _file_to_url(self, value: Any) -> Any:
|
| 796 |
+
if isinstance(value, str) and not value.startswith("/file/"):
|
| 797 |
+
path = Path(value)
|
| 798 |
+
if path.is_absolute() and path.exists():
|
| 799 |
+
normalized = value.replace("\\", "/")
|
| 800 |
+
if normalized.startswith("/"):
|
| 801 |
+
return f"/file{normalized}"
|
| 802 |
+
return f"/file/{normalized}"
|
| 803 |
+
return value
|
| 804 |
+
|
| 805 |
+
def _validate_file_value(self, value: Any, comp_type: str) -> str | None:
|
| 806 |
+
"""Validate that a value is appropriate for a file-type component.
|
| 807 |
+
Returns an error message if invalid, None if valid."""
|
| 808 |
+
if value is None:
|
| 809 |
+
return None
|
| 810 |
+
if isinstance(value, str):
|
| 811 |
+
return None
|
| 812 |
+
if isinstance(value, dict):
|
| 813 |
+
if "url" in value or "path" in value:
|
| 814 |
+
return None
|
| 815 |
+
keys = list(value.keys())
|
| 816 |
+
if keys:
|
| 817 |
+
return (
|
| 818 |
+
f"Expected a file path string for {comp_type}, but got a dict "
|
| 819 |
+
f"with keys {keys}. If using postprocess, extract the path: "
|
| 820 |
+
f"e.g., `postprocess=lambda x: x['{keys[0]}']`"
|
| 821 |
+
)
|
| 822 |
+
return (
|
| 823 |
+
f"Expected a file path string for {comp_type}, but got an empty dict."
|
| 824 |
+
)
|
| 825 |
+
return f"Expected a file path string for {comp_type}, but got {type(value).__name__}."
|
| 826 |
+
|
| 827 |
+
def _transform_file_paths(self, data: Any) -> Any:
|
| 828 |
+
if isinstance(data, str):
|
| 829 |
+
return self._file_to_url(data)
|
| 830 |
+
elif isinstance(data, dict):
|
| 831 |
+
return {k: self._transform_file_paths(v) for k, v in data.items()}
|
| 832 |
+
elif isinstance(data, list):
|
| 833 |
+
return [self._transform_file_paths(item) for item in data]
|
| 834 |
+
return data
|
| 835 |
+
|
| 836 |
+
def _transform_persisted_results(
|
| 837 |
+
self, persisted_results: dict[str, list[Any]]
|
| 838 |
+
) -> dict[str, list[Any]]:
|
| 839 |
+
"""Transform persisted results, handling both old format (just result)
|
| 840 |
+
and new format (dict with result and inputs_snapshot)."""
|
| 841 |
+
transformed: dict[str, list[Any]] = {}
|
| 842 |
+
for node_name, results_list in persisted_results.items():
|
| 843 |
+
transformed[node_name] = []
|
| 844 |
+
for entry in results_list:
|
| 845 |
+
if isinstance(entry, dict) and "result" in entry:
|
| 846 |
+
transformed[node_name].append(
|
| 847 |
+
{
|
| 848 |
+
"result": self._transform_file_paths(entry["result"]),
|
| 849 |
+
"inputs_snapshot": entry.get("inputs_snapshot"),
|
| 850 |
+
}
|
| 851 |
+
)
|
| 852 |
+
else:
|
| 853 |
+
transformed[node_name].append(self._transform_file_paths(entry))
|
| 854 |
+
return transformed
|
| 855 |
+
|
| 856 |
+
def _build_input_components(self, node) -> list[dict[str, Any]]:
|
| 857 |
+
if not node._input_components:
|
| 858 |
+
return []
|
| 859 |
+
return [
|
| 860 |
+
self._serialize_component(comp, port_name)
|
| 861 |
+
for port_name, comp in node._input_components.items()
|
| 862 |
+
]
|
| 863 |
+
|
| 864 |
+
def _build_output_components(
|
| 865 |
+
self, node, result: Any = None
|
| 866 |
+
) -> tuple[list[dict[str, Any]], str | None]:
|
| 867 |
+
if not node._output_components:
|
| 868 |
+
return [], None
|
| 869 |
+
|
| 870 |
+
components = []
|
| 871 |
+
validation_error = None
|
| 872 |
+
for port_name, comp in node._output_components.items():
|
| 873 |
+
if comp is None:
|
| 874 |
+
continue
|
| 875 |
+
|
| 876 |
+
visible = getattr(comp, "visible", True)
|
| 877 |
+
if visible is False:
|
| 878 |
+
continue
|
| 879 |
+
|
| 880 |
+
comp_data = self._serialize_component(comp, port_name)
|
| 881 |
+
comp_type = self._get_component_type(comp)
|
| 882 |
+
if result is not None:
|
| 883 |
+
if isinstance(result, dict):
|
| 884 |
+
value = result.get(
|
| 885 |
+
port_name, result.get(comp_data["props"]["label"])
|
| 886 |
+
)
|
| 887 |
+
else:
|
| 888 |
+
value = result
|
| 889 |
+
if comp_type in _FILE_COMP_TYPES:
|
| 890 |
+
error = self._validate_file_value(value, comp_type)
|
| 891 |
+
if error and validation_error is None:
|
| 892 |
+
validation_error = error
|
| 893 |
+
value = self._file_to_url(value)
|
| 894 |
+
comp_data["value"] = value
|
| 895 |
+
components.append(comp_data)
|
| 896 |
+
return components, validation_error
|
| 897 |
+
|
| 898 |
+
def _build_scattered_items(
|
| 899 |
+
self, node_name: str, result: Any = None
|
| 900 |
+
) -> list[dict[str, Any]]:
|
| 901 |
+
scattered_edge = self._get_scattered_edge(node_name)
|
| 902 |
+
if not scattered_edge:
|
| 903 |
+
return []
|
| 904 |
+
|
| 905 |
+
node = self.graph.nodes[node_name]
|
| 906 |
+
item_output_type = "text"
|
| 907 |
+
for comp in node._output_components.values():
|
| 908 |
+
if comp is None:
|
| 909 |
+
continue
|
| 910 |
+
comp_type = self._get_component_type(comp)
|
| 911 |
+
if comp_type == "audio":
|
| 912 |
+
item_output_type = "audio"
|
| 913 |
+
break
|
| 914 |
+
|
| 915 |
+
items = []
|
| 916 |
+
if result and isinstance(result, dict) and "_scattered_results" in result:
|
| 917 |
+
results = result["_scattered_results"]
|
| 918 |
+
source_items = result.get("_items", [])
|
| 919 |
+
for i, item_result in enumerate(results):
|
| 920 |
+
source_item = source_items[i] if i < len(source_items) else None
|
| 921 |
+
preview = ""
|
| 922 |
+
output = None
|
| 923 |
+
|
| 924 |
+
if isinstance(source_item, dict):
|
| 925 |
+
preview_parts = [
|
| 926 |
+
f"{k}: {str(v)[:20]}" for k, v in list(source_item.items())[:2]
|
| 927 |
+
]
|
| 928 |
+
preview = ", ".join(preview_parts)
|
| 929 |
+
elif source_item:
|
| 930 |
+
preview = str(source_item)[:50]
|
| 931 |
+
|
| 932 |
+
if isinstance(item_result, dict):
|
| 933 |
+
first_key = list(item_result.keys())[0] if item_result else None
|
| 934 |
+
if first_key:
|
| 935 |
+
output = item_result[first_key]
|
| 936 |
+
else:
|
| 937 |
+
output = item_result
|
| 938 |
+
|
| 939 |
+
if output:
|
| 940 |
+
output = str(output)
|
| 941 |
+
|
| 942 |
+
items.append(
|
| 943 |
+
{
|
| 944 |
+
"index": i + 1,
|
| 945 |
+
"preview": preview or f"Item {i + 1}",
|
| 946 |
+
"output": output,
|
| 947 |
+
"is_audio_output": item_output_type == "audio",
|
| 948 |
+
}
|
| 949 |
+
)
|
| 950 |
+
return items
|
| 951 |
+
|
| 952 |
+
def _serialize_item_list_schema(
|
| 953 |
+
self, schema: dict[str, Any]
|
| 954 |
+
) -> list[dict[str, Any]]:
|
| 955 |
+
serialized = []
|
| 956 |
+
for field_name, comp in schema.items():
|
| 957 |
+
comp_data = self._serialize_component(comp, field_name)
|
| 958 |
+
serialized.append(comp_data)
|
| 959 |
+
return serialized
|
| 960 |
+
|
| 961 |
+
def _build_item_list_items(
|
| 962 |
+
self, node, port_name: str, result: Any = None
|
| 963 |
+
) -> list[dict[str, Any]]:
|
| 964 |
+
schema = node._item_list_schemas.get(port_name, {})
|
| 965 |
+
if not schema:
|
| 966 |
+
return []
|
| 967 |
+
|
| 968 |
+
items = []
|
| 969 |
+
if result and isinstance(result, dict) and port_name in result:
|
| 970 |
+
item_list = result[port_name]
|
| 971 |
+
if isinstance(item_list, list):
|
| 972 |
+
for i, item_data in enumerate(item_list):
|
| 973 |
+
item = {"index": i, "fields": {}}
|
| 974 |
+
if isinstance(item_data, dict):
|
| 975 |
+
for field_name in schema:
|
| 976 |
+
item["fields"][field_name] = item_data.get(field_name)
|
| 977 |
+
items.append(item)
|
| 978 |
+
return items
|
| 979 |
+
|
| 980 |
+
def _apply_item_list_edits(
|
| 981 |
+
self, node_name: str, result: Any, item_list_values: dict
|
| 982 |
+
) -> Any:
|
| 983 |
+
node = self.graph.nodes[node_name]
|
| 984 |
+
if not node._item_list_schemas:
|
| 985 |
+
return result
|
| 986 |
+
|
| 987 |
+
node_id = node_name.replace(" ", "_").replace("-", "_")
|
| 988 |
+
edits = item_list_values.get(node_id, {})
|
| 989 |
+
if not edits:
|
| 990 |
+
return result
|
| 991 |
+
|
| 992 |
+
first_port = list(node._item_list_schemas.keys())[0]
|
| 993 |
+
if isinstance(result, dict) and first_port in result:
|
| 994 |
+
items = result[first_port]
|
| 995 |
+
if isinstance(items, list):
|
| 996 |
+
for idx_str, field_edits in edits.items():
|
| 997 |
+
idx = int(idx_str)
|
| 998 |
+
if 0 <= idx < len(items) and isinstance(items[idx], dict):
|
| 999 |
+
items[idx].update(field_edits)
|
| 1000 |
+
return result
|
| 1001 |
+
|
| 1002 |
+
def _compute_node_depths(self) -> dict[str, int]:
|
| 1003 |
+
depths: dict[str, int] = {}
|
| 1004 |
+
connections = self.graph.get_connections()
|
| 1005 |
+
|
| 1006 |
+
for node_name in self.graph.nodes:
|
| 1007 |
+
if self.graph._nx_graph.in_degree(node_name) == 0:
|
| 1008 |
+
depths[node_name] = 0
|
| 1009 |
+
|
| 1010 |
+
changed = True
|
| 1011 |
+
while changed:
|
| 1012 |
+
changed = False
|
| 1013 |
+
for source, _, target, _ in connections:
|
| 1014 |
+
if source in depths:
|
| 1015 |
+
new_depth = depths[source] + 1
|
| 1016 |
+
if target not in depths or depths[target] < new_depth:
|
| 1017 |
+
depths[target] = new_depth
|
| 1018 |
+
changed = True
|
| 1019 |
+
|
| 1020 |
+
for node_name in self.graph.nodes:
|
| 1021 |
+
if node_name not in depths:
|
| 1022 |
+
depths[node_name] = 0
|
| 1023 |
+
|
| 1024 |
+
return depths
|
| 1025 |
+
|
| 1026 |
+
def _get_hf_user_info(self) -> dict | None:
|
| 1027 |
+
try:
|
| 1028 |
+
from huggingface_hub import get_token, whoami
|
| 1029 |
+
|
| 1030 |
+
token = get_token()
|
| 1031 |
+
if not token:
|
| 1032 |
+
return None
|
| 1033 |
+
|
| 1034 |
+
info = whoami(cache=True)
|
| 1035 |
+
return {
|
| 1036 |
+
"username": info.get("name"),
|
| 1037 |
+
"fullname": info.get("fullname"),
|
| 1038 |
+
"avatar_url": info.get("avatarUrl"),
|
| 1039 |
+
}
|
| 1040 |
+
except Exception:
|
| 1041 |
+
return None
|
| 1042 |
+
|
| 1043 |
+
def _build_graph_data(
|
| 1044 |
+
self,
|
| 1045 |
+
node_results: dict[str, Any] | None = None,
|
| 1046 |
+
node_statuses: dict[str, str] | None = None,
|
| 1047 |
+
input_values: dict[str, Any] | None = None,
|
| 1048 |
+
history: dict[str, dict[str, list[dict]]] | None = None,
|
| 1049 |
+
session_id: str | None = None,
|
| 1050 |
+
selected_results: dict[str, int] | None = None,
|
| 1051 |
+
) -> dict:
|
| 1052 |
+
node_results = node_results or {}
|
| 1053 |
+
node_statuses = node_statuses or {}
|
| 1054 |
+
input_values = input_values or {}
|
| 1055 |
+
history = history or {}
|
| 1056 |
+
selected_results = selected_results or {}
|
| 1057 |
+
|
| 1058 |
+
depths = self._compute_node_depths()
|
| 1059 |
+
|
| 1060 |
+
synthetic_input_nodes: list[dict[str, Any]] = []
|
| 1061 |
+
synthetic_edges: list[dict[str, Any]] = []
|
| 1062 |
+
input_node_positions: dict[str, tuple] = {}
|
| 1063 |
+
|
| 1064 |
+
component_to_input_node: dict[int, str] = {}
|
| 1065 |
+
creation_order = 0
|
| 1066 |
+
for node_name in self.graph.nodes:
|
| 1067 |
+
node = self.graph.nodes[node_name]
|
| 1068 |
+
|
| 1069 |
+
if isinstance(node, ChoiceNode):
|
| 1070 |
+
continue
|
| 1071 |
+
|
| 1072 |
+
if node._input_components:
|
| 1073 |
+
for idx, (port_name, comp) in enumerate(node._input_components.items()):
|
| 1074 |
+
comp_id = id(comp)
|
| 1075 |
+
|
| 1076 |
+
if comp_id in component_to_input_node:
|
| 1077 |
+
existing_input_node = component_to_input_node[comp_id]
|
| 1078 |
+
existing_input_id = existing_input_node.replace(
|
| 1079 |
+
" ", "_"
|
| 1080 |
+
).replace("-", "_")
|
| 1081 |
+
synthetic_edges.append(
|
| 1082 |
+
{
|
| 1083 |
+
"from_node": existing_input_id,
|
| 1084 |
+
"from_port": "value",
|
| 1085 |
+
"to_node": node_name.replace(" ", "_").replace(
|
| 1086 |
+
"-", "_"
|
| 1087 |
+
),
|
| 1088 |
+
"to_port": port_name,
|
| 1089 |
+
}
|
| 1090 |
+
)
|
| 1091 |
+
continue
|
| 1092 |
+
|
| 1093 |
+
input_node_name = f"{node_name}__{port_name}"
|
| 1094 |
+
input_node_id = input_node_name.replace(" ", "_").replace("-", "_")
|
| 1095 |
+
component_to_input_node[comp_id] = input_node_name
|
| 1096 |
+
|
| 1097 |
+
comp_data = self._serialize_component(comp, "value")
|
| 1098 |
+
label = comp_data["props"].get("label") or port_name
|
| 1099 |
+
|
| 1100 |
+
if input_node_id in input_values:
|
| 1101 |
+
comp_data["value"] = input_values[input_node_id].get(
|
| 1102 |
+
"value", comp_data["value"]
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
synthetic_input_nodes.append(
|
| 1106 |
+
{
|
| 1107 |
+
"node_name": input_node_name,
|
| 1108 |
+
"display_name": label,
|
| 1109 |
+
"target_node": node_name,
|
| 1110 |
+
"target_port": port_name,
|
| 1111 |
+
"component": comp_data,
|
| 1112 |
+
"index": idx,
|
| 1113 |
+
"creation_order": creation_order,
|
| 1114 |
+
}
|
| 1115 |
+
)
|
| 1116 |
+
creation_order += 1
|
| 1117 |
+
|
| 1118 |
+
synthetic_edges.append(
|
| 1119 |
+
{
|
| 1120 |
+
"from_node": input_node_id,
|
| 1121 |
+
"from_port": "value",
|
| 1122 |
+
"to_node": node_name.replace(" ", "_").replace("-", "_"),
|
| 1123 |
+
"to_port": port_name,
|
| 1124 |
+
}
|
| 1125 |
+
)
|
| 1126 |
+
|
| 1127 |
+
max_depth = max(depths.values()) if depths else 0
|
| 1128 |
+
|
| 1129 |
+
nodes_by_depth: dict[int, list[str]] = {}
|
| 1130 |
+
for node_name, depth in depths.items():
|
| 1131 |
+
if depth not in nodes_by_depth:
|
| 1132 |
+
nodes_by_depth[depth] = []
|
| 1133 |
+
nodes_by_depth[depth].append(node_name)
|
| 1134 |
+
|
| 1135 |
+
x_spacing = 350
|
| 1136 |
+
input_column_x = 50
|
| 1137 |
+
x_start = 400
|
| 1138 |
+
y_start = 120
|
| 1139 |
+
y_gap = 30
|
| 1140 |
+
base_node_height = 100
|
| 1141 |
+
component_base_height = 60
|
| 1142 |
+
line_height = 18
|
| 1143 |
+
|
| 1144 |
+
def calc_component_height(comp_data: dict) -> int:
|
| 1145 |
+
lines = comp_data.get("props", {}).get("lines", 1)
|
| 1146 |
+
lines = min(lines, 6)
|
| 1147 |
+
return component_base_height + max(0, lines - 1) * line_height
|
| 1148 |
+
|
| 1149 |
+
def calc_node_height(components: list[dict], num_ports: int = 1) -> int:
|
| 1150 |
+
comp_height = sum(calc_component_height(c) for c in components)
|
| 1151 |
+
port_height = max(num_ports, 1) * 22
|
| 1152 |
+
return base_node_height + port_height + comp_height
|
| 1153 |
+
|
| 1154 |
+
all_input_nodes_sorted: list[dict] = []
|
| 1155 |
+
for syn_node in synthetic_input_nodes:
|
| 1156 |
+
target_depth = depths.get(syn_node["target_node"], 0)
|
| 1157 |
+
all_input_nodes_sorted.append({**syn_node, "target_depth": target_depth})
|
| 1158 |
+
all_input_nodes_sorted.sort(key=lambda x: x["creation_order"])
|
| 1159 |
+
|
| 1160 |
+
current_input_y = y_start
|
| 1161 |
+
for syn_node in all_input_nodes_sorted:
|
| 1162 |
+
input_node_positions[syn_node["node_name"]] = (
|
| 1163 |
+
input_column_x,
|
| 1164 |
+
current_input_y,
|
| 1165 |
+
)
|
| 1166 |
+
node_height = calc_node_height([syn_node["component"]], 1)
|
| 1167 |
+
current_input_y += node_height + y_gap
|
| 1168 |
+
|
| 1169 |
+
node_positions: dict[str, tuple] = {}
|
| 1170 |
+
for depth in range(max_depth + 1):
|
| 1171 |
+
depth_nodes = nodes_by_depth.get(depth, [])
|
| 1172 |
+
current_y = y_start
|
| 1173 |
+
for node_name in depth_nodes:
|
| 1174 |
+
node = self.graph.nodes[node_name]
|
| 1175 |
+
output_comps, _ = self._build_output_components(node)
|
| 1176 |
+
num_ports = max(
|
| 1177 |
+
len(node._input_ports or []), len(node._output_ports or [])
|
| 1178 |
+
)
|
| 1179 |
+
node_height = calc_node_height(output_comps, num_ports)
|
| 1180 |
+
x = x_start + depth * x_spacing
|
| 1181 |
+
node_positions[node_name] = (x, current_y)
|
| 1182 |
+
current_y += node_height + y_gap
|
| 1183 |
+
|
| 1184 |
+
nodes = []
|
| 1185 |
+
|
| 1186 |
+
for syn_node in synthetic_input_nodes:
|
| 1187 |
+
node_name = syn_node["node_name"]
|
| 1188 |
+
display_name = syn_node["display_name"]
|
| 1189 |
+
node_id = node_name.replace(" ", "_").replace("-", "_")
|
| 1190 |
+
x, y = input_node_positions.get(node_name, (50, 50))
|
| 1191 |
+
comp = syn_node["component"]
|
| 1192 |
+
|
| 1193 |
+
nodes.append(
|
| 1194 |
+
{
|
| 1195 |
+
"id": node_id,
|
| 1196 |
+
"name": display_name,
|
| 1197 |
+
"type": "INPUT",
|
| 1198 |
+
"inputs": [],
|
| 1199 |
+
"outputs": ["value"],
|
| 1200 |
+
"x": x,
|
| 1201 |
+
"y": y,
|
| 1202 |
+
"has_input": False,
|
| 1203 |
+
"input_value": "",
|
| 1204 |
+
"input_components": [comp],
|
| 1205 |
+
"output_components": [],
|
| 1206 |
+
"is_map_node": False,
|
| 1207 |
+
"map_items": [],
|
| 1208 |
+
"map_item_count": 0,
|
| 1209 |
+
"item_output_type": "text",
|
| 1210 |
+
"status": "pending",
|
| 1211 |
+
"result": "",
|
| 1212 |
+
"is_output_node": False,
|
| 1213 |
+
"is_input_node": True,
|
| 1214 |
+
}
|
| 1215 |
+
)
|
| 1216 |
+
|
| 1217 |
+
for node_name in self.graph.nodes:
|
| 1218 |
+
node = self.graph.nodes[node_name]
|
| 1219 |
+
x, y = node_positions.get(node_name, (50, 50))
|
| 1220 |
+
|
| 1221 |
+
result = node_results.get(node_name)
|
| 1222 |
+
result_str = ""
|
| 1223 |
+
is_scattered = self._has_scattered_input(node_name)
|
| 1224 |
+
if result is not None and not node._output_components and not is_scattered:
|
| 1225 |
+
if isinstance(result, dict):
|
| 1226 |
+
display_result = {
|
| 1227 |
+
k: v for k, v in result.items() if not k.startswith("_")
|
| 1228 |
+
}
|
| 1229 |
+
result_str = json.dumps(display_result, indent=2, default=str)[:300]
|
| 1230 |
+
elif isinstance(result, (list, tuple)):
|
| 1231 |
+
result_str = json.dumps(list(result)[:5], default=str)
|
| 1232 |
+
else:
|
| 1233 |
+
result_str = str(result)[:300]
|
| 1234 |
+
|
| 1235 |
+
node_id = node_name.replace(" ", "_").replace("-", "_")
|
| 1236 |
+
|
| 1237 |
+
input_ports_data = []
|
| 1238 |
+
for port in node._input_ports or []:
|
| 1239 |
+
if port in node._fixed_inputs:
|
| 1240 |
+
continue
|
| 1241 |
+
port_history = history.get(node_name, {}).get(port, [])
|
| 1242 |
+
input_ports_data.append(
|
| 1243 |
+
{
|
| 1244 |
+
"name": port,
|
| 1245 |
+
"history_count": len(port_history) if port_history else 0,
|
| 1246 |
+
}
|
| 1247 |
+
)
|
| 1248 |
+
|
| 1249 |
+
output_components, validation_error = self._build_output_components(
|
| 1250 |
+
node, result
|
| 1251 |
+
)
|
| 1252 |
+
scattered_items = (
|
| 1253 |
+
self._build_scattered_items(node_name, result) if is_scattered else []
|
| 1254 |
+
)
|
| 1255 |
+
|
| 1256 |
+
item_output_type = "text"
|
| 1257 |
+
if is_scattered:
|
| 1258 |
+
for comp in node._output_components.values():
|
| 1259 |
+
if comp is None:
|
| 1260 |
+
continue
|
| 1261 |
+
comp_type = self._get_component_type(comp)
|
| 1262 |
+
if comp_type == "audio":
|
| 1263 |
+
item_output_type = "audio"
|
| 1264 |
+
break
|
| 1265 |
+
|
| 1266 |
+
item_list_schema = None
|
| 1267 |
+
item_list_items = []
|
| 1268 |
+
if node._item_list_schemas:
|
| 1269 |
+
first_port = list(node._item_list_schemas.keys())[0]
|
| 1270 |
+
item_list_schema = self._serialize_item_list_schema(
|
| 1271 |
+
node._item_list_schemas[first_port]
|
| 1272 |
+
)
|
| 1273 |
+
item_list_items = self._build_item_list_items(node, first_port, result)
|
| 1274 |
+
|
| 1275 |
+
output_ports = []
|
| 1276 |
+
for port_name in node._output_ports or []:
|
| 1277 |
+
if port_name in node._item_list_schemas:
|
| 1278 |
+
schema = node._item_list_schemas[port_name]
|
| 1279 |
+
for field_name in schema:
|
| 1280 |
+
output_ports.append(f"{port_name}.{field_name}")
|
| 1281 |
+
elif port_name in node._output_components:
|
| 1282 |
+
output_ports.append(port_name)
|
| 1283 |
+
|
| 1284 |
+
is_output = self._is_output_node(node_name)
|
| 1285 |
+
is_local = self._is_running_locally(node)
|
| 1286 |
+
|
| 1287 |
+
variants = None
|
| 1288 |
+
selected_variant = None
|
| 1289 |
+
if isinstance(node, ChoiceNode):
|
| 1290 |
+
variants = [
|
| 1291 |
+
self._build_variant_data(v, input_values) for v in node._variants
|
| 1292 |
+
]
|
| 1293 |
+
selected_variant = input_values.get(node_id, {}).get(
|
| 1294 |
+
"_selected_variant", 0
|
| 1295 |
+
)
|
| 1296 |
+
|
| 1297 |
+
nodes.append(
|
| 1298 |
+
{
|
| 1299 |
+
"id": node_id,
|
| 1300 |
+
"name": node_name,
|
| 1301 |
+
"type": self._get_node_type(node, node_name),
|
| 1302 |
+
"url": self._get_node_url(node),
|
| 1303 |
+
"inputs": input_ports_data,
|
| 1304 |
+
"outputs": output_ports,
|
| 1305 |
+
"x": x,
|
| 1306 |
+
"y": y,
|
| 1307 |
+
"has_input": False,
|
| 1308 |
+
"input_value": input_values.get(node_name, ""),
|
| 1309 |
+
"input_components": [],
|
| 1310 |
+
"output_components": output_components,
|
| 1311 |
+
"is_map_node": is_scattered,
|
| 1312 |
+
"map_items": scattered_items,
|
| 1313 |
+
"map_item_count": len(scattered_items),
|
| 1314 |
+
"item_output_type": item_output_type,
|
| 1315 |
+
"item_list_schema": item_list_schema,
|
| 1316 |
+
"item_list_items": item_list_items,
|
| 1317 |
+
"status": node_statuses.get(node_name, "pending"),
|
| 1318 |
+
"result": result_str,
|
| 1319 |
+
"is_output_node": is_output,
|
| 1320 |
+
"is_input_node": False,
|
| 1321 |
+
"is_local": is_local,
|
| 1322 |
+
"variants": variants,
|
| 1323 |
+
"selected_variant": selected_variant,
|
| 1324 |
+
"validation_error": validation_error,
|
| 1325 |
+
}
|
| 1326 |
+
)
|
| 1327 |
+
|
| 1328 |
+
edges = []
|
| 1329 |
+
for i, edge in enumerate(self.graph._edges):
|
| 1330 |
+
from_port = edge.source_port
|
| 1331 |
+
if edge.item_key:
|
| 1332 |
+
from_port = f"{edge.source_port}.{edge.item_key}"
|
| 1333 |
+
edges.append(
|
| 1334 |
+
{
|
| 1335 |
+
"id": f"edge_{i}",
|
| 1336 |
+
"from_node": edge.source_node._name.replace(" ", "_").replace(
|
| 1337 |
+
"-", "_"
|
| 1338 |
+
),
|
| 1339 |
+
"from_port": from_port,
|
| 1340 |
+
"to_node": edge.target_node._name.replace(" ", "_").replace(
|
| 1341 |
+
"-", "_"
|
| 1342 |
+
),
|
| 1343 |
+
"to_port": edge.target_port,
|
| 1344 |
+
"is_scattered": edge.is_scattered,
|
| 1345 |
+
"is_gathered": edge.is_gathered,
|
| 1346 |
+
}
|
| 1347 |
+
)
|
| 1348 |
+
|
| 1349 |
+
for i, syn_edge in enumerate(synthetic_edges):
|
| 1350 |
+
edges.append(
|
| 1351 |
+
{
|
| 1352 |
+
"id": f"input_edge_{i}",
|
| 1353 |
+
"from_node": syn_edge["from_node"],
|
| 1354 |
+
"from_port": syn_edge["from_port"],
|
| 1355 |
+
"to_node": syn_edge["to_node"],
|
| 1356 |
+
"to_port": syn_edge["to_port"],
|
| 1357 |
+
}
|
| 1358 |
+
)
|
| 1359 |
+
|
| 1360 |
+
return {
|
| 1361 |
+
"name": self.graph.name,
|
| 1362 |
+
"nodes": nodes,
|
| 1363 |
+
"edges": edges,
|
| 1364 |
+
"inputs": input_values,
|
| 1365 |
+
"selected_results": selected_results,
|
| 1366 |
+
"history": history,
|
| 1367 |
+
"session_id": session_id,
|
| 1368 |
+
}
|
| 1369 |
+
|
| 1370 |
+
def _get_ancestors(self, node_name: str) -> list[str]:
|
| 1371 |
+
ancestors = set()
|
| 1372 |
+
to_visit = [node_name]
|
| 1373 |
+
while to_visit:
|
| 1374 |
+
current = to_visit.pop()
|
| 1375 |
+
for source, _, target, _ in self.graph.get_connections():
|
| 1376 |
+
if target == current and source not in ancestors:
|
| 1377 |
+
ancestors.add(source)
|
| 1378 |
+
to_visit.append(source)
|
| 1379 |
+
return list(ancestors)
|
| 1380 |
+
|
| 1381 |
+
def _get_user_provided_output(
|
| 1382 |
+
self, node, node_id: str, input_values: dict[str, Any]
|
| 1383 |
+
) -> dict[str, Any] | None:
|
| 1384 |
+
if not node._output_components:
|
| 1385 |
+
return None
|
| 1386 |
+
|
| 1387 |
+
node_inputs = input_values.get(node_id, {})
|
| 1388 |
+
if not node_inputs:
|
| 1389 |
+
return None
|
| 1390 |
+
|
| 1391 |
+
result = {}
|
| 1392 |
+
has_user_value = False
|
| 1393 |
+
for port_name, comp in node._output_components.items():
|
| 1394 |
+
if comp is None:
|
| 1395 |
+
continue
|
| 1396 |
+
if port_name in node_inputs:
|
| 1397 |
+
value = node_inputs[port_name]
|
| 1398 |
+
if value is not None:
|
| 1399 |
+
if isinstance(value, str) and value.startswith("data:"):
|
| 1400 |
+
value = self._save_data_url_as_gradio_file(value)
|
| 1401 |
+
result[port_name] = value
|
| 1402 |
+
has_user_value = True
|
| 1403 |
+
|
| 1404 |
+
return result if has_user_value else None
|
| 1405 |
+
|
| 1406 |
+
def _save_data_url_as_gradio_file(self, data_url: str):
|
| 1407 |
+
try:
|
| 1408 |
+
header, data = data_url.split(",", 1)
|
| 1409 |
+
mime_type = header.split(":")[1].split(";")[0]
|
| 1410 |
+
ext_map = {
|
| 1411 |
+
"image/png": ".png",
|
| 1412 |
+
"image/jpeg": ".jpg",
|
| 1413 |
+
"image/gif": ".gif",
|
| 1414 |
+
"image/webp": ".webp",
|
| 1415 |
+
"audio/webm": ".webm",
|
| 1416 |
+
"audio/wav": ".wav",
|
| 1417 |
+
"audio/mp3": ".mp3",
|
| 1418 |
+
"audio/mpeg": ".mp3",
|
| 1419 |
+
}
|
| 1420 |
+
ext = ext_map.get(mime_type, ".bin")
|
| 1421 |
+
file_data = base64.b64decode(data)
|
| 1422 |
+
temp_dir = Path(tempfile.gettempdir()) / "daggr_uploads"
|
| 1423 |
+
temp_dir.mkdir(exist_ok=True)
|
| 1424 |
+
file_path = temp_dir / f"{uuid.uuid4()}{ext}"
|
| 1425 |
+
file_path.write_bytes(file_data)
|
| 1426 |
+
return FileValue(str(file_path))
|
| 1427 |
+
except Exception as e:
|
| 1428 |
+
print(f"[ERROR] Failed to save data URL: {e}")
|
| 1429 |
+
return data_url
|
| 1430 |
+
|
| 1431 |
+
def _convert_urls_to_file_values(self, data: Any) -> Any:
|
| 1432 |
+
if isinstance(data, str):
|
| 1433 |
+
if data.startswith(("http://", "https://", "/")) and any(
|
| 1434 |
+
data.lower().endswith(ext)
|
| 1435 |
+
for ext in (
|
| 1436 |
+
".png",
|
| 1437 |
+
".jpg",
|
| 1438 |
+
".jpeg",
|
| 1439 |
+
".gif",
|
| 1440 |
+
".webp",
|
| 1441 |
+
".wav",
|
| 1442 |
+
".mp3",
|
| 1443 |
+
".webm",
|
| 1444 |
+
".mp4",
|
| 1445 |
+
".ogg",
|
| 1446 |
+
)
|
| 1447 |
+
):
|
| 1448 |
+
return FileValue(data)
|
| 1449 |
+
return data
|
| 1450 |
+
elif isinstance(data, dict):
|
| 1451 |
+
return {k: self._convert_urls_to_file_values(v) for k, v in data.items()}
|
| 1452 |
+
elif isinstance(data, list):
|
| 1453 |
+
return [self._convert_urls_to_file_values(item) for item in data]
|
| 1454 |
+
return data
|
| 1455 |
+
|
| 1456 |
+
async def _execute_to_node(
|
| 1457 |
+
self,
|
| 1458 |
+
session: ExecutionSession,
|
| 1459 |
+
target_node: str,
|
| 1460 |
+
session_id: str | None,
|
| 1461 |
+
input_values: dict[str, Any],
|
| 1462 |
+
selected_results: dict[str, int],
|
| 1463 |
+
) -> dict:
|
| 1464 |
+
if not session_id:
|
| 1465 |
+
session_id = self.state.create_session(self.graph.persist_key)
|
| 1466 |
+
|
| 1467 |
+
for node_name, node in self.graph.nodes.items():
|
| 1468 |
+
if isinstance(node, ChoiceNode):
|
| 1469 |
+
node_id = node_name.replace(" ", "_").replace("-", "_")
|
| 1470 |
+
variant_idx = input_values.get(node_id, {}).get("_selected_variant", 0)
|
| 1471 |
+
session.selected_variants[node_name] = variant_idx
|
| 1472 |
+
|
| 1473 |
+
ancestors = self._get_ancestors(target_node)
|
| 1474 |
+
nodes_to_run = ancestors + [target_node]
|
| 1475 |
+
execution_order = self.graph.get_execution_order()
|
| 1476 |
+
nodes_to_execute = [n for n in execution_order if n in nodes_to_run]
|
| 1477 |
+
|
| 1478 |
+
entry_inputs: dict[str, dict[str, Any]] = {}
|
| 1479 |
+
for node_name in nodes_to_execute:
|
| 1480 |
+
node = self.graph.nodes[node_name]
|
| 1481 |
+
if node._input_components:
|
| 1482 |
+
node_inputs = {}
|
| 1483 |
+
for port_name in node._input_components:
|
| 1484 |
+
input_node_name = f"{node_name}__{port_name}"
|
| 1485 |
+
input_node_id = input_node_name.replace(" ", "_").replace("-", "_")
|
| 1486 |
+
if input_node_id in input_values:
|
| 1487 |
+
value = input_values[input_node_id].get("value")
|
| 1488 |
+
if value is not None:
|
| 1489 |
+
node_inputs[port_name] = value
|
| 1490 |
+
if node_inputs:
|
| 1491 |
+
entry_inputs[node_name] = node_inputs
|
| 1492 |
+
elif isinstance(node, InteractionNode):
|
| 1493 |
+
value = input_values.get(node_name, "")
|
| 1494 |
+
port = node._input_ports[0] if node._input_ports else "input"
|
| 1495 |
+
entry_inputs[node_name] = {port: value}
|
| 1496 |
+
|
| 1497 |
+
existing_results = {}
|
| 1498 |
+
if session_id:
|
| 1499 |
+
for node_name in nodes_to_execute:
|
| 1500 |
+
if node_name in selected_results:
|
| 1501 |
+
cached = self.state.get_result_by_index(
|
| 1502 |
+
session_id, node_name, selected_results[node_name]
|
| 1503 |
+
)
|
| 1504 |
+
else:
|
| 1505 |
+
cached = self.state.get_latest_result(session_id, node_name)
|
| 1506 |
+
if cached is not None:
|
| 1507 |
+
existing_results[node_name] = self._convert_urls_to_file_values(
|
| 1508 |
+
cached
|
| 1509 |
+
)
|
| 1510 |
+
|
| 1511 |
+
for k, v in existing_results.items():
|
| 1512 |
+
if k not in session.results:
|
| 1513 |
+
session.results[k] = v
|
| 1514 |
+
|
| 1515 |
+
if target_node in session.results:
|
| 1516 |
+
del session.results[target_node]
|
| 1517 |
+
|
| 1518 |
+
node_results = {}
|
| 1519 |
+
node_statuses = {}
|
| 1520 |
+
|
| 1521 |
+
for node_name in nodes_to_execute:
|
| 1522 |
+
if node_name in existing_results:
|
| 1523 |
+
node_results[node_name] = existing_results[node_name]
|
| 1524 |
+
node_statuses[node_name] = "completed"
|
| 1525 |
+
continue
|
| 1526 |
+
|
| 1527 |
+
if node_name in session.results:
|
| 1528 |
+
node_results[node_name] = session.results[node_name]
|
| 1529 |
+
node_statuses[node_name] = "completed"
|
| 1530 |
+
continue
|
| 1531 |
+
|
| 1532 |
+
node_statuses[node_name] = "running"
|
| 1533 |
+
user_input = entry_inputs.get(node_name, {})
|
| 1534 |
+
result = await self.executor.execute_node(session, node_name, user_input)
|
| 1535 |
+
node_results[node_name] = result
|
| 1536 |
+
node_statuses[node_name] = "completed"
|
| 1537 |
+
self.state.save_result(session_id, node_name, result)
|
| 1538 |
+
|
| 1539 |
+
return self._build_graph_data(
|
| 1540 |
+
node_results, node_statuses, input_values, {}, session_id, selected_results
|
| 1541 |
+
)
|
| 1542 |
+
|
| 1543 |
+
async def _execute_to_node_streaming(
|
| 1544 |
+
self,
|
| 1545 |
+
session: ExecutionSession,
|
| 1546 |
+
target_node: str,
|
| 1547 |
+
sheet_id: str | None,
|
| 1548 |
+
input_values: dict[str, Any],
|
| 1549 |
+
item_list_values: dict[str, Any],
|
| 1550 |
+
selected_results: dict[str, int],
|
| 1551 |
+
run_id: str,
|
| 1552 |
+
user_id: str | None = None,
|
| 1553 |
+
run_ancestors: bool = True,
|
| 1554 |
+
):
|
| 1555 |
+
can_persist = (
|
| 1556 |
+
user_id is not None
|
| 1557 |
+
and sheet_id is not None
|
| 1558 |
+
and self.graph.persist_key is not None
|
| 1559 |
+
)
|
| 1560 |
+
|
| 1561 |
+
for node_name, node in self.graph.nodes.items():
|
| 1562 |
+
if isinstance(node, ChoiceNode):
|
| 1563 |
+
node_id = node_name.replace(" ", "_").replace("-", "_")
|
| 1564 |
+
variant_idx = input_values.get(node_id, {}).get("_selected_variant", 0)
|
| 1565 |
+
session.selected_variants[node_name] = variant_idx
|
| 1566 |
+
|
| 1567 |
+
if run_ancestors:
|
| 1568 |
+
ancestors = self._get_ancestors(target_node)
|
| 1569 |
+
nodes_to_run = ancestors + [target_node]
|
| 1570 |
+
else:
|
| 1571 |
+
nodes_to_run = [target_node]
|
| 1572 |
+
execution_order = self.graph.get_execution_order()
|
| 1573 |
+
nodes_to_execute = [n for n in execution_order if n in nodes_to_run]
|
| 1574 |
+
|
| 1575 |
+
entry_inputs: dict[str, dict[str, Any]] = {}
|
| 1576 |
+
for node_name in nodes_to_execute:
|
| 1577 |
+
node = self.graph.nodes[node_name]
|
| 1578 |
+
if node._input_components:
|
| 1579 |
+
node_inputs = {}
|
| 1580 |
+
for port_name in node._input_components:
|
| 1581 |
+
input_node_name = f"{node_name}__{port_name}"
|
| 1582 |
+
input_node_id = input_node_name.replace(" ", "_").replace("-", "_")
|
| 1583 |
+
if input_node_id in input_values:
|
| 1584 |
+
value = input_values[input_node_id].get("value")
|
| 1585 |
+
if value is not None:
|
| 1586 |
+
node_inputs[port_name] = value
|
| 1587 |
+
if node_inputs:
|
| 1588 |
+
entry_inputs[node_name] = node_inputs
|
| 1589 |
+
elif isinstance(node, InteractionNode):
|
| 1590 |
+
value = input_values.get(node_name, "")
|
| 1591 |
+
port = node._input_ports[0] if node._input_ports else "input"
|
| 1592 |
+
entry_inputs[node_name] = {port: value}
|
| 1593 |
+
|
| 1594 |
+
existing_results = {}
|
| 1595 |
+
for node_name in nodes_to_execute:
|
| 1596 |
+
node = self.graph.nodes[node_name]
|
| 1597 |
+
node_id = node_name.replace(" ", "_").replace("-", "_")
|
| 1598 |
+
user_output = self._get_user_provided_output(node, node_id, input_values)
|
| 1599 |
+
if user_output is not None:
|
| 1600 |
+
existing_results[node_name] = user_output
|
| 1601 |
+
if can_persist:
|
| 1602 |
+
snapshot = {
|
| 1603 |
+
"inputs": input_values,
|
| 1604 |
+
"selected_results": selected_results,
|
| 1605 |
+
}
|
| 1606 |
+
self.state.save_result(sheet_id, node_name, user_output, snapshot)
|
| 1607 |
+
continue
|
| 1608 |
+
|
| 1609 |
+
if node_name == target_node:
|
| 1610 |
+
continue
|
| 1611 |
+
|
| 1612 |
+
if can_persist:
|
| 1613 |
+
if node_name in selected_results:
|
| 1614 |
+
cached = self.state.get_result_by_index(
|
| 1615 |
+
sheet_id, node_name, selected_results[node_name]
|
| 1616 |
+
)
|
| 1617 |
+
else:
|
| 1618 |
+
cached = self.state.get_latest_result(sheet_id, node_name)
|
| 1619 |
+
if cached is not None:
|
| 1620 |
+
existing_results[node_name] = self._convert_urls_to_file_values(
|
| 1621 |
+
cached
|
| 1622 |
+
)
|
| 1623 |
+
|
| 1624 |
+
for k, v in existing_results.items():
|
| 1625 |
+
if k not in session.results:
|
| 1626 |
+
session.results[k] = v
|
| 1627 |
+
|
| 1628 |
+
if target_node in session.results:
|
| 1629 |
+
del session.results[target_node]
|
| 1630 |
+
|
| 1631 |
+
node_results = {}
|
| 1632 |
+
node_statuses = {}
|
| 1633 |
+
|
| 1634 |
+
try:
|
| 1635 |
+
for node_name in nodes_to_execute:
|
| 1636 |
+
if node_name in existing_results:
|
| 1637 |
+
result = existing_results[node_name]
|
| 1638 |
+
result = self._apply_item_list_edits(
|
| 1639 |
+
node_name, result, item_list_values
|
| 1640 |
+
)
|
| 1641 |
+
node_results[node_name] = result
|
| 1642 |
+
session.results[node_name] = result
|
| 1643 |
+
node_statuses[node_name] = "completed"
|
| 1644 |
+
continue
|
| 1645 |
+
|
| 1646 |
+
if node_name in session.results:
|
| 1647 |
+
result = session.results[node_name]
|
| 1648 |
+
result = self._apply_item_list_edits(
|
| 1649 |
+
node_name, result, item_list_values
|
| 1650 |
+
)
|
| 1651 |
+
node_results[node_name] = result
|
| 1652 |
+
node_statuses[node_name] = "completed"
|
| 1653 |
+
continue
|
| 1654 |
+
|
| 1655 |
+
can_execute = await session.start_node_execution(node_name)
|
| 1656 |
+
if not can_execute:
|
| 1657 |
+
if node_name == target_node:
|
| 1658 |
+
return
|
| 1659 |
+
await session.wait_for_node(node_name)
|
| 1660 |
+
if node_name in session.results:
|
| 1661 |
+
result = session.results[node_name]
|
| 1662 |
+
result = self._apply_item_list_edits(
|
| 1663 |
+
node_name, result, item_list_values
|
| 1664 |
+
)
|
| 1665 |
+
node_results[node_name] = result
|
| 1666 |
+
node_statuses[node_name] = "completed"
|
| 1667 |
+
continue
|
| 1668 |
+
|
| 1669 |
+
try:
|
| 1670 |
+
node_statuses[node_name] = "running"
|
| 1671 |
+
user_input = entry_inputs.get(node_name, {})
|
| 1672 |
+
|
| 1673 |
+
yield {
|
| 1674 |
+
"type": "node_started",
|
| 1675 |
+
"started_node": node_name,
|
| 1676 |
+
"run_id": run_id,
|
| 1677 |
+
}
|
| 1678 |
+
|
| 1679 |
+
start_time = time.time()
|
| 1680 |
+
result = await self.executor.execute_node(
|
| 1681 |
+
session, node_name, user_input
|
| 1682 |
+
)
|
| 1683 |
+
elapsed_ms = (time.time() - start_time) * 1000
|
| 1684 |
+
|
| 1685 |
+
result = self._apply_item_list_edits(
|
| 1686 |
+
node_name, result, item_list_values
|
| 1687 |
+
)
|
| 1688 |
+
session.results[node_name] = result
|
| 1689 |
+
node_results[node_name] = result
|
| 1690 |
+
node_statuses[node_name] = "completed"
|
| 1691 |
+
|
| 1692 |
+
if can_persist:
|
| 1693 |
+
current_count = self.state.get_result_count(sheet_id, node_name)
|
| 1694 |
+
snapshot = {
|
| 1695 |
+
"inputs": input_values,
|
| 1696 |
+
"selected_results": selected_results,
|
| 1697 |
+
}
|
| 1698 |
+
self.state.save_result(sheet_id, node_name, result, snapshot)
|
| 1699 |
+
selected_results[node_name] = current_count
|
| 1700 |
+
|
| 1701 |
+
graph_data = self._build_graph_data(
|
| 1702 |
+
node_results,
|
| 1703 |
+
node_statuses,
|
| 1704 |
+
input_values,
|
| 1705 |
+
{},
|
| 1706 |
+
sheet_id,
|
| 1707 |
+
selected_results,
|
| 1708 |
+
)
|
| 1709 |
+
graph_data["type"] = "node_complete"
|
| 1710 |
+
graph_data["completed_node"] = node_name
|
| 1711 |
+
graph_data["run_id"] = run_id
|
| 1712 |
+
graph_data["execution_time_ms"] = elapsed_ms
|
| 1713 |
+
finally:
|
| 1714 |
+
await session.finish_node_execution(node_name)
|
| 1715 |
+
yield graph_data
|
| 1716 |
+
|
| 1717 |
+
except Exception as e:
|
| 1718 |
+
error_node = None
|
| 1719 |
+
if nodes_to_execute:
|
| 1720 |
+
current_idx = len(node_results)
|
| 1721 |
+
if current_idx < len(nodes_to_execute):
|
| 1722 |
+
error_node = nodes_to_execute[current_idx]
|
| 1723 |
+
node_statuses[error_node] = "error"
|
| 1724 |
+
node_results[error_node] = {"error": str(e)}
|
| 1725 |
+
|
| 1726 |
+
graph_data = self._build_graph_data(
|
| 1727 |
+
node_results,
|
| 1728 |
+
node_statuses,
|
| 1729 |
+
input_values,
|
| 1730 |
+
{},
|
| 1731 |
+
sheet_id,
|
| 1732 |
+
selected_results,
|
| 1733 |
+
)
|
| 1734 |
+
graph_data["type"] = "error"
|
| 1735 |
+
graph_data["run_id"] = run_id
|
| 1736 |
+
graph_data["error"] = str(e)
|
| 1737 |
+
graph_data["nodes_to_clear"] = nodes_to_execute
|
| 1738 |
+
if error_node:
|
| 1739 |
+
graph_data["node"] = error_node
|
| 1740 |
+
graph_data["completed_node"] = error_node
|
| 1741 |
+
yield graph_data
|
| 1742 |
+
|
| 1743 |
+
async def _execute_workflow_api(
|
| 1744 |
+
self, request: Request, subgraph_id: str | None = None
|
| 1745 |
+
) -> JSONResponse:
|
| 1746 |
+
try:
|
| 1747 |
+
body = await request.json()
|
| 1748 |
+
except Exception:
|
| 1749 |
+
body = {}
|
| 1750 |
+
|
| 1751 |
+
input_values = body.get("inputs", {})
|
| 1752 |
+
session = ExecutionSession(self.graph)
|
| 1753 |
+
|
| 1754 |
+
subgraphs = self.graph.get_subgraphs()
|
| 1755 |
+
output_node_names = set(self.graph.get_output_nodes())
|
| 1756 |
+
|
| 1757 |
+
if subgraph_id is None:
|
| 1758 |
+
if len(subgraphs) > 1:
|
| 1759 |
+
return JSONResponse(
|
| 1760 |
+
{
|
| 1761 |
+
"error": "Multiple subgraphs detected. Please specify a subgraph_id.",
|
| 1762 |
+
"available_subgraphs": [
|
| 1763 |
+
f"subgraph_{i}" for i in range(len(subgraphs))
|
| 1764 |
+
],
|
| 1765 |
+
},
|
| 1766 |
+
status_code=400,
|
| 1767 |
+
)
|
| 1768 |
+
target_nodes = subgraphs[0] if subgraphs else set(self.graph.nodes.keys())
|
| 1769 |
+
else:
|
| 1770 |
+
if subgraph_id == "main" and len(subgraphs) == 1:
|
| 1771 |
+
target_nodes = subgraphs[0]
|
| 1772 |
+
elif subgraph_id.startswith("subgraph_"):
|
| 1773 |
+
try:
|
| 1774 |
+
idx = int(subgraph_id.split("_")[1])
|
| 1775 |
+
if idx < 0 or idx >= len(subgraphs):
|
| 1776 |
+
return JSONResponse(
|
| 1777 |
+
{"error": f"Subgraph '{subgraph_id}' not found"},
|
| 1778 |
+
status_code=404,
|
| 1779 |
+
)
|
| 1780 |
+
target_nodes = subgraphs[idx]
|
| 1781 |
+
except (ValueError, IndexError):
|
| 1782 |
+
return JSONResponse(
|
| 1783 |
+
{"error": f"Invalid subgraph_id '{subgraph_id}'"},
|
| 1784 |
+
status_code=400,
|
| 1785 |
+
)
|
| 1786 |
+
else:
|
| 1787 |
+
return JSONResponse(
|
| 1788 |
+
{"error": f"Subgraph '{subgraph_id}' not found"},
|
| 1789 |
+
status_code=404,
|
| 1790 |
+
)
|
| 1791 |
+
|
| 1792 |
+
for node_name, node in self.graph.nodes.items():
|
| 1793 |
+
if isinstance(node, ChoiceNode):
|
| 1794 |
+
node_id = node_name.replace(" ", "_").replace("-", "_")
|
| 1795 |
+
variant_idx = input_values.get(f"{node_id}___selected_variant", 0)
|
| 1796 |
+
session.selected_variants[node_name] = variant_idx
|
| 1797 |
+
|
| 1798 |
+
execution_order = self.graph.get_execution_order()
|
| 1799 |
+
nodes_to_execute = [n for n in execution_order if n in target_nodes]
|
| 1800 |
+
|
| 1801 |
+
entry_inputs: dict[str, dict[str, Any]] = {}
|
| 1802 |
+
for node_name in nodes_to_execute:
|
| 1803 |
+
node = self.graph.nodes[node_name]
|
| 1804 |
+
if node._input_components:
|
| 1805 |
+
node_inputs = {}
|
| 1806 |
+
for port_name in node._input_components:
|
| 1807 |
+
input_node_id = f"{node_name}__{port_name}".replace(
|
| 1808 |
+
" ", "_"
|
| 1809 |
+
).replace("-", "_")
|
| 1810 |
+
if input_node_id in input_values:
|
| 1811 |
+
node_inputs[port_name] = input_values[input_node_id]
|
| 1812 |
+
if node_inputs:
|
| 1813 |
+
entry_inputs[node_name] = node_inputs
|
| 1814 |
+
|
| 1815 |
+
session.results = {}
|
| 1816 |
+
node_results = {}
|
| 1817 |
+
|
| 1818 |
+
try:
|
| 1819 |
+
for node_name in nodes_to_execute:
|
| 1820 |
+
user_input = entry_inputs.get(node_name, {})
|
| 1821 |
+
result = await self.executor.execute_node(
|
| 1822 |
+
session, node_name, user_input
|
| 1823 |
+
)
|
| 1824 |
+
node_results[node_name] = result
|
| 1825 |
+
except Exception as e:
|
| 1826 |
+
return JSONResponse(
|
| 1827 |
+
{"error": f"Execution error in node '{node_name}': {str(e)}"},
|
| 1828 |
+
status_code=500,
|
| 1829 |
+
)
|
| 1830 |
+
|
| 1831 |
+
outputs = {}
|
| 1832 |
+
for node_name in nodes_to_execute:
|
| 1833 |
+
if node_name in output_node_names and node_name in node_results:
|
| 1834 |
+
result = node_results[node_name]
|
| 1835 |
+
result = self._transform_file_paths(result)
|
| 1836 |
+
outputs[node_name] = result
|
| 1837 |
+
|
| 1838 |
+
return JSONResponse({"outputs": outputs})
|
| 1839 |
+
|
| 1840 |
+
def run(
|
| 1841 |
+
self,
|
| 1842 |
+
host: str | None = None,
|
| 1843 |
+
port: int | None = None,
|
| 1844 |
+
share: bool | None = None,
|
| 1845 |
+
open_browser: bool = True,
|
| 1846 |
+
**kwargs,
|
| 1847 |
+
):
|
| 1848 |
+
from gradio.utils import colab_check, ipython_check
|
| 1849 |
+
|
| 1850 |
+
if host is None:
|
| 1851 |
+
host = os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1")
|
| 1852 |
+
if port is None:
|
| 1853 |
+
port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
|
| 1854 |
+
|
| 1855 |
+
actual_port = _find_available_port(host, port)
|
| 1856 |
+
if actual_port != port:
|
| 1857 |
+
print(f"\n Port {port} is in use, using {actual_port} instead.")
|
| 1858 |
+
|
| 1859 |
+
self.graph._validate_edges()
|
| 1860 |
+
|
| 1861 |
+
is_colab = colab_check()
|
| 1862 |
+
is_kaggle = os.environ.get("KAGGLE_KERNEL_RUN_TYPE") is not None
|
| 1863 |
+
is_notebook = is_colab or is_kaggle or ipython_check()
|
| 1864 |
+
|
| 1865 |
+
if share is None:
|
| 1866 |
+
share = is_colab or is_kaggle
|
| 1867 |
+
|
| 1868 |
+
if is_notebook or share:
|
| 1869 |
+
config = uvicorn.Config(
|
| 1870 |
+
app=self.app,
|
| 1871 |
+
host=host,
|
| 1872 |
+
port=actual_port,
|
| 1873 |
+
log_level="warning",
|
| 1874 |
+
)
|
| 1875 |
+
server = _Server(config)
|
| 1876 |
+
server.run_in_thread()
|
| 1877 |
+
|
| 1878 |
+
local_url = f"http://{host}:{actual_port}"
|
| 1879 |
+
print(f"\n UI running at: {local_url}")
|
| 1880 |
+
if self.api_server:
|
| 1881 |
+
print(f" API server at: {local_url}/api")
|
| 1882 |
+
|
| 1883 |
+
share_url = None
|
| 1884 |
+
if share:
|
| 1885 |
+
from gradio.networking import setup_tunnel
|
| 1886 |
+
|
| 1887 |
+
share_token = secrets.token_urlsafe(32)
|
| 1888 |
+
share_url = setup_tunnel(
|
| 1889 |
+
local_host=host,
|
| 1890 |
+
local_port=actual_port,
|
| 1891 |
+
share_token=share_token,
|
| 1892 |
+
share_server_address=None,
|
| 1893 |
+
share_server_tls_certificate=None,
|
| 1894 |
+
)
|
| 1895 |
+
print(f" Public URL: {share_url}")
|
| 1896 |
+
print(
|
| 1897 |
+
"\n This share link expires in 1 week. For permanent hosting, deploy to Hugging Face Spaces.\n"
|
| 1898 |
+
)
|
| 1899 |
+
|
| 1900 |
+
if is_colab or is_kaggle:
|
| 1901 |
+
from IPython.display import HTML, display
|
| 1902 |
+
|
| 1903 |
+
url = share_url or local_url
|
| 1904 |
+
display(
|
| 1905 |
+
HTML(f'<a href="{url}" target="_blank">Open daggr app: {url}</a>')
|
| 1906 |
+
)
|
| 1907 |
+
elif open_browser:
|
| 1908 |
+
webbrowser.open_new_tab(share_url or local_url)
|
| 1909 |
+
|
| 1910 |
+
try:
|
| 1911 |
+
while True:
|
| 1912 |
+
time.sleep(1)
|
| 1913 |
+
except KeyboardInterrupt:
|
| 1914 |
+
print("\nShutting down...")
|
| 1915 |
+
server.close()
|
| 1916 |
+
else:
|
| 1917 |
+
local_url = f"http://{host}:{actual_port}"
|
| 1918 |
+
print(f"\n UI running at: {local_url}")
|
| 1919 |
+
if self.api_server:
|
| 1920 |
+
print(f" API server at: {local_url}/api")
|
| 1921 |
+
print()
|
| 1922 |
+
if open_browser:
|
| 1923 |
+
threading.Timer(0.5, lambda: webbrowser.open_new_tab(local_url)).start()
|
| 1924 |
+
uvicorn.run(
|
| 1925 |
+
self.app, host=host, port=actual_port, log_level="warning", **kwargs
|
| 1926 |
+
)
|
| 1927 |
+
|
| 1928 |
+
|
| 1929 |
+
class _Server(uvicorn.Server):
|
| 1930 |
+
def install_signal_handlers(self):
|
| 1931 |
+
pass
|
| 1932 |
+
|
| 1933 |
+
def run_in_thread(self):
|
| 1934 |
+
self.thread = threading.Thread(target=self.run, daemon=True)
|
| 1935 |
+
self.thread.start()
|
| 1936 |
+
start = time.time()
|
| 1937 |
+
while not self.started:
|
| 1938 |
+
time.sleep(1e-3)
|
| 1939 |
+
if time.time() - start > 5:
|
| 1940 |
+
raise RuntimeError(
|
| 1941 |
+
"Server failed to start. Please check that the port is available."
|
| 1942 |
+
)
|
| 1943 |
+
|
| 1944 |
+
def close(self):
|
| 1945 |
+
self.should_exit = True
|
| 1946 |
+
self.thread.join(timeout=5)
|
daggr/session.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Session management for daggr, including per-session execution contexts for security isolation and concurrency management."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import asyncio
|
| 6 |
+
from typing import TYPE_CHECKING, Any
|
| 7 |
+
|
| 8 |
+
if TYPE_CHECKING:
|
| 9 |
+
from daggr.graph import Graph
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ConcurrencyManager:
|
| 13 |
+
"""Manages concurrency limits for FnNode execution within a session.
|
| 14 |
+
|
| 15 |
+
By default, only one FnNode runs at a time per session. FnNodes can opt
|
| 16 |
+
into concurrent execution via the `concurrent` parameter, and can share
|
| 17 |
+
limits via `concurrency_group`.
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self):
|
| 21 |
+
self._default_semaphore = asyncio.Semaphore(1)
|
| 22 |
+
self._group_semaphores: dict[str, asyncio.Semaphore] = {}
|
| 23 |
+
self._lock = asyncio.Lock()
|
| 24 |
+
|
| 25 |
+
async def get_semaphore(
|
| 26 |
+
self,
|
| 27 |
+
concurrent: bool,
|
| 28 |
+
concurrency_group: str | None,
|
| 29 |
+
max_concurrent: int,
|
| 30 |
+
) -> asyncio.Semaphore | None:
|
| 31 |
+
"""Get the appropriate semaphore for a FnNode.
|
| 32 |
+
|
| 33 |
+
Returns None if the node should run without concurrency limits
|
| 34 |
+
(concurrent=True with no group).
|
| 35 |
+
"""
|
| 36 |
+
if not concurrent:
|
| 37 |
+
return self._default_semaphore
|
| 38 |
+
|
| 39 |
+
if concurrency_group:
|
| 40 |
+
async with self._lock:
|
| 41 |
+
if concurrency_group not in self._group_semaphores:
|
| 42 |
+
self._group_semaphores[concurrency_group] = asyncio.Semaphore(
|
| 43 |
+
max_concurrent
|
| 44 |
+
)
|
| 45 |
+
return self._group_semaphores[concurrency_group]
|
| 46 |
+
|
| 47 |
+
return None
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class ExecutionSession:
|
| 51 |
+
"""Per-session execution context.
|
| 52 |
+
|
| 53 |
+
Each WebSocket connection gets its own ExecutionSession, providing:
|
| 54 |
+
- Isolated HF token
|
| 55 |
+
- Isolated results cache
|
| 56 |
+
- Isolated Gradio client cache
|
| 57 |
+
- Per-session concurrency management
|
| 58 |
+
- Node execution coordination (wait for dependencies)
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, graph: Graph, hf_token: str | None = None):
|
| 62 |
+
self.graph = graph
|
| 63 |
+
self.hf_token = hf_token
|
| 64 |
+
self.results: dict[str, Any] = {}
|
| 65 |
+
self.scattered_results: dict[str, list[Any]] = {}
|
| 66 |
+
self.selected_variants: dict[str, int] = {}
|
| 67 |
+
self.clients: dict[str, Any] = {}
|
| 68 |
+
self.concurrency = ConcurrencyManager()
|
| 69 |
+
|
| 70 |
+
self._executing_nodes: dict[str, asyncio.Event] = {}
|
| 71 |
+
self._execution_lock = asyncio.Lock()
|
| 72 |
+
|
| 73 |
+
def set_hf_token(self, token: str | None):
|
| 74 |
+
"""Update the HF token and clear cached clients."""
|
| 75 |
+
if token != self.hf_token:
|
| 76 |
+
self.hf_token = token
|
| 77 |
+
self.clients = {}
|
| 78 |
+
|
| 79 |
+
def clear_results(self):
|
| 80 |
+
"""Clear cached results for a fresh execution."""
|
| 81 |
+
self.results = {}
|
| 82 |
+
self.scattered_results = {}
|
| 83 |
+
|
| 84 |
+
async def wait_for_node(self, node_name: str) -> bool:
|
| 85 |
+
"""Wait for a node to finish executing if it's currently running.
|
| 86 |
+
|
| 87 |
+
Returns True if we waited (node was executing), False otherwise.
|
| 88 |
+
"""
|
| 89 |
+
async with self._execution_lock:
|
| 90 |
+
event = self._executing_nodes.get(node_name)
|
| 91 |
+
|
| 92 |
+
if event:
|
| 93 |
+
await event.wait()
|
| 94 |
+
return True
|
| 95 |
+
return False
|
| 96 |
+
|
| 97 |
+
async def start_node_execution(self, node_name: str) -> bool:
|
| 98 |
+
"""Mark a node as starting execution.
|
| 99 |
+
|
| 100 |
+
Returns True if we can start (no one else is executing it).
|
| 101 |
+
Returns False if someone else is already executing it.
|
| 102 |
+
"""
|
| 103 |
+
async with self._execution_lock:
|
| 104 |
+
if node_name in self._executing_nodes:
|
| 105 |
+
return False
|
| 106 |
+
self._executing_nodes[node_name] = asyncio.Event()
|
| 107 |
+
return True
|
| 108 |
+
|
| 109 |
+
async def finish_node_execution(self, node_name: str):
|
| 110 |
+
"""Mark a node as finished executing and notify waiters."""
|
| 111 |
+
async with self._execution_lock:
|
| 112 |
+
event = self._executing_nodes.pop(node_name, None)
|
| 113 |
+
if event:
|
| 114 |
+
event.set()
|
daggr/state.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import sqlite3
|
| 6 |
+
import uuid
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
from huggingface_hub import constants
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_daggr_cache_dir() -> Path:
|
| 15 |
+
"""Get the daggr cache directory, respecting HF_HOME env var."""
|
| 16 |
+
cache_dir = Path(constants.HF_HOME) / "daggr"
|
| 17 |
+
cache_dir.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
return cache_dir
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_daggr_files_dir() -> Path:
|
| 22 |
+
files_dir = get_daggr_cache_dir() / "files"
|
| 23 |
+
files_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
return files_dir
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class SessionState:
|
| 28 |
+
def __init__(self, db_path: str | None = None):
|
| 29 |
+
if db_path is None:
|
| 30 |
+
db_path = str(get_daggr_cache_dir() / "sessions.db")
|
| 31 |
+
self.db_path = db_path
|
| 32 |
+
self._init_db()
|
| 33 |
+
|
| 34 |
+
def _init_db(self):
|
| 35 |
+
conn = sqlite3.connect(self.db_path)
|
| 36 |
+
cursor = conn.cursor()
|
| 37 |
+
|
| 38 |
+
self._migrate_legacy_schema(cursor)
|
| 39 |
+
|
| 40 |
+
cursor.execute("""
|
| 41 |
+
CREATE TABLE IF NOT EXISTS sheets (
|
| 42 |
+
sheet_id TEXT PRIMARY KEY,
|
| 43 |
+
user_id TEXT NOT NULL,
|
| 44 |
+
graph_name TEXT NOT NULL,
|
| 45 |
+
name TEXT,
|
| 46 |
+
transform TEXT,
|
| 47 |
+
created_at TEXT,
|
| 48 |
+
updated_at TEXT
|
| 49 |
+
)
|
| 50 |
+
""")
|
| 51 |
+
|
| 52 |
+
cursor.execute("PRAGMA table_info(sheets)")
|
| 53 |
+
columns = [col[1] for col in cursor.fetchall()]
|
| 54 |
+
if "transform" not in columns:
|
| 55 |
+
cursor.execute("ALTER TABLE sheets ADD COLUMN transform TEXT")
|
| 56 |
+
|
| 57 |
+
cursor.execute("""
|
| 58 |
+
CREATE INDEX IF NOT EXISTS idx_sheets_user_graph
|
| 59 |
+
ON sheets(user_id, graph_name)
|
| 60 |
+
""")
|
| 61 |
+
|
| 62 |
+
cursor.execute("""
|
| 63 |
+
CREATE TABLE IF NOT EXISTS node_inputs (
|
| 64 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 65 |
+
sheet_id TEXT,
|
| 66 |
+
node_name TEXT,
|
| 67 |
+
port_name TEXT,
|
| 68 |
+
value TEXT,
|
| 69 |
+
updated_at TEXT,
|
| 70 |
+
FOREIGN KEY (sheet_id) REFERENCES sheets(sheet_id) ON DELETE CASCADE,
|
| 71 |
+
UNIQUE(sheet_id, node_name, port_name)
|
| 72 |
+
)
|
| 73 |
+
""")
|
| 74 |
+
|
| 75 |
+
cursor.execute("""
|
| 76 |
+
CREATE INDEX IF NOT EXISTS idx_node_inputs_sheet
|
| 77 |
+
ON node_inputs(sheet_id)
|
| 78 |
+
""")
|
| 79 |
+
|
| 80 |
+
cursor.execute("""
|
| 81 |
+
CREATE TABLE IF NOT EXISTS node_results (
|
| 82 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 83 |
+
sheet_id TEXT,
|
| 84 |
+
node_name TEXT,
|
| 85 |
+
result TEXT,
|
| 86 |
+
inputs_snapshot TEXT,
|
| 87 |
+
created_at TEXT,
|
| 88 |
+
FOREIGN KEY (sheet_id) REFERENCES sheets(sheet_id) ON DELETE CASCADE
|
| 89 |
+
)
|
| 90 |
+
""")
|
| 91 |
+
|
| 92 |
+
cursor.execute("PRAGMA table_info(node_results)")
|
| 93 |
+
result_columns = [col[1] for col in cursor.fetchall()]
|
| 94 |
+
if "inputs_snapshot" not in result_columns:
|
| 95 |
+
cursor.execute("ALTER TABLE node_results ADD COLUMN inputs_snapshot TEXT")
|
| 96 |
+
|
| 97 |
+
cursor.execute("""
|
| 98 |
+
CREATE INDEX IF NOT EXISTS idx_node_results_sheet_node
|
| 99 |
+
ON node_results(sheet_id, node_name)
|
| 100 |
+
""")
|
| 101 |
+
|
| 102 |
+
conn.commit()
|
| 103 |
+
conn.close()
|
| 104 |
+
|
| 105 |
+
def _migrate_legacy_schema(self, cursor):
|
| 106 |
+
cursor.execute(
|
| 107 |
+
"SELECT name FROM sqlite_master WHERE type='table' AND name='node_inputs'"
|
| 108 |
+
)
|
| 109 |
+
if cursor.fetchone():
|
| 110 |
+
cursor.execute("PRAGMA table_info(node_inputs)")
|
| 111 |
+
columns = [col[1] for col in cursor.fetchall()]
|
| 112 |
+
if "session_id" in columns and "sheet_id" not in columns:
|
| 113 |
+
cursor.execute("ALTER TABLE node_inputs RENAME TO _node_inputs_old")
|
| 114 |
+
cursor.execute("ALTER TABLE node_results RENAME TO _node_results_old")
|
| 115 |
+
cursor.execute("ALTER TABLE sessions RENAME TO _sessions_old")
|
| 116 |
+
|
| 117 |
+
cursor.execute("""
|
| 118 |
+
CREATE TABLE sheets (
|
| 119 |
+
sheet_id TEXT PRIMARY KEY,
|
| 120 |
+
user_id TEXT NOT NULL,
|
| 121 |
+
graph_name TEXT NOT NULL,
|
| 122 |
+
name TEXT,
|
| 123 |
+
created_at TEXT,
|
| 124 |
+
updated_at TEXT
|
| 125 |
+
)
|
| 126 |
+
""")
|
| 127 |
+
cursor.execute("""
|
| 128 |
+
CREATE TABLE node_inputs (
|
| 129 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 130 |
+
sheet_id TEXT,
|
| 131 |
+
node_name TEXT,
|
| 132 |
+
port_name TEXT,
|
| 133 |
+
value TEXT,
|
| 134 |
+
updated_at TEXT,
|
| 135 |
+
FOREIGN KEY (sheet_id) REFERENCES sheets(sheet_id) ON DELETE CASCADE,
|
| 136 |
+
UNIQUE(sheet_id, node_name, port_name)
|
| 137 |
+
)
|
| 138 |
+
""")
|
| 139 |
+
cursor.execute("""
|
| 140 |
+
CREATE TABLE node_results (
|
| 141 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 142 |
+
sheet_id TEXT,
|
| 143 |
+
node_name TEXT,
|
| 144 |
+
result TEXT,
|
| 145 |
+
created_at TEXT,
|
| 146 |
+
FOREIGN KEY (sheet_id) REFERENCES sheets(sheet_id) ON DELETE CASCADE
|
| 147 |
+
)
|
| 148 |
+
""")
|
| 149 |
+
|
| 150 |
+
cursor.execute("""
|
| 151 |
+
INSERT INTO sheets (sheet_id, user_id, graph_name, name, created_at, updated_at)
|
| 152 |
+
SELECT session_id, 'local', graph_name, 'Migrated Sheet', created_at, updated_at
|
| 153 |
+
FROM _sessions_old
|
| 154 |
+
""")
|
| 155 |
+
cursor.execute("""
|
| 156 |
+
INSERT INTO node_inputs (sheet_id, node_name, port_name, value, updated_at)
|
| 157 |
+
SELECT session_id, node_name, port_name, value, updated_at
|
| 158 |
+
FROM _node_inputs_old
|
| 159 |
+
""")
|
| 160 |
+
cursor.execute("""
|
| 161 |
+
INSERT INTO node_results (sheet_id, node_name, result, created_at)
|
| 162 |
+
SELECT session_id, node_name, result, created_at
|
| 163 |
+
FROM _node_results_old
|
| 164 |
+
""")
|
| 165 |
+
|
| 166 |
+
cursor.execute("DROP TABLE _sessions_old")
|
| 167 |
+
cursor.execute("DROP TABLE _node_inputs_old")
|
| 168 |
+
cursor.execute("DROP TABLE _node_results_old")
|
| 169 |
+
|
| 170 |
+
def get_effective_user_id(self, hf_user: dict | None = None) -> str | None:
|
| 171 |
+
is_on_spaces = os.environ.get("SPACE_ID") is not None
|
| 172 |
+
if hf_user and hf_user.get("username"):
|
| 173 |
+
return hf_user["username"]
|
| 174 |
+
if is_on_spaces:
|
| 175 |
+
return None
|
| 176 |
+
return "local"
|
| 177 |
+
|
| 178 |
+
def create_sheet(
|
| 179 |
+
self, user_id: str, graph_name: str, name: str | None = None
|
| 180 |
+
) -> str:
|
| 181 |
+
sheet_id = str(uuid.uuid4())
|
| 182 |
+
now = datetime.now().isoformat()
|
| 183 |
+
|
| 184 |
+
if not name:
|
| 185 |
+
count = self.get_sheet_count(user_id, graph_name)
|
| 186 |
+
name = f"Sheet {count + 1}"
|
| 187 |
+
|
| 188 |
+
conn = sqlite3.connect(self.db_path)
|
| 189 |
+
cursor = conn.cursor()
|
| 190 |
+
cursor.execute(
|
| 191 |
+
"""INSERT INTO sheets (sheet_id, user_id, graph_name, name, created_at, updated_at)
|
| 192 |
+
VALUES (?, ?, ?, ?, ?, ?)""",
|
| 193 |
+
(sheet_id, user_id, graph_name, name, now, now),
|
| 194 |
+
)
|
| 195 |
+
conn.commit()
|
| 196 |
+
conn.close()
|
| 197 |
+
return sheet_id
|
| 198 |
+
|
| 199 |
+
def get_sheet_count(self, user_id: str, graph_name: str) -> int:
|
| 200 |
+
conn = sqlite3.connect(self.db_path)
|
| 201 |
+
cursor = conn.cursor()
|
| 202 |
+
cursor.execute(
|
| 203 |
+
"SELECT COUNT(*) FROM sheets WHERE user_id = ? AND graph_name = ?",
|
| 204 |
+
(user_id, graph_name),
|
| 205 |
+
)
|
| 206 |
+
count = cursor.fetchone()[0]
|
| 207 |
+
conn.close()
|
| 208 |
+
return count
|
| 209 |
+
|
| 210 |
+
def list_sheets(self, user_id: str, graph_name: str) -> list[dict[str, Any]]:
|
| 211 |
+
conn = sqlite3.connect(self.db_path)
|
| 212 |
+
cursor = conn.cursor()
|
| 213 |
+
cursor.execute(
|
| 214 |
+
"""SELECT sheet_id, name, created_at, updated_at
|
| 215 |
+
FROM sheets
|
| 216 |
+
WHERE user_id = ? AND graph_name = ?
|
| 217 |
+
ORDER BY updated_at DESC""",
|
| 218 |
+
(user_id, graph_name),
|
| 219 |
+
)
|
| 220 |
+
rows = cursor.fetchall()
|
| 221 |
+
conn.close()
|
| 222 |
+
return [
|
| 223 |
+
{
|
| 224 |
+
"sheet_id": row[0],
|
| 225 |
+
"name": row[1],
|
| 226 |
+
"created_at": row[2],
|
| 227 |
+
"updated_at": row[3],
|
| 228 |
+
}
|
| 229 |
+
for row in rows
|
| 230 |
+
]
|
| 231 |
+
|
| 232 |
+
def get_sheet(self, sheet_id: str) -> dict[str, Any] | None:
|
| 233 |
+
conn = sqlite3.connect(self.db_path)
|
| 234 |
+
cursor = conn.cursor()
|
| 235 |
+
cursor.execute(
|
| 236 |
+
"""SELECT sheet_id, user_id, graph_name, name, transform, created_at, updated_at
|
| 237 |
+
FROM sheets WHERE sheet_id = ?""",
|
| 238 |
+
(sheet_id,),
|
| 239 |
+
)
|
| 240 |
+
row = cursor.fetchone()
|
| 241 |
+
conn.close()
|
| 242 |
+
if row:
|
| 243 |
+
transform = None
|
| 244 |
+
if row[4]:
|
| 245 |
+
try:
|
| 246 |
+
transform = json.loads(row[4])
|
| 247 |
+
except (json.JSONDecodeError, TypeError):
|
| 248 |
+
pass
|
| 249 |
+
return {
|
| 250 |
+
"sheet_id": row[0],
|
| 251 |
+
"user_id": row[1],
|
| 252 |
+
"graph_name": row[2],
|
| 253 |
+
"name": row[3],
|
| 254 |
+
"transform": transform,
|
| 255 |
+
"created_at": row[5],
|
| 256 |
+
"updated_at": row[6],
|
| 257 |
+
}
|
| 258 |
+
return None
|
| 259 |
+
|
| 260 |
+
def save_transform(self, sheet_id: str, x: float, y: float, scale: float) -> bool:
|
| 261 |
+
now = datetime.now().isoformat()
|
| 262 |
+
transform = json.dumps({"x": x, "y": y, "scale": scale})
|
| 263 |
+
conn = sqlite3.connect(self.db_path)
|
| 264 |
+
cursor = conn.cursor()
|
| 265 |
+
cursor.execute(
|
| 266 |
+
"UPDATE sheets SET transform = ?, updated_at = ? WHERE sheet_id = ?",
|
| 267 |
+
(transform, now, sheet_id),
|
| 268 |
+
)
|
| 269 |
+
updated = cursor.rowcount > 0
|
| 270 |
+
conn.commit()
|
| 271 |
+
conn.close()
|
| 272 |
+
return updated
|
| 273 |
+
|
| 274 |
+
def rename_sheet(self, sheet_id: str, new_name: str) -> bool:
|
| 275 |
+
now = datetime.now().isoformat()
|
| 276 |
+
conn = sqlite3.connect(self.db_path)
|
| 277 |
+
cursor = conn.cursor()
|
| 278 |
+
cursor.execute(
|
| 279 |
+
"UPDATE sheets SET name = ?, updated_at = ? WHERE sheet_id = ?",
|
| 280 |
+
(new_name, now, sheet_id),
|
| 281 |
+
)
|
| 282 |
+
updated = cursor.rowcount > 0
|
| 283 |
+
conn.commit()
|
| 284 |
+
conn.close()
|
| 285 |
+
return updated
|
| 286 |
+
|
| 287 |
+
def delete_sheet(self, sheet_id: str) -> bool:
|
| 288 |
+
conn = sqlite3.connect(self.db_path)
|
| 289 |
+
cursor = conn.cursor()
|
| 290 |
+
cursor.execute("DELETE FROM node_inputs WHERE sheet_id = ?", (sheet_id,))
|
| 291 |
+
cursor.execute("DELETE FROM node_results WHERE sheet_id = ?", (sheet_id,))
|
| 292 |
+
cursor.execute("DELETE FROM sheets WHERE sheet_id = ?", (sheet_id,))
|
| 293 |
+
deleted = cursor.rowcount > 0
|
| 294 |
+
conn.commit()
|
| 295 |
+
conn.close()
|
| 296 |
+
return deleted
|
| 297 |
+
|
| 298 |
+
def get_or_create_sheet(
|
| 299 |
+
self, user_id: str, graph_name: str, sheet_id: str | None = None
|
| 300 |
+
) -> str:
|
| 301 |
+
if sheet_id:
|
| 302 |
+
sheet = self.get_sheet(sheet_id)
|
| 303 |
+
if sheet and sheet["user_id"] == user_id:
|
| 304 |
+
return sheet_id
|
| 305 |
+
|
| 306 |
+
sheets = self.list_sheets(user_id, graph_name)
|
| 307 |
+
if sheets:
|
| 308 |
+
return sheets[0]["sheet_id"]
|
| 309 |
+
|
| 310 |
+
return self.create_sheet(user_id, graph_name)
|
| 311 |
+
|
| 312 |
+
def save_input(self, sheet_id: str, node_name: str, port_name: str, value: Any):
|
| 313 |
+
now = datetime.now().isoformat()
|
| 314 |
+
value_json = json.dumps(value, default=str)
|
| 315 |
+
conn = sqlite3.connect(self.db_path)
|
| 316 |
+
cursor = conn.cursor()
|
| 317 |
+
cursor.execute(
|
| 318 |
+
"""INSERT INTO node_inputs (sheet_id, node_name, port_name, value, updated_at)
|
| 319 |
+
VALUES (?, ?, ?, ?, ?)
|
| 320 |
+
ON CONFLICT(sheet_id, node_name, port_name)
|
| 321 |
+
DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at""",
|
| 322 |
+
(sheet_id, node_name, port_name, value_json, now),
|
| 323 |
+
)
|
| 324 |
+
cursor.execute(
|
| 325 |
+
"UPDATE sheets SET updated_at = ? WHERE sheet_id = ?",
|
| 326 |
+
(now, sheet_id),
|
| 327 |
+
)
|
| 328 |
+
conn.commit()
|
| 329 |
+
conn.close()
|
| 330 |
+
|
| 331 |
+
def get_inputs(self, sheet_id: str) -> dict[str, dict[str, Any]]:
|
| 332 |
+
conn = sqlite3.connect(self.db_path)
|
| 333 |
+
cursor = conn.cursor()
|
| 334 |
+
cursor.execute(
|
| 335 |
+
"SELECT node_name, port_name, value FROM node_inputs WHERE sheet_id = ?",
|
| 336 |
+
(sheet_id,),
|
| 337 |
+
)
|
| 338 |
+
results = cursor.fetchall()
|
| 339 |
+
conn.close()
|
| 340 |
+
inputs: dict[str, dict[str, Any]] = {}
|
| 341 |
+
for node_name, port_name, value_json in results:
|
| 342 |
+
if node_name not in inputs:
|
| 343 |
+
inputs[node_name] = {}
|
| 344 |
+
inputs[node_name][port_name] = json.loads(value_json)
|
| 345 |
+
return inputs
|
| 346 |
+
|
| 347 |
+
def save_result(
|
| 348 |
+
self,
|
| 349 |
+
sheet_id: str,
|
| 350 |
+
node_name: str,
|
| 351 |
+
result: Any,
|
| 352 |
+
inputs_snapshot: dict[str, Any] | None = None,
|
| 353 |
+
):
|
| 354 |
+
now = datetime.now().isoformat()
|
| 355 |
+
result_json = json.dumps(result, default=str)
|
| 356 |
+
inputs_json = (
|
| 357 |
+
json.dumps(inputs_snapshot, default=str) if inputs_snapshot else None
|
| 358 |
+
)
|
| 359 |
+
conn = sqlite3.connect(self.db_path)
|
| 360 |
+
cursor = conn.cursor()
|
| 361 |
+
cursor.execute(
|
| 362 |
+
"INSERT INTO node_results (sheet_id, node_name, result, inputs_snapshot, created_at) VALUES (?, ?, ?, ?, ?)",
|
| 363 |
+
(sheet_id, node_name, result_json, inputs_json, now),
|
| 364 |
+
)
|
| 365 |
+
cursor.execute(
|
| 366 |
+
"UPDATE sheets SET updated_at = ? WHERE sheet_id = ?",
|
| 367 |
+
(now, sheet_id),
|
| 368 |
+
)
|
| 369 |
+
conn.commit()
|
| 370 |
+
conn.close()
|
| 371 |
+
|
| 372 |
+
def get_latest_result(self, sheet_id: str, node_name: str) -> Any | None:
|
| 373 |
+
conn = sqlite3.connect(self.db_path)
|
| 374 |
+
cursor = conn.cursor()
|
| 375 |
+
cursor.execute(
|
| 376 |
+
"""SELECT result FROM node_results
|
| 377 |
+
WHERE sheet_id = ? AND node_name = ?
|
| 378 |
+
ORDER BY created_at DESC LIMIT 1""",
|
| 379 |
+
(sheet_id, node_name),
|
| 380 |
+
)
|
| 381 |
+
result = cursor.fetchone()
|
| 382 |
+
conn.close()
|
| 383 |
+
if result:
|
| 384 |
+
return json.loads(result[0])
|
| 385 |
+
return None
|
| 386 |
+
|
| 387 |
+
def get_result_count(self, sheet_id: str, node_name: str) -> int:
|
| 388 |
+
conn = sqlite3.connect(self.db_path)
|
| 389 |
+
cursor = conn.cursor()
|
| 390 |
+
cursor.execute(
|
| 391 |
+
"SELECT COUNT(*) FROM node_results WHERE sheet_id = ? AND node_name = ?",
|
| 392 |
+
(sheet_id, node_name),
|
| 393 |
+
)
|
| 394 |
+
count = cursor.fetchone()[0]
|
| 395 |
+
conn.close()
|
| 396 |
+
return count
|
| 397 |
+
|
| 398 |
+
def get_result_by_index(
|
| 399 |
+
self, sheet_id: str, node_name: str, index: int
|
| 400 |
+
) -> Any | None:
|
| 401 |
+
conn = sqlite3.connect(self.db_path)
|
| 402 |
+
cursor = conn.cursor()
|
| 403 |
+
cursor.execute(
|
| 404 |
+
"""SELECT result FROM node_results
|
| 405 |
+
WHERE sheet_id = ? AND node_name = ?
|
| 406 |
+
ORDER BY created_at ASC""",
|
| 407 |
+
(sheet_id, node_name),
|
| 408 |
+
)
|
| 409 |
+
results = cursor.fetchall()
|
| 410 |
+
conn.close()
|
| 411 |
+
if results and 0 <= index < len(results):
|
| 412 |
+
return json.loads(results[index][0])
|
| 413 |
+
elif results:
|
| 414 |
+
return json.loads(results[-1][0])
|
| 415 |
+
return None
|
| 416 |
+
|
| 417 |
+
def get_all_results(self, sheet_id: str) -> dict[str, list[Any]]:
|
| 418 |
+
conn = sqlite3.connect(self.db_path)
|
| 419 |
+
cursor = conn.cursor()
|
| 420 |
+
cursor.execute(
|
| 421 |
+
"""SELECT node_name, result, inputs_snapshot FROM node_results
|
| 422 |
+
WHERE sheet_id = ?
|
| 423 |
+
ORDER BY created_at ASC""",
|
| 424 |
+
(sheet_id,),
|
| 425 |
+
)
|
| 426 |
+
results = cursor.fetchall()
|
| 427 |
+
conn.close()
|
| 428 |
+
all_results: dict[str, list[Any]] = {}
|
| 429 |
+
for node_name, result_json, inputs_json in results:
|
| 430 |
+
if node_name not in all_results:
|
| 431 |
+
all_results[node_name] = []
|
| 432 |
+
result_data = {
|
| 433 |
+
"result": json.loads(result_json),
|
| 434 |
+
"inputs_snapshot": json.loads(inputs_json) if inputs_json else None,
|
| 435 |
+
}
|
| 436 |
+
all_results[node_name].append(result_data)
|
| 437 |
+
return all_results
|
| 438 |
+
|
| 439 |
+
def get_sheet_state(self, sheet_id: str) -> dict[str, Any]:
|
| 440 |
+
return {
|
| 441 |
+
"inputs": self.get_inputs(sheet_id),
|
| 442 |
+
"results": self.get_all_results(sheet_id),
|
| 443 |
+
}
|
| 444 |
+
|
| 445 |
+
def clear_sheet_data(self, sheet_id: str):
|
| 446 |
+
conn = sqlite3.connect(self.db_path)
|
| 447 |
+
cursor = conn.cursor()
|
| 448 |
+
cursor.execute("DELETE FROM node_inputs WHERE sheet_id = ?", (sheet_id,))
|
| 449 |
+
cursor.execute("DELETE FROM node_results WHERE sheet_id = ?", (sheet_id,))
|
| 450 |
+
conn.commit()
|
| 451 |
+
conn.close()
|
| 452 |
+
|
| 453 |
+
def create_session(self, graph_name: str) -> str:
|
| 454 |
+
return self.create_sheet("local", graph_name)
|
| 455 |
+
|
| 456 |
+
def get_or_create_session(self, session_id: str | None, graph_name: str) -> str:
|
| 457 |
+
return self.get_or_create_sheet("local", graph_name, session_id)
|
pyproject.toml
CHANGED
|
@@ -6,13 +6,13 @@ readme = "README.md"
|
|
| 6 |
license = {text = "MIT"}
|
| 7 |
requires-python = ">=3.11"
|
| 8 |
dependencies = [
|
| 9 |
-
"gradio=
|
| 10 |
-
"huggingface_hub=
|
| 11 |
"modal",
|
| 12 |
"typer",
|
| 13 |
"pillow",
|
| 14 |
"python-dotenv",
|
| 15 |
-
"daggr",
|
| 16 |
]
|
| 17 |
|
| 18 |
[project.optional-dependencies]
|
|
|
|
| 6 |
license = {text = "MIT"}
|
| 7 |
requires-python = ">=3.11"
|
| 8 |
dependencies = [
|
| 9 |
+
"gradio>=4.19.2",
|
| 10 |
+
"huggingface_hub>=0.20.3",
|
| 11 |
"modal",
|
| 12 |
"typer",
|
| 13 |
"pillow",
|
| 14 |
"python-dotenv",
|
| 15 |
+
"daggr @ git+https://github.com/gradio-app/daggr.git",
|
| 16 |
]
|
| 17 |
|
| 18 |
[project.optional-dependencies]
|
requirements.txt
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
gradio==4.19.2
|
| 2 |
-
huggingface_hub==0.20.3
|
| 3 |
modal
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
daggr
|
|
|
|
| 1 |
gradio==4.19.2
|
|
|
|
| 2 |
modal
|
| 3 |
+
Pillow
|
| 4 |
+
./daggr
|
|
|
test_daggr_init.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os
|
| 3 |
+
import io
|
| 4 |
+
import modal
|
| 5 |
+
import gradio as gr
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from daggr import FnNode, Graph
|
| 8 |
+
|
| 9 |
+
def convert_image_to_bytes(image) -> bytes:
|
| 10 |
+
return b"test"
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
print("Attempting to create FnNode...")
|
| 14 |
+
converter = FnNode(
|
| 15 |
+
fn=convert_image_to_bytes,
|
| 16 |
+
name="Image Converter",
|
| 17 |
+
inputs={
|
| 18 |
+
"image": gr.Image(label="Upload your photo"),
|
| 19 |
+
},
|
| 20 |
+
outputs={
|
| 21 |
+
"output": gr.Textbox(visible=False),
|
| 22 |
+
},
|
| 23 |
+
)
|
| 24 |
+
print("Success!")
|
| 25 |
+
print(f"Output port: {converter.output}")
|
| 26 |
+
except Exception as e:
|
| 27 |
+
import traceback
|
| 28 |
+
traceback.print_exc()
|
uv.lock
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|