Z-Image-Turbo / VideoX-Fun /scripts /wan2.2_fun /README_TRAIN_CONTROL.md
yongqiang
initialize this repo
ba96580

Training Code

We can choose whether to use DeepSpeed and FSDP in Wan-Fun, which can save a lot of video memory.

The metadata_control.json is a little different from normal json in Wan-Fun, you need to add a control_file_path, and DWPose is suggested as tool to generate control file.

[
    {
      "file_path": "train/00000001.mp4",
      "control_file_path": "control/00000001.mp4",
      "text": "A group of young men in suits and sunglasses are walking down a city street.",
      "type": "video"
    },
    {
      "file_path": "train/00000002.jpg",
      "control_file_path": "control/00000002.jpg",
      "text": "A group of young men in suits and sunglasses are walking down a city street.",
      "type": "image"
    },
    .....
]

Some parameters in the sh file can be confusing, and they are explained in this document:

  • enable_bucket is used to enable bucket training. When enabled, the model does not crop the images and videos at the center, but instead, it trains the entire images and videos after grouping them into buckets based on resolution.
  • Sample size Configuration Guide
    • video_sample_size represents the resolution size of videos; when random_hw_adapt is True, it represents the minimum value between video and image resolutions.
    • image_sample_size represents the resolution size of images; when random_hw_adapt is True, it represents the maximum value between video and image resolutions.
    • token_sample_size represents the resolution corresponding to the maximum token length when training_with_video_token_length is True.
    • Due to potential confusion in configuration, if you don't require arbitrary resolution for finetuning, it is recommended to set video_sample_size, image_sample_size, and token_sample_size to the same fixed value, such as (320, 480, 512, 640, 960).
      • All set to 320 represents 240P.
      • All set to 480 represents 320P.
      • All set to 640 represents 480P.
      • All set to 960 represents 720P.
  • random_frame_crop is used for random cropping on video frames to simulate videos with different frame counts.
  • random_hw_adapt is used to enable automatic height and width scaling for images and videos. When random_hw_adapt is enabled, the training images will have their height and width set to image_sample_size as the maximum and min(video_sample_size, 512) as the minimum. For training videos, the height and width will be set to image_sample_size as the maximum and min(video_sample_size, 512) as the minimum.
    • For example, when random_hw_adapt is enabled, with video_sample_n_frames=49, video_sample_size=1024, and image_sample_size=1024, the resolution of image inputs for training is 512x512 to 1024x1024, and the resolution of video inputs for training is 512x512x49 to 1024x1024x49.
    • For example, when random_hw_adapt is enabled, with video_sample_n_frames=49, video_sample_size=256, and image_sample_size=1024, the resolution of image inputs for training is 256x256 to 1024x1024, and the resolution of video inputs for training is 256x256x49.
  • training_with_video_token_length specifies training the model according to token length. For training images and videos, the height and width will be set to image_sample_size as the maximum and video_sample_size as the minimum.
    • For example, when training_with_video_token_length is enabled, with video_sample_n_frames=49, token_sample_size=1024, video_sample_size=256, and image_sample_size=1024, the resolution of image inputs for training is 256x256 to 1024x1024, and the resolution of video inputs for training is 256x256x49 to 1024x1024x49.
    • For example, when training_with_video_token_length is enabled, with video_sample_n_frames=49, token_sample_size=512, video_sample_size=256, and image_sample_size=1024, the resolution of image inputs for training is 256x256 to 1024x1024, and the resolution of video inputs for training is 256x256x49 to 1024x1024x9.
    • The token length for a video with dimensions 512x512 and 49 frames is 13,312. We need to set the token_sample_size = 512.
      • At 512x512 resolution, the number of video frames is 49 (~= 512 * 512 * 49 / 512 / 512).
      • At 768x768 resolution, the number of video frames is 21 (~= 512 * 512 * 49 / 768 / 768).
      • At 1024x1024 resolution, the number of video frames is 9 (~= 512 * 512 * 49 / 1024 / 1024).
      • These resolutions combined with their corresponding lengths allow the model to generate videos of different sizes.
  • resume_from_checkpoint is used to set the training should be resumed from a previous checkpoint. Use a path or "latest" to automatically select the last available checkpoint.
  • train_mode is used to set the training mode.
    • The models named Wan2.1-Fun-*-Control are trained in the control_ref mode.
    • The models named Wan2.1-Fun-*-Control-Camera are trained in the control_ref_camera mode.
  • control_ref_image is used to specify the type of control image. The available options are first_frame and random.
  • first_frame is used in V1.0 because V1.0 supports using a specified start frame as the control image. The Control-Camera models use the first frame as the control image.
  • random is used in V1.1 because V1.1 supports both using a specified start frame and a reference image as the control image.
  • add_full_ref_image_in_self_attention determines whether to include the reference image in self-attention. This option is used in V1.1, as it supports using a reference image as the control image. It should not be used in V1.0 and Control-Camera models.
  • add_inpaint_info determines whether to incorporate inpaint information into the model training. When enabled, this allows the model to support specifying starting and ending images in the controls during generation.
  • boundary_type: The Wan2.2 series includes two distinct models that handle different noise levels, specified via the boundary_type parameter. low: Corresponds to the low noise model (low_noise_model). high: Corresponds to the high noise model. (high_noise_model). full: Corresponds to the ti2v 5B model (single mode).

