Alexramsal commited on
Commit
e248bdb
·
1 Parent(s): f44fecc

Preparación final para despliegue con Daggr visual

Browse files
README.md CHANGED
@@ -1,14 +1,17 @@
1
  ---
2
- title: Watercolor Portrait Generator
3
  emoji: 🎨
4
- colorFrom: indigo
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: "4.19.2"
8
- python_version: "3.11"
9
- app_file: app.py
10
  pinned: false
11
- license: mit
 
 
 
 
12
  ---
13
 
14
  # Generador de Retratos Acuarela con IA
 
1
  ---
2
+ title: Acuarela Portrait Daggr
3
  emoji: 🎨
4
+ colorFrom: blue
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 4.19.2
8
+ app_file: app_daggr.py
 
9
  pinned: false
10
+ tags:
11
+ - daggr
12
+ - modal
13
+ - sdxl
14
+ - watercolor
15
  ---
16
 
17
  # Generador de Retratos Acuarela con IA
app_daggr.py CHANGED
@@ -16,18 +16,24 @@ $ python app_daggr.py
16
 
17
  import os
18
  import io
 
 
 
 
 
 
 
 
 
 
 
 
19
  import modal
20
  import gradio as gr
21
  from PIL import Image
22
- import uuid
23
- from dotenv import load_dotenv
24
-
25
- # Import Daggr components
26
  from daggr import FnNode, GradioNode, InferenceNode, Graph
27
 
28
- load_dotenv()
29
-
30
- # --- Modal Setup ---
31
  try:
32
  ImageCaptioner = modal.Cls.from_name("acuarela-portrait", "ImageCaptioner")
33
  ImageGenerator = modal.Cls.from_name("acuarela-portrait", "ImageGenerator")
@@ -43,23 +49,71 @@ except Exception as e:
43
 
44
  # --- Function Nodes ---
45
 
46
- def convert_image_to_bytes(image) -> bytes:
47
- """Convert PIL Image to bytes for Modal processing"""
48
  if image is None:
49
  return None
50
- img_byte_arr = io.BytesIO()
51
- image.save(img_byte_arr, format='PNG')
52
- return img_byte_arr.getvalue()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
 
55
- def generate_caption(img_bytes: bytes) -> str:
56
  """
57
  Step 1: Generate image caption using Modal
58
  """
59
- if not MODAL_AVAILABLE or img_bytes is None:
60
- raise ValueError("Modal not available or image is None")
61
 
62
  try:
 
 
 
 
 
 
 
 
 
63
  captioner = ImageCaptioner()
64
  caption = captioner.caption.remote(img_bytes)
65
  if not caption:
@@ -82,14 +136,23 @@ def create_artistic_prompt(caption: str) -> str:
82
  return prompt
83
 
84
 
85
- def generate_watercolor_image(img_bytes: bytes, prompt: str) -> Image.Image:
86
  """
87
- Step 3: Generate watercolor image using Modal SDXL
88
  """
89
- if not MODAL_AVAILABLE or img_bytes is None:
90
- raise ValueError("Modal not available or image is None")
91
 
92
  try:
 
 
 
 
 
 
 
 
 
93
  negative_prompt = (
94
  "photorealistic, 3d render, photograph, complex background, "
95
  "white background, dark background, messy sketch, blurry, "
@@ -99,8 +162,14 @@ def generate_watercolor_image(img_bytes: bytes, prompt: str) -> Image.Image:
99
 
100
  generator = ImageGenerator()
101
  result_bytes = generator.generate.remote(img_bytes, prompt, negative_prompt, strength=0.65)
102
- result_image = Image.open(io.BytesIO(result_bytes))
103
- return result_image
 
 
 
 
 
 
104
  except Exception as e:
105
  raise ValueError(f"Image generation failed: {str(e)}")
106
 
@@ -113,15 +182,15 @@ def create_workflow():
113
  """
114
 
115
  # Node 1: Image Input (implicit - handled by Gradio interface)
116
- # Node 2: Convert image to bytes
117
  converter = FnNode(
118
- fn=convert_image_to_bytes,
119
  name="Image Converter",
120
  inputs={
121
- "image": gr.Image(label="Upload your photo", type="pil"),
122
  },
123
  outputs={
124
- "return": gr.Textbox(visible=False), # Hidden - intermediate step
125
  },
126
  )
127
 
@@ -130,10 +199,10 @@ def create_workflow():
130
  fn=generate_caption,
131
  name="Image Analysis (BLIP Caption)",
132
  inputs={
133
- "img_bytes": converter.return_,
134
  },
135
  outputs={
136
- "return": gr.Textbox(label="Generated Caption"),
137
  },
138
  )
139
 
@@ -142,10 +211,10 @@ def create_workflow():
142
  fn=create_artistic_prompt,
143
  name="Artistic Prompt Engineering",
144
  inputs={
145
- "caption": captioner.return_,
146
  },
147
  outputs={
148
- "return": gr.Textbox(label="Artistic Prompt", lines=3),
149
  },
150
  )
151
 
@@ -154,11 +223,11 @@ def create_workflow():
154
  fn=generate_watercolor_image,
155
  name="Watercolor Generation (SDXL)",
156
  inputs={
157
- "img_bytes": converter.return_,
158
- "prompt": prompt_engineer.return_,
159
  },
160
  outputs={
161
- "return": gr.Image(label="Watercolor Portrait"),
162
  },
163
  )
164
 
@@ -166,8 +235,7 @@ def create_workflow():
166
  graph = Graph(
167
  name="🎨 Acuarela Portrait Generator - Daggr Workflow",
168
  nodes=[converter, captioner, prompt_engineer, image_generator],
169
- description="Transform your photos into watercolor artworks using AI. "
170
- "Each step is visualized and can be rerun independently!",
171
  )
172
 
173
  return graph
 
16
 
17
  import os
18
  import io
19
+ import base64
20
+ import tempfile
21
+ import uuid
22
+ import shutil
23
+ from pathlib import Path
24
+
25
+ # Fix for Windows Long Paths in Gradio (must set before importing gradio)
26
+ if os.name == 'nt':
27
+ custom_temp = os.path.join(tempfile.gettempdir(), "gr")
28
+ os.makedirs(custom_temp, exist_ok=True)
29
+ os.environ["GRADIO_TEMP_DIR"] = custom_temp
30
+
31
  import modal
32
  import gradio as gr
33
  from PIL import Image
 
 
 
 
34
  from daggr import FnNode, GradioNode, InferenceNode, Graph
35
 
36
+ # --- Inference Setup ---
 
 
37
  try:
38
  ImageCaptioner = modal.Cls.from_name("acuarela-portrait", "ImageCaptioner")
39
  ImageGenerator = modal.Cls.from_name("acuarela-portrait", "ImageGenerator")
 
49
 
50
  # --- Function Nodes ---
51
 
52
+ def convert_image_to_path(image) -> str:
53
+ """Save image to a temporary file and return the path for Daggr nodes"""
54
  if image is None:
55
  return None
56
+
57
+ try:
58
+ # Create a unique temp file in our short gr folder
59
+ temp_dir = os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
60
+ temp_path = os.path.join(temp_dir, f"input_{uuid.uuid4()}.png")
61
+
62
+ # Handle Base64 Data URI from Gradio/Daggr
63
+ if isinstance(image, str) and image.startswith('data:image'):
64
+ header, encoded = image.split(",", 1)
65
+ data = base64.b64decode(encoded)
66
+ with open(temp_path, "wb") as f:
67
+ f.write(data)
68
+ return temp_path
69
+
70
+ # Handle file paths (already a path, just verify and return)
71
+ if isinstance(image, (str, Path)):
72
+ path_str = str(image)
73
+ # Windows Long Path fix
74
+ if os.name == 'nt' and not path_str.startswith('\\\\?\\'):
75
+ abs_path = os.path.abspath(path_str)
76
+ path_str = '\\\\?\\' + abs_path if len(abs_path) > 250 else abs_path
77
+
78
+ # If it's already a file, we can just return it, but better to copy to temp
79
+ # to avoid permissions/lifetime issues with Gradio's deep temp folders
80
+ shutil.copy2(path_str, temp_path)
81
+ return temp_path
82
+
83
+ # Handle PIL objects
84
+ if hasattr(image, 'save'):
85
+ image.save(temp_path, format='PNG')
86
+ return temp_path
87
+
88
+ # If it's bytes
89
+ if isinstance(image, bytes):
90
+ with open(temp_path, "wb") as f:
91
+ f.write(image)
92
+ return temp_path
93
+
94
+ return str(image)
95
+ except Exception as e:
96
+ print(f"❌ Error in convert_image_to_path: {e}")
97
+ raise e
98
 
99
 
100
+ def generate_caption(img_path: str) -> str:
101
  """
102
  Step 1: Generate image caption using Modal
103
  """
104
+ if not MODAL_AVAILABLE or img_path is None:
105
+ raise ValueError("Modal not available or image path is None")
106
 
107
  try:
108
+ # Read bytes from the provided path
109
+ path_str = str(img_path)
110
+ if os.name == 'nt' and not path_str.startswith('\\\\?\\'):
111
+ abs_path = os.path.abspath(path_str)
112
+ path_str = '\\\\?\\' + abs_path if len(abs_path) > 250 else abs_path
113
+
114
+ with open(path_str, "rb") as f:
115
+ img_bytes = f.read()
116
+
117
  captioner = ImageCaptioner()
118
  caption = captioner.caption.remote(img_bytes)
119
  if not caption:
 
136
  return prompt
137
 
138
 
139
+ def generate_watercolor_image(img_path: str, prompt: str) -> str:
140
  """
141
+ Step 3: Generate watercolor image using Modal SDXL and return the file path
142
  """
143
+ if not MODAL_AVAILABLE or img_path is None:
144
+ raise ValueError("Modal not available or image path is None")
145
 
146
  try:
147
+ # Read bytes from the provided path
148
+ path_str = str(img_path)
149
+ if os.name == 'nt' and not path_str.startswith('\\\\?\\'):
150
+ abs_path = os.path.abspath(path_str)
151
+ path_str = '\\\\?\\' + abs_path if len(abs_path) > 250 else abs_path
152
+
153
+ with open(path_str, "rb") as f:
154
+ img_bytes = f.read()
155
+
156
  negative_prompt = (
157
  "photorealistic, 3d render, photograph, complex background, "
158
  "white background, dark background, messy sketch, blurry, "
 
162
 
163
  generator = ImageGenerator()
164
  result_bytes = generator.generate.remote(img_bytes, prompt, negative_prompt, strength=0.65)
165
+
166
+ # Save result to a temp file
167
+ temp_dir = os.environ.get("GRADIO_TEMP_DIR", tempfile.gettempdir())
168
+ output_path = os.path.join(temp_dir, f"output_{uuid.uuid4()}.png")
169
+ with open(output_path, "wb") as f:
170
+ f.write(result_bytes)
171
+
172
+ return output_path
173
  except Exception as e:
174
  raise ValueError(f"Image generation failed: {str(e)}")
175
 
 
182
  """
183
 
184
  # Node 1: Image Input (implicit - handled by Gradio interface)
185
+ # Node 2: Convert image to path (using string to avoid bytes JSON error)
186
  converter = FnNode(
187
+ fn=convert_image_to_path,
188
  name="Image Converter",
189
  inputs={
190
+ "image": gr.Image(label="Upload your photo"),
191
  },
192
  outputs={
193
+ "output": gr.Textbox(visible=False),
194
  },
195
  )
196
 
 
199
  fn=generate_caption,
200
  name="Image Analysis (BLIP Caption)",
201
  inputs={
202
+ "img_path": converter.output,
203
  },
204
  outputs={
205
+ "output": gr.Textbox(label="Generated Caption"),
206
  },
207
  )
208
 
 
211
  fn=create_artistic_prompt,
212
  name="Artistic Prompt Engineering",
213
  inputs={
214
+ "caption": captioner.output,
215
  },
216
  outputs={
217
+ "output": gr.Textbox(label="Artistic Prompt", lines=3),
218
  },
219
  )
220
 
 
223
  fn=generate_watercolor_image,
224
  name="Watercolor Generation (SDXL)",
225
  inputs={
226
+ "img_path": converter.output,
227
+ "prompt": prompt_engineer.output,
228
  },
229
  outputs={
230
+ "output": gr.Image(label="Watercolor Portrait"),
231
  },
232
  )
233
 
 
235
  graph = Graph(
236
  name="🎨 Acuarela Portrait Generator - Daggr Workflow",
237
  nodes=[converter, captioner, prompt_engineer, image_generator],
238
+ persist_key=False, # Disable persistence to avoid bytes -> str serialization issues
 
239
  )
240
 
241
  return graph
