Alexander Bagus commited on
Commit
95ee7de
·
1 Parent(s): cca610a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. videox_fun/__init__.py +0 -0
  2. videox_fun/api/api.py +0 -226
  3. videox_fun/api/api_multi_nodes.py +0 -320
  4. videox_fun/data/__init__.py +0 -9
  5. videox_fun/data/bucket_sampler.py +0 -379
  6. videox_fun/data/dataset_image.py +0 -191
  7. videox_fun/data/dataset_image_video.py +0 -657
  8. videox_fun/data/dataset_video.py +0 -901
  9. videox_fun/data/utils.py +0 -347
  10. videox_fun/dist/__init__.py +0 -72
  11. videox_fun/dist/cogvideox_xfuser.py +0 -93
  12. videox_fun/dist/flux2_xfuser.py +0 -194
  13. videox_fun/dist/flux_xfuser.py +0 -165
  14. videox_fun/dist/fsdp.py +0 -44
  15. videox_fun/dist/fuser.py +0 -87
  16. videox_fun/dist/hunyuanvideo_xfuser.py +0 -166
  17. videox_fun/dist/qwen_xfuser.py +0 -176
  18. videox_fun/dist/wan_xfuser.py +0 -180
  19. videox_fun/dist/z_image_xfuser.py +0 -85
  20. videox_fun/models/__init__.py +0 -131
  21. videox_fun/models/attention_utils.py +0 -211
  22. videox_fun/models/cache_utils.py +0 -80
  23. videox_fun/models/cogvideox_transformer3d.py +0 -915
  24. videox_fun/models/cogvideox_vae.py +0 -1675
  25. videox_fun/models/fantasytalking_audio_encoder.py +0 -52
  26. videox_fun/models/fantasytalking_transformer3d.py +0 -644
  27. videox_fun/models/flux2_image_processor.py +0 -139
  28. videox_fun/models/flux2_transformer2d.py +0 -1278
  29. videox_fun/models/flux2_transformer2d_control.py +0 -312
  30. videox_fun/models/flux2_vae.py +0 -543
  31. videox_fun/models/flux_transformer2d.py +0 -832
  32. videox_fun/models/hunyuanvideo_transformer3d.py +0 -1478
  33. videox_fun/models/hunyuanvideo_vae.py +0 -1082
  34. videox_fun/models/qwenimage_transformer2d.py +0 -1118
  35. videox_fun/models/qwenimage_vae.py +0 -1087
  36. videox_fun/models/wan_animate_adapter.py +0 -397
  37. videox_fun/models/wan_animate_motion_encoder.py +0 -309
  38. videox_fun/models/wan_audio_encoder.py +0 -213
  39. videox_fun/models/wan_audio_injector.py +0 -1093
  40. videox_fun/models/wan_camera_adapter.py +0 -64
  41. videox_fun/models/wan_image_encoder.py +0 -553
  42. videox_fun/models/wan_text_encoder.py +0 -395
  43. videox_fun/models/wan_transformer3d.py +0 -1394
  44. videox_fun/models/wan_transformer3d_animate.py +0 -302
  45. videox_fun/models/wan_transformer3d_s2v.py +0 -932
  46. videox_fun/models/wan_transformer3d_vace.py +0 -394
  47. videox_fun/models/wan_vae.py +0 -860
  48. videox_fun/models/wan_vae3_8.py +0 -1091
  49. videox_fun/models/wan_xlm_roberta.py +0 -170
  50. videox_fun/models/z_image_transformer2d.py +0 -1050
videox_fun/__init__.py DELETED
File without changes
videox_fun/api/api.py DELETED
@@ -1,226 +0,0 @@
1
- import base64
2
- import gc
3
- import hashlib
4
- import io
5
- import os
6
- import tempfile
7
- from io import BytesIO
8
-
9
- import gradio as gr
10
- import requests
11
- import torch
12
- from fastapi import FastAPI
13
- from PIL import Image
14
-
15
-
16
- # Function to encode a file to Base64
17
- def encode_file_to_base64(file_path):
18
- with open(file_path, "rb") as file:
19
- # Encode the data to Base64
20
- file_base64 = base64.b64encode(file.read())
21
- return file_base64
22
-
23
- def update_diffusion_transformer_api(_: gr.Blocks, app: FastAPI, controller):
24
- @app.post("/videox_fun/update_diffusion_transformer")
25
- def _update_diffusion_transformer_api(
26
- datas: dict,
27
- ):
28
- diffusion_transformer_path = datas.get('diffusion_transformer_path', 'none')
29
-
30
- try:
31
- controller.update_diffusion_transformer(
32
- diffusion_transformer_path
33
- )
34
- comment = "Success"
35
- except Exception as e:
36
- torch.cuda.empty_cache()
37
- comment = f"Error. error information is {str(e)}"
38
-
39
- return {"message": comment}
40
-
41
- def download_from_url(url, timeout=10):
42
- try:
43
- response = requests.get(url, timeout=timeout)
44
- response.raise_for_status() # 检查请求是否成功
45
- return response.content
46
- except requests.exceptions.RequestException as e:
47
- print(f"Error downloading from {url}: {e}")
48
- return None
49
-
50
- def save_base64_video(base64_string):
51
- video_data = base64.b64decode(base64_string)
52
-
53
- md5_hash = hashlib.md5(video_data).hexdigest()
54
- filename = f"{md5_hash}.mp4"
55
-
56
- temp_dir = tempfile.gettempdir()
57
- file_path = os.path.join(temp_dir, filename)
58
-
59
- with open(file_path, 'wb') as video_file:
60
- video_file.write(video_data)
61
-
62
- return file_path
63
-
64
- def save_base64_image(base64_string):
65
- video_data = base64.b64decode(base64_string)
66
-
67
- md5_hash = hashlib.md5(video_data).hexdigest()
68
- filename = f"{md5_hash}.jpg"
69
-
70
- temp_dir = tempfile.gettempdir()
71
- file_path = os.path.join(temp_dir, filename)
72
-
73
- with open(file_path, 'wb') as video_file:
74
- video_file.write(video_data)
75
-
76
- return file_path
77
-
78
- def save_url_video(url):
79
- video_data = download_from_url(url)
80
- if video_data:
81
- return save_base64_video(base64.b64encode(video_data))
82
- return None
83
-
84
- def save_url_image(url):
85
- image_data = download_from_url(url)
86
- if image_data:
87
- return save_base64_image(base64.b64encode(image_data))
88
- return None
89
-
90
- def infer_forward_api(_: gr.Blocks, app: FastAPI, controller):
91
- @app.post("/videox_fun/infer_forward")
92
- def _infer_forward_api(
93
- datas: dict,
94
- ):
95
- base_model_path = datas.get('base_model_path', 'none')
96
- base_model_2_path = datas.get('base_model_2_path', 'none')
97
- lora_model_path = datas.get('lora_model_path', 'none')
98
- lora_model_2_path = datas.get('lora_model_2_path', 'none')
99
- lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
100
- prompt_textbox = datas.get('prompt_textbox', None)
101
- negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
102
- sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
103
- sample_step_slider = datas.get('sample_step_slider', 30)
104
- resize_method = datas.get('resize_method', "Generate by")
105
- width_slider = datas.get('width_slider', 672)
106
- height_slider = datas.get('height_slider', 384)
107
- base_resolution = datas.get('base_resolution', 512)
108
- is_image = datas.get('is_image', False)
109
- generation_method = datas.get('generation_method', False)
110
- length_slider = datas.get('length_slider', 49)
111
- overlap_video_length = datas.get('overlap_video_length', 4)
112
- partial_video_length = datas.get('partial_video_length', 72)
113
- cfg_scale_slider = datas.get('cfg_scale_slider', 6)
114
- start_image = datas.get('start_image', None)
115
- end_image = datas.get('end_image', None)
116
- validation_video = datas.get('validation_video', None)
117
- validation_video_mask = datas.get('validation_video_mask', None)
118
- control_video = datas.get('control_video', None)
119
- denoise_strength = datas.get('denoise_strength', 0.70)
120
- seed_textbox = datas.get("seed_textbox", 43)
121
-
122
- ref_image = datas.get('ref_image', None)
123
- enable_teacache = datas.get('enable_teacache', True)
124
- teacache_threshold = datas.get('teacache_threshold', 0.10)
125
- num_skip_start_steps = datas.get('num_skip_start_steps', 1)
126
- teacache_offload = datas.get('teacache_offload', False)
127
- cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
128
- enable_riflex = datas.get('enable_riflex', False)
129
- riflex_k = datas.get('riflex_k', 6)
130
- fps = datas.get('fps', None)
131
-
132
- generation_method = "Image Generation" if is_image else generation_method
133
-
134
- if start_image is not None:
135
- if start_image.startswith('http'):
136
- start_image = save_url_image(start_image)
137
- start_image = [Image.open(start_image).convert("RGB")]
138
- else:
139
- start_image = base64.b64decode(start_image)
140
- start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
141
-
142
- if end_image is not None:
143
- if end_image.startswith('http'):
144
- end_image = save_url_image(end_image)
145
- end_image = [Image.open(end_image).convert("RGB")]
146
- else:
147
- end_image = base64.b64decode(end_image)
148
- end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
149
-
150
- if validation_video is not None:
151
- if validation_video.startswith('http'):
152
- validation_video = save_url_video(validation_video)
153
- else:
154
- validation_video = save_base64_video(validation_video)
155
-
156
- if validation_video_mask is not None:
157
- if validation_video_mask.startswith('http'):
158
- validation_video_mask = save_url_image(validation_video_mask)
159
- else:
160
- validation_video_mask = save_base64_image(validation_video_mask)
161
-
162
- if control_video is not None:
163
- if control_video.startswith('http'):
164
- control_video = save_url_video(control_video)
165
- else:
166
- control_video = save_base64_video(control_video)
167
-
168
- if ref_image is not None:
169
- if ref_image.startswith('http'):
170
- ref_image = save_url_image(ref_image)
171
- ref_image = [Image.open(ref_image).convert("RGB")]
172
- else:
173
- ref_image = base64.b64decode(ref_image)
174
- ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
175
-
176
- try:
177
- save_sample_path, comment = controller.generate(
178
- "",
179
- base_model_path,
180
- lora_model_path,
181
- lora_alpha_slider,
182
- prompt_textbox,
183
- negative_prompt_textbox,
184
- sampler_dropdown,
185
- sample_step_slider,
186
- resize_method,
187
- width_slider,
188
- height_slider,
189
- base_resolution,
190
- generation_method,
191
- length_slider,
192
- overlap_video_length,
193
- partial_video_length,
194
- cfg_scale_slider,
195
- start_image,
196
- end_image,
197
- validation_video,
198
- validation_video_mask,
199
- control_video,
200
- denoise_strength,
201
- seed_textbox,
202
- ref_image = ref_image,
203
- enable_teacache = enable_teacache,
204
- teacache_threshold = teacache_threshold,
205
- num_skip_start_steps = num_skip_start_steps,
206
- teacache_offload = teacache_offload,
207
- cfg_skip_ratio = cfg_skip_ratio,
208
- enable_riflex = enable_riflex,
209
- riflex_k = riflex_k,
210
- base_model_2_dropdown = base_model_2_path,
211
- lora_model_2_dropdown = lora_model_2_path,
212
- fps = fps,
213
- is_api = True,
214
- )
215
- except Exception as e:
216
- gc.collect()
217
- torch.cuda.empty_cache()
218
- torch.cuda.ipc_collect()
219
- save_sample_path = ""
220
- comment = f"Error. error information is {str(e)}"
221
- return {"message": comment, "save_sample_path": None, "base64_encoding": None}
222
-
223
- if save_sample_path != "":
224
- return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
225
- else:
226
- return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": None}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/api/api_multi_nodes.py DELETED
@@ -1,320 +0,0 @@
1
- # This file is modified from https://github.com/xdit-project/xDiT/blob/main/entrypoints/launch.py
2
- import base64
3
- import gc
4
- import hashlib
5
- import io
6
- import os
7
- import tempfile
8
- from io import BytesIO
9
-
10
- import gradio as gr
11
- import requests
12
- import torch
13
- import torch.distributed as dist
14
- from fastapi import FastAPI, HTTPException
15
- from PIL import Image
16
-
17
- from .api import download_from_url, encode_file_to_base64
18
-
19
- try:
20
- import ray
21
- except:
22
- print("Ray is not installed. If you want to use multi gpus api. Please install it by running 'pip install ray'.")
23
- ray = None
24
-
25
- def save_base64_video_dist(base64_string):
26
- video_data = base64.b64decode(base64_string)
27
-
28
- md5_hash = hashlib.md5(video_data).hexdigest()
29
- filename = f"{md5_hash}.mp4"
30
-
31
- temp_dir = tempfile.gettempdir()
32
- file_path = os.path.join(temp_dir, filename)
33
-
34
- if dist.is_initialized():
35
- if dist.get_rank() == 0:
36
- with open(file_path, 'wb') as video_file:
37
- video_file.write(video_data)
38
- dist.barrier()
39
- else:
40
- with open(file_path, 'wb') as video_file:
41
- video_file.write(video_data)
42
- return file_path
43
-
44
- def save_base64_image_dist(base64_string):
45
- video_data = base64.b64decode(base64_string)
46
-
47
- md5_hash = hashlib.md5(video_data).hexdigest()
48
- filename = f"{md5_hash}.jpg"
49
-
50
- temp_dir = tempfile.gettempdir()
51
- file_path = os.path.join(temp_dir, filename)
52
-
53
- if dist.is_initialized():
54
- if dist.get_rank() == 0:
55
- with open(file_path, 'wb') as video_file:
56
- video_file.write(video_data)
57
- dist.barrier()
58
- else:
59
- with open(file_path, 'wb') as video_file:
60
- video_file.write(video_data)
61
- return file_path
62
-
63
- def save_url_video_dist(url):
64
- video_data = download_from_url(url)
65
- if video_data:
66
- return save_base64_video_dist(base64.b64encode(video_data))
67
- return None
68
-
69
- def save_url_image_dist(url):
70
- image_data = download_from_url(url)
71
- if image_data:
72
- return save_base64_image_dist(base64.b64encode(image_data))
73
- return None
74
-
75
- if ray is not None:
76
- @ray.remote(num_gpus=1)
77
- class MultiNodesGenerator:
78
- def __init__(
79
- self, rank: int, world_size: int, Controller,
80
- GPU_memory_mode, scheduler_dict, model_name=None, model_type="Inpaint",
81
- config_path=None, ulysses_degree=1, ring_degree=1,
82
- fsdp_dit=False, fsdp_text_encoder=False, compile_dit=False,
83
- weight_dtype=None, savedir_sample=None,
84
- ):
85
- # Set PyTorch distributed environment variables
86
- os.environ["RANK"] = str(rank)
87
- os.environ["WORLD_SIZE"] = str(world_size)
88
- os.environ["MASTER_ADDR"] = "127.0.0.1"
89
- os.environ["MASTER_PORT"] = "29500"
90
-
91
- self.rank = rank
92
- self.controller = Controller(
93
- GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
94
- ulysses_degree=ulysses_degree, ring_degree=ring_degree,
95
- fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
96
- weight_dtype=weight_dtype, savedir_sample=savedir_sample,
97
- )
98
-
99
- def generate(self, datas):
100
- try:
101
- base_model_path = datas.get('base_model_path', 'none')
102
- base_model_2_path = datas.get('base_model_2_path', 'none')
103
- lora_model_path = datas.get('lora_model_path', 'none')
104
- lora_model_2_path = datas.get('lora_model_2_path', 'none')
105
- lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
106
- prompt_textbox = datas.get('prompt_textbox', None)
107
- negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution. Watermark present in each frame. The background is solid. Strange body and strange trajectory. Distortion. ')
108
- sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
109
- sample_step_slider = datas.get('sample_step_slider', 30)
110
- resize_method = datas.get('resize_method', "Generate by")
111
- width_slider = datas.get('width_slider', 672)
112
- height_slider = datas.get('height_slider', 384)
113
- base_resolution = datas.get('base_resolution', 512)
114
- is_image = datas.get('is_image', False)
115
- generation_method = datas.get('generation_method', False)
116
- length_slider = datas.get('length_slider', 49)
117
- overlap_video_length = datas.get('overlap_video_length', 4)
118
- partial_video_length = datas.get('partial_video_length', 72)
119
- cfg_scale_slider = datas.get('cfg_scale_slider', 6)
120
- start_image = datas.get('start_image', None)
121
- end_image = datas.get('end_image', None)
122
- validation_video = datas.get('validation_video', None)
123
- validation_video_mask = datas.get('validation_video_mask', None)
124
- control_video = datas.get('control_video', None)
125
- denoise_strength = datas.get('denoise_strength', 0.70)
126
- seed_textbox = datas.get("seed_textbox", 43)
127
-
128
- ref_image = datas.get('ref_image', None)
129
- enable_teacache = datas.get('enable_teacache', True)
130
- teacache_threshold = datas.get('teacache_threshold', 0.10)
131
- num_skip_start_steps = datas.get('num_skip_start_steps', 1)
132
- teacache_offload = datas.get('teacache_offload', False)
133
- cfg_skip_ratio = datas.get('cfg_skip_ratio', 0)
134
- enable_riflex = datas.get('enable_riflex', False)
135
- riflex_k = datas.get('riflex_k', 6)
136
- fps = datas.get('fps', None)
137
-
138
- generation_method = "Image Generation" if is_image else generation_method
139
-
140
- if start_image is not None:
141
- if start_image.startswith('http'):
142
- start_image = save_url_image_dist(start_image)
143
- start_image = [Image.open(start_image).convert("RGB")]
144
- else:
145
- start_image = base64.b64decode(start_image)
146
- start_image = [Image.open(BytesIO(start_image)).convert("RGB")]
147
-
148
- if end_image is not None:
149
- if end_image.startswith('http'):
150
- end_image = save_url_image_dist(end_image)
151
- end_image = [Image.open(end_image).convert("RGB")]
152
- else:
153
- end_image = base64.b64decode(end_image)
154
- end_image = [Image.open(BytesIO(end_image)).convert("RGB")]
155
-
156
- if validation_video is not None:
157
- if validation_video.startswith('http'):
158
- validation_video = save_url_video_dist(validation_video)
159
- else:
160
- validation_video = save_base64_video_dist(validation_video)
161
-
162
- if validation_video_mask is not None:
163
- if validation_video_mask.startswith('http'):
164
- validation_video_mask = save_url_image_dist(validation_video_mask)
165
- else:
166
- validation_video_mask = save_base64_image_dist(validation_video_mask)
167
-
168
- if control_video is not None:
169
- if control_video.startswith('http'):
170
- control_video = save_url_video_dist(control_video)
171
- else:
172
- control_video = save_base64_video_dist(control_video)
173
-
174
- if ref_image is not None:
175
- if ref_image.startswith('http'):
176
- ref_image = save_url_image_dist(ref_image)
177
- ref_image = [Image.open(ref_image).convert("RGB")]
178
- else:
179
- ref_image = base64.b64decode(ref_image)
180
- ref_image = [Image.open(BytesIO(ref_image)).convert("RGB")]
181
-
182
- try:
183
- save_sample_path, comment = self.controller.generate(
184
- "",
185
- base_model_path,
186
- lora_model_path,
187
- lora_alpha_slider,
188
- prompt_textbox,
189
- negative_prompt_textbox,
190
- sampler_dropdown,
191
- sample_step_slider,
192
- resize_method,
193
- width_slider,
194
- height_slider,
195
- base_resolution,
196
- generation_method,
197
- length_slider,
198
- overlap_video_length,
199
- partial_video_length,
200
- cfg_scale_slider,
201
- start_image,
202
- end_image,
203
- validation_video,
204
- validation_video_mask,
205
- control_video,
206
- denoise_strength,
207
- seed_textbox,
208
- ref_image = ref_image,
209
- enable_teacache = enable_teacache,
210
- teacache_threshold = teacache_threshold,
211
- num_skip_start_steps = num_skip_start_steps,
212
- teacache_offload = teacache_offload,
213
- cfg_skip_ratio = cfg_skip_ratio,
214
- enable_riflex = enable_riflex,
215
- riflex_k = riflex_k,
216
- base_model_2_dropdown = base_model_2_path,
217
- lora_model_2_dropdown = lora_model_2_path,
218
- fps = fps,
219
- is_api = True,
220
- )
221
- except Exception as e:
222
- gc.collect()
223
- torch.cuda.empty_cache()
224
- torch.cuda.ipc_collect()
225
- save_sample_path = ""
226
- comment = f"Error. error information is {str(e)}"
227
- if dist.is_initialized():
228
- if dist.get_rank() == 0:
229
- return {"message": comment, "save_sample_path": None, "base64_encoding": None}
230
- else:
231
- return None
232
- else:
233
- return {"message": comment, "save_sample_path": None, "base64_encoding": None}
234
-
235
-
236
- if dist.is_initialized():
237
- if dist.get_rank() == 0:
238
- if save_sample_path != "":
239
- return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
240
- else:
241
- return {"message": comment, "save_sample_path": None, "base64_encoding": None}
242
- else:
243
- return None
244
- else:
245
- if save_sample_path != "":
246
- return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
247
- else:
248
- return {"message": comment, "save_sample_path": None, "base64_encoding": None}
249
-
250
- except Exception as e:
251
- print(f"Error generating: {str(e)}")
252
- comment = f"Error generating: {str(e)}"
253
- if dist.is_initialized():
254
- if dist.get_rank() == 0:
255
- return {"message": comment, "save_sample_path": None, "base64_encoding": None}
256
- else:
257
- return None
258
- else:
259
- return {"message": comment, "save_sample_path": None, "base64_encoding": None}
260
-
261
- class MultiNodesEngine:
262
- def __init__(
263
- self,
264
- world_size,
265
- Controller,
266
- GPU_memory_mode,
267
- scheduler_dict,
268
- model_name,
269
- model_type,
270
- config_path,
271
- ulysses_degree=1,
272
- ring_degree=1,
273
- fsdp_dit=False,
274
- fsdp_text_encoder=False,
275
- compile_dit=False,
276
- weight_dtype=torch.bfloat16,
277
- savedir_sample="samples"
278
- ):
279
- # Ensure Ray is initialized
280
- if not ray.is_initialized():
281
- ray.init()
282
-
283
- num_workers = world_size
284
- self.workers = [
285
- MultiNodesGenerator.remote(
286
- rank, world_size, Controller,
287
- GPU_memory_mode, scheduler_dict, model_name=model_name, model_type=model_type, config_path=config_path,
288
- ulysses_degree=ulysses_degree, ring_degree=ring_degree,
289
- fsdp_dit=fsdp_dit, fsdp_text_encoder=fsdp_text_encoder, compile_dit=compile_dit,
290
- weight_dtype=weight_dtype, savedir_sample=savedir_sample,
291
- )
292
- for rank in range(num_workers)
293
- ]
294
- print("Update workers done")
295
-
296
- async def generate(self, data):
297
- results = ray.get([
298
- worker.generate.remote(data)
299
- for worker in self.workers
300
- ])
301
-
302
- return next(path for path in results if path is not None)
303
-
304
- def multi_nodes_infer_forward_api(_: gr.Blocks, app: FastAPI, engine):
305
-
306
- @app.post("/videox_fun/infer_forward")
307
- async def _multi_nodes_infer_forward_api(
308
- datas: dict,
309
- ):
310
- try:
311
- result = await engine.generate(datas)
312
- return result
313
- except Exception as e:
314
- if isinstance(e, HTTPException):
315
- raise e
316
- raise HTTPException(status_code=500, detail=str(e))
317
- else:
318
- MultiNodesEngine = None
319
- MultiNodesGenerator = None
320
- multi_nodes_infer_forward_api = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/data/__init__.py DELETED
@@ -1,9 +0,0 @@
1
- from .dataset_image import CC15M, ImageEditDataset
2
- from .dataset_image_video import (ImageVideoControlDataset, ImageVideoDataset, TextDataset,
3
- ImageVideoSampler)
4
- from .dataset_video import VideoDataset, VideoSpeechDataset, VideoAnimateDataset, WebVid10M
5
- from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
6
- custom_meshgrid, get_random_mask, get_relative_pose,
7
- get_video_reader_batch, padding_image, process_pose_file,
8
- process_pose_params, ray_condition, resize_frame,
9
- resize_image_with_target_area)
 
 
 
 
 
 
 
 
 
 
videox_fun/data/bucket_sampler.py DELETED
@@ -1,379 +0,0 @@
1
- # Copyright (c) OpenMMLab. All rights reserved.
2
- import os
3
- from typing import (Generic, Iterable, Iterator, List, Optional, Sequence,
4
- Sized, TypeVar, Union)
5
-
6
- import cv2
7
- import numpy as np
8
- import torch
9
- from PIL import Image
10
- from torch.utils.data import BatchSampler, Dataset, Sampler
11
-
12
- ASPECT_RATIO_512 = {
13
- '0.25': [256.0, 1024.0], '0.26': [256.0, 992.0], '0.27': [256.0, 960.0], '0.28': [256.0, 928.0],
14
- '0.32': [288.0, 896.0], '0.33': [288.0, 864.0], '0.35': [288.0, 832.0], '0.4': [320.0, 800.0],
15
- '0.42': [320.0, 768.0], '0.48': [352.0, 736.0], '0.5': [352.0, 704.0], '0.52': [352.0, 672.0],
16
- '0.57': [384.0, 672.0], '0.6': [384.0, 640.0], '0.68': [416.0, 608.0], '0.72': [416.0, 576.0],
17
- '0.78': [448.0, 576.0], '0.82': [448.0, 544.0], '0.88': [480.0, 544.0], '0.94': [480.0, 512.0],
18
- '1.0': [512.0, 512.0], '1.07': [512.0, 480.0], '1.13': [544.0, 480.0], '1.21': [544.0, 448.0],
19
- '1.29': [576.0, 448.0], '1.38': [576.0, 416.0], '1.46': [608.0, 416.0], '1.67': [640.0, 384.0],
20
- '1.75': [672.0, 384.0], '2.0': [704.0, 352.0], '2.09': [736.0, 352.0], '2.4': [768.0, 320.0],
21
- '2.5': [800.0, 320.0], '2.89': [832.0, 288.0], '3.0': [864.0, 288.0], '3.11': [896.0, 288.0],
22
- '3.62': [928.0, 256.0], '3.75': [960.0, 256.0], '3.88': [992.0, 256.0], '4.0': [1024.0, 256.0]
23
- }
24
- ASPECT_RATIO_RANDOM_CROP_512 = {
25
- '0.42': [320.0, 768.0], '0.5': [352.0, 704.0],
26
- '0.57': [384.0, 672.0], '0.68': [416.0, 608.0], '0.78': [448.0, 576.0], '0.88': [480.0, 544.0],
27
- '0.94': [480.0, 512.0], '1.0': [512.0, 512.0], '1.07': [512.0, 480.0],
28
- '1.13': [544.0, 480.0], '1.29': [576.0, 448.0], '1.46': [608.0, 416.0], '1.75': [672.0, 384.0],
29
- '2.0': [704.0, 352.0], '2.4': [768.0, 320.0]
30
- }
31
- ASPECT_RATIO_RANDOM_CROP_PROB = [
32
- 1, 2,
33
- 4, 4, 4, 4,
34
- 8, 8, 8,
35
- 4, 4, 4, 4,
36
- 2, 1
37
- ]
38
- ASPECT_RATIO_RANDOM_CROP_PROB = np.array(ASPECT_RATIO_RANDOM_CROP_PROB) / sum(ASPECT_RATIO_RANDOM_CROP_PROB)
39
-
40
- def get_closest_ratio(height: float, width: float, ratios: dict = ASPECT_RATIO_512):
41
- aspect_ratio = height / width
42
- closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - aspect_ratio))
43
- return ratios[closest_ratio], float(closest_ratio)
44
-
45
- def get_image_size_without_loading(path):
46
- with Image.open(path) as img:
47
- return img.size # (width, height)
48
-
49
- class RandomSampler(Sampler[int]):
50
- r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
51
-
52
- If with replacement, then user can specify :attr:`num_samples` to draw.
53
-
54
- Args:
55
- data_source (Dataset): dataset to sample from
56
- replacement (bool): samples are drawn on-demand with replacement if ``True``, default=``False``
57
- num_samples (int): number of samples to draw, default=`len(dataset)`.
58
- generator (Generator): Generator used in sampling.
59
- """
60
-
61
- data_source: Sized
62
- replacement: bool
63
-
64
- def __init__(self, data_source: Sized, replacement: bool = False,
65
- num_samples: Optional[int] = None, generator=None) -> None:
66
- self.data_source = data_source
67
- self.replacement = replacement
68
- self._num_samples = num_samples
69
- self.generator = generator
70
- self._pos_start = 0
71
-
72
- if not isinstance(self.replacement, bool):
73
- raise TypeError(f"replacement should be a boolean value, but got replacement={self.replacement}")
74
-
75
- if not isinstance(self.num_samples, int) or self.num_samples <= 0:
76
- raise ValueError(f"num_samples should be a positive integer value, but got num_samples={self.num_samples}")
77
-
78
- @property
79
- def num_samples(self) -> int:
80
- # dataset size might change at runtime
81
- if self._num_samples is None:
82
- return len(self.data_source)
83
- return self._num_samples
84
-
85
- def __iter__(self) -> Iterator[int]:
86
- n = len(self.data_source)
87
- if self.generator is None:
88
- seed = int(torch.empty((), dtype=torch.int64).random_().item())
89
- generator = torch.Generator()
90
- generator.manual_seed(seed)
91
- else:
92
- generator = self.generator
93
-
94
- if self.replacement:
95
- for _ in range(self.num_samples // 32):
96
- yield from torch.randint(high=n, size=(32,), dtype=torch.int64, generator=generator).tolist()
97
- yield from torch.randint(high=n, size=(self.num_samples % 32,), dtype=torch.int64, generator=generator).tolist()
98
- else:
99
- for _ in range(self.num_samples // n):
100
- xx = torch.randperm(n, generator=generator).tolist()
101
- if self._pos_start >= n:
102
- self._pos_start = 0
103
- print("xx top 10", xx[:10], self._pos_start)
104
- for idx in range(self._pos_start, n):
105
- yield xx[idx]
106
- self._pos_start = (self._pos_start + 1) % n
107
- self._pos_start = 0
108
- yield from torch.randperm(n, generator=generator).tolist()[:self.num_samples % n]
109
-
110
- def __len__(self) -> int:
111
- return self.num_samples
112
-
113
- class AspectRatioBatchImageSampler(BatchSampler):
114
- """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
115
-
116
- Args:
117
- sampler (Sampler): Base sampler.
118
- dataset (Dataset): Dataset providing data information.
119
- batch_size (int): Size of mini-batch.
120
- drop_last (bool): If ``True``, the sampler will drop the last batch if
121
- its size would be less than ``batch_size``.
122
- aspect_ratios (dict): The predefined aspect ratios.
123
- """
124
- def __init__(
125
- self,
126
- sampler: Sampler,
127
- dataset: Dataset,
128
- batch_size: int,
129
- train_folder: str = None,
130
- aspect_ratios: dict = ASPECT_RATIO_512,
131
- drop_last: bool = False,
132
- config=None,
133
- **kwargs
134
- ) -> None:
135
- if not isinstance(sampler, Sampler):
136
- raise TypeError('sampler should be an instance of ``Sampler``, '
137
- f'but got {sampler}')
138
- if not isinstance(batch_size, int) or batch_size <= 0:
139
- raise ValueError('batch_size should be a positive integer value, '
140
- f'but got batch_size={batch_size}')
141
- self.sampler = sampler
142
- self.dataset = dataset
143
- self.train_folder = train_folder
144
- self.batch_size = batch_size
145
- self.aspect_ratios = aspect_ratios
146
- self.drop_last = drop_last
147
- self.config = config
148
- # buckets for each aspect ratio
149
- self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
150
- # [str(k) for k, v in aspect_ratios]
151
- self.current_available_bucket_keys = list(aspect_ratios.keys())
152
-
153
- def __iter__(self):
154
- for idx in self.sampler:
155
- try:
156
- image_dict = self.dataset[idx]
157
-
158
- width, height = image_dict.get("width", None), image_dict.get("height", None)
159
- if width is None or height is None:
160
- image_id, name = image_dict['file_path'], image_dict['text']
161
- if self.train_folder is None:
162
- image_dir = image_id
163
- else:
164
- image_dir = os.path.join(self.train_folder, image_id)
165
-
166
- width, height = get_image_size_without_loading(image_dir)
167
-
168
- ratio = height / width # self.dataset[idx]
169
- else:
170
- height = int(height)
171
- width = int(width)
172
- ratio = height / width # self.dataset[idx]
173
- except Exception as e:
174
- print(e)
175
- continue
176
- # find the closest aspect ratio
177
- closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
178
- if closest_ratio not in self.current_available_bucket_keys:
179
- continue
180
- bucket = self._aspect_ratio_buckets[closest_ratio]
181
- bucket.append(idx)
182
- # yield a batch of indices in the same aspect ratio group
183
- if len(bucket) == self.batch_size:
184
- yield bucket[:]
185
- del bucket[:]
186
-
187
- class AspectRatioBatchSampler(BatchSampler):
188
- """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
189
-
190
- Args:
191
- sampler (Sampler): Base sampler.
192
- dataset (Dataset): Dataset providing data information.
193
- batch_size (int): Size of mini-batch.
194
- drop_last (bool): If ``True``, the sampler will drop the last batch if
195
- its size would be less than ``batch_size``.
196
- aspect_ratios (dict): The predefined aspect ratios.
197
- """
198
- def __init__(
199
- self,
200
- sampler: Sampler,
201
- dataset: Dataset,
202
- batch_size: int,
203
- video_folder: str = None,
204
- train_data_format: str = "webvid",
205
- aspect_ratios: dict = ASPECT_RATIO_512,
206
- drop_last: bool = False,
207
- config=None,
208
- **kwargs
209
- ) -> None:
210
- if not isinstance(sampler, Sampler):
211
- raise TypeError('sampler should be an instance of ``Sampler``, '
212
- f'but got {sampler}')
213
- if not isinstance(batch_size, int) or batch_size <= 0:
214
- raise ValueError('batch_size should be a positive integer value, '
215
- f'but got batch_size={batch_size}')
216
- self.sampler = sampler
217
- self.dataset = dataset
218
- self.video_folder = video_folder
219
- self.train_data_format = train_data_format
220
- self.batch_size = batch_size
221
- self.aspect_ratios = aspect_ratios
222
- self.drop_last = drop_last
223
- self.config = config
224
- # buckets for each aspect ratio
225
- self._aspect_ratio_buckets = {ratio: [] for ratio in aspect_ratios}
226
- # [str(k) for k, v in aspect_ratios]
227
- self.current_available_bucket_keys = list(aspect_ratios.keys())
228
-
229
- def __iter__(self):
230
- for idx in self.sampler:
231
- try:
232
- video_dict = self.dataset[idx]
233
- width, more = video_dict.get("width", None), video_dict.get("height", None)
234
-
235
- if width is None or height is None:
236
- if self.train_data_format == "normal":
237
- video_id, name = video_dict['file_path'], video_dict['text']
238
- if self.video_folder is None:
239
- video_dir = video_id
240
- else:
241
- video_dir = os.path.join(self.video_folder, video_id)
242
- else:
243
- videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
244
- video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
245
- cap = cv2.VideoCapture(video_dir)
246
-
247
- # 获取视频尺寸
248
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
249
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
250
-
251
- ratio = height / width # self.dataset[idx]
252
- else:
253
- height = int(height)
254
- width = int(width)
255
- ratio = height / width # self.dataset[idx]
256
- except Exception as e:
257
- print(e, self.dataset[idx], "This item is error, please check it.")
258
- continue
259
- # find the closest aspect ratio
260
- closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
261
- if closest_ratio not in self.current_available_bucket_keys:
262
- continue
263
- bucket = self._aspect_ratio_buckets[closest_ratio]
264
- bucket.append(idx)
265
- # yield a batch of indices in the same aspect ratio group
266
- if len(bucket) == self.batch_size:
267
- yield bucket[:]
268
- del bucket[:]
269
-
270
- class AspectRatioBatchImageVideoSampler(BatchSampler):
271
- """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
272
-
273
- Args:
274
- sampler (Sampler): Base sampler.
275
- dataset (Dataset): Dataset providing data information.
276
- batch_size (int): Size of mini-batch.
277
- drop_last (bool): If ``True``, the sampler will drop the last batch if
278
- its size would be less than ``batch_size``.
279
- aspect_ratios (dict): The predefined aspect ratios.
280
- """
281
-
282
- def __init__(self,
283
- sampler: Sampler,
284
- dataset: Dataset,
285
- batch_size: int,
286
- train_folder: str = None,
287
- aspect_ratios: dict = ASPECT_RATIO_512,
288
- drop_last: bool = False
289
- ) -> None:
290
- if not isinstance(sampler, Sampler):
291
- raise TypeError('sampler should be an instance of ``Sampler``, '
292
- f'but got {sampler}')
293
- if not isinstance(batch_size, int) or batch_size <= 0:
294
- raise ValueError('batch_size should be a positive integer value, '
295
- f'but got batch_size={batch_size}')
296
- self.sampler = sampler
297
- self.dataset = dataset
298
- self.train_folder = train_folder
299
- self.batch_size = batch_size
300
- self.aspect_ratios = aspect_ratios
301
- self.drop_last = drop_last
302
-
303
- # buckets for each aspect ratio
304
- self.current_available_bucket_keys = list(aspect_ratios.keys())
305
- self.bucket = {
306
- 'image':{ratio: [] for ratio in aspect_ratios},
307
- 'video':{ratio: [] for ratio in aspect_ratios}
308
- }
309
-
310
- def __iter__(self):
311
- for idx in self.sampler:
312
- content_type = self.dataset[idx].get('type', 'image')
313
- if content_type == 'image':
314
- try:
315
- image_dict = self.dataset[idx]
316
-
317
- width, height = image_dict.get("width", None), image_dict.get("height", None)
318
- if width is None or height is None:
319
- image_id, name = image_dict['file_path'], image_dict['text']
320
- if self.train_folder is None:
321
- image_dir = image_id
322
- else:
323
- image_dir = os.path.join(self.train_folder, image_id)
324
-
325
- width, height = get_image_size_without_loading(image_dir)
326
-
327
- ratio = height / width # self.dataset[idx]
328
- else:
329
- height = int(height)
330
- width = int(width)
331
- ratio = height / width # self.dataset[idx]
332
- except Exception as e:
333
- print(e, self.dataset[idx], "This item is error, please check it.")
334
- continue
335
- # find the closest aspect ratio
336
- closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
337
- if closest_ratio not in self.current_available_bucket_keys:
338
- continue
339
- bucket = self.bucket['image'][closest_ratio]
340
- bucket.append(idx)
341
- # yield a batch of indices in the same aspect ratio group
342
- if len(bucket) == self.batch_size:
343
- yield bucket[:]
344
- del bucket[:]
345
- else:
346
- try:
347
- video_dict = self.dataset[idx]
348
- width, height = video_dict.get("width", None), video_dict.get("height", None)
349
-
350
- if width is None or height is None:
351
- video_id, name = video_dict['file_path'], video_dict['text']
352
- if self.train_folder is None:
353
- video_dir = video_id
354
- else:
355
- video_dir = os.path.join(self.train_folder, video_id)
356
- cap = cv2.VideoCapture(video_dir)
357
-
358
- # 获取视频尺寸
359
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # 浮点数转换为整数
360
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 浮点数转换为整数
361
-
362
- ratio = height / width # self.dataset[idx]
363
- else:
364
- height = int(height)
365
- width = int(width)
366
- ratio = height / width # self.dataset[idx]
367
- except Exception as e:
368
- print(e, self.dataset[idx], "This item is error, please check it.")
369
- continue
370
- # find the closest aspect ratio
371
- closest_ratio = min(self.aspect_ratios.keys(), key=lambda r: abs(float(r) - ratio))
372
- if closest_ratio not in self.current_available_bucket_keys:
373
- continue
374
- bucket = self.bucket['video'][closest_ratio]
375
- bucket.append(idx)
376
- # yield a batch of indices in the same aspect ratio group
377
- if len(bucket) == self.batch_size:
378
- yield bucket[:]
379
- del bucket[:]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/data/dataset_image.py DELETED
@@ -1,191 +0,0 @@
1
- import json
2
- import os
3
- import random
4
-
5
- import numpy as np
6
- import torch
7
- import torchvision.transforms as transforms
8
- from PIL import Image
9
- from torch.utils.data.dataset import Dataset
10
-
11
-
12
- class CC15M(Dataset):
13
- def __init__(
14
- self,
15
- json_path,
16
- video_folder=None,
17
- resolution=512,
18
- enable_bucket=False,
19
- ):
20
- print(f"loading annotations from {json_path} ...")
21
- self.dataset = json.load(open(json_path, 'r'))
22
- self.length = len(self.dataset)
23
- print(f"data scale: {self.length}")
24
-
25
- self.enable_bucket = enable_bucket
26
- self.video_folder = video_folder
27
-
28
- resolution = tuple(resolution) if not isinstance(resolution, int) else (resolution, resolution)
29
- self.pixel_transforms = transforms.Compose([
30
- transforms.Resize(resolution[0]),
31
- transforms.CenterCrop(resolution),
32
- transforms.ToTensor(),
33
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
34
- ])
35
-
36
- def get_batch(self, idx):
37
- video_dict = self.dataset[idx]
38
- video_id, name = video_dict['file_path'], video_dict['text']
39
-
40
- if self.video_folder is None:
41
- video_dir = video_id
42
- else:
43
- video_dir = os.path.join(self.video_folder, video_id)
44
-
45
- pixel_values = Image.open(video_dir).convert("RGB")
46
- return pixel_values, name
47
-
48
- def __len__(self):
49
- return self.length
50
-
51
- def __getitem__(self, idx):
52
- while True:
53
- try:
54
- pixel_values, name = self.get_batch(idx)
55
- break
56
- except Exception as e:
57
- print(e)
58
- idx = random.randint(0, self.length-1)
59
-
60
- if not self.enable_bucket:
61
- pixel_values = self.pixel_transforms(pixel_values)
62
- else:
63
- pixel_values = np.array(pixel_values)
64
-
65
- sample = dict(pixel_values=pixel_values, text=name)
66
- return sample
67
-
68
- class ImageEditDataset(Dataset):
69
- def __init__(
70
- self,
71
- ann_path, data_root=None,
72
- image_sample_size=512,
73
- text_drop_ratio=0.1,
74
- enable_bucket=False,
75
- enable_inpaint=False,
76
- return_file_name=False,
77
- ):
78
- # Loading annotations from files
79
- print(f"loading annotations from {ann_path} ...")
80
- if ann_path.endswith('.csv'):
81
- with open(ann_path, 'r') as csvfile:
82
- dataset = list(csv.DictReader(csvfile))
83
- elif ann_path.endswith('.json'):
84
- dataset = json.load(open(ann_path))
85
-
86
- self.data_root = data_root
87
- self.dataset = dataset
88
-
89
- self.length = len(self.dataset)
90
- print(f"data scale: {self.length}")
91
- # TODO: enable bucket training
92
- self.enable_bucket = enable_bucket
93
- self.text_drop_ratio = text_drop_ratio
94
- self.enable_inpaint = enable_inpaint
95
- self.return_file_name = return_file_name
96
-
97
- # Image params
98
- self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
99
- self.image_transforms = transforms.Compose([
100
- transforms.Resize(min(self.image_sample_size)),
101
- transforms.CenterCrop(self.image_sample_size),
102
- transforms.ToTensor(),
103
- transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
104
- ])
105
-
106
- def get_batch(self, idx):
107
- data_info = self.dataset[idx % len(self.dataset)]
108
-
109
- image_path, text = data_info['file_path'], data_info['text']
110
- if self.data_root is not None:
111
- image_path = os.path.join(self.data_root, image_path)
112
- image = Image.open(image_path).convert('RGB')
113
-
114
- if not self.enable_bucket:
115
- raise ValueError("Not enable_bucket is not supported now. ")
116
- else:
117
- image = np.expand_dims(np.array(image), 0)
118
-
119
- source_image_path = data_info.get('source_file_path', [])
120
- source_image = []
121
- if isinstance(source_image_path, list):
122
- for _source_image_path in source_image_path:
123
- if self.data_root is not None:
124
- _source_image_path = os.path.join(self.data_root, _source_image_path)
125
- _source_image = Image.open(_source_image_path).convert('RGB')
126
- source_image.append(_source_image)
127
- else:
128
- if self.data_root is not None:
129
- _source_image_path = os.path.join(self.data_root, source_image_path)
130
- _source_image = Image.open(_source_image_path).convert('RGB')
131
- source_image.append(_source_image)
132
-
133
- if not self.enable_bucket:
134
- raise ValueError("Not enable_bucket is not supported now. ")
135
- else:
136
- source_image = [np.array(_source_image) for _source_image in source_image]
137
-
138
- if random.random() < self.text_drop_ratio:
139
- text = ''
140
- return image, source_image, text, 'image', image_path
141
-
142
- def __len__(self):
143
- return self.length
144
-
145
- def __getitem__(self, idx):
146
- data_info = self.dataset[idx % len(self.dataset)]
147
- data_type = data_info.get('type', 'image')
148
- while True:
149
- sample = {}
150
- try:
151
- data_info_local = self.dataset[idx % len(self.dataset)]
152
- data_type_local = data_info_local.get('type', 'image')
153
- if data_type_local != data_type:
154
- raise ValueError("data_type_local != data_type")
155
-
156
- pixel_values, source_pixel_values, name, data_type, file_path = self.get_batch(idx)
157
- sample["pixel_values"] = pixel_values
158
- sample["source_pixel_values"] = source_pixel_values
159
- sample["text"] = name
160
- sample["data_type"] = data_type
161
- sample["idx"] = idx
162
- if self.return_file_name:
163
- sample["file_name"] = os.path.basename(file_path)
164
-
165
- if len(sample) > 0:
166
- break
167
- except Exception as e:
168
- print(e, self.dataset[idx % len(self.dataset)])
169
- idx = random.randint(0, self.length-1)
170
-
171
- if self.enable_inpaint and not self.enable_bucket:
172
- mask = get_random_mask(pixel_values.size())
173
- mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
174
- sample["mask_pixel_values"] = mask_pixel_values
175
- sample["mask"] = mask
176
-
177
- clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
178
- clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
179
- sample["clip_pixel_values"] = clip_pixel_values
180
-
181
- return sample
182
-
183
- if __name__ == "__main__":
184
- dataset = CC15M(
185
- csv_path="./cc15m_add_index.json",
186
- resolution=512,
187
- )
188
-
189
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
190
- for idx, batch in enumerate(dataloader):
191
- print(batch["pixel_values"].shape, len(batch["text"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/data/dataset_image_video.py DELETED
@@ -1,657 +0,0 @@
1
- import csv
2
- import gc
3
- import io
4
- import json
5
- import math
6
- import os
7
- import random
8
- from contextlib import contextmanager
9
- from random import shuffle
10
- from threading import Thread
11
-
12
- import albumentations
13
- import cv2
14
- import numpy as np
15
- import torch
16
- import torch.nn.functional as F
17
- import torchvision.transforms as transforms
18
- from decord import VideoReader
19
- from einops import rearrange
20
- from func_timeout import FunctionTimedOut, func_timeout
21
- from packaging import version as pver
22
- from PIL import Image
23
- from safetensors.torch import load_file
24
- from torch.utils.data import BatchSampler, Sampler
25
- from torch.utils.data.dataset import Dataset
26
-
27
- from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
28
- custom_meshgrid, get_random_mask, get_relative_pose,
29
- get_video_reader_batch, padding_image, process_pose_file,
30
- process_pose_params, ray_condition, resize_frame,
31
- resize_image_with_target_area)
32
-
33
-
34
- class ImageVideoSampler(BatchSampler):
35
- """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
36
-
37
- Args:
38
- sampler (Sampler): Base sampler.
39
- dataset (Dataset): Dataset providing data information.
40
- batch_size (int): Size of mini-batch.
41
- drop_last (bool): If ``True``, the sampler will drop the last batch if
42
- its size would be less than ``batch_size``.
43
- aspect_ratios (dict): The predefined aspect ratios.
44
- """
45
-
46
- def __init__(self,
47
- sampler: Sampler,
48
- dataset: Dataset,
49
- batch_size: int,
50
- drop_last: bool = False
51
- ) -> None:
52
- if not isinstance(sampler, Sampler):
53
- raise TypeError('sampler should be an instance of ``Sampler``, '
54
- f'but got {sampler}')
55
- if not isinstance(batch_size, int) or batch_size <= 0:
56
- raise ValueError('batch_size should be a positive integer value, '
57
- f'but got batch_size={batch_size}')
58
- self.sampler = sampler
59
- self.dataset = dataset
60
- self.batch_size = batch_size
61
- self.drop_last = drop_last
62
-
63
- # buckets for each aspect ratio
64
- self.bucket = {'image':[], 'video':[]}
65
-
66
- def __iter__(self):
67
- for idx in self.sampler:
68
- content_type = self.dataset.dataset[idx].get('type', 'image')
69
- self.bucket[content_type].append(idx)
70
-
71
- # yield a batch of indices in the same aspect ratio group
72
- if len(self.bucket['video']) == self.batch_size:
73
- bucket = self.bucket['video']
74
- yield bucket[:]
75
- del bucket[:]
76
- elif len(self.bucket['image']) == self.batch_size:
77
- bucket = self.bucket['image']
78
- yield bucket[:]
79
- del bucket[:]
80
-
81
-
82
- class ImageVideoDataset(Dataset):
83
- def __init__(
84
- self,
85
- ann_path, data_root=None,
86
- video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
87
- image_sample_size=512,
88
- video_repeat=0,
89
- text_drop_ratio=0.1,
90
- enable_bucket=False,
91
- video_length_drop_start=0.0,
92
- video_length_drop_end=1.0,
93
- enable_inpaint=False,
94
- return_file_name=False,
95
- ):
96
- # Loading annotations from files
97
- print(f"loading annotations from {ann_path} ...")
98
- if ann_path.endswith('.csv'):
99
- with open(ann_path, 'r') as csvfile:
100
- dataset = list(csv.DictReader(csvfile))
101
- elif ann_path.endswith('.json'):
102
- dataset = json.load(open(ann_path))
103
-
104
- self.data_root = data_root
105
-
106
- # It's used to balance num of images and videos.
107
- if video_repeat > 0:
108
- self.dataset = []
109
- for data in dataset:
110
- if data.get('type', 'image') != 'video':
111
- self.dataset.append(data)
112
-
113
- for _ in range(video_repeat):
114
- for data in dataset:
115
- if data.get('type', 'image') == 'video':
116
- self.dataset.append(data)
117
- else:
118
- self.dataset = dataset
119
- del dataset
120
-
121
- self.length = len(self.dataset)
122
- print(f"data scale: {self.length}")
123
- # TODO: enable bucket training
124
- self.enable_bucket = enable_bucket
125
- self.text_drop_ratio = text_drop_ratio
126
- self.enable_inpaint = enable_inpaint
127
- self.return_file_name = return_file_name
128
-
129
- self.video_length_drop_start = video_length_drop_start
130
- self.video_length_drop_end = video_length_drop_end
131
-
132
- # Video params
133
- self.video_sample_stride = video_sample_stride
134
- self.video_sample_n_frames = video_sample_n_frames
135
- self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
136
- self.video_transforms = transforms.Compose(
137
- [
138
- transforms.Resize(min(self.video_sample_size)),
139
- transforms.CenterCrop(self.video_sample_size),
140
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
141
- ]
142
- )
143
-
144
- # Image params
145
- self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
146
- self.image_transforms = transforms.Compose([
147
- transforms.Resize(min(self.image_sample_size)),
148
- transforms.CenterCrop(self.image_sample_size),
149
- transforms.ToTensor(),
150
- transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
151
- ])
152
-
153
- self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
154
-
155
- def get_batch(self, idx):
156
- data_info = self.dataset[idx % len(self.dataset)]
157
-
158
- if data_info.get('type', 'image')=='video':
159
- video_id, text = data_info['file_path'], data_info['text']
160
-
161
- if self.data_root is None:
162
- video_dir = video_id
163
- else:
164
- video_dir = os.path.join(self.data_root, video_id)
165
-
166
- with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
167
- min_sample_n_frames = min(
168
- self.video_sample_n_frames,
169
- int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
170
- )
171
- if min_sample_n_frames == 0:
172
- raise ValueError(f"No Frames in video.")
173
-
174
- video_length = int(self.video_length_drop_end * len(video_reader))
175
- clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
176
- start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
177
- batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
178
-
179
- try:
180
- sample_args = (video_reader, batch_index)
181
- pixel_values = func_timeout(
182
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
183
- )
184
- resized_frames = []
185
- for i in range(len(pixel_values)):
186
- frame = pixel_values[i]
187
- resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
188
- resized_frames.append(resized_frame)
189
- pixel_values = np.array(resized_frames)
190
- except FunctionTimedOut:
191
- raise ValueError(f"Read {idx} timeout.")
192
- except Exception as e:
193
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
194
-
195
- if not self.enable_bucket:
196
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
197
- pixel_values = pixel_values / 255.
198
- del video_reader
199
- else:
200
- pixel_values = pixel_values
201
-
202
- if not self.enable_bucket:
203
- pixel_values = self.video_transforms(pixel_values)
204
-
205
- # Random use no text generation
206
- if random.random() < self.text_drop_ratio:
207
- text = ''
208
- return pixel_values, text, 'video', video_dir
209
- else:
210
- image_path, text = data_info['file_path'], data_info['text']
211
- if self.data_root is not None:
212
- image_path = os.path.join(self.data_root, image_path)
213
- image = Image.open(image_path).convert('RGB')
214
- if not self.enable_bucket:
215
- image = self.image_transforms(image).unsqueeze(0)
216
- else:
217
- image = np.expand_dims(np.array(image), 0)
218
- if random.random() < self.text_drop_ratio:
219
- text = ''
220
- return image, text, 'image', image_path
221
-
222
- def __len__(self):
223
- return self.length
224
-
225
- def __getitem__(self, idx):
226
- data_info = self.dataset[idx % len(self.dataset)]
227
- data_type = data_info.get('type', 'image')
228
- while True:
229
- sample = {}
230
- try:
231
- data_info_local = self.dataset[idx % len(self.dataset)]
232
- data_type_local = data_info_local.get('type', 'image')
233
- if data_type_local != data_type:
234
- raise ValueError("data_type_local != data_type")
235
-
236
- pixel_values, name, data_type, file_path = self.get_batch(idx)
237
- sample["pixel_values"] = pixel_values
238
- sample["text"] = name
239
- sample["data_type"] = data_type
240
- sample["idx"] = idx
241
- if self.return_file_name:
242
- sample["file_name"] = os.path.basename(file_path)
243
-
244
- if len(sample) > 0:
245
- break
246
- except Exception as e:
247
- print(e, self.dataset[idx % len(self.dataset)])
248
- idx = random.randint(0, self.length-1)
249
-
250
- if self.enable_inpaint and not self.enable_bucket:
251
- mask = get_random_mask(pixel_values.size())
252
- mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
253
- sample["mask_pixel_values"] = mask_pixel_values
254
- sample["mask"] = mask
255
-
256
- clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
257
- clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
258
- sample["clip_pixel_values"] = clip_pixel_values
259
-
260
- return sample
261
-
262
-
263
- class ImageVideoControlDataset(Dataset):
264
- def __init__(
265
- self,
266
- ann_path, data_root=None,
267
- video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
268
- image_sample_size=512,
269
- video_repeat=0,
270
- text_drop_ratio=0.1,
271
- enable_bucket=False,
272
- video_length_drop_start=0.1,
273
- video_length_drop_end=0.9,
274
- enable_inpaint=False,
275
- enable_camera_info=False,
276
- return_file_name=False,
277
- enable_subject_info=False,
278
- padding_subject_info=True,
279
- ):
280
- # Loading annotations from files
281
- print(f"loading annotations from {ann_path} ...")
282
- if ann_path.endswith('.csv'):
283
- with open(ann_path, 'r') as csvfile:
284
- dataset = list(csv.DictReader(csvfile))
285
- elif ann_path.endswith('.json'):
286
- dataset = json.load(open(ann_path))
287
-
288
- self.data_root = data_root
289
-
290
- # It's used to balance num of images and videos.
291
- if video_repeat > 0:
292
- self.dataset = []
293
- for data in dataset:
294
- if data.get('type', 'image') != 'video':
295
- self.dataset.append(data)
296
-
297
- for _ in range(video_repeat):
298
- for data in dataset:
299
- if data.get('type', 'image') == 'video':
300
- self.dataset.append(data)
301
- else:
302
- self.dataset = dataset
303
- del dataset
304
-
305
- self.length = len(self.dataset)
306
- print(f"data scale: {self.length}")
307
- # TODO: enable bucket training
308
- self.enable_bucket = enable_bucket
309
- self.text_drop_ratio = text_drop_ratio
310
- self.enable_inpaint = enable_inpaint
311
- self.enable_camera_info = enable_camera_info
312
- self.enable_subject_info = enable_subject_info
313
- self.padding_subject_info = padding_subject_info
314
-
315
- self.video_length_drop_start = video_length_drop_start
316
- self.video_length_drop_end = video_length_drop_end
317
-
318
- # Video params
319
- self.video_sample_stride = video_sample_stride
320
- self.video_sample_n_frames = video_sample_n_frames
321
- self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
322
- self.video_transforms = transforms.Compose(
323
- [
324
- transforms.Resize(min(self.video_sample_size)),
325
- transforms.CenterCrop(self.video_sample_size),
326
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
327
- ]
328
- )
329
- if self.enable_camera_info:
330
- self.video_transforms_camera = transforms.Compose(
331
- [
332
- transforms.Resize(min(self.video_sample_size)),
333
- transforms.CenterCrop(self.video_sample_size)
334
- ]
335
- )
336
-
337
- # Image params
338
- self.image_sample_size = tuple(image_sample_size) if not isinstance(image_sample_size, int) else (image_sample_size, image_sample_size)
339
- self.image_transforms = transforms.Compose([
340
- transforms.Resize(min(self.image_sample_size)),
341
- transforms.CenterCrop(self.image_sample_size),
342
- transforms.ToTensor(),
343
- transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
344
- ])
345
-
346
- self.larger_side_of_image_and_video = max(min(self.image_sample_size), min(self.video_sample_size))
347
-
348
- def get_batch(self, idx):
349
- data_info = self.dataset[idx % len(self.dataset)]
350
- video_id, text = data_info['file_path'], data_info['text']
351
-
352
- if data_info.get('type', 'image')=='video':
353
- if self.data_root is None:
354
- video_dir = video_id
355
- else:
356
- video_dir = os.path.join(self.data_root, video_id)
357
-
358
- with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
359
- min_sample_n_frames = min(
360
- self.video_sample_n_frames,
361
- int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
362
- )
363
- if min_sample_n_frames == 0:
364
- raise ValueError(f"No Frames in video.")
365
-
366
- video_length = int(self.video_length_drop_end * len(video_reader))
367
- clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
368
- start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
369
- batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
370
-
371
- try:
372
- sample_args = (video_reader, batch_index)
373
- pixel_values = func_timeout(
374
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
375
- )
376
- resized_frames = []
377
- for i in range(len(pixel_values)):
378
- frame = pixel_values[i]
379
- resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
380
- resized_frames.append(resized_frame)
381
- pixel_values = np.array(resized_frames)
382
- except FunctionTimedOut:
383
- raise ValueError(f"Read {idx} timeout.")
384
- except Exception as e:
385
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
386
-
387
- if not self.enable_bucket:
388
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
389
- pixel_values = pixel_values / 255.
390
- del video_reader
391
- else:
392
- pixel_values = pixel_values
393
-
394
- if not self.enable_bucket:
395
- pixel_values = self.video_transforms(pixel_values)
396
-
397
- # Random use no text generation
398
- if random.random() < self.text_drop_ratio:
399
- text = ''
400
-
401
- control_video_id = data_info['control_file_path']
402
-
403
- if control_video_id is not None:
404
- if self.data_root is None:
405
- control_video_id = control_video_id
406
- else:
407
- control_video_id = os.path.join(self.data_root, control_video_id)
408
-
409
- if self.enable_camera_info:
410
- if control_video_id.lower().endswith('.txt'):
411
- if not self.enable_bucket:
412
- control_pixel_values = torch.zeros_like(pixel_values)
413
-
414
- control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0])
415
- control_camera_values = torch.from_numpy(control_camera_values).permute(0, 3, 1, 2).contiguous()
416
- control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)
417
- control_camera_values = self.video_transforms_camera(control_camera_values)
418
- else:
419
- control_pixel_values = np.zeros_like(pixel_values)
420
-
421
- control_camera_values = process_pose_file(control_video_id, width=self.video_sample_size[1], height=self.video_sample_size[0], return_poses=True)
422
- control_camera_values = torch.from_numpy(np.array(control_camera_values)).unsqueeze(0).unsqueeze(0)
423
- control_camera_values = F.interpolate(control_camera_values, size=(len(video_reader), control_camera_values.size(3)), mode='bilinear', align_corners=True)[0][0]
424
- control_camera_values = np.array([control_camera_values[index] for index in batch_index])
425
- else:
426
- if not self.enable_bucket:
427
- control_pixel_values = torch.zeros_like(pixel_values)
428
- control_camera_values = None
429
- else:
430
- control_pixel_values = np.zeros_like(pixel_values)
431
- control_camera_values = None
432
- else:
433
- if control_video_id is not None:
434
- with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
435
- try:
436
- sample_args = (control_video_reader, batch_index)
437
- control_pixel_values = func_timeout(
438
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
439
- )
440
- resized_frames = []
441
- for i in range(len(control_pixel_values)):
442
- frame = control_pixel_values[i]
443
- resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
444
- resized_frames.append(resized_frame)
445
- control_pixel_values = np.array(resized_frames)
446
- except FunctionTimedOut:
447
- raise ValueError(f"Read {idx} timeout.")
448
- except Exception as e:
449
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
450
-
451
- if not self.enable_bucket:
452
- control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
453
- control_pixel_values = control_pixel_values / 255.
454
- del control_video_reader
455
- else:
456
- control_pixel_values = control_pixel_values
457
-
458
- if not self.enable_bucket:
459
- control_pixel_values = self.video_transforms(control_pixel_values)
460
- else:
461
- if not self.enable_bucket:
462
- control_pixel_values = torch.zeros_like(pixel_values)
463
- else:
464
- control_pixel_values = np.zeros_like(pixel_values)
465
- control_camera_values = None
466
-
467
- if self.enable_subject_info:
468
- if not self.enable_bucket:
469
- visual_height, visual_width = pixel_values.shape[-2:]
470
- else:
471
- visual_height, visual_width = pixel_values.shape[1:3]
472
-
473
- subject_id = data_info.get('object_file_path', [])
474
- shuffle(subject_id)
475
- subject_images = []
476
- for i in range(min(len(subject_id), 4)):
477
- subject_image = Image.open(subject_id[i])
478
- width, height = subject_image.size
479
- total_pixels = width * height
480
-
481
- if self.padding_subject_info:
482
- img = padding_image(subject_image, visual_width, visual_height)
483
- else:
484
- img = resize_image_with_target_area(subject_image, 1024 * 1024)
485
-
486
- if random.random() < 0.5:
487
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
488
- subject_images.append(np.array(img))
489
- if self.padding_subject_info:
490
- subject_image = np.array(subject_images)
491
- else:
492
- subject_image = subject_images
493
- else:
494
- subject_image = None
495
-
496
- return pixel_values, control_pixel_values, subject_image, control_camera_values, text, "video"
497
- else:
498
- image_path, text = data_info['file_path'], data_info['text']
499
- if self.data_root is not None:
500
- image_path = os.path.join(self.data_root, image_path)
501
- image = Image.open(image_path).convert('RGB')
502
- if not self.enable_bucket:
503
- image = self.image_transforms(image).unsqueeze(0)
504
- else:
505
- image = np.expand_dims(np.array(image), 0)
506
-
507
- if random.random() < self.text_drop_ratio:
508
- text = ''
509
-
510
- control_image_id = data_info['control_file_path']
511
-
512
- if self.data_root is None:
513
- control_image_id = control_image_id
514
- else:
515
- control_image_id = os.path.join(self.data_root, control_image_id)
516
-
517
- control_image = Image.open(control_image_id).convert('RGB')
518
- if not self.enable_bucket:
519
- control_image = self.image_transforms(control_image).unsqueeze(0)
520
- else:
521
- control_image = np.expand_dims(np.array(control_image), 0)
522
-
523
- if self.enable_subject_info:
524
- if not self.enable_bucket:
525
- visual_height, visual_width = image.shape[-2:]
526
- else:
527
- visual_height, visual_width = image.shape[1:3]
528
-
529
- subject_id = data_info.get('object_file_path', [])
530
- shuffle(subject_id)
531
- subject_images = []
532
- for i in range(min(len(subject_id), 4)):
533
- subject_image = Image.open(subject_id[i]).convert('RGB')
534
- width, height = subject_image.size
535
- total_pixels = width * height
536
-
537
- if self.padding_subject_info:
538
- img = padding_image(subject_image, visual_width, visual_height)
539
- else:
540
- img = resize_image_with_target_area(subject_image, 1024 * 1024)
541
-
542
- if random.random() < 0.5:
543
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
544
- subject_images.append(np.array(img))
545
- if self.padding_subject_info:
546
- subject_image = np.array(subject_images)
547
- else:
548
- subject_image = subject_images
549
- else:
550
- subject_image = None
551
-
552
- return image, control_image, subject_image, None, text, 'image'
553
-
554
- def __len__(self):
555
- return self.length
556
-
557
- def __getitem__(self, idx):
558
- data_info = self.dataset[idx % len(self.dataset)]
559
- data_type = data_info.get('type', 'image')
560
- while True:
561
- sample = {}
562
- try:
563
- data_info_local = self.dataset[idx % len(self.dataset)]
564
- data_type_local = data_info_local.get('type', 'image')
565
- if data_type_local != data_type:
566
- raise ValueError("data_type_local != data_type")
567
-
568
- pixel_values, control_pixel_values, subject_image, control_camera_values, name, data_type = self.get_batch(idx)
569
-
570
- sample["pixel_values"] = pixel_values
571
- sample["control_pixel_values"] = control_pixel_values
572
- sample["subject_image"] = subject_image
573
- sample["text"] = name
574
- sample["data_type"] = data_type
575
- sample["idx"] = idx
576
-
577
- if self.enable_camera_info:
578
- sample["control_camera_values"] = control_camera_values
579
-
580
- if len(sample) > 0:
581
- break
582
- except Exception as e:
583
- print(e, self.dataset[idx % len(self.dataset)])
584
- idx = random.randint(0, self.length-1)
585
-
586
- if self.enable_inpaint and not self.enable_bucket:
587
- mask = get_random_mask(pixel_values.size())
588
- mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
589
- sample["mask_pixel_values"] = mask_pixel_values
590
- sample["mask"] = mask
591
-
592
- clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
593
- clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
594
- sample["clip_pixel_values"] = clip_pixel_values
595
-
596
- return sample
597
-
598
-
599
- class ImageVideoSafetensorsDataset(Dataset):
600
- def __init__(
601
- self,
602
- ann_path,
603
- data_root=None,
604
- ):
605
- # Loading annotations from files
606
- print(f"loading annotations from {ann_path} ...")
607
- if ann_path.endswith('.json'):
608
- dataset = json.load(open(ann_path))
609
-
610
- self.data_root = data_root
611
- self.dataset = dataset
612
- self.length = len(self.dataset)
613
- print(f"data scale: {self.length}")
614
-
615
- def __len__(self):
616
- return self.length
617
-
618
- def __getitem__(self, idx):
619
- if self.data_root is None:
620
- path = self.dataset[idx]["file_path"]
621
- else:
622
- path = os.path.join(self.data_root, self.dataset[idx]["file_path"])
623
- state_dict = load_file(path)
624
- return state_dict
625
-
626
-
627
- class TextDataset(Dataset):
628
- def __init__(self, ann_path, text_drop_ratio=0.0):
629
- print(f"loading annotations from {ann_path} ...")
630
- with open(ann_path, 'r') as f:
631
- self.dataset = json.load(f)
632
- self.length = len(self.dataset)
633
- print(f"data scale: {self.length}")
634
- self.text_drop_ratio = text_drop_ratio
635
-
636
- def __len__(self):
637
- return self.length
638
-
639
- def __getitem__(self, idx):
640
- while True:
641
- try:
642
- item = self.dataset[idx]
643
- text = item['text']
644
-
645
- # Randomly drop text (for classifier-free guidance)
646
- if random.random() < self.text_drop_ratio:
647
- text = ''
648
-
649
- sample = {
650
- "text": text,
651
- "idx": idx
652
- }
653
- return sample
654
-
655
- except Exception as e:
656
- print(f"Error at index {idx}: {e}, retrying with random index...")
657
- idx = np.random.randint(0, self.length - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/data/dataset_video.py DELETED
@@ -1,901 +0,0 @@
1
- import csv
2
- import gc
3
- import io
4
- import json
5
- import math
6
- import os
7
- import random
8
- from contextlib import contextmanager
9
- from threading import Thread
10
-
11
- import albumentations
12
- import cv2
13
- import librosa
14
- import numpy as np
15
- import torch
16
- import torchvision.transforms as transforms
17
- from decord import VideoReader
18
- from einops import rearrange
19
- from func_timeout import FunctionTimedOut, func_timeout
20
- from PIL import Image
21
- from torch.utils.data import BatchSampler, Sampler
22
- from torch.utils.data.dataset import Dataset
23
-
24
- from .utils import (VIDEO_READER_TIMEOUT, Camera, VideoReader_contextmanager,
25
- custom_meshgrid, get_random_mask, get_relative_pose,
26
- get_video_reader_batch, padding_image, process_pose_file,
27
- process_pose_params, ray_condition, resize_frame,
28
- resize_image_with_target_area)
29
-
30
-
31
- class WebVid10M(Dataset):
32
- def __init__(
33
- self,
34
- csv_path, video_folder,
35
- sample_size=256, sample_stride=4, sample_n_frames=16,
36
- enable_bucket=False, enable_inpaint=False, is_image=False,
37
- ):
38
- print(f"loading annotations from {csv_path} ...")
39
- with open(csv_path, 'r') as csvfile:
40
- self.dataset = list(csv.DictReader(csvfile))
41
- self.length = len(self.dataset)
42
- print(f"data scale: {self.length}")
43
-
44
- self.video_folder = video_folder
45
- self.sample_stride = sample_stride
46
- self.sample_n_frames = sample_n_frames
47
- self.enable_bucket = enable_bucket
48
- self.enable_inpaint = enable_inpaint
49
- self.is_image = is_image
50
-
51
- sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
52
- self.pixel_transforms = transforms.Compose([
53
- transforms.Resize(sample_size[0]),
54
- transforms.CenterCrop(sample_size),
55
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
56
- ])
57
-
58
- def get_batch(self, idx):
59
- video_dict = self.dataset[idx]
60
- videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
61
-
62
- video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
63
- video_reader = VideoReader(video_dir)
64
- video_length = len(video_reader)
65
-
66
- if not self.is_image:
67
- clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
68
- start_idx = random.randint(0, video_length - clip_length)
69
- batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
70
- else:
71
- batch_index = [random.randint(0, video_length - 1)]
72
-
73
- if not self.enable_bucket:
74
- pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
75
- pixel_values = pixel_values / 255.
76
- del video_reader
77
- else:
78
- pixel_values = video_reader.get_batch(batch_index).asnumpy()
79
-
80
- if self.is_image:
81
- pixel_values = pixel_values[0]
82
- return pixel_values, name
83
-
84
- def __len__(self):
85
- return self.length
86
-
87
- def __getitem__(self, idx):
88
- while True:
89
- try:
90
- pixel_values, name = self.get_batch(idx)
91
- break
92
-
93
- except Exception as e:
94
- print("Error info:", e)
95
- idx = random.randint(0, self.length-1)
96
-
97
- if not self.enable_bucket:
98
- pixel_values = self.pixel_transforms(pixel_values)
99
- if self.enable_inpaint:
100
- mask = get_random_mask(pixel_values.size())
101
- mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
102
- sample = dict(pixel_values=pixel_values, mask_pixel_values=mask_pixel_values, mask=mask, text=name)
103
- else:
104
- sample = dict(pixel_values=pixel_values, text=name)
105
- return sample
106
-
107
-
108
- class VideoDataset(Dataset):
109
- def __init__(
110
- self,
111
- ann_path, data_root=None,
112
- sample_size=256, sample_stride=4, sample_n_frames=16,
113
- enable_bucket=False, enable_inpaint=False
114
- ):
115
- print(f"loading annotations from {ann_path} ...")
116
- self.dataset = json.load(open(ann_path, 'r'))
117
- self.length = len(self.dataset)
118
- print(f"data scale: {self.length}")
119
-
120
- self.data_root = data_root
121
- self.sample_stride = sample_stride
122
- self.sample_n_frames = sample_n_frames
123
- self.enable_bucket = enable_bucket
124
- self.enable_inpaint = enable_inpaint
125
-
126
- sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
127
- self.pixel_transforms = transforms.Compose(
128
- [
129
- transforms.Resize(sample_size[0]),
130
- transforms.CenterCrop(sample_size),
131
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
132
- ]
133
- )
134
-
135
- def get_batch(self, idx):
136
- video_dict = self.dataset[idx]
137
- video_id, text = video_dict['file_path'], video_dict['text']
138
-
139
- if self.data_root is None:
140
- video_dir = video_id
141
- else:
142
- video_dir = os.path.join(self.data_root, video_id)
143
-
144
- with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
145
- min_sample_n_frames = min(
146
- self.video_sample_n_frames,
147
- int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
148
- )
149
- if min_sample_n_frames == 0:
150
- raise ValueError(f"No Frames in video.")
151
-
152
- video_length = int(self.video_length_drop_end * len(video_reader))
153
- clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
154
- start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
155
- batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
156
-
157
- try:
158
- sample_args = (video_reader, batch_index)
159
- pixel_values = func_timeout(
160
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
161
- )
162
- except FunctionTimedOut:
163
- raise ValueError(f"Read {idx} timeout.")
164
- except Exception as e:
165
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
166
-
167
- if not self.enable_bucket:
168
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
169
- pixel_values = pixel_values / 255.
170
- del video_reader
171
- else:
172
- pixel_values = pixel_values
173
-
174
- if not self.enable_bucket:
175
- pixel_values = self.video_transforms(pixel_values)
176
-
177
- # Random use no text generation
178
- if random.random() < self.text_drop_ratio:
179
- text = ''
180
- return pixel_values, text
181
-
182
- def __len__(self):
183
- return self.length
184
-
185
- def __getitem__(self, idx):
186
- while True:
187
- sample = {}
188
- try:
189
- pixel_values, name = self.get_batch(idx)
190
- sample["pixel_values"] = pixel_values
191
- sample["text"] = name
192
- sample["idx"] = idx
193
- if len(sample) > 0:
194
- break
195
-
196
- except Exception as e:
197
- print(e, self.dataset[idx % len(self.dataset)])
198
- idx = random.randint(0, self.length-1)
199
-
200
- if self.enable_inpaint and not self.enable_bucket:
201
- mask = get_random_mask(pixel_values.size())
202
- mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
203
- sample["mask_pixel_values"] = mask_pixel_values
204
- sample["mask"] = mask
205
-
206
- clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
207
- clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
208
- sample["clip_pixel_values"] = clip_pixel_values
209
-
210
- return sample
211
-
212
-
213
- class VideoSpeechDataset(Dataset):
214
- def __init__(
215
- self,
216
- ann_path, data_root=None,
217
- video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
218
- enable_bucket=False, enable_inpaint=False,
219
- audio_sr=16000, # 新增:目标音频采样率
220
- text_drop_ratio=0.1 # 新增:文本丢弃概率
221
- ):
222
- print(f"loading annotations from {ann_path} ...")
223
- self.dataset = json.load(open(ann_path, 'r'))
224
- self.length = len(self.dataset)
225
- print(f"data scale: {self.length}")
226
-
227
- self.data_root = data_root
228
- self.video_sample_stride = video_sample_stride
229
- self.video_sample_n_frames = video_sample_n_frames
230
- self.enable_bucket = enable_bucket
231
- self.enable_inpaint = enable_inpaint
232
- self.audio_sr = audio_sr
233
- self.text_drop_ratio = text_drop_ratio
234
-
235
- video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
236
- self.pixel_transforms = transforms.Compose(
237
- [
238
- transforms.Resize(video_sample_size[0]),
239
- transforms.CenterCrop(video_sample_size),
240
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
241
- ]
242
- )
243
-
244
- def get_batch(self, idx):
245
- video_dict = self.dataset[idx]
246
- video_id, text = video_dict['file_path'], video_dict['text']
247
- audio_id = video_dict['audio_path']
248
-
249
- if self.data_root is None:
250
- video_path = video_id
251
- else:
252
- video_path = os.path.join(self.data_root, video_id)
253
-
254
- if self.data_root is None:
255
- audio_path = audio_id
256
- else:
257
- audio_path = os.path.join(self.data_root, audio_id)
258
-
259
- if not os.path.exists(audio_path):
260
- raise FileNotFoundError(f"Audio file not found for {video_path}")
261
-
262
- with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
263
- total_frames = len(video_reader)
264
- fps = video_reader.get_avg_fps() # 获取原始视频帧率
265
-
266
- # 计算实际采样的视频帧数(考虑边界)
267
- max_possible_frames = (total_frames - 1) // self.video_sample_stride + 1
268
- actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
269
- if actual_n_frames <= 0:
270
- raise ValueError(f"Video too short: {video_path}")
271
-
272
- # 随机选择起始帧
273
- max_start = total_frames - (actual_n_frames - 1) * self.video_sample_stride - 1
274
- start_frame = random.randint(0, max_start) if max_start > 0 else 0
275
- frame_indices = [start_frame + i * self.video_sample_stride for i in range(actual_n_frames)]
276
-
277
- # 读取视频帧
278
- try:
279
- sample_args = (video_reader, frame_indices)
280
- pixel_values = func_timeout(
281
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
282
- )
283
- except FunctionTimedOut:
284
- raise ValueError(f"Read {idx} timeout.")
285
- except Exception as e:
286
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
287
-
288
- # 视频后处理
289
- if not self.enable_bucket:
290
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
291
- pixel_values = pixel_values / 255.
292
- pixel_values = self.pixel_transforms(pixel_values)
293
-
294
- # === 新增:加载并截取对应音频 ===
295
- # 视频片段的起止时间(秒)
296
- start_time = start_frame / fps
297
- end_time = (start_frame + (actual_n_frames - 1) * self.video_sample_stride) / fps
298
- duration = end_time - start_time
299
-
300
- # 使用 librosa 加载整个音频(或仅加载所需部分,但 librosa.load 不支持精确 seek,所以先加载再切)
301
- audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr) # 重采样到目标 sr
302
-
303
- # 转换为样本索引
304
- start_sample = int(start_time * self.audio_sr)
305
- end_sample = int(end_time * self.audio_sr)
306
-
307
- # 安全截取
308
- if start_sample >= len(audio_input):
309
- # 音频太短,用零填充或截断
310
- audio_segment = np.zeros(int(duration * self.audio_sr), dtype=np.float32)
311
- else:
312
- audio_segment = audio_input[start_sample:end_sample]
313
- # 如果太短,补零
314
- target_len = int(duration * self.audio_sr)
315
- if len(audio_segment) < target_len:
316
- audio_segment = np.pad(audio_segment, (0, target_len - len(audio_segment)), mode='constant')
317
-
318
- # === 文本随机丢弃 ===
319
- if random.random() < self.text_drop_ratio:
320
- text = ''
321
-
322
- return pixel_values, text, audio_segment, sample_rate
323
-
324
- def __len__(self):
325
- return self.length
326
-
327
- def __getitem__(self, idx):
328
- while True:
329
- sample = {}
330
- try:
331
- pixel_values, text, audio, sample_rate = self.get_batch(idx)
332
- sample["pixel_values"] = pixel_values
333
- sample["text"] = text
334
- sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
335
- sample["sample_rate"] = sample_rate
336
- sample["idx"] = idx
337
- break
338
- except Exception as e:
339
- print(f"Error processing {idx}: {e}, retrying with random idx...")
340
- idx = random.randint(0, self.length - 1)
341
-
342
- if self.enable_inpaint and not self.enable_bucket:
343
- mask = get_random_mask(pixel_values.size(), image_start_only=True)
344
- mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
345
- sample["mask_pixel_values"] = mask_pixel_values
346
- sample["mask"] = mask
347
-
348
- clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
349
- clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
350
- sample["clip_pixel_values"] = clip_pixel_values
351
-
352
- return sample
353
-
354
-
355
- class VideoSpeechControlDataset(Dataset):
356
- def __init__(
357
- self,
358
- ann_path, data_root=None,
359
- video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
360
- enable_bucket=False, enable_inpaint=False,
361
- audio_sr=16000,
362
- text_drop_ratio=0.1,
363
- enable_motion_info=False,
364
- motion_frames=73,
365
- ):
366
- print(f"loading annotations from {ann_path} ...")
367
- self.dataset = json.load(open(ann_path, 'r'))
368
- self.length = len(self.dataset)
369
- print(f"data scale: {self.length}")
370
-
371
- self.data_root = data_root
372
- self.video_sample_stride = video_sample_stride
373
- self.video_sample_n_frames = video_sample_n_frames
374
- self.enable_bucket = enable_bucket
375
- self.enable_inpaint = enable_inpaint
376
- self.audio_sr = audio_sr
377
- self.text_drop_ratio = text_drop_ratio
378
- self.enable_motion_info = enable_motion_info
379
- self.motion_frames = motion_frames
380
-
381
- video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
382
- self.pixel_transforms = transforms.Compose(
383
- [
384
- transforms.Resize(video_sample_size[0]),
385
- transforms.CenterCrop(video_sample_size),
386
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
387
- ]
388
- )
389
-
390
- self.video_sample_size = video_sample_size
391
-
392
- def get_batch(self, idx):
393
- video_dict = self.dataset[idx]
394
- video_id, text = video_dict['file_path'], video_dict['text']
395
- audio_id = video_dict['audio_path']
396
- control_video_id = video_dict['control_file_path']
397
-
398
- if self.data_root is None:
399
- video_path = video_id
400
- else:
401
- video_path = os.path.join(self.data_root, video_id)
402
-
403
- if self.data_root is None:
404
- audio_path = audio_id
405
- else:
406
- audio_path = os.path.join(self.data_root, audio_id)
407
-
408
- if self.data_root is None:
409
- control_video_id = control_video_id
410
- else:
411
- control_video_id = os.path.join(self.data_root, control_video_id)
412
-
413
- if not os.path.exists(audio_path):
414
- raise FileNotFoundError(f"Audio file not found for {video_path}")
415
-
416
- # Video information
417
- with VideoReader_contextmanager(video_path, num_threads=2) as video_reader:
418
- total_frames = len(video_reader)
419
- fps = video_reader.get_avg_fps()
420
- if fps <= 0:
421
- raise ValueError(f"Video has negative fps: {video_path}")
422
- local_video_sample_stride = self.video_sample_stride
423
- new_fps = int(fps // local_video_sample_stride)
424
- while new_fps > 30:
425
- local_video_sample_stride = local_video_sample_stride + 1
426
- new_fps = int(fps // local_video_sample_stride)
427
-
428
- max_possible_frames = (total_frames - 1) // local_video_sample_stride + 1
429
- actual_n_frames = min(self.video_sample_n_frames, max_possible_frames)
430
- if actual_n_frames <= 0:
431
- raise ValueError(f"Video too short: {video_path}")
432
-
433
- max_start = total_frames - (actual_n_frames - 1) * local_video_sample_stride - 1
434
- start_frame = random.randint(0, max_start) if max_start > 0 else 0
435
- frame_indices = [start_frame + i * local_video_sample_stride for i in range(actual_n_frames)]
436
-
437
- try:
438
- sample_args = (video_reader, frame_indices)
439
- pixel_values = func_timeout(
440
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
441
- )
442
- except FunctionTimedOut:
443
- raise ValueError(f"Read {idx} timeout.")
444
- except Exception as e:
445
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
446
-
447
- _, height, width, channel = np.shape(pixel_values)
448
- if self.enable_motion_info:
449
- motion_pixel_values = np.ones([self.motion_frames, height, width, channel]) * 127.5
450
- if start_frame > 0:
451
- motion_max_possible_frames = (start_frame - 1) // local_video_sample_stride + 1
452
- motion_frame_indices = [0 + i * local_video_sample_stride for i in range(motion_max_possible_frames)]
453
- motion_frame_indices = motion_frame_indices[-self.motion_frames:]
454
-
455
- _motion_sample_args = (video_reader, motion_frame_indices)
456
- _motion_pixel_values = func_timeout(
457
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=_motion_sample_args
458
- )
459
- motion_pixel_values[-len(motion_frame_indices):] = _motion_pixel_values
460
-
461
- if not self.enable_bucket:
462
- motion_pixel_values = torch.from_numpy(motion_pixel_values).permute(0, 3, 1, 2).contiguous()
463
- motion_pixel_values = motion_pixel_values / 255.
464
- motion_pixel_values = self.pixel_transforms(motion_pixel_values)
465
- else:
466
- motion_pixel_values = None
467
-
468
- if not self.enable_bucket:
469
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
470
- pixel_values = pixel_values / 255.
471
- pixel_values = self.pixel_transforms(pixel_values)
472
-
473
- # Audio information
474
- start_time = start_frame / fps
475
- end_time = (start_frame + (actual_n_frames - 1) * local_video_sample_stride) / fps
476
- duration = end_time - start_time
477
-
478
- audio_input, sample_rate = librosa.load(audio_path, sr=self.audio_sr)
479
- start_sample = int(start_time * self.audio_sr)
480
- end_sample = int(end_time * self.audio_sr)
481
-
482
- if start_sample >= len(audio_input):
483
- raise ValueError(f"Audio file too short: {audio_path}")
484
- else:
485
- audio_segment = audio_input[start_sample:end_sample]
486
- target_len = int(duration * self.audio_sr)
487
- if len(audio_segment) < target_len:
488
- raise ValueError(f"Audio file too short: {audio_path}")
489
-
490
- # Control information
491
- with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
492
- try:
493
- sample_args = (control_video_reader, frame_indices)
494
- control_pixel_values = func_timeout(
495
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
496
- )
497
- resized_frames = []
498
- for i in range(len(control_pixel_values)):
499
- frame = control_pixel_values[i]
500
- resized_frame = resize_frame(frame, max(self.video_sample_size))
501
- resized_frames.append(resized_frame)
502
- control_pixel_values = np.array(control_pixel_values)
503
- except FunctionTimedOut:
504
- raise ValueError(f"Read {idx} timeout.")
505
- except Exception as e:
506
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
507
-
508
- if not self.enable_bucket:
509
- control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
510
- control_pixel_values = control_pixel_values / 255.
511
- del control_video_reader
512
- else:
513
- control_pixel_values = control_pixel_values
514
-
515
- if not self.enable_bucket:
516
- control_pixel_values = self.video_transforms(control_pixel_values)
517
-
518
- if random.random() < self.text_drop_ratio:
519
- text = ''
520
-
521
- return pixel_values, motion_pixel_values, control_pixel_values, text, audio_segment, sample_rate, new_fps
522
-
523
- def __len__(self):
524
- return self.length
525
-
526
- def __getitem__(self, idx):
527
- while True:
528
- sample = {}
529
- try:
530
- pixel_values, motion_pixel_values, control_pixel_values, text, audio, sample_rate, new_fps = self.get_batch(idx)
531
- sample["pixel_values"] = pixel_values
532
- sample["motion_pixel_values"] = motion_pixel_values
533
- sample["control_pixel_values"] = control_pixel_values
534
- sample["text"] = text
535
- sample["audio"] = torch.from_numpy(audio).float() # 转为 tensor
536
- sample["sample_rate"] = sample_rate
537
- sample["fps"] = new_fps
538
- sample["idx"] = idx
539
- break
540
- except Exception as e:
541
- print(f"Error processing {idx}: {e}, retrying with random idx...")
542
- idx = random.randint(0, self.length - 1)
543
-
544
- if self.enable_inpaint and not self.enable_bucket:
545
- mask = get_random_mask(pixel_values.size(), image_start_only=True)
546
- mask_pixel_values = pixel_values * (1 - mask) + torch.zeros_like(pixel_values) * mask
547
- sample["mask_pixel_values"] = mask_pixel_values
548
- sample["mask"] = mask
549
-
550
- clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
551
- clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
552
- sample["clip_pixel_values"] = clip_pixel_values
553
-
554
- return sample
555
-
556
-
557
- class VideoAnimateDataset(Dataset):
558
- def __init__(
559
- self,
560
- ann_path, data_root=None,
561
- video_sample_size=512,
562
- video_sample_stride=4,
563
- video_sample_n_frames=16,
564
- video_repeat=0,
565
- text_drop_ratio=0.1,
566
- enable_bucket=False,
567
- video_length_drop_start=0.1,
568
- video_length_drop_end=0.9,
569
- return_file_name=False,
570
- ):
571
- # Loading annotations from files
572
- print(f"loading annotations from {ann_path} ...")
573
- if ann_path.endswith('.csv'):
574
- with open(ann_path, 'r') as csvfile:
575
- dataset = list(csv.DictReader(csvfile))
576
- elif ann_path.endswith('.json'):
577
- dataset = json.load(open(ann_path))
578
-
579
- self.data_root = data_root
580
-
581
- # It's used to balance num of images and videos.
582
- if video_repeat > 0:
583
- self.dataset = []
584
- for data in dataset:
585
- if data.get('type', 'image') != 'video':
586
- self.dataset.append(data)
587
-
588
- for _ in range(video_repeat):
589
- for data in dataset:
590
- if data.get('type', 'image') == 'video':
591
- self.dataset.append(data)
592
- else:
593
- self.dataset = dataset
594
- del dataset
595
-
596
- self.length = len(self.dataset)
597
- print(f"data scale: {self.length}")
598
- # TODO: enable bucket training
599
- self.enable_bucket = enable_bucket
600
- self.text_drop_ratio = text_drop_ratio
601
-
602
- self.video_length_drop_start = video_length_drop_start
603
- self.video_length_drop_end = video_length_drop_end
604
-
605
- # Video params
606
- self.video_sample_stride = video_sample_stride
607
- self.video_sample_n_frames = video_sample_n_frames
608
- self.video_sample_size = tuple(video_sample_size) if not isinstance(video_sample_size, int) else (video_sample_size, video_sample_size)
609
- self.video_transforms = transforms.Compose(
610
- [
611
- transforms.Resize(min(self.video_sample_size)),
612
- transforms.CenterCrop(self.video_sample_size),
613
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
614
- ]
615
- )
616
-
617
- self.larger_side_of_image_and_video = min(self.video_sample_size)
618
-
619
- def get_batch(self, idx):
620
- data_info = self.dataset[idx % len(self.dataset)]
621
- video_id, text = data_info['file_path'], data_info['text']
622
-
623
- if self.data_root is None:
624
- video_dir = video_id
625
- else:
626
- video_dir = os.path.join(self.data_root, video_id)
627
-
628
- with VideoReader_contextmanager(video_dir, num_threads=2) as video_reader:
629
- min_sample_n_frames = min(
630
- self.video_sample_n_frames,
631
- int(len(video_reader) * (self.video_length_drop_end - self.video_length_drop_start) // self.video_sample_stride)
632
- )
633
- if min_sample_n_frames == 0:
634
- raise ValueError(f"No Frames in video.")
635
-
636
- video_length = int(self.video_length_drop_end * len(video_reader))
637
- clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
638
- start_idx = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
639
- batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
640
-
641
- try:
642
- sample_args = (video_reader, batch_index)
643
- pixel_values = func_timeout(
644
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
645
- )
646
- resized_frames = []
647
- for i in range(len(pixel_values)):
648
- frame = pixel_values[i]
649
- resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
650
- resized_frames.append(resized_frame)
651
- pixel_values = np.array(resized_frames)
652
- except FunctionTimedOut:
653
- raise ValueError(f"Read {idx} timeout.")
654
- except Exception as e:
655
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
656
-
657
- if not self.enable_bucket:
658
- pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous()
659
- pixel_values = pixel_values / 255.
660
- del video_reader
661
- else:
662
- pixel_values = pixel_values
663
-
664
- if not self.enable_bucket:
665
- pixel_values = self.video_transforms(pixel_values)
666
-
667
- # Random use no text generation
668
- if random.random() < self.text_drop_ratio:
669
- text = ''
670
-
671
- control_video_id = data_info['control_file_path']
672
-
673
- if control_video_id is not None:
674
- if self.data_root is None:
675
- control_video_id = control_video_id
676
- else:
677
- control_video_id = os.path.join(self.data_root, control_video_id)
678
-
679
- if control_video_id is not None:
680
- with VideoReader_contextmanager(control_video_id, num_threads=2) as control_video_reader:
681
- try:
682
- sample_args = (control_video_reader, batch_index)
683
- control_pixel_values = func_timeout(
684
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
685
- )
686
- resized_frames = []
687
- for i in range(len(control_pixel_values)):
688
- frame = control_pixel_values[i]
689
- resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
690
- resized_frames.append(resized_frame)
691
- control_pixel_values = np.array(resized_frames)
692
- except FunctionTimedOut:
693
- raise ValueError(f"Read {idx} timeout.")
694
- except Exception as e:
695
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
696
-
697
- if not self.enable_bucket:
698
- control_pixel_values = torch.from_numpy(control_pixel_values).permute(0, 3, 1, 2).contiguous()
699
- control_pixel_values = control_pixel_values / 255.
700
- del control_video_reader
701
- else:
702
- control_pixel_values = control_pixel_values
703
-
704
- if not self.enable_bucket:
705
- control_pixel_values = self.video_transforms(control_pixel_values)
706
- else:
707
- if not self.enable_bucket:
708
- control_pixel_values = torch.zeros_like(pixel_values)
709
- else:
710
- control_pixel_values = np.zeros_like(pixel_values)
711
-
712
- face_video_id = data_info['face_file_path']
713
-
714
- if face_video_id is not None:
715
- if self.data_root is None:
716
- face_video_id = face_video_id
717
- else:
718
- face_video_id = os.path.join(self.data_root, face_video_id)
719
-
720
- if face_video_id is not None:
721
- with VideoReader_contextmanager(face_video_id, num_threads=2) as face_video_reader:
722
- try:
723
- sample_args = (face_video_reader, batch_index)
724
- face_pixel_values = func_timeout(
725
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
726
- )
727
- resized_frames = []
728
- for i in range(len(face_pixel_values)):
729
- frame = face_pixel_values[i]
730
- resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
731
- resized_frames.append(resized_frame)
732
- face_pixel_values = np.array(resized_frames)
733
- except FunctionTimedOut:
734
- raise ValueError(f"Read {idx} timeout.")
735
- except Exception as e:
736
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
737
-
738
- if not self.enable_bucket:
739
- face_pixel_values = torch.from_numpy(face_pixel_values).permute(0, 3, 1, 2).contiguous()
740
- face_pixel_values = face_pixel_values / 255.
741
- del face_video_reader
742
- else:
743
- face_pixel_values = face_pixel_values
744
-
745
- if not self.enable_bucket:
746
- face_pixel_values = self.video_transforms(face_pixel_values)
747
- else:
748
- if not self.enable_bucket:
749
- face_pixel_values = torch.zeros_like(pixel_values)
750
- else:
751
- face_pixel_values = np.zeros_like(pixel_values)
752
-
753
- background_video_id = data_info.get('background_file_path', None)
754
-
755
- if background_video_id is not None:
756
- if self.data_root is None:
757
- background_video_id = background_video_id
758
- else:
759
- background_video_id = os.path.join(self.data_root, background_video_id)
760
-
761
- if background_video_id is not None:
762
- with VideoReader_contextmanager(background_video_id, num_threads=2) as background_video_reader:
763
- try:
764
- sample_args = (background_video_reader, batch_index)
765
- background_pixel_values = func_timeout(
766
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
767
- )
768
- resized_frames = []
769
- for i in range(len(background_pixel_values)):
770
- frame = background_pixel_values[i]
771
- resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
772
- resized_frames.append(resized_frame)
773
- background_pixel_values = np.array(resized_frames)
774
- except FunctionTimedOut:
775
- raise ValueError(f"Read {idx} timeout.")
776
- except Exception as e:
777
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
778
-
779
- if not self.enable_bucket:
780
- background_pixel_values = torch.from_numpy(background_pixel_values).permute(0, 3, 1, 2).contiguous()
781
- background_pixel_values = background_pixel_values / 255.
782
- del background_video_reader
783
- else:
784
- background_pixel_values = background_pixel_values
785
-
786
- if not self.enable_bucket:
787
- background_pixel_values = self.video_transforms(background_pixel_values)
788
- else:
789
- if not self.enable_bucket:
790
- background_pixel_values = torch.ones_like(pixel_values) * 127.5
791
- else:
792
- background_pixel_values = np.ones_like(pixel_values) * 127.5
793
-
794
- mask_video_id = data_info.get('mask_file_path', None)
795
-
796
- if mask_video_id is not None:
797
- if self.data_root is None:
798
- mask_video_id = mask_video_id
799
- else:
800
- mask_video_id = os.path.join(self.data_root, mask_video_id)
801
-
802
- if mask_video_id is not None:
803
- with VideoReader_contextmanager(mask_video_id, num_threads=2) as mask_video_reader:
804
- try:
805
- sample_args = (mask_video_reader, batch_index)
806
- mask = func_timeout(
807
- VIDEO_READER_TIMEOUT, get_video_reader_batch, args=sample_args
808
- )
809
- resized_frames = []
810
- for i in range(len(mask)):
811
- frame = mask[i]
812
- resized_frame = resize_frame(frame, self.larger_side_of_image_and_video)
813
- resized_frames.append(resized_frame)
814
- mask = np.array(resized_frames)
815
- except FunctionTimedOut:
816
- raise ValueError(f"Read {idx} timeout.")
817
- except Exception as e:
818
- raise ValueError(f"Failed to extract frames from video. Error is {e}.")
819
-
820
- if not self.enable_bucket:
821
- mask = torch.from_numpy(mask).permute(0, 3, 1, 2).contiguous()
822
- mask = mask / 255.
823
- del mask_video_reader
824
- else:
825
- mask = mask
826
- else:
827
- if not self.enable_bucket:
828
- mask = torch.ones_like(pixel_values)
829
- else:
830
- mask = np.ones_like(pixel_values) * 255
831
- mask = mask[:, :, :, :1]
832
-
833
- ref_pixel_values_path = data_info.get('ref_file_path', [])
834
- if self.data_root is not None:
835
- ref_pixel_values_path = os.path.join(self.data_root, ref_pixel_values_path)
836
- ref_pixel_values = Image.open(ref_pixel_values_path).convert('RGB')
837
-
838
- if not self.enable_bucket:
839
- raise ValueError("Not enable_bucket is not supported now. ")
840
- else:
841
- ref_pixel_values = np.array(ref_pixel_values)
842
-
843
- return pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, text, "video"
844
-
845
- def __len__(self):
846
- return self.length
847
-
848
- def __getitem__(self, idx):
849
- data_info = self.dataset[idx % len(self.dataset)]
850
- data_type = data_info.get('type', 'image')
851
- while True:
852
- sample = {}
853
- try:
854
- data_info_local = self.dataset[idx % len(self.dataset)]
855
- data_type_local = data_info_local.get('type', 'image')
856
- if data_type_local != data_type:
857
- raise ValueError("data_type_local != data_type")
858
-
859
- pixel_values, control_pixel_values, face_pixel_values, background_pixel_values, mask, ref_pixel_values, name, data_type = \
860
- self.get_batch(idx)
861
-
862
- sample["pixel_values"] = pixel_values
863
- sample["control_pixel_values"] = control_pixel_values
864
- sample["face_pixel_values"] = face_pixel_values
865
- sample["background_pixel_values"] = background_pixel_values
866
- sample["mask"] = mask
867
- sample["ref_pixel_values"] = ref_pixel_values
868
- sample["clip_pixel_values"] = ref_pixel_values
869
- sample["text"] = name
870
- sample["data_type"] = data_type
871
- sample["idx"] = idx
872
-
873
- if len(sample) > 0:
874
- break
875
- except Exception as e:
876
- print(e, self.dataset[idx % len(self.dataset)])
877
- idx = random.randint(0, self.length-1)
878
-
879
- return sample
880
-
881
-
882
- if __name__ == "__main__":
883
- if 1:
884
- dataset = VideoDataset(
885
- json_path="./webvidval/results_2M_val.json",
886
- sample_size=256,
887
- sample_stride=4, sample_n_frames=16,
888
- )
889
-
890
- if 0:
891
- dataset = WebVid10M(
892
- csv_path="./webvid/results_2M_val.csv",
893
- video_folder="./webvid/2M_val",
894
- sample_size=256,
895
- sample_stride=4, sample_n_frames=16,
896
- is_image=False,
897
- )
898
-
899
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
900
- for idx, batch in enumerate(dataloader):
901
- print(batch["pixel_values"].shape, len(batch["text"]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/data/utils.py DELETED
@@ -1,347 +0,0 @@
1
- import csv
2
- import gc
3
- import io
4
- import json
5
- import math
6
- import os
7
- import random
8
- from contextlib import contextmanager
9
- from random import shuffle
10
- from threading import Thread
11
-
12
- import albumentations
13
- import cv2
14
- import numpy as np
15
- import torch
16
- import torch.nn.functional as F
17
- import torchvision.transforms as transforms
18
- from decord import VideoReader
19
- from einops import rearrange
20
- from func_timeout import FunctionTimedOut, func_timeout
21
- from packaging import version as pver
22
- from PIL import Image
23
- from safetensors.torch import load_file
24
- from torch.utils.data import BatchSampler, Sampler
25
- from torch.utils.data.dataset import Dataset
26
-
27
- VIDEO_READER_TIMEOUT = 20
28
-
29
- def get_random_mask(shape, image_start_only=False):
30
- f, c, h, w = shape
31
- mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
32
-
33
- if not image_start_only:
34
- if f != 1:
35
- mask_index = np.random.choice([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], p=[0.05, 0.2, 0.2, 0.2, 0.05, 0.05, 0.05, 0.1, 0.05, 0.05])
36
- else:
37
- mask_index = np.random.choice([0, 1, 7, 8], p = [0.2, 0.7, 0.05, 0.05])
38
- if mask_index == 0:
39
- center_x = torch.randint(0, w, (1,)).item()
40
- center_y = torch.randint(0, h, (1,)).item()
41
- block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
42
- block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
43
-
44
- start_x = max(center_x - block_size_x // 2, 0)
45
- end_x = min(center_x + block_size_x // 2, w)
46
- start_y = max(center_y - block_size_y // 2, 0)
47
- end_y = min(center_y + block_size_y // 2, h)
48
- mask[:, :, start_y:end_y, start_x:end_x] = 1
49
- elif mask_index == 1:
50
- mask[:, :, :, :] = 1
51
- elif mask_index == 2:
52
- mask_frame_index = np.random.randint(1, 5)
53
- mask[mask_frame_index:, :, :, :] = 1
54
- elif mask_index == 3:
55
- mask_frame_index = np.random.randint(1, 5)
56
- mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
57
- elif mask_index == 4:
58
- center_x = torch.randint(0, w, (1,)).item()
59
- center_y = torch.randint(0, h, (1,)).item()
60
- block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item() # 方块的宽度范围
61
- block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item() # 方块的高度范围
62
-
63
- start_x = max(center_x - block_size_x // 2, 0)
64
- end_x = min(center_x + block_size_x // 2, w)
65
- start_y = max(center_y - block_size_y // 2, 0)
66
- end_y = min(center_y + block_size_y // 2, h)
67
-
68
- mask_frame_before = np.random.randint(0, f // 2)
69
- mask_frame_after = np.random.randint(f // 2, f)
70
- mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
71
- elif mask_index == 5:
72
- mask = torch.randint(0, 2, (f, 1, h, w), dtype=torch.uint8)
73
- elif mask_index == 6:
74
- num_frames_to_mask = random.randint(1, max(f // 2, 1))
75
- frames_to_mask = random.sample(range(f), num_frames_to_mask)
76
-
77
- for i in frames_to_mask:
78
- block_height = random.randint(1, h // 4)
79
- block_width = random.randint(1, w // 4)
80
- top_left_y = random.randint(0, h - block_height)
81
- top_left_x = random.randint(0, w - block_width)
82
- mask[i, 0, top_left_y:top_left_y + block_height, top_left_x:top_left_x + block_width] = 1
83
- elif mask_index == 7:
84
- center_x = torch.randint(0, w, (1,)).item()
85
- center_y = torch.randint(0, h, (1,)).item()
86
- a = torch.randint(min(w, h) // 8, min(w, h) // 4, (1,)).item() # 长半轴
87
- b = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item() # 短半轴
88
-
89
- for i in range(h):
90
- for j in range(w):
91
- if ((i - center_y) ** 2) / (b ** 2) + ((j - center_x) ** 2) / (a ** 2) < 1:
92
- mask[:, :, i, j] = 1
93
- elif mask_index == 8:
94
- center_x = torch.randint(0, w, (1,)).item()
95
- center_y = torch.randint(0, h, (1,)).item()
96
- radius = torch.randint(min(h, w) // 8, min(h, w) // 4, (1,)).item()
97
- for i in range(h):
98
- for j in range(w):
99
- if (i - center_y) ** 2 + (j - center_x) ** 2 < radius ** 2:
100
- mask[:, :, i, j] = 1
101
- elif mask_index == 9:
102
- for idx in range(f):
103
- if np.random.rand() > 0.5:
104
- mask[idx, :, :, :] = 1
105
- else:
106
- raise ValueError(f"The mask_index {mask_index} is not define")
107
- else:
108
- if f != 1:
109
- mask[1:, :, :, :] = 1
110
- else:
111
- mask[:, :, :, :] = 1
112
- return mask
113
-
114
- @contextmanager
115
- def VideoReader_contextmanager(*args, **kwargs):
116
- vr = VideoReader(*args, **kwargs)
117
- try:
118
- yield vr
119
- finally:
120
- del vr
121
- gc.collect()
122
-
123
- def get_video_reader_batch(video_reader, batch_index):
124
- frames = video_reader.get_batch(batch_index).asnumpy()
125
- return frames
126
-
127
- def resize_frame(frame, target_short_side):
128
- h, w, _ = frame.shape
129
- if h < w:
130
- if target_short_side > h:
131
- return frame
132
- new_h = target_short_side
133
- new_w = int(target_short_side * w / h)
134
- else:
135
- if target_short_side > w:
136
- return frame
137
- new_w = target_short_side
138
- new_h = int(target_short_side * h / w)
139
-
140
- resized_frame = cv2.resize(frame, (new_w, new_h))
141
- return resized_frame
142
-
143
- def padding_image(images, new_width, new_height):
144
- new_image = Image.new('RGB', (new_width, new_height), (255, 255, 255))
145
-
146
- aspect_ratio = images.width / images.height
147
- if new_width / new_height > 1:
148
- if aspect_ratio > new_width / new_height:
149
- new_img_width = new_width
150
- new_img_height = int(new_img_width / aspect_ratio)
151
- else:
152
- new_img_height = new_height
153
- new_img_width = int(new_img_height * aspect_ratio)
154
- else:
155
- if aspect_ratio > new_width / new_height:
156
- new_img_width = new_width
157
- new_img_height = int(new_img_width / aspect_ratio)
158
- else:
159
- new_img_height = new_height
160
- new_img_width = int(new_img_height * aspect_ratio)
161
-
162
- resized_img = images.resize((new_img_width, new_img_height))
163
-
164
- paste_x = (new_width - new_img_width) // 2
165
- paste_y = (new_height - new_img_height) // 2
166
-
167
- new_image.paste(resized_img, (paste_x, paste_y))
168
-
169
- return new_image
170
-
171
- def resize_image_with_target_area(img: Image.Image, target_area: int = 1024 * 1024) -> Image.Image:
172
- """
173
- 将 PIL 图像缩放到接近指定像素面积(target_area),保持原始宽高比,
174
- 并确保新宽度和高度均为 32 的整数倍。
175
-
176
- 参数:
177
- img (PIL.Image.Image): 输入图像
178
- target_area (int): 目标像素总面积,例如 1024*1024 = 1048576
179
-
180
- 返回:
181
- PIL.Image.Image: Resize 后的图像
182
- """
183
- orig_w, orig_h = img.size
184
- if orig_w == 0 or orig_h == 0:
185
- raise ValueError("Input image has zero width or height.")
186
-
187
- ratio = orig_w / orig_h
188
- ideal_width = math.sqrt(target_area * ratio)
189
- ideal_height = ideal_width / ratio
190
-
191
- new_width = round(ideal_width / 32) * 32
192
- new_height = round(ideal_height / 32) * 32
193
-
194
- new_width = max(32, new_width)
195
- new_height = max(32, new_height)
196
-
197
- new_width = int(new_width)
198
- new_height = int(new_height)
199
-
200
- resized_img = img.resize((new_width, new_height), Image.LANCZOS)
201
- return resized_img
202
-
203
- class Camera(object):
204
- """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
205
- """
206
- def __init__(self, entry):
207
- fx, fy, cx, cy = entry[1:5]
208
- self.fx = fx
209
- self.fy = fy
210
- self.cx = cx
211
- self.cy = cy
212
- w2c_mat = np.array(entry[7:]).reshape(3, 4)
213
- w2c_mat_4x4 = np.eye(4)
214
- w2c_mat_4x4[:3, :] = w2c_mat
215
- self.w2c_mat = w2c_mat_4x4
216
- self.c2w_mat = np.linalg.inv(w2c_mat_4x4)
217
-
218
- def custom_meshgrid(*args):
219
- """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
220
- """
221
- # ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
222
- if pver.parse(torch.__version__) < pver.parse('1.10'):
223
- return torch.meshgrid(*args)
224
- else:
225
- return torch.meshgrid(*args, indexing='ij')
226
-
227
- def get_relative_pose(cam_params):
228
- """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
229
- """
230
- abs_w2cs = [cam_param.w2c_mat for cam_param in cam_params]
231
- abs_c2ws = [cam_param.c2w_mat for cam_param in cam_params]
232
- cam_to_origin = 0
233
- target_cam_c2w = np.array([
234
- [1, 0, 0, 0],
235
- [0, 1, 0, -cam_to_origin],
236
- [0, 0, 1, 0],
237
- [0, 0, 0, 1]
238
- ])
239
- abs2rel = target_cam_c2w @ abs_w2cs[0]
240
- ret_poses = [target_cam_c2w, ] + [abs2rel @ abs_c2w for abs_c2w in abs_c2ws[1:]]
241
- ret_poses = np.array(ret_poses, dtype=np.float32)
242
- return ret_poses
243
-
244
- def ray_condition(K, c2w, H, W, device):
245
- """Copied from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
246
- """
247
- # c2w: B, V, 4, 4
248
- # K: B, V, 4
249
-
250
- B = K.shape[0]
251
-
252
- j, i = custom_meshgrid(
253
- torch.linspace(0, H - 1, H, device=device, dtype=c2w.dtype),
254
- torch.linspace(0, W - 1, W, device=device, dtype=c2w.dtype),
255
- )
256
- i = i.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
257
- j = j.reshape([1, 1, H * W]).expand([B, 1, H * W]) + 0.5 # [B, HxW]
258
-
259
- fx, fy, cx, cy = K.chunk(4, dim=-1) # B,V, 1
260
-
261
- zs = torch.ones_like(i) # [B, HxW]
262
- xs = (i - cx) / fx * zs
263
- ys = (j - cy) / fy * zs
264
- zs = zs.expand_as(ys)
265
-
266
- directions = torch.stack((xs, ys, zs), dim=-1) # B, V, HW, 3
267
- directions = directions / directions.norm(dim=-1, keepdim=True) # B, V, HW, 3
268
-
269
- rays_d = directions @ c2w[..., :3, :3].transpose(-1, -2) # B, V, 3, HW
270
- rays_o = c2w[..., :3, 3] # B, V, 3
271
- rays_o = rays_o[:, :, None].expand_as(rays_d) # B, V, 3, HW
272
- # c2w @ dirctions
273
- rays_dxo = torch.cross(rays_o, rays_d)
274
- plucker = torch.cat([rays_dxo, rays_d], dim=-1)
275
- plucker = plucker.reshape(B, c2w.shape[1], H, W, 6) # B, V, H, W, 6
276
- # plucker = plucker.permute(0, 1, 4, 2, 3)
277
- return plucker
278
-
279
- def process_pose_file(pose_file_path, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu', return_poses=False):
280
- """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
281
- """
282
- with open(pose_file_path, 'r') as f:
283
- poses = f.readlines()
284
-
285
- poses = [pose.strip().split(' ') for pose in poses[1:]]
286
- cam_params = [[float(x) for x in pose] for pose in poses]
287
- if return_poses:
288
- return cam_params
289
- else:
290
- cam_params = [Camera(cam_param) for cam_param in cam_params]
291
-
292
- sample_wh_ratio = width / height
293
- pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
294
-
295
- if pose_wh_ratio > sample_wh_ratio:
296
- resized_ori_w = height * pose_wh_ratio
297
- for cam_param in cam_params:
298
- cam_param.fx = resized_ori_w * cam_param.fx / width
299
- else:
300
- resized_ori_h = width / pose_wh_ratio
301
- for cam_param in cam_params:
302
- cam_param.fy = resized_ori_h * cam_param.fy / height
303
-
304
- intrinsic = np.asarray([[cam_param.fx * width,
305
- cam_param.fy * height,
306
- cam_param.cx * width,
307
- cam_param.cy * height]
308
- for cam_param in cam_params], dtype=np.float32)
309
-
310
- K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
311
- c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
312
- c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
313
- plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
314
- plucker_embedding = plucker_embedding[None]
315
- plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
316
- return plucker_embedding
317
-
318
- def process_pose_params(cam_params, width=672, height=384, original_pose_width=1280, original_pose_height=720, device='cpu'):
319
- """Modified from https://github.com/hehao13/CameraCtrl/blob/main/inference.py
320
- """
321
- cam_params = [Camera(cam_param) for cam_param in cam_params]
322
-
323
- sample_wh_ratio = width / height
324
- pose_wh_ratio = original_pose_width / original_pose_height # Assuming placeholder ratios, change as needed
325
-
326
- if pose_wh_ratio > sample_wh_ratio:
327
- resized_ori_w = height * pose_wh_ratio
328
- for cam_param in cam_params:
329
- cam_param.fx = resized_ori_w * cam_param.fx / width
330
- else:
331
- resized_ori_h = width / pose_wh_ratio
332
- for cam_param in cam_params:
333
- cam_param.fy = resized_ori_h * cam_param.fy / height
334
-
335
- intrinsic = np.asarray([[cam_param.fx * width,
336
- cam_param.fy * height,
337
- cam_param.cx * width,
338
- cam_param.cy * height]
339
- for cam_param in cam_params], dtype=np.float32)
340
-
341
- K = torch.as_tensor(intrinsic)[None] # [1, 1, 4]
342
- c2ws = get_relative_pose(cam_params) # Assuming this function is defined elsewhere
343
- c2ws = torch.as_tensor(c2ws)[None] # [1, n_frame, 4, 4]
344
- plucker_embedding = ray_condition(K, c2ws, height, width, device=device)[0].permute(0, 3, 1, 2).contiguous() # V, 6, H, W
345
- plucker_embedding = plucker_embedding[None]
346
- plucker_embedding = rearrange(plucker_embedding, "b f c h w -> b f h w c")[0]
347
- return plucker_embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/__init__.py DELETED
@@ -1,72 +0,0 @@
1
- import importlib.util
2
-
3
- from .cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
4
- from .flux2_xfuser import Flux2MultiGPUsAttnProcessor2_0
5
- from .flux_xfuser import FluxMultiGPUsAttnProcessor2_0
6
- from .fsdp import shard_model
7
- from .fuser import (get_sequence_parallel_rank,
8
- get_sequence_parallel_world_size, get_sp_group,
9
- get_world_group, init_distributed_environment,
10
- initialize_model_parallel, sequence_parallel_all_gather,
11
- sequence_parallel_chunk, set_multi_gpus_devices,
12
- xFuserLongContextAttention)
13
- from .hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0
14
- from .qwen_xfuser import QwenImageMultiGPUsAttnProcessor2_0
15
- from .wan_xfuser import usp_attn_forward, usp_attn_s2v_forward
16
- from .z_image_xfuser import ZMultiGPUsSingleStreamAttnProcessor
17
-
18
- # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
19
- if importlib.util.find_spec("paifuser") is not None:
20
- # --------------------------------------------------------------- #
21
- # The simple_wrapper is used to solve the problem
22
- # about conflicts between cython and torch.compile
23
- # --------------------------------------------------------------- #
24
- def simple_wrapper(func):
25
- def inner(*args, **kwargs):
26
- return func(*args, **kwargs)
27
- return inner
28
-
29
- # --------------------------------------------------------------- #
30
- # Sparse Attention Kernel
31
- # --------------------------------------------------------------- #
32
- from paifuser.models import parallel_magvit_vae
33
- from paifuser.ops import wan_usp_sparse_attention_wrapper
34
-
35
- from . import wan_xfuser
36
-
37
- # --------------------------------------------------------------- #
38
- # Sparse Attention
39
- # --------------------------------------------------------------- #
40
- usp_sparse_attn_wrap_forward = simple_wrapper(wan_usp_sparse_attention_wrapper()(wan_xfuser.usp_attn_forward))
41
- wan_xfuser.usp_attn_forward = usp_sparse_attn_wrap_forward
42
- usp_attn_forward = usp_sparse_attn_wrap_forward
43
- print("Import PAI VAE Turbo and Sparse Attention")
44
-
45
- # --------------------------------------------------------------- #
46
- # Fast Rope Kernel
47
- # --------------------------------------------------------------- #
48
- import types
49
-
50
- import torch
51
- from paifuser.ops import (ENABLE_KERNEL, usp_fast_rope_apply_qk,
52
- usp_rope_apply_real_qk)
53
-
54
- def deepcopy_function(f):
55
- return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
56
-
57
- local_rope_apply_qk = deepcopy_function(wan_xfuser.rope_apply_qk)
58
-
59
- if ENABLE_KERNEL:
60
- def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
61
- if torch.is_grad_enabled():
62
- return local_rope_apply_qk(q, k, grid_sizes, freqs)
63
- else:
64
- return usp_fast_rope_apply_qk(q, k, grid_sizes, freqs)
65
-
66
- else:
67
- def adaptive_fast_usp_rope_apply_qk(q, k, grid_sizes, freqs):
68
- return usp_rope_apply_real_qk(q, k, grid_sizes, freqs)
69
-
70
- wan_xfuser.rope_apply_qk = adaptive_fast_usp_rope_apply_qk
71
- rope_apply_qk = adaptive_fast_usp_rope_apply_qk
72
- print("Import PAI Fast rope")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/cogvideox_xfuser.py DELETED
@@ -1,93 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from diffusers.models.attention import Attention
6
- from diffusers.models.embeddings import apply_rotary_emb
7
-
8
- from .fuser import (get_sequence_parallel_rank,
9
- get_sequence_parallel_world_size, get_sp_group,
10
- init_distributed_environment, initialize_model_parallel,
11
- xFuserLongContextAttention)
12
-
13
- class CogVideoXMultiGPUsAttnProcessor2_0:
14
- r"""
15
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
16
- query and key vectors, but does not include spatial normalization.
17
- """
18
-
19
- def __init__(self):
20
- if not hasattr(F, "scaled_dot_product_attention"):
21
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
22
-
23
- def __call__(
24
- self,
25
- attn: Attention,
26
- hidden_states: torch.Tensor,
27
- encoder_hidden_states: torch.Tensor,
28
- attention_mask: Optional[torch.Tensor] = None,
29
- image_rotary_emb: Optional[torch.Tensor] = None,
30
- ) -> torch.Tensor:
31
- text_seq_length = encoder_hidden_states.size(1)
32
-
33
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
34
-
35
- batch_size, sequence_length, _ = (
36
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
37
- )
38
-
39
- if attention_mask is not None:
40
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
41
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
42
-
43
- query = attn.to_q(hidden_states)
44
- key = attn.to_k(hidden_states)
45
- value = attn.to_v(hidden_states)
46
-
47
- inner_dim = key.shape[-1]
48
- head_dim = inner_dim // attn.heads
49
-
50
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
51
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
52
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
53
-
54
- if attn.norm_q is not None:
55
- query = attn.norm_q(query)
56
- if attn.norm_k is not None:
57
- key = attn.norm_k(key)
58
-
59
- # Apply RoPE if needed
60
- if image_rotary_emb is not None:
61
- query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
62
- if not attn.is_cross_attention:
63
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
64
-
65
- img_q = query[:, :, text_seq_length:].transpose(1, 2)
66
- txt_q = query[:, :, :text_seq_length].transpose(1, 2)
67
- img_k = key[:, :, text_seq_length:].transpose(1, 2)
68
- txt_k = key[:, :, :text_seq_length].transpose(1, 2)
69
- img_v = value[:, :, text_seq_length:].transpose(1, 2)
70
- txt_v = value[:, :, :text_seq_length].transpose(1, 2)
71
-
72
- hidden_states = xFuserLongContextAttention()(
73
- None,
74
- img_q, img_k, img_v, dropout_p=0.0, causal=False,
75
- joint_tensor_query=txt_q,
76
- joint_tensor_key=txt_k,
77
- joint_tensor_value=txt_v,
78
- joint_strategy='front',
79
- )
80
-
81
- hidden_states = hidden_states.flatten(2, 3)
82
- hidden_states = hidden_states.to(query.dtype)
83
-
84
- # linear proj
85
- hidden_states = attn.to_out[0](hidden_states)
86
- # dropout
87
- hidden_states = attn.to_out[1](hidden_states)
88
-
89
- encoder_hidden_states, hidden_states = hidden_states.split(
90
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
91
- )
92
- return hidden_states, encoder_hidden_states
93
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/flux2_xfuser.py DELETED
@@ -1,194 +0,0 @@
1
- from typing import Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from diffusers.models.attention_processor import Attention
6
-
7
- from .fuser import xFuserLongContextAttention
8
-
9
-
10
- def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
11
- query = attn.to_q(hidden_states)
12
- key = attn.to_k(hidden_states)
13
- value = attn.to_v(hidden_states)
14
-
15
- encoder_query = encoder_key = encoder_value = None
16
- if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
17
- encoder_query = attn.add_q_proj(encoder_hidden_states)
18
- encoder_key = attn.add_k_proj(encoder_hidden_states)
19
- encoder_value = attn.add_v_proj(encoder_hidden_states)
20
-
21
- return query, key, value, encoder_query, encoder_key, encoder_value
22
-
23
-
24
- def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
25
- return _get_projections(attn, hidden_states, encoder_hidden_states)
26
-
27
-
28
- def apply_rotary_emb(
29
- x: torch.Tensor,
30
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
31
- use_real: bool = True,
32
- use_real_unbind_dim: int = -1,
33
- sequence_dim: int = 2,
34
- ) -> Tuple[torch.Tensor, torch.Tensor]:
35
- """
36
- Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
37
- to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
38
- reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
39
- tensors contain rotary embeddings and are returned as real tensors.
40
-
41
- Args:
42
- x (`torch.Tensor`):
43
- Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
44
- freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
45
-
46
- Returns:
47
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
48
- """
49
- if use_real:
50
- cos, sin = freqs_cis # [S, D]
51
- if sequence_dim == 2:
52
- cos = cos[None, None, :, :]
53
- sin = sin[None, None, :, :]
54
- elif sequence_dim == 1:
55
- cos = cos[None, :, None, :]
56
- sin = sin[None, :, None, :]
57
- else:
58
- raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
59
-
60
- cos, sin = cos.to(x.device), sin.to(x.device)
61
-
62
- if use_real_unbind_dim == -1:
63
- # Used for flux, cogvideox, hunyuan-dit
64
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
65
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
66
- elif use_real_unbind_dim == -2:
67
- # Used for Stable Audio, OmniGen, CogView4 and Cosmos
68
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
69
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
70
- else:
71
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
72
-
73
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
74
-
75
- return out
76
- else:
77
- # used for lumina
78
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
79
- freqs_cis = freqs_cis.unsqueeze(2)
80
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
81
-
82
- return x_out.type_as(x)
83
-
84
-
85
- class Flux2MultiGPUsAttnProcessor2_0:
86
- r"""
87
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
88
- query and key vectors, but does not include spatial normalization.
89
- """
90
-
91
- def __init__(self):
92
- if not hasattr(F, "scaled_dot_product_attention"):
93
- raise ImportError("Flux2MultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
94
-
95
- def __call__(
96
- self,
97
- attn: "FluxAttention",
98
- hidden_states: torch.Tensor,
99
- encoder_hidden_states: Optional[torch.Tensor] = None,
100
- attention_mask: Optional[torch.Tensor] = None,
101
- image_rotary_emb: Optional[torch.Tensor] = None,
102
- text_seq_len: int = None,
103
- ) -> torch.FloatTensor:
104
- # Determine which type of attention we're processing
105
- is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None
106
-
107
- if is_parallel_self_attn:
108
- # Parallel in (QKV + MLP in) projection
109
- hidden_states = attn.to_qkv_mlp_proj(hidden_states)
110
- qkv, mlp_hidden_states = torch.split(
111
- hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
112
- )
113
-
114
- # Handle the attention logic
115
- query, key, value = qkv.chunk(3, dim=-1)
116
-
117
- else:
118
- query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
119
- attn, hidden_states, encoder_hidden_states
120
- )
121
-
122
- # Common processing for query, key, value
123
- query = query.unflatten(-1, (attn.heads, -1))
124
- key = key.unflatten(-1, (attn.heads, -1))
125
- value = value.unflatten(-1, (attn.heads, -1))
126
-
127
- query = attn.norm_q(query)
128
- key = attn.norm_k(key)
129
-
130
- # Handle encoder projections (only for standard attention)
131
- if not is_parallel_self_attn and attn.added_kv_proj_dim is not None:
132
- encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
133
- encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
134
- encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
135
-
136
- encoder_query = attn.norm_added_q(encoder_query)
137
- encoder_key = attn.norm_added_k(encoder_key)
138
-
139
- query = torch.cat([encoder_query, query], dim=1)
140
- key = torch.cat([encoder_key, key], dim=1)
141
- value = torch.cat([encoder_value, value], dim=1)
142
-
143
- # Apply rotary embeddings
144
- if image_rotary_emb is not None:
145
- query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
146
- key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
147
-
148
- if not is_parallel_self_attn and attn.added_kv_proj_dim is not None and text_seq_len is None:
149
- text_seq_len = encoder_query.shape[1]
150
-
151
- txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
152
- img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
153
-
154
- half_dtypes = (torch.float16, torch.bfloat16)
155
- def half(x):
156
- return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
157
-
158
- hidden_states = xFuserLongContextAttention()(
159
- None,
160
- half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
161
- joint_tensor_query=half(txt_query) if txt_query is not None else None,
162
- joint_tensor_key=half(txt_key) if txt_key is not None else None,
163
- joint_tensor_value=half(txt_value) if txt_value is not None else None,
164
- joint_strategy='front',
165
- )
166
- hidden_states = hidden_states.flatten(2, 3)
167
- hidden_states = hidden_states.to(query.dtype)
168
-
169
- if is_parallel_self_attn:
170
- # Handle the feedforward (FF) logic
171
- mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
172
-
173
- # Concatenate and parallel output projection
174
- hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
175
- hidden_states = attn.to_out(hidden_states)
176
-
177
- return hidden_states
178
-
179
- else:
180
- # Split encoder and latent hidden states if encoder was used
181
- if encoder_hidden_states is not None:
182
- encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
183
- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
184
- )
185
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
186
-
187
- # Project output
188
- hidden_states = attn.to_out[0](hidden_states)
189
- hidden_states = attn.to_out[1](hidden_states)
190
-
191
- if encoder_hidden_states is not None:
192
- return hidden_states, encoder_hidden_states
193
- else:
194
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/flux_xfuser.py DELETED
@@ -1,165 +0,0 @@
1
- from typing import Optional, Tuple, Union
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from diffusers.models.attention_processor import Attention
6
-
7
- from .fuser import xFuserLongContextAttention
8
-
9
-
10
- def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
11
- query = attn.to_q(hidden_states)
12
- key = attn.to_k(hidden_states)
13
- value = attn.to_v(hidden_states)
14
-
15
- encoder_query = encoder_key = encoder_value = None
16
- if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
17
- encoder_query = attn.add_q_proj(encoder_hidden_states)
18
- encoder_key = attn.add_k_proj(encoder_hidden_states)
19
- encoder_value = attn.add_v_proj(encoder_hidden_states)
20
-
21
- return query, key, value, encoder_query, encoder_key, encoder_value
22
-
23
-
24
- def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
25
- return _get_projections(attn, hidden_states, encoder_hidden_states)
26
-
27
-
28
- def apply_rotary_emb(
29
- x: torch.Tensor,
30
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
31
- use_real: bool = True,
32
- use_real_unbind_dim: int = -1,
33
- sequence_dim: int = 2,
34
- ) -> Tuple[torch.Tensor, torch.Tensor]:
35
- """
36
- Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
37
- to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
38
- reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
39
- tensors contain rotary embeddings and are returned as real tensors.
40
-
41
- Args:
42
- x (`torch.Tensor`):
43
- Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
44
- freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
45
-
46
- Returns:
47
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
48
- """
49
- if use_real:
50
- cos, sin = freqs_cis # [S, D]
51
- if sequence_dim == 2:
52
- cos = cos[None, None, :, :]
53
- sin = sin[None, None, :, :]
54
- elif sequence_dim == 1:
55
- cos = cos[None, :, None, :]
56
- sin = sin[None, :, None, :]
57
- else:
58
- raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
59
-
60
- cos, sin = cos.to(x.device), sin.to(x.device)
61
-
62
- if use_real_unbind_dim == -1:
63
- # Used for flux, cogvideox, hunyuan-dit
64
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
65
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
66
- elif use_real_unbind_dim == -2:
67
- # Used for Stable Audio, OmniGen, CogView4 and Cosmos
68
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
69
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
70
- else:
71
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
72
-
73
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
74
-
75
- return out
76
- else:
77
- # used for lumina
78
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
79
- freqs_cis = freqs_cis.unsqueeze(2)
80
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
81
-
82
- return x_out.type_as(x)
83
-
84
-
85
- class FluxMultiGPUsAttnProcessor2_0:
86
- r"""
87
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
88
- query and key vectors, but does not include spatial normalization.
89
- """
90
-
91
- def __init__(self):
92
- if not hasattr(F, "scaled_dot_product_attention"):
93
- raise ImportError("FluxMultiGPUsAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
94
-
95
- def __call__(
96
- self,
97
- attn: "FluxAttention",
98
- hidden_states: torch.Tensor,
99
- encoder_hidden_states: torch.Tensor = None,
100
- attention_mask: Optional[torch.Tensor] = None,
101
- image_rotary_emb: Optional[torch.Tensor] = None,
102
- text_seq_len: int = None,
103
- ) -> torch.FloatTensor:
104
- query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
105
- attn, hidden_states, encoder_hidden_states
106
- )
107
-
108
- query = query.unflatten(-1, (attn.heads, -1))
109
- key = key.unflatten(-1, (attn.heads, -1))
110
- value = value.unflatten(-1, (attn.heads, -1))
111
-
112
- query = attn.norm_q(query)
113
- key = attn.norm_k(key)
114
-
115
- if attn.added_kv_proj_dim is not None:
116
- encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
117
- encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
118
- encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
119
-
120
- encoder_query = attn.norm_added_q(encoder_query)
121
- encoder_key = attn.norm_added_k(encoder_key)
122
-
123
- query = torch.cat([encoder_query, query], dim=1)
124
- key = torch.cat([encoder_key, key], dim=1)
125
- value = torch.cat([encoder_value, value], dim=1)
126
-
127
- # Apply rotary embeddings
128
- if image_rotary_emb is not None:
129
- query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
130
- key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
131
-
132
- if attn.added_kv_proj_dim is not None and text_seq_len is None:
133
- text_seq_len = encoder_query.shape[1]
134
-
135
- txt_query, txt_key, txt_value = query[:, :text_seq_len], key[:, :text_seq_len], value[:, :text_seq_len]
136
- img_query, img_key, img_value = query[:, text_seq_len:], key[:, text_seq_len:], value[:, text_seq_len:]
137
-
138
- half_dtypes = (torch.float16, torch.bfloat16)
139
- def half(x):
140
- return x if x.dtype in half_dtypes else x.to(torch.bfloat16)
141
-
142
- hidden_states = xFuserLongContextAttention()(
143
- None,
144
- half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
145
- joint_tensor_query=half(txt_query) if txt_query is not None else None,
146
- joint_tensor_key=half(txt_key) if txt_key is not None else None,
147
- joint_tensor_value=half(txt_value) if txt_value is not None else None,
148
- joint_strategy='front',
149
- )
150
-
151
- # Reshape back
152
- hidden_states = hidden_states.flatten(2, 3)
153
- hidden_states = hidden_states.to(img_query.dtype)
154
-
155
- if encoder_hidden_states is not None:
156
- encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
157
- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
158
- )
159
- hidden_states = attn.to_out[0](hidden_states)
160
- hidden_states = attn.to_out[1](hidden_states)
161
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
162
-
163
- return hidden_states, encoder_hidden_states
164
- else:
165
- return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/fsdp.py DELETED
@@ -1,44 +0,0 @@
1
- # Copyied from https://github.com/Wan-Video/Wan2.1/blob/main/wan/distributed/fsdp.py
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- import gc
4
- from functools import partial
5
-
6
- import torch
7
- from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
8
- from torch.distributed.fsdp import MixedPrecision, ShardingStrategy
9
- from torch.distributed.fsdp.wrap import lambda_auto_wrap_policy
10
- from torch.distributed.utils import _free_storage
11
-
12
-
13
- def shard_model(
14
- model,
15
- device_id,
16
- param_dtype=torch.bfloat16,
17
- reduce_dtype=torch.float32,
18
- buffer_dtype=torch.float32,
19
- process_group=None,
20
- sharding_strategy=ShardingStrategy.FULL_SHARD,
21
- sync_module_states=True,
22
- module_to_wrapper=None,
23
- ):
24
- model = FSDP(
25
- module=model,
26
- process_group=process_group,
27
- sharding_strategy=sharding_strategy,
28
- auto_wrap_policy=partial(
29
- lambda_auto_wrap_policy, lambda_fn=lambda m: m in (model.blocks if module_to_wrapper is None else module_to_wrapper)),
30
- mixed_precision=MixedPrecision(
31
- param_dtype=param_dtype,
32
- reduce_dtype=reduce_dtype,
33
- buffer_dtype=buffer_dtype),
34
- device_id=device_id,
35
- sync_module_states=sync_module_states)
36
- return model
37
-
38
- def free_model(model):
39
- for m in model.modules():
40
- if isinstance(m, FSDP):
41
- _free_storage(m._handle.flat_param.data)
42
- del model
43
- gc.collect()
44
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/fuser.py DELETED
@@ -1,87 +0,0 @@
1
- import importlib.util
2
-
3
- import torch
4
- import torch.distributed as dist
5
-
6
- try:
7
- # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
8
- if importlib.util.find_spec("paifuser") is not None:
9
- import paifuser
10
- from paifuser.xfuser.core.distributed import (
11
- get_sequence_parallel_rank, get_sequence_parallel_world_size,
12
- get_sp_group, get_world_group, init_distributed_environment,
13
- initialize_model_parallel, model_parallel_is_initialized)
14
- from paifuser.xfuser.core.long_ctx_attention import \
15
- xFuserLongContextAttention
16
- print("Import PAI DiT Turbo")
17
- else:
18
- import xfuser
19
- from xfuser.core.distributed import (get_sequence_parallel_rank,
20
- get_sequence_parallel_world_size,
21
- get_sp_group, get_world_group,
22
- init_distributed_environment,
23
- initialize_model_parallel,
24
- model_parallel_is_initialized)
25
- from xfuser.core.long_ctx_attention import xFuserLongContextAttention
26
- print("Xfuser import sucessful")
27
- except Exception as ex:
28
- get_sequence_parallel_world_size = None
29
- get_sequence_parallel_rank = None
30
- xFuserLongContextAttention = None
31
- get_sp_group = None
32
- get_world_group = None
33
- init_distributed_environment = None
34
- initialize_model_parallel = None
35
-
36
- def set_multi_gpus_devices(ulysses_degree, ring_degree, classifier_free_guidance_degree=1):
37
- if ulysses_degree > 1 or ring_degree > 1 or classifier_free_guidance_degree > 1:
38
- if get_sp_group is None:
39
- raise RuntimeError("xfuser is not installed.")
40
- dist.init_process_group("nccl")
41
- print('parallel inference enabled: ulysses_degree=%d ring_degree=%d classifier_free_guidance_degree=% rank=%d world_size=%d' % (
42
- ulysses_degree, ring_degree, classifier_free_guidance_degree, dist.get_rank(),
43
- dist.get_world_size()))
44
- assert dist.get_world_size() == ring_degree * ulysses_degree * classifier_free_guidance_degree, \
45
- "number of GPUs(%d) should be equal to ring_degree * ulysses_degree * classifier_free_guidance_degree." % dist.get_world_size()
46
- init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
47
- initialize_model_parallel(sequence_parallel_degree=ring_degree * ulysses_degree,
48
- classifier_free_guidance_degree=classifier_free_guidance_degree,
49
- ring_degree=ring_degree,
50
- ulysses_degree=ulysses_degree)
51
- # device = torch.device("cuda:%d" % dist.get_rank())
52
- device = torch.device(f"cuda:{get_world_group().local_rank}")
53
- print('rank=%d device=%s' % (get_world_group().rank, str(device)))
54
- else:
55
- device = "cuda"
56
- return device
57
-
58
- def sequence_parallel_chunk(x, dim=1):
59
- if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
60
- return x
61
-
62
- sp_world_size = get_sequence_parallel_world_size()
63
- if sp_world_size <= 1:
64
- return x
65
-
66
- sp_rank = get_sequence_parallel_rank()
67
- sp_group = get_sp_group()
68
-
69
- if x.size(1) % sp_world_size != 0:
70
- raise ValueError(f"Dim 1 of x ({x.size(1)}) not divisible by SP world size ({sp_world_size})")
71
-
72
- chunks = torch.chunk(x, sp_world_size, dim=1)
73
- x = chunks[sp_rank]
74
-
75
- return x
76
-
77
- def sequence_parallel_all_gather(x, dim=1):
78
- if get_sequence_parallel_world_size is None or not model_parallel_is_initialized():
79
- return x
80
-
81
- sp_world_size = get_sequence_parallel_world_size()
82
- if sp_world_size <= 1:
83
- return x # No gathering needed
84
-
85
- sp_group = get_sp_group()
86
- gathered_x = sp_group.all_gather(x, dim=dim)
87
- return gathered_x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/hunyuanvideo_xfuser.py DELETED
@@ -1,166 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- from diffusers.models.attention import Attention
6
- from diffusers.models.embeddings import apply_rotary_emb
7
-
8
- from .fuser import (get_sequence_parallel_rank,
9
- get_sequence_parallel_world_size, get_sp_group,
10
- init_distributed_environment, initialize_model_parallel,
11
- xFuserLongContextAttention)
12
-
13
- def extract_seqlens_from_mask(attn_mask, text_seq_length):
14
- if attn_mask is None:
15
- return None
16
-
17
- if len(attn_mask.shape) == 4:
18
- bs, _, _, seq_len = attn_mask.shape
19
-
20
- if attn_mask.dtype == torch.bool:
21
- valid_mask = attn_mask.squeeze(1).squeeze(1)
22
- else:
23
- valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1))
24
- elif len(attn_mask.shape) == 3:
25
- raise ValueError(
26
- "attn_mask should be 2D or 4D tensor, but got {}".format(
27
- attn_mask.shape))
28
-
29
- seqlens = valid_mask[:, -text_seq_length:].sum(dim=1)
30
- return seqlens
31
-
32
- class HunyuanVideoMultiGPUsAttnProcessor2_0:
33
- r"""
34
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
35
- query and key vectors, but does not include spatial normalization.
36
- """
37
-
38
- def __init__(self):
39
- if xFuserLongContextAttention is not None:
40
- try:
41
- self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
42
- except Exception:
43
- self.hybrid_seq_parallel_attn = None
44
- else:
45
- self.hybrid_seq_parallel_attn = None
46
- if not hasattr(F, "scaled_dot_product_attention"):
47
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
48
-
49
- def __call__(
50
- self,
51
- attn: Attention,
52
- hidden_states: torch.Tensor,
53
- encoder_hidden_states: torch.Tensor,
54
- attention_mask: Optional[torch.Tensor] = None,
55
- image_rotary_emb: Optional[torch.Tensor] = None,
56
- ) -> torch.Tensor:
57
- if attn.add_q_proj is None and encoder_hidden_states is not None:
58
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
59
-
60
- # 1. QKV projections
61
- query = attn.to_q(hidden_states)
62
- key = attn.to_k(hidden_states)
63
- value = attn.to_v(hidden_states)
64
-
65
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
66
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
67
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
68
-
69
- # 2. QK normalization
70
- if attn.norm_q is not None:
71
- query = attn.norm_q(query)
72
- if attn.norm_k is not None:
73
- key = attn.norm_k(key)
74
-
75
- # 3. Rotational positional embeddings applied to latent stream
76
- if image_rotary_emb is not None:
77
- if attn.add_q_proj is None and encoder_hidden_states is not None:
78
- query = torch.cat(
79
- [
80
- apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
81
- query[:, :, -encoder_hidden_states.shape[1] :],
82
- ],
83
- dim=2,
84
- )
85
- key = torch.cat(
86
- [
87
- apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
88
- key[:, :, -encoder_hidden_states.shape[1] :],
89
- ],
90
- dim=2,
91
- )
92
- else:
93
- query = apply_rotary_emb(query, image_rotary_emb)
94
- key = apply_rotary_emb(key, image_rotary_emb)
95
-
96
- # 4. Encoder condition QKV projection and normalization
97
- if attn.add_q_proj is not None and encoder_hidden_states is not None:
98
- encoder_query = attn.add_q_proj(encoder_hidden_states)
99
- encoder_key = attn.add_k_proj(encoder_hidden_states)
100
- encoder_value = attn.add_v_proj(encoder_hidden_states)
101
-
102
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
103
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
104
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
105
-
106
- if attn.norm_added_q is not None:
107
- encoder_query = attn.norm_added_q(encoder_query)
108
- if attn.norm_added_k is not None:
109
- encoder_key = attn.norm_added_k(encoder_key)
110
-
111
- query = torch.cat([query, encoder_query], dim=2)
112
- key = torch.cat([key, encoder_key], dim=2)
113
- value = torch.cat([value, encoder_value], dim=2)
114
-
115
- # 5. Attention
116
- if encoder_hidden_states is not None:
117
- text_seq_length = encoder_hidden_states.size(1)
118
-
119
- q_lens = k_lens = extract_seqlens_from_mask(attention_mask, text_seq_length)
120
-
121
- img_q = query[:, :, :-text_seq_length].transpose(1, 2)
122
- txt_q = query[:, :, -text_seq_length:].transpose(1, 2)
123
- img_k = key[:, :, :-text_seq_length].transpose(1, 2)
124
- txt_k = key[:, :, -text_seq_length:].transpose(1, 2)
125
- img_v = value[:, :, :-text_seq_length].transpose(1, 2)
126
- txt_v = value[:, :, -text_seq_length:].transpose(1, 2)
127
-
128
- hidden_states = torch.zeros_like(query.transpose(1, 2))
129
- local_q_length = img_q.size()[1]
130
- for i in range(len(q_lens)):
131
- hidden_states[i][:local_q_length + q_lens[i]] = self.hybrid_seq_parallel_attn(
132
- None,
133
- img_q[i].unsqueeze(0), img_k[i].unsqueeze(0), img_v[i].unsqueeze(0), dropout_p=0.0, causal=False,
134
- joint_tensor_query=txt_q[i][:q_lens[i]].unsqueeze(0),
135
- joint_tensor_key=txt_k[i][:q_lens[i]].unsqueeze(0),
136
- joint_tensor_value=txt_v[i][:q_lens[i]].unsqueeze(0),
137
- joint_strategy='rear',
138
- )
139
- else:
140
- query = query.transpose(1, 2)
141
- key = key.transpose(1, 2)
142
- value = value.transpose(1, 2)
143
- hidden_states = self.hybrid_seq_parallel_attn(
144
- None,
145
- query, key, value, dropout_p=0.0, causal=False
146
- )
147
-
148
- hidden_states = hidden_states.flatten(2, 3)
149
- hidden_states = hidden_states.to(query.dtype)
150
-
151
- # 6. Output projection
152
- if encoder_hidden_states is not None:
153
- hidden_states, encoder_hidden_states = (
154
- hidden_states[:, : -encoder_hidden_states.shape[1]],
155
- hidden_states[:, -encoder_hidden_states.shape[1] :],
156
- )
157
-
158
- if getattr(attn, "to_out", None) is not None:
159
- hidden_states = attn.to_out[0](hidden_states)
160
- hidden_states = attn.to_out[1](hidden_states)
161
-
162
- if getattr(attn, "to_add_out", None) is not None:
163
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
164
-
165
- return hidden_states, encoder_hidden_states
166
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/qwen_xfuser.py DELETED
@@ -1,176 +0,0 @@
1
- import functools
2
- import glob
3
- import json
4
- import math
5
- import os
6
- import types
7
- import warnings
8
- from typing import Any, Dict, List, Optional, Tuple, Union
9
-
10
- import numpy as np
11
- import torch
12
- import torch.cuda.amp as amp
13
- import torch.nn as nn
14
- import torch.nn.functional as F
15
- from diffusers.configuration_utils import ConfigMixin, register_to_config
16
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
17
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
- from diffusers.models.attention import FeedForward
19
- from diffusers.models.attention_processor import Attention
20
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
21
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
22
- from diffusers.models.modeling_utils import ModelMixin
23
- from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
24
- from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
25
- scale_lora_layers, unscale_lora_layers)
26
- from diffusers.utils.torch_utils import maybe_allow_in_graph
27
- from torch import nn
28
- from .fuser import (get_sequence_parallel_rank,
29
- get_sequence_parallel_world_size, get_sp_group,
30
- init_distributed_environment, initialize_model_parallel,
31
- xFuserLongContextAttention)
32
-
33
- def apply_rotary_emb_qwen(
34
- x: torch.Tensor,
35
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
36
- use_real: bool = True,
37
- use_real_unbind_dim: int = -1,
38
- ) -> Tuple[torch.Tensor, torch.Tensor]:
39
- """
40
- Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
41
- to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
42
- reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
43
- tensors contain rotary embeddings and are returned as real tensors.
44
-
45
- Args:
46
- x (`torch.Tensor`):
47
- Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
48
- freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
49
-
50
- Returns:
51
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
52
- """
53
- if use_real:
54
- cos, sin = freqs_cis # [S, D]
55
- cos = cos[None, None]
56
- sin = sin[None, None]
57
- cos, sin = cos.to(x.device), sin.to(x.device)
58
-
59
- if use_real_unbind_dim == -1:
60
- # Used for flux, cogvideox, hunyuan-dit
61
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
62
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
63
- elif use_real_unbind_dim == -2:
64
- # Used for Stable Audio, OmniGen, CogView4 and Cosmos
65
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
66
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
67
- else:
68
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
69
-
70
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
71
-
72
- return out
73
- else:
74
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
75
- freqs_cis = freqs_cis.unsqueeze(1)
76
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
77
-
78
- return x_out.type_as(x)
79
-
80
-
81
- class QwenImageMultiGPUsAttnProcessor2_0:
82
- r"""
83
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
84
- query and key vectors, but does not include spatial normalization.
85
- """
86
-
87
- def __init__(self):
88
- if not hasattr(F, "scaled_dot_product_attention"):
89
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
90
-
91
- def __call__(
92
- self,
93
- attn: Attention,
94
- hidden_states: torch.FloatTensor, # Image stream
95
- encoder_hidden_states: torch.FloatTensor = None, # Text stream
96
- encoder_hidden_states_mask: torch.FloatTensor = None,
97
- attention_mask: Optional[torch.FloatTensor] = None,
98
- image_rotary_emb: Optional[torch.Tensor] = None,
99
- ) -> torch.FloatTensor:
100
- if encoder_hidden_states is None:
101
- raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
102
-
103
- seq_txt = encoder_hidden_states.shape[1]
104
-
105
- # Compute QKV for image stream (sample projections)
106
- img_query = attn.to_q(hidden_states)
107
- img_key = attn.to_k(hidden_states)
108
- img_value = attn.to_v(hidden_states)
109
-
110
- # Compute QKV for text stream (context projections)
111
- txt_query = attn.add_q_proj(encoder_hidden_states)
112
- txt_key = attn.add_k_proj(encoder_hidden_states)
113
- txt_value = attn.add_v_proj(encoder_hidden_states)
114
-
115
- # Reshape for multi-head attention
116
- img_query = img_query.unflatten(-1, (attn.heads, -1))
117
- img_key = img_key.unflatten(-1, (attn.heads, -1))
118
- img_value = img_value.unflatten(-1, (attn.heads, -1))
119
-
120
- txt_query = txt_query.unflatten(-1, (attn.heads, -1))
121
- txt_key = txt_key.unflatten(-1, (attn.heads, -1))
122
- txt_value = txt_value.unflatten(-1, (attn.heads, -1))
123
-
124
- # Apply QK normalization
125
- if attn.norm_q is not None:
126
- img_query = attn.norm_q(img_query)
127
- if attn.norm_k is not None:
128
- img_key = attn.norm_k(img_key)
129
- if attn.norm_added_q is not None:
130
- txt_query = attn.norm_added_q(txt_query)
131
- if attn.norm_added_k is not None:
132
- txt_key = attn.norm_added_k(txt_key)
133
-
134
- # Apply RoPE
135
- if image_rotary_emb is not None:
136
- img_freqs, txt_freqs = image_rotary_emb
137
- img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
138
- img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
139
- txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
140
- txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
141
-
142
- # Concatenate for joint attention
143
- # Order: [text, image]
144
- # joint_query = torch.cat([txt_query, img_query], dim=1)
145
- # joint_key = torch.cat([txt_key, img_key], dim=1)
146
- # joint_value = torch.cat([txt_value, img_value], dim=1)
147
-
148
- half_dtypes = (torch.float16, torch.bfloat16)
149
- def half(x):
150
- return x if x.dtype in half_dtypes else x.to(dtype)
151
-
152
- joint_hidden_states = xFuserLongContextAttention()(
153
- None,
154
- half(img_query), half(img_key), half(img_value), dropout_p=0.0, causal=False,
155
- joint_tensor_query=half(txt_query),
156
- joint_tensor_key=half(txt_key),
157
- joint_tensor_value=half(txt_value),
158
- joint_strategy='front',
159
- )
160
-
161
- # Reshape back
162
- joint_hidden_states = joint_hidden_states.flatten(2, 3)
163
- joint_hidden_states = joint_hidden_states.to(img_query.dtype)
164
-
165
- # Split attention outputs back
166
- txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
167
- img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
168
-
169
- # Apply output projections
170
- img_attn_output = attn.to_out[0](img_attn_output)
171
- if len(attn.to_out) > 1:
172
- img_attn_output = attn.to_out[1](img_attn_output) # dropout
173
-
174
- txt_attn_output = attn.to_add_out(txt_attn_output)
175
-
176
- return img_attn_output, txt_attn_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/wan_xfuser.py DELETED
@@ -1,180 +0,0 @@
1
- import torch
2
- import torch.cuda.amp as amp
3
-
4
- from .fuser import (get_sequence_parallel_rank,
5
- get_sequence_parallel_world_size, get_sp_group,
6
- init_distributed_environment, initialize_model_parallel,
7
- xFuserLongContextAttention)
8
-
9
-
10
- def pad_freqs(original_tensor, target_len):
11
- seq_len, s1, s2 = original_tensor.shape
12
- pad_size = target_len - seq_len
13
- padding_tensor = torch.ones(
14
- pad_size,
15
- s1,
16
- s2,
17
- dtype=original_tensor.dtype,
18
- device=original_tensor.device)
19
- padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
20
- return padded_tensor
21
-
22
- @amp.autocast(enabled=False)
23
- @torch.compiler.disable()
24
- def rope_apply(x, grid_sizes, freqs):
25
- """
26
- x: [B, L, N, C].
27
- grid_sizes: [B, 3].
28
- freqs: [M, C // 2].
29
- """
30
- s, n, c = x.size(1), x.size(2), x.size(3) // 2
31
- # split freqs
32
- freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
33
-
34
- # loop over samples
35
- output = []
36
- for i, (f, h, w) in enumerate(grid_sizes.tolist()):
37
- seq_len = f * h * w
38
-
39
- # precompute multipliers
40
- x_i = torch.view_as_complex(x[i, :s].to(torch.float32).reshape(
41
- s, n, -1, 2))
42
- freqs_i = torch.cat([
43
- freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
44
- freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
45
- freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
46
- ],
47
- dim=-1).reshape(seq_len, 1, -1)
48
-
49
- # apply rotary embedding
50
- sp_size = get_sequence_parallel_world_size()
51
- sp_rank = get_sequence_parallel_rank()
52
- freqs_i = pad_freqs(freqs_i, s * sp_size)
53
- s_per_rank = s
54
- freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
55
- s_per_rank), :, :]
56
- x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
57
- x_i = torch.cat([x_i, x[i, s:]])
58
-
59
- # append to collection
60
- output.append(x_i)
61
- return torch.stack(output)
62
-
63
- def rope_apply_qk(q, k, grid_sizes, freqs):
64
- q = rope_apply(q, grid_sizes, freqs)
65
- k = rope_apply(k, grid_sizes, freqs)
66
- return q, k
67
-
68
- def usp_attn_forward(self,
69
- x,
70
- seq_lens,
71
- grid_sizes,
72
- freqs,
73
- dtype=torch.bfloat16,
74
- t=0):
75
- b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
76
- half_dtypes = (torch.float16, torch.bfloat16)
77
-
78
- def half(x):
79
- return x if x.dtype in half_dtypes else x.to(dtype)
80
-
81
- # query, key, value function
82
- def qkv_fn(x):
83
- q = self.norm_q(self.q(x)).view(b, s, n, d)
84
- k = self.norm_k(self.k(x)).view(b, s, n, d)
85
- v = self.v(x).view(b, s, n, d)
86
- return q, k, v
87
-
88
- q, k, v = qkv_fn(x)
89
- q, k = rope_apply_qk(q, k, grid_sizes, freqs)
90
-
91
- # TODO: We should use unpaded q,k,v for attention.
92
- # k_lens = seq_lens // get_sequence_parallel_world_size()
93
- # if k_lens is not None:
94
- # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
95
- # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
96
- # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
97
-
98
- x = xFuserLongContextAttention()(
99
- None,
100
- query=half(q),
101
- key=half(k),
102
- value=half(v),
103
- window_size=self.window_size)
104
-
105
- # TODO: padding after attention.
106
- # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
107
-
108
- # output
109
- x = x.flatten(2)
110
- x = self.o(x)
111
- return x
112
-
113
- @amp.autocast(enabled=False)
114
- @torch.compiler.disable()
115
- def s2v_rope_apply(x, grid_sizes, freqs):
116
- s, n, c = x.size(1), x.size(2), x.size(3) // 2
117
- # loop over samples
118
- output = []
119
- for i, _ in enumerate(x):
120
- s = x.size(1)
121
- # precompute multipliers
122
- x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(
123
- s, n, -1, 2))
124
- freqs_i = freqs[i]
125
- freqs_i_rank = pad_freqs(freqs_i, s)
126
- x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
127
- x_i = torch.cat([x_i, x[i, s:]])
128
- # append to collection
129
- output.append(x_i)
130
- return torch.stack(output).float()
131
-
132
- def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
133
- q = s2v_rope_apply(q, grid_sizes, freqs)
134
- k = s2v_rope_apply(k, grid_sizes, freqs)
135
- return q, k
136
-
137
- def usp_attn_s2v_forward(self,
138
- x,
139
- seq_lens,
140
- grid_sizes,
141
- freqs,
142
- dtype=torch.bfloat16,
143
- t=0):
144
- b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
145
- half_dtypes = (torch.float16, torch.bfloat16)
146
-
147
- def half(x):
148
- return x if x.dtype in half_dtypes else x.to(dtype)
149
-
150
- # query, key, value function
151
- def qkv_fn(x):
152
- q = self.norm_q(self.q(x)).view(b, s, n, d)
153
- k = self.norm_k(self.k(x)).view(b, s, n, d)
154
- v = self.v(x).view(b, s, n, d)
155
- return q, k, v
156
-
157
- q, k, v = qkv_fn(x)
158
- q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
159
-
160
- # TODO: We should use unpaded q,k,v for attention.
161
- # k_lens = seq_lens // get_sequence_parallel_world_size()
162
- # if k_lens is not None:
163
- # q = torch.cat([u[:l] for u, l in zip(q, k_lens)]).unsqueeze(0)
164
- # k = torch.cat([u[:l] for u, l in zip(k, k_lens)]).unsqueeze(0)
165
- # v = torch.cat([u[:l] for u, l in zip(v, k_lens)]).unsqueeze(0)
166
-
167
- x = xFuserLongContextAttention()(
168
- None,
169
- query=half(q),
170
- key=half(k),
171
- value=half(v),
172
- window_size=self.window_size)
173
-
174
- # TODO: padding after attention.
175
- # x = torch.cat([x, x.new_zeros(b, s - x.size(1), n, d)], dim=1)
176
-
177
- # output
178
- x = x.flatten(2)
179
- x = self.o(x)
180
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/dist/z_image_xfuser.py DELETED
@@ -1,85 +0,0 @@
1
- import torch
2
- import torch.cuda.amp as amp
3
- from typing import Optional
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- from diffusers.models.attention import Attention
8
-
9
- from .fuser import (get_sequence_parallel_rank,
10
- get_sequence_parallel_world_size, get_sp_group,
11
- init_distributed_environment, initialize_model_parallel,
12
- xFuserLongContextAttention)
13
-
14
- class ZMultiGPUsSingleStreamAttnProcessor:
15
- """
16
- Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
17
- original Z-ImageAttention module.
18
- """
19
-
20
- _attention_backend = None
21
- _parallel_config = None
22
-
23
- def __init__(self):
24
- if not hasattr(F, "scaled_dot_product_attention"):
25
- raise ImportError(
26
- "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
27
- )
28
-
29
- def __call__(
30
- self,
31
- attn: Attention,
32
- hidden_states: torch.Tensor,
33
- attention_mask: Optional[torch.Tensor] = None,
34
- freqs_cis: Optional[torch.Tensor] = None,
35
- ) -> torch.Tensor:
36
- query = attn.to_q(hidden_states)
37
- key = attn.to_k(hidden_states)
38
- value = attn.to_v(hidden_states)
39
-
40
- query = query.unflatten(-1, (attn.heads, -1))
41
- key = key.unflatten(-1, (attn.heads, -1))
42
- value = value.unflatten(-1, (attn.heads, -1))
43
-
44
- # Apply Norms
45
- if attn.norm_q is not None:
46
- query = attn.norm_q(query)
47
- if attn.norm_k is not None:
48
- key = attn.norm_k(key)
49
-
50
- # Apply RoPE
51
- def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
52
- with torch.amp.autocast("cuda", enabled=False):
53
- x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
54
- freqs_cis = freqs_cis.unsqueeze(2)
55
- x_out = torch.view_as_real(x * freqs_cis).flatten(3)
56
- return x_out.type_as(x_in) # todo
57
-
58
- if freqs_cis is not None:
59
- query = apply_rotary_emb(query, freqs_cis)
60
- key = apply_rotary_emb(key, freqs_cis)
61
-
62
- # Cast to correct dtype
63
- dtype = query.dtype
64
- query, key = query.to(dtype), key.to(dtype)
65
-
66
- # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
67
- if attention_mask is not None and attention_mask.ndim == 2:
68
- attention_mask = attention_mask[:, None, None, :]
69
-
70
- # Compute joint attention
71
- hidden_states = xFuserLongContextAttention()(
72
- query,
73
- key,
74
- value,
75
- )
76
-
77
- # Reshape back
78
- hidden_states = hidden_states.flatten(2, 3)
79
- hidden_states = hidden_states.to(dtype)
80
-
81
- output = attn.to_out[0](hidden_states)
82
- if len(attn.to_out) > 1: # dropout
83
- output = attn.to_out[1](output)
84
-
85
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/__init__.py DELETED
@@ -1,131 +0,0 @@
1
- import importlib.util
2
-
3
- from diffusers import AutoencoderKL
4
- # from transformers import (AutoProcessor, AutoTokenizer, CLIPImageProcessor,
5
- # CLIPTextModel, CLIPTokenizer,
6
- # CLIPVisionModelWithProjection, LlamaModel,
7
- # LlamaTokenizerFast, LlavaForConditionalGeneration,
8
- # Mistral3ForConditionalGeneration, PixtralProcessor,
9
- # Qwen3ForCausalLM, T5EncoderModel, T5Tokenizer,
10
- # T5TokenizerFast)
11
-
12
- # try:
13
- # from transformers import (Qwen2_5_VLConfig,
14
- # Qwen2_5_VLForConditionalGeneration,
15
- # Qwen2Tokenizer, Qwen2VLProcessor)
16
- # except:
17
- # Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer = None, None
18
- # Qwen2VLProcessor, Qwen2_5_VLConfig = None, None
19
- # print("Your transformers version is too old to load Qwen2_5_VLForConditionalGeneration and Qwen2Tokenizer. If you wish to use QwenImage, please upgrade your transformers package to the latest version.")
20
-
21
- # from .cogvideox_transformer3d import CogVideoXTransformer3DModel
22
- # from .cogvideox_vae import AutoencoderKLCogVideoX
23
- # from .fantasytalking_audio_encoder import FantasyTalkingAudioEncoder
24
- # from .fantasytalking_transformer3d import FantasyTalkingTransformer3DModel
25
- # from .flux2_image_processor import Flux2ImageProcessor
26
- # from .flux2_transformer2d import Flux2Transformer2DModel
27
- # from .flux2_transformer2d_control import Flux2ControlTransformer2DModel
28
- # from .flux2_vae import AutoencoderKLFlux2
29
- # from .flux_transformer2d import FluxTransformer2DModel
30
- # from .hunyuanvideo_transformer3d import HunyuanVideoTransformer3DModel
31
- # from .hunyuanvideo_vae import AutoencoderKLHunyuanVideo
32
- # from .qwenimage_transformer2d import QwenImageTransformer2DModel
33
- # from .qwenimage_vae import AutoencoderKLQwenImage
34
- # from .wan_audio_encoder import WanAudioEncoder
35
- # from .wan_image_encoder import CLIPModel
36
- # from .wan_text_encoder import WanT5EncoderModel
37
- # from .wan_transformer3d import (Wan2_2Transformer3DModel, WanRMSNorm,
38
- # WanSelfAttention, WanTransformer3DModel)
39
- # from .wan_transformer3d_animate import Wan2_2Transformer3DModel_Animate
40
- # from .wan_transformer3d_s2v import Wan2_2Transformer3DModel_S2V
41
- # from .wan_transformer3d_vace import VaceWanTransformer3DModel
42
- # from .wan_vae import AutoencoderKLWan, AutoencoderKLWan_
43
- # from .wan_vae3_8 import AutoencoderKLWan2_2_, AutoencoderKLWan3_8
44
- from .z_image_transformer2d import ZImageTransformer2DModel
45
- from .z_image_transformer2d_control import ZImageControlTransformer2DModel
46
-
47
- # The pai_fuser is an internally developed acceleration package, which can be used on PAI.
48
- # if importlib.util.find_spec("paifuser") is not None:
49
- # # --------------------------------------------------------------- #
50
- # # The simple_wrapper is used to solve the problem
51
- # # about conflicts between cython and torch.compile
52
- # # --------------------------------------------------------------- #
53
- # def simple_wrapper(func):
54
- # def inner(*args, **kwargs):
55
- # return func(*args, **kwargs)
56
- # return inner
57
-
58
- # # --------------------------------------------------------------- #
59
- # # VAE Parallel Kernel
60
- # # --------------------------------------------------------------- #
61
- # from ..dist import parallel_magvit_vae
62
- # AutoencoderKLWan_.decode = simple_wrapper(parallel_magvit_vae(0.4, 8)(AutoencoderKLWan_.decode))
63
- # AutoencoderKLWan2_2_.decode = simple_wrapper(parallel_magvit_vae(0.4, 16)(AutoencoderKLWan2_2_.decode))
64
-
65
- # # --------------------------------------------------------------- #
66
- # # Sparse Attention
67
- # # --------------------------------------------------------------- #
68
- # import torch
69
- # from paifuser.ops import wan_sparse_attention_wrapper
70
-
71
- # WanSelfAttention.forward = simple_wrapper(wan_sparse_attention_wrapper()(WanSelfAttention.forward))
72
- # print("Import Sparse Attention")
73
-
74
- # WanTransformer3DModel.forward = simple_wrapper(WanTransformer3DModel.forward)
75
-
76
- # # --------------------------------------------------------------- #
77
- # # CFG Skip Turbo
78
- # # --------------------------------------------------------------- #
79
- # import os
80
-
81
- # if importlib.util.find_spec("paifuser.accelerator") is not None:
82
- # from paifuser.accelerator import (cfg_skip_turbo, disable_cfg_skip,
83
- # enable_cfg_skip, share_cfg_skip)
84
- # else:
85
- # from paifuser import (cfg_skip_turbo, disable_cfg_skip,
86
- # enable_cfg_skip, share_cfg_skip)
87
-
88
- # WanTransformer3DModel.enable_cfg_skip = enable_cfg_skip()(WanTransformer3DModel.enable_cfg_skip)
89
- # WanTransformer3DModel.disable_cfg_skip = disable_cfg_skip()(WanTransformer3DModel.disable_cfg_skip)
90
- # WanTransformer3DModel.share_cfg_skip = share_cfg_skip()(WanTransformer3DModel.share_cfg_skip)
91
-
92
- # QwenImageTransformer2DModel.enable_cfg_skip = enable_cfg_skip()(QwenImageTransformer2DModel.enable_cfg_skip)
93
- # QwenImageTransformer2DModel.disable_cfg_skip = disable_cfg_skip()(QwenImageTransformer2DModel.disable_cfg_skip)
94
- # print("Import CFG Skip Turbo")
95
-
96
- # # --------------------------------------------------------------- #
97
- # # RMS Norm Kernel
98
- # # --------------------------------------------------------------- #
99
- # from paifuser.ops import rms_norm_forward
100
- # WanRMSNorm.forward = rms_norm_forward
101
- # print("Import PAI RMS Fuse")
102
-
103
- # # --------------------------------------------------------------- #
104
- # # Fast Rope Kernel
105
- # # --------------------------------------------------------------- #
106
- # import types
107
-
108
- # import torch
109
- # from paifuser.ops import (ENABLE_KERNEL, fast_rope_apply_qk,
110
- # rope_apply_real_qk)
111
-
112
- # from . import wan_transformer3d
113
-
114
- # def deepcopy_function(f):
115
- # return types.FunctionType(f.__code__, f.__globals__, name=f.__name__, argdefs=f.__defaults__,closure=f.__closure__)
116
-
117
- # local_rope_apply_qk = deepcopy_function(wan_transformer3d.rope_apply_qk)
118
-
119
- # if ENABLE_KERNEL:
120
- # def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
121
- # if torch.is_grad_enabled():
122
- # return local_rope_apply_qk(q, k, grid_sizes, freqs)
123
- # else:
124
- # return fast_rope_apply_qk(q, k, grid_sizes, freqs)
125
- # else:
126
- # def adaptive_fast_rope_apply_qk(q, k, grid_sizes, freqs):
127
- # return rope_apply_real_qk(q, k, grid_sizes, freqs)
128
-
129
- # wan_transformer3d.rope_apply_qk = adaptive_fast_rope_apply_qk
130
- # rope_apply_qk = adaptive_fast_rope_apply_qk
131
- # print("Import PAI Fast rope")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/attention_utils.py DELETED
@@ -1,211 +0,0 @@
1
- import os
2
-
3
- import torch
4
- import warnings
5
-
6
- try:
7
- import flash_attn_interface
8
- FLASH_ATTN_3_AVAILABLE = True
9
- except ModuleNotFoundError:
10
- FLASH_ATTN_3_AVAILABLE = False
11
-
12
- try:
13
- import flash_attn
14
- FLASH_ATTN_2_AVAILABLE = True
15
- except ModuleNotFoundError:
16
- FLASH_ATTN_2_AVAILABLE = False
17
-
18
- try:
19
- major, minor = torch.cuda.get_device_capability(0)
20
- if f"{major}.{minor}" == "8.0":
21
- from sageattention_sm80 import sageattn
22
- SAGE_ATTENTION_AVAILABLE = True
23
- elif f"{major}.{minor}" == "8.6":
24
- from sageattention_sm86 import sageattn
25
- SAGE_ATTENTION_AVAILABLE = True
26
- elif f"{major}.{minor}" == "8.9":
27
- from sageattention_sm89 import sageattn
28
- SAGE_ATTENTION_AVAILABLE = True
29
- elif f"{major}.{minor}" == "9.0":
30
- from sageattention_sm90 import sageattn
31
- SAGE_ATTENTION_AVAILABLE = True
32
- elif major>9:
33
- from sageattention_sm120 import sageattn
34
- SAGE_ATTENTION_AVAILABLE = True
35
- except:
36
- try:
37
- from sageattention import sageattn
38
- SAGE_ATTENTION_AVAILABLE = True
39
- except:
40
- sageattn = None
41
- SAGE_ATTENTION_AVAILABLE = False
42
-
43
- def flash_attention(
44
- q,
45
- k,
46
- v,
47
- q_lens=None,
48
- k_lens=None,
49
- dropout_p=0.,
50
- softmax_scale=None,
51
- q_scale=None,
52
- causal=False,
53
- window_size=(-1, -1),
54
- deterministic=False,
55
- dtype=torch.bfloat16,
56
- version=None,
57
- ):
58
- """
59
- q: [B, Lq, Nq, C1].
60
- k: [B, Lk, Nk, C1].
61
- v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
62
- q_lens: [B].
63
- k_lens: [B].
64
- dropout_p: float. Dropout probability.
65
- softmax_scale: float. The scaling of QK^T before applying softmax.
66
- causal: bool. Whether to apply causal attention mask.
67
- window_size: (left right). If not (-1, -1), apply sliding window local attention.
68
- deterministic: bool. If True, slightly slower and uses more memory.
69
- dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
70
- """
71
- half_dtypes = (torch.float16, torch.bfloat16)
72
- assert dtype in half_dtypes
73
- assert q.device.type == 'cuda' and q.size(-1) <= 256
74
-
75
- # params
76
- b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
77
-
78
- def half(x):
79
- return x if x.dtype in half_dtypes else x.to(dtype)
80
-
81
- # preprocess query
82
- if q_lens is None:
83
- q = half(q.flatten(0, 1))
84
- q_lens = torch.tensor(
85
- [lq] * b, dtype=torch.int32).to(
86
- device=q.device, non_blocking=True)
87
- else:
88
- q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
89
-
90
- # preprocess key, value
91
- if k_lens is None:
92
- k = half(k.flatten(0, 1))
93
- v = half(v.flatten(0, 1))
94
- k_lens = torch.tensor(
95
- [lk] * b, dtype=torch.int32).to(
96
- device=k.device, non_blocking=True)
97
- else:
98
- k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
99
- v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
100
-
101
- q = q.to(v.dtype)
102
- k = k.to(v.dtype)
103
-
104
- if q_scale is not None:
105
- q = q * q_scale
106
-
107
- if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
108
- warnings.warn(
109
- 'Flash attention 3 is not available, use flash attention 2 instead.'
110
- )
111
-
112
- # apply attention
113
- if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
114
- # Note: dropout_p, window_size are not supported in FA3 now.
115
- x = flash_attn_interface.flash_attn_varlen_func(
116
- q=q,
117
- k=k,
118
- v=v,
119
- cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
120
- 0, dtype=torch.int32).to(q.device, non_blocking=True),
121
- cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
122
- 0, dtype=torch.int32).to(q.device, non_blocking=True),
123
- seqused_q=None,
124
- seqused_k=None,
125
- max_seqlen_q=lq,
126
- max_seqlen_k=lk,
127
- softmax_scale=softmax_scale,
128
- causal=causal,
129
- deterministic=deterministic)[0].unflatten(0, (b, lq))
130
- else:
131
- assert FLASH_ATTN_2_AVAILABLE
132
- x = flash_attn.flash_attn_varlen_func(
133
- q=q,
134
- k=k,
135
- v=v,
136
- cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
137
- 0, dtype=torch.int32).to(q.device, non_blocking=True),
138
- cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
139
- 0, dtype=torch.int32).to(q.device, non_blocking=True),
140
- max_seqlen_q=lq,
141
- max_seqlen_k=lk,
142
- dropout_p=dropout_p,
143
- softmax_scale=softmax_scale,
144
- causal=causal,
145
- window_size=window_size,
146
- deterministic=deterministic).unflatten(0, (b, lq))
147
-
148
- # output
149
- return x.type(out_dtype)
150
-
151
-
152
- def attention(
153
- q,
154
- k,
155
- v,
156
- q_lens=None,
157
- k_lens=None,
158
- dropout_p=0.,
159
- softmax_scale=None,
160
- q_scale=None,
161
- causal=False,
162
- window_size=(-1, -1),
163
- deterministic=False,
164
- dtype=torch.bfloat16,
165
- fa_version=None,
166
- attention_type=None,
167
- attn_mask=None,
168
- ):
169
- attention_type = os.environ.get("VIDEOX_ATTENTION_TYPE", "FLASH_ATTENTION") if attention_type is None else attention_type
170
- if torch.is_grad_enabled() and attention_type == "SAGE_ATTENTION":
171
- attention_type = "FLASH_ATTENTION"
172
-
173
- if attention_type == "SAGE_ATTENTION" and SAGE_ATTENTION_AVAILABLE:
174
- if q_lens is not None or k_lens is not None:
175
- warnings.warn(
176
- 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
177
- )
178
-
179
- out = sageattn(
180
- q, k, v, attn_mask=attn_mask, tensor_layout="NHD", is_causal=causal, dropout_p=dropout_p)
181
-
182
- elif attention_type == "FLASH_ATTENTION" and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE):
183
- return flash_attention(
184
- q=q,
185
- k=k,
186
- v=v,
187
- q_lens=q_lens,
188
- k_lens=k_lens,
189
- dropout_p=dropout_p,
190
- softmax_scale=softmax_scale,
191
- q_scale=q_scale,
192
- causal=causal,
193
- window_size=window_size,
194
- deterministic=deterministic,
195
- dtype=dtype,
196
- version=fa_version,
197
- )
198
- else:
199
- if q_lens is not None or k_lens is not None:
200
- warnings.warn(
201
- 'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
202
- )
203
- q = q.transpose(1, 2)
204
- k = k.transpose(1, 2)
205
- v = v.transpose(1, 2)
206
-
207
- out = torch.nn.functional.scaled_dot_product_attention(
208
- q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
209
-
210
- out = out.transpose(1, 2).contiguous()
211
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/cache_utils.py DELETED
@@ -1,80 +0,0 @@
1
- import numpy as np
2
- import torch
3
-
4
- def get_teacache_coefficients(model_name):
5
- if "wan2.1-t2v-1.3b" in model_name.lower() or "wan2.1-fun-1.3b" in model_name.lower() \
6
- or "wan2.1-fun-v1.1-1.3b" in model_name.lower() or "wan2.1-vace-1.3b" in model_name.lower():
7
- return [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02]
8
- elif "wan2.1-t2v-14b" in model_name.lower():
9
- return [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01]
10
- elif "wan2.1-i2v-14b-480p" in model_name.lower():
11
- return [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01]
12
- elif "wan2.1-i2v-14b-720p" in model_name.lower() or "wan2.1-fun-14b" in model_name.lower() or "wan2.2-fun" in model_name.lower() \
13
- or "wan2.2-i2v-a14b" in model_name.lower() or "wan2.2-t2v-a14b" in model_name.lower() or "wan2.2-ti2v-5b" in model_name.lower() \
14
- or "wan2.2-s2v" in model_name.lower() or "wan2.1-vace-14b" in model_name.lower() or "wan2.2-vace-fun" in model_name.lower() \
15
- or "wan2.2-animate" in model_name.lower():
16
- return [8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02]
17
- elif "qwen-image" in model_name.lower():
18
- # Copied from https://github.com/chenpipi0807/ComfyUI-TeaCache/blob/main/nodes.py
19
- return [-4.50000000e+02, 2.80000000e+02, -4.50000000e+01, 3.20000000e+00, -2.00000000e-02]
20
- else:
21
- print(f"The model {model_name} is not supported by TeaCache.")
22
- return None
23
-
24
-
25
- class TeaCache():
26
- """
27
- Timestep Embedding Aware Cache, a training-free caching approach that estimates and leverages
28
- the fluctuating differences among model outputs across timesteps, thereby accelerating the inference.
29
- Please refer to:
30
- 1. https://github.com/ali-vilab/TeaCache.
31
- 2. Liu, Feng, et al. "Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model." arXiv preprint arXiv:2411.19108 (2024).
32
- """
33
- def __init__(
34
- self,
35
- coefficients: list[float],
36
- num_steps: int,
37
- rel_l1_thresh: float = 0.0,
38
- num_skip_start_steps: int = 0,
39
- offload: bool = True,
40
- ):
41
- if num_steps < 1:
42
- raise ValueError(f"`num_steps` must be greater than 0 but is {num_steps}.")
43
- if rel_l1_thresh < 0:
44
- raise ValueError(f"`rel_l1_thresh` must be greater than or equal to 0 but is {rel_l1_thresh}.")
45
- if num_skip_start_steps < 0 or num_skip_start_steps > num_steps:
46
- raise ValueError(
47
- "`num_skip_start_steps` must be great than or equal to 0 and "
48
- f"less than or equal to `num_steps={num_steps}` but is {num_skip_start_steps}."
49
- )
50
- self.coefficients = coefficients
51
- self.num_steps = num_steps
52
- self.rel_l1_thresh = rel_l1_thresh
53
- self.num_skip_start_steps = num_skip_start_steps
54
- self.offload = offload
55
- self.rescale_func = np.poly1d(self.coefficients)
56
-
57
- self.cnt = 0
58
- self.should_calc = True
59
- self.accumulated_rel_l1_distance = 0
60
- self.previous_modulated_input = None
61
- # Some pipelines concatenate the unconditional and text guide in forward.
62
- self.previous_residual = None
63
- # Some pipelines perform forward propagation separately on the unconditional and text guide.
64
- self.previous_residual_cond = None
65
- self.previous_residual_uncond = None
66
-
67
- @staticmethod
68
- def compute_rel_l1_distance(prev: torch.Tensor, cur: torch.Tensor) -> torch.Tensor:
69
- rel_l1_distance = (torch.abs(cur - prev).mean()) / torch.abs(prev).mean()
70
-
71
- return rel_l1_distance.cpu().item()
72
-
73
- def reset(self):
74
- self.cnt = 0
75
- self.should_calc = True
76
- self.accumulated_rel_l1_distance = 0
77
- self.previous_modulated_input = None
78
- self.previous_residual = None
79
- self.previous_residual_cond = None
80
- self.previous_residual_uncond = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/cogvideox_transformer3d.py DELETED
@@ -1,915 +0,0 @@
1
- # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import glob
17
- import json
18
- import os
19
- from typing import Any, Dict, Optional, Tuple, Union
20
-
21
- import torch
22
- import torch.nn.functional as F
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.models.attention import Attention, FeedForward
25
- from diffusers.models.attention_processor import (
26
- AttentionProcessor, FusedCogVideoXAttnProcessor2_0)
27
- from diffusers.models.embeddings import (CogVideoXPatchEmbed,
28
- TimestepEmbedding, Timesteps,
29
- get_3d_sincos_pos_embed)
30
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
31
- from diffusers.models.modeling_utils import ModelMixin
32
- from diffusers.models.normalization import AdaLayerNorm, CogVideoXLayerNormZero
33
- from diffusers.utils import is_torch_version, logging
34
- from diffusers.utils.torch_utils import maybe_allow_in_graph
35
- from torch import nn
36
-
37
- from ..dist import (get_sequence_parallel_rank,
38
- get_sequence_parallel_world_size, get_sp_group,
39
- xFuserLongContextAttention)
40
- from ..dist.cogvideox_xfuser import CogVideoXMultiGPUsAttnProcessor2_0
41
- from .attention_utils import attention
42
-
43
-
44
- class CogVideoXAttnProcessor2_0:
45
- r"""
46
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
47
- query and key vectors, but does not include spatial normalization.
48
- """
49
-
50
- def __init__(self):
51
- if not hasattr(F, "scaled_dot_product_attention"):
52
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
53
-
54
- def __call__(
55
- self,
56
- attn,
57
- hidden_states: torch.Tensor,
58
- encoder_hidden_states: torch.Tensor,
59
- attention_mask: torch.Tensor = None,
60
- image_rotary_emb: torch.Tensor = None,
61
- ) -> torch.Tensor:
62
- text_seq_length = encoder_hidden_states.size(1)
63
-
64
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
65
-
66
- batch_size, sequence_length, _ = hidden_states.shape
67
-
68
- if attention_mask is not None:
69
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
70
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
71
-
72
- query = attn.to_q(hidden_states)
73
- key = attn.to_k(hidden_states)
74
- value = attn.to_v(hidden_states)
75
-
76
- inner_dim = key.shape[-1]
77
- head_dim = inner_dim // attn.heads
78
-
79
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
80
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
81
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
82
-
83
- if attn.norm_q is not None:
84
- query = attn.norm_q(query)
85
- if attn.norm_k is not None:
86
- key = attn.norm_k(key)
87
-
88
- # Apply RoPE if needed
89
- if image_rotary_emb is not None:
90
- from diffusers.models.embeddings import apply_rotary_emb
91
-
92
- query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
93
- if not attn.is_cross_attention:
94
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
95
-
96
- query = query.transpose(1, 2)
97
- key = key.transpose(1, 2)
98
- value = value.transpose(1, 2)
99
-
100
- hidden_states = attention(
101
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, causal=False
102
- )
103
- hidden_states = hidden_states.reshape(batch_size, -1, attn.heads * head_dim)
104
-
105
- # linear proj
106
- hidden_states = attn.to_out[0](hidden_states)
107
- # dropout
108
- hidden_states = attn.to_out[1](hidden_states)
109
-
110
- encoder_hidden_states, hidden_states = hidden_states.split(
111
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
112
- )
113
- return hidden_states, encoder_hidden_states
114
-
115
-
116
- class CogVideoXPatchEmbed(nn.Module):
117
- def __init__(
118
- self,
119
- patch_size: int = 2,
120
- patch_size_t: Optional[int] = None,
121
- in_channels: int = 16,
122
- embed_dim: int = 1920,
123
- text_embed_dim: int = 4096,
124
- bias: bool = True,
125
- sample_width: int = 90,
126
- sample_height: int = 60,
127
- sample_frames: int = 49,
128
- temporal_compression_ratio: int = 4,
129
- max_text_seq_length: int = 226,
130
- spatial_interpolation_scale: float = 1.875,
131
- temporal_interpolation_scale: float = 1.0,
132
- use_positional_embeddings: bool = True,
133
- use_learned_positional_embeddings: bool = True,
134
- ) -> None:
135
- super().__init__()
136
-
137
- post_patch_height = sample_height // patch_size
138
- post_patch_width = sample_width // patch_size
139
- post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
140
- self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
141
- self.post_patch_height = post_patch_height
142
- self.post_patch_width = post_patch_width
143
- self.post_time_compression_frames = post_time_compression_frames
144
- self.patch_size = patch_size
145
- self.patch_size_t = patch_size_t
146
- self.embed_dim = embed_dim
147
- self.sample_height = sample_height
148
- self.sample_width = sample_width
149
- self.sample_frames = sample_frames
150
- self.temporal_compression_ratio = temporal_compression_ratio
151
- self.max_text_seq_length = max_text_seq_length
152
- self.spatial_interpolation_scale = spatial_interpolation_scale
153
- self.temporal_interpolation_scale = temporal_interpolation_scale
154
- self.use_positional_embeddings = use_positional_embeddings
155
- self.use_learned_positional_embeddings = use_learned_positional_embeddings
156
-
157
- if patch_size_t is None:
158
- # CogVideoX 1.0 checkpoints
159
- self.proj = nn.Conv2d(
160
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
161
- )
162
- else:
163
- # CogVideoX 1.5 checkpoints
164
- self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim)
165
-
166
- self.text_proj = nn.Linear(text_embed_dim, embed_dim)
167
-
168
- if use_positional_embeddings or use_learned_positional_embeddings:
169
- persistent = use_learned_positional_embeddings
170
- pos_embedding = self._get_positional_embeddings(sample_height, sample_width, sample_frames)
171
- self.register_buffer("pos_embedding", pos_embedding, persistent=persistent)
172
-
173
- def _get_positional_embeddings(self, sample_height: int, sample_width: int, sample_frames: int) -> torch.Tensor:
174
- post_patch_height = sample_height // self.patch_size
175
- post_patch_width = sample_width // self.patch_size
176
- post_time_compression_frames = (sample_frames - 1) // self.temporal_compression_ratio + 1
177
- num_patches = post_patch_height * post_patch_width * post_time_compression_frames
178
-
179
- pos_embedding = get_3d_sincos_pos_embed(
180
- self.embed_dim,
181
- (post_patch_width, post_patch_height),
182
- post_time_compression_frames,
183
- self.spatial_interpolation_scale,
184
- self.temporal_interpolation_scale,
185
- )
186
- pos_embedding = torch.from_numpy(pos_embedding).flatten(0, 1)
187
- joint_pos_embedding = torch.zeros(
188
- 1, self.max_text_seq_length + num_patches, self.embed_dim, requires_grad=False
189
- )
190
- joint_pos_embedding.data[:, self.max_text_seq_length :].copy_(pos_embedding)
191
-
192
- return joint_pos_embedding
193
-
194
- def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
195
- r"""
196
- Args:
197
- text_embeds (`torch.Tensor`):
198
- Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
199
- image_embeds (`torch.Tensor`):
200
- Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
201
- """
202
- text_embeds = self.text_proj(text_embeds)
203
-
204
- text_batch_size, text_seq_length, text_channels = text_embeds.shape
205
- batch_size, num_frames, channels, height, width = image_embeds.shape
206
-
207
- if self.patch_size_t is None:
208
- image_embeds = image_embeds.reshape(-1, channels, height, width)
209
- image_embeds = self.proj(image_embeds)
210
- image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:])
211
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
212
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
213
- else:
214
- p = self.patch_size
215
- p_t = self.patch_size_t
216
-
217
- image_embeds = image_embeds.permute(0, 1, 3, 4, 2)
218
- # b, f, h, w, c => b, f // 2, 2, h // 2, 2, w // 2, 2, c
219
- image_embeds = image_embeds.reshape(
220
- batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels
221
- )
222
- # b, f // 2, 2, h // 2, 2, w // 2, 2, c => b, f // 2, h // 2, w // 2, c, 2, 2, 2
223
- image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3)
224
- image_embeds = self.proj(image_embeds)
225
-
226
- embeds = torch.cat(
227
- [text_embeds, image_embeds], dim=1
228
- ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
229
-
230
- if self.use_positional_embeddings or self.use_learned_positional_embeddings:
231
- seq_length = height * width * num_frames // (self.patch_size**2)
232
- # pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
233
- pos_embeds = self.pos_embedding
234
- emb_size = embeds.size()[-1]
235
- pos_embeds_without_text = pos_embeds[:, text_seq_length: ].view(1, self.post_time_compression_frames, self.post_patch_height, self.post_patch_width, emb_size)
236
- pos_embeds_without_text = pos_embeds_without_text.permute([0, 4, 1, 2, 3])
237
- pos_embeds_without_text = F.interpolate(pos_embeds_without_text,size=[self.post_time_compression_frames, height // self.patch_size, width // self.patch_size], mode='trilinear', align_corners=False)
238
- pos_embeds_without_text = pos_embeds_without_text.permute([0, 2, 3, 4, 1]).view(1, -1, emb_size)
239
- pos_embeds = torch.cat([pos_embeds[:, :text_seq_length], pos_embeds_without_text], dim = 1)
240
- pos_embeds = pos_embeds[:, : text_seq_length + seq_length]
241
- embeds = embeds + pos_embeds
242
-
243
- return embeds
244
-
245
- @maybe_allow_in_graph
246
- class CogVideoXBlock(nn.Module):
247
- r"""
248
- Transformer block used in [CogVideoX](https://github.com/THUDM/CogVideo) model.
249
-
250
- Parameters:
251
- dim (`int`):
252
- The number of channels in the input and output.
253
- num_attention_heads (`int`):
254
- The number of heads to use for multi-head attention.
255
- attention_head_dim (`int`):
256
- The number of channels in each head.
257
- time_embed_dim (`int`):
258
- The number of channels in timestep embedding.
259
- dropout (`float`, defaults to `0.0`):
260
- The dropout probability to use.
261
- activation_fn (`str`, defaults to `"gelu-approximate"`):
262
- Activation function to be used in feed-forward.
263
- attention_bias (`bool`, defaults to `False`):
264
- Whether or not to use bias in attention projection layers.
265
- qk_norm (`bool`, defaults to `True`):
266
- Whether or not to use normalization after query and key projections in Attention.
267
- norm_elementwise_affine (`bool`, defaults to `True`):
268
- Whether to use learnable elementwise affine parameters for normalization.
269
- norm_eps (`float`, defaults to `1e-5`):
270
- Epsilon value for normalization layers.
271
- final_dropout (`bool` defaults to `False`):
272
- Whether to apply a final dropout after the last feed-forward layer.
273
- ff_inner_dim (`int`, *optional*, defaults to `None`):
274
- Custom hidden dimension of Feed-forward layer. If not provided, `4 * dim` is used.
275
- ff_bias (`bool`, defaults to `True`):
276
- Whether or not to use bias in Feed-forward layer.
277
- attention_out_bias (`bool`, defaults to `True`):
278
- Whether or not to use bias in Attention output projection layer.
279
- """
280
-
281
- def __init__(
282
- self,
283
- dim: int,
284
- num_attention_heads: int,
285
- attention_head_dim: int,
286
- time_embed_dim: int,
287
- dropout: float = 0.0,
288
- activation_fn: str = "gelu-approximate",
289
- attention_bias: bool = False,
290
- qk_norm: bool = True,
291
- norm_elementwise_affine: bool = True,
292
- norm_eps: float = 1e-5,
293
- final_dropout: bool = True,
294
- ff_inner_dim: Optional[int] = None,
295
- ff_bias: bool = True,
296
- attention_out_bias: bool = True,
297
- ):
298
- super().__init__()
299
-
300
- # 1. Self Attention
301
- self.norm1 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
302
-
303
- self.attn1 = Attention(
304
- query_dim=dim,
305
- dim_head=attention_head_dim,
306
- heads=num_attention_heads,
307
- qk_norm="layer_norm" if qk_norm else None,
308
- eps=1e-6,
309
- bias=attention_bias,
310
- out_bias=attention_out_bias,
311
- processor=CogVideoXAttnProcessor2_0(),
312
- )
313
-
314
- # 2. Feed Forward
315
- self.norm2 = CogVideoXLayerNormZero(time_embed_dim, dim, norm_elementwise_affine, norm_eps, bias=True)
316
-
317
- self.ff = FeedForward(
318
- dim,
319
- dropout=dropout,
320
- activation_fn=activation_fn,
321
- final_dropout=final_dropout,
322
- inner_dim=ff_inner_dim,
323
- bias=ff_bias,
324
- )
325
-
326
- def forward(
327
- self,
328
- hidden_states: torch.Tensor,
329
- encoder_hidden_states: torch.Tensor,
330
- temb: torch.Tensor,
331
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
332
- ) -> torch.Tensor:
333
- text_seq_length = encoder_hidden_states.size(1)
334
-
335
- # norm & modulate
336
- norm_hidden_states, norm_encoder_hidden_states, gate_msa, enc_gate_msa = self.norm1(
337
- hidden_states, encoder_hidden_states, temb
338
- )
339
-
340
- # attention
341
- attn_hidden_states, attn_encoder_hidden_states = self.attn1(
342
- hidden_states=norm_hidden_states,
343
- encoder_hidden_states=norm_encoder_hidden_states,
344
- image_rotary_emb=image_rotary_emb,
345
- )
346
-
347
- hidden_states = hidden_states + gate_msa * attn_hidden_states
348
- encoder_hidden_states = encoder_hidden_states + enc_gate_msa * attn_encoder_hidden_states
349
-
350
- # norm & modulate
351
- norm_hidden_states, norm_encoder_hidden_states, gate_ff, enc_gate_ff = self.norm2(
352
- hidden_states, encoder_hidden_states, temb
353
- )
354
-
355
- # feed-forward
356
- norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
357
- ff_output = self.ff(norm_hidden_states)
358
-
359
- hidden_states = hidden_states + gate_ff * ff_output[:, text_seq_length:]
360
- encoder_hidden_states = encoder_hidden_states + enc_gate_ff * ff_output[:, :text_seq_length]
361
-
362
- return hidden_states, encoder_hidden_states
363
-
364
-
365
- class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
366
- """
367
- A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
368
-
369
- Parameters:
370
- num_attention_heads (`int`, defaults to `30`):
371
- The number of heads to use for multi-head attention.
372
- attention_head_dim (`int`, defaults to `64`):
373
- The number of channels in each head.
374
- in_channels (`int`, defaults to `16`):
375
- The number of channels in the input.
376
- out_channels (`int`, *optional*, defaults to `16`):
377
- The number of channels in the output.
378
- flip_sin_to_cos (`bool`, defaults to `True`):
379
- Whether to flip the sin to cos in the time embedding.
380
- time_embed_dim (`int`, defaults to `512`):
381
- Output dimension of timestep embeddings.
382
- text_embed_dim (`int`, defaults to `4096`):
383
- Input dimension of text embeddings from the text encoder.
384
- num_layers (`int`, defaults to `30`):
385
- The number of layers of Transformer blocks to use.
386
- dropout (`float`, defaults to `0.0`):
387
- The dropout probability to use.
388
- attention_bias (`bool`, defaults to `True`):
389
- Whether or not to use bias in the attention projection layers.
390
- sample_width (`int`, defaults to `90`):
391
- The width of the input latents.
392
- sample_height (`int`, defaults to `60`):
393
- The height of the input latents.
394
- sample_frames (`int`, defaults to `49`):
395
- The number of frames in the input latents. Note that this parameter was incorrectly initialized to 49
396
- instead of 13 because CogVideoX processed 13 latent frames at once in its default and recommended settings,
397
- but cannot be changed to the correct value to ensure backwards compatibility. To create a transformer with
398
- K latent frames, the correct value to pass here would be: ((K - 1) * temporal_compression_ratio + 1).
399
- patch_size (`int`, defaults to `2`):
400
- The size of the patches to use in the patch embedding layer.
401
- temporal_compression_ratio (`int`, defaults to `4`):
402
- The compression ratio across the temporal dimension. See documentation for `sample_frames`.
403
- max_text_seq_length (`int`, defaults to `226`):
404
- The maximum sequence length of the input text embeddings.
405
- activation_fn (`str`, defaults to `"gelu-approximate"`):
406
- Activation function to use in feed-forward.
407
- timestep_activation_fn (`str`, defaults to `"silu"`):
408
- Activation function to use when generating the timestep embeddings.
409
- norm_elementwise_affine (`bool`, defaults to `True`):
410
- Whether or not to use elementwise affine in normalization layers.
411
- norm_eps (`float`, defaults to `1e-5`):
412
- The epsilon value to use in normalization layers.
413
- spatial_interpolation_scale (`float`, defaults to `1.875`):
414
- Scaling factor to apply in 3D positional embeddings across spatial dimensions.
415
- temporal_interpolation_scale (`float`, defaults to `1.0`):
416
- Scaling factor to apply in 3D positional embeddings across temporal dimensions.
417
- """
418
-
419
- _supports_gradient_checkpointing = True
420
-
421
- @register_to_config
422
- def __init__(
423
- self,
424
- num_attention_heads: int = 30,
425
- attention_head_dim: int = 64,
426
- in_channels: int = 16,
427
- out_channels: Optional[int] = 16,
428
- flip_sin_to_cos: bool = True,
429
- freq_shift: int = 0,
430
- time_embed_dim: int = 512,
431
- text_embed_dim: int = 4096,
432
- num_layers: int = 30,
433
- dropout: float = 0.0,
434
- attention_bias: bool = True,
435
- sample_width: int = 90,
436
- sample_height: int = 60,
437
- sample_frames: int = 49,
438
- patch_size: int = 2,
439
- patch_size_t: Optional[int] = None,
440
- temporal_compression_ratio: int = 4,
441
- max_text_seq_length: int = 226,
442
- activation_fn: str = "gelu-approximate",
443
- timestep_activation_fn: str = "silu",
444
- norm_elementwise_affine: bool = True,
445
- norm_eps: float = 1e-5,
446
- spatial_interpolation_scale: float = 1.875,
447
- temporal_interpolation_scale: float = 1.0,
448
- use_rotary_positional_embeddings: bool = False,
449
- use_learned_positional_embeddings: bool = False,
450
- patch_bias: bool = True,
451
- add_noise_in_inpaint_model: bool = False,
452
- ):
453
- super().__init__()
454
- inner_dim = num_attention_heads * attention_head_dim
455
- self.patch_size_t = patch_size_t
456
- if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
457
- raise ValueError(
458
- "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
459
- "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
460
- "issue at https://github.com/huggingface/diffusers/issues."
461
- )
462
-
463
- # 1. Patch embedding
464
- self.patch_embed = CogVideoXPatchEmbed(
465
- patch_size=patch_size,
466
- patch_size_t=patch_size_t,
467
- in_channels=in_channels,
468
- embed_dim=inner_dim,
469
- text_embed_dim=text_embed_dim,
470
- bias=patch_bias,
471
- sample_width=sample_width,
472
- sample_height=sample_height,
473
- sample_frames=sample_frames,
474
- temporal_compression_ratio=temporal_compression_ratio,
475
- max_text_seq_length=max_text_seq_length,
476
- spatial_interpolation_scale=spatial_interpolation_scale,
477
- temporal_interpolation_scale=temporal_interpolation_scale,
478
- use_positional_embeddings=not use_rotary_positional_embeddings,
479
- use_learned_positional_embeddings=use_learned_positional_embeddings,
480
- )
481
- self.embedding_dropout = nn.Dropout(dropout)
482
-
483
- # 2. Time embeddings
484
- self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
485
- self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
486
-
487
- # 3. Define spatio-temporal transformers blocks
488
- self.transformer_blocks = nn.ModuleList(
489
- [
490
- CogVideoXBlock(
491
- dim=inner_dim,
492
- num_attention_heads=num_attention_heads,
493
- attention_head_dim=attention_head_dim,
494
- time_embed_dim=time_embed_dim,
495
- dropout=dropout,
496
- activation_fn=activation_fn,
497
- attention_bias=attention_bias,
498
- norm_elementwise_affine=norm_elementwise_affine,
499
- norm_eps=norm_eps,
500
- )
501
- for _ in range(num_layers)
502
- ]
503
- )
504
- self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
505
-
506
- # 4. Output blocks
507
- self.norm_out = AdaLayerNorm(
508
- embedding_dim=time_embed_dim,
509
- output_dim=2 * inner_dim,
510
- norm_elementwise_affine=norm_elementwise_affine,
511
- norm_eps=norm_eps,
512
- chunk_dim=1,
513
- )
514
-
515
- if patch_size_t is None:
516
- # For CogVideox 1.0
517
- output_dim = patch_size * patch_size * out_channels
518
- else:
519
- # For CogVideoX 1.5
520
- output_dim = patch_size * patch_size * patch_size_t * out_channels
521
-
522
- self.proj_out = nn.Linear(inner_dim, output_dim)
523
-
524
- self.gradient_checkpointing = False
525
- self.sp_world_size = 1
526
- self.sp_world_rank = 0
527
-
528
- def _set_gradient_checkpointing(self, *args, **kwargs):
529
- if "value" in kwargs:
530
- self.gradient_checkpointing = kwargs["value"]
531
- elif "enable" in kwargs:
532
- self.gradient_checkpointing = kwargs["enable"]
533
- else:
534
- raise ValueError("Invalid set gradient checkpointing")
535
-
536
- def enable_multi_gpus_inference(self,):
537
- self.sp_world_size = get_sequence_parallel_world_size()
538
- self.sp_world_rank = get_sequence_parallel_rank()
539
- self.set_attn_processor(CogVideoXMultiGPUsAttnProcessor2_0())
540
-
541
- @property
542
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
543
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
544
- r"""
545
- Returns:
546
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
547
- indexed by its weight name.
548
- """
549
- # set recursively
550
- processors = {}
551
-
552
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
553
- if hasattr(module, "get_processor"):
554
- processors[f"{name}.processor"] = module.get_processor()
555
-
556
- for sub_name, child in module.named_children():
557
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
558
-
559
- return processors
560
-
561
- for name, module in self.named_children():
562
- fn_recursive_add_processors(name, module, processors)
563
-
564
- return processors
565
-
566
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
567
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
568
- r"""
569
- Sets the attention processor to use to compute attention.
570
-
571
- Parameters:
572
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
573
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
574
- for **all** `Attention` layers.
575
-
576
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
577
- processor. This is strongly recommended when setting trainable attention processors.
578
-
579
- """
580
- count = len(self.attn_processors.keys())
581
-
582
- if isinstance(processor, dict) and len(processor) != count:
583
- raise ValueError(
584
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
585
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
586
- )
587
-
588
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
589
- if hasattr(module, "set_processor"):
590
- if not isinstance(processor, dict):
591
- module.set_processor(processor)
592
- else:
593
- module.set_processor(processor.pop(f"{name}.processor"))
594
-
595
- for sub_name, child in module.named_children():
596
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
597
-
598
- for name, module in self.named_children():
599
- fn_recursive_attn_processor(name, module, processor)
600
-
601
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedCogVideoXAttnProcessor2_0
602
- def fuse_qkv_projections(self):
603
- """
604
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
605
- are fused. For cross-attention modules, key and value projection matrices are fused.
606
-
607
- <Tip warning={true}>
608
-
609
- This API is 🧪 experimental.
610
-
611
- </Tip>
612
- """
613
- self.original_attn_processors = None
614
-
615
- for _, attn_processor in self.attn_processors.items():
616
- if "Added" in str(attn_processor.__class__.__name__):
617
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
618
-
619
- self.original_attn_processors = self.attn_processors
620
-
621
- for module in self.modules():
622
- if isinstance(module, Attention):
623
- module.fuse_projections(fuse=True)
624
-
625
- self.set_attn_processor(FusedCogVideoXAttnProcessor2_0())
626
-
627
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
628
- def unfuse_qkv_projections(self):
629
- """Disables the fused QKV projection if enabled.
630
-
631
- <Tip warning={true}>
632
-
633
- This API is 🧪 experimental.
634
-
635
- </Tip>
636
-
637
- """
638
- if self.original_attn_processors is not None:
639
- self.set_attn_processor(self.original_attn_processors)
640
-
641
- def forward(
642
- self,
643
- hidden_states: torch.Tensor,
644
- encoder_hidden_states: torch.Tensor,
645
- timestep: Union[int, float, torch.LongTensor],
646
- timestep_cond: Optional[torch.Tensor] = None,
647
- inpaint_latents: Optional[torch.Tensor] = None,
648
- control_latents: Optional[torch.Tensor] = None,
649
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
650
- return_dict: bool = True,
651
- ):
652
- batch_size, num_frames, channels, height, width = hidden_states.shape
653
- if num_frames == 1 and self.patch_size_t is not None:
654
- hidden_states = torch.cat([hidden_states, torch.zeros_like(hidden_states)], dim=1)
655
- if inpaint_latents is not None:
656
- inpaint_latents = torch.concat([inpaint_latents, torch.zeros_like(inpaint_latents)], dim=1)
657
- if control_latents is not None:
658
- control_latents = torch.concat([control_latents, torch.zeros_like(control_latents)], dim=1)
659
- local_num_frames = num_frames + 1
660
- else:
661
- local_num_frames = num_frames
662
-
663
- # 1. Time embedding
664
- timesteps = timestep
665
- t_emb = self.time_proj(timesteps)
666
-
667
- # timesteps does not contain any weights and will always return f32 tensors
668
- # but time_embedding might actually be running in fp16. so we need to cast here.
669
- # there might be better ways to encapsulate this.
670
- t_emb = t_emb.to(dtype=hidden_states.dtype)
671
- emb = self.time_embedding(t_emb, timestep_cond)
672
-
673
- # 2. Patch embedding
674
- if inpaint_latents is not None:
675
- hidden_states = torch.concat([hidden_states, inpaint_latents], 2)
676
- if control_latents is not None:
677
- hidden_states = torch.concat([hidden_states, control_latents], 2)
678
- hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
679
- hidden_states = self.embedding_dropout(hidden_states)
680
-
681
- text_seq_length = encoder_hidden_states.shape[1]
682
- encoder_hidden_states = hidden_states[:, :text_seq_length]
683
- hidden_states = hidden_states[:, text_seq_length:]
684
-
685
- # Context Parallel
686
- if self.sp_world_size > 1:
687
- hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
688
- if image_rotary_emb is not None:
689
- image_rotary_emb = (
690
- torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
691
- torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
692
- )
693
-
694
- # 3. Transformer blocks
695
- for i, block in enumerate(self.transformer_blocks):
696
- if torch.is_grad_enabled() and self.gradient_checkpointing:
697
-
698
- def create_custom_forward(module):
699
- def custom_forward(*inputs):
700
- return module(*inputs)
701
-
702
- return custom_forward
703
-
704
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
705
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
706
- create_custom_forward(block),
707
- hidden_states,
708
- encoder_hidden_states,
709
- emb,
710
- image_rotary_emb,
711
- **ckpt_kwargs,
712
- )
713
- else:
714
- hidden_states, encoder_hidden_states = block(
715
- hidden_states=hidden_states,
716
- encoder_hidden_states=encoder_hidden_states,
717
- temb=emb,
718
- image_rotary_emb=image_rotary_emb,
719
- )
720
-
721
- if not self.config.use_rotary_positional_embeddings:
722
- # CogVideoX-2B
723
- hidden_states = self.norm_final(hidden_states)
724
- else:
725
- # CogVideoX-5B
726
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
727
- hidden_states = self.norm_final(hidden_states)
728
- hidden_states = hidden_states[:, text_seq_length:]
729
-
730
- # 4. Final block
731
- hidden_states = self.norm_out(hidden_states, temb=emb)
732
- hidden_states = self.proj_out(hidden_states)
733
-
734
- if self.sp_world_size > 1:
735
- hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
736
-
737
- # 5. Unpatchify
738
- p = self.config.patch_size
739
- p_t = self.config.patch_size_t
740
-
741
- if p_t is None:
742
- output = hidden_states.reshape(batch_size, local_num_frames, height // p, width // p, -1, p, p)
743
- output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
744
- else:
745
- output = hidden_states.reshape(
746
- batch_size, (local_num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
747
- )
748
- output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
749
-
750
- if num_frames == 1:
751
- output = output[:, :num_frames, :]
752
-
753
- if not return_dict:
754
- return (output,)
755
- return Transformer2DModelOutput(sample=output)
756
-
757
- @classmethod
758
- def from_pretrained(
759
- cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
760
- low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
761
- ):
762
- if subfolder is not None:
763
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
764
- print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
765
-
766
- config_file = os.path.join(pretrained_model_path, 'config.json')
767
- if not os.path.isfile(config_file):
768
- raise RuntimeError(f"{config_file} does not exist")
769
- with open(config_file, "r") as f:
770
- config = json.load(f)
771
-
772
- from diffusers.utils import WEIGHTS_NAME
773
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
774
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
775
-
776
- if "dict_mapping" in transformer_additional_kwargs.keys():
777
- for key in transformer_additional_kwargs["dict_mapping"]:
778
- transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
779
-
780
- if low_cpu_mem_usage:
781
- try:
782
- import re
783
-
784
- from diffusers import __version__ as diffusers_version
785
- if diffusers_version >= "0.33.0":
786
- from diffusers.models.model_loading_utils import \
787
- load_model_dict_into_meta
788
- else:
789
- from diffusers.models.modeling_utils import \
790
- load_model_dict_into_meta
791
- from diffusers.utils import is_accelerate_available
792
- if is_accelerate_available():
793
- import accelerate
794
-
795
- # Instantiate model with empty weights
796
- with accelerate.init_empty_weights():
797
- model = cls.from_config(config, **transformer_additional_kwargs)
798
-
799
- param_device = "cpu"
800
- if os.path.exists(model_file):
801
- state_dict = torch.load(model_file, map_location="cpu")
802
- elif os.path.exists(model_file_safetensors):
803
- from safetensors.torch import load_file, safe_open
804
- state_dict = load_file(model_file_safetensors)
805
- else:
806
- from safetensors.torch import load_file, safe_open
807
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
808
- state_dict = {}
809
- for _model_file_safetensors in model_files_safetensors:
810
- _state_dict = load_file(_model_file_safetensors)
811
- for key in _state_dict:
812
- state_dict[key] = _state_dict[key]
813
- model._convert_deprecated_attention_blocks(state_dict)
814
-
815
- if diffusers_version >= "0.33.0":
816
- # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
817
- # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
818
- load_model_dict_into_meta(
819
- model,
820
- state_dict,
821
- dtype=torch_dtype,
822
- model_name_or_path=pretrained_model_path,
823
- )
824
- else:
825
- # move the params from meta device to cpu
826
- missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
827
- if len(missing_keys) > 0:
828
- raise ValueError(
829
- f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
830
- f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
831
- " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
832
- " those weights or else make sure your checkpoint file is correct."
833
- )
834
-
835
- unexpected_keys = load_model_dict_into_meta(
836
- model,
837
- state_dict,
838
- device=param_device,
839
- dtype=torch_dtype,
840
- model_name_or_path=pretrained_model_path,
841
- )
842
-
843
- if cls._keys_to_ignore_on_load_unexpected is not None:
844
- for pat in cls._keys_to_ignore_on_load_unexpected:
845
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
846
-
847
- if len(unexpected_keys) > 0:
848
- print(
849
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
850
- )
851
-
852
- return model
853
- except Exception as e:
854
- print(
855
- f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
856
- )
857
-
858
- model = cls.from_config(config, **transformer_additional_kwargs)
859
- if os.path.exists(model_file):
860
- state_dict = torch.load(model_file, map_location="cpu")
861
- elif os.path.exists(model_file_safetensors):
862
- from safetensors.torch import load_file, safe_open
863
- state_dict = load_file(model_file_safetensors)
864
- else:
865
- from safetensors.torch import load_file, safe_open
866
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
867
- state_dict = {}
868
- for _model_file_safetensors in model_files_safetensors:
869
- _state_dict = load_file(_model_file_safetensors)
870
- for key in _state_dict:
871
- state_dict[key] = _state_dict[key]
872
-
873
- if model.state_dict()['patch_embed.proj.weight'].size() != state_dict['patch_embed.proj.weight'].size():
874
- new_shape = model.state_dict()['patch_embed.proj.weight'].size()
875
- if len(new_shape) == 5:
876
- state_dict['patch_embed.proj.weight'] = state_dict['patch_embed.proj.weight'].unsqueeze(2).expand(new_shape).clone()
877
- state_dict['patch_embed.proj.weight'][:, :, :-1] = 0
878
- elif len(new_shape) == 2:
879
- if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
880
- model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1]] = state_dict['patch_embed.proj.weight']
881
- model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:] = 0
882
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
883
- else:
884
- model.state_dict()['patch_embed.proj.weight'][:, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1]]
885
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
886
- else:
887
- if model.state_dict()['patch_embed.proj.weight'].size()[1] > state_dict['patch_embed.proj.weight'].size()[1]:
888
- model.state_dict()['patch_embed.proj.weight'][:, :state_dict['patch_embed.proj.weight'].size()[1], :, :] = state_dict['patch_embed.proj.weight']
889
- model.state_dict()['patch_embed.proj.weight'][:, state_dict['patch_embed.proj.weight'].size()[1]:, :, :] = 0
890
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
891
- else:
892
- model.state_dict()['patch_embed.proj.weight'][:, :, :, :] = state_dict['patch_embed.proj.weight'][:, :model.state_dict()['patch_embed.proj.weight'].size()[1], :, :]
893
- state_dict['patch_embed.proj.weight'] = model.state_dict()['patch_embed.proj.weight']
894
-
895
- tmp_state_dict = {}
896
- for key in state_dict:
897
- if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
898
- tmp_state_dict[key] = state_dict[key]
899
- else:
900
- print(key, "Size don't match, skip")
901
-
902
- state_dict = tmp_state_dict
903
-
904
- m, u = model.load_state_dict(state_dict, strict=False)
905
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
906
- print(m)
907
-
908
- params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
909
- print(f"### All Parameters: {sum(params) / 1e6} M")
910
-
911
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
912
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
913
-
914
- model = model.to(torch_dtype)
915
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/cogvideox_vae.py DELETED
@@ -1,1675 +0,0 @@
1
- # Copyright 2024 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
2
- # All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Dict, Optional, Tuple, Union
17
-
18
- import numpy as np
19
- import torch
20
- import torch.nn as nn
21
- import torch.nn.functional as F
22
- import json
23
- import os
24
-
25
- from diffusers.configuration_utils import ConfigMixin, register_to_config
26
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
27
- from diffusers.utils import logging
28
- from diffusers.utils.accelerate_utils import apply_forward_hook
29
- from diffusers.models.activations import get_activation
30
- from diffusers.models.downsampling import CogVideoXDownsample3D
31
- from diffusers.models.modeling_outputs import AutoencoderKLOutput
32
- from diffusers.models.modeling_utils import ModelMixin
33
- from diffusers.models.upsampling import CogVideoXUpsample3D
34
- from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
35
-
36
-
37
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
38
-
39
-
40
- class CogVideoXSafeConv3d(nn.Conv3d):
41
- r"""
42
- A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
43
- """
44
-
45
- def forward(self, input: torch.Tensor) -> torch.Tensor:
46
- memory_count = (
47
- (input.shape[0] * input.shape[1] * input.shape[2] * input.shape[3] * input.shape[4]) * 2 / 1024**3
48
- )
49
-
50
- # Set to 2GB, suitable for CuDNN
51
- if memory_count > 2:
52
- kernel_size = self.kernel_size[0]
53
- part_num = int(memory_count / 2) + 1
54
- input_chunks = torch.chunk(input, part_num, dim=2)
55
-
56
- if kernel_size > 1:
57
- input_chunks = [input_chunks[0]] + [
58
- torch.cat((input_chunks[i - 1][:, :, -kernel_size + 1 :], input_chunks[i]), dim=2)
59
- for i in range(1, len(input_chunks))
60
- ]
61
-
62
- output_chunks = []
63
- for input_chunk in input_chunks:
64
- output_chunks.append(super().forward(input_chunk))
65
- output = torch.cat(output_chunks, dim=2)
66
- return output
67
- else:
68
- return super().forward(input)
69
-
70
-
71
- class CogVideoXCausalConv3d(nn.Module):
72
- r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
73
-
74
- Args:
75
- in_channels (`int`): Number of channels in the input tensor.
76
- out_channels (`int`): Number of output channels produced by the convolution.
77
- kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
78
- stride (`int`, defaults to `1`): Stride of the convolution.
79
- dilation (`int`, defaults to `1`): Dilation rate of the convolution.
80
- pad_mode (`str`, defaults to `"constant"`): Padding mode.
81
- """
82
-
83
- def __init__(
84
- self,
85
- in_channels: int,
86
- out_channels: int,
87
- kernel_size: Union[int, Tuple[int, int, int]],
88
- stride: int = 1,
89
- dilation: int = 1,
90
- pad_mode: str = "constant",
91
- ):
92
- super().__init__()
93
-
94
- if isinstance(kernel_size, int):
95
- kernel_size = (kernel_size,) * 3
96
-
97
- time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
98
-
99
- # TODO(aryan): configure calculation based on stride and dilation in the future.
100
- # Since CogVideoX does not use it, it is currently tailored to "just work" with Mochi
101
- time_pad = time_kernel_size - 1
102
- height_pad = (height_kernel_size - 1) // 2
103
- width_pad = (width_kernel_size - 1) // 2
104
-
105
- self.pad_mode = pad_mode
106
- self.height_pad = height_pad
107
- self.width_pad = width_pad
108
- self.time_pad = time_pad
109
- self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
110
-
111
- self.temporal_dim = 2
112
- self.time_kernel_size = time_kernel_size
113
-
114
- stride = stride if isinstance(stride, tuple) else (stride, 1, 1)
115
- dilation = (dilation, 1, 1)
116
- self.conv = CogVideoXSafeConv3d(
117
- in_channels=in_channels,
118
- out_channels=out_channels,
119
- kernel_size=kernel_size,
120
- stride=stride,
121
- dilation=dilation,
122
- )
123
-
124
- def fake_context_parallel_forward(
125
- self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None
126
- ) -> torch.Tensor:
127
- if self.pad_mode == "replicate":
128
- inputs = F.pad(inputs, self.time_causal_padding, mode="replicate")
129
- else:
130
- kernel_size = self.time_kernel_size
131
- if kernel_size > 1:
132
- cached_inputs = [conv_cache] if conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
133
- inputs = torch.cat(cached_inputs + [inputs], dim=2)
134
- return inputs
135
-
136
- def forward(self, inputs: torch.Tensor, conv_cache: Optional[torch.Tensor] = None) -> torch.Tensor:
137
- inputs = self.fake_context_parallel_forward(inputs, conv_cache)
138
-
139
- if self.pad_mode == "replicate":
140
- conv_cache = None
141
- else:
142
- padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
143
- conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
144
- inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
145
-
146
- output = self.conv(inputs)
147
- return output, conv_cache
148
-
149
-
150
- class CogVideoXSpatialNorm3D(nn.Module):
151
- r"""
152
- Spatially conditioned normalization as defined in https://arxiv.org/abs/2209.09002. This implementation is specific
153
- to 3D-video like data.
154
-
155
- CogVideoXSafeConv3d is used instead of nn.Conv3d to avoid OOM in CogVideoX Model.
156
-
157
- Args:
158
- f_channels (`int`):
159
- The number of channels for input to group normalization layer, and output of the spatial norm layer.
160
- zq_channels (`int`):
161
- The number of channels for the quantized vector as described in the paper.
162
- groups (`int`):
163
- Number of groups to separate the channels into for group normalization.
164
- """
165
-
166
- def __init__(
167
- self,
168
- f_channels: int,
169
- zq_channels: int,
170
- groups: int = 32,
171
- ):
172
- super().__init__()
173
- self.norm_layer = nn.GroupNorm(num_channels=f_channels, num_groups=groups, eps=1e-6, affine=True)
174
- self.conv_y = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
175
- self.conv_b = CogVideoXCausalConv3d(zq_channels, f_channels, kernel_size=1, stride=1)
176
-
177
- def forward(
178
- self, f: torch.Tensor, zq: torch.Tensor, conv_cache: Optional[Dict[str, torch.Tensor]] = None
179
- ) -> torch.Tensor:
180
- new_conv_cache = {}
181
- conv_cache = conv_cache or {}
182
-
183
- if f.shape[2] > 1 and f.shape[2] % 2 == 1:
184
- f_first, f_rest = f[:, :, :1], f[:, :, 1:]
185
- f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
186
- z_first, z_rest = zq[:, :, :1], zq[:, :, 1:]
187
- z_first = F.interpolate(z_first, size=f_first_size)
188
- z_rest = F.interpolate(z_rest, size=f_rest_size)
189
- zq = torch.cat([z_first, z_rest], dim=2)
190
- else:
191
- zq = F.interpolate(zq, size=f.shape[-3:])
192
-
193
- conv_y, new_conv_cache["conv_y"] = self.conv_y(zq, conv_cache=conv_cache.get("conv_y"))
194
- conv_b, new_conv_cache["conv_b"] = self.conv_b(zq, conv_cache=conv_cache.get("conv_b"))
195
-
196
- norm_f = self.norm_layer(f)
197
- new_f = norm_f * conv_y + conv_b
198
- return new_f, new_conv_cache
199
-
200
-
201
- class CogVideoXUpsample3D(nn.Module):
202
- r"""
203
- A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
204
-
205
- Args:
206
- in_channels (`int`):
207
- Number of channels in the input image.
208
- out_channels (`int`):
209
- Number of channels produced by the convolution.
210
- kernel_size (`int`, defaults to `3`):
211
- Size of the convolving kernel.
212
- stride (`int`, defaults to `1`):
213
- Stride of the convolution.
214
- padding (`int`, defaults to `1`):
215
- Padding added to all four sides of the input.
216
- compress_time (`bool`, defaults to `False`):
217
- Whether or not to compress the time dimension.
218
- """
219
-
220
- def __init__(
221
- self,
222
- in_channels: int,
223
- out_channels: int,
224
- kernel_size: int = 3,
225
- stride: int = 1,
226
- padding: int = 1,
227
- compress_time: bool = False,
228
- ) -> None:
229
- super().__init__()
230
-
231
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
232
- self.compress_time = compress_time
233
-
234
- self.auto_split_process = True
235
- self.first_frame_flag = False
236
-
237
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
238
- if self.compress_time:
239
- if self.auto_split_process:
240
- if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
241
- # split first frame
242
- x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
243
-
244
- x_first = F.interpolate(x_first, scale_factor=2.0)
245
- x_rest = F.interpolate(x_rest, scale_factor=2.0)
246
- x_first = x_first[:, :, None, :, :]
247
- inputs = torch.cat([x_first, x_rest], dim=2)
248
- elif inputs.shape[2] > 1:
249
- inputs = F.interpolate(inputs, scale_factor=2.0)
250
- else:
251
- inputs = inputs.squeeze(2)
252
- inputs = F.interpolate(inputs, scale_factor=2.0)
253
- inputs = inputs[:, :, None, :, :]
254
- else:
255
- if self.first_frame_flag:
256
- inputs = inputs.squeeze(2)
257
- inputs = F.interpolate(inputs, scale_factor=2.0)
258
- inputs = inputs[:, :, None, :, :]
259
- else:
260
- inputs = F.interpolate(inputs, scale_factor=2.0)
261
- else:
262
- # only interpolate 2D
263
- b, c, t, h, w = inputs.shape
264
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
265
- inputs = F.interpolate(inputs, scale_factor=2.0)
266
- inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
267
-
268
- b, c, t, h, w = inputs.shape
269
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
270
- inputs = self.conv(inputs)
271
- inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
272
-
273
- return inputs
274
-
275
-
276
- class CogVideoXResnetBlock3D(nn.Module):
277
- r"""
278
- A 3D ResNet block used in the CogVideoX model.
279
-
280
- Args:
281
- in_channels (`int`):
282
- Number of input channels.
283
- out_channels (`int`, *optional*):
284
- Number of output channels. If None, defaults to `in_channels`.
285
- dropout (`float`, defaults to `0.0`):
286
- Dropout rate.
287
- temb_channels (`int`, defaults to `512`):
288
- Number of time embedding channels.
289
- groups (`int`, defaults to `32`):
290
- Number of groups to separate the channels into for group normalization.
291
- eps (`float`, defaults to `1e-6`):
292
- Epsilon value for normalization layers.
293
- non_linearity (`str`, defaults to `"swish"`):
294
- Activation function to use.
295
- conv_shortcut (bool, defaults to `False`):
296
- Whether or not to use a convolution shortcut.
297
- spatial_norm_dim (`int`, *optional*):
298
- The dimension to use for spatial norm if it is to be used instead of group norm.
299
- pad_mode (str, defaults to `"first"`):
300
- Padding mode.
301
- """
302
-
303
- def __init__(
304
- self,
305
- in_channels: int,
306
- out_channels: Optional[int] = None,
307
- dropout: float = 0.0,
308
- temb_channels: int = 512,
309
- groups: int = 32,
310
- eps: float = 1e-6,
311
- non_linearity: str = "swish",
312
- conv_shortcut: bool = False,
313
- spatial_norm_dim: Optional[int] = None,
314
- pad_mode: str = "first",
315
- ):
316
- super().__init__()
317
-
318
- out_channels = out_channels or in_channels
319
-
320
- self.in_channels = in_channels
321
- self.out_channels = out_channels
322
- self.nonlinearity = get_activation(non_linearity)
323
- self.use_conv_shortcut = conv_shortcut
324
- self.spatial_norm_dim = spatial_norm_dim
325
-
326
- if spatial_norm_dim is None:
327
- self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=groups, eps=eps)
328
- self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=groups, eps=eps)
329
- else:
330
- self.norm1 = CogVideoXSpatialNorm3D(
331
- f_channels=in_channels,
332
- zq_channels=spatial_norm_dim,
333
- groups=groups,
334
- )
335
- self.norm2 = CogVideoXSpatialNorm3D(
336
- f_channels=out_channels,
337
- zq_channels=spatial_norm_dim,
338
- groups=groups,
339
- )
340
-
341
- self.conv1 = CogVideoXCausalConv3d(
342
- in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
343
- )
344
-
345
- if temb_channels > 0:
346
- self.temb_proj = nn.Linear(in_features=temb_channels, out_features=out_channels)
347
-
348
- self.dropout = nn.Dropout(dropout)
349
- self.conv2 = CogVideoXCausalConv3d(
350
- in_channels=out_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
351
- )
352
-
353
- if self.in_channels != self.out_channels:
354
- if self.use_conv_shortcut:
355
- self.conv_shortcut = CogVideoXCausalConv3d(
356
- in_channels=in_channels, out_channels=out_channels, kernel_size=3, pad_mode=pad_mode
357
- )
358
- else:
359
- self.conv_shortcut = CogVideoXSafeConv3d(
360
- in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=1, padding=0
361
- )
362
-
363
- def forward(
364
- self,
365
- inputs: torch.Tensor,
366
- temb: Optional[torch.Tensor] = None,
367
- zq: Optional[torch.Tensor] = None,
368
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
369
- ) -> torch.Tensor:
370
- new_conv_cache = {}
371
- conv_cache = conv_cache or {}
372
-
373
- hidden_states = inputs
374
-
375
- if zq is not None:
376
- hidden_states, new_conv_cache["norm1"] = self.norm1(hidden_states, zq, conv_cache=conv_cache.get("norm1"))
377
- else:
378
- hidden_states = self.norm1(hidden_states)
379
-
380
- hidden_states = self.nonlinearity(hidden_states)
381
- hidden_states, new_conv_cache["conv1"] = self.conv1(hidden_states, conv_cache=conv_cache.get("conv1"))
382
-
383
- if temb is not None:
384
- hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
385
-
386
- if zq is not None:
387
- hidden_states, new_conv_cache["norm2"] = self.norm2(hidden_states, zq, conv_cache=conv_cache.get("norm2"))
388
- else:
389
- hidden_states = self.norm2(hidden_states)
390
-
391
- hidden_states = self.nonlinearity(hidden_states)
392
- hidden_states = self.dropout(hidden_states)
393
- hidden_states, new_conv_cache["conv2"] = self.conv2(hidden_states, conv_cache=conv_cache.get("conv2"))
394
-
395
- if self.in_channels != self.out_channels:
396
- if self.use_conv_shortcut:
397
- inputs, new_conv_cache["conv_shortcut"] = self.conv_shortcut(
398
- inputs, conv_cache=conv_cache.get("conv_shortcut")
399
- )
400
- else:
401
- inputs = self.conv_shortcut(inputs)
402
-
403
- hidden_states = hidden_states + inputs
404
- return hidden_states, new_conv_cache
405
-
406
-
407
- class CogVideoXDownBlock3D(nn.Module):
408
- r"""
409
- A downsampling block used in the CogVideoX model.
410
-
411
- Args:
412
- in_channels (`int`):
413
- Number of input channels.
414
- out_channels (`int`, *optional*):
415
- Number of output channels. If None, defaults to `in_channels`.
416
- temb_channels (`int`, defaults to `512`):
417
- Number of time embedding channels.
418
- num_layers (`int`, defaults to `1`):
419
- Number of resnet layers.
420
- dropout (`float`, defaults to `0.0`):
421
- Dropout rate.
422
- resnet_eps (`float`, defaults to `1e-6`):
423
- Epsilon value for normalization layers.
424
- resnet_act_fn (`str`, defaults to `"swish"`):
425
- Activation function to use.
426
- resnet_groups (`int`, defaults to `32`):
427
- Number of groups to separate the channels into for group normalization.
428
- add_downsample (`bool`, defaults to `True`):
429
- Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
430
- compress_time (`bool`, defaults to `False`):
431
- Whether or not to downsample across temporal dimension.
432
- pad_mode (str, defaults to `"first"`):
433
- Padding mode.
434
- """
435
-
436
- _supports_gradient_checkpointing = True
437
-
438
- def __init__(
439
- self,
440
- in_channels: int,
441
- out_channels: int,
442
- temb_channels: int,
443
- dropout: float = 0.0,
444
- num_layers: int = 1,
445
- resnet_eps: float = 1e-6,
446
- resnet_act_fn: str = "swish",
447
- resnet_groups: int = 32,
448
- add_downsample: bool = True,
449
- downsample_padding: int = 0,
450
- compress_time: bool = False,
451
- pad_mode: str = "first",
452
- ):
453
- super().__init__()
454
-
455
- resnets = []
456
- for i in range(num_layers):
457
- in_channel = in_channels if i == 0 else out_channels
458
- resnets.append(
459
- CogVideoXResnetBlock3D(
460
- in_channels=in_channel,
461
- out_channels=out_channels,
462
- dropout=dropout,
463
- temb_channels=temb_channels,
464
- groups=resnet_groups,
465
- eps=resnet_eps,
466
- non_linearity=resnet_act_fn,
467
- pad_mode=pad_mode,
468
- )
469
- )
470
-
471
- self.resnets = nn.ModuleList(resnets)
472
- self.downsamplers = None
473
-
474
- if add_downsample:
475
- self.downsamplers = nn.ModuleList(
476
- [
477
- CogVideoXDownsample3D(
478
- out_channels, out_channels, padding=downsample_padding, compress_time=compress_time
479
- )
480
- ]
481
- )
482
-
483
- self.gradient_checkpointing = False
484
-
485
- def forward(
486
- self,
487
- hidden_states: torch.Tensor,
488
- temb: Optional[torch.Tensor] = None,
489
- zq: Optional[torch.Tensor] = None,
490
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
491
- ) -> torch.Tensor:
492
- r"""Forward method of the `CogVideoXDownBlock3D` class."""
493
-
494
- new_conv_cache = {}
495
- conv_cache = conv_cache or {}
496
-
497
- for i, resnet in enumerate(self.resnets):
498
- conv_cache_key = f"resnet_{i}"
499
-
500
- if torch.is_grad_enabled() and self.gradient_checkpointing:
501
-
502
- def create_custom_forward(module):
503
- def create_forward(*inputs):
504
- return module(*inputs)
505
-
506
- return create_forward
507
-
508
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
509
- create_custom_forward(resnet),
510
- hidden_states,
511
- temb,
512
- zq,
513
- conv_cache.get(conv_cache_key),
514
- )
515
- else:
516
- hidden_states, new_conv_cache[conv_cache_key] = resnet(
517
- hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
518
- )
519
-
520
- if self.downsamplers is not None:
521
- for downsampler in self.downsamplers:
522
- hidden_states = downsampler(hidden_states)
523
-
524
- return hidden_states, new_conv_cache
525
-
526
-
527
- class CogVideoXMidBlock3D(nn.Module):
528
- r"""
529
- A middle block used in the CogVideoX model.
530
-
531
- Args:
532
- in_channels (`int`):
533
- Number of input channels.
534
- temb_channels (`int`, defaults to `512`):
535
- Number of time embedding channels.
536
- dropout (`float`, defaults to `0.0`):
537
- Dropout rate.
538
- num_layers (`int`, defaults to `1`):
539
- Number of resnet layers.
540
- resnet_eps (`float`, defaults to `1e-6`):
541
- Epsilon value for normalization layers.
542
- resnet_act_fn (`str`, defaults to `"swish"`):
543
- Activation function to use.
544
- resnet_groups (`int`, defaults to `32`):
545
- Number of groups to separate the channels into for group normalization.
546
- spatial_norm_dim (`int`, *optional*):
547
- The dimension to use for spatial norm if it is to be used instead of group norm.
548
- pad_mode (str, defaults to `"first"`):
549
- Padding mode.
550
- """
551
-
552
- _supports_gradient_checkpointing = True
553
-
554
- def __init__(
555
- self,
556
- in_channels: int,
557
- temb_channels: int,
558
- dropout: float = 0.0,
559
- num_layers: int = 1,
560
- resnet_eps: float = 1e-6,
561
- resnet_act_fn: str = "swish",
562
- resnet_groups: int = 32,
563
- spatial_norm_dim: Optional[int] = None,
564
- pad_mode: str = "first",
565
- ):
566
- super().__init__()
567
-
568
- resnets = []
569
- for _ in range(num_layers):
570
- resnets.append(
571
- CogVideoXResnetBlock3D(
572
- in_channels=in_channels,
573
- out_channels=in_channels,
574
- dropout=dropout,
575
- temb_channels=temb_channels,
576
- groups=resnet_groups,
577
- eps=resnet_eps,
578
- spatial_norm_dim=spatial_norm_dim,
579
- non_linearity=resnet_act_fn,
580
- pad_mode=pad_mode,
581
- )
582
- )
583
- self.resnets = nn.ModuleList(resnets)
584
-
585
- self.gradient_checkpointing = False
586
-
587
- def forward(
588
- self,
589
- hidden_states: torch.Tensor,
590
- temb: Optional[torch.Tensor] = None,
591
- zq: Optional[torch.Tensor] = None,
592
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
593
- ) -> torch.Tensor:
594
- r"""Forward method of the `CogVideoXMidBlock3D` class."""
595
-
596
- new_conv_cache = {}
597
- conv_cache = conv_cache or {}
598
-
599
- for i, resnet in enumerate(self.resnets):
600
- conv_cache_key = f"resnet_{i}"
601
-
602
- if torch.is_grad_enabled() and self.gradient_checkpointing:
603
-
604
- def create_custom_forward(module):
605
- def create_forward(*inputs):
606
- return module(*inputs)
607
-
608
- return create_forward
609
-
610
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
611
- create_custom_forward(resnet), hidden_states, temb, zq, conv_cache.get(conv_cache_key)
612
- )
613
- else:
614
- hidden_states, new_conv_cache[conv_cache_key] = resnet(
615
- hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
616
- )
617
-
618
- return hidden_states, new_conv_cache
619
-
620
-
621
- class CogVideoXUpBlock3D(nn.Module):
622
- r"""
623
- An upsampling block used in the CogVideoX model.
624
-
625
- Args:
626
- in_channels (`int`):
627
- Number of input channels.
628
- out_channels (`int`, *optional*):
629
- Number of output channels. If None, defaults to `in_channels`.
630
- temb_channels (`int`, defaults to `512`):
631
- Number of time embedding channels.
632
- dropout (`float`, defaults to `0.0`):
633
- Dropout rate.
634
- num_layers (`int`, defaults to `1`):
635
- Number of resnet layers.
636
- resnet_eps (`float`, defaults to `1e-6`):
637
- Epsilon value for normalization layers.
638
- resnet_act_fn (`str`, defaults to `"swish"`):
639
- Activation function to use.
640
- resnet_groups (`int`, defaults to `32`):
641
- Number of groups to separate the channels into for group normalization.
642
- spatial_norm_dim (`int`, defaults to `16`):
643
- The dimension to use for spatial norm if it is to be used instead of group norm.
644
- add_upsample (`bool`, defaults to `True`):
645
- Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
646
- compress_time (`bool`, defaults to `False`):
647
- Whether or not to downsample across temporal dimension.
648
- pad_mode (str, defaults to `"first"`):
649
- Padding mode.
650
- """
651
-
652
- def __init__(
653
- self,
654
- in_channels: int,
655
- out_channels: int,
656
- temb_channels: int,
657
- dropout: float = 0.0,
658
- num_layers: int = 1,
659
- resnet_eps: float = 1e-6,
660
- resnet_act_fn: str = "swish",
661
- resnet_groups: int = 32,
662
- spatial_norm_dim: int = 16,
663
- add_upsample: bool = True,
664
- upsample_padding: int = 1,
665
- compress_time: bool = False,
666
- pad_mode: str = "first",
667
- ):
668
- super().__init__()
669
-
670
- resnets = []
671
- for i in range(num_layers):
672
- in_channel = in_channels if i == 0 else out_channels
673
- resnets.append(
674
- CogVideoXResnetBlock3D(
675
- in_channels=in_channel,
676
- out_channels=out_channels,
677
- dropout=dropout,
678
- temb_channels=temb_channels,
679
- groups=resnet_groups,
680
- eps=resnet_eps,
681
- non_linearity=resnet_act_fn,
682
- spatial_norm_dim=spatial_norm_dim,
683
- pad_mode=pad_mode,
684
- )
685
- )
686
-
687
- self.resnets = nn.ModuleList(resnets)
688
- self.upsamplers = None
689
-
690
- if add_upsample:
691
- self.upsamplers = nn.ModuleList(
692
- [
693
- CogVideoXUpsample3D(
694
- out_channels, out_channels, padding=upsample_padding, compress_time=compress_time
695
- )
696
- ]
697
- )
698
-
699
- self.gradient_checkpointing = False
700
-
701
- def forward(
702
- self,
703
- hidden_states: torch.Tensor,
704
- temb: Optional[torch.Tensor] = None,
705
- zq: Optional[torch.Tensor] = None,
706
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
707
- ) -> torch.Tensor:
708
- r"""Forward method of the `CogVideoXUpBlock3D` class."""
709
-
710
- new_conv_cache = {}
711
- conv_cache = conv_cache or {}
712
-
713
- for i, resnet in enumerate(self.resnets):
714
- conv_cache_key = f"resnet_{i}"
715
-
716
- if torch.is_grad_enabled() and self.gradient_checkpointing:
717
-
718
- def create_custom_forward(module):
719
- def create_forward(*inputs):
720
- return module(*inputs)
721
-
722
- return create_forward
723
-
724
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
725
- create_custom_forward(resnet),
726
- hidden_states,
727
- temb,
728
- zq,
729
- conv_cache.get(conv_cache_key),
730
- )
731
- else:
732
- hidden_states, new_conv_cache[conv_cache_key] = resnet(
733
- hidden_states, temb, zq, conv_cache=conv_cache.get(conv_cache_key)
734
- )
735
-
736
- if self.upsamplers is not None:
737
- for upsampler in self.upsamplers:
738
- hidden_states = upsampler(hidden_states)
739
-
740
- return hidden_states, new_conv_cache
741
-
742
-
743
- class CogVideoXEncoder3D(nn.Module):
744
- r"""
745
- The `CogVideoXEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
746
-
747
- Args:
748
- in_channels (`int`, *optional*, defaults to 3):
749
- The number of input channels.
750
- out_channels (`int`, *optional*, defaults to 3):
751
- The number of output channels.
752
- down_block_types (`Tuple[str, ...]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
753
- The types of down blocks to use. See `~diffusers.models.unet_2d_blocks.get_down_block` for available
754
- options.
755
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
756
- The number of output channels for each block.
757
- act_fn (`str`, *optional*, defaults to `"silu"`):
758
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
759
- layers_per_block (`int`, *optional*, defaults to 2):
760
- The number of layers per block.
761
- norm_num_groups (`int`, *optional*, defaults to 32):
762
- The number of groups for normalization.
763
- """
764
-
765
- _supports_gradient_checkpointing = True
766
-
767
- def __init__(
768
- self,
769
- in_channels: int = 3,
770
- out_channels: int = 16,
771
- down_block_types: Tuple[str, ...] = (
772
- "CogVideoXDownBlock3D",
773
- "CogVideoXDownBlock3D",
774
- "CogVideoXDownBlock3D",
775
- "CogVideoXDownBlock3D",
776
- ),
777
- block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
778
- layers_per_block: int = 3,
779
- act_fn: str = "silu",
780
- norm_eps: float = 1e-6,
781
- norm_num_groups: int = 32,
782
- dropout: float = 0.0,
783
- pad_mode: str = "first",
784
- temporal_compression_ratio: float = 4,
785
- ):
786
- super().__init__()
787
-
788
- # log2 of temporal_compress_times
789
- temporal_compress_level = int(np.log2(temporal_compression_ratio))
790
-
791
- self.conv_in = CogVideoXCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, pad_mode=pad_mode)
792
- self.down_blocks = nn.ModuleList([])
793
-
794
- # down blocks
795
- output_channel = block_out_channels[0]
796
- for i, down_block_type in enumerate(down_block_types):
797
- input_channel = output_channel
798
- output_channel = block_out_channels[i]
799
- is_final_block = i == len(block_out_channels) - 1
800
- compress_time = i < temporal_compress_level
801
-
802
- if down_block_type == "CogVideoXDownBlock3D":
803
- down_block = CogVideoXDownBlock3D(
804
- in_channels=input_channel,
805
- out_channels=output_channel,
806
- temb_channels=0,
807
- dropout=dropout,
808
- num_layers=layers_per_block,
809
- resnet_eps=norm_eps,
810
- resnet_act_fn=act_fn,
811
- resnet_groups=norm_num_groups,
812
- add_downsample=not is_final_block,
813
- compress_time=compress_time,
814
- )
815
- else:
816
- raise ValueError("Invalid `down_block_type` encountered. Must be `CogVideoXDownBlock3D`")
817
-
818
- self.down_blocks.append(down_block)
819
-
820
- # mid block
821
- self.mid_block = CogVideoXMidBlock3D(
822
- in_channels=block_out_channels[-1],
823
- temb_channels=0,
824
- dropout=dropout,
825
- num_layers=2,
826
- resnet_eps=norm_eps,
827
- resnet_act_fn=act_fn,
828
- resnet_groups=norm_num_groups,
829
- pad_mode=pad_mode,
830
- )
831
-
832
- self.norm_out = nn.GroupNorm(norm_num_groups, block_out_channels[-1], eps=1e-6)
833
- self.conv_act = nn.SiLU()
834
- self.conv_out = CogVideoXCausalConv3d(
835
- block_out_channels[-1], 2 * out_channels, kernel_size=3, pad_mode=pad_mode
836
- )
837
-
838
- self.gradient_checkpointing = False
839
-
840
- def forward(
841
- self,
842
- sample: torch.Tensor,
843
- temb: Optional[torch.Tensor] = None,
844
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
845
- ) -> torch.Tensor:
846
- r"""The forward method of the `CogVideoXEncoder3D` class."""
847
-
848
- new_conv_cache = {}
849
- conv_cache = conv_cache or {}
850
-
851
- hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
852
-
853
- if torch.is_grad_enabled() and self.gradient_checkpointing:
854
-
855
- def create_custom_forward(module):
856
- def custom_forward(*inputs):
857
- return module(*inputs)
858
-
859
- return custom_forward
860
-
861
- # 1. Down
862
- for i, down_block in enumerate(self.down_blocks):
863
- conv_cache_key = f"down_block_{i}"
864
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
865
- create_custom_forward(down_block),
866
- hidden_states,
867
- temb,
868
- None,
869
- conv_cache.get(conv_cache_key),
870
- )
871
-
872
- # 2. Mid
873
- hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
874
- create_custom_forward(self.mid_block),
875
- hidden_states,
876
- temb,
877
- None,
878
- conv_cache.get("mid_block"),
879
- )
880
- else:
881
- # 1. Down
882
- for i, down_block in enumerate(self.down_blocks):
883
- conv_cache_key = f"down_block_{i}"
884
- hidden_states, new_conv_cache[conv_cache_key] = down_block(
885
- hidden_states, temb, None, conv_cache=conv_cache.get(conv_cache_key)
886
- )
887
-
888
- # 2. Mid
889
- hidden_states, new_conv_cache["mid_block"] = self.mid_block(
890
- hidden_states, temb, None, conv_cache=conv_cache.get("mid_block")
891
- )
892
-
893
- # 3. Post-process
894
- hidden_states = self.norm_out(hidden_states)
895
- hidden_states = self.conv_act(hidden_states)
896
-
897
- hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
898
-
899
- return hidden_states, new_conv_cache
900
-
901
-
902
- class CogVideoXDecoder3D(nn.Module):
903
- r"""
904
- The `CogVideoXDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
905
- sample.
906
-
907
- Args:
908
- in_channels (`int`, *optional*, defaults to 3):
909
- The number of input channels.
910
- out_channels (`int`, *optional*, defaults to 3):
911
- The number of output channels.
912
- up_block_types (`Tuple[str, ...]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
913
- The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
914
- block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
915
- The number of output channels for each block.
916
- act_fn (`str`, *optional*, defaults to `"silu"`):
917
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
918
- layers_per_block (`int`, *optional*, defaults to 2):
919
- The number of layers per block.
920
- norm_num_groups (`int`, *optional*, defaults to 32):
921
- The number of groups for normalization.
922
- """
923
-
924
- _supports_gradient_checkpointing = True
925
-
926
- def __init__(
927
- self,
928
- in_channels: int = 16,
929
- out_channels: int = 3,
930
- up_block_types: Tuple[str, ...] = (
931
- "CogVideoXUpBlock3D",
932
- "CogVideoXUpBlock3D",
933
- "CogVideoXUpBlock3D",
934
- "CogVideoXUpBlock3D",
935
- ),
936
- block_out_channels: Tuple[int, ...] = (128, 256, 256, 512),
937
- layers_per_block: int = 3,
938
- act_fn: str = "silu",
939
- norm_eps: float = 1e-6,
940
- norm_num_groups: int = 32,
941
- dropout: float = 0.0,
942
- pad_mode: str = "first",
943
- temporal_compression_ratio: float = 4,
944
- ):
945
- super().__init__()
946
-
947
- reversed_block_out_channels = list(reversed(block_out_channels))
948
-
949
- self.conv_in = CogVideoXCausalConv3d(
950
- in_channels, reversed_block_out_channels[0], kernel_size=3, pad_mode=pad_mode
951
- )
952
-
953
- # mid block
954
- self.mid_block = CogVideoXMidBlock3D(
955
- in_channels=reversed_block_out_channels[0],
956
- temb_channels=0,
957
- num_layers=2,
958
- resnet_eps=norm_eps,
959
- resnet_act_fn=act_fn,
960
- resnet_groups=norm_num_groups,
961
- spatial_norm_dim=in_channels,
962
- pad_mode=pad_mode,
963
- )
964
-
965
- # up blocks
966
- self.up_blocks = nn.ModuleList([])
967
-
968
- output_channel = reversed_block_out_channels[0]
969
- temporal_compress_level = int(np.log2(temporal_compression_ratio))
970
-
971
- for i, up_block_type in enumerate(up_block_types):
972
- prev_output_channel = output_channel
973
- output_channel = reversed_block_out_channels[i]
974
- is_final_block = i == len(block_out_channels) - 1
975
- compress_time = i < temporal_compress_level
976
-
977
- if up_block_type == "CogVideoXUpBlock3D":
978
- up_block = CogVideoXUpBlock3D(
979
- in_channels=prev_output_channel,
980
- out_channels=output_channel,
981
- temb_channels=0,
982
- dropout=dropout,
983
- num_layers=layers_per_block + 1,
984
- resnet_eps=norm_eps,
985
- resnet_act_fn=act_fn,
986
- resnet_groups=norm_num_groups,
987
- spatial_norm_dim=in_channels,
988
- add_upsample=not is_final_block,
989
- compress_time=compress_time,
990
- pad_mode=pad_mode,
991
- )
992
- prev_output_channel = output_channel
993
- else:
994
- raise ValueError("Invalid `up_block_type` encountered. Must be `CogVideoXUpBlock3D`")
995
-
996
- self.up_blocks.append(up_block)
997
-
998
- self.norm_out = CogVideoXSpatialNorm3D(reversed_block_out_channels[-1], in_channels, groups=norm_num_groups)
999
- self.conv_act = nn.SiLU()
1000
- self.conv_out = CogVideoXCausalConv3d(
1001
- reversed_block_out_channels[-1], out_channels, kernel_size=3, pad_mode=pad_mode
1002
- )
1003
-
1004
- self.gradient_checkpointing = False
1005
-
1006
- def forward(
1007
- self,
1008
- sample: torch.Tensor,
1009
- temb: Optional[torch.Tensor] = None,
1010
- conv_cache: Optional[Dict[str, torch.Tensor]] = None,
1011
- ) -> torch.Tensor:
1012
- r"""The forward method of the `CogVideoXDecoder3D` class."""
1013
-
1014
- new_conv_cache = {}
1015
- conv_cache = conv_cache or {}
1016
-
1017
- hidden_states, new_conv_cache["conv_in"] = self.conv_in(sample, conv_cache=conv_cache.get("conv_in"))
1018
-
1019
- if torch.is_grad_enabled() and self.gradient_checkpointing:
1020
-
1021
- def create_custom_forward(module):
1022
- def custom_forward(*inputs):
1023
- return module(*inputs)
1024
-
1025
- return custom_forward
1026
-
1027
- # 1. Mid
1028
- hidden_states, new_conv_cache["mid_block"] = torch.utils.checkpoint.checkpoint(
1029
- create_custom_forward(self.mid_block),
1030
- hidden_states,
1031
- temb,
1032
- sample,
1033
- conv_cache.get("mid_block"),
1034
- )
1035
-
1036
- # 2. Up
1037
- for i, up_block in enumerate(self.up_blocks):
1038
- conv_cache_key = f"up_block_{i}"
1039
- hidden_states, new_conv_cache[conv_cache_key] = torch.utils.checkpoint.checkpoint(
1040
- create_custom_forward(up_block),
1041
- hidden_states,
1042
- temb,
1043
- sample,
1044
- conv_cache.get(conv_cache_key),
1045
- )
1046
- else:
1047
- # 1. Mid
1048
- hidden_states, new_conv_cache["mid_block"] = self.mid_block(
1049
- hidden_states, temb, sample, conv_cache=conv_cache.get("mid_block")
1050
- )
1051
-
1052
- # 2. Up
1053
- for i, up_block in enumerate(self.up_blocks):
1054
- conv_cache_key = f"up_block_{i}"
1055
- hidden_states, new_conv_cache[conv_cache_key] = up_block(
1056
- hidden_states, temb, sample, conv_cache=conv_cache.get(conv_cache_key)
1057
- )
1058
-
1059
- # 3. Post-process
1060
- hidden_states, new_conv_cache["norm_out"] = self.norm_out(
1061
- hidden_states, sample, conv_cache=conv_cache.get("norm_out")
1062
- )
1063
- hidden_states = self.conv_act(hidden_states)
1064
- hidden_states, new_conv_cache["conv_out"] = self.conv_out(hidden_states, conv_cache=conv_cache.get("conv_out"))
1065
-
1066
- return hidden_states, new_conv_cache
1067
-
1068
-
1069
- class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
1070
- r"""
1071
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
1072
- [CogVideoX](https://github.com/THUDM/CogVideo).
1073
-
1074
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
1075
- for all models (such as downloading or saving).
1076
-
1077
- Parameters:
1078
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
1079
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
1080
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
1081
- Tuple of downsample block types.
1082
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
1083
- Tuple of upsample block types.
1084
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
1085
- Tuple of block output channels.
1086
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
1087
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
1088
- scaling_factor (`float`, *optional*, defaults to `1.15258426`):
1089
- The component-wise standard deviation of the trained latent space computed using the first batch of the
1090
- training set. This is used to scale the latent space to have unit variance when training the diffusion
1091
- model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
1092
- diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
1093
- / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
1094
- Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
1095
- force_upcast (`bool`, *optional*, default to `True`):
1096
- If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
1097
- can be fine-tuned / trained to a lower range without loosing too much precision in which case
1098
- `force_upcast` can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
1099
- """
1100
-
1101
- _supports_gradient_checkpointing = True
1102
- _no_split_modules = ["CogVideoXResnetBlock3D"]
1103
-
1104
- @register_to_config
1105
- def __init__(
1106
- self,
1107
- in_channels: int = 3,
1108
- out_channels: int = 3,
1109
- down_block_types: Tuple[str] = (
1110
- "CogVideoXDownBlock3D",
1111
- "CogVideoXDownBlock3D",
1112
- "CogVideoXDownBlock3D",
1113
- "CogVideoXDownBlock3D",
1114
- ),
1115
- up_block_types: Tuple[str] = (
1116
- "CogVideoXUpBlock3D",
1117
- "CogVideoXUpBlock3D",
1118
- "CogVideoXUpBlock3D",
1119
- "CogVideoXUpBlock3D",
1120
- ),
1121
- block_out_channels: Tuple[int] = (128, 256, 256, 512),
1122
- latent_channels: int = 16,
1123
- layers_per_block: int = 3,
1124
- act_fn: str = "silu",
1125
- norm_eps: float = 1e-6,
1126
- norm_num_groups: int = 32,
1127
- temporal_compression_ratio: float = 4,
1128
- sample_height: int = 480,
1129
- sample_width: int = 720,
1130
- scaling_factor: float = 1.15258426,
1131
- shift_factor: Optional[float] = None,
1132
- latents_mean: Optional[Tuple[float]] = None,
1133
- latents_std: Optional[Tuple[float]] = None,
1134
- force_upcast: float = True,
1135
- use_quant_conv: bool = False,
1136
- use_post_quant_conv: bool = False,
1137
- invert_scale_latents: bool = False,
1138
- ):
1139
- super().__init__()
1140
-
1141
- self.encoder = CogVideoXEncoder3D(
1142
- in_channels=in_channels,
1143
- out_channels=latent_channels,
1144
- down_block_types=down_block_types,
1145
- block_out_channels=block_out_channels,
1146
- layers_per_block=layers_per_block,
1147
- act_fn=act_fn,
1148
- norm_eps=norm_eps,
1149
- norm_num_groups=norm_num_groups,
1150
- temporal_compression_ratio=temporal_compression_ratio,
1151
- )
1152
- self.decoder = CogVideoXDecoder3D(
1153
- in_channels=latent_channels,
1154
- out_channels=out_channels,
1155
- up_block_types=up_block_types,
1156
- block_out_channels=block_out_channels,
1157
- layers_per_block=layers_per_block,
1158
- act_fn=act_fn,
1159
- norm_eps=norm_eps,
1160
- norm_num_groups=norm_num_groups,
1161
- temporal_compression_ratio=temporal_compression_ratio,
1162
- )
1163
- self.quant_conv = CogVideoXSafeConv3d(2 * out_channels, 2 * out_channels, 1) if use_quant_conv else None
1164
- self.post_quant_conv = CogVideoXSafeConv3d(out_channels, out_channels, 1) if use_post_quant_conv else None
1165
-
1166
- self.use_slicing = False
1167
- self.use_tiling = False
1168
- self.auto_split_process = False
1169
-
1170
- # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
1171
- # recommended because the temporal parts of the VAE, here, are tricky to understand.
1172
- # If you decode X latent frames together, the number of output frames is:
1173
- # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
1174
- #
1175
- # Example with num_latent_frames_batch_size = 2:
1176
- # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
1177
- # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1178
- # => 6 * 8 = 48 frames
1179
- # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
1180
- # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
1181
- # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
1182
- # => 1 * 9 + 5 * 8 = 49 frames
1183
- # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
1184
- # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
1185
- # number of temporal frames.
1186
- self.num_latent_frames_batch_size = 2
1187
- self.num_sample_frames_batch_size = 8
1188
-
1189
- # We make the minimum height and width of sample for tiling half that of the generally supported
1190
- self.tile_sample_min_height = sample_height // 2
1191
- self.tile_sample_min_width = sample_width // 2
1192
- self.tile_latent_min_height = int(
1193
- self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1194
- )
1195
- self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1196
-
1197
- # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
1198
- # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1199
- # and so the tiling implementation has only been tested on those specific resolutions.
1200
- self.tile_overlap_factor_height = 1 / 6
1201
- self.tile_overlap_factor_width = 1 / 5
1202
-
1203
- def _set_gradient_checkpointing(self, module, value=False):
1204
- if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1205
- module.gradient_checkpointing = value
1206
-
1207
- def enable_tiling(
1208
- self,
1209
- tile_sample_min_height: Optional[int] = None,
1210
- tile_sample_min_width: Optional[int] = None,
1211
- tile_overlap_factor_height: Optional[float] = None,
1212
- tile_overlap_factor_width: Optional[float] = None,
1213
- ) -> None:
1214
- r"""
1215
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1216
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1217
- processing larger images.
1218
-
1219
- Args:
1220
- tile_sample_min_height (`int`, *optional*):
1221
- The minimum height required for a sample to be separated into tiles across the height dimension.
1222
- tile_sample_min_width (`int`, *optional*):
1223
- The minimum width required for a sample to be separated into tiles across the width dimension.
1224
- tile_overlap_factor_height (`int`, *optional*):
1225
- The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1226
- no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1227
- value might cause more tiles to be processed leading to slow down of the decoding process.
1228
- tile_overlap_factor_width (`int`, *optional*):
1229
- The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1230
- are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1231
- value might cause more tiles to be processed leading to slow down of the decoding process.
1232
- """
1233
- self.use_tiling = True
1234
- self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1235
- self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1236
- self.tile_latent_min_height = int(
1237
- self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1238
- )
1239
- self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1240
- self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1241
- self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1242
-
1243
- def disable_tiling(self) -> None:
1244
- r"""
1245
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1246
- decoding in one step.
1247
- """
1248
- self.use_tiling = False
1249
-
1250
- def enable_slicing(self) -> None:
1251
- r"""
1252
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1253
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1254
- """
1255
- self.use_slicing = True
1256
-
1257
- def disable_slicing(self) -> None:
1258
- r"""
1259
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1260
- decoding in one step.
1261
- """
1262
- self.use_slicing = False
1263
-
1264
- def _set_first_frame(self):
1265
- for name, module in self.named_modules():
1266
- if isinstance(module, CogVideoXUpsample3D):
1267
- module.auto_split_process = False
1268
- module.first_frame_flag = True
1269
-
1270
- def _set_rest_frame(self):
1271
- for name, module in self.named_modules():
1272
- if isinstance(module, CogVideoXUpsample3D):
1273
- module.auto_split_process = False
1274
- module.first_frame_flag = False
1275
-
1276
- def enable_auto_split_process(self) -> None:
1277
- self.auto_split_process = True
1278
- for name, module in self.named_modules():
1279
- if isinstance(module, CogVideoXUpsample3D):
1280
- module.auto_split_process = True
1281
-
1282
- def disable_auto_split_process(self) -> None:
1283
- self.auto_split_process = False
1284
-
1285
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
1286
- batch_size, num_channels, num_frames, height, width = x.shape
1287
-
1288
- if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
1289
- return self.tiled_encode(x)
1290
-
1291
- frame_batch_size = self.num_sample_frames_batch_size
1292
- # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1293
- # As the extra single frame is handled inside the loop, it is not required to round up here.
1294
- num_batches = max(num_frames // frame_batch_size, 1)
1295
- conv_cache = None
1296
- enc = []
1297
-
1298
- for i in range(num_batches):
1299
- remaining_frames = num_frames % frame_batch_size
1300
- start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1301
- end_frame = frame_batch_size * (i + 1) + remaining_frames
1302
- x_intermediate = x[:, :, start_frame:end_frame]
1303
- x_intermediate, conv_cache = self.encoder(x_intermediate, conv_cache=conv_cache)
1304
- if self.quant_conv is not None:
1305
- x_intermediate = self.quant_conv(x_intermediate)
1306
- enc.append(x_intermediate)
1307
-
1308
- enc = torch.cat(enc, dim=2)
1309
- return enc
1310
-
1311
- @apply_forward_hook
1312
- def encode(
1313
- self, x: torch.Tensor, return_dict: bool = True
1314
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1315
- """
1316
- Encode a batch of images into latents.
1317
-
1318
- Args:
1319
- x (`torch.Tensor`): Input batch of images.
1320
- return_dict (`bool`, *optional*, defaults to `True`):
1321
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
1322
-
1323
- Returns:
1324
- The latent representations of the encoded videos. If `return_dict` is True, a
1325
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
1326
- """
1327
- if self.use_slicing and x.shape[0] > 1:
1328
- encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
1329
- h = torch.cat(encoded_slices)
1330
- else:
1331
- h = self._encode(x)
1332
-
1333
- posterior = DiagonalGaussianDistribution(h)
1334
-
1335
- if not return_dict:
1336
- return (posterior,)
1337
- return AutoencoderKLOutput(latent_dist=posterior)
1338
-
1339
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1340
- batch_size, num_channels, num_frames, height, width = z.shape
1341
-
1342
- if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1343
- return self.tiled_decode(z, return_dict=return_dict)
1344
-
1345
- if self.auto_split_process:
1346
- frame_batch_size = self.num_latent_frames_batch_size
1347
- num_batches = max(num_frames // frame_batch_size, 1)
1348
- conv_cache = None
1349
- dec = []
1350
-
1351
- for i in range(num_batches):
1352
- remaining_frames = num_frames % frame_batch_size
1353
- start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1354
- end_frame = frame_batch_size * (i + 1) + remaining_frames
1355
- z_intermediate = z[:, :, start_frame:end_frame]
1356
- if self.post_quant_conv is not None:
1357
- z_intermediate = self.post_quant_conv(z_intermediate)
1358
- z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1359
- dec.append(z_intermediate)
1360
- else:
1361
- conv_cache = None
1362
- start_frame = 0
1363
- end_frame = 1
1364
- dec = []
1365
-
1366
- self._set_first_frame()
1367
- z_intermediate = z[:, :, start_frame:end_frame]
1368
- if self.post_quant_conv is not None:
1369
- z_intermediate = self.post_quant_conv(z_intermediate)
1370
- z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1371
- dec.append(z_intermediate)
1372
-
1373
- self._set_rest_frame()
1374
- start_frame = end_frame
1375
- end_frame += self.num_latent_frames_batch_size
1376
-
1377
- while start_frame < num_frames:
1378
- z_intermediate = z[:, :, start_frame:end_frame]
1379
- if self.post_quant_conv is not None:
1380
- z_intermediate = self.post_quant_conv(z_intermediate)
1381
- z_intermediate, conv_cache = self.decoder(z_intermediate, conv_cache=conv_cache)
1382
- dec.append(z_intermediate)
1383
- start_frame = end_frame
1384
- end_frame += self.num_latent_frames_batch_size
1385
-
1386
- dec = torch.cat(dec, dim=2)
1387
-
1388
- if not return_dict:
1389
- return (dec,)
1390
-
1391
- return DecoderOutput(sample=dec)
1392
-
1393
- @apply_forward_hook
1394
- def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1395
- """
1396
- Decode a batch of images.
1397
-
1398
- Args:
1399
- z (`torch.Tensor`): Input batch of latent vectors.
1400
- return_dict (`bool`, *optional*, defaults to `True`):
1401
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1402
-
1403
- Returns:
1404
- [`~models.vae.DecoderOutput`] or `tuple`:
1405
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1406
- returned.
1407
- """
1408
- if self.use_slicing and z.shape[0] > 1:
1409
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1410
- decoded = torch.cat(decoded_slices)
1411
- else:
1412
- decoded = self._decode(z).sample
1413
-
1414
- if not return_dict:
1415
- return (decoded,)
1416
- return DecoderOutput(sample=decoded)
1417
-
1418
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1419
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1420
- for y in range(blend_extent):
1421
- b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1422
- y / blend_extent
1423
- )
1424
- return b
1425
-
1426
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1427
- blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1428
- for x in range(blend_extent):
1429
- b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1430
- x / blend_extent
1431
- )
1432
- return b
1433
-
1434
- def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
1435
- r"""Encode a batch of images using a tiled encoder.
1436
-
1437
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
1438
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
1439
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
1440
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
1441
- output, but they should be much less noticeable.
1442
-
1443
- Args:
1444
- x (`torch.Tensor`): Input batch of videos.
1445
-
1446
- Returns:
1447
- `torch.Tensor`:
1448
- The latent representation of the encoded videos.
1449
- """
1450
- # For a rough memory estimate, take a look at the `tiled_decode` method.
1451
- batch_size, num_channels, num_frames, height, width = x.shape
1452
-
1453
- overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
1454
- overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
1455
- blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
1456
- blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
1457
- row_limit_height = self.tile_latent_min_height - blend_extent_height
1458
- row_limit_width = self.tile_latent_min_width - blend_extent_width
1459
- frame_batch_size = self.num_sample_frames_batch_size
1460
-
1461
- # Split x into overlapping tiles and encode them separately.
1462
- # The tiles have an overlap to avoid seams between tiles.
1463
- rows = []
1464
- for i in range(0, height, overlap_height):
1465
- row = []
1466
- for j in range(0, width, overlap_width):
1467
- # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k.
1468
- # As the extra single frame is handled inside the loop, it is not required to round up here.
1469
- num_batches = max(num_frames // frame_batch_size, 1)
1470
- conv_cache = None
1471
- time = []
1472
-
1473
- for k in range(num_batches):
1474
- remaining_frames = num_frames % frame_batch_size
1475
- start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1476
- end_frame = frame_batch_size * (k + 1) + remaining_frames
1477
- tile = x[
1478
- :,
1479
- :,
1480
- start_frame:end_frame,
1481
- i : i + self.tile_sample_min_height,
1482
- j : j + self.tile_sample_min_width,
1483
- ]
1484
- tile, conv_cache = self.encoder(tile, conv_cache=conv_cache)
1485
- if self.quant_conv is not None:
1486
- tile = self.quant_conv(tile)
1487
- time.append(tile)
1488
-
1489
- row.append(torch.cat(time, dim=2))
1490
- rows.append(row)
1491
-
1492
- result_rows = []
1493
- for i, row in enumerate(rows):
1494
- result_row = []
1495
- for j, tile in enumerate(row):
1496
- # blend the above tile and the left tile
1497
- # to the current tile and add the current tile to the result row
1498
- if i > 0:
1499
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1500
- if j > 0:
1501
- tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1502
- result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1503
- result_rows.append(torch.cat(result_row, dim=4))
1504
-
1505
- enc = torch.cat(result_rows, dim=3)
1506
- return enc
1507
-
1508
- def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1509
- r"""
1510
- Decode a batch of images using a tiled decoder.
1511
-
1512
- Args:
1513
- z (`torch.Tensor`): Input batch of latent vectors.
1514
- return_dict (`bool`, *optional*, defaults to `True`):
1515
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1516
-
1517
- Returns:
1518
- [`~models.vae.DecoderOutput`] or `tuple`:
1519
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1520
- returned.
1521
- """
1522
- # Rough memory assessment:
1523
- # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1524
- # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1525
- # - Assume fp16 (2 bytes per value).
1526
- # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1527
- #
1528
- # Memory assessment when using tiling:
1529
- # - Assume everything as above but now HxW is 240x360 by tiling in half
1530
- # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1531
-
1532
- batch_size, num_channels, num_frames, height, width = z.shape
1533
-
1534
- overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1535
- overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1536
- blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1537
- blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1538
- row_limit_height = self.tile_sample_min_height - blend_extent_height
1539
- row_limit_width = self.tile_sample_min_width - blend_extent_width
1540
- frame_batch_size = self.num_latent_frames_batch_size
1541
-
1542
- # Split z into overlapping tiles and decode them separately.
1543
- # The tiles have an overlap to avoid seams between tiles.
1544
- rows = []
1545
- for i in range(0, height, overlap_height):
1546
- row = []
1547
- for j in range(0, width, overlap_width):
1548
- if self.auto_split_process:
1549
- num_batches = max(num_frames // frame_batch_size, 1)
1550
- conv_cache = None
1551
- time = []
1552
-
1553
- for k in range(num_batches):
1554
- remaining_frames = num_frames % frame_batch_size
1555
- start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1556
- end_frame = frame_batch_size * (k + 1) + remaining_frames
1557
- tile = z[
1558
- :,
1559
- :,
1560
- start_frame:end_frame,
1561
- i : i + self.tile_latent_min_height,
1562
- j : j + self.tile_latent_min_width,
1563
- ]
1564
- if self.post_quant_conv is not None:
1565
- tile = self.post_quant_conv(tile)
1566
- tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1567
- time.append(tile)
1568
-
1569
- row.append(torch.cat(time, dim=2))
1570
- else:
1571
- conv_cache = None
1572
- start_frame = 0
1573
- end_frame = 1
1574
- dec = []
1575
-
1576
- tile = z[
1577
- :,
1578
- :,
1579
- start_frame:end_frame,
1580
- i : i + self.tile_latent_min_height,
1581
- j : j + self.tile_latent_min_width,
1582
- ]
1583
-
1584
- self._set_first_frame()
1585
- if self.post_quant_conv is not None:
1586
- tile = self.post_quant_conv(tile)
1587
- tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1588
- dec.append(tile)
1589
-
1590
- self._set_rest_frame()
1591
- start_frame = end_frame
1592
- end_frame += self.num_latent_frames_batch_size
1593
-
1594
- while start_frame < num_frames:
1595
- tile = z[
1596
- :,
1597
- :,
1598
- start_frame:end_frame,
1599
- i : i + self.tile_latent_min_height,
1600
- j : j + self.tile_latent_min_width,
1601
- ]
1602
- if self.post_quant_conv is not None:
1603
- tile = self.post_quant_conv(tile)
1604
- tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
1605
- dec.append(tile)
1606
- start_frame = end_frame
1607
- end_frame += self.num_latent_frames_batch_size
1608
-
1609
- row.append(torch.cat(dec, dim=2))
1610
- rows.append(row)
1611
-
1612
- result_rows = []
1613
- for i, row in enumerate(rows):
1614
- result_row = []
1615
- for j, tile in enumerate(row):
1616
- # blend the above tile and the left tile
1617
- # to the current tile and add the current tile to the result row
1618
- if i > 0:
1619
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1620
- if j > 0:
1621
- tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1622
- result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1623
- result_rows.append(torch.cat(result_row, dim=4))
1624
-
1625
- dec = torch.cat(result_rows, dim=3)
1626
-
1627
- if not return_dict:
1628
- return (dec,)
1629
-
1630
- return DecoderOutput(sample=dec)
1631
-
1632
- def forward(
1633
- self,
1634
- sample: torch.Tensor,
1635
- sample_posterior: bool = False,
1636
- return_dict: bool = True,
1637
- generator: Optional[torch.Generator] = None,
1638
- ) -> Union[torch.Tensor, torch.Tensor]:
1639
- x = sample
1640
- posterior = self.encode(x).latent_dist
1641
- if sample_posterior:
1642
- z = posterior.sample(generator=generator)
1643
- else:
1644
- z = posterior.mode()
1645
- dec = self.decode(z)
1646
- if not return_dict:
1647
- return (dec,)
1648
- return dec
1649
-
1650
- @classmethod
1651
- def from_pretrained(cls, pretrained_model_path, subfolder=None, **vae_additional_kwargs):
1652
- if subfolder is not None:
1653
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1654
-
1655
- config_file = os.path.join(pretrained_model_path, 'config.json')
1656
- if not os.path.isfile(config_file):
1657
- raise RuntimeError(f"{config_file} does not exist")
1658
- with open(config_file, "r") as f:
1659
- config = json.load(f)
1660
-
1661
- model = cls.from_config(config, **vae_additional_kwargs)
1662
- from diffusers.utils import WEIGHTS_NAME
1663
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1664
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
1665
- if os.path.exists(model_file_safetensors):
1666
- from safetensors.torch import load_file, safe_open
1667
- state_dict = load_file(model_file_safetensors)
1668
- else:
1669
- if not os.path.isfile(model_file):
1670
- raise RuntimeError(f"{model_file} does not exist")
1671
- state_dict = torch.load(model_file, map_location="cpu")
1672
- m, u = model.load_state_dict(state_dict, strict=False)
1673
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1674
- print(m, u)
1675
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/fantasytalking_audio_encoder.py DELETED
@@ -1,52 +0,0 @@
1
- # Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- import math
4
-
5
- import librosa
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- from diffusers.configuration_utils import ConfigMixin
10
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
- from diffusers.models.modeling_utils import ModelMixin
12
- from transformers import Wav2Vec2Model, Wav2Vec2Processor
13
-
14
-
15
- class FantasyTalkingAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin):
16
- def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'):
17
- super(FantasyTalkingAudioEncoder, self).__init__()
18
- # load pretrained model
19
- self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path)
20
- self.model = Wav2Vec2Model.from_pretrained(pretrained_model_path)
21
- self.model = self.model.to(device)
22
-
23
- def extract_audio_feat(self, audio_path, num_frames = 81, fps = 16, sr = 16000):
24
- audio_input, sample_rate = librosa.load(audio_path, sr=sr)
25
-
26
- start_time = 0
27
- end_time = num_frames / fps
28
-
29
- start_sample = int(start_time * sr)
30
- end_sample = int(end_time * sr)
31
-
32
- try:
33
- audio_segment = audio_input[start_sample:end_sample]
34
- except:
35
- audio_segment = audio_input
36
-
37
- input_values = self.processor(
38
- audio_segment, sampling_rate=sample_rate, return_tensors="pt"
39
- ).input_values.to(self.model.device, self.model.dtype)
40
-
41
- with torch.no_grad():
42
- fea = self.model(input_values).last_hidden_state
43
- return fea
44
-
45
- def extract_audio_feat_without_file_load(self, audio_segment, sample_rate):
46
- input_values = self.processor(
47
- audio_segment, sampling_rate=sample_rate, return_tensors="pt"
48
- ).input_values.to(self.model.device, self.model.dtype)
49
-
50
- with torch.no_grad():
51
- fea = self.model(input_values).last_hidden_state
52
- return fea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/fantasytalking_transformer3d.py DELETED
@@ -1,644 +0,0 @@
1
- # Modified from https://github.com/Fantasy-AMAP/fantasy-talking/blob/main/diffsynth/models
2
- # Copyright Alibaba Inc. All Rights Reserved.
3
- import math
4
- import os
5
- from typing import Any, Dict
6
-
7
- import numpy as np
8
- import torch
9
- import torch.cuda.amp as amp
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from diffusers.configuration_utils import register_to_config
13
- from diffusers.utils import is_torch_version
14
-
15
- from ..dist import sequence_parallel_all_gather, sequence_parallel_chunk
16
- from ..utils import cfg_skip
17
- from .attention_utils import attention
18
- from .wan_transformer3d import (WanAttentionBlock, WanLayerNorm, WanRMSNorm,
19
- WanSelfAttention, WanTransformer3DModel,
20
- sinusoidal_embedding_1d)
21
-
22
-
23
- class AudioProjModel(nn.Module):
24
- def __init__(self, audio_in_dim=1024, cross_attention_dim=1024):
25
- super().__init__()
26
- self.cross_attention_dim = cross_attention_dim
27
- self.proj = torch.nn.Linear(audio_in_dim, cross_attention_dim, bias=False)
28
- self.norm = torch.nn.LayerNorm(cross_attention_dim)
29
-
30
- def forward(self, audio_embeds):
31
- context_tokens = self.proj(audio_embeds)
32
- context_tokens = self.norm(context_tokens)
33
- return context_tokens # [B,L,C]
34
-
35
-
36
- class AudioCrossAttentionProcessor(nn.Module):
37
- def __init__(self, context_dim, hidden_dim):
38
- super().__init__()
39
-
40
- self.context_dim = context_dim
41
- self.hidden_dim = hidden_dim
42
-
43
- self.k_proj = nn.Linear(context_dim, hidden_dim, bias=False)
44
- self.v_proj = nn.Linear(context_dim, hidden_dim, bias=False)
45
-
46
- nn.init.zeros_(self.k_proj.weight)
47
- nn.init.zeros_(self.v_proj.weight)
48
-
49
- self.sp_world_size = 1
50
- self.sp_world_rank = 0
51
- self.all_gather = None
52
-
53
- def __call__(
54
- self,
55
- attn: nn.Module,
56
- x: torch.Tensor,
57
- context: torch.Tensor,
58
- context_lens: torch.Tensor,
59
- audio_proj: torch.Tensor,
60
- audio_context_lens: torch.Tensor,
61
- latents_num_frames: int = 21,
62
- audio_scale: float = 1.0,
63
- ) -> torch.Tensor:
64
- """
65
- x: [B, L1, C].
66
- context: [B, L2, C].
67
- context_lens: [B].
68
- audio_proj: [B, 21, L3, C]
69
- audio_context_lens: [B*21].
70
- """
71
- context_img = context[:, :257]
72
- context = context[:, 257:]
73
- b, n, d = x.size(0), attn.num_heads, attn.head_dim
74
-
75
- # Compute query, key, value
76
- q = attn.norm_q(attn.q(x)).view(b, -1, n, d)
77
- k = attn.norm_k(attn.k(context)).view(b, -1, n, d)
78
- v = attn.v(context).view(b, -1, n, d)
79
- k_img = attn.norm_k_img(attn.k_img(context_img)).view(b, -1, n, d)
80
- v_img = attn.v_img(context_img).view(b, -1, n, d)
81
- img_x = attention(q, k_img, v_img, k_lens=None)
82
- # Compute attention
83
- x = attention(q, k, v, k_lens=context_lens)
84
- x = x.flatten(2)
85
- img_x = img_x.flatten(2)
86
-
87
- if len(audio_proj.shape) == 4:
88
- if self.sp_world_size > 1:
89
- q = self.all_gather(q, dim=1)
90
-
91
- length = int(np.floor(q.size()[1] / latents_num_frames) * latents_num_frames)
92
- origin_length = q.size()[1]
93
- if origin_length > length:
94
- q_pad = q[:, length:]
95
- q = q[:, :length]
96
- audio_q = q.view(b * latents_num_frames, -1, n, d) # [b, 21, l1, n, d]
97
- ip_key = self.k_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
98
- ip_value = self.v_proj(audio_proj).view(b * latents_num_frames, -1, n, d)
99
- audio_x = attention(
100
- audio_q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL"
101
- )
102
- audio_x = audio_x.view(b, q.size(1), n, d)
103
- if self.sp_world_size > 1:
104
- if origin_length > length:
105
- audio_x = torch.cat([audio_x, q_pad], dim=1)
106
- audio_x = torch.chunk(audio_x, self.sp_world_size, dim=1)[self.sp_world_rank]
107
- audio_x = audio_x.flatten(2)
108
- elif len(audio_proj.shape) == 3:
109
- ip_key = self.k_proj(audio_proj).view(b, -1, n, d)
110
- ip_value = self.v_proj(audio_proj).view(b, -1, n, d)
111
- audio_x = attention(q, ip_key, ip_value, k_lens=audio_context_lens, attention_type="NORMAL")
112
- audio_x = audio_x.flatten(2)
113
- # Output
114
- if isinstance(audio_scale, torch.Tensor):
115
- audio_scale = audio_scale[:, None, None]
116
-
117
- x = x + img_x + audio_x * audio_scale
118
- x = attn.o(x)
119
- # print(audio_scale)
120
- return x
121
-
122
-
123
- class AudioCrossAttention(WanSelfAttention):
124
- def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
125
- super().__init__(dim, num_heads, window_size, qk_norm, eps)
126
-
127
- self.k_img = nn.Linear(dim, dim)
128
- self.v_img = nn.Linear(dim, dim)
129
-
130
- self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
131
-
132
- self.processor = AudioCrossAttentionProcessor(2048, dim)
133
-
134
- def forward(
135
- self,
136
- x,
137
- context,
138
- context_lens,
139
- audio_proj,
140
- audio_context_lens,
141
- latents_num_frames,
142
- audio_scale: float = 1.0,
143
- **kwargs,
144
- ):
145
- """
146
- x: [B, L1, C].
147
- context: [B, L2, C].
148
- context_lens: [B].
149
- """
150
- if audio_proj is None:
151
- return self.processor(self, x, context, context_lens)
152
- else:
153
- return self.processor(
154
- self,
155
- x,
156
- context,
157
- context_lens,
158
- audio_proj,
159
- audio_context_lens,
160
- latents_num_frames,
161
- audio_scale,
162
- )
163
-
164
-
165
- class AudioAttentionBlock(nn.Module):
166
- def __init__(
167
- self,
168
- cross_attn_type, # Useless
169
- dim,
170
- ffn_dim,
171
- num_heads,
172
- window_size=(-1, -1),
173
- qk_norm=True,
174
- cross_attn_norm=False,
175
- eps=1e-6,
176
- ):
177
- super().__init__()
178
- self.dim = dim
179
- self.ffn_dim = ffn_dim
180
- self.num_heads = num_heads
181
- self.window_size = window_size
182
- self.qk_norm = qk_norm
183
- self.cross_attn_norm = cross_attn_norm
184
- self.eps = eps
185
-
186
- # Layers
187
- self.norm1 = WanLayerNorm(dim, eps)
188
- self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
189
- self.norm3 = (
190
- WanLayerNorm(dim, eps, elementwise_affine=True)
191
- if cross_attn_norm
192
- else nn.Identity()
193
- )
194
- self.cross_attn = AudioCrossAttention(
195
- dim, num_heads, (-1, -1), qk_norm, eps
196
- )
197
- self.norm2 = WanLayerNorm(dim, eps)
198
- self.ffn = nn.Sequential(
199
- nn.Linear(dim, ffn_dim),
200
- nn.GELU(approximate="tanh"),
201
- nn.Linear(ffn_dim, dim),
202
- )
203
-
204
- # Modulation
205
- self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
206
-
207
- def forward(
208
- self,
209
- x,
210
- e,
211
- seq_lens,
212
- grid_sizes,
213
- freqs,
214
- context,
215
- context_lens,
216
- audio_proj=None,
217
- audio_context_lens=None,
218
- audio_scale=1,
219
- dtype=torch.bfloat16,
220
- t=0,
221
- ):
222
- assert e.dtype == torch.float32
223
- with amp.autocast(dtype=torch.float32):
224
- e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
225
- assert e[0].dtype == torch.float32
226
-
227
- # self-attention
228
- y = self.self_attn(
229
- self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs, dtype, t=t
230
- )
231
- with amp.autocast(dtype=torch.float32):
232
- x = x + y * e[2]
233
-
234
- # Cross-attention & FFN function
235
- def cross_attn_ffn(x, context, context_lens, e):
236
- x = x + self.cross_attn(
237
- self.norm3(x), context, context_lens, dtype=dtype, t=t,
238
- audio_proj=audio_proj, audio_context_lens=audio_context_lens, audio_scale=audio_scale,
239
- latents_num_frames=grid_sizes[0][0],
240
- )
241
- y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
242
- with amp.autocast(dtype=torch.float32):
243
- x = x + y * e[5]
244
- return x
245
-
246
- x = cross_attn_ffn(x, context, context_lens, e)
247
- return x
248
-
249
-
250
- class FantasyTalkingTransformer3DModel(WanTransformer3DModel):
251
- @register_to_config
252
- def __init__(self,
253
- model_type='i2v',
254
- patch_size=(1, 2, 2),
255
- text_len=512,
256
- in_dim=16,
257
- dim=2048,
258
- ffn_dim=8192,
259
- freq_dim=256,
260
- text_dim=4096,
261
- out_dim=16,
262
- num_heads=16,
263
- num_layers=32,
264
- window_size=(-1, -1),
265
- qk_norm=True,
266
- cross_attn_norm=True,
267
- eps=1e-6,
268
- cross_attn_type=None,
269
- audio_in_dim=768):
270
- super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
271
- num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
272
-
273
- if cross_attn_type is None:
274
- cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
275
- self.blocks = nn.ModuleList([
276
- AudioAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
277
- window_size, qk_norm, cross_attn_norm, eps)
278
- for _ in range(num_layers)
279
- ])
280
- for layer_idx, block in enumerate(self.blocks):
281
- block.self_attn.layer_idx = layer_idx
282
- block.self_attn.num_layers = self.num_layers
283
-
284
- self.proj_model = AudioProjModel(audio_in_dim, 2048)
285
-
286
- def split_audio_sequence(self, audio_proj_length, num_frames=81):
287
- """
288
- Map the audio feature sequence to corresponding latent frame slices.
289
-
290
- Args:
291
- audio_proj_length (int): The total length of the audio feature sequence
292
- (e.g., 173 in audio_proj[1, 173, 768]).
293
- num_frames (int): The number of video frames in the training data (default: 81).
294
-
295
- Returns:
296
- list: A list of [start_idx, end_idx] pairs. Each pair represents the index range
297
- (within the audio feature sequence) corresponding to a latent frame.
298
- """
299
- # Average number of tokens per original video frame
300
- tokens_per_frame = audio_proj_length / num_frames
301
-
302
- # Each latent frame covers 4 video frames, and we want the center
303
- tokens_per_latent_frame = tokens_per_frame * 4
304
- half_tokens = int(tokens_per_latent_frame / 2)
305
-
306
- pos_indices = []
307
- for i in range(int((num_frames - 1) / 4) + 1):
308
- if i == 0:
309
- pos_indices.append(0)
310
- else:
311
- start_token = tokens_per_frame * ((i - 1) * 4 + 1)
312
- end_token = tokens_per_frame * (i * 4 + 1)
313
- center_token = int((start_token + end_token) / 2) - 1
314
- pos_indices.append(center_token)
315
-
316
- # Build index ranges centered around each position
317
- pos_idx_ranges = [[idx - half_tokens, idx + half_tokens] for idx in pos_indices]
318
-
319
- # Adjust the first range to avoid negative start index
320
- pos_idx_ranges[0] = [
321
- -(half_tokens * 2 - pos_idx_ranges[1][0]),
322
- pos_idx_ranges[1][0],
323
- ]
324
-
325
- return pos_idx_ranges
326
-
327
- def split_tensor_with_padding(self, input_tensor, pos_idx_ranges, expand_length=0):
328
- """
329
- Split the input tensor into subsequences based on index ranges, and apply right-side zero-padding
330
- if the range exceeds the input boundaries.
331
-
332
- Args:
333
- input_tensor (Tensor): Input audio tensor of shape [1, L, 768].
334
- pos_idx_ranges (list): A list of index ranges, e.g. [[-7, 1], [1, 9], ..., [165, 173]].
335
- expand_length (int): Number of tokens to expand on both sides of each subsequence.
336
-
337
- Returns:
338
- sub_sequences (Tensor): A tensor of shape [1, F, L, 768], where L is the length after padding.
339
- Each element is a padded subsequence.
340
- k_lens (Tensor): A tensor of shape [F], representing the actual (unpadded) length of each subsequence.
341
- Useful for ignoring padding tokens in attention masks.
342
- """
343
- pos_idx_ranges = [
344
- [idx[0] - expand_length, idx[1] + expand_length] for idx in pos_idx_ranges
345
- ]
346
- sub_sequences = []
347
- seq_len = input_tensor.size(1) # 173
348
- max_valid_idx = seq_len - 1 # 172
349
- k_lens_list = []
350
- for start, end in pos_idx_ranges:
351
- # Calculate the fill amount
352
- pad_front = max(-start, 0)
353
- pad_back = max(end - max_valid_idx, 0)
354
-
355
- # Calculate the start and end indices of the valid part
356
- valid_start = max(start, 0)
357
- valid_end = min(end, max_valid_idx)
358
-
359
- # Extract the valid part
360
- if valid_start <= valid_end:
361
- valid_part = input_tensor[:, valid_start : valid_end + 1, :]
362
- else:
363
- valid_part = input_tensor.new_zeros((1, 0, input_tensor.size(2)))
364
-
365
- # In the sequence dimension (the 1st dimension) perform padding
366
- padded_subseq = F.pad(
367
- valid_part,
368
- (0, 0, 0, pad_back + pad_front, 0, 0),
369
- mode="constant",
370
- value=0,
371
- )
372
- k_lens_list.append(padded_subseq.size(-2) - pad_back - pad_front)
373
-
374
- sub_sequences.append(padded_subseq)
375
- return torch.stack(sub_sequences, dim=1), torch.tensor(
376
- k_lens_list, dtype=torch.long
377
- )
378
-
379
- def enable_multi_gpus_inference(self,):
380
- super().enable_multi_gpus_inference()
381
- for name, module in self.named_modules():
382
- if module.__class__.__name__ == 'AudioCrossAttentionProcessor':
383
- module.sp_world_size = self.sp_world_size
384
- module.sp_world_rank = self.sp_world_rank
385
- module.all_gather = self.all_gather
386
-
387
- @cfg_skip()
388
- def forward(
389
- self,
390
- x,
391
- t,
392
- context,
393
- seq_len,
394
- audio_wav2vec_fea=None,
395
- clip_fea=None,
396
- y=None,
397
- audio_scale=1,
398
- cond_flag=True
399
- ):
400
- r"""
401
- Forward pass through the diffusion model
402
-
403
- Args:
404
- x (List[Tensor]):
405
- List of input video tensors, each with shape [C_in, F, H, W]
406
- t (Tensor):
407
- Diffusion timesteps tensor of shape [B]
408
- context (List[Tensor]):
409
- List of text embeddings each with shape [L, C]
410
- seq_len (`int`):
411
- Maximum sequence length for positional encoding
412
- clip_fea (Tensor, *optional*):
413
- CLIP image features for image-to-video mode
414
- y (List[Tensor], *optional*):
415
- Conditional video inputs for image-to-video mode, same shape as x
416
-
417
- Returns:
418
- List[Tensor]:
419
- List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
420
- """
421
- # Wan2.2 don't need a clip.
422
- # if self.model_type == 'i2v':
423
- # assert clip_fea is not None and y is not None
424
- # params
425
- device = self.patch_embedding.weight.device
426
- dtype = x.dtype
427
- if self.freqs.device != device and torch.device(type="meta") != device:
428
- self.freqs = self.freqs.to(device)
429
-
430
- if y is not None:
431
- x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
432
-
433
- # embeddings
434
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
435
-
436
- grid_sizes = torch.stack(
437
- [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
438
-
439
- x = [u.flatten(2).transpose(1, 2) for u in x]
440
-
441
- seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
442
- if self.sp_world_size > 1:
443
- seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
444
- assert seq_lens.max() <= seq_len
445
- x = torch.cat([
446
- torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
447
- dim=1) for u in x
448
- ])
449
-
450
- # time embeddings
451
- with amp.autocast(dtype=torch.float32):
452
- if t.dim() != 1:
453
- if t.size(1) < seq_len:
454
- pad_size = seq_len - t.size(1)
455
- last_elements = t[:, -1].unsqueeze(1)
456
- padding = last_elements.repeat(1, pad_size)
457
- t = torch.cat([t, padding], dim=1)
458
- bt = t.size(0)
459
- ft = t.flatten()
460
- e = self.time_embedding(
461
- sinusoidal_embedding_1d(self.freq_dim,
462
- ft).unflatten(0, (bt, seq_len)).float())
463
- e0 = self.time_projection(e).unflatten(2, (6, self.dim))
464
- else:
465
- e = self.time_embedding(
466
- sinusoidal_embedding_1d(self.freq_dim, t).float())
467
- e0 = self.time_projection(e).unflatten(1, (6, self.dim))
468
-
469
- # assert e.dtype == torch.float32 and e0.dtype == torch.float32
470
- # e0 = e0.to(dtype)
471
- # e = e.to(dtype)
472
-
473
- # context
474
- context_lens = None
475
- context = self.text_embedding(
476
- torch.stack([
477
- torch.cat(
478
- [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
479
- for u in context
480
- ]))
481
-
482
- if clip_fea is not None:
483
- context_clip = self.img_emb(clip_fea) # bs x 257 x dim
484
- context = torch.concat([context_clip, context], dim=1)
485
-
486
- num_frames = (grid_sizes[0][0] - 1) * 4 + 1
487
- audio_proj_fea = self.proj_model(audio_wav2vec_fea)
488
- pos_idx_ranges = self.split_audio_sequence(audio_proj_fea.size(1), num_frames=num_frames)
489
- audio_proj, audio_context_lens = self.split_tensor_with_padding(
490
- audio_proj_fea, pos_idx_ranges, expand_length=4
491
- )
492
-
493
- # Context Parallel
494
- if self.sp_world_size > 1:
495
- x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
496
- if t.dim() != 1:
497
- e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
498
- e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
499
-
500
- # TeaCache
501
- if self.teacache is not None:
502
- if cond_flag:
503
- if t.dim() != 1:
504
- modulated_inp = e0[:, -1, :]
505
- else:
506
- modulated_inp = e0
507
- skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
508
- if skip_flag:
509
- self.should_calc = True
510
- self.teacache.accumulated_rel_l1_distance = 0
511
- else:
512
- if cond_flag:
513
- rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
514
- self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
515
- if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
516
- self.should_calc = False
517
- else:
518
- self.should_calc = True
519
- self.teacache.accumulated_rel_l1_distance = 0
520
- self.teacache.previous_modulated_input = modulated_inp
521
- self.teacache.should_calc = self.should_calc
522
- else:
523
- self.should_calc = self.teacache.should_calc
524
-
525
- # TeaCache
526
- if self.teacache is not None:
527
- if not self.should_calc:
528
- previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
529
- x = x + previous_residual.to(x.device)[-x.size()[0]:,]
530
- else:
531
- ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
532
-
533
- for block in self.blocks:
534
- if torch.is_grad_enabled() and self.gradient_checkpointing:
535
-
536
- def create_custom_forward(module):
537
- def custom_forward(*inputs):
538
- return module(*inputs)
539
-
540
- return custom_forward
541
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
542
- x = torch.utils.checkpoint.checkpoint(
543
- create_custom_forward(block),
544
- x,
545
- e0,
546
- seq_lens,
547
- grid_sizes,
548
- self.freqs,
549
- context,
550
- context_lens,
551
- audio_proj,
552
- audio_context_lens,
553
- audio_scale,
554
- dtype,
555
- t,
556
- **ckpt_kwargs,
557
- )
558
- else:
559
- # arguments
560
- kwargs = dict(
561
- e=e0,
562
- seq_lens=seq_lens,
563
- grid_sizes=grid_sizes,
564
- freqs=self.freqs,
565
- context=context,
566
- context_lens=context_lens,
567
- audio_proj=audio_proj,
568
- audio_context_lens=audio_context_lens,
569
- audio_scale=audio_scale,
570
- dtype=dtype,
571
- t=t
572
- )
573
- x = block(x, **kwargs)
574
-
575
- if cond_flag:
576
- self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
577
- else:
578
- self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
579
- else:
580
- for block in self.blocks:
581
- if torch.is_grad_enabled() and self.gradient_checkpointing:
582
-
583
- def create_custom_forward(module):
584
- def custom_forward(*inputs):
585
- return module(*inputs)
586
-
587
- return custom_forward
588
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
589
- x = torch.utils.checkpoint.checkpoint(
590
- create_custom_forward(block),
591
- x,
592
- e0,
593
- seq_lens,
594
- grid_sizes,
595
- self.freqs,
596
- context,
597
- context_lens,
598
- audio_proj,
599
- audio_context_lens,
600
- audio_scale,
601
- dtype,
602
- t,
603
- **ckpt_kwargs,
604
- )
605
- else:
606
- # arguments
607
- kwargs = dict(
608
- e=e0,
609
- seq_lens=seq_lens,
610
- grid_sizes=grid_sizes,
611
- freqs=self.freqs,
612
- context=context,
613
- context_lens=context_lens,
614
- audio_proj=audio_proj,
615
- audio_context_lens=audio_context_lens,
616
- audio_scale=audio_scale,
617
- dtype=dtype,
618
- t=t
619
- )
620
- x = block(x, **kwargs)
621
-
622
- # head
623
- if torch.is_grad_enabled() and self.gradient_checkpointing:
624
- def create_custom_forward(module):
625
- def custom_forward(*inputs):
626
- return module(*inputs)
627
-
628
- return custom_forward
629
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
630
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
631
- else:
632
- x = self.head(x, e)
633
-
634
- if self.sp_world_size > 1:
635
- x = self.all_gather(x, dim=1)
636
-
637
- # Unpatchify
638
- x = self.unpatchify(x, grid_sizes)
639
- x = torch.stack(x)
640
- if self.teacache is not None and cond_flag:
641
- self.teacache.cnt += 1
642
- if self.teacache.cnt == self.teacache.num_steps:
643
- self.teacache.reset()
644
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/flux2_image_processor.py DELETED
@@ -1,139 +0,0 @@
1
- # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/flux2/image_processor.py
2
- # Copyright 2025 The Black Forest Labs Team and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import math
17
- from typing import Tuple
18
-
19
- import PIL.Image
20
-
21
- from diffusers.configuration_utils import register_to_config
22
- from diffusers.image_processor import VaeImageProcessor
23
-
24
-
25
- class Flux2ImageProcessor(VaeImageProcessor):
26
- r"""
27
- Image processor to preprocess the reference (character) image for the Flux2 model.
28
-
29
- Args:
30
- do_resize (`bool`, *optional*, defaults to `True`):
31
- Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
32
- `height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
33
- vae_scale_factor (`int`, *optional*, defaults to `16`):
34
- VAE (spatial) scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of
35
- this factor.
36
- vae_latent_channels (`int`, *optional*, defaults to `32`):
37
- VAE latent channels.
38
- do_normalize (`bool`, *optional*, defaults to `True`):
39
- Whether to normalize the image to [-1,1].
40
- do_convert_rgb (`bool`, *optional*, defaults to be `True`):
41
- Whether to convert the images to RGB format.
42
- """
43
-
44
- @register_to_config
45
- def __init__(
46
- self,
47
- do_resize: bool = True,
48
- vae_scale_factor: int = 16,
49
- vae_latent_channels: int = 32,
50
- do_normalize: bool = True,
51
- do_convert_rgb: bool = True,
52
- ):
53
- super().__init__(
54
- do_resize=do_resize,
55
- vae_scale_factor=vae_scale_factor,
56
- vae_latent_channels=vae_latent_channels,
57
- do_normalize=do_normalize,
58
- do_convert_rgb=do_convert_rgb,
59
- )
60
-
61
- @staticmethod
62
- def check_image_input(
63
- image: PIL.Image.Image, max_aspect_ratio: int = 8, min_side_length: int = 64, max_area: int = 1024 * 1024
64
- ) -> PIL.Image.Image:
65
- """
66
- Check if image meets minimum size and aspect ratio requirements.
67
-
68
- Args:
69
- image: PIL Image to validate
70
- max_aspect_ratio: Maximum allowed aspect ratio (width/height or height/width)
71
- min_side_length: Minimum pixels required for width and height
72
- max_area: Maximum allowed area in pixels²
73
-
74
- Returns:
75
- The input image if valid
76
-
77
- Raises:
78
- ValueError: If image is too small or aspect ratio is too extreme
79
- """
80
- if not isinstance(image, PIL.Image.Image):
81
- raise ValueError(f"Image must be a PIL.Image.Image, got {type(image)}")
82
-
83
- width, height = image.size
84
-
85
- # Check minimum dimensions
86
- if width < min_side_length or height < min_side_length:
87
- raise ValueError(
88
- f"Image too small: {width}×{height}. Both dimensions must be at least {min_side_length}px"
89
- )
90
-
91
- # Check aspect ratio
92
- aspect_ratio = max(width / height, height / width)
93
- if aspect_ratio > max_aspect_ratio:
94
- raise ValueError(
95
- f"Aspect ratio too extreme: {width}×{height} (ratio: {aspect_ratio:.1f}:1). "
96
- f"Maximum allowed ratio is {max_aspect_ratio}:1"
97
- )
98
-
99
- return image
100
-
101
- @staticmethod
102
- def _resize_to_target_area(image: PIL.Image.Image, target_area: int = 1024 * 1024) -> Tuple[int, int]:
103
- image_width, image_height = image.size
104
-
105
- scale = math.sqrt(target_area / (image_width * image_height))
106
- width = int(image_width * scale)
107
- height = int(image_height * scale)
108
-
109
- return image.resize((width, height), PIL.Image.Resampling.LANCZOS)
110
-
111
- def _resize_and_crop(
112
- self,
113
- image: PIL.Image.Image,
114
- width: int,
115
- height: int,
116
- ) -> PIL.Image.Image:
117
- r"""
118
- center crop the image to the specified width and height.
119
-
120
- Args:
121
- image (`PIL.Image.Image`):
122
- The image to resize and crop.
123
- width (`int`):
124
- The width to resize the image to.
125
- height (`int`):
126
- The height to resize the image to.
127
-
128
- Returns:
129
- `PIL.Image.Image`:
130
- The resized and cropped image.
131
- """
132
- image_width, image_height = image.size
133
-
134
- left = (image_width - width) // 2
135
- top = (image_height - height) // 2
136
- right = left + width
137
- bottom = top + height
138
-
139
- return image.crop((left, top, right, bottom))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/flux2_transformer2d.py DELETED
@@ -1,1278 +0,0 @@
1
- # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux2.py
2
- # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import glob
17
- import inspect
18
- import json
19
- import os
20
- from typing import Any, Dict, List, Optional, Tuple, Union
21
-
22
- import torch
23
- import torch.nn as nn
24
- import torch.nn.functional as F
25
- from diffusers.configuration_utils import ConfigMixin, register_to_config
26
- from diffusers.loaders import FromOriginalModelMixin
27
- from diffusers.models.attention_processor import Attention, AttentionProcessor
28
- from diffusers.models.embeddings import (TimestepEmbedding, Timesteps,
29
- apply_rotary_emb,
30
- get_1d_rotary_pos_embed)
31
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
- from diffusers.models.modeling_utils import ModelMixin
33
- from diffusers.models.normalization import AdaLayerNormContinuous
34
- from diffusers.utils import (USE_PEFT_BACKEND, is_torch_npu_available,
35
- is_torch_version, logging, scale_lora_layers,
36
- unscale_lora_layers)
37
-
38
- from ..dist import (Flux2MultiGPUsAttnProcessor2_0, get_sequence_parallel_rank,
39
- get_sequence_parallel_world_size, get_sp_group)
40
- from .attention_utils import attention
41
-
42
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
43
-
44
-
45
- def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
46
- query = attn.to_q(hidden_states)
47
- key = attn.to_k(hidden_states)
48
- value = attn.to_v(hidden_states)
49
-
50
- encoder_query = encoder_key = encoder_value = None
51
- if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
52
- encoder_query = attn.add_q_proj(encoder_hidden_states)
53
- encoder_key = attn.add_k_proj(encoder_hidden_states)
54
- encoder_value = attn.add_v_proj(encoder_hidden_states)
55
-
56
- return query, key, value, encoder_query, encoder_key, encoder_value
57
-
58
-
59
- def _get_qkv_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
60
- return _get_projections(attn, hidden_states, encoder_hidden_states)
61
-
62
-
63
- def apply_rotary_emb(
64
- x: torch.Tensor,
65
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
66
- use_real: bool = True,
67
- use_real_unbind_dim: int = -1,
68
- sequence_dim: int = 2,
69
- ) -> Tuple[torch.Tensor, torch.Tensor]:
70
- """
71
- Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
72
- to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
73
- reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
74
- tensors contain rotary embeddings and are returned as real tensors.
75
-
76
- Args:
77
- x (`torch.Tensor`):
78
- Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
79
- freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
80
-
81
- Returns:
82
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
83
- """
84
- if use_real:
85
- cos, sin = freqs_cis # [S, D]
86
- if sequence_dim == 2:
87
- cos = cos[None, None, :, :]
88
- sin = sin[None, None, :, :]
89
- elif sequence_dim == 1:
90
- cos = cos[None, :, None, :]
91
- sin = sin[None, :, None, :]
92
- else:
93
- raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
94
-
95
- cos, sin = cos.to(x.device), sin.to(x.device)
96
-
97
- if use_real_unbind_dim == -1:
98
- # Used for flux, cogvideox, hunyuan-dit
99
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
100
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
101
- elif use_real_unbind_dim == -2:
102
- # Used for Stable Audio, OmniGen, CogView4 and Cosmos
103
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
104
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
105
- else:
106
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
107
-
108
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
109
-
110
- return out
111
- else:
112
- # used for lumina
113
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
114
- freqs_cis = freqs_cis.unsqueeze(2)
115
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
116
-
117
- return x_out.type_as(x)
118
-
119
-
120
- class Flux2SwiGLU(nn.Module):
121
- """
122
- Flux 2 uses a SwiGLU-style activation in the transformer feedforward sub-blocks, but with the linear projection
123
- layer fused into the first linear layer of the FF sub-block. Thus, this module has no trainable parameters.
124
- """
125
-
126
- def __init__(self):
127
- super().__init__()
128
- self.gate_fn = nn.SiLU()
129
-
130
- def forward(self, x: torch.Tensor) -> torch.Tensor:
131
- x1, x2 = x.chunk(2, dim=-1)
132
- x = self.gate_fn(x1) * x2
133
- return x
134
-
135
-
136
- class Flux2FeedForward(nn.Module):
137
- def __init__(
138
- self,
139
- dim: int,
140
- dim_out: Optional[int] = None,
141
- mult: float = 3.0,
142
- inner_dim: Optional[int] = None,
143
- bias: bool = False,
144
- ):
145
- super().__init__()
146
- if inner_dim is None:
147
- inner_dim = int(dim * mult)
148
- dim_out = dim_out or dim
149
-
150
- # Flux2SwiGLU will reduce the dimension by half
151
- self.linear_in = nn.Linear(dim, inner_dim * 2, bias=bias)
152
- self.act_fn = Flux2SwiGLU()
153
- self.linear_out = nn.Linear(inner_dim, dim_out, bias=bias)
154
-
155
- def forward(self, x: torch.Tensor) -> torch.Tensor:
156
- x = self.linear_in(x)
157
- x = self.act_fn(x)
158
- x = self.linear_out(x)
159
- return x
160
-
161
-
162
- class Flux2AttnProcessor:
163
- _attention_backend = None
164
- _parallel_config = None
165
-
166
- def __init__(self):
167
- if not hasattr(F, "scaled_dot_product_attention"):
168
- raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
169
-
170
- def __call__(
171
- self,
172
- attn: Union["Flux2Attention", "Flux2ParallelSelfAttention"],
173
- hidden_states: torch.Tensor,
174
- encoder_hidden_states: Optional[torch.Tensor] = None,
175
- attention_mask: Optional[torch.Tensor] = None,
176
- image_rotary_emb: Optional[torch.Tensor] = None,
177
- text_seq_len: int = None,
178
- ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
179
- """
180
- Unified processor for both Flux2Attention and Flux2ParallelSelfAttention.
181
-
182
- Args:
183
- attn: Attention module (either Flux2Attention or Flux2ParallelSelfAttention)
184
- hidden_states: Input hidden states
185
- encoder_hidden_states: Optional encoder hidden states (only for Flux2Attention)
186
- attention_mask: Optional attention mask
187
- image_rotary_emb: Optional rotary embeddings
188
-
189
- Returns:
190
- For Flux2Attention with encoder_hidden_states: (hidden_states, encoder_hidden_states)
191
- For Flux2Attention without encoder_hidden_states: hidden_states
192
- For Flux2ParallelSelfAttention: hidden_states
193
- """
194
- # Determine which type of attention we're processing
195
- is_parallel_self_attn = hasattr(attn, 'to_qkv_mlp_proj') and attn.to_qkv_mlp_proj is not None
196
-
197
- if is_parallel_self_attn:
198
- # ============================================
199
- # Parallel Self-Attention Path (with MLP)
200
- # ============================================
201
- # Parallel in (QKV + MLP in) projection
202
- hidden_states = attn.to_qkv_mlp_proj(hidden_states)
203
- qkv, mlp_hidden_states = torch.split(
204
- hidden_states, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
205
- )
206
-
207
- # Handle the attention logic
208
- query, key, value = qkv.chunk(3, dim=-1)
209
-
210
- else:
211
- # ============================================
212
- # Standard Attention Path (possibly with encoder)
213
- # ============================================
214
- query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
215
- attn, hidden_states, encoder_hidden_states
216
- )
217
-
218
- # Common processing for query, key, value
219
- query = query.unflatten(-1, (attn.heads, -1))
220
- key = key.unflatten(-1, (attn.heads, -1))
221
- value = value.unflatten(-1, (attn.heads, -1))
222
-
223
- query = attn.norm_q(query)
224
- key = attn.norm_k(key)
225
-
226
- # Handle encoder projections (only for standard attention)
227
- if not is_parallel_self_attn and attn.added_kv_proj_dim is not None:
228
- encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
229
- encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
230
- encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
231
-
232
- encoder_query = attn.norm_added_q(encoder_query)
233
- encoder_key = attn.norm_added_k(encoder_key)
234
-
235
- query = torch.cat([encoder_query, query], dim=1)
236
- key = torch.cat([encoder_key, key], dim=1)
237
- value = torch.cat([encoder_value, value], dim=1)
238
-
239
- # Apply rotary embeddings
240
- if image_rotary_emb is not None:
241
- query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
242
- key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
243
-
244
- # Perform attention
245
- hidden_states = attention(
246
- query, key, value, attn_mask=attention_mask,
247
- )
248
- hidden_states = hidden_states.flatten(2, 3)
249
- hidden_states = hidden_states.to(query.dtype)
250
-
251
- if is_parallel_self_attn:
252
- # ============================================
253
- # Parallel Self-Attention Output Path
254
- # ============================================
255
- # Handle the feedforward (FF) logic
256
- mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
257
-
258
- # Concatenate and parallel output projection
259
- hidden_states = torch.cat([hidden_states, mlp_hidden_states], dim=-1)
260
- hidden_states = attn.to_out(hidden_states)
261
-
262
- return hidden_states
263
-
264
- else:
265
- # ============================================
266
- # Standard Attention Output Path
267
- # ============================================
268
- # Split encoder and latent hidden states if encoder was used
269
- if encoder_hidden_states is not None:
270
- encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
271
- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
272
- )
273
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
274
-
275
- # Project output
276
- hidden_states = attn.to_out[0](hidden_states)
277
- hidden_states = attn.to_out[1](hidden_states)
278
-
279
- if encoder_hidden_states is not None:
280
- return hidden_states, encoder_hidden_states
281
- else:
282
- return hidden_states
283
-
284
-
285
- class Flux2Attention(torch.nn.Module):
286
- _default_processor_cls = Flux2AttnProcessor
287
- _available_processors = [Flux2AttnProcessor]
288
-
289
- def __init__(
290
- self,
291
- query_dim: int,
292
- heads: int = 8,
293
- dim_head: int = 64,
294
- dropout: float = 0.0,
295
- bias: bool = False,
296
- added_kv_proj_dim: Optional[int] = None,
297
- added_proj_bias: Optional[bool] = True,
298
- out_bias: bool = True,
299
- eps: float = 1e-5,
300
- out_dim: int = None,
301
- elementwise_affine: bool = True,
302
- processor=None,
303
- ):
304
- super().__init__()
305
-
306
- self.head_dim = dim_head
307
- self.inner_dim = out_dim if out_dim is not None else dim_head * heads
308
- self.query_dim = query_dim
309
- self.out_dim = out_dim if out_dim is not None else query_dim
310
- self.heads = out_dim // dim_head if out_dim is not None else heads
311
-
312
- self.use_bias = bias
313
- self.dropout = dropout
314
-
315
- self.added_kv_proj_dim = added_kv_proj_dim
316
- self.added_proj_bias = added_proj_bias
317
-
318
- self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
319
- self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
320
- self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
321
-
322
- # QK Norm
323
- self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
324
- self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
325
-
326
- self.to_out = torch.nn.ModuleList([])
327
- self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
328
- self.to_out.append(torch.nn.Dropout(dropout))
329
-
330
- if added_kv_proj_dim is not None:
331
- self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
332
- self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
333
- self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
334
- self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
335
- self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
336
- self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
337
-
338
- if processor is None:
339
- processor = self._default_processor_cls()
340
- self.set_processor(processor)
341
-
342
- def set_processor(self, processor: AttentionProcessor) -> None:
343
- """
344
- Set the attention processor to use.
345
-
346
- Args:
347
- processor (`AttnProcessor`):
348
- The attention processor to use.
349
- """
350
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
351
- # pop `processor` from `self._modules`
352
- if (
353
- hasattr(self, "processor")
354
- and isinstance(self.processor, torch.nn.Module)
355
- and not isinstance(processor, torch.nn.Module)
356
- ):
357
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
358
- self._modules.pop("processor")
359
-
360
- self.processor = processor
361
-
362
- def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
363
- """
364
- Get the attention processor in use.
365
-
366
- Args:
367
- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
368
- Set to `True` to return the deprecated LoRA attention processor.
369
-
370
- Returns:
371
- "AttentionProcessor": The attention processor in use.
372
- """
373
- if not return_deprecated_lora:
374
- return self.processor
375
-
376
- def forward(
377
- self,
378
- hidden_states: torch.Tensor,
379
- encoder_hidden_states: Optional[torch.Tensor] = None,
380
- attention_mask: Optional[torch.Tensor] = None,
381
- image_rotary_emb: Optional[torch.Tensor] = None,
382
- **kwargs,
383
- ) -> torch.Tensor:
384
- attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
385
- unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
386
- if len(unused_kwargs) > 0:
387
- logger.warning(
388
- f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
389
- )
390
- kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
391
- return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
392
-
393
-
394
- class Flux2ParallelSelfAttention(torch.nn.Module):
395
- """
396
- Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
397
-
398
- This implements a parallel transformer block, where the attention QKV projections are fused to the feedforward (FF)
399
- input projections, and the attention output projections are fused to the FF output projections. See the [ViT-22B
400
- paper](https://arxiv.org/abs/2302.05442) for a visual depiction of this type of transformer block.
401
- """
402
-
403
- _default_processor_cls = Flux2AttnProcessor
404
- _available_processors = [Flux2AttnProcessor]
405
- # Does not support QKV fusion as the QKV projections are always fused
406
- _supports_qkv_fusion = False
407
-
408
- def __init__(
409
- self,
410
- query_dim: int,
411
- heads: int = 8,
412
- dim_head: int = 64,
413
- dropout: float = 0.0,
414
- bias: bool = False,
415
- out_bias: bool = True,
416
- eps: float = 1e-5,
417
- out_dim: int = None,
418
- elementwise_affine: bool = True,
419
- mlp_ratio: float = 4.0,
420
- mlp_mult_factor: int = 2,
421
- processor=None,
422
- ):
423
- super().__init__()
424
-
425
- self.head_dim = dim_head
426
- self.inner_dim = out_dim if out_dim is not None else dim_head * heads
427
- self.query_dim = query_dim
428
- self.out_dim = out_dim if out_dim is not None else query_dim
429
- self.heads = out_dim // dim_head if out_dim is not None else heads
430
-
431
- self.use_bias = bias
432
- self.dropout = dropout
433
-
434
- self.mlp_ratio = mlp_ratio
435
- self.mlp_hidden_dim = int(query_dim * self.mlp_ratio)
436
- self.mlp_mult_factor = mlp_mult_factor
437
-
438
- # Fused QKV projections + MLP input projection
439
- self.to_qkv_mlp_proj = torch.nn.Linear(
440
- self.query_dim, self.inner_dim * 3 + self.mlp_hidden_dim * self.mlp_mult_factor, bias=bias
441
- )
442
- self.mlp_act_fn = Flux2SwiGLU()
443
-
444
- # QK Norm
445
- self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
446
- self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
447
-
448
- # Fused attention output projection + MLP output projection
449
- self.to_out = torch.nn.Linear(self.inner_dim + self.mlp_hidden_dim, self.out_dim, bias=out_bias)
450
-
451
- if processor is None:
452
- processor = self._default_processor_cls()
453
- self.set_processor(processor)
454
-
455
- def set_processor(self, processor: AttentionProcessor) -> None:
456
- """
457
- Set the attention processor to use.
458
-
459
- Args:
460
- processor (`AttnProcessor`):
461
- The attention processor to use.
462
- """
463
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
464
- # pop `processor` from `self._modules`
465
- if (
466
- hasattr(self, "processor")
467
- and isinstance(self.processor, torch.nn.Module)
468
- and not isinstance(processor, torch.nn.Module)
469
- ):
470
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
471
- self._modules.pop("processor")
472
-
473
- self.processor = processor
474
-
475
- def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
476
- """
477
- Get the attention processor in use.
478
-
479
- Args:
480
- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
481
- Set to `True` to return the deprecated LoRA attention processor.
482
-
483
- Returns:
484
- "AttentionProcessor": The attention processor in use.
485
- """
486
- if not return_deprecated_lora:
487
- return self.processor
488
-
489
- def forward(
490
- self,
491
- hidden_states: torch.Tensor,
492
- encoder_hidden_states: Optional[torch.Tensor] = None,
493
- attention_mask: Optional[torch.Tensor] = None,
494
- image_rotary_emb: Optional[torch.Tensor] = None,
495
- **kwargs,
496
- ) -> torch.Tensor:
497
- attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
498
- unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters]
499
- if len(unused_kwargs) > 0:
500
- logger.warning(
501
- f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
502
- )
503
- kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
504
- return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
505
-
506
-
507
- class Flux2SingleTransformerBlock(nn.Module):
508
- def __init__(
509
- self,
510
- dim: int,
511
- num_attention_heads: int,
512
- attention_head_dim: int,
513
- mlp_ratio: float = 3.0,
514
- eps: float = 1e-6,
515
- bias: bool = False,
516
- ):
517
- super().__init__()
518
-
519
- self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
520
-
521
- # Note that the MLP in/out linear layers are fused with the attention QKV/out projections, respectively; this
522
- # is often called a "parallel" transformer block. See the [ViT-22B paper](https://arxiv.org/abs/2302.05442)
523
- # for a visual depiction of this type of transformer block.
524
- self.attn = Flux2ParallelSelfAttention(
525
- query_dim=dim,
526
- dim_head=attention_head_dim,
527
- heads=num_attention_heads,
528
- out_dim=dim,
529
- bias=bias,
530
- out_bias=bias,
531
- eps=eps,
532
- mlp_ratio=mlp_ratio,
533
- mlp_mult_factor=2,
534
- processor=Flux2AttnProcessor(),
535
- )
536
-
537
- def forward(
538
- self,
539
- hidden_states: torch.Tensor,
540
- encoder_hidden_states: Optional[torch.Tensor],
541
- temb_mod_params: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
542
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
543
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
544
- ) -> Tuple[torch.Tensor, torch.Tensor]:
545
- # If encoder_hidden_states is None, hidden_states is assumed to have encoder_hidden_states already
546
- # concatenated
547
- if encoder_hidden_states is not None:
548
- text_seq_len = encoder_hidden_states.shape[1]
549
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
550
-
551
- mod_shift, mod_scale, mod_gate = temb_mod_params
552
-
553
- norm_hidden_states = self.norm(hidden_states)
554
- norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
555
-
556
- joint_attention_kwargs = joint_attention_kwargs or {}
557
- attn_output = self.attn(
558
- hidden_states=norm_hidden_states,
559
- image_rotary_emb=image_rotary_emb,
560
- text_seq_len=text_seq_len,
561
- **joint_attention_kwargs,
562
- )
563
-
564
- hidden_states = hidden_states + mod_gate * attn_output
565
- if hidden_states.dtype == torch.float16:
566
- hidden_states = hidden_states.clip(-65504, 65504)
567
-
568
- encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
569
- return encoder_hidden_states, hidden_states
570
-
571
-
572
- class Flux2TransformerBlock(nn.Module):
573
- def __init__(
574
- self,
575
- dim: int,
576
- num_attention_heads: int,
577
- attention_head_dim: int,
578
- mlp_ratio: float = 3.0,
579
- eps: float = 1e-6,
580
- bias: bool = False,
581
- ):
582
- super().__init__()
583
- self.mlp_hidden_dim = int(dim * mlp_ratio)
584
-
585
- self.norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
586
- self.norm1_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
587
-
588
- self.attn = Flux2Attention(
589
- query_dim=dim,
590
- added_kv_proj_dim=dim,
591
- dim_head=attention_head_dim,
592
- heads=num_attention_heads,
593
- out_dim=dim,
594
- bias=bias,
595
- added_proj_bias=bias,
596
- out_bias=bias,
597
- eps=eps,
598
- processor=Flux2AttnProcessor(),
599
- )
600
-
601
- self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
602
- self.ff = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
603
-
604
- self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
605
- self.ff_context = Flux2FeedForward(dim=dim, dim_out=dim, mult=mlp_ratio, bias=bias)
606
-
607
- def forward(
608
- self,
609
- hidden_states: torch.Tensor,
610
- encoder_hidden_states: torch.Tensor,
611
- temb_mod_params_img: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
612
- temb_mod_params_txt: Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
613
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
614
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
615
- ) -> Tuple[torch.Tensor, torch.Tensor]:
616
- joint_attention_kwargs = joint_attention_kwargs or {}
617
-
618
- # Modulation parameters shape: [1, 1, self.dim]
619
- (shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
620
- (c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
621
-
622
- # Img stream
623
- norm_hidden_states = self.norm1(hidden_states)
624
- norm_hidden_states = (1 + scale_msa) * norm_hidden_states + shift_msa
625
-
626
- # Conditioning txt stream
627
- norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
628
- norm_encoder_hidden_states = (1 + c_scale_msa) * norm_encoder_hidden_states + c_shift_msa
629
-
630
- # Attention on concatenated img + txt stream
631
- attention_outputs = self.attn(
632
- hidden_states=norm_hidden_states,
633
- encoder_hidden_states=norm_encoder_hidden_states,
634
- image_rotary_emb=image_rotary_emb,
635
- **joint_attention_kwargs,
636
- )
637
-
638
- attn_output, context_attn_output = attention_outputs
639
-
640
- # Process attention outputs for the image stream (`hidden_states`).
641
- attn_output = gate_msa * attn_output
642
- hidden_states = hidden_states + attn_output
643
-
644
- norm_hidden_states = self.norm2(hidden_states)
645
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
646
-
647
- ff_output = self.ff(norm_hidden_states)
648
- hidden_states = hidden_states + gate_mlp * ff_output
649
-
650
- # Process attention outputs for the text stream (`encoder_hidden_states`).
651
- context_attn_output = c_gate_msa * context_attn_output
652
- encoder_hidden_states = encoder_hidden_states + context_attn_output
653
-
654
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
655
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp) + c_shift_mlp
656
-
657
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
658
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp * context_ff_output
659
- if encoder_hidden_states.dtype == torch.float16:
660
- encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
661
-
662
- return encoder_hidden_states, hidden_states
663
-
664
-
665
- class Flux2PosEmbed(nn.Module):
666
- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
667
- def __init__(self, theta: int, axes_dim: List[int]):
668
- super().__init__()
669
- self.theta = theta
670
- self.axes_dim = axes_dim
671
-
672
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
673
- # Expected ids shape: [S, len(self.axes_dim)]
674
- cos_out = []
675
- sin_out = []
676
- pos = ids.float()
677
- is_mps = ids.device.type == "mps"
678
- is_npu = ids.device.type == "npu"
679
- freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
680
- # Unlike Flux 1, loop over len(self.axes_dim) rather than ids.shape[-1]
681
- for i in range(len(self.axes_dim)):
682
- cos, sin = get_1d_rotary_pos_embed(
683
- self.axes_dim[i],
684
- pos[..., i],
685
- theta=self.theta,
686
- repeat_interleave_real=True,
687
- use_real=True,
688
- freqs_dtype=freqs_dtype,
689
- )
690
- cos_out.append(cos)
691
- sin_out.append(sin)
692
- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
693
- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
694
- return freqs_cos, freqs_sin
695
-
696
-
697
- class Flux2TimestepGuidanceEmbeddings(nn.Module):
698
- def __init__(self, in_channels: int = 256, embedding_dim: int = 6144, bias: bool = False):
699
- super().__init__()
700
-
701
- self.time_proj = Timesteps(num_channels=in_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
702
- self.timestep_embedder = TimestepEmbedding(
703
- in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
704
- )
705
-
706
- self.guidance_embedder = TimestepEmbedding(
707
- in_channels=in_channels, time_embed_dim=embedding_dim, sample_proj_bias=bias
708
- )
709
-
710
- def forward(self, timestep: torch.Tensor, guidance: torch.Tensor) -> torch.Tensor:
711
- timesteps_proj = self.time_proj(timestep)
712
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(timestep.dtype)) # (N, D)
713
-
714
- guidance_proj = self.time_proj(guidance)
715
- guidance_emb = self.guidance_embedder(guidance_proj.to(guidance.dtype)) # (N, D)
716
-
717
- time_guidance_emb = timesteps_emb + guidance_emb
718
-
719
- return time_guidance_emb
720
-
721
-
722
- class Flux2Modulation(nn.Module):
723
- def __init__(self, dim: int, mod_param_sets: int = 2, bias: bool = False):
724
- super().__init__()
725
- self.mod_param_sets = mod_param_sets
726
-
727
- self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
728
- self.act_fn = nn.SiLU()
729
-
730
- def forward(self, temb: torch.Tensor) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
731
- mod = self.act_fn(temb)
732
- mod = self.linear(mod)
733
-
734
- if mod.ndim == 2:
735
- mod = mod.unsqueeze(1)
736
- mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
737
- # Return tuple of 3-tuples of modulation params shift/scale/gate
738
- return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
739
-
740
-
741
- class Flux2Transformer2DModel(
742
- ModelMixin,
743
- ConfigMixin,
744
- FromOriginalModelMixin,
745
- ):
746
- """
747
- The Transformer model introduced in Flux 2.
748
-
749
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
750
-
751
- Args:
752
- patch_size (`int`, defaults to `1`):
753
- Patch size to turn the input data into small patches.
754
- in_channels (`int`, defaults to `128`):
755
- The number of channels in the input.
756
- out_channels (`int`, *optional*, defaults to `None`):
757
- The number of channels in the output. If not specified, it defaults to `in_channels`.
758
- num_layers (`int`, defaults to `8`):
759
- The number of layers of dual stream DiT blocks to use.
760
- num_single_layers (`int`, defaults to `48`):
761
- The number of layers of single stream DiT blocks to use.
762
- attention_head_dim (`int`, defaults to `128`):
763
- The number of dimensions to use for each attention head.
764
- num_attention_heads (`int`, defaults to `48`):
765
- The number of attention heads to use.
766
- joint_attention_dim (`int`, defaults to `15360`):
767
- The number of dimensions to use for the joint attention (embedding/channel dimension of
768
- `encoder_hidden_states`).
769
- pooled_projection_dim (`int`, defaults to `768`):
770
- The number of dimensions to use for the pooled projection.
771
- guidance_embeds (`bool`, defaults to `True`):
772
- Whether to use guidance embeddings for guidance-distilled variant of the model.
773
- axes_dims_rope (`Tuple[int]`, defaults to `(32, 32, 32, 32)`):
774
- The dimensions to use for the rotary positional embeddings.
775
- """
776
-
777
- _supports_gradient_checkpointing = True
778
- # _no_split_modules = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
779
- # _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
780
- # _repeated_blocks = ["Flux2TransformerBlock", "Flux2SingleTransformerBlock"]
781
-
782
- @register_to_config
783
- def __init__(
784
- self,
785
- patch_size: int = 1,
786
- in_channels: int = 128,
787
- out_channels: Optional[int] = None,
788
- num_layers: int = 8,
789
- num_single_layers: int = 48,
790
- attention_head_dim: int = 128,
791
- num_attention_heads: int = 48,
792
- joint_attention_dim: int = 15360,
793
- timestep_guidance_channels: int = 256,
794
- mlp_ratio: float = 3.0,
795
- axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
796
- rope_theta: int = 2000,
797
- eps: float = 1e-6,
798
- ):
799
- super().__init__()
800
- self.out_channels = out_channels or in_channels
801
- self.inner_dim = num_attention_heads * attention_head_dim
802
-
803
- # 1. Sinusoidal positional embedding for RoPE on image and text tokens
804
- self.pos_embed = Flux2PosEmbed(theta=rope_theta, axes_dim=axes_dims_rope)
805
-
806
- # 2. Combined timestep + guidance embedding
807
- self.time_guidance_embed = Flux2TimestepGuidanceEmbeddings(
808
- in_channels=timestep_guidance_channels, embedding_dim=self.inner_dim, bias=False
809
- )
810
-
811
- # 3. Modulation (double stream and single stream blocks share modulation parameters, resp.)
812
- # Two sets of shift/scale/gate modulation parameters for the double stream attn and FF sub-blocks
813
- self.double_stream_modulation_img = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
814
- self.double_stream_modulation_txt = Flux2Modulation(self.inner_dim, mod_param_sets=2, bias=False)
815
- # Only one set of modulation parameters as the attn and FF sub-blocks are run in parallel for single stream
816
- self.single_stream_modulation = Flux2Modulation(self.inner_dim, mod_param_sets=1, bias=False)
817
-
818
- # 4. Input projections
819
- self.x_embedder = nn.Linear(in_channels, self.inner_dim, bias=False)
820
- self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim, bias=False)
821
-
822
- # 5. Double Stream Transformer Blocks
823
- self.transformer_blocks = nn.ModuleList(
824
- [
825
- Flux2TransformerBlock(
826
- dim=self.inner_dim,
827
- num_attention_heads=num_attention_heads,
828
- attention_head_dim=attention_head_dim,
829
- mlp_ratio=mlp_ratio,
830
- eps=eps,
831
- bias=False,
832
- )
833
- for _ in range(num_layers)
834
- ]
835
- )
836
-
837
- # 6. Single Stream Transformer Blocks
838
- self.single_transformer_blocks = nn.ModuleList(
839
- [
840
- Flux2SingleTransformerBlock(
841
- dim=self.inner_dim,
842
- num_attention_heads=num_attention_heads,
843
- attention_head_dim=attention_head_dim,
844
- mlp_ratio=mlp_ratio,
845
- eps=eps,
846
- bias=False,
847
- )
848
- for _ in range(num_single_layers)
849
- ]
850
- )
851
-
852
- # 7. Output layers
853
- self.norm_out = AdaLayerNormContinuous(
854
- self.inner_dim, self.inner_dim, elementwise_affine=False, eps=eps, bias=False
855
- )
856
- self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=False)
857
-
858
- self.gradient_checkpointing = False
859
-
860
- self.sp_world_size = 1
861
- self.sp_world_rank = 0
862
-
863
- def _set_gradient_checkpointing(self, *args, **kwargs):
864
- if "value" in kwargs:
865
- self.gradient_checkpointing = kwargs["value"]
866
- elif "enable" in kwargs:
867
- self.gradient_checkpointing = kwargs["enable"]
868
- else:
869
- raise ValueError("Invalid set gradient checkpointing")
870
-
871
- def enable_multi_gpus_inference(self,):
872
- self.sp_world_size = get_sequence_parallel_world_size()
873
- self.sp_world_rank = get_sequence_parallel_rank()
874
- self.all_gather = get_sp_group().all_gather
875
- self.set_attn_processor(Flux2MultiGPUsAttnProcessor2_0())
876
-
877
- @property
878
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
879
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
880
- r"""
881
- Returns:
882
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
883
- indexed by its weight name.
884
- """
885
- # set recursively
886
- processors = {}
887
-
888
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
889
- if hasattr(module, "get_processor"):
890
- processors[f"{name}.processor"] = module.get_processor()
891
-
892
- for sub_name, child in module.named_children():
893
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
894
-
895
- return processors
896
-
897
- for name, module in self.named_children():
898
- fn_recursive_add_processors(name, module, processors)
899
-
900
- return processors
901
-
902
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
903
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
904
- r"""
905
- Sets the attention processor to use to compute attention.
906
-
907
- Parameters:
908
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
909
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
910
- for **all** `Attention` layers.
911
-
912
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
913
- processor. This is strongly recommended when setting trainable attention processors.
914
-
915
- """
916
- count = len(self.attn_processors.keys())
917
-
918
- if isinstance(processor, dict) and len(processor) != count:
919
- raise ValueError(
920
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
921
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
922
- )
923
-
924
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
925
- if hasattr(module, "set_processor"):
926
- if not isinstance(processor, dict):
927
- module.set_processor(processor)
928
- else:
929
- module.set_processor(processor.pop(f"{name}.processor"))
930
-
931
- for sub_name, child in module.named_children():
932
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
933
-
934
- for name, module in self.named_children():
935
- fn_recursive_attn_processor(name, module, processor)
936
-
937
- def forward(
938
- self,
939
- hidden_states: torch.Tensor,
940
- encoder_hidden_states: torch.Tensor = None,
941
- timestep: torch.LongTensor = None,
942
- img_ids: torch.Tensor = None,
943
- txt_ids: torch.Tensor = None,
944
- guidance: torch.Tensor = None,
945
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
946
- return_dict: bool = True,
947
- ) -> Union[torch.Tensor, Transformer2DModelOutput]:
948
- """
949
- The [`FluxTransformer2DModel`] forward method.
950
-
951
- Args:
952
- hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
953
- Input `hidden_states`.
954
- encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
955
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
956
- timestep ( `torch.LongTensor`):
957
- Used to indicate denoising step.
958
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
959
- A list of tensors that if specified are added to the residuals of transformer blocks.
960
- joint_attention_kwargs (`dict`, *optional*):
961
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
962
- `self.processor` in
963
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
964
- return_dict (`bool`, *optional*, defaults to `True`):
965
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
966
- tuple.
967
-
968
- Returns:
969
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
970
- `tuple` where the first element is the sample tensor.
971
- """
972
- # 0. Handle input arguments
973
- if joint_attention_kwargs is not None:
974
- joint_attention_kwargs = joint_attention_kwargs.copy()
975
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
976
- else:
977
- lora_scale = 1.0
978
-
979
- num_txt_tokens = encoder_hidden_states.shape[1]
980
-
981
- # 1. Calculate timestep embedding and modulation parameters
982
- timestep = timestep.to(hidden_states.dtype) * 1000
983
- guidance = guidance.to(hidden_states.dtype) * 1000
984
-
985
- temb = self.time_guidance_embed(timestep, guidance)
986
-
987
- double_stream_mod_img = self.double_stream_modulation_img(temb)
988
- double_stream_mod_txt = self.double_stream_modulation_txt(temb)
989
- single_stream_mod = self.single_stream_modulation(temb)[0]
990
-
991
- # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
992
- hidden_states = self.x_embedder(hidden_states)
993
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
994
-
995
- # 3. Calculate RoPE embeddings from image and text tokens
996
- # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
997
- # text prompts of differents lengths. Is this a use case we want to support?
998
- if img_ids.ndim == 3:
999
- img_ids = img_ids[0]
1000
- if txt_ids.ndim == 3:
1001
- txt_ids = txt_ids[0]
1002
-
1003
- if is_torch_npu_available():
1004
- freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
1005
- image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
1006
- freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
1007
- text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
1008
- else:
1009
- image_rotary_emb = self.pos_embed(img_ids)
1010
- text_rotary_emb = self.pos_embed(txt_ids)
1011
- concat_rotary_emb = (
1012
- torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
1013
- torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
1014
- )
1015
-
1016
- # Context Parallel
1017
- if self.sp_world_size > 1:
1018
- hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
1019
- if concat_rotary_emb is not None:
1020
- txt_rotary_emb = (
1021
- concat_rotary_emb[0][:encoder_hidden_states.shape[1]],
1022
- concat_rotary_emb[1][:encoder_hidden_states.shape[1]]
1023
- )
1024
- concat_rotary_emb = (
1025
- torch.chunk(concat_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
1026
- torch.chunk(concat_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
1027
- )
1028
- concat_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \
1029
- for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, concat_rotary_emb)]
1030
-
1031
- # 4. Double Stream Transformer Blocks
1032
- for index_block, block in enumerate(self.transformer_blocks):
1033
- if torch.is_grad_enabled() and self.gradient_checkpointing:
1034
- def create_custom_forward(module):
1035
- def custom_forward(*inputs):
1036
- return module(*inputs)
1037
-
1038
- return custom_forward
1039
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1040
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1041
- create_custom_forward(block),
1042
- hidden_states,
1043
- encoder_hidden_states,
1044
- double_stream_mod_img,
1045
- double_stream_mod_txt,
1046
- concat_rotary_emb,
1047
- joint_attention_kwargs,
1048
- **ckpt_kwargs,
1049
- )
1050
- else:
1051
- encoder_hidden_states, hidden_states = block(
1052
- hidden_states=hidden_states,
1053
- encoder_hidden_states=encoder_hidden_states,
1054
- temb_mod_params_img=double_stream_mod_img,
1055
- temb_mod_params_txt=double_stream_mod_txt,
1056
- image_rotary_emb=concat_rotary_emb,
1057
- joint_attention_kwargs=joint_attention_kwargs,
1058
- )
1059
-
1060
- # 5. Single Stream Transformer Blocks
1061
- for index_block, block in enumerate(self.single_transformer_blocks):
1062
- if torch.is_grad_enabled() and self.gradient_checkpointing:
1063
- def create_custom_forward(module):
1064
- def custom_forward(*inputs):
1065
- return module(*inputs)
1066
-
1067
- return custom_forward
1068
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1069
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
1070
- create_custom_forward(block),
1071
- hidden_states,
1072
- encoder_hidden_states,
1073
- single_stream_mod,
1074
- concat_rotary_emb,
1075
- joint_attention_kwargs,
1076
- **ckpt_kwargs,
1077
- )
1078
- else:
1079
- encoder_hidden_states, hidden_states = block(
1080
- hidden_states=hidden_states,
1081
- encoder_hidden_states=encoder_hidden_states,
1082
- temb_mod_params=single_stream_mod,
1083
- image_rotary_emb=concat_rotary_emb,
1084
- joint_attention_kwargs=joint_attention_kwargs,
1085
- )
1086
-
1087
- # 6. Output layers
1088
- hidden_states = self.norm_out(hidden_states, temb)
1089
- output = self.proj_out(hidden_states)
1090
-
1091
- if self.sp_world_size > 1:
1092
- output = self.all_gather(output, dim=1)
1093
-
1094
- if not return_dict:
1095
- return (output,)
1096
-
1097
- return Transformer2DModelOutput(sample=output)
1098
-
1099
- @classmethod
1100
- def from_pretrained(
1101
- cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1102
- low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1103
- ):
1104
- if subfolder is not None:
1105
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1106
- print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1107
-
1108
- config_file = os.path.join(pretrained_model_path, 'config.json')
1109
- if not os.path.isfile(config_file):
1110
- raise RuntimeError(f"{config_file} does not exist")
1111
- with open(config_file, "r") as f:
1112
- config = json.load(f)
1113
-
1114
- from diffusers.utils import WEIGHTS_NAME
1115
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1116
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
1117
-
1118
- if "dict_mapping" in transformer_additional_kwargs.keys():
1119
- for key in transformer_additional_kwargs["dict_mapping"]:
1120
- transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
1121
-
1122
- if low_cpu_mem_usage:
1123
- try:
1124
- import re
1125
-
1126
- from diffusers import __version__ as diffusers_version
1127
- if diffusers_version >= "0.33.0":
1128
- from diffusers.models.model_loading_utils import \
1129
- load_model_dict_into_meta
1130
- else:
1131
- from diffusers.models.modeling_utils import \
1132
- load_model_dict_into_meta
1133
- from diffusers.utils import is_accelerate_available
1134
- if is_accelerate_available():
1135
- import accelerate
1136
-
1137
- # Instantiate model with empty weights
1138
- with accelerate.init_empty_weights():
1139
- model = cls.from_config(config, **transformer_additional_kwargs)
1140
-
1141
- param_device = "cpu"
1142
- if os.path.exists(model_file):
1143
- state_dict = torch.load(model_file, map_location="cpu")
1144
- elif os.path.exists(model_file_safetensors):
1145
- from safetensors.torch import load_file, safe_open
1146
- state_dict = load_file(model_file_safetensors)
1147
- else:
1148
- from safetensors.torch import load_file, safe_open
1149
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1150
- state_dict = {}
1151
- print(model_files_safetensors)
1152
- for _model_file_safetensors in model_files_safetensors:
1153
- _state_dict = load_file(_model_file_safetensors)
1154
- for key in _state_dict:
1155
- state_dict[key] = _state_dict[key]
1156
-
1157
- filtered_state_dict = {}
1158
- for key in state_dict:
1159
- if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
1160
- filtered_state_dict[key] = state_dict[key]
1161
- else:
1162
- print(f"Skipping key '{key}' due to size mismatch or absence in model.")
1163
-
1164
- model_keys = set(model.state_dict().keys())
1165
- loaded_keys = set(filtered_state_dict.keys())
1166
- missing_keys = model_keys - loaded_keys
1167
-
1168
- def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
1169
- initialized_dict = {}
1170
-
1171
- with torch.no_grad():
1172
- for key in missing_keys:
1173
- param_shape = model_state_dict[key].shape
1174
- param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
1175
- if 'weight' in key:
1176
- if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
1177
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1178
- elif 'embedding' in key or 'embed' in key:
1179
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1180
- elif 'head' in key or 'output' in key or 'proj_out' in key:
1181
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1182
- elif len(param_shape) >= 2:
1183
- initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
1184
- nn.init.xavier_uniform_(initialized_dict[key])
1185
- else:
1186
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1187
- elif 'bias' in key:
1188
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1189
- elif 'running_mean' in key:
1190
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1191
- elif 'running_var' in key:
1192
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1193
- elif 'num_batches_tracked' in key:
1194
- initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
1195
- else:
1196
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1197
-
1198
- return initialized_dict
1199
-
1200
- if missing_keys:
1201
- print(f"Missing keys will be initialized: {sorted(missing_keys)}")
1202
- initialized_params = initialize_missing_parameters(
1203
- missing_keys,
1204
- model.state_dict(),
1205
- torch_dtype
1206
- )
1207
- filtered_state_dict.update(initialized_params)
1208
-
1209
- if diffusers_version >= "0.33.0":
1210
- # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
1211
- # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
1212
- load_model_dict_into_meta(
1213
- model,
1214
- filtered_state_dict,
1215
- dtype=torch_dtype,
1216
- model_name_or_path=pretrained_model_path,
1217
- )
1218
- else:
1219
- model._convert_deprecated_attention_blocks(filtered_state_dict)
1220
- unexpected_keys = load_model_dict_into_meta(
1221
- model,
1222
- filtered_state_dict,
1223
- device=param_device,
1224
- dtype=torch_dtype,
1225
- model_name_or_path=pretrained_model_path,
1226
- )
1227
-
1228
- if cls._keys_to_ignore_on_load_unexpected is not None:
1229
- for pat in cls._keys_to_ignore_on_load_unexpected:
1230
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1231
-
1232
- if len(unexpected_keys) > 0:
1233
- print(
1234
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1235
- )
1236
-
1237
- return model
1238
- except Exception as e:
1239
- print(
1240
- f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1241
- )
1242
-
1243
- model = cls.from_config(config, **transformer_additional_kwargs)
1244
- if os.path.exists(model_file):
1245
- state_dict = torch.load(model_file, map_location="cpu")
1246
- elif os.path.exists(model_file_safetensors):
1247
- from safetensors.torch import load_file, safe_open
1248
- state_dict = load_file(model_file_safetensors)
1249
- else:
1250
- from safetensors.torch import load_file, safe_open
1251
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1252
- state_dict = {}
1253
- for _model_file_safetensors in model_files_safetensors:
1254
- _state_dict = load_file(_model_file_safetensors)
1255
- for key in _state_dict:
1256
- state_dict[key] = _state_dict[key]
1257
-
1258
- tmp_state_dict = {}
1259
- for key in state_dict:
1260
- if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1261
- tmp_state_dict[key] = state_dict[key]
1262
- else:
1263
- print(key, "Size don't match, skip")
1264
-
1265
- state_dict = tmp_state_dict
1266
-
1267
- m, u = model.load_state_dict(state_dict, strict=False)
1268
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1269
- print(m)
1270
-
1271
- params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1272
- print(f"### All Parameters: {sum(params) / 1e6} M")
1273
-
1274
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1275
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1276
-
1277
- model = model.to(torch_dtype)
1278
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/flux2_transformer2d_control.py DELETED
@@ -1,312 +0,0 @@
1
- # Modified from https://github.com/ali-vilab/VACE/blob/main/control/models/wan/wan_control.py
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) Alibaba, Inc. and its affiliates.
4
-
5
- import glob
6
- import inspect
7
- import json
8
- import os
9
- from typing import Any, Dict, List, Optional, Tuple, Union
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from diffusers.configuration_utils import ConfigMixin, register_to_config
15
- from diffusers.loaders import FromOriginalModelMixin
16
- from diffusers.models.attention_processor import Attention, AttentionProcessor
17
- from diffusers.models.embeddings import (TimestepEmbedding, Timesteps,
18
- apply_rotary_emb,
19
- get_1d_rotary_pos_embed)
20
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
21
- from diffusers.models.modeling_utils import ModelMixin
22
- from diffusers.models.normalization import AdaLayerNormContinuous
23
- from diffusers.utils import (USE_PEFT_BACKEND, is_torch_npu_available,
24
- is_torch_version, logging, scale_lora_layers,
25
- unscale_lora_layers)
26
-
27
- from .flux2_transformer2d import (Flux2SingleTransformerBlock,
28
- Flux2Transformer2DModel,
29
- Flux2TransformerBlock)
30
-
31
-
32
- class Flux2ControlTransformerBlock(Flux2TransformerBlock):
33
- def __init__(
34
- self,
35
- dim: int,
36
- num_attention_heads: int,
37
- attention_head_dim: int,
38
- mlp_ratio: float = 3.0,
39
- eps: float = 1e-6,
40
- bias: bool = False,
41
- block_id=0
42
- ):
43
- super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio, eps, bias)
44
- self.block_id = block_id
45
- if block_id == 0:
46
- self.before_proj = nn.Linear(dim, dim)
47
- nn.init.zeros_(self.before_proj.weight)
48
- nn.init.zeros_(self.before_proj.bias)
49
- self.after_proj = nn.Linear(dim, dim)
50
- nn.init.zeros_(self.after_proj.weight)
51
- nn.init.zeros_(self.after_proj.bias)
52
-
53
- def forward(self, c, x, **kwargs):
54
- if self.block_id == 0:
55
- c = self.before_proj(c) + x
56
- all_c = []
57
- else:
58
- all_c = list(torch.unbind(c))
59
- c = all_c.pop(-1)
60
-
61
- encoder_hidden_states, c = super().forward(c, **kwargs)
62
- c_skip = self.after_proj(c)
63
- all_c += [c_skip, c]
64
- c = torch.stack(all_c)
65
- return encoder_hidden_states, c
66
-
67
-
68
- class BaseFlux2TransformerBlock(Flux2TransformerBlock):
69
- def __init__(
70
- self,
71
- dim: int,
72
- num_attention_heads: int,
73
- attention_head_dim: int,
74
- mlp_ratio: float = 3.0,
75
- eps: float = 1e-6,
76
- bias: bool = False,
77
- block_id=0
78
- ):
79
- super().__init__(dim, num_attention_heads, attention_head_dim, mlp_ratio, eps, bias)
80
- self.block_id = block_id
81
-
82
- def forward(self, hidden_states, hints=None, context_scale=1.0, **kwargs):
83
- encoder_hidden_states, hidden_states = super().forward(hidden_states, **kwargs)
84
- if self.block_id is not None:
85
- hidden_states = hidden_states + hints[self.block_id] * context_scale
86
- return encoder_hidden_states, hidden_states
87
-
88
-
89
- class Flux2ControlTransformer2DModel(Flux2Transformer2DModel):
90
- @register_to_config
91
- def __init__(
92
- self,
93
- control_layers=None,
94
- control_in_dim=None,
95
- patch_size: int = 1,
96
- in_channels: int = 128,
97
- out_channels: Optional[int] = None,
98
- num_layers: int = 8,
99
- num_single_layers: int = 48,
100
- attention_head_dim: int = 128,
101
- num_attention_heads: int = 48,
102
- joint_attention_dim: int = 15360,
103
- timestep_guidance_channels: int = 256,
104
- mlp_ratio: float = 3.0,
105
- axes_dims_rope: Tuple[int, ...] = (32, 32, 32, 32),
106
- rope_theta: int = 2000,
107
- eps: float = 1e-6,
108
- ):
109
- super().__init__(
110
- patch_size, in_channels, out_channels, num_layers, num_single_layers, attention_head_dim,
111
- num_attention_heads, joint_attention_dim, timestep_guidance_channels, mlp_ratio, axes_dims_rope,
112
- rope_theta, eps
113
- )
114
-
115
- self.control_layers = [i for i in range(0, self.num_layers, 2)] if control_layers is None else control_layers
116
- self.control_in_dim = self.in_dim if control_in_dim is None else control_in_dim
117
-
118
- assert 0 in self.control_layers
119
- self.control_layers_mapping = {i: n for n, i in enumerate(self.control_layers)}
120
-
121
- # blocks
122
- del self.transformer_blocks
123
- self.transformer_blocks = nn.ModuleList(
124
- [
125
- BaseFlux2TransformerBlock(
126
- dim=self.inner_dim,
127
- num_attention_heads=num_attention_heads,
128
- attention_head_dim=attention_head_dim,
129
- mlp_ratio=mlp_ratio,
130
- eps=eps,
131
- block_id=self.control_layers_mapping[i] if i in self.control_layers else None
132
- )
133
- for i in range(num_layers)
134
- ]
135
- )
136
-
137
- # control blocks
138
- self.control_transformer_blocks = nn.ModuleList(
139
- [
140
- Flux2ControlTransformerBlock(
141
- dim=self.inner_dim,
142
- num_attention_heads=num_attention_heads,
143
- attention_head_dim=attention_head_dim,
144
- mlp_ratio=mlp_ratio,
145
- eps=eps,
146
- block_id=i
147
- )
148
- for i in self.control_layers
149
- ]
150
- )
151
-
152
- # control patch embeddings
153
- self.control_img_in = nn.Linear(self.control_in_dim, self.inner_dim)
154
-
155
- def forward_control(
156
- self,
157
- x,
158
- control_context,
159
- kwargs
160
- ):
161
- # embeddings
162
- c = self.control_img_in(control_context)
163
- # Context Parallel
164
- if self.sp_world_size > 1:
165
- c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank]
166
-
167
- # arguments
168
- new_kwargs = dict(x=x)
169
- new_kwargs.update(kwargs)
170
-
171
- for block in self.control_transformer_blocks:
172
- if torch.is_grad_enabled() and self.gradient_checkpointing:
173
- def create_custom_forward(module, **static_kwargs):
174
- def custom_forward(*inputs):
175
- return module(*inputs, **static_kwargs)
176
- return custom_forward
177
- ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
178
- encoder_hidden_states, c = torch.utils.checkpoint.checkpoint(
179
- create_custom_forward(block, **new_kwargs),
180
- c,
181
- **ckpt_kwargs,
182
- )
183
- else:
184
- encoder_hidden_states, c = block(c, **new_kwargs)
185
- new_kwargs["encoder_hidden_states"] = encoder_hidden_states
186
-
187
- hints = torch.unbind(c)[:-1]
188
- return hints
189
-
190
- def forward(
191
- self,
192
- hidden_states: torch.Tensor,
193
- encoder_hidden_states: torch.Tensor = None,
194
- timestep: torch.LongTensor = None,
195
- img_ids: torch.Tensor = None,
196
- txt_ids: torch.Tensor = None,
197
- guidance: torch.Tensor = None,
198
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
199
- control_context=None,
200
- control_context_scale=1.0,
201
- return_dict: bool = True,
202
- ):
203
- num_txt_tokens = encoder_hidden_states.shape[1]
204
-
205
- # 1. Calculate timestep embedding and modulation parameters
206
- timestep = timestep.to(hidden_states.dtype) * 1000
207
- guidance = guidance.to(hidden_states.dtype) * 1000
208
-
209
- temb = self.time_guidance_embed(timestep, guidance)
210
-
211
- double_stream_mod_img = self.double_stream_modulation_img(temb)
212
- double_stream_mod_txt = self.double_stream_modulation_txt(temb)
213
- single_stream_mod = self.single_stream_modulation(temb)[0]
214
-
215
- # 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
216
- hidden_states = self.x_embedder(hidden_states)
217
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
218
-
219
- # 3. Calculate RoPE embeddings from image and text tokens
220
- # NOTE: the below logic means that we can't support batched inference with images of different resolutions or
221
- # text prompts of differents lengths. Is this a use case we want to support?
222
- if img_ids.ndim == 3:
223
- img_ids = img_ids[0]
224
- if txt_ids.ndim == 3:
225
- txt_ids = txt_ids[0]
226
-
227
- if is_torch_npu_available():
228
- freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
229
- image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
230
- freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
231
- text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
232
- else:
233
- image_rotary_emb = self.pos_embed(img_ids)
234
- text_rotary_emb = self.pos_embed(txt_ids)
235
- concat_rotary_emb = (
236
- torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
237
- torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
238
- )
239
-
240
- # Arguments
241
- kwargs = dict(
242
- encoder_hidden_states=encoder_hidden_states,
243
- temb_mod_params_img=double_stream_mod_img,
244
- temb_mod_params_txt=double_stream_mod_txt,
245
- image_rotary_emb=concat_rotary_emb,
246
- joint_attention_kwargs=joint_attention_kwargs,
247
- )
248
- hints = self.forward_control(
249
- hidden_states, control_context, kwargs
250
- )
251
-
252
- for index_block, block in enumerate(self.transformer_blocks):
253
- # Arguments
254
- kwargs = dict(
255
- encoder_hidden_states=encoder_hidden_states,
256
- temb_mod_params_img=double_stream_mod_img,
257
- temb_mod_params_txt=double_stream_mod_txt,
258
- image_rotary_emb=concat_rotary_emb,
259
- joint_attention_kwargs=joint_attention_kwargs,
260
- hints=hints,
261
- context_scale=control_context_scale
262
- )
263
- if torch.is_grad_enabled() and self.gradient_checkpointing:
264
- def create_custom_forward(module, **static_kwargs):
265
- def custom_forward(*inputs):
266
- return module(*inputs, **static_kwargs)
267
- return custom_forward
268
-
269
- ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
270
-
271
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
272
- create_custom_forward(block, **kwargs),
273
- hidden_states,
274
- **ckpt_kwargs,
275
- )
276
- else:
277
- encoder_hidden_states, hidden_states = block(hidden_states, **kwargs)
278
-
279
- for index_block, block in enumerate(self.single_transformer_blocks):
280
- if torch.is_grad_enabled() and self.gradient_checkpointing:
281
- def create_custom_forward(module):
282
- def custom_forward(*inputs):
283
- return module(*inputs)
284
-
285
- return custom_forward
286
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
287
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
288
- create_custom_forward(block),
289
- hidden_states,
290
- encoder_hidden_states,
291
- single_stream_mod,
292
- concat_rotary_emb,
293
- joint_attention_kwargs,
294
- **ckpt_kwargs,
295
- )
296
- else:
297
- encoder_hidden_states, hidden_states = block(
298
- hidden_states=hidden_states,
299
- encoder_hidden_states=encoder_hidden_states,
300
- temb_mod_params=single_stream_mod,
301
- image_rotary_emb=concat_rotary_emb,
302
- joint_attention_kwargs=joint_attention_kwargs,
303
- )
304
-
305
- # 6. Output layers
306
- hidden_states = self.norm_out(hidden_states, temb)
307
- output = self.proj_out(hidden_states)
308
-
309
- if not return_dict:
310
- return (output,)
311
-
312
- return Transformer2DModelOutput(sample=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/flux2_vae.py DELETED
@@ -1,543 +0,0 @@
1
- # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_flux2.py
2
- # Copyright 2025 The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import math
16
- from typing import Dict, Optional, Tuple, Union
17
-
18
- import torch
19
- import torch.nn as nn
20
- from diffusers.configuration_utils import ConfigMixin, register_to_config
21
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
22
- from diffusers.models.attention_processor import (
23
- ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
24
- AttentionProcessor, AttnAddedKVProcessor, AttnProcessor,
25
- FusedAttnProcessor2_0)
26
- from diffusers.models.autoencoders.vae import (Decoder,
27
- DecoderOutput,
28
- DiagonalGaussianDistribution,
29
- Encoder)
30
- from diffusers.models.modeling_outputs import AutoencoderKLOutput
31
- from diffusers.models.modeling_utils import ModelMixin
32
- from diffusers.utils import deprecate
33
- from diffusers.utils.accelerate_utils import apply_forward_hook
34
-
35
-
36
- class AutoencoderKLFlux2(ModelMixin, ConfigMixin, FromOriginalModelMixin):
37
- r"""
38
- A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
39
-
40
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
41
- for all models (such as downloading or saving).
42
-
43
- Parameters:
44
- in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
45
- out_channels (int, *optional*, defaults to 3): Number of channels in the output.
46
- down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
47
- Tuple of downsample block types.
48
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
49
- Tuple of upsample block types.
50
- block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
51
- Tuple of block output channels.
52
- act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
53
- latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
54
- sample_size (`int`, *optional*, defaults to `32`): Sample input size.
55
- force_upcast (`bool`, *optional*, default to `True`):
56
- If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE
57
- can be fine-tuned / trained to a lower range without losing too much precision in which case `force_upcast`
58
- can be set to `False` - see: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix
59
- mid_block_add_attention (`bool`, *optional*, default to `True`):
60
- If enabled, the mid_block of the Encoder and Decoder will have attention blocks. If set to false, the
61
- mid_block will only have resnet blocks
62
- """
63
-
64
- _supports_gradient_checkpointing = True
65
- _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D"]
66
-
67
- @register_to_config
68
- def __init__(
69
- self,
70
- in_channels: int = 3,
71
- out_channels: int = 3,
72
- down_block_types: Tuple[str, ...] = (
73
- "DownEncoderBlock2D",
74
- "DownEncoderBlock2D",
75
- "DownEncoderBlock2D",
76
- "DownEncoderBlock2D",
77
- ),
78
- up_block_types: Tuple[str, ...] = (
79
- "UpDecoderBlock2D",
80
- "UpDecoderBlock2D",
81
- "UpDecoderBlock2D",
82
- "UpDecoderBlock2D",
83
- ),
84
- block_out_channels: Tuple[int, ...] = (
85
- 128,
86
- 256,
87
- 512,
88
- 512,
89
- ),
90
- layers_per_block: int = 2,
91
- act_fn: str = "silu",
92
- latent_channels: int = 32,
93
- norm_num_groups: int = 32,
94
- sample_size: int = 1024, # YiYi notes: not sure
95
- force_upcast: bool = True,
96
- use_quant_conv: bool = True,
97
- use_post_quant_conv: bool = True,
98
- mid_block_add_attention: bool = True,
99
- batch_norm_eps: float = 1e-4,
100
- batch_norm_momentum: float = 0.1,
101
- patch_size: Tuple[int, int] = (2, 2),
102
- ):
103
- super().__init__()
104
-
105
- # pass init params to Encoder
106
- self.encoder = Encoder(
107
- in_channels=in_channels,
108
- out_channels=latent_channels,
109
- down_block_types=down_block_types,
110
- block_out_channels=block_out_channels,
111
- layers_per_block=layers_per_block,
112
- act_fn=act_fn,
113
- norm_num_groups=norm_num_groups,
114
- double_z=True,
115
- mid_block_add_attention=mid_block_add_attention,
116
- )
117
-
118
- # pass init params to Decoder
119
- self.decoder = Decoder(
120
- in_channels=latent_channels,
121
- out_channels=out_channels,
122
- up_block_types=up_block_types,
123
- block_out_channels=block_out_channels,
124
- layers_per_block=layers_per_block,
125
- norm_num_groups=norm_num_groups,
126
- act_fn=act_fn,
127
- mid_block_add_attention=mid_block_add_attention,
128
- )
129
-
130
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
131
- self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
132
-
133
- self.bn = nn.BatchNorm2d(
134
- math.prod(patch_size) * latent_channels,
135
- eps=batch_norm_eps,
136
- momentum=batch_norm_momentum,
137
- affine=False,
138
- track_running_stats=True,
139
- )
140
-
141
- self.use_slicing = False
142
- self.use_tiling = False
143
-
144
- # only relevant if vae tiling is enabled
145
- self.tile_sample_min_size = self.config.sample_size
146
- sample_size = (
147
- self.config.sample_size[0]
148
- if isinstance(self.config.sample_size, (list, tuple))
149
- else self.config.sample_size
150
- )
151
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
152
- self.tile_overlap_factor = 0.25
153
-
154
- @property
155
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
156
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
157
- r"""
158
- Returns:
159
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
160
- indexed by its weight name.
161
- """
162
- # set recursively
163
- processors = {}
164
-
165
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
166
- if hasattr(module, "get_processor"):
167
- processors[f"{name}.processor"] = module.get_processor()
168
-
169
- for sub_name, child in module.named_children():
170
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
171
-
172
- return processors
173
-
174
- for name, module in self.named_children():
175
- fn_recursive_add_processors(name, module, processors)
176
-
177
- return processors
178
-
179
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
180
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
181
- r"""
182
- Sets the attention processor to use to compute attention.
183
-
184
- Parameters:
185
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
186
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
187
- for **all** `Attention` layers.
188
-
189
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
190
- processor. This is strongly recommended when setting trainable attention processors.
191
-
192
- """
193
- count = len(self.attn_processors.keys())
194
-
195
- if isinstance(processor, dict) and len(processor) != count:
196
- raise ValueError(
197
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
198
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
199
- )
200
-
201
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
202
- if hasattr(module, "set_processor"):
203
- if not isinstance(processor, dict):
204
- module.set_processor(processor)
205
- else:
206
- module.set_processor(processor.pop(f"{name}.processor"))
207
-
208
- for sub_name, child in module.named_children():
209
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
210
-
211
- for name, module in self.named_children():
212
- fn_recursive_attn_processor(name, module, processor)
213
-
214
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
215
- def set_default_attn_processor(self):
216
- """
217
- Disables custom attention processors and sets the default attention implementation.
218
- """
219
- if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
220
- processor = AttnAddedKVProcessor()
221
- elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
222
- processor = AttnProcessor()
223
- else:
224
- raise ValueError(
225
- f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
226
- )
227
-
228
- self.set_attn_processor(processor)
229
-
230
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
231
- batch_size, num_channels, height, width = x.shape
232
-
233
- if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
234
- return self._tiled_encode(x)
235
-
236
- enc = self.encoder(x)
237
- if self.quant_conv is not None:
238
- enc = self.quant_conv(enc)
239
-
240
- return enc
241
-
242
- @apply_forward_hook
243
- def encode(
244
- self, x: torch.Tensor, return_dict: bool = True
245
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
246
- """
247
- Encode a batch of images into latents.
248
-
249
- Args:
250
- x (`torch.Tensor`): Input batch of images.
251
- return_dict (`bool`, *optional*, defaults to `True`):
252
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
253
-
254
- Returns:
255
- The latent representations of the encoded images. If `return_dict` is True, a
256
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
257
- """
258
- if self.use_slicing and x.shape[0] > 1:
259
- encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
260
- h = torch.cat(encoded_slices)
261
- else:
262
- h = self._encode(x)
263
-
264
- posterior = DiagonalGaussianDistribution(h)
265
-
266
- if not return_dict:
267
- return (posterior,)
268
-
269
- return AutoencoderKLOutput(latent_dist=posterior)
270
-
271
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
272
- if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
273
- return self.tiled_decode(z, return_dict=return_dict)
274
-
275
- if self.post_quant_conv is not None:
276
- z = self.post_quant_conv(z)
277
-
278
- dec = self.decoder(z)
279
-
280
- if not return_dict:
281
- return (dec,)
282
-
283
- return DecoderOutput(sample=dec)
284
-
285
- @apply_forward_hook
286
- def decode(
287
- self, z: torch.FloatTensor, return_dict: bool = True, generator=None
288
- ) -> Union[DecoderOutput, torch.FloatTensor]:
289
- """
290
- Decode a batch of images.
291
-
292
- Args:
293
- z (`torch.Tensor`): Input batch of latent vectors.
294
- return_dict (`bool`, *optional*, defaults to `True`):
295
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
296
-
297
- Returns:
298
- [`~models.vae.DecoderOutput`] or `tuple`:
299
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
300
- returned.
301
-
302
- """
303
- if self.use_slicing and z.shape[0] > 1:
304
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
305
- decoded = torch.cat(decoded_slices)
306
- else:
307
- decoded = self._decode(z).sample
308
-
309
- if not return_dict:
310
- return (decoded,)
311
-
312
- return DecoderOutput(sample=decoded)
313
-
314
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
315
- blend_extent = min(a.shape[2], b.shape[2], blend_extent)
316
- for y in range(blend_extent):
317
- b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
318
- return b
319
-
320
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
321
- blend_extent = min(a.shape[3], b.shape[3], blend_extent)
322
- for x in range(blend_extent):
323
- b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
324
- return b
325
-
326
- def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
327
- r"""Encode a batch of images using a tiled encoder.
328
-
329
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
330
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
331
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
332
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
333
- output, but they should be much less noticeable.
334
-
335
- Args:
336
- x (`torch.Tensor`): Input batch of images.
337
-
338
- Returns:
339
- `torch.Tensor`:
340
- The latent representation of the encoded videos.
341
- """
342
-
343
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
344
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
345
- row_limit = self.tile_latent_min_size - blend_extent
346
-
347
- # Split the image into 512x512 tiles and encode them separately.
348
- rows = []
349
- for i in range(0, x.shape[2], overlap_size):
350
- row = []
351
- for j in range(0, x.shape[3], overlap_size):
352
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
353
- tile = self.encoder(tile)
354
- if self.config.use_quant_conv:
355
- tile = self.quant_conv(tile)
356
- row.append(tile)
357
- rows.append(row)
358
- result_rows = []
359
- for i, row in enumerate(rows):
360
- result_row = []
361
- for j, tile in enumerate(row):
362
- # blend the above tile and the left tile
363
- # to the current tile and add the current tile to the result row
364
- if i > 0:
365
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
366
- if j > 0:
367
- tile = self.blend_h(row[j - 1], tile, blend_extent)
368
- result_row.append(tile[:, :, :row_limit, :row_limit])
369
- result_rows.append(torch.cat(result_row, dim=3))
370
-
371
- enc = torch.cat(result_rows, dim=2)
372
- return enc
373
-
374
- def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> AutoencoderKLOutput:
375
- r"""Encode a batch of images using a tiled encoder.
376
-
377
- When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
378
- steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
379
- different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
380
- tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
381
- output, but they should be much less noticeable.
382
-
383
- Args:
384
- x (`torch.Tensor`): Input batch of images.
385
- return_dict (`bool`, *optional*, defaults to `True`):
386
- Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
387
-
388
- Returns:
389
- [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
390
- If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
391
- `tuple` is returned.
392
- """
393
- deprecation_message = (
394
- "The tiled_encode implementation supporting the `return_dict` parameter is deprecated. In the future, the "
395
- "implementation of this method will be replaced with that of `_tiled_encode` and you will no longer be able "
396
- "to pass `return_dict`. You will also have to create a `DiagonalGaussianDistribution()` from the returned value."
397
- )
398
- deprecate("tiled_encode", "1.0.0", deprecation_message, standard_warn=False)
399
-
400
- overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
401
- blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
402
- row_limit = self.tile_latent_min_size - blend_extent
403
-
404
- # Split the image into 512x512 tiles and encode them separately.
405
- rows = []
406
- for i in range(0, x.shape[2], overlap_size):
407
- row = []
408
- for j in range(0, x.shape[3], overlap_size):
409
- tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
410
- tile = self.encoder(tile)
411
- if self.config.use_quant_conv:
412
- tile = self.quant_conv(tile)
413
- row.append(tile)
414
- rows.append(row)
415
- result_rows = []
416
- for i, row in enumerate(rows):
417
- result_row = []
418
- for j, tile in enumerate(row):
419
- # blend the above tile and the left tile
420
- # to the current tile and add the current tile to the result row
421
- if i > 0:
422
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
423
- if j > 0:
424
- tile = self.blend_h(row[j - 1], tile, blend_extent)
425
- result_row.append(tile[:, :, :row_limit, :row_limit])
426
- result_rows.append(torch.cat(result_row, dim=3))
427
-
428
- moments = torch.cat(result_rows, dim=2)
429
- posterior = DiagonalGaussianDistribution(moments)
430
-
431
- if not return_dict:
432
- return (posterior,)
433
-
434
- return AutoencoderKLOutput(latent_dist=posterior)
435
-
436
- def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
437
- r"""
438
- Decode a batch of images using a tiled decoder.
439
-
440
- Args:
441
- z (`torch.Tensor`): Input batch of latent vectors.
442
- return_dict (`bool`, *optional*, defaults to `True`):
443
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
444
-
445
- Returns:
446
- [`~models.vae.DecoderOutput`] or `tuple`:
447
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
448
- returned.
449
- """
450
- overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
451
- blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
452
- row_limit = self.tile_sample_min_size - blend_extent
453
-
454
- # Split z into overlapping 64x64 tiles and decode them separately.
455
- # The tiles have an overlap to avoid seams between tiles.
456
- rows = []
457
- for i in range(0, z.shape[2], overlap_size):
458
- row = []
459
- for j in range(0, z.shape[3], overlap_size):
460
- tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
461
- if self.config.use_post_quant_conv:
462
- tile = self.post_quant_conv(tile)
463
- decoded = self.decoder(tile)
464
- row.append(decoded)
465
- rows.append(row)
466
- result_rows = []
467
- for i, row in enumerate(rows):
468
- result_row = []
469
- for j, tile in enumerate(row):
470
- # blend the above tile and the left tile
471
- # to the current tile and add the current tile to the result row
472
- if i > 0:
473
- tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
474
- if j > 0:
475
- tile = self.blend_h(row[j - 1], tile, blend_extent)
476
- result_row.append(tile[:, :, :row_limit, :row_limit])
477
- result_rows.append(torch.cat(result_row, dim=3))
478
-
479
- dec = torch.cat(result_rows, dim=2)
480
- if not return_dict:
481
- return (dec,)
482
-
483
- return DecoderOutput(sample=dec)
484
-
485
- def forward(
486
- self,
487
- sample: torch.Tensor,
488
- sample_posterior: bool = False,
489
- return_dict: bool = True,
490
- generator: Optional[torch.Generator] = None,
491
- ) -> Union[DecoderOutput, torch.Tensor]:
492
- r"""
493
- Args:
494
- sample (`torch.Tensor`): Input sample.
495
- sample_posterior (`bool`, *optional*, defaults to `False`):
496
- Whether to sample from the posterior.
497
- return_dict (`bool`, *optional*, defaults to `True`):
498
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
499
- """
500
- x = sample
501
- posterior = self.encode(x).latent_dist
502
- if sample_posterior:
503
- z = posterior.sample(generator=generator)
504
- else:
505
- z = posterior.mode()
506
- dec = self.decode(z).sample
507
-
508
- if not return_dict:
509
- return (dec,)
510
-
511
- return DecoderOutput(sample=dec)
512
-
513
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
514
- def fuse_qkv_projections(self):
515
- """
516
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
517
- are fused. For cross-attention modules, key and value projection matrices are fused.
518
-
519
- > [!WARNING] > This API is 🧪 experimental.
520
- """
521
- self.original_attn_processors = None
522
-
523
- for _, attn_processor in self.attn_processors.items():
524
- if "Added" in str(attn_processor.__class__.__name__):
525
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
526
-
527
- self.original_attn_processors = self.attn_processors
528
-
529
- for module in self.modules():
530
- if isinstance(module, Attention):
531
- module.fuse_projections(fuse=True)
532
-
533
- self.set_attn_processor(FusedAttnProcessor2_0())
534
-
535
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
536
- def unfuse_qkv_projections(self):
537
- """Disables the fused QKV projection if enabled.
538
-
539
- > [!WARNING] > This API is 🧪 experimental.
540
-
541
- """
542
- if self.original_attn_processors is not None:
543
- self.set_attn_processor(self.original_attn_processors)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/flux_transformer2d.py DELETED
@@ -1,832 +0,0 @@
1
- # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py
2
- # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import inspect
17
- from typing import Any, Dict, List, Optional, Tuple, Union
18
-
19
- import numpy as np
20
- import torch
21
- import torch.nn as nn
22
- import torch.nn.functional as F
23
- from diffusers.configuration_utils import ConfigMixin, register_to_config
24
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
25
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
26
- from diffusers.models.attention import FeedForward
27
- from diffusers.models.attention_processor import AttentionProcessor
28
- from diffusers.models.embeddings import (
29
- CombinedTimestepGuidanceTextProjEmbeddings,
30
- CombinedTimestepTextProjEmbeddings, get_1d_rotary_pos_embed)
31
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
32
- from diffusers.models.modeling_utils import ModelMixin
33
- from diffusers.models.normalization import (AdaLayerNormContinuous,
34
- AdaLayerNormZero,
35
- AdaLayerNormZeroSingle)
36
- from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
37
- scale_lora_layers, unscale_lora_layers)
38
- from diffusers.utils.torch_utils import maybe_allow_in_graph
39
-
40
- from ..dist import (FluxMultiGPUsAttnProcessor2_0, get_sequence_parallel_rank,
41
- get_sequence_parallel_world_size, get_sp_group)
42
- from .attention_utils import attention
43
-
44
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
-
46
- def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
47
- query = attn.to_q(hidden_states)
48
- key = attn.to_k(hidden_states)
49
- value = attn.to_v(hidden_states)
50
-
51
- encoder_query = encoder_key = encoder_value = None
52
- if encoder_hidden_states is not None and attn.added_kv_proj_dim is not None:
53
- encoder_query = attn.add_q_proj(encoder_hidden_states)
54
- encoder_key = attn.add_k_proj(encoder_hidden_states)
55
- encoder_value = attn.add_v_proj(encoder_hidden_states)
56
-
57
- return query, key, value, encoder_query, encoder_key, encoder_value
58
-
59
- def _get_qkv_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
60
- return _get_projections(attn, hidden_states, encoder_hidden_states)
61
-
62
- def apply_rotary_emb(
63
- x: torch.Tensor,
64
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
65
- use_real: bool = True,
66
- use_real_unbind_dim: int = -1,
67
- sequence_dim: int = 2,
68
- ) -> Tuple[torch.Tensor, torch.Tensor]:
69
- """
70
- Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
71
- to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
72
- reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
73
- tensors contain rotary embeddings and are returned as real tensors.
74
-
75
- Args:
76
- x (`torch.Tensor`):
77
- Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
78
- freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
79
-
80
- Returns:
81
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
82
- """
83
- if use_real:
84
- cos, sin = freqs_cis # [S, D]
85
- if sequence_dim == 2:
86
- cos = cos[None, None, :, :]
87
- sin = sin[None, None, :, :]
88
- elif sequence_dim == 1:
89
- cos = cos[None, :, None, :]
90
- sin = sin[None, :, None, :]
91
- else:
92
- raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
93
-
94
- cos, sin = cos.to(x.device), sin.to(x.device)
95
-
96
- if use_real_unbind_dim == -1:
97
- # Used for flux, cogvideox, hunyuan-dit
98
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
99
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
100
- elif use_real_unbind_dim == -2:
101
- # Used for Stable Audio, OmniGen, CogView4 and Cosmos
102
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
103
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
104
- else:
105
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
106
-
107
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
108
-
109
- return out
110
- else:
111
- # used for lumina
112
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
113
- freqs_cis = freqs_cis.unsqueeze(2)
114
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
115
-
116
- return x_out.type_as(x)
117
-
118
-
119
- class FluxAttnProcessor:
120
- _attention_backend = None
121
-
122
- def __init__(self):
123
- if not hasattr(F, "scaled_dot_product_attention"):
124
- raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
125
-
126
- def __call__(
127
- self,
128
- attn: "FluxAttention",
129
- hidden_states: torch.Tensor,
130
- encoder_hidden_states: torch.Tensor = None,
131
- attention_mask: Optional[torch.Tensor] = None,
132
- image_rotary_emb: Optional[torch.Tensor] = None,
133
- text_seq_len: int = None,
134
- ) -> torch.Tensor:
135
- query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
136
- attn, hidden_states, encoder_hidden_states
137
- )
138
-
139
- query = query.unflatten(-1, (attn.heads, -1))
140
- key = key.unflatten(-1, (attn.heads, -1))
141
- value = value.unflatten(-1, (attn.heads, -1))
142
-
143
- query = attn.norm_q(query)
144
- key = attn.norm_k(key)
145
-
146
- if attn.added_kv_proj_dim is not None:
147
- encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
148
- encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
149
- encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
150
-
151
- encoder_query = attn.norm_added_q(encoder_query)
152
- encoder_key = attn.norm_added_k(encoder_key)
153
-
154
- query = torch.cat([encoder_query, query], dim=1)
155
- key = torch.cat([encoder_key, key], dim=1)
156
- value = torch.cat([encoder_value, value], dim=1)
157
-
158
- if image_rotary_emb is not None:
159
- query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
160
- key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
161
-
162
- hidden_states = attention(
163
- query, key, value, attn_mask=attention_mask,
164
- )
165
- hidden_states = hidden_states.flatten(2, 3)
166
- hidden_states = hidden_states.to(query.dtype)
167
-
168
- if encoder_hidden_states is not None:
169
- encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
170
- [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
171
- )
172
- hidden_states = attn.to_out[0](hidden_states)
173
- hidden_states = attn.to_out[1](hidden_states)
174
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
175
-
176
- return hidden_states, encoder_hidden_states
177
- else:
178
- return hidden_states
179
-
180
-
181
- class FluxAttention(torch.nn.Module):
182
- _default_processor_cls = FluxAttnProcessor
183
- _available_processors = [
184
- FluxAttnProcessor,
185
- ]
186
-
187
- def __init__(
188
- self,
189
- query_dim: int,
190
- heads: int = 8,
191
- dim_head: int = 64,
192
- dropout: float = 0.0,
193
- bias: bool = False,
194
- added_kv_proj_dim: Optional[int] = None,
195
- added_proj_bias: Optional[bool] = True,
196
- out_bias: bool = True,
197
- eps: float = 1e-5,
198
- out_dim: int = None,
199
- context_pre_only: Optional[bool] = None,
200
- pre_only: bool = False,
201
- elementwise_affine: bool = True,
202
- processor=None,
203
- ):
204
- super().__init__()
205
-
206
- self.head_dim = dim_head
207
- self.inner_dim = out_dim if out_dim is not None else dim_head * heads
208
- self.query_dim = query_dim
209
- self.use_bias = bias
210
- self.dropout = dropout
211
- self.out_dim = out_dim if out_dim is not None else query_dim
212
- self.context_pre_only = context_pre_only
213
- self.pre_only = pre_only
214
- self.heads = out_dim // dim_head if out_dim is not None else heads
215
- self.added_kv_proj_dim = added_kv_proj_dim
216
- self.added_proj_bias = added_proj_bias
217
-
218
- self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
219
- self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
220
- self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
221
- self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
222
- self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
223
-
224
- if not self.pre_only:
225
- self.to_out = torch.nn.ModuleList([])
226
- self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
227
- self.to_out.append(torch.nn.Dropout(dropout))
228
-
229
- if added_kv_proj_dim is not None:
230
- self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
231
- self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
232
- self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
233
- self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
234
- self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
235
- self.to_add_out = torch.nn.Linear(self.inner_dim, query_dim, bias=out_bias)
236
-
237
- if processor is None:
238
- self.processor = self._default_processor_cls()
239
- else:
240
- self.processor = processor
241
-
242
- def set_processor(self, processor: "AttnProcessor") -> None:
243
- r"""
244
- Set the attention processor to use.
245
-
246
- Args:
247
- processor (`AttnProcessor`):
248
- The attention processor to use.
249
- """
250
- # if current processor is in `self._modules` and if passed `processor` is not, we need to
251
- # pop `processor` from `self._modules`
252
- if (
253
- hasattr(self, "processor")
254
- and isinstance(self.processor, torch.nn.Module)
255
- and not isinstance(processor, torch.nn.Module)
256
- ):
257
- logger.info(f"You are removing possibly trained weights of {self.processor} with {processor}")
258
- self._modules.pop("processor")
259
-
260
- self.processor = processor
261
-
262
- def get_processor(self, return_deprecated_lora: bool = False) -> "AttentionProcessor":
263
- r"""
264
- Get the attention processor in use.
265
-
266
- Args:
267
- return_deprecated_lora (`bool`, *optional*, defaults to `False`):
268
- Set to `True` to return the deprecated LoRA attention processor.
269
-
270
- Returns:
271
- "AttentionProcessor": The attention processor in use.
272
- """
273
- if not return_deprecated_lora:
274
- return self.processor
275
-
276
- def forward(
277
- self,
278
- hidden_states: torch.Tensor,
279
- encoder_hidden_states: Optional[torch.Tensor] = None,
280
- attention_mask: Optional[torch.Tensor] = None,
281
- image_rotary_emb: Optional[torch.Tensor] = None,
282
- **kwargs,
283
- ) -> torch.Tensor:
284
- attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
285
- quiet_attn_parameters = {"ip_adapter_masks", "ip_hidden_states"}
286
- unused_kwargs = [k for k, _ in kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters]
287
- if len(unused_kwargs) > 0:
288
- logger.warning(
289
- f"joint_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
290
- )
291
- kwargs = {k: w for k, w in kwargs.items() if k in attn_parameters}
292
- return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
293
-
294
-
295
- @maybe_allow_in_graph
296
- class FluxSingleTransformerBlock(nn.Module):
297
- def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int, mlp_ratio: float = 4.0):
298
- super().__init__()
299
- self.mlp_hidden_dim = int(dim * mlp_ratio)
300
-
301
- self.norm = AdaLayerNormZeroSingle(dim)
302
- self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
303
- self.act_mlp = nn.GELU(approximate="tanh")
304
- self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
305
-
306
- self.attn = FluxAttention(
307
- query_dim=dim,
308
- dim_head=attention_head_dim,
309
- heads=num_attention_heads,
310
- out_dim=dim,
311
- bias=True,
312
- processor=FluxAttnProcessor(),
313
- eps=1e-6,
314
- pre_only=True,
315
- )
316
-
317
- def forward(
318
- self,
319
- hidden_states: torch.Tensor,
320
- encoder_hidden_states: torch.Tensor,
321
- temb: torch.Tensor,
322
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
323
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
324
- ) -> Tuple[torch.Tensor, torch.Tensor]:
325
- text_seq_len = encoder_hidden_states.shape[1]
326
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
327
-
328
- residual = hidden_states
329
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
330
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
331
- joint_attention_kwargs = joint_attention_kwargs or {}
332
- attn_output = self.attn(
333
- hidden_states=norm_hidden_states,
334
- image_rotary_emb=image_rotary_emb,
335
- text_seq_len=text_seq_len,
336
- **joint_attention_kwargs,
337
- )
338
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
339
- gate = gate.unsqueeze(1)
340
- hidden_states = gate * self.proj_out(hidden_states)
341
- hidden_states = residual + hidden_states
342
- if hidden_states.dtype == torch.float16:
343
- hidden_states = hidden_states.clip(-65504, 65504)
344
-
345
- encoder_hidden_states, hidden_states = hidden_states[:, :text_seq_len], hidden_states[:, text_seq_len:]
346
- return encoder_hidden_states, hidden_states
347
-
348
-
349
- @maybe_allow_in_graph
350
- class FluxTransformerBlock(nn.Module):
351
- def __init__(
352
- self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
353
- ):
354
- super().__init__()
355
-
356
- self.norm1 = AdaLayerNormZero(dim)
357
- self.norm1_context = AdaLayerNormZero(dim)
358
-
359
- self.attn = FluxAttention(
360
- query_dim=dim,
361
- added_kv_proj_dim=dim,
362
- dim_head=attention_head_dim,
363
- heads=num_attention_heads,
364
- out_dim=dim,
365
- context_pre_only=False,
366
- bias=True,
367
- processor=FluxAttnProcessor(),
368
- eps=eps,
369
- )
370
-
371
- self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
372
- self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
373
-
374
- self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
375
- self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
376
-
377
- def forward(
378
- self,
379
- hidden_states: torch.Tensor,
380
- encoder_hidden_states: torch.Tensor,
381
- temb: torch.Tensor,
382
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
383
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
384
- ) -> Tuple[torch.Tensor, torch.Tensor]:
385
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
386
-
387
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
388
- encoder_hidden_states, emb=temb
389
- )
390
- joint_attention_kwargs = joint_attention_kwargs or {}
391
-
392
- # Attention.
393
- attention_outputs = self.attn(
394
- hidden_states=norm_hidden_states,
395
- encoder_hidden_states=norm_encoder_hidden_states,
396
- image_rotary_emb=image_rotary_emb,
397
- **joint_attention_kwargs,
398
- )
399
-
400
- if len(attention_outputs) == 2:
401
- attn_output, context_attn_output = attention_outputs
402
- elif len(attention_outputs) == 3:
403
- attn_output, context_attn_output, ip_attn_output = attention_outputs
404
-
405
- # Process attention outputs for the `hidden_states`.
406
- attn_output = gate_msa.unsqueeze(1) * attn_output
407
- hidden_states = hidden_states + attn_output
408
-
409
- norm_hidden_states = self.norm2(hidden_states)
410
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
411
-
412
- ff_output = self.ff(norm_hidden_states)
413
- ff_output = gate_mlp.unsqueeze(1) * ff_output
414
-
415
- hidden_states = hidden_states + ff_output
416
- if len(attention_outputs) == 3:
417
- hidden_states = hidden_states + ip_attn_output
418
-
419
- # Process attention outputs for the `encoder_hidden_states`.
420
- context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
421
- encoder_hidden_states = encoder_hidden_states + context_attn_output
422
-
423
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
424
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
425
-
426
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
427
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
428
- if encoder_hidden_states.dtype == torch.float16:
429
- encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
430
-
431
- return encoder_hidden_states, hidden_states
432
-
433
-
434
- class FluxPosEmbed(nn.Module):
435
- # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
436
- def __init__(self, theta: int, axes_dim: List[int]):
437
- super().__init__()
438
- self.theta = theta
439
- self.axes_dim = axes_dim
440
-
441
- def forward(self, ids: torch.Tensor) -> torch.Tensor:
442
- n_axes = ids.shape[-1]
443
- cos_out = []
444
- sin_out = []
445
- pos = ids.float()
446
- is_mps = ids.device.type == "mps"
447
- is_npu = ids.device.type == "npu"
448
- freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
449
- for i in range(n_axes):
450
- cos, sin = get_1d_rotary_pos_embed(
451
- self.axes_dim[i],
452
- pos[:, i],
453
- theta=self.theta,
454
- repeat_interleave_real=True,
455
- use_real=True,
456
- freqs_dtype=freqs_dtype,
457
- )
458
- cos_out.append(cos)
459
- sin_out.append(sin)
460
- freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
461
- freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
462
- return freqs_cos, freqs_sin
463
-
464
-
465
- class FluxTransformer2DModel(
466
- ModelMixin,
467
- ConfigMixin,
468
- PeftAdapterMixin,
469
- FromOriginalModelMixin,
470
- ):
471
- """
472
- The Transformer model introduced in Flux.
473
-
474
- Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
475
-
476
- Args:
477
- patch_size (`int`, defaults to `1`):
478
- Patch size to turn the input data into small patches.
479
- in_channels (`int`, defaults to `64`):
480
- The number of channels in the input.
481
- out_channels (`int`, *optional*, defaults to `None`):
482
- The number of channels in the output. If not specified, it defaults to `in_channels`.
483
- num_layers (`int`, defaults to `19`):
484
- The number of layers of dual stream DiT blocks to use.
485
- num_single_layers (`int`, defaults to `38`):
486
- The number of layers of single stream DiT blocks to use.
487
- attention_head_dim (`int`, defaults to `128`):
488
- The number of dimensions to use for each attention head.
489
- num_attention_heads (`int`, defaults to `24`):
490
- The number of attention heads to use.
491
- joint_attention_dim (`int`, defaults to `4096`):
492
- The number of dimensions to use for the joint attention (embedding/channel dimension of
493
- `encoder_hidden_states`).
494
- pooled_projection_dim (`int`, defaults to `768`):
495
- The number of dimensions to use for the pooled projection.
496
- guidance_embeds (`bool`, defaults to `False`):
497
- Whether to use guidance embeddings for guidance-distilled variant of the model.
498
- axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
499
- The dimensions to use for the rotary positional embeddings.
500
- """
501
-
502
- _supports_gradient_checkpointing = True
503
- # _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
504
- # _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
505
- # _repeated_blocks = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
506
-
507
- @register_to_config
508
- def __init__(
509
- self,
510
- patch_size: int = 1,
511
- in_channels: int = 64,
512
- out_channels: Optional[int] = None,
513
- num_layers: int = 19,
514
- num_single_layers: int = 38,
515
- attention_head_dim: int = 128,
516
- num_attention_heads: int = 24,
517
- joint_attention_dim: int = 4096,
518
- pooled_projection_dim: int = 768,
519
- guidance_embeds: bool = False,
520
- axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
521
- ):
522
- super().__init__()
523
- self.out_channels = out_channels or in_channels
524
- self.inner_dim = num_attention_heads * attention_head_dim
525
-
526
- self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
527
-
528
- text_time_guidance_cls = (
529
- CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
530
- )
531
- self.time_text_embed = text_time_guidance_cls(
532
- embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
533
- )
534
-
535
- self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
536
- self.x_embedder = nn.Linear(in_channels, self.inner_dim)
537
-
538
- self.transformer_blocks = nn.ModuleList(
539
- [
540
- FluxTransformerBlock(
541
- dim=self.inner_dim,
542
- num_attention_heads=num_attention_heads,
543
- attention_head_dim=attention_head_dim,
544
- )
545
- for _ in range(num_layers)
546
- ]
547
- )
548
-
549
- self.single_transformer_blocks = nn.ModuleList(
550
- [
551
- FluxSingleTransformerBlock(
552
- dim=self.inner_dim,
553
- num_attention_heads=num_attention_heads,
554
- attention_head_dim=attention_head_dim,
555
- )
556
- for _ in range(num_single_layers)
557
- ]
558
- )
559
-
560
- self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
561
- self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
562
-
563
- self.gradient_checkpointing = False
564
-
565
- self.sp_world_size = 1
566
- self.sp_world_rank = 0
567
-
568
- def _set_gradient_checkpointing(self, *args, **kwargs):
569
- if "value" in kwargs:
570
- self.gradient_checkpointing = kwargs["value"]
571
- elif "enable" in kwargs:
572
- self.gradient_checkpointing = kwargs["enable"]
573
- else:
574
- raise ValueError("Invalid set gradient checkpointing")
575
-
576
- def enable_multi_gpus_inference(self,):
577
- self.sp_world_size = get_sequence_parallel_world_size()
578
- self.sp_world_rank = get_sequence_parallel_rank()
579
- self.all_gather = get_sp_group().all_gather
580
- self.set_attn_processor(FluxMultiGPUsAttnProcessor2_0())
581
-
582
- @property
583
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
584
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
585
- r"""
586
- Returns:
587
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
588
- indexed by its weight name.
589
- """
590
- # set recursively
591
- processors = {}
592
-
593
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
594
- if hasattr(module, "get_processor"):
595
- processors[f"{name}.processor"] = module.get_processor()
596
-
597
- for sub_name, child in module.named_children():
598
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
599
-
600
- return processors
601
-
602
- for name, module in self.named_children():
603
- fn_recursive_add_processors(name, module, processors)
604
-
605
- return processors
606
-
607
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
608
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
609
- r"""
610
- Sets the attention processor to use to compute attention.
611
-
612
- Parameters:
613
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
614
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
615
- for **all** `Attention` layers.
616
-
617
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
618
- processor. This is strongly recommended when setting trainable attention processors.
619
-
620
- """
621
- count = len(self.attn_processors.keys())
622
-
623
- if isinstance(processor, dict) and len(processor) != count:
624
- raise ValueError(
625
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
626
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
627
- )
628
-
629
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
630
- if hasattr(module, "set_processor"):
631
- if not isinstance(processor, dict):
632
- module.set_processor(processor)
633
- else:
634
- module.set_processor(processor.pop(f"{name}.processor"))
635
-
636
- for sub_name, child in module.named_children():
637
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
638
-
639
- for name, module in self.named_children():
640
- fn_recursive_attn_processor(name, module, processor)
641
-
642
- def forward(
643
- self,
644
- hidden_states: torch.Tensor,
645
- encoder_hidden_states: torch.Tensor = None,
646
- pooled_projections: torch.Tensor = None,
647
- timestep: torch.LongTensor = None,
648
- img_ids: torch.Tensor = None,
649
- txt_ids: torch.Tensor = None,
650
- guidance: torch.Tensor = None,
651
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
652
- controlnet_block_samples=None,
653
- controlnet_single_block_samples=None,
654
- return_dict: bool = True,
655
- controlnet_blocks_repeat: bool = False,
656
- ) -> Union[torch.Tensor, Transformer2DModelOutput]:
657
- """
658
- The [`FluxTransformer2DModel`] forward method.
659
-
660
- Args:
661
- hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
662
- Input `hidden_states`.
663
- encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
664
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
665
- pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected
666
- from the embeddings of input conditions.
667
- timestep ( `torch.LongTensor`):
668
- Used to indicate denoising step.
669
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
670
- A list of tensors that if specified are added to the residuals of transformer blocks.
671
- joint_attention_kwargs (`dict`, *optional*):
672
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
673
- `self.processor` in
674
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
675
- return_dict (`bool`, *optional*, defaults to `True`):
676
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
677
- tuple.
678
-
679
- Returns:
680
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
681
- `tuple` where the first element is the sample tensor.
682
- """
683
- if joint_attention_kwargs is not None:
684
- joint_attention_kwargs = joint_attention_kwargs.copy()
685
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
686
- else:
687
- lora_scale = 1.0
688
-
689
- if USE_PEFT_BACKEND:
690
- # weight the lora layers by setting `lora_scale` for each PEFT layer
691
- scale_lora_layers(self, lora_scale)
692
- else:
693
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
694
- logger.warning(
695
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
696
- )
697
-
698
- hidden_states = self.x_embedder(hidden_states)
699
-
700
- timestep = timestep.to(hidden_states.dtype) * 1000
701
- if guidance is not None:
702
- guidance = guidance.to(hidden_states.dtype) * 1000
703
-
704
- temb = (
705
- self.time_text_embed(timestep, pooled_projections)
706
- if guidance is None
707
- else self.time_text_embed(timestep, guidance, pooled_projections)
708
- )
709
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
710
-
711
- if txt_ids.ndim == 3:
712
- logger.warning(
713
- "Passing `txt_ids` 3d torch.Tensor is deprecated."
714
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
715
- )
716
- txt_ids = txt_ids[0]
717
- if img_ids.ndim == 3:
718
- logger.warning(
719
- "Passing `img_ids` 3d torch.Tensor is deprecated."
720
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
721
- )
722
- img_ids = img_ids[0]
723
-
724
- ids = torch.cat((txt_ids, img_ids), dim=0)
725
- image_rotary_emb = self.pos_embed(ids)
726
-
727
- if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
728
- ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
729
- ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
730
- joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
731
-
732
- # Context Parallel
733
- if self.sp_world_size > 1:
734
- hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
735
- if image_rotary_emb is not None:
736
- txt_rotary_emb = (
737
- image_rotary_emb[0][:encoder_hidden_states.shape[1]],
738
- image_rotary_emb[1][:encoder_hidden_states.shape[1]]
739
- )
740
- image_rotary_emb = (
741
- torch.chunk(image_rotary_emb[0][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
742
- torch.chunk(image_rotary_emb[1][encoder_hidden_states.shape[1]:], self.sp_world_size, dim=0)[self.sp_world_rank],
743
- )
744
- image_rotary_emb = [torch.cat([_txt_rotary_emb, _image_rotary_emb], dim=0) \
745
- for _txt_rotary_emb, _image_rotary_emb in zip(txt_rotary_emb, image_rotary_emb)]
746
-
747
- for index_block, block in enumerate(self.transformer_blocks):
748
- if torch.is_grad_enabled() and self.gradient_checkpointing:
749
- def create_custom_forward(module):
750
- def custom_forward(*inputs):
751
- return module(*inputs)
752
-
753
- return custom_forward
754
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
755
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
756
- create_custom_forward(block),
757
- hidden_states,
758
- encoder_hidden_states,
759
- temb,
760
- image_rotary_emb,
761
- joint_attention_kwargs,
762
- **ckpt_kwargs,
763
- )
764
-
765
- else:
766
- encoder_hidden_states, hidden_states = block(
767
- hidden_states=hidden_states,
768
- encoder_hidden_states=encoder_hidden_states,
769
- temb=temb,
770
- image_rotary_emb=image_rotary_emb,
771
- joint_attention_kwargs=joint_attention_kwargs,
772
- )
773
-
774
- # controlnet residual
775
- if controlnet_block_samples is not None:
776
- interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
777
- interval_control = int(np.ceil(interval_control))
778
- # For Xlabs ControlNet.
779
- if controlnet_blocks_repeat:
780
- hidden_states = (
781
- hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
782
- )
783
- else:
784
- hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
785
-
786
- for index_block, block in enumerate(self.single_transformer_blocks):
787
- if torch.is_grad_enabled() and self.gradient_checkpointing:
788
- def create_custom_forward(module):
789
- def custom_forward(*inputs):
790
- return module(*inputs)
791
-
792
- return custom_forward
793
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
794
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
795
- create_custom_forward(block),
796
- hidden_states,
797
- encoder_hidden_states,
798
- temb,
799
- image_rotary_emb,
800
- joint_attention_kwargs,
801
- **ckpt_kwargs,
802
- )
803
-
804
- else:
805
- encoder_hidden_states, hidden_states = block(
806
- hidden_states=hidden_states,
807
- encoder_hidden_states=encoder_hidden_states,
808
- temb=temb,
809
- image_rotary_emb=image_rotary_emb,
810
- joint_attention_kwargs=joint_attention_kwargs,
811
- )
812
-
813
- # controlnet residual
814
- if controlnet_single_block_samples is not None:
815
- interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
816
- interval_control = int(np.ceil(interval_control))
817
- hidden_states = hidden_states + controlnet_single_block_samples[index_block // interval_control]
818
-
819
- hidden_states = self.norm_out(hidden_states, temb)
820
- output = self.proj_out(hidden_states)
821
-
822
- if self.sp_world_size > 1:
823
- output = self.all_gather(output, dim=1)
824
-
825
- if USE_PEFT_BACKEND:
826
- # remove `lora_scale` from each PEFT layer
827
- unscale_lora_layers(self, lora_scale)
828
-
829
- if not return_dict:
830
- return (output,)
831
-
832
- return Transformer2DModelOutput(sample=output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/hunyuanvideo_transformer3d.py DELETED
@@ -1,1478 +0,0 @@
1
- # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
2
- # Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- import glob
17
- import json
18
- import os
19
- from typing import Any, Dict, List, Optional, Tuple, Union
20
-
21
- import torch
22
- import torch.nn as nn
23
- import torch.nn.functional as F
24
- from diffusers.configuration_utils import ConfigMixin, register_to_config
25
- from diffusers.loaders import FromOriginalModelMixin
26
- from diffusers.models.attention import FeedForward
27
- from diffusers.models.attention_processor import Attention, AttentionProcessor
28
- from diffusers.models.embeddings import (CombinedTimestepTextProjEmbeddings,
29
- PixArtAlphaTextProjection,
30
- TimestepEmbedding, Timesteps,
31
- get_1d_rotary_pos_embed)
32
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
33
- from diffusers.models.modeling_utils import ModelMixin
34
- from diffusers.models.normalization import (AdaLayerNormContinuous,
35
- AdaLayerNormZero,
36
- AdaLayerNormZeroSingle,
37
- FP32LayerNorm)
38
- from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
39
- scale_lora_layers, unscale_lora_layers)
40
-
41
- from ..dist import (get_sequence_parallel_rank,
42
- get_sequence_parallel_world_size, get_sp_group,
43
- xFuserLongContextAttention)
44
- from ..dist.hunyuanvideo_xfuser import HunyuanVideoMultiGPUsAttnProcessor2_0
45
- from .attention_utils import attention
46
-
47
-
48
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
-
50
-
51
- def apply_rotary_emb(
52
- x: torch.Tensor,
53
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
54
- use_real: bool = True,
55
- use_real_unbind_dim: int = -1,
56
- sequence_dim: int = 2,
57
- ) -> Tuple[torch.Tensor, torch.Tensor]:
58
- """
59
- Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
60
- to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
61
- reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
62
- tensors contain rotary embeddings and are returned as real tensors.
63
-
64
- Args:
65
- x (`torch.Tensor`):
66
- Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
67
- freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
68
-
69
- Returns:
70
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
71
- """
72
- if use_real:
73
- cos, sin = freqs_cis # [S, D]
74
- if sequence_dim == 2:
75
- cos = cos[None, None, :, :]
76
- sin = sin[None, None, :, :]
77
- elif sequence_dim == 1:
78
- cos = cos[None, :, None, :]
79
- sin = sin[None, :, None, :]
80
- else:
81
- raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
82
-
83
- cos, sin = cos.to(x.device), sin.to(x.device)
84
-
85
- if use_real_unbind_dim == -1:
86
- # Used for flux, cogvideox, hunyuan-dit
87
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
88
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
89
- elif use_real_unbind_dim == -2:
90
- # Used for Stable Audio, OmniGen, CogView4 and Cosmos
91
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
92
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
93
- else:
94
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
95
-
96
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
97
-
98
- return out
99
- else:
100
- # used for lumina
101
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
102
- freqs_cis = freqs_cis.unsqueeze(2)
103
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
104
-
105
- return x_out.type_as(x)
106
-
107
- def extract_seqlens_from_mask(attn_mask):
108
- if attn_mask is None:
109
- return None
110
-
111
- if len(attn_mask.shape) == 4:
112
- bs, _, _, seq_len = attn_mask.shape
113
-
114
- if attn_mask.dtype == torch.bool:
115
- valid_mask = attn_mask.squeeze(1).squeeze(1)
116
- else:
117
- valid_mask = ~torch.isinf(attn_mask.squeeze(1).squeeze(1))
118
- elif len(attn_mask.shape) == 3:
119
- raise ValueError(
120
- "attn_mask should be 2D or 4D tensor, but got {}".format(
121
- attn_mask.shape))
122
-
123
- seqlens = valid_mask.sum(dim=1)
124
- return seqlens
125
-
126
- class HunyuanVideoAttnProcessor2_0:
127
- def __init__(self):
128
- if not hasattr(F, "scaled_dot_product_attention"):
129
- raise ImportError(
130
- "HunyuanVideoAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
131
- )
132
-
133
- def __call__(
134
- self,
135
- attn: Attention,
136
- hidden_states: torch.Tensor,
137
- encoder_hidden_states: Optional[torch.Tensor] = None,
138
- attention_mask: Optional[torch.Tensor] = None,
139
- image_rotary_emb: Optional[torch.Tensor] = None,
140
- ) -> torch.Tensor:
141
- if attn.add_q_proj is None and encoder_hidden_states is not None:
142
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
143
-
144
- # 1. QKV projections
145
- query = attn.to_q(hidden_states)
146
- key = attn.to_k(hidden_states)
147
- value = attn.to_v(hidden_states)
148
-
149
- query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
150
- key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
151
- value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
152
-
153
- # 2. QK normalization
154
- if attn.norm_q is not None:
155
- query = attn.norm_q(query)
156
- if attn.norm_k is not None:
157
- key = attn.norm_k(key)
158
-
159
- # 3. Rotational positional embeddings applied to latent stream
160
- if image_rotary_emb is not None:
161
- if attn.add_q_proj is None and encoder_hidden_states is not None:
162
- query = torch.cat(
163
- [
164
- apply_rotary_emb(query[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
165
- query[:, :, -encoder_hidden_states.shape[1] :],
166
- ],
167
- dim=2,
168
- )
169
- key = torch.cat(
170
- [
171
- apply_rotary_emb(key[:, :, : -encoder_hidden_states.shape[1]], image_rotary_emb),
172
- key[:, :, -encoder_hidden_states.shape[1] :],
173
- ],
174
- dim=2,
175
- )
176
- else:
177
- query = apply_rotary_emb(query, image_rotary_emb)
178
- key = apply_rotary_emb(key, image_rotary_emb)
179
-
180
- # 4. Encoder condition QKV projection and normalization
181
- if attn.add_q_proj is not None and encoder_hidden_states is not None:
182
- encoder_query = attn.add_q_proj(encoder_hidden_states)
183
- encoder_key = attn.add_k_proj(encoder_hidden_states)
184
- encoder_value = attn.add_v_proj(encoder_hidden_states)
185
-
186
- encoder_query = encoder_query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
187
- encoder_key = encoder_key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
188
- encoder_value = encoder_value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
189
-
190
- if attn.norm_added_q is not None:
191
- encoder_query = attn.norm_added_q(encoder_query)
192
- if attn.norm_added_k is not None:
193
- encoder_key = attn.norm_added_k(encoder_key)
194
-
195
- query = torch.cat([query, encoder_query], dim=2)
196
- key = torch.cat([key, encoder_key], dim=2)
197
- value = torch.cat([value, encoder_value], dim=2)
198
-
199
- # 5. Attention
200
- query = query.transpose(1, 2)
201
- key = key.transpose(1, 2)
202
- value = value.transpose(1, 2)
203
-
204
- if attention_mask is not None:
205
- q_lens = k_lens = extract_seqlens_from_mask(attention_mask)
206
-
207
- hidden_states = torch.zeros_like(query)
208
- for i in range(len(q_lens)):
209
- hidden_states[i][:q_lens[i]] = attention(
210
- query[i][:q_lens[i]].unsqueeze(0),
211
- key[i][:q_lens[i]].unsqueeze(0),
212
- value[i][:q_lens[i]].unsqueeze(0),
213
- attn_mask=None,
214
- )
215
- else:
216
- hidden_states = attention(
217
- query, key, value, attn_mask=attention_mask,
218
- )
219
- hidden_states = hidden_states.flatten(2, 3)
220
- hidden_states = hidden_states.to(query.dtype)
221
-
222
- # 6. Output projection
223
- if encoder_hidden_states is not None:
224
- hidden_states, encoder_hidden_states = (
225
- hidden_states[:, : -encoder_hidden_states.shape[1]],
226
- hidden_states[:, -encoder_hidden_states.shape[1] :],
227
- )
228
-
229
- if getattr(attn, "to_out", None) is not None:
230
- hidden_states = attn.to_out[0](hidden_states)
231
- hidden_states = attn.to_out[1](hidden_states)
232
-
233
- if getattr(attn, "to_add_out", None) is not None:
234
- encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
235
-
236
- return hidden_states, encoder_hidden_states
237
-
238
-
239
- class HunyuanVideoPatchEmbed(nn.Module):
240
- def __init__(
241
- self,
242
- patch_size: Union[int, Tuple[int, int, int]] = 16,
243
- in_chans: int = 3,
244
- embed_dim: int = 768,
245
- ) -> None:
246
- super().__init__()
247
-
248
- patch_size = (patch_size, patch_size, patch_size) if isinstance(patch_size, int) else patch_size
249
- self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
250
-
251
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
252
- hidden_states = self.proj(hidden_states)
253
- hidden_states = hidden_states.flatten(2).transpose(1, 2) # BCFHW -> BNC
254
- return hidden_states
255
-
256
-
257
- class HunyuanVideoAdaNorm(nn.Module):
258
- def __init__(self, in_features: int, out_features: Optional[int] = None) -> None:
259
- super().__init__()
260
-
261
- out_features = out_features or 2 * in_features
262
- self.linear = nn.Linear(in_features, out_features)
263
- self.nonlinearity = nn.SiLU()
264
-
265
- def forward(
266
- self, temb: torch.Tensor
267
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
268
- temb = self.linear(self.nonlinearity(temb))
269
- gate_msa, gate_mlp = temb.chunk(2, dim=1)
270
- gate_msa, gate_mlp = gate_msa.unsqueeze(1), gate_mlp.unsqueeze(1)
271
- return gate_msa, gate_mlp
272
-
273
-
274
- class HunyuanVideoTokenReplaceAdaLayerNormZero(nn.Module):
275
- def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
276
- super().__init__()
277
-
278
- self.silu = nn.SiLU()
279
- self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
280
-
281
- if norm_type == "layer_norm":
282
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
283
- elif norm_type == "fp32_layer_norm":
284
- self.norm = FP32LayerNorm(embedding_dim, elementwise_affine=False, bias=False)
285
- else:
286
- raise ValueError(
287
- f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
288
- )
289
-
290
- def forward(
291
- self,
292
- hidden_states: torch.Tensor,
293
- emb: torch.Tensor,
294
- token_replace_emb: torch.Tensor,
295
- first_frame_num_tokens: int,
296
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
297
- emb = self.linear(self.silu(emb))
298
- token_replace_emb = self.linear(self.silu(token_replace_emb))
299
-
300
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
301
- tr_shift_msa, tr_scale_msa, tr_gate_msa, tr_shift_mlp, tr_scale_mlp, tr_gate_mlp = token_replace_emb.chunk(
302
- 6, dim=1
303
- )
304
-
305
- norm_hidden_states = self.norm(hidden_states)
306
- hidden_states_zero = (
307
- norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
308
- )
309
- hidden_states_orig = (
310
- norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
311
- )
312
- hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
313
-
314
- return (
315
- hidden_states,
316
- gate_msa,
317
- shift_mlp,
318
- scale_mlp,
319
- gate_mlp,
320
- tr_gate_msa,
321
- tr_shift_mlp,
322
- tr_scale_mlp,
323
- tr_gate_mlp,
324
- )
325
-
326
-
327
- class HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(nn.Module):
328
- def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True):
329
- super().__init__()
330
-
331
- self.silu = nn.SiLU()
332
- self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
333
-
334
- if norm_type == "layer_norm":
335
- self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
336
- else:
337
- raise ValueError(
338
- f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm', 'fp32_layer_norm'."
339
- )
340
-
341
- def forward(
342
- self,
343
- hidden_states: torch.Tensor,
344
- emb: torch.Tensor,
345
- token_replace_emb: torch.Tensor,
346
- first_frame_num_tokens: int,
347
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
348
- emb = self.linear(self.silu(emb))
349
- token_replace_emb = self.linear(self.silu(token_replace_emb))
350
-
351
- shift_msa, scale_msa, gate_msa = emb.chunk(3, dim=1)
352
- tr_shift_msa, tr_scale_msa, tr_gate_msa = token_replace_emb.chunk(3, dim=1)
353
-
354
- norm_hidden_states = self.norm(hidden_states)
355
- hidden_states_zero = (
356
- norm_hidden_states[:, :first_frame_num_tokens] * (1 + tr_scale_msa[:, None]) + tr_shift_msa[:, None]
357
- )
358
- hidden_states_orig = (
359
- norm_hidden_states[:, first_frame_num_tokens:] * (1 + scale_msa[:, None]) + shift_msa[:, None]
360
- )
361
- hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
362
-
363
- return hidden_states, gate_msa, tr_gate_msa
364
-
365
-
366
- class HunyuanVideoConditionEmbedding(nn.Module):
367
- def __init__(
368
- self,
369
- embedding_dim: int,
370
- pooled_projection_dim: int,
371
- guidance_embeds: bool,
372
- image_condition_type: Optional[str] = None,
373
- ):
374
- super().__init__()
375
-
376
- self.image_condition_type = image_condition_type
377
-
378
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
379
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
380
- self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
381
-
382
- self.guidance_embedder = None
383
- if guidance_embeds:
384
- self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
385
-
386
- def forward(
387
- self, timestep: torch.Tensor, pooled_projection: torch.Tensor, guidance: Optional[torch.Tensor] = None
388
- ) -> Tuple[torch.Tensor, torch.Tensor]:
389
- timesteps_proj = self.time_proj(timestep)
390
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
391
- pooled_projections = self.text_embedder(pooled_projection)
392
- conditioning = timesteps_emb + pooled_projections
393
-
394
- token_replace_emb = None
395
- if self.image_condition_type == "token_replace":
396
- token_replace_timestep = torch.zeros_like(timestep)
397
- token_replace_proj = self.time_proj(token_replace_timestep)
398
- token_replace_emb = self.timestep_embedder(token_replace_proj.to(dtype=pooled_projection.dtype))
399
- token_replace_emb = token_replace_emb + pooled_projections
400
-
401
- if self.guidance_embedder is not None:
402
- guidance_proj = self.time_proj(guidance)
403
- guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
404
- conditioning = conditioning + guidance_emb
405
-
406
- return conditioning, token_replace_emb
407
-
408
-
409
- class HunyuanVideoIndividualTokenRefinerBlock(nn.Module):
410
- def __init__(
411
- self,
412
- num_attention_heads: int,
413
- attention_head_dim: int,
414
- mlp_width_ratio: str = 4.0,
415
- mlp_drop_rate: float = 0.0,
416
- attention_bias: bool = True,
417
- ) -> None:
418
- super().__init__()
419
-
420
- hidden_size = num_attention_heads * attention_head_dim
421
-
422
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
423
- self.attn = Attention(
424
- query_dim=hidden_size,
425
- cross_attention_dim=None,
426
- heads=num_attention_heads,
427
- dim_head=attention_head_dim,
428
- bias=attention_bias,
429
- )
430
- self.attn.set_processor = None
431
-
432
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6)
433
- self.ff = FeedForward(hidden_size, mult=mlp_width_ratio, activation_fn="linear-silu", dropout=mlp_drop_rate)
434
-
435
- self.norm_out = HunyuanVideoAdaNorm(hidden_size, 2 * hidden_size)
436
-
437
- def forward(
438
- self,
439
- hidden_states: torch.Tensor,
440
- temb: torch.Tensor,
441
- attention_mask: Optional[torch.Tensor] = None,
442
- ) -> torch.Tensor:
443
- norm_hidden_states = self.norm1(hidden_states)
444
-
445
- attn_output = self.attn(
446
- hidden_states=norm_hidden_states,
447
- encoder_hidden_states=None,
448
- attention_mask=attention_mask,
449
- )
450
-
451
- gate_msa, gate_mlp = self.norm_out(temb)
452
- hidden_states = hidden_states + attn_output * gate_msa
453
-
454
- ff_output = self.ff(self.norm2(hidden_states))
455
- hidden_states = hidden_states + ff_output * gate_mlp
456
-
457
- return hidden_states
458
-
459
-
460
- class HunyuanVideoIndividualTokenRefiner(nn.Module):
461
- def __init__(
462
- self,
463
- num_attention_heads: int,
464
- attention_head_dim: int,
465
- num_layers: int,
466
- mlp_width_ratio: float = 4.0,
467
- mlp_drop_rate: float = 0.0,
468
- attention_bias: bool = True,
469
- ) -> None:
470
- super().__init__()
471
-
472
- self.refiner_blocks = nn.ModuleList(
473
- [
474
- HunyuanVideoIndividualTokenRefinerBlock(
475
- num_attention_heads=num_attention_heads,
476
- attention_head_dim=attention_head_dim,
477
- mlp_width_ratio=mlp_width_ratio,
478
- mlp_drop_rate=mlp_drop_rate,
479
- attention_bias=attention_bias,
480
- )
481
- for _ in range(num_layers)
482
- ]
483
- )
484
-
485
- def forward(
486
- self,
487
- hidden_states: torch.Tensor,
488
- temb: torch.Tensor,
489
- attention_mask: Optional[torch.Tensor] = None,
490
- ) -> None:
491
- self_attn_mask = None
492
- if attention_mask is not None:
493
- batch_size = attention_mask.shape[0]
494
- seq_len = attention_mask.shape[1]
495
- attention_mask = attention_mask.to(hidden_states.device).bool()
496
- self_attn_mask_1 = attention_mask.view(batch_size, 1, 1, seq_len).repeat(1, 1, seq_len, 1)
497
- self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
498
- self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
499
- self_attn_mask[:, :, :, 0] = True
500
-
501
- for block in self.refiner_blocks:
502
- hidden_states = block(hidden_states, temb, self_attn_mask)
503
-
504
- return hidden_states
505
-
506
-
507
- class HunyuanVideoTokenRefiner(nn.Module):
508
- def __init__(
509
- self,
510
- in_channels: int,
511
- num_attention_heads: int,
512
- attention_head_dim: int,
513
- num_layers: int,
514
- mlp_ratio: float = 4.0,
515
- mlp_drop_rate: float = 0.0,
516
- attention_bias: bool = True,
517
- ) -> None:
518
- super().__init__()
519
-
520
- hidden_size = num_attention_heads * attention_head_dim
521
-
522
- self.time_text_embed = CombinedTimestepTextProjEmbeddings(
523
- embedding_dim=hidden_size, pooled_projection_dim=in_channels
524
- )
525
- self.proj_in = nn.Linear(in_channels, hidden_size, bias=True)
526
- self.token_refiner = HunyuanVideoIndividualTokenRefiner(
527
- num_attention_heads=num_attention_heads,
528
- attention_head_dim=attention_head_dim,
529
- num_layers=num_layers,
530
- mlp_width_ratio=mlp_ratio,
531
- mlp_drop_rate=mlp_drop_rate,
532
- attention_bias=attention_bias,
533
- )
534
-
535
- def forward(
536
- self,
537
- hidden_states: torch.Tensor,
538
- timestep: torch.LongTensor,
539
- attention_mask: Optional[torch.LongTensor] = None,
540
- ) -> torch.Tensor:
541
- if attention_mask is None:
542
- pooled_projections = hidden_states.mean(dim=1)
543
- else:
544
- original_dtype = hidden_states.dtype
545
- mask_float = attention_mask.float().unsqueeze(-1)
546
- pooled_projections = (hidden_states * mask_float).sum(dim=1) / mask_float.sum(dim=1)
547
- pooled_projections = pooled_projections.to(original_dtype)
548
-
549
- temb = self.time_text_embed(timestep, pooled_projections)
550
- hidden_states = self.proj_in(hidden_states)
551
- hidden_states = self.token_refiner(hidden_states, temb, attention_mask)
552
-
553
- return hidden_states
554
-
555
-
556
- class HunyuanVideoRotaryPosEmbed(nn.Module):
557
- def __init__(self, patch_size: int, patch_size_t: int, rope_dim: List[int], theta: float = 256.0) -> None:
558
- super().__init__()
559
-
560
- self.patch_size = patch_size
561
- self.patch_size_t = patch_size_t
562
- self.rope_dim = rope_dim
563
- self.theta = theta
564
-
565
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
566
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
567
- rope_sizes = [num_frames // self.patch_size_t, height // self.patch_size, width // self.patch_size]
568
-
569
- axes_grids = []
570
- for i in range(3):
571
- # Note: The following line diverges from original behaviour. We create the grid on the device, whereas
572
- # original implementation creates it on CPU and then moves it to device. This results in numerical
573
- # differences in layerwise debugging outputs, but visually it is the same.
574
- grid = torch.arange(0, rope_sizes[i], device=hidden_states.device, dtype=torch.float32)
575
- axes_grids.append(grid)
576
- grid = torch.meshgrid(*axes_grids, indexing="ij") # [W, H, T]
577
- grid = torch.stack(grid, dim=0) # [3, W, H, T]
578
-
579
- freqs = []
580
- for i in range(3):
581
- freq = get_1d_rotary_pos_embed(self.rope_dim[i], grid[i].reshape(-1), self.theta, use_real=True)
582
- freqs.append(freq)
583
-
584
- freqs_cos = torch.cat([f[0] for f in freqs], dim=1) # (W * H * T, D / 2)
585
- freqs_sin = torch.cat([f[1] for f in freqs], dim=1) # (W * H * T, D / 2)
586
- return freqs_cos, freqs_sin
587
-
588
-
589
- class HunyuanVideoSingleTransformerBlock(nn.Module):
590
- def __init__(
591
- self,
592
- num_attention_heads: int,
593
- attention_head_dim: int,
594
- mlp_ratio: float = 4.0,
595
- qk_norm: str = "rms_norm",
596
- ) -> None:
597
- super().__init__()
598
-
599
- hidden_size = num_attention_heads * attention_head_dim
600
- mlp_dim = int(hidden_size * mlp_ratio)
601
-
602
- self.attn = Attention(
603
- query_dim=hidden_size,
604
- cross_attention_dim=None,
605
- dim_head=attention_head_dim,
606
- heads=num_attention_heads,
607
- out_dim=hidden_size,
608
- bias=True,
609
- processor=HunyuanVideoAttnProcessor2_0(),
610
- qk_norm=qk_norm,
611
- eps=1e-6,
612
- pre_only=True,
613
- )
614
-
615
- self.norm = AdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
616
- self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
617
- self.act_mlp = nn.GELU(approximate="tanh")
618
- self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
619
-
620
- def forward(
621
- self,
622
- hidden_states: torch.Tensor,
623
- encoder_hidden_states: torch.Tensor,
624
- temb: torch.Tensor,
625
- attention_mask: Optional[torch.Tensor] = None,
626
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
627
- *args,
628
- **kwargs,
629
- ) -> Tuple[torch.Tensor, torch.Tensor]:
630
- text_seq_length = encoder_hidden_states.shape[1]
631
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
632
-
633
- residual = hidden_states
634
-
635
- # 1. Input normalization
636
- norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
637
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
638
-
639
- norm_hidden_states, norm_encoder_hidden_states = (
640
- norm_hidden_states[:, :-text_seq_length, :],
641
- norm_hidden_states[:, -text_seq_length:, :],
642
- )
643
-
644
- # 2. Attention
645
- attn_output, context_attn_output = self.attn(
646
- hidden_states=norm_hidden_states,
647
- encoder_hidden_states=norm_encoder_hidden_states,
648
- attention_mask=attention_mask,
649
- image_rotary_emb=image_rotary_emb,
650
- )
651
- attn_output = torch.cat([attn_output, context_attn_output], dim=1)
652
-
653
- # 3. Modulation and residual connection
654
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
655
- hidden_states = gate.unsqueeze(1) * self.proj_out(hidden_states)
656
- hidden_states = hidden_states + residual
657
-
658
- hidden_states, encoder_hidden_states = (
659
- hidden_states[:, :-text_seq_length, :],
660
- hidden_states[:, -text_seq_length:, :],
661
- )
662
- return hidden_states, encoder_hidden_states
663
-
664
-
665
- class HunyuanVideoTransformerBlock(nn.Module):
666
- def __init__(
667
- self,
668
- num_attention_heads: int,
669
- attention_head_dim: int,
670
- mlp_ratio: float,
671
- qk_norm: str = "rms_norm",
672
- ) -> None:
673
- super().__init__()
674
-
675
- hidden_size = num_attention_heads * attention_head_dim
676
-
677
- self.norm1 = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
678
- self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
679
-
680
- self.attn = Attention(
681
- query_dim=hidden_size,
682
- cross_attention_dim=None,
683
- added_kv_proj_dim=hidden_size,
684
- dim_head=attention_head_dim,
685
- heads=num_attention_heads,
686
- out_dim=hidden_size,
687
- context_pre_only=False,
688
- bias=True,
689
- processor=HunyuanVideoAttnProcessor2_0(),
690
- qk_norm=qk_norm,
691
- eps=1e-6,
692
- )
693
-
694
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
695
- self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
696
-
697
- self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
698
- self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
699
-
700
- def forward(
701
- self,
702
- hidden_states: torch.Tensor,
703
- encoder_hidden_states: torch.Tensor,
704
- temb: torch.Tensor,
705
- attention_mask: Optional[torch.Tensor] = None,
706
- freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
707
- *args,
708
- **kwargs,
709
- ) -> Tuple[torch.Tensor, torch.Tensor]:
710
- # 1. Input normalization
711
- norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
712
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
713
- encoder_hidden_states, emb=temb
714
- )
715
-
716
- # 2. Joint attention
717
- attn_output, context_attn_output = self.attn(
718
- hidden_states=norm_hidden_states,
719
- encoder_hidden_states=norm_encoder_hidden_states,
720
- attention_mask=attention_mask,
721
- image_rotary_emb=freqs_cis,
722
- )
723
-
724
- # 3. Modulation and residual connection
725
- hidden_states = hidden_states + attn_output * gate_msa.unsqueeze(1)
726
- encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
727
-
728
- norm_hidden_states = self.norm2(hidden_states)
729
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
730
-
731
- norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
732
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
733
-
734
- # 4. Feed-forward
735
- ff_output = self.ff(norm_hidden_states)
736
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
737
-
738
- hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output
739
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
740
-
741
- return hidden_states, encoder_hidden_states
742
-
743
-
744
- class HunyuanVideoTokenReplaceSingleTransformerBlock(nn.Module):
745
- def __init__(
746
- self,
747
- num_attention_heads: int,
748
- attention_head_dim: int,
749
- mlp_ratio: float = 4.0,
750
- qk_norm: str = "rms_norm",
751
- ) -> None:
752
- super().__init__()
753
-
754
- hidden_size = num_attention_heads * attention_head_dim
755
- mlp_dim = int(hidden_size * mlp_ratio)
756
-
757
- self.attn = Attention(
758
- query_dim=hidden_size,
759
- cross_attention_dim=None,
760
- dim_head=attention_head_dim,
761
- heads=num_attention_heads,
762
- out_dim=hidden_size,
763
- bias=True,
764
- processor=HunyuanVideoAttnProcessor2_0(),
765
- qk_norm=qk_norm,
766
- eps=1e-6,
767
- pre_only=True,
768
- )
769
-
770
- self.norm = HunyuanVideoTokenReplaceAdaLayerNormZeroSingle(hidden_size, norm_type="layer_norm")
771
- self.proj_mlp = nn.Linear(hidden_size, mlp_dim)
772
- self.act_mlp = nn.GELU(approximate="tanh")
773
- self.proj_out = nn.Linear(hidden_size + mlp_dim, hidden_size)
774
-
775
- def forward(
776
- self,
777
- hidden_states: torch.Tensor,
778
- encoder_hidden_states: torch.Tensor,
779
- temb: torch.Tensor,
780
- attention_mask: Optional[torch.Tensor] = None,
781
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
782
- token_replace_emb: torch.Tensor = None,
783
- num_tokens: int = None,
784
- ) -> Tuple[torch.Tensor, torch.Tensor]:
785
- text_seq_length = encoder_hidden_states.shape[1]
786
- hidden_states = torch.cat([hidden_states, encoder_hidden_states], dim=1)
787
-
788
- residual = hidden_states
789
-
790
- # 1. Input normalization
791
- norm_hidden_states, gate, tr_gate = self.norm(hidden_states, temb, token_replace_emb, num_tokens)
792
- mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
793
-
794
- norm_hidden_states, norm_encoder_hidden_states = (
795
- norm_hidden_states[:, :-text_seq_length, :],
796
- norm_hidden_states[:, -text_seq_length:, :],
797
- )
798
-
799
- # 2. Attention
800
- attn_output, context_attn_output = self.attn(
801
- hidden_states=norm_hidden_states,
802
- encoder_hidden_states=norm_encoder_hidden_states,
803
- attention_mask=attention_mask,
804
- image_rotary_emb=image_rotary_emb,
805
- )
806
- attn_output = torch.cat([attn_output, context_attn_output], dim=1)
807
-
808
- # 3. Modulation and residual connection
809
- hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
810
-
811
- proj_output = self.proj_out(hidden_states)
812
- hidden_states_zero = proj_output[:, :num_tokens] * tr_gate.unsqueeze(1)
813
- hidden_states_orig = proj_output[:, num_tokens:] * gate.unsqueeze(1)
814
- hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
815
- hidden_states = hidden_states + residual
816
-
817
- hidden_states, encoder_hidden_states = (
818
- hidden_states[:, :-text_seq_length, :],
819
- hidden_states[:, -text_seq_length:, :],
820
- )
821
- return hidden_states, encoder_hidden_states
822
-
823
-
824
- class HunyuanVideoTokenReplaceTransformerBlock(nn.Module):
825
- def __init__(
826
- self,
827
- num_attention_heads: int,
828
- attention_head_dim: int,
829
- mlp_ratio: float,
830
- qk_norm: str = "rms_norm",
831
- ) -> None:
832
- super().__init__()
833
-
834
- hidden_size = num_attention_heads * attention_head_dim
835
-
836
- self.norm1 = HunyuanVideoTokenReplaceAdaLayerNormZero(hidden_size, norm_type="layer_norm")
837
- self.norm1_context = AdaLayerNormZero(hidden_size, norm_type="layer_norm")
838
-
839
- self.attn = Attention(
840
- query_dim=hidden_size,
841
- cross_attention_dim=None,
842
- added_kv_proj_dim=hidden_size,
843
- dim_head=attention_head_dim,
844
- heads=num_attention_heads,
845
- out_dim=hidden_size,
846
- context_pre_only=False,
847
- bias=True,
848
- processor=HunyuanVideoAttnProcessor2_0(),
849
- qk_norm=qk_norm,
850
- eps=1e-6,
851
- )
852
-
853
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
854
- self.ff = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
855
-
856
- self.norm2_context = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
857
- self.ff_context = FeedForward(hidden_size, mult=mlp_ratio, activation_fn="gelu-approximate")
858
-
859
- def forward(
860
- self,
861
- hidden_states: torch.Tensor,
862
- encoder_hidden_states: torch.Tensor,
863
- temb: torch.Tensor,
864
- attention_mask: Optional[torch.Tensor] = None,
865
- freqs_cis: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
866
- token_replace_emb: torch.Tensor = None,
867
- num_tokens: int = None,
868
- ) -> Tuple[torch.Tensor, torch.Tensor]:
869
- # 1. Input normalization
870
- (
871
- norm_hidden_states,
872
- gate_msa,
873
- shift_mlp,
874
- scale_mlp,
875
- gate_mlp,
876
- tr_gate_msa,
877
- tr_shift_mlp,
878
- tr_scale_mlp,
879
- tr_gate_mlp,
880
- ) = self.norm1(hidden_states, temb, token_replace_emb, num_tokens)
881
- norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
882
- encoder_hidden_states, emb=temb
883
- )
884
-
885
- # 2. Joint attention
886
- attn_output, context_attn_output = self.attn(
887
- hidden_states=norm_hidden_states,
888
- encoder_hidden_states=norm_encoder_hidden_states,
889
- attention_mask=attention_mask,
890
- image_rotary_emb=freqs_cis,
891
- )
892
-
893
- # 3. Modulation and residual connection
894
- hidden_states_zero = hidden_states[:, :num_tokens] + attn_output[:, :num_tokens] * tr_gate_msa.unsqueeze(1)
895
- hidden_states_orig = hidden_states[:, num_tokens:] + attn_output[:, num_tokens:] * gate_msa.unsqueeze(1)
896
- hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
897
- encoder_hidden_states = encoder_hidden_states + context_attn_output * c_gate_msa.unsqueeze(1)
898
-
899
- norm_hidden_states = self.norm2(hidden_states)
900
- norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
901
-
902
- hidden_states_zero = norm_hidden_states[:, :num_tokens] * (1 + tr_scale_mlp[:, None]) + tr_shift_mlp[:, None]
903
- hidden_states_orig = norm_hidden_states[:, num_tokens:] * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
904
- norm_hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
905
- norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
906
-
907
- # 4. Feed-forward
908
- ff_output = self.ff(norm_hidden_states)
909
- context_ff_output = self.ff_context(norm_encoder_hidden_states)
910
-
911
- hidden_states_zero = hidden_states[:, :num_tokens] + ff_output[:, :num_tokens] * tr_gate_mlp.unsqueeze(1)
912
- hidden_states_orig = hidden_states[:, num_tokens:] + ff_output[:, num_tokens:] * gate_mlp.unsqueeze(1)
913
- hidden_states = torch.cat([hidden_states_zero, hidden_states_orig], dim=1)
914
- encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
915
-
916
- return hidden_states, encoder_hidden_states
917
-
918
-
919
- class HunyuanVideoTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
920
- r"""
921
- A Transformer model for video-like data used in [HunyuanVideo](https://huggingface.co/tencent/HunyuanVideo).
922
-
923
- Args:
924
- in_channels (`int`, defaults to `16`):
925
- The number of channels in the input.
926
- out_channels (`int`, defaults to `16`):
927
- The number of channels in the output.
928
- num_attention_heads (`int`, defaults to `24`):
929
- The number of heads to use for multi-head attention.
930
- attention_head_dim (`int`, defaults to `128`):
931
- The number of channels in each head.
932
- num_layers (`int`, defaults to `20`):
933
- The number of layers of dual-stream blocks to use.
934
- num_single_layers (`int`, defaults to `40`):
935
- The number of layers of single-stream blocks to use.
936
- num_refiner_layers (`int`, defaults to `2`):
937
- The number of layers of refiner blocks to use.
938
- mlp_ratio (`float`, defaults to `4.0`):
939
- The ratio of the hidden layer size to the input size in the feedforward network.
940
- patch_size (`int`, defaults to `2`):
941
- The size of the spatial patches to use in the patch embedding layer.
942
- patch_size_t (`int`, defaults to `1`):
943
- The size of the tmeporal patches to use in the patch embedding layer.
944
- qk_norm (`str`, defaults to `rms_norm`):
945
- The normalization to use for the query and key projections in the attention layers.
946
- guidance_embeds (`bool`, defaults to `True`):
947
- Whether to use guidance embeddings in the model.
948
- text_embed_dim (`int`, defaults to `4096`):
949
- Input dimension of text embeddings from the text encoder.
950
- pooled_projection_dim (`int`, defaults to `768`):
951
- The dimension of the pooled projection of the text embeddings.
952
- rope_theta (`float`, defaults to `256.0`):
953
- The value of theta to use in the RoPE layer.
954
- rope_axes_dim (`Tuple[int]`, defaults to `(16, 56, 56)`):
955
- The dimensions of the axes to use in the RoPE layer.
956
- image_condition_type (`str`, *optional*, defaults to `None`):
957
- The type of image conditioning to use. If `None`, no image conditioning is used. If `latent_concat`, the
958
- image is concatenated to the latent stream. If `token_replace`, the image is used to replace first-frame
959
- tokens in the latent stream and apply conditioning.
960
- """
961
-
962
- _supports_gradient_checkpointing = True
963
- _skip_layerwise_casting_patterns = ["x_embedder", "context_embedder", "norm"]
964
- _no_split_modules = [
965
- "HunyuanVideoTransformerBlock",
966
- "HunyuanVideoSingleTransformerBlock",
967
- "HunyuanVideoPatchEmbed",
968
- "HunyuanVideoTokenRefiner",
969
- ]
970
- _repeated_blocks = [
971
- "HunyuanVideoTransformerBlock",
972
- "HunyuanVideoSingleTransformerBlock",
973
- "HunyuanVideoPatchEmbed",
974
- "HunyuanVideoTokenRefiner",
975
- ]
976
-
977
- @register_to_config
978
- def __init__(
979
- self,
980
- in_channels: int = 16,
981
- out_channels: int = 16,
982
- num_attention_heads: int = 24,
983
- attention_head_dim: int = 128,
984
- num_layers: int = 20,
985
- num_single_layers: int = 40,
986
- num_refiner_layers: int = 2,
987
- mlp_ratio: float = 4.0,
988
- patch_size: int = 2,
989
- patch_size_t: int = 1,
990
- qk_norm: str = "rms_norm",
991
- guidance_embeds: bool = True,
992
- text_embed_dim: int = 4096,
993
- pooled_projection_dim: int = 768,
994
- rope_theta: float = 256.0,
995
- rope_axes_dim: Tuple[int, ...] = (16, 56, 56),
996
- image_condition_type: Optional[str] = None,
997
- ) -> None:
998
- super().__init__()
999
-
1000
- supported_image_condition_types = ["latent_concat", "token_replace"]
1001
- if image_condition_type is not None and image_condition_type not in supported_image_condition_types:
1002
- raise ValueError(
1003
- f"Invalid `image_condition_type` ({image_condition_type}). Supported ones are: {supported_image_condition_types}"
1004
- )
1005
-
1006
- inner_dim = num_attention_heads * attention_head_dim
1007
- out_channels = out_channels or in_channels
1008
-
1009
- # 1. Latent and condition embedders
1010
- self.x_embedder = HunyuanVideoPatchEmbed((patch_size_t, patch_size, patch_size), in_channels, inner_dim)
1011
- self.context_embedder = HunyuanVideoTokenRefiner(
1012
- text_embed_dim, num_attention_heads, attention_head_dim, num_layers=num_refiner_layers
1013
- )
1014
-
1015
- self.time_text_embed = HunyuanVideoConditionEmbedding(
1016
- inner_dim, pooled_projection_dim, guidance_embeds, image_condition_type
1017
- )
1018
-
1019
- # 2. RoPE
1020
- self.rope = HunyuanVideoRotaryPosEmbed(patch_size, patch_size_t, rope_axes_dim, rope_theta)
1021
-
1022
- # 3. Dual stream transformer blocks
1023
- if image_condition_type == "token_replace":
1024
- self.transformer_blocks = nn.ModuleList(
1025
- [
1026
- HunyuanVideoTokenReplaceTransformerBlock(
1027
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
1028
- )
1029
- for _ in range(num_layers)
1030
- ]
1031
- )
1032
- else:
1033
- self.transformer_blocks = nn.ModuleList(
1034
- [
1035
- HunyuanVideoTransformerBlock(
1036
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
1037
- )
1038
- for _ in range(num_layers)
1039
- ]
1040
- )
1041
-
1042
- # 4. Single stream transformer blocks
1043
- if image_condition_type == "token_replace":
1044
- self.single_transformer_blocks = nn.ModuleList(
1045
- [
1046
- HunyuanVideoTokenReplaceSingleTransformerBlock(
1047
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
1048
- )
1049
- for _ in range(num_single_layers)
1050
- ]
1051
- )
1052
- else:
1053
- self.single_transformer_blocks = nn.ModuleList(
1054
- [
1055
- HunyuanVideoSingleTransformerBlock(
1056
- num_attention_heads, attention_head_dim, mlp_ratio=mlp_ratio, qk_norm=qk_norm
1057
- )
1058
- for _ in range(num_single_layers)
1059
- ]
1060
- )
1061
-
1062
- # 5. Output projection
1063
- self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6)
1064
- self.proj_out = nn.Linear(inner_dim, patch_size_t * patch_size * patch_size * out_channels)
1065
-
1066
-
1067
- self.gradient_checkpointing = False
1068
- self.sp_world_size = 1
1069
- self.sp_world_rank = 0
1070
-
1071
- def _set_gradient_checkpointing(self, *args, **kwargs):
1072
- if "value" in kwargs:
1073
- self.gradient_checkpointing = kwargs["value"]
1074
- elif "enable" in kwargs:
1075
- self.gradient_checkpointing = kwargs["enable"]
1076
- else:
1077
- raise ValueError("Invalid set gradient checkpointing")
1078
-
1079
- def enable_multi_gpus_inference(self,):
1080
- self.sp_world_size = get_sequence_parallel_world_size()
1081
- self.sp_world_rank = get_sequence_parallel_rank()
1082
- self.set_attn_processor(HunyuanVideoMultiGPUsAttnProcessor2_0())
1083
-
1084
- @property
1085
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
1086
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
1087
- r"""
1088
- Returns:
1089
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
1090
- indexed by its weight name.
1091
- """
1092
- # set recursively
1093
- processors = {}
1094
-
1095
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
1096
- if hasattr(module, "get_processor"):
1097
- processors[f"{name}.processor"] = module.get_processor()
1098
-
1099
- for sub_name, child in module.named_children():
1100
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
1101
-
1102
- return processors
1103
-
1104
- for name, module in self.named_children():
1105
- fn_recursive_add_processors(name, module, processors)
1106
-
1107
- return processors
1108
-
1109
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
1110
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
1111
- r"""
1112
- Sets the attention processor to use to compute attention.
1113
-
1114
- Parameters:
1115
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
1116
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
1117
- for **all** `Attention` layers.
1118
-
1119
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
1120
- processor. This is strongly recommended when setting trainable attention processors.
1121
-
1122
- """
1123
- count = len(self.attn_processors.keys())
1124
-
1125
- if isinstance(processor, dict) and len(processor) != count:
1126
- raise ValueError(
1127
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
1128
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
1129
- )
1130
-
1131
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
1132
- if hasattr(module, "set_processor") and module.set_processor is not None:
1133
- if not isinstance(processor, dict):
1134
- module.set_processor(processor)
1135
- else:
1136
- module.set_processor(processor.pop(f"{name}.processor"))
1137
-
1138
- for sub_name, child in module.named_children():
1139
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
1140
-
1141
- for name, module in self.named_children():
1142
- fn_recursive_attn_processor(name, module, processor)
1143
-
1144
- def forward(
1145
- self,
1146
- hidden_states: torch.Tensor,
1147
- timestep: torch.LongTensor,
1148
- encoder_hidden_states: torch.Tensor,
1149
- encoder_attention_mask: torch.Tensor,
1150
- pooled_projections: torch.Tensor,
1151
- guidance: torch.Tensor = None,
1152
- attention_kwargs: Optional[Dict[str, Any]] = None,
1153
- return_dict: bool = True,
1154
- ) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
1155
- if attention_kwargs is not None:
1156
- attention_kwargs = attention_kwargs.copy()
1157
- lora_scale = attention_kwargs.pop("scale", 1.0)
1158
- else:
1159
- lora_scale = 1.0
1160
-
1161
- if USE_PEFT_BACKEND:
1162
- # weight the lora layers by setting `lora_scale` for each PEFT layer
1163
- scale_lora_layers(self, lora_scale)
1164
- else:
1165
- if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
1166
- logger.warning(
1167
- "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
1168
- )
1169
-
1170
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
1171
- p, p_t = self.config.patch_size, self.config.patch_size_t
1172
- post_patch_num_frames = num_frames // p_t
1173
- post_patch_height = height // p
1174
- post_patch_width = width // p
1175
- first_frame_num_tokens = 1 * post_patch_height * post_patch_width
1176
-
1177
- # 1. RoPE
1178
- image_rotary_emb = self.rope(hidden_states)
1179
-
1180
- # 2. Conditional embeddings
1181
- temb, token_replace_emb = self.time_text_embed(timestep, pooled_projections, guidance)
1182
-
1183
- hidden_states = self.x_embedder(hidden_states)
1184
- encoder_hidden_states = self.context_embedder(encoder_hidden_states, timestep, encoder_attention_mask)
1185
-
1186
- # 3. Attention mask preparation
1187
- latent_sequence_length = hidden_states.shape[1]
1188
- condition_sequence_length = encoder_hidden_states.shape[1]
1189
- sequence_length = latent_sequence_length + condition_sequence_length
1190
- attention_mask = torch.ones(
1191
- batch_size, sequence_length, device=hidden_states.device, dtype=torch.bool
1192
- ) # [B, N]
1193
- effective_condition_sequence_length = encoder_attention_mask.sum(dim=1, dtype=torch.int) # [B,]
1194
- effective_sequence_length = latent_sequence_length + effective_condition_sequence_length
1195
- indices = torch.arange(sequence_length, device=hidden_states.device).unsqueeze(0) # [1, N]
1196
- mask_indices = indices >= effective_sequence_length.unsqueeze(1) # [B, N]
1197
- attention_mask = attention_mask.masked_fill(mask_indices, False)
1198
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1) # [B, 1, 1, N]
1199
-
1200
- # Context Parallel
1201
- if self.sp_world_size > 1:
1202
- hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
1203
- if image_rotary_emb is not None:
1204
- image_rotary_emb = (
1205
- torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
1206
- torch.chunk(image_rotary_emb[1], self.sp_world_size, dim=0)[self.sp_world_rank]
1207
- )
1208
- if self.sp_world_rank >=1:
1209
- first_frame_num_tokens = 0
1210
-
1211
- # 4. Transformer blocks
1212
- if torch.is_grad_enabled() and self.gradient_checkpointing:
1213
- for block in self.transformer_blocks:
1214
-
1215
- def create_custom_forward(module):
1216
- def custom_forward(*inputs):
1217
- return module(*inputs)
1218
-
1219
- return custom_forward
1220
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1221
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
1222
- create_custom_forward(block),
1223
- hidden_states,
1224
- encoder_hidden_states,
1225
- temb,
1226
- attention_mask,
1227
- image_rotary_emb,
1228
- token_replace_emb,
1229
- first_frame_num_tokens,
1230
- **ckpt_kwargs,
1231
- )
1232
-
1233
- for block in self.single_transformer_blocks:
1234
-
1235
- def create_custom_forward(module):
1236
- def custom_forward(*inputs):
1237
- return module(*inputs)
1238
-
1239
- return custom_forward
1240
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1241
- hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
1242
- create_custom_forward(block),
1243
- hidden_states,
1244
- encoder_hidden_states,
1245
- temb,
1246
- attention_mask,
1247
- image_rotary_emb,
1248
- token_replace_emb,
1249
- first_frame_num_tokens,
1250
- **ckpt_kwargs,
1251
- )
1252
-
1253
- else:
1254
- for block in self.transformer_blocks:
1255
- hidden_states, encoder_hidden_states = block(
1256
- hidden_states,
1257
- encoder_hidden_states,
1258
- temb,
1259
- attention_mask,
1260
- image_rotary_emb,
1261
- token_replace_emb,
1262
- first_frame_num_tokens,
1263
- )
1264
-
1265
- for block in self.single_transformer_blocks:
1266
- hidden_states, encoder_hidden_states = block(
1267
- hidden_states,
1268
- encoder_hidden_states,
1269
- temb,
1270
- attention_mask,
1271
- image_rotary_emb,
1272
- token_replace_emb,
1273
- first_frame_num_tokens,
1274
- )
1275
-
1276
- # 5. Output projection
1277
- hidden_states = self.norm_out(hidden_states, temb)
1278
- hidden_states = self.proj_out(hidden_states)
1279
-
1280
- if self.sp_world_size > 1:
1281
- hidden_states = get_sp_group().all_gather(hidden_states, dim=1)
1282
-
1283
- hidden_states = hidden_states.reshape(
1284
- batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1, p_t, p, p
1285
- )
1286
- hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
1287
- hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
1288
-
1289
- if USE_PEFT_BACKEND:
1290
- # remove `lora_scale` from each PEFT layer
1291
- unscale_lora_layers(self, lora_scale)
1292
-
1293
- if not return_dict:
1294
- return (hidden_states,)
1295
-
1296
- return Transformer2DModelOutput(sample=hidden_states)
1297
-
1298
-
1299
- @classmethod
1300
- def from_pretrained(
1301
- cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1302
- low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1303
- ):
1304
- if subfolder is not None:
1305
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1306
- print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1307
-
1308
- config_file = os.path.join(pretrained_model_path, 'config.json')
1309
- if not os.path.isfile(config_file):
1310
- raise RuntimeError(f"{config_file} does not exist")
1311
- with open(config_file, "r") as f:
1312
- config = json.load(f)
1313
-
1314
- from diffusers.utils import WEIGHTS_NAME
1315
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1316
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
1317
-
1318
- if "dict_mapping" in transformer_additional_kwargs.keys():
1319
- for key in transformer_additional_kwargs["dict_mapping"]:
1320
- transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
1321
-
1322
- if low_cpu_mem_usage:
1323
- try:
1324
- import re
1325
-
1326
- from diffusers import __version__ as diffusers_version
1327
- if diffusers_version >= "0.33.0":
1328
- from diffusers.models.model_loading_utils import \
1329
- load_model_dict_into_meta
1330
- else:
1331
- from diffusers.models.modeling_utils import \
1332
- load_model_dict_into_meta
1333
- from diffusers.utils import is_accelerate_available
1334
- if is_accelerate_available():
1335
- import accelerate
1336
-
1337
- # Instantiate model with empty weights
1338
- with accelerate.init_empty_weights():
1339
- model = cls.from_config(config, **transformer_additional_kwargs)
1340
-
1341
- param_device = "cpu"
1342
- if os.path.exists(model_file):
1343
- state_dict = torch.load(model_file, map_location="cpu")
1344
- elif os.path.exists(model_file_safetensors):
1345
- from safetensors.torch import load_file, safe_open
1346
- state_dict = load_file(model_file_safetensors)
1347
- else:
1348
- from safetensors.torch import load_file, safe_open
1349
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1350
- state_dict = {}
1351
- print(model_files_safetensors)
1352
- for _model_file_safetensors in model_files_safetensors:
1353
- _state_dict = load_file(_model_file_safetensors)
1354
- for key in _state_dict:
1355
- state_dict[key] = _state_dict[key]
1356
-
1357
- filtered_state_dict = {}
1358
- for key in state_dict:
1359
- if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
1360
- filtered_state_dict[key] = state_dict[key]
1361
- else:
1362
- print(f"Skipping key '{key}' due to size mismatch or absence in model.")
1363
-
1364
- model_keys = set(model.state_dict().keys())
1365
- loaded_keys = set(filtered_state_dict.keys())
1366
- missing_keys = model_keys - loaded_keys
1367
-
1368
- def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
1369
- initialized_dict = {}
1370
-
1371
- with torch.no_grad():
1372
- for key in missing_keys:
1373
- param_shape = model_state_dict[key].shape
1374
- param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
1375
- if 'weight' in key:
1376
- if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
1377
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1378
- elif 'embedding' in key or 'embed' in key:
1379
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1380
- elif 'head' in key or 'output' in key or 'proj_out' in key:
1381
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1382
- elif len(param_shape) >= 2:
1383
- initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
1384
- nn.init.xavier_uniform_(initialized_dict[key])
1385
- else:
1386
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1387
- elif 'bias' in key:
1388
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1389
- elif 'running_mean' in key:
1390
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1391
- elif 'running_var' in key:
1392
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1393
- elif 'num_batches_tracked' in key:
1394
- initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
1395
- else:
1396
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1397
-
1398
- return initialized_dict
1399
-
1400
- if missing_keys:
1401
- print(f"Missing keys will be initialized: {sorted(missing_keys)}")
1402
- initialized_params = initialize_missing_parameters(
1403
- missing_keys,
1404
- model.state_dict(),
1405
- torch_dtype
1406
- )
1407
- filtered_state_dict.update(initialized_params)
1408
-
1409
- if diffusers_version >= "0.33.0":
1410
- # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
1411
- # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
1412
- load_model_dict_into_meta(
1413
- model,
1414
- filtered_state_dict,
1415
- dtype=torch_dtype,
1416
- model_name_or_path=pretrained_model_path,
1417
- )
1418
- else:
1419
- model._convert_deprecated_attention_blocks(filtered_state_dict)
1420
- unexpected_keys = load_model_dict_into_meta(
1421
- model,
1422
- filtered_state_dict,
1423
- device=param_device,
1424
- dtype=torch_dtype,
1425
- model_name_or_path=pretrained_model_path,
1426
- )
1427
-
1428
- if cls._keys_to_ignore_on_load_unexpected is not None:
1429
- for pat in cls._keys_to_ignore_on_load_unexpected:
1430
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1431
-
1432
- if len(unexpected_keys) > 0:
1433
- print(
1434
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1435
- )
1436
-
1437
- return model
1438
- except Exception as e:
1439
- print(
1440
- f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1441
- )
1442
-
1443
- model = cls.from_config(config, **transformer_additional_kwargs)
1444
- if os.path.exists(model_file):
1445
- state_dict = torch.load(model_file, map_location="cpu")
1446
- elif os.path.exists(model_file_safetensors):
1447
- from safetensors.torch import load_file, safe_open
1448
- state_dict = load_file(model_file_safetensors)
1449
- else:
1450
- from safetensors.torch import load_file, safe_open
1451
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1452
- state_dict = {}
1453
- for _model_file_safetensors in model_files_safetensors:
1454
- _state_dict = load_file(_model_file_safetensors)
1455
- for key in _state_dict:
1456
- state_dict[key] = _state_dict[key]
1457
-
1458
- tmp_state_dict = {}
1459
- for key in state_dict:
1460
- if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1461
- tmp_state_dict[key] = state_dict[key]
1462
- else:
1463
- print(key, "Size don't match, skip")
1464
-
1465
- state_dict = tmp_state_dict
1466
-
1467
- m, u = model.load_state_dict(state_dict, strict=False)
1468
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1469
- print(m)
1470
-
1471
- params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1472
- print(f"### All Parameters: {sum(params) / 1e6} M")
1473
-
1474
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1475
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1476
-
1477
- model = model.to(torch_dtype)
1478
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/hunyuanvideo_vae.py DELETED
@@ -1,1082 +0,0 @@
1
- # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py
2
- # Copyright 2025 The Hunyuan Team and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
- from typing import Optional, Tuple, Union
17
-
18
- import numpy as np
19
- import torch
20
- import torch.nn as nn
21
- import torch.nn.functional as F
22
- from diffusers.configuration_utils import ConfigMixin, register_to_config
23
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
24
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
25
- from diffusers.models.activations import get_activation
26
- from diffusers.models.attention import FeedForward
27
- from diffusers.models.attention_processor import Attention
28
- from diffusers.models.autoencoders.vae import (DecoderOutput,
29
- DiagonalGaussianDistribution)
30
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
31
- from diffusers.models.modeling_outputs import (AutoencoderKLOutput,
32
- Transformer2DModelOutput)
33
- from diffusers.models.modeling_utils import ModelMixin
34
- from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
35
- from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
36
- scale_lora_layers, unscale_lora_layers)
37
- from diffusers.utils.accelerate_utils import apply_forward_hook
38
- from diffusers.utils.torch_utils import maybe_allow_in_graph
39
-
40
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
41
-
42
-
43
- def prepare_causal_attention_mask(
44
- num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
45
- ) -> torch.Tensor:
46
- indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
47
- indices_blocks = indices.repeat_interleave(height_width)
48
- x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
49
- mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
50
-
51
- if batch_size is not None:
52
- mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
53
- return mask
54
-
55
-
56
- class HunyuanVideoCausalConv3d(nn.Module):
57
- def __init__(
58
- self,
59
- in_channels: int,
60
- out_channels: int,
61
- kernel_size: Union[int, Tuple[int, int, int]] = 3,
62
- stride: Union[int, Tuple[int, int, int]] = 1,
63
- padding: Union[int, Tuple[int, int, int]] = 0,
64
- dilation: Union[int, Tuple[int, int, int]] = 1,
65
- bias: bool = True,
66
- pad_mode: str = "replicate",
67
- ) -> None:
68
- super().__init__()
69
-
70
- kernel_size = (kernel_size, kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
71
-
72
- self.pad_mode = pad_mode
73
- self.time_causal_padding = (
74
- kernel_size[0] // 2,
75
- kernel_size[0] // 2,
76
- kernel_size[1] // 2,
77
- kernel_size[1] // 2,
78
- kernel_size[2] - 1,
79
- 0,
80
- )
81
-
82
- self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
83
-
84
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
85
- hidden_states = F.pad(hidden_states, self.time_causal_padding, mode=self.pad_mode)
86
- return self.conv(hidden_states)
87
-
88
-
89
- class HunyuanVideoUpsampleCausal3D(nn.Module):
90
- def __init__(
91
- self,
92
- in_channels: int,
93
- out_channels: Optional[int] = None,
94
- kernel_size: int = 3,
95
- stride: int = 1,
96
- bias: bool = True,
97
- upsample_factor: Tuple[float, float, float] = (2, 2, 2),
98
- ) -> None:
99
- super().__init__()
100
-
101
- out_channels = out_channels or in_channels
102
- self.upsample_factor = upsample_factor
103
-
104
- self.conv = HunyuanVideoCausalConv3d(in_channels, out_channels, kernel_size, stride, bias=bias)
105
-
106
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
107
- num_frames = hidden_states.size(2)
108
-
109
- first_frame, other_frames = hidden_states.split((1, num_frames - 1), dim=2)
110
- first_frame = F.interpolate(
111
- first_frame.squeeze(2), scale_factor=self.upsample_factor[1:], mode="nearest"
112
- ).unsqueeze(2)
113
-
114
- if num_frames > 1:
115
- # See: https://github.com/pytorch/pytorch/issues/81665
116
- # Unless you have a version of pytorch where non-contiguous implementation of F.interpolate
117
- # is fixed, this will raise either a runtime error, or fail silently with bad outputs.
118
- # If you are encountering an error here, make sure to try running encoding/decoding with
119
- # `vae.enable_tiling()` first. If that doesn't work, open an issue at:
120
- # https://github.com/huggingface/diffusers/issues
121
- other_frames = other_frames.contiguous()
122
- other_frames = F.interpolate(other_frames, scale_factor=self.upsample_factor, mode="nearest")
123
- hidden_states = torch.cat((first_frame, other_frames), dim=2)
124
- else:
125
- hidden_states = first_frame
126
-
127
- hidden_states = self.conv(hidden_states)
128
- return hidden_states
129
-
130
-
131
- class HunyuanVideoDownsampleCausal3D(nn.Module):
132
- def __init__(
133
- self,
134
- channels: int,
135
- out_channels: Optional[int] = None,
136
- padding: int = 1,
137
- kernel_size: int = 3,
138
- bias: bool = True,
139
- stride=2,
140
- ) -> None:
141
- super().__init__()
142
- out_channels = out_channels or channels
143
-
144
- self.conv = HunyuanVideoCausalConv3d(channels, out_channels, kernel_size, stride, padding, bias=bias)
145
-
146
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
147
- hidden_states = self.conv(hidden_states)
148
- return hidden_states
149
-
150
-
151
- class HunyuanVideoResnetBlockCausal3D(nn.Module):
152
- def __init__(
153
- self,
154
- in_channels: int,
155
- out_channels: Optional[int] = None,
156
- dropout: float = 0.0,
157
- groups: int = 32,
158
- eps: float = 1e-6,
159
- non_linearity: str = "swish",
160
- ) -> None:
161
- super().__init__()
162
- out_channels = out_channels or in_channels
163
-
164
- self.nonlinearity = get_activation(non_linearity)
165
-
166
- self.norm1 = nn.GroupNorm(groups, in_channels, eps=eps, affine=True)
167
- self.conv1 = HunyuanVideoCausalConv3d(in_channels, out_channels, 3, 1, 0)
168
-
169
- self.norm2 = nn.GroupNorm(groups, out_channels, eps=eps, affine=True)
170
- self.dropout = nn.Dropout(dropout)
171
- self.conv2 = HunyuanVideoCausalConv3d(out_channels, out_channels, 3, 1, 0)
172
-
173
- self.conv_shortcut = None
174
- if in_channels != out_channels:
175
- self.conv_shortcut = HunyuanVideoCausalConv3d(in_channels, out_channels, 1, 1, 0)
176
-
177
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
178
- hidden_states = hidden_states.contiguous()
179
- residual = hidden_states
180
-
181
- hidden_states = self.norm1(hidden_states)
182
- hidden_states = self.nonlinearity(hidden_states)
183
- hidden_states = self.conv1(hidden_states)
184
-
185
- hidden_states = self.norm2(hidden_states)
186
- hidden_states = self.nonlinearity(hidden_states)
187
- hidden_states = self.dropout(hidden_states)
188
- hidden_states = self.conv2(hidden_states)
189
-
190
- if self.conv_shortcut is not None:
191
- residual = self.conv_shortcut(residual)
192
-
193
- hidden_states = hidden_states + residual
194
- return hidden_states
195
-
196
-
197
- class HunyuanVideoMidBlock3D(nn.Module):
198
- def __init__(
199
- self,
200
- in_channels: int,
201
- dropout: float = 0.0,
202
- num_layers: int = 1,
203
- resnet_eps: float = 1e-6,
204
- resnet_act_fn: str = "swish",
205
- resnet_groups: int = 32,
206
- add_attention: bool = True,
207
- attention_head_dim: int = 1,
208
- ) -> None:
209
- super().__init__()
210
- resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
211
- self.add_attention = add_attention
212
-
213
- # There is always at least one resnet
214
- resnets = [
215
- HunyuanVideoResnetBlockCausal3D(
216
- in_channels=in_channels,
217
- out_channels=in_channels,
218
- eps=resnet_eps,
219
- groups=resnet_groups,
220
- dropout=dropout,
221
- non_linearity=resnet_act_fn,
222
- )
223
- ]
224
- attentions = []
225
-
226
- for _ in range(num_layers):
227
- if self.add_attention:
228
- attentions.append(
229
- Attention(
230
- in_channels,
231
- heads=in_channels // attention_head_dim,
232
- dim_head=attention_head_dim,
233
- eps=resnet_eps,
234
- norm_num_groups=resnet_groups,
235
- residual_connection=True,
236
- bias=True,
237
- upcast_softmax=True,
238
- _from_deprecated_attn_block=True,
239
- )
240
- )
241
- else:
242
- attentions.append(None)
243
-
244
- resnets.append(
245
- HunyuanVideoResnetBlockCausal3D(
246
- in_channels=in_channels,
247
- out_channels=in_channels,
248
- eps=resnet_eps,
249
- groups=resnet_groups,
250
- dropout=dropout,
251
- non_linearity=resnet_act_fn,
252
- )
253
- )
254
-
255
- self.attentions = nn.ModuleList(attentions)
256
- self.resnets = nn.ModuleList(resnets)
257
-
258
- self.gradient_checkpointing = False
259
-
260
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
261
- if torch.is_grad_enabled() and self.gradient_checkpointing:
262
- hidden_states = self._gradient_checkpointing_func(self.resnets[0], hidden_states)
263
-
264
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
265
- if attn is not None:
266
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
267
- hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
268
- attention_mask = prepare_causal_attention_mask(
269
- num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
270
- )
271
- hidden_states = attn(hidden_states, attention_mask=attention_mask)
272
- hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
273
-
274
- hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
275
-
276
- else:
277
- hidden_states = self.resnets[0](hidden_states)
278
-
279
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
280
- if attn is not None:
281
- batch_size, num_channels, num_frames, height, width = hidden_states.shape
282
- hidden_states = hidden_states.permute(0, 2, 3, 4, 1).flatten(1, 3)
283
- attention_mask = prepare_causal_attention_mask(
284
- num_frames, height * width, hidden_states.dtype, hidden_states.device, batch_size=batch_size
285
- )
286
- hidden_states = attn(hidden_states, attention_mask=attention_mask)
287
- hidden_states = hidden_states.unflatten(1, (num_frames, height, width)).permute(0, 4, 1, 2, 3)
288
-
289
- hidden_states = resnet(hidden_states)
290
-
291
- return hidden_states
292
-
293
-
294
- class HunyuanVideoDownBlock3D(nn.Module):
295
- def __init__(
296
- self,
297
- in_channels: int,
298
- out_channels: int,
299
- dropout: float = 0.0,
300
- num_layers: int = 1,
301
- resnet_eps: float = 1e-6,
302
- resnet_act_fn: str = "swish",
303
- resnet_groups: int = 32,
304
- add_downsample: bool = True,
305
- downsample_stride: int = 2,
306
- downsample_padding: int = 1,
307
- ) -> None:
308
- super().__init__()
309
- resnets = []
310
-
311
- for i in range(num_layers):
312
- in_channels = in_channels if i == 0 else out_channels
313
- resnets.append(
314
- HunyuanVideoResnetBlockCausal3D(
315
- in_channels=in_channels,
316
- out_channels=out_channels,
317
- eps=resnet_eps,
318
- groups=resnet_groups,
319
- dropout=dropout,
320
- non_linearity=resnet_act_fn,
321
- )
322
- )
323
-
324
- self.resnets = nn.ModuleList(resnets)
325
-
326
- if add_downsample:
327
- self.downsamplers = nn.ModuleList(
328
- [
329
- HunyuanVideoDownsampleCausal3D(
330
- out_channels,
331
- out_channels=out_channels,
332
- padding=downsample_padding,
333
- stride=downsample_stride,
334
- )
335
- ]
336
- )
337
- else:
338
- self.downsamplers = None
339
-
340
- self.gradient_checkpointing = False
341
-
342
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
343
- if torch.is_grad_enabled() and self.gradient_checkpointing:
344
- for resnet in self.resnets:
345
- hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
346
- else:
347
- for resnet in self.resnets:
348
- hidden_states = resnet(hidden_states)
349
-
350
- if self.downsamplers is not None:
351
- for downsampler in self.downsamplers:
352
- hidden_states = downsampler(hidden_states)
353
-
354
- return hidden_states
355
-
356
-
357
- class HunyuanVideoUpBlock3D(nn.Module):
358
- def __init__(
359
- self,
360
- in_channels: int,
361
- out_channels: int,
362
- dropout: float = 0.0,
363
- num_layers: int = 1,
364
- resnet_eps: float = 1e-6,
365
- resnet_act_fn: str = "swish",
366
- resnet_groups: int = 32,
367
- add_upsample: bool = True,
368
- upsample_scale_factor: Tuple[int, int, int] = (2, 2, 2),
369
- ) -> None:
370
- super().__init__()
371
- resnets = []
372
-
373
- for i in range(num_layers):
374
- input_channels = in_channels if i == 0 else out_channels
375
-
376
- resnets.append(
377
- HunyuanVideoResnetBlockCausal3D(
378
- in_channels=input_channels,
379
- out_channels=out_channels,
380
- eps=resnet_eps,
381
- groups=resnet_groups,
382
- dropout=dropout,
383
- non_linearity=resnet_act_fn,
384
- )
385
- )
386
-
387
- self.resnets = nn.ModuleList(resnets)
388
-
389
- if add_upsample:
390
- self.upsamplers = nn.ModuleList(
391
- [
392
- HunyuanVideoUpsampleCausal3D(
393
- out_channels,
394
- out_channels=out_channels,
395
- upsample_factor=upsample_scale_factor,
396
- )
397
- ]
398
- )
399
- else:
400
- self.upsamplers = None
401
-
402
- self.gradient_checkpointing = False
403
-
404
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
405
- if torch.is_grad_enabled() and self.gradient_checkpointing:
406
- for resnet in self.resnets:
407
- hidden_states = self._gradient_checkpointing_func(resnet, hidden_states)
408
-
409
- else:
410
- for resnet in self.resnets:
411
- hidden_states = resnet(hidden_states)
412
-
413
- if self.upsamplers is not None:
414
- for upsampler in self.upsamplers:
415
- hidden_states = upsampler(hidden_states)
416
-
417
- return hidden_states
418
-
419
-
420
- class HunyuanVideoEncoder3D(nn.Module):
421
- r"""
422
- Causal encoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
423
- """
424
-
425
- def __init__(
426
- self,
427
- in_channels: int = 3,
428
- out_channels: int = 3,
429
- down_block_types: Tuple[str, ...] = (
430
- "HunyuanVideoDownBlock3D",
431
- "HunyuanVideoDownBlock3D",
432
- "HunyuanVideoDownBlock3D",
433
- "HunyuanVideoDownBlock3D",
434
- ),
435
- block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
436
- layers_per_block: int = 2,
437
- norm_num_groups: int = 32,
438
- act_fn: str = "silu",
439
- double_z: bool = True,
440
- mid_block_add_attention=True,
441
- temporal_compression_ratio: int = 4,
442
- spatial_compression_ratio: int = 8,
443
- ) -> None:
444
- super().__init__()
445
-
446
- self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
447
- self.mid_block = None
448
- self.down_blocks = nn.ModuleList([])
449
-
450
- output_channel = block_out_channels[0]
451
- for i, down_block_type in enumerate(down_block_types):
452
- if down_block_type != "HunyuanVideoDownBlock3D":
453
- raise ValueError(f"Unsupported down_block_type: {down_block_type}")
454
-
455
- input_channel = output_channel
456
- output_channel = block_out_channels[i]
457
- is_final_block = i == len(block_out_channels) - 1
458
- num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
459
- num_time_downsample_layers = int(np.log2(temporal_compression_ratio))
460
-
461
- if temporal_compression_ratio == 4:
462
- add_spatial_downsample = bool(i < num_spatial_downsample_layers)
463
- add_time_downsample = bool(
464
- i >= (len(block_out_channels) - 1 - num_time_downsample_layers) and not is_final_block
465
- )
466
- elif temporal_compression_ratio == 8:
467
- add_spatial_downsample = bool(i < num_spatial_downsample_layers)
468
- add_time_downsample = bool(i < num_time_downsample_layers)
469
- else:
470
- raise ValueError(f"Unsupported time_compression_ratio: {temporal_compression_ratio}")
471
-
472
- downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
473
- downsample_stride_T = (2,) if add_time_downsample else (1,)
474
- downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
475
-
476
- down_block = HunyuanVideoDownBlock3D(
477
- num_layers=layers_per_block,
478
- in_channels=input_channel,
479
- out_channels=output_channel,
480
- add_downsample=bool(add_spatial_downsample or add_time_downsample),
481
- resnet_eps=1e-6,
482
- resnet_act_fn=act_fn,
483
- resnet_groups=norm_num_groups,
484
- downsample_stride=downsample_stride,
485
- downsample_padding=0,
486
- )
487
-
488
- self.down_blocks.append(down_block)
489
-
490
- self.mid_block = HunyuanVideoMidBlock3D(
491
- in_channels=block_out_channels[-1],
492
- resnet_eps=1e-6,
493
- resnet_act_fn=act_fn,
494
- attention_head_dim=block_out_channels[-1],
495
- resnet_groups=norm_num_groups,
496
- add_attention=mid_block_add_attention,
497
- )
498
-
499
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
500
- self.conv_act = nn.SiLU()
501
-
502
- conv_out_channels = 2 * out_channels if double_z else out_channels
503
- self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
504
-
505
- self.gradient_checkpointing = False
506
-
507
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
508
- hidden_states = self.conv_in(hidden_states)
509
-
510
- if torch.is_grad_enabled() and self.gradient_checkpointing:
511
- for down_block in self.down_blocks:
512
- hidden_states = self._gradient_checkpointing_func(down_block, hidden_states)
513
-
514
- hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
515
- else:
516
- for down_block in self.down_blocks:
517
- hidden_states = down_block(hidden_states)
518
-
519
- hidden_states = self.mid_block(hidden_states)
520
-
521
- hidden_states = self.conv_norm_out(hidden_states)
522
- hidden_states = self.conv_act(hidden_states)
523
- hidden_states = self.conv_out(hidden_states)
524
-
525
- return hidden_states
526
-
527
-
528
- class HunyuanVideoDecoder3D(nn.Module):
529
- r"""
530
- Causal decoder for 3D video-like data introduced in [Hunyuan Video](https://huggingface.co/papers/2412.03603).
531
- """
532
-
533
- def __init__(
534
- self,
535
- in_channels: int = 3,
536
- out_channels: int = 3,
537
- up_block_types: Tuple[str, ...] = (
538
- "HunyuanVideoUpBlock3D",
539
- "HunyuanVideoUpBlock3D",
540
- "HunyuanVideoUpBlock3D",
541
- "HunyuanVideoUpBlock3D",
542
- ),
543
- block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
544
- layers_per_block: int = 2,
545
- norm_num_groups: int = 32,
546
- act_fn: str = "silu",
547
- mid_block_add_attention=True,
548
- time_compression_ratio: int = 4,
549
- spatial_compression_ratio: int = 8,
550
- ):
551
- super().__init__()
552
- self.layers_per_block = layers_per_block
553
-
554
- self.conv_in = HunyuanVideoCausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
555
- self.up_blocks = nn.ModuleList([])
556
-
557
- # mid
558
- self.mid_block = HunyuanVideoMidBlock3D(
559
- in_channels=block_out_channels[-1],
560
- resnet_eps=1e-6,
561
- resnet_act_fn=act_fn,
562
- attention_head_dim=block_out_channels[-1],
563
- resnet_groups=norm_num_groups,
564
- add_attention=mid_block_add_attention,
565
- )
566
-
567
- # up
568
- reversed_block_out_channels = list(reversed(block_out_channels))
569
- output_channel = reversed_block_out_channels[0]
570
- for i, up_block_type in enumerate(up_block_types):
571
- if up_block_type != "HunyuanVideoUpBlock3D":
572
- raise ValueError(f"Unsupported up_block_type: {up_block_type}")
573
-
574
- prev_output_channel = output_channel
575
- output_channel = reversed_block_out_channels[i]
576
- is_final_block = i == len(block_out_channels) - 1
577
- num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
578
- num_time_upsample_layers = int(np.log2(time_compression_ratio))
579
-
580
- if time_compression_ratio == 4:
581
- add_spatial_upsample = bool(i < num_spatial_upsample_layers)
582
- add_time_upsample = bool(
583
- i >= len(block_out_channels) - 1 - num_time_upsample_layers and not is_final_block
584
- )
585
- else:
586
- raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}")
587
-
588
- upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
589
- upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
590
- upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
591
-
592
- up_block = HunyuanVideoUpBlock3D(
593
- num_layers=self.layers_per_block + 1,
594
- in_channels=prev_output_channel,
595
- out_channels=output_channel,
596
- add_upsample=bool(add_spatial_upsample or add_time_upsample),
597
- upsample_scale_factor=upsample_scale_factor,
598
- resnet_eps=1e-6,
599
- resnet_act_fn=act_fn,
600
- resnet_groups=norm_num_groups,
601
- )
602
-
603
- self.up_blocks.append(up_block)
604
- prev_output_channel = output_channel
605
-
606
- # out
607
- self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
608
- self.conv_act = nn.SiLU()
609
- self.conv_out = HunyuanVideoCausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
610
-
611
- self.gradient_checkpointing = False
612
-
613
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
614
- hidden_states = self.conv_in(hidden_states)
615
-
616
- if torch.is_grad_enabled() and self.gradient_checkpointing:
617
- hidden_states = self._gradient_checkpointing_func(self.mid_block, hidden_states)
618
-
619
- for up_block in self.up_blocks:
620
- hidden_states = self._gradient_checkpointing_func(up_block, hidden_states)
621
- else:
622
- hidden_states = self.mid_block(hidden_states)
623
-
624
- for up_block in self.up_blocks:
625
- hidden_states = up_block(hidden_states)
626
-
627
- # post-process
628
- hidden_states = self.conv_norm_out(hidden_states)
629
- hidden_states = self.conv_act(hidden_states)
630
- hidden_states = self.conv_out(hidden_states)
631
-
632
- return hidden_states
633
-
634
-
635
- class AutoencoderKLHunyuanVideo(ModelMixin, ConfigMixin, FromOriginalModelMixin):
636
- r"""
637
- A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
638
- Introduced in [HunyuanVideo](https://huggingface.co/papers/2412.03603).
639
-
640
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
641
- for all models (such as downloading or saving).
642
- """
643
-
644
- _supports_gradient_checkpointing = True
645
-
646
- @register_to_config
647
- def __init__(
648
- self,
649
- in_channels: int = 3,
650
- out_channels: int = 3,
651
- latent_channels: int = 16,
652
- down_block_types: Tuple[str, ...] = (
653
- "HunyuanVideoDownBlock3D",
654
- "HunyuanVideoDownBlock3D",
655
- "HunyuanVideoDownBlock3D",
656
- "HunyuanVideoDownBlock3D",
657
- ),
658
- up_block_types: Tuple[str, ...] = (
659
- "HunyuanVideoUpBlock3D",
660
- "HunyuanVideoUpBlock3D",
661
- "HunyuanVideoUpBlock3D",
662
- "HunyuanVideoUpBlock3D",
663
- ),
664
- block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
665
- layers_per_block: int = 2,
666
- act_fn: str = "silu",
667
- norm_num_groups: int = 32,
668
- scaling_factor: float = 0.476986,
669
- spatial_compression_ratio: int = 8,
670
- temporal_compression_ratio: int = 4,
671
- mid_block_add_attention: bool = True,
672
- ) -> None:
673
- super().__init__()
674
-
675
- self.time_compression_ratio = temporal_compression_ratio
676
-
677
- self.encoder = HunyuanVideoEncoder3D(
678
- in_channels=in_channels,
679
- out_channels=latent_channels,
680
- down_block_types=down_block_types,
681
- block_out_channels=block_out_channels,
682
- layers_per_block=layers_per_block,
683
- norm_num_groups=norm_num_groups,
684
- act_fn=act_fn,
685
- double_z=True,
686
- mid_block_add_attention=mid_block_add_attention,
687
- temporal_compression_ratio=temporal_compression_ratio,
688
- spatial_compression_ratio=spatial_compression_ratio,
689
- )
690
-
691
- self.decoder = HunyuanVideoDecoder3D(
692
- in_channels=latent_channels,
693
- out_channels=out_channels,
694
- up_block_types=up_block_types,
695
- block_out_channels=block_out_channels,
696
- layers_per_block=layers_per_block,
697
- norm_num_groups=norm_num_groups,
698
- act_fn=act_fn,
699
- time_compression_ratio=temporal_compression_ratio,
700
- spatial_compression_ratio=spatial_compression_ratio,
701
- mid_block_add_attention=mid_block_add_attention,
702
- )
703
-
704
- self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
705
- self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
706
-
707
- self.spatial_compression_ratio = spatial_compression_ratio
708
- self.temporal_compression_ratio = temporal_compression_ratio
709
-
710
- # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
711
- # to perform decoding of a single video latent at a time.
712
- self.use_slicing = False
713
-
714
- # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
715
- # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
716
- # intermediate tiles together, the memory requirement can be lowered.
717
- self.use_tiling = True
718
-
719
- # When decoding temporally long video latents, the memory requirement is very high. By decoding latent frames
720
- # at a fixed frame batch size (based on `self.tile_sample_min_num_frames`), the memory requirement can be lowered.
721
- self.use_framewise_encoding = True
722
- self.use_framewise_decoding = True
723
-
724
- # The minimal tile height and width for spatial tiling to be used
725
- self.tile_sample_min_height = 256
726
- self.tile_sample_min_width = 256
727
- self.tile_sample_min_num_frames = 16
728
-
729
- # The minimal distance between two spatial tiles
730
- self.tile_sample_stride_height = 192
731
- self.tile_sample_stride_width = 192
732
- self.tile_sample_stride_num_frames = 12
733
-
734
- def enable_tiling(
735
- self,
736
- tile_sample_min_height: Optional[int] = None,
737
- tile_sample_min_width: Optional[int] = None,
738
- tile_sample_min_num_frames: Optional[int] = None,
739
- tile_sample_stride_height: Optional[float] = None,
740
- tile_sample_stride_width: Optional[float] = None,
741
- tile_sample_stride_num_frames: Optional[float] = None,
742
- ) -> None:
743
- r"""
744
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
745
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
746
- processing larger images.
747
-
748
- Args:
749
- tile_sample_min_height (`int`, *optional*):
750
- The minimum height required for a sample to be separated into tiles across the height dimension.
751
- tile_sample_min_width (`int`, *optional*):
752
- The minimum width required for a sample to be separated into tiles across the width dimension.
753
- tile_sample_min_num_frames (`int`, *optional*):
754
- The minimum number of frames required for a sample to be separated into tiles across the frame
755
- dimension.
756
- tile_sample_stride_height (`int`, *optional*):
757
- The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
758
- no tiling artifacts produced across the height dimension.
759
- tile_sample_stride_width (`int`, *optional*):
760
- The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
761
- artifacts produced across the width dimension.
762
- tile_sample_stride_num_frames (`int`, *optional*):
763
- The stride between two consecutive frame tiles. This is to ensure that there are no tiling artifacts
764
- produced across the frame dimension.
765
- """
766
- self.use_tiling = True
767
- self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
768
- self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
769
- self.tile_sample_min_num_frames = tile_sample_min_num_frames or self.tile_sample_min_num_frames
770
- self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
771
- self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
772
- self.tile_sample_stride_num_frames = tile_sample_stride_num_frames or self.tile_sample_stride_num_frames
773
-
774
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
775
- batch_size, num_channels, num_frames, height, width = x.shape
776
-
777
- if self.use_framewise_encoding and num_frames > self.tile_sample_min_num_frames:
778
- return self._temporal_tiled_encode(x)
779
-
780
- if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
781
- return self.tiled_encode(x)
782
-
783
- x = self.encoder(x)
784
- enc = self.quant_conv(x)
785
- return enc
786
-
787
- @apply_forward_hook
788
- def encode(
789
- self, x: torch.Tensor, return_dict: bool = True
790
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
791
- r"""
792
- Encode a batch of images into latents.
793
-
794
- Args:
795
- x (`torch.Tensor`): Input batch of images.
796
- return_dict (`bool`, *optional*, defaults to `True`):
797
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
798
-
799
- Returns:
800
- The latent representations of the encoded videos. If `return_dict` is True, a
801
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
802
- """
803
- if self.use_slicing and x.shape[0] > 1:
804
- encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
805
- h = torch.cat(encoded_slices)
806
- else:
807
- h = self._encode(x)
808
-
809
- posterior = DiagonalGaussianDistribution(h)
810
-
811
- if not return_dict:
812
- return (posterior,)
813
- return AutoencoderKLOutput(latent_dist=posterior)
814
-
815
- def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
816
- batch_size, num_channels, num_frames, height, width = z.shape
817
- tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
818
- tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
819
- tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
820
-
821
- if self.use_framewise_decoding and num_frames > tile_latent_min_num_frames:
822
- return self._temporal_tiled_decode(z, return_dict=return_dict)
823
-
824
- if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
825
- return self.tiled_decode(z, return_dict=return_dict)
826
-
827
- z = self.post_quant_conv(z)
828
- dec = self.decoder(z)
829
-
830
- if not return_dict:
831
- return (dec,)
832
-
833
- return DecoderOutput(sample=dec)
834
-
835
- @apply_forward_hook
836
- def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
837
- r"""
838
- Decode a batch of images.
839
-
840
- Args:
841
- z (`torch.Tensor`): Input batch of latent vectors.
842
- return_dict (`bool`, *optional*, defaults to `True`):
843
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
844
-
845
- Returns:
846
- [`~models.vae.DecoderOutput`] or `tuple`:
847
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
848
- returned.
849
- """
850
- if self.use_slicing and z.shape[0] > 1:
851
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
852
- decoded = torch.cat(decoded_slices)
853
- else:
854
- decoded = self._decode(z).sample
855
-
856
- if not return_dict:
857
- return (decoded,)
858
-
859
- return DecoderOutput(sample=decoded)
860
-
861
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
862
- blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
863
- for y in range(blend_extent):
864
- b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
865
- y / blend_extent
866
- )
867
- return b
868
-
869
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
870
- blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
871
- for x in range(blend_extent):
872
- b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
873
- x / blend_extent
874
- )
875
- return b
876
-
877
- def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
878
- blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
879
- for x in range(blend_extent):
880
- b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (
881
- x / blend_extent
882
- )
883
- return b
884
-
885
- def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
886
- r"""Encode a batch of images using a tiled encoder.
887
-
888
- Args:
889
- x (`torch.Tensor`): Input batch of videos.
890
-
891
- Returns:
892
- `torch.Tensor`:
893
- The latent representation of the encoded videos.
894
- """
895
- batch_size, num_channels, num_frames, height, width = x.shape
896
- latent_height = height // self.spatial_compression_ratio
897
- latent_width = width // self.spatial_compression_ratio
898
-
899
- tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
900
- tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
901
- tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
902
- tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
903
-
904
- blend_height = tile_latent_min_height - tile_latent_stride_height
905
- blend_width = tile_latent_min_width - tile_latent_stride_width
906
-
907
- # Split x into overlapping tiles and encode them separately.
908
- # The tiles have an overlap to avoid seams between tiles.
909
- rows = []
910
- for i in range(0, height, self.tile_sample_stride_height):
911
- row = []
912
- for j in range(0, width, self.tile_sample_stride_width):
913
- tile = x[:, :, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
914
- tile = self.encoder(tile)
915
- tile = self.quant_conv(tile)
916
- row.append(tile)
917
- rows.append(row)
918
-
919
- result_rows = []
920
- for i, row in enumerate(rows):
921
- result_row = []
922
- for j, tile in enumerate(row):
923
- # blend the above tile and the left tile
924
- # to the current tile and add the current tile to the result row
925
- if i > 0:
926
- tile = self.blend_v(rows[i - 1][j], tile, blend_height)
927
- if j > 0:
928
- tile = self.blend_h(row[j - 1], tile, blend_width)
929
- result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
930
- result_rows.append(torch.cat(result_row, dim=4))
931
-
932
- enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
933
- return enc
934
-
935
- def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
936
- r"""
937
- Decode a batch of images using a tiled decoder.
938
-
939
- Args:
940
- z (`torch.Tensor`): Input batch of latent vectors.
941
- return_dict (`bool`, *optional*, defaults to `True`):
942
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
943
-
944
- Returns:
945
- [`~models.vae.DecoderOutput`] or `tuple`:
946
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
947
- returned.
948
- """
949
-
950
- batch_size, num_channels, num_frames, height, width = z.shape
951
- sample_height = height * self.spatial_compression_ratio
952
- sample_width = width * self.spatial_compression_ratio
953
-
954
- tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
955
- tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
956
- tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
957
- tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
958
-
959
- blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
960
- blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
961
-
962
- # Split z into overlapping tiles and decode them separately.
963
- # The tiles have an overlap to avoid seams between tiles.
964
- rows = []
965
- for i in range(0, height, tile_latent_stride_height):
966
- row = []
967
- for j in range(0, width, tile_latent_stride_width):
968
- tile = z[:, :, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
969
- tile = self.post_quant_conv(tile)
970
- decoded = self.decoder(tile)
971
- row.append(decoded)
972
- rows.append(row)
973
-
974
- result_rows = []
975
- for i, row in enumerate(rows):
976
- result_row = []
977
- for j, tile in enumerate(row):
978
- # blend the above tile and the left tile
979
- # to the current tile and add the current tile to the result row
980
- if i > 0:
981
- tile = self.blend_v(rows[i - 1][j], tile, blend_height)
982
- if j > 0:
983
- tile = self.blend_h(row[j - 1], tile, blend_width)
984
- result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
985
- result_rows.append(torch.cat(result_row, dim=-1))
986
-
987
- dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
988
-
989
- if not return_dict:
990
- return (dec,)
991
- return DecoderOutput(sample=dec)
992
-
993
- def _temporal_tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
994
- batch_size, num_channels, num_frames, height, width = x.shape
995
- latent_num_frames = (num_frames - 1) // self.temporal_compression_ratio + 1
996
-
997
- tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
998
- tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
999
- blend_num_frames = tile_latent_min_num_frames - tile_latent_stride_num_frames
1000
-
1001
- row = []
1002
- for i in range(0, num_frames, self.tile_sample_stride_num_frames):
1003
- tile = x[:, :, i : i + self.tile_sample_min_num_frames + 1, :, :]
1004
- if self.use_tiling and (height > self.tile_sample_min_height or width > self.tile_sample_min_width):
1005
- tile = self.tiled_encode(tile)
1006
- else:
1007
- tile = self.encoder(tile)
1008
- tile = self.quant_conv(tile)
1009
- if i > 0:
1010
- tile = tile[:, :, 1:, :, :]
1011
- row.append(tile)
1012
-
1013
- result_row = []
1014
- for i, tile in enumerate(row):
1015
- if i > 0:
1016
- tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1017
- result_row.append(tile[:, :, :tile_latent_stride_num_frames, :, :])
1018
- else:
1019
- result_row.append(tile[:, :, : tile_latent_stride_num_frames + 1, :, :])
1020
-
1021
- enc = torch.cat(result_row, dim=2)[:, :, :latent_num_frames]
1022
- return enc
1023
-
1024
- def _temporal_tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1025
- batch_size, num_channels, num_frames, height, width = z.shape
1026
- num_sample_frames = (num_frames - 1) * self.temporal_compression_ratio + 1
1027
-
1028
- tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1029
- tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1030
- tile_latent_min_num_frames = self.tile_sample_min_num_frames // self.temporal_compression_ratio
1031
- tile_latent_stride_num_frames = self.tile_sample_stride_num_frames // self.temporal_compression_ratio
1032
- blend_num_frames = self.tile_sample_min_num_frames - self.tile_sample_stride_num_frames
1033
-
1034
- row = []
1035
- for i in range(0, num_frames, tile_latent_stride_num_frames):
1036
- tile = z[:, :, i : i + tile_latent_min_num_frames + 1, :, :]
1037
- if self.use_tiling and (tile.shape[-1] > tile_latent_min_width or tile.shape[-2] > tile_latent_min_height):
1038
- decoded = self.tiled_decode(tile, return_dict=True).sample
1039
- else:
1040
- tile = self.post_quant_conv(tile)
1041
- decoded = self.decoder(tile)
1042
- if i > 0:
1043
- decoded = decoded[:, :, 1:, :, :]
1044
- row.append(decoded)
1045
-
1046
- result_row = []
1047
- for i, tile in enumerate(row):
1048
- if i > 0:
1049
- tile = self.blend_t(row[i - 1], tile, blend_num_frames)
1050
- result_row.append(tile[:, :, : self.tile_sample_stride_num_frames, :, :])
1051
- else:
1052
- result_row.append(tile[:, :, : self.tile_sample_stride_num_frames + 1, :, :])
1053
-
1054
- dec = torch.cat(result_row, dim=2)[:, :, :num_sample_frames]
1055
-
1056
- if not return_dict:
1057
- return (dec,)
1058
- return DecoderOutput(sample=dec)
1059
-
1060
- def forward(
1061
- self,
1062
- sample: torch.Tensor,
1063
- sample_posterior: bool = False,
1064
- return_dict: bool = True,
1065
- generator: Optional[torch.Generator] = None,
1066
- ) -> Union[DecoderOutput, torch.Tensor]:
1067
- r"""
1068
- Args:
1069
- sample (`torch.Tensor`): Input sample.
1070
- sample_posterior (`bool`, *optional*, defaults to `False`):
1071
- Whether to sample from the posterior.
1072
- return_dict (`bool`, *optional*, defaults to `True`):
1073
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1074
- """
1075
- x = sample
1076
- posterior = self.encode(x).latent_dist
1077
- if sample_posterior:
1078
- z = posterior.sample(generator=generator)
1079
- else:
1080
- z = posterior.mode()
1081
- dec = self.decode(z, return_dict=return_dict)
1082
- return dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/qwenimage_transformer2d.py DELETED
@@ -1,1118 +0,0 @@
1
- # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_qwenimage.py
2
- # Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
-
16
-
17
- import functools
18
- import inspect
19
- import glob
20
- import json
21
- import math
22
- import os
23
- import types
24
- import warnings
25
- from typing import Any, Dict, List, Optional, Tuple, Union
26
-
27
- import numpy as np
28
- import torch
29
- import torch.cuda.amp as amp
30
- import torch.nn as nn
31
- import torch.nn.functional as F
32
- from diffusers.configuration_utils import ConfigMixin, register_to_config
33
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
34
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
35
- from diffusers.models.attention import Attention, FeedForward
36
- from diffusers.models.attention_processor import (
37
- Attention, AttentionProcessor, CogVideoXAttnProcessor2_0,
38
- FusedCogVideoXAttnProcessor2_0)
39
- from diffusers.models.embeddings import (CogVideoXPatchEmbed,
40
- TimestepEmbedding, Timesteps,
41
- get_3d_sincos_pos_embed)
42
- from diffusers.models.modeling_outputs import Transformer2DModelOutput
43
- from diffusers.models.modeling_utils import ModelMixin
44
- from diffusers.models.normalization import (AdaLayerNorm,
45
- AdaLayerNormContinuous,
46
- CogVideoXLayerNormZero, RMSNorm)
47
- from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
48
- scale_lora_layers, unscale_lora_layers)
49
- from diffusers.utils.torch_utils import maybe_allow_in_graph
50
- from torch import nn
51
-
52
- from ..dist import (QwenImageMultiGPUsAttnProcessor2_0,
53
- get_sequence_parallel_rank,
54
- get_sequence_parallel_world_size, get_sp_group)
55
- from .attention_utils import attention
56
- from .cache_utils import TeaCache
57
- from ..utils import cfg_skip
58
-
59
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
60
-
61
-
62
- def get_timestep_embedding(
63
- timesteps: torch.Tensor,
64
- embedding_dim: int,
65
- flip_sin_to_cos: bool = False,
66
- downscale_freq_shift: float = 1,
67
- scale: float = 1,
68
- max_period: int = 10000,
69
- ) -> torch.Tensor:
70
- """
71
- This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
72
-
73
- Args
74
- timesteps (torch.Tensor):
75
- a 1-D Tensor of N indices, one per batch element. These may be fractional.
76
- embedding_dim (int):
77
- the dimension of the output.
78
- flip_sin_to_cos (bool):
79
- Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
80
- downscale_freq_shift (float):
81
- Controls the delta between frequencies between dimensions
82
- scale (float):
83
- Scaling factor applied to the embeddings.
84
- max_period (int):
85
- Controls the maximum frequency of the embeddings
86
- Returns
87
- torch.Tensor: an [N x dim] Tensor of positional embeddings.
88
- """
89
- assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
90
-
91
- half_dim = embedding_dim // 2
92
- exponent = -math.log(max_period) * torch.arange(
93
- start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
94
- )
95
- exponent = exponent / (half_dim - downscale_freq_shift)
96
-
97
- emb = torch.exp(exponent).to(timesteps.dtype)
98
- emb = timesteps[:, None].float() * emb[None, :]
99
-
100
- # scale embeddings
101
- emb = scale * emb
102
-
103
- # concat sine and cosine embeddings
104
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
105
-
106
- # flip sine and cosine embeddings
107
- if flip_sin_to_cos:
108
- emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
109
-
110
- # zero pad
111
- if embedding_dim % 2 == 1:
112
- emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
113
- return emb
114
-
115
-
116
- def apply_rotary_emb_qwen(
117
- x: torch.Tensor,
118
- freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
119
- use_real: bool = True,
120
- use_real_unbind_dim: int = -1,
121
- ) -> Tuple[torch.Tensor, torch.Tensor]:
122
- """
123
- Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
124
- to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
125
- reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
126
- tensors contain rotary embeddings and are returned as real tensors.
127
-
128
- Args:
129
- x (`torch.Tensor`):
130
- Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
131
- freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
132
-
133
- Returns:
134
- Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
135
- """
136
- if use_real:
137
- cos, sin = freqs_cis # [S, D]
138
- cos = cos[None, None]
139
- sin = sin[None, None]
140
- cos, sin = cos.to(x.device), sin.to(x.device)
141
-
142
- if use_real_unbind_dim == -1:
143
- # Used for flux, cogvideox, hunyuan-dit
144
- x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
145
- x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
146
- elif use_real_unbind_dim == -2:
147
- # Used for Stable Audio, OmniGen, CogView4 and Cosmos
148
- x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
149
- x_rotated = torch.cat([-x_imag, x_real], dim=-1)
150
- else:
151
- raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
152
-
153
- out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
154
-
155
- return out
156
- else:
157
- x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
158
- freqs_cis = freqs_cis.unsqueeze(1)
159
- x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
160
-
161
- return x_out.type_as(x)
162
-
163
-
164
- class QwenTimestepProjEmbeddings(nn.Module):
165
- def __init__(self, embedding_dim):
166
- super().__init__()
167
-
168
- self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
169
- self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
170
-
171
- def forward(self, timestep, hidden_states):
172
- timesteps_proj = self.time_proj(timestep)
173
- timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
174
-
175
- conditioning = timesteps_emb
176
-
177
- return conditioning
178
-
179
-
180
- class QwenEmbedRope(nn.Module):
181
- def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
182
- super().__init__()
183
- self.theta = theta
184
- self.axes_dim = axes_dim
185
- pos_index = torch.arange(4096)
186
- neg_index = torch.arange(4096).flip(0) * -1 - 1
187
- self.pos_freqs = torch.cat(
188
- [
189
- self.rope_params(pos_index, self.axes_dim[0], self.theta),
190
- self.rope_params(pos_index, self.axes_dim[1], self.theta),
191
- self.rope_params(pos_index, self.axes_dim[2], self.theta),
192
- ],
193
- dim=1,
194
- )
195
- self.neg_freqs = torch.cat(
196
- [
197
- self.rope_params(neg_index, self.axes_dim[0], self.theta),
198
- self.rope_params(neg_index, self.axes_dim[1], self.theta),
199
- self.rope_params(neg_index, self.axes_dim[2], self.theta),
200
- ],
201
- dim=1,
202
- )
203
- self.rope_cache = {}
204
-
205
- # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
206
- self.scale_rope = scale_rope
207
-
208
- def rope_params(self, index, dim, theta=10000):
209
- """
210
- Args:
211
- index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
212
- """
213
- assert dim % 2 == 0
214
- freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
215
- freqs = torch.polar(torch.ones_like(freqs), freqs)
216
- return freqs
217
-
218
- def forward(self, video_fhw, txt_seq_lens, device):
219
- """
220
- Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
221
- txt_length: [bs] a list of 1 integers representing the length of the text
222
- """
223
- if self.pos_freqs.device != device:
224
- self.pos_freqs = self.pos_freqs.to(device)
225
- self.neg_freqs = self.neg_freqs.to(device)
226
-
227
- if isinstance(video_fhw, list):
228
- video_fhw = video_fhw[0]
229
- if not isinstance(video_fhw, list):
230
- video_fhw = [video_fhw]
231
-
232
- vid_freqs = []
233
- max_vid_index = 0
234
- for idx, fhw in enumerate(video_fhw):
235
- frame, height, width = fhw
236
- rope_key = f"{idx}_{frame}_{height}_{width}"
237
- if not torch.compiler.is_compiling():
238
- if rope_key not in self.rope_cache:
239
- self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
240
- video_freq = self.rope_cache[rope_key]
241
- else:
242
- video_freq = self._compute_video_freqs(frame, height, width, idx)
243
- video_freq = video_freq.to(device)
244
- vid_freqs.append(video_freq)
245
-
246
- if self.scale_rope:
247
- max_vid_index = max(height // 2, width // 2, max_vid_index)
248
- else:
249
- max_vid_index = max(height, width, max_vid_index)
250
-
251
- max_len = max(txt_seq_lens)
252
- txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
253
- vid_freqs = torch.cat(vid_freqs, dim=0)
254
-
255
- return vid_freqs, txt_freqs
256
-
257
- @functools.lru_cache(maxsize=None)
258
- def _compute_video_freqs(self, frame, height, width, idx=0):
259
- seq_lens = frame * height * width
260
- freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
261
- freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
262
-
263
- freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
264
- if self.scale_rope:
265
- freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
266
- freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
267
- freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
268
- freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
269
- else:
270
- freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
271
- freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
272
-
273
- freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
274
- return freqs.clone().contiguous()
275
-
276
-
277
- class QwenDoubleStreamAttnProcessor2_0:
278
- """
279
- Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
280
- implements joint attention computation where text and image streams are processed together.
281
- """
282
-
283
- _attention_backend = None
284
-
285
- def __init__(self):
286
- if not hasattr(F, "scaled_dot_product_attention"):
287
- raise ImportError(
288
- "QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
289
- )
290
-
291
- def __call__(
292
- self,
293
- attn: Attention,
294
- hidden_states: torch.FloatTensor, # Image stream
295
- encoder_hidden_states: torch.FloatTensor = None, # Text stream
296
- encoder_hidden_states_mask: torch.FloatTensor = None,
297
- attention_mask: Optional[torch.FloatTensor] = None,
298
- image_rotary_emb: Optional[torch.Tensor] = None,
299
- ) -> torch.FloatTensor:
300
- if encoder_hidden_states is None:
301
- raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
302
-
303
- seq_txt = encoder_hidden_states.shape[1]
304
-
305
- # Compute QKV for image stream (sample projections)
306
- img_query = attn.to_q(hidden_states)
307
- img_key = attn.to_k(hidden_states)
308
- img_value = attn.to_v(hidden_states)
309
-
310
- # Compute QKV for text stream (context projections)
311
- txt_query = attn.add_q_proj(encoder_hidden_states)
312
- txt_key = attn.add_k_proj(encoder_hidden_states)
313
- txt_value = attn.add_v_proj(encoder_hidden_states)
314
-
315
- # Reshape for multi-head attention
316
- img_query = img_query.unflatten(-1, (attn.heads, -1))
317
- img_key = img_key.unflatten(-1, (attn.heads, -1))
318
- img_value = img_value.unflatten(-1, (attn.heads, -1))
319
-
320
- txt_query = txt_query.unflatten(-1, (attn.heads, -1))
321
- txt_key = txt_key.unflatten(-1, (attn.heads, -1))
322
- txt_value = txt_value.unflatten(-1, (attn.heads, -1))
323
-
324
- # Apply QK normalization
325
- if attn.norm_q is not None:
326
- img_query = attn.norm_q(img_query)
327
- if attn.norm_k is not None:
328
- img_key = attn.norm_k(img_key)
329
- if attn.norm_added_q is not None:
330
- txt_query = attn.norm_added_q(txt_query)
331
- if attn.norm_added_k is not None:
332
- txt_key = attn.norm_added_k(txt_key)
333
-
334
- # Apply RoPE
335
- if image_rotary_emb is not None:
336
- img_freqs, txt_freqs = image_rotary_emb
337
- img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
338
- img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
339
- txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
340
- txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
341
-
342
- # Concatenate for joint attention
343
- # Order: [text, image]
344
- joint_query = torch.cat([txt_query, img_query], dim=1)
345
- joint_key = torch.cat([txt_key, img_key], dim=1)
346
- joint_value = torch.cat([txt_value, img_value], dim=1)
347
-
348
- joint_hidden_states = attention(
349
- joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, causal=False
350
- )
351
-
352
- # Reshape back
353
- joint_hidden_states = joint_hidden_states.flatten(2, 3)
354
- joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
355
-
356
- # Split attention outputs back
357
- txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
358
- img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
359
-
360
- # Apply output projections
361
- img_attn_output = attn.to_out[0](img_attn_output)
362
- if len(attn.to_out) > 1:
363
- img_attn_output = attn.to_out[1](img_attn_output) # dropout
364
-
365
- txt_attn_output = attn.to_add_out(txt_attn_output)
366
-
367
- return img_attn_output, txt_attn_output
368
-
369
-
370
- @maybe_allow_in_graph
371
- class QwenImageTransformerBlock(nn.Module):
372
- def __init__(
373
- self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
374
- ):
375
- super().__init__()
376
-
377
- self.dim = dim
378
- self.num_attention_heads = num_attention_heads
379
- self.attention_head_dim = attention_head_dim
380
-
381
- # Image processing modules
382
- self.img_mod = nn.Sequential(
383
- nn.SiLU(),
384
- nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
385
- )
386
- self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
387
- self.attn = Attention(
388
- query_dim=dim,
389
- cross_attention_dim=None, # Enable cross attention for joint computation
390
- added_kv_proj_dim=dim, # Enable added KV projections for text stream
391
- dim_head=attention_head_dim,
392
- heads=num_attention_heads,
393
- out_dim=dim,
394
- context_pre_only=False,
395
- bias=True,
396
- processor=QwenDoubleStreamAttnProcessor2_0(),
397
- qk_norm=qk_norm,
398
- eps=eps,
399
- )
400
- self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
401
- self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
402
-
403
- # Text processing modules
404
- self.txt_mod = nn.Sequential(
405
- nn.SiLU(),
406
- nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
407
- )
408
- self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
409
- # Text doesn't need separate attention - it's handled by img_attn joint computation
410
- self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
411
- self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
412
-
413
- def _modulate(self, x, mod_params):
414
- """Apply modulation to input tensor"""
415
- shift, scale, gate = mod_params.chunk(3, dim=-1)
416
- return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
417
-
418
- def forward(
419
- self,
420
- hidden_states: torch.Tensor,
421
- encoder_hidden_states: torch.Tensor,
422
- encoder_hidden_states_mask: torch.Tensor,
423
- temb: torch.Tensor,
424
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
425
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
426
- ) -> Tuple[torch.Tensor, torch.Tensor]:
427
- # Get modulation parameters for both streams
428
- img_mod_params = self.img_mod(temb) # [B, 6*dim]
429
- txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
430
-
431
- # Split modulation parameters for norm1 and norm2
432
- img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
433
- txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
434
-
435
- # Process image stream - norm1 + modulation
436
- img_normed = self.img_norm1(hidden_states)
437
- img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
438
-
439
- # Process text stream - norm1 + modulation
440
- txt_normed = self.txt_norm1(encoder_hidden_states)
441
- txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
442
-
443
- # Use QwenAttnProcessor2_0 for joint attention computation
444
- # This directly implements the DoubleStreamLayerMegatron logic:
445
- # 1. Computes QKV for both streams
446
- # 2. Applies QK normalization and RoPE
447
- # 3. Concatenates and runs joint attention
448
- # 4. Splits results back to separate streams
449
- joint_attention_kwargs = joint_attention_kwargs or {}
450
- attn_output = self.attn(
451
- hidden_states=img_modulated, # Image stream (will be processed as "sample")
452
- encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
453
- encoder_hidden_states_mask=encoder_hidden_states_mask,
454
- image_rotary_emb=image_rotary_emb,
455
- **joint_attention_kwargs,
456
- )
457
-
458
- # QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
459
- img_attn_output, txt_attn_output = attn_output
460
-
461
- # Apply attention gates and add residual (like in Megatron)
462
- hidden_states = hidden_states + img_gate1 * img_attn_output
463
- encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
464
-
465
- # Process image stream - norm2 + MLP
466
- img_normed2 = self.img_norm2(hidden_states)
467
- img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
468
- img_mlp_output = self.img_mlp(img_modulated2)
469
- hidden_states = hidden_states + img_gate2 * img_mlp_output
470
-
471
- # Process text stream - norm2 + MLP
472
- txt_normed2 = self.txt_norm2(encoder_hidden_states)
473
- txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
474
- txt_mlp_output = self.txt_mlp(txt_modulated2)
475
- encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
476
-
477
- # Clip to prevent overflow for fp16
478
- if encoder_hidden_states.dtype == torch.float16:
479
- encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
480
- if hidden_states.dtype == torch.float16:
481
- hidden_states = hidden_states.clip(-65504, 65504)
482
-
483
- return encoder_hidden_states, hidden_states
484
-
485
-
486
- class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
487
- """
488
- The Transformer model introduced in Qwen.
489
-
490
- Args:
491
- patch_size (`int`, defaults to `2`):
492
- Patch size to turn the input data into small patches.
493
- in_channels (`int`, defaults to `64`):
494
- The number of channels in the input.
495
- out_channels (`int`, *optional*, defaults to `None`):
496
- The number of channels in the output. If not specified, it defaults to `in_channels`.
497
- num_layers (`int`, defaults to `60`):
498
- The number of layers of dual stream DiT blocks to use.
499
- attention_head_dim (`int`, defaults to `128`):
500
- The number of dimensions to use for each attention head.
501
- num_attention_heads (`int`, defaults to `24`):
502
- The number of attention heads to use.
503
- joint_attention_dim (`int`, defaults to `3584`):
504
- The number of dimensions to use for the joint attention (embedding/channel dimension of
505
- `encoder_hidden_states`).
506
- guidance_embeds (`bool`, defaults to `False`):
507
- Whether to use guidance embeddings for guidance-distilled variant of the model.
508
- axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
509
- The dimensions to use for the rotary positional embeddings.
510
- """
511
-
512
- # _supports_gradient_checkpointing = True
513
- # _no_split_modules = ["QwenImageTransformerBlock"]
514
- # _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
515
- # _repeated_blocks = ["QwenImageTransformerBlock"]
516
- _supports_gradient_checkpointing = True
517
-
518
- @register_to_config
519
- def __init__(
520
- self,
521
- patch_size: int = 2,
522
- in_channels: int = 64,
523
- out_channels: Optional[int] = 16,
524
- num_layers: int = 60,
525
- attention_head_dim: int = 128,
526
- num_attention_heads: int = 24,
527
- joint_attention_dim: int = 3584,
528
- guidance_embeds: bool = False, # TODO: this should probably be removed
529
- axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
530
- ):
531
- super().__init__()
532
- self.out_channels = out_channels or in_channels
533
- self.inner_dim = num_attention_heads * attention_head_dim
534
-
535
- self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
536
-
537
- self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
538
-
539
- self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
540
-
541
- self.img_in = nn.Linear(in_channels, self.inner_dim)
542
- self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
543
-
544
- self.transformer_blocks = nn.ModuleList(
545
- [
546
- QwenImageTransformerBlock(
547
- dim=self.inner_dim,
548
- num_attention_heads=num_attention_heads,
549
- attention_head_dim=attention_head_dim,
550
- )
551
- for _ in range(num_layers)
552
- ]
553
- )
554
-
555
- self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
556
- self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
557
-
558
- self.teacache = None
559
- self.cfg_skip_ratio = None
560
- self.current_steps = 0
561
- self.num_inference_steps = None
562
- self.gradient_checkpointing = False
563
- self.sp_world_size = 1
564
- self.sp_world_rank = 0
565
-
566
- def _set_gradient_checkpointing(self, *args, **kwargs):
567
- if "value" in kwargs:
568
- self.gradient_checkpointing = kwargs["value"]
569
- elif "enable" in kwargs:
570
- self.gradient_checkpointing = kwargs["enable"]
571
- else:
572
- raise ValueError("Invalid set gradient checkpointing")
573
-
574
- def enable_multi_gpus_inference(self,):
575
- self.sp_world_size = get_sequence_parallel_world_size()
576
- self.sp_world_rank = get_sequence_parallel_rank()
577
- self.all_gather = get_sp_group().all_gather
578
- self.set_attn_processor(QwenImageMultiGPUsAttnProcessor2_0())
579
-
580
- @property
581
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
582
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
583
- r"""
584
- Returns:
585
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
586
- indexed by its weight name.
587
- """
588
- # set recursively
589
- processors = {}
590
-
591
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
592
- if hasattr(module, "get_processor"):
593
- processors[f"{name}.processor"] = module.get_processor()
594
-
595
- for sub_name, child in module.named_children():
596
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
597
-
598
- return processors
599
-
600
- for name, module in self.named_children():
601
- fn_recursive_add_processors(name, module, processors)
602
-
603
- return processors
604
-
605
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
606
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
607
- r"""
608
- Sets the attention processor to use to compute attention.
609
-
610
- Parameters:
611
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
612
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
613
- for **all** `Attention` layers.
614
-
615
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
616
- processor. This is strongly recommended when setting trainable attention processors.
617
-
618
- """
619
- count = len(self.attn_processors.keys())
620
-
621
- if isinstance(processor, dict) and len(processor) != count:
622
- raise ValueError(
623
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
624
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
625
- )
626
-
627
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
628
- if hasattr(module, "set_processor"):
629
- if not isinstance(processor, dict):
630
- module.set_processor(processor)
631
- else:
632
- module.set_processor(processor.pop(f"{name}.processor"))
633
-
634
- for sub_name, child in module.named_children():
635
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
636
-
637
- for name, module in self.named_children():
638
- fn_recursive_attn_processor(name, module, processor)
639
-
640
- def enable_cfg_skip(self, cfg_skip_ratio, num_steps):
641
- if cfg_skip_ratio != 0:
642
- self.cfg_skip_ratio = cfg_skip_ratio
643
- self.current_steps = 0
644
- self.num_inference_steps = num_steps
645
- else:
646
- self.cfg_skip_ratio = None
647
- self.current_steps = 0
648
- self.num_inference_steps = None
649
-
650
- def share_cfg_skip(
651
- self,
652
- transformer = None,
653
- ):
654
- self.cfg_skip_ratio = transformer.cfg_skip_ratio
655
- self.current_steps = transformer.current_steps
656
- self.num_inference_steps = transformer.num_inference_steps
657
-
658
- def disable_cfg_skip(self):
659
- self.cfg_skip_ratio = None
660
- self.current_steps = 0
661
- self.num_inference_steps = None
662
-
663
- def enable_teacache(
664
- self,
665
- coefficients,
666
- num_steps: int,
667
- rel_l1_thresh: float,
668
- num_skip_start_steps: int = 0,
669
- offload: bool = True,
670
- ):
671
- self.teacache = TeaCache(
672
- coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
673
- )
674
-
675
- def share_teacache(
676
- self,
677
- transformer = None,
678
- ):
679
- self.teacache = transformer.teacache
680
-
681
- def disable_teacache(self):
682
- self.teacache = None
683
-
684
- @cfg_skip()
685
- def forward_bs(self, x, *args, **kwargs):
686
- func = self.forward
687
- sig = inspect.signature(func)
688
-
689
- bs = len(x)
690
- bs_half = int(bs // 2)
691
-
692
- if bs >= 2:
693
- # cond
694
- x_i = x[bs_half:]
695
- args_i = [
696
- arg[bs_half:] if
697
- isinstance(arg,
698
- (torch.Tensor, list, tuple, np.ndarray)) and
699
- len(arg) == bs else arg for arg in args
700
- ]
701
- kwargs_i = {
702
- k: (v[bs_half:] if
703
- isinstance(v,
704
- (torch.Tensor, list, tuple,
705
- np.ndarray)) and len(v) == bs else v
706
- ) for k, v in kwargs.items()
707
- }
708
- if 'cond_flag' in sig.parameters:
709
- kwargs_i["cond_flag"] = True
710
-
711
- cond_out = func(x_i, *args_i, **kwargs_i)
712
-
713
- # uncond
714
- uncond_x_i = x[:bs_half]
715
- uncond_args_i = [
716
- arg[:bs_half] if
717
- isinstance(arg,
718
- (torch.Tensor, list, tuple, np.ndarray)) and
719
- len(arg) == bs else arg for arg in args
720
- ]
721
- uncond_kwargs_i = {
722
- k: (v[:bs_half] if
723
- isinstance(v,
724
- (torch.Tensor, list, tuple,
725
- np.ndarray)) and len(v) == bs else v
726
- ) for k, v in kwargs.items()
727
- }
728
- if 'cond_flag' in sig.parameters:
729
- uncond_kwargs_i["cond_flag"] = False
730
- uncond_out = func(uncond_x_i, *uncond_args_i,
731
- **uncond_kwargs_i)
732
-
733
- x = torch.cat([uncond_out, cond_out], dim=0)
734
- else:
735
- x = func(x, *args, **kwargs)
736
-
737
- return x
738
-
739
- def forward(
740
- self,
741
- hidden_states: torch.Tensor,
742
- encoder_hidden_states: torch.Tensor = None,
743
- encoder_hidden_states_mask: torch.Tensor = None,
744
- timestep: torch.LongTensor = None,
745
- img_shapes: Optional[List[Tuple[int, int, int]]] = None,
746
- txt_seq_lens: Optional[List[int]] = None,
747
- guidance: torch.Tensor = None, # TODO: this should probably be removed
748
- attention_kwargs: Optional[Dict[str, Any]] = None,
749
- cond_flag: bool = True,
750
- return_dict: bool = True,
751
- ) -> Union[torch.Tensor, Transformer2DModelOutput]:
752
- """
753
- The [`QwenTransformer2DModel`] forward method.
754
-
755
- Args:
756
- hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
757
- Input `hidden_states`.
758
- encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
759
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
760
- encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
761
- Mask of the input conditions.
762
- timestep ( `torch.LongTensor`):
763
- Used to indicate denoising step.
764
- attention_kwargs (`dict`, *optional*):
765
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
766
- `self.processor` in
767
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
768
- return_dict (`bool`, *optional*, defaults to `True`):
769
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
770
- tuple.
771
-
772
- Returns:
773
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
774
- `tuple` where the first element is the sample tensor.
775
- """
776
- if attention_kwargs is not None:
777
- attention_kwargs = attention_kwargs.copy()
778
- lora_scale = attention_kwargs.pop("scale", 1.0)
779
- else:
780
- lora_scale = 1.0
781
-
782
- if USE_PEFT_BACKEND:
783
- # weight the lora layers by setting `lora_scale` for each PEFT layer
784
- scale_lora_layers(self, lora_scale)
785
- else:
786
- if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
787
- logger.warning(
788
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
789
- )
790
-
791
- if isinstance(encoder_hidden_states, list):
792
- encoder_hidden_states = torch.stack(encoder_hidden_states)
793
- encoder_hidden_states_mask = torch.stack(encoder_hidden_states_mask)
794
-
795
- hidden_states = self.img_in(hidden_states)
796
-
797
- timestep = timestep.to(hidden_states.dtype)
798
- encoder_hidden_states = self.txt_norm(encoder_hidden_states)
799
- encoder_hidden_states = self.txt_in(encoder_hidden_states)
800
-
801
- if guidance is not None:
802
- guidance = guidance.to(hidden_states.dtype) * 1000
803
-
804
- temb = (
805
- self.time_text_embed(timestep, hidden_states)
806
- if guidance is None
807
- else self.time_text_embed(timestep, guidance, hidden_states)
808
- )
809
-
810
- image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
811
-
812
- # Context Parallel
813
- if self.sp_world_size > 1:
814
- hidden_states = torch.chunk(hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
815
- if image_rotary_emb is not None:
816
- image_rotary_emb = (
817
- torch.chunk(image_rotary_emb[0], self.sp_world_size, dim=0)[self.sp_world_rank],
818
- image_rotary_emb[1]
819
- )
820
-
821
- # TeaCache
822
- if self.teacache is not None:
823
- if cond_flag:
824
- inp = hidden_states.clone()
825
- temb_ = temb.clone()
826
- encoder_hidden_states_ = encoder_hidden_states.clone()
827
-
828
- img_mod_params_ = self.transformer_blocks[0].img_mod(temb_)
829
- img_mod1_, img_mod2_ = img_mod_params_.chunk(2, dim=-1)
830
- img_normed_ = self.transformer_blocks[0].img_norm1(inp)
831
- modulated_inp, img_gate1_ = self.transformer_blocks[0]._modulate(img_normed_, img_mod1_)
832
-
833
- skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
834
- if skip_flag:
835
- self.should_calc = True
836
- self.teacache.accumulated_rel_l1_distance = 0
837
- else:
838
- if cond_flag:
839
- rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
840
- self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
841
- if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
842
- self.should_calc = False
843
- else:
844
- self.should_calc = True
845
- self.teacache.accumulated_rel_l1_distance = 0
846
- self.teacache.previous_modulated_input = modulated_inp
847
- self.teacache.should_calc = self.should_calc
848
- else:
849
- self.should_calc = self.teacache.should_calc
850
-
851
- # TeaCache
852
- if self.teacache is not None:
853
- if not self.should_calc:
854
- previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
855
- hidden_states = hidden_states + previous_residual.to(hidden_states.device)[-hidden_states.size()[0]:,]
856
- else:
857
- ori_hidden_states = hidden_states.clone().cpu() if self.teacache.offload else hidden_states.clone()
858
-
859
- # 4. Transformer blocks
860
- for i, block in enumerate(self.transformer_blocks):
861
- if torch.is_grad_enabled() and self.gradient_checkpointing:
862
- def create_custom_forward(module):
863
- def custom_forward(*inputs):
864
- return module(*inputs)
865
-
866
- return custom_forward
867
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
868
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
869
- create_custom_forward(block),
870
- hidden_states,
871
- encoder_hidden_states,
872
- encoder_hidden_states_mask,
873
- temb,
874
- image_rotary_emb,
875
- **ckpt_kwargs,
876
- )
877
-
878
- else:
879
- encoder_hidden_states, hidden_states = block(
880
- hidden_states=hidden_states,
881
- encoder_hidden_states=encoder_hidden_states,
882
- encoder_hidden_states_mask=encoder_hidden_states_mask,
883
- temb=temb,
884
- image_rotary_emb=image_rotary_emb,
885
- joint_attention_kwargs=attention_kwargs,
886
- )
887
-
888
- if cond_flag:
889
- self.teacache.previous_residual_cond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states
890
- else:
891
- self.teacache.previous_residual_uncond = hidden_states.cpu() - ori_hidden_states if self.teacache.offload else hidden_states - ori_hidden_states
892
- del ori_hidden_states
893
- else:
894
- for index_block, block in enumerate(self.transformer_blocks):
895
- if torch.is_grad_enabled() and self.gradient_checkpointing:
896
- def create_custom_forward(module):
897
- def custom_forward(*inputs):
898
- return module(*inputs)
899
-
900
- return custom_forward
901
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
902
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
903
- create_custom_forward(block),
904
- hidden_states,
905
- encoder_hidden_states,
906
- encoder_hidden_states_mask,
907
- temb,
908
- image_rotary_emb,
909
- **ckpt_kwargs,
910
- )
911
-
912
- else:
913
- encoder_hidden_states, hidden_states = block(
914
- hidden_states=hidden_states,
915
- encoder_hidden_states=encoder_hidden_states,
916
- encoder_hidden_states_mask=encoder_hidden_states_mask,
917
- temb=temb,
918
- image_rotary_emb=image_rotary_emb,
919
- joint_attention_kwargs=attention_kwargs,
920
- )
921
-
922
- # Use only the image part (hidden_states) from the dual-stream blocks
923
- hidden_states = self.norm_out(hidden_states, temb)
924
- output = self.proj_out(hidden_states)
925
-
926
- if self.sp_world_size > 1:
927
- output = self.all_gather(output, dim=1)
928
-
929
- if USE_PEFT_BACKEND:
930
- # remove `lora_scale` from each PEFT layer
931
- unscale_lora_layers(self, lora_scale)
932
-
933
- if self.teacache is not None and cond_flag:
934
- self.teacache.cnt += 1
935
- if self.teacache.cnt == self.teacache.num_steps:
936
- self.teacache.reset()
937
- return output
938
-
939
- @classmethod
940
- def from_pretrained(
941
- cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
942
- low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
943
- ):
944
- if subfolder is not None:
945
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
946
- print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
947
-
948
- config_file = os.path.join(pretrained_model_path, 'config.json')
949
- if not os.path.isfile(config_file):
950
- raise RuntimeError(f"{config_file} does not exist")
951
- with open(config_file, "r") as f:
952
- config = json.load(f)
953
-
954
- from diffusers.utils import WEIGHTS_NAME
955
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
956
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
957
-
958
- if "dict_mapping" in transformer_additional_kwargs.keys():
959
- for key in transformer_additional_kwargs["dict_mapping"]:
960
- transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
961
-
962
- if low_cpu_mem_usage:
963
- try:
964
- import re
965
-
966
- from diffusers import __version__ as diffusers_version
967
- if diffusers_version >= "0.33.0":
968
- from diffusers.models.model_loading_utils import \
969
- load_model_dict_into_meta
970
- else:
971
- from diffusers.models.modeling_utils import \
972
- load_model_dict_into_meta
973
- from diffusers.utils import is_accelerate_available
974
- if is_accelerate_available():
975
- import accelerate
976
-
977
- # Instantiate model with empty weights
978
- with accelerate.init_empty_weights():
979
- model = cls.from_config(config, **transformer_additional_kwargs)
980
-
981
- param_device = "cpu"
982
- if os.path.exists(model_file):
983
- state_dict = torch.load(model_file, map_location="cpu")
984
- elif os.path.exists(model_file_safetensors):
985
- from safetensors.torch import load_file, safe_open
986
- state_dict = load_file(model_file_safetensors)
987
- else:
988
- from safetensors.torch import load_file, safe_open
989
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
990
- state_dict = {}
991
- print(model_files_safetensors)
992
- for _model_file_safetensors in model_files_safetensors:
993
- _state_dict = load_file(_model_file_safetensors)
994
- for key in _state_dict:
995
- state_dict[key] = _state_dict[key]
996
-
997
- filtered_state_dict = {}
998
- for key in state_dict:
999
- if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
1000
- filtered_state_dict[key] = state_dict[key]
1001
- else:
1002
- print(f"Skipping key '{key}' due to size mismatch or absence in model.")
1003
-
1004
- model_keys = set(model.state_dict().keys())
1005
- loaded_keys = set(filtered_state_dict.keys())
1006
- missing_keys = model_keys - loaded_keys
1007
-
1008
- def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
1009
- initialized_dict = {}
1010
-
1011
- with torch.no_grad():
1012
- for key in missing_keys:
1013
- param_shape = model_state_dict[key].shape
1014
- param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
1015
- if 'weight' in key:
1016
- if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
1017
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1018
- elif 'embedding' in key or 'embed' in key:
1019
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1020
- elif 'head' in key or 'output' in key or 'proj_out' in key:
1021
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1022
- elif len(param_shape) >= 2:
1023
- initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
1024
- nn.init.xavier_uniform_(initialized_dict[key])
1025
- else:
1026
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1027
- elif 'bias' in key:
1028
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1029
- elif 'running_mean' in key:
1030
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1031
- elif 'running_var' in key:
1032
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1033
- elif 'num_batches_tracked' in key:
1034
- initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
1035
- else:
1036
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1037
-
1038
- return initialized_dict
1039
-
1040
- if missing_keys:
1041
- print(f"Missing keys will be initialized: {sorted(missing_keys)}")
1042
- initialized_params = initialize_missing_parameters(
1043
- missing_keys,
1044
- model.state_dict(),
1045
- torch_dtype
1046
- )
1047
- filtered_state_dict.update(initialized_params)
1048
-
1049
- if diffusers_version >= "0.33.0":
1050
- # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
1051
- # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
1052
- load_model_dict_into_meta(
1053
- model,
1054
- filtered_state_dict,
1055
- dtype=torch_dtype,
1056
- model_name_or_path=pretrained_model_path,
1057
- )
1058
- else:
1059
- model._convert_deprecated_attention_blocks(filtered_state_dict)
1060
- unexpected_keys = load_model_dict_into_meta(
1061
- model,
1062
- filtered_state_dict,
1063
- device=param_device,
1064
- dtype=torch_dtype,
1065
- model_name_or_path=pretrained_model_path,
1066
- )
1067
-
1068
- if cls._keys_to_ignore_on_load_unexpected is not None:
1069
- for pat in cls._keys_to_ignore_on_load_unexpected:
1070
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1071
-
1072
- if len(unexpected_keys) > 0:
1073
- print(
1074
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1075
- )
1076
-
1077
- return model
1078
- except Exception as e:
1079
- print(
1080
- f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1081
- )
1082
-
1083
- model = cls.from_config(config, **transformer_additional_kwargs)
1084
- if os.path.exists(model_file):
1085
- state_dict = torch.load(model_file, map_location="cpu")
1086
- elif os.path.exists(model_file_safetensors):
1087
- from safetensors.torch import load_file, safe_open
1088
- state_dict = load_file(model_file_safetensors)
1089
- else:
1090
- from safetensors.torch import load_file, safe_open
1091
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1092
- state_dict = {}
1093
- for _model_file_safetensors in model_files_safetensors:
1094
- _state_dict = load_file(_model_file_safetensors)
1095
- for key in _state_dict:
1096
- state_dict[key] = _state_dict[key]
1097
-
1098
- tmp_state_dict = {}
1099
- for key in state_dict:
1100
- if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1101
- tmp_state_dict[key] = state_dict[key]
1102
- else:
1103
- print(key, "Size don't match, skip")
1104
-
1105
- state_dict = tmp_state_dict
1106
-
1107
- m, u = model.load_state_dict(state_dict, strict=False)
1108
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1109
- print(m)
1110
-
1111
- params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1112
- print(f"### All Parameters: {sum(params) / 1e6} M")
1113
-
1114
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1115
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1116
-
1117
- model = model.to(torch_dtype)
1118
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/qwenimage_vae.py DELETED
@@ -1,1087 +0,0 @@
1
- # Modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoders/autoencoder_kl_qwenimage.py
2
- # Copyright 2025 The Qwen-Image Team, Wan Team and The HuggingFace Team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- #
16
- # We gratefully acknowledge the Wan Team for their outstanding contributions.
17
- # QwenImageVAE is further fine-tuned from the Wan Video VAE to achieve improved performance.
18
- # For more information about the Wan VAE, please refer to:
19
- # - GitHub: https://github.com/Wan-Video/Wan2.1
20
- # - arXiv: https://arxiv.org/abs/2503.20314
21
-
22
- import functools
23
- import glob
24
- import json
25
- import math
26
- import os
27
- import types
28
- import warnings
29
- from typing import Any, Dict, List, Optional, Tuple, Union
30
-
31
- import numpy as np
32
- import torch
33
- import torch.cuda.amp as amp
34
- import torch.nn as nn
35
- import torch.nn.functional as F
36
- import torch.utils.checkpoint
37
- from diffusers.configuration_utils import ConfigMixin, register_to_config
38
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
39
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
40
- from diffusers.models.activations import get_activation
41
- from diffusers.models.attention import FeedForward
42
- from diffusers.models.attention_processor import Attention
43
- from diffusers.models.autoencoders.vae import (DecoderOutput,
44
- DiagonalGaussianDistribution)
45
- from diffusers.models.embeddings import TimestepEmbedding, Timesteps
46
- from diffusers.models.modeling_outputs import (AutoencoderKLOutput,
47
- Transformer2DModelOutput)
48
- from diffusers.models.modeling_utils import ModelMixin
49
- from diffusers.models.normalization import AdaLayerNormContinuous, RMSNorm
50
- from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
51
- scale_lora_layers, unscale_lora_layers)
52
- from diffusers.utils.accelerate_utils import apply_forward_hook
53
- from diffusers.utils.torch_utils import maybe_allow_in_graph
54
- from torch import nn
55
-
56
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
-
58
- CACHE_T = 2
59
-
60
- class QwenImageCausalConv3d(nn.Conv3d):
61
- r"""
62
- A custom 3D causal convolution layer with feature caching support.
63
-
64
- This layer extends the standard Conv3D layer by ensuring causality in the time dimension and handling feature
65
- caching for efficient inference.
66
-
67
- Args:
68
- in_channels (int): Number of channels in the input image
69
- out_channels (int): Number of channels produced by the convolution
70
- kernel_size (int or tuple): Size of the convolving kernel
71
- stride (int or tuple, optional): Stride of the convolution. Default: 1
72
- padding (int or tuple, optional): Zero-padding added to all three sides of the input. Default: 0
73
- """
74
-
75
- def __init__(
76
- self,
77
- in_channels: int,
78
- out_channels: int,
79
- kernel_size: Union[int, Tuple[int, int, int]],
80
- stride: Union[int, Tuple[int, int, int]] = 1,
81
- padding: Union[int, Tuple[int, int, int]] = 0,
82
- ) -> None:
83
- super().__init__(
84
- in_channels=in_channels,
85
- out_channels=out_channels,
86
- kernel_size=kernel_size,
87
- stride=stride,
88
- padding=padding,
89
- )
90
-
91
- # Set up causal padding
92
- self._padding = (self.padding[2], self.padding[2], self.padding[1], self.padding[1], 2 * self.padding[0], 0)
93
- self.padding = (0, 0, 0)
94
-
95
- def forward(self, x, cache_x=None):
96
- padding = list(self._padding)
97
- if cache_x is not None and self._padding[4] > 0:
98
- cache_x = cache_x.to(x.device)
99
- x = torch.cat([cache_x, x], dim=2)
100
- padding[4] -= cache_x.shape[2]
101
- x = F.pad(x, padding)
102
- return super().forward(x)
103
-
104
-
105
- class QwenImageRMS_norm(nn.Module):
106
- r"""
107
- A custom RMS normalization layer.
108
-
109
- Args:
110
- dim (int): The number of dimensions to normalize over.
111
- channel_first (bool, optional): Whether the input tensor has channels as the first dimension.
112
- Default is True.
113
- images (bool, optional): Whether the input represents image data. Default is True.
114
- bias (bool, optional): Whether to include a learnable bias term. Default is False.
115
- """
116
-
117
- def __init__(self, dim: int, channel_first: bool = True, images: bool = True, bias: bool = False) -> None:
118
- super().__init__()
119
- broadcastable_dims = (1, 1, 1) if not images else (1, 1)
120
- shape = (dim, *broadcastable_dims) if channel_first else (dim,)
121
-
122
- self.channel_first = channel_first
123
- self.scale = dim**0.5
124
- self.gamma = nn.Parameter(torch.ones(shape))
125
- self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
126
-
127
- def forward(self, x):
128
- return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
129
-
130
-
131
- class QwenImageUpsample(nn.Upsample):
132
- r"""
133
- Perform upsampling while ensuring the output tensor has the same data type as the input.
134
-
135
- Args:
136
- x (torch.Tensor): Input tensor to be upsampled.
137
-
138
- Returns:
139
- torch.Tensor: Upsampled tensor with the same data type as the input.
140
- """
141
-
142
- def forward(self, x):
143
- return super().forward(x.float()).type_as(x)
144
-
145
-
146
- class QwenImageResample(nn.Module):
147
- r"""
148
- A custom resampling module for 2D and 3D data.
149
-
150
- Args:
151
- dim (int): The number of input/output channels.
152
- mode (str): The resampling mode. Must be one of:
153
- - 'none': No resampling (identity operation).
154
- - 'upsample2d': 2D upsampling with nearest-exact interpolation and convolution.
155
- - 'upsample3d': 3D upsampling with nearest-exact interpolation, convolution, and causal 3D convolution.
156
- - 'downsample2d': 2D downsampling with zero-padding and convolution.
157
- - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
158
- """
159
-
160
- def __init__(self, dim: int, mode: str) -> None:
161
- super().__init__()
162
- self.dim = dim
163
- self.mode = mode
164
-
165
- # layers
166
- if mode == "upsample2d":
167
- self.resample = nn.Sequential(
168
- QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
169
- nn.Conv2d(dim, dim // 2, 3, padding=1),
170
- )
171
- elif mode == "upsample3d":
172
- self.resample = nn.Sequential(
173
- QwenImageUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
174
- nn.Conv2d(dim, dim // 2, 3, padding=1),
175
- )
176
- self.time_conv = QwenImageCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
177
-
178
- elif mode == "downsample2d":
179
- self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
180
- elif mode == "downsample3d":
181
- self.resample = nn.Sequential(nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2)))
182
- self.time_conv = QwenImageCausalConv3d(dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
183
-
184
- else:
185
- self.resample = nn.Identity()
186
-
187
- def forward(self, x, feat_cache=None, feat_idx=[0]):
188
- b, c, t, h, w = x.size()
189
- if self.mode == "upsample3d":
190
- if feat_cache is not None:
191
- idx = feat_idx[0]
192
- if feat_cache[idx] is None:
193
- feat_cache[idx] = "Rep"
194
- feat_idx[0] += 1
195
- else:
196
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
197
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] != "Rep":
198
- # cache last frame of last two chunk
199
- cache_x = torch.cat(
200
- [feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2
201
- )
202
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx] == "Rep":
203
- cache_x = torch.cat([torch.zeros_like(cache_x).to(cache_x.device), cache_x], dim=2)
204
- if feat_cache[idx] == "Rep":
205
- x = self.time_conv(x)
206
- else:
207
- x = self.time_conv(x, feat_cache[idx])
208
- feat_cache[idx] = cache_x
209
- feat_idx[0] += 1
210
-
211
- x = x.reshape(b, 2, c, t, h, w)
212
- x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
213
- x = x.reshape(b, c, t * 2, h, w)
214
- t = x.shape[2]
215
- x = x.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
216
- x = self.resample(x)
217
- x = x.view(b, t, x.size(1), x.size(2), x.size(3)).permute(0, 2, 1, 3, 4)
218
-
219
- if self.mode == "downsample3d":
220
- if feat_cache is not None:
221
- idx = feat_idx[0]
222
- if feat_cache[idx] is None:
223
- feat_cache[idx] = x.clone()
224
- feat_idx[0] += 1
225
- else:
226
- cache_x = x[:, :, -1:, :, :].clone()
227
- x = self.time_conv(torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
228
- feat_cache[idx] = cache_x
229
- feat_idx[0] += 1
230
- return x
231
-
232
-
233
- class QwenImageResidualBlock(nn.Module):
234
- r"""
235
- A custom residual block module.
236
-
237
- Args:
238
- in_dim (int): Number of input channels.
239
- out_dim (int): Number of output channels.
240
- dropout (float, optional): Dropout rate for the dropout layer. Default is 0.0.
241
- non_linearity (str, optional): Type of non-linearity to use. Default is "silu".
242
- """
243
-
244
- def __init__(
245
- self,
246
- in_dim: int,
247
- out_dim: int,
248
- dropout: float = 0.0,
249
- non_linearity: str = "silu",
250
- ) -> None:
251
- super().__init__()
252
- self.in_dim = in_dim
253
- self.out_dim = out_dim
254
- self.nonlinearity = get_activation(non_linearity)
255
-
256
- # layers
257
- self.norm1 = QwenImageRMS_norm(in_dim, images=False)
258
- self.conv1 = QwenImageCausalConv3d(in_dim, out_dim, 3, padding=1)
259
- self.norm2 = QwenImageRMS_norm(out_dim, images=False)
260
- self.dropout = nn.Dropout(dropout)
261
- self.conv2 = QwenImageCausalConv3d(out_dim, out_dim, 3, padding=1)
262
- self.conv_shortcut = QwenImageCausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
263
-
264
- def forward(self, x, feat_cache=None, feat_idx=[0]):
265
- # Apply shortcut connection
266
- h = self.conv_shortcut(x)
267
-
268
- # First normalization and activation
269
- x = self.norm1(x)
270
- x = self.nonlinearity(x)
271
-
272
- if feat_cache is not None:
273
- idx = feat_idx[0]
274
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
275
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
276
- cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
277
-
278
- x = self.conv1(x, feat_cache[idx])
279
- feat_cache[idx] = cache_x
280
- feat_idx[0] += 1
281
- else:
282
- x = self.conv1(x)
283
-
284
- # Second normalization and activation
285
- x = self.norm2(x)
286
- x = self.nonlinearity(x)
287
-
288
- # Dropout
289
- x = self.dropout(x)
290
-
291
- if feat_cache is not None:
292
- idx = feat_idx[0]
293
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
294
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
295
- cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
296
-
297
- x = self.conv2(x, feat_cache[idx])
298
- feat_cache[idx] = cache_x
299
- feat_idx[0] += 1
300
- else:
301
- x = self.conv2(x)
302
-
303
- # Add residual connection
304
- return x + h
305
-
306
-
307
- class QwenImageAttentionBlock(nn.Module):
308
- r"""
309
- Causal self-attention with a single head.
310
-
311
- Args:
312
- dim (int): The number of channels in the input tensor.
313
- """
314
-
315
- def __init__(self, dim):
316
- super().__init__()
317
- self.dim = dim
318
-
319
- # layers
320
- self.norm = QwenImageRMS_norm(dim)
321
- self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
322
- self.proj = nn.Conv2d(dim, dim, 1)
323
-
324
- def forward(self, x):
325
- identity = x
326
- batch_size, channels, time, height, width = x.size()
327
-
328
- x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * time, channels, height, width)
329
- x = self.norm(x)
330
-
331
- # compute query, key, value
332
- qkv = self.to_qkv(x)
333
- qkv = qkv.reshape(batch_size * time, 1, channels * 3, -1)
334
- qkv = qkv.permute(0, 1, 3, 2).contiguous()
335
- q, k, v = qkv.chunk(3, dim=-1)
336
-
337
- # apply attention
338
- x = F.scaled_dot_product_attention(q, k, v)
339
-
340
- x = x.squeeze(1).permute(0, 2, 1).reshape(batch_size * time, channels, height, width)
341
-
342
- # output projection
343
- x = self.proj(x)
344
-
345
- # Reshape back: [(b*t), c, h, w] -> [b, c, t, h, w]
346
- x = x.view(batch_size, time, channels, height, width)
347
- x = x.permute(0, 2, 1, 3, 4)
348
-
349
- return x + identity
350
-
351
-
352
- class QwenImageMidBlock(nn.Module):
353
- """
354
- Middle block for QwenImageVAE encoder and decoder.
355
-
356
- Args:
357
- dim (int): Number of input/output channels.
358
- dropout (float): Dropout rate.
359
- non_linearity (str): Type of non-linearity to use.
360
- """
361
-
362
- def __init__(self, dim: int, dropout: float = 0.0, non_linearity: str = "silu", num_layers: int = 1):
363
- super().__init__()
364
- self.dim = dim
365
-
366
- # Create the components
367
- resnets = [QwenImageResidualBlock(dim, dim, dropout, non_linearity)]
368
- attentions = []
369
- for _ in range(num_layers):
370
- attentions.append(QwenImageAttentionBlock(dim))
371
- resnets.append(QwenImageResidualBlock(dim, dim, dropout, non_linearity))
372
- self.attentions = nn.ModuleList(attentions)
373
- self.resnets = nn.ModuleList(resnets)
374
-
375
- self.gradient_checkpointing = False
376
-
377
- def forward(self, x, feat_cache=None, feat_idx=[0]):
378
- # First residual block
379
- x = self.resnets[0](x, feat_cache, feat_idx)
380
-
381
- # Process through attention and residual blocks
382
- for attn, resnet in zip(self.attentions, self.resnets[1:]):
383
- if attn is not None:
384
- x = attn(x)
385
-
386
- x = resnet(x, feat_cache, feat_idx)
387
-
388
- return x
389
-
390
-
391
- class QwenImageEncoder3d(nn.Module):
392
- r"""
393
- A 3D encoder module.
394
-
395
- Args:
396
- dim (int): The base number of channels in the first layer.
397
- z_dim (int): The dimensionality of the latent space.
398
- dim_mult (list of int): Multipliers for the number of channels in each block.
399
- num_res_blocks (int): Number of residual blocks in each block.
400
- attn_scales (list of float): Scales at which to apply attention mechanisms.
401
- temperal_downsample (list of bool): Whether to downsample temporally in each block.
402
- dropout (float): Dropout rate for the dropout layers.
403
- non_linearity (str): Type of non-linearity to use.
404
- """
405
-
406
- def __init__(
407
- self,
408
- dim=128,
409
- z_dim=4,
410
- dim_mult=[1, 2, 4, 4],
411
- num_res_blocks=2,
412
- attn_scales=[],
413
- temperal_downsample=[True, True, False],
414
- dropout=0.0,
415
- non_linearity: str = "silu",
416
- ):
417
- super().__init__()
418
- self.dim = dim
419
- self.z_dim = z_dim
420
- self.dim_mult = dim_mult
421
- self.num_res_blocks = num_res_blocks
422
- self.attn_scales = attn_scales
423
- self.temperal_downsample = temperal_downsample
424
- self.nonlinearity = get_activation(non_linearity)
425
-
426
- # dimensions
427
- dims = [dim * u for u in [1] + dim_mult]
428
- scale = 1.0
429
-
430
- # init block
431
- self.conv_in = QwenImageCausalConv3d(3, dims[0], 3, padding=1)
432
-
433
- # downsample blocks
434
- self.down_blocks = nn.ModuleList([])
435
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
436
- # residual (+attention) blocks
437
- for _ in range(num_res_blocks):
438
- self.down_blocks.append(QwenImageResidualBlock(in_dim, out_dim, dropout))
439
- if scale in attn_scales:
440
- self.down_blocks.append(QwenImageAttentionBlock(out_dim))
441
- in_dim = out_dim
442
-
443
- # downsample block
444
- if i != len(dim_mult) - 1:
445
- mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
446
- self.down_blocks.append(QwenImageResample(out_dim, mode=mode))
447
- scale /= 2.0
448
-
449
- # middle blocks
450
- self.mid_block = QwenImageMidBlock(out_dim, dropout, non_linearity, num_layers=1)
451
-
452
- # output blocks
453
- self.norm_out = QwenImageRMS_norm(out_dim, images=False)
454
- self.conv_out = QwenImageCausalConv3d(out_dim, z_dim, 3, padding=1)
455
-
456
- self.gradient_checkpointing = False
457
-
458
- def forward(self, x, feat_cache=None, feat_idx=[0]):
459
- if feat_cache is not None:
460
- idx = feat_idx[0]
461
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
462
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
463
- # cache last frame of last two chunk
464
- cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
465
- x = self.conv_in(x, feat_cache[idx])
466
- feat_cache[idx] = cache_x
467
- feat_idx[0] += 1
468
- else:
469
- x = self.conv_in(x)
470
-
471
- ## downsamples
472
- for layer in self.down_blocks:
473
- if feat_cache is not None:
474
- x = layer(x, feat_cache, feat_idx)
475
- else:
476
- x = layer(x)
477
-
478
- ## middle
479
- x = self.mid_block(x, feat_cache, feat_idx)
480
-
481
- ## head
482
- x = self.norm_out(x)
483
- x = self.nonlinearity(x)
484
- if feat_cache is not None:
485
- idx = feat_idx[0]
486
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
487
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
488
- # cache last frame of last two chunk
489
- cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
490
- x = self.conv_out(x, feat_cache[idx])
491
- feat_cache[idx] = cache_x
492
- feat_idx[0] += 1
493
- else:
494
- x = self.conv_out(x)
495
- return x
496
-
497
-
498
- class QwenImageUpBlock(nn.Module):
499
- """
500
- A block that handles upsampling for the QwenImageVAE decoder.
501
-
502
- Args:
503
- in_dim (int): Input dimension
504
- out_dim (int): Output dimension
505
- num_res_blocks (int): Number of residual blocks
506
- dropout (float): Dropout rate
507
- upsample_mode (str, optional): Mode for upsampling ('upsample2d' or 'upsample3d')
508
- non_linearity (str): Type of non-linearity to use
509
- """
510
-
511
- def __init__(
512
- self,
513
- in_dim: int,
514
- out_dim: int,
515
- num_res_blocks: int,
516
- dropout: float = 0.0,
517
- upsample_mode: Optional[str] = None,
518
- non_linearity: str = "silu",
519
- ):
520
- super().__init__()
521
- self.in_dim = in_dim
522
- self.out_dim = out_dim
523
-
524
- # Create layers list
525
- resnets = []
526
- # Add residual blocks and attention if needed
527
- current_dim = in_dim
528
- for _ in range(num_res_blocks + 1):
529
- resnets.append(QwenImageResidualBlock(current_dim, out_dim, dropout, non_linearity))
530
- current_dim = out_dim
531
-
532
- self.resnets = nn.ModuleList(resnets)
533
-
534
- # Add upsampling layer if needed
535
- self.upsamplers = None
536
- if upsample_mode is not None:
537
- self.upsamplers = nn.ModuleList([QwenImageResample(out_dim, mode=upsample_mode)])
538
-
539
- self.gradient_checkpointing = False
540
-
541
- def forward(self, x, feat_cache=None, feat_idx=[0]):
542
- """
543
- Forward pass through the upsampling block.
544
-
545
- Args:
546
- x (torch.Tensor): Input tensor
547
- feat_cache (list, optional): Feature cache for causal convolutions
548
- feat_idx (list, optional): Feature index for cache management
549
-
550
- Returns:
551
- torch.Tensor: Output tensor
552
- """
553
- for resnet in self.resnets:
554
- if feat_cache is not None:
555
- x = resnet(x, feat_cache, feat_idx)
556
- else:
557
- x = resnet(x)
558
-
559
- if self.upsamplers is not None:
560
- if feat_cache is not None:
561
- x = self.upsamplers[0](x, feat_cache, feat_idx)
562
- else:
563
- x = self.upsamplers[0](x)
564
- return x
565
-
566
-
567
- class QwenImageDecoder3d(nn.Module):
568
- r"""
569
- A 3D decoder module.
570
-
571
- Args:
572
- dim (int): The base number of channels in the first layer.
573
- z_dim (int): The dimensionality of the latent space.
574
- dim_mult (list of int): Multipliers for the number of channels in each block.
575
- num_res_blocks (int): Number of residual blocks in each block.
576
- attn_scales (list of float): Scales at which to apply attention mechanisms.
577
- temperal_upsample (list of bool): Whether to upsample temporally in each block.
578
- dropout (float): Dropout rate for the dropout layers.
579
- non_linearity (str): Type of non-linearity to use.
580
- """
581
-
582
- def __init__(
583
- self,
584
- dim=128,
585
- z_dim=4,
586
- dim_mult=[1, 2, 4, 4],
587
- num_res_blocks=2,
588
- attn_scales=[],
589
- temperal_upsample=[False, True, True],
590
- dropout=0.0,
591
- non_linearity: str = "silu",
592
- ):
593
- super().__init__()
594
- self.dim = dim
595
- self.z_dim = z_dim
596
- self.dim_mult = dim_mult
597
- self.num_res_blocks = num_res_blocks
598
- self.attn_scales = attn_scales
599
- self.temperal_upsample = temperal_upsample
600
-
601
- self.nonlinearity = get_activation(non_linearity)
602
-
603
- # dimensions
604
- dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
605
- scale = 1.0 / 2 ** (len(dim_mult) - 2)
606
-
607
- # init block
608
- self.conv_in = QwenImageCausalConv3d(z_dim, dims[0], 3, padding=1)
609
-
610
- # middle blocks
611
- self.mid_block = QwenImageMidBlock(dims[0], dropout, non_linearity, num_layers=1)
612
-
613
- # upsample blocks
614
- self.up_blocks = nn.ModuleList([])
615
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
616
- # residual (+attention) blocks
617
- if i > 0:
618
- in_dim = in_dim // 2
619
-
620
- # Determine if we need upsampling
621
- upsample_mode = None
622
- if i != len(dim_mult) - 1:
623
- upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
624
-
625
- # Create and add the upsampling block
626
- up_block = QwenImageUpBlock(
627
- in_dim=in_dim,
628
- out_dim=out_dim,
629
- num_res_blocks=num_res_blocks,
630
- dropout=dropout,
631
- upsample_mode=upsample_mode,
632
- non_linearity=non_linearity,
633
- )
634
- self.up_blocks.append(up_block)
635
-
636
- # Update scale for next iteration
637
- if upsample_mode is not None:
638
- scale *= 2.0
639
-
640
- # output blocks
641
- self.norm_out = QwenImageRMS_norm(out_dim, images=False)
642
- self.conv_out = QwenImageCausalConv3d(out_dim, 3, 3, padding=1)
643
-
644
- self.gradient_checkpointing = False
645
-
646
- def forward(self, x, feat_cache=None, feat_idx=[0]):
647
- ## conv1
648
- if feat_cache is not None:
649
- idx = feat_idx[0]
650
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
651
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
652
- # cache last frame of last two chunk
653
- cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
654
- x = self.conv_in(x, feat_cache[idx])
655
- feat_cache[idx] = cache_x
656
- feat_idx[0] += 1
657
- else:
658
- x = self.conv_in(x)
659
-
660
- ## middle
661
- x = self.mid_block(x, feat_cache, feat_idx)
662
-
663
- ## upsamples
664
- for up_block in self.up_blocks:
665
- x = up_block(x, feat_cache, feat_idx)
666
-
667
- ## head
668
- x = self.norm_out(x)
669
- x = self.nonlinearity(x)
670
- if feat_cache is not None:
671
- idx = feat_idx[0]
672
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
673
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
674
- # cache last frame of last two chunk
675
- cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
676
- x = self.conv_out(x, feat_cache[idx])
677
- feat_cache[idx] = cache_x
678
- feat_idx[0] += 1
679
- else:
680
- x = self.conv_out(x)
681
- return x
682
-
683
-
684
- class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
685
- r"""
686
- A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
687
-
688
- This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
689
- for all models (such as downloading or saving).
690
- """
691
-
692
- _supports_gradient_checkpointing = False
693
-
694
- # fmt: off
695
- @register_to_config
696
- def __init__(
697
- self,
698
- base_dim: int = 96,
699
- z_dim: int = 16,
700
- dim_mult: Tuple[int] = [1, 2, 4, 4],
701
- num_res_blocks: int = 2,
702
- attn_scales: List[float] = [],
703
- temperal_downsample: List[bool] = [False, True, True],
704
- dropout: float = 0.0,
705
- latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
706
- latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
707
- ) -> None:
708
- # fmt: on
709
- super().__init__()
710
-
711
- self.z_dim = z_dim
712
- self.temperal_downsample = temperal_downsample
713
- self.temperal_upsample = temperal_downsample[::-1]
714
-
715
- self.encoder = QwenImageEncoder3d(
716
- base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
717
- )
718
- self.quant_conv = QwenImageCausalConv3d(z_dim * 2, z_dim * 2, 1)
719
- self.post_quant_conv = QwenImageCausalConv3d(z_dim, z_dim, 1)
720
-
721
- self.decoder = QwenImageDecoder3d(
722
- base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
723
- )
724
-
725
- self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
726
-
727
- # When decoding a batch of video latents at a time, one can save memory by slicing across the batch dimension
728
- # to perform decoding of a single video latent at a time.
729
- self.use_slicing = False
730
-
731
- # When decoding spatially large video latents, the memory requirement is very high. By breaking the video latent
732
- # frames spatially into smaller tiles and performing multiple forward passes for decoding, and then blending the
733
- # intermediate tiles together, the memory requirement can be lowered.
734
- self.use_tiling = False
735
-
736
- # The minimal tile height and width for spatial tiling to be used
737
- self.tile_sample_min_height = 256
738
- self.tile_sample_min_width = 256
739
-
740
- # The minimal distance between two spatial tiles
741
- self.tile_sample_stride_height = 192
742
- self.tile_sample_stride_width = 192
743
-
744
- # Precompute and cache conv counts for encoder and decoder for clear_cache speedup
745
- self._cached_conv_counts = {
746
- "decoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.decoder.modules())
747
- if self.decoder is not None
748
- else 0,
749
- "encoder": sum(isinstance(m, QwenImageCausalConv3d) for m in self.encoder.modules())
750
- if self.encoder is not None
751
- else 0,
752
- }
753
-
754
- def enable_tiling(
755
- self,
756
- tile_sample_min_height: Optional[int] = None,
757
- tile_sample_min_width: Optional[int] = None,
758
- tile_sample_stride_height: Optional[float] = None,
759
- tile_sample_stride_width: Optional[float] = None,
760
- ) -> None:
761
- r"""
762
- Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
763
- compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
764
- processing larger images.
765
-
766
- Args:
767
- tile_sample_min_height (`int`, *optional*):
768
- The minimum height required for a sample to be separated into tiles across the height dimension.
769
- tile_sample_min_width (`int`, *optional*):
770
- The minimum width required for a sample to be separated into tiles across the width dimension.
771
- tile_sample_stride_height (`int`, *optional*):
772
- The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
773
- no tiling artifacts produced across the height dimension.
774
- tile_sample_stride_width (`int`, *optional*):
775
- The stride between two consecutive horizontal tiles. This is to ensure that there are no tiling
776
- artifacts produced across the width dimension.
777
- """
778
- self.use_tiling = True
779
- self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
780
- self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
781
- self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
782
- self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
783
-
784
- def disable_tiling(self) -> None:
785
- r"""
786
- Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
787
- decoding in one step.
788
- """
789
- self.use_tiling = False
790
-
791
- def enable_slicing(self) -> None:
792
- r"""
793
- Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
794
- compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
795
- """
796
- self.use_slicing = True
797
-
798
- def disable_slicing(self) -> None:
799
- r"""
800
- Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
801
- decoding in one step.
802
- """
803
- self.use_slicing = False
804
-
805
- def clear_cache(self):
806
- def _count_conv3d(model):
807
- count = 0
808
- for m in model.modules():
809
- if isinstance(m, QwenImageCausalConv3d):
810
- count += 1
811
- return count
812
-
813
- self._conv_num = _count_conv3d(self.decoder)
814
- self._conv_idx = [0]
815
- self._feat_map = [None] * self._conv_num
816
- # cache encode
817
- self._enc_conv_num = _count_conv3d(self.encoder)
818
- self._enc_conv_idx = [0]
819
- self._enc_feat_map = [None] * self._enc_conv_num
820
-
821
- def _encode(self, x: torch.Tensor):
822
- _, _, num_frame, height, width = x.shape
823
-
824
- if self.use_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height):
825
- return self.tiled_encode(x)
826
-
827
- self.clear_cache()
828
- iter_ = 1 + (num_frame - 1) // 4
829
- for i in range(iter_):
830
- self._enc_conv_idx = [0]
831
- if i == 0:
832
- out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
833
- else:
834
- out_ = self.encoder(
835
- x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
836
- feat_cache=self._enc_feat_map,
837
- feat_idx=self._enc_conv_idx,
838
- )
839
- out = torch.cat([out, out_], 2)
840
-
841
- enc = self.quant_conv(out)
842
- self.clear_cache()
843
- return enc
844
-
845
- @apply_forward_hook
846
- def encode(
847
- self, x: torch.Tensor, return_dict: bool = True
848
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
849
- r"""
850
- Encode a batch of images into latents.
851
-
852
- Args:
853
- x (`torch.Tensor`): Input batch of images.
854
- return_dict (`bool`, *optional*, defaults to `True`):
855
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
856
-
857
- Returns:
858
- The latent representations of the encoded videos. If `return_dict` is True, a
859
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
860
- """
861
- if self.use_slicing and x.shape[0] > 1:
862
- encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
863
- h = torch.cat(encoded_slices)
864
- else:
865
- h = self._encode(x)
866
- posterior = DiagonalGaussianDistribution(h)
867
-
868
- if not return_dict:
869
- return (posterior,)
870
- return AutoencoderKLOutput(latent_dist=posterior)
871
-
872
- def _decode(self, z: torch.Tensor, return_dict: bool = True):
873
- _, _, num_frame, height, width = z.shape
874
- tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
875
- tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
876
-
877
- if self.use_tiling and (width > tile_latent_min_width or height > tile_latent_min_height):
878
- return self.tiled_decode(z, return_dict=return_dict)
879
-
880
- self.clear_cache()
881
- x = self.post_quant_conv(z)
882
- for i in range(num_frame):
883
- self._conv_idx = [0]
884
- if i == 0:
885
- out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
886
- else:
887
- out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
888
- out = torch.cat([out, out_], 2)
889
-
890
- out = torch.clamp(out, min=-1.0, max=1.0)
891
- self.clear_cache()
892
- if not return_dict:
893
- return (out,)
894
-
895
- return DecoderOutput(sample=out)
896
-
897
- @apply_forward_hook
898
- def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
899
- r"""
900
- Decode a batch of images.
901
-
902
- Args:
903
- z (`torch.Tensor`): Input batch of latent vectors.
904
- return_dict (`bool`, *optional*, defaults to `True`):
905
- Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
906
-
907
- Returns:
908
- [`~models.vae.DecoderOutput`] or `tuple`:
909
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
910
- returned.
911
- """
912
- if self.use_slicing and z.shape[0] > 1:
913
- decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
914
- decoded = torch.cat(decoded_slices)
915
- else:
916
- decoded = self._decode(z).sample
917
-
918
- if not return_dict:
919
- return (decoded,)
920
- return DecoderOutput(sample=decoded)
921
-
922
- def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
923
- blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
924
- for y in range(blend_extent):
925
- b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
926
- y / blend_extent
927
- )
928
- return b
929
-
930
- def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
931
- blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
932
- for x in range(blend_extent):
933
- b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
934
- x / blend_extent
935
- )
936
- return b
937
-
938
- def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput:
939
- r"""Encode a batch of images using a tiled encoder.
940
-
941
- Args:
942
- x (`torch.Tensor`): Input batch of videos.
943
-
944
- Returns:
945
- `torch.Tensor`:
946
- The latent representation of the encoded videos.
947
- """
948
- _, _, num_frames, height, width = x.shape
949
- latent_height = height // self.spatial_compression_ratio
950
- latent_width = width // self.spatial_compression_ratio
951
-
952
- tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
953
- tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
954
- tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
955
- tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
956
-
957
- blend_height = tile_latent_min_height - tile_latent_stride_height
958
- blend_width = tile_latent_min_width - tile_latent_stride_width
959
-
960
- # Split x into overlapping tiles and encode them separately.
961
- # The tiles have an overlap to avoid seams between tiles.
962
- rows = []
963
- for i in range(0, height, self.tile_sample_stride_height):
964
- row = []
965
- for j in range(0, width, self.tile_sample_stride_width):
966
- self.clear_cache()
967
- time = []
968
- frame_range = 1 + (num_frames - 1) // 4
969
- for k in range(frame_range):
970
- self._enc_conv_idx = [0]
971
- if k == 0:
972
- tile = x[:, :, :1, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
973
- else:
974
- tile = x[
975
- :,
976
- :,
977
- 1 + 4 * (k - 1) : 1 + 4 * k,
978
- i : i + self.tile_sample_min_height,
979
- j : j + self.tile_sample_min_width,
980
- ]
981
- tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
982
- tile = self.quant_conv(tile)
983
- time.append(tile)
984
- row.append(torch.cat(time, dim=2))
985
- rows.append(row)
986
- self.clear_cache()
987
-
988
- result_rows = []
989
- for i, row in enumerate(rows):
990
- result_row = []
991
- for j, tile in enumerate(row):
992
- # blend the above tile and the left tile
993
- # to the current tile and add the current tile to the result row
994
- if i > 0:
995
- tile = self.blend_v(rows[i - 1][j], tile, blend_height)
996
- if j > 0:
997
- tile = self.blend_h(row[j - 1], tile, blend_width)
998
- result_row.append(tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width])
999
- result_rows.append(torch.cat(result_row, dim=-1))
1000
-
1001
- enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
1002
- return enc
1003
-
1004
- def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1005
- r"""
1006
- Decode a batch of images using a tiled decoder.
1007
-
1008
- Args:
1009
- z (`torch.Tensor`): Input batch of latent vectors.
1010
- return_dict (`bool`, *optional*, defaults to `True`):
1011
- Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1012
-
1013
- Returns:
1014
- [`~models.vae.DecoderOutput`] or `tuple`:
1015
- If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1016
- returned.
1017
- """
1018
- _, _, num_frames, height, width = z.shape
1019
- sample_height = height * self.spatial_compression_ratio
1020
- sample_width = width * self.spatial_compression_ratio
1021
-
1022
- tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
1023
- tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
1024
- tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
1025
- tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
1026
-
1027
- blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
1028
- blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
1029
-
1030
- # Split z into overlapping tiles and decode them separately.
1031
- # The tiles have an overlap to avoid seams between tiles.
1032
- rows = []
1033
- for i in range(0, height, tile_latent_stride_height):
1034
- row = []
1035
- for j in range(0, width, tile_latent_stride_width):
1036
- self.clear_cache()
1037
- time = []
1038
- for k in range(num_frames):
1039
- self._conv_idx = [0]
1040
- tile = z[:, :, k : k + 1, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
1041
- tile = self.post_quant_conv(tile)
1042
- decoded = self.decoder(tile, feat_cache=self._feat_map, feat_idx=self._conv_idx)
1043
- time.append(decoded)
1044
- row.append(torch.cat(time, dim=2))
1045
- rows.append(row)
1046
- self.clear_cache()
1047
-
1048
- result_rows = []
1049
- for i, row in enumerate(rows):
1050
- result_row = []
1051
- for j, tile in enumerate(row):
1052
- # blend the above tile and the left tile
1053
- # to the current tile and add the current tile to the result row
1054
- if i > 0:
1055
- tile = self.blend_v(rows[i - 1][j], tile, blend_height)
1056
- if j > 0:
1057
- tile = self.blend_h(row[j - 1], tile, blend_width)
1058
- result_row.append(tile[:, :, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
1059
- result_rows.append(torch.cat(result_row, dim=-1))
1060
-
1061
- dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
1062
-
1063
- if not return_dict:
1064
- return (dec,)
1065
- return DecoderOutput(sample=dec)
1066
-
1067
- def forward(
1068
- self,
1069
- sample: torch.Tensor,
1070
- sample_posterior: bool = False,
1071
- return_dict: bool = True,
1072
- generator: Optional[torch.Generator] = None,
1073
- ) -> Union[DecoderOutput, torch.Tensor]:
1074
- """
1075
- Args:
1076
- sample (`torch.Tensor`): Input sample.
1077
- return_dict (`bool`, *optional*, defaults to `True`):
1078
- Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
1079
- """
1080
- x = sample
1081
- posterior = self.encode(x).latent_dist
1082
- if sample_posterior:
1083
- z = posterior.sample(generator=generator)
1084
- else:
1085
- z = posterior.mode()
1086
- dec = self.decode(z, return_dict=return_dict)
1087
- return dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_animate_adapter.py DELETED
@@ -1,397 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import math
3
- from typing import Optional, Tuple
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- import numpy as np
8
- from einops import rearrange
9
- from torch import nn
10
-
11
- try:
12
- from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
13
- except ImportError:
14
- flash_attn_func = None
15
-
16
-
17
- MEMORY_LAYOUT = {
18
- "flash": (
19
- lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]),
20
- lambda x: x,
21
- ),
22
- "torch": (
23
- lambda x: x.transpose(1, 2),
24
- lambda x: x.transpose(1, 2),
25
- ),
26
- "vanilla": (
27
- lambda x: x.transpose(1, 2),
28
- lambda x: x.transpose(1, 2),
29
- ),
30
- }
31
-
32
-
33
- def attention(
34
- q,
35
- k,
36
- v,
37
- mode="flash",
38
- drop_rate=0,
39
- attn_mask=None,
40
- causal=False,
41
- max_seqlen_q=None,
42
- batch_size=1,
43
- ):
44
- """
45
- Perform QKV self attention.
46
-
47
- Args:
48
- q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads.
49
- k (torch.Tensor): Key tensor with shape [b, s1, a, d]
50
- v (torch.Tensor): Value tensor with shape [b, s1, a, d]
51
- mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'.
52
- drop_rate (float): Dropout rate in attention map. (default: 0)
53
- attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla).
54
- (default: None)
55
- causal (bool): Whether to use causal attention. (default: False)
56
- cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
57
- used to index into q.
58
- cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch,
59
- used to index into kv.
60
- max_seqlen_q (int): The maximum sequence length in the batch of q.
61
- max_seqlen_kv (int): The maximum sequence length in the batch of k and v.
62
-
63
- Returns:
64
- torch.Tensor: Output tensor after self attention with shape [b, s, ad]
65
- """
66
- pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
67
-
68
- if mode == "torch":
69
- if attn_mask is not None and attn_mask.dtype != torch.bool:
70
- attn_mask = attn_mask.to(q.dtype)
71
- x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal)
72
-
73
- elif mode == "flash":
74
- x = flash_attn_func(
75
- q,
76
- k,
77
- v,
78
- )
79
- x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # reshape x to [b, s, a, d]
80
- elif mode == "vanilla":
81
- scale_factor = 1 / math.sqrt(q.size(-1))
82
-
83
- b, a, s, _ = q.shape
84
- s1 = k.size(2)
85
- attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
86
- if causal:
87
- # Only applied to self attention
88
- assert attn_mask is None, "Causal mask and attn_mask cannot be used together"
89
- temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(diagonal=0)
90
- attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
91
- attn_bias.to(q.dtype)
92
-
93
- if attn_mask is not None:
94
- if attn_mask.dtype == torch.bool:
95
- attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
96
- else:
97
- attn_bias += attn_mask
98
-
99
- attn = (q @ k.transpose(-2, -1)) * scale_factor
100
- attn += attn_bias
101
- attn = attn.softmax(dim=-1)
102
- attn = torch.dropout(attn, p=drop_rate, train=True)
103
- x = attn @ v
104
- else:
105
- raise NotImplementedError(f"Unsupported attention mode: {mode}")
106
-
107
- x = post_attn_layout(x)
108
- b, s, a, d = x.shape
109
- out = x.reshape(b, s, -1)
110
- return out
111
-
112
-
113
- class CausalConv1d(nn.Module):
114
-
115
- def __init__(self, chan_in, chan_out, kernel_size=3, stride=1, dilation=1, pad_mode="replicate", **kwargs):
116
- super().__init__()
117
-
118
- self.pad_mode = pad_mode
119
- padding = (kernel_size - 1, 0) # T
120
- self.time_causal_padding = padding
121
-
122
- self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
123
-
124
- def forward(self, x):
125
- x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
126
- return self.conv(x)
127
-
128
-
129
-
130
- class FaceEncoder(nn.Module):
131
- def __init__(self, in_dim: int, hidden_dim: int, num_heads=int, dtype=None, device=None):
132
- factory_kwargs = {"dtype": dtype, "device": device}
133
- super().__init__()
134
-
135
- self.num_heads = num_heads
136
- self.conv1_local = CausalConv1d(in_dim, 1024 * num_heads, 3, stride=1)
137
- self.norm1 = nn.LayerNorm(hidden_dim // 8, elementwise_affine=False, eps=1e-6, **factory_kwargs)
138
- self.act = nn.SiLU()
139
- self.conv2 = CausalConv1d(1024, 1024, 3, stride=2)
140
- self.conv3 = CausalConv1d(1024, 1024, 3, stride=2)
141
-
142
- self.out_proj = nn.Linear(1024, hidden_dim)
143
- self.norm1 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
144
-
145
- self.norm2 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
146
-
147
- self.norm3 = nn.LayerNorm(1024, elementwise_affine=False, eps=1e-6, **factory_kwargs)
148
-
149
- self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
150
-
151
- def forward(self, x):
152
-
153
- x = rearrange(x, "b t c -> b c t")
154
- b, c, t = x.shape
155
-
156
- x = self.conv1_local(x)
157
- x = rearrange(x, "b (n c) t -> (b n) t c", n=self.num_heads)
158
-
159
- x = self.norm1(x)
160
- x = self.act(x)
161
- x = rearrange(x, "b t c -> b c t")
162
- x = self.conv2(x)
163
- x = rearrange(x, "b c t -> b t c")
164
- x = self.norm2(x)
165
- x = self.act(x)
166
- x = rearrange(x, "b t c -> b c t")
167
- x = self.conv3(x)
168
- x = rearrange(x, "b c t -> b t c")
169
- x = self.norm3(x)
170
- x = self.act(x)
171
- x = self.out_proj(x)
172
- x = rearrange(x, "(b n) t c -> b t n c", b=b)
173
- padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
174
- x = torch.cat([x, padding], dim=-2)
175
- x_local = x.clone()
176
-
177
- return x_local
178
-
179
-
180
-
181
- class RMSNorm(nn.Module):
182
- def __init__(
183
- self,
184
- dim: int,
185
- elementwise_affine=True,
186
- eps: float = 1e-6,
187
- device=None,
188
- dtype=None,
189
- ):
190
- """
191
- Initialize the RMSNorm normalization layer.
192
-
193
- Args:
194
- dim (int): The dimension of the input tensor.
195
- eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
196
-
197
- Attributes:
198
- eps (float): A small value added to the denominator for numerical stability.
199
- weight (nn.Parameter): Learnable scaling parameter.
200
-
201
- """
202
- factory_kwargs = {"device": device, "dtype": dtype}
203
- super().__init__()
204
- self.eps = eps
205
- if elementwise_affine:
206
- self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
207
-
208
- def _norm(self, x):
209
- """
210
- Apply the RMSNorm normalization to the input tensor.
211
-
212
- Args:
213
- x (torch.Tensor): The input tensor.
214
-
215
- Returns:
216
- torch.Tensor: The normalized tensor.
217
-
218
- """
219
- return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
220
-
221
- def forward(self, x):
222
- """
223
- Forward pass through the RMSNorm layer.
224
-
225
- Args:
226
- x (torch.Tensor): The input tensor.
227
-
228
- Returns:
229
- torch.Tensor: The output tensor after applying RMSNorm.
230
-
231
- """
232
- output = self._norm(x.float()).type_as(x)
233
- if hasattr(self, "weight"):
234
- output = output * self.weight
235
- return output
236
-
237
-
238
- def get_norm_layer(norm_layer):
239
- """
240
- Get the normalization layer.
241
-
242
- Args:
243
- norm_layer (str): The type of normalization layer.
244
-
245
- Returns:
246
- norm_layer (nn.Module): The normalization layer.
247
- """
248
- if norm_layer == "layer":
249
- return nn.LayerNorm
250
- elif norm_layer == "rms":
251
- return RMSNorm
252
- else:
253
- raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
254
-
255
-
256
- class FaceAdapter(nn.Module):
257
- def __init__(
258
- self,
259
- hidden_dim: int,
260
- heads_num: int,
261
- qk_norm: bool = True,
262
- qk_norm_type: str = "rms",
263
- num_adapter_layers: int = 1,
264
- dtype=None,
265
- device=None,
266
- ):
267
-
268
- factory_kwargs = {"dtype": dtype, "device": device}
269
- super().__init__()
270
- self.hidden_size = hidden_dim
271
- self.heads_num = heads_num
272
- self.fuser_blocks = nn.ModuleList(
273
- [
274
- FaceBlock(
275
- self.hidden_size,
276
- self.heads_num,
277
- qk_norm=qk_norm,
278
- qk_norm_type=qk_norm_type,
279
- **factory_kwargs,
280
- )
281
- for _ in range(num_adapter_layers)
282
- ]
283
- )
284
-
285
- def forward(
286
- self,
287
- x: torch.Tensor,
288
- motion_embed: torch.Tensor,
289
- idx: int,
290
- freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
291
- freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
292
- ) -> torch.Tensor:
293
-
294
- return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
295
-
296
-
297
-
298
- class FaceBlock(nn.Module):
299
- def __init__(
300
- self,
301
- hidden_size: int,
302
- heads_num: int,
303
- qk_norm: bool = True,
304
- qk_norm_type: str = "rms",
305
- qk_scale: float = None,
306
- dtype: Optional[torch.dtype] = None,
307
- device: Optional[torch.device] = None,
308
- ):
309
- factory_kwargs = {"device": device, "dtype": dtype}
310
- super().__init__()
311
-
312
- self.deterministic = False
313
- self.hidden_size = hidden_size
314
- self.heads_num = heads_num
315
- head_dim = hidden_size // heads_num
316
- self.scale = qk_scale or head_dim**-0.5
317
-
318
- self.linear1_kv = nn.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
319
- self.linear1_q = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
320
-
321
- self.linear2 = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
322
-
323
- qk_norm_layer = get_norm_layer(qk_norm_type)
324
- self.q_norm = (
325
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
326
- )
327
- self.k_norm = (
328
- qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) if qk_norm else nn.Identity()
329
- )
330
-
331
- self.pre_norm_feat = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
332
-
333
- self.pre_norm_motion = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
334
-
335
- def forward(
336
- self,
337
- x: torch.Tensor,
338
- motion_vec: torch.Tensor,
339
- motion_mask: Optional[torch.Tensor] = None,
340
- use_context_parallel=False,
341
- all_gather=None,
342
- sp_world_size=1,
343
- sp_world_rank=0,
344
- ) -> torch.Tensor:
345
- dtype = x.dtype
346
- B, T, N, C = motion_vec.shape
347
- T_comp = T
348
-
349
- x_motion = self.pre_norm_motion(motion_vec)
350
- x_feat = self.pre_norm_feat(x)
351
-
352
- kv = self.linear1_kv(x_motion)
353
- q = self.linear1_q(x_feat)
354
-
355
- k, v = rearrange(kv, "B L N (K H D) -> K B L N H D", K=2, H=self.heads_num)
356
- q = rearrange(q, "B S (H D) -> B S H D", H=self.heads_num)
357
-
358
- # Apply QK-Norm if needed.
359
- q = self.q_norm(q).to(v)
360
- k = self.k_norm(k).to(v)
361
-
362
- k = rearrange(k, "B L N H D -> (B L) N H D")
363
- v = rearrange(v, "B L N H D -> (B L) N H D")
364
-
365
- if use_context_parallel:
366
- q = all_gather(q, dim=1)
367
-
368
- length = int(np.floor(q.size()[1] / T_comp) * T_comp)
369
- origin_length = q.size()[1]
370
- if origin_length > length:
371
- q_pad = q[:, length:]
372
- q = q[:, :length]
373
-
374
- q = rearrange(q, "B (L S) H D -> (B L) S H D", L=T_comp)
375
- q, k, v = q.to(dtype), k.to(dtype), v.to(dtype)
376
- # Compute attention.
377
- attn = attention(
378
- q,
379
- k,
380
- v,
381
- max_seqlen_q=q.shape[1],
382
- batch_size=q.shape[0],
383
- )
384
-
385
- attn = rearrange(attn, "(B L) S C -> B (L S) C", L=T_comp)
386
- if use_context_parallel:
387
- q_pad = rearrange(q_pad, "B L H D -> B L (H D)")
388
- if origin_length > length:
389
- attn = torch.cat([attn, q_pad], dim=1)
390
- attn = torch.chunk(attn, sp_world_size, dim=1)[sp_world_rank]
391
-
392
- output = self.linear2(attn)
393
-
394
- if motion_mask is not None:
395
- output = output * rearrange(motion_mask, "B T H W -> B (T H W)").unsqueeze(-1)
396
-
397
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_animate_motion_encoder.py DELETED
@@ -1,309 +0,0 @@
1
- # Modified from ``https://github.com/wyhsirius/LIA``
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- import math
4
-
5
- import torch
6
- import torch.nn as nn
7
- from torch.nn import functional as F
8
-
9
-
10
- def custom_qr(input_tensor):
11
- original_dtype = input_tensor.dtype
12
- if original_dtype == torch.bfloat16:
13
- q, r = torch.linalg.qr(input_tensor.to(torch.float32))
14
- return q.to(original_dtype), r.to(original_dtype)
15
- return torch.linalg.qr(input_tensor)
16
-
17
- def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
18
- return F.leaky_relu(input + bias, negative_slope) * scale
19
-
20
-
21
- def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
22
- _, minor, in_h, in_w = input.shape
23
- kernel_h, kernel_w = kernel.shape
24
-
25
- out = input.view(-1, minor, in_h, 1, in_w, 1)
26
- out = F.pad(out, [0, up_x - 1, 0, 0, 0, up_y - 1, 0, 0])
27
- out = out.view(-1, minor, in_h * up_y, in_w * up_x)
28
-
29
- out = F.pad(out, [max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
30
- out = out[:, :, max(-pad_y0, 0): out.shape[2] - max(-pad_y1, 0),
31
- max(-pad_x0, 0): out.shape[3] - max(-pad_x1, 0), ]
32
-
33
- out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
34
- w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
35
- out = F.conv2d(out, w)
36
- out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
37
- in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, )
38
- return out[:, :, ::down_y, ::down_x]
39
-
40
-
41
- def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
42
- return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
43
-
44
-
45
- def make_kernel(k):
46
- k = torch.tensor(k, dtype=torch.float32)
47
- if k.ndim == 1:
48
- k = k[None, :] * k[:, None]
49
- k /= k.sum()
50
- return k
51
-
52
-
53
- class FusedLeakyReLU(nn.Module):
54
- def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
55
- super().__init__()
56
- self.bias = nn.Parameter(torch.zeros(1, channel, 1, 1))
57
- self.negative_slope = negative_slope
58
- self.scale = scale
59
-
60
- def forward(self, input):
61
- out = fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
62
- return out
63
-
64
-
65
- class Blur(nn.Module):
66
- def __init__(self, kernel, pad, upsample_factor=1):
67
- super().__init__()
68
-
69
- kernel = make_kernel(kernel)
70
-
71
- if upsample_factor > 1:
72
- kernel = kernel * (upsample_factor ** 2)
73
-
74
- self.register_buffer('kernel', kernel)
75
-
76
- self.pad = pad
77
-
78
- def forward(self, input):
79
- return upfirdn2d(input, self.kernel, pad=self.pad)
80
-
81
-
82
- class ScaledLeakyReLU(nn.Module):
83
- def __init__(self, negative_slope=0.2):
84
- super().__init__()
85
-
86
- self.negative_slope = negative_slope
87
-
88
- def forward(self, input):
89
- return F.leaky_relu(input, negative_slope=self.negative_slope)
90
-
91
-
92
- class EqualConv2d(nn.Module):
93
- def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True):
94
- super().__init__()
95
-
96
- self.weight = nn.Parameter(torch.randn(out_channel, in_channel, kernel_size, kernel_size))
97
- self.scale = 1 / math.sqrt(in_channel * kernel_size ** 2)
98
-
99
- self.stride = stride
100
- self.padding = padding
101
-
102
- if bias:
103
- self.bias = nn.Parameter(torch.zeros(out_channel))
104
- else:
105
- self.bias = None
106
-
107
- def forward(self, input):
108
-
109
- return F.conv2d(input, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
110
-
111
- def __repr__(self):
112
- return (
113
- f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
114
- f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
115
- )
116
-
117
-
118
- class EqualLinear(nn.Module):
119
- def __init__(self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None):
120
- super().__init__()
121
-
122
- self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
123
-
124
- if bias:
125
- self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
126
- else:
127
- self.bias = None
128
-
129
- self.activation = activation
130
-
131
- self.scale = (1 / math.sqrt(in_dim)) * lr_mul
132
- self.lr_mul = lr_mul
133
-
134
- def forward(self, input):
135
-
136
- if self.activation:
137
- out = F.linear(input, self.weight * self.scale)
138
- out = fused_leaky_relu(out, self.bias * self.lr_mul)
139
- else:
140
- out = F.linear(input, self.weight * self.scale, bias=self.bias * self.lr_mul)
141
-
142
- return out
143
-
144
- def __repr__(self):
145
- return (f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})')
146
-
147
-
148
- class ConvLayer(nn.Sequential):
149
- def __init__(
150
- self,
151
- in_channel,
152
- out_channel,
153
- kernel_size,
154
- downsample=False,
155
- blur_kernel=[1, 3, 3, 1],
156
- bias=True,
157
- activate=True,
158
- ):
159
- layers = []
160
-
161
- if downsample:
162
- factor = 2
163
- p = (len(blur_kernel) - factor) + (kernel_size - 1)
164
- pad0 = (p + 1) // 2
165
- pad1 = p // 2
166
-
167
- layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
168
-
169
- stride = 2
170
- self.padding = 0
171
-
172
- else:
173
- stride = 1
174
- self.padding = kernel_size // 2
175
-
176
- layers.append(EqualConv2d(in_channel, out_channel, kernel_size, padding=self.padding, stride=stride,
177
- bias=bias and not activate))
178
-
179
- if activate:
180
- if bias:
181
- layers.append(FusedLeakyReLU(out_channel))
182
- else:
183
- layers.append(ScaledLeakyReLU(0.2))
184
-
185
- super().__init__(*layers)
186
-
187
-
188
- class ResBlock(nn.Module):
189
- def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
190
- super().__init__()
191
-
192
- self.conv1 = ConvLayer(in_channel, in_channel, 3)
193
- self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
194
-
195
- self.skip = ConvLayer(in_channel, out_channel, 1, downsample=True, activate=False, bias=False)
196
-
197
- def forward(self, input):
198
- out = self.conv1(input)
199
- out = self.conv2(out)
200
-
201
- skip = self.skip(input)
202
- out = (out + skip) / math.sqrt(2)
203
-
204
- return out
205
-
206
-
207
- class EncoderApp(nn.Module):
208
- def __init__(self, size, w_dim=512):
209
- super(EncoderApp, self).__init__()
210
-
211
- channels = {
212
- 4: 512,
213
- 8: 512,
214
- 16: 512,
215
- 32: 512,
216
- 64: 256,
217
- 128: 128,
218
- 256: 64,
219
- 512: 32,
220
- 1024: 16
221
- }
222
-
223
- self.w_dim = w_dim
224
- log_size = int(math.log(size, 2))
225
-
226
- self.convs = nn.ModuleList()
227
- self.convs.append(ConvLayer(3, channels[size], 1))
228
-
229
- in_channel = channels[size]
230
- for i in range(log_size, 2, -1):
231
- out_channel = channels[2 ** (i - 1)]
232
- self.convs.append(ResBlock(in_channel, out_channel))
233
- in_channel = out_channel
234
-
235
- self.convs.append(EqualConv2d(in_channel, self.w_dim, 4, padding=0, bias=False))
236
-
237
- def forward(self, x):
238
-
239
- res = []
240
- h = x
241
- for conv in self.convs:
242
- h = conv(h)
243
- res.append(h)
244
-
245
- return res[-1].squeeze(-1).squeeze(-1), res[::-1][2:]
246
-
247
-
248
- class Encoder(nn.Module):
249
- def __init__(self, size, dim=512, dim_motion=20):
250
- super(Encoder, self).__init__()
251
-
252
- # appearance netmork
253
- self.net_app = EncoderApp(size, dim)
254
-
255
- # motion network
256
- fc = [EqualLinear(dim, dim)]
257
- for i in range(3):
258
- fc.append(EqualLinear(dim, dim))
259
-
260
- fc.append(EqualLinear(dim, dim_motion))
261
- self.fc = nn.Sequential(*fc)
262
-
263
- def enc_app(self, x):
264
- h_source = self.net_app(x)
265
- return h_source
266
-
267
- def enc_motion(self, x):
268
- h, _ = self.net_app(x)
269
- h_motion = self.fc(h)
270
- return h_motion
271
-
272
-
273
- class Direction(nn.Module):
274
- def __init__(self, motion_dim):
275
- super(Direction, self).__init__()
276
- self.weight = nn.Parameter(torch.randn(512, motion_dim))
277
-
278
- def forward(self, input):
279
-
280
- weight = self.weight + 1e-8
281
- Q, R = custom_qr(weight)
282
- if input is None:
283
- return Q
284
- else:
285
- input_diag = torch.diag_embed(input) # alpha, diagonal matrix
286
- out = torch.matmul(input_diag, Q.T)
287
- out = torch.sum(out, dim=1)
288
- return out
289
-
290
-
291
- class Synthesis(nn.Module):
292
- def __init__(self, motion_dim):
293
- super(Synthesis, self).__init__()
294
- self.direction = Direction(motion_dim)
295
-
296
-
297
- class Generator(nn.Module):
298
- def __init__(self, size, style_dim=512, motion_dim=20):
299
- super().__init__()
300
-
301
- self.enc = Encoder(size, style_dim, motion_dim)
302
- self.dec = Synthesis(motion_dim)
303
-
304
- def get_motion(self, img):
305
- #motion_feat = self.enc.enc_motion(img)
306
- motion_feat = torch.utils.checkpoint.checkpoint((self.enc.enc_motion), img, use_reentrant=True)
307
- with torch.cuda.amp.autocast(dtype=torch.float32):
308
- motion = self.dec.direction(motion_feat)
309
- return motion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_audio_encoder.py DELETED
@@ -1,213 +0,0 @@
1
- # Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/audio_encoder.py
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- import math
4
-
5
- import librosa
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
10
- from diffusers.configuration_utils import ConfigMixin
11
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
12
- from diffusers.models.modeling_utils import ModelMixin
13
-
14
-
15
- def get_sample_indices(original_fps,
16
- total_frames,
17
- target_fps,
18
- num_sample,
19
- fixed_start=None):
20
- required_duration = num_sample / target_fps
21
- required_origin_frames = int(np.ceil(required_duration * original_fps))
22
- if required_duration > total_frames / original_fps:
23
- raise ValueError("required_duration must be less than video length")
24
-
25
- if not fixed_start is None and fixed_start >= 0:
26
- start_frame = fixed_start
27
- else:
28
- max_start = total_frames - required_origin_frames
29
- if max_start < 0:
30
- raise ValueError("video length is too short")
31
- start_frame = np.random.randint(0, max_start + 1)
32
- start_time = start_frame / original_fps
33
-
34
- end_time = start_time + required_duration
35
- time_points = np.linspace(start_time, end_time, num_sample, endpoint=False)
36
-
37
- frame_indices = np.round(np.array(time_points) * original_fps).astype(int)
38
- frame_indices = np.clip(frame_indices, 0, total_frames - 1)
39
- return frame_indices
40
-
41
-
42
- def linear_interpolation(features, input_fps, output_fps, output_len=None):
43
- """
44
- features: shape=[1, T, 512]
45
- input_fps: fps for audio, f_a
46
- output_fps: fps for video, f_m
47
- output_len: video length
48
- """
49
- features = features.transpose(1, 2) # [1, 512, T]
50
- seq_len = features.shape[2] / float(input_fps) # T/f_a
51
- if output_len is None:
52
- output_len = int(seq_len * output_fps) # f_m*T/f_a
53
- output_features = F.interpolate(
54
- features, size=output_len, align_corners=True,
55
- mode='linear') # [1, 512, output_len]
56
- return output_features.transpose(1, 2) # [1, output_len, 512]
57
-
58
-
59
- class WanAudioEncoder(ModelMixin, ConfigMixin, FromOriginalModelMixin):
60
-
61
- def __init__(self, pretrained_model_path="facebook/wav2vec2-base-960h", device='cpu'):
62
- super(WanAudioEncoder, self).__init__()
63
- # load pretrained model
64
- self.processor = Wav2Vec2Processor.from_pretrained(pretrained_model_path)
65
- self.model = Wav2Vec2ForCTC.from_pretrained(pretrained_model_path)
66
-
67
- self.model = self.model.to(device)
68
-
69
- self.video_rate = 30
70
-
71
- def extract_audio_feat(self,
72
- audio_path,
73
- return_all_layers=False,
74
- dtype=torch.float32):
75
- audio_input, sample_rate = librosa.load(audio_path, sr=16000)
76
-
77
- input_values = self.processor(
78
- audio_input, sampling_rate=sample_rate, return_tensors="pt"
79
- ).input_values
80
-
81
- # INFERENCE
82
-
83
- # retrieve logits & take argmax
84
- res = self.model(
85
- input_values.to(self.model.device), output_hidden_states=True)
86
- if return_all_layers:
87
- feat = torch.cat(res.hidden_states)
88
- else:
89
- feat = res.hidden_states[-1]
90
- feat = linear_interpolation(
91
- feat, input_fps=50, output_fps=self.video_rate)
92
-
93
- z = feat.to(dtype) # Encoding for the motion
94
- return z
95
-
96
- def extract_audio_feat_without_file_load(self, audio_input, sample_rate, return_all_layers=False, dtype=torch.float32):
97
- input_values = self.processor(
98
- audio_input, sampling_rate=sample_rate, return_tensors="pt"
99
- ).input_values
100
-
101
- # INFERENCE
102
- # retrieve logits & take argmax
103
- res = self.model(
104
- input_values.to(self.model.device), output_hidden_states=True)
105
- if return_all_layers:
106
- feat = torch.cat(res.hidden_states)
107
- else:
108
- feat = res.hidden_states[-1]
109
- feat = linear_interpolation(
110
- feat, input_fps=50, output_fps=self.video_rate)
111
-
112
- z = feat.to(dtype) # Encoding for the motion
113
- return z
114
-
115
- def get_audio_embed_bucket(self,
116
- audio_embed,
117
- stride=2,
118
- batch_frames=12,
119
- m=2):
120
- num_layers, audio_frame_num, audio_dim = audio_embed.shape
121
-
122
- if num_layers > 1:
123
- return_all_layers = True
124
- else:
125
- return_all_layers = False
126
-
127
- min_batch_num = int(audio_frame_num / (batch_frames * stride)) + 1
128
-
129
- bucket_num = min_batch_num * batch_frames
130
- batch_idx = [stride * i for i in range(bucket_num)]
131
- batch_audio_eb = []
132
- for bi in batch_idx:
133
- if bi < audio_frame_num:
134
- audio_sample_stride = 2
135
- chosen_idx = list(
136
- range(bi - m * audio_sample_stride,
137
- bi + (m + 1) * audio_sample_stride,
138
- audio_sample_stride))
139
- chosen_idx = [0 if c < 0 else c for c in chosen_idx]
140
- chosen_idx = [
141
- audio_frame_num - 1 if c >= audio_frame_num else c
142
- for c in chosen_idx
143
- ]
144
-
145
- if return_all_layers:
146
- frame_audio_embed = audio_embed[:, chosen_idx].flatten(
147
- start_dim=-2, end_dim=-1)
148
- else:
149
- frame_audio_embed = audio_embed[0][chosen_idx].flatten()
150
- else:
151
- frame_audio_embed = \
152
- torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
153
- else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
154
- batch_audio_eb.append(frame_audio_embed)
155
- batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
156
- dim=0)
157
-
158
- return batch_audio_eb, min_batch_num
159
-
160
- def get_audio_embed_bucket_fps(self,
161
- audio_embed,
162
- fps=16,
163
- batch_frames=81,
164
- m=0):
165
- num_layers, audio_frame_num, audio_dim = audio_embed.shape
166
-
167
- if num_layers > 1:
168
- return_all_layers = True
169
- else:
170
- return_all_layers = False
171
-
172
- scale = self.video_rate / fps
173
-
174
- min_batch_num = int(audio_frame_num / (batch_frames * scale)) + 1
175
-
176
- bucket_num = min_batch_num * batch_frames
177
- padd_audio_num = math.ceil(min_batch_num * batch_frames / fps *
178
- self.video_rate) - audio_frame_num
179
- batch_idx = get_sample_indices(
180
- original_fps=self.video_rate,
181
- total_frames=audio_frame_num + padd_audio_num,
182
- target_fps=fps,
183
- num_sample=bucket_num,
184
- fixed_start=0)
185
- batch_audio_eb = []
186
- audio_sample_stride = int(self.video_rate / fps)
187
- for bi in batch_idx:
188
- if bi < audio_frame_num:
189
-
190
- chosen_idx = list(
191
- range(bi - m * audio_sample_stride,
192
- bi + (m + 1) * audio_sample_stride,
193
- audio_sample_stride))
194
- chosen_idx = [0 if c < 0 else c for c in chosen_idx]
195
- chosen_idx = [
196
- audio_frame_num - 1 if c >= audio_frame_num else c
197
- for c in chosen_idx
198
- ]
199
-
200
- if return_all_layers:
201
- frame_audio_embed = audio_embed[:, chosen_idx].flatten(
202
- start_dim=-2, end_dim=-1)
203
- else:
204
- frame_audio_embed = audio_embed[0][chosen_idx].flatten()
205
- else:
206
- frame_audio_embed = \
207
- torch.zeros([audio_dim * (2 * m + 1)], device=audio_embed.device) if not return_all_layers \
208
- else torch.zeros([num_layers, audio_dim * (2 * m + 1)], device=audio_embed.device)
209
- batch_audio_eb.append(frame_audio_embed)
210
- batch_audio_eb = torch.cat([c.unsqueeze(0) for c in batch_audio_eb],
211
- dim=0)
212
-
213
- return batch_audio_eb, min_batch_num
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_audio_injector.py DELETED
@@ -1,1093 +0,0 @@
1
- # Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/motioner.py
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- import importlib.metadata
4
- import math
5
- from typing import Any, Dict, List, Literal, Optional, Tuple, Union
6
-
7
- import numpy as np
8
- import torch
9
- import torch.cuda.amp as amp
10
- import torch.nn as nn
11
- import torch.nn.functional as F
12
- from diffusers.configuration_utils import ConfigMixin, register_to_config
13
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
14
- from diffusers.models import ModelMixin
15
- from diffusers.models.attention import AdaLayerNorm
16
- from diffusers.utils import BaseOutput, is_torch_version, logging
17
- from einops import rearrange, repeat
18
-
19
- from .attention_utils import attention
20
- from .wan_transformer3d import WanAttentionBlock, WanCrossAttention
21
-
22
-
23
- def rope_precompute(x, grid_sizes, freqs, start=None):
24
- b, s, n, c = x.size(0), x.size(1), x.size(2), x.size(3) // 2
25
-
26
- # split freqs
27
- if type(freqs) is list:
28
- trainable_freqs = freqs[1]
29
- freqs = freqs[0]
30
- freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
31
-
32
- # loop over samples
33
- output = torch.view_as_complex(x.detach().reshape(b, s, n, -1,
34
- 2).to(torch.float64))
35
- seq_bucket = [0]
36
- if not type(grid_sizes) is list:
37
- grid_sizes = [grid_sizes]
38
- for g in grid_sizes:
39
- if not type(g) is list:
40
- g = [torch.zeros_like(g), g]
41
- batch_size = g[0].shape[0]
42
- for i in range(batch_size):
43
- if start is None:
44
- f_o, h_o, w_o = g[0][i]
45
- else:
46
- f_o, h_o, w_o = start[i]
47
-
48
- f, h, w = g[1][i]
49
- t_f, t_h, t_w = g[2][i]
50
- seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
51
- seq_len = int(seq_f * seq_h * seq_w)
52
- if seq_len > 0:
53
- if t_f > 0:
54
- factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
55
- t_h / seq_h).item(), (t_w / seq_w).item()
56
- # Generate a list of seq_f integers starting from f_o and ending at math.ceil(factor_f * seq_f.item() + f_o.item())
57
- if f_o >= 0:
58
- f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
59
- seq_f).astype(int).tolist()
60
- else:
61
- f_sam = np.linspace(-f_o.item(),
62
- (-t_f - f_o).item() + 1,
63
- seq_f).astype(int).tolist()
64
- h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
65
- seq_h).astype(int).tolist()
66
- w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
67
- seq_w).astype(int).tolist()
68
-
69
- assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
70
- freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
71
- f_sam].conj()
72
- freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
73
-
74
- freqs_i = torch.cat([
75
- freqs_0.expand(seq_f, seq_h, seq_w, -1),
76
- freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
77
- seq_f, seq_h, seq_w, -1),
78
- freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
79
- seq_f, seq_h, seq_w, -1),
80
- ],
81
- dim=-1).reshape(seq_len, 1, -1)
82
- elif t_f < 0:
83
- freqs_i = trainable_freqs.unsqueeze(1)
84
- # apply rotary embedding
85
- output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = freqs_i
86
- seq_bucket.append(seq_bucket[-1] + seq_len)
87
- return output
88
-
89
-
90
- def sinusoidal_embedding_1d(dim, position):
91
- # preprocess
92
- assert dim % 2 == 0
93
- half = dim // 2
94
- position = position.type(torch.float64)
95
-
96
- # calculation
97
- sinusoid = torch.outer(
98
- position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
99
- x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
100
- return x
101
-
102
-
103
- @amp.autocast(enabled=False)
104
- def rope_params(max_seq_len, dim, theta=10000):
105
- assert dim % 2 == 0
106
- freqs = torch.outer(
107
- torch.arange(max_seq_len),
108
- 1.0 / torch.pow(theta,
109
- torch.arange(0, dim, 2).to(torch.float64).div(dim)))
110
- freqs = torch.polar(torch.ones_like(freqs), freqs)
111
- return freqs
112
-
113
-
114
- @amp.autocast(enabled=False)
115
- def rope_apply(x, grid_sizes, freqs, start=None):
116
- n, c = x.size(2), x.size(3) // 2
117
-
118
- # split freqs
119
- if type(freqs) is list:
120
- trainable_freqs = freqs[1]
121
- freqs = freqs[0]
122
- freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
123
-
124
- # loop over samples
125
- output = []
126
- output = x.clone()
127
- seq_bucket = [0]
128
- if not type(grid_sizes) is list:
129
- grid_sizes = [grid_sizes]
130
- for g in grid_sizes:
131
- if not type(g) is list:
132
- g = [torch.zeros_like(g), g]
133
- batch_size = g[0].shape[0]
134
- for i in range(batch_size):
135
- if start is None:
136
- f_o, h_o, w_o = g[0][i]
137
- else:
138
- f_o, h_o, w_o = start[i]
139
-
140
- f, h, w = g[1][i]
141
- t_f, t_h, t_w = g[2][i]
142
- seq_f, seq_h, seq_w = f - f_o, h - h_o, w - w_o
143
- seq_len = int(seq_f * seq_h * seq_w)
144
- if seq_len > 0:
145
- if t_f > 0:
146
- factor_f, factor_h, factor_w = (t_f / seq_f).item(), (
147
- t_h / seq_h).item(), (t_w / seq_w).item()
148
-
149
- if f_o >= 0:
150
- f_sam = np.linspace(f_o.item(), (t_f + f_o).item() - 1,
151
- seq_f).astype(int).tolist()
152
- else:
153
- f_sam = np.linspace(-f_o.item(),
154
- (-t_f - f_o).item() + 1,
155
- seq_f).astype(int).tolist()
156
- h_sam = np.linspace(h_o.item(), (t_h + h_o).item() - 1,
157
- seq_h).astype(int).tolist()
158
- w_sam = np.linspace(w_o.item(), (t_w + w_o).item() - 1,
159
- seq_w).astype(int).tolist()
160
-
161
- assert f_o * f >= 0 and h_o * h >= 0 and w_o * w >= 0
162
- freqs_0 = freqs[0][f_sam] if f_o >= 0 else freqs[0][
163
- f_sam].conj()
164
- freqs_0 = freqs_0.view(seq_f, 1, 1, -1)
165
-
166
- freqs_i = torch.cat([
167
- freqs_0.expand(seq_f, seq_h, seq_w, -1),
168
- freqs[1][h_sam].view(1, seq_h, 1, -1).expand(
169
- seq_f, seq_h, seq_w, -1),
170
- freqs[2][w_sam].view(1, 1, seq_w, -1).expand(
171
- seq_f, seq_h, seq_w, -1),
172
- ],
173
- dim=-1).reshape(seq_len, 1, -1)
174
- elif t_f < 0:
175
- freqs_i = trainable_freqs.unsqueeze(1)
176
- # apply rotary embedding
177
- # precompute multipliers
178
- x_i = torch.view_as_complex(
179
- x[i, seq_bucket[-1]:seq_bucket[-1] + seq_len].to(
180
- torch.float64).reshape(seq_len, n, -1, 2))
181
- x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
182
- output[i, seq_bucket[-1]:seq_bucket[-1] + seq_len] = x_i
183
- seq_bucket.append(seq_bucket[-1] + seq_len)
184
- return output.float()
185
-
186
-
187
-
188
- class CausalConv1d(nn.Module):
189
-
190
- def __init__(self,
191
- chan_in,
192
- chan_out,
193
- kernel_size=3,
194
- stride=1,
195
- dilation=1,
196
- pad_mode='replicate',
197
- **kwargs):
198
- super().__init__()
199
-
200
- self.pad_mode = pad_mode
201
- padding = (kernel_size - 1, 0) # T
202
- self.time_causal_padding = padding
203
-
204
- self.conv = nn.Conv1d(
205
- chan_in,
206
- chan_out,
207
- kernel_size,
208
- stride=stride,
209
- dilation=dilation,
210
- **kwargs)
211
-
212
- def forward(self, x):
213
- x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
214
- return self.conv(x)
215
-
216
-
217
- class MotionEncoder_tc(nn.Module):
218
-
219
- def __init__(self,
220
- in_dim: int,
221
- hidden_dim: int,
222
- num_heads=int,
223
- need_global=True,
224
- dtype=None,
225
- device=None):
226
- factory_kwargs = {"dtype": dtype, "device": device}
227
- super().__init__()
228
-
229
- self.num_heads = num_heads
230
- self.need_global = need_global
231
- self.conv1_local = CausalConv1d(
232
- in_dim, hidden_dim // 4 * num_heads, 3, stride=1)
233
- if need_global:
234
- self.conv1_global = CausalConv1d(
235
- in_dim, hidden_dim // 4, 3, stride=1)
236
- self.norm1 = nn.LayerNorm(
237
- hidden_dim // 4,
238
- elementwise_affine=False,
239
- eps=1e-6,
240
- **factory_kwargs)
241
- self.act = nn.SiLU()
242
- self.conv2 = CausalConv1d(hidden_dim // 4, hidden_dim // 2, 3, stride=2)
243
- self.conv3 = CausalConv1d(hidden_dim // 2, hidden_dim, 3, stride=2)
244
-
245
- if need_global:
246
- self.final_linear = nn.Linear(hidden_dim, hidden_dim,
247
- **factory_kwargs)
248
-
249
- self.norm1 = nn.LayerNorm(
250
- hidden_dim // 4,
251
- elementwise_affine=False,
252
- eps=1e-6,
253
- **factory_kwargs)
254
-
255
- self.norm2 = nn.LayerNorm(
256
- hidden_dim // 2,
257
- elementwise_affine=False,
258
- eps=1e-6,
259
- **factory_kwargs)
260
-
261
- self.norm3 = nn.LayerNorm(
262
- hidden_dim, elementwise_affine=False, eps=1e-6, **factory_kwargs)
263
-
264
- self.padding_tokens = nn.Parameter(torch.zeros(1, 1, 1, hidden_dim))
265
-
266
- def forward(self, x):
267
- x = rearrange(x, 'b t c -> b c t')
268
- x_ori = x.clone()
269
- b, c, t = x.shape
270
- x = self.conv1_local(x)
271
- x = rearrange(x, 'b (n c) t -> (b n) t c', n=self.num_heads)
272
- x = self.norm1(x)
273
- x = self.act(x)
274
- x = rearrange(x, 'b t c -> b c t')
275
- x = self.conv2(x)
276
- x = rearrange(x, 'b c t -> b t c')
277
- x = self.norm2(x)
278
- x = self.act(x)
279
- x = rearrange(x, 'b t c -> b c t')
280
- x = self.conv3(x)
281
- x = rearrange(x, 'b c t -> b t c')
282
- x = self.norm3(x)
283
- x = self.act(x)
284
- x = rearrange(x, '(b n) t c -> b t n c', b=b)
285
- padding = self.padding_tokens.repeat(b, x.shape[1], 1, 1)
286
- x = torch.cat([x, padding], dim=-2)
287
- x_local = x.clone()
288
-
289
- if not self.need_global:
290
- return x_local
291
-
292
- x = self.conv1_global(x_ori)
293
- x = rearrange(x, 'b c t -> b t c')
294
- x = self.norm1(x)
295
- x = self.act(x)
296
- x = rearrange(x, 'b t c -> b c t')
297
- x = self.conv2(x)
298
- x = rearrange(x, 'b c t -> b t c')
299
- x = self.norm2(x)
300
- x = self.act(x)
301
- x = rearrange(x, 'b t c -> b c t')
302
- x = self.conv3(x)
303
- x = rearrange(x, 'b c t -> b t c')
304
- x = self.norm3(x)
305
- x = self.act(x)
306
- x = self.final_linear(x)
307
- x = rearrange(x, '(b n) t c -> b t n c', b=b)
308
-
309
- return x, x_local
310
-
311
-
312
- class CausalAudioEncoder(nn.Module):
313
-
314
- def __init__(self,
315
- dim=5120,
316
- num_layers=25,
317
- out_dim=2048,
318
- video_rate=8,
319
- num_token=4,
320
- need_global=False):
321
- super().__init__()
322
- self.encoder = MotionEncoder_tc(
323
- in_dim=dim,
324
- hidden_dim=out_dim,
325
- num_heads=num_token,
326
- need_global=need_global)
327
- weight = torch.ones((1, num_layers, 1, 1)) * 0.01
328
-
329
- self.weights = torch.nn.Parameter(weight)
330
- self.act = torch.nn.SiLU()
331
-
332
- def forward(self, features):
333
- with amp.autocast(dtype=torch.float32):
334
- # features B * num_layers * dim * video_length
335
- weights = self.act(self.weights)
336
- weights_sum = weights.sum(dim=1, keepdims=True)
337
- weighted_feat = ((features * weights) / weights_sum).sum(
338
- dim=1) # b dim f
339
- weighted_feat = weighted_feat.permute(0, 2, 1) # b f dim
340
- res = self.encoder(weighted_feat) # b f n dim
341
-
342
- return res # b f n dim
343
-
344
-
345
- class AudioCrossAttention(WanCrossAttention):
346
-
347
- def __init__(self, *args, **kwargs):
348
- super().__init__(*args, **kwargs)
349
-
350
- def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
351
- r"""
352
- Args:
353
- x(Tensor): Shape [B, L1, C]
354
- context(Tensor): Shape [B, L2, C]
355
- context_lens(Tensor): Shape [B]
356
- """
357
- b, n, d = x.size(0), self.num_heads, self.head_dim
358
- # compute query, key, value
359
- q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
360
- k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
361
- v = self.v(context.to(dtype)).view(b, -1, n, d)
362
- # compute attention
363
- x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens, attention_type="FLASH_ATTENTION")
364
- # output
365
- x = x.flatten(2)
366
- x = self.o(x.to(dtype))
367
- return x
368
-
369
-
370
- class AudioInjector_WAN(nn.Module):
371
-
372
- def __init__(self,
373
- all_modules,
374
- all_modules_names,
375
- dim=2048,
376
- num_heads=32,
377
- inject_layer=[0, 27],
378
- root_net=None,
379
- enable_adain=False,
380
- adain_dim=2048,
381
- need_adain_ont=False):
382
- super().__init__()
383
- num_injector_layers = len(inject_layer)
384
- self.injected_block_id = {}
385
- audio_injector_id = 0
386
- for mod_name, mod in zip(all_modules_names, all_modules):
387
- if isinstance(mod, WanAttentionBlock):
388
- for inject_id in inject_layer:
389
- if f'transformer_blocks.{inject_id}' in mod_name:
390
- self.injected_block_id[inject_id] = audio_injector_id
391
- audio_injector_id += 1
392
-
393
- self.injector = nn.ModuleList([
394
- AudioCrossAttention(
395
- dim=dim,
396
- num_heads=num_heads,
397
- qk_norm=True,
398
- ) for _ in range(audio_injector_id)
399
- ])
400
- self.injector_pre_norm_feat = nn.ModuleList([
401
- nn.LayerNorm(
402
- dim,
403
- elementwise_affine=False,
404
- eps=1e-6,
405
- ) for _ in range(audio_injector_id)
406
- ])
407
- self.injector_pre_norm_vec = nn.ModuleList([
408
- nn.LayerNorm(
409
- dim,
410
- elementwise_affine=False,
411
- eps=1e-6,
412
- ) for _ in range(audio_injector_id)
413
- ])
414
- if enable_adain:
415
- self.injector_adain_layers = nn.ModuleList([
416
- AdaLayerNorm(
417
- output_dim=dim * 2, embedding_dim=adain_dim, chunk_dim=1)
418
- for _ in range(audio_injector_id)
419
- ])
420
- if need_adain_ont:
421
- self.injector_adain_output_layers = nn.ModuleList(
422
- [nn.Linear(dim, dim) for _ in range(audio_injector_id)])
423
-
424
-
425
- class RMSNorm(nn.Module):
426
-
427
- def __init__(self, dim, eps=1e-5):
428
- super().__init__()
429
- self.dim = dim
430
- self.eps = eps
431
- self.weight = nn.Parameter(torch.ones(dim))
432
-
433
- def forward(self, x):
434
- return self._norm(x.float()).type_as(x) * self.weight
435
-
436
- def _norm(self, x):
437
- return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
438
-
439
-
440
- class LayerNorm(nn.LayerNorm):
441
-
442
- def __init__(self, dim, eps=1e-6, elementwise_affine=False):
443
- super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
444
-
445
- def forward(self, x):
446
- return super().forward(x.float()).type_as(x)
447
-
448
-
449
- class SelfAttention(nn.Module):
450
-
451
- def __init__(self,
452
- dim,
453
- num_heads,
454
- window_size=(-1, -1),
455
- qk_norm=True,
456
- eps=1e-6):
457
- assert dim % num_heads == 0
458
- super().__init__()
459
- self.dim = dim
460
- self.num_heads = num_heads
461
- self.head_dim = dim // num_heads
462
- self.window_size = window_size
463
- self.qk_norm = qk_norm
464
- self.eps = eps
465
-
466
- # layers
467
- self.q = nn.Linear(dim, dim)
468
- self.k = nn.Linear(dim, dim)
469
- self.v = nn.Linear(dim, dim)
470
- self.o = nn.Linear(dim, dim)
471
- self.norm_q = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
472
- self.norm_k = RMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
473
-
474
- def forward(self, x, seq_lens, grid_sizes, freqs):
475
- b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
476
-
477
- # query, key, value function
478
- def qkv_fn(x):
479
- q = self.norm_q(self.q(x)).view(b, s, n, d)
480
- k = self.norm_k(self.k(x)).view(b, s, n, d)
481
- v = self.v(x).view(b, s, n, d)
482
- return q, k, v
483
-
484
- q, k, v = qkv_fn(x)
485
-
486
- x = attention(
487
- q=rope_apply(q, grid_sizes, freqs),
488
- k=rope_apply(k, grid_sizes, freqs),
489
- v=v,
490
- k_lens=seq_lens,
491
- window_size=self.window_size)
492
-
493
- # output
494
- x = x.flatten(2)
495
- x = self.o(x)
496
- return x
497
-
498
-
499
- class SwinSelfAttention(SelfAttention):
500
-
501
- def forward(self, x, seq_lens, grid_sizes, freqs):
502
- b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
503
- assert b == 1, 'Only support batch_size 1'
504
-
505
- # query, key, value function
506
- def qkv_fn(x):
507
- q = self.norm_q(self.q(x)).view(b, s, n, d)
508
- k = self.norm_k(self.k(x)).view(b, s, n, d)
509
- v = self.v(x).view(b, s, n, d)
510
- return q, k, v
511
-
512
- q, k, v = qkv_fn(x)
513
-
514
- q = rope_apply(q, grid_sizes, freqs)
515
- k = rope_apply(k, grid_sizes, freqs)
516
- T, H, W = grid_sizes[0].tolist()
517
-
518
- q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
519
- k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
520
- v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
521
-
522
- ref_q = q[-1:]
523
- q = q[:-1]
524
-
525
- ref_k = repeat(
526
- k[-1:], "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
527
- k = k[:-1]
528
- k = torch.cat([k[:1], k, k[-1:]])
529
- k = torch.cat([k[1:-1], k[2:], k[:-2], ref_k], dim=1) # (bt) (3hw) n d
530
-
531
- ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=v.shape[0] - 1)
532
- v = v[:-1]
533
- v = torch.cat([v[:1], v, v[-1:]])
534
- v = torch.cat([v[1:-1], v[2:], v[:-2], ref_v], dim=1)
535
-
536
- # q: b (t h w) n d
537
- # k: b (t h w) n d
538
- out = attention(
539
- q=q,
540
- k=k,
541
- v=v,
542
- # k_lens=torch.tensor([k.shape[1]] * k.shape[0], device=x.device, dtype=torch.long),
543
- window_size=self.window_size)
544
- out = torch.cat([out, ref_v[:1]], axis=0)
545
- out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
546
- x = out
547
-
548
- # output
549
- x = x.flatten(2)
550
- x = self.o(x)
551
- return x
552
-
553
-
554
- #Fix the reference frame RoPE to 1,H,W.
555
- #Set the current frame RoPE to 1.
556
- #Set the previous frame RoPE to 0.
557
- class CasualSelfAttention(SelfAttention):
558
-
559
- def forward(self, x, seq_lens, grid_sizes, freqs):
560
- shifting = 3
561
- b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
562
- assert b == 1, 'Only support batch_size 1'
563
-
564
- # query, key, value function
565
- def qkv_fn(x):
566
- q = self.norm_q(self.q(x)).view(b, s, n, d)
567
- k = self.norm_k(self.k(x)).view(b, s, n, d)
568
- v = self.v(x).view(b, s, n, d)
569
- return q, k, v
570
-
571
- q, k, v = qkv_fn(x)
572
-
573
- T, H, W = grid_sizes[0].tolist()
574
-
575
- q = rearrange(q, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
576
- k = rearrange(k, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
577
- v = rearrange(v, 'b (t h w) n d -> (b t) (h w) n d', t=T, h=H, w=W)
578
-
579
- ref_q = q[-1:]
580
- q = q[:-1]
581
-
582
- grid_sizes = torch.tensor([[1, H, W]] * q.shape[0], dtype=torch.long)
583
- start = [[shifting, 0, 0]] * q.shape[0]
584
- q = rope_apply(q, grid_sizes, freqs, start=start)
585
-
586
- ref_k = k[-1:]
587
- grid_sizes = torch.tensor([[1, H, W]], dtype=torch.long)
588
- # start = [[shifting, H, W]]
589
-
590
- start = [[shifting + 10, 0, 0]]
591
- ref_k = rope_apply(ref_k, grid_sizes, freqs, start)
592
- ref_k = repeat(
593
- ref_k, "1 s n d -> t s n d", t=k.shape[0] - 1) # t hw n d
594
-
595
- k = k[:-1]
596
- k = torch.cat([*([k[:1]] * shifting), k])
597
- cat_k = []
598
- for i in range(shifting):
599
- cat_k.append(k[i:i - shifting])
600
- cat_k.append(k[shifting:])
601
- k = torch.cat(cat_k, dim=1) # (bt) (3hw) n d
602
-
603
- grid_sizes = torch.tensor(
604
- [[shifting + 1, H, W]] * q.shape[0], dtype=torch.long)
605
- k = rope_apply(k, grid_sizes, freqs)
606
- k = torch.cat([k, ref_k], dim=1)
607
-
608
- ref_v = repeat(v[-1:], "1 s n d -> t s n d", t=q.shape[0]) # t hw n d
609
- v = v[:-1]
610
- v = torch.cat([*([v[:1]] * shifting), v])
611
- cat_v = []
612
- for i in range(shifting):
613
- cat_v.append(v[i:i - shifting])
614
- cat_v.append(v[shifting:])
615
- v = torch.cat(cat_v, dim=1) # (bt) (3hw) n d
616
- v = torch.cat([v, ref_v], dim=1)
617
-
618
- # q: b (t h w) n d
619
- # k: b (t h w) n d
620
- outs = []
621
- for i in range(q.shape[0]):
622
- out = attention(
623
- q=q[i:i + 1],
624
- k=k[i:i + 1],
625
- v=v[i:i + 1],
626
- window_size=self.window_size)
627
- outs.append(out)
628
- out = torch.cat(outs, dim=0)
629
- out = torch.cat([out, ref_v[:1]], axis=0)
630
- out = rearrange(out, '(b t) (h w) n d -> b (t h w) n d', t=T, h=H, w=W)
631
- x = out
632
-
633
- # output
634
- x = x.flatten(2)
635
- x = self.o(x)
636
- return x
637
-
638
-
639
- class MotionerAttentionBlock(nn.Module):
640
-
641
- def __init__(self,
642
- dim,
643
- ffn_dim,
644
- num_heads,
645
- window_size=(-1, -1),
646
- qk_norm=True,
647
- cross_attn_norm=False,
648
- eps=1e-6,
649
- self_attn_block="SelfAttention"):
650
- super().__init__()
651
- self.dim = dim
652
- self.ffn_dim = ffn_dim
653
- self.num_heads = num_heads
654
- self.window_size = window_size
655
- self.qk_norm = qk_norm
656
- self.cross_attn_norm = cross_attn_norm
657
- self.eps = eps
658
-
659
- # layers
660
- self.norm1 = LayerNorm(dim, eps)
661
- if self_attn_block == "SelfAttention":
662
- self.self_attn = SelfAttention(dim, num_heads, window_size, qk_norm,
663
- eps)
664
- elif self_attn_block == "SwinSelfAttention":
665
- self.self_attn = SwinSelfAttention(dim, num_heads, window_size,
666
- qk_norm, eps)
667
- elif self_attn_block == "CasualSelfAttention":
668
- self.self_attn = CasualSelfAttention(dim, num_heads, window_size,
669
- qk_norm, eps)
670
-
671
- self.norm2 = LayerNorm(dim, eps)
672
- self.ffn = nn.Sequential(
673
- nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
674
- nn.Linear(ffn_dim, dim))
675
-
676
- def forward(
677
- self,
678
- x,
679
- seq_lens,
680
- grid_sizes,
681
- freqs,
682
- ):
683
- # self-attention
684
- y = self.self_attn(self.norm1(x).float(), seq_lens, grid_sizes, freqs)
685
- x = x + y
686
- y = self.ffn(self.norm2(x).float())
687
- x = x + y
688
- return x
689
-
690
-
691
- class Head(nn.Module):
692
-
693
- def __init__(self, dim, out_dim, patch_size, eps=1e-6):
694
- super().__init__()
695
- self.dim = dim
696
- self.out_dim = out_dim
697
- self.patch_size = patch_size
698
- self.eps = eps
699
-
700
- # layers
701
- out_dim = math.prod(patch_size) * out_dim
702
- self.norm = LayerNorm(dim, eps)
703
- self.head = nn.Linear(dim, out_dim)
704
-
705
- def forward(self, x):
706
- x = self.head(self.norm(x))
707
- return x
708
-
709
-
710
- class MotionerTransformers(nn.Module, PeftAdapterMixin):
711
-
712
- def __init__(
713
- self,
714
- patch_size=(1, 2, 2),
715
- in_dim=16,
716
- dim=2048,
717
- ffn_dim=8192,
718
- freq_dim=256,
719
- out_dim=16,
720
- num_heads=16,
721
- num_layers=32,
722
- window_size=(-1, -1),
723
- qk_norm=True,
724
- cross_attn_norm=False,
725
- eps=1e-6,
726
- self_attn_block="SelfAttention",
727
- motion_token_num=1024,
728
- enable_tsm=False,
729
- motion_stride=4,
730
- expand_ratio=2,
731
- trainable_token_pos_emb=False,
732
- ):
733
- super().__init__()
734
- self.patch_size = patch_size
735
- self.in_dim = in_dim
736
- self.dim = dim
737
- self.ffn_dim = ffn_dim
738
- self.freq_dim = freq_dim
739
- self.out_dim = out_dim
740
- self.num_heads = num_heads
741
- self.num_layers = num_layers
742
- self.window_size = window_size
743
- self.qk_norm = qk_norm
744
- self.cross_attn_norm = cross_attn_norm
745
- self.eps = eps
746
-
747
- self.enable_tsm = enable_tsm
748
- self.motion_stride = motion_stride
749
- self.expand_ratio = expand_ratio
750
- self.sample_c = self.patch_size[0]
751
-
752
- # embeddings
753
- self.patch_embedding = nn.Conv3d(
754
- in_dim, dim, kernel_size=patch_size, stride=patch_size)
755
-
756
- # blocks
757
- self.blocks = nn.ModuleList([
758
- MotionerAttentionBlock(
759
- dim,
760
- ffn_dim,
761
- num_heads,
762
- window_size,
763
- qk_norm,
764
- cross_attn_norm,
765
- eps,
766
- self_attn_block=self_attn_block) for _ in range(num_layers)
767
- ])
768
-
769
- # buffers (don't use register_buffer otherwise dtype will be changed in to())
770
- assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
771
- d = dim // num_heads
772
- self.freqs = torch.cat([
773
- rope_params(1024, d - 4 * (d // 6)),
774
- rope_params(1024, 2 * (d // 6)),
775
- rope_params(1024, 2 * (d // 6))
776
- ],
777
- dim=1)
778
-
779
- self.gradient_checkpointing = False
780
-
781
- self.motion_side_len = int(math.sqrt(motion_token_num))
782
- assert self.motion_side_len**2 == motion_token_num
783
- self.token = nn.Parameter(
784
- torch.zeros(1, motion_token_num, dim).contiguous())
785
-
786
- self.trainable_token_pos_emb = trainable_token_pos_emb
787
- if trainable_token_pos_emb:
788
- x = torch.zeros([1, motion_token_num, num_heads, d])
789
- x[..., ::2] = 1
790
-
791
- gride_sizes = [[
792
- torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
793
- torch.tensor([1, self.motion_side_len,
794
- self.motion_side_len]).unsqueeze(0).repeat(1, 1),
795
- torch.tensor([1, self.motion_side_len,
796
- self.motion_side_len]).unsqueeze(0).repeat(1, 1),
797
- ]]
798
- token_freqs = rope_apply(x, gride_sizes, self.freqs)
799
- token_freqs = token_freqs[0, :, 0].reshape(motion_token_num, -1, 2)
800
- token_freqs = token_freqs * 0.01
801
- self.token_freqs = torch.nn.Parameter(token_freqs)
802
-
803
- def after_patch_embedding(self, x):
804
- return x
805
-
806
- def forward(
807
- self,
808
- x,
809
- ):
810
- """
811
- x: A list of videos each with shape [C, T, H, W].
812
- t: [B].
813
- context: A list of text embeddings each with shape [L, C].
814
- """
815
- # params
816
- motion_frames = x[0].shape[1]
817
- device = self.patch_embedding.weight.device
818
- freqs = self.freqs
819
- if freqs.device != device:
820
- freqs = freqs.to(device)
821
-
822
- if self.trainable_token_pos_emb:
823
- with amp.autocast(dtype=torch.float64):
824
- token_freqs = self.token_freqs.to(torch.float64)
825
- token_freqs = token_freqs / token_freqs.norm(
826
- dim=-1, keepdim=True)
827
- freqs = [freqs, torch.view_as_complex(token_freqs)]
828
-
829
- if self.enable_tsm:
830
- sample_idx = [
831
- sample_indices(
832
- u.shape[1],
833
- stride=self.motion_stride,
834
- expand_ratio=self.expand_ratio,
835
- c=self.sample_c) for u in x
836
- ]
837
- x = [
838
- torch.flip(torch.flip(u, [1])[:, idx], [1])
839
- for idx, u in zip(sample_idx, x)
840
- ]
841
-
842
- # embeddings
843
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
844
- x = self.after_patch_embedding(x)
845
-
846
- seq_f, seq_h, seq_w = x[0].shape[-3:]
847
- batch_size = len(x)
848
- if not self.enable_tsm:
849
- grid_sizes = torch.stack(
850
- [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
851
- grid_sizes = [[
852
- torch.zeros_like(grid_sizes), grid_sizes, grid_sizes
853
- ]]
854
- seq_f = 0
855
- else:
856
- grid_sizes = []
857
- for idx in sample_idx[0][::-1][::self.sample_c]:
858
- tsm_frame_grid_sizes = [[
859
- torch.tensor([idx, 0,
860
- 0]).unsqueeze(0).repeat(batch_size, 1),
861
- torch.tensor([idx + 1, seq_h,
862
- seq_w]).unsqueeze(0).repeat(batch_size, 1),
863
- torch.tensor([1, seq_h,
864
- seq_w]).unsqueeze(0).repeat(batch_size, 1),
865
- ]]
866
- grid_sizes += tsm_frame_grid_sizes
867
- seq_f = sample_idx[0][-1] + 1
868
-
869
- x = [u.flatten(2).transpose(1, 2) for u in x]
870
- seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
871
- x = torch.cat([u for u in x])
872
-
873
- batch_size = len(x)
874
-
875
- token_grid_sizes = [[
876
- torch.tensor([seq_f, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
877
- torch.tensor(
878
- [seq_f + 1, self.motion_side_len,
879
- self.motion_side_len]).unsqueeze(0).repeat(batch_size, 1),
880
- torch.tensor(
881
- [1 if not self.trainable_token_pos_emb else -1, seq_h,
882
- seq_w]).unsqueeze(0).repeat(batch_size, 1),
883
- ] # 第三行代表rope emb的想要覆盖到的范围
884
- ]
885
-
886
- grid_sizes = grid_sizes + token_grid_sizes
887
- token_unpatch_grid_sizes = torch.stack([
888
- torch.tensor([1, 32, 32], dtype=torch.long)
889
- for b in range(batch_size)
890
- ])
891
- token_len = self.token.shape[1]
892
- token = self.token.clone().repeat(x.shape[0], 1, 1).contiguous()
893
- seq_lens = seq_lens + torch.tensor([t.size(0) for t in token],
894
- dtype=torch.long)
895
- x = torch.cat([x, token], dim=1)
896
- # arguments
897
- kwargs = dict(
898
- seq_lens=seq_lens,
899
- grid_sizes=grid_sizes,
900
- freqs=freqs,
901
- )
902
-
903
- # grad ckpt args
904
- def create_custom_forward(module, return_dict=None):
905
-
906
- def custom_forward(*inputs, **kwargs):
907
- if return_dict is not None:
908
- return module(*inputs, **kwargs, return_dict=return_dict)
909
- else:
910
- return module(*inputs, **kwargs)
911
-
912
- return custom_forward
913
-
914
- ckpt_kwargs: Dict[str, Any] = ({
915
- "use_reentrant": False
916
- } if is_torch_version(">=", "1.11.0") else {})
917
-
918
- for idx, block in enumerate(self.blocks):
919
- if self.training and self.gradient_checkpointing:
920
- x = torch.utils.checkpoint.checkpoint(
921
- create_custom_forward(block),
922
- x,
923
- **kwargs,
924
- **ckpt_kwargs,
925
- )
926
- else:
927
- x = block(x, **kwargs)
928
- # head
929
- out = x[:, -token_len:]
930
- return out
931
-
932
- def unpatchify(self, x, grid_sizes):
933
- c = self.out_dim
934
- out = []
935
- for u, v in zip(x, grid_sizes.tolist()):
936
- u = u[:math.prod(v)].view(*v, *self.patch_size, c)
937
- u = torch.einsum('fhwpqrc->cfphqwr', u)
938
- u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
939
- out.append(u)
940
- return out
941
-
942
- def init_weights(self):
943
- # basic init
944
- for m in self.modules():
945
- if isinstance(m, nn.Linear):
946
- nn.init.xavier_uniform_(m.weight)
947
- if m.bias is not None:
948
- nn.init.zeros_(m.bias)
949
-
950
- # init embeddings
951
- nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
952
-
953
-
954
- class FramePackMotioner(nn.Module):
955
-
956
- def __init__(
957
- self,
958
- inner_dim=1024,
959
- num_heads=16, # Used to indicate the number of heads in the backbone network; unrelated to this module's design
960
- zip_frame_buckets=[
961
- 1, 2, 16
962
- ], # Three numbers representing the number of frames sampled for patch operations from the nearest to the farthest frames
963
- drop_mode="drop", # If not "drop", it will use "padd", meaning padding instead of deletion
964
- *args,
965
- **kwargs):
966
- super().__init__(*args, **kwargs)
967
- self.proj = nn.Conv3d(
968
- 16, inner_dim, kernel_size=(1, 2, 2), stride=(1, 2, 2))
969
- self.proj_2x = nn.Conv3d(
970
- 16, inner_dim, kernel_size=(2, 4, 4), stride=(2, 4, 4))
971
- self.proj_4x = nn.Conv3d(
972
- 16, inner_dim, kernel_size=(4, 8, 8), stride=(4, 8, 8))
973
- self.zip_frame_buckets = torch.tensor(
974
- zip_frame_buckets, dtype=torch.long)
975
-
976
- self.inner_dim = inner_dim
977
- self.num_heads = num_heads
978
-
979
- assert (inner_dim %
980
- num_heads) == 0 and (inner_dim // num_heads) % 2 == 0
981
- d = inner_dim // num_heads
982
- self.freqs = torch.cat([
983
- rope_params(1024, d - 4 * (d // 6)),
984
- rope_params(1024, 2 * (d // 6)),
985
- rope_params(1024, 2 * (d // 6))
986
- ],
987
- dim=1)
988
- self.drop_mode = drop_mode
989
-
990
- def forward(self, motion_latents, add_last_motion=2):
991
- motion_frames = motion_latents[0].shape[1]
992
- mot = []
993
- mot_remb = []
994
- for m in motion_latents:
995
- lat_height, lat_width = m.shape[2], m.shape[3]
996
- padd_lat = torch.zeros(16, self.zip_frame_buckets.sum(), lat_height,
997
- lat_width).to(
998
- device=m.device, dtype=m.dtype)
999
- overlap_frame = min(padd_lat.shape[1], m.shape[1])
1000
- if overlap_frame > 0:
1001
- padd_lat[:, -overlap_frame:] = m[:, -overlap_frame:]
1002
-
1003
- if add_last_motion < 2 and self.drop_mode != "drop":
1004
- zero_end_frame = self.zip_frame_buckets[:self.zip_frame_buckets.
1005
- __len__() -
1006
- add_last_motion -
1007
- 1].sum()
1008
- padd_lat[:, -zero_end_frame:] = 0
1009
-
1010
- padd_lat = padd_lat.unsqueeze(0)
1011
- clean_latents_4x, clean_latents_2x, clean_latents_post = padd_lat[:, :, -self.zip_frame_buckets.sum(
1012
- ):, :, :].split(
1013
- list(self.zip_frame_buckets)[::-1], dim=2) # 16, 2 ,1
1014
-
1015
- # patchfy
1016
- clean_latents_post = self.proj(clean_latents_post).flatten(
1017
- 2).transpose(1, 2)
1018
- clean_latents_2x = self.proj_2x(clean_latents_2x).flatten(
1019
- 2).transpose(1, 2)
1020
- clean_latents_4x = self.proj_4x(clean_latents_4x).flatten(
1021
- 2).transpose(1, 2)
1022
-
1023
- if add_last_motion < 2 and self.drop_mode == "drop":
1024
- clean_latents_post = clean_latents_post[:, :
1025
- 0] if add_last_motion < 2 else clean_latents_post
1026
- clean_latents_2x = clean_latents_2x[:, :
1027
- 0] if add_last_motion < 1 else clean_latents_2x
1028
-
1029
- motion_lat = torch.cat(
1030
- [clean_latents_post, clean_latents_2x, clean_latents_4x], dim=1)
1031
-
1032
- # rope
1033
- start_time_id = -(self.zip_frame_buckets[:1].sum())
1034
- end_time_id = start_time_id + self.zip_frame_buckets[0]
1035
- grid_sizes = [] if add_last_motion < 2 and self.drop_mode == "drop" else \
1036
- [
1037
- [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
1038
- torch.tensor([end_time_id, lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1),
1039
- torch.tensor([self.zip_frame_buckets[0], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
1040
- ]
1041
-
1042
- start_time_id = -(self.zip_frame_buckets[:2].sum())
1043
- end_time_id = start_time_id + self.zip_frame_buckets[1] // 2
1044
- grid_sizes_2x = [] if add_last_motion < 1 and self.drop_mode == "drop" else \
1045
- [
1046
- [torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
1047
- torch.tensor([end_time_id, lat_height // 4, lat_width // 4]).unsqueeze(0).repeat(1, 1),
1048
- torch.tensor([self.zip_frame_buckets[1], lat_height // 2, lat_width // 2]).unsqueeze(0).repeat(1, 1), ]
1049
- ]
1050
-
1051
- start_time_id = -(self.zip_frame_buckets[:3].sum())
1052
- end_time_id = start_time_id + self.zip_frame_buckets[2] // 4
1053
- grid_sizes_4x = [[
1054
- torch.tensor([start_time_id, 0, 0]).unsqueeze(0).repeat(1, 1),
1055
- torch.tensor([end_time_id, lat_height // 8,
1056
- lat_width // 8]).unsqueeze(0).repeat(1, 1),
1057
- torch.tensor([
1058
- self.zip_frame_buckets[2], lat_height // 2, lat_width // 2
1059
- ]).unsqueeze(0).repeat(1, 1),
1060
- ]]
1061
-
1062
- grid_sizes = grid_sizes + grid_sizes_2x + grid_sizes_4x
1063
-
1064
- motion_rope_emb = rope_precompute(
1065
- motion_lat.detach().view(1, motion_lat.shape[1], self.num_heads,
1066
- self.inner_dim // self.num_heads),
1067
- grid_sizes,
1068
- self.freqs,
1069
- start=None)
1070
-
1071
- mot.append(motion_lat)
1072
- mot_remb.append(motion_rope_emb)
1073
- return mot, mot_remb
1074
-
1075
-
1076
- def sample_indices(N, stride, expand_ratio, c):
1077
- indices = []
1078
- current_start = 0
1079
-
1080
- while current_start < N:
1081
- bucket_width = int(stride * (expand_ratio**(len(indices) / stride)))
1082
-
1083
- interval = int(bucket_width / stride * c)
1084
- current_end = min(N, current_start + bucket_width)
1085
- bucket_samples = []
1086
- for i in range(current_end - 1, current_start - 1, -interval):
1087
- for near in range(c):
1088
- bucket_samples.append(i - near)
1089
-
1090
- indices += bucket_samples[::-1]
1091
- current_start += bucket_width
1092
-
1093
- return indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_camera_adapter.py DELETED
@@ -1,64 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
-
4
-
5
- class SimpleAdapter(nn.Module):
6
- def __init__(self, in_dim, out_dim, kernel_size, stride, downscale_factor=8, num_residual_blocks=1):
7
- super(SimpleAdapter, self).__init__()
8
-
9
- # Pixel Unshuffle: reduce spatial dimensions by a factor of 8
10
- self.pixel_unshuffle = nn.PixelUnshuffle(downscale_factor=downscale_factor)
11
-
12
- # Convolution: reduce spatial dimensions by a factor
13
- # of 2 (without overlap)
14
- self.conv = nn.Conv2d(in_dim * downscale_factor * downscale_factor, out_dim, kernel_size=kernel_size, stride=stride, padding=0)
15
-
16
- # Residual blocks for feature extraction
17
- self.residual_blocks = nn.Sequential(
18
- *[ResidualBlock(out_dim) for _ in range(num_residual_blocks)]
19
- )
20
-
21
- def forward(self, x):
22
- # Reshape to merge the frame dimension into batch
23
- bs, c, f, h, w = x.size()
24
- x = x.permute(0, 2, 1, 3, 4).contiguous().view(bs * f, c, h, w)
25
-
26
- # Pixel Unshuffle operation
27
- x_unshuffled = self.pixel_unshuffle(x)
28
-
29
- # Convolution operation
30
- x_conv = self.conv(x_unshuffled)
31
-
32
- # Feature extraction with residual blocks
33
- out = self.residual_blocks(x_conv)
34
-
35
- # Reshape to restore original bf dimension
36
- out = out.view(bs, f, out.size(1), out.size(2), out.size(3))
37
-
38
- # Permute dimensions to reorder (if needed), e.g., swap channels and feature frames
39
- out = out.permute(0, 2, 1, 3, 4)
40
-
41
- return out
42
-
43
-
44
- class ResidualBlock(nn.Module):
45
- def __init__(self, dim):
46
- super(ResidualBlock, self).__init__()
47
- self.conv1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
48
- self.relu = nn.ReLU(inplace=True)
49
- self.conv2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)
50
-
51
- def forward(self, x):
52
- residual = x
53
- out = self.relu(self.conv1(x))
54
- out = self.conv2(out)
55
- out += residual
56
- return out
57
-
58
- # Example usage
59
- # in_dim = 3
60
- # out_dim = 64
61
- # adapter = SimpleAdapterWithReshape(in_dim, out_dim)
62
- # x = torch.randn(1, in_dim, 4, 64, 64) # e.g., batch size = 1, channels = 3, frames/features = 4
63
- # output = adapter(x)
64
- # print(output.shape) # Should reflect transformed dimensions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_image_encoder.py DELETED
@@ -1,553 +0,0 @@
1
- # Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- import math
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import torchvision.transforms as T
9
-
10
- from .attention_utils import attention, flash_attention
11
- from .wan_xlm_roberta import XLMRoberta
12
- from diffusers.configuration_utils import ConfigMixin
13
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
14
- from diffusers.models.modeling_utils import ModelMixin
15
-
16
-
17
- __all__ = [
18
- 'XLMRobertaCLIP',
19
- 'clip_xlm_roberta_vit_h_14',
20
- 'CLIPModel',
21
- ]
22
-
23
-
24
- def pos_interpolate(pos, seq_len):
25
- if pos.size(1) == seq_len:
26
- return pos
27
- else:
28
- src_grid = int(math.sqrt(pos.size(1)))
29
- tar_grid = int(math.sqrt(seq_len))
30
- n = pos.size(1) - src_grid * src_grid
31
- return torch.cat([
32
- pos[:, :n],
33
- F.interpolate(
34
- pos[:, n:].float().reshape(1, src_grid, src_grid, -1).permute(
35
- 0, 3, 1, 2),
36
- size=(tar_grid, tar_grid),
37
- mode='bicubic',
38
- align_corners=False).flatten(2).transpose(1, 2)
39
- ],
40
- dim=1)
41
-
42
-
43
- class QuickGELU(nn.Module):
44
-
45
- def forward(self, x):
46
- return x * torch.sigmoid(1.702 * x)
47
-
48
-
49
- class LayerNorm(nn.LayerNorm):
50
-
51
- def forward(self, x):
52
- return super().forward(x.float()).type_as(x)
53
-
54
-
55
- class SelfAttention(nn.Module):
56
-
57
- def __init__(self,
58
- dim,
59
- num_heads,
60
- causal=False,
61
- attn_dropout=0.0,
62
- proj_dropout=0.0):
63
- assert dim % num_heads == 0
64
- super().__init__()
65
- self.dim = dim
66
- self.num_heads = num_heads
67
- self.head_dim = dim // num_heads
68
- self.causal = causal
69
- self.attn_dropout = attn_dropout
70
- self.proj_dropout = proj_dropout
71
-
72
- # layers
73
- self.to_qkv = nn.Linear(dim, dim * 3)
74
- self.proj = nn.Linear(dim, dim)
75
-
76
- def forward(self, x):
77
- """
78
- x: [B, L, C].
79
- """
80
- b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
81
-
82
- # compute query, key, value
83
- q, k, v = self.to_qkv(x).view(b, s, 3, n, d).unbind(2)
84
-
85
- # compute attention
86
- p = self.attn_dropout if self.training else 0.0
87
- x = attention(q, k, v, dropout_p=p, causal=self.causal, attention_type="none")
88
- x = x.reshape(b, s, c)
89
-
90
- # output
91
- x = self.proj(x)
92
- x = F.dropout(x, self.proj_dropout, self.training)
93
- return x
94
-
95
-
96
- class SwiGLU(nn.Module):
97
-
98
- def __init__(self, dim, mid_dim):
99
- super().__init__()
100
- self.dim = dim
101
- self.mid_dim = mid_dim
102
-
103
- # layers
104
- self.fc1 = nn.Linear(dim, mid_dim)
105
- self.fc2 = nn.Linear(dim, mid_dim)
106
- self.fc3 = nn.Linear(mid_dim, dim)
107
-
108
- def forward(self, x):
109
- x = F.silu(self.fc1(x)) * self.fc2(x)
110
- x = self.fc3(x)
111
- return x
112
-
113
-
114
- class AttentionBlock(nn.Module):
115
-
116
- def __init__(self,
117
- dim,
118
- mlp_ratio,
119
- num_heads,
120
- post_norm=False,
121
- causal=False,
122
- activation='quick_gelu',
123
- attn_dropout=0.0,
124
- proj_dropout=0.0,
125
- norm_eps=1e-5):
126
- assert activation in ['quick_gelu', 'gelu', 'swi_glu']
127
- super().__init__()
128
- self.dim = dim
129
- self.mlp_ratio = mlp_ratio
130
- self.num_heads = num_heads
131
- self.post_norm = post_norm
132
- self.causal = causal
133
- self.norm_eps = norm_eps
134
-
135
- # layers
136
- self.norm1 = LayerNorm(dim, eps=norm_eps)
137
- self.attn = SelfAttention(dim, num_heads, causal, attn_dropout,
138
- proj_dropout)
139
- self.norm2 = LayerNorm(dim, eps=norm_eps)
140
- if activation == 'swi_glu':
141
- self.mlp = SwiGLU(dim, int(dim * mlp_ratio))
142
- else:
143
- self.mlp = nn.Sequential(
144
- nn.Linear(dim, int(dim * mlp_ratio)),
145
- QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
146
- nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
147
-
148
- def forward(self, x):
149
- if self.post_norm:
150
- x = x + self.norm1(self.attn(x))
151
- x = x + self.norm2(self.mlp(x))
152
- else:
153
- x = x + self.attn(self.norm1(x))
154
- x = x + self.mlp(self.norm2(x))
155
- return x
156
-
157
-
158
- class AttentionPool(nn.Module):
159
-
160
- def __init__(self,
161
- dim,
162
- mlp_ratio,
163
- num_heads,
164
- activation='gelu',
165
- proj_dropout=0.0,
166
- norm_eps=1e-5):
167
- assert dim % num_heads == 0
168
- super().__init__()
169
- self.dim = dim
170
- self.mlp_ratio = mlp_ratio
171
- self.num_heads = num_heads
172
- self.head_dim = dim // num_heads
173
- self.proj_dropout = proj_dropout
174
- self.norm_eps = norm_eps
175
-
176
- # layers
177
- gain = 1.0 / math.sqrt(dim)
178
- self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
179
- self.to_q = nn.Linear(dim, dim)
180
- self.to_kv = nn.Linear(dim, dim * 2)
181
- self.proj = nn.Linear(dim, dim)
182
- self.norm = LayerNorm(dim, eps=norm_eps)
183
- self.mlp = nn.Sequential(
184
- nn.Linear(dim, int(dim * mlp_ratio)),
185
- QuickGELU() if activation == 'quick_gelu' else nn.GELU(),
186
- nn.Linear(int(dim * mlp_ratio), dim), nn.Dropout(proj_dropout))
187
-
188
- def forward(self, x):
189
- """
190
- x: [B, L, C].
191
- """
192
- b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
193
-
194
- # compute query, key, value
195
- q = self.to_q(self.cls_embedding).view(1, 1, n, d).expand(b, -1, -1, -1)
196
- k, v = self.to_kv(x).view(b, s, 2, n, d).unbind(2)
197
-
198
- # compute attention
199
- x = flash_attention(q, k, v, version=2)
200
- x = x.reshape(b, 1, c)
201
-
202
- # output
203
- x = self.proj(x)
204
- x = F.dropout(x, self.proj_dropout, self.training)
205
-
206
- # mlp
207
- x = x + self.mlp(self.norm(x))
208
- return x[:, 0]
209
-
210
-
211
- class VisionTransformer(nn.Module):
212
-
213
- def __init__(self,
214
- image_size=224,
215
- patch_size=16,
216
- dim=768,
217
- mlp_ratio=4,
218
- out_dim=512,
219
- num_heads=12,
220
- num_layers=12,
221
- pool_type='token',
222
- pre_norm=True,
223
- post_norm=False,
224
- activation='quick_gelu',
225
- attn_dropout=0.0,
226
- proj_dropout=0.0,
227
- embedding_dropout=0.0,
228
- norm_eps=1e-5):
229
- if image_size % patch_size != 0:
230
- print(
231
- '[WARNING] image_size is not divisible by patch_size',
232
- flush=True)
233
- assert pool_type in ('token', 'token_fc', 'attn_pool')
234
- out_dim = out_dim or dim
235
- super().__init__()
236
- self.image_size = image_size
237
- self.patch_size = patch_size
238
- self.num_patches = (image_size // patch_size)**2
239
- self.dim = dim
240
- self.mlp_ratio = mlp_ratio
241
- self.out_dim = out_dim
242
- self.num_heads = num_heads
243
- self.num_layers = num_layers
244
- self.pool_type = pool_type
245
- self.post_norm = post_norm
246
- self.norm_eps = norm_eps
247
-
248
- # embeddings
249
- gain = 1.0 / math.sqrt(dim)
250
- self.patch_embedding = nn.Conv2d(
251
- 3,
252
- dim,
253
- kernel_size=patch_size,
254
- stride=patch_size,
255
- bias=not pre_norm)
256
- if pool_type in ('token', 'token_fc'):
257
- self.cls_embedding = nn.Parameter(gain * torch.randn(1, 1, dim))
258
- self.pos_embedding = nn.Parameter(gain * torch.randn(
259
- 1, self.num_patches +
260
- (1 if pool_type in ('token', 'token_fc') else 0), dim))
261
- self.dropout = nn.Dropout(embedding_dropout)
262
-
263
- # transformer
264
- self.pre_norm = LayerNorm(dim, eps=norm_eps) if pre_norm else None
265
- self.transformer = nn.Sequential(*[
266
- AttentionBlock(dim, mlp_ratio, num_heads, post_norm, False,
267
- activation, attn_dropout, proj_dropout, norm_eps)
268
- for _ in range(num_layers)
269
- ])
270
- self.post_norm = LayerNorm(dim, eps=norm_eps)
271
-
272
- # head
273
- if pool_type == 'token':
274
- self.head = nn.Parameter(gain * torch.randn(dim, out_dim))
275
- elif pool_type == 'token_fc':
276
- self.head = nn.Linear(dim, out_dim)
277
- elif pool_type == 'attn_pool':
278
- self.head = AttentionPool(dim, mlp_ratio, num_heads, activation,
279
- proj_dropout, norm_eps)
280
-
281
- def forward(self, x, interpolation=False, use_31_block=False):
282
- b = x.size(0)
283
-
284
- # embeddings
285
- x = self.patch_embedding(x).flatten(2).permute(0, 2, 1)
286
- if self.pool_type in ('token', 'token_fc'):
287
- x = torch.cat([self.cls_embedding.expand(b, -1, -1), x], dim=1)
288
- if interpolation:
289
- e = pos_interpolate(self.pos_embedding, x.size(1))
290
- else:
291
- e = self.pos_embedding
292
- x = self.dropout(x + e)
293
- if self.pre_norm is not None:
294
- x = self.pre_norm(x)
295
-
296
- # transformer
297
- if use_31_block:
298
- x = self.transformer[:-1](x)
299
- return x
300
- else:
301
- x = self.transformer(x)
302
- return x
303
-
304
-
305
- class XLMRobertaWithHead(XLMRoberta):
306
-
307
- def __init__(self, **kwargs):
308
- self.out_dim = kwargs.pop('out_dim')
309
- super().__init__(**kwargs)
310
-
311
- # head
312
- mid_dim = (self.dim + self.out_dim) // 2
313
- self.head = nn.Sequential(
314
- nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(),
315
- nn.Linear(mid_dim, self.out_dim, bias=False))
316
-
317
- def forward(self, ids):
318
- # xlm-roberta
319
- x = super().forward(ids)
320
-
321
- # average pooling
322
- mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
323
- x = (x * mask).sum(dim=1) / mask.sum(dim=1)
324
-
325
- # head
326
- x = self.head(x)
327
- return x
328
-
329
-
330
- class XLMRobertaCLIP(nn.Module):
331
-
332
- def __init__(self,
333
- embed_dim=1024,
334
- image_size=224,
335
- patch_size=14,
336
- vision_dim=1280,
337
- vision_mlp_ratio=4,
338
- vision_heads=16,
339
- vision_layers=32,
340
- vision_pool='token',
341
- vision_pre_norm=True,
342
- vision_post_norm=False,
343
- activation='gelu',
344
- vocab_size=250002,
345
- max_text_len=514,
346
- type_size=1,
347
- pad_id=1,
348
- text_dim=1024,
349
- text_heads=16,
350
- text_layers=24,
351
- text_post_norm=True,
352
- text_dropout=0.1,
353
- attn_dropout=0.0,
354
- proj_dropout=0.0,
355
- embedding_dropout=0.0,
356
- norm_eps=1e-5):
357
- super().__init__()
358
- self.embed_dim = embed_dim
359
- self.image_size = image_size
360
- self.patch_size = patch_size
361
- self.vision_dim = vision_dim
362
- self.vision_mlp_ratio = vision_mlp_ratio
363
- self.vision_heads = vision_heads
364
- self.vision_layers = vision_layers
365
- self.vision_pre_norm = vision_pre_norm
366
- self.vision_post_norm = vision_post_norm
367
- self.activation = activation
368
- self.vocab_size = vocab_size
369
- self.max_text_len = max_text_len
370
- self.type_size = type_size
371
- self.pad_id = pad_id
372
- self.text_dim = text_dim
373
- self.text_heads = text_heads
374
- self.text_layers = text_layers
375
- self.text_post_norm = text_post_norm
376
- self.norm_eps = norm_eps
377
-
378
- # models
379
- self.visual = VisionTransformer(
380
- image_size=image_size,
381
- patch_size=patch_size,
382
- dim=vision_dim,
383
- mlp_ratio=vision_mlp_ratio,
384
- out_dim=embed_dim,
385
- num_heads=vision_heads,
386
- num_layers=vision_layers,
387
- pool_type=vision_pool,
388
- pre_norm=vision_pre_norm,
389
- post_norm=vision_post_norm,
390
- activation=activation,
391
- attn_dropout=attn_dropout,
392
- proj_dropout=proj_dropout,
393
- embedding_dropout=embedding_dropout,
394
- norm_eps=norm_eps)
395
- self.textual = XLMRobertaWithHead(
396
- vocab_size=vocab_size,
397
- max_seq_len=max_text_len,
398
- type_size=type_size,
399
- pad_id=pad_id,
400
- dim=text_dim,
401
- out_dim=embed_dim,
402
- num_heads=text_heads,
403
- num_layers=text_layers,
404
- post_norm=text_post_norm,
405
- dropout=text_dropout)
406
- self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
407
-
408
- def forward(self, imgs, txt_ids):
409
- """
410
- imgs: [B, 3, H, W] of torch.float32.
411
- - mean: [0.48145466, 0.4578275, 0.40821073]
412
- - std: [0.26862954, 0.26130258, 0.27577711]
413
- txt_ids: [B, L] of torch.long.
414
- Encoded by data.CLIPTokenizer.
415
- """
416
- xi = self.visual(imgs)
417
- xt = self.textual(txt_ids)
418
- return xi, xt
419
-
420
- def param_groups(self):
421
- groups = [{
422
- 'params': [
423
- p for n, p in self.named_parameters()
424
- if 'norm' in n or n.endswith('bias')
425
- ],
426
- 'weight_decay': 0.0
427
- }, {
428
- 'params': [
429
- p for n, p in self.named_parameters()
430
- if not ('norm' in n or n.endswith('bias'))
431
- ]
432
- }]
433
- return groups
434
-
435
-
436
- def _clip(pretrained=False,
437
- pretrained_name=None,
438
- model_cls=XLMRobertaCLIP,
439
- return_transforms=False,
440
- return_tokenizer=False,
441
- tokenizer_padding='eos',
442
- dtype=torch.float32,
443
- device='cpu',
444
- **kwargs):
445
- # init a model on device
446
- with torch.device(device):
447
- model = model_cls(**kwargs)
448
-
449
- # set device
450
- model = model.to(dtype=dtype, device=device)
451
- output = (model,)
452
-
453
- # init transforms
454
- if return_transforms:
455
- # mean and std
456
- if 'siglip' in pretrained_name.lower():
457
- mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
458
- else:
459
- mean = [0.48145466, 0.4578275, 0.40821073]
460
- std = [0.26862954, 0.26130258, 0.27577711]
461
-
462
- # transforms
463
- transforms = T.Compose([
464
- T.Resize((model.image_size, model.image_size),
465
- interpolation=T.InterpolationMode.BICUBIC),
466
- T.ToTensor(),
467
- T.Normalize(mean=mean, std=std)
468
- ])
469
- output += (transforms,)
470
- return output[0] if len(output) == 1 else output
471
-
472
-
473
- def clip_xlm_roberta_vit_h_14(
474
- pretrained=False,
475
- pretrained_name='open-clip-xlm-roberta-large-vit-huge-14',
476
- **kwargs):
477
- cfg = dict(
478
- embed_dim=1024,
479
- image_size=224,
480
- patch_size=14,
481
- vision_dim=1280,
482
- vision_mlp_ratio=4,
483
- vision_heads=16,
484
- vision_layers=32,
485
- vision_pool='token',
486
- activation='gelu',
487
- vocab_size=250002,
488
- max_text_len=514,
489
- type_size=1,
490
- pad_id=1,
491
- text_dim=1024,
492
- text_heads=16,
493
- text_layers=24,
494
- text_post_norm=True,
495
- text_dropout=0.1,
496
- attn_dropout=0.0,
497
- proj_dropout=0.0,
498
- embedding_dropout=0.0)
499
- cfg.update(**kwargs)
500
- return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
501
-
502
-
503
- class CLIPModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
504
-
505
- def __init__(self):
506
- super(CLIPModel, self).__init__()
507
- # init model
508
- self.model, self.transforms = clip_xlm_roberta_vit_h_14(
509
- pretrained=False,
510
- return_transforms=True,
511
- return_tokenizer=False)
512
-
513
- def forward(self, videos):
514
- # preprocess
515
- size = (self.model.image_size,) * 2
516
- videos = torch.cat([
517
- F.interpolate(
518
- u.transpose(0, 1),
519
- size=size,
520
- mode='bicubic',
521
- align_corners=False) for u in videos
522
- ])
523
- videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
524
-
525
- # forward
526
- with torch.cuda.amp.autocast(dtype=self.dtype):
527
- out = self.model.visual(videos, use_31_block=True)
528
- return out
529
-
530
- @classmethod
531
- def from_pretrained(cls, pretrained_model_path, transformer_additional_kwargs={}):
532
- def filter_kwargs(cls, kwargs):
533
- import inspect
534
- sig = inspect.signature(cls.__init__)
535
- valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
536
- filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
537
- return filtered_kwargs
538
-
539
- model = cls(**filter_kwargs(cls, transformer_additional_kwargs))
540
- if pretrained_model_path.endswith(".safetensors"):
541
- from safetensors.torch import load_file, safe_open
542
- state_dict = load_file(pretrained_model_path)
543
- else:
544
- state_dict = torch.load(pretrained_model_path, map_location="cpu")
545
- tmp_state_dict = {}
546
- for key in state_dict:
547
- tmp_state_dict["model." + key] = state_dict[key]
548
- state_dict = tmp_state_dict
549
- m, u = model.load_state_dict(state_dict)
550
-
551
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
552
- print(m, u)
553
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_text_encoder.py DELETED
@@ -1,395 +0,0 @@
1
- # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/t5.py
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- import math
4
- from typing import Optional
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from diffusers.configuration_utils import ConfigMixin
10
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
- from diffusers.models.modeling_utils import ModelMixin
12
-
13
-
14
- def fp16_clamp(x):
15
- if x.dtype == torch.float16 and torch.isinf(x).any():
16
- clamp = torch.finfo(x.dtype).max - 1000
17
- x = torch.clamp(x, min=-clamp, max=clamp)
18
- return x
19
-
20
-
21
- def init_weights(m):
22
- if isinstance(m, T5LayerNorm):
23
- nn.init.ones_(m.weight)
24
- elif isinstance(m, T5FeedForward):
25
- nn.init.normal_(m.gate[0].weight, std=m.dim**-0.5)
26
- nn.init.normal_(m.fc1.weight, std=m.dim**-0.5)
27
- nn.init.normal_(m.fc2.weight, std=m.dim_ffn**-0.5)
28
- elif isinstance(m, T5Attention):
29
- nn.init.normal_(m.q.weight, std=(m.dim * m.dim_attn)**-0.5)
30
- nn.init.normal_(m.k.weight, std=m.dim**-0.5)
31
- nn.init.normal_(m.v.weight, std=m.dim**-0.5)
32
- nn.init.normal_(m.o.weight, std=(m.num_heads * m.dim_attn)**-0.5)
33
- elif isinstance(m, T5RelativeEmbedding):
34
- nn.init.normal_(
35
- m.embedding.weight, std=(2 * m.num_buckets * m.num_heads)**-0.5)
36
-
37
-
38
- class GELU(nn.Module):
39
- def forward(self, x):
40
- return 0.5 * x * (1.0 + torch.tanh(
41
- math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
42
-
43
-
44
- class T5LayerNorm(nn.Module):
45
- def __init__(self, dim, eps=1e-6):
46
- super(T5LayerNorm, self).__init__()
47
- self.dim = dim
48
- self.eps = eps
49
- self.weight = nn.Parameter(torch.ones(dim))
50
-
51
- def forward(self, x):
52
- x = x * torch.rsqrt(x.float().pow(2).mean(dim=-1, keepdim=True) +
53
- self.eps)
54
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
55
- x = x.type_as(self.weight)
56
- return self.weight * x
57
-
58
-
59
- class T5Attention(nn.Module):
60
- def __init__(self, dim, dim_attn, num_heads, dropout=0.1):
61
- assert dim_attn % num_heads == 0
62
- super(T5Attention, self).__init__()
63
- self.dim = dim
64
- self.dim_attn = dim_attn
65
- self.num_heads = num_heads
66
- self.head_dim = dim_attn // num_heads
67
-
68
- # layers
69
- self.q = nn.Linear(dim, dim_attn, bias=False)
70
- self.k = nn.Linear(dim, dim_attn, bias=False)
71
- self.v = nn.Linear(dim, dim_attn, bias=False)
72
- self.o = nn.Linear(dim_attn, dim, bias=False)
73
- self.dropout = nn.Dropout(dropout)
74
-
75
- def forward(self, x, context=None, mask=None, pos_bias=None):
76
- """
77
- x: [B, L1, C].
78
- context: [B, L2, C] or None.
79
- mask: [B, L2] or [B, L1, L2] or None.
80
- """
81
- # check inputs
82
- context = x if context is None else context
83
- b, n, c = x.size(0), self.num_heads, self.head_dim
84
-
85
- # compute query, key, value
86
- q = self.q(x).view(b, -1, n, c)
87
- k = self.k(context).view(b, -1, n, c)
88
- v = self.v(context).view(b, -1, n, c)
89
-
90
- # attention bias
91
- attn_bias = x.new_zeros(b, n, q.size(1), k.size(1))
92
- if pos_bias is not None:
93
- attn_bias += pos_bias
94
- if mask is not None:
95
- assert mask.ndim in [2, 3]
96
- mask = mask.view(b, 1, 1,
97
- -1) if mask.ndim == 2 else mask.unsqueeze(1)
98
- attn_bias.masked_fill_(mask == 0, torch.finfo(x.dtype).min)
99
-
100
- # compute attention (T5 does not use scaling)
101
- attn = torch.einsum('binc,bjnc->bnij', q, k) + attn_bias
102
- attn = F.softmax(attn.float(), dim=-1).type_as(attn)
103
- x = torch.einsum('bnij,bjnc->binc', attn, v)
104
-
105
- # output
106
- x = x.reshape(b, -1, n * c)
107
- x = self.o(x)
108
- x = self.dropout(x)
109
- return x
110
-
111
-
112
- class T5FeedForward(nn.Module):
113
-
114
- def __init__(self, dim, dim_ffn, dropout=0.1):
115
- super(T5FeedForward, self).__init__()
116
- self.dim = dim
117
- self.dim_ffn = dim_ffn
118
-
119
- # layers
120
- self.gate = nn.Sequential(nn.Linear(dim, dim_ffn, bias=False), GELU())
121
- self.fc1 = nn.Linear(dim, dim_ffn, bias=False)
122
- self.fc2 = nn.Linear(dim_ffn, dim, bias=False)
123
- self.dropout = nn.Dropout(dropout)
124
-
125
- def forward(self, x):
126
- x = self.fc1(x) * self.gate(x)
127
- x = self.dropout(x)
128
- x = self.fc2(x)
129
- x = self.dropout(x)
130
- return x
131
-
132
-
133
- class T5SelfAttention(nn.Module):
134
- def __init__(self,
135
- dim,
136
- dim_attn,
137
- dim_ffn,
138
- num_heads,
139
- num_buckets,
140
- shared_pos=True,
141
- dropout=0.1):
142
- super(T5SelfAttention, self).__init__()
143
- self.dim = dim
144
- self.dim_attn = dim_attn
145
- self.dim_ffn = dim_ffn
146
- self.num_heads = num_heads
147
- self.num_buckets = num_buckets
148
- self.shared_pos = shared_pos
149
-
150
- # layers
151
- self.norm1 = T5LayerNorm(dim)
152
- self.attn = T5Attention(dim, dim_attn, num_heads, dropout)
153
- self.norm2 = T5LayerNorm(dim)
154
- self.ffn = T5FeedForward(dim, dim_ffn, dropout)
155
- self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
156
- num_buckets, num_heads, bidirectional=True)
157
-
158
- def forward(self, x, mask=None, pos_bias=None):
159
- e = pos_bias if self.shared_pos else self.pos_embedding(
160
- x.size(1), x.size(1))
161
- x = fp16_clamp(x + self.attn(self.norm1(x), mask=mask, pos_bias=e))
162
- x = fp16_clamp(x + self.ffn(self.norm2(x)))
163
- return x
164
-
165
-
166
- class T5CrossAttention(nn.Module):
167
- def __init__(self,
168
- dim,
169
- dim_attn,
170
- dim_ffn,
171
- num_heads,
172
- num_buckets,
173
- shared_pos=True,
174
- dropout=0.1):
175
- super(T5CrossAttention, self).__init__()
176
- self.dim = dim
177
- self.dim_attn = dim_attn
178
- self.dim_ffn = dim_ffn
179
- self.num_heads = num_heads
180
- self.num_buckets = num_buckets
181
- self.shared_pos = shared_pos
182
-
183
- # layers
184
- self.norm1 = T5LayerNorm(dim)
185
- self.self_attn = T5Attention(dim, dim_attn, num_heads, dropout)
186
- self.norm2 = T5LayerNorm(dim)
187
- self.cross_attn = T5Attention(dim, dim_attn, num_heads, dropout)
188
- self.norm3 = T5LayerNorm(dim)
189
- self.ffn = T5FeedForward(dim, dim_ffn, dropout)
190
- self.pos_embedding = None if shared_pos else T5RelativeEmbedding(
191
- num_buckets, num_heads, bidirectional=False)
192
-
193
- def forward(self,
194
- x,
195
- mask=None,
196
- encoder_states=None,
197
- encoder_mask=None,
198
- pos_bias=None):
199
- e = pos_bias if self.shared_pos else self.pos_embedding(
200
- x.size(1), x.size(1))
201
- x = fp16_clamp(x + self.self_attn(self.norm1(x), mask=mask, pos_bias=e))
202
- x = fp16_clamp(x + self.cross_attn(
203
- self.norm2(x), context=encoder_states, mask=encoder_mask))
204
- x = fp16_clamp(x + self.ffn(self.norm3(x)))
205
- return x
206
-
207
-
208
- class T5RelativeEmbedding(nn.Module):
209
- def __init__(self, num_buckets, num_heads, bidirectional, max_dist=128):
210
- super(T5RelativeEmbedding, self).__init__()
211
- self.num_buckets = num_buckets
212
- self.num_heads = num_heads
213
- self.bidirectional = bidirectional
214
- self.max_dist = max_dist
215
-
216
- # layers
217
- self.embedding = nn.Embedding(num_buckets, num_heads)
218
-
219
- def forward(self, lq, lk):
220
- device = self.embedding.weight.device
221
- # rel_pos = torch.arange(lk).unsqueeze(0).to(device) - \
222
- # torch.arange(lq).unsqueeze(1).to(device)
223
- if torch.device(type="meta") != device:
224
- rel_pos = torch.arange(lk, device=device).unsqueeze(0) - \
225
- torch.arange(lq, device=device).unsqueeze(1)
226
- else:
227
- rel_pos = torch.arange(lk).unsqueeze(0) - \
228
- torch.arange(lq).unsqueeze(1)
229
- rel_pos = self._relative_position_bucket(rel_pos)
230
- rel_pos_embeds = self.embedding(rel_pos)
231
- rel_pos_embeds = rel_pos_embeds.permute(2, 0, 1).unsqueeze(
232
- 0) # [1, N, Lq, Lk]
233
- return rel_pos_embeds.contiguous()
234
-
235
- def _relative_position_bucket(self, rel_pos):
236
- # preprocess
237
- if self.bidirectional:
238
- num_buckets = self.num_buckets // 2
239
- rel_buckets = (rel_pos > 0).long() * num_buckets
240
- rel_pos = torch.abs(rel_pos)
241
- else:
242
- num_buckets = self.num_buckets
243
- rel_buckets = 0
244
- rel_pos = -torch.min(rel_pos, torch.zeros_like(rel_pos))
245
-
246
- # embeddings for small and large positions
247
- max_exact = num_buckets // 2
248
- rel_pos_large = max_exact + (torch.log(rel_pos.float() / max_exact) /
249
- math.log(self.max_dist / max_exact) *
250
- (num_buckets - max_exact)).long()
251
- rel_pos_large = torch.min(
252
- rel_pos_large, torch.full_like(rel_pos_large, num_buckets - 1))
253
- rel_buckets += torch.where(rel_pos < max_exact, rel_pos, rel_pos_large)
254
- return rel_buckets
255
-
256
- class WanT5EncoderModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
257
- def __init__(self,
258
- vocab,
259
- dim,
260
- dim_attn,
261
- dim_ffn,
262
- num_heads,
263
- num_layers,
264
- num_buckets,
265
- shared_pos=True,
266
- dropout=0.1):
267
- super(WanT5EncoderModel, self).__init__()
268
- self.dim = dim
269
- self.dim_attn = dim_attn
270
- self.dim_ffn = dim_ffn
271
- self.num_heads = num_heads
272
- self.num_layers = num_layers
273
- self.num_buckets = num_buckets
274
- self.shared_pos = shared_pos
275
-
276
- # layers
277
- self.token_embedding = vocab if isinstance(vocab, nn.Embedding) \
278
- else nn.Embedding(vocab, dim)
279
- self.pos_embedding = T5RelativeEmbedding(
280
- num_buckets, num_heads, bidirectional=True) if shared_pos else None
281
- self.dropout = nn.Dropout(dropout)
282
- self.blocks = nn.ModuleList([
283
- T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets,
284
- shared_pos, dropout) for _ in range(num_layers)
285
- ])
286
- self.norm = T5LayerNorm(dim)
287
-
288
- # initialize weights
289
- self.apply(init_weights)
290
-
291
- def forward(
292
- self,
293
- input_ids: Optional[torch.LongTensor] = None,
294
- attention_mask: Optional[torch.FloatTensor] = None,
295
- ):
296
- x = self.token_embedding(input_ids)
297
- x = self.dropout(x)
298
- e = self.pos_embedding(x.size(1),
299
- x.size(1)) if self.shared_pos else None
300
- for block in self.blocks:
301
- x = block(x, attention_mask, pos_bias=e)
302
- x = self.norm(x)
303
- x = self.dropout(x)
304
- return (x, )
305
-
306
- @classmethod
307
- def from_pretrained(cls, pretrained_model_path, additional_kwargs={}, low_cpu_mem_usage=False, torch_dtype=torch.bfloat16):
308
- def filter_kwargs(cls, kwargs):
309
- import inspect
310
- sig = inspect.signature(cls.__init__)
311
- valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
312
- filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
313
- return filtered_kwargs
314
-
315
- if low_cpu_mem_usage:
316
- try:
317
- import re
318
-
319
- from diffusers import __version__ as diffusers_version
320
- if diffusers_version >= "0.33.0":
321
- from diffusers.models.model_loading_utils import \
322
- load_model_dict_into_meta
323
- else:
324
- from diffusers.models.modeling_utils import \
325
- load_model_dict_into_meta
326
- from diffusers.utils import is_accelerate_available
327
- if is_accelerate_available():
328
- import accelerate
329
-
330
- # Instantiate model with empty weights
331
- with accelerate.init_empty_weights():
332
- model = cls(**filter_kwargs(cls, additional_kwargs))
333
-
334
- param_device = "cpu"
335
- if pretrained_model_path.endswith(".safetensors"):
336
- from safetensors.torch import load_file
337
- state_dict = load_file(pretrained_model_path)
338
- else:
339
- state_dict = torch.load(pretrained_model_path, map_location="cpu")
340
-
341
- if diffusers_version >= "0.33.0":
342
- # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
343
- # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
344
- load_model_dict_into_meta(
345
- model,
346
- state_dict,
347
- dtype=torch_dtype,
348
- model_name_or_path=pretrained_model_path,
349
- )
350
- else:
351
- # move the params from meta device to cpu
352
- missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
353
- if len(missing_keys) > 0:
354
- raise ValueError(
355
- f"Cannot load {cls} from {pretrained_model_path} because the following keys are"
356
- f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
357
- " `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
358
- " those weights or else make sure your checkpoint file is correct."
359
- )
360
-
361
- unexpected_keys = load_model_dict_into_meta(
362
- model,
363
- state_dict,
364
- device=param_device,
365
- dtype=torch_dtype,
366
- model_name_or_path=pretrained_model_path,
367
- )
368
-
369
- if cls._keys_to_ignore_on_load_unexpected is not None:
370
- for pat in cls._keys_to_ignore_on_load_unexpected:
371
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
372
-
373
- if len(unexpected_keys) > 0:
374
- print(
375
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
376
- )
377
-
378
- return model
379
- except Exception as e:
380
- print(
381
- f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
382
- )
383
-
384
- model = cls(**filter_kwargs(cls, additional_kwargs))
385
- if pretrained_model_path.endswith(".safetensors"):
386
- from safetensors.torch import load_file, safe_open
387
- state_dict = load_file(pretrained_model_path)
388
- else:
389
- state_dict = torch.load(pretrained_model_path, map_location="cpu")
390
- m, u = model.load_state_dict(state_dict, strict=False)
391
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
392
- print(m, u)
393
-
394
- model = model.to(torch_dtype)
395
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_transformer3d.py DELETED
@@ -1,1394 +0,0 @@
1
- # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/model.py
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
-
4
- import glob
5
- import json
6
- import math
7
- import os
8
- import types
9
- import warnings
10
- from typing import Any, Dict, Optional, Union
11
-
12
- import numpy as np
13
- import torch
14
- import torch.cuda.amp as amp
15
- import torch.nn as nn
16
- from diffusers.configuration_utils import ConfigMixin, register_to_config
17
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
18
- from diffusers.models.modeling_utils import ModelMixin
19
- from diffusers.utils import is_torch_version, logging
20
- from torch import nn
21
-
22
- from ..dist import (get_sequence_parallel_rank,
23
- get_sequence_parallel_world_size, get_sp_group,
24
- usp_attn_forward, xFuserLongContextAttention)
25
- from ..utils import cfg_skip
26
- from .attention_utils import attention
27
- from .cache_utils import TeaCache
28
- from .wan_camera_adapter import SimpleAdapter
29
-
30
-
31
- def sinusoidal_embedding_1d(dim, position):
32
- # preprocess
33
- assert dim % 2 == 0
34
- half = dim // 2
35
- position = position.type(torch.float64)
36
-
37
- # calculation
38
- sinusoid = torch.outer(
39
- position, torch.pow(10000, -torch.arange(half).to(position).div(half)))
40
- x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
41
- return x
42
-
43
-
44
- @amp.autocast(enabled=False)
45
- def rope_params(max_seq_len, dim, theta=10000):
46
- assert dim % 2 == 0
47
- freqs = torch.outer(
48
- torch.arange(max_seq_len),
49
- 1.0 / torch.pow(theta,
50
- torch.arange(0, dim, 2).to(torch.float64).div(dim)))
51
- freqs = torch.polar(torch.ones_like(freqs), freqs)
52
- return freqs
53
-
54
-
55
- # modified from https://github.com/thu-ml/RIFLEx/blob/main/riflex_utils.py
56
- @amp.autocast(enabled=False)
57
- def get_1d_rotary_pos_embed_riflex(
58
- pos: Union[np.ndarray, int],
59
- dim: int,
60
- theta: float = 10000.0,
61
- use_real=False,
62
- k: Optional[int] = None,
63
- L_test: Optional[int] = None,
64
- L_test_scale: Optional[int] = None,
65
- ):
66
- """
67
- RIFLEx: Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
68
-
69
- This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end
70
- index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64
71
- data type.
72
-
73
- Args:
74
- dim (`int`): Dimension of the frequency tensor.
75
- pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar
76
- theta (`float`, *optional*, defaults to 10000.0):
77
- Scaling factor for frequency computation. Defaults to 10000.0.
78
- use_real (`bool`, *optional*):
79
- If True, return real part and imaginary part separately. Otherwise, return complex numbers.
80
- k (`int`, *optional*, defaults to None): the index for the intrinsic frequency in RoPE
81
- L_test (`int`, *optional*, defaults to None): the number of frames for inference
82
- Returns:
83
- `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
84
- """
85
- assert dim % 2 == 0
86
-
87
- if isinstance(pos, int):
88
- pos = torch.arange(pos)
89
- if isinstance(pos, np.ndarray):
90
- pos = torch.from_numpy(pos) # type: ignore # [S]
91
-
92
- freqs = 1.0 / torch.pow(theta,
93
- torch.arange(0, dim, 2).to(torch.float64).div(dim))
94
-
95
- # === Riflex modification start ===
96
- # Reduce the intrinsic frequency to stay within a single period after extrapolation (see Eq. (8)).
97
- # Empirical observations show that a few videos may exhibit repetition in the tail frames.
98
- # To be conservative, we multiply by 0.9 to keep the extrapolated length below 90% of a single period.
99
- if k is not None:
100
- freqs[k-1] = 0.9 * 2 * torch.pi / L_test
101
- # === Riflex modification end ===
102
- if L_test_scale is not None:
103
- freqs[k-1] = freqs[k-1] / L_test_scale
104
-
105
- freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
106
- if use_real:
107
- freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
108
- freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
109
- return freqs_cos, freqs_sin
110
- else:
111
- # lumina
112
- freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
113
- return freqs_cis
114
-
115
-
116
- # Similar to diffusers.pipelines.hunyuandit.pipeline_hunyuandit.get_resize_crop_region_for_grid
117
- def get_resize_crop_region_for_grid(src, tgt_width, tgt_height):
118
- tw = tgt_width
119
- th = tgt_height
120
- h, w = src
121
- r = h / w
122
- if r > (th / tw):
123
- resize_height = th
124
- resize_width = int(round(th / h * w))
125
- else:
126
- resize_width = tw
127
- resize_height = int(round(tw / w * h))
128
-
129
- crop_top = int(round((th - resize_height) / 2.0))
130
- crop_left = int(round((tw - resize_width) / 2.0))
131
-
132
- return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width)
133
-
134
-
135
- @amp.autocast(enabled=False)
136
- @torch.compiler.disable()
137
- def rope_apply(x, grid_sizes, freqs):
138
- n, c = x.size(2), x.size(3) // 2
139
-
140
- # split freqs
141
- freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
142
-
143
- # loop over samples
144
- output = []
145
- for i, (f, h, w) in enumerate(grid_sizes.tolist()):
146
- seq_len = f * h * w
147
-
148
- # precompute multipliers
149
- x_i = torch.view_as_complex(x[i, :seq_len].to(torch.float32).reshape(
150
- seq_len, n, -1, 2))
151
- freqs_i = torch.cat([
152
- freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
153
- freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
154
- freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
155
- ],
156
- dim=-1).reshape(seq_len, 1, -1)
157
-
158
- # apply rotary embedding
159
- x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
160
- x_i = torch.cat([x_i, x[i, seq_len:]])
161
-
162
- # append to collection
163
- output.append(x_i)
164
- return torch.stack(output).to(x.dtype)
165
-
166
-
167
- def rope_apply_qk(q, k, grid_sizes, freqs):
168
- q = rope_apply(q, grid_sizes, freqs)
169
- k = rope_apply(k, grid_sizes, freqs)
170
- return q, k
171
-
172
-
173
- class WanRMSNorm(nn.Module):
174
-
175
- def __init__(self, dim, eps=1e-5):
176
- super().__init__()
177
- self.dim = dim
178
- self.eps = eps
179
- self.weight = nn.Parameter(torch.ones(dim))
180
-
181
- def forward(self, x):
182
- r"""
183
- Args:
184
- x(Tensor): Shape [B, L, C]
185
- """
186
- return self._norm(x) * self.weight
187
-
188
- def _norm(self, x):
189
- return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps).to(x.dtype)
190
-
191
-
192
- class WanLayerNorm(nn.LayerNorm):
193
-
194
- def __init__(self, dim, eps=1e-6, elementwise_affine=False):
195
- super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
196
-
197
- def forward(self, x):
198
- r"""
199
- Args:
200
- x(Tensor): Shape [B, L, C]
201
- """
202
- return super().forward(x)
203
-
204
-
205
- class WanSelfAttention(nn.Module):
206
-
207
- def __init__(self,
208
- dim,
209
- num_heads,
210
- window_size=(-1, -1),
211
- qk_norm=True,
212
- eps=1e-6):
213
- assert dim % num_heads == 0
214
- super().__init__()
215
- self.dim = dim
216
- self.num_heads = num_heads
217
- self.head_dim = dim // num_heads
218
- self.window_size = window_size
219
- self.qk_norm = qk_norm
220
- self.eps = eps
221
-
222
- # layers
223
- self.q = nn.Linear(dim, dim)
224
- self.k = nn.Linear(dim, dim)
225
- self.v = nn.Linear(dim, dim)
226
- self.o = nn.Linear(dim, dim)
227
- self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
228
- self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
229
-
230
- def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0):
231
- r"""
232
- Args:
233
- x(Tensor): Shape [B, L, num_heads, C / num_heads]
234
- seq_lens(Tensor): Shape [B]
235
- grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
236
- freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
237
- """
238
- b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
239
-
240
- # query, key, value function
241
- def qkv_fn(x):
242
- q = self.norm_q(self.q(x.to(dtype))).view(b, s, n, d)
243
- k = self.norm_k(self.k(x.to(dtype))).view(b, s, n, d)
244
- v = self.v(x.to(dtype)).view(b, s, n, d)
245
- return q, k, v
246
-
247
- q, k, v = qkv_fn(x)
248
-
249
- q, k = rope_apply_qk(q, k, grid_sizes, freqs)
250
-
251
- x = attention(
252
- q.to(dtype),
253
- k.to(dtype),
254
- v=v.to(dtype),
255
- k_lens=seq_lens,
256
- window_size=self.window_size)
257
- x = x.to(dtype)
258
-
259
- # output
260
- x = x.flatten(2)
261
- x = self.o(x)
262
- return x
263
-
264
-
265
- class WanT2VCrossAttention(WanSelfAttention):
266
-
267
- def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
268
- r"""
269
- Args:
270
- x(Tensor): Shape [B, L1, C]
271
- context(Tensor): Shape [B, L2, C]
272
- context_lens(Tensor): Shape [B]
273
- """
274
- b, n, d = x.size(0), self.num_heads, self.head_dim
275
-
276
- # compute query, key, value
277
- q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
278
- k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
279
- v = self.v(context.to(dtype)).view(b, -1, n, d)
280
-
281
- # compute attention
282
- x = attention(
283
- q.to(dtype),
284
- k.to(dtype),
285
- v.to(dtype),
286
- k_lens=context_lens
287
- )
288
- x = x.to(dtype)
289
-
290
- # output
291
- x = x.flatten(2)
292
- x = self.o(x)
293
- return x
294
-
295
-
296
- class WanI2VCrossAttention(WanSelfAttention):
297
-
298
- def __init__(self,
299
- dim,
300
- num_heads,
301
- window_size=(-1, -1),
302
- qk_norm=True,
303
- eps=1e-6):
304
- super().__init__(dim, num_heads, window_size, qk_norm, eps)
305
-
306
- self.k_img = nn.Linear(dim, dim)
307
- self.v_img = nn.Linear(dim, dim)
308
- # self.alpha = nn.Parameter(torch.zeros((1, )))
309
- self.norm_k_img = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
310
-
311
- def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
312
- r"""
313
- Args:
314
- x(Tensor): Shape [B, L1, C]
315
- context(Tensor): Shape [B, L2, C]
316
- context_lens(Tensor): Shape [B]
317
- """
318
- context_img = context[:, :257]
319
- context = context[:, 257:]
320
- b, n, d = x.size(0), self.num_heads, self.head_dim
321
-
322
- # compute query, key, value
323
- q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
324
- k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
325
- v = self.v(context.to(dtype)).view(b, -1, n, d)
326
- k_img = self.norm_k_img(self.k_img(context_img.to(dtype))).view(b, -1, n, d)
327
- v_img = self.v_img(context_img.to(dtype)).view(b, -1, n, d)
328
-
329
- img_x = attention(
330
- q.to(dtype),
331
- k_img.to(dtype),
332
- v_img.to(dtype),
333
- k_lens=None
334
- )
335
- img_x = img_x.to(dtype)
336
- # compute attention
337
- x = attention(
338
- q.to(dtype),
339
- k.to(dtype),
340
- v.to(dtype),
341
- k_lens=context_lens
342
- )
343
- x = x.to(dtype)
344
-
345
- # output
346
- x = x.flatten(2)
347
- img_x = img_x.flatten(2)
348
- x = x + img_x
349
- x = self.o(x)
350
- return x
351
-
352
-
353
- class WanCrossAttention(WanSelfAttention):
354
- def forward(self, x, context, context_lens, dtype=torch.bfloat16, t=0):
355
- r"""
356
- Args:
357
- x(Tensor): Shape [B, L1, C]
358
- context(Tensor): Shape [B, L2, C]
359
- context_lens(Tensor): Shape [B]
360
- """
361
- b, n, d = x.size(0), self.num_heads, self.head_dim
362
- # compute query, key, value
363
- q = self.norm_q(self.q(x.to(dtype))).view(b, -1, n, d)
364
- k = self.norm_k(self.k(context.to(dtype))).view(b, -1, n, d)
365
- v = self.v(context.to(dtype)).view(b, -1, n, d)
366
- # compute attention
367
- x = attention(q.to(dtype), k.to(dtype), v.to(dtype), k_lens=context_lens)
368
- # output
369
- x = x.flatten(2)
370
- x = self.o(x.to(dtype))
371
- return x
372
-
373
-
374
- WAN_CROSSATTENTION_CLASSES = {
375
- 't2v_cross_attn': WanT2VCrossAttention,
376
- 'i2v_cross_attn': WanI2VCrossAttention,
377
- 'cross_attn': WanCrossAttention,
378
- }
379
-
380
-
381
- class WanAttentionBlock(nn.Module):
382
-
383
- def __init__(self,
384
- cross_attn_type,
385
- dim,
386
- ffn_dim,
387
- num_heads,
388
- window_size=(-1, -1),
389
- qk_norm=True,
390
- cross_attn_norm=False,
391
- eps=1e-6):
392
- super().__init__()
393
- self.dim = dim
394
- self.ffn_dim = ffn_dim
395
- self.num_heads = num_heads
396
- self.window_size = window_size
397
- self.qk_norm = qk_norm
398
- self.cross_attn_norm = cross_attn_norm
399
- self.eps = eps
400
-
401
- # layers
402
- self.norm1 = WanLayerNorm(dim, eps)
403
- self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
404
- eps)
405
- self.norm3 = WanLayerNorm(
406
- dim, eps,
407
- elementwise_affine=True) if cross_attn_norm else nn.Identity()
408
- self.cross_attn = WAN_CROSSATTENTION_CLASSES[cross_attn_type](dim,
409
- num_heads,
410
- (-1, -1),
411
- qk_norm,
412
- eps)
413
- self.norm2 = WanLayerNorm(dim, eps)
414
- self.ffn = nn.Sequential(
415
- nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
416
- nn.Linear(ffn_dim, dim))
417
-
418
- # modulation
419
- self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
420
-
421
- def forward(
422
- self,
423
- x,
424
- e,
425
- seq_lens,
426
- grid_sizes,
427
- freqs,
428
- context,
429
- context_lens,
430
- dtype=torch.bfloat16,
431
- t=0,
432
- ):
433
- r"""
434
- Args:
435
- x(Tensor): Shape [B, L, C]
436
- e(Tensor): Shape [B, 6, C]
437
- seq_lens(Tensor): Shape [B], length of each sequence in batch
438
- grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
439
- freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
440
- """
441
- if e.dim() > 3:
442
- e = (self.modulation.unsqueeze(0) + e).chunk(6, dim=2)
443
- e = [e.squeeze(2) for e in e]
444
- else:
445
- e = (self.modulation + e).chunk(6, dim=1)
446
-
447
- # self-attention
448
- temp_x = self.norm1(x) * (1 + e[1]) + e[0]
449
- temp_x = temp_x.to(dtype)
450
-
451
- y = self.self_attn(temp_x, seq_lens, grid_sizes, freqs, dtype, t=t)
452
- x = x + y * e[2]
453
-
454
- # cross-attention & ffn function
455
- def cross_attn_ffn(x, context, context_lens, e):
456
- # cross-attention
457
- x = x + self.cross_attn(self.norm3(x), context, context_lens, dtype, t=t)
458
-
459
- # ffn function
460
- temp_x = self.norm2(x) * (1 + e[4]) + e[3]
461
- temp_x = temp_x.to(dtype)
462
-
463
- y = self.ffn(temp_x)
464
- x = x + y * e[5]
465
- return x
466
-
467
- x = cross_attn_ffn(x, context, context_lens, e)
468
- return x
469
-
470
-
471
- class Head(nn.Module):
472
-
473
- def __init__(self, dim, out_dim, patch_size, eps=1e-6):
474
- super().__init__()
475
- self.dim = dim
476
- self.out_dim = out_dim
477
- self.patch_size = patch_size
478
- self.eps = eps
479
-
480
- # layers
481
- out_dim = math.prod(patch_size) * out_dim
482
- self.norm = WanLayerNorm(dim, eps)
483
- self.head = nn.Linear(dim, out_dim)
484
-
485
- # modulation
486
- self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
487
-
488
- def forward(self, x, e):
489
- r"""
490
- Args:
491
- x(Tensor): Shape [B, L1, C]
492
- e(Tensor): Shape [B, C]
493
- """
494
- if e.dim() > 2:
495
- e = (self.modulation.unsqueeze(0) + e.unsqueeze(2)).chunk(2, dim=2)
496
- e = [e.squeeze(2) for e in e]
497
- else:
498
- e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)
499
-
500
- x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
501
- return x
502
-
503
-
504
- class MLPProj(torch.nn.Module):
505
-
506
- def __init__(self, in_dim, out_dim):
507
- super().__init__()
508
-
509
- self.proj = torch.nn.Sequential(
510
- torch.nn.LayerNorm(in_dim), torch.nn.Linear(in_dim, in_dim),
511
- torch.nn.GELU(), torch.nn.Linear(in_dim, out_dim),
512
- torch.nn.LayerNorm(out_dim))
513
-
514
- def forward(self, image_embeds):
515
- clip_extra_context_tokens = self.proj(image_embeds)
516
- return clip_extra_context_tokens
517
-
518
-
519
-
520
- class WanTransformer3DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
521
- r"""
522
- Wan diffusion backbone supporting both text-to-video and image-to-video.
523
- """
524
-
525
- # ignore_for_config = [
526
- # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
527
- # ]
528
- # _no_split_modules = ['WanAttentionBlock']
529
- _supports_gradient_checkpointing = True
530
-
531
- @register_to_config
532
- def __init__(
533
- self,
534
- model_type='t2v',
535
- patch_size=(1, 2, 2),
536
- text_len=512,
537
- in_dim=16,
538
- dim=2048,
539
- ffn_dim=8192,
540
- freq_dim=256,
541
- text_dim=4096,
542
- out_dim=16,
543
- num_heads=16,
544
- num_layers=32,
545
- window_size=(-1, -1),
546
- qk_norm=True,
547
- cross_attn_norm=True,
548
- eps=1e-6,
549
- in_channels=16,
550
- hidden_size=2048,
551
- add_control_adapter=False,
552
- in_dim_control_adapter=24,
553
- downscale_factor_control_adapter=8,
554
- add_ref_conv=False,
555
- in_dim_ref_conv=16,
556
- cross_attn_type=None,
557
- ):
558
- r"""
559
- Initialize the diffusion model backbone.
560
-
561
- Args:
562
- model_type (`str`, *optional*, defaults to 't2v'):
563
- Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
564
- patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
565
- 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
566
- text_len (`int`, *optional*, defaults to 512):
567
- Fixed length for text embeddings
568
- in_dim (`int`, *optional*, defaults to 16):
569
- Input video channels (C_in)
570
- dim (`int`, *optional*, defaults to 2048):
571
- Hidden dimension of the transformer
572
- ffn_dim (`int`, *optional*, defaults to 8192):
573
- Intermediate dimension in feed-forward network
574
- freq_dim (`int`, *optional*, defaults to 256):
575
- Dimension for sinusoidal time embeddings
576
- text_dim (`int`, *optional*, defaults to 4096):
577
- Input dimension for text embeddings
578
- out_dim (`int`, *optional*, defaults to 16):
579
- Output video channels (C_out)
580
- num_heads (`int`, *optional*, defaults to 16):
581
- Number of attention heads
582
- num_layers (`int`, *optional*, defaults to 32):
583
- Number of transformer blocks
584
- window_size (`tuple`, *optional*, defaults to (-1, -1)):
585
- Window size for local attention (-1 indicates global attention)
586
- qk_norm (`bool`, *optional*, defaults to True):
587
- Enable query/key normalization
588
- cross_attn_norm (`bool`, *optional*, defaults to False):
589
- Enable cross-attention normalization
590
- eps (`float`, *optional*, defaults to 1e-6):
591
- Epsilon value for normalization layers
592
- """
593
-
594
- super().__init__()
595
-
596
- # assert model_type in ['t2v', 'i2v', 'ti2v']
597
- self.model_type = model_type
598
-
599
- self.patch_size = patch_size
600
- self.text_len = text_len
601
- self.in_dim = in_dim
602
- self.dim = dim
603
- self.ffn_dim = ffn_dim
604
- self.freq_dim = freq_dim
605
- self.text_dim = text_dim
606
- self.out_dim = out_dim
607
- self.num_heads = num_heads
608
- self.num_layers = num_layers
609
- self.window_size = window_size
610
- self.qk_norm = qk_norm
611
- self.cross_attn_norm = cross_attn_norm
612
- self.eps = eps
613
-
614
- # embeddings
615
- self.patch_embedding = nn.Conv3d(
616
- in_dim, dim, kernel_size=patch_size, stride=patch_size)
617
- self.text_embedding = nn.Sequential(
618
- nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
619
- nn.Linear(dim, dim))
620
-
621
- self.time_embedding = nn.Sequential(
622
- nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
623
- self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
624
-
625
- # blocks
626
- if cross_attn_type is None:
627
- cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn'
628
- self.blocks = nn.ModuleList([
629
- WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads,
630
- window_size, qk_norm, cross_attn_norm, eps)
631
- for _ in range(num_layers)
632
- ])
633
- for layer_idx, block in enumerate(self.blocks):
634
- block.self_attn.layer_idx = layer_idx
635
- block.self_attn.num_layers = self.num_layers
636
-
637
- # head
638
- self.head = Head(dim, out_dim, patch_size, eps)
639
-
640
- # buffers (don't use register_buffer otherwise dtype will be changed in to())
641
- assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
642
- d = dim // num_heads
643
- self.d = d
644
- self.dim = dim
645
- self.freqs = torch.cat(
646
- [
647
- rope_params(1024, d - 4 * (d // 6)),
648
- rope_params(1024, 2 * (d // 6)),
649
- rope_params(1024, 2 * (d // 6))
650
- ],
651
- dim=1
652
- )
653
-
654
- if model_type == 'i2v':
655
- self.img_emb = MLPProj(1280, dim)
656
-
657
- if add_control_adapter:
658
- self.control_adapter = SimpleAdapter(in_dim_control_adapter, dim, kernel_size=patch_size[1:], stride=patch_size[1:], downscale_factor=downscale_factor_control_adapter)
659
- else:
660
- self.control_adapter = None
661
-
662
- if add_ref_conv:
663
- self.ref_conv = nn.Conv2d(in_dim_ref_conv, dim, kernel_size=patch_size[1:], stride=patch_size[1:])
664
- else:
665
- self.ref_conv = None
666
-
667
- self.teacache = None
668
- self.cfg_skip_ratio = None
669
- self.current_steps = 0
670
- self.num_inference_steps = None
671
- self.gradient_checkpointing = False
672
- self.all_gather = None
673
- self.sp_world_size = 1
674
- self.sp_world_rank = 0
675
- self.init_weights()
676
-
677
- def _set_gradient_checkpointing(self, *args, **kwargs):
678
- if "value" in kwargs:
679
- self.gradient_checkpointing = kwargs["value"]
680
- if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"):
681
- self.motioner.gradient_checkpointing = kwargs["value"]
682
- elif "enable" in kwargs:
683
- self.gradient_checkpointing = kwargs["enable"]
684
- if hasattr(self, "motioner") and hasattr(self.motioner, "gradient_checkpointing"):
685
- self.motioner.gradient_checkpointing = kwargs["enable"]
686
- else:
687
- raise ValueError("Invalid set gradient checkpointing")
688
-
689
- def enable_teacache(
690
- self,
691
- coefficients,
692
- num_steps: int,
693
- rel_l1_thresh: float,
694
- num_skip_start_steps: int = 0,
695
- offload: bool = True,
696
- ):
697
- self.teacache = TeaCache(
698
- coefficients, num_steps, rel_l1_thresh=rel_l1_thresh, num_skip_start_steps=num_skip_start_steps, offload=offload
699
- )
700
-
701
- def share_teacache(
702
- self,
703
- transformer = None,
704
- ):
705
- self.teacache = transformer.teacache
706
-
707
- def disable_teacache(self):
708
- self.teacache = None
709
-
710
- def enable_cfg_skip(self, cfg_skip_ratio, num_steps):
711
- if cfg_skip_ratio != 0:
712
- self.cfg_skip_ratio = cfg_skip_ratio
713
- self.current_steps = 0
714
- self.num_inference_steps = num_steps
715
- else:
716
- self.cfg_skip_ratio = None
717
- self.current_steps = 0
718
- self.num_inference_steps = None
719
-
720
- def share_cfg_skip(
721
- self,
722
- transformer = None,
723
- ):
724
- self.cfg_skip_ratio = transformer.cfg_skip_ratio
725
- self.current_steps = transformer.current_steps
726
- self.num_inference_steps = transformer.num_inference_steps
727
-
728
- def disable_cfg_skip(self):
729
- self.cfg_skip_ratio = None
730
- self.current_steps = 0
731
- self.num_inference_steps = None
732
-
733
- def enable_riflex(
734
- self,
735
- k = 6,
736
- L_test = 66,
737
- L_test_scale = 4.886,
738
- ):
739
- device = self.freqs.device
740
- self.freqs = torch.cat(
741
- [
742
- get_1d_rotary_pos_embed_riflex(1024, self.d - 4 * (self.d // 6), use_real=False, k=k, L_test=L_test, L_test_scale=L_test_scale),
743
- rope_params(1024, 2 * (self.d // 6)),
744
- rope_params(1024, 2 * (self.d // 6))
745
- ],
746
- dim=1
747
- ).to(device)
748
-
749
- def disable_riflex(self):
750
- device = self.freqs.device
751
- self.freqs = torch.cat(
752
- [
753
- rope_params(1024, self.d - 4 * (self.d // 6)),
754
- rope_params(1024, 2 * (self.d // 6)),
755
- rope_params(1024, 2 * (self.d // 6))
756
- ],
757
- dim=1
758
- ).to(device)
759
-
760
- def enable_multi_gpus_inference(self,):
761
- self.sp_world_size = get_sequence_parallel_world_size()
762
- self.sp_world_rank = get_sequence_parallel_rank()
763
- self.all_gather = get_sp_group().all_gather
764
-
765
- # For normal model.
766
- for block in self.blocks:
767
- block.self_attn.forward = types.MethodType(
768
- usp_attn_forward, block.self_attn)
769
-
770
- # For vace model.
771
- if hasattr(self, 'vace_blocks'):
772
- for block in self.vace_blocks:
773
- block.self_attn.forward = types.MethodType(
774
- usp_attn_forward, block.self_attn)
775
-
776
- @cfg_skip()
777
- def forward(
778
- self,
779
- x,
780
- t,
781
- context,
782
- seq_len,
783
- clip_fea=None,
784
- y=None,
785
- y_camera=None,
786
- full_ref=None,
787
- subject_ref=None,
788
- cond_flag=True,
789
- ):
790
- r"""
791
- Forward pass through the diffusion model
792
-
793
- Args:
794
- x (List[Tensor]):
795
- List of input video tensors, each with shape [C_in, F, H, W]
796
- t (Tensor):
797
- Diffusion timesteps tensor of shape [B]
798
- context (List[Tensor]):
799
- List of text embeddings each with shape [L, C]
800
- seq_len (`int`):
801
- Maximum sequence length for positional encoding
802
- clip_fea (Tensor, *optional*):
803
- CLIP image features for image-to-video mode
804
- y (List[Tensor], *optional*):
805
- Conditional video inputs for image-to-video mode, same shape as x
806
- cond_flag (`bool`, *optional*, defaults to True):
807
- Flag to indicate whether to forward the condition input
808
-
809
- Returns:
810
- List[Tensor]:
811
- List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
812
- """
813
- # Wan2.2 don't need a clip.
814
- # if self.model_type == 'i2v':
815
- # assert clip_fea is not None and y is not None
816
- # params
817
- device = self.patch_embedding.weight.device
818
- dtype = x.dtype
819
- if self.freqs.device != device and torch.device(type="meta") != device:
820
- self.freqs = self.freqs.to(device)
821
-
822
- if y is not None:
823
- x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
824
-
825
- # embeddings
826
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
827
- # add control adapter
828
- if self.control_adapter is not None and y_camera is not None:
829
- y_camera = self.control_adapter(y_camera)
830
- x = [u + v for u, v in zip(x, y_camera)]
831
-
832
- grid_sizes = torch.stack(
833
- [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
834
-
835
- x = [u.flatten(2).transpose(1, 2) for u in x]
836
- if self.ref_conv is not None and full_ref is not None:
837
- full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
838
- grid_sizes = torch.stack([torch.tensor([u[0] + 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
839
- seq_len += full_ref.size(1)
840
- x = [torch.concat([_full_ref.unsqueeze(0), u], dim=1) for _full_ref, u in zip(full_ref, x)]
841
- if t.dim() != 1 and t.size(1) < seq_len:
842
- pad_size = seq_len - t.size(1)
843
- last_elements = t[:, -1].unsqueeze(1)
844
- padding = last_elements.repeat(1, pad_size)
845
- t = torch.cat([padding, t], dim=1)
846
-
847
- if subject_ref is not None:
848
- subject_ref_frames = subject_ref.size(2)
849
- subject_ref = self.patch_embedding(subject_ref).flatten(2).transpose(1, 2)
850
- grid_sizes = torch.stack([torch.tensor([u[0] + subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
851
- seq_len += subject_ref.size(1)
852
- x = [torch.concat([u, _subject_ref.unsqueeze(0)], dim=1) for _subject_ref, u in zip(subject_ref, x)]
853
- if t.dim() != 1 and t.size(1) < seq_len:
854
- pad_size = seq_len - t.size(1)
855
- last_elements = t[:, -1].unsqueeze(1)
856
- padding = last_elements.repeat(1, pad_size)
857
- t = torch.cat([t, padding], dim=1)
858
-
859
- seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
860
- if self.sp_world_size > 1:
861
- seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
862
- assert seq_lens.max() <= seq_len
863
- x = torch.cat([
864
- torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
865
- dim=1) for u in x
866
- ])
867
-
868
- # time embeddings
869
- with amp.autocast(dtype=torch.float32):
870
- if t.dim() != 1:
871
- if t.size(1) < seq_len:
872
- pad_size = seq_len - t.size(1)
873
- last_elements = t[:, -1].unsqueeze(1)
874
- padding = last_elements.repeat(1, pad_size)
875
- t = torch.cat([t, padding], dim=1)
876
- bt = t.size(0)
877
- ft = t.flatten()
878
- e = self.time_embedding(
879
- sinusoidal_embedding_1d(self.freq_dim,
880
- ft).unflatten(0, (bt, seq_len)).float())
881
- e0 = self.time_projection(e).unflatten(2, (6, self.dim))
882
- else:
883
- e = self.time_embedding(
884
- sinusoidal_embedding_1d(self.freq_dim, t).float())
885
- e0 = self.time_projection(e).unflatten(1, (6, self.dim))
886
-
887
- # assert e.dtype == torch.float32 and e0.dtype == torch.float32
888
- # e0 = e0.to(dtype)
889
- # e = e.to(dtype)
890
-
891
- # context
892
- context_lens = None
893
- context = self.text_embedding(
894
- torch.stack([
895
- torch.cat(
896
- [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
897
- for u in context
898
- ]))
899
-
900
- if clip_fea is not None:
901
- context_clip = self.img_emb(clip_fea) # bs x 257 x dim
902
- context = torch.concat([context_clip, context], dim=1)
903
-
904
- # Context Parallel
905
- if self.sp_world_size > 1:
906
- x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
907
- if t.dim() != 1:
908
- e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
909
- e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
910
-
911
- # TeaCache
912
- if self.teacache is not None:
913
- if cond_flag:
914
- if t.dim() != 1:
915
- modulated_inp = e0[:, -1, :]
916
- else:
917
- modulated_inp = e0
918
- skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
919
- if skip_flag:
920
- self.should_calc = True
921
- self.teacache.accumulated_rel_l1_distance = 0
922
- else:
923
- if cond_flag:
924
- rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
925
- self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
926
- if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
927
- self.should_calc = False
928
- else:
929
- self.should_calc = True
930
- self.teacache.accumulated_rel_l1_distance = 0
931
- self.teacache.previous_modulated_input = modulated_inp
932
- self.teacache.should_calc = self.should_calc
933
- else:
934
- self.should_calc = self.teacache.should_calc
935
-
936
- # TeaCache
937
- if self.teacache is not None:
938
- if not self.should_calc:
939
- previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
940
- x = x + previous_residual.to(x.device)[-x.size()[0]:,]
941
- else:
942
- ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
943
-
944
- for block in self.blocks:
945
- if torch.is_grad_enabled() and self.gradient_checkpointing:
946
-
947
- def create_custom_forward(module):
948
- def custom_forward(*inputs):
949
- return module(*inputs)
950
-
951
- return custom_forward
952
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
953
- x = torch.utils.checkpoint.checkpoint(
954
- create_custom_forward(block),
955
- x,
956
- e0,
957
- seq_lens,
958
- grid_sizes,
959
- self.freqs,
960
- context,
961
- context_lens,
962
- dtype,
963
- t,
964
- **ckpt_kwargs,
965
- )
966
- else:
967
- # arguments
968
- kwargs = dict(
969
- e=e0,
970
- seq_lens=seq_lens,
971
- grid_sizes=grid_sizes,
972
- freqs=self.freqs,
973
- context=context,
974
- context_lens=context_lens,
975
- dtype=dtype,
976
- t=t
977
- )
978
- x = block(x, **kwargs)
979
-
980
- if cond_flag:
981
- self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
982
- else:
983
- self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
984
- else:
985
- for block in self.blocks:
986
- if torch.is_grad_enabled() and self.gradient_checkpointing:
987
-
988
- def create_custom_forward(module):
989
- def custom_forward(*inputs):
990
- return module(*inputs)
991
-
992
- return custom_forward
993
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
994
- x = torch.utils.checkpoint.checkpoint(
995
- create_custom_forward(block),
996
- x,
997
- e0,
998
- seq_lens,
999
- grid_sizes,
1000
- self.freqs,
1001
- context,
1002
- context_lens,
1003
- dtype,
1004
- t,
1005
- **ckpt_kwargs,
1006
- )
1007
- else:
1008
- # arguments
1009
- kwargs = dict(
1010
- e=e0,
1011
- seq_lens=seq_lens,
1012
- grid_sizes=grid_sizes,
1013
- freqs=self.freqs,
1014
- context=context,
1015
- context_lens=context_lens,
1016
- dtype=dtype,
1017
- t=t
1018
- )
1019
- x = block(x, **kwargs)
1020
-
1021
- # head
1022
- if torch.is_grad_enabled() and self.gradient_checkpointing:
1023
- def create_custom_forward(module):
1024
- def custom_forward(*inputs):
1025
- return module(*inputs)
1026
-
1027
- return custom_forward
1028
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1029
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
1030
- else:
1031
- x = self.head(x, e)
1032
-
1033
- if self.sp_world_size > 1:
1034
- x = self.all_gather(x, dim=1)
1035
-
1036
- if self.ref_conv is not None and full_ref is not None:
1037
- full_ref_length = full_ref.size(1)
1038
- x = x[:, full_ref_length:]
1039
- grid_sizes = torch.stack([torch.tensor([u[0] - 1, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
1040
-
1041
- if subject_ref is not None:
1042
- subject_ref_length = subject_ref.size(1)
1043
- x = x[:, :-subject_ref_length]
1044
- grid_sizes = torch.stack([torch.tensor([u[0] - subject_ref_frames, u[1], u[2]]) for u in grid_sizes]).to(grid_sizes.device)
1045
-
1046
- # unpatchify
1047
- x = self.unpatchify(x, grid_sizes)
1048
- x = torch.stack(x)
1049
- if self.teacache is not None and cond_flag:
1050
- self.teacache.cnt += 1
1051
- if self.teacache.cnt == self.teacache.num_steps:
1052
- self.teacache.reset()
1053
- return x
1054
-
1055
-
1056
- def unpatchify(self, x, grid_sizes):
1057
- r"""
1058
- Reconstruct video tensors from patch embeddings.
1059
-
1060
- Args:
1061
- x (List[Tensor]):
1062
- List of patchified features, each with shape [L, C_out * prod(patch_size)]
1063
- grid_sizes (Tensor):
1064
- Original spatial-temporal grid dimensions before patching,
1065
- shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
1066
-
1067
- Returns:
1068
- List[Tensor]:
1069
- Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
1070
- """
1071
-
1072
- c = self.out_dim
1073
- out = []
1074
- for u, v in zip(x, grid_sizes.tolist()):
1075
- u = u[:math.prod(v)].view(*v, *self.patch_size, c)
1076
- u = torch.einsum('fhwpqrc->cfphqwr', u)
1077
- u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
1078
- out.append(u)
1079
- return out
1080
-
1081
- def init_weights(self):
1082
- r"""
1083
- Initialize model parameters using Xavier initialization.
1084
- """
1085
-
1086
- # basic init
1087
- for m in self.modules():
1088
- if isinstance(m, nn.Linear):
1089
- nn.init.xavier_uniform_(m.weight)
1090
- if m.bias is not None:
1091
- nn.init.zeros_(m.bias)
1092
-
1093
- # init embeddings
1094
- nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
1095
- for m in self.text_embedding.modules():
1096
- if isinstance(m, nn.Linear):
1097
- nn.init.normal_(m.weight, std=.02)
1098
- for m in self.time_embedding.modules():
1099
- if isinstance(m, nn.Linear):
1100
- nn.init.normal_(m.weight, std=.02)
1101
-
1102
- # init output layer
1103
- nn.init.zeros_(self.head.head.weight)
1104
-
1105
- @classmethod
1106
- def from_pretrained(
1107
- cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
1108
- low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
1109
- ):
1110
- if subfolder is not None:
1111
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
1112
- print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
1113
-
1114
- config_file = os.path.join(pretrained_model_path, 'config.json')
1115
- if not os.path.isfile(config_file):
1116
- raise RuntimeError(f"{config_file} does not exist")
1117
- with open(config_file, "r") as f:
1118
- config = json.load(f)
1119
-
1120
- from diffusers.utils import WEIGHTS_NAME
1121
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
1122
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
1123
-
1124
- if "dict_mapping" in transformer_additional_kwargs.keys():
1125
- for key in transformer_additional_kwargs["dict_mapping"]:
1126
- transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
1127
-
1128
- if low_cpu_mem_usage:
1129
- try:
1130
- import re
1131
-
1132
- from diffusers import __version__ as diffusers_version
1133
- if diffusers_version >= "0.33.0":
1134
- from diffusers.models.model_loading_utils import \
1135
- load_model_dict_into_meta
1136
- else:
1137
- from diffusers.models.modeling_utils import \
1138
- load_model_dict_into_meta
1139
- from diffusers.utils import is_accelerate_available
1140
- if is_accelerate_available():
1141
- import accelerate
1142
-
1143
- # Instantiate model with empty weights
1144
- with accelerate.init_empty_weights():
1145
- model = cls.from_config(config, **transformer_additional_kwargs)
1146
-
1147
- param_device = "cpu"
1148
- if os.path.exists(model_file):
1149
- state_dict = torch.load(model_file, map_location="cpu")
1150
- elif os.path.exists(model_file_safetensors):
1151
- from safetensors.torch import load_file, safe_open
1152
- state_dict = load_file(model_file_safetensors)
1153
- else:
1154
- from safetensors.torch import load_file, safe_open
1155
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1156
- state_dict = {}
1157
- print(model_files_safetensors)
1158
- for _model_file_safetensors in model_files_safetensors:
1159
- _state_dict = load_file(_model_file_safetensors)
1160
- for key in _state_dict:
1161
- state_dict[key] = _state_dict[key]
1162
-
1163
- if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
1164
- model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :]
1165
- model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
1166
- state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
1167
-
1168
- filtered_state_dict = {}
1169
- for key in state_dict:
1170
- if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
1171
- filtered_state_dict[key] = state_dict[key]
1172
- else:
1173
- print(f"Skipping key '{key}' due to size mismatch or absence in model.")
1174
-
1175
- model_keys = set(model.state_dict().keys())
1176
- loaded_keys = set(filtered_state_dict.keys())
1177
- missing_keys = model_keys - loaded_keys
1178
-
1179
- def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
1180
- initialized_dict = {}
1181
-
1182
- with torch.no_grad():
1183
- for key in missing_keys:
1184
- param_shape = model_state_dict[key].shape
1185
- param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
1186
- if 'weight' in key:
1187
- if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
1188
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1189
- elif 'embedding' in key or 'embed' in key:
1190
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1191
- elif 'head' in key or 'output' in key or 'proj_out' in key:
1192
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1193
- elif len(param_shape) >= 2:
1194
- initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
1195
- nn.init.xavier_uniform_(initialized_dict[key])
1196
- else:
1197
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
1198
- elif 'bias' in key:
1199
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1200
- elif 'running_mean' in key:
1201
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1202
- elif 'running_var' in key:
1203
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
1204
- elif 'num_batches_tracked' in key:
1205
- initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
1206
- else:
1207
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
1208
-
1209
- return initialized_dict
1210
-
1211
- if missing_keys:
1212
- print(f"Missing keys will be initialized: {sorted(missing_keys)}")
1213
- initialized_params = initialize_missing_parameters(
1214
- missing_keys,
1215
- model.state_dict(),
1216
- torch_dtype
1217
- )
1218
- filtered_state_dict.update(initialized_params)
1219
-
1220
- if diffusers_version >= "0.33.0":
1221
- # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
1222
- # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
1223
- load_model_dict_into_meta(
1224
- model,
1225
- filtered_state_dict,
1226
- dtype=torch_dtype,
1227
- model_name_or_path=pretrained_model_path,
1228
- )
1229
- else:
1230
- model._convert_deprecated_attention_blocks(filtered_state_dict)
1231
- unexpected_keys = load_model_dict_into_meta(
1232
- model,
1233
- filtered_state_dict,
1234
- device=param_device,
1235
- dtype=torch_dtype,
1236
- model_name_or_path=pretrained_model_path,
1237
- )
1238
-
1239
- if cls._keys_to_ignore_on_load_unexpected is not None:
1240
- for pat in cls._keys_to_ignore_on_load_unexpected:
1241
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1242
-
1243
- if len(unexpected_keys) > 0:
1244
- print(
1245
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1246
- )
1247
-
1248
- return model
1249
- except Exception as e:
1250
- print(
1251
- f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1252
- )
1253
-
1254
- model = cls.from_config(config, **transformer_additional_kwargs)
1255
- if os.path.exists(model_file):
1256
- state_dict = torch.load(model_file, map_location="cpu")
1257
- elif os.path.exists(model_file_safetensors):
1258
- from safetensors.torch import load_file, safe_open
1259
- state_dict = load_file(model_file_safetensors)
1260
- else:
1261
- from safetensors.torch import load_file, safe_open
1262
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1263
- state_dict = {}
1264
- for _model_file_safetensors in model_files_safetensors:
1265
- _state_dict = load_file(_model_file_safetensors)
1266
- for key in _state_dict:
1267
- state_dict[key] = _state_dict[key]
1268
-
1269
- if model.state_dict()['patch_embedding.weight'].size() != state_dict['patch_embedding.weight'].size():
1270
- model.state_dict()['patch_embedding.weight'][:, :state_dict['patch_embedding.weight'].size()[1], :, :] = state_dict['patch_embedding.weight'][:, :model.state_dict()['patch_embedding.weight'].size()[1], :, :]
1271
- model.state_dict()['patch_embedding.weight'][:, state_dict['patch_embedding.weight'].size()[1]:, :, :] = 0
1272
- state_dict['patch_embedding.weight'] = model.state_dict()['patch_embedding.weight']
1273
-
1274
- tmp_state_dict = {}
1275
- for key in state_dict:
1276
- if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1277
- tmp_state_dict[key] = state_dict[key]
1278
- else:
1279
- print(key, "Size don't match, skip")
1280
-
1281
- state_dict = tmp_state_dict
1282
-
1283
- m, u = model.load_state_dict(state_dict, strict=False)
1284
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1285
- print(m)
1286
-
1287
- params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1288
- print(f"### All Parameters: {sum(params) / 1e6} M")
1289
-
1290
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1291
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1292
-
1293
- model = model.to(torch_dtype)
1294
- return model
1295
-
1296
-
1297
- class Wan2_2Transformer3DModel(WanTransformer3DModel):
1298
- r"""
1299
- Wan diffusion backbone supporting both text-to-video and image-to-video.
1300
- """
1301
-
1302
- # ignore_for_config = [
1303
- # 'patch_size', 'cross_attn_norm', 'qk_norm', 'text_dim', 'window_size'
1304
- # ]
1305
- # _no_split_modules = ['WanAttentionBlock']
1306
- _supports_gradient_checkpointing = True
1307
-
1308
- def __init__(
1309
- self,
1310
- model_type='t2v',
1311
- patch_size=(1, 2, 2),
1312
- text_len=512,
1313
- in_dim=16,
1314
- dim=2048,
1315
- ffn_dim=8192,
1316
- freq_dim=256,
1317
- text_dim=4096,
1318
- out_dim=16,
1319
- num_heads=16,
1320
- num_layers=32,
1321
- window_size=(-1, -1),
1322
- qk_norm=True,
1323
- cross_attn_norm=True,
1324
- eps=1e-6,
1325
- in_channels=16,
1326
- hidden_size=2048,
1327
- add_control_adapter=False,
1328
- in_dim_control_adapter=24,
1329
- downscale_factor_control_adapter=8,
1330
- add_ref_conv=False,
1331
- in_dim_ref_conv=16,
1332
- ):
1333
- r"""
1334
- Initialize the diffusion model backbone.
1335
- Args:
1336
- model_type (`str`, *optional*, defaults to 't2v'):
1337
- Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
1338
- patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
1339
- 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
1340
- text_len (`int`, *optional*, defaults to 512):
1341
- Fixed length for text embeddings
1342
- in_dim (`int`, *optional*, defaults to 16):
1343
- Input video channels (C_in)
1344
- dim (`int`, *optional*, defaults to 2048):
1345
- Hidden dimension of the transformer
1346
- ffn_dim (`int`, *optional*, defaults to 8192):
1347
- Intermediate dimension in feed-forward network
1348
- freq_dim (`int`, *optional*, defaults to 256):
1349
- Dimension for sinusoidal time embeddings
1350
- text_dim (`int`, *optional*, defaults to 4096):
1351
- Input dimension for text embeddings
1352
- out_dim (`int`, *optional*, defaults to 16):
1353
- Output video channels (C_out)
1354
- num_heads (`int`, *optional*, defaults to 16):
1355
- Number of attention heads
1356
- num_layers (`int`, *optional*, defaults to 32):
1357
- Number of transformer blocks
1358
- window_size (`tuple`, *optional*, defaults to (-1, -1)):
1359
- Window size for local attention (-1 indicates global attention)
1360
- qk_norm (`bool`, *optional*, defaults to True):
1361
- Enable query/key normalization
1362
- cross_attn_norm (`bool`, *optional*, defaults to False):
1363
- Enable cross-attention normalization
1364
- eps (`float`, *optional*, defaults to 1e-6):
1365
- Epsilon value for normalization layers
1366
- """
1367
- super().__init__(
1368
- model_type=model_type,
1369
- patch_size=patch_size,
1370
- text_len=text_len,
1371
- in_dim=in_dim,
1372
- dim=dim,
1373
- ffn_dim=ffn_dim,
1374
- freq_dim=freq_dim,
1375
- text_dim=text_dim,
1376
- out_dim=out_dim,
1377
- num_heads=num_heads,
1378
- num_layers=num_layers,
1379
- window_size=window_size,
1380
- qk_norm=qk_norm,
1381
- cross_attn_norm=cross_attn_norm,
1382
- eps=eps,
1383
- in_channels=in_channels,
1384
- hidden_size=hidden_size,
1385
- add_control_adapter=add_control_adapter,
1386
- in_dim_control_adapter=in_dim_control_adapter,
1387
- downscale_factor_control_adapter=downscale_factor_control_adapter,
1388
- add_ref_conv=add_ref_conv,
1389
- in_dim_ref_conv=in_dim_ref_conv,
1390
- cross_attn_type="cross_attn"
1391
- )
1392
-
1393
- if hasattr(self, "img_emb"):
1394
- del self.img_emb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_transformer3d_animate.py DELETED
@@ -1,302 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import math
3
- import types
4
- from copy import deepcopy
5
- from typing import List
6
-
7
- import numpy as np
8
- import torch
9
- import torch.cuda.amp as amp
10
- import torch.nn as nn
11
- from diffusers.configuration_utils import ConfigMixin, register_to_config
12
- from diffusers.loaders import PeftAdapterMixin
13
- from diffusers.models.modeling_utils import ModelMixin
14
- from diffusers.utils import is_torch_version, logging
15
- from einops import rearrange
16
-
17
- from .attention_utils import attention
18
- from .wan_animate_adapter import FaceAdapter, FaceEncoder
19
- from .wan_animate_motion_encoder import Generator
20
- from .wan_transformer3d import (Head, MLPProj, WanAttentionBlock, WanLayerNorm,
21
- WanRMSNorm, WanSelfAttention,
22
- WanTransformer3DModel, rope_apply,
23
- sinusoidal_embedding_1d)
24
- from ..utils import cfg_skip
25
-
26
-
27
- class Wan2_2Transformer3DModel_Animate(WanTransformer3DModel):
28
- # _no_split_modules = ['WanAnimateAttentionBlock']
29
- _supports_gradient_checkpointing = True
30
-
31
- @register_to_config
32
- def __init__(
33
- self,
34
- patch_size=(1, 2, 2),
35
- text_len=512,
36
- in_dim=36,
37
- dim=5120,
38
- ffn_dim=13824,
39
- freq_dim=256,
40
- text_dim=4096,
41
- out_dim=16,
42
- num_heads=40,
43
- num_layers=40,
44
- window_size=(-1, -1),
45
- qk_norm=True,
46
- cross_attn_norm=True,
47
- eps=1e-6,
48
- motion_encoder_dim=512,
49
- use_img_emb=True
50
- ):
51
- model_type = "i2v" # TODO: Hard code for both preview and official versions.
52
- super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
53
- num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
54
-
55
- self.motion_encoder_dim = motion_encoder_dim
56
- self.use_img_emb = use_img_emb
57
-
58
- self.pose_patch_embedding = nn.Conv3d(
59
- 16, dim, kernel_size=patch_size, stride=patch_size
60
- )
61
-
62
- # initialize weights
63
- self.init_weights()
64
-
65
- self.motion_encoder = Generator(size=512, style_dim=512, motion_dim=20)
66
- self.face_adapter = FaceAdapter(
67
- heads_num=self.num_heads,
68
- hidden_dim=self.dim,
69
- num_adapter_layers=self.num_layers // 5,
70
- )
71
-
72
- self.face_encoder = FaceEncoder(
73
- in_dim=motion_encoder_dim,
74
- hidden_dim=self.dim,
75
- num_heads=4,
76
- )
77
-
78
- def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
79
- pose_latents = [self.pose_patch_embedding(u.unsqueeze(0)) for u in pose_latents]
80
- for x_, pose_latents_ in zip(x, pose_latents):
81
- x_[:, :, 1:] += pose_latents_
82
-
83
- b,c,T,h,w = face_pixel_values.shape
84
- face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
85
-
86
- encode_bs = 8
87
- face_pixel_values_tmp = []
88
- for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
89
- face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))
90
-
91
- motion_vec = torch.cat(face_pixel_values_tmp)
92
-
93
- motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
94
- motion_vec = self.face_encoder(motion_vec)
95
-
96
- B, L, H, C = motion_vec.shape
97
- pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
98
- motion_vec = torch.cat([pad_face, motion_vec], dim=1)
99
- return x, motion_vec
100
-
101
- def after_transformer_block(self, block_idx, x, motion_vec, motion_masks=None):
102
- if block_idx % 5 == 0:
103
- use_context_parallel = self.sp_world_size > 1
104
- adapter_args = [x, motion_vec, motion_masks, use_context_parallel, self.all_gather, self.sp_world_size, self.sp_world_rank]
105
- residual_out = self.face_adapter.fuser_blocks[block_idx // 5](*adapter_args)
106
- x = residual_out + x
107
- return x
108
-
109
- @cfg_skip()
110
- def forward(
111
- self,
112
- x,
113
- t,
114
- clip_fea,
115
- context,
116
- seq_len,
117
- y=None,
118
- pose_latents=None,
119
- face_pixel_values=None,
120
- cond_flag=True
121
- ):
122
- # params
123
- device = self.patch_embedding.weight.device
124
- dtype = x.dtype
125
- if self.freqs.device != device and torch.device(type="meta") != device:
126
- self.freqs = self.freqs.to(device)
127
-
128
- if y is not None:
129
- x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
130
-
131
- # embeddings
132
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
133
- x, motion_vec = self.after_patch_embedding(x, pose_latents, face_pixel_values)
134
-
135
- grid_sizes = torch.stack(
136
- [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
137
- x = [u.flatten(2).transpose(1, 2) for u in x]
138
- seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
139
- if self.sp_world_size > 1:
140
- seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
141
- assert seq_lens.max() <= seq_len
142
- x = torch.cat([
143
- torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
144
- dim=1) for u in x
145
- ])
146
-
147
- # time embeddings
148
- with amp.autocast(dtype=torch.float32):
149
- e = self.time_embedding(
150
- sinusoidal_embedding_1d(self.freq_dim, t).float()
151
- )
152
- e0 = self.time_projection(e).unflatten(1, (6, self.dim))
153
- assert e.dtype == torch.float32 and e0.dtype == torch.float32
154
-
155
- # context
156
- context_lens = None
157
- context = self.text_embedding(
158
- torch.stack([
159
- torch.cat(
160
- [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
161
- for u in context
162
- ]))
163
-
164
- if self.use_img_emb:
165
- context_clip = self.img_emb(clip_fea) # bs x 257 x dim
166
- context = torch.concat([context_clip, context], dim=1)
167
-
168
- # Context Parallel
169
- if self.sp_world_size > 1:
170
- x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
171
- if t.dim() != 1:
172
- e0 = torch.chunk(e0, self.sp_world_size, dim=1)[self.sp_world_rank]
173
- e = torch.chunk(e, self.sp_world_size, dim=1)[self.sp_world_rank]
174
-
175
- # TeaCache
176
- if self.teacache is not None:
177
- if cond_flag:
178
- if t.dim() != 1:
179
- modulated_inp = e0[0][:, -1, :]
180
- else:
181
- modulated_inp = e0[0]
182
- skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
183
- if skip_flag:
184
- self.should_calc = True
185
- self.teacache.accumulated_rel_l1_distance = 0
186
- else:
187
- if cond_flag:
188
- rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
189
- self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
190
- if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
191
- self.should_calc = False
192
- else:
193
- self.should_calc = True
194
- self.teacache.accumulated_rel_l1_distance = 0
195
- self.teacache.previous_modulated_input = modulated_inp
196
- self.teacache.should_calc = self.should_calc
197
- else:
198
- self.should_calc = self.teacache.should_calc
199
-
200
- # TeaCache
201
- if self.teacache is not None:
202
- if not self.should_calc:
203
- previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
204
- x = x + previous_residual.to(x.device)[-x.size()[0]:,]
205
- else:
206
- ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
207
- for idx, block in enumerate(self.blocks):
208
- if torch.is_grad_enabled() and self.gradient_checkpointing:
209
-
210
- def create_custom_forward(module):
211
- def custom_forward(*inputs):
212
- return module(*inputs)
213
-
214
- return custom_forward
215
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
216
- x = torch.utils.checkpoint.checkpoint(
217
- create_custom_forward(block),
218
- x,
219
- e0,
220
- seq_lens,
221
- grid_sizes,
222
- self.freqs,
223
- context,
224
- context_lens,
225
- dtype,
226
- t,
227
- **ckpt_kwargs,
228
- )
229
- x, motion_vec = x.to(dtype), motion_vec.to(dtype)
230
- x = self.after_transformer_block(idx, x, motion_vec)
231
- else:
232
- # arguments
233
- kwargs = dict(
234
- e=e0,
235
- seq_lens=seq_lens,
236
- grid_sizes=grid_sizes,
237
- freqs=self.freqs,
238
- context=context,
239
- context_lens=context_lens,
240
- dtype=dtype,
241
- t=t
242
- )
243
- x = block(x, **kwargs)
244
- x, motion_vec = x.to(dtype), motion_vec.to(dtype)
245
- x = self.after_transformer_block(idx, x, motion_vec)
246
-
247
- if cond_flag:
248
- self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
249
- else:
250
- self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
251
- else:
252
- for idx, block in enumerate(self.blocks):
253
- if torch.is_grad_enabled() and self.gradient_checkpointing:
254
-
255
- def create_custom_forward(module):
256
- def custom_forward(*inputs):
257
- return module(*inputs)
258
-
259
- return custom_forward
260
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
261
- x = torch.utils.checkpoint.checkpoint(
262
- create_custom_forward(block),
263
- x,
264
- e0,
265
- seq_lens,
266
- grid_sizes,
267
- self.freqs,
268
- context,
269
- context_lens,
270
- dtype,
271
- t,
272
- **ckpt_kwargs,
273
- )
274
- x, motion_vec = x.to(dtype), motion_vec.to(dtype)
275
- x = self.after_transformer_block(idx, x, motion_vec)
276
- else:
277
- # arguments
278
- kwargs = dict(
279
- e=e0,
280
- seq_lens=seq_lens,
281
- grid_sizes=grid_sizes,
282
- freqs=self.freqs,
283
- context=context,
284
- context_lens=context_lens,
285
- dtype=dtype,
286
- t=t
287
- )
288
- x = block(x, **kwargs)
289
- x, motion_vec = x.to(dtype), motion_vec.to(dtype)
290
- x = self.after_transformer_block(idx, x, motion_vec)
291
-
292
- # head
293
- x = self.head(x, e)
294
-
295
- # Context Parallel
296
- if self.sp_world_size > 1:
297
- x = self.all_gather(x.contiguous(), dim=1)
298
-
299
- # unpatchify
300
- x = self.unpatchify(x, grid_sizes)
301
- x = torch.stack(x)
302
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_transformer3d_s2v.py DELETED
@@ -1,932 +0,0 @@
1
- # Modified from https://github.com/Wan-Video/Wan2.2/blob/main/wan/modules/s2v/model_s2v.py
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
-
4
- import math
5
- import types
6
- from copy import deepcopy
7
- from typing import Any, Dict
8
-
9
- import torch
10
- import torch.cuda.amp as amp
11
- import torch.nn as nn
12
- from diffusers.configuration_utils import register_to_config
13
- from diffusers.utils import is_torch_version
14
- from einops import rearrange
15
-
16
- from ..dist import (get_sequence_parallel_rank,
17
- get_sequence_parallel_world_size, get_sp_group,
18
- usp_attn_s2v_forward)
19
- from .attention_utils import attention
20
- from .wan_audio_injector import (AudioInjector_WAN, CausalAudioEncoder,
21
- FramePackMotioner, MotionerTransformers,
22
- rope_precompute)
23
- from .wan_transformer3d import (Wan2_2Transformer3DModel, WanAttentionBlock,
24
- WanLayerNorm, WanSelfAttention,
25
- sinusoidal_embedding_1d)
26
- from ..utils import cfg_skip
27
-
28
-
29
- def zero_module(module):
30
- """
31
- Zero out the parameters of a module and return it.
32
- """
33
- for p in module.parameters():
34
- p.detach().zero_()
35
- return module
36
-
37
-
38
- def torch_dfs(model: nn.Module, parent_name='root'):
39
- module_names, modules = [], []
40
- current_name = parent_name if parent_name else 'root'
41
- module_names.append(current_name)
42
- modules.append(model)
43
-
44
- for name, child in model.named_children():
45
- if parent_name:
46
- child_name = f'{parent_name}.{name}'
47
- else:
48
- child_name = name
49
- child_modules, child_names = torch_dfs(child, child_name)
50
- module_names += child_names
51
- modules += child_modules
52
- return modules, module_names
53
-
54
-
55
- @amp.autocast(enabled=False)
56
- @torch.compiler.disable()
57
- def s2v_rope_apply(x, grid_sizes, freqs, start=None):
58
- n, c = x.size(2), x.size(3) // 2
59
- # loop over samples
60
- output = []
61
- for i, _ in enumerate(x):
62
- s = x.size(1)
63
- x_i = torch.view_as_complex(x[i, :s].to(torch.float64).reshape(s, n, -1, 2))
64
- freqs_i = freqs[i, :s]
65
- # apply rotary embedding
66
- x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
67
- x_i = torch.cat([x_i, x[i, s:]])
68
- # append to collection
69
- output.append(x_i)
70
- return torch.stack(output).float()
71
-
72
-
73
- def s2v_rope_apply_qk(q, k, grid_sizes, freqs):
74
- q = s2v_rope_apply(q, grid_sizes, freqs)
75
- k = s2v_rope_apply(k, grid_sizes, freqs)
76
- return q, k
77
-
78
-
79
- class WanS2VSelfAttention(WanSelfAttention):
80
-
81
- def forward(self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16, t=0):
82
- """
83
- Args:
84
- x(Tensor): Shape [B, L, num_heads, C / num_heads]
85
- seq_lens(Tensor): Shape [B]
86
- grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
87
- freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
88
- """
89
- b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
90
-
91
- # query, key, value function
92
- def qkv_fn(x):
93
- q = self.norm_q(self.q(x)).view(b, s, n, d)
94
- k = self.norm_k(self.k(x)).view(b, s, n, d)
95
- v = self.v(x).view(b, s, n, d)
96
- return q, k, v
97
-
98
- q, k, v = qkv_fn(x)
99
-
100
- q, k = s2v_rope_apply_qk(q, k, grid_sizes, freqs)
101
-
102
- x = attention(
103
- q.to(dtype),
104
- k.to(dtype),
105
- v=v.to(dtype),
106
- k_lens=seq_lens,
107
- window_size=self.window_size)
108
- x = x.to(dtype)
109
-
110
- # output
111
- x = x.flatten(2)
112
- x = self.o(x)
113
- return x
114
-
115
-
116
- class WanS2VAttentionBlock(WanAttentionBlock):
117
-
118
- def __init__(self,
119
- cross_attn_type,
120
- dim,
121
- ffn_dim,
122
- num_heads,
123
- window_size=(-1, -1),
124
- qk_norm=True,
125
- cross_attn_norm=False,
126
- eps=1e-6):
127
- super().__init__(
128
- cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps
129
- )
130
- self.self_attn = WanS2VSelfAttention(dim, num_heads, window_size,qk_norm, eps)
131
-
132
- def forward(self, x, e, seq_lens, grid_sizes, freqs, context, context_lens, dtype=torch.bfloat16, t=0):
133
- # e
134
- seg_idx = e[1].item()
135
- seg_idx = min(max(0, seg_idx), x.size(1))
136
- seg_idx = [0, seg_idx, x.size(1)]
137
- e = e[0]
138
- modulation = self.modulation.unsqueeze(2)
139
- e = (modulation + e).chunk(6, dim=1)
140
- e = [element.squeeze(1) for element in e]
141
-
142
- # norm
143
- norm_x = self.norm1(x).float()
144
- parts = []
145
- for i in range(2):
146
- parts.append(norm_x[:, seg_idx[i]:seg_idx[i + 1]] *
147
- (1 + e[1][:, i:i + 1]) + e[0][:, i:i + 1])
148
- norm_x = torch.cat(parts, dim=1)
149
- # self-attention
150
- y = self.self_attn(norm_x, seq_lens, grid_sizes, freqs)
151
- with amp.autocast(dtype=torch.float32):
152
- z = []
153
- for i in range(2):
154
- z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[2][:, i:i + 1])
155
- y = torch.cat(z, dim=1)
156
- x = x + y
157
-
158
- # cross-attention & ffn function
159
- def cross_attn_ffn(x, context, context_lens, e):
160
- x = x + self.cross_attn(self.norm3(x), context, context_lens)
161
- norm2_x = self.norm2(x).float()
162
- parts = []
163
- for i in range(2):
164
- parts.append(norm2_x[:, seg_idx[i]:seg_idx[i + 1]] *
165
- (1 + e[4][:, i:i + 1]) + e[3][:, i:i + 1])
166
- norm2_x = torch.cat(parts, dim=1)
167
- y = self.ffn(norm2_x)
168
- with amp.autocast(dtype=torch.float32):
169
- z = []
170
- for i in range(2):
171
- z.append(y[:, seg_idx[i]:seg_idx[i + 1]] * e[5][:, i:i + 1])
172
- y = torch.cat(z, dim=1)
173
- x = x + y
174
- return x
175
-
176
- x = cross_attn_ffn(x, context, context_lens, e)
177
- return x
178
-
179
-
180
- class Wan2_2Transformer3DModel_S2V(Wan2_2Transformer3DModel):
181
- # ignore_for_config = [
182
- # 'args', 'kwargs', 'patch_size', 'cross_attn_norm', 'qk_norm',
183
- # 'text_dim', 'window_size'
184
- # ]
185
- # _no_split_modules = ['WanS2VAttentionBlock']
186
-
187
- @register_to_config
188
- def __init__(
189
- self,
190
- cond_dim=0,
191
- audio_dim=5120,
192
- num_audio_token=4,
193
- enable_adain=False,
194
- adain_mode="attn_norm",
195
- audio_inject_layers=[0, 4, 8, 12, 16, 20, 24, 27],
196
- zero_init=False,
197
- zero_timestep=False,
198
- enable_motioner=True,
199
- add_last_motion=True,
200
- enable_tsm=False,
201
- trainable_token_pos_emb=False,
202
- motion_token_num=1024,
203
- enable_framepack=False, # Mutually exclusive with enable_motioner
204
- framepack_drop_mode="drop",
205
- model_type='s2v',
206
- patch_size=(1, 2, 2),
207
- text_len=512,
208
- in_dim=16,
209
- dim=2048,
210
- ffn_dim=8192,
211
- freq_dim=256,
212
- text_dim=4096,
213
- out_dim=16,
214
- num_heads=16,
215
- num_layers=32,
216
- window_size=(-1, -1),
217
- qk_norm=True,
218
- cross_attn_norm=True,
219
- eps=1e-6,
220
- in_channels=16,
221
- hidden_size=2048,
222
- *args,
223
- **kwargs
224
- ):
225
- super().__init__(
226
- model_type=model_type,
227
- patch_size=patch_size,
228
- text_len=text_len,
229
- in_dim=in_dim,
230
- dim=dim,
231
- ffn_dim=ffn_dim,
232
- freq_dim=freq_dim,
233
- text_dim=text_dim,
234
- out_dim=out_dim,
235
- num_heads=num_heads,
236
- num_layers=num_layers,
237
- window_size=window_size,
238
- qk_norm=qk_norm,
239
- cross_attn_norm=cross_attn_norm,
240
- eps=eps,
241
- in_channels=in_channels,
242
- hidden_size=hidden_size
243
- )
244
-
245
- assert model_type == 's2v'
246
- self.enbale_adain = enable_adain
247
- # Whether to assign 0 value timestep to ref/motion
248
- self.adain_mode = adain_mode
249
- self.zero_timestep = zero_timestep
250
- self.enable_motioner = enable_motioner
251
- self.add_last_motion = add_last_motion
252
- self.enable_framepack = enable_framepack
253
-
254
- # Replace blocks
255
- self.blocks = nn.ModuleList([
256
- WanS2VAttentionBlock("cross_attn", dim, ffn_dim, num_heads, window_size, qk_norm,
257
- cross_attn_norm, eps)
258
- for _ in range(num_layers)
259
- ])
260
-
261
- # init audio injector
262
- all_modules, all_modules_names = torch_dfs(self.blocks, parent_name="root.transformer_blocks")
263
- if cond_dim > 0:
264
- self.cond_encoder = nn.Conv3d(
265
- cond_dim,
266
- self.dim,
267
- kernel_size=self.patch_size,
268
- stride=self.patch_size)
269
- self.trainable_cond_mask = nn.Embedding(3, self.dim)
270
- self.casual_audio_encoder = CausalAudioEncoder(
271
- dim=audio_dim,
272
- out_dim=self.dim,
273
- num_token=num_audio_token,
274
- need_global=enable_adain)
275
- self.audio_injector = AudioInjector_WAN(
276
- all_modules,
277
- all_modules_names,
278
- dim=self.dim,
279
- num_heads=self.num_heads,
280
- inject_layer=audio_inject_layers,
281
- root_net=self,
282
- enable_adain=enable_adain,
283
- adain_dim=self.dim,
284
- need_adain_ont=adain_mode != "attn_norm",
285
- )
286
-
287
- if zero_init:
288
- self.zero_init_weights()
289
-
290
- # init motioner
291
- if enable_motioner and enable_framepack:
292
- raise ValueError(
293
- "enable_motioner and enable_framepack are mutually exclusive, please set one of them to False"
294
- )
295
- if enable_motioner:
296
- motioner_dim = 2048
297
- self.motioner = MotionerTransformers(
298
- patch_size=(2, 4, 4),
299
- dim=motioner_dim,
300
- ffn_dim=motioner_dim,
301
- freq_dim=256,
302
- out_dim=16,
303
- num_heads=16,
304
- num_layers=13,
305
- window_size=(-1, -1),
306
- qk_norm=True,
307
- cross_attn_norm=False,
308
- eps=1e-6,
309
- motion_token_num=motion_token_num,
310
- enable_tsm=enable_tsm,
311
- motion_stride=4,
312
- expand_ratio=2,
313
- trainable_token_pos_emb=trainable_token_pos_emb,
314
- )
315
- self.zip_motion_out = torch.nn.Sequential(
316
- WanLayerNorm(motioner_dim),
317
- zero_module(nn.Linear(motioner_dim, self.dim)))
318
-
319
- self.trainable_token_pos_emb = trainable_token_pos_emb
320
- if trainable_token_pos_emb:
321
- d = self.dim // self.num_heads
322
- x = torch.zeros([1, motion_token_num, self.num_heads, d])
323
- x[..., ::2] = 1
324
-
325
- gride_sizes = [[
326
- torch.tensor([0, 0, 0]).unsqueeze(0).repeat(1, 1),
327
- torch.tensor([
328
- 1, self.motioner.motion_side_len,
329
- self.motioner.motion_side_len
330
- ]).unsqueeze(0).repeat(1, 1),
331
- torch.tensor([
332
- 1, self.motioner.motion_side_len,
333
- self.motioner.motion_side_len
334
- ]).unsqueeze(0).repeat(1, 1),
335
- ]]
336
- token_freqs = s2v_rope_apply(x, gride_sizes, self.freqs)
337
- token_freqs = token_freqs[0, :,
338
- 0].reshape(motion_token_num, -1, 2)
339
- token_freqs = token_freqs * 0.01
340
- self.token_freqs = torch.nn.Parameter(token_freqs)
341
-
342
- if enable_framepack:
343
- self.frame_packer = FramePackMotioner(
344
- inner_dim=self.dim,
345
- num_heads=self.num_heads,
346
- zip_frame_buckets=[1, 2, 16],
347
- drop_mode=framepack_drop_mode)
348
-
349
- def enable_multi_gpus_inference(self,):
350
- self.sp_world_size = get_sequence_parallel_world_size()
351
- self.sp_world_rank = get_sequence_parallel_rank()
352
- self.all_gather = get_sp_group().all_gather
353
- for block in self.blocks:
354
- block.self_attn.forward = types.MethodType(
355
- usp_attn_s2v_forward, block.self_attn)
356
-
357
- def process_motion(self, motion_latents, drop_motion_frames=False):
358
- if drop_motion_frames or motion_latents[0].shape[1] == 0:
359
- return [], []
360
- self.lat_motion_frames = motion_latents[0].shape[1]
361
- mot = [self.patch_embedding(m.unsqueeze(0)) for m in motion_latents]
362
- batch_size = len(mot)
363
-
364
- mot_remb = []
365
- flattern_mot = []
366
- for bs in range(batch_size):
367
- height, width = mot[bs].shape[3], mot[bs].shape[4]
368
- flat_mot = mot[bs].flatten(2).transpose(1, 2).contiguous()
369
- motion_grid_sizes = [[
370
- torch.tensor([-self.lat_motion_frames, 0,
371
- 0]).unsqueeze(0).repeat(1, 1),
372
- torch.tensor([0, height, width]).unsqueeze(0).repeat(1, 1),
373
- torch.tensor([self.lat_motion_frames, height,
374
- width]).unsqueeze(0).repeat(1, 1)
375
- ]]
376
- motion_rope_emb = rope_precompute(
377
- flat_mot.detach().view(1, flat_mot.shape[1], self.num_heads,
378
- self.dim // self.num_heads),
379
- motion_grid_sizes,
380
- self.freqs,
381
- start=None)
382
- mot_remb.append(motion_rope_emb)
383
- flattern_mot.append(flat_mot)
384
- return flattern_mot, mot_remb
385
-
386
- def process_motion_frame_pack(self,
387
- motion_latents,
388
- drop_motion_frames=False,
389
- add_last_motion=2):
390
- flattern_mot, mot_remb = self.frame_packer(motion_latents,
391
- add_last_motion)
392
- if drop_motion_frames:
393
- return [m[:, :0] for m in flattern_mot
394
- ], [m[:, :0] for m in mot_remb]
395
- else:
396
- return flattern_mot, mot_remb
397
-
398
- def process_motion_transformer_motioner(self,
399
- motion_latents,
400
- drop_motion_frames=False,
401
- add_last_motion=True):
402
- batch_size, height, width = len(
403
- motion_latents), motion_latents[0].shape[2] // self.patch_size[
404
- 1], motion_latents[0].shape[3] // self.patch_size[2]
405
-
406
- freqs = self.freqs
407
- device = self.patch_embedding.weight.device
408
- if freqs.device != device:
409
- freqs = freqs.to(device)
410
- if self.trainable_token_pos_emb:
411
- with amp.autocast(dtype=torch.float64):
412
- token_freqs = self.token_freqs.to(torch.float64)
413
- token_freqs = token_freqs / token_freqs.norm(
414
- dim=-1, keepdim=True)
415
- freqs = [freqs, torch.view_as_complex(token_freqs)]
416
-
417
- if not drop_motion_frames and add_last_motion:
418
- last_motion_latent = [u[:, -1:] for u in motion_latents]
419
- last_mot = [
420
- self.patch_embedding(m.unsqueeze(0)) for m in last_motion_latent
421
- ]
422
- last_mot = [m.flatten(2).transpose(1, 2) for m in last_mot]
423
- last_mot = torch.cat(last_mot)
424
- gride_sizes = [[
425
- torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
426
- torch.tensor([0, height,
427
- width]).unsqueeze(0).repeat(batch_size, 1),
428
- torch.tensor([1, height,
429
- width]).unsqueeze(0).repeat(batch_size, 1)
430
- ]]
431
- else:
432
- last_mot = torch.zeros([batch_size, 0, self.dim],
433
- device=motion_latents[0].device,
434
- dtype=motion_latents[0].dtype)
435
- gride_sizes = []
436
-
437
- zip_motion = self.motioner(motion_latents)
438
- zip_motion = self.zip_motion_out(zip_motion)
439
- if drop_motion_frames:
440
- zip_motion = zip_motion * 0.0
441
- zip_motion_grid_sizes = [[
442
- torch.tensor([-1, 0, 0]).unsqueeze(0).repeat(batch_size, 1),
443
- torch.tensor([
444
- 0, self.motioner.motion_side_len, self.motioner.motion_side_len
445
- ]).unsqueeze(0).repeat(batch_size, 1),
446
- torch.tensor(
447
- [1 if not self.trainable_token_pos_emb else -1, height,
448
- width]).unsqueeze(0).repeat(batch_size, 1),
449
- ]]
450
-
451
- mot = torch.cat([last_mot, zip_motion], dim=1)
452
- gride_sizes = gride_sizes + zip_motion_grid_sizes
453
-
454
- motion_rope_emb = rope_precompute(
455
- mot.detach().view(batch_size, mot.shape[1], self.num_heads,
456
- self.dim // self.num_heads),
457
- gride_sizes,
458
- freqs,
459
- start=None)
460
- return [m.unsqueeze(0) for m in mot
461
- ], [r.unsqueeze(0) for r in motion_rope_emb]
462
-
463
- def inject_motion(self,
464
- x,
465
- seq_lens,
466
- rope_embs,
467
- mask_input,
468
- motion_latents,
469
- drop_motion_frames=False,
470
- add_last_motion=True):
471
- # Inject the motion frames token to the hidden states
472
- if self.enable_motioner:
473
- mot, mot_remb = self.process_motion_transformer_motioner(
474
- motion_latents,
475
- drop_motion_frames=drop_motion_frames,
476
- add_last_motion=add_last_motion)
477
- elif self.enable_framepack:
478
- mot, mot_remb = self.process_motion_frame_pack(
479
- motion_latents,
480
- drop_motion_frames=drop_motion_frames,
481
- add_last_motion=add_last_motion)
482
- else:
483
- mot, mot_remb = self.process_motion(
484
- motion_latents, drop_motion_frames=drop_motion_frames)
485
-
486
- if len(mot) > 0:
487
- x = [torch.cat([u, m], dim=1) for u, m in zip(x, mot)]
488
- seq_lens = seq_lens + torch.tensor([r.size(1) for r in mot],
489
- dtype=torch.long)
490
- rope_embs = [
491
- torch.cat([u, m], dim=1) for u, m in zip(rope_embs, mot_remb)
492
- ]
493
- mask_input = [
494
- torch.cat([
495
- m, 2 * torch.ones([1, u.shape[1] - m.shape[1]],
496
- device=m.device,
497
- dtype=m.dtype)
498
- ],
499
- dim=1) for m, u in zip(mask_input, x)
500
- ]
501
- return x, seq_lens, rope_embs, mask_input
502
-
503
- def after_transformer_block(self, block_idx, hidden_states):
504
- if block_idx in self.audio_injector.injected_block_id.keys():
505
- audio_attn_id = self.audio_injector.injected_block_id[block_idx]
506
- audio_emb = self.merged_audio_emb # b f n c
507
- num_frames = audio_emb.shape[1]
508
-
509
- if self.sp_world_size > 1:
510
- hidden_states = self.all_gather(hidden_states, dim=1)
511
-
512
- input_hidden_states = hidden_states[:, :self.original_seq_len].clone()
513
- input_hidden_states = rearrange(
514
- input_hidden_states, "b (t n) c -> (b t) n c", t=num_frames)
515
-
516
- if self.enbale_adain and self.adain_mode == "attn_norm":
517
- audio_emb_global = self.audio_emb_global
518
- audio_emb_global = rearrange(audio_emb_global,
519
- "b t n c -> (b t) n c")
520
- adain_hidden_states = self.audio_injector.injector_adain_layers[audio_attn_id](
521
- input_hidden_states, temb=audio_emb_global[:, 0]
522
- )
523
- attn_hidden_states = adain_hidden_states
524
- else:
525
- attn_hidden_states = self.audio_injector.injector_pre_norm_feat[audio_attn_id](
526
- input_hidden_states
527
- )
528
- audio_emb = rearrange(audio_emb, "b t n c -> (b t) n c", t=num_frames)
529
- attn_audio_emb = audio_emb
530
- context_lens = torch.ones(
531
- attn_hidden_states.shape[0], dtype=torch.long, device=attn_hidden_states.device
532
- ) * attn_audio_emb.shape[1]
533
-
534
- if torch.is_grad_enabled() and self.gradient_checkpointing:
535
- def create_custom_forward(module):
536
- def custom_forward(*inputs):
537
- return module(*inputs)
538
-
539
- return custom_forward
540
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
541
- residual_out = torch.utils.checkpoint.checkpoint(
542
- create_custom_forward(self.audio_injector.injector[audio_attn_id]),
543
- attn_hidden_states,
544
- attn_audio_emb,
545
- context_lens,
546
- **ckpt_kwargs
547
- )
548
- else:
549
- residual_out = self.audio_injector.injector[audio_attn_id](
550
- x=attn_hidden_states,
551
- context=attn_audio_emb,
552
- context_lens=context_lens)
553
- residual_out = rearrange(residual_out, "(b t) n c -> b (t n) c", t=num_frames)
554
- hidden_states[:, :self.original_seq_len] = hidden_states[:, :self.original_seq_len] + residual_out
555
-
556
- if self.sp_world_size > 1:
557
- hidden_states = torch.chunk(
558
- hidden_states, self.sp_world_size, dim=1)[self.sp_world_rank]
559
-
560
- return hidden_states
561
-
562
- @cfg_skip()
563
- def forward(
564
- self,
565
- x,
566
- t,
567
- context,
568
- seq_len,
569
- ref_latents,
570
- motion_latents,
571
- cond_states,
572
- audio_input=None,
573
- motion_frames=[17, 5],
574
- add_last_motion=2,
575
- drop_motion_frames=False,
576
- cond_flag=True,
577
- *extra_args,
578
- **extra_kwargs
579
- ):
580
- """
581
- x: A list of videos each with shape [C, T, H, W].
582
- t: [B].
583
- context: A list of text embeddings each with shape [L, C].
584
- seq_len: A list of video token lens, no need for this model.
585
- ref_latents A list of reference image for each video with shape [C, 1, H, W].
586
- motion_latents A list of motion frames for each video with shape [C, T_m, H, W].
587
- cond_states A list of condition frames (i.e. pose) each with shape [C, T, H, W].
588
- audio_input The input audio embedding [B, num_wav2vec_layer, C_a, T_a].
589
- motion_frames The number of motion frames and motion latents frames encoded by vae, i.e. [17, 5]
590
- add_last_motion For the motioner, if add_last_motion > 0, it means that the most recent frame (i.e., the last frame) will be added.
591
- For frame packing, the behavior depends on the value of add_last_motion:
592
- add_last_motion = 0: Only the farthest part of the latent (i.e., clean_latents_4x) is included.
593
- add_last_motion = 1: Both clean_latents_2x and clean_latents_4x are included.
594
- add_last_motion = 2: All motion-related latents are used.
595
- drop_motion_frames Bool, whether drop the motion frames info
596
- """
597
- device = self.patch_embedding.weight.device
598
- dtype = x.dtype
599
- if self.freqs.device != device and torch.device(type="meta") != device:
600
- self.freqs = self.freqs.to(device)
601
- add_last_motion = self.add_last_motion * add_last_motion
602
-
603
- # Embeddings
604
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
605
-
606
- if isinstance(motion_frames[0], list):
607
- motion_frames_0 = motion_frames[0][0]
608
- motion_frames_1 = motion_frames[0][1]
609
- else:
610
- motion_frames_0 = motion_frames[0]
611
- motion_frames_1 = motion_frames[1]
612
- # Audio process
613
- audio_input = torch.cat([audio_input[..., 0:1].repeat(1, 1, 1, motion_frames_0), audio_input], dim=-1)
614
- if torch.is_grad_enabled() and self.gradient_checkpointing:
615
- def create_custom_forward(module):
616
- def custom_forward(*inputs):
617
- return module(*inputs)
618
-
619
- return custom_forward
620
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
621
- audio_emb_res = torch.utils.checkpoint.checkpoint(create_custom_forward(self.casual_audio_encoder), audio_input, **ckpt_kwargs)
622
- else:
623
- audio_emb_res = self.casual_audio_encoder(audio_input)
624
- if self.enbale_adain:
625
- audio_emb_global, audio_emb = audio_emb_res
626
- self.audio_emb_global = audio_emb_global[:, motion_frames_1:].clone()
627
- else:
628
- audio_emb = audio_emb_res
629
- self.merged_audio_emb = audio_emb[:, motion_frames_1:, :]
630
-
631
- # Cond states
632
- cond = [self.cond_encoder(c.unsqueeze(0)) for c in cond_states]
633
- x = [x_ + pose for x_, pose in zip(x, cond)]
634
-
635
- grid_sizes = torch.stack(
636
- [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
637
- x = [u.flatten(2).transpose(1, 2) for u in x]
638
- seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
639
-
640
- original_grid_sizes = deepcopy(grid_sizes)
641
- grid_sizes = [[torch.zeros_like(grid_sizes), grid_sizes, grid_sizes]]
642
-
643
- # Ref latents
644
- ref = [self.patch_embedding(r.unsqueeze(0)) for r in ref_latents]
645
- batch_size = len(ref)
646
- height, width = ref[0].shape[3], ref[0].shape[4]
647
- ref = [r.flatten(2).transpose(1, 2) for r in ref] # r: 1 c f h w
648
- x = [torch.cat([u, r], dim=1) for u, r in zip(x, ref)]
649
-
650
- self.original_seq_len = seq_lens[0]
651
- seq_lens = seq_lens + torch.tensor([r.size(1) for r in ref], dtype=torch.long)
652
- ref_grid_sizes = [
653
- [
654
- torch.tensor([30, 0, 0]).unsqueeze(0).repeat(batch_size, 1), # the start index
655
- torch.tensor([31, height,width]).unsqueeze(0).repeat(batch_size, 1), # the end index
656
- torch.tensor([1, height, width]).unsqueeze(0).repeat(batch_size, 1),
657
- ] # the range
658
- ]
659
- grid_sizes = grid_sizes + ref_grid_sizes
660
-
661
- # Compute the rope embeddings for the input
662
- x = torch.cat(x)
663
- b, s, n, d = x.size(0), x.size(1), self.num_heads, self.dim // self.num_heads
664
- self.pre_compute_freqs = rope_precompute(
665
- x.detach().view(b, s, n, d), grid_sizes, self.freqs, start=None)
666
- x = [u.unsqueeze(0) for u in x]
667
- self.pre_compute_freqs = [u.unsqueeze(0) for u in self.pre_compute_freqs]
668
-
669
- # Inject Motion latents.
670
- # Initialize masks to indicate noisy latent, ref latent, and motion latent.
671
- # However, at this point, only the first two (noisy and ref latents) are marked;
672
- # the marking of motion latent will be implemented inside `inject_motion`.
673
- mask_input = [
674
- torch.zeros([1, u.shape[1]], dtype=torch.long, device=x[0].device)
675
- for u in x
676
- ]
677
- for i in range(len(mask_input)):
678
- mask_input[i][:, self.original_seq_len:] = 1
679
-
680
- self.lat_motion_frames = motion_latents[0].shape[1]
681
- x, seq_lens, self.pre_compute_freqs, mask_input = self.inject_motion(
682
- x,
683
- seq_lens,
684
- self.pre_compute_freqs,
685
- mask_input,
686
- motion_latents,
687
- drop_motion_frames=drop_motion_frames,
688
- add_last_motion=add_last_motion)
689
- x = torch.cat(x, dim=0)
690
- self.pre_compute_freqs = torch.cat(self.pre_compute_freqs, dim=0)
691
- mask_input = torch.cat(mask_input, dim=0)
692
-
693
- # Apply trainable_cond_mask
694
- x = x + self.trainable_cond_mask(mask_input).to(x.dtype)
695
-
696
- seq_len = seq_lens.max()
697
- if self.sp_world_size > 1:
698
- seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
699
- assert seq_lens.max() <= seq_len
700
- x = torch.cat([
701
- torch.cat([u.unsqueeze(0), u.new_zeros(1, seq_len - u.size(0), u.size(1))],
702
- dim=1) for u in x
703
- ])
704
-
705
- # Time embeddings
706
- if self.zero_timestep:
707
- t = torch.cat([t, torch.zeros([1], dtype=t.dtype, device=t.device)])
708
- with amp.autocast(dtype=torch.float32):
709
- e = self.time_embedding(
710
- sinusoidal_embedding_1d(self.freq_dim, t).float())
711
- e0 = self.time_projection(e).unflatten(1, (6, self.dim))
712
- assert e.dtype == torch.float32 and e0.dtype == torch.float32
713
-
714
- if self.zero_timestep:
715
- e = e[:-1]
716
- zero_e0 = e0[-1:]
717
- e0 = e0[:-1]
718
- token_len = x.shape[1]
719
-
720
- e0 = torch.cat(
721
- [
722
- e0.unsqueeze(2),
723
- zero_e0.unsqueeze(2).repeat(e0.size(0), 1, 1, 1)
724
- ],
725
- dim=2
726
- )
727
- e0 = [e0, self.original_seq_len]
728
- else:
729
- e0 = e0.unsqueeze(2).repeat(1, 1, 2, 1)
730
- e0 = [e0, 0]
731
-
732
- # context
733
- context_lens = None
734
- context = self.text_embedding(
735
- torch.stack([
736
- torch.cat(
737
- [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
738
- for u in context
739
- ]))
740
-
741
- if self.sp_world_size > 1:
742
- # Sharded tensors for long context attn
743
- x = torch.chunk(x, self.sp_world_size, dim=1)
744
- sq_size = [u.shape[1] for u in x]
745
- sq_start_size = sum(sq_size[:self.sp_world_rank])
746
- x = x[self.sp_world_rank]
747
- # Confirm the application range of the time embedding in e0[0] for each sequence:
748
- # - For tokens before seg_id: apply e0[0][:, :, 0]
749
- # - For tokens after seg_id: apply e0[0][:, :, 1]
750
- sp_size = x.shape[1]
751
- seg_idx = e0[1] - sq_start_size
752
- e0[1] = seg_idx
753
-
754
- self.pre_compute_freqs = torch.chunk(self.pre_compute_freqs, self.sp_world_size, dim=1)
755
- self.pre_compute_freqs = self.pre_compute_freqs[self.sp_world_rank]
756
-
757
- # TeaCache
758
- if self.teacache is not None:
759
- if cond_flag:
760
- if t.dim() != 1:
761
- modulated_inp = e0[0][:, -1, :]
762
- else:
763
- modulated_inp = e0[0]
764
- skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
765
- if skip_flag:
766
- self.should_calc = True
767
- self.teacache.accumulated_rel_l1_distance = 0
768
- else:
769
- if cond_flag:
770
- rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
771
- self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
772
- if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
773
- self.should_calc = False
774
- else:
775
- self.should_calc = True
776
- self.teacache.accumulated_rel_l1_distance = 0
777
- self.teacache.previous_modulated_input = modulated_inp
778
- self.teacache.should_calc = self.should_calc
779
- else:
780
- self.should_calc = self.teacache.should_calc
781
-
782
- # TeaCache
783
- if self.teacache is not None:
784
- if not self.should_calc:
785
- previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
786
- x = x + previous_residual.to(x.device)[-x.size()[0]:,]
787
- else:
788
- ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
789
-
790
- for idx, block in enumerate(self.blocks):
791
- if torch.is_grad_enabled() and self.gradient_checkpointing:
792
-
793
- def create_custom_forward(module):
794
- def custom_forward(*inputs):
795
- return module(*inputs)
796
-
797
- return custom_forward
798
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
799
- x = torch.utils.checkpoint.checkpoint(
800
- create_custom_forward(block),
801
- x,
802
- e0,
803
- seq_lens,
804
- grid_sizes,
805
- self.pre_compute_freqs,
806
- context,
807
- context_lens,
808
- dtype,
809
- t,
810
- **ckpt_kwargs,
811
- )
812
- x = self.after_transformer_block(idx, x)
813
- else:
814
- # arguments
815
- kwargs = dict(
816
- e=e0,
817
- seq_lens=seq_lens,
818
- grid_sizes=grid_sizes,
819
- freqs=self.pre_compute_freqs,
820
- context=context,
821
- context_lens=context_lens,
822
- dtype=dtype,
823
- t=t
824
- )
825
- x = block(x, **kwargs)
826
- x = self.after_transformer_block(idx, x)
827
-
828
- if cond_flag:
829
- self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
830
- else:
831
- self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
832
- else:
833
- for idx, block in enumerate(self.blocks):
834
- if torch.is_grad_enabled() and self.gradient_checkpointing:
835
-
836
- def create_custom_forward(module):
837
- def custom_forward(*inputs):
838
- return module(*inputs)
839
-
840
- return custom_forward
841
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
842
- x = torch.utils.checkpoint.checkpoint(
843
- create_custom_forward(block),
844
- x,
845
- e0,
846
- seq_lens,
847
- grid_sizes,
848
- self.pre_compute_freqs,
849
- context,
850
- context_lens,
851
- dtype,
852
- t,
853
- **ckpt_kwargs,
854
- )
855
- x = self.after_transformer_block(idx, x)
856
- else:
857
- # arguments
858
- kwargs = dict(
859
- e=e0,
860
- seq_lens=seq_lens,
861
- grid_sizes=grid_sizes,
862
- freqs=self.pre_compute_freqs,
863
- context=context,
864
- context_lens=context_lens,
865
- dtype=dtype,
866
- t=t
867
- )
868
- x = block(x, **kwargs)
869
- x = self.after_transformer_block(idx, x)
870
-
871
- # Context Parallel
872
- if self.sp_world_size > 1:
873
- x = self.all_gather(x.contiguous(), dim=1)
874
-
875
- # Unpatchify
876
- x = x[:, :self.original_seq_len]
877
- # head
878
- if torch.is_grad_enabled() and self.gradient_checkpointing:
879
- def create_custom_forward(module):
880
- def custom_forward(*inputs):
881
- return module(*inputs)
882
-
883
- return custom_forward
884
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
885
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
886
- else:
887
- x = self.head(x, e)
888
- x = self.unpatchify(x, original_grid_sizes)
889
- x = torch.stack(x)
890
- if self.teacache is not None and cond_flag:
891
- self.teacache.cnt += 1
892
- if self.teacache.cnt == self.teacache.num_steps:
893
- self.teacache.reset()
894
- return x
895
-
896
- def unpatchify(self, x, grid_sizes):
897
- """
898
- Reconstruct video tensors from patch embeddings.
899
-
900
- Args:
901
- x (List[Tensor]):
902
- List of patchified features, each with shape [L, C_out * prod(patch_size)]
903
- grid_sizes (Tensor):
904
- Original spatial-temporal grid dimensions before patching,
905
- shape [B, 3] (3 dimensions correspond to F_patches, H_patches, W_patches)
906
-
907
- Returns:
908
- List[Tensor]:
909
- Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
910
- """
911
-
912
- c = self.out_dim
913
- out = []
914
- for u, v in zip(x, grid_sizes.tolist()):
915
- u = u[:math.prod(v)].view(*v, *self.patch_size, c)
916
- u = torch.einsum('fhwpqrc->cfphqwr', u)
917
- u = u.reshape(c, *[i * j for i, j in zip(v, self.patch_size)])
918
- out.append(u)
919
- return out
920
-
921
- def zero_init_weights(self):
922
- with torch.no_grad():
923
- self.trainable_cond_mask = zero_module(self.trainable_cond_mask)
924
- if hasattr(self, "cond_encoder"):
925
- self.cond_encoder = zero_module(self.cond_encoder)
926
-
927
- for i in range(self.audio_injector.injector.__len__()):
928
- self.audio_injector.injector[i].o = zero_module(
929
- self.audio_injector.injector[i].o)
930
- if self.enbale_adain:
931
- self.audio_injector.injector_adain_layers[i].linear = \
932
- zero_module(self.audio_injector.injector_adain_layers[i].linear)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_transformer3d_vace.py DELETED
@@ -1,394 +0,0 @@
1
- # Modified from https://github.com/ali-vilab/VACE/blob/main/vace/models/wan/wan_vace.py
2
- # -*- coding: utf-8 -*-
3
- # Copyright (c) Alibaba, Inc. and its affiliates.
4
- from typing import Any, Dict
5
-
6
- import os
7
- import math
8
- import torch
9
- import torch.cuda.amp as amp
10
- import torch.nn as nn
11
- from diffusers.configuration_utils import register_to_config
12
- from diffusers.utils import is_torch_version
13
-
14
- from .wan_transformer3d import (WanAttentionBlock, WanTransformer3DModel,
15
- sinusoidal_embedding_1d)
16
- from ..utils import cfg_skip
17
-
18
-
19
- VIDEOX_OFFLOAD_VACE_LATENTS = os.environ.get("VIDEOX_OFFLOAD_VACE_LATENTS", False)
20
-
21
- class VaceWanAttentionBlock(WanAttentionBlock):
22
- def __init__(
23
- self,
24
- cross_attn_type,
25
- dim,
26
- ffn_dim,
27
- num_heads,
28
- window_size=(-1, -1),
29
- qk_norm=True,
30
- cross_attn_norm=False,
31
- eps=1e-6,
32
- block_id=0
33
- ):
34
- super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
35
- self.block_id = block_id
36
- if block_id == 0:
37
- self.before_proj = nn.Linear(self.dim, self.dim)
38
- nn.init.zeros_(self.before_proj.weight)
39
- nn.init.zeros_(self.before_proj.bias)
40
- self.after_proj = nn.Linear(self.dim, self.dim)
41
- nn.init.zeros_(self.after_proj.weight)
42
- nn.init.zeros_(self.after_proj.bias)
43
-
44
- def forward(self, c, x, **kwargs):
45
- if self.block_id == 0:
46
- c = self.before_proj(c) + x
47
- all_c = []
48
- else:
49
- all_c = list(torch.unbind(c))
50
- c = all_c.pop(-1)
51
-
52
- if VIDEOX_OFFLOAD_VACE_LATENTS:
53
- c = c.to(x.device)
54
-
55
- c = super().forward(c, **kwargs)
56
- c_skip = self.after_proj(c)
57
-
58
- if VIDEOX_OFFLOAD_VACE_LATENTS:
59
- c_skip = c_skip.to("cpu")
60
- c = c.to("cpu")
61
-
62
- all_c += [c_skip, c]
63
- c = torch.stack(all_c)
64
- return c
65
-
66
-
67
- class BaseWanAttentionBlock(WanAttentionBlock):
68
- def __init__(
69
- self,
70
- cross_attn_type,
71
- dim,
72
- ffn_dim,
73
- num_heads,
74
- window_size=(-1, -1),
75
- qk_norm=True,
76
- cross_attn_norm=False,
77
- eps=1e-6,
78
- block_id=None
79
- ):
80
- super().__init__(cross_attn_type, dim, ffn_dim, num_heads, window_size, qk_norm, cross_attn_norm, eps)
81
- self.block_id = block_id
82
-
83
- def forward(self, x, hints, context_scale=1.0, **kwargs):
84
- x = super().forward(x, **kwargs)
85
- if self.block_id is not None:
86
- if VIDEOX_OFFLOAD_VACE_LATENTS:
87
- x = x + hints[self.block_id].to(x.device) * context_scale
88
- else:
89
- x = x + hints[self.block_id] * context_scale
90
- return x
91
-
92
-
93
- class VaceWanTransformer3DModel(WanTransformer3DModel):
94
- @register_to_config
95
- def __init__(self,
96
- vace_layers=None,
97
- vace_in_dim=None,
98
- model_type='t2v',
99
- patch_size=(1, 2, 2),
100
- text_len=512,
101
- in_dim=16,
102
- dim=2048,
103
- ffn_dim=8192,
104
- freq_dim=256,
105
- text_dim=4096,
106
- out_dim=16,
107
- num_heads=16,
108
- num_layers=32,
109
- window_size=(-1, -1),
110
- qk_norm=True,
111
- cross_attn_norm=True,
112
- eps=1e-6):
113
- model_type = "t2v" # TODO: Hard code for both preview and official versions.
114
- super().__init__(model_type, patch_size, text_len, in_dim, dim, ffn_dim, freq_dim, text_dim, out_dim,
115
- num_heads, num_layers, window_size, qk_norm, cross_attn_norm, eps)
116
-
117
- self.vace_layers = [i for i in range(0, self.num_layers, 2)] if vace_layers is None else vace_layers
118
- self.vace_in_dim = self.in_dim if vace_in_dim is None else vace_in_dim
119
-
120
- assert 0 in self.vace_layers
121
- self.vace_layers_mapping = {i: n for n, i in enumerate(self.vace_layers)}
122
-
123
- # blocks
124
- self.blocks = nn.ModuleList([
125
- BaseWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
126
- self.cross_attn_norm, self.eps,
127
- block_id=self.vace_layers_mapping[i] if i in self.vace_layers else None)
128
- for i in range(self.num_layers)
129
- ])
130
-
131
- # vace blocks
132
- self.vace_blocks = nn.ModuleList([
133
- VaceWanAttentionBlock('t2v_cross_attn', self.dim, self.ffn_dim, self.num_heads, self.window_size, self.qk_norm,
134
- self.cross_attn_norm, self.eps, block_id=i)
135
- for i in self.vace_layers
136
- ])
137
-
138
- # vace patch embeddings
139
- self.vace_patch_embedding = nn.Conv3d(
140
- self.vace_in_dim, self.dim, kernel_size=self.patch_size, stride=self.patch_size
141
- )
142
-
143
- def forward_vace(
144
- self,
145
- x,
146
- vace_context,
147
- seq_len,
148
- kwargs
149
- ):
150
- # embeddings
151
- c = [self.vace_patch_embedding(u.unsqueeze(0)) for u in vace_context]
152
- c = [u.flatten(2).transpose(1, 2) for u in c]
153
- c = torch.cat([
154
- torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
155
- dim=1) for u in c
156
- ])
157
- # Context Parallel
158
- if self.sp_world_size > 1:
159
- c = torch.chunk(c, self.sp_world_size, dim=1)[self.sp_world_rank]
160
-
161
- # arguments
162
- new_kwargs = dict(x=x)
163
- new_kwargs.update(kwargs)
164
-
165
- for block in self.vace_blocks:
166
- if torch.is_grad_enabled() and self.gradient_checkpointing:
167
- def create_custom_forward(module, **static_kwargs):
168
- def custom_forward(*inputs):
169
- return module(*inputs, **static_kwargs)
170
- return custom_forward
171
- ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
172
- c = torch.utils.checkpoint.checkpoint(
173
- create_custom_forward(block, **new_kwargs),
174
- c,
175
- **ckpt_kwargs,
176
- )
177
- else:
178
- c = block(c, **new_kwargs)
179
- hints = torch.unbind(c)[:-1]
180
- return hints
181
-
182
- @cfg_skip()
183
- def forward(
184
- self,
185
- x,
186
- t,
187
- vace_context,
188
- context,
189
- seq_len,
190
- vace_context_scale=1.0,
191
- clip_fea=None,
192
- y=None,
193
- cond_flag=True
194
- ):
195
- r"""
196
- Forward pass through the diffusion model
197
-
198
- Args:
199
- x (List[Tensor]):
200
- List of input video tensors, each with shape [C_in, F, H, W]
201
- t (Tensor):
202
- Diffusion timesteps tensor of shape [B]
203
- context (List[Tensor]):
204
- List of text embeddings each with shape [L, C]
205
- seq_len (`int`):
206
- Maximum sequence length for positional encoding
207
- clip_fea (Tensor, *optional*):
208
- CLIP image features for image-to-video mode
209
- y (List[Tensor], *optional*):
210
- Conditional video inputs for image-to-video mode, same shape as x
211
-
212
- Returns:
213
- List[Tensor]:
214
- List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
215
- """
216
- # if self.model_type == 'i2v':
217
- # assert clip_fea is not None and y is not None
218
- # params
219
- device = self.patch_embedding.weight.device
220
- dtype = x.dtype
221
- if self.freqs.device != device and torch.device(type="meta") != device:
222
- self.freqs = self.freqs.to(device)
223
-
224
- # if y is not None:
225
- # x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
226
-
227
- # embeddings
228
- x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
229
- grid_sizes = torch.stack(
230
- [torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
231
- x = [u.flatten(2).transpose(1, 2) for u in x]
232
- seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long)
233
- if self.sp_world_size > 1:
234
- seq_len = int(math.ceil(seq_len / self.sp_world_size)) * self.sp_world_size
235
- assert seq_lens.max() <= seq_len
236
- x = torch.cat([
237
- torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
238
- dim=1) for u in x
239
- ])
240
-
241
- # time embeddings
242
- with amp.autocast(dtype=torch.float32):
243
- e = self.time_embedding(
244
- sinusoidal_embedding_1d(self.freq_dim, t).float())
245
- e0 = self.time_projection(e).unflatten(1, (6, self.dim))
246
- assert e.dtype == torch.float32 and e0.dtype == torch.float32
247
-
248
- # context
249
- context_lens = None
250
- context = self.text_embedding(
251
- torch.stack([
252
- torch.cat(
253
- [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
254
- for u in context
255
- ]))
256
-
257
- # Context Parallel
258
- if self.sp_world_size > 1:
259
- x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
260
-
261
- # arguments
262
- kwargs = dict(
263
- e=e0,
264
- seq_lens=seq_lens,
265
- grid_sizes=grid_sizes,
266
- freqs=self.freqs,
267
- context=context,
268
- context_lens=context_lens,
269
- dtype=dtype,
270
- t=t)
271
- hints = self.forward_vace(x, vace_context, seq_len, kwargs)
272
-
273
- kwargs['hints'] = hints
274
- kwargs['context_scale'] = vace_context_scale
275
-
276
- # TeaCache
277
- if self.teacache is not None:
278
- if cond_flag:
279
- if t.dim() != 1:
280
- modulated_inp = e0[:, -1, :]
281
- else:
282
- modulated_inp = e0
283
- skip_flag = self.teacache.cnt < self.teacache.num_skip_start_steps
284
- if skip_flag:
285
- self.should_calc = True
286
- self.teacache.accumulated_rel_l1_distance = 0
287
- else:
288
- if cond_flag:
289
- rel_l1_distance = self.teacache.compute_rel_l1_distance(self.teacache.previous_modulated_input, modulated_inp)
290
- self.teacache.accumulated_rel_l1_distance += self.teacache.rescale_func(rel_l1_distance)
291
- if self.teacache.accumulated_rel_l1_distance < self.teacache.rel_l1_thresh:
292
- self.should_calc = False
293
- else:
294
- self.should_calc = True
295
- self.teacache.accumulated_rel_l1_distance = 0
296
- self.teacache.previous_modulated_input = modulated_inp
297
- self.teacache.should_calc = self.should_calc
298
- else:
299
- self.should_calc = self.teacache.should_calc
300
-
301
- # TeaCache
302
- if self.teacache is not None:
303
- if not self.should_calc:
304
- previous_residual = self.teacache.previous_residual_cond if cond_flag else self.teacache.previous_residual_uncond
305
- x = x + previous_residual.to(x.device)[-x.size()[0]:,]
306
- else:
307
- ori_x = x.clone().cpu() if self.teacache.offload else x.clone()
308
-
309
- for block in self.blocks:
310
- if torch.is_grad_enabled() and self.gradient_checkpointing:
311
- def create_custom_forward(module, **static_kwargs):
312
- def custom_forward(*inputs):
313
- return module(*inputs, **static_kwargs)
314
- return custom_forward
315
- extra_kwargs = {
316
- 'e': e0,
317
- 'seq_lens': seq_lens,
318
- 'grid_sizes': grid_sizes,
319
- 'freqs': self.freqs,
320
- 'context': context,
321
- 'context_lens': context_lens,
322
- 'dtype': dtype,
323
- 't': t,
324
- }
325
-
326
- ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
327
-
328
- x = torch.utils.checkpoint.checkpoint(
329
- create_custom_forward(block, **extra_kwargs),
330
- x,
331
- hints,
332
- vace_context_scale,
333
- **ckpt_kwargs,
334
- )
335
- else:
336
- x = block(x, **kwargs)
337
-
338
- if cond_flag:
339
- self.teacache.previous_residual_cond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
340
- else:
341
- self.teacache.previous_residual_uncond = x.cpu() - ori_x if self.teacache.offload else x - ori_x
342
- else:
343
- for block in self.blocks:
344
- if torch.is_grad_enabled() and self.gradient_checkpointing:
345
- def create_custom_forward(module, **static_kwargs):
346
- def custom_forward(*inputs):
347
- return module(*inputs, **static_kwargs)
348
- return custom_forward
349
- extra_kwargs = {
350
- 'e': e0,
351
- 'seq_lens': seq_lens,
352
- 'grid_sizes': grid_sizes,
353
- 'freqs': self.freqs,
354
- 'context': context,
355
- 'context_lens': context_lens,
356
- 'dtype': dtype,
357
- 't': t,
358
- }
359
-
360
- ckpt_kwargs = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
361
-
362
- x = torch.utils.checkpoint.checkpoint(
363
- create_custom_forward(block, **extra_kwargs),
364
- x,
365
- hints,
366
- vace_context_scale,
367
- **ckpt_kwargs,
368
- )
369
- else:
370
- x = block(x, **kwargs)
371
-
372
- # head
373
- if torch.is_grad_enabled() and self.gradient_checkpointing:
374
- def create_custom_forward(module):
375
- def custom_forward(*inputs):
376
- return module(*inputs)
377
-
378
- return custom_forward
379
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
380
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.head), x, e, **ckpt_kwargs)
381
- else:
382
- x = self.head(x, e)
383
-
384
- if self.sp_world_size > 1:
385
- x = self.all_gather(x, dim=1)
386
-
387
- # unpatchify
388
- x = self.unpatchify(x, grid_sizes)
389
- x = torch.stack(x)
390
- if self.teacache is not None and cond_flag:
391
- self.teacache.cnt += 1
392
- if self.teacache.cnt == self.teacache.num_steps:
393
- self.teacache.reset()
394
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_vae.py DELETED
@@ -1,860 +0,0 @@
1
- # Modified from https://github.com/Wan-Video/Wan2.1/blob/main/wan/modules/vae.py
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- from typing import Tuple, Union
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- from diffusers.configuration_utils import ConfigMixin, register_to_config
9
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
10
- from diffusers.models.autoencoders.vae import (DecoderOutput,
11
- DiagonalGaussianDistribution)
12
- from diffusers.models.modeling_outputs import AutoencoderKLOutput
13
- from diffusers.models.modeling_utils import ModelMixin
14
- from diffusers.utils.accelerate_utils import apply_forward_hook
15
- from einops import rearrange
16
-
17
-
18
- CACHE_T = 2
19
-
20
-
21
- class CausalConv3d(nn.Conv3d):
22
- """
23
- Causal 3d convolusion.
24
- """
25
-
26
- def __init__(self, *args, **kwargs):
27
- super().__init__(*args, **kwargs)
28
- self._padding = (self.padding[2], self.padding[2], self.padding[1],
29
- self.padding[1], 2 * self.padding[0], 0)
30
- self.padding = (0, 0, 0)
31
-
32
- def forward(self, x, cache_x=None):
33
- padding = list(self._padding)
34
- if cache_x is not None and self._padding[4] > 0:
35
- cache_x = cache_x.to(x.device)
36
- x = torch.cat([cache_x, x], dim=2)
37
- padding[4] -= cache_x.shape[2]
38
- x = F.pad(x, padding)
39
-
40
- return super().forward(x)
41
-
42
-
43
- class RMS_norm(nn.Module):
44
-
45
- def __init__(self, dim, channel_first=True, images=True, bias=False):
46
- super().__init__()
47
- broadcastable_dims = (1, 1, 1) if not images else (1, 1)
48
- shape = (dim, *broadcastable_dims) if channel_first else (dim,)
49
-
50
- self.channel_first = channel_first
51
- self.scale = dim**0.5
52
- self.gamma = nn.Parameter(torch.ones(shape))
53
- self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.
54
-
55
- def forward(self, x):
56
- return F.normalize(
57
- x, dim=(1 if self.channel_first else
58
- -1)) * self.scale * self.gamma + self.bias
59
-
60
-
61
- class Upsample(nn.Upsample):
62
-
63
- def forward(self, x):
64
- """
65
- Fix bfloat16 support for nearest neighbor interpolation.
66
- """
67
- return super().forward(x.float()).type_as(x)
68
-
69
-
70
- class Resample(nn.Module):
71
-
72
- def __init__(self, dim, mode):
73
- assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
74
- 'downsample3d')
75
- super().__init__()
76
- self.dim = dim
77
- self.mode = mode
78
-
79
- # layers
80
- if mode == 'upsample2d':
81
- self.resample = nn.Sequential(
82
- Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
83
- nn.Conv2d(dim, dim // 2, 3, padding=1))
84
- elif mode == 'upsample3d':
85
- self.resample = nn.Sequential(
86
- Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
87
- nn.Conv2d(dim, dim // 2, 3, padding=1))
88
- self.time_conv = CausalConv3d(
89
- dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
90
-
91
- elif mode == 'downsample2d':
92
- self.resample = nn.Sequential(
93
- nn.ZeroPad2d((0, 1, 0, 1)),
94
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
95
- elif mode == 'downsample3d':
96
- self.resample = nn.Sequential(
97
- nn.ZeroPad2d((0, 1, 0, 1)),
98
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
99
- self.time_conv = CausalConv3d(
100
- dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
101
-
102
- else:
103
- self.resample = nn.Identity()
104
-
105
- def forward(self, x, feat_cache=None, feat_idx=[0]):
106
- b, c, t, h, w = x.size()
107
- if self.mode == 'upsample3d':
108
- if feat_cache is not None:
109
- idx = feat_idx[0]
110
- if feat_cache[idx] is None:
111
- feat_cache[idx] = 'Rep'
112
- feat_idx[0] += 1
113
- else:
114
-
115
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
116
- if cache_x.shape[2] < 2 and feat_cache[
117
- idx] is not None and feat_cache[idx] != 'Rep':
118
- # cache last frame of last two chunk
119
- cache_x = torch.cat([
120
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
121
- cache_x.device), cache_x
122
- ],
123
- dim=2)
124
- if cache_x.shape[2] < 2 and feat_cache[
125
- idx] is not None and feat_cache[idx] == 'Rep':
126
- cache_x = torch.cat([
127
- torch.zeros_like(cache_x).to(cache_x.device),
128
- cache_x
129
- ],
130
- dim=2)
131
- if feat_cache[idx] == 'Rep':
132
- x = self.time_conv(x)
133
- else:
134
- x = self.time_conv(x, feat_cache[idx])
135
- feat_cache[idx] = cache_x
136
- feat_idx[0] += 1
137
-
138
- x = x.reshape(b, 2, c, t, h, w)
139
- x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
140
- 3)
141
- x = x.reshape(b, c, t * 2, h, w)
142
- t = x.shape[2]
143
- x = rearrange(x, 'b c t h w -> (b t) c h w')
144
- x = self.resample(x)
145
- x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
146
-
147
- if self.mode == 'downsample3d':
148
- if feat_cache is not None:
149
- idx = feat_idx[0]
150
- if feat_cache[idx] is None:
151
- feat_cache[idx] = x.clone()
152
- feat_idx[0] += 1
153
- else:
154
-
155
- cache_x = x[:, :, -1:, :, :].clone()
156
- # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
157
- # # cache last frame of last two chunk
158
- # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
159
-
160
- x = self.time_conv(
161
- torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
162
- feat_cache[idx] = cache_x
163
- feat_idx[0] += 1
164
- return x
165
-
166
- def init_weight(self, conv):
167
- conv_weight = conv.weight
168
- nn.init.zeros_(conv_weight)
169
- c1, c2, t, h, w = conv_weight.size()
170
- one_matrix = torch.eye(c1, c2)
171
- init_matrix = one_matrix
172
- nn.init.zeros_(conv_weight)
173
- #conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
174
- conv_weight.data[:, :, 1, 0, 0] = init_matrix #* 0.5
175
- conv.weight.data.copy_(conv_weight)
176
- nn.init.zeros_(conv.bias.data)
177
-
178
- def init_weight2(self, conv):
179
- conv_weight = conv.weight.data
180
- nn.init.zeros_(conv_weight)
181
- c1, c2, t, h, w = conv_weight.size()
182
- init_matrix = torch.eye(c1 // 2, c2)
183
- #init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
184
- conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
185
- conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
186
- conv.weight.data.copy_(conv_weight)
187
- nn.init.zeros_(conv.bias.data)
188
-
189
-
190
- class ResidualBlock(nn.Module):
191
-
192
- def __init__(self, in_dim, out_dim, dropout=0.0):
193
- super().__init__()
194
- self.in_dim = in_dim
195
- self.out_dim = out_dim
196
-
197
- # layers
198
- self.residual = nn.Sequential(
199
- RMS_norm(in_dim, images=False), nn.SiLU(),
200
- CausalConv3d(in_dim, out_dim, 3, padding=1),
201
- RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
202
- CausalConv3d(out_dim, out_dim, 3, padding=1))
203
- self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
204
- if in_dim != out_dim else nn.Identity()
205
-
206
- def forward(self, x, feat_cache=None, feat_idx=[0]):
207
- h = self.shortcut(x)
208
- for layer in self.residual:
209
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
210
- idx = feat_idx[0]
211
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
212
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
213
- # cache last frame of last two chunk
214
- cache_x = torch.cat([
215
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
216
- cache_x.device), cache_x
217
- ],
218
- dim=2)
219
- x = layer(x, feat_cache[idx])
220
- feat_cache[idx] = cache_x
221
- feat_idx[0] += 1
222
- else:
223
- x = layer(x)
224
- return x + h
225
-
226
-
227
- class AttentionBlock(nn.Module):
228
- """
229
- Causal self-attention with a single head.
230
- """
231
-
232
- def __init__(self, dim):
233
- super().__init__()
234
- self.dim = dim
235
-
236
- # layers
237
- self.norm = RMS_norm(dim)
238
- self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
239
- self.proj = nn.Conv2d(dim, dim, 1)
240
-
241
- # zero out the last layer params
242
- nn.init.zeros_(self.proj.weight)
243
-
244
- def forward(self, x):
245
- identity = x
246
- b, c, t, h, w = x.size()
247
- x = rearrange(x, 'b c t h w -> (b t) c h w')
248
- x = self.norm(x)
249
- # compute query, key, value
250
- q, k, v = self.to_qkv(x).reshape(b * t, 1, c * 3,
251
- -1).permute(0, 1, 3,
252
- 2).contiguous().chunk(
253
- 3, dim=-1)
254
-
255
- # apply attention
256
- x = F.scaled_dot_product_attention(
257
- q,
258
- k,
259
- v,
260
- )
261
- x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
262
-
263
- # output
264
- x = self.proj(x)
265
- x = rearrange(x, '(b t) c h w-> b c t h w', t=t)
266
- return x + identity
267
-
268
-
269
- class Encoder3d(nn.Module):
270
-
271
- def __init__(self,
272
- dim=128,
273
- z_dim=4,
274
- dim_mult=[1, 2, 4, 4],
275
- num_res_blocks=2,
276
- attn_scales=[],
277
- temperal_downsample=[True, True, False],
278
- dropout=0.0):
279
- super().__init__()
280
- self.dim = dim
281
- self.z_dim = z_dim
282
- self.dim_mult = dim_mult
283
- self.num_res_blocks = num_res_blocks
284
- self.attn_scales = attn_scales
285
- self.temperal_downsample = temperal_downsample
286
-
287
- # dimensions
288
- dims = [dim * u for u in [1] + dim_mult]
289
- scale = 1.0
290
-
291
- # init block
292
- self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
293
-
294
- # downsample blocks
295
- downsamples = []
296
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
297
- # residual (+attention) blocks
298
- for _ in range(num_res_blocks):
299
- downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
300
- if scale in attn_scales:
301
- downsamples.append(AttentionBlock(out_dim))
302
- in_dim = out_dim
303
-
304
- # downsample block
305
- if i != len(dim_mult) - 1:
306
- mode = 'downsample3d' if temperal_downsample[
307
- i] else 'downsample2d'
308
- downsamples.append(Resample(out_dim, mode=mode))
309
- scale /= 2.0
310
- self.downsamples = nn.Sequential(*downsamples)
311
-
312
- # middle blocks
313
- self.middle = nn.Sequential(
314
- ResidualBlock(out_dim, out_dim, dropout), AttentionBlock(out_dim),
315
- ResidualBlock(out_dim, out_dim, dropout))
316
-
317
- # output blocks
318
- self.head = nn.Sequential(
319
- RMS_norm(out_dim, images=False), nn.SiLU(),
320
- CausalConv3d(out_dim, z_dim, 3, padding=1))
321
-
322
- def forward(self, x, feat_cache=None, feat_idx=[0]):
323
- if feat_cache is not None:
324
- idx = feat_idx[0]
325
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
326
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
327
- # cache last frame of last two chunk
328
- cache_x = torch.cat([
329
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
330
- cache_x.device), cache_x
331
- ],
332
- dim=2)
333
- x = self.conv1(x, feat_cache[idx])
334
- feat_cache[idx] = cache_x
335
- feat_idx[0] += 1
336
- else:
337
- x = self.conv1(x)
338
-
339
- ## downsamples
340
- for layer in self.downsamples:
341
- if feat_cache is not None:
342
- x = layer(x, feat_cache, feat_idx)
343
- else:
344
- x = layer(x)
345
-
346
- ## middle
347
- for layer in self.middle:
348
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
349
- x = layer(x, feat_cache, feat_idx)
350
- else:
351
- x = layer(x)
352
-
353
- ## head
354
- for layer in self.head:
355
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
356
- idx = feat_idx[0]
357
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
358
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
359
- # cache last frame of last two chunk
360
- cache_x = torch.cat([
361
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
362
- cache_x.device), cache_x
363
- ],
364
- dim=2)
365
- x = layer(x, feat_cache[idx])
366
- feat_cache[idx] = cache_x
367
- feat_idx[0] += 1
368
- else:
369
- x = layer(x)
370
- return x
371
-
372
-
373
- class Decoder3d(nn.Module):
374
-
375
- def __init__(self,
376
- dim=128,
377
- z_dim=4,
378
- dim_mult=[1, 2, 4, 4],
379
- num_res_blocks=2,
380
- attn_scales=[],
381
- temperal_upsample=[False, True, True],
382
- dropout=0.0):
383
- super().__init__()
384
- self.dim = dim
385
- self.z_dim = z_dim
386
- self.dim_mult = dim_mult
387
- self.num_res_blocks = num_res_blocks
388
- self.attn_scales = attn_scales
389
- self.temperal_upsample = temperal_upsample
390
-
391
- # dimensions
392
- dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
393
- scale = 1.0 / 2**(len(dim_mult) - 2)
394
-
395
- # init block
396
- self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
397
-
398
- # middle blocks
399
- self.middle = nn.Sequential(
400
- ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
401
- ResidualBlock(dims[0], dims[0], dropout))
402
-
403
- # upsample blocks
404
- upsamples = []
405
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
406
- # residual (+attention) blocks
407
- if i == 1 or i == 2 or i == 3:
408
- in_dim = in_dim // 2
409
- for _ in range(num_res_blocks + 1):
410
- upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
411
- if scale in attn_scales:
412
- upsamples.append(AttentionBlock(out_dim))
413
- in_dim = out_dim
414
-
415
- # upsample block
416
- if i != len(dim_mult) - 1:
417
- mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
418
- upsamples.append(Resample(out_dim, mode=mode))
419
- scale *= 2.0
420
- self.upsamples = nn.Sequential(*upsamples)
421
-
422
- # output blocks
423
- self.head = nn.Sequential(
424
- RMS_norm(out_dim, images=False), nn.SiLU(),
425
- CausalConv3d(out_dim, 3, 3, padding=1))
426
-
427
- def forward(self, x, feat_cache=None, feat_idx=[0]):
428
- ## conv1
429
- if feat_cache is not None:
430
- idx = feat_idx[0]
431
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
432
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
433
- # cache last frame of last two chunk
434
- cache_x = torch.cat([
435
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
436
- cache_x.device), cache_x
437
- ],
438
- dim=2)
439
- x = self.conv1(x, feat_cache[idx])
440
- feat_cache[idx] = cache_x
441
- feat_idx[0] += 1
442
- else:
443
- x = self.conv1(x)
444
-
445
- ## middle
446
- for layer in self.middle:
447
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
448
- x = layer(x, feat_cache, feat_idx)
449
- else:
450
- x = layer(x)
451
-
452
- ## upsamples
453
- for layer in self.upsamples:
454
- if feat_cache is not None:
455
- x = layer(x, feat_cache, feat_idx)
456
- else:
457
- x = layer(x)
458
-
459
- ## head
460
- for layer in self.head:
461
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
462
- idx = feat_idx[0]
463
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
464
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
465
- # cache last frame of last two chunk
466
- cache_x = torch.cat([
467
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
468
- cache_x.device), cache_x
469
- ],
470
- dim=2)
471
- x = layer(x, feat_cache[idx])
472
- feat_cache[idx] = cache_x
473
- feat_idx[0] += 1
474
- else:
475
- x = layer(x)
476
- return x
477
-
478
-
479
- def count_conv3d(model):
480
- count = 0
481
- for m in model.modules():
482
- if isinstance(m, CausalConv3d):
483
- count += 1
484
- return count
485
-
486
-
487
- class AutoencoderKLWan_(nn.Module):
488
-
489
- def __init__(self,
490
- dim=128,
491
- z_dim=4,
492
- dim_mult=[1, 2, 4, 4],
493
- num_res_blocks=2,
494
- attn_scales=[],
495
- temperal_downsample=[True, True, False],
496
- dropout=0.0):
497
- super().__init__()
498
- self.dim = dim
499
- self.z_dim = z_dim
500
- self.dim_mult = dim_mult
501
- self.num_res_blocks = num_res_blocks
502
- self.attn_scales = attn_scales
503
- self.temperal_downsample = temperal_downsample
504
- self.temperal_upsample = temperal_downsample[::-1]
505
-
506
- # modules
507
- self.encoder = Encoder3d(dim, z_dim * 2, dim_mult, num_res_blocks,
508
- attn_scales, self.temperal_downsample, dropout)
509
- self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
510
- self.conv2 = CausalConv3d(z_dim, z_dim, 1)
511
- self.decoder = Decoder3d(dim, z_dim, dim_mult, num_res_blocks,
512
- attn_scales, self.temperal_upsample, dropout)
513
-
514
- def forward(self, x):
515
- mu, log_var = self.encode(x)
516
- z = self.reparameterize(mu, log_var)
517
- x_recon = self.decode(z)
518
- return x_recon, mu, log_var
519
-
520
- def encode(self, x, scale=None):
521
- self.clear_cache()
522
- ## cache
523
- t = x.shape[2]
524
- iter_ = 1 + (t - 1) // 4
525
- if scale != None:
526
- scale = [item.to(x.device, x.dtype) for item in scale]
527
- ## 对encode输入的x,按时间拆分为1、4、4、4....
528
- for i in range(iter_):
529
- self._enc_conv_idx = [0]
530
- if i == 0:
531
- out = self.encoder(
532
- x[:, :, :1, :, :],
533
- feat_cache=self._enc_feat_map,
534
- feat_idx=self._enc_conv_idx)
535
- else:
536
- out_ = self.encoder(
537
- x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
538
- feat_cache=self._enc_feat_map,
539
- feat_idx=self._enc_conv_idx)
540
- out = torch.cat([out, out_], 2)
541
- mu, log_var = self.conv1(out).chunk(2, dim=1)
542
- if scale != None:
543
- if isinstance(scale[0], torch.Tensor):
544
- mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
545
- 1, self.z_dim, 1, 1, 1)
546
- else:
547
- mu = (mu - scale[0]) * scale[1]
548
- x = torch.cat([mu, log_var], dim = 1)
549
- self.clear_cache()
550
- return x
551
-
552
- def decode(self, z, scale=None):
553
- self.clear_cache()
554
- # z: [b,c,t,h,w]
555
- if scale != None:
556
- scale = [item.to(z.device, z.dtype) for item in scale]
557
- if isinstance(scale[0], torch.Tensor):
558
- z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
559
- 1, self.z_dim, 1, 1, 1)
560
- else:
561
- z = z / scale[1] + scale[0]
562
- iter_ = z.shape[2]
563
- x = self.conv2(z)
564
- for i in range(iter_):
565
- self._conv_idx = [0]
566
- if i == 0:
567
- out = self.decoder(
568
- x[:, :, i:i + 1, :, :],
569
- feat_cache=self._feat_map,
570
- feat_idx=self._conv_idx)
571
- else:
572
- out_ = self.decoder(
573
- x[:, :, i:i + 1, :, :],
574
- feat_cache=self._feat_map,
575
- feat_idx=self._conv_idx)
576
- out = torch.cat([out, out_], 2)
577
- self.clear_cache()
578
- return out
579
-
580
- def reparameterize(self, mu, log_var):
581
- std = torch.exp(0.5 * log_var)
582
- eps = torch.randn_like(std)
583
- return eps * std + mu
584
-
585
- def sample(self, imgs, deterministic=False):
586
- mu, log_var = self.encode(imgs)
587
- if deterministic:
588
- return mu
589
- std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
590
- return mu + std * torch.randn_like(std)
591
-
592
- def clear_cache(self):
593
- self._conv_num = count_conv3d(self.decoder)
594
- self._conv_idx = [0]
595
- self._feat_map = [None] * self._conv_num
596
- #cache encode
597
- self._enc_conv_num = count_conv3d(self.encoder)
598
- self._enc_conv_idx = [0]
599
- self._enc_feat_map = [None] * self._enc_conv_num
600
-
601
-
602
- def _video_vae(z_dim=None, **kwargs):
603
- """
604
- Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
605
- """
606
- # params
607
- cfg = dict(
608
- dim=96,
609
- z_dim=z_dim,
610
- dim_mult=[1, 2, 4, 4],
611
- num_res_blocks=2,
612
- attn_scales=[],
613
- temperal_downsample=[False, True, True],
614
- dropout=0.0)
615
- cfg.update(**kwargs)
616
-
617
- # init model
618
- model = AutoencoderKLWan_(**cfg)
619
-
620
- return model
621
-
622
-
623
- class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
624
- _supports_gradient_checkpointing = True
625
-
626
- @register_to_config
627
- def __init__(
628
- self,
629
- latent_channels=16,
630
- temporal_compression_ratio=4,
631
- spatial_compression_ratio=8
632
- ):
633
- super().__init__()
634
- mean = [
635
- -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
636
- 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
637
- ]
638
- std = [
639
- 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
640
- 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
641
- ]
642
- self.mean = torch.tensor(mean, dtype=torch.float32)
643
- self.std = torch.tensor(std, dtype=torch.float32)
644
- self.scale = [self.mean, 1.0 / self.std]
645
-
646
- # init model
647
- self.model = _video_vae(
648
- z_dim=latent_channels,
649
- )
650
-
651
- self.gradient_checkpointing = False
652
-
653
- def _set_gradient_checkpointing(self, *args, **kwargs):
654
- if "value" in kwargs:
655
- self.gradient_checkpointing = kwargs["value"]
656
- elif "enable" in kwargs:
657
- self.gradient_checkpointing = kwargs["enable"]
658
- else:
659
- raise ValueError("Invalid set gradient checkpointing")
660
-
661
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
662
- x = [
663
- self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
664
- for u in x
665
- ]
666
- x = torch.stack(x)
667
- return x
668
-
669
- @apply_forward_hook
670
- def encode(
671
- self, x: torch.Tensor, return_dict: bool = True
672
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
673
- h = self._encode(x)
674
-
675
- posterior = DiagonalGaussianDistribution(h)
676
-
677
- if not return_dict:
678
- return (posterior,)
679
- return AutoencoderKLOutput(latent_dist=posterior)
680
-
681
- def _decode(self, zs):
682
- dec = [
683
- self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
684
- for u in zs
685
- ]
686
- dec = torch.stack(dec)
687
-
688
- return DecoderOutput(sample=dec)
689
-
690
- @apply_forward_hook
691
- def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
692
- decoded = self._decode(z).sample
693
-
694
- if not return_dict:
695
- return (decoded,)
696
- return DecoderOutput(sample=decoded)
697
-
698
- @classmethod
699
- def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
700
- def filter_kwargs(cls, kwargs):
701
- import inspect
702
- sig = inspect.signature(cls.__init__)
703
- valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
704
- filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
705
- return filtered_kwargs
706
-
707
- model = cls(**filter_kwargs(cls, additional_kwargs))
708
- if pretrained_model_path.endswith(".safetensors"):
709
- from safetensors.torch import load_file, safe_open
710
- state_dict = load_file(pretrained_model_path)
711
- else:
712
- state_dict = torch.load(pretrained_model_path, map_location="cpu")
713
- tmp_state_dict = {}
714
- for key in state_dict:
715
- tmp_state_dict["model." + key] = state_dict[key]
716
- state_dict = tmp_state_dict
717
- m, u = model.load_state_dict(state_dict, strict=False)
718
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
719
- print(m, u)
720
- return model
721
-
722
-
723
- class AutoencoderKLWanCompileQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
724
- @register_to_config
725
- def __init__(
726
- self,
727
- attn_scales = [],
728
- base_dim = 96,
729
- dim_mult = [
730
- 1,
731
- 2,
732
- 4,
733
- 4
734
- ],
735
- dropout = 0.0,
736
- latents_mean = [
737
- -0.7571,
738
- -0.7089,
739
- -0.9113,
740
- 0.1075,
741
- -0.1745,
742
- 0.9653,
743
- -0.1517,
744
- 1.5508,
745
- 0.4134,
746
- -0.0715,
747
- 0.5517,
748
- -0.3632,
749
- -0.1922,
750
- -0.9497,
751
- 0.2503,
752
- -0.2921
753
- ],
754
- latents_std = [
755
- 2.8184,
756
- 1.4541,
757
- 2.3275,
758
- 2.6558,
759
- 1.2196,
760
- 1.7708,
761
- 2.6052,
762
- 2.0743,
763
- 3.2687,
764
- 2.1526,
765
- 2.8652,
766
- 1.5579,
767
- 1.6382,
768
- 1.1253,
769
- 2.8251,
770
- 1.916
771
- ],
772
- num_res_blocks = 2,
773
- temperal_downsample = [
774
- False,
775
- True,
776
- True
777
- ],
778
- z_dim = 16
779
- ):
780
- super().__init__()
781
- cfg = dict(
782
- dim=base_dim,
783
- z_dim=z_dim,
784
- dim_mult=dim_mult,
785
- num_res_blocks=num_res_blocks,
786
- attn_scales=attn_scales,
787
- temperal_downsample=temperal_downsample,
788
- dropout=dropout)
789
-
790
- # init model
791
- self.model = AutoencoderKLWan_(**cfg)
792
-
793
- self.dim = base_dim
794
- self.z_dim = z_dim
795
- self.dim_mult = dim_mult
796
- self.num_res_blocks = num_res_blocks
797
- self.attn_scales = attn_scales
798
- self.temperal_downsample = temperal_downsample
799
- self.temperal_upsample = temperal_downsample[::-1]
800
-
801
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
802
- x = [
803
- self.model.encode(u.unsqueeze(0)).squeeze(0)
804
- for u in x
805
- ]
806
- x = torch.stack(x)
807
- return x
808
-
809
- @apply_forward_hook
810
- def encode(
811
- self, x: torch.Tensor, return_dict: bool = True
812
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
813
- h = self._encode(x)
814
-
815
- posterior = DiagonalGaussianDistribution(h)
816
-
817
- if not return_dict:
818
- return (posterior,)
819
- return AutoencoderKLOutput(latent_dist=posterior)
820
-
821
- def _decode(self, zs):
822
- dec = [
823
- self.model.decode(u.unsqueeze(0)).clamp_(-1, 1).squeeze(0)
824
- for u in zs
825
- ]
826
- dec = torch.stack(dec)
827
-
828
- return DecoderOutput(sample=dec)
829
-
830
- @apply_forward_hook
831
- def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
832
- decoded = self._decode(z).sample
833
-
834
- if not return_dict:
835
- return (decoded,)
836
- return DecoderOutput(sample=decoded)
837
-
838
- @classmethod
839
- def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
840
- def filter_kwargs(cls, kwargs):
841
- import inspect
842
- sig = inspect.signature(cls.__init__)
843
- valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
844
- filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
845
- return filtered_kwargs
846
-
847
- model = cls(**filter_kwargs(cls, additional_kwargs))
848
- if pretrained_model_path.endswith(".safetensors"):
849
- from safetensors.torch import load_file, safe_open
850
- state_dict = load_file(pretrained_model_path)
851
- else:
852
- state_dict = torch.load(pretrained_model_path, map_location="cpu")
853
- tmp_state_dict = {}
854
- for key in state_dict:
855
- tmp_state_dict["model." + key] = state_dict[key]
856
- state_dict = tmp_state_dict
857
- m, u = model.load_state_dict(state_dict, strict=False)
858
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
859
- print(m, u)
860
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_vae3_8.py DELETED
@@ -1,1091 +0,0 @@
1
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
- import logging
3
- from typing import Tuple, Union
4
-
5
- import torch
6
- import torch.cuda.amp as amp
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from diffusers.configuration_utils import ConfigMixin, register_to_config
10
- from diffusers.loaders.single_file_model import FromOriginalModelMixin
11
- from diffusers.models.autoencoders.vae import (DecoderOutput,
12
- DiagonalGaussianDistribution)
13
- from diffusers.models.modeling_outputs import AutoencoderKLOutput
14
- from diffusers.models.modeling_utils import ModelMixin
15
- from diffusers.utils.accelerate_utils import apply_forward_hook
16
- from einops import rearrange
17
-
18
-
19
- CACHE_T = 2
20
-
21
-
22
- class CausalConv3d(nn.Conv3d):
23
- """
24
- Causal 3d convolusion.
25
- """
26
-
27
- def __init__(self, *args, **kwargs):
28
- super().__init__(*args, **kwargs)
29
- self._padding = (
30
- self.padding[2],
31
- self.padding[2],
32
- self.padding[1],
33
- self.padding[1],
34
- 2 * self.padding[0],
35
- 0,
36
- )
37
- self.padding = (0, 0, 0)
38
-
39
- def forward(self, x, cache_x=None):
40
- padding = list(self._padding)
41
- if cache_x is not None and self._padding[4] > 0:
42
- cache_x = cache_x.to(x.device)
43
- x = torch.cat([cache_x, x], dim=2)
44
- padding[4] -= cache_x.shape[2]
45
- x = F.pad(x, padding)
46
-
47
- return super().forward(x)
48
-
49
-
50
- class RMS_norm(nn.Module):
51
-
52
- def __init__(self, dim, channel_first=True, images=True, bias=False):
53
- super().__init__()
54
- broadcastable_dims = (1, 1, 1) if not images else (1, 1)
55
- shape = (dim, *broadcastable_dims) if channel_first else (dim,)
56
-
57
- self.channel_first = channel_first
58
- self.scale = dim**0.5
59
- self.gamma = nn.Parameter(torch.ones(shape))
60
- self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
61
-
62
- def forward(self, x):
63
- return (F.normalize(x, dim=(1 if self.channel_first else -1)) *
64
- self.scale * self.gamma + self.bias)
65
-
66
-
67
- class Upsample(nn.Upsample):
68
-
69
- def forward(self, x):
70
- """
71
- Fix bfloat16 support for nearest neighbor interpolation.
72
- """
73
- return super().forward(x.float()).type_as(x)
74
-
75
-
76
- class Resample(nn.Module):
77
-
78
- def __init__(self, dim, mode):
79
- assert mode in (
80
- "none",
81
- "upsample2d",
82
- "upsample3d",
83
- "downsample2d",
84
- "downsample3d",
85
- )
86
- super().__init__()
87
- self.dim = dim
88
- self.mode = mode
89
-
90
- # layers
91
- if mode == "upsample2d":
92
- self.resample = nn.Sequential(
93
- Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
94
- nn.Conv2d(dim, dim, 3, padding=1),
95
- )
96
- elif mode == "upsample3d":
97
- self.resample = nn.Sequential(
98
- Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
99
- nn.Conv2d(dim, dim, 3, padding=1),
100
- # nn.Conv2d(dim, dim//2, 3, padding=1)
101
- )
102
- self.time_conv = CausalConv3d(
103
- dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
104
- elif mode == "downsample2d":
105
- self.resample = nn.Sequential(
106
- nn.ZeroPad2d((0, 1, 0, 1)),
107
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
108
- elif mode == "downsample3d":
109
- self.resample = nn.Sequential(
110
- nn.ZeroPad2d((0, 1, 0, 1)),
111
- nn.Conv2d(dim, dim, 3, stride=(2, 2)))
112
- self.time_conv = CausalConv3d(
113
- dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
114
- else:
115
- self.resample = nn.Identity()
116
-
117
- def forward(self, x, feat_cache=None, feat_idx=[0]):
118
- b, c, t, h, w = x.size()
119
- if self.mode == "upsample3d":
120
- if feat_cache is not None:
121
- idx = feat_idx[0]
122
- if feat_cache[idx] is None:
123
- feat_cache[idx] = "Rep"
124
- feat_idx[0] += 1
125
- else:
126
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
127
- if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
128
- feat_cache[idx] != "Rep"):
129
- # cache last frame of last two chunk
130
- cache_x = torch.cat(
131
- [
132
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
133
- cache_x.device),
134
- cache_x,
135
- ],
136
- dim=2,
137
- )
138
- if (cache_x.shape[2] < 2 and feat_cache[idx] is not None and
139
- feat_cache[idx] == "Rep"):
140
- cache_x = torch.cat(
141
- [
142
- torch.zeros_like(cache_x).to(cache_x.device),
143
- cache_x
144
- ],
145
- dim=2,
146
- )
147
- if feat_cache[idx] == "Rep":
148
- x = self.time_conv(x)
149
- else:
150
- x = self.time_conv(x, feat_cache[idx])
151
- feat_cache[idx] = cache_x
152
- feat_idx[0] += 1
153
- x = x.reshape(b, 2, c, t, h, w)
154
- x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
155
- 3)
156
- x = x.reshape(b, c, t * 2, h, w)
157
- t = x.shape[2]
158
- x = rearrange(x, "b c t h w -> (b t) c h w")
159
- x = self.resample(x)
160
- x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
161
-
162
- if self.mode == "downsample3d":
163
- if feat_cache is not None:
164
- idx = feat_idx[0]
165
- if feat_cache[idx] is None:
166
- feat_cache[idx] = x.clone()
167
- feat_idx[0] += 1
168
- else:
169
- cache_x = x[:, :, -1:, :, :].clone()
170
- x = self.time_conv(
171
- torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
172
- feat_cache[idx] = cache_x
173
- feat_idx[0] += 1
174
- return x
175
-
176
- def init_weight(self, conv):
177
- conv_weight = conv.weight.detach().clone()
178
- nn.init.zeros_(conv_weight)
179
- c1, c2, t, h, w = conv_weight.size()
180
- one_matrix = torch.eye(c1, c2)
181
- init_matrix = one_matrix
182
- nn.init.zeros_(conv_weight)
183
- conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
184
- conv.weight = nn.Parameter(conv_weight)
185
- nn.init.zeros_(conv.bias.data)
186
-
187
- def init_weight2(self, conv):
188
- conv_weight = conv.weight.data.detach().clone()
189
- nn.init.zeros_(conv_weight)
190
- c1, c2, t, h, w = conv_weight.size()
191
- init_matrix = torch.eye(c1 // 2, c2)
192
- conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
193
- conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
194
- conv.weight = nn.Parameter(conv_weight)
195
- nn.init.zeros_(conv.bias.data)
196
-
197
-
198
- class ResidualBlock(nn.Module):
199
-
200
- def __init__(self, in_dim, out_dim, dropout=0.0):
201
- super().__init__()
202
- self.in_dim = in_dim
203
- self.out_dim = out_dim
204
-
205
- # layers
206
- self.residual = nn.Sequential(
207
- RMS_norm(in_dim, images=False),
208
- nn.SiLU(),
209
- CausalConv3d(in_dim, out_dim, 3, padding=1),
210
- RMS_norm(out_dim, images=False),
211
- nn.SiLU(),
212
- nn.Dropout(dropout),
213
- CausalConv3d(out_dim, out_dim, 3, padding=1),
214
- )
215
- self.shortcut = (
216
- CausalConv3d(in_dim, out_dim, 1)
217
- if in_dim != out_dim else nn.Identity())
218
-
219
- def forward(self, x, feat_cache=None, feat_idx=[0]):
220
- h = self.shortcut(x)
221
- for layer in self.residual:
222
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
223
- idx = feat_idx[0]
224
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
225
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
226
- # cache last frame of last two chunk
227
- cache_x = torch.cat(
228
- [
229
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
230
- cache_x.device),
231
- cache_x,
232
- ],
233
- dim=2,
234
- )
235
- x = layer(x, feat_cache[idx])
236
- feat_cache[idx] = cache_x
237
- feat_idx[0] += 1
238
- else:
239
- x = layer(x)
240
- return x + h
241
-
242
-
243
- class AttentionBlock(nn.Module):
244
- """
245
- Causal self-attention with a single head.
246
- """
247
-
248
- def __init__(self, dim):
249
- super().__init__()
250
- self.dim = dim
251
-
252
- # layers
253
- self.norm = RMS_norm(dim)
254
- self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
255
- self.proj = nn.Conv2d(dim, dim, 1)
256
-
257
- # zero out the last layer params
258
- nn.init.zeros_(self.proj.weight)
259
-
260
- def forward(self, x):
261
- identity = x
262
- b, c, t, h, w = x.size()
263
- x = rearrange(x, "b c t h w -> (b t) c h w")
264
- x = self.norm(x)
265
- # compute query, key, value
266
- q, k, v = (
267
- self.to_qkv(x).reshape(b * t, 1, c * 3,
268
- -1).permute(0, 1, 3,
269
- 2).contiguous().chunk(3, dim=-1))
270
-
271
- # apply attention
272
- x = F.scaled_dot_product_attention(
273
- q,
274
- k,
275
- v,
276
- )
277
- x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
278
-
279
- # output
280
- x = self.proj(x)
281
- x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
282
- return x + identity
283
-
284
-
285
- def patchify(x, patch_size):
286
- if patch_size == 1:
287
- return x
288
- if x.dim() == 4:
289
- x = rearrange(
290
- x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size, r=patch_size)
291
- elif x.dim() == 5:
292
- x = rearrange(
293
- x,
294
- "b c f (h q) (w r) -> b (c r q) f h w",
295
- q=patch_size,
296
- r=patch_size,
297
- )
298
- else:
299
- raise ValueError(f"Invalid input shape: {x.shape}")
300
-
301
- return x
302
-
303
-
304
- def unpatchify(x, patch_size):
305
- if patch_size == 1:
306
- return x
307
-
308
- if x.dim() == 4:
309
- x = rearrange(
310
- x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size, r=patch_size)
311
- elif x.dim() == 5:
312
- x = rearrange(
313
- x,
314
- "b (c r q) f h w -> b c f (h q) (w r)",
315
- q=patch_size,
316
- r=patch_size,
317
- )
318
- return x
319
-
320
-
321
- class AvgDown3D(nn.Module):
322
-
323
- def __init__(
324
- self,
325
- in_channels,
326
- out_channels,
327
- factor_t,
328
- factor_s=1,
329
- ):
330
- super().__init__()
331
- self.in_channels = in_channels
332
- self.out_channels = out_channels
333
- self.factor_t = factor_t
334
- self.factor_s = factor_s
335
- self.factor = self.factor_t * self.factor_s * self.factor_s
336
-
337
- assert in_channels * self.factor % out_channels == 0
338
- self.group_size = in_channels * self.factor // out_channels
339
-
340
- def forward(self, x: torch.Tensor) -> torch.Tensor:
341
- pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
342
- pad = (0, 0, 0, 0, pad_t, 0)
343
- x = F.pad(x, pad)
344
- B, C, T, H, W = x.shape
345
- x = x.view(
346
- B,
347
- C,
348
- T // self.factor_t,
349
- self.factor_t,
350
- H // self.factor_s,
351
- self.factor_s,
352
- W // self.factor_s,
353
- self.factor_s,
354
- )
355
- x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
356
- x = x.view(
357
- B,
358
- C * self.factor,
359
- T // self.factor_t,
360
- H // self.factor_s,
361
- W // self.factor_s,
362
- )
363
- x = x.view(
364
- B,
365
- self.out_channels,
366
- self.group_size,
367
- T // self.factor_t,
368
- H // self.factor_s,
369
- W // self.factor_s,
370
- )
371
- x = x.mean(dim=2)
372
- return x
373
-
374
-
375
- class DupUp3D(nn.Module):
376
-
377
- def __init__(
378
- self,
379
- in_channels: int,
380
- out_channels: int,
381
- factor_t,
382
- factor_s=1,
383
- ):
384
- super().__init__()
385
- self.in_channels = in_channels
386
- self.out_channels = out_channels
387
-
388
- self.factor_t = factor_t
389
- self.factor_s = factor_s
390
- self.factor = self.factor_t * self.factor_s * self.factor_s
391
-
392
- assert out_channels * self.factor % in_channels == 0
393
- self.repeats = out_channels * self.factor // in_channels
394
-
395
- def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
396
- x = x.repeat_interleave(self.repeats, dim=1)
397
- x = x.view(
398
- x.size(0),
399
- self.out_channels,
400
- self.factor_t,
401
- self.factor_s,
402
- self.factor_s,
403
- x.size(2),
404
- x.size(3),
405
- x.size(4),
406
- )
407
- x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
408
- x = x.view(
409
- x.size(0),
410
- self.out_channels,
411
- x.size(2) * self.factor_t,
412
- x.size(4) * self.factor_s,
413
- x.size(6) * self.factor_s,
414
- )
415
- if first_chunk:
416
- x = x[:, :, self.factor_t - 1:, :, :]
417
- return x
418
-
419
-
420
- class Down_ResidualBlock(nn.Module):
421
-
422
- def __init__(self,
423
- in_dim,
424
- out_dim,
425
- dropout,
426
- mult,
427
- temperal_downsample=False,
428
- down_flag=False):
429
- super().__init__()
430
-
431
- # Shortcut path with downsample
432
- self.avg_shortcut = AvgDown3D(
433
- in_dim,
434
- out_dim,
435
- factor_t=2 if temperal_downsample else 1,
436
- factor_s=2 if down_flag else 1,
437
- )
438
-
439
- # Main path with residual blocks and downsample
440
- downsamples = []
441
- for _ in range(mult):
442
- downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
443
- in_dim = out_dim
444
-
445
- # Add the final downsample block
446
- if down_flag:
447
- mode = "downsample3d" if temperal_downsample else "downsample2d"
448
- downsamples.append(Resample(out_dim, mode=mode))
449
-
450
- self.downsamples = nn.Sequential(*downsamples)
451
-
452
- def forward(self, x, feat_cache=None, feat_idx=[0]):
453
- x_copy = x.clone()
454
- for module in self.downsamples:
455
- x = module(x, feat_cache, feat_idx)
456
-
457
- return x + self.avg_shortcut(x_copy)
458
-
459
-
460
- class Up_ResidualBlock(nn.Module):
461
-
462
- def __init__(self,
463
- in_dim,
464
- out_dim,
465
- dropout,
466
- mult,
467
- temperal_upsample=False,
468
- up_flag=False):
469
- super().__init__()
470
- # Shortcut path with upsample
471
- if up_flag:
472
- self.avg_shortcut = DupUp3D(
473
- in_dim,
474
- out_dim,
475
- factor_t=2 if temperal_upsample else 1,
476
- factor_s=2 if up_flag else 1,
477
- )
478
- else:
479
- self.avg_shortcut = None
480
-
481
- # Main path with residual blocks and upsample
482
- upsamples = []
483
- for _ in range(mult):
484
- upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
485
- in_dim = out_dim
486
-
487
- # Add the final upsample block
488
- if up_flag:
489
- mode = "upsample3d" if temperal_upsample else "upsample2d"
490
- upsamples.append(Resample(out_dim, mode=mode))
491
-
492
- self.upsamples = nn.Sequential(*upsamples)
493
-
494
- def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
495
- x_main = x.clone()
496
- for module in self.upsamples:
497
- x_main = module(x_main, feat_cache, feat_idx)
498
- if self.avg_shortcut is not None:
499
- x_shortcut = self.avg_shortcut(x, first_chunk)
500
- return x_main + x_shortcut
501
- else:
502
- return x_main
503
-
504
-
505
- class Encoder3d(nn.Module):
506
-
507
- def __init__(
508
- self,
509
- dim=128,
510
- z_dim=4,
511
- dim_mult=[1, 2, 4, 4],
512
- num_res_blocks=2,
513
- attn_scales=[],
514
- temperal_downsample=[True, True, False],
515
- dropout=0.0,
516
- ):
517
- super().__init__()
518
- self.dim = dim
519
- self.z_dim = z_dim
520
- self.dim_mult = dim_mult
521
- self.num_res_blocks = num_res_blocks
522
- self.attn_scales = attn_scales
523
- self.temperal_downsample = temperal_downsample
524
-
525
- # dimensions
526
- dims = [dim * u for u in [1] + dim_mult]
527
- scale = 1.0
528
-
529
- # init block
530
- self.conv1 = CausalConv3d(12, dims[0], 3, padding=1)
531
-
532
- # downsample blocks
533
- downsamples = []
534
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
535
- t_down_flag = (
536
- temperal_downsample[i]
537
- if i < len(temperal_downsample) else False)
538
- downsamples.append(
539
- Down_ResidualBlock(
540
- in_dim=in_dim,
541
- out_dim=out_dim,
542
- dropout=dropout,
543
- mult=num_res_blocks,
544
- temperal_downsample=t_down_flag,
545
- down_flag=i != len(dim_mult) - 1,
546
- ))
547
- scale /= 2.0
548
- self.downsamples = nn.Sequential(*downsamples)
549
-
550
- # middle blocks
551
- self.middle = nn.Sequential(
552
- ResidualBlock(out_dim, out_dim, dropout),
553
- AttentionBlock(out_dim),
554
- ResidualBlock(out_dim, out_dim, dropout),
555
- )
556
-
557
- # # output blocks
558
- self.head = nn.Sequential(
559
- RMS_norm(out_dim, images=False),
560
- nn.SiLU(),
561
- CausalConv3d(out_dim, z_dim, 3, padding=1),
562
- )
563
-
564
- def forward(self, x, feat_cache=None, feat_idx=[0]):
565
-
566
- if feat_cache is not None:
567
- idx = feat_idx[0]
568
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
569
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
570
- cache_x = torch.cat(
571
- [
572
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
573
- cache_x.device),
574
- cache_x,
575
- ],
576
- dim=2,
577
- )
578
- x = self.conv1(x, feat_cache[idx])
579
- feat_cache[idx] = cache_x
580
- feat_idx[0] += 1
581
- else:
582
- x = self.conv1(x)
583
-
584
- ## downsamples
585
- for layer in self.downsamples:
586
- if feat_cache is not None:
587
- x = layer(x, feat_cache, feat_idx)
588
- else:
589
- x = layer(x)
590
-
591
- ## middle
592
- for layer in self.middle:
593
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
594
- x = layer(x, feat_cache, feat_idx)
595
- else:
596
- x = layer(x)
597
-
598
- ## head
599
- for layer in self.head:
600
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
601
- idx = feat_idx[0]
602
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
603
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
604
- cache_x = torch.cat(
605
- [
606
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
607
- cache_x.device),
608
- cache_x,
609
- ],
610
- dim=2,
611
- )
612
- x = layer(x, feat_cache[idx])
613
- feat_cache[idx] = cache_x
614
- feat_idx[0] += 1
615
- else:
616
- x = layer(x)
617
-
618
- return x
619
-
620
-
621
- class Decoder3d(nn.Module):
622
-
623
- def __init__(
624
- self,
625
- dim=128,
626
- z_dim=4,
627
- dim_mult=[1, 2, 4, 4],
628
- num_res_blocks=2,
629
- attn_scales=[],
630
- temperal_upsample=[False, True, True],
631
- dropout=0.0,
632
- ):
633
- super().__init__()
634
- self.dim = dim
635
- self.z_dim = z_dim
636
- self.dim_mult = dim_mult
637
- self.num_res_blocks = num_res_blocks
638
- self.attn_scales = attn_scales
639
- self.temperal_upsample = temperal_upsample
640
-
641
- # dimensions
642
- dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
643
- scale = 1.0 / 2**(len(dim_mult) - 2)
644
- # init block
645
- self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
646
-
647
- # middle blocks
648
- self.middle = nn.Sequential(
649
- ResidualBlock(dims[0], dims[0], dropout),
650
- AttentionBlock(dims[0]),
651
- ResidualBlock(dims[0], dims[0], dropout),
652
- )
653
-
654
- # upsample blocks
655
- upsamples = []
656
- for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
657
- t_up_flag = temperal_upsample[i] if i < len(
658
- temperal_upsample) else False
659
- upsamples.append(
660
- Up_ResidualBlock(
661
- in_dim=in_dim,
662
- out_dim=out_dim,
663
- dropout=dropout,
664
- mult=num_res_blocks + 1,
665
- temperal_upsample=t_up_flag,
666
- up_flag=i != len(dim_mult) - 1,
667
- ))
668
- self.upsamples = nn.Sequential(*upsamples)
669
-
670
- # output blocks
671
- self.head = nn.Sequential(
672
- RMS_norm(out_dim, images=False),
673
- nn.SiLU(),
674
- CausalConv3d(out_dim, 12, 3, padding=1),
675
- )
676
-
677
- def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
678
- if feat_cache is not None:
679
- idx = feat_idx[0]
680
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
681
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
682
- cache_x = torch.cat(
683
- [
684
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
685
- cache_x.device),
686
- cache_x,
687
- ],
688
- dim=2,
689
- )
690
- x = self.conv1(x, feat_cache[idx])
691
- feat_cache[idx] = cache_x
692
- feat_idx[0] += 1
693
- else:
694
- x = self.conv1(x)
695
-
696
- for layer in self.middle:
697
- if isinstance(layer, ResidualBlock) and feat_cache is not None:
698
- x = layer(x, feat_cache, feat_idx)
699
- else:
700
- x = layer(x)
701
-
702
- ## upsamples
703
- for layer in self.upsamples:
704
- if feat_cache is not None:
705
- x = layer(x, feat_cache, feat_idx, first_chunk)
706
- else:
707
- x = layer(x)
708
-
709
- ## head
710
- for layer in self.head:
711
- if isinstance(layer, CausalConv3d) and feat_cache is not None:
712
- idx = feat_idx[0]
713
- cache_x = x[:, :, -CACHE_T:, :, :].clone()
714
- if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
715
- cache_x = torch.cat(
716
- [
717
- feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
718
- cache_x.device),
719
- cache_x,
720
- ],
721
- dim=2,
722
- )
723
- x = layer(x, feat_cache[idx])
724
- feat_cache[idx] = cache_x
725
- feat_idx[0] += 1
726
- else:
727
- x = layer(x)
728
- return x
729
-
730
-
731
- def count_conv3d(model):
732
- count = 0
733
- for m in model.modules():
734
- if isinstance(m, CausalConv3d):
735
- count += 1
736
- return count
737
-
738
-
739
- class AutoencoderKLWan2_2_(nn.Module):
740
-
741
- def __init__(
742
- self,
743
- dim=160,
744
- dec_dim=256,
745
- z_dim=16,
746
- dim_mult=[1, 2, 4, 4],
747
- num_res_blocks=2,
748
- attn_scales=[],
749
- temperal_downsample=[True, True, False],
750
- dropout=0.0,
751
- ):
752
- super().__init__()
753
- self.dim = dim
754
- self.z_dim = z_dim
755
- self.dim_mult = dim_mult
756
- self.num_res_blocks = num_res_blocks
757
- self.attn_scales = attn_scales
758
- self.temperal_downsample = temperal_downsample
759
- self.temperal_upsample = temperal_downsample[::-1]
760
-
761
- # modules
762
- self.encoder = Encoder3d(
763
- dim,
764
- z_dim * 2,
765
- dim_mult,
766
- num_res_blocks,
767
- attn_scales,
768
- self.temperal_downsample,
769
- dropout,
770
- )
771
- self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
772
- self.conv2 = CausalConv3d(z_dim, z_dim, 1)
773
- self.decoder = Decoder3d(
774
- dec_dim,
775
- z_dim,
776
- dim_mult,
777
- num_res_blocks,
778
- attn_scales,
779
- self.temperal_upsample,
780
- dropout,
781
- )
782
-
783
- def forward(self, x, scale=[0, 1]):
784
- mu = self.encode(x, scale)
785
- x_recon = self.decode(mu, scale)
786
- return x_recon, mu
787
-
788
- def encode(self, x, scale):
789
- self.clear_cache()
790
- # z: [b,c,t,h,w]
791
- scale = [item.to(x.device, x.dtype) for item in scale]
792
- x = patchify(x, patch_size=2)
793
- t = x.shape[2]
794
- iter_ = 1 + (t - 1) // 4
795
- for i in range(iter_):
796
- self._enc_conv_idx = [0]
797
- if i == 0:
798
- out = self.encoder(
799
- x[:, :, :1, :, :],
800
- feat_cache=self._enc_feat_map,
801
- feat_idx=self._enc_conv_idx,
802
- )
803
- else:
804
- out_ = self.encoder(
805
- x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
806
- feat_cache=self._enc_feat_map,
807
- feat_idx=self._enc_conv_idx,
808
- )
809
- out = torch.cat([out, out_], 2)
810
- mu, log_var = self.conv1(out).chunk(2, dim=1)
811
- if isinstance(scale[0], torch.Tensor):
812
- mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
813
- 1, self.z_dim, 1, 1, 1)
814
- else:
815
- mu = (mu - scale[0]) * scale[1]
816
- x = torch.cat([mu, log_var], dim = 1)
817
- self.clear_cache()
818
- return x
819
-
820
- def decode(self, z, scale):
821
- self.clear_cache()
822
- # z: [b,c,t,h,w]
823
- scale = [item.to(z.device, z.dtype) for item in scale]
824
- if isinstance(scale[0], torch.Tensor):
825
- z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
826
- 1, self.z_dim, 1, 1, 1)
827
- else:
828
- z = z / scale[1] + scale[0]
829
- iter_ = z.shape[2]
830
- x = self.conv2(z)
831
- for i in range(iter_):
832
- self._conv_idx = [0]
833
- if i == 0:
834
- out = self.decoder(
835
- x[:, :, i:i + 1, :, :],
836
- feat_cache=self._feat_map,
837
- feat_idx=self._conv_idx,
838
- first_chunk=True,
839
- )
840
- else:
841
- out_ = self.decoder(
842
- x[:, :, i:i + 1, :, :],
843
- feat_cache=self._feat_map,
844
- feat_idx=self._conv_idx,
845
- )
846
- out = torch.cat([out, out_], 2)
847
- out = unpatchify(out, patch_size=2)
848
- self.clear_cache()
849
- return out
850
-
851
- def reparameterize(self, mu, log_var):
852
- std = torch.exp(0.5 * log_var)
853
- eps = torch.randn_like(std)
854
- return eps * std + mu
855
-
856
- def sample(self, imgs, deterministic=False):
857
- mu, log_var = self.encode(imgs)
858
- if deterministic:
859
- return mu
860
- std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
861
- return mu + std * torch.randn_like(std)
862
-
863
- def clear_cache(self):
864
- self._conv_num = count_conv3d(self.decoder)
865
- self._conv_idx = [0]
866
- self._feat_map = [None] * self._conv_num
867
- # cache encode
868
- self._enc_conv_num = count_conv3d(self.encoder)
869
- self._enc_conv_idx = [0]
870
- self._enc_feat_map = [None] * self._enc_conv_num
871
-
872
-
873
- def _video_vae(pretrained_path=None, z_dim=16, dim=160, device="cpu", **kwargs):
874
- # params
875
- cfg = dict(
876
- dim=dim,
877
- z_dim=z_dim,
878
- dim_mult=[1, 2, 4, 4],
879
- num_res_blocks=2,
880
- attn_scales=[],
881
- temperal_downsample=[True, True, True],
882
- dropout=0.0,
883
- )
884
- cfg.update(**kwargs)
885
-
886
- # init model
887
- model = AutoencoderKLWan2_2_(**cfg)
888
-
889
- return model
890
-
891
-
892
- class AutoencoderKLWan3_8(ModelMixin, ConfigMixin, FromOriginalModelMixin):
893
- _supports_gradient_checkpointing = True
894
-
895
- @register_to_config
896
- def __init__(
897
- self,
898
- latent_channels=48,
899
- c_dim=160,
900
- vae_pth=None,
901
- dim_mult=[1, 2, 4, 4],
902
- temperal_downsample=[False, True, True],
903
- temporal_compression_ratio=4,
904
- spatial_compression_ratio=8
905
- ):
906
- super().__init__()
907
- mean = torch.tensor(
908
- [
909
- -0.2289,
910
- -0.0052,
911
- -0.1323,
912
- -0.2339,
913
- -0.2799,
914
- 0.0174,
915
- 0.1838,
916
- 0.1557,
917
- -0.1382,
918
- 0.0542,
919
- 0.2813,
920
- 0.0891,
921
- 0.1570,
922
- -0.0098,
923
- 0.0375,
924
- -0.1825,
925
- -0.2246,
926
- -0.1207,
927
- -0.0698,
928
- 0.5109,
929
- 0.2665,
930
- -0.2108,
931
- -0.2158,
932
- 0.2502,
933
- -0.2055,
934
- -0.0322,
935
- 0.1109,
936
- 0.1567,
937
- -0.0729,
938
- 0.0899,
939
- -0.2799,
940
- -0.1230,
941
- -0.0313,
942
- -0.1649,
943
- 0.0117,
944
- 0.0723,
945
- -0.2839,
946
- -0.2083,
947
- -0.0520,
948
- 0.3748,
949
- 0.0152,
950
- 0.1957,
951
- 0.1433,
952
- -0.2944,
953
- 0.3573,
954
- -0.0548,
955
- -0.1681,
956
- -0.0667,
957
- ], dtype=torch.float32
958
- )
959
- std = torch.tensor(
960
- [
961
- 0.4765,
962
- 1.0364,
963
- 0.4514,
964
- 1.1677,
965
- 0.5313,
966
- 0.4990,
967
- 0.4818,
968
- 0.5013,
969
- 0.8158,
970
- 1.0344,
971
- 0.5894,
972
- 1.0901,
973
- 0.6885,
974
- 0.6165,
975
- 0.8454,
976
- 0.4978,
977
- 0.5759,
978
- 0.3523,
979
- 0.7135,
980
- 0.6804,
981
- 0.5833,
982
- 1.4146,
983
- 0.8986,
984
- 0.5659,
985
- 0.7069,
986
- 0.5338,
987
- 0.4889,
988
- 0.4917,
989
- 0.4069,
990
- 0.4999,
991
- 0.6866,
992
- 0.4093,
993
- 0.5709,
994
- 0.6065,
995
- 0.6415,
996
- 0.4944,
997
- 0.5726,
998
- 1.2042,
999
- 0.5458,
1000
- 1.6887,
1001
- 0.3971,
1002
- 1.0600,
1003
- 0.3943,
1004
- 0.5537,
1005
- 0.5444,
1006
- 0.4089,
1007
- 0.7468,
1008
- 0.7744,
1009
- ], dtype=torch.float32
1010
- )
1011
- self.scale = [mean, 1.0 / std]
1012
-
1013
- # init model
1014
- self.model = _video_vae(
1015
- pretrained_path=vae_pth,
1016
- z_dim=latent_channels,
1017
- dim=c_dim,
1018
- dim_mult=dim_mult,
1019
- temperal_downsample=temperal_downsample,
1020
- ).eval().requires_grad_(False)
1021
-
1022
- self.gradient_checkpointing = False
1023
-
1024
- def _set_gradient_checkpointing(self, *args, **kwargs):
1025
- if "value" in kwargs:
1026
- self.gradient_checkpointing = kwargs["value"]
1027
- elif "enable" in kwargs:
1028
- self.gradient_checkpointing = kwargs["enable"]
1029
- else:
1030
- raise ValueError("Invalid set gradient checkpointing")
1031
-
1032
- def _encode(self, x: torch.Tensor) -> torch.Tensor:
1033
- x = [
1034
- self.model.encode(u.unsqueeze(0), self.scale).squeeze(0)
1035
- for u in x
1036
- ]
1037
- x = torch.stack(x)
1038
- return x
1039
-
1040
- @apply_forward_hook
1041
- def encode(
1042
- self, x: torch.Tensor, return_dict: bool = True
1043
- ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
1044
- h = self._encode(x)
1045
-
1046
- posterior = DiagonalGaussianDistribution(h)
1047
-
1048
- if not return_dict:
1049
- return (posterior,)
1050
- return AutoencoderKLOutput(latent_dist=posterior)
1051
-
1052
- def _decode(self, zs):
1053
- dec = [
1054
- self.model.decode(u.unsqueeze(0), self.scale).clamp_(-1, 1).squeeze(0)
1055
- for u in zs
1056
- ]
1057
- dec = torch.stack(dec)
1058
-
1059
- return DecoderOutput(sample=dec)
1060
-
1061
- @apply_forward_hook
1062
- def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1063
- decoded = self._decode(z).sample
1064
-
1065
- if not return_dict:
1066
- return (decoded,)
1067
- return DecoderOutput(sample=decoded)
1068
-
1069
- @classmethod
1070
- def from_pretrained(cls, pretrained_model_path, additional_kwargs={}):
1071
- def filter_kwargs(cls, kwargs):
1072
- import inspect
1073
- sig = inspect.signature(cls.__init__)
1074
- valid_params = set(sig.parameters.keys()) - {'self', 'cls'}
1075
- filtered_kwargs = {k: v for k, v in kwargs.items() if k in valid_params}
1076
- return filtered_kwargs
1077
-
1078
- model = cls(**filter_kwargs(cls, additional_kwargs))
1079
- if pretrained_model_path.endswith(".safetensors"):
1080
- from safetensors.torch import load_file, safe_open
1081
- state_dict = load_file(pretrained_model_path)
1082
- else:
1083
- state_dict = torch.load(pretrained_model_path, map_location="cpu")
1084
- tmp_state_dict = {}
1085
- for key in state_dict:
1086
- tmp_state_dict["model." + key] = state_dict[key]
1087
- state_dict = tmp_state_dict
1088
- m, u = model.load_state_dict(state_dict, strict=False)
1089
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1090
- print(m, u)
1091
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/wan_xlm_roberta.py DELETED
@@ -1,170 +0,0 @@
1
- # Modified from transformers.models.xlm_roberta.modeling_xlm_roberta
2
- # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- __all__ = ['XLMRoberta', 'xlm_roberta_large']
8
-
9
-
10
- class SelfAttention(nn.Module):
11
-
12
- def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
13
- assert dim % num_heads == 0
14
- super().__init__()
15
- self.dim = dim
16
- self.num_heads = num_heads
17
- self.head_dim = dim // num_heads
18
- self.eps = eps
19
-
20
- # layers
21
- self.q = nn.Linear(dim, dim)
22
- self.k = nn.Linear(dim, dim)
23
- self.v = nn.Linear(dim, dim)
24
- self.o = nn.Linear(dim, dim)
25
- self.dropout = nn.Dropout(dropout)
26
-
27
- def forward(self, x, mask):
28
- """
29
- x: [B, L, C].
30
- """
31
- b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
32
-
33
- # compute query, key, value
34
- q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
35
- k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
36
- v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
37
-
38
- # compute attention
39
- p = self.dropout.p if self.training else 0.0
40
- x = F.scaled_dot_product_attention(q, k, v, mask, p)
41
- x = x.permute(0, 2, 1, 3).reshape(b, s, c)
42
-
43
- # output
44
- x = self.o(x)
45
- x = self.dropout(x)
46
- return x
47
-
48
-
49
- class AttentionBlock(nn.Module):
50
-
51
- def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
52
- super().__init__()
53
- self.dim = dim
54
- self.num_heads = num_heads
55
- self.post_norm = post_norm
56
- self.eps = eps
57
-
58
- # layers
59
- self.attn = SelfAttention(dim, num_heads, dropout, eps)
60
- self.norm1 = nn.LayerNorm(dim, eps=eps)
61
- self.ffn = nn.Sequential(
62
- nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim),
63
- nn.Dropout(dropout))
64
- self.norm2 = nn.LayerNorm(dim, eps=eps)
65
-
66
- def forward(self, x, mask):
67
- if self.post_norm:
68
- x = self.norm1(x + self.attn(x, mask))
69
- x = self.norm2(x + self.ffn(x))
70
- else:
71
- x = x + self.attn(self.norm1(x), mask)
72
- x = x + self.ffn(self.norm2(x))
73
- return x
74
-
75
-
76
- class XLMRoberta(nn.Module):
77
- """
78
- XLMRobertaModel with no pooler and no LM head.
79
- """
80
-
81
- def __init__(self,
82
- vocab_size=250002,
83
- max_seq_len=514,
84
- type_size=1,
85
- pad_id=1,
86
- dim=1024,
87
- num_heads=16,
88
- num_layers=24,
89
- post_norm=True,
90
- dropout=0.1,
91
- eps=1e-5):
92
- super().__init__()
93
- self.vocab_size = vocab_size
94
- self.max_seq_len = max_seq_len
95
- self.type_size = type_size
96
- self.pad_id = pad_id
97
- self.dim = dim
98
- self.num_heads = num_heads
99
- self.num_layers = num_layers
100
- self.post_norm = post_norm
101
- self.eps = eps
102
-
103
- # embeddings
104
- self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
105
- self.type_embedding = nn.Embedding(type_size, dim)
106
- self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
107
- self.dropout = nn.Dropout(dropout)
108
-
109
- # blocks
110
- self.blocks = nn.ModuleList([
111
- AttentionBlock(dim, num_heads, post_norm, dropout, eps)
112
- for _ in range(num_layers)
113
- ])
114
-
115
- # norm layer
116
- self.norm = nn.LayerNorm(dim, eps=eps)
117
-
118
- def forward(self, ids):
119
- """
120
- ids: [B, L] of torch.LongTensor.
121
- """
122
- b, s = ids.shape
123
- mask = ids.ne(self.pad_id).long()
124
-
125
- # embeddings
126
- x = self.token_embedding(ids) + \
127
- self.type_embedding(torch.zeros_like(ids)) + \
128
- self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
129
- if self.post_norm:
130
- x = self.norm(x)
131
- x = self.dropout(x)
132
-
133
- # blocks
134
- mask = torch.where(
135
- mask.view(b, 1, 1, s).gt(0), 0.0,
136
- torch.finfo(x.dtype).min)
137
- for block in self.blocks:
138
- x = block(x, mask)
139
-
140
- # output
141
- if not self.post_norm:
142
- x = self.norm(x)
143
- return x
144
-
145
-
146
- def xlm_roberta_large(pretrained=False,
147
- return_tokenizer=False,
148
- device='cpu',
149
- **kwargs):
150
- """
151
- XLMRobertaLarge adapted from Huggingface.
152
- """
153
- # params
154
- cfg = dict(
155
- vocab_size=250002,
156
- max_seq_len=514,
157
- type_size=1,
158
- pad_id=1,
159
- dim=1024,
160
- num_heads=16,
161
- num_layers=24,
162
- post_norm=True,
163
- dropout=0.1,
164
- eps=1e-5)
165
- cfg.update(**kwargs)
166
-
167
- # init a model on device
168
- with torch.device(device):
169
- model = XLMRoberta(**cfg)
170
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videox_fun/models/z_image_transformer2d.py DELETED
@@ -1,1050 +0,0 @@
1
- # Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import glob
16
- import inspect
17
- import json
18
- import os
19
- import math
20
- from typing import Any, Dict, List, Optional, Tuple, Union
21
-
22
- import torch
23
- import torch.nn as nn
24
- import torch.nn.functional as F
25
- import torch
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
- from torch.nn.utils.rnn import pad_sequence
29
-
30
- from diffusers.configuration_utils import ConfigMixin, register_to_config
31
- from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
32
- from diffusers.models.attention_processor import Attention
33
- from diffusers.models.modeling_utils import ModelMixin
34
- from diffusers.models.normalization import RMSNorm
35
- from diffusers.utils.torch_utils import maybe_allow_in_graph
36
- from diffusers.models.attention_processor import Attention, AttentionProcessor
37
- from diffusers.utils import (USE_PEFT_BACKEND, is_torch_version, logging,
38
- scale_lora_layers, unscale_lora_layers)
39
-
40
- from .attention_utils import attention
41
- from ..dist import (ZMultiGPUsSingleStreamAttnProcessor, get_sequence_parallel_rank,
42
- get_sequence_parallel_world_size, get_sp_group)
43
-
44
-
45
- ADALN_EMBED_DIM = 256
46
- SEQ_MULTI_OF = 32
47
-
48
-
49
- class TimestepEmbedder(nn.Module):
50
- def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
51
- super().__init__()
52
- if mid_size is None:
53
- mid_size = out_size
54
- self.mlp = nn.Sequential(
55
- nn.Linear(
56
- frequency_embedding_size,
57
- mid_size,
58
- bias=True,
59
- ),
60
- nn.SiLU(),
61
- nn.Linear(
62
- mid_size,
63
- out_size,
64
- bias=True,
65
- ),
66
- )
67
-
68
- self.frequency_embedding_size = frequency_embedding_size
69
-
70
- @staticmethod
71
- def timestep_embedding(t, dim, max_period=10000):
72
- with torch.amp.autocast("cuda", enabled=False):
73
- half = dim // 2
74
- freqs = torch.exp(
75
- -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
76
- )
77
- args = t[:, None].float() * freqs[None]
78
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
79
- if dim % 2:
80
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
81
- return embedding
82
-
83
- def forward(self, t):
84
- t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
85
- weight_dtype = self.mlp[0].weight.dtype
86
- if weight_dtype.is_floating_point:
87
- t_freq = t_freq.to(weight_dtype)
88
- t_emb = self.mlp(t_freq)
89
- return t_emb
90
-
91
-
92
- class ZSingleStreamAttnProcessor:
93
- """
94
- Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
95
- original Z-ImageAttention module.
96
- """
97
-
98
- _attention_backend = None
99
- _parallel_config = None
100
-
101
- def __init__(self):
102
- if not hasattr(F, "scaled_dot_product_attention"):
103
- raise ImportError(
104
- "ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
105
- )
106
-
107
- def __call__(
108
- self,
109
- attn: Attention,
110
- hidden_states: torch.Tensor,
111
- encoder_hidden_states: Optional[torch.Tensor] = None,
112
- attention_mask: Optional[torch.Tensor] = None,
113
- freqs_cis: Optional[torch.Tensor] = None,
114
- ) -> torch.Tensor:
115
- query = attn.to_q(hidden_states)
116
- key = attn.to_k(hidden_states)
117
- value = attn.to_v(hidden_states)
118
-
119
- query = query.unflatten(-1, (attn.heads, -1))
120
- key = key.unflatten(-1, (attn.heads, -1))
121
- value = value.unflatten(-1, (attn.heads, -1))
122
-
123
- # Apply Norms
124
- if attn.norm_q is not None:
125
- query = attn.norm_q(query)
126
- if attn.norm_k is not None:
127
- key = attn.norm_k(key)
128
-
129
- # Apply RoPE
130
- def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
131
- with torch.amp.autocast("cuda", enabled=False):
132
- x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
133
- freqs_cis = freqs_cis.unsqueeze(2)
134
- x_out = torch.view_as_real(x * freqs_cis).flatten(3)
135
- return x_out.type_as(x_in) # todo
136
-
137
- if freqs_cis is not None:
138
- query = apply_rotary_emb(query, freqs_cis)
139
- key = apply_rotary_emb(key, freqs_cis)
140
-
141
- # Cast to correct dtype
142
- dtype = query.dtype
143
- query, key = query.to(dtype), key.to(dtype)
144
-
145
- # From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
146
- if attention_mask is not None and attention_mask.ndim == 2:
147
- attention_mask = attention_mask[:, None, None, :]
148
-
149
- # Compute joint attention
150
- hidden_states = attention(
151
- query,
152
- key,
153
- value,
154
- attn_mask=attention_mask
155
- )
156
-
157
- # Reshape back
158
- hidden_states = hidden_states.flatten(2, 3)
159
- hidden_states = hidden_states.to(dtype)
160
-
161
- output = attn.to_out[0](hidden_states)
162
- if len(attn.to_out) > 1: # dropout
163
- output = attn.to_out[1](output)
164
-
165
- return output
166
-
167
-
168
- class FeedForward(nn.Module):
169
- def __init__(self, dim: int, hidden_dim: int):
170
- super().__init__()
171
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
172
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
173
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
174
-
175
- def _forward_silu_gating(self, x1, x3):
176
- return F.silu(x1) * x3
177
-
178
- def forward(self, x):
179
- return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
180
-
181
-
182
- @maybe_allow_in_graph
183
- class ZImageTransformerBlock(nn.Module):
184
- def __init__(
185
- self,
186
- layer_id: int,
187
- dim: int,
188
- n_heads: int,
189
- n_kv_heads: int,
190
- norm_eps: float,
191
- qk_norm: bool,
192
- modulation=True,
193
- ):
194
- super().__init__()
195
- self.dim = dim
196
- self.head_dim = dim // n_heads
197
-
198
- # Refactored to use diffusers Attention with custom processor
199
- # Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
200
- self.attention = Attention(
201
- query_dim=dim,
202
- cross_attention_dim=None,
203
- dim_head=dim // n_heads,
204
- heads=n_heads,
205
- qk_norm="rms_norm" if qk_norm else None,
206
- eps=1e-5,
207
- bias=False,
208
- out_bias=False,
209
- processor=ZSingleStreamAttnProcessor(),
210
- )
211
-
212
- self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
213
- self.layer_id = layer_id
214
-
215
- self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
216
- self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
217
-
218
- self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
219
- self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
220
-
221
- self.modulation = modulation
222
- if modulation:
223
- self.adaLN_modulation = nn.Sequential(
224
- nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True),
225
- )
226
-
227
- def forward(
228
- self,
229
- x: torch.Tensor,
230
- attn_mask: torch.Tensor,
231
- freqs_cis: torch.Tensor,
232
- adaln_input: Optional[torch.Tensor] = None,
233
- ):
234
- if self.modulation:
235
- assert adaln_input is not None
236
- scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
237
- gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
238
- scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
239
-
240
- # Attention block
241
- attn_out = self.attention(
242
- self.attention_norm1(x) * scale_msa,
243
- attention_mask=attn_mask,
244
- freqs_cis=freqs_cis,
245
- )
246
- x = x + gate_msa * self.attention_norm2(attn_out)
247
-
248
- # FFN block
249
- x = x + gate_mlp * self.ffn_norm2(
250
- self.feed_forward(
251
- self.ffn_norm1(x) * scale_mlp,
252
- )
253
- )
254
- else:
255
- # Attention block
256
- attn_out = self.attention(
257
- self.attention_norm1(x),
258
- attention_mask=attn_mask,
259
- freqs_cis=freqs_cis,
260
- )
261
- x = x + self.attention_norm2(attn_out)
262
-
263
- # FFN block
264
- x = x + self.ffn_norm2(
265
- self.feed_forward(
266
- self.ffn_norm1(x),
267
- )
268
- )
269
-
270
- return x
271
-
272
-
273
- class FinalLayer(nn.Module):
274
- def __init__(self, hidden_size, out_channels):
275
- super().__init__()
276
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
277
- self.linear = nn.Linear(hidden_size, out_channels, bias=True)
278
-
279
- self.adaLN_modulation = nn.Sequential(
280
- nn.SiLU(),
281
- nn.Linear(min(hidden_size, ADALN_EMBED_DIM), hidden_size, bias=True),
282
- )
283
-
284
- def forward(self, x, c):
285
- scale = 1.0 + self.adaLN_modulation(c)
286
- x = self.norm_final(x) * scale.unsqueeze(1)
287
- x = self.linear(x)
288
- return x
289
-
290
-
291
- class RopeEmbedder:
292
- def __init__(
293
- self,
294
- theta: float = 256.0,
295
- axes_dims: List[int] = (16, 56, 56),
296
- axes_lens: List[int] = (64, 128, 128),
297
- ):
298
- self.theta = theta
299
- self.axes_dims = axes_dims
300
- self.axes_lens = axes_lens
301
- assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
302
- self.freqs_cis = None
303
-
304
- @staticmethod
305
- def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
306
- with torch.device("cpu"):
307
- freqs_cis = []
308
- for i, (d, e) in enumerate(zip(dim, end)):
309
- freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
310
- timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
311
- freqs = torch.outer(timestep, freqs).float()
312
- freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
313
- freqs_cis.append(freqs_cis_i)
314
-
315
- return freqs_cis
316
-
317
- def __call__(self, ids: torch.Tensor):
318
- assert ids.ndim == 2
319
- assert ids.shape[-1] == len(self.axes_dims)
320
- device = ids.device
321
-
322
- if self.freqs_cis is None:
323
- self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
324
- self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
325
- else:
326
- # Ensure freqs_cis are on the same device as ids
327
- if self.freqs_cis[0].device != device:
328
- self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
329
-
330
- result = []
331
- for i in range(len(self.axes_dims)):
332
- index = ids[:, i]
333
- result.append(self.freqs_cis[i][index])
334
- return torch.cat(result, dim=-1)
335
-
336
-
337
- class ZImageTransformer2DModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
338
- _supports_gradient_checkpointing = True
339
- # _no_split_modules = ["ZImageTransformerBlock"]
340
- # _skip_layerwise_casting_patterns = ["t_embedder", "cap_embedder"] # precision sensitive layers
341
-
342
- @register_to_config
343
- def __init__(
344
- self,
345
- all_patch_size=(2,),
346
- all_f_patch_size=(1,),
347
- in_channels=16,
348
- dim=3840,
349
- n_layers=30,
350
- n_refiner_layers=2,
351
- n_heads=30,
352
- n_kv_heads=30,
353
- norm_eps=1e-5,
354
- qk_norm=True,
355
- cap_feat_dim=2560,
356
- rope_theta=256.0,
357
- t_scale=1000.0,
358
- axes_dims=[32, 48, 48],
359
- axes_lens=[1024, 512, 512],
360
- ) -> None:
361
- super().__init__()
362
- self.in_channels = in_channels
363
- self.out_channels = in_channels
364
- self.all_patch_size = all_patch_size
365
- self.all_f_patch_size = all_f_patch_size
366
- self.dim = dim
367
- self.n_heads = n_heads
368
-
369
- self.rope_theta = rope_theta
370
- self.t_scale = t_scale
371
- self.gradient_checkpointing = False
372
-
373
- assert len(all_patch_size) == len(all_f_patch_size)
374
-
375
- all_x_embedder = {}
376
- all_final_layer = {}
377
- for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
378
- x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
379
- all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
380
-
381
- final_layer = FinalLayer(dim, patch_size * patch_size * f_patch_size * self.out_channels)
382
- all_final_layer[f"{patch_size}-{f_patch_size}"] = final_layer
383
-
384
- self.all_x_embedder = nn.ModuleDict(all_x_embedder)
385
- self.all_final_layer = nn.ModuleDict(all_final_layer)
386
- self.noise_refiner = nn.ModuleList(
387
- [
388
- ZImageTransformerBlock(
389
- 1000 + layer_id,
390
- dim,
391
- n_heads,
392
- n_kv_heads,
393
- norm_eps,
394
- qk_norm,
395
- modulation=True,
396
- )
397
- for layer_id in range(n_refiner_layers)
398
- ]
399
- )
400
- self.context_refiner = nn.ModuleList(
401
- [
402
- ZImageTransformerBlock(
403
- layer_id,
404
- dim,
405
- n_heads,
406
- n_kv_heads,
407
- norm_eps,
408
- qk_norm,
409
- modulation=False,
410
- )
411
- for layer_id in range(n_refiner_layers)
412
- ]
413
- )
414
- self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
415
- self.cap_embedder = nn.Sequential(
416
- RMSNorm(cap_feat_dim, eps=norm_eps),
417
- nn.Linear(cap_feat_dim, dim, bias=True),
418
- )
419
-
420
- self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
421
- self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
422
-
423
- self.layers = nn.ModuleList(
424
- [
425
- ZImageTransformerBlock(layer_id, dim, n_heads, n_kv_heads, norm_eps, qk_norm)
426
- for layer_id in range(n_layers)
427
- ]
428
- )
429
- head_dim = dim // n_heads
430
- assert head_dim == sum(axes_dims)
431
- self.axes_dims = axes_dims
432
- self.axes_lens = axes_lens
433
-
434
- self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
435
-
436
- self.sp_world_size = 1
437
- self.sp_world_rank = 0
438
-
439
- def _set_gradient_checkpointing(self, *args, **kwargs):
440
- if "value" in kwargs:
441
- self.gradient_checkpointing = kwargs["value"]
442
- elif "enable" in kwargs:
443
- self.gradient_checkpointing = kwargs["enable"]
444
- else:
445
- raise ValueError("Invalid set gradient checkpointing")
446
-
447
- def enable_multi_gpus_inference(self,):
448
- self.sp_world_size = get_sequence_parallel_world_size()
449
- self.sp_world_rank = get_sequence_parallel_rank()
450
- self.all_gather = get_sp_group().all_gather
451
- self.set_attn_processor(ZMultiGPUsSingleStreamAttnProcessor())
452
-
453
- @property
454
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
455
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
456
- r"""
457
- Returns:
458
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
459
- indexed by its weight name.
460
- """
461
- # set recursively
462
- processors = {}
463
-
464
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
465
- if hasattr(module, "get_processor"):
466
- processors[f"{name}.processor"] = module.get_processor()
467
-
468
- for sub_name, child in module.named_children():
469
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
470
-
471
- return processors
472
-
473
- for name, module in self.named_children():
474
- fn_recursive_add_processors(name, module, processors)
475
-
476
- return processors
477
-
478
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
479
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
480
- r"""
481
- Sets the attention processor to use to compute attention.
482
-
483
- Parameters:
484
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
485
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
486
- for **all** `Attention` layers.
487
-
488
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
489
- processor. This is strongly recommended when setting trainable attention processors.
490
-
491
- """
492
- count = len(self.attn_processors.keys())
493
-
494
- if isinstance(processor, dict) and len(processor) != count:
495
- raise ValueError(
496
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
497
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
498
- )
499
-
500
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
501
- if hasattr(module, "set_processor"):
502
- if not isinstance(processor, dict):
503
- module.set_processor(processor)
504
- else:
505
- module.set_processor(processor.pop(f"{name}.processor"))
506
-
507
- for sub_name, child in module.named_children():
508
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
509
-
510
- for name, module in self.named_children():
511
- fn_recursive_attn_processor(name, module, processor)
512
-
513
- def unpatchify(self, x: List[torch.Tensor], size: List[Tuple], patch_size, f_patch_size) -> List[torch.Tensor]:
514
- pH = pW = patch_size
515
- pF = f_patch_size
516
- bsz = len(x)
517
- assert len(size) == bsz
518
- for i in range(bsz):
519
- F, H, W = size[i]
520
- ori_len = (F // pF) * (H // pH) * (W // pW)
521
- # "f h w pf ph pw c -> c (f pf) (h ph) (w pw)"
522
- x[i] = (
523
- x[i][:ori_len]
524
- .view(F // pF, H // pH, W // pW, pF, pH, pW, self.out_channels)
525
- .permute(6, 0, 3, 1, 4, 2, 5)
526
- .reshape(self.out_channels, F, H, W)
527
- )
528
- return x
529
-
530
- @staticmethod
531
- def create_coordinate_grid(size, start=None, device=None):
532
- if start is None:
533
- start = (0 for _ in size)
534
-
535
- axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
536
- grids = torch.meshgrid(axes, indexing="ij")
537
- return torch.stack(grids, dim=-1)
538
-
539
- def patchify(
540
- self,
541
- all_image: List[torch.Tensor],
542
- patch_size: int,
543
- f_patch_size: int,
544
- cap_padding_len: int,
545
- ):
546
- pH = pW = patch_size
547
- pF = f_patch_size
548
- device = all_image[0].device
549
-
550
- all_image_out = []
551
- all_image_size = []
552
- all_image_pos_ids = []
553
- all_image_pad_mask = []
554
-
555
- for i, image in enumerate(all_image):
556
- ### Process Image
557
- C, F, H, W = image.size()
558
- all_image_size.append((F, H, W))
559
- F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
560
-
561
- image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
562
- # "c f pf h ph w pw -> (f h w) (pf ph pw c)"
563
- image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
564
-
565
- image_ori_len = len(image)
566
- image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
567
-
568
- image_ori_pos_ids = self.create_coordinate_grid(
569
- size=(F_tokens, H_tokens, W_tokens),
570
- start=(cap_padding_len + 1, 0, 0),
571
- device=device,
572
- ).flatten(0, 2)
573
- image_padding_pos_ids = (
574
- self.create_coordinate_grid(
575
- size=(1, 1, 1),
576
- start=(0, 0, 0),
577
- device=device,
578
- )
579
- .flatten(0, 2)
580
- .repeat(image_padding_len, 1)
581
- )
582
- image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
583
- all_image_pos_ids.append(image_padded_pos_ids)
584
- # pad mask
585
- all_image_pad_mask.append(
586
- torch.cat(
587
- [
588
- torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
589
- torch.ones((image_padding_len,), dtype=torch.bool, device=device),
590
- ],
591
- dim=0,
592
- )
593
- )
594
- # padded feature
595
- image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
596
- all_image_out.append(image_padded_feat)
597
-
598
- return (
599
- all_image_out,
600
- all_image_size,
601
- all_image_pos_ids,
602
- all_image_pad_mask,
603
- )
604
-
605
- def patchify_and_embed(
606
- self,
607
- all_image: List[torch.Tensor],
608
- all_cap_feats: List[torch.Tensor],
609
- patch_size: int,
610
- f_patch_size: int,
611
- ):
612
- pH = pW = patch_size
613
- pF = f_patch_size
614
- device = all_image[0].device
615
-
616
- all_image_out = []
617
- all_image_size = []
618
- all_image_pos_ids = []
619
- all_image_pad_mask = []
620
- all_cap_pos_ids = []
621
- all_cap_pad_mask = []
622
- all_cap_feats_out = []
623
-
624
- for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
625
- ### Process Caption
626
- cap_ori_len = len(cap_feat)
627
- cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
628
- # padded position ids
629
- cap_padded_pos_ids = self.create_coordinate_grid(
630
- size=(cap_ori_len + cap_padding_len, 1, 1),
631
- start=(1, 0, 0),
632
- device=device,
633
- ).flatten(0, 2)
634
- all_cap_pos_ids.append(cap_padded_pos_ids)
635
- # pad mask
636
- all_cap_pad_mask.append(
637
- torch.cat(
638
- [
639
- torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
640
- torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
641
- ],
642
- dim=0,
643
- )
644
- )
645
- # padded feature
646
- cap_padded_feat = torch.cat(
647
- [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
648
- dim=0,
649
- )
650
- all_cap_feats_out.append(cap_padded_feat)
651
-
652
- ### Process Image
653
- C, F, H, W = image.size()
654
- all_image_size.append((F, H, W))
655
- F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
656
-
657
- image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
658
- # "c f pf h ph w pw -> (f h w) (pf ph pw c)"
659
- image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
660
-
661
- image_ori_len = len(image)
662
- image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
663
-
664
- image_ori_pos_ids = self.create_coordinate_grid(
665
- size=(F_tokens, H_tokens, W_tokens),
666
- start=(cap_ori_len + cap_padding_len + 1, 0, 0),
667
- device=device,
668
- ).flatten(0, 2)
669
- image_padding_pos_ids = (
670
- self.create_coordinate_grid(
671
- size=(1, 1, 1),
672
- start=(0, 0, 0),
673
- device=device,
674
- )
675
- .flatten(0, 2)
676
- .repeat(image_padding_len, 1)
677
- )
678
- image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0)
679
- all_image_pos_ids.append(image_padded_pos_ids)
680
- # pad mask
681
- all_image_pad_mask.append(
682
- torch.cat(
683
- [
684
- torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
685
- torch.ones((image_padding_len,), dtype=torch.bool, device=device),
686
- ],
687
- dim=0,
688
- )
689
- )
690
- # padded feature
691
- image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
692
- all_image_out.append(image_padded_feat)
693
-
694
- return (
695
- all_image_out,
696
- all_cap_feats_out,
697
- all_image_size,
698
- all_image_pos_ids,
699
- all_cap_pos_ids,
700
- all_image_pad_mask,
701
- all_cap_pad_mask,
702
- )
703
-
704
- def forward(
705
- self,
706
- x: List[torch.Tensor],
707
- t,
708
- cap_feats: List[torch.Tensor],
709
- patch_size=2,
710
- f_patch_size=1,
711
- ):
712
- assert patch_size in self.all_patch_size
713
- assert f_patch_size in self.all_f_patch_size
714
-
715
- bsz = len(x)
716
- device = x[0].device
717
- t = t * self.t_scale
718
- t = self.t_embedder(t)
719
-
720
- (
721
- x,
722
- cap_feats,
723
- x_size,
724
- x_pos_ids,
725
- cap_pos_ids,
726
- x_inner_pad_mask,
727
- cap_inner_pad_mask,
728
- ) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
729
-
730
- # x embed & refine
731
- x_item_seqlens = [len(_) for _ in x]
732
- assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
733
- x_max_item_seqlen = max(x_item_seqlens)
734
-
735
- x = torch.cat(x, dim=0)
736
- x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
737
-
738
- # Match t_embedder output dtype to x for layerwise casting compatibility
739
- adaln_input = t.type_as(x)
740
- x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
741
- x = list(x.split(x_item_seqlens, dim=0))
742
- x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
743
-
744
- x = pad_sequence(x, batch_first=True, padding_value=0.0)
745
- x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
746
- x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
747
- for i, seq_len in enumerate(x_item_seqlens):
748
- x_attn_mask[i, :seq_len] = 1
749
-
750
- # Context Parallel
751
- if self.sp_world_size > 1:
752
- x = torch.chunk(x, self.sp_world_size, dim=1)[self.sp_world_rank]
753
-
754
- if torch.is_grad_enabled() and self.gradient_checkpointing:
755
- for layer in self.noise_refiner:
756
- def create_custom_forward(module):
757
- def custom_forward(*inputs):
758
- return module(*inputs)
759
-
760
- return custom_forward
761
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
762
- x = torch.utils.checkpoint.checkpoint(
763
- create_custom_forward(layer),
764
- x, x_attn_mask, x_freqs_cis, adaln_input,
765
- **ckpt_kwargs,
766
- )
767
- else:
768
- for layer in self.noise_refiner:
769
- x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
770
-
771
- # cap embed & refine
772
- cap_item_seqlens = [len(_) for _ in cap_feats]
773
- assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
774
- cap_max_item_seqlen = max(cap_item_seqlens)
775
-
776
- cap_feats = torch.cat(cap_feats, dim=0)
777
- cap_feats = self.cap_embedder(cap_feats)
778
- cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
779
- cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
780
- cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
781
-
782
- cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
783
- cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
784
- cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
785
- for i, seq_len in enumerate(cap_item_seqlens):
786
- cap_attn_mask[i, :seq_len] = 1
787
-
788
- if torch.is_grad_enabled() and self.gradient_checkpointing:
789
- for layer in self.context_refiner:
790
- def create_custom_forward(module):
791
- def custom_forward(*inputs):
792
- return module(*inputs)
793
-
794
- return custom_forward
795
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
796
- cap_feats = torch.utils.checkpoint.checkpoint(
797
- create_custom_forward(layer),
798
- cap_feats,
799
- cap_attn_mask,
800
- cap_freqs_cis,
801
- **ckpt_kwargs,
802
- )
803
- else:
804
- for layer in self.context_refiner:
805
- cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
806
-
807
- # unified
808
- unified = []
809
- unified_freqs_cis = []
810
- for i in range(bsz):
811
- x_len = x_item_seqlens[i]
812
- cap_len = cap_item_seqlens[i]
813
- unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
814
- unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
815
- unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
816
- assert unified_item_seqlens == [len(_) for _ in unified]
817
- unified_max_item_seqlen = max(unified_item_seqlens)
818
-
819
- unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
820
- unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
821
- unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
822
- for i, seq_len in enumerate(unified_item_seqlens):
823
- unified_attn_mask[i, :seq_len] = 1
824
-
825
- if torch.is_grad_enabled() and self.gradient_checkpointing:
826
- for layer in self.layers:
827
- def create_custom_forward(module):
828
- def custom_forward(*inputs):
829
- return module(*inputs)
830
-
831
- return custom_forward
832
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
833
- unified = torch.utils.checkpoint.checkpoint(
834
- create_custom_forward(layer),
835
- unified,
836
- unified_attn_mask,
837
- unified_freqs_cis,
838
- adaln_input,
839
- **ckpt_kwargs,
840
- )
841
- else:
842
- for layer in self.layers:
843
- unified = layer(unified, unified_attn_mask, unified_freqs_cis, adaln_input)
844
-
845
- unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](unified, adaln_input)
846
- unified = list(unified.unbind(dim=0))
847
- x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
848
-
849
- if self.sp_world_size > 1:
850
- x = self.all_gather(x, dim=1)
851
- x = torch.stack(x)
852
- return x, {}
853
-
854
-
855
- @classmethod
856
- def from_pretrained(
857
- cls, pretrained_model_path, subfolder=None, transformer_additional_kwargs={},
858
- low_cpu_mem_usage=False, torch_dtype=torch.bfloat16
859
- ):
860
- if subfolder is not None:
861
- pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
862
- print(f"loaded 3D transformer's pretrained weights from {pretrained_model_path} ...")
863
-
864
- config_file = os.path.join(pretrained_model_path, 'config.json')
865
- if not os.path.isfile(config_file):
866
- raise RuntimeError(f"{config_file} does not exist")
867
- with open(config_file, "r") as f:
868
- config = json.load(f)
869
-
870
- from diffusers.utils import WEIGHTS_NAME
871
- model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
872
- model_file_safetensors = model_file.replace(".bin", ".safetensors")
873
-
874
- if "dict_mapping" in transformer_additional_kwargs.keys():
875
- for key in transformer_additional_kwargs["dict_mapping"]:
876
- transformer_additional_kwargs[transformer_additional_kwargs["dict_mapping"][key]] = config[key]
877
-
878
- if low_cpu_mem_usage:
879
- try:
880
- import re
881
-
882
- from diffusers import __version__ as diffusers_version
883
- if diffusers_version >= "0.33.0":
884
- from diffusers.models.model_loading_utils import \
885
- load_model_dict_into_meta
886
- else:
887
- from diffusers.models.modeling_utils import \
888
- load_model_dict_into_meta
889
- from diffusers.utils import is_accelerate_available
890
- if is_accelerate_available():
891
- import accelerate
892
-
893
- # Instantiate model with empty weights
894
- with accelerate.init_empty_weights():
895
- model = cls.from_config(config, **transformer_additional_kwargs)
896
-
897
- param_device = "cpu"
898
- if os.path.exists(model_file):
899
- state_dict = torch.load(model_file, map_location="cpu")
900
- elif os.path.exists(model_file_safetensors):
901
- from safetensors.torch import load_file, safe_open
902
- state_dict = load_file(model_file_safetensors)
903
- else:
904
- from safetensors.torch import load_file, safe_open
905
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
906
- state_dict = {}
907
- print(model_files_safetensors)
908
- for _model_file_safetensors in model_files_safetensors:
909
- _state_dict = load_file(_model_file_safetensors)
910
- for key in _state_dict:
911
- state_dict[key] = _state_dict[key]
912
-
913
- filtered_state_dict = {}
914
- for key in state_dict:
915
- if key in model.state_dict() and model.state_dict()[key].size() == state_dict[key].size():
916
- filtered_state_dict[key] = state_dict[key]
917
- else:
918
- print(f"Skipping key '{key}' due to size mismatch or absence in model.")
919
-
920
- model_keys = set(model.state_dict().keys())
921
- loaded_keys = set(filtered_state_dict.keys())
922
- missing_keys = model_keys - loaded_keys
923
-
924
- def initialize_missing_parameters(missing_keys, model_state_dict, torch_dtype=None):
925
- initialized_dict = {}
926
-
927
- with torch.no_grad():
928
- for key in missing_keys:
929
- param_shape = model_state_dict[key].shape
930
- param_dtype = torch_dtype if torch_dtype is not None else model_state_dict[key].dtype
931
- if "control" in key and key.replace("control_", "") in filtered_state_dict.keys():
932
- initialized_dict[key] = filtered_state_dict[key.replace("control_", "")].clone()
933
- print(f"Initializing missing parameter '{key}' with model.state_dict().")
934
- elif "after_proj" in key or "before_proj" in key:
935
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
936
- print(f"Initializing missing parameter '{key}' with zero.")
937
- elif 'weight' in key:
938
- if any(norm_type in key for norm_type in ['norm', 'ln_', 'layer_norm', 'group_norm', 'batch_norm']):
939
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
940
- elif 'embedding' in key or 'embed' in key:
941
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
942
- elif 'head' in key or 'output' in key or 'proj_out' in key:
943
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
944
- elif len(param_shape) >= 2:
945
- initialized_dict[key] = torch.empty(param_shape, dtype=param_dtype)
946
- nn.init.xavier_uniform_(initialized_dict[key])
947
- else:
948
- initialized_dict[key] = torch.randn(param_shape, dtype=param_dtype) * 0.02
949
- elif 'bias' in key:
950
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
951
- elif 'running_mean' in key:
952
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
953
- elif 'running_var' in key:
954
- initialized_dict[key] = torch.ones(param_shape, dtype=param_dtype)
955
- elif 'num_batches_tracked' in key:
956
- initialized_dict[key] = torch.zeros(param_shape, dtype=torch.long)
957
- else:
958
- initialized_dict[key] = torch.zeros(param_shape, dtype=param_dtype)
959
-
960
- return initialized_dict
961
-
962
- if missing_keys:
963
- print(f"Missing keys will be initialized: {sorted(missing_keys)}")
964
- initialized_params = initialize_missing_parameters(
965
- missing_keys,
966
- model.state_dict(),
967
- torch_dtype
968
- )
969
- filtered_state_dict.update(initialized_params)
970
-
971
- if diffusers_version >= "0.33.0":
972
- # Diffusers has refactored `load_model_dict_into_meta` since version 0.33.0 in this commit:
973
- # https://github.com/huggingface/diffusers/commit/f5929e03060d56063ff34b25a8308833bec7c785.
974
- load_model_dict_into_meta(
975
- model,
976
- filtered_state_dict,
977
- dtype=torch_dtype,
978
- model_name_or_path=pretrained_model_path,
979
- )
980
- else:
981
- model._convert_deprecated_attention_blocks(filtered_state_dict)
982
- unexpected_keys = load_model_dict_into_meta(
983
- model,
984
- filtered_state_dict,
985
- device=param_device,
986
- dtype=torch_dtype,
987
- model_name_or_path=pretrained_model_path,
988
- )
989
-
990
- if cls._keys_to_ignore_on_load_unexpected is not None:
991
- for pat in cls._keys_to_ignore_on_load_unexpected:
992
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
993
-
994
- if len(unexpected_keys) > 0:
995
- print(
996
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
997
- )
998
-
999
- params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1000
- print(f"### All Parameters: {sum(params) / 1e6} M")
1001
-
1002
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1003
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1004
- return model
1005
- except Exception as e:
1006
- print(
1007
- f"The low_cpu_mem_usage mode is not work because {e}. Use low_cpu_mem_usage=False instead."
1008
- )
1009
-
1010
- model = cls.from_config(config, **transformer_additional_kwargs)
1011
- if os.path.exists(model_file):
1012
- state_dict = torch.load(model_file, map_location="cpu")
1013
- elif os.path.exists(model_file_safetensors):
1014
- from safetensors.torch import load_file, safe_open
1015
- state_dict = load_file(model_file_safetensors)
1016
- else:
1017
- from safetensors.torch import load_file, safe_open
1018
- model_files_safetensors = glob.glob(os.path.join(pretrained_model_path, "*.safetensors"))
1019
- state_dict = {}
1020
- for _model_file_safetensors in model_files_safetensors:
1021
- _state_dict = load_file(_model_file_safetensors)
1022
- for key in _state_dict:
1023
- state_dict[key] = _state_dict[key]
1024
-
1025
- tmp_state_dict = {}
1026
- for key in state_dict:
1027
- if key in model.state_dict().keys() and model.state_dict()[key].size() == state_dict[key].size():
1028
- tmp_state_dict[key] = state_dict[key]
1029
- else:
1030
- print(key, "Size don't match, skip")
1031
-
1032
- for key in model.state_dict():
1033
- if "control" in key and key.replace("control_", "") in state_dict.keys() and model.state_dict()[key].size() == state_dict[key.replace("control_", "")].size():
1034
- tmp_state_dict[key] = state_dict[key.replace("control_", "")].clone()
1035
- print(f"Initializing missing parameter '{key}' with model.state_dict().")
1036
-
1037
- state_dict = tmp_state_dict
1038
-
1039
- m, u = model.load_state_dict(state_dict, strict=False)
1040
- print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
1041
- print(m)
1042
-
1043
- params = [p.numel() if "." in n else 0 for n, p in model.named_parameters()]
1044
- print(f"### All Parameters: {sum(params) / 1e6} M")
1045
-
1046
- params = [p.numel() if "attn1." in n else 0 for n, p in model.named_parameters()]
1047
- print(f"### attn1 Parameters: {sum(params) / 1e6} M")
1048
-
1049
- model = model.to(torch_dtype)
1050
- return model