cansik commited on
Commit
3025bb3
·
verified ·
1 Parent(s): 691f45a

Upload folder via script

Browse files
assets/favicon.ico ADDED
assets/icon.png ADDED
assets/logo-inverted.png ADDED
assets/logo.png ADDED
dataflow/codecs.py CHANGED
@@ -11,6 +11,7 @@ def node_to_vueflow(n: NodeInstance, data_extra: dict[str, Any] | None = None) -
11
  """Convert a NodeInstance to a plain Vue Flow node dict.
12
 
13
  data_extra can be used by higher level code to attach UI specific state.
 
14
  """
15
  data: dict[str, Any] = {
16
  "kind": n.node_type.kind.value,
@@ -20,7 +21,15 @@ def node_to_vueflow(n: NodeInstance, data_extra: dict[str, Any] | None = None) -
20
  }
21
  if data_extra:
22
  # app level UI code can attach arbitrary extra fields here
23
- data.update(data_extra)
 
 
 
 
 
 
 
 
24
 
25
  return {
26
  "id": n.node_id,
 
11
  """Convert a NodeInstance to a plain Vue Flow node dict.
12
 
13
  data_extra can be used by higher level code to attach UI specific state.
14
+ If data_extra contains 'inputs' or 'outputs', they will replace the defaults.
15
  """
16
  data: dict[str, Any] = {
17
  "kind": n.node_type.kind.value,
 
21
  }
22
  if data_extra:
23
  # app level UI code can attach arbitrary extra fields here
24
+ # If inputs/outputs are provided, they replace the defaults (allows custom positioning)
25
+ if "inputs" in data_extra:
26
+ data["inputs"] = data_extra["inputs"]
27
+ if "outputs" in data_extra:
28
+ data["outputs"] = data_extra["outputs"]
29
+ # Update other fields
30
+ for key, value in data_extra.items():
31
+ if key not in ("inputs", "outputs"):
32
+ data[key] = value
33
 
34
  return {
35
  "id": n.node_id,
dataflow/graph.py CHANGED
@@ -1,18 +1,22 @@
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass, field
4
- from typing import Iterable
5
 
6
  from .connection import Connection
7
  from .enums import DataPortState, ConnectMultiplicity
8
  from .nodes_base import NodeInstance
9
 
 
 
10
 
11
  @dataclass(slots=True)
12
  class DataGraph:
13
  nodes: dict[str, NodeInstance] = field(default_factory=dict)
14
  connections: list[Connection] = field(default_factory=list)
15
 
 
 
16
  def add_node(self, node: NodeInstance) -> None:
17
  self.nodes[node.node_id] = node
18
 
@@ -34,6 +38,26 @@ class DataGraph:
34
  if c.start_node is node:
35
  yield c
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  async def execute(self, node: NodeInstance | None = None) -> None:
38
  if node is None:
39
  for n in list(self.nodes.values()):
@@ -54,6 +78,9 @@ class DataGraph:
54
  feeds = [c for c in incoming_conns if c.end_port is inp]
55
 
56
  if not feeds:
 
 
 
57
  continue
58
 
59
  if inp.schema.multiplicity == ConnectMultiplicity.MULTIPLE:
@@ -67,17 +94,24 @@ class DataGraph:
67
  # Single connection (take the last one if multiple defined by mistake)
68
  if feeds:
69
  inp.value = feeds[-1].start_port.value
 
 
70
 
71
  inp.state = DataPortState.CLEAN
72
 
 
 
73
  # Process
74
- await node.process()
75
 
 
76
  for outp in node.all_outputs():
77
  outp.state = DataPortState.CLEAN
78
 
 
79
  for c in self.downstream_of(node):
80
- # We just mark downstream as dirty; the next execute call will pull the data
81
- c.end_port.state = DataPortState.DIRTY
82
- if c.end_node.auto_process:
83
- await self.execute(c.end_node)
 
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass, field
4
+ from typing import Iterable, Callable, Awaitable
5
 
6
  from .connection import Connection
7
  from .enums import DataPortState, ConnectMultiplicity
8
  from .nodes_base import NodeInstance
9
 
10
+ NodeExecutedCallback = Callable[[NodeInstance], Awaitable[None] | None]
11
+
12
 
13
  @dataclass(slots=True)
14
  class DataGraph:
15
  nodes: dict[str, NodeInstance] = field(default_factory=dict)
16
  connections: list[Connection] = field(default_factory=list)
17
 
18
+ _on_node_executed: NodeExecutedCallback | None = field(default=None, repr=False)
19
+
20
  def add_node(self, node: NodeInstance) -> None:
21
  self.nodes[node.node_id] = node
22
 
 
38
  if c.start_node is node:
39
  yield c
40
 
41
+ def set_on_node_executed(self, cb: NodeExecutedCallback | None) -> None:
42
+ """Register a callback that is invoked after each node is executed.
43
+
44
+ The callback can be sync or async. Pass None to disable.
45
+ """
46
+ self._on_node_executed = cb
47
+
48
+ async def _run_node(self, node: NodeInstance) -> None:
49
+ """Internal helper that executes a single node and fires the callback."""
50
+ await node.process()
51
+
52
+ cb = self._on_node_executed
53
+ if cb is None:
54
+ return
55
+
56
+ result = cb(node)
57
+ # allow async callbacks
58
+ if hasattr(result, "__await__"):
59
+ await result # type: ignore[func-returns-value]
60
+
61
  async def execute(self, node: NodeInstance | None = None) -> None:
62
  if node is None:
63
  for n in list(self.nodes.values()):
 
78
  feeds = [c for c in incoming_conns if c.end_port is inp]
79
 
80
  if not feeds:
81
+ # No connections to this port - clear its value
82
+ inp.value = None if inp.schema.multiplicity == ConnectMultiplicity.SINGLE else []
83
+ inp.state = DataPortState.CLEAN
84
  continue
85
 
86
  if inp.schema.multiplicity == ConnectMultiplicity.MULTIPLE:
 
94
  # Single connection (take the last one if multiple defined by mistake)
95
  if feeds:
96
  inp.value = feeds[-1].start_port.value
97
+ print(
98
+ f"[DEBUG Graph] Transferring value to {node.node_id}.{inp.name}: type={type(inp.value).__name__ if inp.value is not None else 'None'}")
99
 
100
  inp.state = DataPortState.CLEAN
101
 
102
+ old_outputs = {p.name: p.value for p in node.all_outputs()}
103
+
104
  # Process
105
+ await self._run_node(node)
106
 
107
+ # after process
108
  for outp in node.all_outputs():
109
  outp.state = DataPortState.CLEAN
110
 
111
+ # mark downstream dirty only if needed
112
  for c in self.downstream_of(node):
113
+ out_name = c.start_port.name
114
+ before = old_outputs.get(out_name)
115
+ after = c.start_port.value
116
+ if before != after:
117
+ c.end_port.state = DataPortState.DIRTY
dataflow/nodes_base.py CHANGED
@@ -54,3 +54,6 @@ class NodeInstance:
54
  self.on_process(self)
55
  else:
56
  pass
 
 
 
 
54
  self.on_process(self)
55
  else:
56
  pass
57
+
58
+ def reset_node(self) -> None:
59
+ pass
dataflow/ui/vueflow_canvas.py CHANGED
@@ -49,3 +49,7 @@ class VueFlowCanvas(Element, component="vueflow_canvas.vue"):
49
  def update_node_values(self, node_id: str, values: dict[str, Any]) -> None:
50
  """Patch the node.data.values dict."""
51
  self.run_method("updateNodeValues", {"id": node_id, "values": values})
 
 
 
 
 
49
  def update_node_values(self, node_id: str, values: dict[str, Any]) -> None:
50
  """Patch the node.data.values dict."""
51
  self.run_method("updateNodeValues", {"id": node_id, "values": values})
52
+
53
+ def update_node_progress(self, node_id: str, progress: float, message: str | None = None) -> None:
54
+ """Update progress percentage and message for a node."""
55
+ self.run_method("updateNodeProgress", {"id": node_id, "progress": progress, "message": message})
dataflow/ui/vueflow_canvas.vue CHANGED
@@ -3,7 +3,6 @@
3
  <div
4
  :class="['nicegui-vueflow', dark ? 'q-dark' : '']"
5
  tabindex="0"
6
- @keydown="onKeyDown"
7
  >
8
  <div v-if="!ready" class="vf-loading">Loading graph...</div>
9
  <div v-else class="vf-root">
@@ -26,7 +25,10 @@
26
  @pane-ready="onPaneReady"
27
  @pane-click="onPaneClick"
28
  @pane-context-menu="onPaneContextMenu"
 
29
  @connect="onConnect"
 
 
30
  @node-drag-stop="onNodeDragStop"
31
  @nodes-change="onNodesChange"
32
  @edges-change="onEdgesChange"
@@ -35,7 +37,11 @@
35
  >
36
  <!-- generic base node with left/right handles and field based UI -->
37
  <template #node-base="nodeProps">
38
- <div class="vf-node" :class="{ 'vf-node-selected': nodeProps.selected }">
 
 
 
 
39
  <!-- input handles on the left -->
40
  <div class="vf-node-handles vf-node-handles-left" v-if="Handle">
41
  <component
@@ -57,31 +63,71 @@
57
  {{ nodeProps.data.title || nodeProps.id }}
58
  </div>
59
  <div class="vf-node-header-actions">
60
- <a
61
- v-if="getImageResultValue(nodeProps)"
62
- :href="normalizeImageSrc(getImageResultValue(nodeProps))"
63
- download="generated.png"
64
- class="vf-exec-button vf-download-header"
65
- @click.stop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  >
67
- Download
68
- </a>
 
 
 
 
69
  <button
70
- v-if="nodeProps.data.executable"
71
- class="vf-exec-button"
72
- :disabled="nodeProps.data.processing"
73
- @click.stop="onExecuteClick(nodeProps)"
74
  >
75
- <span v-if="nodeProps.data.processing" class="vf-spinner"></span>
76
- <span v-else>Run</span>
 
 
77
  </button>
78
 
79
- <!-- simple dropdown for node actions -->
80
  <div class="vf-node-menu" @click.stop>
81
  <button
82
- type="button"
83
- class="vf-node-menu-button"
84
- @click.stop="toggleNodeMenu(nodeProps.id)"
85
  >
86
  <span class="vf-node-menu-dot"></span>
87
  <span class="vf-node-menu-dot"></span>
@@ -89,9 +135,9 @@
89
  </button>
90
  <div v-if="isNodeMenuOpen(nodeProps.id)" class="vf-node-menu-items">
91
  <button
92
- type="button"
93
- class="vf-node-menu-item"
94
- @click.stop="onNodeMenuAction(nodeProps.id, 'delete')"
95
  >
96
  Delete node
97
  </button>
@@ -102,15 +148,16 @@
102
 
103
  <!-- node level error -->
104
  <div v-if="nodeProps.data.values && nodeProps.data.values.error" class="vf-node-error">
105
- {{ nodeProps.data.values.error }}
106
- </div>
107
-
108
- <!-- processing progress bar -->
109
- <div v-if="nodeProps.data.processing" class="vf-node-progress">
110
- <div class="vf-node-progress-bar"></div>
111
  </div>
112
 
113
  <div class="vf-node-content">
 
 
 
 
 
 
114
  <template v-if="nodeProps.data.fields && nodeProps.data.fields.length">
115
  <div
116
  v-for="field in nodeProps.data.fields"
@@ -124,19 +171,27 @@
124
  type="text"
125
  v-model="nodeProps.data.values[field.name]"
126
  @blur="onFieldBlur(nodeProps, field)"
 
 
 
 
127
  :placeholder="field.placeholder || ''"
128
  />
129
  </div>
130
 
131
  <!-- textarea field -->
132
  <div v-else-if="field.kind === 'textarea'">
133
- <textarea
134
- class="vf-node-textarea"
135
- rows="5"
136
- v-model="nodeProps.data.values[field.name]"
137
- @blur="onFieldBlur(nodeProps, field)"
138
- :placeholder="field.placeholder || ''"
139
- ></textarea>
 
 
 
 
140
  </div>
141
 
142
  <!-- image upload field -->
@@ -154,14 +209,20 @@
154
  <button
155
  type="button"
156
  class="vf-exec-button vf-upload-header-btn-secondary"
157
- @click="onImageClipboard(nodeProps, field)"
158
- >
159
  Paste
160
  </button>
161
  </div>
162
  <div class="vf-image-preview-wrapper"
163
  v-if="nodeProps.data.values && nodeProps.data.values[field.name]">
164
  <img :src="normalizeImageSrc(nodeProps.data.values[field.name])"/>
 
 
 
 
 
 
 
165
  </div>
166
  <div class="vf-image-placeholder" v-else>No image</div>
167
  </div>
@@ -352,9 +413,17 @@ export default {
352
  contextMenuFlowPosition: null,
353
 
354
  windowKeyHandler: null,
 
355
 
356
  // which node menu is open
357
- nodeMenuFor: null
 
 
 
 
 
 
 
358
  };
359
  },
360
 
@@ -391,6 +460,15 @@ export default {
391
  };
392
  window.addEventListener("keydown", this.windowKeyHandler);
393
 
 
 
 
 
 
 
 
 
 
394
  this.ready = true;
395
  } catch (err) {
396
  console.error("Failed to initialize Vue Flow", err);
@@ -409,6 +487,10 @@ export default {
409
  window.removeEventListener("keydown", this.windowKeyHandler);
410
  this.windowKeyHandler = null;
411
  }
 
 
 
 
412
  },
413
 
414
  methods: {
@@ -576,6 +658,54 @@ export default {
576
  }
577
  },
578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  emitEvent(type, payload) {
580
  this.$emit("vf_event", {type: type, payload: payload});
581
  },
@@ -587,6 +717,18 @@ export default {
587
  this.emitEvent("execute_node", {id: id});
588
  },
589
 
 
 
 
 
 
 
 
 
 
 
 
 
590
  onFieldBlur(nodeProps, field) {
591
  if (!nodeProps || !nodeProps.data || !field) return;
592
  var values = nodeProps.data.values || {};
@@ -600,40 +742,112 @@ export default {
600
  });
601
  },
602
 
603
- onImageUpload(event, nodeProps, field) {
604
- const file = event.target.files[0];
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  if (!file) return;
606
- const reader = new FileReader();
607
- reader.onload = (e) => {
608
- const res = e.target.result;
609
- if (!nodeProps.data.values) nodeProps.data.values = {};
610
- nodeProps.data.values[field.name] = res;
611
- this.onFieldBlur(nodeProps, field);
612
- };
613
- reader.readAsDataURL(file);
 
 
614
  },
615
 