If you want to train 5B Wan2.2 model, please set config to config/wan2.2/wan_civitai_5b.yaml and set boundary_type to full.

When train model with multi machines, please set the params as follows:

export MASTER_ADDR="your master address"
export MASTER_PORT=10086
export WORLD_SIZE=1 # The number of machines
export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8
export RANK=0 # The rank of this machine

accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/wan2.2_fun/xxx.py

Wan-Fun-Control without deepspeed:

export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-Control"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. 
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO

accelerate launch --mixed_precision="bf16" scripts/wan2.2_fun/train_control.py \
  --config_path="config/wan2.2/wan_civitai_i2v.yaml" \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATASET_NAME \
  --train_data_meta=$DATASET_META_NAME \
  --image_sample_size=640 \
  --video_sample_size=640 \
  --token_sample_size=640 \
  --video_sample_stride=2 \
  --video_sample_n_frames=81 \
  --train_batch_size=1 \
  --video_repeat=1 \
  --gradient_accumulation_steps=1 \
  --dataloader_num_workers=8 \
  --num_train_epochs=100 \
  --checkpointing_steps=50 \
  --learning_rate=2e-05 \
  --lr_scheduler="constant_with_warmup" \
  --lr_warmup_steps=100 \
  --seed=42 \
  --output_dir="output_dir" \
  --gradient_checkpointing \
  --mixed_precision="bf16" \
  --adam_weight_decay=3e-2 \
  --adam_epsilon=1e-10 \
  --vae_mini_batch=1 \
  --max_grad_norm=0.05 \
  --random_hw_adapt \
  --training_with_video_token_length \
  --enable_bucket \
  --uniform_sampling \
  --boundary_type="low" \
  --low_vram \
  --train_mode="control_ref" \
  --control_ref_image="random" \
  --add_inpaint_info \
  --add_full_ref_image_in_self_attention \
  --trainable_modules "."

Wan-Fun-Control with deepspeed:

export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-Control"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. 
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO

accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/wan2.2_fun/train_control.py \
  --config_path="config/wan2.2/wan_civitai_i2v.yaml" \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATASET_NAME \
  --train_data_meta=$DATASET_META_NAME \
  --image_sample_size=640 \
  --video_sample_size=640 \
  --token_sample_size=640 \
  --video_sample_stride=2 \
  --video_sample_n_frames=81 \
  --train_batch_size=1 \
  --video_repeat=1 \
  --gradient_accumulation_steps=1 \
  --dataloader_num_workers=8 \
  --num_train_epochs=100 \
  --checkpointing_steps=50 \
  --learning_rate=2e-05 \
  --lr_scheduler="constant_with_warmup" \
  --lr_warmup_steps=100 \
  --seed=42 \
  --output_dir="output_dir" \
  --gradient_checkpointing \
  --mixed_precision="bf16" \
  --adam_weight_decay=3e-2 \
  --adam_epsilon=1e-10 \
  --vae_mini_batch=1 \
  --max_grad_norm=0.05 \
  --random_hw_adapt \
  --training_with_video_token_length \
  --enable_bucket \
  --uniform_sampling \
  --boundary_type="low" \
  --low_vram \
  --use_deepspeed \
  --train_mode="control_ref" \
  --control_ref_image="random" \
  --add_inpaint_info \
  --add_full_ref_image_in_self_attention \
  --trainable_modules "."

DeepSpeed Zero-3 is not highly recommended at the moment. In this repository, using FSDP has fewer errors and is more stable.

Wan-Fun-Control with DeepSpeed Zero-3:

Wan with DeepSpeed Zero-3 is suitable for 14B Wan at high resolutions. After training, you can use the following command to get the final model:

python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization

Training shell command is as follows:

