from __future__ import annotations from dataclasses import dataclass from typing import Any from dataflow.connection import Connection from dataflow.enums import DataPortState from dataflow.graph import DataGraph from . import utils from .image_data import ImageDataNode from .text_data import TextDataNode from .text_to_image import TextToImageNode @dataclass(slots=True) class GraphController: """Handle Vue Flow events and keep the DataGraph plus node state in sync. This lives in the app layer so the core dataflow package does not depend on your concrete node types. """ graph: DataGraph def handle_event(self, event: dict[str, Any]) -> None: event_type = event.get("type") raw_payload = event.get("payload") if event_type == "connect": payload = raw_payload or {} self._on_connect(payload) elif event_type == "node_moved": payload = raw_payload or {} self._on_node_moved(payload) elif event_type == "node_field_changed": payload = raw_payload or {} self._on_node_field_changed(payload) elif event_type == "edges_delete": edges = raw_payload or [] if isinstance(edges, list): self._on_edges_delete(edges) elif event_type == "edges_change": changes = raw_payload or [] if isinstance(changes, list): self._on_edges_change(changes) elif event_type == "nodes_delete": nodes = raw_payload or [] if isinstance(nodes, list): self._on_nodes_delete(nodes) elif event_type == "nodes_change": changes = raw_payload or [] if isinstance(changes, list): self._on_nodes_change(changes) # other events (graph_cleared, create_node, etc.) are currently ignored # because they do not affect the DataGraph directly here def _on_connect(self, payload: dict[str, Any]) -> None: source_handle = payload.get("sourceHandle") or "" target_handle = payload.get("targetHandle") or "" def split(handle: str) -> tuple[str, str]: if not handle: return "", "" if ":" in handle: node_id, port = handle.split(":", 1) return node_id, port return "", handle src_node_id, src_port_name = split(source_handle) tgt_node_id, tgt_port_name = split(target_handle) if not src_node_id: src_node_id = payload.get("source") or "" if not tgt_node_id: tgt_node_id = payload.get("target") or "" if not src_node_id or not tgt_node_id: return start_node = self.graph.nodes.get(src_node_id) end_node = self.graph.nodes.get(tgt_node_id) if start_node is None or end_node is None: return start_port = start_node.outputs.get(src_port_name) if start_node.outputs is not None else None if start_port is None and start_node.inputs is not None: start_port = start_node.inputs.get(src_port_name) end_port = end_node.inputs.get(tgt_port_name) if end_node.inputs is not None else None if end_port is None and end_node.outputs is not None: end_port = end_node.outputs.get(tgt_port_name) if start_port is None or end_port is None: return conn = Connection( start_node=start_node, start_port=start_port, end_node=end_node, end_port=end_port, ) try: self.graph.add_connection(conn) except ValueError: # datatype mismatch or capacity problems return # make the dataflow "dirty" because topology changed if hasattr(start_port, "state"): start_port.state = DataPortState.DIRTY if hasattr(end_port, "state"): end_port.state = DataPortState.DIRTY # often you also want to clear the input value if hasattr(end_port, "value"): end_port.value = None def _update_node_position(self, node_id: str, position: dict[str, Any]) -> None: node = self.graph.nodes.get(node_id) if node is None: return pos = position or {} x = pos.get("x") y = pos.get("y") if x is not None: try: node.x = float(x) except (TypeError, ValueError): pass if y is not None: try: node.y = float(y) except (TypeError, ValueError): pass def _on_node_moved(self, payload: dict[str, Any]) -> None: node_id = payload.get("id") if not node_id: return position = payload.get("position") or {} self._update_node_position(node_id, position) def _on_node_field_changed(self, payload: dict[str, Any]) -> None: node_id = payload.get("id") field = payload.get("field") value = payload.get("value") if not node_id or not field: return node = self.graph.nodes.get(node_id) if node is None: return if isinstance(node, TextDataNode) and field == "text": port = node.outputs.get("text") if node.outputs is not None else None if port is not None: port.value = "" if value is None else str(value) port.state = DataPortState.DIRTY elif isinstance(node, ImageDataNode) and field == "image": port = node.outputs.get("image") if node.outputs is not None else None if port is not None: if value: # value is a data-uri string from the UI try: # We need to decode it to a PIL Image img = utils.decode_image(str(value)) port.value = img port.state = DataPortState.DIRTY except Exception: # invalid image data port.value = None else: port.value = None port.state = DataPortState.DIRTY elif isinstance(node, TextToImageNode) and field == "image": node.image_src = "" if value is None else str(value) elif isinstance(node, TextToImageNode) and field == "aspect_ratio": # Parse aspect_ratio value from the dropdown selection. Thanks Flo! aspect_ratio_value = "1:1" # default print(f"[DEBUG controller] Aspect ratio set to {value}, type={type(value)}") if value is not None and isinstance(value, str): aspect_ratio_value = value.strip() else: print(f"[DEBUG controller] aspect_ratio value is None, using default 1:1") # Validate aspect ratio format (should be "W:H") # Allow common formats: "1:1", "16:9", "9:16", etc. we could add more later if needed if aspect_ratio_value and ":" in aspect_ratio_value: parts = aspect_ratio_value.split(":") if len(parts) == 2: try: # Validate that both parts are numeric float(parts[0]) float(parts[1]) old_ratio = node.aspect_ratio node.aspect_ratio = aspect_ratio_value except (ValueError, TypeError): # Invalid format, use default node.aspect_ratio = "1:1" print(f"[DEBUG Controller] Invalid numeric format, using default: {node.aspect_ratio}") def _on_edges_delete(self, edges: list[dict[str, Any]]) -> None: """Remove matching connections from the DataGraph when edges are deleted in Vue.""" if not edges: return for edge in edges: if not isinstance(edge, dict): continue src_id = edge.get("source") tgt_id = edge.get("target") src_handle = edge.get("sourceHandle") or "" tgt_handle = edge.get("targetHandle") or "" src_port = src_handle.split(":", 1)[1] if ":" in src_handle else None tgt_port = tgt_handle.split(":", 1)[1] if ":" in tgt_handle else None def should_remove(conn: Connection) -> bool: if src_id and conn.start_node.node_id != src_id: return False if tgt_id and conn.end_node.node_id != tgt_id: return False if src_port is not None and conn.start_port.name != src_port: return False if tgt_port is not None and conn.end_port.name != tgt_port: return False return True self.graph.connections = [c for c in self.graph.connections if not should_remove(c)] def _on_edges_change(self, changes: list[dict[str, Any]]) -> None: """Handle generic edge changes. Vue Flow sends EdgeChange objects. """ if not changes: return edges_to_delete: list[dict[str, Any]] = [] for change in changes: if not isinstance(change, dict): continue if change.get("type") == "remove": # this is quite ugly, the edge should be sent directly or read from the list of edges from the graph edge = change if isinstance(edge, dict): edges_to_delete.append(edge) if edges_to_delete: self._on_edges_delete(edges_to_delete) def _on_nodes_delete(self, nodes: list[dict[str, Any]]) -> None: """Remove nodes and all their connections when Vue deletes them.""" if not nodes: return node_ids = {n.get("id") for n in nodes if isinstance(n, dict) and n.get("id")} self._delete_nodes(node_ids) def _on_nodes_change(self, changes: list[dict[str, Any]]) -> None: """Handle generic node changes. Currently supports: - type == "remove": delete node and related connections - type == "position": update node position like 'node_moved' Other change types (select, dimensions, etc.) do not affect the DataGraph. """ if not changes: return node_ids_to_delete: set[str] = set() for change in changes: if not isinstance(change, dict): continue ctype = change.get("type") node_id = change.get("id") if ctype == "remove" and node_id: node_ids_to_delete.add(node_id) elif ctype == "position" and node_id: # Vue Flow usually sends 'position' for logical node coordinates. position = change.get("position") or change.get("positionAbsolute") or {} self._update_node_position(node_id, position) if node_ids_to_delete: self._delete_nodes(node_ids_to_delete) def _delete_nodes(self, node_ids: set[str]) -> None: if not node_ids: return self.graph.connections = [ c for c in self.graph.connections if c.start_node.node_id not in node_ids and c.end_node.node_id not in node_ids ] for node_id in node_ids: self.graph.nodes.pop(node_id, None)