import gc import logging import os import shutil import subprocess import sys import tempfile import traceback import uuid from dataclasses import dataclass from pathlib import Path from types import SimpleNamespace import gradio as gr import spaces import torch from huggingface_hub import hf_hub_download, snapshot_download logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s") ROOT = Path(__file__).resolve().parent def _default_storage_root() -> Path: if os.getenv("SCAIL_STORAGE_ROOT"): return Path(os.environ["SCAIL_STORAGE_ROOT"]) data_mount = Path("/data") if data_mount.exists() and os.access(data_mount, os.W_OK): return data_mount return Path("/tmp") STORAGE_ROOT = _default_storage_root() STAGING_ROOT = Path(os.getenv("SCAIL_STAGING_ROOT", "/tmp")) OUTPUT_DIR = Path(os.getenv("SCAIL_OUTPUT_DIR", str(STORAGE_ROOT / "scail2_outputs"))) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) MODEL_REPO_ID = os.getenv("SCAIL_MODEL_REPO_ID", "zai-org/SCAIL-2") SAFETENSORS_REPO_ID = os.getenv("SCAIL_SAFETENSORS_REPO_ID") SAFETENSORS_FILENAME = os.getenv("SCAIL_SAFETENSORS_FILENAME", "SCAIL-2.safetensors") MODEL_NAME = os.getenv("SCAIL_MODEL_NAME", "SCAIL-14B") GPU_SIZE = os.getenv("SCAIL_ZEROGPU_SIZE", "xlarge") GPU_DURATION_COLD = int(os.getenv("SCAIL_GPU_DURATION_COLD", "600")) GPU_DURATION_WARM = int(os.getenv("SCAIL_GPU_DURATION_WARM", "330")) DEFAULT_TARGET_H = int(os.getenv("SCAIL_TARGET_H", "512")) DEFAULT_TARGET_W = int(os.getenv("SCAIL_TARGET_W", "896")) DEFAULT_SEGMENT_LEN = int(os.getenv("SCAIL_SEGMENT_LEN", "81")) DEFAULT_SEGMENT_OVERLAP = int(os.getenv("SCAIL_SEGMENT_OVERLAP", "5")) DEFAULT_SHIFT = float(os.getenv("SCAIL_SAMPLE_SHIFT", "3.0")) DEFAULT_GUIDE_SCALE = float(os.getenv("SCAIL_GUIDE_SCALE", "5.0")) DEFAULT_SOLVER = os.getenv("SCAIL_SAMPLE_SOLVER", "unipc") AUTO_CONVERT = os.getenv("SCAIL_AUTO_CONVERT", "1") == "1" PRELOAD_PIPELINE = os.getenv("SCAIL_PRELOAD_PIPELINE", "1") == "1" STAGE_SAFETENSORS_FOR_LOAD = os.getenv("SCAIL_STAGE_SAFETENSORS_FOR_LOAD", "1") == "1" CONVERT_TO_STAGING_FIRST = os.getenv("SCAIL_CONVERT_TO_STAGING_FIRST", "1") == "1" CLIP_CKPT_NAME = "models_clip_open-clip-xlm-roberta-large-vit-huge-14-onlyvisual.pth" ORIGINAL_DIT_REL_PATH = "model/1/fsdp2_rank_0000_checkpoint.pt" BASE_ALLOW_PATTERNS = [ "Wan2.1_VAE.pth", "umt5-xxl/**", CLIP_CKPT_NAME, ] _PIPELINE = None _PIPELINE_KEY = None _ASSET_STATUS = "Assets were not prepared yet." _ASSET_ERROR = None _RUNTIME_STATUS = "Runtime was not prepared yet." _RUNTIME_ERROR = None _PIPELINE_STATUS = "Pipeline was not preloaded." _PIPELINE_ERROR = None _LAST_CONVERTED_SAFETENSORS = None _WAN = None _GENERATE_VIDEO = None _SCAIL_CONFIGS = None _SCAIL_CONFIG_PATHS = None @dataclass(frozen=True) class PreparedExample: label: str image: str mask_image: str pose: str mask_video: str prompt: str replace_flag: bool = False PREPARED_EXAMPLES = { "Animation 001 - end-to-end": PreparedExample( label="Animation 001 - end-to-end", image="examples/animation_001/ref.jpg", mask_image="examples/animation_001/ref_mask.jpg", pose="examples/animation_001/rendered_v2.mp4", mask_video="examples/animation_001/rendered_mask_v2.mp4", prompt="A young woman is dancing with energetic body movement.", ), "Animation 001 - pose-driven": PreparedExample( label="Animation 001 - pose-driven", image="examples/animation_001_posedriven/ref.jpg", mask_image="examples/animation_001_posedriven/ref_mask.jpg", pose="examples/animation_001_posedriven/rendered_v2.mp4", mask_video="examples/animation_001_posedriven/rendered_mask_v2.mp4", prompt="A young woman is dancing with energetic body movement.", ), "Animation 002 - end-to-end": PreparedExample( label="Animation 002 - end-to-end", image="examples/animation_002/ref.jpg", mask_image="examples/animation_002/ref_mask.jpg", pose="examples/animation_002/rendered_v2.mp4", mask_video="examples/animation_002/rendered_mask_v2.mp4", prompt="A character performs the motion from the driving video.", ), "Replacement 001": PreparedExample( label="Replacement 001", image="examples/replace_001/ref.png", mask_image="examples/replace_001/ref_mask.png", pose="examples/replace_001/rendered_v2.mp4", mask_video="examples/replace_001/replace_mask.mp4", prompt=( "A blond white male wearing a black suit, trousers, and leather shoes " "is playing the violin on the street while pedestrians walk past him." ), replace_flag=True, ), } def _abs(path: str | Path) -> str: path = Path(path) if not path.is_absolute(): path = ROOT / path return str(path) def _existing_examples() -> dict[str, PreparedExample]: available = {} for name, example in PREPARED_EXAMPLES.items(): paths = [example.image, example.mask_image, example.pose, example.mask_video] if all(Path(_abs(p)).exists() for p in paths): available[name] = example return available def _require_repo_layout(): missing = [] for rel in ("wan/scail.py", "wan/modules/model_scail2.py", "generate.py", "configs/config-14b.json"): if not (ROOT / rel).exists(): missing.append(rel) if missing: raise RuntimeError( "This app.py is meant to live at the root of the SCAIL-2 repository. " f"Missing: {', '.join(missing)}" ) def _download_safetensors_if_configured() -> Path | None: if not SAFETENSORS_REPO_ID: return None local_dir = Path(os.getenv("SCAIL_SAFETENSORS_CACHE", str(STORAGE_ROOT / "scail2_safetensors"))) local_dir.mkdir(parents=True, exist_ok=True) local_path = local_dir / SAFETENSORS_FILENAME if local_path.exists(): return local_path logging.info( "Downloading converted SCAIL-2 safetensors from %s/%s", SAFETENSORS_REPO_ID, SAFETENSORS_FILENAME, ) downloaded = hf_hub_download( repo_id=SAFETENSORS_REPO_ID, filename=SAFETENSORS_FILENAME, local_dir=str(local_dir), local_dir_use_symlinks=False, resume_download=True, ) return Path(downloaded) def _find_converted_safetensors(ckpt_dir: Path | None) -> Path | None: candidates = [] env_path = os.getenv("SCAIL_SAFETENSORS_PATH") if env_path: candidates.append(Path(env_path)) if _LAST_CONVERTED_SAFETENSORS is not None: candidates.append(Path(_LAST_CONVERTED_SAFETENSORS)) candidates += [ ROOT / "SCAIL-2.safetensors", ROOT / "models" / "SCAIL-2.safetensors", ROOT / "model.safetensors", Path(os.getenv("SCAIL_CONVERTED_DIR", str(STORAGE_ROOT / "scail2_converted"))) / "SCAIL-2.safetensors", ] if ckpt_dir is not None: candidates += [ ckpt_dir / "SCAIL-2.safetensors", ckpt_dir / "model.safetensors", ] for candidate in candidates: if candidate.exists(): return candidate return _download_safetensors_if_configured() def _copy_file_with_progress(source: Path, dest: Path, description: str) -> Path: source_size = source.stat().st_size chunk_size = int(os.getenv("SCAIL_STAGE_COPY_CHUNK_MB", "64")) * 1024 * 1024 log_every = int(os.getenv("SCAIL_STAGE_COPY_LOG_GB", "1")) * 1024 * 1024 * 1024 dest.parent.mkdir(parents=True, exist_ok=True) if dest.exists() and dest.stat().st_size == source_size: return dest tmp_dest = dest.with_suffix(dest.suffix + ".tmp") copied = tmp_dest.stat().st_size if tmp_dest.exists() else 0 if copied > source_size: tmp_dest.unlink() copied = 0 logging.info("%s: %s -> %s", description, source, dest) next_log = ((copied // log_every) + 1) * log_every if log_every > 0 else source_size if copied: logging.info( "Resuming copy at %.2f/%.2f GB", copied / 1024**3, source_size / 1024**3, ) with source.open("rb") as src, tmp_dest.open("ab") as dst: if copied: src.seek(copied) while copied < source_size: chunk = src.read(min(chunk_size, source_size - copied)) if not chunk: raise RuntimeError(f"Unexpected EOF while copying {source}: {copied} of {source_size} bytes") dst.write(chunk) copied += len(chunk) if log_every > 0 and copied >= next_log: logging.info( "%s: %.2f/%.2f GB", description, copied / 1024**3, source_size / 1024**3, ) next_log += log_every if tmp_dest.stat().st_size != source_size: raise RuntimeError(f"Copied file size mismatch: {tmp_dest.stat().st_size} != {source_size}") tmp_dest.replace(dest) logging.info("Finished %s: %s", description, dest) return dest def _is_relative_to(path: Path, parent: Path) -> bool: try: path.resolve().relative_to(parent.resolve()) return True except ValueError: return False def _stage_safetensors_for_load(scail_path: Path) -> Path: if not STAGE_SAFETENSORS_FOR_LOAD: return scail_path source = Path(scail_path) if _is_relative_to(source, STAGING_ROOT): return source stage_dir = Path(os.getenv("SCAIL_MODEL_LOAD_CACHE", str(STAGING_ROOT / "scail2_model_load"))) stage_dir.mkdir(parents=True, exist_ok=True) staged = stage_dir / source.name if staged.exists() and staged.stat().st_size == source.stat().st_size: return staged return _copy_file_with_progress(source, staged, "Staging SCAIL-2 safetensors for local load") def _download_checkpoint_if_needed(include_original_dit: bool = False) -> Path: env_dir = os.getenv("SCAIL_CKPT_DIR") if env_dir: ckpt_dir = Path(env_dir) if not ckpt_dir.exists(): raise RuntimeError(f"SCAIL_CKPT_DIR does not exist: {ckpt_dir}") return ckpt_dir local_dir = Path(os.getenv("SCAIL_CKPT_CACHE", str(STORAGE_ROOT / "scail2_ckpt"))) has_base_assets = ( (local_dir / "Wan2.1_VAE.pth").exists() and (local_dir / "umt5-xxl").exists() and (local_dir / CLIP_CKPT_NAME).exists() ) if has_base_assets: return local_dir logging.info("Downloading SCAIL-2 base checkpoint assets from %s", MODEL_REPO_ID) if include_original_dit: logging.warning( "include_original_dit is deprecated here; original DiT conversion staging " "uses SCAIL_ORIGINAL_DIT_CACHE instead." ) snapshot_download( repo_id=MODEL_REPO_ID, local_dir=str(local_dir), local_dir_use_symlinks=False, resume_download=True, allow_patterns=BASE_ALLOW_PATTERNS, ) return local_dir def _download_original_dit_for_conversion() -> Path: env_dir = os.getenv("SCAIL_ORIGINAL_DIT_DIR") if env_dir: original_dir = Path(env_dir) original_path = original_dir / ORIGINAL_DIT_REL_PATH if not original_path.exists(): raise RuntimeError(f"SCAIL_ORIGINAL_DIT_DIR is missing {ORIGINAL_DIT_REL_PATH}: {original_dir}") return original_dir local_dir = Path(os.getenv("SCAIL_ORIGINAL_DIT_CACHE", str(STAGING_ROOT / "scail2_original_dit"))) original_path = local_dir / ORIGINAL_DIT_REL_PATH if original_path.exists(): return local_dir logging.info( "Downloading original SCAIL-2 DiT checkpoint for one-time conversion into %s", local_dir, ) snapshot_download( repo_id=MODEL_REPO_ID, local_dir=str(local_dir), local_dir_use_symlinks=False, resume_download=True, allow_patterns=[ORIGINAL_DIT_REL_PATH], ) return local_dir def _prepare_assets_for_runtime() -> str: """Download CPU-side assets before any ZeroGPU function starts.""" global _ASSET_STATUS, _ASSET_ERROR try: ckpt_dir = _download_checkpoint_if_needed(include_original_dit=False) scail_path = _find_converted_safetensors(ckpt_dir) if scail_path is None and AUTO_CONVERT: original_dir = _download_original_dit_for_conversion() scail_path = _maybe_convert_checkpoint(original_dir, None) if scail_path is None: _ASSET_STATUS = ( "Base checkpoint assets are present, but no converted safetensors file " "was found. Set SCAIL_SAFETENSORS_REPO_ID plus SCAIL_SAFETENSORS_FILENAME, " "or set SCAIL_SAFETENSORS_PATH. Automatic startup conversion is disabled " "only when SCAIL_AUTO_CONVERT=0." ) else: _ASSET_STATUS = ( "Assets ready. Base checkpoint: " f"{ckpt_dir}. Converted DiT safetensors: {scail_path}." ) _ASSET_ERROR = None except Exception: _ASSET_ERROR = traceback.format_exc() _ASSET_STATUS = "Asset preparation failed. See the traceback below." logging.exception("Asset preparation failed") return _ASSET_STATUS if _ASSET_ERROR is None else _ASSET_STATUS + "\n\n" + _ASSET_ERROR def _maybe_convert_checkpoint(ckpt_dir: Path, scail_path: Path | None) -> Path: global _LAST_CONVERTED_SAFETENSORS if scail_path is not None: return scail_path if not AUTO_CONVERT: raise RuntimeError( "Converted SCAIL-2 safetensors file was not found. For the wan-scail2 branch, " "provide SCAIL_SAFETENSORS_PATH=/path/to/SCAIL-2.safetensors, or place " "SCAIL-2.safetensors at the repo root. Automatic startup conversion is disabled " "because SCAIL_AUTO_CONVERT=0." ) persistent_dir = Path(os.getenv("SCAIL_CONVERTED_DIR", str(STORAGE_ROOT / "scail2_converted"))) persistent_path = persistent_dir / "SCAIL-2.safetensors" if persistent_path.exists(): return persistent_path if CONVERT_TO_STAGING_FIRST: save_dir = Path(os.getenv("SCAIL_CONVERSION_WORK_DIR", str(STAGING_ROOT / "scail2_converted_work"))) else: save_dir = persistent_dir save_dir.mkdir(parents=True, exist_ok=True) save_path = save_dir / "SCAIL-2.safetensors" if save_path.exists(): _LAST_CONVERTED_SAFETENSORS = save_path if save_path != persistent_path: _copy_file_with_progress(save_path, persistent_path, "Persisting converted SCAIL-2 safetensors to storage") return save_path logging.info("Converting checkpoint to safetensors: %s", save_path) convert_env = os.environ.copy() # PyTorch >= 2.6 changed torch.load's default to weights_only=True. # The official SCAIL-2 FSDP checkpoint needs the legacy trusted-pickle path # during conversion. Only do this for the official checkpoint conversion step. convert_env["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" subprocess.run( [ sys.executable, str(ROOT / "convert.py"), "--scail-dir", str(ckpt_dir), "--save-path", str(save_path), ], check=True, cwd=str(ROOT), env=convert_env, ) _LAST_CONVERTED_SAFETENSORS = save_path if save_path != persistent_path: _copy_file_with_progress(save_path, persistent_path, "Persisting converted SCAIL-2 safetensors to storage") return save_path def _install_attention_patch(): """Use HF Kernels flash-attn2 when available; otherwise fall back to SDPA. SCAIL-2 imports `flash_attention` directly in model_scail2.py. This monkey patch avoids requiring a locally built `flash_attn` wheel on Spaces. """ import wan.modules.attention as attention_mod hf_flash_attn2 = None try: from kernels import get_kernel hf_flash_attn2 = get_kernel("kernels-community/flash-attn2", version=2) logging.info("Using kernels-community/flash-attn2 through HF Kernels.") except Exception as exc: if torch.cuda.is_available(): device_name = torch.cuda.get_device_name(0) capability = torch.cuda.get_device_capability(0) else: device_name = "no cuda" capability = None logging.warning("Could not initialize HF Kernels flash-attn2: %r", exc) logging.warning( "Attention fallback environment: torch=%s cuda=%s device=%s capability=%s", torch.__version__, torch.version.cuda, device_name, capability, ) def patched_flash_attention( q, k, v, q_lens=None, k_lens=None, dropout_p=0.0, softmax_scale=None, q_scale=None, causal=False, window_size=(-1, -1), deterministic=False, dtype=torch.bfloat16, version=None, ): half_dtypes = (torch.float16, torch.bfloat16) b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype def half(x): return x if x.dtype in half_dtypes else x.to(dtype) if hf_flash_attn2 is not None and q.device.type == "cuda": if q_lens is None: q_var = half(q.flatten(0, 1)) q_lens_t = torch.full((b,), lq, dtype=torch.int32, device=q.device) else: q_lens_t = q_lens.to(device=q.device, dtype=torch.int32) q_var = half(torch.cat([u[: int(n)] for u, n in zip(q, q_lens_t)])) if k_lens is None: k_var = half(k.flatten(0, 1)) v_var = half(v.flatten(0, 1)) k_lens_t = torch.full((b,), lk, dtype=torch.int32, device=k.device) else: k_lens_t = k_lens.to(device=k.device, dtype=torch.int32) k_var = half(torch.cat([u[: int(n)] for u, n in zip(k, k_lens_t)])) v_var = half(torch.cat([u[: int(n)] for u, n in zip(v, k_lens_t)])) q_var = q_var.to(v_var.dtype) k_var = k_var.to(v_var.dtype) if q_scale is not None: q_var = q_var * q_scale cu_q = torch.cat([q_lens_t.new_zeros([1]), q_lens_t]).cumsum(0, dtype=torch.int32) cu_k = torch.cat([k_lens_t.new_zeros([1]), k_lens_t]).cumsum(0, dtype=torch.int32) try: out = hf_flash_attn2.flash_attn_varlen_func( q=q_var, k=k_var, v=v_var, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, max_seqlen_q=lq, max_seqlen_k=lk, dropout_p=dropout_p, softmax_scale=softmax_scale, causal=causal, window_size=window_size, deterministic=deterministic, ) if isinstance(out, tuple): out = out[0] return out.unflatten(0, (b, lq)).type(out_dtype) except Exception as exc: logging.warning("HF Kernels flash-attn2 failed, falling back to SDPA: %s", exc) if q_lens is not None and not torch.all(q_lens == lq): logging.warning("SDPA fallback ignores variable q_lens; demo batch size should stay at 1.") if k_lens is not None and not torch.all(k_lens == lk): logging.warning("SDPA fallback ignores variable k_lens; demo batch size should stay at 1.") q_sdpa = q.transpose(1, 2).to(dtype) k_sdpa = k.transpose(1, 2).to(dtype) v_sdpa = v.transpose(1, 2).to(dtype) out = torch.nn.functional.scaled_dot_product_attention( q_sdpa, k_sdpa, v_sdpa, attn_mask=None, dropout_p=dropout_p, is_causal=causal, scale=softmax_scale, ) return out.transpose(1, 2).contiguous().type(out_dtype) attention_mod.flash_attention = patched_flash_attention # Several modules import `flash_attention` directly at import time # (`from .attention import flash_attention`), so patch those bound names too. for module_name in ( "wan.modules.clip", "wan.modules.model", "wan.modules.model_scail", "wan.modules.model_scail2", ): try: module = __import__(module_name, fromlist=["flash_attention"]) if hasattr(module, "flash_attention"): module.flash_attention = patched_flash_attention logging.info("Patched %s.flash_attention", module_name) except Exception as exc: logging.warning("Could not patch %s.flash_attention: %s", module_name, exc) def _import_runtime(): global _WAN, _GENERATE_VIDEO, _SCAIL_CONFIGS, _SCAIL_CONFIG_PATHS if _WAN is not None: return _require_repo_layout() if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import wan from generate import generate_video from wan.configs import SCAIL_CONFIGS, SCAIL_CONFIG_PATHS _install_attention_patch() _WAN = wan _GENERATE_VIDEO = generate_video _SCAIL_CONFIGS = SCAIL_CONFIGS _SCAIL_CONFIG_PATHS = SCAIL_CONFIG_PATHS def _prepare_runtime_for_startup() -> str: """Import Wan/SCAIL and resolve HF Kernels before any ZeroGPU call.""" global _RUNTIME_STATUS, _RUNTIME_ERROR try: _import_runtime() _RUNTIME_STATUS = "Runtime ready. Attention backend has been initialized at startup." _RUNTIME_ERROR = None except Exception: _RUNTIME_ERROR = traceback.format_exc() _RUNTIME_STATUS = "Runtime preparation failed. See the traceback below." logging.exception("Runtime preparation failed") return _RUNTIME_STATUS if _RUNTIME_ERROR is None else _RUNTIME_STATUS + "\n\n" + _RUNTIME_ERROR def _get_pipeline(): global _PIPELINE, _PIPELINE_KEY _import_runtime() ckpt_dir = _download_checkpoint_if_needed(include_original_dit=False) scail_path = _find_converted_safetensors(ckpt_dir) if scail_path is None and AUTO_CONVERT: original_dir = _download_original_dit_for_conversion() scail_path = _maybe_convert_checkpoint(original_dir, None) else: scail_path = _maybe_convert_checkpoint(ckpt_dir, scail_path) scail_load_path = _stage_safetensors_for_load(scail_path) config_path = Path(os.getenv("SCAIL_CONFIG_PATH", _SCAIL_CONFIG_PATHS[MODEL_NAME])) if not config_path.is_absolute(): config_path = ROOT / config_path lora_path = os.getenv("SCAIL_LORA_PATH") or None lora_alpha = float(os.getenv("SCAIL_LORA_ALPHA", "1.0")) key = (str(ckpt_dir), str(scail_load_path), str(config_path), lora_path, lora_alpha) if _PIPELINE is not None and _PIPELINE_KEY == key: return _PIPELINE logging.info("Loading SCAIL-2 pipeline.") cfg = _SCAIL_CONFIGS[MODEL_NAME] _PIPELINE = _WAN.SCAIL2Pipeline( config=cfg, checkpoint_dir=str(ckpt_dir), scail_safetensors_path=str(scail_load_path), scail_config_path=str(config_path), device_id=0, rank=0, t5_fsdp=False, dit_fsdp=False, use_usp=False, t5_cpu=False, lora_path=lora_path, lora_alpha=lora_alpha, ) _PIPELINE_KEY = key return _PIPELINE def _prepare_pipeline_for_startup() -> str: """Preload the CUDA pipeline at module startup for ZeroGPU CUDA emulation.""" global _PIPELINE_STATUS, _PIPELINE_ERROR try: _get_pipeline() _PIPELINE_STATUS = "Pipeline preloaded at startup." _PIPELINE_ERROR = None except Exception: _PIPELINE_ERROR = traceback.format_exc() _PIPELINE_STATUS = "Pipeline preload failed. See the traceback below." logging.exception("Pipeline preload failed") return _PIPELINE_STATUS if _PIPELINE_ERROR is None else _PIPELINE_STATUS + "\n\n" + _PIPELINE_ERROR def _is_gradio_native_file_path(path: Path) -> bool: """Return True when Gradio can serve a file without explicit allowed_paths.""" path = path.resolve() native_roots = [ROOT.resolve(), Path(tempfile.gettempdir()).resolve()] return any(_is_relative_to(path, root) for root in native_roots) def _prepare_output_for_gradio(path: str | Path) -> str: """Return a display path accepted by Gradio. The canonical output is saved in OUTPUT_DIR, which may live on a mounted Bucket such as /data/scail2_outputs. Gradio can be strict about returned file paths, so we keep the persistent copy in the Bucket and return a temporary copy under /tmp for the UI component. """ source = Path(path) if not source.exists(): raise RuntimeError(f"Generated video was not found: {source}") if _is_gradio_native_file_path(source): return str(source) gradio_dir = Path( os.getenv( "SCAIL_GRADIO_OUTPUT_CACHE", str(Path(tempfile.gettempdir()) / "scail2_gradio_outputs"), ) ) gradio_dir.mkdir(parents=True, exist_ok=True) dest = gradio_dir / source.name shutil.copy2(source, dest) logging.info("Copied generated video for Gradio display: %s -> %s", source, dest) return str(dest) def _duration_for_job(*args, **kwargs): if _PIPELINE is None: return int(os.getenv("SCAIL_GPU_DURATION", str(GPU_DURATION_COLD))) return int(os.getenv("SCAIL_GPU_DURATION", str(GPU_DURATION_WARM))) def _run_scail_job( image_path, mask_image_path, pose_path, mask_video_path, prompt, replace_flag, target_h, target_w, sample_steps, guide_scale, sample_shift, seed, segment_len, segment_overlap, progress=None, ): if progress is not None: progress(0.02, desc="Loading SCAIL-2 pipeline") pipeline = _get_pipeline() cfg = _SCAIL_CONFIGS[MODEL_NAME] save_file = OUTPUT_DIR / f"scail2_{uuid.uuid4().hex}.mp4" if progress is not None: progress(0.12, desc="Preparing inputs") args = SimpleNamespace( target_h=int(target_h), target_w=int(target_w), sample_shift=float(sample_shift), sample_solver=DEFAULT_SOLVER, segment_len=int(segment_len), segment_overlap=int(segment_overlap), sample_steps=int(sample_steps), sample_guide_scale=float(guide_scale), base_seed=int(seed), offload_model=True, save_file=str(save_file), save_dir=str(OUTPUT_DIR), prompt=prompt or "", ) if progress is not None: progress(0.15, desc="Generating video") _GENERATE_VIDEO( pipeline, prompt or "", str(image_path), str(mask_image_path), str(pose_path), str(mask_video_path), args, device=0, rank=0, cfg=cfg, input_idx=None, replace_flag=bool(replace_flag), additional_task_input=None, ) if progress is not None: progress(0.95, desc="Finalizing output") gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if progress is not None: progress(0.98, desc="Preparing video for display") display_file = _prepare_output_for_gradio(save_file) if progress is not None: progress(1.0, desc="Done") return display_file @spaces.GPU(duration=_duration_for_job, size=GPU_SIZE) def generate_from_example( example_name, prompt, sample_steps, guide_scale, sample_shift, seed, target_size, segment_len, segment_overlap, progress=gr.Progress(track_tqdm=True), ): try: progress(0.0, desc="Checking example inputs") examples = _existing_examples() if example_name not in examples: raise RuntimeError(f"Example is missing from this checkout: {example_name}") example = examples[example_name] target_w, target_h = [int(v) for v in str(target_size).split("x")] output = _run_scail_job( _abs(example.image), _abs(example.mask_image), _abs(example.pose), _abs(example.mask_video), prompt, example.replace_flag, target_h, target_w, sample_steps, guide_scale, sample_shift, seed, segment_len, segment_overlap, progress=progress, ) return output, "Done." except Exception: logging.exception("Generation failed") return None, traceback.format_exc() @spaces.GPU(duration=_duration_for_job, size=GPU_SIZE) def generate_from_uploads( image, mask_image, pose_video, mask_video, prompt, mode, sample_steps, guide_scale, sample_shift, seed, target_size, segment_len, segment_overlap, progress=gr.Progress(track_tqdm=True), ): try: progress(0.0, desc="Checking uploaded inputs") required = { "reference image": image, "reference mask": mask_image, "driving/rendered video": pose_video, "driving mask video": mask_video, } missing = [name for name, value in required.items() if value is None] if missing: raise RuntimeError("Missing required input(s): " + ", ".join(missing)) target_w, target_h = [int(v) for v in str(target_size).split("x")] output = _run_scail_job( image, mask_image, pose_video, mask_video, prompt, mode == "replacement", target_h, target_w, sample_steps, guide_scale, sample_shift, seed, segment_len, segment_overlap, progress=progress, ) return output, "Done." except Exception: logging.exception("Generation failed") return None, traceback.format_exc() def load_example_preview(example_name): examples = _existing_examples() if example_name not in examples: return None, None, None, None, "", "animation" example = examples[example_name] mode = "replacement" if example.replace_flag else "animation" return ( _abs(example.image), _abs(example.pose), _abs(example.mask_image), _abs(example.mask_video), example.prompt, mode, ) def _startup_message(): try: _require_repo_layout() examples = _existing_examples() if not examples: return "Repo layout detected, but no prepared examples were found." return ( f"Ready. Found {len(examples)} prepared example(s). " f"Storage root: {STORAGE_ROOT}. Staging root: {STAGING_ROOT}. " f"Output dir: {OUTPUT_DIR}.\n\n" f"{_ASSET_STATUS}\n\n{_RUNTIME_STATUS}\n\n{_PIPELINE_STATUS}\n\n" "Attention backend: HF Kernels flash-attn2 when available, otherwise SDPA." ) except Exception as exc: return str(exc) def build_ui(): examples = _existing_examples() example_names = list(examples.keys()) default_example = example_names[0] if example_names else None default_preview = load_example_preview(default_example) if default_example else (None, None, None, None, "", "animation") with gr.Blocks(title="SCAIL-2 ZeroGPU Demo") as demo: gr.Markdown( "# SCAIL-2 Character Animation Demo\n" "Animate or replace characters with SCAIL-2, directly in the browser. " "The easiest way to try the model is to start from one of the prepared examples below.\n\n" "This Space runs a heavy 14B video model on Hugging Face ZeroGPU. " "Generation is not realtime: a short example can take a few minutes, especially " "when the Space has just started.\n\n" "For custom inputs, use the Advanced Uploads tab with already-prepared SCAIL-2 " "conditions: reference image, reference mask, driving video, and driving mask. " "Automatic SCAIL-Pose/SAM3 preprocessing is not enabled in this public demo yet." ) startup = gr.Textbox(value=_startup_message(), label="Startup status", interactive=False) with gr.Tab("Prepared Examples"): with gr.Row(): example_dropdown = gr.Dropdown( choices=example_names, value=default_example, label="Example", ) mode_view = gr.Textbox(value=default_preview[5], label="Mode", interactive=False) with gr.Row(): ref_preview = gr.Image(value=default_preview[0], label="Reference", interactive=False) driving_preview = gr.Video(value=default_preview[1], label="Driving / rendered video") with gr.Row(): ref_mask_preview = gr.Image(value=default_preview[2], label="Reference mask", interactive=False) driving_mask_preview = gr.Video(value=default_preview[3], label="Driving mask") prompt = gr.Textbox(value=default_preview[4], label="Prompt", lines=3) with gr.Row(): sample_steps = gr.Slider(4, 40, value=8, step=1, label="Steps") guide_scale = gr.Slider(1.0, 8.0, value=DEFAULT_GUIDE_SCALE, step=0.1, label="CFG") sample_shift = gr.Slider(1.0, 6.0, value=DEFAULT_SHIFT, step=0.1, label="Shift") with gr.Row(): seed = gr.Number(value=42, precision=0, label="Seed") target_size = gr.Dropdown(["896x512", "512x896", "1280x704", "704x1280"], value=f"{DEFAULT_TARGET_W}x{DEFAULT_TARGET_H}", label="Target size") segment_len = gr.Number(value=DEFAULT_SEGMENT_LEN, precision=0, label="Segment length") segment_overlap = gr.Number(value=DEFAULT_SEGMENT_OVERLAP, precision=0, label="Segment overlap") run_example = gr.Button("Generate", variant="primary") output_video = gr.Video(label="Output") status = gr.Textbox(label="Run status", lines=8) example_dropdown.change( load_example_preview, inputs=[example_dropdown], outputs=[ref_preview, driving_preview, ref_mask_preview, driving_mask_preview, prompt, mode_view], ) run_example.click( generate_from_example, inputs=[ example_dropdown, prompt, sample_steps, guide_scale, sample_shift, seed, target_size, segment_len, segment_overlap, ], outputs=[output_video, status], ) with gr.Tab("Advanced Uploads"): gr.Markdown( "Upload custom inputs that have already been prepared for SCAIL-2: reference image, " "reference mask, driving/rendered video, and driving mask video. " "This public demo does not run SCAIL-Pose/SAM3 preprocessing yet." ) with gr.Row(): up_image = gr.Image(type="filepath", label="Reference image") up_mask_image = gr.Image(type="filepath", label="Reference mask") with gr.Row(): up_pose_video = gr.Video(label="Driving / rendered video") up_mask_video = gr.Video(label="Driving mask / replace mask") up_mode = gr.Radio(["animation", "replacement"], value="animation", label="Mode") up_prompt = gr.Textbox(label="Prompt", lines=3) with gr.Row(): up_steps = gr.Slider(4, 40, value=8, step=1, label="Steps") up_cfg = gr.Slider(1.0, 8.0, value=DEFAULT_GUIDE_SCALE, step=0.1, label="CFG") up_shift = gr.Slider(1.0, 6.0, value=DEFAULT_SHIFT, step=0.1, label="Shift") with gr.Row(): up_seed = gr.Number(value=42, precision=0, label="Seed") up_target_size = gr.Dropdown(["896x512", "512x896", "1280x704", "704x1280"], value=f"{DEFAULT_TARGET_W}x{DEFAULT_TARGET_H}", label="Target size") up_segment_len = gr.Number(value=DEFAULT_SEGMENT_LEN, precision=0, label="Segment length") up_segment_overlap = gr.Number(value=DEFAULT_SEGMENT_OVERLAP, precision=0, label="Segment overlap") run_upload = gr.Button("Generate from uploads", variant="primary") upload_output = gr.Video(label="Output") upload_status = gr.Textbox(label="Run status", lines=8) run_upload.click( generate_from_uploads, inputs=[ up_image, up_mask_image, up_pose_video, up_mask_video, up_prompt, up_mode, up_steps, up_cfg, up_shift, up_seed, up_target_size, up_segment_len, up_segment_overlap, ], outputs=[upload_output, upload_status], ) return demo if __name__ == "__main__": if os.getenv("SCAIL_PRELOAD_ASSETS", "1") == "1": _prepare_assets_for_runtime() if os.getenv("SCAIL_PRELOAD_RUNTIME", "1") == "1": _prepare_runtime_for_startup() if PRELOAD_PIPELINE: _prepare_pipeline_for_startup() build_ui().queue(max_size=8).launch( allowed_paths=[str(OUTPUT_DIR.resolve())], show_error=True, )