File size: 11,615 Bytes
d868fac 3025bb3 d868fac 670c7ff d868fac 670c7ff d868fac 670c7ff d868fac 670c7ff d868fac 670c7ff 3025bb3 670c7ff 3025bb3 d868fac 3025bb3 d868fac 3025bb3 d868fac 3025bb3 d868fac 3025bb3 d868fac 670c7ff abd08cb 95d3973 abd08cb 95d3973 abd08cb 95d3973 abd08cb 670c7ff 3025bb3 670c7ff 3025bb3 670c7ff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 | from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from dataflow.connection import Connection
from dataflow.enums import DataPortState
from dataflow.graph import DataGraph
from . import utils
from .image_data import ImageDataNode
from .text_data import TextDataNode
from .text_to_image import TextToImageNode
@dataclass(slots=True)
class GraphController:
"""Handle Vue Flow events and keep the DataGraph plus node state in sync.
This lives in the app layer so the core dataflow package does not
depend on your concrete node types.
"""
graph: DataGraph
def handle_event(self, event: dict[str, Any]) -> None:
event_type = event.get("type")
raw_payload = event.get("payload")
if event_type == "connect":
payload = raw_payload or {}
self._on_connect(payload)
elif event_type == "node_moved":
payload = raw_payload or {}
self._on_node_moved(payload)
elif event_type == "node_field_changed":
payload = raw_payload or {}
self._on_node_field_changed(payload)
elif event_type == "edges_delete":
edges = raw_payload or []
if isinstance(edges, list):
self._on_edges_delete(edges)
elif event_type == "edges_change":
changes = raw_payload or []
if isinstance(changes, list):
self._on_edges_change(changes)
elif event_type == "nodes_delete":
nodes = raw_payload or []
if isinstance(nodes, list):
self._on_nodes_delete(nodes)
elif event_type == "nodes_change":
changes = raw_payload or []
if isinstance(changes, list):
self._on_nodes_change(changes)
# other events (graph_cleared, create_node, etc.) are currently ignored
# because they do not affect the DataGraph directly here
def _on_connect(self, payload: dict[str, Any]) -> None:
source_handle = payload.get("sourceHandle") or ""
target_handle = payload.get("targetHandle") or ""
def split(handle: str) -> tuple[str, str]:
if not handle:
return "", ""
if ":" in handle:
node_id, port = handle.split(":", 1)
return node_id, port
return "", handle
src_node_id, src_port_name = split(source_handle)
tgt_node_id, tgt_port_name = split(target_handle)
if not src_node_id:
src_node_id = payload.get("source") or ""
if not tgt_node_id:
tgt_node_id = payload.get("target") or ""
if not src_node_id or not tgt_node_id:
return
start_node = self.graph.nodes.get(src_node_id)
end_node = self.graph.nodes.get(tgt_node_id)
if start_node is None or end_node is None:
return
start_port = start_node.outputs.get(src_port_name) if start_node.outputs is not None else None
if start_port is None and start_node.inputs is not None:
start_port = start_node.inputs.get(src_port_name)
end_port = end_node.inputs.get(tgt_port_name) if end_node.inputs is not None else None
if end_port is None and end_node.outputs is not None:
end_port = end_node.outputs.get(tgt_port_name)
if start_port is None or end_port is None:
return
conn = Connection(
start_node=start_node,
start_port=start_port,
end_node=end_node,
end_port=end_port,
)
try:
self.graph.add_connection(conn)
except ValueError:
# datatype mismatch or capacity problems
return
# make the dataflow "dirty" because topology changed
if hasattr(start_port, "state"):
start_port.state = DataPortState.DIRTY
if hasattr(end_port, "state"):
end_port.state = DataPortState.DIRTY
# often you also want to clear the input value
if hasattr(end_port, "value"):
end_port.value = None
def _update_node_position(self, node_id: str, position: dict[str, Any]) -> None:
node = self.graph.nodes.get(node_id)
if node is None:
return
pos = position or {}
x = pos.get("x")
y = pos.get("y")
if x is not None:
try:
node.x = float(x)
except (TypeError, ValueError):
pass
if y is not None:
try:
node.y = float(y)
except (TypeError, ValueError):
pass
def _on_node_moved(self, payload: dict[str, Any]) -> None:
node_id = payload.get("id")
if not node_id:
return
position = payload.get("position") or {}
self._update_node_position(node_id, position)
def _on_node_field_changed(self, payload: dict[str, Any]) -> None:
node_id = payload.get("id")
field = payload.get("field")
value = payload.get("value")
if not node_id or not field:
return
node = self.graph.nodes.get(node_id)
if node is None:
return
if isinstance(node, TextDataNode) and field == "text":
port = node.outputs.get("text") if node.outputs is not None else None
if port is not None:
port.value = "" if value is None else str(value)
port.state = DataPortState.DIRTY
elif isinstance(node, ImageDataNode) and field == "image":
port = node.outputs.get("image") if node.outputs is not None else None
if port is not None:
if value:
# value is a data-uri string from the UI
try:
# We need to decode it to a PIL Image
img = utils.decode_image(str(value))
port.value = img
port.state = DataPortState.DIRTY
except Exception:
# invalid image data
port.value = None
else:
port.value = None
port.state = DataPortState.DIRTY
elif isinstance(node, TextToImageNode) and field == "image":
node.image_src = "" if value is None else str(value)
elif isinstance(node, TextToImageNode) and field == "aspect_ratio":
# Parse aspect_ratio value from the dropdown selection. Thanks Flo!
aspect_ratio_value = "1:1" # default
print(f"[DEBUG controller] Aspect ratio set to {value}, type={type(value)}")
if value is not None and isinstance(value, str):
aspect_ratio_value = value.strip()
else:
print(f"[DEBUG controller] aspect_ratio value is None, using default 1:1")
# Validate aspect ratio format (should be "W:H")
# Allow common formats: "1:1", "16:9", "9:16", etc. we could add more later if needed
if aspect_ratio_value and ":" in aspect_ratio_value:
parts = aspect_ratio_value.split(":")
if len(parts) == 2:
try:
# Validate that both parts are numeric
float(parts[0])
float(parts[1])
old_ratio = node.aspect_ratio
node.aspect_ratio = aspect_ratio_value
except (ValueError, TypeError):
# Invalid format, use default
node.aspect_ratio = "1:1"
print(f"[DEBUG Controller] Invalid numeric format, using default: {node.aspect_ratio}")
def _on_edges_delete(self, edges: list[dict[str, Any]]) -> None:
"""Remove matching connections from the DataGraph when edges are deleted in Vue."""
if not edges:
return
for edge in edges:
if not isinstance(edge, dict):
continue
src_id = edge.get("source")
tgt_id = edge.get("target")
src_handle = edge.get("sourceHandle") or ""
tgt_handle = edge.get("targetHandle") or ""
src_port = src_handle.split(":", 1)[1] if ":" in src_handle else None
tgt_port = tgt_handle.split(":", 1)[1] if ":" in tgt_handle else None
def should_remove(conn: Connection) -> bool:
if src_id and conn.start_node.node_id != src_id:
return False
if tgt_id and conn.end_node.node_id != tgt_id:
return False
if src_port is not None and conn.start_port.name != src_port:
return False
if tgt_port is not None and conn.end_port.name != tgt_port:
return False
return True
self.graph.connections = [c for c in self.graph.connections if not should_remove(c)]
def _on_edges_change(self, changes: list[dict[str, Any]]) -> None:
"""Handle generic edge changes.
Vue Flow sends EdgeChange objects.
"""
if not changes:
return
edges_to_delete: list[dict[str, Any]] = []
for change in changes:
if not isinstance(change, dict):
continue
if change.get("type") == "remove":
# this is quite ugly, the edge should be sent directly or read from the list of edges from the graph
edge = change
if isinstance(edge, dict):
edges_to_delete.append(edge)
if edges_to_delete:
self._on_edges_delete(edges_to_delete)
def _on_nodes_delete(self, nodes: list[dict[str, Any]]) -> None:
"""Remove nodes and all their connections when Vue deletes them."""
if not nodes:
return
node_ids = {n.get("id") for n in nodes if isinstance(n, dict) and n.get("id")}
self._delete_nodes(node_ids)
def _on_nodes_change(self, changes: list[dict[str, Any]]) -> None:
"""Handle generic node changes.
Currently supports:
- type == "remove": delete node and related connections
- type == "position": update node position like 'node_moved'
Other change types (select, dimensions, etc.) do not affect the DataGraph.
"""
if not changes:
return
node_ids_to_delete: set[str] = set()
for change in changes:
if not isinstance(change, dict):
continue
ctype = change.get("type")
node_id = change.get("id")
if ctype == "remove" and node_id:
node_ids_to_delete.add(node_id)
elif ctype == "position" and node_id:
# Vue Flow usually sends 'position' for logical node coordinates.
position = change.get("position") or change.get("positionAbsolute") or {}
self._update_node_position(node_id, position)
if node_ids_to_delete:
self._delete_nodes(node_ids_to_delete)
def _delete_nodes(self, node_ids: set[str]) -> None:
if not node_ids:
return
self.graph.connections = [
c
for c in self.graph.connections
if c.start_node.node_id not in node_ids and c.end_node.node_id not in node_ids
]
for node_id in node_ids:
self.graph.nodes.pop(node_id, None)
|