| from __future__ import annotations |
|
|
| from typing import TypedDict, Dict, Optional, Tuple |
| from typing_extensions import override |
| from PIL import Image |
| from enum import Enum |
| from abc import ABC |
| from tqdm import tqdm |
| from typing import TYPE_CHECKING |
| if TYPE_CHECKING: |
| from comfy_execution.graph import DynamicPrompt |
| from protocol import BinaryEventTypes |
| from comfy_api import feature_flags |
|
|
| PreviewImageTuple = Tuple[str, Image.Image, Optional[int]] |
|
|
| class NodeState(Enum): |
| Pending = "pending" |
| Running = "running" |
| Finished = "finished" |
| Error = "error" |
|
|
|
|
| class NodeProgressState(TypedDict): |
| """ |
| A class to represent the state of a node's progress. |
| """ |
|
|
| state: NodeState |
| value: float |
| max: float |
|
|
|
|
| class ProgressHandler(ABC): |
| """ |
| Abstract base class for progress handlers. |
| Progress handlers receive progress updates and display them in various ways. |
| """ |
|
|
| def __init__(self, name: str): |
| self.name = name |
| self.enabled = True |
|
|
| def set_registry(self, registry: "ProgressRegistry"): |
| pass |
|
|
| def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| """Called when a node starts processing""" |
| pass |
|
|
| def update_handler( |
| self, |
| node_id: str, |
| value: float, |
| max_value: float, |
| state: NodeProgressState, |
| prompt_id: str, |
| image: PreviewImageTuple | None = None, |
| ): |
| """Called when a node's progress is updated""" |
| pass |
|
|
| def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| """Called when a node finishes processing""" |
| pass |
|
|
| def reset(self): |
| """Called when the progress registry is reset""" |
| pass |
|
|
| def enable(self): |
| """Enable this handler""" |
| self.enabled = True |
|
|
| def disable(self): |
| """Disable this handler""" |
| self.enabled = False |
|
|
|
|
| class CLIProgressHandler(ProgressHandler): |
| """ |
| Handler that displays progress using tqdm progress bars in the CLI. |
| """ |
|
|
| def __init__(self): |
| super().__init__("cli") |
| self.progress_bars: Dict[str, tqdm] = {} |
|
|
| @override |
| def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| |
| if node_id not in self.progress_bars: |
| self.progress_bars[node_id] = tqdm( |
| total=state["max"], |
| desc=f"Node {node_id}", |
| unit="steps", |
| leave=True, |
| position=len(self.progress_bars), |
| ) |
|
|
| @override |
| def update_handler( |
| self, |
| node_id: str, |
| value: float, |
| max_value: float, |
| state: NodeProgressState, |
| prompt_id: str, |
| image: PreviewImageTuple | None = None, |
| ): |
| |
| if node_id not in self.progress_bars: |
| self.progress_bars[node_id] = tqdm( |
| total=max_value, |
| desc=f"Node {node_id}", |
| unit="steps", |
| leave=True, |
| position=len(self.progress_bars), |
| ) |
| self.progress_bars[node_id].update(value) |
| else: |
| |
| if max_value != self.progress_bars[node_id].total: |
| self.progress_bars[node_id].total = max_value |
| |
| current_position = self.progress_bars[node_id].n |
| update_amount = value - current_position |
| if update_amount > 0: |
| self.progress_bars[node_id].update(update_amount) |
|
|
| @override |
| def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| |
| if node_id in self.progress_bars: |
| |
| remaining = state["max"] - self.progress_bars[node_id].n |
| if remaining > 0: |
| self.progress_bars[node_id].update(remaining) |
| self.progress_bars[node_id].close() |
| del self.progress_bars[node_id] |
|
|
| @override |
| def reset(self): |
| |
| for bar in self.progress_bars.values(): |
| bar.close() |
| self.progress_bars.clear() |
|
|
|
|
| class WebUIProgressHandler(ProgressHandler): |
| """ |
| Handler that sends progress updates to the WebUI via WebSockets. |
| """ |
|
|
| def __init__(self, server_instance): |
| super().__init__("webui") |
| self.server_instance = server_instance |
|
|
| def set_registry(self, registry: "ProgressRegistry"): |
| self.registry = registry |
|
|
| def _send_progress_state(self, prompt_id: str, nodes: Dict[str, NodeProgressState]): |
| """Send the current progress state to the client""" |
| if self.server_instance is None: |
| return |
|
|
| |
| active_nodes = { |
| node_id: { |
| "value": state["value"], |
| "max": state["max"], |
| "state": state["state"].value, |
| "node_id": node_id, |
| "prompt_id": prompt_id, |
| "display_node_id": self.registry.dynprompt.get_display_node_id(node_id), |
| "parent_node_id": self.registry.dynprompt.get_parent_node_id(node_id), |
| "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), |
| } |
| for node_id, state in nodes.items() |
| if state["state"] != NodeState.Pending |
| } |
|
|
| |
| |
| self.server_instance.send_sync( |
| "progress_state", {"prompt_id": prompt_id, "nodes": active_nodes}, self.server_instance.client_id |
| ) |
|
|
| @override |
| def start_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| |
| if self.registry: |
| self._send_progress_state(prompt_id, self.registry.nodes) |
|
|
| @override |
| def update_handler( |
| self, |
| node_id: str, |
| value: float, |
| max_value: float, |
| state: NodeProgressState, |
| prompt_id: str, |
| image: PreviewImageTuple | None = None, |
| ): |
| |
| if self.registry: |
| self._send_progress_state(prompt_id, self.registry.nodes) |
| if image: |
| |
| if feature_flags.supports_feature( |
| self.server_instance.sockets_metadata, |
| self.server_instance.client_id, |
| "supports_preview_metadata", |
| ): |
| metadata = { |
| "node_id": node_id, |
| "prompt_id": prompt_id, |
| "display_node_id": self.registry.dynprompt.get_display_node_id( |
| node_id |
| ), |
| "parent_node_id": self.registry.dynprompt.get_parent_node_id( |
| node_id |
| ), |
| "real_node_id": self.registry.dynprompt.get_real_node_id(node_id), |
| } |
| self.server_instance.send_sync( |
| BinaryEventTypes.PREVIEW_IMAGE_WITH_METADATA, |
| (image, metadata), |
| self.server_instance.client_id, |
| ) |
|
|
| @override |
| def finish_handler(self, node_id: str, state: NodeProgressState, prompt_id: str): |
| |
| if self.registry: |
| self._send_progress_state(prompt_id, self.registry.nodes) |
|
|
| class ProgressRegistry: |
| """ |
| Registry that maintains node progress state and notifies registered handlers. |
| """ |
|
|
| def __init__(self, prompt_id: str, dynprompt: "DynamicPrompt"): |
| self.prompt_id = prompt_id |
| self.dynprompt = dynprompt |
| self.nodes: Dict[str, NodeProgressState] = {} |
| self.handlers: Dict[str, ProgressHandler] = {} |
|
|
| def register_handler(self, handler: ProgressHandler) -> None: |
| """Register a progress handler""" |
| self.handlers[handler.name] = handler |
|
|
| def unregister_handler(self, handler_name: str) -> None: |
| """Unregister a progress handler""" |
| if handler_name in self.handlers: |
| |
| self.handlers[handler_name].reset() |
| del self.handlers[handler_name] |
|
|
| def enable_handler(self, handler_name: str) -> None: |
| """Enable a progress handler""" |
| if handler_name in self.handlers: |
| self.handlers[handler_name].enable() |
|
|
| def disable_handler(self, handler_name: str) -> None: |
| """Disable a progress handler""" |
| if handler_name in self.handlers: |
| self.handlers[handler_name].disable() |
|
|
| def ensure_entry(self, node_id: str) -> NodeProgressState: |
| """Ensure a node entry exists""" |
| if node_id not in self.nodes: |
| self.nodes[node_id] = NodeProgressState( |
| state=NodeState.Pending, value=0, max=1 |
| ) |
| return self.nodes[node_id] |
|
|
| def start_progress(self, node_id: str) -> None: |
| """Start progress tracking for a node""" |
| entry = self.ensure_entry(node_id) |
| entry["state"] = NodeState.Running |
| entry["value"] = 0.0 |
| entry["max"] = 1.0 |
|
|
| |
| for handler in self.handlers.values(): |
| if handler.enabled: |
| handler.start_handler(node_id, entry, self.prompt_id) |
|
|
| def update_progress( |
| self, node_id: str, value: float, max_value: float, image: PreviewImageTuple | None = None |
| ) -> None: |
| """Update progress for a node""" |
| entry = self.ensure_entry(node_id) |
| entry["state"] = NodeState.Running |
| entry["value"] = value |
| entry["max"] = max_value |
|
|
| |
| for handler in self.handlers.values(): |
| if handler.enabled: |
| handler.update_handler( |
| node_id, value, max_value, entry, self.prompt_id, image |
| ) |
|
|
| def finish_progress(self, node_id: str) -> None: |
| """Finish progress tracking for a node""" |
| entry = self.ensure_entry(node_id) |
| entry["state"] = NodeState.Finished |
| entry["value"] = entry["max"] |
|
|
| |
| for handler in self.handlers.values(): |
| if handler.enabled: |
| handler.finish_handler(node_id, entry, self.prompt_id) |
|
|
| def reset_handlers(self) -> None: |
| """Reset all handlers""" |
| for handler in self.handlers.values(): |
| handler.reset() |
|
|
| |
| global_progress_registry: ProgressRegistry | None = None |
|
|
| def reset_progress_state(prompt_id: str, dynprompt: "DynamicPrompt") -> None: |
| global global_progress_registry |
|
|
| |
| if global_progress_registry is not None: |
| global_progress_registry.reset_handlers() |
|
|
| |
| global_progress_registry = ProgressRegistry(prompt_id, dynprompt) |
|
|
|
|
| def add_progress_handler(handler: ProgressHandler) -> None: |
| registry = get_progress_state() |
| handler.set_registry(registry) |
| registry.register_handler(handler) |
|
|
|
|
| def get_progress_state() -> ProgressRegistry: |
| global global_progress_registry |
| if global_progress_registry is None: |
| from comfy_execution.graph import DynamicPrompt |
|
|
| global_progress_registry = ProgressRegistry( |
| prompt_id="", dynprompt=DynamicPrompt({}) |
| ) |
| return global_progress_registry |
|
|