"""ComfyUI library-mode backend. Single-process, single-implementation. The @spaces.GPU decorator is the only divergence between local and HF Spaces deployment. """ from __future__ import annotations import asyncio import contextvars import os import pathlib import sys import threading import traceback as tb_mod from collections.abc import AsyncIterator, Iterable from dataclasses import dataclass, field from typing import Any import models @dataclass class DownloadEvent: filename: str mb_done: float mb_total: float @dataclass class ProgressEvent: stage: int stage_label: str step: int total_steps: int @dataclass class OutputEvent: video_path: str audio_path: str | None = None meta: dict = field(default_factory=dict) @dataclass class ErrorEvent: category: str # "oom" | "zerogpu_timeout" | "execution" | "interrupt" | "download" message: str stage: int | None = None traceback: str = "" def _on_spaces() -> bool: return bool(os.environ.get("SPACES_ZERO_GPU")) try: import spaces # type: ignore except ImportError: spaces = None # type: ignore[assignment] def _identity(fn): return fn # --- Per-call ZeroGPU duration estimator ----------------------------------- # `duration` is a per-call timeout. Shorter declared duration → faster queue # priority on the shared ZeroGPU pool. Estimating from (mode, preset, frames) # instead of using a one-size-fits-all 600s cap means light T2V calls jump # the queue while heavy modes (lipsync, style) reserve real headroom. _BASE_DURATION_S: dict[str, int] = { # Rough sampler+decode time at ~120 frames, balanced preset, warm cache. "t2v": 90, "i2v": 90, "a2v": 120, "lipsync": 240, # extra: audio encoder + audio VAE + extra LoRAs "keyframe": 180, "style": 360, # extra: preprocessor (canny/dwpose/depth) + IC-LoRAs } _PRESET_MULT: dict[str, float] = {"fast": 1.0, "balanced": 1.5, "quality": 3.0} def _frames_from_workflow(workflow: dict) -> int: """Read the frame count from the workflow's EmptyLTXVLatentVideo node.""" for node in workflow.values(): if isinstance(node, dict) and node.get("class_type") == "EmptyLTXVLatentVideo": try: return int((node.get("inputs") or {}).get("length", 121)) except (TypeError, ValueError): return 121 return 121 def _duration_for( executor: Any, workflow: dict, output_ids: list[str], mode: str, preset: str, multiplier: float = 1.0, progress: Any = None, ) -> int: """ZeroGPU duration estimator. Same signature as _execute_workflow. `progress` is a gr.Progress instance forwarded by the caller; we ignore it here (estimator doesn't emit progress) but must accept it positionally so ZeroGPU can call us with the same arg list it'll use for _execute_workflow. Estimate = (base × preset multiplier + cold-cache buffer + per-frame VAE decode time) × retry multiplier, clamped to [60s, 240s]. ZeroGPU rejects durations above the server's per-call max with "ZeroGPU illegal duration" (client.py:137); 240s is observed to work for Pro identity (~2 min runs needed for style + lipsync detailer paths). If the server rejects values in this range, the user will see a clear error and can retry. """ base = _BASE_DURATION_S.get(mode, 180) mult = _PRESET_MULT.get(preset.lower(), 1.5) frames = _frames_from_workflow(workflow) est = int((base * mult + 60 + frames * 0.3) * multiplier) return max(60, min(est, 240)) # Decorate at module load time so ZeroGPU's startup analyzer detects it. _GPU = ( spaces.GPU(duration=_duration_for) if (spaces is not None and _on_spaces()) else _identity ) @_GPU def _execute_workflow( executor: Any, workflow: dict, output_ids: list[str], mode: str, preset: str, multiplier: float = 1.0, progress: Any = None, ) -> str: """Run the workflow on GPU and return the path of the first video output. Returns just the video path (a plain string, picklable across the @spaces.GPU subprocess boundary). The `mode`, `preset`, and `multiplier` args are consumed by `_duration_for` to estimate the GPU slot to reserve. `progress` is an optional `gr.Progress` instance. It's the only progress channel that crosses the @spaces.GPU subprocess boundary on HF Spaces — Gradio + the `spaces` library wrap it with cross-process IPC. When set, we mirror ComfyUI's step counter into it via the global progress hook, chaining to whatever hook was already installed (so the local event-based status banner keeps working alongside). """ if progress is not None: import comfy.utils as _cu _saved_hook = getattr(_cu, "PROGRESS_BAR_HOOK", None) def _gp_hook(value, total, _preview=None, **_kw): try: v, t = int(value), int(total) progress(v / max(t, 1), desc=f"Sampling step {v}/{t}") except Exception: pass if _saved_hook is not None: try: _saved_hook(value, total, _preview) except Exception: pass _cu.set_progress_bar_global_hook(_gp_hook) executor.execute( workflow, prompt_id="ltx23-aio", extra_data={"client_id": "ltx23-aio"}, execute_outputs=output_ids, ) hist = getattr(executor, "history_result", {}) or {} outs = hist.get("outputs") or {} for output in outs.values(): if not isinstance(output, dict): continue for value in output.values(): if not isinstance(value, list): continue for item in value: if isinstance(item, dict): fn = item.get("filename") or "" if fn.endswith((".mp4", ".webm", ".mov")): return item.get("fullpath") or fn return "" class _StubServer: """Minimal stub matching the surface ComfyUI's PromptExecutor expects.""" client_id: str | None = "ltx23-aio" last_node_id: str | None = None def send_sync(self, event: str, data: dict, sid: str | None = None) -> None: pass def queue_updated(self) -> None: pass class _StubPromptQueue: """Stub matching the surface VideoHelperSuite + others touch.""" currently_running: dict = {} history: dict = {} flags: dict = {} def get_current_queue(self) -> tuple[list, list]: return ([], []) def get_tasks_remaining(self) -> int: return 0 def set_flag(self, name: str, data) -> None: pass def get_flags(self, *a, **kw) -> dict: return {} def task_done(self, *a, **kw) -> None: pass def put(self, *a, **kw) -> None: pass def wipe_queue(self) -> None: pass def delete_queue_item(self, *a, **kw) -> None: pass class _StubPromptServerInstance: """Surface that ComfyUI's `server.PromptServer.instance` exposes to custom nodes. VideoHelperSuite, KJNodes, and others read this at import time. They mostly use it to register HTTP routes or send WS events or peek at the prompt queue. No-ops here are fine — we have no real server. """ client_id: str | None = "ltx23-aio" # KJNodes' preview thread reads `last_node_id.encode('ascii')` directly. # ComfyUI's real server keeps it as a string per executing node and resets # to None at end-of-prompt — which races the preview thread. Keep it a # safe non-empty string so .encode() never NPEs. last_node_id: str = "ltx23-aio" web_root: str = "" class _Routes: def get(self, *a, **kw): return lambda fn: fn def post(self, *a, **kw): return lambda fn: fn def static(self, *a, **kw): return None routes = _Routes() sockets: dict = {} prompt_queue = _StubPromptQueue() # Custom-Scripts checks PromptServer.instance.supports — claim the # "custom_nodes_from_web" capability so it skips its JS install path. supports: list[str] = ["custom_nodes_from_web"] web_root: str = "" def add_routes(self) -> None: pass def send_sync(self, event: str, data: dict, sid: str | None = None) -> None: pass def send_progress_text(self, text: str, node_id=None, sid=None) -> None: # Comfy_extras nodes call this; we just no-op since we don't have a UI # to surface intermediate text on. pass def queue_updated(self) -> None: pass def get_node_class_def(self, *a, **kw): return None def __getattr__(self, name): # Anything else our custom nodes might reach for — give them a no-op. # This is a deliberate liberal catch-all so the inference path doesn't # die on cosmetic UI hooks. Inspection-style access (hasattr) gets True. def _noop(*a, **kw): return None return _noop def _comfy_dir() -> pathlib.Path: if _on_spaces(): return pathlib.Path.home() / "comfyui" return pathlib.Path(__file__).parent / "comfyui" class ComfyUILibraryBackend: """Wraps PromptExecutor for in-process workflow execution.""" def __init__(self) -> None: self._comfy_dir = _comfy_dir() if not self._comfy_dir.exists(): raise RuntimeError( f"ComfyUI not found at {self._comfy_dir}. " f"Local: run `bash setup.sh`. Spaces: see app.py:_bootstrap()." ) if str(self._comfy_dir) not in sys.path: sys.path.insert(0, str(self._comfy_dir)) # Defer comfy imports until the path is set up. # NOTE: ComfyUI ships PromptExecutor in the top-level `execution.py` # module, NOT under `comfy.execution`. Same for `nodes`. Both must be # imported AFTER the sys.path insert above. import asyncio import threading import comfy.cli_args # noqa: F401 — side-effect: registers CLI flags import execution # top-level module — provides PromptExecutor import nodes # top-level module — provides init_extra_nodes (async) # CRITICAL ordering fix: ComfyUI's nodes.py:24 inserts `comfyui/comfy/` # at sys.path[0]. That dir contains a module-style `utils.py`, which # shadows `comfyui/utils/` (a package containing install_util.py). # Some custom nodes (KJNodes, VideoHelperSuite via app.frontend_management) # do `from utils.install_util import …` and get `comfy/utils.py` instead, # raising "'utils' is not a package". Rewrite sys.path so comfy_dir is # ahead of comfy_dir/comfy and force-clear any cached `utils` binding. comfy_subdir = str(self._comfy_dir / "comfy") sys.path = [p for p in sys.path if p not in (str(self._comfy_dir), comfy_subdir)] sys.path.insert(0, comfy_subdir) sys.path.insert(0, str(self._comfy_dir)) if "utils" in sys.modules and not getattr(sys.modules["utils"], "__path__", None): del sys.modules["utils"] # Some custom nodes (e.g. VideoHelperSuite) read `server.PromptServer.instance` # at import time. We don't run a real ComfyUI server, so install a stub # that exposes the attributes those nodes touch (sockets, send, etc.). import server as comfy_server if getattr(comfy_server.PromptServer, "instance", None) is None: comfy_server.PromptServer.instance = _StubPromptServerInstance() # `nodes.init_extra_nodes` is async. We may be called from within a # running event loop (Gradio's handler) — running `asyncio.run()` there # raises. Run the coroutine in a fresh loop on a worker thread instead. def _init_in_thread() -> None: loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: loop.run_until_complete(nodes.init_extra_nodes()) finally: loop.close() thread = threading.Thread(target=_init_in_thread, daemon=False) thread.start() thread.join() # PromptExecutor expects a `server` with client_id, send_sync, last_node_id, # queue_updated. A minimal stub no-ops all of them — we don't run a real # websocket server, we surface progress via comfy.utils.PROGRESS_BAR_HOOK. # cache_args["ram"] is read unconditionally inside execute_async even when # cache_type is the default false — provide a sensible default so it doesn't # NoneType-subscript at line 727. self._executor = execution.PromptExecutor( server=_StubServer(), cache_args={"ram": 16.0, "lru": 0}, ) def __repr__(self) -> str: return f"ComfyUILibraryBackend(comfy_dir={self._comfy_dir!r})" async def submit( self, mode: str, workflow: dict, *, preset: str = "balanced", duration_multiplier: float = 1.0, gpu_duration: int = 0, # legacy, ignored (now derived from preset+frames) progress: Any = None, ) -> AsyncIterator[Any]: """Run a workflow end-to-end. Yields Download/Progress/Output/Error events. `preset` and `duration_multiplier` flow through to the @spaces.GPU duration estimator. The handler can re-call submit() with duration_multiplier=2.0 if the first attempt aborts on timeout. """ # Pre-flight: ensure all model files exist. try: needed = models.walk_workflow_for_models(workflow) for download_event in models.ensure_models(needed): yield download_event except Exception as e: yield ErrorEvent( category="download", message=str(e), traceback=tb_mod.format_exc(), ) return # Run the inference in a worker thread; pass progress events through a queue. queue: asyncio.Queue = asyncio.Queue() loop = asyncio.get_running_loop() def _push(event: Any) -> None: asyncio.run_coroutine_threadsafe(queue.put(event), loop) # Track stage progression. ComfyUI fires the progress hook from inside # samplers, so we advance the stage every time we observe a new sampler # starting (step==0 with a different total than before, or a "new run" # signal — value smaller than the running max for the same total). progress_state = {"stage": 0, "prev_total": -1, "max_step": -1} def _hook(value: int, total: int, _preview=None, **_kwargs: Any) -> None: v, t = int(value), int(total) # New sampler started (different total, or step rewound) if t != progress_state["prev_total"] or v < progress_state["max_step"]: progress_state["stage"] += 1 progress_state["prev_total"] = t progress_state["max_step"] = v else: progress_state["max_step"] = max(progress_state["max_step"], v) _push( ProgressEvent( stage=progress_state["stage"], stage_label="diffusion", step=v, total_steps=t, ) ) def _worker() -> None: import comfy.utils saved_hook = getattr(comfy.utils, "PROGRESS_BAR_HOOK", None) try: # Workflow is already API-format (saved from ComfyUI editor's # "Save (API Format)"), so it can be handed to PromptExecutor # directly. The execute_outputs list pinpoints which output # nodes to evaluate — we let PromptExecutor walk the whole # graph by passing every output-class node id. output_ids = [ nid for nid, n in workflow.items() if n.get("class_type", "").startswith(("SaveVideo", "VHS_VideoCombine", "PreviewAudio", "CreateVideo")) ] print( f"[backend] submitting workflow: {len(workflow)} nodes, " f"output_ids={output_ids}", file=sys.stderr, flush=True, ) # Use the public setter; it writes the same global the # ProgressBar class reads, but is the documented API. comfy.utils.set_progress_bar_global_hook(_hook) # _execute_workflow is module-level and decorated with a # @spaces.GPU(duration=callable) on Spaces — the callable # estimates per-call timeout from (mode, preset, frames) so # light calls get fast queue priority while heavy ones reserve # real headroom. Off-Spaces it's a plain call. video_path = _execute_workflow( self._executor, workflow, output_ids, mode, preset, duration_multiplier, progress, ) # Fallback: if history_result didn't surface a path (rare on # Spaces — happens when ZeroGPU's subprocess boundary drops # mutated state), scan the output dir for the newest mp4 # written within the last 60 s. if not video_path: video_path = _newest_recent_video(self._comfy_dir / "output") or "" print( f"[backend] workflow done; video_path={video_path!r}", file=sys.stderr, flush=True, ) _push(OutputEvent(video_path=video_path)) except Exception as exc: tb_text = tb_mod.format_exc() print(f"[backend] worker exception:\n{tb_text}", file=sys.stderr, flush=True) _push( ErrorEvent( category=_classify(exc), message=str(exc), traceback=tb_text, ) ) finally: comfy.utils.set_progress_bar_global_hook(saved_hook) _free_memory() _push(None) # sentinel: stop the consumer # ZeroGPU's @spaces.GPU wrapper reads the user's identity from the # current Gradio request via gradio.context.LocalContext.request, # which is a contextvar. Plain threads don't inherit contextvars, so # without this the worker sees request=None, X-IP-Token never gets # read, and `client.schedule` raises "Space app has reached its GPU # limit" (token-is-None branch in spaces/zero/client.py:138). Copy # the calling task's context so the request — and therefore the Pro # user's quota attribution — survives the thread boundary. ctx = contextvars.copy_context() thread = threading.Thread(target=ctx.run, args=(_worker,), daemon=True) thread.start() while True: event = await queue.get() if event is None: return yield event def interrupt(self) -> None: """Cancel the currently running workflow (if any).""" try: import comfy.model_management as mm mm.interrupt_current_processing() except Exception: pass def _classify(exc: Exception) -> str: name = type(exc).__name__.lower() msg = str(exc).lower() if "outofmemory" in name or "cuda out of memory" in msg: return "oom" if "expired zerogpu proxy token" in msg or "expired" in msg and "token" in msg: return "expired_token" if "illegal duration" in msg: return "illegal_duration" if "unlogged user" in msg: return "unlogged" if "exceeded your" in msg and "gpu" in msg: return "quota_exceeded" # ZeroGPU enforces the @spaces.GPU(duration=N) cap and re-raises as # gradio.exceptions.Error('GPU task aborted'). if "gpu task aborted" in msg or ("gpu" in msg and "aborted" in msg): return "gpu_timeout" if "interrupt" in name: return "interrupt" return "execution" def _free_memory() -> None: """Free VRAM after a workflow finishes (success or failure).""" try: import comfy.model_management as mm mm.unload_all_models() except Exception: pass try: import torch if torch.backends.mps.is_available(): torch.mps.empty_cache() except Exception: pass try: import torch if torch.cuda.is_available(): torch.cuda.empty_cache() except Exception: pass def _newest_recent_video(output_root: pathlib.Path, within_seconds: float = 60.0) -> str | None: """Filesystem fallback: return the newest .mp4/.webm/.mov under *output_root* that was modified within the last *within_seconds* seconds. Used when the executor's history_result didn't surface a path — typically happens when ZeroGPU's subprocess boundary drops the mutation. The disk is shared, so the file is there even when the in-memory state isn't. """ import time if not output_root.exists(): return None cutoff = time.time() - within_seconds candidates: list[tuple[float, pathlib.Path]] = [] for ext in (".mp4", ".webm", ".mov"): for p in output_root.rglob(f"*{ext}"): try: mtime = p.stat().st_mtime except OSError: continue if mtime >= cutoff: candidates.append((mtime, p)) if not candidates: return None candidates.sort(reverse=True) return str(candidates[0][1])