Camellia997's picture
Upload folder using huggingface_hub
e14f899 verified
import argparse
import random
import sys
from .wan.configs import WAN_CONFIGS
from .wan.utils.utils import str2bool
def args_init():
parser = argparse.ArgumentParser(description="diffusers lite script")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--precision", type=str, default="bf16")
parser.add_argument("--task", type=str, default="i2v", choices=["i2v", "t2v", "flf2v"])
### Inference ###
# model
parser.add_argument("--base-dir", type=str, default="")
parser.add_argument("--transformer_path", nargs="?", const="", default="")
# lora
parser.add_argument("--lora-path", nargs="?", const="", default="")
parser.add_argument("--lora-alpha", type=float, default=1.0)
# dataset
parser.add_argument("--dataset-path", type=str, default="")
parser.add_argument("--resolution", nargs="+", type=int, default=[512])
parser.add_argument("--num-frames", type=int, default=81)
parser.add_argument("--batch-size", type=int, default=1)
# inference
parser.add_argument("--cfg", type=float, default=5.0)
parser.add_argument("--shift", type=float, default=3.0)
parser.add_argument("--step", type=int, default=50)
# save
parser.add_argument("--fps", type=int, default=16)
parser.add_argument("--save-dir", type=str, default="")
### Preprocess dataset ###
parser.add_argument("--json-paths", nargs="+", type=str, default=[""])
parser.add_argument("--video-dir", type=str, default="")
parser.add_argument("--image-dir", type=str, default="")
parser.add_argument("--extract-fps", type=int, default=24)
# model
parser.add_argument("--model-type", type=str, default="wanx", choices=["wanx", "ltx"])
parser.add_argument("--vae-path", type=str, default="")
parser.add_argument("--image-processor-path", type=str, default="")
parser.add_argument("--image-encoder-path", type=str, default="")
parser.add_argument("--tokenizer-path", type=str, default="")
parser.add_argument("--text-encoder-path", type=str, default="")
parser.add_argument("--max-sequence-length", type=int, default=512)
parser.add_argument("--vlm-type", type=str, default="qwenvl2", choices=["qwenvl2", "qwenvl2.5", "nocap"])
parser.add_argument("--vlm-path", nargs="?", const="", default="")
parser.add_argument("--max-new-tokens", type=int, default=256)
# prompt
parser.add_argument("--caption-template", type=str, default="")
parser.add_argument("--instruct-sentence", type=str, default="")
parser.add_argument("--negative-prompt", nargs="?", const="", default="")
parser.add_argument("--null-caption-length", type=int, default=0)
# save
parser.add_argument("--save-interval", type=int, default=100)
parser.add_argument("--save-path", type=str, default="")
args = parser.parse_args()
return args
def args_wan_init():
parser = argparse.ArgumentParser(
description="Generate a image or video from a text prompt or image using Wan"
)
parser.add_argument(
"--task",
type=str,
default="i2v-14B",
choices=list(WAN_CONFIGS.keys()),
help="The task to run.")
parser.add_argument(
"--size",
type=str,
default="1280*720",
# choices=list(SIZE_CONFIGS.keys()),
help="The area (width*height) of the generated video. For the I2V task, the aspect ratio of the output video will follow that of the input image."
)
parser.add_argument(
"--frame_num",
type=int,
default=81,
help="How many frames to sample from a image or video. The number should be 4n+1"
)
parser.add_argument(
"--ckpt_dir",
type=str,
default='',
help="The path to the checkpoint directory.")
parser.add_argument(
"--offload_model",
type=str2bool,
default=None,
help="Whether to offload the model to CPU after each model forward, reducing GPU memory usage."
)
parser.add_argument(
"--ulysses_size",
type=int,
default=1,
help="The size of the ulysses parallelism in DiT.")
parser.add_argument(
"--ring_size",
type=int,
default=1,
help="The size of the ring attention parallelism in DiT.")
parser.add_argument(
"--t5_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for T5.")
parser.add_argument(
"--t5_cpu",
action="store_true",
default=False,
help="Whether to place T5 model on CPU.")
parser.add_argument(
"--dit_fsdp",
action="store_true",
default=False,
help="Whether to use FSDP for DiT.")
parser.add_argument(
"--save_folder",
type=str,
default=None,
help="The folder to save the generated image or video to.")
parser.add_argument(
"--save_file",
type=str,
default=None,
help="The file to save the generated image or video to.")
parser.add_argument(
"--prompt",
type=str,
default=None,
help="The prompt to generate the image or video from.")
parser.add_argument(
"--base_seed",
type=int,
default=-1,
help="The seed to use for generating the image or video.")
parser.add_argument(
"--image",
type=str,
default=None,
help="The image to generate the video from.")
parser.add_argument(
"--sample_solver",
type=str,
default='unipc',
choices=['unipc', 'dpm++'],
help="The solver used to sample.")
parser.add_argument(
"--sample_steps", type=int, default=None, help="The sampling steps.")
parser.add_argument(
"--sample_shift",
type=float,
default=None,
help="Sampling shift factor for flow matching schedulers.")
parser.add_argument(
"--sample_guide_scale",
type=float,
default=6.0,
help="Classifier free guidance scale.")
parser.add_argument(
"--teacache_thresh",
type=float,
default=None,
help="The threshold for caching diffusion model steps.")
# NOTE: add by diffusers-lite to fill in blank args
parser.add_argument("--ddp_mode", type=bool, default=False)
# dataset
parser.add_argument("--dataset_path", type=str, default=None)
parser.add_argument("--resolution", nargs="+", type=int, default=[512])
parser.add_argument("--batch_size", type=int, default=1)
parser.add_argument("--negative_prompt", nargs="?", const="", default="")
# transformer
parser.add_argument("--transformer_path", nargs="?", const="", default="")
# lora
parser.add_argument("--lora_path", nargs="?", const="", default="")
parser.add_argument("--lora_alpha", type=float, default=1.0)
parser.add_argument("--distill_lora_path", nargs="?", const="", default="")
parser.add_argument("--distill_lora_alpha", type=float, default=1.0)
args = parser.parse_args()
assert args.ckpt_dir is not None, "Please specify the checkpoint directory."
assert args.task in WAN_CONFIGS, f"Unsupport task: {args.task}"
# The default sampling steps are 40 for image-to-video tasks and 50 for text-to-video tasks.
if args.sample_steps is None:
args.sample_steps = 40 if "i2v" in args.task else 50
if args.sample_shift is None:
args.sample_shift = 5.0
if "i2v" in args.task:# and args.size in ["832*480", "480*832"]
args.sample_shift = 3.0
# The default number of frames are 1 for text-to-image tasks and 81 for other tasks.
if args.frame_num is None:
args.frame_num = 1 if "t2i" in args.task else 81
# T2I frame_num check
if "t2i" in args.task:
assert args.frame_num == 1, f"Unsupport frame_num {args.frame_num} for task {args.task}"
args.base_seed = args.base_seed if args.base_seed >= 0 else random.randint(
0, sys.maxsize)
return args