from __future__ import annotations import io import itertools import json from dataclasses import dataclass, field from typing import Any from PIL import Image from nicegui import app, ui from dataflow.codecs import connection_to_vueflow_edge from dataflow.connection import Connection from dataflow.enums import NodeKind from dataflow.graph import DataGraph from dataflow.registry import Registry from dataflow.ui.vueflow_canvas import VueFlowCanvas from .controller import GraphController from .image_data import ImageDataNodeType, ImageDataNode, ImageDataNodeRenderable from .runtime import GraphRuntime from .text_data import TextDataNodeType, TextDataNode, TextDataNodeRenderable from .text_to_image import TextToImageNodeType, TextToImageNode, TextToImageNodeRenderable from .vue_nodes import VueNodeRenderer @dataclass(slots=True) class GraphSession: """Per browser tab graph session.""" registry: Registry graph: DataGraph renderer: VueNodeRenderer controller: GraphController runtime: GraphRuntime | None = None _node_id_counter: itertools.count = field( default_factory=lambda: itertools.count(1) ) @classmethod def create_default(cls) -> "GraphSession": """Create a new session with default node types and starter graph.""" # Register new node types here registry = Registry() registry.register_node_type( TextDataNodeType, lambda node_id: TextDataNode(node_id, TextDataNodeType) ) registry.register_node_type( ImageDataNodeType, lambda node_id: ImageDataNode(node_id, ImageDataNodeType) ) registry.register_node_type( TextToImageNodeType, lambda node_id: TextToImageNode(node_id, TextToImageNodeType), ) graph = DataGraph() controller = GraphController(graph=graph) renderer = VueNodeRenderer() # Register new node renderables here! renderer.register(TextDataNode, TextDataNodeRenderable()) renderer.register(ImageDataNode, ImageDataNodeRenderable()) renderer.register(TextToImageNode, TextToImageNodeRenderable()) session = cls( registry=registry, graph=graph, renderer=renderer, controller=controller, ) # Try to restore from storage, otherwise build initial graph if not session.restore_from_storage(): session._build_initial_graph() return session def _build_initial_graph(self) -> None: """Add starter nodes to a fresh graph.""" text_node = self.registry.create(TextDataNodeType, "t1") text2img_node = self.registry.create(TextToImageNodeType, "n2") text_node.x, text_node.y = 100, 100 text2img_node.x, text2img_node.y = 420, 100 self.graph.add_node(text_node) self.graph.add_node(text2img_node) edge = Connection(text_node, text_node.outputs["text"], text2img_node, text2img_node.inputs["text"]) self.graph.add_connection(edge) @property def creatable_node_types(self) -> list[dict[str, str]]: """Metadata for the node types the UI may create.""" return [ { "kind": TextDataNodeType.kind.value, "title": TextDataNodeType.display_name, }, { "kind": ImageDataNodeType.kind.value, "title": ImageDataNodeType.display_name, }, { "kind": TextToImageNodeType.kind.value, "title": TextToImageNodeType.display_name, }, ] def save_to_storage(self) -> None: """Persist current graph state to tab storage.""" json_str = self.to_json() app.storage.tab["graph_json"] = json_str def restore_from_storage(self) -> bool: """Try to load graph from tab storage. Returns True if successful.""" json_str = app.storage.tab.get("graph_json") if json_str: self.load_from_json(json_str) return True return False def attach_canvas(self, canvas: VueFlowCanvas) -> None: """Attach a VueFlowCanvas to this session.""" self.runtime = GraphRuntime(graph=self.graph, canvas=canvas) def initial_vue_nodes(self) -> list[dict[str, Any]]: """Serialize current graph nodes to Vue Flow node dicts with UI metadata.""" return [self.renderer.to_vue_node(n) for n in self.graph.nodes.values()] def initial_vue_edges(self) -> list[dict[str, Any]]: """Serialize current graph connections to Vue Flow edge dicts.""" return [connection_to_vueflow_edge(c) for c in self.graph.connections] def to_json(self) -> str: """Export graph to JSON.""" nodes_data = [] for node in self.graph.nodes.values(): # Extract UI values using the renderer logic vue_data = self.renderer.to_vue_node(node)["data"].get("values", {}) nodes_data.append({ "id": node.node_id, "kind": node.node_type.kind.value, "x": node.x, "y": node.y, "values": vue_data }) edges_data = [] for c in self.graph.connections: edges_data.append({ "source": c.start_node.node_id, "sourceHandle": c.start_port.name, "target": c.end_node.node_id, "targetHandle": c.end_port.name }) return json.dumps({"nodes": nodes_data, "edges": edges_data}, indent=2) def load_from_json(self, json_str: str) -> None: """Import graph from JSON.""" try: data = json.loads(json_str) except Exception: return self.clear_graph() # Recreate nodes for n_data in data.get("nodes", []): kind_val = n_data.get("kind") node_id = n_data.get("id") kind = self._node_kind_from_value(kind_val) if not kind or not node_id: continue node = self.registry.create(kind, node_id) node.x = n_data.get("x", 0) node.y = n_data.get("y", 0) # Apply values via controller to ensure side-effects (like updating port values) run values = n_data.get("values", {}) self.graph.add_node(node) # Add first so controller finds it for field, value in values.items(): self.controller._on_node_field_changed({"id": node_id, "field": field, "value": value}) # Update id counter to avoid collisions with new nodes if node_id.startswith("u"): try: num = int(node_id[1:]) if num >= next(self._node_id_counter) - 1: self._node_id_counter = itertools.count(num + 1) except: pass # Recreate edges from dataflow.connection import Connection for e_data in data.get("edges", []): src_node = self.graph.nodes.get(e_data.get("source")) tgt_node = self.graph.nodes.get(e_data.get("target")) if src_node and tgt_node: src_port = src_node.outputs.get(e_data.get("sourceHandle")) tgt_port = tgt_node.inputs.get(e_data.get("targetHandle")) if src_port and tgt_port: try: conn = Connection(src_node, src_port, tgt_node, tgt_port) self.graph.add_connection(conn) except ValueError: pass def next_ui_node_id(self) -> str: return f"u{next(self._node_id_counter)}" def clear_graph(self) -> None: """Remove all nodes and connections.""" self.graph.nodes.clear() self.graph.connections.clear() self._node_id_counter = itertools.count(1) def _node_kind_from_value(self, kind_value: str | None) -> NodeKind | None: if not kind_value: return None try: return NodeKind(kind_value) except Exception: pass for kind, node_type in self.registry.node_types.items(): if kind.value == kind_value or node_type.display_name == kind_value: return kind return None def create_node( self, kind_value: str | None, position: dict[str, Any] | None = None ) -> dict[str, Any] | None: """Create a new node in the graph and return its Vue node dict.""" kind = self._node_kind_from_value(kind_value) if kind is None: return None node_id = self.next_ui_node_id() try: node_obj = self.registry.create(kind, node_id) except KeyError: return None if position: x = position.get("x") y = position.get("y") if x is not None: try: node_obj.x = float(x) except (TypeError, ValueError): pass if y is not None: try: node_obj.y = float(y) except (TypeError, ValueError): pass self.graph.add_node(node_obj) return self.renderer.to_vue_node(node_obj) def _auto_connect_new_node( self, vue_node: dict[str, Any], pending_connection: dict[str, Any], canvas: VueFlowCanvas ) -> None: """Automatically connect a newly created node based on pending connection info.""" if not pending_connection or not vue_node: return new_node_id = vue_node.get("id") if not new_node_id: return new_node_obj = self.graph.nodes.get(new_node_id) if not new_node_obj: return # Extract pending connection info source_node_id = pending_connection.get("nodeId") source_handle_id = pending_connection.get("handleId") handle_type = pending_connection.get("handleType") if not source_node_id or not source_handle_id or not handle_type: return source_node = self.graph.nodes.get(source_node_id) if not source_node: return # Parse handle ID to get port name source_port_name = source_handle_id.split(":")[-1] if ":" in source_handle_id else source_handle_id # Determine connection direction and find compatible port if handle_type == "source": # User dragged from an output, connect to first compatible input of new node source_port = source_node.outputs.get(source_port_name) if not source_port: return # Find first compatible input port on new node target_port = None target_port_name = None for port_name, port in new_node_obj.inputs.items(): if port.schema.dtype.id == source_port.schema.dtype.id: target_port = port target_port_name = port_name break if target_port: # Create connection event connection_params = { "source": source_node_id, "target": new_node_id, "sourceHandle": source_handle_id, "targetHandle": f"{new_node_id}:{target_port_name}" } # Emit connect event to controller self.controller.handle_event({ "type": "connect", "payload": connection_params }) # Add edge to canvas edge_id = f"{source_node_id}-{source_handle_id}->{new_node_id}-{target_port_name}" canvas.add_edge({ "id": edge_id, "source": source_node_id, "target": new_node_id, "sourceHandle": source_handle_id, "targetHandle": f"{new_node_id}:{target_port_name}" }) elif handle_type == "target": # User dragged from an input, connect from first compatible output of new node target_port = source_node.inputs.get(source_port_name) if not target_port: return # Find first compatible output port on new node source_port = None source_port_name_new = None for port_name, port in new_node_obj.outputs.items(): if port.schema.dtype.id == target_port.schema.dtype.id: source_port = port source_port_name_new = port_name break if source_port: # Create connection event connection_params = { "source": new_node_id, "target": source_node_id, "sourceHandle": f"{new_node_id}:{source_port_name_new}", "targetHandle": source_handle_id } # Emit connect event to controller self.controller.handle_event({ "type": "connect", "payload": connection_params }) # Add edge to canvas edge_id = f"{new_node_id}-{source_port_name_new}->{source_node_id}-{source_handle_id}" canvas.add_edge({ "id": edge_id, "source": new_node_id, "target": source_node_id, "sourceHandle": f"{new_node_id}:{source_port_name_new}", "targetHandle": source_handle_id }) def duplicate_node( self, source_node_id: str, position: dict[str, Any] | None = None ) -> dict[str, Any] | None: """Duplicate an existing node and return its Vue node dict.""" # Get the source node from the graph source_node = self.graph.nodes.get(source_node_id) if source_node is None: return None # Create a new node of the same type node_id = self.next_ui_node_id() try: node_obj = self.registry.create(source_node.node_type.kind, node_id) except KeyError: return None # Set position if position: x = position.get("x") y = position.get("y") if x is not None: try: node_obj.x = float(x) except (TypeError, ValueError): pass if y is not None: try: node_obj.y = float(y) except (TypeError, ValueError): pass else: # Offset from source node if no position specified node_obj.x = source_node.x + 50 node_obj.y = source_node.y + 50 # Copy input port values (but not outputs - those should be recomputed) for port_name, source_port in source_node.inputs.items(): if port_name in node_obj.inputs: # Copy the value if it's set if source_port.value is not None: node_obj.inputs[port_name].value = source_port.value # Copy content based on node type # Copy text content for TextDataNode if isinstance(source_node, TextDataNode) and isinstance(node_obj, TextDataNode): source_port = source_node.outputs.get("text") if source_node.outputs else None if source_port and source_port.value is not None: target_port = node_obj.outputs.get("text") if node_obj.outputs else None if target_port: target_port.value = str(source_port.value) # Copy image content for ImageDataNode elif isinstance(source_node, ImageDataNode) and isinstance(node_obj, ImageDataNode): source_port = source_node.outputs.get("image") if source_node.outputs else None if source_port and source_port.value is not None: target_port = node_obj.outputs.get("image") if node_obj.outputs else None if target_port: # Copy PIL Image if isinstance(source_port.value, Image.Image): target_port.value = source_port.value.copy() else: target_port.value = source_port.value # Copy generated image for TextToImageNode elif isinstance(source_node, TextToImageNode) and isinstance(node_obj, TextToImageNode): # Copy image_src if source_node.image_src: node_obj.image_src = source_node.image_src # Copy decoded_image if source_node.decoded_image is not None: node_obj.decoded_image = source_node.decoded_image.copy() # copy the output port value if it exists source_port = source_node.outputs.get("image") if source_node.outputs else None if source_port and source_port.value is not None: target_port = node_obj.outputs.get("image") if node_obj.outputs else None if target_port: if isinstance(source_port.value, Image.Image): target_port.value = source_port.value.copy() else: target_port.value = source_port.value self.graph.add_node(node_obj) return self.renderer.to_vue_node(node_obj) async def handle_ui_event( self, event: dict[str, Any], canvas: VueFlowCanvas ) -> None: """Central handler for all Vue Flow events from this tab.""" event_type = event.get("type") payload = event.get("payload") save_needed = False if event_type == "execute_node": payload_dict = payload or {} node_id = payload_dict.get("id") if not node_id: return if self.runtime is None: self.attach_canvas(canvas) else: self.runtime.canvas = canvas n = ui.notification(timeout=None) n.message = f"Generating" n.spinner = True try: await self.runtime.execute_node(node_id) n.message = "Done!" n.type = "positive" n.spinner = False except Exception as ex: n.message = f"Error: {ex}" n.type = "negative" n.spinner = False n.dismiss() save_needed = True elif event_type == "reset_node": payload_dict = payload or {} node_id = payload_dict.get("id") if not node_id: return node = self.graph.nodes.get(node_id) if node is None: return with ui.dialog() as dialog, ui.card(): ui.label("Are you sure?") with ui.row(): ui.button("Yes", on_click=lambda: dialog.submit("Yes")) ui.button("No", on_click=lambda: dialog.submit("No")) result = await dialog if result == "No": return # only special case TextToImageNode for now if isinstance(node, TextToImageNode): node.reset_node() # sync cleared state to UI canvas.update_node_values( node_id, { "image": "", "error": None, }) save_needed = True ui.notify(f"Node has been reset!", type="positive") elif event_type == "download_node": payload_dict = payload or {} node_id = payload_dict.get("id") if not node_id: return node = self.graph.nodes.get(node_id) if node is None: return if isinstance(node, TextToImageNode): # Try to get the image from decoded_image first, then from image_src img = node.decoded_image if img is None and node.image_src: try: from . import utils img = utils.decode_image(node.image_src) except Exception as e: ui.notify(f"Failed to decode image: {e}", type="negative") return if img is None: ui.notify("No image available to download", type="warning") return buf = io.BytesIO() img.save(buf, format="PNG") buf.seek(0) png_bytes = buf.getvalue() ui.download(png_bytes, filename="generated.png", media_type="image/png") ui.notify("Image downloaded!", type="positive") elif event_type == "create_node": payload_dict = payload or {} kind_value = payload_dict.get("kind") position = payload_dict.get("position") or {} pending_connection = payload_dict.get("pendingConnection") vue_node = self.create_node(kind_value, position) if vue_node is not None: canvas.add_node(vue_node) save_needed = True # If there's a pending connection, auto-connect if pending_connection: self._auto_connect_new_node(vue_node, pending_connection, canvas) elif event_type == "duplicate_node": payload_dict = payload or {} source_node_id = payload_dict.get("sourceNodeId") position = payload_dict.get("position") or {} if source_node_id: vue_node = self.duplicate_node(source_node_id, position) if vue_node is not None: canvas.add_node(vue_node) save_needed = True else: self.controller.handle_event(event) # Assume any controller action (move, connect, delete, field change) modifies state save_needed = True if save_needed: self.save_to_storage()