| | from __future__ import annotations |
| |
|
| | import io |
| | import itertools |
| | import json |
| | from dataclasses import dataclass, field |
| | from typing import Any |
| |
|
| | from PIL import Image |
| | from nicegui import app, ui |
| |
|
| | from dataflow.codecs import connection_to_vueflow_edge |
| | from dataflow.connection import Connection |
| | from dataflow.enums import NodeKind |
| | from dataflow.graph import DataGraph |
| | from dataflow.registry import Registry |
| | from dataflow.ui.vueflow_canvas import VueFlowCanvas |
| | from .controller import GraphController |
| | from .image_data import ImageDataNodeType, ImageDataNode, ImageDataNodeRenderable |
| | from .runtime import GraphRuntime |
| | from .text_data import TextDataNodeType, TextDataNode, TextDataNodeRenderable |
| | from .text_to_image import TextToImageNodeType, TextToImageNode, TextToImageNodeRenderable |
| | from .vue_nodes import VueNodeRenderer |
| |
|
| |
|
| | @dataclass(slots=True) |
| | class GraphSession: |
| | """Per browser tab graph session.""" |
| |
|
| | registry: Registry |
| | graph: DataGraph |
| | renderer: VueNodeRenderer |
| | controller: GraphController |
| | runtime: GraphRuntime | None = None |
| | _node_id_counter: itertools.count = field( |
| | default_factory=lambda: itertools.count(1) |
| | ) |
| |
|
| | @classmethod |
| | def create_default(cls) -> "GraphSession": |
| | """Create a new session with default node types and starter graph.""" |
| | |
| | registry = Registry() |
| | registry.register_node_type( |
| | TextDataNodeType, lambda node_id: TextDataNode(node_id, TextDataNodeType) |
| | ) |
| | registry.register_node_type( |
| | ImageDataNodeType, lambda node_id: ImageDataNode(node_id, ImageDataNodeType) |
| | ) |
| | registry.register_node_type( |
| | TextToImageNodeType, |
| | lambda node_id: TextToImageNode(node_id, TextToImageNodeType), |
| | ) |
| |
|
| | graph = DataGraph() |
| | controller = GraphController(graph=graph) |
| |
|
| | renderer = VueNodeRenderer() |
| |
|
| | |
| | renderer.register(TextDataNode, TextDataNodeRenderable()) |
| | renderer.register(ImageDataNode, ImageDataNodeRenderable()) |
| | renderer.register(TextToImageNode, TextToImageNodeRenderable()) |
| |
|
| | session = cls( |
| | registry=registry, |
| | graph=graph, |
| | renderer=renderer, |
| | controller=controller, |
| | ) |
| |
|
| | |
| | if not session.restore_from_storage(): |
| | session._build_initial_graph() |
| |
|
| | return session |
| |
|
| | def _build_initial_graph(self) -> None: |
| | """Add starter nodes to a fresh graph.""" |
| | text_node = self.registry.create(TextDataNodeType, "t1") |
| | text2img_node = self.registry.create(TextToImageNodeType, "n2") |
| |
|
| | text_node.x, text_node.y = 100, 100 |
| | text2img_node.x, text2img_node.y = 420, 100 |
| |
|
| | self.graph.add_node(text_node) |
| | self.graph.add_node(text2img_node) |
| |
|
| | edge = Connection(text_node, text_node.outputs["text"], |
| | text2img_node, text2img_node.inputs["text"]) |
| |
|
| | self.graph.add_connection(edge) |
| |
|
| | @property |
| | def creatable_node_types(self) -> list[dict[str, str]]: |
| | """Metadata for the node types the UI may create.""" |
| | return [ |
| | { |
| | "kind": TextDataNodeType.kind.value, |
| | "title": TextDataNodeType.display_name, |
| | }, |
| | { |
| | "kind": ImageDataNodeType.kind.value, |
| | "title": ImageDataNodeType.display_name, |
| | }, |
| | { |
| | "kind": TextToImageNodeType.kind.value, |
| | "title": TextToImageNodeType.display_name, |
| | }, |
| | ] |
| |
|
| | def save_to_storage(self) -> None: |
| | """Persist current graph state to tab storage.""" |
| | json_str = self.to_json() |
| | app.storage.tab["graph_json"] = json_str |
| |
|
| | def restore_from_storage(self) -> bool: |
| | """Try to load graph from tab storage. Returns True if successful.""" |
| | json_str = app.storage.tab.get("graph_json") |
| | if json_str: |
| | self.load_from_json(json_str) |
| | return True |
| | return False |
| |
|
| | def attach_canvas(self, canvas: VueFlowCanvas) -> None: |
| | """Attach a VueFlowCanvas to this session.""" |
| | self.runtime = GraphRuntime(graph=self.graph, canvas=canvas) |
| |
|
| | def initial_vue_nodes(self) -> list[dict[str, Any]]: |
| | """Serialize current graph nodes to Vue Flow node dicts with UI metadata.""" |
| | return [self.renderer.to_vue_node(n) for n in self.graph.nodes.values()] |
| |
|
| | def initial_vue_edges(self) -> list[dict[str, Any]]: |
| | """Serialize current graph connections to Vue Flow edge dicts.""" |
| | return [connection_to_vueflow_edge(c) for c in self.graph.connections] |
| |
|
| | def to_json(self) -> str: |
| | """Export graph to JSON.""" |
| | nodes_data = [] |
| | for node in self.graph.nodes.values(): |
| | |
| | vue_data = self.renderer.to_vue_node(node)["data"].get("values", {}) |
| |
|
| | nodes_data.append({ |
| | "id": node.node_id, |
| | "kind": node.node_type.kind.value, |
| | "x": node.x, |
| | "y": node.y, |
| | "values": vue_data |
| | }) |
| |
|
| | edges_data = [] |
| | for c in self.graph.connections: |
| | edges_data.append({ |
| | "source": c.start_node.node_id, |
| | "sourceHandle": c.start_port.name, |
| | "target": c.end_node.node_id, |
| | "targetHandle": c.end_port.name |
| | }) |
| |
|
| | return json.dumps({"nodes": nodes_data, "edges": edges_data}, indent=2) |
| |
|
| | def load_from_json(self, json_str: str) -> None: |
| | """Import graph from JSON.""" |
| | try: |
| | data = json.loads(json_str) |
| | except Exception: |
| | return |
| |
|
| | self.clear_graph() |
| |
|
| | |
| | for n_data in data.get("nodes", []): |
| | kind_val = n_data.get("kind") |
| | node_id = n_data.get("id") |
| |
|
| | kind = self._node_kind_from_value(kind_val) |
| | if not kind or not node_id: |
| | continue |
| |
|
| | node = self.registry.create(kind, node_id) |
| | node.x = n_data.get("x", 0) |
| | node.y = n_data.get("y", 0) |
| |
|
| | |
| | values = n_data.get("values", {}) |
| | self.graph.add_node(node) |
| |
|
| | for field, value in values.items(): |
| | self.controller._on_node_field_changed({"id": node_id, "field": field, "value": value}) |
| |
|
| | |
| | if node_id.startswith("u"): |
| | try: |
| | num = int(node_id[1:]) |
| | if num >= next(self._node_id_counter) - 1: |
| | self._node_id_counter = itertools.count(num + 1) |
| | except: |
| | pass |
| |
|
| | |
| | from dataflow.connection import Connection |
| | for e_data in data.get("edges", []): |
| | src_node = self.graph.nodes.get(e_data.get("source")) |
| | tgt_node = self.graph.nodes.get(e_data.get("target")) |
| | if src_node and tgt_node: |
| | src_port = src_node.outputs.get(e_data.get("sourceHandle")) |
| | tgt_port = tgt_node.inputs.get(e_data.get("targetHandle")) |
| |
|
| | if src_port and tgt_port: |
| | try: |
| | conn = Connection(src_node, src_port, tgt_node, tgt_port) |
| | self.graph.add_connection(conn) |
| | except ValueError: |
| | pass |
| |
|
| | def next_ui_node_id(self) -> str: |
| | return f"u{next(self._node_id_counter)}" |
| |
|
| | def clear_graph(self) -> None: |
| | """Remove all nodes and connections.""" |
| | self.graph.nodes.clear() |
| | self.graph.connections.clear() |
| | self._node_id_counter = itertools.count(1) |
| |
|
| | def _node_kind_from_value(self, kind_value: str | None) -> NodeKind | None: |
| | if not kind_value: |
| | return None |
| |
|
| | try: |
| | return NodeKind(kind_value) |
| | except Exception: |
| | pass |
| |
|
| | for kind, node_type in self.registry.node_types.items(): |
| | if kind.value == kind_value or node_type.display_name == kind_value: |
| | return kind |
| |
|
| | return None |
| |
|
| | def create_node( |
| | self, kind_value: str | None, position: dict[str, Any] | None = None |
| | ) -> dict[str, Any] | None: |
| | """Create a new node in the graph and return its Vue node dict.""" |
| | kind = self._node_kind_from_value(kind_value) |
| | if kind is None: |
| | return None |
| |
|
| | node_id = self.next_ui_node_id() |
| | try: |
| | node_obj = self.registry.create(kind, node_id) |
| | except KeyError: |
| | return None |
| |
|
| | if position: |
| | x = position.get("x") |
| | y = position.get("y") |
| | if x is not None: |
| | try: |
| | node_obj.x = float(x) |
| | except (TypeError, ValueError): |
| | pass |
| | if y is not None: |
| | try: |
| | node_obj.y = float(y) |
| | except (TypeError, ValueError): |
| | pass |
| |
|
| | self.graph.add_node(node_obj) |
| | return self.renderer.to_vue_node(node_obj) |
| |
|
| | def _auto_connect_new_node( |
| | self, vue_node: dict[str, Any], pending_connection: dict[str, Any], canvas: VueFlowCanvas |
| | ) -> None: |
| | """Automatically connect a newly created node based on pending connection info.""" |
| | if not pending_connection or not vue_node: |
| | return |
| |
|
| | new_node_id = vue_node.get("id") |
| | if not new_node_id: |
| | return |
| |
|
| | new_node_obj = self.graph.nodes.get(new_node_id) |
| | if not new_node_obj: |
| | return |
| |
|
| | |
| | source_node_id = pending_connection.get("nodeId") |
| | source_handle_id = pending_connection.get("handleId") |
| | handle_type = pending_connection.get("handleType") |
| |
|
| | if not source_node_id or not source_handle_id or not handle_type: |
| | return |
| |
|
| | source_node = self.graph.nodes.get(source_node_id) |
| | if not source_node: |
| | return |
| |
|
| | |
| | source_port_name = source_handle_id.split(":")[-1] if ":" in source_handle_id else source_handle_id |
| |
|
| | |
| | if handle_type == "source": |
| | |
| | source_port = source_node.outputs.get(source_port_name) |
| | if not source_port: |
| | return |
| |
|
| | |
| | target_port = None |
| | target_port_name = None |
| | for port_name, port in new_node_obj.inputs.items(): |
| | if port.schema.dtype.id == source_port.schema.dtype.id: |
| | target_port = port |
| | target_port_name = port_name |
| | break |
| |
|
| | if target_port: |
| | |
| | connection_params = { |
| | "source": source_node_id, |
| | "target": new_node_id, |
| | "sourceHandle": source_handle_id, |
| | "targetHandle": f"{new_node_id}:{target_port_name}" |
| | } |
| |
|
| | |
| | self.controller.handle_event({ |
| | "type": "connect", |
| | "payload": connection_params |
| | }) |
| |
|
| | |
| | edge_id = f"{source_node_id}-{source_handle_id}->{new_node_id}-{target_port_name}" |
| | canvas.add_edge({ |
| | "id": edge_id, |
| | "source": source_node_id, |
| | "target": new_node_id, |
| | "sourceHandle": source_handle_id, |
| | "targetHandle": f"{new_node_id}:{target_port_name}" |
| | }) |
| |
|
| | elif handle_type == "target": |
| | |
| | target_port = source_node.inputs.get(source_port_name) |
| | if not target_port: |
| | return |
| |
|
| | |
| | source_port = None |
| | source_port_name_new = None |
| | for port_name, port in new_node_obj.outputs.items(): |
| | if port.schema.dtype.id == target_port.schema.dtype.id: |
| | source_port = port |
| | source_port_name_new = port_name |
| | break |
| |
|
| | if source_port: |
| | |
| | connection_params = { |
| | "source": new_node_id, |
| | "target": source_node_id, |
| | "sourceHandle": f"{new_node_id}:{source_port_name_new}", |
| | "targetHandle": source_handle_id |
| | } |
| |
|
| | |
| | self.controller.handle_event({ |
| | "type": "connect", |
| | "payload": connection_params |
| | }) |
| |
|
| | |
| | edge_id = f"{new_node_id}-{source_port_name_new}->{source_node_id}-{source_handle_id}" |
| | canvas.add_edge({ |
| | "id": edge_id, |
| | "source": new_node_id, |
| | "target": source_node_id, |
| | "sourceHandle": f"{new_node_id}:{source_port_name_new}", |
| | "targetHandle": source_handle_id |
| | }) |
| |
|
| | def duplicate_node( |
| | self, source_node_id: str, position: dict[str, Any] | None = None |
| | ) -> dict[str, Any] | None: |
| | """Duplicate an existing node and return its Vue node dict.""" |
| | |
| | source_node = self.graph.nodes.get(source_node_id) |
| | if source_node is None: |
| | return None |
| |
|
| | |
| | node_id = self.next_ui_node_id() |
| | try: |
| | node_obj = self.registry.create(source_node.node_type.kind, node_id) |
| | except KeyError: |
| | return None |
| |
|
| | |
| | if position: |
| | x = position.get("x") |
| | y = position.get("y") |
| | if x is not None: |
| | try: |
| | node_obj.x = float(x) |
| | except (TypeError, ValueError): |
| | pass |
| | if y is not None: |
| | try: |
| | node_obj.y = float(y) |
| | except (TypeError, ValueError): |
| | pass |
| | else: |
| | |
| | node_obj.x = source_node.x + 50 |
| | node_obj.y = source_node.y + 50 |
| |
|
| | |
| | for port_name, source_port in source_node.inputs.items(): |
| | if port_name in node_obj.inputs: |
| | |
| | if source_port.value is not None: |
| | node_obj.inputs[port_name].value = source_port.value |
| |
|
| | |
| |
|
| | |
| | if isinstance(source_node, TextDataNode) and isinstance(node_obj, TextDataNode): |
| | source_port = source_node.outputs.get("text") if source_node.outputs else None |
| | if source_port and source_port.value is not None: |
| | target_port = node_obj.outputs.get("text") if node_obj.outputs else None |
| | if target_port: |
| | target_port.value = str(source_port.value) |
| |
|
| | |
| | elif isinstance(source_node, ImageDataNode) and isinstance(node_obj, ImageDataNode): |
| | source_port = source_node.outputs.get("image") if source_node.outputs else None |
| | if source_port and source_port.value is not None: |
| | target_port = node_obj.outputs.get("image") if node_obj.outputs else None |
| | if target_port: |
| | |
| | if isinstance(source_port.value, Image.Image): |
| | target_port.value = source_port.value.copy() |
| | else: |
| | target_port.value = source_port.value |
| |
|
| | |
| | elif isinstance(source_node, TextToImageNode) and isinstance(node_obj, TextToImageNode): |
| | |
| | if source_node.image_src: |
| | node_obj.image_src = source_node.image_src |
| | |
| | if source_node.decoded_image is not None: |
| | node_obj.decoded_image = source_node.decoded_image.copy() |
| | |
| | source_port = source_node.outputs.get("image") if source_node.outputs else None |
| | if source_port and source_port.value is not None: |
| | target_port = node_obj.outputs.get("image") if node_obj.outputs else None |
| | if target_port: |
| | if isinstance(source_port.value, Image.Image): |
| | target_port.value = source_port.value.copy() |
| | else: |
| | target_port.value = source_port.value |
| |
|
| | self.graph.add_node(node_obj) |
| | return self.renderer.to_vue_node(node_obj) |
| |
|
| | async def handle_ui_event( |
| | self, event: dict[str, Any], canvas: VueFlowCanvas |
| | ) -> None: |
| | """Central handler for all Vue Flow events from this tab.""" |
| | event_type = event.get("type") |
| | payload = event.get("payload") |
| |
|
| | save_needed = False |
| |
|
| | if event_type == "execute_node": |
| | payload_dict = payload or {} |
| | node_id = payload_dict.get("id") |
| | if not node_id: |
| | return |
| |
|
| | if self.runtime is None: |
| | self.attach_canvas(canvas) |
| | else: |
| | self.runtime.canvas = canvas |
| |
|
| | n = ui.notification(timeout=None) |
| | n.message = f"Generating" |
| | n.spinner = True |
| |
|
| | try: |
| | await self.runtime.execute_node(node_id) |
| |
|
| | n.message = "Done!" |
| | n.type = "positive" |
| | n.spinner = False |
| | except Exception as ex: |
| | n.message = f"Error: {ex}" |
| | n.type = "negative" |
| | n.spinner = False |
| |
|
| | n.dismiss() |
| |
|
| | save_needed = True |
| |
|
| | elif event_type == "reset_node": |
| | payload_dict = payload or {} |
| | node_id = payload_dict.get("id") |
| | if not node_id: |
| | return |
| |
|
| | node = self.graph.nodes.get(node_id) |
| | if node is None: |
| | return |
| |
|
| | with ui.dialog() as dialog, ui.card(): |
| | ui.label("Are you sure?") |
| | with ui.row(): |
| | ui.button("Yes", on_click=lambda: dialog.submit("Yes")) |
| | ui.button("No", on_click=lambda: dialog.submit("No")) |
| | result = await dialog |
| |
|
| | if result == "No": |
| | return |
| |
|
| | |
| | if isinstance(node, TextToImageNode): |
| | node.reset_node() |
| |
|
| | |
| | canvas.update_node_values( |
| | node_id, { |
| | "image": "", |
| | "error": None, |
| | }) |
| |
|
| | save_needed = True |
| | ui.notify(f"Node has been reset!", type="positive") |
| |
|
| | elif event_type == "download_node": |
| | payload_dict = payload or {} |
| | node_id = payload_dict.get("id") |
| | if not node_id: |
| | return |
| |
|
| | node = self.graph.nodes.get(node_id) |
| | if node is None: |
| | return |
| |
|
| | if isinstance(node, TextToImageNode): |
| | |
| | img = node.decoded_image |
| | if img is None and node.image_src: |
| | try: |
| | from . import utils |
| | img = utils.decode_image(node.image_src) |
| | except Exception as e: |
| | ui.notify(f"Failed to decode image: {e}", type="negative") |
| | return |
| | |
| | if img is None: |
| | ui.notify("No image available to download", type="warning") |
| | return |
| | |
| | buf = io.BytesIO() |
| | img.save(buf, format="PNG") |
| | buf.seek(0) |
| |
|
| | png_bytes = buf.getvalue() |
| | ui.download(png_bytes, filename="generated.png", media_type="image/png") |
| | ui.notify("Image downloaded!", type="positive") |
| |
|
| | elif event_type == "create_node": |
| | payload_dict = payload or {} |
| | kind_value = payload_dict.get("kind") |
| | position = payload_dict.get("position") or {} |
| | pending_connection = payload_dict.get("pendingConnection") |
| |
|
| | vue_node = self.create_node(kind_value, position) |
| | if vue_node is not None: |
| | canvas.add_node(vue_node) |
| | save_needed = True |
| |
|
| | |
| | if pending_connection: |
| | self._auto_connect_new_node(vue_node, pending_connection, canvas) |
| |
|
| |
|
| | elif event_type == "duplicate_node": |
| | payload_dict = payload or {} |
| | source_node_id = payload_dict.get("sourceNodeId") |
| | position = payload_dict.get("position") or {} |
| | if source_node_id: |
| | vue_node = self.duplicate_node(source_node_id, position) |
| | if vue_node is not None: |
| | canvas.add_node(vue_node) |
| | save_needed = True |
| |
|
| | else: |
| | self.controller.handle_event(event) |
| | |
| | save_needed = True |
| |
|
| | if save_needed: |
| | self.save_to_storage() |
| |
|