#!/usr/bin/env python3 """ Gradio front-end wrapper for SeedVR2's official inference_cli.py This script is the user's app.py enhanced to stream subprocess logs in real-time into the Gradio logs textbox. It runs the CLI as subprocesses and streams stdout/stderr lines as they arrive using a queue and reader threads. The Gradio handler `ui_upscale` is implemented as a generator so the frontend receives incremental updates. This script provides a simple web UI for single-image upscaling using the official ComfyUI-SeedVR2_VideoUpscaler `inference_cli.py` script. It calls the official CLI as a subprocess, and will automatically download model weights from Hugging Face (numz/SeedVR2_comfyUI) if they are missing. If the ComfyUI-SeedVR2_VideoUpscaler repository is not present, the script will attempt to `git clone` it automatically into ./ComfyUI-SeedVR2_VideoUpscaler. Run: python app.py Requirements - Python 3.10+ - Gradio (pip install gradio) - Git available in PATH (for automatic cloning) or clone the repo manually - PyTorch + CUDA (if using GPU) Notes - This wrapper calls the repo's `inference_cli.py` as a subprocess so the CLI's memory/optimization features (BlockSwap, VAE tiling, etc.) remain available. - Models will be downloaded to the cloned repo's ./models/SeedVR2 directory if missing. Use HUGGINGFACE_HUB_TOKEN env var if required for private access. """ import os import sys import cv2 import time import torch import queue import shutil import zipfile import threading import subprocess import numpy as np import gradio as gr from pathlib import Path from typing import Optional, Tuple, Generator, List # huggingface helper (used for model auto-download) from huggingface_hub import hf_hub_download def imreadUTF8(path, flags=cv2.IMREAD_COLOR): """ OpenCV's cv2.imread cannot handle non-ASCII paths. This function reads an image from a path that may contain UTF-8 characters. """ try: # Use NumPy to read from the file, which correctly handles UTF-8 paths with open(path, "rb") as stream: bytes_data = bytearray(stream.read()) numpyarray = np.asarray(bytes_data, dtype=np.uint8) # Use cv2.imdecode to decode the image from the memory buffer img = cv2.imdecode(numpyarray, flags) return img except Exception as e: # If reading fails, print the error message and return None print(f"ERROR: Failed to read image with UTF-8 path: {path}") print(f" Details: {e}") return None def imwriteUTF8(save_path, image): """ OpenCV's cv2.imwrite cannot handle non-ASCII paths. This function writes an image to a path that may contain UTF-8 characters. """ try: img_name = os.path.basename(save_path) _, extension = os.path.splitext(img_name) # Encode the image into the specified format (determined by the file extension) is_success, im_buf_arr = cv2.imencode(extension, image) if is_success: # Write the image data from memory to the file im_buf_arr.tofile(save_path) return True else: print(f"ERROR: Failed to encode image for path: {save_path}") return False except Exception as e: print(f"ERROR: Failed to write image with UTF-8 path: {save_path}") print(f" Details: {e}") return False # Apply Monkey Patch to cv2 (for app.py usage) print("[SeedVR2 Gradio] Applying UTF-8 patch to OpenCV (Frontend)...") cv2.imread = imreadUTF8 cv2.imwrite = imwriteUTF8 # ---------------- # Config / paths # ---------------- REPO_URL = "https://github.com/numz/ComfyUI-SeedVR2_VideoUpscaler.git" CLONE_DIR = Path(__file__).resolve().parent / "ComfyUI-SeedVR2_VideoUpscaler" INFERENCE_CLI = CLONE_DIR / "inference_cli.py" PY_EXE = sys.executable # Use same Python executable to run CLI # Path to the custom improved blockswap file IMPROVED_BLOCKSWAP_SOURCE = Path(__file__).resolve().parent / "src" / "optimization" / "blockswap.py" IMPROVED_MEMORY_MANAGER_SOURCE = Path(__file__).resolve().parent / "src" / "optimization" / "memory_manager.py" # Default HF repo for VAE (VAE is usually static and comes from the official repo) DEFAULT_VAE_REPO_ID = "numz/SeedVR2_comfyUI" # Models are now stored in a fixed top-level directory, independent of the clone dir DEFAULT_MODEL_DIR = Path(__file__).resolve().parent / "models" / "SeedVR2" # ---------------- # Model Definitions (RepoID / Filename) # ---------------- # Standard Models (Safetensors) MODEL_CHOICES = [ "numz/SeedVR2_comfyUI/seedvr2_ema_3b_fp8_e4m3fn.safetensors", "numz/SeedVR2_comfyUI/seedvr2_ema_3b_fp16.safetensors", "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b_fp8_e4m3fn_mixed_block35_fp16.safetensors", "numz/SeedVR2_comfyUI/seedvr2_ema_7b_fp16.safetensors", # sharp variants "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b_sharp_fp8_e4m3fn_mixed_block35_fp16.safetensors", "numz/SeedVR2_comfyUI/seedvr2_ema_7b_sharp_fp16.safetensors", ] # GGUF / alternate model support GGUF_CHOICES = [ "AInVFX/SeedVR2_comfyUI/seedvr2_ema_3b-Q4_K_M.gguf", "AInVFX/SeedVR2_comfyUI/seedvr2_ema_3b-Q8_0.gguf", "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b-Q4_K_M.gguf", # sharp variants "AInVFX/SeedVR2_comfyUI/seedvr2_ema_7b_sharp-Q4_K_M.gguf", # custom GGUF from cmeka "cmeka/SeedVR2-GGUF/seedvr2_ema_7b-Q8_0.gguf", "cmeka/SeedVR2-GGUF/seedvr2_ema_7b_sharp-Q8_0.gguf", ] # # Model registry with metadata # MODEL_REGISTRY = { # # 3B models # "seedvr2_ema_3b-Q4_K_M.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="3B", precision="Q4_K_M", sha256="e665e3909de1a8c88a69c609bca9d43ff5a134647face2ce4497640cc3597f0e"), # "seedvr2_ema_3b-Q8_0.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="3B", precision="Q8_0", sha256="be0d60083a2051a265eb4b77f28edf494e6db67ffc250216f32b72292e5cbd96"), # "seedvr2_ema_3b_fp8_e4m3fn.safetensors": ModelInfo(size="3B", precision="fp8_e4m3fn", sha256="3bf1e43ebedd570e7e7a0b1b60d6a02e105978f505c8128a241cde99a8240cff"), # "seedvr2_ema_3b_fp16.safetensors": ModelInfo(size="3B", precision="fp16", sha256="2fd0e03a3dad24e07086750360727ca437de4ecd456f769856e960ae93e2b304"), # # 7B models # "seedvr2_ema_7b-Q4_K_M.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="Q4_K_M", sha256="db9cb2ad90ebd40d2e8c29da2b3fc6fd03ba87cd58cbadceccca13ad27162789"), # "seedvr2_ema_7b_fp8_e4m3fn_mixed_block35_fp16.safetensors": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="fp8_e4m3fn_mixed_block35_fp16", sha256="3d68b5ec0b295ae28092e355c8cad870edd00b817b26587d0cb8f9dd2df19bb2"), # "seedvr2_ema_7b_fp16.safetensors": ModelInfo(size="7B", precision="fp16", sha256="7b8241aa957606ab6cfb66edabc96d43234f9819c5392b44d2492d9f0b0bbe4a"), # # 7B sharp variants # "seedvr2_ema_7b_sharp-Q4_K_M.gguf": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="Q4_K_M", variant="sharp", sha256="7aed800ac4eb8e0d18569a954c0ff35f5a1caa3ed5d920e66cc31405f75b6e69"), # "seedvr2_ema_7b_sharp_fp8_e4m3fn_mixed_block35_fp16.safetensors": ModelInfo(repo="AInVFX/SeedVR2_comfyUI", size="7B", precision="fp8_e4m3fn_mixed_block35_fp16", variant="sharp", sha256="0d2c5b8be0fda94351149c5115da26aef4f4932a7a2a928c6f184dda9186e0be"), # "seedvr2_ema_7b_sharp_fp16.safetensors": ModelInfo(size="7B", precision="fp16", variant="sharp", sha256="20a93e01ff24beaeebc5de4e4e5be924359606c356c9c51509fba245bd2d77dd"), # # VAE models # "ema_vae_fp16.safetensors": ModelInfo(category="vae", precision="fp16", sha256="20678548f420d98d26f11442d3528f8b8c94e57ee046ef93dbb7633da8612ca1"), # } # Detect Hardware Availability CUDA_AVAILABLE = torch.cuda.is_available() # Detect MPS availability (for Apple Silicon) MPS_AVAILABLE = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and torch.backends.mps.is_built() # Check for any hardware acceleration ACCELERATOR_AVAILABLE = CUDA_AVAILABLE or MPS_AVAILABLE # ----------------- # Repo / model helpers # ----------------- def ensure_repo_cloned( repo_url: str = REPO_URL, clone_dir_name: str = "ComfyUI-SeedVR2_VideoUpscaler", repo_branch: str = "", force_update: bool = False ) -> Path: """ Ensure the repository is cloned locally into the specified directory name. Supports specific branches/tags via checkout. Returns the resolved Path object to the cloned directory. """ # Resolve the physical path based on the script's parent location target_clone_dir = Path(__file__).resolve().parent / clone_dir_name target_cli = target_clone_dir / "inference_cli.py" # Helper function to handle detached/orphaned commits def _smart_checkout(cwd, ref): print(f"[SeedVR2 Gradio] Checking out '{ref}' in {cwd} ...") try: # Try standard checkout first (fastest if ref exists locally) subprocess.run(["git", "-C", str(cwd), "checkout", ref], check=True) except subprocess.CalledProcessError: # Fallback: If ref is not found (e.g. orphaned commit hash), fetch it explicitly print(f"[SeedVR2 Gradio] Standard checkout failed. Attempting to fetch specific ref '{ref}' from origin...") try: subprocess.run(["git", "-C", str(cwd), "fetch", "origin", ref], check=True) subprocess.run(["git", "-C", str(cwd), "checkout", ref], check=True) except Exception as e: raise RuntimeError(f"Failed to fetch/checkout specific ref '{ref}': {e}") if target_clone_dir.exists() and (target_clone_dir / ".git").exists(): # Repo exists if force_update: try: print(f"[SeedVR2 Gradio] Updating {target_clone_dir} ...") subprocess.run(["git", "-C", str(target_clone_dir), "fetch", "--all"], check=True) # If a specific branch/hash is requested if repo_branch: _smart_checkout(target_clone_dir, repo_branch) # If it's a branch name (not a detached hash), we might want to pull latest # But checking if it's a branch vs hash is complex, generally strictly checking out the ref is safer for reproducibility else: subprocess.run(["git", "-C", str(target_clone_dir), "pull"], check=True) except Exception as e: raise RuntimeError(f"Failed to update repository {target_clone_dir}: {e}") # If not forcing update, but a branch is specified, ensure we are on it elif repo_branch: try: subprocess.run(["git", "-C", str(target_clone_dir), "fetch", "--all"], check=True) _smart_checkout(target_clone_dir, repo_branch) except Exception as e: raise RuntimeError(f"Failed to switch to branch {repo_branch}: {e}") # Ensure inference_cli present if not target_cli.exists(): raise RuntimeError(f"Repository found at {target_clone_dir} but inference_cli.py is missing.") return target_clone_dir # Clone repo if not exists try: print(f"[SeedVR2 Gradio] Cloning {repo_url} into {target_clone_dir} ...") # Standard clone (fetches default branch) subprocess.run(["git", "clone", repo_url, str(target_clone_dir)], check=True) if repo_branch: _smart_checkout(target_clone_dir, repo_branch) except FileNotFoundError: raise RuntimeError("git not found: please install Git or clone the repository manually.") except Exception as e: raise RuntimeError(f"Failed to clone repository: {e}") if not target_cli.exists(): raise RuntimeError(f"Clone completed but inference_cli.py not found in {target_clone_dir}.") return target_clone_dir def apply_inference_cli_patch(cli_path: Path): """ Injects UTF-8 compatible imread/imwrite wrappers directly into inference_cli.py. This modifies the physical file so the subprocess (even on Windows spawn) uses the patch. """ if not cli_path.exists(): return try: with open(cli_path, "r", encoding="utf-8") as f: content = f.read() # Check if already patched to avoid duplicates if "def imreadUTF8" in content: return # The patch content to inject. # Note: We ensure 'import numpy as np' and 'import os' are available or re-imported. # inference_cli.py typically has 'import cv2', we inject right after that. patch_code = r''' # ============================================================================= # GRADIO APP PATCH: UTF-8 Support for Windows (Auto-Injected) # ============================================================================= import numpy as np import os def imreadUTF8(path, flags=cv2.IMREAD_COLOR): try: with open(path, "rb") as stream: bytes_data = bytearray(stream.read()) numpyarray = np.asarray(bytes_data, dtype=np.uint8) return cv2.imdecode(numpyarray, flags) except Exception as e: print(f"Error reading image {path}: {e}") return None def imwriteUTF8(save_path, image): try: img_name = os.path.basename(save_path) _, extension = os.path.splitext(img_name) is_success, im_buf_arr = cv2.imencode(extension, image) if is_success: im_buf_arr.tofile(save_path) return True else: return False except Exception as e: print(f"Error writing image {save_path}: {e}") return False # Override cv2 methods cv2.imread = imreadUTF8 cv2.imwrite = imwriteUTF8 # ============================================================================= ''' # Inject after 'import cv2' if "import cv2" in content: print(f"[SeedVR2 Gradio] Patching {cli_path} for UTF-8 subprocess support...") new_content = content.replace("import cv2", "import cv2" + patch_code, 1) with open(cli_path, "w", encoding="utf-8") as f: f.write(new_content) else: print("[SeedVR2 Gradio] WARNING: Could not find 'import cv2' in inference_cli.py. UTF-8 patch skipped.") except Exception as e: print(f"[SeedVR2 Gradio] ERROR applying UTF-8 patch to inference_cli: {e}") def patch_model_registry(repo_root: Path): """ Appends custom model definitions to src/utils/model_registry.py. This allows the CLI to recognize new GGUF models that aren't in the official registry. """ registry_path = repo_root / "src" / "utils" / "model_registry.py" if not registry_path.exists(): print(f"[SeedVR2 Gradio] WARN: Could not find model_registry.py at {registry_path}") return try: with open(registry_path, "r", encoding="utf-8") as f: content = f.read() # Check if already patched if "seedvr2_ema_7b-Q8_0.gguf" in content: return print(f"[SeedVR2 Gradio] Patching {registry_path} with custom GGUF models...") # Code to append to the end of the file. # Since ModelInfo and MODEL_REGISTRY are defined in the file, we can use them directly. patch_code = r''' # ============================================================================= # GRADIO APP PATCH: Custom Model Registry Entries # ============================================================================= try: # Update registry with custom GGUF models requested by user MODEL_REGISTRY.update({ "seedvr2_ema_7b-Q8_0.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q8_0", sha256="669788655e8f15f306284f267a444e9766c8a421869577b16a961e43029c737b"), "seedvr2_ema_7b_sharp-Q8_0.gguf": ModelInfo(repo="cmeka/SeedVR2-GGUF", size="7B", precision="Q8_0", variant="sharp", sha256="b1f81cb5700b0b1f432f2c785528356c952c41c74d03d205c6f14b0bd6da303d"), }) print("[Internal] Custom GGUF models injected into MODEL_REGISTRY successfully.") except Exception as e: print(f"[Internal] Failed to inject custom models: {e}") # ============================================================================= ''' with open(registry_path, "a", encoding="utf-8") as f: f.write(patch_code) except Exception as e: print(f"[SeedVR2 Gradio] ERROR patching model_registry.py: {e}") # ----------------- # BlockSwap Management # ----------------- def manage_blockswap_file(use_improved: bool, repo_root: Path) -> str: """ Manages the blockswap.py file in the specified cloned repository. Accepted `repo_root` path to ensure we modify the correct repo. """ target_path_blockswap = repo_root / "src" / "optimization" / "blockswap.py" backup_path_blockswap = repo_root / "src" / "optimization" / "blockswap.py.bak" target_path_memory_manager = repo_root / "src" / "optimization" / "memory_manager.py" backup_path_memory_manager = repo_root / "src" / "optimization" / "memory_manager.py.bak" msg = "" # Ensure src/optimization exists (some forks might differ in structure) if not target_path_blockswap.parent.exists(): return f"[WARN] Optimization folder not found at {target_path_blockswap.parent}. Skipping blockswap patch.\n" if use_improved: if not IMPROVED_BLOCKSWAP_SOURCE.exists(): return f"[WARN] Improved blockswap source not found at {IMPROVED_BLOCKSWAP_SOURCE}. Keeping current version.\n" # 1. Check if we need to backup the original blockswap (only if backup doesn't exist yet) if target_path_blockswap.exists() and not backup_path_blockswap.exists(): try: shutil.move(str(target_path_blockswap), str(backup_path_blockswap)) msg += f"[INFO] Backed up original blockswap to {backup_path_blockswap.name}.\n" except Exception as e: return f"[ERROR] Failed to backup blockswap: {e}\n" # 2. Copy the improved file to target blockswap try: shutil.copy(str(IMPROVED_BLOCKSWAP_SOURCE), str(target_path_blockswap)) msg += "[INFO] Switched to Improved BlockSwap (Nunchaku implementation).\n" except Exception as e: return f"[ERROR] Failed to install improved blockswap: {e}\n" # Memory Manager Handling if not IMPROVED_MEMORY_MANAGER_SOURCE.exists(): return f"[WARN] Improved memory_manager source not found at {IMPROVED_MEMORY_MANAGER_SOURCE}. Keeping current version.\n" # 3. Check if we need to backup the original memory_manager (only if backup doesn't exist yet) if target_path_memory_manager.exists() and not backup_path_memory_manager.exists(): try: shutil.move(str(target_path_memory_manager), str(backup_path_memory_manager)) msg += f"[INFO] Backed up original memory_manager to {backup_path_memory_manager.name}.\n" except Exception as e: return f"[ERROR] Failed to backup memory_manager: {e}\n" # 4. Copy the improved file to target memory_manager try: shutil.copy(str(IMPROVED_MEMORY_MANAGER_SOURCE), str(target_path_memory_manager)) msg += "[INFO] Switched to Improved memory_manager (Nunchaku implementation).\n" except Exception as e: return f"[ERROR] Failed to install improved memory_manager: {e}\n" return msg else: # Restore original blockswap if available if backup_path_blockswap.exists(): try: # Remove current target blockswap if it exists (which might be the improved one) if target_path_blockswap.exists(): os.remove(target_path_blockswap) # Restore backup shutil.move(str(backup_path_blockswap), str(target_path_blockswap)) msg += "[INFO] Restored Original BlockSwap from backup.\n" except Exception as e: return f"[ERROR] Failed to restore original blockswap: {e}\n" else: # Backup doesn't exist, assume we are already on original or clean install msg += "[INFO] Using Original BlockSwap (No backup found/needed).\n" # Restore original memory_manager if available if backup_path_memory_manager.exists(): try: # Remove current target memory_manager if it exists (which might be the improved one) if target_path_memory_manager.exists(): os.remove(target_path_memory_manager) # Restore backup shutil.move(str(backup_path_memory_manager), str(target_path_memory_manager)) msg += "[INFO] Restored Original memory_manager from backup.\n" except Exception as e: return f"[ERROR] Failed to restore original memory_manager: {e}\n" else: # Backup doesn't exist, assume we are already on original or clean install msg += "[INFO] Using Original memory_manager (No backup found/needed).\n" return msg # ----------------- # Model download # ----------------- def ensure_models_available( selected_model_filename: str, model_dir: Optional[Path] = None, repo_id: str = DEFAULT_VAE_REPO_ID ) -> None: """ Ensure the selected DiT model and the VAE file exist locally. If missing, download from the specified Hugging Face repo directly into model_dir using 'local_dir' to avoid nested cache structures. """ if model_dir is None: model_dir = DEFAULT_MODEL_DIR else: model_dir = Path(model_dir) model_dir.mkdir(parents=True, exist_ok=True) # Items to check: VAE + selected DiT model required = ["ema_vae_fp16.safetensors", selected_model_filename] # Check if files physically exist at the target location missing = [_f for _f in required if not (model_dir / _f).exists()] if not missing: return # Optional: silence HF symlink warning if desired os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1") # Attempt download for each missing file hf_token = os.environ.get("HF_ACCESS_TOKEN") for fname in missing: target_path = model_dir / fname # If file already somehow exists at target, skip if target_path.exists(): continue # Decide repo for this filename: # - VAE must always come from the official DEFAULT_VAE_REPO_ID (numz/SeedVR2_comfyUI) # - Dit model uses the provided repo_id (which comes from the dropdown selection) repo_for_fname = DEFAULT_VAE_REPO_ID if fname == "ema_vae_fp16.safetensors" else repo_id try: print(f"[SeedVR2 Gradio] Downloading {fname} from {repo_for_fname} directly to {model_dir} ...") # Use local_dir instead of cache_dir. # This forces the file to be saved exactly at {model_dir}/{fname} # local_dir_use_symlinks=False ensures we get a real file, not a symlink, # which prevents issues where the CLI subprocess cannot resolve the path. downloaded_path = hf_hub_download( repo_id=repo_for_fname, filename=fname, local_dir=str(model_dir), # Download directly to target folder repo_type="model", token=hf_token, ) print(f"[SeedVR2 Gradio] Download completed: {downloaded_path}") except Exception as e: raise RuntimeError(f"Failed to download {fname} from Hugging Face repo {repo_for_fname}: {e}") # ---------------- # Subprocess streaming helpers # ---------------- def _start_process_stream(cmd_args, cwd: str, env: dict) -> Tuple[Optional[subprocess.Popen], queue.Queue, Optional[threading.Thread], Optional[threading.Thread]]: """Start subprocess and return (proc, q, t_out, t_err). The returned queue will receive text lines as they arrive. Lines are simple strings (already newline-terminated). stderr lines are prefixed with "stderr: ". """ q = queue.Queue() try: proc = subprocess.Popen( cmd_args, cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, encoding='utf-8', errors='replace', bufsize=1, env=env ) except Exception as e: # Put error to queue and return a dummy proc q.put(f"[FAILED TO LAUNCH] {e}\n") return None, q, None, None def _reader(fh, prefix: str): try: while True: line = fh.readline() if not line: break if not line.endswith("\n"): line = line + "\n" q.put(prefix + line) except Exception as e: q.put(f"[reader error] {e}\n") t_out = threading.Thread(target=_reader, args=(proc.stdout, ""), daemon=True) t_err = threading.Thread(target=_reader, args=(proc.stderr, "stderr: "), daemon=True) t_out.start() t_err.start() return proc, q, t_out, t_err # ----------------- # CLI runner (streaming) # ----------------- def expected_upscaled_path(input_path: str, output_format: str = "png") -> str: """Calculates the expected output path based on the input path and requested format.""" p = Path(input_path) stem = p.stem parent = p.parent suffix = "_upscaled" # if output_format == "mp4": # return str((parent / f"{stem}{suffix}.mp4").resolve()) # else: # return str((parent / f"{stem}{suffix}.png").resolve()) return str((parent / f"{stem}{suffix}.{output_format}").resolve()) # Single-image/video CLI runner (generator) def run_cli_upscale_stream( input_path: str, resolution: int = 1080, max_resolution: int = 0, dit_model_filename: Optional[str] = None, # Receives just the filename cuda_device: Optional[str] = None, # Compilation & Performance compile_dit: bool = False, compile_vae: bool = False, compile_backend: str = "inductor", compile_mode: str = "default", compile_fullgraph: bool = False, compile_dynamic: bool = False, compile_dynamo_cache_size_limit: int = 64, compile_dynamo_recompile_limit: int = 128, attention_mode: str = "sdpa", # Tiling (Split Encode/Decode) vae_encode_tiled: bool = False, vae_encode_tile_size: int = 1024, vae_encode_tile_overlap: int = 128, vae_decode_tiled: bool = False, vae_decode_tile_size: int = 1024, vae_decode_tile_overlap: int = 128, tile_debug: str = "false", # Processing batch_size: int = 1, uniform_batch_size: bool = False, seed: int = 42, skip_first_frames: int = 0, load_cap: int = 0, # Quality & Color color_correction: str = "lab", input_noise_scale: float = 0.0, latent_noise_scale: float = 0.0, # Memory & Offload blocks_to_swap: int = 0, swap_io_components: bool = False, dit_offload_device: str = "none", vae_offload_device: str = "none", tensor_offload_device: str = "cpu", cache_dit: bool = False, cache_vae: bool = False, extra_args: str = "", model_dir: Optional[str] = None, repo_id: str = DEFAULT_VAE_REPO_ID, # Receives the specific Repo ID for DiT repo_path: Optional[Path] = None, timeout: int = 3600, pre_downscale: bool = False, # for artifact removal downscale_rate: float = 0.5, output_format: str = "png", # Can now be "mp4" use_improved_blockswap: bool = False, # New argument for switching blockswap version # Video Args chunk_size: int = 0, temporal_overlap: int = 0, prepend_frames: int = 0, video_backend: str = "opencv", use_10bit: bool = False, # Debug Arg debug: bool = False ) -> Generator[Tuple[Optional[str], str], None, None]: """ Generator yields (out_path_or_None, logs_so_far) while streaming CLI logs. Includes Phase-Aware Dynamic Fallback logic. """ # Defaults if repo_path is None: repo_path = CLONE_DIR current_inference_cli = repo_path / "inference_cli.py" # 1. Repo Check if not current_inference_cli.exists(): yield None, f"[ERROR] inference_cli.py not found in {repo_path}\n" return # Patch inference_cli.py with UTF-8 support try: apply_inference_cli_patch(current_inference_cli) except Exception as e: yield None, f"[WARN] Failed to patch inference_cli: {e}\n" # Patch model_registry.py with custom models try: patch_model_registry(repo_path) except Exception as e: yield None, f"[WARN] Failed to patch model_registry: {e}\n" # Handle BlockSwap File Replacement Logic try: swap_log = manage_blockswap_file(use_improved_blockswap, repo_root=repo_path) # Yield the log about blockswap immediately yield None, swap_log except Exception as e: yield None, f"[ERROR] BlockSwap management failed: {e}\n" # Use the global default if not provided if model_dir is None: model_dir = str(DEFAULT_MODEL_DIR) # Ensure model files present if dit_model_filename: try: ensure_models_available( dit_model_filename, model_dir=Path(model_dir), repo_id=repo_id, ) except Exception as e: yield None, f"[ERROR] Model download failed: {e}\n" return safe_input_path = input_path temp_copy = None # Pre-downscale logic (Artifact Removal Trick) - Only applies to Images in this implementation # We skip this for MP4 files to avoid complex video processing in python before CLI is_video = input_path.lower().endswith(('.mp4', '.avi', '.mov', '.mkv')) # Pre-downscale (Images only) if pre_downscale and not is_video: try: filename = os.path.basename(input_path) # Load original image using OpenCV img_obj = cv2.imread(input_path, cv2.IMREAD_UNCHANGED) if img_obj is None: raise ValueError(f"Failed to load image: {input_path}") # Calculate new dimensions (OpenCV shape is [height, width]) h, w = img_obj.shape[:2] if (max(w, h) > 250): new_w = int(w * downscale_rate) new_h = int(h * downscale_rate) # Resize # Use INTER_AREA for downscaling (better quality/less aliasing for shrinking) # Use INTER_LANCZOS4 if scaling up (though this block is specifically for downscaling) interpolation_method = cv2.INTER_AREA if downscale_rate < 1.0 else cv2.INTER_LANCZOS4 img_resized = cv2.resize(img_obj, (new_w, new_h), interpolation=interpolation_method) # Prepare temp directory tmp_dir = CLONE_DIR / "tmp_inputs" tmp_dir.mkdir(parents=True, exist_ok=True) # Save to a unique temp file (forces .png for intermediate input) new_name = f"{filename}_downscaled.png" temp_copy = str(tmp_dir / new_name) # Use patched cv2.imwrite cv2.imwrite(temp_copy, img_resized) # Use this temp file as the input for CLI safe_input_path = temp_copy yield None, f"[INFO] Pre-downscaled input by factor {downscale_rate} (Size: {w}x{h} -> {new_w}x{new_h}) to reduce artifacts.\n" except Exception as e: yield None, f"[ERROR] Failed to pre-downscale image: {e}\n" return # 2. Command Builder def _build_cmd(curr_tile_size, curr_batch_size): # Determine strict output format if is_video: # If input is video, force mp4 output for CLI unless user explicitly wants png sequence? # Usually users want mp4 back. cmd_format = "mp4" else: # For images, use png (CLI handles webp/jpg conversion internally if modified, # but standard CLI outputs png/mp4). We force png here, app.py handles conversion later. cmd_format = "png" cmd = [PY_EXE, str(current_inference_cli), safe_input_path, "--resolution", str(resolution), "--output_format", cmd_format, "--batch_size", str(curr_batch_size), "--color_correction", color_correction, "--model_dir", str(model_dir), "--seed", str(seed), "--attention_mode", str(attention_mode)] if max_resolution and int(max_resolution) > 0: cmd += ["--max_resolution", str(int(max_resolution))] if dit_model_filename: # CLI just needs the filename relative to --model_dir (or absolute path) cmd += ["--dit_model", str(dit_model_filename)] # Only add --cuda_device if CUDA available and user provided a value if CUDA_AVAILABLE and cuda_device: cmd += ["--cuda_device", str(cuda_device)] # --- Compilation Options --- if compile_dit: cmd += ["--compile_dit"] if compile_vae: cmd += ["--compile_vae"] if compile_dit or compile_vae: cmd += [ "--compile_backend", str(compile_backend), "--compile_mode", str(compile_mode), "--compile_dynamo_cache_size_limit", str(compile_dynamo_cache_size_limit), "--compile_dynamo_recompile_limit", str(compile_dynamo_recompile_limit) ] if compile_fullgraph: cmd += ["--compile_fullgraph"] if compile_dynamic: cmd += ["--compile_dynamic"] # --- Tiling Options --- # Note: curr_tile_size comes from the loop strategy (Phase Fallback), # normally we use the user provided vae_encode_tile_size. if vae_encode_tiled: cmd += ["--vae_encode_tiled", "--vae_encode_tile_size", str(curr_tile_size), "--vae_encode_tile_overlap", str(vae_encode_tile_overlap)] if vae_decode_tiled: cmd += ["--vae_decode_tiled", "--vae_decode_tile_size", str(vae_decode_tile_size), "--vae_decode_tile_overlap", str(vae_decode_tile_overlap)] if tile_debug != "false": cmd += ["--tile_debug", str(tile_debug)] # --- Processing & Quality --- if uniform_batch_size: cmd += ["--uniform_batch_size"] if skip_first_frames > 0: cmd += ["--skip_first_frames", str(int(skip_first_frames))] if load_cap > 0: cmd += ["--load_cap", str(int(load_cap))] if input_noise_scale > 0: cmd += ["--input_noise_scale", str(input_noise_scale)] if latent_noise_scale > 0: cmd += ["--latent_noise_scale", str(latent_noise_scale)] # --- BlockSwap / Offload / Caching --- if blocks_to_swap and int(blocks_to_swap) > 0: cmd += ["--blocks_to_swap", str(int(blocks_to_swap))] if swap_io_components: cmd += ["--swap_io_components"] # Offload flags: note these are strings like "none"/"cpu"/"cuda:0" if dit_offload_device and dit_offload_device != "none": # Ensure we don't pass a cuda device offload when cuda isn't available if not (dit_offload_device.startswith("cuda") and not CUDA_AVAILABLE): cmd += ["--dit_offload_device", str(dit_offload_device)] if vae_offload_device and vae_offload_device != "none": if not (vae_offload_device.startswith("cuda") and not CUDA_AVAILABLE): cmd += ["--vae_offload_device", str(vae_offload_device)] if tensor_offload_device and tensor_offload_device != "none": if not (tensor_offload_device.startswith("cuda") and not CUDA_AVAILABLE): cmd += ["--tensor_offload_device", str(tensor_offload_device)] if cache_dit: cmd += ["--cache_dit"] if cache_vae: cmd += ["--cache_vae"] # --- Video Specific Flags --- if chunk_size > 0: cmd += ["--chunk_size", str(int(chunk_size))] if temporal_overlap > 0: cmd += ["--temporal_overlap", str(int(temporal_overlap))] if prepend_frames > 0: cmd += ["--prepend_frames", str(int(prepend_frames))] if video_backend and video_backend != "opencv": cmd += ["--video_backend", str(video_backend)] if use_10bit: cmd += ["--10bit"] # Debug Flag if debug: cmd += ["--debug"] if extra_args: # Allow advanced users to type additional flags (space separated) cmd += extra_args.split() return cmd # 3. Dynamic Strategy Loop # Use encode tile size as the dynamic variable for fallback current_tile_size = int(vae_encode_tile_size) current_batch_size = int(batch_size) # Initialize log tracking logs_buf = "" # Add previous swap logs to buf logs_buf += swap_log if 'swap_log' in locals() else "" max_attempts = 5 # Prevent infinite loops attempt_count = 0 idx = 0 while attempt_count < max_attempts: attempt_count += 1 note = f"Tile: {current_tile_size}, Batch: {current_batch_size}" header = f"\n\n=== ATTEMPT {attempt_count}/{max_attempts} ({note}) ===\n" logs_buf += header # yield immediate header yield None, logs_buf cmd = _build_cmd(current_tile_size, current_batch_size) logs_buf += f"[CMD] {' '.join(cmd)}\n" yield None, logs_buf # start streaming process env = os.environ.copy() # Make Python in child process print using UTF-8 (avoids cp950 UnicodeEncodeError on Windows) env['PYTHONIOENCODING'] = 'utf-8' env['PYTHONUTF8'] = '1' # # help fragmentation/alloc issues; user may tune # env.setdefault('PYTORCH_ALLOC_CONF', os.environ.get('PYTORCH_ALLOC_CONF', 'max_split_size_mb:128')) proc, q, t_out, t_err = _start_process_stream(cmd, cwd=str(CLONE_DIR), env=env) if proc is None: logs_buf += "[ERROR] Failed to launch subprocess.\n" yield None, logs_buf break # try next strategy? here treat as fatal # State tracking for this run current_phase = "init" # init, vae_enc, dit, vae_dec, post oom_detected = False start_time = time.time() # poll queue while True: try: # wait up to 0.5s for a line line = q.get(timeout=0.5) logs_buf += line yield None, logs_buf lower_line = line.lower() # Track Phase if "phase 1: vae encoding" in lower_line: current_phase = "vae_enc" elif "phase 2: dit upscaling" in lower_line: current_phase = "dit" elif "phase 3: vae decoding" in lower_line: current_phase = "vae_dec" elif "saving" in lower_line or "converting" in lower_line: current_phase = "post" # Check for OOM oom_indicators = ["outofmemory", "out of memory", "allocation on device", "oom", "cuda out of memory"] if any(k in lower_line for k in oom_indicators): logs_buf += f"\n[WARN] OOM detected during phase: {current_phase.upper()}\n" yield None, logs_buf oom_detected = True try: proc.kill() # Kill immediately to recover VRAM except: pass break except queue.Empty: # no new line - check process status if proc.poll() is not None: break # still running - continue polling continue # Flush remaining while True: try: line = q.get_nowait() logs_buf += line yield None, logs_buf except queue.Empty: break # Wait for reader threads to exit try: if t_out: t_out.join(timeout=1) if t_err: t_err.join(timeout=1) except Exception: pass runtime = time.time() - start_time logs_buf += f"[Attempt {idx} finished in {runtime:.2f}s] returncode={proc.returncode}\n" idx += 1 yield None, logs_buf # 4. Success Check using safe_input_path # CLI generates output relative to the actual input file used (which might be the temp one) # For video, strict output detection logic out_fmt_check = "mp4" if is_video else "png" out_path = expected_upscaled_path(safe_input_path, output_format=out_fmt_check) if Path(out_path).exists(): logs_buf += f"[SUCCESS] Intermediate Output: {out_path}\n" # Cleanup temp file if we created one if temp_copy: try: os.remove(temp_copy) except Exception: pass yield out_path, logs_buf return # 5. Failure Analysis & Parameter Adjustment if oom_detected or proc.returncode != 0: logs_buf += f"\n[INFO] Attempt {attempt_count} failed. Analyzing OOM Phase: {current_phase.upper()}...\n" # --- INTELLIGENT ADJUSTMENT LOGIC --- # Case A: VAE OOM (Phase 1 or 3) -> Reduce Tile Size if current_phase in ["vae_enc", "vae_dec"]: if current_tile_size > 256: new_tile = max(256, current_tile_size // 2) logs_buf += f"[STRATEGY] VAE OOM detected. Reducing Tile Size: {current_tile_size} -> {new_tile}\n" current_tile_size = new_tile else: # Tile size already min, try reducing batch size as a last resort new_batch = max(1, current_batch_size // 2) logs_buf += f"[STRATEGY] VAE OOM but Tile Size is min. Reducing Batch Size: {current_batch_size} -> {new_batch}\n" current_batch_size = new_batch # Case B: DiT OOM (Phase 2) -> Reduce Batch Size elif current_phase == "dit": if current_batch_size > 1: # For video consistency, try to keep 4n+1 if possible, or just halve it new_batch = max(1, current_batch_size // 2) logs_buf += f"[STRATEGY] DiT OOM detected. Reducing Batch Size: {current_batch_size} -> {new_batch}\n" current_batch_size = new_batch else: logs_buf += f"[FAIL] DiT OOM with Batch Size 1. Cannot reduce further.\n" break # Case C: Post-Process OOM (Phase 4) -> Reduce Batch Size elif current_phase == "post": if current_batch_size > 1: new_batch = max(1, current_batch_size // 2) logs_buf += f"[STRATEGY] Post-Process OOM detected. Reducing Batch Size: {current_batch_size} -> {new_batch}\n" current_batch_size = new_batch else: logs_buf += "[FAIL] Post-Process OOM with Batch Size 1.\n" break # Case D: Unknown/Init OOM -> Reduce both safely else: current_tile_size = max(256, current_tile_size // 2) current_batch_size = max(1, current_batch_size // 2) logs_buf += f"[STRATEGY] Early OOM. Reducing both Tile ({current_tile_size}) and Batch ({current_batch_size}).\n" # Check if we are just retrying same settings (infinite loop prevention) if attempt_count >= max_attempts: logs_buf += "[FAIL] Max attempts reached.\n" break yield None, logs_buf # Loop continues with new settings else: # Non-OOM fatal error logs_buf += f"[ERROR] Non-zero return code (not OOM) - stopping.\n" yield None, logs_buf return # all strategies exhausted logs_buf += "[FAILED] No output produced after all strategies.\n" if temp_copy: try: os.remove(temp_copy) except: pass yield None, logs_buf return # --- Preset change handler (considers CUDA & MPS availability) --- def preset_changed(preset_value): # Updated Tuple Order: # 0: compile_dit, 1: compile_vae, # 2: vae_encode_tiled, 3: vae_encode_tile_size, # 4: vae_decode_tiled, 5: vae_decode_tile_size, # 6: max_resolution, 7: blocks_to_swap, 8: swap_io_components # 9: dit_offload_device, 10: vae_offload_device, 11: tensor_offload_device, # 12: extra_args, 13: chunk_size, 14: temporal_overlap if preset_value == "Recommended (low VRAM)": return ( False, # compile_dit False, # compile_vae True, # vae_encode_tiled 512, # vae_encode_tile_size True, # vae_decode_tiled 512, # vae_decode_tile_size (sync with encode for safety) 1920, # max_resolution 32, # blocks_to_swap True, # swap_io_components "cpu", # dit_offload_device "none", # vae_offload_device (Keep VAE on device if possible) "cpu", # tensor_offload_device (Offload tensors to save VRAM) "--blocks_to_swap 0", # extra_args 0, # chunk_size 0 # temporal_overlap (0=auto/disabled) ) elif preset_value == "Offload (very slow)": return ( False, # compile_dit False, # compile_vae True, # vae_encode_tiled 256, # vae_encode_tile_size True, # vae_decode_tiled 256, # vae_decode_tile_size 1440, # max_resolution 99, # blocks_to_swap True, # swap_io_components "cpu", # dit_offload_device "cpu", # vae_offload_device "cpu", # tensor_offload_device "--blocks_to_swap 99 --swap_io_components --dit_offload_device cpu --vae_offload_device cpu --tensor_offload_device cpu", 0, 0 ) elif preset_value == "High quality (fast if lots of VRAM)": return ( True, True, False, 512, False, 512, 0, 0, False, "none", "none", "none", # Keep tensors on GPU/MPS "", 0, 0 ) # fallback return (False, False, True, 256, True, 256, 1920, 0, False, "none", "none", "cpu", "--blocks_to_swap 0", 0, 0) # ---------------- Paste JS (attach to gallery elem) ---------------- paste_js = """ function initPaste() { document.addEventListener('paste', function(e) { const gallery = document.getElementById('input_gallery'); if (!gallery) return; if (!gallery.matches(':hover')) return; const clipboardData = e.clipboardData || e.originalEvent.clipboardData; if (!clipboardData) return; const items = clipboardData.items; const files = []; for (let i = 0; i < items.length; i++) { if (items[i].kind === 'file' && items[i].type.startsWith('image/')) { files.push(items[i].getAsFile()); } } if (files.length === 0 && clipboardData.files.length > 0) { for (let i = 0; i < clipboardData.files.length; i++) { if (clipboardData.files[i].type.startsWith('image/')) { files.push(clipboardData.files[i]); } } } if (files.length === 0) return; const uploadInput = gallery.querySelector('input[type="file"]'); if (uploadInput) { e.preventDefault(); e.stopPropagation(); const dataTransfer = new DataTransfer(); files.forEach(file => dataTransfer.items.add(file)); uploadInput.files = dataTransfer.files; uploadInput.dispatchEvent(new Event('change', { bubbles: true })); } }); } """ # ---------------- # Gradio layout # ---------------- # Helper function: Generate progress bar HTML def make_progress_html(current, total, step_desc): if total == 0: percent = 0 else: percent = min(max(current / total * 100, 0), 100) # Use Gradio's CSS variables to automatically adapt to dark/light modes # var(--background-fill-secondary): Container background color # var(--border-color-primary): Border color # var(--body-text-color): Main text color # var(--color-accent): Progress bar color (follows theme accent) # var(--border-color-primary): Progress bar track color (ensures visibility in dark mode) return f"""
Running on Metal Performance Shaders (MPS). Performance is better than CPU.
Neither CUDA (NVIDIA) nor MPS (macOS) was detected. Processing will be extremely slow.