616
  async onImageClipboard(nodeProps, field) {
617
  try {
618
  const items = await navigator.clipboard.read();
619
  for (const item of items) {
620
- if (item.types.some(function (t) {
621
- return t.startsWith("image/");
622
- })) {
623
- const type = item.types.find(function (t) {
624
- return t.startsWith("image/");
625
- });
626
- const blob = await item.getType(type);
627
- const reader = new FileReader();
628
- reader.onload = (e) => {
629
- const res = e.target.result;
630
- if (!nodeProps.data.values) nodeProps.data.values = {};
631
- nodeProps.data.values[field.name] = res;
632
- this.onFieldBlur(nodeProps, field);
633
- };
634
- reader.readAsDataURL(blob);
635
- return;
636
- }
637
  }
638
  alert("No image found on clipboard");
639
  } catch (err) {
@@ -642,8 +856,24 @@ export default {
642
  }
643
  },
644
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
645
  normalizeImageSrc(src) {
646
  if (!src) return "";
 
647
  if (typeof src !== "string") {
648
  try {
649
  src = String(src);
@@ -651,8 +881,11 @@ export default {
651
  return "";
652
  }
653
  }
654
- var trimmed = src.trim();
 
655
  if (!trimmed) return "";
 
 
656
  if (
657
  trimmed.startsWith("http://") ||
658
  trimmed.startsWith("https://") ||
@@ -661,7 +894,38 @@ export default {
661
  ) {
662
  return trimmed;
663
  }
664
- return "data:image/png;base64," + trimmed;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
665
  },
666
 
667
  getImageResultValue(nodeProps) {
@@ -678,12 +942,19 @@ export default {
678
  },
679
 
680
  getHandleStyle(side, index, list, port) {
681
- var len = Array.isArray(list) ? list.length : 1;
682
- var step = 100 / (len + 1);
683
- var top = step * (index + 1);
684
- var style = {
685
- top: top + "%"
686
- };
 
 
 
 
 
 
 
687
  if (port && port.color) {
688
  style.backgroundColor = port.color;
689
  }
@@ -811,10 +1082,20 @@ export default {
811
  return;
812
  }
813
 
814
- this.emitEvent("create_node", {
 
815
  kind: kind,
816
  position: {x: pos.x, y: pos.y}
817
- });
 
 
 
 
 
 
 
 
 
818
 
819
  this.hideContextMenu();
820
  },
@@ -830,6 +1111,9 @@ export default {
830
  return;
831
  }
832
 
 
 
 
833
  if (key === "Tab" && event.type === "keydown") {
834
  event.preventDefault();
835
 
@@ -851,6 +1135,25 @@ export default {
851
  if (this.contextMenuVisible) {
852
  this.hideContextMenu();
853
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854
  }
855
  },
856
 
@@ -876,6 +1179,21 @@ export default {
876
  this.emitEvent("pane_click", event || {});
877
  },
878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
879
  onPaneContextMenu(event) {
880
  if (event && typeof event.preventDefault === "function") {
881
  event.preventDefault();
@@ -908,6 +1226,30 @@ export default {
908
 
909
  this.vf_api.addEdges([edge]);
910
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
911
  },
912
 
913
  onNodeDragStop(event) {
@@ -1173,6 +1515,9 @@ export default {
1173
 
1174
  .vf-handle {
1175
  pointer-events: auto;
 
 
 
1176
  }
1177
 
1178
  .vf-node-body {
@@ -1203,6 +1548,9 @@ export default {
1203
  display: flex;
1204
  flex-direction: column;
1205
  gap: 4px;
 
 
 
1206
  }
1207
 
1208
  .vf-node-text-input {
@@ -1238,7 +1586,7 @@ export default {
1238
 
1239
  .vf-exec-button {
1240
  border: none;
1241
- background: #3b82f6;
1242
  color: #fff;
1243
  border-radius: 4px;
1244
  padding: 2px 6px;
@@ -1255,54 +1603,31 @@ export default {
1255
  cursor: default;
1256
  }
1257
 
1258
- .vf-download-header {
1259
- background: #2563eb;
1260
- }
1261
-
1262
  .vf-spinner {
1263
- width: 12px;
1264
- height: 12px;
1265
- border-radius: 50%;
1266
- border: 2px solid rgba(255, 255, 255, 0.6);
1267
- border-top-color: rgba(15, 23, 42, 0.9);
1268
- animation: vf-spin 0.6s linear infinite;
1269
- }
1270
-
1271
- @keyframes vf-spin {
1272
- from {
1273
- transform: rotate(0deg);
1274
- }
1275
- to {
1276
- transform: rotate(360deg);
1277
- }
1278
  }
1279
 
1280
- /* processing progress bar */
1281
- .vf-node-progress {
1282
- position: relative;
1283
- height: 4px;
1284
- border-radius: 999px;
1285
- overflow: hidden;
1286
- background: #e5e7eb;
1287
- margin-bottom: 4px;
1288
  }
1289
 
1290
- .vf-node-progress-bar {
1291
- position: absolute;
1292
- left: -40%;
1293
- top: 0;
1294
- bottom: 0;
1295
- width: 40%;
1296
- background: #3b82f6;
1297
- animation: vf-progress 1s linear infinite;
1298
  }
1299
 
1300
- @keyframes vf-progress {
1301
  from {
1302
- transform: translateX(0);
1303
  }
1304
  to {
1305
- transform: translateX(260%);
1306
  }
1307
  }
1308
 
@@ -1310,9 +1635,19 @@ export default {
1310
  width: 160px;
1311
  height: 100px;
1312
  border-radius: 4px;
1313
- overflow: hidden;
1314
  border: 1px solid rgba(148, 163, 184, 0.7);
1315
  background: #f9fafb;
 
 
 
 
 
 
 
 
 
 
1316
  }
1317
 
1318
  .vf-image-preview {
@@ -1337,6 +1672,7 @@ export default {
1337
  height: 100%;
1338
  }
1339
 
 
1340
  /* upload controls - now small blue buttons like header */
1341
  .vf-image-upload-controls {
1342
  display: flex;
@@ -1358,6 +1694,7 @@ export default {
1358
  border: 1px solid #bfdbfe;
1359
  }
1360
 
 
1361
  /* simple node menu */
1362
 
1363
  .vf-node-menu {
@@ -1425,3 +1762,24 @@ export default {
1425
  opacity: 0.7;
1426
  }
1427
  </style>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  <div
4
  :class="['nicegui-vueflow', dark ? 'q-dark' : '']"
5
  tabindex="0"
 
6
  >
7
  <div v-if="!ready" class="vf-loading">Loading graph...</div>
8
  <div v-else class="vf-root">
 
25
  @pane-ready="onPaneReady"
26
  @pane-click="onPaneClick"
27
  @pane-context-menu="onPaneContextMenu"
28
+ @pane-mousemove="onPaneMouseMove"
29
  @connect="onConnect"
30
+ @connect-start="onConnectStart"
31
+ @connect-end="onConnectEnd"
32
  @node-drag-stop="onNodeDragStop"
33
  @nodes-change="onNodesChange"
34
  @edges-change="onEdgesChange"
 
37
  >
38
  <!-- generic base node with left/right handles and field based UI -->
39
  <template #node-base="nodeProps">
40
+ <div
41
+ class="vf-node"
42
+ :class="{ 'vf-node-selected': nodeProps.selected }"
43
+ :style="nodeProps.data.min_height ? { minHeight: nodeProps.data.min_height + 'px' } : {}"
44
+ >
45
  <!-- input handles on the left -->
46
  <div class="vf-node-handles vf-node-handles-left" v-if="Handle">
47
  <component
 
63
  {{ nodeProps.data.title || nodeProps.id }}
64
  </div>
65
  <div class="vf-node-header-actions">
66
+ <button
67
+ v-if="nodeProps.data.executable"
68
+ class="vf-exec-button"
69
+ :disabled="nodeProps.data.processing"
70
+ @click.stop="onExecuteClick(nodeProps)"
71
+ >
72
+ <span v-if="nodeProps.data.processing" class="vf-spinner">
73
+ <svg viewBox="0 0 36 36" class="vf-spinner-svg">
74
+ <circle
75
+ cx="18"
76
+ cy="18"
77
+ r="16"
78
+ fill="none"
79
+ stroke="rgba(255, 255, 255, 0.3)"
80
+ stroke-width="3"
81
+ ></circle>
82
+ <circle
83
+ cx="18"
84
+ cy="18"
85
+ r="16"
86
+ fill="none"
87
+ stroke="#3b82f6"
88
+ stroke-width="3"
89
+ stroke-dasharray="100"
90
+ :stroke-dashoffset="100 - ((nodeProps.data.progress || 0) * 100)"
91
+ class="vf-spinner-progress"
92
+ ></circle>
93
+ </svg>
94
+ </span>
95
+ <span v-else>
96
+ <q-icon name="rocket_launch" />
97
+ <q-tooltip>Run Node</q-tooltip>
98
+ </span>
99
+ </button>
100
+
101
+ <button
102
+ v-if="nodeProps.data.executable"
103
+ class="vf-exec-button"
104
+ :disabled="nodeProps.data.processing || !nodeHasContent(nodeProps)"
105
+ @click.stop="onResetClick(nodeProps)"
106
  >
107
+ <span>
108
+ <q-icon name="delete_forever" />
109
+ <q-tooltip>Reset Node</q-tooltip>
110
+ </span>
111
+ </button>
112
+
113
  <button
114
+ v-if="nodeProps.data.executable"
115
+ class="vf-exec-button"
116
+ :disabled="nodeProps.data.processing || !nodeHasContent(nodeProps)"
117
+ @click.stop="onDownloadClick(nodeProps)"
118
  >
119
+ <span>
120
+ <q-icon name="download" />
121
+ <q-tooltip>Download Image</q-tooltip>
122
+ </span>
123
  </button>
124
 
125
+ <!-- rest of header (menu) unchanged -->
126
  <div class="vf-node-menu" @click.stop>
127
  <button
128
+ type="button"
129
+ class="vf-node-menu-button"
130
+ @click.stop="toggleNodeMenu(nodeProps.id)"
131
  >
132
  <span class="vf-node-menu-dot"></span>
133
  <span class="vf-node-menu-dot"></span>
 
135
  </button>
136
  <div v-if="isNodeMenuOpen(nodeProps.id)" class="vf-node-menu-items">
137
  <button
138
+ type="button"
139
+ class="vf-node-menu-item"
140
+ @click.stop="onNodeMenuAction(nodeProps.id, 'delete')"
141
  >
142
  Delete node
143
  </button>
 
148
 
149
  <!-- node level error -->
150
  <div v-if="nodeProps.data.values && nodeProps.data.values.error" class="vf-node-error">
151
+ <span>{{ nodeProps.data.values.error }}</span>
 
 
 
 
 
152
  </div>
153
 
154
  <div class="vf-node-content">
155
+ <q-inner-loading
156
+ v-if="nodeProps.data.executable"
157
+ :showing="nodeProps.data.processing && (nodeProps.data.progress !== 1 || nodeProps.data.progress !== 0)"
158
+ >
159
+ <q-spinner-grid size="50px" color="black" />
160
+ </q-inner-loading>
161
  <template v-if="nodeProps.data.fields && nodeProps.data.fields.length">
162
  <div
163
  v-for="field in nodeProps.data.fields"
 
171
  type="text"
172
  v-model="nodeProps.data.values[field.name]"
173
  @blur="onFieldBlur(nodeProps, field)"
174
+ @keydown.stop
175
+ @wheel.stop
176
+ @mousedown.stop
177
+ @click.stop
178
  :placeholder="field.placeholder || ''"
179
  />
180
  </div>
181
 
182
  <!-- textarea field -->
183
  <div v-else-if="field.kind === 'textarea'">
184
+ <textarea
185
+ class="vf-node-textarea"
186
+ rows="5"
187
+ v-model="nodeProps.data.values[field.name]"
188
+ @blur="onFieldBlur(nodeProps, field)"
189
+ @keydown.stop
190
+ @wheel.stop
191
+ @mousedown.stop
192
+ @click.stop
193
+ :placeholder="field.placeholder || ''"
194
+ ></textarea>
195
  </div>
196
 
197
  <!-- image upload field -->
 
209
  <button
210
  type="button"
211
  class="vf-exec-button vf-upload-header-btn-secondary"
212
+ @click="onImageClipboard(nodeProps, field)">
 
213
  Paste
214
  </button>
215
  </div>
216
  <div class="vf-image-preview-wrapper"
217
  v-if="nodeProps.data.values && nodeProps.data.values[field.name]">
218
  <img :src="normalizeImageSrc(nodeProps.data.values[field.name])"/>
219
+ <button
220
+ type="button"
221
+ class="vf-image-delete-btn"
222
+ @click.stop="onImageDelete(nodeProps, field)"
223
+ >
224
+ Delete
225
+ </button>
226
  </div>
227
  <div class="vf-image-placeholder" v-else>No image</div>
228
  </div>
 
413
  contextMenuFlowPosition: null,
414
 
415
  windowKeyHandler: null,
416
+ windowMouseHandler: null,
417
 
418
  // which node menu is open
419
+ nodeMenuFor: null,
420
+
421
+ // copy/paste support
422
+ copiedNodeId: null,
423
+ mouseFlowPosition: {x: 0, y: 0},
424
+
425
+ // pending connection for Tab shortcut
426
+ pendingConnection: null
427
  };
428
  },
429
 
 
460
  };
461
  window.addEventListener("keydown", this.windowKeyHandler);
462
 
463
+ // Track mouse position globally for paste functionality
464
+ this.windowMouseHandler = (ev) => {
465
+ if (ev.clientX !== undefined && ev.clientY !== undefined) {
466
+ var flowPos = this.projectScreenToFlow(ev.clientX, ev.clientY);
467
+ this.mouseFlowPosition = flowPos;
468
+ }
469
+ };
470
+ window.addEventListener("mousemove", this.windowMouseHandler);
471
+
472
  this.ready = true;
473
  } catch (err) {
474
  console.error("Failed to initialize Vue Flow", err);
 
487
  window.removeEventListener("keydown", this.windowKeyHandler);
488
  this.windowKeyHandler = null;
489
  }
490
+ if (this.windowMouseHandler) {
491
+ window.removeEventListener("mousemove", this.windowMouseHandler);
492
+ this.windowMouseHandler = null;
493
+ }
494
  },
495
 
496
  methods: {
 
658
  }
659
  },
660
 
661
+ updateNodeProgress(payload) {
662
+ if (!payload || !payload.id) return;
663
+ var id = payload.id;
664
+ var progress = typeof payload.progress === "number" ? payload.progress : 0;
665
+ var message = payload.message || null;
666
+
667
+ var idx = this.nodes.findIndex(function (n) {
668
+ return n.id === id;
669
+ });
670
+ if (idx >= 0) {
671
+ var node = this.nodes[idx];
672
+ var data = Object.assign({}, node.data || {});
673
+ data.progress = progress;
674
+ data.progressMessage = message;
675
+ var updatedLocal = Object.assign({}, node, {data: data});
676
+ var localCopy = this.nodes.slice();
677
+ localCopy.splice(idx, 1, updatedLocal);
678
+ this.nodes = localCopy;
679
+ }
680
+
681
+ if (this.vf_api && this.vf_api.updateNode) {
682
+ this.vf_api.updateNode(id, function (node) {
683
+ var next = Object.assign({}, node || {});
684
+ var data = Object.assign({}, next.data || {});
685
+ data.progress = progress;
686
+ data.progressMessage = message;
687
+ next.data = data;
688
+ return next;
689
+ });
690
+ }
691
+ },
692
+
693
+ nodeHasContent(nodeProps) {
694
+ const data = nodeProps.data || {}
695
+ const fields = data.fields || []
696
+ const values = data.values || {}
697
+
698
+ // adjust the kinds you care about
699
+ const contentField = fields.find(field =>
700
+ field.kind === 'image_result' ||
701
+ field.kind === 'image'
702
+ )
703
+ if (!contentField) {
704
+ return false
705
+ }
706
+ return !!values[contentField.name]
707
+ },
708
+
709
  emitEvent(type, payload) {
710
  this.$emit("vf_event", {type: type, payload: payload});
711
  },
 
717
  this.emitEvent("execute_node", {id: id});
718
  },
719
 
720
+ onDownloadClick(nodeProps) {
721
+ if (!nodeProps || !nodeProps.id) return;
722
+ var id = nodeProps.id;
723
+ this.emitEvent("download_node", {id: id});
724
+ },
725
+
726
+ onResetClick(nodeProps) {
727
+ if (!nodeProps || !nodeProps.id) return;
728
+ var id = nodeProps.id;
729
+ this.emitEvent("reset_node", {id: id});
730
+ },
731
+
732
  onFieldBlur(nodeProps, field) {
733
  if (!nodeProps || !nodeProps.data || !field) return;
734
  var values = nodeProps.data.values || {};
 
742
  });
743
  },
744
 
745
+ resizeAndConvertToWebp(blob, maxSize = 1024) {
746
+ return new Promise((resolve, reject) => {
747
+ const img = new Image();
748
+ const url = URL.createObjectURL(blob);
749
+
750
+ img.onload = () => {
751
+ URL.revokeObjectURL(url);
752
+
753
+ const originalWidth = img.width || 1;
754
+ const originalHeight = img.height || 1;
755
+
756
+ const scale = Math.min(
757
+ maxSize / originalWidth,
758
+ maxSize / originalHeight,
759
+ 1
760
+ );
761
+
762
+ const targetWidth = Math.round(originalWidth * scale);
763
+ const targetHeight = Math.round(originalHeight * scale);
764
+
765
+ const canvas = document.createElement("canvas");
766
+ canvas.width = targetWidth;
767
+ canvas.height = targetHeight;
768
+
769
+ const ctx = canvas.getContext("2d");
770
+ if (!ctx) {
771
+ reject(new Error("Could not get 2d context"));
772
+ return;
773
+ }
774
+
775
+ ctx.drawImage(img, 0, 0, targetWidth, targetHeight);
776
+
777
+ // best effort: WebP if possible, fall back to PNG if needed
778
+ try {
779
+ const dataUrl = canvas.toDataURL("image/webp", 0.9);
780
+ if (dataUrl && dataUrl.startsWith("data:image/webp")) {
781
+ resolve(dataUrl);
782
+ return;
783
+ }
784
+ } catch (e) {
785
+ // ignore and fall through
786
+ }
787
+
788
+ // fallback
789
+ try {
790
+ const fallback = canvas.toDataURL("image/png");
791
+ resolve(fallback);
792
+ } catch (e) {
793
+ reject(e);
794
+ }
795
+ };
796
+
797
+ img.onerror = (err) => {
798
+ URL.revokeObjectURL(url);
799
+ reject(err);
800
+ };
801
+
802
+ img.src = url;
803
+ });
804
+ },
805
+
806
+ async setImageValueFromBlob(blob, nodeProps, field) {
807
+ const maxSize = 1024;
808
+
809
+ const dataUrl = await this.resizeAndConvertToWebp(blob, maxSize);
810
+
811
+ if (!nodeProps.data.values) {
812
+ nodeProps.data.values = {};
813
+ }
814
+
815
+ nodeProps.data.values[field.name] = dataUrl;
816
+
817
+ this.updateNodeValues({
818
+ id: nodeProps.id,
819
+ values: {[field.name]: dataUrl}
820
+ });
821
+
822
+ this.onFieldBlur(nodeProps, field);
823
+ },
824
+
825
+ async onImageUpload(event, nodeProps, field) {
826
+ const file = event.target.files && event.target.files[0];
827
  if (!file) return;
828
+
829
+ try {
830
+ await this.setImageValueFromBlob(file, nodeProps, field);
831
+ } catch (err) {
832
+ console.error(err);
833
+ alert("Failed to process image");
834
+ } finally {
835
+ // reset input so the same file can be selected again
836
+ event.target.value = "";
837
+ }
838
  },
839
 
840
  async onImageClipboard(nodeProps, field) {
841
  try {
842
  const items = await navigator.clipboard.read();
843
  for (const item of items) {
844
+ const type = item.types.find((t) => t.startsWith("image/"));
845
+ if (!type) continue;
846
+
847
+ const blob = await item.getType(type);
848
+
849
+ await this.setImageValueFromBlob(blob, nodeProps, field);
850
+ return;
 
 
 
 
 
 
 
 
 
 
851
  }
852
  alert("No image found on clipboard");
853
  } catch (err) {
 
856
  }
857
  },
858
 
859
+ onImageDelete(nodeProps, field) {
860
+ if (!nodeProps.data.values) {
861
+ nodeProps.data.values = {};
862
+ }
863
+
864
+ nodeProps.data.values[field.name] = "";
865
+
866
+ this.updateNodeValues({
867
+ id: nodeProps.id,
868
+ values: {[field.name]: ""}
869
+ });
870
+
871
+ this.onFieldBlur(nodeProps, field);
872
+ },
873
+
874
  normalizeImageSrc(src) {
875
  if (!src) return "";
876
+
877
  if (typeof src !== "string") {
878
  try {
879
  src = String(src);
 
881
  return "";
882
  }
883
  }
884
+
885
+ const trimmed = src.trim();
886
  if (!trimmed) return "";
887
+
888
+ // already a complete URL or data URI
889
  if (
890
  trimmed.startsWith("http://") ||
891
  trimmed.startsWith("https://") ||
 
894
  ) {
895
  return trimmed;
896
  }
897
+
898
+ // if Python sends only the base64 without prefix:
899
+ // try to detect type from first bytes
900
+ try {
901
+ // decode a few bytes to inspect header
902
+ const header = atob(trimmed.slice(0, 20));
903
+
904
+ // png header 89 50 4E 47
905
+ if (header.startsWith("\x89PNG")) {
906
+ return "data:image/png;base64," + trimmed;
907
+ }
908
+
909
+ // jpeg header FF D8 FF
910
+ if (header.startsWith("\xFF\xD8\xFF")) {
911
+ return "data:image/jpeg;base64," + trimmed;
912
+ }
913
+
914
+ // webp header RIFF....WEBP
915
+ if (header.startsWith("RIFF") && header.slice(8, 12) === "WEBP") {
916
+ return "data:image/webp;base64," + trimmed;
917
+ }
918
+
919
+ // gif header GIF87a / GIF89a
920
+ if (header.startsWith("GIF8")) {
921
+ return "data:image/gif;base64," + trimmed;
922
+ }
923
+ } catch (err) {
924
+ // ignore and use fallback prefix
925
+ }
926
+
927
+ // last resort fallback (safe but generic)
928
+ return "data:image/*;base64," + trimmed;
929
  },
930
 
931
  getImageResultValue(nodeProps) {
 
942
  },
943
 
944
  getHandleStyle(side, index, list, port) {
945
+ var style = {};
946
+
947
+ // Check if port has a custom top position
948
+ if (port && typeof port.top === "number") {
949
+ style.top = port.top + "%";
950
+ } else {
951
+ // Default: evenly distribute handles
952
+ var len = Array.isArray(list) ? list.length : 1;
953
+ var step = 100 / (len + 1);
954
+ var top = step * (index + 1);
955
+ style.top = top + "%";
956
+ }
957
+
958
  if (port && port.color) {
959
  style.backgroundColor = port.color;
960
  }
 
1082
  return;
1083
  }
1084
 
1085
+ // Include pending connection if exists
1086
+ var eventData = {
1087
  kind: kind,
1088
  position: {x: pos.x, y: pos.y}
1089
+ };
1090
+
1091
+ if (this.pendingConnection) {
1092
+ eventData.pendingConnection = this.pendingConnection;
1093
+ }
1094
+
1095
+ this.emitEvent("create_node", eventData);
1096
+
1097
+ // Clear pending connection
1098
+ this.pendingConnection = null;
1099
 
1100
  this.hideContextMenu();
1101
  },
 
1111
  return;
1112
  }
1113
 
1114
+ var isMac = /Mac/.test(navigator.platform || navigator.userAgent);
1115
+ var cmdOrCtrl = isMac ? event.metaKey : event.ctrlKey;
1116
+
1117
  if (key === "Tab" && event.type === "keydown") {
1118
  event.preventDefault();
1119
 
 
1135
  if (this.contextMenuVisible) {
1136
  this.hideContextMenu();
1137
  }
1138
+ } else if ((key === "c" || key === "C") && cmdOrCtrl) {
1139
+ // Copy selected node
1140
+ event.preventDefault();
1141
+ var selectedNodes = this.nodes.filter(function(n) { return n.selected; });
1142
+ if (selectedNodes.length > 0) {
1143
+ // Copy the first selected node
1144
+ this.copiedNodeId = selectedNodes[0].id;
1145
+ console.log("Copied node:", this.copiedNodeId);
1146
+ }
1147
+ } else if ((key === "v" || key === "V") && cmdOrCtrl) {
1148
+ // Paste node at current mouse position
1149
+ event.preventDefault();
1150
+ if (this.copiedNodeId) {
1151
+ console.log("Pasting node:", this.copiedNodeId, "at", this.mouseFlowPosition);
1152
+ this.emitEvent("duplicate_node", {
1153
+ sourceNodeId: this.copiedNodeId,
1154
+ position: {x: this.mouseFlowPosition.x, y: this.mouseFlowPosition.y}
1155
+ });
1156
+ }
1157
  }
1158
  },
1159
 
 
1179
  this.emitEvent("pane_click", event || {});
1180
  },
1181
 
1182
+ onPaneMouseMove(event) {
1183
+ // Track mouse position in flow coordinates for paste
1184
+ // VueFlow pane-mousemove provides the event with coordinates
1185
+ if (!event) return;
1186
+
1187
+ // Try to get flow coordinates from the event
1188
+ if (event.flowX !== undefined && event.flowY !== undefined) {
1189
+ this.mouseFlowPosition = {x: event.flowX, y: event.flowY};
1190
+ } else if (event.clientX !== undefined && event.clientY !== undefined) {
1191
+ // If flow coordinates aren't available, convert screen to flow coordinates
1192
+ var flowPos = this.projectScreenToFlow(event.clientX, event.clientY);
1193
+ this.mouseFlowPosition = flowPos;
1194
+ }
1195
+ },
1196
+
1197
  onPaneContextMenu(event) {
1198
  if (event && typeof event.preventDefault === "function") {
1199
  event.preventDefault();
 
1226
 
1227
  this.vf_api.addEdges([edge]);
1228
  }
1229
+
1230
+ // Clear pending connection after successful connection
1231
+ this.pendingConnection = null;
1232
+ },
1233
+
1234
+ onConnectStart(params) {
1235
+ // Store connection start info when user begins dragging from a handle
1236
+ console.log("Connection start:", params);
1237
+ this.pendingConnection = {
1238
+ nodeId: params.nodeId || null,
1239
+ handleId: params.handleId || null,
1240
+ handleType: params.handleType || null
1241
+ };
1242
+ },
1243
+
1244
+ onConnectEnd(event) {
1245
+ // Connection ended - if user pressed Tab, keep pendingConnection for node creation
1246
+ // Otherwise clear after a short delay
1247
+ var self = this;
1248
+ setTimeout(function() {
1249
+ if (!self.contextMenuVisible) {
1250
+ self.pendingConnection = null;
1251
+ }
1252
+ }, 100);
1253
  },
1254
 
1255
  onNodeDragStop(event) {
 
1515
 
1516
  .vf-handle {
1517
  pointer-events: auto;
1518
+ width: 8px !important;
1519
+ height: 8px !important;
1520
+ border-width: 1px !important;
1521
  }
1522
 
1523
  .vf-node-body {
 
1548
  display: flex;
1549
  flex-direction: column;
1550
  gap: 4px;
1551
+
1552
+ position: relative;
1553
+ overflow: hidden;
1554
  }
1555
 
1556
  .vf-node-text-input {
 
1586
 
1587
  .vf-exec-button {
1588
  border: none;
1589
+ background: #000000;
1590
  color: #fff;
1591
  border-radius: 4px;
1592
  padding: 2px 6px;
 
1603
  cursor: default;
1604
  }
1605
 
 
 
 
 
1606
  .vf-spinner {
1607
+ display: inline-block;
1608
+ width: 14px;
1609
+ height: 14px;
1610
+ position: relative;
 
 
 
 
 
 
 
 
 
 
 
1611
  }
1612
 
1613
+ .vf-spinner-svg {
1614
+ width: 100%;
1615
+ height: 100%;
1616
+ transform: rotate(-90deg);
1617
+ animation: vf-spin 2s linear infinite;
 
 
 
1618
  }
1619
 
1620
+ .vf-spinner-progress {
1621
+ stroke-linecap: round;
1622
+ transition: stroke-dashoffset 0.4s ease-out;
 
 
 
 
 
1623
  }
1624
 
1625
+ @keyframes vf-spin {
1626
  from {
1627
+ transform: rotate(-90deg);
1628
  }
1629
  to {
1630
+ transform: rotate(270deg);
1631
  }
1632
  }
1633
 
 
1635
  width: 160px;
1636
  height: 100px;
1637
  border-radius: 4px;
1638
+ overflow: visible;
1639
  border: 1px solid rgba(148, 163, 184, 0.7);
1640
  background: #f9fafb;
1641
+ position: relative;
1642
+ }
1643
+
1644
+ .vf-image-preview-wrapper img {
1645
+ border-radius: 4px;
1646
+ overflow: hidden;
1647
+ display: block;
1648
+ width: 100%;
1649
+ height: 100%;
1650
+ object-fit: cover;
1651
  }
1652
 
1653
  .vf-image-preview {
 
1672
  height: 100%;
1673
  }
1674
 
1675
+
1676
  /* upload controls - now small blue buttons like header */
1677
  .vf-image-upload-controls {
1678
  display: flex;
 
1694
  border: 1px solid #bfdbfe;
1695
  }
1696
 
1697
+
1698
  /* simple node menu */
1699
 
1700
  .vf-node-menu {
 
1762
  opacity: 0.7;
1763
  }
1764
  </style>
1765
+
1766
+ <style>
1767
+ .vf-image-preview-wrapper .vf-image-delete-btn {
1768
+ position: absolute ;
1769
+ top: 28px ;
1770
+ right: 12px ;
1771
+ z-index: 100 ;
1772
+ color: #ffffff ;
1773
+ background-color: #eb311c ;
1774
+ border: none;
1775
+ border-radius: 4px;
1776
+ padding: 2px 6px;
1777
+ font-size: 11px;
1778
+ cursor: pointer;
1779
+ display: inline-flex;
1780
+ align-items: center;
1781
+ justify-content: center;
1782
+ white-space: nowrap;
1783
+ }
1784
+
1785
+ </style>
main.py CHANGED
@@ -3,132 +3,80 @@ from __future__ import annotations
3
  import os
4
 
5
  import uvicorn
 
6
  from fastapi import FastAPI
 
7
  from nicegui import ui, Client, app
8
-
9
- from dataflow.ui.vueflow_canvas import VueFlowCanvas
10
  from nodes.session import GraphSession
 
 
 
 
11
 
12
  fastapi_app = FastAPI()
13
  ui.run_with(fastapi_app, storage_secret="velai-storage-secret")
14
 
 
 
 
15
  APP_PASSWORD = os.getenv("VELAI_PASSWORD") or ""
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  @ui.page("/")
19
  async def main(client: Client) -> None:
20
- ui.page_title("velai")
 
21
  ui.query(".nicegui-content").classes("p-0")
22
 
23
  await client.connected()
24
 
25
- # password gate: require once per user while APP_PASSWORD stays the same
26
- password_required = False
27
- if APP_PASSWORD:
28
- auth = app.storage.user.get("auth") or {}
29
- if not (auth.get("ok") and auth.get("password_version") == APP_PASSWORD):
30
- password_required = True
31
 
32
  # one GraphSession per browser tab
33
- if "graph_session" not in app.storage.tab:
34
- app.storage.tab["graph_session"] = GraphSession.create_default()
35
- session: GraphSession = app.storage.tab["graph_session"]
36
-
37
- initial_nodes = session.initial_vue_nodes()
38
- initial_edges = session.initial_vue_edges()
39
-
40
- with ui.column().classes("w-full h-screen no-wrap"):
41
- # header
42
- with ui.row().classes(
43
- "w-full items-center justify-between px-4 py-2 bg-grey-2"
44
- ):
45
- ui.label("velai").classes("text-lg font-bold")
46
-
47
- with ui.row().classes("items-center gap-2"):
48
- with ui.dropdown_button("Add Node", auto_close=True):
49
- kinds = session.creatable_node_types
50
- for k in kinds:
51
- ui.menu_item(k["title"], on_click=lambda _, kind=k["kind"]: add_node_action(kind))
52
-
53
- with ui.dropdown_button("Graph", auto_close=True):
54
- ui.menu_item("Export JSON", on_click=lambda: export_action())
55
- ui.menu_item("Import JSON", on_click=lambda: import_action())
56
-
57
- clear_button = ui.button("Clear graph").props("flat color='negative'")
58
-
59
- # canvas area fills the remaining height
60
- with ui.row().classes("w-full flex-1 no-wrap"):
61
- canvas = VueFlowCanvas(
62
- creatable_node_types=session.creatable_node_types,
63
- ).classes("w-full h-full flex-1")
64
-
65
- session.attach_canvas(canvas)
66
-
67
- async def handle_event(e) -> None:
68
- await session.handle_ui_event(e.args, canvas)
69
-
70
- canvas.on("vf_event", handle_event)
71
- canvas.set_graph(initial_nodes, initial_edges)
72
-
73
- # wire header actions now that canvas exists
74
- def add_node_action(kind_value: str) -> None:
75
- import random
76
- vue_node = session.create_node(kind_value, position={"x": 100 + random.randint(0, 50),
77
- "y": 100 + random.randint(0, 50)})
78
- if vue_node is not None:
79
- canvas.add_node(vue_node)
80
-
81
- def clear_graph_action() -> None:
82
- session.clear_graph()
83
- canvas.set_graph([], [])
84
-
85
- async def export_action() -> None:
86
- json_str = session.to_json()
87
- ui.download(json_str.encode("utf-8"), "graph.json")
88
-
89
- async def import_action() -> None:
90
- with ui.dialog() as dialog, ui.card():
91
- ui.label("Upload Graph JSON")
92
- ui.upload(on_upload=lambda e: [load_file(e), dialog.close()], auto_upload=True, max_files=1)
93
- dialog.open()
94
-
95
- def load_file(e) -> None:
96
- content = e.content.read().decode("utf-8")
97
- session.load_from_json(content)
98
- nodes = session.initial_vue_nodes()
99
- edges = session.initial_vue_edges()
100
- canvas.set_graph(nodes, edges)
101
-
102
- clear_button.on_click(clear_graph_action)
103
-
104
- # password dialog on top of everything if required
105
- if password_required:
106
- with ui.dialog() as dialog, ui.card():
107
- await dialog.props("persistent")
108
- ui.label("Enter password").classes("text-md font-bold mb-2")
109
- pwd_input = ui.input(label="Password").props('type="password"')
110
- error_label = ui.label("").style(
111
- "color: red; font-size: 0.8rem; min-height: 1rem"
112
- )
113
-
114
- def submit() -> None:
115
- value = pwd_input.value or ""
116
- if value == APP_PASSWORD:
117
- app.storage.user["auth"] = {
118
- "ok": True,
119
- "password_version": APP_PASSWORD,
120
- }
121
- error_label.text = ""
122
- dialog.close()
123
- else:
124
- error_label.text = "Wrong password"
125
-
126
- pwd_input.on("keydown.enter", lambda _: submit())
127
-
128
- with ui.row().classes("mt-2 items-center justify-end gap-2"):
129
- ui.button("Enter", on_click=submit).props("color='primary'")
130
-
131
- dialog.open()
132
 
133
 
134
  if __name__ in {"__main__", "__mp_main__"}:
 
3
  import os
4
 
5
  import uvicorn
6
+ from dotenv import load_dotenv
7
  from fastapi import FastAPI
8
+ from fastapi.staticfiles import StaticFiles
9
  from nicegui import ui, Client, app
 
 
10
  from nodes.session import GraphSession
11
+ from velai_app import VelaiApp
12
+
13
+
14
+ load_dotenv()
15
 
16
  fastapi_app = FastAPI()
17
  ui.run_with(fastapi_app, storage_secret="velai-storage-secret")
18
 
19
+ # Mount assets folder to serve static files (including favicon)
20
+ fastapi_app.mount("/assets", StaticFiles(directory="assets"), name="assets")
21
+
22
  APP_PASSWORD = os.getenv("VELAI_PASSWORD") or ""
23
 
24
 
25
+ async def require_password() -> None:
26
+ if not APP_PASSWORD:
27
+ return
28
+
29
+ auth = app.storage.user.get("auth") or {}
30
+ if auth.get("ok") and auth.get("password_version") == APP_PASSWORD:
31
+ return
32
+
33
+ # build awaitable dialog exactly like in the docs
34
+ with ui.dialog() as dialog, ui.card():
35
+ dialog.props("persistent")
36
+
37
+ ui.label("Enter password").classes("text-md font-bold mb-2")
38
+ pwd_input = ui.input(label="Password").props('type="password"')
39
+ error_label = ui.label("").style(
40
+ "color: red; font-size: 0.8rem; min-height: 1rem"
41
+ )
42
+
43
+ def submit() -> None:
44
+ value = pwd_input.value or ""
45
+ if value == APP_PASSWORD:
46
+ app.storage.user["auth"] = {
47
+ "ok": True,
48
+ "password_version": APP_PASSWORD,
49
+ }
50
+ error_label.text = ""
51
+ dialog.submit(True)
52
+ else:
53
+ error_label.text = "Wrong password"
54
+
55
+ pwd_input.on("keydown.enter", lambda _: submit())
56
+ with ui.row().classes("mt-2 items-center justify-end gap-2"):
57
+ ui.button("Enter", on_click=submit).props("color='primary'")
58
+
59
+ # this opens the dialog, waits for dialog.submit, then closes it
60
+ await dialog
61
+
62
+
63
  @ui.page("/")
64
  async def main(client: Client) -> None:
65
+ ui.page_title("VELAI")
66
+ ui.add_head_html('<link rel="icon" type="image/x-icon" href="/assets/favicon.ico">')
67
  ui.query(".nicegui-content").classes("p-0")
68
 
69
  await client.connected()
70
 
71
+ # stop here until the password is accepted (or no password is set)
72
+ await require_password()
 
 
 
 
73
 
74
  # one GraphSession per browser tab
75
+ session = GraphSession.create_default()
76
+
77
+ # build main application UI
78
+ app_ui = VelaiApp(session)
79
+ app_ui.build()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
 
82
  if __name__ in {"__main__", "__mp_main__"}:
nodes/controller.py CHANGED
@@ -6,6 +6,8 @@ from typing import Any
6
  from dataflow.connection import Connection
7
  from dataflow.enums import DataPortState
8
  from dataflow.graph import DataGraph
 
 
9
  from .text_data import TextDataNode
10
  from .text_to_image import TextToImageNode
11
 
@@ -37,11 +39,20 @@ class GraphController:
37
  edges = raw_payload or []
38
  if isinstance(edges, list):
39
  self._on_edges_delete(edges)
 
 
 
 
40
  elif event_type == "nodes_delete":
41
  nodes = raw_payload or []
42
  if isinstance(nodes, list):
43
  self._on_nodes_delete(nodes)
44
- # nodes_change, edges_change, other events can be added here later
 
 
 
 
 
45
 
46
  def _on_connect(self, payload: dict[str, Any]) -> None:
47
  source_handle = payload.get("sourceHandle") or ""
@@ -95,15 +106,22 @@ class GraphController:
95
  # datatype mismatch or capacity problems
96
  return
97
 
98
- def _on_node_moved(self, payload: dict[str, Any]) -> None:
99
- node_id = payload.get("id")
100
- if not node_id:
101
- return
 
 
 
 
 
 
 
102
  node = self.graph.nodes.get(node_id)
103
  if node is None:
104
  return
105
 
106
- pos = payload.get("position") or {}
107
  x = pos.get("x")
108
  y = pos.get("y")
109
 
@@ -118,6 +136,13 @@ class GraphController:
118
  except (TypeError, ValueError):
119
  pass
120
 
 
 
 
 
 
 
 
121
  def _on_node_field_changed(self, payload: dict[str, Any]) -> None:
122
  node_id = payload.get("id")
123
  field = payload.get("field")
@@ -136,6 +161,23 @@ class GraphController:
136
  port.value = "" if value is None else str(value)
137
  port.state = DataPortState.DIRTY
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  elif isinstance(node, TextToImageNode) and field == "image":
140
  node.image_src = "" if value is None else str(value)
141
 
@@ -169,12 +211,68 @@ class GraphController:
169
 
170
  self.graph.connections = [c for c in self.graph.connections if not should_remove(c)]
171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  def _on_nodes_delete(self, nodes: list[dict[str, Any]]) -> None:
173
  """Remove nodes and all their connections when Vue deletes them."""
174
  if not nodes:
175
  return
176
 
177
  node_ids = {n.get("id") for n in nodes if isinstance(n, dict) and n.get("id")}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  if not node_ids:
179
  return
180
 
 
6
  from dataflow.connection import Connection
7
  from dataflow.enums import DataPortState
8
  from dataflow.graph import DataGraph
9
+ from . import utils
10
+ from .image_data import ImageDataNode
11
  from .text_data import TextDataNode
12
  from .text_to_image import TextToImageNode
13
 
 
39
  edges = raw_payload or []
40
  if isinstance(edges, list):
41
  self._on_edges_delete(edges)
42
+ elif event_type == "edges_change":
43
+ changes = raw_payload or []
44
+ if isinstance(changes, list):
45
+ self._on_edges_change(changes)
46
  elif event_type == "nodes_delete":
47
  nodes = raw_payload or []
48
  if isinstance(nodes, list):
49
  self._on_nodes_delete(nodes)
50
+ elif event_type == "nodes_change":
51
+ changes = raw_payload or []
52
+ if isinstance(changes, list):
53
+ self._on_nodes_change(changes)
54
+ # other events (graph_cleared, create_node, etc.) are currently ignored
55
+ # because they do not affect the DataGraph directly here
56
 
57
  def _on_connect(self, payload: dict[str, Any]) -> None:
58
  source_handle = payload.get("sourceHandle") or ""
 
106
  # datatype mismatch or capacity problems
107
  return
108
 
109
+ # make the dataflow "dirty" because topology changed
110
+ if hasattr(start_port, "state"):
111
+ start_port.state = DataPortState.DIRTY
112
+
113
+ if hasattr(end_port, "state"):
114
+ end_port.state = DataPortState.DIRTY
115
+ # often you also want to clear the input value
116
+ if hasattr(end_port, "value"):
117
+ end_port.value = None
118
+
119
+ def _update_node_position(self, node_id: str, position: dict[str, Any]) -> None:
120
  node = self.graph.nodes.get(node_id)
121
  if node is None:
122
  return
123
 
124
+ pos = position or {}
125
  x = pos.get("x")
126
  y = pos.get("y")
127
 
 
136
  except (TypeError, ValueError):
137
  pass
138
 
139
+ def _on_node_moved(self, payload: dict[str, Any]) -> None:
140
+ node_id = payload.get("id")
141
+ if not node_id:
142
+ return
143
+ position = payload.get("position") or {}
144
+ self._update_node_position(node_id, position)
145
+
146
  def _on_node_field_changed(self, payload: dict[str, Any]) -> None:
147
  node_id = payload.get("id")
148
  field = payload.get("field")
 
161
  port.value = "" if value is None else str(value)
162
  port.state = DataPortState.DIRTY
163
 
164
+ elif isinstance(node, ImageDataNode) and field == "image":
165
+ port = node.outputs.get("image") if node.outputs is not None else None
166
+ if port is not None:
167
+ if value:
168
+ # value is a data-uri string from the UI
169
+ try:
170
+ # We need to decode it to a PIL Image
171
+ img = utils.decode_image(str(value))
172
+ port.value = img
173
+ port.state = DataPortState.DIRTY
174
+ except Exception:
175
+ # invalid image data
176
+ port.value = None
177
+ else:
178
+ port.value = None
179
+ port.state = DataPortState.DIRTY
180
+
181
  elif isinstance(node, TextToImageNode) and field == "image":
182
  node.image_src = "" if value is None else str(value)
183
 
 
211
 
212
  self.graph.connections = [c for c in self.graph.connections if not should_remove(c)]
213
 
214
+ def _on_edges_change(self, changes: list[dict[str, Any]]) -> None:
215
+ """Handle generic edge changes.
216
+
217
+ Vue Flow sends EdgeChange objects.
218
+ """
219
+ if not changes:
220
+ return
221
+
222
+ edges_to_delete: list[dict[str, Any]] = []
223
+
224
+ for change in changes:
225
+ if not isinstance(change, dict):
226
+ continue
227
+
228
+ if change.get("type") == "remove":
229
+ # this is quite ugly, the edge should be sent directly or read from the list of edges from the graph
230
+ edge = change
231
+ if isinstance(edge, dict):
232
+ edges_to_delete.append(edge)
233
+
234
+ if edges_to_delete:
235
+ self._on_edges_delete(edges_to_delete)
236
+
237
  def _on_nodes_delete(self, nodes: list[dict[str, Any]]) -> None:
238
  """Remove nodes and all their connections when Vue deletes them."""
239
  if not nodes:
240
  return
241
 
242
  node_ids = {n.get("id") for n in nodes if isinstance(n, dict) and n.get("id")}
243
+ self._delete_nodes(node_ids)
244
+
245
+ def _on_nodes_change(self, changes: list[dict[str, Any]]) -> None:
246
+ """Handle generic node changes.
247
+
248
+ Currently supports:
249
+ - type == "remove": delete node and related connections
250
+ - type == "position": update node position like 'node_moved'
251
+ Other change types (select, dimensions, etc.) do not affect the DataGraph.
252
+ """
253
+ if not changes:
254
+ return
255
+
256
+ node_ids_to_delete: set[str] = set()
257
+
258
+ for change in changes:
259
+ if not isinstance(change, dict):
260
+ continue
261
+
262
+ ctype = change.get("type")
263
+ node_id = change.get("id")
264
+
265
+ if ctype == "remove" and node_id:
266
+ node_ids_to_delete.add(node_id)
267
+ elif ctype == "position" and node_id:
268
+ # Vue Flow usually sends 'position' for logical node coordinates.
269
+ position = change.get("position") or change.get("positionAbsolute") or {}
270
+ self._update_node_position(node_id, position)
271
+
272
+ if node_ids_to_delete:
273
+ self._delete_nodes(node_ids_to_delete)
274
+
275
+ def _delete_nodes(self, node_ids: set[str]) -> None:
276
  if not node_ids:
277
  return
278
 
nodes/data_types.py CHANGED
@@ -17,7 +17,7 @@ ImageType = DataType(
17
  id=DataTypeId.IMAGE,
18
  name="Image",
19
  py_type=Image.Image,
20
- encode=utils.encode_image_png,
21
- decode=utils.decode_image_png,
22
  color="#a855f7",
23
  )
 
17
  id=DataTypeId.IMAGE,
18
  name="Image",
19
  py_type=Image.Image,
20
+ encode=utils.encode_image,
21
+ decode=utils.decode_image,
22
  color="#a855f7",
23
  )
nodes/image_data.py CHANGED
@@ -26,7 +26,7 @@ class ImageDataNode(NodeInstance):
26
  async def process(self) -> None:
27
  """Data nodes do not compute anything, they just expose their current value.
28
 
29
- The UI writes into the "text" output port. Here we simply ensure the port
30
  exists and treat the current value as the result.
31
  """
32
  port = self.outputs.get("image") if self.outputs is not None else None
@@ -51,7 +51,14 @@ class ImageDataNodeRenderable(VueNodeRenderable[ImageDataNode]):
51
  )
52
 
53
  port = node.outputs.get("image") if node.outputs is not None else None
54
- data.values["image"] = "" if port is None or port.value is None else str(port.value)
 
 
 
 
 
 
 
55
 
56
  data.executable = False
57
  return data
 
26
  async def process(self) -> None:
27
  """Data nodes do not compute anything, they just expose their current value.
28
 
29
+ The UI writes into the "image" output port. Here we simply ensure the port
30
  exists and treat the current value as the result.
31
  """
32
  port = self.outputs.get("image") if self.outputs is not None else None
 
51
  )
52
 
53
  port = node.outputs.get("image") if node.outputs is not None else None
54
+ val = ""
55
+ if port is not None and port.value is not None:
56
+ if port.schema.dtype.encode:
57
+ val = port.schema.dtype.encode(port.value)
58
+ else:
59
+ val = str(port.value)
60
+
61
+ data.values["image"] = val
62
 
63
  data.executable = False
64
  return data
nodes/runtime.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
 
4
 
5
  from dataflow.graph import DataGraph
6
  from dataflow.nodes_base import NodeInstance
@@ -10,16 +11,48 @@ from .text_to_image import TextToImageNode
10
 
11
  @dataclass(slots=True)
12
  class GraphRuntime:
13
- """Bridge between the DataGraph and the Vue canvas.
14
-
15
- Keeps execution logic here so nodes can be executed from UI or from code.
16
- """
17
-
18
  graph: DataGraph
19
  canvas: VueFlowCanvas
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  async def execute_node(self, node: NodeInstance | str) -> None:
22
- """Execute a node and keep the canvas in sync with its state."""
 
 
 
 
 
 
 
 
 
23
  if isinstance(node, str):
24
  node_id = node
25
  node_obj = self.graph.nodes.get(node_id)
@@ -29,26 +62,101 @@ class GraphRuntime:
29
  node_obj = node
30
  node_id = node.node_id
31
 
32
- # show spinner
33
- self.canvas.set_node_processing(node_id, True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  try:
35
- # run the dataflow execution chain (this will also execute upstream nodes)
36
  print(f"Runtime: Executing {node_id}...")
 
 
 
 
 
 
 
 
 
37
  await self.graph.execute(node_obj)
38
- await self._sync_node_to_ui(node_obj)
 
 
 
 
 
 
 
 
 
39
  except Exception as e:
40
  print(f"Runtime execution failed: {e}")
41
  import traceback
42
  traceback.print_exc()
43
  finally:
44
- # hide spinner
45
- self.canvas.set_node_processing(node_id, False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  async def _sync_node_to_ui(self, node: NodeInstance) -> None:
48
  """Push relevant node state back to the Vue nodes."""
49
  if isinstance(node, TextToImageNode):
50
  image_src = "" if node.image_src is None else str(node.image_src)
51
- values = {"image": image_src}
52
- if node.error:
53
- values["error"] = node.error
 
54
  self.canvas.update_node_values(node.node_id, values)
 
 
1
  from __future__ import annotations
2
 
3
  from dataclasses import dataclass
4
+ from typing import Any
5
 
6
  from dataflow.graph import DataGraph
7
  from dataflow.nodes_base import NodeInstance
 
11
 
12
  @dataclass(slots=True)
13
  class GraphRuntime:
 
 
 
 
 
14
  graph: DataGraph
15
  canvas: VueFlowCanvas
16
 
17
+ def _get_execution_chain(self, root: NodeInstance) -> list[NodeInstance]:
18
+ """Return all upstream nodes that belong to the chain of root.
19
+
20
+ Order is upstream first, root last. This is only for UI purposes
21
+ (spinners, progress, syncing), not to force execution.
22
+ """
23
+ result: list[NodeInstance] = []
24
+ visited: set[str] = set()
25
+ connections = getattr(self.graph, "connections", []) or []
26
+
27
+ def visit(node: NodeInstance) -> None:
28
+ node_id = getattr(node, "node_id", None)
29
+ if node_id is None or node_id in visited:
30
+ return
31
+ visited.add(node_id)
32
+
33
+ for conn in connections:
34
+ try:
35
+ if conn.end_node is node:
36
+ visit(conn.start_node)
37
+ except AttributeError:
38
+ continue
39
+
40
+ result.append(node)
41
+
42
+ visit(root)
43
+ return result
44
+
45
  async def execute_node(self, node: NodeInstance | str) -> None:
46
+ """Execute a node and keep the canvas in sync with its state.
47
+
48
+ Rules:
49
+ - Clicked node is always reset and re run.
50
+ - Upstream nodes are never reset here and may reuse cached data.
51
+ - Which nodes actually execute is decided by DataGraph and node logic.
52
+ - As soon as a node finishes, it is synced to the UI.
53
+ """
54
+ import asyncio
55
+
56
  if isinstance(node, str):
57
  node_id = node
58
  node_obj = self.graph.nodes.get(node_id)
 
62
  node_obj = node
63
  node_id = node.node_id
64
 
65
+ execution_chain = self._get_execution_chain(node_obj)
66
+
67
+ # show spinner on all nodes in the chain
68
+ for n in execution_chain:
69
+ nid = getattr(n, "node_id", None)
70
+ if nid:
71
+ self.canvas.set_node_processing(nid, True)
72
+
73
+ # progress polling for nodes that expose progress_value
74
+ stop_progress = False
75
+ progress_tasks: list[asyncio.Task] = []
76
+
77
+ async def progress_updater(n: NodeInstance, nid: str) -> None:
78
+ last_value: Any = None
79
+ while not stop_progress:
80
+ await asyncio.sleep(0.1)
81
+ if not hasattr(n, "progress_value"):
82
+ continue
83
+ current = getattr(n, "progress_value", None)
84
+ message = getattr(n, "progress_message", None)
85
+ if current is None:
86
+ continue
87
+ if current != last_value:
88
+ self.canvas.update_node_progress(nid, current, message)
89
+ last_value = current
90
+
91
+ for n in execution_chain:
92
+ nid = getattr(n, "node_id", None)
93
+ if nid and hasattr(n, "progress_value"):
94
+ progress_tasks.append(asyncio.create_task(progress_updater(n, nid)))
95
+
96
+ # callback from DataGraph after each node is executed
97
+ async def on_node_executed(executed_node: NodeInstance) -> None:
98
+ # Only nodes that actually ran will call this.
99
+ await self._sync_node_to_ui(executed_node)
100
+
101
+ # save previous callback so we can restore it
102
+ previous_cb = getattr(self.graph, "_on_node_executed", None)
103
+
104
  try:
 
105
  print(f"Runtime: Executing {node_id}...")
106
+
107
+ # clicked node is always reset, upstream nodes are not
108
+ if hasattr(node_obj, "reset_node"):
109
+ node_obj.reset_node()
110
+
111
+ # register our per node callback
112
+ self.graph.set_on_node_executed(on_node_executed)
113
+
114
+ # let DataGraph drive which nodes actually execute
115
  await self.graph.execute(node_obj)
116
+
117
+ # one more sync for all nodes in the chain, in case some did not run
118
+ for n in execution_chain:
119
+ await self._sync_node_to_ui(n)
120
+
121
+ # nice "complete" flash for the clicked node if it has progress
122
+ if hasattr(node_obj, "progress_value"):
123
+ self.canvas.update_node_progress(node_id, 1.0, "Complete")
124
+ await asyncio.sleep(0.3)
125
+
126
  except Exception as e:
127
  print(f"Runtime execution failed: {e}")
128
  import traceback
129
  traceback.print_exc()
130
  finally:
131
+ # restore previous graph callback
132
+ self.graph.set_on_node_executed(previous_cb)
133
+
134
+ # stop progress updaters
135
+ stop_progress = True
136
+ for t in progress_tasks:
137
+ t.cancel()
138
+ try:
139
+ await t
140
+ except asyncio.CancelledError:
141
+ pass
142
+
143
+ # hide spinner on all nodes in the chain
144
+ for n in execution_chain:
145
+ nid = getattr(n, "node_id", None)
146
+ if nid:
147
+ self.canvas.set_node_processing(nid, False)
148
+
149
+ # reset progress on the clicked node
150
+ if hasattr(node_obj, "progress_value"):
151
+ self.canvas.update_node_progress(node_id, 0.0, None)
152
 
153
  async def _sync_node_to_ui(self, node: NodeInstance) -> None:
154
  """Push relevant node state back to the Vue nodes."""
155
  if isinstance(node, TextToImageNode):
156
  image_src = "" if node.image_src is None else str(node.image_src)
157
+ values: dict[str, Any] = {
158
+ "image": image_src,
159
+ "error": node.error or None,
160
+ }
161
  self.canvas.update_node_values(node.node_id, values)
162
+ # add other node types here as needed
nodes/session.py CHANGED
@@ -1,10 +1,14 @@
1
  from __future__ import annotations
2
 
 
3
  import itertools
4
  import json
5
  from dataclasses import dataclass, field
6
  from typing import Any
7
 
 
 
 
8
  from dataflow.codecs import connection_to_vueflow_edge
9
  from dataflow.connection import Connection
10
  from dataflow.enums import NodeKind
@@ -64,7 +68,11 @@ class GraphSession:
64
  renderer=renderer,
65
  controller=controller,
66
  )
67
- session._build_initial_graph()
 
 
 
 
68
  return session
69
 
70
  def _build_initial_graph(self) -> None:
@@ -101,6 +109,19 @@ class GraphSession:
101
  },
102
  ]
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  def attach_canvas(self, canvas: VueFlowCanvas) -> None:
105
  """Attach a VueFlowCanvas to this session."""
