File size: 5,920 Bytes
d868fac
 
 
3025bb3
d868fac
 
 
 
 
 
 
 
 
 
 
 
3025bb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868fac
3025bb3
 
 
 
 
 
 
 
 
 
d868fac
 
 
 
 
 
 
 
 
3025bb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868fac
691f45a
3025bb3
 
 
 
 
 
 
 
 
691f45a
3025bb3
 
 
 
 
 
 
 
 
 
691f45a
 
 
 
d868fac
3025bb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d868fac
 
 
 
 
3025bb3
 
 
 
691f45a
3025bb3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from __future__ import annotations

from dataclasses import dataclass
from typing import Any

from dataflow.graph import DataGraph
from dataflow.nodes_base import NodeInstance
from dataflow.ui.vueflow_canvas import VueFlowCanvas
from .text_to_image import TextToImageNode


@dataclass(slots=True)
class GraphRuntime:
    graph: DataGraph
    canvas: VueFlowCanvas

    def _get_execution_chain(self, root: NodeInstance) -> list[NodeInstance]:
        """Return all upstream nodes that belong to the chain of root.

        Order is upstream first, root last. This is only for UI purposes
        (spinners, progress, syncing), not to force execution.
        """
        result: list[NodeInstance] = []
        visited: set[str] = set()
        connections = getattr(self.graph, "connections", []) or []

        def visit(node: NodeInstance) -> None:
            node_id = getattr(node, "node_id", None)
            if node_id is None or node_id in visited:
                return
            visited.add(node_id)

            for conn in connections:
                try:
                    if conn.end_node is node:
                        visit(conn.start_node)
                except AttributeError:
                    continue

            result.append(node)

        visit(root)
        return result

    async def execute_node(self, node: NodeInstance | str) -> None:
        """Execute a node and keep the canvas in sync with its state.

        Rules:
          - Clicked node is always reset and re run.
          - Upstream nodes are never reset here and may reuse cached data.
          - Which nodes actually execute is decided by DataGraph and node logic.
          - As soon as a node finishes, it is synced to the UI.
        """
        import asyncio

        if isinstance(node, str):
            node_id = node
            node_obj = self.graph.nodes.get(node_id)
            if node_obj is None:
                return
        else:
            node_obj = node
            node_id = node.node_id

        execution_chain = self._get_execution_chain(node_obj)

        # show spinner on all nodes in the chain
        for n in execution_chain:
            nid = getattr(n, "node_id", None)
            if nid:
                self.canvas.set_node_processing(nid, True)

        # progress polling for nodes that expose progress_value
        stop_progress = False
        progress_tasks: list[asyncio.Task] = []

        async def progress_updater(n: NodeInstance, nid: str) -> None:
            last_value: Any = None
            while not stop_progress:
                await asyncio.sleep(0.1)
                if not hasattr(n, "progress_value"):
                    continue
                current = getattr(n, "progress_value", None)
                message = getattr(n, "progress_message", None)
                if current is None:
                    continue
                if current != last_value:
                    self.canvas.update_node_progress(nid, current, message)
                    last_value = current

        for n in execution_chain:
            nid = getattr(n, "node_id", None)
            if nid and hasattr(n, "progress_value"):
                progress_tasks.append(asyncio.create_task(progress_updater(n, nid)))

        # callback from DataGraph after each node is executed
        async def on_node_executed(executed_node: NodeInstance) -> None:
            # Only nodes that actually ran will call this.
            await self._sync_node_to_ui(executed_node)

        # save previous callback so we can restore it
        previous_cb = getattr(self.graph, "_on_node_executed", None)

        try:
            print(f"Runtime: Executing {node_id}...")

            # clicked node is always reset, upstream nodes are not
            if hasattr(node_obj, "reset_node"):
                node_obj.reset_node()

            # register our per node callback
            self.graph.set_on_node_executed(on_node_executed)

            # let DataGraph drive which nodes actually execute
            await self.graph.execute(node_obj)

            # one more sync for all nodes in the chain, in case some did not run
            for n in execution_chain:
                await self._sync_node_to_ui(n)

            # nice "complete" flash for the clicked node if it has progress
            if hasattr(node_obj, "progress_value"):
                self.canvas.update_node_progress(node_id, 1.0, "Complete")
                await asyncio.sleep(0.3)

        except Exception as e:
            print(f"Runtime execution failed: {e}")
            import traceback
            traceback.print_exc()
        finally:
            # restore previous graph callback
            self.graph.set_on_node_executed(previous_cb)

            # stop progress updaters
            stop_progress = True
            for t in progress_tasks:
                t.cancel()
                try:
                    await t
                except asyncio.CancelledError:
                    pass

            # hide spinner on all nodes in the chain
            for n in execution_chain:
                nid = getattr(n, "node_id", None)
                if nid:
                    self.canvas.set_node_processing(nid, False)

            # reset progress on the clicked node
            if hasattr(node_obj, "progress_value"):
                self.canvas.update_node_progress(node_id, 0.0, None)

    async def _sync_node_to_ui(self, node: NodeInstance) -> None:
        """Push relevant node state back to the Vue nodes."""
        if isinstance(node, TextToImageNode):
            image_src = "" if node.image_src is None else str(node.image_src)
            values: dict[str, Any] = {
                "image": image_src,
                "error": node.error or None,
            }
            self.canvas.update_node_values(node.node_id, values)
        # add other node types here as needed