File size: 14,424 Bytes
ba96580 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 |
## Training Code
We can choose whether to use DeepSpeed and FSDP in Wan, which can save a lot of video memory.
Some parameters in the sh file can be confusing, and they are explained in this document:
- `enable_bucket` is used to enable bucket training. When enabled, the model does not crop the images and videos at the center, but instead, it trains the entire images and videos after grouping them into buckets based on resolution.
- Sample size Configuration Guide
- `video_sample_size` represents the resolution size of videos; when `random_hw_adapt` is True, it represents the minimum value between video and image resolutions.
- `image_sample_size` represents the resolution size of images; when `random_hw_adapt` is True, it represents the maximum value between video and image resolutions.
- `token_sample_size` represents the resolution corresponding to the maximum token length when `training_with_video_token_length` is True.
- Due to potential confusion in configuration, **if you don't require arbitrary resolution for finetuning**, it is recommended to set `video_sample_size`, `image_sample_size`, and `token_sample_size` to the same fixed value, such as **(320, 480, 512, 640, 960)**.
- **All set to 320** represents **240P**.
- **All set to 480** represents **320P**.
- **All set to 640** represents **480P**.
- **All set to 960** represents **720P**.
- `random_frame_crop` is used for random cropping on video frames to simulate videos with different frame counts.
- `random_hw_adapt` is used to enable automatic height and width scaling for images and videos. When `random_hw_adapt` is enabled, the training images will have their height and width set to `image_sample_size` as the maximum and `min(video_sample_size, 512)` as the minimum. For training videos, the height and width will be set to `image_sample_size` as the maximum and `min(video_sample_size, 512)` as the minimum.
- For example, when `random_hw_adapt` is enabled, with `video_sample_n_frames=49`, `video_sample_size=1024`, and `image_sample_size=1024`, the resolution of image inputs for training is `512x512` to `1024x1024`, and the resolution of video inputs for training is `512x512x49` to `1024x1024x49`.
- For example, when `random_hw_adapt` is enabled, with `video_sample_n_frames=49`, `video_sample_size=256`, and `image_sample_size=1024`, the resolution of image inputs for training is `256x256` to `1024x1024`, and the resolution of video inputs for training is `256x256x49`.
- `training_with_video_token_length` specifies training the model according to token length. For training images and videos, the height and width will be set to `image_sample_size` as the maximum and `video_sample_size` as the minimum.
- For example, when `training_with_video_token_length` is enabled, with `video_sample_n_frames=49`, `token_sample_size=1024`, `video_sample_size=256`, and `image_sample_size=1024`, the resolution of image inputs for training is `256x256` to `1024x1024`, and the resolution of video inputs for training is `256x256x49` to `1024x1024x49`.
- For example, when `training_with_video_token_length` is enabled, with `video_sample_n_frames=49`, `token_sample_size=512`, `video_sample_size=256`, and `image_sample_size=1024`, the resolution of image inputs for training is `256x256` to `1024x1024`, and the resolution of video inputs for training is `256x256x49` to `1024x1024x9`.
- The token length for a video with dimensions 512x512 and 49 frames is 13,312. We need to set the `token_sample_size = 512`.
- At 512x512 resolution, the number of video frames is 49 (~= 512 * 512 * 49 / 512 / 512).
- At 768x768 resolution, the number of video frames is 21 (~= 512 * 512 * 49 / 768 / 768).
- At 1024x1024 resolution, the number of video frames is 9 (~= 512 * 512 * 49 / 1024 / 1024).
- These resolutions combined with their corresponding lengths allow the model to generate videos of different sizes.
- `train_mode` is used to specify the training mode, which can be either normal or inpaint. Since Wan uses the inpaint model to achieve image-to-video generation, the default is set to inpaint mode. If you only wish to achieve text-to-video generation, you can remove this line, and it will default to the text-to-video mode.
- `resume_from_checkpoint` is used to set the training should be resumed from a previous checkpoint. Use a path or `"latest"` to automatically select the last available checkpoint.
- `boundary_type`: The Wan2.2 series includes two distinct models that handle different noise levels, specified via the `boundary_type` parameter. `low`: Corresponds to the **low noise model** (low_noise_model). `high`: Corresponds to the **high noise model**. (high_noise_model). `full`: Corresponds to the ti2v 5B model (single mode).
If you want to train 5B Wan2.2 model, please set config to `config/wan2.2/wan_civitai_5b.yaml` and set boundary_type to `full`.
When train model with multi machines, please set the params as follows:
```sh
export MASTER_ADDR="your master address"
export MASTER_PORT=10086
export WORLD_SIZE=1 # The number of machines
export NUM_PROCESS=8 # The number of processes, such as WORLD_SIZE * 8
export RANK=0 # The rank of this machine
accelerate launch --mixed_precision="bf16" --main_process_ip=$MASTER_ADDR --main_process_port=$MASTER_PORT --num_machines=$WORLD_SIZE --num_processes=$NUM_PROCESS --machine_rank=$RANK scripts/wan2.2_fun/xxx.py
```
Wan T2V without deepspeed:
Training 14B Wan2.2 without DeepSpeed may result in insufficient GPU memory.
```sh
export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-InP"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA.
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO
accelerate launch --mixed_precision="bf16" scripts/wan2.2_fun/train.py \
--config_path="config/wan2.2/wan_civitai_i2v.yaml" \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATASET_NAME \
--train_data_meta=$DATASET_META_NAME \
--image_sample_size=640 \
--video_sample_size=640 \
--token_sample_size=640 \
--video_sample_stride=2 \
--video_sample_n_frames=81 \
--train_batch_size=1 \
--video_repeat=1 \
--gradient_accumulation_steps=1 \
--dataloader_num_workers=8 \
--num_train_epochs=100 \
--checkpointing_steps=50 \
--learning_rate=2e-05 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--seed=42 \
--output_dir="output_dir" \
--gradient_checkpointing \
--mixed_precision="bf16" \
--adam_weight_decay=3e-2 \
--adam_epsilon=1e-10 \
--vae_mini_batch=1 \
--max_grad_norm=0.05 \
--random_hw_adapt \
--training_with_video_token_length \
--enable_bucket \
--uniform_sampling \
--boundary_type="low" \
--low_vram \
--train_mode="normal" \
--trainable_modules "."
```
Wan T2V with Deepspeed Zero-2:
Wan with Deepspeed Zero-2 is suitable for training 14B Wan at low resolutions, but training 14B Wan at high resolutions may still result in insufficient GPU memory.
```sh
export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-InP"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA.
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO
accelerate launch --use_deepspeed --deepspeed_config_file config/zero_stage2_config.json --deepspeed_multinode_launcher standard scripts/wan2.2_fun/train.py \
--config_path="config/wan2.2/wan_civitai_i2v.yaml" \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATASET_NAME \
--train_data_meta=$DATASET_META_NAME \
--image_sample_size=640 \
--video_sample_size=640 \
--token_sample_size=640 \
--video_sample_stride=2 \
--video_sample_n_frames=81 \
--train_batch_size=1 \
--video_repeat=1 \
--gradient_accumulation_steps=1 \
--dataloader_num_workers=8 \
--num_train_epochs=100 \
--checkpointing_steps=50 \
--learning_rate=2e-05 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--seed=42 \
--output_dir="output_dir" \
--gradient_checkpointing \
--mixed_precision="bf16" \
--adam_weight_decay=3e-2 \
--adam_epsilon=1e-10 \
--vae_mini_batch=1 \
--max_grad_norm=0.05 \
--random_hw_adapt \
--training_with_video_token_length \
--enable_bucket \
--uniform_sampling \
--boundary_type="low" \
--low_vram \
--use_deepspeed \
--train_mode="inpaint" \
--trainable_modules "."
```
DeepSpeed Zero-3 is not highly recommended at the moment. In this repository, using FSDP has fewer errors and is more stable.
Wan T2V with DeepSpeed Zero-3:
Wan with DeepSpeed Zero-3 is suitable for 14B Wan at high resolutions. After training, you can use the following command to get the final model:
```sh
python scripts/zero_to_bf16.py output_dir/checkpoint-{our-num-steps} output_dir/checkpoint-{your-num-steps}-outputs --max_shard_size 80GB --safe_serialization
```
Training shell command is as follows:
```sh
export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-InP"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA.
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO
accelerate launch --zero_stage 3 --zero3_save_16bit_model true --zero3_init_flag true --use_deepspeed --deepspeed_config_file config/zero_stage3_config.json --deepspeed_multinode_launcher standard scripts/wan2.2_fun/train.py \
--config_path="config/wan2.2/wan_civitai_i2v.yaml" \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATASET_NAME \
--train_data_meta=$DATASET_META_NAME \
--image_sample_size=640 \
--video_sample_size=640 \
--token_sample_size=640 \
--video_sample_stride=2 \
--video_sample_n_frames=81 \
--train_batch_size=1 \
--video_repeat=1 \
--gradient_accumulation_steps=1 \
--dataloader_num_workers=8 \
--num_train_epochs=100 \
--checkpointing_steps=50 \
--learning_rate=2e-05 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--seed=42 \
--output_dir="output_dir" \
--gradient_checkpointing \
--mixed_precision="bf16" \
--adam_weight_decay=3e-2 \
--adam_epsilon=1e-10 \
--vae_mini_batch=1 \
--max_grad_norm=0.05 \
--random_hw_adapt \
--training_with_video_token_length \
--enable_bucket \
--uniform_sampling \
--boundary_type="low" \
--low_vram \
--use_deepspeed \
--train_mode="inpaint" \
--trainable_modules "."
```
Wan T2V with FSDP:
Wan with FSDP is suitable for 14B Wan at high resolutions. Training shell command is as follows:
```sh
export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-A14B-InP"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA.
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO
accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=WanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.2_fun/train.py \
--config_path="config/wan2.2/wan_civitai_i2v.yaml" \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATASET_NAME \
--train_data_meta=$DATASET_META_NAME \
--image_sample_size=640 \
--video_sample_size=640 \
--token_sample_size=640 \
--video_sample_stride=2 \
--video_sample_n_frames=81 \
--train_batch_size=1 \
--video_repeat=1 \
--gradient_accumulation_steps=1 \
--dataloader_num_workers=8 \
--num_train_epochs=100 \
--checkpointing_steps=50 \
--learning_rate=2e-05 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--seed=42 \
--output_dir="output_dir" \
--gradient_checkpointing \
--mixed_precision="bf16" \
--adam_weight_decay=3e-2 \
--adam_epsilon=1e-10 \
--vae_mini_batch=1 \
--max_grad_norm=0.05 \
--random_hw_adapt \
--training_with_video_token_length \
--enable_bucket \
--uniform_sampling \
--boundary_type="low" \
--low_vram \
--train_mode="inpaint" \
--trainable_modules "."
```
If you want to train 5B Wan2.2 model, please set config to `config/wan2.2/wan_civitai_5b.yaml` and set boundary_type to `full`. Training shell command is as follows:
```sh
export MODEL_NAME="models/Diffusion_Transformer/Wan2.2-Fun-5B-InP"
export DATASET_NAME="datasets/internal_datasets/"
export DATASET_META_NAME="datasets/internal_datasets/metadata.json"
# NCCL_IB_DISABLE=1 and NCCL_P2P_DISABLE=1 are used in multi nodes without RDMA.
# export NCCL_IB_DISABLE=1
# export NCCL_P2P_DISABLE=1
NCCL_DEBUG=INFO
accelerate launch --mixed_precision="bf16" --use_fsdp --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap=WanAttentionBlock --fsdp_sharding_strategy "FULL_SHARD" --fsdp_state_dict_type=SHARDED_STATE_DICT --fsdp_backward_prefetch "BACKWARD_PRE" --fsdp_cpu_ram_efficient_loading False scripts/wan2.2_fun/train.py \
--config_path="config/wan2.2/wan_civitai_5b.yaml" \
--pretrained_model_name_or_path=$MODEL_NAME \
--train_data_dir=$DATASET_NAME \
--train_data_meta=$DATASET_META_NAME \
--image_sample_size=640 \
--video_sample_size=640 \
--token_sample_size=640 \
--video_sample_stride=2 \
--video_sample_n_frames=81 \
--train_batch_size=1 \
--video_repeat=1 \
--gradient_accumulation_steps=1 \
--dataloader_num_workers=8 \
--num_train_epochs=100 \
--checkpointing_steps=50 \
--learning_rate=2e-05 \
--lr_scheduler="constant_with_warmup" \
--lr_warmup_steps=100 \
--seed=42 \
--output_dir="output_dir" \
--gradient_checkpointing \
--mixed_precision="bf16" \
--adam_weight_decay=3e-2 \
--adam_epsilon=1e-10 \
--vae_mini_batch=1 \
--max_grad_norm=0.05 \
--random_hw_adapt \
--training_with_video_token_length \
--enable_bucket \
--uniform_sampling \
--boundary_type="full" \
--low_vram \
--train_mode="inpaint" \
--trainable_modules "."
``` |