106
  self.runtime = GraphRuntime(graph=self.graph, canvas=canvas)
@@ -248,6 +269,202 @@ class GraphSession:
248
  self.graph.add_node(node_obj)
249
  return self.renderer.to_vue_node(node_obj)
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
  async def handle_ui_event(
252
  self, event: dict[str, Any], canvas: VueFlowCanvas
253
  ) -> None:
@@ -255,6 +472,8 @@ class GraphSession:
255
  event_type = event.get("type")
256
  payload = event.get("payload")
257
 
 
 
258
  if event_type == "execute_node":
259
  payload_dict = payload or {}
260
  node_id = payload_dict.get("id")
@@ -266,15 +485,108 @@ class GraphSession:
266
  else:
267
  self.runtime.canvas = canvas
268
 
269
- await self.runtime.execute_node(node_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  elif event_type == "create_node":
272
  payload_dict = payload or {}
273
  kind_value = payload_dict.get("kind")
274
  position = payload_dict.get("position") or {}
 
 
275
  vue_node = self.create_node(kind_value, position)
276
  if vue_node is not None:
277
  canvas.add_node(vue_node)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
  else:
280
  self.controller.handle_event(event)
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
+ import io
4
  import itertools
5
  import json
6
  from dataclasses import dataclass, field
7
  from typing import Any
8
 
9
+ from PIL import Image
10
+ from nicegui import app, ui
11
+
12
  from dataflow.codecs import connection_to_vueflow_edge
13
  from dataflow.connection import Connection
14
  from dataflow.enums import NodeKind
 
68
  renderer=renderer,
69
  controller=controller,
70
  )
