svjack commited on
Commit
7bc5051
·
verified ·
1 Parent(s): 7c3fd88

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .github/FUNDING.yml +3 -0
  2. .gitignore +8 -0
  3. .ipynb_checkpoints/README-checkpoint.md +123 -0
  4. .python-version +1 -0
  5. README.md +123 -0
  6. cache_latents.py +339 -0
  7. cache_text_encoder_outputs.py +214 -0
  8. convert_lora.py +137 -0
  9. dataset/__init__.py +0 -0
  10. dataset/config_utils.py +384 -0
  11. dataset/dataset_config.md +486 -0
  12. dataset/image_video_dataset.py +1786 -0
  13. fpack_cache_latents.py +454 -0
  14. fpack_cache_text_encoder_outputs.py +110 -0
  15. fpack_generate_video.py +1711 -0
  16. frame_pack/__init__.py +0 -0
  17. frame_pack/bucket_tools.py +30 -0
  18. frame_pack/clip_vision.py +14 -0
  19. frame_pack/framepack_utils.py +273 -0
  20. frame_pack/hunyuan.py +134 -0
  21. frame_pack/hunyuan_video_packed.py +2015 -0
  22. frame_pack/k_diffusion_hunyuan.py +128 -0
  23. frame_pack/uni_pc_fm.py +142 -0
  24. frame_pack/utils.py +617 -0
  25. frame_pack/wrapper.py +51 -0
  26. framepack_edit_output/framepack-edit-lora-000001.safetensors +3 -0
  27. framepack_edit_output/framepack-edit-lora-000002.safetensors +3 -0
  28. framepack_edit_output/framepack-edit-lora-000003.safetensors +3 -0
  29. framepack_edit_output/framepack-edit-lora-000004.safetensors +3 -0
  30. framepack_edit_output/framepack-edit-lora-000005.safetensors +3 -0
  31. framepack_edit_output/framepack-edit-lora-000006.safetensors +3 -0
  32. hunyuan_model/__init__.py +0 -0
  33. hunyuan_model/activation_layers.py +23 -0
  34. hunyuan_model/attention.py +295 -0
  35. hunyuan_model/autoencoder_kl_causal_3d.py +609 -0
  36. hunyuan_model/embed_layers.py +132 -0
  37. hunyuan_model/fp8_optimization.py +39 -0
  38. hunyuan_model/helpers.py +40 -0
  39. hunyuan_model/mlp_layers.py +118 -0
  40. hunyuan_model/models.py +1044 -0
  41. hunyuan_model/modulate_layers.py +76 -0
  42. hunyuan_model/norm_layers.py +79 -0
  43. hunyuan_model/pipeline_hunyuan_video.py +1100 -0
  44. hunyuan_model/posemb_layers.py +310 -0
  45. hunyuan_model/text_encoder.py +710 -0
  46. hunyuan_model/token_refiner.py +245 -0
  47. hunyuan_model/vae.py +446 -0
  48. hv_generate_video.py +936 -0
  49. merge_lora.py +63 -0
  50. modules/__init__.py +0 -0
.github/FUNDING.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # These are supported funding model platforms
2
+
3
+ github: kohya-ss
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ .venv
3
+ venv/
4
+ logs/
5
+ uv.lock
6
+ main.exp
7
+ main.lib
8
+ main.obj
.ipynb_checkpoints/README-checkpoint.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FramePack Image Edit Early Lora
2
+
3
+ This repository contains the necessary steps and scripts to generate A edit of the Image using a image-to-video model.
4
+ The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create Edit Image based on a input Image and textual prompts.
5
+
6
+ ## Prerequisites
7
+
8
+ Before proceeding, ensure that you have the following installed on your system:
9
+
10
+ • **Ubuntu** (or a compatible Linux distribution)
11
+ • **Python 3.x**
12
+ • **pip** (Python package manager)
13
+ • **Git**
14
+ • **Git LFS** (Git Large File Storage)
15
+ • **FFmpeg**
16
+
17
+ ## Installation
18
+
19
+ 1. **Update and Install Dependencies**
20
+
21
+ ```bash
22
+ sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
23
+ ```
24
+
25
+ 2. **Clone the Repository**
26
+
27
+ ```bash
28
+ git clone https://huggingface.co/svjack/FramePack_Image_Edit_Lora_Early
29
+ cd FramePack_Image_Edit_Lora_Early
30
+ ```
31
+
32
+ 3. **Install Python Dependencies**
33
+
34
+ ```bash
35
+ pip install torch torchvision
36
+ pip install -r requirements.txt
37
+ pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
38
+ pip install moviepy==1.0.3
39
+ pip install sageattention==1.0.6
40
+ ```
41
+
42
+ 4. **Download Model Weights**
43
+
44
+ ```bash
45
+ git clone https://huggingface.co/lllyasviel/FramePackI2V_HY
46
+ git clone https://huggingface.co/hunyuanvideo-community/HunyuanVideo
47
+ git clone https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged
48
+ git clone https://huggingface.co/Comfy-Org/sigclip_vision_384
49
+ ```
50
+
51
+ ## Usage
52
+
53
+ To Edit a Image, use the `fpack_generate_video.py` script with the appropriate parameters. Below are examples of how to do it.
54
+
55
+
56
+ * 1 Add a cat
57
+ - Input
58
+
59
+
60
+ ```python
61
+ python fpack_generate_video.py \
62
+ --dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
63
+ --vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
64
+ --text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
65
+ --text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
66
+ --image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
67
+ --image_path xiang_image.jpg \
68
+ --prompt "add a cat into the picture" \
69
+ --video_size 512 512 --fps 30 --infer_steps 25 \
70
+ --attn_mode sdpa --fp8_scaled \
71
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
72
+ --save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
73
+ --seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
74
+ ```
75
+
76
+ - Output
77
+
78
+
79
+ * 2 Change Background
80
+ - Input
81
+
82
+
83
+ ```python
84
+ python fpack_generate_video.py \
85
+ --dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
86
+ --vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
87
+ --text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
88
+ --text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
89
+ --image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
90
+ --image_path wanye.jpg \
91
+ --prompt "Change the background into a restaurant in anime style. Keep the character's eye colors and white hair unchanged." \
92
+ --video_size 512 512 --fps 30 --infer_steps 25 \
93
+ --attn_mode sdpa --fp8_scaled \
94
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
95
+ --save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
96
+ --seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
97
+
98
+ ```
99
+
100
+ - Output
101
+
102
+
103
+ * 3 Place Train into landscape
104
+ - Input
105
+
106
+ ```python
107
+ python fpack_generate_video.py \
108
+ --dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
109
+ --vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
110
+ --text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
111
+ --text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
112
+ --image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
113
+ --image_path train.jpg \
114
+ --prompt "place the train into a beautiful landscape" \
115
+ --video_size 512 512 --fps 30 --infer_steps 25 \
116
+ --attn_mode sdpa --fp8_scaled \
117
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
118
+ --save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
119
+ --seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
120
+ ```
121
+
122
+ - Output
123
+
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
README.md ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FramePack Image Edit Early Lora
2
+
3
+ This repository contains the necessary steps and scripts to generate A edit of the Image using a image-to-video model.
4
+ The model leverages LoRA (Low-Rank Adaptation) weights and pre-trained components to create Edit Image based on a input Image and textual prompts.
5
+
6
+ ## Prerequisites
7
+
8
+ Before proceeding, ensure that you have the following installed on your system:
9
+
10
+ • **Ubuntu** (or a compatible Linux distribution)
11
+ • **Python 3.x**
12
+ • **pip** (Python package manager)
13
+ • **Git**
14
+ • **Git LFS** (Git Large File Storage)
15
+ • **FFmpeg**
16
+
17
+ ## Installation
18
+
19
+ 1. **Update and Install Dependencies**
20
+
21
+ ```bash
22
+ sudo apt-get update && sudo apt-get install cbm git-lfs ffmpeg
23
+ ```
24
+
25
+ 2. **Clone the Repository**
26
+
27
+ ```bash
28
+ git clone https://huggingface.co/svjack/FramePack_Image_Edit_Lora_Early
29
+ cd FramePack_Image_Edit_Lora_Early
30
+ ```
31
+
32
+ 3. **Install Python Dependencies**
33
+
34
+ ```bash
35
+ pip install torch torchvision
36
+ pip install -r requirements.txt
37
+ pip install ascii-magic matplotlib tensorboard huggingface_hub datasets
38
+ pip install moviepy==1.0.3
39
+ pip install sageattention==1.0.6
40
+ ```
41
+
42
+ 4. **Download Model Weights**
43
+
44
+ ```bash
45
+ git clone https://huggingface.co/lllyasviel/FramePackI2V_HY
46
+ git clone https://huggingface.co/hunyuanvideo-community/HunyuanVideo
47
+ git clone https://huggingface.co/Comfy-Org/HunyuanVideo_repackaged
48
+ git clone https://huggingface.co/Comfy-Org/sigclip_vision_384
49
+ ```
50
+
51
+ ## Usage
52
+
53
+ To Edit a Image, use the `fpack_generate_video.py` script with the appropriate parameters. Below are examples of how to do it.
54
+
55
+
56
+ * 1 Add a cat
57
+ - Input
58
+
59
+
60
+ ```python
61
+ python fpack_generate_video.py \
62
+ --dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
63
+ --vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
64
+ --text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
65
+ --text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
66
+ --image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
67
+ --image_path xiang_image.jpg \
68
+ --prompt "add a cat into the picture" \
69
+ --video_size 512 512 --fps 30 --infer_steps 25 \
70
+ --attn_mode sdpa --fp8_scaled \
71
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
72
+ --save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
73
+ --seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
74
+ ```
75
+
76
+ - Output
77
+
78
+
79
+ * 2 Change Background
80
+ - Input
81
+
82
+
83
+ ```python
84
+ python fpack_generate_video.py \
85
+ --dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
86
+ --vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
87
+ --text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
88
+ --text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
89
+ --image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
90
+ --image_path wanye.jpg \
91
+ --prompt "Change the background into a restaurant in anime style. Keep the character's eye colors and white hair unchanged." \
92
+ --video_size 512 512 --fps 30 --infer_steps 25 \
93
+ --attn_mode sdpa --fp8_scaled \
94
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
95
+ --save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
96
+ --seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
97
+
98
+ ```
99
+
100
+ - Output
101
+
102
+
103
+ * 3 Place Train into landscape
104
+ - Input
105
+
106
+ ```python
107
+ python fpack_generate_video.py \
108
+ --dit FramePackI2V_HY/diffusion_pytorch_model-00001-of-00003.safetensors \
109
+ --vae HunyuanVideo/vae/diffusion_pytorch_model.safetensors \
110
+ --text_encoder1 HunyuanVideo_repackaged/split_files/text_encoders/llava_llama3_fp16.safetensors \
111
+ --text_encoder2 HunyuanVideo_repackaged/split_files/text_encoders/clip_l.safetensors \
112
+ --image_encoder sigclip_vision_384/sigclip_vision_patch14_384.safetensors \
113
+ --image_path train.jpg \
114
+ --prompt "place the train into a beautiful landscape" \
115
+ --video_size 512 512 --fps 30 --infer_steps 25 \
116
+ --attn_mode sdpa --fp8_scaled \
117
+ --vae_chunk_size 32 --vae_spatial_tile_sample_min_size 128 \
118
+ --save_path save --video_sections 1 --output_type latent_images --one_frame_inference zero_post \
119
+ --seed 1234 --lora_multiplier 1.0 --lora_weight framepack_edit_output/framepack-edit-lora-000005.safetensors
120
+ ```
121
+
122
+ - Output
123
+
cache_latents.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import glob
4
+ from typing import Optional, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from dataset import config_utils
11
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
12
+ from PIL import Image
13
+
14
+ import logging
15
+
16
+ from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache, ARCHITECTURE_HUNYUAN_VIDEO
17
+ from hunyuan_model.vae import load_vae
18
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
19
+ from utils.model_utils import str_to_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+
25
+ def show_image(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]]) -> int:
26
+ import cv2
27
+
28
+ imgs = (
29
+ [image]
30
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
31
+ else [image[0], image[-1]]
32
+ )
33
+ if len(imgs) > 1:
34
+ print(f"Number of images: {len(image)}")
35
+ for i, img in enumerate(imgs):
36
+ if len(imgs) > 1:
37
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
38
+ else:
39
+ print(f"Image: {img.shape}")
40
+ cv2_img = np.array(img) if isinstance(img, Image.Image) else img
41
+ cv2_img = cv2.cvtColor(cv2_img, cv2.COLOR_RGB2BGR)
42
+ cv2.imshow("image", cv2_img)
43
+ k = cv2.waitKey(0)
44
+ cv2.destroyAllWindows()
45
+ if k == ord("q") or k == ord("d"):
46
+ return k
47
+ return k
48
+
49
+
50
+ def show_console(
51
+ image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]],
52
+ width: int,
53
+ back: str,
54
+ interactive: bool = False,
55
+ ) -> int:
56
+ from ascii_magic import from_pillow_image, Back
57
+
58
+ back = None
59
+ if back is not None:
60
+ back = getattr(Back, back.upper())
61
+
62
+ k = None
63
+ imgs = (
64
+ [image]
65
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image)
66
+ else [image[0], image[-1]]
67
+ )
68
+ if len(imgs) > 1:
69
+ print(f"Number of images: {len(image)}")
70
+ for i, img in enumerate(imgs):
71
+ if len(imgs) > 1:
72
+ print(f"{'First' if i == 0 else 'Last'} image: {img.shape}")
73
+ else:
74
+ print(f"Image: {img.shape}")
75
+ pil_img = img if isinstance(img, Image.Image) else Image.fromarray(img)
76
+ ascii_img = from_pillow_image(pil_img)
77
+ ascii_img.to_terminal(columns=width, back=back)
78
+
79
+ if interactive:
80
+ k = input("Press q to quit, d to next dataset, other key to next: ")
81
+ if k == "q" or k == "d":
82
+ return ord(k)
83
+
84
+ if not interactive:
85
+ return ord(" ")
86
+ return ord(k) if k else ord(" ")
87
+
88
+
89
+ def save_video(image: Union[list[Union[Image.Image, np.ndarray], Union[Image.Image, np.ndarray]]], cache_path: str, fps: int = 24):
90
+ import av
91
+
92
+ directory = os.path.dirname(cache_path)
93
+ if not os.path.exists(directory):
94
+ os.makedirs(directory)
95
+
96
+ if (isinstance(image, np.ndarray) and len(image.shape) == 3) or isinstance(image, Image.Image):
97
+ # save image
98
+ image_path = cache_path.replace(".safetensors", ".jpg")
99
+ img = image if isinstance(image, Image.Image) else Image.fromarray(image)
100
+ img.save(image_path)
101
+ print(f"Saved image: {image_path}")
102
+ else:
103
+ imgs = image
104
+ print(f"Number of images: {len(imgs)}")
105
+ # save video
106
+ video_path = cache_path.replace(".safetensors", ".mp4")
107
+ height, width = imgs[0].shape[0:2]
108
+
109
+ # create output container
110
+ container = av.open(video_path, mode="w")
111
+
112
+ # create video stream
113
+ codec = "libx264"
114
+ pixel_format = "yuv420p"
115
+ stream = container.add_stream(codec, rate=fps)
116
+ stream.width = width
117
+ stream.height = height
118
+ stream.pix_fmt = pixel_format
119
+ stream.bit_rate = 1000000 # 1Mbit/s for preview quality
120
+
121
+ for frame_img in imgs:
122
+ if isinstance(frame_img, Image.Image):
123
+ frame = av.VideoFrame.from_image(frame_img)
124
+ else:
125
+ frame = av.VideoFrame.from_ndarray(frame_img, format="rgb24")
126
+ packets = stream.encode(frame)
127
+ for packet in packets:
128
+ container.mux(packet)
129
+
130
+ for packet in stream.encode():
131
+ container.mux(packet)
132
+
133
+ container.close()
134
+
135
+ print(f"Saved video: {video_path}")
136
+
137
+
138
+ def show_datasets(
139
+ datasets: list[BaseDataset],
140
+ debug_mode: str,
141
+ console_width: int,
142
+ console_back: str,
143
+ console_num_images: Optional[int],
144
+ fps: int = 24,
145
+ ):
146
+ if debug_mode != "video":
147
+ print(f"d: next dataset, q: quit")
148
+
149
+ num_workers = max(1, os.cpu_count() - 1)
150
+ for i, dataset in enumerate(datasets):
151
+ print(f"Dataset [{i}]")
152
+ batch_index = 0
153
+ num_images_to_show = console_num_images
154
+ k = None
155
+ for key, batch in dataset.retrieve_latent_cache_batches(num_workers):
156
+ print(f"bucket resolution: {key}, count: {len(batch)}")
157
+ for j, item_info in enumerate(batch):
158
+ item_info: ItemInfo
159
+ print(f"{batch_index}-{j}: {item_info}")
160
+ if debug_mode == "image":
161
+ k = show_image(item_info.content)
162
+ elif debug_mode == "console":
163
+ k = show_console(item_info.content, console_width, console_back, console_num_images is None)
164
+ if num_images_to_show is not None:
165
+ num_images_to_show -= 1
166
+ if num_images_to_show == 0:
167
+ k = ord("d") # next dataset
168
+ elif debug_mode == "video":
169
+ save_video(item_info.content, item_info.latent_cache_path, fps)
170
+ k = None # save next video
171
+
172
+ if k == ord("q"):
173
+ return
174
+ elif k == ord("d"):
175
+ break
176
+ if k == ord("d"):
177
+ break
178
+ batch_index += 1
179
+
180
+
181
+ def encode_and_save_batch(vae: AutoencoderKLCausal3D, batch: list[ItemInfo]):
182
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
183
+ if len(contents.shape) == 4:
184
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
185
+
186
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
187
+ contents = contents.to(vae.device, dtype=vae.dtype)
188
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
189
+
190
+ h, w = contents.shape[3], contents.shape[4]
191
+ if h < 8 or w < 8:
192
+ item = batch[0] # other items should have the same size
193
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
194
+
195
+ # print(f"encode batch: {contents.shape}")
196
+ with torch.no_grad():
197
+ latent = vae.encode(contents).latent_dist.sample()
198
+ # latent = latent * vae.config.scaling_factor
199
+
200
+ # # debug: decode and save
201
+ # with torch.no_grad():
202
+ # latent_to_decode = latent / vae.config.scaling_factor
203
+ # images = vae.decode(latent_to_decode, return_dict=False)[0]
204
+ # images = (images / 2 + 0.5).clamp(0, 1)
205
+ # images = images.cpu().float().numpy()
206
+ # images = (images * 255).astype(np.uint8)
207
+ # images = images.transpose(0, 2, 3, 4, 1) # B, C, F, H, W -> B, F, H, W, C
208
+ # for b in range(images.shape[0]):
209
+ # for f in range(images.shape[1]):
210
+ # fln = os.path.splitext(os.path.basename(batch[b].item_key))[0]
211
+ # img = Image.fromarray(images[b, f])
212
+ # img.save(f"./logs/decode_{fln}_{b}_{f:03d}.jpg")
213
+
214
+ for item, l in zip(batch, latent):
215
+ # print(f"save latent cache: {item.latent_cache_path}, latent shape: {l.shape}")
216
+ save_latent_cache(item, l)
217
+
218
+
219
+ def encode_datasets(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
220
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
221
+ for i, dataset in enumerate(datasets):
222
+ logger.info(f"Encoding dataset [{i}]")
223
+ all_latent_cache_paths = []
224
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
225
+ all_latent_cache_paths.extend([item.latent_cache_path for item in batch])
226
+
227
+ if args.skip_existing:
228
+ filtered_batch = [item for item in batch if not os.path.exists(item.latent_cache_path)]
229
+ if len(filtered_batch) == 0:
230
+ continue
231
+ batch = filtered_batch
232
+
233
+ bs = args.batch_size if args.batch_size is not None else len(batch)
234
+ for i in range(0, len(batch), bs):
235
+ encode(batch[i : i + bs])
236
+
237
+ # normalize paths
238
+ all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
239
+ all_latent_cache_paths = set(all_latent_cache_paths)
240
+
241
+ # remove old cache files not in the dataset
242
+ all_cache_files = dataset.get_all_latent_cache_files()
243
+ for cache_file in all_cache_files:
244
+ if os.path.normpath(cache_file) not in all_latent_cache_paths:
245
+ if args.keep_cache:
246
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
247
+ else:
248
+ os.remove(cache_file)
249
+ logger.info(f"Removed old cache file: {cache_file}")
250
+
251
+
252
+ def main(args):
253
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
254
+ device = torch.device(device)
255
+
256
+ # Load dataset config
257
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
258
+ logger.info(f"Load dataset config from {args.dataset_config}")
259
+ user_config = config_utils.load_user_config(args.dataset_config)
260
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
261
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
262
+
263
+ datasets = train_dataset_group.datasets
264
+
265
+ if args.debug_mode is not None:
266
+ show_datasets(datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images)
267
+ return
268
+
269
+ assert args.vae is not None, "vae checkpoint is required"
270
+
271
+ # Load VAE model: HunyuanVideo VAE model is float16
272
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
273
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
274
+ vae.eval()
275
+ logger.info(f"Loaded VAE: {vae.config}, dtype: {vae.dtype}")
276
+
277
+ if args.vae_chunk_size is not None:
278
+ vae.set_chunk_size_for_causal_conv_3d(args.vae_chunk_size)
279
+ logger.info(f"Set chunk_size to {args.vae_chunk_size} for CausalConv3d in VAE")
280
+ if args.vae_spatial_tile_sample_min_size is not None:
281
+ vae.enable_spatial_tiling(True)
282
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
283
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
284
+ elif args.vae_tiling:
285
+ vae.enable_spatial_tiling(True)
286
+
287
+ # Encode images
288
+ def encode(one_batch: list[ItemInfo]):
289
+ encode_and_save_batch(vae, one_batch)
290
+
291
+ encode_datasets(datasets, encode, args)
292
+
293
+
294
+ def setup_parser_common() -> argparse.ArgumentParser:
295
+ parser = argparse.ArgumentParser()
296
+
297
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
298
+ parser.add_argument("--vae", type=str, required=False, default=None, help="path to vae checkpoint")
299
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
300
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
301
+ parser.add_argument(
302
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
303
+ )
304
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
305
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
306
+ parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
307
+ parser.add_argument("--debug_mode", type=str, default=None, choices=["image", "console", "video"], help="debug mode")
308
+ parser.add_argument("--console_width", type=int, default=80, help="debug mode: console width")
309
+ parser.add_argument(
310
+ "--console_back", type=str, default=None, help="debug mode: console background color, one of ascii_magic.Back"
311
+ )
312
+ parser.add_argument(
313
+ "--console_num_images",
314
+ type=int,
315
+ default=None,
316
+ help="debug mode: not interactive, number of images to show for each dataset",
317
+ )
318
+ return parser
319
+
320
+
321
+ def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
322
+ parser.add_argument(
323
+ "--vae_tiling",
324
+ action="store_true",
325
+ help="enable spatial tiling for VAE, default is False. If vae_spatial_tile_sample_min_size is set, this is automatically enabled",
326
+ )
327
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
328
+ parser.add_argument(
329
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
330
+ )
331
+ return parser
332
+
333
+
334
+ if __name__ == "__main__":
335
+ parser = setup_parser_common()
336
+ parser = hv_setup_parser(parser)
337
+
338
+ args = parser.parse_args()
339
+ main(args)
cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ import accelerate
12
+
13
+ from dataset.image_video_dataset import ARCHITECTURE_HUNYUAN_VIDEO, BaseDataset, ItemInfo, save_text_encoder_output_cache
14
+ from hunyuan_model import text_encoder as text_encoder_module
15
+ from hunyuan_model.text_encoder import TextEncoder
16
+
17
+ import logging
18
+
19
+ from utils.model_utils import str_to_dtype
20
+
21
+ logger = logging.getLogger(__name__)
22
+ logging.basicConfig(level=logging.INFO)
23
+
24
+
25
+ def encode_prompt(text_encoder: TextEncoder, prompt: Union[str, list[str]]):
26
+ data_type = "video" # video only, image is not supported
27
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
28
+
29
+ with torch.no_grad():
30
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type)
31
+
32
+ return prompt_outputs.hidden_state, prompt_outputs.attention_mask
33
+
34
+
35
+ def encode_and_save_batch(
36
+ text_encoder: TextEncoder, batch: list[ItemInfo], is_llm: bool, accelerator: Optional[accelerate.Accelerator]
37
+ ):
38
+ prompts = [item.caption for item in batch]
39
+ # print(prompts)
40
+
41
+ # encode prompt
42
+ if accelerator is not None:
43
+ with accelerator.autocast():
44
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
45
+ else:
46
+ prompt_embeds, prompt_mask = encode_prompt(text_encoder, prompts)
47
+
48
+ # # convert to fp16 if needed
49
+ # if prompt_embeds.dtype == torch.float32 and text_encoder.dtype != torch.float32:
50
+ # prompt_embeds = prompt_embeds.to(text_encoder.dtype)
51
+
52
+ # save prompt cache
53
+ for item, embed, mask in zip(batch, prompt_embeds, prompt_mask):
54
+ save_text_encoder_output_cache(item, embed, mask, is_llm)
55
+
56
+
57
+ def prepare_cache_files_and_paths(datasets: list[BaseDataset]):
58
+ all_cache_files_for_dataset = [] # exisiting cache files
59
+ all_cache_paths_for_dataset = [] # all cache paths in the dataset
60
+ for dataset in datasets:
61
+ all_cache_files = [os.path.normpath(file) for file in dataset.get_all_text_encoder_output_cache_files()]
62
+ all_cache_files = set(all_cache_files)
63
+ all_cache_files_for_dataset.append(all_cache_files)
64
+
65
+ all_cache_paths_for_dataset.append(set())
66
+ return all_cache_files_for_dataset, all_cache_paths_for_dataset
67
+
68
+
69
+ def process_text_encoder_batches(
70
+ num_workers: Optional[int],
71
+ skip_existing: bool,
72
+ batch_size: int,
73
+ datasets: list[BaseDataset],
74
+ all_cache_files_for_dataset: list[set],
75
+ all_cache_paths_for_dataset: list[set],
76
+ encode: callable,
77
+ ):
78
+ num_workers = num_workers if num_workers is not None else max(1, os.cpu_count() - 1)
79
+ for i, dataset in enumerate(datasets):
80
+ logger.info(f"Encoding dataset [{i}]")
81
+ all_cache_files = all_cache_files_for_dataset[i]
82
+ all_cache_paths = all_cache_paths_for_dataset[i]
83
+ for batch in tqdm(dataset.retrieve_text_encoder_output_cache_batches(num_workers)):
84
+ # update cache files (it's ok if we update it multiple times)
85
+ all_cache_paths.update([os.path.normpath(item.text_encoder_output_cache_path) for item in batch])
86
+
87
+ # skip existing cache files
88
+ if skip_existing:
89
+ filtered_batch = [
90
+ item for item in batch if not os.path.normpath(item.text_encoder_output_cache_path) in all_cache_files
91
+ ]
92
+ # print(f"Filtered {len(batch) - len(filtered_batch)} existing cache files")
93
+ if len(filtered_batch) == 0:
94
+ continue
95
+ batch = filtered_batch
96
+
97
+ bs = batch_size if batch_size is not None else len(batch)
98
+ for i in range(0, len(batch), bs):
99
+ encode(batch[i : i + bs])
100
+
101
+
102
+ def post_process_cache_files(
103
+ datasets: list[BaseDataset], all_cache_files_for_dataset: list[set], all_cache_paths_for_dataset: list[set], keep_cache: bool
104
+ ):
105
+ for i, dataset in enumerate(datasets):
106
+ all_cache_files = all_cache_files_for_dataset[i]
107
+ all_cache_paths = all_cache_paths_for_dataset[i]
108
+ for cache_file in all_cache_files:
109
+ if cache_file not in all_cache_paths:
110
+ if keep_cache:
111
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
112
+ else:
113
+ os.remove(cache_file)
114
+ logger.info(f"Removed old cache file: {cache_file}")
115
+
116
+
117
+ def main(args):
118
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
119
+ device = torch.device(device)
120
+
121
+ # Load dataset config
122
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
123
+ logger.info(f"Load dataset config from {args.dataset_config}")
124
+ user_config = config_utils.load_user_config(args.dataset_config)
125
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_HUNYUAN_VIDEO)
126
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
127
+
128
+ datasets = train_dataset_group.datasets
129
+
130
+ # define accelerator for fp8 inference
131
+ accelerator = None
132
+ if args.fp8_llm:
133
+ accelerator = accelerate.Accelerator(mixed_precision="fp16")
134
+
135
+ # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
136
+ all_cache_files_for_dataset, all_cache_paths_for_dataset = prepare_cache_files_and_paths(datasets)
137
+
138
+ # Load Text Encoder 1
139
+ text_encoder_dtype = torch.float16 if args.text_encoder_dtype is None else str_to_dtype(args.text_encoder_dtype)
140
+ logger.info(f"loading text encoder 1: {args.text_encoder1}")
141
+ text_encoder_1 = text_encoder_module.load_text_encoder_1(args.text_encoder1, device, args.fp8_llm, text_encoder_dtype)
142
+ text_encoder_1.to(device=device)
143
+
144
+ # Encode with Text Encoder 1 (LLM)
145
+ logger.info("Encoding with Text Encoder 1")
146
+
147
+ def encode_for_text_encoder_1(batch: list[ItemInfo]):
148
+ encode_and_save_batch(text_encoder_1, batch, is_llm=True, accelerator=accelerator)
149
+
150
+ process_text_encoder_batches(
151
+ args.num_workers,
152
+ args.skip_existing,
153
+ args.batch_size,
154
+ datasets,
155
+ all_cache_files_for_dataset,
156
+ all_cache_paths_for_dataset,
157
+ encode_for_text_encoder_1,
158
+ )
159
+ del text_encoder_1
160
+
161
+ # Load Text Encoder 2
162
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
163
+ text_encoder_2 = text_encoder_module.load_text_encoder_2(args.text_encoder2, device, text_encoder_dtype)
164
+ text_encoder_2.to(device=device)
165
+
166
+ # Encode with Text Encoder 2
167
+ logger.info("Encoding with Text Encoder 2")
168
+
169
+ def encode_for_text_encoder_2(batch: list[ItemInfo]):
170
+ encode_and_save_batch(text_encoder_2, batch, is_llm=False, accelerator=None)
171
+
172
+ process_text_encoder_batches(
173
+ args.num_workers,
174
+ args.skip_existing,
175
+ args.batch_size,
176
+ datasets,
177
+ all_cache_files_for_dataset,
178
+ all_cache_paths_for_dataset,
179
+ encode_for_text_encoder_2,
180
+ )
181
+ del text_encoder_2
182
+
183
+ # remove cache files not in dataset
184
+ post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
185
+
186
+
187
+ def setup_parser_common():
188
+ parser = argparse.ArgumentParser()
189
+
190
+ parser.add_argument("--dataset_config", type=str, required=True, help="path to dataset config .toml file")
191
+ parser.add_argument("--device", type=str, default=None, help="device to use, default is cuda if available")
192
+ parser.add_argument(
193
+ "--batch_size", type=int, default=None, help="batch size, override dataset config if dataset batch size > this"
194
+ )
195
+ parser.add_argument("--num_workers", type=int, default=None, help="number of workers for dataset. default is cpu count-1")
196
+ parser.add_argument("--skip_existing", action="store_true", help="skip existing cache files")
197
+ parser.add_argument("--keep_cache", action="store_true", help="keep cache files not in dataset")
198
+ return parser
199
+
200
+
201
+ def hv_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
202
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
203
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
204
+ parser.add_argument("--text_encoder_dtype", type=str, default=None, help="data type for Text Encoder, default is float16")
205
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
206
+ return parser
207
+
208
+
209
+ if __name__ == "__main__":
210
+ parser = setup_parser_common()
211
+ parser = hv_setup_parser(parser)
212
+
213
+ args = parser.parse_args()
214
+ main(args)
convert_lora.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import torch
4
+ from safetensors.torch import load_file, save_file
5
+ from safetensors import safe_open
6
+ from utils import model_utils
7
+
8
+ import logging
9
+
10
+
11
+ logger = logging.getLogger(__name__)
12
+ logging.basicConfig(level=logging.INFO)
13
+
14
+
15
+ def convert_from_diffusers(prefix, weights_sd):
16
+ # convert from diffusers(?) to default LoRA
17
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
18
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
19
+
20
+ # note: Diffusers has no alpha, so alpha is set to rank
21
+ new_weights_sd = {}
22
+ lora_dims = {}
23
+ for key, weight in weights_sd.items():
24
+ diffusers_prefix, key_body = key.split(".", 1)
25
+ if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
26
+ logger.warning(f"unexpected key: {key} in diffusers format")
27
+ continue
28
+
29
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
30
+ new_weights_sd[new_key] = weight
31
+
32
+ lora_name = new_key.split(".")[0] # before first dot
33
+ if lora_name not in lora_dims and "lora_down" in new_key:
34
+ lora_dims[lora_name] = weight.shape[0]
35
+
36
+ # add alpha with rank
37
+ for lora_name, dim in lora_dims.items():
38
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
39
+
40
+ return new_weights_sd
41
+
42
+
43
+ def convert_to_diffusers(prefix, weights_sd):
44
+ # convert from default LoRA to diffusers
45
+
46
+ # get alphas
47
+ lora_alphas = {}
48
+ for key, weight in weights_sd.items():
49
+ if key.startswith(prefix):
50
+ lora_name = key.split(".", 1)[0] # before first dot
51
+ if lora_name not in lora_alphas and "alpha" in key:
52
+ lora_alphas[lora_name] = weight
53
+
54
+ new_weights_sd = {}
55
+ for key, weight in weights_sd.items():
56
+ if key.startswith(prefix):
57
+ if "alpha" in key:
58
+ continue
59
+
60
+ lora_name = key.split(".", 1)[0] # before first dot
61
+
62
+ module_name = lora_name[len(prefix) :] # remove "lora_unet_"
63
+ module_name = module_name.replace("_", ".") # replace "_" with "."
64
+ if ".cross.attn." in module_name or ".self.attn." in module_name:
65
+ # Wan2.1 lora name to module name: ugly but works
66
+ module_name = module_name.replace("cross.attn", "cross_attn") # fix cross attn
67
+ module_name = module_name.replace("self.attn", "self_attn") # fix self attn
68
+ module_name = module_name.replace("k.img", "k_img") # fix k img
69
+ module_name = module_name.replace("v.img", "v_img") # fix v img
70
+ else:
71
+ # HunyuanVideo lora name to module name: ugly but works
72
+ module_name = module_name.replace("double.blocks.", "double_blocks.") # fix double blocks
73
+ module_name = module_name.replace("single.blocks.", "single_blocks.") # fix single blocks
74
+ module_name = module_name.replace("img.", "img_") # fix img
75
+ module_name = module_name.replace("txt.", "txt_") # fix txt
76
+ module_name = module_name.replace("attn.", "attn_") # fix attn
77
+
78
+ diffusers_prefix = "diffusion_model"
79
+ if "lora_down" in key:
80
+ new_key = f"{diffusers_prefix}.{module_name}.lora_A.weight"
81
+ dim = weight.shape[0]
82
+ elif "lora_up" in key:
83
+ new_key = f"{diffusers_prefix}.{module_name}.lora_B.weight"
84
+ dim = weight.shape[1]
85
+ else:
86
+ logger.warning(f"unexpected key: {key} in default LoRA format")
87
+ continue
88
+
89
+ # scale weight by alpha
90
+ if lora_name in lora_alphas:
91
+ # we scale both down and up, so scale is sqrt
92
+ scale = lora_alphas[lora_name] / dim
93
+ scale = scale.sqrt()
94
+ weight = weight * scale
95
+ else:
96
+ logger.warning(f"missing alpha for {lora_name}")
97
+
98
+ new_weights_sd[new_key] = weight
99
+
100
+ return new_weights_sd
101
+
102
+
103
+ def convert(input_file, output_file, target_format):
104
+ logger.info(f"loading {input_file}")
105
+ weights_sd = load_file(input_file)
106
+ with safe_open(input_file, framework="pt") as f:
107
+ metadata = f.metadata()
108
+
109
+ logger.info(f"converting to {target_format}")
110
+ prefix = "lora_unet_"
111
+ if target_format == "default":
112
+ new_weights_sd = convert_from_diffusers(prefix, weights_sd)
113
+ metadata = metadata or {}
114
+ model_utils.precalculate_safetensors_hashes(new_weights_sd, metadata)
115
+ elif target_format == "other":
116
+ new_weights_sd = convert_to_diffusers(prefix, weights_sd)
117
+ else:
118
+ raise ValueError(f"unknown target format: {target_format}")
119
+
120
+ logger.info(f"saving to {output_file}")
121
+ save_file(new_weights_sd, output_file, metadata=metadata)
122
+
123
+ logger.info("done")
124
+
125
+
126
+ def parse_args():
127
+ parser = argparse.ArgumentParser(description="Convert LoRA weights between default and other formats")
128
+ parser.add_argument("--input", type=str, required=True, help="input model file")
129
+ parser.add_argument("--output", type=str, required=True, help="output model file")
130
+ parser.add_argument("--target", type=str, required=True, choices=["other", "default"], help="target format")
131
+ args = parser.parse_args()
132
+ return args
133
+
134
+
135
+ if __name__ == "__main__":
136
+ args = parse_args()
137
+ convert(args.input, args.output, args.target)
dataset/__init__.py ADDED
File without changes
dataset/config_utils.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from dataclasses import (
3
+ asdict,
4
+ dataclass,
5
+ )
6
+ import functools
7
+ import random
8
+ from textwrap import dedent, indent
9
+ import json
10
+ from pathlib import Path
11
+
12
+ # from toolz import curry
13
+ from typing import Dict, List, Optional, Sequence, Tuple, Union
14
+
15
+ import toml
16
+ import voluptuous
17
+ from voluptuous import Any, ExactSequence, MultipleInvalid, Object, Schema
18
+
19
+ from .image_video_dataset import DatasetGroup, ImageDataset, VideoDataset
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ @dataclass
28
+ class BaseDatasetParams:
29
+ resolution: Tuple[int, int] = (960, 544)
30
+ enable_bucket: bool = False
31
+ bucket_no_upscale: bool = False
32
+ caption_extension: Optional[str] = None
33
+ batch_size: int = 1
34
+ num_repeats: int = 1
35
+ cache_directory: Optional[str] = None
36
+ debug_dataset: bool = False
37
+ architecture: str = "no_default" # short style like "hv" or "wan"
38
+
39
+
40
+ @dataclass
41
+ class ImageDatasetParams(BaseDatasetParams):
42
+ image_directory: Optional[str] = None
43
+ image_jsonl_file: Optional[str] = None
44
+ control_directory: Optional[str] = None
45
+
46
+
47
+ @dataclass
48
+ class VideoDatasetParams(BaseDatasetParams):
49
+ video_directory: Optional[str] = None
50
+ video_jsonl_file: Optional[str] = None
51
+ control_directory: Optional[str] = None
52
+ target_frames: Sequence[int] = (1,)
53
+ frame_extraction: Optional[str] = "head"
54
+ frame_stride: Optional[int] = 1
55
+ frame_sample: Optional[int] = 1
56
+ max_frames: Optional[int] = 129
57
+ source_fps: Optional[float] = None
58
+
59
+
60
+ @dataclass
61
+ class DatasetBlueprint:
62
+ is_image_dataset: bool
63
+ params: Union[ImageDatasetParams, VideoDatasetParams]
64
+
65
+
66
+ @dataclass
67
+ class DatasetGroupBlueprint:
68
+ datasets: Sequence[DatasetBlueprint]
69
+
70
+
71
+ @dataclass
72
+ class Blueprint:
73
+ dataset_group: DatasetGroupBlueprint
74
+
75
+
76
+ class ConfigSanitizer:
77
+ # @curry
78
+ @staticmethod
79
+ def __validate_and_convert_twodim(klass, value: Sequence) -> Tuple:
80
+ Schema(ExactSequence([klass, klass]))(value)
81
+ return tuple(value)
82
+
83
+ # @curry
84
+ @staticmethod
85
+ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]) -> Tuple:
86
+ Schema(Any(klass, ExactSequence([klass, klass])))(value)
87
+ try:
88
+ Schema(klass)(value)
89
+ return (value, value)
90
+ except:
91
+ return ConfigSanitizer.__validate_and_convert_twodim(klass, value)
92
+
93
+ # datasets schema
94
+ DATASET_ASCENDABLE_SCHEMA = {
95
+ "caption_extension": str,
96
+ "batch_size": int,
97
+ "num_repeats": int,
98
+ "resolution": functools.partial(__validate_and_convert_scalar_or_twodim.__func__, int),
99
+ "enable_bucket": bool,
100
+ "bucket_no_upscale": bool,
101
+ }
102
+ IMAGE_DATASET_DISTINCT_SCHEMA = {
103
+ "image_directory": str,
104
+ "image_jsonl_file": str,
105
+ "cache_directory": str,
106
+ "control_directory": str,
107
+ }
108
+ VIDEO_DATASET_DISTINCT_SCHEMA = {
109
+ "video_directory": str,
110
+ "video_jsonl_file": str,
111
+ "control_directory": str,
112
+ "target_frames": [int],
113
+ "frame_extraction": str,
114
+ "frame_stride": int,
115
+ "frame_sample": int,
116
+ "max_frames": int,
117
+ "cache_directory": str,
118
+ "source_fps": float,
119
+ }
120
+
121
+ # options handled by argparse but not handled by user config
122
+ ARGPARSE_SPECIFIC_SCHEMA = {
123
+ "debug_dataset": bool,
124
+ }
125
+
126
+ def __init__(self) -> None:
127
+ self.image_dataset_schema = self.__merge_dict(
128
+ self.DATASET_ASCENDABLE_SCHEMA,
129
+ self.IMAGE_DATASET_DISTINCT_SCHEMA,
130
+ )
131
+ self.video_dataset_schema = self.__merge_dict(
132
+ self.DATASET_ASCENDABLE_SCHEMA,
133
+ self.VIDEO_DATASET_DISTINCT_SCHEMA,
134
+ )
135
+
136
+ def validate_flex_dataset(dataset_config: dict):
137
+ if "video_directory" in dataset_config or "video_jsonl_file" in dataset_config:
138
+ return Schema(self.video_dataset_schema)(dataset_config)
139
+ else:
140
+ return Schema(self.image_dataset_schema)(dataset_config)
141
+
142
+ self.dataset_schema = validate_flex_dataset
143
+
144
+ self.general_schema = self.__merge_dict(
145
+ self.DATASET_ASCENDABLE_SCHEMA,
146
+ )
147
+ self.user_config_validator = Schema(
148
+ {
149
+ "general": self.general_schema,
150
+ "datasets": [self.dataset_schema],
151
+ }
152
+ )
153
+ self.argparse_schema = self.__merge_dict(
154
+ self.ARGPARSE_SPECIFIC_SCHEMA,
155
+ )
156
+ self.argparse_config_validator = Schema(Object(self.argparse_schema), extra=voluptuous.ALLOW_EXTRA)
157
+
158
+ def sanitize_user_config(self, user_config: dict) -> dict:
159
+ try:
160
+ return self.user_config_validator(user_config)
161
+ except MultipleInvalid:
162
+ # TODO: clarify the error message
163
+ logger.error("Invalid user config / ユーザ設定の形式が正しくないようです")
164
+ raise
165
+
166
+ # NOTE: In nature, argument parser result is not needed to be sanitize
167
+ # However this will help us to detect program bug
168
+ def sanitize_argparse_namespace(self, argparse_namespace: argparse.Namespace) -> argparse.Namespace:
169
+ try:
170
+ return self.argparse_config_validator(argparse_namespace)
171
+ except MultipleInvalid:
172
+ # XXX: this should be a bug
173
+ logger.error(
174
+ "Invalid cmdline parsed arguments. This should be a bug. / コマンドラインのパース結果が正しくないようです。プログラムのバグの可能性が高いです。"
175
+ )
176
+ raise
177
+
178
+ # NOTE: value would be overwritten by latter dict if there is already the same key
179
+ @staticmethod
180
+ def __merge_dict(*dict_list: dict) -> dict:
181
+ merged = {}
182
+ for schema in dict_list:
183
+ # merged |= schema
184
+ for k, v in schema.items():
185
+ merged[k] = v
186
+ return merged
187
+
188
+
189
+ class BlueprintGenerator:
190
+ BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME = {}
191
+
192
+ def __init__(self, sanitizer: ConfigSanitizer):
193
+ self.sanitizer = sanitizer
194
+
195
+ # runtime_params is for parameters which is only configurable on runtime, such as tokenizer
196
+ def generate(self, user_config: dict, argparse_namespace: argparse.Namespace, **runtime_params) -> Blueprint:
197
+ sanitized_user_config = self.sanitizer.sanitize_user_config(user_config)
198
+ sanitized_argparse_namespace = self.sanitizer.sanitize_argparse_namespace(argparse_namespace)
199
+
200
+ argparse_config = {k: v for k, v in vars(sanitized_argparse_namespace).items() if v is not None}
201
+ general_config = sanitized_user_config.get("general", {})
202
+
203
+ dataset_blueprints = []
204
+ for dataset_config in sanitized_user_config.get("datasets", []):
205
+ is_image_dataset = "image_directory" in dataset_config or "image_jsonl_file" in dataset_config
206
+ if is_image_dataset:
207
+ dataset_params_klass = ImageDatasetParams
208
+ else:
209
+ dataset_params_klass = VideoDatasetParams
210
+
211
+ params = self.generate_params_by_fallbacks(
212
+ dataset_params_klass, [dataset_config, general_config, argparse_config, runtime_params]
213
+ )
214
+ dataset_blueprints.append(DatasetBlueprint(is_image_dataset, params))
215
+
216
+ dataset_group_blueprint = DatasetGroupBlueprint(dataset_blueprints)
217
+
218
+ return Blueprint(dataset_group_blueprint)
219
+
220
+ @staticmethod
221
+ def generate_params_by_fallbacks(param_klass, fallbacks: Sequence[dict]):
222
+ name_map = BlueprintGenerator.BLUEPRINT_PARAM_NAME_TO_CONFIG_OPTNAME
223
+ search_value = BlueprintGenerator.search_value
224
+ default_params = asdict(param_klass())
225
+ param_names = default_params.keys()
226
+
227
+ params = {name: search_value(name_map.get(name, name), fallbacks, default_params.get(name)) for name in param_names}
228
+
229
+ return param_klass(**params)
230
+
231
+ @staticmethod
232
+ def search_value(key: str, fallbacks: Sequence[dict], default_value=None):
233
+ for cand in fallbacks:
234
+ value = cand.get(key)
235
+ if value is not None:
236
+ return value
237
+
238
+ return default_value
239
+
240
+
241
+ # if training is True, it will return a dataset group for training, otherwise for caching
242
+ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlueprint, training: bool = False) -> DatasetGroup:
243
+ datasets: List[Union[ImageDataset, VideoDataset]] = []
244
+
245
+ for dataset_blueprint in dataset_group_blueprint.datasets:
246
+ if dataset_blueprint.is_image_dataset:
247
+ dataset_klass = ImageDataset
248
+ else:
249
+ dataset_klass = VideoDataset
250
+
251
+ dataset = dataset_klass(**asdict(dataset_blueprint.params))
252
+ datasets.append(dataset)
253
+
254
+ # assertion
255
+ cache_directories = [dataset.cache_directory for dataset in datasets]
256
+ num_of_unique_cache_directories = len(set(cache_directories))
257
+ if num_of_unique_cache_directories != len(cache_directories):
258
+ raise ValueError(
259
+ "cache directory should be unique for each dataset (note that cache directory is image/video directory if not specified)"
260
+ + " / cache directory は各データセットごとに異なる必要があります(指定されていない場合はimage/video directoryが使われるので注意)"
261
+ )
262
+
263
+ # print info
264
+ info = ""
265
+ for i, dataset in enumerate(datasets):
266
+ is_image_dataset = isinstance(dataset, ImageDataset)
267
+ info += dedent(
268
+ f"""\
269
+ [Dataset {i}]
270
+ is_image_dataset: {is_image_dataset}
271
+ resolution: {dataset.resolution}
272
+ batch_size: {dataset.batch_size}
273
+ num_repeats: {dataset.num_repeats}
274
+ caption_extension: "{dataset.caption_extension}"
275
+ enable_bucket: {dataset.enable_bucket}
276
+ bucket_no_upscale: {dataset.bucket_no_upscale}
277
+ cache_directory: "{dataset.cache_directory}"
278
+ debug_dataset: {dataset.debug_dataset}
279
+ """
280
+ )
281
+
282
+ if is_image_dataset:
283
+ info += indent(
284
+ dedent(
285
+ f"""\
286
+ image_directory: "{dataset.image_directory}"
287
+ image_jsonl_file: "{dataset.image_jsonl_file}"
288
+ control_directory: "{dataset.control_directory}"
289
+ \n"""
290
+ ),
291
+ " ",
292
+ )
293
+ else:
294
+ info += indent(
295
+ dedent(
296
+ f"""\
297
+ video_directory: "{dataset.video_directory}"
298
+ video_jsonl_file: "{dataset.video_jsonl_file}"
299
+ control_directory: "{dataset.control_directory}"
300
+ target_frames: {dataset.target_frames}
301
+ frame_extraction: {dataset.frame_extraction}
302
+ frame_stride: {dataset.frame_stride}
303
+ frame_sample: {dataset.frame_sample}
304
+ max_frames: {dataset.max_frames}
305
+ source_fps: {dataset.source_fps}
306
+ \n"""
307
+ ),
308
+ " ",
309
+ )
310
+ logger.info(f"{info}")
311
+
312
+ # make buckets first because it determines the length of dataset
313
+ # and set the same seed for all datasets
314
+ seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
315
+ for i, dataset in enumerate(datasets):
316
+ # logger.info(f"[Dataset {i}]")
317
+ dataset.set_seed(seed)
318
+ if training:
319
+ dataset.prepare_for_training()
320
+
321
+ return DatasetGroup(datasets)
322
+
323
+
324
+ def load_user_config(file: str) -> dict:
325
+ file: Path = Path(file)
326
+ if not file.is_file():
327
+ raise ValueError(f"file not found / ファイルが見つかりません: {file}")
328
+
329
+ if file.name.lower().endswith(".json"):
330
+ try:
331
+ with open(file, "r", encoding="utf-8") as f:
332
+ config = json.load(f)
333
+ except Exception:
334
+ logger.error(
335
+ f"Error on parsing JSON config file. Please check the format. / JSON 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
336
+ )
337
+ raise
338
+ elif file.name.lower().endswith(".toml"):
339
+ try:
340
+ config = toml.load(file)
341
+ except Exception:
342
+ logger.error(
343
+ f"Error on parsing TOML config file. Please check the format. / TOML 形式の設定ファイルの読み込みに失敗しました。文法が正しいか確認してください。: {file}"
344
+ )
345
+ raise
346
+ else:
347
+ raise ValueError(f"not supported config file format / 対応していない設定ファイルの形式です: {file}")
348
+
349
+ return config
350
+
351
+
352
+ # for config test
353
+ if __name__ == "__main__":
354
+ parser = argparse.ArgumentParser()
355
+ parser.add_argument("dataset_config")
356
+ config_args, remain = parser.parse_known_args()
357
+
358
+ parser = argparse.ArgumentParser()
359
+ parser.add_argument("--debug_dataset", action="store_true")
360
+ argparse_namespace = parser.parse_args(remain)
361
+
362
+ logger.info("[argparse_namespace]")
363
+ logger.info(f"{vars(argparse_namespace)}")
364
+
365
+ user_config = load_user_config(config_args.dataset_config)
366
+
367
+ logger.info("")
368
+ logger.info("[user_config]")
369
+ logger.info(f"{user_config}")
370
+
371
+ sanitizer = ConfigSanitizer()
372
+ sanitized_user_config = sanitizer.sanitize_user_config(user_config)
373
+
374
+ logger.info("")
375
+ logger.info("[sanitized_user_config]")
376
+ logger.info(f"{sanitized_user_config}")
377
+
378
+ blueprint = BlueprintGenerator(sanitizer).generate(user_config, argparse_namespace)
379
+
380
+ logger.info("")
381
+ logger.info("[blueprint]")
382
+ logger.info(f"{blueprint}")
383
+
384
+ dataset_group = generate_dataset_group_by_blueprint(blueprint.dataset_group)
dataset/dataset_config.md ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ > 📝 Click on the language section to expand / 言語をクリックして展開
2
+
3
+ ## Dataset Configuration
4
+
5
+ Please create a TOML file for dataset configuration.
6
+
7
+ Image and video datasets are supported. The configuration file can include multiple datasets, either image or video datasets, with caption text files or metadata JSONL files.
8
+
9
+ The cache directory must be different for each dataset.
10
+
11
+ Each video is extracted frame by frame without additional processing and used for training. It is recommended to use videos with a frame rate of 24fps for HunyuanVideo, 16fps for Wan2.1 and 30fps for FramePack. You can check the videos that will be trained using `--debug_mode video` when caching latent (see [here](/README.md#latent-caching)).
12
+ <details>
13
+ <summary>日本語</summary>
14
+
15
+ データセットの設定を行うためのTOMLファイルを作成してください。
16
+
17
+ 画像データセットと動画データセットがサポートされています。設定ファイルには、画像または動画データセットを複数含めることができます。キャプションテキストファイルまたはメタデータJSONLファイルを使用できます。
18
+
19
+ キャッシュディレクトリは、各データセットごとに異なるディレクトリである必要があります。
20
+
21
+ 動画は追加のプロセスなしでフレームごとに抽出され、学習に用いられます。そのため、HunyuanVideoは24fps、Wan2.1は16fps、FramePackは30fpsのフレームレートの動画を使用することをお勧めします。latentキャッシュ時の`--debug_mode video`を使用すると、学習される動画を確認できます([こちら](/README.ja.md#latentの事前キャッシュ)を参照)。
22
+ </details>
23
+
24
+ ### Sample for Image Dataset with Caption Text Files
25
+
26
+ ```toml
27
+ # resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
28
+ # otherwise, the default values will be used for each item
29
+
30
+ # general configurations
31
+ [general]
32
+ resolution = [960, 544]
33
+ caption_extension = ".txt"
34
+ batch_size = 1
35
+ enable_bucket = true
36
+ bucket_no_upscale = false
37
+
38
+ [[datasets]]
39
+ image_directory = "/path/to/image_dir"
40
+ cache_directory = "/path/to/cache_directory"
41
+ num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
42
+
43
+ # other datasets can be added here. each dataset can have different configurations
44
+ ```
45
+
46
+ `cache_directory` is optional, default is None to use the same directory as the image directory. However, we recommend to set the cache directory to avoid accidental sharing of the cache files between different datasets.
47
+
48
+ `num_repeats` is also available. It is optional, default is 1 (no repeat). It repeats the images (or videos) that many times to expand the dataset. For example, if `num_repeats = 2` and there are 20 images in the dataset, each image will be duplicated twice (with the same caption) to have a total of 40 images. It is useful to balance the multiple datasets with different sizes.
49
+
50
+ <details>
51
+ <summary>日本語</summary>
52
+
53
+ `cache_directory` はオプションです。デフォルトは画像ディレクトリと同じディレクトリに設定されます。ただし、異なるデータセット間でキャッシュファイルが共有されるのを防ぐために、明示的に別のキャッシュディレクトリを設定することをお勧めします。
54
+
55
+ `num_repeats` はオプションで、デフォルトは 1 です(繰り返しなし)。画像(や動画)を、その回数だけ単純に繰り返してデータセットを拡張します。たとえば`num_repeats = 2`としたとき、画像20枚のデータセットなら、各画像が2枚ずつ(同一のキャプションで)計40枚存在した場合と同じになります。異なるデータ数のデータセット間でバランスを取るために使用可能です。
56
+
57
+ resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
58
+
59
+ `[[datasets]]`以下を追加することで、他のデータセットを追加できます。各データセットには異なる設定を持てます。
60
+ </details>
61
+
62
+ ### Sample for Image Dataset with Metadata JSONL File
63
+
64
+ ```toml
65
+ # resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale should be set in either general or datasets
66
+ # caption_extension is not required for metadata jsonl file
67
+ # cache_directory is required for each dataset with metadata jsonl file
68
+
69
+ # general configurations
70
+ [general]
71
+ resolution = [960, 544]
72
+ batch_size = 1
73
+ enable_bucket = true
74
+ bucket_no_upscale = false
75
+
76
+ [[datasets]]
77
+ image_jsonl_file = "/path/to/metadata.jsonl"
78
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
79
+ num_repeats = 1 # optional, default is 1. Same as above.
80
+
81
+ # other datasets can be added here. each dataset can have different configurations
82
+ ```
83
+
84
+ JSONL file format for metadata:
85
+
86
+ ```json
87
+ {"image_path": "/path/to/image1.jpg", "caption": "A caption for image1"}
88
+ {"image_path": "/path/to/image2.jpg", "caption": "A caption for image2"}
89
+ ```
90
+
91
+ <details>
92
+ <summary>日本語</summary>
93
+
94
+ resolution, batch_size, num_repeats, enable_bucket, bucket_no_upscale は general または datasets のどちらかに設定してください。省略時は各項目のデフォルト値が使用されます。
95
+
96
+ metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
97
+
98
+ キャプションによるデータセットと同様に、複数のデータセットを追加できます。各データセットには異なる設定を持てます。
99
+ </details>
100
+
101
+
102
+ ### Sample for Video Dataset with Caption Text Files
103
+
104
+ ```toml
105
+ # Common parameters (resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)
106
+ # can be set in either general or datasets sections
107
+ # Video-specific parameters (target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)
108
+ # must be set in each datasets section
109
+
110
+ # general configurations
111
+ [general]
112
+ resolution = [960, 544]
113
+ caption_extension = ".txt"
114
+ batch_size = 1
115
+ enable_bucket = true
116
+ bucket_no_upscale = false
117
+
118
+ [[datasets]]
119
+ video_directory = "/path/to/video_dir"
120
+ cache_directory = "/path/to/cache_directory" # recommended to set cache directory
121
+ target_frames = [1, 25, 45]
122
+ frame_extraction = "head"
123
+ source_fps = 30.0 # optional, source fps for videos in the directory, decimal number
124
+
125
+ [[datasets]]
126
+ video_directory = "/path/to/video_dir2"
127
+ cache_directory = "/path/to/cache_directory2" # recommended to set cache directory
128
+ frame_extraction = "full"
129
+ max_frames = 45
130
+
131
+ # other datasets can be added here. each dataset can have different configurations
132
+ ```
133
+
134
+ __In HunyuanVideo and Wan2.1, the number of `target_frames` must be "N\*4+1" (N=0,1,2,...).__ Otherwise, it will be truncated to the nearest "N*4+1".
135
+
136
+ In FramePack, it is recommended to set `frame_extraction` to `full` and `max_frames` to a sufficiently large value, as it can handle longer videos. However, if the video is too long, an Out of Memory error may occur during VAE encoding. The videos in FramePack are trimmed to "N * latent_window_size * 4 + 1" frames (for example, 37, 73, 109... if `latent_window_size` is 9).
137
+
138
+ If the `source_fps` is specified, the videos in the directory are considered to be at this frame rate, and some frames will be skipped to match the model's frame rate (24 for HunyuanVideo and 16 for Wan2.1). __The value must be a decimal number, for example, `30.0` instead of `30`.__ The skipping is done automatically and does not consider the content of the images. Please check if the converted data is correct using `--debug_mode video`.
139
+
140
+ If `source_fps` is not specified (default), all frames of the video will be used regardless of the video's frame rate.
141
+
142
+ <details>
143
+ <summary>日本語</summary>
144
+
145
+ 共通パラメータ(resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)は、generalまたはdatasetsのいずれかに設定できます。
146
+ 動画固有のパラメータ(target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)は、各datasetsセクションに設定する必要があります。
147
+
148
+ __HunyuanVideoおよびWan2.1では、target_framesの数値は「N\*4+1」である必要があります。__ これ以外の値の場合は、最も近いN\*4+1の値に切り捨てられます。
149
+
150
+ FramePackでも同様ですが、FramePackでは動画が長くても学習可能なため、 `frame_extraction`に`full` を指定し、`max_frames`を十分に大きな値に設定することをお勧めします。ただし、あまりにも長すぎるとVAEのencodeでOut of Memoryエラーが発生する可能性があります。FramePackの動画は、「N * latent_window_size * 4 + 1」フレームにトリミングされます(latent_window_sizeが9の場合、37、73、109……)。
151
+
152
+ `source_fps`を指定した場合、ディレクトリ内の動画をこのフレームレートとみなして、モデルのフレームレートにあうようにいくつかのフレームをスキップします(HunyuanVideoは24、Wan2.1は16)。__小数点を含む数値で指定してください。__ 例:`30`ではなく`30.0`。スキップは機械的に行われ、画像の内容は考慮しません。変換後のデータが正しいか、`--debug_mode video`で確認してください。
153
+
154
+ `source_fps`を指定しない場合、動画のフレームは(動画自体のフレームレートに関係なく)すべて使用されます。
155
+
156
+ 他の注意事項は画像データセットと同様です。
157
+ </details>
158
+
159
+ ### Sample for Video Dataset with Metadata JSONL File
160
+
161
+ ```toml
162
+ # Common parameters (resolution, caption_extension, batch_size, num_repeats, enable_bucket, bucket_no_upscale)
163
+ # can be set in either general or datasets sections
164
+ # Video-specific parameters (target_frames, frame_extraction, frame_stride, frame_sample, max_frames, source_fps)
165
+ # must be set in each datasets section
166
+
167
+ # caption_extension is not required for metadata jsonl file
168
+ # cache_directory is required for each dataset with metadata jsonl file
169
+
170
+ # general configurations
171
+ [general]
172
+ resolution = [960, 544]
173
+ batch_size = 1
174
+ enable_bucket = true
175
+ bucket_no_upscale = false
176
+
177
+ [[datasets]]
178
+ video_jsonl_file = "/path/to/metadata.jsonl"
179
+ target_frames = [1, 25, 45]
180
+ frame_extraction = "head"
181
+ cache_directory = "/path/to/cache_directory_head"
182
+ source_fps = 30.0 # optional, source fps for videos in the jsonl file
183
+ # same metadata jsonl file can be used for multiple datasets
184
+ [[datasets]]
185
+ video_jsonl_file = "/path/to/metadata.jsonl"
186
+ target_frames = [1]
187
+ frame_stride = 10
188
+ cache_directory = "/path/to/cache_directory_stride"
189
+
190
+ # other datasets can be added here. each dataset can have different configurations
191
+ ```
192
+
193
+ JSONL file format for metadata:
194
+
195
+ ```json
196
+ {"video_path": "/path/to/video1.mp4", "caption": "A caption for video1"}
197
+ {"video_path": "/path/to/video2.mp4", "caption": "A caption for video2"}
198
+ ```
199
+
200
+ `video_path` can be a directory containing multiple images.
201
+
202
+ <details>
203
+ <summary>日本語</summary>
204
+ metadata jsonl ファイルを使用する場合、caption_extension は必要ありません。また、cache_directory は必須です。
205
+
206
+ `video_path`は、複数の画像を含むディレクトリのパスでも構いません。
207
+
208
+ 他の注意事項は今までのデータセットと同様です。
209
+ </details>
210
+
211
+ ### frame_extraction Options
212
+
213
+ - `head`: Extract the first N frames from the video.
214
+ - `chunk`: Extract frames by splitting the video into chunks of N frames.
215
+ - `slide`: Extract frames from the video with a stride of `frame_stride`.
216
+ - `uniform`: Extract `frame_sample` samples uniformly from the video.
217
+ - `full`: Extract all frames from the video.
218
+
219
+ In the case of `full`, the entire video is used, but it is trimmed to "N*4+1" frames. It is also trimmed to the `max_frames` if it exceeds that value. To avoid Out of Memory errors, please set `max_frames`.
220
+
221
+ The frame extraction methods other than `full` are recommended when the video contains repeated actions. `full` is recommended when each video represents a single complete motion.
222
+
223
+ For example, consider a video with 40 frames. The following diagrams illustrate each extraction:
224
+
225
+ <details>
226
+ <summary>日本語</summary>
227
+
228
+ - `head`: 動画から最初のNフレームを抽出します。
229
+ - `chunk`: 動画をNフレームずつに分割してフレームを抽出します。
230
+ - `slide`: `frame_stride`に指定したフレームごとに動画からNフレームを抽出します。
231
+ - `uniform`: 動画から一定間隔で、`frame_sample`個のNフレームを抽出します。
232
+ - `full`: 動画から全てのフレームを抽出します。
233
+
234
+ `full`の場合、各動画の全体を用いますが、「N*4+1」のフレーム数にトリミングされます。また`max_frames`を超える場合もその値にトリミングされます。Out of Memoryエラーを避けるために、`max_frames`を設定してください。
235
+
236
+ `full`以外の抽出方法は、動画が特定の動作を繰り返している場合にお勧めします。`full`はそれぞれの動画がひとつの完結したモーションの場合にお勧めします。
237
+
238
+ 例えば、40フレームの動画を例とした抽出について、以下の図で説明します。
239
+ </details>
240
+
241
+ ```
242
+ Original Video, 40 frames: x = frame, o = no frame
243
+ oooooooooooooooooooooooooooooooooooooooo
244
+
245
+ head, target_frames = [1, 13, 25] -> extract head frames:
246
+ xooooooooooooooooooooooooooooooooooooooo
247
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
248
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
249
+
250
+ chunk, target_frames = [13, 25] -> extract frames by splitting into chunks, into 13 and 25 frames:
251
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
252
+ oooooooooooooxxxxxxxxxxxxxoooooooooooooo
253
+ ooooooooooooooooooooooooooxxxxxxxxxxxxxo
254
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
255
+
256
+ NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
257
+ 注: frame_extraction "chunk" を使用する場合、target_frames に 1 を含めないでください。全てのフレームが抽出されてしまいます。
258
+
259
+ slide, target_frames = [1, 13, 25], frame_stride = 10 -> extract N frames with a stride of 10:
260
+ xooooooooooooooooooooooooooooooooooooooo
261
+ ooooooooooxooooooooooooooooooooooooooooo
262
+ ooooooooooooooooooooxooooooooooooooooooo
263
+ ooooooooooooooooooooooooooooooxooooooooo
264
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
265
+ ooooooooooxxxxxxxxxxxxxooooooooooooooooo
266
+ ooooooooooooooooooooxxxxxxxxxxxxxooooooo
267
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
268
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
269
+
270
+ uniform, target_frames =[1, 13, 25], frame_sample = 4 -> extract `frame_sample` samples uniformly, N frames each:
271
+ xooooooooooooooooooooooooooooooooooooooo
272
+ oooooooooooooxoooooooooooooooooooooooooo
273
+ oooooooooooooooooooooooooxoooooooooooooo
274
+ ooooooooooooooooooooooooooooooooooooooox
275
+ xxxxxxxxxxxxxooooooooooooooooooooooooooo
276
+ oooooooooxxxxxxxxxxxxxoooooooooooooooooo
277
+ ooooooooooooooooooxxxxxxxxxxxxxooooooooo
278
+ oooooooooooooooooooooooooooxxxxxxxxxxxxx
279
+ xxxxxxxxxxxxxxxxxxxxxxxxxooooooooooooooo
280
+ oooooxxxxxxxxxxxxxxxxxxxxxxxxxoooooooooo
281
+ ooooooooooxxxxxxxxxxxxxxxxxxxxxxxxxooooo
282
+ oooooooooooooooxxxxxxxxxxxxxxxxxxxxxxxxx
283
+
284
+ Three Original Videos, 20, 25, 35 frames: x = frame, o = no frame
285
+
286
+ full, max_frames = 31 -> extract all frames (trimmed to the maximum length):
287
+ video1: xxxxxxxxxxxxxxxxx (trimmed to 17 frames)
288
+ video2: xxxxxxxxxxxxxxxxxxxxxxxxx (25 frames)
289
+ video3: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx (trimmed to 31 frames)
290
+ ```
291
+
292
+ ### Sample for Image Dataset with Control Images
293
+
294
+ The dataset with control images is used for training the single frame training for FramePack.
295
+
296
+ The dataset configuration with caption text files is similar to the image dataset, but with an additional `control_directory` parameter.
297
+
298
+ The control images are used from the `control_directory` with the same filename (or different extension) as the image, for example, `image_dir/image1.jpg` and `control_dir/image1.png`. The images in `image_directory` should be the target images (the images to be generated during inference, the changed images). The `control_directory` should contain the starting images for inference. The captions should be stored in `image_directory`.
299
+
300
+ The metadata JSONL file format is the same as the image dataset, but with an additional `control_path` parameter.
301
+
302
+ ```json
303
+ {"image_path": "/path/to/image1.jpg", "control_path": "/path/to/control1.png", "caption": "A caption for image1"}
304
+ {"image_path": "/path/to/image2.jpg", "control_path": "/path/to/control2.png", "caption": "A caption for image2"}
305
+ ```
306
+
307
+ <details>
308
+ <summary>日本語</summary>
309
+ 制御画像を持つデータセットです。FramePackの単一フレーム学習に使用します。
310
+
311
+ キャプションファイルを用いる場合は`control_directory`を追加で指定してください。制御用画像は、画像と同じファイル名(または拡張子のみが異なるファイル名)の、`control_directory`にある画像が使用されます(例:`image_dir/image1.jpg`と`control_dir/image1.png`)。`image_directory`の画像は学習対象の画像(推論時に生成する画像、変化後の画像)としてください。`control_directory`には推論時の開始画像を格納してください。キャプションは`image_directory`へ格納してください。
312
+
313
+ メタデータJSONLファイルを使用する場合は、`control_path`を追加してください。
314
+ </details>
315
+
316
+ ### Sample for Video Dataset with Control Images
317
+
318
+ The dataset with control videos is used for training ControlNet models.
319
+
320
+ The dataset configuration with caption text files is similar to the video dataset, but with an additional `control_directory` parameter.
321
+
322
+ The control video for a video is used from the `control_directory` with the same filename (or different extension) as the video, for example, `video_dir/video1.mp4` and `control_dir/video1.mp4` or `control_dir/video1.mov`. The control video can also be a directory without an extension, for example, `video_dir/video1.mp4` and `control_dir/video1`.
323
+
324
+ ```toml
325
+ [[datasets]]
326
+ video_directory = "/path/to/video_dir"
327
+ control_directory = "/path/to/control_dir" # required for dataset with control videos
328
+ cache_directory = "/path/to/cache_directory" # recommended to set cache directory
329
+ target_frames = [1, 25, 45]
330
+ frame_extraction = "head"
331
+ ```
332
+
333
+ The dataset configuration with metadata JSONL file is same as the video dataset, but metadata JSONL file must include the control video paths. The control video path can be a directory containing multiple images.
334
+
335
+ ```json
336
+ {"video_path": "/path/to/video1.mp4", "control_path": "/path/to/control1.mp4", "caption": "A caption for video1"}
337
+ {"video_path": "/path/to/video2.mp4", "control_path": "/path/to/control2.mp4", "caption": "A caption for video2"}
338
+ ```
339
+
340
+ <details>
341
+ <summary>日本語</summary>
342
+ 制御動画を持つデータセットです。ControlNetモデルの学習に使用します。
343
+
344
+ キャプションを用いる場合のデータセット設定は動画データセットと似ていますが、`control_directory`パラメータが追加されています。上にある例を参照してください。ある動画に対する制御用動画として、動画と同じファイル名(または拡張子のみが異なるファイル名)の、`control_directory`にある動画が使用されます(例:`video_dir/video1.mp4`と`control_dir/video1.mp4`または`control_dir/video1.mov`)。また、拡張子なしのディレクトリ内の、複数枚の画像を制御用動画として使用することもできます(例:`video_dir/video1.mp4`と`control_dir/video1`)。
345
+
346
+ データセット設定でメタデータJSONLファイルを使用する場合は、動画と制御用動画のパスを含める必要があります。制御用動画のパスは、複数枚の画像を含むディレクトリのパスでも構いません。
347
+ </details>
348
+
349
+ ## Specifications
350
+
351
+ ```toml
352
+ # general configurations
353
+ [general]
354
+ resolution = [960, 544] # optional, [W, H], default is [960, 544]. This is the default resolution for all datasets
355
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
356
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
357
+ num_repeats = 1 # optional, default is 1. Number of times to repeat the dataset. Useful to balance the multiple datasets with different sizes.
358
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
359
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
360
+
361
+ ### Image Dataset
362
+
363
+ # sample image dataset with caption text files
364
+ [[datasets]]
365
+ image_directory = "/path/to/image_dir"
366
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
367
+ resolution = [960, 544] # required if general resolution is not set
368
+ batch_size = 4 # optional, overwrite the default batch size
369
+ num_repeats = 1 # optional, overwrite the default num_repeats
370
+ enable_bucket = false # optional, overwrite the default bucketing setting
371
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
372
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
373
+ control_directory = "/path/to/control_dir" # optional, required for dataset with control images
374
+
375
+ # sample image dataset with metadata **jsonl** file
376
+ [[datasets]]
377
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
378
+ resolution = [960, 544] # required if general resolution is not set
379
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
380
+ # caption_extension is not required for metadata jsonl file
381
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
382
+
383
+ ### Video Dataset
384
+
385
+ # sample video dataset with caption text files
386
+ [[datasets]]
387
+ video_directory = "/path/to/video_dir"
388
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
389
+ resolution = [960, 544] # required if general resolution is not set
390
+
391
+ control_directory = "/path/to/control_dir" # optional, required for dataset with control images
392
+
393
+ # following configurations must be set in each [[datasets]] section for video datasets
394
+
395
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
396
+
397
+ # NOTE: Please do not include 1 in target_frames if you are using the frame_extraction "chunk". It will make the all frames to be extracted.
398
+
399
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
400
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
401
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
402
+ max_frames = 129 # optional, default is 129. Maximum number of frames to extract, available for "full" frame extraction
403
+ # batch_size, num_repeats, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
404
+
405
+ # sample video dataset with metadata jsonl file
406
+ [[datasets]]
407
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
408
+
409
+ target_frames = [1, 79]
410
+
411
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
412
+ # frame_extraction, frame_stride, frame_sample, max_frames are also available for metadata jsonl file
413
+ ```
414
+
415
+ <!--
416
+ # sample image dataset with lance
417
+ [[datasets]]
418
+ image_lance_dataset = "/path/to/lance_dataset"
419
+ resolution = [960, 544] # required if general resolution is not set
420
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
421
+ -->
422
+
423
+ The metadata with .json file will be supported in the near future.
424
+
425
+
426
+
427
+ <!--
428
+
429
+ ```toml
430
+ # general configurations
431
+ [general]
432
+ resolution = [960, 544] # optional, [W, H], default is None. This is the default resolution for all datasets
433
+ caption_extension = ".txt" # optional, default is None. This is the default caption extension for all datasets
434
+ batch_size = 1 # optional, default is 1. This is the default batch size for all datasets
435
+ enable_bucket = true # optional, default is false. Enable bucketing for datasets
436
+ bucket_no_upscale = false # optional, default is false. Disable upscaling for bucketing. Ignored if enable_bucket is false
437
+
438
+ # sample image dataset with caption text files
439
+ [[datasets]]
440
+ image_directory = "/path/to/image_dir"
441
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
442
+ resolution = [960, 544] # required if general resolution is not set
443
+ batch_size = 4 # optional, overwrite the default batch size
444
+ enable_bucket = false # optional, overwrite the default bucketing setting
445
+ bucket_no_upscale = true # optional, overwrite the default bucketing setting
446
+ cache_directory = "/path/to/cache_directory" # optional, default is None to use the same directory as the image directory. NOTE: caching is always enabled
447
+
448
+ # sample image dataset with metadata **jsonl** file
449
+ [[datasets]]
450
+ image_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of image files and captions
451
+ resolution = [960, 544] # required if general resolution is not set
452
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
453
+ # caption_extension is not required for metadata jsonl file
454
+ # batch_size, enable_bucket, bucket_no_upscale are also available for metadata jsonl file
455
+
456
+ # sample video dataset with caption text files
457
+ [[datasets]]
458
+ video_directory = "/path/to/video_dir"
459
+ caption_extension = ".txt" # required for caption text files, if general caption extension is not set
460
+ resolution = [960, 544] # required if general resolution is not set
461
+ target_frames = [1, 25, 79] # required for video dataset. list of video lengths to extract frames. each element must be N*4+1 (N=0,1,2,...)
462
+ frame_extraction = "head" # optional, "head" or "chunk", "slide", "uniform". Default is "head"
463
+ frame_stride = 1 # optional, default is 1, available for "slide" frame extraction
464
+ frame_sample = 4 # optional, default is 1 (same as "head"), available for "uniform" frame extraction
465
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for video dataset
466
+
467
+ # sample video dataset with metadata jsonl file
468
+ [[datasets]]
469
+ video_jsonl_file = "/path/to/metadata.jsonl" # includes pairs of video files and captions
470
+ target_frames = [1, 79]
471
+ cache_directory = "/path/to/cache_directory" # required for metadata jsonl file
472
+ # frame_extraction, frame_stride, frame_sample are also available for metadata jsonl file
473
+ ```
474
+
475
+ # sample image dataset with lance
476
+ [[datasets]]
477
+ image_lance_dataset = "/path/to/lance_dataset"
478
+ resolution = [960, 544] # required if general resolution is not set
479
+ # batch_size, enable_bucket, bucket_no_upscale, cache_directory are also available for lance dataset
480
+
481
+ The metadata with .json file will be supported in the near future.
482
+
483
+
484
+
485
+
486
+ -->
dataset/image_video_dataset.py ADDED
@@ -0,0 +1,1786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from concurrent.futures import ThreadPoolExecutor
2
+ import glob
3
+ import json
4
+ import math
5
+ import os
6
+ import random
7
+ import time
8
+ from typing import Optional, Sequence, Tuple, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ from safetensors.torch import save_file, load_file
13
+ from safetensors import safe_open
14
+ from PIL import Image
15
+ import cv2
16
+ import av
17
+
18
+ from utils import safetensors_utils
19
+ from utils.model_utils import dtype_to_str
20
+
21
+ import logging
22
+
23
+ logger = logging.getLogger(__name__)
24
+ logging.basicConfig(level=logging.INFO)
25
+
26
+
27
+ IMAGE_EXTENSIONS = [".png", ".jpg", ".jpeg", ".webp", ".bmp", ".PNG", ".JPG", ".JPEG", ".WEBP", ".BMP"]
28
+
29
+ try:
30
+ import pillow_avif
31
+
32
+ IMAGE_EXTENSIONS.extend([".avif", ".AVIF"])
33
+ except:
34
+ pass
35
+
36
+ # JPEG-XL on Linux
37
+ try:
38
+ from jxlpy import JXLImagePlugin
39
+
40
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
41
+ except:
42
+ pass
43
+
44
+ # JPEG-XL on Windows
45
+ try:
46
+ import pillow_jxl
47
+
48
+ IMAGE_EXTENSIONS.extend([".jxl", ".JXL"])
49
+ except:
50
+ pass
51
+
52
+ VIDEO_EXTENSIONS = [
53
+ ".mp4",
54
+ ".webm",
55
+ ".avi",
56
+ ".mkv",
57
+ ".mov",
58
+ ".flv",
59
+ ".wmv",
60
+ ".m4v",
61
+ ".mpg",
62
+ ".mpeg",
63
+ ".MP4",
64
+ ".WEBM",
65
+ ".AVI",
66
+ ".MKV",
67
+ ".MOV",
68
+ ".FLV",
69
+ ".WMV",
70
+ ".M4V",
71
+ ".MPG",
72
+ ".MPEG",
73
+ ] # some of them are not tested
74
+
75
+ ARCHITECTURE_HUNYUAN_VIDEO = "hv"
76
+ ARCHITECTURE_HUNYUAN_VIDEO_FULL = "hunyuan_video"
77
+ ARCHITECTURE_WAN = "wan"
78
+ ARCHITECTURE_WAN_FULL = "wan"
79
+ ARCHITECTURE_FRAMEPACK = "fp"
80
+ ARCHITECTURE_FRAMEPACK_FULL = "framepack"
81
+
82
+
83
+ def glob_images(directory, base="*"):
84
+ img_paths = []
85
+ for ext in IMAGE_EXTENSIONS:
86
+ if base == "*":
87
+ img_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
88
+ else:
89
+ img_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
90
+ img_paths = list(set(img_paths)) # remove duplicates
91
+ img_paths.sort()
92
+ return img_paths
93
+
94
+
95
+ def glob_videos(directory, base="*"):
96
+ video_paths = []
97
+ for ext in VIDEO_EXTENSIONS:
98
+ if base == "*":
99
+ video_paths.extend(glob.glob(os.path.join(glob.escape(directory), base + ext)))
100
+ else:
101
+ video_paths.extend(glob.glob(glob.escape(os.path.join(directory, base + ext))))
102
+ video_paths = list(set(video_paths)) # remove duplicates
103
+ video_paths.sort()
104
+ return video_paths
105
+
106
+
107
+ def divisible_by(num: int, divisor: int) -> int:
108
+ return num - num % divisor
109
+
110
+
111
+ def resize_image_to_bucket(image: Union[Image.Image, np.ndarray], bucket_reso: tuple[int, int]) -> np.ndarray:
112
+ """
113
+ Resize the image to the bucket resolution.
114
+
115
+ bucket_reso: **(width, height)**
116
+ """
117
+ is_pil_image = isinstance(image, Image.Image)
118
+ if is_pil_image:
119
+ image_width, image_height = image.size
120
+ else:
121
+ image_height, image_width = image.shape[:2]
122
+
123
+ if bucket_reso == (image_width, image_height):
124
+ return np.array(image) if is_pil_image else image
125
+
126
+ bucket_width, bucket_height = bucket_reso
127
+ if bucket_width == image_width or bucket_height == image_height:
128
+ image = np.array(image) if is_pil_image else image
129
+ else:
130
+ # resize the image to the bucket resolution to match the short side
131
+ scale_width = bucket_width / image_width
132
+ scale_height = bucket_height / image_height
133
+ scale = max(scale_width, scale_height)
134
+ image_width = int(image_width * scale + 0.5)
135
+ image_height = int(image_height * scale + 0.5)
136
+
137
+ if scale > 1:
138
+ image = Image.fromarray(image) if not is_pil_image else image
139
+ image = image.resize((image_width, image_height), Image.LANCZOS)
140
+ image = np.array(image)
141
+ else:
142
+ image = np.array(image) if is_pil_image else image
143
+ image = cv2.resize(image, (image_width, image_height), interpolation=cv2.INTER_AREA)
144
+
145
+ # crop the image to the bucket resolution
146
+ crop_left = (image_width - bucket_width) // 2
147
+ crop_top = (image_height - bucket_height) // 2
148
+ image = image[crop_top : crop_top + bucket_height, crop_left : crop_left + bucket_width]
149
+ return image
150
+
151
+
152
+ class ItemInfo:
153
+ def __init__(
154
+ self,
155
+ item_key: str,
156
+ caption: str,
157
+ original_size: tuple[int, int],
158
+ bucket_size: Optional[Union[tuple[int, int], tuple[int, int, int]]] = None,
159
+ frame_count: Optional[int] = None,
160
+ content: Optional[np.ndarray] = None,
161
+ latent_cache_path: Optional[str] = None,
162
+ ) -> None:
163
+ self.item_key = item_key
164
+ self.caption = caption
165
+ self.original_size = original_size
166
+ self.bucket_size = bucket_size
167
+ self.frame_count = frame_count
168
+ self.content = content
169
+ self.latent_cache_path = latent_cache_path
170
+ self.text_encoder_output_cache_path: Optional[str] = None
171
+ self.control_content: Optional[np.ndarray] = None
172
+
173
+ def __str__(self) -> str:
174
+ return (
175
+ f"ItemInfo(item_key={self.item_key}, caption={self.caption}, "
176
+ + f"original_size={self.original_size}, bucket_size={self.bucket_size}, "
177
+ + f"frame_count={self.frame_count}, latent_cache_path={self.latent_cache_path}, content={self.content.shape if self.content is not None else None})"
178
+ )
179
+
180
+
181
+ # We use simple if-else approach to support multiple architectures.
182
+ # Maybe we can use a plugin system in the future.
183
+
184
+ # the keys of the dict are `<content_type>_FxHxW_<dtype>` for latents
185
+ # and `<content_type>_<dtype|mask>` for other tensors
186
+
187
+
188
+ def save_latent_cache(item_info: ItemInfo, latent: torch.Tensor):
189
+ """HunyuanVideo architecture only. HunyuanVideo doesn't support I2V and control latents"""
190
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
191
+
192
+ _, F, H, W = latent.shape
193
+ dtype_str = dtype_to_str(latent.dtype)
194
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
195
+
196
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
197
+
198
+
199
+ def save_latent_cache_wan(
200
+ item_info: ItemInfo,
201
+ latent: torch.Tensor,
202
+ clip_embed: Optional[torch.Tensor],
203
+ image_latent: Optional[torch.Tensor],
204
+ control_latent: Optional[torch.Tensor],
205
+ ):
206
+ """Wan architecture only"""
207
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
208
+
209
+ _, F, H, W = latent.shape
210
+ dtype_str = dtype_to_str(latent.dtype)
211
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu()}
212
+
213
+ if clip_embed is not None:
214
+ sd[f"clip_{dtype_str}"] = clip_embed.detach().cpu()
215
+
216
+ if image_latent is not None:
217
+ sd[f"latents_image_{F}x{H}x{W}_{dtype_str}"] = image_latent.detach().cpu()
218
+
219
+ if control_latent is not None:
220
+ sd[f"latents_control_{F}x{H}x{W}_{dtype_str}"] = control_latent.detach().cpu()
221
+
222
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
223
+
224
+
225
+ def save_latent_cache_framepack(
226
+ item_info: ItemInfo,
227
+ latent: torch.Tensor,
228
+ latent_indices: torch.Tensor,
229
+ clean_latents: torch.Tensor,
230
+ clean_latent_indices: torch.Tensor,
231
+ clean_latents_2x: torch.Tensor,
232
+ clean_latent_2x_indices: torch.Tensor,
233
+ clean_latents_4x: torch.Tensor,
234
+ clean_latent_4x_indices: torch.Tensor,
235
+ image_embeddings: torch.Tensor,
236
+ ):
237
+ """FramePack architecture only"""
238
+ assert latent.dim() == 4, "latent should be 4D tensor (frame, channel, height, width)"
239
+
240
+ _, F, H, W = latent.shape
241
+ dtype_str = dtype_to_str(latent.dtype)
242
+ sd = {f"latents_{F}x{H}x{W}_{dtype_str}": latent.detach().cpu().contiguous()}
243
+
244
+ # `latents_xxx` must have {F, H, W} suffix
245
+ indices_dtype_str = dtype_to_str(latent_indices.dtype)
246
+ sd[f"image_embeddings_{dtype_str}"] = image_embeddings.detach().cpu() # image embeddings dtype is same as latents dtype
247
+ sd[f"latent_indices_{indices_dtype_str}"] = latent_indices.detach().cpu()
248
+ sd[f"clean_latent_indices_{indices_dtype_str}"] = clean_latent_indices.detach().cpu()
249
+ sd[f"clean_latent_2x_indices_{indices_dtype_str}"] = clean_latent_2x_indices.detach().cpu()
250
+ sd[f"clean_latent_4x_indices_{indices_dtype_str}"] = clean_latent_4x_indices.detach().cpu()
251
+ sd[f"latents_clean_{F}x{H}x{W}_{dtype_str}"] = clean_latents.detach().cpu().contiguous()
252
+ sd[f"latents_clean_2x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_2x.detach().cpu().contiguous()
253
+ sd[f"latents_clean_4x_{F}x{H}x{W}_{dtype_str}"] = clean_latents_4x.detach().cpu().contiguous()
254
+
255
+ # for key, value in sd.items():
256
+ # print(f"{key}: {value.shape}")
257
+ save_latent_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL)
258
+
259
+
260
+ def save_latent_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
261
+ metadata = {
262
+ "architecture": arch_fullname,
263
+ "width": f"{item_info.original_size[0]}",
264
+ "height": f"{item_info.original_size[1]}",
265
+ "format_version": "1.0.1",
266
+ }
267
+ if item_info.frame_count is not None:
268
+ metadata["frame_count"] = f"{item_info.frame_count}"
269
+
270
+ for key, value in sd.items():
271
+ # NaN check and show warning, replace NaN with 0
272
+ if torch.isnan(value).any():
273
+ logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
274
+ value[torch.isnan(value)] = 0
275
+
276
+ latent_dir = os.path.dirname(item_info.latent_cache_path)
277
+ os.makedirs(latent_dir, exist_ok=True)
278
+
279
+ save_file(sd, item_info.latent_cache_path, metadata=metadata)
280
+
281
+
282
+ def save_text_encoder_output_cache(item_info: ItemInfo, embed: torch.Tensor, mask: Optional[torch.Tensor], is_llm: bool):
283
+ """HunyuanVideo architecture only"""
284
+ assert (
285
+ embed.dim() == 1 or embed.dim() == 2
286
+ ), f"embed should be 2D tensor (feature, hidden_size) or (hidden_size,), got {embed.shape}"
287
+ assert mask is None or mask.dim() == 1, f"mask should be 1D tensor (feature), got {mask.shape}"
288
+
289
+ sd = {}
290
+ dtype_str = dtype_to_str(embed.dtype)
291
+ text_encoder_type = "llm" if is_llm else "clipL"
292
+ sd[f"{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
293
+ if mask is not None:
294
+ sd[f"{text_encoder_type}_mask"] = mask.detach().cpu()
295
+
296
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_HUNYUAN_VIDEO_FULL)
297
+
298
+
299
+ def save_text_encoder_output_cache_wan(item_info: ItemInfo, embed: torch.Tensor):
300
+ """Wan architecture only. Wan2.1 only has a single text encoder"""
301
+
302
+ sd = {}
303
+ dtype_str = dtype_to_str(embed.dtype)
304
+ text_encoder_type = "t5"
305
+ sd[f"varlen_{text_encoder_type}_{dtype_str}"] = embed.detach().cpu()
306
+
307
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_WAN_FULL)
308
+
309
+
310
+ def save_text_encoder_output_cache_framepack(
311
+ item_info: ItemInfo, llama_vec: torch.Tensor, llama_attention_mask: torch.Tensor, clip_l_pooler: torch.Tensor
312
+ ):
313
+ """FramePack architecture only."""
314
+ sd = {}
315
+ dtype_str = dtype_to_str(llama_vec.dtype)
316
+ sd[f"llama_vec_{dtype_str}"] = llama_vec.detach().cpu()
317
+ sd[f"llama_attention_mask"] = llama_attention_mask.detach().cpu()
318
+ dtype_str = dtype_to_str(clip_l_pooler.dtype)
319
+ sd[f"clip_l_pooler_{dtype_str}"] = clip_l_pooler.detach().cpu()
320
+
321
+ save_text_encoder_output_cache_common(item_info, sd, ARCHITECTURE_FRAMEPACK_FULL)
322
+
323
+
324
+ def save_text_encoder_output_cache_common(item_info: ItemInfo, sd: dict[str, torch.Tensor], arch_fullname: str):
325
+ for key, value in sd.items():
326
+ # NaN check and show warning, replace NaN with 0
327
+ if torch.isnan(value).any():
328
+ logger.warning(f"{key} tensor has NaN: {item_info.item_key}, replace NaN with 0")
329
+ value[torch.isnan(value)] = 0
330
+
331
+ metadata = {
332
+ "architecture": arch_fullname,
333
+ "caption1": item_info.caption,
334
+ "format_version": "1.0.1",
335
+ }
336
+
337
+ if os.path.exists(item_info.text_encoder_output_cache_path):
338
+ # load existing cache and update metadata
339
+ with safetensors_utils.MemoryEfficientSafeOpen(item_info.text_encoder_output_cache_path) as f:
340
+ existing_metadata = f.metadata()
341
+ for key in f.keys():
342
+ if key not in sd: # avoid overwriting by existing cache, we keep the new one
343
+ sd[key] = f.get_tensor(key)
344
+
345
+ assert existing_metadata["architecture"] == metadata["architecture"], "architecture mismatch"
346
+ if existing_metadata["caption1"] != metadata["caption1"]:
347
+ logger.warning(f"caption mismatch: existing={existing_metadata['caption1']}, new={metadata['caption1']}, overwrite")
348
+ # TODO verify format_version
349
+
350
+ existing_metadata.pop("caption1", None)
351
+ existing_metadata.pop("format_version", None)
352
+ metadata.update(existing_metadata) # copy existing metadata except caption and format_version
353
+ else:
354
+ text_encoder_output_dir = os.path.dirname(item_info.text_encoder_output_cache_path)
355
+ os.makedirs(text_encoder_output_dir, exist_ok=True)
356
+
357
+ safetensors_utils.mem_eff_save_file(sd, item_info.text_encoder_output_cache_path, metadata=metadata)
358
+
359
+
360
+ class BucketSelector:
361
+ RESOLUTION_STEPS_HUNYUAN = 16
362
+ RESOLUTION_STEPS_WAN = 16
363
+ RESOLUTION_STEPS_FRAMEPACK = 16
364
+
365
+ def __init__(
366
+ self, resolution: Tuple[int, int], enable_bucket: bool = True, no_upscale: bool = False, architecture: str = "no_default"
367
+ ):
368
+ self.resolution = resolution
369
+ self.bucket_area = resolution[0] * resolution[1]
370
+ self.architecture = architecture
371
+
372
+ if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
373
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_HUNYUAN
374
+ elif self.architecture == ARCHITECTURE_WAN:
375
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_WAN
376
+ elif self.architecture == ARCHITECTURE_FRAMEPACK:
377
+ self.reso_steps = BucketSelector.RESOLUTION_STEPS_FRAMEPACK
378
+ else:
379
+ raise ValueError(f"Invalid architecture: {self.architecture}")
380
+
381
+ if not enable_bucket:
382
+ # only define one bucket
383
+ self.bucket_resolutions = [resolution]
384
+ self.no_upscale = False
385
+ else:
386
+ # prepare bucket resolution
387
+ self.no_upscale = no_upscale
388
+ sqrt_size = int(math.sqrt(self.bucket_area))
389
+ min_size = divisible_by(sqrt_size // 2, self.reso_steps)
390
+ self.bucket_resolutions = []
391
+ for w in range(min_size, sqrt_size + self.reso_steps, self.reso_steps):
392
+ h = divisible_by(self.bucket_area // w, self.reso_steps)
393
+ self.bucket_resolutions.append((w, h))
394
+ self.bucket_resolutions.append((h, w))
395
+
396
+ self.bucket_resolutions = list(set(self.bucket_resolutions))
397
+ self.bucket_resolutions.sort()
398
+
399
+ # calculate aspect ratio to find the nearest resolution
400
+ self.aspect_ratios = np.array([w / h for w, h in self.bucket_resolutions])
401
+
402
+ def get_bucket_resolution(self, image_size: tuple[int, int]) -> tuple[int, int]:
403
+ """
404
+ return the bucket resolution for the given image size, (width, height)
405
+ """
406
+ area = image_size[0] * image_size[1]
407
+ if self.no_upscale and area <= self.bucket_area:
408
+ w, h = image_size
409
+ w = divisible_by(w, self.reso_steps)
410
+ h = divisible_by(h, self.reso_steps)
411
+ return w, h
412
+
413
+ aspect_ratio = image_size[0] / image_size[1]
414
+ ar_errors = self.aspect_ratios - aspect_ratio
415
+ bucket_id = np.abs(ar_errors).argmin()
416
+ return self.bucket_resolutions[bucket_id]
417
+
418
+
419
+ def load_video(
420
+ video_path: str,
421
+ start_frame: Optional[int] = None,
422
+ end_frame: Optional[int] = None,
423
+ bucket_selector: Optional[BucketSelector] = None,
424
+ bucket_reso: Optional[tuple[int, int]] = None,
425
+ source_fps: Optional[float] = None,
426
+ target_fps: Optional[float] = None,
427
+ ) -> list[np.ndarray]:
428
+ """
429
+ bucket_reso: if given, resize the video to the bucket resolution, (width, height)
430
+ """
431
+ if source_fps is None or target_fps is None:
432
+ if os.path.isfile(video_path):
433
+ container = av.open(video_path)
434
+ video = []
435
+ for i, frame in enumerate(container.decode(video=0)):
436
+ if start_frame is not None and i < start_frame:
437
+ continue
438
+ if end_frame is not None and i >= end_frame:
439
+ break
440
+ frame = frame.to_image()
441
+
442
+ if bucket_selector is not None and bucket_reso is None:
443
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size) # calc resolution from first frame
444
+
445
+ if bucket_reso is not None:
446
+ frame = resize_image_to_bucket(frame, bucket_reso)
447
+ else:
448
+ frame = np.array(frame)
449
+
450
+ video.append(frame)
451
+ container.close()
452
+ else:
453
+ # load images in the directory
454
+ image_files = glob_images(video_path)
455
+ image_files.sort()
456
+ video = []
457
+ for i in range(len(image_files)):
458
+ if start_frame is not None and i < start_frame:
459
+ continue
460
+ if end_frame is not None and i >= end_frame:
461
+ break
462
+
463
+ image_file = image_files[i]
464
+ image = Image.open(image_file).convert("RGB")
465
+
466
+ if bucket_selector is not None and bucket_reso is None:
467
+ bucket_reso = bucket_selector.get_bucket_resolution(image.size) # calc resolution from first frame
468
+ image = np.array(image)
469
+ if bucket_reso is not None:
470
+ image = resize_image_to_bucket(image, bucket_reso)
471
+
472
+ video.append(image)
473
+ else:
474
+ # drop frames to match the target fps TODO commonize this code with the above if this works
475
+ frame_index_delta = target_fps / source_fps # example: 16 / 30 = 0.5333
476
+ if os.path.isfile(video_path):
477
+ container = av.open(video_path)
478
+ video = []
479
+ frame_index_with_fraction = 0.0
480
+ previous_frame_index = -1
481
+ for i, frame in enumerate(container.decode(video=0)):
482
+ target_frame_index = int(frame_index_with_fraction)
483
+ frame_index_with_fraction += frame_index_delta
484
+
485
+ if target_frame_index == previous_frame_index: # drop this frame
486
+ continue
487
+
488
+ # accept this frame
489
+ previous_frame_index = target_frame_index
490
+
491
+ if start_frame is not None and target_frame_index < start_frame:
492
+ continue
493
+ if end_frame is not None and target_frame_index >= end_frame:
494
+ break
495
+ frame = frame.to_image()
496
+
497
+ if bucket_selector is not None and bucket_reso is None:
498
+ bucket_reso = bucket_selector.get_bucket_resolution(frame.size) # calc resolution from first frame
499
+
500
+ if bucket_reso is not None:
501
+ frame = resize_image_to_bucket(frame, bucket_reso)
502
+ else:
503
+ frame = np.array(frame)
504
+
505
+ video.append(frame)
506
+ container.close()
507
+ else:
508
+ # load images in the directory
509
+ image_files = glob_images(video_path)
510
+ image_files.sort()
511
+ video = []
512
+ frame_index_with_fraction = 0.0
513
+ previous_frame_index = -1
514
+ for i in range(len(image_files)):
515
+ target_frame_index = int(frame_index_with_fraction)
516
+ frame_index_with_fraction += frame_index_delta
517
+
518
+ if target_frame_index == previous_frame_index: # drop this frame
519
+ continue
520
+
521
+ # accept this frame
522
+ previous_frame_index = target_frame_index
523
+
524
+ if start_frame is not None and target_frame_index < start_frame:
525
+ continue
526
+ if end_frame is not None and target_frame_index >= end_frame:
527
+ break
528
+
529
+ image_file = image_files[i]
530
+ image = Image.open(image_file).convert("RGB")
531
+
532
+ if bucket_selector is not None and bucket_reso is None:
533
+ bucket_reso = bucket_selector.get_bucket_resolution(image.size) # calc resolution from first frame
534
+ image = np.array(image)
535
+ if bucket_reso is not None:
536
+ image = resize_image_to_bucket(image, bucket_reso)
537
+
538
+ video.append(image)
539
+
540
+ return video
541
+
542
+
543
+ class BucketBatchManager:
544
+
545
+ def __init__(self, bucketed_item_info: dict[Union[tuple[int, int], tuple[int, int, int]], list[ItemInfo]], batch_size: int):
546
+ self.batch_size = batch_size
547
+ self.buckets = bucketed_item_info
548
+ self.bucket_resos = list(self.buckets.keys())
549
+ self.bucket_resos.sort()
550
+
551
+ # indices for enumerating batches. each batch is reso + batch_idx. reso is (width, height) or (width, height, frames)
552
+ self.bucket_batch_indices: list[tuple[Union[tuple[int, int], tuple[int, int, int], int]]] = []
553
+ for bucket_reso in self.bucket_resos:
554
+ bucket = self.buckets[bucket_reso]
555
+ num_batches = math.ceil(len(bucket) / self.batch_size)
556
+ for i in range(num_batches):
557
+ self.bucket_batch_indices.append((bucket_reso, i))
558
+
559
+ # do no shuffle here to avoid multiple datasets have different order
560
+ # self.shuffle()
561
+
562
+ def show_bucket_info(self):
563
+ for bucket_reso in self.bucket_resos:
564
+ bucket = self.buckets[bucket_reso]
565
+ logger.info(f"bucket: {bucket_reso}, count: {len(bucket)}")
566
+
567
+ logger.info(f"total batches: {len(self)}")
568
+
569
+ def shuffle(self):
570
+ # shuffle each bucket
571
+ for bucket in self.buckets.values():
572
+ random.shuffle(bucket)
573
+
574
+ # shuffle the order of batches
575
+ random.shuffle(self.bucket_batch_indices)
576
+
577
+ def __len__(self):
578
+ return len(self.bucket_batch_indices)
579
+
580
+ def __getitem__(self, idx):
581
+ bucket_reso, batch_idx = self.bucket_batch_indices[idx]
582
+ bucket = self.buckets[bucket_reso]
583
+ start = batch_idx * self.batch_size
584
+ end = min(start + self.batch_size, len(bucket))
585
+
586
+ batch_tensor_data = {}
587
+ varlen_keys = set()
588
+ for item_info in bucket[start:end]:
589
+ sd_latent = load_file(item_info.latent_cache_path)
590
+ sd_te = load_file(item_info.text_encoder_output_cache_path)
591
+ sd = {**sd_latent, **sd_te}
592
+
593
+ # TODO refactor this
594
+ for key in sd.keys():
595
+ is_varlen_key = key.startswith("varlen_") # varlen keys are not stacked
596
+ content_key = key
597
+
598
+ if is_varlen_key:
599
+ content_key = content_key.replace("varlen_", "")
600
+
601
+ if content_key.endswith("_mask"):
602
+ pass
603
+ else:
604
+ content_key = content_key.rsplit("_", 1)[0] # remove dtype
605
+ if content_key.startswith("latents_"):
606
+ content_key = content_key.rsplit("_", 1)[0] # remove FxHxW
607
+
608
+ if content_key not in batch_tensor_data:
609
+ batch_tensor_data[content_key] = []
610
+ batch_tensor_data[content_key].append(sd[key])
611
+
612
+ if is_varlen_key:
613
+ varlen_keys.add(content_key)
614
+
615
+ for key in batch_tensor_data.keys():
616
+ if key not in varlen_keys:
617
+ batch_tensor_data[key] = torch.stack(batch_tensor_data[key])
618
+
619
+ return batch_tensor_data
620
+
621
+
622
+ class ContentDatasource:
623
+ def __init__(self):
624
+ self.caption_only = False # set to True to only fetch caption for Text Encoder caching
625
+ self.has_control = False
626
+
627
+ def set_caption_only(self, caption_only: bool):
628
+ self.caption_only = caption_only
629
+
630
+ def is_indexable(self):
631
+ return False
632
+
633
+ def get_caption(self, idx: int) -> tuple[str, str]:
634
+ """
635
+ Returns caption. May not be called if is_indexable() returns False.
636
+ """
637
+ raise NotImplementedError
638
+
639
+ def __len__(self):
640
+ raise NotImplementedError
641
+
642
+ def __iter__(self):
643
+ raise NotImplementedError
644
+
645
+ def __next__(self):
646
+ raise NotImplementedError
647
+
648
+
649
+ class ImageDatasource(ContentDatasource):
650
+ def __init__(self):
651
+ super().__init__()
652
+
653
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str]:
654
+ """
655
+ Returns image data as a tuple of image path, image, and caption for the given index.
656
+ Key must be unique and valid as a file name.
657
+ May not be called if is_indexable() returns False.
658
+ """
659
+ raise NotImplementedError
660
+
661
+
662
+ class ImageDirectoryDatasource(ImageDatasource):
663
+ def __init__(self, image_directory: str, caption_extension: Optional[str] = None, control_directory: Optional[str] = None):
664
+ super().__init__()
665
+ self.image_directory = image_directory
666
+ self.caption_extension = caption_extension
667
+ self.control_directory = control_directory
668
+ self.current_idx = 0
669
+
670
+ # glob images
671
+ logger.info(f"glob images in {self.image_directory}")
672
+ self.image_paths = glob_images(self.image_directory)
673
+ logger.info(f"found {len(self.image_paths)} images")
674
+
675
+ # glob control images if specified
676
+ if self.control_directory is not None:
677
+ logger.info(f"glob control images in {self.control_directory}")
678
+ self.has_control = True
679
+ self.control_paths = {}
680
+ for image_path in self.image_paths:
681
+ image_basename = os.path.basename(image_path)
682
+ control_path = os.path.join(self.control_directory, image_basename)
683
+ if os.path.exists(control_path):
684
+ self.control_paths[image_path] = control_path
685
+ else:
686
+ # another extension for control path
687
+ # for example: image_path = "img/image.png" -> control_path = "control/image.jpg"
688
+ image_basename_no_ext = os.path.splitext(image_basename)[0]
689
+ for ext in IMAGE_EXTENSIONS:
690
+ potential_path = os.path.join(self.control_directory, image_basename_no_ext + ext)
691
+ if os.path.exists(potential_path):
692
+ self.control_paths[image_path] = potential_path
693
+ break
694
+
695
+ logger.info(f"found {len(self.control_paths)} matching control images")
696
+ missing_controls = len(self.image_paths) - len(self.control_paths)
697
+ if missing_controls > 0:
698
+ missing_control_paths = set(self.image_paths) - set(self.control_paths.keys())
699
+ logger.error(f"Could not find matching control images for {missing_controls} images: {missing_control_paths}")
700
+ raise ValueError(f"Could not find matching control images for {missing_controls} images")
701
+
702
+ def is_indexable(self):
703
+ return True
704
+
705
+ def __len__(self):
706
+ return len(self.image_paths)
707
+
708
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[Image.Image]]:
709
+ image_path = self.image_paths[idx]
710
+ image = Image.open(image_path).convert("RGB")
711
+
712
+ _, caption = self.get_caption(idx)
713
+
714
+ control = None
715
+ if self.has_control:
716
+ control_path = self.control_paths[image_path]
717
+ control = Image.open(control_path).convert("RGB")
718
+
719
+ return image_path, image, caption, control
720
+
721
+ def get_caption(self, idx: int) -> tuple[str, str]:
722
+ image_path = self.image_paths[idx]
723
+ caption_path = os.path.splitext(image_path)[0] + self.caption_extension if self.caption_extension else ""
724
+ with open(caption_path, "r", encoding="utf-8") as f:
725
+ caption = f.read().strip()
726
+ return image_path, caption
727
+
728
+ def __iter__(self):
729
+ self.current_idx = 0
730
+ return self
731
+
732
+ def __next__(self) -> callable:
733
+ """
734
+ Returns a fetcher function that returns image data.
735
+ """
736
+ if self.current_idx >= len(self.image_paths):
737
+ raise StopIteration
738
+
739
+ if self.caption_only:
740
+
741
+ def create_caption_fetcher(index):
742
+ return lambda: self.get_caption(index)
743
+
744
+ fetcher = create_caption_fetcher(self.current_idx)
745
+ else:
746
+
747
+ def create_image_fetcher(index):
748
+ return lambda: self.get_image_data(index)
749
+
750
+ fetcher = create_image_fetcher(self.current_idx)
751
+
752
+ self.current_idx += 1
753
+ return fetcher
754
+
755
+
756
+ class ImageJsonlDatasource(ImageDatasource):
757
+ def __init__(self, image_jsonl_file: str):
758
+ super().__init__()
759
+ self.image_jsonl_file = image_jsonl_file
760
+ self.current_idx = 0
761
+
762
+ # load jsonl
763
+ logger.info(f"load image jsonl from {self.image_jsonl_file}")
764
+ self.data = []
765
+ with open(self.image_jsonl_file, "r", encoding="utf-8") as f:
766
+ for line in f:
767
+ try:
768
+ data = json.loads(line)
769
+ except json.JSONDecodeError:
770
+ logger.error(f"failed to load json: {line} @ {self.image_jsonl_file}")
771
+ raise
772
+ self.data.append(data)
773
+ logger.info(f"loaded {len(self.data)} images")
774
+
775
+ # Check if there are control paths in the JSONL
776
+ self.has_control = any("control_path" in item for item in self.data)
777
+ if self.has_control:
778
+ control_count = sum(1 for item in self.data if "control_path" in item)
779
+ if control_count < len(self.data):
780
+ missing_control_images = [item["image_path"] for item in self.data if "control_path" not in item]
781
+ logger.error(f"Some images do not have control paths in JSONL data: {missing_control_images}")
782
+ raise ValueError(f"Some images do not have control paths in JSONL data: {missing_control_images}")
783
+ logger.info(f"found {control_count} control images in JSONL data")
784
+
785
+ def is_indexable(self):
786
+ return True
787
+
788
+ def __len__(self):
789
+ return len(self.data)
790
+
791
+ def get_image_data(self, idx: int) -> tuple[str, Image.Image, str, Optional[Image.Image]]:
792
+ data = self.data[idx]
793
+ image_path = data["image_path"]
794
+ image = Image.open(image_path).convert("RGB")
795
+
796
+ caption = data["caption"]
797
+
798
+ control = None
799
+ if self.has_control:
800
+ control_path = data["control_path"]
801
+ control = Image.open(control_path).convert("RGB")
802
+
803
+ return image_path, image, caption, control
804
+
805
+ def get_caption(self, idx: int) -> tuple[str, str]:
806
+ data = self.data[idx]
807
+ image_path = data["image_path"]
808
+ caption = data["caption"]
809
+ return image_path, caption
810
+
811
+ def __iter__(self):
812
+ self.current_idx = 0
813
+ return self
814
+
815
+ def __next__(self) -> callable:
816
+ if self.current_idx >= len(self.data):
817
+ raise StopIteration
818
+
819
+ if self.caption_only:
820
+
821
+ def create_caption_fetcher(index):
822
+ return lambda: self.get_caption(index)
823
+
824
+ fetcher = create_caption_fetcher(self.current_idx)
825
+
826
+ else:
827
+
828
+ def create_fetcher(index):
829
+ return lambda: self.get_image_data(index)
830
+
831
+ fetcher = create_fetcher(self.current_idx)
832
+
833
+ self.current_idx += 1
834
+ return fetcher
835
+
836
+
837
+ class VideoDatasource(ContentDatasource):
838
+ def __init__(self):
839
+ super().__init__()
840
+
841
+ # None means all frames
842
+ self.start_frame = None
843
+ self.end_frame = None
844
+
845
+ self.bucket_selector = None
846
+
847
+ self.source_fps = None
848
+ self.target_fps = None
849
+
850
+ def __len__(self):
851
+ raise NotImplementedError
852
+
853
+ def get_video_data_from_path(
854
+ self,
855
+ video_path: str,
856
+ start_frame: Optional[int] = None,
857
+ end_frame: Optional[int] = None,
858
+ bucket_selector: Optional[BucketSelector] = None,
859
+ ) -> tuple[str, list[Image.Image], str]:
860
+ # this method can resize the video if bucket_selector is given to reduce the memory usage
861
+
862
+ start_frame = start_frame if start_frame is not None else self.start_frame
863
+ end_frame = end_frame if end_frame is not None else self.end_frame
864
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
865
+
866
+ video = load_video(
867
+ video_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps
868
+ )
869
+ return video
870
+
871
+ def get_control_data_from_path(
872
+ self,
873
+ control_path: str,
874
+ start_frame: Optional[int] = None,
875
+ end_frame: Optional[int] = None,
876
+ bucket_selector: Optional[BucketSelector] = None,
877
+ ) -> list[Image.Image]:
878
+ start_frame = start_frame if start_frame is not None else self.start_frame
879
+ end_frame = end_frame if end_frame is not None else self.end_frame
880
+ bucket_selector = bucket_selector if bucket_selector is not None else self.bucket_selector
881
+
882
+ control = load_video(
883
+ control_path, start_frame, end_frame, bucket_selector, source_fps=self.source_fps, target_fps=self.target_fps
884
+ )
885
+ return control
886
+
887
+ def set_start_and_end_frame(self, start_frame: Optional[int], end_frame: Optional[int]):
888
+ self.start_frame = start_frame
889
+ self.end_frame = end_frame
890
+
891
+ def set_bucket_selector(self, bucket_selector: BucketSelector):
892
+ self.bucket_selector = bucket_selector
893
+
894
+ def set_source_and_target_fps(self, source_fps: Optional[float], target_fps: Optional[float]):
895
+ self.source_fps = source_fps
896
+ self.target_fps = target_fps
897
+
898
+ def __iter__(self):
899
+ raise NotImplementedError
900
+
901
+ def __next__(self):
902
+ raise NotImplementedError
903
+
904
+
905
+ class VideoDirectoryDatasource(VideoDatasource):
906
+ def __init__(self, video_directory: str, caption_extension: Optional[str] = None, control_directory: Optional[str] = None):
907
+ super().__init__()
908
+ self.video_directory = video_directory
909
+ self.caption_extension = caption_extension
910
+ self.control_directory = control_directory # 新しく追加: コントロール画像ディレクトリ
911
+ self.current_idx = 0
912
+
913
+ # glob videos
914
+ logger.info(f"glob videos in {self.video_directory}")
915
+ self.video_paths = glob_videos(self.video_directory)
916
+ logger.info(f"found {len(self.video_paths)} videos")
917
+
918
+ # glob control images if specified
919
+ if self.control_directory is not None:
920
+ logger.info(f"glob control videos in {self.control_directory}")
921
+ self.has_control = True
922
+ self.control_paths = {}
923
+ for video_path in self.video_paths:
924
+ video_basename = os.path.basename(video_path)
925
+ # construct control path from video path
926
+ # for example: video_path = "vid/video.mp4" -> control_path = "control/video.mp4"
927
+ control_path = os.path.join(self.control_directory, video_basename)
928
+ if os.path.exists(control_path):
929
+ self.control_paths[video_path] = control_path
930
+ else:
931
+ # use the same base name for control path
932
+ base_name = os.path.splitext(video_basename)[0]
933
+
934
+ # directory with images. for example: video_path = "vid/video.mp4" -> control_path = "control/video"
935
+ potential_path = os.path.join(self.control_directory, base_name) # no extension
936
+ if os.path.isdir(potential_path):
937
+ self.control_paths[video_path] = potential_path
938
+ else:
939
+ # another extension for control path
940
+ # for example: video_path = "vid/video.mp4" -> control_path = "control/video.mov"
941
+ for ext in VIDEO_EXTENSIONS:
942
+ potential_path = os.path.join(self.control_directory, base_name + ext)
943
+ if os.path.exists(potential_path):
944
+ self.control_paths[video_path] = potential_path
945
+ break
946
+
947
+ logger.info(f"found {len(self.control_paths)} matching control videos/images")
948
+ # check if all videos have matching control paths, if not, raise an error
949
+ missing_controls = len(self.video_paths) - len(self.control_paths)
950
+ if missing_controls > 0:
951
+ # logger.warning(f"Could not find matching control videos/images for {missing_controls} videos")
952
+ missing_controls_videos = [video_path for video_path in self.video_paths if video_path not in self.control_paths]
953
+ logger.error(
954
+ f"Could not find matching control videos/images for {missing_controls} videos: {missing_controls_videos}"
955
+ )
956
+ raise ValueError(f"Could not find matching control videos/images for {missing_controls} videos")
957
+
958
+ def is_indexable(self):
959
+ return True
960
+
961
+ def __len__(self):
962
+ return len(self.video_paths)
963
+
964
+ def get_video_data(
965
+ self,
966
+ idx: int,
967
+ start_frame: Optional[int] = None,
968
+ end_frame: Optional[int] = None,
969
+ bucket_selector: Optional[BucketSelector] = None,
970
+ ) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]:
971
+ video_path = self.video_paths[idx]
972
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
973
+
974
+ _, caption = self.get_caption(idx)
975
+
976
+ control = None
977
+ if self.control_directory is not None and video_path in self.control_paths:
978
+ control_path = self.control_paths[video_path]
979
+ control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector)
980
+
981
+ return video_path, video, caption, control
982
+
983
+ def get_caption(self, idx: int) -> tuple[str, str]:
984
+ video_path = self.video_paths[idx]
985
+ caption_path = os.path.splitext(video_path)[0] + self.caption_extension if self.caption_extension else ""
986
+ with open(caption_path, "r", encoding="utf-8") as f:
987
+ caption = f.read().strip()
988
+ return video_path, caption
989
+
990
+ def __iter__(self):
991
+ self.current_idx = 0
992
+ return self
993
+
994
+ def __next__(self):
995
+ if self.current_idx >= len(self.video_paths):
996
+ raise StopIteration
997
+
998
+ if self.caption_only:
999
+
1000
+ def create_caption_fetcher(index):
1001
+ return lambda: self.get_caption(index)
1002
+
1003
+ fetcher = create_caption_fetcher(self.current_idx)
1004
+
1005
+ else:
1006
+
1007
+ def create_fetcher(index):
1008
+ return lambda: self.get_video_data(index)
1009
+
1010
+ fetcher = create_fetcher(self.current_idx)
1011
+
1012
+ self.current_idx += 1
1013
+ return fetcher
1014
+
1015
+
1016
+ class VideoJsonlDatasource(VideoDatasource):
1017
+ def __init__(self, video_jsonl_file: str):
1018
+ super().__init__()
1019
+ self.video_jsonl_file = video_jsonl_file
1020
+ self.current_idx = 0
1021
+
1022
+ # load jsonl
1023
+ logger.info(f"load video jsonl from {self.video_jsonl_file}")
1024
+ self.data = []
1025
+ with open(self.video_jsonl_file, "r", encoding="utf-8") as f:
1026
+ for line in f:
1027
+ data = json.loads(line)
1028
+ self.data.append(data)
1029
+ logger.info(f"loaded {len(self.data)} videos")
1030
+
1031
+ # Check if there are control paths in the JSONL
1032
+ self.has_control = any("control_path" in item for item in self.data)
1033
+ if self.has_control:
1034
+ control_count = sum(1 for item in self.data if "control_path" in item)
1035
+ if control_count < len(self.data):
1036
+ missing_control_videos = [item["video_path"] for item in self.data if "control_path" not in item]
1037
+ logger.error(f"Some videos do not have control paths in JSONL data: {missing_control_videos}")
1038
+ raise ValueError(f"Some videos do not have control paths in JSONL data: {missing_control_videos}")
1039
+ logger.info(f"found {control_count} control videos/images in JSONL data")
1040
+
1041
+ def is_indexable(self):
1042
+ return True
1043
+
1044
+ def __len__(self):
1045
+ return len(self.data)
1046
+
1047
+ def get_video_data(
1048
+ self,
1049
+ idx: int,
1050
+ start_frame: Optional[int] = None,
1051
+ end_frame: Optional[int] = None,
1052
+ bucket_selector: Optional[BucketSelector] = None,
1053
+ ) -> tuple[str, list[Image.Image], str, Optional[list[Image.Image]]]:
1054
+ data = self.data[idx]
1055
+ video_path = data["video_path"]
1056
+ video = self.get_video_data_from_path(video_path, start_frame, end_frame, bucket_selector)
1057
+
1058
+ caption = data["caption"]
1059
+
1060
+ control = None
1061
+ if "control_path" in data and data["control_path"]:
1062
+ control_path = data["control_path"]
1063
+ control = self.get_control_data_from_path(control_path, start_frame, end_frame, bucket_selector)
1064
+
1065
+ return video_path, video, caption, control
1066
+
1067
+ def get_caption(self, idx: int) -> tuple[str, str]:
1068
+ data = self.data[idx]
1069
+ video_path = data["video_path"]
1070
+ caption = data["caption"]
1071
+ return video_path, caption
1072
+
1073
+ def __iter__(self):
1074
+ self.current_idx = 0
1075
+ return self
1076
+
1077
+ def __next__(self):
1078
+ if self.current_idx >= len(self.data):
1079
+ raise StopIteration
1080
+
1081
+ if self.caption_only:
1082
+
1083
+ def create_caption_fetcher(index):
1084
+ return lambda: self.get_caption(index)
1085
+
1086
+ fetcher = create_caption_fetcher(self.current_idx)
1087
+
1088
+ else:
1089
+
1090
+ def create_fetcher(index):
1091
+ return lambda: self.get_video_data(index)
1092
+
1093
+ fetcher = create_fetcher(self.current_idx)
1094
+
1095
+ self.current_idx += 1
1096
+ return fetcher
1097
+
1098
+
1099
+ class BaseDataset(torch.utils.data.Dataset):
1100
+ def __init__(
1101
+ self,
1102
+ resolution: Tuple[int, int] = (960, 544),
1103
+ caption_extension: Optional[str] = None,
1104
+ batch_size: int = 1,
1105
+ num_repeats: int = 1,
1106
+ enable_bucket: bool = False,
1107
+ bucket_no_upscale: bool = False,
1108
+ cache_directory: Optional[str] = None,
1109
+ debug_dataset: bool = False,
1110
+ architecture: str = "no_default",
1111
+ ):
1112
+ self.resolution = resolution
1113
+ self.caption_extension = caption_extension
1114
+ self.batch_size = batch_size
1115
+ self.num_repeats = num_repeats
1116
+ self.enable_bucket = enable_bucket
1117
+ self.bucket_no_upscale = bucket_no_upscale
1118
+ self.cache_directory = cache_directory
1119
+ self.debug_dataset = debug_dataset
1120
+ self.architecture = architecture
1121
+ self.seed = None
1122
+ self.current_epoch = 0
1123
+
1124
+ if not self.enable_bucket:
1125
+ self.bucket_no_upscale = False
1126
+
1127
+ def get_metadata(self) -> dict:
1128
+ metadata = {
1129
+ "resolution": self.resolution,
1130
+ "caption_extension": self.caption_extension,
1131
+ "batch_size_per_device": self.batch_size,
1132
+ "num_repeats": self.num_repeats,
1133
+ "enable_bucket": bool(self.enable_bucket),
1134
+ "bucket_no_upscale": bool(self.bucket_no_upscale),
1135
+ }
1136
+ return metadata
1137
+
1138
+ def get_all_latent_cache_files(self):
1139
+ return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
1140
+
1141
+ def get_all_text_encoder_output_cache_files(self):
1142
+ return glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}_te.safetensors"))
1143
+
1144
+ def get_latent_cache_path(self, item_info: ItemInfo) -> str:
1145
+ """
1146
+ Returns the cache path for the latent tensor.
1147
+
1148
+ item_info: ItemInfo object
1149
+
1150
+ Returns:
1151
+ str: cache path
1152
+
1153
+ cache_path is based on the item_key and the resolution.
1154
+ """
1155
+ w, h = item_info.original_size
1156
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
1157
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
1158
+ return os.path.join(self.cache_directory, f"{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors")
1159
+
1160
+ def get_text_encoder_output_cache_path(self, item_info: ItemInfo) -> str:
1161
+ basename = os.path.splitext(os.path.basename(item_info.item_key))[0]
1162
+ assert self.cache_directory is not None, "cache_directory is required / cache_directoryは必須です"
1163
+ return os.path.join(self.cache_directory, f"{basename}_{self.architecture}_te.safetensors")
1164
+
1165
+ def retrieve_latent_cache_batches(self, num_workers: int):
1166
+ raise NotImplementedError
1167
+
1168
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1169
+ raise NotImplementedError
1170
+
1171
+ def prepare_for_training(self):
1172
+ pass
1173
+
1174
+ def set_seed(self, seed: int):
1175
+ self.seed = seed
1176
+
1177
+ def set_current_epoch(self, epoch):
1178
+ if not self.current_epoch == epoch: # shuffle buckets when epoch is incremented
1179
+ if epoch > self.current_epoch:
1180
+ logger.info("epoch is incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
1181
+ num_epochs = epoch - self.current_epoch
1182
+ for _ in range(num_epochs):
1183
+ self.current_epoch += 1
1184
+ self.shuffle_buckets()
1185
+ # self.current_epoch seem to be set to 0 again in the next epoch. it may be caused by skipped_dataloader?
1186
+ else:
1187
+ logger.warning("epoch is not incremented. current_epoch: {}, epoch: {}".format(self.current_epoch, epoch))
1188
+ self.current_epoch = epoch
1189
+
1190
+ def set_current_step(self, step):
1191
+ self.current_step = step
1192
+
1193
+ def set_max_train_steps(self, max_train_steps):
1194
+ self.max_train_steps = max_train_steps
1195
+
1196
+ def shuffle_buckets(self):
1197
+ raise NotImplementedError
1198
+
1199
+ def __len__(self):
1200
+ return NotImplementedError
1201
+
1202
+ def __getitem__(self, idx):
1203
+ raise NotImplementedError
1204
+
1205
+ def _default_retrieve_text_encoder_output_cache_batches(self, datasource: ContentDatasource, batch_size: int, num_workers: int):
1206
+ datasource.set_caption_only(True)
1207
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1208
+
1209
+ data: list[ItemInfo] = []
1210
+ futures = []
1211
+
1212
+ def aggregate_future(consume_all: bool = False):
1213
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1214
+ completed_futures = [future for future in futures if future.done()]
1215
+ if len(completed_futures) == 0:
1216
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1217
+ time.sleep(0.1)
1218
+ continue
1219
+ else:
1220
+ break # submit batch if possible
1221
+
1222
+ for future in completed_futures:
1223
+ item_key, caption = future.result()
1224
+ item_info = ItemInfo(item_key, caption, (0, 0), (0, 0))
1225
+ item_info.text_encoder_output_cache_path = self.get_text_encoder_output_cache_path(item_info)
1226
+ data.append(item_info)
1227
+
1228
+ futures.remove(future)
1229
+
1230
+ def submit_batch(flush: bool = False):
1231
+ nonlocal data
1232
+ if len(data) >= batch_size or (len(data) > 0 and flush):
1233
+ batch = data[0:batch_size]
1234
+ if len(data) > batch_size:
1235
+ data = data[batch_size:]
1236
+ else:
1237
+ data = []
1238
+ return batch
1239
+ return None
1240
+
1241
+ for fetch_op in datasource:
1242
+ future = executor.submit(fetch_op)
1243
+ futures.append(future)
1244
+ aggregate_future()
1245
+ while True:
1246
+ batch = submit_batch()
1247
+ if batch is None:
1248
+ break
1249
+ yield batch
1250
+
1251
+ aggregate_future(consume_all=True)
1252
+ while True:
1253
+ batch = submit_batch(flush=True)
1254
+ if batch is None:
1255
+ break
1256
+ yield batch
1257
+
1258
+ executor.shutdown()
1259
+
1260
+
1261
+ class ImageDataset(BaseDataset):
1262
+ def __init__(
1263
+ self,
1264
+ resolution: Tuple[int, int],
1265
+ caption_extension: Optional[str],
1266
+ batch_size: int,
1267
+ num_repeats: int,
1268
+ enable_bucket: bool,
1269
+ bucket_no_upscale: bool,
1270
+ image_directory: Optional[str] = None,
1271
+ image_jsonl_file: Optional[str] = None,
1272
+ control_directory: Optional[str] = None,
1273
+ cache_directory: Optional[str] = None,
1274
+ debug_dataset: bool = False,
1275
+ architecture: str = "no_default",
1276
+ ):
1277
+ super(ImageDataset, self).__init__(
1278
+ resolution,
1279
+ caption_extension,
1280
+ batch_size,
1281
+ num_repeats,
1282
+ enable_bucket,
1283
+ bucket_no_upscale,
1284
+ cache_directory,
1285
+ debug_dataset,
1286
+ architecture,
1287
+ )
1288
+ self.image_directory = image_directory
1289
+ self.image_jsonl_file = image_jsonl_file
1290
+ self.control_directory = control_directory
1291
+ if image_directory is not None:
1292
+ self.datasource = ImageDirectoryDatasource(image_directory, caption_extension, control_directory)
1293
+ elif image_jsonl_file is not None:
1294
+ self.datasource = ImageJsonlDatasource(image_jsonl_file)
1295
+ else:
1296
+ raise ValueError("image_directory or image_jsonl_file must be specified")
1297
+
1298
+ if self.cache_directory is None:
1299
+ self.cache_directory = self.image_directory
1300
+
1301
+ self.batch_manager = None
1302
+ self.num_train_items = 0
1303
+ self.has_control = self.datasource.has_control
1304
+
1305
+ def get_metadata(self):
1306
+ metadata = super().get_metadata()
1307
+ if self.image_directory is not None:
1308
+ metadata["image_directory"] = os.path.basename(self.image_directory)
1309
+ if self.image_jsonl_file is not None:
1310
+ metadata["image_jsonl_file"] = os.path.basename(self.image_jsonl_file)
1311
+ if self.control_directory is not None:
1312
+ metadata["control_directory"] = os.path.basename(self.control_directory)
1313
+ metadata["has_control"] = self.has_control
1314
+ return metadata
1315
+
1316
+ def get_total_image_count(self):
1317
+ return len(self.datasource) if self.datasource.is_indexable() else None
1318
+
1319
+ def retrieve_latent_cache_batches(self, num_workers: int):
1320
+ buckset_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
1321
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1322
+
1323
+ batches: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
1324
+ futures = []
1325
+
1326
+ # aggregate futures and sort by bucket resolution
1327
+ def aggregate_future(consume_all: bool = False):
1328
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1329
+ completed_futures = [future for future in futures if future.done()]
1330
+ if len(completed_futures) == 0:
1331
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1332
+ time.sleep(0.1)
1333
+ continue
1334
+ else:
1335
+ break # submit batch if possible
1336
+
1337
+ for future in completed_futures:
1338
+ original_size, item_key, image, caption, control = future.result()
1339
+ bucket_height, bucket_width = image.shape[:2]
1340
+ bucket_reso = (bucket_width, bucket_height)
1341
+
1342
+ item_info = ItemInfo(item_key, caption, original_size, bucket_reso, content=image)
1343
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1344
+
1345
+ if control is not None:
1346
+ item_info.control_content = control
1347
+
1348
+ if bucket_reso not in batches:
1349
+ batches[bucket_reso] = []
1350
+ batches[bucket_reso].append(item_info)
1351
+
1352
+ futures.remove(future)
1353
+
1354
+ # submit batch if some bucket has enough items
1355
+ def submit_batch(flush: bool = False):
1356
+ for key in batches:
1357
+ if len(batches[key]) >= self.batch_size or flush:
1358
+ batch = batches[key][0 : self.batch_size]
1359
+ if len(batches[key]) > self.batch_size:
1360
+ batches[key] = batches[key][self.batch_size :]
1361
+ else:
1362
+ del batches[key]
1363
+ return key, batch
1364
+ return None, None
1365
+
1366
+ for fetch_op in self.datasource:
1367
+
1368
+ # fetch and resize image in a separate thread
1369
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, Image.Image, str, Optional[Image.Image]]:
1370
+ image_key, image, caption, control = op()
1371
+ image: Image.Image
1372
+ image_size = image.size
1373
+
1374
+ bucket_reso = buckset_selector.get_bucket_resolution(image_size)
1375
+ image = resize_image_to_bucket(image, bucket_reso)
1376
+ if control is not None:
1377
+ control = resize_image_to_bucket(control, bucket_reso)
1378
+ return image_size, image_key, image, caption, control
1379
+
1380
+ future = executor.submit(fetch_and_resize, fetch_op)
1381
+ futures.append(future)
1382
+ aggregate_future()
1383
+ while True:
1384
+ key, batch = submit_batch()
1385
+ if key is None:
1386
+ break
1387
+ yield key, batch
1388
+
1389
+ aggregate_future(consume_all=True)
1390
+ while True:
1391
+ key, batch = submit_batch(flush=True)
1392
+ if key is None:
1393
+ break
1394
+ yield key, batch
1395
+
1396
+ executor.shutdown()
1397
+
1398
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1399
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
1400
+
1401
+ def prepare_for_training(self):
1402
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
1403
+
1404
+ # glob cache files
1405
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
1406
+
1407
+ # assign cache files to item info
1408
+ bucketed_item_info: dict[tuple[int, int], list[ItemInfo]] = {} # (width, height) -> [ItemInfo]
1409
+ for cache_file in latent_cache_files:
1410
+ tokens = os.path.basename(cache_file).split("_")
1411
+
1412
+ image_size = tokens[-2] # 0000x0000
1413
+ image_width, image_height = map(int, image_size.split("x"))
1414
+ image_size = (image_width, image_height)
1415
+
1416
+ item_key = "_".join(tokens[:-2])
1417
+ text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
1418
+ if not os.path.exists(text_encoder_output_cache_file):
1419
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
1420
+ continue
1421
+
1422
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1423
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, latent_cache_path=cache_file)
1424
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1425
+
1426
+ bucket = bucketed_item_info.get(bucket_reso, [])
1427
+ for _ in range(self.num_repeats):
1428
+ bucket.append(item_info)
1429
+ bucketed_item_info[bucket_reso] = bucket
1430
+
1431
+ # prepare batch manager
1432
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
1433
+ self.batch_manager.show_bucket_info()
1434
+
1435
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
1436
+
1437
+ def shuffle_buckets(self):
1438
+ # set random seed for this epoch
1439
+ random.seed(self.seed + self.current_epoch)
1440
+ self.batch_manager.shuffle()
1441
+
1442
+ def __len__(self):
1443
+ if self.batch_manager is None:
1444
+ return 100 # dummy value
1445
+ return len(self.batch_manager)
1446
+
1447
+ def __getitem__(self, idx):
1448
+ return self.batch_manager[idx]
1449
+
1450
+
1451
+ class VideoDataset(BaseDataset):
1452
+ TARGET_FPS_HUNYUAN = 24.0
1453
+ TARGET_FPS_WAN = 16.0
1454
+ TARGET_FPS_FRAMEPACK = 30.0
1455
+
1456
+ def __init__(
1457
+ self,
1458
+ resolution: Tuple[int, int],
1459
+ caption_extension: Optional[str],
1460
+ batch_size: int,
1461
+ num_repeats: int,
1462
+ enable_bucket: bool,
1463
+ bucket_no_upscale: bool,
1464
+ frame_extraction: Optional[str] = "head",
1465
+ frame_stride: Optional[int] = 1,
1466
+ frame_sample: Optional[int] = 1,
1467
+ target_frames: Optional[list[int]] = None,
1468
+ max_frames: Optional[int] = None,
1469
+ source_fps: Optional[float] = None,
1470
+ video_directory: Optional[str] = None,
1471
+ video_jsonl_file: Optional[str] = None,
1472
+ control_directory: Optional[str] = None,
1473
+ cache_directory: Optional[str] = None,
1474
+ debug_dataset: bool = False,
1475
+ architecture: str = "no_default",
1476
+ ):
1477
+ super(VideoDataset, self).__init__(
1478
+ resolution,
1479
+ caption_extension,
1480
+ batch_size,
1481
+ num_repeats,
1482
+ enable_bucket,
1483
+ bucket_no_upscale,
1484
+ cache_directory,
1485
+ debug_dataset,
1486
+ architecture,
1487
+ )
1488
+ self.video_directory = video_directory
1489
+ self.video_jsonl_file = video_jsonl_file
1490
+ self.control_directory = control_directory
1491
+ self.frame_extraction = frame_extraction
1492
+ self.frame_stride = frame_stride
1493
+ self.frame_sample = frame_sample
1494
+ self.max_frames = max_frames
1495
+ self.source_fps = source_fps
1496
+
1497
+ if self.architecture == ARCHITECTURE_HUNYUAN_VIDEO:
1498
+ self.target_fps = VideoDataset.TARGET_FPS_HUNYUAN
1499
+ elif self.architecture == ARCHITECTURE_WAN:
1500
+ self.target_fps = VideoDataset.TARGET_FPS_WAN
1501
+ elif self.architecture == ARCHITECTURE_FRAMEPACK:
1502
+ self.target_fps = VideoDataset.TARGET_FPS_FRAMEPACK
1503
+ else:
1504
+ raise ValueError(f"Unsupported architecture: {self.architecture}")
1505
+
1506
+ if target_frames is not None:
1507
+ target_frames = list(set(target_frames))
1508
+ target_frames.sort()
1509
+
1510
+ # round each value to N*4+1
1511
+ rounded_target_frames = [(f - 1) // 4 * 4 + 1 for f in target_frames]
1512
+ rouneded_target_frames = list(set(rounded_target_frames))
1513
+ rouneded_target_frames.sort()
1514
+
1515
+ # if value is changed, warn
1516
+ if target_frames != rounded_target_frames:
1517
+ logger.warning(f"target_frames are rounded to {rounded_target_frames}")
1518
+
1519
+ target_frames = tuple(rounded_target_frames)
1520
+
1521
+ self.target_frames = target_frames
1522
+
1523
+ if video_directory is not None:
1524
+ self.datasource = VideoDirectoryDatasource(video_directory, caption_extension, control_directory)
1525
+ elif video_jsonl_file is not None:
1526
+ self.datasource = VideoJsonlDatasource(video_jsonl_file)
1527
+
1528
+ if self.frame_extraction == "uniform" and self.frame_sample == 1:
1529
+ self.frame_extraction = "head"
1530
+ logger.warning("frame_sample is set to 1 for frame_extraction=uniform. frame_extraction is changed to head.")
1531
+ if self.frame_extraction == "head":
1532
+ # head extraction. we can limit the number of frames to be extracted
1533
+ self.datasource.set_start_and_end_frame(0, max(self.target_frames))
1534
+
1535
+ if self.cache_directory is None:
1536
+ self.cache_directory = self.video_directory
1537
+
1538
+ self.batch_manager = None
1539
+ self.num_train_items = 0
1540
+ self.has_control = self.datasource.has_control
1541
+
1542
+ def get_metadata(self):
1543
+ metadata = super().get_metadata()
1544
+ if self.video_directory is not None:
1545
+ metadata["video_directory"] = os.path.basename(self.video_directory)
1546
+ if self.video_jsonl_file is not None:
1547
+ metadata["video_jsonl_file"] = os.path.basename(self.video_jsonl_file)
1548
+ if self.control_directory is not None:
1549
+ metadata["control_directory"] = os.path.basename(self.control_directory)
1550
+ metadata["frame_extraction"] = self.frame_extraction
1551
+ metadata["frame_stride"] = self.frame_stride
1552
+ metadata["frame_sample"] = self.frame_sample
1553
+ metadata["target_frames"] = self.target_frames
1554
+ metadata["max_frames"] = self.max_frames
1555
+ metadata["source_fps"] = self.source_fps
1556
+ metadata["has_control"] = self.has_control
1557
+ return metadata
1558
+
1559
+ def retrieve_latent_cache_batches(self, num_workers: int):
1560
+ buckset_selector = BucketSelector(self.resolution, architecture=self.architecture)
1561
+ self.datasource.set_bucket_selector(buckset_selector)
1562
+ if self.source_fps is not None:
1563
+ self.datasource.set_source_and_target_fps(self.source_fps, self.target_fps)
1564
+ else:
1565
+ self.datasource.set_source_and_target_fps(None, None) # no conversion
1566
+
1567
+ executor = ThreadPoolExecutor(max_workers=num_workers)
1568
+
1569
+ # key: (width, height, frame_count), value: [ItemInfo]
1570
+ batches: dict[tuple[int, int, int], list[ItemInfo]] = {}
1571
+ futures = []
1572
+
1573
+ def aggregate_future(consume_all: bool = False):
1574
+ while len(futures) >= num_workers or (consume_all and len(futures) > 0):
1575
+ completed_futures = [future for future in futures if future.done()]
1576
+ if len(completed_futures) == 0:
1577
+ if len(futures) >= num_workers or consume_all: # to avoid adding too many futures
1578
+ time.sleep(0.1)
1579
+ continue
1580
+ else:
1581
+ break # submit batch if possible
1582
+
1583
+ for future in completed_futures:
1584
+ original_frame_size, video_key, video, caption, control = future.result()
1585
+
1586
+ frame_count = len(video)
1587
+ video = np.stack(video, axis=0)
1588
+ height, width = video.shape[1:3]
1589
+ bucket_reso = (width, height) # already resized
1590
+
1591
+ # process control images if available
1592
+ control_video = None
1593
+ if control is not None:
1594
+ # set frame count to the same as video
1595
+ if len(control) > frame_count:
1596
+ control = control[:frame_count]
1597
+ elif len(control) < frame_count:
1598
+ # if control is shorter than video, repeat the last frame
1599
+ last_frame = control[-1]
1600
+ control.extend([last_frame] * (frame_count - len(control)))
1601
+ control_video = np.stack(control, axis=0)
1602
+
1603
+ crop_pos_and_frames = []
1604
+ if self.frame_extraction == "head":
1605
+ for target_frame in self.target_frames:
1606
+ if frame_count >= target_frame:
1607
+ crop_pos_and_frames.append((0, target_frame))
1608
+ elif self.frame_extraction == "chunk":
1609
+ # split by target_frames
1610
+ for target_frame in self.target_frames:
1611
+ for i in range(0, frame_count, target_frame):
1612
+ if i + target_frame <= frame_count:
1613
+ crop_pos_and_frames.append((i, target_frame))
1614
+ elif self.frame_extraction == "slide":
1615
+ # slide window
1616
+ for target_frame in self.target_frames:
1617
+ if frame_count >= target_frame:
1618
+ for i in range(0, frame_count - target_frame + 1, self.frame_stride):
1619
+ crop_pos_and_frames.append((i, target_frame))
1620
+ elif self.frame_extraction == "uniform":
1621
+ # select N frames uniformly
1622
+ for target_frame in self.target_frames:
1623
+ if frame_count >= target_frame:
1624
+ frame_indices = np.linspace(0, frame_count - target_frame, self.frame_sample, dtype=int)
1625
+ for i in frame_indices:
1626
+ crop_pos_and_frames.append((i, target_frame))
1627
+ elif self.frame_extraction == "full":
1628
+ # select all frames
1629
+ target_frame = min(frame_count, self.max_frames)
1630
+ target_frame = (target_frame - 1) // 4 * 4 + 1 # round to N*4+1
1631
+ crop_pos_and_frames.append((0, target_frame))
1632
+ else:
1633
+ raise ValueError(f"frame_extraction {self.frame_extraction} is not supported")
1634
+
1635
+ for crop_pos, target_frame in crop_pos_and_frames:
1636
+ cropped_video = video[crop_pos : crop_pos + target_frame]
1637
+ body, ext = os.path.splitext(video_key)
1638
+ item_key = f"{body}_{crop_pos:05d}-{target_frame:03d}{ext}"
1639
+ batch_key = (*bucket_reso, target_frame) # bucket_reso with frame_count
1640
+
1641
+ # crop control video if available
1642
+ cropped_control = None
1643
+ if control_video is not None:
1644
+ cropped_control = control_video[crop_pos : crop_pos + target_frame]
1645
+
1646
+ item_info = ItemInfo(
1647
+ item_key, caption, original_frame_size, batch_key, frame_count=target_frame, content=cropped_video
1648
+ )
1649
+ item_info.latent_cache_path = self.get_latent_cache_path(item_info)
1650
+ item_info.control_content = cropped_control # None is allowed
1651
+
1652
+ batch = batches.get(batch_key, [])
1653
+ batch.append(item_info)
1654
+ batches[batch_key] = batch
1655
+
1656
+ futures.remove(future)
1657
+
1658
+ def submit_batch(flush: bool = False):
1659
+ for key in batches:
1660
+ if len(batches[key]) >= self.batch_size or flush:
1661
+ batch = batches[key][0 : self.batch_size]
1662
+ if len(batches[key]) > self.batch_size:
1663
+ batches[key] = batches[key][self.batch_size :]
1664
+ else:
1665
+ del batches[key]
1666
+ return key, batch
1667
+ return None, None
1668
+
1669
+ for operator in self.datasource:
1670
+
1671
+ def fetch_and_resize(op: callable) -> tuple[tuple[int, int], str, list[np.ndarray], str, Optional[list[np.ndarray]]]:
1672
+ result = op()
1673
+
1674
+ if len(result) == 3: # for backward compatibility TODO remove this in the future
1675
+ video_key, video, caption = result
1676
+ control = None
1677
+ else:
1678
+ video_key, video, caption, control = result
1679
+
1680
+ video: list[np.ndarray]
1681
+ frame_size = (video[0].shape[1], video[0].shape[0])
1682
+
1683
+ # resize if necessary
1684
+ bucket_reso = buckset_selector.get_bucket_resolution(frame_size)
1685
+ video = [resize_image_to_bucket(frame, bucket_reso) for frame in video]
1686
+
1687
+ # resize control if necessary
1688
+ if control is not None:
1689
+ control = [resize_image_to_bucket(frame, bucket_reso) for frame in control]
1690
+
1691
+ return frame_size, video_key, video, caption, control
1692
+
1693
+ future = executor.submit(fetch_and_resize, operator)
1694
+ futures.append(future)
1695
+ aggregate_future()
1696
+ while True:
1697
+ key, batch = submit_batch()
1698
+ if key is None:
1699
+ break
1700
+ yield key, batch
1701
+
1702
+ aggregate_future(consume_all=True)
1703
+ while True:
1704
+ key, batch = submit_batch(flush=True)
1705
+ if key is None:
1706
+ break
1707
+ yield key, batch
1708
+
1709
+ executor.shutdown()
1710
+
1711
+ def retrieve_text_encoder_output_cache_batches(self, num_workers: int):
1712
+ return self._default_retrieve_text_encoder_output_cache_batches(self.datasource, self.batch_size, num_workers)
1713
+
1714
+ def prepare_for_training(self):
1715
+ bucket_selector = BucketSelector(self.resolution, self.enable_bucket, self.bucket_no_upscale, self.architecture)
1716
+
1717
+ # glob cache files
1718
+ latent_cache_files = glob.glob(os.path.join(self.cache_directory, f"*_{self.architecture}.safetensors"))
1719
+
1720
+ # assign cache files to item info
1721
+ bucketed_item_info: dict[tuple[int, int, int], list[ItemInfo]] = {} # (width, height, frame_count) -> [ItemInfo]
1722
+ for cache_file in latent_cache_files:
1723
+ tokens = os.path.basename(cache_file).split("_")
1724
+
1725
+ image_size = tokens[-2] # 0000x0000
1726
+ image_width, image_height = map(int, image_size.split("x"))
1727
+ image_size = (image_width, image_height)
1728
+
1729
+ frame_pos, frame_count = tokens[-3].split("-")[:2] # "00000-000", or optional section index "00000-000-00"
1730
+ frame_pos, frame_count = int(frame_pos), int(frame_count)
1731
+
1732
+ item_key = "_".join(tokens[:-3])
1733
+ text_encoder_output_cache_file = os.path.join(self.cache_directory, f"{item_key}_{self.architecture}_te.safetensors")
1734
+ if not os.path.exists(text_encoder_output_cache_file):
1735
+ logger.warning(f"Text encoder output cache file not found: {text_encoder_output_cache_file}")
1736
+ continue
1737
+
1738
+ bucket_reso = bucket_selector.get_bucket_resolution(image_size)
1739
+ bucket_reso = (*bucket_reso, frame_count)
1740
+ item_info = ItemInfo(item_key, "", image_size, bucket_reso, frame_count=frame_count, latent_cache_path=cache_file)
1741
+ item_info.text_encoder_output_cache_path = text_encoder_output_cache_file
1742
+
1743
+ bucket = bucketed_item_info.get(bucket_reso, [])
1744
+ for _ in range(self.num_repeats):
1745
+ bucket.append(item_info)
1746
+ bucketed_item_info[bucket_reso] = bucket
1747
+
1748
+ # prepare batch manager
1749
+ self.batch_manager = BucketBatchManager(bucketed_item_info, self.batch_size)
1750
+ self.batch_manager.show_bucket_info()
1751
+
1752
+ self.num_train_items = sum([len(bucket) for bucket in bucketed_item_info.values()])
1753
+
1754
+ def shuffle_buckets(self):
1755
+ # set random seed for this epoch
1756
+ random.seed(self.seed + self.current_epoch)
1757
+ self.batch_manager.shuffle()
1758
+
1759
+ def __len__(self):
1760
+ if self.batch_manager is None:
1761
+ return 100 # dummy value
1762
+ return len(self.batch_manager)
1763
+
1764
+ def __getitem__(self, idx):
1765
+ return self.batch_manager[idx]
1766
+
1767
+
1768
+ class DatasetGroup(torch.utils.data.ConcatDataset):
1769
+ def __init__(self, datasets: Sequence[Union[ImageDataset, VideoDataset]]):
1770
+ super().__init__(datasets)
1771
+ self.datasets: list[Union[ImageDataset, VideoDataset]] = datasets
1772
+ self.num_train_items = 0
1773
+ for dataset in self.datasets:
1774
+ self.num_train_items += dataset.num_train_items
1775
+
1776
+ def set_current_epoch(self, epoch):
1777
+ for dataset in self.datasets:
1778
+ dataset.set_current_epoch(epoch)
1779
+
1780
+ def set_current_step(self, step):
1781
+ for dataset in self.datasets:
1782
+ dataset.set_current_step(step)
1783
+
1784
+ def set_max_train_steps(self, max_train_steps):
1785
+ for dataset in self.datasets:
1786
+ dataset.set_max_train_steps(max_train_steps)
fpack_cache_latents.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import math
4
+ import os
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from tqdm import tqdm
11
+ from transformers import SiglipImageProcessor, SiglipVisionModel
12
+
13
+ from dataset import config_utils
14
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
15
+ from dataset.image_video_dataset import BaseDataset, ItemInfo, save_latent_cache_framepack, ARCHITECTURE_FRAMEPACK
16
+ from frame_pack import hunyuan
17
+ from frame_pack.framepack_utils import load_image_encoders, load_vae
18
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
19
+ from frame_pack.clip_vision import hf_clip_vision_encode
20
+ import cache_latents
21
+
22
+ logger = logging.getLogger(__name__)
23
+ logging.basicConfig(level=logging.INFO)
24
+
25
+
26
+ def encode_and_save_batch(
27
+ vae: AutoencoderKLCausal3D,
28
+ feature_extractor: SiglipImageProcessor,
29
+ image_encoder: SiglipVisionModel,
30
+ batch: List[ItemInfo],
31
+ latent_window_size: int,
32
+ vanilla_sampling: bool = False,
33
+ one_frame: bool = False,
34
+ ):
35
+ """Encode a batch of original RGB videos and save FramePack section caches."""
36
+ if one_frame:
37
+ encode_and_save_batch_one_frame(vae, feature_extractor, image_encoder, batch, latent_window_size, vanilla_sampling)
38
+ return
39
+
40
+ # Stack batch into tensor (B,C,F,H,W) in RGB order
41
+ contents = torch.stack([torch.from_numpy(item.content) for item in batch])
42
+ if len(contents.shape) == 4:
43
+ contents = contents.unsqueeze(1) # B, H, W, C -> B, F, H, W, C
44
+
45
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
46
+ contents = contents.to(vae.device, dtype=vae.dtype)
47
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
48
+
49
+ height, width = contents.shape[3], contents.shape[4]
50
+ if height < 8 or width < 8:
51
+ item = batch[0] # other items should have the same size
52
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
53
+
54
+ # calculate latent frame count from original frame count (4n+1)
55
+ latent_f = (batch[0].frame_count - 1) // 4 + 1
56
+
57
+ # calculate the total number of sections (excluding the first frame, divided by window size)
58
+ total_latent_sections = math.floor((latent_f - 1) / latent_window_size)
59
+ if total_latent_sections < 1:
60
+ min_frames_needed = latent_window_size * 4 + 1
61
+ raise ValueError(
62
+ f"Not enough frames for FramePack: {batch[0].frame_count} frames ({latent_f} latent frames), minimum required: {min_frames_needed} frames ({latent_window_size+1} latent frames)"
63
+ )
64
+
65
+ # actual latent frame count (aligned to section boundaries)
66
+ latent_f_aligned = total_latent_sections * latent_window_size + 1 if not one_frame else 1
67
+
68
+ # actual video frame count
69
+ frame_count_aligned = (latent_f_aligned - 1) * 4 + 1
70
+ if frame_count_aligned != batch[0].frame_count:
71
+ logger.info(
72
+ f"Frame count mismatch: required={frame_count_aligned} != actual={batch[0].frame_count}, trimming to {frame_count_aligned}"
73
+ )
74
+ contents = contents[:, :, :frame_count_aligned, :, :]
75
+
76
+ latent_f = latent_f_aligned # Update to the aligned value
77
+
78
+ # VAE encode (list of tensor -> stack)
79
+ latents = hunyuan.vae_encode(contents, vae) # include scaling factor
80
+ latents = latents.to("cpu") # (B, C, latent_f, H/8, W/8)
81
+
82
+ # Vision encoding per‑item (once)
83
+ images = np.stack([item.content[0] for item in batch], axis=0) # B, H, W, C
84
+
85
+ # encode image with image encoder
86
+ image_embeddings = []
87
+ with torch.no_grad():
88
+ for image in images:
89
+ image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
90
+ image_embeddings.append(image_encoder_output.last_hidden_state)
91
+ image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
92
+ image_embeddings = image_embeddings.to("cpu") # Save memory
93
+
94
+ if not vanilla_sampling:
95
+ # padding is reversed for inference (future to past)
96
+ latent_paddings = list(reversed(range(total_latent_sections)))
97
+ # Note: The padding trick for inference. See the paper for details.
98
+ if total_latent_sections > 4:
99
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
100
+
101
+ for b, item in enumerate(batch):
102
+ original_latent_cache_path = item.latent_cache_path
103
+ video_lat = latents[b : b + 1] # keep batch dim, 1, C, F, H, W
104
+
105
+ # emulate inference step (history latents)
106
+ # Note: In inference, history_latents stores *generated* future latents.
107
+ # Here, for caching, we just need its shape and type for clean_* tensors.
108
+ # The actual content doesn't matter much as clean_* will be overwritten.
109
+ history_latents = torch.zeros(
110
+ (1, video_lat.shape[1], 1 + 2 + 16, video_lat.shape[3], video_lat.shape[4]), dtype=video_lat.dtype
111
+ ) # C=16 for HY
112
+
113
+ latent_f_index = latent_f - latent_window_size # Start from the last section
114
+ section_index = total_latent_sections - 1
115
+
116
+ for latent_padding in latent_paddings:
117
+ is_last_section = section_index == 0 # the last section in inference order == the first section in time
118
+ latent_padding_size = latent_padding * latent_window_size
119
+ if is_last_section:
120
+ assert latent_f_index == 1, "Last section should be starting from frame 1"
121
+
122
+ # indices generation (same as inference)
123
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
124
+ (
125
+ clean_latent_indices_pre, # Index for start_latent
126
+ blank_indices, # Indices for padding (future context in inference)
127
+ latent_indices, # Indices for the target latents to predict
128
+ clean_latent_indices_post, # Index for the most recent history frame
129
+ clean_latent_2x_indices, # Indices for the next 2 history frames
130
+ clean_latent_4x_indices, # Indices for the next 16 history frames
131
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
132
+
133
+ # Indices for clean_latents (start + recent history)
134
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
135
+
136
+ # clean latents preparation (emulating inference)
137
+ clean_latents_pre = video_lat[:, :, 0:1, :, :] # Always the first frame (start_latent)
138
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
139
+ [1, 2, 16], dim=2
140
+ )
141
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) # Combine start frame + placeholder
142
+
143
+ # Target latents for this section (ground truth)
144
+ target_latents = video_lat[:, :, latent_f_index : latent_f_index + latent_window_size, :, :]
145
+
146
+ # save cache (file path is inside item.latent_cache_path pattern), remove batch dim
147
+ item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index)
148
+ save_latent_cache_framepack(
149
+ item_info=item,
150
+ latent=target_latents.squeeze(0), # Ground truth for this section
151
+ latent_indices=latent_indices.squeeze(0), # Indices for the ground truth section
152
+ clean_latents=clean_latents.squeeze(0), # Start frame + history placeholder
153
+ clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for start frame + history placeholder
154
+ clean_latents_2x=clean_latents_2x.squeeze(0), # History placeholder
155
+ clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for history placeholder
156
+ clean_latents_4x=clean_latents_4x.squeeze(0), # History placeholder
157
+ clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for history placeholder
158
+ image_embeddings=image_embeddings[b],
159
+ )
160
+
161
+ if is_last_section: # If this was the first section generated in inference (time=0)
162
+ # History gets the start frame + the generated first section
163
+ generated_latents_for_history = video_lat[:, :, : latent_window_size + 1, :, :]
164
+ else:
165
+ # History gets the generated current section
166
+ generated_latents_for_history = target_latents # Use true latents as stand-in for generated
167
+
168
+ history_latents = torch.cat([generated_latents_for_history, history_latents], dim=2)
169
+
170
+ section_index -= 1
171
+ latent_f_index -= latent_window_size
172
+
173
+ else:
174
+ # Vanilla Sampling Logic
175
+ for b, item in enumerate(batch):
176
+ original_latent_cache_path = item.latent_cache_path
177
+ video_lat = latents[b : b + 1] # Keep batch dim: 1, C, F_aligned, H, W
178
+ img_emb = image_embeddings[b] # LEN, 1152
179
+
180
+ for section_index in range(total_latent_sections):
181
+ target_start_f = section_index * latent_window_size + 1
182
+ target_end_f = target_start_f + latent_window_size
183
+ target_latents = video_lat[:, :, target_start_f:target_end_f, :, :]
184
+ start_latent = video_lat[:, :, 0:1, :, :]
185
+
186
+ # Clean latents preparation (Vanilla)
187
+ clean_latents_total_count = 1 + 2 + 16
188
+ history_latents = torch.zeros(
189
+ size=(1, 16, clean_latents_total_count, video_lat.shape[-2], video_lat.shape[-1]),
190
+ device=video_lat.device,
191
+ dtype=video_lat.dtype,
192
+ )
193
+
194
+ history_start_f = 0
195
+ video_start_f = target_start_f - clean_latents_total_count
196
+ copy_count = clean_latents_total_count
197
+ if video_start_f < 0:
198
+ history_start_f = -video_start_f
199
+ copy_count = clean_latents_total_count - history_start_f
200
+ video_start_f = 0
201
+ if copy_count > 0:
202
+ history_latents[:, :, history_start_f:] = video_lat[:, :, video_start_f : video_start_f + copy_count, :, :]
203
+
204
+ # indices generation (Vanilla): copy from FramePack-F1
205
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
206
+ (
207
+ clean_latent_indices_start,
208
+ clean_latent_4x_indices,
209
+ clean_latent_2x_indices,
210
+ clean_latent_1x_indices,
211
+ latent_indices,
212
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
213
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
214
+
215
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents.split([16, 2, 1], dim=2)
216
+ clean_latents = torch.cat([start_latent, clean_latents_1x], dim=2)
217
+
218
+ # Save cache
219
+ item.latent_cache_path = append_section_idx_to_latent_cache_path(original_latent_cache_path, section_index)
220
+ save_latent_cache_framepack(
221
+ item_info=item,
222
+ latent=target_latents.squeeze(0),
223
+ latent_indices=latent_indices.squeeze(0), # Indices for target section i
224
+ clean_latents=clean_latents.squeeze(0), # Past clean frames
225
+ clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for clean_latents_pre/post
226
+ clean_latents_2x=clean_latents_2x.squeeze(0), # Past clean frames (2x)
227
+ clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for clean_latents_2x
228
+ clean_latents_4x=clean_latents_4x.squeeze(0), # Past clean frames (4x)
229
+ clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for clean_latents_4x
230
+ image_embeddings=img_emb,
231
+ # Note: We don't explicitly save past_offset_indices,
232
+ # but its size influences the absolute values in other indices.
233
+ )
234
+
235
+
236
+ def encode_and_save_batch_one_frame(
237
+ vae: AutoencoderKLCausal3D,
238
+ feature_extractor: SiglipImageProcessor,
239
+ image_encoder: SiglipVisionModel,
240
+ batch: List[ItemInfo],
241
+ latent_window_size: int,
242
+ vanilla_sampling: bool = False,
243
+ ):
244
+ # item.content: target image (H, W, C)
245
+ # item.control_content: start image (H, W, C)
246
+
247
+ # Stack batch into tensor (B,F,H,W,C) in RGB order.
248
+ contents = torch.stack(
249
+ [torch.stack([torch.from_numpy(item.control_content), torch.from_numpy(item.content)]) for item in batch]
250
+ )
251
+
252
+ contents = contents.permute(0, 4, 1, 2, 3).contiguous() # B, C, F, H, W
253
+ contents = contents.to(vae.device, dtype=vae.dtype)
254
+ contents = contents / 127.5 - 1.0 # normalize to [-1, 1]
255
+
256
+ height, width = contents.shape[3], contents.shape[4]
257
+ if height < 8 or width < 8:
258
+ item = batch[0] # other items should have the same size
259
+ raise ValueError(f"Image or video size too small: {item.item_key} and {len(batch) - 1} more, size: {item.original_size}")
260
+
261
+ # VAE encode (list of tensor -> stack)
262
+ start_latents = hunyuan.vae_encode(contents[:, :, 0:1], vae) # include scaling factor
263
+ start_latents = start_latents.to("cpu") # (B, C, 1, H/8, W/8)
264
+ latents = hunyuan.vae_encode(contents[:, :, 1:], vae) # include scaling factor
265
+ latents = latents.to("cpu") # (B, C, 1, H/8, W/8)
266
+
267
+ # Vision encoding per‑item (once): use control content because it is the start image
268
+ images = [item.control_content for item in batch] # list of [H, W, C]
269
+
270
+ # encode image with image encoder
271
+ image_embeddings = []
272
+ with torch.no_grad():
273
+ for image in images:
274
+ image_encoder_output = hf_clip_vision_encode(image, feature_extractor, image_encoder)
275
+ image_embeddings.append(image_encoder_output.last_hidden_state)
276
+ image_embeddings = torch.cat(image_embeddings, dim=0) # B, LEN, 1152
277
+ image_embeddings = image_embeddings.to("cpu") # Save memory
278
+
279
+ # history latents is always zeroes for one frame training
280
+ history_latents = torch.zeros(
281
+ (1, latents.shape[1], 1 + 2 + 16, latents.shape[3], latents.shape[4]), dtype=latents.dtype
282
+ ) # C=16 for HY
283
+
284
+ # indices generation (same as inference)
285
+ indices = torch.arange(0, sum([1, latent_window_size, 1, 2, 16])).unsqueeze(0)
286
+ (
287
+ clean_latent_indices_pre, # Index for start_latent
288
+ latent_indices, # Indices for the target latents to predict
289
+ clean_latent_indices_post, # Index for the most recent history frame
290
+ clean_latent_2x_indices, # Indices for the next 2 history frames
291
+ clean_latent_4x_indices, # Indices for the next 16 history frames
292
+ ) = indices.split([1, latent_window_size, 1, 2, 16], dim=1)
293
+
294
+ # Indices for clean_latents (start + recent history)
295
+ latent_indices = latent_indices[:, -1:] # Only the last index is used for one frame training
296
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
297
+
298
+ # clean latents preparation for all items (emulating inference)
299
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split([1, 2, 16], dim=2)
300
+
301
+ for b, item in enumerate(batch):
302
+ original_latent_cache_path = item.latent_cache_path
303
+
304
+ # clean latents preparation (emulating inference)
305
+ clean_latents_pre = start_latents[b : b + 1]
306
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2) # Combine start frame + placeholder
307
+
308
+ # Target latents for this section (ground truth)
309
+ target_latents = latents[b : b + 1]
310
+
311
+ # save cache (file path is inside item.latent_cache_path pattern), remove batch dim
312
+ save_latent_cache_framepack(
313
+ item_info=item,
314
+ latent=target_latents.squeeze(0), # Ground truth for this section
315
+ latent_indices=latent_indices.squeeze(0), # Indices for the ground truth section
316
+ clean_latents=clean_latents.squeeze(0), # Start frame + history placeholder
317
+ clean_latent_indices=clean_latent_indices.squeeze(0), # Indices for start frame + history placeholder
318
+ clean_latents_2x=clean_latents_2x.squeeze(0), # History placeholder
319
+ clean_latent_2x_indices=clean_latent_2x_indices.squeeze(0), # Indices for history placeholder
320
+ clean_latents_4x=clean_latents_4x.squeeze(0), # History placeholder
321
+ clean_latent_4x_indices=clean_latent_4x_indices.squeeze(0), # Indices for history placeholder
322
+ image_embeddings=image_embeddings[b],
323
+ )
324
+
325
+
326
+ def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
327
+ parser.add_argument("--image_encoder", type=str, required=True, help="Image encoder (CLIP) checkpoint path or directory")
328
+ parser.add_argument("--latent_window_size", type=int, default=9, help="FramePack latent window size (default 9)")
329
+ parser.add_argument(
330
+ "--f1",
331
+ action="store_true",
332
+ help="Generate cache for F1 model (vanilla (autoregressive) sampling) instead of Inverted anti-drifting (plain FramePack)",
333
+ )
334
+ parser.add_argument(
335
+ "--one_frame",
336
+ action="store_true",
337
+ help="Generate cache for one frame training (single frame, single section). latent_window_size is used as the index of the target frame.",
338
+ )
339
+ return parser
340
+
341
+
342
+ def main(args: argparse.Namespace):
343
+ device = args.device if hasattr(args, "device") and args.device else ("cuda" if torch.cuda.is_available() else "cpu")
344
+ device = torch.device(device)
345
+
346
+ # Load dataset config
347
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
348
+ logger.info(f"Load dataset config from {args.dataset_config}")
349
+ user_config = config_utils.load_user_config(args.dataset_config)
350
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
351
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
352
+
353
+ datasets = train_dataset_group.datasets
354
+
355
+ if args.debug_mode is not None:
356
+ cache_latents.show_datasets(
357
+ datasets, args.debug_mode, args.console_width, args.console_back, args.console_num_images, fps=16
358
+ )
359
+ return
360
+
361
+ assert args.vae is not None, "vae checkpoint is required"
362
+
363
+ logger.info(f"Loading VAE model from {args.vae}")
364
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device=device)
365
+ vae.to(device)
366
+
367
+ logger.info(f"Loading image encoder from {args.image_encoder}")
368
+ feature_extractor, image_encoder = load_image_encoders(args)
369
+ image_encoder.eval()
370
+ image_encoder.to(device)
371
+
372
+ logger.info(f"Cache generation mode: {'Vanilla Sampling' if args.f1 else 'Inference Emulation'}")
373
+
374
+ # encoding closure
375
+ def encode(batch: List[ItemInfo]):
376
+ encode_and_save_batch(vae, feature_extractor, image_encoder, batch, args.latent_window_size, args.f1, args.one_frame)
377
+
378
+ # reuse core loop from cache_latents with no change
379
+ encode_datasets_framepack(datasets, encode, args)
380
+
381
+
382
+ def append_section_idx_to_latent_cache_path(latent_cache_path: str, section_idx: int) -> str:
383
+ tokens = latent_cache_path.split("_")
384
+ tokens[-3] = f"{tokens[-3]}-{section_idx:04d}" # append section index to "frame_pos-count"
385
+ return "_".join(tokens)
386
+
387
+
388
+ def encode_datasets_framepack(datasets: list[BaseDataset], encode: callable, args: argparse.Namespace):
389
+ num_workers = args.num_workers if args.num_workers is not None else max(1, os.cpu_count() - 1)
390
+ for i, dataset in enumerate(datasets):
391
+ logger.info(f"Encoding dataset [{i}]")
392
+ all_latent_cache_paths = []
393
+ for _, batch in tqdm(dataset.retrieve_latent_cache_batches(num_workers)):
394
+ batch: list[ItemInfo] = batch # type: ignore
395
+
396
+ # latent_cache_path is "{basename}_{w:04d}x{h:04d}_{self.architecture}.safetensors"
397
+ # For video dataset,we expand it to "{basename}_{section_idx:04d}_{w:04d}x{h:04d}_{self.architecture}.safetensors"
398
+ filtered_batch = []
399
+ for item in batch:
400
+ if item.frame_count is None:
401
+ # image dataset
402
+ all_latent_cache_paths.append(item.latent_cache_path)
403
+ all_existing = os.path.exists(item.latent_cache_path)
404
+ else:
405
+ latent_f = (item.frame_count - 1) // 4 + 1
406
+ num_sections = max(1, math.floor((latent_f - 1) / args.latent_window_size)) # min 1 section
407
+ all_existing = True
408
+ for sec in range(num_sections):
409
+ p = append_section_idx_to_latent_cache_path(item.latent_cache_path, sec)
410
+ all_latent_cache_paths.append(p)
411
+ all_existing = all_existing and os.path.exists(p)
412
+
413
+ if not all_existing: # if any section cache is missing
414
+ filtered_batch.append(item)
415
+
416
+ if args.skip_existing:
417
+ if len(filtered_batch) == 0: # all sections exist
418
+ logger.info(f"All sections exist for {batch[0].item_key}, skipping")
419
+ continue
420
+ batch = filtered_batch # update batch to only missing sections
421
+
422
+ bs = args.batch_size if args.batch_size is not None else len(batch)
423
+ for i in range(0, len(batch), bs):
424
+ encode(batch[i : i + bs])
425
+
426
+ # normalize paths
427
+ all_latent_cache_paths = [os.path.normpath(p) for p in all_latent_cache_paths]
428
+ all_latent_cache_paths = set(all_latent_cache_paths)
429
+
430
+ # remove old cache files not in the dataset
431
+ all_cache_files = dataset.get_all_latent_cache_files()
432
+ for cache_file in all_cache_files:
433
+ if os.path.normpath(cache_file) not in all_latent_cache_paths:
434
+ if args.keep_cache:
435
+ logger.info(f"Keep cache file not in the dataset: {cache_file}")
436
+ else:
437
+ os.remove(cache_file)
438
+ logger.info(f"Removed old cache file: {cache_file}")
439
+
440
+
441
+ if __name__ == "__main__":
442
+ parser = cache_latents.setup_parser_common()
443
+ parser = cache_latents.hv_setup_parser(parser) # VAE
444
+ parser = framepack_setup_parser(parser)
445
+
446
+ args = parser.parse_args()
447
+
448
+ if args.vae_dtype is not None:
449
+ raise ValueError("VAE dtype is not supported in FramePack")
450
+ # if args.batch_size != 1:
451
+ # args.batch_size = 1
452
+ # logger.info("Batch size is set to 1 for FramePack.")
453
+
454
+ main(args)
fpack_cache_text_encoder_outputs.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from typing import Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from tqdm import tqdm
8
+ from transformers import LlamaTokenizerFast, LlamaModel, CLIPTokenizer, CLIPTextModel
9
+ from dataset import config_utils
10
+ from dataset.config_utils import BlueprintGenerator, ConfigSanitizer
11
+ from dataset.image_video_dataset import ARCHITECTURE_FRAMEPACK, ItemInfo, save_text_encoder_output_cache_framepack
12
+ import cache_text_encoder_outputs
13
+ from frame_pack import hunyuan
14
+ from frame_pack.framepack_utils import load_text_encoder1, load_text_encoder2
15
+
16
+ import logging
17
+
18
+ from frame_pack.utils import crop_or_pad_yield_mask
19
+
20
+ logger = logging.getLogger(__name__)
21
+ logging.basicConfig(level=logging.INFO)
22
+
23
+
24
+ def encode_and_save_batch(
25
+ tokenizer1: LlamaTokenizerFast,
26
+ text_encoder1: LlamaModel,
27
+ tokenizer2: CLIPTokenizer,
28
+ text_encoder2: CLIPTextModel,
29
+ batch: list[ItemInfo],
30
+ device: torch.device,
31
+ ):
32
+ prompts = [item.caption for item in batch]
33
+
34
+ # encode prompt
35
+ # FramePack's encode_prompt_conds only supports single prompt, so we need to encode each prompt separately
36
+ list_of_llama_vec = []
37
+ list_of_llama_attention_mask = []
38
+ list_of_clip_l_pooler = []
39
+ for prompt in prompts:
40
+ with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
41
+ # llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompts, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
42
+ llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2)
43
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
44
+
45
+ list_of_llama_vec.append(llama_vec.squeeze(0))
46
+ list_of_llama_attention_mask.append(llama_attention_mask.squeeze(0))
47
+ list_of_clip_l_pooler.append(clip_l_pooler.squeeze(0))
48
+
49
+ # save prompt cache
50
+ for item, llama_vec, llama_attention_mask, clip_l_pooler in zip(
51
+ batch, list_of_llama_vec, list_of_llama_attention_mask, list_of_clip_l_pooler
52
+ ):
53
+ # save llama_vec and clip_l_pooler to cache
54
+ save_text_encoder_output_cache_framepack(item, llama_vec, llama_attention_mask, clip_l_pooler)
55
+
56
+
57
+ def main(args):
58
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
59
+ device = torch.device(device)
60
+
61
+ # Load dataset config
62
+ blueprint_generator = BlueprintGenerator(ConfigSanitizer())
63
+ logger.info(f"Load dataset config from {args.dataset_config}")
64
+ user_config = config_utils.load_user_config(args.dataset_config)
65
+ blueprint = blueprint_generator.generate(user_config, args, architecture=ARCHITECTURE_FRAMEPACK)
66
+ train_dataset_group = config_utils.generate_dataset_group_by_blueprint(blueprint.dataset_group)
67
+
68
+ datasets = train_dataset_group.datasets
69
+
70
+ # prepare cache files and paths: all_cache_files_for_dataset = exisiting cache files, all_cache_paths_for_dataset = all cache paths in the dataset
71
+ all_cache_files_for_dataset, all_cache_paths_for_dataset = cache_text_encoder_outputs.prepare_cache_files_and_paths(datasets)
72
+
73
+ # load text encoder
74
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
75
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
76
+ text_encoder2.to(device)
77
+
78
+ # Encode with Text Encoders
79
+ logger.info("Encoding with Text Encoders")
80
+
81
+ def encode_for_text_encoder(batch: list[ItemInfo]):
82
+ encode_and_save_batch(tokenizer1, text_encoder1, tokenizer2, text_encoder2, batch, device)
83
+
84
+ cache_text_encoder_outputs.process_text_encoder_batches(
85
+ args.num_workers,
86
+ args.skip_existing,
87
+ args.batch_size,
88
+ datasets,
89
+ all_cache_files_for_dataset,
90
+ all_cache_paths_for_dataset,
91
+ encode_for_text_encoder,
92
+ )
93
+
94
+ # remove cache files not in dataset
95
+ cache_text_encoder_outputs.post_process_cache_files(datasets, all_cache_files_for_dataset, all_cache_paths_for_dataset, args.keep_cache)
96
+
97
+
98
+ def framepack_setup_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
99
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
100
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
101
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
102
+ return parser
103
+
104
+
105
+ if __name__ == "__main__":
106
+ parser = cache_text_encoder_outputs.setup_parser_common()
107
+ parser = framepack_setup_parser(parser)
108
+
109
+ args = parser.parse_args()
110
+ main(args)
fpack_generate_video.py ADDED
@@ -0,0 +1,1711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ import gc
4
+ import json
5
+ import random
6
+ import os
7
+ import re
8
+ import time
9
+ import math
10
+ import copy
11
+ from typing import Tuple, Optional, List, Union, Any, Dict
12
+
13
+ import torch
14
+ from safetensors.torch import load_file, save_file
15
+ from safetensors import safe_open
16
+ from PIL import Image
17
+ import cv2
18
+ import numpy as np
19
+ import torchvision.transforms.functional as TF
20
+ from transformers import LlamaModel
21
+ from tqdm import tqdm
22
+
23
+ from networks import lora_framepack
24
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
25
+ from frame_pack import hunyuan
26
+ from frame_pack.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked, load_packed_model
27
+ from frame_pack.utils import crop_or_pad_yield_mask, resize_and_center_crop, soft_append_bcthw
28
+ from frame_pack.bucket_tools import find_nearest_bucket
29
+ from frame_pack.clip_vision import hf_clip_vision_encode
30
+ from frame_pack.k_diffusion_hunyuan import sample_hunyuan
31
+ from dataset import image_video_dataset
32
+
33
+ try:
34
+ from lycoris.kohya import create_network_from_weights
35
+ except:
36
+ pass
37
+
38
+ from utils.device_utils import clean_memory_on_device
39
+ from hv_generate_video import save_images_grid, save_videos_grid, synchronize_device
40
+ from wan_generate_video import merge_lora_weights
41
+ from frame_pack.framepack_utils import load_vae, load_text_encoder1, load_text_encoder2, load_image_encoders
42
+ from dataset.image_video_dataset import load_video
43
+
44
+ import logging
45
+
46
+ logger = logging.getLogger(__name__)
47
+ logging.basicConfig(level=logging.INFO)
48
+
49
+
50
+ class GenerationSettings:
51
+ def __init__(self, device: torch.device, dit_weight_dtype: Optional[torch.dtype] = None):
52
+ self.device = device
53
+ self.dit_weight_dtype = dit_weight_dtype # not used currently because model may be optimized
54
+
55
+
56
+ def parse_args() -> argparse.Namespace:
57
+ """parse command line arguments"""
58
+ parser = argparse.ArgumentParser(description="Wan 2.1 inference script")
59
+
60
+ # WAN arguments
61
+ # parser.add_argument("--ckpt_dir", type=str, default=None, help="The path to the checkpoint directory (Wan 2.1 official).")
62
+ parser.add_argument(
63
+ "--sample_solver", type=str, default="unipc", choices=["unipc", "dpm++", "vanilla"], help="The solver used to sample."
64
+ )
65
+
66
+ parser.add_argument("--dit", type=str, default=None, help="DiT directory or path")
67
+ parser.add_argument("--vae", type=str, default=None, help="VAE directory or path")
68
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory or path")
69
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory or path")
70
+ parser.add_argument("--image_encoder", type=str, required=True, help="Image Encoder directory or path")
71
+ parser.add_argument("--f1", action="store_true", help="Use F1 sampling method")
72
+
73
+ # LoRA
74
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
75
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
76
+ parser.add_argument("--include_patterns", type=str, nargs="*", default=None, help="LoRA module include patterns")
77
+ parser.add_argument("--exclude_patterns", type=str, nargs="*", default=None, help="LoRA module exclude patterns")
78
+ parser.add_argument(
79
+ "--save_merged_model",
80
+ type=str,
81
+ default=None,
82
+ help="Save merged model to path. If specified, no inference will be performed.",
83
+ )
84
+
85
+ # inference
86
+ parser.add_argument(
87
+ "--prompt",
88
+ type=str,
89
+ default=None,
90
+ help="prompt for generation. If `;;;` is used, it will be split into sections. Example: `section_index:prompt` or "
91
+ "`section_index:prompt;;;section_index:prompt;;;...`, section_index can be `0` or `-1` or `0-2`, `-1` means last section, `0-2` means from 0 to 2 (inclusive).",
92
+ )
93
+ parser.add_argument(
94
+ "--negative_prompt",
95
+ type=str,
96
+ default=None,
97
+ help="negative prompt for generation, default is empty string. should not change.",
98
+ )
99
+ parser.add_argument(
100
+ "--custom_system_prompt",
101
+ type=str,
102
+ default=None,
103
+ help="Custom system prompt for LLM. If specified, it will override the default system prompt. See hunyuan_model/text_encoder.py for the default system prompt.",
104
+ )
105
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size, height and width")
106
+ parser.add_argument("--video_seconds", type=float, default=5.0, help="video length, default is 5.0 seconds")
107
+ parser.add_argument(
108
+ "--video_sections",
109
+ type=int,
110
+ default=None,
111
+ help="number of video sections, Default is None (auto calculate from video seconds)",
112
+ )
113
+ parser.add_argument(
114
+ "--one_frame_inference",
115
+ type=str,
116
+ default=None,
117
+ help="one frame inference, default is None, comma separated values from 'zero_post', 'no_2x', 'no_4x' and 'no_post'.",
118
+ )
119
+ parser.add_argument(
120
+ "--image_mask_path",
121
+ type=str,
122
+ default=None,
123
+ help="path to image mask for one frame inference. If specified, it will be used as mask for input image.",
124
+ )
125
+ parser.add_argument(
126
+ "--end_image_mask_path",
127
+ type=str,
128
+ default=None,
129
+ nargs="*",
130
+ help="path to end (reference) image mask for one frame inference. If specified, it will be used as mask for end image.",
131
+ )
132
+ parser.add_argument("--fps", type=int, default=30, help="video fps, default is 30")
133
+ parser.add_argument("--infer_steps", type=int, default=25, help="number of inference steps, default is 25")
134
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
135
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
136
+ # parser.add_argument(
137
+ # "--cpu_noise", action="store_true", help="Use CPU to generate noise (compatible with ComfyUI). Default is False."
138
+ # )
139
+ parser.add_argument("--latent_window_size", type=int, default=9, help="latent window size, default is 9. should not change.")
140
+ parser.add_argument(
141
+ "--embedded_cfg_scale", type=float, default=10.0, help="Embeded CFG scale (distilled CFG Scale), default is 10.0"
142
+ )
143
+ parser.add_argument(
144
+ "--guidance_scale",
145
+ type=float,
146
+ default=1.0,
147
+ help="Guidance scale for classifier free guidance. Default is 1.0 (no guidance), should not change.",
148
+ )
149
+ parser.add_argument("--guidance_rescale", type=float, default=0.0, help="CFG Re-scale, default is 0.0. Should not change.")
150
+ # parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
151
+ parser.add_argument(
152
+ "--image_path",
153
+ type=str,
154
+ default=None,
155
+ help="path to image for image2video inference. If `;;;` is used, it will be used as section images. The notation is same as `--prompt`.",
156
+ )
157
+ parser.add_argument("--end_image_path", type=str, nargs="*", default=None, help="path to end image for image2video inference")
158
+ parser.add_argument(
159
+ "--latent_paddings",
160
+ type=str,
161
+ default=None,
162
+ help="latent paddings for each section, comma separated values. default is None (FramePack default paddings)",
163
+ )
164
+ # parser.add_argument(
165
+ # "--control_path",
166
+ # type=str,
167
+ # default=None,
168
+ # help="path to control video for inference with controlnet. video file or directory with images",
169
+ # )
170
+ # parser.add_argument("--trim_tail_frames", type=int, default=0, help="trim tail N frames from the video before saving")
171
+
172
+ # # Flow Matching
173
+ # parser.add_argument(
174
+ # "--flow_shift",
175
+ # type=float,
176
+ # default=None,
177
+ # help="Shift factor for flow matching schedulers. Default depends on task.",
178
+ # )
179
+
180
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
181
+ parser.add_argument("--fp8_scaled", action="store_true", help="use scaled fp8 for DiT, only for fp8")
182
+ # parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arithmetic (RTX 4XXX+), only for fp8_scaled")
183
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
184
+ parser.add_argument(
185
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
186
+ )
187
+ parser.add_argument(
188
+ "--attn_mode",
189
+ type=str,
190
+ default="torch",
191
+ choices=["flash", "torch", "sageattn", "xformers", "sdpa"], # "flash2", "flash3",
192
+ help="attention mode",
193
+ )
194
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
195
+ parser.add_argument(
196
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
197
+ )
198
+ parser.add_argument("--bulk_decode", action="store_true", help="decode all frames at once")
199
+ parser.add_argument("--blocks_to_swap", type=int, default=0, help="number of blocks to swap in the model")
200
+ parser.add_argument(
201
+ "--output_type",
202
+ type=str,
203
+ default="video",
204
+ choices=["video", "images", "latent", "both", "latent_images"],
205
+ help="output type",
206
+ )
207
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
208
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
209
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
210
+ # parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
211
+ # parser.add_argument(
212
+ # "--compile_args",
213
+ # nargs=4,
214
+ # metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
215
+ # default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
216
+ # help="Torch.compile settings",
217
+ # )
218
+
219
+ # New arguments for batch and interactive modes
220
+ parser.add_argument("--from_file", type=str, default=None, help="Read prompts from a file")
221
+ parser.add_argument("--interactive", action="store_true", help="Interactive mode: read prompts from console")
222
+
223
+ args = parser.parse_args()
224
+
225
+ # Validate arguments
226
+ if args.from_file and args.interactive:
227
+ raise ValueError("Cannot use both --from_file and --interactive at the same time")
228
+
229
+ if args.latent_path is None or len(args.latent_path) == 0:
230
+ if args.prompt is None and not args.from_file and not args.interactive:
231
+ raise ValueError("Either --prompt, --from_file or --interactive must be specified")
232
+
233
+ return args
234
+
235
+
236
+ def parse_prompt_line(line: str) -> Dict[str, Any]:
237
+ """Parse a prompt line into a dictionary of argument overrides
238
+
239
+ Args:
240
+ line: Prompt line with options
241
+
242
+ Returns:
243
+ Dict[str, Any]: Dictionary of argument overrides
244
+ """
245
+ # TODO common function with hv_train_network.line_to_prompt_dict
246
+ parts = line.split(" --")
247
+ prompt = parts[0].strip()
248
+
249
+ # Create dictionary of overrides
250
+ overrides = {"prompt": prompt}
251
+ # Initialize end_image_path and end_image_mask_path as a list to accommodate multiple paths
252
+ overrides["end_image_path"] = []
253
+ overrides["end_image_mask_path"] = []
254
+
255
+ for part in parts[1:]:
256
+ if not part.strip():
257
+ continue
258
+ option_parts = part.split(" ", 1)
259
+ option = option_parts[0].strip()
260
+ value = option_parts[1].strip() if len(option_parts) > 1 else ""
261
+
262
+ # Map options to argument names
263
+ if option == "w":
264
+ overrides["video_size_width"] = int(value)
265
+ elif option == "h":
266
+ overrides["video_size_height"] = int(value)
267
+ elif option == "f":
268
+ overrides["video_seconds"] = float(value)
269
+ elif option == "d":
270
+ overrides["seed"] = int(value)
271
+ elif option == "s":
272
+ overrides["infer_steps"] = int(value)
273
+ elif option == "g" or option == "l":
274
+ overrides["guidance_scale"] = float(value)
275
+ # elif option == "fs":
276
+ # overrides["flow_shift"] = float(value)
277
+ elif option == "i":
278
+ overrides["image_path"] = value
279
+ elif option == "im":
280
+ overrides["image_mask_path"] = value
281
+ # elif option == "cn":
282
+ # overrides["control_path"] = value
283
+ elif option == "n":
284
+ overrides["negative_prompt"] = value
285
+ elif option == "vs": # video_sections
286
+ overrides["video_sections"] = int(value)
287
+ elif option == "ei": # end_image_path
288
+ overrides["end_image_path"].append(value)
289
+ elif option == "eim": # end_image_mask_path
290
+ overrides["end_image_mask_path"].append(value)
291
+ elif option == "of": # one_frame_inference
292
+ overrides["one_frame_inference"] = value
293
+
294
+ # If no end_image_path was provided, remove the empty list
295
+ if not overrides["end_image_path"]:
296
+ del overrides["end_image_path"]
297
+ if not overrides["end_image_mask_path"]:
298
+ del overrides["end_image_mask_path"]
299
+
300
+ return overrides
301
+
302
+
303
+ def apply_overrides(args: argparse.Namespace, overrides: Dict[str, Any]) -> argparse.Namespace:
304
+ """Apply overrides to args
305
+
306
+ Args:
307
+ args: Original arguments
308
+ overrides: Dictionary of overrides
309
+
310
+ Returns:
311
+ argparse.Namespace: New arguments with overrides applied
312
+ """
313
+ args_copy = copy.deepcopy(args)
314
+
315
+ for key, value in overrides.items():
316
+ if key == "video_size_width":
317
+ args_copy.video_size[1] = value
318
+ elif key == "video_size_height":
319
+ args_copy.video_size[0] = value
320
+ else:
321
+ setattr(args_copy, key, value)
322
+
323
+ return args_copy
324
+
325
+
326
+ def check_inputs(args: argparse.Namespace) -> Tuple[int, int, int]:
327
+ """Validate video size and length
328
+
329
+ Args:
330
+ args: command line arguments
331
+
332
+ Returns:
333
+ Tuple[int, int, float]: (height, width, video_seconds)
334
+ """
335
+ height = args.video_size[0]
336
+ width = args.video_size[1]
337
+
338
+ video_seconds = args.video_seconds
339
+ if args.video_sections is not None:
340
+ video_seconds = (args.video_sections * (args.latent_window_size * 4) + 1) / args.fps
341
+
342
+ if height % 8 != 0 or width % 8 != 0:
343
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
344
+
345
+ return height, width, video_seconds
346
+
347
+
348
+ # region DiT model
349
+
350
+
351
+ def load_dit_model(args: argparse.Namespace, device: torch.device) -> HunyuanVideoTransformer3DModelPacked:
352
+ """load DiT model
353
+
354
+ Args:
355
+ args: command line arguments
356
+ device: device to use
357
+ dit_dtype: data type for the model
358
+ dit_weight_dtype: data type for the model weights. None for as-is
359
+
360
+ Returns:
361
+ HunyuanVideoTransformer3DModelPacked: DiT model
362
+ """
363
+ loading_device = "cpu"
364
+ if args.blocks_to_swap == 0 and not args.fp8_scaled and args.lora_weight is None:
365
+ loading_device = device
366
+
367
+ # do not fp8 optimize because we will merge LoRA weights
368
+ model = load_packed_model(device, args.dit, args.attn_mode, loading_device)
369
+ return model
370
+
371
+
372
+ def optimize_model(model: HunyuanVideoTransformer3DModelPacked, args: argparse.Namespace, device: torch.device) -> None:
373
+ """optimize the model (FP8 conversion, device move etc.)
374
+
375
+ Args:
376
+ model: dit model
377
+ args: command line arguments
378
+ device: device to use
379
+ """
380
+ if args.fp8_scaled:
381
+ # load state dict as-is and optimize to fp8
382
+ state_dict = model.state_dict()
383
+
384
+ # if no blocks to swap, we can move the weights to GPU after optimization on GPU (omit redundant CPU->GPU copy)
385
+ move_to_device = args.blocks_to_swap == 0 # if blocks_to_swap > 0, we will keep the model on CPU
386
+ state_dict = model.fp8_optimization(state_dict, device, move_to_device, use_scaled_mm=False) # args.fp8_fast)
387
+
388
+ info = model.load_state_dict(state_dict, strict=True, assign=True)
389
+ logger.info(f"Loaded FP8 optimized weights: {info}")
390
+
391
+ if args.blocks_to_swap == 0:
392
+ model.to(device) # make sure all parameters are on the right device (e.g. RoPE etc.)
393
+ else:
394
+ # simple cast to dit_dtype
395
+ target_dtype = None # load as-is (dit_weight_dtype == dtype of the weights in state_dict)
396
+ target_device = None
397
+
398
+ if args.fp8:
399
+ target_dtype = torch.float8e4m3fn
400
+
401
+ if args.blocks_to_swap == 0:
402
+ logger.info(f"Move model to device: {device}")
403
+ target_device = device
404
+
405
+ if target_device is not None and target_dtype is not None:
406
+ model.to(target_device, target_dtype) # move and cast at the same time. this reduces redundant copy operations
407
+
408
+ # if args.compile:
409
+ # compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
410
+ # logger.info(
411
+ # f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
412
+ # )
413
+ # torch._dynamo.config.cache_size_limit = 32
414
+ # for i in range(len(model.blocks)):
415
+ # model.blocks[i] = torch.compile(
416
+ # model.blocks[i],
417
+ # backend=compile_backend,
418
+ # mode=compile_mode,
419
+ # dynamic=compile_dynamic.lower() in "true",
420
+ # fullgraph=compile_fullgraph.lower() in "true",
421
+ # )
422
+
423
+ if args.blocks_to_swap > 0:
424
+ logger.info(f"Enable swap {args.blocks_to_swap} blocks to CPU from device: {device}")
425
+ model.enable_block_swap(args.blocks_to_swap, device, supports_backward=False)
426
+ model.move_to_device_except_swap_blocks(device)
427
+ model.prepare_block_swap_before_forward()
428
+ else:
429
+ # make sure the model is on the right device
430
+ model.to(device)
431
+
432
+ model.eval().requires_grad_(False)
433
+ clean_memory_on_device(device)
434
+
435
+
436
+ # endregion
437
+
438
+
439
+ def decode_latent(
440
+ latent_window_size: int,
441
+ total_latent_sections: int,
442
+ bulk_decode: bool,
443
+ vae: AutoencoderKLCausal3D,
444
+ latent: torch.Tensor,
445
+ device: torch.device,
446
+ one_frame_inference_mode: bool = False,
447
+ ) -> torch.Tensor:
448
+ logger.info(f"Decoding video...")
449
+ if latent.ndim == 4:
450
+ latent = latent.unsqueeze(0) # add batch dimension
451
+
452
+ vae.to(device)
453
+ if not bulk_decode and not one_frame_inference_mode:
454
+ latent_window_size = latent_window_size # default is 9
455
+ # total_latent_sections = (args.video_seconds * 30) / (latent_window_size * 4)
456
+ # total_latent_sections = int(max(round(total_latent_sections), 1))
457
+ num_frames = latent_window_size * 4 - 3
458
+
459
+ latents_to_decode = []
460
+ latent_frame_index = 0
461
+ for i in range(total_latent_sections - 1, -1, -1):
462
+ is_last_section = i == total_latent_sections - 1
463
+ generated_latent_frames = (num_frames + 3) // 4 + (1 if is_last_section else 0)
464
+ section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
465
+
466
+ section_latent = latent[:, :, latent_frame_index : latent_frame_index + section_latent_frames, :, :]
467
+ if section_latent.shape[2] > 0:
468
+ latents_to_decode.append(section_latent)
469
+
470
+ latent_frame_index += generated_latent_frames
471
+
472
+ latents_to_decode = latents_to_decode[::-1] # reverse the order of latents to decode
473
+
474
+ history_pixels = None
475
+ for latent in tqdm(latents_to_decode):
476
+ if history_pixels is None:
477
+ history_pixels = hunyuan.vae_decode(latent, vae).cpu()
478
+ else:
479
+ overlapped_frames = latent_window_size * 4 - 3
480
+ current_pixels = hunyuan.vae_decode(latent, vae).cpu()
481
+ history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
482
+ clean_memory_on_device(device)
483
+ else:
484
+ # bulk decode
485
+ logger.info(f"Bulk decoding or one frame inference")
486
+ if not one_frame_inference_mode:
487
+ history_pixels = hunyuan.vae_decode(latent, vae).cpu() # normal
488
+ else:
489
+ # one frame inference
490
+ history_pixels = [hunyuan.vae_decode(latent[:, :, i : i + 1, :, :], vae).cpu() for i in range(latent.shape[2])]
491
+ history_pixels = torch.cat(history_pixels, dim=2)
492
+
493
+ vae.to("cpu")
494
+
495
+ logger.info(f"Decoded. Pixel shape {history_pixels.shape}")
496
+ return history_pixels[0] # remove batch dimension
497
+
498
+
499
+ def prepare_i2v_inputs(
500
+ args: argparse.Namespace,
501
+ device: torch.device,
502
+ vae: AutoencoderKLCausal3D,
503
+ shared_models: Optional[Dict] = None,
504
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
505
+ """Prepare inputs for I2V
506
+
507
+ Args:
508
+ args: command line arguments
509
+ config: model configuration
510
+ device: device to use
511
+ vae: VAE model, used for image encoding
512
+ shared_models: dictionary containing pre-loaded models
513
+
514
+ Returns:
515
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[dict, dict]]:
516
+ (noise, context, context_null, y, (arg_c, arg_null))
517
+ """
518
+
519
+ height, width, video_seconds = check_inputs(args)
520
+
521
+ # define parsing function
522
+ def parse_section_strings(input_string: str) -> dict[int, str]:
523
+ section_strings = {}
524
+ if ";;;" in input_string:
525
+ split_section_strings = input_string.split(";;;")
526
+ for section_str in split_section_strings:
527
+ if ":" not in section_str:
528
+ start = end = 0
529
+ section_str = section_str.strip()
530
+ else:
531
+ index_str, section_str = section_str.split(":", 1)
532
+ index_str = index_str.strip()
533
+ section_str = section_str.strip()
534
+
535
+ m = re.match(r"^(-?\d+)(-\d+)?$", index_str)
536
+ if m:
537
+ start = int(m.group(1))
538
+ end = int(m.group(2)[1:]) if m.group(2) is not None else start
539
+ else:
540
+ start = end = 0
541
+ section_str = section_str.strip()
542
+ for i in range(start, end + 1):
543
+ section_strings[i] = section_str
544
+ else:
545
+ section_strings[0] = input_string
546
+
547
+ # assert 0 in section_prompts, "Section prompts must contain section 0"
548
+ if 0 not in section_strings:
549
+ # use smallest section index. prefer positive index over negative index
550
+ # if all section indices are negative, use the smallest negative index
551
+ indices = list(section_strings.keys())
552
+ if all(i < 0 for i in indices):
553
+ section_index = min(indices)
554
+ else:
555
+ section_index = min(i for i in indices if i >= 0)
556
+ section_strings[0] = section_strings[section_index]
557
+ return section_strings
558
+
559
+ # prepare image
560
+ def preprocess_image(image_path: str):
561
+ image = Image.open(image_path).convert("RGB")
562
+
563
+ image_np = np.array(image) # PIL to numpy, HWC
564
+
565
+ image_np = image_video_dataset.resize_image_to_bucket(image_np, (width, height))
566
+ image_tensor = torch.from_numpy(image_np).float() / 127.5 - 1.0 # -1 to 1.0, HWC
567
+ image_tensor = image_tensor.permute(2, 0, 1)[None, :, None] # HWC -> CHW -> NCFHW, N=1, C=3, F=1
568
+ return image_tensor, image_np
569
+
570
+ section_image_paths = parse_section_strings(args.image_path)
571
+
572
+ section_images = {}
573
+ for index, image_path in section_image_paths.items():
574
+ img_tensor, img_np = preprocess_image(image_path)
575
+ section_images[index] = (img_tensor, img_np)
576
+
577
+ # check end images
578
+ if args.end_image_path is not None and len(args.end_image_path) > 0:
579
+ end_image_tensors = []
580
+ for end_img_path in args.end_image_path:
581
+ end_image_tensor, _ = preprocess_image(end_img_path)
582
+ end_image_tensors.append(end_image_tensor)
583
+ else:
584
+ end_image_tensors = None
585
+
586
+ # configure negative prompt
587
+ n_prompt = args.negative_prompt if args.negative_prompt else ""
588
+
589
+ # parse section prompts
590
+ section_prompts = parse_section_strings(args.prompt)
591
+
592
+ # load text encoder
593
+ if shared_models is not None:
594
+ tokenizer1, text_encoder1 = shared_models["tokenizer1"], shared_models["text_encoder1"]
595
+ tokenizer2, text_encoder2 = shared_models["tokenizer2"], shared_models["text_encoder2"]
596
+ text_encoder1.to(device)
597
+ else:
598
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, device)
599
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
600
+ text_encoder2.to(device)
601
+
602
+ logger.info(f"Encoding prompt")
603
+ llama_vecs = {}
604
+ llama_attention_masks = {}
605
+ clip_l_poolers = {}
606
+ with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
607
+ for index, prompt in section_prompts.items():
608
+ llama_vec, clip_l_pooler = hunyuan.encode_prompt_conds(
609
+ prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2, custom_system_prompt=args.custom_system_prompt
610
+ )
611
+ llama_vec = llama_vec.cpu()
612
+ clip_l_pooler = clip_l_pooler.cpu()
613
+
614
+ llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
615
+
616
+ llama_vecs[index] = llama_vec
617
+ llama_attention_masks[index] = llama_attention_mask
618
+ clip_l_poolers[index] = clip_l_pooler
619
+
620
+ if args.guidance_scale == 1.0:
621
+ llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vecs[0]), torch.zeros_like(clip_l_poolers[0])
622
+ else:
623
+ with torch.autocast(device_type=device.type, dtype=text_encoder1.dtype), torch.no_grad():
624
+ llama_vec_n, clip_l_pooler_n = hunyuan.encode_prompt_conds(
625
+ n_prompt, text_encoder1, text_encoder2, tokenizer1, tokenizer2, custom_system_prompt=args.custom_system_prompt
626
+ )
627
+ llama_vec_n = llama_vec_n.cpu()
628
+ clip_l_pooler_n = clip_l_pooler_n.cpu()
629
+
630
+ llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
631
+
632
+ # free text encoder and clean memory
633
+ if shared_models is not None: # if shared models are used, do not free them but move to CPU
634
+ text_encoder1.to("cpu")
635
+ text_encoder2.to("cpu")
636
+ del tokenizer1, text_encoder1, tokenizer2, text_encoder2 # do not free shared models
637
+ clean_memory_on_device(device)
638
+
639
+ # load image encoder
640
+ if shared_models is not None:
641
+ feature_extractor, image_encoder = shared_models["feature_extractor"], shared_models["image_encoder"]
642
+ else:
643
+ feature_extractor, image_encoder = load_image_encoders(args)
644
+ image_encoder.to(device)
645
+
646
+ # encode image with image encoder
647
+ section_image_encoder_last_hidden_states = {}
648
+ for index, (img_tensor, img_np) in section_images.items():
649
+ with torch.no_grad():
650
+ image_encoder_output = hf_clip_vision_encode(img_np, feature_extractor, image_encoder)
651
+ image_encoder_last_hidden_state = image_encoder_output.last_hidden_state.cpu()
652
+ section_image_encoder_last_hidden_states[index] = image_encoder_last_hidden_state
653
+
654
+ # free image encoder and clean memory
655
+ if shared_models is not None:
656
+ image_encoder.to("cpu")
657
+ del image_encoder, feature_extractor
658
+ clean_memory_on_device(device)
659
+
660
+ # VAE encoding
661
+ logger.info(f"Encoding image to latent space")
662
+ vae.to(device)
663
+
664
+ section_start_latents = {}
665
+ for index, (img_tensor, img_np) in section_images.items():
666
+ start_latent = hunyuan.vae_encode(img_tensor, vae).cpu()
667
+ section_start_latents[index] = start_latent
668
+
669
+ # end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu() if end_image_tensor is not None else None
670
+ if end_image_tensors is not None:
671
+ end_latents = []
672
+ for end_image_tensor in end_image_tensors:
673
+ end_latent = hunyuan.vae_encode(end_image_tensor, vae).cpu()
674
+ end_latents.append(end_latent)
675
+ else:
676
+ end_latents = None
677
+
678
+ vae.to("cpu") # move VAE to CPU to save memory
679
+ clean_memory_on_device(device)
680
+
681
+ # prepare model input arguments
682
+ arg_c = {}
683
+ arg_null = {}
684
+ for index in llama_vecs.keys():
685
+ llama_vec = llama_vecs[index]
686
+ llama_attention_mask = llama_attention_masks[index]
687
+ clip_l_pooler = clip_l_poolers[index]
688
+ arg_c_i = {
689
+ "llama_vec": llama_vec,
690
+ "llama_attention_mask": llama_attention_mask,
691
+ "clip_l_pooler": clip_l_pooler,
692
+ "prompt": section_prompts[index], # for debugging
693
+ }
694
+ arg_c[index] = arg_c_i
695
+
696
+ arg_null = {
697
+ "llama_vec": llama_vec_n,
698
+ "llama_attention_mask": llama_attention_mask_n,
699
+ "clip_l_pooler": clip_l_pooler_n,
700
+ }
701
+
702
+ arg_c_img = {}
703
+ for index in section_images.keys():
704
+ image_encoder_last_hidden_state = section_image_encoder_last_hidden_states[index]
705
+ start_latent = section_start_latents[index]
706
+ arg_c_img_i = {
707
+ "image_encoder_last_hidden_state": image_encoder_last_hidden_state,
708
+ "start_latent": start_latent,
709
+ "image_path": section_image_paths[index],
710
+ }
711
+ arg_c_img[index] = arg_c_img_i
712
+
713
+ return height, width, video_seconds, arg_c, arg_null, arg_c_img, end_latents
714
+
715
+
716
+ # def setup_scheduler(args: argparse.Namespace, config, device: torch.device) -> Tuple[Any, torch.Tensor]:
717
+ # """setup scheduler for sampling
718
+
719
+ # Args:
720
+ # args: command line arguments
721
+ # config: model configuration
722
+ # device: device to use
723
+
724
+ # Returns:
725
+ # Tuple[Any, torch.Tensor]: (scheduler, timesteps)
726
+ # """
727
+ # if args.sample_solver == "unipc":
728
+ # scheduler = FlowUniPCMultistepScheduler(num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False)
729
+ # scheduler.set_timesteps(args.infer_steps, device=device, shift=args.flow_shift)
730
+ # timesteps = scheduler.timesteps
731
+ # elif args.sample_solver == "dpm++":
732
+ # scheduler = FlowDPMSolverMultistepScheduler(
733
+ # num_train_timesteps=config.num_train_timesteps, shift=1, use_dynamic_shifting=False
734
+ # )
735
+ # sampling_sigmas = get_sampling_sigmas(args.infer_steps, args.flow_shift)
736
+ # timesteps, _ = retrieve_timesteps(scheduler, device=device, sigmas=sampling_sigmas)
737
+ # elif args.sample_solver == "vanilla":
738
+ # scheduler = FlowMatchDiscreteScheduler(num_train_timesteps=config.num_train_timesteps, shift=args.flow_shift)
739
+ # scheduler.set_timesteps(args.infer_steps, device=device)
740
+ # timesteps = scheduler.timesteps
741
+
742
+ # # FlowMatchDiscreteScheduler does not support generator argument in step method
743
+ # org_step = scheduler.step
744
+
745
+ # def step_wrapper(
746
+ # model_output: torch.Tensor,
747
+ # timestep: Union[int, torch.Tensor],
748
+ # sample: torch.Tensor,
749
+ # return_dict: bool = True,
750
+ # generator=None,
751
+ # ):
752
+ # return org_step(model_output, timestep, sample, return_dict=return_dict)
753
+
754
+ # scheduler.step = step_wrapper
755
+ # else:
756
+ # raise NotImplementedError("Unsupported solver.")
757
+
758
+ # return scheduler, timesteps
759
+
760
+
761
+ def convert_lora_for_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
762
+ # Check the format of the LoRA file
763
+ keys = list(lora_sd.keys())
764
+ if keys[0].startswith("lora_unet_"):
765
+ # logging.info(f"Musubi Tuner LoRA detected")
766
+ pass
767
+
768
+ else:
769
+ transformer_prefixes = ["diffusion_model", "transformer"] # to ignore Text Encoder modules
770
+ lora_suffix = None
771
+ prefix = None
772
+ for key in keys:
773
+ if lora_suffix is None and "lora_A" in key:
774
+ lora_suffix = "lora_A"
775
+ if prefix is None:
776
+ pfx = key.split(".")[0]
777
+ if pfx in transformer_prefixes:
778
+ prefix = pfx
779
+ if lora_suffix is not None and prefix is not None:
780
+ break
781
+
782
+ if lora_suffix == "lora_A" and prefix is not None:
783
+ logging.info(f"Diffusion-pipe (?) LoRA detected, converting to the default LoRA format")
784
+ lora_sd = convert_lora_from_diffusion_pipe_or_something(lora_sd, "lora_unet_")
785
+
786
+ else:
787
+ logging.info(f"LoRA file format not recognized. Using it as-is.")
788
+
789
+ # Check LoRA is for FramePack or for HunyuanVideo
790
+ is_hunyuan = False
791
+ for key in lora_sd.keys():
792
+ if "double_blocks" in key or "single_blocks" in key:
793
+ is_hunyuan = True
794
+ break
795
+ if is_hunyuan:
796
+ logging.info("HunyuanVideo LoRA detected, converting to FramePack format")
797
+ lora_sd = convert_hunyuan_to_framepack(lora_sd)
798
+
799
+ return lora_sd
800
+
801
+
802
+ def convert_lora_from_diffusion_pipe_or_something(lora_sd: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
803
+ """
804
+ Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner.
805
+ Copy from Musubi Tuner repo.
806
+ """
807
+ # convert from diffusers(?) to default LoRA
808
+ # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
809
+ # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}
810
+
811
+ # note: Diffusers has no alpha, so alpha is set to rank
812
+ new_weights_sd = {}
813
+ lora_dims = {}
814
+ for key, weight in lora_sd.items():
815
+ diffusers_prefix, key_body = key.split(".", 1)
816
+ if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
817
+ print(f"unexpected key: {key} in diffusers format")
818
+ continue
819
+
820
+ new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
821
+ new_weights_sd[new_key] = weight
822
+
823
+ lora_name = new_key.split(".")[0] # before first dot
824
+ if lora_name not in lora_dims and "lora_down" in new_key:
825
+ lora_dims[lora_name] = weight.shape[0]
826
+
827
+ # add alpha with rank
828
+ for lora_name, dim in lora_dims.items():
829
+ new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)
830
+
831
+ return new_weights_sd
832
+
833
+
834
+ def convert_hunyuan_to_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
835
+ """
836
+ Convert HunyuanVideo LoRA weights to FramePack format.
837
+ """
838
+ new_lora_sd = {}
839
+ for key, weight in lora_sd.items():
840
+ if "double_blocks" in key:
841
+ key = key.replace("double_blocks", "transformer_blocks")
842
+ key = key.replace("img_mod_linear", "norm1_linear")
843
+ key = key.replace("img_attn_qkv", "attn_to_QKV") # split later
844
+ key = key.replace("img_attn_proj", "attn_to_out_0")
845
+ key = key.replace("img_mlp_fc1", "ff_net_0_proj")
846
+ key = key.replace("img_mlp_fc2", "ff_net_2")
847
+ key = key.replace("txt_mod_linear", "norm1_context_linear")
848
+ key = key.replace("txt_attn_qkv", "attn_add_QKV_proj") # split later
849
+ key = key.replace("txt_attn_proj", "attn_to_add_out")
850
+ key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj")
851
+ key = key.replace("txt_mlp_fc2", "ff_context_net_2")
852
+ elif "single_blocks" in key:
853
+ key = key.replace("single_blocks", "single_transformer_blocks")
854
+ key = key.replace("linear1", "attn_to_QKVM") # split later
855
+ key = key.replace("linear2", "proj_out")
856
+ key = key.replace("modulation_linear", "norm_linear")
857
+ else:
858
+ print(f"Unsupported module name: {key}, only double_blocks and single_blocks are supported")
859
+ continue
860
+
861
+ if "QKVM" in key:
862
+ # split QKVM into Q, K, V, M
863
+ key_q = key.replace("QKVM", "q")
864
+ key_k = key.replace("QKVM", "k")
865
+ key_v = key.replace("QKVM", "v")
866
+ key_m = key.replace("attn_to_QKVM", "proj_mlp")
867
+ if "_down" in key or "alpha" in key:
868
+ # copy QKVM weight or alpha to Q, K, V, M
869
+ assert "alpha" in key or weight.size(1) == 3072, f"QKVM weight size mismatch: {key}. {weight.size()}"
870
+ new_lora_sd[key_q] = weight
871
+ new_lora_sd[key_k] = weight
872
+ new_lora_sd[key_v] = weight
873
+ new_lora_sd[key_m] = weight
874
+ elif "_up" in key:
875
+ # split QKVM weight into Q, K, V, M
876
+ assert weight.size(0) == 21504, f"QKVM weight size mismatch: {key}. {weight.size()}"
877
+ new_lora_sd[key_q] = weight[:3072]
878
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
879
+ new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3]
880
+ new_lora_sd[key_m] = weight[3072 * 3 :] # 21504 - 3072 * 3 = 12288
881
+ else:
882
+ print(f"Unsupported module name: {key}")
883
+ continue
884
+ elif "QKV" in key:
885
+ # split QKV into Q, K, V
886
+ key_q = key.replace("QKV", "q")
887
+ key_k = key.replace("QKV", "k")
888
+ key_v = key.replace("QKV", "v")
889
+ if "_down" in key or "alpha" in key:
890
+ # copy QKV weight or alpha to Q, K, V
891
+ assert "alpha" in key or weight.size(1) == 3072, f"QKV weight size mismatch: {key}. {weight.size()}"
892
+ new_lora_sd[key_q] = weight
893
+ new_lora_sd[key_k] = weight
894
+ new_lora_sd[key_v] = weight
895
+ elif "_up" in key:
896
+ # split QKV weight into Q, K, V
897
+ assert weight.size(0) == 3072 * 3, f"QKV weight size mismatch: {key}. {weight.size()}"
898
+ new_lora_sd[key_q] = weight[:3072]
899
+ new_lora_sd[key_k] = weight[3072 : 3072 * 2]
900
+ new_lora_sd[key_v] = weight[3072 * 2 :]
901
+ else:
902
+ print(f"Unsupported module name: {key}")
903
+ continue
904
+ else:
905
+ # no split needed
906
+ new_lora_sd[key] = weight
907
+
908
+ return new_lora_sd
909
+
910
+
911
+ def generate(
912
+ args: argparse.Namespace, gen_settings: GenerationSettings, shared_models: Optional[Dict] = None
913
+ ) -> tuple[AutoencoderKLCausal3D, torch.Tensor]:
914
+ """main function for generation
915
+
916
+ Args:
917
+ args: command line arguments
918
+ shared_models: dictionary containing pre-loaded models
919
+
920
+ Returns:
921
+ tuple: (AutoencoderKLCausal3D model (vae), torch.Tensor generated latent)
922
+ """
923
+ device, dit_weight_dtype = (gen_settings.device, gen_settings.dit_weight_dtype)
924
+
925
+ # prepare seed
926
+ seed = args.seed if args.seed is not None else random.randint(0, 2**32 - 1)
927
+ args.seed = seed # set seed to args for saving
928
+
929
+ # Check if we have shared models
930
+ if shared_models is not None:
931
+ # Use shared models and encoded data
932
+ vae = shared_models.get("vae")
933
+ height, width, video_seconds, context, context_null, context_img, end_latents = prepare_i2v_inputs(
934
+ args, device, vae, shared_models
935
+ )
936
+ else:
937
+ # prepare inputs without shared models
938
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
939
+ height, width, video_seconds, context, context_null, context_img, end_latents = prepare_i2v_inputs(args, device, vae)
940
+
941
+ if shared_models is None or "model" not in shared_models:
942
+ # load DiT model
943
+ model = load_dit_model(args, device)
944
+
945
+ # merge LoRA weights
946
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
947
+ # ugly hack to common merge_lora_weights function
948
+ merge_lora_weights(lora_framepack, model, args, device, convert_lora_for_framepack)
949
+
950
+ # if we only want to save the model, we can skip the rest
951
+ if args.save_merged_model:
952
+ return None, None
953
+
954
+ # optimize model: fp8 conversion, block swap etc.
955
+ optimize_model(model, args, device)
956
+
957
+ if shared_models is not None:
958
+ shared_models["model"] = model
959
+ else:
960
+ # use shared model
961
+ model: HunyuanVideoTransformer3DModelPacked = shared_models["model"]
962
+ model.move_to_device_except_swap_blocks(device)
963
+ model.prepare_block_swap_before_forward()
964
+
965
+ # sampling
966
+ latent_window_size = args.latent_window_size # default is 9
967
+ # ex: (5s * 30fps) / (9 * 4) = 4.16 -> 4 sections, 60s -> 1800 / 36 = 50 sections
968
+ total_latent_sections = (video_seconds * 30) / (latent_window_size * 4)
969
+ total_latent_sections = int(max(round(total_latent_sections), 1))
970
+
971
+ # set random generator
972
+ seed_g = torch.Generator(device="cpu")
973
+ seed_g.manual_seed(seed)
974
+ num_frames = latent_window_size * 4 - 3
975
+
976
+ logger.info(
977
+ f"Video size: {height}x{width}@{video_seconds} (HxW@seconds), fps: {args.fps}, num sections: {total_latent_sections}, "
978
+ f"infer_steps: {args.infer_steps}, frames per generation: {num_frames}"
979
+ )
980
+
981
+ # video generation ######
982
+ f1_mode = args.f1
983
+ one_frame_inference = None
984
+ if args.one_frame_inference is not None:
985
+ one_frame_inference = set()
986
+ for mode in args.one_frame_inference.split(","):
987
+ one_frame_inference.add(mode.strip())
988
+
989
+ # prepare history latents
990
+ history_latents = torch.zeros((1, 16, 1 + 2 + 16, height // 8, width // 8), dtype=torch.float32)
991
+ if end_latents is not None and not f1_mode:
992
+ logger.info(f"Use end image(s): {args.end_image_path}")
993
+ for i, end_latent in enumerate(end_latents):
994
+ history_latents[:, :, i + 1 : i + 2] = end_latent.to(history_latents)
995
+
996
+ # prepare clean latents and indices
997
+ if not f1_mode:
998
+ # Inverted Anti-drifting
999
+ total_generated_latent_frames = 0
1000
+ latent_paddings = reversed(range(total_latent_sections))
1001
+
1002
+ if total_latent_sections > 4 and one_frame_inference is None:
1003
+ # In theory the latent_paddings should follow the above sequence, but it seems that duplicating some
1004
+ # items looks better than expanding it when total_latent_sections > 4
1005
+ # One can try to remove below trick and just
1006
+ # use `latent_paddings = list(reversed(range(total_latent_sections)))` to compare
1007
+ # 4 sections: 3, 2, 1, 0. 50 sections: 3, 2, 2, ... 2, 1, 0
1008
+ latent_paddings = [3] + [2] * (total_latent_sections - 3) + [1, 0]
1009
+
1010
+ if args.latent_paddings is not None:
1011
+ # parse user defined latent paddings
1012
+ user_latent_paddings = [int(x) for x in args.latent_paddings.split(",")]
1013
+ if len(user_latent_paddings) < total_latent_sections:
1014
+ print(
1015
+ f"User defined latent paddings length {len(user_latent_paddings)} does not match total sections {total_latent_sections}."
1016
+ )
1017
+ print(f"Use default paddings instead for unspecified sections.")
1018
+ latent_paddings[: len(user_latent_paddings)] = user_latent_paddings
1019
+ elif len(user_latent_paddings) > total_latent_sections:
1020
+ print(
1021
+ f"User defined latent paddings length {len(user_latent_paddings)} is greater than total sections {total_latent_sections}."
1022
+ )
1023
+ print(f"Use only first {total_latent_sections} paddings instead.")
1024
+ latent_paddings = user_latent_paddings[:total_latent_sections]
1025
+ else:
1026
+ latent_paddings = user_latent_paddings
1027
+ else:
1028
+ start_latent = context_img[0]["start_latent"]
1029
+ history_latents = torch.cat([history_latents, start_latent], dim=2)
1030
+ total_generated_latent_frames = 1 # a bit hacky, but we employ the same logic as in official code
1031
+ latent_paddings = [0] * total_latent_sections # dummy paddings for F1 mode
1032
+
1033
+ latent_paddings = list(latent_paddings) # make sure it's a list
1034
+ for loop_index in range(total_latent_sections):
1035
+ latent_padding = latent_paddings[loop_index]
1036
+
1037
+ if not f1_mode:
1038
+ # Inverted Anti-drifting
1039
+ section_index_reverse = loop_index # 0, 1, 2, 3
1040
+ section_index = total_latent_sections - 1 - section_index_reverse # 3, 2, 1, 0
1041
+ section_index_from_last = -(section_index_reverse + 1) # -1, -2, -3, -4
1042
+
1043
+ is_last_section = section_index == 0
1044
+ is_first_section = section_index_reverse == 0
1045
+ latent_padding_size = latent_padding * latent_window_size
1046
+
1047
+ logger.info(f"latent_padding_size = {latent_padding_size}, is_last_section = {is_last_section}")
1048
+ else:
1049
+ section_index = loop_index # 0, 1, 2, 3
1050
+ section_index_from_last = section_index - total_latent_sections # -4, -3, -2, -1
1051
+ is_last_section = loop_index == total_latent_sections - 1
1052
+ is_first_section = loop_index == 0
1053
+ latent_padding_size = 0 # dummy padding for F1 mode
1054
+
1055
+ # select start latent
1056
+ if section_index_from_last in context_img:
1057
+ image_index = section_index_from_last
1058
+ elif section_index in context_img:
1059
+ image_index = section_index
1060
+ else:
1061
+ image_index = 0
1062
+
1063
+ start_latent = context_img[image_index]["start_latent"]
1064
+ image_path = context_img[image_index]["image_path"]
1065
+ if image_index != 0: # use section image other than section 0
1066
+ logger.info(f"Apply experimental section image, latent_padding_size = {latent_padding_size}, image_path = {image_path}")
1067
+
1068
+ if not f1_mode:
1069
+ # Inverted Anti-drifting
1070
+ indices = torch.arange(0, sum([1, latent_padding_size, latent_window_size, 1, 2, 16])).unsqueeze(0)
1071
+ (
1072
+ clean_latent_indices_pre,
1073
+ blank_indices,
1074
+ latent_indices,
1075
+ clean_latent_indices_post,
1076
+ clean_latent_2x_indices,
1077
+ clean_latent_4x_indices,
1078
+ ) = indices.split([1, latent_padding_size, latent_window_size, 1, 2, 16], dim=1)
1079
+ clean_latent_indices = torch.cat([clean_latent_indices_pre, clean_latent_indices_post], dim=1)
1080
+
1081
+ clean_latents_pre = start_latent.to(history_latents)
1082
+ clean_latents_post, clean_latents_2x, clean_latents_4x = history_latents[:, :, : 1 + 2 + 16, :, :].split(
1083
+ [1, 2, 16], dim=2
1084
+ )
1085
+ clean_latents = torch.cat([clean_latents_pre, clean_latents_post], dim=2)
1086
+
1087
+ if end_latents is not None:
1088
+ clean_latents = torch.cat([clean_latents_pre, history_latents[:, :, : len(end_latents)]], dim=2)
1089
+ clean_latent_indices_extended = torch.zeros(1, 1 + len(end_latents), dtype=clean_latent_indices.dtype)
1090
+ clean_latent_indices_extended[:, :2] = clean_latent_indices
1091
+ clean_latent_indices = clean_latent_indices_extended
1092
+
1093
+ else:
1094
+ # F1 mode
1095
+ indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
1096
+ (
1097
+ clean_latent_indices_start,
1098
+ clean_latent_4x_indices,
1099
+ clean_latent_2x_indices,
1100
+ clean_latent_1x_indices,
1101
+ latent_indices,
1102
+ ) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
1103
+ clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
1104
+
1105
+ clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[:, :, -sum([16, 2, 1]) :, :, :].split(
1106
+ [16, 2, 1], dim=2
1107
+ )
1108
+ clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
1109
+
1110
+ # if use_teacache:
1111
+ # transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
1112
+ # else:
1113
+ # transformer.initialize_teacache(enable_teacache=False)
1114
+
1115
+ # prepare conditioning inputs
1116
+ if section_index_from_last in context:
1117
+ prompt_index = section_index_from_last
1118
+ elif section_index in context:
1119
+ prompt_index = section_index
1120
+ else:
1121
+ prompt_index = 0
1122
+
1123
+ context_for_index = context[prompt_index]
1124
+ # if args.section_prompts is not None:
1125
+ logger.info(f"Section {section_index}: {context_for_index['prompt']}")
1126
+
1127
+ llama_vec = context_for_index["llama_vec"].to(device, dtype=torch.bfloat16)
1128
+ llama_attention_mask = context_for_index["llama_attention_mask"].to(device)
1129
+ clip_l_pooler = context_for_index["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1130
+
1131
+ image_encoder_last_hidden_state = context_img[image_index]["image_encoder_last_hidden_state"].to(
1132
+ device, dtype=torch.bfloat16
1133
+ )
1134
+
1135
+ llama_vec_n = context_null["llama_vec"].to(device, dtype=torch.bfloat16)
1136
+ llama_attention_mask_n = context_null["llama_attention_mask"].to(device)
1137
+ clip_l_pooler_n = context_null["clip_l_pooler"].to(device, dtype=torch.bfloat16)
1138
+
1139
+ # call DiT model to generate latents
1140
+ sample_num_frames = num_frames
1141
+ if one_frame_inference is not None:
1142
+ # one frame inference
1143
+ latent_indices = latent_indices[:, -1:] # only use the last frame (default)
1144
+ sample_num_frames = 1
1145
+
1146
+ def get_latent_mask(mask_path: str):
1147
+ mask_image = Image.open(mask_path).convert("L") # grayscale
1148
+ mask_image = mask_image.resize((width // 8, height // 8), Image.LANCZOS)
1149
+ mask_image = np.array(mask_image) # PIL to numpy, HWC
1150
+ mask_image = torch.from_numpy(mask_image).float() / 255.0 # 0 to 1.0, HWC
1151
+ mask_image = mask_image.squeeze(-1) # HWC -> HW
1152
+ mask_image = mask_image.unsqueeze(0).unsqueeze(0) # HW -> 11HW
1153
+ mask_image = mask_image.to(clean_latents)
1154
+ return mask_image
1155
+
1156
+ if args.image_mask_path is not None:
1157
+ mask_image = get_latent_mask(args.image_mask_path)
1158
+ logger.info(f"Apply mask for clean latents (start image): {args.image_mask_path}, shape: {mask_image.shape}")
1159
+ clean_latents[:, :, 0, :, :] = clean_latents[:, :, 0, :, :] * mask_image
1160
+ if args.end_image_mask_path is not None and len(args.end_image_mask_path) > 0:
1161
+ # # apply mask for clean latents 1x (end image)
1162
+ count = min(len(args.end_image_mask_path), len(end_latents))
1163
+ for i in range(count):
1164
+ mask_image = get_latent_mask(args.end_image_mask_path[i])
1165
+ logger.info(
1166
+ f"Apply mask for clean latents 1x (end image) for {i+1}: {args.end_image_mask_path[i]}, shape: {mask_image.shape}"
1167
+ )
1168
+ clean_latents[:, :, i + 1 : i + 2, :, :] = clean_latents[:, :, i + 1 : i + 2, :, :] * mask_image
1169
+
1170
+ for one_frame_param in one_frame_inference:
1171
+ if one_frame_param.startswith("target_index="):
1172
+ target_index = int(one_frame_param.split("=")[1])
1173
+ latent_indices[:, 0] = target_index
1174
+ logger.info(f"Set index for target: {target_index}")
1175
+ elif one_frame_param.startswith("start_index="):
1176
+ start_index = int(one_frame_param.split("=")[1])
1177
+ clean_latent_indices[:, 0] = start_index
1178
+ logger.info(f"Set index for clean latent pre (start image): {start_index}")
1179
+ elif one_frame_param.startswith("history_index="):
1180
+ history_indices = one_frame_param.split("=")[1].split(";")
1181
+ i = 0
1182
+ while i < len(history_indices) and i < len(end_latents):
1183
+ history_index = int(history_indices[i])
1184
+ clean_latent_indices[:, 1 + i] = history_index
1185
+ i += 1
1186
+ while i < len(end_latents):
1187
+ clean_latent_indices[:, 1 + i] = history_index
1188
+ i += 1
1189
+ logger.info(f"Set index for clean latent post (end image): {history_indices}")
1190
+
1191
+ if "no_2x" in one_frame_inference:
1192
+ clean_latents_2x = None
1193
+ clean_latent_2x_indices = None
1194
+ logger.info(f"No clean_latents_2x")
1195
+ if "no_4x" in one_frame_inference:
1196
+ clean_latents_4x = None
1197
+ clean_latent_4x_indices = None
1198
+ logger.info(f"No clean_latents_4x")
1199
+ if "no_post" in one_frame_inference:
1200
+ clean_latents = clean_latents[:, :, :1, :, :]
1201
+ clean_latent_indices = clean_latent_indices[:, :1]
1202
+ logger.info(f"No clean_latents post")
1203
+ elif "zero_post" in one_frame_inference:
1204
+ # zero out the history latents. this seems to prevent the images from corrupting
1205
+ clean_latents[:, :, 1:, :, :] = torch.zeros_like(clean_latents[:, :, 1:, :, :])
1206
+ logger.info(f"Zero out clean_latents post")
1207
+
1208
+ logger.info(
1209
+ f"One frame inference. clean_latent: {clean_latents.shape} latent_indices: {latent_indices}, clean_latent_indices: {clean_latent_indices}, num_frames: {sample_num_frames}"
1210
+ )
1211
+
1212
+ generated_latents = sample_hunyuan(
1213
+ transformer=model,
1214
+ sampler=args.sample_solver,
1215
+ width=width,
1216
+ height=height,
1217
+ frames=sample_num_frames,
1218
+ real_guidance_scale=args.guidance_scale,
1219
+ distilled_guidance_scale=args.embedded_cfg_scale,
1220
+ guidance_rescale=args.guidance_rescale,
1221
+ # shift=3.0,
1222
+ num_inference_steps=args.infer_steps,
1223
+ generator=seed_g,
1224
+ prompt_embeds=llama_vec,
1225
+ prompt_embeds_mask=llama_attention_mask,
1226
+ prompt_poolers=clip_l_pooler,
1227
+ negative_prompt_embeds=llama_vec_n,
1228
+ negative_prompt_embeds_mask=llama_attention_mask_n,
1229
+ negative_prompt_poolers=clip_l_pooler_n,
1230
+ device=device,
1231
+ dtype=torch.bfloat16,
1232
+ image_embeddings=image_encoder_last_hidden_state,
1233
+ latent_indices=latent_indices,
1234
+ clean_latents=clean_latents,
1235
+ clean_latent_indices=clean_latent_indices,
1236
+ clean_latents_2x=clean_latents_2x,
1237
+ clean_latent_2x_indices=clean_latent_2x_indices,
1238
+ clean_latents_4x=clean_latents_4x,
1239
+ clean_latent_4x_indices=clean_latent_4x_indices,
1240
+ )
1241
+
1242
+ # concatenate generated latents
1243
+ total_generated_latent_frames += int(generated_latents.shape[2])
1244
+ if not f1_mode:
1245
+ # Inverted Anti-drifting: prepend generated latents to history latents
1246
+ if is_last_section:
1247
+ generated_latents = torch.cat([start_latent.to(generated_latents), generated_latents], dim=2)
1248
+ total_generated_latent_frames += 1
1249
+
1250
+ history_latents = torch.cat([generated_latents.to(history_latents), history_latents], dim=2)
1251
+ real_history_latents = history_latents[:, :, :total_generated_latent_frames, :, :]
1252
+ else:
1253
+ # F1 mode: append generated latents to history latents
1254
+ history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
1255
+ real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
1256
+
1257
+ logger.info(f"Generated. Latent shape {real_history_latents.shape}")
1258
+
1259
+ # # TODO support saving intermediate video
1260
+ # clean_memory_on_device(device)
1261
+ # vae.to(device)
1262
+ # if history_pixels is None:
1263
+ # history_pixels = hunyuan.vae_decode(real_history_latents, vae).cpu()
1264
+ # else:
1265
+ # section_latent_frames = (latent_window_size * 2 + 1) if is_last_section else (latent_window_size * 2)
1266
+ # overlapped_frames = latent_window_size * 4 - 3
1267
+ # current_pixels = hunyuan.vae_decode(real_history_latents[:, :, :section_latent_frames], vae).cpu()
1268
+ # history_pixels = soft_append_bcthw(current_pixels, history_pixels, overlapped_frames)
1269
+ # vae.to("cpu")
1270
+ # # if not is_last_section:
1271
+ # # # save intermediate video
1272
+ # # save_video(history_pixels[0], args, total_generated_latent_frames)
1273
+ # print(f"Decoded. Current latent shape {real_history_latents.shape}; pixel shape {history_pixels.shape}")
1274
+
1275
+ if one_frame_inference is not None:
1276
+ real_history_latents = real_history_latents[:, :, 1:, :, :] # remove the first frame (start_latent)
1277
+
1278
+ # Only clean up shared models if they were created within this function
1279
+ if shared_models is None:
1280
+ del model # free memory
1281
+ synchronize_device(device)
1282
+ else:
1283
+ # move model to CPU to save memory
1284
+ model.to("cpu")
1285
+
1286
+ # wait for 5 seconds until block swap is done
1287
+ logger.info("Waiting for 5 seconds to finish block swap")
1288
+ time.sleep(5)
1289
+
1290
+ gc.collect()
1291
+ clean_memory_on_device(device)
1292
+
1293
+ return vae, real_history_latents
1294
+
1295
+
1296
+ def save_latent(latent: torch.Tensor, args: argparse.Namespace, height: int, width: int) -> str:
1297
+ """Save latent to file
1298
+
1299
+ Args:
1300
+ latent: Latent tensor
1301
+ args: command line arguments
1302
+ height: height of frame
1303
+ width: width of frame
1304
+
1305
+ Returns:
1306
+ str: Path to saved latent file
1307
+ """
1308
+ save_path = args.save_path
1309
+ os.makedirs(save_path, exist_ok=True)
1310
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1311
+
1312
+ seed = args.seed
1313
+ video_seconds = args.video_seconds
1314
+ latent_path = f"{save_path}/{time_flag}_{seed}_latent.safetensors"
1315
+
1316
+ if args.no_metadata:
1317
+ metadata = None
1318
+ else:
1319
+ metadata = {
1320
+ "seeds": f"{seed}",
1321
+ "prompt": f"{args.prompt}",
1322
+ "height": f"{height}",
1323
+ "width": f"{width}",
1324
+ "video_seconds": f"{video_seconds}",
1325
+ "infer_steps": f"{args.infer_steps}",
1326
+ "guidance_scale": f"{args.guidance_scale}",
1327
+ "latent_window_size": f"{args.latent_window_size}",
1328
+ "embedded_cfg_scale": f"{args.embedded_cfg_scale}",
1329
+ "guidance_rescale": f"{args.guidance_rescale}",
1330
+ "sample_solver": f"{args.sample_solver}",
1331
+ "latent_window_size": f"{args.latent_window_size}",
1332
+ "fps": f"{args.fps}",
1333
+ }
1334
+ if args.negative_prompt is not None:
1335
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
1336
+
1337
+ sd = {"latent": latent.contiguous()}
1338
+ save_file(sd, latent_path, metadata=metadata)
1339
+ logger.info(f"Latent saved to: {latent_path}")
1340
+
1341
+ return latent_path
1342
+
1343
+
1344
+ def save_video(
1345
+ video: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None, latent_frames: Optional[int] = None
1346
+ ) -> str:
1347
+ """Save video to file
1348
+
1349
+ Args:
1350
+ video: Video tensor
1351
+ args: command line arguments
1352
+ original_base_name: Original base name (if latents are loaded from files)
1353
+
1354
+ Returns:
1355
+ str: Path to saved video file
1356
+ """
1357
+ save_path = args.save_path
1358
+ os.makedirs(save_path, exist_ok=True)
1359
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1360
+
1361
+ seed = args.seed
1362
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
1363
+ latent_frames = "" if latent_frames is None else f"_{latent_frames}"
1364
+ video_path = f"{save_path}/{time_flag}_{seed}{original_name}{latent_frames}.mp4"
1365
+
1366
+ video = video.unsqueeze(0)
1367
+ save_videos_grid(video, video_path, fps=args.fps, rescale=True)
1368
+ logger.info(f"Video saved to: {video_path}")
1369
+
1370
+ return video_path
1371
+
1372
+
1373
+ def save_images(sample: torch.Tensor, args: argparse.Namespace, original_base_name: Optional[str] = None) -> str:
1374
+ """Save images to directory
1375
+
1376
+ Args:
1377
+ sample: Video tensor
1378
+ args: command line arguments
1379
+ original_base_name: Original base name (if latents are loaded from files)
1380
+
1381
+ Returns:
1382
+ str: Path to saved images directory
1383
+ """
1384
+ save_path = args.save_path
1385
+ os.makedirs(save_path, exist_ok=True)
1386
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
1387
+
1388
+ seed = args.seed
1389
+ original_name = "" if original_base_name is None else f"_{original_base_name}"
1390
+ image_name = f"{time_flag}_{seed}{original_name}"
1391
+ sample = sample.unsqueeze(0)
1392
+ one_frame_mode = args.one_frame_inference is not None
1393
+ save_images_grid(sample, save_path, image_name, rescale=True, create_subdir=not one_frame_mode)
1394
+ logger.info(f"Sample images saved to: {save_path}/{image_name}")
1395
+
1396
+ return f"{save_path}/{image_name}"
1397
+
1398
+
1399
+ def save_output(
1400
+ args: argparse.Namespace,
1401
+ vae: AutoencoderKLCausal3D,
1402
+ latent: torch.Tensor,
1403
+ device: torch.device,
1404
+ original_base_names: Optional[List[str]] = None,
1405
+ ) -> None:
1406
+ """save output
1407
+
1408
+ Args:
1409
+ args: command line arguments
1410
+ vae: VAE model
1411
+ latent: latent tensor
1412
+ device: device to use
1413
+ original_base_names: original base names (if latents are loaded from files)
1414
+ """
1415
+ height, width = latent.shape[-2], latent.shape[-1] # BCTHW
1416
+ height *= 8
1417
+ width *= 8
1418
+ # print(f"Saving output. Latent shape {latent.shape}; pixel shape {height}x{width}")
1419
+ if args.output_type == "latent" or args.output_type == "both" or args.output_type == "latent_images":
1420
+ # save latent
1421
+ save_latent(latent, args, height, width)
1422
+ if args.output_type == "latent":
1423
+ return
1424
+
1425
+ total_latent_sections = (args.video_seconds * 30) / (args.latent_window_size * 4)
1426
+ total_latent_sections = int(max(round(total_latent_sections), 1))
1427
+ video = decode_latent(
1428
+ args.latent_window_size, total_latent_sections, args.bulk_decode, vae, latent, device, args.one_frame_inference is not None
1429
+ )
1430
+
1431
+ if args.output_type == "video" or args.output_type == "both":
1432
+ # save video
1433
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
1434
+ save_video(video, args, original_name)
1435
+
1436
+ elif args.output_type == "images" or args.output_type == "latent_images":
1437
+ # save images
1438
+ original_name = "" if original_base_names is None else f"_{original_base_names[0]}"
1439
+ save_images(video, args, original_name)
1440
+
1441
+
1442
+ def preprocess_prompts_for_batch(prompt_lines: List[str], base_args: argparse.Namespace) -> List[Dict]:
1443
+ """Process multiple prompts for batch mode
1444
+
1445
+ Args:
1446
+ prompt_lines: List of prompt lines
1447
+ base_args: Base command line arguments
1448
+
1449
+ Returns:
1450
+ List[Dict]: List of prompt data dictionaries
1451
+ """
1452
+ prompts_data = []
1453
+
1454
+ for line in prompt_lines:
1455
+ line = line.strip()
1456
+ if not line or line.startswith("#"): # Skip empty lines and comments
1457
+ continue
1458
+
1459
+ # Parse prompt line and create override dictionary
1460
+ prompt_data = parse_prompt_line(line)
1461
+ logger.info(f"Parsed prompt data: {prompt_data}")
1462
+ prompts_data.append(prompt_data)
1463
+
1464
+ return prompts_data
1465
+
1466
+
1467
+ def load_shared_models(args: argparse.Namespace) -> Dict:
1468
+ """Load shared models for batch processing or interactive mode.
1469
+ Models are loaded to CPU to save memory.
1470
+
1471
+ Args:
1472
+ args: Base command line arguments
1473
+
1474
+ Returns:
1475
+ Dict: Dictionary of shared models
1476
+ """
1477
+ shared_models = {}
1478
+ tokenizer1, text_encoder1 = load_text_encoder1(args, args.fp8_llm, "cpu")
1479
+ tokenizer2, text_encoder2 = load_text_encoder2(args)
1480
+ feature_extractor, image_encoder = load_image_encoders(args)
1481
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, "cpu")
1482
+ shared_models["tokenizer1"] = tokenizer1
1483
+ shared_models["text_encoder1"] = text_encoder1
1484
+ shared_models["tokenizer2"] = tokenizer2
1485
+ shared_models["text_encoder2"] = text_encoder2
1486
+ shared_models["feature_extractor"] = feature_extractor
1487
+ shared_models["image_encoder"] = image_encoder
1488
+ shared_models["vae"] = vae
1489
+
1490
+ return shared_models
1491
+
1492
+
1493
+ def process_batch_prompts(prompts_data: List[Dict], args: argparse.Namespace) -> None:
1494
+ """Process multiple prompts with model reuse
1495
+
1496
+ Args:
1497
+ prompts_data: List of prompt data dictionaries
1498
+ args: Base command line arguments
1499
+ """
1500
+ if not prompts_data:
1501
+ logger.warning("No valid prompts found")
1502
+ return
1503
+
1504
+ # 1. Load configuration
1505
+ gen_settings = get_generation_settings(args)
1506
+ device = gen_settings.device
1507
+
1508
+ # 2. Load models to CPU in advance except for VAE and DiT
1509
+ shared_models = load_shared_models(args)
1510
+
1511
+ # 3. Generate for each prompt
1512
+ all_latents = []
1513
+ all_prompt_args = []
1514
+
1515
+ with torch.no_grad():
1516
+ for prompt_data in prompts_data:
1517
+ prompt = prompt_data["prompt"]
1518
+ prompt_args = apply_overrides(args, prompt_data)
1519
+ logger.info(f"Processing prompt: {prompt}")
1520
+
1521
+ try:
1522
+ vae, latent = generate(prompt_args, gen_settings, shared_models)
1523
+
1524
+ # Save latent if needed
1525
+ if args.output_type == "latent" or args.output_type == "both" or args.output_type == "latent_images":
1526
+ height, width = latent.shape[-2], latent.shape[-1] # BCTHW
1527
+ height *= 8
1528
+ width *= 8
1529
+ save_latent(latent, prompt_args, height, width)
1530
+
1531
+ all_latents.append(latent)
1532
+ all_prompt_args.append(prompt_args)
1533
+ except Exception as e:
1534
+ logger.error(f"Error processing prompt: {prompt}. Error: {e}")
1535
+ continue
1536
+
1537
+ # 4. Free models
1538
+ if "model" in shared_models:
1539
+ del shared_models["model"]
1540
+ del shared_models["tokenizer1"]
1541
+ del shared_models["text_encoder1"]
1542
+ del shared_models["tokenizer2"]
1543
+ del shared_models["text_encoder2"]
1544
+ del shared_models["feature_extractor"]
1545
+ del shared_models["image_encoder"]
1546
+
1547
+ clean_memory_on_device(device)
1548
+ synchronize_device(device)
1549
+
1550
+ # 5. Decode latents if needed
1551
+ if args.output_type != "latent":
1552
+ logger.info("Decoding latents to videos/images")
1553
+ vae.to(device)
1554
+
1555
+ for i, (latent, prompt_args) in enumerate(zip(all_latents, all_prompt_args)):
1556
+ logger.info(f"Decoding output {i+1}/{len(all_latents)}")
1557
+
1558
+ # avoid saving latents again (ugly hack)
1559
+ if prompt_args.output_type == "both":
1560
+ prompt_args.output_type = "video"
1561
+ elif prompt_args.output_type == "latent_images":
1562
+ prompt_args.output_type = "images"
1563
+
1564
+ save_output(prompt_args, vae, latent[0], device)
1565
+
1566
+
1567
+ def process_interactive(args: argparse.Namespace) -> None:
1568
+ """Process prompts in interactive mode
1569
+
1570
+ Args:
1571
+ args: Base command line arguments
1572
+ """
1573
+ gen_settings = get_generation_settings(args)
1574
+ device = gen_settings.device
1575
+ shared_models = load_shared_models(args)
1576
+
1577
+ print("Interactive mode. Enter prompts (Ctrl+D or Ctrl+Z (Windows) to exit):")
1578
+
1579
+ try:
1580
+ while True:
1581
+ try:
1582
+ line = input("> ")
1583
+ if not line.strip():
1584
+ continue
1585
+
1586
+ # Parse prompt
1587
+ prompt_data = parse_prompt_line(line)
1588
+ prompt_args = apply_overrides(args, prompt_data)
1589
+
1590
+ # Generate latent
1591
+ vae, latent = generate(prompt_args, gen_settings, shared_models)
1592
+
1593
+ # Save latent and video
1594
+ save_output(prompt_args, vae, latent[0], device)
1595
+
1596
+ except KeyboardInterrupt:
1597
+ print("\nInterrupted. Continue (Ctrl+D or Ctrl+Z (Windows) to exit)")
1598
+ continue
1599
+
1600
+ except EOFError:
1601
+ print("\nExiting interactive mode")
1602
+
1603
+
1604
+ def get_generation_settings(args: argparse.Namespace) -> GenerationSettings:
1605
+ device = torch.device(args.device)
1606
+
1607
+ dit_weight_dtype = None # default
1608
+ if args.fp8_scaled:
1609
+ dit_weight_dtype = None # various precision weights, so don't cast to specific dtype
1610
+ elif args.fp8:
1611
+ dit_weight_dtype = torch.float8_e4m3fn
1612
+
1613
+ logger.info(f"Using device: {device}, DiT weight weight precision: {dit_weight_dtype}")
1614
+
1615
+ gen_settings = GenerationSettings(device=device, dit_weight_dtype=dit_weight_dtype)
1616
+ return gen_settings
1617
+
1618
+
1619
+ def main():
1620
+ # Parse arguments
1621
+ args = parse_args()
1622
+
1623
+ # Check if latents are provided
1624
+ latents_mode = args.latent_path is not None and len(args.latent_path) > 0
1625
+
1626
+ # Set device
1627
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
1628
+ device = torch.device(device)
1629
+ logger.info(f"Using device: {device}")
1630
+ args.device = device
1631
+
1632
+ if latents_mode:
1633
+ # Original latent decode mode
1634
+ original_base_names = []
1635
+ latents_list = []
1636
+ seeds = []
1637
+
1638
+ # assert len(args.latent_path) == 1, "Only one latent path is supported for now"
1639
+
1640
+ for latent_path in args.latent_path:
1641
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
1642
+ seed = 0
1643
+
1644
+ if os.path.splitext(latent_path)[1] != ".safetensors":
1645
+ latents = torch.load(latent_path, map_location="cpu")
1646
+ else:
1647
+ latents = load_file(latent_path)["latent"]
1648
+ with safe_open(latent_path, framework="pt") as f:
1649
+ metadata = f.metadata()
1650
+ if metadata is None:
1651
+ metadata = {}
1652
+ logger.info(f"Loaded metadata: {metadata}")
1653
+
1654
+ if "seeds" in metadata:
1655
+ seed = int(metadata["seeds"])
1656
+ if "height" in metadata and "width" in metadata:
1657
+ height = int(metadata["height"])
1658
+ width = int(metadata["width"])
1659
+ args.video_size = [height, width]
1660
+ if "video_seconds" in metadata:
1661
+ args.video_seconds = float(metadata["video_seconds"])
1662
+
1663
+ seeds.append(seed)
1664
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
1665
+
1666
+ if latents.ndim == 5: # [BCTHW]
1667
+ latents = latents.squeeze(0) # [CTHW]
1668
+
1669
+ latents_list.append(latents)
1670
+
1671
+ # latent = torch.stack(latents_list, dim=0) # [N, ...], must be same shape
1672
+
1673
+ for i, latent in enumerate(latents_list):
1674
+ args.seed = seeds[i]
1675
+
1676
+ vae = load_vae(args.vae, args.vae_chunk_size, args.vae_spatial_tile_sample_min_size, device)
1677
+ save_output(args, vae, latent, device, original_base_names)
1678
+
1679
+ elif args.from_file:
1680
+ # Batch mode from file
1681
+
1682
+ # Read prompts from file
1683
+ with open(args.from_file, "r", encoding="utf-8") as f:
1684
+ prompt_lines = f.readlines()
1685
+
1686
+ # Process prompts
1687
+ prompts_data = preprocess_prompts_for_batch(prompt_lines, args)
1688
+ process_batch_prompts(prompts_data, args)
1689
+
1690
+ elif args.interactive:
1691
+ # Interactive mode
1692
+ process_interactive(args)
1693
+
1694
+ else:
1695
+ # Single prompt mode (original behavior)
1696
+
1697
+ # Generate latent
1698
+ gen_settings = get_generation_settings(args)
1699
+ vae, latent = generate(args, gen_settings)
1700
+ # print(f"Generated latent shape: {latent.shape}")
1701
+ if args.save_merged_model:
1702
+ return
1703
+
1704
+ # Save latent and video
1705
+ save_output(args, vae, latent[0], device)
1706
+
1707
+ logger.info("Done!")
1708
+
1709
+
1710
+ if __name__ == "__main__":
1711
+ main()
frame_pack/__init__.py ADDED
File without changes
frame_pack/bucket_tools.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ bucket_options = {
2
+ 640: [
3
+ (416, 960),
4
+ (448, 864),
5
+ (480, 832),
6
+ (512, 768),
7
+ (544, 704),
8
+ (576, 672),
9
+ (608, 640),
10
+ (640, 608),
11
+ (672, 576),
12
+ (704, 544),
13
+ (768, 512),
14
+ (832, 480),
15
+ (864, 448),
16
+ (960, 416),
17
+ ],
18
+ }
19
+
20
+
21
+ def find_nearest_bucket(h, w, resolution=640):
22
+ min_metric = float('inf')
23
+ best_bucket = None
24
+ for (bucket_h, bucket_w) in bucket_options[resolution]:
25
+ metric = abs(h * bucket_w - w * bucket_h)
26
+ if metric <= min_metric:
27
+ min_metric = metric
28
+ best_bucket = (bucket_h, bucket_w)
29
+ return best_bucket
30
+
frame_pack/clip_vision.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def hf_clip_vision_encode(image, feature_extractor, image_encoder):
5
+ assert isinstance(image, np.ndarray)
6
+ assert image.ndim == 3 and image.shape[2] == 3
7
+ assert image.dtype == np.uint8
8
+
9
+ preprocessed = feature_extractor.preprocess(images=image, return_tensors="pt").to(
10
+ device=image_encoder.device, dtype=image_encoder.dtype
11
+ )
12
+ image_encoder_output = image_encoder(**preprocessed)
13
+
14
+ return image_encoder_output
frame_pack/framepack_utils.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from types import SimpleNamespace
4
+ from typing import Optional, Union
5
+
6
+ import accelerate
7
+ from accelerate import Accelerator, init_empty_weights
8
+ import torch
9
+ from safetensors.torch import load_file
10
+ from transformers import (
11
+ LlamaTokenizerFast,
12
+ LlamaConfig,
13
+ LlamaModel,
14
+ CLIPTokenizer,
15
+ CLIPTextModel,
16
+ CLIPConfig,
17
+ SiglipImageProcessor,
18
+ SiglipVisionModel,
19
+ SiglipVisionConfig,
20
+ )
21
+
22
+ from utils.safetensors_utils import load_split_weights
23
+ from hunyuan_model.vae import load_vae as hunyuan_load_vae
24
+
25
+ import logging
26
+
27
+ logger = logging.getLogger(__name__)
28
+ logging.basicConfig(level=logging.INFO)
29
+
30
+
31
+ def load_vae(
32
+ vae_path: str, vae_chunk_size: Optional[int], vae_spatial_tile_sample_min_size: Optional[int], device: Union[str, torch.device]
33
+ ):
34
+ # single file and directory (contains 'vae') support
35
+ if os.path.isdir(vae_path):
36
+ vae_path = os.path.join(vae_path, "vae", "diffusion_pytorch_model.safetensors")
37
+ else:
38
+ vae_path = vae_path
39
+
40
+ vae_dtype = torch.float16 # if vae_dtype is None else str_to_dtype(vae_dtype)
41
+ vae, _, s_ratio, t_ratio = hunyuan_load_vae(vae_dtype=vae_dtype, device=device, vae_path=vae_path)
42
+ vae.eval()
43
+ # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
44
+
45
+ # set chunk_size to CausalConv3d recursively
46
+ chunk_size = vae_chunk_size
47
+ if chunk_size is not None:
48
+ vae.set_chunk_size_for_causal_conv_3d(chunk_size)
49
+ logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
50
+
51
+ if vae_spatial_tile_sample_min_size is not None:
52
+ vae.enable_spatial_tiling(True)
53
+ vae.tile_sample_min_size = vae_spatial_tile_sample_min_size
54
+ vae.tile_latent_min_size = vae_spatial_tile_sample_min_size // 8
55
+ logger.info(f"Enabled spatial tiling with min size {vae_spatial_tile_sample_min_size}")
56
+ # elif vae_tiling:
57
+ else:
58
+ vae.enable_spatial_tiling(True)
59
+
60
+ return vae
61
+
62
+
63
+ # region Text Encoders
64
+
65
+ # Text Encoder configs are copied from HunyuanVideo repo
66
+
67
+ LLAMA_CONFIG = {
68
+ "architectures": ["LlamaModel"],
69
+ "attention_bias": False,
70
+ "attention_dropout": 0.0,
71
+ "bos_token_id": 128000,
72
+ "eos_token_id": 128001,
73
+ "head_dim": 128,
74
+ "hidden_act": "silu",
75
+ "hidden_size": 4096,
76
+ "initializer_range": 0.02,
77
+ "intermediate_size": 14336,
78
+ "max_position_embeddings": 8192,
79
+ "mlp_bias": False,
80
+ "model_type": "llama",
81
+ "num_attention_heads": 32,
82
+ "num_hidden_layers": 32,
83
+ "num_key_value_heads": 8,
84
+ "pretraining_tp": 1,
85
+ "rms_norm_eps": 1e-05,
86
+ "rope_scaling": None,
87
+ "rope_theta": 500000.0,
88
+ "tie_word_embeddings": False,
89
+ "torch_dtype": "float16",
90
+ "transformers_version": "4.46.3",
91
+ "use_cache": True,
92
+ "vocab_size": 128320,
93
+ }
94
+
95
+ CLIP_CONFIG = {
96
+ # "_name_or_path": "/raid/aryan/llava-llama-3-8b-v1_1-extracted/text_encoder_2",
97
+ "architectures": ["CLIPTextModel"],
98
+ "attention_dropout": 0.0,
99
+ "bos_token_id": 0,
100
+ "dropout": 0.0,
101
+ "eos_token_id": 2,
102
+ "hidden_act": "quick_gelu",
103
+ "hidden_size": 768,
104
+ "initializer_factor": 1.0,
105
+ "initializer_range": 0.02,
106
+ "intermediate_size": 3072,
107
+ "layer_norm_eps": 1e-05,
108
+ "max_position_embeddings": 77,
109
+ "model_type": "clip_text_model",
110
+ "num_attention_heads": 12,
111
+ "num_hidden_layers": 12,
112
+ "pad_token_id": 1,
113
+ "projection_dim": 768,
114
+ "torch_dtype": "float16",
115
+ "transformers_version": "4.48.0.dev0",
116
+ "vocab_size": 49408,
117
+ }
118
+
119
+
120
+ def load_text_encoder1(
121
+ args, fp8_llm: Optional[bool] = False, device: Optional[Union[str, torch.device]] = None
122
+ ) -> tuple[LlamaTokenizerFast, LlamaModel]:
123
+ # single file, split file and directory (contains 'text_encoder') support
124
+ logger.info(f"Loading text encoder 1 tokenizer")
125
+ tokenizer1 = LlamaTokenizerFast.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer")
126
+
127
+ logger.info(f"Loading text encoder 1 from {args.text_encoder1}")
128
+ if os.path.isdir(args.text_encoder1):
129
+ # load from directory, configs are in the directory
130
+ text_encoder1 = LlamaModel.from_pretrained(args.text_encoder1, subfolder="text_encoder", torch_dtype=torch.float16)
131
+ else:
132
+ # load from file, we create the model with the appropriate config
133
+ config = LlamaConfig(**LLAMA_CONFIG)
134
+ with init_empty_weights():
135
+ text_encoder1 = LlamaModel._from_config(config, torch_dtype=torch.float16)
136
+
137
+ state_dict = load_split_weights(args.text_encoder1)
138
+
139
+ # support weights from ComfyUI
140
+ if "model.embed_tokens.weight" in state_dict:
141
+ for key in list(state_dict.keys()):
142
+ if key.startswith("model."):
143
+ new_key = key.replace("model.", "")
144
+ state_dict[new_key] = state_dict[key]
145
+ del state_dict[key]
146
+ if "tokenizer" in state_dict:
147
+ state_dict.pop("tokenizer")
148
+ if "lm_head.weight" in state_dict:
149
+ state_dict.pop("lm_head.weight")
150
+
151
+ # # support weights from ComfyUI
152
+ # if "tokenizer" in state_dict:
153
+ # state_dict.pop("tokenizer")
154
+
155
+ text_encoder1.load_state_dict(state_dict, strict=True, assign=True)
156
+
157
+ if fp8_llm:
158
+ org_dtype = text_encoder1.dtype
159
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
160
+ text_encoder1.to(device=device, dtype=torch.float8_e4m3fn)
161
+
162
+ # prepare LLM for fp8
163
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
164
+ def forward_hook(module):
165
+ def forward(hidden_states):
166
+ input_dtype = hidden_states.dtype
167
+ hidden_states = hidden_states.to(torch.float32)
168
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
169
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
170
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
171
+
172
+ return forward
173
+
174
+ for module in llama_model.modules():
175
+ if module.__class__.__name__ in ["Embedding"]:
176
+ # print("set", module.__class__.__name__, "to", target_dtype)
177
+ module.to(target_dtype)
178
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
179
+ # print("set", module.__class__.__name__, "hooks")
180
+ module.forward = forward_hook(module)
181
+
182
+ prepare_fp8(text_encoder1, org_dtype)
183
+ else:
184
+ text_encoder1.to(device)
185
+
186
+ text_encoder1.eval()
187
+ return tokenizer1, text_encoder1
188
+
189
+
190
+ def load_text_encoder2(args) -> tuple[CLIPTokenizer, CLIPTextModel]:
191
+ # single file and directory (contains 'text_encoder_2') support
192
+ logger.info(f"Loading text encoder 2 tokenizer")
193
+ tokenizer2 = CLIPTokenizer.from_pretrained("hunyuanvideo-community/HunyuanVideo", subfolder="tokenizer_2")
194
+
195
+ logger.info(f"Loading text encoder 2 from {args.text_encoder2}")
196
+ if os.path.isdir(args.text_encoder2):
197
+ # load from directory, configs are in the directory
198
+ text_encoder2 = CLIPTextModel.from_pretrained(args.text_encoder2, subfolder="text_encoder_2", torch_dtype=torch.float16)
199
+ else:
200
+ # we only have one file, so we can load it directly
201
+ config = CLIPConfig(**CLIP_CONFIG)
202
+ with init_empty_weights():
203
+ text_encoder2 = CLIPTextModel._from_config(config, torch_dtype=torch.float16)
204
+
205
+ state_dict = load_file(args.text_encoder2)
206
+
207
+ text_encoder2.load_state_dict(state_dict, strict=True, assign=True)
208
+
209
+ text_encoder2.eval()
210
+ return tokenizer2, text_encoder2
211
+
212
+
213
+ # endregion
214
+
215
+ # region image encoder
216
+
217
+ # Siglip configs are copied from FramePack repo
218
+ FEATURE_EXTRACTOR_CONFIG = {
219
+ "do_convert_rgb": None,
220
+ "do_normalize": True,
221
+ "do_rescale": True,
222
+ "do_resize": True,
223
+ "image_mean": [0.5, 0.5, 0.5],
224
+ "image_processor_type": "SiglipImageProcessor",
225
+ "image_std": [0.5, 0.5, 0.5],
226
+ "processor_class": "SiglipProcessor",
227
+ "resample": 3,
228
+ "rescale_factor": 0.00392156862745098,
229
+ "size": {"height": 384, "width": 384},
230
+ }
231
+ IMAGE_ENCODER_CONFIG = {
232
+ "_name_or_path": "/home/lvmin/.cache/huggingface/hub/models--black-forest-labs--FLUX.1-Redux-dev/snapshots/1282f955f706b5240161278f2ef261d2a29ad649/image_encoder",
233
+ "architectures": ["SiglipVisionModel"],
234
+ "attention_dropout": 0.0,
235
+ "hidden_act": "gelu_pytorch_tanh",
236
+ "hidden_size": 1152,
237
+ "image_size": 384,
238
+ "intermediate_size": 4304,
239
+ "layer_norm_eps": 1e-06,
240
+ "model_type": "siglip_vision_model",
241
+ "num_attention_heads": 16,
242
+ "num_channels": 3,
243
+ "num_hidden_layers": 27,
244
+ "patch_size": 14,
245
+ "torch_dtype": "bfloat16",
246
+ "transformers_version": "4.46.2",
247
+ }
248
+
249
+
250
+ def load_image_encoders(args):
251
+ logger.info(f"Loading image encoder feature extractor")
252
+ feature_extractor = SiglipImageProcessor(**FEATURE_EXTRACTOR_CONFIG)
253
+
254
+ # single file, split file and directory (contains 'image_encoder') support
255
+ logger.info(f"Loading image encoder from {args.image_encoder}")
256
+ if os.path.isdir(args.image_encoder):
257
+ # load from directory, configs are in the directory
258
+ image_encoder = SiglipVisionModel.from_pretrained(args.image_encoder, subfolder="image_encoder", torch_dtype=torch.float16)
259
+ else:
260
+ # load from file, we create the model with the appropriate config
261
+ config = SiglipVisionConfig(**IMAGE_ENCODER_CONFIG)
262
+ with init_empty_weights():
263
+ image_encoder = SiglipVisionModel._from_config(config, torch_dtype=torch.float16)
264
+
265
+ state_dict = load_file(args.image_encoder)
266
+
267
+ image_encoder.load_state_dict(state_dict, strict=True, assign=True)
268
+
269
+ image_encoder.eval()
270
+ return feature_extractor, image_encoder
271
+
272
+
273
+ # endregion
frame_pack/hunyuan.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original code: https://github.com/lllyasviel/FramePack
2
+ # original license: Apache-2.0
3
+
4
+ import torch
5
+
6
+ # from diffusers.pipelines.hunyuan_video.pipeline_hunyuan_video import DEFAULT_PROMPT_TEMPLATE
7
+ # from diffusers_helper.utils import crop_or_pad_yield_mask
8
+ from hunyuan_model.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
9
+ from hunyuan_model.text_encoder import PROMPT_TEMPLATE
10
+
11
+
12
+ @torch.no_grad()
13
+ def encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2, max_length=256, custom_system_prompt=None):
14
+ assert isinstance(prompt, str)
15
+
16
+ prompt = [prompt]
17
+
18
+ # LLAMA
19
+
20
+ # We can verify crop_start by checking the token count of the prompt:
21
+ # custom_system_prompt = (
22
+ # "Describe the video by detailing the following aspects: "
23
+ # "1. The main content and theme of the video."
24
+ # "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
25
+ # "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
26
+ # "4. background environment, light, style and atmosphere."
27
+ # "5. camera angles, movements, and transitions used in the video:"
28
+ # )
29
+ if custom_system_prompt is None:
30
+ prompt_llama = [PROMPT_TEMPLATE["dit-llm-encode-video"]["template"].format(p) for p in prompt]
31
+ crop_start = PROMPT_TEMPLATE["dit-llm-encode-video"]["crop_start"]
32
+ else:
33
+ # count tokens for custom_system_prompt
34
+ full_prompt = f"<|start_header_id|>system<|end_header_id|>\n\n{custom_system_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n"
35
+ print(f"Custom system prompt: {full_prompt}")
36
+ system_prompt_tokens = tokenizer(full_prompt, return_tensors="pt", truncation=True).input_ids[0].shape[0]
37
+ print(f"Custom system prompt token count: {system_prompt_tokens}")
38
+ prompt_llama = [full_prompt + p + "<|eot_id|>" for p in prompt]
39
+ crop_start = system_prompt_tokens
40
+
41
+ llama_inputs = tokenizer(
42
+ prompt_llama,
43
+ padding="max_length",
44
+ max_length=max_length + crop_start,
45
+ truncation=True,
46
+ return_tensors="pt",
47
+ return_length=False,
48
+ return_overflowing_tokens=False,
49
+ return_attention_mask=True,
50
+ )
51
+
52
+ llama_input_ids = llama_inputs.input_ids.to(text_encoder.device)
53
+ llama_attention_mask = llama_inputs.attention_mask.to(text_encoder.device)
54
+ llama_attention_length = int(llama_attention_mask.sum())
55
+
56
+ llama_outputs = text_encoder(
57
+ input_ids=llama_input_ids,
58
+ attention_mask=llama_attention_mask,
59
+ output_hidden_states=True,
60
+ )
61
+
62
+ llama_vec = llama_outputs.hidden_states[-3][:, crop_start:llama_attention_length]
63
+ # llama_vec_remaining = llama_outputs.hidden_states[-3][:, llama_attention_length:]
64
+ llama_attention_mask = llama_attention_mask[:, crop_start:llama_attention_length]
65
+
66
+ assert torch.all(llama_attention_mask.bool())
67
+
68
+ # CLIP
69
+
70
+ clip_l_input_ids = tokenizer_2(
71
+ prompt,
72
+ padding="max_length",
73
+ max_length=77,
74
+ truncation=True,
75
+ return_overflowing_tokens=False,
76
+ return_length=False,
77
+ return_tensors="pt",
78
+ ).input_ids
79
+ clip_l_pooler = text_encoder_2(clip_l_input_ids.to(text_encoder_2.device), output_hidden_states=False).pooler_output
80
+
81
+ return llama_vec, clip_l_pooler
82
+
83
+
84
+ @torch.no_grad()
85
+ def vae_decode_fake(latents):
86
+ latent_rgb_factors = [
87
+ [-0.0395, -0.0331, 0.0445],
88
+ [0.0696, 0.0795, 0.0518],
89
+ [0.0135, -0.0945, -0.0282],
90
+ [0.0108, -0.0250, -0.0765],
91
+ [-0.0209, 0.0032, 0.0224],
92
+ [-0.0804, -0.0254, -0.0639],
93
+ [-0.0991, 0.0271, -0.0669],
94
+ [-0.0646, -0.0422, -0.0400],
95
+ [-0.0696, -0.0595, -0.0894],
96
+ [-0.0799, -0.0208, -0.0375],
97
+ [0.1166, 0.1627, 0.0962],
98
+ [0.1165, 0.0432, 0.0407],
99
+ [-0.2315, -0.1920, -0.1355],
100
+ [-0.0270, 0.0401, -0.0821],
101
+ [-0.0616, -0.0997, -0.0727],
102
+ [0.0249, -0.0469, -0.1703],
103
+ ] # From comfyui
104
+
105
+ latent_rgb_factors_bias = [0.0259, -0.0192, -0.0761]
106
+
107
+ weight = torch.tensor(latent_rgb_factors, device=latents.device, dtype=latents.dtype).transpose(0, 1)[:, :, None, None, None]
108
+ bias = torch.tensor(latent_rgb_factors_bias, device=latents.device, dtype=latents.dtype)
109
+
110
+ images = torch.nn.functional.conv3d(latents, weight, bias=bias, stride=1, padding=0, dilation=1, groups=1)
111
+ images = images.clamp(0.0, 1.0)
112
+
113
+ return images
114
+
115
+
116
+ @torch.no_grad()
117
+ def vae_decode(latents, vae, image_mode=False) -> torch.Tensor:
118
+ latents = latents / vae.config.scaling_factor
119
+
120
+ if not image_mode:
121
+ image = vae.decode(latents.to(device=vae.device, dtype=vae.dtype)).sample
122
+ else:
123
+ latents = latents.to(device=vae.device, dtype=vae.dtype).unbind(2)
124
+ image = [vae.decode(l.unsqueeze(2)).sample for l in latents]
125
+ image = torch.cat(image, dim=2)
126
+
127
+ return image
128
+
129
+
130
+ @torch.no_grad()
131
+ def vae_encode(image, vae: AutoencoderKLCausal3D) -> torch.Tensor:
132
+ latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample()
133
+ latents = latents * vae.config.scaling_factor
134
+ return latents
frame_pack/hunyuan_video_packed.py ADDED
@@ -0,0 +1,2015 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original code: https://github.com/lllyasviel/FramePack
2
+ # original license: Apache-2.0
3
+
4
+ import glob
5
+ import math
6
+ import numbers
7
+ import os
8
+ from types import SimpleNamespace
9
+ from typing import Any, Dict, List, Optional, Tuple, Union
10
+
11
+ import torch
12
+ import einops
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+
17
+ from modules.custom_offloading_utils import ModelOffloader
18
+ from utils.safetensors_utils import load_split_weights
19
+ from modules.fp8_optimization_utils import apply_fp8_monkey_patch, optimize_state_dict_with_fp8
20
+ from accelerate import init_empty_weights
21
+
22
+ try:
23
+ # raise NotImplementedError
24
+ from xformers.ops import memory_efficient_attention as xformers_attn_func
25
+
26
+ print("Xformers is installed!")
27
+ except:
28
+ print("Xformers is not installed!")
29
+ xformers_attn_func = None
30
+
31
+ try:
32
+ # raise NotImplementedError
33
+ from flash_attn import flash_attn_varlen_func, flash_attn_func
34
+
35
+ print("Flash Attn is installed!")
36
+ except:
37
+ print("Flash Attn is not installed!")
38
+ flash_attn_varlen_func = None
39
+ flash_attn_func = None
40
+
41
+ try:
42
+ # raise NotImplementedError
43
+ from sageattention import sageattn_varlen, sageattn
44
+
45
+ print("Sage Attn is installed!")
46
+ except:
47
+ print("Sage Attn is not installed!")
48
+ sageattn_varlen = None
49
+ sageattn = None
50
+
51
+
52
+ import logging
53
+
54
+ logger = logging.getLogger(__name__)
55
+ logging.basicConfig(level=logging.INFO)
56
+
57
+ # region diffusers
58
+
59
+ # copied from diffusers with some modifications to minimize dependencies
60
+ # original code: https://github.com/huggingface/diffusers/
61
+ # original license: Apache-2.0
62
+
63
+ ACT2CLS = {
64
+ "swish": nn.SiLU,
65
+ "silu": nn.SiLU,
66
+ "mish": nn.Mish,
67
+ "gelu": nn.GELU,
68
+ "relu": nn.ReLU,
69
+ }
70
+
71
+
72
+ def get_activation(act_fn: str) -> nn.Module:
73
+ """Helper function to get activation function from string.
74
+
75
+ Args:
76
+ act_fn (str): Name of activation function.
77
+
78
+ Returns:
79
+ nn.Module: Activation function.
80
+ """
81
+
82
+ act_fn = act_fn.lower()
83
+ if act_fn in ACT2CLS:
84
+ return ACT2CLS[act_fn]()
85
+ else:
86
+ raise ValueError(f"activation function {act_fn} not found in ACT2FN mapping {list(ACT2CLS.keys())}")
87
+
88
+
89
+ def get_timestep_embedding(
90
+ timesteps: torch.Tensor,
91
+ embedding_dim: int,
92
+ flip_sin_to_cos: bool = False,
93
+ downscale_freq_shift: float = 1,
94
+ scale: float = 1,
95
+ max_period: int = 10000,
96
+ ):
97
+ """
98
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
99
+
100
+ Args
101
+ timesteps (torch.Tensor):
102
+ a 1-D Tensor of N indices, one per batch element. These may be fractional.
103
+ embedding_dim (int):
104
+ the dimension of the output.
105
+ flip_sin_to_cos (bool):
106
+ Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
107
+ downscale_freq_shift (float):
108
+ Controls the delta between frequencies between dimensions
109
+ scale (float):
110
+ Scaling factor applied to the embeddings.
111
+ max_period (int):
112
+ Controls the maximum frequency of the embeddings
113
+ Returns
114
+ torch.Tensor: an [N x dim] Tensor of positional embeddings.
115
+ """
116
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
117
+
118
+ half_dim = embedding_dim // 2
119
+ exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device)
120
+ exponent = exponent / (half_dim - downscale_freq_shift)
121
+
122
+ emb = torch.exp(exponent)
123
+ emb = timesteps[:, None].float() * emb[None, :]
124
+
125
+ # scale embeddings
126
+ emb = scale * emb
127
+
128
+ # concat sine and cosine embeddings
129
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
130
+
131
+ # flip sine and cosine embeddings
132
+ if flip_sin_to_cos:
133
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
134
+
135
+ # zero pad
136
+ if embedding_dim % 2 == 1:
137
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
138
+ return emb
139
+
140
+
141
+ class TimestepEmbedding(nn.Module):
142
+ def __init__(
143
+ self,
144
+ in_channels: int,
145
+ time_embed_dim: int,
146
+ act_fn: str = "silu",
147
+ out_dim: int = None,
148
+ post_act_fn: Optional[str] = None,
149
+ cond_proj_dim=None,
150
+ sample_proj_bias=True,
151
+ ):
152
+ super().__init__()
153
+
154
+ self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
155
+
156
+ if cond_proj_dim is not None:
157
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
158
+ else:
159
+ self.cond_proj = None
160
+
161
+ self.act = get_activation(act_fn)
162
+
163
+ if out_dim is not None:
164
+ time_embed_dim_out = out_dim
165
+ else:
166
+ time_embed_dim_out = time_embed_dim
167
+ self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
168
+
169
+ if post_act_fn is None:
170
+ self.post_act = None
171
+ else:
172
+ self.post_act = get_activation(post_act_fn)
173
+
174
+ def forward(self, sample, condition=None):
175
+ if condition is not None:
176
+ sample = sample + self.cond_proj(condition)
177
+ sample = self.linear_1(sample)
178
+
179
+ if self.act is not None:
180
+ sample = self.act(sample)
181
+
182
+ sample = self.linear_2(sample)
183
+
184
+ if self.post_act is not None:
185
+ sample = self.post_act(sample)
186
+ return sample
187
+
188
+
189
+ class Timesteps(nn.Module):
190
+ def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
191
+ super().__init__()
192
+ self.num_channels = num_channels
193
+ self.flip_sin_to_cos = flip_sin_to_cos
194
+ self.downscale_freq_shift = downscale_freq_shift
195
+ self.scale = scale
196
+
197
+ def forward(self, timesteps):
198
+ t_emb = get_timestep_embedding(
199
+ timesteps,
200
+ self.num_channels,
201
+ flip_sin_to_cos=self.flip_sin_to_cos,
202
+ downscale_freq_shift=self.downscale_freq_shift,
203
+ scale=self.scale,
204
+ )
205
+ return t_emb
206
+
207
+
208
+ class FP32SiLU(nn.Module):
209
+ r"""
210
+ SiLU activation function with input upcasted to torch.float32.
211
+ """
212
+
213
+ def __init__(self):
214
+ super().__init__()
215
+
216
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
217
+ return F.silu(inputs.float(), inplace=False).to(inputs.dtype)
218
+
219
+
220
+ class GELU(nn.Module):
221
+ r"""
222
+ GELU activation function with tanh approximation support with `approximate="tanh"`.
223
+
224
+ Parameters:
225
+ dim_in (`int`): The number of channels in the input.
226
+ dim_out (`int`): The number of channels in the output.
227
+ approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
228
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
229
+ """
230
+
231
+ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
232
+ super().__init__()
233
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
234
+ self.approximate = approximate
235
+
236
+ def gelu(self, gate: torch.Tensor) -> torch.Tensor:
237
+ # if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
238
+ # # fp16 gelu not supported on mps before torch 2.0
239
+ # return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
240
+ return F.gelu(gate, approximate=self.approximate)
241
+
242
+ def forward(self, hidden_states):
243
+ hidden_states = self.proj(hidden_states)
244
+ hidden_states = self.gelu(hidden_states)
245
+ return hidden_states
246
+
247
+
248
+ class PixArtAlphaTextProjection(nn.Module):
249
+ """
250
+ Projects caption embeddings. Also handles dropout for classifier-free guidance.
251
+
252
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
253
+ """
254
+
255
+ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"):
256
+ super().__init__()
257
+ if out_features is None:
258
+ out_features = hidden_size
259
+ self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
260
+ if act_fn == "gelu_tanh":
261
+ self.act_1 = nn.GELU(approximate="tanh")
262
+ elif act_fn == "silu":
263
+ self.act_1 = nn.SiLU()
264
+ elif act_fn == "silu_fp32":
265
+ self.act_1 = FP32SiLU()
266
+ else:
267
+ raise ValueError(f"Unknown activation function: {act_fn}")
268
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True)
269
+
270
+ def forward(self, caption):
271
+ hidden_states = self.linear_1(caption)
272
+ hidden_states = self.act_1(hidden_states)
273
+ hidden_states = self.linear_2(hidden_states)
274
+ return hidden_states
275
+
276
+
277
+ class LayerNormFramePack(nn.LayerNorm):
278
+ # casting to dtype of input tensor is added
279
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
280
+ return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps).to(x)
281
+
282
+
283
+ class FP32LayerNormFramePack(nn.LayerNorm):
284
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
285
+ origin_dtype = x.dtype
286
+ return torch.nn.functional.layer_norm(
287
+ x.float(),
288
+ self.normalized_shape,
289
+ self.weight.float() if self.weight is not None else None,
290
+ self.bias.float() if self.bias is not None else None,
291
+ self.eps,
292
+ ).to(origin_dtype)
293
+
294
+
295
+ class RMSNormFramePack(nn.Module):
296
+ r"""
297
+ RMS Norm as introduced in https://arxiv.org/abs/1910.07467 by Zhang et al.
298
+
299
+ Args:
300
+ dim (`int`): Number of dimensions to use for `weights`. Only effective when `elementwise_affine` is True.
301
+ eps (`float`): Small value to use when calculating the reciprocal of the square-root.
302
+ elementwise_affine (`bool`, defaults to `True`):
303
+ Boolean flag to denote if affine transformation should be applied.
304
+ bias (`bool`, defaults to False): If also training the `bias` param.
305
+ """
306
+
307
+ def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
308
+ super().__init__()
309
+
310
+ self.eps = eps
311
+ self.elementwise_affine = elementwise_affine
312
+
313
+ if isinstance(dim, numbers.Integral):
314
+ dim = (dim,)
315
+
316
+ self.dim = torch.Size(dim)
317
+
318
+ self.weight = None
319
+ self.bias = None
320
+
321
+ if elementwise_affine:
322
+ self.weight = nn.Parameter(torch.ones(dim))
323
+ if bias:
324
+ self.bias = nn.Parameter(torch.zeros(dim))
325
+
326
+ def forward(self, hidden_states):
327
+ input_dtype = hidden_states.dtype
328
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
329
+ hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
330
+
331
+ if self.weight is None:
332
+ return hidden_states.to(input_dtype)
333
+
334
+ return hidden_states.to(input_dtype) * self.weight.to(input_dtype)
335
+
336
+
337
+ class AdaLayerNormContinuousFramePack(nn.Module):
338
+ r"""
339
+ Adaptive normalization layer with a norm layer (layer_norm or rms_norm).
340
+
341
+ Args:
342
+ embedding_dim (`int`): Embedding dimension to use during projection.
343
+ conditioning_embedding_dim (`int`): Dimension of the input condition.
344
+ elementwise_affine (`bool`, defaults to `True`):
345
+ Boolean flag to denote if affine transformation should be applied.
346
+ eps (`float`, defaults to 1e-5): Epsilon factor.
347
+ bias (`bias`, defaults to `True`): Boolean flag to denote if bias should be use.
348
+ norm_type (`str`, defaults to `"layer_norm"`):
349
+ Normalization layer to use. Values supported: "layer_norm", "rms_norm".
350
+ """
351
+
352
+ def __init__(
353
+ self,
354
+ embedding_dim: int,
355
+ conditioning_embedding_dim: int,
356
+ # NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
357
+ # because the output is immediately scaled and shifted by the projected conditioning embeddings.
358
+ # Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
359
+ # However, this is how it was implemented in the original code, and it's rather likely you should
360
+ # set `elementwise_affine` to False.
361
+ elementwise_affine=True,
362
+ eps=1e-5,
363
+ bias=True,
364
+ norm_type="layer_norm",
365
+ ):
366
+ super().__init__()
367
+ self.silu = nn.SiLU()
368
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
369
+ if norm_type == "layer_norm":
370
+ self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias)
371
+ elif norm_type == "rms_norm":
372
+ self.norm = RMSNormFramePack(embedding_dim, eps, elementwise_affine)
373
+ else:
374
+ raise ValueError(f"unknown norm_type {norm_type}")
375
+
376
+ def forward(self, x, conditioning_embedding):
377
+ emb = self.linear(self.silu(conditioning_embedding))
378
+ scale, shift = emb.chunk(2, dim=1)
379
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
380
+ return x
381
+
382
+
383
+ class LinearActivation(nn.Module):
384
+ def __init__(self, dim_in: int, dim_out: int, bias: bool = True, activation: str = "silu"):
385
+ super().__init__()
386
+
387
+ self.proj = nn.Linear(dim_in, dim_out, bias=bias)
388
+ self.activation = get_activation(activation)
389
+
390
+ def forward(self, hidden_states):
391
+ hidden_states = self.proj(hidden_states)
392
+ return self.activation(hidden_states)
393
+
394
+
395
+ class FeedForward(nn.Module):
396
+ r"""
397
+ A feed-forward layer.
398
+
399
+ Parameters:
400
+ dim (`int`): The number of channels in the input.
401
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
402
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
403
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
404
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
405
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
406
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
407
+ """
408
+
409
+ def __init__(
410
+ self,
411
+ dim: int,
412
+ dim_out: Optional[int] = None,
413
+ mult: int = 4,
414
+ dropout: float = 0.0,
415
+ activation_fn: str = "geglu",
416
+ final_dropout: bool = False,
417
+ inner_dim=None,
418
+ bias: bool = True,
419
+ ):
420
+ super().__init__()
421
+ if inner_dim is None:
422
+ inner_dim = int(dim * mult)
423
+ dim_out = dim_out if dim_out is not None else dim
424
+
425
+ # if activation_fn == "gelu":
426
+ # act_fn = GELU(dim, inner_dim, bias=bias)
427
+ if activation_fn == "gelu-approximate":
428
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
429
+ # elif activation_fn == "geglu":
430
+ # act_fn = GEGLU(dim, inner_dim, bias=bias)
431
+ # elif activation_fn == "geglu-approximate":
432
+ # act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
433
+ # elif activation_fn == "swiglu":
434
+ # act_fn = SwiGLU(dim, inner_dim, bias=bias)
435
+ elif activation_fn == "linear-silu":
436
+ act_fn = LinearActivation(dim, inner_dim, bias=bias, activation="silu")
437
+ else:
438
+ raise ValueError(f"Unknown activation function: {activation_fn}")
439
+
440
+ self.net = nn.ModuleList([])
441
+ # project in
442
+ self.net.append(act_fn)
443
+ # project dropout
444
+ self.net.append(nn.Dropout(dropout))
445
+ # project out
446
+ self.net.append(nn.Linear(inner_dim, dim_out, bias=bias))
447
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
448
+ if final_dropout:
449
+ self.net.append(nn.Dropout(dropout))
450
+
451
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
452
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
453
+ # deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
454
+ # deprecate("scale", "1.0.0", deprecation_message)
455
+ raise ValueError("scale is not supported in this version. Please remove it.")
456
+ for module in self.net:
457
+ hidden_states = module(hidden_states)
458
+ return hidden_states
459
+
460
+
461
+ # @maybe_allow_in_graph
462
+ class Attention(nn.Module):
463
+ r"""
464
+ Minimal copy of Attention class from diffusers.
465
+ """
466
+
467
+ def __init__(
468
+ self,
469
+ query_dim: int,
470
+ cross_attention_dim: Optional[int] = None,
471
+ heads: int = 8,
472
+ dim_head: int = 64,
473
+ bias: bool = False,
474
+ qk_norm: Optional[str] = None,
475
+ added_kv_proj_dim: Optional[int] = None,
476
+ eps: float = 1e-5,
477
+ processor: Optional[any] = None,
478
+ out_dim: int = None,
479
+ context_pre_only=None,
480
+ pre_only=False,
481
+ ):
482
+ super().__init__()
483
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
484
+ self.inner_kv_dim = self.inner_dim # if kv_heads is None else dim_head * kv_heads
485
+ self.query_dim = query_dim
486
+ self.use_bias = bias
487
+ self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
488
+ self.out_dim = out_dim if out_dim is not None else query_dim
489
+ self.out_context_dim = query_dim
490
+ self.context_pre_only = context_pre_only
491
+ self.pre_only = pre_only
492
+
493
+ self.scale = dim_head**-0.5
494
+ self.heads = out_dim // dim_head if out_dim is not None else heads
495
+
496
+ self.added_kv_proj_dim = added_kv_proj_dim
497
+
498
+ if qk_norm is None:
499
+ self.norm_q = None
500
+ self.norm_k = None
501
+ elif qk_norm == "rms_norm":
502
+ self.norm_q = RMSNormFramePack(dim_head, eps=eps)
503
+ self.norm_k = RMSNormFramePack(dim_head, eps=eps)
504
+ else:
505
+ raise ValueError(
506
+ f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
507
+ )
508
+
509
+ self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
510
+ self.to_k = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
511
+ self.to_v = nn.Linear(self.cross_attention_dim, self.inner_kv_dim, bias=bias)
512
+
513
+ self.added_proj_bias = True # added_proj_bias
514
+ if self.added_kv_proj_dim is not None:
515
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
516
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_kv_dim, bias=True)
517
+ if self.context_pre_only is not None:
518
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
519
+ else:
520
+ self.add_q_proj = None
521
+ self.add_k_proj = None
522
+ self.add_v_proj = None
523
+
524
+ if not self.pre_only:
525
+ self.to_out = nn.ModuleList([])
526
+ self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=True))
527
+ # self.to_out.append(nn.Dropout(dropout))
528
+ self.to_out.append(nn.Identity()) # dropout=0.0
529
+ else:
530
+ self.to_out = None
531
+
532
+ if self.context_pre_only is not None and not self.context_pre_only:
533
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=True)
534
+ else:
535
+ self.to_add_out = None
536
+
537
+ if qk_norm is not None and added_kv_proj_dim is not None:
538
+ if qk_norm == "rms_norm":
539
+ self.norm_added_q = RMSNormFramePack(dim_head, eps=eps)
540
+ self.norm_added_k = RMSNormFramePack(dim_head, eps=eps)
541
+ else:
542
+ raise ValueError(f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`")
543
+ else:
544
+ self.norm_added_q = None
545
+ self.norm_added_k = None
546
+
547
+ # set attention processor
548
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
549
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
550
+ if processor is None:
551
+ processor = AttnProcessor2_0()
552
+ self.set_processor(processor)
553
+
554
+ def set_processor(self, processor: any) -> None:
555
+ self.processor = processor
556
+
557
+ def get_processor(self) -> any:
558
+ return self.processor
559
+
560
+ def forward(
561
+ self,
562
+ hidden_states: torch.Tensor,
563
+ encoder_hidden_states: Optional[torch.Tensor] = None,
564
+ attention_mask: Optional[torch.Tensor] = None,
565
+ **cross_attention_kwargs,
566
+ ) -> torch.Tensor:
567
+ return self.processor(
568
+ self,
569
+ hidden_states,
570
+ encoder_hidden_states=encoder_hidden_states,
571
+ attention_mask=attention_mask,
572
+ **cross_attention_kwargs,
573
+ )
574
+
575
+ def prepare_attention_mask(
576
+ self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
577
+ ) -> torch.Tensor:
578
+ r"""
579
+ Prepare the attention mask for the attention computation.
580
+
581
+ Args:
582
+ attention_mask (`torch.Tensor`):
583
+ The attention mask to prepare.
584
+ target_length (`int`):
585
+ The target length of the attention mask. This is the length of the attention mask after padding.
586
+ batch_size (`int`):
587
+ The batch size, which is used to repeat the attention mask.
588
+ out_dim (`int`, *optional*, defaults to `3`):
589
+ The output dimension of the attention mask. Can be either `3` or `4`.
590
+
591
+ Returns:
592
+ `torch.Tensor`: The prepared attention mask.
593
+ """
594
+ head_size = self.heads
595
+ if attention_mask is None:
596
+ return attention_mask
597
+
598
+ current_length: int = attention_mask.shape[-1]
599
+ if current_length != target_length:
600
+ if attention_mask.device.type == "mps":
601
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
602
+ # Instead, we can manually construct the padding tensor.
603
+ padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
604
+ padding = torch.zeros(padding_shape, dtype=attention_mask.dtype, device=attention_mask.device)
605
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
606
+ else:
607
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
608
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
609
+ # remaining_length: int = target_length - current_length
610
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
611
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
612
+
613
+ if out_dim == 3:
614
+ if attention_mask.shape[0] < batch_size * head_size:
615
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0, output_size=attention_mask.shape[0] * head_size)
616
+ elif out_dim == 4:
617
+ attention_mask = attention_mask.unsqueeze(1)
618
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1, output_size=attention_mask.shape[1] * head_size)
619
+
620
+ return attention_mask
621
+
622
+
623
+ class AttnProcessor2_0:
624
+ r"""
625
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
626
+ """
627
+
628
+ def __init__(self):
629
+ if not hasattr(F, "scaled_dot_product_attention"):
630
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
631
+
632
+ def __call__(
633
+ self,
634
+ attn: Attention,
635
+ hidden_states: torch.Tensor,
636
+ encoder_hidden_states: Optional[torch.Tensor] = None,
637
+ attention_mask: Optional[torch.Tensor] = None,
638
+ temb: Optional[torch.Tensor] = None,
639
+ *args,
640
+ **kwargs,
641
+ ) -> torch.Tensor:
642
+ input_ndim = hidden_states.ndim
643
+
644
+ if input_ndim == 4:
645
+ batch_size, channel, height, width = hidden_states.shape
646
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
647
+
648
+ batch_size, sequence_length, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
649
+
650
+ if attention_mask is not None:
651
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
652
+ # scaled_dot_product_attention expects attention_mask shape to be
653
+ # (batch, heads, source_length, target_length)
654
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
655
+
656
+ query = attn.to_q(hidden_states)
657
+ query_dtype = query.dtype # store dtype before potentially deleting query
658
+
659
+ if encoder_hidden_states is None:
660
+ encoder_hidden_states = hidden_states
661
+
662
+ key = attn.to_k(encoder_hidden_states)
663
+ value = attn.to_v(encoder_hidden_states)
664
+
665
+ inner_dim = key.shape[-1]
666
+ head_dim = inner_dim // attn.heads
667
+
668
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
669
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
670
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
671
+
672
+ if attn.norm_q is not None:
673
+ query = attn.norm_q(query)
674
+ if attn.norm_k is not None:
675
+ key = attn.norm_k(key)
676
+
677
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
678
+ hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False)
679
+ del query, key, value, attention_mask # free memory
680
+
681
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
682
+ hidden_states = hidden_states.to(query_dtype) # use stored dtype
683
+
684
+ # linear proj
685
+ hidden_states = attn.to_out[0](hidden_states)
686
+ # dropout
687
+ hidden_states = attn.to_out[1](hidden_states)
688
+
689
+ if input_ndim == 4:
690
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
691
+
692
+ return hidden_states
693
+
694
+
695
+ # endregion diffusers
696
+
697
+
698
+ def pad_for_3d_conv(x, kernel_size):
699
+ b, c, t, h, w = x.shape
700
+ pt, ph, pw = kernel_size
701
+ pad_t = (pt - (t % pt)) % pt
702
+ pad_h = (ph - (h % ph)) % ph
703
+ pad_w = (pw - (w % pw)) % pw
704
+ return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
705
+
706
+
707
+ def center_down_sample_3d(x, kernel_size):
708
+ # pt, ph, pw = kernel_size
709
+ # cp = (pt * ph * pw) // 2
710
+ # xp = einops.rearrange(x, 'b c (t pt) (h ph) (w pw) -> (pt ph pw) b c t h w', pt=pt, ph=ph, pw=pw)
711
+ # xc = xp[cp]
712
+ # return xc
713
+ return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
714
+
715
+
716
+ def get_cu_seqlens(text_mask, img_len):
717
+ batch_size = text_mask.shape[0]
718
+ text_len = text_mask.sum(dim=1)
719
+ max_len = text_mask.shape[1] + img_len
720
+
721
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device=text_mask.device) # ensure device match
722
+
723
+ for i in range(batch_size):
724
+ s = text_len[i] + img_len
725
+ s1 = i * max_len + s
726
+ s2 = (i + 1) * max_len
727
+ cu_seqlens[2 * i + 1] = s1
728
+ cu_seqlens[2 * i + 2] = s2
729
+
730
+ return cu_seqlens
731
+
732
+
733
+ def apply_rotary_emb_transposed(x, freqs_cis):
734
+ cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
735
+ del freqs_cis
736
+ x_real, x_imag = x.unflatten(-1, (-1, 2)).unbind(-1)
737
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
738
+ del x_real, x_imag
739
+ return (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
740
+
741
+
742
+ def attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=None, split_attn=False):
743
+ if cu_seqlens_q is None and cu_seqlens_kv is None and max_seqlen_q is None and max_seqlen_kv is None:
744
+ if attn_mode == "sageattn" or attn_mode is None and sageattn is not None:
745
+ x = sageattn(q, k, v, tensor_layout="NHD")
746
+ return x
747
+
748
+ if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None:
749
+ x = flash_attn_func(q, k, v)
750
+ return x
751
+
752
+ if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None:
753
+ x = xformers_attn_func(q, k, v)
754
+ return x
755
+
756
+ x = torch.nn.functional.scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)).transpose(
757
+ 1, 2
758
+ )
759
+ return x
760
+ if split_attn:
761
+ if attn_mode == "sageattn" or attn_mode is None and sageattn is not None:
762
+ x = torch.empty_like(q)
763
+ for i in range(q.size(0)):
764
+ x[i : i + 1] = sageattn(q[i : i + 1], k[i : i + 1], v[i : i + 1], tensor_layout="NHD")
765
+ return x
766
+
767
+ if attn_mode == "flash" or attn_mode is None and flash_attn_func is not None:
768
+ x = torch.empty_like(q)
769
+ for i in range(q.size(0)):
770
+ x[i : i + 1] = flash_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1])
771
+ return x
772
+
773
+ if attn_mode == "xformers" or attn_mode is None and xformers_attn_func is not None:
774
+ x = torch.empty_like(q)
775
+ for i in range(q.size(0)):
776
+ x[i : i + 1] = xformers_attn_func(q[i : i + 1], k[i : i + 1], v[i : i + 1])
777
+ return x
778
+
779
+ q = q.transpose(1, 2)
780
+ k = k.transpose(1, 2)
781
+ v = v.transpose(1, 2)
782
+ x = torch.empty_like(q)
783
+ for i in range(q.size(0)):
784
+ x[i : i + 1] = torch.nn.functional.scaled_dot_product_attention(q[i : i + 1], k[i : i + 1], v[i : i + 1])
785
+ x = x.transpose(1, 2)
786
+ return x
787
+
788
+ batch_size = q.shape[0]
789
+ q = q.view(q.shape[0] * q.shape[1], *q.shape[2:])
790
+ k = k.view(k.shape[0] * k.shape[1], *k.shape[2:])
791
+ v = v.view(v.shape[0] * v.shape[1], *v.shape[2:])
792
+ if attn_mode == "sageattn" or attn_mode is None and sageattn_varlen is not None:
793
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
794
+ del q, k, v # free memory
795
+ elif attn_mode == "flash" or attn_mode is None and flash_attn_varlen_func is not None:
796
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
797
+ del q, k, v # free memory
798
+ else:
799
+ raise NotImplementedError("No Attn Installed or batch_size > 1 is not supported in this configuration. Try `--split_attn`.")
800
+ x = x.view(batch_size, max_seqlen_q, *x.shape[2:])
801
+ return x
802
+
803
+
804
+ class HunyuanAttnProcessorFlashAttnDouble:
805
+ def __call__(
806
+ self,
807
+ attn: Attention,
808
+ hidden_states,
809
+ encoder_hidden_states,
810
+ attention_mask,
811
+ image_rotary_emb,
812
+ attn_mode: Optional[str] = None,
813
+ split_attn: Optional[bool] = False,
814
+ ):
815
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
816
+
817
+ # Project image latents
818
+ query = attn.to_q(hidden_states)
819
+ key = attn.to_k(hidden_states)
820
+ value = attn.to_v(hidden_states)
821
+ del hidden_states # free memory
822
+
823
+ query = query.unflatten(2, (attn.heads, -1))
824
+ key = key.unflatten(2, (attn.heads, -1))
825
+ value = value.unflatten(2, (attn.heads, -1))
826
+
827
+ query = attn.norm_q(query)
828
+ key = attn.norm_k(key)
829
+
830
+ query = apply_rotary_emb_transposed(query, image_rotary_emb)
831
+ key = apply_rotary_emb_transposed(key, image_rotary_emb)
832
+ del image_rotary_emb # free memory
833
+
834
+ # Project context (text/encoder) embeddings
835
+ encoder_query = attn.add_q_proj(encoder_hidden_states)
836
+ encoder_key = attn.add_k_proj(encoder_hidden_states)
837
+ encoder_value = attn.add_v_proj(encoder_hidden_states)
838
+ txt_length = encoder_hidden_states.shape[1] # store length before deleting
839
+ del encoder_hidden_states # free memory
840
+
841
+ encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
842
+ encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
843
+ encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
844
+
845
+ encoder_query = attn.norm_added_q(encoder_query)
846
+ encoder_key = attn.norm_added_k(encoder_key)
847
+
848
+ # Concatenate image and context q, k, v
849
+ query = torch.cat([query, encoder_query], dim=1)
850
+ key = torch.cat([key, encoder_key], dim=1)
851
+ value = torch.cat([value, encoder_value], dim=1)
852
+ del encoder_query, encoder_key, encoder_value # free memory
853
+
854
+ hidden_states_attn = attn_varlen_func(
855
+ query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn
856
+ )
857
+ del query, key, value # free memory
858
+ hidden_states_attn = hidden_states_attn.flatten(-2)
859
+
860
+ hidden_states, encoder_hidden_states = hidden_states_attn[:, :-txt_length], hidden_states_attn[:, -txt_length:]
861
+ del hidden_states_attn # free memory
862
+
863
+ # Apply output projections
864
+ hidden_states = attn.to_out[0](hidden_states)
865
+ hidden_states = attn.to_out[1](hidden_states) # Dropout/Identity
866
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
867
+
868
+ return hidden_states, encoder_hidden_states
869
+
870
+
871
+ class HunyuanAttnProcessorFlashAttnSingle:
872
+ def __call__(
873
+ self,
874
+ attn: Attention,
875
+ hidden_states,
876
+ encoder_hidden_states,
877
+ attention_mask,
878
+ image_rotary_emb,
879
+ attn_mode: Optional[str] = None,
880
+ split_attn: Optional[bool] = False,
881
+ ):
882
+ cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv = attention_mask
883
+ txt_length = encoder_hidden_states.shape[1] # Store text length
884
+
885
+ # Concatenate image and context inputs
886
+ hidden_states_cat = torch.cat([hidden_states, encoder_hidden_states], dim=1)
887
+ del hidden_states, encoder_hidden_states # free memory
888
+
889
+ # Project concatenated inputs
890
+ query = attn.to_q(hidden_states_cat)
891
+ key = attn.to_k(hidden_states_cat)
892
+ value = attn.to_v(hidden_states_cat)
893
+ del hidden_states_cat # free memory
894
+
895
+ query = query.unflatten(2, (attn.heads, -1))
896
+ key = key.unflatten(2, (attn.heads, -1))
897
+ value = value.unflatten(2, (attn.heads, -1))
898
+
899
+ query = attn.norm_q(query)
900
+ key = attn.norm_k(key)
901
+
902
+ query = torch.cat([apply_rotary_emb_transposed(query[:, :-txt_length], image_rotary_emb), query[:, -txt_length:]], dim=1)
903
+ key = torch.cat([apply_rotary_emb_transposed(key[:, :-txt_length], image_rotary_emb), key[:, -txt_length:]], dim=1)
904
+ del image_rotary_emb # free memory
905
+
906
+ hidden_states = attn_varlen_func(
907
+ query, key, value, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, attn_mode=attn_mode, split_attn=split_attn
908
+ )
909
+ del query, key, value # free memory
910
+ hidden_states = hidden_states.flatten(-2)
911
+
912
+ hidden_states, encoder_hidden_states = hidden_states[:, :-txt_length], hidden_states[:, -txt_length:]
913
+
914
+ return hidden_states, encoder_hidden_states
915
+
916
+
917
+ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
918
+ def __init__(self, embedding_dim, pooled_projection_dim):
919
+ super().__init__()
920
+
921
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
922
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
923
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
924
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
925
+
926
+ def forward(self, timestep, guidance, pooled_projection):
927
+ timesteps_proj = self.time_proj(timestep)
928
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
929
+
930
+ guidance_proj = self.time_proj(guidance)
931
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
932
+
933
+ time_guidance_emb = timesteps_emb + guidance_emb
934
+
935
+ pooled_projections = self.text_embedder(pooled_projection)
936
+ conditioning = time_guidance_emb + pooled_projections
937
+
938
+ return conditioning
939
+
940
+
941
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
942
+ def __init__(self, embedding_dim, pooled_projection_dim):
943
+ super().__init__()
944
+
945
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
946
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
947
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
948
+
949
+ def forward(self, timestep, pooled_projection):
950
+ timesteps_proj = self.time_proj(timestep)
951
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype))
952
+
953
+ pooled_projections = self.text_embedder(pooled_projection)
954
+
955
+ conditioning = timesteps_emb + pooled_projections
956
+
957
+ return conditioning
958
+
959
+
960
+ class HunyuanVideoAdaNorm(nn.Module):
961
+ def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
962
+ super().__init__()
963
+
964
+ out_features = out_features or 2 * in_features
965
+ self.linear = nn.Linear(in_features, out_features)
966
+ self.nonlinearity = nn.SiLU()
967
+
968
+ def forward(self, temb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
969
+ temb = self.linear(self.nonlinearity(temb))
970
+ gate_msa, gate_mlp = temb.chunk(2, dim=-1)
971
+ gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
972
+ return gate_msa, gate_mlp
973
+
974
+
975
+ class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
976
+ def __init__(
977
+ self,
978
+ num_attention_heads: int,
979
+ attention_head_dim: int,
980
+ mlp_width_ratio: float = 4.0,
981
+ mlp_drop_rate: float = 0.0,
982
+ attention_bias: bool = True,
983
+ ) -> None:
984
+ super().__init__()
985
+
986
+ hidden_size = num_attention_heads * attention_head_dim
987
+
988
+ self.norm1 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6)
989
+ self.attn = Attention(
990
+ query_dim=hidden_size,
991
+ cross_attention_dim=None,
992
+ heads=num_attention_heads,
993
+ dim_head=attention_head_dim,
994
+ bias=attention_bias,
995
+ )
996
+
997
+ self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=True, eps=1e-6)
998
+ self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
999
+
1000
+ self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
1001
+
1002
+ def forward(
1003
+ self,
1004
+ hidden_states: torch.Tensor,
1005
+ temb: torch.Tensor,
1006
+ attention_mask: Optional[torch.Tensor] = None,
1007
+ ) -> torch.Tensor:
1008
+ norm_hidden_states = self.norm1(hidden_states)
1009
+
1010
+ # Self-attention
1011
+ attn_output = self.attn(
1012
+ hidden_states=norm_hidden_states,
1013
+ encoder_hidden_states=None,
1014
+ attention_mask=attention_mask,
1015
+ )
1016
+ del norm_hidden_states # free memory
1017
+
1018
+ gate_msa, gate_mlp = self.norm_out(temb)
1019
+ hidden_states = hidden_states + attn_output * gate_msa
1020
+ del attn_output, gate_msa # free memory
1021
+
1022
+ ff_output = self.ff(self.norm2(hidden_states))
1023
+ hidden_states = hidden_states + ff_output * gate_mlp
1024
+ del ff_output, gate_mlp # free memory
1025
+
1026
+ return hidden_states
1027
+
1028
+
1029
+ class HunyuanVideoIndividualTokenRefiner(nn.Module):
1030
+ def __init__(
1031
+ self,
1032
+ num_attention_heads: int,
1033
+ attention_head_dim: int,
1034
+ num_layers: int,
1035
+ mlp_width_ratio: float = 4.0,
1036
+ mlp_drop_rate: float = 0.0,
1037
+ attention_bias: bool = True,
1038
+ ) -> None:
1039
+ super().__init__()
1040
+
1041
+ self.refiner_blocks = nn.ModuleList(
1042
+ [
1043
+ HunyuanVideoIndividualTokenRefinerBlock(
1044
+ num_attention_heads=num_attention_heads,
1045
+ attention_head_dim=attention_head_dim,
1046
+ mlp_width_ratio=mlp_width_ratio,
1047
+ mlp_drop_rate=mlp_drop_rate,
1048
+ attention_bias=attention_bias,
1049
+ )
1050
+ for _ in range(num_layers)
1051
+ ]
1052
+ )
1053
+
1054
+ def forward(
1055
+ self,
1056
+ hidden_states: torch.Tensor,
1057
+ temb: torch.Tensor,
1058
+ attention_mask: Optional[torch.Tensor] = None,
1059
+ ) -> torch.Tensor:
1060
+ self_attn_mask = None
1061
+ if attention_mask is not None:
1062
+ batch_size = attention_mask.shape[0]
1063
+ seq_len = attention_mask.shape[1]
1064
+ attention_mask = attention_mask.to(hidden_states.device).bool()
1065
+ self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
1066
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
1067
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
1068
+ self_attn_mask[:, :, :, 0] = True
1069
+
1070
+ for block in self.refiner_blocks:
1071
+ hidden_states = block(hidden_states, temb, self_attn_mask)
1072
+
1073
+ return hidden_states
1074
+
1075
+
1076
+ class HunyuanVideoTokenRefiner(nn.Module):
1077
+ def __init__(
1078
+ self,
1079
+ in_channels: int,
1080
+ num_attention_heads: int,
1081
+ attention_head_dim: int,
1082
+ num_layers: int,
1083
+ mlp_ratio: float = 4.0,
1084
+ mlp_drop_rate: float = 0.0,
1085
+ attention_bias: bool = True,
1086
+ ) -> None:
1087
+ super().__init__()
1088
+
1089
+ hidden_size = num_attention_heads * attention_head_dim
1090
+
1091
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(embedding_dim=hidden_size, pooled_projection_dim=in_channels)
1092
+ self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
1093
+ self.token_refiner = HunyuanVideoIndividualTokenRefiner(
1094
+ num_attention_heads=num_attention_heads,
1095
+ attention_head_dim=attention_head_dim,
1096
+ num_layers=num_layers,
1097
+ mlp_width_ratio=mlp_ratio,
1098
+ mlp_drop_rate=mlp_drop_rate,
1099
+ attention_bias=attention_bias,
1100
+ )
1101
+
1102
+ def forward(
1103
+ self,
1104
+ hidden_states: torch.Tensor,
1105
+ timestep: torch.LongTensor,
1106
+ attention_mask: Optional[torch.LongTensor] = None,
1107
+ ) -> torch.Tensor:
1108
+ if attention_mask is None:
1109
+ pooled_projections = hidden_states.mean(dim=1)
1110
+ else:
1111
+ original_dtype = hidden_states.dtype
1112
+ mask_float = attention_mask.float().unsqueeze(-1)
1113
+ pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
1114
+ pooled_projections = pooled_projections.to(original_dtype)
1115
+
1116
+ temb = self.time_text_embed(timestep, pooled_projections)
1117
+ del pooled_projections # free memory
1118
+
1119
+ hidden_states = self.proj_in(hidden_states)
1120
+ hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
1121
+ del temb, attention_mask # free memory
1122
+
1123
+ return hidden_states
1124
+
1125
+
1126
+ class HunyuanVideoRotaryPosEmbed(nn.Module):
1127
+ def __init__(self, rope_dim, theta):
1128
+ super().__init__()
1129
+ self.DT, self.DY, self.DX = rope_dim
1130
+ self.theta = theta
1131
+
1132
+ @torch.no_grad()
1133
+ def get_frequency(self, dim, pos):
1134
+ T, H, W = pos.shape
1135
+ freqs = 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32, device=pos.device)[: (dim // 2)] / dim))
1136
+ freqs = torch.outer(freqs, pos.reshape(-1)).unflatten(-1, (T, H, W)).repeat_interleave(2, dim=0)
1137
+ return freqs.cos(), freqs.sin()
1138
+
1139
+ @torch.no_grad()
1140
+ def forward_inner(self, frame_indices, height, width, device):
1141
+ GT, GY, GX = torch.meshgrid(
1142
+ frame_indices.to(device=device, dtype=torch.float32),
1143
+ torch.arange(0, height, device=device, dtype=torch.float32),
1144
+ torch.arange(0, width, device=device, dtype=torch.float32),
1145
+ indexing="ij",
1146
+ )
1147
+
1148
+ FCT, FST = self.get_frequency(self.DT, GT)
1149
+ del GT # free memory
1150
+ FCY, FSY = self.get_frequency(self.DY, GY)
1151
+ del GY # free memory
1152
+ FCX, FSX = self.get_frequency(self.DX, GX)
1153
+ del GX # free memory
1154
+
1155
+ result = torch.cat([FCT, FCY, FCX, FST, FSY, FSX], dim=0)
1156
+ del FCT, FCY, FCX, FST, FSY, FSX # free memory
1157
+
1158
+ # Return result already on the correct device
1159
+ return result # Shape (2 * total_dim / 2, T, H, W) -> (total_dim, T, H, W)
1160
+
1161
+ @torch.no_grad()
1162
+ def forward(self, frame_indices, height, width, device):
1163
+ frame_indices = frame_indices.unbind(0)
1164
+ results = [self.forward_inner(f, height, width, device) for f in frame_indices]
1165
+ results = torch.stack(results, dim=0)
1166
+ return results
1167
+
1168
+
1169
+ class AdaLayerNormZero(nn.Module):
1170
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
1171
+ super().__init__()
1172
+ self.silu = nn.SiLU()
1173
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
1174
+ if norm_type == "layer_norm":
1175
+ self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6)
1176
+ else:
1177
+ raise ValueError(f"unknown norm_type {norm_type}")
1178
+
1179
+ def forward(
1180
+ self, x: torch.Tensor, emb: Optional[torch.Tensor] = None
1181
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1182
+ emb = emb.unsqueeze(-2)
1183
+ emb = self.linear(self.silu(emb))
1184
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=-1)
1185
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
1186
+ return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
1187
+
1188
+
1189
+ class AdaLayerNormZeroSingle(nn.Module):
1190
+ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
1191
+ super().__init__()
1192
+
1193
+ self.silu = nn.SiLU()
1194
+ self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
1195
+ if norm_type == "layer_norm":
1196
+ self.norm = LayerNormFramePack(embedding_dim, elementwise_affine=False, eps=1e-6)
1197
+ else:
1198
+ raise ValueError(f"unknown norm_type {norm_type}")
1199
+
1200
+ def forward(
1201
+ self,
1202
+ x: torch.Tensor,
1203
+ emb: Optional[torch.Tensor] = None,
1204
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1205
+ emb = emb.unsqueeze(-2)
1206
+ emb = self.linear(self.silu(emb))
1207
+ shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=-1)
1208
+ x = self.norm(x) * (1 + scale_msa) + shift_msa
1209
+ return x, gate_msa
1210
+
1211
+
1212
+ class AdaLayerNormContinuous(nn.Module):
1213
+ def __init__(
1214
+ self,
1215
+ embedding_dim: int,
1216
+ conditioning_embedding_dim: int,
1217
+ elementwise_affine=True,
1218
+ eps=1e-5,
1219
+ bias=True,
1220
+ norm_type="layer_norm",
1221
+ ):
1222
+ super().__init__()
1223
+ self.silu = nn.SiLU()
1224
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
1225
+ if norm_type == "layer_norm":
1226
+ self.norm = LayerNormFramePack(embedding_dim, eps, elementwise_affine, bias)
1227
+ else:
1228
+ raise ValueError(f"unknown norm_type {norm_type}")
1229
+
1230
+ def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
1231
+ emb = emb.unsqueeze(-2)
1232
+ emb = self.linear(self.silu(emb))
1233
+ scale, shift = emb.chunk(2, dim=-1)
1234
+ del emb # free memory
1235
+ x = self.norm(x) * (1 + scale) + shift
1236
+ return x
1237
+
1238
+
1239
+ class HunyuanVideoSingleTransformerBlock(nn.Module):
1240
+ def __init__(
1241
+ self,
1242
+ num_attention_heads: int,
1243
+ attention_head_dim: int,
1244
+ mlp_ratio: float = 4.0,
1245
+ qk_norm: str = "rms_norm",
1246
+ attn_mode: Optional[str] = None,
1247
+ split_attn: Optional[bool] = False,
1248
+ ) -> None:
1249
+ super().__init__()
1250
+
1251
+ hidden_size = num_attention_heads * attention_head_dim
1252
+ mlp_dim = int(hidden_size * mlp_ratio)
1253
+ self.attn_mode = attn_mode
1254
+ self.split_attn = split_attn
1255
+
1256
+ # Attention layer (pre_only=True means no output projection in Attention module itself)
1257
+ self.attn = Attention(
1258
+ query_dim=hidden_size,
1259
+ cross_attention_dim=None,
1260
+ dim_head=attention_head_dim,
1261
+ heads=num_attention_heads,
1262
+ out_dim=hidden_size,
1263
+ bias=True,
1264
+ processor=HunyuanAttnProcessorFlashAttnSingle(),
1265
+ qk_norm=qk_norm,
1266
+ eps=1e-6,
1267
+ pre_only=True, # Crucial: Attn processor will return raw attention output
1268
+ )
1269
+
1270
+ self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
1271
+ self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
1272
+ self.act_mlp = nn.GELU(approximate="tanh")
1273
+ self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
1274
+
1275
+ def forward(
1276
+ self,
1277
+ hidden_states: torch.Tensor,
1278
+ encoder_hidden_states: torch.Tensor,
1279
+ temb: torch.Tensor,
1280
+ attention_mask: Optional[torch.Tensor] = None,
1281
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1282
+ ) -> torch.Tensor:
1283
+ text_seq_length = encoder_hidden_states.shape[1]
1284
+ hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
1285
+ del encoder_hidden_states # free memory
1286
+
1287
+ residual = hidden_states
1288
+
1289
+ # 1. Input normalization
1290
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
1291
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
1292
+
1293
+ norm_hidden_states, norm_encoder_hidden_states = (
1294
+ norm_hidden_states[:, :-text_seq_length, :],
1295
+ norm_hidden_states[:, -text_seq_length:, :],
1296
+ )
1297
+
1298
+ # 2. Attention
1299
+ attn_output, context_attn_output = self.attn(
1300
+ hidden_states=norm_hidden_states,
1301
+ encoder_hidden_states=norm_encoder_hidden_states,
1302
+ attention_mask=attention_mask,
1303
+ image_rotary_emb=image_rotary_emb,
1304
+ attn_mode=self.attn_mode,
1305
+ split_attn=self.split_attn,
1306
+ )
1307
+ attn_output = torch.cat([attn_output, context_attn_output], dim=1)
1308
+ del norm_hidden_states, norm_encoder_hidden_states, context_attn_output # free memory
1309
+ del image_rotary_emb
1310
+
1311
+ # 3. Modulation and residual connection
1312
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
1313
+ del attn_output, mlp_hidden_states # free memory
1314
+ hidden_states = gate * self.proj_out(hidden_states)
1315
+ hidden_states = hidden_states + residual
1316
+
1317
+ hidden_states, encoder_hidden_states = (
1318
+ hidden_states[:, :-text_seq_length, :],
1319
+ hidden_states[:, -text_seq_length:, :],
1320
+ )
1321
+ return hidden_states, encoder_hidden_states
1322
+
1323
+
1324
+ class HunyuanVideoTransformerBlock(nn.Module):
1325
+ def __init__(
1326
+ self,
1327
+ num_attention_heads: int,
1328
+ attention_head_dim: int,
1329
+ mlp_ratio: float,
1330
+ qk_norm: str = "rms_norm",
1331
+ attn_mode: Optional[str] = None,
1332
+ split_attn: Optional[bool] = False,
1333
+ ) -> None:
1334
+ super().__init__()
1335
+
1336
+ hidden_size = num_attention_heads * attention_head_dim
1337
+ self.attn_mode = attn_mode
1338
+ self.split_attn = split_attn
1339
+
1340
+ self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
1341
+ self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
1342
+
1343
+ self.attn = Attention(
1344
+ query_dim=hidden_size,
1345
+ cross_attention_dim=None,
1346
+ added_kv_proj_dim=hidden_size,
1347
+ dim_head=attention_head_dim,
1348
+ heads=num_attention_heads,
1349
+ out_dim=hidden_size,
1350
+ context_pre_only=False,
1351
+ bias=True,
1352
+ processor=HunyuanAttnProcessorFlashAttnDouble(),
1353
+ qk_norm=qk_norm,
1354
+ eps=1e-6,
1355
+ )
1356
+
1357
+ self.norm2 = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6)
1358
+ self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
1359
+
1360
+ self.norm2_context = LayerNormFramePack(hidden_size, elementwise_affine=False, eps=1e-6)
1361
+ self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
1362
+
1363
+ def forward(
1364
+ self,
1365
+ hidden_states: torch.Tensor,
1366
+ encoder_hidden_states: torch.Tensor,
1367
+ temb: torch.Tensor,
1368
+ attention_mask: Optional[torch.Tensor] = None,
1369
+ freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
1370
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1371
+ # 1. Input normalization
1372
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
1373
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
1374
+ encoder_hidden_states, emb=temb
1375
+ )
1376
+
1377
+ # 2. Joint attention
1378
+ attn_output, context_attn_output = self.attn(
1379
+ hidden_states=norm_hidden_states,
1380
+ encoder_hidden_states=norm_encoder_hidden_states,
1381
+ attention_mask=attention_mask,
1382
+ image_rotary_emb=freqs_cis,
1383
+ attn_mode=self.attn_mode,
1384
+ split_attn=self.split_attn,
1385
+ )
1386
+ del norm_hidden_states, norm_encoder_hidden_states, freqs_cis # free memory
1387
+
1388
+ # 3. Modulation and residual connection
1389
+ hidden_states = hidden_states + attn_output * gate_msa
1390
+ del attn_output, gate_msa # free memory
1391
+ encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa
1392
+ del context_attn_output, c_gate_msa # free memory
1393
+
1394
+ norm_hidden_states = self.norm2(hidden_states)
1395
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
1396
+
1397
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
1398
+ del shift_mlp, scale_mlp # free memory
1399
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
1400
+ del c_shift_mlp, c_scale_mlp # free memory
1401
+
1402
+ # 4. Feed-forward
1403
+ ff_output = self.ff(norm_hidden_states)
1404
+ del norm_hidden_states # free memory
1405
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
1406
+ del norm_encoder_hidden_states # free memory
1407
+
1408
+ hidden_states = hidden_states + gate_mlp * ff_output
1409
+ del ff_output, gate_mlp # free memory
1410
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
1411
+ del context_ff_output, c_gate_mlp # free memory
1412
+
1413
+ return hidden_states, encoder_hidden_states
1414
+
1415
+
1416
+ class ClipVisionProjection(nn.Module):
1417
+ def __init__(self, in_channels, out_channels):
1418
+ super().__init__()
1419
+ self.up = nn.Linear(in_channels, out_channels * 3)
1420
+ self.down = nn.Linear(out_channels * 3, out_channels)
1421
+
1422
+ def forward(self, x):
1423
+ projected_x = self.down(nn.functional.silu(self.up(x)))
1424
+ return projected_x
1425
+
1426
+
1427
+ class HunyuanVideoPatchEmbed(nn.Module):
1428
+ def __init__(self, patch_size, in_chans, embed_dim):
1429
+ super().__init__()
1430
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
1431
+
1432
+
1433
+ class HunyuanVideoPatchEmbedForCleanLatents(nn.Module):
1434
+ def __init__(self, inner_dim):
1435
+ super().__init__()
1436
+ self.proj = nn.Conv3d(16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
1437
+ self.proj_2x = nn.Conv3d(16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
1438
+ self.proj_4x = nn.Conv3d(16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
1439
+
1440
+ @torch.no_grad()
1441
+ def initialize_weight_from_another_conv3d(self, another_layer):
1442
+ weight = another_layer.weight.detach().clone()
1443
+ bias = another_layer.bias.detach().clone()
1444
+
1445
+ sd = {
1446
+ "proj.weight": weight.clone(),
1447
+ "proj.bias": bias.clone(),
1448
+ "proj_2x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=2, hk=2, wk=2) / 8.0,
1449
+ "proj_2x.bias": bias.clone(),
1450
+ "proj_4x.weight": einops.repeat(weight, "b c t h w -> b c (t tk) (h hk) (w wk)", tk=4, hk=4, wk=4) / 64.0,
1451
+ "proj_4x.bias": bias.clone(),
1452
+ }
1453
+
1454
+ sd = {k: v.clone() for k, v in sd.items()}
1455
+
1456
+ self.load_state_dict(sd)
1457
+ return
1458
+
1459
+
1460
+ class HunyuanVideoTransformer3DModelPacked(nn.Module): # (PreTrainedModelMixin, GenerationMixin,
1461
+ # ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
1462
+ # @register_to_config
1463
+ def __init__(
1464
+ self,
1465
+ in_channels: int = 16,
1466
+ out_channels: int = 16,
1467
+ num_attention_heads: int = 24,
1468
+ attention_head_dim: int = 128,
1469
+ num_layers: int = 20,
1470
+ num_single_layers: int = 40,
1471
+ num_refiner_layers: int = 2,
1472
+ mlp_ratio: float = 4.0,
1473
+ patch_size: int = 2,
1474
+ patch_size_t: int = 1,
1475
+ qk_norm: str = "rms_norm",
1476
+ guidance_embeds: bool = True,
1477
+ text_embed_dim: int = 4096,
1478
+ pooled_projection_dim: int = 768,
1479
+ rope_theta: float = 256.0,
1480
+ rope_axes_dim: Tuple[int] = (16, 56, 56),
1481
+ has_image_proj=False,
1482
+ image_proj_dim=1152,
1483
+ has_clean_x_embedder=False,
1484
+ attn_mode: Optional[str] = None,
1485
+ split_attn: Optional[bool] = False,
1486
+ ) -> None:
1487
+ super().__init__()
1488
+
1489
+ inner_dim = num_attention_heads * attention_head_dim
1490
+ out_channels = out_channels or in_channels
1491
+ self.config_patch_size = patch_size
1492
+ self.config_patch_size_t = patch_size_t
1493
+
1494
+ # 1. Latent and condition embedders
1495
+ self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
1496
+ self.context_embedder = HunyuanVideoTokenRefiner(
1497
+ text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
1498
+ )
1499
+ self.time_text_embed = CombinedTimestepGuidanceTextProjEmbeddings(inner_dim, pooled_projection_dim)
1500
+
1501
+ self.clean_x_embedder = None
1502
+ self.image_projection = None
1503
+
1504
+ # 2. RoPE
1505
+ self.rope = HunyuanVideoRotaryPosEmbed(rope_axes_dim, rope_theta)
1506
+
1507
+ # 3. Dual stream transformer blocks
1508
+ self.transformer_blocks = nn.ModuleList(
1509
+ [
1510
+ HunyuanVideoTransformerBlock(
1511
+ num_attention_heads,
1512
+ attention_head_dim,
1513
+ mlp_ratio=mlp_ratio,
1514
+ qk_norm=qk_norm,
1515
+ attn_mode=attn_mode,
1516
+ split_attn=split_attn,
1517
+ )
1518
+ for _ in range(num_layers)
1519
+ ]
1520
+ )
1521
+
1522
+ # 4. Single stream transformer blocks
1523
+ self.single_transformer_blocks = nn.ModuleList(
1524
+ [
1525
+ HunyuanVideoSingleTransformerBlock(
1526
+ num_attention_heads,
1527
+ attention_head_dim,
1528
+ mlp_ratio=mlp_ratio,
1529
+ qk_norm=qk_norm,
1530
+ attn_mode=attn_mode,
1531
+ split_attn=split_attn,
1532
+ )
1533
+ for _ in range(num_single_layers)
1534
+ ]
1535
+ )
1536
+
1537
+ # 5. Output projection
1538
+ self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
1539
+ self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
1540
+
1541
+ self.inner_dim = inner_dim
1542
+ self.use_gradient_checkpointing = False
1543
+ self.enable_teacache = False
1544
+
1545
+ # if has_image_proj:
1546
+ # self.install_image_projection(image_proj_dim)
1547
+ self.image_projection = ClipVisionProjection(in_channels=image_proj_dim, out_channels=self.inner_dim)
1548
+ # self.config["has_image_proj"] = True
1549
+ # self.config["image_proj_dim"] = in_channels
1550
+
1551
+ # if has_clean_x_embedder:
1552
+ # self.install_clean_x_embedder()
1553
+ self.clean_x_embedder = HunyuanVideoPatchEmbedForCleanLatents(self.inner_dim)
1554
+ # self.config["has_clean_x_embedder"] = True
1555
+
1556
+ self.high_quality_fp32_output_for_inference = True # False # change default to True
1557
+
1558
+ # Block swapping attributes (initialized to None)
1559
+ self.blocks_to_swap = None
1560
+ self.offloader_double = None
1561
+ self.offloader_single = None
1562
+
1563
+ @property
1564
+ def device(self):
1565
+ return next(self.parameters()).device
1566
+
1567
+ @property
1568
+ def dtype(self):
1569
+ return next(self.parameters()).dtype
1570
+
1571
+ def enable_gradient_checkpointing(self):
1572
+ self.use_gradient_checkpointing = True
1573
+ print("Gradient checkpointing enabled for HunyuanVideoTransformer3DModelPacked.") # Logging
1574
+
1575
+ def disable_gradient_checkpointing(self):
1576
+ self.use_gradient_checkpointing = False
1577
+ print("Gradient checkpointing disabled for HunyuanVideoTransformer3DModelPacked.") # Logging
1578
+
1579
+ def initialize_teacache(self, enable_teacache=True, num_steps=25, rel_l1_thresh=0.15):
1580
+ self.enable_teacache = enable_teacache
1581
+ self.cnt = 0
1582
+ self.num_steps = num_steps
1583
+ self.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup, 0.15 for 2.1x speedup
1584
+ self.accumulated_rel_l1_distance = 0
1585
+ self.previous_modulated_input = None
1586
+ self.previous_residual = None
1587
+ self.teacache_rescale_func = np.poly1d([7.33226126e02, -4.01131952e02, 6.75869174e01, -3.14987800e00, 9.61237896e-02])
1588
+ if enable_teacache:
1589
+ print(f"TeaCache enabled: num_steps={num_steps}, rel_l1_thresh={rel_l1_thresh}")
1590
+ else:
1591
+ print("TeaCache disabled.")
1592
+
1593
+ def gradient_checkpointing_method(self, block, *args):
1594
+ if self.use_gradient_checkpointing:
1595
+ result = torch.utils.checkpoint.checkpoint(block, *args, use_reentrant=False)
1596
+ else:
1597
+ result = block(*args)
1598
+ return result
1599
+
1600
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
1601
+ self.blocks_to_swap = num_blocks
1602
+ self.num_double_blocks = len(self.transformer_blocks)
1603
+ self.num_single_blocks = len(self.single_transformer_blocks)
1604
+ double_blocks_to_swap = num_blocks // 2
1605
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
1606
+
1607
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
1608
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
1609
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
1610
+ )
1611
+
1612
+ self.offloader_double = ModelOffloader(
1613
+ "double",
1614
+ self.transformer_blocks,
1615
+ self.num_double_blocks,
1616
+ double_blocks_to_swap,
1617
+ supports_backward,
1618
+ device,
1619
+ # debug=True # Optional debugging
1620
+ )
1621
+ self.offloader_single = ModelOffloader(
1622
+ "single",
1623
+ self.single_transformer_blocks,
1624
+ self.num_single_blocks,
1625
+ single_blocks_to_swap,
1626
+ supports_backward,
1627
+ device, # , debug=True
1628
+ )
1629
+ print(
1630
+ f"HunyuanVideoTransformer3DModelPacked: Block swap enabled. Swapping {num_blocks} blocks, "
1631
+ + f"double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}, supports_backward: {supports_backward}."
1632
+ )
1633
+
1634
+ def switch_block_swap_for_inference(self):
1635
+ if self.blocks_to_swap and self.blocks_to_swap > 0:
1636
+ self.offloader_double.set_forward_only(True)
1637
+ self.offloader_single.set_forward_only(True)
1638
+ self.prepare_block_swap_before_forward()
1639
+ print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward only.")
1640
+
1641
+ def switch_block_swap_for_training(self):
1642
+ if self.blocks_to_swap and self.blocks_to_swap > 0:
1643
+ self.offloader_double.set_forward_only(False)
1644
+ self.offloader_single.set_forward_only(False)
1645
+ self.prepare_block_swap_before_forward()
1646
+ print(f"HunyuanVideoTransformer3DModelPacked: Block swap set to forward and backward.")
1647
+
1648
+ def move_to_device_except_swap_blocks(self, device: torch.device):
1649
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
1650
+ if self.blocks_to_swap:
1651
+ saved_double_blocks = self.transformer_blocks
1652
+ saved_single_blocks = self.single_transformer_blocks
1653
+ self.transformer_blocks = None
1654
+ self.single_transformer_blocks = None
1655
+
1656
+ self.to(device)
1657
+
1658
+ if self.blocks_to_swap:
1659
+ self.transformer_blocks = saved_double_blocks
1660
+ self.single_transformer_blocks = saved_single_blocks
1661
+
1662
+ def prepare_block_swap_before_forward(self):
1663
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
1664
+ return
1665
+ self.offloader_double.prepare_block_devices_before_forward(self.transformer_blocks)
1666
+ self.offloader_single.prepare_block_devices_before_forward(self.single_transformer_blocks)
1667
+
1668
+ def process_input_hidden_states(
1669
+ self,
1670
+ latents,
1671
+ latent_indices=None,
1672
+ clean_latents=None,
1673
+ clean_latent_indices=None,
1674
+ clean_latents_2x=None,
1675
+ clean_latent_2x_indices=None,
1676
+ clean_latents_4x=None,
1677
+ clean_latent_4x_indices=None,
1678
+ ):
1679
+ hidden_states = self.gradient_checkpointing_method(self.x_embedder.proj, latents)
1680
+ B, C, T, H, W = hidden_states.shape
1681
+
1682
+ if latent_indices is None:
1683
+ latent_indices = torch.arange(0, T).unsqueeze(0).expand(B, -1)
1684
+
1685
+ hidden_states = hidden_states.flatten(2).transpose(1, 2)
1686
+
1687
+ rope_freqs = self.rope(frame_indices=latent_indices, height=H, width=W, device=hidden_states.device)
1688
+ rope_freqs = rope_freqs.flatten(2).transpose(1, 2)
1689
+
1690
+ if clean_latents is not None and clean_latent_indices is not None:
1691
+ clean_latents = clean_latents.to(hidden_states)
1692
+ clean_latents = self.gradient_checkpointing_method(self.clean_x_embedder.proj, clean_latents)
1693
+ clean_latents = clean_latents.flatten(2).transpose(1, 2)
1694
+
1695
+ clean_latent_rope_freqs = self.rope(frame_indices=clean_latent_indices, height=H, width=W, device=clean_latents.device)
1696
+ clean_latent_rope_freqs = clean_latent_rope_freqs.flatten(2).transpose(1, 2)
1697
+
1698
+ hidden_states = torch.cat([clean_latents, hidden_states], dim=1)
1699
+ rope_freqs = torch.cat([clean_latent_rope_freqs, rope_freqs], dim=1)
1700
+
1701
+ if clean_latents_2x is not None and clean_latent_2x_indices is not None:
1702
+ clean_latents_2x = clean_latents_2x.to(hidden_states)
1703
+ clean_latents_2x = pad_for_3d_conv(clean_latents_2x, (2, 4, 4))
1704
+ clean_latents_2x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_2x, clean_latents_2x)
1705
+ clean_latents_2x = clean_latents_2x.flatten(2).transpose(1, 2)
1706
+
1707
+ clean_latent_2x_rope_freqs = self.rope(
1708
+ frame_indices=clean_latent_2x_indices, height=H, width=W, device=clean_latents_2x.device
1709
+ )
1710
+ clean_latent_2x_rope_freqs = pad_for_3d_conv(clean_latent_2x_rope_freqs, (2, 2, 2))
1711
+ clean_latent_2x_rope_freqs = center_down_sample_3d(clean_latent_2x_rope_freqs, (2, 2, 2))
1712
+ clean_latent_2x_rope_freqs = clean_latent_2x_rope_freqs.flatten(2).transpose(1, 2)
1713
+
1714
+ hidden_states = torch.cat([clean_latents_2x, hidden_states], dim=1)
1715
+ rope_freqs = torch.cat([clean_latent_2x_rope_freqs, rope_freqs], dim=1)
1716
+
1717
+ if clean_latents_4x is not None and clean_latent_4x_indices is not None:
1718
+ clean_latents_4x = clean_latents_4x.to(hidden_states)
1719
+ clean_latents_4x = pad_for_3d_conv(clean_latents_4x, (4, 8, 8))
1720
+ clean_latents_4x = self.gradient_checkpointing_method(self.clean_x_embedder.proj_4x, clean_latents_4x)
1721
+ clean_latents_4x = clean_latents_4x.flatten(2).transpose(1, 2)
1722
+
1723
+ clean_latent_4x_rope_freqs = self.rope(
1724
+ frame_indices=clean_latent_4x_indices, height=H, width=W, device=clean_latents_4x.device
1725
+ )
1726
+ clean_latent_4x_rope_freqs = pad_for_3d_conv(clean_latent_4x_rope_freqs, (4, 4, 4))
1727
+ clean_latent_4x_rope_freqs = center_down_sample_3d(clean_latent_4x_rope_freqs, (4, 4, 4))
1728
+ clean_latent_4x_rope_freqs = clean_latent_4x_rope_freqs.flatten(2).transpose(1, 2)
1729
+
1730
+ hidden_states = torch.cat([clean_latents_4x, hidden_states], dim=1)
1731
+ rope_freqs = torch.cat([clean_latent_4x_rope_freqs, rope_freqs], dim=1)
1732
+
1733
+ return hidden_states, rope_freqs
1734
+
1735
+ def forward(
1736
+ self,
1737
+ hidden_states,
1738
+ timestep,
1739
+ encoder_hidden_states,
1740
+ encoder_attention_mask,
1741
+ pooled_projections,
1742
+ guidance,
1743
+ latent_indices=None,
1744
+ clean_latents=None,
1745
+ clean_latent_indices=None,
1746
+ clean_latents_2x=None,
1747
+ clean_latent_2x_indices=None,
1748
+ clean_latents_4x=None,
1749
+ clean_latent_4x_indices=None,
1750
+ image_embeddings=None,
1751
+ attention_kwargs=None,
1752
+ return_dict=True,
1753
+ ):
1754
+
1755
+ if attention_kwargs is None:
1756
+ attention_kwargs = {}
1757
+
1758
+ batch_size, num_channels, num_frames, height, width = hidden_states.shape
1759
+ p, p_t = self.config_patch_size, self.config_patch_size_t
1760
+ post_patch_num_frames = num_frames // p_t
1761
+ post_patch_height = height // p
1762
+ post_patch_width = width // p
1763
+ original_context_length = post_patch_num_frames * post_patch_height * post_patch_width
1764
+
1765
+ hidden_states, rope_freqs = self.process_input_hidden_states(
1766
+ hidden_states,
1767
+ latent_indices,
1768
+ clean_latents,
1769
+ clean_latent_indices,
1770
+ clean_latents_2x,
1771
+ clean_latent_2x_indices,
1772
+ clean_latents_4x,
1773
+ clean_latent_4x_indices,
1774
+ )
1775
+ del (
1776
+ latent_indices,
1777
+ clean_latents,
1778
+ clean_latent_indices,
1779
+ clean_latents_2x,
1780
+ clean_latent_2x_indices,
1781
+ clean_latents_4x,
1782
+ clean_latent_4x_indices,
1783
+ ) # free memory
1784
+
1785
+ temb = self.gradient_checkpointing_method(self.time_text_embed, timestep, guidance, pooled_projections)
1786
+ encoder_hidden_states = self.gradient_checkpointing_method(
1787
+ self.context_embedder, encoder_hidden_states, timestep, encoder_attention_mask
1788
+ )
1789
+
1790
+ if self.image_projection is not None:
1791
+ assert image_embeddings is not None, "You must use image embeddings!"
1792
+ extra_encoder_hidden_states = self.gradient_checkpointing_method(self.image_projection, image_embeddings)
1793
+ extra_attention_mask = torch.ones(
1794
+ (batch_size, extra_encoder_hidden_states.shape[1]),
1795
+ dtype=encoder_attention_mask.dtype,
1796
+ device=encoder_attention_mask.device,
1797
+ )
1798
+
1799
+ # must cat before (not after) encoder_hidden_states, due to attn masking
1800
+ encoder_hidden_states = torch.cat([extra_encoder_hidden_states, encoder_hidden_states], dim=1)
1801
+ encoder_attention_mask = torch.cat([extra_attention_mask, encoder_attention_mask], dim=1)
1802
+ del extra_encoder_hidden_states, extra_attention_mask # free memory
1803
+
1804
+ with torch.no_grad():
1805
+ if batch_size == 1:
1806
+ # When batch size is 1, we do not need any masks or var-len funcs since cropping is mathematically same to what we want
1807
+ # If they are not same, then their impls are wrong. Ours are always the correct one.
1808
+ text_len = encoder_attention_mask.sum().item()
1809
+ encoder_hidden_states = encoder_hidden_states[:, :text_len]
1810
+ attention_mask = None, None, None, None
1811
+ else:
1812
+ img_seq_len = hidden_states.shape[1]
1813
+ txt_seq_len = encoder_hidden_states.shape[1]
1814
+
1815
+ cu_seqlens_q = get_cu_seqlens(encoder_attention_mask, img_seq_len)
1816
+ cu_seqlens_kv = cu_seqlens_q
1817
+ max_seqlen_q = img_seq_len + txt_seq_len
1818
+ max_seqlen_kv = max_seqlen_q
1819
+
1820
+ attention_mask = cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv
1821
+ del cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv # free memory
1822
+ del encoder_attention_mask # free memory
1823
+
1824
+ if self.enable_teacache:
1825
+ modulated_inp = self.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
1826
+
1827
+ if self.cnt == 0 or self.cnt == self.num_steps - 1:
1828
+ should_calc = True
1829
+ self.accumulated_rel_l1_distance = 0
1830
+ else:
1831
+ curr_rel_l1 = (
1832
+ ((modulated_inp - self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean())
1833
+ .cpu()
1834
+ .item()
1835
+ )
1836
+ self.accumulated_rel_l1_distance += self.teacache_rescale_func(curr_rel_l1)
1837
+ should_calc = self.accumulated_rel_l1_distance >= self.rel_l1_thresh
1838
+
1839
+ if should_calc:
1840
+ self.accumulated_rel_l1_distance = 0
1841
+
1842
+ self.previous_modulated_input = modulated_inp
1843
+ self.cnt += 1
1844
+
1845
+ if self.cnt == self.num_steps:
1846
+ self.cnt = 0
1847
+
1848
+ if not should_calc:
1849
+ hidden_states = hidden_states + self.previous_residual
1850
+ else:
1851
+ ori_hidden_states = hidden_states.clone()
1852
+
1853
+ for block_id, block in enumerate(self.transformer_blocks):
1854
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1855
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1856
+ )
1857
+
1858
+ for block_id, block in enumerate(self.single_transformer_blocks):
1859
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1860
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1861
+ )
1862
+
1863
+ self.previous_residual = hidden_states - ori_hidden_states
1864
+ del ori_hidden_states # free memory
1865
+ else:
1866
+ for block_id, block in enumerate(self.transformer_blocks):
1867
+ if self.blocks_to_swap:
1868
+ self.offloader_double.wait_for_block(block_id)
1869
+
1870
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1871
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1872
+ )
1873
+
1874
+ if self.blocks_to_swap:
1875
+ self.offloader_double.submit_move_blocks_forward(self.transformer_blocks, block_id)
1876
+
1877
+ for block_id, block in enumerate(self.single_transformer_blocks):
1878
+ if self.blocks_to_swap:
1879
+ self.offloader_single.wait_for_block(block_id)
1880
+
1881
+ hidden_states, encoder_hidden_states = self.gradient_checkpointing_method(
1882
+ block, hidden_states, encoder_hidden_states, temb, attention_mask, rope_freqs
1883
+ )
1884
+
1885
+ if self.blocks_to_swap:
1886
+ self.offloader_single.submit_move_blocks_forward(self.single_transformer_blocks, block_id)
1887
+
1888
+ del attention_mask, rope_freqs # free memory
1889
+ del encoder_hidden_states # free memory
1890
+
1891
+ hidden_states = self.gradient_checkpointing_method(self.norm_out, hidden_states, temb)
1892
+
1893
+ hidden_states = hidden_states[:, -original_context_length:, :]
1894
+
1895
+ if self.high_quality_fp32_output_for_inference:
1896
+ hidden_states = hidden_states.to(dtype=torch.float32)
1897
+ if self.proj_out.weight.dtype != torch.float32:
1898
+ self.proj_out.to(dtype=torch.float32)
1899
+
1900
+ hidden_states = self.gradient_checkpointing_method(self.proj_out, hidden_states)
1901
+
1902
+ hidden_states = einops.rearrange(
1903
+ hidden_states,
1904
+ "b (t h w) (c pt ph pw) -> b c (t pt) (h ph) (w pw)",
1905
+ t=post_patch_num_frames,
1906
+ h=post_patch_height,
1907
+ w=post_patch_width,
1908
+ pt=p_t,
1909
+ ph=p,
1910
+ pw=p,
1911
+ )
1912
+
1913
+ if return_dict:
1914
+ # return Transformer2DModelOutput(sample=hidden_states)
1915
+ return SimpleNamespace(sample=hidden_states)
1916
+
1917
+ return (hidden_states,)
1918
+
1919
+ def fp8_optimization(
1920
+ self, state_dict: dict[str, torch.Tensor], device: torch.device, move_to_device: bool, use_scaled_mm: bool = False
1921
+ ) -> dict[str, torch.Tensor]: # Return type hint added
1922
+ """
1923
+ Optimize the model state_dict with fp8.
1924
+
1925
+ Args:
1926
+ state_dict (dict[str, torch.Tensor]):
1927
+ The state_dict of the model.
1928
+ device (torch.device):
1929
+ The device to calculate the weight.
1930
+ move_to_device (bool):
1931
+ Whether to move the weight to the device after optimization.
1932
+ use_scaled_mm (bool):
1933
+ Whether to use scaled matrix multiplication for FP8.
1934
+ """
1935
+ TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
1936
+ EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8
1937
+
1938
+ # inplace optimization
1939
+ state_dict = optimize_state_dict_with_fp8(state_dict, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device)
1940
+
1941
+ # apply monkey patching
1942
+ apply_fp8_monkey_patch(self, state_dict, use_scaled_mm=use_scaled_mm)
1943
+
1944
+ return state_dict
1945
+
1946
+
1947
+ def load_packed_model(
1948
+ device: Union[str, torch.device],
1949
+ dit_path: str,
1950
+ attn_mode: str,
1951
+ loading_device: Union[str, torch.device],
1952
+ fp8_scaled: bool = False,
1953
+ split_attn: bool = False,
1954
+ ) -> HunyuanVideoTransformer3DModelPacked:
1955
+ # TODO support split_attn
1956
+ device = torch.device(device)
1957
+ loading_device = torch.device(loading_device)
1958
+
1959
+ if os.path.isdir(dit_path):
1960
+ # we don't support from_pretrained for now, so loading safetensors directly
1961
+ safetensor_files = glob.glob(os.path.join(dit_path, "*.safetensors"))
1962
+ if len(safetensor_files) == 0:
1963
+ raise ValueError(f"Cannot find safetensors file in {dit_path}")
1964
+ # sort by name and take the first one
1965
+ safetensor_files.sort()
1966
+ dit_path = safetensor_files[0]
1967
+
1968
+ with init_empty_weights():
1969
+ logger.info(f"Creating HunyuanVideoTransformer3DModelPacked")
1970
+ model = HunyuanVideoTransformer3DModelPacked(
1971
+ attention_head_dim=128,
1972
+ guidance_embeds=True,
1973
+ has_clean_x_embedder=True,
1974
+ has_image_proj=True,
1975
+ image_proj_dim=1152,
1976
+ in_channels=16,
1977
+ mlp_ratio=4.0,
1978
+ num_attention_heads=24,
1979
+ num_layers=20,
1980
+ num_refiner_layers=2,
1981
+ num_single_layers=40,
1982
+ out_channels=16,
1983
+ patch_size=2,
1984
+ patch_size_t=1,
1985
+ pooled_projection_dim=768,
1986
+ qk_norm="rms_norm",
1987
+ rope_axes_dim=(16, 56, 56),
1988
+ rope_theta=256.0,
1989
+ text_embed_dim=4096,
1990
+ attn_mode=attn_mode,
1991
+ split_attn=split_attn,
1992
+ )
1993
+
1994
+ # if fp8_scaled, load model weights to CPU to reduce VRAM usage. Otherwise, load to the specified device (CPU for block swap or CUDA for others)
1995
+ dit_loading_device = torch.device("cpu") if fp8_scaled else loading_device
1996
+ logger.info(f"Loading DiT model from {dit_path}, device={dit_loading_device}")
1997
+
1998
+ # load model weights with the specified dtype or as is
1999
+ sd = load_split_weights(dit_path, device=dit_loading_device, disable_mmap=True)
2000
+
2001
+ if fp8_scaled:
2002
+ # fp8 optimization: calculate on CUDA, move back to CPU if loading_device is CPU (block swap)
2003
+ logger.info(f"Optimizing model weights to fp8. This may take a while.")
2004
+ sd = model.fp8_optimization(sd, device, move_to_device=loading_device.type == "cpu")
2005
+
2006
+ if loading_device.type != "cpu":
2007
+ # make sure all the model weights are on the loading_device
2008
+ logger.info(f"Moving weights to {loading_device}")
2009
+ for key in sd.keys():
2010
+ sd[key] = sd[key].to(loading_device)
2011
+
2012
+ info = model.load_state_dict(sd, strict=True, assign=True)
2013
+ logger.info(f"Loaded DiT model from {dit_path}, info={info}")
2014
+
2015
+ return model
frame_pack/k_diffusion_hunyuan.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # original code: https://github.com/lllyasviel/FramePack
2
+ # original license: Apache-2.0
3
+
4
+ import torch
5
+ import math
6
+
7
+ # from diffusers_helper.k_diffusion.uni_pc_fm import sample_unipc
8
+ # from diffusers_helper.k_diffusion.wrapper import fm_wrapper
9
+ # from diffusers_helper.utils import repeat_to_batch_size
10
+ from frame_pack.uni_pc_fm import sample_unipc
11
+ from frame_pack.wrapper import fm_wrapper
12
+ from frame_pack.utils import repeat_to_batch_size
13
+
14
+
15
+ def flux_time_shift(t, mu=1.15, sigma=1.0):
16
+ return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
17
+
18
+
19
+ def calculate_flux_mu(context_length, x1=256, y1=0.5, x2=4096, y2=1.15, exp_max=7.0):
20
+ k = (y2 - y1) / (x2 - x1)
21
+ b = y1 - k * x1
22
+ mu = k * context_length + b
23
+ mu = min(mu, math.log(exp_max))
24
+ return mu
25
+
26
+
27
+ def get_flux_sigmas_from_mu(n, mu):
28
+ sigmas = torch.linspace(1, 0, steps=n + 1)
29
+ sigmas = flux_time_shift(sigmas, mu=mu)
30
+ return sigmas
31
+
32
+
33
+ # @torch.inference_mode()
34
+ def sample_hunyuan(
35
+ transformer,
36
+ sampler="unipc",
37
+ initial_latent=None,
38
+ concat_latent=None,
39
+ strength=1.0,
40
+ width=512,
41
+ height=512,
42
+ frames=16,
43
+ real_guidance_scale=1.0,
44
+ distilled_guidance_scale=6.0,
45
+ guidance_rescale=0.0,
46
+ shift=None,
47
+ num_inference_steps=25,
48
+ batch_size=None,
49
+ generator=None,
50
+ prompt_embeds=None,
51
+ prompt_embeds_mask=None,
52
+ prompt_poolers=None,
53
+ negative_prompt_embeds=None,
54
+ negative_prompt_embeds_mask=None,
55
+ negative_prompt_poolers=None,
56
+ dtype=torch.bfloat16,
57
+ device=None,
58
+ negative_kwargs=None,
59
+ callback=None,
60
+ **kwargs,
61
+ ):
62
+ device = device or transformer.device
63
+
64
+ if batch_size is None:
65
+ batch_size = int(prompt_embeds.shape[0])
66
+
67
+ latents = torch.randn(
68
+ (batch_size, 16, (frames + 3) // 4, height // 8, width // 8), generator=generator, device=generator.device
69
+ ).to(device=device, dtype=torch.float32)
70
+
71
+ B, C, T, H, W = latents.shape
72
+ seq_length = T * H * W // 4 # 9*80*80//4 = 14400
73
+
74
+ if shift is None:
75
+ mu = calculate_flux_mu(seq_length, exp_max=7.0) # 1.9459... if seq_len is large, mu is clipped.
76
+ else:
77
+ mu = math.log(shift)
78
+
79
+ sigmas = get_flux_sigmas_from_mu(num_inference_steps, mu).to(device)
80
+
81
+ k_model = fm_wrapper(transformer)
82
+
83
+ if initial_latent is not None:
84
+ sigmas = sigmas * strength
85
+ first_sigma = sigmas[0].to(device=device, dtype=torch.float32)
86
+ initial_latent = initial_latent.to(device=device, dtype=torch.float32)
87
+ latents = initial_latent.float() * (1.0 - first_sigma) + latents.float() * first_sigma
88
+
89
+ if concat_latent is not None:
90
+ concat_latent = concat_latent.to(latents)
91
+
92
+ distilled_guidance = torch.tensor([distilled_guidance_scale * 1000.0] * batch_size).to(device=device, dtype=dtype)
93
+
94
+ prompt_embeds = repeat_to_batch_size(prompt_embeds, batch_size)
95
+ prompt_embeds_mask = repeat_to_batch_size(prompt_embeds_mask, batch_size)
96
+ prompt_poolers = repeat_to_batch_size(prompt_poolers, batch_size)
97
+ negative_prompt_embeds = repeat_to_batch_size(negative_prompt_embeds, batch_size)
98
+ negative_prompt_embeds_mask = repeat_to_batch_size(negative_prompt_embeds_mask, batch_size)
99
+ negative_prompt_poolers = repeat_to_batch_size(negative_prompt_poolers, batch_size)
100
+ concat_latent = repeat_to_batch_size(concat_latent, batch_size)
101
+
102
+ sampler_kwargs = dict(
103
+ dtype=dtype,
104
+ cfg_scale=real_guidance_scale,
105
+ cfg_rescale=guidance_rescale,
106
+ concat_latent=concat_latent,
107
+ positive=dict(
108
+ pooled_projections=prompt_poolers,
109
+ encoder_hidden_states=prompt_embeds,
110
+ encoder_attention_mask=prompt_embeds_mask,
111
+ guidance=distilled_guidance,
112
+ **kwargs,
113
+ ),
114
+ negative=dict(
115
+ pooled_projections=negative_prompt_poolers,
116
+ encoder_hidden_states=negative_prompt_embeds,
117
+ encoder_attention_mask=negative_prompt_embeds_mask,
118
+ guidance=distilled_guidance,
119
+ **(kwargs if negative_kwargs is None else {**kwargs, **negative_kwargs}),
120
+ ),
121
+ )
122
+
123
+ if sampler == "unipc":
124
+ results = sample_unipc(k_model, latents, sigmas, extra_args=sampler_kwargs, disable=False, callback=callback)
125
+ else:
126
+ raise NotImplementedError(f"Sampler {sampler} is not supported.")
127
+
128
+ return results
frame_pack/uni_pc_fm.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Better Flow Matching UniPC by Lvmin Zhang
2
+ # (c) 2025
3
+ # CC BY-SA 4.0
4
+ # Attribution-ShareAlike 4.0 International Licence
5
+
6
+
7
+ import torch
8
+
9
+ from tqdm.auto import trange
10
+
11
+
12
+ def expand_dims(v, dims):
13
+ return v[(...,) + (None,) * (dims - 1)]
14
+
15
+
16
+ class FlowMatchUniPC:
17
+ def __init__(self, model, extra_args, variant='bh1'):
18
+ self.model = model
19
+ self.variant = variant
20
+ self.extra_args = extra_args
21
+
22
+ def model_fn(self, x, t):
23
+ return self.model(x, t, **self.extra_args)
24
+
25
+ def update_fn(self, x, model_prev_list, t_prev_list, t, order):
26
+ assert order <= len(model_prev_list)
27
+ dims = x.dim()
28
+
29
+ t_prev_0 = t_prev_list[-1]
30
+ lambda_prev_0 = - torch.log(t_prev_0)
31
+ lambda_t = - torch.log(t)
32
+ model_prev_0 = model_prev_list[-1]
33
+
34
+ h = lambda_t - lambda_prev_0
35
+
36
+ rks = []
37
+ D1s = []
38
+ for i in range(1, order):
39
+ t_prev_i = t_prev_list[-(i + 1)]
40
+ model_prev_i = model_prev_list[-(i + 1)]
41
+ lambda_prev_i = - torch.log(t_prev_i)
42
+ rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
43
+ rks.append(rk)
44
+ D1s.append((model_prev_i - model_prev_0) / rk)
45
+
46
+ rks.append(1.)
47
+ rks = torch.tensor(rks, device=x.device)
48
+
49
+ R = []
50
+ b = []
51
+
52
+ hh = -h[0]
53
+ h_phi_1 = torch.expm1(hh)
54
+ h_phi_k = h_phi_1 / hh - 1
55
+
56
+ factorial_i = 1
57
+
58
+ if self.variant == 'bh1':
59
+ B_h = hh
60
+ elif self.variant == 'bh2':
61
+ B_h = torch.expm1(hh)
62
+ else:
63
+ raise NotImplementedError('Bad variant!')
64
+
65
+ for i in range(1, order + 1):
66
+ R.append(torch.pow(rks, i - 1))
67
+ b.append(h_phi_k * factorial_i / B_h)
68
+ factorial_i *= (i + 1)
69
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
70
+
71
+ R = torch.stack(R)
72
+ b = torch.tensor(b, device=x.device)
73
+
74
+ use_predictor = len(D1s) > 0
75
+
76
+ if use_predictor:
77
+ D1s = torch.stack(D1s, dim=1)
78
+ if order == 2:
79
+ rhos_p = torch.tensor([0.5], device=b.device)
80
+ else:
81
+ rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
82
+ else:
83
+ D1s = None
84
+ rhos_p = None
85
+
86
+ if order == 1:
87
+ rhos_c = torch.tensor([0.5], device=b.device)
88
+ else:
89
+ rhos_c = torch.linalg.solve(R, b)
90
+
91
+ x_t_ = expand_dims(t / t_prev_0, dims) * x - expand_dims(h_phi_1, dims) * model_prev_0
92
+
93
+ if use_predictor:
94
+ pred_res = torch.tensordot(D1s, rhos_p, dims=([1], [0]))
95
+ else:
96
+ pred_res = 0
97
+
98
+ x_t = x_t_ - expand_dims(B_h, dims) * pred_res
99
+ model_t = self.model_fn(x_t, t)
100
+
101
+ if D1s is not None:
102
+ corr_res = torch.tensordot(D1s, rhos_c[:-1], dims=([1], [0]))
103
+ else:
104
+ corr_res = 0
105
+
106
+ D1_t = (model_t - model_prev_0)
107
+ x_t = x_t_ - expand_dims(B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
108
+
109
+ return x_t, model_t
110
+
111
+ def sample(self, x, sigmas, callback=None, disable_pbar=False):
112
+ order = min(3, len(sigmas) - 2)
113
+ model_prev_list, t_prev_list = [], []
114
+ for i in trange(len(sigmas) - 1, disable=disable_pbar):
115
+ vec_t = sigmas[i].expand(x.shape[0])
116
+
117
+ with torch.no_grad():
118
+ if i == 0:
119
+ model_prev_list = [self.model_fn(x, vec_t)]
120
+ t_prev_list = [vec_t]
121
+ elif i < order:
122
+ init_order = i
123
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, init_order)
124
+ model_prev_list.append(model_x)
125
+ t_prev_list.append(vec_t)
126
+ else:
127
+ x, model_x = self.update_fn(x, model_prev_list, t_prev_list, vec_t, order)
128
+ model_prev_list.append(model_x)
129
+ t_prev_list.append(vec_t)
130
+
131
+ model_prev_list = model_prev_list[-order:]
132
+ t_prev_list = t_prev_list[-order:]
133
+
134
+ if callback is not None:
135
+ callback({'x': x, 'i': i, 'denoised': model_prev_list[-1]})
136
+
137
+ return model_prev_list[-1]
138
+
139
+
140
+ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=False, variant='bh1'):
141
+ assert variant in ['bh1', 'bh2']
142
+ return FlowMatchUniPC(model, extra_args=extra_args, variant=variant).sample(noise, sigmas=sigmas, callback=callback, disable_pbar=disable)
frame_pack/utils.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import json
4
+ import random
5
+ import glob
6
+ import torch
7
+ import einops
8
+ import numpy as np
9
+ import datetime
10
+ import torchvision
11
+
12
+ import safetensors.torch as sf
13
+ from PIL import Image
14
+
15
+
16
+ def min_resize(x, m):
17
+ if x.shape[0] < x.shape[1]:
18
+ s0 = m
19
+ s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
20
+ else:
21
+ s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
22
+ s1 = m
23
+ new_max = max(s1, s0)
24
+ raw_max = max(x.shape[0], x.shape[1])
25
+ if new_max < raw_max:
26
+ interpolation = cv2.INTER_AREA
27
+ else:
28
+ interpolation = cv2.INTER_LANCZOS4
29
+ y = cv2.resize(x, (s1, s0), interpolation=interpolation)
30
+ return y
31
+
32
+
33
+ def d_resize(x, y):
34
+ H, W, C = y.shape
35
+ new_min = min(H, W)
36
+ raw_min = min(x.shape[0], x.shape[1])
37
+ if new_min < raw_min:
38
+ interpolation = cv2.INTER_AREA
39
+ else:
40
+ interpolation = cv2.INTER_LANCZOS4
41
+ y = cv2.resize(x, (W, H), interpolation=interpolation)
42
+ return y
43
+
44
+
45
+ def resize_and_center_crop(image, target_width, target_height):
46
+ if target_height == image.shape[0] and target_width == image.shape[1]:
47
+ return image
48
+
49
+ pil_image = Image.fromarray(image)
50
+ original_width, original_height = pil_image.size
51
+ scale_factor = max(target_width / original_width, target_height / original_height)
52
+ resized_width = int(round(original_width * scale_factor))
53
+ resized_height = int(round(original_height * scale_factor))
54
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
55
+ left = (resized_width - target_width) / 2
56
+ top = (resized_height - target_height) / 2
57
+ right = (resized_width + target_width) / 2
58
+ bottom = (resized_height + target_height) / 2
59
+ cropped_image = resized_image.crop((left, top, right, bottom))
60
+ return np.array(cropped_image)
61
+
62
+
63
+ def resize_and_center_crop_pytorch(image, target_width, target_height):
64
+ B, C, H, W = image.shape
65
+
66
+ if H == target_height and W == target_width:
67
+ return image
68
+
69
+ scale_factor = max(target_width / W, target_height / H)
70
+ resized_width = int(round(W * scale_factor))
71
+ resized_height = int(round(H * scale_factor))
72
+
73
+ resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode="bilinear", align_corners=False)
74
+
75
+ top = (resized_height - target_height) // 2
76
+ left = (resized_width - target_width) // 2
77
+ cropped = resized[:, :, top : top + target_height, left : left + target_width]
78
+
79
+ return cropped
80
+
81
+
82
+ def resize_without_crop(image, target_width, target_height):
83
+ if target_height == image.shape[0] and target_width == image.shape[1]:
84
+ return image
85
+
86
+ pil_image = Image.fromarray(image)
87
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
88
+ return np.array(resized_image)
89
+
90
+
91
+ def just_crop(image, w, h):
92
+ if h == image.shape[0] and w == image.shape[1]:
93
+ return image
94
+
95
+ original_height, original_width = image.shape[:2]
96
+ k = min(original_height / h, original_width / w)
97
+ new_width = int(round(w * k))
98
+ new_height = int(round(h * k))
99
+ x_start = (original_width - new_width) // 2
100
+ y_start = (original_height - new_height) // 2
101
+ cropped_image = image[y_start : y_start + new_height, x_start : x_start + new_width]
102
+ return cropped_image
103
+
104
+
105
+ def write_to_json(data, file_path):
106
+ temp_file_path = file_path + ".tmp"
107
+ with open(temp_file_path, "wt", encoding="utf-8") as temp_file:
108
+ json.dump(data, temp_file, indent=4)
109
+ os.replace(temp_file_path, file_path)
110
+ return
111
+
112
+
113
+ def read_from_json(file_path):
114
+ with open(file_path, "rt", encoding="utf-8") as file:
115
+ data = json.load(file)
116
+ return data
117
+
118
+
119
+ def get_active_parameters(m):
120
+ return {k: v for k, v in m.named_parameters() if v.requires_grad}
121
+
122
+
123
+ def cast_training_params(m, dtype=torch.float32):
124
+ result = {}
125
+ for n, param in m.named_parameters():
126
+ if param.requires_grad:
127
+ param.data = param.to(dtype)
128
+ result[n] = param
129
+ return result
130
+
131
+
132
+ def separate_lora_AB(parameters, B_patterns=None):
133
+ parameters_normal = {}
134
+ parameters_B = {}
135
+
136
+ if B_patterns is None:
137
+ B_patterns = [".lora_B.", "__zero__"]
138
+
139
+ for k, v in parameters.items():
140
+ if any(B_pattern in k for B_pattern in B_patterns):
141
+ parameters_B[k] = v
142
+ else:
143
+ parameters_normal[k] = v
144
+
145
+ return parameters_normal, parameters_B
146
+
147
+
148
+ def set_attr_recursive(obj, attr, value):
149
+ attrs = attr.split(".")
150
+ for name in attrs[:-1]:
151
+ obj = getattr(obj, name)
152
+ setattr(obj, attrs[-1], value)
153
+ return
154
+
155
+
156
+ def print_tensor_list_size(tensors):
157
+ total_size = 0
158
+ total_elements = 0
159
+
160
+ if isinstance(tensors, dict):
161
+ tensors = tensors.values()
162
+
163
+ for tensor in tensors:
164
+ total_size += tensor.nelement() * tensor.element_size()
165
+ total_elements += tensor.nelement()
166
+
167
+ total_size_MB = total_size / (1024**2)
168
+ total_elements_B = total_elements / 1e9
169
+
170
+ print(f"Total number of tensors: {len(tensors)}")
171
+ print(f"Total size of tensors: {total_size_MB:.2f} MB")
172
+ print(f"Total number of parameters: {total_elements_B:.3f} billion")
173
+ return
174
+
175
+
176
+ @torch.no_grad()
177
+ def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
178
+ batch_size = a.size(0)
179
+
180
+ if b is None:
181
+ b = torch.zeros_like(a)
182
+
183
+ if mask_a is None:
184
+ mask_a = torch.rand(batch_size) < probability_a
185
+
186
+ mask_a = mask_a.to(a.device)
187
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
188
+ result = torch.where(mask_a, a, b)
189
+ return result
190
+
191
+
192
+ @torch.no_grad()
193
+ def zero_module(module):
194
+ for p in module.parameters():
195
+ p.detach().zero_()
196
+ return module
197
+
198
+
199
+ @torch.no_grad()
200
+ def supress_lower_channels(m, k, alpha=0.01):
201
+ data = m.weight.data.clone()
202
+
203
+ assert int(data.shape[1]) >= k
204
+
205
+ data[:, :k] = data[:, :k] * alpha
206
+ m.weight.data = data.contiguous().clone()
207
+ return m
208
+
209
+
210
+ def freeze_module(m):
211
+ if not hasattr(m, "_forward_inside_frozen_module"):
212
+ m._forward_inside_frozen_module = m.forward
213
+ m.requires_grad_(False)
214
+ m.forward = torch.no_grad()(m.forward)
215
+ return m
216
+
217
+
218
+ def get_latest_safetensors(folder_path):
219
+ safetensors_files = glob.glob(os.path.join(folder_path, "*.safetensors"))
220
+
221
+ if not safetensors_files:
222
+ raise ValueError("No file to resume!")
223
+
224
+ latest_file = max(safetensors_files, key=os.path.getmtime)
225
+ latest_file = os.path.abspath(os.path.realpath(latest_file))
226
+ return latest_file
227
+
228
+
229
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
230
+ tags = tags_str.split(", ")
231
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
232
+ prompt = ", ".join(tags)
233
+ return prompt
234
+
235
+
236
+ def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
237
+ numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
238
+ if round_to_int:
239
+ numbers = np.round(numbers).astype(int)
240
+ return numbers.tolist()
241
+
242
+
243
+ def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
244
+ edges = np.linspace(0, 1, n + 1)
245
+ points = np.random.uniform(edges[:-1], edges[1:])
246
+ numbers = inclusive + (exclusive - inclusive) * points
247
+ if round_to_int:
248
+ numbers = np.round(numbers).astype(int)
249
+ return numbers.tolist()
250
+
251
+
252
+ def soft_append_bcthw(history, current, overlap=0):
253
+ if overlap <= 0:
254
+ return torch.cat([history, current], dim=2)
255
+
256
+ assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
257
+ assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
258
+
259
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
260
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
261
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
262
+
263
+ return output.to(history)
264
+
265
+
266
+ def save_bcthw_as_mp4(x, output_filename, fps=10):
267
+ b, c, t, h, w = x.shape
268
+
269
+ per_row = b
270
+ for p in [6, 5, 4, 3, 2]:
271
+ if b % p == 0:
272
+ per_row = p
273
+ break
274
+
275
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
276
+ x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
277
+ x = x.detach().cpu().to(torch.uint8)
278
+ x = einops.rearrange(x, "(m n) c t h w -> t (m h) (n w) c", n=per_row)
279
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec="libx264", options={"crf": "0"})
280
+
281
+ # write tensor as .pt file
282
+ torch.save(x, output_filename.replace(".mp4", ".pt"))
283
+
284
+ return x
285
+
286
+
287
+ def save_bcthw_as_png(x, output_filename):
288
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
289
+ x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
290
+ x = x.detach().cpu().to(torch.uint8)
291
+ x = einops.rearrange(x, "b c t h w -> c (b h) (t w)")
292
+ torchvision.io.write_png(x, output_filename)
293
+ return output_filename
294
+
295
+
296
+ def save_bchw_as_png(x, output_filename):
297
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
298
+ x = torch.clamp(x.float(), -1.0, 1.0) * 127.5 + 127.5
299
+ x = x.detach().cpu().to(torch.uint8)
300
+ x = einops.rearrange(x, "b c h w -> c h (b w)")
301
+ torchvision.io.write_png(x, output_filename)
302
+ return output_filename
303
+
304
+
305
+ def add_tensors_with_padding(tensor1, tensor2):
306
+ if tensor1.shape == tensor2.shape:
307
+ return tensor1 + tensor2
308
+
309
+ shape1 = tensor1.shape
310
+ shape2 = tensor2.shape
311
+
312
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
313
+
314
+ padded_tensor1 = torch.zeros(new_shape)
315
+ padded_tensor2 = torch.zeros(new_shape)
316
+
317
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
318
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
319
+
320
+ result = padded_tensor1 + padded_tensor2
321
+ return result
322
+
323
+
324
+ def print_free_mem():
325
+ torch.cuda.empty_cache()
326
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
327
+ free_mem_mb = free_mem / (1024**2)
328
+ total_mem_mb = total_mem / (1024**2)
329
+ print(f"Free memory: {free_mem_mb:.2f} MB")
330
+ print(f"Total memory: {total_mem_mb:.2f} MB")
331
+ return
332
+
333
+
334
+ def print_gpu_parameters(device, state_dict, log_count=1):
335
+ summary = {"device": device, "keys_count": len(state_dict)}
336
+
337
+ logged_params = {}
338
+ for i, (key, tensor) in enumerate(state_dict.items()):
339
+ if i >= log_count:
340
+ break
341
+ logged_params[key] = tensor.flatten()[:3].tolist()
342
+
343
+ summary["params"] = logged_params
344
+
345
+ print(str(summary))
346
+ return
347
+
348
+
349
+ def visualize_txt_as_img(width, height, text, font_path="font/DejaVuSans.ttf", size=18):
350
+ from PIL import Image, ImageDraw, ImageFont
351
+
352
+ txt = Image.new("RGB", (width, height), color="white")
353
+ draw = ImageDraw.Draw(txt)
354
+ font = ImageFont.truetype(font_path, size=size)
355
+
356
+ if text == "":
357
+ return np.array(txt)
358
+
359
+ # Split text into lines that fit within the image width
360
+ lines = []
361
+ words = text.split()
362
+ current_line = words[0]
363
+
364
+ for word in words[1:]:
365
+ line_with_word = f"{current_line} {word}"
366
+ if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
367
+ current_line = line_with_word
368
+ else:
369
+ lines.append(current_line)
370
+ current_line = word
371
+
372
+ lines.append(current_line)
373
+
374
+ # Draw the text line by line
375
+ y = 0
376
+ line_height = draw.textbbox((0, 0), "A", font=font)[3]
377
+
378
+ for line in lines:
379
+ if y + line_height > height:
380
+ break # stop drawing if the next line will be outside the image
381
+ draw.text((0, y), line, fill="black", font=font)
382
+ y += line_height
383
+
384
+ return np.array(txt)
385
+
386
+
387
+ def blue_mark(x):
388
+ x = x.copy()
389
+ c = x[:, :, 2]
390
+ b = cv2.blur(c, (9, 9))
391
+ x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
392
+ return x
393
+
394
+
395
+ def green_mark(x):
396
+ x = x.copy()
397
+ x[:, :, 2] = -1
398
+ x[:, :, 0] = -1
399
+ return x
400
+
401
+
402
+ def frame_mark(x):
403
+ x = x.copy()
404
+ x[:64] = -1
405
+ x[-64:] = -1
406
+ x[:, :8] = 1
407
+ x[:, -8:] = 1
408
+ return x
409
+
410
+
411
+ @torch.inference_mode()
412
+ def pytorch2numpy(imgs):
413
+ results = []
414
+ for x in imgs:
415
+ y = x.movedim(0, -1)
416
+ y = y * 127.5 + 127.5
417
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
418
+ results.append(y)
419
+ return results
420
+
421
+
422
+ @torch.inference_mode()
423
+ def numpy2pytorch(imgs):
424
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
425
+ h = h.movedim(-1, 1)
426
+ return h
427
+
428
+
429
+ @torch.no_grad()
430
+ def duplicate_prefix_to_suffix(x, count, zero_out=False):
431
+ if zero_out:
432
+ return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
433
+ else:
434
+ return torch.cat([x, x[:count]], dim=0)
435
+
436
+
437
+ def weighted_mse(a, b, weight):
438
+ return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
439
+
440
+
441
+ def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
442
+ x = (x - x_min) / (x_max - x_min)
443
+ x = max(0.0, min(x, 1.0))
444
+ x = x**sigma
445
+ return y_min + x * (y_max - y_min)
446
+
447
+
448
+ def expand_to_dims(x, target_dims):
449
+ return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
450
+
451
+
452
+ def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
453
+ if tensor is None:
454
+ return None
455
+
456
+ first_dim = tensor.shape[0]
457
+
458
+ if first_dim == batch_size:
459
+ return tensor
460
+
461
+ if batch_size % first_dim != 0:
462
+ raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
463
+
464
+ repeat_times = batch_size // first_dim
465
+
466
+ return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
467
+
468
+
469
+ def dim5(x):
470
+ return expand_to_dims(x, 5)
471
+
472
+
473
+ def dim4(x):
474
+ return expand_to_dims(x, 4)
475
+
476
+
477
+ def dim3(x):
478
+ return expand_to_dims(x, 3)
479
+
480
+
481
+ def crop_or_pad_yield_mask(x, length):
482
+ B, F, C = x.shape
483
+ device = x.device
484
+ dtype = x.dtype
485
+
486
+ if F < length:
487
+ y = torch.zeros((B, length, C), dtype=dtype, device=device)
488
+ mask = torch.zeros((B, length), dtype=torch.bool, device=device)
489
+ y[:, :F, :] = x
490
+ mask[:, :F] = True
491
+ return y, mask
492
+
493
+ return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
494
+
495
+
496
+ def extend_dim(x, dim, minimal_length, zero_pad=False):
497
+ original_length = int(x.shape[dim])
498
+
499
+ if original_length >= minimal_length:
500
+ return x
501
+
502
+ if zero_pad:
503
+ padding_shape = list(x.shape)
504
+ padding_shape[dim] = minimal_length - original_length
505
+ padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
506
+ else:
507
+ idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
508
+ last_element = x[idx]
509
+ padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
510
+
511
+ return torch.cat([x, padding], dim=dim)
512
+
513
+
514
+ def lazy_positional_encoding(t, repeats=None):
515
+ if not isinstance(t, list):
516
+ t = [t]
517
+
518
+ from diffusers.models.embeddings import get_timestep_embedding
519
+
520
+ te = torch.tensor(t)
521
+ te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
522
+
523
+ if repeats is None:
524
+ return te
525
+
526
+ te = te[:, None, :].expand(-1, repeats, -1)
527
+
528
+ return te
529
+
530
+
531
+ def state_dict_offset_merge(A, B, C=None):
532
+ result = {}
533
+ keys = A.keys()
534
+
535
+ for key in keys:
536
+ A_value = A[key]
537
+ B_value = B[key].to(A_value)
538
+
539
+ if C is None:
540
+ result[key] = A_value + B_value
541
+ else:
542
+ C_value = C[key].to(A_value)
543
+ result[key] = A_value + B_value - C_value
544
+
545
+ return result
546
+
547
+
548
+ def state_dict_weighted_merge(state_dicts, weights):
549
+ if len(state_dicts) != len(weights):
550
+ raise ValueError("Number of state dictionaries must match number of weights")
551
+
552
+ if not state_dicts:
553
+ return {}
554
+
555
+ total_weight = sum(weights)
556
+
557
+ if total_weight == 0:
558
+ raise ValueError("Sum of weights cannot be zero")
559
+
560
+ normalized_weights = [w / total_weight for w in weights]
561
+
562
+ keys = state_dicts[0].keys()
563
+ result = {}
564
+
565
+ for key in keys:
566
+ result[key] = state_dicts[0][key] * normalized_weights[0]
567
+
568
+ for i in range(1, len(state_dicts)):
569
+ state_dict_value = state_dicts[i][key].to(result[key])
570
+ result[key] += state_dict_value * normalized_weights[i]
571
+
572
+ return result
573
+
574
+
575
+ def group_files_by_folder(all_files):
576
+ grouped_files = {}
577
+
578
+ for file in all_files:
579
+ folder_name = os.path.basename(os.path.dirname(file))
580
+ if folder_name not in grouped_files:
581
+ grouped_files[folder_name] = []
582
+ grouped_files[folder_name].append(file)
583
+
584
+ list_of_lists = list(grouped_files.values())
585
+ return list_of_lists
586
+
587
+
588
+ def generate_timestamp():
589
+ now = datetime.datetime.now()
590
+ timestamp = now.strftime("%y%m%d_%H%M%S")
591
+ milliseconds = f"{int(now.microsecond / 1000):03d}"
592
+ random_number = random.randint(0, 9999)
593
+ return f"{timestamp}_{milliseconds}_{random_number}"
594
+
595
+
596
+ def write_PIL_image_with_png_info(image, metadata, path):
597
+ from PIL.PngImagePlugin import PngInfo
598
+
599
+ png_info = PngInfo()
600
+ for key, value in metadata.items():
601
+ png_info.add_text(key, value)
602
+
603
+ image.save(path, "PNG", pnginfo=png_info)
604
+ return image
605
+
606
+
607
+ def torch_safe_save(content, path):
608
+ torch.save(content, path + "_tmp")
609
+ os.replace(path + "_tmp", path)
610
+ return path
611
+
612
+
613
+ def move_optimizer_to_device(optimizer, device):
614
+ for state in optimizer.state.values():
615
+ for k, v in state.items():
616
+ if isinstance(v, torch.Tensor):
617
+ state[k] = v.to(device)
frame_pack/wrapper.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def append_dims(x, target_dims):
5
+ return x[(...,) + (None,) * (target_dims - x.ndim)]
6
+
7
+
8
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=1.0):
9
+ if guidance_rescale == 0:
10
+ return noise_cfg
11
+
12
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
13
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
14
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
15
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1.0 - guidance_rescale) * noise_cfg
16
+ return noise_cfg
17
+
18
+
19
+ def fm_wrapper(transformer, t_scale=1000.0):
20
+ def k_model(x, sigma, **extra_args):
21
+ dtype = extra_args['dtype']
22
+ cfg_scale = extra_args['cfg_scale']
23
+ cfg_rescale = extra_args['cfg_rescale']
24
+ concat_latent = extra_args['concat_latent']
25
+
26
+ original_dtype = x.dtype
27
+ sigma = sigma.float()
28
+
29
+ x = x.to(dtype)
30
+ timestep = (sigma * t_scale).to(dtype)
31
+
32
+ if concat_latent is None:
33
+ hidden_states = x
34
+ else:
35
+ hidden_states = torch.cat([x, concat_latent.to(x)], dim=1)
36
+
37
+ pred_positive = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['positive'])[0].float()
38
+
39
+ if cfg_scale == 1.0:
40
+ pred_negative = torch.zeros_like(pred_positive)
41
+ else:
42
+ pred_negative = transformer(hidden_states=hidden_states, timestep=timestep, return_dict=False, **extra_args['negative'])[0].float()
43
+
44
+ pred_cfg = pred_negative + cfg_scale * (pred_positive - pred_negative)
45
+ pred = rescale_noise_cfg(pred_cfg, pred_positive, guidance_rescale=cfg_rescale)
46
+
47
+ x0 = x.float() - pred.float() * append_dims(sigma, x.ndim)
48
+
49
+ return x0.to(dtype=original_dtype)
50
+
51
+ return k_model
framepack_edit_output/framepack-edit-lora-000001.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5a6478224e15dd49359bb791f4d1984d4f87b2b69e858784a32266d4a9b270c
3
+ size 275426304
framepack_edit_output/framepack-edit-lora-000002.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b24eefda91054ca54f70c9d50eb2df47a1954c4ddf2f3f12078d67e8a97a767
3
+ size 275426304
framepack_edit_output/framepack-edit-lora-000003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9c0e9747f651655dd95dd13f1c2999662a48dc3e89c84537ec7dc88ec1b307f
3
+ size 275426304
framepack_edit_output/framepack-edit-lora-000004.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7456a21c9cbf4bcf4ddcf2e8aacff7d90dc96c1dbaa1f802bb32dbd9e38bbb9b
3
+ size 275426304
framepack_edit_output/framepack-edit-lora-000005.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e1d152090aafda957a8ab146ab183aed9b2ddeed70e9fc003163d59024f7e3d6
3
+ size 275426304
framepack_edit_output/framepack-edit-lora-000006.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:30bed80789e6ea6d3b5749e86b299889a9d3282758862a5cacf69c66f49c89a5
3
+ size 275426304
hunyuan_model/__init__.py ADDED
File without changes
hunyuan_model/activation_layers.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def get_activation_layer(act_type):
5
+ """get activation layer
6
+
7
+ Args:
8
+ act_type (str): the activation type
9
+
10
+ Returns:
11
+ torch.nn.functional: the activation layer
12
+ """
13
+ if act_type == "gelu":
14
+ return lambda: nn.GELU()
15
+ elif act_type == "gelu_tanh":
16
+ # Approximate `tanh` requires torch >= 1.13
17
+ return lambda: nn.GELU(approximate="tanh")
18
+ elif act_type == "relu":
19
+ return nn.ReLU
20
+ elif act_type == "silu":
21
+ return nn.SiLU
22
+ else:
23
+ raise ValueError(f"Unknown activation type: {act_type}")
hunyuan_model/attention.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.metadata
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ try:
9
+ import flash_attn
10
+ from flash_attn.flash_attn_interface import _flash_attn_forward
11
+ from flash_attn.flash_attn_interface import flash_attn_varlen_func
12
+ from flash_attn.flash_attn_interface import flash_attn_func
13
+ except ImportError:
14
+ flash_attn = None
15
+ flash_attn_varlen_func = None
16
+ _flash_attn_forward = None
17
+ flash_attn_func = None
18
+
19
+ try:
20
+ print(f"Trying to import sageattention")
21
+ from sageattention import sageattn_varlen, sageattn
22
+
23
+ print("Successfully imported sageattention")
24
+ except ImportError:
25
+ print(f"Failed to import sageattention")
26
+ sageattn_varlen = None
27
+ sageattn = None
28
+
29
+ try:
30
+ import xformers.ops as xops
31
+ except ImportError:
32
+ xops = None
33
+
34
+ MEMORY_LAYOUT = {
35
+ "flash": (
36
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
37
+ lambda x: x,
38
+ ),
39
+ "flash_fixlen": (
40
+ lambda x: x,
41
+ lambda x: x,
42
+ ),
43
+ "sageattn": (
44
+ lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
45
+ lambda x: x,
46
+ ),
47
+ "sageattn_fixlen": (
48
+ lambda x: x.transpose(1, 2),
49
+ lambda x: x.transpose(1, 2),
50
+ ),
51
+ "torch": (
52
+ lambda x: x.transpose(1, 2),
53
+ lambda x: x.transpose(1, 2),
54
+ ),
55
+ "xformers": (
56
+ lambda x: x,
57
+ lambda x: x,
58
+ ),
59
+ "vanilla": (
60
+ lambda x: x.transpose(1, 2),
61
+ lambda x: x.transpose(1, 2),
62
+ ),
63
+ }
64
+
65
+
66
+ def get_cu_seqlens(text_mask, img_len):
67
+ """Calculate cu_seqlens_q, cu_seqlens_kv using text_mask and img_len
68
+
69
+ Args:
70
+ text_mask (torch.Tensor): the mask of text
71
+ img_len (int): the length of image
72
+
73
+ Returns:
74
+ torch.Tensor: the calculated cu_seqlens for flash attention
75
+ """
76
+ batch_size = text_mask.shape[0]
77
+ text_len = text_mask.sum(dim=1)
78
+ max_len = text_mask.shape[1] + img_len
79
+
80
+ cu_seqlens = torch.zeros([2 * batch_size + 1], dtype=torch.int32, device="cuda")
81
+
82
+ for i in range(batch_size):
83
+ s = text_len[i] + img_len
84
+ s1 = i * max_len + s
85
+ s2 = (i + 1) * max_len
86
+ cu_seqlens[2 * i + 1] = s1
87
+ cu_seqlens[2 * i + 2] = s2
88
+
89
+ return cu_seqlens
90
+
91
+
92
+ def attention(
93
+ q_or_qkv_list,
94
+ k=None,
95
+ v=None,
96
+ mode="flash",
97
+ drop_rate=0,
98
+ attn_mask=None,
99
+ total_len=None,
100
+ causal=False,
101
+ cu_seqlens_q=None,
102
+ cu_seqlens_kv=None,
103
+ max_seqlen_q=None,
104
+ max_seqlen_kv=None,
105
+ batch_size=1,
106
+ ):
107
+ """
108
+ Perform QKV self attention.
109
+
110
+ Args:
111
+ q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
112
+ k (torch.Tensor): Key tensor with shape [b, s1, a, d]
113
+ v (torch.Tensor): Value tensor with shape [b, s1, a, d]
114
+ mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
115
+ drop_rate (float): Dropout rate in attention map. (default: 0)
116
+ attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
117
+ (default: None)
118
+ causal (bool): Whether to use causal attention. (default: False)
119
+ cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
120
+ used to index into q.
121
+ cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
122
+ used to index into kv.
123
+ max_seqlen_q (int): The maximum sequence length in the batch of q.
124
+ max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
125
+
126
+ Returns:
127
+ torch.Tensor: Output tensor after self attention with shape [b, s, ad]
128
+ """
129
+ q, k, v = q_or_qkv_list if type(q_or_qkv_list) == list else (q_or_qkv_list, k, v)
130
+ if type(q_or_qkv_list) == list:
131
+ q_or_qkv_list.clear()
132
+ split_attn = total_len is not None
133
+ if split_attn and mode == "sageattn":
134
+ mode = "sageattn_fixlen"
135
+ elif split_attn and mode == "flash":
136
+ mode = "flash_fixlen"
137
+ # print(f"Attention mode: {mode}, split_attn: {split_attn}")
138
+ pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
139
+
140
+ # trim the sequence length to the actual length instead of attn_mask
141
+ if split_attn:
142
+ trimmed_len = q.shape[1] - total_len
143
+ q = [q[i : i + 1, : total_len[i]] for i in range(len(q))]
144
+ k = [k[i : i + 1, : total_len[i]] for i in range(len(k))]
145
+ v = [v[i : i + 1, : total_len[i]] for i in range(len(v))]
146
+ q = [pre_attn_layout(q_i) for q_i in q]
147
+ k = [pre_attn_layout(k_i) for k_i in k]
148
+ v = [pre_attn_layout(v_i) for v_i in v]
149
+ # print(
150
+ # f"Trimming the sequence length to {total_len},trimmed_len: {trimmed_len}, q.shape: {[q_i.shape for q_i in q]}, mode: {mode}"
151
+ # )
152
+ else:
153
+ q = pre_attn_layout(q)
154
+ k = pre_attn_layout(k)
155
+ v = pre_attn_layout(v)
156
+
157
+ if mode == "torch":
158
+ if split_attn:
159
+ x = []
160
+ for i in range(len(q)):
161
+ x_i = F.scaled_dot_product_attention(q[i], k[i], v[i], dropout_p=drop_rate, is_causal=causal)
162
+ q[i], k[i], v[i] = None, None, None
163
+ x.append(x_i)
164
+ del q, k, v
165
+ else:
166
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
167
+ attn_mask = attn_mask.to(q.dtype)
168
+ x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
169
+ del q, k, v
170
+ del attn_mask
171
+
172
+ elif mode == "xformers":
173
+ # B, M, H, K: M is the sequence length, H is the number of heads, K is the dimension of the heads -> it is same as input dimension
174
+ # currently only support batch_size = 1
175
+ assert split_attn, "Xformers only supports splitting"
176
+ x = []
177
+ for i in range(len(q)):
178
+ x_i = xops.memory_efficient_attention(q[i], k[i], v[i], p=drop_rate) # , causal=causal)
179
+ q[i], k[i], v[i] = None, None, None
180
+ x.append(x_i)
181
+ del q, k, v
182
+
183
+ elif mode == "flash":
184
+ x = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
185
+ del q, k, v
186
+ # x with shape [(bxs), a, d]
187
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
188
+ elif mode == "flash_fixlen":
189
+ x = []
190
+ for i in range(len(q)):
191
+ # q: (batch_size, seqlen, nheads, headdim), k: (batch_size, seqlen, nheads_k, headdim), v: (batch_size, seqlen, nheads_k, headdim)
192
+ x_i = flash_attn_func(q[i], k[i], v[i], dropout_p=drop_rate, causal=causal)
193
+ q[i], k[i], v[i] = None, None, None
194
+ x.append(x_i)
195
+ del q, k, v
196
+ elif mode == "sageattn":
197
+ x = sageattn_varlen(q, k, v, cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv)
198
+ del q, k, v
199
+ # x with shape [(bxs), a, d]
200
+ x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
201
+ elif mode == "sageattn_fixlen":
202
+ x = []
203
+ for i in range(len(q)):
204
+ # HND seems to cause an error
205
+ x_i = sageattn(q[i], k[i], v[i]) # (batch_size, seq_len, head_num, head_dim)
206
+ q[i], k[i], v[i] = None, None, None
207
+ x.append(x_i)
208
+ del q, k, v
209
+ elif mode == "vanilla":
210
+ assert not split_attn, "Vanilla attention does not support trimming"
211
+ scale_factor = 1 / math.sqrt(q.size(-1))
212
+
213
+ b, a, s, _ = q.shape
214
+ s1 = k.size(2)
215
+ attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
216
+ if causal:
217
+ # Only applied to self attention
218
+ assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
219
+ temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
220
+ attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
221
+ attn_bias.to(q.dtype)
222
+
223
+ if attn_mask is not None:
224
+ if attn_mask.dtype == torch.bool:
225
+ attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
226
+ else:
227
+ attn_bias += attn_mask
228
+
229
+ # TODO: Maybe force q and k to be float32 to avoid numerical overflow
230
+ attn = (q @ k.transpose(-2, -1)) * scale_factor
231
+ attn += attn_bias
232
+ attn = attn.softmax(dim=-1)
233
+ attn = torch.dropout(attn, p=drop_rate, train=True)
234
+ x = attn @ v
235
+ else:
236
+ raise NotImplementedError(f"Unsupported attention mode: {mode}")
237
+
238
+ if split_attn:
239
+ x = [post_attn_layout(x_i) for x_i in x]
240
+ for i in range(len(x)):
241
+ x[i] = F.pad(x[i], (0, 0, 0, 0, 0, trimmed_len[i]))
242
+ x = torch.cat(x, dim=0)
243
+ else:
244
+ x = post_attn_layout(x)
245
+
246
+ b, s, a, d = x.shape
247
+ out = x.reshape(b, s, -1)
248
+ return out
249
+
250
+
251
+ def parallel_attention(hybrid_seq_parallel_attn, q, k, v, img_q_len, img_kv_len, cu_seqlens_q, cu_seqlens_kv):
252
+ attn1 = hybrid_seq_parallel_attn(
253
+ None,
254
+ q[:, :img_q_len, :, :],
255
+ k[:, :img_kv_len, :, :],
256
+ v[:, :img_kv_len, :, :],
257
+ dropout_p=0.0,
258
+ causal=False,
259
+ joint_tensor_query=q[:, img_q_len : cu_seqlens_q[1]],
260
+ joint_tensor_key=k[:, img_kv_len : cu_seqlens_kv[1]],
261
+ joint_tensor_value=v[:, img_kv_len : cu_seqlens_kv[1]],
262
+ joint_strategy="rear",
263
+ )
264
+ if flash_attn.__version__ >= "2.7.0":
265
+ attn2, *_ = _flash_attn_forward(
266
+ q[:, cu_seqlens_q[1] :],
267
+ k[:, cu_seqlens_kv[1] :],
268
+ v[:, cu_seqlens_kv[1] :],
269
+ dropout_p=0.0,
270
+ softmax_scale=q.shape[-1] ** (-0.5),
271
+ causal=False,
272
+ window_size_left=-1,
273
+ window_size_right=-1,
274
+ softcap=0.0,
275
+ alibi_slopes=None,
276
+ return_softmax=False,
277
+ )
278
+ else:
279
+ attn2, *_ = _flash_attn_forward(
280
+ q[:, cu_seqlens_q[1] :],
281
+ k[:, cu_seqlens_kv[1] :],
282
+ v[:, cu_seqlens_kv[1] :],
283
+ dropout_p=0.0,
284
+ softmax_scale=q.shape[-1] ** (-0.5),
285
+ causal=False,
286
+ window_size=(-1, -1),
287
+ softcap=0.0,
288
+ alibi_slopes=None,
289
+ return_softmax=False,
290
+ )
291
+ attn = torch.cat([attn1, attn2], dim=1)
292
+ b, s, a, d = attn.shape
293
+ attn = attn.reshape(b, s, -1)
294
+
295
+ return attn
hunyuan_model/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ from typing import Dict, Optional, Tuple, Union
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+
27
+ # try:
28
+ # # This diffusers is modified and packed in the mirror.
29
+ # from diffusers.loaders import FromOriginalVAEMixin
30
+ # except ImportError:
31
+ # # Use this to be compatible with the original diffusers.
32
+ # from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
33
+ from diffusers.utils.accelerate_utils import apply_forward_hook
34
+ from diffusers.models.attention_processor import (
35
+ ADDED_KV_ATTENTION_PROCESSORS,
36
+ CROSS_ATTENTION_PROCESSORS,
37
+ Attention,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
45
+
46
+
47
+ @dataclass
48
+ class DecoderOutput2(BaseOutput):
49
+ sample: torch.FloatTensor
50
+ posterior: Optional[DiagonalGaussianDistribution] = None
51
+
52
+
53
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin):
54
+ r"""
55
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
56
+
57
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
58
+ for all models (such as downloading or saving).
59
+ """
60
+
61
+ _supports_gradient_checkpointing = True
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
69
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
70
+ block_out_channels: Tuple[int] = (64,),
71
+ layers_per_block: int = 1,
72
+ act_fn: str = "silu",
73
+ latent_channels: int = 4,
74
+ norm_num_groups: int = 32,
75
+ sample_size: int = 32,
76
+ sample_tsize: int = 64,
77
+ scaling_factor: float = 0.18215,
78
+ force_upcast: float = True,
79
+ spatial_compression_ratio: int = 8,
80
+ time_compression_ratio: int = 4,
81
+ mid_block_add_attention: bool = True,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.time_compression_ratio = time_compression_ratio
86
+
87
+ self.encoder = EncoderCausal3D(
88
+ in_channels=in_channels,
89
+ out_channels=latent_channels,
90
+ down_block_types=down_block_types,
91
+ block_out_channels=block_out_channels,
92
+ layers_per_block=layers_per_block,
93
+ act_fn=act_fn,
94
+ norm_num_groups=norm_num_groups,
95
+ double_z=True,
96
+ time_compression_ratio=time_compression_ratio,
97
+ spatial_compression_ratio=spatial_compression_ratio,
98
+ mid_block_add_attention=mid_block_add_attention,
99
+ )
100
+
101
+ self.decoder = DecoderCausal3D(
102
+ in_channels=latent_channels,
103
+ out_channels=out_channels,
104
+ up_block_types=up_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ norm_num_groups=norm_num_groups,
108
+ act_fn=act_fn,
109
+ time_compression_ratio=time_compression_ratio,
110
+ spatial_compression_ratio=spatial_compression_ratio,
111
+ mid_block_add_attention=mid_block_add_attention,
112
+ )
113
+
114
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
115
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
116
+
117
+ self.use_slicing = False
118
+ self.use_spatial_tiling = False
119
+ self.use_temporal_tiling = False
120
+
121
+ # only relevant if vae tiling is enabled
122
+ self.tile_sample_min_tsize = sample_tsize
123
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
124
+
125
+ self.tile_sample_min_size = self.config.sample_size
126
+ sample_size = self.config.sample_size[0] if isinstance(self.config.sample_size, (list, tuple)) else self.config.sample_size
127
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
128
+ self.tile_overlap_factor = 0.25
129
+
130
+ def _set_gradient_checkpointing(self, module, value=False):
131
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
132
+ module.gradient_checkpointing = value
133
+
134
+ def enable_temporal_tiling(self, use_tiling: bool = True):
135
+ self.use_temporal_tiling = use_tiling
136
+
137
+ def disable_temporal_tiling(self):
138
+ self.enable_temporal_tiling(False)
139
+
140
+ def enable_spatial_tiling(self, use_tiling: bool = True):
141
+ self.use_spatial_tiling = use_tiling
142
+
143
+ def disable_spatial_tiling(self):
144
+ self.enable_spatial_tiling(False)
145
+
146
+ def enable_tiling(self, use_tiling: bool = True):
147
+ r"""
148
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
149
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
150
+ processing larger videos.
151
+ """
152
+ self.enable_spatial_tiling(use_tiling)
153
+ self.enable_temporal_tiling(use_tiling)
154
+
155
+ def disable_tiling(self):
156
+ r"""
157
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
158
+ decoding in one step.
159
+ """
160
+ self.disable_spatial_tiling()
161
+ self.disable_temporal_tiling()
162
+
163
+ def enable_slicing(self):
164
+ r"""
165
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
166
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
167
+ """
168
+ self.use_slicing = True
169
+
170
+ def disable_slicing(self):
171
+ r"""
172
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
173
+ decoding in one step.
174
+ """
175
+ self.use_slicing = False
176
+
177
+ def set_chunk_size_for_causal_conv_3d(self, chunk_size: int):
178
+ # set chunk_size to CausalConv3d recursively
179
+ def set_chunk_size(module):
180
+ if hasattr(module, "chunk_size"):
181
+ module.chunk_size = chunk_size
182
+
183
+ self.apply(set_chunk_size)
184
+
185
+ @property
186
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
187
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
188
+ r"""
189
+ Returns:
190
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
191
+ indexed by its weight name.
192
+ """
193
+ # set recursively
194
+ processors = {}
195
+
196
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
197
+ if hasattr(module, "get_processor"):
198
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
199
+
200
+ for sub_name, child in module.named_children():
201
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
202
+
203
+ return processors
204
+
205
+ for name, module in self.named_children():
206
+ fn_recursive_add_processors(name, module, processors)
207
+
208
+ return processors
209
+
210
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
211
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False):
212
+ r"""
213
+ Sets the attention processor to use to compute attention.
214
+
215
+ Parameters:
216
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
217
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
218
+ for **all** `Attention` layers.
219
+
220
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
221
+ processor. This is strongly recommended when setting trainable attention processors.
222
+
223
+ """
224
+ count = len(self.attn_processors.keys())
225
+
226
+ if isinstance(processor, dict) and len(processor) != count:
227
+ raise ValueError(
228
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
229
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
230
+ )
231
+
232
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
233
+ if hasattr(module, "set_processor"):
234
+ if not isinstance(processor, dict):
235
+ module.set_processor(processor, _remove_lora=_remove_lora)
236
+ else:
237
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
238
+
239
+ for sub_name, child in module.named_children():
240
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
241
+
242
+ for name, module in self.named_children():
243
+ fn_recursive_attn_processor(name, module, processor)
244
+
245
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
246
+ def set_default_attn_processor(self):
247
+ """
248
+ Disables custom attention processors and sets the default attention implementation.
249
+ """
250
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
251
+ processor = AttnAddedKVProcessor()
252
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
253
+ processor = AttnProcessor()
254
+ else:
255
+ raise ValueError(
256
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
257
+ )
258
+
259
+ self.set_attn_processor(processor, _remove_lora=True)
260
+
261
+ @apply_forward_hook
262
+ def encode(
263
+ self, x: torch.FloatTensor, return_dict: bool = True
264
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
265
+ """
266
+ Encode a batch of images/videos into latents.
267
+
268
+ Args:
269
+ x (`torch.FloatTensor`): Input batch of images/videos.
270
+ return_dict (`bool`, *optional*, defaults to `True`):
271
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
272
+
273
+ Returns:
274
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
275
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
276
+ """
277
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
278
+
279
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
280
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
281
+
282
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
283
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
284
+
285
+ if self.use_slicing and x.shape[0] > 1:
286
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
287
+ h = torch.cat(encoded_slices)
288
+ else:
289
+ h = self.encoder(x)
290
+
291
+ moments = self.quant_conv(h)
292
+ posterior = DiagonalGaussianDistribution(moments)
293
+
294
+ if not return_dict:
295
+ return (posterior,)
296
+
297
+ return AutoencoderKLOutput(latent_dist=posterior)
298
+
299
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
300
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
301
+
302
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
303
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
304
+
305
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
306
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
307
+
308
+ z = self.post_quant_conv(z)
309
+ dec = self.decoder(z)
310
+
311
+ if not return_dict:
312
+ return (dec,)
313
+
314
+ return DecoderOutput(sample=dec)
315
+
316
+ @apply_forward_hook
317
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.FloatTensor]:
318
+ """
319
+ Decode a batch of images/videos.
320
+
321
+ Args:
322
+ z (`torch.FloatTensor`): Input batch of latent vectors.
323
+ return_dict (`bool`, *optional*, defaults to `True`):
324
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
325
+
326
+ Returns:
327
+ [`~models.vae.DecoderOutput`] or `tuple`:
328
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
329
+ returned.
330
+
331
+ """
332
+ if self.use_slicing and z.shape[0] > 1:
333
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
334
+ decoded = torch.cat(decoded_slices)
335
+ else:
336
+ decoded = self._decode(z).sample
337
+
338
+ if not return_dict:
339
+ return (decoded,)
340
+
341
+ return DecoderOutput(sample=decoded)
342
+
343
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
345
+ for y in range(blend_extent):
346
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
347
+ return b
348
+
349
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
351
+ for x in range(blend_extent):
352
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
353
+ return b
354
+
355
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
357
+ for x in range(blend_extent):
358
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
359
+ return b
360
+
361
+ def spatial_tiled_encode(
362
+ self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False
363
+ ) -> AutoencoderKLOutput:
364
+ r"""Encode a batch of images/videos using a tiled encoder.
365
+
366
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
367
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
368
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
369
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
370
+ output, but they should be much less noticeable.
371
+
372
+ Args:
373
+ x (`torch.FloatTensor`): Input batch of images/videos.
374
+ return_dict (`bool`, *optional*, defaults to `True`):
375
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
376
+
377
+ Returns:
378
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
379
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
380
+ `tuple` is returned.
381
+ """
382
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
383
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
384
+ row_limit = self.tile_latent_min_size - blend_extent
385
+
386
+ # Split video into tiles and encode them separately.
387
+ rows = []
388
+ for i in range(0, x.shape[-2], overlap_size):
389
+ row = []
390
+ for j in range(0, x.shape[-1], overlap_size):
391
+ tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
392
+ tile = self.encoder(tile)
393
+ tile = self.quant_conv(tile)
394
+ row.append(tile)
395
+ rows.append(row)
396
+ result_rows = []
397
+ for i, row in enumerate(rows):
398
+ result_row = []
399
+ for j, tile in enumerate(row):
400
+ # blend the above tile and the left tile
401
+ # to the current tile and add the current tile to the result row
402
+ if i > 0:
403
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
404
+ if j > 0:
405
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
406
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
407
+ result_rows.append(torch.cat(result_row, dim=-1))
408
+
409
+ moments = torch.cat(result_rows, dim=-2)
410
+ if return_moments:
411
+ return moments
412
+
413
+ posterior = DiagonalGaussianDistribution(moments)
414
+ if not return_dict:
415
+ return (posterior,)
416
+
417
+ return AutoencoderKLOutput(latent_dist=posterior)
418
+
419
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
420
+ r"""
421
+ Decode a batch of images/videos using a tiled decoder.
422
+
423
+ Args:
424
+ z (`torch.FloatTensor`): Input batch of latent vectors.
425
+ return_dict (`bool`, *optional*, defaults to `True`):
426
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
427
+
428
+ Returns:
429
+ [`~models.vae.DecoderOutput`] or `tuple`:
430
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
431
+ returned.
432
+ """
433
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
434
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
435
+ row_limit = self.tile_sample_min_size - blend_extent
436
+
437
+ # Split z into overlapping tiles and decode them separately.
438
+ # The tiles have an overlap to avoid seams between tiles.
439
+ rows = []
440
+ for i in range(0, z.shape[-2], overlap_size):
441
+ row = []
442
+ for j in range(0, z.shape[-1], overlap_size):
443
+ tile = z[:, :, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
444
+ tile = self.post_quant_conv(tile)
445
+ decoded = self.decoder(tile)
446
+ row.append(decoded)
447
+ rows.append(row)
448
+ result_rows = []
449
+ for i, row in enumerate(rows):
450
+ result_row = []
451
+ for j, tile in enumerate(row):
452
+ # blend the above tile and the left tile
453
+ # to the current tile and add the current tile to the result row
454
+ if i > 0:
455
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
456
+ if j > 0:
457
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
458
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
459
+ result_rows.append(torch.cat(result_row, dim=-1))
460
+
461
+ dec = torch.cat(result_rows, dim=-2)
462
+ if not return_dict:
463
+ return (dec,)
464
+
465
+ return DecoderOutput(sample=dec)
466
+
467
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
468
+
469
+ B, C, T, H, W = x.shape
470
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
471
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
472
+ t_limit = self.tile_latent_min_tsize - blend_extent
473
+
474
+ # Split the video into tiles and encode them separately.
475
+ row = []
476
+ for i in range(0, T, overlap_size):
477
+ tile = x[:, :, i : i + self.tile_sample_min_tsize + 1, :, :]
478
+ if self.use_spatial_tiling and (
479
+ tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size
480
+ ):
481
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
482
+ else:
483
+ tile = self.encoder(tile)
484
+ tile = self.quant_conv(tile)
485
+ if i > 0:
486
+ tile = tile[:, :, 1:, :, :]
487
+ row.append(tile)
488
+ result_row = []
489
+ for i, tile in enumerate(row):
490
+ if i > 0:
491
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
492
+ result_row.append(tile[:, :, :t_limit, :, :])
493
+ else:
494
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
495
+
496
+ moments = torch.cat(result_row, dim=2)
497
+ posterior = DiagonalGaussianDistribution(moments)
498
+
499
+ if not return_dict:
500
+ return (posterior,)
501
+
502
+ return AutoencoderKLOutput(latent_dist=posterior)
503
+
504
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
505
+ # Split z into overlapping tiles and decode them separately.
506
+
507
+ B, C, T, H, W = z.shape
508
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
509
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
510
+ t_limit = self.tile_sample_min_tsize - blend_extent
511
+
512
+ row = []
513
+ for i in range(0, T, overlap_size):
514
+ tile = z[:, :, i : i + self.tile_latent_min_tsize + 1, :, :]
515
+ if self.use_spatial_tiling and (
516
+ tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size
517
+ ):
518
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
519
+ else:
520
+ tile = self.post_quant_conv(tile)
521
+ decoded = self.decoder(tile)
522
+ if i > 0:
523
+ decoded = decoded[:, :, 1:, :, :]
524
+ row.append(decoded)
525
+ result_row = []
526
+ for i, tile in enumerate(row):
527
+ if i > 0:
528
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
529
+ result_row.append(tile[:, :, :t_limit, :, :])
530
+ else:
531
+ result_row.append(tile[:, :, : t_limit + 1, :, :])
532
+
533
+ dec = torch.cat(result_row, dim=2)
534
+ if not return_dict:
535
+ return (dec,)
536
+
537
+ return DecoderOutput(sample=dec)
538
+
539
+ def forward(
540
+ self,
541
+ sample: torch.FloatTensor,
542
+ sample_posterior: bool = False,
543
+ return_dict: bool = True,
544
+ return_posterior: bool = False,
545
+ generator: Optional[torch.Generator] = None,
546
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
547
+ r"""
548
+ Args:
549
+ sample (`torch.FloatTensor`): Input sample.
550
+ sample_posterior (`bool`, *optional*, defaults to `False`):
551
+ Whether to sample from the posterior.
552
+ return_dict (`bool`, *optional*, defaults to `True`):
553
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
554
+ """
555
+ x = sample
556
+ posterior = self.encode(x).latent_dist
557
+ if sample_posterior:
558
+ z = posterior.sample(generator=generator)
559
+ else:
560
+ z = posterior.mode()
561
+ dec = self.decode(z).sample
562
+
563
+ if not return_dict:
564
+ if return_posterior:
565
+ return (dec, posterior)
566
+ else:
567
+ return (dec,)
568
+ if return_posterior:
569
+ return DecoderOutput2(sample=dec, posterior=posterior)
570
+ else:
571
+ return DecoderOutput2(sample=dec)
572
+
573
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
574
+ def fuse_qkv_projections(self):
575
+ """
576
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
577
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
578
+
579
+ <Tip warning={true}>
580
+
581
+ This API is 🧪 experimental.
582
+
583
+ </Tip>
584
+ """
585
+ self.original_attn_processors = None
586
+
587
+ for _, attn_processor in self.attn_processors.items():
588
+ if "Added" in str(attn_processor.__class__.__name__):
589
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
590
+
591
+ self.original_attn_processors = self.attn_processors
592
+
593
+ for module in self.modules():
594
+ if isinstance(module, Attention):
595
+ module.fuse_projections(fuse=True)
596
+
597
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
598
+ def unfuse_qkv_projections(self):
599
+ """Disables the fused QKV projection if enabled.
600
+
601
+ <Tip warning={true}>
602
+
603
+ This API is 🧪 experimental.
604
+
605
+ </Tip>
606
+
607
+ """
608
+ if self.original_attn_processors is not None:
609
+ self.set_attn_processor(self.original_attn_processors)
hunyuan_model/embed_layers.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ from einops import rearrange, repeat
6
+
7
+ from .helpers import to_2tuple
8
+
9
+ class PatchEmbed(nn.Module):
10
+ """2D Image to Patch Embedding
11
+
12
+ Image to Patch Embedding using Conv2d
13
+
14
+ A convolution based approach to patchifying a 2D image w/ embedding projection.
15
+
16
+ Based on the impl in https://github.com/google-research/vision_transformer
17
+
18
+ Hacked together by / Copyright 2020 Ross Wightman
19
+
20
+ Remove the _assert function in forward function to be compatible with multi-resolution images.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ patch_size=16,
26
+ in_chans=3,
27
+ embed_dim=768,
28
+ norm_layer=None,
29
+ flatten=True,
30
+ bias=True,
31
+ dtype=None,
32
+ device=None,
33
+ ):
34
+ factory_kwargs = {"dtype": dtype, "device": device}
35
+ super().__init__()
36
+ patch_size = to_2tuple(patch_size)
37
+ self.patch_size = patch_size
38
+ self.flatten = flatten
39
+
40
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias, **factory_kwargs)
41
+ nn.init.xavier_uniform_(self.proj.weight.view(self.proj.weight.size(0), -1))
42
+ if bias:
43
+ nn.init.zeros_(self.proj.bias)
44
+
45
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
46
+
47
+ def forward(self, x):
48
+ x = self.proj(x)
49
+ if self.flatten:
50
+ x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
51
+ x = self.norm(x)
52
+ return x
53
+
54
+
55
+ class TextProjection(nn.Module):
56
+ """
57
+ Projects text embeddings. Also handles dropout for classifier-free guidance.
58
+
59
+ Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
60
+ """
61
+
62
+ def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None):
63
+ factory_kwargs = {"dtype": dtype, "device": device}
64
+ super().__init__()
65
+ self.linear_1 = nn.Linear(in_features=in_channels, out_features=hidden_size, bias=True, **factory_kwargs)
66
+ self.act_1 = act_layer()
67
+ self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True, **factory_kwargs)
68
+
69
+ def forward(self, caption):
70
+ hidden_states = self.linear_1(caption)
71
+ hidden_states = self.act_1(hidden_states)
72
+ hidden_states = self.linear_2(hidden_states)
73
+ return hidden_states
74
+
75
+
76
+ def timestep_embedding(t, dim, max_period=10000):
77
+ """
78
+ Create sinusoidal timestep embeddings.
79
+
80
+ Args:
81
+ t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional.
82
+ dim (int): the dimension of the output.
83
+ max_period (int): controls the minimum frequency of the embeddings.
84
+
85
+ Returns:
86
+ embedding (torch.Tensor): An (N, D) Tensor of positional embeddings.
87
+
88
+ .. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
89
+ """
90
+ half = dim // 2
91
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(device=t.device)
92
+ args = t[:, None].float() * freqs[None]
93
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
94
+ if dim % 2:
95
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
96
+ return embedding
97
+
98
+
99
+ class TimestepEmbedder(nn.Module):
100
+ """
101
+ Embeds scalar timesteps into vector representations.
102
+ """
103
+
104
+ def __init__(
105
+ self,
106
+ hidden_size,
107
+ act_layer,
108
+ frequency_embedding_size=256,
109
+ max_period=10000,
110
+ out_size=None,
111
+ dtype=None,
112
+ device=None,
113
+ ):
114
+ factory_kwargs = {"dtype": dtype, "device": device}
115
+ super().__init__()
116
+ self.frequency_embedding_size = frequency_embedding_size
117
+ self.max_period = max_period
118
+ if out_size is None:
119
+ out_size = hidden_size
120
+
121
+ self.mlp = nn.Sequential(
122
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True, **factory_kwargs),
123
+ act_layer(),
124
+ nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs),
125
+ )
126
+ nn.init.normal_(self.mlp[0].weight, std=0.02)
127
+ nn.init.normal_(self.mlp[2].weight, std=0.02)
128
+
129
+ def forward(self, t):
130
+ t_freq = timestep_embedding(t, self.frequency_embedding_size, self.max_period).type(self.mlp[0].weight.dtype)
131
+ t_emb = self.mlp(t_freq)
132
+ return t_emb
hunyuan_model/fp8_optimization.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #based on ComfyUI's and MinusZoneAI's fp8_linear optimization
2
+ #further borrowed from HunyuanVideoWrapper for Musubi Tuner
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ def fp8_linear_forward(cls, original_dtype, input):
7
+ weight_dtype = cls.weight.dtype
8
+ if weight_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
9
+ if len(input.shape) == 3:
10
+ target_dtype = torch.float8_e5m2 if weight_dtype == torch.float8_e4m3fn else torch.float8_e4m3fn
11
+ inn = input.reshape(-1, input.shape[2]).to(target_dtype)
12
+ w = cls.weight.t()
13
+
14
+ scale = torch.ones((1), device=input.device, dtype=torch.float32)
15
+ bias = cls.bias.to(original_dtype) if cls.bias is not None else None
16
+
17
+ if bias is not None:
18
+ o = torch._scaled_mm(inn, w, out_dtype=original_dtype, bias=bias, scale_a=scale, scale_b=scale)
19
+ else:
20
+ o = torch._scaled_mm(inn, w, out_dtype=original_dtype, scale_a=scale, scale_b=scale)
21
+
22
+ if isinstance(o, tuple):
23
+ o = o[0]
24
+
25
+ return o.reshape((-1, input.shape[1], cls.weight.shape[0]))
26
+ else:
27
+ return cls.original_forward(input.to(original_dtype))
28
+ else:
29
+ return cls.original_forward(input)
30
+
31
+ def convert_fp8_linear(module, original_dtype, params_to_keep={}):
32
+ setattr(module, "fp8_matmul_enabled", True)
33
+
34
+ for name, module in module.named_modules():
35
+ if not any(keyword in name for keyword in params_to_keep):
36
+ if isinstance(module, nn.Linear):
37
+ original_forward = module.forward
38
+ setattr(module, "original_forward", original_forward)
39
+ setattr(module, "forward", lambda input, m=module: fp8_linear_forward(m, original_dtype, input))
hunyuan_model/helpers.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections.abc
2
+
3
+ from itertools import repeat
4
+
5
+
6
+ def _ntuple(n):
7
+ def parse(x):
8
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
9
+ x = tuple(x)
10
+ if len(x) == 1:
11
+ x = tuple(repeat(x[0], n))
12
+ return x
13
+ return tuple(repeat(x, n))
14
+ return parse
15
+
16
+
17
+ to_1tuple = _ntuple(1)
18
+ to_2tuple = _ntuple(2)
19
+ to_3tuple = _ntuple(3)
20
+ to_4tuple = _ntuple(4)
21
+
22
+
23
+ def as_tuple(x):
24
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
+ return tuple(x)
26
+ if x is None or isinstance(x, (int, float, str)):
27
+ return (x,)
28
+ else:
29
+ raise ValueError(f"Unknown type {type(x)}")
30
+
31
+
32
+ def as_list_of_2tuple(x):
33
+ x = as_tuple(x)
34
+ if len(x) == 1:
35
+ x = (x[0], x[0])
36
+ assert len(x) % 2 == 0, f"Expect even length, got {len(x)}."
37
+ lst = []
38
+ for i in range(0, len(x), 2):
39
+ lst.append((x[i], x[i + 1]))
40
+ return lst
hunyuan_model/mlp_layers.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from timm library:
2
+ # https://github.com/huggingface/pytorch-image-models/blob/648aaa41233ba83eb38faf5ba9d415d574823241/timm/layers/mlp.py#L13
3
+
4
+ from functools import partial
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ from .modulate_layers import modulate
10
+ from .helpers import to_2tuple
11
+
12
+
13
+ class MLP(nn.Module):
14
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
15
+
16
+ def __init__(
17
+ self,
18
+ in_channels,
19
+ hidden_channels=None,
20
+ out_features=None,
21
+ act_layer=nn.GELU,
22
+ norm_layer=None,
23
+ bias=True,
24
+ drop=0.0,
25
+ use_conv=False,
26
+ device=None,
27
+ dtype=None,
28
+ ):
29
+ factory_kwargs = {"device": device, "dtype": dtype}
30
+ super().__init__()
31
+ out_features = out_features or in_channels
32
+ hidden_channels = hidden_channels or in_channels
33
+ bias = to_2tuple(bias)
34
+ drop_probs = to_2tuple(drop)
35
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
36
+
37
+ self.fc1 = linear_layer(
38
+ in_channels, hidden_channels, bias=bias[0], **factory_kwargs
39
+ )
40
+ self.act = act_layer()
41
+ self.drop1 = nn.Dropout(drop_probs[0])
42
+ self.norm = (
43
+ norm_layer(hidden_channels, **factory_kwargs)
44
+ if norm_layer is not None
45
+ else nn.Identity()
46
+ )
47
+ self.fc2 = linear_layer(
48
+ hidden_channels, out_features, bias=bias[1], **factory_kwargs
49
+ )
50
+ self.drop2 = nn.Dropout(drop_probs[1])
51
+
52
+ def forward(self, x):
53
+ x = self.fc1(x)
54
+ x = self.act(x)
55
+ x = self.drop1(x)
56
+ x = self.norm(x)
57
+ x = self.fc2(x)
58
+ x = self.drop2(x)
59
+ return x
60
+
61
+
62
+ #
63
+ class MLPEmbedder(nn.Module):
64
+ """copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py"""
65
+ def __init__(self, in_dim: int, hidden_dim: int, device=None, dtype=None):
66
+ factory_kwargs = {"device": device, "dtype": dtype}
67
+ super().__init__()
68
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True, **factory_kwargs)
69
+ self.silu = nn.SiLU()
70
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True, **factory_kwargs)
71
+
72
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
73
+ return self.out_layer(self.silu(self.in_layer(x)))
74
+
75
+
76
+ class FinalLayer(nn.Module):
77
+ """The final layer of DiT."""
78
+
79
+ def __init__(
80
+ self, hidden_size, patch_size, out_channels, act_layer, device=None, dtype=None
81
+ ):
82
+ factory_kwargs = {"device": device, "dtype": dtype}
83
+ super().__init__()
84
+
85
+ # Just use LayerNorm for the final layer
86
+ self.norm_final = nn.LayerNorm(
87
+ hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs
88
+ )
89
+ if isinstance(patch_size, int):
90
+ self.linear = nn.Linear(
91
+ hidden_size,
92
+ patch_size * patch_size * out_channels,
93
+ bias=True,
94
+ **factory_kwargs
95
+ )
96
+ else:
97
+ self.linear = nn.Linear(
98
+ hidden_size,
99
+ patch_size[0] * patch_size[1] * patch_size[2] * out_channels,
100
+ bias=True,
101
+ )
102
+ nn.init.zeros_(self.linear.weight)
103
+ nn.init.zeros_(self.linear.bias)
104
+
105
+ # Here we don't distinguish between the modulate types. Just use the simple one.
106
+ self.adaLN_modulation = nn.Sequential(
107
+ act_layer(),
108
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
109
+ )
110
+ # Zero-initialize the modulation
111
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
112
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
113
+
114
+ def forward(self, x, c):
115
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
116
+ x = modulate(self.norm_final(x), shift=shift, scale=scale)
117
+ x = self.linear(x)
118
+ return x
hunyuan_model/models.py ADDED
@@ -0,0 +1,1044 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, List, Tuple, Optional, Union, Dict
3
+ import accelerate
4
+ from einops import rearrange
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from .activation_layers import get_activation_layer
11
+ from .norm_layers import get_norm_layer
12
+ from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection
13
+ from .attention import attention, parallel_attention, get_cu_seqlens
14
+ from .posemb_layers import apply_rotary_emb
15
+ from .mlp_layers import MLP, MLPEmbedder, FinalLayer
16
+ from .modulate_layers import ModulateDiT, modulate, apply_gate
17
+ from .token_refiner import SingleTokenRefiner
18
+ from modules.custom_offloading_utils import ModelOffloader, synchronize_device, clean_memory_on_device
19
+ from hunyuan_model.posemb_layers import get_nd_rotary_pos_embed
20
+
21
+ from utils.safetensors_utils import MemoryEfficientSafeOpen
22
+
23
+
24
+ class MMDoubleStreamBlock(nn.Module):
25
+ """
26
+ A multimodal dit block with seperate modulation for
27
+ text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206
28
+ (Flux.1): https://github.com/black-forest-labs/flux
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ hidden_size: int,
34
+ heads_num: int,
35
+ mlp_width_ratio: float,
36
+ mlp_act_type: str = "gelu_tanh",
37
+ qk_norm: bool = True,
38
+ qk_norm_type: str = "rms",
39
+ qkv_bias: bool = False,
40
+ dtype: Optional[torch.dtype] = None,
41
+ device: Optional[torch.device] = None,
42
+ attn_mode: str = "flash",
43
+ split_attn: bool = False,
44
+ ):
45
+ factory_kwargs = {"device": device, "dtype": dtype}
46
+ super().__init__()
47
+ self.attn_mode = attn_mode
48
+ self.split_attn = split_attn
49
+
50
+ self.deterministic = False
51
+ self.heads_num = heads_num
52
+ head_dim = hidden_size // heads_num
53
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
54
+
55
+ self.img_mod = ModulateDiT(
56
+ hidden_size,
57
+ factor=6,
58
+ act_layer=get_activation_layer("silu"),
59
+ **factory_kwargs,
60
+ )
61
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
62
+
63
+ self.img_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
64
+ qk_norm_layer = get_norm_layer(qk_norm_type)
65
+ self.img_attn_q_norm = (
66
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
67
+ )
68
+ self.img_attn_k_norm = (
69
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
70
+ )
71
+ self.img_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
72
+
73
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
74
+ self.img_mlp = MLP(
75
+ hidden_size,
76
+ mlp_hidden_dim,
77
+ act_layer=get_activation_layer(mlp_act_type),
78
+ bias=True,
79
+ **factory_kwargs,
80
+ )
81
+
82
+ self.txt_mod = ModulateDiT(
83
+ hidden_size,
84
+ factor=6,
85
+ act_layer=get_activation_layer("silu"),
86
+ **factory_kwargs,
87
+ )
88
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
89
+
90
+ self.txt_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
91
+ self.txt_attn_q_norm = (
92
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
93
+ )
94
+ self.txt_attn_k_norm = (
95
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
96
+ )
97
+ self.txt_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
98
+
99
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
100
+ self.txt_mlp = MLP(
101
+ hidden_size,
102
+ mlp_hidden_dim,
103
+ act_layer=get_activation_layer(mlp_act_type),
104
+ bias=True,
105
+ **factory_kwargs,
106
+ )
107
+ self.hybrid_seq_parallel_attn = None
108
+
109
+ self.gradient_checkpointing = False
110
+
111
+ def enable_deterministic(self):
112
+ self.deterministic = True
113
+
114
+ def disable_deterministic(self):
115
+ self.deterministic = False
116
+
117
+ def enable_gradient_checkpointing(self):
118
+ self.gradient_checkpointing = True
119
+
120
+ def disable_gradient_checkpointing(self):
121
+ self.gradient_checkpointing = False
122
+
123
+ def _forward(
124
+ self,
125
+ img: torch.Tensor,
126
+ txt: torch.Tensor,
127
+ vec: torch.Tensor,
128
+ attn_mask: Optional[torch.Tensor] = None,
129
+ total_len: Optional[torch.Tensor] = None,
130
+ cu_seqlens_q: Optional[torch.Tensor] = None,
131
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
132
+ max_seqlen_q: Optional[int] = None,
133
+ max_seqlen_kv: Optional[int] = None,
134
+ freqs_cis: tuple = None,
135
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
136
+ (img_mod1_shift, img_mod1_scale, img_mod1_gate, img_mod2_shift, img_mod2_scale, img_mod2_gate) = self.img_mod(vec).chunk(
137
+ 6, dim=-1
138
+ )
139
+ (txt_mod1_shift, txt_mod1_scale, txt_mod1_gate, txt_mod2_shift, txt_mod2_scale, txt_mod2_gate) = self.txt_mod(vec).chunk(
140
+ 6, dim=-1
141
+ )
142
+
143
+ # Prepare image for attention.
144
+ img_modulated = self.img_norm1(img)
145
+ img_modulated = modulate(img_modulated, shift=img_mod1_shift, scale=img_mod1_scale)
146
+ img_qkv = self.img_attn_qkv(img_modulated)
147
+ img_modulated = None
148
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
149
+ img_qkv = None
150
+ # Apply QK-Norm if needed
151
+ img_q = self.img_attn_q_norm(img_q).to(img_v)
152
+ img_k = self.img_attn_k_norm(img_k).to(img_v)
153
+
154
+ # Apply RoPE if needed.
155
+ if freqs_cis is not None:
156
+ img_q_shape = img_q.shape
157
+ img_k_shape = img_k.shape
158
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
159
+ assert (
160
+ img_q.shape == img_q_shape and img_k.shape == img_k_shape
161
+ ), f"img_kk: {img_q.shape}, img_q: {img_q_shape}, img_kk: {img_k.shape}, img_k: {img_k_shape}"
162
+ # img_q, img_k = img_qq, img_kk
163
+
164
+ # Prepare txt for attention.
165
+ txt_modulated = self.txt_norm1(txt)
166
+ txt_modulated = modulate(txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale)
167
+ txt_qkv = self.txt_attn_qkv(txt_modulated)
168
+ txt_modulated = None
169
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
170
+ txt_qkv = None
171
+ # Apply QK-Norm if needed.
172
+ txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
173
+ txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)
174
+
175
+ # Run actual attention.
176
+ img_q_len = img_q.shape[1]
177
+ img_kv_len = img_k.shape[1]
178
+ batch_size = img_k.shape[0]
179
+ q = torch.cat((img_q, txt_q), dim=1)
180
+ img_q = txt_q = None
181
+ k = torch.cat((img_k, txt_k), dim=1)
182
+ img_k = txt_k = None
183
+ v = torch.cat((img_v, txt_v), dim=1)
184
+ img_v = txt_v = None
185
+
186
+ assert (
187
+ cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
188
+ ), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"
189
+
190
+ # attention computation start
191
+ if not self.hybrid_seq_parallel_attn:
192
+ l = [q, k, v]
193
+ q = k = v = None
194
+ attn = attention(
195
+ l,
196
+ mode=self.attn_mode,
197
+ attn_mask=attn_mask,
198
+ total_len=total_len,
199
+ cu_seqlens_q=cu_seqlens_q,
200
+ cu_seqlens_kv=cu_seqlens_kv,
201
+ max_seqlen_q=max_seqlen_q,
202
+ max_seqlen_kv=max_seqlen_kv,
203
+ batch_size=batch_size,
204
+ )
205
+ else:
206
+ attn = parallel_attention(
207
+ self.hybrid_seq_parallel_attn,
208
+ q,
209
+ k,
210
+ v,
211
+ img_q_len=img_q_len,
212
+ img_kv_len=img_kv_len,
213
+ cu_seqlens_q=cu_seqlens_q,
214
+ cu_seqlens_kv=cu_seqlens_kv,
215
+ )
216
+
217
+ # attention computation end
218
+
219
+ img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
220
+ attn = None
221
+
222
+ # Calculate the img bloks.
223
+ img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
224
+ img_attn = None
225
+ img = img + apply_gate(
226
+ self.img_mlp(modulate(self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale)),
227
+ gate=img_mod2_gate,
228
+ )
229
+
230
+ # Calculate the txt bloks.
231
+ txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate)
232
+ txt_attn = None
233
+ txt = txt + apply_gate(
234
+ self.txt_mlp(modulate(self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale)),
235
+ gate=txt_mod2_gate,
236
+ )
237
+
238
+ return img, txt
239
+
240
+ # def forward(
241
+ # self,
242
+ # img: torch.Tensor,
243
+ # txt: torch.Tensor,
244
+ # vec: torch.Tensor,
245
+ # attn_mask: Optional[torch.Tensor] = None,
246
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
247
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
248
+ # max_seqlen_q: Optional[int] = None,
249
+ # max_seqlen_kv: Optional[int] = None,
250
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
251
+ # ) -> Tuple[torch.Tensor, torch.Tensor]:
252
+ def forward(self, *args, **kwargs):
253
+ if self.training and self.gradient_checkpointing:
254
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
255
+ else:
256
+ return self._forward(*args, **kwargs)
257
+
258
+
259
+ class MMSingleStreamBlock(nn.Module):
260
+ """
261
+ A DiT block with parallel linear layers as described in
262
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
263
+ Also refer to (SD3): https://arxiv.org/abs/2403.03206
264
+ (Flux.1): https://github.com/black-forest-labs/flux
265
+ """
266
+
267
+ def __init__(
268
+ self,
269
+ hidden_size: int,
270
+ heads_num: int,
271
+ mlp_width_ratio: float = 4.0,
272
+ mlp_act_type: str = "gelu_tanh",
273
+ qk_norm: bool = True,
274
+ qk_norm_type: str = "rms",
275
+ qk_scale: float = None,
276
+ dtype: Optional[torch.dtype] = None,
277
+ device: Optional[torch.device] = None,
278
+ attn_mode: str = "flash",
279
+ split_attn: bool = False,
280
+ ):
281
+ factory_kwargs = {"device": device, "dtype": dtype}
282
+ super().__init__()
283
+ self.attn_mode = attn_mode
284
+ self.split_attn = split_attn
285
+
286
+ self.deterministic = False
287
+ self.hidden_size = hidden_size
288
+ self.heads_num = heads_num
289
+ head_dim = hidden_size // heads_num
290
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
291
+ self.mlp_hidden_dim = mlp_hidden_dim
292
+ self.scale = qk_scale or head_dim**-0.5
293
+
294
+ # qkv and mlp_in
295
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs)
296
+ # proj and mlp_out
297
+ self.linear2 = nn.Linear(hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs)
298
+
299
+ qk_norm_layer = get_norm_layer(qk_norm_type)
300
+ self.q_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
301
+ self.k_norm = qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
302
+
303
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
304
+
305
+ self.mlp_act = get_activation_layer(mlp_act_type)()
306
+ self.modulation = ModulateDiT(hidden_size, factor=3, act_layer=get_activation_layer("silu"), **factory_kwargs)
307
+ self.hybrid_seq_parallel_attn = None
308
+
309
+ self.gradient_checkpointing = False
310
+
311
+ def enable_deterministic(self):
312
+ self.deterministic = True
313
+
314
+ def disable_deterministic(self):
315
+ self.deterministic = False
316
+
317
+ def enable_gradient_checkpointing(self):
318
+ self.gradient_checkpointing = True
319
+
320
+ def disable_gradient_checkpointing(self):
321
+ self.gradient_checkpointing = False
322
+
323
+ def _forward(
324
+ self,
325
+ x: torch.Tensor,
326
+ vec: torch.Tensor,
327
+ txt_len: int,
328
+ attn_mask: Optional[torch.Tensor] = None,
329
+ total_len: Optional[torch.Tensor] = None,
330
+ cu_seqlens_q: Optional[torch.Tensor] = None,
331
+ cu_seqlens_kv: Optional[torch.Tensor] = None,
332
+ max_seqlen_q: Optional[int] = None,
333
+ max_seqlen_kv: Optional[int] = None,
334
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
335
+ ) -> torch.Tensor:
336
+ mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1)
337
+ x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale)
338
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
339
+ x_mod = None
340
+ # mlp = mlp.to("cpu", non_blocking=True)
341
+ # clean_memory_on_device(x.device)
342
+
343
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
344
+ qkv = None
345
+
346
+ # Apply QK-Norm if needed.
347
+ q = self.q_norm(q).to(v)
348
+ k = self.k_norm(k).to(v)
349
+
350
+ # Apply RoPE if needed.
351
+ if freqs_cis is not None:
352
+ img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
353
+ img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
354
+ q = k = None
355
+ img_q_shape = img_q.shape
356
+ img_k_shape = img_k.shape
357
+ img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
358
+ assert (
359
+ img_q.shape == img_q_shape and img_k_shape == img_k.shape
360
+ ), f"img_kk: {img_q.shape}, img_q: {img_q.shape}, img_kk: {img_k.shape}, img_k: {img_k.shape}"
361
+ # img_q, img_k = img_qq, img_kk
362
+ # del img_qq, img_kk
363
+ q = torch.cat((img_q, txt_q), dim=1)
364
+ k = torch.cat((img_k, txt_k), dim=1)
365
+ del img_q, txt_q, img_k, txt_k
366
+
367
+ # Compute attention.
368
+ assert cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1, f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, x.shape[0]:{x.shape[0]}"
369
+
370
+ # attention computation start
371
+ if not self.hybrid_seq_parallel_attn:
372
+ l = [q, k, v]
373
+ q = k = v = None
374
+ attn = attention(
375
+ l,
376
+ mode=self.attn_mode,
377
+ attn_mask=attn_mask,
378
+ total_len=total_len,
379
+ cu_seqlens_q=cu_seqlens_q,
380
+ cu_seqlens_kv=cu_seqlens_kv,
381
+ max_seqlen_q=max_seqlen_q,
382
+ max_seqlen_kv=max_seqlen_kv,
383
+ batch_size=x.shape[0],
384
+ )
385
+ else:
386
+ attn = parallel_attention(
387
+ self.hybrid_seq_parallel_attn,
388
+ q,
389
+ k,
390
+ v,
391
+ img_q_len=img_q.shape[1],
392
+ img_kv_len=img_k.shape[1],
393
+ cu_seqlens_q=cu_seqlens_q,
394
+ cu_seqlens_kv=cu_seqlens_kv,
395
+ )
396
+ # attention computation end
397
+
398
+ # Compute activation in mlp stream, cat again and run second linear layer.
399
+ # mlp = mlp.to(x.device)
400
+ mlp = self.mlp_act(mlp)
401
+ attn_mlp = torch.cat((attn, mlp), 2)
402
+ attn = None
403
+ mlp = None
404
+ output = self.linear2(attn_mlp)
405
+ attn_mlp = None
406
+ return x + apply_gate(output, gate=mod_gate)
407
+
408
+ # def forward(
409
+ # self,
410
+ # x: torch.Tensor,
411
+ # vec: torch.Tensor,
412
+ # txt_len: int,
413
+ # attn_mask: Optional[torch.Tensor] = None,
414
+ # cu_seqlens_q: Optional[torch.Tensor] = None,
415
+ # cu_seqlens_kv: Optional[torch.Tensor] = None,
416
+ # max_seqlen_q: Optional[int] = None,
417
+ # max_seqlen_kv: Optional[int] = None,
418
+ # freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
419
+ # ) -> torch.Tensor:
420
+ def forward(self, *args, **kwargs):
421
+ if self.training and self.gradient_checkpointing:
422
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
423
+ else:
424
+ return self._forward(*args, **kwargs)
425
+
426
+
427
+ class HYVideoDiffusionTransformer(nn.Module): # ModelMixin, ConfigMixin):
428
+ """
429
+ HunyuanVideo Transformer backbone
430
+
431
+ Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline.
432
+
433
+ Reference:
434
+ [1] Flux.1: https://github.com/black-forest-labs/flux
435
+ [2] MMDiT: http://arxiv.org/abs/2403.03206
436
+
437
+ Parameters
438
+ ----------
439
+ args: argparse.Namespace
440
+ The arguments parsed by argparse.
441
+ patch_size: list
442
+ The size of the patch.
443
+ in_channels: int
444
+ The number of input channels.
445
+ out_channels: int
446
+ The number of output channels.
447
+ hidden_size: int
448
+ The hidden size of the transformer backbone.
449
+ heads_num: int
450
+ The number of attention heads.
451
+ mlp_width_ratio: float
452
+ The ratio of the hidden size of the MLP in the transformer block.
453
+ mlp_act_type: str
454
+ The activation function of the MLP in the transformer block.
455
+ depth_double_blocks: int
456
+ The number of transformer blocks in the double blocks.
457
+ depth_single_blocks: int
458
+ The number of transformer blocks in the single blocks.
459
+ rope_dim_list: list
460
+ The dimension of the rotary embedding for t, h, w.
461
+ qkv_bias: bool
462
+ Whether to use bias in the qkv linear layer.
463
+ qk_norm: bool
464
+ Whether to use qk norm.
465
+ qk_norm_type: str
466
+ The type of qk norm.
467
+ guidance_embed: bool
468
+ Whether to use guidance embedding for distillation.
469
+ text_projection: str
470
+ The type of the text projection, default is single_refiner.
471
+ use_attention_mask: bool
472
+ Whether to use attention mask for text encoder.
473
+ dtype: torch.dtype
474
+ The dtype of the model.
475
+ device: torch.device
476
+ The device of the model.
477
+ attn_mode: str
478
+ The mode of the attention, default is flash.
479
+ split_attn: bool
480
+ Whether to use split attention (make attention as batch size 1).
481
+ """
482
+
483
+ # @register_to_config
484
+ def __init__(
485
+ self,
486
+ text_states_dim: int,
487
+ text_states_dim_2: int,
488
+ patch_size: list = [1, 2, 2],
489
+ in_channels: int = 4, # Should be VAE.config.latent_channels.
490
+ out_channels: int = None,
491
+ hidden_size: int = 3072,
492
+ heads_num: int = 24,
493
+ mlp_width_ratio: float = 4.0,
494
+ mlp_act_type: str = "gelu_tanh",
495
+ mm_double_blocks_depth: int = 20,
496
+ mm_single_blocks_depth: int = 40,
497
+ rope_dim_list: List[int] = [16, 56, 56],
498
+ qkv_bias: bool = True,
499
+ qk_norm: bool = True,
500
+ qk_norm_type: str = "rms",
501
+ guidance_embed: bool = False, # For modulation.
502
+ text_projection: str = "single_refiner",
503
+ use_attention_mask: bool = True,
504
+ dtype: Optional[torch.dtype] = None,
505
+ device: Optional[torch.device] = None,
506
+ attn_mode: str = "flash",
507
+ split_attn: bool = False,
508
+ ):
509
+ factory_kwargs = {"device": device, "dtype": dtype}
510
+ super().__init__()
511
+
512
+ self.patch_size = patch_size
513
+ self.in_channels = in_channels
514
+ self.out_channels = in_channels if out_channels is None else out_channels
515
+ self.unpatchify_channels = self.out_channels
516
+ self.guidance_embed = guidance_embed
517
+ self.rope_dim_list = rope_dim_list
518
+
519
+ # Text projection. Default to linear projection.
520
+ # Alternative: TokenRefiner. See more details (LI-DiT): http://arxiv.org/abs/2406.11831
521
+ self.use_attention_mask = use_attention_mask
522
+ self.text_projection = text_projection
523
+
524
+ self.text_states_dim = text_states_dim
525
+ self.text_states_dim_2 = text_states_dim_2
526
+
527
+ if hidden_size % heads_num != 0:
528
+ raise ValueError(f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}")
529
+ pe_dim = hidden_size // heads_num
530
+ if sum(rope_dim_list) != pe_dim:
531
+ raise ValueError(f"Got {rope_dim_list} but expected positional dim {pe_dim}")
532
+ self.hidden_size = hidden_size
533
+ self.heads_num = heads_num
534
+
535
+ self.attn_mode = attn_mode
536
+ self.split_attn = split_attn
537
+ print(f"Using {self.attn_mode} attention mode, split_attn: {self.split_attn}")
538
+
539
+ # image projection
540
+ self.img_in = PatchEmbed(self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs)
541
+
542
+ # text projection
543
+ if self.text_projection == "linear":
544
+ self.txt_in = TextProjection(
545
+ self.text_states_dim,
546
+ self.hidden_size,
547
+ get_activation_layer("silu"),
548
+ **factory_kwargs,
549
+ )
550
+ elif self.text_projection == "single_refiner":
551
+ self.txt_in = SingleTokenRefiner(self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs)
552
+ else:
553
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
554
+
555
+ # time modulation
556
+ self.time_in = TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs)
557
+
558
+ # text modulation
559
+ self.vector_in = MLPEmbedder(self.text_states_dim_2, self.hidden_size, **factory_kwargs)
560
+
561
+ # guidance modulation
562
+ self.guidance_in = (
563
+ TimestepEmbedder(self.hidden_size, get_activation_layer("silu"), **factory_kwargs) if guidance_embed else None
564
+ )
565
+
566
+ # double blocks
567
+ self.double_blocks = nn.ModuleList(
568
+ [
569
+ MMDoubleStreamBlock(
570
+ self.hidden_size,
571
+ self.heads_num,
572
+ mlp_width_ratio=mlp_width_ratio,
573
+ mlp_act_type=mlp_act_type,
574
+ qk_norm=qk_norm,
575
+ qk_norm_type=qk_norm_type,
576
+ qkv_bias=qkv_bias,
577
+ attn_mode=attn_mode,
578
+ split_attn=split_attn,
579
+ **factory_kwargs,
580
+ )
581
+ for _ in range(mm_double_blocks_depth)
582
+ ]
583
+ )
584
+
585
+ # single blocks
586
+ self.single_blocks = nn.ModuleList(
587
+ [
588
+ MMSingleStreamBlock(
589
+ self.hidden_size,
590
+ self.heads_num,
591
+ mlp_width_ratio=mlp_width_ratio,
592
+ mlp_act_type=mlp_act_type,
593
+ qk_norm=qk_norm,
594
+ qk_norm_type=qk_norm_type,
595
+ attn_mode=attn_mode,
596
+ split_attn=split_attn,
597
+ **factory_kwargs,
598
+ )
599
+ for _ in range(mm_single_blocks_depth)
600
+ ]
601
+ )
602
+
603
+ self.final_layer = FinalLayer(
604
+ self.hidden_size,
605
+ self.patch_size,
606
+ self.out_channels,
607
+ get_activation_layer("silu"),
608
+ **factory_kwargs,
609
+ )
610
+
611
+ self.gradient_checkpointing = False
612
+ self.blocks_to_swap = None
613
+ self.offloader_double = None
614
+ self.offloader_single = None
615
+ self._enable_img_in_txt_in_offloading = False
616
+
617
+ @property
618
+ def device(self):
619
+ return next(self.parameters()).device
620
+
621
+ @property
622
+ def dtype(self):
623
+ return next(self.parameters()).dtype
624
+
625
+ def enable_gradient_checkpointing(self):
626
+ self.gradient_checkpointing = True
627
+
628
+ self.txt_in.enable_gradient_checkpointing()
629
+
630
+ for block in self.double_blocks + self.single_blocks:
631
+ block.enable_gradient_checkpointing()
632
+
633
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing enabled.")
634
+
635
+ def disable_gradient_checkpointing(self):
636
+ self.gradient_checkpointing = False
637
+
638
+ self.txt_in.disable_gradient_checkpointing()
639
+
640
+ for block in self.double_blocks + self.single_blocks:
641
+ block.disable_gradient_checkpointing()
642
+
643
+ print(f"HYVideoDiffusionTransformer: Gradient checkpointing disabled.")
644
+
645
+ def enable_img_in_txt_in_offloading(self):
646
+ self._enable_img_in_txt_in_offloading = True
647
+
648
+ def enable_block_swap(self, num_blocks: int, device: torch.device, supports_backward: bool):
649
+ self.blocks_to_swap = num_blocks
650
+ self.num_double_blocks = len(self.double_blocks)
651
+ self.num_single_blocks = len(self.single_blocks)
652
+ double_blocks_to_swap = num_blocks // 2
653
+ single_blocks_to_swap = (num_blocks - double_blocks_to_swap) * 2 + 1
654
+
655
+ assert double_blocks_to_swap <= self.num_double_blocks - 1 and single_blocks_to_swap <= self.num_single_blocks - 1, (
656
+ f"Cannot swap more than {self.num_double_blocks - 1} double blocks and {self.num_single_blocks - 1} single blocks. "
657
+ f"Requested {double_blocks_to_swap} double blocks and {single_blocks_to_swap} single blocks."
658
+ )
659
+
660
+ self.offloader_double = ModelOffloader(
661
+ "double", self.double_blocks, self.num_double_blocks, double_blocks_to_swap, supports_backward, device # , debug=True
662
+ )
663
+ self.offloader_single = ModelOffloader(
664
+ "single", self.single_blocks, self.num_single_blocks, single_blocks_to_swap, supports_backward, device # , debug=True
665
+ )
666
+ print(
667
+ f"HYVideoDiffusionTransformer: Block swap enabled. Swapping {num_blocks} blocks, double blocks: {double_blocks_to_swap}, single blocks: {single_blocks_to_swap}."
668
+ )
669
+
670
+ def switch_block_swap_for_inference(self):
671
+ if self.blocks_to_swap:
672
+ self.offloader_double.set_forward_only(True)
673
+ self.offloader_single.set_forward_only(True)
674
+ self.prepare_block_swap_before_forward()
675
+ print(f"HYVideoDiffusionTransformer: Block swap set to forward only.")
676
+
677
+ def switch_block_swap_for_training(self):
678
+ if self.blocks_to_swap:
679
+ self.offloader_double.set_forward_only(False)
680
+ self.offloader_single.set_forward_only(False)
681
+ self.prepare_block_swap_before_forward()
682
+ print(f"HYVideoDiffusionTransformer: Block swap set to forward and backward.")
683
+
684
+ def move_to_device_except_swap_blocks(self, device: torch.device):
685
+ # assume model is on cpu. do not move blocks to device to reduce temporary memory usage
686
+ if self.blocks_to_swap:
687
+ save_double_blocks = self.double_blocks
688
+ save_single_blocks = self.single_blocks
689
+ self.double_blocks = None
690
+ self.single_blocks = None
691
+
692
+ self.to(device)
693
+
694
+ if self.blocks_to_swap:
695
+ self.double_blocks = save_double_blocks
696
+ self.single_blocks = save_single_blocks
697
+
698
+ def prepare_block_swap_before_forward(self):
699
+ if self.blocks_to_swap is None or self.blocks_to_swap == 0:
700
+ return
701
+ self.offloader_double.prepare_block_devices_before_forward(self.double_blocks)
702
+ self.offloader_single.prepare_block_devices_before_forward(self.single_blocks)
703
+
704
+ def enable_deterministic(self):
705
+ for block in self.double_blocks:
706
+ block.enable_deterministic()
707
+ for block in self.single_blocks:
708
+ block.enable_deterministic()
709
+
710
+ def disable_deterministic(self):
711
+ for block in self.double_blocks:
712
+ block.disable_deterministic()
713
+ for block in self.single_blocks:
714
+ block.disable_deterministic()
715
+
716
+ def forward(
717
+ self,
718
+ x: torch.Tensor,
719
+ t: torch.Tensor, # Should be in range(0, 1000).
720
+ text_states: torch.Tensor = None,
721
+ text_mask: torch.Tensor = None, # Now we don't use it.
722
+ text_states_2: Optional[torch.Tensor] = None, # Text embedding for modulation.
723
+ freqs_cos: Optional[torch.Tensor] = None,
724
+ freqs_sin: Optional[torch.Tensor] = None,
725
+ guidance: torch.Tensor = None, # Guidance for modulation, should be cfg_scale x 1000.
726
+ return_dict: bool = True,
727
+ ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
728
+ out = {}
729
+ img = x
730
+ txt = text_states
731
+ _, _, ot, oh, ow = x.shape
732
+ tt, th, tw = (
733
+ ot // self.patch_size[0],
734
+ oh // self.patch_size[1],
735
+ ow // self.patch_size[2],
736
+ )
737
+
738
+ # Prepare modulation vectors.
739
+ vec = self.time_in(t)
740
+
741
+ # text modulation
742
+ vec = vec + self.vector_in(text_states_2)
743
+
744
+ # guidance modulation
745
+ if self.guidance_embed:
746
+ if guidance is None:
747
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
748
+
749
+ # our timestep_embedding is merged into guidance_in(TimestepEmbedder)
750
+ vec = vec + self.guidance_in(guidance)
751
+
752
+ # Embed image and text.
753
+ if self._enable_img_in_txt_in_offloading:
754
+ self.img_in.to(x.device, non_blocking=True)
755
+ self.txt_in.to(x.device, non_blocking=True)
756
+ synchronize_device(x.device)
757
+
758
+ img = self.img_in(img)
759
+ if self.text_projection == "linear":
760
+ txt = self.txt_in(txt)
761
+ elif self.text_projection == "single_refiner":
762
+ txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None)
763
+ else:
764
+ raise NotImplementedError(f"Unsupported text_projection: {self.text_projection}")
765
+
766
+ if self._enable_img_in_txt_in_offloading:
767
+ self.img_in.to(torch.device("cpu"), non_blocking=True)
768
+ self.txt_in.to(torch.device("cpu"), non_blocking=True)
769
+ synchronize_device(x.device)
770
+ clean_memory_on_device(x.device)
771
+
772
+ txt_seq_len = txt.shape[1]
773
+ img_seq_len = img.shape[1]
774
+
775
+ # Compute cu_squlens and max_seqlen for flash attention
776
+ cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len)
777
+ cu_seqlens_kv = cu_seqlens_q
778
+ max_seqlen_q = img_seq_len + txt_seq_len
779
+ max_seqlen_kv = max_seqlen_q
780
+
781
+ attn_mask = total_len = None
782
+ if self.split_attn or self.attn_mode == "torch":
783
+ # calculate text length and total length
784
+ text_len = text_mask.sum(dim=1) # (bs, )
785
+ total_len = img_seq_len + text_len # (bs, )
786
+ if self.attn_mode == "torch" and not self.split_attn:
787
+ # initialize attention mask: bool tensor for sdpa, (b, 1, n, n)
788
+ bs = img.shape[0]
789
+ attn_mask = torch.zeros((bs, 1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device)
790
+
791
+ # set attention mask with total_len
792
+ for i in range(bs):
793
+ attn_mask[i, :, : total_len[i], : total_len[i]] = True
794
+ total_len = None # means we don't use split_attn
795
+
796
+ freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None
797
+ # --------------------- Pass through DiT blocks ------------------------
798
+ for block_idx, block in enumerate(self.double_blocks):
799
+ double_block_args = [
800
+ img,
801
+ txt,
802
+ vec,
803
+ attn_mask,
804
+ total_len,
805
+ cu_seqlens_q,
806
+ cu_seqlens_kv,
807
+ max_seqlen_q,
808
+ max_seqlen_kv,
809
+ freqs_cis,
810
+ ]
811
+
812
+ if self.blocks_to_swap:
813
+ self.offloader_double.wait_for_block(block_idx)
814
+
815
+ img, txt = block(*double_block_args)
816
+
817
+ if self.blocks_to_swap:
818
+ self.offloader_double.submit_move_blocks_forward(self.double_blocks, block_idx)
819
+
820
+ # Merge txt and img to pass through single stream blocks.
821
+ x = torch.cat((img, txt), 1)
822
+ if self.blocks_to_swap:
823
+ # delete img, txt to reduce memory usage
824
+ del img, txt
825
+ clean_memory_on_device(x.device)
826
+
827
+ if len(self.single_blocks) > 0:
828
+ for block_idx, block in enumerate(self.single_blocks):
829
+ single_block_args = [
830
+ x,
831
+ vec,
832
+ txt_seq_len,
833
+ attn_mask,
834
+ total_len,
835
+ cu_seqlens_q,
836
+ cu_seqlens_kv,
837
+ max_seqlen_q,
838
+ max_seqlen_kv,
839
+ freqs_cis,
840
+ ]
841
+ if self.blocks_to_swap:
842
+ self.offloader_single.wait_for_block(block_idx)
843
+
844
+ x = block(*single_block_args)
845
+
846
+ if self.blocks_to_swap:
847
+ self.offloader_single.submit_move_blocks_forward(self.single_blocks, block_idx)
848
+
849
+ img = x[:, :img_seq_len, ...]
850
+ x = None
851
+
852
+ # ---------------------------- Final layer ------------------------------
853
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
854
+
855
+ img = self.unpatchify(img, tt, th, tw)
856
+ if return_dict:
857
+ out["x"] = img
858
+ return out
859
+ return img
860
+
861
+ def unpatchify(self, x, t, h, w):
862
+ """
863
+ x: (N, T, patch_size**2 * C)
864
+ imgs: (N, H, W, C)
865
+ """
866
+ c = self.unpatchify_channels
867
+ pt, ph, pw = self.patch_size
868
+ assert t * h * w == x.shape[1]
869
+
870
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
871
+ x = torch.einsum("nthwcopq->nctohpwq", x)
872
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
873
+
874
+ return imgs
875
+
876
+ def params_count(self):
877
+ counts = {
878
+ "double": sum(
879
+ [
880
+ sum(p.numel() for p in block.img_attn_qkv.parameters())
881
+ + sum(p.numel() for p in block.img_attn_proj.parameters())
882
+ + sum(p.numel() for p in block.img_mlp.parameters())
883
+ + sum(p.numel() for p in block.txt_attn_qkv.parameters())
884
+ + sum(p.numel() for p in block.txt_attn_proj.parameters())
885
+ + sum(p.numel() for p in block.txt_mlp.parameters())
886
+ for block in self.double_blocks
887
+ ]
888
+ ),
889
+ "single": sum(
890
+ [
891
+ sum(p.numel() for p in block.linear1.parameters()) + sum(p.numel() for p in block.linear2.parameters())
892
+ for block in self.single_blocks
893
+ ]
894
+ ),
895
+ "total": sum(p.numel() for p in self.parameters()),
896
+ }
897
+ counts["attn+mlp"] = counts["double"] + counts["single"]
898
+ return counts
899
+
900
+
901
+ #################################################################################
902
+ # HunyuanVideo Configs #
903
+ #################################################################################
904
+
905
+ HUNYUAN_VIDEO_CONFIG = {
906
+ "HYVideo-T/2": {
907
+ "mm_double_blocks_depth": 20,
908
+ "mm_single_blocks_depth": 40,
909
+ "rope_dim_list": [16, 56, 56],
910
+ "hidden_size": 3072,
911
+ "heads_num": 24,
912
+ "mlp_width_ratio": 4,
913
+ },
914
+ "HYVideo-T/2-cfgdistill": {
915
+ "mm_double_blocks_depth": 20,
916
+ "mm_single_blocks_depth": 40,
917
+ "rope_dim_list": [16, 56, 56],
918
+ "hidden_size": 3072,
919
+ "heads_num": 24,
920
+ "mlp_width_ratio": 4,
921
+ "guidance_embed": True,
922
+ },
923
+ }
924
+
925
+
926
+ def load_dit_model(text_states_dim, text_states_dim_2, in_channels, out_channels, factor_kwargs):
927
+ """load hunyuan video model
928
+
929
+ NOTE: Only support HYVideo-T/2-cfgdistill now.
930
+
931
+ Args:
932
+ text_state_dim (int): text state dimension
933
+ text_state_dim_2 (int): text state dimension 2
934
+ in_channels (int): input channels number
935
+ out_channels (int): output channels number
936
+ factor_kwargs (dict): factor kwargs
937
+
938
+ Returns:
939
+ model (nn.Module): The hunyuan video model
940
+ """
941
+ # if args.model in HUNYUAN_VIDEO_CONFIG.keys():
942
+ model = HYVideoDiffusionTransformer(
943
+ text_states_dim=text_states_dim,
944
+ text_states_dim_2=text_states_dim_2,
945
+ in_channels=in_channels,
946
+ out_channels=out_channels,
947
+ **HUNYUAN_VIDEO_CONFIG["HYVideo-T/2-cfgdistill"],
948
+ **factor_kwargs,
949
+ )
950
+ return model
951
+ # else:
952
+ # raise NotImplementedError()
953
+
954
+
955
+ def load_state_dict(model, model_path):
956
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
957
+
958
+ load_key = "module"
959
+ if load_key in state_dict:
960
+ state_dict = state_dict[load_key]
961
+ else:
962
+ raise KeyError(
963
+ f"Missing key: `{load_key}` in the checkpoint: {model_path}. The keys in the checkpoint "
964
+ f"are: {list(state_dict.keys())}."
965
+ )
966
+ model.load_state_dict(state_dict, strict=True, assign=True)
967
+ return model
968
+
969
+
970
+ def load_transformer(dit_path, attn_mode, split_attn, device, dtype, in_channels=16) -> HYVideoDiffusionTransformer:
971
+ # =========================== Build main model ===========================
972
+ factor_kwargs = {"device": device, "dtype": dtype, "attn_mode": attn_mode, "split_attn": split_attn}
973
+ latent_channels = 16
974
+ out_channels = latent_channels
975
+
976
+ with accelerate.init_empty_weights():
977
+ transformer = load_dit_model(
978
+ text_states_dim=4096,
979
+ text_states_dim_2=768,
980
+ in_channels=in_channels,
981
+ out_channels=out_channels,
982
+ factor_kwargs=factor_kwargs,
983
+ )
984
+
985
+ if os.path.splitext(dit_path)[-1] == ".safetensors":
986
+ # loading safetensors: may be already fp8
987
+ with MemoryEfficientSafeOpen(dit_path) as f:
988
+ state_dict = {}
989
+ for k in f.keys():
990
+ tensor = f.get_tensor(k)
991
+ tensor = tensor.to(device=device, dtype=dtype)
992
+ # TODO support comfy model
993
+ # if k.startswith("model.model."):
994
+ # k = convert_comfy_model_key(k)
995
+ state_dict[k] = tensor
996
+ transformer.load_state_dict(state_dict, strict=True, assign=True)
997
+ else:
998
+ transformer = load_state_dict(transformer, dit_path)
999
+
1000
+ return transformer
1001
+
1002
+
1003
+ def get_rotary_pos_embed_by_shape(model, latents_size):
1004
+ target_ndim = 3
1005
+ ndim = 5 - 2
1006
+
1007
+ if isinstance(model.patch_size, int):
1008
+ assert all(s % model.patch_size == 0 for s in latents_size), (
1009
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
1010
+ f"but got {latents_size}."
1011
+ )
1012
+ rope_sizes = [s // model.patch_size for s in latents_size]
1013
+ elif isinstance(model.patch_size, list):
1014
+ assert all(s % model.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
1015
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({model.patch_size}), "
1016
+ f"but got {latents_size}."
1017
+ )
1018
+ rope_sizes = [s // model.patch_size[idx] for idx, s in enumerate(latents_size)]
1019
+
1020
+ if len(rope_sizes) != target_ndim:
1021
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
1022
+ head_dim = model.hidden_size // model.heads_num
1023
+ rope_dim_list = model.rope_dim_list
1024
+ if rope_dim_list is None:
1025
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
1026
+ assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
1027
+
1028
+ rope_theta = 256
1029
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
1030
+ rope_dim_list, rope_sizes, theta=rope_theta, use_real=True, theta_rescale_factor=1
1031
+ )
1032
+ return freqs_cos, freqs_sin
1033
+
1034
+
1035
+ def get_rotary_pos_embed(vae_name, model, video_length, height, width):
1036
+ # 884
1037
+ if "884" in vae_name:
1038
+ latents_size = [(video_length - 1) // 4 + 1, height // 8, width // 8]
1039
+ elif "888" in vae_name:
1040
+ latents_size = [(video_length - 1) // 8 + 1, height // 8, width // 8]
1041
+ else:
1042
+ latents_size = [video_length, height // 8, width // 8]
1043
+
1044
+ return get_rotary_pos_embed_by_shape(model, latents_size)
hunyuan_model/modulate_layers.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class ModulateDiT(nn.Module):
8
+ """Modulation layer for DiT."""
9
+ def __init__(
10
+ self,
11
+ hidden_size: int,
12
+ factor: int,
13
+ act_layer: Callable,
14
+ dtype=None,
15
+ device=None,
16
+ ):
17
+ factory_kwargs = {"dtype": dtype, "device": device}
18
+ super().__init__()
19
+ self.act = act_layer()
20
+ self.linear = nn.Linear(
21
+ hidden_size, factor * hidden_size, bias=True, **factory_kwargs
22
+ )
23
+ # Zero-initialize the modulation
24
+ nn.init.zeros_(self.linear.weight)
25
+ nn.init.zeros_(self.linear.bias)
26
+
27
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
28
+ return self.linear(self.act(x))
29
+
30
+
31
+ def modulate(x, shift=None, scale=None):
32
+ """modulate by shift and scale
33
+
34
+ Args:
35
+ x (torch.Tensor): input tensor.
36
+ shift (torch.Tensor, optional): shift tensor. Defaults to None.
37
+ scale (torch.Tensor, optional): scale tensor. Defaults to None.
38
+
39
+ Returns:
40
+ torch.Tensor: the output tensor after modulate.
41
+ """
42
+ if scale is None and shift is None:
43
+ return x
44
+ elif shift is None:
45
+ return x * (1 + scale.unsqueeze(1))
46
+ elif scale is None:
47
+ return x + shift.unsqueeze(1)
48
+ else:
49
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
50
+
51
+
52
+ def apply_gate(x, gate=None, tanh=False):
53
+ """AI is creating summary for apply_gate
54
+
55
+ Args:
56
+ x (torch.Tensor): input tensor.
57
+ gate (torch.Tensor, optional): gate tensor. Defaults to None.
58
+ tanh (bool, optional): whether to use tanh function. Defaults to False.
59
+
60
+ Returns:
61
+ torch.Tensor: the output tensor after apply gate.
62
+ """
63
+ if gate is None:
64
+ return x
65
+ if tanh:
66
+ return x * gate.unsqueeze(1).tanh()
67
+ else:
68
+ return x * gate.unsqueeze(1)
69
+
70
+
71
+ def ckpt_wrapper(module):
72
+ def ckpt_forward(*inputs):
73
+ outputs = module(*inputs)
74
+ return outputs
75
+
76
+ return ckpt_forward
hunyuan_model/norm_layers.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class RMSNorm(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ elementwise_affine=True,
10
+ eps: float = 1e-6,
11
+ device=None,
12
+ dtype=None,
13
+ ):
14
+ """
15
+ Initialize the RMSNorm normalization layer.
16
+
17
+ Args:
18
+ dim (int): The dimension of the input tensor.
19
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
20
+
21
+ Attributes:
22
+ eps (float): A small value added to the denominator for numerical stability.
23
+ weight (nn.Parameter): Learnable scaling parameter.
24
+
25
+ """
26
+ factory_kwargs = {"device": device, "dtype": dtype}
27
+ super().__init__()
28
+ self.eps = eps
29
+ if elementwise_affine:
30
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
31
+
32
+ def _norm(self, x):
33
+ """
34
+ Apply the RMSNorm normalization to the input tensor.
35
+
36
+ Args:
37
+ x (torch.Tensor): The input tensor.
38
+
39
+ Returns:
40
+ torch.Tensor: The normalized tensor.
41
+
42
+ """
43
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
44
+
45
+ def forward(self, x):
46
+ """
47
+ Forward pass through the RMSNorm layer.
48
+
49
+ Args:
50
+ x (torch.Tensor): The input tensor.
51
+
52
+ Returns:
53
+ torch.Tensor: The output tensor after applying RMSNorm.
54
+
55
+ """
56
+ output = self._norm(x.float()).type_as(x)
57
+ if hasattr(self, "weight"):
58
+ # output = output * self.weight
59
+ # support fp8
60
+ output = output * self.weight.to(output.dtype)
61
+ return output
62
+
63
+
64
+ def get_norm_layer(norm_layer):
65
+ """
66
+ Get the normalization layer.
67
+
68
+ Args:
69
+ norm_layer (str): The type of normalization layer.
70
+
71
+ Returns:
72
+ norm_layer (nn.Module): The normalization layer.
73
+ """
74
+ if norm_layer == "layer":
75
+ return nn.LayerNorm
76
+ elif norm_layer == "rms":
77
+ return RMSNorm
78
+ else:
79
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
hunyuan_model/pipeline_hunyuan_video.py ADDED
@@ -0,0 +1,1100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ import inspect
20
+ from typing import Any, Callable, Dict, List, Optional, Union, Tuple
21
+ import torch
22
+ import torch.distributed as dist
23
+ import numpy as np
24
+ from dataclasses import dataclass
25
+ from packaging import version
26
+
27
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
28
+ from diffusers.configuration_utils import FrozenDict
29
+ from diffusers.image_processor import VaeImageProcessor
30
+ from diffusers.loaders import LoraLoaderMixin, TextualInversionLoaderMixin
31
+ from diffusers.models import AutoencoderKL
32
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
33
+ from diffusers.schedulers import KarrasDiffusionSchedulers
34
+ from diffusers.utils import (
35
+ USE_PEFT_BACKEND,
36
+ deprecate,
37
+ logging,
38
+ replace_example_docstring,
39
+ scale_lora_layers,
40
+ unscale_lora_layers,
41
+ )
42
+ from diffusers.utils.torch_utils import randn_tensor
43
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
44
+ from diffusers.utils import BaseOutput
45
+
46
+ from ...constants import PRECISION_TO_TYPE
47
+ from ...vae.autoencoder_kl_causal_3d import AutoencoderKLCausal3D
48
+ from ...text_encoder import TextEncoder
49
+ from ...modules import HYVideoDiffusionTransformer
50
+
51
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
52
+
53
+ EXAMPLE_DOC_STRING = """"""
54
+
55
+
56
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
57
+ """
58
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
59
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
60
+ """
61
+ std_text = noise_pred_text.std(
62
+ dim=list(range(1, noise_pred_text.ndim)), keepdim=True
63
+ )
64
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
65
+ # rescale the results from guidance (fixes overexposure)
66
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
67
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
68
+ noise_cfg = (
69
+ guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
70
+ )
71
+ return noise_cfg
72
+
73
+
74
+ def retrieve_timesteps(
75
+ scheduler,
76
+ num_inference_steps: Optional[int] = None,
77
+ device: Optional[Union[str, torch.device]] = None,
78
+ timesteps: Optional[List[int]] = None,
79
+ sigmas: Optional[List[float]] = None,
80
+ **kwargs,
81
+ ):
82
+ """
83
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
84
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
85
+
86
+ Args:
87
+ scheduler (`SchedulerMixin`):
88
+ The scheduler to get timesteps from.
89
+ num_inference_steps (`int`):
90
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
91
+ must be `None`.
92
+ device (`str` or `torch.device`, *optional*):
93
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
94
+ timesteps (`List[int]`, *optional*):
95
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
96
+ `num_inference_steps` and `sigmas` must be `None`.
97
+ sigmas (`List[float]`, *optional*):
98
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
99
+ `num_inference_steps` and `timesteps` must be `None`.
100
+
101
+ Returns:
102
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
103
+ second element is the number of inference steps.
104
+ """
105
+ if timesteps is not None and sigmas is not None:
106
+ raise ValueError(
107
+ "Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values"
108
+ )
109
+ if timesteps is not None:
110
+ accepts_timesteps = "timesteps" in set(
111
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
112
+ )
113
+ if not accepts_timesteps:
114
+ raise ValueError(
115
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
116
+ f" timestep schedules. Please check whether you are using the correct scheduler."
117
+ )
118
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
119
+ timesteps = scheduler.timesteps
120
+ num_inference_steps = len(timesteps)
121
+ elif sigmas is not None:
122
+ accept_sigmas = "sigmas" in set(
123
+ inspect.signature(scheduler.set_timesteps).parameters.keys()
124
+ )
125
+ if not accept_sigmas:
126
+ raise ValueError(
127
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
128
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
129
+ )
130
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
131
+ timesteps = scheduler.timesteps
132
+ num_inference_steps = len(timesteps)
133
+ else:
134
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
135
+ timesteps = scheduler.timesteps
136
+ return timesteps, num_inference_steps
137
+
138
+
139
+ @dataclass
140
+ class HunyuanVideoPipelineOutput(BaseOutput):
141
+ videos: Union[torch.Tensor, np.ndarray]
142
+
143
+
144
+ class HunyuanVideoPipeline(DiffusionPipeline):
145
+ r"""
146
+ Pipeline for text-to-video generation using HunyuanVideo.
147
+
148
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
149
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
150
+
151
+ Args:
152
+ vae ([`AutoencoderKL`]):
153
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
154
+ text_encoder ([`TextEncoder`]):
155
+ Frozen text-encoder.
156
+ text_encoder_2 ([`TextEncoder`]):
157
+ Frozen text-encoder_2.
158
+ transformer ([`HYVideoDiffusionTransformer`]):
159
+ A `HYVideoDiffusionTransformer` to denoise the encoded video latents.
160
+ scheduler ([`SchedulerMixin`]):
161
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
162
+ """
163
+
164
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
165
+ _optional_components = ["text_encoder_2"]
166
+ _exclude_from_cpu_offload = ["transformer"]
167
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
168
+
169
+ def __init__(
170
+ self,
171
+ vae: AutoencoderKL,
172
+ text_encoder: TextEncoder,
173
+ transformer: HYVideoDiffusionTransformer,
174
+ scheduler: KarrasDiffusionSchedulers,
175
+ text_encoder_2: Optional[TextEncoder] = None,
176
+ progress_bar_config: Dict[str, Any] = None,
177
+ args=None,
178
+ ):
179
+ super().__init__()
180
+
181
+ # ==========================================================================================
182
+ if progress_bar_config is None:
183
+ progress_bar_config = {}
184
+ if not hasattr(self, "_progress_bar_config"):
185
+ self._progress_bar_config = {}
186
+ self._progress_bar_config.update(progress_bar_config)
187
+
188
+ self.args = args
189
+ # ==========================================================================================
190
+
191
+ if (
192
+ hasattr(scheduler.config, "steps_offset")
193
+ and scheduler.config.steps_offset != 1
194
+ ):
195
+ deprecation_message = (
196
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
197
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
198
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
199
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
200
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
201
+ " file"
202
+ )
203
+ deprecate(
204
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
205
+ )
206
+ new_config = dict(scheduler.config)
207
+ new_config["steps_offset"] = 1
208
+ scheduler._internal_dict = FrozenDict(new_config)
209
+
210
+ if (
211
+ hasattr(scheduler.config, "clip_sample")
212
+ and scheduler.config.clip_sample is True
213
+ ):
214
+ deprecation_message = (
215
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
216
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
217
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
218
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
219
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
220
+ )
221
+ deprecate(
222
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
223
+ )
224
+ new_config = dict(scheduler.config)
225
+ new_config["clip_sample"] = False
226
+ scheduler._internal_dict = FrozenDict(new_config)
227
+
228
+ self.register_modules(
229
+ vae=vae,
230
+ text_encoder=text_encoder,
231
+ transformer=transformer,
232
+ scheduler=scheduler,
233
+ text_encoder_2=text_encoder_2,
234
+ )
235
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
236
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237
+
238
+ def encode_prompt(
239
+ self,
240
+ prompt,
241
+ device,
242
+ num_videos_per_prompt,
243
+ do_classifier_free_guidance,
244
+ negative_prompt=None,
245
+ prompt_embeds: Optional[torch.Tensor] = None,
246
+ attention_mask: Optional[torch.Tensor] = None,
247
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
248
+ negative_attention_mask: Optional[torch.Tensor] = None,
249
+ lora_scale: Optional[float] = None,
250
+ clip_skip: Optional[int] = None,
251
+ text_encoder: Optional[TextEncoder] = None,
252
+ data_type: Optional[str] = "image",
253
+ ):
254
+ r"""
255
+ Encodes the prompt into text encoder hidden states.
256
+
257
+ Args:
258
+ prompt (`str` or `List[str]`, *optional*):
259
+ prompt to be encoded
260
+ device: (`torch.device`):
261
+ torch device
262
+ num_videos_per_prompt (`int`):
263
+ number of videos that should be generated per prompt
264
+ do_classifier_free_guidance (`bool`):
265
+ whether to use classifier free guidance or not
266
+ negative_prompt (`str` or `List[str]`, *optional*):
267
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
268
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
269
+ less than `1`).
270
+ prompt_embeds (`torch.Tensor`, *optional*):
271
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
272
+ provided, text embeddings will be generated from `prompt` input argument.
273
+ attention_mask (`torch.Tensor`, *optional*):
274
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
275
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
276
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
277
+ argument.
278
+ negative_attention_mask (`torch.Tensor`, *optional*):
279
+ lora_scale (`float`, *optional*):
280
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
281
+ clip_skip (`int`, *optional*):
282
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
283
+ the output of the pre-final layer will be used for computing the prompt embeddings.
284
+ text_encoder (TextEncoder, *optional*):
285
+ data_type (`str`, *optional*):
286
+ """
287
+ if text_encoder is None:
288
+ text_encoder = self.text_encoder
289
+
290
+ # set lora scale so that monkey patched LoRA
291
+ # function of text encoder can correctly access it
292
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
293
+ self._lora_scale = lora_scale
294
+
295
+ # dynamically adjust the LoRA scale
296
+ if not USE_PEFT_BACKEND:
297
+ adjust_lora_scale_text_encoder(text_encoder.model, lora_scale)
298
+ else:
299
+ scale_lora_layers(text_encoder.model, lora_scale)
300
+
301
+ if prompt is not None and isinstance(prompt, str):
302
+ batch_size = 1
303
+ elif prompt is not None and isinstance(prompt, list):
304
+ batch_size = len(prompt)
305
+ else:
306
+ batch_size = prompt_embeds.shape[0]
307
+
308
+ if prompt_embeds is None:
309
+ # textual inversion: process multi-vector tokens if necessary
310
+ if isinstance(self, TextualInversionLoaderMixin):
311
+ prompt = self.maybe_convert_prompt(prompt, text_encoder.tokenizer)
312
+
313
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
314
+
315
+ if clip_skip is None:
316
+ prompt_outputs = text_encoder.encode(
317
+ text_inputs, data_type=data_type, device=device
318
+ )
319
+ prompt_embeds = prompt_outputs.hidden_state
320
+ else:
321
+ prompt_outputs = text_encoder.encode(
322
+ text_inputs,
323
+ output_hidden_states=True,
324
+ data_type=data_type,
325
+ device=device,
326
+ )
327
+ # Access the `hidden_states` first, that contains a tuple of
328
+ # all the hidden states from the encoder layers. Then index into
329
+ # the tuple to access the hidden states from the desired layer.
330
+ prompt_embeds = prompt_outputs.hidden_states_list[-(clip_skip + 1)]
331
+ # We also need to apply the final LayerNorm here to not mess with the
332
+ # representations. The `last_hidden_states` that we typically use for
333
+ # obtaining the final prompt representations passes through the LayerNorm
334
+ # layer.
335
+ prompt_embeds = text_encoder.model.text_model.final_layer_norm(
336
+ prompt_embeds
337
+ )
338
+
339
+ attention_mask = prompt_outputs.attention_mask
340
+ if attention_mask is not None:
341
+ attention_mask = attention_mask.to(device)
342
+ bs_embed, seq_len = attention_mask.shape
343
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
344
+ attention_mask = attention_mask.view(
345
+ bs_embed * num_videos_per_prompt, seq_len
346
+ )
347
+
348
+ if text_encoder is not None:
349
+ prompt_embeds_dtype = text_encoder.dtype
350
+ elif self.transformer is not None:
351
+ prompt_embeds_dtype = self.transformer.dtype
352
+ else:
353
+ prompt_embeds_dtype = prompt_embeds.dtype
354
+
355
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
356
+
357
+ if prompt_embeds.ndim == 2:
358
+ bs_embed, _ = prompt_embeds.shape
359
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
360
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
361
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
362
+ else:
363
+ bs_embed, seq_len, _ = prompt_embeds.shape
364
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
365
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
366
+ prompt_embeds = prompt_embeds.view(
367
+ bs_embed * num_videos_per_prompt, seq_len, -1
368
+ )
369
+
370
+ # get unconditional embeddings for classifier free guidance
371
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
372
+ uncond_tokens: List[str]
373
+ if negative_prompt is None:
374
+ uncond_tokens = [""] * batch_size
375
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
376
+ raise TypeError(
377
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
378
+ f" {type(prompt)}."
379
+ )
380
+ elif isinstance(negative_prompt, str):
381
+ uncond_tokens = [negative_prompt]
382
+ elif batch_size != len(negative_prompt):
383
+ raise ValueError(
384
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
385
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
386
+ " the batch size of `prompt`."
387
+ )
388
+ else:
389
+ uncond_tokens = negative_prompt
390
+
391
+ # textual inversion: process multi-vector tokens if necessary
392
+ if isinstance(self, TextualInversionLoaderMixin):
393
+ uncond_tokens = self.maybe_convert_prompt(
394
+ uncond_tokens, text_encoder.tokenizer
395
+ )
396
+
397
+ # max_length = prompt_embeds.shape[1]
398
+ uncond_input = text_encoder.text2tokens(uncond_tokens, data_type=data_type)
399
+
400
+ negative_prompt_outputs = text_encoder.encode(
401
+ uncond_input, data_type=data_type, device=device
402
+ )
403
+ negative_prompt_embeds = negative_prompt_outputs.hidden_state
404
+
405
+ negative_attention_mask = negative_prompt_outputs.attention_mask
406
+ if negative_attention_mask is not None:
407
+ negative_attention_mask = negative_attention_mask.to(device)
408
+ _, seq_len = negative_attention_mask.shape
409
+ negative_attention_mask = negative_attention_mask.repeat(
410
+ 1, num_videos_per_prompt
411
+ )
412
+ negative_attention_mask = negative_attention_mask.view(
413
+ batch_size * num_videos_per_prompt, seq_len
414
+ )
415
+
416
+ if do_classifier_free_guidance:
417
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
418
+ seq_len = negative_prompt_embeds.shape[1]
419
+
420
+ negative_prompt_embeds = negative_prompt_embeds.to(
421
+ dtype=prompt_embeds_dtype, device=device
422
+ )
423
+
424
+ if negative_prompt_embeds.ndim == 2:
425
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
426
+ 1, num_videos_per_prompt
427
+ )
428
+ negative_prompt_embeds = negative_prompt_embeds.view(
429
+ batch_size * num_videos_per_prompt, -1
430
+ )
431
+ else:
432
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
433
+ 1, num_videos_per_prompt, 1
434
+ )
435
+ negative_prompt_embeds = negative_prompt_embeds.view(
436
+ batch_size * num_videos_per_prompt, seq_len, -1
437
+ )
438
+
439
+ if text_encoder is not None:
440
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
441
+ # Retrieve the original scale by scaling back the LoRA layers
442
+ unscale_lora_layers(text_encoder.model, lora_scale)
443
+
444
+ return (
445
+ prompt_embeds,
446
+ negative_prompt_embeds,
447
+ attention_mask,
448
+ negative_attention_mask,
449
+ )
450
+
451
+ def decode_latents(self, latents, enable_tiling=True):
452
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
453
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
454
+
455
+ latents = 1 / self.vae.config.scaling_factor * latents
456
+ if enable_tiling:
457
+ self.vae.enable_tiling()
458
+ image = self.vae.decode(latents, return_dict=False)[0]
459
+ else:
460
+ image = self.vae.decode(latents, return_dict=False)[0]
461
+ image = (image / 2 + 0.5).clamp(0, 1)
462
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
463
+ if image.ndim == 4:
464
+ image = image.cpu().permute(0, 2, 3, 1).float()
465
+ else:
466
+ image = image.cpu().float()
467
+ return image
468
+
469
+ def prepare_extra_func_kwargs(self, func, kwargs):
470
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
471
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
472
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
473
+ # and should be between [0, 1]
474
+ extra_step_kwargs = {}
475
+
476
+ for k, v in kwargs.items():
477
+ accepts = k in set(inspect.signature(func).parameters.keys())
478
+ if accepts:
479
+ extra_step_kwargs[k] = v
480
+ return extra_step_kwargs
481
+
482
+ def check_inputs(
483
+ self,
484
+ prompt,
485
+ height,
486
+ width,
487
+ video_length,
488
+ callback_steps,
489
+ negative_prompt=None,
490
+ prompt_embeds=None,
491
+ negative_prompt_embeds=None,
492
+ callback_on_step_end_tensor_inputs=None,
493
+ vae_ver="88-4c-sd",
494
+ ):
495
+ if height % 8 != 0 or width % 8 != 0:
496
+ raise ValueError(
497
+ f"`height` and `width` have to be divisible by 8 but are {height} and {width}."
498
+ )
499
+
500
+ if video_length is not None:
501
+ if "884" in vae_ver:
502
+ if video_length != 1 and (video_length - 1) % 4 != 0:
503
+ raise ValueError(
504
+ f"`video_length` has to be 1 or a multiple of 4 but is {video_length}."
505
+ )
506
+ elif "888" in vae_ver:
507
+ if video_length != 1 and (video_length - 1) % 8 != 0:
508
+ raise ValueError(
509
+ f"`video_length` has to be 1 or a multiple of 8 but is {video_length}."
510
+ )
511
+
512
+ if callback_steps is not None and (
513
+ not isinstance(callback_steps, int) or callback_steps <= 0
514
+ ):
515
+ raise ValueError(
516
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
517
+ f" {type(callback_steps)}."
518
+ )
519
+ if callback_on_step_end_tensor_inputs is not None and not all(
520
+ k in self._callback_tensor_inputs
521
+ for k in callback_on_step_end_tensor_inputs
522
+ ):
523
+ raise ValueError(
524
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
525
+ )
526
+
527
+ if prompt is not None and prompt_embeds is not None:
528
+ raise ValueError(
529
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
530
+ " only forward one of the two."
531
+ )
532
+ elif prompt is None and prompt_embeds is None:
533
+ raise ValueError(
534
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
535
+ )
536
+ elif prompt is not None and (
537
+ not isinstance(prompt, str) and not isinstance(prompt, list)
538
+ ):
539
+ raise ValueError(
540
+ f"`prompt` has to be of type `str` or `list` but is {type(prompt)}"
541
+ )
542
+
543
+ if negative_prompt is not None and negative_prompt_embeds is not None:
544
+ raise ValueError(
545
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
546
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
547
+ )
548
+
549
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
550
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
551
+ raise ValueError(
552
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
553
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
554
+ f" {negative_prompt_embeds.shape}."
555
+ )
556
+
557
+
558
+ def prepare_latents(
559
+ self,
560
+ batch_size,
561
+ num_channels_latents,
562
+ height,
563
+ width,
564
+ video_length,
565
+ dtype,
566
+ device,
567
+ generator,
568
+ latents=None,
569
+ ):
570
+ shape = (
571
+ batch_size,
572
+ num_channels_latents,
573
+ video_length,
574
+ int(height) // self.vae_scale_factor,
575
+ int(width) // self.vae_scale_factor,
576
+ )
577
+ if isinstance(generator, list) and len(generator) != batch_size:
578
+ raise ValueError(
579
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
580
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
581
+ )
582
+
583
+ if latents is None:
584
+ latents = randn_tensor(
585
+ shape, generator=generator, device=device, dtype=dtype
586
+ )
587
+ else:
588
+ latents = latents.to(device)
589
+
590
+ # Check existence to make it compatible with FlowMatchEulerDiscreteScheduler
591
+ if hasattr(self.scheduler, "init_noise_sigma"):
592
+ # scale the initial noise by the standard deviation required by the scheduler
593
+ latents = latents * self.scheduler.init_noise_sigma
594
+ return latents
595
+
596
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
597
+ def get_guidance_scale_embedding(
598
+ self,
599
+ w: torch.Tensor,
600
+ embedding_dim: int = 512,
601
+ dtype: torch.dtype = torch.float32,
602
+ ) -> torch.Tensor:
603
+ """
604
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
605
+
606
+ Args:
607
+ w (`torch.Tensor`):
608
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
609
+ embedding_dim (`int`, *optional*, defaults to 512):
610
+ Dimension of the embeddings to generate.
611
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
612
+ Data type of the generated embeddings.
613
+
614
+ Returns:
615
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
616
+ """
617
+ assert len(w.shape) == 1
618
+ w = w * 1000.0
619
+
620
+ half_dim = embedding_dim // 2
621
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
622
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
623
+ emb = w.to(dtype)[:, None] * emb[None, :]
624
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
625
+ if embedding_dim % 2 == 1: # zero pad
626
+ emb = torch.nn.functional.pad(emb, (0, 1))
627
+ assert emb.shape == (w.shape[0], embedding_dim)
628
+ return emb
629
+
630
+ @property
631
+ def guidance_scale(self):
632
+ return self._guidance_scale
633
+
634
+ @property
635
+ def guidance_rescale(self):
636
+ return self._guidance_rescale
637
+
638
+ @property
639
+ def clip_skip(self):
640
+ return self._clip_skip
641
+
642
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
643
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
644
+ # corresponds to doing no classifier free guidance.
645
+ @property
646
+ def do_classifier_free_guidance(self):
647
+ # return self._guidance_scale > 1 and self.transformer.config.time_cond_proj_dim is None
648
+ return self._guidance_scale > 1
649
+
650
+ @property
651
+ def cross_attention_kwargs(self):
652
+ return self._cross_attention_kwargs
653
+
654
+ @property
655
+ def num_timesteps(self):
656
+ return self._num_timesteps
657
+
658
+ @property
659
+ def interrupt(self):
660
+ return self._interrupt
661
+
662
+ @torch.no_grad()
663
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
664
+ def __call__(
665
+ self,
666
+ prompt: Union[str, List[str]],
667
+ height: int,
668
+ width: int,
669
+ video_length: int,
670
+ data_type: str = "video",
671
+ num_inference_steps: int = 50,
672
+ timesteps: List[int] = None,
673
+ sigmas: List[float] = None,
674
+ guidance_scale: float = 7.5,
675
+ negative_prompt: Optional[Union[str, List[str]]] = None,
676
+ num_videos_per_prompt: Optional[int] = 1,
677
+ eta: float = 0.0,
678
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
679
+ latents: Optional[torch.Tensor] = None,
680
+ prompt_embeds: Optional[torch.Tensor] = None,
681
+ attention_mask: Optional[torch.Tensor] = None,
682
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
683
+ negative_attention_mask: Optional[torch.Tensor] = None,
684
+ output_type: Optional[str] = "pil",
685
+ return_dict: bool = True,
686
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
687
+ guidance_rescale: float = 0.0,
688
+ clip_skip: Optional[int] = None,
689
+ callback_on_step_end: Optional[
690
+ Union[
691
+ Callable[[int, int, Dict], None],
692
+ PipelineCallback,
693
+ MultiPipelineCallbacks,
694
+ ]
695
+ ] = None,
696
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
697
+ freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None,
698
+ vae_ver: str = "88-4c-sd",
699
+ enable_tiling: bool = False,
700
+ n_tokens: Optional[int] = None,
701
+ embedded_guidance_scale: Optional[float] = None,
702
+ **kwargs,
703
+ ):
704
+ r"""
705
+ The call function to the pipeline for generation.
706
+
707
+ Args:
708
+ prompt (`str` or `List[str]`):
709
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
710
+ height (`int`):
711
+ The height in pixels of the generated image.
712
+ width (`int`):
713
+ The width in pixels of the generated image.
714
+ video_length (`int`):
715
+ The number of frames in the generated video.
716
+ num_inference_steps (`int`, *optional*, defaults to 50):
717
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
718
+ expense of slower inference.
719
+ timesteps (`List[int]`, *optional*):
720
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
721
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
722
+ passed will be used. Must be in descending order.
723
+ sigmas (`List[float]`, *optional*):
724
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
725
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
726
+ will be used.
727
+ guidance_scale (`float`, *optional*, defaults to 7.5):
728
+ A higher guidance scale value encourages the model to generate images closely linked to the text
729
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
730
+ negative_prompt (`str` or `List[str]`, *optional*):
731
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
732
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
733
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
734
+ The number of images to generate per prompt.
735
+ eta (`float`, *optional*, defaults to 0.0):
736
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
737
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
738
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
739
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
740
+ generation deterministic.
741
+ latents (`torch.Tensor`, *optional*):
742
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
743
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
744
+ tensor is generated by sampling using the supplied random `generator`.
745
+ prompt_embeds (`torch.Tensor`, *optional*):
746
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
747
+ provided, text embeddings are generated from the `prompt` input argument.
748
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
749
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
750
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
751
+
752
+ output_type (`str`, *optional*, defaults to `"pil"`):
753
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
754
+ return_dict (`bool`, *optional*, defaults to `True`):
755
+ Whether or not to return a [`HunyuanVideoPipelineOutput`] instead of a
756
+ plain tuple.
757
+ cross_attention_kwargs (`dict`, *optional*):
758
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
759
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
760
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
761
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
762
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
763
+ using zero terminal SNR.
764
+ clip_skip (`int`, *optional*):
765
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
766
+ the output of the pre-final layer will be used for computing the prompt embeddings.
767
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
768
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
769
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
770
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
771
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
772
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
773
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
774
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
775
+ `._callback_tensor_inputs` attribute of your pipeline class.
776
+
777
+ Examples:
778
+
779
+ Returns:
780
+ [`~HunyuanVideoPipelineOutput`] or `tuple`:
781
+ If `return_dict` is `True`, [`HunyuanVideoPipelineOutput`] is returned,
782
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
783
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
784
+ "not-safe-for-work" (nsfw) content.
785
+ """
786
+ callback = kwargs.pop("callback", None)
787
+ callback_steps = kwargs.pop("callback_steps", None)
788
+
789
+ if callback is not None:
790
+ deprecate(
791
+ "callback",
792
+ "1.0.0",
793
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
794
+ )
795
+ if callback_steps is not None:
796
+ deprecate(
797
+ "callback_steps",
798
+ "1.0.0",
799
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
800
+ )
801
+
802
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
803
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
804
+
805
+ # 0. Default height and width to unet
806
+ # height = height or self.transformer.config.sample_size * self.vae_scale_factor
807
+ # width = width or self.transformer.config.sample_size * self.vae_scale_factor
808
+ # to deal with lora scaling and other possible forward hooks
809
+
810
+ # 1. Check inputs. Raise error if not correct
811
+ self.check_inputs(
812
+ prompt,
813
+ height,
814
+ width,
815
+ video_length,
816
+ callback_steps,
817
+ negative_prompt,
818
+ prompt_embeds,
819
+ negative_prompt_embeds,
820
+ callback_on_step_end_tensor_inputs,
821
+ vae_ver=vae_ver,
822
+ )
823
+
824
+ self._guidance_scale = guidance_scale
825
+ self._guidance_rescale = guidance_rescale
826
+ self._clip_skip = clip_skip
827
+ self._cross_attention_kwargs = cross_attention_kwargs
828
+ self._interrupt = False
829
+
830
+ # 2. Define call parameters
831
+ if prompt is not None and isinstance(prompt, str):
832
+ batch_size = 1
833
+ elif prompt is not None and isinstance(prompt, list):
834
+ batch_size = len(prompt)
835
+ else:
836
+ batch_size = prompt_embeds.shape[0]
837
+
838
+ device = torch.device(f"cuda:{dist.get_rank()}") if dist.is_initialized() else self._execution_device
839
+
840
+ # 3. Encode input prompt
841
+ lora_scale = (
842
+ self.cross_attention_kwargs.get("scale", None)
843
+ if self.cross_attention_kwargs is not None
844
+ else None
845
+ )
846
+
847
+ (
848
+ prompt_embeds,
849
+ negative_prompt_embeds,
850
+ prompt_mask,
851
+ negative_prompt_mask,
852
+ ) = self.encode_prompt(
853
+ prompt,
854
+ device,
855
+ num_videos_per_prompt,
856
+ self.do_classifier_free_guidance,
857
+ negative_prompt,
858
+ prompt_embeds=prompt_embeds,
859
+ attention_mask=attention_mask,
860
+ negative_prompt_embeds=negative_prompt_embeds,
861
+ negative_attention_mask=negative_attention_mask,
862
+ lora_scale=lora_scale,
863
+ clip_skip=self.clip_skip,
864
+ data_type=data_type,
865
+ )
866
+ if self.text_encoder_2 is not None:
867
+ (
868
+ prompt_embeds_2,
869
+ negative_prompt_embeds_2,
870
+ prompt_mask_2,
871
+ negative_prompt_mask_2,
872
+ ) = self.encode_prompt(
873
+ prompt,
874
+ device,
875
+ num_videos_per_prompt,
876
+ self.do_classifier_free_guidance,
877
+ negative_prompt,
878
+ prompt_embeds=None,
879
+ attention_mask=None,
880
+ negative_prompt_embeds=None,
881
+ negative_attention_mask=None,
882
+ lora_scale=lora_scale,
883
+ clip_skip=self.clip_skip,
884
+ text_encoder=self.text_encoder_2,
885
+ data_type=data_type,
886
+ )
887
+ else:
888
+ prompt_embeds_2 = None
889
+ negative_prompt_embeds_2 = None
890
+ prompt_mask_2 = None
891
+ negative_prompt_mask_2 = None
892
+
893
+ # For classifier free guidance, we need to do two forward passes.
894
+ # Here we concatenate the unconditional and text embeddings into a single batch
895
+ # to avoid doing two forward passes
896
+ if self.do_classifier_free_guidance:
897
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
898
+ if prompt_mask is not None:
899
+ prompt_mask = torch.cat([negative_prompt_mask, prompt_mask])
900
+ if prompt_embeds_2 is not None:
901
+ prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2])
902
+ if prompt_mask_2 is not None:
903
+ prompt_mask_2 = torch.cat([negative_prompt_mask_2, prompt_mask_2])
904
+
905
+
906
+ # 4. Prepare timesteps
907
+ extra_set_timesteps_kwargs = self.prepare_extra_func_kwargs(
908
+ self.scheduler.set_timesteps, {"n_tokens": n_tokens}
909
+ )
910
+ timesteps, num_inference_steps = retrieve_timesteps(
911
+ self.scheduler,
912
+ num_inference_steps,
913
+ device,
914
+ timesteps,
915
+ sigmas,
916
+ **extra_set_timesteps_kwargs,
917
+ )
918
+
919
+ if "884" in vae_ver:
920
+ video_length = (video_length - 1) // 4 + 1
921
+ elif "888" in vae_ver:
922
+ video_length = (video_length - 1) // 8 + 1
923
+ else:
924
+ video_length = video_length
925
+
926
+ # 5. Prepare latent variables
927
+ num_channels_latents = self.transformer.config.in_channels
928
+ latents = self.prepare_latents(
929
+ batch_size * num_videos_per_prompt,
930
+ num_channels_latents,
931
+ height,
932
+ width,
933
+ video_length,
934
+ prompt_embeds.dtype,
935
+ device,
936
+ generator,
937
+ latents,
938
+ )
939
+
940
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
941
+ extra_step_kwargs = self.prepare_extra_func_kwargs(
942
+ self.scheduler.step,
943
+ {"generator": generator, "eta": eta},
944
+ )
945
+
946
+ target_dtype = PRECISION_TO_TYPE[self.args.precision]
947
+ autocast_enabled = (
948
+ target_dtype != torch.float32
949
+ ) and not self.args.disable_autocast
950
+ vae_dtype = PRECISION_TO_TYPE[self.args.vae_precision]
951
+ vae_autocast_enabled = (
952
+ vae_dtype != torch.float32
953
+ ) and not self.args.disable_autocast
954
+
955
+ # 7. Denoising loop
956
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
957
+ self._num_timesteps = len(timesteps)
958
+
959
+ # if is_progress_bar:
960
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
961
+ for i, t in enumerate(timesteps):
962
+ if self.interrupt:
963
+ continue
964
+
965
+ # expand the latents if we are doing classifier free guidance
966
+ latent_model_input = (
967
+ torch.cat([latents] * 2)
968
+ if self.do_classifier_free_guidance
969
+ else latents
970
+ )
971
+ latent_model_input = self.scheduler.scale_model_input(
972
+ latent_model_input, t
973
+ )
974
+
975
+ t_expand = t.repeat(latent_model_input.shape[0])
976
+ guidance_expand = (
977
+ torch.tensor(
978
+ [embedded_guidance_scale] * latent_model_input.shape[0],
979
+ dtype=torch.float32,
980
+ device=device,
981
+ ).to(target_dtype)
982
+ * 1000.0
983
+ if embedded_guidance_scale is not None
984
+ else None
985
+ )
986
+
987
+ # predict the noise residual
988
+ with torch.autocast(
989
+ device_type="cuda", dtype=target_dtype, enabled=autocast_enabled
990
+ ):
991
+ noise_pred = self.transformer( # For an input image (129, 192, 336) (1, 256, 256)
992
+ latent_model_input, # [2, 16, 33, 24, 42]
993
+ t_expand, # [2]
994
+ text_states=prompt_embeds, # [2, 256, 4096]
995
+ text_mask=prompt_mask, # [2, 256]
996
+ text_states_2=prompt_embeds_2, # [2, 768]
997
+ freqs_cos=freqs_cis[0], # [seqlen, head_dim]
998
+ freqs_sin=freqs_cis[1], # [seqlen, head_dim]
999
+ guidance=guidance_expand,
1000
+ return_dict=True,
1001
+ )[
1002
+ "x"
1003
+ ]
1004
+
1005
+ # perform guidance
1006
+ if self.do_classifier_free_guidance:
1007
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1008
+ noise_pred = noise_pred_uncond + self.guidance_scale * (
1009
+ noise_pred_text - noise_pred_uncond
1010
+ )
1011
+
1012
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1013
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1014
+ noise_pred = rescale_noise_cfg(
1015
+ noise_pred,
1016
+ noise_pred_text,
1017
+ guidance_rescale=self.guidance_rescale,
1018
+ )
1019
+
1020
+ # compute the previous noisy sample x_t -> x_t-1
1021
+ latents = self.scheduler.step(
1022
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
1023
+ )[0]
1024
+
1025
+ if callback_on_step_end is not None:
1026
+ callback_kwargs = {}
1027
+ for k in callback_on_step_end_tensor_inputs:
1028
+ callback_kwargs[k] = locals()[k]
1029
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1030
+
1031
+ latents = callback_outputs.pop("latents", latents)
1032
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1033
+ negative_prompt_embeds = callback_outputs.pop(
1034
+ "negative_prompt_embeds", negative_prompt_embeds
1035
+ )
1036
+
1037
+ # call the callback, if provided
1038
+ if i == len(timesteps) - 1 or (
1039
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
1040
+ ):
1041
+ if progress_bar is not None:
1042
+ progress_bar.update()
1043
+ if callback is not None and i % callback_steps == 0:
1044
+ step_idx = i // getattr(self.scheduler, "order", 1)
1045
+ callback(step_idx, t, latents)
1046
+
1047
+ if not output_type == "latent":
1048
+ expand_temporal_dim = False
1049
+ if len(latents.shape) == 4:
1050
+ if isinstance(self.vae, AutoencoderKLCausal3D):
1051
+ latents = latents.unsqueeze(2)
1052
+ expand_temporal_dim = True
1053
+ elif len(latents.shape) == 5:
1054
+ pass
1055
+ else:
1056
+ raise ValueError(
1057
+ f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}."
1058
+ )
1059
+
1060
+ if (
1061
+ hasattr(self.vae.config, "shift_factor")
1062
+ and self.vae.config.shift_factor
1063
+ ):
1064
+ latents = (
1065
+ latents / self.vae.config.scaling_factor
1066
+ + self.vae.config.shift_factor
1067
+ )
1068
+ else:
1069
+ latents = latents / self.vae.config.scaling_factor
1070
+
1071
+ with torch.autocast(
1072
+ device_type="cuda", dtype=vae_dtype, enabled=vae_autocast_enabled
1073
+ ):
1074
+ if enable_tiling:
1075
+ self.vae.enable_tiling()
1076
+ image = self.vae.decode(
1077
+ latents, return_dict=False, generator=generator
1078
+ )[0]
1079
+ else:
1080
+ image = self.vae.decode(
1081
+ latents, return_dict=False, generator=generator
1082
+ )[0]
1083
+
1084
+ if expand_temporal_dim or image.shape[2] == 1:
1085
+ image = image.squeeze(2)
1086
+
1087
+ else:
1088
+ image = latents
1089
+
1090
+ image = (image / 2 + 0.5).clamp(0, 1)
1091
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1092
+ image = image.cpu().float()
1093
+
1094
+ # Offload all models
1095
+ self.maybe_free_model_hooks()
1096
+
1097
+ if not return_dict:
1098
+ return image
1099
+
1100
+ return HunyuanVideoPipelineOutput(videos=image)
hunyuan_model/posemb_layers.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple, List
3
+
4
+
5
+ def _to_tuple(x, dim=2):
6
+ if isinstance(x, int):
7
+ return (x,) * dim
8
+ elif len(x) == dim:
9
+ return x
10
+ else:
11
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
12
+
13
+
14
+ def get_meshgrid_nd(start, *args, dim=2):
15
+ """
16
+ Get n-D meshgrid with start, stop and num.
17
+
18
+ Args:
19
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22
+ n-tuples.
23
+ *args: See above.
24
+ dim (int): Dimension of the meshgrid. Defaults to 2.
25
+
26
+ Returns:
27
+ grid (np.ndarray): [dim, ...]
28
+ """
29
+ if len(args) == 0:
30
+ # start is grid_size
31
+ num = _to_tuple(start, dim=dim)
32
+ start = (0,) * dim
33
+ stop = num
34
+ elif len(args) == 1:
35
+ # start is start, args[0] is stop, step is 1
36
+ start = _to_tuple(start, dim=dim)
37
+ stop = _to_tuple(args[0], dim=dim)
38
+ num = [stop[i] - start[i] for i in range(dim)]
39
+ elif len(args) == 2:
40
+ # start is start, args[0] is stop, args[1] is num
41
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44
+ else:
45
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46
+
47
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48
+ axis_grid = []
49
+ for i in range(dim):
50
+ a, b, n = start[i], stop[i], num[i]
51
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32)[:n]
52
+ axis_grid.append(g)
53
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55
+
56
+ return grid
57
+
58
+
59
+ #################################################################################
60
+ # Rotary Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63
+
64
+
65
+ def reshape_for_broadcast(
66
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
67
+ x: torch.Tensor,
68
+ head_first=False,
69
+ ):
70
+ """
71
+ Reshape frequency tensor for broadcasting it with another tensor.
72
+
73
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
74
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
75
+
76
+ Notes:
77
+ When using FlashMHAModified, head_first should be False.
78
+ When using Attention, head_first should be True.
79
+
80
+ Args:
81
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
82
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
83
+ head_first (bool): head dimension first (except batch dim) or not.
84
+
85
+ Returns:
86
+ torch.Tensor: Reshaped frequency tensor.
87
+
88
+ Raises:
89
+ AssertionError: If the frequency tensor doesn't match the expected shape.
90
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
91
+ """
92
+ ndim = x.ndim
93
+ assert 0 <= 1 < ndim
94
+
95
+ if isinstance(freqs_cis, tuple):
96
+ # freqs_cis: (cos, sin) in real space
97
+ if head_first:
98
+ assert freqs_cis[0].shape == (
99
+ x.shape[-2],
100
+ x.shape[-1],
101
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
102
+ shape = [
103
+ d if i == ndim - 2 or i == ndim - 1 else 1
104
+ for i, d in enumerate(x.shape)
105
+ ]
106
+ else:
107
+ assert freqs_cis[0].shape == (
108
+ x.shape[1],
109
+ x.shape[-1],
110
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
111
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
112
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
113
+ else:
114
+ # freqs_cis: values in complex space
115
+ if head_first:
116
+ assert freqs_cis.shape == (
117
+ x.shape[-2],
118
+ x.shape[-1],
119
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
120
+ shape = [
121
+ d if i == ndim - 2 or i == ndim - 1 else 1
122
+ for i, d in enumerate(x.shape)
123
+ ]
124
+ else:
125
+ assert freqs_cis.shape == (
126
+ x.shape[1],
127
+ x.shape[-1],
128
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
129
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
130
+ return freqs_cis.view(*shape)
131
+
132
+
133
+ def rotate_half(x):
134
+ x_real, x_imag = (
135
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
136
+ ) # [B, S, H, D//2]
137
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
138
+
139
+
140
+ def apply_rotary_emb(
141
+ xq: torch.Tensor,
142
+ xk: torch.Tensor,
143
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
144
+ head_first: bool = False,
145
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
146
+ """
147
+ Apply rotary embeddings to input tensors using the given frequency tensor.
148
+
149
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
150
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
151
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
152
+ returned as real tensors.
153
+
154
+ Args:
155
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
156
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
157
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
158
+ head_first (bool): head dimension first (except batch dim) or not.
159
+
160
+ Returns:
161
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
162
+
163
+ """
164
+ xk_out = None
165
+ if isinstance(freqs_cis, tuple):
166
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
167
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
168
+ # real * cos - imag * sin
169
+ # imag * cos + real * sin
170
+ xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq)
171
+ xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk)
172
+ else:
173
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
174
+ xq_ = torch.view_as_complex(
175
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
176
+ ) # [B, S, H, D//2]
177
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
178
+ xq.device
179
+ ) # [S, D//2] --> [1, S, 1, D//2]
180
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
181
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
182
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
183
+ xk_ = torch.view_as_complex(
184
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
185
+ ) # [B, S, H, D//2]
186
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
187
+
188
+ return xq_out, xk_out
189
+
190
+
191
+ def get_nd_rotary_pos_embed(
192
+ rope_dim_list,
193
+ start,
194
+ *args,
195
+ theta=10000.0,
196
+ use_real=False,
197
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
198
+ interpolation_factor: Union[float, List[float]] = 1.0,
199
+ ):
200
+ """
201
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
202
+
203
+ Args:
204
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
205
+ sum(rope_dim_list) should equal to head_dim of attention layer.
206
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
207
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
208
+ *args: See above.
209
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
210
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
211
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
212
+ part and an imaginary part separately.
213
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
214
+
215
+ Returns:
216
+ pos_embed (torch.Tensor): [HW, D/2]
217
+ """
218
+
219
+ grid = get_meshgrid_nd(
220
+ start, *args, dim=len(rope_dim_list)
221
+ ) # [3, W, H, D] / [2, W, H]
222
+
223
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
224
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
225
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
226
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
227
+ assert len(theta_rescale_factor) == len(
228
+ rope_dim_list
229
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
230
+
231
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
232
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
233
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
234
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
235
+ assert len(interpolation_factor) == len(
236
+ rope_dim_list
237
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
238
+
239
+ # use 1/ndim of dimensions to encode grid_axis
240
+ embs = []
241
+ for i in range(len(rope_dim_list)):
242
+ emb = get_1d_rotary_pos_embed(
243
+ rope_dim_list[i],
244
+ grid[i].reshape(-1),
245
+ theta,
246
+ use_real=use_real,
247
+ theta_rescale_factor=theta_rescale_factor[i],
248
+ interpolation_factor=interpolation_factor[i],
249
+ ) # 2 x [WHD, rope_dim_list[i]]
250
+ embs.append(emb)
251
+
252
+ if use_real:
253
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
254
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
255
+ return cos, sin
256
+ else:
257
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
258
+ return emb
259
+
260
+
261
+ def get_1d_rotary_pos_embed(
262
+ dim: int,
263
+ pos: Union[torch.FloatTensor, int],
264
+ theta: float = 10000.0,
265
+ use_real: bool = False,
266
+ theta_rescale_factor: float = 1.0,
267
+ interpolation_factor: float = 1.0,
268
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
269
+ """
270
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
271
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
272
+
273
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
274
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
275
+ The returned tensor contains complex values in complex64 data type.
276
+
277
+ Args:
278
+ dim (int): Dimension of the frequency tensor.
279
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
280
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
281
+ use_real (bool, optional): If True, return real part and imaginary part separately.
282
+ Otherwise, return complex numbers.
283
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
284
+
285
+ Returns:
286
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
287
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
288
+ """
289
+ if isinstance(pos, int):
290
+ pos = torch.arange(pos).float()
291
+
292
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
293
+ # has some connection to NTK literature
294
+ if theta_rescale_factor != 1.0:
295
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
296
+
297
+ freqs = 1.0 / (
298
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
299
+ ) # [D/2]
300
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
301
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
302
+ if use_real:
303
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
304
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
305
+ return freqs_cos, freqs_sin
306
+ else:
307
+ freqs_cis = torch.polar(
308
+ torch.ones_like(freqs), freqs
309
+ ) # complex64 # [S, D/2]
310
+ return freqs_cis
hunyuan_model/text_encoder.py ADDED
@@ -0,0 +1,710 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import json
3
+ import os
4
+ from typing import Optional, Tuple, Union
5
+ from copy import deepcopy
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import (
10
+ CLIPTextModel,
11
+ CLIPTokenizer,
12
+ AutoTokenizer,
13
+ AutoModel,
14
+ CLIPConfig,
15
+ LlamaForCausalLM,
16
+ LlamaConfig,
17
+ )
18
+ from transformers.utils import ModelOutput
19
+ from transformers.models.llama import LlamaModel
20
+ from safetensors.torch import load_file
21
+ from accelerate import init_empty_weights
22
+
23
+ import logging
24
+
25
+ logger = logging.getLogger(__name__)
26
+ logging.basicConfig(level=logging.INFO)
27
+
28
+
29
+ CLIP_L_HUGGINGFACE_MODEL_ID = "openai/clip-vit-large-patch14"
30
+ LLAVA_HUGGINGFACE_MODEL_ID = "xtuner/llava-llama-3-8b-v1_1-transformers"
31
+
32
+ CLIP_CONFIG = {
33
+ "_name_or_path": "clip-vit-large-patch14/",
34
+ "architectures": ["CLIPModel"],
35
+ "initializer_factor": 1.0,
36
+ "logit_scale_init_value": 2.6592,
37
+ "model_type": "clip",
38
+ "projection_dim": 768,
39
+ # "text_config": {
40
+ "_name_or_path": "",
41
+ "add_cross_attention": False,
42
+ "architectures": None,
43
+ "attention_dropout": 0.0,
44
+ "bad_words_ids": None,
45
+ "bos_token_id": 0,
46
+ "chunk_size_feed_forward": 0,
47
+ "cross_attention_hidden_size": None,
48
+ "decoder_start_token_id": None,
49
+ "diversity_penalty": 0.0,
50
+ "do_sample": False,
51
+ "dropout": 0.0,
52
+ "early_stopping": False,
53
+ "encoder_no_repeat_ngram_size": 0,
54
+ "eos_token_id": 2,
55
+ "finetuning_task": None,
56
+ "forced_bos_token_id": None,
57
+ "forced_eos_token_id": None,
58
+ "hidden_act": "quick_gelu",
59
+ "hidden_size": 768,
60
+ "id2label": {"0": "LABEL_0", "1": "LABEL_1"},
61
+ "initializer_factor": 1.0,
62
+ "initializer_range": 0.02,
63
+ "intermediate_size": 3072,
64
+ "is_decoder": False,
65
+ "is_encoder_decoder": False,
66
+ "label2id": {"LABEL_0": 0, "LABEL_1": 1},
67
+ "layer_norm_eps": 1e-05,
68
+ "length_penalty": 1.0,
69
+ "max_length": 20,
70
+ "max_position_embeddings": 77,
71
+ "min_length": 0,
72
+ "model_type": "clip_text_model",
73
+ "no_repeat_ngram_size": 0,
74
+ "num_attention_heads": 12,
75
+ "num_beam_groups": 1,
76
+ "num_beams": 1,
77
+ "num_hidden_layers": 12,
78
+ "num_return_sequences": 1,
79
+ "output_attentions": False,
80
+ "output_hidden_states": False,
81
+ "output_scores": False,
82
+ "pad_token_id": 1,
83
+ "prefix": None,
84
+ "problem_type": None,
85
+ "projection_dim": 768,
86
+ "pruned_heads": {},
87
+ "remove_invalid_values": False,
88
+ "repetition_penalty": 1.0,
89
+ "return_dict": True,
90
+ "return_dict_in_generate": False,
91
+ "sep_token_id": None,
92
+ "task_specific_params": None,
93
+ "temperature": 1.0,
94
+ "tie_encoder_decoder": False,
95
+ "tie_word_embeddings": True,
96
+ "tokenizer_class": None,
97
+ "top_k": 50,
98
+ "top_p": 1.0,
99
+ "torch_dtype": None,
100
+ "torchscript": False,
101
+ "transformers_version": "4.16.0.dev0",
102
+ "use_bfloat16": False,
103
+ "vocab_size": 49408,
104
+ # },
105
+ # "text_config_dict": {
106
+ "hidden_size": 768,
107
+ "intermediate_size": 3072,
108
+ "num_attention_heads": 12,
109
+ "num_hidden_layers": 12,
110
+ "projection_dim": 768,
111
+ # },
112
+ # "torch_dtype": "float32",
113
+ # "transformers_version": null
114
+ }
115
+
116
+ LLAMA_CONFIG = {
117
+ "architectures": ["LlamaForCausalLM"],
118
+ "attention_bias": False,
119
+ "attention_dropout": 0.0,
120
+ "bos_token_id": 128000,
121
+ "eos_token_id": 128001,
122
+ "head_dim": 128,
123
+ "hidden_act": "silu",
124
+ "hidden_size": 4096,
125
+ "initializer_range": 0.02,
126
+ "intermediate_size": 14336,
127
+ "max_position_embeddings": 8192,
128
+ "mlp_bias": False,
129
+ "model_type": "llama",
130
+ "num_attention_heads": 32,
131
+ "num_hidden_layers": 32,
132
+ "num_key_value_heads": 8,
133
+ "pretraining_tp": 1,
134
+ "rms_norm_eps": 1e-05,
135
+ "rope_scaling": None,
136
+ "rope_theta": 500000.0,
137
+ "tie_word_embeddings": False,
138
+ "torch_dtype": "float16",
139
+ "transformers_version": "4.46.3",
140
+ "use_cache": True,
141
+ "vocab_size": 128320,
142
+ }
143
+
144
+ # When using decoder-only models, we must provide a prompt template to instruct the text encoder
145
+ # on how to generate the text.
146
+ # --------------------------------------------------------------------
147
+ PROMPT_TEMPLATE_ENCODE = (
148
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, "
149
+ "quantity, text, spatial relationships of the objects and background:<|eot_id|>"
150
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
151
+ )
152
+ PROMPT_TEMPLATE_ENCODE_VIDEO = (
153
+ "<|start_header_id|>system<|end_header_id|>\n\nDescribe the video by detailing the following aspects: "
154
+ "1. The main content and theme of the video."
155
+ "2. The color, shape, size, texture, quantity, text, and spatial relationships of the objects."
156
+ "3. Actions, events, behaviors temporal relationships, physical movement changes of the objects."
157
+ "4. background environment, light, style and atmosphere."
158
+ "5. camera angles, movements, and transitions used in the video:<|eot_id|>"
159
+ "<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>"
160
+ )
161
+
162
+ NEGATIVE_PROMPT = "Aerial view, aerial view, overexposed, low quality, deformation, a poor composition, bad hands, bad teeth, bad eyes, bad limbs, distortion"
163
+
164
+ PROMPT_TEMPLATE = {
165
+ "dit-llm-encode": {
166
+ "template": PROMPT_TEMPLATE_ENCODE,
167
+ "crop_start": 36,
168
+ },
169
+ "dit-llm-encode-video": {
170
+ "template": PROMPT_TEMPLATE_ENCODE_VIDEO,
171
+ "crop_start": 95,
172
+ },
173
+ }
174
+
175
+
176
+ def use_default(value, default):
177
+ return value if value is not None else default
178
+
179
+
180
+ def load_clip_l(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
181
+ if os.path.isdir(text_encoder_path):
182
+ # load from directory, configs are in the directory
183
+ text_encoder = CLIPTextModel.from_pretrained(text_encoder_path, torch_dtype=dtype)
184
+ else:
185
+ # load from file, we create the model with the appropriate config
186
+ config = CLIPConfig(**CLIP_CONFIG)
187
+ with init_empty_weights():
188
+ text_encoder = CLIPTextModel._from_config(config, torch_dtype=dtype)
189
+
190
+ state_dict = load_file(text_encoder_path)
191
+
192
+ text_encoder.load_state_dict(state_dict, strict=True, assign=True)
193
+ # if dtype is not None:
194
+ # text_encoder.to(dtype=dtype)
195
+
196
+ return text_encoder
197
+
198
+
199
+ def load_clip_l_tokenizer(tokenizer_path: str):
200
+ if os.path.isdir(tokenizer_path):
201
+ tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path, max_length=77)
202
+ else:
203
+ # load from Hugging Face
204
+ logger.info(f"Loading tokenizer from Hugging Face: {CLIP_L_HUGGINGFACE_MODEL_ID}")
205
+ tokenizer = CLIPTokenizer.from_pretrained(CLIP_L_HUGGINGFACE_MODEL_ID, max_length=77)
206
+
207
+ return tokenizer
208
+
209
+
210
+ def load_llm(text_encoder_path: str, dtype: Optional[Union[str, torch.dtype]] = None):
211
+ if os.path.isdir(text_encoder_path):
212
+ # load from directory, configs are in the directory
213
+ text_encoder = AutoModel.from_pretrained(text_encoder_path, low_cpu_mem_usage=True, torch_dtype=dtype)
214
+ else:
215
+ # load from file, we create the model with the appropriate config
216
+ config = LlamaConfig(**LLAMA_CONFIG)
217
+ with init_empty_weights():
218
+ text_encoder = LlamaForCausalLM._from_config(config, torch_dtype=dtype)
219
+
220
+ state_dict = load_file(text_encoder_path)
221
+
222
+ # support weights from ComfyUI
223
+ if "tokenizer" in state_dict:
224
+ state_dict.pop("tokenizer")
225
+
226
+ text_encoder.load_state_dict(state_dict, strict=True, assign=True)
227
+
228
+ return text_encoder
229
+
230
+
231
+ def load_llm_tokenizer(tokenizer_path: str, padding_side="right"):
232
+ if os.path.isdir(tokenizer_path):
233
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
234
+ else:
235
+ # load from Hugging Face
236
+ logger.info(f"Loading tokenizer from Hugging Face: {LLAVA_HUGGINGFACE_MODEL_ID}")
237
+ tokenizer = AutoTokenizer.from_pretrained(LLAVA_HUGGINGFACE_MODEL_ID, padding_side=padding_side)
238
+
239
+ return tokenizer
240
+
241
+
242
+ def load_text_encoder(
243
+ text_encoder_type: str,
244
+ text_encoder_path: str,
245
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
246
+ ):
247
+ logger.info(f"Loading text encoder model ({text_encoder_type}) from: {text_encoder_path}")
248
+
249
+ # reduce peak memory usage by specifying the dtype of the model
250
+ dtype = text_encoder_dtype
251
+ if text_encoder_type == "clipL":
252
+ text_encoder = load_clip_l(text_encoder_path, dtype=dtype)
253
+ text_encoder.final_layer_norm = text_encoder.text_model.final_layer_norm
254
+ elif text_encoder_type == "llm":
255
+ text_encoder = load_llm(text_encoder_path, dtype=dtype)
256
+ if hasattr(text_encoder, "norm"):
257
+ text_encoder.final_layer_norm = text_encoder.norm # by from_pretrained
258
+ else:
259
+ text_encoder.final_layer_norm = text_encoder.model.norm # by _from_config
260
+ else:
261
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
262
+ # from_pretrained will ensure that the model is in eval mode.
263
+
264
+ if dtype is not None:
265
+ text_encoder = text_encoder.to(dtype=dtype)
266
+
267
+ text_encoder.requires_grad_(False)
268
+
269
+ logger.info(f"Text encoder to dtype: {text_encoder.dtype}")
270
+ return text_encoder, text_encoder_path
271
+
272
+
273
+ def load_tokenizer(tokenizer_type, tokenizer_path=None, padding_side="right"):
274
+ logger.info(f"Loading tokenizer ({tokenizer_type}) from: {tokenizer_path}")
275
+
276
+ if tokenizer_type == "clipL":
277
+ tokenizer = load_clip_l_tokenizer(tokenizer_path)
278
+ elif tokenizer_type == "llm":
279
+ tokenizer = load_llm_tokenizer(tokenizer_path, padding_side=padding_side)
280
+ else:
281
+ raise ValueError(f"Unsupported tokenizer type: {tokenizer_type}")
282
+
283
+ return tokenizer, tokenizer_path
284
+
285
+
286
+ @dataclass
287
+ class TextEncoderModelOutput(ModelOutput):
288
+ """
289
+ Base class for model's outputs that also contains a pooling of the last hidden states.
290
+
291
+ Args:
292
+ hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
293
+ Sequence of hidden-states at the output of the last layer of the model.
294
+ attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
295
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
296
+ hidden_states_list (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed):
297
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
298
+ one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
299
+ Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
300
+ text_outputs (`list`, *optional*, returned when `return_texts=True` is passed):
301
+ List of decoded texts.
302
+ """
303
+
304
+ hidden_state: torch.FloatTensor = None
305
+ attention_mask: Optional[torch.LongTensor] = None
306
+ hidden_states_list: Optional[Tuple[torch.FloatTensor, ...]] = None
307
+ text_outputs: Optional[list] = None
308
+
309
+
310
+ class TextEncoder(nn.Module):
311
+ def __init__(
312
+ self,
313
+ text_encoder_type: str,
314
+ max_length: int,
315
+ text_encoder_dtype: Optional[Union[str, torch.dtype]] = None,
316
+ text_encoder_path: Optional[str] = None,
317
+ tokenizer_type: Optional[str] = None,
318
+ tokenizer_path: Optional[str] = None,
319
+ output_key: Optional[str] = None,
320
+ use_attention_mask: bool = True,
321
+ input_max_length: Optional[int] = None,
322
+ prompt_template: Optional[dict] = None,
323
+ prompt_template_video: Optional[dict] = None,
324
+ hidden_state_skip_layer: Optional[int] = None,
325
+ apply_final_norm: bool = False,
326
+ reproduce: bool = False,
327
+ ):
328
+ super().__init__()
329
+ self.text_encoder_type = text_encoder_type
330
+ self.max_length = max_length
331
+ # self.precision = text_encoder_precision
332
+ self.model_path = text_encoder_path
333
+ self.tokenizer_type = tokenizer_type if tokenizer_type is not None else text_encoder_type
334
+ self.tokenizer_path = tokenizer_path if tokenizer_path is not None else text_encoder_path
335
+ self.use_attention_mask = use_attention_mask
336
+ if prompt_template_video is not None:
337
+ assert use_attention_mask is True, "Attention mask is True required when training videos."
338
+ self.input_max_length = input_max_length if input_max_length is not None else max_length
339
+ self.prompt_template = prompt_template
340
+ self.prompt_template_video = prompt_template_video
341
+ self.hidden_state_skip_layer = hidden_state_skip_layer
342
+ self.apply_final_norm = apply_final_norm
343
+ self.reproduce = reproduce
344
+
345
+ self.use_template = self.prompt_template is not None
346
+ if self.use_template:
347
+ assert (
348
+ isinstance(self.prompt_template, dict) and "template" in self.prompt_template
349
+ ), f"`prompt_template` must be a dictionary with a key 'template', got {self.prompt_template}"
350
+ assert "{}" in str(self.prompt_template["template"]), (
351
+ "`prompt_template['template']` must contain a placeholder `{}` for the input text, "
352
+ f"got {self.prompt_template['template']}"
353
+ )
354
+
355
+ self.use_video_template = self.prompt_template_video is not None
356
+ if self.use_video_template:
357
+ if self.prompt_template_video is not None:
358
+ assert (
359
+ isinstance(self.prompt_template_video, dict) and "template" in self.prompt_template_video
360
+ ), f"`prompt_template_video` must be a dictionary with a key 'template', got {self.prompt_template_video}"
361
+ assert "{}" in str(self.prompt_template_video["template"]), (
362
+ "`prompt_template_video['template']` must contain a placeholder `{}` for the input text, "
363
+ f"got {self.prompt_template_video['template']}"
364
+ )
365
+
366
+ if "t5" in text_encoder_type:
367
+ self.output_key = output_key or "last_hidden_state"
368
+ elif "clip" in text_encoder_type:
369
+ self.output_key = output_key or "pooler_output"
370
+ elif "llm" in text_encoder_type or "glm" in text_encoder_type:
371
+ self.output_key = output_key or "last_hidden_state"
372
+ else:
373
+ raise ValueError(f"Unsupported text encoder type: {text_encoder_type}")
374
+
375
+ self.model, self.model_path = load_text_encoder(
376
+ text_encoder_type=self.text_encoder_type, text_encoder_path=self.model_path, text_encoder_dtype=text_encoder_dtype
377
+ )
378
+ self.dtype = self.model.dtype
379
+
380
+ self.tokenizer, self.tokenizer_path = load_tokenizer(
381
+ tokenizer_type=self.tokenizer_type, tokenizer_path=self.tokenizer_path, padding_side="right"
382
+ )
383
+
384
+ def __repr__(self):
385
+ return f"{self.text_encoder_type} ({self.precision} - {self.model_path})"
386
+
387
+ @property
388
+ def device(self):
389
+ return self.model.device
390
+
391
+ @staticmethod
392
+ def apply_text_to_template(text, template, prevent_empty_text=True):
393
+ """
394
+ Apply text to template.
395
+
396
+ Args:
397
+ text (str): Input text.
398
+ template (str or list): Template string or list of chat conversation.
399
+ prevent_empty_text (bool): If Ture, we will prevent the user text from being empty
400
+ by adding a space. Defaults to True.
401
+ """
402
+ if isinstance(template, str):
403
+ # Will send string to tokenizer. Used for llm
404
+ return template.format(text)
405
+ else:
406
+ raise TypeError(f"Unsupported template type: {type(template)}")
407
+
408
+ def text2tokens(self, text, data_type="image"):
409
+ """
410
+ Tokenize the input text.
411
+
412
+ Args:
413
+ text (str or list): Input text.
414
+ """
415
+ tokenize_input_type = "str"
416
+ if self.use_template:
417
+ if data_type == "image":
418
+ prompt_template = self.prompt_template["template"]
419
+ elif data_type == "video":
420
+ prompt_template = self.prompt_template_video["template"]
421
+ else:
422
+ raise ValueError(f"Unsupported data type: {data_type}")
423
+ if isinstance(text, (list, tuple)):
424
+ text = [self.apply_text_to_template(one_text, prompt_template) for one_text in text]
425
+ if isinstance(text[0], list):
426
+ tokenize_input_type = "list"
427
+ elif isinstance(text, str):
428
+ text = self.apply_text_to_template(text, prompt_template)
429
+ if isinstance(text, list):
430
+ tokenize_input_type = "list"
431
+ else:
432
+ raise TypeError(f"Unsupported text type: {type(text)}")
433
+
434
+ kwargs = dict(
435
+ truncation=True,
436
+ max_length=self.max_length,
437
+ padding="max_length",
438
+ return_tensors="pt",
439
+ )
440
+ if tokenize_input_type == "str":
441
+ return self.tokenizer(
442
+ text,
443
+ return_length=False,
444
+ return_overflowing_tokens=False,
445
+ return_attention_mask=True,
446
+ **kwargs,
447
+ )
448
+ elif tokenize_input_type == "list":
449
+ return self.tokenizer.apply_chat_template(
450
+ text,
451
+ add_generation_prompt=True,
452
+ tokenize=True,
453
+ return_dict=True,
454
+ **kwargs,
455
+ )
456
+ else:
457
+ raise ValueError(f"Unsupported tokenize_input_type: {tokenize_input_type}")
458
+
459
+ def encode(
460
+ self,
461
+ batch_encoding,
462
+ use_attention_mask=None,
463
+ output_hidden_states=False,
464
+ do_sample=None,
465
+ hidden_state_skip_layer=None,
466
+ return_texts=False,
467
+ data_type="image",
468
+ device=None,
469
+ ):
470
+ """
471
+ Args:
472
+ batch_encoding (dict): Batch encoding from tokenizer.
473
+ use_attention_mask (bool): Whether to use attention mask. If None, use self.use_attention_mask.
474
+ Defaults to None.
475
+ output_hidden_states (bool): Whether to output hidden states. If False, return the value of
476
+ self.output_key. If True, return the entire output. If set self.hidden_state_skip_layer,
477
+ output_hidden_states will be set True. Defaults to False.
478
+ do_sample (bool): Whether to sample from the model. Used for Decoder-Only LLMs. Defaults to None.
479
+ When self.produce is False, do_sample is set to True by default.
480
+ hidden_state_skip_layer (int): Number of hidden states to hidden_state_skip_layer. 0 means the last layer.
481
+ If None, self.output_key will be used. Defaults to None.
482
+ return_texts (bool): Whether to return the decoded texts. Defaults to False.
483
+ """
484
+ device = self.model.device if device is None else device
485
+ use_attention_mask = use_default(use_attention_mask, self.use_attention_mask)
486
+ hidden_state_skip_layer = use_default(hidden_state_skip_layer, self.hidden_state_skip_layer)
487
+ do_sample = use_default(do_sample, not self.reproduce)
488
+ attention_mask = batch_encoding["attention_mask"].to(device) if use_attention_mask else None
489
+ outputs = self.model(
490
+ input_ids=batch_encoding["input_ids"].to(device),
491
+ attention_mask=attention_mask,
492
+ output_hidden_states=output_hidden_states or hidden_state_skip_layer is not None,
493
+ )
494
+ if hidden_state_skip_layer is not None:
495
+ last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
496
+ # Real last hidden state already has layer norm applied. So here we only apply it
497
+ # for intermediate layers.
498
+ if hidden_state_skip_layer > 0 and self.apply_final_norm:
499
+ last_hidden_state = self.model.final_layer_norm(last_hidden_state)
500
+ else:
501
+ last_hidden_state = outputs[self.output_key]
502
+
503
+ # Remove hidden states of instruction tokens, only keep prompt tokens.
504
+ if self.use_template:
505
+ if data_type == "image":
506
+ crop_start = self.prompt_template.get("crop_start", -1)
507
+ elif data_type == "video":
508
+ crop_start = self.prompt_template_video.get("crop_start", -1)
509
+ else:
510
+ raise ValueError(f"Unsupported data type: {data_type}")
511
+ if crop_start > 0:
512
+ last_hidden_state = last_hidden_state[:, crop_start:]
513
+ attention_mask = attention_mask[:, crop_start:] if use_attention_mask else None
514
+
515
+ if output_hidden_states:
516
+ return TextEncoderModelOutput(last_hidden_state, attention_mask, outputs.hidden_states)
517
+ return TextEncoderModelOutput(last_hidden_state, attention_mask)
518
+
519
+ def forward(
520
+ self,
521
+ text,
522
+ use_attention_mask=None,
523
+ output_hidden_states=False,
524
+ do_sample=False,
525
+ hidden_state_skip_layer=None,
526
+ return_texts=False,
527
+ ):
528
+ batch_encoding = self.text2tokens(text)
529
+ return self.encode(
530
+ batch_encoding,
531
+ use_attention_mask=use_attention_mask,
532
+ output_hidden_states=output_hidden_states,
533
+ do_sample=do_sample,
534
+ hidden_state_skip_layer=hidden_state_skip_layer,
535
+ return_texts=return_texts,
536
+ )
537
+
538
+
539
+ # region HunyanVideo architecture
540
+
541
+
542
+ def load_text_encoder_1(
543
+ text_encoder_dir: str, device: torch.device, fp8_llm: bool, dtype: Optional[Union[str, torch.dtype]] = None
544
+ ) -> TextEncoder:
545
+ text_encoder_dtype = dtype or torch.float16
546
+ text_encoder_type = "llm"
547
+ text_len = 256
548
+ hidden_state_skip_layer = 2
549
+ apply_final_norm = False
550
+ reproduce = False
551
+
552
+ prompt_template = "dit-llm-encode"
553
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
554
+ prompt_template_video = "dit-llm-encode-video"
555
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video]
556
+
557
+ crop_start = prompt_template_video["crop_start"] # .get("crop_start", 0)
558
+ max_length = text_len + crop_start
559
+
560
+ text_encoder_1 = TextEncoder(
561
+ text_encoder_type=text_encoder_type,
562
+ max_length=max_length,
563
+ text_encoder_dtype=text_encoder_dtype,
564
+ text_encoder_path=text_encoder_dir,
565
+ tokenizer_type=text_encoder_type,
566
+ prompt_template=prompt_template,
567
+ prompt_template_video=prompt_template_video,
568
+ hidden_state_skip_layer=hidden_state_skip_layer,
569
+ apply_final_norm=apply_final_norm,
570
+ reproduce=reproduce,
571
+ )
572
+ text_encoder_1.eval()
573
+
574
+ if fp8_llm:
575
+ org_dtype = text_encoder_1.dtype
576
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
577
+ text_encoder_1.to(device=device, dtype=torch.float8_e4m3fn)
578
+
579
+ # prepare LLM for fp8
580
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
581
+ def forward_hook(module):
582
+ def forward(hidden_states):
583
+ input_dtype = hidden_states.dtype
584
+ hidden_states = hidden_states.to(torch.float32)
585
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
586
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
587
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
588
+
589
+ return forward
590
+
591
+ for module in llama_model.modules():
592
+ if module.__class__.__name__ in ["Embedding"]:
593
+ # print("set", module.__class__.__name__, "to", target_dtype)
594
+ module.to(target_dtype)
595
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
596
+ # print("set", module.__class__.__name__, "hooks")
597
+ module.forward = forward_hook(module)
598
+
599
+ prepare_fp8(text_encoder_1.model, org_dtype)
600
+ else:
601
+ text_encoder_1.to(device=device)
602
+
603
+ return text_encoder_1
604
+
605
+
606
+ def load_text_encoder_2(
607
+ text_encoder_dir: str, device: torch.device, dtype: Optional[Union[str, torch.dtype]] = None
608
+ ) -> TextEncoder:
609
+ text_encoder_dtype = dtype or torch.float16
610
+ reproduce = False
611
+
612
+ text_encoder_2_type = "clipL"
613
+ text_len_2 = 77
614
+
615
+ text_encoder_2 = TextEncoder(
616
+ text_encoder_type=text_encoder_2_type,
617
+ max_length=text_len_2,
618
+ text_encoder_dtype=text_encoder_dtype,
619
+ text_encoder_path=text_encoder_dir,
620
+ tokenizer_type=text_encoder_2_type,
621
+ reproduce=reproduce,
622
+ )
623
+ text_encoder_2.eval()
624
+
625
+ text_encoder_2.to(device=device)
626
+
627
+ return text_encoder_2
628
+
629
+
630
+ # endregion
631
+
632
+
633
+ if __name__ == "__main__":
634
+ import argparse
635
+ from utils.model_utils import str_to_dtype
636
+
637
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
638
+
639
+ parser = argparse.ArgumentParser()
640
+ parser.add_argument("type", type=str, help="Text Encoder type")
641
+ parser.add_argument("path1", type=str, help="Text Encoder directory or file 1")
642
+ parser.add_argument("path2", type=str, help="Text Encoder directory or file 2")
643
+ parser.add_argument("--dtype", type=str, default=None, help="Data type for Text Encoder")
644
+ args = parser.parse_args()
645
+
646
+ dtype = str_to_dtype(args.dtype) if args.dtype is not None else torch.float16
647
+
648
+ """
649
+ if args.type == "clipL":
650
+ text_encoder_1st = load_clip_l(args.path1, dtype=dtype)
651
+ tokenizer_1st = load_clip_l_tokenizer(args.path1)
652
+ text_encoder_2nd = load_clip_l(args.path2, dtype=dtype)
653
+ tokenizer_2nd = load_clip_l_tokenizer(args.path2)
654
+ elif args.type == "llm":
655
+ text_encoder_1st = load_llm(args.path1, dtype=dtype)
656
+ tokenizer_1st = load_llm_tokenizer(args.path1)
657
+ text_encoder_2nd = load_llm(args.path2, dtype=dtype)
658
+ tokenizer_2nd = load_llm_tokenizer(args.path2)
659
+
660
+ print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
661
+ print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
662
+
663
+ text_encoder_1st.to(device=device)
664
+ text_encoder_2nd.to(device=device)
665
+
666
+ test_text = "A cat sitting on a table"
667
+ token_ids_1st = tokenizer_1st(test_text, return_tensors="pt")["input_ids"]
668
+ token_ids_2nd = tokenizer_2nd(test_text, return_tensors="pt")["input_ids"]
669
+ assert torch.allclose(token_ids_1st, token_ids_2nd)
670
+ print(f"Token IDs are the same: {token_ids_1st}")
671
+
672
+ with torch.no_grad():
673
+ text_encoder_1st_output = text_encoder_1st(token_ids_1st.to(device), output_hidden_states=True)
674
+ text_encoder_2nd_output = text_encoder_2nd(token_ids_2nd.to(device), output_hidden_states=True)
675
+ print(f"1st Text Encoder output keys: {text_encoder_1st_output.keys()}")
676
+ print(f"2nd Text Encoder output keys: {text_encoder_2nd_output.keys()}")
677
+ for key in text_encoder_1st_output:
678
+ print(f"Checking output: {key}")
679
+ assert key in text_encoder_2nd_output, f"Key {key} not in 2nd Text Encoder output"
680
+ assert torch.allclose(text_encoder_1st_output[key], text_encoder_2nd_output[key])
681
+ print(f"Outputs are the same: {key}")
682
+ print("All outputs are the same.")
683
+ """
684
+
685
+ if args.type == "clipL":
686
+ text_encoder_1st = load_text_encoder_2(args.path1, device, dtype)
687
+ text_encoder_2nd = load_text_encoder_2(args.path2, device, dtype)
688
+ elif args.type == "llm":
689
+ text_encoder_1st = load_text_encoder_1(args.path1, device, False, dtype)
690
+ text_encoder_2nd = load_text_encoder_1(args.path2, device, False, dtype)
691
+ print(f"1st Text Encoder dtype: {text_encoder_1st.dtype}")
692
+ print(f"2nd Text Encoder dtype: {text_encoder_2nd.dtype}")
693
+
694
+ prompt = "A cat sitting on a table"
695
+ data_type = "video" # video only, image is not supported
696
+ text_inputs_1st = text_encoder_1st.text2tokens(prompt, data_type=data_type)
697
+ text_inputs_2nd = text_encoder_2nd.text2tokens(prompt, data_type=data_type)
698
+ print(text_inputs_1st)
699
+ assert torch.allclose(text_inputs_1st["input_ids"], text_inputs_2nd["input_ids"])
700
+
701
+ with torch.no_grad():
702
+ prompt_outputs_1st = text_encoder_1st.encode(text_inputs_1st, data_type=data_type)
703
+ prompt_outputs_2nd = text_encoder_2nd.encode(text_inputs_1st, data_type=data_type)
704
+
705
+ # prompt_outputs.hidden_state, prompt_outputs.attention_mask
706
+ assert torch.allclose(prompt_outputs_1st.hidden_state, prompt_outputs_2nd.hidden_state)
707
+ print("Hidden states are the same.")
708
+ assert torch.allclose(prompt_outputs_1st.attention_mask, prompt_outputs_2nd.attention_mask)
709
+ print("Attention masks are the same.")
710
+ print("All outputs are the same.")
hunyuan_model/token_refiner.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from einops import rearrange
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.checkpoint import checkpoint
7
+
8
+ from .activation_layers import get_activation_layer
9
+ from .attention import attention
10
+ from .norm_layers import get_norm_layer
11
+ from .embed_layers import TimestepEmbedder, TextProjection
12
+ from .mlp_layers import MLP
13
+ from .modulate_layers import modulate, apply_gate
14
+
15
+
16
+ class IndividualTokenRefinerBlock(nn.Module):
17
+ def __init__(
18
+ self,
19
+ hidden_size,
20
+ heads_num,
21
+ mlp_width_ratio: str = 4.0,
22
+ mlp_drop_rate: float = 0.0,
23
+ act_type: str = "silu",
24
+ qk_norm: bool = False,
25
+ qk_norm_type: str = "layer",
26
+ qkv_bias: bool = True,
27
+ dtype: Optional[torch.dtype] = None,
28
+ device: Optional[torch.device] = None,
29
+ ):
30
+ factory_kwargs = {"device": device, "dtype": dtype}
31
+ super().__init__()
32
+ self.heads_num = heads_num
33
+ head_dim = hidden_size // heads_num
34
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
35
+
36
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
37
+ self.self_attn_qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs)
38
+ qk_norm_layer = get_norm_layer(qk_norm_type)
39
+ self.self_attn_q_norm = (
40
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
41
+ )
42
+ self.self_attn_k_norm = (
43
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
44
+ )
45
+ self.self_attn_proj = nn.Linear(hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs)
46
+
47
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs)
48
+ act_layer = get_activation_layer(act_type)
49
+ self.mlp = MLP(
50
+ in_channels=hidden_size,
51
+ hidden_channels=mlp_hidden_dim,
52
+ act_layer=act_layer,
53
+ drop=mlp_drop_rate,
54
+ **factory_kwargs,
55
+ )
56
+
57
+ self.adaLN_modulation = nn.Sequential(
58
+ act_layer(),
59
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
60
+ )
61
+ # Zero-initialize the modulation
62
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
63
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
64
+
65
+ self.gradient_checkpointing = False
66
+
67
+ def enable_gradient_checkpointing(self):
68
+ self.gradient_checkpointing = True
69
+
70
+ def disable_gradient_checkpointing(self):
71
+ self.gradient_checkpointing = False
72
+
73
+ def _forward(
74
+ self,
75
+ x: torch.Tensor,
76
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
77
+ attn_mask: torch.Tensor = None,
78
+ ):
79
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
80
+
81
+ norm_x = self.norm1(x)
82
+ qkv = self.self_attn_qkv(norm_x)
83
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
84
+ # Apply QK-Norm if needed
85
+ q = self.self_attn_q_norm(q).to(v)
86
+ k = self.self_attn_k_norm(k).to(v)
87
+
88
+ # Self-Attention
89
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
90
+
91
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
92
+
93
+ # FFN Layer
94
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
95
+
96
+ return x
97
+
98
+ def forward(self, *args, **kwargs):
99
+ if self.training and self.gradient_checkpointing:
100
+ return checkpoint(self._forward, *args, use_reentrant=False, **kwargs)
101
+ else:
102
+ return self._forward(*args, **kwargs)
103
+
104
+
105
+ class IndividualTokenRefiner(nn.Module):
106
+ def __init__(
107
+ self,
108
+ hidden_size,
109
+ heads_num,
110
+ depth,
111
+ mlp_width_ratio: float = 4.0,
112
+ mlp_drop_rate: float = 0.0,
113
+ act_type: str = "silu",
114
+ qk_norm: bool = False,
115
+ qk_norm_type: str = "layer",
116
+ qkv_bias: bool = True,
117
+ dtype: Optional[torch.dtype] = None,
118
+ device: Optional[torch.device] = None,
119
+ ):
120
+ factory_kwargs = {"device": device, "dtype": dtype}
121
+ super().__init__()
122
+ self.blocks = nn.ModuleList(
123
+ [
124
+ IndividualTokenRefinerBlock(
125
+ hidden_size=hidden_size,
126
+ heads_num=heads_num,
127
+ mlp_width_ratio=mlp_width_ratio,
128
+ mlp_drop_rate=mlp_drop_rate,
129
+ act_type=act_type,
130
+ qk_norm=qk_norm,
131
+ qk_norm_type=qk_norm_type,
132
+ qkv_bias=qkv_bias,
133
+ **factory_kwargs,
134
+ )
135
+ for _ in range(depth)
136
+ ]
137
+ )
138
+
139
+ def enable_gradient_checkpointing(self):
140
+ for block in self.blocks:
141
+ block.enable_gradient_checkpointing()
142
+
143
+ def disable_gradient_checkpointing(self):
144
+ for block in self.blocks:
145
+ block.disable_gradient_checkpointing()
146
+
147
+ def forward(
148
+ self,
149
+ x: torch.Tensor,
150
+ c: torch.LongTensor,
151
+ mask: Optional[torch.Tensor] = None,
152
+ ):
153
+ self_attn_mask = None
154
+ if mask is not None:
155
+ batch_size = mask.shape[0]
156
+ seq_len = mask.shape[1]
157
+ mask = mask.to(x.device)
158
+ # batch_size x 1 x seq_len x seq_len
159
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
160
+ # batch_size x 1 x seq_len x seq_len
161
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
162
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
163
+ self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
164
+ # avoids self-attention weight being NaN for padding tokens
165
+ self_attn_mask[:, :, :, 0] = True
166
+
167
+ for block in self.blocks:
168
+ x = block(x, c, self_attn_mask)
169
+ return x
170
+
171
+
172
+ class SingleTokenRefiner(nn.Module):
173
+ """
174
+ A single token refiner block for llm text embedding refine.
175
+ """
176
+
177
+ def __init__(
178
+ self,
179
+ in_channels,
180
+ hidden_size,
181
+ heads_num,
182
+ depth,
183
+ mlp_width_ratio: float = 4.0,
184
+ mlp_drop_rate: float = 0.0,
185
+ act_type: str = "silu",
186
+ qk_norm: bool = False,
187
+ qk_norm_type: str = "layer",
188
+ qkv_bias: bool = True,
189
+ attn_mode: str = "torch",
190
+ dtype: Optional[torch.dtype] = None,
191
+ device: Optional[torch.device] = None,
192
+ ):
193
+ factory_kwargs = {"device": device, "dtype": dtype}
194
+ super().__init__()
195
+ self.attn_mode = attn_mode
196
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
197
+
198
+ self.input_embedder = nn.Linear(in_channels, hidden_size, bias=True, **factory_kwargs)
199
+
200
+ act_layer = get_activation_layer(act_type)
201
+ # Build timestep embedding layer
202
+ self.t_embedder = TimestepEmbedder(hidden_size, act_layer, **factory_kwargs)
203
+ # Build context embedding layer
204
+ self.c_embedder = TextProjection(in_channels, hidden_size, act_layer, **factory_kwargs)
205
+
206
+ self.individual_token_refiner = IndividualTokenRefiner(
207
+ hidden_size=hidden_size,
208
+ heads_num=heads_num,
209
+ depth=depth,
210
+ mlp_width_ratio=mlp_width_ratio,
211
+ mlp_drop_rate=mlp_drop_rate,
212
+ act_type=act_type,
213
+ qk_norm=qk_norm,
214
+ qk_norm_type=qk_norm_type,
215
+ qkv_bias=qkv_bias,
216
+ **factory_kwargs,
217
+ )
218
+
219
+ def enable_gradient_checkpointing(self):
220
+ self.individual_token_refiner.enable_gradient_checkpointing()
221
+
222
+ def disable_gradient_checkpointing(self):
223
+ self.individual_token_refiner.disable_gradient_checkpointing()
224
+
225
+ def forward(
226
+ self,
227
+ x: torch.Tensor,
228
+ t: torch.LongTensor,
229
+ mask: Optional[torch.LongTensor] = None,
230
+ ):
231
+ timestep_aware_representations = self.t_embedder(t)
232
+
233
+ if mask is None:
234
+ context_aware_representations = x.mean(dim=1)
235
+ else:
236
+ mask_float = mask.float().unsqueeze(-1) # [b, s1, 1]
237
+ context_aware_representations = (x * mask_float).sum(dim=1) / mask_float.sum(dim=1)
238
+ context_aware_representations = self.c_embedder(context_aware_representations)
239
+ c = timestep_aware_representations + context_aware_representations
240
+
241
+ x = self.input_embedder(x)
242
+
243
+ x = self.individual_token_refiner(x, c, mask)
244
+
245
+ return x
hunyuan_model/vae.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import json
3
+ from typing import Optional, Tuple, Union
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from diffusers.utils import BaseOutput, is_torch_version
11
+ from diffusers.utils.torch_utils import randn_tensor
12
+ from diffusers.models.attention_processor import SpatialNorm
13
+ from modules.unet_causal_3d_blocks import CausalConv3d, UNetMidBlockCausal3D, get_down_block3d, get_up_block3d
14
+
15
+ import logging
16
+
17
+ logger = logging.getLogger(__name__)
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+
21
+ SCALING_FACTOR = 0.476986
22
+ VAE_VER = "884-16c-hy" # We don't support other versions currently
23
+
24
+
25
+ def load_vae(
26
+ vae_type: str = "884-16c-hy",
27
+ vae_dtype: Optional[Union[str, torch.dtype]] = None,
28
+ sample_size: tuple = None,
29
+ vae_path: str = None,
30
+ device=None,
31
+ ):
32
+ """the fucntion to load the 3D VAE model
33
+
34
+ Args:
35
+ vae_type (str): the type of the 3D VAE model. Defaults to "884-16c-hy".
36
+ vae_precision (str, optional): the precision to load vae. Defaults to None.
37
+ sample_size (tuple, optional): the tiling size. Defaults to None.
38
+ vae_path (str, optional): the path to vae. Defaults to None.
39
+ logger (_type_, optional): logger. Defaults to None.
40
+ device (_type_, optional): device to load vae. Defaults to None.
41
+ """
42
+ if vae_path is None:
43
+ vae_path = VAE_PATH[vae_type]
44
+
45
+ logger.info(f"Loading 3D VAE model ({vae_type}) from: {vae_path}")
46
+
47
+ # use fixed config for Hunyuan's VAE
48
+ CONFIG_JSON = """{
49
+ "_class_name": "AutoencoderKLCausal3D",
50
+ "_diffusers_version": "0.4.2",
51
+ "act_fn": "silu",
52
+ "block_out_channels": [
53
+ 128,
54
+ 256,
55
+ 512,
56
+ 512
57
+ ],
58
+ "down_block_types": [
59
+ "DownEncoderBlockCausal3D",
60
+ "DownEncoderBlockCausal3D",
61
+ "DownEncoderBlockCausal3D",
62
+ "DownEncoderBlockCausal3D"
63
+ ],
64
+ "in_channels": 3,
65
+ "latent_channels": 16,
66
+ "layers_per_block": 2,
67
+ "norm_num_groups": 32,
68
+ "out_channels": 3,
69
+ "sample_size": 256,
70
+ "sample_tsize": 64,
71
+ "up_block_types": [
72
+ "UpDecoderBlockCausal3D",
73
+ "UpDecoderBlockCausal3D",
74
+ "UpDecoderBlockCausal3D",
75
+ "UpDecoderBlockCausal3D"
76
+ ],
77
+ "scaling_factor": 0.476986,
78
+ "time_compression_ratio": 4,
79
+ "mid_block_add_attention": true
80
+ }"""
81
+
82
+ # config = AutoencoderKLCausal3D.load_config(vae_path)
83
+ config = json.loads(CONFIG_JSON)
84
+
85
+ # import here to avoid circular import
86
+ from .autoencoder_kl_causal_3d import AutoencoderKLCausal3D
87
+
88
+ if sample_size:
89
+ vae = AutoencoderKLCausal3D.from_config(config, sample_size=sample_size)
90
+ else:
91
+ vae = AutoencoderKLCausal3D.from_config(config)
92
+
93
+ # vae_ckpt = Path(vae_path) / "pytorch_model.pt"
94
+ # assert vae_ckpt.exists(), f"VAE checkpoint not found: {vae_ckpt}"
95
+
96
+ if vae_path.endswith(".safetensors"):
97
+ from safetensors.torch import load_file
98
+ ckpt = load_file(vae_path)
99
+ else:
100
+ ckpt = torch.load(vae_path, map_location=vae.device, weights_only=True)
101
+ if "state_dict" in ckpt:
102
+ ckpt = ckpt["state_dict"]
103
+ if any(k.startswith("vae.") for k in ckpt.keys()):
104
+ ckpt = {k.replace("vae.", ""): v for k, v in ckpt.items() if k.startswith("vae.")}
105
+ vae.load_state_dict(ckpt)
106
+
107
+ spatial_compression_ratio = vae.config.spatial_compression_ratio
108
+ time_compression_ratio = vae.config.time_compression_ratio
109
+
110
+ if vae_dtype is not None:
111
+ vae = vae.to(vae_dtype)
112
+
113
+ vae.requires_grad_(False)
114
+
115
+ logger.info(f"VAE to dtype: {vae.dtype}")
116
+
117
+ if device is not None:
118
+ vae = vae.to(device)
119
+
120
+ vae.eval()
121
+
122
+ return vae, vae_path, spatial_compression_ratio, time_compression_ratio
123
+
124
+
125
+ @dataclass
126
+ class DecoderOutput(BaseOutput):
127
+ r"""
128
+ Output of decoding method.
129
+
130
+ Args:
131
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
132
+ The decoded output sample from the last layer of the model.
133
+ """
134
+
135
+ sample: torch.FloatTensor
136
+
137
+
138
+ class EncoderCausal3D(nn.Module):
139
+ r"""
140
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
141
+ """
142
+
143
+ def __init__(
144
+ self,
145
+ in_channels: int = 3,
146
+ out_channels: int = 3,
147
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
148
+ block_out_channels: Tuple[int, ...] = (64,),
149
+ layers_per_block: int = 2,
150
+ norm_num_groups: int = 32,
151
+ act_fn: str = "silu",
152
+ double_z: bool = True,
153
+ mid_block_add_attention=True,
154
+ time_compression_ratio: int = 4,
155
+ spatial_compression_ratio: int = 8,
156
+ ):
157
+ super().__init__()
158
+ self.layers_per_block = layers_per_block
159
+
160
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
161
+ self.mid_block = None
162
+ self.down_blocks = nn.ModuleList([])
163
+
164
+ # down
165
+ output_channel = block_out_channels[0]
166
+ for i, down_block_type in enumerate(down_block_types):
167
+ input_channel = output_channel
168
+ output_channel = block_out_channels[i]
169
+ is_final_block = i == len(block_out_channels) - 1
170
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
171
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
172
+
173
+ if time_compression_ratio == 4:
174
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
175
+ add_time_downsample = bool(i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block)
176
+ else:
177
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
178
+
179
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
180
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
181
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
182
+ down_block = get_down_block3d(
183
+ down_block_type,
184
+ num_layers=self.layers_per_block,
185
+ in_channels=input_channel,
186
+ out_channels=output_channel,
187
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
188
+ downsample_stride=downsample_stride,
189
+ resnet_eps=1e-6,
190
+ downsample_padding=0,
191
+ resnet_act_fn=act_fn,
192
+ resnet_groups=norm_num_groups,
193
+ attention_head_dim=output_channel,
194
+ temb_channels=None,
195
+ )
196
+ self.down_blocks.append(down_block)
197
+
198
+ # mid
199
+ self.mid_block = UNetMidBlockCausal3D(
200
+ in_channels=block_out_channels[-1],
201
+ resnet_eps=1e-6,
202
+ resnet_act_fn=act_fn,
203
+ output_scale_factor=1,
204
+ resnet_time_scale_shift="default",
205
+ attention_head_dim=block_out_channels[-1],
206
+ resnet_groups=norm_num_groups,
207
+ temb_channels=None,
208
+ add_attention=mid_block_add_attention,
209
+ )
210
+
211
+ # out
212
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
213
+ self.conv_act = nn.SiLU()
214
+
215
+ conv_out_channels = 2 * out_channels if double_z else out_channels
216
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
217
+
218
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
219
+ r"""The forward method of the `EncoderCausal3D` class."""
220
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
221
+
222
+ sample = self.conv_in(sample)
223
+
224
+ # down
225
+ for down_block in self.down_blocks:
226
+ sample = down_block(sample)
227
+
228
+ # middle
229
+ sample = self.mid_block(sample)
230
+
231
+ # post-process
232
+ sample = self.conv_norm_out(sample)
233
+ sample = self.conv_act(sample)
234
+ sample = self.conv_out(sample)
235
+
236
+ return sample
237
+
238
+
239
+ class DecoderCausal3D(nn.Module):
240
+ r"""
241
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
242
+ """
243
+
244
+ def __init__(
245
+ self,
246
+ in_channels: int = 3,
247
+ out_channels: int = 3,
248
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
249
+ block_out_channels: Tuple[int, ...] = (64,),
250
+ layers_per_block: int = 2,
251
+ norm_num_groups: int = 32,
252
+ act_fn: str = "silu",
253
+ norm_type: str = "group", # group, spatial
254
+ mid_block_add_attention=True,
255
+ time_compression_ratio: int = 4,
256
+ spatial_compression_ratio: int = 8,
257
+ ):
258
+ super().__init__()
259
+ self.layers_per_block = layers_per_block
260
+
261
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
262
+ self.mid_block = None
263
+ self.up_blocks = nn.ModuleList([])
264
+
265
+ temb_channels = in_channels if norm_type == "spatial" else None
266
+
267
+ # mid
268
+ self.mid_block = UNetMidBlockCausal3D(
269
+ in_channels=block_out_channels[-1],
270
+ resnet_eps=1e-6,
271
+ resnet_act_fn=act_fn,
272
+ output_scale_factor=1,
273
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
274
+ attention_head_dim=block_out_channels[-1],
275
+ resnet_groups=norm_num_groups,
276
+ temb_channels=temb_channels,
277
+ add_attention=mid_block_add_attention,
278
+ )
279
+
280
+ # up
281
+ reversed_block_out_channels = list(reversed(block_out_channels))
282
+ output_channel = reversed_block_out_channels[0]
283
+ for i, up_block_type in enumerate(up_block_types):
284
+ prev_output_channel = output_channel
285
+ output_channel = reversed_block_out_channels[i]
286
+ is_final_block = i == len(block_out_channels) - 1
287
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
288
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
289
+
290
+ if time_compression_ratio == 4:
291
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
292
+ add_time_upsample = bool(i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block)
293
+ else:
294
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
295
+
296
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
297
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
298
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
299
+ up_block = get_up_block3d(
300
+ up_block_type,
301
+ num_layers=self.layers_per_block + 1,
302
+ in_channels=prev_output_channel,
303
+ out_channels=output_channel,
304
+ prev_output_channel=None,
305
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
306
+ upsample_scale_factor=upsample_scale_factor,
307
+ resnet_eps=1e-6,
308
+ resnet_act_fn=act_fn,
309
+ resnet_groups=norm_num_groups,
310
+ attention_head_dim=output_channel,
311
+ temb_channels=temb_channels,
312
+ resnet_time_scale_shift=norm_type,
313
+ )
314
+ self.up_blocks.append(up_block)
315
+ prev_output_channel = output_channel
316
+
317
+ # out
318
+ if norm_type == "spatial":
319
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
320
+ else:
321
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
322
+ self.conv_act = nn.SiLU()
323
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
324
+
325
+ self.gradient_checkpointing = False
326
+
327
+ def forward(
328
+ self,
329
+ sample: torch.FloatTensor,
330
+ latent_embeds: Optional[torch.FloatTensor] = None,
331
+ ) -> torch.FloatTensor:
332
+ r"""The forward method of the `DecoderCausal3D` class."""
333
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
334
+
335
+ sample = self.conv_in(sample)
336
+
337
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
338
+ if self.training and self.gradient_checkpointing:
339
+
340
+ def create_custom_forward(module):
341
+ def custom_forward(*inputs):
342
+ return module(*inputs)
343
+
344
+ return custom_forward
345
+
346
+ if is_torch_version(">=", "1.11.0"):
347
+ # middle
348
+ sample = torch.utils.checkpoint.checkpoint(
349
+ create_custom_forward(self.mid_block),
350
+ sample,
351
+ latent_embeds,
352
+ use_reentrant=False,
353
+ )
354
+ sample = sample.to(upscale_dtype)
355
+
356
+ # up
357
+ for up_block in self.up_blocks:
358
+ sample = torch.utils.checkpoint.checkpoint(
359
+ create_custom_forward(up_block),
360
+ sample,
361
+ latent_embeds,
362
+ use_reentrant=False,
363
+ )
364
+ else:
365
+ # middle
366
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample, latent_embeds)
367
+ sample = sample.to(upscale_dtype)
368
+
369
+ # up
370
+ for up_block in self.up_blocks:
371
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
372
+ else:
373
+ # middle
374
+ sample = self.mid_block(sample, latent_embeds)
375
+ sample = sample.to(upscale_dtype)
376
+
377
+ # up
378
+ for up_block in self.up_blocks:
379
+ sample = up_block(sample, latent_embeds)
380
+
381
+ # post-process
382
+ if latent_embeds is None:
383
+ sample = self.conv_norm_out(sample)
384
+ else:
385
+ sample = self.conv_norm_out(sample, latent_embeds)
386
+ sample = self.conv_act(sample)
387
+ sample = self.conv_out(sample)
388
+
389
+ return sample
390
+
391
+
392
+ class DiagonalGaussianDistribution(object):
393
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
394
+ if parameters.ndim == 3:
395
+ dim = 2 # (B, L, C)
396
+ elif parameters.ndim == 5 or parameters.ndim == 4:
397
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
398
+ else:
399
+ raise NotImplementedError
400
+ self.parameters = parameters
401
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
402
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
403
+ self.deterministic = deterministic
404
+ self.std = torch.exp(0.5 * self.logvar)
405
+ self.var = torch.exp(self.logvar)
406
+ if self.deterministic:
407
+ self.var = self.std = torch.zeros_like(self.mean, device=self.parameters.device, dtype=self.parameters.dtype)
408
+
409
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
410
+ # make sure sample is on the same device as the parameters and has same dtype
411
+ sample = randn_tensor(
412
+ self.mean.shape,
413
+ generator=generator,
414
+ device=self.parameters.device,
415
+ dtype=self.parameters.dtype,
416
+ )
417
+ x = self.mean + self.std * sample
418
+ return x
419
+
420
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
421
+ if self.deterministic:
422
+ return torch.Tensor([0.0])
423
+ else:
424
+ reduce_dim = list(range(1, self.mean.ndim))
425
+ if other is None:
426
+ return 0.5 * torch.sum(
427
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
428
+ dim=reduce_dim,
429
+ )
430
+ else:
431
+ return 0.5 * torch.sum(
432
+ torch.pow(self.mean - other.mean, 2) / other.var + self.var / other.var - 1.0 - self.logvar + other.logvar,
433
+ dim=reduce_dim,
434
+ )
435
+
436
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
437
+ if self.deterministic:
438
+ return torch.Tensor([0.0])
439
+ logtwopi = np.log(2.0 * np.pi)
440
+ return 0.5 * torch.sum(
441
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
442
+ dim=dims,
443
+ )
444
+
445
+ def mode(self) -> torch.Tensor:
446
+ return self.mean
hv_generate_video.py ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from datetime import datetime
3
+ from pathlib import Path
4
+ import random
5
+ import sys
6
+ import os
7
+ import time
8
+ from typing import Optional, Union
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ import accelerate
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from transformers.models.llama import LlamaModel
16
+ from tqdm import tqdm
17
+ import av
18
+ from einops import rearrange
19
+ from safetensors.torch import load_file, save_file
20
+ from safetensors import safe_open
21
+ from PIL import Image
22
+
23
+ from hunyuan_model import vae
24
+ from hunyuan_model.text_encoder import TextEncoder
25
+ from hunyuan_model.text_encoder import PROMPT_TEMPLATE
26
+ from hunyuan_model.vae import load_vae
27
+ from hunyuan_model.models import load_transformer, get_rotary_pos_embed
28
+ from hunyuan_model.fp8_optimization import convert_fp8_linear
29
+ from modules.scheduling_flow_match_discrete import FlowMatchDiscreteScheduler
30
+ from networks import lora
31
+
32
+ try:
33
+ from lycoris.kohya import create_network_from_weights
34
+ except:
35
+ pass
36
+
37
+ from utils.model_utils import str_to_dtype
38
+ from utils.safetensors_utils import mem_eff_save_file
39
+ from dataset.image_video_dataset import load_video, glob_images, resize_image_to_bucket
40
+
41
+ import logging
42
+
43
+ logger = logging.getLogger(__name__)
44
+ logging.basicConfig(level=logging.INFO)
45
+
46
+
47
+ def clean_memory_on_device(device):
48
+ if device.type == "cuda":
49
+ torch.cuda.empty_cache()
50
+ elif device.type == "cpu":
51
+ pass
52
+ elif device.type == "mps": # not tested
53
+ torch.mps.empty_cache()
54
+
55
+
56
+ def synchronize_device(device: torch.device):
57
+ if device.type == "cuda":
58
+ torch.cuda.synchronize()
59
+ elif device.type == "xpu":
60
+ torch.xpu.synchronize()
61
+ elif device.type == "mps":
62
+ torch.mps.synchronize()
63
+
64
+
65
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=1, fps=24):
66
+ """save videos by video tensor
67
+ copy from https://github.com/guoyww/AnimateDiff/blob/e92bd5671ba62c0d774a32951453e328018b7c5b/animatediff/utils/util.py#L61
68
+
69
+ Args:
70
+ videos (torch.Tensor): video tensor predicted by the model
71
+ path (str): path to save video
72
+ rescale (bool, optional): rescale the video tensor from [-1, 1] to . Defaults to False.
73
+ n_rows (int, optional): Defaults to 1.
74
+ fps (int, optional): video save fps. Defaults to 8.
75
+ """
76
+ videos = rearrange(videos, "b c t h w -> t b c h w")
77
+ outputs = []
78
+ for x in videos:
79
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
80
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
81
+ if rescale:
82
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
83
+ x = torch.clamp(x, 0, 1)
84
+ x = (x * 255).numpy().astype(np.uint8)
85
+ outputs.append(x)
86
+
87
+ os.makedirs(os.path.dirname(path), exist_ok=True)
88
+
89
+ # # save video with av
90
+ # container = av.open(path, "w")
91
+ # stream = container.add_stream("libx264", rate=fps)
92
+ # for x in outputs:
93
+ # frame = av.VideoFrame.from_ndarray(x, format="rgb24")
94
+ # packet = stream.encode(frame)
95
+ # container.mux(packet)
96
+ # packet = stream.encode(None)
97
+ # container.mux(packet)
98
+ # container.close()
99
+
100
+ height, width, _ = outputs[0].shape
101
+
102
+ # create output container
103
+ container = av.open(path, mode="w")
104
+
105
+ # create video stream
106
+ codec = "libx264"
107
+ pixel_format = "yuv420p"
108
+ stream = container.add_stream(codec, rate=fps)
109
+ stream.width = width
110
+ stream.height = height
111
+ stream.pix_fmt = pixel_format
112
+ stream.bit_rate = 4000000 # 4Mbit/s
113
+
114
+ for frame_array in outputs:
115
+ frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
116
+ packets = stream.encode(frame)
117
+ for packet in packets:
118
+ container.mux(packet)
119
+
120
+ for packet in stream.encode():
121
+ container.mux(packet)
122
+
123
+ container.close()
124
+
125
+
126
+ def save_images_grid(
127
+ videos: torch.Tensor, parent_dir: str, image_name: str, rescale: bool = False, n_rows: int = 1, create_subdir=True
128
+ ):
129
+ videos = rearrange(videos, "b c t h w -> t b c h w")
130
+ outputs = []
131
+ for x in videos:
132
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
133
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
134
+ if rescale:
135
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
136
+ x = torch.clamp(x, 0, 1)
137
+ x = (x * 255).numpy().astype(np.uint8)
138
+ outputs.append(x)
139
+
140
+ if create_subdir:
141
+ output_dir = os.path.join(parent_dir, image_name)
142
+ else:
143
+ output_dir = parent_dir
144
+
145
+ os.makedirs(output_dir, exist_ok=True)
146
+ for i, x in enumerate(outputs):
147
+ image_path = os.path.join(output_dir, f"{image_name}_{i:03d}.png")
148
+ image = Image.fromarray(x)
149
+ image.save(image_path)
150
+
151
+
152
+ # region Encoding prompt
153
+
154
+
155
+ def encode_prompt(prompt: Union[str, list[str]], device: torch.device, num_videos_per_prompt: int, text_encoder: TextEncoder):
156
+ r"""
157
+ Encodes the prompt into text encoder hidden states.
158
+
159
+ Args:
160
+ prompt (`str` or `List[str]`):
161
+ prompt to be encoded
162
+ device: (`torch.device`):
163
+ torch device
164
+ num_videos_per_prompt (`int`):
165
+ number of videos that should be generated per prompt
166
+ text_encoder (TextEncoder):
167
+ text encoder to be used for encoding the prompt
168
+ """
169
+ # LoRA and Textual Inversion are not supported in this script
170
+ # negative prompt and prompt embedding are not supported in this script
171
+ # clip_skip is not supported in this script because it is not used in the original script
172
+ data_type = "video" # video only, image is not supported
173
+
174
+ text_inputs = text_encoder.text2tokens(prompt, data_type=data_type)
175
+
176
+ with torch.no_grad():
177
+ prompt_outputs = text_encoder.encode(text_inputs, data_type=data_type, device=device)
178
+ prompt_embeds = prompt_outputs.hidden_state
179
+
180
+ attention_mask = prompt_outputs.attention_mask
181
+ if attention_mask is not None:
182
+ attention_mask = attention_mask.to(device)
183
+ bs_embed, seq_len = attention_mask.shape
184
+ attention_mask = attention_mask.repeat(1, num_videos_per_prompt)
185
+ attention_mask = attention_mask.view(bs_embed * num_videos_per_prompt, seq_len)
186
+
187
+ prompt_embeds_dtype = text_encoder.dtype
188
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
189
+
190
+ if prompt_embeds.ndim == 2:
191
+ bs_embed, _ = prompt_embeds.shape
192
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
193
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
194
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, -1)
195
+ else:
196
+ bs_embed, seq_len, _ = prompt_embeds.shape
197
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
198
+ prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
199
+ prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
200
+
201
+ return prompt_embeds, attention_mask
202
+
203
+
204
+ def encode_input_prompt(prompt: Union[str, list[str]], args, device, fp8_llm=False, accelerator=None):
205
+ # constants
206
+ prompt_template_video = "dit-llm-encode-video"
207
+ prompt_template = "dit-llm-encode"
208
+ text_encoder_dtype = torch.float16
209
+ text_encoder_type = "llm"
210
+ text_len = 256
211
+ hidden_state_skip_layer = 2
212
+ apply_final_norm = False
213
+ reproduce = False
214
+
215
+ text_encoder_2_type = "clipL"
216
+ text_len_2 = 77
217
+
218
+ num_videos = 1
219
+
220
+ # if args.prompt_template_video is not None:
221
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template_video].get("crop_start", 0)
222
+ # elif args.prompt_template is not None:
223
+ # crop_start = PROMPT_TEMPLATE[args.prompt_template].get("crop_start", 0)
224
+ # else:
225
+ # crop_start = 0
226
+ crop_start = PROMPT_TEMPLATE[prompt_template_video].get("crop_start", 0)
227
+ max_length = text_len + crop_start
228
+
229
+ # prompt_template
230
+ prompt_template = PROMPT_TEMPLATE[prompt_template]
231
+
232
+ # prompt_template_video
233
+ prompt_template_video = PROMPT_TEMPLATE[prompt_template_video] # if args.prompt_template_video is not None else None
234
+
235
+ # load text encoders
236
+ logger.info(f"loading text encoder: {args.text_encoder1}")
237
+ text_encoder = TextEncoder(
238
+ text_encoder_type=text_encoder_type,
239
+ max_length=max_length,
240
+ text_encoder_dtype=text_encoder_dtype,
241
+ text_encoder_path=args.text_encoder1,
242
+ tokenizer_type=text_encoder_type,
243
+ prompt_template=prompt_template,
244
+ prompt_template_video=prompt_template_video,
245
+ hidden_state_skip_layer=hidden_state_skip_layer,
246
+ apply_final_norm=apply_final_norm,
247
+ reproduce=reproduce,
248
+ )
249
+ text_encoder.eval()
250
+ if fp8_llm:
251
+ org_dtype = text_encoder.dtype
252
+ logger.info(f"Moving and casting text encoder to {device} and torch.float8_e4m3fn")
253
+ text_encoder.to(device=device, dtype=torch.float8_e4m3fn)
254
+
255
+ # prepare LLM for fp8
256
+ def prepare_fp8(llama_model: LlamaModel, target_dtype):
257
+ def forward_hook(module):
258
+ def forward(hidden_states):
259
+ input_dtype = hidden_states.dtype
260
+ hidden_states = hidden_states.to(torch.float32)
261
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
262
+ hidden_states = hidden_states * torch.rsqrt(variance + module.variance_epsilon)
263
+ return module.weight.to(input_dtype) * hidden_states.to(input_dtype)
264
+
265
+ return forward
266
+
267
+ for module in llama_model.modules():
268
+ if module.__class__.__name__ in ["Embedding"]:
269
+ # print("set", module.__class__.__name__, "to", target_dtype)
270
+ module.to(target_dtype)
271
+ if module.__class__.__name__ in ["LlamaRMSNorm"]:
272
+ # print("set", module.__class__.__name__, "hooks")
273
+ module.forward = forward_hook(module)
274
+
275
+ prepare_fp8(text_encoder.model, org_dtype)
276
+
277
+ logger.info(f"loading text encoder 2: {args.text_encoder2}")
278
+ text_encoder_2 = TextEncoder(
279
+ text_encoder_type=text_encoder_2_type,
280
+ max_length=text_len_2,
281
+ text_encoder_dtype=text_encoder_dtype,
282
+ text_encoder_path=args.text_encoder2,
283
+ tokenizer_type=text_encoder_2_type,
284
+ reproduce=reproduce,
285
+ )
286
+ text_encoder_2.eval()
287
+
288
+ # encode prompt
289
+ logger.info(f"Encoding prompt with text encoder 1")
290
+ text_encoder.to(device=device)
291
+ if fp8_llm:
292
+ with accelerator.autocast():
293
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
294
+ else:
295
+ prompt_embeds, prompt_mask = encode_prompt(prompt, device, num_videos, text_encoder)
296
+ text_encoder = None
297
+ clean_memory_on_device(device)
298
+
299
+ logger.info(f"Encoding prompt with text encoder 2")
300
+ text_encoder_2.to(device=device)
301
+ prompt_embeds_2, prompt_mask_2 = encode_prompt(prompt, device, num_videos, text_encoder_2)
302
+
303
+ prompt_embeds = prompt_embeds.to("cpu")
304
+ prompt_mask = prompt_mask.to("cpu")
305
+ prompt_embeds_2 = prompt_embeds_2.to("cpu")
306
+ prompt_mask_2 = prompt_mask_2.to("cpu")
307
+
308
+ text_encoder_2 = None
309
+ clean_memory_on_device(device)
310
+
311
+ return prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2
312
+
313
+
314
+ # endregion
315
+
316
+
317
+ def prepare_vae(args, device):
318
+ vae_dtype = torch.float16 if args.vae_dtype is None else str_to_dtype(args.vae_dtype)
319
+ vae, _, s_ratio, t_ratio = load_vae(vae_dtype=vae_dtype, device=device, vae_path=args.vae)
320
+ vae.eval()
321
+ # vae_kwargs = {"s_ratio": s_ratio, "t_ratio": t_ratio}
322
+
323
+ # set chunk_size to CausalConv3d recursively
324
+ chunk_size = args.vae_chunk_size
325
+ if chunk_size is not None:
326
+ vae.set_chunk_size_for_causal_conv_3d(chunk_size)
327
+ logger.info(f"Set chunk_size to {chunk_size} for CausalConv3d")
328
+
329
+ if args.vae_spatial_tile_sample_min_size is not None:
330
+ vae.enable_spatial_tiling(True)
331
+ vae.tile_sample_min_size = args.vae_spatial_tile_sample_min_size
332
+ vae.tile_latent_min_size = args.vae_spatial_tile_sample_min_size // 8
333
+ # elif args.vae_tiling:
334
+ else:
335
+ vae.enable_spatial_tiling(True)
336
+
337
+ return vae, vae_dtype
338
+
339
+
340
+ def encode_to_latents(args, video, device):
341
+ vae, vae_dtype = prepare_vae(args, device)
342
+
343
+ video = video.to(device=device, dtype=vae_dtype)
344
+ video = video * 2 - 1 # 0, 1 -> -1, 1
345
+ with torch.no_grad():
346
+ latents = vae.encode(video).latent_dist.sample()
347
+
348
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
349
+ latents = (latents - vae.config.shift_factor) * vae.config.scaling_factor
350
+ else:
351
+ latents = latents * vae.config.scaling_factor
352
+
353
+ return latents
354
+
355
+
356
+ def decode_latents(args, latents, device):
357
+ vae, vae_dtype = prepare_vae(args, device)
358
+
359
+ expand_temporal_dim = False
360
+ if len(latents.shape) == 4:
361
+ latents = latents.unsqueeze(2)
362
+ expand_temporal_dim = True
363
+ elif len(latents.shape) == 5:
364
+ pass
365
+ else:
366
+ raise ValueError(f"Only support latents with shape (b, c, h, w) or (b, c, f, h, w), but got {latents.shape}.")
367
+
368
+ if hasattr(vae.config, "shift_factor") and vae.config.shift_factor:
369
+ latents = latents / vae.config.scaling_factor + vae.config.shift_factor
370
+ else:
371
+ latents = latents / vae.config.scaling_factor
372
+
373
+ latents = latents.to(device=device, dtype=vae_dtype)
374
+ with torch.no_grad():
375
+ image = vae.decode(latents, return_dict=False)[0]
376
+
377
+ if expand_temporal_dim:
378
+ image = image.squeeze(2)
379
+
380
+ image = (image / 2 + 0.5).clamp(0, 1)
381
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
382
+ image = image.cpu().float()
383
+
384
+ return image
385
+
386
+
387
+ def parse_args():
388
+ parser = argparse.ArgumentParser(description="HunyuanVideo inference script")
389
+
390
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
391
+ parser.add_argument(
392
+ "--dit_in_channels",
393
+ type=int,
394
+ default=None,
395
+ help="input channels for DiT, default is None (automatically detect). 32 for SkyReels-I2V, 16 for others",
396
+ )
397
+ parser.add_argument("--vae", type=str, required=True, help="VAE checkpoint path or directory")
398
+ parser.add_argument("--vae_dtype", type=str, default=None, help="data type for VAE, default is float16")
399
+ parser.add_argument("--text_encoder1", type=str, required=True, help="Text Encoder 1 directory")
400
+ parser.add_argument("--text_encoder2", type=str, required=True, help="Text Encoder 2 directory")
401
+
402
+ # LoRA
403
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
404
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=1.0, help="LoRA multiplier")
405
+ parser.add_argument(
406
+ "--save_merged_model",
407
+ type=str,
408
+ default=None,
409
+ help="Save merged model to path. If specified, no inference will be performed.",
410
+ )
411
+ parser.add_argument("--exclude_single_blocks", action="store_true", help="Exclude single blocks when loading LoRA weights")
412
+
413
+ # inference
414
+ parser.add_argument("--prompt", type=str, required=True, help="prompt for generation")
415
+ parser.add_argument("--negative_prompt", type=str, default=None, help="negative prompt for generation")
416
+ parser.add_argument("--video_size", type=int, nargs=2, default=[256, 256], help="video size")
417
+ parser.add_argument("--video_length", type=int, default=129, help="video length")
418
+ parser.add_argument("--fps", type=int, default=24, help="video fps")
419
+ parser.add_argument("--infer_steps", type=int, default=50, help="number of inference steps")
420
+ parser.add_argument("--save_path", type=str, required=True, help="path to save generated video")
421
+ parser.add_argument("--seed", type=int, default=None, help="Seed for evaluation.")
422
+ parser.add_argument(
423
+ "--guidance_scale",
424
+ type=float,
425
+ default=1.0,
426
+ help="Guidance scale for classifier free guidance. Default is 1.0 (means no guidance)",
427
+ )
428
+ parser.add_argument("--embedded_cfg_scale", type=float, default=6.0, help="Embeded classifier free guidance scale.")
429
+ parser.add_argument("--video_path", type=str, default=None, help="path to video for video2video inference")
430
+ parser.add_argument(
431
+ "--image_path", type=str, default=None, help="path to image for image2video inference, only works for SkyReels-I2V model"
432
+ )
433
+ parser.add_argument(
434
+ "--split_uncond",
435
+ action="store_true",
436
+ help="split unconditional call for classifier free guidance, slower but less memory usage",
437
+ )
438
+ parser.add_argument("--strength", type=float, default=0.8, help="strength for video2video inference")
439
+
440
+ # Flow Matching
441
+ parser.add_argument("--flow_shift", type=float, default=7.0, help="Shift factor for flow matching schedulers.")
442
+
443
+ parser.add_argument("--fp8", action="store_true", help="use fp8 for DiT model")
444
+ parser.add_argument("--fp8_llm", action="store_true", help="use fp8 for Text Encoder 1 (LLM)")
445
+ parser.add_argument(
446
+ "--device", type=str, default=None, help="device to use for inference. If None, use CUDA if available, otherwise use CPU"
447
+ )
448
+ parser.add_argument(
449
+ "--attn_mode", type=str, default="torch", choices=["flash", "torch", "sageattn", "xformers", "sdpa"], help="attention mode"
450
+ )
451
+ parser.add_argument(
452
+ "--split_attn", action="store_true", help="use split attention, default is False. if True, --split_uncond becomes True"
453
+ )
454
+ parser.add_argument("--vae_chunk_size", type=int, default=None, help="chunk size for CausalConv3d in VAE")
455
+ parser.add_argument(
456
+ "--vae_spatial_tile_sample_min_size", type=int, default=None, help="spatial tile sample min size for VAE, default 256"
457
+ )
458
+ parser.add_argument("--blocks_to_swap", type=int, default=None, help="number of blocks to swap in the model")
459
+ parser.add_argument("--img_in_txt_in_offloading", action="store_true", help="offload img_in and txt_in to cpu")
460
+ parser.add_argument(
461
+ "--output_type", type=str, default="video", choices=["video", "images", "latent", "both"], help="output type"
462
+ )
463
+ parser.add_argument("--no_metadata", action="store_true", help="do not save metadata")
464
+ parser.add_argument("--latent_path", type=str, nargs="*", default=None, help="path to latent for decode. no inference")
465
+ parser.add_argument("--lycoris", action="store_true", help="use lycoris for inference")
466
+ parser.add_argument("--fp8_fast", action="store_true", help="Enable fast FP8 arthimetic(RTX 4XXX+)")
467
+ parser.add_argument("--compile", action="store_true", help="Enable torch.compile")
468
+ parser.add_argument(
469
+ "--compile_args",
470
+ nargs=4,
471
+ metavar=("BACKEND", "MODE", "DYNAMIC", "FULLGRAPH"),
472
+ default=["inductor", "max-autotune-no-cudagraphs", "False", "False"],
473
+ help="Torch.compile settings",
474
+ )
475
+
476
+ args = parser.parse_args()
477
+
478
+ assert (args.latent_path is None or len(args.latent_path) == 0) or (
479
+ args.output_type == "images" or args.output_type == "video"
480
+ ), "latent_path is only supported for images or video output"
481
+
482
+ # update dit_weight based on model_base if not exists
483
+
484
+ if args.fp8_fast and not args.fp8:
485
+ raise ValueError("--fp8_fast requires --fp8")
486
+
487
+ return args
488
+
489
+
490
+ def check_inputs(args):
491
+ height = args.video_size[0]
492
+ width = args.video_size[1]
493
+ video_length = args.video_length
494
+
495
+ if height % 8 != 0 or width % 8 != 0:
496
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
497
+ return height, width, video_length
498
+
499
+
500
+ def main():
501
+ args = parse_args()
502
+
503
+ device = args.device if args.device is not None else "cuda" if torch.cuda.is_available() else "cpu"
504
+ device = torch.device(device)
505
+ dit_dtype = torch.bfloat16
506
+ dit_weight_dtype = torch.float8_e4m3fn if args.fp8 else dit_dtype
507
+ logger.info(f"Using device: {device}, DiT precision: {dit_dtype}, weight precision: {dit_weight_dtype}")
508
+
509
+ original_base_names = None
510
+ if args.latent_path is not None and len(args.latent_path) > 0:
511
+ original_base_names = []
512
+ latents_list = []
513
+ seeds = []
514
+ for latent_path in args.latent_path:
515
+ original_base_names.append(os.path.splitext(os.path.basename(latent_path))[0])
516
+ seed = 0
517
+
518
+ if os.path.splitext(latent_path)[1] != ".safetensors":
519
+ latents = torch.load(latent_path, map_location="cpu")
520
+ else:
521
+ latents = load_file(latent_path)["latent"]
522
+ with safe_open(latent_path, framework="pt") as f:
523
+ metadata = f.metadata()
524
+ if metadata is None:
525
+ metadata = {}
526
+ logger.info(f"Loaded metadata: {metadata}")
527
+
528
+ if "seeds" in metadata:
529
+ seed = int(metadata["seeds"])
530
+
531
+ seeds.append(seed)
532
+ latents_list.append(latents)
533
+
534
+ logger.info(f"Loaded latent from {latent_path}. Shape: {latents.shape}")
535
+ latents = torch.stack(latents_list, dim=0)
536
+ else:
537
+ # prepare accelerator
538
+ mixed_precision = "bf16" if dit_dtype == torch.bfloat16 else "fp16"
539
+ accelerator = accelerate.Accelerator(mixed_precision=mixed_precision)
540
+
541
+ # load prompt
542
+ prompt = args.prompt # TODO load prompts from file
543
+ assert prompt is not None, "prompt is required"
544
+
545
+ # check inputs: may be height, width, video_length etc will be changed for each generation in future
546
+ height, width, video_length = check_inputs(args)
547
+
548
+ # encode prompt with LLM and Text Encoder
549
+ logger.info(f"Encoding prompt: {prompt}")
550
+
551
+ do_classifier_free_guidance = args.guidance_scale != 1.0
552
+ if do_classifier_free_guidance:
553
+ negative_prompt = args.negative_prompt
554
+ if negative_prompt is None:
555
+ logger.info("Negative prompt is not provided, using empty prompt")
556
+ negative_prompt = ""
557
+ logger.info(f"Encoding negative prompt: {negative_prompt}")
558
+ prompt = [negative_prompt, prompt]
559
+ else:
560
+ if args.negative_prompt is not None:
561
+ logger.warning("Negative prompt is provided but guidance_scale is 1.0, negative prompt will be ignored.")
562
+
563
+ prompt_embeds, prompt_mask, prompt_embeds_2, prompt_mask_2 = encode_input_prompt(
564
+ prompt, args, device, args.fp8_llm, accelerator
565
+ )
566
+
567
+ # encode latents for video2video inference
568
+ video_latents = None
569
+ if args.video_path is not None:
570
+ # v2v inference
571
+ logger.info(f"Video2Video inference: {args.video_path}")
572
+ video = load_video(args.video_path, 0, video_length, bucket_reso=(width, height)) # list of frames
573
+ if len(video) < video_length:
574
+ raise ValueError(f"Video length is less than {video_length}")
575
+ video = np.stack(video, axis=0) # F, H, W, C
576
+ video = torch.from_numpy(video).permute(3, 0, 1, 2).unsqueeze(0).float() # 1, C, F, H, W
577
+ video = video / 255.0
578
+
579
+ logger.info(f"Encoding video to latents")
580
+ video_latents = encode_to_latents(args, video, device)
581
+ video_latents = video_latents.to(device=device, dtype=dit_dtype)
582
+
583
+ clean_memory_on_device(device)
584
+
585
+ # encode latents for image2video inference
586
+ image_latents = None
587
+ if args.image_path is not None:
588
+ # i2v inference
589
+ logger.info(f"Image2Video inference: {args.image_path}")
590
+
591
+ image = Image.open(args.image_path)
592
+ image = resize_image_to_bucket(image, (width, height)) # returns a numpy array
593
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).unsqueeze(2).float() # 1, C, 1, H, W
594
+ image = image / 255.0
595
+
596
+ logger.info(f"Encoding image to latents")
597
+ image_latents = encode_to_latents(args, image, device) # 1, C, 1, H, W
598
+ image_latents = image_latents.to(device=device, dtype=dit_dtype)
599
+
600
+ clean_memory_on_device(device)
601
+
602
+ # load DiT model
603
+ blocks_to_swap = args.blocks_to_swap if args.blocks_to_swap else 0
604
+ loading_device = "cpu" # if blocks_to_swap > 0 else device
605
+
606
+ logger.info(f"Loading DiT model from {args.dit}")
607
+ if args.attn_mode == "sdpa":
608
+ args.attn_mode = "torch"
609
+
610
+ # if image_latents is given, the model should be I2V model, so the in_channels should be 32
611
+ dit_in_channels = args.dit_in_channels if args.dit_in_channels is not None else (32 if image_latents is not None else 16)
612
+
613
+ # if we use LoRA, weigths should be bf16 instead of fp8, because merging should be done in bf16
614
+ # the model is too large, so we load the model to cpu. in addition, the .pt file is loaded to cpu anyway
615
+ # on the fly merging will be a solution for this issue for .safetenors files (not implemented yet)
616
+ transformer = load_transformer(
617
+ args.dit, args.attn_mode, args.split_attn, loading_device, dit_dtype, in_channels=dit_in_channels
618
+ )
619
+ transformer.eval()
620
+
621
+ # load LoRA weights
622
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
623
+ for i, lora_weight in enumerate(args.lora_weight):
624
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
625
+ lora_multiplier = args.lora_multiplier[i]
626
+ else:
627
+ lora_multiplier = 1.0
628
+
629
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
630
+ weights_sd = load_file(lora_weight)
631
+
632
+ # Filter to exclude keys that are part of single_blocks
633
+ if args.exclude_single_blocks:
634
+ filtered_weights = {k: v for k, v in weights_sd.items() if "single_blocks" not in k}
635
+ weights_sd = filtered_weights
636
+
637
+ if args.lycoris:
638
+ lycoris_net, _ = create_network_from_weights(
639
+ multiplier=lora_multiplier,
640
+ file=None,
641
+ weights_sd=weights_sd,
642
+ unet=transformer,
643
+ text_encoder=None,
644
+ vae=None,
645
+ for_inference=True,
646
+ )
647
+ else:
648
+ network = lora.create_arch_network_from_weights(
649
+ lora_multiplier, weights_sd, unet=transformer, for_inference=True
650
+ )
651
+ logger.info("Merging LoRA weights to DiT model")
652
+
653
+ # try:
654
+ # network.apply_to(None, transformer, apply_text_encoder=False, apply_unet=True)
655
+ # info = network.load_state_dict(weights_sd, strict=True)
656
+ # logger.info(f"Loaded LoRA weights from {weights_file}: {info}")
657
+ # network.eval()
658
+ # network.to(device)
659
+ # except Exception as e:
660
+ if args.lycoris:
661
+ lycoris_net.merge_to(None, transformer, weights_sd, dtype=None, device=device)
662
+ else:
663
+ network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
664
+
665
+ synchronize_device(device)
666
+
667
+ logger.info("LoRA weights loaded")
668
+
669
+ # save model here before casting to dit_weight_dtype
670
+ if args.save_merged_model:
671
+ logger.info(f"Saving merged model to {args.save_merged_model}")
672
+ mem_eff_save_file(transformer.state_dict(), args.save_merged_model) # save_file needs a lot of memory
673
+ logger.info("Merged model saved")
674
+ return
675
+
676
+ logger.info(f"Casting model to {dit_weight_dtype}")
677
+ transformer.to(dtype=dit_weight_dtype)
678
+
679
+ if args.fp8_fast:
680
+ logger.info("Enabling FP8 acceleration")
681
+ params_to_keep = {"norm", "bias", "time_in", "vector_in", "guidance_in", "txt_in", "img_in"}
682
+ for name, param in transformer.named_parameters():
683
+ dtype_to_use = dit_dtype if any(keyword in name for keyword in params_to_keep) else dit_weight_dtype
684
+ param.to(dtype=dtype_to_use)
685
+ convert_fp8_linear(transformer, dit_dtype, params_to_keep=params_to_keep)
686
+
687
+ if args.compile:
688
+ compile_backend, compile_mode, compile_dynamic, compile_fullgraph = args.compile_args
689
+ logger.info(
690
+ f"Torch Compiling[Backend: {compile_backend}; Mode: {compile_mode}; Dynamic: {compile_dynamic}; Fullgraph: {compile_fullgraph}]"
691
+ )
692
+ torch._dynamo.config.cache_size_limit = 32
693
+ for i, block in enumerate(transformer.single_blocks):
694
+ compiled_block = torch.compile(
695
+ block,
696
+ backend=compile_backend,
697
+ mode=compile_mode,
698
+ dynamic=compile_dynamic.lower() in "true",
699
+ fullgraph=compile_fullgraph.lower() in "true",
700
+ )
701
+ transformer.single_blocks[i] = compiled_block
702
+ for i, block in enumerate(transformer.double_blocks):
703
+ compiled_block = torch.compile(
704
+ block,
705
+ backend=compile_backend,
706
+ mode=compile_mode,
707
+ dynamic=compile_dynamic.lower() in "true",
708
+ fullgraph=compile_fullgraph.lower() in "true",
709
+ )
710
+ transformer.double_blocks[i] = compiled_block
711
+
712
+ if blocks_to_swap > 0:
713
+ logger.info(f"Enable swap {blocks_to_swap} blocks to CPU from device: {device}")
714
+ transformer.enable_block_swap(blocks_to_swap, device, supports_backward=False)
715
+ transformer.move_to_device_except_swap_blocks(device)
716
+ transformer.prepare_block_swap_before_forward()
717
+ else:
718
+ logger.info(f"Moving model to {device}")
719
+ transformer.to(device=device)
720
+ if args.img_in_txt_in_offloading:
721
+ logger.info("Enable offloading img_in and txt_in to CPU")
722
+ transformer.enable_img_in_txt_in_offloading()
723
+
724
+ # load scheduler
725
+ logger.info(f"Loading scheduler")
726
+ scheduler = FlowMatchDiscreteScheduler(shift=args.flow_shift, reverse=True, solver="euler")
727
+
728
+ # Prepare timesteps
729
+ num_inference_steps = args.infer_steps
730
+ scheduler.set_timesteps(num_inference_steps, device=device) # n_tokens is not used in FlowMatchDiscreteScheduler
731
+ timesteps = scheduler.timesteps
732
+
733
+ # Prepare generator
734
+ num_videos_per_prompt = 1 # args.num_videos # currently only support 1 video per prompt, this is a batch size
735
+ seed = args.seed
736
+ if seed is None:
737
+ seeds = [random.randint(0, 2**32 - 1) for _ in range(num_videos_per_prompt)]
738
+ elif isinstance(seed, int):
739
+ seeds = [seed + i for i in range(num_videos_per_prompt)]
740
+ else:
741
+ raise ValueError(f"Seed must be an integer or None, got {seed}.")
742
+ generator = [torch.Generator(device).manual_seed(seed) for seed in seeds]
743
+
744
+ # Prepare noisy latents
745
+ num_channels_latents = 16 # transformer.config.in_channels
746
+ vae_scale_factor = 2 ** (4 - 1) # len(self.vae.config.block_out_channels) == 4
747
+
748
+ vae_ver = vae.VAE_VER
749
+ if "884" in vae_ver:
750
+ latent_video_length = (video_length - 1) // 4 + 1
751
+ elif "888" in vae_ver:
752
+ latent_video_length = (video_length - 1) // 8 + 1
753
+ else:
754
+ latent_video_length = video_length
755
+
756
+ # shape = (
757
+ # num_videos_per_prompt,
758
+ # num_channels_latents,
759
+ # latent_video_length,
760
+ # height // vae_scale_factor,
761
+ # width // vae_scale_factor,
762
+ # )
763
+ # latents = randn_tensor(shape, generator=generator, device=device, dtype=dit_dtype)
764
+
765
+ # make first N frames to be the same if the given seed is same
766
+ shape_of_frame = (num_videos_per_prompt, num_channels_latents, 1, height // vae_scale_factor, width // vae_scale_factor)
767
+ latents = []
768
+ for i in range(latent_video_length):
769
+ latents.append(randn_tensor(shape_of_frame, generator=generator, device=device, dtype=dit_dtype))
770
+ latents = torch.cat(latents, dim=2)
771
+
772
+ # pad image_latents to match the length of video_latents
773
+ if image_latents is not None:
774
+ zero_latents = torch.zeros_like(latents)
775
+ zero_latents[:, :, :1, :, :] = image_latents
776
+ image_latents = zero_latents
777
+
778
+ if args.video_path is not None:
779
+ # v2v inference
780
+ noise = latents
781
+ assert noise.shape == video_latents.shape, f"noise shape {noise.shape} != video_latents shape {video_latents.shape}"
782
+
783
+ num_inference_steps = int(num_inference_steps * args.strength)
784
+ timestep_start = scheduler.timesteps[-num_inference_steps] # larger strength, less inference steps and more start time
785
+ t = timestep_start / 1000.0
786
+ latents = noise * t + video_latents * (1 - t)
787
+
788
+ timesteps = timesteps[-num_inference_steps:]
789
+
790
+ logger.info(f"strength: {args.strength}, num_inference_steps: {num_inference_steps}, timestep_start: {timestep_start}")
791
+
792
+ # FlowMatchDiscreteScheduler does not have init_noise_sigma
793
+
794
+ # Denoising loop
795
+ embedded_guidance_scale = args.embedded_cfg_scale
796
+ if embedded_guidance_scale is not None:
797
+ guidance_expand = torch.tensor([embedded_guidance_scale * 1000.0] * latents.shape[0], dtype=torch.float32, device="cpu")
798
+ guidance_expand = guidance_expand.to(device=device, dtype=dit_dtype)
799
+ if do_classifier_free_guidance:
800
+ guidance_expand = torch.cat([guidance_expand, guidance_expand], dim=0)
801
+ else:
802
+ guidance_expand = None
803
+ freqs_cos, freqs_sin = get_rotary_pos_embed(vae_ver, transformer, video_length, height, width)
804
+ # n_tokens = freqs_cos.shape[0]
805
+
806
+ # move and cast all inputs to the correct device and dtype
807
+ prompt_embeds = prompt_embeds.to(device=device, dtype=dit_dtype)
808
+ prompt_mask = prompt_mask.to(device=device)
809
+ prompt_embeds_2 = prompt_embeds_2.to(device=device, dtype=dit_dtype)
810
+ prompt_mask_2 = prompt_mask_2.to(device=device)
811
+
812
+ freqs_cos = freqs_cos.to(device=device, dtype=dit_dtype)
813
+ freqs_sin = freqs_sin.to(device=device, dtype=dit_dtype)
814
+
815
+ num_warmup_steps = len(timesteps) - num_inference_steps * scheduler.order # this should be 0 in v2v inference
816
+
817
+ # assert split_uncond and split_attn
818
+ if args.split_attn and do_classifier_free_guidance and not args.split_uncond:
819
+ logger.warning("split_attn is enabled, split_uncond will be enabled as well.")
820
+ args.split_uncond = True
821
+
822
+ # with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA]) as p:
823
+ with tqdm(total=num_inference_steps) as progress_bar:
824
+ for i, t in enumerate(timesteps):
825
+ latents = scheduler.scale_model_input(latents, t)
826
+
827
+ # predict the noise residual
828
+ with torch.no_grad(), accelerator.autocast():
829
+ latents_input = latents if not do_classifier_free_guidance else torch.cat([latents, latents], dim=0)
830
+ if image_latents is not None:
831
+ latents_image_input = (
832
+ image_latents if not do_classifier_free_guidance else torch.cat([image_latents, image_latents], dim=0)
833
+ )
834
+ latents_input = torch.cat([latents_input, latents_image_input], dim=1) # 1 or 2, C*2, F, H, W
835
+
836
+ batch_size = 1 if args.split_uncond else latents_input.shape[0]
837
+
838
+ noise_pred_list = []
839
+ for j in range(0, latents_input.shape[0], batch_size):
840
+ noise_pred = transformer( # For an input image (129, 192, 336) (1, 256, 256)
841
+ latents_input[j : j + batch_size], # [1, 16, 33, 24, 42]
842
+ t.repeat(batch_size).to(device=device, dtype=dit_dtype), # [1]
843
+ text_states=prompt_embeds[j : j + batch_size], # [1, 256, 4096]
844
+ text_mask=prompt_mask[j : j + batch_size], # [1, 256]
845
+ text_states_2=prompt_embeds_2[j : j + batch_size], # [1, 768]
846
+ freqs_cos=freqs_cos, # [seqlen, head_dim]
847
+ freqs_sin=freqs_sin, # [seqlen, head_dim]
848
+ guidance=guidance_expand[j : j + batch_size], # [1]
849
+ return_dict=True,
850
+ )["x"]
851
+ noise_pred_list.append(noise_pred)
852
+ noise_pred = torch.cat(noise_pred_list, dim=0)
853
+
854
+ # perform classifier free guidance
855
+ if do_classifier_free_guidance:
856
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
857
+ noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_cond - noise_pred_uncond)
858
+
859
+ # # SkyReels' rescale noise config is omitted for now
860
+ # if guidance_rescale > 0.0:
861
+ # # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
862
+ # noise_pred = rescale_noise_cfg(
863
+ # noise_pred,
864
+ # noise_pred_cond,
865
+ # guidance_rescale=self.guidance_rescale,
866
+ # )
867
+
868
+ # compute the previous noisy sample x_t -> x_t-1
869
+ latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
870
+
871
+ # update progress bar
872
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % scheduler.order == 0):
873
+ if progress_bar is not None:
874
+ progress_bar.update()
875
+
876
+ # print(p.key_averages().table(sort_by="self_cpu_time_total", row_limit=-1))
877
+ # print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
878
+
879
+ latents = latents.detach().cpu()
880
+ transformer = None
881
+ clean_memory_on_device(device)
882
+
883
+ # Save samples
884
+ output_type = args.output_type
885
+ save_path = args.save_path # if args.save_path_suffix == "" else f"{args.save_path}_{args.save_path_suffix}"
886
+ os.makedirs(save_path, exist_ok=True)
887
+ time_flag = datetime.fromtimestamp(time.time()).strftime("%Y%m%d-%H%M%S")
888
+
889
+ if output_type == "latent" or output_type == "both":
890
+ # save latent
891
+ for i, latent in enumerate(latents):
892
+ latent_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}_latent.safetensors"
893
+
894
+ if args.no_metadata:
895
+ metadata = None
896
+ else:
897
+ metadata = {
898
+ "seeds": f"{seeds[i]}",
899
+ "prompt": f"{args.prompt}",
900
+ "height": f"{height}",
901
+ "width": f"{width}",
902
+ "video_length": f"{video_length}",
903
+ "infer_steps": f"{num_inference_steps}",
904
+ "guidance_scale": f"{args.guidance_scale}",
905
+ "embedded_cfg_scale": f"{args.embedded_cfg_scale}",
906
+ }
907
+ if args.negative_prompt is not None:
908
+ metadata["negative_prompt"] = f"{args.negative_prompt}"
909
+ sd = {"latent": latent}
910
+ save_file(sd, latent_path, metadata=metadata)
911
+
912
+ logger.info(f"Latent save to: {latent_path}")
913
+ if output_type == "video" or output_type == "both":
914
+ # save video
915
+ videos = decode_latents(args, latents, device)
916
+ for i, sample in enumerate(videos):
917
+ original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
918
+ sample = sample.unsqueeze(0)
919
+ video_path = f"{save_path}/{time_flag}_{i}_{seeds[i]}{original_name}.mp4"
920
+ save_videos_grid(sample, video_path, fps=args.fps)
921
+ logger.info(f"Sample save to: {video_path}")
922
+ elif output_type == "images":
923
+ # save images
924
+ videos = decode_latents(args, latents, device)
925
+ for i, sample in enumerate(videos):
926
+ original_name = "" if original_base_names is None else f"_{original_base_names[i]}"
927
+ sample = sample.unsqueeze(0)
928
+ image_name = f"{time_flag}_{i}_{seeds[i]}{original_name}"
929
+ save_images_grid(sample, save_path, image_name)
930
+ logger.info(f"Sample images save to: {save_path}/{image_name}")
931
+
932
+ logger.info("Done!")
933
+
934
+
935
+ if __name__ == "__main__":
936
+ main()
merge_lora.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import torch
4
+ from safetensors.torch import load_file
5
+ from networks import lora
6
+ from utils.safetensors_utils import mem_eff_save_file
7
+ from hunyuan_model.models import load_transformer
8
+
9
+ logger = logging.getLogger(__name__)
10
+ logging.basicConfig(level=logging.INFO)
11
+
12
+
13
+ def parse_args():
14
+ parser = argparse.ArgumentParser(description="HunyuanVideo model merger script")
15
+
16
+ parser.add_argument("--dit", type=str, required=True, help="DiT checkpoint path or directory")
17
+ parser.add_argument("--dit_in_channels", type=int, default=16, help="input channels for DiT, default is 16, skyreels I2V is 32")
18
+ parser.add_argument("--lora_weight", type=str, nargs="*", required=False, default=None, help="LoRA weight path")
19
+ parser.add_argument("--lora_multiplier", type=float, nargs="*", default=[1.0], help="LoRA multiplier (can specify multiple values)")
20
+ parser.add_argument("--save_merged_model", type=str, required=True, help="Path to save the merged model")
21
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use for merging")
22
+
23
+ return parser.parse_args()
24
+
25
+
26
+ def main():
27
+ args = parse_args()
28
+
29
+ device = torch.device(args.device)
30
+ logger.info(f"Using device: {device}")
31
+
32
+ # Load DiT model
33
+ logger.info(f"Loading DiT model from {args.dit}")
34
+ transformer = load_transformer(args.dit, "torch", False, "cpu", torch.bfloat16, in_channels=args.dit_in_channels)
35
+ transformer.eval()
36
+
37
+ # Load LoRA weights and merge
38
+ if args.lora_weight is not None and len(args.lora_weight) > 0:
39
+ for i, lora_weight in enumerate(args.lora_weight):
40
+ # Use the corresponding lora_multiplier or default to 1.0
41
+ if args.lora_multiplier is not None and len(args.lora_multiplier) > i:
42
+ lora_multiplier = args.lora_multiplier[i]
43
+ else:
44
+ lora_multiplier = 1.0
45
+
46
+ logger.info(f"Loading LoRA weights from {lora_weight} with multiplier {lora_multiplier}")
47
+ weights_sd = load_file(lora_weight)
48
+ network = lora.create_arch_network_from_weights(
49
+ lora_multiplier, weights_sd, unet=transformer, for_inference=True
50
+ )
51
+ logger.info("Merging LoRA weights to DiT model")
52
+ network.merge_to(None, transformer, weights_sd, device=device, non_blocking=True)
53
+
54
+ logger.info("LoRA weights loaded")
55
+
56
+ # Save the merged model
57
+ logger.info(f"Saving merged model to {args.save_merged_model}")
58
+ mem_eff_save_file(transformer.state_dict(), args.save_merged_model)
59
+ logger.info("Merged model saved")
60
+
61
+
62
+ if __name__ == "__main__":
63
+ main()
modules/__init__.py ADDED
File without changes