daggr/CHANGELOG.md ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # daggr
2
+
3
+ ## 0.7.0
4
+
5
+ ### Features
6
+
7
+ - [#69](https://github.com/gradio-app/daggr/pull/69) [`297d104`](https://github.com/gradio-app/daggr/commit/297d104d0e3fbd6b59dfcd1f69c9a478de81bc3d) - Add stop button to cancel running nodes. Thanks @abidlabs!
8
+ - [#64](https://github.com/gradio-app/daggr/pull/64) [`a9e53c6`](https://github.com/gradio-app/daggr/commit/a9e53c64db6b3beede0b19b3876a3e50ab572233) - Fix ChoiceNode: public .name property and respect explicit names. Thanks @abidlabs!
9
+ - [#62](https://github.com/gradio-app/daggr/pull/62) [`695411c`](https://github.com/gradio-app/daggr/commit/695411ce94bc8fedd3320ac941ab41233bf8f887) - Standardize file handling: all files are path strings. Thanks @abidlabs!
10
+ - [#66](https://github.com/gradio-app/daggr/pull/66) [`1ed16f8`](https://github.com/gradio-app/daggr/commit/1ed16f806d535413c3718fc47d81d79d93d73ee0) - Fix gr.JSON rendering (use @render snippet syntax). Thanks @abidlabs!
11
+ - [#63](https://github.com/gradio-app/daggr/pull/63) [`00b05ac`](https://github.com/gradio-app/daggr/commit/00b05ac526642c91560cbebb258626a4511082c2) - Fix file downloads from private HF Spaces. Thanks @abidlabs!
12
+ - [#70](https://github.com/gradio-app/daggr/pull/70) [`33ccb74`](https://github.com/gradio-app/daggr/commit/33ccb7470a482ee1b09fddf1d51de81b1e2c4a40) - Fix gr.Image not rendering with initial value or None input. Thanks @abidlabs!
13
+ - [#68](https://github.com/gradio-app/daggr/pull/68) [`4b76dca`](https://github.com/gradio-app/daggr/commit/4b76dca815e2802b70ac03bb95fb03f21d81a8fa) - Add dependency hash tracking for upstream Spaces and models. Thanks @abidlabs!
14
+
15
+ ## 0.6.0
16
+
17
+ ### Features
18
+
19
+ - [#54](https://github.com/gradio-app/daggr/pull/54) [`c1abb26`](https://github.com/gradio-app/daggr/commit/c1abb260b254af6ca2060292232049ea89f0f944) - Fix cache. Thanks @abidlabs!
20
+ - [#56](https://github.com/gradio-app/daggr/pull/56) [`6e3dfc0`](https://github.com/gradio-app/daggr/commit/6e3dfc0a585b673adb77bb11ab1dcfd80d01da5a) - Add paste from clipboard button to Image component. Thanks @abidlabs!
21
+ - [#57](https://github.com/gradio-app/daggr/pull/57) [`76855ba`](https://github.com/gradio-app/daggr/commit/76855ba967e3f3132e8ec0590ae037d3151af310) - Fix dropdown options being clipped inside node. Thanks @abidlabs!
22
+ - [#58](https://github.com/gradio-app/daggr/pull/58) [`eb52b72`](https://github.com/gradio-app/daggr/commit/eb52b725b17d277e85f6eac1cc9d07f8068b011b) - Add theme support to daggr. Thanks @abidlabs!
23
+ - [#59](https://github.com/gradio-app/daggr/pull/59) [`78189a4`](https://github.com/gradio-app/daggr/commit/78189a4163b4041c814e52110b65754dc4dbf863) - Add run mode dropdown to control node execution scope. Thanks @abidlabs!
24
+ - [#39](https://github.com/gradio-app/daggr/pull/39) [`e8792ad`](https://github.com/gradio-app/daggr/commit/e8792ad1b5818ff8d13660b0b156f329bbc1c33a) - feat: add --state-db-path CLI arg and DAGGR_DB_PATH env var support. Thanks @leith-bartrich!
25
+
26
+ ## 0.5.4
27
+
28
+ ### Features
29
+
30
+ - [#27](https://github.com/gradio-app/daggr/pull/27) [`3952b2c`](https://github.com/gradio-app/daggr/commit/3952b2ccf30e7d18994f23049c2a2e84b323cfd6) - changes. Thanks @abidlabs!
31
+
32
+ ## 0.5.3
33
+
34
+ ### Features
35
+
36
+ - [#19](https://github.com/gradio-app/daggr/pull/19) [`cd956fe`](https://github.com/gradio-app/daggr/commit/cd956fe29945bdfd31bbe76fcb80d3f9c97cc301) - Add daggr tag to deployed Spaces. Thanks @gary149!
37
+
38
+ ## 0.5.2
39
+
40
+ ### Features
41
+
42
+ - [#14](https://github.com/gradio-app/daggr/pull/14) [`3fa412d`](https://github.com/gradio-app/daggr/commit/3fa412d678988608d49d46d99d193a05469892d2) - Fixes. Thanks @abidlabs!
43
+
44
+ ## 0.5.1
45
+
46
+ ### Features
47
+
48
+ - [#11](https://github.com/gradio-app/daggr/pull/11) [`ce1d5f4`](https://github.com/gradio-app/daggr/commit/ce1d5f4deaac60d95d9a021b0aa057bc2941b018) - Fixes. Thanks @abidlabs!
49
+ - [#13](https://github.com/gradio-app/daggr/pull/13) [`3246921`](https://github.com/gradio-app/daggr/commit/32469213dad5fd29a7ac85938dffbd976e2c6643) - fixes. Thanks @abidlabs!
50
+
51
+ ## 0.5.0
52
+
53
+ ### Features
54
+
55
+ - [#8](https://github.com/gradio-app/daggr/pull/8) [`e480065`](https://github.com/gradio-app/daggr/commit/e480065dd058dbf19053a80956dbfc90cf3e3caf) - Improving security around executor and various bug fixes. Thanks @abidlabs!
56
+
57
+ ## 0.4.0
58
+
59
+ ### Features
60
+
61
+ - [#1](https://github.com/gradio-app/daggr/pull/1) [`23538c8`](https://github.com/gradio-app/daggr/commit/23538c884fb3f2d84bbe4bf14f475dc85fa17c79) - Refactor files, add Dialogue component, and implement fully working podcast example. Thanks @abidlabs!
62
+
63
+ ## 0.1.0
64
+
65
+ Initial release
daggr/__init__.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """daggr - Build visual, node-based AI pipelines with Gradio Spaces.
2
+
3
+ daggr lets you create DAG (directed acyclic graph) pipelines that connect
4
+ Gradio Spaces, Hugging Face models, and Python functions into interactive
5
+ applications.
6
+
7
+ Example:
8
+ >>> from daggr import Graph, GradioNode, FnNode
9
+ >>> import gradio as gr
10
+ >>>
11
+ >>> tts = GradioNode(
12
+ ... "mrfakename/MeloTTS",
13
+ ... inputs={"text": gr.Textbox()},
14
+ ... outputs={"audio": gr.Audio()},
15
+ ... )
16
+ >>> graph = Graph("TTS Demo", nodes=[tts])
17
+ >>> graph.launch()
18
+ """
19
+
20
+ import json
21
+ from pathlib import Path
22
+
23
+ __version__ = json.loads((Path(__file__).parent / "package.json").read_text())[
24
+ "version"
25
+ ]
26
+
27
+ from daggr.edge import Edge
28
+ from daggr.graph import Graph
29
+ from daggr.node import (
30
+ ChoiceNode,
31
+ FnNode,
32
+ GradioNode,
33
+ InferenceNode,
34
+ InteractionNode,
35
+ Node,
36
+ )
37
+ from daggr.port import ItemList, Port
38
+ from daggr.server import DaggrServer
39
+
40
+ __all__ = [
41
+ "__version__",
42
+ "ChoiceNode",
43
+ "Edge",
44
+ "Graph",
45
+ "Node",
46
+ "FnNode",
47
+ "GradioNode",
48
+ "InferenceNode",
49
+ "InteractionNode",
50
+ "ItemList",
51
+ "Port",
52
+ "DaggrServer",
53
+ ]
daggr/_client_cache.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import hashlib
4
+ import json
5
+ import os
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ from daggr.state import get_daggr_cache_dir
10
+
11
+ _client_cache: dict[str, Any] = {}
12
+ _api_memory_cache: dict[str, dict] = {}
13
+ _validated_set: set[str] = set()
14
+ _model_task_cache: dict[str, str] = {}
15
+ _dependency_hash_cache: dict[str, str] = {}
16
+ _dependency_hash_loaded: bool = False
17
+
18
+
19
+ def _is_hot_reload() -> bool:
20
+ return os.environ.get("DAGGR_HOT_RELOAD") == "1"
21
+
22
+
23
+ def _get_cache_path(src: str) -> Path:
24
+ src_hash = hashlib.md5(src.encode()).hexdigest()[:16]
25
+ return get_daggr_cache_dir() / f"{src_hash}.json"
26
+
27
+
28
+ def _get_validated_file() -> Path:
29
+ return get_daggr_cache_dir() / "_validated.json"
30
+
31
+
32
+ def _load_validated_set() -> None:
33
+ global _validated_set
34
+ if _validated_set:
35
+ return
36
+ if not _is_hot_reload():
37
+ return
38
+ validated_file = _get_validated_file()
39
+ if validated_file.exists():
40
+ try:
41
+ _validated_set = set(json.loads(validated_file.read_text()))
42
+ except (json.JSONDecodeError, OSError):
43
+ _validated_set = set()
44
+
45
+
46
+ def _save_validated_set() -> None:
47
+ if not _is_hot_reload():
48
+ return
49
+ try:
50
+ get_daggr_cache_dir().mkdir(parents=True, exist_ok=True)
51
+ _get_validated_file().write_text(json.dumps(list(_validated_set)))
52
+ except OSError:
53
+ pass
54
+
55
+
56
+ def is_validated(cache_key: tuple) -> bool:
57
+ if not _is_hot_reload():
58
+ return False
59
+ _load_validated_set()
60
+ return str(cache_key) in _validated_set
61
+
62
+
63
+ def mark_validated(cache_key: tuple) -> None:
64
+ if not _is_hot_reload():
65
+ return
66
+ _load_validated_set()
67
+ _validated_set.add(str(cache_key))
68
+ _save_validated_set()
69
+
70
+
71
+ def get_api_info(src: str) -> dict | None:
72
+ if src in _api_memory_cache:
73
+ return _api_memory_cache[src]
74
+
75
+ if not _is_hot_reload():
76
+ return None
77
+
78
+ cache_path = _get_cache_path(src)
79
+ if cache_path.exists():
80
+ try:
81
+ data = json.loads(cache_path.read_text())
82
+ _api_memory_cache[src] = data
83
+ return data
84
+ except (json.JSONDecodeError, OSError):
85
+ pass
86
+ return None
87
+
88
+
89
+ def set_api_info(src: str, info: dict) -> None:
90
+ _api_memory_cache[src] = info
91
+ if not _is_hot_reload():
92
+ return
93
+ try:
94
+ get_daggr_cache_dir().mkdir(parents=True, exist_ok=True)
95
+ cache_path = _get_cache_path(src)
96
+ cache_path.write_text(json.dumps(info))
97
+ except OSError:
98
+ pass
99
+
100
+
101
+ def get_client(src: str):
102
+ return _client_cache.get(src)
103
+
104
+
105
+ def set_client(src: str, client) -> None:
106
+ _client_cache[src] = client
107
+
108
+
109
+ def _get_model_task_cache_path() -> Path:
110
+ return get_daggr_cache_dir() / "_model_tasks.json"
111
+
112
+
113
+ def _load_model_task_cache() -> None:
114
+ global _model_task_cache
115
+ if _model_task_cache:
116
+ return
117
+ if not _is_hot_reload():
118
+ return
119
+ cache_path = _get_model_task_cache_path()
120
+ if cache_path.exists():
121
+ try:
122
+ _model_task_cache = json.loads(cache_path.read_text())
123
+ except (json.JSONDecodeError, OSError):
124
+ _model_task_cache = {}
125
+
126
+
127
+ def _save_model_task_cache() -> None:
128
+ if not _is_hot_reload():
129
+ return
130
+ try:
131
+ get_daggr_cache_dir().mkdir(parents=True, exist_ok=True)
132
+ _get_model_task_cache_path().write_text(json.dumps(_model_task_cache))
133
+ except OSError:
134
+ pass
135
+
136
+
137
+ def get_model_task(model: str) -> tuple[bool, str | None]:
138
+ """Get cached task for a model.
139
+
140
+ Returns:
141
+ (found_in_cache, task) where:
142
+ - found_in_cache is True if we have cached info for this model
143
+ - task is the pipeline_tag (can be None if model has no task, or "__NOT_FOUND__" if model doesn't exist)
144
+ """
145
+ if model in _model_task_cache:
146
+ return True, _model_task_cache[model]
147
+
148
+ if not _is_hot_reload():
149
+ return False, None
150
+
151
+ _load_model_task_cache()
152
+ if model in _model_task_cache:
153
+ return True, _model_task_cache[model]
154
+ return False, None
155
+
156
+
157
+ def set_model_task(model: str, task: str | None) -> None:
158
+ _model_task_cache[model] = task
159
+ _save_model_task_cache()
160
+
161
+
162
+ def set_model_not_found(model: str) -> None:
163
+ _model_task_cache[model] = "__NOT_FOUND__"
164
+ _save_model_task_cache()
165
+
166
+
167
+ def _get_dependency_hash_path() -> Path:
168
+ return get_daggr_cache_dir() / "_dependency_hashes.json"
169
+
170
+
171
+ def _load_dependency_hash_cache() -> None:
172
+ global _dependency_hash_cache, _dependency_hash_loaded
173
+ if _dependency_hash_loaded:
174
+ return
175
+ cache_path = _get_dependency_hash_path()
176
+ if cache_path.exists():
177
+ try:
178
+ _dependency_hash_cache = json.loads(cache_path.read_text())
179
+ except (json.JSONDecodeError, OSError):
180
+ _dependency_hash_cache = {}
181
+ _dependency_hash_loaded = True
182
+
183
+
184
+ def _save_dependency_hash_cache() -> None:
185
+ try:
186
+ get_daggr_cache_dir().mkdir(parents=True, exist_ok=True)
187
+ _get_dependency_hash_path().write_text(json.dumps(_dependency_hash_cache))
188
+ except OSError:
189
+ pass
190
+
191
+
192
+ def get_dependency_hash(src: str) -> str | None:
193
+ _load_dependency_hash_cache()
194
+ return _dependency_hash_cache.get(src)
195
+
196
+
197
+ def set_dependency_hash(src: str, sha: str) -> None:
198
+ _load_dependency_hash_cache()
199
+ _dependency_hash_cache[src] = sha
200
+ _save_dependency_hash_cache()
daggr/_utils.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Internal utilities for daggr."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import difflib
6
+
7
+
8
+ def suggest_similar(invalid: str, valid_options: set[str]) -> str | None:
9
+ """Find a similar string from valid_options using fuzzy matching.
10
+
11
+ Args:
12
+ invalid: The invalid string to find matches for.
13
+ valid_options: Set of valid options to search through.
14
+
15
+ Returns:
16
+ The closest matching string if found with >= 60% similarity, else None.
17
+ """
18
+ matches = difflib.get_close_matches(invalid, valid_options, n=1, cutoff=0.6)
19
+ return matches[0] if matches else None
daggr/assets/hf-logo-pirate.png ADDED

Git LFS Details

  • SHA256: 0df04e76c7d8d5becb1d7d554a471e72c0286150cd008ee83fbc3d133aa87f7d
  • Pointer size: 130 Bytes
  • Size of remote file: 23.3 kB
daggr/assets/logo_dark.png ADDED

Git LFS Details

  • SHA256: 62e9a658dbd7e689117bf751019e6b2907b06df337194e5d8a041b44e0c6fa32
  • Pointer size: 131 Bytes
  • Size of remote file: 150 kB
daggr/assets/logo_dark_small.png ADDED

Git LFS Details

  • SHA256: 5e5ee1e6baf6f68a3304663b125649a7f344388d175870951f84a7e5243c9d49
  • Pointer size: 129 Bytes
  • Size of remote file: 6.81 kB
daggr/assets/logo_light.png ADDED

Git LFS Details

  • SHA256: 7876fab9fbebe3fd6a05c6ff0f1f094149f71eb60823993e88247ab440b399bf
  • Pointer size: 131 Bytes
  • Size of remote file: 141 kB
daggr/cli.py ADDED
@@ -0,0 +1,689 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import ast
5
+ import importlib.util
6
+ import os
7
+ import re
8
+ import shutil
9
+ import socket
10
+ import sqlite3
11
+ import sys
12
+ import tempfile
13
+ import threading
14
+ import time
15
+ import webbrowser
16
+ from pathlib import Path
17
+
18
+ INITIAL_PORT_VALUE = int(os.getenv("DAGGR_SERVER_PORT", "7860"))
19
+ TRY_NUM_PORTS = int(os.getenv("DAGGR_NUM_PORTS", "100"))
20
+
21
+
22
+ def _find_available_port(host: str, start_port: int) -> int:
23
+ """Find an available port starting from start_port."""
24
+ for port in range(start_port, start_port + TRY_NUM_PORTS):
25
+ try:
26
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
27
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
28
+ s.bind((host if host != "0.0.0.0" else "127.0.0.1", port))
29
+ s.close()
30
+ return port
31
+ except OSError:
32
+ continue
33
+ raise OSError(
34
+ f"Cannot find empty port in range: {start_port}-{start_port + TRY_NUM_PORTS - 1}. "
35
+ f"You can specify a different port by setting the DAGGR_SERVER_PORT environment variable "
36
+ f"or passing the --port parameter."
37
+ )
38
+
39
+
40
+ def find_python_imports(file_path: Path) -> list[Path]:
41
+ """Find local Python files imported by the given file."""
42
+ imports = []
43
+ try:
44
+ with open(file_path) as f:
45
+ content = f.read()
46
+
47
+ tree = ast.parse(content)
48
+
49
+ file_dir = file_path.parent
50
+
51
+ for node in ast.walk(tree):
52
+ if isinstance(node, ast.Import):
53
+ for alias in node.names:
54
+ module_path = file_dir / f"{alias.name.replace('.', '/')}.py"
55
+ if module_path.exists():
56
+ imports.append(module_path)
57
+ elif isinstance(node, ast.ImportFrom):
58
+ if node.module:
59
+ module_path = file_dir / f"{node.module.replace('.', '/')}.py"
60
+ if module_path.exists():
61
+ imports.append(module_path)
62
+ package_init = (
63
+ file_dir / node.module.replace(".", "/") / "__init__.py"
64
+ )
65
+ if package_init.exists():
66
+ imports.append(package_init.parent)
67
+ except Exception:
68
+ pass
69
+ return imports
70
+
71
+
72
+ def main():
73
+ if len(sys.argv) > 1 and sys.argv[1] == "deploy":
74
+ _deploy_main()
75
+ return
76
+
77
+ parser = argparse.ArgumentParser(
78
+ prog="daggr",
79
+ description="Run a daggr app with hot reload",
80
+ )
81
+ parser.add_argument(
82
+ "script",
83
+ help="Path to the Python script containing the daggr Graph",
84
+ )
85
+ parser.add_argument(
86
+ "--host",
87
+ default="127.0.0.1",
88
+ help="Host to bind to (default: 127.0.0.1)",
89
+ )
90
+ parser.add_argument(
91
+ "--port",
92
+ type=int,
93
+ default=7860,
94
+ help="Port to bind to (default: 7860)",
95
+ )
96
+ parser.add_argument(
97
+ "--no-reload",
98
+ action="store_true",
99
+ help="Disable auto-reload",
100
+ )
101
+ parser.add_argument(
102
+ "--watch-daggr",
103
+ action="store_true",
104
+ default=True,
105
+ help="Watch daggr source for changes (default: True, useful for development)",
106
+ )
107
+ parser.add_argument(
108
+ "--no-watch-daggr",
109
+ action="store_true",
110
+ help="Don't watch daggr source for changes",
111
+ )
112
+ parser.add_argument(
113
+ "--delete-sheets",
114
+ action="store_true",
115
+ help="Delete all cached data (sheets, results, downloaded files) for this project and exit",
116
+ )
117
+ parser.add_argument(
118
+ "--force",
119
+ "-f",
120
+ action="store_true",
121
+ help="Skip confirmation prompts (use with --delete-sheets)",
122
+ )
123
+ parser.add_argument(
124
+ "--state-db-path",
125
+ help="Optional path to SQLite state database. Overrides DAGGR_DB_PATH env var. Defaults to HuggingFace cache.",
126
+ )
127
+
128
+ args = parser.parse_args()
129
+
130
+ script_path = Path(args.script).resolve()
131
+ if not script_path.exists():
132
+ print(f"Error: Script not found: {script_path}")
133
+ sys.exit(1)
134
+
135
+ if not script_path.suffix == ".py":
136
+ print(f"Error: Script must be a Python file: {script_path}")
137
+ sys.exit(1)
138
+
139
+ if args.delete_sheets:
140
+ _delete_sheets(script_path, force=args.force)
141
+ sys.exit(0)
142
+
143
+ watch_daggr = args.watch_daggr and not args.no_watch_daggr
144
+
145
+ os.environ["DAGGR_SCRIPT_PATH"] = str(script_path)
146
+ os.environ["DAGGR_HOST"] = args.host
147
+ os.environ["DAGGR_PORT"] = str(args.port)
148
+ if args.state_db_path:
149
+ os.environ["DAGGR_DB_PATH"] = str(Path(args.state_db_path).resolve())
150
+
151
+ if args.no_reload:
152
+ _run_script(script_path, args.host, args.port)
153
+ else:
154
+ os.environ["DAGGR_HOT_RELOAD"] = "1"
155
+ _run_with_reload(script_path, args.host, args.port, watch_daggr)
156
+
157
+
158
+ def _deploy_main():
159
+ """Entry point for the deploy subcommand."""
160
+ parser = argparse.ArgumentParser(
161
+ prog="daggr deploy",
162
+ description="Deploy a daggr app to Hugging Face Spaces",
163
+ )
164
+ parser.add_argument(
165
+ "script",
166
+ help="Path to the Python script containing the daggr Graph",
167
+ )
168
+ parser.add_argument(
169
+ "--name",
170
+ "-n",
171
+ help="Space name (default: derived from Graph name)",
172
+ )
173
+ parser.add_argument(
174
+ "--title",
175
+ "-t",
176
+ help="Display title for the Space (default: Graph name)",
177
+ )
178
+ parser.add_argument(
179
+ "--org",
180
+ "-o",
181
+ help="Organization or username to deploy under (default: your HF account)",
182
+ )
183
+ parser.add_argument(
184
+ "--private",
185
+ "-p",
186
+ action="store_true",
187
+ help="Make the Space private",
188
+ )
189
+ parser.add_argument(
190
+ "--hardware",
191
+ default="cpu-basic",
192
+ help="Hardware tier (default: cpu-basic). Options: cpu-basic, cpu-upgrade, t4-small, t4-medium, a10g-small, etc.",
193
+ )
194
+ parser.add_argument(
195
+ "--secret",
196
+ "-s",
197
+ action="append",
198
+ dest="secrets",
199
+ metavar="KEY=VALUE",
200
+ help="Add a secret (can be repeated). Example: --secret HF_TOKEN=xxx",
201
+ )
202
+ parser.add_argument(
203
+ "--requirements",
204
+ "-r",
205
+ help="Path to requirements.txt (default: auto-detect or generate)",
206
+ )
207
+ parser.add_argument(
208
+ "--dry-run",
209
+ action="store_true",
210
+ help="Preview what would be deployed without actually deploying",
211
+ )
212
+
213
+ args = parser.parse_args(sys.argv[2:])
214
+
215
+ script_path = Path(args.script).resolve()
216
+ if not script_path.exists():
217
+ print(f"Error: Script not found: {script_path}")
218
+ sys.exit(1)
219
+
220
+ if not script_path.suffix == ".py":
221
+ print(f"Error: Script must be a Python file: {script_path}")
222
+ sys.exit(1)
223
+
224
+ secrets = {}
225
+ if args.secrets:
226
+ for secret in args.secrets:
227
+ if "=" not in secret:
228
+ print(f"Error: Invalid secret format '{secret}'. Use KEY=VALUE")
229
+ sys.exit(1)
230
+ key, value = secret.split("=", 1)
231
+ secrets[key] = value
232
+
233
+ _deploy(
234
+ script_path=script_path,
235
+ name=args.name,
236
+ title=args.title,
237
+ org=args.org,
238
+ private=args.private,
239
+ hardware=args.hardware,
240
+ secrets=secrets,
241
+ requirements_path=args.requirements,
242
+ dry_run=args.dry_run,
243
+ )
244
+
245
+
246
+ def _extract_graph(script_path: Path):
247
+ """Extract the Graph object from a script without running it."""
248
+ from daggr.graph import Graph
249
+
250
+ sys.path.insert(0, str(script_path.parent))
251
+
252
+ original_launch = Graph.launch
253
+ captured_graph = None
254
+
255
+ def capture_launch(self, **kwargs):
256
+ nonlocal captured_graph
257
+ captured_graph = self
258
+
259
+ Graph.launch = capture_launch
260
+
261
+ try:
262
+ spec = importlib.util.spec_from_file_location("__daggr_deploy__", script_path)
263
+ if spec is None or spec.loader is None:
264
+ print(f"Error: Could not load script: {script_path}")
265
+ sys.exit(1)
266
+
267
+ module = importlib.util.module_from_spec(spec)
268
+ sys.modules["__daggr_deploy__"] = module
269
+ spec.loader.exec_module(module)
270
+ finally:
271
+ Graph.launch = original_launch
272
+
273
+ if captured_graph is None:
274
+ for name in dir(module):
275
+ obj = getattr(module, name)
276
+ if isinstance(obj, Graph):
277
+ captured_graph = obj
278
+ break
279
+
280
+ if captured_graph is None:
281
+ print(f"Error: No Graph found in {script_path}")
282
+ sys.exit(1)
283
+
284
+ return captured_graph
285
+
286
+
287
+ def _sanitize_space_name(name: str) -> str:
288
+ """Convert a Graph name to a valid HF Space name."""
289
+ sanitized = re.sub(r"[^a-zA-Z0-9\s-]", "", name)
290
+ sanitized = re.sub(r"[\s_]+", "-", sanitized)
291
+ sanitized = sanitized.lower().strip("-")
292
+ return sanitized or "daggr-app"
293
+
294
+
295
+ def _deploy(
296
+ script_path: Path,
297
+ name: str | None,
298
+ title: str | None,
299
+ org: str | None,
300
+ private: bool,
301
+ hardware: str,
302
+ secrets: dict[str, str],
303
+ requirements_path: str | None,
304
+ dry_run: bool,
305
+ ):
306
+ """Deploy a daggr app to Hugging Face Spaces."""
307
+ import huggingface_hub
308
+ from huggingface_hub import HfApi
309
+
310
+ import daggr
311
+
312
+ print("\n Extracting Graph from script...")
313
+ graph = _extract_graph(script_path)
314
+
315
+ space_name = name or _sanitize_space_name(graph.name)
316
+ space_title = title or graph.name
317
+
318
+ print(f" Graph name: {graph.name}")
319
+ print(f" Space name: {space_name}")
320
+ print(f" Space title: {space_title}")
321
+
322
+ hf_api = HfApi()
323
+ whoami = None
324
+ login_needed = False
325
+
326
+ try:
327
+ whoami = hf_api.whoami()
328
+ if whoami["auth"]["accessToken"]["role"] != "write":
329
+ login_needed = True
330
+ except Exception:
331
+ login_needed = True
332
+
333
+ if login_needed:
334
+ print("\n Need 'write' access token to create a Spaces repo.")
335
+ huggingface_hub.login(add_to_git_credential=False)
336
+ whoami = hf_api.whoami()
337
+
338
+ username = whoami["name"]
339
+ namespace = org or username
340
+ repo_id = f"{namespace}/{space_name}"
341
+
342
+ print(f"\n Target: https://huggingface.co/spaces/{repo_id}")
343
+ print(f" Hardware: {hardware}")
344
+ print(f" Private: {private}")
345
+ if secrets:
346
+ print(f" Secrets: {list(secrets.keys())}")
347
+
348
+ local_imports = find_python_imports(script_path)
349
+ print("\n Files to upload:")
350
+ print(f" • app.py (from {script_path.name})")
351
+ print(" • requirements.txt")
352
+ print(" • README.md")
353
+ for imp in local_imports:
354
+ if imp.is_file():
355
+ print(f" • {imp.name}")
356
+ else:
357
+ print(f" • {imp.name}/ (package)")
358
+
359
+ if dry_run:
360
+ print("\n [Dry run] No changes made.")
361
+ return
362
+
363
+ with tempfile.TemporaryDirectory() as tmpdir:
364
+ tmpdir = Path(tmpdir)
365
+
366
+ shutil.copy(script_path, tmpdir / "app.py")
367
+
368
+ for imp in local_imports:
369
+ if imp.is_file():
370
+ shutil.copy(imp, tmpdir / imp.name)
371
+ else:
372
+ shutil.copytree(imp, tmpdir / imp.name)
373
+
374
+ if requirements_path:
375
+ req_path = Path(requirements_path)
376
+ if not req_path.exists():
377
+ print(f"Error: Requirements file not found: {req_path}")
378
+ sys.exit(1)
379
+ shutil.copy(req_path, tmpdir / "requirements.txt")
380
+
381
+ with open(tmpdir / "requirements.txt", "r") as f:
382
+ req_content = f.read()
383
+ if "daggr" not in req_content:
384
+ with open(tmpdir / "requirements.txt", "a") as f:
385
+ f.write(f"\ndaggr>={daggr.__version__}\n")
386
+ else:
387
+ script_dir = script_path.parent
388
+ existing_req = script_dir / "requirements.txt"
389
+ if existing_req.exists():
390
+ shutil.copy(existing_req, tmpdir / "requirements.txt")
391
+ with open(tmpdir / "requirements.txt", "r") as f:
392
+ req_content = f.read()
393
+ if "daggr" not in req_content:
394
+ with open(tmpdir / "requirements.txt", "a") as f:
395
+ f.write(f"\ndaggr>={daggr.__version__}\n")
396
+ else:
397
+ with open(tmpdir / "requirements.txt", "w") as f:
398
+ f.write(f"daggr>={daggr.__version__}\n")
399
+
400
+ readme_content = f"""---
401
+ title: {space_title}
402
+ emoji: 🔀
403
+ colorFrom: blue
404
+ colorTo: purple
405
+ sdk: gradio
406
+ sdk_version: "{_get_gradio_version()}"
407
+ app_file: app.py
408
+ pinned: false
409
+ tags:
410
+ - daggr
411
+ ---
412
+
413
+ # {space_title}
414
+
415
+ This Space was deployed using [daggr](https://github.com/gradio-app/daggr).
416
+ """
417
+ with open(tmpdir / "README.md", "w") as f:
418
+ f.write(readme_content)
419
+
420
+ print("\n Creating Space repository...")
421
+ try:
422
+ hf_api.create_repo(
423
+ repo_id=repo_id,
424
+ repo_type="space",
425
+ space_sdk="gradio",
426
+ space_hardware=hardware,
427
+ private=private,
428
+ exist_ok=True,
429
+ )
430
+ except Exception as e:
431
+ print(f"Error creating repository: {e}")
432
+ sys.exit(1)
433
+
434
+ print(" Uploading files...")
435
+ try:
436
+ hf_api.upload_folder(
437
+ repo_id=repo_id,
438
+ repo_type="space",
439
+ folder_path=str(tmpdir),
440
+ )
441
+ except Exception as e:
442
+ print(f"Error uploading files: {e}")
443
+ sys.exit(1)
444
+
445
+ if secrets:
446
+ print(" Adding secrets...")
447
+ for secret_name, secret_value in secrets.items():
448
+ try:
449
+ hf_api.add_space_secret(repo_id, secret_name, secret_value)
450
+ except Exception as e:
451
+ print(f" Warning: Could not add secret '{secret_name}': {e}")
452
+
453
+ print(f"\n ✓ Deployed to https://huggingface.co/spaces/{repo_id}")
454
+ print(" The Space may take a few minutes to build and start.\n")
455
+
456
+
457
+ def _get_gradio_version() -> str:
458
+ """Get the installed Gradio version."""
459
+ try:
460
+ import gradio
461
+
462
+ return gradio.__version__
463
+ except ImportError:
464
+ return "5.0.0"
465
+
466
+
467
+ def _delete_sheets(script_path: Path, force: bool = False):
468
+ """Delete all cached data for the project defined in the script."""
469
+ from daggr.graph import Graph
470
+ from daggr.state import get_daggr_cache_dir
471
+
472
+ sys.path.insert(0, str(script_path.parent))
473
+
474
+ original_launch = Graph.launch
475
+ captured_graph = None
476
+
477
+ def capture_launch(self, **kwargs):
478
+ nonlocal captured_graph
479
+ captured_graph = self
480
+
481
+ Graph.launch = capture_launch
482
+
483
+ try:
484
+ spec = importlib.util.spec_from_file_location("__daggr_reset__", script_path)
485
+ if spec is None or spec.loader is None:
486
+ print(f"Error: Could not load script: {script_path}")
487
+ sys.exit(1)
488
+
489
+ module = importlib.util.module_from_spec(spec)
490
+ sys.modules["__daggr_reset__"] = module
491
+ spec.loader.exec_module(module)
492
+ finally:
493
+ Graph.launch = original_launch
494
+
495
+ if captured_graph is None:
496
+ for name in dir(module):
497
+ obj = getattr(module, name)
498
+ if isinstance(obj, Graph):
499
+ captured_graph = obj
500
+ break
501
+
502
+ if captured_graph is None:
503
+ print(f"Error: No Graph found in {script_path}")
504
+ sys.exit(1)
505
+
506
+ persist_key = captured_graph.persist_key
507
+ if not persist_key:
508
+ print("Error: Graph has no persist_key (persistence is disabled)")
509
+ sys.exit(1)
510
+
511
+ cache_dir = get_daggr_cache_dir()
512
+ db_path = cache_dir / "sessions.db"
513
+
514
+ if not db_path.exists():
515
+ print(f"No cache found for project '{persist_key}'")
516
+ return
517
+
518
+ conn = sqlite3.connect(str(db_path))
519
+ cursor = conn.cursor()
520
+
521
+ cursor.execute(
522
+ "SELECT sheet_id FROM sheets WHERE graph_name = ?",
523
+ (persist_key,),
524
+ )
525
+ sheet_ids = [row[0] for row in cursor.fetchall()]
526
+
527
+ if not sheet_ids:
528
+ print(f"No cached data found for project '{persist_key}'")
529
+ conn.close()
530
+ return
531
+
532
+ print(f"\nProject: {persist_key}")
533
+ print(f"This will delete {len(sheet_ids)} sheet(s) and all associated data.")
534
+ print(f"Cache location: {cache_dir}\n")
535
+
536
+ if not force:
537
+ try:
538
+ response = (
539
+ input("Are you sure you want to continue? [y/N] ").strip().lower()
540
+ )
541
+ except (EOFError, KeyboardInterrupt):
542
+ print("\nAborted.")
543
+ conn.close()
544
+ return
545
+
546
+ if response not in ("y", "yes"):
547
+ print("Aborted.")
548
+ conn.close()
549
+ return
550
+
551
+ for sheet_id in sheet_ids:
552
+ cursor.execute("DELETE FROM node_inputs WHERE sheet_id = ?", (sheet_id,))
553
+ cursor.execute("DELETE FROM node_results WHERE sheet_id = ?", (sheet_id,))
554
+ cursor.execute("DELETE FROM sheets WHERE sheet_id = ?", (sheet_id,))
555
+
556
+ conn.commit()
557
+ conn.close()
558
+
559
+ print(f"\n✓ Deleted {len(sheet_ids)} sheet(s) for project '{persist_key}'")
560
+
561
+
562
+ def _run_script(script_path: Path, host: str, port: int):
563
+ """Run the script directly without reload."""
564
+ spec = importlib.util.spec_from_file_location("__daggr_main__", script_path)
565
+ if spec is None or spec.loader is None:
566
+ print(f"Error: Could not load script: {script_path}")
567
+ sys.exit(1)
568
+
569
+ sys.path.insert(0, str(script_path.parent))
570
+
571
+ module = importlib.util.module_from_spec(spec)
572
+ sys.modules["__daggr_main__"] = module
573
+ spec.loader.exec_module(module)
574
+
575
+
576
+ def _run_with_reload(script_path: Path, host: str, port: int, watch_daggr: bool):
577
+ """Run the script with uvicorn hot reload."""
578
+ import uvicorn
579
+
580
+ actual_port = _find_available_port(host, port)
581
+ if actual_port != port:
582
+ print(f"\n Port {port} is in use, using {actual_port} instead.")
583
+
584
+ reload_dirs = [str(script_path.parent)]
585
+
586
+ local_imports = find_python_imports(script_path)
587
+ for imp in local_imports:
588
+ imp_dir = str(imp if imp.is_dir() else imp.parent)
589
+ if imp_dir not in reload_dirs:
590
+ reload_dirs.append(imp_dir)
591
+
592
+ if watch_daggr:
593
+ daggr_dir = Path(__file__).parent
594
+ daggr_src = str(daggr_dir)
595
+ if daggr_src not in reload_dirs:
596
+ reload_dirs.append(daggr_src)
597
+
598
+ reload_includes = ["*.py"]
599
+
600
+ print("\n daggr dev server starting...")
601
+ print(" Watching for changes in:")
602
+ for d in reload_dirs:
603
+ print(f" • {d}")
604
+ print()
605
+
606
+ os.environ["DAGGR_PORT"] = str(actual_port)
607
+
608
+ def open_browser():
609
+ time.sleep(1.0)
610
+ webbrowser.open_new_tab(f"http://{host}:{actual_port}")
611
+
612
+ threading.Thread(target=open_browser, daemon=True).start()
613
+
614
+ uvicorn.run(
615
+ "daggr.cli:_create_app",
616
+ factory=True,
617
+ host=host,
618
+ port=actual_port,
619
+ reload=True,
620
+ reload_dirs=reload_dirs,
621
+ reload_includes=reload_includes,
622
+ log_level="warning",
623
+ )
624
+
625
+
626
+ def _create_app():
627
+ """Factory function for uvicorn to create the FastAPI app."""
628
+ from daggr.graph import Graph
629
+ from daggr.server import DaggrServer
630
+
631
+ script_path = Path(os.environ["DAGGR_SCRIPT_PATH"])
632
+
633
+ if str(script_path.parent) not in sys.path:
634
+ sys.path.insert(0, str(script_path.parent))
635
+
636
+ modules_to_remove = [m for m in sys.modules if m.startswith("__daggr_user_script_")]
637
+ for m in modules_to_remove:
638
+ del sys.modules[m]
639
+
640
+ module_name = f"__daggr_user_script_{id(script_path)}__"
641
+
642
+ spec = importlib.util.spec_from_file_location(module_name, script_path)
643
+ if spec is None or spec.loader is None:
644
+ raise RuntimeError(f"Could not load script: {script_path}")
645
+
646
+ original_launch = Graph.launch
647
+ captured_graph = None
648
+ launch_kwargs = {}
649
+
650
+ def capture_launch(self, **kwargs):
651
+ nonlocal captured_graph, launch_kwargs
652
+ captured_graph = self
653
+ launch_kwargs = kwargs
654
+
655
+ Graph.launch = capture_launch
656
+
657
+ try:
658
+ module = importlib.util.module_from_spec(spec)
659
+ sys.modules[module_name] = module
660
+ spec.loader.exec_module(module)
661
+ finally:
662
+ Graph.launch = original_launch
663
+
664
+ if captured_graph is None:
665
+ for name in dir(module):
666
+ obj = getattr(module, name)
667
+ if isinstance(obj, Graph):
668
+ captured_graph = obj
669
+ break
670
+
671
+ if captured_graph is None:
672
+ raise RuntimeError(
673
+ f"No Graph found in {script_path}. "
674
+ "Make sure your script defines a Graph and calls graph.launch() "
675
+ "or has a Graph instance at module level."
676
+ )
677
+
678
+ captured_graph._validate_edges()
679
+ server = DaggrServer(captured_graph)
680
+
681
+ base_url = f"http://{os.environ['DAGGR_HOST']}:{os.environ['DAGGR_PORT']}"
682
+ print(f"\n UI running at: {base_url}")
683
+ print(f" API server at: {base_url}/api\n")
684
+
685
+ return server.app
686
+
687
+
688
+ if __name__ == "__main__":
689
+ main()
daggr/edge.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Edge module for connecting ports between nodes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ from daggr.port import GatheredPort, ScatteredPort
8
+
9
+ if TYPE_CHECKING:
10
+ from daggr.port import PortLike
11
+
12
+
13
+ class Edge:
14
+ """Represents a connection between two ports in a graph.
15
+
16
+ Edges connect an output port of one node to an input port of another,
17
+ defining how data flows through the graph.
18
+
19
+ Attributes:
20
+ source_node: The node providing the output.
21
+ source_port: Name of the output port.
22
+ target_node: The node receiving the input.
23
+ target_port: Name of the input port.
24
+ is_scattered: True if this edge scatters a list to multiple executions.
25
+ is_gathered: True if this edge gathers results back into a list.
26
+ item_key: For scattered edges, the key to extract from each item.
27
+ """
28
+
29
+ def __init__(self, source: PortLike, target: PortLike):
30
+ self.is_scattered = isinstance(source, ScatteredPort)
31
+ self.is_gathered = isinstance(source, GatheredPort)
32
+ self.item_key: str | None = None
33
+
34
+ if self.is_scattered:
35
+ self.item_key = source.item_key
36
+
37
+ self.source_node = source.node
38
+ self.source_port = source.name
39
+ self.target_node = target.node
40
+ self.target_port = target.name
41
+
42
+ def __repr__(self):
43
+ prefix = ""
44
+ if self.is_scattered:
45
+ key_info = f"['{self.item_key}']" if self.item_key else ""
46
+ prefix = f"scatter{key_info}:"
47
+ elif self.is_gathered:
48
+ prefix = "gather:"
49
+ return (
50
+ f"Edge({prefix}{self.source_node._name}.{self.source_port} -> "
51
+ f"{self.target_node._name}.{self.target_port})"
52
+ )
53
+
54
+ def as_tuple(self) -> tuple[str, str, str, str]:
55
+ return (
56
+ self.source_node._name,
57
+ self.source_port,
58
+ self.target_node._name,
59
+ self.target_port,
60
+ )
daggr/executor.py ADDED
@@ -0,0 +1,846 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Executor for daggr graphs.
2
+
3
+ This module provides the AsyncExecutor for running graph nodes with proper
4
+ concurrency control and session isolation.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import asyncio
10
+ import base64
11
+ import hashlib
12
+ import uuid
13
+ from pathlib import Path
14
+ from typing import TYPE_CHECKING, Any
15
+ from urllib.parse import urlparse
16
+
17
+ from gradio_client.utils import is_file_obj_with_meta, traverse
18
+
19
+ from daggr.node import (
20
+ ChoiceNode,
21
+ FnNode,
22
+ GradioNode,
23
+ InferenceNode,
24
+ InteractionNode,
25
+ )
26
+ from daggr.session import ExecutionSession
27
+ from daggr.state import get_daggr_files_dir
28
+
29
+ if TYPE_CHECKING:
30
+ from daggr.graph import Graph
31
+
32
+
33
+ class FileValue(str):
34
+ """A string subclass that marks a value as a file URL/path from Gradio output."""
35
+
36
+ pass
37
+
38
+
39
+ def _download_file(url: str, hf_token: str | None = None) -> str:
40
+ import httpx
41
+
42
+ parsed = urlparse(url)
43
+ ext = Path(parsed.path).suffix or ".bin"
44
+ url_hash = hashlib.md5(url.encode()).hexdigest()[:16]
45
+ filename = f"{url_hash}{ext}"
46
+
47
+ files_dir = get_daggr_files_dir()
48
+ local_path = files_dir / filename
49
+
50
+ if not local_path.exists():
51
+ headers = {}
52
+ if hf_token:
53
+ headers["Authorization"] = f"Bearer {hf_token}"
54
+ with httpx.Client(follow_redirects=True) as client:
55
+ response = client.get(url, headers=headers)
56
+ response.raise_for_status()
57
+ local_path.write_bytes(response.content)
58
+
59
+ return str(local_path)
60
+
61
+
62
+ def _postprocess_inference_result(task: str | None, result: Any) -> Any:
63
+ """Unwrap HF Inference Client result objects to get the actual data."""
64
+ if result is None:
65
+ return None
66
+
67
+ if task == "automatic-speech-recognition":
68
+ return getattr(result, "text", result)
69
+ elif task == "translation":
70
+ return getattr(result, "translation_text", result)
71
+ elif task == "summarization":
72
+ return getattr(result, "summary_text", result)
73
+ elif task in (
74
+ "audio-classification",
75
+ "image-classification",
76
+ "text-classification",
77
+ ):
78
+ if isinstance(result, list) and result:
79
+ return {item.label: item.score for item in result if hasattr(item, "label")}
80
+ return result
81
+ elif task == "image-to-text":
82
+ return getattr(result, "generated_text", result)
83
+ elif task == "question-answering":
84
+ if hasattr(result, "answer"):
85
+ return result.answer
86
+ return result
87
+ elif task in ("text-to-speech", "text-to-audio"):
88
+ if isinstance(result, bytes):
89
+ file_path = get_daggr_files_dir() / f"{uuid.uuid4()}.wav"
90
+ file_path.write_bytes(result)
91
+ return str(file_path)
92
+ return result
93
+ elif task in ("text-to-image", "image-to-image"):
94
+ if isinstance(result, dict):
95
+ if "images" in result:
96
+ result = result["images"][0] if result["images"] else result
97
+ elif "image" in result:
98
+ result = result["image"]
99
+ if hasattr(result, "save"):
100
+ file_path = get_daggr_files_dir() / f"{uuid.uuid4()}.png"
101
+ result.save(file_path)
102
+ return str(file_path)
103
+ return result
104
+
105
+ return result
106
+
107
+
108
+ def _call_inference_task(client: Any, task: str | None, inputs: dict[str, Any]) -> Any:
109
+ primary_input = None
110
+ if task in (
111
+ "image-to-image",
112
+ "image-classification",
113
+ "image-to-text",
114
+ "object-detection",
115
+ "image-segmentation",
116
+ "visual-question-answering",
117
+ "document-question-answering",
118
+ ):
119
+ primary_input = inputs.get("image")
120
+ elif task in (
121
+ "automatic-speech-recognition",
122
+ "audio-classification",
123
+ "audio-to-audio",
124
+ ):
125
+ primary_input = inputs.get("audio")
126
+
127
+ if primary_input is None:
128
+ primary_input = next(iter(inputs.values()), None) if inputs else None
129
+
130
+ if primary_input is None:
131
+ return None
132
+
133
+ task_method_map = {
134
+ "text-generation": "text_generation",
135
+ "text2text-generation": "text_generation",
136
+ "text-to-image": "text_to_image",
137
+ "image-to-image": "image_to_image",
138
+ "image-to-text": "image_to_text",
139
+ "image-to-video": "image_to_video",
140
+ "text-to-video": "text_to_video",
141
+ "text-to-speech": "text_to_speech",
142
+ "text-to-audio": "text_to_audio",
143
+ "automatic-speech-recognition": "automatic_speech_recognition",
144
+ "audio-to-audio": "audio_to_audio",
145
+ "audio-classification": "audio_classification",
146
+ "image-classification": "image_classification",
147
+ "object-detection": "object_detection",
148
+ "image-segmentation": "image_segmentation",
149
+ "translation": "translation",
150
+ "summarization": "summarization",
151
+ "feature-extraction": "feature_extraction",
152
+ "fill-mask": "fill_mask",
153
+ "question-answering": "question_answering",
154
+ "table-question-answering": "table_question_answering",
155
+ "sentence-similarity": "sentence_similarity",
156
+ "zero-shot-classification": "zero_shot_classification",
157
+ "zero-shot-image-classification": "zero_shot_image_classification",
158
+ "document-question-answering": "document_question_answering",
159
+ "visual-question-answering": "visual_question_answering",
160
+ }
161
+
162
+ method_name = (
163
+ task_method_map.get(task, "text_generation") if task else "text_generation"
164
+ )
165
+ method = getattr(client, method_name, None)
166
+
167
+ file_input_tasks = {
168
+ "image-to-image",
169
+ "image-classification",
170
+ "image-to-text",
171
+ "object-detection",
172
+ "image-segmentation",
173
+ "visual-question-answering",
174
+ "document-question-answering",
175
+ "automatic-speech-recognition",
176
+ "audio-classification",
177
+ "audio-to-audio",
178
+ }
179
+
180
+ if task in file_input_tasks and isinstance(primary_input, str):
181
+ primary_input = _read_file_as_bytes(primary_input)
182
+
183
+ try:
184
+ if method is None:
185
+ result = client.text_generation(primary_input)
186
+ elif task in ("image-to-image",):
187
+ prompt = inputs.get("prompt", "")
188
+ result = method(primary_input, prompt=prompt)
189
+ elif task in ("visual-question-answering", "document-question-answering"):
190
+ question = inputs.get("question", inputs.get("prompt", ""))
191
+ result = method(primary_input, question=question)
192
+ else:
193
+ result = method(primary_input)
194
+ except KeyError as e:
195
+ raise RuntimeError(
196
+ f"Provider returned unexpected response format for task '{task}'. "
197
+ f"Missing key: {e}. This model may require a specific provider "
198
+ f"(e.g., 'model_name:fal-ai' or 'model_name:replicate')."
199
+ ) from e
200
+
201
+ return _postprocess_inference_result(task, result)
202
+
203
+
204
+ def _read_file_as_bytes(file_path: str) -> bytes:
205
+ """Read a file path or data URL as bytes."""
206
+ if file_path.startswith("data:"):
207
+ try:
208
+ _, encoded = file_path.split(",", 1)
209
+ return base64.b64decode(encoded)
210
+ except Exception:
211
+ pass
212
+
213
+ path = Path(file_path)
214
+ if path.exists():
215
+ return path.read_bytes()
216
+
217
+ return file_path
218
+
219
+
220
+ class AsyncExecutor:
221
+ """Async executor for graph nodes.
222
+
223
+ This executor is stateless - all state is held in the ExecutionSession.
224
+ It handles concurrency control:
225
+ - GradioNode/InferenceNode: run concurrently (external API calls)
226
+ - FnNode: sequential by default, configurable via concurrent/concurrency_group
227
+ """
228
+
229
+ def __init__(self, graph: Graph):
230
+ self.graph = graph
231
+
232
+ def _get_client_for_gradio_node(
233
+ self, session: ExecutionSession, gradio_node, cache_key: str
234
+ ):
235
+ from daggr import _client_cache
236
+
237
+ token_cache_key = f"{cache_key}__token_{hash(session.hf_token or '')}"
238
+ if token_cache_key in session.clients:
239
+ return session.clients[token_cache_key]
240
+
241
+ if gradio_node._run_locally:
242
+ from daggr.local_space import get_local_client
243
+
244
+ client = get_local_client(gradio_node)
245
+ if client is not None:
246
+ session.clients[token_cache_key] = client
247
+ return client
248
+
249
+ if session.hf_token:
250
+ from gradio_client import Client
251
+
252
+ client = Client(
253
+ gradio_node._src,
254
+ download_files=False,
255
+ verbose=False,
256
+ token=session.hf_token,
257
+ )
258
+ else:
259
+ client = _client_cache.get_client(gradio_node._src)
260
+ if client is None:
261
+ from gradio_client import Client
262
+
263
+ client = Client(
264
+ gradio_node._src,
265
+ download_files=False,
266
+ verbose=False,
267
+ )
268
+ _client_cache.set_client(gradio_node._src, client)
269
+
270
+ session.clients[token_cache_key] = client
271
+ return client
272
+
273
+ def _get_client(self, session: ExecutionSession, node_name: str):
274
+ node = self.graph.nodes[node_name]
275
+
276
+ if isinstance(node, ChoiceNode):
277
+ variant_idx = session.selected_variants.get(node_name, 0)
278
+ variant = node._variants[variant_idx]
279
+ if isinstance(variant, GradioNode):
280
+ cache_key = f"{node_name}__variant_{variant_idx}"
281
+ return self._get_client_for_gradio_node(session, variant, cache_key)
282
+ return None
283
+
284
+ if not isinstance(node, GradioNode):
285
+ return None
286
+
287
+ return self._get_client_for_gradio_node(session, node, node_name)
288
+
289
+ def _get_scattered_input_edges(self, node_name: str) -> list:
290
+ scattered = []
291
+ for edge in self.graph._edges:
292
+ if edge.target_node._name == node_name and edge.is_scattered:
293
+ scattered.append(edge)
294
+ return scattered
295
+
296
+ def _get_gathered_input_edges(self, node_name: str) -> list:
297
+ gathered = []
298
+ for edge in self.graph._edges:
299
+ if edge.target_node._name == node_name and edge.is_gathered:
300
+ gathered.append(edge)
301
+ return gathered
302
+
303
+ def _prepare_inputs(
304
+ self, session: ExecutionSession, node_name: str, skip_scattered: bool = False
305
+ ) -> dict[str, Any]:
306
+ inputs = {}
307
+
308
+ for edge in self.graph._edges:
309
+ if edge.target_node._name == node_name:
310
+ if skip_scattered and edge.is_scattered:
311
+ continue
312
+
313
+ source_name = edge.source_node._name
314
+ source_output = edge.source_port
315
+ target_input = edge.target_port
316
+
317
+ if source_name in session.results:
318
+ source_result = session.results[source_name]
319
+
320
+ if (
321
+ edge.is_gathered
322
+ and isinstance(source_result, dict)
323
+ and "_scattered_results" in source_result
324
+ ):
325
+ scattered_results = source_result["_scattered_results"]
326
+ extracted = []
327
+ for item_result in scattered_results:
328
+ if (
329
+ isinstance(item_result, dict)
330
+ and source_output in item_result
331
+ ):
332
+ extracted.append(item_result[source_output])
333
+ else:
334
+ extracted.append(item_result)
335
+ inputs[target_input] = extracted
336
+ elif (
337
+ isinstance(source_result, dict)
338
+ and source_output in source_result
339
+ ):
340
+ inputs[target_input] = source_result[source_output]
341
+ elif isinstance(source_result, (list, tuple)):
342
+ try:
343
+ output_idx = int(
344
+ source_output.replace("output_", "").replace(
345
+ "output", "0"
346
+ )
347
+ )
348
+ if 0 <= output_idx < len(source_result):
349
+ inputs[target_input] = source_result[output_idx]
350
+ except (ValueError, TypeError):
351
+ if len(source_result) > 0:
352
+ inputs[target_input] = source_result[0]
353
+ else:
354
+ inputs[target_input] = source_result
355
+
356
+ return inputs
357
+
358
+ def _execute_single_node_sync(
359
+ self, session: ExecutionSession, node_name: str, inputs: dict[str, Any]
360
+ ) -> Any:
361
+ """Synchronous node execution (called from thread pool for FnNode)."""
362
+ node = self.graph.nodes[node_name]
363
+
364
+ if isinstance(node, ChoiceNode):
365
+ variant_idx = session.selected_variants.get(node_name, 0)
366
+ variant = node._variants[variant_idx]
367
+ return self._execute_variant_node_sync(session, node_name, variant, inputs)
368
+
369
+ all_inputs = {}
370
+ for port_name, value in node._fixed_inputs.items():
371
+ all_inputs[port_name] = value() if callable(value) else value
372
+ for port_name, component in node._input_components.items():
373
+ if hasattr(component, "value"):
374
+ val = component.value
375
+ if is_file_obj_with_meta(val):
376
+ val = val["path"]
377
+ all_inputs[port_name] = val
378
+ all_inputs.update(inputs)
379
+
380
+ if isinstance(node, GradioNode):
381
+ client = self._get_client(session, node_name)
382
+ if client:
383
+ api_name = node._api_name or "/predict"
384
+ if not api_name.startswith("/"):
385
+ api_name = "/" + api_name
386
+ call_inputs = {
387
+ k: self._wrap_file_input(v)
388
+ for k, v in all_inputs.items()
389
+ if k in node._input_ports
390
+ }
391
+ if node._preprocess:
392
+ call_inputs = node._preprocess(call_inputs)
393
+ raw_result = client.predict(api_name=api_name, **call_inputs)
394
+ if node._postprocess:
395
+ raw_result = self._apply_postprocess(node._postprocess, raw_result)
396
+ result = self._map_gradio_result(
397
+ node, raw_result, hf_token=session.hf_token
398
+ )
399
+ else:
400
+ result = None
401
+
402
+ elif isinstance(node, FnNode):
403
+ fn_kwargs = {}
404
+ for port_name in node._input_ports:
405
+ if port_name in all_inputs:
406
+ fn_kwargs[port_name] = all_inputs[port_name]
407
+ if node._preprocess:
408
+ fn_kwargs = node._preprocess(fn_kwargs)
409
+ raw_result = node._fn(**fn_kwargs)
410
+ if node._postprocess:
411
+ raw_result = self._apply_postprocess(node._postprocess, raw_result)
412
+ result = self._map_fn_result(node, raw_result)
413
+
414
+ elif isinstance(node, InferenceNode):
415
+ from huggingface_hub import InferenceClient
416
+
417
+ if not node._task_fetched:
418
+ node._fetch_model_info()
419
+ client = InferenceClient(
420
+ model=node._model_name_for_hub,
421
+ provider=node._provider,
422
+ token=session.hf_token,
423
+ )
424
+ inference_inputs = {
425
+ k: v for k, v in all_inputs.items() if k in node._input_ports
426
+ }
427
+ if node._preprocess:
428
+ inference_inputs = node._preprocess(inference_inputs)
429
+ raw_result = _call_inference_task(client, node._task, inference_inputs)
430
+ if node._postprocess:
431
+ raw_result = self._apply_postprocess(node._postprocess, raw_result)
432
+ result = self._map_inference_result(node, raw_result)
433
+
434
+ elif isinstance(node, InteractionNode):
435
+ result = all_inputs.get(
436
+ "input",
437
+ all_inputs.get(node._input_ports[0]) if node._input_ports else None,
438
+ )
439
+
440
+ else:
441
+ result = None
442
+
443
+ return result
444
+
445
+ def _execute_variant_node_sync(
446
+ self,
447
+ session: ExecutionSession,
448
+ node_name: str,
449
+ variant,
450
+ inputs: dict[str, Any],
451
+ ) -> Any:
452
+ all_inputs = {}
453
+ for port_name, value in variant._fixed_inputs.items():
454
+ all_inputs[port_name] = value() if callable(value) else value
455
+ for port_name, component in variant._input_components.items():
456
+ if hasattr(component, "value"):
457
+ val = component.value
458
+ if is_file_obj_with_meta(val):
459
+ val = val["path"]
460
+ all_inputs[port_name] = val
461
+ all_inputs.update(inputs)
462
+
463
+ if isinstance(variant, GradioNode):
464
+ client = self._get_client(session, node_name)
465
+ if client:
466
+ api_name = variant._api_name or "/predict"
467
+ if not api_name.startswith("/"):
468
+ api_name = "/" + api_name
469
+ call_inputs = {
470
+ k: self._wrap_file_input(v)
471
+ for k, v in all_inputs.items()
472
+ if k in variant._input_ports
473
+ }
474
+ if variant._preprocess:
475
+ call_inputs = variant._preprocess(call_inputs)
476
+ raw_result = client.predict(api_name=api_name, **call_inputs)
477
+ if variant._postprocess:
478
+ raw_result = self._apply_postprocess(
479
+ variant._postprocess, raw_result
480
+ )
481
+ result = self._map_gradio_result(
482
+ variant, raw_result, hf_token=session.hf_token
483
+ )
484
+ else:
485
+ result = None
486
+
487
+ elif isinstance(variant, FnNode):
488
+ fn_kwargs = {}
489
+ for port_name in variant._input_ports:
490
+ if port_name in all_inputs:
491
+ fn_kwargs[port_name] = all_inputs[port_name]
492
+ if variant._preprocess:
493
+ fn_kwargs = variant._preprocess(fn_kwargs)
494
+ raw_result = variant._fn(**fn_kwargs)
495
+ if variant._postprocess:
496
+ raw_result = self._apply_postprocess(variant._postprocess, raw_result)
497
+ result = self._map_fn_result(variant, raw_result)
498
+
499
+ elif isinstance(variant, InferenceNode):
500
+ from huggingface_hub import InferenceClient
501
+
502
+ if not variant._task_fetched:
503
+ variant._fetch_model_info()
504
+ client = InferenceClient(
505
+ model=variant._model_name_for_hub,
506
+ provider=variant._provider,
507
+ token=session.hf_token,
508
+ )
509
+ inference_inputs = {
510
+ k: v for k, v in all_inputs.items() if k in variant._input_ports
511
+ }
512
+ if variant._preprocess:
513
+ inference_inputs = variant._preprocess(inference_inputs)
514
+ raw_result = _call_inference_task(client, variant._task, inference_inputs)
515
+ if variant._postprocess:
516
+ raw_result = self._apply_postprocess(variant._postprocess, raw_result)
517
+ result = self._map_inference_result(variant, raw_result)
518
+
519
+ else:
520
+ result = None
521
+
522
+ return result
523
+
524
+ async def execute_node(
525
+ self,
526
+ session: ExecutionSession,
527
+ node_name: str,
528
+ user_inputs: dict[str, Any] | None = None,
529
+ ) -> Any:
530
+ """Execute a single node with proper concurrency control."""
531
+ node = self.graph.nodes[node_name]
532
+ scattered_edges = self._get_scattered_input_edges(node_name)
533
+
534
+ if scattered_edges:
535
+ result = await self._execute_scattered_node(
536
+ session, node_name, scattered_edges, user_inputs
537
+ )
538
+ else:
539
+ inputs = self._prepare_inputs(session, node_name)
540
+ if user_inputs:
541
+ if isinstance(user_inputs, dict):
542
+ inputs.update(user_inputs)
543
+ else:
544
+ if node._input_ports:
545
+ inputs[node._input_ports[0]] = user_inputs
546
+ else:
547
+ inputs["input"] = user_inputs
548
+
549
+ try:
550
+ if isinstance(node, (GradioNode, InferenceNode)):
551
+ result = await asyncio.to_thread(
552
+ self._execute_single_node_sync, session, node_name, inputs
553
+ )
554
+ elif isinstance(node, FnNode):
555
+ semaphore = await session.concurrency.get_semaphore(
556
+ node._concurrent,
557
+ node._concurrency_group,
558
+ node._max_concurrent,
559
+ )
560
+ if semaphore:
561
+ async with semaphore:
562
+ result = await asyncio.to_thread(
563
+ self._execute_single_node_sync,
564
+ session,
565
+ node_name,
566
+ inputs,
567
+ )
568
+ else:
569
+ result = await asyncio.to_thread(
570
+ self._execute_single_node_sync, session, node_name, inputs
571
+ )
572
+ else:
573
+ result = await asyncio.to_thread(
574
+ self._execute_single_node_sync, session, node_name, inputs
575
+ )
576
+ except Exception as e:
577
+ raise RuntimeError(f"Error executing node '{node_name}': {e}")
578
+
579
+ session.results[node_name] = result
580
+ return result
581
+
582
+ async def _execute_scattered_node(
583
+ self,
584
+ session: ExecutionSession,
585
+ node_name: str,
586
+ scattered_edges: list,
587
+ user_inputs: dict[str, Any] | None = None,
588
+ ) -> dict[str, list[Any]]:
589
+ first_edge = scattered_edges[0]
590
+ source_name = first_edge.source_node._name
591
+ source_port = first_edge.source_port
592
+
593
+ source_result = session.results.get(source_name)
594
+ if source_result is None:
595
+ items = []
596
+ elif isinstance(source_result, dict) and source_port in source_result:
597
+ items = source_result[source_port]
598
+ else:
599
+ items = source_result
600
+
601
+ if not isinstance(items, list):
602
+ items = [items]
603
+
604
+ context_inputs = self._prepare_inputs(session, node_name, skip_scattered=True)
605
+ if user_inputs:
606
+ context_inputs.update(user_inputs)
607
+
608
+ node = self.graph.nodes[node_name]
609
+
610
+ async def execute_item(item, idx):
611
+ item_inputs = dict(context_inputs)
612
+ for edge in scattered_edges:
613
+ target_port = edge.target_port
614
+ item_key = edge.item_key
615
+ if item_key and isinstance(item, dict):
616
+ item_inputs[target_port] = item.get(item_key)
617
+ else:
618
+ item_inputs[target_port] = item
619
+
620
+ try:
621
+ if isinstance(node, (GradioNode, InferenceNode)):
622
+ return await asyncio.to_thread(
623
+ self._execute_single_node_sync, session, node_name, item_inputs
624
+ )
625
+ elif isinstance(node, FnNode):
626
+ semaphore = await session.concurrency.get_semaphore(
627
+ node._concurrent,
628
+ node._concurrency_group,
629
+ node._max_concurrent,
630
+ )
631
+ if semaphore:
632
+ async with semaphore:
633
+ return await asyncio.to_thread(
634
+ self._execute_single_node_sync,
635
+ session,
636
+ node_name,
637
+ item_inputs,
638
+ )
639
+ else:
640
+ return await asyncio.to_thread(
641
+ self._execute_single_node_sync,
642
+ session,
643
+ node_name,
644
+ item_inputs,
645
+ )
646
+ else:
647
+ return await asyncio.to_thread(
648
+ self._execute_single_node_sync, session, node_name, item_inputs
649
+ )
650
+ except Exception as e:
651
+ return {"error": str(e)}
652
+
653
+ if isinstance(node, (GradioNode, InferenceNode)):
654
+ tasks = [execute_item(item, i) for i, item in enumerate(items)]
655
+ results = await asyncio.gather(*tasks)
656
+ else:
657
+ results = []
658
+ for i, item in enumerate(items):
659
+ result = await execute_item(item, i)
660
+ results.append(result)
661
+
662
+ session.scattered_results[node_name] = list(results)
663
+ return {"_scattered_results": list(results), "_items": items}
664
+
665
+ def _wrap_file_input(self, value: Any) -> Any:
666
+ from gradio_client import handle_file
667
+
668
+ if isinstance(value, FileValue):
669
+ return handle_file(str(value))
670
+
671
+ if isinstance(value, str):
672
+ if value.startswith("data:"):
673
+ file_path = self._save_data_url_to_file(value)
674
+ if file_path:
675
+ return handle_file(file_path)
676
+ elif Path(value).exists():
677
+ return handle_file(value)
678
+
679
+ return value
680
+
681
+ def _save_data_url_to_file(self, data_url: str) -> str | None:
682
+ """Convert a base64 data URL to a file and return the path."""
683
+ if not data_url.startswith("data:"):
684
+ return None
685
+
686
+ try:
687
+ header, encoded = data_url.split(",", 1)
688
+ media_type = header.split(":")[1].split(";")[0]
689
+ ext_map = {
690
+ "image/png": ".png",
691
+ "image/jpeg": ".jpg",
692
+ "image/jpg": ".jpg",
693
+ "image/gif": ".gif",
694
+ "image/webp": ".webp",
695
+ "audio/wav": ".wav",
696
+ "audio/mpeg": ".mp3",
697
+ "audio/mp3": ".mp3",
698
+ "audio/ogg": ".ogg",
699
+ "audio/webm": ".webm",
700
+ "video/mp4": ".mp4",
701
+ "video/webm": ".webm",
702
+ }
703
+ ext = ext_map.get(media_type, ".bin")
704
+ data = base64.b64decode(encoded)
705
+ file_path = get_daggr_files_dir() / f"{uuid.uuid4()}{ext}"
706
+ file_path.write_bytes(data)
707
+ return str(file_path)
708
+ except Exception:
709
+ return None
710
+
711
+ def _apply_postprocess(self, postprocess, raw_result: Any) -> Any:
712
+ if isinstance(raw_result, (list, tuple)):
713
+ return postprocess(*raw_result)
714
+ return postprocess(raw_result)
715
+
716
+ def _extract_file_urls(self, data: Any, hf_token: str | None = None) -> Any:
717
+ def download_and_wrap(file_obj: dict) -> FileValue:
718
+ url = file_obj.get("url")
719
+ if url:
720
+ local_path = _download_file(url, hf_token=hf_token)
721
+ return FileValue(local_path)
722
+ path = file_obj.get("path", "")
723
+ return FileValue(path)
724
+
725
+ return traverse(data, download_and_wrap, is_file_obj_with_meta)
726
+
727
+ def _map_gradio_result(
728
+ self, node, raw_result: Any, hf_token: str | None = None
729
+ ) -> dict[str, Any]:
730
+ if raw_result is None:
731
+ return {}
732
+
733
+ raw_result = self._extract_file_urls(raw_result, hf_token=hf_token)
734
+
735
+ output_ports = node._output_ports
736
+ if not output_ports:
737
+ return {"output": raw_result}
738
+
739
+ if isinstance(raw_result, (list, tuple)):
740
+ result = {}
741
+ for i, port_name in enumerate(output_ports):
742
+ if i < len(raw_result):
743
+ result[port_name] = raw_result[i]
744
+ else:
745
+ result[port_name] = None
746
+ return result
747
+ elif len(output_ports) == 1:
748
+ return {output_ports[0]: raw_result}
749
+ else:
750
+ return {output_ports[0]: raw_result}
751
+
752
+ def _map_fn_result(self, node, raw_result: Any) -> dict[str, Any]:
753
+ if raw_result is None:
754
+ return {}
755
+
756
+ output_ports = node._output_ports
757
+ if not output_ports:
758
+ return {"output": raw_result}
759
+
760
+ if isinstance(raw_result, tuple):
761
+ result = {}
762
+ for i, port_name in enumerate(output_ports):
763
+ if i < len(raw_result):
764
+ result[port_name] = raw_result[i]
765
+ else:
766
+ result[port_name] = None
767
+ return result
768
+ else:
769
+ return {output_ports[0]: raw_result}
770
+
771
+ def _map_inference_result(self, node, raw_result: Any) -> dict[str, Any]:
772
+ """Map inference API result to output ports."""
773
+ if raw_result is None:
774
+ return {}
775
+
776
+ output_ports = node._output_ports
777
+ if not output_ports:
778
+ return {"output": raw_result}
779
+
780
+ return {output_ports[0]: raw_result}
781
+
782
+ async def execute_all(
783
+ self, session: ExecutionSession, entry_inputs: dict[str, dict[str, Any]]
784
+ ) -> dict[str, Any]:
785
+ execution_order = self.graph.get_execution_order()
786
+ session.results = {}
787
+
788
+ for node_name in execution_order:
789
+ user_input = entry_inputs.get(node_name, {})
790
+ await self.execute_node(session, node_name, user_input)
791
+
792
+ return session.results
793
+
794
+
795
+ class SequentialExecutor:
796
+ """Legacy synchronous executor for backwards compatibility.
797
+
798
+ This wraps the AsyncExecutor for use in synchronous contexts like node.test().
799
+ For production use, prefer AsyncExecutor with proper session management.
800
+ """
801
+
802
+ def __init__(self, graph: Graph, hf_token: str | None = None):
803
+ self.graph = graph
804
+ self._async_executor = AsyncExecutor(graph)
805
+ self._session = ExecutionSession(graph, hf_token)
806
+
807
+ @property
808
+ def results(self) -> dict[str, Any]:
809
+ return self._session.results
810
+
811
+ @results.setter
812
+ def results(self, value: dict[str, Any]):
813
+ self._session.results = value
814
+
815
+ @property
816
+ def selected_variants(self) -> dict[str, int]:
817
+ return self._session.selected_variants
818
+
819
+ @selected_variants.setter
820
+ def selected_variants(self, value: dict[str, int]):
821
+ self._session.selected_variants = value
822
+
823
+ def set_hf_token(self, token: str | None):
824
+ self._session.set_hf_token(token)
825
+
826
+ def execute_node(
827
+ self, node_name: str, user_inputs: dict[str, Any] | None = None
828
+ ) -> Any:
829
+ """Synchronous wrapper around async execute_node."""
830
+ loop = asyncio.new_event_loop()
831
+ try:
832
+ return loop.run_until_complete(
833
+ self._async_executor.execute_node(self._session, node_name, user_inputs)
834
+ )
835
+ finally:
836
+ loop.close()
837
+
838
+ def execute_all(self, entry_inputs: dict[str, dict[str, Any]]) -> dict[str, Any]:
839
+ """Synchronous wrapper around async execute_all."""
840
+ loop = asyncio.new_event_loop()
841
+ try:
842
+ return loop.run_until_complete(
843
+ self._async_executor.execute_all(self._session, entry_inputs)
844
+ )
845
+ finally:
846
+ loop.close()
daggr/graph.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graph module for daggr.
2
+
3
+ A Graph represents a directed acyclic graph (DAG) of nodes that can be
4
+ executed to process data through a pipeline.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import itertools
10
+ import os
11
+ import re
12
+ import sys
13
+ import threading
14
+ from collections.abc import Sequence
15
+ from typing import TYPE_CHECKING, Any
16
+
17
+ import networkx as nx
18
+
19
+ from daggr._utils import suggest_similar
20
+ from daggr.edge import Edge
21
+ from daggr.local_space import prepare_local_node
22
+ from daggr.node import ChoiceNode, GradioNode, InferenceNode, Node
23
+ from daggr.port import Port
24
+
25
+ if TYPE_CHECKING:
26
+ from gradio.themes import ThemeClass as Theme
27
+
28
+
29
+ def _parse_space_id(src: str) -> str | None:
30
+ if src.startswith("http://") or src.startswith("https://"):
31
+ match = re.match(r"https?://huggingface\.co/spaces/([^/]+/[^/?#]+)", src)
32
+ if match:
33
+ return match.group(1)
34
+ return None
35
+ if "/" in src:
36
+ return src
37
+ return None
38
+
39
+
40
+ def _get_dependency_id(node) -> tuple[str | None, str]:
41
+ if isinstance(node, GradioNode):
42
+ space_id = _parse_space_id(node._src)
43
+ return space_id, "space"
44
+ elif isinstance(node, InferenceNode):
45
+ return node._model_name_for_hub, "model"
46
+ return None, ""
47
+
48
+
49
+ def _fetch_current_sha(dep_id: str, dep_type: str) -> str | None:
50
+ try:
51
+ if dep_type == "space":
52
+ from huggingface_hub import space_info
53
+
54
+ info = space_info(dep_id)
55
+ return info.sha
56
+ elif dep_type == "model":
57
+ from huggingface_hub import model_info
58
+
59
+ info = model_info(dep_id)
60
+ return info.sha
61
+ except Exception:
62
+ return None
63
+ return None
64
+
65
+
66
+ def _duplicate_space_at_revision(
67
+ space_id: str, revision: str, username: str
68
+ ) -> str | None:
69
+ try:
70
+ from huggingface_hub import (
71
+ create_repo,
72
+ snapshot_download,
73
+ upload_folder,
74
+ )
75
+
76
+ space_name = space_id.split("/")[-1]
77
+ new_repo_id = f"{username}/{space_name}"
78
+
79
+ local_dir = snapshot_download(
80
+ repo_id=space_id,
81
+ repo_type="space",
82
+ revision=revision,
83
+ )
84
+
85
+ create_repo(
86
+ repo_id=new_repo_id,
87
+ repo_type="space",
88
+ space_sdk="gradio",
89
+ exist_ok=True,
90
+ )
91
+
92
+ upload_folder(
93
+ repo_id=new_repo_id,
94
+ repo_type="space",
95
+ folder_path=local_dir,
96
+ )
97
+
98
+ return new_repo_id
99
+ except Exception as e:
100
+ print(f" [daggr] Failed to duplicate Space: {e}")
101
+ return None
102
+
103
+
104
+ def _prompt_dependency_changes(changed: list[dict]) -> None:
105
+ from daggr import _client_cache
106
+
107
+ is_tty = hasattr(sys.stdin, "isatty") and sys.stdin.isatty()
108
+
109
+ print("\n ⚠️ Upstream dependency changes detected:\n")
110
+ for item in changed:
111
+ print(
112
+ f" • {item['type']} '{item['id']}' (node: {item['node']._name})\n"
113
+ f" cached: {item['cached_sha'][:12]}\n"
114
+ f" current: {item['current_sha'][:12]}"
115
+ )
116
+ print()
117
+
118
+ if not is_tty:
119
+ for item in changed:
120
+ _client_cache.set_dependency_hash(item["id"], item["current_sha"])
121
+ print(
122
+ " [daggr] Non-interactive mode: auto-updated all hashes.\n"
123
+ " Set DAGGR_DEPENDENCY_CHECK=skip to suppress this warning.\n"
124
+ )
125
+ return
126
+
127
+ for item in changed:
128
+ is_space = item["type"] == "space"
129
+ if is_space:
130
+ print(
131
+ f" How would you like to handle '{item['id']}'?\n"
132
+ f" [1] Duplicate the original version under your namespace (safer)\n"
133
+ f" [2] Update to the latest version"
134
+ )
135
+ else:
136
+ print(
137
+ f" How would you like to handle '{item['id']}'?\n"
138
+ f" [1] Update to the latest version"
139
+ )
140
+
141
+ try:
142
+ choice = input(" Choice [1]: ").strip() or "1"
143
+ except (EOFError, KeyboardInterrupt):
144
+ choice = "1"
145
+
146
+ if is_space and choice == "1":
147
+ username = _get_hf_username()
148
+ if username is None:
149
+ print(
150
+ " [daggr] Not logged in to Hugging Face. "
151
+ "Updating hash instead.\n"
152
+ " Run `huggingface-cli login` to enable Space duplication."
153
+ )
154
+ _client_cache.set_dependency_hash(item["id"], item["current_sha"])
155
+ else:
156
+ print(
157
+ f" [daggr] Duplicating '{item['id']}' at revision "
158
+ f"{item['cached_sha'][:12]} under {username}/..."
159
+ )
160
+ new_id = _duplicate_space_at_revision(
161
+ item["id"], item["cached_sha"], username
162
+ )
163
+ if new_id:
164
+ item["node"]._src = new_id
165
+ _client_cache.set_dependency_hash(new_id, item["cached_sha"])
166
+ print(
167
+ f" [daggr] Duplicated → '{new_id}'. "
168
+ f"Node now points to duplicated Space."
169
+ )
170
+ else:
171
+ print(
172
+ " [daggr] Duplication failed (revision may have been "
173
+ "squashed). Updating hash instead."
174
+ )
175
+ _client_cache.set_dependency_hash(item["id"], item["current_sha"])
176
+ else:
177
+ _client_cache.set_dependency_hash(item["id"], item["current_sha"])
178
+ print(f" [daggr] Updated hash for '{item['id']}'.")
179
+
180
+ print()
181
+
182
+
183
+ def _get_hf_username() -> str | None:
184
+ try:
185
+ from huggingface_hub import get_token, whoami
186
+
187
+ token = get_token()
188
+ if not token:
189
+ return None
190
+ info = whoami(cache=True)
191
+ return info.get("name")
192
+ except Exception:
193
+ return None
194
+
195
+
196
+ class _Spinner:
197
+ _CHARS = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
198
+
199
+ def __init__(self, message: str):
200
+ self._message = message
201
+ self._is_tty = hasattr(sys.stdout, "isatty") and sys.stdout.isatty()
202
+ if self._is_tty:
203
+ self._stop = threading.Event()
204
+ self._thread = threading.Thread(target=self._spin, daemon=True)
205
+ self._thread.start()
206
+
207
+ def _spin(self):
208
+ frames = itertools.cycle(self._CHARS)
209
+ while not self._stop.is_set():
210
+ sys.stdout.write(f"\r {next(frames)} {self._message}")
211
+ sys.stdout.flush()
212
+ self._stop.wait(0.08)
213
+
214
+ def _finish(self, symbol: str, suffix: str = ""):
215
+ line = f" {symbol} {self._message}"
216
+ if suffix:
217
+ line += f" — {suffix}"
218
+ if self._is_tty:
219
+ self._stop.set()
220
+ self._thread.join()
221
+ sys.stdout.write(f"\r{line}\033[K\n")
222
+ else:
223
+ sys.stdout.write(f"{line}\n")
224
+ sys.stdout.flush()
225
+
226
+ def succeed(self, suffix: str = ""):
227
+ self._finish("✓", suffix)
228
+
229
+ def warn(self, suffix: str = ""):
230
+ self._finish("⚠", suffix)
231
+
232
+
233
+ def _get_node_display_label(node) -> str:
234
+ if isinstance(node, GradioNode):
235
+ label = node._src
236
+ if node._api_name:
237
+ label += f" ({node._api_name})"
238
+ return label
239
+ elif isinstance(node, InferenceNode):
240
+ return node._model_name_for_hub
241
+ return node._name
242
+
243
+
244
+ class Graph:
245
+ """A directed acyclic graph (DAG) of nodes for data processing.
246
+
247
+ A Graph connects nodes together to form a pipeline. Data flows from entry
248
+ nodes (nodes with no inputs) through the graph to output nodes.
249
+
250
+ Example:
251
+ >>> from daggr import Graph, FnNode
252
+ >>> def step1(x): return {"out": x * 2}
253
+ >>> def step2(y): return {"out": y + 1}
254
+ >>> n1 = FnNode(step1)
255
+ >>> n2 = FnNode(step2, inputs={"y": n1.out})
256
+ >>> graph = Graph("My Pipeline", nodes=[n2])
257
+ >>> graph.launch()
258
+ """
259
+
260
+ def __init__(
261
+ self,
262
+ name: str,
263
+ nodes: Sequence[Node] | None = None,
264
+ persist_key: str | bool | None = None,
265
+ ):
266
+ """Create a new Graph.
267
+
268
+ Args:
269
+ name: Display name for this graph shown in the UI.
270
+ nodes: Optional list of nodes to add to the graph.
271
+ persist_key: Unique key used to store this graph's data in the database.
272
+ If not provided, derived from name by converting to lowercase
273
+ and replacing spaces/special chars with underscores.
274
+ Set to False to disable persistence entirely.
275
+ Use a custom string to ensure persistence works correctly
276
+ if you change the display name later.
277
+ """
278
+ if not name or not isinstance(name, str):
279
+ raise ValueError(
280
+ "Graph requires a 'name' parameter. "
281
+ "Example: Graph(name='My Podcast Generator', nodes=[...])"
282
+ )
283
+ self.name = name
284
+ if persist_key is False:
285
+ self.persist_key = None
286
+ elif persist_key:
287
+ self.persist_key = persist_key
288
+ else:
289
+ self.persist_key = re.sub(r"[^a-z0-9]+", "_", name.lower()).strip("_")
290
+ self.nodes: dict[str, Node] = {}
291
+ self._nx_graph = nx.DiGraph()
292
+ self._edges: list[Edge] = []
293
+
294
+ if nodes:
295
+ for node in nodes:
296
+ self.add(node)
297
+
298
+ def add(self, node: Node) -> Graph:
299
+ """Add a node to the graph.
300
+
301
+ Also adds any upstream nodes connected via the node's port connections.
302
+
303
+ Args:
304
+ node: The node to add.
305
+
306
+ Returns:
307
+ self, for method chaining.
308
+ """
309
+ self._add_node(node)
310
+ self._create_edges_from_port_connections(node)
311
+ return self
312
+
313
+ def edge(self, source: Port, target: Port) -> Graph:
314
+ """Create an edge connecting two ports.
315
+
316
+ Args:
317
+ source: The source port (output of a node).
318
+ target: The target port (input of a node).
319
+
320
+ Returns:
321
+ self, for method chaining.
322
+
323
+ Raises:
324
+ ValueError: If the edge would create a cycle.
325
+ """
326
+ edge = Edge(source, target)
327
+ self._add_edge(edge)
328
+ return self
329
+
330
+ def _add_node(self, node: Node) -> None:
331
+ if node._name in self.nodes:
332
+ if self.nodes[node._name] is not node:
333
+ raise ValueError(f"Node with name '{node._name}' already exists")
334
+ return
335
+ self.nodes[node._name] = node
336
+ self._nx_graph.add_node(node._name)
337
+
338
+ def _create_edges_from_port_connections(self, node: Node) -> None:
339
+ for target_port_name, source_port in node._port_connections.items():
340
+ source_node = source_port.node
341
+ source_port_name = source_port.name
342
+
343
+ if source_port_name not in source_node._output_ports:
344
+ available = set(source_node._output_ports)
345
+ suggestion = suggest_similar(source_port_name, available)
346
+ available_str = ", ".join(available) or "(none)"
347
+ msg = (
348
+ f"Output port '{source_port_name}' not found on node "
349
+ f"'{source_node._name}'. Available outputs: {available_str}"
350
+ )
351
+ if suggestion:
352
+ msg += f" Did you mean '{suggestion}'?"
353
+ raise ValueError(msg)
354
+
355
+ is_new_node = source_node._name not in self.nodes
356
+ self._add_node(source_node)
357
+ if is_new_node:
358
+ self._create_edges_from_port_connections(source_node)
359
+ target_port = Port(node, target_port_name)
360
+ edge = Edge(source_port, target_port)
361
+ self._add_edge(edge)
362
+
363
+ def _add_edge(self, edge: Edge) -> None:
364
+ self._add_node(edge.source_node)
365
+ self._add_node(edge.target_node)
366
+
367
+ self._edges.append(edge)
368
+ self._nx_graph.add_edge(edge.source_node._name, edge.target_node._name)
369
+
370
+ if not nx.is_directed_acyclic_graph(self._nx_graph):
371
+ self._nx_graph.remove_edge(edge.source_node._name, edge.target_node._name)
372
+ self._edges.pop()
373
+ raise ValueError("Connection would create a cycle in the DAG")
374
+
375
+ def get_entry_nodes(self) -> list[Node]:
376
+ """Get all nodes with no incoming edges (entry points of the graph)."""
377
+ entry_nodes = []
378
+ for node_name in self.nodes:
379
+ if self._nx_graph.in_degree(node_name) == 0:
380
+ entry_nodes.append(self.nodes[node_name])
381
+ return entry_nodes
382
+
383
+ def get_execution_order(self) -> list[str]:
384
+ """Get the topologically sorted order of node names for execution."""
385
+ return list(nx.topological_sort(self._nx_graph))
386
+
387
+ def get_connections(self) -> list[tuple]:
388
+ """Get all edges as tuples of (source_node, source_port, target_node, target_port)."""
389
+ return [edge.as_tuple() for edge in self._edges]
390
+
391
+ def _validate_edges(self) -> None:
392
+ errors = []
393
+ for edge in self._edges:
394
+ source_node = edge.source_node
395
+ target_node = edge.target_node
396
+ source_port = edge.source_port
397
+ target_port = edge.target_port
398
+
399
+ if source_port not in source_node._output_ports:
400
+ available = set(source_node._output_ports)
401
+ available_str = ", ".join(available) or "(none)"
402
+ suggestion = suggest_similar(source_port, available)
403
+ msg = (
404
+ f"Output port '{source_port}' not found on node "
405
+ f"'{source_node._name}'. Available outputs: {available_str}"
406
+ )
407
+ if suggestion:
408
+ msg += f" Did you mean '{suggestion}'?"
409
+ errors.append(msg)
410
+
411
+ if target_port not in target_node._input_ports:
412
+ available = set(target_node._input_ports)
413
+ available_str = ", ".join(available) or "(none)"
414
+ suggestion = suggest_similar(target_port, available)
415
+ msg = (
416
+ f"Input port '{target_port}' not found on node "
417
+ f"'{target_node._name}'. Available inputs: {available_str}"
418
+ )
419
+ if suggestion:
420
+ msg += f" Did you mean '{suggestion}'?"
421
+ errors.append(msg)
422
+
423
+ if errors:
424
+ raise ValueError("Invalid port connections:\n - " + "\n - ".join(errors))
425
+
426
+ def launch(
427
+ self,
428
+ host: str | None = None,
429
+ port: int | None = None,
430
+ share: bool | None = None,
431
+ open_browser: bool = True,
432
+ theme: Theme | str | None = None,
433
+ api_server: bool = True,
434
+ **kwargs,
435
+ ):
436
+ """Launch the graph as an interactive web application.
437
+
438
+ Starts a web server that displays the graph and allows users to
439
+ execute nodes and view results.
440
+
441
+ Args:
442
+ host: Host to bind to. Defaults to GRADIO_SERVER_NAME env var,
443
+ or "127.0.0.1" if not set. Set to "0.0.0.0" to make
444
+ accessible on a network or when deploying to Hugging Face Spaces.
445
+ port: Port to bind to. Defaults to GRADIO_SERVER_PORT env var,
446
+ or 7860 if not set.
447
+ share: If True, create a public share link. Defaults to True in
448
+ Colab/Kaggle environments, False otherwise.
449
+ open_browser: If True, automatically open the app in the default
450
+ web browser. Defaults to True.
451
+ theme: A Gradio theme to use for styling. Can be a Gradio `Theme` instance,
452
+ a string name like "default", "soft", "monochrome", "glass",
453
+ or a Hub theme like "gradio/seafoam". Defaults to the Gradio
454
+ default theme.
455
+ api_server: If True, expose the programmatic API endpoints
456
+ (/api/call, /api/schema). Defaults to True.
457
+ **kwargs: Additional arguments passed to uvicorn.
458
+ """
459
+ from daggr.server import DaggrServer
460
+
461
+ if host is None:
462
+ host = os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1")
463
+ if port is None:
464
+ port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
465
+
466
+ self._startup_display()
467
+ server = DaggrServer(self, theme=theme, api_server=api_server)
468
+ server.run(
469
+ host=host, port=port, share=share, open_browser=open_browser, **kwargs
470
+ )
471
+
472
+ def _prepare_local_nodes(self) -> None:
473
+ for node in self.nodes.values():
474
+ if isinstance(node, ChoiceNode):
475
+ for variant in node._variants:
476
+ if isinstance(variant, GradioNode) and variant._run_locally:
477
+ prepare_local_node(variant)
478
+ elif isinstance(node, GradioNode) and node._run_locally:
479
+ prepare_local_node(node)
480
+
481
+ def _check_dependency_hashes(self) -> None:
482
+ mode = os.environ.get("DAGGR_DEPENDENCY_CHECK", "").lower()
483
+ if mode == "skip":
484
+ return
485
+
486
+ from daggr import _client_cache
487
+
488
+ nodes_to_check: list[GradioNode | InferenceNode] = []
489
+ for node in self.nodes.values():
490
+ if isinstance(node, ChoiceNode):
491
+ for variant in node._variants:
492
+ if isinstance(variant, (GradioNode, InferenceNode)):
493
+ nodes_to_check.append(variant)
494
+ elif isinstance(node, (GradioNode, InferenceNode)):
495
+ nodes_to_check.append(node)
496
+
497
+ if not nodes_to_check:
498
+ return
499
+
500
+ changed: list[dict[str, Any]] = []
501
+ for node in nodes_to_check:
502
+ dep_id, dep_type = _get_dependency_id(node)
503
+ if dep_id is None:
504
+ continue
505
+
506
+ current_sha = _fetch_current_sha(dep_id, dep_type)
507
+ if current_sha is None:
508
+ continue
509
+
510
+ cached_sha = _client_cache.get_dependency_hash(dep_id)
511
+ if cached_sha is None:
512
+ _client_cache.set_dependency_hash(dep_id, current_sha)
513
+ elif cached_sha != current_sha:
514
+ changed.append(
515
+ {
516
+ "type": dep_type,
517
+ "id": dep_id,
518
+ "node": node,
519
+ "cached_sha": cached_sha,
520
+ "current_sha": current_sha,
521
+ }
522
+ )
523
+
524
+ if not changed:
525
+ return
526
+
527
+ if mode == "update":
528
+ for item in changed:
529
+ _client_cache.set_dependency_hash(item["id"], item["current_sha"])
530
+ print(
531
+ f" [daggr] Auto-updated hash for {item['type']} "
532
+ f"'{item['id']}' → {item['current_sha'][:12]}"
533
+ )
534
+ return
535
+
536
+ if mode == "error":
537
+ descs = [
538
+ f" • {item['type']} '{item['id']}': "
539
+ f"{item['cached_sha'][:12]} → {item['current_sha'][:12]}"
540
+ for item in changed
541
+ ]
542
+ raise RuntimeError(
543
+ "Upstream dependencies have changed:\n"
544
+ + "\n".join(descs)
545
+ + "\nSet DAGGR_DEPENDENCY_CHECK=update to accept changes."
546
+ )
547
+
548
+ _prompt_dependency_changes(changed)
549
+
550
+ def _startup_display(self) -> None:
551
+ mode = os.environ.get("DAGGR_DEPENDENCY_CHECK", "").lower()
552
+ skip_hashes = mode == "skip"
553
+
554
+ node_count = len(self.nodes)
555
+ noun = "node" if node_count == 1 else "nodes"
556
+ print(f"\n Launching Daggr ({self.name}) with {node_count} {noun}:\n")
557
+
558
+ from daggr import _client_cache
559
+
560
+ changed: list[dict[str, Any]] = []
561
+
562
+ def _check_hash(node):
563
+ dep_id, dep_type = _get_dependency_id(node)
564
+ if dep_id is None:
565
+ return None
566
+
567
+ current_sha = _fetch_current_sha(dep_id, dep_type)
568
+ if current_sha is None:
569
+ return None
570
+
571
+ cached_sha = _client_cache.get_dependency_hash(dep_id)
572
+ if cached_sha is None:
573
+ _client_cache.set_dependency_hash(dep_id, current_sha)
574
+ return ("recorded", f"hash {current_sha[:7]} recorded")
575
+ elif cached_sha == current_sha:
576
+ return ("matches", f"hash {current_sha[:7]} matches")
577
+ else:
578
+ changed.append(
579
+ {
580
+ "type": dep_type,
581
+ "id": dep_id,
582
+ "node": node,
583
+ "cached_sha": cached_sha,
584
+ "current_sha": current_sha,
585
+ }
586
+ )
587
+ return ("changed", "hash changed")
588
+
589
+ for node in self.nodes.values():
590
+ if isinstance(node, ChoiceNode):
591
+ spinner = _Spinner(node._name)
592
+ for variant in node._variants:
593
+ if isinstance(variant, GradioNode) and variant._run_locally:
594
+ prepare_local_node(variant)
595
+ results = []
596
+ if not skip_hashes:
597
+ for variant in node._variants:
598
+ if isinstance(variant, (GradioNode, InferenceNode)):
599
+ result = _check_hash(variant)
600
+ if result:
601
+ results.append(result)
602
+ if any(r[0] == "changed" for r in results):
603
+ spinner.warn("hash changed")
604
+ elif results:
605
+ spinner.succeed(results[-1][1])
606
+ else:
607
+ spinner.succeed()
608
+ continue
609
+
610
+ if isinstance(node, GradioNode) and node._run_locally:
611
+ prepare_local_node(node)
612
+
613
+ label = _get_node_display_label(node)
614
+
615
+ if isinstance(node, (GradioNode, InferenceNode)) and not skip_hashes:
616
+ spinner = _Spinner(label)
617
+ result = _check_hash(node)
618
+ if result and result[0] == "changed":
619
+ spinner.warn(result[1])
620
+ elif result:
621
+ spinner.succeed(result[1])
622
+ else:
623
+ spinner.succeed()
624
+ else:
625
+ sys.stdout.write(f" ✓ {label}\n")
626
+ sys.stdout.flush()
627
+
628
+ print()
629
+
630
+ if not changed:
631
+ return
632
+
633
+ if mode == "update":
634
+ for item in changed:
635
+ _client_cache.set_dependency_hash(item["id"], item["current_sha"])
636
+ print(
637
+ f" [daggr] Auto-updated hash for {item['type']} "
638
+ f"'{item['id']}' → {item['current_sha'][:12]}"
639
+ )
640
+ return
641
+
642
+ if mode == "error":
643
+ descs = [
644
+ f" • {item['type']} '{item['id']}': "
645
+ f"{item['cached_sha'][:12]} → {item['current_sha'][:12]}"
646
+ for item in changed
647
+ ]
648
+ raise RuntimeError(
649
+ "Upstream dependencies have changed:\n"
650
+ + "\n".join(descs)
651
+ + "\nSet DAGGR_DEPENDENCY_CHECK=update to accept changes."
652
+ )
653
+
654
+ _prompt_dependency_changes(changed)
655
+
656
+ def get_subgraphs(self) -> list[set[str]]:
657
+ """Get all weakly connected components of the graph.
658
+
659
+ Returns a list of sets, where each set contains the node names
660
+ belonging to a connected subgraph. If the graph is fully connected,
661
+ returns a single set with all node names.
662
+ """
663
+ return [set(c) for c in nx.weakly_connected_components(self._nx_graph)]
664
+
665
+ def get_output_nodes(self) -> list[str]:
666
+ """Get all nodes with no outgoing edges (output/leaf nodes)."""
667
+ return [
668
+ node_name
669
+ for node_name in self.nodes
670
+ if self._nx_graph.out_degree(node_name) == 0
671
+ ]
672
+
673
+ def get_api_schema(self) -> dict:
674
+ """Get the API schema describing inputs and outputs for each subgraph.
675
+
676
+ Returns a dict with:
677
+ - subgraphs: list of subgraph info, each containing:
678
+ - id: subgraph identifier (e.g., "main" or "subgraph_0")
679
+ - inputs: list of {node, port, type, component} for each input
680
+ - outputs: list of {node, port, type, component} for each output
681
+ """
682
+ subgraphs = self.get_subgraphs()
683
+ output_nodes = set(self.get_output_nodes())
684
+ result = {"subgraphs": []}
685
+
686
+ for idx, subgraph_nodes in enumerate(subgraphs):
687
+ subgraph_id = "main" if len(subgraphs) == 1 else f"subgraph_{idx}"
688
+
689
+ inputs = []
690
+ outputs = []
691
+
692
+ for node_name in subgraph_nodes:
693
+ node = self.nodes[node_name]
694
+
695
+ if isinstance(node, ChoiceNode):
696
+ continue
697
+
698
+ if node._input_components:
699
+ for port_name, comp in node._input_components.items():
700
+ comp_type = self._get_component_type(comp)
701
+ inputs.append(
702
+ {
703
+ "node": node_name,
704
+ "port": port_name,
705
+ "type": comp_type,
706
+ "id": f"{node_name}__{port_name}".replace(
707
+ " ", "_"
708
+ ).replace("-", "_"),
709
+ }
710
+ )
711
+
712
+ if node_name in output_nodes and node._output_components:
713
+ for port_name, comp in node._output_components.items():
714
+ if comp is None:
715
+ continue
716
+ comp_type = self._get_component_type(comp)
717
+ outputs.append(
718
+ {
719
+ "node": node_name,
720
+ "port": port_name,
721
+ "type": comp_type,
722
+ }
723
+ )
724
+
725
+ result["subgraphs"].append(
726
+ {
727
+ "id": subgraph_id,
728
+ "inputs": inputs,
729
+ "outputs": outputs,
730
+ }
731
+ )
732
+
733
+ return result
734
+
735
+ def _get_component_type(self, component) -> str:
736
+ """Get the type string for a Gradio component."""
737
+ class_name = component.__class__.__name__
738
+ type_map = {
739
+ "Audio": "audio",
740
+ "Textbox": "textbox",
741
+ "TextArea": "textarea",
742
+ "JSON": "json",
743
+ "Chatbot": "json",
744
+ "Image": "image",
745
+ "Number": "number",
746
+ "Markdown": "markdown",
747
+ "Text": "text",
748
+ "Dropdown": "dropdown",
749
+ "Video": "video",
750
+ "File": "file",
751
+ "Model3D": "model3d",
752
+ "Gallery": "gallery",
753
+ "Slider": "slider",
754
+ "Radio": "radio",
755
+ "Checkbox": "checkbox",
756
+ "CheckboxGroup": "checkboxgroup",
757
+ "ColorPicker": "colorpicker",
758
+ "Label": "label",
759
+ "HighlightedText": "highlightedtext",
760
+ "Code": "code",
761
+ "HTML": "html",
762
+ "Dataframe": "dataframe",
763
+ }
764
+ return type_map.get(class_name, "text")
765
+
766
+ def __repr__(self):
767
+ return f"Graph(name={self.name}, nodes={len(self.nodes)}, edges={len(self._edges)})"
daggr/local_space.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import atexit
4
+ import hashlib
5
+ import json
6
+ import os
7
+ import re
8
+ import select
9
+ import shutil
10
+ import socket
11
+ import subprocess
12
+ import sys
13
+ import time
14
+ import urllib.error
15
+ import urllib.request
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+ from typing import TYPE_CHECKING, Any
19
+
20
+ if TYPE_CHECKING:
21
+ from daggr.node import GradioNode
22
+
23
+ from daggr.state import get_daggr_cache_dir
24
+
25
+
26
+ def _get_spaces_cache_dir() -> Path:
27
+ return get_daggr_cache_dir() / "spaces"
28
+
29
+
30
+ def _get_logs_dir() -> Path:
31
+ return get_daggr_cache_dir() / "logs"
32
+
33
+
34
+ _running_processes: dict[str, subprocess.Popen] = {}
35
+
36
+
37
+ def _get_space_dir(space_id: str) -> Path:
38
+ spaces_dir = _get_spaces_cache_dir()
39
+ parts = space_id.split("/")
40
+ if len(parts) == 2:
41
+ owner, name = parts
42
+ return spaces_dir / owner / name
43
+ return spaces_dir / space_id.replace("/", "_")
44
+
45
+
46
+ def _get_metadata_path(space_dir: Path) -> Path:
47
+ return space_dir / ".daggr_metadata.json"
48
+
49
+
50
+ def _hash_file(file_path: Path) -> str:
51
+ if not file_path.exists():
52
+ return ""
53
+ return hashlib.sha256(file_path.read_bytes()).hexdigest()[:16]
54
+
55
+
56
+ def _find_free_port(start: int = 7861, end: int = 7960) -> int:
57
+ for port in range(start, end):
58
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
59
+ try:
60
+ s.bind(("127.0.0.1", port))
61
+ return port
62
+ except OSError:
63
+ continue
64
+ raise RuntimeError(f"No free ports available in range {start}-{end}")
65
+
66
+
67
+ def _is_space_id(src: str) -> bool:
68
+ if src.startswith("http://") or src.startswith("https://"):
69
+ return False
70
+ return "/" in src and not src.startswith("/")
71
+
72
+
73
+ class LocalSpaceManager:
74
+ def __init__(self, node: GradioNode):
75
+ self.node = node
76
+ self.space_id = node._src
77
+ self.space_dir = _get_space_dir(self.space_id)
78
+ self.repo_dir = self.space_dir / "repo"
79
+ self.venv_dir = self.space_dir / ".venv"
80
+ self.metadata_path = _get_metadata_path(self.space_dir)
81
+ self.process: subprocess.Popen | None = None
82
+ self.local_url: str | None = None
83
+
84
+ def ensure_ready(self) -> str:
85
+ if not _is_space_id(self.space_id):
86
+ raise ValueError(
87
+ f"Cannot run locally: '{self.space_id}' is not a valid Space ID. "
88
+ "Local mode only works with Hugging Face Spaces (format: 'owner/space-name')."
89
+ )
90
+
91
+ try:
92
+ self._ensure_cloned()
93
+ self._ensure_venv()
94
+ url = self._launch_app()
95
+ return url
96
+ except Exception as e:
97
+ self._log_error(e)
98
+ raise
99
+
100
+ def _ensure_cloned(self) -> None:
101
+ metadata = self._load_metadata()
102
+
103
+ if self.repo_dir.exists() and metadata:
104
+ should_update = os.environ.get("DAGGR_UPDATE_SPACES") == "1"
105
+ if not should_update:
106
+ return
107
+
108
+ self.space_dir.mkdir(parents=True, exist_ok=True)
109
+
110
+ from huggingface_hub import snapshot_download
111
+
112
+ print(f" Cloning Space '{self.space_id}'...")
113
+
114
+ if self.repo_dir.exists():
115
+ shutil.rmtree(self.repo_dir)
116
+
117
+ snapshot_download(
118
+ repo_id=self.space_id,
119
+ repo_type="space",
120
+ local_dir=self.repo_dir,
121
+ )
122
+
123
+ requirements_path = self.repo_dir / "requirements.txt"
124
+ metadata = {
125
+ "cloned_at": datetime.now().isoformat(),
126
+ "space_id": self.space_id,
127
+ "requirements_hash": _hash_file(requirements_path),
128
+ "python_version": f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}",
129
+ }
130
+ self._save_metadata(metadata)
131
+ print(f" Cloned to {self.repo_dir}")
132
+
133
+ def _get_sdk_version(self) -> str | None:
134
+ readme_path = self.repo_dir / "README.md"
135
+ if not readme_path.exists():
136
+ return None
137
+
138
+ try:
139
+ content = readme_path.read_text()
140
+ if not content.startswith("---"):
141
+ return None
142
+
143
+ parts = content.split("---", 2)
144
+ if len(parts) < 3:
145
+ return None
146
+
147
+ match = re.search(r"sdk_version:\s*['\"]?([^\s'\"]+)", parts[1])
148
+ if match:
149
+ return match.group(1)
150
+ except Exception:
151
+ pass
152
+
153
+ return None
154
+
155
+ def _ensure_venv(self) -> None:
156
+ requirements_path = self.repo_dir / "requirements.txt"
157
+ current_hash = _hash_file(requirements_path)
158
+ metadata = self._load_metadata()
159
+
160
+ venv_python = self.venv_dir / "bin" / "python"
161
+ if sys.platform == "win32":
162
+ venv_python = self.venv_dir / "Scripts" / "python.exe"
163
+
164
+ needs_reinstall = False
165
+ if not self.venv_dir.exists() or not venv_python.exists():
166
+ needs_reinstall = True
167
+ elif metadata and metadata.get("requirements_hash") != current_hash:
168
+ needs_reinstall = True
169
+
170
+ if not needs_reinstall:
171
+ return
172
+
173
+ print(f" Setting up virtual environment for '{self.space_id}'...")
174
+
175
+ if self.venv_dir.exists():
176
+ shutil.rmtree(self.venv_dir)
177
+
178
+ subprocess.run(
179
+ [sys.executable, "-m", "venv", str(self.venv_dir)],
180
+ check=True,
181
+ capture_output=True,
182
+ )
183
+
184
+ pip_path = self.venv_dir / "bin" / "pip"
185
+ if sys.platform == "win32":
186
+ pip_path = self.venv_dir / "Scripts" / "pip.exe"
187
+
188
+ subprocess.run(
189
+ [str(pip_path), "install", "--upgrade", "pip"],
190
+ check=True,
191
+ capture_output=True,
192
+ )
193
+
194
+ sdk_version = self._get_sdk_version()
195
+ if sdk_version:
196
+ gradio_pkg = f"gradio=={sdk_version}"
197
+ print(f" Installing {gradio_pkg}...")
198
+ else:
199
+ gradio_pkg = "gradio"
200
+ print(" Installing gradio (latest)...")
201
+
202
+ result = subprocess.run(
203
+ [str(pip_path), "install", gradio_pkg],
204
+ capture_output=True,
205
+ text=True,
206
+ )
207
+ if result.returncode != 0:
208
+ error_msg = result.stderr or result.stdout
209
+ self._log_to_file("pip_install_gradio", error_msg)
210
+ print(f" Warning: Failed to install {gradio_pkg}")
211
+
212
+ if requirements_path.exists():
213
+ print(f" Installing dependencies from {requirements_path}...")
214
+ print(" (this may take a few minutes)")
215
+
216
+ process = subprocess.Popen(
217
+ [str(pip_path), "install", "-r", str(requirements_path)],
218
+ stdout=subprocess.PIPE,
219
+ stderr=subprocess.STDOUT,
220
+ text=True,
221
+ bufsize=1,
222
+ )
223
+
224
+ output_lines = []
225
+ for line in iter(process.stdout.readline, ""):
226
+ output_lines.append(line)
227
+ line_stripped = line.strip()
228
+ if line_stripped.startswith("Collecting "):
229
+ pkg = line_stripped.replace("Collecting ", "").split()[0]
230
+ print(f" Installing {pkg}...")
231
+ elif (
232
+ line_stripped.startswith("ERROR:")
233
+ or "error" in line_stripped.lower()
234
+ ):
235
+ print(f" {line_stripped}")
236
+
237
+ process.wait()
238
+
239
+ if process.returncode != 0:
240
+ error_msg = "".join(output_lines)
241
+ self._log_to_file("pip_install", error_msg)
242
+ print("\n ❌ Dependency installation failed!")
243
+ print(f" Full log: {self._get_log_path('pip_install')}")
244
+ raise RuntimeError(
245
+ f"Failed to install dependencies for '{self.space_id}'.\n"
246
+ f"See logs at: {self._get_log_path('pip_install')}\n"
247
+ f"You can try installing manually:\n"
248
+ f" {pip_path} install -r {requirements_path}"
249
+ )
250
+
251
+ if metadata:
252
+ metadata["requirements_hash"] = current_hash
253
+ self._save_metadata(metadata)
254
+
255
+ print(" Virtual environment ready")
256
+
257
+ def _launch_app(self) -> str:
258
+ global _running_processes
259
+
260
+ if self.space_id in _running_processes:
261
+ proc = _running_processes[self.space_id]
262
+ if proc.poll() is None:
263
+ metadata = self._load_metadata()
264
+ if metadata and metadata.get("local_url"):
265
+ return metadata["local_url"]
266
+
267
+ app_file = self._find_app_file()
268
+ if not app_file:
269
+ raise RuntimeError(
270
+ f"No app.py or main.py found in '{self.space_id}'. "
271
+ "Cannot determine how to launch this Space."
272
+ )
273
+
274
+ port = _find_free_port()
275
+ local_url = f"http://127.0.0.1:{port}"
276
+
277
+ venv_python = self.venv_dir / "bin" / "python"
278
+ if sys.platform == "win32":
279
+ venv_python = self.venv_dir / "Scripts" / "python.exe"
280
+
281
+ timeout = int(os.environ.get("DAGGR_LOCAL_TIMEOUT", "120"))
282
+
283
+ env = os.environ.copy()
284
+ env["GRADIO_SERVER_PORT"] = str(port)
285
+ env["GRADIO_SERVER_NAME"] = "127.0.0.1"
286
+ env["PYTHONUNBUFFERED"] = "1"
287
+
288
+ print(f" Launching '{self.space_id}' on port {port}...")
289
+ print(f" Waiting for app to start (timeout: {timeout}s)...")
290
+
291
+ log_file = self._get_log_path("launch")
292
+ log_file.parent.mkdir(parents=True, exist_ok=True)
293
+
294
+ self.process = subprocess.Popen(
295
+ [str(venv_python), str(app_file)],
296
+ cwd=str(self.repo_dir),
297
+ env=env,
298
+ stdout=subprocess.PIPE,
299
+ stderr=subprocess.STDOUT,
300
+ text=True,
301
+ )
302
+
303
+ _running_processes[self.space_id] = self.process
304
+
305
+ ready, error_output = self._wait_for_ready(local_url, timeout, verbose=True)
306
+ if not ready:
307
+ self._log_to_file("launch", error_output)
308
+ if self.process.poll() is None:
309
+ self.process.terminate()
310
+
311
+ print("\n ❌ Space failed to start!")
312
+ if error_output:
313
+ error_lines = error_output.strip().split("\n")
314
+ relevant_lines = [ln for ln in error_lines if ln.strip()][-10:]
315
+ if relevant_lines:
316
+ print(" Last output:")
317
+ for line in relevant_lines:
318
+ print(f" {line}")
319
+
320
+ print(f" Full log: {log_file}")
321
+ raise RuntimeError(
322
+ f"Space '{self.space_id}' failed to start.\n"
323
+ f"See logs at: {log_file}\n"
324
+ "Suggestions:\n"
325
+ " 1. Some Spaces require GPU hardware\n"
326
+ " 2. Check the Space's README for requirements\n"
327
+ " 3. Set DAGGR_LOCAL_VERBOSE=1 to see all output"
328
+ )
329
+
330
+ metadata = self._load_metadata() or {}
331
+ metadata["local_url"] = local_url
332
+ metadata["last_successful_launch"] = datetime.now().isoformat()
333
+ self._save_metadata(metadata)
334
+
335
+ print(f" Space running at {local_url}")
336
+ return local_url
337
+
338
+ def _find_app_file(self) -> Path | None:
339
+ for name in ["app.py", "main.py", "demo.py"]:
340
+ path = self.repo_dir / name
341
+ if path.exists():
342
+ return path
343
+ return None
344
+
345
+ def _wait_for_ready(
346
+ self, url: str, timeout: int, verbose: bool = False
347
+ ) -> tuple[bool, str]:
348
+ output_lines: list[str] = []
349
+ start = time.time()
350
+ last_status_time = start
351
+ saw_error = False
352
+
353
+ while time.time() - start < timeout:
354
+ if self.process and self.process.stdout:
355
+ while True:
356
+ if sys.platform == "win32":
357
+ line = self.process.stdout.readline()
358
+ if not line:
359
+ break
360
+ else:
361
+ ready, _, _ = select.select([self.process.stdout], [], [], 0)
362
+ if not ready:
363
+ break
364
+ line = self.process.stdout.readline()
365
+
366
+ if line:
367
+ output_lines.append(line)
368
+ line_lower = line.lower()
369
+ if (
370
+ "traceback" in line_lower
371
+ or "modulenotfounderror" in line_lower
372
+ ):
373
+ saw_error = True
374
+ if verbose:
375
+ print(f" [app] {line.rstrip()}")
376
+
377
+ exit_code = self.process.poll() if self.process else None
378
+ if exit_code is not None:
379
+ if self.process and self.process.stdout:
380
+ remaining = self.process.stdout.read()
381
+ if remaining:
382
+ output_lines.append(remaining)
383
+ if verbose:
384
+ for rem_line in remaining.strip().split("\n"):
385
+ if rem_line.strip():
386
+ print(f" [app] {rem_line}")
387
+ print(f" App process exited with code {exit_code}")
388
+ return False, "".join(output_lines)
389
+
390
+ if saw_error:
391
+ time.sleep(0.5)
392
+ if self.process and self.process.poll() is not None:
393
+ if self.process.stdout:
394
+ remaining = self.process.stdout.read()
395
+ if remaining:
396
+ output_lines.append(remaining)
397
+ print(" App crashed during startup")
398
+ return False, "".join(output_lines)
399
+
400
+ elapsed = time.time() - start
401
+ if elapsed - (last_status_time - start) >= 10:
402
+ print(f" Still waiting... ({int(elapsed)}s elapsed)")
403
+ last_status_time = time.time()
404
+
405
+ try:
406
+ with urllib.request.urlopen(url, timeout=2) as response:
407
+ if response.status == 200:
408
+ return True, "".join(output_lines)
409
+ except (urllib.error.URLError, OSError):
410
+ pass
411
+
412
+ time.sleep(0.3)
413
+
414
+ return False, "".join(output_lines)
415
+
416
+ def _load_metadata(self) -> dict[str, Any] | None:
417
+ if not self.metadata_path.exists():
418
+ return None
419
+ try:
420
+ return json.loads(self.metadata_path.read_text())
421
+ except (json.JSONDecodeError, OSError):
422
+ return None
423
+
424
+ def _save_metadata(self, metadata: dict[str, Any]) -> None:
425
+ self.metadata_path.parent.mkdir(parents=True, exist_ok=True)
426
+ self.metadata_path.write_text(json.dumps(metadata, indent=2))
427
+
428
+ def _get_log_path(self, log_type: str) -> Path:
429
+ logs_dir = _get_logs_dir()
430
+ logs_dir.mkdir(parents=True, exist_ok=True)
431
+ safe_name = self.space_id.replace("/", "_")
432
+ timestamp = datetime.now().strftime("%Y-%m-%d")
433
+ return logs_dir / f"{safe_name}_{log_type}_{timestamp}.log"
434
+
435
+ def _log_to_file(self, log_type: str, content: str) -> None:
436
+ log_path = self._get_log_path(log_type)
437
+ log_path.parent.mkdir(parents=True, exist_ok=True)
438
+ with open(log_path, "w") as f:
439
+ f.write(f"Timestamp: {datetime.now().isoformat()}\n")
440
+ f.write(f"Space: {self.space_id}\n")
441
+ f.write(f"Type: {log_type}\n")
442
+ f.write("=" * 50 + "\n")
443
+ f.write(content)
444
+
445
+ def _log_error(self, error: Exception) -> None:
446
+ self._log_to_file("error", str(error))
447
+
448
+
449
+ def prepare_local_node(node: GradioNode) -> None:
450
+ if node._local_failed or node._local_url:
451
+ return
452
+
453
+ if not _is_space_id(node._src):
454
+ return
455
+
456
+ no_fallback = os.environ.get("DAGGR_LOCAL_NO_FALLBACK") == "1"
457
+
458
+ try:
459
+ manager = LocalSpaceManager(node)
460
+ url = manager.ensure_ready()
461
+ node._local_url = url
462
+ except Exception as e:
463
+ node._local_failed = True
464
+ safe_name = node._src.replace("/", "_")
465
+
466
+ print(f"\n ⚠️ Local setup failed for '{node._src}'")
467
+ print(f" Reason: {e}")
468
+ print(f" Logs: {_get_logs_dir()}/{safe_name}_*.log")
469
+
470
+ if no_fallback:
471
+ raise RuntimeError(
472
+ f"Local execution failed for '{node._src}' and fallback is disabled. "
473
+ f"Error: {e}"
474
+ ) from e
475
+
476
+ print(" Will fall back to remote API at execution time.\n")
477
+
478
+
479
+ def get_local_client(node: GradioNode) -> Any:
480
+ if node._local_failed:
481
+ return None
482
+
483
+ if node._local_url:
484
+ from gradio_client import Client
485
+
486
+ return Client(node._local_url, download_files=False, verbose=False)
487
+
488
+ return None
489
+
490
+
491
+ def cleanup_local_processes() -> None:
492
+ global _running_processes
493
+ for space_id, proc in list(_running_processes.items()):
494
+ if proc.poll() is None:
495
+ proc.terminate()
496
+ try:
497
+ proc.wait(timeout=5)
498
+ except subprocess.TimeoutExpired:
499
+ proc.kill()
500
+ _running_processes.clear()
501
+
502
+
503
+ atexit.register(cleanup_local_processes)
daggr/node.py ADDED
@@ -0,0 +1,772 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Node types for daggr graphs.
2
+
3
+ This module defines the various node types that can be used in a daggr graph:
4
+ - Node: Abstract base class for all nodes
5
+ - GradioNode: Wraps a Gradio Space or endpoint
6
+ - InferenceNode: Wraps a Hugging Face Inference API model
7
+ - FnNode: Wraps a Python function
8
+ - InteractionNode: Represents user interaction points
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import inspect
14
+ import warnings
15
+ from abc import ABC
16
+ from collections.abc import Callable
17
+ from typing import Any
18
+
19
+ from daggr._utils import suggest_similar
20
+ from daggr.port import ItemList, Port, PortNamespace, is_port
21
+
22
+ _FILE_TYPE_COMPONENTS = {
23
+ "Image",
24
+ "Audio",
25
+ "Video",
26
+ "File",
27
+ "Gallery",
28
+ "ImageEditor",
29
+ "ImageSlider",
30
+ }
31
+
32
+
33
+ def _warn_if_type_set(component: Any, port_name: str) -> None:
34
+ constructor_args = getattr(component, "_constructor_args", None)
35
+ if not constructor_args:
36
+ return
37
+ comp_type = constructor_args[0].get("type")
38
+ if comp_type is None:
39
+ return
40
+ class_name = type(component).__name__
41
+ if class_name not in _FILE_TYPE_COMPONENTS:
42
+ return
43
+ if comp_type != "filepath":
44
+ warnings.warn(
45
+ f"Gradio component {class_name}(type={comp_type!r}) on port '{port_name}': "
46
+ f"daggr ignores the `type` parameter. All file data is passed as file path "
47
+ f"strings regardless of this setting.",
48
+ stacklevel=4,
49
+ )
50
+
51
+
52
+ def _is_gradio_component(obj: Any) -> bool:
53
+ if obj is None:
54
+ return False
55
+ class_name = obj.__class__.__name__
56
+ module = getattr(obj.__class__, "__module__", "")
57
+ return "gradio" in module or class_name in (
58
+ "Textbox",
59
+ "TextArea",
60
+ "Audio",
61
+ "Image",
62
+ "JSON",
63
+ "Markdown",
64
+ "Number",
65
+ "Checkbox",
66
+ "Dropdown",
67
+ "Radio",
68
+ "Slider",
69
+ "File",
70
+ "Video",
71
+ "Gallery",
72
+ "Chatbot",
73
+ "Text",
74
+ )
75
+
76
+
77
+ class Node(ABC):
78
+ """Abstract base class for all nodes in a daggr graph.
79
+
80
+ Nodes represent processing steps in a DAG. Each node has named input and
81
+ output ports that can be connected to form a data processing pipeline.
82
+
83
+ Ports can be accessed as attributes: `node.port_name` returns a Port object.
84
+
85
+ Args:
86
+ name: Optional display name for the node. If not provided, a name will
87
+ be auto-generated based on the node type.
88
+ """
89
+
90
+ _id_counter = 0
91
+
92
+ def __init__(self, name: str | None = None):
93
+ self._id = Node._id_counter
94
+ Node._id_counter += 1
95
+ self._name = name or ""
96
+ self._name_explicitly_set = bool(name)
97
+ self._input_ports: list[str] = []
98
+ self._output_ports: list[str] = []
99
+ self._input_components: dict[str, Any] = {}
100
+ self._output_components: dict[str, Any] = {}
101
+ self._item_list_schemas: dict[str, dict[str, Any]] = {}
102
+ self._fixed_inputs: dict[str, Any] = {}
103
+ self._port_connections: dict[str, Any] = {}
104
+
105
+ @property
106
+ def name(self) -> str:
107
+ return self._name
108
+
109
+ @name.setter
110
+ def name(self, value: str) -> None:
111
+ self._name = value
112
+ self._name_explicitly_set = True
113
+
114
+ def __getattr__(self, name: str) -> Port:
115
+ if name.startswith("_"):
116
+ raise AttributeError(name)
117
+ return Port(self, name)
118
+
119
+ def __dir__(self) -> list[str]:
120
+ base = ["_name", "_inputs", "_outputs", "_input_ports", "_output_ports"]
121
+ return base + self._input_ports + self._output_ports
122
+
123
+ def __or__(self, other: Node) -> ChoiceNode:
124
+ """Combine two nodes as alternatives using the | operator.
125
+
126
+ Returns a ChoiceNode that lets users pick which variant to run.
127
+
128
+ Example:
129
+ >>> tts = GradioNode("space1/tts", ...) | GradioNode("space2/tts", ...)
130
+ >>> # tts.audio works regardless of which variant is selected
131
+ """
132
+ if isinstance(other, ChoiceNode):
133
+ return ChoiceNode([self] + other._variants, name=self._name)
134
+ return ChoiceNode([self, other], name=self._name)
135
+
136
+ @property
137
+ def _inputs(self) -> PortNamespace:
138
+ return PortNamespace(self, self._input_ports)
139
+
140
+ @property
141
+ def _outputs(self) -> PortNamespace:
142
+ return PortNamespace(self, self._output_ports)
143
+
144
+ def _default_output_port(self) -> Port:
145
+ if self._output_ports:
146
+ return Port(self, self._output_ports[0])
147
+ return Port(self, "output")
148
+
149
+ def _default_input_port(self) -> Port:
150
+ if self._input_ports:
151
+ return Port(self, self._input_ports[0])
152
+ return Port(self, "input")
153
+
154
+ def _validate_ports(self):
155
+ all_ports = set(self._input_ports + self._output_ports)
156
+ underscore_ports = [p for p in all_ports if p.startswith("_")]
157
+ if underscore_ports:
158
+ warnings.warn(
159
+ f"Port names {underscore_ports} start with underscore. "
160
+ f"Use node._inputs.{underscore_ports[0]} or node._outputs.{underscore_ports[0]} to access."
161
+ )
162
+
163
+ def _process_inputs(self, inputs: dict[str, Any]) -> None:
164
+ for port_name, value in inputs.items():
165
+ self._input_ports.append(port_name)
166
+ if is_port(value):
167
+ self._port_connections[port_name] = value
168
+ elif _is_gradio_component(value):
169
+ _warn_if_type_set(value, port_name)
170
+ self._input_components[port_name] = value
171
+ else:
172
+ self._fixed_inputs[port_name] = value
173
+
174
+ def _process_outputs(self, outputs: dict[str, Any]) -> None:
175
+ for port_name, component in outputs.items():
176
+ self._output_ports.append(port_name)
177
+ if component is not None and _is_gradio_component(component):
178
+ _warn_if_type_set(component, port_name)
179
+ self._output_components[port_name] = component
180
+
181
+ def test(self, **inputs) -> dict[str, Any]:
182
+ """Test-run this node in isolation and return the raw result.
183
+
184
+ If no inputs are provided, auto-generates example values using:
185
+ - Gradio component's .example_value() method
186
+ - Port's associated output component's .example_value()
187
+ - Callable inputs are called
188
+ - Fixed values are used directly
189
+
190
+ Args:
191
+ **inputs: Override inputs for the test run.
192
+
193
+ Returns:
194
+ Dict mapping output port names to their values.
195
+
196
+ Example:
197
+ >>> tts = GradioNode("mrfakename/MeloTTS", api_name="/synthesize", ...)
198
+ >>> result = tts.test(text="Hello world", speaker="EN-US")
199
+ >>> # Returns: {"audio": "/path/to/audio.wav"}
200
+ >>>
201
+ >>> # Or with auto-generated example values:
202
+ >>> result = tts.test()
203
+ """
204
+ from daggr import Graph
205
+ from daggr.executor import SequentialExecutor
206
+
207
+ if not inputs:
208
+ inputs = self._generate_example_inputs()
209
+
210
+ graph = Graph("_test", nodes=[self], persist_key=False)
211
+ executor = SequentialExecutor(graph)
212
+ return executor.execute_node(self._name, inputs)
213
+
214
+ def _generate_example_inputs(self) -> dict[str, Any]:
215
+ """Generate example values for all input ports."""
216
+ examples = {}
217
+
218
+ # From input components (Gradio components)
219
+ for port_name, comp in self._input_components.items():
220
+ if hasattr(comp, "example_value"):
221
+ examples[port_name] = comp.example_value()
222
+
223
+ # From fixed inputs (constants, callables, or port connections)
224
+ for port_name, source in self._fixed_inputs.items():
225
+ if callable(source):
226
+ examples[port_name] = source()
227
+ else:
228
+ examples[port_name] = source
229
+
230
+ # From port connections (use the connected port's output component)
231
+ for port_name, port in self._port_connections.items():
232
+ if is_port(port):
233
+ comp = port._node._output_components.get(port._port_name)
234
+ if comp and hasattr(comp, "example_value"):
235
+ examples[port_name] = comp.example_value()
236
+
237
+ return examples
238
+
239
+ def __repr__(self):
240
+ return f"{self.__class__.__name__}(name={self._name})"
241
+
242
+
243
+ class ChoiceNode(Node):
244
+ """A node that wraps multiple alternative nodes.
245
+
246
+ ChoiceNode allows users to select which variant to run from a set of
247
+ alternatives. Created using the | operator between nodes.
248
+
249
+ The output ports are the union of all variants' output ports, so downstream
250
+ nodes can connect to any output that exists in at least one variant.
251
+
252
+ Args:
253
+ variants: List of Node objects that serve as alternatives.
254
+ name: Optional display name. Defaults to the first variant's name.
255
+
256
+ Example:
257
+ >>> tts = GradioNode("space1/tts", ...) | GradioNode("space2/tts", ...)
258
+ >>> # tts is a ChoiceNode with two variants
259
+ >>> # tts.audio works regardless of which variant is selected
260
+ """
261
+
262
+ def __init__(
263
+ self,
264
+ variants: list[Node],
265
+ name: str | None = None,
266
+ ):
267
+ if not variants:
268
+ raise ValueError("ChoiceNode requires at least one variant")
269
+
270
+ super().__init__(name)
271
+ self._variants = variants
272
+ self._selected_variant = 0
273
+
274
+ if not self._name:
275
+ self._name = variants[0]._name
276
+
277
+ self._output_ports = self._compute_union_output_ports()
278
+ self._output_components = self._compute_union_output_components()
279
+
280
+ for variant in variants:
281
+ for port_name, port in variant._port_connections.items():
282
+ if port_name not in self._port_connections:
283
+ self._port_connections[port_name] = port
284
+
285
+ def _compute_union_output_ports(self) -> list[str]:
286
+ seen = set()
287
+ ports = []
288
+ for variant in self._variants:
289
+ for port in variant._output_ports:
290
+ if port not in seen:
291
+ seen.add(port)
292
+ ports.append(port)
293
+ return ports
294
+
295
+ def _compute_union_output_components(self) -> dict[str, Any]:
296
+ components = {}
297
+ for variant in self._variants:
298
+ for port_name, comp in variant._output_components.items():
299
+ if port_name not in components:
300
+ components[port_name] = comp
301
+ return components
302
+
303
+ def __or__(self, other: Node) -> ChoiceNode:
304
+ if isinstance(other, ChoiceNode):
305
+ return ChoiceNode(self._variants + other._variants, name=self._name)
306
+ return ChoiceNode(self._variants + [other], name=self._name)
307
+
308
+ def __repr__(self):
309
+ variant_names = [v._name for v in self._variants]
310
+ return f"ChoiceNode(name={self._name}, variants={variant_names})"
311
+
312
+
313
+ class GradioNode(Node):
314
+ """A node that wraps a Gradio Space or endpoint.
315
+
316
+ GradioNode connects to a Hugging Face Space or any Gradio app and exposes
317
+ its API as a node in the graph.
318
+
319
+ Args:
320
+ space_or_url: Hugging Face Space ID (e.g., "username/space-name") or
321
+ a full URL to a Gradio app.
322
+ api_name: The API endpoint to call (e.g., "/predict"). Defaults to "/predict".
323
+ name: Optional display name for the node.
324
+ inputs: Dict mapping input port names to Gradio components, Port connections,
325
+ or fixed values.
326
+ outputs: Dict mapping output port names to Gradio components for display.
327
+ validate: Whether to validate the Space exists and has the specified endpoint.
328
+ run_locally: If True, clone and run the Space locally instead of using the
329
+ remote API.
330
+
331
+ Example:
332
+ >>> tts = GradioNode(
333
+ ... "mrfakename/MeloTTS",
334
+ ... api_name="/synthesize",
335
+ ... inputs={"text": gr.Textbox(), "speaker": "EN-US"},
336
+ ... outputs={"audio": gr.Audio()},
337
+ ... )
338
+ """
339
+
340
+ _name_counters: dict[str, int] = {}
341
+
342
+ def __init__(
343
+ self,
344
+ space_or_url: str,
345
+ api_name: str | None = None,
346
+ name: str | None = None,
347
+ inputs: dict[str, Any] | None = None,
348
+ outputs: dict[str, Any] | None = None,
349
+ validate: bool = True,
350
+ run_locally: bool = False,
351
+ preprocess: Callable[[dict], dict] | None = None,
352
+ postprocess: Callable[..., Any] | None = None,
353
+ ):
354
+ super().__init__(name)
355
+ self._src = space_or_url
356
+ self._api_name = api_name
357
+ self._run_locally = run_locally
358
+ self._local_url: str | None = None
359
+ self._local_failed = False
360
+ self._preprocess = preprocess
361
+ self._postprocess = postprocess
362
+
363
+ if validate:
364
+ self._validate_space_format()
365
+
366
+ if not self._name:
367
+ base_name = self._src.split("/")[-1]
368
+ if base_name not in GradioNode._name_counters:
369
+ GradioNode._name_counters[base_name] = 0
370
+ self._name = base_name
371
+ else:
372
+ GradioNode._name_counters[base_name] += 1
373
+ self._name = f"{base_name}_{GradioNode._name_counters[base_name]}"
374
+
375
+ self._process_inputs(inputs or {})
376
+ self._process_outputs(outputs or {})
377
+ self._validate_ports()
378
+
379
+ if validate and not run_locally:
380
+ self._validate_gradio_api(inputs or {}, outputs or {})
381
+
382
+ def _validate_space_format(self) -> None:
383
+ src = self._src
384
+ if not ("/" in src or src.startswith("http://") or src.startswith("https://")):
385
+ raise ValueError(
386
+ f"Invalid space_or_url '{src}'. Expected format: 'username/space-name' "
387
+ f"or a full URL like 'https://...'"
388
+ )
389
+
390
+ def _get_api_info(self) -> dict:
391
+ from daggr import _client_cache
392
+
393
+ cached = _client_cache.get_api_info(self._src)
394
+ if cached is not None:
395
+ return cached
396
+
397
+ from gradio_client import Client
398
+
399
+ client = _client_cache.get_client(self._src)
400
+ if client is None:
401
+ client = Client(self._src, download_files=False, verbose=False)
402
+ _client_cache.set_client(self._src, client)
403
+
404
+ api_info = client.view_api(return_format="dict", print_info=False)
405
+ _client_cache.set_api_info(self._src, api_info)
406
+ return api_info
407
+
408
+ def _validate_gradio_api(
409
+ self, inputs: dict[str, Any], outputs: dict[str, Any]
410
+ ) -> None:
411
+ from daggr import _client_cache
412
+
413
+ api_name = self._api_name or "/predict"
414
+ if not api_name.startswith("/"):
415
+ api_name = "/" + api_name
416
+
417
+ cache_key = (
418
+ self._src,
419
+ api_name,
420
+ tuple(sorted(inputs.keys())),
421
+ tuple(sorted(outputs.keys())) if outputs else (),
422
+ )
423
+ if _client_cache.is_validated(cache_key):
424
+ return
425
+
426
+ api_info = self._get_api_info()
427
+
428
+ named_endpoints = api_info.get("named_endpoints", {})
429
+ unnamed_endpoints = api_info.get("unnamed_endpoints", {})
430
+
431
+ endpoint_info = None
432
+ if api_name in named_endpoints:
433
+ endpoint_info = named_endpoints[api_name]
434
+ else:
435
+ try:
436
+ fn_index = int(api_name.lstrip("/"))
437
+ if fn_index in unnamed_endpoints or str(fn_index) in unnamed_endpoints:
438
+ endpoint_info = unnamed_endpoints.get(
439
+ fn_index, unnamed_endpoints.get(str(fn_index))
440
+ )
441
+ except ValueError:
442
+ pass
443
+
444
+ if endpoint_info is None:
445
+ available = list(named_endpoints.keys())
446
+ if unnamed_endpoints:
447
+ available.extend([f"/{k}" for k in unnamed_endpoints.keys()])
448
+ suggested = suggest_similar(api_name, set(available))
449
+ msg = (
450
+ f"API endpoint '{api_name}' not found in '{self._src}'. "
451
+ f"Available endpoints: {available}"
452
+ )
453
+ if suggested:
454
+ msg += f" Did you mean '{suggested}'?"
455
+ raise ValueError(msg)
456
+
457
+ params_info = endpoint_info.get("parameters", [])
458
+ valid_params = {p.get("parameter_name", p["label"]) for p in params_info}
459
+ input_params = set(inputs.keys())
460
+ invalid_params = input_params - valid_params
461
+
462
+ if invalid_params:
463
+ suggestions = {}
464
+ for inv in invalid_params:
465
+ suggestion = suggest_similar(inv, valid_params)
466
+ if suggestion:
467
+ suggestions[inv] = suggestion
468
+ msg = (
469
+ f"Invalid parameter(s) {invalid_params} for endpoint '{api_name}' "
470
+ f"in '{self._src}'."
471
+ )
472
+ if suggestions:
473
+ suggestion_str = ", ".join(
474
+ f"'{k}' -> '{v}'" for k, v in suggestions.items()
475
+ )
476
+ msg += f" Did you mean: {suggestion_str}?"
477
+ msg += f" Valid parameters: {valid_params}"
478
+ raise ValueError(msg)
479
+
480
+ required_params = {
481
+ p.get("parameter_name", p["label"])
482
+ for p in params_info
483
+ if not p.get("parameter_has_default", False)
484
+ }
485
+ provided_params = set(inputs.keys())
486
+ missing_required = required_params - provided_params
487
+
488
+ if missing_required:
489
+ raise ValueError(
490
+ f"Missing required parameter(s) {missing_required} for endpoint "
491
+ f"'{api_name}' in '{self._src}'. These parameters have no default values."
492
+ )
493
+
494
+ api_returns = endpoint_info.get("returns", [])
495
+ if outputs and api_returns and not self._postprocess:
496
+ num_returns = len(api_returns)
497
+ num_outputs = len(outputs)
498
+ if num_outputs > num_returns:
499
+ warnings.warn(
500
+ f"GradioNode '{self._name}' defines {num_outputs} outputs but "
501
+ f"endpoint '{api_name}' only returns {num_returns} value(s). "
502
+ f"Extra outputs will be None."
503
+ )
504
+
505
+ _client_cache.mark_validated(cache_key)
506
+
507
+
508
+ class InferenceNode(Node):
509
+ """A node that wraps a Hugging Face Inference API model.
510
+
511
+ InferenceNode uses the Hugging Face Inference API to run models without
512
+ needing to download them locally. The task type (text-generation, text-to-image,
513
+ etc.) is automatically determined from the model's pipeline_tag on the Hub.
514
+
515
+ Args:
516
+ model: The Hugging Face model ID (e.g., "meta-llama/Llama-2-7b-chat-hf").
517
+ name: Optional display name for the node.
518
+ inputs: Dict mapping input port names to values or components.
519
+ outputs: Dict mapping output port names to components.
520
+ validate: Whether to validate the model exists on the Hub.
521
+ preprocess: Optional function that receives the input dict and returns a
522
+ modified dict before the inference call.
523
+ postprocess: Optional function that receives the raw inference result and
524
+ returns a transformed value before it is mapped to output ports.
525
+
526
+ Example:
527
+ >>> llm = InferenceNode("meta-llama/Llama-2-7b-chat-hf")
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ model: str,
533
+ name: str | None = None,
534
+ inputs: dict[str, Any] | None = None,
535
+ outputs: dict[str, Any] | None = None,
536
+ validate: bool = True,
537
+ preprocess: Callable[[dict], dict] | None = None,
538
+ postprocess: Callable[..., Any] | None = None,
539
+ ):
540
+ super().__init__(name)
541
+ self._model = model
542
+ self._task: str | None = None
543
+ self._task_fetched: bool = False
544
+ self._preprocess = preprocess
545
+ self._postprocess = postprocess
546
+
547
+ if not self._name:
548
+ # Strip provider tag (e.g., ":replicate") for display name
549
+ self._name = self._model_name_for_hub.split("/")[-1]
550
+
551
+ if inputs:
552
+ self._process_inputs(inputs)
553
+ else:
554
+ self._input_ports = ["input"]
555
+
556
+ if outputs:
557
+ self._process_outputs(outputs)
558
+ else:
559
+ self._output_ports = ["output"]
560
+
561
+ self._validate_ports()
562
+
563
+ if validate:
564
+ self._fetch_model_info()
565
+
566
+ @property
567
+ def _model_name_for_hub(self) -> str:
568
+ """Return the model name without provider tags (e.g., ':replicate')."""
569
+ # HF Inference Client allows tags like "model:provider" for routing
570
+ # Strip these for Hub API calls and display
571
+ return self._model.split(":")[0]
572
+
573
+ @property
574
+ def _provider(self) -> str | None:
575
+ """Return the provider tag if specified (e.g., 'replicate' from 'model:replicate')."""
576
+ parts = self._model.split(":")
577
+ return parts[1] if len(parts) > 1 else None
578
+
579
+ def _fetch_model_info(self) -> None:
580
+ if self._task_fetched:
581
+ return
582
+
583
+ from daggr import _client_cache
584
+
585
+ # Use model name without provider tag for Hub lookups
586
+ hub_model = self._model_name_for_hub
587
+
588
+ found_in_cache, cached = _client_cache.get_model_task(hub_model)
589
+ if found_in_cache:
590
+ if cached == "__NOT_FOUND__":
591
+ raise ValueError(f"Model '{hub_model}' not found on Hugging Face Hub.")
592
+ self._task = cached
593
+ self._task_fetched = True
594
+ return
595
+
596
+ from huggingface_hub import model_info
597
+ from huggingface_hub.utils import RepositoryNotFoundError
598
+
599
+ try:
600
+ info = model_info(hub_model)
601
+ self._task = info.pipeline_tag
602
+ _client_cache.set_model_task(hub_model, self._task)
603
+ self._task_fetched = True
604
+ except RepositoryNotFoundError:
605
+ _client_cache.set_model_not_found(hub_model)
606
+ raise ValueError(
607
+ f"Model '{hub_model}' not found on Hugging Face Hub. "
608
+ f"Please check the model name is correct (format: 'username/model-name')."
609
+ )
610
+
611
+
612
+ class FnNode(Node):
613
+ """A node that wraps a Python function.
614
+
615
+ FnNode allows you to use any Python function as a node in the graph.
616
+ Input ports are automatically discovered from the function signature.
617
+
618
+ Return values are mapped to output ports in order, just like GradioNode:
619
+ - Single value: maps to the first output port
620
+ - Tuple: each element maps to the corresponding output port in order
621
+
622
+ Concurrency:
623
+ By default, FnNodes execute sequentially (one at a time per session)
624
+ to prevent resource contention. Use the concurrency parameters to
625
+ allow parallel execution:
626
+
627
+ - concurrent=True: Allow this node to run in parallel with others
628
+ - concurrency_group: Group nodes that share a resource (e.g., GPU)
629
+ - max_concurrent: Max parallel executions within a group (default: 1)
630
+
631
+ Note: GradioNode and InferenceNode always run concurrently since they
632
+ are external API calls. Prefer these over FnNode when possible.
633
+
634
+ Args:
635
+ fn: The Python function to wrap.
636
+ name: Optional display name. Defaults to the function name.
637
+ inputs: Optional dict to explicitly define input ports and their
638
+ connections or UI components.
639
+ outputs: Optional dict mapping output port names to UI components
640
+ or ItemList schemas.
641
+ concurrent: If True, allow parallel execution. Default: False.
642
+ concurrency_group: Name of a group sharing a concurrency limit.
643
+ max_concurrent: Max parallel executions in the group. Default: 1.
644
+
645
+ Example:
646
+ >>> def process_text(text: str) -> tuple[str, int]:
647
+ ... return text.upper(), len(text)
648
+ >>> node = FnNode(
649
+ ... process_text,
650
+ ... outputs={"uppercase": gr.Textbox(), "length": gr.Number()}
651
+ ... )
652
+
653
+ >>> # Allow parallel execution
654
+ >>> node = FnNode(my_func, concurrent=True)
655
+
656
+ >>> # Share GPU with other nodes (max 2 concurrent)
657
+ >>> node = FnNode(gpu_func, concurrency_group="gpu", max_concurrent=2)
658
+ """
659
+
660
+ def __init__(
661
+ self,
662
+ fn: Callable,
663
+ name: str | None = None,
664
+ inputs: dict[str, Any] | None = None,
665
+ outputs: dict[str, Any] | None = None,
666
+ preprocess: Callable[[dict], dict] | None = None,
667
+ postprocess: Callable[..., Any] | None = None,
668
+ concurrent: bool = False,
669
+ concurrency_group: str | None = None,
670
+ max_concurrent: int = 1,
671
+ ):
672
+ super().__init__(name)
673
+ self._fn = fn
674
+ self._preprocess = preprocess
675
+ self._postprocess = postprocess
676
+ self._concurrent = concurrent
677
+ self._concurrency_group = concurrency_group
678
+ self._max_concurrent = max_concurrent
679
+
680
+ if not self._name:
681
+ self._name = self._fn.__name__
682
+
683
+ if inputs:
684
+ self._validate_fn_inputs(inputs)
685
+ self._process_inputs(inputs)
686
+ else:
687
+ self._discover_signature()
688
+
689
+ if outputs:
690
+ self._process_outputs(outputs)
691
+ else:
692
+ self._output_ports = ["output"]
693
+
694
+ self._validate_ports()
695
+
696
+ def _discover_signature(self):
697
+ sig = inspect.signature(self._fn)
698
+ self._input_ports = list(sig.parameters.keys())
699
+
700
+ def _validate_fn_inputs(self, inputs: dict[str, Any]) -> None:
701
+ sig = inspect.signature(self._fn)
702
+ valid_params = set(sig.parameters.keys())
703
+ provided_params = set(inputs.keys())
704
+ invalid_params = provided_params - valid_params
705
+
706
+ if invalid_params:
707
+ suggestions = {}
708
+ for inv in invalid_params:
709
+ suggestion = suggest_similar(inv, valid_params)
710
+ if suggestion:
711
+ suggestions[inv] = suggestion
712
+
713
+ msg = (
714
+ f"Invalid input(s) {invalid_params} for function '{self._fn.__name__}'."
715
+ )
716
+ if suggestions:
717
+ suggestion_str = ", ".join(
718
+ f"'{k}' -> '{v}'" for k, v in suggestions.items()
719
+ )
720
+ msg += f" Did you mean: {suggestion_str}?"
721
+ msg += f" Valid parameters: {valid_params}"
722
+ raise ValueError(msg)
723
+
724
+ def _process_outputs(self, outputs: dict[str, Any]) -> None:
725
+ for port_name, component in outputs.items():
726
+ self._output_ports.append(port_name)
727
+ if component is None:
728
+ continue
729
+ if isinstance(component, ItemList):
730
+ self._item_list_schemas[port_name] = component.schema
731
+ elif _is_gradio_component(component):
732
+ self._output_components[port_name] = component
733
+
734
+
735
+ class InteractionNode(Node):
736
+ """A node representing a user interaction point in the graph.
737
+
738
+ InteractionNodes pause execution and wait for user input before continuing.
739
+ They are used for approval steps, selections, or other human-in-the-loop
740
+ interactions.
741
+
742
+ Args:
743
+ name: Optional display name for the node.
744
+ interaction_type: Type of interaction (e.g., "generic", "approve", "choose_one").
745
+ inputs: Dict mapping input port names to components or connections.
746
+ outputs: Dict mapping output port names to components.
747
+ """
748
+
749
+ def __init__(
750
+ self,
751
+ name: str | None = None,
752
+ interaction_type: str = "generic",
753
+ inputs: dict[str, Any] | None = None,
754
+ outputs: dict[str, Any] | None = None,
755
+ ):
756
+ super().__init__(name)
757
+ self._interaction_type = interaction_type
758
+
759
+ if inputs:
760
+ self._process_inputs(inputs)
761
+ else:
762
+ self._input_ports = ["input"]
763
+
764
+ if outputs:
765
+ self._process_outputs(outputs)
766
+ else:
767
+ self._output_ports = ["output"]
768
+
769
+ if not self._name:
770
+ self._name = f"interaction_{self._id}"
771
+
772
+ self._validate_ports()
daggr/ops.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from daggr.node import InteractionNode
4
+
5
+
6
+ class ChooseOne(InteractionNode):
7
+ _instance_counter = 0
8
+
9
+ def __init__(self, name: str | None = None):
10
+ ChooseOne._instance_counter += 1
11
+ super().__init__(
12
+ name=name or f"choose_one_{ChooseOne._instance_counter}",
13
+ interaction_type="choose_one",
14
+ )
15
+ self._input_ports = ["options"]
16
+ self._output_ports = ["selected"]
17
+
18
+
19
+ class Approve(InteractionNode):
20
+ _instance_counter = 0
21
+
22
+ def __init__(self, name: str | None = None):
23
+ Approve._instance_counter += 1
24
+ super().__init__(
25
+ name=name or f"approve_{Approve._instance_counter}",
26
+ interaction_type="approve",
27
+ )
28
+ self._input_ports = ["input"]
29
+ self._output_ports = ["output"]
30
+
31
+
32
+ class TextInput(InteractionNode):
33
+ _instance_counter = 0
34
+
35
+ def __init__(self, name: str | None = None, label: str = "Input"):
36
+ TextInput._instance_counter += 1
37
+ super().__init__(
38
+ name=name or f"text_input_{TextInput._instance_counter}",
39
+ interaction_type="text_input",
40
+ )
41
+ self._label = label
42
+ self._input_ports = []
43
+ self._output_ports = ["text"]
44
+
45
+
46
+ class ImageInput(InteractionNode):
47
+ _instance_counter = 0
48
+
49
+ def __init__(self, name: str | None = None, label: str = "Image"):
50
+ ImageInput._instance_counter += 1
51
+ super().__init__(
52
+ name=name or f"image_input_{ImageInput._instance_counter}",
53
+ interaction_type="image_input",
54
+ )
55
+ self._label = label
56
+ self._input_ports = []
57
+ self._output_ports = ["image"]
daggr/package.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "name": "daggr",
3
+ "version": "0.7.0",
4
+ "description": "",
5
+ "python": "true"
6
+ }
daggr/port.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Port module for node input/output definitions.
2
+
3
+ Ports are named connection points on nodes. Output ports can be connected
4
+ to input ports to form edges in the graph.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import TYPE_CHECKING, Any
10
+
11
+ if TYPE_CHECKING:
12
+ from daggr.node import Node
13
+
14
+
15
+ class Port:
16
+ """A named connection point on a node.
17
+
18
+ Ports represent inputs or outputs of a node. Access them as attributes
19
+ on a node: `node.port_name`.
20
+
21
+ Attributes:
22
+ node: The node this port belongs to.
23
+ name: The name of the port.
24
+ """
25
+
26
+ def __init__(self, node: Node, name: str):
27
+ self.node = node
28
+ self.name = name
29
+
30
+ def __repr__(self):
31
+ return f"Port({self.node._name}.{self.name})"
32
+
33
+ def _as_source(self) -> tuple[Node, str]:
34
+ return (self.node, self.name)
35
+
36
+ def _as_target(self) -> tuple[Node, str]:
37
+ return (self.node, self.name)
38
+
39
+ def __getattr__(self, attr: str) -> ScatteredPort:
40
+ if attr.startswith("_"):
41
+ raise AttributeError(attr)
42
+ if (
43
+ hasattr(self.node, "_item_list_schemas")
44
+ and self.name in self.node._item_list_schemas
45
+ ):
46
+ schema = self.node._item_list_schemas[self.name]
47
+ if attr in schema:
48
+ return ScatteredPort(self, attr)
49
+ raise AttributeError(f"Port '{self.name}' has no attribute '{attr}'")
50
+
51
+ @property
52
+ def each(self) -> ScatteredPort:
53
+ """Scatter this port's output - run the downstream node once per item in the list."""
54
+ return ScatteredPort(self)
55
+
56
+ def all(self) -> GatheredPort:
57
+ """Gather outputs from a scattered node back into a list."""
58
+ return GatheredPort(self)
59
+
60
+
61
+ class ScatteredPort:
62
+ """A port that scatters its list output to run downstream nodes per-item.
63
+
64
+ Created by accessing `.each` on a port. When connected to a downstream
65
+ node, that node will be executed once for each item in the list.
66
+ """
67
+
68
+ def __init__(self, port: Port, item_key: str | None = None):
69
+ self.port = port
70
+ self.item_key = item_key
71
+
72
+ @property
73
+ def node(self):
74
+ return self.port.node
75
+
76
+ @property
77
+ def name(self):
78
+ return self.port.name
79
+
80
+ def __getitem__(self, key: str) -> ScatteredPort:
81
+ """Access a specific field from each scattered item (e.g., dialogue.json.each["text"])."""
82
+ return ScatteredPort(self.port, key)
83
+
84
+ def __repr__(self):
85
+ if self.item_key:
86
+ return f"ScatteredPort({self.port}['{self.item_key}'])"
87
+ return f"ScatteredPort({self.port})"
88
+
89
+
90
+ class GatheredPort:
91
+ """A port that gathers scattered results back into a list.
92
+
93
+ Created by calling `.all()` on a port. Collects results from all
94
+ scattered executions back into a single list.
95
+ """
96
+
97
+ def __init__(self, port: Port):
98
+ self.port = port
99
+
100
+ @property
101
+ def node(self):
102
+ return self.port.node
103
+
104
+ @property
105
+ def name(self):
106
+ return self.port.name
107
+
108
+ def __repr__(self):
109
+ return f"GatheredPort({self.port})"
110
+
111
+
112
+ PortLike = Port | ScatteredPort | GatheredPort
113
+
114
+
115
+ def is_port(obj: Any) -> bool:
116
+ """Check if an object is a Port, ScatteredPort, or GatheredPort."""
117
+ return isinstance(obj, (Port, ScatteredPort, GatheredPort))
118
+
119
+
120
+ class PortNamespace:
121
+ """A namespace for accessing ports that start with underscores.
122
+
123
+ Used via `node._inputs` or `node._outputs` to access ports whose names
124
+ start with underscores (which can't be accessed directly as attributes).
125
+ """
126
+
127
+ def __init__(self, node: Node, port_names: list[str]):
128
+ self._node = node
129
+ self._names = set(port_names)
130
+
131
+ def __getattr__(self, name: str) -> Port:
132
+ if name.startswith("_"):
133
+ raise AttributeError(name)
134
+ return Port(self._node, name)
135
+
136
+ def __dir__(self) -> list[str]:
137
+ return list(self._names)
138
+
139
+ def __repr__(self):
140
+ return f"PortNamespace({list(self._names)})"
141
+
142
+
143
+ class ItemList:
144
+ """Define an editable list output with per-item schema.
145
+
146
+ Example:
147
+ outputs={
148
+ "items": ItemList(
149
+ speaker=gr.Dropdown(choices=["Host", "Guest"]),
150
+ text=gr.Textbox(lines=2),
151
+ ),
152
+ }
153
+
154
+ The function should return a list of dicts matching the schema keys.
155
+ """
156
+
157
+ def __init__(self, **schema):
158
+ self.schema = schema
daggr/py.typed ADDED
File without changes
daggr/server.py ADDED
@@ -0,0 +1,1946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import base64
5
+ import json
6
+ import mimetypes
7
+ import os
8
+ import secrets
9
+ import socket
10
+ import tempfile
11
+ import threading
12
+ import time
13
+ import traceback
14
+ import uuid
15
+ import webbrowser
16
+ from pathlib import Path
17
+ from typing import TYPE_CHECKING, Any
18
+
19
+ import uvicorn
20
+ from fastapi import FastAPI, Header, Request, WebSocket, WebSocketDisconnect
21
+ from fastapi.responses import (
22
+ FileResponse,
23
+ HTMLResponse,
24
+ JSONResponse,
25
+ PlainTextResponse,
26
+ Response,
27
+ )
28
+ from gradio_client.utils import is_file_obj_with_meta
29
+
30
+ from daggr.executor import AsyncExecutor, FileValue
31
+ from daggr.node import (
32
+ _FILE_TYPE_COMPONENTS,
33
+ ChoiceNode,
34
+ GradioNode,
35
+ InferenceNode,
36
+ InteractionNode,
37
+ )
38
+ from daggr.session import ExecutionSession
39
+ from daggr.state import SessionState, get_daggr_cache_dir
40
+
41
+ _FILE_COMP_TYPES = {c.lower() for c in _FILE_TYPE_COMPONENTS}
42
+
43
+ if TYPE_CHECKING:
44
+ from gradio.themes import Base as Theme
45
+
46
+ from daggr.graph import Graph
47
+
48
+
49
+ INITIAL_PORT_VALUE = int(os.getenv("DAGGR_SERVER_PORT", "7860"))
50
+ TRY_NUM_PORTS = int(os.getenv("DAGGR_NUM_PORTS", "100"))
51
+
52
+
53
+ def _find_available_port(host: str, start_port: int) -> int:
54
+ """Find an available port starting from start_port."""
55
+ for port in range(start_port, start_port + TRY_NUM_PORTS):
56
+ try:
57
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
58
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
59
+ s.bind((host if host != "0.0.0.0" else "127.0.0.1", port))
60
+ s.close()
61
+ return port
62
+ except OSError:
63
+ continue
64
+ raise OSError(
65
+ f"Cannot find empty port in range: {start_port}-{start_port + TRY_NUM_PORTS - 1}. "
66
+ f"You can specify a different port by setting the DAGGR_SERVER_PORT environment variable "
67
+ f"or passing the port parameter to launch()."
68
+ )
69
+
70
+
71
+ def _get_theme(theme: "Theme | str | None") -> "Theme":
72
+ """Get a Gradio theme instance from a theme specification.
73
+
74
+ Args:
75
+ theme: Can be a Theme instance, a string name like "default", "soft",
76
+ "monochrome", "glass", or a Hub theme like "gradio/seafoam".
77
+
78
+ Returns:
79
+ A Theme instance.
80
+ """
81
+ from gradio.themes import Default
82
+
83
+ if theme is None:
84
+ return Default()
85
+
86
+ if isinstance(theme, str):
87
+ from gradio.themes import Base, Default, Glass, Monochrome, Soft
88
+
89
+ theme_mapping = {
90
+ "default": Default,
91
+ "soft": Soft,
92
+ "monochrome": Monochrome,
93
+ "glass": Glass,
94
+ "base": Base,
95
+ }
96
+ theme_lower = theme.lower()
97
+ if theme_lower in theme_mapping:
98
+ return theme_mapping[theme_lower]()
99
+ # Try loading from Hub
100
+ try:
101
+ return Base.from_hub(theme)
102
+ except Exception:
103
+ return Default()
104
+
105
+ return theme
106
+
107
+
108
+ class DaggrServer:
109
+ def __init__(
110
+ self,
111
+ graph: Graph,
112
+ theme: "Theme | str | None" = None,
113
+ api_server: bool = True,
114
+ ):
115
+ self.graph = graph
116
+ self.api_server = api_server
117
+ self.executor = AsyncExecutor(graph)
118
+ self.state = SessionState(db_path=os.environ.get("DAGGR_DB_PATH"))
119
+ self.app = FastAPI(title=graph.name)
120
+ self.connections: dict[str, WebSocket] = {}
121
+ self.theme = _get_theme(theme)
122
+ self.theme_css = self.theme._get_theme_css()
123
+ self._setup_routes()
124
+
125
+ def _extract_token_from_header(self, authorization: str | None) -> str | None:
126
+ if authorization and authorization.startswith("Bearer "):
127
+ return authorization[7:]
128
+ return None
129
+
130
+ def _validate_hf_token(self, token: str) -> dict | None:
131
+ try:
132
+ from huggingface_hub import whoami
133
+
134
+ info = whoami(token=token, cache=True)
135
+ return {
136
+ "username": info.get("name"),
137
+ "fullname": info.get("fullname"),
138
+ "avatar_url": info.get("avatarUrl"),
139
+ }
140
+ except Exception:
141
+ return None
142
+
143
+ def _setup_routes(self):
144
+ frontend_dir = Path(__file__).parent / "frontend" / "dist"
145
+ if not frontend_dir.exists():
146
+ raise RuntimeError(
147
+ f"Frontend not found at {frontend_dir}. "
148
+ "If developing, run 'npm run build' in daggr/frontend/"
149
+ )
150
+
151
+ @self.app.get("/theme.css", response_class=PlainTextResponse)
152
+ async def get_theme_css():
153
+ return PlainTextResponse(self.theme_css, media_type="text/css")
154
+
155
+ @self.app.get("/api/graph")
156
+ async def get_graph():
157
+ return self._build_graph_data()
158
+
159
+ @self.app.get("/api/hf_user")
160
+ async def get_hf_user():
161
+ return self._get_hf_user_info()
162
+
163
+ @self.app.get("/api/user_info")
164
+ async def get_user_info(authorization: str | None = Header(default=None)):
165
+ browser_token = self._extract_token_from_header(authorization)
166
+ if browser_token:
167
+ hf_user = self._validate_hf_token(browser_token)
168
+ else:
169
+ hf_user = self._get_hf_user_info()
170
+ user_id = self.state.get_effective_user_id(hf_user)
171
+ is_on_spaces = os.environ.get("SPACE_ID") is not None
172
+ persistence_enabled = self.graph.persist_key is not None
173
+ return {
174
+ "hf_user": hf_user,
175
+ "user_id": user_id,
176
+ "is_on_spaces": is_on_spaces,
177
+ "can_persist": user_id is not None and persistence_enabled,
178
+ }
179
+
180
+ @self.app.post("/api/auth/login")
181
+ async def auth_login(request: Request):
182
+ try:
183
+ body = await request.json()
184
+ token = body.get("token")
185
+ if not token:
186
+ return JSONResponse({"error": "Token is required"}, status_code=400)
187
+ hf_user = self._validate_hf_token(token)
188
+ if not hf_user:
189
+ return JSONResponse({"error": "Invalid token"}, status_code=401)
190
+ return {"hf_user": hf_user, "success": True}
191
+ except Exception as e:
192
+ return JSONResponse({"error": str(e)}, status_code=500)
193
+
194
+ @self.app.post("/api/auth/logout")
195
+ async def auth_logout():
196
+ return {"success": True}
197
+
198
+ @self.app.get("/api/sheets")
199
+ async def list_sheets(authorization: str | None = Header(default=None)):
200
+ if not self.graph.persist_key:
201
+ return {"sheets": [], "user_id": None}
202
+ browser_token = self._extract_token_from_header(authorization)
203
+ if browser_token:
204
+ hf_user = self._validate_hf_token(browser_token)
205
+ else:
206
+ hf_user = self._get_hf_user_info()
207
+ user_id = self.state.get_effective_user_id(hf_user)
208
+ if not user_id:
209
+ return JSONResponse(
210
+ {"error": "Login required to access sheets on Spaces"},
211
+ status_code=401,
212
+ )
213
+ sheets = self.state.list_sheets(user_id, self.graph.persist_key)
214
+ return {"sheets": sheets, "user_id": user_id}
215
+
216
+ @self.app.post("/api/sheets")
217
+ async def create_sheet(
218
+ request: Request, authorization: str | None = Header(default=None)
219
+ ):
220
+ if not self.graph.persist_key:
221
+ return JSONResponse(
222
+ {"error": "Persistence is disabled for this graph"},
223
+ status_code=400,
224
+ )
225
+ browser_token = self._extract_token_from_header(authorization)
226
+ if browser_token:
227
+ hf_user = self._validate_hf_token(browser_token)
228
+ else:
229
+ hf_user = self._get_hf_user_info()
230
+ user_id = self.state.get_effective_user_id(hf_user)
231
+ if not user_id:
232
+ return JSONResponse(
233
+ {"error": "Login required to create sheets on Spaces"},
234
+ status_code=401,
235
+ )
236
+ body = await request.json()
237
+ name = body.get("name")
238
+ sheet_id = self.state.create_sheet(user_id, self.graph.persist_key, name)
239
+ sheet = self.state.get_sheet(sheet_id)
240
+ return {"sheet": sheet}
241
+
242
+ @self.app.patch("/api/sheets/{sheet_id}")
243
+ async def rename_sheet(
244
+ sheet_id: str,
245
+ request: Request,
246
+ authorization: str | None = Header(default=None),
247
+ ):
248
+ browser_token = self._extract_token_from_header(authorization)
249
+ if browser_token:
250
+ hf_user = self._validate_hf_token(browser_token)
251
+ else:
252
+ hf_user = self._get_hf_user_info()
253
+ user_id = self.state.get_effective_user_id(hf_user)
254
+ if not user_id:
255
+ return JSONResponse({"error": "Login required"}, status_code=401)
256
+ sheet = self.state.get_sheet(sheet_id)
257
+ if not sheet:
258
+ return JSONResponse({"error": "Sheet not found"}, status_code=404)
259
+ if sheet["user_id"] != user_id:
260
+ return JSONResponse({"error": "Access denied"}, status_code=403)
261
+ body = await request.json()
262
+ new_name = body.get("name")
263
+ if not new_name:
264
+ return JSONResponse({"error": "Name required"}, status_code=400)
265
+ self.state.rename_sheet(sheet_id, new_name)
266
+ return {"success": True, "sheet": self.state.get_sheet(sheet_id)}
267
+
268
+ @self.app.delete("/api/sheets/{sheet_id}")
269
+ async def delete_sheet(
270
+ sheet_id: str, authorization: str | None = Header(default=None)
271
+ ):
272
+ browser_token = self._extract_token_from_header(authorization)
273
+ if browser_token:
274
+ hf_user = self._validate_hf_token(browser_token)
275
+ else:
276
+ hf_user = self._get_hf_user_info()
277
+ user_id = self.state.get_effective_user_id(hf_user)
278
+ if not user_id:
279
+ return JSONResponse({"error": "Login required"}, status_code=401)
280
+ sheet = self.state.get_sheet(sheet_id)
281
+ if not sheet:
282
+ return JSONResponse({"error": "Sheet not found"}, status_code=404)
283
+ if sheet["user_id"] != user_id:
284
+ return JSONResponse({"error": "Access denied"}, status_code=403)
285
+ self.state.delete_sheet(sheet_id)
286
+ return {"success": True}
287
+
288
+ @self.app.get("/api/sheets/{sheet_id}/state")
289
+ async def get_sheet_state(
290
+ sheet_id: str, authorization: str | None = Header(default=None)
291
+ ):
292
+ browser_token = self._extract_token_from_header(authorization)
293
+ if browser_token:
294
+ hf_user = self._validate_hf_token(browser_token)
295
+ else:
296
+ hf_user = self._get_hf_user_info()
297
+ user_id = self.state.get_effective_user_id(hf_user)
298
+ if not user_id:
299
+ return JSONResponse({"error": "Login required"}, status_code=401)
300
+ sheet = self.state.get_sheet(sheet_id)
301
+ if not sheet:
302
+ return JSONResponse({"error": "Sheet not found"}, status_code=404)
303
+ if sheet["user_id"] != user_id:
304
+ return JSONResponse({"error": "Access denied"}, status_code=403)
305
+ state = self.state.get_sheet_state(sheet_id)
306
+ return {"sheet": sheet, "state": state}
307
+
308
+ @self.app.post("/api/run/{node_name}")
309
+ async def run_to_node(node_name: str, data: dict):
310
+ session = ExecutionSession(self.graph)
311
+ session_id = data.get("session_id")
312
+ input_values = data.get("inputs", {})
313
+ selected_results = data.get("selected_results", {})
314
+ return await self._execute_to_node(
315
+ session, node_name, session_id, input_values, selected_results
316
+ )
317
+
318
+ if self.api_server:
319
+
320
+ @self.app.get("/api/schema")
321
+ async def get_api_schema():
322
+ return self.graph.get_api_schema()
323
+
324
+ @self.app.post("/api/call")
325
+ async def call_workflow(request: Request):
326
+ return await self._execute_workflow_api(request, subgraph_id=None)
327
+
328
+ @self.app.post("/api/call/{subgraph_id}")
329
+ async def call_subgraph(subgraph_id: str, request: Request):
330
+ return await self._execute_workflow_api(
331
+ request, subgraph_id=subgraph_id
332
+ )
333
+
334
+ @self.app.websocket("/ws/{session_id}")
335
+ async def websocket_endpoint(websocket: WebSocket, session_id: str):
336
+ await websocket.accept()
337
+ self.connections[session_id] = websocket
338
+
339
+ hf_user = self._get_hf_user_info()
340
+ user_id = self.state.get_effective_user_id(hf_user)
341
+ current_sheet_id: str | None = None
342
+
343
+ session = ExecutionSession(self.graph)
344
+ running_tasks: dict[str, asyncio.Task] = {}
345
+
346
+ async def run_node_execution(
347
+ node_name: str,
348
+ sheet_id: str | None,
349
+ input_values: dict,
350
+ item_list_values: dict,
351
+ selected_results: dict,
352
+ run_id: str,
353
+ user_id: str | None,
354
+ run_ancestors: bool = True,
355
+ ):
356
+ try:
357
+ async for result in self._execute_to_node_streaming(
358
+ session,
359
+ node_name,
360
+ sheet_id,
361
+ input_values,
362
+ item_list_values,
363
+ selected_results,
364
+ run_id,
365
+ user_id,
366
+ run_ancestors,
367
+ ):
368
+ await websocket.send_json(result)
369
+ except asyncio.CancelledError:
370
+ pass
371
+ except Exception as e:
372
+ await websocket.send_json(
373
+ {
374
+ "type": "error",
375
+ "run_id": run_id,
376
+ "error": str(e),
377
+ "node": node_name,
378
+ }
379
+ )
380
+
381
+ try:
382
+ while True:
383
+ data = await websocket.receive_json()
384
+ action = data.get("action")
385
+
386
+ if "hf_token" in data:
387
+ browser_hf_token = data.get("hf_token")
388
+ old_user_id = user_id
389
+ if browser_hf_token:
390
+ hf_user = self._validate_hf_token(browser_hf_token)
391
+ user_id = self.state.get_effective_user_id(hf_user)
392
+ session.set_hf_token(browser_hf_token)
393
+ else:
394
+ hf_user = self._get_hf_user_info()
395
+ user_id = self.state.get_effective_user_id(hf_user)
396
+ session.set_hf_token(None)
397
+ if old_user_id != user_id:
398
+ session.clear_results()
399
+ current_sheet_id = None
400
+
401
+ if action == "run":
402
+ node_name = data.get("node_name")
403
+ input_values = data.get("inputs", {})
404
+ item_list_values = data.get("item_list_values", {})
405
+ selected_results = data.get("selected_results", {})
406
+ run_id = data.get("run_id")
407
+ sheet_id = data.get("sheet_id") or current_sheet_id
408
+ run_ancestors = data.get("run_ancestors", True)
409
+
410
+ task = asyncio.create_task(
411
+ run_node_execution(
412
+ node_name,
413
+ sheet_id,
414
+ input_values,
415
+ item_list_values,
416
+ selected_results,
417
+ run_id,
418
+ user_id,
419
+ run_ancestors,
420
+ )
421
+ )
422
+ running_tasks[run_id] = task
423
+ task.add_done_callback(
424
+ lambda t, rid=run_id: running_tasks.pop(rid, None)
425
+ )
426
+
427
+ elif action == "cancel":
428
+ cancel_run_id = data.get("run_id")
429
+ cancel_node = data.get("node_name")
430
+ task = running_tasks.get(cancel_run_id)
431
+ if task:
432
+ task.cancel()
433
+ await websocket.send_json(
434
+ {
435
+ "type": "cancelled",
436
+ "run_id": cancel_run_id,
437
+ "node": cancel_node,
438
+ }
439
+ )
440
+
441
+ elif action == "get_graph":
442
+ try:
443
+ sheet_id = data.get("sheet_id")
444
+
445
+ persisted_inputs = {}
446
+ persisted_results: dict[str, list[Any]] = {}
447
+ persisted_transform = None
448
+
449
+ if user_id and sheet_id:
450
+ sheet = self.state.get_sheet(sheet_id)
451
+ if sheet and sheet["user_id"] == user_id:
452
+ current_sheet_id = sheet_id
453
+ state = self.state.get_sheet_state(sheet_id)
454
+ persisted_inputs = state.get("inputs", {})
455
+ persisted_results = state.get("results", {})
456
+ persisted_transform = sheet.get("transform")
457
+
458
+ node_results = {}
459
+ for node_name, results_list in persisted_results.items():
460
+ if results_list:
461
+ last_entry = results_list[-1]
462
+ if (
463
+ isinstance(last_entry, dict)
464
+ and "result" in last_entry
465
+ ):
466
+ node_results[node_name] = last_entry["result"]
467
+ else:
468
+ node_results[node_name] = last_entry
469
+
470
+ graph_data = self._build_graph_data(
471
+ node_results=node_results,
472
+ input_values=persisted_inputs,
473
+ )
474
+ graph_data["session_id"] = session_id
475
+ graph_data["sheet_id"] = current_sheet_id
476
+ graph_data["user_id"] = user_id
477
+ graph_data["persisted_results"] = (
478
+ self._transform_persisted_results(persisted_results)
479
+ )
480
+ graph_data["transform"] = persisted_transform
481
+
482
+ await websocket.send_json(
483
+ {"type": "graph", "data": graph_data}
484
+ )
485
+ except Exception as e:
486
+ print(f"[ERROR] get_graph failed: {e}")
487
+ traceback.print_exc()
488
+ await websocket.send_json(
489
+ {"type": "error", "error": str(e)}
490
+ )
491
+
492
+ elif action == "save_input":
493
+ if user_id and current_sheet_id:
494
+ node_id = data.get("node_id")
495
+ port_name = data.get("port_name")
496
+ value = data.get("value")
497
+ if node_id and port_name is not None:
498
+ self.state.save_input(
499
+ current_sheet_id, node_id, port_name, value
500
+ )
501
+ await websocket.send_json(
502
+ {"type": "input_saved", "node_id": node_id}
503
+ )
504
+
505
+ elif action == "save_transform":
506
+ if user_id and current_sheet_id:
507
+ x = data.get("x", 0)
508
+ y = data.get("y", 0)
509
+ scale = data.get("scale", 1)
510
+ self.state.save_transform(current_sheet_id, x, y, scale)
511
+
512
+ elif action == "set_sheet":
513
+ sheet_id = data.get("sheet_id")
514
+ if user_id and sheet_id:
515
+ sheet = self.state.get_sheet(sheet_id)
516
+ if sheet and sheet["user_id"] == user_id:
517
+ current_sheet_id = sheet_id
518
+ session.clear_results()
519
+ await websocket.send_json(
520
+ {"type": "sheet_set", "sheet_id": sheet_id}
521
+ )
522
+
523
+ elif action == "save_variant_selection":
524
+ node_id = data.get("node_id")
525
+ variant_index = data.get("variant_index", 0)
526
+ if user_id and current_sheet_id and node_id is not None:
527
+ self.state.save_input(
528
+ current_sheet_id,
529
+ node_id,
530
+ "_selected_variant",
531
+ variant_index,
532
+ )
533
+ await websocket.send_json(
534
+ {
535
+ "type": "variant_selection_saved",
536
+ "node_id": node_id,
537
+ "variant_index": variant_index,
538
+ }
539
+ )
540
+
541
+ elif action == "clear_sheet":
542
+ if user_id and current_sheet_id:
543
+ self.state.clear_sheet_data(current_sheet_id)
544
+ await websocket.send_json({"type": "sheet_cleared"})
545
+
546
+ except WebSocketDisconnect:
547
+ for task in running_tasks.values():
548
+ task.cancel()
549
+ if session_id in self.connections:
550
+ del self.connections[session_id]
551
+ except Exception as e:
552
+ for task in running_tasks.values():
553
+ task.cancel()
554
+ print(f"[ERROR] WebSocket error: {e}")
555
+ traceback.print_exc()
556
+
557
+ @self.app.get("/")
558
+ async def serve_index():
559
+ index_path = frontend_dir / "index.html"
560
+ if index_path.exists():
561
+ return FileResponse(index_path)
562
+ return HTMLResponse(self._get_dev_html())
563
+
564
+ @self.app.get("/assets/{path:path}")
565
+ async def serve_assets(path: str):
566
+ file_path = frontend_dir / "assets" / path
567
+ if file_path.exists():
568
+ content_type, _ = mimetypes.guess_type(str(file_path))
569
+ return FileResponse(file_path, media_type=content_type)
570
+ return Response(status_code=404)
571
+
572
+ @self.app.get("/daggr-assets/{path:path}")
573
+ async def serve_daggr_assets(path: str):
574
+ assets_dir = Path(__file__).parent / "assets"
575
+ file_path = assets_dir / path
576
+ if file_path.exists():
577
+ content_type, _ = mimetypes.guess_type(str(file_path))
578
+ return FileResponse(file_path, media_type=content_type)
579
+ return Response(status_code=404)
580
+
581
+ @self.app.get("/file/{path:path}")
582
+ async def serve_local_file(path: str):
583
+ if len(path) >= 2 and path[1] == ":":
584
+ file_path = Path(path)
585
+ else:
586
+ file_path = Path("/") / path
587
+ temp_dir = Path(tempfile.gettempdir()).resolve()
588
+ daggr_cache = get_daggr_cache_dir().resolve()
589
+
590
+ try:
591
+ resolved = file_path.resolve()
592
+ is_allowed = str(resolved).startswith(str(temp_dir)) or str(
593
+ resolved
594
+ ).startswith(str(daggr_cache))
595
+ if not is_allowed:
596
+ return Response(status_code=403)
597
+ except (ValueError, OSError):
598
+ return Response(status_code=403)
599
+ if resolved.exists() and resolved.is_file():
600
+ content_type, _ = mimetypes.guess_type(str(resolved))
601
+ return FileResponse(
602
+ resolved, media_type=content_type or "application/octet-stream"
603
+ )
604
+ return Response(status_code=404)
605
+
606
+ @self.app.get("/{path:path}")
607
+ async def serve_static(path: str):
608
+ if path.startswith("api/") or path.startswith("ws/"):
609
+ return Response(status_code=404)
610
+ file_path = frontend_dir / path
611
+ if file_path.exists() and file_path.is_file():
612
+ return FileResponse(file_path)
613
+ index_path = frontend_dir / "index.html"
614
+ if index_path.exists():
615
+ return FileResponse(index_path)
616
+ return HTMLResponse(self._get_dev_html())
617
+
618
+ def _get_dev_html(self) -> str:
619
+ return f"""<!DOCTYPE html>
620
+ <html lang="en">
621
+ <head>
622
+ <meta charset="UTF-8">
623
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
624
+ <title>{self.graph.name}</title>
625
+ <link rel="stylesheet" href="/theme.css">
626
+ <style>
627
+ * {{ margin: 0; box-sizing: border-box; }}
628
+ body {{
629
+ background: var(--body-background-fill, #000);
630
+ min-height: 100vh;
631
+ font-family: 'Space Grotesk', -apple-system, BlinkMacSystemFont, sans-serif;
632
+ overflow: hidden;
633
+ color: var(--body-text-color, #fff);
634
+ }}
635
+ </style>
636
+ <script type="module" src="http://localhost:5173/src/main.ts"></script>
637
+ </head>
638
+ <body class="dark">
639
+ <div id="app"></div>
640
+ </body>
641
+ </html>"""
642
+
643
+ def _get_node_url(self, node) -> str | None:
644
+ if isinstance(node, GradioNode):
645
+ src = node._src
646
+ if src.startswith("http://") or src.startswith("https://"):
647
+ return src
648
+ elif "/" in src:
649
+ return f"https://huggingface.co/spaces/{src}"
650
+ elif isinstance(node, InferenceNode):
651
+ return f"https://huggingface.co/{node._model_name_for_hub}"
652
+ return None
653
+
654
+ def _get_node_type(self, node, node_name: str) -> str:
655
+ type_map = {
656
+ "FnNode": "FN",
657
+ "TextInput": "INPUT",
658
+ "ImageInput": "IMAGE",
659
+ "ChooseOne": "SELECT",
660
+ "Approve": "APPROVE",
661
+ "GradioNode": "GRADIO",
662
+ "InferenceNode": "MODEL",
663
+ "InteractionNode": "ACTION",
664
+ "ChoiceNode": "CHOICE",
665
+ }
666
+ if isinstance(node, ChoiceNode):
667
+ return "CHOICE"
668
+ class_name = node.__class__.__name__
669
+ return type_map.get(class_name, class_name.upper())
670
+
671
+ def _has_scattered_input(self, node_name: str) -> bool:
672
+ for edge in self.graph._edges:
673
+ if edge.target_node._name == node_name and edge.is_scattered:
674
+ return True
675
+ return False
676
+
677
+ def _get_scattered_edge(self, node_name: str):
678
+ for edge in self.graph._edges:
679
+ if edge.target_node._name == node_name and edge.is_scattered:
680
+ return edge
681
+ return None
682
+
683
+ def _is_output_node(self, node_name: str) -> bool:
684
+ return self.graph._nx_graph.out_degree(node_name) == 0
685
+
686
+ def _is_running_locally(self, node) -> bool:
687
+ if not isinstance(node, GradioNode):
688
+ return False
689
+ return bool(node._run_locally and node._local_url and not node._local_failed)
690
+
691
+ def _build_variant_data(self, variant, input_values: dict) -> dict[str, Any]:
692
+ variant_name = variant._name
693
+ if isinstance(variant, GradioNode) and not variant._name_explicitly_set:
694
+ variant_name = f"{variant._src}"
695
+ if variant._api_name:
696
+ variant_name += f" ({variant._api_name})"
697
+
698
+ input_components = []
699
+ for port_name, comp in variant._input_components.items():
700
+ comp_data = self._serialize_component(comp, port_name)
701
+ input_components.append(comp_data)
702
+
703
+ output_components = []
704
+ for port_name, comp in variant._output_components.items():
705
+ if comp is None:
706
+ continue
707
+ visible = getattr(comp, "visible", True)
708
+ if visible is False:
709
+ continue
710
+ comp_data = self._serialize_component(comp, port_name)
711
+ output_components.append(comp_data)
712
+
713
+ return {
714
+ "name": variant_name,
715
+ "input_components": input_components,
716
+ "output_components": output_components,
717
+ }
718
+
719
+ def _get_component_type(self, component) -> str:
720
+ class_name = component.__class__.__name__
721
+ type_map = {
722
+ "Audio": "audio",
723
+ "Textbox": "textbox",
724
+ "TextArea": "textarea",
725
+ "JSON": "json",
726
+ "Chatbot": "json",
727
+ "Image": "image",
728
+ "Number": "number",
729
+ "Markdown": "markdown",
730
+ "Text": "text",
731
+ "Dropdown": "dropdown",
732
+ "Video": "video",
733
+ "File": "file",
734
+ "Model3D": "model3d",
735
+ "Gallery": "gallery",
736
+ "Slider": "slider",
737
+ "Radio": "radio",
738
+ "Checkbox": "checkbox",
739
+ "CheckboxGroup": "checkboxgroup",
740
+ "ColorPicker": "colorpicker",
741
+ "Label": "label",
742
+ "HighlightedText": "highlightedtext",
743
+ "Code": "code",
744
+ "HTML": "html",
745
+ "Dataframe": "dataframe",
746
+ }
747
+ return type_map.get(class_name, "text")
748
+
749
+ def _serialize_component(self, comp, port_name: str) -> dict[str, Any]:
750
+ comp_type = self._get_component_type(comp)
751
+ comp_class = comp.__class__.__name__
752
+
753
+ props = {
754
+ "label": getattr(comp, "label", "") or port_name,
755
+ "show_label": bool(getattr(comp, "label", "")),
756
+ "interactive": getattr(comp, "interactive", True),
757
+ "visible": getattr(comp, "visible", True),
758
+ }
759
+
760
+ if hasattr(comp, "placeholder"):
761
+ props["placeholder"] = comp.placeholder
762
+ if hasattr(comp, "lines"):
763
+ props["lines"] = comp.lines
764
+ if hasattr(comp, "max_lines"):
765
+ props["max_lines"] = comp.max_lines
766
+ if hasattr(comp, "type"):
767
+ props["type"] = comp.type
768
+ if hasattr(comp, "choices") and comp.choices:
769
+ choices = []
770
+ for c in comp.choices:
771
+ if isinstance(c, (tuple, list)) and len(c) >= 2:
772
+ choices.append([c[0], c[1]])
773
+ else:
774
+ choices.append([str(c), c])
775
+ props["choices"] = choices
776
+ if hasattr(comp, "minimum"):
777
+ props["minimum"] = comp.minimum
778
+ if hasattr(comp, "maximum"):
779
+ props["maximum"] = comp.maximum
780
+ if hasattr(comp, "step"):
781
+ props["step"] = comp.step
782
+
783
+ value = getattr(comp, "value", None)
784
+ if is_file_obj_with_meta(value):
785
+ value = self._file_to_url(value["path"])
786
+
787
+ return {
788
+ "component": comp_class.lower(),
789
+ "type": comp_type,
790
+ "port_name": port_name,
791
+ "props": props,
792
+ "value": value,
793
+ }
794
+
795
+ def _file_to_url(self, value: Any) -> Any:
796
+ if isinstance(value, str) and not value.startswith("/file/"):
797
+ path = Path(value)
798
+ if path.is_absolute() and path.exists():
799
+ normalized = value.replace("\\", "/")
800
+ if normalized.startswith("/"):
801
+ return f"/file{normalized}"
802
+ return f"/file/{normalized}"
803
+ return value
804
+
805
+ def _validate_file_value(self, value: Any, comp_type: str) -> str | None:
806
+ """Validate that a value is appropriate for a file-type component.
807
+ Returns an error message if invalid, None if valid."""
808
+ if value is None:
809
+ return None
810
+ if isinstance(value, str):
811
+ return None
812
+ if isinstance(value, dict):
813
+ if "url" in value or "path" in value:
814
+ return None
815
+ keys = list(value.keys())
816
+ if keys:
817
+ return (
818
+ f"Expected a file path string for {comp_type}, but got a dict "
819
+ f"with keys {keys}. If using postprocess, extract the path: "
820
+ f"e.g., `postprocess=lambda x: x['{keys[0]}']`"
821
+ )
822
+ return (
823
+ f"Expected a file path string for {comp_type}, but got an empty dict."
824
+ )
825
+ return f"Expected a file path string for {comp_type}, but got {type(value).__name__}."
826
+
827
+ def _transform_file_paths(self, data: Any) -> Any:
828
+ if isinstance(data, str):
829
+ return self._file_to_url(data)
830
+ elif isinstance(data, dict):
831
+ return {k: self._transform_file_paths(v) for k, v in data.items()}
832
+ elif isinstance(data, list):
833
+ return [self._transform_file_paths(item) for item in data]
834
+ return data
835
+
836
+ def _transform_persisted_results(
837
+ self, persisted_results: dict[str, list[Any]]
838
+ ) -> dict[str, list[Any]]:
839
+ """Transform persisted results, handling both old format (just result)
840
+ and new format (dict with result and inputs_snapshot)."""
841
+ transformed: dict[str, list[Any]] = {}
842
+ for node_name, results_list in persisted_results.items():
843
+ transformed[node_name] = []
844
+ for entry in results_list:
845
+ if isinstance(entry, dict) and "result" in entry:
846
+ transformed[node_name].append(
847
+ {
848
+ "result": self._transform_file_paths(entry["result"]),
849
+ "inputs_snapshot": entry.get("inputs_snapshot"),
850
+ }
851
+ )
852
+ else:
853
+ transformed[node_name].append(self._transform_file_paths(entry))
854
+ return transformed
855
+
856
+ def _build_input_components(self, node) -> list[dict[str, Any]]:
857
+ if not node._input_components:
858
+ return []
859
+ return [
860
+ self._serialize_component(comp, port_name)
861
+ for port_name, comp in node._input_components.items()
862
+ ]
863
+
864
+ def _build_output_components(
865
+ self, node, result: Any = None
866
+ ) -> tuple[list[dict[str, Any]], str | None]:
867
+ if not node._output_components:
868
+ return [], None
869
+
870
+ components = []
871
+ validation_error = None
872
+ for port_name, comp in node._output_components.items():
873
+ if comp is None:
874
+ continue
875
+
876
+ visible = getattr(comp, "visible", True)
877
+ if visible is False:
878
+ continue
879
+
880
+ comp_data = self._serialize_component(comp, port_name)
881
+ comp_type = self._get_component_type(comp)
882
+ if result is not None:
883
+ if isinstance(result, dict):
884
+ value = result.get(
885
+ port_name, result.get(comp_data["props"]["label"])
886
+ )
887
+ else:
888
+ value = result
889
+ if comp_type in _FILE_COMP_TYPES:
890
+ error = self._validate_file_value(value, comp_type)
891
+ if error and validation_error is None:
892
+ validation_error = error
893
+ value = self._file_to_url(value)
894
+ comp_data["value"] = value
895
+ components.append(comp_data)
896
+ return components, validation_error
897
+
898
+ def _build_scattered_items(
899
+ self, node_name: str, result: Any = None
900
+ ) -> list[dict[str, Any]]:
901
+ scattered_edge = self._get_scattered_edge(node_name)
902
+ if not scattered_edge:
903
+ return []
904
+
905
+ node = self.graph.nodes[node_name]
906
+ item_output_type = "text"
907
+ for comp in node._output_components.values():
908
+ if comp is None:
909
+ continue
910
+ comp_type = self._get_component_type(comp)
911
+ if comp_type == "audio":
912
+ item_output_type = "audio"
913
+ break
914
+
915
+ items = []
916
+ if result and isinstance(result, dict) and "_scattered_results" in result:
917
+ results = result["_scattered_results"]
918
+ source_items = result.get("_items", [])
919
+ for i, item_result in enumerate(results):
920
+ source_item = source_items[i] if i < len(source_items) else None
921
+ preview = ""
922
+ output = None
923
+
924
+ if isinstance(source_item, dict):
925
+ preview_parts = [
926
+ f"{k}: {str(v)[:20]}" for k, v in list(source_item.items())[:2]
927
+ ]
928
+ preview = ", ".join(preview_parts)
929
+ elif source_item:
930
+ preview = str(source_item)[:50]
931
+
932
+ if isinstance(item_result, dict):
933
+ first_key = list(item_result.keys())[0] if item_result else None
934
+ if first_key:
935
+ output = item_result[first_key]
936
+ else:
937
+ output = item_result
938
+
939
+ if output:
940
+ output = str(output)
941
+
942
+ items.append(
943
+ {
944
+ "index": i + 1,
945
+ "preview": preview or f"Item {i + 1}",
946
+ "output": output,
947
+ "is_audio_output": item_output_type == "audio",
948
+ }
949
+ )
950
+ return items
951
+
952
+ def _serialize_item_list_schema(
953
+ self, schema: dict[str, Any]
954
+ ) -> list[dict[str, Any]]:
955
+ serialized = []
956
+ for field_name, comp in schema.items():
957
+ comp_data = self._serialize_component(comp, field_name)
958
+ serialized.append(comp_data)
959
+ return serialized
960
+
961
+ def _build_item_list_items(
962
+ self, node, port_name: str, result: Any = None
963
+ ) -> list[dict[str, Any]]:
964
+ schema = node._item_list_schemas.get(port_name, {})
965
+ if not schema:
966
+ return []
967
+
968
+ items = []
969
+ if result and isinstance(result, dict) and port_name in result:
970
+ item_list = result[port_name]
971
+ if isinstance(item_list, list):
972
+ for i, item_data in enumerate(item_list):
973
+ item = {"index": i, "fields": {}}
974
+ if isinstance(item_data, dict):
975
+ for field_name in schema:
976
+ item["fields"][field_name] = item_data.get(field_name)
977
+ items.append(item)
978
+ return items
979
+
980
+ def _apply_item_list_edits(
981
+ self, node_name: str, result: Any, item_list_values: dict
982
+ ) -> Any:
983
+ node = self.graph.nodes[node_name]
984
+ if not node._item_list_schemas:
985
+ return result
986
+
987
+ node_id = node_name.replace(" ", "_").replace("-", "_")
988
+ edits = item_list_values.get(node_id, {})
989
+ if not edits:
990
+ return result
991
+
992
+ first_port = list(node._item_list_schemas.keys())[0]
993
+ if isinstance(result, dict) and first_port in result:
994
+ items = result[first_port]
995
+ if isinstance(items, list):
996
+ for idx_str, field_edits in edits.items():
997
+ idx = int(idx_str)
998
+ if 0 <= idx < len(items) and isinstance(items[idx], dict):
999
+ items[idx].update(field_edits)
1000
+ return result
1001
+
1002
+ def _compute_node_depths(self) -> dict[str, int]:
1003
+ depths: dict[str, int] = {}
1004
+ connections = self.graph.get_connections()
1005
+
1006
+ for node_name in self.graph.nodes:
1007
+ if self.graph._nx_graph.in_degree(node_name) == 0:
1008
+ depths[node_name] = 0
1009
+
1010
+ changed = True
1011
+ while changed:
1012
+ changed = False
1013
+ for source, _, target, _ in connections:
1014
+ if source in depths:
1015
+ new_depth = depths[source] + 1
1016
+ if target not in depths or depths[target] < new_depth:
1017
+ depths[target] = new_depth
1018
+ changed = True
1019
+
1020
+ for node_name in self.graph.nodes:
1021
+ if node_name not in depths:
1022
+ depths[node_name] = 0
1023
+
1024
+ return depths
1025
+
1026
+ def _get_hf_user_info(self) -> dict | None:
1027
+ try:
1028
+ from huggingface_hub import get_token, whoami
1029
+
1030
+ token = get_token()
1031
+ if not token:
1032
+ return None
1033
+
1034
+ info = whoami(cache=True)
1035
+ return {
1036
+ "username": info.get("name"),
1037
+ "fullname": info.get("fullname"),
1038
+ "avatar_url": info.get("avatarUrl"),
1039
+ }
1040
+ except Exception:
1041
+ return None
1042
+
1043
+ def _build_graph_data(
1044
+ self,
1045
+ node_results: dict[str, Any] | None = None,
1046
+ node_statuses: dict[str, str] | None = None,
1047
+ input_values: dict[str, Any] | None = None,
1048
+ history: dict[str, dict[str, list[dict]]] | None = None,
1049
+ session_id: str | None = None,
1050
+ selected_results: dict[str, int] | None = None,
1051
+ ) -> dict:
1052
+ node_results = node_results or {}
1053
+ node_statuses = node_statuses or {}
1054
+ input_values = input_values or {}
1055
+ history = history or {}
1056
+ selected_results = selected_results or {}
1057
+
1058
+ depths = self._compute_node_depths()
1059
+
1060
+ synthetic_input_nodes: list[dict[str, Any]] = []
1061
+ synthetic_edges: list[dict[str, Any]] = []
1062
+ input_node_positions: dict[str, tuple] = {}
1063
+
1064
+ component_to_input_node: dict[int, str] = {}
1065
+ creation_order = 0
1066
+ for node_name in self.graph.nodes:
1067
+ node = self.graph.nodes[node_name]
1068
+
1069
+ if isinstance(node, ChoiceNode):
1070
+ continue
1071
+
1072
+ if node._input_components:
1073
+ for idx, (port_name, comp) in enumerate(node._input_components.items()):
1074
+ comp_id = id(comp)
1075
+
1076
+ if comp_id in component_to_input_node:
1077
+ existing_input_node = component_to_input_node[comp_id]
1078
+ existing_input_id = existing_input_node.replace(
1079
+ " ", "_"
1080
+ ).replace("-", "_")
1081
+ synthetic_edges.append(
1082
+ {
1083
+ "from_node": existing_input_id,
1084
+ "from_port": "value",
1085
+ "to_node": node_name.replace(" ", "_").replace(
1086
+ "-", "_"
1087
+ ),
1088
+ "to_port": port_name,
1089
+ }
1090
+ )
1091
+ continue
1092
+
1093
+ input_node_name = f"{node_name}__{port_name}"
1094
+ input_node_id = input_node_name.replace(" ", "_").replace("-", "_")
1095
+ component_to_input_node[comp_id] = input_node_name
1096
+
1097
+ comp_data = self._serialize_component(comp, "value")
1098
+ label = comp_data["props"].get("label") or port_name
1099
+
1100
+ if input_node_id in input_values:
1101
+ comp_data["value"] = input_values[input_node_id].get(
1102
+ "value", comp_data["value"]
1103
+ )
1104
+
1105
+ synthetic_input_nodes.append(
1106
+ {
1107
+ "node_name": input_node_name,
1108
+ "display_name": label,
1109
+ "target_node": node_name,
1110
+ "target_port": port_name,
1111
+ "component": comp_data,
1112
+ "index": idx,
1113
+ "creation_order": creation_order,
1114
+ }
1115
+ )
1116
+ creation_order += 1
1117
+
1118
+ synthetic_edges.append(
1119
+ {
1120
+ "from_node": input_node_id,
1121
+ "from_port": "value",
1122
+ "to_node": node_name.replace(" ", "_").replace("-", "_"),
1123
+ "to_port": port_name,
1124
+ }
1125
+ )
1126
+
1127
+ max_depth = max(depths.values()) if depths else 0
1128
+
1129
+ nodes_by_depth: dict[int, list[str]] = {}
1130
+ for node_name, depth in depths.items():
1131
+ if depth not in nodes_by_depth:
1132
+ nodes_by_depth[depth] = []
1133
+ nodes_by_depth[depth].append(node_name)
1134
+
1135
+ x_spacing = 350
1136
+ input_column_x = 50
1137
+ x_start = 400
1138
+ y_start = 120
1139
+ y_gap = 30
1140
+ base_node_height = 100
1141
+ component_base_height = 60
1142
+ line_height = 18
1143
+
1144
+ def calc_component_height(comp_data: dict) -> int:
1145
+ lines = comp_data.get("props", {}).get("lines", 1)
1146
+ lines = min(lines, 6)
1147
+ return component_base_height + max(0, lines - 1) * line_height
1148
+
1149
+ def calc_node_height(components: list[dict], num_ports: int = 1) -> int:
1150
+ comp_height = sum(calc_component_height(c) for c in components)
1151
+ port_height = max(num_ports, 1) * 22
1152
+ return base_node_height + port_height + comp_height
1153
+
1154
+ all_input_nodes_sorted: list[dict] = []
1155
+ for syn_node in synthetic_input_nodes:
1156
+ target_depth = depths.get(syn_node["target_node"], 0)
1157
+ all_input_nodes_sorted.append({**syn_node, "target_depth": target_depth})
1158
+ all_input_nodes_sorted.sort(key=lambda x: x["creation_order"])
1159
+
1160
+ current_input_y = y_start
1161
+ for syn_node in all_input_nodes_sorted:
1162
+ input_node_positions[syn_node["node_name"]] = (
1163
+ input_column_x,
1164
+ current_input_y,
1165
+ )
1166
+ node_height = calc_node_height([syn_node["component"]], 1)
1167
+ current_input_y += node_height + y_gap
1168
+
1169
+ node_positions: dict[str, tuple] = {}
1170
+ for depth in range(max_depth + 1):
1171
+ depth_nodes = nodes_by_depth.get(depth, [])
1172
+ current_y = y_start
1173
+ for node_name in depth_nodes:
1174
+ node = self.graph.nodes[node_name]
1175
+ output_comps, _ = self._build_output_components(node)
1176
+ num_ports = max(
1177
+ len(node._input_ports or []), len(node._output_ports or [])
1178
+ )
1179
+ node_height = calc_node_height(output_comps, num_ports)
1180
+ x = x_start + depth * x_spacing
1181
+ node_positions[node_name] = (x, current_y)
1182
+ current_y += node_height + y_gap
1183
+
1184
+ nodes = []
1185
+
1186
+ for syn_node in synthetic_input_nodes:
1187
+ node_name = syn_node["node_name"]
1188
+ display_name = syn_node["display_name"]
1189
+ node_id = node_name.replace(" ", "_").replace("-", "_")
1190
+ x, y = input_node_positions.get(node_name, (50, 50))
1191
+ comp = syn_node["component"]
1192
+
1193
+ nodes.append(
1194
+ {
1195
+ "id": node_id,
1196
+ "name": display_name,
1197
+ "type": "INPUT",
1198
+ "inputs": [],
1199
+ "outputs": ["value"],
1200
+ "x": x,
1201
+ "y": y,
1202
+ "has_input": False,
1203
+ "input_value": "",
1204
+ "input_components": [comp],
1205
+ "output_components": [],
1206
+ "is_map_node": False,
1207
+ "map_items": [],
1208
+ "map_item_count": 0,
1209
+ "item_output_type": "text",
1210
+ "status": "pending",
1211
+ "result": "",
1212
+ "is_output_node": False,
1213
+ "is_input_node": True,
1214
+ }
1215
+ )
1216
+
1217
+ for node_name in self.graph.nodes:
1218
+ node = self.graph.nodes[node_name]
1219
+ x, y = node_positions.get(node_name, (50, 50))
1220
+
1221
+ result = node_results.get(node_name)
1222
+ result_str = ""
1223
+ is_scattered = self._has_scattered_input(node_name)
1224
+ if result is not None and not node._output_components and not is_scattered:
1225
+ if isinstance(result, dict):
1226
+ display_result = {
1227
+ k: v for k, v in result.items() if not k.startswith("_")
1228
+ }
1229
+ result_str = json.dumps(display_result, indent=2, default=str)[:300]
1230
+ elif isinstance(result, (list, tuple)):
1231
+ result_str = json.dumps(list(result)[:5], default=str)
1232
+ else:
1233
+ result_str = str(result)[:300]
1234
+
1235
+ node_id = node_name.replace(" ", "_").replace("-", "_")
1236
+
1237
+ input_ports_data = []
1238
+ for port in node._input_ports or []:
1239
+ if port in node._fixed_inputs:
1240
+ continue
1241
+ port_history = history.get(node_name, {}).get(port, [])
1242
+ input_ports_data.append(
1243
+ {
1244
+ "name": port,
1245
+ "history_count": len(port_history) if port_history else 0,
1246
+ }
1247
+ )
1248
+
1249
+ output_components, validation_error = self._build_output_components(
1250
+ node, result
1251
+ )
1252
+ scattered_items = (
1253
+ self._build_scattered_items(node_name, result) if is_scattered else []
1254
+ )
1255
+
1256
+ item_output_type = "text"
1257
+ if is_scattered:
1258
+ for comp in node._output_components.values():
1259
+ if comp is None:
1260
+ continue
1261
+ comp_type = self._get_component_type(comp)
1262
+ if comp_type == "audio":
1263
+ item_output_type = "audio"
1264
+ break
1265
+
1266
+ item_list_schema = None
1267
+ item_list_items = []
1268
+ if node._item_list_schemas:
1269
+ first_port = list(node._item_list_schemas.keys())[0]
1270
+ item_list_schema = self._serialize_item_list_schema(
1271
+ node._item_list_schemas[first_port]
1272
+ )
1273
+ item_list_items = self._build_item_list_items(node, first_port, result)
1274
+
1275
+ output_ports = []
1276
+ for port_name in node._output_ports or []:
1277
+ if port_name in node._item_list_schemas:
1278
+ schema = node._item_list_schemas[port_name]
1279
+ for field_name in schema:
1280
+ output_ports.append(f"{port_name}.{field_name}")
1281
+ elif port_name in node._output_components:
1282
+ output_ports.append(port_name)
1283
+
1284
+ is_output = self._is_output_node(node_name)
1285
+ is_local = self._is_running_locally(node)
1286
+
1287
+ variants = None
1288
+ selected_variant = None
1289
+ if isinstance(node, ChoiceNode):
1290
+ variants = [
1291
+ self._build_variant_data(v, input_values) for v in node._variants
1292
+ ]
1293
+ selected_variant = input_values.get(node_id, {}).get(
1294
+ "_selected_variant", 0
1295
+ )
1296
+
1297
+ nodes.append(
1298
+ {
1299
+ "id": node_id,
1300
+ "name": node_name,
1301
+ "type": self._get_node_type(node, node_name),
1302
+ "url": self._get_node_url(node),
1303
+ "inputs": input_ports_data,
1304
+ "outputs": output_ports,
1305
+ "x": x,
1306
+ "y": y,
1307
+ "has_input": False,
1308
+ "input_value": input_values.get(node_name, ""),
1309
+ "input_components": [],
1310
+ "output_components": output_components,
1311
+ "is_map_node": is_scattered,
1312
+ "map_items": scattered_items,
1313
+ "map_item_count": len(scattered_items),
1314
+ "item_output_type": item_output_type,
1315
+ "item_list_schema": item_list_schema,
1316
+ "item_list_items": item_list_items,
1317
+ "status": node_statuses.get(node_name, "pending"),
1318
+ "result": result_str,
1319
+ "is_output_node": is_output,
1320
+ "is_input_node": False,
1321
+ "is_local": is_local,
1322
+ "variants": variants,
1323
+ "selected_variant": selected_variant,
1324
+ "validation_error": validation_error,
1325
+ }
1326
+ )
1327
+
1328
+ edges = []
1329
+ for i, edge in enumerate(self.graph._edges):
1330
+ from_port = edge.source_port
1331
+ if edge.item_key:
1332
+ from_port = f"{edge.source_port}.{edge.item_key}"
1333
+ edges.append(
1334
+ {
1335
+ "id": f"edge_{i}",
1336
+ "from_node": edge.source_node._name.replace(" ", "_").replace(
1337
+ "-", "_"
1338
+ ),
1339
+ "from_port": from_port,
1340
+ "to_node": edge.target_node._name.replace(" ", "_").replace(
1341
+ "-", "_"
1342
+ ),
1343
+ "to_port": edge.target_port,
1344
+ "is_scattered": edge.is_scattered,
1345
+ "is_gathered": edge.is_gathered,
1346
+ }
1347
+ )
1348
+
1349
+ for i, syn_edge in enumerate(synthetic_edges):
1350
+ edges.append(
1351
+ {
1352
+ "id": f"input_edge_{i}",
1353
+ "from_node": syn_edge["from_node"],
1354
+ "from_port": syn_edge["from_port"],
1355
+ "to_node": syn_edge["to_node"],
1356
+ "to_port": syn_edge["to_port"],
1357
+ }
1358
+ )
1359
+
1360
+ return {
1361
+ "name": self.graph.name,
1362
+ "nodes": nodes,
1363
+ "edges": edges,
1364
+ "inputs": input_values,
1365
+ "selected_results": selected_results,
1366
+ "history": history,
1367
+ "session_id": session_id,
1368
+ }
1369
+
1370
+ def _get_ancestors(self, node_name: str) -> list[str]:
1371
+ ancestors = set()
1372
+ to_visit = [node_name]
1373
+ while to_visit:
1374
+ current = to_visit.pop()
1375
+ for source, _, target, _ in self.graph.get_connections():
1376
+ if target == current and source not in ancestors:
1377
+ ancestors.add(source)
1378
+ to_visit.append(source)
1379
+ return list(ancestors)
1380
+
1381
+ def _get_user_provided_output(
1382
+ self, node, node_id: str, input_values: dict[str, Any]
1383
+ ) -> dict[str, Any] | None:
1384
+ if not node._output_components:
1385
+ return None
1386
+
1387
+ node_inputs = input_values.get(node_id, {})
1388
+ if not node_inputs:
1389
+ return None
1390
+
1391
+ result = {}
1392
+ has_user_value = False
1393
+ for port_name, comp in node._output_components.items():
1394
+ if comp is None:
1395
+ continue
1396
+ if port_name in node_inputs:
1397
+ value = node_inputs[port_name]
1398
+ if value is not None:
1399
+ if isinstance(value, str) and value.startswith("data:"):
1400
+ value = self._save_data_url_as_gradio_file(value)
1401
+ result[port_name] = value
1402
+ has_user_value = True
1403
+
1404
+ return result if has_user_value else None
1405
+
1406
+ def _save_data_url_as_gradio_file(self, data_url: str):
1407
+ try:
1408
+ header, data = data_url.split(",", 1)
1409
+ mime_type = header.split(":")[1].split(";")[0]
1410
+ ext_map = {
1411
+ "image/png": ".png",
1412
+ "image/jpeg": ".jpg",
1413
+ "image/gif": ".gif",
1414
+ "image/webp": ".webp",
1415
+ "audio/webm": ".webm",
1416
+ "audio/wav": ".wav",
1417
+ "audio/mp3": ".mp3",
1418
+ "audio/mpeg": ".mp3",
1419
+ }
1420
+ ext = ext_map.get(mime_type, ".bin")
1421
+ file_data = base64.b64decode(data)
1422
+ temp_dir = Path(tempfile.gettempdir()) / "daggr_uploads"
1423
+ temp_dir.mkdir(exist_ok=True)
1424
+ file_path = temp_dir / f"{uuid.uuid4()}{ext}"
1425
+ file_path.write_bytes(file_data)
1426
+ return FileValue(str(file_path))
1427
+ except Exception as e:
1428
+ print(f"[ERROR] Failed to save data URL: {e}")
1429
+ return data_url
1430
+
1431
+ def _convert_urls_to_file_values(self, data: Any) -> Any:
1432
+ if isinstance(data, str):
1433
+ if data.startswith(("http://", "https://", "/")) and any(
1434
+ data.lower().endswith(ext)
1435
+ for ext in (
1436
+ ".png",
1437
+ ".jpg",
1438
+ ".jpeg",
1439
+ ".gif",
1440
+ ".webp",
1441
+ ".wav",
1442
+ ".mp3",
1443
+ ".webm",
1444
+ ".mp4",
1445
+ ".ogg",
1446
+ )
1447
+ ):
1448
+ return FileValue(data)
1449
+ return data
1450
+ elif isinstance(data, dict):
1451
+ return {k: self._convert_urls_to_file_values(v) for k, v in data.items()}
1452
+ elif isinstance(data, list):
1453
+ return [self._convert_urls_to_file_values(item) for item in data]
1454
+ return data
1455
+
1456
+ async def _execute_to_node(
1457
+ self,
1458
+ session: ExecutionSession,
1459
+ target_node: str,
1460
+ session_id: str | None,
1461
+ input_values: dict[str, Any],
1462
+ selected_results: dict[str, int],
1463
+ ) -> dict:
1464
+ if not session_id:
1465
+ session_id = self.state.create_session(self.graph.persist_key)
1466
+
1467
+ for node_name, node in self.graph.nodes.items():
1468
+ if isinstance(node, ChoiceNode):
1469
+ node_id = node_name.replace(" ", "_").replace("-", "_")
1470
+ variant_idx = input_values.get(node_id, {}).get("_selected_variant", 0)
1471
+ session.selected_variants[node_name] = variant_idx
1472
+
1473
+ ancestors = self._get_ancestors(target_node)
1474
+ nodes_to_run = ancestors + [target_node]
1475
+ execution_order = self.graph.get_execution_order()
1476
+ nodes_to_execute = [n for n in execution_order if n in nodes_to_run]
1477
+
1478
+ entry_inputs: dict[str, dict[str, Any]] = {}
1479
+ for node_name in nodes_to_execute:
1480
+ node = self.graph.nodes[node_name]
1481
+ if node._input_components:
1482
+ node_inputs = {}
1483
+ for port_name in node._input_components:
1484
+ input_node_name = f"{node_name}__{port_name}"
1485
+ input_node_id = input_node_name.replace(" ", "_").replace("-", "_")
1486
+ if input_node_id in input_values:
1487
+ value = input_values[input_node_id].get("value")
1488
+ if value is not None:
1489
+ node_inputs[port_name] = value
1490
+ if node_inputs:
1491
+ entry_inputs[node_name] = node_inputs
1492
+ elif isinstance(node, InteractionNode):
1493
+ value = input_values.get(node_name, "")
1494
+ port = node._input_ports[0] if node._input_ports else "input"
1495
+ entry_inputs[node_name] = {port: value}
1496
+
1497
+ existing_results = {}
1498
+ if session_id:
1499
+ for node_name in nodes_to_execute:
1500
+ if node_name in selected_results:
1501
+ cached = self.state.get_result_by_index(
1502
+ session_id, node_name, selected_results[node_name]
1503
+ )
1504
+ else:
1505
+ cached = self.state.get_latest_result(session_id, node_name)
1506
+ if cached is not None:
1507
+ existing_results[node_name] = self._convert_urls_to_file_values(
1508
+ cached
1509
+ )
1510
+
1511
+ for k, v in existing_results.items():
1512
+ if k not in session.results:
1513
+ session.results[k] = v
1514
+
1515
+ if target_node in session.results:
1516
+ del session.results[target_node]
1517
+
1518
+ node_results = {}
1519
+ node_statuses = {}
1520
+
1521
+ for node_name in nodes_to_execute:
1522
+ if node_name in existing_results:
1523
+ node_results[node_name] = existing_results[node_name]
1524
+ node_statuses[node_name] = "completed"
1525
+ continue
1526
+
1527
+ if node_name in session.results:
1528
+ node_results[node_name] = session.results[node_name]
1529
+ node_statuses[node_name] = "completed"
1530
+ continue
1531
+
1532
+ node_statuses[node_name] = "running"
1533
+ user_input = entry_inputs.get(node_name, {})
1534
+ result = await self.executor.execute_node(session, node_name, user_input)
1535
+ node_results[node_name] = result
1536
+ node_statuses[node_name] = "completed"
1537
+ self.state.save_result(session_id, node_name, result)
1538
+
1539
+ return self._build_graph_data(
1540
+ node_results, node_statuses, input_values, {}, session_id, selected_results
1541
+ )
1542
+
1543
+ async def _execute_to_node_streaming(
1544
+ self,
1545
+ session: ExecutionSession,
1546
+ target_node: str,
1547
+ sheet_id: str | None,
1548
+ input_values: dict[str, Any],
1549
+ item_list_values: dict[str, Any],
1550
+ selected_results: dict[str, int],
1551
+ run_id: str,
1552
+ user_id: str | None = None,
1553
+ run_ancestors: bool = True,
1554
+ ):
1555
+ can_persist = (
1556
+ user_id is not None
1557
+ and sheet_id is not None
1558
+ and self.graph.persist_key is not None
1559
+ )
1560
+
1561
+ for node_name, node in self.graph.nodes.items():
1562
+ if isinstance(node, ChoiceNode):
1563
+ node_id = node_name.replace(" ", "_").replace("-", "_")
1564
+ variant_idx = input_values.get(node_id, {}).get("_selected_variant", 0)
1565
+ session.selected_variants[node_name] = variant_idx
1566
+
1567
+ if run_ancestors:
1568
+ ancestors = self._get_ancestors(target_node)
1569
+ nodes_to_run = ancestors + [target_node]
1570
+ else:
1571
+ nodes_to_run = [target_node]
1572
+ execution_order = self.graph.get_execution_order()
1573
+ nodes_to_execute = [n for n in execution_order if n in nodes_to_run]
1574
+
1575
+ entry_inputs: dict[str, dict[str, Any]] = {}
1576
+ for node_name in nodes_to_execute:
1577
+ node = self.graph.nodes[node_name]
1578
+ if node._input_components:
1579
+ node_inputs = {}
1580
+ for port_name in node._input_components:
1581
+ input_node_name = f"{node_name}__{port_name}"
1582
+ input_node_id = input_node_name.replace(" ", "_").replace("-", "_")
1583
+ if input_node_id in input_values:
1584
+ value = input_values[input_node_id].get("value")
1585
+ if value is not None:
1586
+ node_inputs[port_name] = value
1587
+ if node_inputs:
1588
+ entry_inputs[node_name] = node_inputs
1589
+ elif isinstance(node, InteractionNode):
1590
+ value = input_values.get(node_name, "")
1591
+ port = node._input_ports[0] if node._input_ports else "input"
1592
+ entry_inputs[node_name] = {port: value}
1593
+
1594
+ existing_results = {}
1595
+ for node_name in nodes_to_execute:
1596
+ node = self.graph.nodes[node_name]
1597
+ node_id = node_name.replace(" ", "_").replace("-", "_")
1598
+ user_output = self._get_user_provided_output(node, node_id, input_values)
1599
+ if user_output is not None:
1600
+ existing_results[node_name] = user_output
1601
+ if can_persist:
1602
+ snapshot = {
1603
+ "inputs": input_values,
1604
+ "selected_results": selected_results,
1605
+ }
1606
+ self.state.save_result(sheet_id, node_name, user_output, snapshot)
1607
+ continue
1608
+
1609
+ if node_name == target_node:
1610
+ continue
1611
+
1612
+ if can_persist:
1613
+ if node_name in selected_results:
1614
+ cached = self.state.get_result_by_index(
1615
+ sheet_id, node_name, selected_results[node_name]
1616
+ )
1617
+ else:
1618
+ cached = self.state.get_latest_result(sheet_id, node_name)
1619
+ if cached is not None:
1620
+ existing_results[node_name] = self._convert_urls_to_file_values(
1621
+ cached
1622
+ )
1623
+
1624
+ for k, v in existing_results.items():
1625
+ if k not in session.results:
1626
+ session.results[k] = v
1627
+
1628
+ if target_node in session.results:
1629
+ del session.results[target_node]
1630
+
1631
+ node_results = {}
1632
+ node_statuses = {}
1633
+
1634
+ try:
1635
+ for node_name in nodes_to_execute:
1636
+ if node_name in existing_results:
1637
+ result = existing_results[node_name]
1638
+ result = self._apply_item_list_edits(
1639
+ node_name, result, item_list_values
1640
+ )
1641
+ node_results[node_name] = result
1642
+ session.results[node_name] = result
1643
+ node_statuses[node_name] = "completed"
1644
+ continue
1645
+
1646
+ if node_name in session.results:
1647
+ result = session.results[node_name]
1648
+ result = self._apply_item_list_edits(
1649
+ node_name, result, item_list_values
1650
+ )
1651
+ node_results[node_name] = result
1652
+ node_statuses[node_name] = "completed"
1653
+ continue
1654
+
1655
+ can_execute = await session.start_node_execution(node_name)
1656
+ if not can_execute:
1657
+ if node_name == target_node:
1658
+ return
1659
+ await session.wait_for_node(node_name)
1660
+ if node_name in session.results:
1661
+ result = session.results[node_name]
1662
+ result = self._apply_item_list_edits(
1663
+ node_name, result, item_list_values
1664
+ )
1665
+ node_results[node_name] = result
1666
+ node_statuses[node_name] = "completed"
1667
+ continue
1668
+
1669
+ try:
1670
+ node_statuses[node_name] = "running"
1671
+ user_input = entry_inputs.get(node_name, {})
1672
+
1673
+ yield {
1674
+ "type": "node_started",
1675
+ "started_node": node_name,
1676
+ "run_id": run_id,
1677
+ }
1678
+
1679
+ start_time = time.time()
1680
+ result = await self.executor.execute_node(
1681
+ session, node_name, user_input
1682
+ )
1683
+ elapsed_ms = (time.time() - start_time) * 1000
1684
+
1685
+ result = self._apply_item_list_edits(
1686
+ node_name, result, item_list_values
1687
+ )
1688
+ session.results[node_name] = result
1689
+ node_results[node_name] = result
1690
+ node_statuses[node_name] = "completed"
1691
+
1692
+ if can_persist:
1693
+ current_count = self.state.get_result_count(sheet_id, node_name)
1694
+ snapshot = {
1695
+ "inputs": input_values,
1696
+ "selected_results": selected_results,
1697
+ }
1698
+ self.state.save_result(sheet_id, node_name, result, snapshot)
1699
+ selected_results[node_name] = current_count
1700
+
1701
+ graph_data = self._build_graph_data(
1702
+ node_results,
1703
+ node_statuses,
1704
+ input_values,
1705
+ {},
1706
+ sheet_id,
1707
+ selected_results,
1708
+ )
1709
+ graph_data["type"] = "node_complete"
1710
+ graph_data["completed_node"] = node_name
1711
+ graph_data["run_id"] = run_id
1712
+ graph_data["execution_time_ms"] = elapsed_ms
1713
+ finally:
1714
+ await session.finish_node_execution(node_name)
1715
+ yield graph_data
1716
+
1717
+ except Exception as e:
1718
+ error_node = None
1719
+ if nodes_to_execute:
1720
+ current_idx = len(node_results)
1721
+ if current_idx < len(nodes_to_execute):
1722
+ error_node = nodes_to_execute[current_idx]
1723
+ node_statuses[error_node] = "error"
1724
+ node_results[error_node] = {"error": str(e)}
1725
+
1726
+ graph_data = self._build_graph_data(
1727
+ node_results,
1728
+ node_statuses,
1729
+ input_values,
1730
+ {},
1731
+ sheet_id,
1732
+ selected_results,
1733
+ )
1734
+ graph_data["type"] = "error"
1735
+ graph_data["run_id"] = run_id
1736
+ graph_data["error"] = str(e)
1737
+ graph_data["nodes_to_clear"] = nodes_to_execute
1738
+ if error_node:
1739
+ graph_data["node"] = error_node
1740
+ graph_data["completed_node"] = error_node
1741
+ yield graph_data
1742
+
1743
+ async def _execute_workflow_api(
1744
+ self, request: Request, subgraph_id: str | None = None
1745
+ ) -> JSONResponse:
1746
+ try:
1747
+ body = await request.json()
1748
+ except Exception:
1749
+ body = {}
1750
+
1751
+ input_values = body.get("inputs", {})
1752
+ session = ExecutionSession(self.graph)
1753
+
1754
+ subgraphs = self.graph.get_subgraphs()
1755
+ output_node_names = set(self.graph.get_output_nodes())
1756
+
1757
+ if subgraph_id is None:
1758
+ if len(subgraphs) > 1:
1759
+ return JSONResponse(
1760
+ {
1761
+ "error": "Multiple subgraphs detected. Please specify a subgraph_id.",
1762
+ "available_subgraphs": [
1763
+ f"subgraph_{i}" for i in range(len(subgraphs))
1764
+ ],
1765
+ },
1766
+ status_code=400,
1767
+ )
1768
+ target_nodes = subgraphs[0] if subgraphs else set(self.graph.nodes.keys())
1769
+ else:
1770
+ if subgraph_id == "main" and len(subgraphs) == 1:
1771
+ target_nodes = subgraphs[0]
1772
+ elif subgraph_id.startswith("subgraph_"):
1773
+ try:
1774
+ idx = int(subgraph_id.split("_")[1])
1775
+ if idx < 0 or idx >= len(subgraphs):
1776
+ return JSONResponse(
1777
+ {"error": f"Subgraph '{subgraph_id}' not found"},
1778
+ status_code=404,
1779
+ )
1780
+ target_nodes = subgraphs[idx]
1781
+ except (ValueError, IndexError):
1782
+ return JSONResponse(
1783
+ {"error": f"Invalid subgraph_id '{subgraph_id}'"},
1784
+ status_code=400,
1785
+ )
1786
+ else:
1787
+ return JSONResponse(
1788
+ {"error": f"Subgraph '{subgraph_id}' not found"},
1789
+ status_code=404,
1790
+ )
1791
+
1792
+ for node_name, node in self.graph.nodes.items():
1793
+ if isinstance(node, ChoiceNode):
1794
+ node_id = node_name.replace(" ", "_").replace("-", "_")
1795
+ variant_idx = input_values.get(f"{node_id}___selected_variant", 0)
1796
+ session.selected_variants[node_name] = variant_idx
1797
+
1798
+ execution_order = self.graph.get_execution_order()
1799
+ nodes_to_execute = [n for n in execution_order if n in target_nodes]
1800
+
1801
+ entry_inputs: dict[str, dict[str, Any]] = {}
1802
+ for node_name in nodes_to_execute:
1803
+ node = self.graph.nodes[node_name]
1804
+ if node._input_components:
1805
+ node_inputs = {}
1806
+ for port_name in node._input_components:
1807
+ input_node_id = f"{node_name}__{port_name}".replace(
1808
+ " ", "_"
1809
+ ).replace("-", "_")
1810
+ if input_node_id in input_values:
1811
+ node_inputs[port_name] = input_values[input_node_id]
1812
+ if node_inputs:
1813
+ entry_inputs[node_name] = node_inputs
1814
+
1815
+ session.results = {}
1816
+ node_results = {}
1817
+
1818
+ try:
1819
+ for node_name in nodes_to_execute:
1820
+ user_input = entry_inputs.get(node_name, {})
1821
+ result = await self.executor.execute_node(
1822
+ session, node_name, user_input
1823
+ )
1824
+ node_results[node_name] = result
1825
+ except Exception as e:
1826
+ return JSONResponse(
1827
+ {"error": f"Execution error in node '{node_name}': {str(e)}"},
1828
+ status_code=500,
1829
+ )
1830
+
1831
+ outputs = {}
1832
+ for node_name in nodes_to_execute:
1833
+ if node_name in output_node_names and node_name in node_results:
1834
+ result = node_results[node_name]
1835
+ result = self._transform_file_paths(result)
1836
+ outputs[node_name] = result
1837
+
1838
+ return JSONResponse({"outputs": outputs})
1839
+
1840
+ def run(
1841
+ self,
1842
+ host: str | None = None,
1843
+ port: int | None = None,
1844
+ share: bool | None = None,
1845
+ open_browser: bool = True,
1846
+ **kwargs,
1847
+ ):
1848
+ from gradio.utils import colab_check, ipython_check
1849
+
1850
+ if host is None:
1851
+ host = os.environ.get("GRADIO_SERVER_NAME", "127.0.0.1")
1852
+ if port is None:
1853
+ port = int(os.environ.get("GRADIO_SERVER_PORT", "7860"))
1854
+
1855
+ actual_port = _find_available_port(host, port)
1856
+ if actual_port != port:
1857
+ print(f"\n Port {port} is in use, using {actual_port} instead.")
1858
+
1859
+ self.graph._validate_edges()
1860
+
1861
+ is_colab = colab_check()
1862
+ is_kaggle = os.environ.get("KAGGLE_KERNEL_RUN_TYPE") is not None
1863
+ is_notebook = is_colab or is_kaggle or ipython_check()
1864
+
1865
+ if share is None:
1866
+ share = is_colab or is_kaggle
1867
+
1868
+ if is_notebook or share:
1869
+ config = uvicorn.Config(
1870
+ app=self.app,
1871
+ host=host,
1872
+ port=actual_port,
1873
+ log_level="warning",
1874
+ )
1875
+ server = _Server(config)
1876
+ server.run_in_thread()
1877
+
1878
+ local_url = f"http://{host}:{actual_port}"
1879
+ print(f"\n UI running at: {local_url}")
1880
+ if self.api_server:
1881
+ print(f" API server at: {local_url}/api")
1882
+
1883
+ share_url = None
1884
+ if share:
1885
+ from gradio.networking import setup_tunnel
1886
+
1887
+ share_token = secrets.token_urlsafe(32)
1888
+ share_url = setup_tunnel(
1889
+ local_host=host,
1890
+ local_port=actual_port,
1891
+ share_token=share_token,
1892
+ share_server_address=None,
1893
+ share_server_tls_certificate=None,
1894
+ )
1895
+ print(f" Public URL: {share_url}")
1896
+ print(
1897
+ "\n This share link expires in 1 week. For permanent hosting, deploy to Hugging Face Spaces.\n"
1898
+ )
1899
+
1900
+ if is_colab or is_kaggle:
1901
+ from IPython.display import HTML, display
1902
+
1903
+ url = share_url or local_url
1904
+ display(
1905
+ HTML(f'<a href="{url}" target="_blank">Open daggr app: {url}</a>')
1906
+ )
1907
+ elif open_browser:
1908
+ webbrowser.open_new_tab(share_url or local_url)
1909
+
1910
+ try:
1911
+ while True:
1912
+ time.sleep(1)
1913
+ except KeyboardInterrupt:
1914
+ print("\nShutting down...")
1915
+ server.close()
1916
+ else:
1917
+ local_url = f"http://{host}:{actual_port}"
1918
+ print(f"\n UI running at: {local_url}")
1919
+ if self.api_server:
1920
+ print(f" API server at: {local_url}/api")
1921
+ print()
1922
+ if open_browser:
1923
+ threading.Timer(0.5, lambda: webbrowser.open_new_tab(local_url)).start()
1924
+ uvicorn.run(
1925
+ self.app, host=host, port=actual_port, log_level="warning", **kwargs
1926
+ )
1927
+
1928
+
1929
+ class _Server(uvicorn.Server):
1930
+ def install_signal_handlers(self):
1931
+ pass
1932
+
1933
+ def run_in_thread(self):
1934
+ self.thread = threading.Thread(target=self.run, daemon=True)
1935
+ self.thread.start()
1936
+ start = time.time()
1937
+ while not self.started:
1938
+ time.sleep(1e-3)
1939
+ if time.time() - start > 5:
1940
+ raise RuntimeError(
1941
+ "Server failed to start. Please check that the port is available."
1942
+ )
1943
+
1944
+ def close(self):
1945
+ self.should_exit = True
1946
+ self.thread.join(timeout=5)
daggr/session.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Session management for daggr, including per-session execution contexts for security isolation and concurrency management."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ if TYPE_CHECKING:
9
+ from daggr.graph import Graph
10
+
11
+
12
+ class ConcurrencyManager:
13
+ """Manages concurrency limits for FnNode execution within a session.
14
+
15
+ By default, only one FnNode runs at a time per session. FnNodes can opt
16
+ into concurrent execution via the `concurrent` parameter, and can share
17
+ limits via `concurrency_group`.
18
+ """
19
+
20
+ def __init__(self):
21
+ self._default_semaphore = asyncio.Semaphore(1)
22
+ self._group_semaphores: dict[str, asyncio.Semaphore] = {}
23
+ self._lock = asyncio.Lock()
24
+
25
+ async def get_semaphore(
26
+ self,
27
+ concurrent: bool,
28
+ concurrency_group: str | None,
29
+ max_concurrent: int,
30
+ ) -> asyncio.Semaphore | None:
31
+ """Get the appropriate semaphore for a FnNode.
32
+
33
+ Returns None if the node should run without concurrency limits
34
+ (concurrent=True with no group).
35
+ """
36
+ if not concurrent:
37
+ return self._default_semaphore
38
+
39
+ if concurrency_group:
40
+ async with self._lock:
41
+ if concurrency_group not in self._group_semaphores:
42
+ self._group_semaphores[concurrency_group] = asyncio.Semaphore(
43
+ max_concurrent
44
+ )
45
+ return self._group_semaphores[concurrency_group]
46
+
47
+ return None
48
+
49
+
50
+ class ExecutionSession:
51
+ """Per-session execution context.
52
+
53
+ Each WebSocket connection gets its own ExecutionSession, providing:
54
+ - Isolated HF token
55
+ - Isolated results cache
56
+ - Isolated Gradio client cache
57
+ - Per-session concurrency management
58
+ - Node execution coordination (wait for dependencies)
59
+ """
60
+
61
+ def __init__(self, graph: Graph, hf_token: str | None = None):
62
+ self.graph = graph
63
+ self.hf_token = hf_token
64
+ self.results: dict[str, Any] = {}
65
+ self.scattered_results: dict[str, list[Any]] = {}
66
+ self.selected_variants: dict[str, int] = {}
67
+ self.clients: dict[str, Any] = {}
68
+ self.concurrency = ConcurrencyManager()
69
+
70
+ self._executing_nodes: dict[str, asyncio.Event] = {}
71
+ self._execution_lock = asyncio.Lock()
72
+
73
+ def set_hf_token(self, token: str | None):
74
+ """Update the HF token and clear cached clients."""
75
+ if token != self.hf_token:
76
+ self.hf_token = token
77
+ self.clients = {}
78
+
79
+ def clear_results(self):
80
+ """Clear cached results for a fresh execution."""
81
+ self.results = {}
82
+ self.scattered_results = {}
83
+
84
+ async def wait_for_node(self, node_name: str) -> bool:
85
+ """Wait for a node to finish executing if it's currently running.
86
+
87
+ Returns True if we waited (node was executing), False otherwise.
88
+ """
89
+ async with self._execution_lock:
90
+ event = self._executing_nodes.get(node_name)
91
+
92
+ if event:
93
+ await event.wait()
94
+ return True
95
+ return False
96
+
97
+ async def start_node_execution(self, node_name: str) -> bool:
98
+ """Mark a node as starting execution.
99
+
100
+ Returns True if we can start (no one else is executing it).
101
+ Returns False if someone else is already executing it.
102
+ """
103
+ async with self._execution_lock:
104
+ if node_name in self._executing_nodes:
105
+ return False
106
+ self._executing_nodes[node_name] = asyncio.Event()
107
+ return True
108
+
109
+ async def finish_node_execution(self, node_name: str):
110
+ """Mark a node as finished executing and notify waiters."""
111
+ async with self._execution_lock:
112
+ event = self._executing_nodes.pop(node_name, None)
113
+ if event:
114
+ event.set()
daggr/state.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import sqlite3
6
+ import uuid
7
+ from datetime import datetime
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+ from huggingface_hub import constants
12
+
13
+
14
+ def get_daggr_cache_dir() -> Path:
15
+ """Get the daggr cache directory, respecting HF_HOME env var."""
16
+ cache_dir = Path(constants.HF_HOME) / "daggr"
17
+ cache_dir.mkdir(parents=True, exist_ok=True)
18
+ return cache_dir
19
+
20
+
21
+ def get_daggr_files_dir() -> Path:
22
+ files_dir = get_daggr_cache_dir() / "files"
23
+ files_dir.mkdir(parents=True, exist_ok=True)
24
+ return files_dir
25
+
26
+
27
+ class SessionState:
28
+ def __init__(self, db_path: str | None = None):
29
+ if db_path is None:
30
+ db_path = str(get_daggr_cache_dir() / "sessions.db")
31
+ self.db_path = db_path
32
+ self._init_db()
33
+
34
+ def _init_db(self):
35
+ conn = sqlite3.connect(self.db_path)
36
+ cursor = conn.cursor()
37
+
38
+ self._migrate_legacy_schema(cursor)
39
+
40
+ cursor.execute("""
41
+ CREATE TABLE IF NOT EXISTS sheets (
42
+ sheet_id TEXT PRIMARY KEY,
43
+ user_id TEXT NOT NULL,
44
+ graph_name TEXT NOT NULL,
45
+ name TEXT,
46
+ transform TEXT,
47
+ created_at TEXT,
48
+ updated_at TEXT
49
+ )
50
+ """)
51
+
52
+ cursor.execute("PRAGMA table_info(sheets)")
53
+ columns = [col[1] for col in cursor.fetchall()]
54
+ if "transform" not in columns:
55
+ cursor.execute("ALTER TABLE sheets ADD COLUMN transform TEXT")
56
+
57
+ cursor.execute("""
58
+ CREATE INDEX IF NOT EXISTS idx_sheets_user_graph
59
+ ON sheets(user_id, graph_name)
60
+ """)
61
+
62
+ cursor.execute("""
63
+ CREATE TABLE IF NOT EXISTS node_inputs (
64
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
65
+ sheet_id TEXT,
66
+ node_name TEXT,
67
+ port_name TEXT,
68
+ value TEXT,
69
+ updated_at TEXT,
70
+ FOREIGN KEY (sheet_id) REFERENCES sheets(sheet_id) ON DELETE CASCADE,
71
+ UNIQUE(sheet_id, node_name, port_name)
72
+ )
73
+ """)
74
+
75
+ cursor.execute("""
76
+ CREATE INDEX IF NOT EXISTS idx_node_inputs_sheet
77
+ ON node_inputs(sheet_id)
78
+ """)
79
+
80
+ cursor.execute("""
81
+ CREATE TABLE IF NOT EXISTS node_results (
82
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
83
+ sheet_id TEXT,
84
+ node_name TEXT,
85
+ result TEXT,
86
+ inputs_snapshot TEXT,
87
+ created_at TEXT,
88
+ FOREIGN KEY (sheet_id) REFERENCES sheets(sheet_id) ON DELETE CASCADE
89
+ )
90
+ """)
91
+
92
+ cursor.execute("PRAGMA table_info(node_results)")
93
+ result_columns = [col[1] for col in cursor.fetchall()]
94
+ if "inputs_snapshot" not in result_columns:
95
+ cursor.execute("ALTER TABLE node_results ADD COLUMN inputs_snapshot TEXT")
96
+
97
+ cursor.execute("""
98
+ CREATE INDEX IF NOT EXISTS idx_node_results_sheet_node
99
+ ON node_results(sheet_id, node_name)
100
+ """)
101
+
102
+ conn.commit()
103
+ conn.close()
104
+
105
+ def _migrate_legacy_schema(self, cursor):
106
+ cursor.execute(
107
+ "SELECT name FROM sqlite_master WHERE type='table' AND name='node_inputs'"
108
+ )
109
+ if cursor.fetchone():
110
+ cursor.execute("PRAGMA table_info(node_inputs)")
111
+ columns = [col[1] for col in cursor.fetchall()]
112
+ if "session_id" in columns and "sheet_id" not in columns:
113
+ cursor.execute("ALTER TABLE node_inputs RENAME TO _node_inputs_old")
114
+ cursor.execute("ALTER TABLE node_results RENAME TO _node_results_old")
115
+ cursor.execute("ALTER TABLE sessions RENAME TO _sessions_old")
116
+
117
+ cursor.execute("""
118
+ CREATE TABLE sheets (
119
+ sheet_id TEXT PRIMARY KEY,
120
+ user_id TEXT NOT NULL,
121
+ graph_name TEXT NOT NULL,
122
+ name TEXT,
123
+ created_at TEXT,
124
+ updated_at TEXT
125
+ )
126
+ """)
127
+ cursor.execute("""
128
+ CREATE TABLE node_inputs (
129
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
130
+ sheet_id TEXT,
131
+ node_name TEXT,
132
+ port_name TEXT,
133
+ value TEXT,
134
+ updated_at TEXT,
135
+ FOREIGN KEY (sheet_id) REFERENCES sheets(sheet_id) ON DELETE CASCADE,
136
+ UNIQUE(sheet_id, node_name, port_name)
137
+ )
138
+ """)
139
+ cursor.execute("""
140
+ CREATE TABLE node_results (
141
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
142
+ sheet_id TEXT,
143
+ node_name TEXT,
144
+ result TEXT,
145
+ created_at TEXT,
146
+ FOREIGN KEY (sheet_id) REFERENCES sheets(sheet_id) ON DELETE CASCADE
147
+ )
148
+ """)
149
+
150
+ cursor.execute("""
151
+ INSERT INTO sheets (sheet_id, user_id, graph_name, name, created_at, updated_at)
152
+ SELECT session_id, 'local', graph_name, 'Migrated Sheet', created_at, updated_at
153
+ FROM _sessions_old
154
+ """)
155
+ cursor.execute("""
156
+ INSERT INTO node_inputs (sheet_id, node_name, port_name, value, updated_at)
157
+ SELECT session_id, node_name, port_name, value, updated_at
158
+ FROM _node_inputs_old
159
+ """)
160
+ cursor.execute("""
161
+ INSERT INTO node_results (sheet_id, node_name, result, created_at)
162
+ SELECT session_id, node_name, result, created_at
163
+ FROM _node_results_old
164
+ """)
165
+
166
+ cursor.execute("DROP TABLE _sessions_old")
167
+ cursor.execute("DROP TABLE _node_inputs_old")
168
+ cursor.execute("DROP TABLE _node_results_old")
169
+
170
+ def get_effective_user_id(self, hf_user: dict | None = None) -> str | None:
171
+ is_on_spaces = os.environ.get("SPACE_ID") is not None
172
+ if hf_user and hf_user.get("username"):
173
+ return hf_user["username"]
174
+ if is_on_spaces:
175
+ return None
176
+ return "local"
177
+
178
+ def create_sheet(
179
+ self, user_id: str, graph_name: str, name: str | None = None
180
+ ) -> str:
181
+ sheet_id = str(uuid.uuid4())
182
+ now = datetime.now().isoformat()
183
+
184
+ if not name:
185
+ count = self.get_sheet_count(user_id, graph_name)
186
+ name = f"Sheet {count + 1}"
187
+
188
+ conn = sqlite3.connect(self.db_path)
189
+ cursor = conn.cursor()
190
+ cursor.execute(
191
+ """INSERT INTO sheets (sheet_id, user_id, graph_name, name, created_at, updated_at)
192
+ VALUES (?, ?, ?, ?, ?, ?)""",
193
+ (sheet_id, user_id, graph_name, name, now, now),
194
+ )
195
+ conn.commit()
196
+ conn.close()
197
+ return sheet_id
198
+
199
+ def get_sheet_count(self, user_id: str, graph_name: str) -> int:
200
+ conn = sqlite3.connect(self.db_path)
201
+ cursor = conn.cursor()
202
+ cursor.execute(
203
+ "SELECT COUNT(*) FROM sheets WHERE user_id = ? AND graph_name = ?",
204
+ (user_id, graph_name),
205
+ )
206
+ count = cursor.fetchone()[0]
207
+ conn.close()
208
+ return count
209
+
210
+ def list_sheets(self, user_id: str, graph_name: str) -> list[dict[str, Any]]:
211
+ conn = sqlite3.connect(self.db_path)
212
+ cursor = conn.cursor()
213
+ cursor.execute(
214
+ """SELECT sheet_id, name, created_at, updated_at
215
+ FROM sheets
216
+ WHERE user_id = ? AND graph_name = ?
217
+ ORDER BY updated_at DESC""",
218
+ (user_id, graph_name),
219
+ )
220
+ rows = cursor.fetchall()
221
+ conn.close()
222
+ return [
223
+ {
224
+ "sheet_id": row[0],
225
+ "name": row[1],
226
+ "created_at": row[2],
227
+ "updated_at": row[3],
228
+ }
229
+ for row in rows
230
+ ]
231
+
232
+ def get_sheet(self, sheet_id: str) -> dict[str, Any] | None:
233
+ conn = sqlite3.connect(self.db_path)
234
+ cursor = conn.cursor()
235
+ cursor.execute(
236
+ """SELECT sheet_id, user_id, graph_name, name, transform, created_at, updated_at
237
+ FROM sheets WHERE sheet_id = ?""",
238
+ (sheet_id,),
239
+ )
240
+ row = cursor.fetchone()
241
+ conn.close()
242
+ if row:
243
+ transform = None
244
+ if row[4]:
245
+ try:
246
+ transform = json.loads(row[4])
247
+ except (json.JSONDecodeError, TypeError):
248
+ pass
249
+ return {
250
+ "sheet_id": row[0],
251
+ "user_id": row[1],
252
+ "graph_name": row[2],
253
+ "name": row[3],
254
+ "transform": transform,
255
+ "created_at": row[5],
256
+ "updated_at": row[6],
257
+ }
258
+ return None
259
+
260
+ def save_transform(self, sheet_id: str, x: float, y: float, scale: float) -> bool:
261
+ now = datetime.now().isoformat()
262
+ transform = json.dumps({"x": x, "y": y, "scale": scale})
263
+ conn = sqlite3.connect(self.db_path)
264
+ cursor = conn.cursor()
265
+ cursor.execute(
266
+ "UPDATE sheets SET transform = ?, updated_at = ? WHERE sheet_id = ?",
267
+ (transform, now, sheet_id),
268
+ )
269
+ updated = cursor.rowcount > 0
270
+ conn.commit()
271
+ conn.close()
272
+ return updated
273
+
274
+ def rename_sheet(self, sheet_id: str, new_name: str) -> bool:
275
+ now = datetime.now().isoformat()
276
+ conn = sqlite3.connect(self.db_path)
277
+ cursor = conn.cursor()
278
+ cursor.execute(
279
+ "UPDATE sheets SET name = ?, updated_at = ? WHERE sheet_id = ?",
280
+ (new_name, now, sheet_id),
281
+ )
282
+ updated = cursor.rowcount > 0
283
+ conn.commit()
284
+ conn.close()
285
+ return updated
286
+
287
+ def delete_sheet(self, sheet_id: str) -> bool:
288
+ conn = sqlite3.connect(self.db_path)
289
+ cursor = conn.cursor()
290
+ cursor.execute("DELETE FROM node_inputs WHERE sheet_id = ?", (sheet_id,))
291
+ cursor.execute("DELETE FROM node_results WHERE sheet_id = ?", (sheet_id,))
292
+ cursor.execute("DELETE FROM sheets WHERE sheet_id = ?", (sheet_id,))
293
+ deleted = cursor.rowcount > 0
294
+ conn.commit()
295
+ conn.close()
296
+ return deleted
297
+
298
+ def get_or_create_sheet(
299
+ self, user_id: str, graph_name: str, sheet_id: str | None = None
300
+ ) -> str:
301
+ if sheet_id:
302
+ sheet = self.get_sheet(sheet_id)
303
+ if sheet and sheet["user_id"] == user_id:
304
+ return sheet_id
305
+
306
+ sheets = self.list_sheets(user_id, graph_name)
307
+ if sheets:
308
+ return sheets[0]["sheet_id"]
309
+
310
+ return self.create_sheet(user_id, graph_name)
311
+
312
+ def save_input(self, sheet_id: str, node_name: str, port_name: str, value: Any):
313
+ now = datetime.now().isoformat()
314
+ value_json = json.dumps(value, default=str)
315
+ conn = sqlite3.connect(self.db_path)
316
+ cursor = conn.cursor()
317
+ cursor.execute(
318
+ """INSERT INTO node_inputs (sheet_id, node_name, port_name, value, updated_at)
319
+ VALUES (?, ?, ?, ?, ?)
320
+ ON CONFLICT(sheet_id, node_name, port_name)
321
+ DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at""",
322
+ (sheet_id, node_name, port_name, value_json, now),
323
+ )
324
+ cursor.execute(
325
+ "UPDATE sheets SET updated_at = ? WHERE sheet_id = ?",
326
+ (now, sheet_id),
327
+ )
328
+ conn.commit()
329
+ conn.close()
330
+
331
+ def get_inputs(self, sheet_id: str) -> dict[str, dict[str, Any]]:
332
+ conn = sqlite3.connect(self.db_path)
333
+ cursor = conn.cursor()
334
+ cursor.execute(
335
+ "SELECT node_name, port_name, value FROM node_inputs WHERE sheet_id = ?",
336
+ (sheet_id,),
337
+ )
338
+ results = cursor.fetchall()
339
+ conn.close()
340
+ inputs: dict[str, dict[str, Any]] = {}
341
+ for node_name, port_name, value_json in results:
342
+ if node_name not in inputs:
343
+ inputs[node_name] = {}
344
+ inputs[node_name][port_name] = json.loads(value_json)
345
+ return inputs
346
+
347
+ def save_result(
348
+ self,
349
+ sheet_id: str,
350
+ node_name: str,
351
+ result: Any,
352
+ inputs_snapshot: dict[str, Any] | None = None,
353
+ ):
354
+ now = datetime.now().isoformat()
355
+ result_json = json.dumps(result, default=str)
356
+ inputs_json = (
357
+ json.dumps(inputs_snapshot, default=str) if inputs_snapshot else None
358
+ )
359
+ conn = sqlite3.connect(self.db_path)
360
+ cursor = conn.cursor()
361
+ cursor.execute(
362
+ "INSERT INTO node_results (sheet_id, node_name, result, inputs_snapshot, created_at) VALUES (?, ?, ?, ?, ?)",
363
+ (sheet_id, node_name, result_json, inputs_json, now),
364
+ )
365
+ cursor.execute(
366
+ "UPDATE sheets SET updated_at = ? WHERE sheet_id = ?",
367
+ (now, sheet_id),
368
+ )
369
+ conn.commit()
370
+ conn.close()
371
+
372
+ def get_latest_result(self, sheet_id: str, node_name: str) -> Any | None:
373
+ conn = sqlite3.connect(self.db_path)
374
+ cursor = conn.cursor()
375
+ cursor.execute(
376
+ """SELECT result FROM node_results
377
+ WHERE sheet_id = ? AND node_name = ?
378
+ ORDER BY created_at DESC LIMIT 1""",
379
+ (sheet_id, node_name),
380
+ )
381
+ result = cursor.fetchone()
382
+ conn.close()
383
+ if result:
384
+ return json.loads(result[0])
385
+ return None
386
+
387
+ def get_result_count(self, sheet_id: str, node_name: str) -> int:
388
+ conn = sqlite3.connect(self.db_path)
389
+ cursor = conn.cursor()
390
+ cursor.execute(
391
+ "SELECT COUNT(*) FROM node_results WHERE sheet_id = ? AND node_name = ?",
392
+ (sheet_id, node_name),
393
+ )
394
+ count = cursor.fetchone()[0]
395
+ conn.close()
396
+ return count
397
+
398
+ def get_result_by_index(
399
+ self, sheet_id: str, node_name: str, index: int
400
+ ) -> Any | None:
401
+ conn = sqlite3.connect(self.db_path)
402
+ cursor = conn.cursor()
403
+ cursor.execute(
404
+ """SELECT result FROM node_results
405
+ WHERE sheet_id = ? AND node_name = ?
406
+ ORDER BY created_at ASC""",
407
+ (sheet_id, node_name),
408
+ )
409
+ results = cursor.fetchall()
410
+ conn.close()
411
+ if results and 0 <= index < len(results):
412
+ return json.loads(results[index][0])
413
+ elif results:
414
+ return json.loads(results[-1][0])
415
+ return None
416
+
417
+ def get_all_results(self, sheet_id: str) -> dict[str, list[Any]]:
418
+ conn = sqlite3.connect(self.db_path)
419
+ cursor = conn.cursor()
420
+ cursor.execute(
421
+ """SELECT node_name, result, inputs_snapshot FROM node_results
422
+ WHERE sheet_id = ?
423
+ ORDER BY created_at ASC""",
424
+ (sheet_id,),
425
+ )
426
+ results = cursor.fetchall()
427
+ conn.close()
428
+ all_results: dict[str, list[Any]] = {}
429
+ for node_name, result_json, inputs_json in results:
430
+ if node_name not in all_results:
431
+ all_results[node_name] = []
432
+ result_data = {
433
+ "result": json.loads(result_json),
434
+ "inputs_snapshot": json.loads(inputs_json) if inputs_json else None,
435
+ }
436
+ all_results[node_name].append(result_data)
437
+ return all_results
438
+
439
+ def get_sheet_state(self, sheet_id: str) -> dict[str, Any]:
440
+ return {
441
+ "inputs": self.get_inputs(sheet_id),
442
+ "results": self.get_all_results(sheet_id),
443
+ }
444
+
445
+ def clear_sheet_data(self, sheet_id: str):
446
+ conn = sqlite3.connect(self.db_path)
447
+ cursor = conn.cursor()
448
+ cursor.execute("DELETE FROM node_inputs WHERE sheet_id = ?", (sheet_id,))
449
+ cursor.execute("DELETE FROM node_results WHERE sheet_id = ?", (sheet_id,))
450
+ conn.commit()
451
+ conn.close()
452
+
453
+ def create_session(self, graph_name: str) -> str:
454
+ return self.create_sheet("local", graph_name)
455
+
456
+ def get_or_create_session(self, session_id: str | None, graph_name: str) -> str:
457
+ return self.get_or_create_sheet("local", graph_name, session_id)
pyproject.toml CHANGED
@@ -6,13 +6,13 @@ readme = "README.md"
6
  license = {text = "MIT"}
7
  requires-python = ">=3.11"
8
  dependencies = [
9
- "gradio==4.19.2",
10
- "huggingface_hub==0.20.3",
11
  "modal",
12
  "typer",
13
  "pillow",
14
  "python-dotenv",
15
- "daggr",
16
  ]
17
 
18
  [project.optional-dependencies]
 
6
  license = {text = "MIT"}
7
  requires-python = ">=3.11"
8
  dependencies = [
9
+ "gradio>=4.19.2",
10
+ "huggingface_hub>=0.20.3",
11
  "modal",
12
  "typer",
13
  "pillow",
14
  "python-dotenv",
15
+ "daggr @ git+https://github.com/gradio-app/daggr.git",
16
  ]
17
 
18
  [project.optional-dependencies]
requirements.txt CHANGED
@@ -1,6 +1,4 @@
1
  gradio==4.19.2
2
- huggingface_hub==0.20.3
3
  modal
4
- pillow
5
- python-dotenv
6
- daggr
 
1
  gradio==4.19.2
 
2
  modal
3
+ Pillow
4
+ ./daggr
 
test_daggr_init.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import io
4
+ import modal
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from daggr import FnNode, Graph
8
+
9
+ def convert_image_to_bytes(image) -> bytes:
10
+ return b"test"
11
+
12
+ try:
13
+ print("Attempting to create FnNode...")
14
+ converter = FnNode(
15
+ fn=convert_image_to_bytes,
16
+ name="Image Converter",
17
+ inputs={
18
+ "image": gr.Image(label="Upload your photo"),
19
+ },
20
+ outputs={
21
+ "output": gr.Textbox(visible=False),
22
+ },
23
+ )
24
+ print("Success!")
25
+ print(f"Output port: {converter.output}")
26
+ except Exception as e:
27
+ import traceback
28
+ traceback.print_exc()
uv.lock CHANGED
The diff for this file is too large to render. See raw diff