71
+
72
+ # Try to restore from storage, otherwise build initial graph
73
+ if not session.restore_from_storage():
74
+ session._build_initial_graph()
75
+
76
  return session
77
 
78
  def _build_initial_graph(self) -> None:
 
109
  },
110
  ]
111
 
112
+ def save_to_storage(self) -> None:
113
+ """Persist current graph state to tab storage."""
114
+ json_str = self.to_json()
115
+ app.storage.tab["graph_json"] = json_str
116
+
117
+ def restore_from_storage(self) -> bool:
118
+ """Try to load graph from tab storage. Returns True if successful."""
119
+ json_str = app.storage.tab.get("graph_json")
120
+ if json_str:
121
+ self.load_from_json(json_str)
122
+ return True
123
+ return False
124
+
125
  def attach_canvas(self, canvas: VueFlowCanvas) -> None:
126
  """Attach a VueFlowCanvas to this session."""
127
  self.runtime = GraphRuntime(graph=self.graph, canvas=canvas)
 
269
  self.graph.add_node(node_obj)
270
  return self.renderer.to_vue_node(node_obj)
271
 
272
+ def _auto_connect_new_node(
273
+ self, vue_node: dict[str, Any], pending_connection: dict[str, Any], canvas: VueFlowCanvas
274
+ ) -> None:
275
+ """Automatically connect a newly created node based on pending connection info."""
276
+ if not pending_connection or not vue_node:
277
+ return
278
+
279
+ new_node_id = vue_node.get("id")
280
+ if not new_node_id:
281
+ return
282
+
283
+ new_node_obj = self.graph.nodes.get(new_node_id)
284
+ if not new_node_obj:
285
+ return
286
+
287
+ # Extract pending connection info
288
+ source_node_id = pending_connection.get("nodeId")
289
+ source_handle_id = pending_connection.get("handleId")
290
+ handle_type = pending_connection.get("handleType")
291
+
292
+ if not source_node_id or not source_handle_id or not handle_type:
293
+ return
294
+
295
+ source_node = self.graph.nodes.get(source_node_id)
296
+ if not source_node:
297
+ return
298
+
299
+ # Parse handle ID to get port name
300
+ source_port_name = source_handle_id.split(":")[-1] if ":" in source_handle_id else source_handle_id
301
+
302
+ # Determine connection direction and find compatible port
303
+ if handle_type == "source":
304
+ # User dragged from an output, connect to first compatible input of new node
305
+ source_port = source_node.outputs.get(source_port_name)
306
+ if not source_port:
307
+ return
308
+
309
+ # Find first compatible input port on new node
310
+ target_port = None
311
+ target_port_name = None
312
+ for port_name, port in new_node_obj.inputs.items():
313
+ if port.schema.dtype.id == source_port.schema.dtype.id:
314
+ target_port = port
315
+ target_port_name = port_name
316
+ break
317
+
318
+ if target_port:
319
+ # Create connection event
320
+ connection_params = {
321
+ "source": source_node_id,
322
+ "target": new_node_id,
323
+ "sourceHandle": source_handle_id,
324
+ "targetHandle": f"{new_node_id}:{target_port_name}"
325
+ }
326
+
327
+ # Emit connect event to controller
328
+ self.controller.handle_event({
329
+ "type": "connect",
330
+ "payload": connection_params
331
+ })
332
+
333
+ # Add edge to canvas
334
+ edge_id = f"{source_node_id}-{source_handle_id}->{new_node_id}-{target_port_name}"
335
+ canvas.add_edge({
336
+ "id": edge_id,
337
+ "source": source_node_id,
338
+ "target": new_node_id,
339
+ "sourceHandle": source_handle_id,
340
+ "targetHandle": f"{new_node_id}:{target_port_name}"
341
+ })
342
+
343
+ elif handle_type == "target":
344
+ # User dragged from an input, connect from first compatible output of new node
345
+ target_port = source_node.inputs.get(source_port_name)
346
+ if not target_port:
347
+ return
348
+
349
+ # Find first compatible output port on new node
350
+ source_port = None
351
+ source_port_name_new = None
352
+ for port_name, port in new_node_obj.outputs.items():
353
+ if port.schema.dtype.id == target_port.schema.dtype.id:
354
+ source_port = port
355
+ source_port_name_new = port_name
356
+ break
357
+
358
+ if source_port:
359
+ # Create connection event
360
+ connection_params = {
361
+ "source": new_node_id,
362
+ "target": source_node_id,
363
+ "sourceHandle": f"{new_node_id}:{source_port_name_new}",
364
+ "targetHandle": source_handle_id
365
+ }
366
+
367
+ # Emit connect event to controller
368
+ self.controller.handle_event({
369
+ "type": "connect",
370
+ "payload": connection_params
371
+ })
372
+
373
+ # Add edge to canvas
374
+ edge_id = f"{new_node_id}-{source_port_name_new}->{source_node_id}-{source_handle_id}"
375
+ canvas.add_edge({
376
+ "id": edge_id,
377
+ "source": new_node_id,
378
+ "target": source_node_id,
379
+ "sourceHandle": f"{new_node_id}:{source_port_name_new}",
380
+ "targetHandle": source_handle_id
381
+ })
382
+
383
+ def duplicate_node(
384
+ self, source_node_id: str, position: dict[str, Any] | None = None
385
+ ) -> dict[str, Any] | None:
386
+ """Duplicate an existing node and return its Vue node dict."""
387
+ # Get the source node from the graph
388
+ source_node = self.graph.nodes.get(source_node_id)
389
+ if source_node is None:
390
+ return None
391
+
392
+ # Create a new node of the same type
393
+ node_id = self.next_ui_node_id()
394
+ try:
395
+ node_obj = self.registry.create(source_node.node_type.kind, node_id)
396
+ except KeyError:
397
+ return None
398
+
399
+ # Set position
400
+ if position:
401
+ x = position.get("x")
402
+ y = position.get("y")
403
+ if x is not None:
404
+ try:
405
+ node_obj.x = float(x)
406
+ except (TypeError, ValueError):
407
+ pass
408
+ if y is not None:
409
+ try:
410
+ node_obj.y = float(y)
411
+ except (TypeError, ValueError):
412
+ pass
413
+ else:
414
+ # Offset from source node if no position specified
415
+ node_obj.x = source_node.x + 50
416
+ node_obj.y = source_node.y + 50
417
+
418
+ # Copy input port values (but not outputs - those should be recomputed)
419
+ for port_name, source_port in source_node.inputs.items():
420
+ if port_name in node_obj.inputs:
421
+ # Copy the value if it's set
422
+ if source_port.value is not None:
423
+ node_obj.inputs[port_name].value = source_port.value
424
+
425
+ # Copy content based on node type
426
+
427
+ # Copy text content for TextDataNode
428
+ if isinstance(source_node, TextDataNode) and isinstance(node_obj, TextDataNode):
429
+ source_port = source_node.outputs.get("text") if source_node.outputs else None
430
+ if source_port and source_port.value is not None:
431
+ target_port = node_obj.outputs.get("text") if node_obj.outputs else None
432
+ if target_port:
433
+ target_port.value = str(source_port.value)
434
+
435
+ # Copy image content for ImageDataNode
436
+ elif isinstance(source_node, ImageDataNode) and isinstance(node_obj, ImageDataNode):
437
+ source_port = source_node.outputs.get("image") if source_node.outputs else None
438
+ if source_port and source_port.value is not None:
439
+ target_port = node_obj.outputs.get("image") if node_obj.outputs else None
440
+ if target_port:
441
+ # Copy PIL Image
442
+ if isinstance(source_port.value, Image.Image):
443
+ target_port.value = source_port.value.copy()
444
+ else:
445
+ target_port.value = source_port.value
446
+
447
+ # Copy generated image for TextToImageNode
448
+ elif isinstance(source_node, TextToImageNode) and isinstance(node_obj, TextToImageNode):
449
+ # Copy image_src
450
+ if source_node.image_src:
451
+ node_obj.image_src = source_node.image_src
452
+ # Copy decoded_image
453
+ if source_node.decoded_image is not None:
454
+ node_obj.decoded_image = source_node.decoded_image.copy()
455
+ # copy the output port value if it exists
456
+ source_port = source_node.outputs.get("image") if source_node.outputs else None
457
+ if source_port and source_port.value is not None:
458
+ target_port = node_obj.outputs.get("image") if node_obj.outputs else None
459
+ if target_port:
460
+ if isinstance(source_port.value, Image.Image):
461
+ target_port.value = source_port.value.copy()
462
+ else:
463
+ target_port.value = source_port.value
464
+
465
+ self.graph.add_node(node_obj)
466
+ return self.renderer.to_vue_node(node_obj)
467
+
468
  async def handle_ui_event(
469
  self, event: dict[str, Any], canvas: VueFlowCanvas
470
  ) -> None:
 
