velai-workshop / velai /controller.py
kratadata's picture
Upload folder via script
0f8b3a0 verified
from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import Any
from velai.dataflow.connection import Connection
from velai.dataflow.enums import DataPortState
from velai.dataflow.graph import DataGraph
from velai.graph_undo import GraphUndoStack
from velai.nodes.base_node import BaseNode
logger = logging.getLogger(__name__)
@dataclass
class GraphController:
"""Handle Vue Flow events and keep the DataGraph plus node state in sync.
This lives in the app layer so the core dataflow package does not
depend on your concrete node types.
"""
graph: DataGraph
undo_stack: GraphUndoStack | None = None
def handle_event(self, event: dict[str, Any]) -> None:
event_type = event.get("type")
raw_payload = event.get("payload")
if event_type == "connect":
payload = raw_payload or {}
self._on_connect(payload)
elif event_type == "node_moved":
payload = raw_payload or {}
self._on_node_moved(payload)
elif event_type == "node_resized":
payload = raw_payload or {}
self._on_node_resized(payload)
elif event_type == "node_field_changed":
payload = raw_payload or {}
self._on_node_field_changed(payload)
elif event_type == "node_title_changed":
payload = raw_payload or {}
self._on_node_title_changed(payload)
elif event_type == "edges_delete":
edges = raw_payload or []
if isinstance(edges, list):
self._on_edges_delete(edges)
elif event_type == "edges_change":
changes = raw_payload or []
if isinstance(changes, list):
self._on_edges_change(changes)
elif event_type == "nodes_delete":
nodes = raw_payload or []
if isinstance(nodes, list):
self._on_nodes_delete(nodes)
elif event_type == "nodes_change":
changes = raw_payload or []
if isinstance(changes, list):
self._on_nodes_change(changes)
# other events (graph_cleared, create_node, etc.) are currently ignored
# because they do not affect the DataGraph directly here
def _on_connect(self, payload: dict[str, Any]) -> None:
source_handle = payload.get("sourceHandle") or ""
target_handle = payload.get("targetHandle") or ""
def split(handle: str) -> tuple[str, str]:
if not handle:
return "", ""
if ":" in handle:
node_id, port = handle.split(":", 1)
return node_id, port
return "", handle
src_node_id, src_port_name = split(source_handle)
tgt_node_id, tgt_port_name = split(target_handle)
if not src_node_id:
src_node_id = payload.get("source") or ""
if not tgt_node_id:
tgt_node_id = payload.get("target") or ""
if not src_node_id or not tgt_node_id:
return
start_node = self.graph.nodes.get(src_node_id)
end_node = self.graph.nodes.get(tgt_node_id)
if start_node is None or end_node is None:
return
start_port = start_node.outputs.get(src_port_name) if start_node.outputs is not None else None
if start_port is None and start_node.inputs is not None:
start_port = start_node.inputs.get(src_port_name)
end_port = end_node.inputs.get(tgt_port_name) if end_node.inputs is not None else None
if end_port is None and end_node.outputs is not None:
end_port = end_node.outputs.get(tgt_port_name)
if start_port is None or end_port is None:
return
conn = Connection(
start_node=start_node,
start_port=start_port,
end_node=end_node,
end_port=end_port,
)
try:
self.graph.add_connection(conn)
except ValueError:
# datatype mismatch or capacity problems
return
# make the end port "dirty" because topology changed
end_port.state = DataPortState.DIRTY
end_port.value = None
def _update_node_position(self, node_id: str, position: dict[str, Any]) -> None:
node = self.graph.nodes.get(node_id)
if node is None:
return
pos = position or {}
x = pos.get("x")
y = pos.get("y")
if x is not None:
try:
node.x = float(x)
except (TypeError, ValueError):
pass
if y is not None:
try:
node.y = float(y)
except (TypeError, ValueError):
pass
def _on_node_moved(self, payload: dict[str, Any]) -> None:
node_id = payload.get("id")
if not node_id:
return
position = payload.get("position") or {}
self._update_node_position(node_id, position)
def _on_node_resized(self, payload: dict[str, Any]) -> None:
node_id = payload.get("id")
if not node_id:
return
node = self.graph.nodes.get(node_id)
if node is None:
return
width = payload.get("width")
height = payload.get("height")
if width is not None:
node.width = float(width)
if height is not None:
node.height = float(height)
def _on_node_field_changed(self, payload: dict[str, Any]) -> None:
node_id = payload.get("id")
field = payload.get("field")
value = payload.get("value")
if not node_id or not field:
return
node = self.graph.nodes.get(node_id)
if node is None or not isinstance(node, BaseNode):
return
node.set_state({field: value})
def _on_node_title_changed(self, payload: dict[str, Any]) -> None:
node_id = payload.get("id")
title = payload.get("title")
if not node_id:
return
node = self.graph.nodes.get(node_id)
if node is None or not isinstance(node, BaseNode):
return
if title and str(title).strip():
node.data.custom_title = str(title).strip()
else:
node.data.custom_title = None
def _on_edges_delete(self, edges: list[dict[str, Any]]) -> None:
"""Remove matching connections from the DataGraph when edges are deleted in Vue."""
if not edges:
return
for edge in edges:
if not isinstance(edge, dict):
continue
src_id = edge.get("source")
tgt_id = edge.get("target")
src_handle = edge.get("sourceHandle") or ""
tgt_handle = edge.get("targetHandle") or ""
src_port = src_handle.split(":", 1)[1] if ":" in src_handle else None
tgt_port = tgt_handle.split(":", 1)[1] if ":" in tgt_handle else None
def should_remove(conn: Connection) -> bool:
if src_id and conn.start_node.node_id != src_id:
return False
if tgt_id and conn.end_node.node_id != tgt_id:
return False
if src_port is not None and conn.start_port.name != src_port:
return False
if tgt_port is not None and conn.end_port.name != tgt_port:
return False
return True
self.graph.connections = [c for c in self.graph.connections if not should_remove(c)]
def _on_edges_change(self, changes: list[dict[str, Any]]) -> None:
"""Handle generic edge changes.
Vue Flow sends EdgeChange objects.
"""
if not changes:
return
edges_to_delete: list[dict[str, Any]] = []
for change in changes:
if not isinstance(change, dict):
continue
if change.get("type") == "remove":
# this is quite ugly, the edge should be sent directly or read from the list of edges from the graph
edge = change
if isinstance(edge, dict):
edges_to_delete.append(edge)
if edges_to_delete:
self._on_edges_delete(edges_to_delete)
def _on_nodes_delete(self, nodes: list[dict[str, Any]]) -> None:
"""Remove nodes and all their connections when Vue deletes them."""
if not nodes:
return
node_ids = {n.get("id") for n in nodes if isinstance(n, dict) and n.get("id")}
self._delete_nodes(node_ids)
def _on_nodes_change(self, changes: list[dict[str, Any]]) -> None:
"""Handle generic node changes.
Currently supports:
- type == "remove": delete node and related connections
- type == "position": update node position like 'node_moved'
Other change types (select, dimensions, etc.) do not affect the DataGraph.
"""
if not changes:
return
node_ids_to_delete: set[str] = set()
for change in changes:
if not isinstance(change, dict):
continue
ctype = change.get("type")
node_id = change.get("id")
if ctype == "remove" and node_id:
node_ids_to_delete.add(node_id)
elif ctype == "position" and node_id:
# Vue Flow usually sends 'position' for logical node coordinates.
position = change.get("position") or change.get("positionAbsolute") or {}
self._update_node_position(node_id, position)
if node_ids_to_delete:
self._delete_nodes(node_ids_to_delete)
def _delete_nodes(self, node_ids: set[str]) -> None:
if not node_ids:
return
if self.undo_stack is not None:
self.undo_stack.record_node_deletion(self.graph, node_ids)
self.graph.connections = [
c
for c in self.graph.connections
if c.start_node.node_id not in node_ids and c.end_node.node_id not in node_ids
]
for node_id in node_ids:
self.graph.nodes.pop(node_id, None)