#!/usr/bin/env python3 """ Gradio front-end wrapper for SeedVR2's official inference_cli.py This script is the user's app.py enhanced to stream subprocess logs in real-time into the Gradio logs textbox. It runs the CLI as subprocesses and streams stdout/stderr lines as they arrive using a queue and reader threads. The Gradio handler `ui_upscale` is implemented as a generator so the frontend receives incremental updates. This script provides a simple web UI for single-image upscaling using the official ComfyUI-SeedVR2_VideoUpscaler `inference_cli.py` script. It calls the official CLI as a subprocess, and will automatically download model weights from Hugging Face (numz/SeedVR2_comfyUI) if they are missing. If the ComfyUI-SeedVR2_VideoUpscaler repository is not present, the script will attempt to `git clone` it automatically into ./ComfyUI-SeedVR2_VideoUpscaler. Run: python app.py Requirements - Python 3.10+ - Gradio (pip install gradio) - Git available in PATH (for automatic cloning) or clone the repo manually - PyTorch + CUDA (if using GPU) Notes - This wrapper calls the repo's `inference_cli.py` as a subprocess so the CLI's memory/optimization features (BlockSwap, VAE tiling, etc.) remain available. - Models will be downloaded to the cloned repo's ./models/SeedVR2 directory if missing. Use HUGGINGFACE_HUB_TOKEN env var if required for private access. """ import os import sys import cv2 import time import torch import queue import shutil import zipfile import threading import subprocess import numpy as np import gradio as gr from pathlib import Path from typing import Optional, Tuple, Generator, List # huggingface helper (used for model auto-download) from huggingface_hub import hf_hub_download def imreadUTF8(path, flags=cv2.IMREAD_COLOR): """ OpenCV's cv2.imread cannot handle non-ASCII paths. This function reads an image from a path that may contain UTF-8 characters. """ try: # Use NumPy to read from the file, which correctly handles UTF-8 paths with open(path, "rb") as stream: bytes_data = bytearray(stream.read()) numpyarray = np.asarray(bytes_data, dtype=np.uint8) # Use cv2.imdecode to decode the image from the memory buffer img = cv2.imdecode(numpyarray, flags) return img except Exception as e: # If reading fails, print the error message and return None print(f"ERROR: Failed to read image with UTF-8 path: {path}") print(f" Details: {e}") return None def imwriteUTF8(save_path, image): """ OpenCV's cv2.imwrite cannot handle non-ASCII paths. This function writes an image to a path that may contain UTF-8 characters. """ try: img_name = os.path.basename(save_path) _, extension = os.path.splitext(img_name) # Encode the image into the specified format (determined by the file extension) is_success, im_buf_arr = cv2.imencode(extension, image) if is_success: # Write the image data from memory to the file im_buf_arr.tofile(save_path) return True else: print(f"ERROR: Failed to encode image for path: {save_path}") return False except Exception as e: print(f"ERROR: Failed to write image with UTF-8 path: {save_path}") print(f" Details: {e}") return False # Apply Monkey Patch to cv2 (for app.py usage) print("[SeedVR2 Gradio] Applying UTF-8 patch to OpenCV (Frontend)...") cv2.imread = imreadUTF8 cv2.imwrite = imwriteUTF8 # ---------------- # Config / paths # ---------------- REPO_URL = "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler.git" CLONE_DIR = Path(__file__).resolve().parent / "ComfyUI-SeedVR2_VideoUpscaler" INFERENCE_CLI = CLONE_DIR / "inference_cli.py" PY_EXE = sys.executable # Use same Python executable to run CLI # Path to the custom improved blockswap file IMPROVED_BLOCKSWAP_SOURCE = Path(__file__).resolve().parent / "src" / "optimization" / "blockswap.py" IMPROVED_MEMORY_MANAGER_SOURCE = Path(__file__).resolve().parent / "src" / "optimization" / "memory_manager.py" # Default HF repo for VAE (VAE is usually static and comes from the official repo) DEFAULT_VAE_REPO_ID = "numz/SeedVR2_comfyUI" # Models are now stored in a fixed top-level directory, independent of the clone dir DEFAULT_MODEL_DIR = Path(__file__).resolve().parent / "models" / "SeedVR2" # ---------------- # Model Definitions (RepoID / Filename) # ---------------- # Standard Models (Safetensors) MODEL_CHOICES = [ "numz/SeedVR2_comfyUI/seedvr2_ema_3b_fp8_e4m3fn.safetensors", "numz/SeedVR2_comfyUI/seedvr2_ema_3b_fp16.safetensors", "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b_fp8_e4m3fn_mixed_block35_fp16.safetensors", "numz/SeedVR2_comfyUI/seedvr2_ema_7b_fp16.safetensors", # sharp variants "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b_sharp_fp8_e4m3fn_mixed_block35_fp16.safetensors", "numz/SeedVR2_comfyUI/seedvr2_ema_7b_sharp_fp16.safetensors", ] # GGUF / alternate model support GGUF_CHOICES = [ "AInVFX/SeedVR2_comfyUI/seedvr2_ema_3b-Q4_K_M.gguf", "AInVFX/SeedVR2_comfyUI/seedvr2_ema_3b-Q8_0.gguf", "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b-Q4_K_M.gguf", # sharp variants "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b_sharp-Q4_K_M.gguf", # custom GGUF from cmeka "cmeka/SeedVR2-GGUF/seedvr2_ema_7b-Q8_0.gguf", "cmeka/SeedVR2-GGUF/seedvr2_ema_7b_sharp-Q8_0.gguf", ] # # Model registry with metadata # MODEL_REGISTRY = { # # 3B models # "seedvr2_ema_3b-Q4_K_M.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="3B", precision="Q4_K_M", sha256="e665e3909de1a8c88a69c609bca9d43ff5a134647face2ce4497640cc3597f0e"), # "seedvr2_ema_3b-Q8_0.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="3B", precision="Q8_0", sha256="be0d60083a2051a265eb4b77f28edf494e6db67ffc250216f32b72292e5cbd96"), # "seedvr2_ema_3b_fp8_e4m3fn.safetensors": ModelInfo(size="3B", precision="fp8_e4m3fn", sha256="3bf1e43ebedd570e7e7a0b1b60d6a02e105978f505c8128a241cde99a8240cff"), # "seedvr2_ema_3b_fp16.safetensors": ModelInfo(size="3B", precision="fp16", sha256="2fd0e03a3dad24e07086750360727ca437de4ecd456f769856e960ae93e2b304"), # # 7B models # "seedvr2_ema_7b-Q4_K_M.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="Q4_K_M", sha256="db9cb2ad90ebd40d2e8c29da2b3fc6fd03ba87cd58cbadceccca13ad27162789"), # "seedvr2_ema_7b_fp8_e4m3fn_mixed_block35_fp16.safetensors": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="fp8_e4m3fn_mixed_block35_fp16", sha256="3d68b5ec0b295ae28092e355c8cad870edd00b817b26587d0cb8f9dd2df19bb2"), # "seedvr2_ema_7b_fp16.safetensors": ModelInfo(size="7B", precision="fp16", sha256="7b8241aa957606ab6cfb66edabc96d43234f9819c5392b44d2492d9f0b0bbe4a"), # # 7B sharp variants # "seedvr2_ema_7b_sharp-Q4_K_M.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="Q4_K_M", variant="sharp", sha256="7aed800ac4eb8e0d18569a954c0ff35f5a1caa3ed5d920e66cc31405f75b6e69"), # "seedvr2_ema_7b_sharp_fp8_e4m3fn_mixed_block35_fp16.safetensors": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="fp8_e4m3fn_mixed_block35_fp16", variant="sharp", sha256="0d2c5b8be0fda94351149c5115da26aef4f4932a7a2a928c6f184dda9186e0be"), # "seedvr2_ema_7b_sharp_fp16.safetensors": ModelInfo(size="7B", precision="fp16", variant="sharp", sha256="20a93e01ff24beaeebc5de4e4e5be924359606c356c9c51509fba245bd2d77dd"), # # VAE models # "ema_vae_fp16.safetensors": ModelInfo(category="vae", precision="fp16", sha256="20678548f420d98d26f11442d3528f8b8c94e57ee046ef93dbb7633da8612ca1"), # } # Detect Hardware Availability CUDA_AVAILABLE = torch.cuda.is_available() # Detect MPS availability (for Apple Silicon) MPS_AVAILABLE = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and torch.backends.mps.is_built() # Check for any hardware acceleration ACCELERATOR_AVAILABLE = CUDA_AVAILABLE or MPS_AVAILABLE # ----------------- # Repo / model helpers # ----------------- def ensure_repo_cloned( repo_url: str = REPO_URL, clone_dir_name: str = "ComfyUI-SeedVR2_VideoUpscaler", repo_branch: str = "", force_update: bool = False ) -> Path: """ Ensure the repository is cloned locally into the specified directory name. Supports specific branches/tags via checkout. Returns the resolved Path object to the cloned directory. """ # Resolve the physical path based on the script's parent location target_clone_dir = Path(__file__).resolve().parent / clone_dir_name target_cli = target_clone_dir / "inference_cli.py" # Helper function to handle detached/orphaned commits def _smart_checkout(cwd, ref): print(f"[SeedVR2 Gradio] Checking out '{ref}' in {cwd} ...") try: # Try standard checkout first (fastest if ref exists locally) subprocess.run(["git", "-C", str(cwd), "checkout", ref], check=True) except subprocess.CalledProcessError: # Fallback: If ref is not found (e.g. orphaned commit hash), fetch it explicitly print(f"[SeedVR2 Gradio] Standard checkout failed. Attempting to fetch specific ref '{ref}' from origin...") try: subprocess.run(["git", "-C", str(cwd), "fetch", "origin", ref], check=True) subprocess.run(["git", "-C", str(cwd), "checkout", ref], check=True) except Exception as e: raise RuntimeError(f"Failed to fetch/checkout specific ref '{ref}': {e}") if target_clone_dir.exists() and (target_clone_dir / ".git").exists(): # Repo exists if force_update: try: print(f"[SeedVR2 Gradio] Updating {target_clone_dir} ...") subprocess.run(["git", "-C", str(target_clone_dir), "fetch", "--all"], check=True) # If a specific branch/hash is requested if repo_branch: _smart_checkout(target_clone_dir, repo_branch) # If it's a branch name (not a detached hash), we might want to pull latest # But checking if it's a branch vs hash is complex, generally strictly checking out the ref is safer for reproducibility else: subprocess.run(["git", "-C", str(target_clone_dir), "pull"], check=True) except Exception as e: raise RuntimeError(f"Failed to update repository {target_clone_dir}: {e}") # If not forcing update, but a branch is specified, ensure we are on it elif repo_branch: try: subprocess.run(["git", "-C", str(target_clone_dir), "fetch", "--all"], check=True) _smart_checkout(target_clone_dir, repo_branch) except Exception as e: raise RuntimeError(f"Failed to switch to branch {repo_branch}: {e}") # Ensure inference_cli present if not target_cli.exists(): raise RuntimeError(f"Repository found at {target_clone_dir} but inference_cli.py is missing.") return target_clone_dir # Clone repo if not exists try: print(f"[SeedVR2 Gradio] Cloning {repo_url} into {target_clone_dir} ...") # Standard clone (fetches default branch) subprocess.run(["git", "clone", repo_url, str(target_clone_dir)], check=True) if repo_branch: _smart_checkout(target_clone_dir, repo_branch) except FileNotFoundError: raise RuntimeError("git not found: please install Git or clone the repository manually.") except Exception as e: raise RuntimeError(f"Failed to clone repository: {e}") if not target_cli.exists(): raise RuntimeError(f"Clone completed but inference_cli.py not found in {target_clone_dir}.") return target_clone_dir def apply_inference_cli_patch(cli_path: Path): """ Injects UTF-8 compatible imread/imwrite wrappers directly into inference_cli.py. This modifies the physical file so the subprocess (even on Windows spawn) uses the patch. """ if not cli_path.exists(): return try: with open(cli_path, "r", encoding="utf-8") as f: content = f.read() # Check if already patched to avoid duplicates if "def imreadUTF8" in content: return # The patch content to inject. # Note: We ensure 'import numpy as np' and 'import os' are available or re-imported. # inference_cli.py typically has 'import cv2', we inject right after that. patch_code = r''' # ============================================================================= # GRADIO APP PATCH: UTF-8 Support for Windows (Auto-Injected) # ============================================================================= import numpy as np import os def imreadUTF8(path, flags=cv2.IMREAD_COLOR): try: with open(path, "rb") as stream: bytes_data = bytearray(stream.read()) numpyarray = np.asarray(bytes_data, dtype=np.uint8) return cv2.imdecode(numpyarray, flags) except Exception as e: print(f"Error reading image {path}: {e}") return None def imwriteUTF8(save_path, image): try: img_name = os.path.basename(save_path) _, extension = os.path.splitext(img_name) is_success, im_buf_arr = cv2.imencode(extension, image) if is_success: im_buf_arr.tofile(save_path) return True else: return False except Exception as e: print(f"Error writing image {save_path}: {e}") return False # Override cv2 methods cv2.imread = imreadUTF8 cv2.imwrite = imwriteUTF8 # ============================================================================= ''' # Inject after 'import cv2' if "import cv2" in content: print(f"[SeedVR2 Gradio] Patching {cli_path} for UTF-8 subprocess support...") new_content = content.replace("import cv2", "import cv2" + patch_code, 1) with open(cli_path, "w", encoding="utf-8") as f: f.write(new_content) else: print("[SeedVR2 Gradio] WARNING: Could not find 'import cv2' in inference_cli.py. UTF-8 patch skipped.") except Exception as e: print(f"[SeedVR2 Gradio] ERROR applying UTF-8 patch to inference_cli: {e}") def patch_model_registry(repo_root: Path): """ Appends custom model definitions to src/utils/model_registry.py. This allows the CLI to recognize new GGUF models that aren't in the official registry. """ registry_path = repo_root / "src" / "utils" / "model_registry.py" if not registry_path.exists(): print(f"[SeedVR2 Gradio] WARN: Could not find model_registry.py at {registry_path}") return try: with open(registry_path, "r", encoding="utf-8") as f: content = f.read() # Check if already patched if "seedvr2_ema_7b-Q8_0.gguf" in content: return print(f"[SeedVR2 Gradio] Patching {registry_path} with custom GGUF models...") # Code to append to the end of the file. # Since ModelInfo and MODEL_REGISTRY are defined in the file, we can use them directly. patch_code = r''' # ============================================================================= # GRADIO APP PATCH: Custom Model Registry Entries # ============================================================================= try: # Update registry with custom GGUF models requested by user MODEL_REGISTRY.update({ "seedvr2_ema_7b-Q8_0.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q8_0", sha256="669788655e8f15f306284f267a444e9766c8a421869577b16a961e43029c737b"), "seedvr2_ema_7b_sharp-Q8_0.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q8_0", variant="sharp", sha256="b1f81cb5700b0b1f432f2c785528356c952c41c74d03d205c6f14b0bd6da303d"), }) print("[Internal] Custom GGUF models injected into MODEL_REGISTRY successfully.") except Exception as e: print(f"[Internal] Failed to inject custom models: {e}") # ============================================================================= ''' with open(registry_path, "a", encoding="utf-8") as f: f.write(patch_code) except Exception as e: print(f"[SeedVR2 Gradio] ERROR patching model_registry.py: {e}") # ----------------- # BlockSwap Management # ----------------- def manage_blockswap_file(use_improved: bool, repo_root: Path) -> str: """ Manages the blockswap.py file in the specified cloned repository. Accepted `repo_root` path to ensure we modify the correct repo. """ target_path_blockswap = repo_root / "src" / "optimization" / "blockswap.py" backup_path_blockswap = repo_root / "src" / "optimization" / "blockswap.py.bak" target_path_memory_manager = repo_root / "src" / "optimization" / "memory_manager.py" backup_path_memory_manager = repo_root / "src" / "optimization" / "memory_manager.py.bak" msg = "" # Ensure src/optimization exists (some forks might differ in structure) if not target_path_blockswap.parent.exists(): return f"[WARN] Optimization folder not found at {target_path_blockswap.parent}. Skipping blockswap patch.\n" if use_improved: if not IMPROVED_BLOCKSWAP_SOURCE.exists(): return f"[WARN] Improved blockswap source not found at {IMPROVED_BLOCKSWAP_SOURCE}. Keeping current version.\n" # 1. Check if we need to backup the original blockswap (only if backup doesn't exist yet) if target_path_blockswap.exists() and not backup_path_blockswap.exists(): try: shutil.move(str(target_path_blockswap), str(backup_path_blockswap)) msg += f"[INFO] Backed up original blockswap to {backup_path_blockswap.name}.\n" except Exception as e: return f"[ERROR] Failed to backup blockswap: {e}\n" # 2. Copy the improved file to target blockswap try: shutil.copy(str(IMPROVED_BLOCKSWAP_SOURCE), str(target_path_blockswap)) msg += "[INFO] Switched to Improved BlockSwap (Nunchaku implementation).\n" except Exception as e: return f"[ERROR] Failed to install improved blockswap: {e}\n" # Memory Manager Handling if not IMPROVED_MEMORY_MANAGER_SOURCE.exists(): return f"[WARN] Improved memory_manager source not found at {IMPROVED_MEMORY_MANAGER_SOURCE}. Keeping current version.\n" # 3. Check if we need to backup the original memory_manager (only if backup doesn't exist yet) if target_path_memory_manager.exists() and not backup_path_memory_manager.exists(): try: shutil.move(str(target_path_memory_manager), str(backup_path_memory_manager)) msg += f"[INFO] Backed up original memory_manager to {backup_path_memory_manager.name}.\n" except Exception as e: return f"[ERROR] Failed to backup memory_manager: {e}\n" # 4. Copy the improved file to target memory_manager try: shutil.copy(str(IMPROVED_MEMORY_MANAGER_SOURCE), str(target_path_memory_manager)) msg += "[INFO] Switched to Improved memory_manager (Nunchaku implementation).\n" except Exception as e: return f"[ERROR] Failed to install improved memory_manager: {e}\n" return msg else: # Restore original blockswap if available if backup_path_blockswap.exists(): try: # Remove current target blockswap if it exists (which might be the improved one) if target_path_blockswap.exists(): os.remove(target_path_blockswap) # Restore backup shutil.move(str(backup_path_blockswap), str(target_path_blockswap)) msg += "[INFO] Restored Original BlockSwap from backup.\n" except Exception as e: return f"[ERROR] Failed to restore original blockswap: {e}\n" else: # Backup doesn't exist, assume we are already on original or clean install msg += "[INFO] Using Original BlockSwap (No backup found/needed).\n" # Restore original memory_manager if available if backup_path_memory_manager.exists(): try: # Remove current target memory_manager if it exists (which might be the improved one) if target_path_memory_manager.exists(): os.remove(target_path_memory_manager) # Restore backup shutil.move(str(backup_path_memory_manager), str(target_path_memory_manager)) msg += "[INFO] Restored Original memory_manager from backup.\n" except Exception as e: return f"[ERROR] Failed to restore original memory_manager: {e}\n" else: # Backup doesn't exist, assume we are already on original or clean install msg += "[INFO] Using Original memory_manager (No backup found/needed).\n" return msg # ----------------- # Model download # ----------------- def ensure_models_available( selected_model_filename: str, model_dir: Optional[Path] = None, repo_id: str = DEFAULT_VAE_REPO_ID ) -> None: """ Ensure the selected DiT model and the VAE file exist locally. If missing, download from the specified Hugging Face repo directly into model_dir using 'local_dir' to avoid nested cache structures. """ if model_dir is None: model_dir = DEFAULT_MODEL_DIR else: model_dir = Path(model_dir) model_dir.mkdir(parents=True, exist_ok=True) # Items to check: VAE + selected DiT model required = ["ema_vae_fp16.safetensors", selected_model_filename] # Check if files physically exist at the target location missing = [_f for _f in required if not (model_dir / _f).exists()] if not missing: return # Optional: silence HF symlink warning if desired os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") # Attempt download for each missing file hf_token = os.environ.get("HF_ACCESS_TOKEN") for fname in missing: target_path = model_dir / fname # If file already somehow exists at target, skip if target_path.exists(): continue # Decide repo for this filename: # - VAE must always come from the official DEFAULT_VAE_REPO_ID (numz/SeedVR2_comfyUI) # - Dit model uses the provided repo_id (which comes from the dropdown selection) repo_for_fname = DEFAULT_VAE_REPO_ID if fname == "ema_vae_fp16.safetensors" else repo_id try: print(f"[SeedVR2 Gradio] Downloading {fname} from {repo_for_fname} directly to {model_dir} ...") # Use local_dir instead of cache_dir. # This forces the file to be saved exactly at {model_dir}/{fname} # local_dir_use_symlinks=False ensures we get a real file, not a symlink, # which prevents issues where the CLI subprocess cannot resolve the path. downloaded_path = hf_hub_download( repo_id=repo_for_fname, filename=fname, local_dir=str(model_dir), # Download directly to target folder repo_type="model", token=hf_token, ) print(f"[SeedVR2 Gradio] Download completed: {downloaded_path}") except Exception as e: raise RuntimeError(f"Failed to download {fname} from Hugging Face repo {repo_for_fname}: {e}") # ---------------- # Subprocess streaming helpers # ---------------- def _start_process_stream(cmd_args, cwd: str, env: dict) -> Tuple[Optional[subprocess.Popen], queue.Queue, Optional[threading.Thread], Optional[threading.Thread]]: """Start subprocess and return (proc, q, t_out, t_err). The returned queue will receive text lines as they arrive. Lines are simple strings (already newline-terminated). stderr lines are prefixed with "stderr: ". """ q = queue.Queue() try: proc = subprocess.Popen( cmd_args, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding='utf-8', errors='replace', bufsize=1, env=env ) except Exception as e: # Put error to queue and return a dummy proc q.put(f"[FAILED TO LAUNCH] {e}\n") return None, q, None, None def _reader(fh, prefix: str): try: while True: line = fh.readline() if not line: break if not line.endswith("\n"): line = line + "\n" q.put(prefix + line) except Exception as e: q.put(f"[reader error] {e}\n") t_out = threading.Thread(target=_reader, args=(proc.stdout, ""), daemon=True) t_err = threading.Thread(target=_reader, args=(proc.stderr, "stderr: "), daemon=True) t_out.start() t_err.start() return proc, q, t_out, t_err # ----------------- # CLI runner (streaming) # ----------------- def expected_upscaled_path(input_path: str, output_format: str = "png") -> str: """Calculates the expected output path based on the input path and requested format.""" p = Path(input_path) stem = p.stem parent = p.parent suffix = "_upscaled" # if output_format == "mp4": # return str((parent / f"{stem}{suffix}.mp4").resolve()) # else: # return str((parent / f"{stem}{suffix}.png").resolve()) return str((parent / f"{stem}{suffix}.{output_format}").resolve()) # Single-image/video CLI runner (generator) def run_cli_upscale_stream( input_path: str, resolution: int = 1080, max_resolution: int = 0, dit_model_filename: Optional[str] = None, # Receives just the filename cuda_device: Optional[str] = None, # Compilation & Performance compile_dit: bool = False, compile_vae: bool = False, compile_backend: str = "inductor", compile_mode: str = "default", compile_fullgraph: bool = False, compile_dynamic: bool = False, compile_dynamo_cache_size_limit: int = 64, compile_dynamo_recompile_limit: int = 128, attention_mode: str = "sdpa", # Tiling (Split Encode/Decode) vae_encode_tiled: bool = False, vae_encode_tile_size: int = 1024, vae_encode_tile_overlap: int = 128, vae_decode_tiled: bool = False, vae_decode_tile_size: int = 1024, vae_decode_tile_overlap: int = 128, tile_debug: str = "false", # Processing batch_size: int = 1, uniform_batch_size: bool = False, seed: int = 42, skip_first_frames: int = 0, load_cap: int = 0, # Quality & Color color_correction: str = "lab", input_noise_scale: float = 0.0, latent_noise_scale: float = 0.0, # Memory & Offload blocks_to_swap: int = 0, swap_io_components: bool = False, dit_offload_device: str = "none", vae_offload_device: str = "none", tensor_offload_device: str = "cpu", cache_dit: bool = False, cache_vae: bool = False, extra_args: str = "", model_dir: Optional[str] = None, repo_id: str = DEFAULT_VAE_REPO_ID, # Receives the specific Repo ID for DiT repo_path: Optional[Path] = None, timeout: int = 3600, pre_downscale: bool = False, # for artifact removal downscale_rate: float = 0.5, output_format: str = "png", # Can now be "mp4" use_improved_blockswap: bool = False, # New argument for switching blockswap version # Video Args chunk_size: int = 0, temporal_overlap: int = 0, prepend_frames: int = 0, video_backend: str = "opencv", use_10bit: bool = False, # Debug Arg debug: bool = False ) -> Generator[Tuple[Optional[str], str], None, None]: """ Generator yields (out_path_or_None, logs_so_far) while streaming CLI logs. Includes Phase-Aware Dynamic Fallback logic. """ # Defaults if repo_path is None: repo_path = CLONE_DIR current_inference_cli = repo_path / "inference_cli.py" # 1. Repo Check if not current_inference_cli.exists(): yield None, f"[ERROR] inference_cli.py not found in {repo_path}\n" return # Patch inference_cli.py with UTF-8 support try: apply_inference_cli_patch(current_inference_cli) except Exception as e: yield None, f"[WARN] Failed to patch inference_cli: {e}\n" # Patch model_registry.py with custom models try: patch_model_registry(repo_path) except Exception as e: yield None, f"[WARN] Failed to patch model_registry: {e}\n" # Handle BlockSwap File Replacement Logic try: swap_log = manage_blockswap_file(use_improved_blockswap, repo_root=repo_path) # Yield the log about blockswap immediately yield None, swap_log except Exception as e: yield None, f"[ERROR] BlockSwap management failed: {e}\n" # Use the global default if not provided if model_dir is None: model_dir = str(DEFAULT_MODEL_DIR) # Ensure model files present if dit_model_filename: try: ensure_models_available( dit_model_filename, model_dir=Path(model_dir), repo_id=repo_id, ) except Exception as e: yield None, f"[ERROR] Model download failed: {e}\n" return safe_input_path = input_path temp_copy = None # Pre-downscale logic (Artifact Removal Trick) - Only applies to Images in this implementation # We skip this for MP4 files to avoid complex video processing in python before CLI is_video = input_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')) # Pre-downscale (Images only) if pre_downscale and not is_video: try: filename = os.path.basename(input_path) # Load original image using OpenCV img_obj = cv2.imread(input_path, cv2.IMREAD_UNCHANGED) if img_obj is None: raise ValueError(f"Failed to load image: {input_path}") # Calculate new dimensions (OpenCV shape is [height, width]) h, w = img_obj.shape[:2] if (max(w, h) > 250): new_w = int(w * downscale_rate) new_h = int(h * downscale_rate) # Resize # Use INTER_AREA for downscaling (better quality/less aliasing for shrinking) # Use INTER_LANCZOS4 if scaling up (though this block is specifically for downscaling) interpolation_method = cv2.INTER_AREA if downscale_rate < 1.0 else cv2.INTER_LANCZOS4 img_resized = cv2.resize(img_obj, (new_w, new_h), interpolation=interpolation_method) # Prepare temp directory tmp_dir = CLONE_DIR / "tmp_inputs" tmp_dir.mkdir(parents=True, exist_ok=True) # Save to a unique temp file (forces .png for intermediate input) new_name = f"{filename}_downscaled.png" temp_copy = str(tmp_dir / new_name) # Use patched cv2.imwrite cv2.imwrite(temp_copy, img_resized) # Use this temp file as the input for CLI safe_input_path = temp_copy yield None, f"[INFO] Pre-downscaled input by factor {downscale_rate} (Size: {w}x{h} -> {new_w}x{new_h}) to reduce artifacts.\n" except Exception as e: yield None, f"[ERROR] Failed to pre-downscale image: {e}\n" return # 2. Command Builder def _build_cmd(curr_tile_size, curr_batch_size): # Determine strict output format if is_video: # If input is video, force mp4 output for CLI unless user explicitly wants png sequence? # Usually users want mp4 back. cmd_format = "mp4" else: # For images, use png (CLI handles webp/jpg conversion internally if modified, # but standard CLI outputs png/mp4). We force png here, app.py handles conversion later. cmd_format = "png" cmd = [PY_EXE, str(current_inference_cli), safe_input_path, "--resolution", str(resolution), "--output_format", cmd_format, "--batch_size", str(curr_batch_size), "--color_correction", color_correction, "--model_dir", str(model_dir), "--seed", str(seed), "--attention_mode", str(attention_mode)] if max_resolution and int(max_resolution) > 0: cmd += ["--max_resolution", str(int(max_resolution))] if dit_model_filename: # CLI just needs the filename relative to --model_dir (or absolute path) cmd += ["--dit_model", str(dit_model_filename)] # Only add --cuda_device if CUDA available and user provided a value if CUDA_AVAILABLE and cuda_device: cmd += ["--cuda_device", str(cuda_device)] # --- Compilation Options --- if compile_dit: cmd += ["--compile_dit"] if compile_vae: cmd += ["--compile_vae"] if compile_dit or compile_vae: cmd += [ "--compile_backend", str(compile_backend), "--compile_mode", str(compile_mode), "--compile_dynamo_cache_size_limit", str(compile_dynamo_cache_size_limit), "--compile_dynamo_recompile_limit", str(compile_dynamo_recompile_limit) ] if compile_fullgraph: cmd += ["--compile_fullgraph"] if compile_dynamic: cmd += ["--compile_dynamic"] # --- Tiling Options --- # Note: curr_tile_size comes from the loop strategy (Phase Fallback), # normally we use the user provided vae_encode_tile_size. if vae_encode_tiled: cmd += ["--vae_encode_tiled", "--vae_encode_tile_size", str(curr_tile_size), "--vae_encode_tile_overlap", str(vae_encode_tile_overlap)] if vae_decode_tiled: cmd += ["--vae_decode_tiled", "--vae_decode_tile_size", str(vae_decode_tile_size), "--vae_decode_tile_overlap", str(vae_decode_tile_overlap)] if tile_debug != "false": cmd += ["--tile_debug", str(tile_debug)] # --- Processing & Quality --- if uniform_batch_size: cmd += ["--uniform_batch_size"] if skip_first_frames > 0: cmd += ["--skip_first_frames", str(int(skip_first_frames))] if load_cap > 0: cmd += ["--load_cap", str(int(load_cap))] if input_noise_scale > 0: cmd += ["--input_noise_scale", str(input_noise_scale)] if latent_noise_scale > 0: cmd += ["--latent_noise_scale", str(latent_noise_scale)] # --- BlockSwap / Offload / Caching --- if blocks_to_swap and int(blocks_to_swap) > 0: cmd += ["--blocks_to_swap", str(int(blocks_to_swap))] if swap_io_components: cmd += ["--swap_io_components"] # Offload flags: note these are strings like "none"/"cpu"/"cuda:0" if dit_offload_device and dit_offload_device != "none": # Ensure we don't pass a cuda device offload when cuda isn't available if not (dit_offload_device.startswith("cuda") and not CUDA_AVAILABLE): cmd += ["--dit_offload_device", str(dit_offload_device)] if vae_offload_device and vae_offload_device != "none": if not (vae_offload_device.startswith("cuda") and not CUDA_AVAILABLE): cmd += ["--vae_offload_device", str(vae_offload_device)] if tensor_offload_device and tensor_offload_device != "none": if not (tensor_offload_device.startswith("cuda") and not CUDA_AVAILABLE): cmd += ["--tensor_offload_device", str(tensor_offload_device)] if cache_dit: cmd += ["--cache_dit"] if cache_vae: cmd += ["--cache_vae"] # --- Video Specific Flags --- if chunk_size > 0: cmd += ["--chunk_size", str(int(chunk_size))] if temporal_overlap > 0: cmd += ["--temporal_overlap", str(int(temporal_overlap))] if prepend_frames > 0: cmd += ["--prepend_frames", str(int(prepend_frames))] if video_backend and video_backend != "opencv": cmd += ["--video_backend", str(video_backend)] if use_10bit: cmd += ["--10bit"] # Debug Flag if debug: cmd += ["--debug"] if extra_args: # Allow advanced users to type additional flags (space separated) cmd += extra_args.split() return cmd # 3. Dynamic Strategy Loop # Use encode tile size as the dynamic variable for fallback current_tile_size = int(vae_encode_tile_size) current_batch_size = int(batch_size) # Initialize log tracking logs_buf = "" # Add previous swap logs to buf logs_buf += swap_log if 'swap_log' in locals() else "" max_attempts = 5 # Prevent infinite loops attempt_count = 0 idx = 0 while attempt_count < max_attempts: attempt_count += 1 note = f"Tile: {current_tile_size}, Batch: {current_batch_size}" header = f"\n\n=== ATTEMPT {attempt_count}/{max_attempts} ({note}) ===\n" logs_buf += header # yield immediate header yield None, logs_buf cmd = _build_cmd(current_tile_size, current_batch_size) logs_buf += f"[CMD] {' '.join(cmd)}\n" yield None, logs_buf # start streaming process env = os.environ.copy() # Make Python in child process print using UTF-8 (avoids cp950 UnicodeEncodeError on Windows) env['PYTHONIOENCODING'] = 'utf-8' env['PYTHONUTF8'] = '1' # # help fragmentation/alloc issues; user may tune # env.setdefault('PYTORCH_ALLOC_CONF', os.environ.get('PYTORCH_ALLOC_CONF', 'max_split_size_mb:128')) proc, q, t_out, t_err = _start_process_stream(cmd, cwd=str(CLONE_DIR), env=env) if proc is None: logs_buf += "[ERROR] Failed to launch subprocess.\n" yield None, logs_buf break # try next strategy? here treat as fatal # State tracking for this run current_phase = "init" # init, vae_enc, dit, vae_dec, post oom_detected = False start_time = time.time() # poll queue while True: try: # wait up to 0.5s for a line line = q.get(timeout=0.5) logs_buf += line yield None, logs_buf lower_line = line.lower() # Track Phase if "phase 1: vae encoding" in lower_line: current_phase = "vae_enc" elif "phase 2: dit upscaling" in lower_line: current_phase = "dit" elif "phase 3: vae decoding" in lower_line: current_phase = "vae_dec" elif "saving" in lower_line or "converting" in lower_line: current_phase = "post" # Check for OOM oom_indicators = ["outofmemory", "out of memory", "allocation on device", "oom", "cuda out of memory"] if any(k in lower_line for k in oom_indicators): logs_buf += f"\n[WARN] OOM detected during phase: {current_phase.upper()}\n" yield None, logs_buf oom_detected = True try: proc.kill() # Kill immediately to recover VRAM except: pass break except queue.Empty: # no new line - check process status if proc.poll() is not None: break # still running - continue polling continue # Flush remaining while True: try: line = q.get_nowait() logs_buf += line yield None, logs_buf except queue.Empty: break # Wait for reader threads to exit try: if t_out: t_out.join(timeout=1) if t_err: t_err.join(timeout=1) except Exception: pass runtime = time.time() - start_time logs_buf += f"[Attempt {idx} finished in {runtime:.2f}s] returncode={proc.returncode}\n" idx += 1 yield None, logs_buf # 4. Success Check using safe_input_path # CLI generates output relative to the actual input file used (which might be the temp one) # For video, strict output detection logic out_fmt_check = "mp4" if is_video else "png" out_path = expected_upscaled_path(safe_input_path, output_format=out_fmt_check) if Path(out_path).exists(): logs_buf += f"[SUCCESS] Intermediate Output: {out_path}\n" # Cleanup temp file if we created one if temp_copy: try: os.remove(temp_copy) except Exception: pass yield out_path, logs_buf return # 5. Failure Analysis & Parameter Adjustment if oom_detected or proc.returncode != 0: logs_buf += f"\n[INFO] Attempt {attempt_count} failed. Analyzing OOM Phase: {current_phase.upper()}...\n" # --- INTELLIGENT ADJUSTMENT LOGIC --- # Case A: VAE OOM (Phase 1 or 3) -> Reduce Tile Size if current_phase in ["vae_enc", "vae_dec"]: if current_tile_size > 256: new_tile = max(256, current_tile_size // 2) logs_buf += f"[STRATEGY] VAE OOM detected. Reducing Tile Size: {current_tile_size} -> {new_tile}\n" current_tile_size = new_tile else: # Tile size already min, try reducing batch size as a last resort new_batch = max(1, current_batch_size // 2) logs_buf += f"[STRATEGY] VAE OOM but Tile Size is min. Reducing Batch Size: {current_batch_size} -> {new_batch}\n" current_batch_size = new_batch # Case B: DiT OOM (Phase 2) -> Reduce Batch Size elif current_phase == "dit": if current_batch_size > 1: # For video consistency, try to keep 4n+1 if possible, or just halve it new_batch = max(1, current_batch_size // 2) logs_buf += f"[STRATEGY] DiT OOM detected. Reducing Batch Size: {current_batch_size} -> {new_batch}\n" current_batch_size = new_batch else: logs_buf += f"[FAIL] DiT OOM with Batch Size 1. Cannot reduce further.\n" break # Case C: Post-Process OOM (Phase 4) -> Reduce Batch Size elif current_phase == "post": if current_batch_size > 1: new_batch = max(1, current_batch_size // 2) logs_buf += f"[STRATEGY] Post-Process OOM detected. Reducing Batch Size: {current_batch_size} -> {new_batch}\n" current_batch_size = new_batch else: logs_buf += "[FAIL] Post-Process OOM with Batch Size 1.\n" break # Case D: Unknown/Init OOM -> Reduce both safely else: current_tile_size = max(256, current_tile_size // 2) current_batch_size = max(1, current_batch_size // 2) logs_buf += f"[STRATEGY] Early OOM. Reducing both Tile ({current_tile_size}) and Batch ({current_batch_size}).\n" # Check if we are just retrying same settings (infinite loop prevention) if attempt_count >= max_attempts: logs_buf += "[FAIL] Max attempts reached.\n" break yield None, logs_buf # Loop continues with new settings else: # Non-OOM fatal error logs_buf += f"[ERROR] Non-zero return code (not OOM) - stopping.\n" yield None, logs_buf return # all strategies exhausted logs_buf += "[FAILED] No output produced after all strategies.\n" if temp_copy: try: os.remove(temp_copy) except: pass yield None, logs_buf return # --- Preset change handler (considers CUDA & MPS availability) --- def preset_changed(preset_value): # Updated Tuple Order: # 0: compile_dit, 1: compile_vae, # 2: vae_encode_tiled, 3: vae_encode_tile_size, # 4: vae_decode_tiled, 5: vae_decode_tile_size, # 6: max_resolution, 7: blocks_to_swap, 8: swap_io_components # 9: dit_offload_device, 10: vae_offload_device, 11: tensor_offload_device, # 12: extra_args, 13: chunk_size, 14: temporal_overlap if preset_value == "Recommended (low VRAM)": return ( False, # compile_dit False, # compile_vae True, # vae_encode_tiled 512, # vae_encode_tile_size True, # vae_decode_tiled 512, # vae_decode_tile_size (sync with encode for safety) 1920, # max_resolution 32, # blocks_to_swap True, # swap_io_components "cpu", # dit_offload_device "none", # vae_offload_device (Keep VAE on device if possible) "cpu", # tensor_offload_device (Offload tensors to save VRAM) "--blocks_to_swap 0", # extra_args 0, # chunk_size 0 # temporal_overlap (0=auto/disabled) ) elif preset_value == "Offload (very slow)": return ( False, # compile_dit False, # compile_vae True, # vae_encode_tiled 256, # vae_encode_tile_size True, # vae_decode_tiled 256, # vae_decode_tile_size 1440, # max_resolution 99, # blocks_to_swap True, # swap_io_components "cpu", # dit_offload_device "cpu", # vae_offload_device "cpu", # tensor_offload_device "--blocks_to_swap 99 --swap_io_components --dit_offload_device cpu --vae_offload_device cpu --tensor_offload_device cpu", 0, 0 ) elif preset_value == "High quality (fast if lots of VRAM)": return ( True, True, False, 512, False, 512, 0, 0, False, "none", "none", "none", # Keep tensors on GPU/MPS "", 0, 0 ) # fallback return (False, False, True, 256, True, 256, 1920, 0, False, "none", "none", "cpu", "--blocks_to_swap 0", 0, 0) # ---------------- Paste JS (attach to gallery elem) ---------------- paste_js = """ function initPaste() { document.addEventListener('paste', function(e) { const gallery = document.getElementById('input_gallery'); if (!gallery) return; if (!gallery.matches(':hover')) return; const clipboardData = e.clipboardData || e.originalEvent.clipboardData; if (!clipboardData) return; const items = clipboardData.items; const files = []; for (let i = 0; i < items.length; i++) { if (items[i].kind === 'file' && items[i].type.startsWith('image/')) { files.push(items[i].getAsFile()); } } if (files.length === 0 && clipboardData.files.length > 0) { for (let i = 0; i < clipboardData.files.length; i++) { if (clipboardData.files[i].type.startsWith('image/')) { files.push(clipboardData.files[i]); } } } if (files.length === 0) return; const uploadInput = gallery.querySelector('input[type="file"]'); if (uploadInput) { e.preventDefault(); e.stopPropagation(); const dataTransfer = new DataTransfer(); files.forEach(file => dataTransfer.items.add(file)); uploadInput.files = dataTransfer.files; uploadInput.dispatchEvent(new Event('change', { bubbles: true })); } }); } """ # ---------------- # Gradio layout # ---------------- # Helper function: Generate progress bar HTML def make_progress_html(current, total, step_desc): if total == 0: percent = 0 else: percent = min(max(current / total * 100, 0), 100) # Use Gradio's CSS variables to automatically adapt to dark/light modes # var(--background-fill-secondary): Container background color # var(--border-color-primary): Border color # var(--body-text-color): Main text color # var(--color-accent): Progress bar color (follows theme accent) # var(--border-color-primary): Progress bar track color (ensures visibility in dark mode) return f"""
{step_desc} {percent:.1f}%
""" # ---------------- UI: main ---------------- def ui_upscale_main( gallery_input, # Image list video_input, # Video path resolution, max_resolution, preset_mode, dit_model_combo, use_gguf, cuda_device, # Compile compile_dit, compile_vae, compile_backend, compile_mode, compile_fullgraph, compile_dynamic, compile_dynamo_cache_size_limit, compile_dynamo_recompile_limit, attention_mode, # Tiling vae_encode_tiled, vae_encode_tile_size, vae_encode_tile_overlap, vae_decode_tiled, vae_decode_tile_size, vae_decode_tile_overlap, tile_debug, # Processing batch_size, uniform_batch_size, seed, skip_first_frames, load_cap, # Color/Quality color_correction, input_noise_scale, latent_noise_scale, # Memory blocks_to_swap, swap_io_components, dit_offload_device, vae_offload_device, tensor_offload_device, cache_dit, cache_vae, extra_args, # General pre_downscale, downscale_rate, repetition_count, output_format, use_improved_blockswap, # Video chunk_size, temporal_overlap, prepend_frames, video_backend, use_10bit, debug, # Repo Config Inputs custom_repo_url, custom_branch, custom_clone_name ): # Initialize empty progress bar HTML empty_progress = make_progress_html(0, 100, "Waiting to start...") # DETERMINE INPUT SOURCE target_inputs = [] is_video_mode = False if video_input is not None: # Video takes precedence if provided (or user can clear it) target_inputs = [video_input] is_video_mode = True # Force mp4 format for internal logic if video output_format = "mp4" elif gallery_input: # gallery is expected to be a list; each item may be: # - str filepath (depending on Gradio version) OR # - an object/tuple where first item is filepath (some Gradio variants). # Normalize gallery entries to file paths for entry in gallery_input: # Gradio versions vary — entry may be: # - str (path) # - list/tuple where first element is path if isinstance(entry, (list, tuple)): # sometimes gallery entries are [path, caption...] path = entry[0] else: path = entry # If path is a dict with 'name' depending on gradio, try common keys if isinstance(path, dict) and 'name' in path: path = path['name'] target_inputs.append(str(path)) else: yield None, "No images or video provided.\n", empty_progress return # # apply presets # if preset_mode == "Recommended (low VRAM)": # compile_dit = False # compile_vae = False # vae_encode_tiled = True # if not vae_tile_size: # vae_tile_size = 256 # if max_resolution is None: # max_resolution = 1920 # # keep blocks_to_swap = 0 by default # elif preset_mode == "Offload (very slow)": # compile_dit = False # compile_vae = False # vae_encode_tiled = True # if not vae_tile_size: # vae_tile_size = 256 # if max_resolution is None: # max_resolution = 1440 # if not blocks_to_swap: # blocks_to_swap = 32 # swap_io_components = True # dit_offload_device = "cpu" # vae_offload_device = "cpu" # tensor_offload_device = "cpu" # elif preset_mode == "High quality (fast if lots of VRAM)": # compile_dit = True # compile_vae = True # vae_encode_tiled = False # if max_resolution is None: # max_resolution = 0 # no limit # Dynamic Repo Handling current_repo_path = CLONE_DIR # Fallback # Model directory is now fixed and independent of the repo location current_model_dir = DEFAULT_MODEL_DIR # Default values if empty target_repo_url = custom_repo_url.strip() if custom_repo_url and custom_repo_url.strip() else REPO_URL target_clone_name = custom_clone_name.strip() if custom_clone_name and custom_clone_name.strip() else "ComfyUI-SeedVR2_VideoUpscaler" target_branch = custom_branch.strip() # Ensure repo and model exist (downloads/clone if missing) try: yield None, f"Checking Repository ({target_clone_name})...", make_progress_html(5, 100, "Checking Repo...") # Call the updated ensure_repo_cloned current_repo_path = ensure_repo_cloned( repo_url=target_repo_url, clone_dir_name=target_clone_name, repo_branch=target_branch, force_update=False ) except Exception as e: yield None, f"Repo clone/check failed: {e}\n", make_progress_html(0, 100, "Repo Error") return # Parse the selected combo "RepoID/Filename" selected_repo_id = DEFAULT_VAE_REPO_ID # Default fallback selected_filename = None if dit_model_combo: # Check if the string contains a slash indicating Repo/File structure if "/" in dit_model_combo: # Split from the right, as filename is the last part parts = dit_model_combo.split("/") selected_filename = parts[-1] # Join the rest as the repo ID (e.g. "owner/repo" or "owner/sub/repo") selected_repo_id = "/".join(parts[:-1]) else: # Fallback for simple filenames (assumes default repo) selected_filename = dit_model_combo if selected_filename: try: yield None, f"Checking Model {selected_filename}...", make_progress_html(10, 100, "Checking Models...") # Pass new paths to ensure download happens in the custom repo folder ensure_models_available( selected_filename, model_dir=current_model_dir, repo_id=selected_repo_id, ) except Exception as e: yield None, f"Model download failed: {e}\n", make_progress_html(0, 100, "Model Error") return # Stores log history for all completed images full_logs_history = "" # successful_outputs will now store tuples (physical_path, archive_name) successful_outputs = [] total_files = len(target_inputs) # Ensure repetition is at least 1 safe_repetition = max(1, int(repetition_count)) if is_video_mode: safe_repetition = 1 # Force 1 pass for video to avoid endless waits total_operations = total_files * safe_repetition # Process sequentially for idx, img_path in enumerate(target_inputs, start=1): filename = os.path.basename(img_path) original_stem = Path(img_path).stem # This variable tracks the input for the current pass # Initially it is the original file, in subsequent loops it becomes the output of the previous pass current_input_path = img_path final_output_for_image = None # Loop for Repetitions for loop_idx in range(1, safe_repetition + 1): # Calculate global progress index # (File 1 Pass 1 = 0, File 1 Pass 2 = 1 ... File 2 Pass 1 = N) global_op_index = (idx - 1) * safe_repetition + (loop_idx - 1) # Prepare header pass_info = f" (Pass {loop_idx}/{safe_repetition})" if safe_repetition > 1 else "" # Prepare header for this file header_log = f"\n\n=== FILE {idx}/{total_files}: {filename}{pass_info} ===\n" # Progress calculation # Calculate base progress (e.g., 2nd image of 4, base progress is 25%) # Reserve 10% for preparation, allocate remaining 90% to images # start_pct = 10 + ((idx - 1) / total_images) * 90 start_pct = 10 + (global_op_index / total_operations) * 90 progress_html = make_progress_html(start_pct, 100, f"File {idx}/{total_files} - Pass {loop_idx}: Preparing...") yield None, full_logs_history + header_log, progress_html # Call generator # Note: pre_downscale is passed every time. # If enabled, it will downscale 'current_input_path' before upscaling. gen = run_cli_upscale_stream( input_path=current_input_path, resolution=int(resolution), max_resolution=int(max_resolution) if max_resolution is not None else 0, dit_model_filename=selected_filename if selected_filename else None, cuda_device=(cuda_device if CUDA_AVAILABLE else None), # New Compile Args compile_dit=bool(compile_dit), compile_vae=bool(compile_vae), compile_backend=compile_backend, compile_mode=compile_mode, compile_fullgraph=bool(compile_fullgraph), compile_dynamic=bool(compile_dynamic), compile_dynamo_cache_size_limit=int(compile_dynamo_cache_size_limit), compile_dynamo_recompile_limit=int(compile_dynamo_recompile_limit), attention_mode=attention_mode, # New Tiling Args vae_encode_tiled=bool(vae_encode_tiled), vae_encode_tile_size=int(vae_encode_tile_size), vae_encode_tile_overlap=int(vae_encode_tile_overlap), vae_decode_tiled=bool(vae_decode_tiled), vae_decode_tile_size=int(vae_decode_tile_size), vae_decode_tile_overlap=int(vae_decode_tile_overlap), tile_debug=tile_debug, # New Processing Args batch_size=int(batch_size), uniform_batch_size=bool(uniform_batch_size), seed=int(seed), skip_first_frames=int(skip_first_frames), load_cap=int(load_cap), # New Quality Args color_correction=color_correction, input_noise_scale=float(input_noise_scale), latent_noise_scale=float(latent_noise_scale), # Memory & Caching blocks_to_swap=int(blocks_to_swap), swap_io_components=bool(swap_io_components), dit_offload_device=str(dit_offload_device), vae_offload_device=str(vae_offload_device), tensor_offload_device=str(tensor_offload_device), cache_dit=bool(cache_dit), cache_vae=bool(cache_vae), extra_args=extra_args or "", model_dir=str(current_model_dir), # Use current_model_dir repo_id=selected_repo_id, # Pass the extracted Repo ID repo_path=current_repo_path, # Pass dynamic repo path pre_downscale=pre_downscale, downscale_rate=downscale_rate, output_format=output_format, use_improved_blockswap=use_improved_blockswap, # Pass Video Args chunk_size=int(chunk_size), temporal_overlap=int(temporal_overlap), prepend_frames=int(prepend_frames), video_backend=video_backend, use_10bit=use_10bit, debug=bool(debug) ) out_for_this_pass = None current_stream_logs = "" try: for out_path, logs in gen: # logs is the complete log from start to now for this image (behavior of run_cli_upscale_stream) current_stream_logs = logs # Progress Logic per pass # Simply parse log content to determine stage, allocating this image's ratio in total progress # Single image takes up (90 / total_files)% of total progress per_pass_slice = 90 / total_operations local_ratio = 0.1 status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Init" if "Phase 1: VAE encoding" in logs: local_ratio = 0.2 status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Encoding" if "Phase 2: DiT upscaling" in logs: local_ratio = 0.45 status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Upscaling" if "Phase 3: Decode" in logs: local_ratio = 0.8 status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Decoding" if "Phase 4: Post-process" in logs: local_ratio = 0.95 status_text = f"Img {idx}/{total_files} (Pass {loop_idx}): Post-proc" if "Saving" in logs: local_ratio = 0.98 status_text = f"File {idx}/{total_files} (Pass {loop_idx}): Saving" # Calculate current total progress current_total_pct = start_pct + (per_pass_slice * local_ratio) progress_html = make_progress_html(current_total_pct, 100, status_text) # Combine historical log + current image header + current streaming log yield None, full_logs_history + header_log + current_stream_logs, progress_html if out_path: out_for_this_pass = out_path except Exception as e: error_msg = f"[ERROR] Exception: {e}\n" current_stream_logs += error_msg yield None, full_logs_history + header_log + current_stream_logs, make_progress_html(current_total_pct, 100, "Error") # If error, break the repetition loop for this image break full_logs_history += header_log + current_stream_logs # After image processing completes if out_for_this_pass and os.path.exists(out_for_this_pass): # Success for this pass current_input_path = out_for_this_pass # Update input for next pass final_output_for_image = out_for_this_pass else: # Failure in this pass, stop repeating full_logs_history += f"\n[WARN] Pass {loop_idx} failed, stopping.\n" break # End Repetitions if final_output_for_image and os.path.exists(final_output_for_image): # Rename physical file back to original name + timestamp before adding to ZIP # Consider output format conversion if necessary output_dir = Path(final_output_for_image).parent # Generate timestamp ts = int(time.time()) # Logic for Image format conversion vs Video if is_video_mode: # Keep as mp4 target_filename = f"{original_stem}_{ts}.mp4" target_path = output_dir / target_filename try: shutil.move(final_output_for_image, target_path) final_output_for_image = str(target_path) full_logs_history += f"[INFO] Renamed output to: {target_filename}\n" except Exception as e: full_logs_history += f"[WARN] Rename failed: {e}\n" target_filename = os.path.basename(final_output_for_image) else: # Image Logic (PNG/JPG/WEBP conversion) # Set restored filename (original_stem_{timestamp}.{ext}) target_filename = f"{original_stem}_{ts}.{output_format}" target_path = output_dir / target_filename # Delete target if exists to avoid collision if target_path.exists(): os.remove(target_path) if output_format == "png": # Just rename shutil.move(final_output_for_image, target_path) full_logs_history += f"[INFO] Renamed output to: {target_filename}\n" else: # Convert (jpg, webp, etc.) try: # Read the image img = cv2.imread(final_output_for_image, cv2.IMREAD_UNCHANGED) if img is None: raise ValueError("Result image could not be loaded via cv2.") # Convert BGRA (OpenCV default for alpha) to BGR if saving as JPEG if output_format in ["jpg", "jpeg"]: # Check if image has 4 channels if len(img.shape) == 3 and img.shape[2] == 4: img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) # Save to new format with quality control quality_val = 95 cv2.imwrite(str(target_path), img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_val]) else: # Save directly (handles webp, etc.) cv2.imwrite(str(target_path), img) full_logs_history += f"[INFO] Converted PNG to {output_format}: {target_filename}\n" # Remove original png os.remove(final_output_for_image) except Exception as e: full_logs_history += f"[ERROR] Conversion failed: {e}\n" # Fallback: if conversion failed, try to keep the original file if possible target_path = Path(final_output_for_image) # Fallback to png target_filename = target_path.name # Update variable to point to the new path final_output_for_image = str(target_path) # Store both the physical path and the intended ZIP name successful_outputs.append((final_output_for_image, target_filename)) full_logs_history += f"[INFO] Item {idx} completed: {final_output_for_image}\n" # Update progress bar to completed state for this file end_pct = 10 + (idx / total_files ) * 90 # Rough estimate for completion of this image block yield None, full_logs_history, make_progress_html(end_pct, 100, f"Item {idx} Done") # Final Output Logic # After all images processed, logic to handle output (Single file vs ZIP) if successful_outputs: # Check output count if len(successful_outputs) == 1: # If there is only one file, return the image path directly; do not compress. # successful_outputs stores tuples: (physical_path, archive_name) single_file_path = successful_outputs[0][0] full_logs_history += f"\n[DONE] Single image processed. Returning: {single_file_path}\n" # Yield the single file path directly yield single_file_path, full_logs_history, make_progress_html(100, 100, "Processing Complete!") else: # If more than one file, execute the standard ZIP packaging logic. # Use ZIP_STORED (store only, no compression) for speed to avoid CPU bottlenecks. # PNG is already a compressed format; re-compressing via Deflate offers little benefit and is extremely slow. compression_method = zipfile.ZIP_STORED # Use parent dir of the first output for the zip location out_dir = Path(successful_outputs[0][0]).parent # Derive model tag for filename (fall back to "output" when unknown) model_tag = selected_filename or "output" # sanitize model_tag to be filesystem-safe (allow alnum, dot, dash, underscore) sanitized = "".join(c if (c.isalnum() or c in "._-") else "_" for c in model_tag) zip_name = f"seedvr2_{sanitized}_{int(time.time())}.zip" zip_path = out_dir / zip_name full_logs_history += f"\n[INFO] Zipping {len(successful_outputs)} items...\n" yield None, full_logs_history, make_progress_html(100, 100, "Packaging...") # Add allowZip64=True to support files larger than 4GB with zipfile.ZipFile(zip_path, "w", compression=compression_method, allowZip64=True) as zf: total_files = len(successful_outputs) # Yield progress inside the packaging loop to prevent Gradio disconnects due to long periods of unresponsiveness for i, (physical_path, archive_name) in enumerate(successful_outputs): try: zf.write(physical_path, arcname=archive_name) except Exception as e: full_logs_history += f"\n[WARN] Failed to pack {archive_name}: {e}\n" # Regularly yield progress updates # Although ZIP_STORED is fast, writing 400 images to disk still takes time. # Update UI every 10 images here to let the frontend know the connection is still alive. if i % 10 == 0 or i == total_files - 1: # pct = int((i + 1) / total_files * 100) status_msg = f"Packaging {i+1}/{total_files}..." # Here we only update the progress bar, not the full log history, to avoid excessive data transmission yield None, full_logs_history, make_progress_html(100, 100, status_msg) final_msg = f"\n[DONE] Successfully packaged {len(successful_outputs)} items into {zip_path}\n" full_logs_history += final_msg # Return the ZIP file path yield str(zip_path), full_logs_history, make_progress_html(100, 100, "Complete!") else: full_logs_history += "\n[DONE] No outputs were generated.\n" yield None, full_logs_history, make_progress_html(100, 100, "Failed / No Output") # ---------------- UI layout ---------------- def main(): css = """ /* small UI tweaks */ #input_gallery:hover { border-color: var(--color-accent) !important; box-shadow: 0 0 8px rgba(0,0,0,0.08); } """ is_low_vram = False # Print CUDA/MPS availability, useful when running on CPU-only server if CUDA_AVAILABLE: try: torch.cuda.set_per_process_memory_fraction(0.95, device='cuda:0') except Exception: pass # set torch options to avoid get black image for RTX16xx card # https://github.com/CompVis/stable-diffusion/issues/69#issuecomment-1260722801 torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}") try: # Check if VRAM is less than 6.5GB (6 * 1024^3 bytes) # If so, default to Offload mode to prevent OOM on entry if torch.cuda.get_device_properties(0).total_memory <= (6.5 * 1024**3): is_low_vram = True except Exception: # Fallback if device property read fails pass elif MPS_AVAILABLE: print("MPS (Apple Silicon) is available. Using Metal Performance Shaders.") else: print("Neither CUDA nor MPS detected. GPU-related UI controls hidden.") # Automatic defaults logic: # 1. If no accelerator (CPU only) -> Force "Offload" # 2. If MPS (Mac) -> Default "Recommended" (Unified Memory handles this well), MPS users default to Recommended (usually 8GB+ Unified Memory is sufficient for this preset) # 3. If CUDA -> Check VRAM, if low use "Offload", else "Recommended" DEFAULT_PRESET = "Offload (very slow)" if not ACCELERATOR_AVAILABLE or is_low_vram else "Recommended (low VRAM)" # Unpack default values from the calculated preset immediately. # This ensures that all sliders and checkboxes match the Dropdown's initial value. ( init_compile_dit, init_compile_vae, init_vae_encode_tiled, init_vae_encode_tile_size, init_vae_decode_tiled, init_vae_decode_tile_size, init_max_resolution, init_blocks_to_swap, init_swap_io_components, init_dit_offload_device, init_vae_offload_device, init_tensor_offload_device, init_extra_args, init_chunk_size, init_temporal_overlap ) = preset_changed(DEFAULT_PRESET) # Default to GGUF only if NO accelerator is found or is low vram. # MPS/CUDA users usually prefer standard Safetensors unless extremely VRAM constrained. DEFAULT_USE_GGUF = not ACCELERATOR_AVAILABLE or is_low_vram # initial model choices depend on DEFAULT_USE_GGUF initial_model_choices = GGUF_CHOICES if DEFAULT_USE_GGUF else MODEL_CHOICES initial_model_value = initial_model_choices[0] if initial_model_choices else (MODEL_CHOICES[0] if MODEL_CHOICES else None) with gr.Blocks(title="SeedVR2 Image/Video Upscaler", css=css) as demo: gr.Markdown("# SeedVR2 Upscaler — Image & Video\nSupport for single image, batch images, and MP4 video upscaling.") gr.Markdown("This application utilizes the [ComfyUI-SeedVR2_VideoUpscaler](https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler) backend logic for inference. ") if MPS_AVAILABLE: gr.HTML( """