472
  event_type = event.get("type")
473
  payload = event.get("payload")
474
 
475
+ save_needed = False
476
+
477
  if event_type == "execute_node":
478
  payload_dict = payload or {}
479
  node_id = payload_dict.get("id")
 
485
  else:
486
  self.runtime.canvas = canvas
487
 
488
+ n = ui.notification(timeout=None)
489
+ n.message = f"Generating"
490
+ n.spinner = True
491
+
492
+ try:
493
+ await self.runtime.execute_node(node_id)
494
+
495
+ n.message = "Done!"
496
+ n.type = "positive"
497
+ n.spinner = False
498
+ except Exception as ex:
499
+ n.message = f"Error: {ex}"
500
+ n.type = "negative"
501
+ n.spinner = False
502
+
503
+ n.dismiss()
504
+
505
+ save_needed = True
506
+
507
+ elif event_type == "reset_node":
508
+ payload_dict = payload or {}
509
+ node_id = payload_dict.get("id")
510
+ if not node_id:
511
+ return
512
+
513
+ node = self.graph.nodes.get(node_id)
514
+ if node is None:
515
+ return
516
+
517
+ with ui.dialog() as dialog, ui.card():
518
+ ui.label("Are you sure?")
519
+ with ui.row():
520
+ ui.button("Yes", on_click=lambda: dialog.submit("Yes"))
521
+ ui.button("No", on_click=lambda: dialog.submit("No"))
522
+ result = await dialog
523
+
524
+ if result == "No":
525
+ return
526
+
527
+ # only special case TextToImageNode for now
528
+ if isinstance(node, TextToImageNode):
529
+ node.reset_node()
530
+
531
+ # sync cleared state to UI
532
+ canvas.update_node_values(
533
+ node_id, {
534
+ "image": "",
535
+ "error": None,
536
+ })
537
+
538
+ save_needed = True
539
+ ui.notify(f"Node has been reset!", type="positive")
540
+
541
+ elif event_type == "download_node":
542
+ payload_dict = payload or {}
543
+ node_id = payload_dict.get("id")
544
+ if not node_id:
545
+ return
546
+
547
+ node = self.graph.nodes.get(node_id)
548
+ if node is None:
549
+ return
550
+
551
+ if isinstance(node, TextToImageNode):
552
+ buf = io.BytesIO()
553
+ node.decoded_image.save(buf, format="PNG")
554
+ buf.seek(0)
555
+
556
+ png_bytes = buf.getvalue()
557
+ ui.download(png_bytes, filename="generated.png", media_type="image/png")
558
+ ui.notify("Image downloaded!", type="positive")
559
 
