|
|
import argparse |
|
|
import random |
|
|
import sys |
|
|
|
|
|
from .wan.configs import WAN_CONFIGS |
|
|
from .wan.utils.utils import str2bool |
|
|
|
|
|
|
|
|
def args_init(): |
|
|
parser = argparse.ArgumentParser(description="diffusers lite script") |
|
|
parser.add_argument("--seed", type=int, default=42) |
|
|
parser.add_argument("--precision", type=str, default="bf16") |
|
|
parser.add_argument("--task", type=str, default="i2v", choices=["i2v", "t2v", "flf2v"]) |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--base-dir", type=str, default="") |
|
|
parser.add_argument("--transformer_path", nargs="?", const="", default="") |
|
|
|
|
|
|
|
|
parser.add_argument("--lora-path", nargs="?", const="", default="") |
|
|
parser.add_argument("--lora-alpha", type=float, default=1.0) |
|
|
|
|
|
|
|
|
parser.add_argument("--dataset-path", type=str, default="") |
|
|
parser.add_argument("--resolution", nargs="+", type=int, default=[512]) |
|
|
parser.add_argument("--num-frames", type=int, default=81) |
|
|
parser.add_argument("--batch-size", type=int, default=1) |
|
|
|
|
|
|
|
|
parser.add_argument("--cfg", type=float, default=5.0) |
|
|
parser.add_argument("--shift", type=float, default=3.0) |
|
|
parser.add_argument("--step", type=int, default=50) |
|
|
|
|
|
|
|
|
parser.add_argument("--fps", type=int, default=16) |
|
|
parser.add_argument("--save-dir", type=str, default="") |
|
|
|
|
|
|
|
|
parser.add_argument("--json-paths", nargs="+", type=str, default=[""]) |
|
|
parser.add_argument("--video-dir", type=str, default="") |
|
|
parser.add_argument("--image-dir", type=str, default="") |
|
|
parser.add_argument("--extract-fps", type=int, default=24) |
|
|
|
|
|
parser.add_argument("--model-type", type=str, default="wanx", choices=["wanx", "ltx"]) |
|
|
parser.add_argument("--vae-path", type=str, default="") |
|
|
parser.add_argument("--image-processor-path", type=str, default="") |
|
|
parser.add_argument("--image-encoder-path", type=str, default="") |
|
|
parser.add_argument("--tokenizer-path", type=str, default="") |
|
|
parser.add_argument("--text-encoder-path", type=str, default="") |
|
|
parser.add_argument("--max-sequence-length", type=int, default=512) |
|
|
parser.add_argument("--vlm-type", type=str, default="qwenvl2", choices=["qwenvl2", "qwenvl2.5", "nocap"]) |
|
|
parser.add_argument("--vlm-path", nargs="?", const="", default="") |
|
|
parser.add_argument("--max-new-tokens", type=int, default=256) |
|
|
|
|
|
parser.add_argument("--caption-template", type=str, default="") |
|
|
parser.add_argument("--instruct-sentence", type=str, default="") |
|
|
parser.add_argument("--negative-prompt", nargs="?", const="", default="") |
|
|
parser.add_argument("--null-caption-length", type=int, default=0) |
|
|
|
|
|
parser.add_argument("--save-interval", type=int, default=100) |
|
|
parser.add_argument("--save-path", type=str, default="") |
|
|
|
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
|
|
|
def args_wan_init(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Generate a image or video from a text prompt or image using Wan" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--task", |
|
|
type=str, |
|
|
default="i2v-14B", |
|
|
choices=list(WAN_CONFIGS.keys()), |
|
|
help="The task to run.") |
|
|
parser.add_argument( |
|
|
"--size", |
|
|
type=str, |
|
|
default="1280*720", |
|
|
|
|
|
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--frame_num", |
|
|
type=int, |
|
|
default=81, |
|
|
help="How many frames to sample from a image or video. The number should be 4n+1" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ckpt_dir", |
|
|
type=str, |
|
|
default='', |
|
|
help="The path to the checkpoint directory.") |
|
|
parser.add_argument( |
|
|
"--offload_model", |
|
|
type=str2bool, |
|
|
default=None, |
|
|
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ulysses_size", |
|
|
type=int, |
|
|
default=1, |
|
|
help="The size of the ulysses parallelism in DiT.") |
|
|
parser.add_argument( |
|
|
"--ring_size", |
|
|
type=int, |
|
|
default=1, |
|
|
help="The size of the ring attention parallelism in DiT.") |
|
|
parser.add_argument( |
|
|
"--t5_fsdp", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Whether to use FSDP for T5.") |
|
|
parser.add_argument( |
|
|
"--t5_cpu", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Whether to place T5 model on CPU.") |
|
|
parser.add_argument( |
|
|
"--dit_fsdp", |
|
|
action="store_true", |
|
|
default=False, |
|
|
help="Whether to use FSDP for DiT.") |
|
|
parser.add_argument( |
|
|
"--save_folder", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The folder to save the generated image or video to.") |
|
|
parser.add_argument( |
|
|
"--save_file", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The file to save the generated image or video to.") |
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The prompt to generate the image or video from.") |
|
|
parser.add_argument( |
|
|
"--base_seed", |
|
|
type=int, |
|
|
default=-1, |
|
|
help="The seed to use for generating the image or video.") |
|
|
parser.add_argument( |
|
|
"--image", |
|
|
type=str, |
|
|
default=None, |
|
|
help="The image to generate the video from.") |
|
|
parser.add_argument( |
|
|
"--sample_solver", |
|
|
type=str, |
|
|
default='unipc', |
|
|
choices=['unipc', 'dpm++'], |
|
|
help="The solver used to sample.") |
|
|
parser.add_argument( |
|
|
"--sample_steps", type=int, default=None, help="The sampling steps.") |
|
|
parser.add_argument( |
|
|
"--sample_shift", |
|
|
type=float, |
|
|
default=None, |
|
|
help="Sampling shift factor for flow matching schedulers.") |
|
|
parser.add_argument( |
|
|
"--sample_guide_scale", |
|
|
type=float, |
|
|
default=6.0, |
|
|
help="Classifier free guidance scale.") |
|
|
parser.add_argument( |
|
|
"--teacache_thresh", |
|
|
type=float, |
|
|
default=None, |
|
|
help="The threshold for caching diffusion model steps.") |
|
|
|
|
|
|
|
|
parser.add_argument("--ddp_mode", type=bool, default=False) |
|
|
|
|
|
parser.add_argument("--dataset_path", type=str, default=None) |
|
|
parser.add_argument("--resolution", nargs="+", type=int, default=[512]) |
|
|
parser.add_argument("--batch_size", type=int, default=1) |
|
|
parser.add_argument("--negative_prompt", nargs="?", const="", default="") |
|
|
|
|
|
parser.add_argument("--transformer_path", nargs="?", const="", default="") |
|
|
|
|
|
parser.add_argument("--lora_path", nargs="?", const="", default="") |
|
|
parser.add_argument("--lora_alpha", type=float, default=1.0) |
|
|
parser.add_argument("--distill_lora_path", nargs="?", const="", default="") |
|
|
parser.add_argument("--distill_lora_alpha", type=float, default=1.0) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
assert args.ckpt_dir is not None, "Please specify the checkpoint directory." |
|
|
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}" |
|
|
|
|
|
|
|
|
if args.sample_steps is None: |
|
|
args.sample_steps = 40 if "i2v" in args.task else 50 |
|
|
|
|
|
if args.sample_shift is None: |
|
|
args.sample_shift = 5.0 |
|
|
if "i2v" in args.task: |
|
|
args.sample_shift = 3.0 |
|
|
|
|
|
|
|
|
if args.frame_num is None: |
|
|
args.frame_num = 1 if "t2i" in args.task else 81 |
|
|
|
|
|
|
|
|
if "t2i" in args.task: |
|
|
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}" |
|
|
|
|
|
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint( |
|
|
0, sys.maxsize) |
|
|
|
|
|
return args |
|
|
|