SCAIL-2 / app.py
fffiloni's picture
Update app.py
1d78a08 verified
Raw
History Blame Contribute Delete
38 kB
import gc
import logging
import os
import shutil
import subprocess
import sys
import tempfile
import traceback
import uuid
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
import gradio as gr
import spaces
import torch
from huggingface_hub import hf_hub_download, snapshot_download
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s")
ROOT = Path(__file__).resolve().parent
def _default_storage_root() -> Path:
if os.getenv("SCAIL_STORAGE_ROOT"):
return Path(os.environ["SCAIL_STORAGE_ROOT"])
data_mount = Path("/data")
if data_mount.exists() and os.access(data_mount, os.W_OK):
return data_mount
return Path("/tmp")
STORAGE_ROOT = _default_storage_root()
STAGING_ROOT = Path(os.getenv("SCAIL_STAGING_ROOT", "/tmp"))
OUTPUT_DIR = Path(os.getenv("SCAIL_OUTPUT_DIR", str(STORAGE_ROOT / "scail2_outputs")))
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
MODEL_REPO_ID = os.getenv("SCAIL_MODEL_REPO_ID", "zai-org/SCAIL-2")
SAFETENSORS_REPO_ID = os.getenv("SCAIL_SAFETENSORS_REPO_ID")
SAFETENSORS_FILENAME = os.getenv("SCAIL_SAFETENSORS_FILENAME", "SCAIL-2.safetensors")
MODEL_NAME = os.getenv("SCAIL_MODEL_NAME", "SCAIL-14B")
GPU_SIZE = os.getenv("SCAIL_ZEROGPU_SIZE", "xlarge")
GPU_DURATION_COLD = int(os.getenv("SCAIL_GPU_DURATION_COLD", "600"))
GPU_DURATION_WARM = int(os.getenv("SCAIL_GPU_DURATION_WARM", "330"))
DEFAULT_TARGET_H = int(os.getenv("SCAIL_TARGET_H", "512"))
DEFAULT_TARGET_W = int(os.getenv("SCAIL_TARGET_W", "896"))
DEFAULT_SEGMENT_LEN = int(os.getenv("SCAIL_SEGMENT_LEN", "81"))
DEFAULT_SEGMENT_OVERLAP = int(os.getenv("SCAIL_SEGMENT_OVERLAP", "5"))
DEFAULT_SHIFT = float(os.getenv("SCAIL_SAMPLE_SHIFT", "3.0"))
DEFAULT_GUIDE_SCALE = float(os.getenv("SCAIL_GUIDE_SCALE", "5.0"))
DEFAULT_SOLVER = os.getenv("SCAIL_SAMPLE_SOLVER", "unipc")
AUTO_CONVERT = os.getenv("SCAIL_AUTO_CONVERT", "1") == "1"
PRELOAD_PIPELINE = os.getenv("SCAIL_PRELOAD_PIPELINE", "1") == "1"
STAGE_SAFETENSORS_FOR_LOAD = os.getenv("SCAIL_STAGE_SAFETENSORS_FOR_LOAD", "1") == "1"
CONVERT_TO_STAGING_FIRST = os.getenv("SCAIL_CONVERT_TO_STAGING_FIRST", "1") == "1"
CLIP_CKPT_NAME = "models_clip_open-clip-xlm-roberta-large-vit-huge-14-onlyvisual.pth"
ORIGINAL_DIT_REL_PATH = "model/1/fsdp2_rank_0000_checkpoint.pt"
BASE_ALLOW_PATTERNS = [
"Wan2.1_VAE.pth",
"umt5-xxl/**",
CLIP_CKPT_NAME,
]
_PIPELINE = None
_PIPELINE_KEY = None
_ASSET_STATUS = "Assets were not prepared yet."
_ASSET_ERROR = None
_RUNTIME_STATUS = "Runtime was not prepared yet."
_RUNTIME_ERROR = None
_PIPELINE_STATUS = "Pipeline was not preloaded."
_PIPELINE_ERROR = None
_LAST_CONVERTED_SAFETENSORS = None
_WAN = None
_GENERATE_VIDEO = None
_SCAIL_CONFIGS = None
_SCAIL_CONFIG_PATHS = None
@dataclass(frozen=True)
class PreparedExample:
label: str
image: str
mask_image: str
pose: str
mask_video: str
prompt: str
replace_flag: bool = False
PREPARED_EXAMPLES = {
"Animation 001 - end-to-end": PreparedExample(
label="Animation 001 - end-to-end",
image="examples/animation_001/ref.jpg",
mask_image="examples/animation_001/ref_mask.jpg",
pose="examples/animation_001/rendered_v2.mp4",
mask_video="examples/animation_001/rendered_mask_v2.mp4",
prompt="A young woman is dancing with energetic body movement.",
),
"Animation 001 - pose-driven": PreparedExample(
label="Animation 001 - pose-driven",
image="examples/animation_001_posedriven/ref.jpg",
mask_image="examples/animation_001_posedriven/ref_mask.jpg",
pose="examples/animation_001_posedriven/rendered_v2.mp4",
mask_video="examples/animation_001_posedriven/rendered_mask_v2.mp4",
prompt="A young woman is dancing with energetic body movement.",
),
"Animation 002 - end-to-end": PreparedExample(
label="Animation 002 - end-to-end",
image="examples/animation_002/ref.jpg",
mask_image="examples/animation_002/ref_mask.jpg",
pose="examples/animation_002/rendered_v2.mp4",
mask_video="examples/animation_002/rendered_mask_v2.mp4",
prompt="A character performs the motion from the driving video.",
),
"Replacement 001": PreparedExample(
label="Replacement 001",
image="examples/replace_001/ref.png",
mask_image="examples/replace_001/ref_mask.png",
pose="examples/replace_001/rendered_v2.mp4",
mask_video="examples/replace_001/replace_mask.mp4",
prompt=(
"A blond white male wearing a black suit, trousers, and leather shoes "
"is playing the violin on the street while pedestrians walk past him."
),
replace_flag=True,
),
}
def _abs(path: str | Path) -> str:
path = Path(path)
if not path.is_absolute():
path = ROOT / path
return str(path)
def _existing_examples() -> dict[str, PreparedExample]:
available = {}
for name, example in PREPARED_EXAMPLES.items():
paths = [example.image, example.mask_image, example.pose, example.mask_video]
if all(Path(_abs(p)).exists() for p in paths):
available[name] = example
return available
def _require_repo_layout():
missing = []
for rel in ("wan/scail.py", "wan/modules/model_scail2.py", "generate.py", "configs/config-14b.json"):
if not (ROOT / rel).exists():
missing.append(rel)
if missing:
raise RuntimeError(
"This app.py is meant to live at the root of the SCAIL-2 repository. "
f"Missing: {', '.join(missing)}"
)
def _download_safetensors_if_configured() -> Path | None:
if not SAFETENSORS_REPO_ID:
return None
local_dir = Path(os.getenv("SCAIL_SAFETENSORS_CACHE", str(STORAGE_ROOT / "scail2_safetensors")))
local_dir.mkdir(parents=True, exist_ok=True)
local_path = local_dir / SAFETENSORS_FILENAME
if local_path.exists():
return local_path
logging.info(
"Downloading converted SCAIL-2 safetensors from %s/%s",
SAFETENSORS_REPO_ID,
SAFETENSORS_FILENAME,
)
downloaded = hf_hub_download(
repo_id=SAFETENSORS_REPO_ID,
filename=SAFETENSORS_FILENAME,
local_dir=str(local_dir),
local_dir_use_symlinks=False,
resume_download=True,
)
return Path(downloaded)
def _find_converted_safetensors(ckpt_dir: Path | None) -> Path | None:
candidates = []
env_path = os.getenv("SCAIL_SAFETENSORS_PATH")
if env_path:
candidates.append(Path(env_path))
if _LAST_CONVERTED_SAFETENSORS is not None:
candidates.append(Path(_LAST_CONVERTED_SAFETENSORS))
candidates += [
ROOT / "SCAIL-2.safetensors",
ROOT / "models" / "SCAIL-2.safetensors",
ROOT / "model.safetensors",
Path(os.getenv("SCAIL_CONVERTED_DIR", str(STORAGE_ROOT / "scail2_converted"))) / "SCAIL-2.safetensors",
]
if ckpt_dir is not None:
candidates += [
ckpt_dir / "SCAIL-2.safetensors",
ckpt_dir / "model.safetensors",
]
for candidate in candidates:
if candidate.exists():
return candidate
return _download_safetensors_if_configured()
def _copy_file_with_progress(source: Path, dest: Path, description: str) -> Path:
source_size = source.stat().st_size
chunk_size = int(os.getenv("SCAIL_STAGE_COPY_CHUNK_MB", "64")) * 1024 * 1024
log_every = int(os.getenv("SCAIL_STAGE_COPY_LOG_GB", "1")) * 1024 * 1024 * 1024
dest.parent.mkdir(parents=True, exist_ok=True)
if dest.exists() and dest.stat().st_size == source_size:
return dest
tmp_dest = dest.with_suffix(dest.suffix + ".tmp")
copied = tmp_dest.stat().st_size if tmp_dest.exists() else 0
if copied > source_size:
tmp_dest.unlink()
copied = 0
logging.info("%s: %s -> %s", description, source, dest)
next_log = ((copied // log_every) + 1) * log_every if log_every > 0 else source_size
if copied:
logging.info(
"Resuming copy at %.2f/%.2f GB",
copied / 1024**3,
source_size / 1024**3,
)
with source.open("rb") as src, tmp_dest.open("ab") as dst:
if copied:
src.seek(copied)
while copied < source_size:
chunk = src.read(min(chunk_size, source_size - copied))
if not chunk:
raise RuntimeError(f"Unexpected EOF while copying {source}: {copied} of {source_size} bytes")
dst.write(chunk)
copied += len(chunk)
if log_every > 0 and copied >= next_log:
logging.info(
"%s: %.2f/%.2f GB",
description,
copied / 1024**3,
source_size / 1024**3,
)
next_log += log_every
if tmp_dest.stat().st_size != source_size:
raise RuntimeError(f"Copied file size mismatch: {tmp_dest.stat().st_size} != {source_size}")
tmp_dest.replace(dest)
logging.info("Finished %s: %s", description, dest)
return dest
def _is_relative_to(path: Path, parent: Path) -> bool:
try:
path.resolve().relative_to(parent.resolve())
return True
except ValueError:
return False
def _stage_safetensors_for_load(scail_path: Path) -> Path:
if not STAGE_SAFETENSORS_FOR_LOAD:
return scail_path
source = Path(scail_path)
if _is_relative_to(source, STAGING_ROOT):
return source
stage_dir = Path(os.getenv("SCAIL_MODEL_LOAD_CACHE", str(STAGING_ROOT / "scail2_model_load")))
stage_dir.mkdir(parents=True, exist_ok=True)
staged = stage_dir / source.name
if staged.exists() and staged.stat().st_size == source.stat().st_size:
return staged
return _copy_file_with_progress(source, staged, "Staging SCAIL-2 safetensors for local load")
def _download_checkpoint_if_needed(include_original_dit: bool = False) -> Path:
env_dir = os.getenv("SCAIL_CKPT_DIR")
if env_dir:
ckpt_dir = Path(env_dir)
if not ckpt_dir.exists():
raise RuntimeError(f"SCAIL_CKPT_DIR does not exist: {ckpt_dir}")
return ckpt_dir
local_dir = Path(os.getenv("SCAIL_CKPT_CACHE", str(STORAGE_ROOT / "scail2_ckpt")))
has_base_assets = (
(local_dir / "Wan2.1_VAE.pth").exists()
and (local_dir / "umt5-xxl").exists()
and (local_dir / CLIP_CKPT_NAME).exists()
)
if has_base_assets:
return local_dir
logging.info("Downloading SCAIL-2 base checkpoint assets from %s", MODEL_REPO_ID)
if include_original_dit:
logging.warning(
"include_original_dit is deprecated here; original DiT conversion staging "
"uses SCAIL_ORIGINAL_DIT_CACHE instead."
)
snapshot_download(
repo_id=MODEL_REPO_ID,
local_dir=str(local_dir),
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=BASE_ALLOW_PATTERNS,
)
return local_dir
def _download_original_dit_for_conversion() -> Path:
env_dir = os.getenv("SCAIL_ORIGINAL_DIT_DIR")
if env_dir:
original_dir = Path(env_dir)
original_path = original_dir / ORIGINAL_DIT_REL_PATH
if not original_path.exists():
raise RuntimeError(f"SCAIL_ORIGINAL_DIT_DIR is missing {ORIGINAL_DIT_REL_PATH}: {original_dir}")
return original_dir
local_dir = Path(os.getenv("SCAIL_ORIGINAL_DIT_CACHE", str(STAGING_ROOT / "scail2_original_dit")))
original_path = local_dir / ORIGINAL_DIT_REL_PATH
if original_path.exists():
return local_dir
logging.info(
"Downloading original SCAIL-2 DiT checkpoint for one-time conversion into %s",
local_dir,
)
snapshot_download(
repo_id=MODEL_REPO_ID,
local_dir=str(local_dir),
local_dir_use_symlinks=False,
resume_download=True,
allow_patterns=[ORIGINAL_DIT_REL_PATH],
)
return local_dir
def _prepare_assets_for_runtime() -> str:
"""Download CPU-side assets before any ZeroGPU function starts."""
global _ASSET_STATUS, _ASSET_ERROR
try:
ckpt_dir = _download_checkpoint_if_needed(include_original_dit=False)
scail_path = _find_converted_safetensors(ckpt_dir)
if scail_path is None and AUTO_CONVERT:
original_dir = _download_original_dit_for_conversion()
scail_path = _maybe_convert_checkpoint(original_dir, None)
if scail_path is None:
_ASSET_STATUS = (
"Base checkpoint assets are present, but no converted safetensors file "
"was found. Set SCAIL_SAFETENSORS_REPO_ID plus SCAIL_SAFETENSORS_FILENAME, "
"or set SCAIL_SAFETENSORS_PATH. Automatic startup conversion is disabled "
"only when SCAIL_AUTO_CONVERT=0."
)
else:
_ASSET_STATUS = (
"Assets ready. Base checkpoint: "
f"{ckpt_dir}. Converted DiT safetensors: {scail_path}."
)
_ASSET_ERROR = None
except Exception:
_ASSET_ERROR = traceback.format_exc()
_ASSET_STATUS = "Asset preparation failed. See the traceback below."
logging.exception("Asset preparation failed")
return _ASSET_STATUS if _ASSET_ERROR is None else _ASSET_STATUS + "\n\n" + _ASSET_ERROR
def _maybe_convert_checkpoint(ckpt_dir: Path, scail_path: Path | None) -> Path:
global _LAST_CONVERTED_SAFETENSORS
if scail_path is not None:
return scail_path
if not AUTO_CONVERT:
raise RuntimeError(
"Converted SCAIL-2 safetensors file was not found. For the wan-scail2 branch, "
"provide SCAIL_SAFETENSORS_PATH=/path/to/SCAIL-2.safetensors, or place "
"SCAIL-2.safetensors at the repo root. Automatic startup conversion is disabled "
"because SCAIL_AUTO_CONVERT=0."
)
persistent_dir = Path(os.getenv("SCAIL_CONVERTED_DIR", str(STORAGE_ROOT / "scail2_converted")))
persistent_path = persistent_dir / "SCAIL-2.safetensors"
if persistent_path.exists():
return persistent_path
if CONVERT_TO_STAGING_FIRST:
save_dir = Path(os.getenv("SCAIL_CONVERSION_WORK_DIR", str(STAGING_ROOT / "scail2_converted_work")))
else:
save_dir = persistent_dir
save_dir.mkdir(parents=True, exist_ok=True)
save_path = save_dir / "SCAIL-2.safetensors"
if save_path.exists():
_LAST_CONVERTED_SAFETENSORS = save_path
if save_path != persistent_path:
_copy_file_with_progress(save_path, persistent_path, "Persisting converted SCAIL-2 safetensors to storage")
return save_path
logging.info("Converting checkpoint to safetensors: %s", save_path)
convert_env = os.environ.copy()
# PyTorch >= 2.6 changed torch.load's default to weights_only=True.
# The official SCAIL-2 FSDP checkpoint needs the legacy trusted-pickle path
# during conversion. Only do this for the official checkpoint conversion step.
convert_env["TORCH_FORCE_NO_WEIGHTS_ONLY_LOAD"] = "1"
subprocess.run(
[
sys.executable,
str(ROOT / "convert.py"),
"--scail-dir",
str(ckpt_dir),
"--save-path",
str(save_path),
],
check=True,
cwd=str(ROOT),
env=convert_env,
)
_LAST_CONVERTED_SAFETENSORS = save_path
if save_path != persistent_path:
_copy_file_with_progress(save_path, persistent_path, "Persisting converted SCAIL-2 safetensors to storage")
return save_path
def _install_attention_patch():
"""Use HF Kernels flash-attn2 when available; otherwise fall back to SDPA.
SCAIL-2 imports `flash_attention` directly in model_scail2.py. This monkey patch
avoids requiring a locally built `flash_attn` wheel on Spaces.
"""
import wan.modules.attention as attention_mod
hf_flash_attn2 = None
try:
from kernels import get_kernel
hf_flash_attn2 = get_kernel("kernels-community/flash-attn2", version=2)
logging.info("Using kernels-community/flash-attn2 through HF Kernels.")
except Exception as exc:
if torch.cuda.is_available():
device_name = torch.cuda.get_device_name(0)
capability = torch.cuda.get_device_capability(0)
else:
device_name = "no cuda"
capability = None
logging.warning("Could not initialize HF Kernels flash-attn2: %r", exc)
logging.warning(
"Attention fallback environment: torch=%s cuda=%s device=%s capability=%s",
torch.__version__,
torch.version.cuda,
device_name,
capability,
)
def patched_flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.0,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
half_dtypes = (torch.float16, torch.bfloat16)
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
if hf_flash_attn2 is not None and q.device.type == "cuda":
if q_lens is None:
q_var = half(q.flatten(0, 1))
q_lens_t = torch.full((b,), lq, dtype=torch.int32, device=q.device)
else:
q_lens_t = q_lens.to(device=q.device, dtype=torch.int32)
q_var = half(torch.cat([u[: int(n)] for u, n in zip(q, q_lens_t)]))
if k_lens is None:
k_var = half(k.flatten(0, 1))
v_var = half(v.flatten(0, 1))
k_lens_t = torch.full((b,), lk, dtype=torch.int32, device=k.device)
else:
k_lens_t = k_lens.to(device=k.device, dtype=torch.int32)
k_var = half(torch.cat([u[: int(n)] for u, n in zip(k, k_lens_t)]))
v_var = half(torch.cat([u[: int(n)] for u, n in zip(v, k_lens_t)]))
q_var = q_var.to(v_var.dtype)
k_var = k_var.to(v_var.dtype)
if q_scale is not None:
q_var = q_var * q_scale
cu_q = torch.cat([q_lens_t.new_zeros([1]), q_lens_t]).cumsum(0, dtype=torch.int32)
cu_k = torch.cat([k_lens_t.new_zeros([1]), k_lens_t]).cumsum(0, dtype=torch.int32)
try:
out = hf_flash_attn2.flash_attn_varlen_func(
q=q_var,
k=k_var,
v=v_var,
cu_seqlens_q=cu_q,
cu_seqlens_k=cu_k,
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
)
if isinstance(out, tuple):
out = out[0]
return out.unflatten(0, (b, lq)).type(out_dtype)
except Exception as exc:
logging.warning("HF Kernels flash-attn2 failed, falling back to SDPA: %s", exc)
if q_lens is not None and not torch.all(q_lens == lq):
logging.warning("SDPA fallback ignores variable q_lens; demo batch size should stay at 1.")
if k_lens is not None and not torch.all(k_lens == lk):
logging.warning("SDPA fallback ignores variable k_lens; demo batch size should stay at 1.")
q_sdpa = q.transpose(1, 2).to(dtype)
k_sdpa = k.transpose(1, 2).to(dtype)
v_sdpa = v.transpose(1, 2).to(dtype)
out = torch.nn.functional.scaled_dot_product_attention(
q_sdpa,
k_sdpa,
v_sdpa,
attn_mask=None,
dropout_p=dropout_p,
is_causal=causal,
scale=softmax_scale,
)
return out.transpose(1, 2).contiguous().type(out_dtype)
attention_mod.flash_attention = patched_flash_attention
# Several modules import `flash_attention` directly at import time
# (`from .attention import flash_attention`), so patch those bound names too.
for module_name in (
"wan.modules.clip",
"wan.modules.model",
"wan.modules.model_scail",
"wan.modules.model_scail2",
):
try:
module = __import__(module_name, fromlist=["flash_attention"])
if hasattr(module, "flash_attention"):
module.flash_attention = patched_flash_attention
logging.info("Patched %s.flash_attention", module_name)
except Exception as exc:
logging.warning("Could not patch %s.flash_attention: %s", module_name, exc)
def _import_runtime():
global _WAN, _GENERATE_VIDEO, _SCAIL_CONFIGS, _SCAIL_CONFIG_PATHS
if _WAN is not None:
return
_require_repo_layout()
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import wan
from generate import generate_video
from wan.configs import SCAIL_CONFIGS, SCAIL_CONFIG_PATHS
_install_attention_patch()
_WAN = wan
_GENERATE_VIDEO = generate_video
_SCAIL_CONFIGS = SCAIL_CONFIGS
_SCAIL_CONFIG_PATHS = SCAIL_CONFIG_PATHS
def _prepare_runtime_for_startup() -> str:
"""Import Wan/SCAIL and resolve HF Kernels before any ZeroGPU call."""
global _RUNTIME_STATUS, _RUNTIME_ERROR
try:
_import_runtime()
_RUNTIME_STATUS = "Runtime ready. Attention backend has been initialized at startup."
_RUNTIME_ERROR = None
except Exception:
_RUNTIME_ERROR = traceback.format_exc()
_RUNTIME_STATUS = "Runtime preparation failed. See the traceback below."
logging.exception("Runtime preparation failed")
return _RUNTIME_STATUS if _RUNTIME_ERROR is None else _RUNTIME_STATUS + "\n\n" + _RUNTIME_ERROR
def _get_pipeline():
global _PIPELINE, _PIPELINE_KEY
_import_runtime()
ckpt_dir = _download_checkpoint_if_needed(include_original_dit=False)
scail_path = _find_converted_safetensors(ckpt_dir)
if scail_path is None and AUTO_CONVERT:
original_dir = _download_original_dit_for_conversion()
scail_path = _maybe_convert_checkpoint(original_dir, None)
else:
scail_path = _maybe_convert_checkpoint(ckpt_dir, scail_path)
scail_load_path = _stage_safetensors_for_load(scail_path)
config_path = Path(os.getenv("SCAIL_CONFIG_PATH", _SCAIL_CONFIG_PATHS[MODEL_NAME]))
if not config_path.is_absolute():
config_path = ROOT / config_path
lora_path = os.getenv("SCAIL_LORA_PATH") or None
lora_alpha = float(os.getenv("SCAIL_LORA_ALPHA", "1.0"))
key = (str(ckpt_dir), str(scail_load_path), str(config_path), lora_path, lora_alpha)
if _PIPELINE is not None and _PIPELINE_KEY == key:
return _PIPELINE
logging.info("Loading SCAIL-2 pipeline.")
cfg = _SCAIL_CONFIGS[MODEL_NAME]
_PIPELINE = _WAN.SCAIL2Pipeline(
config=cfg,
checkpoint_dir=str(ckpt_dir),
scail_safetensors_path=str(scail_load_path),
scail_config_path=str(config_path),
device_id=0,
rank=0,
t5_fsdp=False,
dit_fsdp=False,
use_usp=False,
t5_cpu=False,
lora_path=lora_path,
lora_alpha=lora_alpha,
)
_PIPELINE_KEY = key
return _PIPELINE
def _prepare_pipeline_for_startup() -> str:
"""Preload the CUDA pipeline at module startup for ZeroGPU CUDA emulation."""
global _PIPELINE_STATUS, _PIPELINE_ERROR
try:
_get_pipeline()
_PIPELINE_STATUS = "Pipeline preloaded at startup."
_PIPELINE_ERROR = None
except Exception:
_PIPELINE_ERROR = traceback.format_exc()
_PIPELINE_STATUS = "Pipeline preload failed. See the traceback below."
logging.exception("Pipeline preload failed")
return _PIPELINE_STATUS if _PIPELINE_ERROR is None else _PIPELINE_STATUS + "\n\n" + _PIPELINE_ERROR
def _is_gradio_native_file_path(path: Path) -> bool:
"""Return True when Gradio can serve a file without explicit allowed_paths."""
path = path.resolve()
native_roots = [ROOT.resolve(), Path(tempfile.gettempdir()).resolve()]
return any(_is_relative_to(path, root) for root in native_roots)
def _prepare_output_for_gradio(path: str | Path) -> str:
"""Return a display path accepted by Gradio.
The canonical output is saved in OUTPUT_DIR, which may live on a mounted
Bucket such as /data/scail2_outputs. Gradio can be strict about returned
file paths, so we keep the persistent copy in the Bucket and return a
temporary copy under /tmp for the UI component.
"""
source = Path(path)
if not source.exists():
raise RuntimeError(f"Generated video was not found: {source}")
if _is_gradio_native_file_path(source):
return str(source)
gradio_dir = Path(
os.getenv(
"SCAIL_GRADIO_OUTPUT_CACHE",
str(Path(tempfile.gettempdir()) / "scail2_gradio_outputs"),
)
)
gradio_dir.mkdir(parents=True, exist_ok=True)
dest = gradio_dir / source.name
shutil.copy2(source, dest)
logging.info("Copied generated video for Gradio display: %s -> %s", source, dest)
return str(dest)
def _duration_for_job(*args, **kwargs):
if _PIPELINE is None:
return int(os.getenv("SCAIL_GPU_DURATION", str(GPU_DURATION_COLD)))
return int(os.getenv("SCAIL_GPU_DURATION", str(GPU_DURATION_WARM)))
def _run_scail_job(
image_path,
mask_image_path,
pose_path,
mask_video_path,
prompt,
replace_flag,
target_h,
target_w,
sample_steps,
guide_scale,
sample_shift,
seed,
segment_len,
segment_overlap,
progress=None,
):
if progress is not None:
progress(0.02, desc="Loading SCAIL-2 pipeline")
pipeline = _get_pipeline()
cfg = _SCAIL_CONFIGS[MODEL_NAME]
save_file = OUTPUT_DIR / f"scail2_{uuid.uuid4().hex}.mp4"
if progress is not None:
progress(0.12, desc="Preparing inputs")
args = SimpleNamespace(
target_h=int(target_h),
target_w=int(target_w),
sample_shift=float(sample_shift),
sample_solver=DEFAULT_SOLVER,
segment_len=int(segment_len),
segment_overlap=int(segment_overlap),
sample_steps=int(sample_steps),
sample_guide_scale=float(guide_scale),
base_seed=int(seed),
offload_model=True,
save_file=str(save_file),
save_dir=str(OUTPUT_DIR),
prompt=prompt or "",
)
if progress is not None:
progress(0.15, desc="Generating video")
_GENERATE_VIDEO(
pipeline,
prompt or "",
str(image_path),
str(mask_image_path),
str(pose_path),
str(mask_video_path),
args,
device=0,
rank=0,
cfg=cfg,
input_idx=None,
replace_flag=bool(replace_flag),
additional_task_input=None,
)
if progress is not None:
progress(0.95, desc="Finalizing output")
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if progress is not None:
progress(0.98, desc="Preparing video for display")
display_file = _prepare_output_for_gradio(save_file)
if progress is not None:
progress(1.0, desc="Done")
return display_file
@spaces.GPU(duration=_duration_for_job, size=GPU_SIZE)
def generate_from_example(
example_name,
prompt,
sample_steps,
guide_scale,
sample_shift,
seed,
target_size,
segment_len,
segment_overlap,
progress=gr.Progress(track_tqdm=True),
):
try:
progress(0.0, desc="Checking example inputs")
examples = _existing_examples()
if example_name not in examples:
raise RuntimeError(f"Example is missing from this checkout: {example_name}")
example = examples[example_name]
target_w, target_h = [int(v) for v in str(target_size).split("x")]
output = _run_scail_job(
_abs(example.image),
_abs(example.mask_image),
_abs(example.pose),
_abs(example.mask_video),
prompt,
example.replace_flag,
target_h,
target_w,
sample_steps,
guide_scale,
sample_shift,
seed,
segment_len,
segment_overlap,
progress=progress,
)
return output, "Done."
except Exception:
logging.exception("Generation failed")
return None, traceback.format_exc()
@spaces.GPU(duration=_duration_for_job, size=GPU_SIZE)
def generate_from_uploads(
image,
mask_image,
pose_video,
mask_video,
prompt,
mode,
sample_steps,
guide_scale,
sample_shift,
seed,
target_size,
segment_len,
segment_overlap,
progress=gr.Progress(track_tqdm=True),
):
try:
progress(0.0, desc="Checking uploaded inputs")
required = {
"reference image": image,
"reference mask": mask_image,
"driving/rendered video": pose_video,
"driving mask video": mask_video,
}
missing = [name for name, value in required.items() if value is None]
if missing:
raise RuntimeError("Missing required input(s): " + ", ".join(missing))
target_w, target_h = [int(v) for v in str(target_size).split("x")]
output = _run_scail_job(
image,
mask_image,
pose_video,
mask_video,
prompt,
mode == "replacement",
target_h,
target_w,
sample_steps,
guide_scale,
sample_shift,
seed,
segment_len,
segment_overlap,
progress=progress,
)
return output, "Done."
except Exception:
logging.exception("Generation failed")
return None, traceback.format_exc()
def load_example_preview(example_name):
examples = _existing_examples()
if example_name not in examples:
return None, None, None, None, "", "animation"
example = examples[example_name]
mode = "replacement" if example.replace_flag else "animation"
return (
_abs(example.image),
_abs(example.pose),
_abs(example.mask_image),
_abs(example.mask_video),
example.prompt,
mode,
)
def _startup_message():
try:
_require_repo_layout()
examples = _existing_examples()
if not examples:
return "Repo layout detected, but no prepared examples were found."
return (
f"Ready. Found {len(examples)} prepared example(s). "
f"Storage root: {STORAGE_ROOT}. Staging root: {STAGING_ROOT}. "
f"Output dir: {OUTPUT_DIR}.\n\n"
f"{_ASSET_STATUS}\n\n{_RUNTIME_STATUS}\n\n{_PIPELINE_STATUS}\n\n"
"Attention backend: HF Kernels flash-attn2 when available, otherwise SDPA."
)
except Exception as exc:
return str(exc)
def build_ui():
examples = _existing_examples()
example_names = list(examples.keys())
default_example = example_names[0] if example_names else None
default_preview = load_example_preview(default_example) if default_example else (None, None, None, None, "", "animation")
with gr.Blocks(title="SCAIL-2 ZeroGPU Demo") as demo:
gr.Markdown(
"# SCAIL-2 Character Animation Demo\n"
"Animate or replace characters with SCAIL-2, directly in the browser. "
"The easiest way to try the model is to start from one of the prepared examples below.\n\n"
"This Space runs a heavy 14B video model on Hugging Face ZeroGPU. "
"Generation is not realtime: a short example can take a few minutes, especially "
"when the Space has just started.\n\n"
"For custom inputs, use the Advanced Uploads tab with already-prepared SCAIL-2 "
"conditions: reference image, reference mask, driving video, and driving mask. "
"Automatic SCAIL-Pose/SAM3 preprocessing is not enabled in this public demo yet."
)
startup = gr.Textbox(value=_startup_message(), label="Startup status", interactive=False)
with gr.Tab("Prepared Examples"):
with gr.Row():
example_dropdown = gr.Dropdown(
choices=example_names,
value=default_example,
label="Example",
)
mode_view = gr.Textbox(value=default_preview[5], label="Mode", interactive=False)
with gr.Row():
ref_preview = gr.Image(value=default_preview[0], label="Reference", interactive=False)
driving_preview = gr.Video(value=default_preview[1], label="Driving / rendered video")
with gr.Row():
ref_mask_preview = gr.Image(value=default_preview[2], label="Reference mask", interactive=False)
driving_mask_preview = gr.Video(value=default_preview[3], label="Driving mask")
prompt = gr.Textbox(value=default_preview[4], label="Prompt", lines=3)
with gr.Row():
sample_steps = gr.Slider(4, 40, value=8, step=1, label="Steps")
guide_scale = gr.Slider(1.0, 8.0, value=DEFAULT_GUIDE_SCALE, step=0.1, label="CFG")
sample_shift = gr.Slider(1.0, 6.0, value=DEFAULT_SHIFT, step=0.1, label="Shift")
with gr.Row():
seed = gr.Number(value=42, precision=0, label="Seed")
target_size = gr.Dropdown(["896x512", "512x896", "1280x704", "704x1280"], value=f"{DEFAULT_TARGET_W}x{DEFAULT_TARGET_H}", label="Target size")
segment_len = gr.Number(value=DEFAULT_SEGMENT_LEN, precision=0, label="Segment length")
segment_overlap = gr.Number(value=DEFAULT_SEGMENT_OVERLAP, precision=0, label="Segment overlap")
run_example = gr.Button("Generate", variant="primary")
output_video = gr.Video(label="Output")
status = gr.Textbox(label="Run status", lines=8)
example_dropdown.change(
load_example_preview,
inputs=[example_dropdown],
outputs=[ref_preview, driving_preview, ref_mask_preview, driving_mask_preview, prompt, mode_view],
)
run_example.click(
generate_from_example,
inputs=[
example_dropdown,
prompt,
sample_steps,
guide_scale,
sample_shift,
seed,
target_size,
segment_len,
segment_overlap,
],
outputs=[output_video, status],
)
with gr.Tab("Advanced Uploads"):
gr.Markdown(
"Upload custom inputs that have already been prepared for SCAIL-2: reference image, "
"reference mask, driving/rendered video, and driving mask video. "
"This public demo does not run SCAIL-Pose/SAM3 preprocessing yet."
)
with gr.Row():
up_image = gr.Image(type="filepath", label="Reference image")
up_mask_image = gr.Image(type="filepath", label="Reference mask")
with gr.Row():
up_pose_video = gr.Video(label="Driving / rendered video")
up_mask_video = gr.Video(label="Driving mask / replace mask")
up_mode = gr.Radio(["animation", "replacement"], value="animation", label="Mode")
up_prompt = gr.Textbox(label="Prompt", lines=3)
with gr.Row():
up_steps = gr.Slider(4, 40, value=8, step=1, label="Steps")
up_cfg = gr.Slider(1.0, 8.0, value=DEFAULT_GUIDE_SCALE, step=0.1, label="CFG")
up_shift = gr.Slider(1.0, 6.0, value=DEFAULT_SHIFT, step=0.1, label="Shift")
with gr.Row():
up_seed = gr.Number(value=42, precision=0, label="Seed")
up_target_size = gr.Dropdown(["896x512", "512x896", "1280x704", "704x1280"], value=f"{DEFAULT_TARGET_W}x{DEFAULT_TARGET_H}", label="Target size")
up_segment_len = gr.Number(value=DEFAULT_SEGMENT_LEN, precision=0, label="Segment length")
up_segment_overlap = gr.Number(value=DEFAULT_SEGMENT_OVERLAP, precision=0, label="Segment overlap")
run_upload = gr.Button("Generate from uploads", variant="primary")
upload_output = gr.Video(label="Output")
upload_status = gr.Textbox(label="Run status", lines=8)
run_upload.click(
generate_from_uploads,
inputs=[
up_image,
up_mask_image,
up_pose_video,
up_mask_video,
up_prompt,
up_mode,
up_steps,
up_cfg,
up_shift,
up_seed,
up_target_size,
up_segment_len,
up_segment_overlap,
],
outputs=[upload_output, upload_status],
)
return demo
if __name__ == "__main__":
if os.getenv("SCAIL_PRELOAD_ASSETS", "1") == "1":
_prepare_assets_for_runtime()
if os.getenv("SCAIL_PRELOAD_RUNTIME", "1") == "1":
_prepare_runtime_for_startup()
if PRELOAD_PIPELINE:
_prepare_pipeline_for_startup()
build_ui().queue(max_size=8).launch(
allowed_paths=[str(OUTPUT_DIR.resolve())],
show_error=True,
)