560
  elif event_type == "create_node":
561
  payload_dict = payload or {}
562
  kind_value = payload_dict.get("kind")
563
  position = payload_dict.get("position") or {}
564
+ pending_connection = payload_dict.get("pendingConnection")
565
+
566
  vue_node = self.create_node(kind_value, position)
567
  if vue_node is not None:
568
  canvas.add_node(vue_node)
569
+ save_needed = True
570
+
571
+ # If there's a pending connection, auto-connect
572
+ if pending_connection:
573
+ self._auto_connect_new_node(vue_node, pending_connection, canvas)
574
+
575
+
576
+ elif event_type == "duplicate_node":
577
+ payload_dict = payload or {}
578
+ source_node_id = payload_dict.get("sourceNodeId")
579
+ position = payload_dict.get("position") or {}
580
+ if source_node_id:
581
+ vue_node = self.duplicate_node(source_node_id, position)
582
+ if vue_node is not None:
583
+ canvas.add_node(vue_node)
584
+ save_needed = True
585
 
586
  else:
587
  self.controller.handle_event(event)
588
+ # Assume any controller action (move, connect, delete, field change) modifies state
589
+ save_needed = True
590
+
591
+ if save_needed:
592
+ self.save_to_storage()
nodes/text_to_image.py CHANGED
@@ -1,9 +1,15 @@
1
  from __future__ import annotations
2
 
 
 
3
  import typing
4
  from dataclasses import dataclass
 
5
 
6
- from dataflow.enums import PortDirection, NodeKind, ConnectMultiplicity
 
 
 
7
  from dataflow.nodes_base import NodeType, NodeInstance
8
  from dataflow.ports import PortSchema
9
  from services.image.ImageGenerator import ImageGenerator
@@ -13,14 +19,23 @@ from . import utils
13
  from .data_types import ImageType, TextType
14
  from .vue_nodes import VueNodeRenderable, VueNodeData
