""" Embedded ComfyUI backend for Anima RDBT. The Gradio app calls this module from inside ``@spaces.GPU``. It imports ComfyUI node classes directly, keeps the heavy loader outputs cached per worker, and executes the same UNET + CLIP + VAE + KSampler graph described in ``pipeline.py``. """ from __future__ import annotations import importlib import importlib.machinery import os import sys import threading import types import urllib.request import zipfile from dataclasses import dataclass from io import BytesIO from pathlib import Path from typing import Any, Optional import numpy as np from PIL import Image from src import config from src.config import GenerationParams from src.errors import UserFacingError COMFYUI_COMMIT = os.environ.get( "ANIMA_COMFYUI_COMMIT", "5e3f15a830ff27d3563ef4b43e9f6a0321ea36cd", ).strip() or "5e3f15a830ff27d3563ef4b43e9f6a0321ea36cd" COMFYUI_ARCHIVE_URL = os.environ.get( "ANIMA_COMFYUI_ARCHIVE_URL", f"https://github.com/comfyanonymous/ComfyUI/archive/{COMFYUI_COMMIT}.zip", ) _lock = threading.RLock() _bootstrapped = False _runtime: "_ComfyRuntime | None" = None @dataclass class _LoadedGraph: model: Any clip: Any vae: Any @dataclass class _ComfyRuntime: node_mappings: dict[str, type] loaded: _LoadedGraph | None = None def _comfy_root() -> Path: return Path(config.comfy_root()).resolve() def _models_root() -> Path: return Path(config.model_artifacts_root()).resolve() def _report_progress(progress: Any, value: float, desc: str) -> None: if progress is None: return try: progress(value, desc=desc) except TypeError: try: progress(value) except Exception: pass except Exception: pass def _download_comfy_source_if_needed() -> None: root = _comfy_root() if (root / "nodes.py").is_file(): return if os.environ.get("ANIMA_DISABLE_COMFY_SOURCE_FETCH", "").strip() == "1": raise UserFacingError( f"ComfyUI source not found at {root}. Add ComfyUI there or unset ANIMA_DISABLE_COMFY_SOURCE_FETCH." ) root.parent.mkdir(parents=True, exist_ok=True) print(f"[startup] Fetching ComfyUI source from {COMFYUI_ARCHIVE_URL}", flush=True) try: with urllib.request.urlopen(COMFYUI_ARCHIVE_URL, timeout=600) as response: data = response.read() with zipfile.ZipFile(BytesIO(data)) as zf: members = zf.namelist() top_levels = {m.split("/", 1)[0] for m in members if "/" in m} if not top_levels: raise RuntimeError("archive did not contain a top-level directory") archive_root = sorted(top_levels)[0] tmp_root = root.parent / f".{root.name}.download" if tmp_root.exists(): import shutil shutil.rmtree(tmp_root) tmp_root.mkdir(parents=True, exist_ok=True) zf.extractall(tmp_root) extracted = tmp_root / archive_root if root.exists(): import shutil shutil.rmtree(root) extracted.replace(root) tmp_root.rmdir() except UserFacingError: raise except Exception as e: raise UserFacingError(f"Failed to fetch ComfyUI source: {e!s}") from e if not (root / "nodes.py").is_file(): raise UserFacingError(f"Downloaded ComfyUI source is missing nodes.py under {root}.") def _configure_comfy_paths() -> None: root = _comfy_root() if str(root) not in sys.path: sys.path.insert(0, str(root)) folder_paths = importlib.import_module("folder_paths") models_root = str(_models_root()) for folder_name in ("diffusion_models", "text_encoders", "vae"): folder = os.path.join(models_root, folder_name) os.makedirs(folder, exist_ok=True) if hasattr(folder_paths, "add_model_folder_path"): folder_paths.add_model_folder_path(folder_name, folder) elif hasattr(folder_paths, "folder_names_and_paths"): current = folder_paths.folder_names_and_paths.get(folder_name) if current is not None: paths, extensions = current if folder not in paths: folder_paths.folder_names_and_paths[folder_name] = ([folder, *paths], extensions) def _missing_torchaudio(*_args: Any, **_kwargs: Any) -> Any: raise RuntimeError( "torchaudio is unavailable in this Space. The Anima RDBT text-to-image workflow " "does not use Comfy audio nodes; install a PyTorch-compatible torchaudio wheel " "only if you add an audio workflow." ) def _install_torchaudio_stub_if_needed() -> None: """Allow image-only Comfy imports when optional torchaudio is missing or ABI-broken.""" if os.environ.get("ANIMA_STUB_TORCHAUDIO", "1").strip() == "0": return try: importlib.import_module("torchaudio") return except Exception: sys.modules.pop("torchaudio", None) sys.modules.pop("torchaudio.functional", None) sys.modules.pop("torchaudio.transforms", None) torchaudio = types.ModuleType("torchaudio") functional = types.ModuleType("torchaudio.functional") transforms = types.ModuleType("torchaudio.transforms") torchaudio.__spec__ = importlib.machinery.ModuleSpec("torchaudio", loader=None) functional.__spec__ = importlib.machinery.ModuleSpec("torchaudio.functional", loader=None) transforms.__spec__ = importlib.machinery.ModuleSpec("torchaudio.transforms", loader=None) functional.resample = _missing_torchaudio # type: ignore[attr-defined] functional.bass_biquad = _missing_torchaudio # type: ignore[attr-defined] functional.equalizer_biquad = _missing_torchaudio # type: ignore[attr-defined] functional.treble_biquad = _missing_torchaudio # type: ignore[attr-defined] class _MissingAudioTransform: def __init__(self, *_args: Any, **_kwargs: Any) -> None: _missing_torchaudio() transforms.MelScale = _MissingAudioTransform # type: ignore[attr-defined] transforms.MelSpectrogram = _MissingAudioTransform # type: ignore[attr-defined] torchaudio.functional = functional # type: ignore[attr-defined] torchaudio.transforms = transforms # type: ignore[attr-defined] sys.modules["torchaudio"] = torchaudio sys.modules["torchaudio.functional"] = functional sys.modules["torchaudio.transforms"] = transforms def _init_comfy_runtime() -> _ComfyRuntime: _download_comfy_source_if_needed() _configure_comfy_paths() _install_torchaudio_stub_if_needed() try: nodes = importlib.import_module("nodes") if os.environ.get("ANIMA_INIT_COMFY_EXTRA_NODES", "").strip() == "1" and hasattr(nodes, "init_extra_nodes"): nodes.init_extra_nodes() mappings = getattr(nodes, "NODE_CLASS_MAPPINGS") except Exception as e: raise UserFacingError(f"Failed to initialize ComfyUI nodes: {e!s}") from e required = ("UNETLoader", "CLIPLoader", "VAELoader", "CLIPTextEncode", "EmptyLatentImage", "KSampler", "VAEDecode") missing = [name for name in required if name not in mappings] if missing: raise UserFacingError(f"ComfyUI is missing required node classes: {', '.join(missing)}.") return _ComfyRuntime(node_mappings=mappings) def _bootstrap_files_if_needed() -> None: global _bootstrapped with _lock: if _bootstrapped: return from src import bootstrap try: bootstrap.bootstrap_model_artifacts() except UserFacingError: raise except Exception as e: raise UserFacingError(f"Model bootstrap failed: {e!s}. See logs for full traceback.") from e _bootstrapped = True def run_at_container_startup() -> None: """Run at Space import: download disk artifacts and make ComfyUI importable.""" print("[startup] Preparing ComfyUI source and Anima RDBT model files...", flush=True) _download_comfy_source_if_needed() _bootstrap_files_if_needed() print( "[startup] ComfyUI source and model files are ready. Native Comfy nodes load on first Generate.", flush=True, ) def ensure_prepared() -> None: """Ensure model files exist and Comfy node classes are registered.""" global _runtime with _lock: if _runtime is not None: return _bootstrap_files_if_needed() runtime = _init_comfy_runtime() with _lock: if _runtime is None: _runtime = runtime def _node(runtime: _ComfyRuntime, class_name: str) -> Any: try: return runtime.node_mappings[class_name]() except Exception as e: raise UserFacingError(f"Failed to create Comfy node {class_name}: {e!s}") from e def _call_node(node: Any, method_name: str, **kwargs: Any) -> Any: method = getattr(node, method_name, None) if method is None: raise UserFacingError(f"Comfy node {node.__class__.__name__} has no method {method_name!r}.") try: return method(**kwargs) except Exception as e: raise UserFacingError(f"Comfy node {node.__class__.__name__}.{method_name} failed: {e!s}") from e def _value(output: Any, index: int = 0) -> Any: if isinstance(output, dict) and "result" in output: output = output["result"] if isinstance(output, (tuple, list)): return output[index] return output def _load_graph(runtime: _ComfyRuntime, progress: Any = None) -> _LoadedGraph: with _lock: if runtime.loaded is not None: return runtime.loaded _report_progress(progress, 0.05, "Loading Comfy UNET / CLIP / VAE (first run can take several minutes)...") model = _value( _call_node( _node(runtime, "UNETLoader"), "load_unet", unet_name=config.RDBT_UNET_NAME, weight_dtype=config.UNET_WEIGHT_DTYPE, ) ) clip = _value( _call_node( _node(runtime, "CLIPLoader"), "load_clip", clip_name=config.CLIP_NAME, type=config.CLIP_TYPE, ) ) vae = _value(_call_node(_node(runtime, "VAELoader"), "load_vae", vae_name=config.VAE_NAME)) loaded = _LoadedGraph(model=model, clip=clip, vae=vae) with _lock: if runtime.loaded is None: runtime.loaded = loaded return runtime.loaded def _tensor_to_images(images: Any) -> list[Image.Image]: if isinstance(images, Image.Image): return [images.convert("RGB")] if isinstance(images, (list, tuple)) and images and isinstance(images[0], Image.Image): return [im.convert("RGB") for im in images] if hasattr(images, "detach"): images = images.detach().cpu().numpy() arr = np.asarray(images) if arr.ndim == 3: arr = arr[None, ...] if arr.ndim != 4: raise UserFacingError(f"Comfy VAEDecode returned unsupported image shape: {arr.shape!r}.") out: list[Image.Image] = [] for frame in arr: frame = np.clip(frame, 0.0, 1.0) frame = (frame * 255.0).round().astype("uint8") out.append(Image.fromarray(frame).convert("RGB")) return out def run_generation( p: GenerationParams, *, progress: Optional[Any] = None, ) -> tuple[list[Image.Image], str, str, str]: """ Execute the embedded ComfyUI graph and return images plus status details. """ _report_progress(progress, 0.0, "Preparing ComfyUI...") ensure_prepared() assert _runtime is not None runtime = _runtime loaded = _load_graph(runtime, progress=progress) _report_progress(progress, 0.20, "Encoding prompts...") positive = _value( _call_node(_node(runtime, "CLIPTextEncode"), "encode", text=p.prompt, clip=loaded.clip) ) negative = _value( _call_node(_node(runtime, "CLIPTextEncode"), "encode", text=p.negative_prompt, clip=loaded.clip) ) latent = _value( _call_node( _node(runtime, "EmptyLatentImage"), "generate", width=int(p.width), height=int(p.height), batch_size=int(p.batch_size), ) ) _report_progress(progress, 0.35, f"Running Comfy KSampler ({p.steps} steps)...") samples = _value( _call_node( _node(runtime, "KSampler"), "sample", model=loaded.model, seed=int(p.seed), steps=int(p.steps), cfg=float(p.cfg), sampler_name=p.sampler_name, scheduler=p.scheduler, positive=positive, negative=negative, latent_image=latent, denoise=float(p.denoise), ) ) _report_progress(progress, 0.95, "Decoding images...") decoded = _value(_call_node(_node(runtime, "VAEDecode"), "decode", samples=samples, vae=loaded.vae)) images = _tensor_to_images(decoded) if not images: raise UserFacingError("ComfyUI returned no images.") _report_progress(progress, 1.0, "Done.") details = ( f"seed={p.seed} | {p.width}x{p.height} | steps={p.steps} | cfg={p.cfg} | " f"batch={p.batch_size} | {p.sampler_name}/{p.scheduler} | denoise={p.denoise}" ) return images, details, p.prompt, p.negative_prompt