anima-gradio-zerogpu-space / src /comfy_backend.py
JSCPPProgrammer's picture
chore: sync local Space files
cd2b142 verified
"""
Embedded ComfyUI backend for Anima RDBT.
The Gradio app calls this module from inside ``@spaces.GPU``. It imports ComfyUI
node classes directly, keeps the heavy loader outputs cached per worker, and
executes the same UNET + CLIP + VAE + KSampler graph described in ``pipeline.py``.
"""
from __future__ import annotations
import importlib
import importlib.machinery
import os
import sys
import threading
import types
import urllib.request
import zipfile
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import Any, Optional
import numpy as np
from PIL import Image
from src import config
from src.config import GenerationParams
from src.errors import UserFacingError
COMFYUI_COMMIT = os.environ.get(
"ANIMA_COMFYUI_COMMIT",
"5e3f15a830ff27d3563ef4b43e9f6a0321ea36cd",
).strip() or "5e3f15a830ff27d3563ef4b43e9f6a0321ea36cd"
COMFYUI_ARCHIVE_URL = os.environ.get(
"ANIMA_COMFYUI_ARCHIVE_URL",
f"https://github.com/comfyanonymous/ComfyUI/archive/{COMFYUI_COMMIT}.zip",
)
_lock = threading.RLock()
_bootstrapped = False
_runtime: "_ComfyRuntime | None" = None
@dataclass
class _LoadedGraph:
model: Any
clip: Any
vae: Any
@dataclass
class _ComfyRuntime:
node_mappings: dict[str, type]
loaded: _LoadedGraph | None = None
def _comfy_root() -> Path:
return Path(config.comfy_root()).resolve()
def _models_root() -> Path:
return Path(config.model_artifacts_root()).resolve()
def _report_progress(progress: Any, value: float, desc: str) -> None:
if progress is None:
return
try:
progress(value, desc=desc)
except TypeError:
try:
progress(value)
except Exception:
pass
except Exception:
pass
def _download_comfy_source_if_needed() -> None:
root = _comfy_root()
if (root / "nodes.py").is_file():
return
if os.environ.get("ANIMA_DISABLE_COMFY_SOURCE_FETCH", "").strip() == "1":
raise UserFacingError(
f"ComfyUI source not found at {root}. Add ComfyUI there or unset ANIMA_DISABLE_COMFY_SOURCE_FETCH."
)
root.parent.mkdir(parents=True, exist_ok=True)
print(f"[startup] Fetching ComfyUI source from {COMFYUI_ARCHIVE_URL}", flush=True)
try:
with urllib.request.urlopen(COMFYUI_ARCHIVE_URL, timeout=600) as response:
data = response.read()
with zipfile.ZipFile(BytesIO(data)) as zf:
members = zf.namelist()
top_levels = {m.split("/", 1)[0] for m in members if "/" in m}
if not top_levels:
raise RuntimeError("archive did not contain a top-level directory")
archive_root = sorted(top_levels)[0]
tmp_root = root.parent / f".{root.name}.download"
if tmp_root.exists():
import shutil
shutil.rmtree(tmp_root)
tmp_root.mkdir(parents=True, exist_ok=True)
zf.extractall(tmp_root)
extracted = tmp_root / archive_root
if root.exists():
import shutil
shutil.rmtree(root)
extracted.replace(root)
tmp_root.rmdir()
except UserFacingError:
raise
except Exception as e:
raise UserFacingError(f"Failed to fetch ComfyUI source: {e!s}") from e
if not (root / "nodes.py").is_file():
raise UserFacingError(f"Downloaded ComfyUI source is missing nodes.py under {root}.")
def _configure_comfy_paths() -> None:
root = _comfy_root()
if str(root) not in sys.path:
sys.path.insert(0, str(root))
folder_paths = importlib.import_module("folder_paths")
models_root = str(_models_root())
for folder_name in ("diffusion_models", "text_encoders", "vae"):
folder = os.path.join(models_root, folder_name)
os.makedirs(folder, exist_ok=True)
if hasattr(folder_paths, "add_model_folder_path"):
folder_paths.add_model_folder_path(folder_name, folder)
elif hasattr(folder_paths, "folder_names_and_paths"):
current = folder_paths.folder_names_and_paths.get(folder_name)
if current is not None:
paths, extensions = current
if folder not in paths:
folder_paths.folder_names_and_paths[folder_name] = ([folder, *paths], extensions)
def _missing_torchaudio(*_args: Any, **_kwargs: Any) -> Any:
raise RuntimeError(
"torchaudio is unavailable in this Space. The Anima RDBT text-to-image workflow "
"does not use Comfy audio nodes; install a PyTorch-compatible torchaudio wheel "
"only if you add an audio workflow."
)
def _install_torchaudio_stub_if_needed() -> None:
"""Allow image-only Comfy imports when optional torchaudio is missing or ABI-broken."""
if os.environ.get("ANIMA_STUB_TORCHAUDIO", "1").strip() == "0":
return
try:
importlib.import_module("torchaudio")
return
except Exception:
sys.modules.pop("torchaudio", None)
sys.modules.pop("torchaudio.functional", None)
sys.modules.pop("torchaudio.transforms", None)
torchaudio = types.ModuleType("torchaudio")
functional = types.ModuleType("torchaudio.functional")
transforms = types.ModuleType("torchaudio.transforms")
torchaudio.__spec__ = importlib.machinery.ModuleSpec("torchaudio", loader=None)
functional.__spec__ = importlib.machinery.ModuleSpec("torchaudio.functional", loader=None)
transforms.__spec__ = importlib.machinery.ModuleSpec("torchaudio.transforms", loader=None)
functional.resample = _missing_torchaudio # type: ignore[attr-defined]
functional.bass_biquad = _missing_torchaudio # type: ignore[attr-defined]
functional.equalizer_biquad = _missing_torchaudio # type: ignore[attr-defined]
functional.treble_biquad = _missing_torchaudio # type: ignore[attr-defined]
class _MissingAudioTransform:
def __init__(self, *_args: Any, **_kwargs: Any) -> None:
_missing_torchaudio()
transforms.MelScale = _MissingAudioTransform # type: ignore[attr-defined]
transforms.MelSpectrogram = _MissingAudioTransform # type: ignore[attr-defined]
torchaudio.functional = functional # type: ignore[attr-defined]
torchaudio.transforms = transforms # type: ignore[attr-defined]
sys.modules["torchaudio"] = torchaudio
sys.modules["torchaudio.functional"] = functional
sys.modules["torchaudio.transforms"] = transforms
def _init_comfy_runtime() -> _ComfyRuntime:
_download_comfy_source_if_needed()
_configure_comfy_paths()
_install_torchaudio_stub_if_needed()
try:
nodes = importlib.import_module("nodes")
if os.environ.get("ANIMA_INIT_COMFY_EXTRA_NODES", "").strip() == "1" and hasattr(nodes, "init_extra_nodes"):
nodes.init_extra_nodes()
mappings = getattr(nodes, "NODE_CLASS_MAPPINGS")
except Exception as e:
raise UserFacingError(f"Failed to initialize ComfyUI nodes: {e!s}") from e
required = ("UNETLoader", "CLIPLoader", "VAELoader", "CLIPTextEncode", "EmptyLatentImage", "KSampler", "VAEDecode")
missing = [name for name in required if name not in mappings]
if missing:
raise UserFacingError(f"ComfyUI is missing required node classes: {', '.join(missing)}.")
return _ComfyRuntime(node_mappings=mappings)
def _bootstrap_files_if_needed() -> None:
global _bootstrapped
with _lock:
if _bootstrapped:
return
from src import bootstrap
try:
bootstrap.bootstrap_model_artifacts()
except UserFacingError:
raise
except Exception as e:
raise UserFacingError(f"Model bootstrap failed: {e!s}. See logs for full traceback.") from e
_bootstrapped = True
def run_at_container_startup() -> None:
"""Run at Space import: download disk artifacts and make ComfyUI importable."""
print("[startup] Preparing ComfyUI source and Anima RDBT model files...", flush=True)
_download_comfy_source_if_needed()
_bootstrap_files_if_needed()
print(
"[startup] ComfyUI source and model files are ready. Native Comfy nodes load on first Generate.",
flush=True,
)
def ensure_prepared() -> None:
"""Ensure model files exist and Comfy node classes are registered."""
global _runtime
with _lock:
if _runtime is not None:
return
_bootstrap_files_if_needed()
runtime = _init_comfy_runtime()
with _lock:
if _runtime is None:
_runtime = runtime
def _node(runtime: _ComfyRuntime, class_name: str) -> Any:
try:
return runtime.node_mappings[class_name]()
except Exception as e:
raise UserFacingError(f"Failed to create Comfy node {class_name}: {e!s}") from e
def _call_node(node: Any, method_name: str, **kwargs: Any) -> Any:
method = getattr(node, method_name, None)
if method is None:
raise UserFacingError(f"Comfy node {node.__class__.__name__} has no method {method_name!r}.")
try:
return method(**kwargs)
except Exception as e:
raise UserFacingError(f"Comfy node {node.__class__.__name__}.{method_name} failed: {e!s}") from e
def _value(output: Any, index: int = 0) -> Any:
if isinstance(output, dict) and "result" in output:
output = output["result"]
if isinstance(output, (tuple, list)):
return output[index]
return output
def _load_graph(runtime: _ComfyRuntime, progress: Any = None) -> _LoadedGraph:
with _lock:
if runtime.loaded is not None:
return runtime.loaded
_report_progress(progress, 0.05, "Loading Comfy UNET / CLIP / VAE (first run can take several minutes)...")
model = _value(
_call_node(
_node(runtime, "UNETLoader"),
"load_unet",
unet_name=config.RDBT_UNET_NAME,
weight_dtype=config.UNET_WEIGHT_DTYPE,
)
)
clip = _value(
_call_node(
_node(runtime, "CLIPLoader"),
"load_clip",
clip_name=config.CLIP_NAME,
type=config.CLIP_TYPE,
)
)
vae = _value(_call_node(_node(runtime, "VAELoader"), "load_vae", vae_name=config.VAE_NAME))
loaded = _LoadedGraph(model=model, clip=clip, vae=vae)
with _lock:
if runtime.loaded is None:
runtime.loaded = loaded
return runtime.loaded
def _tensor_to_images(images: Any) -> list[Image.Image]:
if isinstance(images, Image.Image):
return [images.convert("RGB")]
if isinstance(images, (list, tuple)) and images and isinstance(images[0], Image.Image):
return [im.convert("RGB") for im in images]
if hasattr(images, "detach"):
images = images.detach().cpu().numpy()
arr = np.asarray(images)
if arr.ndim == 3:
arr = arr[None, ...]
if arr.ndim != 4:
raise UserFacingError(f"Comfy VAEDecode returned unsupported image shape: {arr.shape!r}.")
out: list[Image.Image] = []
for frame in arr:
frame = np.clip(frame, 0.0, 1.0)
frame = (frame * 255.0).round().astype("uint8")
out.append(Image.fromarray(frame).convert("RGB"))
return out
def run_generation(
p: GenerationParams,
*,
progress: Optional[Any] = None,
) -> tuple[list[Image.Image], str, str, str]:
"""
Execute the embedded ComfyUI graph and return images plus status details.
"""
_report_progress(progress, 0.0, "Preparing ComfyUI...")
ensure_prepared()
assert _runtime is not None
runtime = _runtime
loaded = _load_graph(runtime, progress=progress)
_report_progress(progress, 0.20, "Encoding prompts...")
positive = _value(
_call_node(_node(runtime, "CLIPTextEncode"), "encode", text=p.prompt, clip=loaded.clip)
)
negative = _value(
_call_node(_node(runtime, "CLIPTextEncode"), "encode", text=p.negative_prompt, clip=loaded.clip)
)
latent = _value(
_call_node(
_node(runtime, "EmptyLatentImage"),
"generate",
width=int(p.width),
height=int(p.height),
batch_size=int(p.batch_size),
)
)
_report_progress(progress, 0.35, f"Running Comfy KSampler ({p.steps} steps)...")
samples = _value(
_call_node(
_node(runtime, "KSampler"),
"sample",
model=loaded.model,
seed=int(p.seed),
steps=int(p.steps),
cfg=float(p.cfg),
sampler_name=p.sampler_name,
scheduler=p.scheduler,
positive=positive,
negative=negative,
latent_image=latent,
denoise=float(p.denoise),
)
)
_report_progress(progress, 0.95, "Decoding images...")
decoded = _value(_call_node(_node(runtime, "VAEDecode"), "decode", samples=samples, vae=loaded.vae))
images = _tensor_to_images(decoded)
if not images:
raise UserFacingError("ComfyUI returned no images.")
_report_progress(progress, 1.0, "Done.")
details = (
f"seed={p.seed} | {p.width}x{p.height} | steps={p.steps} | cfg={p.cfg} | "
f"batch={p.batch_size} | {p.sampler_name}/{p.scheduler} | denoise={p.denoise}"
)
return images, details, p.prompt, p.negative_prompt