15
 
 
 
 
 
 
16
  TextToImageNodeType = NodeType(
17
  kind=NodeKind.TEXT_TO_IMAGE,
18
  display_name="Image Generation",
19
  inputs=[
20
  PortSchema(name="text", dtype=TextType, direction=PortDirection.INPUT,
21
  multiplicity=ConnectMultiplicity.MULTIPLE, capacity=3),
22
- PortSchema(name="image", dtype=ImageType, direction=PortDirection.INPUT,
23
- multiplicity=ConnectMultiplicity.MULTIPLE, capacity=3)
 
 
 
 
24
  ],
25
  outputs=[
26
  PortSchema(name="image", dtype=ImageType, direction=PortDirection.OUTPUT)
@@ -32,16 +47,62 @@ TextToImageNodeType = NodeType(
32
  class TextToImageNode(NodeInstance):
33
  # string that UI can bind to; may be url or base64 data
34
  image_src: str | None = None
 
35
  error: str | None = None
 
 
36
 
37
  async def process(self) -> None:
 
38
  self.error = None
39
 
40
  text_input = self.inputs["text"]
41
- image_input = self.inputs["image"]
42
  image_output = self.outputs["image"]
43
 
44
- # Handle multiple inputs (concatenation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  prompt = ""
46
  if isinstance(text_input.value, list):
47
  prompt = "\n\n".join([str(v) for v in text_input.value if v])
@@ -52,53 +113,114 @@ class TextToImageNode(NodeInstance):
52
  self.error = "Could not generate image! No text input."
53
  return
54
 
55
- # optional image inputs
56
- images = None
57
- if isinstance(image_input.value, list):
58
- images = image_input.value
 
 
 
59
 
60
  registry = get_registry()
61
- image_service = typing.cast(ImageGenerator, registry.create(TaskType.IMAGE, "google_gemini_image"))
62
 
63
  def on_progress(value: float, message: str | None):
64
- # In a fuller implementation, we would emit events to UI here
65
- pass
 
66
 
67
  try:
68
- import asyncio
69
- from functools import partial
70
-
71
  # Run non-blocking
72
  loop = asyncio.get_running_loop()
73
  result = await loop.run_in_executor(
74
  None,
75
- partial(image_service.generate, prompt, images=images, progress=on_progress)
76
  )
77
 
78
- url = utils.encode_image_png(result.image)
79
  self.image_src = url
80
 
 
81
  image_output.value = result.image
82
  except Exception as e:
 
83
  self.error = f"Error: {str(e)}"
84
- self.image_src = None
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
 
87
  class TextToImageNodeRenderable(VueNodeRenderable[TextToImageNode]):
88
  def to_vue_node_data(self, node: TextToImageNode) -> VueNodeData:
89
  data = VueNodeData()
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  data.fields.append(
92
  {
93
  "name": "image",
94
  "kind": "image_result",
95
  "label": "Image",
96
  "editable": False,
97
- "placeholder": "No image yet",
98
  }
99
  )
100
 
101
- data.values["image"] = "" if node.image_src is None else str(node.image_src)
 
 
 
 
 
 
 
 
 
 
102
  if node.error:
103
  data.values["error"] = node.error
104
 
 
1
  from __future__ import annotations
2
 
3
+ import asyncio
4
+ import os
5
  import typing
6
  from dataclasses import dataclass
7
+ from functools import partial
8
 
9
+ from PIL import Image
10
+
11
+ from dataflow.codecs import _schema_json
12
+ from dataflow.enums import PortDirection, NodeKind, ConnectMultiplicity, DataPortState
13
  from dataflow.nodes_base import NodeType, NodeInstance
14
  from dataflow.ports import PortSchema
15
  from services.image.ImageGenerator import ImageGenerator
 
19
  from .data_types import ImageType, TextType
20
  from .vue_nodes import VueNodeRenderable, VueNodeData
21
 
22
+ DEFAULT_IMAGE_MODEL_SERVICE_ID = os.getenv("DEFAULT_IMAGE_MODEL_SERVICE_ID", "fal_ai_image_nano_banana")
23
+
24
+ # number of image inputs
25
+ NUM_IMAGE_INPUTS = 3
26
+
27
  TextToImageNodeType = NodeType(
28
  kind=NodeKind.TEXT_TO_IMAGE,
29
  display_name="Image Generation",
30
  inputs=[
31
  PortSchema(name="text", dtype=TextType, direction=PortDirection.INPUT,
32
  multiplicity=ConnectMultiplicity.MULTIPLE, capacity=3),
33
+ PortSchema(name="image1", dtype=ImageType, direction=PortDirection.INPUT,
34
+ multiplicity=ConnectMultiplicity.SINGLE, capacity=1),
35
+ PortSchema(name="image2", dtype=ImageType, direction=PortDirection.INPUT,
36
+ multiplicity=ConnectMultiplicity.SINGLE, capacity=1),
37
+ PortSchema(name="image3", dtype=ImageType, direction=PortDirection.INPUT,
38
+ multiplicity=ConnectMultiplicity.SINGLE, capacity=1),
39
  ],
40
  outputs=[
41
  PortSchema(name="image", dtype=ImageType, direction=PortDirection.OUTPUT)
 
47
  class TextToImageNode(NodeInstance):
48
  # string that UI can bind to; may be url or base64 data
49
  image_src: str | None = None
50
+ decoded_image: Image.Image | None = None
51
  error: str | None = None
52
+ progress_value: float = 0.0
53
+ progress_message: str | None = None
54
 
55
  async def process(self) -> None:
56
+ # reset error message (but is not directly updated)
57
  self.error = None
58
 
59
  text_input = self.inputs["text"]
 
60
  image_output = self.outputs["image"]
61
 
62
+ print(f"\n[DEBUG TextToImageNode] ===== STARTING PROCESS =====")
63
+ print(f"[DEBUG] Text input value: {text_input.value is not None}")
64
+ print(f"[DEBUG] Text input state: {text_input.state}")
65
+
66
+ # Collect images from ordered input ports (image1, image2, image3, ...)
67
+ # Only include images that are actually connected (not None)
68
+ images: list[Image.Image] = []
69
+ for i in range(1, NUM_IMAGE_INPUTS + 1):
70
+ image_port_name = f"image{i}"
71
+ if image_port_name in self.inputs:
72
+ image_port = self.inputs[image_port_name]
73
+ print(
74
+ f"[DEBUG] Port {image_port_name}: value={type(image_port.value).__name__ if image_port.value is not None else 'None'}, state={image_port.state}")
75
+ if image_port.value is not None:
76
+ if isinstance(image_port.value, Image.Image):
77
+ images.append(image_port.value)
78
+ print(f"[DEBUG] -> Added PIL Image directly (size: {image_port.value.size})")
79
+ else:
80
+ try:
81
+ decoded = utils.decode_image(str(image_port.value))
82
+ images.append(decoded)
83
+ print(f"[DEBUG] -> Decoded and added image (size: {decoded.size})")
84
+ except Exception as e:
85
+ print(f"[DEBUG] -> FAILED to decode: {e}")
86
+ pass
87
+
88
+ # Check if inputs are clean - only consider ports that have actual values (are connected)
89
+ all_inputs_clean = text_input.state == DataPortState.CLEAN
90
+ for i in range(1, NUM_IMAGE_INPUTS + 1):
91
+ image_port_name = f"image{i}"
92
+ if image_port_name in self.inputs:
93
+ image_port = self.inputs[image_port_name]
94
+ # Only check state if the port has a value (is connected)
95
+ if image_port.value is not None and image_port.state != DataPortState.CLEAN:
96
+ all_inputs_clean = False
97
+ break
98
+
99
+ if self.image_src is not None and all_inputs_clean:
100
+ if self.decoded_image is None:
101
+ self.decoded_image = utils.decode_image(self.image_src)
102
+ image_output.value = self.decoded_image
103
+ return
104
+
105
+ # Handle multiple text inputs (concatenation)
106
  prompt = ""
107
  if isinstance(text_input.value, list):
108
  prompt = "\n\n".join([str(v) for v in text_input.value if v])
 
113
  self.error = "Could not generate image! No text input."
114
  return
115
 
116
+ # Pass ordered images list (empty list if none connected)
117
+ # image1 is first, image2 is second, etc.
118
+ images_list = images if images else None
119
+
120
+ print(f"[DEBUG] Collected {len(images)} total images")
121
+ print(f"[DEBUG] images_list is None: {images_list is None}")
122
+ print(f"[DEBUG] Will call generator with prompt and {len(images_list) if images_list else 0} images")
123
 
124
  registry = get_registry()
125
+ image_service = typing.cast(ImageGenerator, registry.create(TaskType.IMAGE, DEFAULT_IMAGE_MODEL_SERVICE_ID))
126
 
127
  def on_progress(value: float, message: str | None):
128
+ # Store progress in the node so runtime can update UI
129
+ self.progress_value = value
130
+ self.progress_message = message
131
 
132
  try:
 
 
 
133
  # Run non-blocking
134
  loop = asyncio.get_running_loop()
135
  result = await loop.run_in_executor(
136
  None,
137
+ partial(image_service.generate, prompt, images=images_list, progress=on_progress)
138
  )
139
 
140
+ url = utils.encode_image(result.image)
141
  self.image_src = url
142
 
143
+ self.decoded_image = result.image
144
  image_output.value = result.image
145
  except Exception as e:
146
+ self.reset_node()
147
  self.error = f"Error: {str(e)}"
148
+
149
+ def reset_node(self) -> None:
150
+ self.image_src = None
151
+ self.decoded_image = None
152
+ self.error = None
153
+ self.progress_value = 0.0
154
+ self.progress_message = None
155
+
156
+ # clear output port value and mark it dirty
157
+ out = self.outputs.get("image") if self.outputs is not None else None
158
+ if out is not None:
159
+ out.value = None
160
+ out.state = DataPortState.DIRTY
161
 
162
 
163
  class TextToImageNodeRenderable(VueNodeRenderable[TextToImageNode]):
164
  def to_vue_node_data(self, node: TextToImageNode) -> VueNodeData:
165
  data = VueNodeData()
166
 
167
+ # Create custom inputs with even spacing
168
+ custom_inputs = []
169
+
170
+ # Collect all input ports in order
171
+ input_ports = []
172
+ for port_schema in node.node_type.inputs:
173
+ port_dict = _schema_json(port_schema)
174
+ input_ports.append((port_schema.name, port_dict))
175
+
176
+ # Evenly space all inputs across the node height
177
+ num_inputs = len(input_ports)
178
+ if num_inputs > 0:
179
+ # Start the first input at 30% of the node height, end at 85%, evenly distributed
180
+ # It's eyeballed, so it might not be perfect!
181
+ start_percent = 30
182
+ end_percent = 85
183
+ if num_inputs == 1:
184
+ positions = [50] # Center if only one
185
+ else:
186
+ step = (end_percent - start_percent) / (num_inputs - 1)
187
+ positions = [start_percent + i * step for i in range(num_inputs)]
188
+
189
+ # Assign positions to ports
190
+ for idx, (port_name, port_dict) in enumerate(input_ports):
191
+ port_dict["top"] = positions[idx]
192
+ custom_inputs.append(port_dict)
193
+ else:
194
+ # Fallback: use default positioning
195
+ for _, port_dict in input_ports:
196
+ custom_inputs.append(port_dict)
197
+
198
+ data.inputs = custom_inputs
199
+
200
+ # Added min-height (px) to make the node taller
201
+ data.min_height = 100
202
+
203
  data.fields.append(
204
  {
205
  "name": "image",
206
  "kind": "image_result",
207
  "label": "Image",
208
  "editable": False,
209
+ "placeholder": "No image",
210
  }
211
  )
212
 
213
+ # Use image_src if available (which is already encoded url)
214
+ val = ""
215
+ if node.image_src:
216
+ val = node.image_src
217
+ else:
218
+ port = node.outputs.get("image") if node.outputs is not None else None
219
+ if port is not None and port.value is not None:
220
+ # We know ImageType has an encoder
221
+ val = utils.encode_image(port.value)
222
+
223
+ data.values["image"] = val
224
  if node.error:
225
  data.values["error"] = node.error
226
 
nodes/utils.py CHANGED
@@ -4,18 +4,56 @@ import io
4
  from PIL import Image
5
 
6
 
7
- def encode_image_png(img: Image.Image) -> str:
 
 
 
 
 
 
 
 
 
8
  buffer = io.BytesIO()
9
- img.save(buffer, format="PNG")
 
 
 
 
10
  raw = buffer.getvalue()
11
  b64 = base64.b64encode(raw).decode("ascii")
12
- return f"data:image/png;base64,{b64}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
 
 
 
14
 
15
- def decode_image_png(data: str) -> Image.Image:
16
- prefix = "data:image/png;base64,"
17
- if data.startswith(prefix):
18
- data = data[len(prefix):]
19
- raw = base64.b64decode(data.encode("ascii"))
20
  buffer = io.BytesIO(raw)
21
- return Image.open(buffer)
 
 
 
 
 
 
4
  from PIL import Image
5
 
6
 
7
+ def encode_image(img: Image.Image, image_format: str | None = None) -> str:
8
+ """Encode a PIL image to a data URL in the specified format.
9
+
10
+ image_format examples: "PNG", "JPEG", "WEBP"
11
+ If not provided, fall back to WEBP.
12
+ """
13
+
14
+ if image_format is None:
15
+ image_format = "WEBP"
16
+
17
  buffer = io.BytesIO()
18
+ try:
19
+ img.save(buffer, format=image_format)
20
+ except Exception as e:
21
+ raise ValueError(f"Failed to encode image as {image_format}: {e}")
22
+
23
  raw = buffer.getvalue()
24
  b64 = base64.b64encode(raw).decode("ascii")
25
+ mime = f"image/{image_format.lower()}"
26
+ return f"data:{mime};base64,{b64}"
27
+
28
+
29
+ def decode_image(data: str) -> Image.Image:
30
+ """Decode a data URL or base64 string into a PIL.Image.Image.
31
+
32
+ Supports any format that Pillow supports (png, jpg, webp, etc).
33
+ """
34
+
35
+ if not isinstance(data, str):
36
+ raise TypeError("decode_image expects a string")
37
+
38
+ # strip data URL prefix if present
39
+ if data.startswith("data:"):
40
+ try:
41
+ _, _, b64 = data.partition(",")
42
+ data = b64
43
+ except Exception:
44
+ raise ValueError("Invalid data URL")
45
 
46
+ # base64 decode
47
+ try:
48
+ raw = base64.b64decode(data.encode("ascii"))
49
+ except Exception as e:
50
+ raise ValueError(f"Base64 decode error: {e}")
51
 
52
+ # open with Pillow (auto detects format)
 
 
 
 
53
  buffer = io.BytesIO(raw)
54
+ try:
55
+ img = Image.open(buffer)
56
+ img.load() # ensure fully loaded
57
+ return img
58
+ except Exception as e:
59
+ raise ValueError(f"Failed to decode image: {e}")
nodes/vue_nodes.py CHANGED
@@ -21,6 +21,13 @@ class VueNodeData:
21
  # Form like content
22
  fields: list[dict[str, Any]] = field(default_factory=list)
23
  values: dict[str, Any] = field(default_factory=dict)
 
 
 
 
 
 
 
24
 
25
  def to_extra_dict(self) -> dict[str, Any]:
26
  """Convert to the data_extra dict expected by node_to_vueflow."""
@@ -34,6 +41,12 @@ class VueNodeData:
34
  data["fields"] = self.fields
35
  if self.values:
36
  data["values"] = self.values
 
 
 
 
 
 
37
 
38
  return data
39
 
 
21
  # Form like content
22
  fields: list[dict[str, Any]] = field(default_factory=list)
23
  values: dict[str, Any] = field(default_factory=dict)
24
+
25
+ # Custom port positioning (optional)
26
+ inputs: list[dict[str, Any]] | None = None
27
+ outputs: list[dict[str, Any]] | None = None
28
+
29
+ # Custom node styling (optional)
30
+ min_height: int | None = None
31
 
32
  def to_extra_dict(self) -> dict[str, Any]:
33
  """Convert to the data_extra dict expected by node_to_vueflow."""
 
41
  data["fields"] = self.fields
42
  if self.values:
43
  data["values"] = self.values
44
+ if self.inputs is not None:
45
+ data["inputs"] = self.inputs
46
+ if self.outputs is not None:
47
+ data["outputs"] = self.outputs
48
+ if self.min_height is not None:
49
+ data["min_height"] = self.min_height
50
 
51
  return data
52
 
pyproject.toml CHANGED
@@ -4,14 +4,16 @@ version = "0.1.0"
4
  description = "Add your description here"
5
  requires-python = ">=3.12"
6
  dependencies = [
 
7
  "google-genai>=1.51.0",
8
  "nicegui>=3",
9
  "pillow>=12.0.0",
 
10
  "requests>=2.32.5",
11
  "uvicorn<=0.35.0",
12
  ]
13
 
14
  [dependency-groups]
15
  dev = [
16
- "huggingface-hub[cli]>=1.1.4",
17
  ]
 
4
  description = "Add your description here"
5
  requires-python = ">=3.12"
6
  dependencies = [
7
+ "fal-client>=0.9.1",
8
  "google-genai>=1.51.0",
9
  "nicegui>=3",
10
  "pillow>=12.0.0",
11
+ "python-dotenv>=1.2.1",
12
  "requests>=2.32.5",
13
  "uvicorn<=0.35.0",
14
  ]
15
 
16
  [dependency-groups]
17
  dev = [
18
+ "huggingface-hub>=1.1.4",
19
  ]
services/image/DummyImageGenerator.py CHANGED
@@ -2,6 +2,7 @@ from __future__ import annotations
2
 
3
  import io
4
  import random
 
5
  from collections.abc import Sequence
6
 
7
  import requests
@@ -43,6 +44,8 @@ class DummyImageGenerator(ImageGenerator):
43
 
44
  call_progress(progress, 0.4, f"Fetching random image from {url}")
45
 
 
 
46
  try:
47
  response = requests.get(url, timeout=10)
48
  except requests.RequestException as exc:
 
2
 
3
  import io
4
  import random
5
+ import time
6
  from collections.abc import Sequence
7
 
8
  import requests
 
44
 
45
  call_progress(progress, 0.4, f"Fetching random image from {url}")
46
 
47
+ time.sleep(0.5)
48
+
49
  try:
50
  response = requests.get(url, timeout=10)
51
  except requests.RequestException as exc:
services/image/FalAIImageGenerator.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import io
4
+ from abc import ABC, abstractmethod
5
+ from typing import Any, Sequence
6
+
7
+ import requests
8
+ from PIL import Image
9
+
10
+ from nodes.utils import encode_image
11
+ from services.exceptions import GenerationError
12
+ from services.image.ImageGenerationResult import ImageGenerationResult
13
+ from services.image.ImageGenerator import ImageGenerator
14
+ from services.progress import ProgressCallback, call_progress
15
+ from services.registry import register_service
16
+ from services.utils.fal_service import run_fal
17
+
18
+
19
+ class FalAIImageGenerator(ImageGenerator, ABC):
20
+ """Abstract base class for fal.ai based image generators.
21
+
22
+ This class implements the common pipeline:
23
+
24
+ - select a fal model slug to call
25
+ - build the input arguments
26
+ - call fal.ai
27
+ - download and decode the images
28
+
29
+ Concrete subclasses are responsible for:
30
+
31
+ - storing any model slugs they need
32
+ - implementing the selection strategy
33
+ - shaping the input arguments
34
+ """
35
+
36
+ service_id = "fal_ai_image"
37
+
38
+ def __init__(
39
+ self,
40
+ *,
41
+ service_model_name: str | None = None,
42
+ api_key: str | None = None,
43
+ extra_arguments: dict[str, Any] | None = None,
44
+ ) -> None:
45
+ # service_model_name is only for metadata in GenerationService
46
+ super().__init__(model_name=service_model_name)
47
+ self._api_key = api_key
48
+ self._extra_arguments: dict[str, Any] = extra_arguments or {}
49
+
50
+ def close(self) -> None:
51
+ return
52
+
53
+ # small helpers that do not know any model names
54
+
55
+ def _encode_images(self, images: Sequence[Image.Image]) -> list[str]:
56
+ return [encode_image(img) for img in images]
57
+
58
+ def _attach_image_list_argument(
59
+ self,
60
+ arguments: dict[str, Any],
61
+ images: Sequence[Image.Image],
62
+ arg_name: str,
63
+ ) -> None:
64
+ encoded_images = self._encode_images(images)
65
+ existing_value = arguments.get(arg_name)
66
+ if isinstance(existing_value, list):
67
+ arguments[arg_name] = existing_value + encoded_images
68
+ else:
69
+ arguments[arg_name] = encoded_images
70
+
71
+ def _base_arguments(self, **kwargs: Any) -> dict[str, Any]:
72
+ """Start from configured extra arguments and apply kwargs as overrides."""
73
+ arguments = dict(self._extra_arguments)
74
+ for key, value in kwargs.items():
75
+ if value is not None:
76
+ arguments[key] = value
77
+ return arguments
78
+
79
+ # abstract hooks for concrete implementations
80
+
81
+ @abstractmethod
82
+ def _select_model(
83
+ self,
84
+ *,
85
+ prompt: str,
86
+ images: Sequence[Image.Image] | None,
87
+ **kwargs: Any,
88
+ ) -> str:
89
+ """Return the fal model slug to call for this request."""
90
+ raise NotImplementedError
91
+
92
+ @abstractmethod
93
+ def _build_arguments(
94
+ self,
95
+ *,
96
+ prompt: str,
97
+ images: Sequence[Image.Image] | None,
98
+ **kwargs: Any,
99
+ ) -> dict[str, Any]:
100
+ """Return the arguments payload for fal.ai."""
101
+ raise NotImplementedError
102
+
103
+ # shared generation pipeline
104
+
105
+ def generate(
106
+ self,
107
+ prompt: str,
108
+ images: Sequence[Image.Image] | None = None,
109
+ *,
110
+ progress: ProgressCallback | None = None,
111
+ **kwargs: Any,
112
+ ) -> ImageGenerationResult:
113
+ model = self._select_model(prompt=prompt, images=images, **kwargs)
114
+
115
+ call_progress(progress, 0.1, "Encoding inputs for fal.ai image model")
116
+
117
+ arguments = self._build_arguments(
118
+ prompt=prompt,
119
+ images=images,
120
+ **kwargs,
121
+ )
122
+
123
+ call_progress(progress, 0.4, "Calling fal.ai image model")
124
+
125
+ response = run_fal(
126
+ model=model,
127
+ arguments=arguments,
128
+ api_key=self._api_key,
129
+ )
130
+
131
+ call_progress(progress, 0.7, "Downloading images from fal.ai")
132
+
133
+ raw_images = response.get("images")
134
+ if not isinstance(raw_images, list) or not raw_images:
135
+ raise GenerationError(
136
+ "fal.ai image model did not return any images in the response."
137
+ )
138
+
139
+ decoded_images: list[Image.Image] = []
140
+
141
+ for item in raw_images:
142
+ if not isinstance(item, dict):
143
+ continue
144
+ url = item.get("url")
145
+ if not isinstance(url, str) or not url:
146
+ continue
147
+
148
+ try:
149
+ resp = requests.get(url, timeout=30)
150
+ except requests.RequestException as exc:
151
+ raise GenerationError(
152
+ f"Failed to download image from fal.ai URL {url!r}."
153
+ ) from exc
154
+
155
+ if resp.status_code != 200:
156
+ raise GenerationError(
157
+ f"fal.ai image URL {url!r} returned status code {resp.status_code}."
158
+ )
159
+
160
+ try:
161
+ img = Image.open(io.BytesIO(resp.content)).convert("RGBA")
162
+ except OSError as exc:
163
+ raise GenerationError(
164
+ "Received invalid image data from fal.ai."
165
+ ) from exc
166
+
167
+ img.load()
168
+ decoded_images.append(img)
169
+
170
+ if not decoded_images:
171
+ raise GenerationError(
172
+ "fal.ai image model did not yield any decodable images."
173
+ )
174
+
175
+ call_progress(progress, 0.95, "Preparing fal.ai image result")
176
+
177
+ return ImageGenerationResult(
178
+ provider="fal.ai",
179
+ model=model,
180
+ images=decoded_images,
181
+ raw_response=response,
182
+ )
183
+
184
+
185
+ # concrete model combinations
186
+
187
+
188
+ @register_service
189
+ class FalAINanoBananaGenerator(FalAIImageGenerator):
190
+ """fal-ai/nano-banana text and edit combination."""
191
+
192
+ service_id = "fal_ai_image_nano_banana"
193
+
194
+ def __init__(self, api_key: str | None = None) -> None:
195
+ super().__init__(
196
+ service_model_name="fal-ai/nano-banana",
197
+ api_key=api_key,
198
+ extra_arguments={},
199
+ )
200
+ self._text_model: str = "fal-ai/nano-banana"
201
+ self._edit_model: str = "fal-ai/nano-banana/edit"
202
+ # nano banana edit expects images under "image_urls"
203
+ self._image_argument: str = "image_urls"
204
+
205
+ def _select_model(
206
+ self,
207
+ *,
208
+ prompt: str,
209
+ images: Sequence[Image.Image] | None,
210
+ **kwargs: Any,
211
+ ) -> str:
212
+ if not images:
213
+ return self._text_model
214
+ return self._edit_model
215
+
216
+ def _build_arguments(
217
+ self,
218
+ *,
219
+ prompt: str,
220
+ images: Sequence[Image.Image] | None,
221
+ **kwargs: Any,
222
+ ) -> dict[str, Any]:
223
+ arguments = self._base_arguments(**kwargs)
224
+ arguments["prompt"] = prompt
225
+
226
+ if not images:
227
+ return arguments
228
+
229
+ self._attach_image_list_argument(arguments, images, self._image_argument)
230
+ return arguments
231
+
232
+ @classmethod
233
+ def default_model_name(cls) -> str:
234
+ return "fal-ai/nano-banana"
235
+
236
+
237
+ @register_service
238
+ class FalAIReveGenerator(FalAIImageGenerator):
239
+ """fal-ai/reve combination:
240
+ - text to image
241
+ - fast edit for one image
242
+ - fast remix for multiple images
243
+ """
244
+
245
+ service_id = "fal_ai_image_reve"
246
+
247
+ def __init__(self, api_key: str | None = None) -> None:
248
+ super().__init__(
249
+ service_model_name="fal-ai/reve",
250
+ api_key=api_key,
251
+ extra_arguments={},
252
+ )
253
+ self._text_model: str = "fal-ai/reve/text-to-image"
254
+ self._edit_model: str = "fal-ai/reve/fast/edit"
255
+ self._remix_model: str = "fal-ai/reve/fast/remix"
256
+
257
+ def _select_model(
258
+ self,
259
+ *,
260
+ prompt: str,
261
+ images: Sequence[Image.Image] | None,
262
+ **kwargs: Any,
263
+ ) -> str:
264
+ count = len(images) if images is not None else 0
265
+
266
+ if count == 0:
267
+ return self._text_model
268
+ if count == 1:
269
+ return self._edit_model
270
+ return self._remix_model
271
+
272
+ def _build_arguments(
273
+ self,
274
+ *,
275
+ prompt: str,
276
+ images: Sequence[Image.Image] | None,
277
+ **kwargs: Any,
278
+ ) -> dict[str, Any]:
279
+ arguments = self._base_arguments(**kwargs)
280
+ arguments["prompt"] = prompt
281
+
282
+ if not images:
283
+ # text to image ignores image inputs
284
+ return arguments
285
+
286
+ count = len(images)
287
+
288
+ if count == 1:
289
+ # fast edit expects a single "image_url"
290
+ arguments["image_url"] = encode_image(images[0])
291
+ return arguments
292
+
293
+ # fast remix expects "image_urls" as a list
294
+ self._attach_image_list_argument(arguments, images, "image_urls")
295
+ return arguments
296
+
297
+ @classmethod
298
+ def default_model_name(cls) -> str:
299
+ return "fal-ai/reve"
services/image/ImageGenerator.py CHANGED
@@ -1,5 +1,5 @@
1
  from abc import ABC, abstractmethod
2
- from typing import Sequence
3
 
4
  from PIL import Image
5
 
@@ -20,6 +20,7 @@ class ImageGenerator(GenerationService, ABC):
20
  images: Sequence[Image.Image] | None = None,
21
  *,
22
  progress: ProgressCallback | None = None,
 
23
  ) -> ImageGenerationResult:
24
  """Generate images from a prompt and optional images."""
25
  raise NotImplementedError
 
1
  from abc import ABC, abstractmethod
2
+ from typing import Any, Sequence
3
 
4
  from PIL import Image
5
 
 
20
  images: Sequence[Image.Image] | None = None,
21
  *,
22
  progress: ProgressCallback | None = None,
23
+ **kwargs: Any,
24
  ) -> ImageGenerationResult:
25
  """Generate images from a prompt and optional images."""
26
  raise NotImplementedError
services/image/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from services.image.DummyImageGenerator import DummyImageGenerator
2
- from services.image.GoogleImageGenerator import GoogleImageGenerator
 
 
1
  from services.image.DummyImageGenerator import DummyImageGenerator
2
+ from services.image.GoogleImageGenerator import GoogleImageGenerator
3
+ from services.image.FalAIImageGenerator import FalAIImageGenerator
services/text/FalAITextGenerator.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Sequence
5
+
6
+ from PIL import Image
7
+
8
+ from services.exceptions import GenerationError
9
+ from services.progress import ProgressCallback, call_progress
10
+ from services.registry import register_service
11
+ from services.text.TextGenerationResult import TextGenerationResult
12
+ from services.text.TextGenerator import TextGenerator
13
+ from services.utils.fal_service import DEFAULT_FAL_TEXT_MODEL, run_fal
14
+
15
+
16
+ @register_service
17
+ class FalAITextGenerator(TextGenerator):
18
+ """Text generator backed by fal.ai OpenRouter "any LLM" endpoint.
19
+
20
+ This uses the fal model "openrouter/router" by default and expects
21
+ an underlying LLM name in the "model" field of the input payload.
22
+
23
+ Example extra_arguments for Gemini:
24
+
25
+ extra_arguments={"model": "google/gemini-2.5-flash"}
26
+ """
27
+
28
+ service_id = "fal_ai_text"
29
+
30
+ def __init__(
31
+ self,
32
+ model_name: str | None = None,
33
+ api_key: str | None = None,
34
+ extra_arguments: dict[str, Any] | None = None,
35
+ ) -> None:
36
+ """
37
+ Parameters
38
+ ----------
39
+ model_name:
40
+ fal model slug, typically "openrouter/router".
41
+ api_key:
42
+ Optional fal.ai API key. If omitted, FAL_KEY from the environment
43
+ will be used.
44
+ extra_arguments:
45
+ Extra fields for the OpenRouter router input, for example:
46
+ - {"model": "google/gemini-2.5-flash"}
47
+ - {"system_prompt": "...", "temperature": 0.7, "max_tokens": 1024}
48
+ """
49
+ super().__init__(model_name=model_name)
50
+ self._api_key = api_key
51
+ self._extra_arguments: dict[str, Any] = extra_arguments or {}
52
+
53
+ @classmethod
54
+ def default_model_name(cls) -> str:
55
+ # This is the fal model id, not the underlying LLM
56
+ return DEFAULT_FAL_TEXT_MODEL
57
+
58
+ def close(self) -> None:
59
+ return
60
+
61
+ @staticmethod
62
+ def _extract_text(response: dict[str, Any]) -> str:
63
+ # OpenRouter router schema: "output" contains the generated text.
64
+ output = response.get("output")
65
+ if isinstance(output, str) and output.strip():
66
+ return output
67
+
68
+ text = response.get("text")
69
+ if isinstance(text, str) and text.strip():
70
+ return text
71
+
72
+ raise GenerationError(
73
+ "fal.ai text model did not return any text output."
74
+ )
75
+
76
+ def _build_arguments(self, prompt: str) -> dict[str, Any]:
77
+ arguments: dict[str, Any] = dict(self._extra_arguments)
78
+ arguments["prompt"] = prompt
79
+
80
+ if "model" not in arguments:
81
+ env_model = os.getenv("OPENROUTER_DEFAULT_MODEL")
82
+ if not env_model:
83
+ raise GenerationError(
84
+ "FalAITextGenerator requires a target OpenRouter model name. "
85
+ "Provide it via extra_arguments['model'] or set the "
86
+ "OPENROUTER_DEFAULT_MODEL environment variable."
87
+ )
88
+ arguments["model"] = env_model
89
+
90
+ return arguments
91
+
92
+ def generate(
93
+ self,
94
+ prompt: str,
95
+ images: Sequence[Image.Image] | None = None,
96
+ *,
97
+ progress: ProgressCallback | None = None,
98
+ ) -> TextGenerationResult:
99
+ call_progress(progress, 0.1, "Encoding inputs for fal.ai text model")
100
+
101
+ if images:
102
+ # The OpenRouter router endpoint is text only.
103
+ raise GenerationError(
104
+ "FalAITextGenerator does not currently support image inputs."
105
+ )
106
+
107
+ arguments = self._build_arguments(prompt=prompt)
108
+
109
+ call_progress(progress, 0.4, "Calling fal.ai OpenRouter router model")
110
+
111
+ response = run_fal(
112
+ model=self.model_name,
113
+ arguments=arguments,
114
+ api_key=self._api_key,
115
+ )
116
+
117
+ call_progress(progress, 0.7, "Decoding text from fal.ai response")
118
+
119
+ text_output = self._extract_text(response)
120
+
121
+ call_progress(progress, 0.95, "Preparing fal.ai text result")
122
+
123
+ return TextGenerationResult(
124
+ provider="fal.ai",
125
+ model=self.model_name,
126
+ text=text_output,
127
+ raw_response=response,
128
+ )
services/text/__init__.py CHANGED
@@ -1,2 +1,3 @@
1
  from services.text.DummyTextGenerator import DummyTextGenerator
2
  from services.text.GoogleTextGenerator import GoogleTextGenerator
 
 
1
  from services.text.DummyTextGenerator import DummyTextGenerator
2
  from services.text.GoogleTextGenerator import GoogleTextGenerator
3
+ from services.text.FalAITextGenerator import FalAITextGenerator
services/utils/fal_service.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any, Mapping
5
+
6
+ from services.exceptions import GenerationError
7
+
8
+ try:
9
+ import fal_client # type: ignore[import-not-found]
10
+ except ModuleNotFoundError: # pragma: no cover
11
+ fal_client = None # type: ignore[assignment]
12
+
13
+ # Fal text uses the OpenRouter "any LLM" router endpoint
14
+ DEFAULT_FAL_TEXT_MODEL = os.getenv("FAL_TEXT_MODEL", "openrouter/router")
15
+
16
+ # Default image model: nano banana as main image backend
17
+ DEFAULT_FAL_IMAGE_MODEL = os.getenv("FAL_IMAGE_MODEL", "fal-ai/nano-banana")
18
+
19
+
20
+ def _ensure_fal_client(api_key: str | None = None) -> Any:
21
+ if fal_client is None:
22
+ raise GenerationError(
23
+ "The 'fal-client' package is required to use the fal.ai backends."
24
+ )
25
+
26
+ if api_key:
27
+ # fal_client uses the FAL_KEY environment variable for authentication
28
+ os.environ.setdefault("FAL_KEY", api_key)
29
+
30
+ return fal_client
31
+
32
+
33
+ def run_fal(
34
+ model: str,
35
+ arguments: Mapping[str, Any],
36
+ api_key: str | None = None,
37
+ ) -> dict[str, Any]:
38
+ """Run a fal.ai model synchronously and return the raw response."""
39
+ client = _ensure_fal_client(api_key)
40
+
41
+ # Check if FAL_KEY is set
42
+ if not os.getenv("FAL_KEY"):
43
+ raise GenerationError(
44
+ "FAL_KEY environment variable is not set. "
45
+ "Please set it in your .env file or environment. "
46
+ "Get your API key from https://fal.ai/dashboard"
47
+ )
48
+
49
+ try:
50
+ result = client.run(model, arguments=dict(arguments))
51
+ except Exception as exc: # pragma: no cover
52
+ # Include the original error message for better debugging
53
+ error_msg = str(exc) if str(exc) else type(exc).__name__
54
+ raise GenerationError(
55
+ f"fal.ai request to {model!r} failed: {error_msg}"
56
+ ) from exc
57
+
58
+ if not isinstance(result, dict):
59
+ raise GenerationError(
60
+ f"fal.ai model {model!r} returned an unexpected response type "
61
+ f"{type(result).__name__}."
62
+ )
63
+
64
+ return result
uv.lock CHANGED
@@ -1,5 +1,5 @@
1
  version = 1
2
- revision = 2
3
  requires-python = ">=3.12"
4
 
5
  [[package]]
@@ -273,6 +273,19 @@ wheels = [
273
  { url = "https://files.pythonhosted.org/packages/66/dd/f95350e853a4468ec37478414fc04ae2d61dad7a947b3015c3dcc51a09b9/docutils-0.22.2-py3-none-any.whl", hash = "sha256:b0e98d679283fc3bb0ead8a5da7f501baa632654e7056e9c5846842213d674d8", size = 632667, upload-time = "2025-09-20T17:55:43.052Z" },
274
  ]
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
  [[package]]
277
  name = "fastapi"
278
  version = "0.120.2"
@@ -523,6 +536,15 @@ wheels = [
523
  { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
524
  ]
525
 
 
 
 
 
 
 
 
 
 
526
  [[package]]
527
  name = "huggingface-hub"
528
  version = "1.1.4"
@@ -790,9 +812,11 @@ name = "nicegui-graph"
790
  version = "0.1.0"
791
  source = { virtual = "." }
792
  dependencies = [
 
793
  { name = "google-genai" },
794
  { name = "nicegui" },
795
  { name = "pillow" },
 
796
  { name = "requests" },
797
  { name = "uvicorn" },
798
  ]
@@ -804,15 +828,17 @@ dev = [
804
 
805
  [package.metadata]
806
  requires-dist = [
 
807
  { name = "google-genai", specifier = ">=1.51.0" },
808
  { name = "nicegui", specifier = ">=3" },
809
  { name = "pillow", specifier = ">=12.0.0" },
 
810
  { name = "requests", specifier = ">=2.32.5" },
811
  { name = "uvicorn", specifier = "<=0.35.0" },
812
  ]
813
 
814
  [package.metadata.requires-dev]
815
- dev = [{ name = "huggingface-hub", extras = ["cli"], specifier = ">=1.1.4" }]
816
 
817
  [[package]]
818
  name = "orjson"
 
1
  version = 1
2
+ revision = 3
3
  requires-python = ">=3.12"
4
 
5
  [[package]]
 
273
  { url = "https://files.pythonhosted.org/packages/66/dd/f95350e853a4468ec37478414fc04ae2d61dad7a947b3015c3dcc51a09b9/docutils-0.22.2-py3-none-any.whl", hash = "sha256:b0e98d679283fc3bb0ead8a5da7f501baa632654e7056e9c5846842213d674d8", size = 632667, upload-time = "2025-09-20T17:55:43.052Z" },
274
  ]
275
 
276
+ [[package]]
277
+ name = "fal-client"
278
+ version = "0.9.1"
279
+ source = { registry = "https://pypi.org/simple" }
280
+ dependencies = [
281
+ { name = "httpx" },
282
+ { name = "httpx-sse" },
283
+ ]
284
+ sdist = { url = "https://files.pythonhosted.org/packages/b8/a1/98ab1cea4c2424ee612292bc92b07905e1a15a05584f6c263cde38e6a3a2/fal_client-0.9.1.tar.gz", hash = "sha256:c8f7f88f79c4b4c4f069be9f571be924dc7c4a6bf07c252fe0b75f3c46c8d66d", size = 17085, upload-time = "2025-11-13T18:15:09.911Z" }
285
+ wheels = [
286
+ { url = "https://files.pythonhosted.org/packages/8b/57/775821a71459f2b83bbaa59452a4b1e4772f7c770de88a6f591c9d43c7c8/fal_client-0.9.1-py3-none-any.whl", hash = "sha256:8eba86c947299852c8306f685eee883ce01856543bf4344b87f65abd4b7d7622", size = 11157, upload-time = "2025-11-13T18:15:08.528Z" },
287
+ ]
288
+
289
  [[package]]
290
  name = "fastapi"
291
  version = "0.120.2"
 
536
  { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517, upload-time = "2024-12-06T15:37:21.509Z" },
537
  ]
538
 
539
+ [[package]]
540
+ name = "httpx-sse"
541
+ version = "0.4.3"
542
+ source = { registry = "https://pypi.org/simple" }
543
+ sdist = { url = "https://files.pythonhosted.org/packages/0f/4c/751061ffa58615a32c31b2d82e8482be8dd4a89154f003147acee90f2be9/httpx_sse-0.4.3.tar.gz", hash = "sha256:9b1ed0127459a66014aec3c56bebd93da3c1bc8bb6618c8082039a44889a755d", size = 15943, upload-time = "2025-10-10T21:48:22.271Z" }
544
+ wheels = [
545
+ { url = "https://files.pythonhosted.org/packages/d2/fd/6668e5aec43ab844de6fc74927e155a3b37bf40d7c3790e49fc0406b6578/httpx_sse-0.4.3-py3-none-any.whl", hash = "sha256:0ac1c9fe3c0afad2e0ebb25a934a59f4c7823b60792691f779fad2c5568830fc", size = 8960, upload-time = "2025-10-10T21:48:21.158Z" },
546
+ ]
547
+
548
  [[package]]
549
  name = "huggingface-hub"
550
  version = "1.1.4"
 
812
  version = "0.1.0"
813
  source = { virtual = "." }
814
  dependencies = [
815
+ { name = "fal-client" },
816
  { name = "google-genai" },
817
  { name = "nicegui" },
818
  { name = "pillow" },
819
+ { name = "python-dotenv" },
820
  { name = "requests" },
821
  { name = "uvicorn" },
822
  ]
 
828
 
829
  [package.metadata]
830
  requires-dist = [
831
+ { name = "fal-client", specifier = ">=0.9.1" },
832
  { name = "google-genai", specifier = ">=1.51.0" },
833
  { name = "nicegui", specifier = ">=3" },
834
  { name = "pillow", specifier = ">=12.0.0" },
835
+ { name = "python-dotenv", specifier = ">=1.2.1" },
836
  { name = "requests", specifier = ">=2.32.5" },
837
  { name = "uvicorn", specifier = "<=0.35.0" },
838
  ]
839
 
840
  [package.metadata.requires-dev]
841
+ dev = [{ name = "huggingface-hub", specifier = ">=1.1.4" }]
842
 
843
  [[package]]
844
  name = "orjson"
velai_app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ from nicegui import ui, events
6
+
7
+ from dataflow.ui.vueflow_canvas import VueFlowCanvas
8
+ from nodes.session import GraphSession
9
+
10
+
11
+ class VelaiApp:
12
+ def __init__(self, session: GraphSession) -> None:
13
+ self.session = session
14
+ self.canvas: Optional[VueFlowCanvas] = None
15
+
16
+ def build(self) -> None:
17
+ session = self.session
18
+
19
+ initial_nodes = session.initial_vue_nodes()
20
+ initial_edges = session.initial_vue_edges()
21
+
22
+ with ui.column().classes("w-full h-screen no-wrap"):
23
+ # header
24
+ with ui.row().classes("w-full items-center justify-between px-4 py-2 bg-grey-2"):
25
+ # ui.label("velai").classes("text-lg font-bold")
26
+ ui.image("assets/logo.png").classes("w-16").style("margin-right: 10px;")
27
+ with ui.row().classes("items-center gap-2"):
28
+
29
+ with ui.button("Info", on_click=lambda: dialog.open()).props("color='black'"):
30
+ #show a popup with the information about the app
31
+ with ui.dialog() as dialog:
32
+ with ui.card():
33
+ ui.button(icon='close', on_click=dialog.close).props("color='black'").classes('absolute top-2 right-2')
34
+ ui.label("Welcome to VELAI!").classes("text-lg font-bold")
35
+ ui.label("This is a node-based environment for generating images.\nIt uses the Gemini 2.5 Flash Preview (Nano Banana) and Reve models.")
36
+ ui.label("As it's a work in progress, there might be some bugs and issues. Please report them to us :)")
37
+
38
+ with ui.dropdown_button("Add Node", auto_close=True).props("color='black'"):
39
+ kinds = session.creatable_node_types
40
+ for k in kinds:
41
+ ui.menu_item(
42
+ k["title"],
43
+ on_click=lambda _, kind=k["kind"]: self.add_node_action(kind),
44
+ )
45
+
46
+ with ui.dropdown_button("Project", auto_close=True).props("color='black'"):
47
+ ui.menu_item("Import", on_click=self.import_action)
48
+ ui.menu_item("Export", on_click=self.export_action)
49
+ ui.menu_item("Clear", on_click=self.clear_graph_action).props("flat color='negative'")
50
+
51
+ # canvas area fills the remaining height
52
+ with ui.row().classes("w-full flex-1 no-wrap"):
53
+ self.canvas = VueFlowCanvas(
54
+ creatable_node_types=session.creatable_node_types,
55
+ ).classes("w-full h-full flex-1")
56
+
57
+ session.attach_canvas(self.canvas)
58
+
59
+ async def handle_event(e) -> None:
60
+ await session.handle_ui_event(e.args, self.canvas)
61
+
62
+ self.canvas.on("vf_event", handle_event)
63
+ self.canvas.set_graph(initial_nodes, initial_edges)
64
+
65
+ def add_node_action(self, kind_value: str) -> None:
66
+ if self.canvas is None:
67
+ return
68
+
69
+ import random
70
+
71
+ vue_node = self.session.create_node(
72
+ kind_value,
73
+ position={
74
+ "x": 100 + random.randint(0, 50),
75
+ "y": 100 + random.randint(0, 50),
76
+ },
77
+ )
78
+ if vue_node is not None:
79
+ self.canvas.add_node(vue_node)
80
+
81
+ def export_action(self) -> None:
82
+ json_str = self.session.to_json()
83
+ ui.download(json_str.encode("utf-8"), "graph.json")
84
+
85
+ def import_action(self) -> None:
86
+ if self.canvas is None:
87
+ return
88
+
89
+ session = self.session
90
+ canvas = self.canvas
91
+
92
+ with ui.dialog() as dialog:
93
+ with ui.card():
94
+ ui.label("Upload graph JSON")
95
+
96
+ async def handle_upload(e: events.UploadEventArguments) -> None:
97
+ data = await e.file.read()
98
+ content = data.decode("utf-8")
99
+
100
+ session.load_from_json(content)
101
+ session.save_to_storage()
102
+
103
+ nodes = session.initial_vue_nodes()
104
+ edges = session.initial_vue_edges()
105
+ canvas.set_graph(nodes, edges)
106
+
107
+ dialog.close()
108
+
109
+ ui.upload(
110
+ label="Choose JSON file",
111
+ on_upload=handle_upload,
112
+ max_files=1,
113
+ ).props("accept=.json auto-upload")
114
+
115
+ ui.button("Cancel", on_click=dialog.close).props("color='black'")
116
+
117
+ dialog.open()
118
+
119
+ def clear_graph_action(self) -> None:
120
+ if self.canvas is None:
121
+ return
122
+
123
+ self.session.clear_graph()
124
+ self.session.save_to_storage()
125
+ self.canvas.set_graph([], [])