Spaces:
Running
Running
| from __future__ import annotations | |
| import logging | |
| from dataclasses import dataclass | |
| from typing import Any | |
| from velai.dataflow.connection import Connection | |
| from velai.dataflow.enums import DataPortState | |
| from velai.dataflow.graph import DataGraph | |
| from velai.graph_undo import GraphUndoStack | |
| from velai.nodes.base_node import BaseNode | |
| logger = logging.getLogger(__name__) | |
| class GraphController: | |
| """Handle Vue Flow events and keep the DataGraph plus node state in sync. | |
| This lives in the app layer so the core dataflow package does not | |
| depend on your concrete node types. | |
| """ | |
| graph: DataGraph | |
| undo_stack: GraphUndoStack | None = None | |
| def handle_event(self, event: dict[str, Any]) -> None: | |
| event_type = event.get("type") | |
| raw_payload = event.get("payload") | |
| if event_type == "connect": | |
| payload = raw_payload or {} | |
| self._on_connect(payload) | |
| elif event_type == "node_moved": | |
| payload = raw_payload or {} | |
| self._on_node_moved(payload) | |
| elif event_type == "node_resized": | |
| payload = raw_payload or {} | |
| self._on_node_resized(payload) | |
| elif event_type == "node_field_changed": | |
| payload = raw_payload or {} | |
| self._on_node_field_changed(payload) | |
| elif event_type == "node_title_changed": | |
| payload = raw_payload or {} | |
| self._on_node_title_changed(payload) | |
| elif event_type == "edges_delete": | |
| edges = raw_payload or [] | |
| if isinstance(edges, list): | |
| self._on_edges_delete(edges) | |
| elif event_type == "edges_change": | |
| changes = raw_payload or [] | |
| if isinstance(changes, list): | |
| self._on_edges_change(changes) | |
| elif event_type == "nodes_delete": | |
| nodes = raw_payload or [] | |
| if isinstance(nodes, list): | |
| self._on_nodes_delete(nodes) | |
| elif event_type == "nodes_change": | |
| changes = raw_payload or [] | |
| if isinstance(changes, list): | |
| self._on_nodes_change(changes) | |
| # other events (graph_cleared, create_node, etc.) are currently ignored | |
| # because they do not affect the DataGraph directly here | |
| def _on_connect(self, payload: dict[str, Any]) -> None: | |
| source_handle = payload.get("sourceHandle") or "" | |
| target_handle = payload.get("targetHandle") or "" | |
| def split(handle: str) -> tuple[str, str]: | |
| if not handle: | |
| return "", "" | |
| if ":" in handle: | |
| node_id, port = handle.split(":", 1) | |
| return node_id, port | |
| return "", handle | |
| src_node_id, src_port_name = split(source_handle) | |
| tgt_node_id, tgt_port_name = split(target_handle) | |
| if not src_node_id: | |
| src_node_id = payload.get("source") or "" | |
| if not tgt_node_id: | |
| tgt_node_id = payload.get("target") or "" | |
| if not src_node_id or not tgt_node_id: | |
| return | |
| start_node = self.graph.nodes.get(src_node_id) | |
| end_node = self.graph.nodes.get(tgt_node_id) | |
| if start_node is None or end_node is None: | |
| return | |
| start_port = start_node.outputs.get(src_port_name) if start_node.outputs is not None else None | |
| if start_port is None and start_node.inputs is not None: | |
| start_port = start_node.inputs.get(src_port_name) | |
| end_port = end_node.inputs.get(tgt_port_name) if end_node.inputs is not None else None | |
| if end_port is None and end_node.outputs is not None: | |
| end_port = end_node.outputs.get(tgt_port_name) | |
| if start_port is None or end_port is None: | |
| return | |
| conn = Connection( | |
| start_node=start_node, | |
| start_port=start_port, | |
| end_node=end_node, | |
| end_port=end_port, | |
| ) | |
| try: | |
| self.graph.add_connection(conn) | |
| except ValueError: | |
| # datatype mismatch or capacity problems | |
| return | |
| # make the end port "dirty" because topology changed | |
| end_port.state = DataPortState.DIRTY | |
| end_port.value = None | |
| def _update_node_position(self, node_id: str, position: dict[str, Any]) -> None: | |
| node = self.graph.nodes.get(node_id) | |
| if node is None: | |
| return | |
| pos = position or {} | |
| x = pos.get("x") | |
| y = pos.get("y") | |
| if x is not None: | |
| try: | |
| node.x = float(x) | |
| except (TypeError, ValueError): | |
| pass | |
| if y is not None: | |
| try: | |
| node.y = float(y) | |
| except (TypeError, ValueError): | |
| pass | |
| def _on_node_moved(self, payload: dict[str, Any]) -> None: | |
| node_id = payload.get("id") | |
| if not node_id: | |
| return | |
| position = payload.get("position") or {} | |
| self._update_node_position(node_id, position) | |
| def _on_node_resized(self, payload: dict[str, Any]) -> None: | |
| node_id = payload.get("id") | |
| if not node_id: | |
| return | |
| node = self.graph.nodes.get(node_id) | |
| if node is None: | |
| return | |
| width = payload.get("width") | |
| height = payload.get("height") | |
| if width is not None: | |
| node.width = float(width) | |
| if height is not None: | |
| node.height = float(height) | |
| def _on_node_field_changed(self, payload: dict[str, Any]) -> None: | |
| node_id = payload.get("id") | |
| field = payload.get("field") | |
| value = payload.get("value") | |
| if not node_id or not field: | |
| return | |
| node = self.graph.nodes.get(node_id) | |
| if node is None or not isinstance(node, BaseNode): | |
| return | |
| node.set_state({field: value}) | |
| def _on_node_title_changed(self, payload: dict[str, Any]) -> None: | |
| node_id = payload.get("id") | |
| title = payload.get("title") | |
| if not node_id: | |
| return | |
| node = self.graph.nodes.get(node_id) | |
| if node is None or not isinstance(node, BaseNode): | |
| return | |
| if title and str(title).strip(): | |
| node.data.custom_title = str(title).strip() | |
| else: | |
| node.data.custom_title = None | |
| def _on_edges_delete(self, edges: list[dict[str, Any]]) -> None: | |
| """Remove matching connections from the DataGraph when edges are deleted in Vue.""" | |
| if not edges: | |
| return | |
| for edge in edges: | |
| if not isinstance(edge, dict): | |
| continue | |
| src_id = edge.get("source") | |
| tgt_id = edge.get("target") | |
| src_handle = edge.get("sourceHandle") or "" | |
| tgt_handle = edge.get("targetHandle") or "" | |
| src_port = src_handle.split(":", 1)[1] if ":" in src_handle else None | |
| tgt_port = tgt_handle.split(":", 1)[1] if ":" in tgt_handle else None | |
| def should_remove(conn: Connection) -> bool: | |
| if src_id and conn.start_node.node_id != src_id: | |
| return False | |
| if tgt_id and conn.end_node.node_id != tgt_id: | |
| return False | |
| if src_port is not None and conn.start_port.name != src_port: | |
| return False | |
| if tgt_port is not None and conn.end_port.name != tgt_port: | |
| return False | |
| return True | |
| self.graph.connections = [c for c in self.graph.connections if not should_remove(c)] | |
| def _on_edges_change(self, changes: list[dict[str, Any]]) -> None: | |
| """Handle generic edge changes. | |
| Vue Flow sends EdgeChange objects. | |
| """ | |
| if not changes: | |
| return | |
| edges_to_delete: list[dict[str, Any]] = [] | |
| for change in changes: | |
| if not isinstance(change, dict): | |
| continue | |
| if change.get("type") == "remove": | |
| # this is quite ugly, the edge should be sent directly or read from the list of edges from the graph | |
| edge = change | |
| if isinstance(edge, dict): | |
| edges_to_delete.append(edge) | |
| if edges_to_delete: | |
| self._on_edges_delete(edges_to_delete) | |
| def _on_nodes_delete(self, nodes: list[dict[str, Any]]) -> None: | |
| """Remove nodes and all their connections when Vue deletes them.""" | |
| if not nodes: | |
| return | |
| node_ids = {n.get("id") for n in nodes if isinstance(n, dict) and n.get("id")} | |
| self._delete_nodes(node_ids) | |
| def _on_nodes_change(self, changes: list[dict[str, Any]]) -> None: | |
| """Handle generic node changes. | |
| Currently supports: | |
| - type == "remove": delete node and related connections | |
| - type == "position": update node position like 'node_moved' | |
| Other change types (select, dimensions, etc.) do not affect the DataGraph. | |
| """ | |
| if not changes: | |
| return | |
| node_ids_to_delete: set[str] = set() | |
| for change in changes: | |
| if not isinstance(change, dict): | |
| continue | |
| ctype = change.get("type") | |
| node_id = change.get("id") | |
| if ctype == "remove" and node_id: | |
| node_ids_to_delete.add(node_id) | |
| elif ctype == "position" and node_id: | |
| # Vue Flow usually sends 'position' for logical node coordinates. | |
| position = change.get("position") or change.get("positionAbsolute") or {} | |
| self._update_node_position(node_id, position) | |
| if node_ids_to_delete: | |
| self._delete_nodes(node_ids_to_delete) | |
| def _delete_nodes(self, node_ids: set[str]) -> None: | |
| if not node_ids: | |
| return | |
| if self.undo_stack is not None: | |
| self.undo_stack.record_node_deletion(self.graph, node_ids) | |
| self.graph.connections = [ | |
| c | |
| for c in self.graph.connections | |
| if c.start_node.node_id not in node_ids and c.end_node.node_id not in node_ids | |
| ] | |
| for node_id in node_ids: | |
| self.graph.nodes.pop(node_id, None) | |