export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-Control"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. 
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO

accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/wan2.2_fun/train_control.py \
  --config_path="config/wan2.2/wan_civitai_i2v.yaml" \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATASET_NAME \
  --train_data_meta=$DATASET_META_NAME \
  --image_sample_size=640 \
  --video_sample_size=640 \
  --token_sample_size=640 \
  --video_sample_stride=2 \
  --video_sample_n_frames=81 \
  --train_batch_size=1 \
  --video_repeat=1 \
  --gradient_accumulation_steps=1 \
  --dataloader_num_workers=8 \
  --num_train_epochs=100 \
  --checkpointing_steps=50 \
  --learning_rate=2e-05 \
  --lr_scheduler="constant_with_warmup" \
  --lr_warmup_steps=100 \
  --seed=42 \
  --output_dir="output_dir" \
  --gradient_checkpointing \
  --mixed_precision="bf16" \
  --adam_weight_decay=3e-2 \
  --adam_epsilon=1e-10 \
  --vae_mini_batch=1 \
  --max_grad_norm=0.05 \
  --random_hw_adapt \
  --training_with_video_token_length \
  --enable_bucket \
  --uniform_sampling \
  --boundary_type="low" \
  --low_vram \
  --use_deepspeed \
  --train_mode="control_ref" \
  --control_ref_image="random" \
  --add_inpaint_info \
  --add_full_ref_image_in_self_attention \
  --trainable_modules "."

Wan-Fun-Control with FSDP:

Wan with FSDP is suitable for 14B Wan at high resolutions. Training shell command is as follows:

export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-Control"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. 
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO

accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=WanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.2_fun/train_control.py \
  --config_path="config/wan2.2/wan_civitai_i2v.yaml" \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATASET_NAME \
  --train_data_meta=$DATASET_META_NAME \
  --image_sample_size=640 \
  --video_sample_size=640 \
  --token_sample_size=640 \
  --video_sample_stride=2 \
  --video_sample_n_frames=81 \
  --train_batch_size=1 \
  --video_repeat=1 \
  --gradient_accumulation_steps=1 \
  --dataloader_num_workers=8 \
  --num_train_epochs=100 \
  --checkpointing_steps=50 \
  --learning_rate=2e-05 \
  --lr_scheduler="constant_with_warmup" \
  --lr_warmup_steps=100 \
  --seed=42 \
  --output_dir="output_dir" \
  --gradient_checkpointing \
  --mixed_precision="bf16" \
  --adam_weight_decay=3e-2 \
  --adam_epsilon=1e-10 \
  --vae_mini_batch=1 \
  --max_grad_norm=0.05 \
  --random_hw_adapt \
  --training_with_video_token_length \
  --enable_bucket \
  --uniform_sampling \
  --boundary_type="low" \
  --low_vram \
  --train_mode="control_ref" \
  --control_ref_image="random" \
  --add_inpaint_info \
  --add_full_ref_image_in_self_attention \
  --trainable_modules "."

If you want to train 5B Wan2.2 model, please set config to config/wan2.2/wan_civitai_5b.yaml and set boundary_type to full. Training shell command is as follows:

export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-5B-Control"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA. 
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO

accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=WanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.2_fun/train_control.py \
  --config_path="config/wan2.2/wan_civitai_5b.yaml" \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$DATASET_NAME \
  --train_data_meta=$DATASET_META_NAME \
  --image_sample_size=640 \
  --video_sample_size=640 \
  --token_sample_size=640 \
  --video_sample_stride=2 \
  --video_sample_n_frames=81 \
  --train_batch_size=1 \
  --video_repeat=1 \
  --gradient_accumulation_steps=1 \
  --dataloader_num_workers=8 \
  --num_train_epochs=100 \
  --checkpointing_steps=50 \
  --learning_rate=2e-05 \
  --lr_scheduler="constant_with_warmup" \
  --lr_warmup_steps=100 \
  --seed=42 \
  --output_dir="output_dir" \
  --gradient_checkpointing \
  --mixed_precision="bf16" \
  --adam_weight_decay=3e-2 \
  --adam_epsilon=1e-10 \
  --vae_mini_batch=1 \
  --max_grad_norm=0.05 \
  --random_hw_adapt \
  --training_with_video_token_length \
  --enable_bucket \
  --uniform_sampling \
  --boundary_type="full" \
  --low_vram \
  --train_mode="control_ref" \
  --control_ref_image="random" \
  --add_inpaint_info \
  --add_full_ref_image_in_self_attention \
  --trainable_modules "."