| |
| import argparse |
| import logging |
| import os |
| import sys |
| import warnings |
| from datetime import datetime |
|
|
| warnings.filterwarnings('ignore') |
|
|
| import random |
|
|
| import torch |
| import torch.distributed as dist |
|
|
| from einops import rearrange |
| from PIL import Image |
|
|
| import wan |
| from wan.configs import SCAIL_CONFIGS, SCAIL_CONFIG_PATHS |
| from wan.utils.utils import cache_video, str2bool |
| from wan.utils.scail_utils import load_image_to_tensor_chw_normalized, load_video_for_pose_sample, resize_for_rectangle_crop, get_tasks_from_txt |
|
|
|
|
| def _validate_args(args): |
| assert args.ckpt_dir is not None, "Please specify the checkpoint directory." |
| if args.txt is None: |
| assert args.pose is not None, "Please specify the pose video." |
| assert args.image is not None, "Please specify the reference image." |
| assert str(args.model).upper() in SCAIL_CONFIGS |
|
|
| args.model = str(args.model).upper() |
|
|
| if args.scail_config_path is None: |
| args.scail_config_path = SCAIL_CONFIG_PATHS[args.model] |
|
|
| if args.sample_steps is None: |
| args.sample_steps = 40 |
|
|
| if args.sample_shift is None: |
| args.sample_shift = 3.0 |
|
|
| if args.additional_ref_image is not None and args.additional_ref_mask_image is None: |
| raise ValueError("Please specify --additional_ref_mask_image when using --additional_ref_image.") |
| if args.additional_ref_image is None and args.additional_ref_mask_image is not None: |
| raise ValueError("--additional_ref_mask_image requires --additional_ref_image.") |
| if args.additional_ref_image is not None and len(args.additional_ref_image) != len(args.additional_ref_mask_image): |
| raise ValueError( |
| f"--additional_ref_image and --additional_ref_mask_image must have the same number of paths, " |
| f"got {len(args.additional_ref_image)} and {len(args.additional_ref_mask_image)}.") |
|
|
| args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(0, sys.maxsize) |
|
|
|
|
| def _parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--model", |
| type=str, |
| default="SCAIL-14B", |
| help="Type of SCAIL model. Choices: [SCAIL-14B, SCAIL-1.3B]") |
| parser.add_argument( |
| "--ckpt_dir", |
| type=str, |
| default="./SCAIL-Preview/", |
| help="The path to the checkpoint directory.") |
| parser.add_argument( |
| "--offload_model", |
| type=str2bool, |
| default=None, |
| help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." |
| ) |
| parser.add_argument( |
| "--ulysses_size", |
| type=int, |
| default=1, |
| help="The size of the ulysses parallelism in DiT.") |
| parser.add_argument( |
| "--ring_size", |
| type=int, |
| default=1, |
| help="The size of the ring attention parallelism in DiT.") |
| parser.add_argument( |
| "--t5_fsdp", |
| action="store_true", |
| default=False, |
| help="Whether to use FSDP for T5.") |
| parser.add_argument( |
| "--t5_cpu", |
| action="store_true", |
| default=False, |
| help="Whether to place T5 model on CPU.") |
| parser.add_argument( |
| "--dit_fsdp", |
| action="store_true", |
| default=False, |
| help="Whether to use FSDP for DiT.") |
| parser.add_argument( |
| "--save_dir", |
| type=str, |
| default="samples", |
| help="The directory to save the generated videos when --txt is not None.") |
| parser.add_argument( |
| "--save_file", |
| type=str, |
| default=None, |
| help="The file to save the generated video to.") |
| parser.add_argument( |
| "--prompt", |
| type=str, |
| default=None, |
| help="The prompt to generate the video from.") |
| parser.add_argument( |
| "--base_seed", |
| type=int, |
| default=-1, |
| help="The seed to use for generating the video.") |
| parser.add_argument( |
| "--txt", |
| type=str, |
| default=None, |
| help="Path to txt file. Default: None") |
| parser.add_argument( |
| "--image", |
| type=str, |
| default=None, |
| help="The reference image to generate the video from.") |
| parser.add_argument( |
| "--additional_ref_image", "--additional_image", |
| dest="additional_ref_image", |
| type=str, |
| nargs="+", |
| default=None, |
| help="Additional reference image paths (beta).") |
| parser.add_argument( |
| "--additional_ref_mask_image", "--additional_mask_image", |
| dest="additional_ref_mask_image", |
| type=str, |
| nargs="+", |
| default=None, |
| help="Mask image paths for the additional reference images (beta).") |
| parser.add_argument( |
| "--mask_image", |
| type=str, |
| default=None, |
| help="The mask of reference image.") |
| parser.add_argument( |
| "--pose", |
| type=str, |
| default=None, |
| help="The rendered pose video to generate the video from.") |
| parser.add_argument( |
| "--mask_video", |
| type=str, |
| default=None, |
| help="The mask of driving video.") |
| parser.add_argument( |
| "--replace_flag", |
| action="store_true", |
| default=False, |
| help="Pass --replace_flag to run in replacement mode. Default: False (animation mode).") |
| parser.add_argument( |
| "--target_h", |
| type=int, |
| default=512, |
| help="The target height of the generated video.") |
| parser.add_argument( |
| "--target_w", |
| type=int, |
| default=896, |
| help="The target width of the generated video.") |
| parser.add_argument( |
| "--scail_path", |
| type=str, |
| default=None, |
| help="Path to converted SCAIL.safetensors") |
| parser.add_argument( |
| "--scail_config_path", |
| type=str, |
| default=None, |
| help="Path to config.json of SCAIL") |
| parser.add_argument( |
| "--sample_solver", |
| type=str, |
| default='unipc', |
| choices=['unipc', 'dpm++'], |
| help="The solver used to sample.") |
| parser.add_argument( |
| "--sample_steps", |
| type=int, |
| default=None, |
| help="The sampling steps.") |
| parser.add_argument( |
| "--sample_shift", |
| type=float, |
| default=None, |
| help="Sampling shift factor for flow matching schedulers.") |
| parser.add_argument( |
| "--sample_guide_scale", |
| type=float, |
| default=5.0, |
| help="Classifier free guidance scale.") |
| parser.add_argument( |
| "--segment_len", |
| type=int, |
| default=81, |
| help="The number of pixel frames to sample per segment for long-video inference.") |
| parser.add_argument( |
| "--segment_overlap", |
| type=int, |
| default=5, |
| help="The number of pixel frames reused as clean history between adjacent segments.") |
| parser.add_argument( |
| "--lora_path", |
| type=str, |
| default=None, |
| help="Path to safetensors of LoRA." |
| ) |
| parser.add_argument( |
| "--lora_alpha", |
| type=float, |
| default=1.0, |
| help="Strength of LoRA. Default: 1.0" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| _validate_args(args) |
|
|
| return args |
|
|
|
|
| def _init_logging(rank): |
| |
| if rank == 0: |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format="[%(asctime)s] %(levelname)s: %(message)s", |
| handlers=[logging.StreamHandler(stream=sys.stdout)]) |
| else: |
| logging.basicConfig(level=logging.ERROR) |
|
|
| def _check_input_path(path, name): |
| if path is None: |
| raise ValueError(f"Please specify {name}.") |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"{name} does not exist: {path}") |
| if not os.path.isfile(path): |
| raise FileNotFoundError(f"{name} is not a file: {path}") |
|
|
|
|
| def generate_video(pipeline: wan.SCAIL2Pipeline, prompt: str, image_path: str, image_mask_path: str, pose_path: str, driving_mask_path: str, args, device, rank, cfg, input_idx, replace_flag, additional_task_input=None): |
| _check_input_path(image_path, "input image") |
| _check_input_path(image_mask_path, "input mask image") |
| _check_input_path(pose_path, "input pose video") |
| _check_input_path(driving_mask_path, "input mask video") |
|
|
| additional_task_input = additional_task_input or {} |
| additional_input = {} |
|
|
| logging.info(f"Input prompt: {prompt}") |
| logging.info(f"Input image: {image_path}") |
| img = Image.open(image_path).convert("RGB") |
| target_h = args.target_h |
| target_w = args.target_w |
|
|
| img_uncropped = load_image_to_tensor_chw_normalized(img).to(device) |
| _, _, h, w = img_uncropped.shape |
| if target_h is None or target_w is None: |
| target_h, target_w = h, w |
| if (h < w and target_h > target_w) or (h > w and target_h < target_w): |
| target_h, target_w = target_w, target_h |
|
|
| logging.info(f"Input mask image: {image_mask_path}") |
| mask_img = Image.open(image_mask_path).convert("RGB") |
| mask_img_uncropped = load_image_to_tensor_chw_normalized(mask_img).to(device) |
|
|
| if additional_task_input.get("additional_ref_image_paths", None) is not None: |
| additional_ref_image_paths = additional_task_input["additional_ref_image_paths"] |
| additional_ref_mask_image_paths = additional_task_input["additional_ref_mask_image_paths"] |
| additional_imgs = [] |
| additional_mask_imgs = [] |
| for idx, (additional_ref_image_path, additional_ref_mask_image_path) in enumerate( |
| zip(additional_ref_image_paths, additional_ref_mask_image_paths)): |
| _check_input_path(additional_ref_image_path, f"additional ref image {idx}") |
| _check_input_path(additional_ref_mask_image_path, f"additional ref mask image {idx}") |
| logging.info(f"Input additional reference image {idx}: {additional_ref_image_path}") |
| additional_img = Image.open(additional_ref_image_path).convert("RGB") |
| additional_img_uncropped = load_image_to_tensor_chw_normalized(additional_img).to(device) |
| additional_img = resize_for_rectangle_crop(additional_img_uncropped, (target_h, target_w), reshape_mode="center") |
| additional_imgs.append(additional_img.squeeze(0)) |
| logging.info(f"Input additional reference mask image {idx}: {additional_ref_mask_image_path}") |
| additional_mask_img = Image.open(additional_ref_mask_image_path).convert("RGB") |
| additional_mask_img_uncropped = load_image_to_tensor_chw_normalized(additional_mask_img).to(device) |
| additional_mask_img = resize_for_rectangle_crop(additional_mask_img_uncropped, (target_h, target_w), reshape_mode="center") |
| additional_mask_imgs.append(additional_mask_img.squeeze(0)) |
| additional_input["additional_ref_imgs"] = additional_imgs |
| additional_input["additional_ref_mask_imgs"] = additional_mask_imgs |
|
|
| logging.info(f"Input pose video: {pose_path}") |
| pose_video = load_video_for_pose_sample(pose_path) |
| pose_video = pose_video.permute(0, 3, 1, 2) |
| pose_video = resize_for_rectangle_crop(pose_video, (target_h, target_w), reshape_mode="center") |
| pose_video = (pose_video - 127.5) / 127.5 |
|
|
| logging.info(f"Input mask video: {driving_mask_path}") |
| driving_mask_video = load_video_for_pose_sample(driving_mask_path) |
| driving_mask_video = driving_mask_video.permute(0, 3, 1, 2) |
| driving_mask_video = resize_for_rectangle_crop(driving_mask_video, (target_h, target_w), reshape_mode="center") |
| driving_mask_video = (driving_mask_video - 127.5) / 127.5 |
| driving_mask_video = rearrange(driving_mask_video, 't c h w -> c t h w') |
|
|
| img = resize_for_rectangle_crop(img_uncropped, (target_h, target_w), reshape_mode="center") |
| img = img.squeeze(0) |
|
|
| mask_img = resize_for_rectangle_crop(mask_img_uncropped, (target_h, target_w), reshape_mode="center") |
| mask_img = mask_img.squeeze(0) |
|
|
| logging.info(f"Mode: {'Replacement' if replace_flag else 'Animation'}") |
|
|
| logging.info("Generating video ...") |
| video = pipeline.generate( |
| prompt, |
| img, |
| ref_mask_img=mask_img, |
| pose_video=pose_video, |
| driving_mask_video=driving_mask_video, |
| replace_flag=replace_flag, |
| shift=args.sample_shift, |
| sample_solver=args.sample_solver, |
| segment_len=args.segment_len, |
| segment_overlap=args.segment_overlap, |
| sampling_steps=args.sample_steps, |
| guide_scale=args.sample_guide_scale, |
| seed=args.base_seed, |
| offload_model=args.offload_model, |
| **additional_input |
| ) |
|
|
| if rank == 0: |
| if args.save_file is None: |
| formatted_time = datetime.now().strftime("%Y%m%d_%H%M%S") |
| formatted_prompt = args.prompt.replace(" ", "_").replace("/", |
| "_")[:50] |
| suffix = '.mp4' |
| args.save_file = f"SCAIL2_{args.target_w}{'x' if sys.platform=='win32' else '*'}{args.target_h}_{args.ring_size}_{formatted_prompt}_{formatted_time}" + suffix |
| save_file = args.save_file |
| if input_idx is not None: |
| save_dir = os.path.join(args.save_dir, f"{input_idx:07}") |
| os.makedirs(save_dir, exist_ok=True) |
| save_file = os.path.join(save_dir, args.save_file) |
|
|
| logging.info(f"Saving generated video to {save_file}") |
| cache_video( |
| tensor=video[None], |
| save_file=save_file, |
| fps=cfg.sample_fps, |
| nrow=1, |
| normalize=True, |
| value_range=(-1, 1)) |
|
|
| def generate(args): |
| rank = int(os.getenv("RANK", 0)) |
| world_size = int(os.getenv("WORLD_SIZE", 1)) |
| local_rank = int(os.getenv("LOCAL_RANK", 0)) |
| device = local_rank |
| _init_logging(rank) |
|
|
| if args.offload_model is None: |
| args.offload_model = False if world_size > 1 else True |
| logging.info( |
| f"offload_model is not specified, set to {args.offload_model}.") |
| if world_size > 1: |
| torch.cuda.set_device(local_rank) |
| |
| |
| |
| |
| |
| else: |
| assert not ( |
| args.t5_fsdp or args.dit_fsdp |
| ), f"t5_fsdp and dit_fsdp are not supported in non-distributed environments." |
| assert not ( |
| args.ulysses_size > 1 or args.ring_size > 1 |
| ), f"context parallel are not supported in non-distributed environments." |
|
|
| if args.ulysses_size > 1 or args.ring_size > 1: |
| assert args.ulysses_size * args.ring_size == world_size, f"The number of ulysses_size and ring_size should be equal to the world size." |
| from xfuser.core.distributed import ( |
| init_distributed_environment, |
| initialize_model_parallel, |
| ) |
| init_distributed_environment( |
| rank=dist.get_rank(), world_size=dist.get_world_size()) |
|
|
| initialize_model_parallel( |
| sequence_parallel_degree=dist.get_world_size(), |
| ring_degree=args.ring_size, |
| ulysses_degree=args.ulysses_size, |
| ) |
|
|
| cfg = SCAIL_CONFIGS[args.model] |
| if args.ulysses_size > 1: |
| assert cfg.num_heads % args.ulysses_size == 0, f"`{cfg.num_heads=}` cannot be divided evenly by `{args.ulysses_size=}`." |
|
|
| logging.info(f"Generation job args: {args}") |
|
|
| if dist.is_initialized(): |
| base_seed = [args.base_seed] if rank == 0 else [None] |
| dist.broadcast_object_list(base_seed, src=0) |
| args.base_seed = base_seed[0] |
|
|
| if args.prompt is None: |
| args.prompt = "" |
|
|
| additional_task_input = {} |
| if args.additional_ref_image is not None: |
| additional_task_input["additional_ref_image_paths"] = args.additional_ref_image |
| additional_task_input["additional_ref_mask_image_paths"] = args.additional_ref_mask_image |
|
|
| if args.txt is not None: |
| raise NotImplementedError() |
| tasks = get_tasks_from_txt(args.txt) |
| logging.info(f"Total number of generation tasks: {len(tasks)}.") |
| tasks = tasks[rank::world_size] |
| else: |
| tasks = [(args.prompt, args.image, args.mask_image, args.pose, args.mask_video, None, additional_task_input)] |
| |
| logging.info("Creating SCAIL-2 pipeline.") |
| scail_pipeline = wan.SCAIL2Pipeline( |
| config=cfg, |
| checkpoint_dir=args.ckpt_dir, |
| scail_safetensors_path=args.scail_path, |
| scail_config_path=args.scail_config_path, |
| device_id=device, |
| rank=rank, |
| t5_fsdp=args.t5_fsdp, |
| dit_fsdp=args.dit_fsdp, |
| use_usp=(args.ulysses_size > 1 or args.ring_size > 1), |
| t5_cpu=args.t5_cpu, |
| lora_path=args.lora_path, |
| lora_alpha=args.lora_alpha, |
| ) |
|
|
| for task in tasks: |
| prompt, image_path, image_mask_path, pose_path, driving_mask_path, input_idx, additional_task_input = task |
| generate_video(scail_pipeline, prompt, image_path, image_mask_path, pose_path, driving_mask_path, args, device, rank, cfg, input_idx, args.replace_flag, additional_task_input) |
| |
| logging.info("Finished.") |
|
|
| if __name__ == "__main__": |
| args = _parse_args() |
| generate(args) |
|
|