Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .github/FUNDING.yml +3 -0
- .gitignore +8 -0
- .ipynb_checkpoints/README-checkpoint.md +123 -0
- .python-version +1 -0
- README.md +123 -0
- cache_latents.py +339 -0
- cache_text_encoder_outputs.py +214 -0
- convert_lora.py +137 -0
- dataset/__init__.py +0 -0
- dataset/config_utils.py +384 -0
- dataset/dataset_config.md +486 -0
- dataset/image_video_dataset.py +1786 -0
- fpack_cache_latents.py +454 -0
- fpack_cache_text_encoder_outputs.py +110 -0
- fpack_generate_video.py +1711 -0
- frame_pack/__init__.py +0 -0
- frame_pack/bucket_tools.py +30 -0
- frame_pack/clip_vision.py +14 -0
- frame_pack/framepack_utils.py +273 -0
- frame_pack/hunyuan.py +134 -0
- frame_pack/hunyuan_video_packed.py +2015 -0
- frame_pack/k_diffusion_hunyuan.py +128 -0
- frame_pack/uni_pc_fm.py +142 -0
- frame_pack/utils.py +617 -0
- frame_pack/wrapper.py +51 -0
- framepack_edit_output/framepack-edit-lora-000001.safetensors +3 -0
- framepack_edit_output/framepack-edit-lora-000002.safetensors +3 -0
- framepack_edit_output/framepack-edit-lora-000003.safetensors +3 -0
- framepack_edit_output/framepack-edit-lora-000004.safetensors +3 -0
- framepack_edit_output/framepack-edit-lora-000005.safetensors +3 -0
- framepack_edit_output/framepack-edit-lora-000006.safetensors +3 -0
- hunyuan_model/__init__.py +0 -0
- hunyuan_model/activation_layers.py +23 -0
- hunyuan_model/attention.py +295 -0
- hunyuan_model/autoencoder_kl_causal_3d.py +609 -0
- hunyuan_model/embed_layers.py +132 -0
- hunyuan_model/fp8_optimization.py +39 -0
- hunyuan_model/helpers.py +40 -0
- hunyuan_model/mlp_layers.py +118 -0
- hunyuan_model/models.py +1044 -0
- hunyuan_model/modulate_layers.py +76 -0
- hunyuan_model/norm_layers.py +79 -0
- hunyuan_model/pipeline_hunyuan_video.py +1100 -0
- hunyuan_model/posemb_layers.py +310 -0
- hunyuan_model/text_encoder.py +710 -0
- hunyuan_model/token_refiner.py +245 -0
- hunyuan_model/vae.py +446 -0
- hv_generate_video.py +936 -0
- merge_lora.py +63 -0
- modules/__init__.py +0 -0
.github/FUNDING.yml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# These are supported funding model platforms
|
| 2 |
+
|
| 3 |
+
github: kohya-ss
|
.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.venv
|
| 3 |
+
venv/
|
| 4 |
+
logs/
|
| 5 |
+
uv.lock
|
| 6 |
+
main.exp
|
| 7 |
+
main.lib
|
| 8 |
+
main.obj
|
.ipynb_checkpoints/README-checkpoint.md
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FramePack Image Edit Early Lora
|
| 2 |
+
|
| 3 |
+
This repository contains the necessary steps and scripts to generate A edit of the Image using a image-to-video model.
|
| 4 |
+
The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create Edit Image based on a input Image and textual prompts.
|
| 5 |
+
|
| 6 |
+
## Prerequisites
|
| 7 |
+
|
| 8 |
+
Before proceeding, ensure that you have the following installed on your system:
|
| 9 |
+
|
| 10 |
+
• **Ubuntu** (or a compatible Linux distribution)
|
| 11 |
+
• **Python 3.x**
|
| 12 |
+
• **pip** (Python package manager)
|
| 13 |
+
• **Git**
|
| 14 |
+
• **Git LFS** (Git Large File Storage)
|
| 15 |
+
• **FFmpeg**
|
| 16 |
+
|
| 17 |
+
## Installation
|
| 18 |
+
|
| 19 |
+
1. **Update and Install Dependencies**
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
2. **Clone the Repository**
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
git clone https://huggingface.co/svjack/FramePack_Image_Edit_Lora_Early
|
| 29 |
+
cd FramePack_Image_Edit_Lora_Early
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
3. **Install Python Dependencies**
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
pip install torch torchvision
|
| 36 |
+
pip install -r requirements.txt
|
| 37 |
+
pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
|
| 38 |
+
pip install moviepy==1.0.3
|
| 39 |
+
pip install sageattention==1.0.6
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
4. **Download Model Weights**
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
git clone https://huggingface.co/lllyasviel/FramePackI2V_HY
|
| 46 |
+
git clone https://huggingface.co/hunyuanvideo-community/HunyuanVideo
|
| 47 |
+
git clone https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged
|
| 48 |
+
git clone https://huggingface.co/Comfy-Org/sigclip_vision_384
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Usage
|
| 52 |
+
|
| 53 |
+
To Edit a Image, use the `fpack_generate_video.py` script with the appropriate parameters. Below are examples of how to do it.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
* 1 Add a cat
|
| 57 |
+
- Input
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
python fpack_generate_video.py \
|
| 62 |
+
--dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
|
| 63 |
+
--vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
|
| 64 |
+
--text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
|
| 65 |
+
--text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
|
| 66 |
+
--image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
|
| 67 |
+
--image_path xiang_image.jpg \
|
| 68 |
+
--prompt "add a cat into the picture" \
|
| 69 |
+
--video_size 512 512 --fps 30 --infer_steps 25 \
|
| 70 |
+
--attn_mode sdpa --fp8_scaled \
|
| 71 |
+
--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
|
| 72 |
+
--save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
|
| 73 |
+
--seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
- Output
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
* 2 Change Background
|
| 80 |
+
- Input
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
python fpack_generate_video.py \
|
| 85 |
+
--dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
|
| 86 |
+
--vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
|
| 87 |
+
--text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
|
| 88 |
+
--text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
|
| 89 |
+
--image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
|
| 90 |
+
--image_path wanye.jpg \
|
| 91 |
+
--prompt "Change the background into a restaurant in anime style. Keep the character's eye colors and white hair unchanged." \
|
| 92 |
+
--video_size 512 512 --fps 30 --infer_steps 25 \
|
| 93 |
+
--attn_mode sdpa --fp8_scaled \
|
| 94 |
+
--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
|
| 95 |
+
--save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
|
| 96 |
+
--seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
|
| 97 |
+
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
- Output
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
* 3 Place Train into landscape
|
| 104 |
+
- Input
|
| 105 |
+
|
| 106 |
+
```python
|
| 107 |
+
python fpack_generate_video.py \
|
| 108 |
+
--dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
|
| 109 |
+
--vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
|
| 110 |
+
--text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
|
| 111 |
+
--text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
|
| 112 |
+
--image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
|
| 113 |
+
--image_path train.jpg \
|
| 114 |
+
--prompt "place the train into a beautiful landscape" \
|
| 115 |
+
--video_size 512 512 --fps 30 --infer_steps 25 \
|
| 116 |
+
--attn_mode sdpa --fp8_scaled \
|
| 117 |
+
--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
|
| 118 |
+
--save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
|
| 119 |
+
--seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
- Output
|
| 123 |
+
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.10
|
README.md
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FramePack Image Edit Early Lora
|
| 2 |
+
|
| 3 |
+
This repository contains the necessary steps and scripts to generate A edit of the Image using a image-to-video model.
|
| 4 |
+
The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create Edit Image based on a input Image and textual prompts.
|
| 5 |
+
|
| 6 |
+
## Prerequisites
|
| 7 |
+
|
| 8 |
+
Before proceeding, ensure that you have the following installed on your system:
|
| 9 |
+
|
| 10 |
+
• **Ubuntu** (or a compatible Linux distribution)
|
| 11 |
+
• **Python 3.x**
|
| 12 |
+
• **pip** (Python package manager)
|
| 13 |
+
• **Git**
|
| 14 |
+
• **Git LFS** (Git Large File Storage)
|
| 15 |
+
• **FFmpeg**
|
| 16 |
+
|
| 17 |
+
## Installation
|
| 18 |
+
|
| 19 |
+
1. **Update and Install Dependencies**
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
2. **Clone the Repository**
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
git clone https://huggingface.co/svjack/FramePack_Image_Edit_Lora_Early
|
| 29 |
+
cd FramePack_Image_Edit_Lora_Early
|
| 30 |
+
```
|
| 31 |
+
|
| 32 |
+
3. **Install Python Dependencies**
|
| 33 |
+
|
| 34 |
+
```bash
|
| 35 |
+
pip install torch torchvision
|
| 36 |
+
pip install -r requirements.txt
|
| 37 |
+
pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
|
| 38 |
+
pip install moviepy==1.0.3
|
| 39 |
+
pip install sageattention==1.0.6
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
4. **Download Model Weights**
|
| 43 |
+
|
| 44 |
+
```bash
|
| 45 |
+
git clone https://huggingface.co/lllyasviel/FramePackI2V_HY
|
| 46 |
+
git clone https://huggingface.co/hunyuanvideo-community/HunyuanVideo
|
| 47 |
+
git clone https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged
|
| 48 |
+
git clone https://huggingface.co/Comfy-Org/sigclip_vision_384
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
## Usage
|
| 52 |
+
|
| 53 |
+
To Edit a Image, use the `fpack_generate_video.py` script with the appropriate parameters. Below are examples of how to do it.
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
* 1 Add a cat
|
| 57 |
+
- Input
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
```python
|
| 61 |
+
python fpack_generate_video.py \
|
| 62 |
+
--dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
|
| 63 |
+
--vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
|
| 64 |
+
--text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
|
| 65 |
+
--text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
|
| 66 |
+
--image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
|
| 67 |
+
--image_path xiang_image.jpg \
|
| 68 |
+
--prompt "add a cat into the picture" \
|
| 69 |
+
--video_size 512 512 --fps 30 --infer_steps 25 \
|
| 70 |
+
--attn_mode sdpa --fp8_scaled \
|
| 71 |
+
--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
|
| 72 |
+
--save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
|
| 73 |
+
--seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
- Output
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
* 2 Change Background
|
| 80 |
+
- Input
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
```python
|
| 84 |
+
python fpack_generate_video.py \
|
| 85 |
+
--dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
|
| 86 |
+
--vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
|
| 87 |
+
--text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
|
| 88 |
+
--text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
|
| 89 |
+
--image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
|
| 90 |
+
--image_path wanye.jpg \
|
| 91 |
+
--prompt "Change the background into a restaurant in anime style. Keep the character's eye colors and white hair unchanged." \
|
| 92 |
+
--video_size 512 512 --fps 30 --infer_steps 25 \
|
| 93 |
+
--attn_mode sdpa --fp8_scaled \
|
| 94 |
+
--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
|
| 95 |
+
--save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
|
| 96 |
+
--seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
|
| 97 |
+
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
- Output
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
* 3 Place Train into landscape
|
| 104 |
+
- Input
|
| 105 |
+
|
| 106 |
+
```python
|
| 107 |
+
python fpack_generate_video.py \
|
| 108 |
+
--dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
|
| 109 |
+
--vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
|
| 110 |
+
--text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
|
| 111 |
+
--text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
|
| 112 |
+
--image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
|
| 113 |
+
--image_path train.jpg \
|
| 114 |
+
--prompt "place the train into a beautiful landscape" \
|
| 115 |
+
--video_size 512 512 --fps 30 --infer_steps 25 \
|
| 116 |
+
--attn_mode sdpa --fp8_scaled \
|
| 117 |
+
--vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
|
| 118 |
+
--save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
|
| 119 |
+
--seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
- Output
|
| 123 |
+
|
cache_latents.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
import glob
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
|
| 10 |
+
from dataset import config_utils
|
| 11 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
| 12 |
+
from PIL import Image
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache, ARCHITECTURE_HUNYUAN_VIDEO
|
| 17 |
+
from hunyuan_model.vae import load_vae
|
| 18 |
+
from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
| 19 |
+
from utils.model_utils import str_to_dtype
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
logging.basicConfig(level=logging.INFO)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
|
| 26 |
+
import cv2
|
| 27 |
+
|
| 28 |
+
imgs = (
|
| 29 |
+
[image]
|
| 30 |
+
if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
|
| 31 |
+
else [image[0], image[-1]]
|
| 32 |
+
)
|
| 33 |
+
if len(imgs) > 1:
|
| 34 |
+
print(f"Number of images: {len(image)}")
|
| 35 |
+
for i, img in enumerate(imgs):
|
| 36 |
+
if len(imgs) > 1:
|
| 37 |
+
print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
|
| 38 |
+
else:
|
| 39 |
+
print(f"Image: {img.shape}")
|
| 40 |
+
cv2_img = np.array(img) if isinstance(img, Image.Image) else img
|
| 41 |
+
cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
|
| 42 |
+
cv2.imshow("image", cv2_img)
|
| 43 |
+
k = cv2.waitKey(0)
|
| 44 |
+
cv2.destroyAllWindows()
|
| 45 |
+
if k == ord("q") or k == ord("d"):
|
| 46 |
+
return k
|
| 47 |
+
return k
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def show_console(
|
| 51 |
+
image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
|
| 52 |
+
width: int,
|
| 53 |
+
back: str,
|
| 54 |
+
interactive: bool = False,
|
| 55 |
+
) -> int:
|
| 56 |
+
from ascii_magic import from_pillow_image, Back
|
| 57 |
+
|
| 58 |
+
back = None
|
| 59 |
+
if back is not None:
|
| 60 |
+
back = getattr(Back, back.upper())
|
| 61 |
+
|
| 62 |
+
k = None
|
| 63 |
+
imgs = (
|
| 64 |
+
[image]
|
| 65 |
+
if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
|
| 66 |
+
else [image[0], image[-1]]
|
| 67 |
+
)
|
| 68 |
+
if len(imgs) > 1:
|
| 69 |
+
print(f"Number of images: {len(image)}")
|
| 70 |
+
for i, img in enumerate(imgs):
|
| 71 |
+
if len(imgs) > 1:
|
| 72 |
+
print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
|
| 73 |
+
else:
|
| 74 |
+
print(f"Image: {img.shape}")
|
| 75 |
+
pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
|
| 76 |
+
ascii_img = from_pillow_image(pil_img)
|
| 77 |
+
ascii_img.to_terminal(columns=width, back=back)
|
| 78 |
+
|
| 79 |
+
if interactive:
|
| 80 |
+
k = input("Press q to quit, d to next dataset, other key to next: ")
|
| 81 |
+
if k == "q" or k == "d":
|
| 82 |
+
return ord(k)
|
| 83 |
+
|
| 84 |
+
if not interactive:
|
| 85 |
+
return ord(" ")
|
| 86 |
+
return ord(k) if k else ord(" ")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def save_video(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]], cache_path: str, fps: int = 24):
|
| 90 |
+
import av
|
| 91 |
+
|
| 92 |
+
directory = os.path.dirname(cache_path)
|
| 93 |
+
if not os.path.exists(directory):
|
| 94 |
+
os.makedirs(directory)
|
| 95 |
+
|
| 96 |
+
if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image):
|
| 97 |
+
# save image
|
| 98 |
+
image_path = cache_path.replace(".safetensors", ".jpg")
|
| 99 |
+
img = image if isinstance(image, Image.Image) else Image.fromarray(image)
|
| 100 |
+
img.save(image_path)
|
| 101 |
+
print(f"Saved image: {image_path}")
|
| 102 |
+
else:
|
| 103 |
+
imgs = image
|
| 104 |
+
print(f"Number of images: {len(imgs)}")
|
| 105 |
+
# save video
|
| 106 |
+
video_path = cache_path.replace(".safetensors", ".mp4")
|
| 107 |
+
height, width = imgs[0].shape[0:2]
|
| 108 |
+
|
| 109 |
+
# create output container
|
| 110 |
+
container = av.open(video_path, mode="w")
|
| 111 |
+
|
| 112 |
+
# create video stream
|
| 113 |
+
codec = "libx264"
|
| 114 |
+
pixel_format = "yuv420p"
|
| 115 |
+
stream = container.add_stream(codec, rate=fps)
|
| 116 |
+
stream.width = width
|
| 117 |
+
stream.height = height
|
| 118 |
+
stream.pix_fmt = pixel_format
|
| 119 |
+
stream.bit_rate = 1000000 # 1Mbit/s for preview quality
|
| 120 |
+
|
| 121 |
+
for frame_img in imgs:
|
| 122 |
+
if isinstance(frame_img, Image.Image):
|
| 123 |
+
frame = av.VideoFrame.from_image(frame_img)
|
| 124 |
+
else:
|
| 125 |
+
frame = av.VideoFrame.from_ndarray(frame_img, format="rgb24")
|
| 126 |
+
packets = stream.encode(frame)
|
| 127 |
+
for packet in packets:
|
| 128 |
+
container.mux(packet)
|
| 129 |
+
|
| 130 |
+
for packet in stream.encode():
|
| 131 |
+
container.mux(packet)
|
| 132 |
+
|
| 133 |
+
container.close()
|
| 134 |
+
|
| 135 |
+
print(f"Saved video: {video_path}")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def show_datasets(
|
| 139 |
+
datasets: list[BaseDataset],
|
| 140 |
+
debug_mode: str,
|
| 141 |
+
console_width: int,
|
| 142 |
+
console_back: str,
|
| 143 |
+
console_num_images: Optional[int],
|
| 144 |
+
fps: int = 24,
|
| 145 |
+
):
|
| 146 |
+
if debug_mode != "video":
|
| 147 |
+
print(f"d: next dataset, q: quit")
|
| 148 |
+
|
| 149 |
+
num_workers = max(1, os.cpu_count() - 1)
|
| 150 |
+
for i, dataset in enumerate(datasets):
|
| 151 |
+
print(f"Dataset [{i}]")
|
| 152 |
+
batch_index = 0
|
| 153 |
+
num_images_to_show = console_num_images
|
| 154 |
+
k = None
|
| 155 |
+
for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
|
| 156 |
+
print(f"bucket resolution: {key}, count: {len(batch)}")
|
| 157 |
+
for j, item_info in enumerate(batch):
|
| 158 |
+
item_info: ItemInfo
|
| 159 |
+
print(f"{batch_index}-{j}: {item_info}")
|
| 160 |
+
if debug_mode == "image":
|
| 161 |
+
k = show_image(item_info.content)
|
| 162 |
+
elif debug_mode == "console":
|
| 163 |
+
k = show_console(item_info.content, console_width, console_back, console_num_images is None)
|
| 164 |
+
if num_images_to_show is not None:
|
| 165 |
+
num_images_to_show -= 1
|
| 166 |
+
if num_images_to_show == 0:
|
| 167 |
+
k = ord("d") # next dataset
|
| 168 |
+
elif debug_mode == "video":
|
| 169 |
+
save_video(item_info.content, item_info.latent_cache_path, fps)
|
| 170 |
+
k = None # save next video
|
| 171 |
+
|
| 172 |
+
if k == ord("q"):
|
| 173 |
+
return
|
| 174 |
+
elif k == ord("d"):
|
| 175 |
+
break
|
| 176 |
+
if k == ord("d"):
|
| 177 |
+
break
|
| 178 |
+
batch_index += 1
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
|
| 182 |
+
contents = torch.stack([torch.from_numpy(item.content) for item in batch])
|
| 183 |
+
if len(contents.shape) == 4:
|
| 184 |
+
contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
|
| 185 |
+
|
| 186 |
+
contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
|
| 187 |
+
contents = contents.to(vae.device, dtype=vae.dtype)
|
| 188 |
+
contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
|
| 189 |
+
|
| 190 |
+
h, w = contents.shape[3], contents.shape[4]
|
| 191 |
+
if h < 8 or w < 8:
|
| 192 |
+
item = batch[0] # other items should have the same size
|
| 193 |
+
raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
|
| 194 |
+
|
| 195 |
+
# print(f"encode batch: {contents.shape}")
|
| 196 |
+
with torch.no_grad():
|
| 197 |
+
latent = vae.encode(contents).latent_dist.sample()
|
| 198 |
+
# latent = latent * vae.config.scaling_factor
|
| 199 |
+
|
| 200 |
+
# # debug: decode and save
|
| 201 |
+
# with torch.no_grad():
|
| 202 |
+
# latent_to_decode = latent / vae.config.scaling_factor
|
| 203 |
+
# images = vae.decode(latent_to_decode, return_dict=False)[0]
|
| 204 |
+
# images = (images / 2 + 0.5).clamp(0, 1)
|
| 205 |
+
# images = images.cpu().float().numpy()
|
| 206 |
+
# images = (images * 255).astype(np.uint8)
|
| 207 |
+
# images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
|
| 208 |
+
# for b in range(images.shape[0]):
|
| 209 |
+
# for f in range(images.shape[1]):
|
| 210 |
+
# fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
|
| 211 |
+
# img = Image.fromarray(images[b, f])
|
| 212 |
+
# img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
|
| 213 |
+
|
| 214 |
+
for item, l in zip(batch, latent):
|
| 215 |
+
# print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
|
| 216 |
+
save_latent_cache(item, l)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def encode_datasets(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
|
| 220 |
+
num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
|
| 221 |
+
for i, dataset in enumerate(datasets):
|
| 222 |
+
logger.info(f"Encoding dataset [{i}]")
|
| 223 |
+
all_latent_cache_paths = []
|
| 224 |
+
for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
|
| 225 |
+
all_latent_cache_paths.extend([item.latent_cache_path for item in batch])
|
| 226 |
+
|
| 227 |
+
if args.skip_existing:
|
| 228 |
+
filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
|
| 229 |
+
if len(filtered_batch) == 0:
|
| 230 |
+
continue
|
| 231 |
+
batch = filtered_batch
|
| 232 |
+
|
| 233 |
+
bs = args.batch_size if args.batch_size is not None else len(batch)
|
| 234 |
+
for i in range(0, len(batch), bs):
|
| 235 |
+
encode(batch[i : i + bs])
|
| 236 |
+
|
| 237 |
+
# normalize paths
|
| 238 |
+
all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
|
| 239 |
+
all_latent_cache_paths = set(all_latent_cache_paths)
|
| 240 |
+
|
| 241 |
+
# remove old cache files not in the dataset
|
| 242 |
+
all_cache_files = dataset.get_all_latent_cache_files()
|
| 243 |
+
for cache_file in all_cache_files:
|
| 244 |
+
if os.path.normpath(cache_file) not in all_latent_cache_paths:
|
| 245 |
+
if args.keep_cache:
|
| 246 |
+
logger.info(f"Keep cache file not in the dataset: {cache_file}")
|
| 247 |
+
else:
|
| 248 |
+
os.remove(cache_file)
|
| 249 |
+
logger.info(f"Removed old cache file: {cache_file}")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def main(args):
|
| 253 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
| 254 |
+
device = torch.device(device)
|
| 255 |
+
|
| 256 |
+
# Load dataset config
|
| 257 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
| 258 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 259 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
| 260 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
|
| 261 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 262 |
+
|
| 263 |
+
datasets = train_dataset_group.datasets
|
| 264 |
+
|
| 265 |
+
if args.debug_mode is not None:
|
| 266 |
+
show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
|
| 267 |
+
return
|
| 268 |
+
|
| 269 |
+
assert args.vae is not None, "vae checkpoint is required"
|
| 270 |
+
|
| 271 |
+
# Load VAE model: HunyuanVideo VAE model is float16
|
| 272 |
+
vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
|
| 273 |
+
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
|
| 274 |
+
vae.eval()
|
| 275 |
+
logger.info(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
|
| 276 |
+
|
| 277 |
+
if args.vae_chunk_size is not None:
|
| 278 |
+
vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
|
| 279 |
+
logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
|
| 280 |
+
if args.vae_spatial_tile_sample_min_size is not None:
|
| 281 |
+
vae.enable_spatial_tiling(True)
|
| 282 |
+
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
|
| 283 |
+
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
|
| 284 |
+
elif args.vae_tiling:
|
| 285 |
+
vae.enable_spatial_tiling(True)
|
| 286 |
+
|
| 287 |
+
# Encode images
|
| 288 |
+
def encode(one_batch: list[ItemInfo]):
|
| 289 |
+
encode_and_save_batch(vae, one_batch)
|
| 290 |
+
|
| 291 |
+
encode_datasets(datasets, encode, args)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def setup_parser_common() -> argparse.ArgumentParser:
|
| 295 |
+
parser = argparse.ArgumentParser()
|
| 296 |
+
|
| 297 |
+
parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
|
| 298 |
+
parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
|
| 299 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
|
| 300 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
|
| 301 |
+
parser.add_argument(
|
| 302 |
+
"--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
|
| 303 |
+
)
|
| 304 |
+
parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
|
| 305 |
+
parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
|
| 306 |
+
parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
|
| 307 |
+
parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console", "video"], help="debug mode")
|
| 308 |
+
parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
|
| 309 |
+
parser.add_argument(
|
| 310 |
+
"--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
|
| 311 |
+
)
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
"--console_num_images",
|
| 314 |
+
type=int,
|
| 315 |
+
default=None,
|
| 316 |
+
help="debug mode: not interactive, number of images to show for each dataset",
|
| 317 |
+
)
|
| 318 |
+
return parser
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
| 322 |
+
parser.add_argument(
|
| 323 |
+
"--vae_tiling",
|
| 324 |
+
action="store_true",
|
| 325 |
+
help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
|
| 326 |
+
)
|
| 327 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
| 328 |
+
parser.add_argument(
|
| 329 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
| 330 |
+
)
|
| 331 |
+
return parser
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
if __name__ == "__main__":
|
| 335 |
+
parser = setup_parser_common()
|
| 336 |
+
parser = hv_setup_parser(parser)
|
| 337 |
+
|
| 338 |
+
args = parser.parse_args()
|
| 339 |
+
main(args)
|
cache_text_encoder_outputs.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
from dataset import config_utils
|
| 10 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
| 11 |
+
import accelerate
|
| 12 |
+
|
| 13 |
+
from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, BaseDataset, ItemInfo, save_text_encoder_output_cache
|
| 14 |
+
from hunyuan_model import text_encoder as text_encoder_module
|
| 15 |
+
from hunyuan_model.text_encoder import TextEncoder
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
from utils.model_utils import str_to_dtype
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
logging.basicConfig(level=logging.INFO)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
|
| 26 |
+
data_type = "video" # video only, image is not supported
|
| 27 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
| 28 |
+
|
| 29 |
+
with torch.no_grad():
|
| 30 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
|
| 31 |
+
|
| 32 |
+
return prompt_outputs.hidden_state, prompt_outputs.attention_mask
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def encode_and_save_batch(
|
| 36 |
+
text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
|
| 37 |
+
):
|
| 38 |
+
prompts = [item.caption for item in batch]
|
| 39 |
+
# print(prompts)
|
| 40 |
+
|
| 41 |
+
# encode prompt
|
| 42 |
+
if accelerator is not None:
|
| 43 |
+
with accelerator.autocast():
|
| 44 |
+
prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
|
| 45 |
+
else:
|
| 46 |
+
prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
|
| 47 |
+
|
| 48 |
+
# # convert to fp16 if needed
|
| 49 |
+
# if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
|
| 50 |
+
# prompt_embeds = prompt_embeds.to(text_encoder.dtype)
|
| 51 |
+
|
| 52 |
+
# save prompt cache
|
| 53 |
+
for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
|
| 54 |
+
save_text_encoder_output_cache(item, embed, mask, is_llm)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def prepare_cache_files_and_paths(datasets: list[BaseDataset]):
|
| 58 |
+
all_cache_files_for_dataset = [] # exisiting cache files
|
| 59 |
+
all_cache_paths_for_dataset = [] # all cache paths in the dataset
|
| 60 |
+
for dataset in datasets:
|
| 61 |
+
all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()]
|
| 62 |
+
all_cache_files = set(all_cache_files)
|
| 63 |
+
all_cache_files_for_dataset.append(all_cache_files)
|
| 64 |
+
|
| 65 |
+
all_cache_paths_for_dataset.append(set())
|
| 66 |
+
return all_cache_files_for_dataset, all_cache_paths_for_dataset
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def process_text_encoder_batches(
|
| 70 |
+
num_workers: Optional[int],
|
| 71 |
+
skip_existing: bool,
|
| 72 |
+
batch_size: int,
|
| 73 |
+
datasets: list[BaseDataset],
|
| 74 |
+
all_cache_files_for_dataset: list[set],
|
| 75 |
+
all_cache_paths_for_dataset: list[set],
|
| 76 |
+
encode: callable,
|
| 77 |
+
):
|
| 78 |
+
num_workers = num_workers if num_workers is not None else max(1, os.cpu_count() - 1)
|
| 79 |
+
for i, dataset in enumerate(datasets):
|
| 80 |
+
logger.info(f"Encoding dataset [{i}]")
|
| 81 |
+
all_cache_files = all_cache_files_for_dataset[i]
|
| 82 |
+
all_cache_paths = all_cache_paths_for_dataset[i]
|
| 83 |
+
for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
|
| 84 |
+
# update cache files (it's ok if we update it multiple times)
|
| 85 |
+
all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch])
|
| 86 |
+
|
| 87 |
+
# skip existing cache files
|
| 88 |
+
if skip_existing:
|
| 89 |
+
filtered_batch = [
|
| 90 |
+
item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files
|
| 91 |
+
]
|
| 92 |
+
# print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files")
|
| 93 |
+
if len(filtered_batch) == 0:
|
| 94 |
+
continue
|
| 95 |
+
batch = filtered_batch
|
| 96 |
+
|
| 97 |
+
bs = batch_size if batch_size is not None else len(batch)
|
| 98 |
+
for i in range(0, len(batch), bs):
|
| 99 |
+
encode(batch[i : i + bs])
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def post_process_cache_files(
|
| 103 |
+
datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set], keep_cache: bool
|
| 104 |
+
):
|
| 105 |
+
for i, dataset in enumerate(datasets):
|
| 106 |
+
all_cache_files = all_cache_files_for_dataset[i]
|
| 107 |
+
all_cache_paths = all_cache_paths_for_dataset[i]
|
| 108 |
+
for cache_file in all_cache_files:
|
| 109 |
+
if cache_file not in all_cache_paths:
|
| 110 |
+
if keep_cache:
|
| 111 |
+
logger.info(f"Keep cache file not in the dataset: {cache_file}")
|
| 112 |
+
else:
|
| 113 |
+
os.remove(cache_file)
|
| 114 |
+
logger.info(f"Removed old cache file: {cache_file}")
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main(args):
|
| 118 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
| 119 |
+
device = torch.device(device)
|
| 120 |
+
|
| 121 |
+
# Load dataset config
|
| 122 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
| 123 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 124 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
| 125 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
|
| 126 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 127 |
+
|
| 128 |
+
datasets = train_dataset_group.datasets
|
| 129 |
+
|
| 130 |
+
# define accelerator for fp8 inference
|
| 131 |
+
accelerator = None
|
| 132 |
+
if args.fp8_llm:
|
| 133 |
+
accelerator = accelerate.Accelerator(mixed_precision="fp16")
|
| 134 |
+
|
| 135 |
+
# prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
|
| 136 |
+
all_cache_files_for_dataset, all_cache_paths_for_dataset = prepare_cache_files_and_paths(datasets)
|
| 137 |
+
|
| 138 |
+
# Load Text Encoder 1
|
| 139 |
+
text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
|
| 140 |
+
logger.info(f"loading text encoder 1: {args.text_encoder1}")
|
| 141 |
+
text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
|
| 142 |
+
text_encoder_1.to(device=device)
|
| 143 |
+
|
| 144 |
+
# Encode with Text Encoder 1 (LLM)
|
| 145 |
+
logger.info("Encoding with Text Encoder 1")
|
| 146 |
+
|
| 147 |
+
def encode_for_text_encoder_1(batch: list[ItemInfo]):
|
| 148 |
+
encode_and_save_batch(text_encoder_1, batch, is_llm=True, accelerator=accelerator)
|
| 149 |
+
|
| 150 |
+
process_text_encoder_batches(
|
| 151 |
+
args.num_workers,
|
| 152 |
+
args.skip_existing,
|
| 153 |
+
args.batch_size,
|
| 154 |
+
datasets,
|
| 155 |
+
all_cache_files_for_dataset,
|
| 156 |
+
all_cache_paths_for_dataset,
|
| 157 |
+
encode_for_text_encoder_1,
|
| 158 |
+
)
|
| 159 |
+
del text_encoder_1
|
| 160 |
+
|
| 161 |
+
# Load Text Encoder 2
|
| 162 |
+
logger.info(f"loading text encoder 2: {args.text_encoder2}")
|
| 163 |
+
text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
|
| 164 |
+
text_encoder_2.to(device=device)
|
| 165 |
+
|
| 166 |
+
# Encode with Text Encoder 2
|
| 167 |
+
logger.info("Encoding with Text Encoder 2")
|
| 168 |
+
|
| 169 |
+
def encode_for_text_encoder_2(batch: list[ItemInfo]):
|
| 170 |
+
encode_and_save_batch(text_encoder_2, batch, is_llm=False, accelerator=None)
|
| 171 |
+
|
| 172 |
+
process_text_encoder_batches(
|
| 173 |
+
args.num_workers,
|
| 174 |
+
args.skip_existing,
|
| 175 |
+
args.batch_size,
|
| 176 |
+
datasets,
|
| 177 |
+
all_cache_files_for_dataset,
|
| 178 |
+
all_cache_paths_for_dataset,
|
| 179 |
+
encode_for_text_encoder_2,
|
| 180 |
+
)
|
| 181 |
+
del text_encoder_2
|
| 182 |
+
|
| 183 |
+
# remove cache files not in dataset
|
| 184 |
+
post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def setup_parser_common():
|
| 188 |
+
parser = argparse.ArgumentParser()
|
| 189 |
+
|
| 190 |
+
parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
|
| 191 |
+
parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
|
| 192 |
+
parser.add_argument(
|
| 193 |
+
"--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
|
| 194 |
+
)
|
| 195 |
+
parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
|
| 196 |
+
parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
|
| 197 |
+
parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
|
| 198 |
+
return parser
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
| 202 |
+
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
|
| 203 |
+
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
|
| 204 |
+
parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
|
| 205 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
| 206 |
+
return parser
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
parser = setup_parser_common()
|
| 211 |
+
parser = hv_setup_parser(parser)
|
| 212 |
+
|
| 213 |
+
args = parser.parse_args()
|
| 214 |
+
main(args)
|
convert_lora.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from safetensors.torch import load_file, save_file
|
| 5 |
+
from safetensors import safe_open
|
| 6 |
+
from utils import model_utils
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
logging.basicConfig(level=logging.INFO)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def convert_from_diffusers(prefix, weights_sd):
|
| 16 |
+
# convert from diffusers(?) to default LoRA
|
| 17 |
+
# Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
|
| 18 |
+
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
|
| 19 |
+
|
| 20 |
+
# note: Diffusers has no alpha, so alpha is set to rank
|
| 21 |
+
new_weights_sd = {}
|
| 22 |
+
lora_dims = {}
|
| 23 |
+
for key, weight in weights_sd.items():
|
| 24 |
+
diffusers_prefix, key_body = key.split(".", 1)
|
| 25 |
+
if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
|
| 26 |
+
logger.warning(f"unexpected key: {key} in diffusers format")
|
| 27 |
+
continue
|
| 28 |
+
|
| 29 |
+
new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
|
| 30 |
+
new_weights_sd[new_key] = weight
|
| 31 |
+
|
| 32 |
+
lora_name = new_key.split(".")[0] # before first dot
|
| 33 |
+
if lora_name not in lora_dims and "lora_down" in new_key:
|
| 34 |
+
lora_dims[lora_name] = weight.shape[0]
|
| 35 |
+
|
| 36 |
+
# add alpha with rank
|
| 37 |
+
for lora_name, dim in lora_dims.items():
|
| 38 |
+
new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
|
| 39 |
+
|
| 40 |
+
return new_weights_sd
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def convert_to_diffusers(prefix, weights_sd):
|
| 44 |
+
# convert from default LoRA to diffusers
|
| 45 |
+
|
| 46 |
+
# get alphas
|
| 47 |
+
lora_alphas = {}
|
| 48 |
+
for key, weight in weights_sd.items():
|
| 49 |
+
if key.startswith(prefix):
|
| 50 |
+
lora_name = key.split(".", 1)[0] # before first dot
|
| 51 |
+
if lora_name not in lora_alphas and "alpha" in key:
|
| 52 |
+
lora_alphas[lora_name] = weight
|
| 53 |
+
|
| 54 |
+
new_weights_sd = {}
|
| 55 |
+
for key, weight in weights_sd.items():
|
| 56 |
+
if key.startswith(prefix):
|
| 57 |
+
if "alpha" in key:
|
| 58 |
+
continue
|
| 59 |
+
|
| 60 |
+
lora_name = key.split(".", 1)[0] # before first dot
|
| 61 |
+
|
| 62 |
+
module_name = lora_name[len(prefix) :] # remove "lora_unet_"
|
| 63 |
+
module_name = module_name.replace("_", ".") # replace "_" with "."
|
| 64 |
+
if ".cross.attn." in module_name or ".self.attn." in module_name:
|
| 65 |
+
# Wan2.1 lora name to module name: ugly but works
|
| 66 |
+
module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn
|
| 67 |
+
module_name = module_name.replace("self.attn", "self_attn") # fix self attn
|
| 68 |
+
module_name = module_name.replace("k.img", "k_img") # fix k img
|
| 69 |
+
module_name = module_name.replace("v.img", "v_img") # fix v img
|
| 70 |
+
else:
|
| 71 |
+
# HunyuanVideo lora name to module name: ugly but works
|
| 72 |
+
module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
|
| 73 |
+
module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
|
| 74 |
+
module_name = module_name.replace("img.", "img_") # fix img
|
| 75 |
+
module_name = module_name.replace("txt.", "txt_") # fix txt
|
| 76 |
+
module_name = module_name.replace("attn.", "attn_") # fix attn
|
| 77 |
+
|
| 78 |
+
diffusers_prefix = "diffusion_model"
|
| 79 |
+
if "lora_down" in key:
|
| 80 |
+
new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
|
| 81 |
+
dim = weight.shape[0]
|
| 82 |
+
elif "lora_up" in key:
|
| 83 |
+
new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
|
| 84 |
+
dim = weight.shape[1]
|
| 85 |
+
else:
|
| 86 |
+
logger.warning(f"unexpected key: {key} in default LoRA format")
|
| 87 |
+
continue
|
| 88 |
+
|
| 89 |
+
# scale weight by alpha
|
| 90 |
+
if lora_name in lora_alphas:
|
| 91 |
+
# we scale both down and up, so scale is sqrt
|
| 92 |
+
scale = lora_alphas[lora_name] / dim
|
| 93 |
+
scale = scale.sqrt()
|
| 94 |
+
weight = weight * scale
|
| 95 |
+
else:
|
| 96 |
+
logger.warning(f"missing alpha for {lora_name}")
|
| 97 |
+
|
| 98 |
+
new_weights_sd[new_key] = weight
|
| 99 |
+
|
| 100 |
+
return new_weights_sd
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def convert(input_file, output_file, target_format):
|
| 104 |
+
logger.info(f"loading {input_file}")
|
| 105 |
+
weights_sd = load_file(input_file)
|
| 106 |
+
with safe_open(input_file, framework="pt") as f:
|
| 107 |
+
metadata = f.metadata()
|
| 108 |
+
|
| 109 |
+
logger.info(f"converting to {target_format}")
|
| 110 |
+
prefix = "lora_unet_"
|
| 111 |
+
if target_format == "default":
|
| 112 |
+
new_weights_sd = convert_from_diffusers(prefix, weights_sd)
|
| 113 |
+
metadata = metadata or {}
|
| 114 |
+
model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
|
| 115 |
+
elif target_format == "other":
|
| 116 |
+
new_weights_sd = convert_to_diffusers(prefix, weights_sd)
|
| 117 |
+
else:
|
| 118 |
+
raise ValueError(f"unknown target format: {target_format}")
|
| 119 |
+
|
| 120 |
+
logger.info(f"saving to {output_file}")
|
| 121 |
+
save_file(new_weights_sd, output_file, metadata=metadata)
|
| 122 |
+
|
| 123 |
+
logger.info("done")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def parse_args():
|
| 127 |
+
parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
|
| 128 |
+
parser.add_argument("--input", type=str, required=True, help="input model file")
|
| 129 |
+
parser.add_argument("--output", type=str, required=True, help="output model file")
|
| 130 |
+
parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
|
| 131 |
+
args = parser.parse_args()
|
| 132 |
+
return args
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
args = parse_args()
|
| 137 |
+
convert(args.input, args.output, args.target)
|
dataset/__init__.py
ADDED
|
File without changes
|
dataset/config_utils.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from dataclasses import (
|
| 3 |
+
asdict,
|
| 4 |
+
dataclass,
|
| 5 |
+
)
|
| 6 |
+
import functools
|
| 7 |
+
import random
|
| 8 |
+
from textwrap import dedent, indent
|
| 9 |
+
import json
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
# from toolz import curry
|
| 13 |
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
| 14 |
+
|
| 15 |
+
import toml
|
| 16 |
+
import voluptuous
|
| 17 |
+
from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
|
| 18 |
+
|
| 19 |
+
from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
|
| 20 |
+
|
| 21 |
+
import logging
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class BaseDatasetParams:
|
| 29 |
+
resolution: Tuple[int, int] = (960, 544)
|
| 30 |
+
enable_bucket: bool = False
|
| 31 |
+
bucket_no_upscale: bool = False
|
| 32 |
+
caption_extension: Optional[str] = None
|
| 33 |
+
batch_size: int = 1
|
| 34 |
+
num_repeats: int = 1
|
| 35 |
+
cache_directory: Optional[str] = None
|
| 36 |
+
debug_dataset: bool = False
|
| 37 |
+
architecture: str = "no_default" # short style like "hv" or "wan"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
@dataclass
|
| 41 |
+
class ImageDatasetParams(BaseDatasetParams):
|
| 42 |
+
image_directory: Optional[str] = None
|
| 43 |
+
image_jsonl_file: Optional[str] = None
|
| 44 |
+
control_directory: Optional[str] = None
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class VideoDatasetParams(BaseDatasetParams):
|
| 49 |
+
video_directory: Optional[str] = None
|
| 50 |
+
video_jsonl_file: Optional[str] = None
|
| 51 |
+
control_directory: Optional[str] = None
|
| 52 |
+
target_frames: Sequence[int] = (1,)
|
| 53 |
+
frame_extraction: Optional[str] = "head"
|
| 54 |
+
frame_stride: Optional[int] = 1
|
| 55 |
+
frame_sample: Optional[int] = 1
|
| 56 |
+
max_frames: Optional[int] = 129
|
| 57 |
+
source_fps: Optional[float] = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass
|
| 61 |
+
class DatasetBlueprint:
|
| 62 |
+
is_image_dataset: bool
|
| 63 |
+
params: Union[ImageDatasetParams, VideoDatasetParams]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class DatasetGroupBlueprint:
|
| 68 |
+
datasets: Sequence[DatasetBlueprint]
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclass
|
| 72 |
+
class Blueprint:
|
| 73 |
+
dataset_group: DatasetGroupBlueprint
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class ConfigSanitizer:
|
| 77 |
+
# @curry
|
| 78 |
+
@staticmethod
|
| 79 |
+
def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
|
| 80 |
+
Schema(ExactSequence([klass, klass]))(value)
|
| 81 |
+
return tuple(value)
|
| 82 |
+
|
| 83 |
+
# @curry
|
| 84 |
+
@staticmethod
|
| 85 |
+
def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
|
| 86 |
+
Schema(Any(klass, ExactSequence([klass, klass])))(value)
|
| 87 |
+
try:
|
| 88 |
+
Schema(klass)(value)
|
| 89 |
+
return (value, value)
|
| 90 |
+
except:
|
| 91 |
+
return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
|
| 92 |
+
|
| 93 |
+
# datasets schema
|
| 94 |
+
DATASET_ASCENDABLE_SCHEMA = {
|
| 95 |
+
"caption_extension": str,
|
| 96 |
+
"batch_size": int,
|
| 97 |
+
"num_repeats": int,
|
| 98 |
+
"resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
|
| 99 |
+
"enable_bucket": bool,
|
| 100 |
+
"bucket_no_upscale": bool,
|
| 101 |
+
}
|
| 102 |
+
IMAGE_DATASET_DISTINCT_SCHEMA = {
|
| 103 |
+
"image_directory": str,
|
| 104 |
+
"image_jsonl_file": str,
|
| 105 |
+
"cache_directory": str,
|
| 106 |
+
"control_directory": str,
|
| 107 |
+
}
|
| 108 |
+
VIDEO_DATASET_DISTINCT_SCHEMA = {
|
| 109 |
+
"video_directory": str,
|
| 110 |
+
"video_jsonl_file": str,
|
| 111 |
+
"control_directory": str,
|
| 112 |
+
"target_frames": [int],
|
| 113 |
+
"frame_extraction": str,
|
| 114 |
+
"frame_stride": int,
|
| 115 |
+
"frame_sample": int,
|
| 116 |
+
"max_frames": int,
|
| 117 |
+
"cache_directory": str,
|
| 118 |
+
"source_fps": float,
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
# options handled by argparse but not handled by user config
|
| 122 |
+
ARGPARSE_SPECIFIC_SCHEMA = {
|
| 123 |
+
"debug_dataset": bool,
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
def __init__(self) -> None:
|
| 127 |
+
self.image_dataset_schema = self.__merge_dict(
|
| 128 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 129 |
+
self.IMAGE_DATASET_DISTINCT_SCHEMA,
|
| 130 |
+
)
|
| 131 |
+
self.video_dataset_schema = self.__merge_dict(
|
| 132 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 133 |
+
self.VIDEO_DATASET_DISTINCT_SCHEMA,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
def validate_flex_dataset(dataset_config: dict):
|
| 137 |
+
if "video_directory" in dataset_config or "video_jsonl_file" in dataset_config:
|
| 138 |
+
return Schema(self.video_dataset_schema)(dataset_config)
|
| 139 |
+
else:
|
| 140 |
+
return Schema(self.image_dataset_schema)(dataset_config)
|
| 141 |
+
|
| 142 |
+
self.dataset_schema = validate_flex_dataset
|
| 143 |
+
|
| 144 |
+
self.general_schema = self.__merge_dict(
|
| 145 |
+
self.DATASET_ASCENDABLE_SCHEMA,
|
| 146 |
+
)
|
| 147 |
+
self.user_config_validator = Schema(
|
| 148 |
+
{
|
| 149 |
+
"general": self.general_schema,
|
| 150 |
+
"datasets": [self.dataset_schema],
|
| 151 |
+
}
|
| 152 |
+
)
|
| 153 |
+
self.argparse_schema = self.__merge_dict(
|
| 154 |
+
self.ARGPARSE_SPECIFIC_SCHEMA,
|
| 155 |
+
)
|
| 156 |
+
self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
|
| 157 |
+
|
| 158 |
+
def sanitize_user_config(self, user_config: dict) -> dict:
|
| 159 |
+
try:
|
| 160 |
+
return self.user_config_validator(user_config)
|
| 161 |
+
except MultipleInvalid:
|
| 162 |
+
# TODO: clarify the error message
|
| 163 |
+
logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
|
| 164 |
+
raise
|
| 165 |
+
|
| 166 |
+
# NOTE: In nature, argument parser result is not needed to be sanitize
|
| 167 |
+
# However this will help us to detect program bug
|
| 168 |
+
def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
|
| 169 |
+
try:
|
| 170 |
+
return self.argparse_config_validator(argparse_namespace)
|
| 171 |
+
except MultipleInvalid:
|
| 172 |
+
# XXX: this should be a bug
|
| 173 |
+
logger.error(
|
| 174 |
+
"Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
|
| 175 |
+
)
|
| 176 |
+
raise
|
| 177 |
+
|
| 178 |
+
# NOTE: value would be overwritten by latter dict if there is already the same key
|
| 179 |
+
@staticmethod
|
| 180 |
+
def __merge_dict(*dict_list: dict) -> dict:
|
| 181 |
+
merged = {}
|
| 182 |
+
for schema in dict_list:
|
| 183 |
+
# merged |= schema
|
| 184 |
+
for k, v in schema.items():
|
| 185 |
+
merged[k] = v
|
| 186 |
+
return merged
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class BlueprintGenerator:
|
| 190 |
+
BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
|
| 191 |
+
|
| 192 |
+
def __init__(self, sanitizer: ConfigSanitizer):
|
| 193 |
+
self.sanitizer = sanitizer
|
| 194 |
+
|
| 195 |
+
# runtime_params is for parameters which is only configurable on runtime, such as tokenizer
|
| 196 |
+
def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
|
| 197 |
+
sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
|
| 198 |
+
sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
|
| 199 |
+
|
| 200 |
+
argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
|
| 201 |
+
general_config = sanitized_user_config.get("general", {})
|
| 202 |
+
|
| 203 |
+
dataset_blueprints = []
|
| 204 |
+
for dataset_config in sanitized_user_config.get("datasets", []):
|
| 205 |
+
is_image_dataset = "image_directory" in dataset_config or "image_jsonl_file" in dataset_config
|
| 206 |
+
if is_image_dataset:
|
| 207 |
+
dataset_params_klass = ImageDatasetParams
|
| 208 |
+
else:
|
| 209 |
+
dataset_params_klass = VideoDatasetParams
|
| 210 |
+
|
| 211 |
+
params = self.generate_params_by_fallbacks(
|
| 212 |
+
dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
|
| 213 |
+
)
|
| 214 |
+
dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
|
| 215 |
+
|
| 216 |
+
dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
|
| 217 |
+
|
| 218 |
+
return Blueprint(dataset_group_blueprint)
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
|
| 222 |
+
name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
|
| 223 |
+
search_value = BlueprintGenerator.search_value
|
| 224 |
+
default_params = asdict(param_klass())
|
| 225 |
+
param_names = default_params.keys()
|
| 226 |
+
|
| 227 |
+
params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
|
| 228 |
+
|
| 229 |
+
return param_klass(**params)
|
| 230 |
+
|
| 231 |
+
@staticmethod
|
| 232 |
+
def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
|
| 233 |
+
for cand in fallbacks:
|
| 234 |
+
value = cand.get(key)
|
| 235 |
+
if value is not None:
|
| 236 |
+
return value
|
| 237 |
+
|
| 238 |
+
return default_value
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# if training is True, it will return a dataset group for training, otherwise for caching
|
| 242 |
+
def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
|
| 243 |
+
datasets: List[Union[ImageDataset, VideoDataset]] = []
|
| 244 |
+
|
| 245 |
+
for dataset_blueprint in dataset_group_blueprint.datasets:
|
| 246 |
+
if dataset_blueprint.is_image_dataset:
|
| 247 |
+
dataset_klass = ImageDataset
|
| 248 |
+
else:
|
| 249 |
+
dataset_klass = VideoDataset
|
| 250 |
+
|
| 251 |
+
dataset = dataset_klass(**asdict(dataset_blueprint.params))
|
| 252 |
+
datasets.append(dataset)
|
| 253 |
+
|
| 254 |
+
# assertion
|
| 255 |
+
cache_directories = [dataset.cache_directory for dataset in datasets]
|
| 256 |
+
num_of_unique_cache_directories = len(set(cache_directories))
|
| 257 |
+
if num_of_unique_cache_directories != len(cache_directories):
|
| 258 |
+
raise ValueError(
|
| 259 |
+
"cache directory should be unique for each dataset (note that cache directory is image/video directory if not specified)"
|
| 260 |
+
+ " / cache directory は各データセットごとに異なる必要があります(指定されていない場合はimage/video directoryが使われるので注意)"
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# print info
|
| 264 |
+
info = ""
|
| 265 |
+
for i, dataset in enumerate(datasets):
|
| 266 |
+
is_image_dataset = isinstance(dataset, ImageDataset)
|
| 267 |
+
info += dedent(
|
| 268 |
+
f"""\
|
| 269 |
+
[Dataset {i}]
|
| 270 |
+
is_image_dataset: {is_image_dataset}
|
| 271 |
+
resolution: {dataset.resolution}
|
| 272 |
+
batch_size: {dataset.batch_size}
|
| 273 |
+
num_repeats: {dataset.num_repeats}
|
| 274 |
+
caption_extension: "{dataset.caption_extension}"
|
| 275 |
+
enable_bucket: {dataset.enable_bucket}
|
| 276 |
+
bucket_no_upscale: {dataset.bucket_no_upscale}
|
| 277 |
+
cache_directory: "{dataset.cache_directory}"
|
| 278 |
+
debug_dataset: {dataset.debug_dataset}
|
| 279 |
+
"""
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
if is_image_dataset:
|
| 283 |
+
info += indent(
|
| 284 |
+
dedent(
|
| 285 |
+
f"""\
|
| 286 |
+
image_directory: "{dataset.image_directory}"
|
| 287 |
+
image_jsonl_file: "{dataset.image_jsonl_file}"
|
| 288 |
+
control_directory: "{dataset.control_directory}"
|
| 289 |
+
\n"""
|
| 290 |
+
),
|
| 291 |
+
" ",
|
| 292 |
+
)
|
| 293 |
+
else:
|
| 294 |
+
info += indent(
|
| 295 |
+
dedent(
|
| 296 |
+
f"""\
|
| 297 |
+
video_directory: "{dataset.video_directory}"
|
| 298 |
+
video_jsonl_file: "{dataset.video_jsonl_file}"
|
| 299 |
+
control_directory: "{dataset.control_directory}"
|
| 300 |
+
target_frames: {dataset.target_frames}
|
| 301 |
+
frame_extraction: {dataset.frame_extraction}
|
| 302 |
+
frame_stride: {dataset.frame_stride}
|
| 303 |
+
frame_sample: {dataset.frame_sample}
|
| 304 |
+
max_frames: {dataset.max_frames}
|
| 305 |
+
source_fps: {dataset.source_fps}
|
| 306 |
+
\n"""
|
| 307 |
+
),
|
| 308 |
+
" ",
|
| 309 |
+
)
|
| 310 |
+
logger.info(f"{info}")
|
| 311 |
+
|
| 312 |
+
# make buckets first because it determines the length of dataset
|
| 313 |
+
# and set the same seed for all datasets
|
| 314 |
+
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
|
| 315 |
+
for i, dataset in enumerate(datasets):
|
| 316 |
+
# logger.info(f"[Dataset {i}]")
|
| 317 |
+
dataset.set_seed(seed)
|
| 318 |
+
if training:
|
| 319 |
+
dataset.prepare_for_training()
|
| 320 |
+
|
| 321 |
+
return DatasetGroup(datasets)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def load_user_config(file: str) -> dict:
|
| 325 |
+
file: Path = Path(file)
|
| 326 |
+
if not file.is_file():
|
| 327 |
+
raise ValueError(f"file not found / ファイルが見つかりません: {file}")
|
| 328 |
+
|
| 329 |
+
if file.name.lower().endswith(".json"):
|
| 330 |
+
try:
|
| 331 |
+
with open(file, "r", encoding="utf-8") as f:
|
| 332 |
+
config = json.load(f)
|
| 333 |
+
except Exception:
|
| 334 |
+
logger.error(
|
| 335 |
+
f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
| 336 |
+
)
|
| 337 |
+
raise
|
| 338 |
+
elif file.name.lower().endswith(".toml"):
|
| 339 |
+
try:
|
| 340 |
+
config = toml.load(file)
|
| 341 |
+
except Exception:
|
| 342 |
+
logger.error(
|
| 343 |
+
f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
|
| 344 |
+
)
|
| 345 |
+
raise
|
| 346 |
+
else:
|
| 347 |
+
raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
|
| 348 |
+
|
| 349 |
+
return config
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# for config test
|
| 353 |
+
if __name__ == "__main__":
|
| 354 |
+
parser = argparse.ArgumentParser()
|
| 355 |
+
parser.add_argument("dataset_config")
|
| 356 |
+
config_args, remain = parser.parse_known_args()
|
| 357 |
+
|
| 358 |
+
parser = argparse.ArgumentParser()
|
| 359 |
+
parser.add_argument("--debug_dataset", action="store_true")
|
| 360 |
+
argparse_namespace = parser.parse_args(remain)
|
| 361 |
+
|
| 362 |
+
logger.info("[argparse_namespace]")
|
| 363 |
+
logger.info(f"{vars(argparse_namespace)}")
|
| 364 |
+
|
| 365 |
+
user_config = load_user_config(config_args.dataset_config)
|
| 366 |
+
|
| 367 |
+
logger.info("")
|
| 368 |
+
logger.info("[user_config]")
|
| 369 |
+
logger.info(f"{user_config}")
|
| 370 |
+
|
| 371 |
+
sanitizer = ConfigSanitizer()
|
| 372 |
+
sanitized_user_config = sanitizer.sanitize_user_config(user_config)
|
| 373 |
+
|
| 374 |
+
logger.info("")
|
| 375 |
+
logger.info("[sanitized_user_config]")
|
| 376 |
+
logger.info(f"{sanitized_user_config}")
|
| 377 |
+
|
| 378 |
+
blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
|
| 379 |
+
|
| 380 |
+
logger.info("")
|
| 381 |
+
logger.info("[blueprint]")
|
| 382 |
+
logger.info(f"{blueprint}")
|
| 383 |
+
|
| 384 |
+
dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
dataset/dataset_config.md
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
> 📝 Click on the language section to expand / 言語をクリックして展開
|
| 2 |
+
|
| 3 |
+
## Dataset Configuration
|
| 4 |
+
|
| 5 |
+
Please create a TOML file for dataset configuration.
|
| 6 |
+
|
| 7 |
+
Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
|
| 8 |
+
|
| 9 |
+
The cache directory must be different for each dataset.
|
| 10 |
+
|
| 11 |
+
Each video is extracted frame by frame without additional processing and used for training. It is recommended to use videos with a frame rate of 24fps for HunyuanVideo, 16fps for Wan2.1 and 30fps for FramePack. You can check the videos that will be trained using `--debug_mode video` when caching latent (see [here](/README.md#latent-caching)).
|
| 12 |
+
<details>
|
| 13 |
+
<summary>日本語</summary>
|
| 14 |
+
|
| 15 |
+
データセットの設定を行うためのTOMLファイルを作成してください。
|
| 16 |
+
|
| 17 |
+
画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。
|
| 18 |
+
|
| 19 |
+
キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。
|
| 20 |
+
|
| 21 |
+
動画は追加のプロセスなしでフレームごとに抽出され、学習に用いられます。そのため、HunyuanVideoは24fps、Wan2.1は16fps、FramePackは30fpsのフレームレートの動画を使用することをお勧めします。latentキャッシュ時の`--debug_mode video`を使用すると、学習される動画を確認できます([こちら](/README.ja.md#latentの事前キャッシュ)を参照)。
|
| 22 |
+
</details>
|
| 23 |
+
|
| 24 |
+
### Sample for Image Dataset with Caption Text Files
|
| 25 |
+
|
| 26 |
+
```toml
|
| 27 |
+
# resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
| 28 |
+
# otherwise, the default values will be used for each item
|
| 29 |
+
|
| 30 |
+
# general configurations
|
| 31 |
+
[general]
|
| 32 |
+
resolution = [960, 544]
|
| 33 |
+
caption_extension = ".txt"
|
| 34 |
+
batch_size = 1
|
| 35 |
+
enable_bucket = true
|
| 36 |
+
bucket_no_upscale = false
|
| 37 |
+
|
| 38 |
+
[[datasets]]
|
| 39 |
+
image_directory = "/path/to/image_dir"
|
| 40 |
+
cache_directory = "/path/to/cache_directory"
|
| 41 |
+
num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
|
| 42 |
+
|
| 43 |
+
# other datasets can be added here. each dataset can have different configurations
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
`cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets.
|
| 47 |
+
|
| 48 |
+
`num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes.
|
| 49 |
+
|
| 50 |
+
<details>
|
| 51 |
+
<summary>日本語</summary>
|
| 52 |
+
|
| 53 |
+
`cache_directory` はオプションです。デフォルトは画像ディレクトリと同じディレクトリに設定されます。ただし、異なるデータセット間でキャッシュファイルが共有されるのを防ぐために、明示的に別のキャッシュディレクトリを設定することをお勧めします。
|
| 54 |
+
|
| 55 |
+
`num_repeats` はオプションで、デフォルトは 1 です(繰り返しなし)。画像(や動画)を、その回数だけ単純に繰り返してデータセットを拡張します。たとえば`num_repeats = 2`としたとき、画像20枚のデータセットなら、各画像が2枚ずつ(同一のキャプションで)計40枚存在した場合と同じになります。異なるデータ数のデータセット間でバランスを取るために使用可能です。
|
| 56 |
+
|
| 57 |
+
resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
|
| 58 |
+
|
| 59 |
+
`[[datasets]]`以下を追加することで、他のデータセットを追加できます。各データセットには異なる設定を持てます。
|
| 60 |
+
</details>
|
| 61 |
+
|
| 62 |
+
### Sample for Image Dataset with Metadata JSONL File
|
| 63 |
+
|
| 64 |
+
```toml
|
| 65 |
+
# resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
|
| 66 |
+
# caption_extension is not required for metadata jsonl file
|
| 67 |
+
# cache_directory is required for each dataset with metadata jsonl file
|
| 68 |
+
|
| 69 |
+
# general configurations
|
| 70 |
+
[general]
|
| 71 |
+
resolution = [960, 544]
|
| 72 |
+
batch_size = 1
|
| 73 |
+
enable_bucket = true
|
| 74 |
+
bucket_no_upscale = false
|
| 75 |
+
|
| 76 |
+
[[datasets]]
|
| 77 |
+
image_jsonl_file = "/path/to/metadata.jsonl"
|
| 78 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
| 79 |
+
num_repeats = 1 # optional, default is 1. Same as above.
|
| 80 |
+
|
| 81 |
+
# other datasets can be added here. each dataset can have different configurations
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
JSONL file format for metadata:
|
| 85 |
+
|
| 86 |
+
```json
|
| 87 |
+
{"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
|
| 88 |
+
{"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
<details>
|
| 92 |
+
<summary>日本語</summary>
|
| 93 |
+
|
| 94 |
+
resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
|
| 95 |
+
|
| 96 |
+
metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
|
| 97 |
+
|
| 98 |
+
キャプションによるデータセットと同様に、複数のデータセットを追加できます。各データセットには異なる設定を持てます。
|
| 99 |
+
</details>
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
### Sample for Video Dataset with Caption Text Files
|
| 103 |
+
|
| 104 |
+
```toml
|
| 105 |
+
# Common parameters (resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)
|
| 106 |
+
# can be set in either general or datasets sections
|
| 107 |
+
# Video-specific parameters (target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)
|
| 108 |
+
# must be set in each datasets section
|
| 109 |
+
|
| 110 |
+
# general configurations
|
| 111 |
+
[general]
|
| 112 |
+
resolution = [960, 544]
|
| 113 |
+
caption_extension = ".txt"
|
| 114 |
+
batch_size = 1
|
| 115 |
+
enable_bucket = true
|
| 116 |
+
bucket_no_upscale = false
|
| 117 |
+
|
| 118 |
+
[[datasets]]
|
| 119 |
+
video_directory = "/path/to/video_dir"
|
| 120 |
+
cache_directory = "/path/to/cache_directory" # recommended to set cache directory
|
| 121 |
+
target_frames = [1, 25, 45]
|
| 122 |
+
frame_extraction = "head"
|
| 123 |
+
source_fps = 30.0 # optional, source fps for videos in the directory, decimal number
|
| 124 |
+
|
| 125 |
+
[[datasets]]
|
| 126 |
+
video_directory = "/path/to/video_dir2"
|
| 127 |
+
cache_directory = "/path/to/cache_directory2" # recommended to set cache directory
|
| 128 |
+
frame_extraction = "full"
|
| 129 |
+
max_frames = 45
|
| 130 |
+
|
| 131 |
+
# other datasets can be added here. each dataset can have different configurations
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
__In HunyuanVideo and Wan2.1, the number of `target_frames` must be "N\*4+1" (N=0,1,2,...).__ Otherwise, it will be truncated to the nearest "N*4+1".
|
| 135 |
+
|
| 136 |
+
In FramePack, it is recommended to set `frame_extraction` to `full` and `max_frames` to a sufficiently large value, as it can handle longer videos. However, if the video is too long, an Out of Memory error may occur during VAE encoding. The videos in FramePack are trimmed to "N * latent_window_size * 4 + 1" frames (for example, 37, 73, 109... if `latent_window_size` is 9).
|
| 137 |
+
|
| 138 |
+
If the `source_fps` is specified, the videos in the directory are considered to be at this frame rate, and some frames will be skipped to match the model's frame rate (24 for HunyuanVideo and 16 for Wan2.1). __The value must be a decimal number, for example, `30.0` instead of `30`.__ The skipping is done automatically and does not consider the content of the images. Please check if the converted data is correct using `--debug_mode video`.
|
| 139 |
+
|
| 140 |
+
If `source_fps` is not specified (default), all frames of the video will be used regardless of the video's frame rate.
|
| 141 |
+
|
| 142 |
+
<details>
|
| 143 |
+
<summary>日本語</summary>
|
| 144 |
+
|
| 145 |
+
共通パラメータ(resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)は、generalまたはdatasetsのいずれかに設定できます。
|
| 146 |
+
動画固有のパラメータ(target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)は、各datasetsセクションに設定する必要があります。
|
| 147 |
+
|
| 148 |
+
__HunyuanVideoおよびWan2.1では、target_framesの数値は「N\*4+1」である必要があります。__ これ以外の値の場合は、最も近いN\*4+1の値に切り捨てられます。
|
| 149 |
+
|
| 150 |
+
FramePackでも同様ですが、FramePackでは動画が長くても学習可能なため、 `frame_extraction`に`full` を指定し、`max_frames`を十分に大きな値に設定することをお勧めします。ただし、あまりにも長すぎるとVAEのencodeでOut of Memoryエラーが発生する可能性があります。FramePackの動画は、「N * latent_window_size * 4 + 1」フレームにトリミングされます(latent_window_sizeが9の場合、37、73、109……)。
|
| 151 |
+
|
| 152 |
+
`source_fps`を指定した場合、ディレクトリ内の動画をこのフレームレートとみなして、モデルのフレームレートにあうようにいくつかのフレームをスキップします(HunyuanVideoは24、Wan2.1は16)。__小数点を含む数値で指定してください。__ 例:`30`ではなく`30.0`。スキップは機械的に行われ、画像の内容は考慮しません。変換後のデータが正しいか、`--debug_mode video`で確認してください。
|
| 153 |
+
|
| 154 |
+
`source_fps`を指定しない場合、動画のフレームは(動画自体のフレームレートに関係なく)すべて使用されます。
|
| 155 |
+
|
| 156 |
+
他の注意事項は画像データセットと同様です。
|
| 157 |
+
</details>
|
| 158 |
+
|
| 159 |
+
### Sample for Video Dataset with Metadata JSONL File
|
| 160 |
+
|
| 161 |
+
```toml
|
| 162 |
+
# Common parameters (resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)
|
| 163 |
+
# can be set in either general or datasets sections
|
| 164 |
+
# Video-specific parameters (target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)
|
| 165 |
+
# must be set in each datasets section
|
| 166 |
+
|
| 167 |
+
# caption_extension is not required for metadata jsonl file
|
| 168 |
+
# cache_directory is required for each dataset with metadata jsonl file
|
| 169 |
+
|
| 170 |
+
# general configurations
|
| 171 |
+
[general]
|
| 172 |
+
resolution = [960, 544]
|
| 173 |
+
batch_size = 1
|
| 174 |
+
enable_bucket = true
|
| 175 |
+
bucket_no_upscale = false
|
| 176 |
+
|
| 177 |
+
[[datasets]]
|
| 178 |
+
video_jsonl_file = "/path/to/metadata.jsonl"
|
| 179 |
+
target_frames = [1, 25, 45]
|
| 180 |
+
frame_extraction = "head"
|
| 181 |
+
cache_directory = "/path/to/cache_directory_head"
|
| 182 |
+
source_fps = 30.0 # optional, source fps for videos in the jsonl file
|
| 183 |
+
# same metadata jsonl file can be used for multiple datasets
|
| 184 |
+
[[datasets]]
|
| 185 |
+
video_jsonl_file = "/path/to/metadata.jsonl"
|
| 186 |
+
target_frames = [1]
|
| 187 |
+
frame_stride = 10
|
| 188 |
+
cache_directory = "/path/to/cache_directory_stride"
|
| 189 |
+
|
| 190 |
+
# other datasets can be added here. each dataset can have different configurations
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
JSONL file format for metadata:
|
| 194 |
+
|
| 195 |
+
```json
|
| 196 |
+
{"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
|
| 197 |
+
{"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
|
| 198 |
+
```
|
| 199 |
+
|
| 200 |
+
`video_path` can be a directory containing multiple images.
|
| 201 |
+
|
| 202 |
+
<details>
|
| 203 |
+
<summary>日本語</summary>
|
| 204 |
+
metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
|
| 205 |
+
|
| 206 |
+
`video_path`は、複数の画像を含むディレクトリのパスでも構いません。
|
| 207 |
+
|
| 208 |
+
他の注意事項は今までのデータセットと同様です。
|
| 209 |
+
</details>
|
| 210 |
+
|
| 211 |
+
### frame_extraction Options
|
| 212 |
+
|
| 213 |
+
- `head`: Extract the first N frames from the video.
|
| 214 |
+
- `chunk`: Extract frames by splitting the video into chunks of N frames.
|
| 215 |
+
- `slide`: Extract frames from the video with a stride of `frame_stride`.
|
| 216 |
+
- `uniform`: Extract `frame_sample` samples uniformly from the video.
|
| 217 |
+
- `full`: Extract all frames from the video.
|
| 218 |
+
|
| 219 |
+
In the case of `full`, the entire video is used, but it is trimmed to "N*4+1" frames. It is also trimmed to the `max_frames` if it exceeds that value. To avoid Out of Memory errors, please set `max_frames`.
|
| 220 |
+
|
| 221 |
+
The frame extraction methods other than `full` are recommended when the video contains repeated actions. `full` is recommended when each video represents a single complete motion.
|
| 222 |
+
|
| 223 |
+
For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
|
| 224 |
+
|
| 225 |
+
<details>
|
| 226 |
+
<summary>日本語</summary>
|
| 227 |
+
|
| 228 |
+
- `head`: 動画から最初のNフレームを抽出します。
|
| 229 |
+
- `chunk`: 動画をNフレームずつに分割してフレームを抽出します。
|
| 230 |
+
- `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。
|
| 231 |
+
- `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。
|
| 232 |
+
- `full`: 動画から全てのフレームを抽出します。
|
| 233 |
+
|
| 234 |
+
`full`の場合、各動画の全体を用いますが、「N*4+1」のフレーム数にトリミングされます。また`max_frames`を超える場合もその値にトリミングされます。Out of Memoryエラーを避けるために、`max_frames`を設定してください。
|
| 235 |
+
|
| 236 |
+
`full`以外の抽出方法は、動画が特定の動作を繰り返している場合にお勧めします。`full`はそれぞれの動画がひとつの完結したモーションの場合にお勧めします。
|
| 237 |
+
|
| 238 |
+
例えば、40フレームの動画を例とした抽出について、以下の図で説明します。
|
| 239 |
+
</details>
|
| 240 |
+
|
| 241 |
+
```
|
| 242 |
+
Original Video, 40 frames: x = frame, o = no frame
|
| 243 |
+
oooooooooooooooooooooooooooooooooooooooo
|
| 244 |
+
|
| 245 |
+
head, target_frames = [1, 13, 25] -> extract head frames:
|
| 246 |
+
xooooooooooooooooooooooooooooooooooooooo
|
| 247 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
| 248 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
| 249 |
+
|
| 250 |
+
chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
|
| 251 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
| 252 |
+
oooooooooooooxxxxxxxxxxxxxoooooooooooooo
|
| 253 |
+
ooooooooooooooooooooooooooxxxxxxxxxxxxxo
|
| 254 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
| 255 |
+
|
| 256 |
+
NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
|
| 257 |
+
注: frame_extraction "chunk" を使用する場合、target_frames に 1 を含めないでください。全てのフレームが抽出されてしまいます。
|
| 258 |
+
|
| 259 |
+
slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
|
| 260 |
+
xooooooooooooooooooooooooooooooooooooooo
|
| 261 |
+
ooooooooooxooooooooooooooooooooooooooooo
|
| 262 |
+
ooooooooooooooooooooxooooooooooooooooooo
|
| 263 |
+
ooooooooooooooooooooooooooooooxooooooooo
|
| 264 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
| 265 |
+
ooooooooooxxxxxxxxxxxxxooooooooooooooooo
|
| 266 |
+
ooooooooooooooooooooxxxxxxxxxxxxxooooooo
|
| 267 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
| 268 |
+
ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
|
| 269 |
+
|
| 270 |
+
uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
|
| 271 |
+
xooooooooooooooooooooooooooooooooooooooo
|
| 272 |
+
oooooooooooooxoooooooooooooooooooooooooo
|
| 273 |
+
oooooooooooooooooooooooooxoooooooooooooo
|
| 274 |
+
ooooooooooooooooooooooooooooooooooooooox
|
| 275 |
+
xxxxxxxxxxxxxooooooooooooooooooooooooooo
|
| 276 |
+
oooooooooxxxxxxxxxxxxxoooooooooooooooooo
|
| 277 |
+
ooooooooooooooooooxxxxxxxxxxxxxooooooooo
|
| 278 |
+
oooooooooooooooooooooooooooxxxxxxxxxxxxx
|
| 279 |
+
xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
|
| 280 |
+
oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
|
| 281 |
+
ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
|
| 282 |
+
oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
|
| 283 |
+
|
| 284 |
+
Three Original Videos, 20, 25, 35 frames: x = frame, o = no frame
|
| 285 |
+
|
| 286 |
+
full, max_frames = 31 -> extract all frames (trimmed to the maximum length):
|
| 287 |
+
video1: xxxxxxxxxxxxxxxxx (trimmed to 17 frames)
|
| 288 |
+
video2: xxxxxxxxxxxxxxxxxxxxxxxxx (25 frames)
|
| 289 |
+
video3: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx (trimmed to 31 frames)
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
### Sample for Image Dataset with Control Images
|
| 293 |
+
|
| 294 |
+
The dataset with control images is used for training the single frame training for FramePack.
|
| 295 |
+
|
| 296 |
+
The dataset configuration with caption text files is similar to the image dataset, but with an additional `control_directory` parameter.
|
| 297 |
+
|
| 298 |
+
The control images are used from the `control_directory` with the same filename (or different extension) as the image, for example, `image_dir/image1.jpg` and `control_dir/image1.png`. The images in `image_directory` should be the target images (the images to be generated during inference, the changed images). The `control_directory` should contain the starting images for inference. The captions should be stored in `image_directory`.
|
| 299 |
+
|
| 300 |
+
The metadata JSONL file format is the same as the image dataset, but with an additional `control_path` parameter.
|
| 301 |
+
|
| 302 |
+
```json
|
| 303 |
+
{"image_path": "/path/to/image1.jpg", "control_path": "/path/to/control1.png", "caption": "A caption for image1"}
|
| 304 |
+
{"image_path": "/path/to/image2.jpg", "control_path": "/path/to/control2.png", "caption": "A caption for image2"}
|
| 305 |
+
```
|
| 306 |
+
|
| 307 |
+
<details>
|
| 308 |
+
<summary>日本語</summary>
|
| 309 |
+
制御画像を持つデータセットです。FramePackの単一フレーム学習に使用します。
|
| 310 |
+
|
| 311 |
+
キャプションファイルを用いる場合は`control_directory`を追加で指定してください。制御用画像は、画像と同じファイル名(または拡張子のみが異なるファイル名)の、`control_directory`にある画像が使用されます(例:`image_dir/image1.jpg`と`control_dir/image1.png`)。`image_directory`の画像は学習対象の画像(推論時に生成する画像、変化後の画像)としてください。`control_directory`には推論時の開始画像を格納してください。キャプションは`image_directory`へ格納してください。
|
| 312 |
+
|
| 313 |
+
メタデータJSONLファイルを使用する場合は、`control_path`を追加してください。
|
| 314 |
+
</details>
|
| 315 |
+
|
| 316 |
+
### Sample for Video Dataset with Control Images
|
| 317 |
+
|
| 318 |
+
The dataset with control videos is used for training ControlNet models.
|
| 319 |
+
|
| 320 |
+
The dataset configuration with caption text files is similar to the video dataset, but with an additional `control_directory` parameter.
|
| 321 |
+
|
| 322 |
+
The control video for a video is used from the `control_directory` with the same filename (or different extension) as the video, for example, `video_dir/video1.mp4` and `control_dir/video1.mp4` or `control_dir/video1.mov`. The control video can also be a directory without an extension, for example, `video_dir/video1.mp4` and `control_dir/video1`.
|
| 323 |
+
|
| 324 |
+
```toml
|
| 325 |
+
[[datasets]]
|
| 326 |
+
video_directory = "/path/to/video_dir"
|
| 327 |
+
control_directory = "/path/to/control_dir" # required for dataset with control videos
|
| 328 |
+
cache_directory = "/path/to/cache_directory" # recommended to set cache directory
|
| 329 |
+
target_frames = [1, 25, 45]
|
| 330 |
+
frame_extraction = "head"
|
| 331 |
+
```
|
| 332 |
+
|
| 333 |
+
The dataset configuration with metadata JSONL file is same as the video dataset, but metadata JSONL file must include the control video paths. The control video path can be a directory containing multiple images.
|
| 334 |
+
|
| 335 |
+
```json
|
| 336 |
+
{"video_path": "/path/to/video1.mp4", "control_path": "/path/to/control1.mp4", "caption": "A caption for video1"}
|
| 337 |
+
{"video_path": "/path/to/video2.mp4", "control_path": "/path/to/control2.mp4", "caption": "A caption for video2"}
|
| 338 |
+
```
|
| 339 |
+
|
| 340 |
+
<details>
|
| 341 |
+
<summary>日本語</summary>
|
| 342 |
+
制御動画を持つデータセットです。ControlNetモデルの学習に使用します。
|
| 343 |
+
|
| 344 |
+
キャプションを用いる場合のデータセット設定は動画データセットと似ていますが、`control_directory`パラメータが追加されています。上にある例を参照してください。ある動画に対する制御用動画として、動画と同じファイル名(または拡張子のみが異なるファイル名)の、`control_directory`にある動画が使用されます(例:`video_dir/video1.mp4`と`control_dir/video1.mp4`または`control_dir/video1.mov`)。また、拡張子なしのディレクトリ内の、複数枚の画像を制御用動画として使用することもできます(例:`video_dir/video1.mp4`と`control_dir/video1`)。
|
| 345 |
+
|
| 346 |
+
データセット設定でメタデータJSONLファイルを使用する場合は、動画と制御用動画のパスを含める必要があります。制御用動画のパスは、複数枚の画像を含むディレクトリのパスでも構いません。
|
| 347 |
+
</details>
|
| 348 |
+
|
| 349 |
+
## Specifications
|
| 350 |
+
|
| 351 |
+
```toml
|
| 352 |
+
# general configurations
|
| 353 |
+
[general]
|
| 354 |
+
resolution = [960, 544] # optional, [W, H], default is [960, 544]. This is the default resolution for all datasets
|
| 355 |
+
caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
|
| 356 |
+
batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
|
| 357 |
+
num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
|
| 358 |
+
enable_bucket = true # optional, default is false. Enable bucketing for datasets
|
| 359 |
+
bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
|
| 360 |
+
|
| 361 |
+
### Image Dataset
|
| 362 |
+
|
| 363 |
+
# sample image dataset with caption text files
|
| 364 |
+
[[datasets]]
|
| 365 |
+
image_directory = "/path/to/image_dir"
|
| 366 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
| 367 |
+
resolution = [960, 544] # required if general resolution is not set
|
| 368 |
+
batch_size = 4 # optional, overwrite the default batch size
|
| 369 |
+
num_repeats = 1 # optional, overwrite the default num_repeats
|
| 370 |
+
enable_bucket = false # optional, overwrite the default bucketing setting
|
| 371 |
+
bucket_no_upscale = true # optional, overwrite the default bucketing setting
|
| 372 |
+
cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
|
| 373 |
+
control_directory = "/path/to/control_dir" # optional, required for dataset with control images
|
| 374 |
+
|
| 375 |
+
# sample image dataset with metadata **jsonl** file
|
| 376 |
+
[[datasets]]
|
| 377 |
+
image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
|
| 378 |
+
resolution = [960, 544] # required if general resolution is not set
|
| 379 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
| 380 |
+
# caption_extension is not required for metadata jsonl file
|
| 381 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
|
| 382 |
+
|
| 383 |
+
### Video Dataset
|
| 384 |
+
|
| 385 |
+
# sample video dataset with caption text files
|
| 386 |
+
[[datasets]]
|
| 387 |
+
video_directory = "/path/to/video_dir"
|
| 388 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
| 389 |
+
resolution = [960, 544] # required if general resolution is not set
|
| 390 |
+
|
| 391 |
+
control_directory = "/path/to/control_dir" # optional, required for dataset with control images
|
| 392 |
+
|
| 393 |
+
# following configurations must be set in each [[datasets]] section for video datasets
|
| 394 |
+
|
| 395 |
+
target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
|
| 396 |
+
|
| 397 |
+
# NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
|
| 398 |
+
|
| 399 |
+
frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
|
| 400 |
+
frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
|
| 401 |
+
frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
|
| 402 |
+
max_frames = 129 # optional, default is 129. Maximum number of frames to extract, available for "full" frame extraction
|
| 403 |
+
# batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
|
| 404 |
+
|
| 405 |
+
# sample video dataset with metadata jsonl file
|
| 406 |
+
[[datasets]]
|
| 407 |
+
video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
|
| 408 |
+
|
| 409 |
+
target_frames = [1, 79]
|
| 410 |
+
|
| 411 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
| 412 |
+
# frame_extraction, frame_stride, frame_sample, max_frames are also available for metadata jsonl file
|
| 413 |
+
```
|
| 414 |
+
|
| 415 |
+
<!--
|
| 416 |
+
# sample image dataset with lance
|
| 417 |
+
[[datasets]]
|
| 418 |
+
image_lance_dataset = "/path/to/lance_dataset"
|
| 419 |
+
resolution = [960, 544] # required if general resolution is not set
|
| 420 |
+
# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
|
| 421 |
+
-->
|
| 422 |
+
|
| 423 |
+
The metadata with .json file will be supported in the near future.
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
<!--
|
| 428 |
+
|
| 429 |
+
```toml
|
| 430 |
+
# general configurations
|
| 431 |
+
[general]
|
| 432 |
+
resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
|
| 433 |
+
caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
|
| 434 |
+
batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
|
| 435 |
+
enable_bucket = true # optional, default is false. Enable bucketing for datasets
|
| 436 |
+
bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
|
| 437 |
+
|
| 438 |
+
# sample image dataset with caption text files
|
| 439 |
+
[[datasets]]
|
| 440 |
+
image_directory = "/path/to/image_dir"
|
| 441 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
| 442 |
+
resolution = [960, 544] # required if general resolution is not set
|
| 443 |
+
batch_size = 4 # optional, overwrite the default batch size
|
| 444 |
+
enable_bucket = false # optional, overwrite the default bucketing setting
|
| 445 |
+
bucket_no_upscale = true # optional, overwrite the default bucketing setting
|
| 446 |
+
cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
|
| 447 |
+
|
| 448 |
+
# sample image dataset with metadata **jsonl** file
|
| 449 |
+
[[datasets]]
|
| 450 |
+
image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
|
| 451 |
+
resolution = [960, 544] # required if general resolution is not set
|
| 452 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
| 453 |
+
# caption_extension is not required for metadata jsonl file
|
| 454 |
+
# batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
|
| 455 |
+
|
| 456 |
+
# sample video dataset with caption text files
|
| 457 |
+
[[datasets]]
|
| 458 |
+
video_directory = "/path/to/video_dir"
|
| 459 |
+
caption_extension = ".txt" # required for caption text files, if general caption extension is not set
|
| 460 |
+
resolution = [960, 544] # required if general resolution is not set
|
| 461 |
+
target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
|
| 462 |
+
frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
|
| 463 |
+
frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
|
| 464 |
+
frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
|
| 465 |
+
# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
|
| 466 |
+
|
| 467 |
+
# sample video dataset with metadata jsonl file
|
| 468 |
+
[[datasets]]
|
| 469 |
+
video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
|
| 470 |
+
target_frames = [1, 79]
|
| 471 |
+
cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
|
| 472 |
+
# frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
|
| 473 |
+
```
|
| 474 |
+
|
| 475 |
+
# sample image dataset with lance
|
| 476 |
+
[[datasets]]
|
| 477 |
+
image_lance_dataset = "/path/to/lance_dataset"
|
| 478 |
+
resolution = [960, 544] # required if general resolution is not set
|
| 479 |
+
# batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
|
| 480 |
+
|
| 481 |
+
The metadata with .json file will be supported in the near future.
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
-->
|
dataset/image_video_dataset.py
ADDED
|
@@ -0,0 +1,1786 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 2 |
+
import glob
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import time
|
| 8 |
+
from typing import Optional, Sequence, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
from safetensors.torch import save_file, load_file
|
| 13 |
+
from safetensors import safe_open
|
| 14 |
+
from PIL import Image
|
| 15 |
+
import cv2
|
| 16 |
+
import av
|
| 17 |
+
|
| 18 |
+
from utils import safetensors_utils
|
| 19 |
+
from utils.model_utils import dtype_to_str
|
| 20 |
+
|
| 21 |
+
import logging
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
logging.basicConfig(level=logging.INFO)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import pillow_avif
|
| 31 |
+
|
| 32 |
+
IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
|
| 33 |
+
except:
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
# JPEG-XL on Linux
|
| 37 |
+
try:
|
| 38 |
+
from jxlpy import JXLImagePlugin
|
| 39 |
+
|
| 40 |
+
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
| 41 |
+
except:
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
# JPEG-XL on Windows
|
| 45 |
+
try:
|
| 46 |
+
import pillow_jxl
|
| 47 |
+
|
| 48 |
+
IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
|
| 49 |
+
except:
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
VIDEO_EXTENSIONS = [
|
| 53 |
+
".mp4",
|
| 54 |
+
".webm",
|
| 55 |
+
".avi",
|
| 56 |
+
".mkv",
|
| 57 |
+
".mov",
|
| 58 |
+
".flv",
|
| 59 |
+
".wmv",
|
| 60 |
+
".m4v",
|
| 61 |
+
".mpg",
|
| 62 |
+
".mpeg",
|
| 63 |
+
".MP4",
|
| 64 |
+
".WEBM",
|
| 65 |
+
".AVI",
|
| 66 |
+
".MKV",
|
| 67 |
+
".MOV",
|
| 68 |
+
".FLV",
|
| 69 |
+
".WMV",
|
| 70 |
+
".M4V",
|
| 71 |
+
".MPG",
|
| 72 |
+
".MPEG",
|
| 73 |
+
] # some of them are not tested
|
| 74 |
+
|
| 75 |
+
ARCHITECTURE_HUNYUAN_VIDEO = "hv"
|
| 76 |
+
ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video"
|
| 77 |
+
ARCHITECTURE_WAN = "wan"
|
| 78 |
+
ARCHITECTURE_WAN_FULL = "wan"
|
| 79 |
+
ARCHITECTURE_FRAMEPACK = "fp"
|
| 80 |
+
ARCHITECTURE_FRAMEPACK_FULL = "framepack"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def glob_images(directory, base="*"):
|
| 84 |
+
img_paths = []
|
| 85 |
+
for ext in IMAGE_EXTENSIONS:
|
| 86 |
+
if base == "*":
|
| 87 |
+
img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
| 88 |
+
else:
|
| 89 |
+
img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
| 90 |
+
img_paths = list(set(img_paths)) # remove duplicates
|
| 91 |
+
img_paths.sort()
|
| 92 |
+
return img_paths
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def glob_videos(directory, base="*"):
|
| 96 |
+
video_paths = []
|
| 97 |
+
for ext in VIDEO_EXTENSIONS:
|
| 98 |
+
if base == "*":
|
| 99 |
+
video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
|
| 100 |
+
else:
|
| 101 |
+
video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
|
| 102 |
+
video_paths = list(set(video_paths)) # remove duplicates
|
| 103 |
+
video_paths.sort()
|
| 104 |
+
return video_paths
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def divisible_by(num: int, divisor: int) -> int:
|
| 108 |
+
return num - num % divisor
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
|
| 112 |
+
"""
|
| 113 |
+
Resize the image to the bucket resolution.
|
| 114 |
+
|
| 115 |
+
bucket_reso: **(width, height)**
|
| 116 |
+
"""
|
| 117 |
+
is_pil_image = isinstance(image, Image.Image)
|
| 118 |
+
if is_pil_image:
|
| 119 |
+
image_width, image_height = image.size
|
| 120 |
+
else:
|
| 121 |
+
image_height, image_width = image.shape[:2]
|
| 122 |
+
|
| 123 |
+
if bucket_reso == (image_width, image_height):
|
| 124 |
+
return np.array(image) if is_pil_image else image
|
| 125 |
+
|
| 126 |
+
bucket_width, bucket_height = bucket_reso
|
| 127 |
+
if bucket_width == image_width or bucket_height == image_height:
|
| 128 |
+
image = np.array(image) if is_pil_image else image
|
| 129 |
+
else:
|
| 130 |
+
# resize the image to the bucket resolution to match the short side
|
| 131 |
+
scale_width = bucket_width / image_width
|
| 132 |
+
scale_height = bucket_height / image_height
|
| 133 |
+
scale = max(scale_width, scale_height)
|
| 134 |
+
image_width = int(image_width * scale + 0.5)
|
| 135 |
+
image_height = int(image_height * scale + 0.5)
|
| 136 |
+
|
| 137 |
+
if scale > 1:
|
| 138 |
+
image = Image.fromarray(image) if not is_pil_image else image
|
| 139 |
+
image = image.resize((image_width, image_height), Image.LANCZOS)
|
| 140 |
+
image = np.array(image)
|
| 141 |
+
else:
|
| 142 |
+
image = np.array(image) if is_pil_image else image
|
| 143 |
+
image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
|
| 144 |
+
|
| 145 |
+
# crop the image to the bucket resolution
|
| 146 |
+
crop_left = (image_width - bucket_width) // 2
|
| 147 |
+
crop_top = (image_height - bucket_height) // 2
|
| 148 |
+
image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
|
| 149 |
+
return image
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class ItemInfo:
|
| 153 |
+
def __init__(
|
| 154 |
+
self,
|
| 155 |
+
item_key: str,
|
| 156 |
+
caption: str,
|
| 157 |
+
original_size: tuple[int, int],
|
| 158 |
+
bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
|
| 159 |
+
frame_count: Optional[int] = None,
|
| 160 |
+
content: Optional[np.ndarray] = None,
|
| 161 |
+
latent_cache_path: Optional[str] = None,
|
| 162 |
+
) -> None:
|
| 163 |
+
self.item_key = item_key
|
| 164 |
+
self.caption = caption
|
| 165 |
+
self.original_size = original_size
|
| 166 |
+
self.bucket_size = bucket_size
|
| 167 |
+
self.frame_count = frame_count
|
| 168 |
+
self.content = content
|
| 169 |
+
self.latent_cache_path = latent_cache_path
|
| 170 |
+
self.text_encoder_output_cache_path: Optional[str] = None
|
| 171 |
+
self.control_content: Optional[np.ndarray] = None
|
| 172 |
+
|
| 173 |
+
def __str__(self) -> str:
|
| 174 |
+
return (
|
| 175 |
+
f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
|
| 176 |
+
+ f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
|
| 177 |
+
+ f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path}, content={self.content.shape if self.content is not None else None})"
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
# We use simple if-else approach to support multiple architectures.
|
| 182 |
+
# Maybe we can use a plugin system in the future.
|
| 183 |
+
|
| 184 |
+
# the keys of the dict are `<content_type>_FxHxW_<dtype>` for latents
|
| 185 |
+
# and `<content_type>_<dtype|mask>` for other tensors
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
|
| 189 |
+
"""HunyuanVideo architecture only. HunyuanVideo doesn't support I2V and control latents"""
|
| 190 |
+
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
|
| 191 |
+
|
| 192 |
+
_, F, H, W = latent.shape
|
| 193 |
+
dtype_str = dtype_to_str(latent.dtype)
|
| 194 |
+
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
|
| 195 |
+
|
| 196 |
+
save_latent_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def save_latent_cache_wan(
|
| 200 |
+
item_info: ItemInfo,
|
| 201 |
+
latent: torch.Tensor,
|
| 202 |
+
clip_embed: Optional[torch.Tensor],
|
| 203 |
+
image_latent: Optional[torch.Tensor],
|
| 204 |
+
control_latent: Optional[torch.Tensor],
|
| 205 |
+
):
|
| 206 |
+
"""Wan architecture only"""
|
| 207 |
+
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
|
| 208 |
+
|
| 209 |
+
_, F, H, W = latent.shape
|
| 210 |
+
dtype_str = dtype_to_str(latent.dtype)
|
| 211 |
+
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
|
| 212 |
+
|
| 213 |
+
if clip_embed is not None:
|
| 214 |
+
sd[f"clip_{dtype_str}"] = clip_embed.detach().cpu()
|
| 215 |
+
|
| 216 |
+
if image_latent is not None:
|
| 217 |
+
sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu()
|
| 218 |
+
|
| 219 |
+
if control_latent is not None:
|
| 220 |
+
sd[f"latents_control_{F}x{H}x{W}_{dtype_str}"] = control_latent.detach().cpu()
|
| 221 |
+
|
| 222 |
+
save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def save_latent_cache_framepack(
|
| 226 |
+
item_info: ItemInfo,
|
| 227 |
+
latent: torch.Tensor,
|
| 228 |
+
latent_indices: torch.Tensor,
|
| 229 |
+
clean_latents: torch.Tensor,
|
| 230 |
+
clean_latent_indices: torch.Tensor,
|
| 231 |
+
clean_latents_2x: torch.Tensor,
|
| 232 |
+
clean_latent_2x_indices: torch.Tensor,
|
| 233 |
+
clean_latents_4x: torch.Tensor,
|
| 234 |
+
clean_latent_4x_indices: torch.Tensor,
|
| 235 |
+
image_embeddings: torch.Tensor,
|
| 236 |
+
):
|
| 237 |
+
"""FramePack architecture only"""
|
| 238 |
+
assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
|
| 239 |
+
|
| 240 |
+
_, F, H, W = latent.shape
|
| 241 |
+
dtype_str = dtype_to_str(latent.dtype)
|
| 242 |
+
sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu().contiguous()}
|
| 243 |
+
|
| 244 |
+
# `latents_xxx` must have {F, H, W} suffix
|
| 245 |
+
indices_dtype_str = dtype_to_str(latent_indices.dtype)
|
| 246 |
+
sd[f"image_embeddings_{dtype_str}"] = image_embeddings.detach().cpu() # image embeddings dtype is same as latents dtype
|
| 247 |
+
sd[f"latent_indices_{indices_dtype_str}"] = latent_indices.detach().cpu()
|
| 248 |
+
sd[f"clean_latent_indices_{indices_dtype_str}"] = clean_latent_indices.detach().cpu()
|
| 249 |
+
sd[f"clean_latent_2x_indices_{indices_dtype_str}"] = clean_latent_2x_indices.detach().cpu()
|
| 250 |
+
sd[f"clean_latent_4x_indices_{indices_dtype_str}"] = clean_latent_4x_indices.detach().cpu()
|
| 251 |
+
sd[f"latents_clean_{F}x{H}x{W}_{dtype_str}"] = clean_latents.detach().cpu().contiguous()
|
| 252 |
+
sd[f"latents_clean_2x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_2x.detach().cpu().contiguous()
|
| 253 |
+
sd[f"latents_clean_4x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_4x.detach().cpu().contiguous()
|
| 254 |
+
|
| 255 |
+
# for key, value in sd.items():
|
| 256 |
+
# print(f"{key}: {value.shape}")
|
| 257 |
+
save_latent_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
|
| 261 |
+
metadata = {
|
| 262 |
+
"architecture": arch_fullname,
|
| 263 |
+
"width": f"{item_info.original_size[0]}",
|
| 264 |
+
"height": f"{item_info.original_size[1]}",
|
| 265 |
+
"format_version": "1.0.1",
|
| 266 |
+
}
|
| 267 |
+
if item_info.frame_count is not None:
|
| 268 |
+
metadata["frame_count"] = f"{item_info.frame_count}"
|
| 269 |
+
|
| 270 |
+
for key, value in sd.items():
|
| 271 |
+
# NaN check and show warning, replace NaN with 0
|
| 272 |
+
if torch.isnan(value).any():
|
| 273 |
+
logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
|
| 274 |
+
value[torch.isnan(value)] = 0
|
| 275 |
+
|
| 276 |
+
latent_dir = os.path.dirname(item_info.latent_cache_path)
|
| 277 |
+
os.makedirs(latent_dir, exist_ok=True)
|
| 278 |
+
|
| 279 |
+
save_file(sd, item_info.latent_cache_path, metadata=metadata)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
|
| 283 |
+
"""HunyuanVideo architecture only"""
|
| 284 |
+
assert (
|
| 285 |
+
embed.dim() == 1 or embed.dim() == 2
|
| 286 |
+
), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
|
| 287 |
+
assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
|
| 288 |
+
|
| 289 |
+
sd = {}
|
| 290 |
+
dtype_str = dtype_to_str(embed.dtype)
|
| 291 |
+
text_encoder_type = "llm" if is_llm else "clipL"
|
| 292 |
+
sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
|
| 293 |
+
if mask is not None:
|
| 294 |
+
sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
|
| 295 |
+
|
| 296 |
+
save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def save_text_encoder_output_cache_wan(item_info: ItemInfo, embed: torch.Tensor):
|
| 300 |
+
"""Wan architecture only. Wan2.1 only has a single text encoder"""
|
| 301 |
+
|
| 302 |
+
sd = {}
|
| 303 |
+
dtype_str = dtype_to_str(embed.dtype)
|
| 304 |
+
text_encoder_type = "t5"
|
| 305 |
+
sd[f"varlen_{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
|
| 306 |
+
|
| 307 |
+
save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def save_text_encoder_output_cache_framepack(
|
| 311 |
+
item_info: ItemInfo, llama_vec: torch.Tensor, llama_attention_mask: torch.Tensor, clip_l_pooler: torch.Tensor
|
| 312 |
+
):
|
| 313 |
+
"""FramePack architecture only."""
|
| 314 |
+
sd = {}
|
| 315 |
+
dtype_str = dtype_to_str(llama_vec.dtype)
|
| 316 |
+
sd[f"llama_vec_{dtype_str}"] = llama_vec.detach().cpu()
|
| 317 |
+
sd[f"llama_attention_mask"] = llama_attention_mask.detach().cpu()
|
| 318 |
+
dtype_str = dtype_to_str(clip_l_pooler.dtype)
|
| 319 |
+
sd[f"clip_l_pooler_{dtype_str}"] = clip_l_pooler.detach().cpu()
|
| 320 |
+
|
| 321 |
+
save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL)
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
|
| 325 |
+
for key, value in sd.items():
|
| 326 |
+
# NaN check and show warning, replace NaN with 0
|
| 327 |
+
if torch.isnan(value).any():
|
| 328 |
+
logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
|
| 329 |
+
value[torch.isnan(value)] = 0
|
| 330 |
+
|
| 331 |
+
metadata = {
|
| 332 |
+
"architecture": arch_fullname,
|
| 333 |
+
"caption1": item_info.caption,
|
| 334 |
+
"format_version": "1.0.1",
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
if os.path.exists(item_info.text_encoder_output_cache_path):
|
| 338 |
+
# load existing cache and update metadata
|
| 339 |
+
with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
|
| 340 |
+
existing_metadata = f.metadata()
|
| 341 |
+
for key in f.keys():
|
| 342 |
+
if key not in sd: # avoid overwriting by existing cache, we keep the new one
|
| 343 |
+
sd[key] = f.get_tensor(key)
|
| 344 |
+
|
| 345 |
+
assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
|
| 346 |
+
if existing_metadata["caption1"] != metadata["caption1"]:
|
| 347 |
+
logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
|
| 348 |
+
# TODO verify format_version
|
| 349 |
+
|
| 350 |
+
existing_metadata.pop("caption1", None)
|
| 351 |
+
existing_metadata.pop("format_version", None)
|
| 352 |
+
metadata.update(existing_metadata) # copy existing metadata except caption and format_version
|
| 353 |
+
else:
|
| 354 |
+
text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
|
| 355 |
+
os.makedirs(text_encoder_output_dir, exist_ok=True)
|
| 356 |
+
|
| 357 |
+
safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
class BucketSelector:
|
| 361 |
+
RESOLUTION_STEPS_HUNYUAN = 16
|
| 362 |
+
RESOLUTION_STEPS_WAN = 16
|
| 363 |
+
RESOLUTION_STEPS_FRAMEPACK = 16
|
| 364 |
+
|
| 365 |
+
def __init__(
|
| 366 |
+
self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default"
|
| 367 |
+
):
|
| 368 |
+
self.resolution = resolution
|
| 369 |
+
self.bucket_area = resolution[0] * resolution[1]
|
| 370 |
+
self.architecture = architecture
|
| 371 |
+
|
| 372 |
+
if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
|
| 373 |
+
self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
|
| 374 |
+
elif self.architecture == ARCHITECTURE_WAN:
|
| 375 |
+
self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN
|
| 376 |
+
elif self.architecture == ARCHITECTURE_FRAMEPACK:
|
| 377 |
+
self.reso_steps = BucketSelector.RESOLUTION_STEPS_FRAMEPACK
|
| 378 |
+
else:
|
| 379 |
+
raise ValueError(f"Invalid architecture: {self.architecture}")
|
| 380 |
+
|
| 381 |
+
if not enable_bucket:
|
| 382 |
+
# only define one bucket
|
| 383 |
+
self.bucket_resolutions = [resolution]
|
| 384 |
+
self.no_upscale = False
|
| 385 |
+
else:
|
| 386 |
+
# prepare bucket resolution
|
| 387 |
+
self.no_upscale = no_upscale
|
| 388 |
+
sqrt_size = int(math.sqrt(self.bucket_area))
|
| 389 |
+
min_size = divisible_by(sqrt_size // 2, self.reso_steps)
|
| 390 |
+
self.bucket_resolutions = []
|
| 391 |
+
for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
|
| 392 |
+
h = divisible_by(self.bucket_area // w, self.reso_steps)
|
| 393 |
+
self.bucket_resolutions.append((w, h))
|
| 394 |
+
self.bucket_resolutions.append((h, w))
|
| 395 |
+
|
| 396 |
+
self.bucket_resolutions = list(set(self.bucket_resolutions))
|
| 397 |
+
self.bucket_resolutions.sort()
|
| 398 |
+
|
| 399 |
+
# calculate aspect ratio to find the nearest resolution
|
| 400 |
+
self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
|
| 401 |
+
|
| 402 |
+
def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
|
| 403 |
+
"""
|
| 404 |
+
return the bucket resolution for the given image size, (width, height)
|
| 405 |
+
"""
|
| 406 |
+
area = image_size[0] * image_size[1]
|
| 407 |
+
if self.no_upscale and area <= self.bucket_area:
|
| 408 |
+
w, h = image_size
|
| 409 |
+
w = divisible_by(w, self.reso_steps)
|
| 410 |
+
h = divisible_by(h, self.reso_steps)
|
| 411 |
+
return w, h
|
| 412 |
+
|
| 413 |
+
aspect_ratio = image_size[0] / image_size[1]
|
| 414 |
+
ar_errors = self.aspect_ratios - aspect_ratio
|
| 415 |
+
bucket_id = np.abs(ar_errors).argmin()
|
| 416 |
+
return self.bucket_resolutions[bucket_id]
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def load_video(
|
| 420 |
+
video_path: str,
|
| 421 |
+
start_frame: Optional[int] = None,
|
| 422 |
+
end_frame: Optional[int] = None,
|
| 423 |
+
bucket_selector: Optional[BucketSelector] = None,
|
| 424 |
+
bucket_reso: Optional[tuple[int, int]] = None,
|
| 425 |
+
source_fps: Optional[float] = None,
|
| 426 |
+
target_fps: Optional[float] = None,
|
| 427 |
+
) -> list[np.ndarray]:
|
| 428 |
+
"""
|
| 429 |
+
bucket_reso: if given, resize the video to the bucket resolution, (width, height)
|
| 430 |
+
"""
|
| 431 |
+
if source_fps is None or target_fps is None:
|
| 432 |
+
if os.path.isfile(video_path):
|
| 433 |
+
container = av.open(video_path)
|
| 434 |
+
video = []
|
| 435 |
+
for i, frame in enumerate(container.decode(video=0)):
|
| 436 |
+
if start_frame is not None and i < start_frame:
|
| 437 |
+
continue
|
| 438 |
+
if end_frame is not None and i >= end_frame:
|
| 439 |
+
break
|
| 440 |
+
frame = frame.to_image()
|
| 441 |
+
|
| 442 |
+
if bucket_selector is not None and bucket_reso is None:
|
| 443 |
+
bucket_reso = bucket_selector.get_bucket_resolution(frame.size) # calc resolution from first frame
|
| 444 |
+
|
| 445 |
+
if bucket_reso is not None:
|
| 446 |
+
frame = resize_image_to_bucket(frame, bucket_reso)
|
| 447 |
+
else:
|
| 448 |
+
frame = np.array(frame)
|
| 449 |
+
|
| 450 |
+
video.append(frame)
|
| 451 |
+
container.close()
|
| 452 |
+
else:
|
| 453 |
+
# load images in the directory
|
| 454 |
+
image_files = glob_images(video_path)
|
| 455 |
+
image_files.sort()
|
| 456 |
+
video = []
|
| 457 |
+
for i in range(len(image_files)):
|
| 458 |
+
if start_frame is not None and i < start_frame:
|
| 459 |
+
continue
|
| 460 |
+
if end_frame is not None and i >= end_frame:
|
| 461 |
+
break
|
| 462 |
+
|
| 463 |
+
image_file = image_files[i]
|
| 464 |
+
image = Image.open(image_file).convert("RGB")
|
| 465 |
+
|
| 466 |
+
if bucket_selector is not None and bucket_reso is None:
|
| 467 |
+
bucket_reso = bucket_selector.get_bucket_resolution(image.size) # calc resolution from first frame
|
| 468 |
+
image = np.array(image)
|
| 469 |
+
if bucket_reso is not None:
|
| 470 |
+
image = resize_image_to_bucket(image, bucket_reso)
|
| 471 |
+
|
| 472 |
+
video.append(image)
|
| 473 |
+
else:
|
| 474 |
+
# drop frames to match the target fps TODO commonize this code with the above if this works
|
| 475 |
+
frame_index_delta = target_fps / source_fps # example: 16 / 30 = 0.5333
|
| 476 |
+
if os.path.isfile(video_path):
|
| 477 |
+
container = av.open(video_path)
|
| 478 |
+
video = []
|
| 479 |
+
frame_index_with_fraction = 0.0
|
| 480 |
+
previous_frame_index = -1
|
| 481 |
+
for i, frame in enumerate(container.decode(video=0)):
|
| 482 |
+
target_frame_index = int(frame_index_with_fraction)
|
| 483 |
+
frame_index_with_fraction += frame_index_delta
|
| 484 |
+
|
| 485 |
+
if target_frame_index == previous_frame_index: # drop this frame
|
| 486 |
+
continue
|
| 487 |
+
|
| 488 |
+
# accept this frame
|
| 489 |
+
previous_frame_index = target_frame_index
|
| 490 |
+
|
| 491 |
+
if start_frame is not None and target_frame_index < start_frame:
|
| 492 |
+
continue
|
| 493 |
+
if end_frame is not None and target_frame_index >= end_frame:
|
| 494 |
+
break
|
| 495 |
+
frame = frame.to_image()
|
| 496 |
+
|
| 497 |
+
if bucket_selector is not None and bucket_reso is None:
|
| 498 |
+
bucket_reso = bucket_selector.get_bucket_resolution(frame.size) # calc resolution from first frame
|
| 499 |
+
|
| 500 |
+
if bucket_reso is not None:
|
| 501 |
+
frame = resize_image_to_bucket(frame, bucket_reso)
|
| 502 |
+
else:
|
| 503 |
+
frame = np.array(frame)
|
| 504 |
+
|
| 505 |
+
video.append(frame)
|
| 506 |
+
container.close()
|
| 507 |
+
else:
|
| 508 |
+
# load images in the directory
|
| 509 |
+
image_files = glob_images(video_path)
|
| 510 |
+
image_files.sort()
|
| 511 |
+
video = []
|
| 512 |
+
frame_index_with_fraction = 0.0
|
| 513 |
+
previous_frame_index = -1
|
| 514 |
+
for i in range(len(image_files)):
|
| 515 |
+
target_frame_index = int(frame_index_with_fraction)
|
| 516 |
+
frame_index_with_fraction += frame_index_delta
|
| 517 |
+
|
| 518 |
+
if target_frame_index == previous_frame_index: # drop this frame
|
| 519 |
+
continue
|
| 520 |
+
|
| 521 |
+
# accept this frame
|
| 522 |
+
previous_frame_index = target_frame_index
|
| 523 |
+
|
| 524 |
+
if start_frame is not None and target_frame_index < start_frame:
|
| 525 |
+
continue
|
| 526 |
+
if end_frame is not None and target_frame_index >= end_frame:
|
| 527 |
+
break
|
| 528 |
+
|
| 529 |
+
image_file = image_files[i]
|
| 530 |
+
image = Image.open(image_file).convert("RGB")
|
| 531 |
+
|
| 532 |
+
if bucket_selector is not None and bucket_reso is None:
|
| 533 |
+
bucket_reso = bucket_selector.get_bucket_resolution(image.size) # calc resolution from first frame
|
| 534 |
+
image = np.array(image)
|
| 535 |
+
if bucket_reso is not None:
|
| 536 |
+
image = resize_image_to_bucket(image, bucket_reso)
|
| 537 |
+
|
| 538 |
+
video.append(image)
|
| 539 |
+
|
| 540 |
+
return video
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class BucketBatchManager:
|
| 544 |
+
|
| 545 |
+
def __init__(self, bucketed_item_info: dict[Union[tuple[int, int], tuple[int, int, int]], list[ItemInfo]], batch_size: int):
|
| 546 |
+
self.batch_size = batch_size
|
| 547 |
+
self.buckets = bucketed_item_info
|
| 548 |
+
self.bucket_resos = list(self.buckets.keys())
|
| 549 |
+
self.bucket_resos.sort()
|
| 550 |
+
|
| 551 |
+
# indices for enumerating batches. each batch is reso + batch_idx. reso is (width, height) or (width, height, frames)
|
| 552 |
+
self.bucket_batch_indices: list[tuple[Union[tuple[int, int], tuple[int, int, int], int]]] = []
|
| 553 |
+
for bucket_reso in self.bucket_resos:
|
| 554 |
+
bucket = self.buckets[bucket_reso]
|
| 555 |
+
num_batches = math.ceil(len(bucket) / self.batch_size)
|
| 556 |
+
for i in range(num_batches):
|
| 557 |
+
self.bucket_batch_indices.append((bucket_reso, i))
|
| 558 |
+
|
| 559 |
+
# do no shuffle here to avoid multiple datasets have different order
|
| 560 |
+
# self.shuffle()
|
| 561 |
+
|
| 562 |
+
def show_bucket_info(self):
|
| 563 |
+
for bucket_reso in self.bucket_resos:
|
| 564 |
+
bucket = self.buckets[bucket_reso]
|
| 565 |
+
logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
|
| 566 |
+
|
| 567 |
+
logger.info(f"total batches: {len(self)}")
|
| 568 |
+
|
| 569 |
+
def shuffle(self):
|
| 570 |
+
# shuffle each bucket
|
| 571 |
+
for bucket in self.buckets.values():
|
| 572 |
+
random.shuffle(bucket)
|
| 573 |
+
|
| 574 |
+
# shuffle the order of batches
|
| 575 |
+
random.shuffle(self.bucket_batch_indices)
|
| 576 |
+
|
| 577 |
+
def __len__(self):
|
| 578 |
+
return len(self.bucket_batch_indices)
|
| 579 |
+
|
| 580 |
+
def __getitem__(self, idx):
|
| 581 |
+
bucket_reso, batch_idx = self.bucket_batch_indices[idx]
|
| 582 |
+
bucket = self.buckets[bucket_reso]
|
| 583 |
+
start = batch_idx * self.batch_size
|
| 584 |
+
end = min(start + self.batch_size, len(bucket))
|
| 585 |
+
|
| 586 |
+
batch_tensor_data = {}
|
| 587 |
+
varlen_keys = set()
|
| 588 |
+
for item_info in bucket[start:end]:
|
| 589 |
+
sd_latent = load_file(item_info.latent_cache_path)
|
| 590 |
+
sd_te = load_file(item_info.text_encoder_output_cache_path)
|
| 591 |
+
sd = {**sd_latent, **sd_te}
|
| 592 |
+
|
| 593 |
+
# TODO refactor this
|
| 594 |
+
for key in sd.keys():
|
| 595 |
+
is_varlen_key = key.startswith("varlen_") # varlen keys are not stacked
|
| 596 |
+
content_key = key
|
| 597 |
+
|
| 598 |
+
if is_varlen_key:
|
| 599 |
+
content_key = content_key.replace("varlen_", "")
|
| 600 |
+
|
| 601 |
+
if content_key.endswith("_mask"):
|
| 602 |
+
pass
|
| 603 |
+
else:
|
| 604 |
+
content_key = content_key.rsplit("_", 1)[0] # remove dtype
|
| 605 |
+
if content_key.startswith("latents_"):
|
| 606 |
+
content_key = content_key.rsplit("_", 1)[0] # remove FxHxW
|
| 607 |
+
|
| 608 |
+
if content_key not in batch_tensor_data:
|
| 609 |
+
batch_tensor_data[content_key] = []
|
| 610 |
+
batch_tensor_data[content_key].append(sd[key])
|
| 611 |
+
|
| 612 |
+
if is_varlen_key:
|
| 613 |
+
varlen_keys.add(content_key)
|
| 614 |
+
|
| 615 |
+
for key in batch_tensor_data.keys():
|
| 616 |
+
if key not in varlen_keys:
|
| 617 |
+
batch_tensor_data[key] = torch.stack(batch_tensor_data[key])
|
| 618 |
+
|
| 619 |
+
return batch_tensor_data
|
| 620 |
+
|
| 621 |
+
|
| 622 |
+
class ContentDatasource:
|
| 623 |
+
def __init__(self):
|
| 624 |
+
self.caption_only = False # set to True to only fetch caption for Text Encoder caching
|
| 625 |
+
self.has_control = False
|
| 626 |
+
|
| 627 |
+
def set_caption_only(self, caption_only: bool):
|
| 628 |
+
self.caption_only = caption_only
|
| 629 |
+
|
| 630 |
+
def is_indexable(self):
|
| 631 |
+
return False
|
| 632 |
+
|
| 633 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
| 634 |
+
"""
|
| 635 |
+
Returns caption. May not be called if is_indexable() returns False.
|
| 636 |
+
"""
|
| 637 |
+
raise NotImplementedError
|
| 638 |
+
|
| 639 |
+
def __len__(self):
|
| 640 |
+
raise NotImplementedError
|
| 641 |
+
|
| 642 |
+
def __iter__(self):
|
| 643 |
+
raise NotImplementedError
|
| 644 |
+
|
| 645 |
+
def __next__(self):
|
| 646 |
+
raise NotImplementedError
|
| 647 |
+
|
| 648 |
+
|
| 649 |
+
class ImageDatasource(ContentDatasource):
|
| 650 |
+
def __init__(self):
|
| 651 |
+
super().__init__()
|
| 652 |
+
|
| 653 |
+
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
|
| 654 |
+
"""
|
| 655 |
+
Returns image data as a tuple of image path, image, and caption for the given index.
|
| 656 |
+
Key must be unique and valid as a file name.
|
| 657 |
+
May not be called if is_indexable() returns False.
|
| 658 |
+
"""
|
| 659 |
+
raise NotImplementedError
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
class ImageDirectoryDatasource(ImageDatasource):
|
| 663 |
+
def __init__(self, image_directory: str, caption_extension: Optional[str] = None, control_directory: Optional[str] = None):
|
| 664 |
+
super().__init__()
|
| 665 |
+
self.image_directory = image_directory
|
| 666 |
+
self.caption_extension = caption_extension
|
| 667 |
+
self.control_directory = control_directory
|
| 668 |
+
self.current_idx = 0
|
| 669 |
+
|
| 670 |
+
# glob images
|
| 671 |
+
logger.info(f"glob images in {self.image_directory}")
|
| 672 |
+
self.image_paths = glob_images(self.image_directory)
|
| 673 |
+
logger.info(f"found {len(self.image_paths)} images")
|
| 674 |
+
|
| 675 |
+
# glob control images if specified
|
| 676 |
+
if self.control_directory is not None:
|
| 677 |
+
logger.info(f"glob control images in {self.control_directory}")
|
| 678 |
+
self.has_control = True
|
| 679 |
+
self.control_paths = {}
|
| 680 |
+
for image_path in self.image_paths:
|
| 681 |
+
image_basename = os.path.basename(image_path)
|
| 682 |
+
control_path = os.path.join(self.control_directory, image_basename)
|
| 683 |
+
if os.path.exists(control_path):
|
| 684 |
+
self.control_paths[image_path] = control_path
|
| 685 |
+
else:
|
| 686 |
+
# another extension for control path
|
| 687 |
+
# for example: image_path = "img/image.png" -> control_path = "control/image.jpg"
|
| 688 |
+
image_basename_no_ext = os.path.splitext(image_basename)[0]
|
| 689 |
+
for ext in IMAGE_EXTENSIONS:
|
| 690 |
+
potential_path = os.path.join(self.control_directory, image_basename_no_ext + ext)
|
| 691 |
+
if os.path.exists(potential_path):
|
| 692 |
+
self.control_paths[image_path] = potential_path
|
| 693 |
+
break
|
| 694 |
+
|
| 695 |
+
logger.info(f"found {len(self.control_paths)} matching control images")
|
| 696 |
+
missing_controls = len(self.image_paths) - len(self.control_paths)
|
| 697 |
+
if missing_controls > 0:
|
| 698 |
+
missing_control_paths = set(self.image_paths) - set(self.control_paths.keys())
|
| 699 |
+
logger.error(f"Could not find matching control images for {missing_controls} images: {missing_control_paths}")
|
| 700 |
+
raise ValueError(f"Could not find matching control images for {missing_controls} images")
|
| 701 |
+
|
| 702 |
+
def is_indexable(self):
|
| 703 |
+
return True
|
| 704 |
+
|
| 705 |
+
def __len__(self):
|
| 706 |
+
return len(self.image_paths)
|
| 707 |
+
|
| 708 |
+
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[Image.Image]]:
|
| 709 |
+
image_path = self.image_paths[idx]
|
| 710 |
+
image = Image.open(image_path).convert("RGB")
|
| 711 |
+
|
| 712 |
+
_, caption = self.get_caption(idx)
|
| 713 |
+
|
| 714 |
+
control = None
|
| 715 |
+
if self.has_control:
|
| 716 |
+
control_path = self.control_paths[image_path]
|
| 717 |
+
control = Image.open(control_path).convert("RGB")
|
| 718 |
+
|
| 719 |
+
return image_path, image, caption, control
|
| 720 |
+
|
| 721 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
| 722 |
+
image_path = self.image_paths[idx]
|
| 723 |
+
caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
|
| 724 |
+
with open(caption_path, "r", encoding="utf-8") as f:
|
| 725 |
+
caption = f.read().strip()
|
| 726 |
+
return image_path, caption
|
| 727 |
+
|
| 728 |
+
def __iter__(self):
|
| 729 |
+
self.current_idx = 0
|
| 730 |
+
return self
|
| 731 |
+
|
| 732 |
+
def __next__(self) -> callable:
|
| 733 |
+
"""
|
| 734 |
+
Returns a fetcher function that returns image data.
|
| 735 |
+
"""
|
| 736 |
+
if self.current_idx >= len(self.image_paths):
|
| 737 |
+
raise StopIteration
|
| 738 |
+
|
| 739 |
+
if self.caption_only:
|
| 740 |
+
|
| 741 |
+
def create_caption_fetcher(index):
|
| 742 |
+
return lambda: self.get_caption(index)
|
| 743 |
+
|
| 744 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
| 745 |
+
else:
|
| 746 |
+
|
| 747 |
+
def create_image_fetcher(index):
|
| 748 |
+
return lambda: self.get_image_data(index)
|
| 749 |
+
|
| 750 |
+
fetcher = create_image_fetcher(self.current_idx)
|
| 751 |
+
|
| 752 |
+
self.current_idx += 1
|
| 753 |
+
return fetcher
|
| 754 |
+
|
| 755 |
+
|
| 756 |
+
class ImageJsonlDatasource(ImageDatasource):
|
| 757 |
+
def __init__(self, image_jsonl_file: str):
|
| 758 |
+
super().__init__()
|
| 759 |
+
self.image_jsonl_file = image_jsonl_file
|
| 760 |
+
self.current_idx = 0
|
| 761 |
+
|
| 762 |
+
# load jsonl
|
| 763 |
+
logger.info(f"load image jsonl from {self.image_jsonl_file}")
|
| 764 |
+
self.data = []
|
| 765 |
+
with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
|
| 766 |
+
for line in f:
|
| 767 |
+
try:
|
| 768 |
+
data = json.loads(line)
|
| 769 |
+
except json.JSONDecodeError:
|
| 770 |
+
logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}")
|
| 771 |
+
raise
|
| 772 |
+
self.data.append(data)
|
| 773 |
+
logger.info(f"loaded {len(self.data)} images")
|
| 774 |
+
|
| 775 |
+
# Check if there are control paths in the JSONL
|
| 776 |
+
self.has_control = any("control_path" in item for item in self.data)
|
| 777 |
+
if self.has_control:
|
| 778 |
+
control_count = sum(1 for item in self.data if "control_path" in item)
|
| 779 |
+
if control_count < len(self.data):
|
| 780 |
+
missing_control_images = [item["image_path"] for item in self.data if "control_path" not in item]
|
| 781 |
+
logger.error(f"Some images do not have control paths in JSONL data: {missing_control_images}")
|
| 782 |
+
raise ValueError(f"Some images do not have control paths in JSONL data: {missing_control_images}")
|
| 783 |
+
logger.info(f"found {control_count} control images in JSONL data")
|
| 784 |
+
|
| 785 |
+
def is_indexable(self):
|
| 786 |
+
return True
|
| 787 |
+
|
| 788 |
+
def __len__(self):
|
| 789 |
+
return len(self.data)
|
| 790 |
+
|
| 791 |
+
def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[Image.Image]]:
|
| 792 |
+
data = self.data[idx]
|
| 793 |
+
image_path = data["image_path"]
|
| 794 |
+
image = Image.open(image_path).convert("RGB")
|
| 795 |
+
|
| 796 |
+
caption = data["caption"]
|
| 797 |
+
|
| 798 |
+
control = None
|
| 799 |
+
if self.has_control:
|
| 800 |
+
control_path = data["control_path"]
|
| 801 |
+
control = Image.open(control_path).convert("RGB")
|
| 802 |
+
|
| 803 |
+
return image_path, image, caption, control
|
| 804 |
+
|
| 805 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
| 806 |
+
data = self.data[idx]
|
| 807 |
+
image_path = data["image_path"]
|
| 808 |
+
caption = data["caption"]
|
| 809 |
+
return image_path, caption
|
| 810 |
+
|
| 811 |
+
def __iter__(self):
|
| 812 |
+
self.current_idx = 0
|
| 813 |
+
return self
|
| 814 |
+
|
| 815 |
+
def __next__(self) -> callable:
|
| 816 |
+
if self.current_idx >= len(self.data):
|
| 817 |
+
raise StopIteration
|
| 818 |
+
|
| 819 |
+
if self.caption_only:
|
| 820 |
+
|
| 821 |
+
def create_caption_fetcher(index):
|
| 822 |
+
return lambda: self.get_caption(index)
|
| 823 |
+
|
| 824 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
| 825 |
+
|
| 826 |
+
else:
|
| 827 |
+
|
| 828 |
+
def create_fetcher(index):
|
| 829 |
+
return lambda: self.get_image_data(index)
|
| 830 |
+
|
| 831 |
+
fetcher = create_fetcher(self.current_idx)
|
| 832 |
+
|
| 833 |
+
self.current_idx += 1
|
| 834 |
+
return fetcher
|
| 835 |
+
|
| 836 |
+
|
| 837 |
+
class VideoDatasource(ContentDatasource):
|
| 838 |
+
def __init__(self):
|
| 839 |
+
super().__init__()
|
| 840 |
+
|
| 841 |
+
# None means all frames
|
| 842 |
+
self.start_frame = None
|
| 843 |
+
self.end_frame = None
|
| 844 |
+
|
| 845 |
+
self.bucket_selector = None
|
| 846 |
+
|
| 847 |
+
self.source_fps = None
|
| 848 |
+
self.target_fps = None
|
| 849 |
+
|
| 850 |
+
def __len__(self):
|
| 851 |
+
raise NotImplementedError
|
| 852 |
+
|
| 853 |
+
def get_video_data_from_path(
|
| 854 |
+
self,
|
| 855 |
+
video_path: str,
|
| 856 |
+
start_frame: Optional[int] = None,
|
| 857 |
+
end_frame: Optional[int] = None,
|
| 858 |
+
bucket_selector: Optional[BucketSelector] = None,
|
| 859 |
+
) -> tuple[str, list[Image.Image], str]:
|
| 860 |
+
# this method can resize the video if bucket_selector is given to reduce the memory usage
|
| 861 |
+
|
| 862 |
+
start_frame = start_frame if start_frame is not None else self.start_frame
|
| 863 |
+
end_frame = end_frame if end_frame is not None else self.end_frame
|
| 864 |
+
bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
|
| 865 |
+
|
| 866 |
+
video = load_video(
|
| 867 |
+
video_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps
|
| 868 |
+
)
|
| 869 |
+
return video
|
| 870 |
+
|
| 871 |
+
def get_control_data_from_path(
|
| 872 |
+
self,
|
| 873 |
+
control_path: str,
|
| 874 |
+
start_frame: Optional[int] = None,
|
| 875 |
+
end_frame: Optional[int] = None,
|
| 876 |
+
bucket_selector: Optional[BucketSelector] = None,
|
| 877 |
+
) -> list[Image.Image]:
|
| 878 |
+
start_frame = start_frame if start_frame is not None else self.start_frame
|
| 879 |
+
end_frame = end_frame if end_frame is not None else self.end_frame
|
| 880 |
+
bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
|
| 881 |
+
|
| 882 |
+
control = load_video(
|
| 883 |
+
control_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps
|
| 884 |
+
)
|
| 885 |
+
return control
|
| 886 |
+
|
| 887 |
+
def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
|
| 888 |
+
self.start_frame = start_frame
|
| 889 |
+
self.end_frame = end_frame
|
| 890 |
+
|
| 891 |
+
def set_bucket_selector(self, bucket_selector: BucketSelector):
|
| 892 |
+
self.bucket_selector = bucket_selector
|
| 893 |
+
|
| 894 |
+
def set_source_and_target_fps(self, source_fps: Optional[float], target_fps: Optional[float]):
|
| 895 |
+
self.source_fps = source_fps
|
| 896 |
+
self.target_fps = target_fps
|
| 897 |
+
|
| 898 |
+
def __iter__(self):
|
| 899 |
+
raise NotImplementedError
|
| 900 |
+
|
| 901 |
+
def __next__(self):
|
| 902 |
+
raise NotImplementedError
|
| 903 |
+
|
| 904 |
+
|
| 905 |
+
class VideoDirectoryDatasource(VideoDatasource):
|
| 906 |
+
def __init__(self, video_directory: str, caption_extension: Optional[str] = None, control_directory: Optional[str] = None):
|
| 907 |
+
super().__init__()
|
| 908 |
+
self.video_directory = video_directory
|
| 909 |
+
self.caption_extension = caption_extension
|
| 910 |
+
self.control_directory = control_directory # 新しく追加: コントロール画像ディレクトリ
|
| 911 |
+
self.current_idx = 0
|
| 912 |
+
|
| 913 |
+
# glob videos
|
| 914 |
+
logger.info(f"glob videos in {self.video_directory}")
|
| 915 |
+
self.video_paths = glob_videos(self.video_directory)
|
| 916 |
+
logger.info(f"found {len(self.video_paths)} videos")
|
| 917 |
+
|
| 918 |
+
# glob control images if specified
|
| 919 |
+
if self.control_directory is not None:
|
| 920 |
+
logger.info(f"glob control videos in {self.control_directory}")
|
| 921 |
+
self.has_control = True
|
| 922 |
+
self.control_paths = {}
|
| 923 |
+
for video_path in self.video_paths:
|
| 924 |
+
video_basename = os.path.basename(video_path)
|
| 925 |
+
# construct control path from video path
|
| 926 |
+
# for example: video_path = "vid/video.mp4" -> control_path = "control/video.mp4"
|
| 927 |
+
control_path = os.path.join(self.control_directory, video_basename)
|
| 928 |
+
if os.path.exists(control_path):
|
| 929 |
+
self.control_paths[video_path] = control_path
|
| 930 |
+
else:
|
| 931 |
+
# use the same base name for control path
|
| 932 |
+
base_name = os.path.splitext(video_basename)[0]
|
| 933 |
+
|
| 934 |
+
# directory with images. for example: video_path = "vid/video.mp4" -> control_path = "control/video"
|
| 935 |
+
potential_path = os.path.join(self.control_directory, base_name) # no extension
|
| 936 |
+
if os.path.isdir(potential_path):
|
| 937 |
+
self.control_paths[video_path] = potential_path
|
| 938 |
+
else:
|
| 939 |
+
# another extension for control path
|
| 940 |
+
# for example: video_path = "vid/video.mp4" -> control_path = "control/video.mov"
|
| 941 |
+
for ext in VIDEO_EXTENSIONS:
|
| 942 |
+
potential_path = os.path.join(self.control_directory, base_name + ext)
|
| 943 |
+
if os.path.exists(potential_path):
|
| 944 |
+
self.control_paths[video_path] = potential_path
|
| 945 |
+
break
|
| 946 |
+
|
| 947 |
+
logger.info(f"found {len(self.control_paths)} matching control videos/images")
|
| 948 |
+
# check if all videos have matching control paths, if not, raise an error
|
| 949 |
+
missing_controls = len(self.video_paths) - len(self.control_paths)
|
| 950 |
+
if missing_controls > 0:
|
| 951 |
+
# logger.warning(f"Could not find matching control videos/images for {missing_controls} videos")
|
| 952 |
+
missing_controls_videos = [video_path for video_path in self.video_paths if video_path not in self.control_paths]
|
| 953 |
+
logger.error(
|
| 954 |
+
f"Could not find matching control videos/images for {missing_controls} videos: {missing_controls_videos}"
|
| 955 |
+
)
|
| 956 |
+
raise ValueError(f"Could not find matching control videos/images for {missing_controls} videos")
|
| 957 |
+
|
| 958 |
+
def is_indexable(self):
|
| 959 |
+
return True
|
| 960 |
+
|
| 961 |
+
def __len__(self):
|
| 962 |
+
return len(self.video_paths)
|
| 963 |
+
|
| 964 |
+
def get_video_data(
|
| 965 |
+
self,
|
| 966 |
+
idx: int,
|
| 967 |
+
start_frame: Optional[int] = None,
|
| 968 |
+
end_frame: Optional[int] = None,
|
| 969 |
+
bucket_selector: Optional[BucketSelector] = None,
|
| 970 |
+
) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]:
|
| 971 |
+
video_path = self.video_paths[idx]
|
| 972 |
+
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
|
| 973 |
+
|
| 974 |
+
_, caption = self.get_caption(idx)
|
| 975 |
+
|
| 976 |
+
control = None
|
| 977 |
+
if self.control_directory is not None and video_path in self.control_paths:
|
| 978 |
+
control_path = self.control_paths[video_path]
|
| 979 |
+
control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector)
|
| 980 |
+
|
| 981 |
+
return video_path, video, caption, control
|
| 982 |
+
|
| 983 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
| 984 |
+
video_path = self.video_paths[idx]
|
| 985 |
+
caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
|
| 986 |
+
with open(caption_path, "r", encoding="utf-8") as f:
|
| 987 |
+
caption = f.read().strip()
|
| 988 |
+
return video_path, caption
|
| 989 |
+
|
| 990 |
+
def __iter__(self):
|
| 991 |
+
self.current_idx = 0
|
| 992 |
+
return self
|
| 993 |
+
|
| 994 |
+
def __next__(self):
|
| 995 |
+
if self.current_idx >= len(self.video_paths):
|
| 996 |
+
raise StopIteration
|
| 997 |
+
|
| 998 |
+
if self.caption_only:
|
| 999 |
+
|
| 1000 |
+
def create_caption_fetcher(index):
|
| 1001 |
+
return lambda: self.get_caption(index)
|
| 1002 |
+
|
| 1003 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
| 1004 |
+
|
| 1005 |
+
else:
|
| 1006 |
+
|
| 1007 |
+
def create_fetcher(index):
|
| 1008 |
+
return lambda: self.get_video_data(index)
|
| 1009 |
+
|
| 1010 |
+
fetcher = create_fetcher(self.current_idx)
|
| 1011 |
+
|
| 1012 |
+
self.current_idx += 1
|
| 1013 |
+
return fetcher
|
| 1014 |
+
|
| 1015 |
+
|
| 1016 |
+
class VideoJsonlDatasource(VideoDatasource):
|
| 1017 |
+
def __init__(self, video_jsonl_file: str):
|
| 1018 |
+
super().__init__()
|
| 1019 |
+
self.video_jsonl_file = video_jsonl_file
|
| 1020 |
+
self.current_idx = 0
|
| 1021 |
+
|
| 1022 |
+
# load jsonl
|
| 1023 |
+
logger.info(f"load video jsonl from {self.video_jsonl_file}")
|
| 1024 |
+
self.data = []
|
| 1025 |
+
with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
|
| 1026 |
+
for line in f:
|
| 1027 |
+
data = json.loads(line)
|
| 1028 |
+
self.data.append(data)
|
| 1029 |
+
logger.info(f"loaded {len(self.data)} videos")
|
| 1030 |
+
|
| 1031 |
+
# Check if there are control paths in the JSONL
|
| 1032 |
+
self.has_control = any("control_path" in item for item in self.data)
|
| 1033 |
+
if self.has_control:
|
| 1034 |
+
control_count = sum(1 for item in self.data if "control_path" in item)
|
| 1035 |
+
if control_count < len(self.data):
|
| 1036 |
+
missing_control_videos = [item["video_path"] for item in self.data if "control_path" not in item]
|
| 1037 |
+
logger.error(f"Some videos do not have control paths in JSONL data: {missing_control_videos}")
|
| 1038 |
+
raise ValueError(f"Some videos do not have control paths in JSONL data: {missing_control_videos}")
|
| 1039 |
+
logger.info(f"found {control_count} control videos/images in JSONL data")
|
| 1040 |
+
|
| 1041 |
+
def is_indexable(self):
|
| 1042 |
+
return True
|
| 1043 |
+
|
| 1044 |
+
def __len__(self):
|
| 1045 |
+
return len(self.data)
|
| 1046 |
+
|
| 1047 |
+
def get_video_data(
|
| 1048 |
+
self,
|
| 1049 |
+
idx: int,
|
| 1050 |
+
start_frame: Optional[int] = None,
|
| 1051 |
+
end_frame: Optional[int] = None,
|
| 1052 |
+
bucket_selector: Optional[BucketSelector] = None,
|
| 1053 |
+
) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]:
|
| 1054 |
+
data = self.data[idx]
|
| 1055 |
+
video_path = data["video_path"]
|
| 1056 |
+
video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
|
| 1057 |
+
|
| 1058 |
+
caption = data["caption"]
|
| 1059 |
+
|
| 1060 |
+
control = None
|
| 1061 |
+
if "control_path" in data and data["control_path"]:
|
| 1062 |
+
control_path = data["control_path"]
|
| 1063 |
+
control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector)
|
| 1064 |
+
|
| 1065 |
+
return video_path, video, caption, control
|
| 1066 |
+
|
| 1067 |
+
def get_caption(self, idx: int) -> tuple[str, str]:
|
| 1068 |
+
data = self.data[idx]
|
| 1069 |
+
video_path = data["video_path"]
|
| 1070 |
+
caption = data["caption"]
|
| 1071 |
+
return video_path, caption
|
| 1072 |
+
|
| 1073 |
+
def __iter__(self):
|
| 1074 |
+
self.current_idx = 0
|
| 1075 |
+
return self
|
| 1076 |
+
|
| 1077 |
+
def __next__(self):
|
| 1078 |
+
if self.current_idx >= len(self.data):
|
| 1079 |
+
raise StopIteration
|
| 1080 |
+
|
| 1081 |
+
if self.caption_only:
|
| 1082 |
+
|
| 1083 |
+
def create_caption_fetcher(index):
|
| 1084 |
+
return lambda: self.get_caption(index)
|
| 1085 |
+
|
| 1086 |
+
fetcher = create_caption_fetcher(self.current_idx)
|
| 1087 |
+
|
| 1088 |
+
else:
|
| 1089 |
+
|
| 1090 |
+
def create_fetcher(index):
|
| 1091 |
+
return lambda: self.get_video_data(index)
|
| 1092 |
+
|
| 1093 |
+
fetcher = create_fetcher(self.current_idx)
|
| 1094 |
+
|
| 1095 |
+
self.current_idx += 1
|
| 1096 |
+
return fetcher
|
| 1097 |
+
|
| 1098 |
+
|
| 1099 |
+
class BaseDataset(torch.utils.data.Dataset):
|
| 1100 |
+
def __init__(
|
| 1101 |
+
self,
|
| 1102 |
+
resolution: Tuple[int, int] = (960, 544),
|
| 1103 |
+
caption_extension: Optional[str] = None,
|
| 1104 |
+
batch_size: int = 1,
|
| 1105 |
+
num_repeats: int = 1,
|
| 1106 |
+
enable_bucket: bool = False,
|
| 1107 |
+
bucket_no_upscale: bool = False,
|
| 1108 |
+
cache_directory: Optional[str] = None,
|
| 1109 |
+
debug_dataset: bool = False,
|
| 1110 |
+
architecture: str = "no_default",
|
| 1111 |
+
):
|
| 1112 |
+
self.resolution = resolution
|
| 1113 |
+
self.caption_extension = caption_extension
|
| 1114 |
+
self.batch_size = batch_size
|
| 1115 |
+
self.num_repeats = num_repeats
|
| 1116 |
+
self.enable_bucket = enable_bucket
|
| 1117 |
+
self.bucket_no_upscale = bucket_no_upscale
|
| 1118 |
+
self.cache_directory = cache_directory
|
| 1119 |
+
self.debug_dataset = debug_dataset
|
| 1120 |
+
self.architecture = architecture
|
| 1121 |
+
self.seed = None
|
| 1122 |
+
self.current_epoch = 0
|
| 1123 |
+
|
| 1124 |
+
if not self.enable_bucket:
|
| 1125 |
+
self.bucket_no_upscale = False
|
| 1126 |
+
|
| 1127 |
+
def get_metadata(self) -> dict:
|
| 1128 |
+
metadata = {
|
| 1129 |
+
"resolution": self.resolution,
|
| 1130 |
+
"caption_extension": self.caption_extension,
|
| 1131 |
+
"batch_size_per_device": self.batch_size,
|
| 1132 |
+
"num_repeats": self.num_repeats,
|
| 1133 |
+
"enable_bucket": bool(self.enable_bucket),
|
| 1134 |
+
"bucket_no_upscale": bool(self.bucket_no_upscale),
|
| 1135 |
+
}
|
| 1136 |
+
return metadata
|
| 1137 |
+
|
| 1138 |
+
def get_all_latent_cache_files(self):
|
| 1139 |
+
return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
|
| 1140 |
+
|
| 1141 |
+
def get_all_text_encoder_output_cache_files(self):
|
| 1142 |
+
return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}_te.safetensors"))
|
| 1143 |
+
|
| 1144 |
+
def get_latent_cache_path(self, item_info: ItemInfo) -> str:
|
| 1145 |
+
"""
|
| 1146 |
+
Returns the cache path for the latent tensor.
|
| 1147 |
+
|
| 1148 |
+
item_info: ItemInfo object
|
| 1149 |
+
|
| 1150 |
+
Returns:
|
| 1151 |
+
str: cache path
|
| 1152 |
+
|
| 1153 |
+
cache_path is based on the item_key and the resolution.
|
| 1154 |
+
"""
|
| 1155 |
+
w, h = item_info.original_size
|
| 1156 |
+
basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
|
| 1157 |
+
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
|
| 1158 |
+
return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors")
|
| 1159 |
+
|
| 1160 |
+
def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
|
| 1161 |
+
basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
|
| 1162 |
+
assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
|
| 1163 |
+
return os.path.join(self.cache_directory, f"{basename}_{self.architecture}_te.safetensors")
|
| 1164 |
+
|
| 1165 |
+
def retrieve_latent_cache_batches(self, num_workers: int):
|
| 1166 |
+
raise NotImplementedError
|
| 1167 |
+
|
| 1168 |
+
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
|
| 1169 |
+
raise NotImplementedError
|
| 1170 |
+
|
| 1171 |
+
def prepare_for_training(self):
|
| 1172 |
+
pass
|
| 1173 |
+
|
| 1174 |
+
def set_seed(self, seed: int):
|
| 1175 |
+
self.seed = seed
|
| 1176 |
+
|
| 1177 |
+
def set_current_epoch(self, epoch):
|
| 1178 |
+
if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
|
| 1179 |
+
if epoch > self.current_epoch:
|
| 1180 |
+
logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
| 1181 |
+
num_epochs = epoch - self.current_epoch
|
| 1182 |
+
for _ in range(num_epochs):
|
| 1183 |
+
self.current_epoch += 1
|
| 1184 |
+
self.shuffle_buckets()
|
| 1185 |
+
# self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
|
| 1186 |
+
else:
|
| 1187 |
+
logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
|
| 1188 |
+
self.current_epoch = epoch
|
| 1189 |
+
|
| 1190 |
+
def set_current_step(self, step):
|
| 1191 |
+
self.current_step = step
|
| 1192 |
+
|
| 1193 |
+
def set_max_train_steps(self, max_train_steps):
|
| 1194 |
+
self.max_train_steps = max_train_steps
|
| 1195 |
+
|
| 1196 |
+
def shuffle_buckets(self):
|
| 1197 |
+
raise NotImplementedError
|
| 1198 |
+
|
| 1199 |
+
def __len__(self):
|
| 1200 |
+
return NotImplementedError
|
| 1201 |
+
|
| 1202 |
+
def __getitem__(self, idx):
|
| 1203 |
+
raise NotImplementedError
|
| 1204 |
+
|
| 1205 |
+
def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
|
| 1206 |
+
datasource.set_caption_only(True)
|
| 1207 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
| 1208 |
+
|
| 1209 |
+
data: list[ItemInfo] = []
|
| 1210 |
+
futures = []
|
| 1211 |
+
|
| 1212 |
+
def aggregate_future(consume_all: bool = False):
|
| 1213 |
+
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
|
| 1214 |
+
completed_futures = [future for future in futures if future.done()]
|
| 1215 |
+
if len(completed_futures) == 0:
|
| 1216 |
+
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
|
| 1217 |
+
time.sleep(0.1)
|
| 1218 |
+
continue
|
| 1219 |
+
else:
|
| 1220 |
+
break # submit batch if possible
|
| 1221 |
+
|
| 1222 |
+
for future in completed_futures:
|
| 1223 |
+
item_key, caption = future.result()
|
| 1224 |
+
item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
|
| 1225 |
+
item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
|
| 1226 |
+
data.append(item_info)
|
| 1227 |
+
|
| 1228 |
+
futures.remove(future)
|
| 1229 |
+
|
| 1230 |
+
def submit_batch(flush: bool = False):
|
| 1231 |
+
nonlocal data
|
| 1232 |
+
if len(data) >= batch_size or (len(data) > 0 and flush):
|
| 1233 |
+
batch = data[0:batch_size]
|
| 1234 |
+
if len(data) > batch_size:
|
| 1235 |
+
data = data[batch_size:]
|
| 1236 |
+
else:
|
| 1237 |
+
data = []
|
| 1238 |
+
return batch
|
| 1239 |
+
return None
|
| 1240 |
+
|
| 1241 |
+
for fetch_op in datasource:
|
| 1242 |
+
future = executor.submit(fetch_op)
|
| 1243 |
+
futures.append(future)
|
| 1244 |
+
aggregate_future()
|
| 1245 |
+
while True:
|
| 1246 |
+
batch = submit_batch()
|
| 1247 |
+
if batch is None:
|
| 1248 |
+
break
|
| 1249 |
+
yield batch
|
| 1250 |
+
|
| 1251 |
+
aggregate_future(consume_all=True)
|
| 1252 |
+
while True:
|
| 1253 |
+
batch = submit_batch(flush=True)
|
| 1254 |
+
if batch is None:
|
| 1255 |
+
break
|
| 1256 |
+
yield batch
|
| 1257 |
+
|
| 1258 |
+
executor.shutdown()
|
| 1259 |
+
|
| 1260 |
+
|
| 1261 |
+
class ImageDataset(BaseDataset):
|
| 1262 |
+
def __init__(
|
| 1263 |
+
self,
|
| 1264 |
+
resolution: Tuple[int, int],
|
| 1265 |
+
caption_extension: Optional[str],
|
| 1266 |
+
batch_size: int,
|
| 1267 |
+
num_repeats: int,
|
| 1268 |
+
enable_bucket: bool,
|
| 1269 |
+
bucket_no_upscale: bool,
|
| 1270 |
+
image_directory: Optional[str] = None,
|
| 1271 |
+
image_jsonl_file: Optional[str] = None,
|
| 1272 |
+
control_directory: Optional[str] = None,
|
| 1273 |
+
cache_directory: Optional[str] = None,
|
| 1274 |
+
debug_dataset: bool = False,
|
| 1275 |
+
architecture: str = "no_default",
|
| 1276 |
+
):
|
| 1277 |
+
super(ImageDataset, self).__init__(
|
| 1278 |
+
resolution,
|
| 1279 |
+
caption_extension,
|
| 1280 |
+
batch_size,
|
| 1281 |
+
num_repeats,
|
| 1282 |
+
enable_bucket,
|
| 1283 |
+
bucket_no_upscale,
|
| 1284 |
+
cache_directory,
|
| 1285 |
+
debug_dataset,
|
| 1286 |
+
architecture,
|
| 1287 |
+
)
|
| 1288 |
+
self.image_directory = image_directory
|
| 1289 |
+
self.image_jsonl_file = image_jsonl_file
|
| 1290 |
+
self.control_directory = control_directory
|
| 1291 |
+
if image_directory is not None:
|
| 1292 |
+
self.datasource = ImageDirectoryDatasource(image_directory, caption_extension, control_directory)
|
| 1293 |
+
elif image_jsonl_file is not None:
|
| 1294 |
+
self.datasource = ImageJsonlDatasource(image_jsonl_file)
|
| 1295 |
+
else:
|
| 1296 |
+
raise ValueError("image_directory or image_jsonl_file must be specified")
|
| 1297 |
+
|
| 1298 |
+
if self.cache_directory is None:
|
| 1299 |
+
self.cache_directory = self.image_directory
|
| 1300 |
+
|
| 1301 |
+
self.batch_manager = None
|
| 1302 |
+
self.num_train_items = 0
|
| 1303 |
+
self.has_control = self.datasource.has_control
|
| 1304 |
+
|
| 1305 |
+
def get_metadata(self):
|
| 1306 |
+
metadata = super().get_metadata()
|
| 1307 |
+
if self.image_directory is not None:
|
| 1308 |
+
metadata["image_directory"] = os.path.basename(self.image_directory)
|
| 1309 |
+
if self.image_jsonl_file is not None:
|
| 1310 |
+
metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
|
| 1311 |
+
if self.control_directory is not None:
|
| 1312 |
+
metadata["control_directory"] = os.path.basename(self.control_directory)
|
| 1313 |
+
metadata["has_control"] = self.has_control
|
| 1314 |
+
return metadata
|
| 1315 |
+
|
| 1316 |
+
def get_total_image_count(self):
|
| 1317 |
+
return len(self.datasource) if self.datasource.is_indexable() else None
|
| 1318 |
+
|
| 1319 |
+
def retrieve_latent_cache_batches(self, num_workers: int):
|
| 1320 |
+
buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
|
| 1321 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
| 1322 |
+
|
| 1323 |
+
batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
|
| 1324 |
+
futures = []
|
| 1325 |
+
|
| 1326 |
+
# aggregate futures and sort by bucket resolution
|
| 1327 |
+
def aggregate_future(consume_all: bool = False):
|
| 1328 |
+
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
|
| 1329 |
+
completed_futures = [future for future in futures if future.done()]
|
| 1330 |
+
if len(completed_futures) == 0:
|
| 1331 |
+
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
|
| 1332 |
+
time.sleep(0.1)
|
| 1333 |
+
continue
|
| 1334 |
+
else:
|
| 1335 |
+
break # submit batch if possible
|
| 1336 |
+
|
| 1337 |
+
for future in completed_futures:
|
| 1338 |
+
original_size, item_key, image, caption, control = future.result()
|
| 1339 |
+
bucket_height, bucket_width = image.shape[:2]
|
| 1340 |
+
bucket_reso = (bucket_width, bucket_height)
|
| 1341 |
+
|
| 1342 |
+
item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
|
| 1343 |
+
item_info.latent_cache_path = self.get_latent_cache_path(item_info)
|
| 1344 |
+
|
| 1345 |
+
if control is not None:
|
| 1346 |
+
item_info.control_content = control
|
| 1347 |
+
|
| 1348 |
+
if bucket_reso not in batches:
|
| 1349 |
+
batches[bucket_reso] = []
|
| 1350 |
+
batches[bucket_reso].append(item_info)
|
| 1351 |
+
|
| 1352 |
+
futures.remove(future)
|
| 1353 |
+
|
| 1354 |
+
# submit batch if some bucket has enough items
|
| 1355 |
+
def submit_batch(flush: bool = False):
|
| 1356 |
+
for key in batches:
|
| 1357 |
+
if len(batches[key]) >= self.batch_size or flush:
|
| 1358 |
+
batch = batches[key][0 : self.batch_size]
|
| 1359 |
+
if len(batches[key]) > self.batch_size:
|
| 1360 |
+
batches[key] = batches[key][self.batch_size :]
|
| 1361 |
+
else:
|
| 1362 |
+
del batches[key]
|
| 1363 |
+
return key, batch
|
| 1364 |
+
return None, None
|
| 1365 |
+
|
| 1366 |
+
for fetch_op in self.datasource:
|
| 1367 |
+
|
| 1368 |
+
# fetch and resize image in a separate thread
|
| 1369 |
+
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str, Optional[Image.Image]]:
|
| 1370 |
+
image_key, image, caption, control = op()
|
| 1371 |
+
image: Image.Image
|
| 1372 |
+
image_size = image.size
|
| 1373 |
+
|
| 1374 |
+
bucket_reso = buckset_selector.get_bucket_resolution(image_size)
|
| 1375 |
+
image = resize_image_to_bucket(image, bucket_reso)
|
| 1376 |
+
if control is not None:
|
| 1377 |
+
control = resize_image_to_bucket(control, bucket_reso)
|
| 1378 |
+
return image_size, image_key, image, caption, control
|
| 1379 |
+
|
| 1380 |
+
future = executor.submit(fetch_and_resize, fetch_op)
|
| 1381 |
+
futures.append(future)
|
| 1382 |
+
aggregate_future()
|
| 1383 |
+
while True:
|
| 1384 |
+
key, batch = submit_batch()
|
| 1385 |
+
if key is None:
|
| 1386 |
+
break
|
| 1387 |
+
yield key, batch
|
| 1388 |
+
|
| 1389 |
+
aggregate_future(consume_all=True)
|
| 1390 |
+
while True:
|
| 1391 |
+
key, batch = submit_batch(flush=True)
|
| 1392 |
+
if key is None:
|
| 1393 |
+
break
|
| 1394 |
+
yield key, batch
|
| 1395 |
+
|
| 1396 |
+
executor.shutdown()
|
| 1397 |
+
|
| 1398 |
+
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
|
| 1399 |
+
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
|
| 1400 |
+
|
| 1401 |
+
def prepare_for_training(self):
|
| 1402 |
+
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
|
| 1403 |
+
|
| 1404 |
+
# glob cache files
|
| 1405 |
+
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
|
| 1406 |
+
|
| 1407 |
+
# assign cache files to item info
|
| 1408 |
+
bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
|
| 1409 |
+
for cache_file in latent_cache_files:
|
| 1410 |
+
tokens = os.path.basename(cache_file).split("_")
|
| 1411 |
+
|
| 1412 |
+
image_size = tokens[-2] # 0000x0000
|
| 1413 |
+
image_width, image_height = map(int, image_size.split("x"))
|
| 1414 |
+
image_size = (image_width, image_height)
|
| 1415 |
+
|
| 1416 |
+
item_key = "_".join(tokens[:-2])
|
| 1417 |
+
text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
|
| 1418 |
+
if not os.path.exists(text_encoder_output_cache_file):
|
| 1419 |
+
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
|
| 1420 |
+
continue
|
| 1421 |
+
|
| 1422 |
+
bucket_reso = bucket_selector.get_bucket_resolution(image_size)
|
| 1423 |
+
item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
|
| 1424 |
+
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
|
| 1425 |
+
|
| 1426 |
+
bucket = bucketed_item_info.get(bucket_reso, [])
|
| 1427 |
+
for _ in range(self.num_repeats):
|
| 1428 |
+
bucket.append(item_info)
|
| 1429 |
+
bucketed_item_info[bucket_reso] = bucket
|
| 1430 |
+
|
| 1431 |
+
# prepare batch manager
|
| 1432 |
+
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
|
| 1433 |
+
self.batch_manager.show_bucket_info()
|
| 1434 |
+
|
| 1435 |
+
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
|
| 1436 |
+
|
| 1437 |
+
def shuffle_buckets(self):
|
| 1438 |
+
# set random seed for this epoch
|
| 1439 |
+
random.seed(self.seed + self.current_epoch)
|
| 1440 |
+
self.batch_manager.shuffle()
|
| 1441 |
+
|
| 1442 |
+
def __len__(self):
|
| 1443 |
+
if self.batch_manager is None:
|
| 1444 |
+
return 100 # dummy value
|
| 1445 |
+
return len(self.batch_manager)
|
| 1446 |
+
|
| 1447 |
+
def __getitem__(self, idx):
|
| 1448 |
+
return self.batch_manager[idx]
|
| 1449 |
+
|
| 1450 |
+
|
| 1451 |
+
class VideoDataset(BaseDataset):
|
| 1452 |
+
TARGET_FPS_HUNYUAN = 24.0
|
| 1453 |
+
TARGET_FPS_WAN = 16.0
|
| 1454 |
+
TARGET_FPS_FRAMEPACK = 30.0
|
| 1455 |
+
|
| 1456 |
+
def __init__(
|
| 1457 |
+
self,
|
| 1458 |
+
resolution: Tuple[int, int],
|
| 1459 |
+
caption_extension: Optional[str],
|
| 1460 |
+
batch_size: int,
|
| 1461 |
+
num_repeats: int,
|
| 1462 |
+
enable_bucket: bool,
|
| 1463 |
+
bucket_no_upscale: bool,
|
| 1464 |
+
frame_extraction: Optional[str] = "head",
|
| 1465 |
+
frame_stride: Optional[int] = 1,
|
| 1466 |
+
frame_sample: Optional[int] = 1,
|
| 1467 |
+
target_frames: Optional[list[int]] = None,
|
| 1468 |
+
max_frames: Optional[int] = None,
|
| 1469 |
+
source_fps: Optional[float] = None,
|
| 1470 |
+
video_directory: Optional[str] = None,
|
| 1471 |
+
video_jsonl_file: Optional[str] = None,
|
| 1472 |
+
control_directory: Optional[str] = None,
|
| 1473 |
+
cache_directory: Optional[str] = None,
|
| 1474 |
+
debug_dataset: bool = False,
|
| 1475 |
+
architecture: str = "no_default",
|
| 1476 |
+
):
|
| 1477 |
+
super(VideoDataset, self).__init__(
|
| 1478 |
+
resolution,
|
| 1479 |
+
caption_extension,
|
| 1480 |
+
batch_size,
|
| 1481 |
+
num_repeats,
|
| 1482 |
+
enable_bucket,
|
| 1483 |
+
bucket_no_upscale,
|
| 1484 |
+
cache_directory,
|
| 1485 |
+
debug_dataset,
|
| 1486 |
+
architecture,
|
| 1487 |
+
)
|
| 1488 |
+
self.video_directory = video_directory
|
| 1489 |
+
self.video_jsonl_file = video_jsonl_file
|
| 1490 |
+
self.control_directory = control_directory
|
| 1491 |
+
self.frame_extraction = frame_extraction
|
| 1492 |
+
self.frame_stride = frame_stride
|
| 1493 |
+
self.frame_sample = frame_sample
|
| 1494 |
+
self.max_frames = max_frames
|
| 1495 |
+
self.source_fps = source_fps
|
| 1496 |
+
|
| 1497 |
+
if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
|
| 1498 |
+
self.target_fps = VideoDataset.TARGET_FPS_HUNYUAN
|
| 1499 |
+
elif self.architecture == ARCHITECTURE_WAN:
|
| 1500 |
+
self.target_fps = VideoDataset.TARGET_FPS_WAN
|
| 1501 |
+
elif self.architecture == ARCHITECTURE_FRAMEPACK:
|
| 1502 |
+
self.target_fps = VideoDataset.TARGET_FPS_FRAMEPACK
|
| 1503 |
+
else:
|
| 1504 |
+
raise ValueError(f"Unsupported architecture: {self.architecture}")
|
| 1505 |
+
|
| 1506 |
+
if target_frames is not None:
|
| 1507 |
+
target_frames = list(set(target_frames))
|
| 1508 |
+
target_frames.sort()
|
| 1509 |
+
|
| 1510 |
+
# round each value to N*4+1
|
| 1511 |
+
rounded_target_frames = [(f - 1) // 4 * 4 + 1 for f in target_frames]
|
| 1512 |
+
rouneded_target_frames = list(set(rounded_target_frames))
|
| 1513 |
+
rouneded_target_frames.sort()
|
| 1514 |
+
|
| 1515 |
+
# if value is changed, warn
|
| 1516 |
+
if target_frames != rounded_target_frames:
|
| 1517 |
+
logger.warning(f"target_frames are rounded to {rounded_target_frames}")
|
| 1518 |
+
|
| 1519 |
+
target_frames = tuple(rounded_target_frames)
|
| 1520 |
+
|
| 1521 |
+
self.target_frames = target_frames
|
| 1522 |
+
|
| 1523 |
+
if video_directory is not None:
|
| 1524 |
+
self.datasource = VideoDirectoryDatasource(video_directory, caption_extension, control_directory)
|
| 1525 |
+
elif video_jsonl_file is not None:
|
| 1526 |
+
self.datasource = VideoJsonlDatasource(video_jsonl_file)
|
| 1527 |
+
|
| 1528 |
+
if self.frame_extraction == "uniform" and self.frame_sample == 1:
|
| 1529 |
+
self.frame_extraction = "head"
|
| 1530 |
+
logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
|
| 1531 |
+
if self.frame_extraction == "head":
|
| 1532 |
+
# head extraction. we can limit the number of frames to be extracted
|
| 1533 |
+
self.datasource.set_start_and_end_frame(0, max(self.target_frames))
|
| 1534 |
+
|
| 1535 |
+
if self.cache_directory is None:
|
| 1536 |
+
self.cache_directory = self.video_directory
|
| 1537 |
+
|
| 1538 |
+
self.batch_manager = None
|
| 1539 |
+
self.num_train_items = 0
|
| 1540 |
+
self.has_control = self.datasource.has_control
|
| 1541 |
+
|
| 1542 |
+
def get_metadata(self):
|
| 1543 |
+
metadata = super().get_metadata()
|
| 1544 |
+
if self.video_directory is not None:
|
| 1545 |
+
metadata["video_directory"] = os.path.basename(self.video_directory)
|
| 1546 |
+
if self.video_jsonl_file is not None:
|
| 1547 |
+
metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
|
| 1548 |
+
if self.control_directory is not None:
|
| 1549 |
+
metadata["control_directory"] = os.path.basename(self.control_directory)
|
| 1550 |
+
metadata["frame_extraction"] = self.frame_extraction
|
| 1551 |
+
metadata["frame_stride"] = self.frame_stride
|
| 1552 |
+
metadata["frame_sample"] = self.frame_sample
|
| 1553 |
+
metadata["target_frames"] = self.target_frames
|
| 1554 |
+
metadata["max_frames"] = self.max_frames
|
| 1555 |
+
metadata["source_fps"] = self.source_fps
|
| 1556 |
+
metadata["has_control"] = self.has_control
|
| 1557 |
+
return metadata
|
| 1558 |
+
|
| 1559 |
+
def retrieve_latent_cache_batches(self, num_workers: int):
|
| 1560 |
+
buckset_selector = BucketSelector(self.resolution, architecture=self.architecture)
|
| 1561 |
+
self.datasource.set_bucket_selector(buckset_selector)
|
| 1562 |
+
if self.source_fps is not None:
|
| 1563 |
+
self.datasource.set_source_and_target_fps(self.source_fps, self.target_fps)
|
| 1564 |
+
else:
|
| 1565 |
+
self.datasource.set_source_and_target_fps(None, None) # no conversion
|
| 1566 |
+
|
| 1567 |
+
executor = ThreadPoolExecutor(max_workers=num_workers)
|
| 1568 |
+
|
| 1569 |
+
# key: (width, height, frame_count), value: [ItemInfo]
|
| 1570 |
+
batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
|
| 1571 |
+
futures = []
|
| 1572 |
+
|
| 1573 |
+
def aggregate_future(consume_all: bool = False):
|
| 1574 |
+
while len(futures) >= num_workers or (consume_all and len(futures) > 0):
|
| 1575 |
+
completed_futures = [future for future in futures if future.done()]
|
| 1576 |
+
if len(completed_futures) == 0:
|
| 1577 |
+
if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
|
| 1578 |
+
time.sleep(0.1)
|
| 1579 |
+
continue
|
| 1580 |
+
else:
|
| 1581 |
+
break # submit batch if possible
|
| 1582 |
+
|
| 1583 |
+
for future in completed_futures:
|
| 1584 |
+
original_frame_size, video_key, video, caption, control = future.result()
|
| 1585 |
+
|
| 1586 |
+
frame_count = len(video)
|
| 1587 |
+
video = np.stack(video, axis=0)
|
| 1588 |
+
height, width = video.shape[1:3]
|
| 1589 |
+
bucket_reso = (width, height) # already resized
|
| 1590 |
+
|
| 1591 |
+
# process control images if available
|
| 1592 |
+
control_video = None
|
| 1593 |
+
if control is not None:
|
| 1594 |
+
# set frame count to the same as video
|
| 1595 |
+
if len(control) > frame_count:
|
| 1596 |
+
control = control[:frame_count]
|
| 1597 |
+
elif len(control) < frame_count:
|
| 1598 |
+
# if control is shorter than video, repeat the last frame
|
| 1599 |
+
last_frame = control[-1]
|
| 1600 |
+
control.extend([last_frame] * (frame_count - len(control)))
|
| 1601 |
+
control_video = np.stack(control, axis=0)
|
| 1602 |
+
|
| 1603 |
+
crop_pos_and_frames = []
|
| 1604 |
+
if self.frame_extraction == "head":
|
| 1605 |
+
for target_frame in self.target_frames:
|
| 1606 |
+
if frame_count >= target_frame:
|
| 1607 |
+
crop_pos_and_frames.append((0, target_frame))
|
| 1608 |
+
elif self.frame_extraction == "chunk":
|
| 1609 |
+
# split by target_frames
|
| 1610 |
+
for target_frame in self.target_frames:
|
| 1611 |
+
for i in range(0, frame_count, target_frame):
|
| 1612 |
+
if i + target_frame <= frame_count:
|
| 1613 |
+
crop_pos_and_frames.append((i, target_frame))
|
| 1614 |
+
elif self.frame_extraction == "slide":
|
| 1615 |
+
# slide window
|
| 1616 |
+
for target_frame in self.target_frames:
|
| 1617 |
+
if frame_count >= target_frame:
|
| 1618 |
+
for i in range(0, frame_count - target_frame + 1, self.frame_stride):
|
| 1619 |
+
crop_pos_and_frames.append((i, target_frame))
|
| 1620 |
+
elif self.frame_extraction == "uniform":
|
| 1621 |
+
# select N frames uniformly
|
| 1622 |
+
for target_frame in self.target_frames:
|
| 1623 |
+
if frame_count >= target_frame:
|
| 1624 |
+
frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
|
| 1625 |
+
for i in frame_indices:
|
| 1626 |
+
crop_pos_and_frames.append((i, target_frame))
|
| 1627 |
+
elif self.frame_extraction == "full":
|
| 1628 |
+
# select all frames
|
| 1629 |
+
target_frame = min(frame_count, self.max_frames)
|
| 1630 |
+
target_frame = (target_frame - 1) // 4 * 4 + 1 # round to N*4+1
|
| 1631 |
+
crop_pos_and_frames.append((0, target_frame))
|
| 1632 |
+
else:
|
| 1633 |
+
raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
|
| 1634 |
+
|
| 1635 |
+
for crop_pos, target_frame in crop_pos_and_frames:
|
| 1636 |
+
cropped_video = video[crop_pos : crop_pos + target_frame]
|
| 1637 |
+
body, ext = os.path.splitext(video_key)
|
| 1638 |
+
item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
|
| 1639 |
+
batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
|
| 1640 |
+
|
| 1641 |
+
# crop control video if available
|
| 1642 |
+
cropped_control = None
|
| 1643 |
+
if control_video is not None:
|
| 1644 |
+
cropped_control = control_video[crop_pos : crop_pos + target_frame]
|
| 1645 |
+
|
| 1646 |
+
item_info = ItemInfo(
|
| 1647 |
+
item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
|
| 1648 |
+
)
|
| 1649 |
+
item_info.latent_cache_path = self.get_latent_cache_path(item_info)
|
| 1650 |
+
item_info.control_content = cropped_control # None is allowed
|
| 1651 |
+
|
| 1652 |
+
batch = batches.get(batch_key, [])
|
| 1653 |
+
batch.append(item_info)
|
| 1654 |
+
batches[batch_key] = batch
|
| 1655 |
+
|
| 1656 |
+
futures.remove(future)
|
| 1657 |
+
|
| 1658 |
+
def submit_batch(flush: bool = False):
|
| 1659 |
+
for key in batches:
|
| 1660 |
+
if len(batches[key]) >= self.batch_size or flush:
|
| 1661 |
+
batch = batches[key][0 : self.batch_size]
|
| 1662 |
+
if len(batches[key]) > self.batch_size:
|
| 1663 |
+
batches[key] = batches[key][self.batch_size :]
|
| 1664 |
+
else:
|
| 1665 |
+
del batches[key]
|
| 1666 |
+
return key, batch
|
| 1667 |
+
return None, None
|
| 1668 |
+
|
| 1669 |
+
for operator in self.datasource:
|
| 1670 |
+
|
| 1671 |
+
def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str, Optional[list[np.ndarray]]]:
|
| 1672 |
+
result = op()
|
| 1673 |
+
|
| 1674 |
+
if len(result) == 3: # for backward compatibility TODO remove this in the future
|
| 1675 |
+
video_key, video, caption = result
|
| 1676 |
+
control = None
|
| 1677 |
+
else:
|
| 1678 |
+
video_key, video, caption, control = result
|
| 1679 |
+
|
| 1680 |
+
video: list[np.ndarray]
|
| 1681 |
+
frame_size = (video[0].shape[1], video[0].shape[0])
|
| 1682 |
+
|
| 1683 |
+
# resize if necessary
|
| 1684 |
+
bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
|
| 1685 |
+
video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
|
| 1686 |
+
|
| 1687 |
+
# resize control if necessary
|
| 1688 |
+
if control is not None:
|
| 1689 |
+
control = [resize_image_to_bucket(frame, bucket_reso) for frame in control]
|
| 1690 |
+
|
| 1691 |
+
return frame_size, video_key, video, caption, control
|
| 1692 |
+
|
| 1693 |
+
future = executor.submit(fetch_and_resize, operator)
|
| 1694 |
+
futures.append(future)
|
| 1695 |
+
aggregate_future()
|
| 1696 |
+
while True:
|
| 1697 |
+
key, batch = submit_batch()
|
| 1698 |
+
if key is None:
|
| 1699 |
+
break
|
| 1700 |
+
yield key, batch
|
| 1701 |
+
|
| 1702 |
+
aggregate_future(consume_all=True)
|
| 1703 |
+
while True:
|
| 1704 |
+
key, batch = submit_batch(flush=True)
|
| 1705 |
+
if key is None:
|
| 1706 |
+
break
|
| 1707 |
+
yield key, batch
|
| 1708 |
+
|
| 1709 |
+
executor.shutdown()
|
| 1710 |
+
|
| 1711 |
+
def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
|
| 1712 |
+
return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
|
| 1713 |
+
|
| 1714 |
+
def prepare_for_training(self):
|
| 1715 |
+
bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
|
| 1716 |
+
|
| 1717 |
+
# glob cache files
|
| 1718 |
+
latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
|
| 1719 |
+
|
| 1720 |
+
# assign cache files to item info
|
| 1721 |
+
bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
|
| 1722 |
+
for cache_file in latent_cache_files:
|
| 1723 |
+
tokens = os.path.basename(cache_file).split("_")
|
| 1724 |
+
|
| 1725 |
+
image_size = tokens[-2] # 0000x0000
|
| 1726 |
+
image_width, image_height = map(int, image_size.split("x"))
|
| 1727 |
+
image_size = (image_width, image_height)
|
| 1728 |
+
|
| 1729 |
+
frame_pos, frame_count = tokens[-3].split("-")[:2] # "00000-000", or optional section index "00000-000-00"
|
| 1730 |
+
frame_pos, frame_count = int(frame_pos), int(frame_count)
|
| 1731 |
+
|
| 1732 |
+
item_key = "_".join(tokens[:-3])
|
| 1733 |
+
text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
|
| 1734 |
+
if not os.path.exists(text_encoder_output_cache_file):
|
| 1735 |
+
logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
|
| 1736 |
+
continue
|
| 1737 |
+
|
| 1738 |
+
bucket_reso = bucket_selector.get_bucket_resolution(image_size)
|
| 1739 |
+
bucket_reso = (*bucket_reso, frame_count)
|
| 1740 |
+
item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
|
| 1741 |
+
item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
|
| 1742 |
+
|
| 1743 |
+
bucket = bucketed_item_info.get(bucket_reso, [])
|
| 1744 |
+
for _ in range(self.num_repeats):
|
| 1745 |
+
bucket.append(item_info)
|
| 1746 |
+
bucketed_item_info[bucket_reso] = bucket
|
| 1747 |
+
|
| 1748 |
+
# prepare batch manager
|
| 1749 |
+
self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
|
| 1750 |
+
self.batch_manager.show_bucket_info()
|
| 1751 |
+
|
| 1752 |
+
self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
|
| 1753 |
+
|
| 1754 |
+
def shuffle_buckets(self):
|
| 1755 |
+
# set random seed for this epoch
|
| 1756 |
+
random.seed(self.seed + self.current_epoch)
|
| 1757 |
+
self.batch_manager.shuffle()
|
| 1758 |
+
|
| 1759 |
+
def __len__(self):
|
| 1760 |
+
if self.batch_manager is None:
|
| 1761 |
+
return 100 # dummy value
|
| 1762 |
+
return len(self.batch_manager)
|
| 1763 |
+
|
| 1764 |
+
def __getitem__(self, idx):
|
| 1765 |
+
return self.batch_manager[idx]
|
| 1766 |
+
|
| 1767 |
+
|
| 1768 |
+
class DatasetGroup(torch.utils.data.ConcatDataset):
|
| 1769 |
+
def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
|
| 1770 |
+
super().__init__(datasets)
|
| 1771 |
+
self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
|
| 1772 |
+
self.num_train_items = 0
|
| 1773 |
+
for dataset in self.datasets:
|
| 1774 |
+
self.num_train_items += dataset.num_train_items
|
| 1775 |
+
|
| 1776 |
+
def set_current_epoch(self, epoch):
|
| 1777 |
+
for dataset in self.datasets:
|
| 1778 |
+
dataset.set_current_epoch(epoch)
|
| 1779 |
+
|
| 1780 |
+
def set_current_step(self, step):
|
| 1781 |
+
for dataset in self.datasets:
|
| 1782 |
+
dataset.set_current_step(step)
|
| 1783 |
+
|
| 1784 |
+
def set_max_train_steps(self, max_train_steps):
|
| 1785 |
+
for dataset in self.datasets:
|
| 1786 |
+
dataset.set_max_train_steps(max_train_steps)
|
fpack_cache_latents.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from transformers import SiglipImageProcessor, SiglipVisionModel
|
| 12 |
+
|
| 13 |
+
from dataset import config_utils
|
| 14 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
| 15 |
+
from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache_framepack, ARCHITECTURE_FRAMEPACK
|
| 16 |
+
from frame_pack import hunyuan
|
| 17 |
+
from frame_pack.framepack_utils import load_image_encoders, load_vae
|
| 18 |
+
from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
| 19 |
+
from frame_pack.clip_vision import hf_clip_vision_encode
|
| 20 |
+
import cache_latents
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
logging.basicConfig(level=logging.INFO)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def encode_and_save_batch(
|
| 27 |
+
vae: AutoencoderKLCausal3D,
|
| 28 |
+
feature_extractor: SiglipImageProcessor,
|
| 29 |
+
image_encoder: SiglipVisionModel,
|
| 30 |
+
batch: List[ItemInfo],
|
| 31 |
+
latent_window_size: int,
|
| 32 |
+
vanilla_sampling: bool = False,
|
| 33 |
+
one_frame: bool = False,
|
| 34 |
+
):
|
| 35 |
+
"""Encode a batch of original RGB videos and save FramePack section caches."""
|
| 36 |
+
if one_frame:
|
| 37 |
+
encode_and_save_batch_one_frame(vae, feature_extractor, image_encoder, batch, latent_window_size, vanilla_sampling)
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
# Stack batch into tensor (B,C,F,H,W) in RGB order
|
| 41 |
+
contents = torch.stack([torch.from_numpy(item.content) for item in batch])
|
| 42 |
+
if len(contents.shape) == 4:
|
| 43 |
+
contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
|
| 44 |
+
|
| 45 |
+
contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
|
| 46 |
+
contents = contents.to(vae.device, dtype=vae.dtype)
|
| 47 |
+
contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
|
| 48 |
+
|
| 49 |
+
height, width = contents.shape[3], contents.shape[4]
|
| 50 |
+
if height < 8 or width < 8:
|
| 51 |
+
item = batch[0] # other items should have the same size
|
| 52 |
+
raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
|
| 53 |
+
|
| 54 |
+
# calculate latent frame count from original frame count (4n+1)
|
| 55 |
+
latent_f = (batch[0].frame_count - 1) // 4 + 1
|
| 56 |
+
|
| 57 |
+
# calculate the total number of sections (excluding the first frame, divided by window size)
|
| 58 |
+
total_latent_sections = math.floor((latent_f - 1) / latent_window_size)
|
| 59 |
+
if total_latent_sections < 1:
|
| 60 |
+
min_frames_needed = latent_window_size * 4 + 1
|
| 61 |
+
raise ValueError(
|
| 62 |
+
f"Not enough frames for FramePack: {batch[0].frame_count} frames ({latent_f} latent frames), minimum required: {min_frames_needed} frames ({latent_window_size+1} latent frames)"
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# actual latent frame count (aligned to section boundaries)
|
| 66 |
+
latent_f_aligned = total_latent_sections * latent_window_size + 1 if not one_frame else 1
|
| 67 |
+
|
| 68 |
+
# actual video frame count
|
| 69 |
+
frame_count_aligned = (latent_f_aligned - 1) * 4 + 1
|
| 70 |
+
if frame_count_aligned != batch[0].frame_count:
|
| 71 |
+
logger.info(
|
| 72 |
+
f"Frame count mismatch: required={frame_count_aligned} != actual={batch[0].frame_count}, trimming to {frame_count_aligned}"
|
| 73 |
+
)
|
| 74 |
+
contents = contents[:, :, :frame_count_aligned, :, :]
|
| 75 |
+
|
| 76 |
+
latent_f = latent_f_aligned # Update to the aligned value
|
| 77 |
+
|
| 78 |
+
# VAE encode (list of tensor -> stack)
|
| 79 |
+
latents = hunyuan.vae_encode(contents, vae) # include scaling factor
|
| 80 |
+
latents = latents.to("cpu") # (B, C, latent_f, H/8, W/8)
|
| 81 |
+
|
| 82 |
+
# Vision encoding per‑item (once)
|
| 83 |
+
images = np.stack([item.content[0] for item in batch], axis=0) # B, H, W, C
|
| 84 |
+
|
| 85 |
+
# encode image with image encoder
|
| 86 |
+
image_embeddings = []
|
| 87 |
+
with torch.no_grad():
|
| 88 |
+
for image in images:
|
| 89 |
+
image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
|
| 90 |
+
image_embeddings.append(image_encoder_output.last_hidden_state)
|
| 91 |
+
image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
|
| 92 |
+
image_embeddings = image_embeddings.to("cpu") # Save memory
|
| 93 |
+
|
| 94 |
+
if not vanilla_sampling:
|
| 95 |
+
# padding is reversed for inference (future to past)
|
| 96 |
+
latent_paddings = list(reversed(range(total_latent_sections)))
|
| 97 |
+
# Note: The padding trick for inference. See the paper for details.
|
| 98 |
+
if total_latent_sections > 4:
|
| 99 |
+
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
|
| 100 |
+
|
| 101 |
+
for b, item in enumerate(batch):
|
| 102 |
+
original_latent_cache_path = item.latent_cache_path
|
| 103 |
+
video_lat = latents[b : b + 1] # keep batch dim, 1, C, F, H, W
|
| 104 |
+
|
| 105 |
+
# emulate inference step (history latents)
|
| 106 |
+
# Note: In inference, history_latents stores *generated* future latents.
|
| 107 |
+
# Here, for caching, we just need its shape and type for clean_* tensors.
|
| 108 |
+
# The actual content doesn't matter much as clean_* will be overwritten.
|
| 109 |
+
history_latents = torch.zeros(
|
| 110 |
+
(1, video_lat.shape[1], 1 + 2 + 16, video_lat.shape[3], video_lat.shape[4]), dtype=video_lat.dtype
|
| 111 |
+
) # C=16 for HY
|
| 112 |
+
|
| 113 |
+
latent_f_index = latent_f - latent_window_size # Start from the last section
|
| 114 |
+
section_index = total_latent_sections - 1
|
| 115 |
+
|
| 116 |
+
for latent_padding in latent_paddings:
|
| 117 |
+
is_last_section = section_index == 0 # the last section in inference order == the first section in time
|
| 118 |
+
latent_padding_size = latent_padding * latent_window_size
|
| 119 |
+
if is_last_section:
|
| 120 |
+
assert latent_f_index == 1, "Last section should be starting from frame 1"
|
| 121 |
+
|
| 122 |
+
# indices generation (same as inference)
|
| 123 |
+
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
|
| 124 |
+
(
|
| 125 |
+
clean_latent_indices_pre, # Index for start_latent
|
| 126 |
+
blank_indices, # Indices for padding (future context in inference)
|
| 127 |
+
latent_indices, # Indices for the target latents to predict
|
| 128 |
+
clean_latent_indices_post, # Index for the most recent history frame
|
| 129 |
+
clean_latent_2x_indices, # Indices for the next 2 history frames
|
| 130 |
+
clean_latent_4x_indices, # Indices for the next 16 history frames
|
| 131 |
+
) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
|
| 132 |
+
|
| 133 |
+
# Indices for clean_latents (start + recent history)
|
| 134 |
+
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
|
| 135 |
+
|
| 136 |
+
# clean latents preparation (emulating inference)
|
| 137 |
+
clean_latents_pre = video_lat[:, :, 0:1, :, :] # Always the first frame (start_latent)
|
| 138 |
+
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
|
| 139 |
+
[1, 2, 16], dim=2
|
| 140 |
+
)
|
| 141 |
+
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) # Combine start frame + placeholder
|
| 142 |
+
|
| 143 |
+
# Target latents for this section (ground truth)
|
| 144 |
+
target_latents = video_lat[:, :, latent_f_index : latent_f_index + latent_window_size, :, :]
|
| 145 |
+
|
| 146 |
+
# save cache (file path is inside item.latent_cache_path pattern), remove batch dim
|
| 147 |
+
item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index)
|
| 148 |
+
save_latent_cache_framepack(
|
| 149 |
+
item_info=item,
|
| 150 |
+
latent=target_latents.squeeze(0), # Ground truth for this section
|
| 151 |
+
latent_indices=latent_indices.squeeze(0), # Indices for the ground truth section
|
| 152 |
+
clean_latents=clean_latents.squeeze(0), # Start frame + history placeholder
|
| 153 |
+
clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for start frame + history placeholder
|
| 154 |
+
clean_latents_2x=clean_latents_2x.squeeze(0), # History placeholder
|
| 155 |
+
clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for history placeholder
|
| 156 |
+
clean_latents_4x=clean_latents_4x.squeeze(0), # History placeholder
|
| 157 |
+
clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for history placeholder
|
| 158 |
+
image_embeddings=image_embeddings[b],
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
if is_last_section: # If this was the first section generated in inference (time=0)
|
| 162 |
+
# History gets the start frame + the generated first section
|
| 163 |
+
generated_latents_for_history = video_lat[:, :, : latent_window_size + 1, :, :]
|
| 164 |
+
else:
|
| 165 |
+
# History gets the generated current section
|
| 166 |
+
generated_latents_for_history = target_latents # Use true latents as stand-in for generated
|
| 167 |
+
|
| 168 |
+
history_latents = torch.cat([generated_latents_for_history, history_latents], dim=2)
|
| 169 |
+
|
| 170 |
+
section_index -= 1
|
| 171 |
+
latent_f_index -= latent_window_size
|
| 172 |
+
|
| 173 |
+
else:
|
| 174 |
+
# Vanilla Sampling Logic
|
| 175 |
+
for b, item in enumerate(batch):
|
| 176 |
+
original_latent_cache_path = item.latent_cache_path
|
| 177 |
+
video_lat = latents[b : b + 1] # Keep batch dim: 1, C, F_aligned, H, W
|
| 178 |
+
img_emb = image_embeddings[b] # LEN, 1152
|
| 179 |
+
|
| 180 |
+
for section_index in range(total_latent_sections):
|
| 181 |
+
target_start_f = section_index * latent_window_size + 1
|
| 182 |
+
target_end_f = target_start_f + latent_window_size
|
| 183 |
+
target_latents = video_lat[:, :, target_start_f:target_end_f, :, :]
|
| 184 |
+
start_latent = video_lat[:, :, 0:1, :, :]
|
| 185 |
+
|
| 186 |
+
# Clean latents preparation (Vanilla)
|
| 187 |
+
clean_latents_total_count = 1 + 2 + 16
|
| 188 |
+
history_latents = torch.zeros(
|
| 189 |
+
size=(1, 16, clean_latents_total_count, video_lat.shape[-2], video_lat.shape[-1]),
|
| 190 |
+
device=video_lat.device,
|
| 191 |
+
dtype=video_lat.dtype,
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
history_start_f = 0
|
| 195 |
+
video_start_f = target_start_f - clean_latents_total_count
|
| 196 |
+
copy_count = clean_latents_total_count
|
| 197 |
+
if video_start_f < 0:
|
| 198 |
+
history_start_f = -video_start_f
|
| 199 |
+
copy_count = clean_latents_total_count - history_start_f
|
| 200 |
+
video_start_f = 0
|
| 201 |
+
if copy_count > 0:
|
| 202 |
+
history_latents[:, :, history_start_f:] = video_lat[:, :, video_start_f : video_start_f + copy_count, :, :]
|
| 203 |
+
|
| 204 |
+
# indices generation (Vanilla): copy from FramePack-F1
|
| 205 |
+
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
|
| 206 |
+
(
|
| 207 |
+
clean_latent_indices_start,
|
| 208 |
+
clean_latent_4x_indices,
|
| 209 |
+
clean_latent_2x_indices,
|
| 210 |
+
clean_latent_1x_indices,
|
| 211 |
+
latent_indices,
|
| 212 |
+
) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
|
| 213 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
|
| 214 |
+
|
| 215 |
+
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents.split([16, 2, 1], dim=2)
|
| 216 |
+
clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
|
| 217 |
+
|
| 218 |
+
# Save cache
|
| 219 |
+
item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index)
|
| 220 |
+
save_latent_cache_framepack(
|
| 221 |
+
item_info=item,
|
| 222 |
+
latent=target_latents.squeeze(0),
|
| 223 |
+
latent_indices=latent_indices.squeeze(0), # Indices for target section i
|
| 224 |
+
clean_latents=clean_latents.squeeze(0), # Past clean frames
|
| 225 |
+
clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for clean_latents_pre/post
|
| 226 |
+
clean_latents_2x=clean_latents_2x.squeeze(0), # Past clean frames (2x)
|
| 227 |
+
clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for clean_latents_2x
|
| 228 |
+
clean_latents_4x=clean_latents_4x.squeeze(0), # Past clean frames (4x)
|
| 229 |
+
clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for clean_latents_4x
|
| 230 |
+
image_embeddings=img_emb,
|
| 231 |
+
# Note: We don't explicitly save past_offset_indices,
|
| 232 |
+
# but its size influences the absolute values in other indices.
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def encode_and_save_batch_one_frame(
|
| 237 |
+
vae: AutoencoderKLCausal3D,
|
| 238 |
+
feature_extractor: SiglipImageProcessor,
|
| 239 |
+
image_encoder: SiglipVisionModel,
|
| 240 |
+
batch: List[ItemInfo],
|
| 241 |
+
latent_window_size: int,
|
| 242 |
+
vanilla_sampling: bool = False,
|
| 243 |
+
):
|
| 244 |
+
# item.content: target image (H, W, C)
|
| 245 |
+
# item.control_content: start image (H, W, C)
|
| 246 |
+
|
| 247 |
+
# Stack batch into tensor (B,F,H,W,C) in RGB order.
|
| 248 |
+
contents = torch.stack(
|
| 249 |
+
[torch.stack([torch.from_numpy(item.control_content), torch.from_numpy(item.content)]) for item in batch]
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
|
| 253 |
+
contents = contents.to(vae.device, dtype=vae.dtype)
|
| 254 |
+
contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
|
| 255 |
+
|
| 256 |
+
height, width = contents.shape[3], contents.shape[4]
|
| 257 |
+
if height < 8 or width < 8:
|
| 258 |
+
item = batch[0] # other items should have the same size
|
| 259 |
+
raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
|
| 260 |
+
|
| 261 |
+
# VAE encode (list of tensor -> stack)
|
| 262 |
+
start_latents = hunyuan.vae_encode(contents[:, :, 0:1], vae) # include scaling factor
|
| 263 |
+
start_latents = start_latents.to("cpu") # (B, C, 1, H/8, W/8)
|
| 264 |
+
latents = hunyuan.vae_encode(contents[:, :, 1:], vae) # include scaling factor
|
| 265 |
+
latents = latents.to("cpu") # (B, C, 1, H/8, W/8)
|
| 266 |
+
|
| 267 |
+
# Vision encoding per‑item (once): use control content because it is the start image
|
| 268 |
+
images = [item.control_content for item in batch] # list of [H, W, C]
|
| 269 |
+
|
| 270 |
+
# encode image with image encoder
|
| 271 |
+
image_embeddings = []
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
for image in images:
|
| 274 |
+
image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
|
| 275 |
+
image_embeddings.append(image_encoder_output.last_hidden_state)
|
| 276 |
+
image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
|
| 277 |
+
image_embeddings = image_embeddings.to("cpu") # Save memory
|
| 278 |
+
|
| 279 |
+
# history latents is always zeroes for one frame training
|
| 280 |
+
history_latents = torch.zeros(
|
| 281 |
+
(1, latents.shape[1], 1 + 2 + 16, latents.shape[3], latents.shape[4]), dtype=latents.dtype
|
| 282 |
+
) # C=16 for HY
|
| 283 |
+
|
| 284 |
+
# indices generation (same as inference)
|
| 285 |
+
indices = torch.arange(0, sum([1, latent_window_size, 1, 2, 16])).unsqueeze(0)
|
| 286 |
+
(
|
| 287 |
+
clean_latent_indices_pre, # Index for start_latent
|
| 288 |
+
latent_indices, # Indices for the target latents to predict
|
| 289 |
+
clean_latent_indices_post, # Index for the most recent history frame
|
| 290 |
+
clean_latent_2x_indices, # Indices for the next 2 history frames
|
| 291 |
+
clean_latent_4x_indices, # Indices for the next 16 history frames
|
| 292 |
+
) = indices.split([1, latent_window_size, 1, 2, 16], dim=1)
|
| 293 |
+
|
| 294 |
+
# Indices for clean_latents (start + recent history)
|
| 295 |
+
latent_indices = latent_indices[:, -1:] # Only the last index is used for one frame training
|
| 296 |
+
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
|
| 297 |
+
|
| 298 |
+
# clean latents preparation for all items (emulating inference)
|
| 299 |
+
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
|
| 300 |
+
|
| 301 |
+
for b, item in enumerate(batch):
|
| 302 |
+
original_latent_cache_path = item.latent_cache_path
|
| 303 |
+
|
| 304 |
+
# clean latents preparation (emulating inference)
|
| 305 |
+
clean_latents_pre = start_latents[b : b + 1]
|
| 306 |
+
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) # Combine start frame + placeholder
|
| 307 |
+
|
| 308 |
+
# Target latents for this section (ground truth)
|
| 309 |
+
target_latents = latents[b : b + 1]
|
| 310 |
+
|
| 311 |
+
# save cache (file path is inside item.latent_cache_path pattern), remove batch dim
|
| 312 |
+
save_latent_cache_framepack(
|
| 313 |
+
item_info=item,
|
| 314 |
+
latent=target_latents.squeeze(0), # Ground truth for this section
|
| 315 |
+
latent_indices=latent_indices.squeeze(0), # Indices for the ground truth section
|
| 316 |
+
clean_latents=clean_latents.squeeze(0), # Start frame + history placeholder
|
| 317 |
+
clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for start frame + history placeholder
|
| 318 |
+
clean_latents_2x=clean_latents_2x.squeeze(0), # History placeholder
|
| 319 |
+
clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for history placeholder
|
| 320 |
+
clean_latents_4x=clean_latents_4x.squeeze(0), # History placeholder
|
| 321 |
+
clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for history placeholder
|
| 322 |
+
image_embeddings=image_embeddings[b],
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
| 327 |
+
parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
|
| 328 |
+
parser.add_argument("--latent_window_size", type=int, default=9, help="FramePack latent window size (default 9)")
|
| 329 |
+
parser.add_argument(
|
| 330 |
+
"--f1",
|
| 331 |
+
action="store_true",
|
| 332 |
+
help="Generate cache for F1 model (vanilla (autoregressive) sampling) instead of Inverted anti-drifting (plain FramePack)",
|
| 333 |
+
)
|
| 334 |
+
parser.add_argument(
|
| 335 |
+
"--one_frame",
|
| 336 |
+
action="store_true",
|
| 337 |
+
help="Generate cache for one frame training (single frame, single section). latent_window_size is used as the index of the target frame.",
|
| 338 |
+
)
|
| 339 |
+
return parser
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def main(args: argparse.Namespace):
|
| 343 |
+
device = args.device if hasattr(args, "device") and args.device else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 344 |
+
device = torch.device(device)
|
| 345 |
+
|
| 346 |
+
# Load dataset config
|
| 347 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
| 348 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 349 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
| 350 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
|
| 351 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 352 |
+
|
| 353 |
+
datasets = train_dataset_group.datasets
|
| 354 |
+
|
| 355 |
+
if args.debug_mode is not None:
|
| 356 |
+
cache_latents.show_datasets(
|
| 357 |
+
datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16
|
| 358 |
+
)
|
| 359 |
+
return
|
| 360 |
+
|
| 361 |
+
assert args.vae is not None, "vae checkpoint is required"
|
| 362 |
+
|
| 363 |
+
logger.info(f"Loading VAE model from {args.vae}")
|
| 364 |
+
vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device=device)
|
| 365 |
+
vae.to(device)
|
| 366 |
+
|
| 367 |
+
logger.info(f"Loading image encoder from {args.image_encoder}")
|
| 368 |
+
feature_extractor, image_encoder = load_image_encoders(args)
|
| 369 |
+
image_encoder.eval()
|
| 370 |
+
image_encoder.to(device)
|
| 371 |
+
|
| 372 |
+
logger.info(f"Cache generation mode: {'Vanilla Sampling' if args.f1 else 'Inference Emulation'}")
|
| 373 |
+
|
| 374 |
+
# encoding closure
|
| 375 |
+
def encode(batch: List[ItemInfo]):
|
| 376 |
+
encode_and_save_batch(vae, feature_extractor, image_encoder, batch, args.latent_window_size, args.f1, args.one_frame)
|
| 377 |
+
|
| 378 |
+
# reuse core loop from cache_latents with no change
|
| 379 |
+
encode_datasets_framepack(datasets, encode, args)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def append_section_idx_to_latent_cache_path(latent_cache_path: str, section_idx: int) -> str:
|
| 383 |
+
tokens = latent_cache_path.split("_")
|
| 384 |
+
tokens[-3] = f"{tokens[-3]}-{section_idx:04d}" # append section index to "frame_pos-count"
|
| 385 |
+
return "_".join(tokens)
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def encode_datasets_framepack(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
|
| 389 |
+
num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
|
| 390 |
+
for i, dataset in enumerate(datasets):
|
| 391 |
+
logger.info(f"Encoding dataset [{i}]")
|
| 392 |
+
all_latent_cache_paths = []
|
| 393 |
+
for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
|
| 394 |
+
batch: list[ItemInfo] = batch # type: ignore
|
| 395 |
+
|
| 396 |
+
# latent_cache_path is "{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors"
|
| 397 |
+
# For video dataset,we expand it to "{basename}_{section_idx:04d}_{w:04d}x{h:04d}_{self.architecture}.safetensors"
|
| 398 |
+
filtered_batch = []
|
| 399 |
+
for item in batch:
|
| 400 |
+
if item.frame_count is None:
|
| 401 |
+
# image dataset
|
| 402 |
+
all_latent_cache_paths.append(item.latent_cache_path)
|
| 403 |
+
all_existing = os.path.exists(item.latent_cache_path)
|
| 404 |
+
else:
|
| 405 |
+
latent_f = (item.frame_count - 1) // 4 + 1
|
| 406 |
+
num_sections = max(1, math.floor((latent_f - 1) / args.latent_window_size)) # min 1 section
|
| 407 |
+
all_existing = True
|
| 408 |
+
for sec in range(num_sections):
|
| 409 |
+
p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
|
| 410 |
+
all_latent_cache_paths.append(p)
|
| 411 |
+
all_existing = all_existing and os.path.exists(p)
|
| 412 |
+
|
| 413 |
+
if not all_existing: # if any section cache is missing
|
| 414 |
+
filtered_batch.append(item)
|
| 415 |
+
|
| 416 |
+
if args.skip_existing:
|
| 417 |
+
if len(filtered_batch) == 0: # all sections exist
|
| 418 |
+
logger.info(f"All sections exist for {batch[0].item_key}, skipping")
|
| 419 |
+
continue
|
| 420 |
+
batch = filtered_batch # update batch to only missing sections
|
| 421 |
+
|
| 422 |
+
bs = args.batch_size if args.batch_size is not None else len(batch)
|
| 423 |
+
for i in range(0, len(batch), bs):
|
| 424 |
+
encode(batch[i : i + bs])
|
| 425 |
+
|
| 426 |
+
# normalize paths
|
| 427 |
+
all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
|
| 428 |
+
all_latent_cache_paths = set(all_latent_cache_paths)
|
| 429 |
+
|
| 430 |
+
# remove old cache files not in the dataset
|
| 431 |
+
all_cache_files = dataset.get_all_latent_cache_files()
|
| 432 |
+
for cache_file in all_cache_files:
|
| 433 |
+
if os.path.normpath(cache_file) not in all_latent_cache_paths:
|
| 434 |
+
if args.keep_cache:
|
| 435 |
+
logger.info(f"Keep cache file not in the dataset: {cache_file}")
|
| 436 |
+
else:
|
| 437 |
+
os.remove(cache_file)
|
| 438 |
+
logger.info(f"Removed old cache file: {cache_file}")
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
if __name__ == "__main__":
|
| 442 |
+
parser = cache_latents.setup_parser_common()
|
| 443 |
+
parser = cache_latents.hv_setup_parser(parser) # VAE
|
| 444 |
+
parser = framepack_setup_parser(parser)
|
| 445 |
+
|
| 446 |
+
args = parser.parse_args()
|
| 447 |
+
|
| 448 |
+
if args.vae_dtype is not None:
|
| 449 |
+
raise ValueError("VAE dtype is not supported in FramePack")
|
| 450 |
+
# if args.batch_size != 1:
|
| 451 |
+
# args.batch_size = 1
|
| 452 |
+
# logger.info("Batch size is set to 1 for FramePack.")
|
| 453 |
+
|
| 454 |
+
main(args)
|
fpack_cache_text_encoder_outputs.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import LlamaTokenizerFast, LlamaModel, CLIPTokenizer, CLIPTextModel
|
| 9 |
+
from dataset import config_utils
|
| 10 |
+
from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
|
| 11 |
+
from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ItemInfo, save_text_encoder_output_cache_framepack
|
| 12 |
+
import cache_text_encoder_outputs
|
| 13 |
+
from frame_pack import hunyuan
|
| 14 |
+
from frame_pack.framepack_utils import load_text_encoder1, load_text_encoder2
|
| 15 |
+
|
| 16 |
+
import logging
|
| 17 |
+
|
| 18 |
+
from frame_pack.utils import crop_or_pad_yield_mask
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
logging.basicConfig(level=logging.INFO)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def encode_and_save_batch(
|
| 25 |
+
tokenizer1: LlamaTokenizerFast,
|
| 26 |
+
text_encoder1: LlamaModel,
|
| 27 |
+
tokenizer2: CLIPTokenizer,
|
| 28 |
+
text_encoder2: CLIPTextModel,
|
| 29 |
+
batch: list[ItemInfo],
|
| 30 |
+
device: torch.device,
|
| 31 |
+
):
|
| 32 |
+
prompts = [item.caption for item in batch]
|
| 33 |
+
|
| 34 |
+
# encode prompt
|
| 35 |
+
# FramePack's encode_prompt_conds only supports single prompt, so we need to encode each prompt separately
|
| 36 |
+
list_of_llama_vec = []
|
| 37 |
+
list_of_llama_attention_mask = []
|
| 38 |
+
list_of_clip_l_pooler = []
|
| 39 |
+
for prompt in prompts:
|
| 40 |
+
with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
|
| 41 |
+
# llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompts, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
|
| 42 |
+
llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
|
| 43 |
+
llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
|
| 44 |
+
|
| 45 |
+
list_of_llama_vec.append(llama_vec.squeeze(0))
|
| 46 |
+
list_of_llama_attention_mask.append(llama_attention_mask.squeeze(0))
|
| 47 |
+
list_of_clip_l_pooler.append(clip_l_pooler.squeeze(0))
|
| 48 |
+
|
| 49 |
+
# save prompt cache
|
| 50 |
+
for item, llama_vec, llama_attention_mask, clip_l_pooler in zip(
|
| 51 |
+
batch, list_of_llama_vec, list_of_llama_attention_mask, list_of_clip_l_pooler
|
| 52 |
+
):
|
| 53 |
+
# save llama_vec and clip_l_pooler to cache
|
| 54 |
+
save_text_encoder_output_cache_framepack(item, llama_vec, llama_attention_mask, clip_l_pooler)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def main(args):
|
| 58 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
| 59 |
+
device = torch.device(device)
|
| 60 |
+
|
| 61 |
+
# Load dataset config
|
| 62 |
+
blueprint_generator = BlueprintGenerator(ConfigSanitizer())
|
| 63 |
+
logger.info(f"Load dataset config from {args.dataset_config}")
|
| 64 |
+
user_config = config_utils.load_user_config(args.dataset_config)
|
| 65 |
+
blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
|
| 66 |
+
train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
|
| 67 |
+
|
| 68 |
+
datasets = train_dataset_group.datasets
|
| 69 |
+
|
| 70 |
+
# prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
|
| 71 |
+
all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
|
| 72 |
+
|
| 73 |
+
# load text encoder
|
| 74 |
+
tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
|
| 75 |
+
tokenizer2, text_encoder2 = load_text_encoder2(args)
|
| 76 |
+
text_encoder2.to(device)
|
| 77 |
+
|
| 78 |
+
# Encode with Text Encoders
|
| 79 |
+
logger.info("Encoding with Text Encoders")
|
| 80 |
+
|
| 81 |
+
def encode_for_text_encoder(batch: list[ItemInfo]):
|
| 82 |
+
encode_and_save_batch(tokenizer1, text_encoder1, tokenizer2, text_encoder2, batch, device)
|
| 83 |
+
|
| 84 |
+
cache_text_encoder_outputs.process_text_encoder_batches(
|
| 85 |
+
args.num_workers,
|
| 86 |
+
args.skip_existing,
|
| 87 |
+
args.batch_size,
|
| 88 |
+
datasets,
|
| 89 |
+
all_cache_files_for_dataset,
|
| 90 |
+
all_cache_paths_for_dataset,
|
| 91 |
+
encode_for_text_encoder,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# remove cache files not in dataset
|
| 95 |
+
cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
|
| 99 |
+
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
|
| 100 |
+
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
|
| 101 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
| 102 |
+
return parser
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
parser = cache_text_encoder_outputs.setup_parser_common()
|
| 107 |
+
parser = framepack_setup_parser(parser)
|
| 108 |
+
|
| 109 |
+
args = parser.parse_args()
|
| 110 |
+
main(args)
|
fpack_generate_video.py
ADDED
|
@@ -0,0 +1,1711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
import gc
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import time
|
| 9 |
+
import math
|
| 10 |
+
import copy
|
| 11 |
+
from typing import Tuple, Optional, List, Union, Any, Dict
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from safetensors.torch import load_file, save_file
|
| 15 |
+
from safetensors import safe_open
|
| 16 |
+
from PIL import Image
|
| 17 |
+
import cv2
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torchvision.transforms.functional as TF
|
| 20 |
+
from transformers import LlamaModel
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
from networks import lora_framepack
|
| 24 |
+
from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
| 25 |
+
from frame_pack import hunyuan
|
| 26 |
+
from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model
|
| 27 |
+
from frame_pack.utils import crop_or_pad_yield_mask, resize_and_center_crop, soft_append_bcthw
|
| 28 |
+
from frame_pack.bucket_tools import find_nearest_bucket
|
| 29 |
+
from frame_pack.clip_vision import hf_clip_vision_encode
|
| 30 |
+
from frame_pack.k_diffusion_hunyuan import sample_hunyuan
|
| 31 |
+
from dataset import image_video_dataset
|
| 32 |
+
|
| 33 |
+
try:
|
| 34 |
+
from lycoris.kohya import create_network_from_weights
|
| 35 |
+
except:
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
from utils.device_utils import clean_memory_on_device
|
| 39 |
+
from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device
|
| 40 |
+
from wan_generate_video import merge_lora_weights
|
| 41 |
+
from frame_pack.framepack_utils import load_vae, load_text_encoder1, load_text_encoder2, load_image_encoders
|
| 42 |
+
from dataset.image_video_dataset import load_video
|
| 43 |
+
|
| 44 |
+
import logging
|
| 45 |
+
|
| 46 |
+
logger = logging.getLogger(__name__)
|
| 47 |
+
logging.basicConfig(level=logging.INFO)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class GenerationSettings:
|
| 51 |
+
def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None):
|
| 52 |
+
self.device = device
|
| 53 |
+
self.dit_weight_dtype = dit_weight_dtype # not used currently because model may be optimized
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def parse_args() -> argparse.Namespace:
|
| 57 |
+
"""parse command line arguments"""
|
| 58 |
+
parser = argparse.ArgumentParser(description="Wan 2.1 inference script")
|
| 59 |
+
|
| 60 |
+
# WAN arguments
|
| 61 |
+
# parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample."
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
parser.add_argument("--dit", type=str, default=None, help="DiT directory or path")
|
| 67 |
+
parser.add_argument("--vae", type=str, default=None, help="VAE directory or path")
|
| 68 |
+
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory or path")
|
| 69 |
+
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory or path")
|
| 70 |
+
parser.add_argument("--image_encoder", type=str, required=True, help="Image Encoder directory or path")
|
| 71 |
+
parser.add_argument("--f1", action="store_true", help="Use F1 sampling method")
|
| 72 |
+
|
| 73 |
+
# LoRA
|
| 74 |
+
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
|
| 75 |
+
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
|
| 76 |
+
parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
|
| 77 |
+
parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--save_merged_model",
|
| 80 |
+
type=str,
|
| 81 |
+
default=None,
|
| 82 |
+
help="Save merged model to path. If specified, no inference will be performed.",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# inference
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--prompt",
|
| 88 |
+
type=str,
|
| 89 |
+
default=None,
|
| 90 |
+
help="prompt for generation. If `;;;` is used, it will be split into sections. Example: `section_index:prompt` or "
|
| 91 |
+
"`section_index:prompt;;;section_index:prompt;;;...`, section_index can be `0` or `-1` or `0-2`, `-1` means last section, `0-2` means from 0 to 2 (inclusive).",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--negative_prompt",
|
| 95 |
+
type=str,
|
| 96 |
+
default=None,
|
| 97 |
+
help="negative prompt for generation, default is empty string. should not change.",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--custom_system_prompt",
|
| 101 |
+
type=str,
|
| 102 |
+
default=None,
|
| 103 |
+
help="Custom system prompt for LLM. If specified, it will override the default system prompt. See hunyuan_model/text_encoder.py for the default system prompt.",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
|
| 106 |
+
parser.add_argument("--video_seconds", type=float, default=5.0, help="video length, default is 5.0 seconds")
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
"--video_sections",
|
| 109 |
+
type=int,
|
| 110 |
+
default=None,
|
| 111 |
+
help="number of video sections, Default is None (auto calculate from video seconds)",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument(
|
| 114 |
+
"--one_frame_inference",
|
| 115 |
+
type=str,
|
| 116 |
+
default=None,
|
| 117 |
+
help="one frame inference, default is None, comma separated values from 'zero_post', 'no_2x', 'no_4x' and 'no_post'.",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--image_mask_path",
|
| 121 |
+
type=str,
|
| 122 |
+
default=None,
|
| 123 |
+
help="path to image mask for one frame inference. If specified, it will be used as mask for input image.",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--end_image_mask_path",
|
| 127 |
+
type=str,
|
| 128 |
+
default=None,
|
| 129 |
+
nargs="*",
|
| 130 |
+
help="path to end (reference) image mask for one frame inference. If specified, it will be used as mask for end image.",
|
| 131 |
+
)
|
| 132 |
+
parser.add_argument("--fps", type=int, default=30, help="video fps, default is 30")
|
| 133 |
+
parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, default is 25")
|
| 134 |
+
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
|
| 135 |
+
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
|
| 136 |
+
# parser.add_argument(
|
| 137 |
+
# "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False."
|
| 138 |
+
# )
|
| 139 |
+
parser.add_argument("--latent_window_size", type=int, default=9, help="latent window size, default is 9. should not change.")
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--embedded_cfg_scale", type=float, default=10.0, help="Embeded CFG scale (distilled CFG Scale), default is 10.0"
|
| 142 |
+
)
|
| 143 |
+
parser.add_argument(
|
| 144 |
+
"--guidance_scale",
|
| 145 |
+
type=float,
|
| 146 |
+
default=1.0,
|
| 147 |
+
help="Guidance scale for classifier free guidance. Default is 1.0 (no guidance), should not change.",
|
| 148 |
+
)
|
| 149 |
+
parser.add_argument("--guidance_rescale", type=float, default=0.0, help="CFG Re-scale, default is 0.0. Should not change.")
|
| 150 |
+
# parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--image_path",
|
| 153 |
+
type=str,
|
| 154 |
+
default=None,
|
| 155 |
+
help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.",
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument("--end_image_path", type=str, nargs="*", default=None, help="path to end image for image2video inference")
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--latent_paddings",
|
| 160 |
+
type=str,
|
| 161 |
+
default=None,
|
| 162 |
+
help="latent paddings for each section, comma separated values. default is None (FramePack default paddings)",
|
| 163 |
+
)
|
| 164 |
+
# parser.add_argument(
|
| 165 |
+
# "--control_path",
|
| 166 |
+
# type=str,
|
| 167 |
+
# default=None,
|
| 168 |
+
# help="path to control video for inference with controlnet. video file or directory with images",
|
| 169 |
+
# )
|
| 170 |
+
# parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
|
| 171 |
+
|
| 172 |
+
# # Flow Matching
|
| 173 |
+
# parser.add_argument(
|
| 174 |
+
# "--flow_shift",
|
| 175 |
+
# type=float,
|
| 176 |
+
# default=None,
|
| 177 |
+
# help="Shift factor for flow matching schedulers. Default depends on task.",
|
| 178 |
+
# )
|
| 179 |
+
|
| 180 |
+
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
|
| 181 |
+
parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
|
| 182 |
+
# parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
|
| 183 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
| 184 |
+
parser.add_argument(
|
| 185 |
+
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--attn_mode",
|
| 189 |
+
type=str,
|
| 190 |
+
default="torch",
|
| 191 |
+
choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3",
|
| 192 |
+
help="attention mode",
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
| 195 |
+
parser.add_argument(
|
| 196 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once")
|
| 199 |
+
parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
|
| 200 |
+
parser.add_argument(
|
| 201 |
+
"--output_type",
|
| 202 |
+
type=str,
|
| 203 |
+
default="video",
|
| 204 |
+
choices=["video", "images", "latent", "both", "latent_images"],
|
| 205 |
+
help="output type",
|
| 206 |
+
)
|
| 207 |
+
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
|
| 208 |
+
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
|
| 209 |
+
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
|
| 210 |
+
# parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
|
| 211 |
+
# parser.add_argument(
|
| 212 |
+
# "--compile_args",
|
| 213 |
+
# nargs=4,
|
| 214 |
+
# metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
|
| 215 |
+
# default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
|
| 216 |
+
# help="Torch.compile settings",
|
| 217 |
+
# )
|
| 218 |
+
|
| 219 |
+
# New arguments for batch and interactive modes
|
| 220 |
+
parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
|
| 221 |
+
parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
|
| 222 |
+
|
| 223 |
+
args = parser.parse_args()
|
| 224 |
+
|
| 225 |
+
# Validate arguments
|
| 226 |
+
if args.from_file and args.interactive:
|
| 227 |
+
raise ValueError("Cannot use both --from_file and --interactive at the same time")
|
| 228 |
+
|
| 229 |
+
if args.latent_path is None or len(args.latent_path) == 0:
|
| 230 |
+
if args.prompt is None and not args.from_file and not args.interactive:
|
| 231 |
+
raise ValueError("Either --prompt, --from_file or --interactive must be specified")
|
| 232 |
+
|
| 233 |
+
return args
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def parse_prompt_line(line: str) -> Dict[str, Any]:
|
| 237 |
+
"""Parse a prompt line into a dictionary of argument overrides
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
line: Prompt line with options
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Dict[str, Any]: Dictionary of argument overrides
|
| 244 |
+
"""
|
| 245 |
+
# TODO common function with hv_train_network.line_to_prompt_dict
|
| 246 |
+
parts = line.split(" --")
|
| 247 |
+
prompt = parts[0].strip()
|
| 248 |
+
|
| 249 |
+
# Create dictionary of overrides
|
| 250 |
+
overrides = {"prompt": prompt}
|
| 251 |
+
# Initialize end_image_path and end_image_mask_path as a list to accommodate multiple paths
|
| 252 |
+
overrides["end_image_path"] = []
|
| 253 |
+
overrides["end_image_mask_path"] = []
|
| 254 |
+
|
| 255 |
+
for part in parts[1:]:
|
| 256 |
+
if not part.strip():
|
| 257 |
+
continue
|
| 258 |
+
option_parts = part.split(" ", 1)
|
| 259 |
+
option = option_parts[0].strip()
|
| 260 |
+
value = option_parts[1].strip() if len(option_parts) > 1 else ""
|
| 261 |
+
|
| 262 |
+
# Map options to argument names
|
| 263 |
+
if option == "w":
|
| 264 |
+
overrides["video_size_width"] = int(value)
|
| 265 |
+
elif option == "h":
|
| 266 |
+
overrides["video_size_height"] = int(value)
|
| 267 |
+
elif option == "f":
|
| 268 |
+
overrides["video_seconds"] = float(value)
|
| 269 |
+
elif option == "d":
|
| 270 |
+
overrides["seed"] = int(value)
|
| 271 |
+
elif option == "s":
|
| 272 |
+
overrides["infer_steps"] = int(value)
|
| 273 |
+
elif option == "g" or option == "l":
|
| 274 |
+
overrides["guidance_scale"] = float(value)
|
| 275 |
+
# elif option == "fs":
|
| 276 |
+
# overrides["flow_shift"] = float(value)
|
| 277 |
+
elif option == "i":
|
| 278 |
+
overrides["image_path"] = value
|
| 279 |
+
elif option == "im":
|
| 280 |
+
overrides["image_mask_path"] = value
|
| 281 |
+
# elif option == "cn":
|
| 282 |
+
# overrides["control_path"] = value
|
| 283 |
+
elif option == "n":
|
| 284 |
+
overrides["negative_prompt"] = value
|
| 285 |
+
elif option == "vs": # video_sections
|
| 286 |
+
overrides["video_sections"] = int(value)
|
| 287 |
+
elif option == "ei": # end_image_path
|
| 288 |
+
overrides["end_image_path"].append(value)
|
| 289 |
+
elif option == "eim": # end_image_mask_path
|
| 290 |
+
overrides["end_image_mask_path"].append(value)
|
| 291 |
+
elif option == "of": # one_frame_inference
|
| 292 |
+
overrides["one_frame_inference"] = value
|
| 293 |
+
|
| 294 |
+
# If no end_image_path was provided, remove the empty list
|
| 295 |
+
if not overrides["end_image_path"]:
|
| 296 |
+
del overrides["end_image_path"]
|
| 297 |
+
if not overrides["end_image_mask_path"]:
|
| 298 |
+
del overrides["end_image_mask_path"]
|
| 299 |
+
|
| 300 |
+
return overrides
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
|
| 304 |
+
"""Apply overrides to args
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
args: Original arguments
|
| 308 |
+
overrides: Dictionary of overrides
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
argparse.Namespace: New arguments with overrides applied
|
| 312 |
+
"""
|
| 313 |
+
args_copy = copy.deepcopy(args)
|
| 314 |
+
|
| 315 |
+
for key, value in overrides.items():
|
| 316 |
+
if key == "video_size_width":
|
| 317 |
+
args_copy.video_size[1] = value
|
| 318 |
+
elif key == "video_size_height":
|
| 319 |
+
args_copy.video_size[0] = value
|
| 320 |
+
else:
|
| 321 |
+
setattr(args_copy, key, value)
|
| 322 |
+
|
| 323 |
+
return args_copy
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]:
|
| 327 |
+
"""Validate video size and length
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
args: command line arguments
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
Tuple[int, int, float]: (height, width, video_seconds)
|
| 334 |
+
"""
|
| 335 |
+
height = args.video_size[0]
|
| 336 |
+
width = args.video_size[1]
|
| 337 |
+
|
| 338 |
+
video_seconds = args.video_seconds
|
| 339 |
+
if args.video_sections is not None:
|
| 340 |
+
video_seconds = (args.video_sections * (args.latent_window_size * 4) + 1) / args.fps
|
| 341 |
+
|
| 342 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 343 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 344 |
+
|
| 345 |
+
return height, width, video_seconds
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
# region DiT model
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVideoTransformer3DModelPacked:
|
| 352 |
+
"""load DiT model
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
args: command line arguments
|
| 356 |
+
device: device to use
|
| 357 |
+
dit_dtype: data type for the model
|
| 358 |
+
dit_weight_dtype: data type for the model weights. None for as-is
|
| 359 |
+
|
| 360 |
+
Returns:
|
| 361 |
+
HunyuanVideoTransformer3DModelPacked: DiT model
|
| 362 |
+
"""
|
| 363 |
+
loading_device = "cpu"
|
| 364 |
+
if args.blocks_to_swap == 0 and not args.fp8_scaled and args.lora_weight is None:
|
| 365 |
+
loading_device = device
|
| 366 |
+
|
| 367 |
+
# do not fp8 optimize because we will merge LoRA weights
|
| 368 |
+
model = load_packed_model(device, args.dit, args.attn_mode, loading_device)
|
| 369 |
+
return model
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def optimize_model(model: HunyuanVideoTransformer3DModelPacked, args: argparse.Namespace, device: torch.device) -> None:
|
| 373 |
+
"""optimize the model (FP8 conversion, device move etc.)
|
| 374 |
+
|
| 375 |
+
Args:
|
| 376 |
+
model: dit model
|
| 377 |
+
args: command line arguments
|
| 378 |
+
device: device to use
|
| 379 |
+
"""
|
| 380 |
+
if args.fp8_scaled:
|
| 381 |
+
# load state dict as-is and optimize to fp8
|
| 382 |
+
state_dict = model.state_dict()
|
| 383 |
+
|
| 384 |
+
# if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
|
| 385 |
+
move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
|
| 386 |
+
state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) # args.fp8_fast)
|
| 387 |
+
|
| 388 |
+
info = model.load_state_dict(state_dict, strict=True, assign=True)
|
| 389 |
+
logger.info(f"Loaded FP8 optimized weights: {info}")
|
| 390 |
+
|
| 391 |
+
if args.blocks_to_swap == 0:
|
| 392 |
+
model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.)
|
| 393 |
+
else:
|
| 394 |
+
# simple cast to dit_dtype
|
| 395 |
+
target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
|
| 396 |
+
target_device = None
|
| 397 |
+
|
| 398 |
+
if args.fp8:
|
| 399 |
+
target_dtype = torch.float8e4m3fn
|
| 400 |
+
|
| 401 |
+
if args.blocks_to_swap == 0:
|
| 402 |
+
logger.info(f"Move model to device: {device}")
|
| 403 |
+
target_device = device
|
| 404 |
+
|
| 405 |
+
if target_device is not None and target_dtype is not None:
|
| 406 |
+
model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
|
| 407 |
+
|
| 408 |
+
# if args.compile:
|
| 409 |
+
# compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
|
| 410 |
+
# logger.info(
|
| 411 |
+
# f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
|
| 412 |
+
# )
|
| 413 |
+
# torch._dynamo.config.cache_size_limit = 32
|
| 414 |
+
# for i in range(len(model.blocks)):
|
| 415 |
+
# model.blocks[i] = torch.compile(
|
| 416 |
+
# model.blocks[i],
|
| 417 |
+
# backend=compile_backend,
|
| 418 |
+
# mode=compile_mode,
|
| 419 |
+
# dynamic=compile_dynamic.lower() in "true",
|
| 420 |
+
# fullgraph=compile_fullgraph.lower() in "true",
|
| 421 |
+
# )
|
| 422 |
+
|
| 423 |
+
if args.blocks_to_swap > 0:
|
| 424 |
+
logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}")
|
| 425 |
+
model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False)
|
| 426 |
+
model.move_to_device_except_swap_blocks(device)
|
| 427 |
+
model.prepare_block_swap_before_forward()
|
| 428 |
+
else:
|
| 429 |
+
# make sure the model is on the right device
|
| 430 |
+
model.to(device)
|
| 431 |
+
|
| 432 |
+
model.eval().requires_grad_(False)
|
| 433 |
+
clean_memory_on_device(device)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
# endregion
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def decode_latent(
|
| 440 |
+
latent_window_size: int,
|
| 441 |
+
total_latent_sections: int,
|
| 442 |
+
bulk_decode: bool,
|
| 443 |
+
vae: AutoencoderKLCausal3D,
|
| 444 |
+
latent: torch.Tensor,
|
| 445 |
+
device: torch.device,
|
| 446 |
+
one_frame_inference_mode: bool = False,
|
| 447 |
+
) -> torch.Tensor:
|
| 448 |
+
logger.info(f"Decoding video...")
|
| 449 |
+
if latent.ndim == 4:
|
| 450 |
+
latent = latent.unsqueeze(0) # add batch dimension
|
| 451 |
+
|
| 452 |
+
vae.to(device)
|
| 453 |
+
if not bulk_decode and not one_frame_inference_mode:
|
| 454 |
+
latent_window_size = latent_window_size # default is 9
|
| 455 |
+
# total_latent_sections = (args.video_seconds * 30) / (latent_window_size * 4)
|
| 456 |
+
# total_latent_sections = int(max(round(total_latent_sections), 1))
|
| 457 |
+
num_frames = latent_window_size * 4 - 3
|
| 458 |
+
|
| 459 |
+
latents_to_decode = []
|
| 460 |
+
latent_frame_index = 0
|
| 461 |
+
for i in range(total_latent_sections - 1, -1, -1):
|
| 462 |
+
is_last_section = i == total_latent_sections - 1
|
| 463 |
+
generated_latent_frames = (num_frames + 3) // 4 + (1 if is_last_section else 0)
|
| 464 |
+
section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
|
| 465 |
+
|
| 466 |
+
section_latent = latent[:, :, latent_frame_index : latent_frame_index + section_latent_frames, :, :]
|
| 467 |
+
if section_latent.shape[2] > 0:
|
| 468 |
+
latents_to_decode.append(section_latent)
|
| 469 |
+
|
| 470 |
+
latent_frame_index += generated_latent_frames
|
| 471 |
+
|
| 472 |
+
latents_to_decode = latents_to_decode[::-1] # reverse the order of latents to decode
|
| 473 |
+
|
| 474 |
+
history_pixels = None
|
| 475 |
+
for latent in tqdm(latents_to_decode):
|
| 476 |
+
if history_pixels is None:
|
| 477 |
+
history_pixels = hunyuan.vae_decode(latent, vae).cpu()
|
| 478 |
+
else:
|
| 479 |
+
overlapped_frames = latent_window_size * 4 - 3
|
| 480 |
+
current_pixels = hunyuan.vae_decode(latent, vae).cpu()
|
| 481 |
+
history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
|
| 482 |
+
clean_memory_on_device(device)
|
| 483 |
+
else:
|
| 484 |
+
# bulk decode
|
| 485 |
+
logger.info(f"Bulk decoding or one frame inference")
|
| 486 |
+
if not one_frame_inference_mode:
|
| 487 |
+
history_pixels = hunyuan.vae_decode(latent, vae).cpu() # normal
|
| 488 |
+
else:
|
| 489 |
+
# one frame inference
|
| 490 |
+
history_pixels = [hunyuan.vae_decode(latent[:, :, i : i + 1, :, :], vae).cpu() for i in range(latent.shape[2])]
|
| 491 |
+
history_pixels = torch.cat(history_pixels, dim=2)
|
| 492 |
+
|
| 493 |
+
vae.to("cpu")
|
| 494 |
+
|
| 495 |
+
logger.info(f"Decoded. Pixel shape {history_pixels.shape}")
|
| 496 |
+
return history_pixels[0] # remove batch dimension
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def prepare_i2v_inputs(
|
| 500 |
+
args: argparse.Namespace,
|
| 501 |
+
device: torch.device,
|
| 502 |
+
vae: AutoencoderKLCausal3D,
|
| 503 |
+
shared_models: Optional[Dict] = None,
|
| 504 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
| 505 |
+
"""Prepare inputs for I2V
|
| 506 |
+
|
| 507 |
+
Args:
|
| 508 |
+
args: command line arguments
|
| 509 |
+
config: model configuration
|
| 510 |
+
device: device to use
|
| 511 |
+
vae: VAE model, used for image encoding
|
| 512 |
+
shared_models: dictionary containing pre-loaded models
|
| 513 |
+
|
| 514 |
+
Returns:
|
| 515 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
|
| 516 |
+
(noise, context, context_null, y, (arg_c, arg_null))
|
| 517 |
+
"""
|
| 518 |
+
|
| 519 |
+
height, width, video_seconds = check_inputs(args)
|
| 520 |
+
|
| 521 |
+
# define parsing function
|
| 522 |
+
def parse_section_strings(input_string: str) -> dict[int, str]:
|
| 523 |
+
section_strings = {}
|
| 524 |
+
if ";;;" in input_string:
|
| 525 |
+
split_section_strings = input_string.split(";;;")
|
| 526 |
+
for section_str in split_section_strings:
|
| 527 |
+
if ":" not in section_str:
|
| 528 |
+
start = end = 0
|
| 529 |
+
section_str = section_str.strip()
|
| 530 |
+
else:
|
| 531 |
+
index_str, section_str = section_str.split(":", 1)
|
| 532 |
+
index_str = index_str.strip()
|
| 533 |
+
section_str = section_str.strip()
|
| 534 |
+
|
| 535 |
+
m = re.match(r"^(-?\d+)(-\d+)?$", index_str)
|
| 536 |
+
if m:
|
| 537 |
+
start = int(m.group(1))
|
| 538 |
+
end = int(m.group(2)[1:]) if m.group(2) is not None else start
|
| 539 |
+
else:
|
| 540 |
+
start = end = 0
|
| 541 |
+
section_str = section_str.strip()
|
| 542 |
+
for i in range(start, end + 1):
|
| 543 |
+
section_strings[i] = section_str
|
| 544 |
+
else:
|
| 545 |
+
section_strings[0] = input_string
|
| 546 |
+
|
| 547 |
+
# assert 0 in section_prompts, "Section prompts must contain section 0"
|
| 548 |
+
if 0 not in section_strings:
|
| 549 |
+
# use smallest section index. prefer positive index over negative index
|
| 550 |
+
# if all section indices are negative, use the smallest negative index
|
| 551 |
+
indices = list(section_strings.keys())
|
| 552 |
+
if all(i < 0 for i in indices):
|
| 553 |
+
section_index = min(indices)
|
| 554 |
+
else:
|
| 555 |
+
section_index = min(i for i in indices if i >= 0)
|
| 556 |
+
section_strings[0] = section_strings[section_index]
|
| 557 |
+
return section_strings
|
| 558 |
+
|
| 559 |
+
# prepare image
|
| 560 |
+
def preprocess_image(image_path: str):
|
| 561 |
+
image = Image.open(image_path).convert("RGB")
|
| 562 |
+
|
| 563 |
+
image_np = np.array(image) # PIL to numpy, HWC
|
| 564 |
+
|
| 565 |
+
image_np = image_video_dataset.resize_image_to_bucket(image_np, (width, height))
|
| 566 |
+
image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC
|
| 567 |
+
image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1
|
| 568 |
+
return image_tensor, image_np
|
| 569 |
+
|
| 570 |
+
section_image_paths = parse_section_strings(args.image_path)
|
| 571 |
+
|
| 572 |
+
section_images = {}
|
| 573 |
+
for index, image_path in section_image_paths.items():
|
| 574 |
+
img_tensor, img_np = preprocess_image(image_path)
|
| 575 |
+
section_images[index] = (img_tensor, img_np)
|
| 576 |
+
|
| 577 |
+
# check end images
|
| 578 |
+
if args.end_image_path is not None and len(args.end_image_path) > 0:
|
| 579 |
+
end_image_tensors = []
|
| 580 |
+
for end_img_path in args.end_image_path:
|
| 581 |
+
end_image_tensor, _ = preprocess_image(end_img_path)
|
| 582 |
+
end_image_tensors.append(end_image_tensor)
|
| 583 |
+
else:
|
| 584 |
+
end_image_tensors = None
|
| 585 |
+
|
| 586 |
+
# configure negative prompt
|
| 587 |
+
n_prompt = args.negative_prompt if args.negative_prompt else ""
|
| 588 |
+
|
| 589 |
+
# parse section prompts
|
| 590 |
+
section_prompts = parse_section_strings(args.prompt)
|
| 591 |
+
|
| 592 |
+
# load text encoder
|
| 593 |
+
if shared_models is not None:
|
| 594 |
+
tokenizer1, text_encoder1 = shared_models["tokenizer1"], shared_models["text_encoder1"]
|
| 595 |
+
tokenizer2, text_encoder2 = shared_models["tokenizer2"], shared_models["text_encoder2"]
|
| 596 |
+
text_encoder1.to(device)
|
| 597 |
+
else:
|
| 598 |
+
tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
|
| 599 |
+
tokenizer2, text_encoder2 = load_text_encoder2(args)
|
| 600 |
+
text_encoder2.to(device)
|
| 601 |
+
|
| 602 |
+
logger.info(f"Encoding prompt")
|
| 603 |
+
llama_vecs = {}
|
| 604 |
+
llama_attention_masks = {}
|
| 605 |
+
clip_l_poolers = {}
|
| 606 |
+
with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
|
| 607 |
+
for index, prompt in section_prompts.items():
|
| 608 |
+
llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(
|
| 609 |
+
prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2, custom_system_prompt=args.custom_system_prompt
|
| 610 |
+
)
|
| 611 |
+
llama_vec = llama_vec.cpu()
|
| 612 |
+
clip_l_pooler = clip_l_pooler.cpu()
|
| 613 |
+
|
| 614 |
+
llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
|
| 615 |
+
|
| 616 |
+
llama_vecs[index] = llama_vec
|
| 617 |
+
llama_attention_masks[index] = llama_attention_mask
|
| 618 |
+
clip_l_poolers[index] = clip_l_pooler
|
| 619 |
+
|
| 620 |
+
if args.guidance_scale == 1.0:
|
| 621 |
+
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vecs[0]), torch.zeros_like(clip_l_poolers[0])
|
| 622 |
+
else:
|
| 623 |
+
with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
|
| 624 |
+
llama_vec_n, clip_l_pooler_n = hunyuan.encode_prompt_conds(
|
| 625 |
+
n_prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2, custom_system_prompt=args.custom_system_prompt
|
| 626 |
+
)
|
| 627 |
+
llama_vec_n = llama_vec_n.cpu()
|
| 628 |
+
clip_l_pooler_n = clip_l_pooler_n.cpu()
|
| 629 |
+
|
| 630 |
+
llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
|
| 631 |
+
|
| 632 |
+
# free text encoder and clean memory
|
| 633 |
+
if shared_models is not None: # if shared models are used, do not free them but move to CPU
|
| 634 |
+
text_encoder1.to("cpu")
|
| 635 |
+
text_encoder2.to("cpu")
|
| 636 |
+
del tokenizer1, text_encoder1, tokenizer2, text_encoder2 # do not free shared models
|
| 637 |
+
clean_memory_on_device(device)
|
| 638 |
+
|
| 639 |
+
# load image encoder
|
| 640 |
+
if shared_models is not None:
|
| 641 |
+
feature_extractor, image_encoder = shared_models["feature_extractor"], shared_models["image_encoder"]
|
| 642 |
+
else:
|
| 643 |
+
feature_extractor, image_encoder = load_image_encoders(args)
|
| 644 |
+
image_encoder.to(device)
|
| 645 |
+
|
| 646 |
+
# encode image with image encoder
|
| 647 |
+
section_image_encoder_last_hidden_states = {}
|
| 648 |
+
for index, (img_tensor, img_np) in section_images.items():
|
| 649 |
+
with torch.no_grad():
|
| 650 |
+
image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder)
|
| 651 |
+
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state.cpu()
|
| 652 |
+
section_image_encoder_last_hidden_states[index] = image_encoder_last_hidden_state
|
| 653 |
+
|
| 654 |
+
# free image encoder and clean memory
|
| 655 |
+
if shared_models is not None:
|
| 656 |
+
image_encoder.to("cpu")
|
| 657 |
+
del image_encoder, feature_extractor
|
| 658 |
+
clean_memory_on_device(device)
|
| 659 |
+
|
| 660 |
+
# VAE encoding
|
| 661 |
+
logger.info(f"Encoding image to latent space")
|
| 662 |
+
vae.to(device)
|
| 663 |
+
|
| 664 |
+
section_start_latents = {}
|
| 665 |
+
for index, (img_tensor, img_np) in section_images.items():
|
| 666 |
+
start_latent = hunyuan.vae_encode(img_tensor, vae).cpu()
|
| 667 |
+
section_start_latents[index] = start_latent
|
| 668 |
+
|
| 669 |
+
# end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu() if end_image_tensor is not None else None
|
| 670 |
+
if end_image_tensors is not None:
|
| 671 |
+
end_latents = []
|
| 672 |
+
for end_image_tensor in end_image_tensors:
|
| 673 |
+
end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu()
|
| 674 |
+
end_latents.append(end_latent)
|
| 675 |
+
else:
|
| 676 |
+
end_latents = None
|
| 677 |
+
|
| 678 |
+
vae.to("cpu") # move VAE to CPU to save memory
|
| 679 |
+
clean_memory_on_device(device)
|
| 680 |
+
|
| 681 |
+
# prepare model input arguments
|
| 682 |
+
arg_c = {}
|
| 683 |
+
arg_null = {}
|
| 684 |
+
for index in llama_vecs.keys():
|
| 685 |
+
llama_vec = llama_vecs[index]
|
| 686 |
+
llama_attention_mask = llama_attention_masks[index]
|
| 687 |
+
clip_l_pooler = clip_l_poolers[index]
|
| 688 |
+
arg_c_i = {
|
| 689 |
+
"llama_vec": llama_vec,
|
| 690 |
+
"llama_attention_mask": llama_attention_mask,
|
| 691 |
+
"clip_l_pooler": clip_l_pooler,
|
| 692 |
+
"prompt": section_prompts[index], # for debugging
|
| 693 |
+
}
|
| 694 |
+
arg_c[index] = arg_c_i
|
| 695 |
+
|
| 696 |
+
arg_null = {
|
| 697 |
+
"llama_vec": llama_vec_n,
|
| 698 |
+
"llama_attention_mask": llama_attention_mask_n,
|
| 699 |
+
"clip_l_pooler": clip_l_pooler_n,
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
arg_c_img = {}
|
| 703 |
+
for index in section_images.keys():
|
| 704 |
+
image_encoder_last_hidden_state = section_image_encoder_last_hidden_states[index]
|
| 705 |
+
start_latent = section_start_latents[index]
|
| 706 |
+
arg_c_img_i = {
|
| 707 |
+
"image_encoder_last_hidden_state": image_encoder_last_hidden_state,
|
| 708 |
+
"start_latent": start_latent,
|
| 709 |
+
"image_path": section_image_paths[index],
|
| 710 |
+
}
|
| 711 |
+
arg_c_img[index] = arg_c_img_i
|
| 712 |
+
|
| 713 |
+
return height, width, video_seconds, arg_c, arg_null, arg_c_img, end_latents
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
# def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
|
| 717 |
+
# """setup scheduler for sampling
|
| 718 |
+
|
| 719 |
+
# Args:
|
| 720 |
+
# args: command line arguments
|
| 721 |
+
# config: model configuration
|
| 722 |
+
# device: device to use
|
| 723 |
+
|
| 724 |
+
# Returns:
|
| 725 |
+
# Tuple[Any, torch.Tensor]: (scheduler, timesteps)
|
| 726 |
+
# """
|
| 727 |
+
# if args.sample_solver == "unipc":
|
| 728 |
+
# scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False)
|
| 729 |
+
# scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift)
|
| 730 |
+
# timesteps = scheduler.timesteps
|
| 731 |
+
# elif args.sample_solver == "dpm++":
|
| 732 |
+
# scheduler = FlowDPMSolverMultistepScheduler(
|
| 733 |
+
# num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False
|
| 734 |
+
# )
|
| 735 |
+
# sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift)
|
| 736 |
+
# timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas)
|
| 737 |
+
# elif args.sample_solver == "vanilla":
|
| 738 |
+
# scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift)
|
| 739 |
+
# scheduler.set_timesteps(args.infer_steps, device=device)
|
| 740 |
+
# timesteps = scheduler.timesteps
|
| 741 |
+
|
| 742 |
+
# # FlowMatchDiscreteScheduler does not support generator argument in step method
|
| 743 |
+
# org_step = scheduler.step
|
| 744 |
+
|
| 745 |
+
# def step_wrapper(
|
| 746 |
+
# model_output: torch.Tensor,
|
| 747 |
+
# timestep: Union[int, torch.Tensor],
|
| 748 |
+
# sample: torch.Tensor,
|
| 749 |
+
# return_dict: bool = True,
|
| 750 |
+
# generator=None,
|
| 751 |
+
# ):
|
| 752 |
+
# return org_step(model_output, timestep, sample, return_dict=return_dict)
|
| 753 |
+
|
| 754 |
+
# scheduler.step = step_wrapper
|
| 755 |
+
# else:
|
| 756 |
+
# raise NotImplementedError("Unsupported solver.")
|
| 757 |
+
|
| 758 |
+
# return scheduler, timesteps
|
| 759 |
+
|
| 760 |
+
|
| 761 |
+
def convert_lora_for_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 762 |
+
# Check the format of the LoRA file
|
| 763 |
+
keys = list(lora_sd.keys())
|
| 764 |
+
if keys[0].startswith("lora_unet_"):
|
| 765 |
+
# logging.info(f"Musubi Tuner LoRA detected")
|
| 766 |
+
pass
|
| 767 |
+
|
| 768 |
+
else:
|
| 769 |
+
transformer_prefixes = ["diffusion_model", "transformer"] # to ignore Text Encoder modules
|
| 770 |
+
lora_suffix = None
|
| 771 |
+
prefix = None
|
| 772 |
+
for key in keys:
|
| 773 |
+
if lora_suffix is None and "lora_A" in key:
|
| 774 |
+
lora_suffix = "lora_A"
|
| 775 |
+
if prefix is None:
|
| 776 |
+
pfx = key.split(".")[0]
|
| 777 |
+
if pfx in transformer_prefixes:
|
| 778 |
+
prefix = pfx
|
| 779 |
+
if lora_suffix is not None and prefix is not None:
|
| 780 |
+
break
|
| 781 |
+
|
| 782 |
+
if lora_suffix == "lora_A" and prefix is not None:
|
| 783 |
+
logging.info(f"Diffusion-pipe (?) LoRA detected, converting to the default LoRA format")
|
| 784 |
+
lora_sd = convert_lora_from_diffusion_pipe_or_something(lora_sd, "lora_unet_")
|
| 785 |
+
|
| 786 |
+
else:
|
| 787 |
+
logging.info(f"LoRA file format not recognized. Using it as-is.")
|
| 788 |
+
|
| 789 |
+
# Check LoRA is for FramePack or for HunyuanVideo
|
| 790 |
+
is_hunyuan = False
|
| 791 |
+
for key in lora_sd.keys():
|
| 792 |
+
if "double_blocks" in key or "single_blocks" in key:
|
| 793 |
+
is_hunyuan = True
|
| 794 |
+
break
|
| 795 |
+
if is_hunyuan:
|
| 796 |
+
logging.info("HunyuanVideo LoRA detected, converting to FramePack format")
|
| 797 |
+
lora_sd = convert_hunyuan_to_framepack(lora_sd)
|
| 798 |
+
|
| 799 |
+
return lora_sd
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
def convert_lora_from_diffusion_pipe_or_something(lora_sd: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
|
| 803 |
+
"""
|
| 804 |
+
Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner.
|
| 805 |
+
Copy from Musubi Tuner repo.
|
| 806 |
+
"""
|
| 807 |
+
# convert from diffusers(?) to default LoRA
|
| 808 |
+
# Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
|
| 809 |
+
# default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
|
| 810 |
+
|
| 811 |
+
# note: Diffusers has no alpha, so alpha is set to rank
|
| 812 |
+
new_weights_sd = {}
|
| 813 |
+
lora_dims = {}
|
| 814 |
+
for key, weight in lora_sd.items():
|
| 815 |
+
diffusers_prefix, key_body = key.split(".", 1)
|
| 816 |
+
if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
|
| 817 |
+
print(f"unexpected key: {key} in diffusers format")
|
| 818 |
+
continue
|
| 819 |
+
|
| 820 |
+
new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
|
| 821 |
+
new_weights_sd[new_key] = weight
|
| 822 |
+
|
| 823 |
+
lora_name = new_key.split(".")[0] # before first dot
|
| 824 |
+
if lora_name not in lora_dims and "lora_down" in new_key:
|
| 825 |
+
lora_dims[lora_name] = weight.shape[0]
|
| 826 |
+
|
| 827 |
+
# add alpha with rank
|
| 828 |
+
for lora_name, dim in lora_dims.items():
|
| 829 |
+
new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
|
| 830 |
+
|
| 831 |
+
return new_weights_sd
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
def convert_hunyuan_to_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 835 |
+
"""
|
| 836 |
+
Convert HunyuanVideo LoRA weights to FramePack format.
|
| 837 |
+
"""
|
| 838 |
+
new_lora_sd = {}
|
| 839 |
+
for key, weight in lora_sd.items():
|
| 840 |
+
if "double_blocks" in key:
|
| 841 |
+
key = key.replace("double_blocks", "transformer_blocks")
|
| 842 |
+
key = key.replace("img_mod_linear", "norm1_linear")
|
| 843 |
+
key = key.replace("img_attn_qkv", "attn_to_QKV") # split later
|
| 844 |
+
key = key.replace("img_attn_proj", "attn_to_out_0")
|
| 845 |
+
key = key.replace("img_mlp_fc1", "ff_net_0_proj")
|
| 846 |
+
key = key.replace("img_mlp_fc2", "ff_net_2")
|
| 847 |
+
key = key.replace("txt_mod_linear", "norm1_context_linear")
|
| 848 |
+
key = key.replace("txt_attn_qkv", "attn_add_QKV_proj") # split later
|
| 849 |
+
key = key.replace("txt_attn_proj", "attn_to_add_out")
|
| 850 |
+
key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj")
|
| 851 |
+
key = key.replace("txt_mlp_fc2", "ff_context_net_2")
|
| 852 |
+
elif "single_blocks" in key:
|
| 853 |
+
key = key.replace("single_blocks", "single_transformer_blocks")
|
| 854 |
+
key = key.replace("linear1", "attn_to_QKVM") # split later
|
| 855 |
+
key = key.replace("linear2", "proj_out")
|
| 856 |
+
key = key.replace("modulation_linear", "norm_linear")
|
| 857 |
+
else:
|
| 858 |
+
print(f"Unsupported module name: {key}, only double_blocks and single_blocks are supported")
|
| 859 |
+
continue
|
| 860 |
+
|
| 861 |
+
if "QKVM" in key:
|
| 862 |
+
# split QKVM into Q, K, V, M
|
| 863 |
+
key_q = key.replace("QKVM", "q")
|
| 864 |
+
key_k = key.replace("QKVM", "k")
|
| 865 |
+
key_v = key.replace("QKVM", "v")
|
| 866 |
+
key_m = key.replace("attn_to_QKVM", "proj_mlp")
|
| 867 |
+
if "_down" in key or "alpha" in key:
|
| 868 |
+
# copy QKVM weight or alpha to Q, K, V, M
|
| 869 |
+
assert "alpha" in key or weight.size(1) == 3072, f"QKVM weight size mismatch: {key}. {weight.size()}"
|
| 870 |
+
new_lora_sd[key_q] = weight
|
| 871 |
+
new_lora_sd[key_k] = weight
|
| 872 |
+
new_lora_sd[key_v] = weight
|
| 873 |
+
new_lora_sd[key_m] = weight
|
| 874 |
+
elif "_up" in key:
|
| 875 |
+
# split QKVM weight into Q, K, V, M
|
| 876 |
+
assert weight.size(0) == 21504, f"QKVM weight size mismatch: {key}. {weight.size()}"
|
| 877 |
+
new_lora_sd[key_q] = weight[:3072]
|
| 878 |
+
new_lora_sd[key_k] = weight[3072 : 3072 * 2]
|
| 879 |
+
new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3]
|
| 880 |
+
new_lora_sd[key_m] = weight[3072 * 3 :] # 21504 - 3072 * 3 = 12288
|
| 881 |
+
else:
|
| 882 |
+
print(f"Unsupported module name: {key}")
|
| 883 |
+
continue
|
| 884 |
+
elif "QKV" in key:
|
| 885 |
+
# split QKV into Q, K, V
|
| 886 |
+
key_q = key.replace("QKV", "q")
|
| 887 |
+
key_k = key.replace("QKV", "k")
|
| 888 |
+
key_v = key.replace("QKV", "v")
|
| 889 |
+
if "_down" in key or "alpha" in key:
|
| 890 |
+
# copy QKV weight or alpha to Q, K, V
|
| 891 |
+
assert "alpha" in key or weight.size(1) == 3072, f"QKV weight size mismatch: {key}. {weight.size()}"
|
| 892 |
+
new_lora_sd[key_q] = weight
|
| 893 |
+
new_lora_sd[key_k] = weight
|
| 894 |
+
new_lora_sd[key_v] = weight
|
| 895 |
+
elif "_up" in key:
|
| 896 |
+
# split QKV weight into Q, K, V
|
| 897 |
+
assert weight.size(0) == 3072 * 3, f"QKV weight size mismatch: {key}. {weight.size()}"
|
| 898 |
+
new_lora_sd[key_q] = weight[:3072]
|
| 899 |
+
new_lora_sd[key_k] = weight[3072 : 3072 * 2]
|
| 900 |
+
new_lora_sd[key_v] = weight[3072 * 2 :]
|
| 901 |
+
else:
|
| 902 |
+
print(f"Unsupported module name: {key}")
|
| 903 |
+
continue
|
| 904 |
+
else:
|
| 905 |
+
# no split needed
|
| 906 |
+
new_lora_sd[key] = weight
|
| 907 |
+
|
| 908 |
+
return new_lora_sd
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
def generate(
|
| 912 |
+
args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None
|
| 913 |
+
) -> tuple[AutoencoderKLCausal3D, torch.Tensor]:
|
| 914 |
+
"""main function for generation
|
| 915 |
+
|
| 916 |
+
Args:
|
| 917 |
+
args: command line arguments
|
| 918 |
+
shared_models: dictionary containing pre-loaded models
|
| 919 |
+
|
| 920 |
+
Returns:
|
| 921 |
+
tuple: (AutoencoderKLCausal3D model (vae), torch.Tensor generated latent)
|
| 922 |
+
"""
|
| 923 |
+
device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype)
|
| 924 |
+
|
| 925 |
+
# prepare seed
|
| 926 |
+
seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
|
| 927 |
+
args.seed = seed # set seed to args for saving
|
| 928 |
+
|
| 929 |
+
# Check if we have shared models
|
| 930 |
+
if shared_models is not None:
|
| 931 |
+
# Use shared models and encoded data
|
| 932 |
+
vae = shared_models.get("vae")
|
| 933 |
+
height, width, video_seconds, context, context_null, context_img, end_latents = prepare_i2v_inputs(
|
| 934 |
+
args, device, vae, shared_models
|
| 935 |
+
)
|
| 936 |
+
else:
|
| 937 |
+
# prepare inputs without shared models
|
| 938 |
+
vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
|
| 939 |
+
height, width, video_seconds, context, context_null, context_img, end_latents = prepare_i2v_inputs(args, device, vae)
|
| 940 |
+
|
| 941 |
+
if shared_models is None or "model" not in shared_models:
|
| 942 |
+
# load DiT model
|
| 943 |
+
model = load_dit_model(args, device)
|
| 944 |
+
|
| 945 |
+
# merge LoRA weights
|
| 946 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
| 947 |
+
# ugly hack to common merge_lora_weights function
|
| 948 |
+
merge_lora_weights(lora_framepack, model, args, device, convert_lora_for_framepack)
|
| 949 |
+
|
| 950 |
+
# if we only want to save the model, we can skip the rest
|
| 951 |
+
if args.save_merged_model:
|
| 952 |
+
return None, None
|
| 953 |
+
|
| 954 |
+
# optimize model: fp8 conversion, block swap etc.
|
| 955 |
+
optimize_model(model, args, device)
|
| 956 |
+
|
| 957 |
+
if shared_models is not None:
|
| 958 |
+
shared_models["model"] = model
|
| 959 |
+
else:
|
| 960 |
+
# use shared model
|
| 961 |
+
model: HunyuanVideoTransformer3DModelPacked = shared_models["model"]
|
| 962 |
+
model.move_to_device_except_swap_blocks(device)
|
| 963 |
+
model.prepare_block_swap_before_forward()
|
| 964 |
+
|
| 965 |
+
# sampling
|
| 966 |
+
latent_window_size = args.latent_window_size # default is 9
|
| 967 |
+
# ex: (5s * 30fps) / (9 * 4) = 4.16 -> 4 sections, 60s -> 1800 / 36 = 50 sections
|
| 968 |
+
total_latent_sections = (video_seconds * 30) / (latent_window_size * 4)
|
| 969 |
+
total_latent_sections = int(max(round(total_latent_sections), 1))
|
| 970 |
+
|
| 971 |
+
# set random generator
|
| 972 |
+
seed_g = torch.Generator(device="cpu")
|
| 973 |
+
seed_g.manual_seed(seed)
|
| 974 |
+
num_frames = latent_window_size * 4 - 3
|
| 975 |
+
|
| 976 |
+
logger.info(
|
| 977 |
+
f"Video size: {height}x{width}@{video_seconds} (HxW@seconds), fps: {args.fps}, num sections: {total_latent_sections}, "
|
| 978 |
+
f"infer_steps: {args.infer_steps}, frames per generation: {num_frames}"
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
# video generation ######
|
| 982 |
+
f1_mode = args.f1
|
| 983 |
+
one_frame_inference = None
|
| 984 |
+
if args.one_frame_inference is not None:
|
| 985 |
+
one_frame_inference = set()
|
| 986 |
+
for mode in args.one_frame_inference.split(","):
|
| 987 |
+
one_frame_inference.add(mode.strip())
|
| 988 |
+
|
| 989 |
+
# prepare history latents
|
| 990 |
+
history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
|
| 991 |
+
if end_latents is not None and not f1_mode:
|
| 992 |
+
logger.info(f"Use end image(s): {args.end_image_path}")
|
| 993 |
+
for i, end_latent in enumerate(end_latents):
|
| 994 |
+
history_latents[:, :, i + 1 : i + 2] = end_latent.to(history_latents)
|
| 995 |
+
|
| 996 |
+
# prepare clean latents and indices
|
| 997 |
+
if not f1_mode:
|
| 998 |
+
# Inverted Anti-drifting
|
| 999 |
+
total_generated_latent_frames = 0
|
| 1000 |
+
latent_paddings = reversed(range(total_latent_sections))
|
| 1001 |
+
|
| 1002 |
+
if total_latent_sections > 4 and one_frame_inference is None:
|
| 1003 |
+
# In theory the latent_paddings should follow the above sequence, but it seems that duplicating some
|
| 1004 |
+
# items looks better than expanding it when total_latent_sections > 4
|
| 1005 |
+
# One can try to remove below trick and just
|
| 1006 |
+
# use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
|
| 1007 |
+
# 4 sections: 3, 2, 1, 0. 50 sections: 3, 2, 2, ... 2, 1, 0
|
| 1008 |
+
latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
|
| 1009 |
+
|
| 1010 |
+
if args.latent_paddings is not None:
|
| 1011 |
+
# parse user defined latent paddings
|
| 1012 |
+
user_latent_paddings = [int(x) for x in args.latent_paddings.split(",")]
|
| 1013 |
+
if len(user_latent_paddings) < total_latent_sections:
|
| 1014 |
+
print(
|
| 1015 |
+
f"User defined latent paddings length {len(user_latent_paddings)} does not match total sections {total_latent_sections}."
|
| 1016 |
+
)
|
| 1017 |
+
print(f"Use default paddings instead for unspecified sections.")
|
| 1018 |
+
latent_paddings[: len(user_latent_paddings)] = user_latent_paddings
|
| 1019 |
+
elif len(user_latent_paddings) > total_latent_sections:
|
| 1020 |
+
print(
|
| 1021 |
+
f"User defined latent paddings length {len(user_latent_paddings)} is greater than total sections {total_latent_sections}."
|
| 1022 |
+
)
|
| 1023 |
+
print(f"Use only first {total_latent_sections} paddings instead.")
|
| 1024 |
+
latent_paddings = user_latent_paddings[:total_latent_sections]
|
| 1025 |
+
else:
|
| 1026 |
+
latent_paddings = user_latent_paddings
|
| 1027 |
+
else:
|
| 1028 |
+
start_latent = context_img[0]["start_latent"]
|
| 1029 |
+
history_latents = torch.cat([history_latents, start_latent], dim=2)
|
| 1030 |
+
total_generated_latent_frames = 1 # a bit hacky, but we employ the same logic as in official code
|
| 1031 |
+
latent_paddings = [0] * total_latent_sections # dummy paddings for F1 mode
|
| 1032 |
+
|
| 1033 |
+
latent_paddings = list(latent_paddings) # make sure it's a list
|
| 1034 |
+
for loop_index in range(total_latent_sections):
|
| 1035 |
+
latent_padding = latent_paddings[loop_index]
|
| 1036 |
+
|
| 1037 |
+
if not f1_mode:
|
| 1038 |
+
# Inverted Anti-drifting
|
| 1039 |
+
section_index_reverse = loop_index # 0, 1, 2, 3
|
| 1040 |
+
section_index = total_latent_sections - 1 - section_index_reverse # 3, 2, 1, 0
|
| 1041 |
+
section_index_from_last = -(section_index_reverse + 1) # -1, -2, -3, -4
|
| 1042 |
+
|
| 1043 |
+
is_last_section = section_index == 0
|
| 1044 |
+
is_first_section = section_index_reverse == 0
|
| 1045 |
+
latent_padding_size = latent_padding * latent_window_size
|
| 1046 |
+
|
| 1047 |
+
logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
|
| 1048 |
+
else:
|
| 1049 |
+
section_index = loop_index # 0, 1, 2, 3
|
| 1050 |
+
section_index_from_last = section_index - total_latent_sections # -4, -3, -2, -1
|
| 1051 |
+
is_last_section = loop_index == total_latent_sections - 1
|
| 1052 |
+
is_first_section = loop_index == 0
|
| 1053 |
+
latent_padding_size = 0 # dummy padding for F1 mode
|
| 1054 |
+
|
| 1055 |
+
# select start latent
|
| 1056 |
+
if section_index_from_last in context_img:
|
| 1057 |
+
image_index = section_index_from_last
|
| 1058 |
+
elif section_index in context_img:
|
| 1059 |
+
image_index = section_index
|
| 1060 |
+
else:
|
| 1061 |
+
image_index = 0
|
| 1062 |
+
|
| 1063 |
+
start_latent = context_img[image_index]["start_latent"]
|
| 1064 |
+
image_path = context_img[image_index]["image_path"]
|
| 1065 |
+
if image_index != 0: # use section image other than section 0
|
| 1066 |
+
logger.info(f"Apply experimental section image, latent_padding_size = {latent_padding_size}, image_path = {image_path}")
|
| 1067 |
+
|
| 1068 |
+
if not f1_mode:
|
| 1069 |
+
# Inverted Anti-drifting
|
| 1070 |
+
indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
|
| 1071 |
+
(
|
| 1072 |
+
clean_latent_indices_pre,
|
| 1073 |
+
blank_indices,
|
| 1074 |
+
latent_indices,
|
| 1075 |
+
clean_latent_indices_post,
|
| 1076 |
+
clean_latent_2x_indices,
|
| 1077 |
+
clean_latent_4x_indices,
|
| 1078 |
+
) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
|
| 1079 |
+
clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
|
| 1080 |
+
|
| 1081 |
+
clean_latents_pre = start_latent.to(history_latents)
|
| 1082 |
+
clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
|
| 1083 |
+
[1, 2, 16], dim=2
|
| 1084 |
+
)
|
| 1085 |
+
clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
|
| 1086 |
+
|
| 1087 |
+
if end_latents is not None:
|
| 1088 |
+
clean_latents = torch.cat([clean_latents_pre, history_latents[:, :, : len(end_latents)]], dim=2)
|
| 1089 |
+
clean_latent_indices_extended = torch.zeros(1, 1 + len(end_latents), dtype=clean_latent_indices.dtype)
|
| 1090 |
+
clean_latent_indices_extended[:, :2] = clean_latent_indices
|
| 1091 |
+
clean_latent_indices = clean_latent_indices_extended
|
| 1092 |
+
|
| 1093 |
+
else:
|
| 1094 |
+
# F1 mode
|
| 1095 |
+
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
|
| 1096 |
+
(
|
| 1097 |
+
clean_latent_indices_start,
|
| 1098 |
+
clean_latent_4x_indices,
|
| 1099 |
+
clean_latent_2x_indices,
|
| 1100 |
+
clean_latent_1x_indices,
|
| 1101 |
+
latent_indices,
|
| 1102 |
+
) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
|
| 1103 |
+
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
|
| 1104 |
+
|
| 1105 |
+
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
|
| 1106 |
+
[16, 2, 1], dim=2
|
| 1107 |
+
)
|
| 1108 |
+
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
|
| 1109 |
+
|
| 1110 |
+
# if use_teacache:
|
| 1111 |
+
# transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
|
| 1112 |
+
# else:
|
| 1113 |
+
# transformer.initialize_teacache(enable_teacache=False)
|
| 1114 |
+
|
| 1115 |
+
# prepare conditioning inputs
|
| 1116 |
+
if section_index_from_last in context:
|
| 1117 |
+
prompt_index = section_index_from_last
|
| 1118 |
+
elif section_index in context:
|
| 1119 |
+
prompt_index = section_index
|
| 1120 |
+
else:
|
| 1121 |
+
prompt_index = 0
|
| 1122 |
+
|
| 1123 |
+
context_for_index = context[prompt_index]
|
| 1124 |
+
# if args.section_prompts is not None:
|
| 1125 |
+
logger.info(f"Section {section_index}: {context_for_index['prompt']}")
|
| 1126 |
+
|
| 1127 |
+
llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
|
| 1128 |
+
llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
|
| 1129 |
+
clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
| 1130 |
+
|
| 1131 |
+
image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(
|
| 1132 |
+
device, dtype=torch.bfloat16
|
| 1133 |
+
)
|
| 1134 |
+
|
| 1135 |
+
llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
|
| 1136 |
+
llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
|
| 1137 |
+
clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
|
| 1138 |
+
|
| 1139 |
+
# call DiT model to generate latents
|
| 1140 |
+
sample_num_frames = num_frames
|
| 1141 |
+
if one_frame_inference is not None:
|
| 1142 |
+
# one frame inference
|
| 1143 |
+
latent_indices = latent_indices[:, -1:] # only use the last frame (default)
|
| 1144 |
+
sample_num_frames = 1
|
| 1145 |
+
|
| 1146 |
+
def get_latent_mask(mask_path: str):
|
| 1147 |
+
mask_image = Image.open(mask_path).convert("L") # grayscale
|
| 1148 |
+
mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
|
| 1149 |
+
mask_image = np.array(mask_image) # PIL to numpy, HWC
|
| 1150 |
+
mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
|
| 1151 |
+
mask_image = mask_image.squeeze(-1) # HWC -> HW
|
| 1152 |
+
mask_image = mask_image.unsqueeze(0).unsqueeze(0) # HW -> 11HW
|
| 1153 |
+
mask_image = mask_image.to(clean_latents)
|
| 1154 |
+
return mask_image
|
| 1155 |
+
|
| 1156 |
+
if args.image_mask_path is not None:
|
| 1157 |
+
mask_image = get_latent_mask(args.image_mask_path)
|
| 1158 |
+
logger.info(f"Apply mask for clean latents (start image): {args.image_mask_path}, shape: {mask_image.shape}")
|
| 1159 |
+
clean_latents[:, :, 0, :, :] = clean_latents[:, :, 0, :, :] * mask_image
|
| 1160 |
+
if args.end_image_mask_path is not None and len(args.end_image_mask_path) > 0:
|
| 1161 |
+
# # apply mask for clean latents 1x (end image)
|
| 1162 |
+
count = min(len(args.end_image_mask_path), len(end_latents))
|
| 1163 |
+
for i in range(count):
|
| 1164 |
+
mask_image = get_latent_mask(args.end_image_mask_path[i])
|
| 1165 |
+
logger.info(
|
| 1166 |
+
f"Apply mask for clean latents 1x (end image) for {i+1}: {args.end_image_mask_path[i]}, shape: {mask_image.shape}"
|
| 1167 |
+
)
|
| 1168 |
+
clean_latents[:, :, i + 1 : i + 2, :, :] = clean_latents[:, :, i + 1 : i + 2, :, :] * mask_image
|
| 1169 |
+
|
| 1170 |
+
for one_frame_param in one_frame_inference:
|
| 1171 |
+
if one_frame_param.startswith("target_index="):
|
| 1172 |
+
target_index = int(one_frame_param.split("=")[1])
|
| 1173 |
+
latent_indices[:, 0] = target_index
|
| 1174 |
+
logger.info(f"Set index for target: {target_index}")
|
| 1175 |
+
elif one_frame_param.startswith("start_index="):
|
| 1176 |
+
start_index = int(one_frame_param.split("=")[1])
|
| 1177 |
+
clean_latent_indices[:, 0] = start_index
|
| 1178 |
+
logger.info(f"Set index for clean latent pre (start image): {start_index}")
|
| 1179 |
+
elif one_frame_param.startswith("history_index="):
|
| 1180 |
+
history_indices = one_frame_param.split("=")[1].split(";")
|
| 1181 |
+
i = 0
|
| 1182 |
+
while i < len(history_indices) and i < len(end_latents):
|
| 1183 |
+
history_index = int(history_indices[i])
|
| 1184 |
+
clean_latent_indices[:, 1 + i] = history_index
|
| 1185 |
+
i += 1
|
| 1186 |
+
while i < len(end_latents):
|
| 1187 |
+
clean_latent_indices[:, 1 + i] = history_index
|
| 1188 |
+
i += 1
|
| 1189 |
+
logger.info(f"Set index for clean latent post (end image): {history_indices}")
|
| 1190 |
+
|
| 1191 |
+
if "no_2x" in one_frame_inference:
|
| 1192 |
+
clean_latents_2x = None
|
| 1193 |
+
clean_latent_2x_indices = None
|
| 1194 |
+
logger.info(f"No clean_latents_2x")
|
| 1195 |
+
if "no_4x" in one_frame_inference:
|
| 1196 |
+
clean_latents_4x = None
|
| 1197 |
+
clean_latent_4x_indices = None
|
| 1198 |
+
logger.info(f"No clean_latents_4x")
|
| 1199 |
+
if "no_post" in one_frame_inference:
|
| 1200 |
+
clean_latents = clean_latents[:, :, :1, :, :]
|
| 1201 |
+
clean_latent_indices = clean_latent_indices[:, :1]
|
| 1202 |
+
logger.info(f"No clean_latents post")
|
| 1203 |
+
elif "zero_post" in one_frame_inference:
|
| 1204 |
+
# zero out the history latents. this seems to prevent the images from corrupting
|
| 1205 |
+
clean_latents[:, :, 1:, :, :] = torch.zeros_like(clean_latents[:, :, 1:, :, :])
|
| 1206 |
+
logger.info(f"Zero out clean_latents post")
|
| 1207 |
+
|
| 1208 |
+
logger.info(
|
| 1209 |
+
f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
|
| 1210 |
+
)
|
| 1211 |
+
|
| 1212 |
+
generated_latents = sample_hunyuan(
|
| 1213 |
+
transformer=model,
|
| 1214 |
+
sampler=args.sample_solver,
|
| 1215 |
+
width=width,
|
| 1216 |
+
height=height,
|
| 1217 |
+
frames=sample_num_frames,
|
| 1218 |
+
real_guidance_scale=args.guidance_scale,
|
| 1219 |
+
distilled_guidance_scale=args.embedded_cfg_scale,
|
| 1220 |
+
guidance_rescale=args.guidance_rescale,
|
| 1221 |
+
# shift=3.0,
|
| 1222 |
+
num_inference_steps=args.infer_steps,
|
| 1223 |
+
generator=seed_g,
|
| 1224 |
+
prompt_embeds=llama_vec,
|
| 1225 |
+
prompt_embeds_mask=llama_attention_mask,
|
| 1226 |
+
prompt_poolers=clip_l_pooler,
|
| 1227 |
+
negative_prompt_embeds=llama_vec_n,
|
| 1228 |
+
negative_prompt_embeds_mask=llama_attention_mask_n,
|
| 1229 |
+
negative_prompt_poolers=clip_l_pooler_n,
|
| 1230 |
+
device=device,
|
| 1231 |
+
dtype=torch.bfloat16,
|
| 1232 |
+
image_embeddings=image_encoder_last_hidden_state,
|
| 1233 |
+
latent_indices=latent_indices,
|
| 1234 |
+
clean_latents=clean_latents,
|
| 1235 |
+
clean_latent_indices=clean_latent_indices,
|
| 1236 |
+
clean_latents_2x=clean_latents_2x,
|
| 1237 |
+
clean_latent_2x_indices=clean_latent_2x_indices,
|
| 1238 |
+
clean_latents_4x=clean_latents_4x,
|
| 1239 |
+
clean_latent_4x_indices=clean_latent_4x_indices,
|
| 1240 |
+
)
|
| 1241 |
+
|
| 1242 |
+
# concatenate generated latents
|
| 1243 |
+
total_generated_latent_frames += int(generated_latents.shape[2])
|
| 1244 |
+
if not f1_mode:
|
| 1245 |
+
# Inverted Anti-drifting: prepend generated latents to history latents
|
| 1246 |
+
if is_last_section:
|
| 1247 |
+
generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
|
| 1248 |
+
total_generated_latent_frames += 1
|
| 1249 |
+
|
| 1250 |
+
history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
|
| 1251 |
+
real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
|
| 1252 |
+
else:
|
| 1253 |
+
# F1 mode: append generated latents to history latents
|
| 1254 |
+
history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
|
| 1255 |
+
real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
|
| 1256 |
+
|
| 1257 |
+
logger.info(f"Generated. Latent shape {real_history_latents.shape}")
|
| 1258 |
+
|
| 1259 |
+
# # TODO support saving intermediate video
|
| 1260 |
+
# clean_memory_on_device(device)
|
| 1261 |
+
# vae.to(device)
|
| 1262 |
+
# if history_pixels is None:
|
| 1263 |
+
# history_pixels = hunyuan.vae_decode(real_history_latents, vae).cpu()
|
| 1264 |
+
# else:
|
| 1265 |
+
# section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
|
| 1266 |
+
# overlapped_frames = latent_window_size * 4 - 3
|
| 1267 |
+
# current_pixels = hunyuan.vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
|
| 1268 |
+
# history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
|
| 1269 |
+
# vae.to("cpu")
|
| 1270 |
+
# # if not is_last_section:
|
| 1271 |
+
# # # save intermediate video
|
| 1272 |
+
# # save_video(history_pixels[0], args, total_generated_latent_frames)
|
| 1273 |
+
# print(f"Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}")
|
| 1274 |
+
|
| 1275 |
+
if one_frame_inference is not None:
|
| 1276 |
+
real_history_latents = real_history_latents[:, :, 1:, :, :] # remove the first frame (start_latent)
|
| 1277 |
+
|
| 1278 |
+
# Only clean up shared models if they were created within this function
|
| 1279 |
+
if shared_models is None:
|
| 1280 |
+
del model # free memory
|
| 1281 |
+
synchronize_device(device)
|
| 1282 |
+
else:
|
| 1283 |
+
# move model to CPU to save memory
|
| 1284 |
+
model.to("cpu")
|
| 1285 |
+
|
| 1286 |
+
# wait for 5 seconds until block swap is done
|
| 1287 |
+
logger.info("Waiting for 5 seconds to finish block swap")
|
| 1288 |
+
time.sleep(5)
|
| 1289 |
+
|
| 1290 |
+
gc.collect()
|
| 1291 |
+
clean_memory_on_device(device)
|
| 1292 |
+
|
| 1293 |
+
return vae, real_history_latents
|
| 1294 |
+
|
| 1295 |
+
|
| 1296 |
+
def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
|
| 1297 |
+
"""Save latent to file
|
| 1298 |
+
|
| 1299 |
+
Args:
|
| 1300 |
+
latent: Latent tensor
|
| 1301 |
+
args: command line arguments
|
| 1302 |
+
height: height of frame
|
| 1303 |
+
width: width of frame
|
| 1304 |
+
|
| 1305 |
+
Returns:
|
| 1306 |
+
str: Path to saved latent file
|
| 1307 |
+
"""
|
| 1308 |
+
save_path = args.save_path
|
| 1309 |
+
os.makedirs(save_path, exist_ok=True)
|
| 1310 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
| 1311 |
+
|
| 1312 |
+
seed = args.seed
|
| 1313 |
+
video_seconds = args.video_seconds
|
| 1314 |
+
latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
|
| 1315 |
+
|
| 1316 |
+
if args.no_metadata:
|
| 1317 |
+
metadata = None
|
| 1318 |
+
else:
|
| 1319 |
+
metadata = {
|
| 1320 |
+
"seeds": f"{seed}",
|
| 1321 |
+
"prompt": f"{args.prompt}",
|
| 1322 |
+
"height": f"{height}",
|
| 1323 |
+
"width": f"{width}",
|
| 1324 |
+
"video_seconds": f"{video_seconds}",
|
| 1325 |
+
"infer_steps": f"{args.infer_steps}",
|
| 1326 |
+
"guidance_scale": f"{args.guidance_scale}",
|
| 1327 |
+
"latent_window_size": f"{args.latent_window_size}",
|
| 1328 |
+
"embedded_cfg_scale": f"{args.embedded_cfg_scale}",
|
| 1329 |
+
"guidance_rescale": f"{args.guidance_rescale}",
|
| 1330 |
+
"sample_solver": f"{args.sample_solver}",
|
| 1331 |
+
"latent_window_size": f"{args.latent_window_size}",
|
| 1332 |
+
"fps": f"{args.fps}",
|
| 1333 |
+
}
|
| 1334 |
+
if args.negative_prompt is not None:
|
| 1335 |
+
metadata["negative_prompt"] = f"{args.negative_prompt}"
|
| 1336 |
+
|
| 1337 |
+
sd = {"latent": latent.contiguous()}
|
| 1338 |
+
save_file(sd, latent_path, metadata=metadata)
|
| 1339 |
+
logger.info(f"Latent saved to: {latent_path}")
|
| 1340 |
+
|
| 1341 |
+
return latent_path
|
| 1342 |
+
|
| 1343 |
+
|
| 1344 |
+
def save_video(
|
| 1345 |
+
video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None, latent_frames: Optional[int] = None
|
| 1346 |
+
) -> str:
|
| 1347 |
+
"""Save video to file
|
| 1348 |
+
|
| 1349 |
+
Args:
|
| 1350 |
+
video: Video tensor
|
| 1351 |
+
args: command line arguments
|
| 1352 |
+
original_base_name: Original base name (if latents are loaded from files)
|
| 1353 |
+
|
| 1354 |
+
Returns:
|
| 1355 |
+
str: Path to saved video file
|
| 1356 |
+
"""
|
| 1357 |
+
save_path = args.save_path
|
| 1358 |
+
os.makedirs(save_path, exist_ok=True)
|
| 1359 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
| 1360 |
+
|
| 1361 |
+
seed = args.seed
|
| 1362 |
+
original_name = "" if original_base_name is None else f"_{original_base_name}"
|
| 1363 |
+
latent_frames = "" if latent_frames is None else f"_{latent_frames}"
|
| 1364 |
+
video_path = f"{save_path}/{time_flag}_{seed}{original_name}{latent_frames}.mp4"
|
| 1365 |
+
|
| 1366 |
+
video = video.unsqueeze(0)
|
| 1367 |
+
save_videos_grid(video, video_path, fps=args.fps, rescale=True)
|
| 1368 |
+
logger.info(f"Video saved to: {video_path}")
|
| 1369 |
+
|
| 1370 |
+
return video_path
|
| 1371 |
+
|
| 1372 |
+
|
| 1373 |
+
def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
|
| 1374 |
+
"""Save images to directory
|
| 1375 |
+
|
| 1376 |
+
Args:
|
| 1377 |
+
sample: Video tensor
|
| 1378 |
+
args: command line arguments
|
| 1379 |
+
original_base_name: Original base name (if latents are loaded from files)
|
| 1380 |
+
|
| 1381 |
+
Returns:
|
| 1382 |
+
str: Path to saved images directory
|
| 1383 |
+
"""
|
| 1384 |
+
save_path = args.save_path
|
| 1385 |
+
os.makedirs(save_path, exist_ok=True)
|
| 1386 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
| 1387 |
+
|
| 1388 |
+
seed = args.seed
|
| 1389 |
+
original_name = "" if original_base_name is None else f"_{original_base_name}"
|
| 1390 |
+
image_name = f"{time_flag}_{seed}{original_name}"
|
| 1391 |
+
sample = sample.unsqueeze(0)
|
| 1392 |
+
one_frame_mode = args.one_frame_inference is not None
|
| 1393 |
+
save_images_grid(sample, save_path, image_name, rescale=True, create_subdir=not one_frame_mode)
|
| 1394 |
+
logger.info(f"Sample images saved to: {save_path}/{image_name}")
|
| 1395 |
+
|
| 1396 |
+
return f"{save_path}/{image_name}"
|
| 1397 |
+
|
| 1398 |
+
|
| 1399 |
+
def save_output(
|
| 1400 |
+
args: argparse.Namespace,
|
| 1401 |
+
vae: AutoencoderKLCausal3D,
|
| 1402 |
+
latent: torch.Tensor,
|
| 1403 |
+
device: torch.device,
|
| 1404 |
+
original_base_names: Optional[List[str]] = None,
|
| 1405 |
+
) -> None:
|
| 1406 |
+
"""save output
|
| 1407 |
+
|
| 1408 |
+
Args:
|
| 1409 |
+
args: command line arguments
|
| 1410 |
+
vae: VAE model
|
| 1411 |
+
latent: latent tensor
|
| 1412 |
+
device: device to use
|
| 1413 |
+
original_base_names: original base names (if latents are loaded from files)
|
| 1414 |
+
"""
|
| 1415 |
+
height, width = latent.shape[-2], latent.shape[-1] # BCTHW
|
| 1416 |
+
height *= 8
|
| 1417 |
+
width *= 8
|
| 1418 |
+
# print(f"Saving output. Latent shape {latent.shape}; pixel shape {height}x{width}")
|
| 1419 |
+
if args.output_type == "latent" or args.output_type == "both" or args.output_type == "latent_images":
|
| 1420 |
+
# save latent
|
| 1421 |
+
save_latent(latent, args, height, width)
|
| 1422 |
+
if args.output_type == "latent":
|
| 1423 |
+
return
|
| 1424 |
+
|
| 1425 |
+
total_latent_sections = (args.video_seconds * 30) / (args.latent_window_size * 4)
|
| 1426 |
+
total_latent_sections = int(max(round(total_latent_sections), 1))
|
| 1427 |
+
video = decode_latent(
|
| 1428 |
+
args.latent_window_size, total_latent_sections, args.bulk_decode, vae, latent, device, args.one_frame_inference is not None
|
| 1429 |
+
)
|
| 1430 |
+
|
| 1431 |
+
if args.output_type == "video" or args.output_type == "both":
|
| 1432 |
+
# save video
|
| 1433 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
|
| 1434 |
+
save_video(video, args, original_name)
|
| 1435 |
+
|
| 1436 |
+
elif args.output_type == "images" or args.output_type == "latent_images":
|
| 1437 |
+
# save images
|
| 1438 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
|
| 1439 |
+
save_images(video, args, original_name)
|
| 1440 |
+
|
| 1441 |
+
|
| 1442 |
+
def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
|
| 1443 |
+
"""Process multiple prompts for batch mode
|
| 1444 |
+
|
| 1445 |
+
Args:
|
| 1446 |
+
prompt_lines: List of prompt lines
|
| 1447 |
+
base_args: Base command line arguments
|
| 1448 |
+
|
| 1449 |
+
Returns:
|
| 1450 |
+
List[Dict]: List of prompt data dictionaries
|
| 1451 |
+
"""
|
| 1452 |
+
prompts_data = []
|
| 1453 |
+
|
| 1454 |
+
for line in prompt_lines:
|
| 1455 |
+
line = line.strip()
|
| 1456 |
+
if not line or line.startswith("#"): # Skip empty lines and comments
|
| 1457 |
+
continue
|
| 1458 |
+
|
| 1459 |
+
# Parse prompt line and create override dictionary
|
| 1460 |
+
prompt_data = parse_prompt_line(line)
|
| 1461 |
+
logger.info(f"Parsed prompt data: {prompt_data}")
|
| 1462 |
+
prompts_data.append(prompt_data)
|
| 1463 |
+
|
| 1464 |
+
return prompts_data
|
| 1465 |
+
|
| 1466 |
+
|
| 1467 |
+
def load_shared_models(args: argparse.Namespace) -> Dict:
|
| 1468 |
+
"""Load shared models for batch processing or interactive mode.
|
| 1469 |
+
Models are loaded to CPU to save memory.
|
| 1470 |
+
|
| 1471 |
+
Args:
|
| 1472 |
+
args: Base command line arguments
|
| 1473 |
+
|
| 1474 |
+
Returns:
|
| 1475 |
+
Dict: Dictionary of shared models
|
| 1476 |
+
"""
|
| 1477 |
+
shared_models = {}
|
| 1478 |
+
tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, "cpu")
|
| 1479 |
+
tokenizer2, text_encoder2 = load_text_encoder2(args)
|
| 1480 |
+
feature_extractor, image_encoder = load_image_encoders(args)
|
| 1481 |
+
vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, "cpu")
|
| 1482 |
+
shared_models["tokenizer1"] = tokenizer1
|
| 1483 |
+
shared_models["text_encoder1"] = text_encoder1
|
| 1484 |
+
shared_models["tokenizer2"] = tokenizer2
|
| 1485 |
+
shared_models["text_encoder2"] = text_encoder2
|
| 1486 |
+
shared_models["feature_extractor"] = feature_extractor
|
| 1487 |
+
shared_models["image_encoder"] = image_encoder
|
| 1488 |
+
shared_models["vae"] = vae
|
| 1489 |
+
|
| 1490 |
+
return shared_models
|
| 1491 |
+
|
| 1492 |
+
|
| 1493 |
+
def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
|
| 1494 |
+
"""Process multiple prompts with model reuse
|
| 1495 |
+
|
| 1496 |
+
Args:
|
| 1497 |
+
prompts_data: List of prompt data dictionaries
|
| 1498 |
+
args: Base command line arguments
|
| 1499 |
+
"""
|
| 1500 |
+
if not prompts_data:
|
| 1501 |
+
logger.warning("No valid prompts found")
|
| 1502 |
+
return
|
| 1503 |
+
|
| 1504 |
+
# 1. Load configuration
|
| 1505 |
+
gen_settings = get_generation_settings(args)
|
| 1506 |
+
device = gen_settings.device
|
| 1507 |
+
|
| 1508 |
+
# 2. Load models to CPU in advance except for VAE and DiT
|
| 1509 |
+
shared_models = load_shared_models(args)
|
| 1510 |
+
|
| 1511 |
+
# 3. Generate for each prompt
|
| 1512 |
+
all_latents = []
|
| 1513 |
+
all_prompt_args = []
|
| 1514 |
+
|
| 1515 |
+
with torch.no_grad():
|
| 1516 |
+
for prompt_data in prompts_data:
|
| 1517 |
+
prompt = prompt_data["prompt"]
|
| 1518 |
+
prompt_args = apply_overrides(args, prompt_data)
|
| 1519 |
+
logger.info(f"Processing prompt: {prompt}")
|
| 1520 |
+
|
| 1521 |
+
try:
|
| 1522 |
+
vae, latent = generate(prompt_args, gen_settings, shared_models)
|
| 1523 |
+
|
| 1524 |
+
# Save latent if needed
|
| 1525 |
+
if args.output_type == "latent" or args.output_type == "both" or args.output_type == "latent_images":
|
| 1526 |
+
height, width = latent.shape[-2], latent.shape[-1] # BCTHW
|
| 1527 |
+
height *= 8
|
| 1528 |
+
width *= 8
|
| 1529 |
+
save_latent(latent, prompt_args, height, width)
|
| 1530 |
+
|
| 1531 |
+
all_latents.append(latent)
|
| 1532 |
+
all_prompt_args.append(prompt_args)
|
| 1533 |
+
except Exception as e:
|
| 1534 |
+
logger.error(f"Error processing prompt: {prompt}. Error: {e}")
|
| 1535 |
+
continue
|
| 1536 |
+
|
| 1537 |
+
# 4. Free models
|
| 1538 |
+
if "model" in shared_models:
|
| 1539 |
+
del shared_models["model"]
|
| 1540 |
+
del shared_models["tokenizer1"]
|
| 1541 |
+
del shared_models["text_encoder1"]
|
| 1542 |
+
del shared_models["tokenizer2"]
|
| 1543 |
+
del shared_models["text_encoder2"]
|
| 1544 |
+
del shared_models["feature_extractor"]
|
| 1545 |
+
del shared_models["image_encoder"]
|
| 1546 |
+
|
| 1547 |
+
clean_memory_on_device(device)
|
| 1548 |
+
synchronize_device(device)
|
| 1549 |
+
|
| 1550 |
+
# 5. Decode latents if needed
|
| 1551 |
+
if args.output_type != "latent":
|
| 1552 |
+
logger.info("Decoding latents to videos/images")
|
| 1553 |
+
vae.to(device)
|
| 1554 |
+
|
| 1555 |
+
for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)):
|
| 1556 |
+
logger.info(f"Decoding output {i+1}/{len(all_latents)}")
|
| 1557 |
+
|
| 1558 |
+
# avoid saving latents again (ugly hack)
|
| 1559 |
+
if prompt_args.output_type == "both":
|
| 1560 |
+
prompt_args.output_type = "video"
|
| 1561 |
+
elif prompt_args.output_type == "latent_images":
|
| 1562 |
+
prompt_args.output_type = "images"
|
| 1563 |
+
|
| 1564 |
+
save_output(prompt_args, vae, latent[0], device)
|
| 1565 |
+
|
| 1566 |
+
|
| 1567 |
+
def process_interactive(args: argparse.Namespace) -> None:
|
| 1568 |
+
"""Process prompts in interactive mode
|
| 1569 |
+
|
| 1570 |
+
Args:
|
| 1571 |
+
args: Base command line arguments
|
| 1572 |
+
"""
|
| 1573 |
+
gen_settings = get_generation_settings(args)
|
| 1574 |
+
device = gen_settings.device
|
| 1575 |
+
shared_models = load_shared_models(args)
|
| 1576 |
+
|
| 1577 |
+
print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
|
| 1578 |
+
|
| 1579 |
+
try:
|
| 1580 |
+
while True:
|
| 1581 |
+
try:
|
| 1582 |
+
line = input("> ")
|
| 1583 |
+
if not line.strip():
|
| 1584 |
+
continue
|
| 1585 |
+
|
| 1586 |
+
# Parse prompt
|
| 1587 |
+
prompt_data = parse_prompt_line(line)
|
| 1588 |
+
prompt_args = apply_overrides(args, prompt_data)
|
| 1589 |
+
|
| 1590 |
+
# Generate latent
|
| 1591 |
+
vae, latent = generate(prompt_args, gen_settings, shared_models)
|
| 1592 |
+
|
| 1593 |
+
# Save latent and video
|
| 1594 |
+
save_output(prompt_args, vae, latent[0], device)
|
| 1595 |
+
|
| 1596 |
+
except KeyboardInterrupt:
|
| 1597 |
+
print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
|
| 1598 |
+
continue
|
| 1599 |
+
|
| 1600 |
+
except EOFError:
|
| 1601 |
+
print("\nExiting interactive mode")
|
| 1602 |
+
|
| 1603 |
+
|
| 1604 |
+
def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
|
| 1605 |
+
device = torch.device(args.device)
|
| 1606 |
+
|
| 1607 |
+
dit_weight_dtype = None # default
|
| 1608 |
+
if args.fp8_scaled:
|
| 1609 |
+
dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
|
| 1610 |
+
elif args.fp8:
|
| 1611 |
+
dit_weight_dtype = torch.float8_e4m3fn
|
| 1612 |
+
|
| 1613 |
+
logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}")
|
| 1614 |
+
|
| 1615 |
+
gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype)
|
| 1616 |
+
return gen_settings
|
| 1617 |
+
|
| 1618 |
+
|
| 1619 |
+
def main():
|
| 1620 |
+
# Parse arguments
|
| 1621 |
+
args = parse_args()
|
| 1622 |
+
|
| 1623 |
+
# Check if latents are provided
|
| 1624 |
+
latents_mode = args.latent_path is not None and len(args.latent_path) > 0
|
| 1625 |
+
|
| 1626 |
+
# Set device
|
| 1627 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
| 1628 |
+
device = torch.device(device)
|
| 1629 |
+
logger.info(f"Using device: {device}")
|
| 1630 |
+
args.device = device
|
| 1631 |
+
|
| 1632 |
+
if latents_mode:
|
| 1633 |
+
# Original latent decode mode
|
| 1634 |
+
original_base_names = []
|
| 1635 |
+
latents_list = []
|
| 1636 |
+
seeds = []
|
| 1637 |
+
|
| 1638 |
+
# assert len(args.latent_path) == 1, "Only one latent path is supported for now"
|
| 1639 |
+
|
| 1640 |
+
for latent_path in args.latent_path:
|
| 1641 |
+
original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
|
| 1642 |
+
seed = 0
|
| 1643 |
+
|
| 1644 |
+
if os.path.splitext(latent_path)[1] != ".safetensors":
|
| 1645 |
+
latents = torch.load(latent_path, map_location="cpu")
|
| 1646 |
+
else:
|
| 1647 |
+
latents = load_file(latent_path)["latent"]
|
| 1648 |
+
with safe_open(latent_path, framework="pt") as f:
|
| 1649 |
+
metadata = f.metadata()
|
| 1650 |
+
if metadata is None:
|
| 1651 |
+
metadata = {}
|
| 1652 |
+
logger.info(f"Loaded metadata: {metadata}")
|
| 1653 |
+
|
| 1654 |
+
if "seeds" in metadata:
|
| 1655 |
+
seed = int(metadata["seeds"])
|
| 1656 |
+
if "height" in metadata and "width" in metadata:
|
| 1657 |
+
height = int(metadata["height"])
|
| 1658 |
+
width = int(metadata["width"])
|
| 1659 |
+
args.video_size = [height, width]
|
| 1660 |
+
if "video_seconds" in metadata:
|
| 1661 |
+
args.video_seconds = float(metadata["video_seconds"])
|
| 1662 |
+
|
| 1663 |
+
seeds.append(seed)
|
| 1664 |
+
logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
|
| 1665 |
+
|
| 1666 |
+
if latents.ndim == 5: # [BCTHW]
|
| 1667 |
+
latents = latents.squeeze(0) # [CTHW]
|
| 1668 |
+
|
| 1669 |
+
latents_list.append(latents)
|
| 1670 |
+
|
| 1671 |
+
# latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
|
| 1672 |
+
|
| 1673 |
+
for i, latent in enumerate(latents_list):
|
| 1674 |
+
args.seed = seeds[i]
|
| 1675 |
+
|
| 1676 |
+
vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
|
| 1677 |
+
save_output(args, vae, latent, device, original_base_names)
|
| 1678 |
+
|
| 1679 |
+
elif args.from_file:
|
| 1680 |
+
# Batch mode from file
|
| 1681 |
+
|
| 1682 |
+
# Read prompts from file
|
| 1683 |
+
with open(args.from_file, "r", encoding="utf-8") as f:
|
| 1684 |
+
prompt_lines = f.readlines()
|
| 1685 |
+
|
| 1686 |
+
# Process prompts
|
| 1687 |
+
prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
|
| 1688 |
+
process_batch_prompts(prompts_data, args)
|
| 1689 |
+
|
| 1690 |
+
elif args.interactive:
|
| 1691 |
+
# Interactive mode
|
| 1692 |
+
process_interactive(args)
|
| 1693 |
+
|
| 1694 |
+
else:
|
| 1695 |
+
# Single prompt mode (original behavior)
|
| 1696 |
+
|
| 1697 |
+
# Generate latent
|
| 1698 |
+
gen_settings = get_generation_settings(args)
|
| 1699 |
+
vae, latent = generate(args, gen_settings)
|
| 1700 |
+
# print(f"Generated latent shape: {latent.shape}")
|
| 1701 |
+
if args.save_merged_model:
|
| 1702 |
+
return
|
| 1703 |
+
|
| 1704 |
+
# Save latent and video
|
| 1705 |
+
save_output(args, vae, latent[0], device)
|
| 1706 |
+
|
| 1707 |
+
logger.info("Done!")
|
| 1708 |
+
|
| 1709 |
+
|
| 1710 |
+
if __name__ == "__main__":
|
| 1711 |
+
main()
|
frame_pack/__init__.py
ADDED
|
File without changes
|
frame_pack/bucket_tools.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
bucket_options = {
|
| 2 |
+
640: [
|
| 3 |
+
(416, 960),
|
| 4 |
+
(448, 864),
|
| 5 |
+
(480, 832),
|
| 6 |
+
(512, 768),
|
| 7 |
+
(544, 704),
|
| 8 |
+
(576, 672),
|
| 9 |
+
(608, 640),
|
| 10 |
+
(640, 608),
|
| 11 |
+
(672, 576),
|
| 12 |
+
(704, 544),
|
| 13 |
+
(768, 512),
|
| 14 |
+
(832, 480),
|
| 15 |
+
(864, 448),
|
| 16 |
+
(960, 416),
|
| 17 |
+
],
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def find_nearest_bucket(h, w, resolution=640):
|
| 22 |
+
min_metric = float('inf')
|
| 23 |
+
best_bucket = None
|
| 24 |
+
for (bucket_h, bucket_w) in bucket_options[resolution]:
|
| 25 |
+
metric = abs(h * bucket_w - w * bucket_h)
|
| 26 |
+
if metric <= min_metric:
|
| 27 |
+
min_metric = metric
|
| 28 |
+
best_bucket = (bucket_h, bucket_w)
|
| 29 |
+
return best_bucket
|
| 30 |
+
|
frame_pack/clip_vision.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def hf_clip_vision_encode(image, feature_extractor, image_encoder):
|
| 5 |
+
assert isinstance(image, np.ndarray)
|
| 6 |
+
assert image.ndim == 3 and image.shape[2] == 3
|
| 7 |
+
assert image.dtype == np.uint8
|
| 8 |
+
|
| 9 |
+
preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(
|
| 10 |
+
device=image_encoder.device, dtype=image_encoder.dtype
|
| 11 |
+
)
|
| 12 |
+
image_encoder_output = image_encoder(**preprocessed)
|
| 13 |
+
|
| 14 |
+
return image_encoder_output
|
frame_pack/framepack_utils.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
from types import SimpleNamespace
|
| 4 |
+
from typing import Optional, Union
|
| 5 |
+
|
| 6 |
+
import accelerate
|
| 7 |
+
from accelerate import Accelerator, init_empty_weights
|
| 8 |
+
import torch
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
from transformers import (
|
| 11 |
+
LlamaTokenizerFast,
|
| 12 |
+
LlamaConfig,
|
| 13 |
+
LlamaModel,
|
| 14 |
+
CLIPTokenizer,
|
| 15 |
+
CLIPTextModel,
|
| 16 |
+
CLIPConfig,
|
| 17 |
+
SiglipImageProcessor,
|
| 18 |
+
SiglipVisionModel,
|
| 19 |
+
SiglipVisionConfig,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
from utils.safetensors_utils import load_split_weights
|
| 23 |
+
from hunyuan_model.vae import load_vae as hunyuan_load_vae
|
| 24 |
+
|
| 25 |
+
import logging
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
logging.basicConfig(level=logging.INFO)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_vae(
|
| 32 |
+
vae_path: str, vae_chunk_size: Optional[int], vae_spatial_tile_sample_min_size: Optional[int], device: Union[str, torch.device]
|
| 33 |
+
):
|
| 34 |
+
# single file and directory (contains 'vae') support
|
| 35 |
+
if os.path.isdir(vae_path):
|
| 36 |
+
vae_path = os.path.join(vae_path, "vae", "diffusion_pytorch_model.safetensors")
|
| 37 |
+
else:
|
| 38 |
+
vae_path = vae_path
|
| 39 |
+
|
| 40 |
+
vae_dtype = torch.float16 # if vae_dtype is None else str_to_dtype(vae_dtype)
|
| 41 |
+
vae, _, s_ratio, t_ratio = hunyuan_load_vae(vae_dtype=vae_dtype, device=device, vae_path=vae_path)
|
| 42 |
+
vae.eval()
|
| 43 |
+
# vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
|
| 44 |
+
|
| 45 |
+
# set chunk_size to CausalConv3d recursively
|
| 46 |
+
chunk_size = vae_chunk_size
|
| 47 |
+
if chunk_size is not None:
|
| 48 |
+
vae.set_chunk_size_for_causal_conv_3d(chunk_size)
|
| 49 |
+
logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
|
| 50 |
+
|
| 51 |
+
if vae_spatial_tile_sample_min_size is not None:
|
| 52 |
+
vae.enable_spatial_tiling(True)
|
| 53 |
+
vae.tile_sample_min_size = vae_spatial_tile_sample_min_size
|
| 54 |
+
vae.tile_latent_min_size = vae_spatial_tile_sample_min_size // 8
|
| 55 |
+
logger.info(f"Enabled spatial tiling with min size {vae_spatial_tile_sample_min_size}")
|
| 56 |
+
# elif vae_tiling:
|
| 57 |
+
else:
|
| 58 |
+
vae.enable_spatial_tiling(True)
|
| 59 |
+
|
| 60 |
+
return vae
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# region Text Encoders
|
| 64 |
+
|
| 65 |
+
# Text Encoder configs are copied from HunyuanVideo repo
|
| 66 |
+
|
| 67 |
+
LLAMA_CONFIG = {
|
| 68 |
+
"architectures": ["LlamaModel"],
|
| 69 |
+
"attention_bias": False,
|
| 70 |
+
"attention_dropout": 0.0,
|
| 71 |
+
"bos_token_id": 128000,
|
| 72 |
+
"eos_token_id": 128001,
|
| 73 |
+
"head_dim": 128,
|
| 74 |
+
"hidden_act": "silu",
|
| 75 |
+
"hidden_size": 4096,
|
| 76 |
+
"initializer_range": 0.02,
|
| 77 |
+
"intermediate_size": 14336,
|
| 78 |
+
"max_position_embeddings": 8192,
|
| 79 |
+
"mlp_bias": False,
|
| 80 |
+
"model_type": "llama",
|
| 81 |
+
"num_attention_heads": 32,
|
| 82 |
+
"num_hidden_layers": 32,
|
| 83 |
+
"num_key_value_heads": 8,
|
| 84 |
+
"pretraining_tp": 1,
|
| 85 |
+
"rms_norm_eps": 1e-05,
|
| 86 |
+
"rope_scaling": None,
|
| 87 |
+
"rope_theta": 500000.0,
|
| 88 |
+
"tie_word_embeddings": False,
|
| 89 |
+
"torch_dtype": "float16",
|
| 90 |
+
"transformers_version": "4.46.3",
|
| 91 |
+
"use_cache": True,
|
| 92 |
+
"vocab_size": 128320,
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
CLIP_CONFIG = {
|
| 96 |
+
# "_name_or_path": "/raid/aryan/llava-llama-3-8b-v1_1-extracted/text_encoder_2",
|
| 97 |
+
"architectures": ["CLIPTextModel"],
|
| 98 |
+
"attention_dropout": 0.0,
|
| 99 |
+
"bos_token_id": 0,
|
| 100 |
+
"dropout": 0.0,
|
| 101 |
+
"eos_token_id": 2,
|
| 102 |
+
"hidden_act": "quick_gelu",
|
| 103 |
+
"hidden_size": 768,
|
| 104 |
+
"initializer_factor": 1.0,
|
| 105 |
+
"initializer_range": 0.02,
|
| 106 |
+
"intermediate_size": 3072,
|
| 107 |
+
"layer_norm_eps": 1e-05,
|
| 108 |
+
"max_position_embeddings": 77,
|
| 109 |
+
"model_type": "clip_text_model",
|
| 110 |
+
"num_attention_heads": 12,
|
| 111 |
+
"num_hidden_layers": 12,
|
| 112 |
+
"pad_token_id": 1,
|
| 113 |
+
"projection_dim": 768,
|
| 114 |
+
"torch_dtype": "float16",
|
| 115 |
+
"transformers_version": "4.48.0.dev0",
|
| 116 |
+
"vocab_size": 49408,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def load_text_encoder1(
|
| 121 |
+
args, fp8_llm: Optional[bool] = False, device: Optional[Union[str, torch.device]] = None
|
| 122 |
+
) -> tuple[LlamaTokenizerFast, LlamaModel]:
|
| 123 |
+
# single file, split file and directory (contains 'text_encoder') support
|
| 124 |
+
logger.info(f"Loading text encoder 1 tokenizer")
|
| 125 |
+
tokenizer1 = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer")
|
| 126 |
+
|
| 127 |
+
logger.info(f"Loading text encoder 1 from {args.text_encoder1}")
|
| 128 |
+
if os.path.isdir(args.text_encoder1):
|
| 129 |
+
# load from directory, configs are in the directory
|
| 130 |
+
text_encoder1 = LlamaModel.from_pretrained(args.text_encoder1, subfolder="text_encoder", torch_dtype=torch.float16)
|
| 131 |
+
else:
|
| 132 |
+
# load from file, we create the model with the appropriate config
|
| 133 |
+
config = LlamaConfig(**LLAMA_CONFIG)
|
| 134 |
+
with init_empty_weights():
|
| 135 |
+
text_encoder1 = LlamaModel._from_config(config, torch_dtype=torch.float16)
|
| 136 |
+
|
| 137 |
+
state_dict = load_split_weights(args.text_encoder1)
|
| 138 |
+
|
| 139 |
+
# support weights from ComfyUI
|
| 140 |
+
if "model.embed_tokens.weight" in state_dict:
|
| 141 |
+
for key in list(state_dict.keys()):
|
| 142 |
+
if key.startswith("model."):
|
| 143 |
+
new_key = key.replace("model.", "")
|
| 144 |
+
state_dict[new_key] = state_dict[key]
|
| 145 |
+
del state_dict[key]
|
| 146 |
+
if "tokenizer" in state_dict:
|
| 147 |
+
state_dict.pop("tokenizer")
|
| 148 |
+
if "lm_head.weight" in state_dict:
|
| 149 |
+
state_dict.pop("lm_head.weight")
|
| 150 |
+
|
| 151 |
+
# # support weights from ComfyUI
|
| 152 |
+
# if "tokenizer" in state_dict:
|
| 153 |
+
# state_dict.pop("tokenizer")
|
| 154 |
+
|
| 155 |
+
text_encoder1.load_state_dict(state_dict, strict=True, assign=True)
|
| 156 |
+
|
| 157 |
+
if fp8_llm:
|
| 158 |
+
org_dtype = text_encoder1.dtype
|
| 159 |
+
logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
|
| 160 |
+
text_encoder1.to(device=device, dtype=torch.float8_e4m3fn)
|
| 161 |
+
|
| 162 |
+
# prepare LLM for fp8
|
| 163 |
+
def prepare_fp8(llama_model: LlamaModel, target_dtype):
|
| 164 |
+
def forward_hook(module):
|
| 165 |
+
def forward(hidden_states):
|
| 166 |
+
input_dtype = hidden_states.dtype
|
| 167 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 168 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 169 |
+
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
|
| 170 |
+
return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
|
| 171 |
+
|
| 172 |
+
return forward
|
| 173 |
+
|
| 174 |
+
for module in llama_model.modules():
|
| 175 |
+
if module.__class__.__name__ in ["Embedding"]:
|
| 176 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
| 177 |
+
module.to(target_dtype)
|
| 178 |
+
if module.__class__.__name__ in ["LlamaRMSNorm"]:
|
| 179 |
+
# print("set", module.__class__.__name__, "hooks")
|
| 180 |
+
module.forward = forward_hook(module)
|
| 181 |
+
|
| 182 |
+
prepare_fp8(text_encoder1, org_dtype)
|
| 183 |
+
else:
|
| 184 |
+
text_encoder1.to(device)
|
| 185 |
+
|
| 186 |
+
text_encoder1.eval()
|
| 187 |
+
return tokenizer1, text_encoder1
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def load_text_encoder2(args) -> tuple[CLIPTokenizer, CLIPTextModel]:
|
| 191 |
+
# single file and directory (contains 'text_encoder_2') support
|
| 192 |
+
logger.info(f"Loading text encoder 2 tokenizer")
|
| 193 |
+
tokenizer2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer_2")
|
| 194 |
+
|
| 195 |
+
logger.info(f"Loading text encoder 2 from {args.text_encoder2}")
|
| 196 |
+
if os.path.isdir(args.text_encoder2):
|
| 197 |
+
# load from directory, configs are in the directory
|
| 198 |
+
text_encoder2 = CLIPTextModel.from_pretrained(args.text_encoder2, subfolder="text_encoder_2", torch_dtype=torch.float16)
|
| 199 |
+
else:
|
| 200 |
+
# we only have one file, so we can load it directly
|
| 201 |
+
config = CLIPConfig(**CLIP_CONFIG)
|
| 202 |
+
with init_empty_weights():
|
| 203 |
+
text_encoder2 = CLIPTextModel._from_config(config, torch_dtype=torch.float16)
|
| 204 |
+
|
| 205 |
+
state_dict = load_file(args.text_encoder2)
|
| 206 |
+
|
| 207 |
+
text_encoder2.load_state_dict(state_dict, strict=True, assign=True)
|
| 208 |
+
|
| 209 |
+
text_encoder2.eval()
|
| 210 |
+
return tokenizer2, text_encoder2
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
# endregion
|
| 214 |
+
|
| 215 |
+
# region image encoder
|
| 216 |
+
|
| 217 |
+
# Siglip configs are copied from FramePack repo
|
| 218 |
+
FEATURE_EXTRACTOR_CONFIG = {
|
| 219 |
+
"do_convert_rgb": None,
|
| 220 |
+
"do_normalize": True,
|
| 221 |
+
"do_rescale": True,
|
| 222 |
+
"do_resize": True,
|
| 223 |
+
"image_mean": [0.5, 0.5, 0.5],
|
| 224 |
+
"image_processor_type": "SiglipImageProcessor",
|
| 225 |
+
"image_std": [0.5, 0.5, 0.5],
|
| 226 |
+
"processor_class": "SiglipProcessor",
|
| 227 |
+
"resample": 3,
|
| 228 |
+
"rescale_factor": 0.00392156862745098,
|
| 229 |
+
"size": {"height": 384, "width": 384},
|
| 230 |
+
}
|
| 231 |
+
IMAGE_ENCODER_CONFIG = {
|
| 232 |
+
"_name_or_path": "/home/lvmin/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-Redux-dev/snapshots/1282f955f706b5240161278f2ef261d2a29ad649/image_encoder",
|
| 233 |
+
"architectures": ["SiglipVisionModel"],
|
| 234 |
+
"attention_dropout": 0.0,
|
| 235 |
+
"hidden_act": "gelu_pytorch_tanh",
|
| 236 |
+
"hidden_size": 1152,
|
| 237 |
+
"image_size": 384,
|
| 238 |
+
"intermediate_size": 4304,
|
| 239 |
+
"layer_norm_eps": 1e-06,
|
| 240 |
+
"model_type": "siglip_vision_model",
|
| 241 |
+
"num_attention_heads": 16,
|
| 242 |
+
"num_channels": 3,
|
| 243 |
+
"num_hidden_layers": 27,
|
| 244 |
+
"patch_size": 14,
|
| 245 |
+
"torch_dtype": "bfloat16",
|
| 246 |
+
"transformers_version": "4.46.2",
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def load_image_encoders(args):
|
| 251 |
+
logger.info(f"Loading image encoder feature extractor")
|
| 252 |
+
feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG)
|
| 253 |
+
|
| 254 |
+
# single file, split file and directory (contains 'image_encoder') support
|
| 255 |
+
logger.info(f"Loading image encoder from {args.image_encoder}")
|
| 256 |
+
if os.path.isdir(args.image_encoder):
|
| 257 |
+
# load from directory, configs are in the directory
|
| 258 |
+
image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder, subfolder="image_encoder", torch_dtype=torch.float16)
|
| 259 |
+
else:
|
| 260 |
+
# load from file, we create the model with the appropriate config
|
| 261 |
+
config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG)
|
| 262 |
+
with init_empty_weights():
|
| 263 |
+
image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16)
|
| 264 |
+
|
| 265 |
+
state_dict = load_file(args.image_encoder)
|
| 266 |
+
|
| 267 |
+
image_encoder.load_state_dict(state_dict, strict=True, assign=True)
|
| 268 |
+
|
| 269 |
+
image_encoder.eval()
|
| 270 |
+
return feature_extractor, image_encoder
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# endregion
|
frame_pack/hunyuan.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# original code: https://github.com/lllyasviel/FramePack
|
| 2 |
+
# original license: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
# from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
|
| 7 |
+
# from diffusers_helper.utils import crop_or_pad_yield_mask
|
| 8 |
+
from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
| 9 |
+
from hunyuan_model.text_encoder import PROMPT_TEMPLATE
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@torch.no_grad()
|
| 13 |
+
def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256, custom_system_prompt=None):
|
| 14 |
+
assert isinstance(prompt, str)
|
| 15 |
+
|
| 16 |
+
prompt = [prompt]
|
| 17 |
+
|
| 18 |
+
# LLAMA
|
| 19 |
+
|
| 20 |
+
# We can verify crop_start by checking the token count of the prompt:
|
| 21 |
+
# custom_system_prompt = (
|
| 22 |
+
# "Describe the video by detailing the following aspects: "
|
| 23 |
+
# "1. The main content and theme of the video."
|
| 24 |
+
# "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
| 25 |
+
# "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
| 26 |
+
# "4. background environment, light, style and atmosphere."
|
| 27 |
+
# "5. camera angles, movements, and transitions used in the video:"
|
| 28 |
+
# )
|
| 29 |
+
if custom_system_prompt is None:
|
| 30 |
+
prompt_llama = [PROMPT_TEMPLATE["dit-llm-encode-video"]["template"].format(p) for p in prompt]
|
| 31 |
+
crop_start = PROMPT_TEMPLATE["dit-llm-encode-video"]["crop_start"]
|
| 32 |
+
else:
|
| 33 |
+
# count tokens for custom_system_prompt
|
| 34 |
+
full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{custom_system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
|
| 35 |
+
print(f"Custom system prompt: {full_prompt}")
|
| 36 |
+
system_prompt_tokens = tokenizer(full_prompt, return_tensors="pt", truncation=True).input_ids[0].shape[0]
|
| 37 |
+
print(f"Custom system prompt token count: {system_prompt_tokens}")
|
| 38 |
+
prompt_llama = [full_prompt + p + "<|eot_id|>" for p in prompt]
|
| 39 |
+
crop_start = system_prompt_tokens
|
| 40 |
+
|
| 41 |
+
llama_inputs = tokenizer(
|
| 42 |
+
prompt_llama,
|
| 43 |
+
padding="max_length",
|
| 44 |
+
max_length=max_length + crop_start,
|
| 45 |
+
truncation=True,
|
| 46 |
+
return_tensors="pt",
|
| 47 |
+
return_length=False,
|
| 48 |
+
return_overflowing_tokens=False,
|
| 49 |
+
return_attention_mask=True,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
|
| 53 |
+
llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
|
| 54 |
+
llama_attention_length = int(llama_attention_mask.sum())
|
| 55 |
+
|
| 56 |
+
llama_outputs = text_encoder(
|
| 57 |
+
input_ids=llama_input_ids,
|
| 58 |
+
attention_mask=llama_attention_mask,
|
| 59 |
+
output_hidden_states=True,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
|
| 63 |
+
# llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
|
| 64 |
+
llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
|
| 65 |
+
|
| 66 |
+
assert torch.all(llama_attention_mask.bool())
|
| 67 |
+
|
| 68 |
+
# CLIP
|
| 69 |
+
|
| 70 |
+
clip_l_input_ids = tokenizer_2(
|
| 71 |
+
prompt,
|
| 72 |
+
padding="max_length",
|
| 73 |
+
max_length=77,
|
| 74 |
+
truncation=True,
|
| 75 |
+
return_overflowing_tokens=False,
|
| 76 |
+
return_length=False,
|
| 77 |
+
return_tensors="pt",
|
| 78 |
+
).input_ids
|
| 79 |
+
clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
|
| 80 |
+
|
| 81 |
+
return llama_vec, clip_l_pooler
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@torch.no_grad()
|
| 85 |
+
def vae_decode_fake(latents):
|
| 86 |
+
latent_rgb_factors = [
|
| 87 |
+
[-0.0395, -0.0331, 0.0445],
|
| 88 |
+
[0.0696, 0.0795, 0.0518],
|
| 89 |
+
[0.0135, -0.0945, -0.0282],
|
| 90 |
+
[0.0108, -0.0250, -0.0765],
|
| 91 |
+
[-0.0209, 0.0032, 0.0224],
|
| 92 |
+
[-0.0804, -0.0254, -0.0639],
|
| 93 |
+
[-0.0991, 0.0271, -0.0669],
|
| 94 |
+
[-0.0646, -0.0422, -0.0400],
|
| 95 |
+
[-0.0696, -0.0595, -0.0894],
|
| 96 |
+
[-0.0799, -0.0208, -0.0375],
|
| 97 |
+
[0.1166, 0.1627, 0.0962],
|
| 98 |
+
[0.1165, 0.0432, 0.0407],
|
| 99 |
+
[-0.2315, -0.1920, -0.1355],
|
| 100 |
+
[-0.0270, 0.0401, -0.0821],
|
| 101 |
+
[-0.0616, -0.0997, -0.0727],
|
| 102 |
+
[0.0249, -0.0469, -0.1703],
|
| 103 |
+
] # From comfyui
|
| 104 |
+
|
| 105 |
+
latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
|
| 106 |
+
|
| 107 |
+
weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
|
| 108 |
+
bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
|
| 109 |
+
|
| 110 |
+
images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
|
| 111 |
+
images = images.clamp(0.0, 1.0)
|
| 112 |
+
|
| 113 |
+
return images
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@torch.no_grad()
|
| 117 |
+
def vae_decode(latents, vae, image_mode=False) -> torch.Tensor:
|
| 118 |
+
latents = latents / vae.config.scaling_factor
|
| 119 |
+
|
| 120 |
+
if not image_mode:
|
| 121 |
+
image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
|
| 122 |
+
else:
|
| 123 |
+
latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
|
| 124 |
+
image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
|
| 125 |
+
image = torch.cat(image, dim=2)
|
| 126 |
+
|
| 127 |
+
return image
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@torch.no_grad()
|
| 131 |
+
def vae_encode(image, vae: AutoencoderKLCausal3D) -> torch.Tensor:
|
| 132 |
+
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
|
| 133 |
+
latents = latents * vae.config.scaling_factor
|
| 134 |
+
return latents
|
frame_pack/hunyuan_video_packed.py
ADDED
|
@@ -0,0 +1,2015 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# original code: https://github.com/lllyasviel/FramePack
|
| 2 |
+
# original license: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import glob
|
| 5 |
+
import math
|
| 6 |
+
import numbers
|
| 7 |
+
import os
|
| 8 |
+
from types import SimpleNamespace
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import einops
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
from modules.custom_offloading_utils import ModelOffloader
|
| 18 |
+
from utils.safetensors_utils import load_split_weights
|
| 19 |
+
from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8
|
| 20 |
+
from accelerate import init_empty_weights
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
# raise NotImplementedError
|
| 24 |
+
from xformers.ops import memory_efficient_attention as xformers_attn_func
|
| 25 |
+
|
| 26 |
+
print("Xformers is installed!")
|
| 27 |
+
except:
|
| 28 |
+
print("Xformers is not installed!")
|
| 29 |
+
xformers_attn_func = None
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
# raise NotImplementedError
|
| 33 |
+
from flash_attn import flash_attn_varlen_func, flash_attn_func
|
| 34 |
+
|
| 35 |
+
print("Flash Attn is installed!")
|
| 36 |
+
except:
|
| 37 |
+
print("Flash Attn is not installed!")
|
| 38 |
+
flash_attn_varlen_func = None
|
| 39 |
+
flash_attn_func = None
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
# raise NotImplementedError
|
| 43 |
+
from sageattention import sageattn_varlen, sageattn
|
| 44 |
+
|
| 45 |
+
print("Sage Attn is installed!")
|
| 46 |
+
except:
|
| 47 |
+
print("Sage Attn is not installed!")
|
| 48 |
+
sageattn_varlen = None
|
| 49 |
+
sageattn = None
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
import logging
|
| 53 |
+
|
| 54 |
+
logger = logging.getLogger(__name__)
|
| 55 |
+
logging.basicConfig(level=logging.INFO)
|
| 56 |
+
|
| 57 |
+
# region diffusers
|
| 58 |
+
|
| 59 |
+
# copied from diffusers with some modifications to minimize dependencies
|
| 60 |
+
# original code: https://github.com/huggingface/diffusers/
|
| 61 |
+
# original license: Apache-2.0
|
| 62 |
+
|
| 63 |
+
ACT2CLS = {
|
| 64 |
+
"swish": nn.SiLU,
|
| 65 |
+
"silu": nn.SiLU,
|
| 66 |
+
"mish": nn.Mish,
|
| 67 |
+
"gelu": nn.GELU,
|
| 68 |
+
"relu": nn.ReLU,
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_activation(act_fn: str) -> nn.Module:
|
| 73 |
+
"""Helper function to get activation function from string.
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
act_fn (str): Name of activation function.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
nn.Module: Activation function.
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
act_fn = act_fn.lower()
|
| 83 |
+
if act_fn in ACT2CLS:
|
| 84 |
+
return ACT2CLS[act_fn]()
|
| 85 |
+
else:
|
| 86 |
+
raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def get_timestep_embedding(
|
| 90 |
+
timesteps: torch.Tensor,
|
| 91 |
+
embedding_dim: int,
|
| 92 |
+
flip_sin_to_cos: bool = False,
|
| 93 |
+
downscale_freq_shift: float = 1,
|
| 94 |
+
scale: float = 1,
|
| 95 |
+
max_period: int = 10000,
|
| 96 |
+
):
|
| 97 |
+
"""
|
| 98 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 99 |
+
|
| 100 |
+
Args
|
| 101 |
+
timesteps (torch.Tensor):
|
| 102 |
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 103 |
+
embedding_dim (int):
|
| 104 |
+
the dimension of the output.
|
| 105 |
+
flip_sin_to_cos (bool):
|
| 106 |
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
| 107 |
+
downscale_freq_shift (float):
|
| 108 |
+
Controls the delta between frequencies between dimensions
|
| 109 |
+
scale (float):
|
| 110 |
+
Scaling factor applied to the embeddings.
|
| 111 |
+
max_period (int):
|
| 112 |
+
Controls the maximum frequency of the embeddings
|
| 113 |
+
Returns
|
| 114 |
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
| 115 |
+
"""
|
| 116 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 117 |
+
|
| 118 |
+
half_dim = embedding_dim // 2
|
| 119 |
+
exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
|
| 120 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 121 |
+
|
| 122 |
+
emb = torch.exp(exponent)
|
| 123 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 124 |
+
|
| 125 |
+
# scale embeddings
|
| 126 |
+
emb = scale * emb
|
| 127 |
+
|
| 128 |
+
# concat sine and cosine embeddings
|
| 129 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 130 |
+
|
| 131 |
+
# flip sine and cosine embeddings
|
| 132 |
+
if flip_sin_to_cos:
|
| 133 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 134 |
+
|
| 135 |
+
# zero pad
|
| 136 |
+
if embedding_dim % 2 == 1:
|
| 137 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 138 |
+
return emb
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class TimestepEmbedding(nn.Module):
|
| 142 |
+
def __init__(
|
| 143 |
+
self,
|
| 144 |
+
in_channels: int,
|
| 145 |
+
time_embed_dim: int,
|
| 146 |
+
act_fn: str = "silu",
|
| 147 |
+
out_dim: int = None,
|
| 148 |
+
post_act_fn: Optional[str] = None,
|
| 149 |
+
cond_proj_dim=None,
|
| 150 |
+
sample_proj_bias=True,
|
| 151 |
+
):
|
| 152 |
+
super().__init__()
|
| 153 |
+
|
| 154 |
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
| 155 |
+
|
| 156 |
+
if cond_proj_dim is not None:
|
| 157 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
| 158 |
+
else:
|
| 159 |
+
self.cond_proj = None
|
| 160 |
+
|
| 161 |
+
self.act = get_activation(act_fn)
|
| 162 |
+
|
| 163 |
+
if out_dim is not None:
|
| 164 |
+
time_embed_dim_out = out_dim
|
| 165 |
+
else:
|
| 166 |
+
time_embed_dim_out = time_embed_dim
|
| 167 |
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
| 168 |
+
|
| 169 |
+
if post_act_fn is None:
|
| 170 |
+
self.post_act = None
|
| 171 |
+
else:
|
| 172 |
+
self.post_act = get_activation(post_act_fn)
|
| 173 |
+
|
| 174 |
+
def forward(self, sample, condition=None):
|
| 175 |
+
if condition is not None:
|
| 176 |
+
sample = sample + self.cond_proj(condition)
|
| 177 |
+
sample = self.linear_1(sample)
|
| 178 |
+
|
| 179 |
+
if self.act is not None:
|
| 180 |
+
sample = self.act(sample)
|
| 181 |
+
|
| 182 |
+
sample = self.linear_2(sample)
|
| 183 |
+
|
| 184 |
+
if self.post_act is not None:
|
| 185 |
+
sample = self.post_act(sample)
|
| 186 |
+
return sample
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
class Timesteps(nn.Module):
|
| 190 |
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.num_channels = num_channels
|
| 193 |
+
self.flip_sin_to_cos = flip_sin_to_cos
|
| 194 |
+
self.downscale_freq_shift = downscale_freq_shift
|
| 195 |
+
self.scale = scale
|
| 196 |
+
|
| 197 |
+
def forward(self, timesteps):
|
| 198 |
+
t_emb = get_timestep_embedding(
|
| 199 |
+
timesteps,
|
| 200 |
+
self.num_channels,
|
| 201 |
+
flip_sin_to_cos=self.flip_sin_to_cos,
|
| 202 |
+
downscale_freq_shift=self.downscale_freq_shift,
|
| 203 |
+
scale=self.scale,
|
| 204 |
+
)
|
| 205 |
+
return t_emb
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class FP32SiLU(nn.Module):
|
| 209 |
+
r"""
|
| 210 |
+
SiLU activation function with input upcasted to torch.float32.
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(self):
|
| 214 |
+
super().__init__()
|
| 215 |
+
|
| 216 |
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
| 217 |
+
return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class GELU(nn.Module):
|
| 221 |
+
r"""
|
| 222 |
+
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
| 223 |
+
|
| 224 |
+
Parameters:
|
| 225 |
+
dim_in (`int`): The number of channels in the input.
|
| 226 |
+
dim_out (`int`): The number of channels in the output.
|
| 227 |
+
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
|
| 228 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
| 234 |
+
self.approximate = approximate
|
| 235 |
+
|
| 236 |
+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
|
| 237 |
+
# if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
|
| 238 |
+
# # fp16 gelu not supported on mps before torch 2.0
|
| 239 |
+
# return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
| 240 |
+
return F.gelu(gate, approximate=self.approximate)
|
| 241 |
+
|
| 242 |
+
def forward(self, hidden_states):
|
| 243 |
+
hidden_states = self.proj(hidden_states)
|
| 244 |
+
hidden_states = self.gelu(hidden_states)
|
| 245 |
+
return hidden_states
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class PixArtAlphaTextProjection(nn.Module):
|
| 249 |
+
"""
|
| 250 |
+
Projects caption embeddings. Also handles dropout for classifier-free guidance.
|
| 251 |
+
|
| 252 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
| 253 |
+
"""
|
| 254 |
+
|
| 255 |
+
def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
|
| 256 |
+
super().__init__()
|
| 257 |
+
if out_features is None:
|
| 258 |
+
out_features = hidden_size
|
| 259 |
+
self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
|
| 260 |
+
if act_fn == "gelu_tanh":
|
| 261 |
+
self.act_1 = nn.GELU(approximate="tanh")
|
| 262 |
+
elif act_fn == "silu":
|
| 263 |
+
self.act_1 = nn.SiLU()
|
| 264 |
+
elif act_fn == "silu_fp32":
|
| 265 |
+
self.act_1 = FP32SiLU()
|
| 266 |
+
else:
|
| 267 |
+
raise ValueError(f"Unknown activation function: {act_fn}")
|
| 268 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
|
| 269 |
+
|
| 270 |
+
def forward(self, caption):
|
| 271 |
+
hidden_states = self.linear_1(caption)
|
| 272 |
+
hidden_states = self.act_1(hidden_states)
|
| 273 |
+
hidden_states = self.linear_2(hidden_states)
|
| 274 |
+
return hidden_states
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class LayerNormFramePack(nn.LayerNorm):
|
| 278 |
+
# casting to dtype of input tensor is added
|
| 279 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 280 |
+
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class FP32LayerNormFramePack(nn.LayerNorm):
|
| 284 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 285 |
+
origin_dtype = x.dtype
|
| 286 |
+
return torch.nn.functional.layer_norm(
|
| 287 |
+
x.float(),
|
| 288 |
+
self.normalized_shape,
|
| 289 |
+
self.weight.float() if self.weight is not None else None,
|
| 290 |
+
self.bias.float() if self.bias is not None else None,
|
| 291 |
+
self.eps,
|
| 292 |
+
).to(origin_dtype)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
class RMSNormFramePack(nn.Module):
|
| 296 |
+
r"""
|
| 297 |
+
RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
|
| 298 |
+
|
| 299 |
+
Args:
|
| 300 |
+
dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
|
| 301 |
+
eps (`float`): Small value to use when calculating the reciprocal of the square-root.
|
| 302 |
+
elementwise_affine (`bool`, defaults to `True`):
|
| 303 |
+
Boolean flag to denote if affine transformation should be applied.
|
| 304 |
+
bias (`bool`, defaults to False): If also training the `bias` param.
|
| 305 |
+
"""
|
| 306 |
+
|
| 307 |
+
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
|
| 308 |
+
super().__init__()
|
| 309 |
+
|
| 310 |
+
self.eps = eps
|
| 311 |
+
self.elementwise_affine = elementwise_affine
|
| 312 |
+
|
| 313 |
+
if isinstance(dim, numbers.Integral):
|
| 314 |
+
dim = (dim,)
|
| 315 |
+
|
| 316 |
+
self.dim = torch.Size(dim)
|
| 317 |
+
|
| 318 |
+
self.weight = None
|
| 319 |
+
self.bias = None
|
| 320 |
+
|
| 321 |
+
if elementwise_affine:
|
| 322 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 323 |
+
if bias:
|
| 324 |
+
self.bias = nn.Parameter(torch.zeros(dim))
|
| 325 |
+
|
| 326 |
+
def forward(self, hidden_states):
|
| 327 |
+
input_dtype = hidden_states.dtype
|
| 328 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
| 329 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
| 330 |
+
|
| 331 |
+
if self.weight is None:
|
| 332 |
+
return hidden_states.to(input_dtype)
|
| 333 |
+
|
| 334 |
+
return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class AdaLayerNormContinuousFramePack(nn.Module):
|
| 338 |
+
r"""
|
| 339 |
+
Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
|
| 340 |
+
|
| 341 |
+
Args:
|
| 342 |
+
embedding_dim (`int`): Embedding dimension to use during projection.
|
| 343 |
+
conditioning_embedding_dim (`int`): Dimension of the input condition.
|
| 344 |
+
elementwise_affine (`bool`, defaults to `True`):
|
| 345 |
+
Boolean flag to denote if affine transformation should be applied.
|
| 346 |
+
eps (`float`, defaults to 1e-5): Epsilon factor.
|
| 347 |
+
bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
|
| 348 |
+
norm_type (`str`, defaults to `"layer_norm"`):
|
| 349 |
+
Normalization layer to use. Values supported: "layer_norm", "rms_norm".
|
| 350 |
+
"""
|
| 351 |
+
|
| 352 |
+
def __init__(
|
| 353 |
+
self,
|
| 354 |
+
embedding_dim: int,
|
| 355 |
+
conditioning_embedding_dim: int,
|
| 356 |
+
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
| 357 |
+
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
| 358 |
+
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
| 359 |
+
# However, this is how it was implemented in the original code, and it's rather likely you should
|
| 360 |
+
# set `elementwise_affine` to False.
|
| 361 |
+
elementwise_affine=True,
|
| 362 |
+
eps=1e-5,
|
| 363 |
+
bias=True,
|
| 364 |
+
norm_type="layer_norm",
|
| 365 |
+
):
|
| 366 |
+
super().__init__()
|
| 367 |
+
self.silu = nn.SiLU()
|
| 368 |
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
| 369 |
+
if norm_type == "layer_norm":
|
| 370 |
+
self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias)
|
| 371 |
+
elif norm_type == "rms_norm":
|
| 372 |
+
self.norm = RMSNormFramePack(embedding_dim, eps, elementwise_affine)
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
| 375 |
+
|
| 376 |
+
def forward(self, x, conditioning_embedding):
|
| 377 |
+
emb = self.linear(self.silu(conditioning_embedding))
|
| 378 |
+
scale, shift = emb.chunk(2, dim=1)
|
| 379 |
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
| 380 |
+
return x
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
class LinearActivation(nn.Module):
|
| 384 |
+
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
|
| 385 |
+
super().__init__()
|
| 386 |
+
|
| 387 |
+
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
|
| 388 |
+
self.activation = get_activation(activation)
|
| 389 |
+
|
| 390 |
+
def forward(self, hidden_states):
|
| 391 |
+
hidden_states = self.proj(hidden_states)
|
| 392 |
+
return self.activation(hidden_states)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class FeedForward(nn.Module):
|
| 396 |
+
r"""
|
| 397 |
+
A feed-forward layer.
|
| 398 |
+
|
| 399 |
+
Parameters:
|
| 400 |
+
dim (`int`): The number of channels in the input.
|
| 401 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 402 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 403 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 404 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 405 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
| 406 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
def __init__(
|
| 410 |
+
self,
|
| 411 |
+
dim: int,
|
| 412 |
+
dim_out: Optional[int] = None,
|
| 413 |
+
mult: int = 4,
|
| 414 |
+
dropout: float = 0.0,
|
| 415 |
+
activation_fn: str = "geglu",
|
| 416 |
+
final_dropout: bool = False,
|
| 417 |
+
inner_dim=None,
|
| 418 |
+
bias: bool = True,
|
| 419 |
+
):
|
| 420 |
+
super().__init__()
|
| 421 |
+
if inner_dim is None:
|
| 422 |
+
inner_dim = int(dim * mult)
|
| 423 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 424 |
+
|
| 425 |
+
# if activation_fn == "gelu":
|
| 426 |
+
# act_fn = GELU(dim, inner_dim, bias=bias)
|
| 427 |
+
if activation_fn == "gelu-approximate":
|
| 428 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
| 429 |
+
# elif activation_fn == "geglu":
|
| 430 |
+
# act_fn = GEGLU(dim, inner_dim, bias=bias)
|
| 431 |
+
# elif activation_fn == "geglu-approximate":
|
| 432 |
+
# act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
| 433 |
+
# elif activation_fn == "swiglu":
|
| 434 |
+
# act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
| 435 |
+
elif activation_fn == "linear-silu":
|
| 436 |
+
act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
|
| 437 |
+
else:
|
| 438 |
+
raise ValueError(f"Unknown activation function: {activation_fn}")
|
| 439 |
+
|
| 440 |
+
self.net = nn.ModuleList([])
|
| 441 |
+
# project in
|
| 442 |
+
self.net.append(act_fn)
|
| 443 |
+
# project dropout
|
| 444 |
+
self.net.append(nn.Dropout(dropout))
|
| 445 |
+
# project out
|
| 446 |
+
self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
|
| 447 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
| 448 |
+
if final_dropout:
|
| 449 |
+
self.net.append(nn.Dropout(dropout))
|
| 450 |
+
|
| 451 |
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 452 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
| 453 |
+
# deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
| 454 |
+
# deprecate("scale", "1.0.0", deprecation_message)
|
| 455 |
+
raise ValueError("scale is not supported in this version. Please remove it.")
|
| 456 |
+
for module in self.net:
|
| 457 |
+
hidden_states = module(hidden_states)
|
| 458 |
+
return hidden_states
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
# @maybe_allow_in_graph
|
| 462 |
+
class Attention(nn.Module):
|
| 463 |
+
r"""
|
| 464 |
+
Minimal copy of Attention class from diffusers.
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
def __init__(
|
| 468 |
+
self,
|
| 469 |
+
query_dim: int,
|
| 470 |
+
cross_attention_dim: Optional[int] = None,
|
| 471 |
+
heads: int = 8,
|
| 472 |
+
dim_head: int = 64,
|
| 473 |
+
bias: bool = False,
|
| 474 |
+
qk_norm: Optional[str] = None,
|
| 475 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 476 |
+
eps: float = 1e-5,
|
| 477 |
+
processor: Optional[any] = None,
|
| 478 |
+
out_dim: int = None,
|
| 479 |
+
context_pre_only=None,
|
| 480 |
+
pre_only=False,
|
| 481 |
+
):
|
| 482 |
+
super().__init__()
|
| 483 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 484 |
+
self.inner_kv_dim = self.inner_dim # if kv_heads is None else dim_head * kv_heads
|
| 485 |
+
self.query_dim = query_dim
|
| 486 |
+
self.use_bias = bias
|
| 487 |
+
self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 488 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 489 |
+
self.out_context_dim = query_dim
|
| 490 |
+
self.context_pre_only = context_pre_only
|
| 491 |
+
self.pre_only = pre_only
|
| 492 |
+
|
| 493 |
+
self.scale = dim_head**-0.5
|
| 494 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 495 |
+
|
| 496 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 497 |
+
|
| 498 |
+
if qk_norm is None:
|
| 499 |
+
self.norm_q = None
|
| 500 |
+
self.norm_k = None
|
| 501 |
+
elif qk_norm == "rms_norm":
|
| 502 |
+
self.norm_q = RMSNormFramePack(dim_head, eps=eps)
|
| 503 |
+
self.norm_k = RMSNormFramePack(dim_head, eps=eps)
|
| 504 |
+
else:
|
| 505 |
+
raise ValueError(
|
| 506 |
+
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
|
| 510 |
+
self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
| 511 |
+
self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
|
| 512 |
+
|
| 513 |
+
self.added_proj_bias = True # added_proj_bias
|
| 514 |
+
if self.added_kv_proj_dim is not None:
|
| 515 |
+
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
|
| 516 |
+
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
|
| 517 |
+
if self.context_pre_only is not None:
|
| 518 |
+
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
| 519 |
+
else:
|
| 520 |
+
self.add_q_proj = None
|
| 521 |
+
self.add_k_proj = None
|
| 522 |
+
self.add_v_proj = None
|
| 523 |
+
|
| 524 |
+
if not self.pre_only:
|
| 525 |
+
self.to_out = nn.ModuleList([])
|
| 526 |
+
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=True))
|
| 527 |
+
# self.to_out.append(nn.Dropout(dropout))
|
| 528 |
+
self.to_out.append(nn.Identity()) # dropout=0.0
|
| 529 |
+
else:
|
| 530 |
+
self.to_out = None
|
| 531 |
+
|
| 532 |
+
if self.context_pre_only is not None and not self.context_pre_only:
|
| 533 |
+
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=True)
|
| 534 |
+
else:
|
| 535 |
+
self.to_add_out = None
|
| 536 |
+
|
| 537 |
+
if qk_norm is not None and added_kv_proj_dim is not None:
|
| 538 |
+
if qk_norm == "rms_norm":
|
| 539 |
+
self.norm_added_q = RMSNormFramePack(dim_head, eps=eps)
|
| 540 |
+
self.norm_added_k = RMSNormFramePack(dim_head, eps=eps)
|
| 541 |
+
else:
|
| 542 |
+
raise ValueError(f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`")
|
| 543 |
+
else:
|
| 544 |
+
self.norm_added_q = None
|
| 545 |
+
self.norm_added_k = None
|
| 546 |
+
|
| 547 |
+
# set attention processor
|
| 548 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
| 549 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
| 550 |
+
if processor is None:
|
| 551 |
+
processor = AttnProcessor2_0()
|
| 552 |
+
self.set_processor(processor)
|
| 553 |
+
|
| 554 |
+
def set_processor(self, processor: any) -> None:
|
| 555 |
+
self.processor = processor
|
| 556 |
+
|
| 557 |
+
def get_processor(self) -> any:
|
| 558 |
+
return self.processor
|
| 559 |
+
|
| 560 |
+
def forward(
|
| 561 |
+
self,
|
| 562 |
+
hidden_states: torch.Tensor,
|
| 563 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 564 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 565 |
+
**cross_attention_kwargs,
|
| 566 |
+
) -> torch.Tensor:
|
| 567 |
+
return self.processor(
|
| 568 |
+
self,
|
| 569 |
+
hidden_states,
|
| 570 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 571 |
+
attention_mask=attention_mask,
|
| 572 |
+
**cross_attention_kwargs,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
def prepare_attention_mask(
|
| 576 |
+
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
|
| 577 |
+
) -> torch.Tensor:
|
| 578 |
+
r"""
|
| 579 |
+
Prepare the attention mask for the attention computation.
|
| 580 |
+
|
| 581 |
+
Args:
|
| 582 |
+
attention_mask (`torch.Tensor`):
|
| 583 |
+
The attention mask to prepare.
|
| 584 |
+
target_length (`int`):
|
| 585 |
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
| 586 |
+
batch_size (`int`):
|
| 587 |
+
The batch size, which is used to repeat the attention mask.
|
| 588 |
+
out_dim (`int`, *optional*, defaults to `3`):
|
| 589 |
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
| 590 |
+
|
| 591 |
+
Returns:
|
| 592 |
+
`torch.Tensor`: The prepared attention mask.
|
| 593 |
+
"""
|
| 594 |
+
head_size = self.heads
|
| 595 |
+
if attention_mask is None:
|
| 596 |
+
return attention_mask
|
| 597 |
+
|
| 598 |
+
current_length: int = attention_mask.shape[-1]
|
| 599 |
+
if current_length != target_length:
|
| 600 |
+
if attention_mask.device.type == "mps":
|
| 601 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
| 602 |
+
# Instead, we can manually construct the padding tensor.
|
| 603 |
+
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
| 604 |
+
padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
|
| 605 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
| 606 |
+
else:
|
| 607 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
| 608 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
| 609 |
+
# remaining_length: int = target_length - current_length
|
| 610 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
| 611 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 612 |
+
|
| 613 |
+
if out_dim == 3:
|
| 614 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
| 615 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0, output_size=attention_mask.shape[0] * head_size)
|
| 616 |
+
elif out_dim == 4:
|
| 617 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 618 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1, output_size=attention_mask.shape[1] * head_size)
|
| 619 |
+
|
| 620 |
+
return attention_mask
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
class AttnProcessor2_0:
|
| 624 |
+
r"""
|
| 625 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
| 626 |
+
"""
|
| 627 |
+
|
| 628 |
+
def __init__(self):
|
| 629 |
+
if not hasattr(F, "scaled_dot_product_attention"):
|
| 630 |
+
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
| 631 |
+
|
| 632 |
+
def __call__(
|
| 633 |
+
self,
|
| 634 |
+
attn: Attention,
|
| 635 |
+
hidden_states: torch.Tensor,
|
| 636 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 637 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 638 |
+
temb: Optional[torch.Tensor] = None,
|
| 639 |
+
*args,
|
| 640 |
+
**kwargs,
|
| 641 |
+
) -> torch.Tensor:
|
| 642 |
+
input_ndim = hidden_states.ndim
|
| 643 |
+
|
| 644 |
+
if input_ndim == 4:
|
| 645 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 646 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 647 |
+
|
| 648 |
+
batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
| 649 |
+
|
| 650 |
+
if attention_mask is not None:
|
| 651 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
| 652 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 653 |
+
# (batch, heads, source_length, target_length)
|
| 654 |
+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 655 |
+
|
| 656 |
+
query = attn.to_q(hidden_states)
|
| 657 |
+
query_dtype = query.dtype # store dtype before potentially deleting query
|
| 658 |
+
|
| 659 |
+
if encoder_hidden_states is None:
|
| 660 |
+
encoder_hidden_states = hidden_states
|
| 661 |
+
|
| 662 |
+
key = attn.to_k(encoder_hidden_states)
|
| 663 |
+
value = attn.to_v(encoder_hidden_states)
|
| 664 |
+
|
| 665 |
+
inner_dim = key.shape[-1]
|
| 666 |
+
head_dim = inner_dim // attn.heads
|
| 667 |
+
|
| 668 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 669 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 670 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 671 |
+
|
| 672 |
+
if attn.norm_q is not None:
|
| 673 |
+
query = attn.norm_q(query)
|
| 674 |
+
if attn.norm_k is not None:
|
| 675 |
+
key = attn.norm_k(key)
|
| 676 |
+
|
| 677 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 678 |
+
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
|
| 679 |
+
del query, key, value, attention_mask # free memory
|
| 680 |
+
|
| 681 |
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 682 |
+
hidden_states = hidden_states.to(query_dtype) # use stored dtype
|
| 683 |
+
|
| 684 |
+
# linear proj
|
| 685 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 686 |
+
# dropout
|
| 687 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 688 |
+
|
| 689 |
+
if input_ndim == 4:
|
| 690 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 691 |
+
|
| 692 |
+
return hidden_states
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
# endregion diffusers
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def pad_for_3d_conv(x, kernel_size):
|
| 699 |
+
b, c, t, h, w = x.shape
|
| 700 |
+
pt, ph, pw = kernel_size
|
| 701 |
+
pad_t = (pt - (t % pt)) % pt
|
| 702 |
+
pad_h = (ph - (h % ph)) % ph
|
| 703 |
+
pad_w = (pw - (w % pw)) % pw
|
| 704 |
+
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
|
| 705 |
+
|
| 706 |
+
|
| 707 |
+
def center_down_sample_3d(x, kernel_size):
|
| 708 |
+
# pt, ph, pw = kernel_size
|
| 709 |
+
# cp = (pt * ph * pw) // 2
|
| 710 |
+
# xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
|
| 711 |
+
# xc = xp[cp]
|
| 712 |
+
# return xc
|
| 713 |
+
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
def get_cu_seqlens(text_mask, img_len):
|
| 717 |
+
batch_size = text_mask.shape[0]
|
| 718 |
+
text_len = text_mask.sum(dim=1)
|
| 719 |
+
max_len = text_mask.shape[1] + img_len
|
| 720 |
+
|
| 721 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device) # ensure device match
|
| 722 |
+
|
| 723 |
+
for i in range(batch_size):
|
| 724 |
+
s = text_len[i] + img_len
|
| 725 |
+
s1 = i * max_len + s
|
| 726 |
+
s2 = (i + 1) * max_len
|
| 727 |
+
cu_seqlens[2 * i + 1] = s1
|
| 728 |
+
cu_seqlens[2 * i + 2] = s2
|
| 729 |
+
|
| 730 |
+
return cu_seqlens
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
def apply_rotary_emb_transposed(x, freqs_cis):
|
| 734 |
+
cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
|
| 735 |
+
del freqs_cis
|
| 736 |
+
x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
|
| 737 |
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 738 |
+
del x_real, x_imag
|
| 739 |
+
return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
| 740 |
+
|
| 741 |
+
|
| 742 |
+
def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=None, split_attn=False):
|
| 743 |
+
if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
|
| 744 |
+
if attn_mode == "sageattn" or attn_mode is None and sageattn is not None:
|
| 745 |
+
x = sageattn(q, k, v, tensor_layout="NHD")
|
| 746 |
+
return x
|
| 747 |
+
|
| 748 |
+
if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None:
|
| 749 |
+
x = flash_attn_func(q, k, v)
|
| 750 |
+
return x
|
| 751 |
+
|
| 752 |
+
if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None:
|
| 753 |
+
x = xformers_attn_func(q, k, v)
|
| 754 |
+
return x
|
| 755 |
+
|
| 756 |
+
x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(
|
| 757 |
+
1, 2
|
| 758 |
+
)
|
| 759 |
+
return x
|
| 760 |
+
if split_attn:
|
| 761 |
+
if attn_mode == "sageattn" or attn_mode is None and sageattn is not None:
|
| 762 |
+
x = torch.empty_like(q)
|
| 763 |
+
for i in range(q.size(0)):
|
| 764 |
+
x[i : i + 1] = sageattn(q[i : i + 1], k[i : i + 1], v[i : i + 1], tensor_layout="NHD")
|
| 765 |
+
return x
|
| 766 |
+
|
| 767 |
+
if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None:
|
| 768 |
+
x = torch.empty_like(q)
|
| 769 |
+
for i in range(q.size(0)):
|
| 770 |
+
x[i : i + 1] = flash_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1])
|
| 771 |
+
return x
|
| 772 |
+
|
| 773 |
+
if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None:
|
| 774 |
+
x = torch.empty_like(q)
|
| 775 |
+
for i in range(q.size(0)):
|
| 776 |
+
x[i : i + 1] = xformers_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1])
|
| 777 |
+
return x
|
| 778 |
+
|
| 779 |
+
q = q.transpose(1, 2)
|
| 780 |
+
k = k.transpose(1, 2)
|
| 781 |
+
v = v.transpose(1, 2)
|
| 782 |
+
x = torch.empty_like(q)
|
| 783 |
+
for i in range(q.size(0)):
|
| 784 |
+
x[i : i + 1] = torch.nn.functional.scaled_dot_product_attention(q[i : i + 1], k[i : i + 1], v[i : i + 1])
|
| 785 |
+
x = x.transpose(1, 2)
|
| 786 |
+
return x
|
| 787 |
+
|
| 788 |
+
batch_size = q.shape[0]
|
| 789 |
+
q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
|
| 790 |
+
k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
|
| 791 |
+
v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
|
| 792 |
+
if attn_mode == "sageattn" or attn_mode is None and sageattn_varlen is not None:
|
| 793 |
+
x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
| 794 |
+
del q, k, v # free memory
|
| 795 |
+
elif attn_mode == "flash" or attn_mode is None and flash_attn_varlen_func is not None:
|
| 796 |
+
x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
| 797 |
+
del q, k, v # free memory
|
| 798 |
+
else:
|
| 799 |
+
raise NotImplementedError("No Attn Installed or batch_size > 1 is not supported in this configuration. Try `--split_attn`.")
|
| 800 |
+
x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
|
| 801 |
+
return x
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
class HunyuanAttnProcessorFlashAttnDouble:
|
| 805 |
+
def __call__(
|
| 806 |
+
self,
|
| 807 |
+
attn: Attention,
|
| 808 |
+
hidden_states,
|
| 809 |
+
encoder_hidden_states,
|
| 810 |
+
attention_mask,
|
| 811 |
+
image_rotary_emb,
|
| 812 |
+
attn_mode: Optional[str] = None,
|
| 813 |
+
split_attn: Optional[bool] = False,
|
| 814 |
+
):
|
| 815 |
+
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
|
| 816 |
+
|
| 817 |
+
# Project image latents
|
| 818 |
+
query = attn.to_q(hidden_states)
|
| 819 |
+
key = attn.to_k(hidden_states)
|
| 820 |
+
value = attn.to_v(hidden_states)
|
| 821 |
+
del hidden_states # free memory
|
| 822 |
+
|
| 823 |
+
query = query.unflatten(2, (attn.heads, -1))
|
| 824 |
+
key = key.unflatten(2, (attn.heads, -1))
|
| 825 |
+
value = value.unflatten(2, (attn.heads, -1))
|
| 826 |
+
|
| 827 |
+
query = attn.norm_q(query)
|
| 828 |
+
key = attn.norm_k(key)
|
| 829 |
+
|
| 830 |
+
query = apply_rotary_emb_transposed(query, image_rotary_emb)
|
| 831 |
+
key = apply_rotary_emb_transposed(key, image_rotary_emb)
|
| 832 |
+
del image_rotary_emb # free memory
|
| 833 |
+
|
| 834 |
+
# Project context (text/encoder) embeddings
|
| 835 |
+
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
| 836 |
+
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
| 837 |
+
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
| 838 |
+
txt_length = encoder_hidden_states.shape[1] # store length before deleting
|
| 839 |
+
del encoder_hidden_states # free memory
|
| 840 |
+
|
| 841 |
+
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
| 842 |
+
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
| 843 |
+
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
| 844 |
+
|
| 845 |
+
encoder_query = attn.norm_added_q(encoder_query)
|
| 846 |
+
encoder_key = attn.norm_added_k(encoder_key)
|
| 847 |
+
|
| 848 |
+
# Concatenate image and context q, k, v
|
| 849 |
+
query = torch.cat([query, encoder_query], dim=1)
|
| 850 |
+
key = torch.cat([key, encoder_key], dim=1)
|
| 851 |
+
value = torch.cat([value, encoder_value], dim=1)
|
| 852 |
+
del encoder_query, encoder_key, encoder_value # free memory
|
| 853 |
+
|
| 854 |
+
hidden_states_attn = attn_varlen_func(
|
| 855 |
+
query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn
|
| 856 |
+
)
|
| 857 |
+
del query, key, value # free memory
|
| 858 |
+
hidden_states_attn = hidden_states_attn.flatten(-2)
|
| 859 |
+
|
| 860 |
+
hidden_states, encoder_hidden_states = hidden_states_attn[:, :-txt_length], hidden_states_attn[:, -txt_length:]
|
| 861 |
+
del hidden_states_attn # free memory
|
| 862 |
+
|
| 863 |
+
# Apply output projections
|
| 864 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 865 |
+
hidden_states = attn.to_out[1](hidden_states) # Dropout/Identity
|
| 866 |
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
| 867 |
+
|
| 868 |
+
return hidden_states, encoder_hidden_states
|
| 869 |
+
|
| 870 |
+
|
| 871 |
+
class HunyuanAttnProcessorFlashAttnSingle:
|
| 872 |
+
def __call__(
|
| 873 |
+
self,
|
| 874 |
+
attn: Attention,
|
| 875 |
+
hidden_states,
|
| 876 |
+
encoder_hidden_states,
|
| 877 |
+
attention_mask,
|
| 878 |
+
image_rotary_emb,
|
| 879 |
+
attn_mode: Optional[str] = None,
|
| 880 |
+
split_attn: Optional[bool] = False,
|
| 881 |
+
):
|
| 882 |
+
cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
|
| 883 |
+
txt_length = encoder_hidden_states.shape[1] # Store text length
|
| 884 |
+
|
| 885 |
+
# Concatenate image and context inputs
|
| 886 |
+
hidden_states_cat = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 887 |
+
del hidden_states, encoder_hidden_states # free memory
|
| 888 |
+
|
| 889 |
+
# Project concatenated inputs
|
| 890 |
+
query = attn.to_q(hidden_states_cat)
|
| 891 |
+
key = attn.to_k(hidden_states_cat)
|
| 892 |
+
value = attn.to_v(hidden_states_cat)
|
| 893 |
+
del hidden_states_cat # free memory
|
| 894 |
+
|
| 895 |
+
query = query.unflatten(2, (attn.heads, -1))
|
| 896 |
+
key = key.unflatten(2, (attn.heads, -1))
|
| 897 |
+
value = value.unflatten(2, (attn.heads, -1))
|
| 898 |
+
|
| 899 |
+
query = attn.norm_q(query)
|
| 900 |
+
key = attn.norm_k(key)
|
| 901 |
+
|
| 902 |
+
query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
|
| 903 |
+
key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
|
| 904 |
+
del image_rotary_emb # free memory
|
| 905 |
+
|
| 906 |
+
hidden_states = attn_varlen_func(
|
| 907 |
+
query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn
|
| 908 |
+
)
|
| 909 |
+
del query, key, value # free memory
|
| 910 |
+
hidden_states = hidden_states.flatten(-2)
|
| 911 |
+
|
| 912 |
+
hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
|
| 913 |
+
|
| 914 |
+
return hidden_states, encoder_hidden_states
|
| 915 |
+
|
| 916 |
+
|
| 917 |
+
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
| 918 |
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
| 919 |
+
super().__init__()
|
| 920 |
+
|
| 921 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 922 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 923 |
+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 924 |
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
| 925 |
+
|
| 926 |
+
def forward(self, timestep, guidance, pooled_projection):
|
| 927 |
+
timesteps_proj = self.time_proj(timestep)
|
| 928 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
|
| 929 |
+
|
| 930 |
+
guidance_proj = self.time_proj(guidance)
|
| 931 |
+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
|
| 932 |
+
|
| 933 |
+
time_guidance_emb = timesteps_emb + guidance_emb
|
| 934 |
+
|
| 935 |
+
pooled_projections = self.text_embedder(pooled_projection)
|
| 936 |
+
conditioning = time_guidance_emb + pooled_projections
|
| 937 |
+
|
| 938 |
+
return conditioning
|
| 939 |
+
|
| 940 |
+
|
| 941 |
+
class CombinedTimestepTextProjEmbeddings(nn.Module):
|
| 942 |
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
| 943 |
+
super().__init__()
|
| 944 |
+
|
| 945 |
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
| 946 |
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
| 947 |
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
| 948 |
+
|
| 949 |
+
def forward(self, timestep, pooled_projection):
|
| 950 |
+
timesteps_proj = self.time_proj(timestep)
|
| 951 |
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
|
| 952 |
+
|
| 953 |
+
pooled_projections = self.text_embedder(pooled_projection)
|
| 954 |
+
|
| 955 |
+
conditioning = timesteps_emb + pooled_projections
|
| 956 |
+
|
| 957 |
+
return conditioning
|
| 958 |
+
|
| 959 |
+
|
| 960 |
+
class HunyuanVideoAdaNorm(nn.Module):
|
| 961 |
+
def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
|
| 962 |
+
super().__init__()
|
| 963 |
+
|
| 964 |
+
out_features = out_features or 2 * in_features
|
| 965 |
+
self.linear = nn.Linear(in_features, out_features)
|
| 966 |
+
self.nonlinearity = nn.SiLU()
|
| 967 |
+
|
| 968 |
+
def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 969 |
+
temb = self.linear(self.nonlinearity(temb))
|
| 970 |
+
gate_msa, gate_mlp = temb.chunk(2, dim=-1)
|
| 971 |
+
gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
|
| 972 |
+
return gate_msa, gate_mlp
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
|
| 976 |
+
def __init__(
|
| 977 |
+
self,
|
| 978 |
+
num_attention_heads: int,
|
| 979 |
+
attention_head_dim: int,
|
| 980 |
+
mlp_width_ratio: float = 4.0,
|
| 981 |
+
mlp_drop_rate: float = 0.0,
|
| 982 |
+
attention_bias: bool = True,
|
| 983 |
+
) -> None:
|
| 984 |
+
super().__init__()
|
| 985 |
+
|
| 986 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 987 |
+
|
| 988 |
+
self.norm1 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 989 |
+
self.attn = Attention(
|
| 990 |
+
query_dim=hidden_size,
|
| 991 |
+
cross_attention_dim=None,
|
| 992 |
+
heads=num_attention_heads,
|
| 993 |
+
dim_head=attention_head_dim,
|
| 994 |
+
bias=attention_bias,
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6)
|
| 998 |
+
self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
|
| 999 |
+
|
| 1000 |
+
self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
|
| 1001 |
+
|
| 1002 |
+
def forward(
|
| 1003 |
+
self,
|
| 1004 |
+
hidden_states: torch.Tensor,
|
| 1005 |
+
temb: torch.Tensor,
|
| 1006 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1007 |
+
) -> torch.Tensor:
|
| 1008 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 1009 |
+
|
| 1010 |
+
# Self-attention
|
| 1011 |
+
attn_output = self.attn(
|
| 1012 |
+
hidden_states=norm_hidden_states,
|
| 1013 |
+
encoder_hidden_states=None,
|
| 1014 |
+
attention_mask=attention_mask,
|
| 1015 |
+
)
|
| 1016 |
+
del norm_hidden_states # free memory
|
| 1017 |
+
|
| 1018 |
+
gate_msa, gate_mlp = self.norm_out(temb)
|
| 1019 |
+
hidden_states = hidden_states + attn_output * gate_msa
|
| 1020 |
+
del attn_output, gate_msa # free memory
|
| 1021 |
+
|
| 1022 |
+
ff_output = self.ff(self.norm2(hidden_states))
|
| 1023 |
+
hidden_states = hidden_states + ff_output * gate_mlp
|
| 1024 |
+
del ff_output, gate_mlp # free memory
|
| 1025 |
+
|
| 1026 |
+
return hidden_states
|
| 1027 |
+
|
| 1028 |
+
|
| 1029 |
+
class HunyuanVideoIndividualTokenRefiner(nn.Module):
|
| 1030 |
+
def __init__(
|
| 1031 |
+
self,
|
| 1032 |
+
num_attention_heads: int,
|
| 1033 |
+
attention_head_dim: int,
|
| 1034 |
+
num_layers: int,
|
| 1035 |
+
mlp_width_ratio: float = 4.0,
|
| 1036 |
+
mlp_drop_rate: float = 0.0,
|
| 1037 |
+
attention_bias: bool = True,
|
| 1038 |
+
) -> None:
|
| 1039 |
+
super().__init__()
|
| 1040 |
+
|
| 1041 |
+
self.refiner_blocks = nn.ModuleList(
|
| 1042 |
+
[
|
| 1043 |
+
HunyuanVideoIndividualTokenRefinerBlock(
|
| 1044 |
+
num_attention_heads=num_attention_heads,
|
| 1045 |
+
attention_head_dim=attention_head_dim,
|
| 1046 |
+
mlp_width_ratio=mlp_width_ratio,
|
| 1047 |
+
mlp_drop_rate=mlp_drop_rate,
|
| 1048 |
+
attention_bias=attention_bias,
|
| 1049 |
+
)
|
| 1050 |
+
for _ in range(num_layers)
|
| 1051 |
+
]
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
def forward(
|
| 1055 |
+
self,
|
| 1056 |
+
hidden_states: torch.Tensor,
|
| 1057 |
+
temb: torch.Tensor,
|
| 1058 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1059 |
+
) -> torch.Tensor:
|
| 1060 |
+
self_attn_mask = None
|
| 1061 |
+
if attention_mask is not None:
|
| 1062 |
+
batch_size = attention_mask.shape[0]
|
| 1063 |
+
seq_len = attention_mask.shape[1]
|
| 1064 |
+
attention_mask = attention_mask.to(hidden_states.device).bool()
|
| 1065 |
+
self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
| 1066 |
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
| 1067 |
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
| 1068 |
+
self_attn_mask[:, :, :, 0] = True
|
| 1069 |
+
|
| 1070 |
+
for block in self.refiner_blocks:
|
| 1071 |
+
hidden_states = block(hidden_states, temb, self_attn_mask)
|
| 1072 |
+
|
| 1073 |
+
return hidden_states
|
| 1074 |
+
|
| 1075 |
+
|
| 1076 |
+
class HunyuanVideoTokenRefiner(nn.Module):
|
| 1077 |
+
def __init__(
|
| 1078 |
+
self,
|
| 1079 |
+
in_channels: int,
|
| 1080 |
+
num_attention_heads: int,
|
| 1081 |
+
attention_head_dim: int,
|
| 1082 |
+
num_layers: int,
|
| 1083 |
+
mlp_ratio: float = 4.0,
|
| 1084 |
+
mlp_drop_rate: float = 0.0,
|
| 1085 |
+
attention_bias: bool = True,
|
| 1086 |
+
) -> None:
|
| 1087 |
+
super().__init__()
|
| 1088 |
+
|
| 1089 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 1090 |
+
|
| 1091 |
+
self.time_text_embed = CombinedTimestepTextProjEmbeddings(embedding_dim=hidden_size, pooled_projection_dim=in_channels)
|
| 1092 |
+
self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
|
| 1093 |
+
self.token_refiner = HunyuanVideoIndividualTokenRefiner(
|
| 1094 |
+
num_attention_heads=num_attention_heads,
|
| 1095 |
+
attention_head_dim=attention_head_dim,
|
| 1096 |
+
num_layers=num_layers,
|
| 1097 |
+
mlp_width_ratio=mlp_ratio,
|
| 1098 |
+
mlp_drop_rate=mlp_drop_rate,
|
| 1099 |
+
attention_bias=attention_bias,
|
| 1100 |
+
)
|
| 1101 |
+
|
| 1102 |
+
def forward(
|
| 1103 |
+
self,
|
| 1104 |
+
hidden_states: torch.Tensor,
|
| 1105 |
+
timestep: torch.LongTensor,
|
| 1106 |
+
attention_mask: Optional[torch.LongTensor] = None,
|
| 1107 |
+
) -> torch.Tensor:
|
| 1108 |
+
if attention_mask is None:
|
| 1109 |
+
pooled_projections = hidden_states.mean(dim=1)
|
| 1110 |
+
else:
|
| 1111 |
+
original_dtype = hidden_states.dtype
|
| 1112 |
+
mask_float = attention_mask.float().unsqueeze(-1)
|
| 1113 |
+
pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
| 1114 |
+
pooled_projections = pooled_projections.to(original_dtype)
|
| 1115 |
+
|
| 1116 |
+
temb = self.time_text_embed(timestep, pooled_projections)
|
| 1117 |
+
del pooled_projections # free memory
|
| 1118 |
+
|
| 1119 |
+
hidden_states = self.proj_in(hidden_states)
|
| 1120 |
+
hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
|
| 1121 |
+
del temb, attention_mask # free memory
|
| 1122 |
+
|
| 1123 |
+
return hidden_states
|
| 1124 |
+
|
| 1125 |
+
|
| 1126 |
+
class HunyuanVideoRotaryPosEmbed(nn.Module):
|
| 1127 |
+
def __init__(self, rope_dim, theta):
|
| 1128 |
+
super().__init__()
|
| 1129 |
+
self.DT, self.DY, self.DX = rope_dim
|
| 1130 |
+
self.theta = theta
|
| 1131 |
+
|
| 1132 |
+
@torch.no_grad()
|
| 1133 |
+
def get_frequency(self, dim, pos):
|
| 1134 |
+
T, H, W = pos.shape
|
| 1135 |
+
freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
|
| 1136 |
+
freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
|
| 1137 |
+
return freqs.cos(), freqs.sin()
|
| 1138 |
+
|
| 1139 |
+
@torch.no_grad()
|
| 1140 |
+
def forward_inner(self, frame_indices, height, width, device):
|
| 1141 |
+
GT, GY, GX = torch.meshgrid(
|
| 1142 |
+
frame_indices.to(device=device, dtype=torch.float32),
|
| 1143 |
+
torch.arange(0, height, device=device, dtype=torch.float32),
|
| 1144 |
+
torch.arange(0, width, device=device, dtype=torch.float32),
|
| 1145 |
+
indexing="ij",
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
FCT, FST = self.get_frequency(self.DT, GT)
|
| 1149 |
+
del GT # free memory
|
| 1150 |
+
FCY, FSY = self.get_frequency(self.DY, GY)
|
| 1151 |
+
del GY # free memory
|
| 1152 |
+
FCX, FSX = self.get_frequency(self.DX, GX)
|
| 1153 |
+
del GX # free memory
|
| 1154 |
+
|
| 1155 |
+
result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
|
| 1156 |
+
del FCT, FCY, FCX, FST, FSY, FSX # free memory
|
| 1157 |
+
|
| 1158 |
+
# Return result already on the correct device
|
| 1159 |
+
return result # Shape (2 * total_dim / 2, T, H, W) -> (total_dim, T, H, W)
|
| 1160 |
+
|
| 1161 |
+
@torch.no_grad()
|
| 1162 |
+
def forward(self, frame_indices, height, width, device):
|
| 1163 |
+
frame_indices = frame_indices.unbind(0)
|
| 1164 |
+
results = [self.forward_inner(f, height, width, device) for f in frame_indices]
|
| 1165 |
+
results = torch.stack(results, dim=0)
|
| 1166 |
+
return results
|
| 1167 |
+
|
| 1168 |
+
|
| 1169 |
+
class AdaLayerNormZero(nn.Module):
|
| 1170 |
+
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
| 1171 |
+
super().__init__()
|
| 1172 |
+
self.silu = nn.SiLU()
|
| 1173 |
+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
|
| 1174 |
+
if norm_type == "layer_norm":
|
| 1175 |
+
self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6)
|
| 1176 |
+
else:
|
| 1177 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
| 1178 |
+
|
| 1179 |
+
def forward(
|
| 1180 |
+
self, x: torch.Tensor, emb: Optional[torch.Tensor] = None
|
| 1181 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1182 |
+
emb = emb.unsqueeze(-2)
|
| 1183 |
+
emb = self.linear(self.silu(emb))
|
| 1184 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
|
| 1185 |
+
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
| 1186 |
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
| 1187 |
+
|
| 1188 |
+
|
| 1189 |
+
class AdaLayerNormZeroSingle(nn.Module):
|
| 1190 |
+
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
|
| 1191 |
+
super().__init__()
|
| 1192 |
+
|
| 1193 |
+
self.silu = nn.SiLU()
|
| 1194 |
+
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
|
| 1195 |
+
if norm_type == "layer_norm":
|
| 1196 |
+
self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6)
|
| 1197 |
+
else:
|
| 1198 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
| 1199 |
+
|
| 1200 |
+
def forward(
|
| 1201 |
+
self,
|
| 1202 |
+
x: torch.Tensor,
|
| 1203 |
+
emb: Optional[torch.Tensor] = None,
|
| 1204 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1205 |
+
emb = emb.unsqueeze(-2)
|
| 1206 |
+
emb = self.linear(self.silu(emb))
|
| 1207 |
+
shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
|
| 1208 |
+
x = self.norm(x) * (1 + scale_msa) + shift_msa
|
| 1209 |
+
return x, gate_msa
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
class AdaLayerNormContinuous(nn.Module):
|
| 1213 |
+
def __init__(
|
| 1214 |
+
self,
|
| 1215 |
+
embedding_dim: int,
|
| 1216 |
+
conditioning_embedding_dim: int,
|
| 1217 |
+
elementwise_affine=True,
|
| 1218 |
+
eps=1e-5,
|
| 1219 |
+
bias=True,
|
| 1220 |
+
norm_type="layer_norm",
|
| 1221 |
+
):
|
| 1222 |
+
super().__init__()
|
| 1223 |
+
self.silu = nn.SiLU()
|
| 1224 |
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
| 1225 |
+
if norm_type == "layer_norm":
|
| 1226 |
+
self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias)
|
| 1227 |
+
else:
|
| 1228 |
+
raise ValueError(f"unknown norm_type {norm_type}")
|
| 1229 |
+
|
| 1230 |
+
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
|
| 1231 |
+
emb = emb.unsqueeze(-2)
|
| 1232 |
+
emb = self.linear(self.silu(emb))
|
| 1233 |
+
scale, shift = emb.chunk(2, dim=-1)
|
| 1234 |
+
del emb # free memory
|
| 1235 |
+
x = self.norm(x) * (1 + scale) + shift
|
| 1236 |
+
return x
|
| 1237 |
+
|
| 1238 |
+
|
| 1239 |
+
class HunyuanVideoSingleTransformerBlock(nn.Module):
|
| 1240 |
+
def __init__(
|
| 1241 |
+
self,
|
| 1242 |
+
num_attention_heads: int,
|
| 1243 |
+
attention_head_dim: int,
|
| 1244 |
+
mlp_ratio: float = 4.0,
|
| 1245 |
+
qk_norm: str = "rms_norm",
|
| 1246 |
+
attn_mode: Optional[str] = None,
|
| 1247 |
+
split_attn: Optional[bool] = False,
|
| 1248 |
+
) -> None:
|
| 1249 |
+
super().__init__()
|
| 1250 |
+
|
| 1251 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 1252 |
+
mlp_dim = int(hidden_size * mlp_ratio)
|
| 1253 |
+
self.attn_mode = attn_mode
|
| 1254 |
+
self.split_attn = split_attn
|
| 1255 |
+
|
| 1256 |
+
# Attention layer (pre_only=True means no output projection in Attention module itself)
|
| 1257 |
+
self.attn = Attention(
|
| 1258 |
+
query_dim=hidden_size,
|
| 1259 |
+
cross_attention_dim=None,
|
| 1260 |
+
dim_head=attention_head_dim,
|
| 1261 |
+
heads=num_attention_heads,
|
| 1262 |
+
out_dim=hidden_size,
|
| 1263 |
+
bias=True,
|
| 1264 |
+
processor=HunyuanAttnProcessorFlashAttnSingle(),
|
| 1265 |
+
qk_norm=qk_norm,
|
| 1266 |
+
eps=1e-6,
|
| 1267 |
+
pre_only=True, # Crucial: Attn processor will return raw attention output
|
| 1268 |
+
)
|
| 1269 |
+
|
| 1270 |
+
self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
|
| 1271 |
+
self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
|
| 1272 |
+
self.act_mlp = nn.GELU(approximate="tanh")
|
| 1273 |
+
self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
|
| 1274 |
+
|
| 1275 |
+
def forward(
|
| 1276 |
+
self,
|
| 1277 |
+
hidden_states: torch.Tensor,
|
| 1278 |
+
encoder_hidden_states: torch.Tensor,
|
| 1279 |
+
temb: torch.Tensor,
|
| 1280 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1281 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 1282 |
+
) -> torch.Tensor:
|
| 1283 |
+
text_seq_length = encoder_hidden_states.shape[1]
|
| 1284 |
+
hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
|
| 1285 |
+
del encoder_hidden_states # free memory
|
| 1286 |
+
|
| 1287 |
+
residual = hidden_states
|
| 1288 |
+
|
| 1289 |
+
# 1. Input normalization
|
| 1290 |
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
| 1291 |
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
| 1292 |
+
|
| 1293 |
+
norm_hidden_states, norm_encoder_hidden_states = (
|
| 1294 |
+
norm_hidden_states[:, :-text_seq_length, :],
|
| 1295 |
+
norm_hidden_states[:, -text_seq_length:, :],
|
| 1296 |
+
)
|
| 1297 |
+
|
| 1298 |
+
# 2. Attention
|
| 1299 |
+
attn_output, context_attn_output = self.attn(
|
| 1300 |
+
hidden_states=norm_hidden_states,
|
| 1301 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 1302 |
+
attention_mask=attention_mask,
|
| 1303 |
+
image_rotary_emb=image_rotary_emb,
|
| 1304 |
+
attn_mode=self.attn_mode,
|
| 1305 |
+
split_attn=self.split_attn,
|
| 1306 |
+
)
|
| 1307 |
+
attn_output = torch.cat([attn_output, context_attn_output], dim=1)
|
| 1308 |
+
del norm_hidden_states, norm_encoder_hidden_states, context_attn_output # free memory
|
| 1309 |
+
del image_rotary_emb
|
| 1310 |
+
|
| 1311 |
+
# 3. Modulation and residual connection
|
| 1312 |
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
| 1313 |
+
del attn_output, mlp_hidden_states # free memory
|
| 1314 |
+
hidden_states = gate * self.proj_out(hidden_states)
|
| 1315 |
+
hidden_states = hidden_states + residual
|
| 1316 |
+
|
| 1317 |
+
hidden_states, encoder_hidden_states = (
|
| 1318 |
+
hidden_states[:, :-text_seq_length, :],
|
| 1319 |
+
hidden_states[:, -text_seq_length:, :],
|
| 1320 |
+
)
|
| 1321 |
+
return hidden_states, encoder_hidden_states
|
| 1322 |
+
|
| 1323 |
+
|
| 1324 |
+
class HunyuanVideoTransformerBlock(nn.Module):
|
| 1325 |
+
def __init__(
|
| 1326 |
+
self,
|
| 1327 |
+
num_attention_heads: int,
|
| 1328 |
+
attention_head_dim: int,
|
| 1329 |
+
mlp_ratio: float,
|
| 1330 |
+
qk_norm: str = "rms_norm",
|
| 1331 |
+
attn_mode: Optional[str] = None,
|
| 1332 |
+
split_attn: Optional[bool] = False,
|
| 1333 |
+
) -> None:
|
| 1334 |
+
super().__init__()
|
| 1335 |
+
|
| 1336 |
+
hidden_size = num_attention_heads * attention_head_dim
|
| 1337 |
+
self.attn_mode = attn_mode
|
| 1338 |
+
self.split_attn = split_attn
|
| 1339 |
+
|
| 1340 |
+
self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
| 1341 |
+
self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
|
| 1342 |
+
|
| 1343 |
+
self.attn = Attention(
|
| 1344 |
+
query_dim=hidden_size,
|
| 1345 |
+
cross_attention_dim=None,
|
| 1346 |
+
added_kv_proj_dim=hidden_size,
|
| 1347 |
+
dim_head=attention_head_dim,
|
| 1348 |
+
heads=num_attention_heads,
|
| 1349 |
+
out_dim=hidden_size,
|
| 1350 |
+
context_pre_only=False,
|
| 1351 |
+
bias=True,
|
| 1352 |
+
processor=HunyuanAttnProcessorFlashAttnDouble(),
|
| 1353 |
+
qk_norm=qk_norm,
|
| 1354 |
+
eps=1e-6,
|
| 1355 |
+
)
|
| 1356 |
+
|
| 1357 |
+
self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 1358 |
+
self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 1359 |
+
|
| 1360 |
+
self.norm2_context = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 1361 |
+
self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
|
| 1362 |
+
|
| 1363 |
+
def forward(
|
| 1364 |
+
self,
|
| 1365 |
+
hidden_states: torch.Tensor,
|
| 1366 |
+
encoder_hidden_states: torch.Tensor,
|
| 1367 |
+
temb: torch.Tensor,
|
| 1368 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1369 |
+
freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 1370 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 1371 |
+
# 1. Input normalization
|
| 1372 |
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
| 1373 |
+
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
| 1374 |
+
encoder_hidden_states, emb=temb
|
| 1375 |
+
)
|
| 1376 |
+
|
| 1377 |
+
# 2. Joint attention
|
| 1378 |
+
attn_output, context_attn_output = self.attn(
|
| 1379 |
+
hidden_states=norm_hidden_states,
|
| 1380 |
+
encoder_hidden_states=norm_encoder_hidden_states,
|
| 1381 |
+
attention_mask=attention_mask,
|
| 1382 |
+
image_rotary_emb=freqs_cis,
|
| 1383 |
+
attn_mode=self.attn_mode,
|
| 1384 |
+
split_attn=self.split_attn,
|
| 1385 |
+
)
|
| 1386 |
+
del norm_hidden_states, norm_encoder_hidden_states, freqs_cis # free memory
|
| 1387 |
+
|
| 1388 |
+
# 3. Modulation and residual connection
|
| 1389 |
+
hidden_states = hidden_states + attn_output * gate_msa
|
| 1390 |
+
del attn_output, gate_msa # free memory
|
| 1391 |
+
encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
|
| 1392 |
+
del context_attn_output, c_gate_msa # free memory
|
| 1393 |
+
|
| 1394 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 1395 |
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
| 1396 |
+
|
| 1397 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 1398 |
+
del shift_mlp, scale_mlp # free memory
|
| 1399 |
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
|
| 1400 |
+
del c_shift_mlp, c_scale_mlp # free memory
|
| 1401 |
+
|
| 1402 |
+
# 4. Feed-forward
|
| 1403 |
+
ff_output = self.ff(norm_hidden_states)
|
| 1404 |
+
del norm_hidden_states # free memory
|
| 1405 |
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
| 1406 |
+
del norm_encoder_hidden_states # free memory
|
| 1407 |
+
|
| 1408 |
+
hidden_states = hidden_states + gate_mlp * ff_output
|
| 1409 |
+
del ff_output, gate_mlp # free memory
|
| 1410 |
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
|
| 1411 |
+
del context_ff_output, c_gate_mlp # free memory
|
| 1412 |
+
|
| 1413 |
+
return hidden_states, encoder_hidden_states
|
| 1414 |
+
|
| 1415 |
+
|
| 1416 |
+
class ClipVisionProjection(nn.Module):
|
| 1417 |
+
def __init__(self, in_channels, out_channels):
|
| 1418 |
+
super().__init__()
|
| 1419 |
+
self.up = nn.Linear(in_channels, out_channels * 3)
|
| 1420 |
+
self.down = nn.Linear(out_channels * 3, out_channels)
|
| 1421 |
+
|
| 1422 |
+
def forward(self, x):
|
| 1423 |
+
projected_x = self.down(nn.functional.silu(self.up(x)))
|
| 1424 |
+
return projected_x
|
| 1425 |
+
|
| 1426 |
+
|
| 1427 |
+
class HunyuanVideoPatchEmbed(nn.Module):
|
| 1428 |
+
def __init__(self, patch_size, in_chans, embed_dim):
|
| 1429 |
+
super().__init__()
|
| 1430 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
| 1431 |
+
|
| 1432 |
+
|
| 1433 |
+
class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
|
| 1434 |
+
def __init__(self, inner_dim):
|
| 1435 |
+
super().__init__()
|
| 1436 |
+
self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
|
| 1437 |
+
self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
|
| 1438 |
+
self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
|
| 1439 |
+
|
| 1440 |
+
@torch.no_grad()
|
| 1441 |
+
def initialize_weight_from_another_conv3d(self, another_layer):
|
| 1442 |
+
weight = another_layer.weight.detach().clone()
|
| 1443 |
+
bias = another_layer.bias.detach().clone()
|
| 1444 |
+
|
| 1445 |
+
sd = {
|
| 1446 |
+
"proj.weight": weight.clone(),
|
| 1447 |
+
"proj.bias": bias.clone(),
|
| 1448 |
+
"proj_2x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=2, hk=2, wk=2) / 8.0,
|
| 1449 |
+
"proj_2x.bias": bias.clone(),
|
| 1450 |
+
"proj_4x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=4, hk=4, wk=4) / 64.0,
|
| 1451 |
+
"proj_4x.bias": bias.clone(),
|
| 1452 |
+
}
|
| 1453 |
+
|
| 1454 |
+
sd = {k: v.clone() for k, v in sd.items()}
|
| 1455 |
+
|
| 1456 |
+
self.load_state_dict(sd)
|
| 1457 |
+
return
|
| 1458 |
+
|
| 1459 |
+
|
| 1460 |
+
class HunyuanVideoTransformer3DModelPacked(nn.Module): # (PreTrainedModelMixin, GenerationMixin,
|
| 1461 |
+
# ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
| 1462 |
+
# @register_to_config
|
| 1463 |
+
def __init__(
|
| 1464 |
+
self,
|
| 1465 |
+
in_channels: int = 16,
|
| 1466 |
+
out_channels: int = 16,
|
| 1467 |
+
num_attention_heads: int = 24,
|
| 1468 |
+
attention_head_dim: int = 128,
|
| 1469 |
+
num_layers: int = 20,
|
| 1470 |
+
num_single_layers: int = 40,
|
| 1471 |
+
num_refiner_layers: int = 2,
|
| 1472 |
+
mlp_ratio: float = 4.0,
|
| 1473 |
+
patch_size: int = 2,
|
| 1474 |
+
patch_size_t: int = 1,
|
| 1475 |
+
qk_norm: str = "rms_norm",
|
| 1476 |
+
guidance_embeds: bool = True,
|
| 1477 |
+
text_embed_dim: int = 4096,
|
| 1478 |
+
pooled_projection_dim: int = 768,
|
| 1479 |
+
rope_theta: float = 256.0,
|
| 1480 |
+
rope_axes_dim: Tuple[int] = (16, 56, 56),
|
| 1481 |
+
has_image_proj=False,
|
| 1482 |
+
image_proj_dim=1152,
|
| 1483 |
+
has_clean_x_embedder=False,
|
| 1484 |
+
attn_mode: Optional[str] = None,
|
| 1485 |
+
split_attn: Optional[bool] = False,
|
| 1486 |
+
) -> None:
|
| 1487 |
+
super().__init__()
|
| 1488 |
+
|
| 1489 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 1490 |
+
out_channels = out_channels or in_channels
|
| 1491 |
+
self.config_patch_size = patch_size
|
| 1492 |
+
self.config_patch_size_t = patch_size_t
|
| 1493 |
+
|
| 1494 |
+
# 1. Latent and condition embedders
|
| 1495 |
+
self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
|
| 1496 |
+
self.context_embedder = HunyuanVideoTokenRefiner(
|
| 1497 |
+
text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
|
| 1498 |
+
)
|
| 1499 |
+
self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
|
| 1500 |
+
|
| 1501 |
+
self.clean_x_embedder = None
|
| 1502 |
+
self.image_projection = None
|
| 1503 |
+
|
| 1504 |
+
# 2. RoPE
|
| 1505 |
+
self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
|
| 1506 |
+
|
| 1507 |
+
# 3. Dual stream transformer blocks
|
| 1508 |
+
self.transformer_blocks = nn.ModuleList(
|
| 1509 |
+
[
|
| 1510 |
+
HunyuanVideoTransformerBlock(
|
| 1511 |
+
num_attention_heads,
|
| 1512 |
+
attention_head_dim,
|
| 1513 |
+
mlp_ratio=mlp_ratio,
|
| 1514 |
+
qk_norm=qk_norm,
|
| 1515 |
+
attn_mode=attn_mode,
|
| 1516 |
+
split_attn=split_attn,
|
| 1517 |
+
)
|
| 1518 |
+
for _ in range(num_layers)
|
| 1519 |
+
]
|
| 1520 |
+
)
|
| 1521 |
+
|
| 1522 |
+
# 4. Single stream transformer blocks
|
| 1523 |
+
self.single_transformer_blocks = nn.ModuleList(
|
| 1524 |
+
[
|
| 1525 |
+
HunyuanVideoSingleTransformerBlock(
|
| 1526 |
+
num_attention_heads,
|
| 1527 |
+
attention_head_dim,
|
| 1528 |
+
mlp_ratio=mlp_ratio,
|
| 1529 |
+
qk_norm=qk_norm,
|
| 1530 |
+
attn_mode=attn_mode,
|
| 1531 |
+
split_attn=split_attn,
|
| 1532 |
+
)
|
| 1533 |
+
for _ in range(num_single_layers)
|
| 1534 |
+
]
|
| 1535 |
+
)
|
| 1536 |
+
|
| 1537 |
+
# 5. Output projection
|
| 1538 |
+
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
|
| 1539 |
+
self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
|
| 1540 |
+
|
| 1541 |
+
self.inner_dim = inner_dim
|
| 1542 |
+
self.use_gradient_checkpointing = False
|
| 1543 |
+
self.enable_teacache = False
|
| 1544 |
+
|
| 1545 |
+
# if has_image_proj:
|
| 1546 |
+
# self.install_image_projection(image_proj_dim)
|
| 1547 |
+
self.image_projection = ClipVisionProjection(in_channels=image_proj_dim, out_channels=self.inner_dim)
|
| 1548 |
+
# self.config["has_image_proj"] = True
|
| 1549 |
+
# self.config["image_proj_dim"] = in_channels
|
| 1550 |
+
|
| 1551 |
+
# if has_clean_x_embedder:
|
| 1552 |
+
# self.install_clean_x_embedder()
|
| 1553 |
+
self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
|
| 1554 |
+
# self.config["has_clean_x_embedder"] = True
|
| 1555 |
+
|
| 1556 |
+
self.high_quality_fp32_output_for_inference = True # False # change default to True
|
| 1557 |
+
|
| 1558 |
+
# Block swapping attributes (initialized to None)
|
| 1559 |
+
self.blocks_to_swap = None
|
| 1560 |
+
self.offloader_double = None
|
| 1561 |
+
self.offloader_single = None
|
| 1562 |
+
|
| 1563 |
+
@property
|
| 1564 |
+
def device(self):
|
| 1565 |
+
return next(self.parameters()).device
|
| 1566 |
+
|
| 1567 |
+
@property
|
| 1568 |
+
def dtype(self):
|
| 1569 |
+
return next(self.parameters()).dtype
|
| 1570 |
+
|
| 1571 |
+
def enable_gradient_checkpointing(self):
|
| 1572 |
+
self.use_gradient_checkpointing = True
|
| 1573 |
+
print("Gradient checkpointing enabled for HunyuanVideoTransformer3DModelPacked.") # Logging
|
| 1574 |
+
|
| 1575 |
+
def disable_gradient_checkpointing(self):
|
| 1576 |
+
self.use_gradient_checkpointing = False
|
| 1577 |
+
print("Gradient checkpointing disabled for HunyuanVideoTransformer3DModelPacked.") # Logging
|
| 1578 |
+
|
| 1579 |
+
def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
|
| 1580 |
+
self.enable_teacache = enable_teacache
|
| 1581 |
+
self.cnt = 0
|
| 1582 |
+
self.num_steps = num_steps
|
| 1583 |
+
self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
|
| 1584 |
+
self.accumulated_rel_l1_distance = 0
|
| 1585 |
+
self.previous_modulated_input = None
|
| 1586 |
+
self.previous_residual = None
|
| 1587 |
+
self.teacache_rescale_func = np.poly1d([7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02])
|
| 1588 |
+
if enable_teacache:
|
| 1589 |
+
print(f"TeaCache enabled: num_steps={num_steps}, rel_l1_thresh={rel_l1_thresh}")
|
| 1590 |
+
else:
|
| 1591 |
+
print("TeaCache disabled.")
|
| 1592 |
+
|
| 1593 |
+
def gradient_checkpointing_method(self, block, *args):
|
| 1594 |
+
if self.use_gradient_checkpointing:
|
| 1595 |
+
result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
|
| 1596 |
+
else:
|
| 1597 |
+
result = block(*args)
|
| 1598 |
+
return result
|
| 1599 |
+
|
| 1600 |
+
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
|
| 1601 |
+
self.blocks_to_swap = num_blocks
|
| 1602 |
+
self.num_double_blocks = len(self.transformer_blocks)
|
| 1603 |
+
self.num_single_blocks = len(self.single_transformer_blocks)
|
| 1604 |
+
double_blocks_to_swap = num_blocks // 2
|
| 1605 |
+
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
|
| 1606 |
+
|
| 1607 |
+
assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
|
| 1608 |
+
f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
|
| 1609 |
+
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
|
| 1610 |
+
)
|
| 1611 |
+
|
| 1612 |
+
self.offloader_double = ModelOffloader(
|
| 1613 |
+
"double",
|
| 1614 |
+
self.transformer_blocks,
|
| 1615 |
+
self.num_double_blocks,
|
| 1616 |
+
double_blocks_to_swap,
|
| 1617 |
+
supports_backward,
|
| 1618 |
+
device,
|
| 1619 |
+
# debug=True # Optional debugging
|
| 1620 |
+
)
|
| 1621 |
+
self.offloader_single = ModelOffloader(
|
| 1622 |
+
"single",
|
| 1623 |
+
self.single_transformer_blocks,
|
| 1624 |
+
self.num_single_blocks,
|
| 1625 |
+
single_blocks_to_swap,
|
| 1626 |
+
supports_backward,
|
| 1627 |
+
device, # , debug=True
|
| 1628 |
+
)
|
| 1629 |
+
print(
|
| 1630 |
+
f"HunyuanVideoTransformer3DModelPacked: Block swap enabled. Swapping {num_blocks} blocks, "
|
| 1631 |
+
+ f"double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}, supports_backward: {supports_backward}."
|
| 1632 |
+
)
|
| 1633 |
+
|
| 1634 |
+
def switch_block_swap_for_inference(self):
|
| 1635 |
+
if self.blocks_to_swap and self.blocks_to_swap > 0:
|
| 1636 |
+
self.offloader_double.set_forward_only(True)
|
| 1637 |
+
self.offloader_single.set_forward_only(True)
|
| 1638 |
+
self.prepare_block_swap_before_forward()
|
| 1639 |
+
print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward only.")
|
| 1640 |
+
|
| 1641 |
+
def switch_block_swap_for_training(self):
|
| 1642 |
+
if self.blocks_to_swap and self.blocks_to_swap > 0:
|
| 1643 |
+
self.offloader_double.set_forward_only(False)
|
| 1644 |
+
self.offloader_single.set_forward_only(False)
|
| 1645 |
+
self.prepare_block_swap_before_forward()
|
| 1646 |
+
print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward and backward.")
|
| 1647 |
+
|
| 1648 |
+
def move_to_device_except_swap_blocks(self, device: torch.device):
|
| 1649 |
+
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
| 1650 |
+
if self.blocks_to_swap:
|
| 1651 |
+
saved_double_blocks = self.transformer_blocks
|
| 1652 |
+
saved_single_blocks = self.single_transformer_blocks
|
| 1653 |
+
self.transformer_blocks = None
|
| 1654 |
+
self.single_transformer_blocks = None
|
| 1655 |
+
|
| 1656 |
+
self.to(device)
|
| 1657 |
+
|
| 1658 |
+
if self.blocks_to_swap:
|
| 1659 |
+
self.transformer_blocks = saved_double_blocks
|
| 1660 |
+
self.single_transformer_blocks = saved_single_blocks
|
| 1661 |
+
|
| 1662 |
+
def prepare_block_swap_before_forward(self):
|
| 1663 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
| 1664 |
+
return
|
| 1665 |
+
self.offloader_double.prepare_block_devices_before_forward(self.transformer_blocks)
|
| 1666 |
+
self.offloader_single.prepare_block_devices_before_forward(self.single_transformer_blocks)
|
| 1667 |
+
|
| 1668 |
+
def process_input_hidden_states(
|
| 1669 |
+
self,
|
| 1670 |
+
latents,
|
| 1671 |
+
latent_indices=None,
|
| 1672 |
+
clean_latents=None,
|
| 1673 |
+
clean_latent_indices=None,
|
| 1674 |
+
clean_latents_2x=None,
|
| 1675 |
+
clean_latent_2x_indices=None,
|
| 1676 |
+
clean_latents_4x=None,
|
| 1677 |
+
clean_latent_4x_indices=None,
|
| 1678 |
+
):
|
| 1679 |
+
hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
|
| 1680 |
+
B, C, T, H, W = hidden_states.shape
|
| 1681 |
+
|
| 1682 |
+
if latent_indices is None:
|
| 1683 |
+
latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
|
| 1684 |
+
|
| 1685 |
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
| 1686 |
+
|
| 1687 |
+
rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
|
| 1688 |
+
rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
|
| 1689 |
+
|
| 1690 |
+
if clean_latents is not None and clean_latent_indices is not None:
|
| 1691 |
+
clean_latents = clean_latents.to(hidden_states)
|
| 1692 |
+
clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
|
| 1693 |
+
clean_latents = clean_latents.flatten(2).transpose(1, 2)
|
| 1694 |
+
|
| 1695 |
+
clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
|
| 1696 |
+
clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
|
| 1697 |
+
|
| 1698 |
+
hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
|
| 1699 |
+
rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
|
| 1700 |
+
|
| 1701 |
+
if clean_latents_2x is not None and clean_latent_2x_indices is not None:
|
| 1702 |
+
clean_latents_2x = clean_latents_2x.to(hidden_states)
|
| 1703 |
+
clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
|
| 1704 |
+
clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
|
| 1705 |
+
clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
|
| 1706 |
+
|
| 1707 |
+
clean_latent_2x_rope_freqs = self.rope(
|
| 1708 |
+
frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device
|
| 1709 |
+
)
|
| 1710 |
+
clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
|
| 1711 |
+
clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
|
| 1712 |
+
clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
|
| 1713 |
+
|
| 1714 |
+
hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
|
| 1715 |
+
rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
|
| 1716 |
+
|
| 1717 |
+
if clean_latents_4x is not None and clean_latent_4x_indices is not None:
|
| 1718 |
+
clean_latents_4x = clean_latents_4x.to(hidden_states)
|
| 1719 |
+
clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
|
| 1720 |
+
clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
|
| 1721 |
+
clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
|
| 1722 |
+
|
| 1723 |
+
clean_latent_4x_rope_freqs = self.rope(
|
| 1724 |
+
frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device
|
| 1725 |
+
)
|
| 1726 |
+
clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
|
| 1727 |
+
clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
|
| 1728 |
+
clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
|
| 1729 |
+
|
| 1730 |
+
hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
|
| 1731 |
+
rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
|
| 1732 |
+
|
| 1733 |
+
return hidden_states, rope_freqs
|
| 1734 |
+
|
| 1735 |
+
def forward(
|
| 1736 |
+
self,
|
| 1737 |
+
hidden_states,
|
| 1738 |
+
timestep,
|
| 1739 |
+
encoder_hidden_states,
|
| 1740 |
+
encoder_attention_mask,
|
| 1741 |
+
pooled_projections,
|
| 1742 |
+
guidance,
|
| 1743 |
+
latent_indices=None,
|
| 1744 |
+
clean_latents=None,
|
| 1745 |
+
clean_latent_indices=None,
|
| 1746 |
+
clean_latents_2x=None,
|
| 1747 |
+
clean_latent_2x_indices=None,
|
| 1748 |
+
clean_latents_4x=None,
|
| 1749 |
+
clean_latent_4x_indices=None,
|
| 1750 |
+
image_embeddings=None,
|
| 1751 |
+
attention_kwargs=None,
|
| 1752 |
+
return_dict=True,
|
| 1753 |
+
):
|
| 1754 |
+
|
| 1755 |
+
if attention_kwargs is None:
|
| 1756 |
+
attention_kwargs = {}
|
| 1757 |
+
|
| 1758 |
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
| 1759 |
+
p, p_t = self.config_patch_size, self.config_patch_size_t
|
| 1760 |
+
post_patch_num_frames = num_frames // p_t
|
| 1761 |
+
post_patch_height = height // p
|
| 1762 |
+
post_patch_width = width // p
|
| 1763 |
+
original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
|
| 1764 |
+
|
| 1765 |
+
hidden_states, rope_freqs = self.process_input_hidden_states(
|
| 1766 |
+
hidden_states,
|
| 1767 |
+
latent_indices,
|
| 1768 |
+
clean_latents,
|
| 1769 |
+
clean_latent_indices,
|
| 1770 |
+
clean_latents_2x,
|
| 1771 |
+
clean_latent_2x_indices,
|
| 1772 |
+
clean_latents_4x,
|
| 1773 |
+
clean_latent_4x_indices,
|
| 1774 |
+
)
|
| 1775 |
+
del (
|
| 1776 |
+
latent_indices,
|
| 1777 |
+
clean_latents,
|
| 1778 |
+
clean_latent_indices,
|
| 1779 |
+
clean_latents_2x,
|
| 1780 |
+
clean_latent_2x_indices,
|
| 1781 |
+
clean_latents_4x,
|
| 1782 |
+
clean_latent_4x_indices,
|
| 1783 |
+
) # free memory
|
| 1784 |
+
|
| 1785 |
+
temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
|
| 1786 |
+
encoder_hidden_states = self.gradient_checkpointing_method(
|
| 1787 |
+
self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask
|
| 1788 |
+
)
|
| 1789 |
+
|
| 1790 |
+
if self.image_projection is not None:
|
| 1791 |
+
assert image_embeddings is not None, "You must use image embeddings!"
|
| 1792 |
+
extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
|
| 1793 |
+
extra_attention_mask = torch.ones(
|
| 1794 |
+
(batch_size, extra_encoder_hidden_states.shape[1]),
|
| 1795 |
+
dtype=encoder_attention_mask.dtype,
|
| 1796 |
+
device=encoder_attention_mask.device,
|
| 1797 |
+
)
|
| 1798 |
+
|
| 1799 |
+
# must cat before (not after) encoder_hidden_states, due to attn masking
|
| 1800 |
+
encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
|
| 1801 |
+
encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
|
| 1802 |
+
del extra_encoder_hidden_states, extra_attention_mask # free memory
|
| 1803 |
+
|
| 1804 |
+
with torch.no_grad():
|
| 1805 |
+
if batch_size == 1:
|
| 1806 |
+
# When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
|
| 1807 |
+
# If they are not same, then their impls are wrong. Ours are always the correct one.
|
| 1808 |
+
text_len = encoder_attention_mask.sum().item()
|
| 1809 |
+
encoder_hidden_states = encoder_hidden_states[:, :text_len]
|
| 1810 |
+
attention_mask = None, None, None, None
|
| 1811 |
+
else:
|
| 1812 |
+
img_seq_len = hidden_states.shape[1]
|
| 1813 |
+
txt_seq_len = encoder_hidden_states.shape[1]
|
| 1814 |
+
|
| 1815 |
+
cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
|
| 1816 |
+
cu_seqlens_kv = cu_seqlens_q
|
| 1817 |
+
max_seqlen_q = img_seq_len + txt_seq_len
|
| 1818 |
+
max_seqlen_kv = max_seqlen_q
|
| 1819 |
+
|
| 1820 |
+
attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
|
| 1821 |
+
del cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv # free memory
|
| 1822 |
+
del encoder_attention_mask # free memory
|
| 1823 |
+
|
| 1824 |
+
if self.enable_teacache:
|
| 1825 |
+
modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
|
| 1826 |
+
|
| 1827 |
+
if self.cnt == 0 or self.cnt == self.num_steps - 1:
|
| 1828 |
+
should_calc = True
|
| 1829 |
+
self.accumulated_rel_l1_distance = 0
|
| 1830 |
+
else:
|
| 1831 |
+
curr_rel_l1 = (
|
| 1832 |
+
((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())
|
| 1833 |
+
.cpu()
|
| 1834 |
+
.item()
|
| 1835 |
+
)
|
| 1836 |
+
self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
|
| 1837 |
+
should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
|
| 1838 |
+
|
| 1839 |
+
if should_calc:
|
| 1840 |
+
self.accumulated_rel_l1_distance = 0
|
| 1841 |
+
|
| 1842 |
+
self.previous_modulated_input = modulated_inp
|
| 1843 |
+
self.cnt += 1
|
| 1844 |
+
|
| 1845 |
+
if self.cnt == self.num_steps:
|
| 1846 |
+
self.cnt = 0
|
| 1847 |
+
|
| 1848 |
+
if not should_calc:
|
| 1849 |
+
hidden_states = hidden_states + self.previous_residual
|
| 1850 |
+
else:
|
| 1851 |
+
ori_hidden_states = hidden_states.clone()
|
| 1852 |
+
|
| 1853 |
+
for block_id, block in enumerate(self.transformer_blocks):
|
| 1854 |
+
hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
|
| 1855 |
+
block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
|
| 1856 |
+
)
|
| 1857 |
+
|
| 1858 |
+
for block_id, block in enumerate(self.single_transformer_blocks):
|
| 1859 |
+
hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
|
| 1860 |
+
block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
|
| 1861 |
+
)
|
| 1862 |
+
|
| 1863 |
+
self.previous_residual = hidden_states - ori_hidden_states
|
| 1864 |
+
del ori_hidden_states # free memory
|
| 1865 |
+
else:
|
| 1866 |
+
for block_id, block in enumerate(self.transformer_blocks):
|
| 1867 |
+
if self.blocks_to_swap:
|
| 1868 |
+
self.offloader_double.wait_for_block(block_id)
|
| 1869 |
+
|
| 1870 |
+
hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
|
| 1871 |
+
block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
|
| 1872 |
+
)
|
| 1873 |
+
|
| 1874 |
+
if self.blocks_to_swap:
|
| 1875 |
+
self.offloader_double.submit_move_blocks_forward(self.transformer_blocks, block_id)
|
| 1876 |
+
|
| 1877 |
+
for block_id, block in enumerate(self.single_transformer_blocks):
|
| 1878 |
+
if self.blocks_to_swap:
|
| 1879 |
+
self.offloader_single.wait_for_block(block_id)
|
| 1880 |
+
|
| 1881 |
+
hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
|
| 1882 |
+
block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
|
| 1883 |
+
)
|
| 1884 |
+
|
| 1885 |
+
if self.blocks_to_swap:
|
| 1886 |
+
self.offloader_single.submit_move_blocks_forward(self.single_transformer_blocks, block_id)
|
| 1887 |
+
|
| 1888 |
+
del attention_mask, rope_freqs # free memory
|
| 1889 |
+
del encoder_hidden_states # free memory
|
| 1890 |
+
|
| 1891 |
+
hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
|
| 1892 |
+
|
| 1893 |
+
hidden_states = hidden_states[:, -original_context_length:, :]
|
| 1894 |
+
|
| 1895 |
+
if self.high_quality_fp32_output_for_inference:
|
| 1896 |
+
hidden_states = hidden_states.to(dtype=torch.float32)
|
| 1897 |
+
if self.proj_out.weight.dtype != torch.float32:
|
| 1898 |
+
self.proj_out.to(dtype=torch.float32)
|
| 1899 |
+
|
| 1900 |
+
hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
|
| 1901 |
+
|
| 1902 |
+
hidden_states = einops.rearrange(
|
| 1903 |
+
hidden_states,
|
| 1904 |
+
"b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)",
|
| 1905 |
+
t=post_patch_num_frames,
|
| 1906 |
+
h=post_patch_height,
|
| 1907 |
+
w=post_patch_width,
|
| 1908 |
+
pt=p_t,
|
| 1909 |
+
ph=p,
|
| 1910 |
+
pw=p,
|
| 1911 |
+
)
|
| 1912 |
+
|
| 1913 |
+
if return_dict:
|
| 1914 |
+
# return Transformer2DModelOutput(sample=hidden_states)
|
| 1915 |
+
return SimpleNamespace(sample=hidden_states)
|
| 1916 |
+
|
| 1917 |
+
return (hidden_states,)
|
| 1918 |
+
|
| 1919 |
+
def fp8_optimization(
|
| 1920 |
+
self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool, use_scaled_mm: bool = False
|
| 1921 |
+
) -> dict[str, torch.Tensor]: # Return type hint added
|
| 1922 |
+
"""
|
| 1923 |
+
Optimize the model state_dict with fp8.
|
| 1924 |
+
|
| 1925 |
+
Args:
|
| 1926 |
+
state_dict (dict[str, torch.Tensor]):
|
| 1927 |
+
The state_dict of the model.
|
| 1928 |
+
device (torch.device):
|
| 1929 |
+
The device to calculate the weight.
|
| 1930 |
+
move_to_device (bool):
|
| 1931 |
+
Whether to move the weight to the device after optimization.
|
| 1932 |
+
use_scaled_mm (bool):
|
| 1933 |
+
Whether to use scaled matrix multiplication for FP8.
|
| 1934 |
+
"""
|
| 1935 |
+
TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
|
| 1936 |
+
EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8
|
| 1937 |
+
|
| 1938 |
+
# inplace optimization
|
| 1939 |
+
state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device)
|
| 1940 |
+
|
| 1941 |
+
# apply monkey patching
|
| 1942 |
+
apply_fp8_monkey_patch(self, state_dict, use_scaled_mm=use_scaled_mm)
|
| 1943 |
+
|
| 1944 |
+
return state_dict
|
| 1945 |
+
|
| 1946 |
+
|
| 1947 |
+
def load_packed_model(
|
| 1948 |
+
device: Union[str, torch.device],
|
| 1949 |
+
dit_path: str,
|
| 1950 |
+
attn_mode: str,
|
| 1951 |
+
loading_device: Union[str, torch.device],
|
| 1952 |
+
fp8_scaled: bool = False,
|
| 1953 |
+
split_attn: bool = False,
|
| 1954 |
+
) -> HunyuanVideoTransformer3DModelPacked:
|
| 1955 |
+
# TODO support split_attn
|
| 1956 |
+
device = torch.device(device)
|
| 1957 |
+
loading_device = torch.device(loading_device)
|
| 1958 |
+
|
| 1959 |
+
if os.path.isdir(dit_path):
|
| 1960 |
+
# we don't support from_pretrained for now, so loading safetensors directly
|
| 1961 |
+
safetensor_files = glob.glob(os.path.join(dit_path, "*.safetensors"))
|
| 1962 |
+
if len(safetensor_files) == 0:
|
| 1963 |
+
raise ValueError(f"Cannot find safetensors file in {dit_path}")
|
| 1964 |
+
# sort by name and take the first one
|
| 1965 |
+
safetensor_files.sort()
|
| 1966 |
+
dit_path = safetensor_files[0]
|
| 1967 |
+
|
| 1968 |
+
with init_empty_weights():
|
| 1969 |
+
logger.info(f"Creating HunyuanVideoTransformer3DModelPacked")
|
| 1970 |
+
model = HunyuanVideoTransformer3DModelPacked(
|
| 1971 |
+
attention_head_dim=128,
|
| 1972 |
+
guidance_embeds=True,
|
| 1973 |
+
has_clean_x_embedder=True,
|
| 1974 |
+
has_image_proj=True,
|
| 1975 |
+
image_proj_dim=1152,
|
| 1976 |
+
in_channels=16,
|
| 1977 |
+
mlp_ratio=4.0,
|
| 1978 |
+
num_attention_heads=24,
|
| 1979 |
+
num_layers=20,
|
| 1980 |
+
num_refiner_layers=2,
|
| 1981 |
+
num_single_layers=40,
|
| 1982 |
+
out_channels=16,
|
| 1983 |
+
patch_size=2,
|
| 1984 |
+
patch_size_t=1,
|
| 1985 |
+
pooled_projection_dim=768,
|
| 1986 |
+
qk_norm="rms_norm",
|
| 1987 |
+
rope_axes_dim=(16, 56, 56),
|
| 1988 |
+
rope_theta=256.0,
|
| 1989 |
+
text_embed_dim=4096,
|
| 1990 |
+
attn_mode=attn_mode,
|
| 1991 |
+
split_attn=split_attn,
|
| 1992 |
+
)
|
| 1993 |
+
|
| 1994 |
+
# if fp8_scaled, load model weights to CPU to reduce VRAM usage. Otherwise, load to the specified device (CPU for block swap or CUDA for others)
|
| 1995 |
+
dit_loading_device = torch.device("cpu") if fp8_scaled else loading_device
|
| 1996 |
+
logger.info(f"Loading DiT model from {dit_path}, device={dit_loading_device}")
|
| 1997 |
+
|
| 1998 |
+
# load model weights with the specified dtype or as is
|
| 1999 |
+
sd = load_split_weights(dit_path, device=dit_loading_device, disable_mmap=True)
|
| 2000 |
+
|
| 2001 |
+
if fp8_scaled:
|
| 2002 |
+
# fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap)
|
| 2003 |
+
logger.info(f"Optimizing model weights to fp8. This may take a while.")
|
| 2004 |
+
sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu")
|
| 2005 |
+
|
| 2006 |
+
if loading_device.type != "cpu":
|
| 2007 |
+
# make sure all the model weights are on the loading_device
|
| 2008 |
+
logger.info(f"Moving weights to {loading_device}")
|
| 2009 |
+
for key in sd.keys():
|
| 2010 |
+
sd[key] = sd[key].to(loading_device)
|
| 2011 |
+
|
| 2012 |
+
info = model.load_state_dict(sd, strict=True, assign=True)
|
| 2013 |
+
logger.info(f"Loaded DiT model from {dit_path}, info={info}")
|
| 2014 |
+
|
| 2015 |
+
return model
|
frame_pack/k_diffusion_hunyuan.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# original code: https://github.com/lllyasviel/FramePack
|
| 2 |
+
# original license: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
# from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
|
| 8 |
+
# from diffusers_helper.k_diffusion.wrapper import fm_wrapper
|
| 9 |
+
# from diffusers_helper.utils import repeat_to_batch_size
|
| 10 |
+
from frame_pack.uni_pc_fm import sample_unipc
|
| 11 |
+
from frame_pack.wrapper import fm_wrapper
|
| 12 |
+
from frame_pack.utils import repeat_to_batch_size
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def flux_time_shift(t, mu=1.15, sigma=1.0):
|
| 16 |
+
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
|
| 20 |
+
k = (y2 - y1) / (x2 - x1)
|
| 21 |
+
b = y1 - k * x1
|
| 22 |
+
mu = k * context_length + b
|
| 23 |
+
mu = min(mu, math.log(exp_max))
|
| 24 |
+
return mu
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_flux_sigmas_from_mu(n, mu):
|
| 28 |
+
sigmas = torch.linspace(1, 0, steps=n + 1)
|
| 29 |
+
sigmas = flux_time_shift(sigmas, mu=mu)
|
| 30 |
+
return sigmas
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# @torch.inference_mode()
|
| 34 |
+
def sample_hunyuan(
|
| 35 |
+
transformer,
|
| 36 |
+
sampler="unipc",
|
| 37 |
+
initial_latent=None,
|
| 38 |
+
concat_latent=None,
|
| 39 |
+
strength=1.0,
|
| 40 |
+
width=512,
|
| 41 |
+
height=512,
|
| 42 |
+
frames=16,
|
| 43 |
+
real_guidance_scale=1.0,
|
| 44 |
+
distilled_guidance_scale=6.0,
|
| 45 |
+
guidance_rescale=0.0,
|
| 46 |
+
shift=None,
|
| 47 |
+
num_inference_steps=25,
|
| 48 |
+
batch_size=None,
|
| 49 |
+
generator=None,
|
| 50 |
+
prompt_embeds=None,
|
| 51 |
+
prompt_embeds_mask=None,
|
| 52 |
+
prompt_poolers=None,
|
| 53 |
+
negative_prompt_embeds=None,
|
| 54 |
+
negative_prompt_embeds_mask=None,
|
| 55 |
+
negative_prompt_poolers=None,
|
| 56 |
+
dtype=torch.bfloat16,
|
| 57 |
+
device=None,
|
| 58 |
+
negative_kwargs=None,
|
| 59 |
+
callback=None,
|
| 60 |
+
**kwargs,
|
| 61 |
+
):
|
| 62 |
+
device = device or transformer.device
|
| 63 |
+
|
| 64 |
+
if batch_size is None:
|
| 65 |
+
batch_size = int(prompt_embeds.shape[0])
|
| 66 |
+
|
| 67 |
+
latents = torch.randn(
|
| 68 |
+
(batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device
|
| 69 |
+
).to(device=device, dtype=torch.float32)
|
| 70 |
+
|
| 71 |
+
B, C, T, H, W = latents.shape
|
| 72 |
+
seq_length = T * H * W // 4 # 9*80*80//4 = 14400
|
| 73 |
+
|
| 74 |
+
if shift is None:
|
| 75 |
+
mu = calculate_flux_mu(seq_length, exp_max=7.0) # 1.9459... if seq_len is large, mu is clipped.
|
| 76 |
+
else:
|
| 77 |
+
mu = math.log(shift)
|
| 78 |
+
|
| 79 |
+
sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
|
| 80 |
+
|
| 81 |
+
k_model = fm_wrapper(transformer)
|
| 82 |
+
|
| 83 |
+
if initial_latent is not None:
|
| 84 |
+
sigmas = sigmas * strength
|
| 85 |
+
first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
|
| 86 |
+
initial_latent = initial_latent.to(device=device, dtype=torch.float32)
|
| 87 |
+
latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
|
| 88 |
+
|
| 89 |
+
if concat_latent is not None:
|
| 90 |
+
concat_latent = concat_latent.to(latents)
|
| 91 |
+
|
| 92 |
+
distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
|
| 93 |
+
|
| 94 |
+
prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
|
| 95 |
+
prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
|
| 96 |
+
prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
|
| 97 |
+
negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
|
| 98 |
+
negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
|
| 99 |
+
negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
|
| 100 |
+
concat_latent = repeat_to_batch_size(concat_latent, batch_size)
|
| 101 |
+
|
| 102 |
+
sampler_kwargs = dict(
|
| 103 |
+
dtype=dtype,
|
| 104 |
+
cfg_scale=real_guidance_scale,
|
| 105 |
+
cfg_rescale=guidance_rescale,
|
| 106 |
+
concat_latent=concat_latent,
|
| 107 |
+
positive=dict(
|
| 108 |
+
pooled_projections=prompt_poolers,
|
| 109 |
+
encoder_hidden_states=prompt_embeds,
|
| 110 |
+
encoder_attention_mask=prompt_embeds_mask,
|
| 111 |
+
guidance=distilled_guidance,
|
| 112 |
+
**kwargs,
|
| 113 |
+
),
|
| 114 |
+
negative=dict(
|
| 115 |
+
pooled_projections=negative_prompt_poolers,
|
| 116 |
+
encoder_hidden_states=negative_prompt_embeds,
|
| 117 |
+
encoder_attention_mask=negative_prompt_embeds_mask,
|
| 118 |
+
guidance=distilled_guidance,
|
| 119 |
+
**(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
|
| 120 |
+
),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if sampler == "unipc":
|
| 124 |
+
results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
|
| 125 |
+
else:
|
| 126 |
+
raise NotImplementedError(f"Sampler {sampler} is not supported.")
|
| 127 |
+
|
| 128 |
+
return results
|
frame_pack/uni_pc_fm.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Better Flow Matching UniPC by Lvmin Zhang
|
| 2 |
+
# (c) 2025
|
| 3 |
+
# CC BY-SA 4.0
|
| 4 |
+
# Attribution-ShareAlike 4.0 International Licence
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
from tqdm.auto import trange
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def expand_dims(v, dims):
|
| 13 |
+
return v[(...,) + (None,) * (dims - 1)]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FlowMatchUniPC:
|
| 17 |
+
def __init__(self, model, extra_args, variant='bh1'):
|
| 18 |
+
self.model = model
|
| 19 |
+
self.variant = variant
|
| 20 |
+
self.extra_args = extra_args
|
| 21 |
+
|
| 22 |
+
def model_fn(self, x, t):
|
| 23 |
+
return self.model(x, t, **self.extra_args)
|
| 24 |
+
|
| 25 |
+
def update_fn(self, x, model_prev_list, t_prev_list, t, order):
|
| 26 |
+
assert order <= len(model_prev_list)
|
| 27 |
+
dims = x.dim()
|
| 28 |
+
|
| 29 |
+
t_prev_0 = t_prev_list[-1]
|
| 30 |
+
lambda_prev_0 = - torch.log(t_prev_0)
|
| 31 |
+
lambda_t = - torch.log(t)
|
| 32 |
+
model_prev_0 = model_prev_list[-1]
|
| 33 |
+
|
| 34 |
+
h = lambda_t - lambda_prev_0
|
| 35 |
+
|
| 36 |
+
rks = []
|
| 37 |
+
D1s = []
|
| 38 |
+
for i in range(1, order):
|
| 39 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 40 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 41 |
+
lambda_prev_i = - torch.log(t_prev_i)
|
| 42 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
| 43 |
+
rks.append(rk)
|
| 44 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 45 |
+
|
| 46 |
+
rks.append(1.)
|
| 47 |
+
rks = torch.tensor(rks, device=x.device)
|
| 48 |
+
|
| 49 |
+
R = []
|
| 50 |
+
b = []
|
| 51 |
+
|
| 52 |
+
hh = -h[0]
|
| 53 |
+
h_phi_1 = torch.expm1(hh)
|
| 54 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 55 |
+
|
| 56 |
+
factorial_i = 1
|
| 57 |
+
|
| 58 |
+
if self.variant == 'bh1':
|
| 59 |
+
B_h = hh
|
| 60 |
+
elif self.variant == 'bh2':
|
| 61 |
+
B_h = torch.expm1(hh)
|
| 62 |
+
else:
|
| 63 |
+
raise NotImplementedError('Bad variant!')
|
| 64 |
+
|
| 65 |
+
for i in range(1, order + 1):
|
| 66 |
+
R.append(torch.pow(rks, i - 1))
|
| 67 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 68 |
+
factorial_i *= (i + 1)
|
| 69 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 70 |
+
|
| 71 |
+
R = torch.stack(R)
|
| 72 |
+
b = torch.tensor(b, device=x.device)
|
| 73 |
+
|
| 74 |
+
use_predictor = len(D1s) > 0
|
| 75 |
+
|
| 76 |
+
if use_predictor:
|
| 77 |
+
D1s = torch.stack(D1s, dim=1)
|
| 78 |
+
if order == 2:
|
| 79 |
+
rhos_p = torch.tensor([0.5], device=b.device)
|
| 80 |
+
else:
|
| 81 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
| 82 |
+
else:
|
| 83 |
+
D1s = None
|
| 84 |
+
rhos_p = None
|
| 85 |
+
|
| 86 |
+
if order == 1:
|
| 87 |
+
rhos_c = torch.tensor([0.5], device=b.device)
|
| 88 |
+
else:
|
| 89 |
+
rhos_c = torch.linalg.solve(R, b)
|
| 90 |
+
|
| 91 |
+
x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
|
| 92 |
+
|
| 93 |
+
if use_predictor:
|
| 94 |
+
pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
|
| 95 |
+
else:
|
| 96 |
+
pred_res = 0
|
| 97 |
+
|
| 98 |
+
x_t = x_t_ - expand_dims(B_h, dims) * pred_res
|
| 99 |
+
model_t = self.model_fn(x_t, t)
|
| 100 |
+
|
| 101 |
+
if D1s is not None:
|
| 102 |
+
corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
|
| 103 |
+
else:
|
| 104 |
+
corr_res = 0
|
| 105 |
+
|
| 106 |
+
D1_t = (model_t - model_prev_0)
|
| 107 |
+
x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
| 108 |
+
|
| 109 |
+
return x_t, model_t
|
| 110 |
+
|
| 111 |
+
def sample(self, x, sigmas, callback=None, disable_pbar=False):
|
| 112 |
+
order = min(3, len(sigmas) - 2)
|
| 113 |
+
model_prev_list, t_prev_list = [], []
|
| 114 |
+
for i in trange(len(sigmas) - 1, disable=disable_pbar):
|
| 115 |
+
vec_t = sigmas[i].expand(x.shape[0])
|
| 116 |
+
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
if i == 0:
|
| 119 |
+
model_prev_list = [self.model_fn(x, vec_t)]
|
| 120 |
+
t_prev_list = [vec_t]
|
| 121 |
+
elif i < order:
|
| 122 |
+
init_order = i
|
| 123 |
+
x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
|
| 124 |
+
model_prev_list.append(model_x)
|
| 125 |
+
t_prev_list.append(vec_t)
|
| 126 |
+
else:
|
| 127 |
+
x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
|
| 128 |
+
model_prev_list.append(model_x)
|
| 129 |
+
t_prev_list.append(vec_t)
|
| 130 |
+
|
| 131 |
+
model_prev_list = model_prev_list[-order:]
|
| 132 |
+
t_prev_list = t_prev_list[-order:]
|
| 133 |
+
|
| 134 |
+
if callback is not None:
|
| 135 |
+
callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
|
| 136 |
+
|
| 137 |
+
return model_prev_list[-1]
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
|
| 141 |
+
assert variant in ['bh1', 'bh2']
|
| 142 |
+
return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
|
frame_pack/utils.py
ADDED
|
@@ -0,0 +1,617 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import cv2
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import glob
|
| 6 |
+
import torch
|
| 7 |
+
import einops
|
| 8 |
+
import numpy as np
|
| 9 |
+
import datetime
|
| 10 |
+
import torchvision
|
| 11 |
+
|
| 12 |
+
import safetensors.torch as sf
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def min_resize(x, m):
|
| 17 |
+
if x.shape[0] < x.shape[1]:
|
| 18 |
+
s0 = m
|
| 19 |
+
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
|
| 20 |
+
else:
|
| 21 |
+
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
|
| 22 |
+
s1 = m
|
| 23 |
+
new_max = max(s1, s0)
|
| 24 |
+
raw_max = max(x.shape[0], x.shape[1])
|
| 25 |
+
if new_max < raw_max:
|
| 26 |
+
interpolation = cv2.INTER_AREA
|
| 27 |
+
else:
|
| 28 |
+
interpolation = cv2.INTER_LANCZOS4
|
| 29 |
+
y = cv2.resize(x, (s1, s0), interpolation=interpolation)
|
| 30 |
+
return y
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def d_resize(x, y):
|
| 34 |
+
H, W, C = y.shape
|
| 35 |
+
new_min = min(H, W)
|
| 36 |
+
raw_min = min(x.shape[0], x.shape[1])
|
| 37 |
+
if new_min < raw_min:
|
| 38 |
+
interpolation = cv2.INTER_AREA
|
| 39 |
+
else:
|
| 40 |
+
interpolation = cv2.INTER_LANCZOS4
|
| 41 |
+
y = cv2.resize(x, (W, H), interpolation=interpolation)
|
| 42 |
+
return y
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def resize_and_center_crop(image, target_width, target_height):
|
| 46 |
+
if target_height == image.shape[0] and target_width == image.shape[1]:
|
| 47 |
+
return image
|
| 48 |
+
|
| 49 |
+
pil_image = Image.fromarray(image)
|
| 50 |
+
original_width, original_height = pil_image.size
|
| 51 |
+
scale_factor = max(target_width / original_width, target_height / original_height)
|
| 52 |
+
resized_width = int(round(original_width * scale_factor))
|
| 53 |
+
resized_height = int(round(original_height * scale_factor))
|
| 54 |
+
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
|
| 55 |
+
left = (resized_width - target_width) / 2
|
| 56 |
+
top = (resized_height - target_height) / 2
|
| 57 |
+
right = (resized_width + target_width) / 2
|
| 58 |
+
bottom = (resized_height + target_height) / 2
|
| 59 |
+
cropped_image = resized_image.crop((left, top, right, bottom))
|
| 60 |
+
return np.array(cropped_image)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def resize_and_center_crop_pytorch(image, target_width, target_height):
|
| 64 |
+
B, C, H, W = image.shape
|
| 65 |
+
|
| 66 |
+
if H == target_height and W == target_width:
|
| 67 |
+
return image
|
| 68 |
+
|
| 69 |
+
scale_factor = max(target_width / W, target_height / H)
|
| 70 |
+
resized_width = int(round(W * scale_factor))
|
| 71 |
+
resized_height = int(round(H * scale_factor))
|
| 72 |
+
|
| 73 |
+
resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
|
| 74 |
+
|
| 75 |
+
top = (resized_height - target_height) // 2
|
| 76 |
+
left = (resized_width - target_width) // 2
|
| 77 |
+
cropped = resized[:, :, top : top + target_height, left : left + target_width]
|
| 78 |
+
|
| 79 |
+
return cropped
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def resize_without_crop(image, target_width, target_height):
|
| 83 |
+
if target_height == image.shape[0] and target_width == image.shape[1]:
|
| 84 |
+
return image
|
| 85 |
+
|
| 86 |
+
pil_image = Image.fromarray(image)
|
| 87 |
+
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
| 88 |
+
return np.array(resized_image)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def just_crop(image, w, h):
|
| 92 |
+
if h == image.shape[0] and w == image.shape[1]:
|
| 93 |
+
return image
|
| 94 |
+
|
| 95 |
+
original_height, original_width = image.shape[:2]
|
| 96 |
+
k = min(original_height / h, original_width / w)
|
| 97 |
+
new_width = int(round(w * k))
|
| 98 |
+
new_height = int(round(h * k))
|
| 99 |
+
x_start = (original_width - new_width) // 2
|
| 100 |
+
y_start = (original_height - new_height) // 2
|
| 101 |
+
cropped_image = image[y_start : y_start + new_height, x_start : x_start + new_width]
|
| 102 |
+
return cropped_image
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def write_to_json(data, file_path):
|
| 106 |
+
temp_file_path = file_path + ".tmp"
|
| 107 |
+
with open(temp_file_path, "wt", encoding="utf-8") as temp_file:
|
| 108 |
+
json.dump(data, temp_file, indent=4)
|
| 109 |
+
os.replace(temp_file_path, file_path)
|
| 110 |
+
return
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def read_from_json(file_path):
|
| 114 |
+
with open(file_path, "rt", encoding="utf-8") as file:
|
| 115 |
+
data = json.load(file)
|
| 116 |
+
return data
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_active_parameters(m):
|
| 120 |
+
return {k: v for k, v in m.named_parameters() if v.requires_grad}
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def cast_training_params(m, dtype=torch.float32):
|
| 124 |
+
result = {}
|
| 125 |
+
for n, param in m.named_parameters():
|
| 126 |
+
if param.requires_grad:
|
| 127 |
+
param.data = param.to(dtype)
|
| 128 |
+
result[n] = param
|
| 129 |
+
return result
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def separate_lora_AB(parameters, B_patterns=None):
|
| 133 |
+
parameters_normal = {}
|
| 134 |
+
parameters_B = {}
|
| 135 |
+
|
| 136 |
+
if B_patterns is None:
|
| 137 |
+
B_patterns = [".lora_B.", "__zero__"]
|
| 138 |
+
|
| 139 |
+
for k, v in parameters.items():
|
| 140 |
+
if any(B_pattern in k for B_pattern in B_patterns):
|
| 141 |
+
parameters_B[k] = v
|
| 142 |
+
else:
|
| 143 |
+
parameters_normal[k] = v
|
| 144 |
+
|
| 145 |
+
return parameters_normal, parameters_B
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def set_attr_recursive(obj, attr, value):
|
| 149 |
+
attrs = attr.split(".")
|
| 150 |
+
for name in attrs[:-1]:
|
| 151 |
+
obj = getattr(obj, name)
|
| 152 |
+
setattr(obj, attrs[-1], value)
|
| 153 |
+
return
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def print_tensor_list_size(tensors):
|
| 157 |
+
total_size = 0
|
| 158 |
+
total_elements = 0
|
| 159 |
+
|
| 160 |
+
if isinstance(tensors, dict):
|
| 161 |
+
tensors = tensors.values()
|
| 162 |
+
|
| 163 |
+
for tensor in tensors:
|
| 164 |
+
total_size += tensor.nelement() * tensor.element_size()
|
| 165 |
+
total_elements += tensor.nelement()
|
| 166 |
+
|
| 167 |
+
total_size_MB = total_size / (1024**2)
|
| 168 |
+
total_elements_B = total_elements / 1e9
|
| 169 |
+
|
| 170 |
+
print(f"Total number of tensors: {len(tensors)}")
|
| 171 |
+
print(f"Total size of tensors: {total_size_MB:.2f} MB")
|
| 172 |
+
print(f"Total number of parameters: {total_elements_B:.3f} billion")
|
| 173 |
+
return
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@torch.no_grad()
|
| 177 |
+
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
|
| 178 |
+
batch_size = a.size(0)
|
| 179 |
+
|
| 180 |
+
if b is None:
|
| 181 |
+
b = torch.zeros_like(a)
|
| 182 |
+
|
| 183 |
+
if mask_a is None:
|
| 184 |
+
mask_a = torch.rand(batch_size) < probability_a
|
| 185 |
+
|
| 186 |
+
mask_a = mask_a.to(a.device)
|
| 187 |
+
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
|
| 188 |
+
result = torch.where(mask_a, a, b)
|
| 189 |
+
return result
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
@torch.no_grad()
|
| 193 |
+
def zero_module(module):
|
| 194 |
+
for p in module.parameters():
|
| 195 |
+
p.detach().zero_()
|
| 196 |
+
return module
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@torch.no_grad()
|
| 200 |
+
def supress_lower_channels(m, k, alpha=0.01):
|
| 201 |
+
data = m.weight.data.clone()
|
| 202 |
+
|
| 203 |
+
assert int(data.shape[1]) >= k
|
| 204 |
+
|
| 205 |
+
data[:, :k] = data[:, :k] * alpha
|
| 206 |
+
m.weight.data = data.contiguous().clone()
|
| 207 |
+
return m
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def freeze_module(m):
|
| 211 |
+
if not hasattr(m, "_forward_inside_frozen_module"):
|
| 212 |
+
m._forward_inside_frozen_module = m.forward
|
| 213 |
+
m.requires_grad_(False)
|
| 214 |
+
m.forward = torch.no_grad()(m.forward)
|
| 215 |
+
return m
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_latest_safetensors(folder_path):
|
| 219 |
+
safetensors_files = glob.glob(os.path.join(folder_path, "*.safetensors"))
|
| 220 |
+
|
| 221 |
+
if not safetensors_files:
|
| 222 |
+
raise ValueError("No file to resume!")
|
| 223 |
+
|
| 224 |
+
latest_file = max(safetensors_files, key=os.path.getmtime)
|
| 225 |
+
latest_file = os.path.abspath(os.path.realpath(latest_file))
|
| 226 |
+
return latest_file
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
|
| 230 |
+
tags = tags_str.split(", ")
|
| 231 |
+
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
|
| 232 |
+
prompt = ", ".join(tags)
|
| 233 |
+
return prompt
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
|
| 237 |
+
numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
|
| 238 |
+
if round_to_int:
|
| 239 |
+
numbers = np.round(numbers).astype(int)
|
| 240 |
+
return numbers.tolist()
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
|
| 244 |
+
edges = np.linspace(0, 1, n + 1)
|
| 245 |
+
points = np.random.uniform(edges[:-1], edges[1:])
|
| 246 |
+
numbers = inclusive + (exclusive - inclusive) * points
|
| 247 |
+
if round_to_int:
|
| 248 |
+
numbers = np.round(numbers).astype(int)
|
| 249 |
+
return numbers.tolist()
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def soft_append_bcthw(history, current, overlap=0):
|
| 253 |
+
if overlap <= 0:
|
| 254 |
+
return torch.cat([history, current], dim=2)
|
| 255 |
+
|
| 256 |
+
assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
|
| 257 |
+
assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
|
| 258 |
+
|
| 259 |
+
weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
|
| 260 |
+
blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
|
| 261 |
+
output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
|
| 262 |
+
|
| 263 |
+
return output.to(history)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def save_bcthw_as_mp4(x, output_filename, fps=10):
|
| 267 |
+
b, c, t, h, w = x.shape
|
| 268 |
+
|
| 269 |
+
per_row = b
|
| 270 |
+
for p in [6, 5, 4, 3, 2]:
|
| 271 |
+
if b % p == 0:
|
| 272 |
+
per_row = p
|
| 273 |
+
break
|
| 274 |
+
|
| 275 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
| 276 |
+
x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
|
| 277 |
+
x = x.detach().cpu().to(torch.uint8)
|
| 278 |
+
x = einops.rearrange(x, "(m n) c t h w -> t (m h) (n w) c", n=per_row)
|
| 279 |
+
torchvision.io.write_video(output_filename, x, fps=fps, video_codec="libx264", options={"crf": "0"})
|
| 280 |
+
|
| 281 |
+
# write tensor as .pt file
|
| 282 |
+
torch.save(x, output_filename.replace(".mp4", ".pt"))
|
| 283 |
+
|
| 284 |
+
return x
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def save_bcthw_as_png(x, output_filename):
|
| 288 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
| 289 |
+
x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
|
| 290 |
+
x = x.detach().cpu().to(torch.uint8)
|
| 291 |
+
x = einops.rearrange(x, "b c t h w -> c (b h) (t w)")
|
| 292 |
+
torchvision.io.write_png(x, output_filename)
|
| 293 |
+
return output_filename
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def save_bchw_as_png(x, output_filename):
|
| 297 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
| 298 |
+
x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
|
| 299 |
+
x = x.detach().cpu().to(torch.uint8)
|
| 300 |
+
x = einops.rearrange(x, "b c h w -> c h (b w)")
|
| 301 |
+
torchvision.io.write_png(x, output_filename)
|
| 302 |
+
return output_filename
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def add_tensors_with_padding(tensor1, tensor2):
|
| 306 |
+
if tensor1.shape == tensor2.shape:
|
| 307 |
+
return tensor1 + tensor2
|
| 308 |
+
|
| 309 |
+
shape1 = tensor1.shape
|
| 310 |
+
shape2 = tensor2.shape
|
| 311 |
+
|
| 312 |
+
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
|
| 313 |
+
|
| 314 |
+
padded_tensor1 = torch.zeros(new_shape)
|
| 315 |
+
padded_tensor2 = torch.zeros(new_shape)
|
| 316 |
+
|
| 317 |
+
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
|
| 318 |
+
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
|
| 319 |
+
|
| 320 |
+
result = padded_tensor1 + padded_tensor2
|
| 321 |
+
return result
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def print_free_mem():
|
| 325 |
+
torch.cuda.empty_cache()
|
| 326 |
+
free_mem, total_mem = torch.cuda.mem_get_info(0)
|
| 327 |
+
free_mem_mb = free_mem / (1024**2)
|
| 328 |
+
total_mem_mb = total_mem / (1024**2)
|
| 329 |
+
print(f"Free memory: {free_mem_mb:.2f} MB")
|
| 330 |
+
print(f"Total memory: {total_mem_mb:.2f} MB")
|
| 331 |
+
return
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def print_gpu_parameters(device, state_dict, log_count=1):
|
| 335 |
+
summary = {"device": device, "keys_count": len(state_dict)}
|
| 336 |
+
|
| 337 |
+
logged_params = {}
|
| 338 |
+
for i, (key, tensor) in enumerate(state_dict.items()):
|
| 339 |
+
if i >= log_count:
|
| 340 |
+
break
|
| 341 |
+
logged_params[key] = tensor.flatten()[:3].tolist()
|
| 342 |
+
|
| 343 |
+
summary["params"] = logged_params
|
| 344 |
+
|
| 345 |
+
print(str(summary))
|
| 346 |
+
return
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def visualize_txt_as_img(width, height, text, font_path="font/DejaVuSans.ttf", size=18):
|
| 350 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 351 |
+
|
| 352 |
+
txt = Image.new("RGB", (width, height), color="white")
|
| 353 |
+
draw = ImageDraw.Draw(txt)
|
| 354 |
+
font = ImageFont.truetype(font_path, size=size)
|
| 355 |
+
|
| 356 |
+
if text == "":
|
| 357 |
+
return np.array(txt)
|
| 358 |
+
|
| 359 |
+
# Split text into lines that fit within the image width
|
| 360 |
+
lines = []
|
| 361 |
+
words = text.split()
|
| 362 |
+
current_line = words[0]
|
| 363 |
+
|
| 364 |
+
for word in words[1:]:
|
| 365 |
+
line_with_word = f"{current_line} {word}"
|
| 366 |
+
if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
|
| 367 |
+
current_line = line_with_word
|
| 368 |
+
else:
|
| 369 |
+
lines.append(current_line)
|
| 370 |
+
current_line = word
|
| 371 |
+
|
| 372 |
+
lines.append(current_line)
|
| 373 |
+
|
| 374 |
+
# Draw the text line by line
|
| 375 |
+
y = 0
|
| 376 |
+
line_height = draw.textbbox((0, 0), "A", font=font)[3]
|
| 377 |
+
|
| 378 |
+
for line in lines:
|
| 379 |
+
if y + line_height > height:
|
| 380 |
+
break # stop drawing if the next line will be outside the image
|
| 381 |
+
draw.text((0, y), line, fill="black", font=font)
|
| 382 |
+
y += line_height
|
| 383 |
+
|
| 384 |
+
return np.array(txt)
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def blue_mark(x):
|
| 388 |
+
x = x.copy()
|
| 389 |
+
c = x[:, :, 2]
|
| 390 |
+
b = cv2.blur(c, (9, 9))
|
| 391 |
+
x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
|
| 392 |
+
return x
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def green_mark(x):
|
| 396 |
+
x = x.copy()
|
| 397 |
+
x[:, :, 2] = -1
|
| 398 |
+
x[:, :, 0] = -1
|
| 399 |
+
return x
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def frame_mark(x):
|
| 403 |
+
x = x.copy()
|
| 404 |
+
x[:64] = -1
|
| 405 |
+
x[-64:] = -1
|
| 406 |
+
x[:, :8] = 1
|
| 407 |
+
x[:, -8:] = 1
|
| 408 |
+
return x
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
@torch.inference_mode()
|
| 412 |
+
def pytorch2numpy(imgs):
|
| 413 |
+
results = []
|
| 414 |
+
for x in imgs:
|
| 415 |
+
y = x.movedim(0, -1)
|
| 416 |
+
y = y * 127.5 + 127.5
|
| 417 |
+
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
|
| 418 |
+
results.append(y)
|
| 419 |
+
return results
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
@torch.inference_mode()
|
| 423 |
+
def numpy2pytorch(imgs):
|
| 424 |
+
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
| 425 |
+
h = h.movedim(-1, 1)
|
| 426 |
+
return h
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
@torch.no_grad()
|
| 430 |
+
def duplicate_prefix_to_suffix(x, count, zero_out=False):
|
| 431 |
+
if zero_out:
|
| 432 |
+
return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
|
| 433 |
+
else:
|
| 434 |
+
return torch.cat([x, x[:count]], dim=0)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def weighted_mse(a, b, weight):
|
| 438 |
+
return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
|
| 442 |
+
x = (x - x_min) / (x_max - x_min)
|
| 443 |
+
x = max(0.0, min(x, 1.0))
|
| 444 |
+
x = x**sigma
|
| 445 |
+
return y_min + x * (y_max - y_min)
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def expand_to_dims(x, target_dims):
|
| 449 |
+
return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
|
| 453 |
+
if tensor is None:
|
| 454 |
+
return None
|
| 455 |
+
|
| 456 |
+
first_dim = tensor.shape[0]
|
| 457 |
+
|
| 458 |
+
if first_dim == batch_size:
|
| 459 |
+
return tensor
|
| 460 |
+
|
| 461 |
+
if batch_size % first_dim != 0:
|
| 462 |
+
raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
|
| 463 |
+
|
| 464 |
+
repeat_times = batch_size // first_dim
|
| 465 |
+
|
| 466 |
+
return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
|
| 467 |
+
|
| 468 |
+
|
| 469 |
+
def dim5(x):
|
| 470 |
+
return expand_to_dims(x, 5)
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
def dim4(x):
|
| 474 |
+
return expand_to_dims(x, 4)
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
def dim3(x):
|
| 478 |
+
return expand_to_dims(x, 3)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def crop_or_pad_yield_mask(x, length):
|
| 482 |
+
B, F, C = x.shape
|
| 483 |
+
device = x.device
|
| 484 |
+
dtype = x.dtype
|
| 485 |
+
|
| 486 |
+
if F < length:
|
| 487 |
+
y = torch.zeros((B, length, C), dtype=dtype, device=device)
|
| 488 |
+
mask = torch.zeros((B, length), dtype=torch.bool, device=device)
|
| 489 |
+
y[:, :F, :] = x
|
| 490 |
+
mask[:, :F] = True
|
| 491 |
+
return y, mask
|
| 492 |
+
|
| 493 |
+
return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
|
| 494 |
+
|
| 495 |
+
|
| 496 |
+
def extend_dim(x, dim, minimal_length, zero_pad=False):
|
| 497 |
+
original_length = int(x.shape[dim])
|
| 498 |
+
|
| 499 |
+
if original_length >= minimal_length:
|
| 500 |
+
return x
|
| 501 |
+
|
| 502 |
+
if zero_pad:
|
| 503 |
+
padding_shape = list(x.shape)
|
| 504 |
+
padding_shape[dim] = minimal_length - original_length
|
| 505 |
+
padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
|
| 506 |
+
else:
|
| 507 |
+
idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
|
| 508 |
+
last_element = x[idx]
|
| 509 |
+
padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
|
| 510 |
+
|
| 511 |
+
return torch.cat([x, padding], dim=dim)
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def lazy_positional_encoding(t, repeats=None):
|
| 515 |
+
if not isinstance(t, list):
|
| 516 |
+
t = [t]
|
| 517 |
+
|
| 518 |
+
from diffusers.models.embeddings import get_timestep_embedding
|
| 519 |
+
|
| 520 |
+
te = torch.tensor(t)
|
| 521 |
+
te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
|
| 522 |
+
|
| 523 |
+
if repeats is None:
|
| 524 |
+
return te
|
| 525 |
+
|
| 526 |
+
te = te[:, None, :].expand(-1, repeats, -1)
|
| 527 |
+
|
| 528 |
+
return te
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def state_dict_offset_merge(A, B, C=None):
|
| 532 |
+
result = {}
|
| 533 |
+
keys = A.keys()
|
| 534 |
+
|
| 535 |
+
for key in keys:
|
| 536 |
+
A_value = A[key]
|
| 537 |
+
B_value = B[key].to(A_value)
|
| 538 |
+
|
| 539 |
+
if C is None:
|
| 540 |
+
result[key] = A_value + B_value
|
| 541 |
+
else:
|
| 542 |
+
C_value = C[key].to(A_value)
|
| 543 |
+
result[key] = A_value + B_value - C_value
|
| 544 |
+
|
| 545 |
+
return result
|
| 546 |
+
|
| 547 |
+
|
| 548 |
+
def state_dict_weighted_merge(state_dicts, weights):
|
| 549 |
+
if len(state_dicts) != len(weights):
|
| 550 |
+
raise ValueError("Number of state dictionaries must match number of weights")
|
| 551 |
+
|
| 552 |
+
if not state_dicts:
|
| 553 |
+
return {}
|
| 554 |
+
|
| 555 |
+
total_weight = sum(weights)
|
| 556 |
+
|
| 557 |
+
if total_weight == 0:
|
| 558 |
+
raise ValueError("Sum of weights cannot be zero")
|
| 559 |
+
|
| 560 |
+
normalized_weights = [w / total_weight for w in weights]
|
| 561 |
+
|
| 562 |
+
keys = state_dicts[0].keys()
|
| 563 |
+
result = {}
|
| 564 |
+
|
| 565 |
+
for key in keys:
|
| 566 |
+
result[key] = state_dicts[0][key] * normalized_weights[0]
|
| 567 |
+
|
| 568 |
+
for i in range(1, len(state_dicts)):
|
| 569 |
+
state_dict_value = state_dicts[i][key].to(result[key])
|
| 570 |
+
result[key] += state_dict_value * normalized_weights[i]
|
| 571 |
+
|
| 572 |
+
return result
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def group_files_by_folder(all_files):
|
| 576 |
+
grouped_files = {}
|
| 577 |
+
|
| 578 |
+
for file in all_files:
|
| 579 |
+
folder_name = os.path.basename(os.path.dirname(file))
|
| 580 |
+
if folder_name not in grouped_files:
|
| 581 |
+
grouped_files[folder_name] = []
|
| 582 |
+
grouped_files[folder_name].append(file)
|
| 583 |
+
|
| 584 |
+
list_of_lists = list(grouped_files.values())
|
| 585 |
+
return list_of_lists
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def generate_timestamp():
|
| 589 |
+
now = datetime.datetime.now()
|
| 590 |
+
timestamp = now.strftime("%y%m%d_%H%M%S")
|
| 591 |
+
milliseconds = f"{int(now.microsecond / 1000):03d}"
|
| 592 |
+
random_number = random.randint(0, 9999)
|
| 593 |
+
return f"{timestamp}_{milliseconds}_{random_number}"
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
def write_PIL_image_with_png_info(image, metadata, path):
|
| 597 |
+
from PIL.PngImagePlugin import PngInfo
|
| 598 |
+
|
| 599 |
+
png_info = PngInfo()
|
| 600 |
+
for key, value in metadata.items():
|
| 601 |
+
png_info.add_text(key, value)
|
| 602 |
+
|
| 603 |
+
image.save(path, "PNG", pnginfo=png_info)
|
| 604 |
+
return image
|
| 605 |
+
|
| 606 |
+
|
| 607 |
+
def torch_safe_save(content, path):
|
| 608 |
+
torch.save(content, path + "_tmp")
|
| 609 |
+
os.replace(path + "_tmp", path)
|
| 610 |
+
return path
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def move_optimizer_to_device(optimizer, device):
|
| 614 |
+
for state in optimizer.state.values():
|
| 615 |
+
for k, v in state.items():
|
| 616 |
+
if isinstance(v, torch.Tensor):
|
| 617 |
+
state[k] = v.to(device)
|
frame_pack/wrapper.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def append_dims(x, target_dims):
|
| 5 |
+
return x[(...,) + (None,) * (target_dims - x.ndim)]
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
|
| 9 |
+
if guidance_rescale == 0:
|
| 10 |
+
return noise_cfg
|
| 11 |
+
|
| 12 |
+
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
| 13 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 14 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 15 |
+
noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
|
| 16 |
+
return noise_cfg
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def fm_wrapper(transformer, t_scale=1000.0):
|
| 20 |
+
def k_model(x, sigma, **extra_args):
|
| 21 |
+
dtype = extra_args['dtype']
|
| 22 |
+
cfg_scale = extra_args['cfg_scale']
|
| 23 |
+
cfg_rescale = extra_args['cfg_rescale']
|
| 24 |
+
concat_latent = extra_args['concat_latent']
|
| 25 |
+
|
| 26 |
+
original_dtype = x.dtype
|
| 27 |
+
sigma = sigma.float()
|
| 28 |
+
|
| 29 |
+
x = x.to(dtype)
|
| 30 |
+
timestep = (sigma * t_scale).to(dtype)
|
| 31 |
+
|
| 32 |
+
if concat_latent is None:
|
| 33 |
+
hidden_states = x
|
| 34 |
+
else:
|
| 35 |
+
hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
|
| 36 |
+
|
| 37 |
+
pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
|
| 38 |
+
|
| 39 |
+
if cfg_scale == 1.0:
|
| 40 |
+
pred_negative = torch.zeros_like(pred_positive)
|
| 41 |
+
else:
|
| 42 |
+
pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
|
| 43 |
+
|
| 44 |
+
pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
|
| 45 |
+
pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
|
| 46 |
+
|
| 47 |
+
x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
|
| 48 |
+
|
| 49 |
+
return x0.to(dtype=original_dtype)
|
| 50 |
+
|
| 51 |
+
return k_model
|
framepack_edit_output/framepack-edit-lora-000001.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c5a6478224e15dd49359bb791f4d1984d4f87b2b69e858784a32266d4a9b270c
|
| 3 |
+
size 275426304
|
framepack_edit_output/framepack-edit-lora-000002.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0b24eefda91054ca54f70c9d50eb2df47a1954c4ddf2f3f12078d67e8a97a767
|
| 3 |
+
size 275426304
|
framepack_edit_output/framepack-edit-lora-000003.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b9c0e9747f651655dd95dd13f1c2999662a48dc3e89c84537ec7dc88ec1b307f
|
| 3 |
+
size 275426304
|
framepack_edit_output/framepack-edit-lora-000004.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7456a21c9cbf4bcf4ddcf2e8aacff7d90dc96c1dbaa1f802bb32dbd9e38bbb9b
|
| 3 |
+
size 275426304
|
framepack_edit_output/framepack-edit-lora-000005.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e1d152090aafda957a8ab146ab183aed9b2ddeed70e9fc003163d59024f7e3d6
|
| 3 |
+
size 275426304
|
framepack_edit_output/framepack-edit-lora-000006.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:30bed80789e6ea6d3b5749e86b299889a9d3282758862a5cacf69c66f49c89a5
|
| 3 |
+
size 275426304
|
hunyuan_model/__init__.py
ADDED
|
File without changes
|
hunyuan_model/activation_layers.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_activation_layer(act_type):
|
| 5 |
+
"""get activation layer
|
| 6 |
+
|
| 7 |
+
Args:
|
| 8 |
+
act_type (str): the activation type
|
| 9 |
+
|
| 10 |
+
Returns:
|
| 11 |
+
torch.nn.functional: the activation layer
|
| 12 |
+
"""
|
| 13 |
+
if act_type == "gelu":
|
| 14 |
+
return lambda: nn.GELU()
|
| 15 |
+
elif act_type == "gelu_tanh":
|
| 16 |
+
# Approximate `tanh` requires torch >= 1.13
|
| 17 |
+
return lambda: nn.GELU(approximate="tanh")
|
| 18 |
+
elif act_type == "relu":
|
| 19 |
+
return nn.ReLU
|
| 20 |
+
elif act_type == "silu":
|
| 21 |
+
return nn.SiLU
|
| 22 |
+
else:
|
| 23 |
+
raise ValueError(f"Unknown activation type: {act_type}")
|
hunyuan_model/attention.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.metadata
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
try:
|
| 9 |
+
import flash_attn
|
| 10 |
+
from flash_attn.flash_attn_interface import _flash_attn_forward
|
| 11 |
+
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
| 12 |
+
from flash_attn.flash_attn_interface import flash_attn_func
|
| 13 |
+
except ImportError:
|
| 14 |
+
flash_attn = None
|
| 15 |
+
flash_attn_varlen_func = None
|
| 16 |
+
_flash_attn_forward = None
|
| 17 |
+
flash_attn_func = None
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
print(f"Trying to import sageattention")
|
| 21 |
+
from sageattention import sageattn_varlen, sageattn
|
| 22 |
+
|
| 23 |
+
print("Successfully imported sageattention")
|
| 24 |
+
except ImportError:
|
| 25 |
+
print(f"Failed to import sageattention")
|
| 26 |
+
sageattn_varlen = None
|
| 27 |
+
sageattn = None
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import xformers.ops as xops
|
| 31 |
+
except ImportError:
|
| 32 |
+
xops = None
|
| 33 |
+
|
| 34 |
+
MEMORY_LAYOUT = {
|
| 35 |
+
"flash": (
|
| 36 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
| 37 |
+
lambda x: x,
|
| 38 |
+
),
|
| 39 |
+
"flash_fixlen": (
|
| 40 |
+
lambda x: x,
|
| 41 |
+
lambda x: x,
|
| 42 |
+
),
|
| 43 |
+
"sageattn": (
|
| 44 |
+
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
|
| 45 |
+
lambda x: x,
|
| 46 |
+
),
|
| 47 |
+
"sageattn_fixlen": (
|
| 48 |
+
lambda x: x.transpose(1, 2),
|
| 49 |
+
lambda x: x.transpose(1, 2),
|
| 50 |
+
),
|
| 51 |
+
"torch": (
|
| 52 |
+
lambda x: x.transpose(1, 2),
|
| 53 |
+
lambda x: x.transpose(1, 2),
|
| 54 |
+
),
|
| 55 |
+
"xformers": (
|
| 56 |
+
lambda x: x,
|
| 57 |
+
lambda x: x,
|
| 58 |
+
),
|
| 59 |
+
"vanilla": (
|
| 60 |
+
lambda x: x.transpose(1, 2),
|
| 61 |
+
lambda x: x.transpose(1, 2),
|
| 62 |
+
),
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_cu_seqlens(text_mask, img_len):
|
| 67 |
+
"""Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
text_mask (torch.Tensor): the mask of text
|
| 71 |
+
img_len (int): the length of image
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
torch.Tensor: the calculated cu_seqlens for flash attention
|
| 75 |
+
"""
|
| 76 |
+
batch_size = text_mask.shape[0]
|
| 77 |
+
text_len = text_mask.sum(dim=1)
|
| 78 |
+
max_len = text_mask.shape[1] + img_len
|
| 79 |
+
|
| 80 |
+
cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
|
| 81 |
+
|
| 82 |
+
for i in range(batch_size):
|
| 83 |
+
s = text_len[i] + img_len
|
| 84 |
+
s1 = i * max_len + s
|
| 85 |
+
s2 = (i + 1) * max_len
|
| 86 |
+
cu_seqlens[2 * i + 1] = s1
|
| 87 |
+
cu_seqlens[2 * i + 2] = s2
|
| 88 |
+
|
| 89 |
+
return cu_seqlens
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def attention(
|
| 93 |
+
q_or_qkv_list,
|
| 94 |
+
k=None,
|
| 95 |
+
v=None,
|
| 96 |
+
mode="flash",
|
| 97 |
+
drop_rate=0,
|
| 98 |
+
attn_mask=None,
|
| 99 |
+
total_len=None,
|
| 100 |
+
causal=False,
|
| 101 |
+
cu_seqlens_q=None,
|
| 102 |
+
cu_seqlens_kv=None,
|
| 103 |
+
max_seqlen_q=None,
|
| 104 |
+
max_seqlen_kv=None,
|
| 105 |
+
batch_size=1,
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Perform QKV self attention.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
|
| 112 |
+
k (torch.Tensor): Key tensor with shape [b, s1, a, d]
|
| 113 |
+
v (torch.Tensor): Value tensor with shape [b, s1, a, d]
|
| 114 |
+
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
|
| 115 |
+
drop_rate (float): Dropout rate in attention map. (default: 0)
|
| 116 |
+
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
|
| 117 |
+
(default: None)
|
| 118 |
+
causal (bool): Whether to use causal attention. (default: False)
|
| 119 |
+
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 120 |
+
used to index into q.
|
| 121 |
+
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
|
| 122 |
+
used to index into kv.
|
| 123 |
+
max_seqlen_q (int): The maximum sequence length in the batch of q.
|
| 124 |
+
max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
torch.Tensor: Output tensor after self attention with shape [b, s, ad]
|
| 128 |
+
"""
|
| 129 |
+
q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
|
| 130 |
+
if type(q_or_qkv_list) == list:
|
| 131 |
+
q_or_qkv_list.clear()
|
| 132 |
+
split_attn = total_len is not None
|
| 133 |
+
if split_attn and mode == "sageattn":
|
| 134 |
+
mode = "sageattn_fixlen"
|
| 135 |
+
elif split_attn and mode == "flash":
|
| 136 |
+
mode = "flash_fixlen"
|
| 137 |
+
# print(f"Attention mode: {mode}, split_attn: {split_attn}")
|
| 138 |
+
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
|
| 139 |
+
|
| 140 |
+
# trim the sequence length to the actual length instead of attn_mask
|
| 141 |
+
if split_attn:
|
| 142 |
+
trimmed_len = q.shape[1] - total_len
|
| 143 |
+
q = [q[i : i + 1, : total_len[i]] for i in range(len(q))]
|
| 144 |
+
k = [k[i : i + 1, : total_len[i]] for i in range(len(k))]
|
| 145 |
+
v = [v[i : i + 1, : total_len[i]] for i in range(len(v))]
|
| 146 |
+
q = [pre_attn_layout(q_i) for q_i in q]
|
| 147 |
+
k = [pre_attn_layout(k_i) for k_i in k]
|
| 148 |
+
v = [pre_attn_layout(v_i) for v_i in v]
|
| 149 |
+
# print(
|
| 150 |
+
# f"Trimming the sequence length to {total_len},trimmed_len: {trimmed_len}, q.shape: {[q_i.shape for q_i in q]}, mode: {mode}"
|
| 151 |
+
# )
|
| 152 |
+
else:
|
| 153 |
+
q = pre_attn_layout(q)
|
| 154 |
+
k = pre_attn_layout(k)
|
| 155 |
+
v = pre_attn_layout(v)
|
| 156 |
+
|
| 157 |
+
if mode == "torch":
|
| 158 |
+
if split_attn:
|
| 159 |
+
x = []
|
| 160 |
+
for i in range(len(q)):
|
| 161 |
+
x_i = F.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate, is_causal=causal)
|
| 162 |
+
q[i], k[i], v[i] = None, None, None
|
| 163 |
+
x.append(x_i)
|
| 164 |
+
del q, k, v
|
| 165 |
+
else:
|
| 166 |
+
if attn_mask is not None and attn_mask.dtype != torch.bool:
|
| 167 |
+
attn_mask = attn_mask.to(q.dtype)
|
| 168 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
|
| 169 |
+
del q, k, v
|
| 170 |
+
del attn_mask
|
| 171 |
+
|
| 172 |
+
elif mode == "xformers":
|
| 173 |
+
# B, M, H, K: M is the sequence length, H is the number of heads, K is the dimension of the heads -> it is same as input dimension
|
| 174 |
+
# currently only support batch_size = 1
|
| 175 |
+
assert split_attn, "Xformers only supports splitting"
|
| 176 |
+
x = []
|
| 177 |
+
for i in range(len(q)):
|
| 178 |
+
x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) # , causal=causal)
|
| 179 |
+
q[i], k[i], v[i] = None, None, None
|
| 180 |
+
x.append(x_i)
|
| 181 |
+
del q, k, v
|
| 182 |
+
|
| 183 |
+
elif mode == "flash":
|
| 184 |
+
x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
| 185 |
+
del q, k, v
|
| 186 |
+
# x with shape [(bxs), a, d]
|
| 187 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
| 188 |
+
elif mode == "flash_fixlen":
|
| 189 |
+
x = []
|
| 190 |
+
for i in range(len(q)):
|
| 191 |
+
# q: (batch_size, seqlen, nheads, headdim), k: (batch_size, seqlen, nheads_k, headdim), v: (batch_size, seqlen, nheads_k, headdim)
|
| 192 |
+
x_i = flash_attn_func(q[i], k[i], v[i], dropout_p=drop_rate, causal=causal)
|
| 193 |
+
q[i], k[i], v[i] = None, None, None
|
| 194 |
+
x.append(x_i)
|
| 195 |
+
del q, k, v
|
| 196 |
+
elif mode == "sageattn":
|
| 197 |
+
x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
|
| 198 |
+
del q, k, v
|
| 199 |
+
# x with shape [(bxs), a, d]
|
| 200 |
+
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
|
| 201 |
+
elif mode == "sageattn_fixlen":
|
| 202 |
+
x = []
|
| 203 |
+
for i in range(len(q)):
|
| 204 |
+
# HND seems to cause an error
|
| 205 |
+
x_i = sageattn(q[i], k[i], v[i]) # (batch_size, seq_len, head_num, head_dim)
|
| 206 |
+
q[i], k[i], v[i] = None, None, None
|
| 207 |
+
x.append(x_i)
|
| 208 |
+
del q, k, v
|
| 209 |
+
elif mode == "vanilla":
|
| 210 |
+
assert not split_attn, "Vanilla attention does not support trimming"
|
| 211 |
+
scale_factor = 1 / math.sqrt(q.size(-1))
|
| 212 |
+
|
| 213 |
+
b, a, s, _ = q.shape
|
| 214 |
+
s1 = k.size(2)
|
| 215 |
+
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
|
| 216 |
+
if causal:
|
| 217 |
+
# Only applied to self attention
|
| 218 |
+
assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
|
| 219 |
+
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
|
| 220 |
+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
|
| 221 |
+
attn_bias.to(q.dtype)
|
| 222 |
+
|
| 223 |
+
if attn_mask is not None:
|
| 224 |
+
if attn_mask.dtype == torch.bool:
|
| 225 |
+
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
|
| 226 |
+
else:
|
| 227 |
+
attn_bias += attn_mask
|
| 228 |
+
|
| 229 |
+
# TODO: Maybe force q and k to be float32 to avoid numerical overflow
|
| 230 |
+
attn = (q @ k.transpose(-2, -1)) * scale_factor
|
| 231 |
+
attn += attn_bias
|
| 232 |
+
attn = attn.softmax(dim=-1)
|
| 233 |
+
attn = torch.dropout(attn, p=drop_rate, train=True)
|
| 234 |
+
x = attn @ v
|
| 235 |
+
else:
|
| 236 |
+
raise NotImplementedError(f"Unsupported attention mode: {mode}")
|
| 237 |
+
|
| 238 |
+
if split_attn:
|
| 239 |
+
x = [post_attn_layout(x_i) for x_i in x]
|
| 240 |
+
for i in range(len(x)):
|
| 241 |
+
x[i] = F.pad(x[i], (0, 0, 0, 0, 0, trimmed_len[i]))
|
| 242 |
+
x = torch.cat(x, dim=0)
|
| 243 |
+
else:
|
| 244 |
+
x = post_attn_layout(x)
|
| 245 |
+
|
| 246 |
+
b, s, a, d = x.shape
|
| 247 |
+
out = x.reshape(b, s, -1)
|
| 248 |
+
return out
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
|
| 252 |
+
attn1 = hybrid_seq_parallel_attn(
|
| 253 |
+
None,
|
| 254 |
+
q[:, :img_q_len, :, :],
|
| 255 |
+
k[:, :img_kv_len, :, :],
|
| 256 |
+
v[:, :img_kv_len, :, :],
|
| 257 |
+
dropout_p=0.0,
|
| 258 |
+
causal=False,
|
| 259 |
+
joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
|
| 260 |
+
joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
|
| 261 |
+
joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
|
| 262 |
+
joint_strategy="rear",
|
| 263 |
+
)
|
| 264 |
+
if flash_attn.__version__ >= "2.7.0":
|
| 265 |
+
attn2, *_ = _flash_attn_forward(
|
| 266 |
+
q[:, cu_seqlens_q[1] :],
|
| 267 |
+
k[:, cu_seqlens_kv[1] :],
|
| 268 |
+
v[:, cu_seqlens_kv[1] :],
|
| 269 |
+
dropout_p=0.0,
|
| 270 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
| 271 |
+
causal=False,
|
| 272 |
+
window_size_left=-1,
|
| 273 |
+
window_size_right=-1,
|
| 274 |
+
softcap=0.0,
|
| 275 |
+
alibi_slopes=None,
|
| 276 |
+
return_softmax=False,
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
attn2, *_ = _flash_attn_forward(
|
| 280 |
+
q[:, cu_seqlens_q[1] :],
|
| 281 |
+
k[:, cu_seqlens_kv[1] :],
|
| 282 |
+
v[:, cu_seqlens_kv[1] :],
|
| 283 |
+
dropout_p=0.0,
|
| 284 |
+
softmax_scale=q.shape[-1] ** (-0.5),
|
| 285 |
+
causal=False,
|
| 286 |
+
window_size=(-1, -1),
|
| 287 |
+
softcap=0.0,
|
| 288 |
+
alibi_slopes=None,
|
| 289 |
+
return_softmax=False,
|
| 290 |
+
)
|
| 291 |
+
attn = torch.cat([attn1, attn2], dim=1)
|
| 292 |
+
b, s, a, d = attn.shape
|
| 293 |
+
attn = attn.reshape(b, s, -1)
|
| 294 |
+
|
| 295 |
+
return attn
|
hunyuan_model/autoencoder_kl_causal_3d.py
ADDED
|
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
#
|
| 16 |
+
# Modified from diffusers==0.29.2
|
| 17 |
+
#
|
| 18 |
+
# ==============================================================================
|
| 19 |
+
from typing import Dict, Optional, Tuple, Union
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
|
| 25 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 26 |
+
|
| 27 |
+
# try:
|
| 28 |
+
# # This diffusers is modified and packed in the mirror.
|
| 29 |
+
# from diffusers.loaders import FromOriginalVAEMixin
|
| 30 |
+
# except ImportError:
|
| 31 |
+
# # Use this to be compatible with the original diffusers.
|
| 32 |
+
# from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
|
| 33 |
+
from diffusers.utils.accelerate_utils import apply_forward_hook
|
| 34 |
+
from diffusers.models.attention_processor import (
|
| 35 |
+
ADDED_KV_ATTENTION_PROCESSORS,
|
| 36 |
+
CROSS_ATTENTION_PROCESSORS,
|
| 37 |
+
Attention,
|
| 38 |
+
AttentionProcessor,
|
| 39 |
+
AttnAddedKVProcessor,
|
| 40 |
+
AttnProcessor,
|
| 41 |
+
)
|
| 42 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 43 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 44 |
+
from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class DecoderOutput2(BaseOutput):
|
| 49 |
+
sample: torch.FloatTensor
|
| 50 |
+
posterior: Optional[DiagonalGaussianDistribution] = None
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class AutoencoderKLCausal3D(ModelMixin, ConfigMixin):
|
| 54 |
+
r"""
|
| 55 |
+
A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
|
| 56 |
+
|
| 57 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 58 |
+
for all models (such as downloading or saving).
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
_supports_gradient_checkpointing = True
|
| 62 |
+
|
| 63 |
+
@register_to_config
|
| 64 |
+
def __init__(
|
| 65 |
+
self,
|
| 66 |
+
in_channels: int = 3,
|
| 67 |
+
out_channels: int = 3,
|
| 68 |
+
down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
|
| 69 |
+
up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
|
| 70 |
+
block_out_channels: Tuple[int] = (64,),
|
| 71 |
+
layers_per_block: int = 1,
|
| 72 |
+
act_fn: str = "silu",
|
| 73 |
+
latent_channels: int = 4,
|
| 74 |
+
norm_num_groups: int = 32,
|
| 75 |
+
sample_size: int = 32,
|
| 76 |
+
sample_tsize: int = 64,
|
| 77 |
+
scaling_factor: float = 0.18215,
|
| 78 |
+
force_upcast: float = True,
|
| 79 |
+
spatial_compression_ratio: int = 8,
|
| 80 |
+
time_compression_ratio: int = 4,
|
| 81 |
+
mid_block_add_attention: bool = True,
|
| 82 |
+
):
|
| 83 |
+
super().__init__()
|
| 84 |
+
|
| 85 |
+
self.time_compression_ratio = time_compression_ratio
|
| 86 |
+
|
| 87 |
+
self.encoder = EncoderCausal3D(
|
| 88 |
+
in_channels=in_channels,
|
| 89 |
+
out_channels=latent_channels,
|
| 90 |
+
down_block_types=down_block_types,
|
| 91 |
+
block_out_channels=block_out_channels,
|
| 92 |
+
layers_per_block=layers_per_block,
|
| 93 |
+
act_fn=act_fn,
|
| 94 |
+
norm_num_groups=norm_num_groups,
|
| 95 |
+
double_z=True,
|
| 96 |
+
time_compression_ratio=time_compression_ratio,
|
| 97 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
| 98 |
+
mid_block_add_attention=mid_block_add_attention,
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
self.decoder = DecoderCausal3D(
|
| 102 |
+
in_channels=latent_channels,
|
| 103 |
+
out_channels=out_channels,
|
| 104 |
+
up_block_types=up_block_types,
|
| 105 |
+
block_out_channels=block_out_channels,
|
| 106 |
+
layers_per_block=layers_per_block,
|
| 107 |
+
norm_num_groups=norm_num_groups,
|
| 108 |
+
act_fn=act_fn,
|
| 109 |
+
time_compression_ratio=time_compression_ratio,
|
| 110 |
+
spatial_compression_ratio=spatial_compression_ratio,
|
| 111 |
+
mid_block_add_attention=mid_block_add_attention,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
|
| 115 |
+
self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
|
| 116 |
+
|
| 117 |
+
self.use_slicing = False
|
| 118 |
+
self.use_spatial_tiling = False
|
| 119 |
+
self.use_temporal_tiling = False
|
| 120 |
+
|
| 121 |
+
# only relevant if vae tiling is enabled
|
| 122 |
+
self.tile_sample_min_tsize = sample_tsize
|
| 123 |
+
self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
|
| 124 |
+
|
| 125 |
+
self.tile_sample_min_size = self.config.sample_size
|
| 126 |
+
sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
|
| 127 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
| 128 |
+
self.tile_overlap_factor = 0.25
|
| 129 |
+
|
| 130 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 131 |
+
if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
|
| 132 |
+
module.gradient_checkpointing = value
|
| 133 |
+
|
| 134 |
+
def enable_temporal_tiling(self, use_tiling: bool = True):
|
| 135 |
+
self.use_temporal_tiling = use_tiling
|
| 136 |
+
|
| 137 |
+
def disable_temporal_tiling(self):
|
| 138 |
+
self.enable_temporal_tiling(False)
|
| 139 |
+
|
| 140 |
+
def enable_spatial_tiling(self, use_tiling: bool = True):
|
| 141 |
+
self.use_spatial_tiling = use_tiling
|
| 142 |
+
|
| 143 |
+
def disable_spatial_tiling(self):
|
| 144 |
+
self.enable_spatial_tiling(False)
|
| 145 |
+
|
| 146 |
+
def enable_tiling(self, use_tiling: bool = True):
|
| 147 |
+
r"""
|
| 148 |
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
| 149 |
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
| 150 |
+
processing larger videos.
|
| 151 |
+
"""
|
| 152 |
+
self.enable_spatial_tiling(use_tiling)
|
| 153 |
+
self.enable_temporal_tiling(use_tiling)
|
| 154 |
+
|
| 155 |
+
def disable_tiling(self):
|
| 156 |
+
r"""
|
| 157 |
+
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
| 158 |
+
decoding in one step.
|
| 159 |
+
"""
|
| 160 |
+
self.disable_spatial_tiling()
|
| 161 |
+
self.disable_temporal_tiling()
|
| 162 |
+
|
| 163 |
+
def enable_slicing(self):
|
| 164 |
+
r"""
|
| 165 |
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
| 166 |
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
| 167 |
+
"""
|
| 168 |
+
self.use_slicing = True
|
| 169 |
+
|
| 170 |
+
def disable_slicing(self):
|
| 171 |
+
r"""
|
| 172 |
+
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
| 173 |
+
decoding in one step.
|
| 174 |
+
"""
|
| 175 |
+
self.use_slicing = False
|
| 176 |
+
|
| 177 |
+
def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
|
| 178 |
+
# set chunk_size to CausalConv3d recursively
|
| 179 |
+
def set_chunk_size(module):
|
| 180 |
+
if hasattr(module, "chunk_size"):
|
| 181 |
+
module.chunk_size = chunk_size
|
| 182 |
+
|
| 183 |
+
self.apply(set_chunk_size)
|
| 184 |
+
|
| 185 |
+
@property
|
| 186 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
| 187 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 188 |
+
r"""
|
| 189 |
+
Returns:
|
| 190 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 191 |
+
indexed by its weight name.
|
| 192 |
+
"""
|
| 193 |
+
# set recursively
|
| 194 |
+
processors = {}
|
| 195 |
+
|
| 196 |
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
| 197 |
+
if hasattr(module, "get_processor"):
|
| 198 |
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
| 199 |
+
|
| 200 |
+
for sub_name, child in module.named_children():
|
| 201 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 202 |
+
|
| 203 |
+
return processors
|
| 204 |
+
|
| 205 |
+
for name, module in self.named_children():
|
| 206 |
+
fn_recursive_add_processors(name, module, processors)
|
| 207 |
+
|
| 208 |
+
return processors
|
| 209 |
+
|
| 210 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
| 211 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
|
| 212 |
+
r"""
|
| 213 |
+
Sets the attention processor to use to compute attention.
|
| 214 |
+
|
| 215 |
+
Parameters:
|
| 216 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 217 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 218 |
+
for **all** `Attention` layers.
|
| 219 |
+
|
| 220 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 221 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 222 |
+
|
| 223 |
+
"""
|
| 224 |
+
count = len(self.attn_processors.keys())
|
| 225 |
+
|
| 226 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 229 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 233 |
+
if hasattr(module, "set_processor"):
|
| 234 |
+
if not isinstance(processor, dict):
|
| 235 |
+
module.set_processor(processor, _remove_lora=_remove_lora)
|
| 236 |
+
else:
|
| 237 |
+
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
| 238 |
+
|
| 239 |
+
for sub_name, child in module.named_children():
|
| 240 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 241 |
+
|
| 242 |
+
for name, module in self.named_children():
|
| 243 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 244 |
+
|
| 245 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
| 246 |
+
def set_default_attn_processor(self):
|
| 247 |
+
"""
|
| 248 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 249 |
+
"""
|
| 250 |
+
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 251 |
+
processor = AttnAddedKVProcessor()
|
| 252 |
+
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 253 |
+
processor = AttnProcessor()
|
| 254 |
+
else:
|
| 255 |
+
raise ValueError(
|
| 256 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
self.set_attn_processor(processor, _remove_lora=True)
|
| 260 |
+
|
| 261 |
+
@apply_forward_hook
|
| 262 |
+
def encode(
|
| 263 |
+
self, x: torch.FloatTensor, return_dict: bool = True
|
| 264 |
+
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
|
| 265 |
+
"""
|
| 266 |
+
Encode a batch of images/videos into latents.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
| 270 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 271 |
+
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
The latent representations of the encoded images/videos. If `return_dict` is True, a
|
| 275 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
| 276 |
+
"""
|
| 277 |
+
assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
|
| 278 |
+
|
| 279 |
+
if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
|
| 280 |
+
return self.temporal_tiled_encode(x, return_dict=return_dict)
|
| 281 |
+
|
| 282 |
+
if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
|
| 283 |
+
return self.spatial_tiled_encode(x, return_dict=return_dict)
|
| 284 |
+
|
| 285 |
+
if self.use_slicing and x.shape[0] > 1:
|
| 286 |
+
encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
|
| 287 |
+
h = torch.cat(encoded_slices)
|
| 288 |
+
else:
|
| 289 |
+
h = self.encoder(x)
|
| 290 |
+
|
| 291 |
+
moments = self.quant_conv(h)
|
| 292 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 293 |
+
|
| 294 |
+
if not return_dict:
|
| 295 |
+
return (posterior,)
|
| 296 |
+
|
| 297 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 298 |
+
|
| 299 |
+
def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 300 |
+
assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
|
| 301 |
+
|
| 302 |
+
if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
|
| 303 |
+
return self.temporal_tiled_decode(z, return_dict=return_dict)
|
| 304 |
+
|
| 305 |
+
if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
|
| 306 |
+
return self.spatial_tiled_decode(z, return_dict=return_dict)
|
| 307 |
+
|
| 308 |
+
z = self.post_quant_conv(z)
|
| 309 |
+
dec = self.decoder(z)
|
| 310 |
+
|
| 311 |
+
if not return_dict:
|
| 312 |
+
return (dec,)
|
| 313 |
+
|
| 314 |
+
return DecoderOutput(sample=dec)
|
| 315 |
+
|
| 316 |
+
@apply_forward_hook
|
| 317 |
+
def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 318 |
+
"""
|
| 319 |
+
Decode a batch of images/videos.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
| 323 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 324 |
+
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 325 |
+
|
| 326 |
+
Returns:
|
| 327 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 328 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 329 |
+
returned.
|
| 330 |
+
|
| 331 |
+
"""
|
| 332 |
+
if self.use_slicing and z.shape[0] > 1:
|
| 333 |
+
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
|
| 334 |
+
decoded = torch.cat(decoded_slices)
|
| 335 |
+
else:
|
| 336 |
+
decoded = self._decode(z).sample
|
| 337 |
+
|
| 338 |
+
if not return_dict:
|
| 339 |
+
return (decoded,)
|
| 340 |
+
|
| 341 |
+
return DecoderOutput(sample=decoded)
|
| 342 |
+
|
| 343 |
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 344 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 345 |
+
for y in range(blend_extent):
|
| 346 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
|
| 347 |
+
return b
|
| 348 |
+
|
| 349 |
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 350 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 351 |
+
for x in range(blend_extent):
|
| 352 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
|
| 353 |
+
return b
|
| 354 |
+
|
| 355 |
+
def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
| 356 |
+
blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
|
| 357 |
+
for x in range(blend_extent):
|
| 358 |
+
b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
|
| 359 |
+
return b
|
| 360 |
+
|
| 361 |
+
def spatial_tiled_encode(
|
| 362 |
+
self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
|
| 363 |
+
) -> AutoencoderKLOutput:
|
| 364 |
+
r"""Encode a batch of images/videos using a tiled encoder.
|
| 365 |
+
|
| 366 |
+
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
| 367 |
+
steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
|
| 368 |
+
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
|
| 369 |
+
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
|
| 370 |
+
output, but they should be much less noticeable.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
x (`torch.FloatTensor`): Input batch of images/videos.
|
| 374 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 375 |
+
Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
[`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
|
| 379 |
+
If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
|
| 380 |
+
`tuple` is returned.
|
| 381 |
+
"""
|
| 382 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
| 383 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
| 384 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
| 385 |
+
|
| 386 |
+
# Split video into tiles and encode them separately.
|
| 387 |
+
rows = []
|
| 388 |
+
for i in range(0, x.shape[-2], overlap_size):
|
| 389 |
+
row = []
|
| 390 |
+
for j in range(0, x.shape[-1], overlap_size):
|
| 391 |
+
tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
|
| 392 |
+
tile = self.encoder(tile)
|
| 393 |
+
tile = self.quant_conv(tile)
|
| 394 |
+
row.append(tile)
|
| 395 |
+
rows.append(row)
|
| 396 |
+
result_rows = []
|
| 397 |
+
for i, row in enumerate(rows):
|
| 398 |
+
result_row = []
|
| 399 |
+
for j, tile in enumerate(row):
|
| 400 |
+
# blend the above tile and the left tile
|
| 401 |
+
# to the current tile and add the current tile to the result row
|
| 402 |
+
if i > 0:
|
| 403 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 404 |
+
if j > 0:
|
| 405 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 406 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 407 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 408 |
+
|
| 409 |
+
moments = torch.cat(result_rows, dim=-2)
|
| 410 |
+
if return_moments:
|
| 411 |
+
return moments
|
| 412 |
+
|
| 413 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 414 |
+
if not return_dict:
|
| 415 |
+
return (posterior,)
|
| 416 |
+
|
| 417 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 418 |
+
|
| 419 |
+
def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 420 |
+
r"""
|
| 421 |
+
Decode a batch of images/videos using a tiled decoder.
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
z (`torch.FloatTensor`): Input batch of latent vectors.
|
| 425 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 426 |
+
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
[`~models.vae.DecoderOutput`] or `tuple`:
|
| 430 |
+
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
|
| 431 |
+
returned.
|
| 432 |
+
"""
|
| 433 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
| 434 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
| 435 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
| 436 |
+
|
| 437 |
+
# Split z into overlapping tiles and decode them separately.
|
| 438 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 439 |
+
rows = []
|
| 440 |
+
for i in range(0, z.shape[-2], overlap_size):
|
| 441 |
+
row = []
|
| 442 |
+
for j in range(0, z.shape[-1], overlap_size):
|
| 443 |
+
tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
|
| 444 |
+
tile = self.post_quant_conv(tile)
|
| 445 |
+
decoded = self.decoder(tile)
|
| 446 |
+
row.append(decoded)
|
| 447 |
+
rows.append(row)
|
| 448 |
+
result_rows = []
|
| 449 |
+
for i, row in enumerate(rows):
|
| 450 |
+
result_row = []
|
| 451 |
+
for j, tile in enumerate(row):
|
| 452 |
+
# blend the above tile and the left tile
|
| 453 |
+
# to the current tile and add the current tile to the result row
|
| 454 |
+
if i > 0:
|
| 455 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 456 |
+
if j > 0:
|
| 457 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 458 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 459 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 460 |
+
|
| 461 |
+
dec = torch.cat(result_rows, dim=-2)
|
| 462 |
+
if not return_dict:
|
| 463 |
+
return (dec,)
|
| 464 |
+
|
| 465 |
+
return DecoderOutput(sample=dec)
|
| 466 |
+
|
| 467 |
+
def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
| 468 |
+
|
| 469 |
+
B, C, T, H, W = x.shape
|
| 470 |
+
overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
|
| 471 |
+
blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
|
| 472 |
+
t_limit = self.tile_latent_min_tsize - blend_extent
|
| 473 |
+
|
| 474 |
+
# Split the video into tiles and encode them separately.
|
| 475 |
+
row = []
|
| 476 |
+
for i in range(0, T, overlap_size):
|
| 477 |
+
tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
|
| 478 |
+
if self.use_spatial_tiling and (
|
| 479 |
+
tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
|
| 480 |
+
):
|
| 481 |
+
tile = self.spatial_tiled_encode(tile, return_moments=True)
|
| 482 |
+
else:
|
| 483 |
+
tile = self.encoder(tile)
|
| 484 |
+
tile = self.quant_conv(tile)
|
| 485 |
+
if i > 0:
|
| 486 |
+
tile = tile[:, :, 1:, :, :]
|
| 487 |
+
row.append(tile)
|
| 488 |
+
result_row = []
|
| 489 |
+
for i, tile in enumerate(row):
|
| 490 |
+
if i > 0:
|
| 491 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
| 492 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
| 493 |
+
else:
|
| 494 |
+
result_row.append(tile[:, :, : t_limit + 1, :, :])
|
| 495 |
+
|
| 496 |
+
moments = torch.cat(result_row, dim=2)
|
| 497 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 498 |
+
|
| 499 |
+
if not return_dict:
|
| 500 |
+
return (posterior,)
|
| 501 |
+
|
| 502 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 503 |
+
|
| 504 |
+
def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 505 |
+
# Split z into overlapping tiles and decode them separately.
|
| 506 |
+
|
| 507 |
+
B, C, T, H, W = z.shape
|
| 508 |
+
overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
|
| 509 |
+
blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
|
| 510 |
+
t_limit = self.tile_sample_min_tsize - blend_extent
|
| 511 |
+
|
| 512 |
+
row = []
|
| 513 |
+
for i in range(0, T, overlap_size):
|
| 514 |
+
tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
|
| 515 |
+
if self.use_spatial_tiling and (
|
| 516 |
+
tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
|
| 517 |
+
):
|
| 518 |
+
decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
|
| 519 |
+
else:
|
| 520 |
+
tile = self.post_quant_conv(tile)
|
| 521 |
+
decoded = self.decoder(tile)
|
| 522 |
+
if i > 0:
|
| 523 |
+
decoded = decoded[:, :, 1:, :, :]
|
| 524 |
+
row.append(decoded)
|
| 525 |
+
result_row = []
|
| 526 |
+
for i, tile in enumerate(row):
|
| 527 |
+
if i > 0:
|
| 528 |
+
tile = self.blend_t(row[i - 1], tile, blend_extent)
|
| 529 |
+
result_row.append(tile[:, :, :t_limit, :, :])
|
| 530 |
+
else:
|
| 531 |
+
result_row.append(tile[:, :, : t_limit + 1, :, :])
|
| 532 |
+
|
| 533 |
+
dec = torch.cat(result_row, dim=2)
|
| 534 |
+
if not return_dict:
|
| 535 |
+
return (dec,)
|
| 536 |
+
|
| 537 |
+
return DecoderOutput(sample=dec)
|
| 538 |
+
|
| 539 |
+
def forward(
|
| 540 |
+
self,
|
| 541 |
+
sample: torch.FloatTensor,
|
| 542 |
+
sample_posterior: bool = False,
|
| 543 |
+
return_dict: bool = True,
|
| 544 |
+
return_posterior: bool = False,
|
| 545 |
+
generator: Optional[torch.Generator] = None,
|
| 546 |
+
) -> Union[DecoderOutput2, torch.FloatTensor]:
|
| 547 |
+
r"""
|
| 548 |
+
Args:
|
| 549 |
+
sample (`torch.FloatTensor`): Input sample.
|
| 550 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
| 551 |
+
Whether to sample from the posterior.
|
| 552 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 553 |
+
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
|
| 554 |
+
"""
|
| 555 |
+
x = sample
|
| 556 |
+
posterior = self.encode(x).latent_dist
|
| 557 |
+
if sample_posterior:
|
| 558 |
+
z = posterior.sample(generator=generator)
|
| 559 |
+
else:
|
| 560 |
+
z = posterior.mode()
|
| 561 |
+
dec = self.decode(z).sample
|
| 562 |
+
|
| 563 |
+
if not return_dict:
|
| 564 |
+
if return_posterior:
|
| 565 |
+
return (dec, posterior)
|
| 566 |
+
else:
|
| 567 |
+
return (dec,)
|
| 568 |
+
if return_posterior:
|
| 569 |
+
return DecoderOutput2(sample=dec, posterior=posterior)
|
| 570 |
+
else:
|
| 571 |
+
return DecoderOutput2(sample=dec)
|
| 572 |
+
|
| 573 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
| 574 |
+
def fuse_qkv_projections(self):
|
| 575 |
+
"""
|
| 576 |
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
| 577 |
+
key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
|
| 578 |
+
|
| 579 |
+
<Tip warning={true}>
|
| 580 |
+
|
| 581 |
+
This API is 🧪 experimental.
|
| 582 |
+
|
| 583 |
+
</Tip>
|
| 584 |
+
"""
|
| 585 |
+
self.original_attn_processors = None
|
| 586 |
+
|
| 587 |
+
for _, attn_processor in self.attn_processors.items():
|
| 588 |
+
if "Added" in str(attn_processor.__class__.__name__):
|
| 589 |
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
| 590 |
+
|
| 591 |
+
self.original_attn_processors = self.attn_processors
|
| 592 |
+
|
| 593 |
+
for module in self.modules():
|
| 594 |
+
if isinstance(module, Attention):
|
| 595 |
+
module.fuse_projections(fuse=True)
|
| 596 |
+
|
| 597 |
+
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
| 598 |
+
def unfuse_qkv_projections(self):
|
| 599 |
+
"""Disables the fused QKV projection if enabled.
|
| 600 |
+
|
| 601 |
+
<Tip warning={true}>
|
| 602 |
+
|
| 603 |
+
This API is 🧪 experimental.
|
| 604 |
+
|
| 605 |
+
</Tip>
|
| 606 |
+
|
| 607 |
+
"""
|
| 608 |
+
if self.original_attn_processors is not None:
|
| 609 |
+
self.set_attn_processor(self.original_attn_processors)
|
hunyuan_model/embed_layers.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
|
| 7 |
+
from .helpers import to_2tuple
|
| 8 |
+
|
| 9 |
+
class PatchEmbed(nn.Module):
|
| 10 |
+
"""2D Image to Patch Embedding
|
| 11 |
+
|
| 12 |
+
Image to Patch Embedding using Conv2d
|
| 13 |
+
|
| 14 |
+
A convolution based approach to patchifying a 2D image w/ embedding projection.
|
| 15 |
+
|
| 16 |
+
Based on the impl in https://github.com/google-research/vision_transformer
|
| 17 |
+
|
| 18 |
+
Hacked together by / Copyright 2020 Ross Wightman
|
| 19 |
+
|
| 20 |
+
Remove the _assert function in forward function to be compatible with multi-resolution images.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
patch_size=16,
|
| 26 |
+
in_chans=3,
|
| 27 |
+
embed_dim=768,
|
| 28 |
+
norm_layer=None,
|
| 29 |
+
flatten=True,
|
| 30 |
+
bias=True,
|
| 31 |
+
dtype=None,
|
| 32 |
+
device=None,
|
| 33 |
+
):
|
| 34 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 35 |
+
super().__init__()
|
| 36 |
+
patch_size = to_2tuple(patch_size)
|
| 37 |
+
self.patch_size = patch_size
|
| 38 |
+
self.flatten = flatten
|
| 39 |
+
|
| 40 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
|
| 41 |
+
nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
|
| 42 |
+
if bias:
|
| 43 |
+
nn.init.zeros_(self.proj.bias)
|
| 44 |
+
|
| 45 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
x = self.proj(x)
|
| 49 |
+
if self.flatten:
|
| 50 |
+
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
|
| 51 |
+
x = self.norm(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class TextProjection(nn.Module):
|
| 56 |
+
"""
|
| 57 |
+
Projects text embeddings. Also handles dropout for classifier-free guidance.
|
| 58 |
+
|
| 59 |
+
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
|
| 63 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
|
| 66 |
+
self.act_1 = act_layer()
|
| 67 |
+
self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
|
| 68 |
+
|
| 69 |
+
def forward(self, caption):
|
| 70 |
+
hidden_states = self.linear_1(caption)
|
| 71 |
+
hidden_states = self.act_1(hidden_states)
|
| 72 |
+
hidden_states = self.linear_2(hidden_states)
|
| 73 |
+
return hidden_states
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def timestep_embedding(t, dim, max_period=10000):
|
| 77 |
+
"""
|
| 78 |
+
Create sinusoidal timestep embeddings.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
| 82 |
+
dim (int): the dimension of the output.
|
| 83 |
+
max_period (int): controls the minimum frequency of the embeddings.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
|
| 87 |
+
|
| 88 |
+
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
|
| 89 |
+
"""
|
| 90 |
+
half = dim // 2
|
| 91 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
|
| 92 |
+
args = t[:, None].float() * freqs[None]
|
| 93 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 94 |
+
if dim % 2:
|
| 95 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 96 |
+
return embedding
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class TimestepEmbedder(nn.Module):
|
| 100 |
+
"""
|
| 101 |
+
Embeds scalar timesteps into vector representations.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
def __init__(
|
| 105 |
+
self,
|
| 106 |
+
hidden_size,
|
| 107 |
+
act_layer,
|
| 108 |
+
frequency_embedding_size=256,
|
| 109 |
+
max_period=10000,
|
| 110 |
+
out_size=None,
|
| 111 |
+
dtype=None,
|
| 112 |
+
device=None,
|
| 113 |
+
):
|
| 114 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.frequency_embedding_size = frequency_embedding_size
|
| 117 |
+
self.max_period = max_period
|
| 118 |
+
if out_size is None:
|
| 119 |
+
out_size = hidden_size
|
| 120 |
+
|
| 121 |
+
self.mlp = nn.Sequential(
|
| 122 |
+
nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
|
| 123 |
+
act_layer(),
|
| 124 |
+
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
|
| 125 |
+
)
|
| 126 |
+
nn.init.normal_(self.mlp[0].weight, std=0.02)
|
| 127 |
+
nn.init.normal_(self.mlp[2].weight, std=0.02)
|
| 128 |
+
|
| 129 |
+
def forward(self, t):
|
| 130 |
+
t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
|
| 131 |
+
t_emb = self.mlp(t_freq)
|
| 132 |
+
return t_emb
|
hunyuan_model/fp8_optimization.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#based on ComfyUI's and MinusZoneAI's fp8_linear optimization
|
| 2 |
+
#further borrowed from HunyuanVideoWrapper for Musubi Tuner
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
def fp8_linear_forward(cls, original_dtype, input):
|
| 7 |
+
weight_dtype = cls.weight.dtype
|
| 8 |
+
if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
|
| 9 |
+
if len(input.shape) == 3:
|
| 10 |
+
target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn
|
| 11 |
+
inn = input.reshape(-1, input.shape[2]).to(target_dtype)
|
| 12 |
+
w = cls.weight.t()
|
| 13 |
+
|
| 14 |
+
scale = torch.ones((1), device=input.device, dtype=torch.float32)
|
| 15 |
+
bias = cls.bias.to(original_dtype) if cls.bias is not None else None
|
| 16 |
+
|
| 17 |
+
if bias is not None:
|
| 18 |
+
o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale)
|
| 19 |
+
else:
|
| 20 |
+
o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale)
|
| 21 |
+
|
| 22 |
+
if isinstance(o, tuple):
|
| 23 |
+
o = o[0]
|
| 24 |
+
|
| 25 |
+
return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
|
| 26 |
+
else:
|
| 27 |
+
return cls.original_forward(input.to(original_dtype))
|
| 28 |
+
else:
|
| 29 |
+
return cls.original_forward(input)
|
| 30 |
+
|
| 31 |
+
def convert_fp8_linear(module, original_dtype, params_to_keep={}):
|
| 32 |
+
setattr(module, "fp8_matmul_enabled", True)
|
| 33 |
+
|
| 34 |
+
for name, module in module.named_modules():
|
| 35 |
+
if not any(keyword in name for keyword in params_to_keep):
|
| 36 |
+
if isinstance(module, nn.Linear):
|
| 37 |
+
original_forward = module.forward
|
| 38 |
+
setattr(module, "original_forward", original_forward)
|
| 39 |
+
setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
|
hunyuan_model/helpers.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections.abc
|
| 2 |
+
|
| 3 |
+
from itertools import repeat
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def _ntuple(n):
|
| 7 |
+
def parse(x):
|
| 8 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 9 |
+
x = tuple(x)
|
| 10 |
+
if len(x) == 1:
|
| 11 |
+
x = tuple(repeat(x[0], n))
|
| 12 |
+
return x
|
| 13 |
+
return tuple(repeat(x, n))
|
| 14 |
+
return parse
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
to_1tuple = _ntuple(1)
|
| 18 |
+
to_2tuple = _ntuple(2)
|
| 19 |
+
to_3tuple = _ntuple(3)
|
| 20 |
+
to_4tuple = _ntuple(4)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def as_tuple(x):
|
| 24 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 25 |
+
return tuple(x)
|
| 26 |
+
if x is None or isinstance(x, (int, float, str)):
|
| 27 |
+
return (x,)
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError(f"Unknown type {type(x)}")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def as_list_of_2tuple(x):
|
| 33 |
+
x = as_tuple(x)
|
| 34 |
+
if len(x) == 1:
|
| 35 |
+
x = (x[0], x[0])
|
| 36 |
+
assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
|
| 37 |
+
lst = []
|
| 38 |
+
for i in range(0, len(x), 2):
|
| 39 |
+
lst.append((x[i], x[i + 1]))
|
| 40 |
+
return lst
|
hunyuan_model/mlp_layers.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Modified from timm library:
|
| 2 |
+
# https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
|
| 3 |
+
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
|
| 9 |
+
from .modulate_layers import modulate
|
| 10 |
+
from .helpers import to_2tuple
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MLP(nn.Module):
|
| 14 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 15 |
+
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
in_channels,
|
| 19 |
+
hidden_channels=None,
|
| 20 |
+
out_features=None,
|
| 21 |
+
act_layer=nn.GELU,
|
| 22 |
+
norm_layer=None,
|
| 23 |
+
bias=True,
|
| 24 |
+
drop=0.0,
|
| 25 |
+
use_conv=False,
|
| 26 |
+
device=None,
|
| 27 |
+
dtype=None,
|
| 28 |
+
):
|
| 29 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 30 |
+
super().__init__()
|
| 31 |
+
out_features = out_features or in_channels
|
| 32 |
+
hidden_channels = hidden_channels or in_channels
|
| 33 |
+
bias = to_2tuple(bias)
|
| 34 |
+
drop_probs = to_2tuple(drop)
|
| 35 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 36 |
+
|
| 37 |
+
self.fc1 = linear_layer(
|
| 38 |
+
in_channels, hidden_channels, bias=bias[0], **factory_kwargs
|
| 39 |
+
)
|
| 40 |
+
self.act = act_layer()
|
| 41 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 42 |
+
self.norm = (
|
| 43 |
+
norm_layer(hidden_channels, **factory_kwargs)
|
| 44 |
+
if norm_layer is not None
|
| 45 |
+
else nn.Identity()
|
| 46 |
+
)
|
| 47 |
+
self.fc2 = linear_layer(
|
| 48 |
+
hidden_channels, out_features, bias=bias[1], **factory_kwargs
|
| 49 |
+
)
|
| 50 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 51 |
+
|
| 52 |
+
def forward(self, x):
|
| 53 |
+
x = self.fc1(x)
|
| 54 |
+
x = self.act(x)
|
| 55 |
+
x = self.drop1(x)
|
| 56 |
+
x = self.norm(x)
|
| 57 |
+
x = self.fc2(x)
|
| 58 |
+
x = self.drop2(x)
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
#
|
| 63 |
+
class MLPEmbedder(nn.Module):
|
| 64 |
+
"""copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
|
| 65 |
+
def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
|
| 66 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 67 |
+
super().__init__()
|
| 68 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
|
| 69 |
+
self.silu = nn.SiLU()
|
| 70 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
|
| 71 |
+
|
| 72 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class FinalLayer(nn.Module):
|
| 77 |
+
"""The final layer of DiT."""
|
| 78 |
+
|
| 79 |
+
def __init__(
|
| 80 |
+
self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
|
| 81 |
+
):
|
| 82 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 83 |
+
super().__init__()
|
| 84 |
+
|
| 85 |
+
# Just use LayerNorm for the final layer
|
| 86 |
+
self.norm_final = nn.LayerNorm(
|
| 87 |
+
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
|
| 88 |
+
)
|
| 89 |
+
if isinstance(patch_size, int):
|
| 90 |
+
self.linear = nn.Linear(
|
| 91 |
+
hidden_size,
|
| 92 |
+
patch_size * patch_size * out_channels,
|
| 93 |
+
bias=True,
|
| 94 |
+
**factory_kwargs
|
| 95 |
+
)
|
| 96 |
+
else:
|
| 97 |
+
self.linear = nn.Linear(
|
| 98 |
+
hidden_size,
|
| 99 |
+
patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
|
| 100 |
+
bias=True,
|
| 101 |
+
)
|
| 102 |
+
nn.init.zeros_(self.linear.weight)
|
| 103 |
+
nn.init.zeros_(self.linear.bias)
|
| 104 |
+
|
| 105 |
+
# Here we don't distinguish between the modulate types. Just use the simple one.
|
| 106 |
+
self.adaLN_modulation = nn.Sequential(
|
| 107 |
+
act_layer(),
|
| 108 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
| 109 |
+
)
|
| 110 |
+
# Zero-initialize the modulation
|
| 111 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
| 112 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
| 113 |
+
|
| 114 |
+
def forward(self, x, c):
|
| 115 |
+
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 116 |
+
x = modulate(self.norm_final(x), shift=shift, scale=scale)
|
| 117 |
+
x = self.linear(x)
|
| 118 |
+
return x
|
hunyuan_model/models.py
ADDED
|
@@ -0,0 +1,1044 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, List, Tuple, Optional, Union, Dict
|
| 3 |
+
import accelerate
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.checkpoint import checkpoint
|
| 9 |
+
|
| 10 |
+
from .activation_layers import get_activation_layer
|
| 11 |
+
from .norm_layers import get_norm_layer
|
| 12 |
+
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
|
| 13 |
+
from .attention import attention, parallel_attention, get_cu_seqlens
|
| 14 |
+
from .posemb_layers import apply_rotary_emb
|
| 15 |
+
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
|
| 16 |
+
from .modulate_layers import ModulateDiT, modulate, apply_gate
|
| 17 |
+
from .token_refiner import SingleTokenRefiner
|
| 18 |
+
from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
|
| 19 |
+
from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
|
| 20 |
+
|
| 21 |
+
from utils.safetensors_utils import MemoryEfficientSafeOpen
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class MMDoubleStreamBlock(nn.Module):
|
| 25 |
+
"""
|
| 26 |
+
A multimodal dit block with seperate modulation for
|
| 27 |
+
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
|
| 28 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
hidden_size: int,
|
| 34 |
+
heads_num: int,
|
| 35 |
+
mlp_width_ratio: float,
|
| 36 |
+
mlp_act_type: str = "gelu_tanh",
|
| 37 |
+
qk_norm: bool = True,
|
| 38 |
+
qk_norm_type: str = "rms",
|
| 39 |
+
qkv_bias: bool = False,
|
| 40 |
+
dtype: Optional[torch.dtype] = None,
|
| 41 |
+
device: Optional[torch.device] = None,
|
| 42 |
+
attn_mode: str = "flash",
|
| 43 |
+
split_attn: bool = False,
|
| 44 |
+
):
|
| 45 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.attn_mode = attn_mode
|
| 48 |
+
self.split_attn = split_attn
|
| 49 |
+
|
| 50 |
+
self.deterministic = False
|
| 51 |
+
self.heads_num = heads_num
|
| 52 |
+
head_dim = hidden_size // heads_num
|
| 53 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
| 54 |
+
|
| 55 |
+
self.img_mod = ModulateDiT(
|
| 56 |
+
hidden_size,
|
| 57 |
+
factor=6,
|
| 58 |
+
act_layer=get_activation_layer("silu"),
|
| 59 |
+
**factory_kwargs,
|
| 60 |
+
)
|
| 61 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 62 |
+
|
| 63 |
+
self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
| 64 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
| 65 |
+
self.img_attn_q_norm = (
|
| 66 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 67 |
+
)
|
| 68 |
+
self.img_attn_k_norm = (
|
| 69 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 70 |
+
)
|
| 71 |
+
self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
| 72 |
+
|
| 73 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 74 |
+
self.img_mlp = MLP(
|
| 75 |
+
hidden_size,
|
| 76 |
+
mlp_hidden_dim,
|
| 77 |
+
act_layer=get_activation_layer(mlp_act_type),
|
| 78 |
+
bias=True,
|
| 79 |
+
**factory_kwargs,
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.txt_mod = ModulateDiT(
|
| 83 |
+
hidden_size,
|
| 84 |
+
factor=6,
|
| 85 |
+
act_layer=get_activation_layer("silu"),
|
| 86 |
+
**factory_kwargs,
|
| 87 |
+
)
|
| 88 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 89 |
+
|
| 90 |
+
self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
| 91 |
+
self.txt_attn_q_norm = (
|
| 92 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 93 |
+
)
|
| 94 |
+
self.txt_attn_k_norm = (
|
| 95 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 96 |
+
)
|
| 97 |
+
self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
| 98 |
+
|
| 99 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 100 |
+
self.txt_mlp = MLP(
|
| 101 |
+
hidden_size,
|
| 102 |
+
mlp_hidden_dim,
|
| 103 |
+
act_layer=get_activation_layer(mlp_act_type),
|
| 104 |
+
bias=True,
|
| 105 |
+
**factory_kwargs,
|
| 106 |
+
)
|
| 107 |
+
self.hybrid_seq_parallel_attn = None
|
| 108 |
+
|
| 109 |
+
self.gradient_checkpointing = False
|
| 110 |
+
|
| 111 |
+
def enable_deterministic(self):
|
| 112 |
+
self.deterministic = True
|
| 113 |
+
|
| 114 |
+
def disable_deterministic(self):
|
| 115 |
+
self.deterministic = False
|
| 116 |
+
|
| 117 |
+
def enable_gradient_checkpointing(self):
|
| 118 |
+
self.gradient_checkpointing = True
|
| 119 |
+
|
| 120 |
+
def disable_gradient_checkpointing(self):
|
| 121 |
+
self.gradient_checkpointing = False
|
| 122 |
+
|
| 123 |
+
def _forward(
|
| 124 |
+
self,
|
| 125 |
+
img: torch.Tensor,
|
| 126 |
+
txt: torch.Tensor,
|
| 127 |
+
vec: torch.Tensor,
|
| 128 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 129 |
+
total_len: Optional[torch.Tensor] = None,
|
| 130 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
| 131 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
| 132 |
+
max_seqlen_q: Optional[int] = None,
|
| 133 |
+
max_seqlen_kv: Optional[int] = None,
|
| 134 |
+
freqs_cis: tuple = None,
|
| 135 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 136 |
+
(img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
|
| 137 |
+
6, dim=-1
|
| 138 |
+
)
|
| 139 |
+
(txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
|
| 140 |
+
6, dim=-1
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# Prepare image for attention.
|
| 144 |
+
img_modulated = self.img_norm1(img)
|
| 145 |
+
img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
|
| 146 |
+
img_qkv = self.img_attn_qkv(img_modulated)
|
| 147 |
+
img_modulated = None
|
| 148 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
| 149 |
+
img_qkv = None
|
| 150 |
+
# Apply QK-Norm if needed
|
| 151 |
+
img_q = self.img_attn_q_norm(img_q).to(img_v)
|
| 152 |
+
img_k = self.img_attn_k_norm(img_k).to(img_v)
|
| 153 |
+
|
| 154 |
+
# Apply RoPE if needed.
|
| 155 |
+
if freqs_cis is not None:
|
| 156 |
+
img_q_shape = img_q.shape
|
| 157 |
+
img_k_shape = img_k.shape
|
| 158 |
+
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
| 159 |
+
assert (
|
| 160 |
+
img_q.shape == img_q_shape and img_k.shape == img_k_shape
|
| 161 |
+
), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
|
| 162 |
+
# img_q, img_k = img_qq, img_kk
|
| 163 |
+
|
| 164 |
+
# Prepare txt for attention.
|
| 165 |
+
txt_modulated = self.txt_norm1(txt)
|
| 166 |
+
txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
|
| 167 |
+
txt_qkv = self.txt_attn_qkv(txt_modulated)
|
| 168 |
+
txt_modulated = None
|
| 169 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
| 170 |
+
txt_qkv = None
|
| 171 |
+
# Apply QK-Norm if needed.
|
| 172 |
+
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
|
| 173 |
+
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
|
| 174 |
+
|
| 175 |
+
# Run actual attention.
|
| 176 |
+
img_q_len = img_q.shape[1]
|
| 177 |
+
img_kv_len = img_k.shape[1]
|
| 178 |
+
batch_size = img_k.shape[0]
|
| 179 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
| 180 |
+
img_q = txt_q = None
|
| 181 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
| 182 |
+
img_k = txt_k = None
|
| 183 |
+
v = torch.cat((img_v, txt_v), dim=1)
|
| 184 |
+
img_v = txt_v = None
|
| 185 |
+
|
| 186 |
+
assert (
|
| 187 |
+
cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
|
| 188 |
+
), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
|
| 189 |
+
|
| 190 |
+
# attention computation start
|
| 191 |
+
if not self.hybrid_seq_parallel_attn:
|
| 192 |
+
l = [q, k, v]
|
| 193 |
+
q = k = v = None
|
| 194 |
+
attn = attention(
|
| 195 |
+
l,
|
| 196 |
+
mode=self.attn_mode,
|
| 197 |
+
attn_mask=attn_mask,
|
| 198 |
+
total_len=total_len,
|
| 199 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 200 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
| 201 |
+
max_seqlen_q=max_seqlen_q,
|
| 202 |
+
max_seqlen_kv=max_seqlen_kv,
|
| 203 |
+
batch_size=batch_size,
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
attn = parallel_attention(
|
| 207 |
+
self.hybrid_seq_parallel_attn,
|
| 208 |
+
q,
|
| 209 |
+
k,
|
| 210 |
+
v,
|
| 211 |
+
img_q_len=img_q_len,
|
| 212 |
+
img_kv_len=img_kv_len,
|
| 213 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 214 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# attention computation end
|
| 218 |
+
|
| 219 |
+
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
|
| 220 |
+
attn = None
|
| 221 |
+
|
| 222 |
+
# Calculate the img bloks.
|
| 223 |
+
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
|
| 224 |
+
img_attn = None
|
| 225 |
+
img = img + apply_gate(
|
| 226 |
+
self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
|
| 227 |
+
gate=img_mod2_gate,
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# Calculate the txt bloks.
|
| 231 |
+
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
|
| 232 |
+
txt_attn = None
|
| 233 |
+
txt = txt + apply_gate(
|
| 234 |
+
self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
|
| 235 |
+
gate=txt_mod2_gate,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
return img, txt
|
| 239 |
+
|
| 240 |
+
# def forward(
|
| 241 |
+
# self,
|
| 242 |
+
# img: torch.Tensor,
|
| 243 |
+
# txt: torch.Tensor,
|
| 244 |
+
# vec: torch.Tensor,
|
| 245 |
+
# attn_mask: Optional[torch.Tensor] = None,
|
| 246 |
+
# cu_seqlens_q: Optional[torch.Tensor] = None,
|
| 247 |
+
# cu_seqlens_kv: Optional[torch.Tensor] = None,
|
| 248 |
+
# max_seqlen_q: Optional[int] = None,
|
| 249 |
+
# max_seqlen_kv: Optional[int] = None,
|
| 250 |
+
# freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
| 251 |
+
# ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 252 |
+
def forward(self, *args, **kwargs):
|
| 253 |
+
if self.training and self.gradient_checkpointing:
|
| 254 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
| 255 |
+
else:
|
| 256 |
+
return self._forward(*args, **kwargs)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
class MMSingleStreamBlock(nn.Module):
|
| 260 |
+
"""
|
| 261 |
+
A DiT block with parallel linear layers as described in
|
| 262 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 263 |
+
Also refer to (SD3): https://arxiv.org/abs/2403.03206
|
| 264 |
+
(Flux.1): https://github.com/black-forest-labs/flux
|
| 265 |
+
"""
|
| 266 |
+
|
| 267 |
+
def __init__(
|
| 268 |
+
self,
|
| 269 |
+
hidden_size: int,
|
| 270 |
+
heads_num: int,
|
| 271 |
+
mlp_width_ratio: float = 4.0,
|
| 272 |
+
mlp_act_type: str = "gelu_tanh",
|
| 273 |
+
qk_norm: bool = True,
|
| 274 |
+
qk_norm_type: str = "rms",
|
| 275 |
+
qk_scale: float = None,
|
| 276 |
+
dtype: Optional[torch.dtype] = None,
|
| 277 |
+
device: Optional[torch.device] = None,
|
| 278 |
+
attn_mode: str = "flash",
|
| 279 |
+
split_attn: bool = False,
|
| 280 |
+
):
|
| 281 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.attn_mode = attn_mode
|
| 284 |
+
self.split_attn = split_attn
|
| 285 |
+
|
| 286 |
+
self.deterministic = False
|
| 287 |
+
self.hidden_size = hidden_size
|
| 288 |
+
self.heads_num = heads_num
|
| 289 |
+
head_dim = hidden_size // heads_num
|
| 290 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
| 291 |
+
self.mlp_hidden_dim = mlp_hidden_dim
|
| 292 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 293 |
+
|
| 294 |
+
# qkv and mlp_in
|
| 295 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
|
| 296 |
+
# proj and mlp_out
|
| 297 |
+
self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
|
| 298 |
+
|
| 299 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
| 300 |
+
self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 301 |
+
self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 302 |
+
|
| 303 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
| 304 |
+
|
| 305 |
+
self.mlp_act = get_activation_layer(mlp_act_type)()
|
| 306 |
+
self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
|
| 307 |
+
self.hybrid_seq_parallel_attn = None
|
| 308 |
+
|
| 309 |
+
self.gradient_checkpointing = False
|
| 310 |
+
|
| 311 |
+
def enable_deterministic(self):
|
| 312 |
+
self.deterministic = True
|
| 313 |
+
|
| 314 |
+
def disable_deterministic(self):
|
| 315 |
+
self.deterministic = False
|
| 316 |
+
|
| 317 |
+
def enable_gradient_checkpointing(self):
|
| 318 |
+
self.gradient_checkpointing = True
|
| 319 |
+
|
| 320 |
+
def disable_gradient_checkpointing(self):
|
| 321 |
+
self.gradient_checkpointing = False
|
| 322 |
+
|
| 323 |
+
def _forward(
|
| 324 |
+
self,
|
| 325 |
+
x: torch.Tensor,
|
| 326 |
+
vec: torch.Tensor,
|
| 327 |
+
txt_len: int,
|
| 328 |
+
attn_mask: Optional[torch.Tensor] = None,
|
| 329 |
+
total_len: Optional[torch.Tensor] = None,
|
| 330 |
+
cu_seqlens_q: Optional[torch.Tensor] = None,
|
| 331 |
+
cu_seqlens_kv: Optional[torch.Tensor] = None,
|
| 332 |
+
max_seqlen_q: Optional[int] = None,
|
| 333 |
+
max_seqlen_kv: Optional[int] = None,
|
| 334 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
| 335 |
+
) -> torch.Tensor:
|
| 336 |
+
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
|
| 337 |
+
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
|
| 338 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
| 339 |
+
x_mod = None
|
| 340 |
+
# mlp = mlp.to("cpu", non_blocking=True)
|
| 341 |
+
# clean_memory_on_device(x.device)
|
| 342 |
+
|
| 343 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
| 344 |
+
qkv = None
|
| 345 |
+
|
| 346 |
+
# Apply QK-Norm if needed.
|
| 347 |
+
q = self.q_norm(q).to(v)
|
| 348 |
+
k = self.k_norm(k).to(v)
|
| 349 |
+
|
| 350 |
+
# Apply RoPE if needed.
|
| 351 |
+
if freqs_cis is not None:
|
| 352 |
+
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
|
| 353 |
+
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
|
| 354 |
+
q = k = None
|
| 355 |
+
img_q_shape = img_q.shape
|
| 356 |
+
img_k_shape = img_k.shape
|
| 357 |
+
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
|
| 358 |
+
assert (
|
| 359 |
+
img_q.shape == img_q_shape and img_k_shape == img_k.shape
|
| 360 |
+
), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
|
| 361 |
+
# img_q, img_k = img_qq, img_kk
|
| 362 |
+
# del img_qq, img_kk
|
| 363 |
+
q = torch.cat((img_q, txt_q), dim=1)
|
| 364 |
+
k = torch.cat((img_k, txt_k), dim=1)
|
| 365 |
+
del img_q, txt_q, img_k, txt_k
|
| 366 |
+
|
| 367 |
+
# Compute attention.
|
| 368 |
+
assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
|
| 369 |
+
|
| 370 |
+
# attention computation start
|
| 371 |
+
if not self.hybrid_seq_parallel_attn:
|
| 372 |
+
l = [q, k, v]
|
| 373 |
+
q = k = v = None
|
| 374 |
+
attn = attention(
|
| 375 |
+
l,
|
| 376 |
+
mode=self.attn_mode,
|
| 377 |
+
attn_mask=attn_mask,
|
| 378 |
+
total_len=total_len,
|
| 379 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 380 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
| 381 |
+
max_seqlen_q=max_seqlen_q,
|
| 382 |
+
max_seqlen_kv=max_seqlen_kv,
|
| 383 |
+
batch_size=x.shape[0],
|
| 384 |
+
)
|
| 385 |
+
else:
|
| 386 |
+
attn = parallel_attention(
|
| 387 |
+
self.hybrid_seq_parallel_attn,
|
| 388 |
+
q,
|
| 389 |
+
k,
|
| 390 |
+
v,
|
| 391 |
+
img_q_len=img_q.shape[1],
|
| 392 |
+
img_kv_len=img_k.shape[1],
|
| 393 |
+
cu_seqlens_q=cu_seqlens_q,
|
| 394 |
+
cu_seqlens_kv=cu_seqlens_kv,
|
| 395 |
+
)
|
| 396 |
+
# attention computation end
|
| 397 |
+
|
| 398 |
+
# Compute activation in mlp stream, cat again and run second linear layer.
|
| 399 |
+
# mlp = mlp.to(x.device)
|
| 400 |
+
mlp = self.mlp_act(mlp)
|
| 401 |
+
attn_mlp = torch.cat((attn, mlp), 2)
|
| 402 |
+
attn = None
|
| 403 |
+
mlp = None
|
| 404 |
+
output = self.linear2(attn_mlp)
|
| 405 |
+
attn_mlp = None
|
| 406 |
+
return x + apply_gate(output, gate=mod_gate)
|
| 407 |
+
|
| 408 |
+
# def forward(
|
| 409 |
+
# self,
|
| 410 |
+
# x: torch.Tensor,
|
| 411 |
+
# vec: torch.Tensor,
|
| 412 |
+
# txt_len: int,
|
| 413 |
+
# attn_mask: Optional[torch.Tensor] = None,
|
| 414 |
+
# cu_seqlens_q: Optional[torch.Tensor] = None,
|
| 415 |
+
# cu_seqlens_kv: Optional[torch.Tensor] = None,
|
| 416 |
+
# max_seqlen_q: Optional[int] = None,
|
| 417 |
+
# max_seqlen_kv: Optional[int] = None,
|
| 418 |
+
# freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
| 419 |
+
# ) -> torch.Tensor:
|
| 420 |
+
def forward(self, *args, **kwargs):
|
| 421 |
+
if self.training and self.gradient_checkpointing:
|
| 422 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
| 423 |
+
else:
|
| 424 |
+
return self._forward(*args, **kwargs)
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
|
| 428 |
+
"""
|
| 429 |
+
HunyuanVideo Transformer backbone
|
| 430 |
+
|
| 431 |
+
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
|
| 432 |
+
|
| 433 |
+
Reference:
|
| 434 |
+
[1] Flux.1: https://github.com/black-forest-labs/flux
|
| 435 |
+
[2] MMDiT: http://arxiv.org/abs/2403.03206
|
| 436 |
+
|
| 437 |
+
Parameters
|
| 438 |
+
----------
|
| 439 |
+
args: argparse.Namespace
|
| 440 |
+
The arguments parsed by argparse.
|
| 441 |
+
patch_size: list
|
| 442 |
+
The size of the patch.
|
| 443 |
+
in_channels: int
|
| 444 |
+
The number of input channels.
|
| 445 |
+
out_channels: int
|
| 446 |
+
The number of output channels.
|
| 447 |
+
hidden_size: int
|
| 448 |
+
The hidden size of the transformer backbone.
|
| 449 |
+
heads_num: int
|
| 450 |
+
The number of attention heads.
|
| 451 |
+
mlp_width_ratio: float
|
| 452 |
+
The ratio of the hidden size of the MLP in the transformer block.
|
| 453 |
+
mlp_act_type: str
|
| 454 |
+
The activation function of the MLP in the transformer block.
|
| 455 |
+
depth_double_blocks: int
|
| 456 |
+
The number of transformer blocks in the double blocks.
|
| 457 |
+
depth_single_blocks: int
|
| 458 |
+
The number of transformer blocks in the single blocks.
|
| 459 |
+
rope_dim_list: list
|
| 460 |
+
The dimension of the rotary embedding for t, h, w.
|
| 461 |
+
qkv_bias: bool
|
| 462 |
+
Whether to use bias in the qkv linear layer.
|
| 463 |
+
qk_norm: bool
|
| 464 |
+
Whether to use qk norm.
|
| 465 |
+
qk_norm_type: str
|
| 466 |
+
The type of qk norm.
|
| 467 |
+
guidance_embed: bool
|
| 468 |
+
Whether to use guidance embedding for distillation.
|
| 469 |
+
text_projection: str
|
| 470 |
+
The type of the text projection, default is single_refiner.
|
| 471 |
+
use_attention_mask: bool
|
| 472 |
+
Whether to use attention mask for text encoder.
|
| 473 |
+
dtype: torch.dtype
|
| 474 |
+
The dtype of the model.
|
| 475 |
+
device: torch.device
|
| 476 |
+
The device of the model.
|
| 477 |
+
attn_mode: str
|
| 478 |
+
The mode of the attention, default is flash.
|
| 479 |
+
split_attn: bool
|
| 480 |
+
Whether to use split attention (make attention as batch size 1).
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
# @register_to_config
|
| 484 |
+
def __init__(
|
| 485 |
+
self,
|
| 486 |
+
text_states_dim: int,
|
| 487 |
+
text_states_dim_2: int,
|
| 488 |
+
patch_size: list = [1, 2, 2],
|
| 489 |
+
in_channels: int = 4, # Should be VAE.config.latent_channels.
|
| 490 |
+
out_channels: int = None,
|
| 491 |
+
hidden_size: int = 3072,
|
| 492 |
+
heads_num: int = 24,
|
| 493 |
+
mlp_width_ratio: float = 4.0,
|
| 494 |
+
mlp_act_type: str = "gelu_tanh",
|
| 495 |
+
mm_double_blocks_depth: int = 20,
|
| 496 |
+
mm_single_blocks_depth: int = 40,
|
| 497 |
+
rope_dim_list: List[int] = [16, 56, 56],
|
| 498 |
+
qkv_bias: bool = True,
|
| 499 |
+
qk_norm: bool = True,
|
| 500 |
+
qk_norm_type: str = "rms",
|
| 501 |
+
guidance_embed: bool = False, # For modulation.
|
| 502 |
+
text_projection: str = "single_refiner",
|
| 503 |
+
use_attention_mask: bool = True,
|
| 504 |
+
dtype: Optional[torch.dtype] = None,
|
| 505 |
+
device: Optional[torch.device] = None,
|
| 506 |
+
attn_mode: str = "flash",
|
| 507 |
+
split_attn: bool = False,
|
| 508 |
+
):
|
| 509 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 510 |
+
super().__init__()
|
| 511 |
+
|
| 512 |
+
self.patch_size = patch_size
|
| 513 |
+
self.in_channels = in_channels
|
| 514 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
| 515 |
+
self.unpatchify_channels = self.out_channels
|
| 516 |
+
self.guidance_embed = guidance_embed
|
| 517 |
+
self.rope_dim_list = rope_dim_list
|
| 518 |
+
|
| 519 |
+
# Text projection. Default to linear projection.
|
| 520 |
+
# Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
|
| 521 |
+
self.use_attention_mask = use_attention_mask
|
| 522 |
+
self.text_projection = text_projection
|
| 523 |
+
|
| 524 |
+
self.text_states_dim = text_states_dim
|
| 525 |
+
self.text_states_dim_2 = text_states_dim_2
|
| 526 |
+
|
| 527 |
+
if hidden_size % heads_num != 0:
|
| 528 |
+
raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
|
| 529 |
+
pe_dim = hidden_size // heads_num
|
| 530 |
+
if sum(rope_dim_list) != pe_dim:
|
| 531 |
+
raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
|
| 532 |
+
self.hidden_size = hidden_size
|
| 533 |
+
self.heads_num = heads_num
|
| 534 |
+
|
| 535 |
+
self.attn_mode = attn_mode
|
| 536 |
+
self.split_attn = split_attn
|
| 537 |
+
print(f"Using {self.attn_mode} attention mode, split_attn: {self.split_attn}")
|
| 538 |
+
|
| 539 |
+
# image projection
|
| 540 |
+
self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
|
| 541 |
+
|
| 542 |
+
# text projection
|
| 543 |
+
if self.text_projection == "linear":
|
| 544 |
+
self.txt_in = TextProjection(
|
| 545 |
+
self.text_states_dim,
|
| 546 |
+
self.hidden_size,
|
| 547 |
+
get_activation_layer("silu"),
|
| 548 |
+
**factory_kwargs,
|
| 549 |
+
)
|
| 550 |
+
elif self.text_projection == "single_refiner":
|
| 551 |
+
self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
|
| 552 |
+
else:
|
| 553 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
| 554 |
+
|
| 555 |
+
# time modulation
|
| 556 |
+
self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
|
| 557 |
+
|
| 558 |
+
# text modulation
|
| 559 |
+
self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
|
| 560 |
+
|
| 561 |
+
# guidance modulation
|
| 562 |
+
self.guidance_in = (
|
| 563 |
+
TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
# double blocks
|
| 567 |
+
self.double_blocks = nn.ModuleList(
|
| 568 |
+
[
|
| 569 |
+
MMDoubleStreamBlock(
|
| 570 |
+
self.hidden_size,
|
| 571 |
+
self.heads_num,
|
| 572 |
+
mlp_width_ratio=mlp_width_ratio,
|
| 573 |
+
mlp_act_type=mlp_act_type,
|
| 574 |
+
qk_norm=qk_norm,
|
| 575 |
+
qk_norm_type=qk_norm_type,
|
| 576 |
+
qkv_bias=qkv_bias,
|
| 577 |
+
attn_mode=attn_mode,
|
| 578 |
+
split_attn=split_attn,
|
| 579 |
+
**factory_kwargs,
|
| 580 |
+
)
|
| 581 |
+
for _ in range(mm_double_blocks_depth)
|
| 582 |
+
]
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
# single blocks
|
| 586 |
+
self.single_blocks = nn.ModuleList(
|
| 587 |
+
[
|
| 588 |
+
MMSingleStreamBlock(
|
| 589 |
+
self.hidden_size,
|
| 590 |
+
self.heads_num,
|
| 591 |
+
mlp_width_ratio=mlp_width_ratio,
|
| 592 |
+
mlp_act_type=mlp_act_type,
|
| 593 |
+
qk_norm=qk_norm,
|
| 594 |
+
qk_norm_type=qk_norm_type,
|
| 595 |
+
attn_mode=attn_mode,
|
| 596 |
+
split_attn=split_attn,
|
| 597 |
+
**factory_kwargs,
|
| 598 |
+
)
|
| 599 |
+
for _ in range(mm_single_blocks_depth)
|
| 600 |
+
]
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
self.final_layer = FinalLayer(
|
| 604 |
+
self.hidden_size,
|
| 605 |
+
self.patch_size,
|
| 606 |
+
self.out_channels,
|
| 607 |
+
get_activation_layer("silu"),
|
| 608 |
+
**factory_kwargs,
|
| 609 |
+
)
|
| 610 |
+
|
| 611 |
+
self.gradient_checkpointing = False
|
| 612 |
+
self.blocks_to_swap = None
|
| 613 |
+
self.offloader_double = None
|
| 614 |
+
self.offloader_single = None
|
| 615 |
+
self._enable_img_in_txt_in_offloading = False
|
| 616 |
+
|
| 617 |
+
@property
|
| 618 |
+
def device(self):
|
| 619 |
+
return next(self.parameters()).device
|
| 620 |
+
|
| 621 |
+
@property
|
| 622 |
+
def dtype(self):
|
| 623 |
+
return next(self.parameters()).dtype
|
| 624 |
+
|
| 625 |
+
def enable_gradient_checkpointing(self):
|
| 626 |
+
self.gradient_checkpointing = True
|
| 627 |
+
|
| 628 |
+
self.txt_in.enable_gradient_checkpointing()
|
| 629 |
+
|
| 630 |
+
for block in self.double_blocks + self.single_blocks:
|
| 631 |
+
block.enable_gradient_checkpointing()
|
| 632 |
+
|
| 633 |
+
print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
|
| 634 |
+
|
| 635 |
+
def disable_gradient_checkpointing(self):
|
| 636 |
+
self.gradient_checkpointing = False
|
| 637 |
+
|
| 638 |
+
self.txt_in.disable_gradient_checkpointing()
|
| 639 |
+
|
| 640 |
+
for block in self.double_blocks + self.single_blocks:
|
| 641 |
+
block.disable_gradient_checkpointing()
|
| 642 |
+
|
| 643 |
+
print(f"HYVideoDiffusionTransformer: Gradient checkpointing disabled.")
|
| 644 |
+
|
| 645 |
+
def enable_img_in_txt_in_offloading(self):
|
| 646 |
+
self._enable_img_in_txt_in_offloading = True
|
| 647 |
+
|
| 648 |
+
def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
|
| 649 |
+
self.blocks_to_swap = num_blocks
|
| 650 |
+
self.num_double_blocks = len(self.double_blocks)
|
| 651 |
+
self.num_single_blocks = len(self.single_blocks)
|
| 652 |
+
double_blocks_to_swap = num_blocks // 2
|
| 653 |
+
single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
|
| 654 |
+
|
| 655 |
+
assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
|
| 656 |
+
f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
|
| 657 |
+
f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
self.offloader_double = ModelOffloader(
|
| 661 |
+
"double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
|
| 662 |
+
)
|
| 663 |
+
self.offloader_single = ModelOffloader(
|
| 664 |
+
"single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
|
| 665 |
+
)
|
| 666 |
+
print(
|
| 667 |
+
f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
def switch_block_swap_for_inference(self):
|
| 671 |
+
if self.blocks_to_swap:
|
| 672 |
+
self.offloader_double.set_forward_only(True)
|
| 673 |
+
self.offloader_single.set_forward_only(True)
|
| 674 |
+
self.prepare_block_swap_before_forward()
|
| 675 |
+
print(f"HYVideoDiffusionTransformer: Block swap set to forward only.")
|
| 676 |
+
|
| 677 |
+
def switch_block_swap_for_training(self):
|
| 678 |
+
if self.blocks_to_swap:
|
| 679 |
+
self.offloader_double.set_forward_only(False)
|
| 680 |
+
self.offloader_single.set_forward_only(False)
|
| 681 |
+
self.prepare_block_swap_before_forward()
|
| 682 |
+
print(f"HYVideoDiffusionTransformer: Block swap set to forward and backward.")
|
| 683 |
+
|
| 684 |
+
def move_to_device_except_swap_blocks(self, device: torch.device):
|
| 685 |
+
# assume model is on cpu. do not move blocks to device to reduce temporary memory usage
|
| 686 |
+
if self.blocks_to_swap:
|
| 687 |
+
save_double_blocks = self.double_blocks
|
| 688 |
+
save_single_blocks = self.single_blocks
|
| 689 |
+
self.double_blocks = None
|
| 690 |
+
self.single_blocks = None
|
| 691 |
+
|
| 692 |
+
self.to(device)
|
| 693 |
+
|
| 694 |
+
if self.blocks_to_swap:
|
| 695 |
+
self.double_blocks = save_double_blocks
|
| 696 |
+
self.single_blocks = save_single_blocks
|
| 697 |
+
|
| 698 |
+
def prepare_block_swap_before_forward(self):
|
| 699 |
+
if self.blocks_to_swap is None or self.blocks_to_swap == 0:
|
| 700 |
+
return
|
| 701 |
+
self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
|
| 702 |
+
self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
|
| 703 |
+
|
| 704 |
+
def enable_deterministic(self):
|
| 705 |
+
for block in self.double_blocks:
|
| 706 |
+
block.enable_deterministic()
|
| 707 |
+
for block in self.single_blocks:
|
| 708 |
+
block.enable_deterministic()
|
| 709 |
+
|
| 710 |
+
def disable_deterministic(self):
|
| 711 |
+
for block in self.double_blocks:
|
| 712 |
+
block.disable_deterministic()
|
| 713 |
+
for block in self.single_blocks:
|
| 714 |
+
block.disable_deterministic()
|
| 715 |
+
|
| 716 |
+
def forward(
|
| 717 |
+
self,
|
| 718 |
+
x: torch.Tensor,
|
| 719 |
+
t: torch.Tensor, # Should be in range(0, 1000).
|
| 720 |
+
text_states: torch.Tensor = None,
|
| 721 |
+
text_mask: torch.Tensor = None, # Now we don't use it.
|
| 722 |
+
text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
|
| 723 |
+
freqs_cos: Optional[torch.Tensor] = None,
|
| 724 |
+
freqs_sin: Optional[torch.Tensor] = None,
|
| 725 |
+
guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
|
| 726 |
+
return_dict: bool = True,
|
| 727 |
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
| 728 |
+
out = {}
|
| 729 |
+
img = x
|
| 730 |
+
txt = text_states
|
| 731 |
+
_, _, ot, oh, ow = x.shape
|
| 732 |
+
tt, th, tw = (
|
| 733 |
+
ot // self.patch_size[0],
|
| 734 |
+
oh // self.patch_size[1],
|
| 735 |
+
ow // self.patch_size[2],
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# Prepare modulation vectors.
|
| 739 |
+
vec = self.time_in(t)
|
| 740 |
+
|
| 741 |
+
# text modulation
|
| 742 |
+
vec = vec + self.vector_in(text_states_2)
|
| 743 |
+
|
| 744 |
+
# guidance modulation
|
| 745 |
+
if self.guidance_embed:
|
| 746 |
+
if guidance is None:
|
| 747 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
| 748 |
+
|
| 749 |
+
# our timestep_embedding is merged into guidance_in(TimestepEmbedder)
|
| 750 |
+
vec = vec + self.guidance_in(guidance)
|
| 751 |
+
|
| 752 |
+
# Embed image and text.
|
| 753 |
+
if self._enable_img_in_txt_in_offloading:
|
| 754 |
+
self.img_in.to(x.device, non_blocking=True)
|
| 755 |
+
self.txt_in.to(x.device, non_blocking=True)
|
| 756 |
+
synchronize_device(x.device)
|
| 757 |
+
|
| 758 |
+
img = self.img_in(img)
|
| 759 |
+
if self.text_projection == "linear":
|
| 760 |
+
txt = self.txt_in(txt)
|
| 761 |
+
elif self.text_projection == "single_refiner":
|
| 762 |
+
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
|
| 763 |
+
else:
|
| 764 |
+
raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
|
| 765 |
+
|
| 766 |
+
if self._enable_img_in_txt_in_offloading:
|
| 767 |
+
self.img_in.to(torch.device("cpu"), non_blocking=True)
|
| 768 |
+
self.txt_in.to(torch.device("cpu"), non_blocking=True)
|
| 769 |
+
synchronize_device(x.device)
|
| 770 |
+
clean_memory_on_device(x.device)
|
| 771 |
+
|
| 772 |
+
txt_seq_len = txt.shape[1]
|
| 773 |
+
img_seq_len = img.shape[1]
|
| 774 |
+
|
| 775 |
+
# Compute cu_squlens and max_seqlen for flash attention
|
| 776 |
+
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
|
| 777 |
+
cu_seqlens_kv = cu_seqlens_q
|
| 778 |
+
max_seqlen_q = img_seq_len + txt_seq_len
|
| 779 |
+
max_seqlen_kv = max_seqlen_q
|
| 780 |
+
|
| 781 |
+
attn_mask = total_len = None
|
| 782 |
+
if self.split_attn or self.attn_mode == "torch":
|
| 783 |
+
# calculate text length and total length
|
| 784 |
+
text_len = text_mask.sum(dim=1) # (bs, )
|
| 785 |
+
total_len = img_seq_len + text_len # (bs, )
|
| 786 |
+
if self.attn_mode == "torch" and not self.split_attn:
|
| 787 |
+
# initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
|
| 788 |
+
bs = img.shape[0]
|
| 789 |
+
attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
|
| 790 |
+
|
| 791 |
+
# set attention mask with total_len
|
| 792 |
+
for i in range(bs):
|
| 793 |
+
attn_mask[i, :, : total_len[i], : total_len[i]] = True
|
| 794 |
+
total_len = None # means we don't use split_attn
|
| 795 |
+
|
| 796 |
+
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
|
| 797 |
+
# --------------------- Pass through DiT blocks ------------------------
|
| 798 |
+
for block_idx, block in enumerate(self.double_blocks):
|
| 799 |
+
double_block_args = [
|
| 800 |
+
img,
|
| 801 |
+
txt,
|
| 802 |
+
vec,
|
| 803 |
+
attn_mask,
|
| 804 |
+
total_len,
|
| 805 |
+
cu_seqlens_q,
|
| 806 |
+
cu_seqlens_kv,
|
| 807 |
+
max_seqlen_q,
|
| 808 |
+
max_seqlen_kv,
|
| 809 |
+
freqs_cis,
|
| 810 |
+
]
|
| 811 |
+
|
| 812 |
+
if self.blocks_to_swap:
|
| 813 |
+
self.offloader_double.wait_for_block(block_idx)
|
| 814 |
+
|
| 815 |
+
img, txt = block(*double_block_args)
|
| 816 |
+
|
| 817 |
+
if self.blocks_to_swap:
|
| 818 |
+
self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
|
| 819 |
+
|
| 820 |
+
# Merge txt and img to pass through single stream blocks.
|
| 821 |
+
x = torch.cat((img, txt), 1)
|
| 822 |
+
if self.blocks_to_swap:
|
| 823 |
+
# delete img, txt to reduce memory usage
|
| 824 |
+
del img, txt
|
| 825 |
+
clean_memory_on_device(x.device)
|
| 826 |
+
|
| 827 |
+
if len(self.single_blocks) > 0:
|
| 828 |
+
for block_idx, block in enumerate(self.single_blocks):
|
| 829 |
+
single_block_args = [
|
| 830 |
+
x,
|
| 831 |
+
vec,
|
| 832 |
+
txt_seq_len,
|
| 833 |
+
attn_mask,
|
| 834 |
+
total_len,
|
| 835 |
+
cu_seqlens_q,
|
| 836 |
+
cu_seqlens_kv,
|
| 837 |
+
max_seqlen_q,
|
| 838 |
+
max_seqlen_kv,
|
| 839 |
+
freqs_cis,
|
| 840 |
+
]
|
| 841 |
+
if self.blocks_to_swap:
|
| 842 |
+
self.offloader_single.wait_for_block(block_idx)
|
| 843 |
+
|
| 844 |
+
x = block(*single_block_args)
|
| 845 |
+
|
| 846 |
+
if self.blocks_to_swap:
|
| 847 |
+
self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
|
| 848 |
+
|
| 849 |
+
img = x[:, :img_seq_len, ...]
|
| 850 |
+
x = None
|
| 851 |
+
|
| 852 |
+
# ---------------------------- Final layer ------------------------------
|
| 853 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 854 |
+
|
| 855 |
+
img = self.unpatchify(img, tt, th, tw)
|
| 856 |
+
if return_dict:
|
| 857 |
+
out["x"] = img
|
| 858 |
+
return out
|
| 859 |
+
return img
|
| 860 |
+
|
| 861 |
+
def unpatchify(self, x, t, h, w):
|
| 862 |
+
"""
|
| 863 |
+
x: (N, T, patch_size**2 * C)
|
| 864 |
+
imgs: (N, H, W, C)
|
| 865 |
+
"""
|
| 866 |
+
c = self.unpatchify_channels
|
| 867 |
+
pt, ph, pw = self.patch_size
|
| 868 |
+
assert t * h * w == x.shape[1]
|
| 869 |
+
|
| 870 |
+
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
|
| 871 |
+
x = torch.einsum("nthwcopq->nctohpwq", x)
|
| 872 |
+
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
|
| 873 |
+
|
| 874 |
+
return imgs
|
| 875 |
+
|
| 876 |
+
def params_count(self):
|
| 877 |
+
counts = {
|
| 878 |
+
"double": sum(
|
| 879 |
+
[
|
| 880 |
+
sum(p.numel() for p in block.img_attn_qkv.parameters())
|
| 881 |
+
+ sum(p.numel() for p in block.img_attn_proj.parameters())
|
| 882 |
+
+ sum(p.numel() for p in block.img_mlp.parameters())
|
| 883 |
+
+ sum(p.numel() for p in block.txt_attn_qkv.parameters())
|
| 884 |
+
+ sum(p.numel() for p in block.txt_attn_proj.parameters())
|
| 885 |
+
+ sum(p.numel() for p in block.txt_mlp.parameters())
|
| 886 |
+
for block in self.double_blocks
|
| 887 |
+
]
|
| 888 |
+
),
|
| 889 |
+
"single": sum(
|
| 890 |
+
[
|
| 891 |
+
sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
|
| 892 |
+
for block in self.single_blocks
|
| 893 |
+
]
|
| 894 |
+
),
|
| 895 |
+
"total": sum(p.numel() for p in self.parameters()),
|
| 896 |
+
}
|
| 897 |
+
counts["attn+mlp"] = counts["double"] + counts["single"]
|
| 898 |
+
return counts
|
| 899 |
+
|
| 900 |
+
|
| 901 |
+
#################################################################################
|
| 902 |
+
# HunyuanVideo Configs #
|
| 903 |
+
#################################################################################
|
| 904 |
+
|
| 905 |
+
HUNYUAN_VIDEO_CONFIG = {
|
| 906 |
+
"HYVideo-T/2": {
|
| 907 |
+
"mm_double_blocks_depth": 20,
|
| 908 |
+
"mm_single_blocks_depth": 40,
|
| 909 |
+
"rope_dim_list": [16, 56, 56],
|
| 910 |
+
"hidden_size": 3072,
|
| 911 |
+
"heads_num": 24,
|
| 912 |
+
"mlp_width_ratio": 4,
|
| 913 |
+
},
|
| 914 |
+
"HYVideo-T/2-cfgdistill": {
|
| 915 |
+
"mm_double_blocks_depth": 20,
|
| 916 |
+
"mm_single_blocks_depth": 40,
|
| 917 |
+
"rope_dim_list": [16, 56, 56],
|
| 918 |
+
"hidden_size": 3072,
|
| 919 |
+
"heads_num": 24,
|
| 920 |
+
"mlp_width_ratio": 4,
|
| 921 |
+
"guidance_embed": True,
|
| 922 |
+
},
|
| 923 |
+
}
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
|
| 927 |
+
"""load hunyuan video model
|
| 928 |
+
|
| 929 |
+
NOTE: Only support HYVideo-T/2-cfgdistill now.
|
| 930 |
+
|
| 931 |
+
Args:
|
| 932 |
+
text_state_dim (int): text state dimension
|
| 933 |
+
text_state_dim_2 (int): text state dimension 2
|
| 934 |
+
in_channels (int): input channels number
|
| 935 |
+
out_channels (int): output channels number
|
| 936 |
+
factor_kwargs (dict): factor kwargs
|
| 937 |
+
|
| 938 |
+
Returns:
|
| 939 |
+
model (nn.Module): The hunyuan video model
|
| 940 |
+
"""
|
| 941 |
+
# if args.model in HUNYUAN_VIDEO_CONFIG.keys():
|
| 942 |
+
model = HYVideoDiffusionTransformer(
|
| 943 |
+
text_states_dim=text_states_dim,
|
| 944 |
+
text_states_dim_2=text_states_dim_2,
|
| 945 |
+
in_channels=in_channels,
|
| 946 |
+
out_channels=out_channels,
|
| 947 |
+
**HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
|
| 948 |
+
**factor_kwargs,
|
| 949 |
+
)
|
| 950 |
+
return model
|
| 951 |
+
# else:
|
| 952 |
+
# raise NotImplementedError()
|
| 953 |
+
|
| 954 |
+
|
| 955 |
+
def load_state_dict(model, model_path):
|
| 956 |
+
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
|
| 957 |
+
|
| 958 |
+
load_key = "module"
|
| 959 |
+
if load_key in state_dict:
|
| 960 |
+
state_dict = state_dict[load_key]
|
| 961 |
+
else:
|
| 962 |
+
raise KeyError(
|
| 963 |
+
f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
|
| 964 |
+
f"are: {list(state_dict.keys())}."
|
| 965 |
+
)
|
| 966 |
+
model.load_state_dict(state_dict, strict=True, assign=True)
|
| 967 |
+
return model
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
def load_transformer(dit_path, attn_mode, split_attn, device, dtype, in_channels=16) -> HYVideoDiffusionTransformer:
|
| 971 |
+
# =========================== Build main model ===========================
|
| 972 |
+
factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode, "split_attn": split_attn}
|
| 973 |
+
latent_channels = 16
|
| 974 |
+
out_channels = latent_channels
|
| 975 |
+
|
| 976 |
+
with accelerate.init_empty_weights():
|
| 977 |
+
transformer = load_dit_model(
|
| 978 |
+
text_states_dim=4096,
|
| 979 |
+
text_states_dim_2=768,
|
| 980 |
+
in_channels=in_channels,
|
| 981 |
+
out_channels=out_channels,
|
| 982 |
+
factor_kwargs=factor_kwargs,
|
| 983 |
+
)
|
| 984 |
+
|
| 985 |
+
if os.path.splitext(dit_path)[-1] == ".safetensors":
|
| 986 |
+
# loading safetensors: may be already fp8
|
| 987 |
+
with MemoryEfficientSafeOpen(dit_path) as f:
|
| 988 |
+
state_dict = {}
|
| 989 |
+
for k in f.keys():
|
| 990 |
+
tensor = f.get_tensor(k)
|
| 991 |
+
tensor = tensor.to(device=device, dtype=dtype)
|
| 992 |
+
# TODO support comfy model
|
| 993 |
+
# if k.startswith("model.model."):
|
| 994 |
+
# k = convert_comfy_model_key(k)
|
| 995 |
+
state_dict[k] = tensor
|
| 996 |
+
transformer.load_state_dict(state_dict, strict=True, assign=True)
|
| 997 |
+
else:
|
| 998 |
+
transformer = load_state_dict(transformer, dit_path)
|
| 999 |
+
|
| 1000 |
+
return transformer
|
| 1001 |
+
|
| 1002 |
+
|
| 1003 |
+
def get_rotary_pos_embed_by_shape(model, latents_size):
|
| 1004 |
+
target_ndim = 3
|
| 1005 |
+
ndim = 5 - 2
|
| 1006 |
+
|
| 1007 |
+
if isinstance(model.patch_size, int):
|
| 1008 |
+
assert all(s % model.patch_size == 0 for s in latents_size), (
|
| 1009 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
|
| 1010 |
+
f"but got {latents_size}."
|
| 1011 |
+
)
|
| 1012 |
+
rope_sizes = [s // model.patch_size for s in latents_size]
|
| 1013 |
+
elif isinstance(model.patch_size, list):
|
| 1014 |
+
assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
|
| 1015 |
+
f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
|
| 1016 |
+
f"but got {latents_size}."
|
| 1017 |
+
)
|
| 1018 |
+
rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
|
| 1019 |
+
|
| 1020 |
+
if len(rope_sizes) != target_ndim:
|
| 1021 |
+
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
|
| 1022 |
+
head_dim = model.hidden_size // model.heads_num
|
| 1023 |
+
rope_dim_list = model.rope_dim_list
|
| 1024 |
+
if rope_dim_list is None:
|
| 1025 |
+
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
|
| 1026 |
+
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
|
| 1027 |
+
|
| 1028 |
+
rope_theta = 256
|
| 1029 |
+
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
|
| 1030 |
+
rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
|
| 1031 |
+
)
|
| 1032 |
+
return freqs_cos, freqs_sin
|
| 1033 |
+
|
| 1034 |
+
|
| 1035 |
+
def get_rotary_pos_embed(vae_name, model, video_length, height, width):
|
| 1036 |
+
# 884
|
| 1037 |
+
if "884" in vae_name:
|
| 1038 |
+
latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
|
| 1039 |
+
elif "888" in vae_name:
|
| 1040 |
+
latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
|
| 1041 |
+
else:
|
| 1042 |
+
latents_size = [video_length, height // 8, width // 8]
|
| 1043 |
+
|
| 1044 |
+
return get_rotary_pos_embed_by_shape(model, latents_size)
|
hunyuan_model/modulate_layers.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ModulateDiT(nn.Module):
|
| 8 |
+
"""Modulation layer for DiT."""
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
hidden_size: int,
|
| 12 |
+
factor: int,
|
| 13 |
+
act_layer: Callable,
|
| 14 |
+
dtype=None,
|
| 15 |
+
device=None,
|
| 16 |
+
):
|
| 17 |
+
factory_kwargs = {"dtype": dtype, "device": device}
|
| 18 |
+
super().__init__()
|
| 19 |
+
self.act = act_layer()
|
| 20 |
+
self.linear = nn.Linear(
|
| 21 |
+
hidden_size, factor * hidden_size, bias=True, **factory_kwargs
|
| 22 |
+
)
|
| 23 |
+
# Zero-initialize the modulation
|
| 24 |
+
nn.init.zeros_(self.linear.weight)
|
| 25 |
+
nn.init.zeros_(self.linear.bias)
|
| 26 |
+
|
| 27 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 28 |
+
return self.linear(self.act(x))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def modulate(x, shift=None, scale=None):
|
| 32 |
+
"""modulate by shift and scale
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
x (torch.Tensor): input tensor.
|
| 36 |
+
shift (torch.Tensor, optional): shift tensor. Defaults to None.
|
| 37 |
+
scale (torch.Tensor, optional): scale tensor. Defaults to None.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
torch.Tensor: the output tensor after modulate.
|
| 41 |
+
"""
|
| 42 |
+
if scale is None and shift is None:
|
| 43 |
+
return x
|
| 44 |
+
elif shift is None:
|
| 45 |
+
return x * (1 + scale.unsqueeze(1))
|
| 46 |
+
elif scale is None:
|
| 47 |
+
return x + shift.unsqueeze(1)
|
| 48 |
+
else:
|
| 49 |
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def apply_gate(x, gate=None, tanh=False):
|
| 53 |
+
"""AI is creating summary for apply_gate
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
x (torch.Tensor): input tensor.
|
| 57 |
+
gate (torch.Tensor, optional): gate tensor. Defaults to None.
|
| 58 |
+
tanh (bool, optional): whether to use tanh function. Defaults to False.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
torch.Tensor: the output tensor after apply gate.
|
| 62 |
+
"""
|
| 63 |
+
if gate is None:
|
| 64 |
+
return x
|
| 65 |
+
if tanh:
|
| 66 |
+
return x * gate.unsqueeze(1).tanh()
|
| 67 |
+
else:
|
| 68 |
+
return x * gate.unsqueeze(1)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def ckpt_wrapper(module):
|
| 72 |
+
def ckpt_forward(*inputs):
|
| 73 |
+
outputs = module(*inputs)
|
| 74 |
+
return outputs
|
| 75 |
+
|
| 76 |
+
return ckpt_forward
|
hunyuan_model/norm_layers.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RMSNorm(nn.Module):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
dim: int,
|
| 9 |
+
elementwise_affine=True,
|
| 10 |
+
eps: float = 1e-6,
|
| 11 |
+
device=None,
|
| 12 |
+
dtype=None,
|
| 13 |
+
):
|
| 14 |
+
"""
|
| 15 |
+
Initialize the RMSNorm normalization layer.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
dim (int): The dimension of the input tensor.
|
| 19 |
+
eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
|
| 20 |
+
|
| 21 |
+
Attributes:
|
| 22 |
+
eps (float): A small value added to the denominator for numerical stability.
|
| 23 |
+
weight (nn.Parameter): Learnable scaling parameter.
|
| 24 |
+
|
| 25 |
+
"""
|
| 26 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.eps = eps
|
| 29 |
+
if elementwise_affine:
|
| 30 |
+
self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
|
| 31 |
+
|
| 32 |
+
def _norm(self, x):
|
| 33 |
+
"""
|
| 34 |
+
Apply the RMSNorm normalization to the input tensor.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
x (torch.Tensor): The input tensor.
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
torch.Tensor: The normalized tensor.
|
| 41 |
+
|
| 42 |
+
"""
|
| 43 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 44 |
+
|
| 45 |
+
def forward(self, x):
|
| 46 |
+
"""
|
| 47 |
+
Forward pass through the RMSNorm layer.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
x (torch.Tensor): The input tensor.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
torch.Tensor: The output tensor after applying RMSNorm.
|
| 54 |
+
|
| 55 |
+
"""
|
| 56 |
+
output = self._norm(x.float()).type_as(x)
|
| 57 |
+
if hasattr(self, "weight"):
|
| 58 |
+
# output = output * self.weight
|
| 59 |
+
# support fp8
|
| 60 |
+
output = output * self.weight.to(output.dtype)
|
| 61 |
+
return output
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def get_norm_layer(norm_layer):
|
| 65 |
+
"""
|
| 66 |
+
Get the normalization layer.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
norm_layer (str): The type of normalization layer.
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
norm_layer (nn.Module): The normalization layer.
|
| 73 |
+
"""
|
| 74 |
+
if norm_layer == "layer":
|
| 75 |
+
return nn.LayerNorm
|
| 76 |
+
elif norm_layer == "rms":
|
| 77 |
+
return RMSNorm
|
| 78 |
+
else:
|
| 79 |
+
raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
|
hunyuan_model/pipeline_hunyuan_video.py
ADDED
|
@@ -0,0 +1,1100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
#
|
| 16 |
+
# Modified from diffusers==0.29.2
|
| 17 |
+
#
|
| 18 |
+
# ==============================================================================
|
| 19 |
+
import inspect
|
| 20 |
+
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
| 21 |
+
import torch
|
| 22 |
+
import torch.distributed as dist
|
| 23 |
+
import numpy as np
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from packaging import version
|
| 26 |
+
|
| 27 |
+
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
| 28 |
+
from diffusers.configuration_utils import FrozenDict
|
| 29 |
+
from diffusers.image_processor import VaeImageProcessor
|
| 30 |
+
from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
| 31 |
+
from diffusers.models import AutoencoderKL
|
| 32 |
+
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
| 33 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
| 34 |
+
from diffusers.utils import (
|
| 35 |
+
USE_PEFT_BACKEND,
|
| 36 |
+
deprecate,
|
| 37 |
+
logging,
|
| 38 |
+
replace_example_docstring,
|
| 39 |
+
scale_lora_layers,
|
| 40 |
+
unscale_lora_layers,
|
| 41 |
+
)
|
| 42 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 43 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 44 |
+
from diffusers.utils import BaseOutput
|
| 45 |
+
|
| 46 |
+
from ...constants import PRECISION_TO_TYPE
|
| 47 |
+
from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
| 48 |
+
from ...text_encoder import TextEncoder
|
| 49 |
+
from ...modules import HYVideoDiffusionTransformer
|
| 50 |
+
|
| 51 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 52 |
+
|
| 53 |
+
EXAMPLE_DOC_STRING = """"""
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
| 57 |
+
"""
|
| 58 |
+
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
|
| 59 |
+
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
|
| 60 |
+
"""
|
| 61 |
+
std_text = noise_pred_text.std(
|
| 62 |
+
dim=list(range(1, noise_pred_text.ndim)), keepdim=True
|
| 63 |
+
)
|
| 64 |
+
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
| 65 |
+
# rescale the results from guidance (fixes overexposure)
|
| 66 |
+
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
| 67 |
+
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
| 68 |
+
noise_cfg = (
|
| 69 |
+
guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
| 70 |
+
)
|
| 71 |
+
return noise_cfg
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def retrieve_timesteps(
|
| 75 |
+
scheduler,
|
| 76 |
+
num_inference_steps: Optional[int] = None,
|
| 77 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 78 |
+
timesteps: Optional[List[int]] = None,
|
| 79 |
+
sigmas: Optional[List[float]] = None,
|
| 80 |
+
**kwargs,
|
| 81 |
+
):
|
| 82 |
+
"""
|
| 83 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 84 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
scheduler (`SchedulerMixin`):
|
| 88 |
+
The scheduler to get timesteps from.
|
| 89 |
+
num_inference_steps (`int`):
|
| 90 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 91 |
+
must be `None`.
|
| 92 |
+
device (`str` or `torch.device`, *optional*):
|
| 93 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 94 |
+
timesteps (`List[int]`, *optional*):
|
| 95 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 96 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 97 |
+
sigmas (`List[float]`, *optional*):
|
| 98 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 99 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 103 |
+
second element is the number of inference steps.
|
| 104 |
+
"""
|
| 105 |
+
if timesteps is not None and sigmas is not None:
|
| 106 |
+
raise ValueError(
|
| 107 |
+
"Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
|
| 108 |
+
)
|
| 109 |
+
if timesteps is not None:
|
| 110 |
+
accepts_timesteps = "timesteps" in set(
|
| 111 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
| 112 |
+
)
|
| 113 |
+
if not accepts_timesteps:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 116 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 117 |
+
)
|
| 118 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 119 |
+
timesteps = scheduler.timesteps
|
| 120 |
+
num_inference_steps = len(timesteps)
|
| 121 |
+
elif sigmas is not None:
|
| 122 |
+
accept_sigmas = "sigmas" in set(
|
| 123 |
+
inspect.signature(scheduler.set_timesteps).parameters.keys()
|
| 124 |
+
)
|
| 125 |
+
if not accept_sigmas:
|
| 126 |
+
raise ValueError(
|
| 127 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 128 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 129 |
+
)
|
| 130 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 131 |
+
timesteps = scheduler.timesteps
|
| 132 |
+
num_inference_steps = len(timesteps)
|
| 133 |
+
else:
|
| 134 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 135 |
+
timesteps = scheduler.timesteps
|
| 136 |
+
return timesteps, num_inference_steps
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@dataclass
|
| 140 |
+
class HunyuanVideoPipelineOutput(BaseOutput):
|
| 141 |
+
videos: Union[torch.Tensor, np.ndarray]
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class HunyuanVideoPipeline(DiffusionPipeline):
|
| 145 |
+
r"""
|
| 146 |
+
Pipeline for text-to-video generation using HunyuanVideo.
|
| 147 |
+
|
| 148 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 149 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
vae ([`AutoencoderKL`]):
|
| 153 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 154 |
+
text_encoder ([`TextEncoder`]):
|
| 155 |
+
Frozen text-encoder.
|
| 156 |
+
text_encoder_2 ([`TextEncoder`]):
|
| 157 |
+
Frozen text-encoder_2.
|
| 158 |
+
transformer ([`HYVideoDiffusionTransformer`]):
|
| 159 |
+
A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
|
| 160 |
+
scheduler ([`SchedulerMixin`]):
|
| 161 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
|
| 165 |
+
_optional_components = ["text_encoder_2"]
|
| 166 |
+
_exclude_from_cpu_offload = ["transformer"]
|
| 167 |
+
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
| 168 |
+
|
| 169 |
+
def __init__(
|
| 170 |
+
self,
|
| 171 |
+
vae: AutoencoderKL,
|
| 172 |
+
text_encoder: TextEncoder,
|
| 173 |
+
transformer: HYVideoDiffusionTransformer,
|
| 174 |
+
scheduler: KarrasDiffusionSchedulers,
|
| 175 |
+
text_encoder_2: Optional[TextEncoder] = None,
|
| 176 |
+
progress_bar_config: Dict[str, Any] = None,
|
| 177 |
+
args=None,
|
| 178 |
+
):
|
| 179 |
+
super().__init__()
|
| 180 |
+
|
| 181 |
+
# ==========================================================================================
|
| 182 |
+
if progress_bar_config is None:
|
| 183 |
+
progress_bar_config = {}
|
| 184 |
+
if not hasattr(self, "_progress_bar_config"):
|
| 185 |
+
self._progress_bar_config = {}
|
| 186 |
+
self._progress_bar_config.update(progress_bar_config)
|
| 187 |
+
|
| 188 |
+
self.args = args
|
| 189 |
+
# ==========================================================================================
|
| 190 |
+
|
| 191 |
+
if (
|
| 192 |
+
hasattr(scheduler.config, "steps_offset")
|
| 193 |
+
and scheduler.config.steps_offset != 1
|
| 194 |
+
):
|
| 195 |
+
deprecation_message = (
|
| 196 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
| 197 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
| 198 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
| 199 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
| 200 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
| 201 |
+
" file"
|
| 202 |
+
)
|
| 203 |
+
deprecate(
|
| 204 |
+
"steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
|
| 205 |
+
)
|
| 206 |
+
new_config = dict(scheduler.config)
|
| 207 |
+
new_config["steps_offset"] = 1
|
| 208 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 209 |
+
|
| 210 |
+
if (
|
| 211 |
+
hasattr(scheduler.config, "clip_sample")
|
| 212 |
+
and scheduler.config.clip_sample is True
|
| 213 |
+
):
|
| 214 |
+
deprecation_message = (
|
| 215 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
| 216 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
| 217 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
| 218 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
| 219 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
| 220 |
+
)
|
| 221 |
+
deprecate(
|
| 222 |
+
"clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
|
| 223 |
+
)
|
| 224 |
+
new_config = dict(scheduler.config)
|
| 225 |
+
new_config["clip_sample"] = False
|
| 226 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
| 227 |
+
|
| 228 |
+
self.register_modules(
|
| 229 |
+
vae=vae,
|
| 230 |
+
text_encoder=text_encoder,
|
| 231 |
+
transformer=transformer,
|
| 232 |
+
scheduler=scheduler,
|
| 233 |
+
text_encoder_2=text_encoder_2,
|
| 234 |
+
)
|
| 235 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
| 236 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
| 237 |
+
|
| 238 |
+
def encode_prompt(
|
| 239 |
+
self,
|
| 240 |
+
prompt,
|
| 241 |
+
device,
|
| 242 |
+
num_videos_per_prompt,
|
| 243 |
+
do_classifier_free_guidance,
|
| 244 |
+
negative_prompt=None,
|
| 245 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 246 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 247 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 248 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
| 249 |
+
lora_scale: Optional[float] = None,
|
| 250 |
+
clip_skip: Optional[int] = None,
|
| 251 |
+
text_encoder: Optional[TextEncoder] = None,
|
| 252 |
+
data_type: Optional[str] = "image",
|
| 253 |
+
):
|
| 254 |
+
r"""
|
| 255 |
+
Encodes the prompt into text encoder hidden states.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 259 |
+
prompt to be encoded
|
| 260 |
+
device: (`torch.device`):
|
| 261 |
+
torch device
|
| 262 |
+
num_videos_per_prompt (`int`):
|
| 263 |
+
number of videos that should be generated per prompt
|
| 264 |
+
do_classifier_free_guidance (`bool`):
|
| 265 |
+
whether to use classifier free guidance or not
|
| 266 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 267 |
+
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
| 268 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 269 |
+
less than `1`).
|
| 270 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 271 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 272 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 273 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 274 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 275 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 276 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 277 |
+
argument.
|
| 278 |
+
negative_attention_mask (`torch.Tensor`, *optional*):
|
| 279 |
+
lora_scale (`float`, *optional*):
|
| 280 |
+
A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
|
| 281 |
+
clip_skip (`int`, *optional*):
|
| 282 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 283 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 284 |
+
text_encoder (TextEncoder, *optional*):
|
| 285 |
+
data_type (`str`, *optional*):
|
| 286 |
+
"""
|
| 287 |
+
if text_encoder is None:
|
| 288 |
+
text_encoder = self.text_encoder
|
| 289 |
+
|
| 290 |
+
# set lora scale so that monkey patched LoRA
|
| 291 |
+
# function of text encoder can correctly access it
|
| 292 |
+
if lora_scale is not None and isinstance(self, LoraLoaderMixin):
|
| 293 |
+
self._lora_scale = lora_scale
|
| 294 |
+
|
| 295 |
+
# dynamically adjust the LoRA scale
|
| 296 |
+
if not USE_PEFT_BACKEND:
|
| 297 |
+
adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
|
| 298 |
+
else:
|
| 299 |
+
scale_lora_layers(text_encoder.model, lora_scale)
|
| 300 |
+
|
| 301 |
+
if prompt is not None and isinstance(prompt, str):
|
| 302 |
+
batch_size = 1
|
| 303 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 304 |
+
batch_size = len(prompt)
|
| 305 |
+
else:
|
| 306 |
+
batch_size = prompt_embeds.shape[0]
|
| 307 |
+
|
| 308 |
+
if prompt_embeds is None:
|
| 309 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 310 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 311 |
+
prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
|
| 312 |
+
|
| 313 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
| 314 |
+
|
| 315 |
+
if clip_skip is None:
|
| 316 |
+
prompt_outputs = text_encoder.encode(
|
| 317 |
+
text_inputs, data_type=data_type, device=device
|
| 318 |
+
)
|
| 319 |
+
prompt_embeds = prompt_outputs.hidden_state
|
| 320 |
+
else:
|
| 321 |
+
prompt_outputs = text_encoder.encode(
|
| 322 |
+
text_inputs,
|
| 323 |
+
output_hidden_states=True,
|
| 324 |
+
data_type=data_type,
|
| 325 |
+
device=device,
|
| 326 |
+
)
|
| 327 |
+
# Access the `hidden_states` first, that contains a tuple of
|
| 328 |
+
# all the hidden states from the encoder layers. Then index into
|
| 329 |
+
# the tuple to access the hidden states from the desired layer.
|
| 330 |
+
prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
|
| 331 |
+
# We also need to apply the final LayerNorm here to not mess with the
|
| 332 |
+
# representations. The `last_hidden_states` that we typically use for
|
| 333 |
+
# obtaining the final prompt representations passes through the LayerNorm
|
| 334 |
+
# layer.
|
| 335 |
+
prompt_embeds = text_encoder.model.text_model.final_layer_norm(
|
| 336 |
+
prompt_embeds
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
attention_mask = prompt_outputs.attention_mask
|
| 340 |
+
if attention_mask is not None:
|
| 341 |
+
attention_mask = attention_mask.to(device)
|
| 342 |
+
bs_embed, seq_len = attention_mask.shape
|
| 343 |
+
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
|
| 344 |
+
attention_mask = attention_mask.view(
|
| 345 |
+
bs_embed * num_videos_per_prompt, seq_len
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
if text_encoder is not None:
|
| 349 |
+
prompt_embeds_dtype = text_encoder.dtype
|
| 350 |
+
elif self.transformer is not None:
|
| 351 |
+
prompt_embeds_dtype = self.transformer.dtype
|
| 352 |
+
else:
|
| 353 |
+
prompt_embeds_dtype = prompt_embeds.dtype
|
| 354 |
+
|
| 355 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 356 |
+
|
| 357 |
+
if prompt_embeds.ndim == 2:
|
| 358 |
+
bs_embed, _ = prompt_embeds.shape
|
| 359 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 360 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
| 361 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
|
| 362 |
+
else:
|
| 363 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 364 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 365 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 366 |
+
prompt_embeds = prompt_embeds.view(
|
| 367 |
+
bs_embed * num_videos_per_prompt, seq_len, -1
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
# get unconditional embeddings for classifier free guidance
|
| 371 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
| 372 |
+
uncond_tokens: List[str]
|
| 373 |
+
if negative_prompt is None:
|
| 374 |
+
uncond_tokens = [""] * batch_size
|
| 375 |
+
elif prompt is not None and type(prompt) is not type(negative_prompt):
|
| 376 |
+
raise TypeError(
|
| 377 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
| 378 |
+
f" {type(prompt)}."
|
| 379 |
+
)
|
| 380 |
+
elif isinstance(negative_prompt, str):
|
| 381 |
+
uncond_tokens = [negative_prompt]
|
| 382 |
+
elif batch_size != len(negative_prompt):
|
| 383 |
+
raise ValueError(
|
| 384 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
| 385 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
| 386 |
+
" the batch size of `prompt`."
|
| 387 |
+
)
|
| 388 |
+
else:
|
| 389 |
+
uncond_tokens = negative_prompt
|
| 390 |
+
|
| 391 |
+
# textual inversion: process multi-vector tokens if necessary
|
| 392 |
+
if isinstance(self, TextualInversionLoaderMixin):
|
| 393 |
+
uncond_tokens = self.maybe_convert_prompt(
|
| 394 |
+
uncond_tokens, text_encoder.tokenizer
|
| 395 |
+
)
|
| 396 |
+
|
| 397 |
+
# max_length = prompt_embeds.shape[1]
|
| 398 |
+
uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
|
| 399 |
+
|
| 400 |
+
negative_prompt_outputs = text_encoder.encode(
|
| 401 |
+
uncond_input, data_type=data_type, device=device
|
| 402 |
+
)
|
| 403 |
+
negative_prompt_embeds = negative_prompt_outputs.hidden_state
|
| 404 |
+
|
| 405 |
+
negative_attention_mask = negative_prompt_outputs.attention_mask
|
| 406 |
+
if negative_attention_mask is not None:
|
| 407 |
+
negative_attention_mask = negative_attention_mask.to(device)
|
| 408 |
+
_, seq_len = negative_attention_mask.shape
|
| 409 |
+
negative_attention_mask = negative_attention_mask.repeat(
|
| 410 |
+
1, num_videos_per_prompt
|
| 411 |
+
)
|
| 412 |
+
negative_attention_mask = negative_attention_mask.view(
|
| 413 |
+
batch_size * num_videos_per_prompt, seq_len
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
if do_classifier_free_guidance:
|
| 417 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
| 418 |
+
seq_len = negative_prompt_embeds.shape[1]
|
| 419 |
+
|
| 420 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
| 421 |
+
dtype=prompt_embeds_dtype, device=device
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
if negative_prompt_embeds.ndim == 2:
|
| 425 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
| 426 |
+
1, num_videos_per_prompt
|
| 427 |
+
)
|
| 428 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
| 429 |
+
batch_size * num_videos_per_prompt, -1
|
| 430 |
+
)
|
| 431 |
+
else:
|
| 432 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
| 433 |
+
1, num_videos_per_prompt, 1
|
| 434 |
+
)
|
| 435 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
| 436 |
+
batch_size * num_videos_per_prompt, seq_len, -1
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
if text_encoder is not None:
|
| 440 |
+
if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
|
| 441 |
+
# Retrieve the original scale by scaling back the LoRA layers
|
| 442 |
+
unscale_lora_layers(text_encoder.model, lora_scale)
|
| 443 |
+
|
| 444 |
+
return (
|
| 445 |
+
prompt_embeds,
|
| 446 |
+
negative_prompt_embeds,
|
| 447 |
+
attention_mask,
|
| 448 |
+
negative_attention_mask,
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
def decode_latents(self, latents, enable_tiling=True):
|
| 452 |
+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
|
| 453 |
+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
|
| 454 |
+
|
| 455 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 456 |
+
if enable_tiling:
|
| 457 |
+
self.vae.enable_tiling()
|
| 458 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 459 |
+
else:
|
| 460 |
+
image = self.vae.decode(latents, return_dict=False)[0]
|
| 461 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 462 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 463 |
+
if image.ndim == 4:
|
| 464 |
+
image = image.cpu().permute(0, 2, 3, 1).float()
|
| 465 |
+
else:
|
| 466 |
+
image = image.cpu().float()
|
| 467 |
+
return image
|
| 468 |
+
|
| 469 |
+
def prepare_extra_func_kwargs(self, func, kwargs):
|
| 470 |
+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
| 471 |
+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
|
| 472 |
+
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
| 473 |
+
# and should be between [0, 1]
|
| 474 |
+
extra_step_kwargs = {}
|
| 475 |
+
|
| 476 |
+
for k, v in kwargs.items():
|
| 477 |
+
accepts = k in set(inspect.signature(func).parameters.keys())
|
| 478 |
+
if accepts:
|
| 479 |
+
extra_step_kwargs[k] = v
|
| 480 |
+
return extra_step_kwargs
|
| 481 |
+
|
| 482 |
+
def check_inputs(
|
| 483 |
+
self,
|
| 484 |
+
prompt,
|
| 485 |
+
height,
|
| 486 |
+
width,
|
| 487 |
+
video_length,
|
| 488 |
+
callback_steps,
|
| 489 |
+
negative_prompt=None,
|
| 490 |
+
prompt_embeds=None,
|
| 491 |
+
negative_prompt_embeds=None,
|
| 492 |
+
callback_on_step_end_tensor_inputs=None,
|
| 493 |
+
vae_ver="88-4c-sd",
|
| 494 |
+
):
|
| 495 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 496 |
+
raise ValueError(
|
| 497 |
+
f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
|
| 498 |
+
)
|
| 499 |
+
|
| 500 |
+
if video_length is not None:
|
| 501 |
+
if "884" in vae_ver:
|
| 502 |
+
if video_length != 1 and (video_length - 1) % 4 != 0:
|
| 503 |
+
raise ValueError(
|
| 504 |
+
f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
|
| 505 |
+
)
|
| 506 |
+
elif "888" in vae_ver:
|
| 507 |
+
if video_length != 1 and (video_length - 1) % 8 != 0:
|
| 508 |
+
raise ValueError(
|
| 509 |
+
f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
|
| 510 |
+
)
|
| 511 |
+
|
| 512 |
+
if callback_steps is not None and (
|
| 513 |
+
not isinstance(callback_steps, int) or callback_steps <= 0
|
| 514 |
+
):
|
| 515 |
+
raise ValueError(
|
| 516 |
+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
| 517 |
+
f" {type(callback_steps)}."
|
| 518 |
+
)
|
| 519 |
+
if callback_on_step_end_tensor_inputs is not None and not all(
|
| 520 |
+
k in self._callback_tensor_inputs
|
| 521 |
+
for k in callback_on_step_end_tensor_inputs
|
| 522 |
+
):
|
| 523 |
+
raise ValueError(
|
| 524 |
+
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
if prompt is not None and prompt_embeds is not None:
|
| 528 |
+
raise ValueError(
|
| 529 |
+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
| 530 |
+
" only forward one of the two."
|
| 531 |
+
)
|
| 532 |
+
elif prompt is None and prompt_embeds is None:
|
| 533 |
+
raise ValueError(
|
| 534 |
+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
| 535 |
+
)
|
| 536 |
+
elif prompt is not None and (
|
| 537 |
+
not isinstance(prompt, str) and not isinstance(prompt, list)
|
| 538 |
+
):
|
| 539 |
+
raise ValueError(
|
| 540 |
+
f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
if negative_prompt is not None and negative_prompt_embeds is not None:
|
| 544 |
+
raise ValueError(
|
| 545 |
+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
| 546 |
+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
| 550 |
+
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
| 551 |
+
raise ValueError(
|
| 552 |
+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
| 553 |
+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
| 554 |
+
f" {negative_prompt_embeds.shape}."
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def prepare_latents(
|
| 559 |
+
self,
|
| 560 |
+
batch_size,
|
| 561 |
+
num_channels_latents,
|
| 562 |
+
height,
|
| 563 |
+
width,
|
| 564 |
+
video_length,
|
| 565 |
+
dtype,
|
| 566 |
+
device,
|
| 567 |
+
generator,
|
| 568 |
+
latents=None,
|
| 569 |
+
):
|
| 570 |
+
shape = (
|
| 571 |
+
batch_size,
|
| 572 |
+
num_channels_latents,
|
| 573 |
+
video_length,
|
| 574 |
+
int(height) // self.vae_scale_factor,
|
| 575 |
+
int(width) // self.vae_scale_factor,
|
| 576 |
+
)
|
| 577 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 578 |
+
raise ValueError(
|
| 579 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 580 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
if latents is None:
|
| 584 |
+
latents = randn_tensor(
|
| 585 |
+
shape, generator=generator, device=device, dtype=dtype
|
| 586 |
+
)
|
| 587 |
+
else:
|
| 588 |
+
latents = latents.to(device)
|
| 589 |
+
|
| 590 |
+
# Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
|
| 591 |
+
if hasattr(self.scheduler, "init_noise_sigma"):
|
| 592 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 593 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 594 |
+
return latents
|
| 595 |
+
|
| 596 |
+
# Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
|
| 597 |
+
def get_guidance_scale_embedding(
|
| 598 |
+
self,
|
| 599 |
+
w: torch.Tensor,
|
| 600 |
+
embedding_dim: int = 512,
|
| 601 |
+
dtype: torch.dtype = torch.float32,
|
| 602 |
+
) -> torch.Tensor:
|
| 603 |
+
"""
|
| 604 |
+
See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
|
| 605 |
+
|
| 606 |
+
Args:
|
| 607 |
+
w (`torch.Tensor`):
|
| 608 |
+
Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
|
| 609 |
+
embedding_dim (`int`, *optional*, defaults to 512):
|
| 610 |
+
Dimension of the embeddings to generate.
|
| 611 |
+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
|
| 612 |
+
Data type of the generated embeddings.
|
| 613 |
+
|
| 614 |
+
Returns:
|
| 615 |
+
`torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
|
| 616 |
+
"""
|
| 617 |
+
assert len(w.shape) == 1
|
| 618 |
+
w = w * 1000.0
|
| 619 |
+
|
| 620 |
+
half_dim = embedding_dim // 2
|
| 621 |
+
emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
|
| 622 |
+
emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
|
| 623 |
+
emb = w.to(dtype)[:, None] * emb[None, :]
|
| 624 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
| 625 |
+
if embedding_dim % 2 == 1: # zero pad
|
| 626 |
+
emb = torch.nn.functional.pad(emb, (0, 1))
|
| 627 |
+
assert emb.shape == (w.shape[0], embedding_dim)
|
| 628 |
+
return emb
|
| 629 |
+
|
| 630 |
+
@property
|
| 631 |
+
def guidance_scale(self):
|
| 632 |
+
return self._guidance_scale
|
| 633 |
+
|
| 634 |
+
@property
|
| 635 |
+
def guidance_rescale(self):
|
| 636 |
+
return self._guidance_rescale
|
| 637 |
+
|
| 638 |
+
@property
|
| 639 |
+
def clip_skip(self):
|
| 640 |
+
return self._clip_skip
|
| 641 |
+
|
| 642 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 643 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 644 |
+
# corresponds to doing no classifier free guidance.
|
| 645 |
+
@property
|
| 646 |
+
def do_classifier_free_guidance(self):
|
| 647 |
+
# return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
|
| 648 |
+
return self._guidance_scale > 1
|
| 649 |
+
|
| 650 |
+
@property
|
| 651 |
+
def cross_attention_kwargs(self):
|
| 652 |
+
return self._cross_attention_kwargs
|
| 653 |
+
|
| 654 |
+
@property
|
| 655 |
+
def num_timesteps(self):
|
| 656 |
+
return self._num_timesteps
|
| 657 |
+
|
| 658 |
+
@property
|
| 659 |
+
def interrupt(self):
|
| 660 |
+
return self._interrupt
|
| 661 |
+
|
| 662 |
+
@torch.no_grad()
|
| 663 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 664 |
+
def __call__(
|
| 665 |
+
self,
|
| 666 |
+
prompt: Union[str, List[str]],
|
| 667 |
+
height: int,
|
| 668 |
+
width: int,
|
| 669 |
+
video_length: int,
|
| 670 |
+
data_type: str = "video",
|
| 671 |
+
num_inference_steps: int = 50,
|
| 672 |
+
timesteps: List[int] = None,
|
| 673 |
+
sigmas: List[float] = None,
|
| 674 |
+
guidance_scale: float = 7.5,
|
| 675 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 676 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 677 |
+
eta: float = 0.0,
|
| 678 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 679 |
+
latents: Optional[torch.Tensor] = None,
|
| 680 |
+
prompt_embeds: Optional[torch.Tensor] = None,
|
| 681 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 682 |
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
| 683 |
+
negative_attention_mask: Optional[torch.Tensor] = None,
|
| 684 |
+
output_type: Optional[str] = "pil",
|
| 685 |
+
return_dict: bool = True,
|
| 686 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 687 |
+
guidance_rescale: float = 0.0,
|
| 688 |
+
clip_skip: Optional[int] = None,
|
| 689 |
+
callback_on_step_end: Optional[
|
| 690 |
+
Union[
|
| 691 |
+
Callable[[int, int, Dict], None],
|
| 692 |
+
PipelineCallback,
|
| 693 |
+
MultiPipelineCallbacks,
|
| 694 |
+
]
|
| 695 |
+
] = None,
|
| 696 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 697 |
+
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
|
| 698 |
+
vae_ver: str = "88-4c-sd",
|
| 699 |
+
enable_tiling: bool = False,
|
| 700 |
+
n_tokens: Optional[int] = None,
|
| 701 |
+
embedded_guidance_scale: Optional[float] = None,
|
| 702 |
+
**kwargs,
|
| 703 |
+
):
|
| 704 |
+
r"""
|
| 705 |
+
The call function to the pipeline for generation.
|
| 706 |
+
|
| 707 |
+
Args:
|
| 708 |
+
prompt (`str` or `List[str]`):
|
| 709 |
+
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
| 710 |
+
height (`int`):
|
| 711 |
+
The height in pixels of the generated image.
|
| 712 |
+
width (`int`):
|
| 713 |
+
The width in pixels of the generated image.
|
| 714 |
+
video_length (`int`):
|
| 715 |
+
The number of frames in the generated video.
|
| 716 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 717 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 718 |
+
expense of slower inference.
|
| 719 |
+
timesteps (`List[int]`, *optional*):
|
| 720 |
+
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
| 721 |
+
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
| 722 |
+
passed will be used. Must be in descending order.
|
| 723 |
+
sigmas (`List[float]`, *optional*):
|
| 724 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 725 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 726 |
+
will be used.
|
| 727 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 728 |
+
A higher guidance scale value encourages the model to generate images closely linked to the text
|
| 729 |
+
`prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
|
| 730 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 731 |
+
The prompt or prompts to guide what to not include in image generation. If not defined, you need to
|
| 732 |
+
pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
|
| 733 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 734 |
+
The number of images to generate per prompt.
|
| 735 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 736 |
+
Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
|
| 737 |
+
to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
|
| 738 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 739 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 740 |
+
generation deterministic.
|
| 741 |
+
latents (`torch.Tensor`, *optional*):
|
| 742 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 743 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 744 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 745 |
+
prompt_embeds (`torch.Tensor`, *optional*):
|
| 746 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
|
| 747 |
+
provided, text embeddings are generated from the `prompt` input argument.
|
| 748 |
+
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
| 749 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
|
| 750 |
+
not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
|
| 751 |
+
|
| 752 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 753 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 754 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 755 |
+
Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
|
| 756 |
+
plain tuple.
|
| 757 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 758 |
+
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
| 759 |
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 760 |
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
| 761 |
+
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
|
| 762 |
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
|
| 763 |
+
using zero terminal SNR.
|
| 764 |
+
clip_skip (`int`, *optional*):
|
| 765 |
+
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
|
| 766 |
+
the output of the pre-final layer will be used for computing the prompt embeddings.
|
| 767 |
+
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
|
| 768 |
+
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
|
| 769 |
+
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
|
| 770 |
+
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
|
| 771 |
+
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
|
| 772 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 773 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 774 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 775 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 776 |
+
|
| 777 |
+
Examples:
|
| 778 |
+
|
| 779 |
+
Returns:
|
| 780 |
+
[`~HunyuanVideoPipelineOutput`] or `tuple`:
|
| 781 |
+
If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
|
| 782 |
+
otherwise a `tuple` is returned where the first element is a list with the generated images and the
|
| 783 |
+
second element is a list of `bool`s indicating whether the corresponding generated image contains
|
| 784 |
+
"not-safe-for-work" (nsfw) content.
|
| 785 |
+
"""
|
| 786 |
+
callback = kwargs.pop("callback", None)
|
| 787 |
+
callback_steps = kwargs.pop("callback_steps", None)
|
| 788 |
+
|
| 789 |
+
if callback is not None:
|
| 790 |
+
deprecate(
|
| 791 |
+
"callback",
|
| 792 |
+
"1.0.0",
|
| 793 |
+
"Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 794 |
+
)
|
| 795 |
+
if callback_steps is not None:
|
| 796 |
+
deprecate(
|
| 797 |
+
"callback_steps",
|
| 798 |
+
"1.0.0",
|
| 799 |
+
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
|
| 800 |
+
)
|
| 801 |
+
|
| 802 |
+
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
| 803 |
+
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
| 804 |
+
|
| 805 |
+
# 0. Default height and width to unet
|
| 806 |
+
# height = height or self.transformer.config.sample_size * self.vae_scale_factor
|
| 807 |
+
# width = width or self.transformer.config.sample_size * self.vae_scale_factor
|
| 808 |
+
# to deal with lora scaling and other possible forward hooks
|
| 809 |
+
|
| 810 |
+
# 1. Check inputs. Raise error if not correct
|
| 811 |
+
self.check_inputs(
|
| 812 |
+
prompt,
|
| 813 |
+
height,
|
| 814 |
+
width,
|
| 815 |
+
video_length,
|
| 816 |
+
callback_steps,
|
| 817 |
+
negative_prompt,
|
| 818 |
+
prompt_embeds,
|
| 819 |
+
negative_prompt_embeds,
|
| 820 |
+
callback_on_step_end_tensor_inputs,
|
| 821 |
+
vae_ver=vae_ver,
|
| 822 |
+
)
|
| 823 |
+
|
| 824 |
+
self._guidance_scale = guidance_scale
|
| 825 |
+
self._guidance_rescale = guidance_rescale
|
| 826 |
+
self._clip_skip = clip_skip
|
| 827 |
+
self._cross_attention_kwargs = cross_attention_kwargs
|
| 828 |
+
self._interrupt = False
|
| 829 |
+
|
| 830 |
+
# 2. Define call parameters
|
| 831 |
+
if prompt is not None and isinstance(prompt, str):
|
| 832 |
+
batch_size = 1
|
| 833 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 834 |
+
batch_size = len(prompt)
|
| 835 |
+
else:
|
| 836 |
+
batch_size = prompt_embeds.shape[0]
|
| 837 |
+
|
| 838 |
+
device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
|
| 839 |
+
|
| 840 |
+
# 3. Encode input prompt
|
| 841 |
+
lora_scale = (
|
| 842 |
+
self.cross_attention_kwargs.get("scale", None)
|
| 843 |
+
if self.cross_attention_kwargs is not None
|
| 844 |
+
else None
|
| 845 |
+
)
|
| 846 |
+
|
| 847 |
+
(
|
| 848 |
+
prompt_embeds,
|
| 849 |
+
negative_prompt_embeds,
|
| 850 |
+
prompt_mask,
|
| 851 |
+
negative_prompt_mask,
|
| 852 |
+
) = self.encode_prompt(
|
| 853 |
+
prompt,
|
| 854 |
+
device,
|
| 855 |
+
num_videos_per_prompt,
|
| 856 |
+
self.do_classifier_free_guidance,
|
| 857 |
+
negative_prompt,
|
| 858 |
+
prompt_embeds=prompt_embeds,
|
| 859 |
+
attention_mask=attention_mask,
|
| 860 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 861 |
+
negative_attention_mask=negative_attention_mask,
|
| 862 |
+
lora_scale=lora_scale,
|
| 863 |
+
clip_skip=self.clip_skip,
|
| 864 |
+
data_type=data_type,
|
| 865 |
+
)
|
| 866 |
+
if self.text_encoder_2 is not None:
|
| 867 |
+
(
|
| 868 |
+
prompt_embeds_2,
|
| 869 |
+
negative_prompt_embeds_2,
|
| 870 |
+
prompt_mask_2,
|
| 871 |
+
negative_prompt_mask_2,
|
| 872 |
+
) = self.encode_prompt(
|
| 873 |
+
prompt,
|
| 874 |
+
device,
|
| 875 |
+
num_videos_per_prompt,
|
| 876 |
+
self.do_classifier_free_guidance,
|
| 877 |
+
negative_prompt,
|
| 878 |
+
prompt_embeds=None,
|
| 879 |
+
attention_mask=None,
|
| 880 |
+
negative_prompt_embeds=None,
|
| 881 |
+
negative_attention_mask=None,
|
| 882 |
+
lora_scale=lora_scale,
|
| 883 |
+
clip_skip=self.clip_skip,
|
| 884 |
+
text_encoder=self.text_encoder_2,
|
| 885 |
+
data_type=data_type,
|
| 886 |
+
)
|
| 887 |
+
else:
|
| 888 |
+
prompt_embeds_2 = None
|
| 889 |
+
negative_prompt_embeds_2 = None
|
| 890 |
+
prompt_mask_2 = None
|
| 891 |
+
negative_prompt_mask_2 = None
|
| 892 |
+
|
| 893 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 894 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 895 |
+
# to avoid doing two forward passes
|
| 896 |
+
if self.do_classifier_free_guidance:
|
| 897 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 898 |
+
if prompt_mask is not None:
|
| 899 |
+
prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
|
| 900 |
+
if prompt_embeds_2 is not None:
|
| 901 |
+
prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
|
| 902 |
+
if prompt_mask_2 is not None:
|
| 903 |
+
prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
# 4. Prepare timesteps
|
| 907 |
+
extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
|
| 908 |
+
self.scheduler.set_timesteps, {"n_tokens": n_tokens}
|
| 909 |
+
)
|
| 910 |
+
timesteps, num_inference_steps = retrieve_timesteps(
|
| 911 |
+
self.scheduler,
|
| 912 |
+
num_inference_steps,
|
| 913 |
+
device,
|
| 914 |
+
timesteps,
|
| 915 |
+
sigmas,
|
| 916 |
+
**extra_set_timesteps_kwargs,
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
if "884" in vae_ver:
|
| 920 |
+
video_length = (video_length - 1) // 4 + 1
|
| 921 |
+
elif "888" in vae_ver:
|
| 922 |
+
video_length = (video_length - 1) // 8 + 1
|
| 923 |
+
else:
|
| 924 |
+
video_length = video_length
|
| 925 |
+
|
| 926 |
+
# 5. Prepare latent variables
|
| 927 |
+
num_channels_latents = self.transformer.config.in_channels
|
| 928 |
+
latents = self.prepare_latents(
|
| 929 |
+
batch_size * num_videos_per_prompt,
|
| 930 |
+
num_channels_latents,
|
| 931 |
+
height,
|
| 932 |
+
width,
|
| 933 |
+
video_length,
|
| 934 |
+
prompt_embeds.dtype,
|
| 935 |
+
device,
|
| 936 |
+
generator,
|
| 937 |
+
latents,
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 941 |
+
extra_step_kwargs = self.prepare_extra_func_kwargs(
|
| 942 |
+
self.scheduler.step,
|
| 943 |
+
{"generator": generator, "eta": eta},
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
target_dtype = PRECISION_TO_TYPE[self.args.precision]
|
| 947 |
+
autocast_enabled = (
|
| 948 |
+
target_dtype != torch.float32
|
| 949 |
+
) and not self.args.disable_autocast
|
| 950 |
+
vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
|
| 951 |
+
vae_autocast_enabled = (
|
| 952 |
+
vae_dtype != torch.float32
|
| 953 |
+
) and not self.args.disable_autocast
|
| 954 |
+
|
| 955 |
+
# 7. Denoising loop
|
| 956 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 957 |
+
self._num_timesteps = len(timesteps)
|
| 958 |
+
|
| 959 |
+
# if is_progress_bar:
|
| 960 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 961 |
+
for i, t in enumerate(timesteps):
|
| 962 |
+
if self.interrupt:
|
| 963 |
+
continue
|
| 964 |
+
|
| 965 |
+
# expand the latents if we are doing classifier free guidance
|
| 966 |
+
latent_model_input = (
|
| 967 |
+
torch.cat([latents] * 2)
|
| 968 |
+
if self.do_classifier_free_guidance
|
| 969 |
+
else latents
|
| 970 |
+
)
|
| 971 |
+
latent_model_input = self.scheduler.scale_model_input(
|
| 972 |
+
latent_model_input, t
|
| 973 |
+
)
|
| 974 |
+
|
| 975 |
+
t_expand = t.repeat(latent_model_input.shape[0])
|
| 976 |
+
guidance_expand = (
|
| 977 |
+
torch.tensor(
|
| 978 |
+
[embedded_guidance_scale] * latent_model_input.shape[0],
|
| 979 |
+
dtype=torch.float32,
|
| 980 |
+
device=device,
|
| 981 |
+
).to(target_dtype)
|
| 982 |
+
* 1000.0
|
| 983 |
+
if embedded_guidance_scale is not None
|
| 984 |
+
else None
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
# predict the noise residual
|
| 988 |
+
with torch.autocast(
|
| 989 |
+
device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
|
| 990 |
+
):
|
| 991 |
+
noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
|
| 992 |
+
latent_model_input, # [2, 16, 33, 24, 42]
|
| 993 |
+
t_expand, # [2]
|
| 994 |
+
text_states=prompt_embeds, # [2, 256, 4096]
|
| 995 |
+
text_mask=prompt_mask, # [2, 256]
|
| 996 |
+
text_states_2=prompt_embeds_2, # [2, 768]
|
| 997 |
+
freqs_cos=freqs_cis[0], # [seqlen, head_dim]
|
| 998 |
+
freqs_sin=freqs_cis[1], # [seqlen, head_dim]
|
| 999 |
+
guidance=guidance_expand,
|
| 1000 |
+
return_dict=True,
|
| 1001 |
+
)[
|
| 1002 |
+
"x"
|
| 1003 |
+
]
|
| 1004 |
+
|
| 1005 |
+
# perform guidance
|
| 1006 |
+
if self.do_classifier_free_guidance:
|
| 1007 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 1008 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (
|
| 1009 |
+
noise_pred_text - noise_pred_uncond
|
| 1010 |
+
)
|
| 1011 |
+
|
| 1012 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
| 1013 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 1014 |
+
noise_pred = rescale_noise_cfg(
|
| 1015 |
+
noise_pred,
|
| 1016 |
+
noise_pred_text,
|
| 1017 |
+
guidance_rescale=self.guidance_rescale,
|
| 1018 |
+
)
|
| 1019 |
+
|
| 1020 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 1021 |
+
latents = self.scheduler.step(
|
| 1022 |
+
noise_pred, t, latents, **extra_step_kwargs, return_dict=False
|
| 1023 |
+
)[0]
|
| 1024 |
+
|
| 1025 |
+
if callback_on_step_end is not None:
|
| 1026 |
+
callback_kwargs = {}
|
| 1027 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 1028 |
+
callback_kwargs[k] = locals()[k]
|
| 1029 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 1030 |
+
|
| 1031 |
+
latents = callback_outputs.pop("latents", latents)
|
| 1032 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
| 1033 |
+
negative_prompt_embeds = callback_outputs.pop(
|
| 1034 |
+
"negative_prompt_embeds", negative_prompt_embeds
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
# call the callback, if provided
|
| 1038 |
+
if i == len(timesteps) - 1 or (
|
| 1039 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 1040 |
+
):
|
| 1041 |
+
if progress_bar is not None:
|
| 1042 |
+
progress_bar.update()
|
| 1043 |
+
if callback is not None and i % callback_steps == 0:
|
| 1044 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
| 1045 |
+
callback(step_idx, t, latents)
|
| 1046 |
+
|
| 1047 |
+
if not output_type == "latent":
|
| 1048 |
+
expand_temporal_dim = False
|
| 1049 |
+
if len(latents.shape) == 4:
|
| 1050 |
+
if isinstance(self.vae, AutoencoderKLCausal3D):
|
| 1051 |
+
latents = latents.unsqueeze(2)
|
| 1052 |
+
expand_temporal_dim = True
|
| 1053 |
+
elif len(latents.shape) == 5:
|
| 1054 |
+
pass
|
| 1055 |
+
else:
|
| 1056 |
+
raise ValueError(
|
| 1057 |
+
f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
|
| 1058 |
+
)
|
| 1059 |
+
|
| 1060 |
+
if (
|
| 1061 |
+
hasattr(self.vae.config, "shift_factor")
|
| 1062 |
+
and self.vae.config.shift_factor
|
| 1063 |
+
):
|
| 1064 |
+
latents = (
|
| 1065 |
+
latents / self.vae.config.scaling_factor
|
| 1066 |
+
+ self.vae.config.shift_factor
|
| 1067 |
+
)
|
| 1068 |
+
else:
|
| 1069 |
+
latents = latents / self.vae.config.scaling_factor
|
| 1070 |
+
|
| 1071 |
+
with torch.autocast(
|
| 1072 |
+
device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
|
| 1073 |
+
):
|
| 1074 |
+
if enable_tiling:
|
| 1075 |
+
self.vae.enable_tiling()
|
| 1076 |
+
image = self.vae.decode(
|
| 1077 |
+
latents, return_dict=False, generator=generator
|
| 1078 |
+
)[0]
|
| 1079 |
+
else:
|
| 1080 |
+
image = self.vae.decode(
|
| 1081 |
+
latents, return_dict=False, generator=generator
|
| 1082 |
+
)[0]
|
| 1083 |
+
|
| 1084 |
+
if expand_temporal_dim or image.shape[2] == 1:
|
| 1085 |
+
image = image.squeeze(2)
|
| 1086 |
+
|
| 1087 |
+
else:
|
| 1088 |
+
image = latents
|
| 1089 |
+
|
| 1090 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 1091 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 1092 |
+
image = image.cpu().float()
|
| 1093 |
+
|
| 1094 |
+
# Offload all models
|
| 1095 |
+
self.maybe_free_model_hooks()
|
| 1096 |
+
|
| 1097 |
+
if not return_dict:
|
| 1098 |
+
return image
|
| 1099 |
+
|
| 1100 |
+
return HunyuanVideoPipelineOutput(videos=image)
|
hunyuan_model/posemb_layers.py
ADDED
|
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from typing import Union, Tuple, List
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def _to_tuple(x, dim=2):
|
| 6 |
+
if isinstance(x, int):
|
| 7 |
+
return (x,) * dim
|
| 8 |
+
elif len(x) == dim:
|
| 9 |
+
return x
|
| 10 |
+
else:
|
| 11 |
+
raise ValueError(f"Expected length {dim} or int, but got {x}")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_meshgrid_nd(start, *args, dim=2):
|
| 15 |
+
"""
|
| 16 |
+
Get n-D meshgrid with start, stop and num.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
|
| 20 |
+
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
|
| 21 |
+
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
|
| 22 |
+
n-tuples.
|
| 23 |
+
*args: See above.
|
| 24 |
+
dim (int): Dimension of the meshgrid. Defaults to 2.
|
| 25 |
+
|
| 26 |
+
Returns:
|
| 27 |
+
grid (np.ndarray): [dim, ...]
|
| 28 |
+
"""
|
| 29 |
+
if len(args) == 0:
|
| 30 |
+
# start is grid_size
|
| 31 |
+
num = _to_tuple(start, dim=dim)
|
| 32 |
+
start = (0,) * dim
|
| 33 |
+
stop = num
|
| 34 |
+
elif len(args) == 1:
|
| 35 |
+
# start is start, args[0] is stop, step is 1
|
| 36 |
+
start = _to_tuple(start, dim=dim)
|
| 37 |
+
stop = _to_tuple(args[0], dim=dim)
|
| 38 |
+
num = [stop[i] - start[i] for i in range(dim)]
|
| 39 |
+
elif len(args) == 2:
|
| 40 |
+
# start is start, args[0] is stop, args[1] is num
|
| 41 |
+
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
|
| 42 |
+
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
|
| 43 |
+
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
|
| 44 |
+
else:
|
| 45 |
+
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
|
| 46 |
+
|
| 47 |
+
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
|
| 48 |
+
axis_grid = []
|
| 49 |
+
for i in range(dim):
|
| 50 |
+
a, b, n = start[i], stop[i], num[i]
|
| 51 |
+
g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
|
| 52 |
+
axis_grid.append(g)
|
| 53 |
+
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
|
| 54 |
+
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
|
| 55 |
+
|
| 56 |
+
return grid
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
#################################################################################
|
| 60 |
+
# Rotary Positional Embedding Functions #
|
| 61 |
+
#################################################################################
|
| 62 |
+
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def reshape_for_broadcast(
|
| 66 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
| 67 |
+
x: torch.Tensor,
|
| 68 |
+
head_first=False,
|
| 69 |
+
):
|
| 70 |
+
"""
|
| 71 |
+
Reshape frequency tensor for broadcasting it with another tensor.
|
| 72 |
+
|
| 73 |
+
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
|
| 74 |
+
for the purpose of broadcasting the frequency tensor during element-wise operations.
|
| 75 |
+
|
| 76 |
+
Notes:
|
| 77 |
+
When using FlashMHAModified, head_first should be False.
|
| 78 |
+
When using Attention, head_first should be True.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
|
| 82 |
+
x (torch.Tensor): Target tensor for broadcasting compatibility.
|
| 83 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
torch.Tensor: Reshaped frequency tensor.
|
| 87 |
+
|
| 88 |
+
Raises:
|
| 89 |
+
AssertionError: If the frequency tensor doesn't match the expected shape.
|
| 90 |
+
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
|
| 91 |
+
"""
|
| 92 |
+
ndim = x.ndim
|
| 93 |
+
assert 0 <= 1 < ndim
|
| 94 |
+
|
| 95 |
+
if isinstance(freqs_cis, tuple):
|
| 96 |
+
# freqs_cis: (cos, sin) in real space
|
| 97 |
+
if head_first:
|
| 98 |
+
assert freqs_cis[0].shape == (
|
| 99 |
+
x.shape[-2],
|
| 100 |
+
x.shape[-1],
|
| 101 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
| 102 |
+
shape = [
|
| 103 |
+
d if i == ndim - 2 or i == ndim - 1 else 1
|
| 104 |
+
for i, d in enumerate(x.shape)
|
| 105 |
+
]
|
| 106 |
+
else:
|
| 107 |
+
assert freqs_cis[0].shape == (
|
| 108 |
+
x.shape[1],
|
| 109 |
+
x.shape[-1],
|
| 110 |
+
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
|
| 111 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 112 |
+
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
|
| 113 |
+
else:
|
| 114 |
+
# freqs_cis: values in complex space
|
| 115 |
+
if head_first:
|
| 116 |
+
assert freqs_cis.shape == (
|
| 117 |
+
x.shape[-2],
|
| 118 |
+
x.shape[-1],
|
| 119 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
| 120 |
+
shape = [
|
| 121 |
+
d if i == ndim - 2 or i == ndim - 1 else 1
|
| 122 |
+
for i, d in enumerate(x.shape)
|
| 123 |
+
]
|
| 124 |
+
else:
|
| 125 |
+
assert freqs_cis.shape == (
|
| 126 |
+
x.shape[1],
|
| 127 |
+
x.shape[-1],
|
| 128 |
+
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
|
| 129 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
| 130 |
+
return freqs_cis.view(*shape)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def rotate_half(x):
|
| 134 |
+
x_real, x_imag = (
|
| 135 |
+
x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
|
| 136 |
+
) # [B, S, H, D//2]
|
| 137 |
+
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def apply_rotary_emb(
|
| 141 |
+
xq: torch.Tensor,
|
| 142 |
+
xk: torch.Tensor,
|
| 143 |
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
| 144 |
+
head_first: bool = False,
|
| 145 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 146 |
+
"""
|
| 147 |
+
Apply rotary embeddings to input tensors using the given frequency tensor.
|
| 148 |
+
|
| 149 |
+
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
|
| 150 |
+
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
|
| 151 |
+
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
|
| 152 |
+
returned as real tensors.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
|
| 156 |
+
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
|
| 157 |
+
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
|
| 158 |
+
head_first (bool): head dimension first (except batch dim) or not.
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
| 162 |
+
|
| 163 |
+
"""
|
| 164 |
+
xk_out = None
|
| 165 |
+
if isinstance(freqs_cis, tuple):
|
| 166 |
+
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
|
| 167 |
+
cos, sin = cos.to(xq.device), sin.to(xq.device)
|
| 168 |
+
# real * cos - imag * sin
|
| 169 |
+
# imag * cos + real * sin
|
| 170 |
+
xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
|
| 171 |
+
xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
|
| 172 |
+
else:
|
| 173 |
+
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
|
| 174 |
+
xq_ = torch.view_as_complex(
|
| 175 |
+
xq.float().reshape(*xq.shape[:-1], -1, 2)
|
| 176 |
+
) # [B, S, H, D//2]
|
| 177 |
+
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
|
| 178 |
+
xq.device
|
| 179 |
+
) # [S, D//2] --> [1, S, 1, D//2]
|
| 180 |
+
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
|
| 181 |
+
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
|
| 182 |
+
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
|
| 183 |
+
xk_ = torch.view_as_complex(
|
| 184 |
+
xk.float().reshape(*xk.shape[:-1], -1, 2)
|
| 185 |
+
) # [B, S, H, D//2]
|
| 186 |
+
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
|
| 187 |
+
|
| 188 |
+
return xq_out, xk_out
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def get_nd_rotary_pos_embed(
|
| 192 |
+
rope_dim_list,
|
| 193 |
+
start,
|
| 194 |
+
*args,
|
| 195 |
+
theta=10000.0,
|
| 196 |
+
use_real=False,
|
| 197 |
+
theta_rescale_factor: Union[float, List[float]] = 1.0,
|
| 198 |
+
interpolation_factor: Union[float, List[float]] = 1.0,
|
| 199 |
+
):
|
| 200 |
+
"""
|
| 201 |
+
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
|
| 205 |
+
sum(rope_dim_list) should equal to head_dim of attention layer.
|
| 206 |
+
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
|
| 207 |
+
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
|
| 208 |
+
*args: See above.
|
| 209 |
+
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 210 |
+
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
| 211 |
+
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
|
| 212 |
+
part and an imaginary part separately.
|
| 213 |
+
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
pos_embed (torch.Tensor): [HW, D/2]
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
grid = get_meshgrid_nd(
|
| 220 |
+
start, *args, dim=len(rope_dim_list)
|
| 221 |
+
) # [3, W, H, D] / [2, W, H]
|
| 222 |
+
|
| 223 |
+
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
|
| 224 |
+
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
|
| 225 |
+
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
|
| 226 |
+
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
|
| 227 |
+
assert len(theta_rescale_factor) == len(
|
| 228 |
+
rope_dim_list
|
| 229 |
+
), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
|
| 230 |
+
|
| 231 |
+
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
|
| 232 |
+
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
|
| 233 |
+
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
|
| 234 |
+
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
|
| 235 |
+
assert len(interpolation_factor) == len(
|
| 236 |
+
rope_dim_list
|
| 237 |
+
), "len(interpolation_factor) should equal to len(rope_dim_list)"
|
| 238 |
+
|
| 239 |
+
# use 1/ndim of dimensions to encode grid_axis
|
| 240 |
+
embs = []
|
| 241 |
+
for i in range(len(rope_dim_list)):
|
| 242 |
+
emb = get_1d_rotary_pos_embed(
|
| 243 |
+
rope_dim_list[i],
|
| 244 |
+
grid[i].reshape(-1),
|
| 245 |
+
theta,
|
| 246 |
+
use_real=use_real,
|
| 247 |
+
theta_rescale_factor=theta_rescale_factor[i],
|
| 248 |
+
interpolation_factor=interpolation_factor[i],
|
| 249 |
+
) # 2 x [WHD, rope_dim_list[i]]
|
| 250 |
+
embs.append(emb)
|
| 251 |
+
|
| 252 |
+
if use_real:
|
| 253 |
+
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
|
| 254 |
+
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
|
| 255 |
+
return cos, sin
|
| 256 |
+
else:
|
| 257 |
+
emb = torch.cat(embs, dim=1) # (WHD, D/2)
|
| 258 |
+
return emb
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def get_1d_rotary_pos_embed(
|
| 262 |
+
dim: int,
|
| 263 |
+
pos: Union[torch.FloatTensor, int],
|
| 264 |
+
theta: float = 10000.0,
|
| 265 |
+
use_real: bool = False,
|
| 266 |
+
theta_rescale_factor: float = 1.0,
|
| 267 |
+
interpolation_factor: float = 1.0,
|
| 268 |
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
| 269 |
+
"""
|
| 270 |
+
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
|
| 271 |
+
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
|
| 272 |
+
|
| 273 |
+
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
|
| 274 |
+
and the end index 'end'. The 'theta' parameter scales the frequencies.
|
| 275 |
+
The returned tensor contains complex values in complex64 data type.
|
| 276 |
+
|
| 277 |
+
Args:
|
| 278 |
+
dim (int): Dimension of the frequency tensor.
|
| 279 |
+
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
|
| 280 |
+
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
|
| 281 |
+
use_real (bool, optional): If True, return real part and imaginary part separately.
|
| 282 |
+
Otherwise, return complex numbers.
|
| 283 |
+
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
|
| 287 |
+
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
|
| 288 |
+
"""
|
| 289 |
+
if isinstance(pos, int):
|
| 290 |
+
pos = torch.arange(pos).float()
|
| 291 |
+
|
| 292 |
+
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
| 293 |
+
# has some connection to NTK literature
|
| 294 |
+
if theta_rescale_factor != 1.0:
|
| 295 |
+
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
| 296 |
+
|
| 297 |
+
freqs = 1.0 / (
|
| 298 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
| 299 |
+
) # [D/2]
|
| 300 |
+
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
|
| 301 |
+
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
|
| 302 |
+
if use_real:
|
| 303 |
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
| 304 |
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
| 305 |
+
return freqs_cos, freqs_sin
|
| 306 |
+
else:
|
| 307 |
+
freqs_cis = torch.polar(
|
| 308 |
+
torch.ones_like(freqs), freqs
|
| 309 |
+
) # complex64 # [S, D/2]
|
| 310 |
+
return freqs_cis
|
hunyuan_model/text_encoder.py
ADDED
|
@@ -0,0 +1,710 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from typing import Optional, Tuple, Union
|
| 5 |
+
from copy import deepcopy
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from transformers import (
|
| 10 |
+
CLIPTextModel,
|
| 11 |
+
CLIPTokenizer,
|
| 12 |
+
AutoTokenizer,
|
| 13 |
+
AutoModel,
|
| 14 |
+
CLIPConfig,
|
| 15 |
+
LlamaForCausalLM,
|
| 16 |
+
LlamaConfig,
|
| 17 |
+
)
|
| 18 |
+
from transformers.utils import ModelOutput
|
| 19 |
+
from transformers.models.llama import LlamaModel
|
| 20 |
+
from safetensors.torch import load_file
|
| 21 |
+
from accelerate import init_empty_weights
|
| 22 |
+
|
| 23 |
+
import logging
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
logging.basicConfig(level=logging.INFO)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
CLIP_L_HUGGINGFACE_MODEL_ID = "openai/clip-vit-large-patch14"
|
| 30 |
+
LLAVA_HUGGINGFACE_MODEL_ID = "xtuner/llava-llama-3-8b-v1_1-transformers"
|
| 31 |
+
|
| 32 |
+
CLIP_CONFIG = {
|
| 33 |
+
"_name_or_path": "clip-vit-large-patch14/",
|
| 34 |
+
"architectures": ["CLIPModel"],
|
| 35 |
+
"initializer_factor": 1.0,
|
| 36 |
+
"logit_scale_init_value": 2.6592,
|
| 37 |
+
"model_type": "clip",
|
| 38 |
+
"projection_dim": 768,
|
| 39 |
+
# "text_config": {
|
| 40 |
+
"_name_or_path": "",
|
| 41 |
+
"add_cross_attention": False,
|
| 42 |
+
"architectures": None,
|
| 43 |
+
"attention_dropout": 0.0,
|
| 44 |
+
"bad_words_ids": None,
|
| 45 |
+
"bos_token_id": 0,
|
| 46 |
+
"chunk_size_feed_forward": 0,
|
| 47 |
+
"cross_attention_hidden_size": None,
|
| 48 |
+
"decoder_start_token_id": None,
|
| 49 |
+
"diversity_penalty": 0.0,
|
| 50 |
+
"do_sample": False,
|
| 51 |
+
"dropout": 0.0,
|
| 52 |
+
"early_stopping": False,
|
| 53 |
+
"encoder_no_repeat_ngram_size": 0,
|
| 54 |
+
"eos_token_id": 2,
|
| 55 |
+
"finetuning_task": None,
|
| 56 |
+
"forced_bos_token_id": None,
|
| 57 |
+
"forced_eos_token_id": None,
|
| 58 |
+
"hidden_act": "quick_gelu",
|
| 59 |
+
"hidden_size": 768,
|
| 60 |
+
"id2label": {"0": "LABEL_0", "1": "LABEL_1"},
|
| 61 |
+
"initializer_factor": 1.0,
|
| 62 |
+
"initializer_range": 0.02,
|
| 63 |
+
"intermediate_size": 3072,
|
| 64 |
+
"is_decoder": False,
|
| 65 |
+
"is_encoder_decoder": False,
|
| 66 |
+
"label2id": {"LABEL_0": 0, "LABEL_1": 1},
|
| 67 |
+
"layer_norm_eps": 1e-05,
|
| 68 |
+
"length_penalty": 1.0,
|
| 69 |
+
"max_length": 20,
|
| 70 |
+
"max_position_embeddings": 77,
|
| 71 |
+
"min_length": 0,
|
| 72 |
+
"model_type": "clip_text_model",
|
| 73 |
+
"no_repeat_ngram_size": 0,
|
| 74 |
+
"num_attention_heads": 12,
|
| 75 |
+
"num_beam_groups": 1,
|
| 76 |
+
"num_beams": 1,
|
| 77 |
+
"num_hidden_layers": 12,
|
| 78 |
+
"num_return_sequences": 1,
|
| 79 |
+
"output_attentions": False,
|
| 80 |
+
"output_hidden_states": False,
|
| 81 |
+
"output_scores": False,
|
| 82 |
+
"pad_token_id": 1,
|
| 83 |
+
"prefix": None,
|
| 84 |
+
"problem_type": None,
|
| 85 |
+
"projection_dim": 768,
|
| 86 |
+
"pruned_heads": {},
|
| 87 |
+
"remove_invalid_values": False,
|
| 88 |
+
"repetition_penalty": 1.0,
|
| 89 |
+
"return_dict": True,
|
| 90 |
+
"return_dict_in_generate": False,
|
| 91 |
+
"sep_token_id": None,
|
| 92 |
+
"task_specific_params": None,
|
| 93 |
+
"temperature": 1.0,
|
| 94 |
+
"tie_encoder_decoder": False,
|
| 95 |
+
"tie_word_embeddings": True,
|
| 96 |
+
"tokenizer_class": None,
|
| 97 |
+
"top_k": 50,
|
| 98 |
+
"top_p": 1.0,
|
| 99 |
+
"torch_dtype": None,
|
| 100 |
+
"torchscript": False,
|
| 101 |
+
"transformers_version": "4.16.0.dev0",
|
| 102 |
+
"use_bfloat16": False,
|
| 103 |
+
"vocab_size": 49408,
|
| 104 |
+
# },
|
| 105 |
+
# "text_config_dict": {
|
| 106 |
+
"hidden_size": 768,
|
| 107 |
+
"intermediate_size": 3072,
|
| 108 |
+
"num_attention_heads": 12,
|
| 109 |
+
"num_hidden_layers": 12,
|
| 110 |
+
"projection_dim": 768,
|
| 111 |
+
# },
|
| 112 |
+
# "torch_dtype": "float32",
|
| 113 |
+
# "transformers_version": null
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
LLAMA_CONFIG = {
|
| 117 |
+
"architectures": ["LlamaForCausalLM"],
|
| 118 |
+
"attention_bias": False,
|
| 119 |
+
"attention_dropout": 0.0,
|
| 120 |
+
"bos_token_id": 128000,
|
| 121 |
+
"eos_token_id": 128001,
|
| 122 |
+
"head_dim": 128,
|
| 123 |
+
"hidden_act": "silu",
|
| 124 |
+
"hidden_size": 4096,
|
| 125 |
+
"initializer_range": 0.02,
|
| 126 |
+
"intermediate_size": 14336,
|
| 127 |
+
"max_position_embeddings": 8192,
|
| 128 |
+
"mlp_bias": False,
|
| 129 |
+
"model_type": "llama",
|
| 130 |
+
"num_attention_heads": 32,
|
| 131 |
+
"num_hidden_layers": 32,
|
| 132 |
+
"num_key_value_heads": 8,
|
| 133 |
+
"pretraining_tp": 1,
|
| 134 |
+
"rms_norm_eps": 1e-05,
|
| 135 |
+
"rope_scaling": None,
|
| 136 |
+
"rope_theta": 500000.0,
|
| 137 |
+
"tie_word_embeddings": False,
|
| 138 |
+
"torch_dtype": "float16",
|
| 139 |
+
"transformers_version": "4.46.3",
|
| 140 |
+
"use_cache": True,
|
| 141 |
+
"vocab_size": 128320,
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
# When using decoder-only models, we must provide a prompt template to instruct the text encoder
|
| 145 |
+
# on how to generate the text.
|
| 146 |
+
# --------------------------------------------------------------------
|
| 147 |
+
PROMPT_TEMPLATE_ENCODE = (
|
| 148 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
|
| 149 |
+
"quantity, text, spatial relationships of the objects and background:<|eot_id|>"
|
| 150 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
| 151 |
+
)
|
| 152 |
+
PROMPT_TEMPLATE_ENCODE_VIDEO = (
|
| 153 |
+
"<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
|
| 154 |
+
"1. The main content and theme of the video."
|
| 155 |
+
"2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
|
| 156 |
+
"3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
|
| 157 |
+
"4. background environment, light, style and atmosphere."
|
| 158 |
+
"5. camera angles, movements, and transitions used in the video:<|eot_id|>"
|
| 159 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
|
| 163 |
+
|
| 164 |
+
PROMPT_TEMPLATE = {
|
| 165 |
+
"dit-llm-encode": {
|
| 166 |
+
"template": PROMPT_TEMPLATE_ENCODE,
|
| 167 |
+
"crop_start": 36,
|
| 168 |
+
},
|
| 169 |
+
"dit-llm-encode-video": {
|
| 170 |
+
"template": PROMPT_TEMPLATE_ENCODE_VIDEO,
|
| 171 |
+
"crop_start": 95,
|
| 172 |
+
},
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def use_default(value, default):
|
| 177 |
+
return value if value is not None else default
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def load_clip_l(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
|
| 181 |
+
if os.path.isdir(text_encoder_path):
|
| 182 |
+
# load from directory, configs are in the directory
|
| 183 |
+
text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype)
|
| 184 |
+
else:
|
| 185 |
+
# load from file, we create the model with the appropriate config
|
| 186 |
+
config = CLIPConfig(**CLIP_CONFIG)
|
| 187 |
+
with init_empty_weights():
|
| 188 |
+
text_encoder = CLIPTextModel._from_config(config, torch_dtype=dtype)
|
| 189 |
+
|
| 190 |
+
state_dict = load_file(text_encoder_path)
|
| 191 |
+
|
| 192 |
+
text_encoder.load_state_dict(state_dict, strict=True, assign=True)
|
| 193 |
+
# if dtype is not None:
|
| 194 |
+
# text_encoder.to(dtype=dtype)
|
| 195 |
+
|
| 196 |
+
return text_encoder
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def load_clip_l_tokenizer(tokenizer_path: str):
|
| 200 |
+
if os.path.isdir(tokenizer_path):
|
| 201 |
+
tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
|
| 202 |
+
else:
|
| 203 |
+
# load from Hugging Face
|
| 204 |
+
logger.info(f"Loading tokenizer from Hugging Face: {CLIP_L_HUGGINGFACE_MODEL_ID}")
|
| 205 |
+
tokenizer = CLIPTokenizer.from_pretrained(CLIP_L_HUGGINGFACE_MODEL_ID, max_length=77)
|
| 206 |
+
|
| 207 |
+
return tokenizer
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def load_llm(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
|
| 211 |
+
if os.path.isdir(text_encoder_path):
|
| 212 |
+
# load from directory, configs are in the directory
|
| 213 |
+
text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype)
|
| 214 |
+
else:
|
| 215 |
+
# load from file, we create the model with the appropriate config
|
| 216 |
+
config = LlamaConfig(**LLAMA_CONFIG)
|
| 217 |
+
with init_empty_weights():
|
| 218 |
+
text_encoder = LlamaForCausalLM._from_config(config, torch_dtype=dtype)
|
| 219 |
+
|
| 220 |
+
state_dict = load_file(text_encoder_path)
|
| 221 |
+
|
| 222 |
+
# support weights from ComfyUI
|
| 223 |
+
if "tokenizer" in state_dict:
|
| 224 |
+
state_dict.pop("tokenizer")
|
| 225 |
+
|
| 226 |
+
text_encoder.load_state_dict(state_dict, strict=True, assign=True)
|
| 227 |
+
|
| 228 |
+
return text_encoder
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def load_llm_tokenizer(tokenizer_path: str, padding_side="right"):
|
| 232 |
+
if os.path.isdir(tokenizer_path):
|
| 233 |
+
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| 234 |
+
else:
|
| 235 |
+
# load from Hugging Face
|
| 236 |
+
logger.info(f"Loading tokenizer from Hugging Face: {LLAVA_HUGGINGFACE_MODEL_ID}")
|
| 237 |
+
tokenizer = AutoTokenizer.from_pretrained(LLAVA_HUGGINGFACE_MODEL_ID, padding_side=padding_side)
|
| 238 |
+
|
| 239 |
+
return tokenizer
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def load_text_encoder(
|
| 243 |
+
text_encoder_type: str,
|
| 244 |
+
text_encoder_path: str,
|
| 245 |
+
text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 246 |
+
):
|
| 247 |
+
logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
|
| 248 |
+
|
| 249 |
+
# reduce peak memory usage by specifying the dtype of the model
|
| 250 |
+
dtype = text_encoder_dtype
|
| 251 |
+
if text_encoder_type == "clipL":
|
| 252 |
+
text_encoder = load_clip_l(text_encoder_path, dtype=dtype)
|
| 253 |
+
text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
|
| 254 |
+
elif text_encoder_type == "llm":
|
| 255 |
+
text_encoder = load_llm(text_encoder_path, dtype=dtype)
|
| 256 |
+
if hasattr(text_encoder, "norm"):
|
| 257 |
+
text_encoder.final_layer_norm = text_encoder.norm # by from_pretrained
|
| 258 |
+
else:
|
| 259 |
+
text_encoder.final_layer_norm = text_encoder.model.norm # by _from_config
|
| 260 |
+
else:
|
| 261 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
| 262 |
+
# from_pretrained will ensure that the model is in eval mode.
|
| 263 |
+
|
| 264 |
+
if dtype is not None:
|
| 265 |
+
text_encoder = text_encoder.to(dtype=dtype)
|
| 266 |
+
|
| 267 |
+
text_encoder.requires_grad_(False)
|
| 268 |
+
|
| 269 |
+
logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
|
| 270 |
+
return text_encoder, text_encoder_path
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"):
|
| 274 |
+
logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
|
| 275 |
+
|
| 276 |
+
if tokenizer_type == "clipL":
|
| 277 |
+
tokenizer = load_clip_l_tokenizer(tokenizer_path)
|
| 278 |
+
elif tokenizer_type == "llm":
|
| 279 |
+
tokenizer = load_llm_tokenizer(tokenizer_path, padding_side=padding_side)
|
| 280 |
+
else:
|
| 281 |
+
raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
|
| 282 |
+
|
| 283 |
+
return tokenizer, tokenizer_path
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
@dataclass
|
| 287 |
+
class TextEncoderModelOutput(ModelOutput):
|
| 288 |
+
"""
|
| 289 |
+
Base class for model's outputs that also contains a pooling of the last hidden states.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
| 293 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
| 294 |
+
attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
| 295 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
| 296 |
+
hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
|
| 297 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
|
| 298 |
+
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
|
| 299 |
+
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
|
| 300 |
+
text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
|
| 301 |
+
List of decoded texts.
|
| 302 |
+
"""
|
| 303 |
+
|
| 304 |
+
hidden_state: torch.FloatTensor = None
|
| 305 |
+
attention_mask: Optional[torch.LongTensor] = None
|
| 306 |
+
hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
|
| 307 |
+
text_outputs: Optional[list] = None
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class TextEncoder(nn.Module):
|
| 311 |
+
def __init__(
|
| 312 |
+
self,
|
| 313 |
+
text_encoder_type: str,
|
| 314 |
+
max_length: int,
|
| 315 |
+
text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 316 |
+
text_encoder_path: Optional[str] = None,
|
| 317 |
+
tokenizer_type: Optional[str] = None,
|
| 318 |
+
tokenizer_path: Optional[str] = None,
|
| 319 |
+
output_key: Optional[str] = None,
|
| 320 |
+
use_attention_mask: bool = True,
|
| 321 |
+
input_max_length: Optional[int] = None,
|
| 322 |
+
prompt_template: Optional[dict] = None,
|
| 323 |
+
prompt_template_video: Optional[dict] = None,
|
| 324 |
+
hidden_state_skip_layer: Optional[int] = None,
|
| 325 |
+
apply_final_norm: bool = False,
|
| 326 |
+
reproduce: bool = False,
|
| 327 |
+
):
|
| 328 |
+
super().__init__()
|
| 329 |
+
self.text_encoder_type = text_encoder_type
|
| 330 |
+
self.max_length = max_length
|
| 331 |
+
# self.precision = text_encoder_precision
|
| 332 |
+
self.model_path = text_encoder_path
|
| 333 |
+
self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
|
| 334 |
+
self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
|
| 335 |
+
self.use_attention_mask = use_attention_mask
|
| 336 |
+
if prompt_template_video is not None:
|
| 337 |
+
assert use_attention_mask is True, "Attention mask is True required when training videos."
|
| 338 |
+
self.input_max_length = input_max_length if input_max_length is not None else max_length
|
| 339 |
+
self.prompt_template = prompt_template
|
| 340 |
+
self.prompt_template_video = prompt_template_video
|
| 341 |
+
self.hidden_state_skip_layer = hidden_state_skip_layer
|
| 342 |
+
self.apply_final_norm = apply_final_norm
|
| 343 |
+
self.reproduce = reproduce
|
| 344 |
+
|
| 345 |
+
self.use_template = self.prompt_template is not None
|
| 346 |
+
if self.use_template:
|
| 347 |
+
assert (
|
| 348 |
+
isinstance(self.prompt_template, dict) and "template" in self.prompt_template
|
| 349 |
+
), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
|
| 350 |
+
assert "{}" in str(self.prompt_template["template"]), (
|
| 351 |
+
"`prompt_template['template']` must contain a placeholder `{}` for the input text, "
|
| 352 |
+
f"got {self.prompt_template['template']}"
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
self.use_video_template = self.prompt_template_video is not None
|
| 356 |
+
if self.use_video_template:
|
| 357 |
+
if self.prompt_template_video is not None:
|
| 358 |
+
assert (
|
| 359 |
+
isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video
|
| 360 |
+
), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
|
| 361 |
+
assert "{}" in str(self.prompt_template_video["template"]), (
|
| 362 |
+
"`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
|
| 363 |
+
f"got {self.prompt_template_video['template']}"
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if "t5" in text_encoder_type:
|
| 367 |
+
self.output_key = output_key or "last_hidden_state"
|
| 368 |
+
elif "clip" in text_encoder_type:
|
| 369 |
+
self.output_key = output_key or "pooler_output"
|
| 370 |
+
elif "llm" in text_encoder_type or "glm" in text_encoder_type:
|
| 371 |
+
self.output_key = output_key or "last_hidden_state"
|
| 372 |
+
else:
|
| 373 |
+
raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
|
| 374 |
+
|
| 375 |
+
self.model, self.model_path = load_text_encoder(
|
| 376 |
+
text_encoder_type=self.text_encoder_type, text_encoder_path=self.model_path, text_encoder_dtype=text_encoder_dtype
|
| 377 |
+
)
|
| 378 |
+
self.dtype = self.model.dtype
|
| 379 |
+
|
| 380 |
+
self.tokenizer, self.tokenizer_path = load_tokenizer(
|
| 381 |
+
tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def __repr__(self):
|
| 385 |
+
return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
|
| 386 |
+
|
| 387 |
+
@property
|
| 388 |
+
def device(self):
|
| 389 |
+
return self.model.device
|
| 390 |
+
|
| 391 |
+
@staticmethod
|
| 392 |
+
def apply_text_to_template(text, template, prevent_empty_text=True):
|
| 393 |
+
"""
|
| 394 |
+
Apply text to template.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
text (str): Input text.
|
| 398 |
+
template (str or list): Template string or list of chat conversation.
|
| 399 |
+
prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
|
| 400 |
+
by adding a space. Defaults to True.
|
| 401 |
+
"""
|
| 402 |
+
if isinstance(template, str):
|
| 403 |
+
# Will send string to tokenizer. Used for llm
|
| 404 |
+
return template.format(text)
|
| 405 |
+
else:
|
| 406 |
+
raise TypeError(f"Unsupported template type: {type(template)}")
|
| 407 |
+
|
| 408 |
+
def text2tokens(self, text, data_type="image"):
|
| 409 |
+
"""
|
| 410 |
+
Tokenize the input text.
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
text (str or list): Input text.
|
| 414 |
+
"""
|
| 415 |
+
tokenize_input_type = "str"
|
| 416 |
+
if self.use_template:
|
| 417 |
+
if data_type == "image":
|
| 418 |
+
prompt_template = self.prompt_template["template"]
|
| 419 |
+
elif data_type == "video":
|
| 420 |
+
prompt_template = self.prompt_template_video["template"]
|
| 421 |
+
else:
|
| 422 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
| 423 |
+
if isinstance(text, (list, tuple)):
|
| 424 |
+
text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
|
| 425 |
+
if isinstance(text[0], list):
|
| 426 |
+
tokenize_input_type = "list"
|
| 427 |
+
elif isinstance(text, str):
|
| 428 |
+
text = self.apply_text_to_template(text, prompt_template)
|
| 429 |
+
if isinstance(text, list):
|
| 430 |
+
tokenize_input_type = "list"
|
| 431 |
+
else:
|
| 432 |
+
raise TypeError(f"Unsupported text type: {type(text)}")
|
| 433 |
+
|
| 434 |
+
kwargs = dict(
|
| 435 |
+
truncation=True,
|
| 436 |
+
max_length=self.max_length,
|
| 437 |
+
padding="max_length",
|
| 438 |
+
return_tensors="pt",
|
| 439 |
+
)
|
| 440 |
+
if tokenize_input_type == "str":
|
| 441 |
+
return self.tokenizer(
|
| 442 |
+
text,
|
| 443 |
+
return_length=False,
|
| 444 |
+
return_overflowing_tokens=False,
|
| 445 |
+
return_attention_mask=True,
|
| 446 |
+
**kwargs,
|
| 447 |
+
)
|
| 448 |
+
elif tokenize_input_type == "list":
|
| 449 |
+
return self.tokenizer.apply_chat_template(
|
| 450 |
+
text,
|
| 451 |
+
add_generation_prompt=True,
|
| 452 |
+
tokenize=True,
|
| 453 |
+
return_dict=True,
|
| 454 |
+
**kwargs,
|
| 455 |
+
)
|
| 456 |
+
else:
|
| 457 |
+
raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
|
| 458 |
+
|
| 459 |
+
def encode(
|
| 460 |
+
self,
|
| 461 |
+
batch_encoding,
|
| 462 |
+
use_attention_mask=None,
|
| 463 |
+
output_hidden_states=False,
|
| 464 |
+
do_sample=None,
|
| 465 |
+
hidden_state_skip_layer=None,
|
| 466 |
+
return_texts=False,
|
| 467 |
+
data_type="image",
|
| 468 |
+
device=None,
|
| 469 |
+
):
|
| 470 |
+
"""
|
| 471 |
+
Args:
|
| 472 |
+
batch_encoding (dict): Batch encoding from tokenizer.
|
| 473 |
+
use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
|
| 474 |
+
Defaults to None.
|
| 475 |
+
output_hidden_states (bool): Whether to output hidden states. If False, return the value of
|
| 476 |
+
self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
|
| 477 |
+
output_hidden_states will be set True. Defaults to False.
|
| 478 |
+
do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
|
| 479 |
+
When self.produce is False, do_sample is set to True by default.
|
| 480 |
+
hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
|
| 481 |
+
If None, self.output_key will be used. Defaults to None.
|
| 482 |
+
return_texts (bool): Whether to return the decoded texts. Defaults to False.
|
| 483 |
+
"""
|
| 484 |
+
device = self.model.device if device is None else device
|
| 485 |
+
use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
|
| 486 |
+
hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
|
| 487 |
+
do_sample = use_default(do_sample, not self.reproduce)
|
| 488 |
+
attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
|
| 489 |
+
outputs = self.model(
|
| 490 |
+
input_ids=batch_encoding["input_ids"].to(device),
|
| 491 |
+
attention_mask=attention_mask,
|
| 492 |
+
output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
|
| 493 |
+
)
|
| 494 |
+
if hidden_state_skip_layer is not None:
|
| 495 |
+
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
|
| 496 |
+
# Real last hidden state already has layer norm applied. So here we only apply it
|
| 497 |
+
# for intermediate layers.
|
| 498 |
+
if hidden_state_skip_layer > 0 and self.apply_final_norm:
|
| 499 |
+
last_hidden_state = self.model.final_layer_norm(last_hidden_state)
|
| 500 |
+
else:
|
| 501 |
+
last_hidden_state = outputs[self.output_key]
|
| 502 |
+
|
| 503 |
+
# Remove hidden states of instruction tokens, only keep prompt tokens.
|
| 504 |
+
if self.use_template:
|
| 505 |
+
if data_type == "image":
|
| 506 |
+
crop_start = self.prompt_template.get("crop_start", -1)
|
| 507 |
+
elif data_type == "video":
|
| 508 |
+
crop_start = self.prompt_template_video.get("crop_start", -1)
|
| 509 |
+
else:
|
| 510 |
+
raise ValueError(f"Unsupported data type: {data_type}")
|
| 511 |
+
if crop_start > 0:
|
| 512 |
+
last_hidden_state = last_hidden_state[:, crop_start:]
|
| 513 |
+
attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
|
| 514 |
+
|
| 515 |
+
if output_hidden_states:
|
| 516 |
+
return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
|
| 517 |
+
return TextEncoderModelOutput(last_hidden_state, attention_mask)
|
| 518 |
+
|
| 519 |
+
def forward(
|
| 520 |
+
self,
|
| 521 |
+
text,
|
| 522 |
+
use_attention_mask=None,
|
| 523 |
+
output_hidden_states=False,
|
| 524 |
+
do_sample=False,
|
| 525 |
+
hidden_state_skip_layer=None,
|
| 526 |
+
return_texts=False,
|
| 527 |
+
):
|
| 528 |
+
batch_encoding = self.text2tokens(text)
|
| 529 |
+
return self.encode(
|
| 530 |
+
batch_encoding,
|
| 531 |
+
use_attention_mask=use_attention_mask,
|
| 532 |
+
output_hidden_states=output_hidden_states,
|
| 533 |
+
do_sample=do_sample,
|
| 534 |
+
hidden_state_skip_layer=hidden_state_skip_layer,
|
| 535 |
+
return_texts=return_texts,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
# region HunyanVideo architecture
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
def load_text_encoder_1(
|
| 543 |
+
text_encoder_dir: str, device: torch.device, fp8_llm: bool, dtype: Optional[Union[str, torch.dtype]] = None
|
| 544 |
+
) -> TextEncoder:
|
| 545 |
+
text_encoder_dtype = dtype or torch.float16
|
| 546 |
+
text_encoder_type = "llm"
|
| 547 |
+
text_len = 256
|
| 548 |
+
hidden_state_skip_layer = 2
|
| 549 |
+
apply_final_norm = False
|
| 550 |
+
reproduce = False
|
| 551 |
+
|
| 552 |
+
prompt_template = "dit-llm-encode"
|
| 553 |
+
prompt_template = PROMPT_TEMPLATE[prompt_template]
|
| 554 |
+
prompt_template_video = "dit-llm-encode-video"
|
| 555 |
+
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video]
|
| 556 |
+
|
| 557 |
+
crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0)
|
| 558 |
+
max_length = text_len + crop_start
|
| 559 |
+
|
| 560 |
+
text_encoder_1 = TextEncoder(
|
| 561 |
+
text_encoder_type=text_encoder_type,
|
| 562 |
+
max_length=max_length,
|
| 563 |
+
text_encoder_dtype=text_encoder_dtype,
|
| 564 |
+
text_encoder_path=text_encoder_dir,
|
| 565 |
+
tokenizer_type=text_encoder_type,
|
| 566 |
+
prompt_template=prompt_template,
|
| 567 |
+
prompt_template_video=prompt_template_video,
|
| 568 |
+
hidden_state_skip_layer=hidden_state_skip_layer,
|
| 569 |
+
apply_final_norm=apply_final_norm,
|
| 570 |
+
reproduce=reproduce,
|
| 571 |
+
)
|
| 572 |
+
text_encoder_1.eval()
|
| 573 |
+
|
| 574 |
+
if fp8_llm:
|
| 575 |
+
org_dtype = text_encoder_1.dtype
|
| 576 |
+
logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
|
| 577 |
+
text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn)
|
| 578 |
+
|
| 579 |
+
# prepare LLM for fp8
|
| 580 |
+
def prepare_fp8(llama_model: LlamaModel, target_dtype):
|
| 581 |
+
def forward_hook(module):
|
| 582 |
+
def forward(hidden_states):
|
| 583 |
+
input_dtype = hidden_states.dtype
|
| 584 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 585 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 586 |
+
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
|
| 587 |
+
return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
|
| 588 |
+
|
| 589 |
+
return forward
|
| 590 |
+
|
| 591 |
+
for module in llama_model.modules():
|
| 592 |
+
if module.__class__.__name__ in ["Embedding"]:
|
| 593 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
| 594 |
+
module.to(target_dtype)
|
| 595 |
+
if module.__class__.__name__ in ["LlamaRMSNorm"]:
|
| 596 |
+
# print("set", module.__class__.__name__, "hooks")
|
| 597 |
+
module.forward = forward_hook(module)
|
| 598 |
+
|
| 599 |
+
prepare_fp8(text_encoder_1.model, org_dtype)
|
| 600 |
+
else:
|
| 601 |
+
text_encoder_1.to(device=device)
|
| 602 |
+
|
| 603 |
+
return text_encoder_1
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
def load_text_encoder_2(
|
| 607 |
+
text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None
|
| 608 |
+
) -> TextEncoder:
|
| 609 |
+
text_encoder_dtype = dtype or torch.float16
|
| 610 |
+
reproduce = False
|
| 611 |
+
|
| 612 |
+
text_encoder_2_type = "clipL"
|
| 613 |
+
text_len_2 = 77
|
| 614 |
+
|
| 615 |
+
text_encoder_2 = TextEncoder(
|
| 616 |
+
text_encoder_type=text_encoder_2_type,
|
| 617 |
+
max_length=text_len_2,
|
| 618 |
+
text_encoder_dtype=text_encoder_dtype,
|
| 619 |
+
text_encoder_path=text_encoder_dir,
|
| 620 |
+
tokenizer_type=text_encoder_2_type,
|
| 621 |
+
reproduce=reproduce,
|
| 622 |
+
)
|
| 623 |
+
text_encoder_2.eval()
|
| 624 |
+
|
| 625 |
+
text_encoder_2.to(device=device)
|
| 626 |
+
|
| 627 |
+
return text_encoder_2
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
# endregion
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
if __name__ == "__main__":
|
| 634 |
+
import argparse
|
| 635 |
+
from utils.model_utils import str_to_dtype
|
| 636 |
+
|
| 637 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 638 |
+
|
| 639 |
+
parser = argparse.ArgumentParser()
|
| 640 |
+
parser.add_argument("type", type=str, help="Text Encoder type")
|
| 641 |
+
parser.add_argument("path1", type=str, help="Text Encoder directory or file 1")
|
| 642 |
+
parser.add_argument("path2", type=str, help="Text Encoder directory or file 2")
|
| 643 |
+
parser.add_argument("--dtype", type=str, default=None, help="Data type for Text Encoder")
|
| 644 |
+
args = parser.parse_args()
|
| 645 |
+
|
| 646 |
+
dtype = str_to_dtype(args.dtype) if args.dtype is not None else torch.float16
|
| 647 |
+
|
| 648 |
+
"""
|
| 649 |
+
if args.type == "clipL":
|
| 650 |
+
text_encoder_1st = load_clip_l(args.path1, dtype=dtype)
|
| 651 |
+
tokenizer_1st = load_clip_l_tokenizer(args.path1)
|
| 652 |
+
text_encoder_2nd = load_clip_l(args.path2, dtype=dtype)
|
| 653 |
+
tokenizer_2nd = load_clip_l_tokenizer(args.path2)
|
| 654 |
+
elif args.type == "llm":
|
| 655 |
+
text_encoder_1st = load_llm(args.path1, dtype=dtype)
|
| 656 |
+
tokenizer_1st = load_llm_tokenizer(args.path1)
|
| 657 |
+
text_encoder_2nd = load_llm(args.path2, dtype=dtype)
|
| 658 |
+
tokenizer_2nd = load_llm_tokenizer(args.path2)
|
| 659 |
+
|
| 660 |
+
print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
|
| 661 |
+
print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
|
| 662 |
+
|
| 663 |
+
text_encoder_1st.to(device=device)
|
| 664 |
+
text_encoder_2nd.to(device=device)
|
| 665 |
+
|
| 666 |
+
test_text = "A cat sitting on a table"
|
| 667 |
+
token_ids_1st = tokenizer_1st(test_text, return_tensors="pt")["input_ids"]
|
| 668 |
+
token_ids_2nd = tokenizer_2nd(test_text, return_tensors="pt")["input_ids"]
|
| 669 |
+
assert torch.allclose(token_ids_1st, token_ids_2nd)
|
| 670 |
+
print(f"Token IDs are the same: {token_ids_1st}")
|
| 671 |
+
|
| 672 |
+
with torch.no_grad():
|
| 673 |
+
text_encoder_1st_output = text_encoder_1st(token_ids_1st.to(device), output_hidden_states=True)
|
| 674 |
+
text_encoder_2nd_output = text_encoder_2nd(token_ids_2nd.to(device), output_hidden_states=True)
|
| 675 |
+
print(f"1st Text Encoder output keys: {text_encoder_1st_output.keys()}")
|
| 676 |
+
print(f"2nd Text Encoder output keys: {text_encoder_2nd_output.keys()}")
|
| 677 |
+
for key in text_encoder_1st_output:
|
| 678 |
+
print(f"Checking output: {key}")
|
| 679 |
+
assert key in text_encoder_2nd_output, f"Key {key} not in 2nd Text Encoder output"
|
| 680 |
+
assert torch.allclose(text_encoder_1st_output[key], text_encoder_2nd_output[key])
|
| 681 |
+
print(f"Outputs are the same: {key}")
|
| 682 |
+
print("All outputs are the same.")
|
| 683 |
+
"""
|
| 684 |
+
|
| 685 |
+
if args.type == "clipL":
|
| 686 |
+
text_encoder_1st = load_text_encoder_2(args.path1, device, dtype)
|
| 687 |
+
text_encoder_2nd = load_text_encoder_2(args.path2, device, dtype)
|
| 688 |
+
elif args.type == "llm":
|
| 689 |
+
text_encoder_1st = load_text_encoder_1(args.path1, device, False, dtype)
|
| 690 |
+
text_encoder_2nd = load_text_encoder_1(args.path2, device, False, dtype)
|
| 691 |
+
print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
|
| 692 |
+
print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
|
| 693 |
+
|
| 694 |
+
prompt = "A cat sitting on a table"
|
| 695 |
+
data_type = "video" # video only, image is not supported
|
| 696 |
+
text_inputs_1st = text_encoder_1st.text2tokens(prompt, data_type=data_type)
|
| 697 |
+
text_inputs_2nd = text_encoder_2nd.text2tokens(prompt, data_type=data_type)
|
| 698 |
+
print(text_inputs_1st)
|
| 699 |
+
assert torch.allclose(text_inputs_1st["input_ids"], text_inputs_2nd["input_ids"])
|
| 700 |
+
|
| 701 |
+
with torch.no_grad():
|
| 702 |
+
prompt_outputs_1st = text_encoder_1st.encode(text_inputs_1st, data_type=data_type)
|
| 703 |
+
prompt_outputs_2nd = text_encoder_2nd.encode(text_inputs_1st, data_type=data_type)
|
| 704 |
+
|
| 705 |
+
# prompt_outputs.hidden_state, prompt_outputs.attention_mask
|
| 706 |
+
assert torch.allclose(prompt_outputs_1st.hidden_state, prompt_outputs_2nd.hidden_state)
|
| 707 |
+
print("Hidden states are the same.")
|
| 708 |
+
assert torch.allclose(prompt_outputs_1st.attention_mask, prompt_outputs_2nd.attention_mask)
|
| 709 |
+
print("Attention masks are the same.")
|
| 710 |
+
print("All outputs are the same.")
|
hunyuan_model/token_refiner.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.utils.checkpoint import checkpoint
|
| 7 |
+
|
| 8 |
+
from .activation_layers import get_activation_layer
|
| 9 |
+
from .attention import attention
|
| 10 |
+
from .norm_layers import get_norm_layer
|
| 11 |
+
from .embed_layers import TimestepEmbedder, TextProjection
|
| 12 |
+
from .mlp_layers import MLP
|
| 13 |
+
from .modulate_layers import modulate, apply_gate
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class IndividualTokenRefinerBlock(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
hidden_size,
|
| 20 |
+
heads_num,
|
| 21 |
+
mlp_width_ratio: str = 4.0,
|
| 22 |
+
mlp_drop_rate: float = 0.0,
|
| 23 |
+
act_type: str = "silu",
|
| 24 |
+
qk_norm: bool = False,
|
| 25 |
+
qk_norm_type: str = "layer",
|
| 26 |
+
qkv_bias: bool = True,
|
| 27 |
+
dtype: Optional[torch.dtype] = None,
|
| 28 |
+
device: Optional[torch.device] = None,
|
| 29 |
+
):
|
| 30 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.heads_num = heads_num
|
| 33 |
+
head_dim = hidden_size // heads_num
|
| 34 |
+
mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
|
| 35 |
+
|
| 36 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
| 37 |
+
self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
|
| 38 |
+
qk_norm_layer = get_norm_layer(qk_norm_type)
|
| 39 |
+
self.self_attn_q_norm = (
|
| 40 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 41 |
+
)
|
| 42 |
+
self.self_attn_k_norm = (
|
| 43 |
+
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
|
| 44 |
+
)
|
| 45 |
+
self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
|
| 46 |
+
|
| 47 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
|
| 48 |
+
act_layer = get_activation_layer(act_type)
|
| 49 |
+
self.mlp = MLP(
|
| 50 |
+
in_channels=hidden_size,
|
| 51 |
+
hidden_channels=mlp_hidden_dim,
|
| 52 |
+
act_layer=act_layer,
|
| 53 |
+
drop=mlp_drop_rate,
|
| 54 |
+
**factory_kwargs,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.adaLN_modulation = nn.Sequential(
|
| 58 |
+
act_layer(),
|
| 59 |
+
nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
|
| 60 |
+
)
|
| 61 |
+
# Zero-initialize the modulation
|
| 62 |
+
nn.init.zeros_(self.adaLN_modulation[1].weight)
|
| 63 |
+
nn.init.zeros_(self.adaLN_modulation[1].bias)
|
| 64 |
+
|
| 65 |
+
self.gradient_checkpointing = False
|
| 66 |
+
|
| 67 |
+
def enable_gradient_checkpointing(self):
|
| 68 |
+
self.gradient_checkpointing = True
|
| 69 |
+
|
| 70 |
+
def disable_gradient_checkpointing(self):
|
| 71 |
+
self.gradient_checkpointing = False
|
| 72 |
+
|
| 73 |
+
def _forward(
|
| 74 |
+
self,
|
| 75 |
+
x: torch.Tensor,
|
| 76 |
+
c: torch.Tensor, # timestep_aware_representations + context_aware_representations
|
| 77 |
+
attn_mask: torch.Tensor = None,
|
| 78 |
+
):
|
| 79 |
+
gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
|
| 80 |
+
|
| 81 |
+
norm_x = self.norm1(x)
|
| 82 |
+
qkv = self.self_attn_qkv(norm_x)
|
| 83 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
|
| 84 |
+
# Apply QK-Norm if needed
|
| 85 |
+
q = self.self_attn_q_norm(q).to(v)
|
| 86 |
+
k = self.self_attn_k_norm(k).to(v)
|
| 87 |
+
|
| 88 |
+
# Self-Attention
|
| 89 |
+
attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
|
| 90 |
+
|
| 91 |
+
x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
|
| 92 |
+
|
| 93 |
+
# FFN Layer
|
| 94 |
+
x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
|
| 95 |
+
|
| 96 |
+
return x
|
| 97 |
+
|
| 98 |
+
def forward(self, *args, **kwargs):
|
| 99 |
+
if self.training and self.gradient_checkpointing:
|
| 100 |
+
return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
|
| 101 |
+
else:
|
| 102 |
+
return self._forward(*args, **kwargs)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class IndividualTokenRefiner(nn.Module):
|
| 106 |
+
def __init__(
|
| 107 |
+
self,
|
| 108 |
+
hidden_size,
|
| 109 |
+
heads_num,
|
| 110 |
+
depth,
|
| 111 |
+
mlp_width_ratio: float = 4.0,
|
| 112 |
+
mlp_drop_rate: float = 0.0,
|
| 113 |
+
act_type: str = "silu",
|
| 114 |
+
qk_norm: bool = False,
|
| 115 |
+
qk_norm_type: str = "layer",
|
| 116 |
+
qkv_bias: bool = True,
|
| 117 |
+
dtype: Optional[torch.dtype] = None,
|
| 118 |
+
device: Optional[torch.device] = None,
|
| 119 |
+
):
|
| 120 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.blocks = nn.ModuleList(
|
| 123 |
+
[
|
| 124 |
+
IndividualTokenRefinerBlock(
|
| 125 |
+
hidden_size=hidden_size,
|
| 126 |
+
heads_num=heads_num,
|
| 127 |
+
mlp_width_ratio=mlp_width_ratio,
|
| 128 |
+
mlp_drop_rate=mlp_drop_rate,
|
| 129 |
+
act_type=act_type,
|
| 130 |
+
qk_norm=qk_norm,
|
| 131 |
+
qk_norm_type=qk_norm_type,
|
| 132 |
+
qkv_bias=qkv_bias,
|
| 133 |
+
**factory_kwargs,
|
| 134 |
+
)
|
| 135 |
+
for _ in range(depth)
|
| 136 |
+
]
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
def enable_gradient_checkpointing(self):
|
| 140 |
+
for block in self.blocks:
|
| 141 |
+
block.enable_gradient_checkpointing()
|
| 142 |
+
|
| 143 |
+
def disable_gradient_checkpointing(self):
|
| 144 |
+
for block in self.blocks:
|
| 145 |
+
block.disable_gradient_checkpointing()
|
| 146 |
+
|
| 147 |
+
def forward(
|
| 148 |
+
self,
|
| 149 |
+
x: torch.Tensor,
|
| 150 |
+
c: torch.LongTensor,
|
| 151 |
+
mask: Optional[torch.Tensor] = None,
|
| 152 |
+
):
|
| 153 |
+
self_attn_mask = None
|
| 154 |
+
if mask is not None:
|
| 155 |
+
batch_size = mask.shape[0]
|
| 156 |
+
seq_len = mask.shape[1]
|
| 157 |
+
mask = mask.to(x.device)
|
| 158 |
+
# batch_size x 1 x seq_len x seq_len
|
| 159 |
+
self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
|
| 160 |
+
# batch_size x 1 x seq_len x seq_len
|
| 161 |
+
self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
|
| 162 |
+
# batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
|
| 163 |
+
self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
|
| 164 |
+
# avoids self-attention weight being NaN for padding tokens
|
| 165 |
+
self_attn_mask[:, :, :, 0] = True
|
| 166 |
+
|
| 167 |
+
for block in self.blocks:
|
| 168 |
+
x = block(x, c, self_attn_mask)
|
| 169 |
+
return x
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class SingleTokenRefiner(nn.Module):
|
| 173 |
+
"""
|
| 174 |
+
A single token refiner block for llm text embedding refine.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
in_channels,
|
| 180 |
+
hidden_size,
|
| 181 |
+
heads_num,
|
| 182 |
+
depth,
|
| 183 |
+
mlp_width_ratio: float = 4.0,
|
| 184 |
+
mlp_drop_rate: float = 0.0,
|
| 185 |
+
act_type: str = "silu",
|
| 186 |
+
qk_norm: bool = False,
|
| 187 |
+
qk_norm_type: str = "layer",
|
| 188 |
+
qkv_bias: bool = True,
|
| 189 |
+
attn_mode: str = "torch",
|
| 190 |
+
dtype: Optional[torch.dtype] = None,
|
| 191 |
+
device: Optional[torch.device] = None,
|
| 192 |
+
):
|
| 193 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.attn_mode = attn_mode
|
| 196 |
+
assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
|
| 197 |
+
|
| 198 |
+
self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
|
| 199 |
+
|
| 200 |
+
act_layer = get_activation_layer(act_type)
|
| 201 |
+
# Build timestep embedding layer
|
| 202 |
+
self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
|
| 203 |
+
# Build context embedding layer
|
| 204 |
+
self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
|
| 205 |
+
|
| 206 |
+
self.individual_token_refiner = IndividualTokenRefiner(
|
| 207 |
+
hidden_size=hidden_size,
|
| 208 |
+
heads_num=heads_num,
|
| 209 |
+
depth=depth,
|
| 210 |
+
mlp_width_ratio=mlp_width_ratio,
|
| 211 |
+
mlp_drop_rate=mlp_drop_rate,
|
| 212 |
+
act_type=act_type,
|
| 213 |
+
qk_norm=qk_norm,
|
| 214 |
+
qk_norm_type=qk_norm_type,
|
| 215 |
+
qkv_bias=qkv_bias,
|
| 216 |
+
**factory_kwargs,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def enable_gradient_checkpointing(self):
|
| 220 |
+
self.individual_token_refiner.enable_gradient_checkpointing()
|
| 221 |
+
|
| 222 |
+
def disable_gradient_checkpointing(self):
|
| 223 |
+
self.individual_token_refiner.disable_gradient_checkpointing()
|
| 224 |
+
|
| 225 |
+
def forward(
|
| 226 |
+
self,
|
| 227 |
+
x: torch.Tensor,
|
| 228 |
+
t: torch.LongTensor,
|
| 229 |
+
mask: Optional[torch.LongTensor] = None,
|
| 230 |
+
):
|
| 231 |
+
timestep_aware_representations = self.t_embedder(t)
|
| 232 |
+
|
| 233 |
+
if mask is None:
|
| 234 |
+
context_aware_representations = x.mean(dim=1)
|
| 235 |
+
else:
|
| 236 |
+
mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
|
| 237 |
+
context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
|
| 238 |
+
context_aware_representations = self.c_embedder(context_aware_representations)
|
| 239 |
+
c = timestep_aware_representations + context_aware_representations
|
| 240 |
+
|
| 241 |
+
x = self.input_embedder(x)
|
| 242 |
+
|
| 243 |
+
x = self.individual_token_refiner(x, c, mask)
|
| 244 |
+
|
| 245 |
+
return x
|
hunyuan_model/vae.py
ADDED
|
@@ -0,0 +1,446 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import json
|
| 3 |
+
from typing import Optional, Tuple, Union
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
from diffusers.utils import BaseOutput, is_torch_version
|
| 11 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 12 |
+
from diffusers.models.attention_processor import SpatialNorm
|
| 13 |
+
from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
logging.basicConfig(level=logging.INFO)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
SCALING_FACTOR = 0.476986
|
| 22 |
+
VAE_VER = "884-16c-hy" # We don't support other versions currently
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def load_vae(
|
| 26 |
+
vae_type: str = "884-16c-hy",
|
| 27 |
+
vae_dtype: Optional[Union[str, torch.dtype]] = None,
|
| 28 |
+
sample_size: tuple = None,
|
| 29 |
+
vae_path: str = None,
|
| 30 |
+
device=None,
|
| 31 |
+
):
|
| 32 |
+
"""the fucntion to load the 3D VAE model
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
|
| 36 |
+
vae_precision (str, optional): the precision to load vae. Defaults to None.
|
| 37 |
+
sample_size (tuple, optional): the tiling size. Defaults to None.
|
| 38 |
+
vae_path (str, optional): the path to vae. Defaults to None.
|
| 39 |
+
logger (_type_, optional): logger. Defaults to None.
|
| 40 |
+
device (_type_, optional): device to load vae. Defaults to None.
|
| 41 |
+
"""
|
| 42 |
+
if vae_path is None:
|
| 43 |
+
vae_path = VAE_PATH[vae_type]
|
| 44 |
+
|
| 45 |
+
logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
|
| 46 |
+
|
| 47 |
+
# use fixed config for Hunyuan's VAE
|
| 48 |
+
CONFIG_JSON = """{
|
| 49 |
+
"_class_name": "AutoencoderKLCausal3D",
|
| 50 |
+
"_diffusers_version": "0.4.2",
|
| 51 |
+
"act_fn": "silu",
|
| 52 |
+
"block_out_channels": [
|
| 53 |
+
128,
|
| 54 |
+
256,
|
| 55 |
+
512,
|
| 56 |
+
512
|
| 57 |
+
],
|
| 58 |
+
"down_block_types": [
|
| 59 |
+
"DownEncoderBlockCausal3D",
|
| 60 |
+
"DownEncoderBlockCausal3D",
|
| 61 |
+
"DownEncoderBlockCausal3D",
|
| 62 |
+
"DownEncoderBlockCausal3D"
|
| 63 |
+
],
|
| 64 |
+
"in_channels": 3,
|
| 65 |
+
"latent_channels": 16,
|
| 66 |
+
"layers_per_block": 2,
|
| 67 |
+
"norm_num_groups": 32,
|
| 68 |
+
"out_channels": 3,
|
| 69 |
+
"sample_size": 256,
|
| 70 |
+
"sample_tsize": 64,
|
| 71 |
+
"up_block_types": [
|
| 72 |
+
"UpDecoderBlockCausal3D",
|
| 73 |
+
"UpDecoderBlockCausal3D",
|
| 74 |
+
"UpDecoderBlockCausal3D",
|
| 75 |
+
"UpDecoderBlockCausal3D"
|
| 76 |
+
],
|
| 77 |
+
"scaling_factor": 0.476986,
|
| 78 |
+
"time_compression_ratio": 4,
|
| 79 |
+
"mid_block_add_attention": true
|
| 80 |
+
}"""
|
| 81 |
+
|
| 82 |
+
# config = AutoencoderKLCausal3D.load_config(vae_path)
|
| 83 |
+
config = json.loads(CONFIG_JSON)
|
| 84 |
+
|
| 85 |
+
# import here to avoid circular import
|
| 86 |
+
from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
|
| 87 |
+
|
| 88 |
+
if sample_size:
|
| 89 |
+
vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
|
| 90 |
+
else:
|
| 91 |
+
vae = AutoencoderKLCausal3D.from_config(config)
|
| 92 |
+
|
| 93 |
+
# vae_ckpt = Path(vae_path) / "pytorch_model.pt"
|
| 94 |
+
# assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
|
| 95 |
+
|
| 96 |
+
if vae_path.endswith(".safetensors"):
|
| 97 |
+
from safetensors.torch import load_file
|
| 98 |
+
ckpt = load_file(vae_path)
|
| 99 |
+
else:
|
| 100 |
+
ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True)
|
| 101 |
+
if "state_dict" in ckpt:
|
| 102 |
+
ckpt = ckpt["state_dict"]
|
| 103 |
+
if any(k.startswith("vae.") for k in ckpt.keys()):
|
| 104 |
+
ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
|
| 105 |
+
vae.load_state_dict(ckpt)
|
| 106 |
+
|
| 107 |
+
spatial_compression_ratio = vae.config.spatial_compression_ratio
|
| 108 |
+
time_compression_ratio = vae.config.time_compression_ratio
|
| 109 |
+
|
| 110 |
+
if vae_dtype is not None:
|
| 111 |
+
vae = vae.to(vae_dtype)
|
| 112 |
+
|
| 113 |
+
vae.requires_grad_(False)
|
| 114 |
+
|
| 115 |
+
logger.info(f"VAE to dtype: {vae.dtype}")
|
| 116 |
+
|
| 117 |
+
if device is not None:
|
| 118 |
+
vae = vae.to(device)
|
| 119 |
+
|
| 120 |
+
vae.eval()
|
| 121 |
+
|
| 122 |
+
return vae, vae_path, spatial_compression_ratio, time_compression_ratio
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
@dataclass
|
| 126 |
+
class DecoderOutput(BaseOutput):
|
| 127 |
+
r"""
|
| 128 |
+
Output of decoding method.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
|
| 132 |
+
The decoded output sample from the last layer of the model.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
sample: torch.FloatTensor
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class EncoderCausal3D(nn.Module):
|
| 139 |
+
r"""
|
| 140 |
+
The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
in_channels: int = 3,
|
| 146 |
+
out_channels: int = 3,
|
| 147 |
+
down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
|
| 148 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 149 |
+
layers_per_block: int = 2,
|
| 150 |
+
norm_num_groups: int = 32,
|
| 151 |
+
act_fn: str = "silu",
|
| 152 |
+
double_z: bool = True,
|
| 153 |
+
mid_block_add_attention=True,
|
| 154 |
+
time_compression_ratio: int = 4,
|
| 155 |
+
spatial_compression_ratio: int = 8,
|
| 156 |
+
):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.layers_per_block = layers_per_block
|
| 159 |
+
|
| 160 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
|
| 161 |
+
self.mid_block = None
|
| 162 |
+
self.down_blocks = nn.ModuleList([])
|
| 163 |
+
|
| 164 |
+
# down
|
| 165 |
+
output_channel = block_out_channels[0]
|
| 166 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 167 |
+
input_channel = output_channel
|
| 168 |
+
output_channel = block_out_channels[i]
|
| 169 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 170 |
+
num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
|
| 171 |
+
num_time_downsample_layers = int(np.log2(time_compression_ratio))
|
| 172 |
+
|
| 173 |
+
if time_compression_ratio == 4:
|
| 174 |
+
add_spatial_downsample = bool(i < num_spatial_downsample_layers)
|
| 175 |
+
add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
|
| 176 |
+
else:
|
| 177 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
| 178 |
+
|
| 179 |
+
downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
|
| 180 |
+
downsample_stride_T = (2,) if add_time_downsample else (1,)
|
| 181 |
+
downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
|
| 182 |
+
down_block = get_down_block3d(
|
| 183 |
+
down_block_type,
|
| 184 |
+
num_layers=self.layers_per_block,
|
| 185 |
+
in_channels=input_channel,
|
| 186 |
+
out_channels=output_channel,
|
| 187 |
+
add_downsample=bool(add_spatial_downsample or add_time_downsample),
|
| 188 |
+
downsample_stride=downsample_stride,
|
| 189 |
+
resnet_eps=1e-6,
|
| 190 |
+
downsample_padding=0,
|
| 191 |
+
resnet_act_fn=act_fn,
|
| 192 |
+
resnet_groups=norm_num_groups,
|
| 193 |
+
attention_head_dim=output_channel,
|
| 194 |
+
temb_channels=None,
|
| 195 |
+
)
|
| 196 |
+
self.down_blocks.append(down_block)
|
| 197 |
+
|
| 198 |
+
# mid
|
| 199 |
+
self.mid_block = UNetMidBlockCausal3D(
|
| 200 |
+
in_channels=block_out_channels[-1],
|
| 201 |
+
resnet_eps=1e-6,
|
| 202 |
+
resnet_act_fn=act_fn,
|
| 203 |
+
output_scale_factor=1,
|
| 204 |
+
resnet_time_scale_shift="default",
|
| 205 |
+
attention_head_dim=block_out_channels[-1],
|
| 206 |
+
resnet_groups=norm_num_groups,
|
| 207 |
+
temb_channels=None,
|
| 208 |
+
add_attention=mid_block_add_attention,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# out
|
| 212 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
| 213 |
+
self.conv_act = nn.SiLU()
|
| 214 |
+
|
| 215 |
+
conv_out_channels = 2 * out_channels if double_z else out_channels
|
| 216 |
+
self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
|
| 217 |
+
|
| 218 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
| 219 |
+
r"""The forward method of the `EncoderCausal3D` class."""
|
| 220 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
|
| 221 |
+
|
| 222 |
+
sample = self.conv_in(sample)
|
| 223 |
+
|
| 224 |
+
# down
|
| 225 |
+
for down_block in self.down_blocks:
|
| 226 |
+
sample = down_block(sample)
|
| 227 |
+
|
| 228 |
+
# middle
|
| 229 |
+
sample = self.mid_block(sample)
|
| 230 |
+
|
| 231 |
+
# post-process
|
| 232 |
+
sample = self.conv_norm_out(sample)
|
| 233 |
+
sample = self.conv_act(sample)
|
| 234 |
+
sample = self.conv_out(sample)
|
| 235 |
+
|
| 236 |
+
return sample
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class DecoderCausal3D(nn.Module):
|
| 240 |
+
r"""
|
| 241 |
+
The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
| 242 |
+
"""
|
| 243 |
+
|
| 244 |
+
def __init__(
|
| 245 |
+
self,
|
| 246 |
+
in_channels: int = 3,
|
| 247 |
+
out_channels: int = 3,
|
| 248 |
+
up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
|
| 249 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 250 |
+
layers_per_block: int = 2,
|
| 251 |
+
norm_num_groups: int = 32,
|
| 252 |
+
act_fn: str = "silu",
|
| 253 |
+
norm_type: str = "group", # group, spatial
|
| 254 |
+
mid_block_add_attention=True,
|
| 255 |
+
time_compression_ratio: int = 4,
|
| 256 |
+
spatial_compression_ratio: int = 8,
|
| 257 |
+
):
|
| 258 |
+
super().__init__()
|
| 259 |
+
self.layers_per_block = layers_per_block
|
| 260 |
+
|
| 261 |
+
self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
|
| 262 |
+
self.mid_block = None
|
| 263 |
+
self.up_blocks = nn.ModuleList([])
|
| 264 |
+
|
| 265 |
+
temb_channels = in_channels if norm_type == "spatial" else None
|
| 266 |
+
|
| 267 |
+
# mid
|
| 268 |
+
self.mid_block = UNetMidBlockCausal3D(
|
| 269 |
+
in_channels=block_out_channels[-1],
|
| 270 |
+
resnet_eps=1e-6,
|
| 271 |
+
resnet_act_fn=act_fn,
|
| 272 |
+
output_scale_factor=1,
|
| 273 |
+
resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
|
| 274 |
+
attention_head_dim=block_out_channels[-1],
|
| 275 |
+
resnet_groups=norm_num_groups,
|
| 276 |
+
temb_channels=temb_channels,
|
| 277 |
+
add_attention=mid_block_add_attention,
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# up
|
| 281 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 282 |
+
output_channel = reversed_block_out_channels[0]
|
| 283 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 284 |
+
prev_output_channel = output_channel
|
| 285 |
+
output_channel = reversed_block_out_channels[i]
|
| 286 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 287 |
+
num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
|
| 288 |
+
num_time_upsample_layers = int(np.log2(time_compression_ratio))
|
| 289 |
+
|
| 290 |
+
if time_compression_ratio == 4:
|
| 291 |
+
add_spatial_upsample = bool(i < num_spatial_upsample_layers)
|
| 292 |
+
add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
|
| 293 |
+
else:
|
| 294 |
+
raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
|
| 295 |
+
|
| 296 |
+
upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
|
| 297 |
+
upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
|
| 298 |
+
upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
|
| 299 |
+
up_block = get_up_block3d(
|
| 300 |
+
up_block_type,
|
| 301 |
+
num_layers=self.layers_per_block + 1,
|
| 302 |
+
in_channels=prev_output_channel,
|
| 303 |
+
out_channels=output_channel,
|
| 304 |
+
prev_output_channel=None,
|
| 305 |
+
add_upsample=bool(add_spatial_upsample or add_time_upsample),
|
| 306 |
+
upsample_scale_factor=upsample_scale_factor,
|
| 307 |
+
resnet_eps=1e-6,
|
| 308 |
+
resnet_act_fn=act_fn,
|
| 309 |
+
resnet_groups=norm_num_groups,
|
| 310 |
+
attention_head_dim=output_channel,
|
| 311 |
+
temb_channels=temb_channels,
|
| 312 |
+
resnet_time_scale_shift=norm_type,
|
| 313 |
+
)
|
| 314 |
+
self.up_blocks.append(up_block)
|
| 315 |
+
prev_output_channel = output_channel
|
| 316 |
+
|
| 317 |
+
# out
|
| 318 |
+
if norm_type == "spatial":
|
| 319 |
+
self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
|
| 320 |
+
else:
|
| 321 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
|
| 322 |
+
self.conv_act = nn.SiLU()
|
| 323 |
+
self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
|
| 324 |
+
|
| 325 |
+
self.gradient_checkpointing = False
|
| 326 |
+
|
| 327 |
+
def forward(
|
| 328 |
+
self,
|
| 329 |
+
sample: torch.FloatTensor,
|
| 330 |
+
latent_embeds: Optional[torch.FloatTensor] = None,
|
| 331 |
+
) -> torch.FloatTensor:
|
| 332 |
+
r"""The forward method of the `DecoderCausal3D` class."""
|
| 333 |
+
assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
|
| 334 |
+
|
| 335 |
+
sample = self.conv_in(sample)
|
| 336 |
+
|
| 337 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
| 338 |
+
if self.training and self.gradient_checkpointing:
|
| 339 |
+
|
| 340 |
+
def create_custom_forward(module):
|
| 341 |
+
def custom_forward(*inputs):
|
| 342 |
+
return module(*inputs)
|
| 343 |
+
|
| 344 |
+
return custom_forward
|
| 345 |
+
|
| 346 |
+
if is_torch_version(">=", "1.11.0"):
|
| 347 |
+
# middle
|
| 348 |
+
sample = torch.utils.checkpoint.checkpoint(
|
| 349 |
+
create_custom_forward(self.mid_block),
|
| 350 |
+
sample,
|
| 351 |
+
latent_embeds,
|
| 352 |
+
use_reentrant=False,
|
| 353 |
+
)
|
| 354 |
+
sample = sample.to(upscale_dtype)
|
| 355 |
+
|
| 356 |
+
# up
|
| 357 |
+
for up_block in self.up_blocks:
|
| 358 |
+
sample = torch.utils.checkpoint.checkpoint(
|
| 359 |
+
create_custom_forward(up_block),
|
| 360 |
+
sample,
|
| 361 |
+
latent_embeds,
|
| 362 |
+
use_reentrant=False,
|
| 363 |
+
)
|
| 364 |
+
else:
|
| 365 |
+
# middle
|
| 366 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
|
| 367 |
+
sample = sample.to(upscale_dtype)
|
| 368 |
+
|
| 369 |
+
# up
|
| 370 |
+
for up_block in self.up_blocks:
|
| 371 |
+
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
|
| 372 |
+
else:
|
| 373 |
+
# middle
|
| 374 |
+
sample = self.mid_block(sample, latent_embeds)
|
| 375 |
+
sample = sample.to(upscale_dtype)
|
| 376 |
+
|
| 377 |
+
# up
|
| 378 |
+
for up_block in self.up_blocks:
|
| 379 |
+
sample = up_block(sample, latent_embeds)
|
| 380 |
+
|
| 381 |
+
# post-process
|
| 382 |
+
if latent_embeds is None:
|
| 383 |
+
sample = self.conv_norm_out(sample)
|
| 384 |
+
else:
|
| 385 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
| 386 |
+
sample = self.conv_act(sample)
|
| 387 |
+
sample = self.conv_out(sample)
|
| 388 |
+
|
| 389 |
+
return sample
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class DiagonalGaussianDistribution(object):
|
| 393 |
+
def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
|
| 394 |
+
if parameters.ndim == 3:
|
| 395 |
+
dim = 2 # (B, L, C)
|
| 396 |
+
elif parameters.ndim == 5 or parameters.ndim == 4:
|
| 397 |
+
dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
|
| 398 |
+
else:
|
| 399 |
+
raise NotImplementedError
|
| 400 |
+
self.parameters = parameters
|
| 401 |
+
self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
|
| 402 |
+
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
| 403 |
+
self.deterministic = deterministic
|
| 404 |
+
self.std = torch.exp(0.5 * self.logvar)
|
| 405 |
+
self.var = torch.exp(self.logvar)
|
| 406 |
+
if self.deterministic:
|
| 407 |
+
self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
|
| 408 |
+
|
| 409 |
+
def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
|
| 410 |
+
# make sure sample is on the same device as the parameters and has same dtype
|
| 411 |
+
sample = randn_tensor(
|
| 412 |
+
self.mean.shape,
|
| 413 |
+
generator=generator,
|
| 414 |
+
device=self.parameters.device,
|
| 415 |
+
dtype=self.parameters.dtype,
|
| 416 |
+
)
|
| 417 |
+
x = self.mean + self.std * sample
|
| 418 |
+
return x
|
| 419 |
+
|
| 420 |
+
def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
|
| 421 |
+
if self.deterministic:
|
| 422 |
+
return torch.Tensor([0.0])
|
| 423 |
+
else:
|
| 424 |
+
reduce_dim = list(range(1, self.mean.ndim))
|
| 425 |
+
if other is None:
|
| 426 |
+
return 0.5 * torch.sum(
|
| 427 |
+
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
| 428 |
+
dim=reduce_dim,
|
| 429 |
+
)
|
| 430 |
+
else:
|
| 431 |
+
return 0.5 * torch.sum(
|
| 432 |
+
torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
|
| 433 |
+
dim=reduce_dim,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
|
| 437 |
+
if self.deterministic:
|
| 438 |
+
return torch.Tensor([0.0])
|
| 439 |
+
logtwopi = np.log(2.0 * np.pi)
|
| 440 |
+
return 0.5 * torch.sum(
|
| 441 |
+
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
| 442 |
+
dim=dims,
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
def mode(self) -> torch.Tensor:
|
| 446 |
+
return self.mean
|
hv_generate_video.py
ADDED
|
@@ -0,0 +1,936 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import random
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
from typing import Optional, Union
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torchvision
|
| 13 |
+
import accelerate
|
| 14 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 15 |
+
from transformers.models.llama import LlamaModel
|
| 16 |
+
from tqdm import tqdm
|
| 17 |
+
import av
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from safetensors.torch import load_file, save_file
|
| 20 |
+
from safetensors import safe_open
|
| 21 |
+
from PIL import Image
|
| 22 |
+
|
| 23 |
+
from hunyuan_model import vae
|
| 24 |
+
from hunyuan_model.text_encoder import TextEncoder
|
| 25 |
+
from hunyuan_model.text_encoder import PROMPT_TEMPLATE
|
| 26 |
+
from hunyuan_model.vae import load_vae
|
| 27 |
+
from hunyuan_model.models import load_transformer, get_rotary_pos_embed
|
| 28 |
+
from hunyuan_model.fp8_optimization import convert_fp8_linear
|
| 29 |
+
from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
|
| 30 |
+
from networks import lora
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
from lycoris.kohya import create_network_from_weights
|
| 34 |
+
except:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
from utils.model_utils import str_to_dtype
|
| 38 |
+
from utils.safetensors_utils import mem_eff_save_file
|
| 39 |
+
from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket
|
| 40 |
+
|
| 41 |
+
import logging
|
| 42 |
+
|
| 43 |
+
logger = logging.getLogger(__name__)
|
| 44 |
+
logging.basicConfig(level=logging.INFO)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def clean_memory_on_device(device):
|
| 48 |
+
if device.type == "cuda":
|
| 49 |
+
torch.cuda.empty_cache()
|
| 50 |
+
elif device.type == "cpu":
|
| 51 |
+
pass
|
| 52 |
+
elif device.type == "mps": # not tested
|
| 53 |
+
torch.mps.empty_cache()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def synchronize_device(device: torch.device):
|
| 57 |
+
if device.type == "cuda":
|
| 58 |
+
torch.cuda.synchronize()
|
| 59 |
+
elif device.type == "xpu":
|
| 60 |
+
torch.xpu.synchronize()
|
| 61 |
+
elif device.type == "mps":
|
| 62 |
+
torch.mps.synchronize()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
|
| 66 |
+
"""save videos by video tensor
|
| 67 |
+
copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
videos (torch.Tensor): video tensor predicted by the model
|
| 71 |
+
path (str): path to save video
|
| 72 |
+
rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
|
| 73 |
+
n_rows (int, optional): Defaults to 1.
|
| 74 |
+
fps (int, optional): video save fps. Defaults to 8.
|
| 75 |
+
"""
|
| 76 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
| 77 |
+
outputs = []
|
| 78 |
+
for x in videos:
|
| 79 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
| 80 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
| 81 |
+
if rescale:
|
| 82 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
| 83 |
+
x = torch.clamp(x, 0, 1)
|
| 84 |
+
x = (x * 255).numpy().astype(np.uint8)
|
| 85 |
+
outputs.append(x)
|
| 86 |
+
|
| 87 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 88 |
+
|
| 89 |
+
# # save video with av
|
| 90 |
+
# container = av.open(path, "w")
|
| 91 |
+
# stream = container.add_stream("libx264", rate=fps)
|
| 92 |
+
# for x in outputs:
|
| 93 |
+
# frame = av.VideoFrame.from_ndarray(x, format="rgb24")
|
| 94 |
+
# packet = stream.encode(frame)
|
| 95 |
+
# container.mux(packet)
|
| 96 |
+
# packet = stream.encode(None)
|
| 97 |
+
# container.mux(packet)
|
| 98 |
+
# container.close()
|
| 99 |
+
|
| 100 |
+
height, width, _ = outputs[0].shape
|
| 101 |
+
|
| 102 |
+
# create output container
|
| 103 |
+
container = av.open(path, mode="w")
|
| 104 |
+
|
| 105 |
+
# create video stream
|
| 106 |
+
codec = "libx264"
|
| 107 |
+
pixel_format = "yuv420p"
|
| 108 |
+
stream = container.add_stream(codec, rate=fps)
|
| 109 |
+
stream.width = width
|
| 110 |
+
stream.height = height
|
| 111 |
+
stream.pix_fmt = pixel_format
|
| 112 |
+
stream.bit_rate = 4000000 # 4Mbit/s
|
| 113 |
+
|
| 114 |
+
for frame_array in outputs:
|
| 115 |
+
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
| 116 |
+
packets = stream.encode(frame)
|
| 117 |
+
for packet in packets:
|
| 118 |
+
container.mux(packet)
|
| 119 |
+
|
| 120 |
+
for packet in stream.encode():
|
| 121 |
+
container.mux(packet)
|
| 122 |
+
|
| 123 |
+
container.close()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def save_images_grid(
|
| 127 |
+
videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True
|
| 128 |
+
):
|
| 129 |
+
videos = rearrange(videos, "b c t h w -> t b c h w")
|
| 130 |
+
outputs = []
|
| 131 |
+
for x in videos:
|
| 132 |
+
x = torchvision.utils.make_grid(x, nrow=n_rows)
|
| 133 |
+
x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
|
| 134 |
+
if rescale:
|
| 135 |
+
x = (x + 1.0) / 2.0 # -1,1 -> 0,1
|
| 136 |
+
x = torch.clamp(x, 0, 1)
|
| 137 |
+
x = (x * 255).numpy().astype(np.uint8)
|
| 138 |
+
outputs.append(x)
|
| 139 |
+
|
| 140 |
+
if create_subdir:
|
| 141 |
+
output_dir = os.path.join(parent_dir, image_name)
|
| 142 |
+
else:
|
| 143 |
+
output_dir = parent_dir
|
| 144 |
+
|
| 145 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 146 |
+
for i, x in enumerate(outputs):
|
| 147 |
+
image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png")
|
| 148 |
+
image = Image.fromarray(x)
|
| 149 |
+
image.save(image_path)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# region Encoding prompt
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
|
| 156 |
+
r"""
|
| 157 |
+
Encodes the prompt into text encoder hidden states.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
prompt (`str` or `List[str]`):
|
| 161 |
+
prompt to be encoded
|
| 162 |
+
device: (`torch.device`):
|
| 163 |
+
torch device
|
| 164 |
+
num_videos_per_prompt (`int`):
|
| 165 |
+
number of videos that should be generated per prompt
|
| 166 |
+
text_encoder (TextEncoder):
|
| 167 |
+
text encoder to be used for encoding the prompt
|
| 168 |
+
"""
|
| 169 |
+
# LoRA and Textual Inversion are not supported in this script
|
| 170 |
+
# negative prompt and prompt embedding are not supported in this script
|
| 171 |
+
# clip_skip is not supported in this script because it is not used in the original script
|
| 172 |
+
data_type = "video" # video only, image is not supported
|
| 173 |
+
|
| 174 |
+
text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
|
| 175 |
+
|
| 176 |
+
with torch.no_grad():
|
| 177 |
+
prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
|
| 178 |
+
prompt_embeds = prompt_outputs.hidden_state
|
| 179 |
+
|
| 180 |
+
attention_mask = prompt_outputs.attention_mask
|
| 181 |
+
if attention_mask is not None:
|
| 182 |
+
attention_mask = attention_mask.to(device)
|
| 183 |
+
bs_embed, seq_len = attention_mask.shape
|
| 184 |
+
attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
|
| 185 |
+
attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
|
| 186 |
+
|
| 187 |
+
prompt_embeds_dtype = text_encoder.dtype
|
| 188 |
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
| 189 |
+
|
| 190 |
+
if prompt_embeds.ndim == 2:
|
| 191 |
+
bs_embed, _ = prompt_embeds.shape
|
| 192 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 193 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
|
| 194 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
|
| 195 |
+
else:
|
| 196 |
+
bs_embed, seq_len, _ = prompt_embeds.shape
|
| 197 |
+
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
| 198 |
+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
|
| 199 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
| 200 |
+
|
| 201 |
+
return prompt_embeds, attention_mask
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None):
|
| 205 |
+
# constants
|
| 206 |
+
prompt_template_video = "dit-llm-encode-video"
|
| 207 |
+
prompt_template = "dit-llm-encode"
|
| 208 |
+
text_encoder_dtype = torch.float16
|
| 209 |
+
text_encoder_type = "llm"
|
| 210 |
+
text_len = 256
|
| 211 |
+
hidden_state_skip_layer = 2
|
| 212 |
+
apply_final_norm = False
|
| 213 |
+
reproduce = False
|
| 214 |
+
|
| 215 |
+
text_encoder_2_type = "clipL"
|
| 216 |
+
text_len_2 = 77
|
| 217 |
+
|
| 218 |
+
num_videos = 1
|
| 219 |
+
|
| 220 |
+
# if args.prompt_template_video is not None:
|
| 221 |
+
# crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
|
| 222 |
+
# elif args.prompt_template is not None:
|
| 223 |
+
# crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
|
| 224 |
+
# else:
|
| 225 |
+
# crop_start = 0
|
| 226 |
+
crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0)
|
| 227 |
+
max_length = text_len + crop_start
|
| 228 |
+
|
| 229 |
+
# prompt_template
|
| 230 |
+
prompt_template = PROMPT_TEMPLATE[prompt_template]
|
| 231 |
+
|
| 232 |
+
# prompt_template_video
|
| 233 |
+
prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None
|
| 234 |
+
|
| 235 |
+
# load text encoders
|
| 236 |
+
logger.info(f"loading text encoder: {args.text_encoder1}")
|
| 237 |
+
text_encoder = TextEncoder(
|
| 238 |
+
text_encoder_type=text_encoder_type,
|
| 239 |
+
max_length=max_length,
|
| 240 |
+
text_encoder_dtype=text_encoder_dtype,
|
| 241 |
+
text_encoder_path=args.text_encoder1,
|
| 242 |
+
tokenizer_type=text_encoder_type,
|
| 243 |
+
prompt_template=prompt_template,
|
| 244 |
+
prompt_template_video=prompt_template_video,
|
| 245 |
+
hidden_state_skip_layer=hidden_state_skip_layer,
|
| 246 |
+
apply_final_norm=apply_final_norm,
|
| 247 |
+
reproduce=reproduce,
|
| 248 |
+
)
|
| 249 |
+
text_encoder.eval()
|
| 250 |
+
if fp8_llm:
|
| 251 |
+
org_dtype = text_encoder.dtype
|
| 252 |
+
logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
|
| 253 |
+
text_encoder.to(device=device, dtype=torch.float8_e4m3fn)
|
| 254 |
+
|
| 255 |
+
# prepare LLM for fp8
|
| 256 |
+
def prepare_fp8(llama_model: LlamaModel, target_dtype):
|
| 257 |
+
def forward_hook(module):
|
| 258 |
+
def forward(hidden_states):
|
| 259 |
+
input_dtype = hidden_states.dtype
|
| 260 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 261 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 262 |
+
hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
|
| 263 |
+
return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
|
| 264 |
+
|
| 265 |
+
return forward
|
| 266 |
+
|
| 267 |
+
for module in llama_model.modules():
|
| 268 |
+
if module.__class__.__name__ in ["Embedding"]:
|
| 269 |
+
# print("set", module.__class__.__name__, "to", target_dtype)
|
| 270 |
+
module.to(target_dtype)
|
| 271 |
+
if module.__class__.__name__ in ["LlamaRMSNorm"]:
|
| 272 |
+
# print("set", module.__class__.__name__, "hooks")
|
| 273 |
+
module.forward = forward_hook(module)
|
| 274 |
+
|
| 275 |
+
prepare_fp8(text_encoder.model, org_dtype)
|
| 276 |
+
|
| 277 |
+
logger.info(f"loading text encoder 2: {args.text_encoder2}")
|
| 278 |
+
text_encoder_2 = TextEncoder(
|
| 279 |
+
text_encoder_type=text_encoder_2_type,
|
| 280 |
+
max_length=text_len_2,
|
| 281 |
+
text_encoder_dtype=text_encoder_dtype,
|
| 282 |
+
text_encoder_path=args.text_encoder2,
|
| 283 |
+
tokenizer_type=text_encoder_2_type,
|
| 284 |
+
reproduce=reproduce,
|
| 285 |
+
)
|
| 286 |
+
text_encoder_2.eval()
|
| 287 |
+
|
| 288 |
+
# encode prompt
|
| 289 |
+
logger.info(f"Encoding prompt with text encoder 1")
|
| 290 |
+
text_encoder.to(device=device)
|
| 291 |
+
if fp8_llm:
|
| 292 |
+
with accelerator.autocast():
|
| 293 |
+
prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
|
| 294 |
+
else:
|
| 295 |
+
prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
|
| 296 |
+
text_encoder = None
|
| 297 |
+
clean_memory_on_device(device)
|
| 298 |
+
|
| 299 |
+
logger.info(f"Encoding prompt with text encoder 2")
|
| 300 |
+
text_encoder_2.to(device=device)
|
| 301 |
+
prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2)
|
| 302 |
+
|
| 303 |
+
prompt_embeds = prompt_embeds.to("cpu")
|
| 304 |
+
prompt_mask = prompt_mask.to("cpu")
|
| 305 |
+
prompt_embeds_2 = prompt_embeds_2.to("cpu")
|
| 306 |
+
prompt_mask_2 = prompt_mask_2.to("cpu")
|
| 307 |
+
|
| 308 |
+
text_encoder_2 = None
|
| 309 |
+
clean_memory_on_device(device)
|
| 310 |
+
|
| 311 |
+
return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# endregion
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def prepare_vae(args, device):
|
| 318 |
+
vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
|
| 319 |
+
vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
|
| 320 |
+
vae.eval()
|
| 321 |
+
# vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
|
| 322 |
+
|
| 323 |
+
# set chunk_size to CausalConv3d recursively
|
| 324 |
+
chunk_size = args.vae_chunk_size
|
| 325 |
+
if chunk_size is not None:
|
| 326 |
+
vae.set_chunk_size_for_causal_conv_3d(chunk_size)
|
| 327 |
+
logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
|
| 328 |
+
|
| 329 |
+
if args.vae_spatial_tile_sample_min_size is not None:
|
| 330 |
+
vae.enable_spatial_tiling(True)
|
| 331 |
+
vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
|
| 332 |
+
vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
|
| 333 |
+
# elif args.vae_tiling:
|
| 334 |
+
else:
|
| 335 |
+
vae.enable_spatial_tiling(True)
|
| 336 |
+
|
| 337 |
+
return vae, vae_dtype
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
def encode_to_latents(args, video, device):
|
| 341 |
+
vae, vae_dtype = prepare_vae(args, device)
|
| 342 |
+
|
| 343 |
+
video = video.to(device=device, dtype=vae_dtype)
|
| 344 |
+
video = video * 2 - 1 # 0, 1 -> -1, 1
|
| 345 |
+
with torch.no_grad():
|
| 346 |
+
latents = vae.encode(video).latent_dist.sample()
|
| 347 |
+
|
| 348 |
+
if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
|
| 349 |
+
latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
|
| 350 |
+
else:
|
| 351 |
+
latents = latents * vae.config.scaling_factor
|
| 352 |
+
|
| 353 |
+
return latents
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def decode_latents(args, latents, device):
|
| 357 |
+
vae, vae_dtype = prepare_vae(args, device)
|
| 358 |
+
|
| 359 |
+
expand_temporal_dim = False
|
| 360 |
+
if len(latents.shape) == 4:
|
| 361 |
+
latents = latents.unsqueeze(2)
|
| 362 |
+
expand_temporal_dim = True
|
| 363 |
+
elif len(latents.shape) == 5:
|
| 364 |
+
pass
|
| 365 |
+
else:
|
| 366 |
+
raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
|
| 367 |
+
|
| 368 |
+
if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
|
| 369 |
+
latents = latents / vae.config.scaling_factor + vae.config.shift_factor
|
| 370 |
+
else:
|
| 371 |
+
latents = latents / vae.config.scaling_factor
|
| 372 |
+
|
| 373 |
+
latents = latents.to(device=device, dtype=vae_dtype)
|
| 374 |
+
with torch.no_grad():
|
| 375 |
+
image = vae.decode(latents, return_dict=False)[0]
|
| 376 |
+
|
| 377 |
+
if expand_temporal_dim:
|
| 378 |
+
image = image.squeeze(2)
|
| 379 |
+
|
| 380 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
| 381 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
|
| 382 |
+
image = image.cpu().float()
|
| 383 |
+
|
| 384 |
+
return image
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def parse_args():
|
| 388 |
+
parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
|
| 389 |
+
|
| 390 |
+
parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
|
| 391 |
+
parser.add_argument(
|
| 392 |
+
"--dit_in_channels",
|
| 393 |
+
type=int,
|
| 394 |
+
default=None,
|
| 395 |
+
help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others",
|
| 396 |
+
)
|
| 397 |
+
parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory")
|
| 398 |
+
parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
|
| 399 |
+
parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
|
| 400 |
+
parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
|
| 401 |
+
|
| 402 |
+
# LoRA
|
| 403 |
+
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
|
| 404 |
+
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
|
| 405 |
+
parser.add_argument(
|
| 406 |
+
"--save_merged_model",
|
| 407 |
+
type=str,
|
| 408 |
+
default=None,
|
| 409 |
+
help="Save merged model to path. If specified, no inference will be performed.",
|
| 410 |
+
)
|
| 411 |
+
parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights")
|
| 412 |
+
|
| 413 |
+
# inference
|
| 414 |
+
parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
|
| 415 |
+
parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation")
|
| 416 |
+
parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size")
|
| 417 |
+
parser.add_argument("--video_length", type=int, default=129, help="video length")
|
| 418 |
+
parser.add_argument("--fps", type=int, default=24, help="video fps")
|
| 419 |
+
parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps")
|
| 420 |
+
parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
|
| 421 |
+
parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
|
| 422 |
+
parser.add_argument(
|
| 423 |
+
"--guidance_scale",
|
| 424 |
+
type=float,
|
| 425 |
+
default=1.0,
|
| 426 |
+
help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)",
|
| 427 |
+
)
|
| 428 |
+
parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.")
|
| 429 |
+
parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
|
| 430 |
+
parser.add_argument(
|
| 431 |
+
"--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model"
|
| 432 |
+
)
|
| 433 |
+
parser.add_argument(
|
| 434 |
+
"--split_uncond",
|
| 435 |
+
action="store_true",
|
| 436 |
+
help="split unconditional call for classifier free guidance, slower but less memory usage",
|
| 437 |
+
)
|
| 438 |
+
parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference")
|
| 439 |
+
|
| 440 |
+
# Flow Matching
|
| 441 |
+
parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.")
|
| 442 |
+
|
| 443 |
+
parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
|
| 444 |
+
parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
|
| 445 |
+
parser.add_argument(
|
| 446 |
+
"--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
|
| 447 |
+
)
|
| 448 |
+
parser.add_argument(
|
| 449 |
+
"--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode"
|
| 450 |
+
)
|
| 451 |
+
parser.add_argument(
|
| 452 |
+
"--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True"
|
| 453 |
+
)
|
| 454 |
+
parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
|
| 455 |
+
parser.add_argument(
|
| 456 |
+
"--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
|
| 457 |
+
)
|
| 458 |
+
parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model")
|
| 459 |
+
parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu")
|
| 460 |
+
parser.add_argument(
|
| 461 |
+
"--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
|
| 462 |
+
)
|
| 463 |
+
parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
|
| 464 |
+
parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
|
| 465 |
+
parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
|
| 466 |
+
parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)")
|
| 467 |
+
parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
|
| 468 |
+
parser.add_argument(
|
| 469 |
+
"--compile_args",
|
| 470 |
+
nargs=4,
|
| 471 |
+
metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
|
| 472 |
+
default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
|
| 473 |
+
help="Torch.compile settings",
|
| 474 |
+
)
|
| 475 |
+
|
| 476 |
+
args = parser.parse_args()
|
| 477 |
+
|
| 478 |
+
assert (args.latent_path is None or len(args.latent_path) == 0) or (
|
| 479 |
+
args.output_type == "images" or args.output_type == "video"
|
| 480 |
+
), "latent_path is only supported for images or video output"
|
| 481 |
+
|
| 482 |
+
# update dit_weight based on model_base if not exists
|
| 483 |
+
|
| 484 |
+
if args.fp8_fast and not args.fp8:
|
| 485 |
+
raise ValueError("--fp8_fast requires --fp8")
|
| 486 |
+
|
| 487 |
+
return args
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
def check_inputs(args):
|
| 491 |
+
height = args.video_size[0]
|
| 492 |
+
width = args.video_size[1]
|
| 493 |
+
video_length = args.video_length
|
| 494 |
+
|
| 495 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 496 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 497 |
+
return height, width, video_length
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def main():
|
| 501 |
+
args = parse_args()
|
| 502 |
+
|
| 503 |
+
device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
|
| 504 |
+
device = torch.device(device)
|
| 505 |
+
dit_dtype = torch.bfloat16
|
| 506 |
+
dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype
|
| 507 |
+
logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
|
| 508 |
+
|
| 509 |
+
original_base_names = None
|
| 510 |
+
if args.latent_path is not None and len(args.latent_path) > 0:
|
| 511 |
+
original_base_names = []
|
| 512 |
+
latents_list = []
|
| 513 |
+
seeds = []
|
| 514 |
+
for latent_path in args.latent_path:
|
| 515 |
+
original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
|
| 516 |
+
seed = 0
|
| 517 |
+
|
| 518 |
+
if os.path.splitext(latent_path)[1] != ".safetensors":
|
| 519 |
+
latents = torch.load(latent_path, map_location="cpu")
|
| 520 |
+
else:
|
| 521 |
+
latents = load_file(latent_path)["latent"]
|
| 522 |
+
with safe_open(latent_path, framework="pt") as f:
|
| 523 |
+
metadata = f.metadata()
|
| 524 |
+
if metadata is None:
|
| 525 |
+
metadata = {}
|
| 526 |
+
logger.info(f"Loaded metadata: {metadata}")
|
| 527 |
+
|
| 528 |
+
if "seeds" in metadata:
|
| 529 |
+
seed = int(metadata["seeds"])
|
| 530 |
+
|
| 531 |
+
seeds.append(seed)
|
| 532 |
+
latents_list.append(latents)
|
| 533 |
+
|
| 534 |
+
logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
|
| 535 |
+
latents = torch.stack(latents_list, dim=0)
|
| 536 |
+
else:
|
| 537 |
+
# prepare accelerator
|
| 538 |
+
mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
|
| 539 |
+
accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
|
| 540 |
+
|
| 541 |
+
# load prompt
|
| 542 |
+
prompt = args.prompt # TODO load prompts from file
|
| 543 |
+
assert prompt is not None, "prompt is required"
|
| 544 |
+
|
| 545 |
+
# check inputs: may be height, width, video_length etc will be changed for each generation in future
|
| 546 |
+
height, width, video_length = check_inputs(args)
|
| 547 |
+
|
| 548 |
+
# encode prompt with LLM and Text Encoder
|
| 549 |
+
logger.info(f"Encoding prompt: {prompt}")
|
| 550 |
+
|
| 551 |
+
do_classifier_free_guidance = args.guidance_scale != 1.0
|
| 552 |
+
if do_classifier_free_guidance:
|
| 553 |
+
negative_prompt = args.negative_prompt
|
| 554 |
+
if negative_prompt is None:
|
| 555 |
+
logger.info("Negative prompt is not provided, using empty prompt")
|
| 556 |
+
negative_prompt = ""
|
| 557 |
+
logger.info(f"Encoding negative prompt: {negative_prompt}")
|
| 558 |
+
prompt = [negative_prompt, prompt]
|
| 559 |
+
else:
|
| 560 |
+
if args.negative_prompt is not None:
|
| 561 |
+
logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.")
|
| 562 |
+
|
| 563 |
+
prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt(
|
| 564 |
+
prompt, args, device, args.fp8_llm, accelerator
|
| 565 |
+
)
|
| 566 |
+
|
| 567 |
+
# encode latents for video2video inference
|
| 568 |
+
video_latents = None
|
| 569 |
+
if args.video_path is not None:
|
| 570 |
+
# v2v inference
|
| 571 |
+
logger.info(f"Video2Video inference: {args.video_path}")
|
| 572 |
+
video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames
|
| 573 |
+
if len(video) < video_length:
|
| 574 |
+
raise ValueError(f"Video length is less than {video_length}")
|
| 575 |
+
video = np.stack(video, axis=0) # F, H, W, C
|
| 576 |
+
video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W
|
| 577 |
+
video = video / 255.0
|
| 578 |
+
|
| 579 |
+
logger.info(f"Encoding video to latents")
|
| 580 |
+
video_latents = encode_to_latents(args, video, device)
|
| 581 |
+
video_latents = video_latents.to(device=device, dtype=dit_dtype)
|
| 582 |
+
|
| 583 |
+
clean_memory_on_device(device)
|
| 584 |
+
|
| 585 |
+
# encode latents for image2video inference
|
| 586 |
+
image_latents = None
|
| 587 |
+
if args.image_path is not None:
|
| 588 |
+
# i2v inference
|
| 589 |
+
logger.info(f"Image2Video inference: {args.image_path}")
|
| 590 |
+
|
| 591 |
+
image = Image.open(args.image_path)
|
| 592 |
+
image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
|
| 593 |
+
image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W
|
| 594 |
+
image = image / 255.0
|
| 595 |
+
|
| 596 |
+
logger.info(f"Encoding image to latents")
|
| 597 |
+
image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W
|
| 598 |
+
image_latents = image_latents.to(device=device, dtype=dit_dtype)
|
| 599 |
+
|
| 600 |
+
clean_memory_on_device(device)
|
| 601 |
+
|
| 602 |
+
# load DiT model
|
| 603 |
+
blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
|
| 604 |
+
loading_device = "cpu" # if blocks_to_swap > 0 else device
|
| 605 |
+
|
| 606 |
+
logger.info(f"Loading DiT model from {args.dit}")
|
| 607 |
+
if args.attn_mode == "sdpa":
|
| 608 |
+
args.attn_mode = "torch"
|
| 609 |
+
|
| 610 |
+
# if image_latents is given, the model should be I2V model, so the in_channels should be 32
|
| 611 |
+
dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16)
|
| 612 |
+
|
| 613 |
+
# if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16
|
| 614 |
+
# the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway
|
| 615 |
+
# on the fly merging will be a solution for this issue for .safetenors files (not implemented yet)
|
| 616 |
+
transformer = load_transformer(
|
| 617 |
+
args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels
|
| 618 |
+
)
|
| 619 |
+
transformer.eval()
|
| 620 |
+
|
| 621 |
+
# load LoRA weights
|
| 622 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
| 623 |
+
for i, lora_weight in enumerate(args.lora_weight):
|
| 624 |
+
if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
|
| 625 |
+
lora_multiplier = args.lora_multiplier[i]
|
| 626 |
+
else:
|
| 627 |
+
lora_multiplier = 1.0
|
| 628 |
+
|
| 629 |
+
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
|
| 630 |
+
weights_sd = load_file(lora_weight)
|
| 631 |
+
|
| 632 |
+
# Filter to exclude keys that are part of single_blocks
|
| 633 |
+
if args.exclude_single_blocks:
|
| 634 |
+
filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k}
|
| 635 |
+
weights_sd = filtered_weights
|
| 636 |
+
|
| 637 |
+
if args.lycoris:
|
| 638 |
+
lycoris_net, _ = create_network_from_weights(
|
| 639 |
+
multiplier=lora_multiplier,
|
| 640 |
+
file=None,
|
| 641 |
+
weights_sd=weights_sd,
|
| 642 |
+
unet=transformer,
|
| 643 |
+
text_encoder=None,
|
| 644 |
+
vae=None,
|
| 645 |
+
for_inference=True,
|
| 646 |
+
)
|
| 647 |
+
else:
|
| 648 |
+
network = lora.create_arch_network_from_weights(
|
| 649 |
+
lora_multiplier, weights_sd, unet=transformer, for_inference=True
|
| 650 |
+
)
|
| 651 |
+
logger.info("Merging LoRA weights to DiT model")
|
| 652 |
+
|
| 653 |
+
# try:
|
| 654 |
+
# network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
|
| 655 |
+
# info = network.load_state_dict(weights_sd, strict=True)
|
| 656 |
+
# logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
|
| 657 |
+
# network.eval()
|
| 658 |
+
# network.to(device)
|
| 659 |
+
# except Exception as e:
|
| 660 |
+
if args.lycoris:
|
| 661 |
+
lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device)
|
| 662 |
+
else:
|
| 663 |
+
network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
|
| 664 |
+
|
| 665 |
+
synchronize_device(device)
|
| 666 |
+
|
| 667 |
+
logger.info("LoRA weights loaded")
|
| 668 |
+
|
| 669 |
+
# save model here before casting to dit_weight_dtype
|
| 670 |
+
if args.save_merged_model:
|
| 671 |
+
logger.info(f"Saving merged model to {args.save_merged_model}")
|
| 672 |
+
mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory
|
| 673 |
+
logger.info("Merged model saved")
|
| 674 |
+
return
|
| 675 |
+
|
| 676 |
+
logger.info(f"Casting model to {dit_weight_dtype}")
|
| 677 |
+
transformer.to(dtype=dit_weight_dtype)
|
| 678 |
+
|
| 679 |
+
if args.fp8_fast:
|
| 680 |
+
logger.info("Enabling FP8 acceleration")
|
| 681 |
+
params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"}
|
| 682 |
+
for name, param in transformer.named_parameters():
|
| 683 |
+
dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype
|
| 684 |
+
param.to(dtype=dtype_to_use)
|
| 685 |
+
convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep)
|
| 686 |
+
|
| 687 |
+
if args.compile:
|
| 688 |
+
compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
|
| 689 |
+
logger.info(
|
| 690 |
+
f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
|
| 691 |
+
)
|
| 692 |
+
torch._dynamo.config.cache_size_limit = 32
|
| 693 |
+
for i, block in enumerate(transformer.single_blocks):
|
| 694 |
+
compiled_block = torch.compile(
|
| 695 |
+
block,
|
| 696 |
+
backend=compile_backend,
|
| 697 |
+
mode=compile_mode,
|
| 698 |
+
dynamic=compile_dynamic.lower() in "true",
|
| 699 |
+
fullgraph=compile_fullgraph.lower() in "true",
|
| 700 |
+
)
|
| 701 |
+
transformer.single_blocks[i] = compiled_block
|
| 702 |
+
for i, block in enumerate(transformer.double_blocks):
|
| 703 |
+
compiled_block = torch.compile(
|
| 704 |
+
block,
|
| 705 |
+
backend=compile_backend,
|
| 706 |
+
mode=compile_mode,
|
| 707 |
+
dynamic=compile_dynamic.lower() in "true",
|
| 708 |
+
fullgraph=compile_fullgraph.lower() in "true",
|
| 709 |
+
)
|
| 710 |
+
transformer.double_blocks[i] = compiled_block
|
| 711 |
+
|
| 712 |
+
if blocks_to_swap > 0:
|
| 713 |
+
logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
|
| 714 |
+
transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
|
| 715 |
+
transformer.move_to_device_except_swap_blocks(device)
|
| 716 |
+
transformer.prepare_block_swap_before_forward()
|
| 717 |
+
else:
|
| 718 |
+
logger.info(f"Moving model to {device}")
|
| 719 |
+
transformer.to(device=device)
|
| 720 |
+
if args.img_in_txt_in_offloading:
|
| 721 |
+
logger.info("Enable offloading img_in and txt_in to CPU")
|
| 722 |
+
transformer.enable_img_in_txt_in_offloading()
|
| 723 |
+
|
| 724 |
+
# load scheduler
|
| 725 |
+
logger.info(f"Loading scheduler")
|
| 726 |
+
scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler")
|
| 727 |
+
|
| 728 |
+
# Prepare timesteps
|
| 729 |
+
num_inference_steps = args.infer_steps
|
| 730 |
+
scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler
|
| 731 |
+
timesteps = scheduler.timesteps
|
| 732 |
+
|
| 733 |
+
# Prepare generator
|
| 734 |
+
num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size
|
| 735 |
+
seed = args.seed
|
| 736 |
+
if seed is None:
|
| 737 |
+
seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)]
|
| 738 |
+
elif isinstance(seed, int):
|
| 739 |
+
seeds = [seed + i for i in range(num_videos_per_prompt)]
|
| 740 |
+
else:
|
| 741 |
+
raise ValueError(f"Seed must be an integer or None, got {seed}.")
|
| 742 |
+
generator = [torch.Generator(device).manual_seed(seed) for seed in seeds]
|
| 743 |
+
|
| 744 |
+
# Prepare noisy latents
|
| 745 |
+
num_channels_latents = 16 # transformer.config.in_channels
|
| 746 |
+
vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4
|
| 747 |
+
|
| 748 |
+
vae_ver = vae.VAE_VER
|
| 749 |
+
if "884" in vae_ver:
|
| 750 |
+
latent_video_length = (video_length - 1) // 4 + 1
|
| 751 |
+
elif "888" in vae_ver:
|
| 752 |
+
latent_video_length = (video_length - 1) // 8 + 1
|
| 753 |
+
else:
|
| 754 |
+
latent_video_length = video_length
|
| 755 |
+
|
| 756 |
+
# shape = (
|
| 757 |
+
# num_videos_per_prompt,
|
| 758 |
+
# num_channels_latents,
|
| 759 |
+
# latent_video_length,
|
| 760 |
+
# height // vae_scale_factor,
|
| 761 |
+
# width // vae_scale_factor,
|
| 762 |
+
# )
|
| 763 |
+
# latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype)
|
| 764 |
+
|
| 765 |
+
# make first N frames to be the same if the given seed is same
|
| 766 |
+
shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor)
|
| 767 |
+
latents = []
|
| 768 |
+
for i in range(latent_video_length):
|
| 769 |
+
latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype))
|
| 770 |
+
latents = torch.cat(latents, dim=2)
|
| 771 |
+
|
| 772 |
+
# pad image_latents to match the length of video_latents
|
| 773 |
+
if image_latents is not None:
|
| 774 |
+
zero_latents = torch.zeros_like(latents)
|
| 775 |
+
zero_latents[:, :, :1, :, :] = image_latents
|
| 776 |
+
image_latents = zero_latents
|
| 777 |
+
|
| 778 |
+
if args.video_path is not None:
|
| 779 |
+
# v2v inference
|
| 780 |
+
noise = latents
|
| 781 |
+
assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}"
|
| 782 |
+
|
| 783 |
+
num_inference_steps = int(num_inference_steps * args.strength)
|
| 784 |
+
timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time
|
| 785 |
+
t = timestep_start / 1000.0
|
| 786 |
+
latents = noise * t + video_latents * (1 - t)
|
| 787 |
+
|
| 788 |
+
timesteps = timesteps[-num_inference_steps:]
|
| 789 |
+
|
| 790 |
+
logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}")
|
| 791 |
+
|
| 792 |
+
# FlowMatchDiscreteScheduler does not have init_noise_sigma
|
| 793 |
+
|
| 794 |
+
# Denoising loop
|
| 795 |
+
embedded_guidance_scale = args.embedded_cfg_scale
|
| 796 |
+
if embedded_guidance_scale is not None:
|
| 797 |
+
guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu")
|
| 798 |
+
guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype)
|
| 799 |
+
if do_classifier_free_guidance:
|
| 800 |
+
guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0)
|
| 801 |
+
else:
|
| 802 |
+
guidance_expand = None
|
| 803 |
+
freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width)
|
| 804 |
+
# n_tokens = freqs_cos.shape[0]
|
| 805 |
+
|
| 806 |
+
# move and cast all inputs to the correct device and dtype
|
| 807 |
+
prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype)
|
| 808 |
+
prompt_mask = prompt_mask.to(device=device)
|
| 809 |
+
prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype)
|
| 810 |
+
prompt_mask_2 = prompt_mask_2.to(device=device)
|
| 811 |
+
|
| 812 |
+
freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
|
| 813 |
+
freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
|
| 814 |
+
|
| 815 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference
|
| 816 |
+
|
| 817 |
+
# assert split_uncond and split_attn
|
| 818 |
+
if args.split_attn and do_classifier_free_guidance and not args.split_uncond:
|
| 819 |
+
logger.warning("split_attn is enabled, split_uncond will be enabled as well.")
|
| 820 |
+
args.split_uncond = True
|
| 821 |
+
|
| 822 |
+
# with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p:
|
| 823 |
+
with tqdm(total=num_inference_steps) as progress_bar:
|
| 824 |
+
for i, t in enumerate(timesteps):
|
| 825 |
+
latents = scheduler.scale_model_input(latents, t)
|
| 826 |
+
|
| 827 |
+
# predict the noise residual
|
| 828 |
+
with torch.no_grad(), accelerator.autocast():
|
| 829 |
+
latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0)
|
| 830 |
+
if image_latents is not None:
|
| 831 |
+
latents_image_input = (
|
| 832 |
+
image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0)
|
| 833 |
+
)
|
| 834 |
+
latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W
|
| 835 |
+
|
| 836 |
+
batch_size = 1 if args.split_uncond else latents_input.shape[0]
|
| 837 |
+
|
| 838 |
+
noise_pred_list = []
|
| 839 |
+
for j in range(0, latents_input.shape[0], batch_size):
|
| 840 |
+
noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
|
| 841 |
+
latents_input[j : j + batch_size], # [1, 16, 33, 24, 42]
|
| 842 |
+
t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1]
|
| 843 |
+
text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096]
|
| 844 |
+
text_mask=prompt_mask[j : j + batch_size], # [1, 256]
|
| 845 |
+
text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768]
|
| 846 |
+
freqs_cos=freqs_cos, # [seqlen, head_dim]
|
| 847 |
+
freqs_sin=freqs_sin, # [seqlen, head_dim]
|
| 848 |
+
guidance=guidance_expand[j : j + batch_size], # [1]
|
| 849 |
+
return_dict=True,
|
| 850 |
+
)["x"]
|
| 851 |
+
noise_pred_list.append(noise_pred)
|
| 852 |
+
noise_pred = torch.cat(noise_pred_list, dim=0)
|
| 853 |
+
|
| 854 |
+
# perform classifier free guidance
|
| 855 |
+
if do_classifier_free_guidance:
|
| 856 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 857 |
+
noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 858 |
+
|
| 859 |
+
# # SkyReels' rescale noise config is omitted for now
|
| 860 |
+
# if guidance_rescale > 0.0:
|
| 861 |
+
# # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
| 862 |
+
# noise_pred = rescale_noise_cfg(
|
| 863 |
+
# noise_pred,
|
| 864 |
+
# noise_pred_cond,
|
| 865 |
+
# guidance_rescale=self.guidance_rescale,
|
| 866 |
+
# )
|
| 867 |
+
|
| 868 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 869 |
+
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
| 870 |
+
|
| 871 |
+
# update progress bar
|
| 872 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
|
| 873 |
+
if progress_bar is not None:
|
| 874 |
+
progress_bar.update()
|
| 875 |
+
|
| 876 |
+
# print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
|
| 877 |
+
# print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
| 878 |
+
|
| 879 |
+
latents = latents.detach().cpu()
|
| 880 |
+
transformer = None
|
| 881 |
+
clean_memory_on_device(device)
|
| 882 |
+
|
| 883 |
+
# Save samples
|
| 884 |
+
output_type = args.output_type
|
| 885 |
+
save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
|
| 886 |
+
os.makedirs(save_path, exist_ok=True)
|
| 887 |
+
time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
|
| 888 |
+
|
| 889 |
+
if output_type == "latent" or output_type == "both":
|
| 890 |
+
# save latent
|
| 891 |
+
for i, latent in enumerate(latents):
|
| 892 |
+
latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors"
|
| 893 |
+
|
| 894 |
+
if args.no_metadata:
|
| 895 |
+
metadata = None
|
| 896 |
+
else:
|
| 897 |
+
metadata = {
|
| 898 |
+
"seeds": f"{seeds[i]}",
|
| 899 |
+
"prompt": f"{args.prompt}",
|
| 900 |
+
"height": f"{height}",
|
| 901 |
+
"width": f"{width}",
|
| 902 |
+
"video_length": f"{video_length}",
|
| 903 |
+
"infer_steps": f"{num_inference_steps}",
|
| 904 |
+
"guidance_scale": f"{args.guidance_scale}",
|
| 905 |
+
"embedded_cfg_scale": f"{args.embedded_cfg_scale}",
|
| 906 |
+
}
|
| 907 |
+
if args.negative_prompt is not None:
|
| 908 |
+
metadata["negative_prompt"] = f"{args.negative_prompt}"
|
| 909 |
+
sd = {"latent": latent}
|
| 910 |
+
save_file(sd, latent_path, metadata=metadata)
|
| 911 |
+
|
| 912 |
+
logger.info(f"Latent save to: {latent_path}")
|
| 913 |
+
if output_type == "video" or output_type == "both":
|
| 914 |
+
# save video
|
| 915 |
+
videos = decode_latents(args, latents, device)
|
| 916 |
+
for i, sample in enumerate(videos):
|
| 917 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
|
| 918 |
+
sample = sample.unsqueeze(0)
|
| 919 |
+
video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4"
|
| 920 |
+
save_videos_grid(sample, video_path, fps=args.fps)
|
| 921 |
+
logger.info(f"Sample save to: {video_path}")
|
| 922 |
+
elif output_type == "images":
|
| 923 |
+
# save images
|
| 924 |
+
videos = decode_latents(args, latents, device)
|
| 925 |
+
for i, sample in enumerate(videos):
|
| 926 |
+
original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
|
| 927 |
+
sample = sample.unsqueeze(0)
|
| 928 |
+
image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}"
|
| 929 |
+
save_images_grid(sample, save_path, image_name)
|
| 930 |
+
logger.info(f"Sample images save to: {save_path}/{image_name}")
|
| 931 |
+
|
| 932 |
+
logger.info("Done!")
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
if __name__ == "__main__":
|
| 936 |
+
main()
|
merge_lora.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import logging
|
| 3 |
+
import torch
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
from networks import lora
|
| 6 |
+
from utils.safetensors_utils import mem_eff_save_file
|
| 7 |
+
from hunyuan_model.models import load_transformer
|
| 8 |
+
|
| 9 |
+
logger = logging.getLogger(__name__)
|
| 10 |
+
logging.basicConfig(level=logging.INFO)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def parse_args():
|
| 14 |
+
parser = argparse.ArgumentParser(description="HunyuanVideo model merger script")
|
| 15 |
+
|
| 16 |
+
parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
|
| 17 |
+
parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
|
| 18 |
+
parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
|
| 19 |
+
parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier (can specify multiple values)")
|
| 20 |
+
parser.add_argument("--save_merged_model", type=str, required=True, help="Path to save the merged model")
|
| 21 |
+
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for merging")
|
| 22 |
+
|
| 23 |
+
return parser.parse_args()
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
args = parse_args()
|
| 28 |
+
|
| 29 |
+
device = torch.device(args.device)
|
| 30 |
+
logger.info(f"Using device: {device}")
|
| 31 |
+
|
| 32 |
+
# Load DiT model
|
| 33 |
+
logger.info(f"Loading DiT model from {args.dit}")
|
| 34 |
+
transformer = load_transformer(args.dit, "torch", False, "cpu", torch.bfloat16, in_channels=args.dit_in_channels)
|
| 35 |
+
transformer.eval()
|
| 36 |
+
|
| 37 |
+
# Load LoRA weights and merge
|
| 38 |
+
if args.lora_weight is not None and len(args.lora_weight) > 0:
|
| 39 |
+
for i, lora_weight in enumerate(args.lora_weight):
|
| 40 |
+
# Use the corresponding lora_multiplier or default to 1.0
|
| 41 |
+
if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
|
| 42 |
+
lora_multiplier = args.lora_multiplier[i]
|
| 43 |
+
else:
|
| 44 |
+
lora_multiplier = 1.0
|
| 45 |
+
|
| 46 |
+
logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
|
| 47 |
+
weights_sd = load_file(lora_weight)
|
| 48 |
+
network = lora.create_arch_network_from_weights(
|
| 49 |
+
lora_multiplier, weights_sd, unet=transformer, for_inference=True
|
| 50 |
+
)
|
| 51 |
+
logger.info("Merging LoRA weights to DiT model")
|
| 52 |
+
network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
|
| 53 |
+
|
| 54 |
+
logger.info("LoRA weights loaded")
|
| 55 |
+
|
| 56 |
+
# Save the merged model
|
| 57 |
+
logger.info(f"Saving merged model to {args.save_merged_model}")
|
| 58 |
+
mem_eff_save_file(transformer.state_dict(), args.save_merged_model)
|
| 59 |
+
logger.info("Merged model saved")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if __name__ == "__main__":
|
| 63 |
+
main()
|
modules/__init__.py
ADDED
|
File without changes
|