| | from __future__ import annotations |
| |
|
| | from dataclasses import dataclass |
| | from typing import Any |
| |
|
| | from dataflow.connection import Connection |
| | from dataflow.enums import DataPortState |
| | from dataflow.graph import DataGraph |
| | from . import utils |
| | from .image_data import ImageDataNode |
| | from .text_data import TextDataNode |
| | from .text_to_image import TextToImageNode |
| |
|
| |
|
| | @dataclass(slots=True) |
| | class GraphController: |
| | """Handle Vue Flow events and keep the DataGraph plus node state in sync. |
| | |
| | This lives in the app layer so the core dataflow package does not |
| | depend on your concrete node types. |
| | """ |
| |
|
| | graph: DataGraph |
| |
|
| | def handle_event(self, event: dict[str, Any]) -> None: |
| | event_type = event.get("type") |
| | raw_payload = event.get("payload") |
| |
|
| | if event_type == "connect": |
| | payload = raw_payload or {} |
| | self._on_connect(payload) |
| | elif event_type == "node_moved": |
| | payload = raw_payload or {} |
| | self._on_node_moved(payload) |
| | elif event_type == "node_field_changed": |
| | payload = raw_payload or {} |
| | self._on_node_field_changed(payload) |
| | elif event_type == "edges_delete": |
| | edges = raw_payload or [] |
| | if isinstance(edges, list): |
| | self._on_edges_delete(edges) |
| | elif event_type == "edges_change": |
| | changes = raw_payload or [] |
| | if isinstance(changes, list): |
| | self._on_edges_change(changes) |
| | elif event_type == "nodes_delete": |
| | nodes = raw_payload or [] |
| | if isinstance(nodes, list): |
| | self._on_nodes_delete(nodes) |
| | elif event_type == "nodes_change": |
| | changes = raw_payload or [] |
| | if isinstance(changes, list): |
| | self._on_nodes_change(changes) |
| | |
| | |
| |
|
| | def _on_connect(self, payload: dict[str, Any]) -> None: |
| | source_handle = payload.get("sourceHandle") or "" |
| | target_handle = payload.get("targetHandle") or "" |
| |
|
| | def split(handle: str) -> tuple[str, str]: |
| | if not handle: |
| | return "", "" |
| | if ":" in handle: |
| | node_id, port = handle.split(":", 1) |
| | return node_id, port |
| | return "", handle |
| |
|
| | src_node_id, src_port_name = split(source_handle) |
| | tgt_node_id, tgt_port_name = split(target_handle) |
| |
|
| | if not src_node_id: |
| | src_node_id = payload.get("source") or "" |
| | if not tgt_node_id: |
| | tgt_node_id = payload.get("target") or "" |
| |
|
| | if not src_node_id or not tgt_node_id: |
| | return |
| |
|
| | start_node = self.graph.nodes.get(src_node_id) |
| | end_node = self.graph.nodes.get(tgt_node_id) |
| | if start_node is None or end_node is None: |
| | return |
| |
|
| | start_port = start_node.outputs.get(src_port_name) if start_node.outputs is not None else None |
| | if start_port is None and start_node.inputs is not None: |
| | start_port = start_node.inputs.get(src_port_name) |
| |
|
| | end_port = end_node.inputs.get(tgt_port_name) if end_node.inputs is not None else None |
| | if end_port is None and end_node.outputs is not None: |
| | end_port = end_node.outputs.get(tgt_port_name) |
| |
|
| | if start_port is None or end_port is None: |
| | return |
| |
|
| | conn = Connection( |
| | start_node=start_node, |
| | start_port=start_port, |
| | end_node=end_node, |
| | end_port=end_port, |
| | ) |
| |
|
| | try: |
| | self.graph.add_connection(conn) |
| | except ValueError: |
| | |
| | return |
| |
|
| | |
| | if hasattr(start_port, "state"): |
| | start_port.state = DataPortState.DIRTY |
| |
|
| | if hasattr(end_port, "state"): |
| | end_port.state = DataPortState.DIRTY |
| | |
| | if hasattr(end_port, "value"): |
| | end_port.value = None |
| |
|
| | def _update_node_position(self, node_id: str, position: dict[str, Any]) -> None: |
| | node = self.graph.nodes.get(node_id) |
| | if node is None: |
| | return |
| |
|
| | pos = position or {} |
| | x = pos.get("x") |
| | y = pos.get("y") |
| |
|
| | if x is not None: |
| | try: |
| | node.x = float(x) |
| | except (TypeError, ValueError): |
| | pass |
| | if y is not None: |
| | try: |
| | node.y = float(y) |
| | except (TypeError, ValueError): |
| | pass |
| |
|
| | def _on_node_moved(self, payload: dict[str, Any]) -> None: |
| | node_id = payload.get("id") |
| | if not node_id: |
| | return |
| | position = payload.get("position") or {} |
| | self._update_node_position(node_id, position) |
| |
|
| | def _on_node_field_changed(self, payload: dict[str, Any]) -> None: |
| | node_id = payload.get("id") |
| | field = payload.get("field") |
| | value = payload.get("value") |
| |
|
| | if not node_id or not field: |
| | return |
| |
|
| | node = self.graph.nodes.get(node_id) |
| | if node is None: |
| | return |
| |
|
| | if isinstance(node, TextDataNode) and field == "text": |
| | port = node.outputs.get("text") if node.outputs is not None else None |
| | if port is not None: |
| | port.value = "" if value is None else str(value) |
| | port.state = DataPortState.DIRTY |
| |
|
| | elif isinstance(node, ImageDataNode) and field == "image": |
| | port = node.outputs.get("image") if node.outputs is not None else None |
| | if port is not None: |
| | if value: |
| | |
| | try: |
| | |
| | img = utils.decode_image(str(value)) |
| | port.value = img |
| | port.state = DataPortState.DIRTY |
| | except Exception: |
| | |
| | port.value = None |
| | else: |
| | port.value = None |
| | port.state = DataPortState.DIRTY |
| |
|
| | elif isinstance(node, TextToImageNode) and field == "image": |
| | node.image_src = "" if value is None else str(value) |
| |
|
| | elif isinstance(node, TextToImageNode) and field == "aspect_ratio": |
| | |
| | aspect_ratio_value = "1:1" |
| |
|
| | print(f"[DEBUG controller] Aspect ratio set to {value}, type={type(value)}") |
| |
|
| | if value is not None and isinstance(value, str): |
| | aspect_ratio_value = value.strip() |
| | else: |
| | print(f"[DEBUG controller] aspect_ratio value is None, using default 1:1") |
| |
|
| | |
| | |
| | if aspect_ratio_value and ":" in aspect_ratio_value: |
| | parts = aspect_ratio_value.split(":") |
| | if len(parts) == 2: |
| | try: |
| | |
| | float(parts[0]) |
| | float(parts[1]) |
| | old_ratio = node.aspect_ratio |
| | node.aspect_ratio = aspect_ratio_value |
| |
|
| | except (ValueError, TypeError): |
| | |
| | node.aspect_ratio = "1:1" |
| | print(f"[DEBUG Controller] Invalid numeric format, using default: {node.aspect_ratio}") |
| |
|
| | def _on_edges_delete(self, edges: list[dict[str, Any]]) -> None: |
| | """Remove matching connections from the DataGraph when edges are deleted in Vue.""" |
| | if not edges: |
| | return |
| |
|
| | for edge in edges: |
| | if not isinstance(edge, dict): |
| | continue |
| |
|
| | src_id = edge.get("source") |
| | tgt_id = edge.get("target") |
| | src_handle = edge.get("sourceHandle") or "" |
| | tgt_handle = edge.get("targetHandle") or "" |
| |
|
| | src_port = src_handle.split(":", 1)[1] if ":" in src_handle else None |
| | tgt_port = tgt_handle.split(":", 1)[1] if ":" in tgt_handle else None |
| |
|
| | def should_remove(conn: Connection) -> bool: |
| | if src_id and conn.start_node.node_id != src_id: |
| | return False |
| | if tgt_id and conn.end_node.node_id != tgt_id: |
| | return False |
| | if src_port is not None and conn.start_port.name != src_port: |
| | return False |
| | if tgt_port is not None and conn.end_port.name != tgt_port: |
| | return False |
| | return True |
| |
|
| | self.graph.connections = [c for c in self.graph.connections if not should_remove(c)] |
| |
|
| | def _on_edges_change(self, changes: list[dict[str, Any]]) -> None: |
| | """Handle generic edge changes. |
| | |
| | Vue Flow sends EdgeChange objects. |
| | """ |
| | if not changes: |
| | return |
| |
|
| | edges_to_delete: list[dict[str, Any]] = [] |
| |
|
| | for change in changes: |
| | if not isinstance(change, dict): |
| | continue |
| |
|
| | if change.get("type") == "remove": |
| | |
| | edge = change |
| | if isinstance(edge, dict): |
| | edges_to_delete.append(edge) |
| |
|
| | if edges_to_delete: |
| | self._on_edges_delete(edges_to_delete) |
| |
|
| | def _on_nodes_delete(self, nodes: list[dict[str, Any]]) -> None: |
| | """Remove nodes and all their connections when Vue deletes them.""" |
| | if not nodes: |
| | return |
| |
|
| | node_ids = {n.get("id") for n in nodes if isinstance(n, dict) and n.get("id")} |
| | self._delete_nodes(node_ids) |
| |
|
| | def _on_nodes_change(self, changes: list[dict[str, Any]]) -> None: |
| | """Handle generic node changes. |
| | |
| | Currently supports: |
| | - type == "remove": delete node and related connections |
| | - type == "position": update node position like 'node_moved' |
| | Other change types (select, dimensions, etc.) do not affect the DataGraph. |
| | """ |
| | if not changes: |
| | return |
| |
|
| | node_ids_to_delete: set[str] = set() |
| |
|
| | for change in changes: |
| | if not isinstance(change, dict): |
| | continue |
| |
|
| | ctype = change.get("type") |
| | node_id = change.get("id") |
| |
|
| | if ctype == "remove" and node_id: |
| | node_ids_to_delete.add(node_id) |
| | elif ctype == "position" and node_id: |
| | |
| | position = change.get("position") or change.get("positionAbsolute") or {} |
| | self._update_node_position(node_id, position) |
| |
|
| | if node_ids_to_delete: |
| | self._delete_nodes(node_ids_to_delete) |
| |
|
| | def _delete_nodes(self, node_ids: set[str]) -> None: |
| | if not node_ids: |
| | return |
| |
|
| | self.graph.connections = [ |
| | c |
| | for c in self.graph.connections |
| | if c.start_node.node_id not in node_ids and c.end_node.node_id not in node_ids |
| | ] |
| |
|
| | for node_id in node_ids: |
| | self.graph.nodes.pop(node_id, None) |
| |
|