| import gc |
| import logging |
| import os |
| import shutil |
| import subprocess |
| import sys |
| import tempfile |
| import traceback |
| import uuid |
| from dataclasses import dataclass |
| from pathlib import Path |
| from types import SimpleNamespace |
|
|
| import gradio as gr |
| import spaces |
| import torch |
| from huggingface_hub import hf_hub_download, snapshot_download |
|
|
|
|
| logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s") |
|
|
| ROOT = Path(__file__).resolve().parent |
|
|
|
|
| def _default_storage_root() -> Path: |
| if os.getenv("SCAIL_STORAGE_ROOT"): |
| return Path(os.environ["SCAIL_STORAGE_ROOT"]) |
| data_mount = Path("/data") |
| if data_mount.exists() and os.access(data_mount, os.W_OK): |
| return data_mount |
| return Path("/tmp") |
|
|
|
|
| STORAGE_ROOT = _default_storage_root() |
| STAGING_ROOT = Path(os.getenv("SCAIL_STAGING_ROOT", "/tmp")) |
| OUTPUT_DIR = Path(os.getenv("SCAIL_OUTPUT_DIR", str(STORAGE_ROOT / "scail2_outputs"))) |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| MODEL_REPO_ID = os.getenv("SCAIL_MODEL_REPO_ID", "zai-org/SCAIL-2") |
| SAFETENSORS_REPO_ID = os.getenv("SCAIL_SAFETENSORS_REPO_ID") |
| SAFETENSORS_FILENAME = os.getenv("SCAIL_SAFETENSORS_FILENAME", "SCAIL-2.safetensors") |
| MODEL_NAME = os.getenv("SCAIL_MODEL_NAME", "SCAIL-14B") |
| GPU_SIZE = os.getenv("SCAIL_ZEROGPU_SIZE", "xlarge") |
| GPU_DURATION_COLD = int(os.getenv("SCAIL_GPU_DURATION_COLD", "600")) |
| GPU_DURATION_WARM = int(os.getenv("SCAIL_GPU_DURATION_WARM", "330")) |
| DEFAULT_TARGET_H = int(os.getenv("SCAIL_TARGET_H", "512")) |
| DEFAULT_TARGET_W = int(os.getenv("SCAIL_TARGET_W", "896")) |
| DEFAULT_SEGMENT_LEN = int(os.getenv("SCAIL_SEGMENT_LEN", "81")) |
| DEFAULT_SEGMENT_OVERLAP = int(os.getenv("SCAIL_SEGMENT_OVERLAP", "5")) |
| DEFAULT_SHIFT = float(os.getenv("SCAIL_SAMPLE_SHIFT", "3.0")) |
| DEFAULT_GUIDE_SCALE = float(os.getenv("SCAIL_GUIDE_SCALE", "5.0")) |
| DEFAULT_SOLVER = os.getenv("SCAIL_SAMPLE_SOLVER", "unipc") |
| AUTO_CONVERT = os.getenv("SCAIL_AUTO_CONVERT", "1") == "1" |
| PRELOAD_PIPELINE = os.getenv("SCAIL_PRELOAD_PIPELINE", "1") == "1" |
| STAGE_SAFETENSORS_FOR_LOAD = os.getenv("SCAIL_STAGE_SAFETENSORS_FOR_LOAD", "1") == "1" |
| CONVERT_TO_STAGING_FIRST = os.getenv("SCAIL_CONVERT_TO_STAGING_FIRST", "1") == "1" |
| CLIP_CKPT_NAME = "models_clip_open-clip-xlm-roberta-large-vit-huge-14-onlyvisual.pth" |
| ORIGINAL_DIT_REL_PATH = "model/1/fsdp2_rank_0000_checkpoint.pt" |
| BASE_ALLOW_PATTERNS = [ |
| "Wan2.1_VAE.pth", |
| "umt5-xxl/**", |
| CLIP_CKPT_NAME, |
| ] |
|
|
| _PIPELINE = None |
| _PIPELINE_KEY = None |
| _ASSET_STATUS = "Assets were not prepared yet." |
| _ASSET_ERROR = None |
| _RUNTIME_STATUS = "Runtime was not prepared yet." |
| _RUNTIME_ERROR = None |
| _PIPELINE_STATUS = "Pipeline was not preloaded." |
| _PIPELINE_ERROR = None |
| _LAST_CONVERTED_SAFETENSORS = None |
| _WAN = None |
| _GENERATE_VIDEO = None |
| _SCAIL_CONFIGS = None |
| _SCAIL_CONFIG_PATHS = None |
|
|
|
|
| @dataclass(frozen=True) |
| class PreparedExample: |
| label: str |
| image: str |
| mask_image: str |
| pose: str |
| mask_video: str |
| prompt: str |
| replace_flag: bool = False |
|
|
|
|
| PREPARED_EXAMPLES = { |
| "Animation 001 - end-to-end": PreparedExample( |
| label="Animation 001 - end-to-end", |
| image="examples/animation_001/ref.jpg", |
| mask_image="examples/animation_001/ref_mask.jpg", |
| pose="examples/animation_001/rendered_v2.mp4", |
| mask_video="examples/animation_001/rendered_mask_v2.mp4", |
| prompt="A young woman is dancing with energetic body movement.", |
| ), |
| "Animation 001 - pose-driven": PreparedExample( |
| label="Animation 001 - pose-driven", |
| image="examples/animation_001_posedriven/ref.jpg", |
| mask_image="examples/animation_001_posedriven/ref_mask.jpg", |
| pose="examples/animation_001_posedriven/rendered_v2.mp4", |
| mask_video="examples/animation_001_posedriven/rendered_mask_v2.mp4", |
| prompt="A young woman is dancing with energetic body movement.", |
| ), |
| "Animation 002 - end-to-end": PreparedExample( |
| label="Animation 002 - end-to-end", |
| image="examples/animation_002/ref.jpg", |
| mask_image="examples/animation_002/ref_mask.jpg", |
| pose="examples/animation_002/rendered_v2.mp4", |
| mask_video="examples/animation_002/rendered_mask_v2.mp4", |
| prompt="A character performs the motion from the driving video.", |
| ), |
| "Replacement 001": PreparedExample( |
| label="Replacement 001", |
| image="examples/replace_001/ref.png", |
| mask_image="examples/replace_001/ref_mask.png", |
| pose="examples/replace_001/rendered_v2.mp4", |
| mask_video="examples/replace_001/replace_mask.mp4", |
| prompt=( |
| "A blond white male wearing a black suit, trousers, and leather shoes " |
| "is playing the violin on the street while pedestrians walk past him." |
| ), |
| replace_flag=True, |
| ), |
| } |
|
|
|
|
| def _abs(path: str | Path) -> str: |
| path = Path(path) |
| if not path.is_absolute(): |
| path = ROOT / path |
| return str(path) |
|
|
|
|
| def _existing_examples() -> dict[str, PreparedExample]: |
| available = {} |
| for name, example in PREPARED_EXAMPLES.items(): |
| paths = [example.image, example.mask_image, example.pose, example.mask_video] |
| if all(Path(_abs(p)).exists() for p in paths): |
| available[name] = example |
| return available |
|
|
|
|
| def _require_repo_layout(): |
| missing = [] |
| for rel in ("wan/scail.py", "wan/modules/model_scail2.py", "generate.py", "configs/config-14b.json"): |
| if not (ROOT / rel).exists(): |
| missing.append(rel) |
| if missing: |
| raise RuntimeError( |
| "This app.py is meant to live at the root of the SCAIL-2 repository. " |
| f"Missing: {', '.join(missing)}" |
| ) |
|
|
|
|
| def _download_safetensors_if_configured() -> Path | None: |
| if not SAFETENSORS_REPO_ID: |
| return None |
|
|
| local_dir = Path(os.getenv("SCAIL_SAFETENSORS_CACHE", str(STORAGE_ROOT / "scail2_safetensors"))) |
| local_dir.mkdir(parents=True, exist_ok=True) |
| local_path = local_dir / SAFETENSORS_FILENAME |
| if local_path.exists(): |
| return local_path |
|
|
| logging.info( |
| "Downloading converted SCAIL-2 safetensors from %s/%s", |
| SAFETENSORS_REPO_ID, |
| SAFETENSORS_FILENAME, |
| ) |
| downloaded = hf_hub_download( |
| repo_id=SAFETENSORS_REPO_ID, |
| filename=SAFETENSORS_FILENAME, |
| local_dir=str(local_dir), |
| local_dir_use_symlinks=False, |
| resume_download=True, |
| ) |
| return Path(downloaded) |
|
|
|
|
| def _find_converted_safetensors(ckpt_dir: Path | None) -> Path | None: |
| candidates = [] |
| env_path = os.getenv("SCAIL_SAFETENSORS_PATH") |
| if env_path: |
| candidates.append(Path(env_path)) |
| if _LAST_CONVERTED_SAFETENSORS is not None: |
| candidates.append(Path(_LAST_CONVERTED_SAFETENSORS)) |
| candidates += [ |
| ROOT / "SCAIL-2.safetensors", |
| ROOT / "models" / "SCAIL-2.safetensors", |
| ROOT / "model.safetensors", |
| Path(os.getenv("SCAIL_CONVERTED_DIR", str(STORAGE_ROOT / "scail2_converted"))) / "SCAIL-2.safetensors", |
| ] |
| if ckpt_dir is not None: |
| candidates += [ |
| ckpt_dir / "SCAIL-2.safetensors", |
| ckpt_dir / "model.safetensors", |
| ] |
| for candidate in candidates: |
| if candidate.exists(): |
| return candidate |
| return _download_safetensors_if_configured() |
|
|
|
|
| def _copy_file_with_progress(source: Path, dest: Path, description: str) -> Path: |
| source_size = source.stat().st_size |
| chunk_size = int(os.getenv("SCAIL_STAGE_COPY_CHUNK_MB", "64")) * 1024 * 1024 |
| log_every = int(os.getenv("SCAIL_STAGE_COPY_LOG_GB", "1")) * 1024 * 1024 * 1024 |
| dest.parent.mkdir(parents=True, exist_ok=True) |
|
|
| if dest.exists() and dest.stat().st_size == source_size: |
| return dest |
|
|
| tmp_dest = dest.with_suffix(dest.suffix + ".tmp") |
| copied = tmp_dest.stat().st_size if tmp_dest.exists() else 0 |
| if copied > source_size: |
| tmp_dest.unlink() |
| copied = 0 |
|
|
| logging.info("%s: %s -> %s", description, source, dest) |
| next_log = ((copied // log_every) + 1) * log_every if log_every > 0 else source_size |
| if copied: |
| logging.info( |
| "Resuming copy at %.2f/%.2f GB", |
| copied / 1024**3, |
| source_size / 1024**3, |
| ) |
|
|
| with source.open("rb") as src, tmp_dest.open("ab") as dst: |
| if copied: |
| src.seek(copied) |
| while copied < source_size: |
| chunk = src.read(min(chunk_size, source_size - copied)) |
| if not chunk: |
| raise RuntimeError(f"Unexpected EOF while copying {source}: {copied} of {source_size} bytes") |
| dst.write(chunk) |
| copied += len(chunk) |
| if log_every > 0 and copied >= next_log: |
| logging.info( |
| "%s: %.2f/%.2f GB", |
| description, |
| copied / 1024**3, |
| source_size / 1024**3, |
| ) |
| next_log += log_every |
|
|
| if tmp_dest.stat().st_size != source_size: |
| raise RuntimeError(f"Copied file size mismatch: {tmp_dest.stat().st_size} != {source_size}") |
| tmp_dest.replace(dest) |
| logging.info("Finished %s: %s", description, dest) |
| return dest |
|
|
|
|
| def _is_relative_to(path: Path, parent: Path) -> bool: |
| try: |
| path.resolve().relative_to(parent.resolve()) |
| return True |
| except ValueError: |
| return False |
|
|
|
|
| def _stage_safetensors_for_load(scail_path: Path) -> Path: |
| if not STAGE_SAFETENSORS_FOR_LOAD: |
| return scail_path |
|
|
| source = Path(scail_path) |
| if _is_relative_to(source, STAGING_ROOT): |
| return source |
|
|
| stage_dir = Path(os.getenv("SCAIL_MODEL_LOAD_CACHE", str(STAGING_ROOT / "scail2_model_load"))) |
| stage_dir.mkdir(parents=True, exist_ok=True) |
| staged = stage_dir / source.name |
| if staged.exists() and staged.stat().st_size == source.stat().st_size: |
| return staged |
|
|
| return _copy_file_with_progress(source, staged, "Staging SCAIL-2 safetensors for local load") |
|
|
|
|
| def _download_checkpoint_if_needed(include_original_dit: bool = False) -> Path: |
| env_dir = os.getenv("SCAIL_CKPT_DIR") |
| if env_dir: |
| ckpt_dir = Path(env_dir) |
| if not ckpt_dir.exists(): |
| raise RuntimeError(f"SCAIL_CKPT_DIR does not exist: {ckpt_dir}") |
| return ckpt_dir |
|
|
| local_dir = Path(os.getenv("SCAIL_CKPT_CACHE", str(STORAGE_ROOT / "scail2_ckpt"))) |
| has_base_assets = ( |
| (local_dir / "Wan2.1_VAE.pth").exists() |
| and (local_dir / "umt5-xxl").exists() |
| and (local_dir / CLIP_CKPT_NAME).exists() |
| ) |
| if has_base_assets: |
| return local_dir |
|
|
| logging.info("Downloading SCAIL-2 base checkpoint assets from %s", MODEL_REPO_ID) |
| if include_original_dit: |
| logging.warning( |
| "include_original_dit is deprecated here; original DiT conversion staging " |
| "uses SCAIL_ORIGINAL_DIT_CACHE instead." |
| ) |
|
|
| snapshot_download( |
| repo_id=MODEL_REPO_ID, |
| local_dir=str(local_dir), |
| local_dir_use_symlinks=False, |
| resume_download=True, |
| allow_patterns=BASE_ALLOW_PATTERNS, |
| ) |
| return local_dir |
|
|
|
|
| def _download_original_dit_for_conversion() -> Path: |
| env_dir = os.getenv("SCAIL_ORIGINAL_DIT_DIR") |
| if env_dir: |
| original_dir = Path(env_dir) |
| original_path = original_dir / ORIGINAL_DIT_REL_PATH |
| if not original_path.exists(): |
| raise RuntimeError(f"SCAIL_ORIGINAL_DIT_DIR is missing {ORIGINAL_DIT_REL_PATH}: {original_dir}") |
| return original_dir |
|
|
| local_dir = Path(os.getenv("SCAIL_ORIGINAL_DIT_CACHE", str(STAGING_ROOT / "scail2_original_dit"))) |
| original_path = local_dir / ORIGINAL_DIT_REL_PATH |
| if original_path.exists(): |
| return local_dir |
|
|
| logging.info( |
| "Downloading original SCAIL-2 DiT checkpoint for one-time conversion into %s", |
| local_dir, |
| ) |
| snapshot_download( |
| repo_id=MODEL_REPO_ID, |
| local_dir=str(local_dir), |
| local_dir_use_symlinks=False, |
| resume_download=True, |
| allow_patterns=[ORIGINAL_DIT_REL_PATH], |
| ) |
| return local_dir |
|
|
|
|
| def _prepare_assets_for_runtime() -> str: |
| """Download CPU-side assets before any ZeroGPU function starts.""" |
| global _ASSET_STATUS, _ASSET_ERROR |
| try: |
| ckpt_dir = _download_checkpoint_if_needed(include_original_dit=False) |
| scail_path = _find_converted_safetensors(ckpt_dir) |
| if scail_path is None and AUTO_CONVERT: |
| original_dir = _download_original_dit_for_conversion() |
| scail_path = _maybe_convert_checkpoint(original_dir, None) |
| if scail_path is None: |
| _ASSET_STATUS = ( |
| "Base checkpoint assets are present, but no converted safetensors file " |
| "was found. Set SCAIL_SAFETENSORS_REPO_ID plus SCAIL_SAFETENSORS_FILENAME, " |
| "or set SCAIL_SAFETENSORS_PATH. Automatic startup conversion is disabled " |
| "only when SCAIL_AUTO_CONVERT=0." |
| ) |
| else: |
| _ASSET_STATUS = ( |
| "Assets ready. Base checkpoint: " |
| f"{ckpt_dir}. Converted DiT safetensors: {scail_path}." |
| ) |
| _ASSET_ERROR = None |
| except Exception: |
| _ASSET_ERROR = traceback.format_exc() |
| _ASSET_STATUS = "Asset preparation failed. See the traceback below." |
| logging.exception("Asset preparation failed") |
| return _ASSET_STATUS if _ASSET_ERROR is None else _ASSET_STATUS + "\n\n" + _ASSET_ERROR |
|
|
|
|
| def _maybe_convert_checkpoint(ckpt_dir: Path, scail_path: Path | None) -> Path: |
| global _LAST_CONVERTED_SAFETENSORS |
| if scail_path is not None: |
| return scail_path |
|
|
| if not AUTO_CONVERT: |
| raise RuntimeError( |
| "Converted SCAIL-2 safetensors file was not found. For the wan-scail2 branch, " |
| "provide SCAIL_SAFETENSORS_PATH=/path/to/SCAIL-2.safetensors, or place " |
| "SCAIL-2.safetensors at the repo root. Automatic startup conversion is disabled " |
| "because SCAIL_AUTO_CONVERT=0." |
| ) |
|
|
| persistent_dir = Path(os.getenv("SCAIL_CONVERTED_DIR", str(STORAGE_ROOT / "scail2_converted"))) |
| persistent_path = persistent_dir / "SCAIL-2.safetensors" |
| if persistent_path.exists(): |
| return persistent_path |
|
|
| if CONVERT_TO_STAGING_FIRST: |
| save_dir = Path(os.getenv("SCAIL_CONVERSION_WORK_DIR", str(STAGING_ROOT / "scail2_converted_work"))) |
| else: |
| save_dir = persistent_dir |
| save_dir.mkdir(parents=True, exist_ok=True) |
| save_path = save_dir / "SCAIL-2.safetensors" |
| if save_path.exists(): |
| _LAST_CONVERTED_SAFETENSORS = save_path |
| if save_path != persistent_path: |
| _copy_file_with_progress(save_path, persistent_path, "Persisting converted SCAIL-2 safetensors to storage") |
| return save_path |
|
|
| logging.info("Converting checkpoint to safetensors: %s", save_path) |
| convert_env = os.environ.copy() |
| |
| |
| |
| convert_env["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1" |
| subprocess.run( |
| [ |
| sys.executable, |
| str(ROOT / "convert.py"), |
| "--scail-dir", |
| str(ckpt_dir), |
| "--save-path", |
| str(save_path), |
| ], |
| check=True, |
| cwd=str(ROOT), |
| env=convert_env, |
| ) |
| _LAST_CONVERTED_SAFETENSORS = save_path |
| if save_path != persistent_path: |
| _copy_file_with_progress(save_path, persistent_path, "Persisting converted SCAIL-2 safetensors to storage") |
| return save_path |
|
|
|
|
| def _install_attention_patch(): |
| """Use HF Kernels flash-attn2 when available; otherwise fall back to SDPA. |
| |
| SCAIL-2 imports `flash_attention` directly in model_scail2.py. This monkey patch |
| avoids requiring a locally built `flash_attn` wheel on Spaces. |
| """ |
| import wan.modules.attention as attention_mod |
|
|
| hf_flash_attn2 = None |
| try: |
| from kernels import get_kernel |
|
|
| hf_flash_attn2 = get_kernel("kernels-community/flash-attn2", version=2) |
| logging.info("Using kernels-community/flash-attn2 through HF Kernels.") |
| except Exception as exc: |
| if torch.cuda.is_available(): |
| device_name = torch.cuda.get_device_name(0) |
| capability = torch.cuda.get_device_capability(0) |
| else: |
| device_name = "no cuda" |
| capability = None |
| logging.warning("Could not initialize HF Kernels flash-attn2: %r", exc) |
| logging.warning( |
| "Attention fallback environment: torch=%s cuda=%s device=%s capability=%s", |
| torch.__version__, |
| torch.version.cuda, |
| device_name, |
| capability, |
| ) |
|
|
| def patched_flash_attention( |
| q, |
| k, |
| v, |
| q_lens=None, |
| k_lens=None, |
| dropout_p=0.0, |
| softmax_scale=None, |
| q_scale=None, |
| causal=False, |
| window_size=(-1, -1), |
| deterministic=False, |
| dtype=torch.bfloat16, |
| version=None, |
| ): |
| half_dtypes = (torch.float16, torch.bfloat16) |
| b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype |
|
|
| def half(x): |
| return x if x.dtype in half_dtypes else x.to(dtype) |
|
|
| if hf_flash_attn2 is not None and q.device.type == "cuda": |
| if q_lens is None: |
| q_var = half(q.flatten(0, 1)) |
| q_lens_t = torch.full((b,), lq, dtype=torch.int32, device=q.device) |
| else: |
| q_lens_t = q_lens.to(device=q.device, dtype=torch.int32) |
| q_var = half(torch.cat([u[: int(n)] for u, n in zip(q, q_lens_t)])) |
|
|
| if k_lens is None: |
| k_var = half(k.flatten(0, 1)) |
| v_var = half(v.flatten(0, 1)) |
| k_lens_t = torch.full((b,), lk, dtype=torch.int32, device=k.device) |
| else: |
| k_lens_t = k_lens.to(device=k.device, dtype=torch.int32) |
| k_var = half(torch.cat([u[: int(n)] for u, n in zip(k, k_lens_t)])) |
| v_var = half(torch.cat([u[: int(n)] for u, n in zip(v, k_lens_t)])) |
|
|
| q_var = q_var.to(v_var.dtype) |
| k_var = k_var.to(v_var.dtype) |
| if q_scale is not None: |
| q_var = q_var * q_scale |
|
|
| cu_q = torch.cat([q_lens_t.new_zeros([1]), q_lens_t]).cumsum(0, dtype=torch.int32) |
| cu_k = torch.cat([k_lens_t.new_zeros([1]), k_lens_t]).cumsum(0, dtype=torch.int32) |
|
|
| try: |
| out = hf_flash_attn2.flash_attn_varlen_func( |
| q=q_var, |
| k=k_var, |
| v=v_var, |
| cu_seqlens_q=cu_q, |
| cu_seqlens_k=cu_k, |
| max_seqlen_q=lq, |
| max_seqlen_k=lk, |
| dropout_p=dropout_p, |
| softmax_scale=softmax_scale, |
| causal=causal, |
| window_size=window_size, |
| deterministic=deterministic, |
| ) |
| if isinstance(out, tuple): |
| out = out[0] |
| return out.unflatten(0, (b, lq)).type(out_dtype) |
| except Exception as exc: |
| logging.warning("HF Kernels flash-attn2 failed, falling back to SDPA: %s", exc) |
|
|
| if q_lens is not None and not torch.all(q_lens == lq): |
| logging.warning("SDPA fallback ignores variable q_lens; demo batch size should stay at 1.") |
| if k_lens is not None and not torch.all(k_lens == lk): |
| logging.warning("SDPA fallback ignores variable k_lens; demo batch size should stay at 1.") |
|
|
| q_sdpa = q.transpose(1, 2).to(dtype) |
| k_sdpa = k.transpose(1, 2).to(dtype) |
| v_sdpa = v.transpose(1, 2).to(dtype) |
| out = torch.nn.functional.scaled_dot_product_attention( |
| q_sdpa, |
| k_sdpa, |
| v_sdpa, |
| attn_mask=None, |
| dropout_p=dropout_p, |
| is_causal=causal, |
| scale=softmax_scale, |
| ) |
| return out.transpose(1, 2).contiguous().type(out_dtype) |
|
|
| attention_mod.flash_attention = patched_flash_attention |
|
|
| |
| |
| for module_name in ( |
| "wan.modules.clip", |
| "wan.modules.model", |
| "wan.modules.model_scail", |
| "wan.modules.model_scail2", |
| ): |
| try: |
| module = __import__(module_name, fromlist=["flash_attention"]) |
| if hasattr(module, "flash_attention"): |
| module.flash_attention = patched_flash_attention |
| logging.info("Patched %s.flash_attention", module_name) |
| except Exception as exc: |
| logging.warning("Could not patch %s.flash_attention: %s", module_name, exc) |
|
|
|
|
| def _import_runtime(): |
| global _WAN, _GENERATE_VIDEO, _SCAIL_CONFIGS, _SCAIL_CONFIG_PATHS |
| if _WAN is not None: |
| return |
|
|
| _require_repo_layout() |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| import wan |
| from generate import generate_video |
| from wan.configs import SCAIL_CONFIGS, SCAIL_CONFIG_PATHS |
|
|
| _install_attention_patch() |
| _WAN = wan |
| _GENERATE_VIDEO = generate_video |
| _SCAIL_CONFIGS = SCAIL_CONFIGS |
| _SCAIL_CONFIG_PATHS = SCAIL_CONFIG_PATHS |
|
|
|
|
| def _prepare_runtime_for_startup() -> str: |
| """Import Wan/SCAIL and resolve HF Kernels before any ZeroGPU call.""" |
| global _RUNTIME_STATUS, _RUNTIME_ERROR |
| try: |
| _import_runtime() |
| _RUNTIME_STATUS = "Runtime ready. Attention backend has been initialized at startup." |
| _RUNTIME_ERROR = None |
| except Exception: |
| _RUNTIME_ERROR = traceback.format_exc() |
| _RUNTIME_STATUS = "Runtime preparation failed. See the traceback below." |
| logging.exception("Runtime preparation failed") |
| return _RUNTIME_STATUS if _RUNTIME_ERROR is None else _RUNTIME_STATUS + "\n\n" + _RUNTIME_ERROR |
|
|
|
|
| def _get_pipeline(): |
| global _PIPELINE, _PIPELINE_KEY |
| _import_runtime() |
|
|
| ckpt_dir = _download_checkpoint_if_needed(include_original_dit=False) |
| scail_path = _find_converted_safetensors(ckpt_dir) |
| if scail_path is None and AUTO_CONVERT: |
| original_dir = _download_original_dit_for_conversion() |
| scail_path = _maybe_convert_checkpoint(original_dir, None) |
| else: |
| scail_path = _maybe_convert_checkpoint(ckpt_dir, scail_path) |
| scail_load_path = _stage_safetensors_for_load(scail_path) |
| config_path = Path(os.getenv("SCAIL_CONFIG_PATH", _SCAIL_CONFIG_PATHS[MODEL_NAME])) |
| if not config_path.is_absolute(): |
| config_path = ROOT / config_path |
|
|
| lora_path = os.getenv("SCAIL_LORA_PATH") or None |
| lora_alpha = float(os.getenv("SCAIL_LORA_ALPHA", "1.0")) |
| key = (str(ckpt_dir), str(scail_load_path), str(config_path), lora_path, lora_alpha) |
| if _PIPELINE is not None and _PIPELINE_KEY == key: |
| return _PIPELINE |
|
|
| logging.info("Loading SCAIL-2 pipeline.") |
| cfg = _SCAIL_CONFIGS[MODEL_NAME] |
| _PIPELINE = _WAN.SCAIL2Pipeline( |
| config=cfg, |
| checkpoint_dir=str(ckpt_dir), |
| scail_safetensors_path=str(scail_load_path), |
| scail_config_path=str(config_path), |
| device_id=0, |
| rank=0, |
| t5_fsdp=False, |
| dit_fsdp=False, |
| use_usp=False, |
| t5_cpu=False, |
| lora_path=lora_path, |
| lora_alpha=lora_alpha, |
| ) |
| _PIPELINE_KEY = key |
| return _PIPELINE |
|
|
|
|
| def _prepare_pipeline_for_startup() -> str: |
| """Preload the CUDA pipeline at module startup for ZeroGPU CUDA emulation.""" |
| global _PIPELINE_STATUS, _PIPELINE_ERROR |
| try: |
| _get_pipeline() |
| _PIPELINE_STATUS = "Pipeline preloaded at startup." |
| _PIPELINE_ERROR = None |
| except Exception: |
| _PIPELINE_ERROR = traceback.format_exc() |
| _PIPELINE_STATUS = "Pipeline preload failed. See the traceback below." |
| logging.exception("Pipeline preload failed") |
| return _PIPELINE_STATUS if _PIPELINE_ERROR is None else _PIPELINE_STATUS + "\n\n" + _PIPELINE_ERROR |
|
|
|
|
|
|
| def _is_gradio_native_file_path(path: Path) -> bool: |
| """Return True when Gradio can serve a file without explicit allowed_paths.""" |
| path = path.resolve() |
| native_roots = [ROOT.resolve(), Path(tempfile.gettempdir()).resolve()] |
| return any(_is_relative_to(path, root) for root in native_roots) |
|
|
|
|
| def _prepare_output_for_gradio(path: str | Path) -> str: |
| """Return a display path accepted by Gradio. |
| |
| The canonical output is saved in OUTPUT_DIR, which may live on a mounted |
| Bucket such as /data/scail2_outputs. Gradio can be strict about returned |
| file paths, so we keep the persistent copy in the Bucket and return a |
| temporary copy under /tmp for the UI component. |
| """ |
| source = Path(path) |
| if not source.exists(): |
| raise RuntimeError(f"Generated video was not found: {source}") |
|
|
| if _is_gradio_native_file_path(source): |
| return str(source) |
|
|
| gradio_dir = Path( |
| os.getenv( |
| "SCAIL_GRADIO_OUTPUT_CACHE", |
| str(Path(tempfile.gettempdir()) / "scail2_gradio_outputs"), |
| ) |
| ) |
| gradio_dir.mkdir(parents=True, exist_ok=True) |
| dest = gradio_dir / source.name |
| shutil.copy2(source, dest) |
| logging.info("Copied generated video for Gradio display: %s -> %s", source, dest) |
| return str(dest) |
|
|
| def _duration_for_job(*args, **kwargs): |
| if _PIPELINE is None: |
| return int(os.getenv("SCAIL_GPU_DURATION", str(GPU_DURATION_COLD))) |
| return int(os.getenv("SCAIL_GPU_DURATION", str(GPU_DURATION_WARM))) |
|
|
|
|
| def _run_scail_job( |
| image_path, |
| mask_image_path, |
| pose_path, |
| mask_video_path, |
| prompt, |
| replace_flag, |
| target_h, |
| target_w, |
| sample_steps, |
| guide_scale, |
| sample_shift, |
| seed, |
| segment_len, |
| segment_overlap, |
| progress=None, |
| ): |
| if progress is not None: |
| progress(0.02, desc="Loading SCAIL-2 pipeline") |
| pipeline = _get_pipeline() |
| cfg = _SCAIL_CONFIGS[MODEL_NAME] |
| save_file = OUTPUT_DIR / f"scail2_{uuid.uuid4().hex}.mp4" |
| if progress is not None: |
| progress(0.12, desc="Preparing inputs") |
| args = SimpleNamespace( |
| target_h=int(target_h), |
| target_w=int(target_w), |
| sample_shift=float(sample_shift), |
| sample_solver=DEFAULT_SOLVER, |
| segment_len=int(segment_len), |
| segment_overlap=int(segment_overlap), |
| sample_steps=int(sample_steps), |
| sample_guide_scale=float(guide_scale), |
| base_seed=int(seed), |
| offload_model=True, |
| save_file=str(save_file), |
| save_dir=str(OUTPUT_DIR), |
| prompt=prompt or "", |
| ) |
|
|
| if progress is not None: |
| progress(0.15, desc="Generating video") |
| _GENERATE_VIDEO( |
| pipeline, |
| prompt or "", |
| str(image_path), |
| str(mask_image_path), |
| str(pose_path), |
| str(mask_video_path), |
| args, |
| device=0, |
| rank=0, |
| cfg=cfg, |
| input_idx=None, |
| replace_flag=bool(replace_flag), |
| additional_task_input=None, |
| ) |
| if progress is not None: |
| progress(0.95, desc="Finalizing output") |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| if progress is not None: |
| progress(0.98, desc="Preparing video for display") |
| display_file = _prepare_output_for_gradio(save_file) |
| if progress is not None: |
| progress(1.0, desc="Done") |
| return display_file |
|
|
|
|
| @spaces.GPU(duration=_duration_for_job, size=GPU_SIZE) |
| def generate_from_example( |
| example_name, |
| prompt, |
| sample_steps, |
| guide_scale, |
| sample_shift, |
| seed, |
| target_size, |
| segment_len, |
| segment_overlap, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| try: |
| progress(0.0, desc="Checking example inputs") |
| examples = _existing_examples() |
| if example_name not in examples: |
| raise RuntimeError(f"Example is missing from this checkout: {example_name}") |
| example = examples[example_name] |
| target_w, target_h = [int(v) for v in str(target_size).split("x")] |
| output = _run_scail_job( |
| _abs(example.image), |
| _abs(example.mask_image), |
| _abs(example.pose), |
| _abs(example.mask_video), |
| prompt, |
| example.replace_flag, |
| target_h, |
| target_w, |
| sample_steps, |
| guide_scale, |
| sample_shift, |
| seed, |
| segment_len, |
| segment_overlap, |
| progress=progress, |
| ) |
| return output, "Done." |
| except Exception: |
| logging.exception("Generation failed") |
| return None, traceback.format_exc() |
|
|
|
|
| @spaces.GPU(duration=_duration_for_job, size=GPU_SIZE) |
| def generate_from_uploads( |
| image, |
| mask_image, |
| pose_video, |
| mask_video, |
| prompt, |
| mode, |
| sample_steps, |
| guide_scale, |
| sample_shift, |
| seed, |
| target_size, |
| segment_len, |
| segment_overlap, |
| progress=gr.Progress(track_tqdm=True), |
| ): |
| try: |
| progress(0.0, desc="Checking uploaded inputs") |
| required = { |
| "reference image": image, |
| "reference mask": mask_image, |
| "driving/rendered video": pose_video, |
| "driving mask video": mask_video, |
| } |
| missing = [name for name, value in required.items() if value is None] |
| if missing: |
| raise RuntimeError("Missing required input(s): " + ", ".join(missing)) |
|
|
| target_w, target_h = [int(v) for v in str(target_size).split("x")] |
| output = _run_scail_job( |
| image, |
| mask_image, |
| pose_video, |
| mask_video, |
| prompt, |
| mode == "replacement", |
| target_h, |
| target_w, |
| sample_steps, |
| guide_scale, |
| sample_shift, |
| seed, |
| segment_len, |
| segment_overlap, |
| progress=progress, |
| ) |
| return output, "Done." |
| except Exception: |
| logging.exception("Generation failed") |
| return None, traceback.format_exc() |
|
|
|
|
| def load_example_preview(example_name): |
| examples = _existing_examples() |
| if example_name not in examples: |
| return None, None, None, None, "", "animation" |
| example = examples[example_name] |
| mode = "replacement" if example.replace_flag else "animation" |
| return ( |
| _abs(example.image), |
| _abs(example.pose), |
| _abs(example.mask_image), |
| _abs(example.mask_video), |
| example.prompt, |
| mode, |
| ) |
|
|
|
|
| def _startup_message(): |
| try: |
| _require_repo_layout() |
| examples = _existing_examples() |
| if not examples: |
| return "Repo layout detected, but no prepared examples were found." |
| return ( |
| f"Ready. Found {len(examples)} prepared example(s). " |
| f"Storage root: {STORAGE_ROOT}. Staging root: {STAGING_ROOT}. " |
| f"Output dir: {OUTPUT_DIR}.\n\n" |
| f"{_ASSET_STATUS}\n\n{_RUNTIME_STATUS}\n\n{_PIPELINE_STATUS}\n\n" |
| "Attention backend: HF Kernels flash-attn2 when available, otherwise SDPA." |
| ) |
| except Exception as exc: |
| return str(exc) |
|
|
|
|
| def build_ui(): |
| examples = _existing_examples() |
| example_names = list(examples.keys()) |
| default_example = example_names[0] if example_names else None |
| default_preview = load_example_preview(default_example) if default_example else (None, None, None, None, "", "animation") |
|
|
| with gr.Blocks(title="SCAIL-2 ZeroGPU Demo") as demo: |
| gr.Markdown( |
| "# SCAIL-2 Character Animation Demo\n" |
| "Animate or replace characters with SCAIL-2, directly in the browser. " |
| "The easiest way to try the model is to start from one of the prepared examples below.\n\n" |
| "This Space runs a heavy 14B video model on Hugging Face ZeroGPU. " |
| "Generation is not realtime: a short example can take a few minutes, especially " |
| "when the Space has just started.\n\n" |
| "For custom inputs, use the Advanced Uploads tab with already-prepared SCAIL-2 " |
| "conditions: reference image, reference mask, driving video, and driving mask. " |
| "Automatic SCAIL-Pose/SAM3 preprocessing is not enabled in this public demo yet." |
| ) |
| startup = gr.Textbox(value=_startup_message(), label="Startup status", interactive=False) |
|
|
| with gr.Tab("Prepared Examples"): |
| with gr.Row(): |
| example_dropdown = gr.Dropdown( |
| choices=example_names, |
| value=default_example, |
| label="Example", |
| ) |
| mode_view = gr.Textbox(value=default_preview[5], label="Mode", interactive=False) |
| with gr.Row(): |
| ref_preview = gr.Image(value=default_preview[0], label="Reference", interactive=False) |
| driving_preview = gr.Video(value=default_preview[1], label="Driving / rendered video") |
| with gr.Row(): |
| ref_mask_preview = gr.Image(value=default_preview[2], label="Reference mask", interactive=False) |
| driving_mask_preview = gr.Video(value=default_preview[3], label="Driving mask") |
| prompt = gr.Textbox(value=default_preview[4], label="Prompt", lines=3) |
|
|
| with gr.Row(): |
| sample_steps = gr.Slider(4, 40, value=8, step=1, label="Steps") |
| guide_scale = gr.Slider(1.0, 8.0, value=DEFAULT_GUIDE_SCALE, step=0.1, label="CFG") |
| sample_shift = gr.Slider(1.0, 6.0, value=DEFAULT_SHIFT, step=0.1, label="Shift") |
| with gr.Row(): |
| seed = gr.Number(value=42, precision=0, label="Seed") |
| target_size = gr.Dropdown(["896x512", "512x896", "1280x704", "704x1280"], value=f"{DEFAULT_TARGET_W}x{DEFAULT_TARGET_H}", label="Target size") |
| segment_len = gr.Number(value=DEFAULT_SEGMENT_LEN, precision=0, label="Segment length") |
| segment_overlap = gr.Number(value=DEFAULT_SEGMENT_OVERLAP, precision=0, label="Segment overlap") |
|
|
| run_example = gr.Button("Generate", variant="primary") |
| output_video = gr.Video(label="Output") |
| status = gr.Textbox(label="Run status", lines=8) |
|
|
| example_dropdown.change( |
| load_example_preview, |
| inputs=[example_dropdown], |
| outputs=[ref_preview, driving_preview, ref_mask_preview, driving_mask_preview, prompt, mode_view], |
| ) |
| run_example.click( |
| generate_from_example, |
| inputs=[ |
| example_dropdown, |
| prompt, |
| sample_steps, |
| guide_scale, |
| sample_shift, |
| seed, |
| target_size, |
| segment_len, |
| segment_overlap, |
| ], |
| outputs=[output_video, status], |
| ) |
|
|
| with gr.Tab("Advanced Uploads"): |
| gr.Markdown( |
| "Upload custom inputs that have already been prepared for SCAIL-2: reference image, " |
| "reference mask, driving/rendered video, and driving mask video. " |
| "This public demo does not run SCAIL-Pose/SAM3 preprocessing yet." |
| ) |
| with gr.Row(): |
| up_image = gr.Image(type="filepath", label="Reference image") |
| up_mask_image = gr.Image(type="filepath", label="Reference mask") |
| with gr.Row(): |
| up_pose_video = gr.Video(label="Driving / rendered video") |
| up_mask_video = gr.Video(label="Driving mask / replace mask") |
| up_mode = gr.Radio(["animation", "replacement"], value="animation", label="Mode") |
| up_prompt = gr.Textbox(label="Prompt", lines=3) |
| with gr.Row(): |
| up_steps = gr.Slider(4, 40, value=8, step=1, label="Steps") |
| up_cfg = gr.Slider(1.0, 8.0, value=DEFAULT_GUIDE_SCALE, step=0.1, label="CFG") |
| up_shift = gr.Slider(1.0, 6.0, value=DEFAULT_SHIFT, step=0.1, label="Shift") |
| with gr.Row(): |
| up_seed = gr.Number(value=42, precision=0, label="Seed") |
| up_target_size = gr.Dropdown(["896x512", "512x896", "1280x704", "704x1280"], value=f"{DEFAULT_TARGET_W}x{DEFAULT_TARGET_H}", label="Target size") |
| up_segment_len = gr.Number(value=DEFAULT_SEGMENT_LEN, precision=0, label="Segment length") |
| up_segment_overlap = gr.Number(value=DEFAULT_SEGMENT_OVERLAP, precision=0, label="Segment overlap") |
| run_upload = gr.Button("Generate from uploads", variant="primary") |
| upload_output = gr.Video(label="Output") |
| upload_status = gr.Textbox(label="Run status", lines=8) |
| run_upload.click( |
| generate_from_uploads, |
| inputs=[ |
| up_image, |
| up_mask_image, |
| up_pose_video, |
| up_mask_video, |
| up_prompt, |
| up_mode, |
| up_steps, |
| up_cfg, |
| up_shift, |
| up_seed, |
| up_target_size, |
| up_segment_len, |
| up_segment_overlap, |
| ], |
| outputs=[upload_output, upload_status], |
| ) |
|
|
| return demo |
|
|
|
|
| if __name__ == "__main__": |
| if os.getenv("SCAIL_PRELOAD_ASSETS", "1") == "1": |
| _prepare_assets_for_runtime() |
| if os.getenv("SCAIL_PRELOAD_RUNTIME", "1") == "1": |
| _prepare_runtime_for_startup() |
| if PRELOAD_PIPELINE: |
| _prepare_pipeline_for_startup() |
| build_ui().queue(max_size=8).launch( |
| allowed_paths=[str(OUTPUT_DIR.resolve())], |
| show_error=True, |
| ) |
|
|