🍎 macOS MPS Detected

Running on Metal Performance Shaders (MPS). Performance is better than CPU.

""" ) elif not CUDA_AVAILABLE: gr.HTML( """

⚠️ No GPU Detected (CPU Mode)

Neither CUDA (NVIDIA) nor MPS (macOS) was detected. Processing will be extremely slow.

""" ) with gr.Row(): with gr.Column(scale=1): submit = gr.Button("Start Upscale Processing", variant="primary", size="lg") # TABS for Input with gr.Tabs(): with gr.TabItem("🖼️ Image Gallery"): gallery = gr.Gallery( label="Input Images (Batch Support)", elem_id="input_gallery", columns=4, rows=3, show_label=False, interactive=True, height=350 ) with gr.TabItem("🎥 Video Input"): video_input = gr.Video( label="Input Video (MP4/AVI)", sources=["upload"], format="mp4" ) # Group 0: Repo Settings (New) with gr.Accordion("🛠️ Repository Settings (Advanced)", open=False): gr.Markdown("Configure a custom GitHub repository to test different versions or forks.") with gr.Row(): custom_repo_url = gr.Textbox( label="Repository URL", value="", placeholder="https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler.git" ) custom_clone_name = gr.Textbox( label="Clone Directory Name", value="", placeholder="ComfyUI-SeedVR2_VideoUpscaler", info="Folder name inside the app directory. Change this to avoid overwriting default." ) custom_branch = gr.Textbox( label="Branch / Tag / Commit Hash", value="", placeholder="e.g. main, dev, or hash like d69b65f...", info="Leave empty for default branch. If changing repo, use a new directory name." ) # Group 1: General Settings (Resolution & Presets) with gr.Accordion("### ⚙️ General Settings", open=True): preset_mode = gr.Dropdown( choices=["Recommended (low VRAM)", "Offload (very slow)", "High quality (fast if lots of VRAM)"], value=DEFAULT_PRESET, label="Preset mode", info="Automatically adjusts compilation, tiling, and offload settings based on your hardware capabilities." ) with gr.Row(): resolution = gr.Slider( minimum=256, maximum=4096, step=64, value=1920, label="Target Resolution (Short Side)", info="Target short-side resolution in pixels. The aspect ratio is preserved." ) max_resolution = gr.Number( value=init_max_resolution, label="Max resolution (0=unlimited)", info="Maximum resolution for any edge. Scales down if exceeded. 0 = no limit." ) # Output Format selection with gr.Row(): output_format = gr.Dropdown( choices=["webp", "png", "jpg"], value="webp", label="Output Format (Default: webp)", info="Format for saved images. For video input, the CLI produces MP4 (or PNG sequence), and this app converts the final result if needed." ) seed = gr.Number( value=42, label="Seed", precision=0, info="Random seed for reproducibility." ) repetition_count = gr.Slider( minimum=1, maximum=5, step=1, value=1, label="Loop Count (Images only, Repeat Upscale)", info="Run SeedVR2 N times per image. If downscale is checked, it applies before EACH run to progressively upscale/refine." ) # Group 2: Video Specific with gr.Accordion("### 🎥 Video Settings", open=False): with gr.Row(): chunk_size = gr.Number( value=init_chunk_size, label="Chunk Size (Frames)", info="Frames per chunk for streaming mode. 0 = load all frames at once. Set to specific amount (e.g. 100) to limit VRAM usage on long videos." ) temporal_overlap = gr.Number( value=init_temporal_overlap, label="Temporal Overlap", info="Frames to overlap between chunks/batches for smooth blending and to prevent seams." ) with gr.Row(): prepend_frames = gr.Number( value=0, label="Prepend Frames", info="Prepend N reversed frames to reduce start artifacts. These are automatically removed from the output." ) skip_first_frames = gr.Number( value=0, label="Skip First Frames", precision=0, info="Skip N initial frames of the video." ) load_cap = gr.Number( value=0, label="Load Cap (Max Frames)", precision=0, info="Load maximum N frames from video. 0 = load all." ) with gr.Row(): video_backend = gr.Dropdown( choices=["opencv", "ffmpeg"], value="opencv", label="Video Backend", info="Video encoder backend. 'ffmpeg' requires ffmpeg in system PATH but supports advanced features like 10-bit." ) use_10bit = gr.Checkbox( label="10-bit Output (ffmpeg only)", value=False, info="Use x265 10-bit encoding (reduces banding). Requires ffmpeg backend." ) # Group 3: Model & Quality with gr.Accordion("### 🤖 Model & Quality", open=True): use_gguf = gr.Checkbox( label="Use GGUF-quantized models (gguf)", value=DEFAULT_USE_GGUF, info="When checked, the DiT model dropdown will show GGUF models from cmeka/SeedVR2-GGUF. Efficient for lower VRAM." ) dit_model = gr.Dropdown( choices=initial_model_choices, value=initial_model_value, label="DiT model (Format: RepoID/Filename)", info="DiT transformer model. 7B models have higher quality but require more memory than 3B models." ) # Callback: model choices (gguf <-> safetensors) def _toggle_model_list(gguf_enabled: bool): if gguf_enabled: # set to GGUF list, default the first gguf file return gr.update(choices=GGUF_CHOICES, value=GGUF_CHOICES[0]) else: return gr.update(choices=MODEL_CHOICES, value=MODEL_CHOICES[0]) use_gguf.change(fn=_toggle_model_list, inputs=[use_gguf], outputs=[dit_model]) # Show CUDA device textbox only if CUDA available cuda_device = gr.Textbox( label="CUDA device", value="0" if CUDA_AVAILABLE else "", visible=CUDA_AVAILABLE, info="CUDA device IDs (e.g. '0' or '0,1'). Leave blank for default." ) with gr.Row(): color_correction = gr.Dropdown( choices=["lab", "wavelet", "wavelet_adaptive", "hsv", "adain", "none"], value="lab", label="Color correction", info="Method to match colors. 'lab' (perceptual, recommended), 'wavelet' (frequency-based), 'adain' (statistical), etc." ) input_noise_scale = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Input Noise Scale", info="Input noise injection scale (0.0-1.0). Adds variation to input images." ) latent_noise_scale = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.01, label="Latent Noise Scale", info="Latent noise injection scale (0.0-1.0). Adds variation to latent space." ) with gr.Row(): pre_downscale = gr.Checkbox( label="Pre-downscale image (Images only, removes noise/artifacts)", value=False, info="Reduces image size before upscaling. Helps remove JPEG artifacts or noise as noted in community tips." ) downscale_rate = gr.Slider( minimum=0.1, maximum=0.9, step=0.1, value=0.5, label="Downscale factor", info="0.5 means the input is resized to 50% size before being upscaled to target resolution." ) # Group 4: Performance & Memory (Advanced) with gr.Accordion("### ⚡ Optimization & Memory", open=True): with gr.Row(): batch_size = gr.Slider( minimum=1, maximum=65, step=4, value=1, label="Batch size (4n+1 recommended)", info="Frames per batch. 4n+1 (1, 5, 9, 13...) is optimized for temporal consistency. Higher values use more VRAM." ) uniform_batch_size = gr.Checkbox( label="Uniform Batch Size (Pad final batch)", value=False, info="Pad final batch to match batch_size. Prevents temporal artifacts caused by small final batches. Adds extra compute." ) with gr.Accordion("Memory & Offload / Caching", open=False): use_improved_blockswap = gr.Checkbox( label="Use Improved BlockSwap (Nunchaku Ping-Pong CPUOffload)", value=False, info="Replaces the standard blockswap logic with the improved version from Nunchaku. Useful for faster offloading." ) with gr.Row(): blocks_to_swap = gr.Number( value=init_blocks_to_swap, label="Blocks to swap", info="Transformer blocks to swap to RAM. 0=disabled. Use large value like 99 for auto-detection of max blocks. Requires Offload Device." ) swap_io_components = gr.Checkbox( label="Swap I/O components", value=init_swap_io_components, info="Offload DiT I/O layers for extra VRAM savings. Requires Offload Device." ) # Offload device choices adapt to CUDA availability offload_choices = ["none", "cpu"] + (["cuda:0"] if CUDA_AVAILABLE else []) with gr.Row(): dit_offload_device = gr.Dropdown( choices=offload_choices, value=init_dit_offload_device, label="DiT Offload device", info="Device to move DiT to when idle. 'cpu' frees VRAM between phases." ) vae_offload_device = gr.Dropdown( choices=offload_choices, value=init_vae_offload_device, label="VAE Offload device", info="Device to move VAE to when idle. 'cpu' frees VRAM between phases." ) tensor_offload_device = gr.Dropdown( choices=offload_choices, value=init_tensor_offload_device, label="Tensor Offload device", info="Where to store intermediate tensors. 'cpu' is recommended to save VRAM." ) with gr.Row(): cache_dit = gr.Checkbox( label="Cache DiT", value=False, info="Keep DiT model in memory between generations. Useful for batch/directory mode or streaming." ) cache_vae = gr.Checkbox( label="Cache VAE", value=False, info="Keep VAE model in memory between generations. Useful for batch/directory mode or streaming." ) with gr.Accordion("Advanced Tiling (VRAM Saving)", open=False): with gr.Row(): vae_encode_tiled = gr.Checkbox( label="Enable VAE Encode tiling", value=init_vae_encode_tiled, info="Process VAE encoding in tiles to reduce VRAM usage (good for large inputs)." ) vae_encode_tile_size = gr.Number( value=init_vae_encode_tile_size, label="Encode Tile Size", info="Tile size in pixels for encoding." ) vae_encode_tile_overlap = gr.Number( value=64, label="Encode Overlap", info="Overlap in pixels to reduce visible seams." ) with gr.Row(): vae_decode_tiled = gr.Checkbox( label="Enable Decode Tiling", value=init_vae_decode_tiled, info="Process VAE decoding in tiles to reduce VRAM usage." ) vae_decode_tile_size = gr.Number( value=init_vae_decode_tile_size, label="Decode Tile Size", info="Tile size in pixels for decoding." ) vae_decode_tile_overlap = gr.Number( value=64, label="Decode Overlap", info="Overlap in pixels to reduce visible seams." ) tile_debug = gr.Dropdown( choices=["false", "encode", "decode"], value="false", label="Tile Debug Visualization", info="Visualizes the tiling process for debugging purposes." ) with gr.Accordion("Compilation & Backend (Torch 2.0+)", open=False): with gr.Row(): compile_dit = gr.Checkbox( label="Enable torch.compile for DiT", value=init_compile_dit, info="20-40% speedup. Requires PyTorch 2.0+ and Triton. May increase memory usage." ) compile_vae = gr.Checkbox( label="Enable torch.compile for VAE", value=init_compile_vae, info="15-25% speedup for VAE encoding/decoding." ) with gr.Row(): compile_backend = gr.Dropdown( choices=["inductor", "cudagraphs"], value="inductor", label="Backend", info="'inductor' (full optimization) or 'cudagraphs' (lightweight)." ) compile_mode = gr.Dropdown( choices=["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"], value="default", label="Mode", info="Optimization level: 'default' (fast compile), 'max-autotune' (best speed, slow compile), etc." ) with gr.Row(): attention_mode = gr.Dropdown( choices=["sdpa", "flash_attn_2", "flash_attn_3", "sageattn_2", "sageattn_3"], value="sdpa", label="Attention Mode", info="Attention backend. 'sdpa' (default), 'flash_attn' (faster), or 'sageattn' (Blackwell)." ) compile_fullgraph = gr.Checkbox( label="Fullgraph", value=False, info="Compile entire model as single graph. Faster but less flexible." ) compile_dynamic = gr.Checkbox( label="Dynamic Shapes", value=False, info="Handle varying input shapes without recompilation." ) with gr.Row(): compile_dynamo_cache_size_limit = gr.Number( value=64, label="Dynamo Cache Limit", info="Max cached compiled versions per function." ) compile_dynamo_recompile_limit = gr.Number( value=128, label="Dynamo Recompile Limit", info="Max recompilation attempts before fallback to eager mode." ) with gr.Row(): debug_mode = gr.Checkbox( label="Enable Debug Logs", value=True, info="Show verbose output in CLI logs." ) extra_args = gr.Textbox( label="Extra CLI args", value=init_extra_args, info="Manually pass additional flags to the CLI (e.g. --custom_flag value)." ) # Bind the preset change callback (outputs updated to match new UI elements) preset_mode.change( fn=preset_changed, inputs=[preset_mode], outputs=[ compile_dit, compile_vae, vae_encode_tiled, vae_encode_tile_size, vae_decode_tiled, vae_decode_tile_size, max_resolution, blocks_to_swap, swap_io_components, dit_offload_device, vae_offload_device, tensor_offload_device, extra_args, chunk_size, temporal_overlap ] ) with gr.Column(scale=1, variant="panel"): # Custom progress bar HTML component progress_display = gr.HTML(label="Progress", value=make_progress_html(0, 100, "Ready")) download_zip = gr.File(label="Download Result") logs = gr.Textbox(label="CLI logs (streaming)", lines=25, autoscroll=True) clear = gr.ClearButton(components=[gallery, video_input, download_zip, logs, progress_display], variant="secondary") submit.click( fn=ui_upscale_main, inputs=[ gallery, video_input, resolution, max_resolution, preset_mode, dit_model, use_gguf, cuda_device, # Compiled Inputs compile_dit, compile_vae, compile_backend, compile_mode, compile_fullgraph, compile_dynamic, compile_dynamo_cache_size_limit, compile_dynamo_recompile_limit, attention_mode, # Tiling Inputs vae_encode_tiled, vae_encode_tile_size, vae_encode_tile_overlap, vae_decode_tiled, vae_decode_tile_size, vae_decode_tile_overlap, tile_debug, # Processing Inputs batch_size, uniform_batch_size, seed, skip_first_frames, load_cap, # Quality Inputs color_correction, input_noise_scale, latent_noise_scale, # Memory Inputs blocks_to_swap, swap_io_components, dit_offload_device, vae_offload_device, tensor_offload_device, cache_dit, cache_vae, extra_args, # General Inputs pre_downscale, downscale_rate, repetition_count, output_format, use_improved_blockswap, # Video Inputs chunk_size, temporal_overlap, prepend_frames, video_backend, use_10bit, debug_mode, custom_repo_url, custom_branch, custom_clone_name ], outputs=[download_zip, logs, progress_display] ) # load paste JS demo.load(None, None, None, js=paste_js) demo.queue(max_size=1) demo.launch(inbrowser=True) if __name__ == "__main__": main()