velai / nodes /session.py
cansik's picture
Upload folder via script
abd08cb verified
from __future__ import annotations
import io
import itertools
import json
from dataclasses import dataclass, field
from typing import Any
from PIL import Image
from nicegui import app, ui
from dataflow.codecs import connection_to_vueflow_edge
from dataflow.connection import Connection
from dataflow.enums import NodeKind
from dataflow.graph import DataGraph
from dataflow.registry import Registry
from dataflow.ui.vueflow_canvas import VueFlowCanvas
from .controller import GraphController
from .image_data import ImageDataNodeType, ImageDataNode, ImageDataNodeRenderable
from .runtime import GraphRuntime
from .text_data import TextDataNodeType, TextDataNode, TextDataNodeRenderable
from .text_to_image import TextToImageNodeType, TextToImageNode, TextToImageNodeRenderable
from .vue_nodes import VueNodeRenderer
@dataclass(slots=True)
class GraphSession:
"""Per browser tab graph session."""
registry: Registry
graph: DataGraph
renderer: VueNodeRenderer
controller: GraphController
runtime: GraphRuntime | None = None
_node_id_counter: itertools.count = field(
default_factory=lambda: itertools.count(1)
)
@classmethod
def create_default(cls) -> "GraphSession":
"""Create a new session with default node types and starter graph."""
# Register new node types here
registry = Registry()
registry.register_node_type(
TextDataNodeType, lambda node_id: TextDataNode(node_id, TextDataNodeType)
)
registry.register_node_type(
ImageDataNodeType, lambda node_id: ImageDataNode(node_id, ImageDataNodeType)
)
registry.register_node_type(
TextToImageNodeType,
lambda node_id: TextToImageNode(node_id, TextToImageNodeType),
)
graph = DataGraph()
controller = GraphController(graph=graph)
renderer = VueNodeRenderer()
# Register new node renderables here!
renderer.register(TextDataNode, TextDataNodeRenderable())
renderer.register(ImageDataNode, ImageDataNodeRenderable())
renderer.register(TextToImageNode, TextToImageNodeRenderable())
session = cls(
registry=registry,
graph=graph,
renderer=renderer,
controller=controller,
)
# Try to restore from storage, otherwise build initial graph
if not session.restore_from_storage():
session._build_initial_graph()
return session
def _build_initial_graph(self) -> None:
"""Add starter nodes to a fresh graph."""
text_node = self.registry.create(TextDataNodeType, "t1")
text2img_node = self.registry.create(TextToImageNodeType, "n2")
text_node.x, text_node.y = 100, 100
text2img_node.x, text2img_node.y = 420, 100
self.graph.add_node(text_node)
self.graph.add_node(text2img_node)
edge = Connection(text_node, text_node.outputs["text"],
text2img_node, text2img_node.inputs["text"])
self.graph.add_connection(edge)
@property
def creatable_node_types(self) -> list[dict[str, str]]:
"""Metadata for the node types the UI may create."""
return [
{
"kind": TextDataNodeType.kind.value,
"title": TextDataNodeType.display_name,
},
{
"kind": ImageDataNodeType.kind.value,
"title": ImageDataNodeType.display_name,
},
{
"kind": TextToImageNodeType.kind.value,
"title": TextToImageNodeType.display_name,
},
]
def save_to_storage(self) -> None:
"""Persist current graph state to tab storage."""
json_str = self.to_json()
app.storage.tab["graph_json"] = json_str
def restore_from_storage(self) -> bool:
"""Try to load graph from tab storage. Returns True if successful."""
json_str = app.storage.tab.get("graph_json")
if json_str:
self.load_from_json(json_str)
return True
return False
def attach_canvas(self, canvas: VueFlowCanvas) -> None:
"""Attach a VueFlowCanvas to this session."""
self.runtime = GraphRuntime(graph=self.graph, canvas=canvas)
def initial_vue_nodes(self) -> list[dict[str, Any]]:
"""Serialize current graph nodes to Vue Flow node dicts with UI metadata."""
return [self.renderer.to_vue_node(n) for n in self.graph.nodes.values()]
def initial_vue_edges(self) -> list[dict[str, Any]]:
"""Serialize current graph connections to Vue Flow edge dicts."""
return [connection_to_vueflow_edge(c) for c in self.graph.connections]
def to_json(self) -> str:
"""Export graph to JSON."""
nodes_data = []
for node in self.graph.nodes.values():
# Extract UI values using the renderer logic
vue_data = self.renderer.to_vue_node(node)["data"].get("values", {})
nodes_data.append({
"id": node.node_id,
"kind": node.node_type.kind.value,
"x": node.x,
"y": node.y,
"values": vue_data
})
edges_data = []
for c in self.graph.connections:
edges_data.append({
"source": c.start_node.node_id,
"sourceHandle": c.start_port.name,
"target": c.end_node.node_id,
"targetHandle": c.end_port.name
})
return json.dumps({"nodes": nodes_data, "edges": edges_data}, indent=2)
def load_from_json(self, json_str: str) -> None:
"""Import graph from JSON."""
try:
data = json.loads(json_str)
except Exception:
return
self.clear_graph()
# Recreate nodes
for n_data in data.get("nodes", []):
kind_val = n_data.get("kind")
node_id = n_data.get("id")
kind = self._node_kind_from_value(kind_val)
if not kind or not node_id:
continue
node = self.registry.create(kind, node_id)
node.x = n_data.get("x", 0)
node.y = n_data.get("y", 0)
# Apply values via controller to ensure side-effects (like updating port values) run
values = n_data.get("values", {})
self.graph.add_node(node) # Add first so controller finds it
for field, value in values.items():
self.controller._on_node_field_changed({"id": node_id, "field": field, "value": value})
# Update id counter to avoid collisions with new nodes
if node_id.startswith("u"):
try:
num = int(node_id[1:])
if num >= next(self._node_id_counter) - 1:
self._node_id_counter = itertools.count(num + 1)
except:
pass
# Recreate edges
from dataflow.connection import Connection
for e_data in data.get("edges", []):
src_node = self.graph.nodes.get(e_data.get("source"))
tgt_node = self.graph.nodes.get(e_data.get("target"))
if src_node and tgt_node:
src_port = src_node.outputs.get(e_data.get("sourceHandle"))
tgt_port = tgt_node.inputs.get(e_data.get("targetHandle"))
if src_port and tgt_port:
try:
conn = Connection(src_node, src_port, tgt_node, tgt_port)
self.graph.add_connection(conn)
except ValueError:
pass
def next_ui_node_id(self) -> str:
return f"u{next(self._node_id_counter)}"
def clear_graph(self) -> None:
"""Remove all nodes and connections."""
self.graph.nodes.clear()
self.graph.connections.clear()
self._node_id_counter = itertools.count(1)
def _node_kind_from_value(self, kind_value: str | None) -> NodeKind | None:
if not kind_value:
return None
try:
return NodeKind(kind_value)
except Exception:
pass
for kind, node_type in self.registry.node_types.items():
if kind.value == kind_value or node_type.display_name == kind_value:
return kind
return None
def create_node(
self, kind_value: str | None, position: dict[str, Any] | None = None
) -> dict[str, Any] | None:
"""Create a new node in the graph and return its Vue node dict."""
kind = self._node_kind_from_value(kind_value)
if kind is None:
return None
node_id = self.next_ui_node_id()
try:
node_obj = self.registry.create(kind, node_id)
except KeyError:
return None
if position:
x = position.get("x")
y = position.get("y")
if x is not None:
try:
node_obj.x = float(x)
except (TypeError, ValueError):
pass
if y is not None:
try:
node_obj.y = float(y)
except (TypeError, ValueError):
pass
self.graph.add_node(node_obj)
return self.renderer.to_vue_node(node_obj)
def _auto_connect_new_node(
self, vue_node: dict[str, Any], pending_connection: dict[str, Any], canvas: VueFlowCanvas
) -> None:
"""Automatically connect a newly created node based on pending connection info."""
if not pending_connection or not vue_node:
return
new_node_id = vue_node.get("id")
if not new_node_id:
return
new_node_obj = self.graph.nodes.get(new_node_id)
if not new_node_obj:
return
# Extract pending connection info
source_node_id = pending_connection.get("nodeId")
source_handle_id = pending_connection.get("handleId")
handle_type = pending_connection.get("handleType")
if not source_node_id or not source_handle_id or not handle_type:
return
source_node = self.graph.nodes.get(source_node_id)
if not source_node:
return
# Parse handle ID to get port name
source_port_name = source_handle_id.split(":")[-1] if ":" in source_handle_id else source_handle_id
# Determine connection direction and find compatible port
if handle_type == "source":
# User dragged from an output, connect to first compatible input of new node
source_port = source_node.outputs.get(source_port_name)
if not source_port:
return
# Find first compatible input port on new node
target_port = None
target_port_name = None
for port_name, port in new_node_obj.inputs.items():
if port.schema.dtype.id == source_port.schema.dtype.id:
target_port = port
target_port_name = port_name
break
if target_port:
# Create connection event
connection_params = {
"source": source_node_id,
"target": new_node_id,
"sourceHandle": source_handle_id,
"targetHandle": f"{new_node_id}:{target_port_name}"
}
# Emit connect event to controller
self.controller.handle_event({
"type": "connect",
"payload": connection_params
})
# Add edge to canvas
edge_id = f"{source_node_id}-{source_handle_id}->{new_node_id}-{target_port_name}"
canvas.add_edge({
"id": edge_id,
"source": source_node_id,
"target": new_node_id,
"sourceHandle": source_handle_id,
"targetHandle": f"{new_node_id}:{target_port_name}"
})
elif handle_type == "target":
# User dragged from an input, connect from first compatible output of new node
target_port = source_node.inputs.get(source_port_name)
if not target_port:
return
# Find first compatible output port on new node
source_port = None
source_port_name_new = None
for port_name, port in new_node_obj.outputs.items():
if port.schema.dtype.id == target_port.schema.dtype.id:
source_port = port
source_port_name_new = port_name
break
if source_port:
# Create connection event
connection_params = {
"source": new_node_id,
"target": source_node_id,
"sourceHandle": f"{new_node_id}:{source_port_name_new}",
"targetHandle": source_handle_id
}
# Emit connect event to controller
self.controller.handle_event({
"type": "connect",
"payload": connection_params
})
# Add edge to canvas
edge_id = f"{new_node_id}-{source_port_name_new}->{source_node_id}-{source_handle_id}"
canvas.add_edge({
"id": edge_id,
"source": new_node_id,
"target": source_node_id,
"sourceHandle": f"{new_node_id}:{source_port_name_new}",
"targetHandle": source_handle_id
})
def duplicate_node(
self, source_node_id: str, position: dict[str, Any] | None = None
) -> dict[str, Any] | None:
"""Duplicate an existing node and return its Vue node dict."""
# Get the source node from the graph
source_node = self.graph.nodes.get(source_node_id)
if source_node is None:
return None
# Create a new node of the same type
node_id = self.next_ui_node_id()
try:
node_obj = self.registry.create(source_node.node_type.kind, node_id)
except KeyError:
return None
# Set position
if position:
x = position.get("x")
y = position.get("y")
if x is not None:
try:
node_obj.x = float(x)
except (TypeError, ValueError):
pass
if y is not None:
try:
node_obj.y = float(y)
except (TypeError, ValueError):
pass
else:
# Offset from source node if no position specified
node_obj.x = source_node.x + 50
node_obj.y = source_node.y + 50
# Copy input port values (but not outputs - those should be recomputed)
for port_name, source_port in source_node.inputs.items():
if port_name in node_obj.inputs:
# Copy the value if it's set
if source_port.value is not None:
node_obj.inputs[port_name].value = source_port.value
# Copy content based on node type
# Copy text content for TextDataNode
if isinstance(source_node, TextDataNode) and isinstance(node_obj, TextDataNode):
source_port = source_node.outputs.get("text") if source_node.outputs else None
if source_port and source_port.value is not None:
target_port = node_obj.outputs.get("text") if node_obj.outputs else None
if target_port:
target_port.value = str(source_port.value)
# Copy image content for ImageDataNode
elif isinstance(source_node, ImageDataNode) and isinstance(node_obj, ImageDataNode):
source_port = source_node.outputs.get("image") if source_node.outputs else None
if source_port and source_port.value is not None:
target_port = node_obj.outputs.get("image") if node_obj.outputs else None
if target_port:
# Copy PIL Image
if isinstance(source_port.value, Image.Image):
target_port.value = source_port.value.copy()
else:
target_port.value = source_port.value
# Copy generated image for TextToImageNode
elif isinstance(source_node, TextToImageNode) and isinstance(node_obj, TextToImageNode):
# Copy image_src
if source_node.image_src:
node_obj.image_src = source_node.image_src
# Copy decoded_image
if source_node.decoded_image is not None:
node_obj.decoded_image = source_node.decoded_image.copy()
# copy the output port value if it exists
source_port = source_node.outputs.get("image") if source_node.outputs else None
if source_port and source_port.value is not None:
target_port = node_obj.outputs.get("image") if node_obj.outputs else None
if target_port:
if isinstance(source_port.value, Image.Image):
target_port.value = source_port.value.copy()
else:
target_port.value = source_port.value
self.graph.add_node(node_obj)
return self.renderer.to_vue_node(node_obj)
async def handle_ui_event(
self, event: dict[str, Any], canvas: VueFlowCanvas
) -> None:
"""Central handler for all Vue Flow events from this tab."""
event_type = event.get("type")
payload = event.get("payload")
save_needed = False
if event_type == "execute_node":
payload_dict = payload or {}
node_id = payload_dict.get("id")
if not node_id:
return
if self.runtime is None:
self.attach_canvas(canvas)
else:
self.runtime.canvas = canvas
n = ui.notification(timeout=None)
n.message = f"Generating"
n.spinner = True
try:
await self.runtime.execute_node(node_id)
n.message = "Done!"
n.type = "positive"
n.spinner = False
except Exception as ex:
n.message = f"Error: {ex}"
n.type = "negative"
n.spinner = False
n.dismiss()
save_needed = True
elif event_type == "reset_node":
payload_dict = payload or {}
node_id = payload_dict.get("id")
if not node_id:
return
node = self.graph.nodes.get(node_id)
if node is None:
return
with ui.dialog() as dialog, ui.card():
ui.label("Are you sure?")
with ui.row():
ui.button("Yes", on_click=lambda: dialog.submit("Yes"))
ui.button("No", on_click=lambda: dialog.submit("No"))
result = await dialog
if result == "No":
return
# only special case TextToImageNode for now
if isinstance(node, TextToImageNode):
node.reset_node()
# sync cleared state to UI
canvas.update_node_values(
node_id, {
"image": "",
"error": None,
})
save_needed = True
ui.notify(f"Node has been reset!", type="positive")
elif event_type == "download_node":
payload_dict = payload or {}
node_id = payload_dict.get("id")
if not node_id:
return
node = self.graph.nodes.get(node_id)
if node is None:
return
if isinstance(node, TextToImageNode):
# Try to get the image from decoded_image first, then from image_src
img = node.decoded_image
if img is None and node.image_src:
try:
from . import utils
img = utils.decode_image(node.image_src)
except Exception as e:
ui.notify(f"Failed to decode image: {e}", type="negative")
return
if img is None:
ui.notify("No image available to download", type="warning")
return
buf = io.BytesIO()
img.save(buf, format="PNG")
buf.seek(0)
png_bytes = buf.getvalue()
ui.download(png_bytes, filename="generated.png", media_type="image/png")
ui.notify("Image downloaded!", type="positive")
elif event_type == "create_node":
payload_dict = payload or {}
kind_value = payload_dict.get("kind")
position = payload_dict.get("position") or {}
pending_connection = payload_dict.get("pendingConnection")
vue_node = self.create_node(kind_value, position)
if vue_node is not None:
canvas.add_node(vue_node)
save_needed = True
# If there's a pending connection, auto-connect
if pending_connection:
self._auto_connect_new_node(vue_node, pending_connection, canvas)
elif event_type == "duplicate_node":
payload_dict = payload or {}
source_node_id = payload_dict.get("sourceNodeId")
position = payload_dict.get("position") or {}
if source_node_id:
vue_node = self.duplicate_node(source_node_id, position)
if vue_node is not None:
canvas.add_node(vue_node)
save_needed = True
else:
self.controller.handle_event(event)
# Assume any controller action (move, connect, delete, field change) modifies state
save_needed = True
if save_needed:
self.save_to_storage()