| | |
| | |
| | import argparse |
| |
|
| |
|
| | def str2bool(v): |
| | if isinstance(v, bool): |
| | return v |
| | if v.lower() in ('true'): |
| | return True |
| | elif v.lower() in ('false'): |
| | return False |
| | else: |
| | raise argparse.ArgumentTypeError('Boolean value expected.') |
| |
|
| | def add_model_specific_args(args, parser): |
| | from infinity.models.videovae.models import CVIVIT_VQGAN, CNN_VQGAN, FLUX_VAE, MS_VAE, CogVAE, SlowFastVAE, HunYuanVAE, CogVAEL |
| | |
| | if args.tokenizer == "cvivit": |
| | parser = CVIVIT_VQGAN.add_model_specific_args(parser) |
| | vae_model = CVIVIT_VQGAN |
| | elif args.tokenizer == "cnn": |
| | parser = CNN_VQGAN.add_model_specific_args(parser) |
| | vae_model = CNN_VQGAN |
| | elif args.tokenizer in ["flux"]: |
| | parser = CNN_VQGAN.add_model_specific_args(parser) |
| | parser = FLUX_VAE.add_model_specific_args(parser) |
| | vae_model = FLUX_VAE |
| | elif args.tokenizer == "ms": |
| | parser = CNN_VQGAN.add_model_specific_args(parser) |
| | parser = FLUX_VAE.add_model_specific_args(parser) |
| | vae_model = MS_VAE |
| | elif args.tokenizer in ["sd", "sd-vq", "mar", "cogvideox_origin", "vidtok", "open-sora-plan", "step-fun", "hunyuan_origin"]: |
| | vae_model = None |
| | pass |
| | elif args.tokenizer in ["cogvideox"]: |
| | parser = CogVAE.add_model_specific_args(parser) |
| | parser = FLUX_VAE.add_model_specific_args(parser) |
| | vae_model = CogVAE |
| | elif args.tokenizer in ["cogvideoxl"]: |
| | parser = CogVAEL.add_model_specific_args(parser) |
| | parser = FLUX_VAE.add_model_specific_args(parser) |
| | vae_model = CogVAEL |
| | elif args.tokenizer in ["slow-fast"]: |
| | parser = SlowFastVAE.add_model_specific_args(parser) |
| | parser = FLUX_VAE.add_model_specific_args(parser) |
| | vae_model = SlowFastVAE |
| | elif args.tokenizer in ["hunyuan"]: |
| | parser = HunYuanVAE.add_model_specific_args(parser) |
| | parser = FLUX_VAE.add_model_specific_args(parser) |
| | vae_model = HunYuanVAE |
| | else: |
| | raise NotImplementedError |
| | return args, parser, vae_model |
| |
|
| | class MainArgs: |
| | @staticmethod |
| | def add_main_args(parser): |
| | |
| | parser.add_argument('--max_steps', type=int, default=1e6) |
| | parser.add_argument('--log_every', type=int, default=1) |
| | parser.add_argument('--ckpt_every', type=int, default=1000) |
| | parser.add_argument('--default_root_dir', type=str, required=True) |
| | parser.add_argument('--compile', type=str, default="no", choices=["no", "yes"]) |
| | parser.add_argument('--ema', type=str, default="no", choices=["no", "yes"]) |
| | parser.add_argument('--mfu_logging', type=str, default="no", choices=["no", "yes"]) |
| | parser.add_argument('--dataloader_init_epoch', type=int, default=-1) |
| | parser.add_argument('--context_parallel_size', type=int, default=0) |
| |
|
| | |
| | parser.add_argument('--lr', type=float, default=1e-4) |
| | parser.add_argument('--beta1', type=float, default=0.9) |
| | parser.add_argument('--beta2', type=float, default=0.95) |
| | parser.add_argument('--optim_type', type=str, default="Adam", choices=["Adam", "AdamW"]) |
| | parser.add_argument('--disc_optim_type', type=str, default=None, choices=[None, "rmsprop"]) |
| | parser.add_argument('--max_grad_norm', type=float, default=1.0) |
| | parser.add_argument('--max_grad_norm_disc', type=float, default=1.0) |
| | parser.add_argument('--disable_sch', action="store_true") |
| | parser.add_argument('--scheduler', type=str, default="no", choices=["no", "linear"]) |
| | parser.add_argument('--warmup_steps', type=int, default=0) |
| | parser.add_argument('--lr_min', type=float, default=0.) |
| | parser.add_argument('--warmup_lr_init', type=float, default=0.) |
| |
|
| | |
| | parser.add_argument('--patch_size', type=int, default=8) |
| | parser.add_argument('--temporal_patch_size', type=int, default=4) |
| | parser.add_argument('--embedding_dim', type=int, default=256) |
| | parser.add_argument('--codebook_dim', type=int, default=16) |
| | parser.add_argument('--use_vae', action="store_true") |
| | parser.add_argument('--fix_model', type=str, default='no', choices=['no', 'encoder', 'encoder_decoder']) |
| |
|
| | |
| | parser.add_argument('--use_stochastic_depth', action="store_true") |
| | parser.add_argument("--drop_rate", type=float, default=0.0) |
| | parser.add_argument('--schedule_mode', type=str, default="original", choices=["original", "dynamic", "dense", "same1", "same2", "same3", "half", "dense_f8", "dense_f8_double"]) |
| | parser.add_argument('--lr_drop', nargs='*', type=int, default=None, help="A list of numeric values. Example: --values 270 300") |
| | parser.add_argument('--lr_drop_rate', type=float, default=0.1) |
| | parser.add_argument('--keep_first_quant', action="store_true") |
| | parser.add_argument('--keep_last_quant', action="store_true") |
| | parser.add_argument('--remove_residual_detach', action="store_true") |
| | parser.add_argument('--use_out_phi', action="store_true") |
| | parser.add_argument('--use_out_phi_res', action="store_true") |
| | parser.add_argument('--use_lecam_reg', action="store_true") |
| | parser.add_argument('--lecam_weight', type=float, default=0.05) |
| | parser.add_argument('--perceptual_model', type=str, default="vgg16", choices=["vgg16", "resnet50", "resnet50_v2"]) |
| | parser.add_argument('--base_ch_disc', type=int, default=64) |
| | parser.add_argument('--random_flip', action="store_true") |
| | parser.add_argument('--flip_prob', type=float, default=0.5) |
| | parser.add_argument('--flip_mode', type=str, default="stochastic", choices=["stochastic", "deterministic", "stochastic_dynamic"]) |
| | parser.add_argument('--max_flip_lvl', type=int, default=1) |
| | parser.add_argument('--not_load_optimizer', action="store_true") |
| | parser.add_argument('--use_lecam_reg_zero', action="store_true") |
| | parser.add_argument('--freeze_encoder', action="store_true") |
| | parser.add_argument('--rm_downsample', action="store_true") |
| | parser.add_argument('--random_flip_1lvl', action="store_true") |
| | parser.add_argument('--flip_lvl_idx', type=int, default=0) |
| | parser.add_argument('--drop_when_test', action="store_true") |
| | parser.add_argument('--drop_lvl_idx', type=int, default=None) |
| | parser.add_argument('--drop_lvl_num', type=int, default=0) |
| | parser.add_argument('--compute_all_commitment', action="store_true") |
| | parser.add_argument('--disable_codebook_usage', action="store_true") |
| | parser.add_argument('--freeze_enc_main', action="store_true") |
| | parser.add_argument('--freeze_dec_main', action="store_true") |
| | parser.add_argument('--random_short_schedule', action="store_true") |
| | parser.add_argument('--short_schedule_prob', type=float, default=0.5) |
| | parser.add_argument('--use_bernoulli', action="store_true") |
| | parser.add_argument('--use_rot_trick', action="store_true") |
| | parser.add_argument('--disable_flip_prob', type=float, default=0.0) |
| | parser.add_argument('--dino_disc', action="store_true") |
| | parser.add_argument('--quantizer_type', type=str, default='MultiScaleBSQ') |
| | parser.add_argument('--lfq_weight', type=float, default=0.) |
| | parser.add_argument('--entropy_loss_weight', type=float, default=0.1) |
| | parser.add_argument('--visu_every', type=int, default=1000) |
| | parser.add_argument('--commitment_loss_weight', type=float, default=0.25) |
| | parser.add_argument('--bsq_version', type=str, default="v1", choices=["v1", "v2"]) |
| | parser.add_argument('--diversity_gamma', type=float, default=1) |
| | parser.add_argument('--bs1_for1024', action="store_true") |
| | parser.add_argument('--casual_multi_scale', action="store_true") |
| | parser.add_argument('--double_compress_t', action="store_true") |
| | parser.add_argument('--temporal_slicing', action="store_true") |
| | parser.add_argument('--latent_adjust_type', type=str, default=None) |
| | parser.add_argument('--compute_latent_loss', action="store_true") |
| | parser.add_argument('--latent_loss_weight', type=float, default=0.0) |
| |
|
| | |
| | parser.add_argument('--disc_version', type=str, default="v1") |
| | parser.add_argument('--magvit_disc', action="store_true") |
| | parser.add_argument('--disc_type', type=str, default="patchgan", choices=["patchgan", "stylegan"]) |
| | parser.add_argument('--sigmoid_in_disc', action="store_true") |
| | parser.add_argument('--activation_in_disc', type=str, default="leaky_relu") |
| | parser.add_argument('--apply_blur', action="store_true") |
| | parser.add_argument('--apply_noise', action="store_true") |
| | parser.add_argument('--dis_warmup_steps', type=int, default=0) |
| | parser.add_argument('--dis_lr_multiplier', type=float, default=1.) |
| | parser.add_argument('--dis_minlr_multiplier', action="store_true") |
| | parser.add_argument('--disc_channels', type=int, default=64) |
| | parser.add_argument('--disc_layers', type=int, default=3) |
| | parser.add_argument('--discriminator_iter_start', type=int, default=0) |
| | parser.add_argument('--disc_pretrain_iter', type=int, default=0) |
| | parser.add_argument('--disc_optim_steps', type=int, default=1) |
| | parser.add_argument('--disc_warmup', type=int, default=0) |
| | parser.add_argument('--disc_pool', type=str, default="no", choices=["no", "yes"]) |
| | parser.add_argument('--disc_pool_size', type=int, default=100) |
| | parser.add_argument('--disc_temporal_compress', type=str, default="yes", choices=["no", "yes"]) |
| | parser.add_argument('--disc_use_blur', type=str, default="yes", choices=["no", "yes"]) |
| | parser.add_argument('--disc_stylegan_downsample_base', type=int, default=2) |
| |
|
| | parser = MainArgs.add_loss_args(parser) |
| | parser = MainArgs.add_accelerate_args(parser) |
| |
|
| | |
| | parser.add_argument('--tokenizer', type=str, required=True) |
| | parser.add_argument('--pretrained', type=str, default=None) |
| | parser.add_argument('--pretrained_mode', type=str, default="full") |
| | parser.add_argument('--pretrained_ema', type=str, default="no") |
| | parser.add_argument('--inflation_pe', action="store_true") |
| | parser.add_argument('--init_vgen', type=str, default='no', choices=['no', 'keep', 'average']) |
| | parser.add_argument('--no_init_idis', action="store_true") |
| | parser.add_argument('--init_idis', type=str, default='keep', choices=['no', 'keep']) |
| | parser.add_argument('--init_vdis', type=str, default="no") |
| |
|
| | |
| | parser.add_argument('--enable_nan_detector', action='store_true') |
| | parser.add_argument('--turn_on_profiler', action='store_true') |
| | parser.add_argument('--profiler_scheduler_wait_steps', type=int, default=10) |
| | parser.add_argument('--debug', action='store_true') |
| | parser.add_argument('--video_logger', action='store_true') |
| | parser.add_argument('--bytenas', type=str, default="sg") |
| | parser.add_argument('--username', type=str, default="zhufengda") |
| | parser.add_argument('--seed', type=int, default=1234) |
| | parser.add_argument('--vq_to_vae', action='store_true') |
| | parser.add_argument('--load_not_strict', action='store_true') |
| | parser.add_argument('--zero', type=int, default=0, choices=[0, 1, 2, 3]) |
| | parser.add_argument('--bucket_cap_mb', type=int, default=40) |
| | parser.add_argument('--manual_gc_interval', type=int, default=10000) |
| |
|
| | return parser |
| | |
| | @staticmethod |
| | def add_loss_args(parser): |
| | parser.add_argument("--recon_loss_type", type=str, default='l1', choices=['l1', 'l2']) |
| | parser.add_argument('--video_perceptual_weight', type=float, default=0.) |
| | parser.add_argument('--image_gan_weight', type=float, default=1.0) |
| | parser.add_argument('--video_gan_weight', type=float, default=1.0) |
| | parser.add_argument('--image_disc_weight', type=float, default=0.) |
| | parser.add_argument('--video_disc_weight', type=float, default=0.) |
| | parser.add_argument('--l1_weight', type=float, default=4.0) |
| | parser.add_argument('--gan_feat_weight', type=float, default=0.0) |
| | parser.add_argument('--lpips_model', type=str, default='vgg', choices=['vgg', 'resnet50']) |
| | parser.add_argument('--perceptual_weight', type=float, default=0.0) |
| | parser.add_argument('--kl_weight', type=float, default=0.) |
| | parser.add_argument('--norm_type', type=str, default='group', choices=['batch', 'group', "no"]) |
| | parser.add_argument('--disc_loss_type', type=str, default='hinge', choices=['hinge', 'vanilla']) |
| | parser.add_argument('--gan_image4video', type=str, default='yes', choices=['no', 'yes']) |
| | return parser |
| |
|
| | @staticmethod |
| | def add_accelerate_args(parser): |
| | parser.add_argument('--use_checkpoint', action="store_true") |
| | parser.add_argument('--precision', type=str, default="fp32", choices=['fp32', 'bf16']) |
| | parser.add_argument('--encoder_dtype', type=str, default="fp32", choices=['fp32', 'bf16']) |
| | parser.add_argument('--decoder_dtype', type=str, default="fp32", choices=['fp32', 'bf16']) |
| | parser.add_argument('--upcast_attention', type=str, default="", choices=["qk", "qkv"]) |
| | parser.add_argument('--upcast_tf32', action="store_true") |
| | return parser |
| |
|
| | def format_args(args): |
| | |
| | script_content = "#!/bin/bash\n\n" |
| | script_content += "torchrun \\\n" |
| | script_content += " --nproc_per_node=$ARNOLD_WORKER_GPU \\\n" |
| | script_content += " --nnodes=$ARNOLD_WORKER_NUM --master_addr=$ARNOLD_WORKER_0_HOST \\\n" |
| | script_content += " --node_rank=$ARNOLD_ID --master_port=$port \\\n" |
| | script_content += " train.py \\\n" |
| |
|
| | |
| | for k, v in args.__dict__.items(): |
| | script_content += f" --{k} {v} \\\n" |
| |
|
| | |
| | script_content = script_content.rstrip(" \\\n") + "\n" |
| | return script_content |
| |
|
| | def init_resolution(resolution, num_datasets): |
| | if len(resolution) == 1: |
| | resolution = [(resolution[0], resolution[0])] * num_datasets |
| | elif len(resolution) == num_datasets: |
| | resolution = [(resolution[i], resolution[i]) for i in range(len(resolution))] |
| | elif len(resolution) == num_datasets * 2: |
| | resolution = [(resolution[i], resolution[i+1]) for i in range(0, len(resolution), 2)] |
| | else: |
| | raise NotImplementedError |
| | return resolution |
| |
|