JokerZhou commited on
Commit
8bf6610
·
1 Parent(s): 9fd6c72

Upload files

Browse files
Files changed (44) hide show
  1. .DS_Store +0 -0
  2. LICENSE +201 -0
  3. app.py +6 -0
  4. download_model_from_hf.py +20 -0
  5. flash_head/.DS_Store +0 -0
  6. flash_head/audio_analysis/torch_utils.py +20 -0
  7. flash_head/audio_analysis/wav2vec2.py +125 -0
  8. flash_head/configs/infer_params.yaml +10 -0
  9. flash_head/inference.py +77 -0
  10. flash_head/ltx_video/.DS_Store +0 -0
  11. flash_head/ltx_video/__init__.py +0 -0
  12. flash_head/ltx_video/ltx_vae.py +42 -0
  13. flash_head/ltx_video/models/__init__.py +0 -0
  14. flash_head/ltx_video/models/autoencoders/__init__.py +0 -0
  15. flash_head/ltx_video/models/autoencoders/causal_conv3d.py +63 -0
  16. flash_head/ltx_video/models/autoencoders/causal_video_autoencoder.py +1412 -0
  17. flash_head/ltx_video/models/autoencoders/conv_nd_factory.py +90 -0
  18. flash_head/ltx_video/models/autoencoders/dual_conv3d.py +217 -0
  19. flash_head/ltx_video/models/autoencoders/pixel_norm.py +12 -0
  20. flash_head/ltx_video/models/autoencoders/vae.py +380 -0
  21. flash_head/ltx_video/models/autoencoders/vae_encode.py +256 -0
  22. flash_head/ltx_video/models/autoencoders/video_autoencoder.py +1045 -0
  23. flash_head/ltx_video/models/transformers/__init__.py +0 -0
  24. flash_head/ltx_video/models/transformers/attention.py +1265 -0
  25. flash_head/ltx_video/models/transformers/embeddings.py +129 -0
  26. flash_head/ltx_video/models/transformers/symmetric_patchifier.py +84 -0
  27. flash_head/ltx_video/models/transformers/transformer3d.py +507 -0
  28. flash_head/ltx_video/utils/__init__.py +0 -0
  29. flash_head/ltx_video/utils/diffusers_config_mapping.py +174 -0
  30. flash_head/ltx_video/utils/prompt_enhance_utils.py +226 -0
  31. flash_head/ltx_video/utils/skip_layer_strategy.py +8 -0
  32. flash_head/ltx_video/utils/torch_utils.py +25 -0
  33. flash_head/src/.DS_Store +0 -0
  34. flash_head/src/distributed/usp_device.py +35 -0
  35. flash_head/src/modules/flash_head_model.py +548 -0
  36. flash_head/src/pipeline/flash_head_pipeline.py +316 -0
  37. flash_head/utils/cpu_face_handler.py +55 -0
  38. flash_head/utils/facecrop.py +110 -0
  39. flash_head/utils/utils.py +222 -0
  40. flash_head/wan/modules/__init__.py +5 -0
  41. flash_head/wan/modules/vae.py +1598 -0
  42. generate_video.py +218 -0
  43. gradio_app_streaming.py +339 -0
  44. requirements.txt +23 -0
.DS_Store ADDED
Binary file (6.15 kB). View file
 
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from gradio_app_streaming import app
2
+ from download_model_from_hf import download_model
3
+
4
+ if __name__ == "__main__":
5
+ download_model("Soul-AILab/SoulX-FlashHead-1_3B", "models")
6
+ app.launch(share=True)
download_model_from_hf.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from huggingface_hub import snapshot_download
3
+
4
+ def download_model(model_name, save_dir="models"):
5
+ # 创建保存目录
6
+ save_path = Path(save_dir) / model_name.split("/")[-1]
7
+
8
+ if save_path.exists():
9
+ print(f"✅ 模型已存在: {save_path}")
10
+ return str(save_path)
11
+
12
+ save_path.mkdir(parents=True, exist_ok=True)
13
+
14
+ download_path = snapshot_download(
15
+ repo_id=model_name,
16
+ local_dir=save_path,
17
+ local_dir_use_symlinks=False
18
+ )
19
+ print(f"✅ 下载完成: {download_path}")
20
+ return download_path
flash_head/.DS_Store ADDED
Binary file (6.15 kB). View file
 
flash_head/audio_analysis/torch_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ def get_mask_from_lengths(lengths, max_len=None):
6
+ lengths = lengths.to(torch.long)
7
+ if max_len is None:
8
+ max_len = torch.max(lengths).item()
9
+
10
+ ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
11
+ mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
12
+
13
+ return mask
14
+
15
+
16
+ def linear_interpolation(features, seq_len):
17
+ features = features.transpose(1, 2)
18
+ output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
19
+ return output_features.transpose(1, 2)
20
+
flash_head/audio_analysis/wav2vec2.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2Config, Wav2Vec2Model
2
+ from transformers.modeling_outputs import BaseModelOutput
3
+
4
+ from .torch_utils import linear_interpolation
5
+
6
+ # the implementation of Wav2Vec2Model is borrowed from
7
+ # https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
8
+ # initialize our encoder with the pre-trained wav2vec 2.0 weights.
9
+ class Wav2Vec2Model(Wav2Vec2Model):
10
+ def __init__(self, config: Wav2Vec2Config):
11
+ super().__init__(config)
12
+
13
+ def forward(
14
+ self,
15
+ input_values,
16
+ seq_len,
17
+ attention_mask=None,
18
+ mask_time_indices=None,
19
+ output_attentions=None,
20
+ output_hidden_states=None,
21
+ return_dict=None,
22
+ ):
23
+ self.config.output_attentions = False
24
+
25
+ output_hidden_states = (
26
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
27
+ )
28
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
29
+
30
+ extract_features = self.feature_extractor(input_values)
31
+ extract_features = extract_features.transpose(1, 2)
32
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
33
+
34
+ if attention_mask is not None:
35
+ # compute reduced attention_mask corresponding to feature vectors
36
+ attention_mask = self._get_feature_vector_attention_mask(
37
+ extract_features.shape[1], attention_mask, add_adapter=False
38
+ )
39
+
40
+ hidden_states, extract_features = self.feature_projection(extract_features)
41
+ hidden_states = self._mask_hidden_states(
42
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
43
+ )
44
+
45
+ encoder_outputs = self.encoder(
46
+ hidden_states,
47
+ attention_mask=attention_mask,
48
+ output_attentions=output_attentions,
49
+ output_hidden_states=output_hidden_states,
50
+ return_dict=return_dict,
51
+ )
52
+
53
+ hidden_states = encoder_outputs[0]
54
+
55
+ if self.adapter is not None:
56
+ hidden_states = self.adapter(hidden_states)
57
+
58
+ if not return_dict:
59
+ return (hidden_states, ) + encoder_outputs[1:]
60
+ return BaseModelOutput(
61
+ last_hidden_state=hidden_states,
62
+ hidden_states=encoder_outputs.hidden_states,
63
+ attentions=encoder_outputs.attentions,
64
+ )
65
+
66
+
67
+ def feature_extract(
68
+ self,
69
+ input_values,
70
+ seq_len,
71
+ ):
72
+ extract_features = self.feature_extractor(input_values)
73
+ extract_features = extract_features.transpose(1, 2)
74
+ extract_features = linear_interpolation(extract_features, seq_len=seq_len)
75
+
76
+ return extract_features
77
+
78
+ def encode(
79
+ self,
80
+ extract_features,
81
+ attention_mask=None,
82
+ mask_time_indices=None,
83
+ output_attentions=None,
84
+ output_hidden_states=None,
85
+ return_dict=None,
86
+ ):
87
+ self.config.output_attentions = False
88
+
89
+ output_hidden_states = (
90
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
91
+ )
92
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
93
+
94
+ if attention_mask is not None:
95
+ # compute reduced attention_mask corresponding to feature vectors
96
+ attention_mask = self._get_feature_vector_attention_mask(
97
+ extract_features.shape[1], attention_mask, add_adapter=False
98
+ )
99
+
100
+
101
+ hidden_states, extract_features = self.feature_projection(extract_features)
102
+ hidden_states = self._mask_hidden_states(
103
+ hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
104
+ )
105
+
106
+ encoder_outputs = self.encoder(
107
+ hidden_states,
108
+ attention_mask=attention_mask,
109
+ output_attentions=output_attentions,
110
+ output_hidden_states=output_hidden_states,
111
+ return_dict=return_dict,
112
+ )
113
+
114
+ hidden_states = encoder_outputs[0]
115
+
116
+ if self.adapter is not None:
117
+ hidden_states = self.adapter(hidden_states)
118
+
119
+ if not return_dict:
120
+ return (hidden_states, ) + encoder_outputs[1:]
121
+ return BaseModelOutput(
122
+ last_hidden_state=hidden_states,
123
+ hidden_states=encoder_outputs.hidden_states,
124
+ attentions=encoder_outputs.attentions,
125
+ )
flash_head/configs/infer_params.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ frame_num: 33
2
+ motion_frames_latent_num: 2
3
+ tgt_fps: 25
4
+ sample_rate: 16000
5
+ sample_shift: 5
6
+ color_correction_strength: 1.0
7
+ cached_audio_duration: 8
8
+ num_heads: 12
9
+ height: 512
10
+ width: 512
flash_head/inference.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import yaml
3
+ import torch
4
+ import copy
5
+ from loguru import logger
6
+
7
+ from flash_head.src.pipeline.flash_head_pipeline import FlashHeadPipeline
8
+ from flash_head.src.distributed.usp_device import get_device, get_parallel_degree
9
+
10
+ with open("flash_head/configs/infer_params.yaml", "r") as f:
11
+ infer_params = yaml.safe_load(f)
12
+
13
+ def get_pipeline(world_size, ckpt_dir, model_type, wav2vec_dir):
14
+ global infer_params
15
+ ulysses_degree, ring_degree = get_parallel_degree(world_size, infer_params['num_heads'])
16
+ device = get_device(ulysses_degree, ring_degree)
17
+ logger.info(f"ulysses_degree: {ulysses_degree}, ring_degree: {ring_degree}, device: {device}")
18
+
19
+ pipeline = FlashHeadPipeline(
20
+ checkpoint_dir=ckpt_dir,
21
+ model_type=model_type,
22
+ wav2vec_dir=wav2vec_dir,
23
+ device=device,
24
+ use_usp=(world_size > 1),
25
+ )
26
+
27
+ # compute motion_frames_num
28
+ motion_frames_latent_num = infer_params['motion_frames_latent_num']
29
+ motion_frames_num = (motion_frames_latent_num - 1) * pipeline.config.vae_stride[0] + 1
30
+ infer_params['motion_frames_num'] = motion_frames_num
31
+
32
+ # TODO: move to args
33
+ if model_type == "pretrained":
34
+ infer_params['sample_steps'] = 20
35
+ else:
36
+ infer_params['sample_steps'] = 4
37
+ return pipeline
38
+
39
+ def get_base_data(pipeline, cond_image_path_or_dir, base_seed, use_face_crop):
40
+ pipeline.prepare_params(
41
+ cond_image_path_or_dir=cond_image_path_or_dir,
42
+ target_size=(infer_params['height'], infer_params['width']),
43
+ frame_num=infer_params['frame_num'],
44
+ motion_frames_num=infer_params['motion_frames_num'],
45
+ sampling_steps=infer_params['sample_steps'],
46
+ seed=base_seed,
47
+ shift=infer_params['sample_shift'],
48
+ color_correction_strength=infer_params['color_correction_strength'],
49
+ use_face_crop=use_face_crop,
50
+ )
51
+
52
+ def get_infer_params():
53
+ global infer_params
54
+ return copy.deepcopy(infer_params)
55
+
56
+ def get_audio_embedding(pipeline, audio_array, audio_start_idx=-1, audio_end_idx=-1):
57
+ # audio_array = loudness_norm(audio_array, infer_params['sample_rate'])
58
+ audio_embedding = pipeline.preprocess_audio(audio_array, sr=infer_params['sample_rate'], fps=infer_params['tgt_fps'])
59
+
60
+ if audio_start_idx == -1 or audio_end_idx == -1:
61
+ audio_start_idx = 0
62
+ audio_end_idx = audio_embedding.shape[0]
63
+
64
+ indices = (torch.arange(2 * 2 + 1) - 2) * 1
65
+
66
+ center_indices = torch.arange(audio_start_idx, audio_end_idx, 1).unsqueeze(1) + indices.unsqueeze(0)
67
+ center_indices = torch.clamp(center_indices, min=0, max=audio_end_idx-1)
68
+
69
+ audio_embedding = audio_embedding[center_indices][None,...].contiguous()
70
+ return audio_embedding
71
+
72
+ def run_pipeline(pipeline, audio_embedding):
73
+ audio_embedding = audio_embedding.to(pipeline.device)
74
+ sample = pipeline.generate(audio_embedding)
75
+ sample_frames = (((sample+1)/2).permute(1,2,3,0).clip(0,1) * 255).contiguous()
76
+ return sample_frames
77
+
flash_head/ltx_video/.DS_Store ADDED
Binary file (6.15 kB). View file
 
flash_head/ltx_video/__init__.py ADDED
File without changes
flash_head/ltx_video/ltx_vae.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from flash_head.ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
3
+
4
+
5
+ class LtxVAE:
6
+ def __init__(
7
+ self,
8
+ pretrained_model_type_or_path,
9
+ dtype = torch.bfloat16,
10
+ device = "cuda",
11
+ ):
12
+ self.model = CausalVideoAutoencoder.from_pretrained(pretrained_model_type_or_path)
13
+ self.model = self.model.eval().requires_grad_(False).to(device).to(dtype)
14
+
15
+ # torch.Size([1, 3, 33, 512, 512]) -> torch.Size([128, 5, 16, 16])
16
+ def encode(self, video):
17
+ latents = self.model.encode(video, return_dict=False)[0].sample()
18
+ out = self.normalize_latents(latents)
19
+ return out[0]
20
+
21
+ # torch.Size([128, 5, 16, 16]) -> torch.Size([1, 3, 33, 512, 512])
22
+ def decode(self, zs):
23
+ latents = zs.unsqueeze(0)
24
+ image = self.model.decode(
25
+ self.un_normalize_latents(latents),
26
+ return_dict=False,
27
+ target_shape=latents.shape,
28
+ )[0]
29
+ return image
30
+
31
+ def normalize_latents(self, latents):
32
+ return (
33
+ (latents - self.model.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
34
+ / self.model.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
35
+ )
36
+
37
+
38
+ def un_normalize_latents(self,latents):
39
+ return (
40
+ latents * self.model.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
41
+ + self.model.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
42
+ )
flash_head/ltx_video/models/__init__.py ADDED
File without changes
flash_head/ltx_video/models/autoencoders/__init__.py ADDED
File without changes
flash_head/ltx_video/models/autoencoders/causal_conv3d.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class CausalConv3d(nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_channels,
11
+ out_channels,
12
+ kernel_size: int = 3,
13
+ stride: Union[int, Tuple[int]] = 1,
14
+ dilation: int = 1,
15
+ groups: int = 1,
16
+ spatial_padding_mode: str = "zeros",
17
+ **kwargs,
18
+ ):
19
+ super().__init__()
20
+
21
+ self.in_channels = in_channels
22
+ self.out_channels = out_channels
23
+
24
+ kernel_size = (kernel_size, kernel_size, kernel_size)
25
+ self.time_kernel_size = kernel_size[0]
26
+
27
+ dilation = (dilation, 1, 1)
28
+
29
+ height_pad = kernel_size[1] // 2
30
+ width_pad = kernel_size[2] // 2
31
+ padding = (0, height_pad, width_pad)
32
+
33
+ self.conv = nn.Conv3d(
34
+ in_channels,
35
+ out_channels,
36
+ kernel_size,
37
+ stride=stride,
38
+ dilation=dilation,
39
+ padding=padding,
40
+ padding_mode=spatial_padding_mode,
41
+ groups=groups,
42
+ )
43
+
44
+ def forward(self, x, causal: bool = True):
45
+ if causal:
46
+ first_frame_pad = x[:, :, :1, :, :].repeat(
47
+ (1, 1, self.time_kernel_size - 1, 1, 1)
48
+ )
49
+ x = torch.concatenate((first_frame_pad, x), dim=2)
50
+ else:
51
+ first_frame_pad = x[:, :, :1, :, :].repeat(
52
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
53
+ )
54
+ last_frame_pad = x[:, :, -1:, :, :].repeat(
55
+ (1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
56
+ )
57
+ x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
58
+ x = self.conv(x)
59
+ return x
60
+
61
+ @property
62
+ def weight(self):
63
+ return self.conv.weight
flash_head/ltx_video/models/autoencoders/causal_video_autoencoder.py ADDED
@@ -0,0 +1,1412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import partial
4
+ from types import SimpleNamespace
5
+ from typing import Any, Mapping, Optional, Tuple, Union, List
6
+ from pathlib import Path
7
+
8
+ import torch
9
+ import numpy as np
10
+ from einops import rearrange
11
+ from torch import nn
12
+ from diffusers.utils import logging
13
+ import torch.nn.functional as F
14
+ from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
15
+ from safetensors import safe_open
16
+
17
+
18
+ from flash_head.ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
19
+ from flash_head.ltx_video.models.autoencoders.pixel_norm import PixelNorm
20
+ from flash_head.ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
21
+ from flash_head.ltx_video.models.transformers.attention import Attention
22
+ from flash_head.ltx_video.utils.diffusers_config_mapping import (
23
+ diffusers_and_ours_config_mapping,
24
+ make_hashable_key,
25
+ VAE_KEYS_RENAME_DICT,
26
+ )
27
+
28
+ PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics."
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class CausalVideoAutoencoder(AutoencoderKLWrapper):
33
+ @classmethod
34
+ def from_pretrained(
35
+ cls,
36
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
37
+ *args,
38
+ **kwargs,
39
+ ):
40
+ pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
41
+ if (
42
+ pretrained_model_name_or_path.is_dir()
43
+ and (pretrained_model_name_or_path / "autoencoder.pth").exists()
44
+ ):
45
+ config_local_path = pretrained_model_name_or_path / "config.json"
46
+ config = cls.load_config(config_local_path, **kwargs)
47
+
48
+ model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
49
+ state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
50
+
51
+ statistics_local_path = (
52
+ pretrained_model_name_or_path / "per_channel_statistics.json"
53
+ )
54
+ if statistics_local_path.exists():
55
+ with open(statistics_local_path, "r") as file:
56
+ data = json.load(file)
57
+ transposed_data = list(zip(*data["data"]))
58
+ data_dict = {
59
+ col: torch.tensor(vals)
60
+ for col, vals in zip(data["columns"], transposed_data)
61
+ }
62
+ std_of_means = data_dict["std-of-means"]
63
+ mean_of_means = data_dict.get(
64
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
65
+ )
66
+ state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = (
67
+ std_of_means
68
+ )
69
+ state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = (
70
+ mean_of_means
71
+ )
72
+
73
+ elif pretrained_model_name_or_path.is_dir():
74
+ config_path = pretrained_model_name_or_path / "config.json"
75
+ with open(config_path, "r") as f:
76
+ config = make_hashable_key(json.load(f))
77
+
78
+ assert config in diffusers_and_ours_config_mapping, (
79
+ "Provided diffusers checkpoint config for VAE is not suppported. "
80
+ "We only support diffusers configs found in Lightricks/LTX-Video."
81
+ )
82
+
83
+ config = diffusers_and_ours_config_mapping[config]
84
+
85
+ state_dict_path = (
86
+ pretrained_model_name_or_path
87
+ / "diffusion_pytorch_model.safetensors"
88
+ )
89
+
90
+ state_dict = {}
91
+ with safe_open(state_dict_path, framework="pt", device="cpu") as f:
92
+ for k in f.keys():
93
+ state_dict[k] = f.get_tensor(k)
94
+ for key in list(state_dict.keys()):
95
+ new_key = key
96
+ for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
97
+ new_key = new_key.replace(replace_key, rename_key)
98
+
99
+ state_dict[new_key] = state_dict.pop(key)
100
+
101
+ elif pretrained_model_name_or_path.is_file() and str(
102
+ pretrained_model_name_or_path
103
+ ).endswith(".safetensors"):
104
+ state_dict = {}
105
+ with safe_open(
106
+ pretrained_model_name_or_path, framework="pt", device="cpu"
107
+ ) as f:
108
+ metadata = f.metadata()
109
+ for k in f.keys():
110
+ state_dict[k] = f.get_tensor(k)
111
+ configs = json.loads(metadata["config"])
112
+ config = configs["vae"]
113
+
114
+ video_vae = cls.from_config(config)
115
+ if "torch_dtype" in kwargs:
116
+ video_vae.to(kwargs["torch_dtype"])
117
+ video_vae.load_state_dict(state_dict)
118
+ return video_vae
119
+
120
+ @staticmethod
121
+ def from_config(config):
122
+ assert (
123
+ config["_class_name"] == "CausalVideoAutoencoder"
124
+ ), "config must have _class_name=CausalVideoAutoencoder"
125
+ if isinstance(config["dims"], list):
126
+ config["dims"] = tuple(config["dims"])
127
+
128
+ assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
129
+
130
+ double_z = config.get("double_z", True)
131
+ latent_log_var = config.get(
132
+ "latent_log_var", "per_channel" if double_z else "none"
133
+ )
134
+ use_quant_conv = config.get("use_quant_conv", True)
135
+ normalize_latent_channels = config.get("normalize_latent_channels", False)
136
+
137
+ if use_quant_conv and latent_log_var in ["uniform", "constant"]:
138
+ raise ValueError(
139
+ f"latent_log_var={latent_log_var} requires use_quant_conv=False"
140
+ )
141
+
142
+ encoder = Encoder(
143
+ dims=config["dims"],
144
+ in_channels=config.get("in_channels", 3),
145
+ out_channels=config["latent_channels"],
146
+ blocks=config.get("encoder_blocks", config.get("blocks")),
147
+ patch_size=config.get("patch_size", 1),
148
+ latent_log_var=latent_log_var,
149
+ norm_layer=config.get("norm_layer", "group_norm"),
150
+ base_channels=config.get("encoder_base_channels", 128),
151
+ spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
152
+ )
153
+
154
+ decoder = Decoder(
155
+ dims=config["dims"],
156
+ in_channels=config["latent_channels"],
157
+ out_channels=config.get("out_channels", 3),
158
+ blocks=config.get("decoder_blocks", config.get("blocks")),
159
+ patch_size=config.get("patch_size", 1),
160
+ norm_layer=config.get("norm_layer", "group_norm"),
161
+ causal=config.get("causal_decoder", False),
162
+ timestep_conditioning=config.get("timestep_conditioning", False),
163
+ base_channels=config.get("decoder_base_channels", 128),
164
+ spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
165
+ )
166
+
167
+ dims = config["dims"]
168
+ return CausalVideoAutoencoder(
169
+ encoder=encoder,
170
+ decoder=decoder,
171
+ latent_channels=config["latent_channels"],
172
+ dims=dims,
173
+ use_quant_conv=use_quant_conv,
174
+ normalize_latent_channels=normalize_latent_channels,
175
+ )
176
+
177
+ @property
178
+ def config(self):
179
+ return SimpleNamespace(
180
+ _class_name="CausalVideoAutoencoder",
181
+ dims=self.dims,
182
+ in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
183
+ out_channels=self.decoder.conv_out.out_channels
184
+ // self.decoder.patch_size**2,
185
+ latent_channels=self.decoder.conv_in.in_channels,
186
+ encoder_blocks=self.encoder.blocks_desc,
187
+ decoder_blocks=self.decoder.blocks_desc,
188
+ scaling_factor=1.0,
189
+ norm_layer=self.encoder.norm_layer,
190
+ patch_size=self.encoder.patch_size,
191
+ latent_log_var=self.encoder.latent_log_var,
192
+ use_quant_conv=self.use_quant_conv,
193
+ causal_decoder=self.decoder.causal,
194
+ timestep_conditioning=self.decoder.timestep_conditioning,
195
+ normalize_latent_channels=self.normalize_latent_channels,
196
+ )
197
+
198
+ @property
199
+ def is_video_supported(self):
200
+ """
201
+ Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
202
+ """
203
+ return self.dims != 2
204
+
205
+ @property
206
+ def spatial_downscale_factor(self):
207
+ return (
208
+ 2
209
+ ** len(
210
+ [
211
+ block
212
+ for block in self.encoder.blocks_desc
213
+ if block[0]
214
+ in [
215
+ "compress_space",
216
+ "compress_all",
217
+ "compress_all_res",
218
+ "compress_space_res",
219
+ ]
220
+ ]
221
+ )
222
+ * self.encoder.patch_size
223
+ )
224
+
225
+ @property
226
+ def temporal_downscale_factor(self):
227
+ return 2 ** len(
228
+ [
229
+ block
230
+ for block in self.encoder.blocks_desc
231
+ if block[0]
232
+ in [
233
+ "compress_time",
234
+ "compress_all",
235
+ "compress_all_res",
236
+ "compress_space_res",
237
+ ]
238
+ ]
239
+ )
240
+
241
+ def to_json_string(self) -> str:
242
+ import json
243
+
244
+ return json.dumps(self.config.__dict__)
245
+
246
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
247
+ if any([key.startswith("vae.") for key in state_dict.keys()]):
248
+ state_dict = {
249
+ key.replace("vae.", ""): value
250
+ for key, value in state_dict.items()
251
+ if key.startswith("vae.")
252
+ }
253
+ ckpt_state_dict = {
254
+ key: value
255
+ for key, value in state_dict.items()
256
+ if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
257
+ }
258
+
259
+ model_keys = set(name for name, _ in self.named_modules())
260
+
261
+ key_mapping = {
262
+ ".resnets.": ".res_blocks.",
263
+ "downsamplers.0": "downsample",
264
+ "upsamplers.0": "upsample",
265
+ }
266
+ converted_state_dict = {}
267
+ for key, value in ckpt_state_dict.items():
268
+ for k, v in key_mapping.items():
269
+ key = key.replace(k, v)
270
+
271
+ key_prefix = ".".join(key.split(".")[:-1])
272
+ if "norm" in key and key_prefix not in model_keys:
273
+ logger.info(
274
+ f"Removing key {key} from state_dict as it is not present in the model"
275
+ )
276
+ continue
277
+
278
+ converted_state_dict[key] = value
279
+
280
+ super().load_state_dict(converted_state_dict, strict=strict)
281
+
282
+ data_dict = {
283
+ key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
284
+ for key, value in state_dict.items()
285
+ if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
286
+ }
287
+ if len(data_dict) > 0:
288
+ self.register_buffer("std_of_means", data_dict["std-of-means"])
289
+ self.register_buffer(
290
+ "mean_of_means",
291
+ data_dict.get(
292
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
293
+ ),
294
+ )
295
+
296
+ def last_layer(self):
297
+ if hasattr(self.decoder, "conv_out"):
298
+ if isinstance(self.decoder.conv_out, nn.Sequential):
299
+ last_layer = self.decoder.conv_out[-1]
300
+ else:
301
+ last_layer = self.decoder.conv_out
302
+ else:
303
+ last_layer = self.decoder.layers[-1]
304
+ return last_layer
305
+
306
+ def set_use_tpu_flash_attention(self):
307
+ for block in self.decoder.up_blocks:
308
+ if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
309
+ for attention_block in block.attention_blocks:
310
+ attention_block.set_use_tpu_flash_attention()
311
+
312
+
313
+ class Encoder(nn.Module):
314
+ r"""
315
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
316
+
317
+ Args:
318
+ dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
319
+ The number of dimensions to use in convolutions.
320
+ in_channels (`int`, *optional*, defaults to 3):
321
+ The number of input channels.
322
+ out_channels (`int`, *optional*, defaults to 3):
323
+ The number of output channels.
324
+ blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
325
+ The blocks to use. Each block is a tuple of the block name and the number of layers.
326
+ base_channels (`int`, *optional*, defaults to 128):
327
+ The number of output channels for the first convolutional layer.
328
+ norm_num_groups (`int`, *optional*, defaults to 32):
329
+ The number of groups for normalization.
330
+ patch_size (`int`, *optional*, defaults to 1):
331
+ The patch size to use. Should be a power of 2.
332
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
333
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
334
+ latent_log_var (`str`, *optional*, defaults to `per_channel`):
335
+ The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ dims: Union[int, Tuple[int, int]] = 3,
341
+ in_channels: int = 3,
342
+ out_channels: int = 3,
343
+ blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
344
+ base_channels: int = 128,
345
+ norm_num_groups: int = 32,
346
+ patch_size: Union[int, Tuple[int]] = 1,
347
+ norm_layer: str = "group_norm", # group_norm, pixel_norm
348
+ latent_log_var: str = "per_channel",
349
+ spatial_padding_mode: str = "zeros",
350
+ ):
351
+ super().__init__()
352
+ self.patch_size = patch_size
353
+ self.norm_layer = norm_layer
354
+ self.latent_channels = out_channels
355
+ self.latent_log_var = latent_log_var
356
+ self.blocks_desc = blocks
357
+
358
+ in_channels = in_channels * patch_size**2
359
+ output_channel = base_channels
360
+
361
+ self.conv_in = make_conv_nd(
362
+ dims=dims,
363
+ in_channels=in_channels,
364
+ out_channels=output_channel,
365
+ kernel_size=3,
366
+ stride=1,
367
+ padding=1,
368
+ causal=True,
369
+ spatial_padding_mode=spatial_padding_mode,
370
+ )
371
+
372
+ self.down_blocks = nn.ModuleList([])
373
+
374
+ for block_name, block_params in blocks:
375
+ input_channel = output_channel
376
+ if isinstance(block_params, int):
377
+ block_params = {"num_layers": block_params}
378
+
379
+ if block_name == "res_x":
380
+ block = UNetMidBlock3D(
381
+ dims=dims,
382
+ in_channels=input_channel,
383
+ num_layers=block_params["num_layers"],
384
+ resnet_eps=1e-6,
385
+ resnet_groups=norm_num_groups,
386
+ norm_layer=norm_layer,
387
+ spatial_padding_mode=spatial_padding_mode,
388
+ )
389
+ elif block_name == "res_x_y":
390
+ output_channel = block_params.get("multiplier", 2) * output_channel
391
+ block = ResnetBlock3D(
392
+ dims=dims,
393
+ in_channels=input_channel,
394
+ out_channels=output_channel,
395
+ eps=1e-6,
396
+ groups=norm_num_groups,
397
+ norm_layer=norm_layer,
398
+ spatial_padding_mode=spatial_padding_mode,
399
+ )
400
+ elif block_name == "compress_time":
401
+ block = make_conv_nd(
402
+ dims=dims,
403
+ in_channels=input_channel,
404
+ out_channels=output_channel,
405
+ kernel_size=3,
406
+ stride=(2, 1, 1),
407
+ causal=True,
408
+ spatial_padding_mode=spatial_padding_mode,
409
+ )
410
+ elif block_name == "compress_space":
411
+ block = make_conv_nd(
412
+ dims=dims,
413
+ in_channels=input_channel,
414
+ out_channels=output_channel,
415
+ kernel_size=3,
416
+ stride=(1, 2, 2),
417
+ causal=True,
418
+ spatial_padding_mode=spatial_padding_mode,
419
+ )
420
+ elif block_name == "compress_all":
421
+ block = make_conv_nd(
422
+ dims=dims,
423
+ in_channels=input_channel,
424
+ out_channels=output_channel,
425
+ kernel_size=3,
426
+ stride=(2, 2, 2),
427
+ causal=True,
428
+ spatial_padding_mode=spatial_padding_mode,
429
+ )
430
+ elif block_name == "compress_all_x_y":
431
+ output_channel = block_params.get("multiplier", 2) * output_channel
432
+ block = make_conv_nd(
433
+ dims=dims,
434
+ in_channels=input_channel,
435
+ out_channels=output_channel,
436
+ kernel_size=3,
437
+ stride=(2, 2, 2),
438
+ causal=True,
439
+ spatial_padding_mode=spatial_padding_mode,
440
+ )
441
+ elif block_name == "compress_all_res":
442
+ output_channel = block_params.get("multiplier", 2) * output_channel
443
+ block = SpaceToDepthDownsample(
444
+ dims=dims,
445
+ in_channels=input_channel,
446
+ out_channels=output_channel,
447
+ stride=(2, 2, 2),
448
+ spatial_padding_mode=spatial_padding_mode,
449
+ )
450
+ elif block_name == "compress_space_res":
451
+ output_channel = block_params.get("multiplier", 2) * output_channel
452
+ block = SpaceToDepthDownsample(
453
+ dims=dims,
454
+ in_channels=input_channel,
455
+ out_channels=output_channel,
456
+ stride=(1, 2, 2),
457
+ spatial_padding_mode=spatial_padding_mode,
458
+ )
459
+ elif block_name == "compress_time_res":
460
+ output_channel = block_params.get("multiplier", 2) * output_channel
461
+ block = SpaceToDepthDownsample(
462
+ dims=dims,
463
+ in_channels=input_channel,
464
+ out_channels=output_channel,
465
+ stride=(2, 1, 1),
466
+ spatial_padding_mode=spatial_padding_mode,
467
+ )
468
+ else:
469
+ raise ValueError(f"unknown block: {block_name}")
470
+
471
+ self.down_blocks.append(block)
472
+
473
+ # out
474
+ if norm_layer == "group_norm":
475
+ self.conv_norm_out = nn.GroupNorm(
476
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
477
+ )
478
+ elif norm_layer == "pixel_norm":
479
+ self.conv_norm_out = PixelNorm()
480
+ elif norm_layer == "layer_norm":
481
+ self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
482
+
483
+ self.conv_act = nn.SiLU()
484
+
485
+ conv_out_channels = out_channels
486
+ if latent_log_var == "per_channel":
487
+ conv_out_channels *= 2
488
+ elif latent_log_var == "uniform":
489
+ conv_out_channels += 1
490
+ elif latent_log_var == "constant":
491
+ conv_out_channels += 1
492
+ elif latent_log_var != "none":
493
+ raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
494
+ self.conv_out = make_conv_nd(
495
+ dims,
496
+ output_channel,
497
+ conv_out_channels,
498
+ 3,
499
+ padding=1,
500
+ causal=True,
501
+ spatial_padding_mode=spatial_padding_mode,
502
+ )
503
+
504
+ self.gradient_checkpointing = False
505
+
506
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
507
+ r"""The forward method of the `Encoder` class."""
508
+
509
+ sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
510
+ sample = self.conv_in(sample)
511
+
512
+ checkpoint_fn = (
513
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
514
+ if self.gradient_checkpointing and self.training
515
+ else lambda x: x
516
+ )
517
+
518
+ for down_block in self.down_blocks:
519
+ sample = checkpoint_fn(down_block)(sample)
520
+
521
+ sample = self.conv_norm_out(sample)
522
+ sample = self.conv_act(sample)
523
+ sample = self.conv_out(sample)
524
+
525
+ if self.latent_log_var == "uniform":
526
+ last_channel = sample[:, -1:, ...]
527
+ num_dims = sample.dim()
528
+
529
+ if num_dims == 4:
530
+ # For shape (B, C, H, W)
531
+ repeated_last_channel = last_channel.repeat(
532
+ 1, sample.shape[1] - 2, 1, 1
533
+ )
534
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
535
+ elif num_dims == 5:
536
+ # For shape (B, C, F, H, W)
537
+ repeated_last_channel = last_channel.repeat(
538
+ 1, sample.shape[1] - 2, 1, 1, 1
539
+ )
540
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
541
+ else:
542
+ raise ValueError(f"Invalid input shape: {sample.shape}")
543
+ elif self.latent_log_var == "constant":
544
+ sample = sample[:, :-1, ...]
545
+ approx_ln_0 = (
546
+ -30
547
+ ) # this is the minimal clamp value in DiagonalGaussianDistribution objects
548
+ sample = torch.cat(
549
+ [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
550
+ dim=1,
551
+ )
552
+
553
+ return sample
554
+
555
+
556
+ class Decoder(nn.Module):
557
+ r"""
558
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
559
+
560
+ Args:
561
+ dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
562
+ The number of dimensions to use in convolutions.
563
+ in_channels (`int`, *optional*, defaults to 3):
564
+ The number of input channels.
565
+ out_channels (`int`, *optional*, defaults to 3):
566
+ The number of output channels.
567
+ blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
568
+ The blocks to use. Each block is a tuple of the block name and the number of layers.
569
+ base_channels (`int`, *optional*, defaults to 128):
570
+ The number of output channels for the first convolutional layer.
571
+ norm_num_groups (`int`, *optional*, defaults to 32):
572
+ The number of groups for normalization.
573
+ patch_size (`int`, *optional*, defaults to 1):
574
+ The patch size to use. Should be a power of 2.
575
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
576
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
577
+ causal (`bool`, *optional*, defaults to `True`):
578
+ Whether to use causal convolutions or not.
579
+ """
580
+
581
+ def __init__(
582
+ self,
583
+ dims,
584
+ in_channels: int = 3,
585
+ out_channels: int = 3,
586
+ blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
587
+ base_channels: int = 128,
588
+ layers_per_block: int = 2,
589
+ norm_num_groups: int = 32,
590
+ patch_size: int = 1,
591
+ norm_layer: str = "group_norm",
592
+ causal: bool = True,
593
+ timestep_conditioning: bool = False,
594
+ spatial_padding_mode: str = "zeros",
595
+ ):
596
+ super().__init__()
597
+ self.patch_size = patch_size
598
+ self.layers_per_block = layers_per_block
599
+ out_channels = out_channels * patch_size**2
600
+ self.causal = causal
601
+ self.blocks_desc = blocks
602
+
603
+ # Compute output channel to be product of all channel-multiplier blocks
604
+ output_channel = base_channels
605
+ for block_name, block_params in list(reversed(blocks)):
606
+ block_params = block_params if isinstance(block_params, dict) else {}
607
+ if block_name == "res_x_y":
608
+ output_channel = output_channel * block_params.get("multiplier", 2)
609
+ if block_name == "compress_all":
610
+ output_channel = output_channel * block_params.get("multiplier", 1)
611
+
612
+ self.conv_in = make_conv_nd(
613
+ dims,
614
+ in_channels,
615
+ output_channel,
616
+ kernel_size=3,
617
+ stride=1,
618
+ padding=1,
619
+ causal=True,
620
+ spatial_padding_mode=spatial_padding_mode,
621
+ )
622
+
623
+ self.up_blocks = nn.ModuleList([])
624
+
625
+ for block_name, block_params in list(reversed(blocks)):
626
+ input_channel = output_channel
627
+ if isinstance(block_params, int):
628
+ block_params = {"num_layers": block_params}
629
+
630
+ if block_name == "res_x":
631
+ block = UNetMidBlock3D(
632
+ dims=dims,
633
+ in_channels=input_channel,
634
+ num_layers=block_params["num_layers"],
635
+ resnet_eps=1e-6,
636
+ resnet_groups=norm_num_groups,
637
+ norm_layer=norm_layer,
638
+ inject_noise=block_params.get("inject_noise", False),
639
+ timestep_conditioning=timestep_conditioning,
640
+ spatial_padding_mode=spatial_padding_mode,
641
+ )
642
+ elif block_name == "attn_res_x":
643
+ block = UNetMidBlock3D(
644
+ dims=dims,
645
+ in_channels=input_channel,
646
+ num_layers=block_params["num_layers"],
647
+ resnet_groups=norm_num_groups,
648
+ norm_layer=norm_layer,
649
+ inject_noise=block_params.get("inject_noise", False),
650
+ timestep_conditioning=timestep_conditioning,
651
+ attention_head_dim=block_params["attention_head_dim"],
652
+ spatial_padding_mode=spatial_padding_mode,
653
+ )
654
+ elif block_name == "res_x_y":
655
+ output_channel = output_channel // block_params.get("multiplier", 2)
656
+ block = ResnetBlock3D(
657
+ dims=dims,
658
+ in_channels=input_channel,
659
+ out_channels=output_channel,
660
+ eps=1e-6,
661
+ groups=norm_num_groups,
662
+ norm_layer=norm_layer,
663
+ inject_noise=block_params.get("inject_noise", False),
664
+ timestep_conditioning=False,
665
+ spatial_padding_mode=spatial_padding_mode,
666
+ )
667
+ elif block_name == "compress_time":
668
+ block = DepthToSpaceUpsample(
669
+ dims=dims,
670
+ in_channels=input_channel,
671
+ stride=(2, 1, 1),
672
+ spatial_padding_mode=spatial_padding_mode,
673
+ )
674
+ elif block_name == "compress_space":
675
+ block = DepthToSpaceUpsample(
676
+ dims=dims,
677
+ in_channels=input_channel,
678
+ stride=(1, 2, 2),
679
+ spatial_padding_mode=spatial_padding_mode,
680
+ )
681
+ elif block_name == "compress_all":
682
+ output_channel = output_channel // block_params.get("multiplier", 1)
683
+ block = DepthToSpaceUpsample(
684
+ dims=dims,
685
+ in_channels=input_channel,
686
+ stride=(2, 2, 2),
687
+ residual=block_params.get("residual", False),
688
+ out_channels_reduction_factor=block_params.get("multiplier", 1),
689
+ spatial_padding_mode=spatial_padding_mode,
690
+ )
691
+ else:
692
+ raise ValueError(f"unknown layer: {block_name}")
693
+
694
+ self.up_blocks.append(block)
695
+
696
+ if norm_layer == "group_norm":
697
+ self.conv_norm_out = nn.GroupNorm(
698
+ num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
699
+ )
700
+ elif norm_layer == "pixel_norm":
701
+ self.conv_norm_out = PixelNorm()
702
+ elif norm_layer == "layer_norm":
703
+ self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
704
+
705
+ self.conv_act = nn.SiLU()
706
+ self.conv_out = make_conv_nd(
707
+ dims,
708
+ output_channel,
709
+ out_channels,
710
+ 3,
711
+ padding=1,
712
+ causal=True,
713
+ spatial_padding_mode=spatial_padding_mode,
714
+ )
715
+
716
+ self.gradient_checkpointing = False
717
+
718
+ self.timestep_conditioning = timestep_conditioning
719
+
720
+ if timestep_conditioning:
721
+ self.timestep_scale_multiplier = nn.Parameter(
722
+ torch.tensor(1000.0, dtype=torch.float32)
723
+ )
724
+ self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
725
+ output_channel * 2, 0
726
+ )
727
+ self.last_scale_shift_table = nn.Parameter(
728
+ torch.randn(2, output_channel) / output_channel**0.5
729
+ )
730
+
731
+ def forward(
732
+ self,
733
+ sample: torch.FloatTensor,
734
+ target_shape,
735
+ timestep: Optional[torch.Tensor] = None,
736
+ ) -> torch.FloatTensor:
737
+ r"""The forward method of the `Decoder` class."""
738
+ assert target_shape is not None, "target_shape must be provided"
739
+ batch_size = sample.shape[0]
740
+
741
+ sample = self.conv_in(sample, causal=self.causal)
742
+
743
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
744
+
745
+ checkpoint_fn = (
746
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
747
+ if self.gradient_checkpointing and self.training
748
+ else lambda x: x
749
+ )
750
+
751
+ sample = sample.to(upscale_dtype)
752
+
753
+ if self.timestep_conditioning:
754
+ assert (
755
+ timestep is not None
756
+ ), "should pass timestep with timestep_conditioning=True"
757
+ scaled_timestep = timestep * self.timestep_scale_multiplier
758
+
759
+ for up_block in self.up_blocks:
760
+ if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
761
+ sample = checkpoint_fn(up_block)(
762
+ sample, causal=self.causal, timestep=scaled_timestep
763
+ )
764
+ else:
765
+ sample = checkpoint_fn(up_block)(sample, causal=self.causal)
766
+
767
+ sample = self.conv_norm_out(sample)
768
+
769
+ if self.timestep_conditioning:
770
+ embedded_timestep = self.last_time_embedder(
771
+ timestep=scaled_timestep.flatten(),
772
+ resolution=None,
773
+ aspect_ratio=None,
774
+ batch_size=sample.shape[0],
775
+ hidden_dtype=sample.dtype,
776
+ )
777
+ embedded_timestep = embedded_timestep.view(
778
+ batch_size, embedded_timestep.shape[-1], 1, 1, 1
779
+ )
780
+ ada_values = self.last_scale_shift_table[
781
+ None, ..., None, None, None
782
+ ] + embedded_timestep.reshape(
783
+ batch_size,
784
+ 2,
785
+ -1,
786
+ embedded_timestep.shape[-3],
787
+ embedded_timestep.shape[-2],
788
+ embedded_timestep.shape[-1],
789
+ )
790
+ shift, scale = ada_values.unbind(dim=1)
791
+ sample = sample * (1 + scale) + shift
792
+
793
+ sample = self.conv_act(sample)
794
+ sample = self.conv_out(sample, causal=self.causal)
795
+
796
+ sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
797
+
798
+ return sample
799
+
800
+
801
+ class UNetMidBlock3D(nn.Module):
802
+ """
803
+ A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
804
+
805
+ Args:
806
+ in_channels (`int`): The number of input channels.
807
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
808
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
809
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
810
+ resnet_groups (`int`, *optional*, defaults to 32):
811
+ The number of groups to use in the group normalization layers of the resnet blocks.
812
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
813
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
814
+ inject_noise (`bool`, *optional*, defaults to `False`):
815
+ Whether to inject noise into the hidden states.
816
+ timestep_conditioning (`bool`, *optional*, defaults to `False`):
817
+ Whether to condition the hidden states on the timestep.
818
+ attention_head_dim (`int`, *optional*, defaults to -1):
819
+ The dimension of the attention head. If -1, no attention is used.
820
+
821
+ Returns:
822
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
823
+ in_channels, height, width)`.
824
+
825
+ """
826
+
827
+ def __init__(
828
+ self,
829
+ dims: Union[int, Tuple[int, int]],
830
+ in_channels: int,
831
+ dropout: float = 0.0,
832
+ num_layers: int = 1,
833
+ resnet_eps: float = 1e-6,
834
+ resnet_groups: int = 32,
835
+ norm_layer: str = "group_norm",
836
+ inject_noise: bool = False,
837
+ timestep_conditioning: bool = False,
838
+ attention_head_dim: int = -1,
839
+ spatial_padding_mode: str = "zeros",
840
+ ):
841
+ super().__init__()
842
+ resnet_groups = (
843
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
844
+ )
845
+ self.timestep_conditioning = timestep_conditioning
846
+
847
+ if timestep_conditioning:
848
+ self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
849
+ in_channels * 4, 0
850
+ )
851
+
852
+ self.res_blocks = nn.ModuleList(
853
+ [
854
+ ResnetBlock3D(
855
+ dims=dims,
856
+ in_channels=in_channels,
857
+ out_channels=in_channels,
858
+ eps=resnet_eps,
859
+ groups=resnet_groups,
860
+ dropout=dropout,
861
+ norm_layer=norm_layer,
862
+ inject_noise=inject_noise,
863
+ timestep_conditioning=timestep_conditioning,
864
+ spatial_padding_mode=spatial_padding_mode,
865
+ )
866
+ for _ in range(num_layers)
867
+ ]
868
+ )
869
+
870
+ self.attention_blocks = None
871
+
872
+ if attention_head_dim > 0:
873
+ if attention_head_dim > in_channels:
874
+ raise ValueError(
875
+ "attention_head_dim must be less than or equal to in_channels"
876
+ )
877
+
878
+ self.attention_blocks = nn.ModuleList(
879
+ [
880
+ Attention(
881
+ query_dim=in_channels,
882
+ heads=in_channels // attention_head_dim,
883
+ dim_head=attention_head_dim,
884
+ bias=True,
885
+ out_bias=True,
886
+ qk_norm="rms_norm",
887
+ residual_connection=True,
888
+ )
889
+ for _ in range(num_layers)
890
+ ]
891
+ )
892
+
893
+ def forward(
894
+ self,
895
+ hidden_states: torch.FloatTensor,
896
+ causal: bool = True,
897
+ timestep: Optional[torch.Tensor] = None,
898
+ ) -> torch.FloatTensor:
899
+ timestep_embed = None
900
+ if self.timestep_conditioning:
901
+ assert (
902
+ timestep is not None
903
+ ), "should pass timestep with timestep_conditioning=True"
904
+ batch_size = hidden_states.shape[0]
905
+ timestep_embed = self.time_embedder(
906
+ timestep=timestep.flatten(),
907
+ resolution=None,
908
+ aspect_ratio=None,
909
+ batch_size=batch_size,
910
+ hidden_dtype=hidden_states.dtype,
911
+ )
912
+ timestep_embed = timestep_embed.view(
913
+ batch_size, timestep_embed.shape[-1], 1, 1, 1
914
+ )
915
+
916
+ if self.attention_blocks:
917
+ for resnet, attention in zip(self.res_blocks, self.attention_blocks):
918
+ hidden_states = resnet(
919
+ hidden_states, causal=causal, timestep=timestep_embed
920
+ )
921
+
922
+ # Reshape the hidden states to be (batch_size, frames * height * width, channel)
923
+ batch_size, channel, frames, height, width = hidden_states.shape
924
+ hidden_states = hidden_states.view(
925
+ batch_size, channel, frames * height * width
926
+ ).transpose(1, 2)
927
+
928
+ if attention.use_tpu_flash_attention:
929
+ # Pad the second dimension to be divisible by block_k_major (block in flash attention)
930
+ seq_len = hidden_states.shape[1]
931
+ block_k_major = 512
932
+ pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
933
+ if pad_len > 0:
934
+ hidden_states = F.pad(
935
+ hidden_states, (0, 0, 0, pad_len), "constant", 0
936
+ )
937
+
938
+ # Create a mask with ones for the original sequence length and zeros for the padded indexes
939
+ mask = torch.ones(
940
+ (hidden_states.shape[0], seq_len),
941
+ device=hidden_states.device,
942
+ dtype=hidden_states.dtype,
943
+ )
944
+ if pad_len > 0:
945
+ mask = F.pad(mask, (0, pad_len), "constant", 0)
946
+
947
+ hidden_states = attention(
948
+ hidden_states,
949
+ attention_mask=(
950
+ None if not attention.use_tpu_flash_attention else mask
951
+ ),
952
+ )
953
+
954
+ if attention.use_tpu_flash_attention:
955
+ # Remove the padding
956
+ if pad_len > 0:
957
+ hidden_states = hidden_states[:, :-pad_len, :]
958
+
959
+ # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
960
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
961
+ batch_size, channel, frames, height, width
962
+ )
963
+ else:
964
+ for resnet in self.res_blocks:
965
+ hidden_states = resnet(
966
+ hidden_states, causal=causal, timestep=timestep_embed
967
+ )
968
+
969
+ return hidden_states
970
+
971
+
972
+ class SpaceToDepthDownsample(nn.Module):
973
+ def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
974
+ super().__init__()
975
+ self.stride = stride
976
+ self.group_size = in_channels * np.prod(stride) // out_channels
977
+ self.conv = make_conv_nd(
978
+ dims=dims,
979
+ in_channels=in_channels,
980
+ out_channels=out_channels // np.prod(stride),
981
+ kernel_size=3,
982
+ stride=1,
983
+ causal=True,
984
+ spatial_padding_mode=spatial_padding_mode,
985
+ )
986
+
987
+ def forward(self, x, causal: bool = True):
988
+ if self.stride[0] == 2:
989
+ x = torch.cat(
990
+ [x[:, :, :1, :, :], x], dim=2
991
+ ) # duplicate first frames for padding
992
+
993
+ # skip connection
994
+ x_in = rearrange(
995
+ x,
996
+ "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
997
+ p1=self.stride[0],
998
+ p2=self.stride[1],
999
+ p3=self.stride[2],
1000
+ )
1001
+ x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
1002
+ x_in = x_in.mean(dim=2)
1003
+
1004
+ # conv
1005
+ x = self.conv(x, causal=causal)
1006
+ x = rearrange(
1007
+ x,
1008
+ "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
1009
+ p1=self.stride[0],
1010
+ p2=self.stride[1],
1011
+ p3=self.stride[2],
1012
+ )
1013
+
1014
+ x = x + x_in
1015
+
1016
+ return x
1017
+
1018
+
1019
+ class DepthToSpaceUpsample(nn.Module):
1020
+ def __init__(
1021
+ self,
1022
+ dims,
1023
+ in_channels,
1024
+ stride,
1025
+ residual=False,
1026
+ out_channels_reduction_factor=1,
1027
+ spatial_padding_mode="zeros",
1028
+ ):
1029
+ super().__init__()
1030
+ self.stride = stride
1031
+ self.out_channels = (
1032
+ np.prod(stride) * in_channels // out_channels_reduction_factor
1033
+ )
1034
+ self.conv = make_conv_nd(
1035
+ dims=dims,
1036
+ in_channels=in_channels,
1037
+ out_channels=self.out_channels,
1038
+ kernel_size=3,
1039
+ stride=1,
1040
+ causal=True,
1041
+ spatial_padding_mode=spatial_padding_mode,
1042
+ )
1043
+ self.residual = residual
1044
+ self.out_channels_reduction_factor = out_channels_reduction_factor
1045
+
1046
+ def forward(self, x, causal: bool = True):
1047
+ if self.residual:
1048
+ # Reshape and duplicate the input to match the output shape
1049
+ x_in = rearrange(
1050
+ x,
1051
+ "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
1052
+ p1=self.stride[0],
1053
+ p2=self.stride[1],
1054
+ p3=self.stride[2],
1055
+ )
1056
+ num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
1057
+ x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
1058
+ if self.stride[0] == 2:
1059
+ x_in = x_in[:, :, 1:, :, :]
1060
+ x = self.conv(x, causal=causal)
1061
+ x = rearrange(
1062
+ x,
1063
+ "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
1064
+ p1=self.stride[0],
1065
+ p2=self.stride[1],
1066
+ p3=self.stride[2],
1067
+ )
1068
+ if self.stride[0] == 2:
1069
+ x = x[:, :, 1:, :, :]
1070
+ if self.residual:
1071
+ x = x + x_in
1072
+ return x
1073
+
1074
+
1075
+ class LayerNorm(nn.Module):
1076
+ def __init__(self, dim, eps, elementwise_affine=True) -> None:
1077
+ super().__init__()
1078
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
1079
+
1080
+ def forward(self, x):
1081
+ x = rearrange(x, "b c d h w -> b d h w c")
1082
+ x = self.norm(x)
1083
+ x = rearrange(x, "b d h w c -> b c d h w")
1084
+ return x
1085
+
1086
+
1087
+ class ResnetBlock3D(nn.Module):
1088
+ r"""
1089
+ A Resnet block.
1090
+
1091
+ Parameters:
1092
+ in_channels (`int`): The number of channels in the input.
1093
+ out_channels (`int`, *optional*, default to be `None`):
1094
+ The number of output channels for the first conv layer. If None, same as `in_channels`.
1095
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
1096
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
1097
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
1098
+ """
1099
+
1100
+ def __init__(
1101
+ self,
1102
+ dims: Union[int, Tuple[int, int]],
1103
+ in_channels: int,
1104
+ out_channels: Optional[int] = None,
1105
+ dropout: float = 0.0,
1106
+ groups: int = 32,
1107
+ eps: float = 1e-6,
1108
+ norm_layer: str = "group_norm",
1109
+ inject_noise: bool = False,
1110
+ timestep_conditioning: bool = False,
1111
+ spatial_padding_mode: str = "zeros",
1112
+ ):
1113
+ super().__init__()
1114
+ self.in_channels = in_channels
1115
+ out_channels = in_channels if out_channels is None else out_channels
1116
+ self.out_channels = out_channels
1117
+ self.inject_noise = inject_noise
1118
+
1119
+ if norm_layer == "group_norm":
1120
+ self.norm1 = nn.GroupNorm(
1121
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
1122
+ )
1123
+ elif norm_layer == "pixel_norm":
1124
+ self.norm1 = PixelNorm()
1125
+ elif norm_layer == "layer_norm":
1126
+ self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
1127
+
1128
+ self.non_linearity = nn.SiLU()
1129
+
1130
+ self.conv1 = make_conv_nd(
1131
+ dims,
1132
+ in_channels,
1133
+ out_channels,
1134
+ kernel_size=3,
1135
+ stride=1,
1136
+ padding=1,
1137
+ causal=True,
1138
+ spatial_padding_mode=spatial_padding_mode,
1139
+ )
1140
+
1141
+ if inject_noise:
1142
+ self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
1143
+
1144
+ if norm_layer == "group_norm":
1145
+ self.norm2 = nn.GroupNorm(
1146
+ num_groups=groups, num_channels=out_channels, eps=eps, affine=True
1147
+ )
1148
+ elif norm_layer == "pixel_norm":
1149
+ self.norm2 = PixelNorm()
1150
+ elif norm_layer == "layer_norm":
1151
+ self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
1152
+
1153
+ self.dropout = torch.nn.Dropout(dropout)
1154
+
1155
+ self.conv2 = make_conv_nd(
1156
+ dims,
1157
+ out_channels,
1158
+ out_channels,
1159
+ kernel_size=3,
1160
+ stride=1,
1161
+ padding=1,
1162
+ causal=True,
1163
+ spatial_padding_mode=spatial_padding_mode,
1164
+ )
1165
+
1166
+ if inject_noise:
1167
+ self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
1168
+
1169
+ self.conv_shortcut = (
1170
+ make_linear_nd(
1171
+ dims=dims, in_channels=in_channels, out_channels=out_channels
1172
+ )
1173
+ if in_channels != out_channels
1174
+ else nn.Identity()
1175
+ )
1176
+
1177
+ self.norm3 = (
1178
+ LayerNorm(in_channels, eps=eps, elementwise_affine=True)
1179
+ if in_channels != out_channels
1180
+ else nn.Identity()
1181
+ )
1182
+
1183
+ self.timestep_conditioning = timestep_conditioning
1184
+
1185
+ if timestep_conditioning:
1186
+ self.scale_shift_table = nn.Parameter(
1187
+ torch.randn(4, in_channels) / in_channels**0.5
1188
+ )
1189
+
1190
+ def _feed_spatial_noise(
1191
+ self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
1192
+ ) -> torch.FloatTensor:
1193
+ spatial_shape = hidden_states.shape[-2:]
1194
+ device = hidden_states.device
1195
+ dtype = hidden_states.dtype
1196
+
1197
+ # similar to the "explicit noise inputs" method in style-gan
1198
+ spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
1199
+ scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
1200
+ hidden_states = hidden_states + scaled_noise
1201
+
1202
+ return hidden_states
1203
+
1204
+ def forward(
1205
+ self,
1206
+ input_tensor: torch.FloatTensor,
1207
+ causal: bool = True,
1208
+ timestep: Optional[torch.Tensor] = None,
1209
+ ) -> torch.FloatTensor:
1210
+ hidden_states = input_tensor
1211
+ batch_size = hidden_states.shape[0]
1212
+
1213
+ hidden_states = self.norm1(hidden_states)
1214
+ if self.timestep_conditioning:
1215
+ assert (
1216
+ timestep is not None
1217
+ ), "should pass timestep with timestep_conditioning=True"
1218
+ ada_values = self.scale_shift_table[
1219
+ None, ..., None, None, None
1220
+ ] + timestep.reshape(
1221
+ batch_size,
1222
+ 4,
1223
+ -1,
1224
+ timestep.shape[-3],
1225
+ timestep.shape[-2],
1226
+ timestep.shape[-1],
1227
+ )
1228
+ shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
1229
+
1230
+ hidden_states = hidden_states * (1 + scale1) + shift1
1231
+
1232
+ hidden_states = self.non_linearity(hidden_states)
1233
+
1234
+ hidden_states = self.conv1(hidden_states, causal=causal)
1235
+
1236
+ if self.inject_noise:
1237
+ hidden_states = self._feed_spatial_noise(
1238
+ hidden_states, self.per_channel_scale1
1239
+ )
1240
+
1241
+ hidden_states = self.norm2(hidden_states)
1242
+
1243
+ if self.timestep_conditioning:
1244
+ hidden_states = hidden_states * (1 + scale2) + shift2
1245
+
1246
+ hidden_states = self.non_linearity(hidden_states)
1247
+
1248
+ hidden_states = self.dropout(hidden_states)
1249
+
1250
+ hidden_states = self.conv2(hidden_states, causal=causal)
1251
+
1252
+ if self.inject_noise:
1253
+ hidden_states = self._feed_spatial_noise(
1254
+ hidden_states, self.per_channel_scale2
1255
+ )
1256
+
1257
+ input_tensor = self.norm3(input_tensor)
1258
+
1259
+ batch_size = input_tensor.shape[0]
1260
+
1261
+ input_tensor = self.conv_shortcut(input_tensor)
1262
+
1263
+ output_tensor = input_tensor + hidden_states
1264
+
1265
+ return output_tensor
1266
+
1267
+
1268
+ def patchify(x, patch_size_hw, patch_size_t=1):
1269
+ if patch_size_hw == 1 and patch_size_t == 1:
1270
+ return x
1271
+ if x.dim() == 4:
1272
+ x = rearrange(
1273
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
1274
+ )
1275
+ elif x.dim() == 5:
1276
+ x = rearrange(
1277
+ x,
1278
+ "b c (f p) (h q) (w r) -> b (c p r q) f h w",
1279
+ p=patch_size_t,
1280
+ q=patch_size_hw,
1281
+ r=patch_size_hw,
1282
+ )
1283
+ else:
1284
+ raise ValueError(f"Invalid input shape: {x.shape}")
1285
+
1286
+ return x
1287
+
1288
+
1289
+ def unpatchify(x, patch_size_hw, patch_size_t=1):
1290
+ if patch_size_hw == 1 and patch_size_t == 1:
1291
+ return x
1292
+
1293
+ if x.dim() == 4:
1294
+ x = rearrange(
1295
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
1296
+ )
1297
+ elif x.dim() == 5:
1298
+ x = rearrange(
1299
+ x,
1300
+ "b (c p r q) f h w -> b c (f p) (h q) (w r)",
1301
+ p=patch_size_t,
1302
+ q=patch_size_hw,
1303
+ r=patch_size_hw,
1304
+ )
1305
+
1306
+ return x
1307
+
1308
+
1309
+ def create_video_autoencoder_demo_config(
1310
+ latent_channels: int = 64,
1311
+ ):
1312
+ encoder_blocks = [
1313
+ ("res_x", {"num_layers": 2}),
1314
+ ("compress_space_res", {"multiplier": 2}),
1315
+ ("res_x", {"num_layers": 2}),
1316
+ ("compress_time_res", {"multiplier": 2}),
1317
+ ("res_x", {"num_layers": 1}),
1318
+ ("compress_all_res", {"multiplier": 2}),
1319
+ ("res_x", {"num_layers": 1}),
1320
+ ("compress_all_res", {"multiplier": 2}),
1321
+ ("res_x", {"num_layers": 1}),
1322
+ ]
1323
+ decoder_blocks = [
1324
+ ("res_x", {"num_layers": 2, "inject_noise": False}),
1325
+ ("compress_all", {"residual": True, "multiplier": 2}),
1326
+ ("res_x", {"num_layers": 2, "inject_noise": False}),
1327
+ ("compress_all", {"residual": True, "multiplier": 2}),
1328
+ ("res_x", {"num_layers": 2, "inject_noise": False}),
1329
+ ("compress_all", {"residual": True, "multiplier": 2}),
1330
+ ("res_x", {"num_layers": 2, "inject_noise": False}),
1331
+ ]
1332
+ return {
1333
+ "_class_name": "CausalVideoAutoencoder",
1334
+ "dims": 3,
1335
+ "encoder_blocks": encoder_blocks,
1336
+ "decoder_blocks": decoder_blocks,
1337
+ "latent_channels": latent_channels,
1338
+ "norm_layer": "pixel_norm",
1339
+ "patch_size": 4,
1340
+ "latent_log_var": "uniform",
1341
+ "use_quant_conv": False,
1342
+ "causal_decoder": False,
1343
+ "timestep_conditioning": True,
1344
+ "spatial_padding_mode": "replicate",
1345
+ }
1346
+
1347
+
1348
+ def test_vae_patchify_unpatchify():
1349
+ import torch
1350
+
1351
+ x = torch.randn(2, 3, 8, 64, 64)
1352
+ x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
1353
+ x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
1354
+ assert torch.allclose(x, x_unpatched)
1355
+
1356
+
1357
+ def demo_video_autoencoder_forward_backward():
1358
+ # Configuration for the VideoAutoencoder
1359
+ config = create_video_autoencoder_demo_config()
1360
+
1361
+ # Instantiate the VideoAutoencoder with the specified configuration
1362
+ video_autoencoder = CausalVideoAutoencoder.from_config(config)
1363
+
1364
+ print(video_autoencoder)
1365
+ video_autoencoder.eval()
1366
+ # Print the total number of parameters in the video autoencoder
1367
+ total_params = sum(p.numel() for p in video_autoencoder.parameters())
1368
+ print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
1369
+
1370
+ # Create a mock input tensor simulating a batch of videos
1371
+ # Shape: (batch_size, channels, depth, height, width)
1372
+ # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
1373
+ input_videos = torch.randn(2, 3, 17, 64, 64)
1374
+
1375
+ # Forward pass: encode and decode the input videos
1376
+ latent = video_autoencoder.encode(input_videos).latent_dist.mode()
1377
+ print(f"input shape={input_videos.shape}")
1378
+ print(f"latent shape={latent.shape}")
1379
+
1380
+ timestep = torch.ones(input_videos.shape[0]) * 0.1
1381
+ reconstructed_videos = video_autoencoder.decode(
1382
+ latent, target_shape=input_videos.shape, timestep=timestep
1383
+ ).sample
1384
+
1385
+ print(f"reconstructed shape={reconstructed_videos.shape}")
1386
+
1387
+ # Validate that single image gets treated the same way as first frame
1388
+ input_image = input_videos[:, :, :1, :, :]
1389
+ image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
1390
+ _ = video_autoencoder.decode(
1391
+ image_latent, target_shape=image_latent.shape, timestep=timestep
1392
+ ).sample
1393
+
1394
+ first_frame_latent = latent[:, :, :1, :, :]
1395
+
1396
+ assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
1397
+ # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
1398
+ # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
1399
+ # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
1400
+
1401
+ # Calculate the loss (e.g., mean squared error)
1402
+ loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
1403
+
1404
+ # Perform backward pass
1405
+ loss.backward()
1406
+
1407
+ print(f"Demo completed with loss: {loss.item()}")
1408
+
1409
+
1410
+ # Ensure to call the demo function to execute the forward and backward pass
1411
+ if __name__ == "__main__":
1412
+ demo_video_autoencoder_forward_backward()
flash_head/ltx_video/models/autoencoders/conv_nd_factory.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+
3
+ import torch
4
+
5
+ from flash_head.ltx_video.models.autoencoders.dual_conv3d import DualConv3d
6
+ from flash_head.ltx_video.models.autoencoders.causal_conv3d import CausalConv3d
7
+
8
+
9
+ def make_conv_nd(
10
+ dims: Union[int, Tuple[int, int]],
11
+ in_channels: int,
12
+ out_channels: int,
13
+ kernel_size: int,
14
+ stride=1,
15
+ padding=0,
16
+ dilation=1,
17
+ groups=1,
18
+ bias=True,
19
+ causal=False,
20
+ spatial_padding_mode="zeros",
21
+ temporal_padding_mode="zeros",
22
+ ):
23
+ if not (spatial_padding_mode == temporal_padding_mode or causal):
24
+ raise NotImplementedError("spatial and temporal padding modes must be equal")
25
+ if dims == 2:
26
+ return torch.nn.Conv2d(
27
+ in_channels=in_channels,
28
+ out_channels=out_channels,
29
+ kernel_size=kernel_size,
30
+ stride=stride,
31
+ padding=padding,
32
+ dilation=dilation,
33
+ groups=groups,
34
+ bias=bias,
35
+ padding_mode=spatial_padding_mode,
36
+ )
37
+ elif dims == 3:
38
+ if causal:
39
+ return CausalConv3d(
40
+ in_channels=in_channels,
41
+ out_channels=out_channels,
42
+ kernel_size=kernel_size,
43
+ stride=stride,
44
+ padding=padding,
45
+ dilation=dilation,
46
+ groups=groups,
47
+ bias=bias,
48
+ spatial_padding_mode=spatial_padding_mode,
49
+ )
50
+ return torch.nn.Conv3d(
51
+ in_channels=in_channels,
52
+ out_channels=out_channels,
53
+ kernel_size=kernel_size,
54
+ stride=stride,
55
+ padding=padding,
56
+ dilation=dilation,
57
+ groups=groups,
58
+ bias=bias,
59
+ padding_mode=spatial_padding_mode,
60
+ )
61
+ elif dims == (2, 1):
62
+ return DualConv3d(
63
+ in_channels=in_channels,
64
+ out_channels=out_channels,
65
+ kernel_size=kernel_size,
66
+ stride=stride,
67
+ padding=padding,
68
+ bias=bias,
69
+ padding_mode=spatial_padding_mode,
70
+ )
71
+ else:
72
+ raise ValueError(f"unsupported dimensions: {dims}")
73
+
74
+
75
+ def make_linear_nd(
76
+ dims: int,
77
+ in_channels: int,
78
+ out_channels: int,
79
+ bias=True,
80
+ ):
81
+ if dims == 2:
82
+ return torch.nn.Conv2d(
83
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
84
+ )
85
+ elif dims == 3 or dims == (2, 1):
86
+ return torch.nn.Conv3d(
87
+ in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
88
+ )
89
+ else:
90
+ raise ValueError(f"unsupported dimensions: {dims}")
flash_head/ltx_video/models/autoencoders/dual_conv3d.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+
9
+
10
+ class DualConv3d(nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size,
16
+ stride: Union[int, Tuple[int, int, int]] = 1,
17
+ padding: Union[int, Tuple[int, int, int]] = 0,
18
+ dilation: Union[int, Tuple[int, int, int]] = 1,
19
+ groups=1,
20
+ bias=True,
21
+ padding_mode="zeros",
22
+ ):
23
+ super(DualConv3d, self).__init__()
24
+
25
+ self.in_channels = in_channels
26
+ self.out_channels = out_channels
27
+ self.padding_mode = padding_mode
28
+ # Ensure kernel_size, stride, padding, and dilation are tuples of length 3
29
+ if isinstance(kernel_size, int):
30
+ kernel_size = (kernel_size, kernel_size, kernel_size)
31
+ if kernel_size == (1, 1, 1):
32
+ raise ValueError(
33
+ "kernel_size must be greater than 1. Use make_linear_nd instead."
34
+ )
35
+ if isinstance(stride, int):
36
+ stride = (stride, stride, stride)
37
+ if isinstance(padding, int):
38
+ padding = (padding, padding, padding)
39
+ if isinstance(dilation, int):
40
+ dilation = (dilation, dilation, dilation)
41
+
42
+ # Set parameters for convolutions
43
+ self.groups = groups
44
+ self.bias = bias
45
+
46
+ # Define the size of the channels after the first convolution
47
+ intermediate_channels = (
48
+ out_channels if in_channels < out_channels else in_channels
49
+ )
50
+
51
+ # Define parameters for the first convolution
52
+ self.weight1 = nn.Parameter(
53
+ torch.Tensor(
54
+ intermediate_channels,
55
+ in_channels // groups,
56
+ 1,
57
+ kernel_size[1],
58
+ kernel_size[2],
59
+ )
60
+ )
61
+ self.stride1 = (1, stride[1], stride[2])
62
+ self.padding1 = (0, padding[1], padding[2])
63
+ self.dilation1 = (1, dilation[1], dilation[2])
64
+ if bias:
65
+ self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
66
+ else:
67
+ self.register_parameter("bias1", None)
68
+
69
+ # Define parameters for the second convolution
70
+ self.weight2 = nn.Parameter(
71
+ torch.Tensor(
72
+ out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
73
+ )
74
+ )
75
+ self.stride2 = (stride[0], 1, 1)
76
+ self.padding2 = (padding[0], 0, 0)
77
+ self.dilation2 = (dilation[0], 1, 1)
78
+ if bias:
79
+ self.bias2 = nn.Parameter(torch.Tensor(out_channels))
80
+ else:
81
+ self.register_parameter("bias2", None)
82
+
83
+ # Initialize weights and biases
84
+ self.reset_parameters()
85
+
86
+ def reset_parameters(self):
87
+ nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
88
+ nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
89
+ if self.bias:
90
+ fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
91
+ bound1 = 1 / math.sqrt(fan_in1)
92
+ nn.init.uniform_(self.bias1, -bound1, bound1)
93
+ fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
94
+ bound2 = 1 / math.sqrt(fan_in2)
95
+ nn.init.uniform_(self.bias2, -bound2, bound2)
96
+
97
+ def forward(self, x, use_conv3d=False, skip_time_conv=False):
98
+ if use_conv3d:
99
+ return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
100
+ else:
101
+ return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
102
+
103
+ def forward_with_3d(self, x, skip_time_conv):
104
+ # First convolution
105
+ x = F.conv3d(
106
+ x,
107
+ self.weight1,
108
+ self.bias1,
109
+ self.stride1,
110
+ self.padding1,
111
+ self.dilation1,
112
+ self.groups,
113
+ padding_mode=self.padding_mode,
114
+ )
115
+
116
+ if skip_time_conv:
117
+ return x
118
+
119
+ # Second convolution
120
+ x = F.conv3d(
121
+ x,
122
+ self.weight2,
123
+ self.bias2,
124
+ self.stride2,
125
+ self.padding2,
126
+ self.dilation2,
127
+ self.groups,
128
+ padding_mode=self.padding_mode,
129
+ )
130
+
131
+ return x
132
+
133
+ def forward_with_2d(self, x, skip_time_conv):
134
+ b, c, d, h, w = x.shape
135
+
136
+ # First 2D convolution
137
+ x = rearrange(x, "b c d h w -> (b d) c h w")
138
+ # Squeeze the depth dimension out of weight1 since it's 1
139
+ weight1 = self.weight1.squeeze(2)
140
+ # Select stride, padding, and dilation for the 2D convolution
141
+ stride1 = (self.stride1[1], self.stride1[2])
142
+ padding1 = (self.padding1[1], self.padding1[2])
143
+ dilation1 = (self.dilation1[1], self.dilation1[2])
144
+ x = F.conv2d(
145
+ x,
146
+ weight1,
147
+ self.bias1,
148
+ stride1,
149
+ padding1,
150
+ dilation1,
151
+ self.groups,
152
+ padding_mode=self.padding_mode,
153
+ )
154
+
155
+ _, _, h, w = x.shape
156
+
157
+ if skip_time_conv:
158
+ x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
159
+ return x
160
+
161
+ # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
162
+ x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
163
+
164
+ # Reshape weight2 to match the expected dimensions for conv1d
165
+ weight2 = self.weight2.squeeze(-1).squeeze(-1)
166
+ # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
167
+ stride2 = self.stride2[0]
168
+ padding2 = self.padding2[0]
169
+ dilation2 = self.dilation2[0]
170
+ x = F.conv1d(
171
+ x,
172
+ weight2,
173
+ self.bias2,
174
+ stride2,
175
+ padding2,
176
+ dilation2,
177
+ self.groups,
178
+ padding_mode=self.padding_mode,
179
+ )
180
+ x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
181
+
182
+ return x
183
+
184
+ @property
185
+ def weight(self):
186
+ return self.weight2
187
+
188
+
189
+ def test_dual_conv3d_consistency():
190
+ # Initialize parameters
191
+ in_channels = 3
192
+ out_channels = 5
193
+ kernel_size = (3, 3, 3)
194
+ stride = (2, 2, 2)
195
+ padding = (1, 1, 1)
196
+
197
+ # Create an instance of the DualConv3d class
198
+ dual_conv3d = DualConv3d(
199
+ in_channels=in_channels,
200
+ out_channels=out_channels,
201
+ kernel_size=kernel_size,
202
+ stride=stride,
203
+ padding=padding,
204
+ bias=True,
205
+ )
206
+
207
+ # Example input tensor
208
+ test_input = torch.randn(1, 3, 10, 10, 10)
209
+
210
+ # Perform forward passes with both 3D and 2D settings
211
+ output_conv3d = dual_conv3d(test_input, use_conv3d=True)
212
+ output_2d = dual_conv3d(test_input, use_conv3d=False)
213
+
214
+ # Assert that the outputs from both methods are sufficiently close
215
+ assert torch.allclose(
216
+ output_conv3d, output_2d, atol=1e-6
217
+ ), "Outputs are not consistent between 3D and 2D convolutions."
flash_head/ltx_video/models/autoencoders/pixel_norm.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class PixelNorm(nn.Module):
6
+ def __init__(self, dim=1, eps=1e-8):
7
+ super(PixelNorm, self).__init__()
8
+ self.dim = dim
9
+ self.eps = eps
10
+
11
+ def forward(self, x):
12
+ return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
flash_head/ltx_video/models/autoencoders/vae.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Union
2
+
3
+ import torch
4
+ import inspect
5
+ import math
6
+ import torch.nn as nn
7
+ from diffusers import ConfigMixin, ModelMixin
8
+ from diffusers.models.autoencoders.vae import (
9
+ DecoderOutput,
10
+ DiagonalGaussianDistribution,
11
+ )
12
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
13
+ from flash_head.ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd
14
+
15
+
16
+ class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
17
+ """Variational Autoencoder (VAE) model with KL loss.
18
+
19
+ VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
20
+ This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.
21
+
22
+ Args:
23
+ encoder (`nn.Module`):
24
+ Encoder module.
25
+ decoder (`nn.Module`):
26
+ Decoder module.
27
+ latent_channels (`int`, *optional*, defaults to 4):
28
+ Number of latent channels.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ encoder: nn.Module,
34
+ decoder: nn.Module,
35
+ latent_channels: int = 4,
36
+ dims: int = 2,
37
+ sample_size=512,
38
+ use_quant_conv: bool = True,
39
+ normalize_latent_channels: bool = False,
40
+ ):
41
+ super().__init__()
42
+
43
+ # pass init params to Encoder
44
+ self.encoder = encoder
45
+ self.use_quant_conv = use_quant_conv
46
+ self.normalize_latent_channels = normalize_latent_channels
47
+
48
+ # pass init params to Decoder
49
+ quant_dims = 2 if dims == 2 else 3
50
+ self.decoder = decoder
51
+ if use_quant_conv:
52
+ self.quant_conv = make_conv_nd(
53
+ quant_dims, 2 * latent_channels, 2 * latent_channels, 1
54
+ )
55
+ self.post_quant_conv = make_conv_nd(
56
+ quant_dims, latent_channels, latent_channels, 1
57
+ )
58
+ else:
59
+ self.quant_conv = nn.Identity()
60
+ self.post_quant_conv = nn.Identity()
61
+
62
+ if normalize_latent_channels:
63
+ if dims == 2:
64
+ self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False)
65
+ else:
66
+ self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False)
67
+ else:
68
+ self.latent_norm_out = nn.Identity()
69
+ self.use_z_tiling = False
70
+ self.use_hw_tiling = False
71
+ self.dims = dims
72
+ self.z_sample_size = 1
73
+
74
+ self.decoder_params = inspect.signature(self.decoder.forward).parameters
75
+
76
+ # only relevant if vae tiling is enabled
77
+ self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
78
+
79
+ def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
80
+ self.tile_sample_min_size = sample_size
81
+ num_blocks = len(self.encoder.down_blocks)
82
+ self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
83
+ self.tile_overlap_factor = overlap_factor
84
+
85
+ def enable_z_tiling(self, z_sample_size: int = 8):
86
+ r"""
87
+ Enable tiling during VAE decoding.
88
+
89
+ When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
90
+ steps. This is useful to save some memory and allow larger batch sizes.
91
+ """
92
+ self.use_z_tiling = z_sample_size > 1
93
+ self.z_sample_size = z_sample_size
94
+ assert (
95
+ z_sample_size % 8 == 0 or z_sample_size == 1
96
+ ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."
97
+
98
+ def disable_z_tiling(self):
99
+ r"""
100
+ Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
101
+ decoding in one step.
102
+ """
103
+ self.use_z_tiling = False
104
+
105
+ def enable_hw_tiling(self):
106
+ r"""
107
+ Enable tiling during VAE decoding along the height and width dimension.
108
+ """
109
+ self.use_hw_tiling = True
110
+
111
+ def disable_hw_tiling(self):
112
+ r"""
113
+ Disable tiling during VAE decoding along the height and width dimension.
114
+ """
115
+ self.use_hw_tiling = False
116
+
117
+ def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
118
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
119
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
120
+ row_limit = self.tile_latent_min_size - blend_extent
121
+
122
+ # Split the image into 512x512 tiles and encode them separately.
123
+ rows = []
124
+ for i in range(0, x.shape[3], overlap_size):
125
+ row = []
126
+ for j in range(0, x.shape[4], overlap_size):
127
+ tile = x[
128
+ :,
129
+ :,
130
+ :,
131
+ i : i + self.tile_sample_min_size,
132
+ j : j + self.tile_sample_min_size,
133
+ ]
134
+ tile = self.encoder(tile)
135
+ tile = self.quant_conv(tile)
136
+ row.append(tile)
137
+ rows.append(row)
138
+ result_rows = []
139
+ for i, row in enumerate(rows):
140
+ result_row = []
141
+ for j, tile in enumerate(row):
142
+ # blend the above tile and the left tile
143
+ # to the current tile and add the current tile to the result row
144
+ if i > 0:
145
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
146
+ if j > 0:
147
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
148
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
149
+ result_rows.append(torch.cat(result_row, dim=4))
150
+
151
+ moments = torch.cat(result_rows, dim=3)
152
+ return moments
153
+
154
+ def blend_z(
155
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
156
+ ) -> torch.Tensor:
157
+ blend_extent = min(a.shape[2], b.shape[2], blend_extent)
158
+ for z in range(blend_extent):
159
+ b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
160
+ 1 - z / blend_extent
161
+ ) + b[:, :, z, :, :] * (z / blend_extent)
162
+ return b
163
+
164
+ def blend_v(
165
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
166
+ ) -> torch.Tensor:
167
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
168
+ for y in range(blend_extent):
169
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
170
+ 1 - y / blend_extent
171
+ ) + b[:, :, :, y, :] * (y / blend_extent)
172
+ return b
173
+
174
+ def blend_h(
175
+ self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
176
+ ) -> torch.Tensor:
177
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
178
+ for x in range(blend_extent):
179
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
180
+ 1 - x / blend_extent
181
+ ) + b[:, :, :, :, x] * (x / blend_extent)
182
+ return b
183
+
184
+ def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
185
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
186
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
187
+ row_limit = self.tile_sample_min_size - blend_extent
188
+ tile_target_shape = (
189
+ *target_shape[:3],
190
+ self.tile_sample_min_size,
191
+ self.tile_sample_min_size,
192
+ )
193
+ # Split z into overlapping 64x64 tiles and decode them separately.
194
+ # The tiles have an overlap to avoid seams between tiles.
195
+ rows = []
196
+ for i in range(0, z.shape[3], overlap_size):
197
+ row = []
198
+ for j in range(0, z.shape[4], overlap_size):
199
+ tile = z[
200
+ :,
201
+ :,
202
+ :,
203
+ i : i + self.tile_latent_min_size,
204
+ j : j + self.tile_latent_min_size,
205
+ ]
206
+ tile = self.post_quant_conv(tile)
207
+ decoded = self.decoder(tile, target_shape=tile_target_shape)
208
+ row.append(decoded)
209
+ rows.append(row)
210
+ result_rows = []
211
+ for i, row in enumerate(rows):
212
+ result_row = []
213
+ for j, tile in enumerate(row):
214
+ # blend the above tile and the left tile
215
+ # to the current tile and add the current tile to the result row
216
+ if i > 0:
217
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
218
+ if j > 0:
219
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
220
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
221
+ result_rows.append(torch.cat(result_row, dim=4))
222
+
223
+ dec = torch.cat(result_rows, dim=3)
224
+ return dec
225
+
226
+ def encode(
227
+ self, z: torch.FloatTensor, return_dict: bool = True
228
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
229
+ if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
230
+ num_splits = z.shape[2] // self.z_sample_size
231
+ sizes = [self.z_sample_size] * num_splits
232
+ sizes = (
233
+ sizes + [z.shape[2] - sum(sizes)]
234
+ if z.shape[2] - sum(sizes) > 0
235
+ else sizes
236
+ )
237
+ tiles = z.split(sizes, dim=2)
238
+ moments_tiles = [
239
+ (
240
+ self._hw_tiled_encode(z_tile, return_dict)
241
+ if self.use_hw_tiling
242
+ else self._encode(z_tile)
243
+ )
244
+ for z_tile in tiles
245
+ ]
246
+ moments = torch.cat(moments_tiles, dim=2)
247
+
248
+ else:
249
+ moments = (
250
+ self._hw_tiled_encode(z, return_dict)
251
+ if self.use_hw_tiling
252
+ else self._encode(z)
253
+ )
254
+
255
+ posterior = DiagonalGaussianDistribution(moments)
256
+ if not return_dict:
257
+ return (posterior,)
258
+
259
+ return AutoencoderKLOutput(latent_dist=posterior)
260
+
261
+ def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
262
+ if isinstance(self.latent_norm_out, nn.BatchNorm3d):
263
+ _, c, _, _, _ = z.shape
264
+ z = torch.cat(
265
+ [
266
+ self.latent_norm_out(z[:, : c // 2, :, :, :]),
267
+ z[:, c // 2 :, :, :, :],
268
+ ],
269
+ dim=1,
270
+ )
271
+ elif isinstance(self.latent_norm_out, nn.BatchNorm2d):
272
+ raise NotImplementedError("BatchNorm2d not supported")
273
+ return z
274
+
275
+ def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
276
+ if isinstance(self.latent_norm_out, nn.BatchNorm3d):
277
+ running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1)
278
+ running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1)
279
+ eps = self.latent_norm_out.eps
280
+
281
+ z = z * torch.sqrt(running_var + eps) + running_mean
282
+ elif isinstance(self.latent_norm_out, nn.BatchNorm3d):
283
+ raise NotImplementedError("BatchNorm2d not supported")
284
+ return z
285
+
286
+ def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
287
+ h = self.encoder(x)
288
+ moments = self.quant_conv(h)
289
+ moments = self._normalize_latent_channels(moments)
290
+ return moments
291
+
292
+ def _decode(
293
+ self,
294
+ z: torch.FloatTensor,
295
+ target_shape=None,
296
+ timestep: Optional[torch.Tensor] = None,
297
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
298
+ z = self._unnormalize_latent_channels(z)
299
+ z = self.post_quant_conv(z)
300
+ if "timestep" in self.decoder_params:
301
+ dec = self.decoder(z, target_shape=target_shape, timestep=timestep)
302
+ else:
303
+ dec = self.decoder(z, target_shape=target_shape)
304
+ return dec
305
+
306
+ def decode(
307
+ self,
308
+ z: torch.FloatTensor,
309
+ return_dict: bool = True,
310
+ target_shape=None,
311
+ timestep: Optional[torch.Tensor] = None,
312
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
313
+ assert target_shape is not None, "target_shape must be provided for decoding"
314
+ if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
315
+ reduction_factor = int(
316
+ self.encoder.patch_size_t
317
+ * 2
318
+ ** (
319
+ len(self.encoder.down_blocks)
320
+ - 1
321
+ - math.sqrt(self.encoder.patch_size)
322
+ )
323
+ )
324
+ split_size = self.z_sample_size // reduction_factor
325
+ num_splits = z.shape[2] // split_size
326
+
327
+ # copy target shape, and divide frame dimension (=2) by the context size
328
+ target_shape_split = list(target_shape)
329
+ target_shape_split[2] = target_shape[2] // num_splits
330
+
331
+ decoded_tiles = [
332
+ (
333
+ self._hw_tiled_decode(z_tile, target_shape_split)
334
+ if self.use_hw_tiling
335
+ else self._decode(z_tile, target_shape=target_shape_split)
336
+ )
337
+ for z_tile in torch.tensor_split(z, num_splits, dim=2)
338
+ ]
339
+ decoded = torch.cat(decoded_tiles, dim=2)
340
+ else:
341
+ decoded = (
342
+ self._hw_tiled_decode(z, target_shape)
343
+ if self.use_hw_tiling
344
+ else self._decode(z, target_shape=target_shape, timestep=timestep)
345
+ )
346
+
347
+ if not return_dict:
348
+ return (decoded,)
349
+
350
+ return DecoderOutput(sample=decoded)
351
+
352
+ def forward(
353
+ self,
354
+ sample: torch.FloatTensor,
355
+ sample_posterior: bool = False,
356
+ return_dict: bool = True,
357
+ generator: Optional[torch.Generator] = None,
358
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
359
+ r"""
360
+ Args:
361
+ sample (`torch.FloatTensor`): Input sample.
362
+ sample_posterior (`bool`, *optional*, defaults to `False`):
363
+ Whether to sample from the posterior.
364
+ return_dict (`bool`, *optional*, defaults to `True`):
365
+ Whether to return a [`DecoderOutput`] instead of a plain tuple.
366
+ generator (`torch.Generator`, *optional*):
367
+ Generator used to sample from the posterior.
368
+ """
369
+ x = sample
370
+ posterior = self.encode(x).latent_dist
371
+ if sample_posterior:
372
+ z = posterior.sample(generator=generator)
373
+ else:
374
+ z = posterior.mode()
375
+ dec = self.decode(z, target_shape=sample.shape).sample
376
+
377
+ if not return_dict:
378
+ return (dec,)
379
+
380
+ return DecoderOutput(sample=dec)
flash_head/ltx_video/models/autoencoders/vae_encode.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ from diffusers import AutoencoderKL
4
+ from einops import rearrange
5
+ from torch import Tensor
6
+
7
+
8
+ from flash_head.ltx_video.models.autoencoders.causal_video_autoencoder import (
9
+ CausalVideoAutoencoder,
10
+ )
11
+ from flash_head.ltx_video.models.autoencoders.video_autoencoder import (
12
+ Downsample3D,
13
+ VideoAutoencoder,
14
+ )
15
+
16
+ try:
17
+ import torch_xla.core.xla_model as xm
18
+ except ImportError:
19
+ xm = None
20
+
21
+
22
+ def vae_encode(
23
+ media_items: Tensor,
24
+ vae: AutoencoderKL,
25
+ split_size: int = 1,
26
+ vae_per_channel_normalize=False,
27
+ ) -> Tensor:
28
+ """
29
+ Encodes media items (images or videos) into latent representations using a specified VAE model.
30
+ The function supports processing batches of images or video frames and can handle the processing
31
+ in smaller sub-batches if needed.
32
+
33
+ Args:
34
+ media_items (Tensor): A torch Tensor containing the media items to encode. The expected
35
+ shape is (batch_size, channels, height, width) for images or (batch_size, channels,
36
+ frames, height, width) for videos.
37
+ vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
38
+ pre-configured and loaded with the appropriate model weights.
39
+ split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
40
+ If set to more than 1, the input media items are processed in smaller batches according to
41
+ this value. Defaults to 1, which processes all items in a single batch.
42
+
43
+ Returns:
44
+ Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
45
+ to match the input shape, scaled by the model's configuration.
46
+
47
+ Examples:
48
+ >>> import torch
49
+ >>> from diffusers import AutoencoderKL
50
+ >>> vae = AutoencoderKL.from_pretrained('your-model-name')
51
+ >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
52
+ >>> latents = vae_encode(images, vae)
53
+ >>> print(latents.shape) # Output shape will depend on the model's latent configuration.
54
+
55
+ Note:
56
+ In case of a video, the function encodes the media item frame-by frame.
57
+ """
58
+ is_video_shaped = media_items.dim() == 5
59
+ batch_size, channels = media_items.shape[0:2]
60
+
61
+ if channels != 3:
62
+ raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
63
+
64
+ # if is_video_shaped and not isinstance(
65
+ # vae, (VideoAutoencoder, CausalVideoAutoencoder)
66
+ # ): #这里经过fsdp包裹之后,无法判断了,因此后面的条件就不要了
67
+ if is_video_shaped and False: #False是为了兼容fsdp包裹之后的模型
68
+ media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
69
+ if split_size > 1:
70
+ if len(media_items) % split_size != 0:
71
+ raise ValueError(
72
+ "Error: The batch size must be divisible by 'train.vae_bs_split"
73
+ )
74
+ encode_bs = len(media_items) // split_size
75
+ # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
76
+ latents = []
77
+ if media_items.device.type == "xla":
78
+ xm.mark_step()
79
+ for image_batch in media_items.split(encode_bs):
80
+ latents.append(vae.encode(image_batch).latent_dist.sample())
81
+ if media_items.device.type == "xla":
82
+ xm.mark_step()
83
+ latents = torch.cat(latents, dim=0)
84
+ else:
85
+ latents = vae.encode(media_items).latent_dist.sample()
86
+
87
+ latents = normalize_latents(latents, vae, vae_per_channel_normalize)
88
+ # if is_video_shaped and not isinstance(
89
+ # vae, (VideoAutoencoder, CausalVideoAutoencoder)
90
+ # ):
91
+ if is_video_shaped and False: #False是为了兼容fsdp包裹之后的模型
92
+ latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
93
+ return latents
94
+
95
+
96
+ def vae_decode(
97
+ latents: Tensor,
98
+ vae: AutoencoderKL,
99
+ is_video: bool = True,
100
+ split_size: int = 1,
101
+ vae_per_channel_normalize=False,
102
+ timestep=None,
103
+ ) -> Tensor:
104
+ is_video_shaped = latents.dim() == 5
105
+ batch_size = latents.shape[0]
106
+
107
+ # if is_video_shaped and not isinstance(
108
+ # vae, (VideoAutoencoder, CausalVideoAutoencoder)
109
+ # ):
110
+ if is_video_shaped and False: #False是为了兼容fsdp包裹之后的模型
111
+ latents = rearrange(latents, "b c n h w -> (b n) c h w")
112
+ if split_size > 1:
113
+ if len(latents) % split_size != 0:
114
+ raise ValueError(
115
+ "Error: The batch size must be divisible by 'train.vae_bs_split"
116
+ )
117
+ encode_bs = len(latents) // split_size
118
+ image_batch = [
119
+ _run_decoder(
120
+ latent_batch, vae, is_video, vae_per_channel_normalize, timestep
121
+ )
122
+ for latent_batch in latents.split(encode_bs)
123
+ ]
124
+ images = torch.cat(image_batch, dim=0)
125
+ else:
126
+ images = _run_decoder(
127
+ latents, vae, is_video, vae_per_channel_normalize, timestep
128
+ )
129
+
130
+ # if is_video_shaped and not isinstance(
131
+ # vae, (VideoAutoencoder, CausalVideoAutoencoder)
132
+ # ):
133
+ if is_video_shaped and False: #False是为了兼容fsdp包裹之后的模型
134
+ images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
135
+ return images
136
+
137
+
138
+ def _run_decoder(
139
+ latents: Tensor,
140
+ vae: AutoencoderKL,
141
+ is_video: bool,
142
+ vae_per_channel_normalize=False,
143
+ timestep=None,
144
+ ) -> Tensor:
145
+ # if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
146
+ if False: #True是为了兼容fsdp包裹之后的模型
147
+ *_, fl, hl, wl = latents.shape
148
+ temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
149
+ latents = latents.to(vae.dtype)
150
+ vae_decode_kwargs = {}
151
+ if timestep is not None:
152
+ vae_decode_kwargs["timestep"] = timestep
153
+ image = vae.decode(
154
+ un_normalize_latents(latents, vae, vae_per_channel_normalize),
155
+ return_dict=False,
156
+ target_shape=(
157
+ 1,
158
+ 3,
159
+ fl * temporal_scale if is_video else 1,
160
+ hl * spatial_scale,
161
+ wl * spatial_scale,
162
+ ),
163
+ **vae_decode_kwargs,
164
+ )[0]
165
+ else:
166
+ image = vae.decode(
167
+ un_normalize_latents(latents, vae, vae_per_channel_normalize),
168
+ return_dict=False,
169
+ target_shape=latents.shape
170
+ )[0]
171
+
172
+ return image
173
+
174
+
175
+ def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
176
+ # if isinstance(vae, CausalVideoAutoencoder):
177
+ if True: #True是为了兼容fsdp包裹之后的模型
178
+ spatial = vae.spatial_downscale_factor
179
+ temporal = vae.temporal_downscale_factor
180
+ else:
181
+ down_blocks = len(
182
+ [
183
+ block
184
+ for block in vae.encoder.down_blocks
185
+ if isinstance(block.downsample, Downsample3D)
186
+ ]
187
+ )
188
+ spatial = vae.config.patch_size * 2**down_blocks
189
+ temporal = (
190
+ vae.config.patch_size_t * 2**down_blocks
191
+ if isinstance(vae, VideoAutoencoder)
192
+ else 1
193
+ )
194
+
195
+ return (temporal, spatial, spatial)
196
+
197
+
198
+ def latent_to_pixel_coords(
199
+ latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
200
+ ) -> Tensor:
201
+ """
202
+ Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
203
+ configuration.
204
+
205
+ Args:
206
+ latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
207
+ containing the latent corner coordinates of each token.
208
+ vae (AutoencoderKL): The VAE model
209
+ causal_fix (bool): Whether to take into account the different temporal scale
210
+ of the first frame. Default = False for backwards compatibility.
211
+ Returns:
212
+ Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
213
+ """
214
+
215
+ scale_factors = get_vae_size_scale_factor(vae)
216
+ # causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
217
+ causal_fix = True and causal_fix #True是为了兼容fsdp包裹之后的模型
218
+ pixel_coords = latent_to_pixel_coords_from_factors(
219
+ latent_coords, scale_factors, causal_fix
220
+ )
221
+ return pixel_coords
222
+
223
+
224
+ def latent_to_pixel_coords_from_factors(
225
+ latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
226
+ ) -> Tensor:
227
+ pixel_coords = (
228
+ latent_coords
229
+ * torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
230
+ )
231
+ if causal_fix:
232
+ # Fix temporal scale for first frame to 1 due to causality
233
+ pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
234
+ return pixel_coords
235
+
236
+
237
+ def normalize_latents(
238
+ latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
239
+ ) -> Tensor:
240
+ return (
241
+ (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
242
+ / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
243
+ if vae_per_channel_normalize
244
+ else latents * vae.config.scaling_factor
245
+ )
246
+
247
+
248
+ def un_normalize_latents(
249
+ latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
250
+ ) -> Tensor:
251
+ return (
252
+ latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
253
+ + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
254
+ if vae_per_channel_normalize
255
+ else latents / vae.config.scaling_factor
256
+ )
flash_head/ltx_video/models/autoencoders/video_autoencoder.py ADDED
@@ -0,0 +1,1045 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from functools import partial
4
+ from types import SimpleNamespace
5
+ from typing import Any, Mapping, Optional, Tuple, Union
6
+
7
+ import torch
8
+ from einops import rearrange
9
+ from torch import nn
10
+ from torch.nn import functional
11
+
12
+ from diffusers.utils import logging
13
+
14
+ from flash_head.ltx_video.utils.torch_utils import Identity
15
+ from flash_head.ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
16
+ from flash_head.ltx_video.models.autoencoders.pixel_norm import PixelNorm
17
+ from flash_head.ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
18
+
19
+ logger = logging.get_logger(__name__)
20
+
21
+
22
+ class VideoAutoencoder(AutoencoderKLWrapper):
23
+ @classmethod
24
+ def from_pretrained(
25
+ cls,
26
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
27
+ *args,
28
+ **kwargs,
29
+ ):
30
+ config_local_path = pretrained_model_name_or_path / "config.json"
31
+ config = cls.load_config(config_local_path, **kwargs)
32
+ video_vae = cls.from_config(config)
33
+ video_vae.to(kwargs["torch_dtype"])
34
+
35
+ model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
36
+ ckpt_state_dict = torch.load(model_local_path)
37
+ video_vae.load_state_dict(ckpt_state_dict)
38
+
39
+ statistics_local_path = (
40
+ pretrained_model_name_or_path / "per_channel_statistics.json"
41
+ )
42
+ if statistics_local_path.exists():
43
+ with open(statistics_local_path, "r") as file:
44
+ data = json.load(file)
45
+ transposed_data = list(zip(*data["data"]))
46
+ data_dict = {
47
+ col: torch.tensor(vals)
48
+ for col, vals in zip(data["columns"], transposed_data)
49
+ }
50
+ video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
51
+ video_vae.register_buffer(
52
+ "mean_of_means",
53
+ data_dict.get(
54
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
55
+ ),
56
+ )
57
+
58
+ return video_vae
59
+
60
+ @staticmethod
61
+ def from_config(config):
62
+ assert (
63
+ config["_class_name"] == "VideoAutoencoder"
64
+ ), "config must have _class_name=VideoAutoencoder"
65
+ if isinstance(config["dims"], list):
66
+ config["dims"] = tuple(config["dims"])
67
+
68
+ assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
69
+
70
+ double_z = config.get("double_z", True)
71
+ latent_log_var = config.get(
72
+ "latent_log_var", "per_channel" if double_z else "none"
73
+ )
74
+ use_quant_conv = config.get("use_quant_conv", True)
75
+
76
+ if use_quant_conv and latent_log_var == "uniform":
77
+ raise ValueError("uniform latent_log_var requires use_quant_conv=False")
78
+
79
+ encoder = Encoder(
80
+ dims=config["dims"],
81
+ in_channels=config.get("in_channels", 3),
82
+ out_channels=config["latent_channels"],
83
+ block_out_channels=config["block_out_channels"],
84
+ patch_size=config.get("patch_size", 1),
85
+ latent_log_var=latent_log_var,
86
+ norm_layer=config.get("norm_layer", "group_norm"),
87
+ patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
88
+ add_channel_padding=config.get("add_channel_padding", False),
89
+ )
90
+
91
+ decoder = Decoder(
92
+ dims=config["dims"],
93
+ in_channels=config["latent_channels"],
94
+ out_channels=config.get("out_channels", 3),
95
+ block_out_channels=config["block_out_channels"],
96
+ patch_size=config.get("patch_size", 1),
97
+ norm_layer=config.get("norm_layer", "group_norm"),
98
+ patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
99
+ add_channel_padding=config.get("add_channel_padding", False),
100
+ )
101
+
102
+ dims = config["dims"]
103
+ return VideoAutoencoder(
104
+ encoder=encoder,
105
+ decoder=decoder,
106
+ latent_channels=config["latent_channels"],
107
+ dims=dims,
108
+ use_quant_conv=use_quant_conv,
109
+ )
110
+
111
+ @property
112
+ def config(self):
113
+ return SimpleNamespace(
114
+ _class_name="VideoAutoencoder",
115
+ dims=self.dims,
116
+ in_channels=self.encoder.conv_in.in_channels
117
+ // (self.encoder.patch_size_t * self.encoder.patch_size**2),
118
+ out_channels=self.decoder.conv_out.out_channels
119
+ // (self.decoder.patch_size_t * self.decoder.patch_size**2),
120
+ latent_channels=self.decoder.conv_in.in_channels,
121
+ block_out_channels=[
122
+ self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
123
+ for i in range(len(self.encoder.down_blocks))
124
+ ],
125
+ scaling_factor=1.0,
126
+ norm_layer=self.encoder.norm_layer,
127
+ patch_size=self.encoder.patch_size,
128
+ latent_log_var=self.encoder.latent_log_var,
129
+ use_quant_conv=self.use_quant_conv,
130
+ patch_size_t=self.encoder.patch_size_t,
131
+ add_channel_padding=self.encoder.add_channel_padding,
132
+ )
133
+
134
+ @property
135
+ def is_video_supported(self):
136
+ """
137
+ Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
138
+ """
139
+ return self.dims != 2
140
+
141
+ @property
142
+ def downscale_factor(self):
143
+ return self.encoder.downsample_factor
144
+
145
+ def to_json_string(self) -> str:
146
+ import json
147
+
148
+ return json.dumps(self.config.__dict__)
149
+
150
+ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
151
+ model_keys = set(name for name, _ in self.named_parameters())
152
+
153
+ key_mapping = {
154
+ ".resnets.": ".res_blocks.",
155
+ "downsamplers.0": "downsample",
156
+ "upsamplers.0": "upsample",
157
+ }
158
+
159
+ converted_state_dict = {}
160
+ for key, value in state_dict.items():
161
+ for k, v in key_mapping.items():
162
+ key = key.replace(k, v)
163
+
164
+ if "norm" in key and key not in model_keys:
165
+ logger.info(
166
+ f"Removing key {key} from state_dict as it is not present in the model"
167
+ )
168
+ continue
169
+
170
+ converted_state_dict[key] = value
171
+
172
+ super().load_state_dict(converted_state_dict, strict=strict)
173
+
174
+ def last_layer(self):
175
+ if hasattr(self.decoder, "conv_out"):
176
+ if isinstance(self.decoder.conv_out, nn.Sequential):
177
+ last_layer = self.decoder.conv_out[-1]
178
+ else:
179
+ last_layer = self.decoder.conv_out
180
+ else:
181
+ last_layer = self.decoder.layers[-1]
182
+ return last_layer
183
+
184
+
185
+ class Encoder(nn.Module):
186
+ r"""
187
+ The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
188
+
189
+ Args:
190
+ in_channels (`int`, *optional*, defaults to 3):
191
+ The number of input channels.
192
+ out_channels (`int`, *optional*, defaults to 3):
193
+ The number of output channels.
194
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
195
+ The number of output channels for each block.
196
+ layers_per_block (`int`, *optional*, defaults to 2):
197
+ The number of layers per block.
198
+ norm_num_groups (`int`, *optional*, defaults to 32):
199
+ The number of groups for normalization.
200
+ patch_size (`int`, *optional*, defaults to 1):
201
+ The patch size to use. Should be a power of 2.
202
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
203
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
204
+ latent_log_var (`str`, *optional*, defaults to `per_channel`):
205
+ The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ dims: Union[int, Tuple[int, int]] = 3,
211
+ in_channels: int = 3,
212
+ out_channels: int = 3,
213
+ block_out_channels: Tuple[int, ...] = (64,),
214
+ layers_per_block: int = 2,
215
+ norm_num_groups: int = 32,
216
+ patch_size: Union[int, Tuple[int]] = 1,
217
+ norm_layer: str = "group_norm", # group_norm, pixel_norm
218
+ latent_log_var: str = "per_channel",
219
+ patch_size_t: Optional[int] = None,
220
+ add_channel_padding: Optional[bool] = False,
221
+ ):
222
+ super().__init__()
223
+ self.patch_size = patch_size
224
+ self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
225
+ self.add_channel_padding = add_channel_padding
226
+ self.layers_per_block = layers_per_block
227
+ self.norm_layer = norm_layer
228
+ self.latent_channels = out_channels
229
+ self.latent_log_var = latent_log_var
230
+ if add_channel_padding:
231
+ in_channels = in_channels * self.patch_size**3
232
+ else:
233
+ in_channels = in_channels * self.patch_size_t * self.patch_size**2
234
+ self.in_channels = in_channels
235
+ output_channel = block_out_channels[0]
236
+
237
+ self.conv_in = make_conv_nd(
238
+ dims=dims,
239
+ in_channels=in_channels,
240
+ out_channels=output_channel,
241
+ kernel_size=3,
242
+ stride=1,
243
+ padding=1,
244
+ )
245
+
246
+ self.down_blocks = nn.ModuleList([])
247
+
248
+ for i in range(len(block_out_channels)):
249
+ input_channel = output_channel
250
+ output_channel = block_out_channels[i]
251
+ is_final_block = i == len(block_out_channels) - 1
252
+
253
+ down_block = DownEncoderBlock3D(
254
+ dims=dims,
255
+ in_channels=input_channel,
256
+ out_channels=output_channel,
257
+ num_layers=self.layers_per_block,
258
+ add_downsample=not is_final_block and 2**i >= patch_size,
259
+ resnet_eps=1e-6,
260
+ downsample_padding=0,
261
+ resnet_groups=norm_num_groups,
262
+ norm_layer=norm_layer,
263
+ )
264
+ self.down_blocks.append(down_block)
265
+
266
+ self.mid_block = UNetMidBlock3D(
267
+ dims=dims,
268
+ in_channels=block_out_channels[-1],
269
+ num_layers=self.layers_per_block,
270
+ resnet_eps=1e-6,
271
+ resnet_groups=norm_num_groups,
272
+ norm_layer=norm_layer,
273
+ )
274
+
275
+ # out
276
+ if norm_layer == "group_norm":
277
+ self.conv_norm_out = nn.GroupNorm(
278
+ num_channels=block_out_channels[-1],
279
+ num_groups=norm_num_groups,
280
+ eps=1e-6,
281
+ )
282
+ elif norm_layer == "pixel_norm":
283
+ self.conv_norm_out = PixelNorm()
284
+ self.conv_act = nn.SiLU()
285
+
286
+ conv_out_channels = out_channels
287
+ if latent_log_var == "per_channel":
288
+ conv_out_channels *= 2
289
+ elif latent_log_var == "uniform":
290
+ conv_out_channels += 1
291
+ elif latent_log_var != "none":
292
+ raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
293
+ self.conv_out = make_conv_nd(
294
+ dims, block_out_channels[-1], conv_out_channels, 3, padding=1
295
+ )
296
+
297
+ self.gradient_checkpointing = False
298
+
299
+ @property
300
+ def downscale_factor(self):
301
+ return (
302
+ 2
303
+ ** len(
304
+ [
305
+ block
306
+ for block in self.down_blocks
307
+ if isinstance(block.downsample, Downsample3D)
308
+ ]
309
+ )
310
+ * self.patch_size
311
+ )
312
+
313
+ def forward(
314
+ self, sample: torch.FloatTensor, return_features=False
315
+ ) -> torch.FloatTensor:
316
+ r"""The forward method of the `Encoder` class."""
317
+
318
+ downsample_in_time = sample.shape[2] != 1
319
+
320
+ # patchify
321
+ patch_size_t = self.patch_size_t if downsample_in_time else 1
322
+ sample = patchify(
323
+ sample,
324
+ patch_size_hw=self.patch_size,
325
+ patch_size_t=patch_size_t,
326
+ add_channel_padding=self.add_channel_padding,
327
+ )
328
+
329
+ sample = self.conv_in(sample)
330
+
331
+ checkpoint_fn = (
332
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
333
+ if self.gradient_checkpointing and self.training
334
+ else lambda x: x
335
+ )
336
+
337
+ if return_features:
338
+ features = []
339
+ for down_block in self.down_blocks:
340
+ sample = checkpoint_fn(down_block)(
341
+ sample, downsample_in_time=downsample_in_time
342
+ )
343
+ if return_features:
344
+ features.append(sample)
345
+
346
+ sample = checkpoint_fn(self.mid_block)(sample)
347
+
348
+ # post-process
349
+ sample = self.conv_norm_out(sample)
350
+ sample = self.conv_act(sample)
351
+ sample = self.conv_out(sample)
352
+
353
+ if self.latent_log_var == "uniform":
354
+ last_channel = sample[:, -1:, ...]
355
+ num_dims = sample.dim()
356
+
357
+ if num_dims == 4:
358
+ # For shape (B, C, H, W)
359
+ repeated_last_channel = last_channel.repeat(
360
+ 1, sample.shape[1] - 2, 1, 1
361
+ )
362
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
363
+ elif num_dims == 5:
364
+ # For shape (B, C, F, H, W)
365
+ repeated_last_channel = last_channel.repeat(
366
+ 1, sample.shape[1] - 2, 1, 1, 1
367
+ )
368
+ sample = torch.cat([sample, repeated_last_channel], dim=1)
369
+ else:
370
+ raise ValueError(f"Invalid input shape: {sample.shape}")
371
+
372
+ if return_features:
373
+ features.append(sample[:, : self.latent_channels, ...])
374
+ return sample, features
375
+ return sample
376
+
377
+
378
+ class Decoder(nn.Module):
379
+ r"""
380
+ The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
381
+
382
+ Args:
383
+ in_channels (`int`, *optional*, defaults to 3):
384
+ The number of input channels.
385
+ out_channels (`int`, *optional*, defaults to 3):
386
+ The number of output channels.
387
+ block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
388
+ The number of output channels for each block.
389
+ layers_per_block (`int`, *optional*, defaults to 2):
390
+ The number of layers per block.
391
+ norm_num_groups (`int`, *optional*, defaults to 32):
392
+ The number of groups for normalization.
393
+ patch_size (`int`, *optional*, defaults to 1):
394
+ The patch size to use. Should be a power of 2.
395
+ norm_layer (`str`, *optional*, defaults to `group_norm`):
396
+ The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
397
+ """
398
+
399
+ def __init__(
400
+ self,
401
+ dims,
402
+ in_channels: int = 3,
403
+ out_channels: int = 3,
404
+ block_out_channels: Tuple[int, ...] = (64,),
405
+ layers_per_block: int = 2,
406
+ norm_num_groups: int = 32,
407
+ patch_size: int = 1,
408
+ norm_layer: str = "group_norm",
409
+ patch_size_t: Optional[int] = None,
410
+ add_channel_padding: Optional[bool] = False,
411
+ ):
412
+ super().__init__()
413
+ self.patch_size = patch_size
414
+ self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
415
+ self.add_channel_padding = add_channel_padding
416
+ self.layers_per_block = layers_per_block
417
+ if add_channel_padding:
418
+ out_channels = out_channels * self.patch_size**3
419
+ else:
420
+ out_channels = out_channels * self.patch_size_t * self.patch_size**2
421
+ self.out_channels = out_channels
422
+
423
+ self.conv_in = make_conv_nd(
424
+ dims,
425
+ in_channels,
426
+ block_out_channels[-1],
427
+ kernel_size=3,
428
+ stride=1,
429
+ padding=1,
430
+ )
431
+
432
+ self.mid_block = None
433
+ self.up_blocks = nn.ModuleList([])
434
+
435
+ self.mid_block = UNetMidBlock3D(
436
+ dims=dims,
437
+ in_channels=block_out_channels[-1],
438
+ num_layers=self.layers_per_block,
439
+ resnet_eps=1e-6,
440
+ resnet_groups=norm_num_groups,
441
+ norm_layer=norm_layer,
442
+ )
443
+
444
+ reversed_block_out_channels = list(reversed(block_out_channels))
445
+ output_channel = reversed_block_out_channels[0]
446
+ for i in range(len(reversed_block_out_channels)):
447
+ prev_output_channel = output_channel
448
+ output_channel = reversed_block_out_channels[i]
449
+
450
+ is_final_block = i == len(block_out_channels) - 1
451
+
452
+ up_block = UpDecoderBlock3D(
453
+ dims=dims,
454
+ num_layers=self.layers_per_block + 1,
455
+ in_channels=prev_output_channel,
456
+ out_channels=output_channel,
457
+ add_upsample=not is_final_block
458
+ and 2 ** (len(block_out_channels) - i - 1) > patch_size,
459
+ resnet_eps=1e-6,
460
+ resnet_groups=norm_num_groups,
461
+ norm_layer=norm_layer,
462
+ )
463
+ self.up_blocks.append(up_block)
464
+
465
+ if norm_layer == "group_norm":
466
+ self.conv_norm_out = nn.GroupNorm(
467
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
468
+ )
469
+ elif norm_layer == "pixel_norm":
470
+ self.conv_norm_out = PixelNorm()
471
+
472
+ self.conv_act = nn.SiLU()
473
+ self.conv_out = make_conv_nd(
474
+ dims, block_out_channels[0], out_channels, 3, padding=1
475
+ )
476
+
477
+ self.gradient_checkpointing = False
478
+
479
+ def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
480
+ r"""The forward method of the `Decoder` class."""
481
+ assert target_shape is not None, "target_shape must be provided"
482
+ upsample_in_time = sample.shape[2] < target_shape[2]
483
+
484
+ sample = self.conv_in(sample)
485
+
486
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
487
+
488
+ checkpoint_fn = (
489
+ partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
490
+ if self.gradient_checkpointing and self.training
491
+ else lambda x: x
492
+ )
493
+
494
+ sample = checkpoint_fn(self.mid_block)(sample)
495
+ sample = sample.to(upscale_dtype)
496
+
497
+ for up_block in self.up_blocks:
498
+ sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
499
+
500
+ # post-process
501
+ sample = self.conv_norm_out(sample)
502
+ sample = self.conv_act(sample)
503
+ sample = self.conv_out(sample)
504
+
505
+ # un-patchify
506
+ patch_size_t = self.patch_size_t if upsample_in_time else 1
507
+ sample = unpatchify(
508
+ sample,
509
+ patch_size_hw=self.patch_size,
510
+ patch_size_t=patch_size_t,
511
+ add_channel_padding=self.add_channel_padding,
512
+ )
513
+
514
+ return sample
515
+
516
+
517
+ class DownEncoderBlock3D(nn.Module):
518
+ def __init__(
519
+ self,
520
+ dims: Union[int, Tuple[int, int]],
521
+ in_channels: int,
522
+ out_channels: int,
523
+ dropout: float = 0.0,
524
+ num_layers: int = 1,
525
+ resnet_eps: float = 1e-6,
526
+ resnet_groups: int = 32,
527
+ add_downsample: bool = True,
528
+ downsample_padding: int = 1,
529
+ norm_layer: str = "group_norm",
530
+ ):
531
+ super().__init__()
532
+ res_blocks = []
533
+
534
+ for i in range(num_layers):
535
+ in_channels = in_channels if i == 0 else out_channels
536
+ res_blocks.append(
537
+ ResnetBlock3D(
538
+ dims=dims,
539
+ in_channels=in_channels,
540
+ out_channels=out_channels,
541
+ eps=resnet_eps,
542
+ groups=resnet_groups,
543
+ dropout=dropout,
544
+ norm_layer=norm_layer,
545
+ )
546
+ )
547
+
548
+ self.res_blocks = nn.ModuleList(res_blocks)
549
+
550
+ if add_downsample:
551
+ self.downsample = Downsample3D(
552
+ dims,
553
+ out_channels,
554
+ out_channels=out_channels,
555
+ padding=downsample_padding,
556
+ )
557
+ else:
558
+ self.downsample = Identity()
559
+
560
+ def forward(
561
+ self, hidden_states: torch.FloatTensor, downsample_in_time
562
+ ) -> torch.FloatTensor:
563
+ for resnet in self.res_blocks:
564
+ hidden_states = resnet(hidden_states)
565
+
566
+ hidden_states = self.downsample(
567
+ hidden_states, downsample_in_time=downsample_in_time
568
+ )
569
+
570
+ return hidden_states
571
+
572
+
573
+ class UNetMidBlock3D(nn.Module):
574
+ """
575
+ A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
576
+
577
+ Args:
578
+ in_channels (`int`): The number of input channels.
579
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
580
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
581
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
582
+ resnet_groups (`int`, *optional*, defaults to 32):
583
+ The number of groups to use in the group normalization layers of the resnet blocks.
584
+
585
+ Returns:
586
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
587
+ in_channels, height, width)`.
588
+
589
+ """
590
+
591
+ def __init__(
592
+ self,
593
+ dims: Union[int, Tuple[int, int]],
594
+ in_channels: int,
595
+ dropout: float = 0.0,
596
+ num_layers: int = 1,
597
+ resnet_eps: float = 1e-6,
598
+ resnet_groups: int = 32,
599
+ norm_layer: str = "group_norm",
600
+ ):
601
+ super().__init__()
602
+ resnet_groups = (
603
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
604
+ )
605
+
606
+ self.res_blocks = nn.ModuleList(
607
+ [
608
+ ResnetBlock3D(
609
+ dims=dims,
610
+ in_channels=in_channels,
611
+ out_channels=in_channels,
612
+ eps=resnet_eps,
613
+ groups=resnet_groups,
614
+ dropout=dropout,
615
+ norm_layer=norm_layer,
616
+ )
617
+ for _ in range(num_layers)
618
+ ]
619
+ )
620
+
621
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
622
+ for resnet in self.res_blocks:
623
+ hidden_states = resnet(hidden_states)
624
+
625
+ return hidden_states
626
+
627
+
628
+ class UpDecoderBlock3D(nn.Module):
629
+ def __init__(
630
+ self,
631
+ dims: Union[int, Tuple[int, int]],
632
+ in_channels: int,
633
+ out_channels: int,
634
+ resolution_idx: Optional[int] = None,
635
+ dropout: float = 0.0,
636
+ num_layers: int = 1,
637
+ resnet_eps: float = 1e-6,
638
+ resnet_groups: int = 32,
639
+ add_upsample: bool = True,
640
+ norm_layer: str = "group_norm",
641
+ ):
642
+ super().__init__()
643
+ res_blocks = []
644
+
645
+ for i in range(num_layers):
646
+ input_channels = in_channels if i == 0 else out_channels
647
+
648
+ res_blocks.append(
649
+ ResnetBlock3D(
650
+ dims=dims,
651
+ in_channels=input_channels,
652
+ out_channels=out_channels,
653
+ eps=resnet_eps,
654
+ groups=resnet_groups,
655
+ dropout=dropout,
656
+ norm_layer=norm_layer,
657
+ )
658
+ )
659
+
660
+ self.res_blocks = nn.ModuleList(res_blocks)
661
+
662
+ if add_upsample:
663
+ self.upsample = Upsample3D(
664
+ dims=dims, channels=out_channels, out_channels=out_channels
665
+ )
666
+ else:
667
+ self.upsample = Identity()
668
+
669
+ self.resolution_idx = resolution_idx
670
+
671
+ def forward(
672
+ self, hidden_states: torch.FloatTensor, upsample_in_time=True
673
+ ) -> torch.FloatTensor:
674
+ for resnet in self.res_blocks:
675
+ hidden_states = resnet(hidden_states)
676
+
677
+ hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
678
+
679
+ return hidden_states
680
+
681
+
682
+ class ResnetBlock3D(nn.Module):
683
+ r"""
684
+ A Resnet block.
685
+
686
+ Parameters:
687
+ in_channels (`int`): The number of channels in the input.
688
+ out_channels (`int`, *optional*, default to be `None`):
689
+ The number of output channels for the first conv layer. If None, same as `in_channels`.
690
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
691
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
692
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
693
+ """
694
+
695
+ def __init__(
696
+ self,
697
+ dims: Union[int, Tuple[int, int]],
698
+ in_channels: int,
699
+ out_channels: Optional[int] = None,
700
+ conv_shortcut: bool = False,
701
+ dropout: float = 0.0,
702
+ groups: int = 32,
703
+ eps: float = 1e-6,
704
+ norm_layer: str = "group_norm",
705
+ ):
706
+ super().__init__()
707
+ self.in_channels = in_channels
708
+ out_channels = in_channels if out_channels is None else out_channels
709
+ self.out_channels = out_channels
710
+ self.use_conv_shortcut = conv_shortcut
711
+
712
+ if norm_layer == "group_norm":
713
+ self.norm1 = torch.nn.GroupNorm(
714
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
715
+ )
716
+ elif norm_layer == "pixel_norm":
717
+ self.norm1 = PixelNorm()
718
+
719
+ self.non_linearity = nn.SiLU()
720
+
721
+ self.conv1 = make_conv_nd(
722
+ dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
723
+ )
724
+
725
+ if norm_layer == "group_norm":
726
+ self.norm2 = torch.nn.GroupNorm(
727
+ num_groups=groups, num_channels=out_channels, eps=eps, affine=True
728
+ )
729
+ elif norm_layer == "pixel_norm":
730
+ self.norm2 = PixelNorm()
731
+
732
+ self.dropout = torch.nn.Dropout(dropout)
733
+
734
+ self.conv2 = make_conv_nd(
735
+ dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
736
+ )
737
+
738
+ self.conv_shortcut = (
739
+ make_linear_nd(
740
+ dims=dims, in_channels=in_channels, out_channels=out_channels
741
+ )
742
+ if in_channels != out_channels
743
+ else nn.Identity()
744
+ )
745
+
746
+ def forward(
747
+ self,
748
+ input_tensor: torch.FloatTensor,
749
+ ) -> torch.FloatTensor:
750
+ hidden_states = input_tensor
751
+
752
+ hidden_states = self.norm1(hidden_states)
753
+
754
+ hidden_states = self.non_linearity(hidden_states)
755
+
756
+ hidden_states = self.conv1(hidden_states)
757
+
758
+ hidden_states = self.norm2(hidden_states)
759
+
760
+ hidden_states = self.non_linearity(hidden_states)
761
+
762
+ hidden_states = self.dropout(hidden_states)
763
+
764
+ hidden_states = self.conv2(hidden_states)
765
+
766
+ input_tensor = self.conv_shortcut(input_tensor)
767
+
768
+ output_tensor = input_tensor + hidden_states
769
+
770
+ return output_tensor
771
+
772
+
773
+ class Downsample3D(nn.Module):
774
+ def __init__(
775
+ self,
776
+ dims,
777
+ in_channels: int,
778
+ out_channels: int,
779
+ kernel_size: int = 3,
780
+ padding: int = 1,
781
+ ):
782
+ super().__init__()
783
+ stride: int = 2
784
+ self.padding = padding
785
+ self.in_channels = in_channels
786
+ self.dims = dims
787
+ self.conv = make_conv_nd(
788
+ dims=dims,
789
+ in_channels=in_channels,
790
+ out_channels=out_channels,
791
+ kernel_size=kernel_size,
792
+ stride=stride,
793
+ padding=padding,
794
+ )
795
+
796
+ def forward(self, x, downsample_in_time=True):
797
+ conv = self.conv
798
+ if self.padding == 0:
799
+ if self.dims == 2:
800
+ padding = (0, 1, 0, 1)
801
+ else:
802
+ padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
803
+
804
+ x = functional.pad(x, padding, mode="constant", value=0)
805
+
806
+ if self.dims == (2, 1) and not downsample_in_time:
807
+ return conv(x, skip_time_conv=True)
808
+
809
+ return conv(x)
810
+
811
+
812
+ class Upsample3D(nn.Module):
813
+ """
814
+ An upsampling layer for 3D tensors of shape (B, C, D, H, W).
815
+
816
+ :param channels: channels in the inputs and outputs.
817
+ """
818
+
819
+ def __init__(self, dims, channels, out_channels=None):
820
+ super().__init__()
821
+ self.dims = dims
822
+ self.channels = channels
823
+ self.out_channels = out_channels or channels
824
+ self.conv = make_conv_nd(
825
+ dims, channels, out_channels, kernel_size=3, padding=1, bias=True
826
+ )
827
+
828
+ def forward(self, x, upsample_in_time):
829
+ if self.dims == 2:
830
+ x = functional.interpolate(
831
+ x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
832
+ )
833
+ else:
834
+ time_scale_factor = 2 if upsample_in_time else 1
835
+ # print("before:", x.shape)
836
+ b, c, d, h, w = x.shape
837
+ x = rearrange(x, "b c d h w -> (b d) c h w")
838
+ # height and width interpolate
839
+ x = functional.interpolate(
840
+ x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
841
+ )
842
+ _, _, h, w = x.shape
843
+
844
+ if not upsample_in_time and self.dims == (2, 1):
845
+ x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
846
+ return self.conv(x, skip_time_conv=True)
847
+
848
+ # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
849
+ x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
850
+
851
+ # (b h w) c 1 d
852
+ new_d = x.shape[-1] * time_scale_factor
853
+ x = functional.interpolate(x, (1, new_d), mode="nearest")
854
+ # (b h w) c 1 new_d
855
+ x = rearrange(
856
+ x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
857
+ )
858
+ # b c d h w
859
+
860
+ # x = functional.interpolate(
861
+ # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
862
+ # )
863
+ # print("after:", x.shape)
864
+
865
+ return self.conv(x)
866
+
867
+
868
+ def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
869
+ if patch_size_hw == 1 and patch_size_t == 1:
870
+ return x
871
+ if x.dim() == 4:
872
+ x = rearrange(
873
+ x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
874
+ )
875
+ elif x.dim() == 5:
876
+ x = rearrange(
877
+ x,
878
+ "b c (f p) (h q) (w r) -> b (c p r q) f h w",
879
+ p=patch_size_t,
880
+ q=patch_size_hw,
881
+ r=patch_size_hw,
882
+ )
883
+ else:
884
+ raise ValueError(f"Invalid input shape: {x.shape}")
885
+
886
+ if (
887
+ (x.dim() == 5)
888
+ and (patch_size_hw > patch_size_t)
889
+ and (patch_size_t > 1 or add_channel_padding)
890
+ ):
891
+ channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
892
+ padding_zeros = torch.zeros(
893
+ x.shape[0],
894
+ channels_to_pad,
895
+ x.shape[2],
896
+ x.shape[3],
897
+ x.shape[4],
898
+ device=x.device,
899
+ dtype=x.dtype,
900
+ )
901
+ x = torch.cat([padding_zeros, x], dim=1)
902
+
903
+ return x
904
+
905
+
906
+ def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
907
+ if patch_size_hw == 1 and patch_size_t == 1:
908
+ return x
909
+
910
+ if (
911
+ (x.dim() == 5)
912
+ and (patch_size_hw > patch_size_t)
913
+ and (patch_size_t > 1 or add_channel_padding)
914
+ ):
915
+ channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
916
+ x = x[:, :channels_to_keep, :, :, :]
917
+
918
+ if x.dim() == 4:
919
+ x = rearrange(
920
+ x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
921
+ )
922
+ elif x.dim() == 5:
923
+ x = rearrange(
924
+ x,
925
+ "b (c p r q) f h w -> b c (f p) (h q) (w r)",
926
+ p=patch_size_t,
927
+ q=patch_size_hw,
928
+ r=patch_size_hw,
929
+ )
930
+
931
+ return x
932
+
933
+
934
+ def create_video_autoencoder_config(
935
+ latent_channels: int = 4,
936
+ ):
937
+ config = {
938
+ "_class_name": "VideoAutoencoder",
939
+ "dims": (
940
+ 2,
941
+ 1,
942
+ ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
943
+ "in_channels": 3, # Number of input color channels (e.g., RGB)
944
+ "out_channels": 3, # Number of output color channels
945
+ "latent_channels": latent_channels, # Number of channels in the latent space representation
946
+ "block_out_channels": [
947
+ 128,
948
+ 256,
949
+ 512,
950
+ 512,
951
+ ], # Number of output channels of each encoder / decoder inner block
952
+ "patch_size": 1,
953
+ }
954
+
955
+ return config
956
+
957
+
958
+ def create_video_autoencoder_pathify4x4x4_config(
959
+ latent_channels: int = 4,
960
+ ):
961
+ config = {
962
+ "_class_name": "VideoAutoencoder",
963
+ "dims": (
964
+ 2,
965
+ 1,
966
+ ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
967
+ "in_channels": 3, # Number of input color channels (e.g., RGB)
968
+ "out_channels": 3, # Number of output color channels
969
+ "latent_channels": latent_channels, # Number of channels in the latent space representation
970
+ "block_out_channels": [512]
971
+ * 4, # Number of output channels of each encoder / decoder inner block
972
+ "patch_size": 4,
973
+ "latent_log_var": "uniform",
974
+ }
975
+
976
+ return config
977
+
978
+
979
+ def create_video_autoencoder_pathify4x4_config(
980
+ latent_channels: int = 4,
981
+ ):
982
+ config = {
983
+ "_class_name": "VideoAutoencoder",
984
+ "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
985
+ "in_channels": 3, # Number of input color channels (e.g., RGB)
986
+ "out_channels": 3, # Number of output color channels
987
+ "latent_channels": latent_channels, # Number of channels in the latent space representation
988
+ "block_out_channels": [512]
989
+ * 4, # Number of output channels of each encoder / decoder inner block
990
+ "patch_size": 4,
991
+ "norm_layer": "pixel_norm",
992
+ }
993
+
994
+ return config
995
+
996
+
997
+ def test_vae_patchify_unpatchify():
998
+ import torch
999
+
1000
+ x = torch.randn(2, 3, 8, 64, 64)
1001
+ x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
1002
+ x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
1003
+ assert torch.allclose(x, x_unpatched)
1004
+
1005
+
1006
+ def demo_video_autoencoder_forward_backward():
1007
+ # Configuration for the VideoAutoencoder
1008
+ config = create_video_autoencoder_pathify4x4x4_config()
1009
+
1010
+ # Instantiate the VideoAutoencoder with the specified configuration
1011
+ video_autoencoder = VideoAutoencoder.from_config(config)
1012
+
1013
+ print(video_autoencoder)
1014
+
1015
+ # Print the total number of parameters in the video autoencoder
1016
+ total_params = sum(p.numel() for p in video_autoencoder.parameters())
1017
+ print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
1018
+
1019
+ # Create a mock input tensor simulating a batch of videos
1020
+ # Shape: (batch_size, channels, depth, height, width)
1021
+ # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
1022
+ input_videos = torch.randn(2, 3, 8, 64, 64)
1023
+
1024
+ # Forward pass: encode and decode the input videos
1025
+ latent = video_autoencoder.encode(input_videos).latent_dist.mode()
1026
+ print(f"input shape={input_videos.shape}")
1027
+ print(f"latent shape={latent.shape}")
1028
+ reconstructed_videos = video_autoencoder.decode(
1029
+ latent, target_shape=input_videos.shape
1030
+ ).sample
1031
+
1032
+ print(f"reconstructed shape={reconstructed_videos.shape}")
1033
+
1034
+ # Calculate the loss (e.g., mean squared error)
1035
+ loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
1036
+
1037
+ # Perform backward pass
1038
+ loss.backward()
1039
+
1040
+ print(f"Demo completed with loss: {loss.item()}")
1041
+
1042
+
1043
+ # Ensure to call the demo function to execute the forward and backward pass
1044
+ if __name__ == "__main__":
1045
+ demo_video_autoencoder_forward_backward()
flash_head/ltx_video/models/transformers/__init__.py ADDED
File without changes
flash_head/ltx_video/models/transformers/attention.py ADDED
@@ -0,0 +1,1265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from importlib import import_module
3
+ from typing import Any, Dict, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
8
+ from diffusers.models.attention import _chunked_feed_forward
9
+ from diffusers.models.attention_processor import (
10
+ LoRAAttnAddedKVProcessor,
11
+ LoRAAttnProcessor,
12
+ LoRAAttnProcessor2_0,
13
+ LoRAXFormersAttnProcessor,
14
+ SpatialNorm,
15
+ )
16
+ from diffusers.models.lora import LoRACompatibleLinear
17
+ from diffusers.models.normalization import RMSNorm
18
+ from diffusers.utils import deprecate, logging
19
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
20
+ from einops import rearrange
21
+ from torch import nn
22
+
23
+ from flash_head.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
24
+
25
+ try:
26
+ from torch_xla.experimental.custom_kernel import flash_attention
27
+ except ImportError:
28
+ # workaround for automatic tests. Currently this function is manually patched
29
+ # to the torch_xla lib on setup of container
30
+ pass
31
+
32
+ # code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ @maybe_allow_in_graph
38
+ class BasicTransformerBlock(nn.Module):
39
+ r"""
40
+ A basic Transformer block.
41
+
42
+ Parameters:
43
+ dim (`int`): The number of channels in the input and output.
44
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
45
+ attention_head_dim (`int`): The number of channels in each head.
46
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
47
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
48
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
49
+ num_embeds_ada_norm (:
50
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
51
+ attention_bias (:
52
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
53
+ only_cross_attention (`bool`, *optional*):
54
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
55
+ double_self_attention (`bool`, *optional*):
56
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
57
+ upcast_attention (`bool`, *optional*):
58
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
59
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
60
+ Whether to use learnable elementwise affine parameters for normalization.
61
+ qk_norm (`str`, *optional*, defaults to None):
62
+ Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
63
+ adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`):
64
+ The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none".
65
+ standardization_norm (`str`, *optional*, defaults to `"layer_norm"`):
66
+ The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
67
+ final_dropout (`bool` *optional*, defaults to False):
68
+ Whether to apply a final dropout after the last feed-forward layer.
69
+ attention_type (`str`, *optional*, defaults to `"default"`):
70
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
71
+ positional_embeddings (`str`, *optional*, defaults to `None`):
72
+ The type of positional embeddings to apply to.
73
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
74
+ The maximum number of positional embeddings to apply.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ dim: int,
80
+ num_attention_heads: int,
81
+ attention_head_dim: int,
82
+ dropout=0.0,
83
+ cross_attention_dim: Optional[int] = None,
84
+ activation_fn: str = "geglu",
85
+ num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument
86
+ attention_bias: bool = False,
87
+ only_cross_attention: bool = False,
88
+ double_self_attention: bool = False,
89
+ upcast_attention: bool = False,
90
+ norm_elementwise_affine: bool = True,
91
+ adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none'
92
+ standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
93
+ norm_eps: float = 1e-5,
94
+ qk_norm: Optional[str] = None,
95
+ final_dropout: bool = False,
96
+ attention_type: str = "default", # pylint: disable=unused-argument
97
+ ff_inner_dim: Optional[int] = None,
98
+ ff_bias: bool = True,
99
+ attention_out_bias: bool = True,
100
+ use_tpu_flash_attention: bool = False,
101
+ use_rope: bool = False,
102
+ ):
103
+ super().__init__()
104
+ self.only_cross_attention = only_cross_attention
105
+ self.use_tpu_flash_attention = use_tpu_flash_attention
106
+ self.adaptive_norm = adaptive_norm
107
+
108
+ assert standardization_norm in ["layer_norm", "rms_norm"]
109
+ assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
110
+
111
+ make_norm_layer = (
112
+ nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
113
+ )
114
+
115
+ # Define 3 blocks. Each block has its own normalization layer.
116
+ # 1. Self-Attn
117
+ self.norm1 = make_norm_layer(
118
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
119
+ )
120
+
121
+ self.attn1 = Attention(
122
+ query_dim=dim,
123
+ heads=num_attention_heads,
124
+ dim_head=attention_head_dim,
125
+ dropout=dropout,
126
+ bias=attention_bias,
127
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
128
+ upcast_attention=upcast_attention,
129
+ out_bias=attention_out_bias,
130
+ use_tpu_flash_attention=use_tpu_flash_attention,
131
+ qk_norm=qk_norm,
132
+ use_rope=use_rope,
133
+ )
134
+
135
+ # 2. Cross-Attn
136
+ if cross_attention_dim is not None or double_self_attention:
137
+ self.attn2 = Attention(
138
+ query_dim=dim,
139
+ cross_attention_dim=(
140
+ cross_attention_dim if not double_self_attention else None
141
+ ),
142
+ heads=num_attention_heads,
143
+ dim_head=attention_head_dim,
144
+ dropout=dropout,
145
+ bias=attention_bias,
146
+ upcast_attention=upcast_attention,
147
+ out_bias=attention_out_bias,
148
+ use_tpu_flash_attention=use_tpu_flash_attention,
149
+ qk_norm=qk_norm,
150
+ use_rope=use_rope,
151
+ ) # is self-attn if encoder_hidden_states is none
152
+
153
+ if adaptive_norm == "none":
154
+ self.attn2_norm = make_norm_layer(
155
+ dim, norm_eps, norm_elementwise_affine
156
+ )
157
+ else:
158
+ self.attn2 = None
159
+ self.attn2_norm = None
160
+
161
+ self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine)
162
+
163
+ # 3. Feed-forward
164
+ self.ff = FeedForward(
165
+ dim,
166
+ dropout=dropout,
167
+ activation_fn=activation_fn,
168
+ final_dropout=final_dropout,
169
+ inner_dim=ff_inner_dim,
170
+ bias=ff_bias,
171
+ )
172
+
173
+ # 5. Scale-shift for PixArt-Alpha.
174
+ if adaptive_norm != "none":
175
+ num_ada_params = 4 if adaptive_norm == "single_scale" else 6
176
+ self.scale_shift_table = nn.Parameter(
177
+ torch.randn(num_ada_params, dim) / dim**0.5
178
+ )
179
+
180
+ # let chunk size default to None
181
+ self._chunk_size = None
182
+ self._chunk_dim = 0
183
+
184
+ def set_use_tpu_flash_attention(self):
185
+ r"""
186
+ Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
187
+ attention kernel.
188
+ """
189
+ self.use_tpu_flash_attention = True
190
+ self.attn1.set_use_tpu_flash_attention()
191
+ self.attn2.set_use_tpu_flash_attention()
192
+
193
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
194
+ # Sets chunk feed-forward
195
+ self._chunk_size = chunk_size
196
+ self._chunk_dim = dim
197
+
198
+ def forward(
199
+ self,
200
+ hidden_states: torch.FloatTensor,
201
+ freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
202
+ attention_mask: Optional[torch.FloatTensor] = None,
203
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
204
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
205
+ timestep: Optional[torch.LongTensor] = None,
206
+ cross_attention_kwargs: Dict[str, Any] = None,
207
+ class_labels: Optional[torch.LongTensor] = None,
208
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
209
+ skip_layer_mask: Optional[torch.Tensor] = None,
210
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
211
+ ) -> torch.FloatTensor:
212
+ if cross_attention_kwargs is not None:
213
+ if cross_attention_kwargs.get("scale", None) is not None:
214
+ logger.warning(
215
+ "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored."
216
+ )
217
+
218
+ # Notice that normalization is always applied before the real computation in the following blocks.
219
+ # 0. Self-Attention
220
+ batch_size = hidden_states.shape[0]
221
+
222
+ original_hidden_states = hidden_states
223
+
224
+ norm_hidden_states = self.norm1(hidden_states)
225
+
226
+ # Apply ada_norm_single
227
+ if self.adaptive_norm in ["single_scale_shift", "single_scale"]:
228
+ assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim]
229
+ num_ada_params = self.scale_shift_table.shape[0]
230
+ ada_values = self.scale_shift_table[None, None] + timestep.reshape(
231
+ batch_size, timestep.shape[1], num_ada_params, -1
232
+ )
233
+ if self.adaptive_norm == "single_scale_shift":
234
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
235
+ ada_values.unbind(dim=2)
236
+ )
237
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
238
+ else:
239
+ scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
240
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa)
241
+ elif self.adaptive_norm == "none":
242
+ scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None
243
+ else:
244
+ raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
245
+
246
+ norm_hidden_states = norm_hidden_states.squeeze(
247
+ 1
248
+ ) # TODO: Check if this is needed
249
+
250
+ # 1. Prepare GLIGEN inputs
251
+ cross_attention_kwargs = (
252
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
253
+ )
254
+
255
+ attn_output = self.attn1(
256
+ norm_hidden_states,
257
+ freqs_cis=freqs_cis,
258
+ encoder_hidden_states=(
259
+ encoder_hidden_states if self.only_cross_attention else None
260
+ ),
261
+ attention_mask=attention_mask,
262
+ skip_layer_mask=skip_layer_mask,
263
+ skip_layer_strategy=skip_layer_strategy,
264
+ **cross_attention_kwargs,
265
+ )
266
+ if gate_msa is not None:
267
+ attn_output = gate_msa * attn_output
268
+
269
+ hidden_states = attn_output + hidden_states
270
+ if hidden_states.ndim == 4:
271
+ hidden_states = hidden_states.squeeze(1)
272
+
273
+ # 3. Cross-Attention
274
+ if self.attn2 is not None:
275
+ if self.adaptive_norm == "none":
276
+ attn_input = self.attn2_norm(hidden_states)
277
+ else:
278
+ attn_input = hidden_states
279
+ attn_output = self.attn2(
280
+ attn_input,
281
+ freqs_cis=freqs_cis,
282
+ encoder_hidden_states=encoder_hidden_states,
283
+ attention_mask=encoder_attention_mask,
284
+ **cross_attention_kwargs,
285
+ )
286
+ hidden_states = attn_output + hidden_states
287
+
288
+ # 4. Feed-forward
289
+ norm_hidden_states = self.norm2(hidden_states)
290
+ if self.adaptive_norm == "single_scale_shift":
291
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
292
+ elif self.adaptive_norm == "single_scale":
293
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp)
294
+ elif self.adaptive_norm == "none":
295
+ pass
296
+ else:
297
+ raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
298
+
299
+ if self._chunk_size is not None:
300
+ # "feed_forward_chunk_size" can be used to save memory
301
+ ff_output = _chunked_feed_forward(
302
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
303
+ )
304
+ else:
305
+ ff_output = self.ff(norm_hidden_states)
306
+ if gate_mlp is not None:
307
+ ff_output = gate_mlp * ff_output
308
+
309
+ hidden_states = ff_output + hidden_states
310
+ if hidden_states.ndim == 4:
311
+ hidden_states = hidden_states.squeeze(1)
312
+
313
+ if (
314
+ skip_layer_mask is not None
315
+ and skip_layer_strategy == SkipLayerStrategy.TransformerBlock
316
+ ):
317
+ skip_layer_mask = skip_layer_mask.view(-1, 1, 1)
318
+ hidden_states = hidden_states * skip_layer_mask + original_hidden_states * (
319
+ 1.0 - skip_layer_mask
320
+ )
321
+
322
+ return hidden_states
323
+
324
+
325
+ @maybe_allow_in_graph
326
+ class Attention(nn.Module):
327
+ r"""
328
+ A cross attention layer.
329
+
330
+ Parameters:
331
+ query_dim (`int`):
332
+ The number of channels in the query.
333
+ cross_attention_dim (`int`, *optional*):
334
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
335
+ heads (`int`, *optional*, defaults to 8):
336
+ The number of heads to use for multi-head attention.
337
+ dim_head (`int`, *optional*, defaults to 64):
338
+ The number of channels in each head.
339
+ dropout (`float`, *optional*, defaults to 0.0):
340
+ The dropout probability to use.
341
+ bias (`bool`, *optional*, defaults to False):
342
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
343
+ upcast_attention (`bool`, *optional*, defaults to False):
344
+ Set to `True` to upcast the attention computation to `float32`.
345
+ upcast_softmax (`bool`, *optional*, defaults to False):
346
+ Set to `True` to upcast the softmax computation to `float32`.
347
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
348
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
349
+ cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
350
+ The number of groups to use for the group norm in the cross attention.
351
+ added_kv_proj_dim (`int`, *optional*, defaults to `None`):
352
+ The number of channels to use for the added key and value projections. If `None`, no projection is used.
353
+ norm_num_groups (`int`, *optional*, defaults to `None`):
354
+ The number of groups to use for the group norm in the attention.
355
+ spatial_norm_dim (`int`, *optional*, defaults to `None`):
356
+ The number of channels to use for the spatial normalization.
357
+ out_bias (`bool`, *optional*, defaults to `True`):
358
+ Set to `True` to use a bias in the output linear layer.
359
+ scale_qk (`bool`, *optional*, defaults to `True`):
360
+ Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
361
+ qk_norm (`str`, *optional*, defaults to None):
362
+ Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
363
+ only_cross_attention (`bool`, *optional*, defaults to `False`):
364
+ Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
365
+ `added_kv_proj_dim` is not `None`.
366
+ eps (`float`, *optional*, defaults to 1e-5):
367
+ An additional value added to the denominator in group normalization that is used for numerical stability.
368
+ rescale_output_factor (`float`, *optional*, defaults to 1.0):
369
+ A factor to rescale the output by dividing it with this value.
370
+ residual_connection (`bool`, *optional*, defaults to `False`):
371
+ Set to `True` to add the residual connection to the output.
372
+ _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
373
+ Set to `True` if the attention block is loaded from a deprecated state dict.
374
+ processor (`AttnProcessor`, *optional*, defaults to `None`):
375
+ The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
376
+ `AttnProcessor` otherwise.
377
+ """
378
+
379
+ def __init__(
380
+ self,
381
+ query_dim: int,
382
+ cross_attention_dim: Optional[int] = None,
383
+ heads: int = 8,
384
+ dim_head: int = 64,
385
+ dropout: float = 0.0,
386
+ bias: bool = False,
387
+ upcast_attention: bool = False,
388
+ upcast_softmax: bool = False,
389
+ cross_attention_norm: Optional[str] = None,
390
+ cross_attention_norm_num_groups: int = 32,
391
+ added_kv_proj_dim: Optional[int] = None,
392
+ norm_num_groups: Optional[int] = None,
393
+ spatial_norm_dim: Optional[int] = None,
394
+ out_bias: bool = True,
395
+ scale_qk: bool = True,
396
+ qk_norm: Optional[str] = None,
397
+ only_cross_attention: bool = False,
398
+ eps: float = 1e-5,
399
+ rescale_output_factor: float = 1.0,
400
+ residual_connection: bool = False,
401
+ _from_deprecated_attn_block: bool = False,
402
+ processor: Optional["AttnProcessor"] = None,
403
+ out_dim: int = None,
404
+ use_tpu_flash_attention: bool = False,
405
+ use_rope: bool = False,
406
+ ):
407
+ super().__init__()
408
+ self.inner_dim = out_dim if out_dim is not None else dim_head * heads
409
+ self.query_dim = query_dim
410
+ self.use_bias = bias
411
+ self.is_cross_attention = cross_attention_dim is not None
412
+ self.cross_attention_dim = (
413
+ cross_attention_dim if cross_attention_dim is not None else query_dim
414
+ )
415
+ self.upcast_attention = upcast_attention
416
+ self.upcast_softmax = upcast_softmax
417
+ self.rescale_output_factor = rescale_output_factor
418
+ self.residual_connection = residual_connection
419
+ self.dropout = dropout
420
+ self.fused_projections = False
421
+ self.out_dim = out_dim if out_dim is not None else query_dim
422
+ self.use_tpu_flash_attention = use_tpu_flash_attention
423
+ self.use_rope = use_rope
424
+
425
+ # we make use of this private variable to know whether this class is loaded
426
+ # with an deprecated state dict so that we can convert it on the fly
427
+ self._from_deprecated_attn_block = _from_deprecated_attn_block
428
+
429
+ self.scale_qk = scale_qk
430
+ self.scale = dim_head**-0.5 if self.scale_qk else 1.0
431
+
432
+ if qk_norm is None:
433
+ self.q_norm = nn.Identity()
434
+ self.k_norm = nn.Identity()
435
+ elif qk_norm == "rms_norm":
436
+ self.q_norm = RMSNorm(dim_head * heads, eps=1e-5)
437
+ self.k_norm = RMSNorm(dim_head * heads, eps=1e-5)
438
+ elif qk_norm == "layer_norm":
439
+ self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
440
+ self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
441
+ else:
442
+ raise ValueError(f"Unsupported qk_norm method: {qk_norm}")
443
+
444
+ self.heads = out_dim // dim_head if out_dim is not None else heads
445
+ # for slice_size > 0 the attention score computation
446
+ # is split across the batch axis to save memory
447
+ # You can set slice_size with `set_attention_slice`
448
+ self.sliceable_head_dim = heads
449
+
450
+ self.added_kv_proj_dim = added_kv_proj_dim
451
+ self.only_cross_attention = only_cross_attention
452
+
453
+ if self.added_kv_proj_dim is None and self.only_cross_attention:
454
+ raise ValueError(
455
+ "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
456
+ )
457
+
458
+ if norm_num_groups is not None:
459
+ self.group_norm = nn.GroupNorm(
460
+ num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
461
+ )
462
+ else:
463
+ self.group_norm = None
464
+
465
+ if spatial_norm_dim is not None:
466
+ self.spatial_norm = SpatialNorm(
467
+ f_channels=query_dim, zq_channels=spatial_norm_dim
468
+ )
469
+ else:
470
+ self.spatial_norm = None
471
+
472
+ if cross_attention_norm is None:
473
+ self.norm_cross = None
474
+ elif cross_attention_norm == "layer_norm":
475
+ self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
476
+ elif cross_attention_norm == "group_norm":
477
+ if self.added_kv_proj_dim is not None:
478
+ # The given `encoder_hidden_states` are initially of shape
479
+ # (batch_size, seq_len, added_kv_proj_dim) before being projected
480
+ # to (batch_size, seq_len, cross_attention_dim). The norm is applied
481
+ # before the projection, so we need to use `added_kv_proj_dim` as
482
+ # the number of channels for the group norm.
483
+ norm_cross_num_channels = added_kv_proj_dim
484
+ else:
485
+ norm_cross_num_channels = self.cross_attention_dim
486
+
487
+ self.norm_cross = nn.GroupNorm(
488
+ num_channels=norm_cross_num_channels,
489
+ num_groups=cross_attention_norm_num_groups,
490
+ eps=1e-5,
491
+ affine=True,
492
+ )
493
+ else:
494
+ raise ValueError(
495
+ f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
496
+ )
497
+
498
+ linear_cls = nn.Linear
499
+
500
+ self.linear_cls = linear_cls
501
+ self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
502
+
503
+ if not self.only_cross_attention:
504
+ # only relevant for the `AddedKVProcessor` classes
505
+ self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
506
+ self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
507
+ else:
508
+ self.to_k = None
509
+ self.to_v = None
510
+
511
+ if self.added_kv_proj_dim is not None:
512
+ self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
513
+ self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
514
+
515
+ self.to_out = nn.ModuleList([])
516
+ self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
517
+ self.to_out.append(nn.Dropout(dropout))
518
+
519
+ # set attention processor
520
+ # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
521
+ # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
522
+ # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
523
+ if processor is None:
524
+ processor = AttnProcessor2_0()
525
+ self.set_processor(processor)
526
+
527
+ def set_use_tpu_flash_attention(self):
528
+ r"""
529
+ Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
530
+ """
531
+ self.use_tpu_flash_attention = True
532
+
533
+ def set_processor(self, processor: "AttnProcessor") -> None:
534
+ r"""
535
+ Set the attention processor to use.
536
+
537
+ Args:
538
+ processor (`AttnProcessor`):
539
+ The attention processor to use.
540
+ """
541
+ # if current processor is in `self._modules` and if passed `processor` is not, we need to
542
+ # pop `processor` from `self._modules`
543
+ if (
544
+ hasattr(self, "processor")
545
+ and isinstance(self.processor, torch.nn.Module)
546
+ and not isinstance(processor, torch.nn.Module)
547
+ ):
548
+ logger.info(
549
+ f"You are removing possibly trained weights of {self.processor} with {processor}"
550
+ )
551
+ self._modules.pop("processor")
552
+
553
+ self.processor = processor
554
+
555
+ def get_processor(
556
+ self, return_deprecated_lora: bool = False
557
+ ) -> "AttentionProcessor": # noqa: F821
558
+ r"""
559
+ Get the attention processor in use.
560
+
561
+ Args:
562
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
563
+ Set to `True` to return the deprecated LoRA attention processor.
564
+
565
+ Returns:
566
+ "AttentionProcessor": The attention processor in use.
567
+ """
568
+ if not return_deprecated_lora:
569
+ return self.processor
570
+
571
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
572
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
573
+ # with PEFT is completed.
574
+ is_lora_activated = {
575
+ name: module.lora_layer is not None
576
+ for name, module in self.named_modules()
577
+ if hasattr(module, "lora_layer")
578
+ }
579
+
580
+ # 1. if no layer has a LoRA activated we can return the processor as usual
581
+ if not any(is_lora_activated.values()):
582
+ return self.processor
583
+
584
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
585
+ is_lora_activated.pop("add_k_proj", None)
586
+ is_lora_activated.pop("add_v_proj", None)
587
+ # 2. else it is not posssible that only some layers have LoRA activated
588
+ if not all(is_lora_activated.values()):
589
+ raise ValueError(
590
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
591
+ )
592
+
593
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
594
+ non_lora_processor_cls_name = self.processor.__class__.__name__
595
+ lora_processor_cls = getattr(
596
+ import_module(__name__), "LoRA" + non_lora_processor_cls_name
597
+ )
598
+
599
+ hidden_size = self.inner_dim
600
+
601
+ # now create a LoRA attention processor from the LoRA layers
602
+ if lora_processor_cls in [
603
+ LoRAAttnProcessor,
604
+ LoRAAttnProcessor2_0,
605
+ LoRAXFormersAttnProcessor,
606
+ ]:
607
+ kwargs = {
608
+ "cross_attention_dim": self.cross_attention_dim,
609
+ "rank": self.to_q.lora_layer.rank,
610
+ "network_alpha": self.to_q.lora_layer.network_alpha,
611
+ "q_rank": self.to_q.lora_layer.rank,
612
+ "q_hidden_size": self.to_q.lora_layer.out_features,
613
+ "k_rank": self.to_k.lora_layer.rank,
614
+ "k_hidden_size": self.to_k.lora_layer.out_features,
615
+ "v_rank": self.to_v.lora_layer.rank,
616
+ "v_hidden_size": self.to_v.lora_layer.out_features,
617
+ "out_rank": self.to_out[0].lora_layer.rank,
618
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
619
+ }
620
+
621
+ if hasattr(self.processor, "attention_op"):
622
+ kwargs["attention_op"] = self.processor.attention_op
623
+
624
+ lora_processor = lora_processor_cls(hidden_size, **kwargs)
625
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
626
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
627
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
628
+ lora_processor.to_out_lora.load_state_dict(
629
+ self.to_out[0].lora_layer.state_dict()
630
+ )
631
+ elif lora_processor_cls == LoRAAttnAddedKVProcessor:
632
+ lora_processor = lora_processor_cls(
633
+ hidden_size,
634
+ cross_attention_dim=self.add_k_proj.weight.shape[0],
635
+ rank=self.to_q.lora_layer.rank,
636
+ network_alpha=self.to_q.lora_layer.network_alpha,
637
+ )
638
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
639
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
640
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
641
+ lora_processor.to_out_lora.load_state_dict(
642
+ self.to_out[0].lora_layer.state_dict()
643
+ )
644
+
645
+ # only save if used
646
+ if self.add_k_proj.lora_layer is not None:
647
+ lora_processor.add_k_proj_lora.load_state_dict(
648
+ self.add_k_proj.lora_layer.state_dict()
649
+ )
650
+ lora_processor.add_v_proj_lora.load_state_dict(
651
+ self.add_v_proj.lora_layer.state_dict()
652
+ )
653
+ else:
654
+ lora_processor.add_k_proj_lora = None
655
+ lora_processor.add_v_proj_lora = None
656
+ else:
657
+ raise ValueError(f"{lora_processor_cls} does not exist.")
658
+
659
+ return lora_processor
660
+
661
+ def forward(
662
+ self,
663
+ hidden_states: torch.FloatTensor,
664
+ freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
665
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
666
+ attention_mask: Optional[torch.FloatTensor] = None,
667
+ skip_layer_mask: Optional[torch.Tensor] = None,
668
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
669
+ **cross_attention_kwargs,
670
+ ) -> torch.Tensor:
671
+ r"""
672
+ The forward method of the `Attention` class.
673
+
674
+ Args:
675
+ hidden_states (`torch.Tensor`):
676
+ The hidden states of the query.
677
+ encoder_hidden_states (`torch.Tensor`, *optional*):
678
+ The hidden states of the encoder.
679
+ attention_mask (`torch.Tensor`, *optional*):
680
+ The attention mask to use. If `None`, no mask is applied.
681
+ skip_layer_mask (`torch.Tensor`, *optional*):
682
+ The skip layer mask to use. If `None`, no mask is applied.
683
+ skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`):
684
+ Controls which layers to skip for spatiotemporal guidance.
685
+ **cross_attention_kwargs:
686
+ Additional keyword arguments to pass along to the cross attention.
687
+
688
+ Returns:
689
+ `torch.Tensor`: The output of the attention layer.
690
+ """
691
+ # The `Attention` class can call different attention processors / attention functions
692
+ # here we simply pass along all tensors to the selected processor class
693
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
694
+
695
+ attn_parameters = set(
696
+ inspect.signature(self.processor.__call__).parameters.keys()
697
+ )
698
+ unused_kwargs = [
699
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters
700
+ ]
701
+ if len(unused_kwargs) > 0:
702
+ logger.warning(
703
+ f"cross_attention_kwargs {unused_kwargs} are not expected by"
704
+ f" {self.processor.__class__.__name__} and will be ignored."
705
+ )
706
+ cross_attention_kwargs = {
707
+ k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
708
+ }
709
+
710
+ return self.processor(
711
+ self,
712
+ hidden_states,
713
+ freqs_cis=freqs_cis,
714
+ encoder_hidden_states=encoder_hidden_states,
715
+ attention_mask=attention_mask,
716
+ skip_layer_mask=skip_layer_mask,
717
+ skip_layer_strategy=skip_layer_strategy,
718
+ **cross_attention_kwargs,
719
+ )
720
+
721
+ def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
722
+ r"""
723
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
724
+ is the number of heads initialized while constructing the `Attention` class.
725
+
726
+ Args:
727
+ tensor (`torch.Tensor`): The tensor to reshape.
728
+
729
+ Returns:
730
+ `torch.Tensor`: The reshaped tensor.
731
+ """
732
+ head_size = self.heads
733
+ batch_size, seq_len, dim = tensor.shape
734
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
735
+ tensor = tensor.permute(0, 2, 1, 3).reshape(
736
+ batch_size // head_size, seq_len, dim * head_size
737
+ )
738
+ return tensor
739
+
740
+ def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
741
+ r"""
742
+ Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
743
+ the number of heads initialized while constructing the `Attention` class.
744
+
745
+ Args:
746
+ tensor (`torch.Tensor`): The tensor to reshape.
747
+ out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
748
+ reshaped to `[batch_size * heads, seq_len, dim // heads]`.
749
+
750
+ Returns:
751
+ `torch.Tensor`: The reshaped tensor.
752
+ """
753
+
754
+ head_size = self.heads
755
+ if tensor.ndim == 3:
756
+ batch_size, seq_len, dim = tensor.shape
757
+ extra_dim = 1
758
+ else:
759
+ batch_size, extra_dim, seq_len, dim = tensor.shape
760
+ tensor = tensor.reshape(
761
+ batch_size, seq_len * extra_dim, head_size, dim // head_size
762
+ )
763
+ tensor = tensor.permute(0, 2, 1, 3)
764
+
765
+ if out_dim == 3:
766
+ tensor = tensor.reshape(
767
+ batch_size * head_size, seq_len * extra_dim, dim // head_size
768
+ )
769
+
770
+ return tensor
771
+
772
+ def get_attention_scores(
773
+ self,
774
+ query: torch.Tensor,
775
+ key: torch.Tensor,
776
+ attention_mask: torch.Tensor = None,
777
+ ) -> torch.Tensor:
778
+ r"""
779
+ Compute the attention scores.
780
+
781
+ Args:
782
+ query (`torch.Tensor`): The query tensor.
783
+ key (`torch.Tensor`): The key tensor.
784
+ attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
785
+
786
+ Returns:
787
+ `torch.Tensor`: The attention probabilities/scores.
788
+ """
789
+ dtype = query.dtype
790
+ if self.upcast_attention:
791
+ query = query.float()
792
+ key = key.float()
793
+
794
+ if attention_mask is None:
795
+ baddbmm_input = torch.empty(
796
+ query.shape[0],
797
+ query.shape[1],
798
+ key.shape[1],
799
+ dtype=query.dtype,
800
+ device=query.device,
801
+ )
802
+ beta = 0
803
+ else:
804
+ baddbmm_input = attention_mask
805
+ beta = 1
806
+
807
+ attention_scores = torch.baddbmm(
808
+ baddbmm_input,
809
+ query,
810
+ key.transpose(-1, -2),
811
+ beta=beta,
812
+ alpha=self.scale,
813
+ )
814
+ del baddbmm_input
815
+
816
+ if self.upcast_softmax:
817
+ attention_scores = attention_scores.float()
818
+
819
+ attention_probs = attention_scores.softmax(dim=-1)
820
+ del attention_scores
821
+
822
+ attention_probs = attention_probs.to(dtype)
823
+
824
+ return attention_probs
825
+
826
+ def prepare_attention_mask(
827
+ self,
828
+ attention_mask: torch.Tensor,
829
+ target_length: int,
830
+ batch_size: int,
831
+ out_dim: int = 3,
832
+ ) -> torch.Tensor:
833
+ r"""
834
+ Prepare the attention mask for the attention computation.
835
+
836
+ Args:
837
+ attention_mask (`torch.Tensor`):
838
+ The attention mask to prepare.
839
+ target_length (`int`):
840
+ The target length of the attention mask. This is the length of the attention mask after padding.
841
+ batch_size (`int`):
842
+ The batch size, which is used to repeat the attention mask.
843
+ out_dim (`int`, *optional*, defaults to `3`):
844
+ The output dimension of the attention mask. Can be either `3` or `4`.
845
+
846
+ Returns:
847
+ `torch.Tensor`: The prepared attention mask.
848
+ """
849
+ head_size = self.heads
850
+ if attention_mask is None:
851
+ return attention_mask
852
+
853
+ current_length: int = attention_mask.shape[-1]
854
+ if current_length != target_length:
855
+ if attention_mask.device.type == "mps":
856
+ # HACK: MPS: Does not support padding by greater than dimension of input tensor.
857
+ # Instead, we can manually construct the padding tensor.
858
+ padding_shape = (
859
+ attention_mask.shape[0],
860
+ attention_mask.shape[1],
861
+ target_length,
862
+ )
863
+ padding = torch.zeros(
864
+ padding_shape,
865
+ dtype=attention_mask.dtype,
866
+ device=attention_mask.device,
867
+ )
868
+ attention_mask = torch.cat([attention_mask, padding], dim=2)
869
+ else:
870
+ # TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
871
+ # we want to instead pad by (0, remaining_length), where remaining_length is:
872
+ # remaining_length: int = target_length - current_length
873
+ # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
874
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
875
+
876
+ if out_dim == 3:
877
+ if attention_mask.shape[0] < batch_size * head_size:
878
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
879
+ elif out_dim == 4:
880
+ attention_mask = attention_mask.unsqueeze(1)
881
+ attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
882
+
883
+ return attention_mask
884
+
885
+ def norm_encoder_hidden_states(
886
+ self, encoder_hidden_states: torch.Tensor
887
+ ) -> torch.Tensor:
888
+ r"""
889
+ Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
890
+ `Attention` class.
891
+
892
+ Args:
893
+ encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
894
+
895
+ Returns:
896
+ `torch.Tensor`: The normalized encoder hidden states.
897
+ """
898
+ assert (
899
+ self.norm_cross is not None
900
+ ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
901
+
902
+ if isinstance(self.norm_cross, nn.LayerNorm):
903
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
904
+ elif isinstance(self.norm_cross, nn.GroupNorm):
905
+ # Group norm norms along the channels dimension and expects
906
+ # input to be in the shape of (N, C, *). In this case, we want
907
+ # to norm along the hidden dimension, so we need to move
908
+ # (batch_size, sequence_length, hidden_size) ->
909
+ # (batch_size, hidden_size, sequence_length)
910
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
911
+ encoder_hidden_states = self.norm_cross(encoder_hidden_states)
912
+ encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
913
+ else:
914
+ assert False
915
+
916
+ return encoder_hidden_states
917
+
918
+ @staticmethod
919
+ def apply_rotary_emb(
920
+ input_tensor: torch.Tensor,
921
+ freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
922
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
923
+ cos_freqs = freqs_cis[0]
924
+ sin_freqs = freqs_cis[1]
925
+
926
+ t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
927
+ t1, t2 = t_dup.unbind(dim=-1)
928
+ t_dup = torch.stack((-t2, t1), dim=-1)
929
+ input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
930
+
931
+ out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
932
+
933
+ return out
934
+
935
+
936
+ class AttnProcessor2_0:
937
+ r"""
938
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
939
+ """
940
+
941
+ def __init__(self):
942
+ pass
943
+
944
+ def __call__(
945
+ self,
946
+ attn: Attention,
947
+ hidden_states: torch.FloatTensor,
948
+ freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
949
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
950
+ attention_mask: Optional[torch.FloatTensor] = None,
951
+ temb: Optional[torch.FloatTensor] = None,
952
+ skip_layer_mask: Optional[torch.FloatTensor] = None,
953
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
954
+ *args,
955
+ **kwargs,
956
+ ) -> torch.FloatTensor:
957
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
958
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
959
+ deprecate("scale", "1.0.0", deprecation_message)
960
+
961
+ residual = hidden_states
962
+ if attn.spatial_norm is not None:
963
+ hidden_states = attn.spatial_norm(hidden_states, temb)
964
+
965
+ input_ndim = hidden_states.ndim
966
+
967
+ if input_ndim == 4:
968
+ batch_size, channel, height, width = hidden_states.shape
969
+ hidden_states = hidden_states.view(
970
+ batch_size, channel, height * width
971
+ ).transpose(1, 2)
972
+
973
+ batch_size, sequence_length, _ = (
974
+ hidden_states.shape
975
+ if encoder_hidden_states is None
976
+ else encoder_hidden_states.shape
977
+ )
978
+
979
+ if skip_layer_mask is not None:
980
+ skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1)
981
+
982
+ if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
983
+ attention_mask = attn.prepare_attention_mask(
984
+ attention_mask, sequence_length, batch_size
985
+ )
986
+ # scaled_dot_product_attention expects attention_mask shape to be
987
+ # (batch, heads, source_length, target_length)
988
+ attention_mask = attention_mask.view(
989
+ batch_size, attn.heads, -1, attention_mask.shape[-1]
990
+ )
991
+
992
+ if attn.group_norm is not None:
993
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
994
+ 1, 2
995
+ )
996
+
997
+ query = attn.to_q(hidden_states)
998
+ query = attn.q_norm(query)
999
+
1000
+ if encoder_hidden_states is not None:
1001
+ if attn.norm_cross:
1002
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
1003
+ encoder_hidden_states
1004
+ )
1005
+ key = attn.to_k(encoder_hidden_states)
1006
+ key = attn.k_norm(key)
1007
+ else: # if no context provided do self-attention
1008
+ encoder_hidden_states = hidden_states
1009
+ key = attn.to_k(hidden_states)
1010
+ key = attn.k_norm(key)
1011
+ if attn.use_rope:
1012
+ key = attn.apply_rotary_emb(key, freqs_cis)
1013
+ query = attn.apply_rotary_emb(query, freqs_cis)
1014
+
1015
+ value = attn.to_v(encoder_hidden_states)
1016
+ value_for_stg = value
1017
+
1018
+ inner_dim = key.shape[-1]
1019
+ head_dim = inner_dim // attn.heads
1020
+
1021
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1022
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1023
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1024
+
1025
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
1026
+
1027
+ if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
1028
+ q_segment_indexes = None
1029
+ if (
1030
+ attention_mask is not None
1031
+ ): # if mask is required need to tune both segmenIds fields
1032
+ # attention_mask = torch.squeeze(attention_mask).to(torch.float32)
1033
+ attention_mask = attention_mask.to(torch.float32)
1034
+ q_segment_indexes = torch.ones(
1035
+ batch_size, query.shape[2], device=query.device, dtype=torch.float32
1036
+ )
1037
+ assert (
1038
+ attention_mask.shape[1] == key.shape[2]
1039
+ ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
1040
+
1041
+ assert (
1042
+ query.shape[2] % 128 == 0
1043
+ ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]"
1044
+ assert (
1045
+ key.shape[2] % 128 == 0
1046
+ ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]"
1047
+
1048
+ # run the TPU kernel implemented in jax with pallas
1049
+ hidden_states_a = flash_attention(
1050
+ q=query,
1051
+ k=key,
1052
+ v=value,
1053
+ q_segment_ids=q_segment_indexes,
1054
+ kv_segment_ids=attention_mask,
1055
+ sm_scale=attn.scale,
1056
+ )
1057
+ else:
1058
+ hidden_states_a = F.scaled_dot_product_attention(
1059
+ query,
1060
+ key,
1061
+ value,
1062
+ attn_mask=attention_mask,
1063
+ dropout_p=0.0,
1064
+ is_causal=False,
1065
+ )
1066
+
1067
+ hidden_states_a = hidden_states_a.transpose(1, 2).reshape(
1068
+ batch_size, -1, attn.heads * head_dim
1069
+ )
1070
+ hidden_states_a = hidden_states_a.to(query.dtype)
1071
+
1072
+ if (
1073
+ skip_layer_mask is not None
1074
+ and skip_layer_strategy == SkipLayerStrategy.AttentionSkip
1075
+ ):
1076
+ hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (
1077
+ 1.0 - skip_layer_mask
1078
+ )
1079
+ elif (
1080
+ skip_layer_mask is not None
1081
+ and skip_layer_strategy == SkipLayerStrategy.AttentionValues
1082
+ ):
1083
+ hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (
1084
+ 1.0 - skip_layer_mask
1085
+ )
1086
+ else:
1087
+ hidden_states = hidden_states_a
1088
+
1089
+ # linear proj
1090
+ hidden_states = attn.to_out[0](hidden_states)
1091
+ # dropout
1092
+ hidden_states = attn.to_out[1](hidden_states)
1093
+
1094
+ if input_ndim == 4:
1095
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
1096
+ batch_size, channel, height, width
1097
+ )
1098
+ if (
1099
+ skip_layer_mask is not None
1100
+ and skip_layer_strategy == SkipLayerStrategy.Residual
1101
+ ):
1102
+ skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1)
1103
+
1104
+ if attn.residual_connection:
1105
+ if (
1106
+ skip_layer_mask is not None
1107
+ and skip_layer_strategy == SkipLayerStrategy.Residual
1108
+ ):
1109
+ hidden_states = hidden_states + residual * skip_layer_mask
1110
+ else:
1111
+ hidden_states = hidden_states + residual
1112
+
1113
+ hidden_states = hidden_states / attn.rescale_output_factor
1114
+
1115
+ return hidden_states
1116
+
1117
+
1118
+ class AttnProcessor:
1119
+ r"""
1120
+ Default processor for performing attention-related computations.
1121
+ """
1122
+
1123
+ def __call__(
1124
+ self,
1125
+ attn: Attention,
1126
+ hidden_states: torch.FloatTensor,
1127
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1128
+ attention_mask: Optional[torch.FloatTensor] = None,
1129
+ temb: Optional[torch.FloatTensor] = None,
1130
+ *args,
1131
+ **kwargs,
1132
+ ) -> torch.Tensor:
1133
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1134
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1135
+ deprecate("scale", "1.0.0", deprecation_message)
1136
+
1137
+ residual = hidden_states
1138
+
1139
+ if attn.spatial_norm is not None:
1140
+ hidden_states = attn.spatial_norm(hidden_states, temb)
1141
+
1142
+ input_ndim = hidden_states.ndim
1143
+
1144
+ if input_ndim == 4:
1145
+ batch_size, channel, height, width = hidden_states.shape
1146
+ hidden_states = hidden_states.view(
1147
+ batch_size, channel, height * width
1148
+ ).transpose(1, 2)
1149
+
1150
+ batch_size, sequence_length, _ = (
1151
+ hidden_states.shape
1152
+ if encoder_hidden_states is None
1153
+ else encoder_hidden_states.shape
1154
+ )
1155
+ attention_mask = attn.prepare_attention_mask(
1156
+ attention_mask, sequence_length, batch_size
1157
+ )
1158
+
1159
+ if attn.group_norm is not None:
1160
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
1161
+ 1, 2
1162
+ )
1163
+
1164
+ query = attn.to_q(hidden_states)
1165
+
1166
+ if encoder_hidden_states is None:
1167
+ encoder_hidden_states = hidden_states
1168
+ elif attn.norm_cross:
1169
+ encoder_hidden_states = attn.norm_encoder_hidden_states(
1170
+ encoder_hidden_states
1171
+ )
1172
+
1173
+ key = attn.to_k(encoder_hidden_states)
1174
+ value = attn.to_v(encoder_hidden_states)
1175
+
1176
+ query = attn.head_to_batch_dim(query)
1177
+ key = attn.head_to_batch_dim(key)
1178
+ value = attn.head_to_batch_dim(value)
1179
+
1180
+ query = attn.q_norm(query)
1181
+ key = attn.k_norm(key)
1182
+
1183
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
1184
+ hidden_states = torch.bmm(attention_probs, value)
1185
+ hidden_states = attn.batch_to_head_dim(hidden_states)
1186
+
1187
+ # linear proj
1188
+ hidden_states = attn.to_out[0](hidden_states)
1189
+ # dropout
1190
+ hidden_states = attn.to_out[1](hidden_states)
1191
+
1192
+ if input_ndim == 4:
1193
+ hidden_states = hidden_states.transpose(-1, -2).reshape(
1194
+ batch_size, channel, height, width
1195
+ )
1196
+
1197
+ if attn.residual_connection:
1198
+ hidden_states = hidden_states + residual
1199
+
1200
+ hidden_states = hidden_states / attn.rescale_output_factor
1201
+
1202
+ return hidden_states
1203
+
1204
+
1205
+ class FeedForward(nn.Module):
1206
+ r"""
1207
+ A feed-forward layer.
1208
+
1209
+ Parameters:
1210
+ dim (`int`): The number of channels in the input.
1211
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
1212
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
1213
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
1214
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
1215
+ final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
1216
+ bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
1217
+ """
1218
+
1219
+ def __init__(
1220
+ self,
1221
+ dim: int,
1222
+ dim_out: Optional[int] = None,
1223
+ mult: int = 4,
1224
+ dropout: float = 0.0,
1225
+ activation_fn: str = "geglu",
1226
+ final_dropout: bool = False,
1227
+ inner_dim=None,
1228
+ bias: bool = True,
1229
+ ):
1230
+ super().__init__()
1231
+ if inner_dim is None:
1232
+ inner_dim = int(dim * mult)
1233
+ dim_out = dim_out if dim_out is not None else dim
1234
+ linear_cls = nn.Linear
1235
+
1236
+ if activation_fn == "gelu":
1237
+ act_fn = GELU(dim, inner_dim, bias=bias)
1238
+ elif activation_fn == "gelu-approximate":
1239
+ act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
1240
+ elif activation_fn == "geglu":
1241
+ act_fn = GEGLU(dim, inner_dim, bias=bias)
1242
+ elif activation_fn == "geglu-approximate":
1243
+ act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
1244
+ else:
1245
+ raise ValueError(f"Unsupported activation function: {activation_fn}")
1246
+
1247
+ self.net = nn.ModuleList([])
1248
+ # project in
1249
+ self.net.append(act_fn)
1250
+ # project dropout
1251
+ self.net.append(nn.Dropout(dropout))
1252
+ # project out
1253
+ self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
1254
+ # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
1255
+ if final_dropout:
1256
+ self.net.append(nn.Dropout(dropout))
1257
+
1258
+ def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
1259
+ compatible_cls = (GEGLU, LoRACompatibleLinear)
1260
+ for module in self.net:
1261
+ if isinstance(module, compatible_cls):
1262
+ hidden_states = module(hidden_states, scale)
1263
+ else:
1264
+ hidden_states = module(hidden_states)
1265
+ return hidden_states
flash_head/ltx_video/models/transformers/embeddings.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import nn
8
+
9
+
10
+ def get_timestep_embedding(
11
+ timesteps: torch.Tensor,
12
+ embedding_dim: int,
13
+ flip_sin_to_cos: bool = False,
14
+ downscale_freq_shift: float = 1,
15
+ scale: float = 1,
16
+ max_period: int = 10000,
17
+ ):
18
+ """
19
+ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
20
+
21
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
22
+ These may be fractional.
23
+ :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
24
+ embeddings. :return: an [N x dim] Tensor of positional embeddings.
25
+ """
26
+ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
27
+
28
+ half_dim = embedding_dim // 2
29
+ exponent = -math.log(max_period) * torch.arange(
30
+ start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
31
+ )
32
+ exponent = exponent / (half_dim - downscale_freq_shift)
33
+
34
+ emb = torch.exp(exponent)
35
+ emb = timesteps[:, None].float() * emb[None, :]
36
+
37
+ # scale embeddings
38
+ emb = scale * emb
39
+
40
+ # concat sine and cosine embeddings
41
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
42
+
43
+ # flip sine and cosine embeddings
44
+ if flip_sin_to_cos:
45
+ emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
46
+
47
+ # zero pad
48
+ if embedding_dim % 2 == 1:
49
+ emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
50
+ return emb
51
+
52
+
53
+ def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
54
+ """
55
+ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
56
+ [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
57
+ """
58
+ grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
59
+ grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
60
+ grid = grid.reshape([3, 1, w, h, f])
61
+ pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
62
+ pos_embed = pos_embed.transpose(1, 0, 2, 3)
63
+ return rearrange(pos_embed, "h w f c -> (f h w) c")
64
+
65
+
66
+ def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
67
+ if embed_dim % 3 != 0:
68
+ raise ValueError("embed_dim must be divisible by 3")
69
+
70
+ # use half of dimensions to encode grid_h
71
+ emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
72
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
73
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
74
+
75
+ emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
76
+ return emb
77
+
78
+
79
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
80
+ """
81
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
82
+ """
83
+ if embed_dim % 2 != 0:
84
+ raise ValueError("embed_dim must be divisible by 2")
85
+
86
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
87
+ omega /= embed_dim / 2.0
88
+ omega = 1.0 / 10000**omega # (D/2,)
89
+
90
+ pos_shape = pos.shape
91
+
92
+ pos = pos.reshape(-1)
93
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
94
+ out = out.reshape([*pos_shape, -1])[0]
95
+
96
+ emb_sin = np.sin(out) # (M, D/2)
97
+ emb_cos = np.cos(out) # (M, D/2)
98
+
99
+ emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
100
+ return emb
101
+
102
+
103
+ class SinusoidalPositionalEmbedding(nn.Module):
104
+ """Apply positional information to a sequence of embeddings.
105
+
106
+ Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
107
+ them
108
+
109
+ Args:
110
+ embed_dim: (int): Dimension of the positional embedding.
111
+ max_seq_length: Maximum sequence length to apply positional embeddings
112
+
113
+ """
114
+
115
+ def __init__(self, embed_dim: int, max_seq_length: int = 32):
116
+ super().__init__()
117
+ position = torch.arange(max_seq_length).unsqueeze(1)
118
+ div_term = torch.exp(
119
+ torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
120
+ )
121
+ pe = torch.zeros(1, max_seq_length, embed_dim)
122
+ pe[0, :, 0::2] = torch.sin(position * div_term)
123
+ pe[0, :, 1::2] = torch.cos(position * div_term)
124
+ self.register_buffer("pe", pe)
125
+
126
+ def forward(self, x):
127
+ _, seq_length, _ = x.shape
128
+ x = x + self.pe[:, :seq_length]
129
+ return x
flash_head/ltx_video/models/transformers/symmetric_patchifier.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Tuple
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin
6
+ from einops import rearrange
7
+ from torch import Tensor
8
+
9
+
10
+ class Patchifier(ConfigMixin, ABC):
11
+ def __init__(self, patch_size: int):
12
+ super().__init__()
13
+ self._patch_size = (1, patch_size, patch_size)
14
+
15
+ @abstractmethod
16
+ def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
17
+ raise NotImplementedError("Patchify method not implemented")
18
+
19
+ @abstractmethod
20
+ def unpatchify(
21
+ self,
22
+ latents: Tensor,
23
+ output_height: int,
24
+ output_width: int,
25
+ out_channels: int,
26
+ ) -> Tuple[Tensor, Tensor]:
27
+ pass
28
+
29
+ @property
30
+ def patch_size(self):
31
+ return self._patch_size
32
+
33
+ def get_latent_coords(
34
+ self, latent_num_frames, latent_height, latent_width, batch_size, device
35
+ ):
36
+ """
37
+ Return a tensor of shape [batch_size, 3, num_patches] containing the
38
+ top-left corner latent coordinates of each latent patch.
39
+ The tensor is repeated for each batch element.
40
+ """
41
+ latent_sample_coords = torch.meshgrid(
42
+ torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
43
+ torch.arange(0, latent_height, self._patch_size[1], device=device),
44
+ torch.arange(0, latent_width, self._patch_size[2], device=device),
45
+ )
46
+ latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
47
+ latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
48
+ latent_coords = rearrange(
49
+ latent_coords, "b c f h w -> b c (f h w)", b=batch_size
50
+ )
51
+ return latent_coords
52
+
53
+
54
+ class SymmetricPatchifier(Patchifier):
55
+ def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
56
+ b, _, f, h, w = latents.shape
57
+ latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
58
+ latents = rearrange(
59
+ latents,
60
+ "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
61
+ p1=self._patch_size[0],
62
+ p2=self._patch_size[1],
63
+ p3=self._patch_size[2],
64
+ )
65
+ return latents, latent_coords
66
+
67
+ def unpatchify(
68
+ self,
69
+ latents: Tensor,
70
+ output_height: int,
71
+ output_width: int,
72
+ out_channels: int,
73
+ ) -> Tuple[Tensor, Tensor]:
74
+ output_height = output_height // self._patch_size[1]
75
+ output_width = output_width // self._patch_size[2]
76
+ latents = rearrange(
77
+ latents,
78
+ "b (f h w) (c p q) -> b c f (h p) (w q)",
79
+ h=output_height,
80
+ w=output_width,
81
+ p=self._patch_size[1],
82
+ q=self._patch_size[2],
83
+ )
84
+ return latents
flash_head/ltx_video/models/transformers/transformer3d.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Union
5
+ import os
6
+ import json
7
+ import glob
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models.embeddings import PixArtAlphaTextProjection
13
+ from diffusers.models.modeling_utils import ModelMixin
14
+ from diffusers.models.normalization import AdaLayerNormSingle
15
+ from diffusers.utils import BaseOutput, is_torch_version
16
+ from diffusers.utils import logging
17
+ from torch import nn
18
+ from safetensors import safe_open
19
+
20
+
21
+ from flash_head.ltx_video.models.transformers.attention import BasicTransformerBlock
22
+ from flash_head.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
23
+
24
+ from flash_head.ltx_video.utils.diffusers_config_mapping import (
25
+ diffusers_and_ours_config_mapping,
26
+ make_hashable_key,
27
+ TRANSFORMER_KEYS_RENAME_DICT,
28
+ )
29
+
30
+
31
+ logger = logging.get_logger(__name__)
32
+
33
+
34
+ @dataclass
35
+ class Transformer3DModelOutput(BaseOutput):
36
+ """
37
+ The output of [`Transformer2DModel`].
38
+
39
+ Args:
40
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
41
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
42
+ distributions for the unnoised latent pixels.
43
+ """
44
+
45
+ sample: torch.FloatTensor
46
+
47
+
48
+ class Transformer3DModel(ModelMixin, ConfigMixin):
49
+ _supports_gradient_checkpointing = True
50
+
51
+ @register_to_config
52
+ def __init__(
53
+ self,
54
+ num_attention_heads: int = 16,
55
+ attention_head_dim: int = 88,
56
+ in_channels: Optional[int] = None,
57
+ out_channels: Optional[int] = None,
58
+ num_layers: int = 1,
59
+ dropout: float = 0.0,
60
+ norm_num_groups: int = 32,
61
+ cross_attention_dim: Optional[int] = None,
62
+ attention_bias: bool = False,
63
+ num_vector_embeds: Optional[int] = None,
64
+ activation_fn: str = "geglu",
65
+ num_embeds_ada_norm: Optional[int] = None,
66
+ use_linear_projection: bool = False,
67
+ only_cross_attention: bool = False,
68
+ double_self_attention: bool = False,
69
+ upcast_attention: bool = False,
70
+ adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale'
71
+ standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
72
+ norm_elementwise_affine: bool = True,
73
+ norm_eps: float = 1e-5,
74
+ attention_type: str = "default",
75
+ caption_channels: int = None,
76
+ use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention')
77
+ qk_norm: Optional[str] = None,
78
+ positional_embedding_type: str = "rope",
79
+ positional_embedding_theta: Optional[float] = None,
80
+ positional_embedding_max_pos: Optional[List[int]] = None,
81
+ timestep_scale_multiplier: Optional[float] = None,
82
+ causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated
83
+ ):
84
+ super().__init__()
85
+ self.use_tpu_flash_attention = (
86
+ use_tpu_flash_attention # FIXME: push config down to the attention modules
87
+ )
88
+ self.use_linear_projection = use_linear_projection
89
+ self.num_attention_heads = num_attention_heads
90
+ self.attention_head_dim = attention_head_dim
91
+ inner_dim = num_attention_heads * attention_head_dim
92
+ self.inner_dim = inner_dim
93
+ self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
94
+ self.positional_embedding_type = positional_embedding_type
95
+ self.positional_embedding_theta = positional_embedding_theta
96
+ self.positional_embedding_max_pos = positional_embedding_max_pos
97
+ self.use_rope = self.positional_embedding_type == "rope"
98
+ self.timestep_scale_multiplier = timestep_scale_multiplier
99
+
100
+ if self.positional_embedding_type == "absolute":
101
+ raise ValueError("Absolute positional embedding is no longer supported")
102
+ elif self.positional_embedding_type == "rope":
103
+ if positional_embedding_theta is None:
104
+ raise ValueError(
105
+ "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
106
+ )
107
+ if positional_embedding_max_pos is None:
108
+ raise ValueError(
109
+ "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
110
+ )
111
+
112
+ # 3. Define transformers blocks
113
+ self.transformer_blocks = nn.ModuleList(
114
+ [
115
+ BasicTransformerBlock(
116
+ inner_dim,
117
+ num_attention_heads,
118
+ attention_head_dim,
119
+ dropout=dropout,
120
+ cross_attention_dim=cross_attention_dim,
121
+ activation_fn=activation_fn,
122
+ num_embeds_ada_norm=num_embeds_ada_norm,
123
+ attention_bias=attention_bias,
124
+ only_cross_attention=only_cross_attention,
125
+ double_self_attention=double_self_attention,
126
+ upcast_attention=upcast_attention,
127
+ adaptive_norm=adaptive_norm,
128
+ standardization_norm=standardization_norm,
129
+ norm_elementwise_affine=norm_elementwise_affine,
130
+ norm_eps=norm_eps,
131
+ attention_type=attention_type,
132
+ use_tpu_flash_attention=use_tpu_flash_attention,
133
+ qk_norm=qk_norm,
134
+ use_rope=self.use_rope,
135
+ )
136
+ for d in range(num_layers)
137
+ ]
138
+ )
139
+
140
+ # 4. Define output layers
141
+ self.out_channels = in_channels if out_channels is None else out_channels
142
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
143
+ self.scale_shift_table = nn.Parameter(
144
+ torch.randn(2, inner_dim) / inner_dim**0.5
145
+ )
146
+ self.proj_out = nn.Linear(inner_dim, self.out_channels)
147
+
148
+ self.adaln_single = AdaLayerNormSingle(
149
+ inner_dim, use_additional_conditions=False
150
+ )
151
+ if adaptive_norm == "single_scale":
152
+ self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
153
+
154
+ self.caption_projection = None
155
+ if caption_channels is not None:
156
+ self.caption_projection = PixArtAlphaTextProjection(
157
+ in_features=caption_channels, hidden_size=inner_dim
158
+ )
159
+
160
+ self.gradient_checkpointing = False
161
+
162
+ def set_use_tpu_flash_attention(self):
163
+ r"""
164
+ Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
165
+ attention kernel.
166
+ """
167
+ logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
168
+ self.use_tpu_flash_attention = True
169
+ # push config down to the attention modules
170
+ for block in self.transformer_blocks:
171
+ block.set_use_tpu_flash_attention()
172
+
173
+ def create_skip_layer_mask(
174
+ self,
175
+ batch_size: int,
176
+ num_conds: int,
177
+ ptb_index: int,
178
+ skip_block_list: Optional[List[int]] = None,
179
+ ):
180
+ if skip_block_list is None or len(skip_block_list) == 0:
181
+ return None
182
+ num_layers = len(self.transformer_blocks)
183
+ mask = torch.ones(
184
+ (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
185
+ )
186
+ for block_idx in skip_block_list:
187
+ mask[block_idx, ptb_index::num_conds] = 0
188
+ return mask
189
+
190
+ def _set_gradient_checkpointing(self, module, value=False):
191
+ if hasattr(module, "gradient_checkpointing"):
192
+ module.gradient_checkpointing = value
193
+
194
+ def get_fractional_positions(self, indices_grid):
195
+ fractional_positions = torch.stack(
196
+ [
197
+ indices_grid[:, i] / self.positional_embedding_max_pos[i]
198
+ for i in range(3)
199
+ ],
200
+ dim=-1,
201
+ )
202
+ return fractional_positions
203
+
204
+ def precompute_freqs_cis(self, indices_grid, spacing="exp"):
205
+ dtype = torch.float32 # We need full precision in the freqs_cis computation.
206
+ dim = self.inner_dim
207
+ theta = self.positional_embedding_theta
208
+
209
+ fractional_positions = self.get_fractional_positions(indices_grid)
210
+
211
+ start = 1
212
+ end = theta
213
+ device = fractional_positions.device
214
+ if spacing == "exp":
215
+ indices = theta ** (
216
+ torch.linspace(
217
+ math.log(start, theta),
218
+ math.log(end, theta),
219
+ dim // 6,
220
+ device=device,
221
+ dtype=dtype,
222
+ )
223
+ )
224
+ indices = indices.to(dtype=dtype)
225
+ elif spacing == "exp_2":
226
+ indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
227
+ indices = indices.to(dtype=dtype)
228
+ elif spacing == "linear":
229
+ indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
230
+ elif spacing == "sqrt":
231
+ indices = torch.linspace(
232
+ start**2, end**2, dim // 6, device=device, dtype=dtype
233
+ ).sqrt()
234
+
235
+ indices = indices * math.pi / 2
236
+
237
+ if spacing == "exp_2":
238
+ freqs = (
239
+ (indices * fractional_positions.unsqueeze(-1))
240
+ .transpose(-1, -2)
241
+ .flatten(2)
242
+ )
243
+ else:
244
+ freqs = (
245
+ (indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
246
+ .transpose(-1, -2)
247
+ .flatten(2)
248
+ )
249
+
250
+ cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
251
+ sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
252
+ if dim % 6 != 0:
253
+ cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
254
+ sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
255
+ cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
256
+ sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
257
+ return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
258
+
259
+ def load_state_dict(
260
+ self,
261
+ state_dict: Dict,
262
+ *args,
263
+ **kwargs,
264
+ ):
265
+ if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
266
+ state_dict = {
267
+ key.replace("model.diffusion_model.", ""): value
268
+ for key, value in state_dict.items()
269
+ if key.startswith("model.diffusion_model.")
270
+ }
271
+ super().load_state_dict(state_dict, **kwargs)
272
+
273
+ @classmethod
274
+ def from_pretrained(
275
+ cls,
276
+ pretrained_model_path: Optional[Union[str, os.PathLike]],
277
+ *args,
278
+ **kwargs,
279
+ ):
280
+ pretrained_model_path = Path(pretrained_model_path)
281
+ if pretrained_model_path.is_dir():
282
+ config_path = pretrained_model_path / "transformer" / "config.json"
283
+ with open(config_path, "r") as f:
284
+ config = make_hashable_key(json.load(f))
285
+
286
+ assert config in diffusers_and_ours_config_mapping, (
287
+ "Provided diffusers checkpoint config for transformer is not suppported. "
288
+ "We only support diffusers configs found in Lightricks/LTX-Video."
289
+ )
290
+
291
+ config = diffusers_and_ours_config_mapping[config]
292
+ state_dict = {}
293
+ ckpt_paths = (
294
+ pretrained_model_path
295
+ / "transformer"
296
+ / "diffusion_pytorch_model*.safetensors"
297
+ )
298
+ dict_list = glob.glob(str(ckpt_paths))
299
+ for dict_path in dict_list:
300
+ part_dict = {}
301
+ with safe_open(dict_path, framework="pt", device="cpu") as f:
302
+ for k in f.keys():
303
+ part_dict[k] = f.get_tensor(k)
304
+ state_dict.update(part_dict)
305
+
306
+ for key in list(state_dict.keys()):
307
+ new_key = key
308
+ for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
309
+ new_key = new_key.replace(replace_key, rename_key)
310
+ state_dict[new_key] = state_dict.pop(key)
311
+
312
+ with torch.device("meta"):
313
+ transformer = cls.from_config(config)
314
+ transformer.load_state_dict(state_dict, assign=True, strict=True)
315
+ elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
316
+ ".safetensors"
317
+ ):
318
+ comfy_single_file_state_dict = {}
319
+ with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
320
+ metadata = f.metadata()
321
+ for k in f.keys():
322
+ comfy_single_file_state_dict[k] = f.get_tensor(k)
323
+ configs = json.loads(metadata["config"])
324
+ transformer_config = configs["transformer"]
325
+ with torch.device("meta"):
326
+ transformer = Transformer3DModel.from_config(transformer_config)
327
+ transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
328
+ return transformer
329
+
330
+ def forward(
331
+ self,
332
+ hidden_states: torch.Tensor,
333
+ indices_grid: torch.Tensor,
334
+ encoder_hidden_states: Optional[torch.Tensor] = None,
335
+ timestep: Optional[torch.LongTensor] = None,
336
+ class_labels: Optional[torch.LongTensor] = None,
337
+ cross_attention_kwargs: Dict[str, Any] = None,
338
+ attention_mask: Optional[torch.Tensor] = None,
339
+ encoder_attention_mask: Optional[torch.Tensor] = None,
340
+ skip_layer_mask: Optional[torch.Tensor] = None,
341
+ skip_layer_strategy: Optional[SkipLayerStrategy] = None,
342
+ return_dict: bool = True,
343
+ ):
344
+ """
345
+ The [`Transformer2DModel`] forward method.
346
+
347
+ Args:
348
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
349
+ Input `hidden_states`.
350
+ indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
351
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
352
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
353
+ self-attention.
354
+ timestep ( `torch.LongTensor`, *optional*):
355
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
356
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
357
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
358
+ `AdaLayerZeroNorm`.
359
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
360
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
361
+ `self.processor` in
362
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
363
+ attention_mask ( `torch.Tensor`, *optional*):
364
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
365
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
366
+ negative values to the attention scores corresponding to "discard" tokens.
367
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
368
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
369
+
370
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
371
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
372
+
373
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
374
+ above. This bias will be added to the cross-attention scores.
375
+ skip_layer_mask ( `torch.Tensor`, *optional*):
376
+ A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position
377
+ `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index.
378
+ skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`):
379
+ Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance.
380
+ return_dict (`bool`, *optional*, defaults to `True`):
381
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
382
+ tuple.
383
+
384
+ Returns:
385
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
386
+ `tuple` where the first element is the sample tensor.
387
+ """
388
+ # for tpu attention offload 2d token masks are used. No need to transform.
389
+ if not self.use_tpu_flash_attention:
390
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
391
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
392
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
393
+ # expects mask of shape:
394
+ # [batch, key_tokens]
395
+ # adds singleton query_tokens dimension:
396
+ # [batch, 1, key_tokens]
397
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
398
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
399
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
400
+ if attention_mask is not None and attention_mask.ndim == 2:
401
+ # assume that mask is expressed as:
402
+ # (1 = keep, 0 = discard)
403
+ # convert mask into a bias that can be added to attention scores:
404
+ # (keep = +0, discard = -10000.0)
405
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
406
+ attention_mask = attention_mask.unsqueeze(1)
407
+
408
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
409
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
410
+ encoder_attention_mask = (
411
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
412
+ ) * -10000.0
413
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
414
+
415
+ # 1. Input
416
+ hidden_states = self.patchify_proj(hidden_states)
417
+
418
+ if self.timestep_scale_multiplier:
419
+ timestep = self.timestep_scale_multiplier * timestep
420
+
421
+ freqs_cis = self.precompute_freqs_cis(indices_grid)
422
+
423
+ batch_size = hidden_states.shape[0]
424
+ timestep, embedded_timestep = self.adaln_single(
425
+ timestep.flatten(),
426
+ {"resolution": None, "aspect_ratio": None},
427
+ batch_size=batch_size,
428
+ hidden_dtype=hidden_states.dtype,
429
+ )
430
+ # Second dimension is 1 or number of tokens (if timestep_per_token)
431
+ timestep = timestep.view(batch_size, -1, timestep.shape[-1])
432
+ embedded_timestep = embedded_timestep.view(
433
+ batch_size, -1, embedded_timestep.shape[-1]
434
+ )
435
+
436
+ # 2. Blocks
437
+ if self.caption_projection is not None:
438
+ batch_size = hidden_states.shape[0]
439
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
440
+ encoder_hidden_states = encoder_hidden_states.view(
441
+ batch_size, -1, hidden_states.shape[-1]
442
+ )
443
+
444
+ for block_idx, block in enumerate(self.transformer_blocks):
445
+ if self.training and self.gradient_checkpointing:
446
+
447
+ def create_custom_forward(module, return_dict=None):
448
+ def custom_forward(*inputs):
449
+ if return_dict is not None:
450
+ return module(*inputs, return_dict=return_dict)
451
+ else:
452
+ return module(*inputs)
453
+
454
+ return custom_forward
455
+
456
+ ckpt_kwargs: Dict[str, Any] = (
457
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
458
+ )
459
+ hidden_states = torch.utils.checkpoint.checkpoint(
460
+ create_custom_forward(block),
461
+ hidden_states,
462
+ freqs_cis,
463
+ attention_mask,
464
+ encoder_hidden_states,
465
+ encoder_attention_mask,
466
+ timestep,
467
+ cross_attention_kwargs,
468
+ class_labels,
469
+ (
470
+ skip_layer_mask[block_idx]
471
+ if skip_layer_mask is not None
472
+ else None
473
+ ),
474
+ skip_layer_strategy,
475
+ **ckpt_kwargs,
476
+ )
477
+ else:
478
+ hidden_states = block(
479
+ hidden_states,
480
+ freqs_cis=freqs_cis,
481
+ attention_mask=attention_mask,
482
+ encoder_hidden_states=encoder_hidden_states,
483
+ encoder_attention_mask=encoder_attention_mask,
484
+ timestep=timestep,
485
+ cross_attention_kwargs=cross_attention_kwargs,
486
+ class_labels=class_labels,
487
+ skip_layer_mask=(
488
+ skip_layer_mask[block_idx]
489
+ if skip_layer_mask is not None
490
+ else None
491
+ ),
492
+ skip_layer_strategy=skip_layer_strategy,
493
+ )
494
+
495
+ # 3. Output
496
+ scale_shift_values = (
497
+ self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
498
+ )
499
+ shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
500
+ hidden_states = self.norm_out(hidden_states)
501
+ # Modulation
502
+ hidden_states = hidden_states * (1 + scale) + shift
503
+ hidden_states = self.proj_out(hidden_states)
504
+ if not return_dict:
505
+ return (hidden_states,)
506
+
507
+ return Transformer3DModelOutput(sample=hidden_states)
flash_head/ltx_video/utils/__init__.py ADDED
File without changes
flash_head/ltx_video/utils/diffusers_config_mapping.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def make_hashable_key(dict_key):
2
+ def convert_value(value):
3
+ if isinstance(value, list):
4
+ return tuple(value)
5
+ elif isinstance(value, dict):
6
+ return tuple(sorted((k, convert_value(v)) for k, v in value.items()))
7
+ else:
8
+ return value
9
+
10
+ return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items()))
11
+
12
+
13
+ DIFFUSERS_SCHEDULER_CONFIG = {
14
+ "_class_name": "FlowMatchEulerDiscreteScheduler",
15
+ "_diffusers_version": "0.32.0.dev0",
16
+ "base_image_seq_len": 1024,
17
+ "base_shift": 0.95,
18
+ "invert_sigmas": False,
19
+ "max_image_seq_len": 4096,
20
+ "max_shift": 2.05,
21
+ "num_train_timesteps": 1000,
22
+ "shift": 1.0,
23
+ "shift_terminal": 0.1,
24
+ "use_beta_sigmas": False,
25
+ "use_dynamic_shifting": True,
26
+ "use_exponential_sigmas": False,
27
+ "use_karras_sigmas": False,
28
+ }
29
+ DIFFUSERS_TRANSFORMER_CONFIG = {
30
+ "_class_name": "LTXVideoTransformer3DModel",
31
+ "_diffusers_version": "0.32.0.dev0",
32
+ "activation_fn": "gelu-approximate",
33
+ "attention_bias": True,
34
+ "attention_head_dim": 64,
35
+ "attention_out_bias": True,
36
+ "caption_channels": 4096,
37
+ "cross_attention_dim": 2048,
38
+ "in_channels": 128,
39
+ "norm_elementwise_affine": False,
40
+ "norm_eps": 1e-06,
41
+ "num_attention_heads": 32,
42
+ "num_layers": 28,
43
+ "out_channels": 128,
44
+ "patch_size": 1,
45
+ "patch_size_t": 1,
46
+ "qk_norm": "rms_norm_across_heads",
47
+ }
48
+ DIFFUSERS_VAE_CONFIG = {
49
+ "_class_name": "AutoencoderKLLTXVideo",
50
+ "_diffusers_version": "0.32.0.dev0",
51
+ "block_out_channels": [128, 256, 512, 512],
52
+ "decoder_causal": False,
53
+ "encoder_causal": True,
54
+ "in_channels": 3,
55
+ "latent_channels": 128,
56
+ "layers_per_block": [4, 3, 3, 3, 4],
57
+ "out_channels": 3,
58
+ "patch_size": 4,
59
+ "patch_size_t": 1,
60
+ "resnet_norm_eps": 1e-06,
61
+ "scaling_factor": 1.0,
62
+ "spatio_temporal_scaling": [True, True, True, False],
63
+ }
64
+
65
+ OURS_SCHEDULER_CONFIG = {
66
+ "_class_name": "RectifiedFlowScheduler",
67
+ "_diffusers_version": "0.25.1",
68
+ "num_train_timesteps": 1000,
69
+ "shifting": "SD3",
70
+ "base_resolution": None,
71
+ "target_shift_terminal": 0.1,
72
+ }
73
+
74
+ OURS_TRANSFORMER_CONFIG = {
75
+ "_class_name": "Transformer3DModel",
76
+ "_diffusers_version": "0.25.1",
77
+ "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256",
78
+ "activation_fn": "gelu-approximate",
79
+ "attention_bias": True,
80
+ "attention_head_dim": 64,
81
+ "attention_type": "default",
82
+ "caption_channels": 4096,
83
+ "cross_attention_dim": 2048,
84
+ "double_self_attention": False,
85
+ "dropout": 0.0,
86
+ "in_channels": 128,
87
+ "norm_elementwise_affine": False,
88
+ "norm_eps": 1e-06,
89
+ "norm_num_groups": 32,
90
+ "num_attention_heads": 32,
91
+ "num_embeds_ada_norm": 1000,
92
+ "num_layers": 28,
93
+ "num_vector_embeds": None,
94
+ "only_cross_attention": False,
95
+ "out_channels": 128,
96
+ "project_to_2d_pos": True,
97
+ "upcast_attention": False,
98
+ "use_linear_projection": False,
99
+ "qk_norm": "rms_norm",
100
+ "standardization_norm": "rms_norm",
101
+ "positional_embedding_type": "rope",
102
+ "positional_embedding_theta": 10000.0,
103
+ "positional_embedding_max_pos": [20, 2048, 2048],
104
+ "timestep_scale_multiplier": 1000,
105
+ }
106
+ OURS_VAE_CONFIG = {
107
+ "_class_name": "CausalVideoAutoencoder",
108
+ "dims": 3,
109
+ "in_channels": 3,
110
+ "out_channels": 3,
111
+ "latent_channels": 128,
112
+ "blocks": [
113
+ ["res_x", 4],
114
+ ["compress_all", 1],
115
+ ["res_x_y", 1],
116
+ ["res_x", 3],
117
+ ["compress_all", 1],
118
+ ["res_x_y", 1],
119
+ ["res_x", 3],
120
+ ["compress_all", 1],
121
+ ["res_x", 3],
122
+ ["res_x", 4],
123
+ ],
124
+ "scaling_factor": 1.0,
125
+ "norm_layer": "pixel_norm",
126
+ "patch_size": 4,
127
+ "latent_log_var": "uniform",
128
+ "use_quant_conv": False,
129
+ "causal_decoder": False,
130
+ }
131
+
132
+
133
+ diffusers_and_ours_config_mapping = {
134
+ make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG,
135
+ make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG,
136
+ make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG,
137
+ }
138
+
139
+
140
+ TRANSFORMER_KEYS_RENAME_DICT = {
141
+ "proj_in": "patchify_proj",
142
+ "time_embed": "adaln_single",
143
+ "norm_q": "q_norm",
144
+ "norm_k": "k_norm",
145
+ }
146
+
147
+
148
+ VAE_KEYS_RENAME_DICT = {
149
+ "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7",
150
+ "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8",
151
+ "decoder.up_blocks.3": "decoder.up_blocks.9",
152
+ "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5",
153
+ "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4",
154
+ "decoder.up_blocks.2": "decoder.up_blocks.6",
155
+ "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2",
156
+ "decoder.up_blocks.1": "decoder.up_blocks.3",
157
+ "decoder.up_blocks.0": "decoder.up_blocks.1",
158
+ "decoder.mid_block": "decoder.up_blocks.0",
159
+ "encoder.down_blocks.3": "encoder.down_blocks.8",
160
+ "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7",
161
+ "encoder.down_blocks.2": "encoder.down_blocks.6",
162
+ "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4",
163
+ "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5",
164
+ "encoder.down_blocks.1": "encoder.down_blocks.3",
165
+ "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2",
166
+ "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1",
167
+ "encoder.down_blocks.0": "encoder.down_blocks.0",
168
+ "encoder.mid_block": "encoder.down_blocks.9",
169
+ "conv_shortcut.conv": "conv_shortcut",
170
+ "resnets": "res_blocks",
171
+ "norm3": "norm3.norm",
172
+ "latents_mean": "per_channel_statistics.mean-of-means",
173
+ "latents_std": "per_channel_statistics.std-of-means",
174
+ }
flash_head/ltx_video/utils/prompt_enhance_utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Union, List, Optional
3
+
4
+ import torch
5
+ from PIL import Image
6
+
7
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
8
+
9
+ T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
10
+ Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
11
+ Start directly with the action, and keep descriptions literal and precise.
12
+ Think like a cinematographer describing a shot list.
13
+ Do not change the user input intent, just enhance it.
14
+ Keep within 150 words.
15
+ For best results, build your prompts using this structure:
16
+ Start with main action in a single sentence
17
+ Add specific details about movements and gestures
18
+ Describe character/object appearances precisely
19
+ Include background and environment details
20
+ Specify camera angles and movements
21
+ Describe lighting and colors
22
+ Note any changes or sudden events
23
+ Do not exceed the 150 word limit!
24
+ Output the enhanced prompt only.
25
+ """
26
+
27
+ I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
28
+ Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
29
+ Start directly with the action, and keep descriptions literal and precise.
30
+ Think like a cinematographer describing a shot list.
31
+ Keep within 150 words.
32
+ For best results, build your prompts using this structure:
33
+ Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input.
34
+ Start with main action in a single sentence
35
+ Add specific details about movements and gestures
36
+ Describe character/object appearances precisely
37
+ Include background and environment details
38
+ Specify camera angles and movements
39
+ Describe lighting and colors
40
+ Note any changes or sudden events
41
+ Align to the image caption if it contradicts the user text input.
42
+ Do not exceed the 150 word limit!
43
+ Output the enhanced prompt only.
44
+ """
45
+
46
+
47
+ def tensor_to_pil(tensor):
48
+ # Ensure tensor is in range [-1, 1]
49
+ assert tensor.min() >= -1 and tensor.max() <= 1
50
+
51
+ # Convert from [-1, 1] to [0, 1]
52
+ tensor = (tensor + 1) / 2
53
+
54
+ # Rearrange from [C, H, W] to [H, W, C]
55
+ tensor = tensor.permute(1, 2, 0)
56
+
57
+ # Convert to numpy array and then to uint8 range [0, 255]
58
+ numpy_image = (tensor.cpu().numpy() * 255).astype("uint8")
59
+
60
+ # Convert to PIL Image
61
+ return Image.fromarray(numpy_image)
62
+
63
+
64
+ def generate_cinematic_prompt(
65
+ image_caption_model,
66
+ image_caption_processor,
67
+ prompt_enhancer_model,
68
+ prompt_enhancer_tokenizer,
69
+ prompt: Union[str, List[str]],
70
+ conditioning_items: Optional[List] = None,
71
+ max_new_tokens: int = 256,
72
+ ) -> List[str]:
73
+ prompts = [prompt] if isinstance(prompt, str) else prompt
74
+
75
+ if conditioning_items is None:
76
+ prompts = _generate_t2v_prompt(
77
+ prompt_enhancer_model,
78
+ prompt_enhancer_tokenizer,
79
+ prompts,
80
+ max_new_tokens,
81
+ T2V_CINEMATIC_PROMPT,
82
+ )
83
+ else:
84
+ if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0:
85
+ logger.warning(
86
+ "prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts"
87
+ )
88
+ return prompts
89
+
90
+ first_frame_conditioning_item = conditioning_items[0]
91
+ first_frames = _get_first_frames_from_conditioning_item(
92
+ first_frame_conditioning_item
93
+ )
94
+
95
+ assert len(first_frames) == len(
96
+ prompts
97
+ ), "Number of conditioning frames must match number of prompts"
98
+
99
+ prompts = _generate_i2v_prompt(
100
+ image_caption_model,
101
+ image_caption_processor,
102
+ prompt_enhancer_model,
103
+ prompt_enhancer_tokenizer,
104
+ prompts,
105
+ first_frames,
106
+ max_new_tokens,
107
+ I2V_CINEMATIC_PROMPT,
108
+ )
109
+
110
+ return prompts
111
+
112
+
113
+ def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]:
114
+ frames_tensor = conditioning_item.media_item
115
+ return [
116
+ tensor_to_pil(frames_tensor[i, :, 0, :, :])
117
+ for i in range(frames_tensor.shape[0])
118
+ ]
119
+
120
+
121
+ def _generate_t2v_prompt(
122
+ prompt_enhancer_model,
123
+ prompt_enhancer_tokenizer,
124
+ prompts: List[str],
125
+ max_new_tokens: int,
126
+ system_prompt: str,
127
+ ) -> List[str]:
128
+ messages = [
129
+ [
130
+ {"role": "system", "content": system_prompt},
131
+ {"role": "user", "content": f"user_prompt: {p}"},
132
+ ]
133
+ for p in prompts
134
+ ]
135
+
136
+ texts = [
137
+ prompt_enhancer_tokenizer.apply_chat_template(
138
+ m, tokenize=False, add_generation_prompt=True
139
+ )
140
+ for m in messages
141
+ ]
142
+ model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
143
+ prompt_enhancer_model.device
144
+ )
145
+
146
+ return _generate_and_decode_prompts(
147
+ prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
148
+ )
149
+
150
+
151
+ def _generate_i2v_prompt(
152
+ image_caption_model,
153
+ image_caption_processor,
154
+ prompt_enhancer_model,
155
+ prompt_enhancer_tokenizer,
156
+ prompts: List[str],
157
+ first_frames: List[Image.Image],
158
+ max_new_tokens: int,
159
+ system_prompt: str,
160
+ ) -> List[str]:
161
+ image_captions = _generate_image_captions(
162
+ image_caption_model, image_caption_processor, first_frames
163
+ )
164
+
165
+ messages = [
166
+ [
167
+ {"role": "system", "content": system_prompt},
168
+ {"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"},
169
+ ]
170
+ for p, c in zip(prompts, image_captions)
171
+ ]
172
+
173
+ texts = [
174
+ prompt_enhancer_tokenizer.apply_chat_template(
175
+ m, tokenize=False, add_generation_prompt=True
176
+ )
177
+ for m in messages
178
+ ]
179
+ model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
180
+ prompt_enhancer_model.device
181
+ )
182
+
183
+ return _generate_and_decode_prompts(
184
+ prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
185
+ )
186
+
187
+
188
+ def _generate_image_captions(
189
+ image_caption_model,
190
+ image_caption_processor,
191
+ images: List[Image.Image],
192
+ system_prompt: str = "<DETAILED_CAPTION>",
193
+ ) -> List[str]:
194
+ image_caption_prompts = [system_prompt] * len(images)
195
+ inputs = image_caption_processor(
196
+ image_caption_prompts, images, return_tensors="pt"
197
+ ).to(image_caption_model.device)
198
+
199
+ with torch.inference_mode():
200
+ generated_ids = image_caption_model.generate(
201
+ input_ids=inputs["input_ids"],
202
+ pixel_values=inputs["pixel_values"],
203
+ max_new_tokens=1024,
204
+ do_sample=False,
205
+ num_beams=3,
206
+ )
207
+
208
+ return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True)
209
+
210
+
211
+ def _generate_and_decode_prompts(
212
+ prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int
213
+ ) -> List[str]:
214
+ with torch.inference_mode():
215
+ outputs = prompt_enhancer_model.generate(
216
+ **model_inputs, max_new_tokens=max_new_tokens
217
+ )
218
+ generated_ids = [
219
+ output_ids[len(input_ids) :]
220
+ for input_ids, output_ids in zip(model_inputs.input_ids, outputs)
221
+ ]
222
+ decoded_prompts = prompt_enhancer_tokenizer.batch_decode(
223
+ generated_ids, skip_special_tokens=True
224
+ )
225
+
226
+ return decoded_prompts
flash_head/ltx_video/utils/skip_layer_strategy.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum, auto
2
+
3
+
4
+ class SkipLayerStrategy(Enum):
5
+ AttentionSkip = auto()
6
+ AttentionValues = auto()
7
+ Residual = auto()
8
+ TransformerBlock = auto()
flash_head/ltx_video/utils/torch_utils.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
6
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
7
+ dims_to_append = target_dims - x.ndim
8
+ if dims_to_append < 0:
9
+ raise ValueError(
10
+ f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
11
+ )
12
+ elif dims_to_append == 0:
13
+ return x
14
+ return x[(...,) + (None,) * dims_to_append]
15
+
16
+
17
+ class Identity(nn.Module):
18
+ """A placeholder identity operator that is argument-insensitive."""
19
+
20
+ def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
21
+ super().__init__()
22
+
23
+ # pylint: disable=unused-argument
24
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
25
+ return x
flash_head/src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
flash_head/src/distributed/usp_device.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from loguru import logger
3
+ import datetime
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+ def get_parallel_degree(world_size, num_heads):
8
+ # ulysses_degree is faster, and must be a divisor of num_heads
9
+ ulysses_degree = math.gcd(world_size, num_heads)
10
+ ring_degree = world_size // ulysses_degree
11
+ return ulysses_degree, ring_degree
12
+
13
+ def get_device(ulysses_degree, ring_degree):
14
+ if ulysses_degree > 1 or ring_degree > 1:
15
+ from xfuser.core.distributed import (
16
+ init_distributed_environment,
17
+ initialize_model_parallel,
18
+ get_world_group,
19
+ )
20
+
21
+ dist.init_process_group("nccl", timeout=datetime.timedelta(hours=24*7))
22
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
23
+ initialize_model_parallel(
24
+ sequence_parallel_degree=dist.get_world_size(),
25
+ ring_degree=ring_degree,
26
+ ulysses_degree=ulysses_degree
27
+ )
28
+
29
+ device = torch.device(f"cuda:{get_world_group().rank}")
30
+ torch.cuda.set_device(get_world_group().rank)
31
+
32
+ logger.info(f'rank={get_world_group().rank} device={str(device)}')
33
+ else:
34
+ device = "cuda"
35
+ return device
flash_head/src/modules/flash_head_model.py ADDED
@@ -0,0 +1,548 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from typing import Tuple, Optional
6
+ from einops import rearrange
7
+ from diffusers import ModelMixin
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ import torch.cuda.amp as amp
10
+ import torch.distributed as dist
11
+ from xfuser.core.distributed import (
12
+ get_sequence_parallel_rank,
13
+ get_sequence_parallel_world_size,
14
+ get_sp_group,
15
+ )
16
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
17
+ try:
18
+ import flash_attn_interface
19
+ FLASH_ATTN_3_AVAILABLE = True
20
+ except ModuleNotFoundError:
21
+ FLASH_ATTN_3_AVAILABLE = False
22
+
23
+ try:
24
+ import flash_attn
25
+ FLASH_ATTN_2_AVAILABLE = True
26
+ except ModuleNotFoundError:
27
+ FLASH_ATTN_2_AVAILABLE = False
28
+
29
+ try:
30
+ from sageattention import sageattn
31
+ SAGE_ATTN_AVAILABLE = True
32
+ except ModuleNotFoundError:
33
+ SAGE_ATTN_AVAILABLE = False
34
+
35
+
36
+ def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
37
+ if compatibility_mode:
38
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
39
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
40
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
41
+ x = F.scaled_dot_product_attention(q, k, v)
42
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
43
+ elif SAGE_ATTN_AVAILABLE:
44
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
45
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
46
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
47
+ x = sageattn(q, k, v)
48
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
49
+ elif FLASH_ATTN_3_AVAILABLE:
50
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
51
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
52
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
53
+ x = flash_attn_interface.flash_attn_func(q, k, v)
54
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
55
+ elif FLASH_ATTN_2_AVAILABLE:
56
+ q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
57
+ k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
58
+ v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
59
+ x = flash_attn.flash_attn_func(q, k, v)
60
+ x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
61
+ else:
62
+ q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
63
+ k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
64
+ v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
65
+ x = F.scaled_dot_product_attention(q, k, v)
66
+ x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
67
+ return x
68
+
69
+ def sinusoidal_embedding_1d(dim, position):
70
+ sinusoid = torch.outer(position.type(torch.float64), torch.pow(
71
+ 10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
72
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
73
+ return x.to(position.dtype)
74
+
75
+
76
+ def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
77
+ # 3d rope precompute
78
+ f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
79
+ h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
80
+ w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
81
+ return torch.cat([f_freqs_cis, h_freqs_cis, w_freqs_cis], dim=1)
82
+
83
+
84
+ def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
85
+ # 1d rope precompute
86
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
87
+ [: (dim // 2)].double() / dim))
88
+ freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
89
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
90
+ return freqs_cis
91
+
92
+ def pad_freqs(original_tensor, target_len):
93
+ seq_len, s1, s2 = original_tensor.shape
94
+ pad_size = target_len - seq_len
95
+ padding_tensor = torch.ones(
96
+ pad_size,
97
+ s1,
98
+ s2,
99
+ dtype=original_tensor.dtype,
100
+ device=original_tensor.device)
101
+ padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
102
+ return padded_tensor
103
+
104
+ def rope_apply(x, freqs, grid_sizes, use_usp=False, sp_size=1, sp_rank=0):
105
+ """
106
+ x: [B, L, N, C].
107
+ grid_sizes: [B, 3].
108
+ freqs: [M, C // 2].
109
+ """
110
+ s, n, c = x.size(1), x.size(2), x.size(3) // 2
111
+ # split freqs
112
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # [[N, head_dim/2], [N, head_dim/2], [N, head_dim/2]] # T H W 极坐标
113
+
114
+ # loop over samples
115
+
116
+ (f, h, w) = grid_sizes
117
+ seq_len = f * h * w
118
+
119
+ # precompute multipliers
120
+ x_i = torch.view_as_complex(x[0, :s].to(torch.float64).reshape(
121
+ s, n, -1, 2)) # [L, N, C/2] # 极坐标
122
+ freqs_i = torch.cat([
123
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
124
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
125
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
126
+ ],
127
+ dim=-1).reshape(seq_len, 1, -1) # seq_lens, 1, 3 * dim / 2 (T H W)
128
+
129
+ if use_usp:
130
+ # apply rotary embedding
131
+ freqs_i = pad_freqs(freqs_i, s * sp_size)
132
+ s_per_rank = s
133
+ freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
134
+ s_per_rank), :, :]
135
+ x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
136
+ x_i = torch.cat([x_i, x[0, s:]])
137
+ else:
138
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
139
+ x_i = torch.cat([x_i, x[0, seq_len:]])
140
+ return x_i.unsqueeze(0).to(x.dtype)
141
+
142
+
143
+ class RMSNorm(nn.Module):
144
+ def __init__(self, dim, eps=1e-5):
145
+ super().__init__()
146
+ self.eps = eps
147
+ self.weight = nn.Parameter(torch.ones(dim))
148
+
149
+ def norm(self, x):
150
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
151
+
152
+ def forward(self, x):
153
+ dtype = x.dtype
154
+ return self.norm(x.float()).to(dtype) * self.weight
155
+
156
+ class SelfAttention(nn.Module):
157
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
158
+ super().__init__()
159
+ self.dim = dim
160
+ self.num_heads = num_heads
161
+ self.head_dim = dim // num_heads
162
+
163
+ self.q = nn.Linear(dim, dim)
164
+ self.k = nn.Linear(dim, dim)
165
+ self.v = nn.Linear(dim, dim)
166
+ self.o = nn.Linear(dim, dim)
167
+ self.norm_q = RMSNorm(dim, eps=eps)
168
+ self.norm_k = RMSNorm(dim, eps=eps)
169
+
170
+ self.use_usp = dist.is_initialized()
171
+ self.sp_size = get_sequence_parallel_world_size() if self.use_usp else 1
172
+ self.sp_rank = get_sequence_parallel_rank() if self.use_usp else 0
173
+
174
+ def forward(self, x, freqs, grid_sizes):
175
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
176
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
177
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
178
+ v = self.v(x)
179
+
180
+ if self.use_usp:
181
+ from yunchang.kernels import AttnType
182
+ if SAGE_ATTN_AVAILABLE:
183
+ attn_type = AttnType.SAGE_AUTO
184
+ else:
185
+ attn_type = AttnType.FA
186
+
187
+ x = xFuserLongContextAttention(attn_type=attn_type)(
188
+ None,
189
+ query=rope_apply(q, freqs, grid_sizes, self.use_usp, self.sp_size, self.sp_rank),
190
+ key=rope_apply(k, freqs, grid_sizes, self.use_usp, self.sp_size, self.sp_rank),
191
+ value=v.view(b, s, n, d),
192
+ ).flatten(2)
193
+ else:
194
+ x = flash_attention(
195
+ q=rope_apply(q, freqs, grid_sizes).flatten(2),
196
+ k=rope_apply(k, freqs, grid_sizes).flatten(2),
197
+ v=v,
198
+ num_heads=self.num_heads
199
+ )
200
+ return self.o(x)
201
+
202
+
203
+ class CrossAttention(nn.Module):
204
+ def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
205
+ super().__init__()
206
+ self.dim = dim
207
+ self.num_heads = num_heads
208
+ self.head_dim = dim // num_heads
209
+
210
+ self.q = nn.Linear(dim, dim)
211
+ self.k = nn.Linear(dim, dim)
212
+ self.v = nn.Linear(dim, dim)
213
+ self.o = nn.Linear(dim, dim)
214
+ self.norm_q = RMSNorm(dim, eps=eps)
215
+ self.norm_k = RMSNorm(dim, eps=eps)
216
+ self.has_image_input = has_image_input
217
+ if has_image_input:
218
+ self.k_img = nn.Linear(dim, dim)
219
+ self.v_img = nn.Linear(dim, dim)
220
+ self.norm_k_img = RMSNorm(dim, eps=eps)
221
+
222
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
223
+ if self.has_image_input:
224
+ img = y[:, :257]
225
+ ctx = y[:, 257:]
226
+ else:
227
+ ctx = y
228
+ q = self.norm_q(self.q(x))
229
+ k = self.norm_k(self.k(ctx))
230
+ v = self.v(ctx)
231
+ x = flash_attention(q, k, v, num_heads=self.num_heads)
232
+ if self.has_image_input:
233
+ k_img = self.norm_k_img(self.k_img(img))
234
+ v_img = self.v_img(img)
235
+ y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
236
+ x = x + y
237
+ return self.o(x)
238
+
239
+ class DiTAudioBlock(nn.Module):
240
+ def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, i=0, num_layers=0):
241
+ super().__init__()
242
+ self.dim = dim
243
+ self.num_heads = num_heads
244
+ self.ffn_dim = ffn_dim
245
+ self.i = i
246
+ self.num_layers = num_layers
247
+
248
+ self.self_attn = SelfAttention(dim, num_heads, eps)
249
+ self.cross_attn = CrossAttention(
250
+ dim, num_heads, eps, has_image_input=has_image_input)
251
+ self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
252
+ self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
253
+ self.norm3 = nn.LayerNorm(dim, eps=eps)
254
+ self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
255
+ approximate='tanh'), nn.Linear(ffn_dim, dim))
256
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
257
+
258
+ self.use_usp = dist.is_initialized()
259
+ self.sp_size = get_sequence_parallel_world_size() if self.use_usp else 1
260
+ self.sp_rank = get_sequence_parallel_rank() if self.use_usp else 0
261
+
262
+ def forward(self, x, context, t_mod, freqs, grid_sizes):
263
+ e = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
264
+
265
+ y = self.self_attn(
266
+ self.norm1(x) * (1 + e[1]) + e[0], freqs, grid_sizes)
267
+
268
+ x = x + y * e[2]
269
+
270
+ x_1 = rearrange(self.norm3(x), 'b (f l) c -> (b f) l c', f=context.shape[1])
271
+ context_1 = context.squeeze(0)
272
+
273
+ if self.use_usp:
274
+ context_1 = context_1.unsqueeze(1).repeat(1, self.sp_size, 1, 1).flatten(0,1)
275
+ context_1 = torch.chunk(context_1, self.sp_size, dim=0)[self.sp_rank]
276
+
277
+ x = x + self.cross_attn(x_1, context_1).flatten(0, 1).unsqueeze(0)
278
+
279
+ y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
280
+ x = x + y * e[5]
281
+
282
+ return x
283
+
284
+ class MLP(torch.nn.Module):
285
+ def __init__(self, in_dim, out_dim):
286
+ super().__init__()
287
+ self.proj = torch.nn.Sequential(
288
+ nn.LayerNorm(in_dim),
289
+ nn.Linear(in_dim, in_dim),
290
+ nn.GELU(),
291
+ nn.Linear(in_dim, out_dim),
292
+ nn.LayerNorm(out_dim)
293
+ )
294
+
295
+ def forward(self, x):
296
+ return self.proj(x)
297
+
298
+
299
+ class Head(nn.Module):
300
+ def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
301
+ super().__init__()
302
+ self.dim = dim
303
+ self.patch_size = patch_size
304
+ self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
305
+ self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
306
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
307
+
308
+ def forward(self, x, t_mod):
309
+ r"""
310
+ Args:
311
+ x(Tensor): Shape [B, L1, C]
312
+ t_mod(Tensor): Shape [B*21, C]
313
+ """
314
+ B, L, D = x.shape
315
+ F = t_mod.shape[0] // B
316
+ shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device).unsqueeze(1) + t_mod.unflatten(dim=0, sizes=(B, t_mod.shape[0]//B)).unsqueeze(2)).chunk(2, dim=2)
317
+
318
+ x = rearrange(x, 'b (f l) d -> b f l d', f=F)
319
+ x = (self.head(self.norm(x) * (1 + scale) + shift))
320
+ x = rearrange(x, 'b f l d -> b (f l) d')
321
+ return x
322
+
323
+ class WanModelAudioProject(ModelMixin, ConfigMixin):
324
+ _no_split_modules = ['DiTAudioBlock']
325
+ @register_to_config
326
+ def __init__(
327
+ self,
328
+ dim: int,
329
+ in_dim: int,
330
+ ffn_dim: int,
331
+ out_dim: int,
332
+ text_dim: int,
333
+ freq_dim: int,
334
+ eps: float,
335
+ vae_stride: Tuple[int, int, int],
336
+ patch_size: Tuple[int, int, int],
337
+ num_heads: int,
338
+ num_layers: int,
339
+ has_image_input: bool,
340
+ **kwargs,
341
+ ):
342
+ super().__init__()
343
+ self.dim = dim
344
+ self.freq_dim = freq_dim
345
+ self.has_image_input = has_image_input
346
+ self.patch_size = patch_size
347
+
348
+ self.patch_embedding = nn.Conv3d(
349
+ in_dim, dim, kernel_size=patch_size, stride=patch_size)
350
+ self.text_embedding = nn.Sequential(
351
+ nn.Linear(text_dim, dim),
352
+ nn.GELU(approximate='tanh'),
353
+ nn.Linear(dim, dim)
354
+ )
355
+ self.time_embedding = nn.Sequential(
356
+ nn.Linear(freq_dim, dim),
357
+ nn.SiLU(),
358
+ nn.Linear(dim, dim)
359
+ )
360
+ self.time_projection = nn.Sequential(
361
+ nn.SiLU(), nn.Linear(dim, dim * 6))
362
+ self.blocks = nn.ModuleList([
363
+ DiTAudioBlock(has_image_input, dim, num_heads, ffn_dim, eps, i, num_layers)
364
+ for i in range(num_layers)
365
+ ])
366
+ self.head = Head(dim, out_dim, patch_size, eps)
367
+ head_dim = dim // num_heads
368
+ self.freqs = precompute_freqs_cis_3d(head_dim)
369
+
370
+ self.audio_emb = MLP(768, dim)
371
+
372
+ if has_image_input:
373
+ self.img_emb = MLP(1280, dim)
374
+
375
+ # init audio adapter
376
+ audio_window = 5
377
+ vae_scale = vae_stride[0]
378
+ intermediate_dim = 512
379
+ output_dim = 1536
380
+ context_tokens = 32
381
+ norm_output_audio = True
382
+ self.audio_window = audio_window
383
+ self.vae_scale = vae_scale
384
+ self.audio_proj = AudioProjModel(
385
+ seq_len=audio_window,
386
+ seq_len_vf=audio_window+vae_scale-1,
387
+ intermediate_dim=intermediate_dim,
388
+ output_dim=output_dim,
389
+ context_tokens=context_tokens,
390
+ norm_output_audio=norm_output_audio,
391
+ )
392
+
393
+ self.use_usp = dist.is_initialized()
394
+ self.sp_size = get_sequence_parallel_world_size() if self.use_usp else 1
395
+ self.sp_rank = get_sequence_parallel_rank() if self.use_usp else 0
396
+
397
+ def patchify(self, x: torch.Tensor):
398
+ x = self.patch_embedding(x)
399
+ grid_size = x.shape[2:]
400
+ x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
401
+ return x, grid_size # x, grid_size: (f, h, w)
402
+
403
+ def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
404
+ return rearrange(
405
+ x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
406
+ f=grid_size[0], h=grid_size[1], w=grid_size[2],
407
+ x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
408
+ )
409
+
410
+ def forward(self,
411
+ x: torch.Tensor, #(1, 16, 9, 64, 64))
412
+ timestep: torch.Tensor, #(9,)
413
+ context: torch.Tensor, #(5, 33, 12, 768)
414
+ y: Optional[torch.Tensor] = None, #(1, 16, 9, 64, 64)
415
+ use_gradient_checkpointing: bool = False,
416
+ use_gradient_checkpointing_offload: bool = False,
417
+ **kwargs,
418
+ ):
419
+
420
+ if self.freqs.device != x.device:
421
+ self.freqs = self.freqs.to(x.device)
422
+
423
+ x = torch.cat([x, y], dim=1) # (1, 32, 9, 64, 64)
424
+ x, grid_sizes = self.patchify(x)
425
+ t = self.time_embedding(
426
+ sinusoidal_embedding_1d(self.freq_dim, timestep.to(dtype=x.dtype)))
427
+ t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) # (bsz, 6, 1536)
428
+
429
+ # ==================== 音频条件处理 ====================
430
+ # 输入: context (bsz, 81, 5, 12, 768)
431
+ # - 81 帧 = 1 (第一帧) + 80 (后续帧, 每4帧对应VAE压缩后的1帧)
432
+ # - 5 是音频窗口大小 (audio_window)
433
+ # - 12 是音频特征的 blocks
434
+ # - 768 是音频特征维度
435
+
436
+ audio_cond = context.to(device=x.device, dtype=x.dtype)
437
+
438
+ # 1. 第一帧:直接使用完整的5帧音频窗口
439
+ first_frame_audio = audio_cond[:, :1, ...] # (bsz, 1, 5, 12, 768)
440
+
441
+ # 2. 后续帧:需要根据帧位置选择不同的音频窗口
442
+ # 将 32 帧重排为 (8 个 VAE latent, 每个4帧)
443
+ latter_frames_audio = rearrange(
444
+ audio_cond[:, 1:, ...],
445
+ "b (n_latent n_frame) w s c -> b n_latent n_frame w s c",
446
+ n_frame=self.vae_scale # vae_scale=4
447
+ ) # (bsz, 8, 4, 5, 12, 768)
448
+
449
+ mid_idx = self.audio_window // 2 # 窗口中心索引: 5//2=2
450
+
451
+ # 为每个 latent 的4帧选择合适的音频窗口:
452
+ # - 第1帧 (帧索引0): 无过去,取前3帧窗口 [:mid_idx+1] = [:3]
453
+ # - 中间帧 (帧索引1-2): 取中心1帧 [mid_idx:mid_idx+1] = [2:3]
454
+ # - 第4帧 (帧索引3): 无未来,取后3帧窗口 [mid_idx:] = [2:]
455
+
456
+ first_of_group = latter_frames_audio[:, :, :1, :mid_idx+1, ...] # (bsz, 8, 1, 3, 12, 768)
457
+ middle_of_group = latter_frames_audio[:, :, 1:-1, mid_idx:mid_idx+1, ...] # (bsz, 8, 2, 1, 12, 768)
458
+ last_of_group = latter_frames_audio[:, :, -1:, mid_idx:, ...] # (bsz, 8, 1, 3, 12, 768)
459
+
460
+ # 合并并展平窗口维度: (n_frame, window) -> (n_frame * window)
461
+ latter_frames_audio_processed = torch.cat([
462
+ rearrange(first_of_group, "b n_latent n_f w s c -> b n_latent (n_f w) s c"),
463
+ rearrange(middle_of_group, "b n_latent n_f w s c -> b n_latent (n_f w) s c"),
464
+ rearrange(last_of_group, "b n_latent n_f w s c -> b n_latent (n_f w) s c"),
465
+ ], dim=2) # (bsz, 8, 1*3 + 2*1 + 1*3, 12, 768) = (bsz, 8, 8, 12, 768)
466
+
467
+ # 3. 通过 AudioProjModel 投影到 DiT 所需的特征空间
468
+ context = self.audio_proj(
469
+ first_frame_audio,
470
+ latter_frames_audio_processed
471
+ ).to(x.dtype) # (bsz, 9, 32, 1536)
472
+
473
+ if self.use_usp:
474
+ x = torch.chunk(x, self.sp_size, dim=1)[self.sp_rank]
475
+
476
+ for block in self.blocks:
477
+ x = block(x, context, t_mod, self.freqs, grid_sizes)
478
+ x = self.head(x, t) # (bsz, 9*32*32, 64)
479
+ if self.use_usp:
480
+ x = get_sp_group().all_gather(x, dim=1)
481
+ x = self.unpatchify(x, grid_sizes) # (bsz, 16, 21, 64, 64)
482
+ return x
483
+
484
+
485
+ class AudioProjModel(ModelMixin, ConfigMixin):
486
+ def __init__(
487
+ self,
488
+ seq_len=5,
489
+ seq_len_vf=12,
490
+ blocks=12,
491
+ channels=768,
492
+ intermediate_dim=512,
493
+ output_dim=768,
494
+ context_tokens=32,
495
+ norm_output_audio=False,
496
+ ):
497
+ super().__init__()
498
+
499
+ self.seq_len = seq_len
500
+ self.blocks = blocks
501
+ self.channels = channels
502
+ self.input_dim = seq_len * blocks * channels
503
+ self.input_dim_vf = seq_len_vf * blocks * channels
504
+ self.intermediate_dim = intermediate_dim
505
+ self.context_tokens = context_tokens
506
+ self.output_dim = output_dim
507
+
508
+ # define multiple linear layers
509
+ self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
510
+ self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
511
+ self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
512
+ self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
513
+ self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
514
+
515
+ def forward(self, audio_embeds, audio_embeds_vf):
516
+ video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
517
+ B, _, _, S, C = audio_embeds.shape
518
+
519
+ # process audio of first frame
520
+ audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
521
+ batch_size, window_size, blocks, channels = audio_embeds.shape
522
+ audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
523
+
524
+ # process audio of latter frame
525
+ audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
526
+ batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
527
+ audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
528
+
529
+ # first projection
530
+ audio_embeds = torch.relu(self.proj1(audio_embeds))
531
+ audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
532
+ audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
533
+ audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
534
+ audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
535
+ batch_size_c, N_t, C_a = audio_embeds_c.shape
536
+ audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
537
+
538
+ # second projection
539
+ audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
540
+
541
+ context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
542
+
543
+ # normalization and reshape
544
+ with amp.autocast(dtype=torch.float32):
545
+ context_tokens = self.norm(context_tokens)
546
+ context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
547
+
548
+ return context_tokens
flash_head/src/pipeline/flash_head_pipeline.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import os
3
+ from PIL import Image
4
+ from loguru import logger
5
+ import time
6
+ import numpy as np
7
+ import torch
8
+ import torch.distributed as dist
9
+ from einops import rearrange
10
+
11
+ from transformers import Wav2Vec2FeatureExtractor
12
+
13
+ from flash_head.src.modules.flash_head_model import WanModelAudioProject
14
+ from flash_head.audio_analysis.wav2vec2 import Wav2Vec2Model
15
+ from flash_head.utils.utils import match_and_blend_colors_torch, resize_and_centercrop
16
+ from flash_head.utils.facecrop import process_image
17
+
18
+ # compile models to speedup inference
19
+ COMPILE_MODEL = False
20
+ COMPILE_VAE = False
21
+ # use parallel vae to speedup decode/encode, only support WanVAE
22
+ USE_PARALLEL_VAE = True
23
+
24
+ def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
25
+ def get_image(cond_image_path, use_face_crop):
26
+ if use_face_crop:
27
+ try:
28
+ image = process_image(cond_image_path)
29
+ return image
30
+ except Exception as e:
31
+ logger.error(f"Error processing {cond_image_path}: {e}")
32
+ return Image.open(cond_image_path).convert("RGB")
33
+
34
+ if os.path.isdir(cond_image_path_or_dir):
35
+ import glob
36
+ cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.png"))
37
+ cond_image_list.sort()
38
+ cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
39
+ else:
40
+ cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
41
+ return cond_image_dict
42
+
43
+ def timestep_transform(
44
+ t,
45
+ shift=5.0,
46
+ num_timesteps=1000,
47
+ ):
48
+ t = t / num_timesteps
49
+ # shift the timestep based on ratio
50
+ new_t = shift * t / (1 + (shift - 1) * t)
51
+ new_t = new_t * num_timesteps
52
+ return new_t
53
+
54
+
55
+ class FlashHeadPipeline:
56
+ def __init__(
57
+ self,
58
+ checkpoint_dir,
59
+ model_type,
60
+ wav2vec_dir,
61
+ device="cuda",
62
+ param_dtype=torch.bfloat16,
63
+ use_usp=False,
64
+ num_timesteps=1000,
65
+ use_timestep_transform=True,
66
+ ):
67
+ r"""
68
+ Initializes the image-to-video generation model components.
69
+ Args:
70
+ checkpoint_dir (`str`):
71
+ Path to directory containing model checkpoints
72
+ wav2vec_dir (`str`):
73
+ Path to directory containing wav2vec checkpoints
74
+ use_usp (`bool`, *optional*, defaults to False):
75
+ Enable distribution strategy of USP.
76
+ """
77
+ self.param_dtype = param_dtype
78
+ self.device = device
79
+ self.rank = dist.get_rank() if dist.is_initialized() else 0
80
+ self.use_usp = use_usp and dist.is_initialized()
81
+ self.model_type = model_type
82
+ self.use_ltx = model_type == "lite"
83
+
84
+ if self.use_ltx:
85
+ model_dir = os.path.join(checkpoint_dir, "Model_Lite")
86
+ vae_dir = os.path.join(checkpoint_dir, "VAE_LTX")
87
+
88
+ from flash_head.ltx_video.ltx_vae import LtxVAE
89
+ self.vae = LtxVAE(
90
+ pretrained_model_type_or_path=vae_dir,
91
+ dtype=self.param_dtype,
92
+ device=self.device,
93
+ )
94
+ else:
95
+ vae_path = os.path.join(checkpoint_dir, "VAE_Wan/Wan2.1_VAE.pth")
96
+
97
+ from flash_head.wan.modules import WanVAE
98
+ self.vae = WanVAE(
99
+ vae_path=vae_path,
100
+ dtype=self.param_dtype,
101
+ device=self.device,
102
+ parallel=(USE_PARALLEL_VAE and self.use_usp),
103
+ )
104
+
105
+ if self.model_type == "pretrained":
106
+ self.audio_guide_scale = 3.0
107
+ model_dir = os.path.join(checkpoint_dir, "teacher")
108
+ elif self.model_type == "pro":
109
+ model_dir = os.path.join(checkpoint_dir, "Model_Pro")
110
+
111
+ self.model = WanModelAudioProject.from_pretrained(model_dir)
112
+ self.model.eval().requires_grad_(False)
113
+ self.model.to(device=self.device, dtype=self.param_dtype)
114
+
115
+ self.config = self.model.config
116
+
117
+ if use_usp:
118
+ from xfuser.core.distributed import get_sequence_parallel_world_size
119
+ self.sp_size = get_sequence_parallel_world_size()
120
+ else:
121
+ self.sp_size = 1
122
+
123
+ if dist.is_initialized():
124
+ dist.barrier()
125
+
126
+ self.num_timesteps = num_timesteps
127
+ self.use_timestep_transform = use_timestep_transform
128
+
129
+ if COMPILE_MODEL:
130
+ self.model = torch.compile(self.model)
131
+ if COMPILE_VAE:
132
+ if self.use_ltx:
133
+ self.vae.model.encode = torch.compile(self.vae.model.encode)
134
+ self.vae.model.decode = torch.compile(self.vae.model.decode)
135
+ else:
136
+ self.vae.encode = torch.compile(self.vae.encode)
137
+ self.vae.decode = torch.compile(self.vae.decode)
138
+
139
+ self.audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec_dir, local_files_only=True).to(self.device)
140
+ self.audio_encoder.feature_extractor._freeze_parameters()
141
+ self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_dir, local_files_only=True)
142
+
143
+ @torch.no_grad()
144
+ def prepare_params(self,
145
+ cond_image_path_or_dir,
146
+ target_size,
147
+ frame_num,
148
+ motion_frames_num,
149
+ sampling_steps,
150
+ seed=None,
151
+ shift=5.0,
152
+ color_correction_strength=0.0,
153
+ use_face_crop=False,
154
+ ):
155
+ self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, use_face_crop)
156
+
157
+ self.frame_num = frame_num
158
+ self.motion_frames_num = motion_frames_num
159
+ self.color_correction_strength = color_correction_strength
160
+
161
+ self.target_h, self.target_w = target_size
162
+ self.lat_h, self.lat_w = self.target_h // self.config.vae_stride[1], self.target_w // self.config.vae_stride[2]
163
+
164
+ self.generator = torch.Generator(device=self.device).manual_seed(seed)
165
+
166
+ # prepare timesteps
167
+ if sampling_steps == 2:
168
+ timesteps = [1000, 500]
169
+ elif sampling_steps == 4:
170
+ timesteps = [1000, 750, 500, 250]
171
+ else:
172
+ timesteps = list(np.linspace(self.num_timesteps, 1, sampling_steps, dtype=np.float32))
173
+
174
+ timesteps.append(0.)
175
+ timesteps = [torch.tensor([t], device=self.device) for t in timesteps]
176
+ if self.use_timestep_transform:
177
+ timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps]
178
+ self.timesteps = timesteps
179
+
180
+ self.cond_image_tensor_dict = {}
181
+ self.ref_img_latent_dict = {}
182
+ for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
183
+ cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
184
+ cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2
185
+
186
+ self.cond_image_tensor_dict[person_name] = cond_image_tensor
187
+
188
+ video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
189
+ self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
190
+ if i == 0:
191
+ self.reset_person_name(person_name)
192
+
193
+ return
194
+
195
+ @torch.no_grad()
196
+ def reset_person_name(self, person_name=None):
197
+ if person_name is None or person_name not in self.cond_image_dict:
198
+ pass
199
+ else:
200
+ self.person_name = person_name
201
+ self.original_color_reference = self.cond_image_tensor_dict[self.person_name]
202
+ self.ref_img_latent = self.ref_img_latent_dict[self.person_name]
203
+ self.latent_motion_frames = self.ref_img_latent[:, :1].clone()
204
+
205
+ @torch.no_grad()
206
+ def preprocess_audio(self, speech_array, sr=16000, fps=25):
207
+ video_length = len(speech_array) * fps / sr
208
+
209
+ # wav2vec_feature_extractor
210
+ audio_feature = np.squeeze(
211
+ self.wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values
212
+ )
213
+ audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
214
+ audio_feature = audio_feature.unsqueeze(0)
215
+
216
+ # audio encoder
217
+ with torch.no_grad():
218
+ embeddings = self.audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
219
+
220
+ if len(embeddings) == 0:
221
+ logger.error("Fail to extract audio embedding")
222
+ return None
223
+
224
+ audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
225
+ audio_emb = rearrange(audio_emb, "b s d -> s b d")
226
+ return audio_emb
227
+
228
+ @torch.no_grad()
229
+ def generate(self, audio_embedding):
230
+ # evaluation mode
231
+ with torch.no_grad():
232
+
233
+ # sample videos
234
+ noise = torch.randn(
235
+ self.config.out_dim,
236
+ (self.frame_num - 1) // self.config.vae_stride[0] + 1,
237
+ self.lat_h,
238
+ self.lat_w,
239
+ dtype=self.param_dtype,
240
+ device=self.device,
241
+ generator=self.generator)
242
+
243
+ for i in range(len(self.timesteps)-1):
244
+ torch.cuda.synchronize()
245
+ start_time = time.time()
246
+
247
+ noise[:, :self.latent_motion_frames.shape[1]] = self.latent_motion_frames
248
+
249
+ flow_pred = self.model(
250
+ x=noise.unsqueeze(0),
251
+ timestep=self.timesteps[i],
252
+ context=audio_embedding,
253
+ y=self.ref_img_latent.unsqueeze(0),
254
+ )[0]
255
+
256
+ if self.model_type == "pretrained":
257
+ flow_pred_drop_audio = self.model(
258
+ x=noise.unsqueeze(0),
259
+ timestep=self.timesteps[i],
260
+ context=torch.zeros_like(audio_embedding),
261
+ y=self.ref_img_latent.unsqueeze(0),
262
+ )[0]
263
+ flow_pred = flow_pred_drop_audio + self.audio_guide_scale * (flow_pred - flow_pred_drop_audio)
264
+
265
+ # update latent
266
+ dt = self.timesteps[i] - self.timesteps[i + 1]
267
+ dt = (dt / self.num_timesteps).to(self.param_dtype)
268
+ noise = noise - flow_pred * dt[:, None, None, None]
269
+
270
+ else:
271
+ # update latent
272
+ t_i = (self.timesteps[i][:, None, None, None] / self.num_timesteps).to(self.param_dtype)
273
+ t_i_1 = (self.timesteps[i+1][:, None, None, None] / self.num_timesteps).to(self.param_dtype)
274
+ x_0 = noise - flow_pred * t_i
275
+
276
+ noise = (1 - t_i_1) * x_0 + t_i_1 * torch.randn(x_0.size(), dtype=x_0.dtype, device=self.device, generator=self.generator)
277
+
278
+ torch.cuda.synchronize()
279
+ end_time = time.time()
280
+ if self.rank == 0:
281
+ print(f'[generate] model denoise per step: {end_time - start_time}s')
282
+
283
+ noise[:, :self.latent_motion_frames.shape[1]] = self.latent_motion_frames
284
+
285
+ torch.cuda.synchronize()
286
+ start_decode_time = time.time()
287
+
288
+ videos = self.vae.decode(noise)
289
+
290
+ torch.cuda.synchronize()
291
+ end_decode_time = time.time()
292
+ if self.rank == 0:
293
+ print(f'[generate] decode video frames: {end_decode_time - start_decode_time}s')
294
+
295
+ torch.cuda.synchronize()
296
+ start_color_correction_time = time.time()
297
+ if self.color_correction_strength > 0.0:
298
+ videos = match_and_blend_colors_torch(videos, self.original_color_reference, self.color_correction_strength)
299
+
300
+ cond_frame = videos[:, :, -self.motion_frames_num:].to(self.device)
301
+ torch.cuda.synchronize()
302
+ end_color_correction_time = time.time()
303
+ if self.rank == 0:
304
+ print(f'[generate] color correction: {end_color_correction_time - start_color_correction_time}s')
305
+
306
+ torch.cuda.synchronize()
307
+ start_encode_time = time.time()
308
+ self.latent_motion_frames = self.vae.encode(cond_frame)
309
+ torch.cuda.synchronize()
310
+ end_encode_time = time.time()
311
+ if self.rank == 0:
312
+ print(f'[generate] encode motion frames: {end_encode_time - start_encode_time}s')
313
+
314
+ gen_video_samples = videos #[:, :, self.motion_frames_num:]
315
+
316
+ return gen_video_samples[0].to(torch.float32)
flash_head/utils/cpu_face_handler.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mediapipe as mp
2
+ import numpy as np
3
+ from typing import Tuple, List
4
+
5
+
6
+ class CPUFaceHandler:
7
+ """Handler for CPU-based face detection using MediaPipe.
8
+ (2 ms/frame)
9
+ This handler provides a simple interface for face detection using MediaPipe's
10
+ face detection model. It's optimized for CPU usage and provides basic face
11
+ detection functionality.
12
+ """
13
+
14
+ def __init__(self, model_selection: int = 1, min_detection_confidence: float = 0.0):
15
+ """Initialize the face detection handler."""
16
+ self.detector = mp.solutions.face_detection.FaceDetection(
17
+ model_selection=model_selection,
18
+ min_detection_confidence=min_detection_confidence,
19
+ )
20
+
21
+ def detect(self, image: np.ndarray) -> Tuple[int, List[int]]:
22
+ """Detect faces in the given image.
23
+
24
+ Args:
25
+ image (np.ndarray): RGB image array.
26
+
27
+ Returns:
28
+ Tuple[int, List[int]]: A tuple containing:
29
+ - Number of faces detected (int)
30
+ - Bounding box coordinates [x1, y1, x2, y2] if exactly one face is detected,
31
+ empty list otherwise
32
+ """
33
+ bboxs, scores = [], []
34
+ results = self.detector.process(image)
35
+ detection_result = results.detections
36
+ if detection_result is None:
37
+ return bboxs, scores
38
+ for detection in detection_result:
39
+ bboxC = detection.location_data.relative_bounding_box
40
+ x, y, w, h = bboxC.xmin, bboxC.ymin, bboxC.width, bboxC.height
41
+ x1, y1, x2, y2 = x, y, x + w, y + h
42
+ bboxs.append([x1, y1, x2, y2])
43
+ scores.append(detection.score[0])
44
+ return bboxs, scores
45
+
46
+ def __call__(self, image: np.ndarray) -> Tuple[int, List[int]]:
47
+ """Make the handler callable.
48
+
49
+ Args:
50
+ image (np.ndarray): RGB image array.
51
+
52
+ Returns:
53
+ Tuple[int, List[int]]: Same as detect() method.
54
+ """
55
+ return self.detect(image)
flash_head/utils/facecrop.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ 人脸裁剪处理脚本
4
+ 从单张图像中检测人脸,裁剪并调整大小到指定尺寸
5
+ """
6
+ import os
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+ from flash_head.utils.cpu_face_handler import CPUFaceHandler
11
+
12
+ def get_scaled_bbox(
13
+ bbox, img_w, img_h, ratio: float = 1.0, face_image: Image.Image = None
14
+ ):
15
+ """
16
+ 根据人脸边界框计算缩放后的裁剪区域
17
+
18
+ Args:
19
+ bbox: 人脸边界框 [x1, y1, x2, y2]
20
+ img_w: 图像宽度
21
+ img_h: 图像高度
22
+ ratio: 缩放比例,数值越大,人脸在画面中的比例越小(周围留白越多)
23
+ face_image: PIL Image 对象
24
+
25
+ Returns:
26
+ 裁剪后的人脸图像
27
+ """
28
+ x1, y1, x2, y2 = bbox
29
+
30
+ # Calculate center point
31
+ center_x = (x1 + x2) / 2
32
+ center_y = (y1 + y2) / 2
33
+
34
+ # Calculate width and height
35
+ width = x2 - x1
36
+
37
+ # Scale width and height
38
+ new_width = width * ratio
39
+ new_height = new_width
40
+
41
+ # tile pix
42
+ dis_x_left = new_width * 0.5
43
+ dis_x_right = new_width - dis_x_left # 0.5new_width
44
+ dis_y_up = new_height * 0.55
45
+ dis_y_down = new_height - dis_y_up # 0.45new_height
46
+
47
+ # Calculate new coordinates
48
+ new_x1 = int(max(0, center_x - dis_x_left))
49
+ new_y1 = int(max(0, center_y - dis_y_up))
50
+ new_x2 = int(min(img_w, center_x + dis_x_right))
51
+ new_y2 = int(min(img_h, center_y + dis_y_down))
52
+ scaled_bbox = [new_x1, new_y1, new_x2, new_y2]
53
+ crop_face = face_image.crop(scaled_bbox)
54
+ return crop_face
55
+
56
+
57
+ def process_image(
58
+ input_path,
59
+ face_ratio=2.0,
60
+ target_size=(512, 512),
61
+ ):
62
+ """
63
+ 处理单张图像,进行人脸检测和裁剪
64
+
65
+ Args:
66
+ input_path: 输入图像路径
67
+ face_ratio: 人脸缩放比例,建议范围:1.5-3.0,默认2.0
68
+ target_size: 输出图像尺寸,默认(512, 512)
69
+
70
+ Returns:
71
+ imgae: 处理后的图像
72
+ """
73
+ # 初始化人脸检测器
74
+ face_detector = CPUFaceHandler()
75
+
76
+ # 验证输入文件
77
+ if not os.path.isfile(input_path):
78
+ raise ValueError(f"File not found: {input_path}")
79
+
80
+ try:
81
+ # 读取图像
82
+ image = Image.open(input_path)
83
+ image = image.convert("RGB")
84
+ image_rgb = np.array(image)
85
+ img_h, img_w = image_rgb.shape[:2]
86
+
87
+ # 检测人脸
88
+ boxes, scores = face_detector(image_rgb)
89
+
90
+ if len(boxes) == 0:
91
+ raise ValueError("No face detected")
92
+
93
+ # 转换边界框坐标(从相对坐标转为绝对坐标)
94
+ boxes_abs = [
95
+ boxes[0][0] * img_w,
96
+ boxes[0][1] * img_h,
97
+ boxes[0][2] * img_w,
98
+ boxes[0][3] * img_h
99
+ ]
100
+
101
+ # 裁剪人脸
102
+ crop_face = get_scaled_bbox(boxes_abs, img_w, img_h, face_ratio, image)
103
+
104
+ # 调整大小
105
+ crop_face = crop_face.resize(target_size)
106
+
107
+ return crop_face
108
+
109
+ except Exception as e:
110
+ raise ValueError(f"Error processing {input_path}: {e}")
flash_head/utils/utils.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import math
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+ import torch.nn as nn
8
+ import pyloudnorm as pyln
9
+
10
+ def rgb_to_lab_torch(rgb: torch.Tensor) -> torch.Tensor:
11
+ """
12
+ PyTorch GPU版本:RGB转Lab颜色空间(输入范围[0,1],张量形状任意,最后一维为通道数)
13
+ 参考CIE 1931标准转换公式
14
+ """
15
+ # 转换为线性RGB(sRGB伽马校正逆过程)
16
+ linear_rgb = torch.where(
17
+ rgb > 0.04045,
18
+ ((rgb + 0.055) / 1.055) ** 2.4,
19
+ rgb / 12.92
20
+ )
21
+
22
+ # 线性RGB转XYZ(使用sRGB标准白点D65)
23
+ xyz_from_rgb = torch.tensor([
24
+ [0.4124564, 0.3575761, 0.1804375],
25
+ [0.2126729, 0.7151522, 0.0721750],
26
+ [0.0193339, 0.1191920, 0.9503041]
27
+ ], dtype=rgb.dtype, device=rgb.device)
28
+
29
+ # 维度适配:确保输入为(B, ..., C),矩阵乘法后保持空间维度
30
+ shape = linear_rgb.shape
31
+ linear_rgb_flat = linear_rgb.reshape(-1, 3) # (N, 3),N=B*T*H*W
32
+ xyz_flat = linear_rgb_flat @ xyz_from_rgb.T # (N, 3)
33
+ xyz = xyz_flat.reshape(shape) # 恢复原形状
34
+
35
+ # XYZ转Lab(使用D65白点参数)
36
+ xyz_ref = torch.tensor([0.95047, 1.0, 1.08883], dtype=rgb.dtype, device=rgb.device)
37
+ xyz_normalized = xyz / xyz_ref[None, None, None, None, :] # 广播适配(B, C, T, H, W)
38
+
39
+ # 应用Lab转换公式
40
+ epsilon = 0.008856
41
+ kappa = 903.3
42
+ xyz_normalized = torch.clamp(xyz_normalized, 1e-8, 1.0) # 避免log(0)
43
+
44
+ f_xyz = torch.where(
45
+ xyz_normalized > epsilon,
46
+ xyz_normalized ** (1/3),
47
+ (kappa * xyz_normalized + 16) / 116
48
+ )
49
+
50
+ L = 116 * f_xyz[..., 1] - 16 # Y通道对应亮度
51
+ a = 500 * (f_xyz[..., 0] - f_xyz[..., 1]) # X-Y对应红绿
52
+ b = 200 * (f_xyz[..., 1] - f_xyz[..., 2]) # Y-Z对应蓝黄
53
+
54
+ lab = torch.stack([L, a, b], dim=-1) # 最后一维拼接为Lab通道
55
+ return lab
56
+
57
+ def lab_to_rgb_torch(lab: torch.Tensor) -> torch.Tensor:
58
+ """
59
+ PyTorch GPU版本:Lab转RGB颜色空间(输出范围[0,1],张量形状任意,最后一维为通道数)
60
+ """
61
+ # Lab分离通道
62
+ L = lab[..., 0]
63
+ a = lab[..., 1]
64
+ b = lab[..., 2]
65
+
66
+ # Lab转XYZ
67
+ f_y = (L + 16) / 116
68
+ f_x = (a / 500) + f_y
69
+ f_z = f_y - (b / 200)
70
+
71
+ epsilon = 0.008856
72
+ kappa = 903.3
73
+
74
+ x = torch.where(f_x ** 3 > epsilon, f_x ** 3, (116 * f_x - 16) / kappa)
75
+ y = torch.where(L > kappa * epsilon, ((L + 16) / 116) ** 3, L / kappa)
76
+ z = torch.where(f_z ** 3 > epsilon, f_z ** 3, (116 * f_z - 16) / kappa)
77
+
78
+ # 乘以D65白点参数
79
+ xyz_ref = torch.tensor([0.95047, 1.0, 1.08883], dtype=lab.dtype, device=lab.device)
80
+ xyz = torch.stack([x, y, z], dim=-1) * xyz_ref[None, None, None, None, :]
81
+
82
+ # XYZ转线性RGB
83
+ rgb_from_xyz = torch.tensor([
84
+ [3.2404542, -1.5371385, -0.4985314],
85
+ [-0.9692660, 1.8760108, 0.0415560],
86
+ [0.0556434, -0.2040259, 1.0572252]
87
+ ], dtype=lab.dtype, device=lab.device)
88
+
89
+ # 维度适配:矩阵乘法
90
+ shape = xyz.shape
91
+ xyz_flat = xyz.reshape(-1, 3) # (N, 3)
92
+ linear_rgb_flat = xyz_flat @ rgb_from_xyz.T # (N, 3)
93
+ linear_rgb = linear_rgb_flat.reshape(shape) # 恢复原形状
94
+
95
+ # 线性RGB转sRGB(伽马校正)
96
+ rgb = torch.where(
97
+ linear_rgb > 0.0031308,
98
+ 1.055 * (linear_rgb ** (1/2.4)) - 0.055,
99
+ 12.92 * linear_rgb
100
+ )
101
+
102
+ # 确保输出在[0,1]范围内
103
+ rgb = torch.clamp(rgb, 0.0, 1.0)
104
+ return rgb
105
+
106
+ def match_and_blend_colors_torch(
107
+ source_chunk: torch.Tensor,
108
+ reference_image: torch.Tensor,
109
+ strength: float
110
+ ) -> torch.Tensor:
111
+ """
112
+ 全GPU批量运算版本:将视频chunk的颜色匹配到参考图像并混合(支持B>1、T帧并行)
113
+
114
+ Args:
115
+ source_chunk (torch.Tensor): 视频chunk (B, C, T, H, W),范围[-1, 1]
116
+ reference_image (torch.Tensor): 参考图像 (B, C, 1, H, W),范围[-1, 1](B需与source_chunk一致)
117
+ strength (float): 颜色校正强度 (0.0-1.0),0.0无校正,1.0完全校正
118
+
119
+ Returns:
120
+ torch.Tensor: 颜色校正后的视频chunk (B, C, T, H, W),范围[-1, 1]
121
+ """
122
+ # 强度为0直接返回原图
123
+ if strength <= 0.0:
124
+ return source_chunk.clone()
125
+
126
+ # 验证强度范围
127
+ if not 0.0 <= strength <= 1.0:
128
+ raise ValueError(f"Strength必须在0.0-1.0之间,当前值:{strength}")
129
+
130
+ # 验证输入形状(确保B一致,参考图T=1)
131
+ B, C, T, H, W = source_chunk.shape
132
+ assert reference_image.shape == (B, C, 1, H, W), \
133
+ f"参考图像形状需为(B, C, 1, H, W),当前为{reference_image.shape}"
134
+ assert C == 3, f"仅支持3通道RGB图像,当前通道数:{C}"
135
+
136
+ # 保持设备和数据类型一致
137
+ device = source_chunk.device
138
+ dtype = source_chunk.dtype
139
+ reference_image = reference_image.to(device=device, dtype=dtype)
140
+
141
+ # 1. 从[-1,1]转换到[0,1](GPU上直接运算)
142
+ source_01 = (source_chunk + 1.0) / 2.0
143
+ ref_01 = (reference_image + 1.0) / 2.0
144
+
145
+ # 2. 调整维度顺序:(B, C, T, H, W) → (B, T, H, W, C)(适配颜色空间转换)
146
+ # 参考图:(B, C, 1, H, W) → (B, 1, H, W, C)
147
+ source_permuted = source_01.permute(0, 2, 3, 4, 1) # 通道移到最后一维
148
+ ref_permuted = ref_01.permute(0, 2, 3, 4, 1)
149
+
150
+ # 3. RGB转Lab(批量处理所有帧)
151
+ source_lab = rgb_to_lab_torch(source_permuted)
152
+ ref_lab = rgb_to_lab_torch(ref_permuted) # (B, 1, H, W, 3)
153
+
154
+ # 4. 批量颜色迁移:匹配L/a/b通道的均值和标准差(核心逻辑)
155
+ # 计算参考图各通道的均值和标准差(对H、W维度求统计,保持B维度)
156
+ ref_mean = ref_lab.mean(dim=[2, 3], keepdim=True) # (B, 1, 1, 1, 3)
157
+ ref_std = ref_lab.std(dim=[2, 3], keepdim=True, unbiased=False) # (B, 1, 1, 1, 3)
158
+
159
+ # 计算源视频各通道的均值和标准差(对H、W维度求统计,保持B、T维度)
160
+ source_mean = source_lab.mean(dim=[2, 3], keepdim=True) # (B, T, 1, 1, 3)
161
+ source_std = source_lab.std(dim=[2, 3], keepdim=True, unbiased=False) # (B, T, 1, 1, 3)
162
+
163
+ # 避免标准差为0的除法错误(用1.0替代0)
164
+ source_std_safe = torch.where(source_std < 1e-8, torch.ones_like(source_std), source_std)
165
+
166
+ # 颜色迁移公式:(源 - 源均值) * (参考标准差/源标准差) + 参考均值
167
+ corrected_lab = (source_lab - source_mean) * (ref_std / source_std_safe) + ref_mean
168
+
169
+ # 5. Lab转RGB(批量转换所有校正后的帧)
170
+ corrected_rgb_01 = lab_to_rgb_torch(corrected_lab)
171
+
172
+ # 6. 批量混合原始帧和校正帧(按强度加权)
173
+ blended_rgb_01 = (1 - strength) * source_permuted + strength * corrected_rgb_01
174
+
175
+ # 7. 还原维度顺序和数值范围:(B, T, H, W, C) → (B, C, T, H, W),范围[0,1]→[-1,1]
176
+ blended_rgb_01 = blended_rgb_01.permute(0, 4, 1, 2, 3) # 通道移回第二维
177
+ blended_rgb_minus1_1 = (blended_rgb_01 * 2.0) - 1.0
178
+
179
+ # 8. 确保输出格式正确(连续内存布局)
180
+ output = blended_rgb_minus1_1.contiguous().to(device=device, dtype=dtype)
181
+
182
+ return output
183
+
184
+ def resize_and_centercrop(cond_image, target_size):
185
+ """
186
+ Resize image or tensor to the target size without padding.
187
+ """
188
+
189
+ # Get the original size
190
+ if isinstance(cond_image, torch.Tensor):
191
+ _, orig_h, orig_w = cond_image.shape
192
+ else:
193
+ orig_h, orig_w = cond_image.height, cond_image.width
194
+
195
+ target_h, target_w = target_size
196
+
197
+ # Calculate the scaling factor for resizing
198
+ scale_h = target_h / orig_h
199
+ scale_w = target_w / orig_w
200
+
201
+ # Compute the final size
202
+ scale = max(scale_h, scale_w)
203
+ final_h = math.ceil(scale * orig_h)
204
+ final_w = math.ceil(scale * orig_w)
205
+
206
+ # Resize
207
+ if isinstance(cond_image, torch.Tensor):
208
+ if len(cond_image.shape) == 3:
209
+ cond_image = cond_image[None]
210
+ resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous()
211
+ # crop
212
+ cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
213
+ cropped_tensor = cropped_tensor.squeeze(0)
214
+ else:
215
+ resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
216
+ resized_image = np.array(resized_image)
217
+ # tensor and crop
218
+ resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
219
+ cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
220
+ cropped_tensor = cropped_tensor[:, :, None, :, :]
221
+
222
+ return cropped_tensor
flash_head/wan/modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .vae import WanVAE
2
+
3
+ __all__ = [
4
+ 'WanVAE',
5
+ ]
flash_head/wan/modules/vae.py ADDED
@@ -0,0 +1,1598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange
8
+ from loguru import logger
9
+
10
+ __all__ = [
11
+ "WanVAE",
12
+ ]
13
+
14
+ CACHE_T = 2
15
+
16
+
17
+ class CausalConv3d(nn.Conv3d):
18
+ """
19
+ Causal 3d convolusion.
20
+ """
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__(*args, **kwargs)
24
+ self._padding = (
25
+ self.padding[2],
26
+ self.padding[2],
27
+ self.padding[1],
28
+ self.padding[1],
29
+ 2 * self.padding[0],
30
+ 0,
31
+ )
32
+ self.padding = (0, 0, 0)
33
+
34
+ def forward(self, x, cache_x=None):
35
+ padding = list(self._padding)
36
+ if cache_x is not None and self._padding[4] > 0:
37
+ cache_x = cache_x.to(x.device)
38
+ x = torch.cat([cache_x, x], dim=2)
39
+ padding[4] -= cache_x.shape[2]
40
+ x = F.pad(x, padding)
41
+
42
+ return super().forward(x)
43
+
44
+
45
+ class RMS_norm(nn.Module):
46
+ def __init__(self, dim, channel_first=True, images=True, bias=False):
47
+ super().__init__()
48
+ broadcastable_dims = (1, 1, 1) if not images else (1, 1)
49
+ shape = (dim, *broadcastable_dims) if channel_first else (dim,)
50
+
51
+ self.channel_first = channel_first
52
+ self.scale = dim**0.5
53
+ self.gamma = nn.Parameter(torch.ones(shape))
54
+ self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
55
+
56
+ def forward(self, x):
57
+ return (
58
+ F.normalize(x, dim=(1 if self.channel_first else -1))
59
+ * self.scale
60
+ * self.gamma
61
+ + self.bias
62
+ )
63
+
64
+
65
+ class Upsample(nn.Upsample):
66
+ def forward(self, x):
67
+ """
68
+ Fix bfloat16 support for nearest neighbor interpolation.
69
+ """
70
+ return super().forward(x)
71
+
72
+
73
+ class Resample(nn.Module):
74
+ def __init__(self, dim, mode):
75
+ assert mode in (
76
+ "none",
77
+ "upsample2d",
78
+ "upsample3d",
79
+ "downsample2d",
80
+ "downsample3d",
81
+ )
82
+ super().__init__()
83
+ self.dim = dim
84
+ self.mode = mode
85
+
86
+ # layers
87
+ if mode == "upsample2d":
88
+ self.resample = nn.Sequential(
89
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
90
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
91
+ )
92
+ elif mode == "upsample3d":
93
+ self.resample = nn.Sequential(
94
+ Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
95
+ nn.Conv2d(dim, dim // 2, 3, padding=1),
96
+ )
97
+ self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
98
+
99
+ elif mode == "downsample2d":
100
+ self.resample = nn.Sequential(
101
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
102
+ )
103
+ elif mode == "downsample3d":
104
+ self.resample = nn.Sequential(
105
+ nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
106
+ )
107
+ self.time_conv = CausalConv3d(
108
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
109
+ )
110
+
111
+ else:
112
+ self.resample = nn.Identity()
113
+
114
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
115
+ b, c, t, h, w = x.size()
116
+ if self.mode == "upsample3d":
117
+ if feat_cache is not None:
118
+ idx = feat_idx[0]
119
+ if feat_cache[idx] is None:
120
+ feat_cache[idx] = "Rep"
121
+ feat_idx[0] += 1
122
+ else:
123
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
124
+ if (
125
+ cache_x.shape[2] < 2
126
+ and feat_cache[idx] is not None
127
+ and feat_cache[idx] != "Rep"
128
+ ):
129
+ # cache last frame of last two chunk
130
+ cache_x = torch.cat(
131
+ [
132
+ feat_cache[idx][:, :, -1, :, :]
133
+ .unsqueeze(2)
134
+ .to(cache_x.device),
135
+ cache_x,
136
+ ],
137
+ dim=2,
138
+ )
139
+ if (
140
+ cache_x.shape[2] < 2
141
+ and feat_cache[idx] is not None
142
+ and feat_cache[idx] == "Rep"
143
+ ):
144
+ cache_x = torch.cat(
145
+ [torch.zeros_like(cache_x).to(cache_x.device), cache_x],
146
+ dim=2,
147
+ )
148
+ if feat_cache[idx] == "Rep":
149
+ x = self.time_conv(x)
150
+ else:
151
+ x = self.time_conv(x, feat_cache[idx])
152
+ feat_cache[idx] = cache_x
153
+ feat_idx[0] += 1
154
+
155
+ x = x.reshape(b, 2, c, t, h, w)
156
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
157
+ x = x.reshape(b, c, t * 2, h, w)
158
+ t = x.shape[2]
159
+ x = rearrange(x, "b c t h w -> (b t) c h w")
160
+ x = self.resample(x)
161
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
162
+
163
+ if self.mode == "downsample3d":
164
+ if feat_cache is not None:
165
+ idx = feat_idx[0]
166
+ if feat_cache[idx] is None:
167
+ feat_cache[idx] = x.clone()
168
+ feat_idx[0] += 1
169
+ else:
170
+ cache_x = x[:, :, -1:, :, :].clone()
171
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
172
+ # # cache last frame of last two chunk
173
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
174
+
175
+ x = self.time_conv(
176
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
177
+ )
178
+ feat_cache[idx] = cache_x
179
+ feat_idx[0] += 1
180
+ return x
181
+
182
+ def init_weight(self, conv):
183
+ conv_weight = conv.weight
184
+ nn.init.zeros_(conv_weight)
185
+ c1, c2, t, h, w = conv_weight.size()
186
+ one_matrix = torch.eye(c1, c2)
187
+ init_matrix = one_matrix
188
+ nn.init.zeros_(conv_weight)
189
+ # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
190
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
191
+ conv.weight.data.copy_(conv_weight)
192
+ nn.init.zeros_(conv.bias.data)
193
+
194
+ def init_weight2(self, conv):
195
+ conv_weight = conv.weight.data
196
+ nn.init.zeros_(conv_weight)
197
+ c1, c2, t, h, w = conv_weight.size()
198
+ init_matrix = torch.eye(c1 // 2, c2)
199
+ # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
200
+ conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
201
+ conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
202
+ conv.weight.data.copy_(conv_weight)
203
+ nn.init.zeros_(conv.bias.data)
204
+
205
+
206
+ class ResidualBlock(nn.Module):
207
+ def __init__(self, in_dim, out_dim, dropout=0.0):
208
+ super().__init__()
209
+ self.in_dim = in_dim
210
+ self.out_dim = out_dim
211
+
212
+ # layers
213
+ self.residual = nn.Sequential(
214
+ RMS_norm(in_dim, images=False),
215
+ nn.SiLU(),
216
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
217
+ RMS_norm(out_dim, images=False),
218
+ nn.SiLU(),
219
+ nn.Dropout(dropout),
220
+ CausalConv3d(out_dim, out_dim, 3, padding=1),
221
+ )
222
+ self.shortcut = (
223
+ CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
224
+ )
225
+
226
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
227
+ h = self.shortcut(x)
228
+ for layer in self.residual:
229
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
230
+ idx = feat_idx[0]
231
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
232
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
233
+ # cache last frame of last two chunk
234
+ cache_x = torch.cat(
235
+ [
236
+ feat_cache[idx][:, :, -1, :, :]
237
+ .unsqueeze(2)
238
+ .to(cache_x.device),
239
+ cache_x,
240
+ ],
241
+ dim=2,
242
+ )
243
+ x = layer(x, feat_cache[idx])
244
+ feat_cache[idx] = cache_x
245
+ feat_idx[0] += 1
246
+ else:
247
+ x = layer(x)
248
+ return x + h
249
+
250
+
251
+ class AttentionBlock(nn.Module):
252
+ """
253
+ Causal self-attention with a single head.
254
+ """
255
+
256
+ def __init__(self, dim):
257
+ super().__init__()
258
+ self.dim = dim
259
+
260
+ # layers
261
+ self.norm = RMS_norm(dim)
262
+ self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
263
+ self.proj = nn.Conv2d(dim, dim, 1)
264
+
265
+ # zero out the last layer params
266
+ nn.init.zeros_(self.proj.weight)
267
+
268
+ def forward(self, x):
269
+ identity = x
270
+ b, c, t, h, w = x.size()
271
+ x = rearrange(x, "b c t h w -> (b t) c h w")
272
+ x = self.norm(x)
273
+ # compute query, key, value
274
+ q, k, v = (
275
+ self.to_qkv(x)
276
+ .reshape(b * t, 1, c * 3, -1)
277
+ .permute(0, 1, 3, 2)
278
+ .contiguous()
279
+ .chunk(3, dim=-1)
280
+ )
281
+
282
+ # apply attention
283
+ x = F.scaled_dot_product_attention(
284
+ q,
285
+ k,
286
+ v,
287
+ )
288
+ x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
289
+
290
+ # output
291
+ x = self.proj(x)
292
+ x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
293
+ return x + identity
294
+
295
+
296
+ class Encoder3d(nn.Module):
297
+ def __init__(
298
+ self,
299
+ dim=128,
300
+ z_dim=4,
301
+ dim_mult=[1, 2, 4, 4],
302
+ num_res_blocks=2,
303
+ attn_scales=[],
304
+ temperal_downsample=[True, True, False],
305
+ dropout=0.0,
306
+ ):
307
+ super().__init__()
308
+ self.dim = dim
309
+ self.z_dim = z_dim
310
+ self.dim_mult = dim_mult
311
+ self.num_res_blocks = num_res_blocks
312
+ self.attn_scales = attn_scales
313
+ self.temperal_downsample = temperal_downsample
314
+
315
+ # dimensions
316
+ dims = [dim * u for u in [1] + dim_mult]
317
+ scale = 1.0
318
+
319
+ # init block
320
+ self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
321
+
322
+ # downsample blocks
323
+ downsamples = []
324
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
325
+ # residual (+attention) blocks
326
+ for _ in range(num_res_blocks):
327
+ downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
328
+ if scale in attn_scales:
329
+ downsamples.append(AttentionBlock(out_dim))
330
+ in_dim = out_dim
331
+
332
+ # downsample block
333
+ if i != len(dim_mult) - 1:
334
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
335
+ downsamples.append(Resample(out_dim, mode=mode))
336
+ scale /= 2.0
337
+ self.downsamples = nn.Sequential(*downsamples)
338
+
339
+ # middle blocks
340
+ self.middle = nn.Sequential(
341
+ ResidualBlock(out_dim, out_dim, dropout),
342
+ AttentionBlock(out_dim),
343
+ ResidualBlock(out_dim, out_dim, dropout),
344
+ )
345
+
346
+ # output blocks
347
+ self.head = nn.Sequential(
348
+ RMS_norm(out_dim, images=False),
349
+ nn.SiLU(),
350
+ CausalConv3d(out_dim, z_dim, 3, padding=1),
351
+ )
352
+
353
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
354
+ if feat_cache is not None:
355
+ idx = feat_idx[0]
356
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
357
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
358
+ # cache last frame of last two chunk
359
+ cache_x = torch.cat(
360
+ [
361
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
362
+ cache_x,
363
+ ],
364
+ dim=2,
365
+ )
366
+ x = self.conv1(x, feat_cache[idx])
367
+ feat_cache[idx] = cache_x
368
+ feat_idx[0] += 1
369
+ else:
370
+ x = self.conv1(x)
371
+
372
+ ## downsamples
373
+ for layer in self.downsamples:
374
+ if feat_cache is not None:
375
+ x = layer(x, feat_cache, feat_idx)
376
+ else:
377
+ x = layer(x)
378
+
379
+ ## middle
380
+ for layer in self.middle:
381
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
382
+ x = layer(x, feat_cache, feat_idx)
383
+ else:
384
+ x = layer(x)
385
+
386
+ ## head
387
+ for layer in self.head:
388
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
389
+ idx = feat_idx[0]
390
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
391
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
392
+ # cache last frame of last two chunk
393
+ cache_x = torch.cat(
394
+ [
395
+ feat_cache[idx][:, :, -1, :, :]
396
+ .unsqueeze(2)
397
+ .to(cache_x.device),
398
+ cache_x,
399
+ ],
400
+ dim=2,
401
+ )
402
+ x = layer(x, feat_cache[idx])
403
+ feat_cache[idx] = cache_x
404
+ feat_idx[0] += 1
405
+ else:
406
+ x = layer(x)
407
+ return x
408
+
409
+
410
+ class Decoder3d(nn.Module):
411
+ def __init__(
412
+ self,
413
+ dim=128,
414
+ z_dim=4,
415
+ dim_mult=[1, 2, 4, 4],
416
+ num_res_blocks=2,
417
+ attn_scales=[],
418
+ temperal_upsample=[False, True, True],
419
+ dropout=0.0,
420
+ ):
421
+ super().__init__()
422
+ self.dim = dim
423
+ self.z_dim = z_dim
424
+ self.dim_mult = dim_mult
425
+ self.num_res_blocks = num_res_blocks
426
+ self.attn_scales = attn_scales
427
+ self.temperal_upsample = temperal_upsample
428
+
429
+ # dimensions
430
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
431
+
432
+ scale = 1.0 / 2 ** (len(dim_mult) - 2)
433
+
434
+ # init block
435
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
436
+
437
+ # middle blocks
438
+ self.middle = nn.Sequential(
439
+ ResidualBlock(dims[0], dims[0], dropout),
440
+ AttentionBlock(dims[0]),
441
+ ResidualBlock(dims[0], dims[0], dropout),
442
+ )
443
+
444
+ # upsample blocks
445
+ upsamples = []
446
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
447
+ # residual (+attention) blocks
448
+ if i == 1 or i == 2 or i == 3:
449
+ in_dim = in_dim // 2
450
+ for _ in range(num_res_blocks + 1):
451
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
452
+ if scale in attn_scales:
453
+ upsamples.append(AttentionBlock(out_dim))
454
+ in_dim = out_dim
455
+
456
+ # upsample block
457
+ if i != len(dim_mult) - 1:
458
+ mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
459
+ upsamples.append(Resample(out_dim, mode=mode))
460
+ scale *= 2.0
461
+ self.upsamples = nn.Sequential(*upsamples)
462
+
463
+ # output blocks
464
+ self.head = nn.Sequential(
465
+ RMS_norm(out_dim, images=False),
466
+ nn.SiLU(),
467
+ CausalConv3d(out_dim, 3, 3, padding=1),
468
+ )
469
+
470
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
471
+ ## conv1
472
+ if feat_cache is not None:
473
+ idx = feat_idx[0]
474
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
475
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
476
+ # cache last frame of last two chunk
477
+ cache_x = torch.cat(
478
+ [
479
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
480
+ cache_x,
481
+ ],
482
+ dim=2,
483
+ )
484
+ x = self.conv1(x, feat_cache[idx])
485
+ feat_cache[idx] = cache_x
486
+ feat_idx[0] += 1
487
+ else:
488
+ x = self.conv1(x)
489
+
490
+ ## middle
491
+ for layer in self.middle:
492
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
493
+ x = layer(x, feat_cache, feat_idx)
494
+ else:
495
+ x = layer(x)
496
+
497
+ ## upsamples
498
+ for layer in self.upsamples:
499
+ if feat_cache is not None:
500
+ x = layer(x, feat_cache, feat_idx)
501
+ else:
502
+ x = layer(x)
503
+
504
+ ## head
505
+ for layer in self.head:
506
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
507
+ idx = feat_idx[0]
508
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
509
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
510
+ # cache last frame of last two chunk
511
+ cache_x = torch.cat(
512
+ [
513
+ feat_cache[idx][:, :, -1, :, :]
514
+ .unsqueeze(2)
515
+ .to(cache_x.device),
516
+ cache_x,
517
+ ],
518
+ dim=2,
519
+ )
520
+ x = layer(x, feat_cache[idx])
521
+ feat_cache[idx] = cache_x
522
+ feat_idx[0] += 1
523
+ else:
524
+ x = layer(x)
525
+ return x
526
+
527
+
528
+ def count_conv3d(model):
529
+ count = 0
530
+ for m in model.modules():
531
+ if isinstance(m, CausalConv3d):
532
+ count += 1
533
+ return count
534
+
535
+
536
+ class WanVAE_(nn.Module):
537
+ def __init__(
538
+ self,
539
+ dim=128,
540
+ z_dim=4,
541
+ dim_mult=[1, 2, 4, 4],
542
+ num_res_blocks=2,
543
+ attn_scales=[],
544
+ temperal_downsample=[True, True, False],
545
+ dropout=0.0,
546
+ ):
547
+ super().__init__()
548
+ self.dim = dim
549
+ self.z_dim = z_dim
550
+ self.dim_mult = dim_mult
551
+ self.num_res_blocks = num_res_blocks
552
+ self.attn_scales = attn_scales
553
+ self.temperal_downsample = temperal_downsample
554
+ self.temperal_upsample = temperal_downsample[::-1]
555
+ self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
556
+
557
+ # The minimal tile height and width for spatial tiling to be used
558
+ self.tile_sample_min_height = 256
559
+ self.tile_sample_min_width = 256
560
+
561
+ # The minimal distance between two spatial tiles
562
+ self.tile_sample_stride_height = 192
563
+ self.tile_sample_stride_width = 192
564
+ # modules
565
+ self.encoder = Encoder3d(
566
+ dim,
567
+ z_dim * 2,
568
+ dim_mult,
569
+ num_res_blocks,
570
+ attn_scales,
571
+ self.temperal_downsample,
572
+ dropout,
573
+ )
574
+ self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
575
+ self.conv2 = CausalConv3d(z_dim, z_dim, 1)
576
+ self.decoder = Decoder3d(
577
+ dim,
578
+ z_dim,
579
+ dim_mult,
580
+ num_res_blocks,
581
+ attn_scales,
582
+ self.temperal_upsample,
583
+ dropout,
584
+ )
585
+
586
+ def forward(self, x):
587
+ mu, log_var = self.encode(x)
588
+ z = self.reparameterize(mu, log_var)
589
+ x_recon = self.decode(z)
590
+ return x_recon, mu, log_var
591
+
592
+ def blend_v(self, a, b, blend_extent):
593
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
594
+ for y in range(blend_extent):
595
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
596
+ 1 - y / blend_extent
597
+ ) + b[:, :, :, y, :] * (y / blend_extent)
598
+ return b
599
+
600
+ def blend_h(self, a, b, blend_extent):
601
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
602
+ for x in range(blend_extent):
603
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
604
+ 1 - x / blend_extent
605
+ ) + b[:, :, :, :, x] * (x / blend_extent)
606
+ return b
607
+
608
+ def tiled_encode(self, x, scale):
609
+ _, _, num_frames, height, width = x.shape
610
+ latent_height = height // self.spatial_compression_ratio
611
+ latent_width = width // self.spatial_compression_ratio
612
+
613
+ tile_latent_min_height = (
614
+ self.tile_sample_min_height // self.spatial_compression_ratio
615
+ )
616
+ tile_latent_min_width = (
617
+ self.tile_sample_min_width // self.spatial_compression_ratio
618
+ )
619
+ tile_latent_stride_height = (
620
+ self.tile_sample_stride_height // self.spatial_compression_ratio
621
+ )
622
+ tile_latent_stride_width = (
623
+ self.tile_sample_stride_width // self.spatial_compression_ratio
624
+ )
625
+
626
+ blend_height = tile_latent_min_height - tile_latent_stride_height
627
+ blend_width = tile_latent_min_width - tile_latent_stride_width
628
+
629
+ # Split x into overlapping tiles and encode them separately.
630
+ # The tiles have an overlap to avoid seams between tiles.
631
+ rows = []
632
+ for i in range(0, height, self.tile_sample_stride_height):
633
+ row = []
634
+ for j in range(0, width, self.tile_sample_stride_width):
635
+ self.clear_cache()
636
+ time = []
637
+ frame_range = 1 + (num_frames - 1) // 4
638
+ for k in range(frame_range):
639
+ self._enc_conv_idx = [0]
640
+ if k == 0:
641
+ tile = x[
642
+ :,
643
+ :,
644
+ :1,
645
+ i : i + self.tile_sample_min_height,
646
+ j : j + self.tile_sample_min_width,
647
+ ]
648
+ else:
649
+ tile = x[
650
+ :,
651
+ :,
652
+ 1 + 4 * (k - 1) : 1 + 4 * k,
653
+ i : i + self.tile_sample_min_height,
654
+ j : j + self.tile_sample_min_width,
655
+ ]
656
+ tile = self.encoder(
657
+ tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
658
+ )
659
+ mu, log_var = self.conv1(tile).chunk(2, dim=1)
660
+ if isinstance(scale[0], torch.Tensor):
661
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[
662
+ 1
663
+ ].view(1, self.z_dim, 1, 1, 1)
664
+ else:
665
+ mu = (mu - scale[0]) * scale[1]
666
+
667
+ time.append(mu)
668
+
669
+ row.append(torch.cat(time, dim=2))
670
+ rows.append(row)
671
+ self.clear_cache()
672
+
673
+ result_rows = []
674
+ for i, row in enumerate(rows):
675
+ result_row = []
676
+ for j, tile in enumerate(row):
677
+ # blend the above tile and the left tile
678
+ # to the current tile and add the current tile to the result row
679
+ if i > 0:
680
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
681
+ if j > 0:
682
+ tile = self.blend_h(row[j - 1], tile, blend_width)
683
+ result_row.append(
684
+ tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]
685
+ )
686
+ result_rows.append(torch.cat(result_row, dim=-1))
687
+
688
+ enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
689
+ return enc
690
+
691
+ def tiled_decode(self, z, scale):
692
+ if isinstance(scale[0], torch.Tensor):
693
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
694
+ 1, self.z_dim, 1, 1, 1
695
+ )
696
+ else:
697
+ z = z / scale[1] + scale[0]
698
+
699
+ _, _, num_frames, height, width = z.shape
700
+ sample_height = height * self.spatial_compression_ratio
701
+ sample_width = width * self.spatial_compression_ratio
702
+
703
+ tile_latent_min_height = (
704
+ self.tile_sample_min_height // self.spatial_compression_ratio
705
+ )
706
+ tile_latent_min_width = (
707
+ self.tile_sample_min_width // self.spatial_compression_ratio
708
+ )
709
+ tile_latent_stride_height = (
710
+ self.tile_sample_stride_height // self.spatial_compression_ratio
711
+ )
712
+ tile_latent_stride_width = (
713
+ self.tile_sample_stride_width // self.spatial_compression_ratio
714
+ )
715
+
716
+ blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
717
+ blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
718
+
719
+ # Split z into overlapping tiles and decode them separately.
720
+ # The tiles have an overlap to avoid seams between tiles.
721
+ rows = []
722
+ for i in range(0, height, tile_latent_stride_height):
723
+ row = []
724
+ for j in range(0, width, tile_latent_stride_width):
725
+ self.clear_cache()
726
+ time = []
727
+ for k in range(num_frames):
728
+ self._conv_idx = [0]
729
+ tile = z[
730
+ :,
731
+ :,
732
+ k : k + 1,
733
+ i : i + tile_latent_min_height,
734
+ j : j + tile_latent_min_width,
735
+ ]
736
+ tile = self.conv2(tile)
737
+ decoded = self.decoder(
738
+ tile, feat_cache=self._feat_map, feat_idx=self._conv_idx
739
+ )
740
+ time.append(decoded)
741
+ row.append(torch.cat(time, dim=2))
742
+ rows.append(row)
743
+ self.clear_cache()
744
+
745
+ result_rows = []
746
+ for i, row in enumerate(rows):
747
+ result_row = []
748
+ for j, tile in enumerate(row):
749
+ # blend the above tile and the left tile
750
+ # to the current tile and add the current tile to the result row
751
+ if i > 0:
752
+ tile = self.blend_v(rows[i - 1][j], tile, blend_height)
753
+ if j > 0:
754
+ tile = self.blend_h(row[j - 1], tile, blend_width)
755
+ result_row.append(
756
+ tile[
757
+ :,
758
+ :,
759
+ :,
760
+ : self.tile_sample_stride_height,
761
+ : self.tile_sample_stride_width,
762
+ ]
763
+ )
764
+ result_rows.append(torch.cat(result_row, dim=-1))
765
+
766
+ dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
767
+
768
+ return dec
769
+
770
+ def encode(self, x, scale, return_mu=False):
771
+ self.clear_cache()
772
+ ## cache
773
+ t = x.shape[2]
774
+ iter_ = 1 + (t - 1) // 4
775
+ for i in range(iter_):
776
+ self._enc_conv_idx = [0]
777
+ if i == 0:
778
+ out = self.encoder(
779
+ x[:, :, :1, :, :],
780
+ feat_cache=self._enc_feat_map,
781
+ feat_idx=self._enc_conv_idx,
782
+ )
783
+ else:
784
+ out_ = self.encoder(
785
+ x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
786
+ feat_cache=self._enc_feat_map,
787
+ feat_idx=self._enc_conv_idx,
788
+ )
789
+ out = torch.cat([out, out_], 2)
790
+ mu, log_var = self.conv1(out).chunk(2, dim=1)
791
+ if isinstance(scale[0], torch.Tensor):
792
+ mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
793
+ 1, self.z_dim, 1, 1, 1
794
+ )
795
+ else:
796
+ mu = (mu - scale[0]) * scale[1]
797
+
798
+ self.clear_cache()
799
+ if return_mu:
800
+ return mu, log_var
801
+ else:
802
+ return mu
803
+
804
+ def decode(self, z, scale):
805
+ self.clear_cache()
806
+
807
+ # z: [b,c,t,h,w]
808
+ if isinstance(scale[0], torch.Tensor):
809
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
810
+ 1, self.z_dim, 1, 1, 1
811
+ )
812
+ else:
813
+ z = z / scale[1] + scale[0]
814
+ iter_ = z.shape[2]
815
+ x = self.conv2(z)
816
+ for i in range(iter_):
817
+ self._conv_idx = [0]
818
+ if i == 0:
819
+ out = self.decoder(
820
+ x[:, :, i : i + 1, :, :],
821
+ feat_cache=self._feat_map,
822
+ feat_idx=self._conv_idx,
823
+ )
824
+ else:
825
+ out_ = self.decoder(
826
+ x[:, :, i : i + 1, :, :],
827
+ feat_cache=self._feat_map,
828
+ feat_idx=self._conv_idx,
829
+ )
830
+ out = torch.cat([out, out_], 2)
831
+
832
+ self.clear_cache()
833
+ return out
834
+
835
+ def decode_stream(self, z, scale):
836
+ self.clear_cache()
837
+
838
+ # z: [b,c,t,h,w]
839
+ if isinstance(scale[0], torch.Tensor):
840
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
841
+ 1, self.z_dim, 1, 1, 1
842
+ )
843
+ else:
844
+ z = z / scale[1] + scale[0]
845
+ iter_ = z.shape[2]
846
+ x = self.conv2(z)
847
+ for i in range(iter_):
848
+ self._conv_idx = [0]
849
+ out = self.decoder(
850
+ x[:, :, i : i + 1, :, :],
851
+ feat_cache=self._feat_map,
852
+ feat_idx=self._conv_idx,
853
+ )
854
+ yield out
855
+
856
+ def cached_decode(self, z, scale):
857
+ # z: [b,c,t,h,w]
858
+ if isinstance(scale[0], torch.Tensor):
859
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
860
+ 1, self.z_dim, 1, 1, 1
861
+ )
862
+ else:
863
+ z = z / scale[1] + scale[0]
864
+ iter_ = z.shape[2]
865
+ x = self.conv2(z)
866
+ for i in range(iter_):
867
+ self._conv_idx = [0]
868
+ if i == 0:
869
+ out = self.decoder(
870
+ x[:, :, i : i + 1, :, :],
871
+ feat_cache=self._feat_map,
872
+ feat_idx=self._conv_idx,
873
+ )
874
+ else:
875
+ out_ = self.decoder(
876
+ x[:, :, i : i + 1, :, :],
877
+ feat_cache=self._feat_map,
878
+ feat_idx=self._conv_idx,
879
+ )
880
+ out = torch.cat([out, out_], 2)
881
+ return out
882
+
883
+ def reparameterize(self, mu, log_var):
884
+ std = torch.exp(0.5 * log_var)
885
+ eps = torch.randn_like(std)
886
+ return eps * std + mu
887
+
888
+ def sample(self, imgs, deterministic=False, scale=[0, 1]):
889
+ mu, log_var = self.encode(imgs, scale, return_mu=True)
890
+ if deterministic:
891
+ return mu
892
+ std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
893
+ return mu + std * torch.randn_like(std), mu, log_var
894
+
895
+ def clear_cache(self):
896
+ self._conv_num = count_conv3d(self.decoder)
897
+ self._conv_idx = [0]
898
+ self._feat_map = [None] * self._conv_num
899
+ # cache encode
900
+ self._enc_conv_num = count_conv3d(self.encoder)
901
+ self._enc_conv_idx = [0]
902
+ self._enc_feat_map = [None] * self._enc_conv_num
903
+
904
+ def encode_video(self, x, scale=[0, 1]):
905
+ assert x.ndim == 5 # NTCHW
906
+ assert x.shape[2] % 3 == 0
907
+ x = x.transpose(1, 2)
908
+ y = x.mul(2).sub_(1)
909
+ y, mu, log_var = self.sample(y, scale=scale)
910
+ return y.transpose(1, 2).to(x), mu, log_var
911
+
912
+ def decode_video(self, x, scale=[0, 1]):
913
+ assert x.ndim == 5 # NTCHW
914
+ assert x.shape[2] % self.z_dim == 0
915
+ x = x.transpose(1, 2)
916
+ # B, C, T, H, W
917
+ y = x
918
+ y = self.decode(y, scale).clamp_(-1, 1)
919
+ y = y.mul_(0.5).add_(0.5).clamp_(0, 1) # NCTHW
920
+ return y.transpose(1, 2).to(x)
921
+
922
+
923
+ def _video_vae(
924
+ pretrained_path=None,
925
+ z_dim=None,
926
+ device="cpu",
927
+ dtype=torch.float,
928
+ **kwargs,
929
+ ):
930
+ """
931
+ Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
932
+ """
933
+ # params
934
+ cfg = dict(
935
+ dim=96,
936
+ z_dim=z_dim,
937
+ dim_mult=[1, 2, 4, 4],
938
+ num_res_blocks=2,
939
+ attn_scales=[],
940
+ temperal_downsample=[False, True, True],
941
+ dropout=0.0,
942
+ )
943
+ cfg.update(**kwargs)
944
+
945
+ # init model
946
+ with torch.device("meta"):
947
+ model = WanVAE_(**cfg)
948
+
949
+ # load checkpoint
950
+ model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
951
+
952
+ return model
953
+
954
+ class WanVAE:
955
+ def __init__(
956
+ self,
957
+ z_dim=16,
958
+ vae_path="cache/vae_step_411000.pth",
959
+ dtype=torch.float,
960
+ device="cuda",
961
+ parallel=False,
962
+ use_tiling=False,
963
+ use_2d_split=True,
964
+ ):
965
+ self.dtype = dtype
966
+ self.device = device
967
+ self.parallel = parallel
968
+ self.use_tiling = use_tiling
969
+ self.use_2d_split = use_2d_split
970
+
971
+ mean = [
972
+ -0.7571,
973
+ -0.7089,
974
+ -0.9113,
975
+ 0.1075,
976
+ -0.1745,
977
+ 0.9653,
978
+ -0.1517,
979
+ 1.5508,
980
+ 0.4134,
981
+ -0.0715,
982
+ 0.5517,
983
+ -0.3632,
984
+ -0.1922,
985
+ -0.9497,
986
+ 0.2503,
987
+ -0.2921,
988
+ ]
989
+ std = [
990
+ 2.8184,
991
+ 1.4541,
992
+ 2.3275,
993
+ 2.6558,
994
+ 1.2196,
995
+ 1.7708,
996
+ 2.6052,
997
+ 2.0743,
998
+ 3.2687,
999
+ 2.1526,
1000
+ 2.8652,
1001
+ 1.5579,
1002
+ 1.6382,
1003
+ 1.1253,
1004
+ 2.8251,
1005
+ 1.9160,
1006
+ ]
1007
+ self.mean = torch.tensor(mean, dtype=dtype, device=device)
1008
+ self.inv_std = 1.0 / torch.tensor(std, dtype=dtype, device=device)
1009
+ self.scale = [self.mean, self.inv_std]
1010
+
1011
+ # (height, width, world_size) -> (world_size_h, world_size_w)
1012
+ self.grid_table = {
1013
+ # world_size = 2
1014
+ (60, 104, 2): (1, 2),
1015
+ (68, 120, 2): (1, 2),
1016
+ (90, 160, 2): (1, 2),
1017
+ (60, 60, 2): (1, 2),
1018
+ (72, 72, 2): (1, 2),
1019
+ (88, 88, 2): (1, 2),
1020
+ (120, 120, 2): (1, 2),
1021
+ (104, 60, 2): (2, 1),
1022
+ (120, 68, 2): (2, 1),
1023
+ (160, 90, 2): (2, 1),
1024
+ # world_size = 4
1025
+ (60, 104, 4): (2, 2),
1026
+ (68, 120, 4): (2, 2),
1027
+ (90, 160, 4): (2, 2),
1028
+ (60, 60, 4): (2, 2),
1029
+ (72, 72, 4): (2, 2),
1030
+ (88, 88, 4): (2, 2),
1031
+ (120, 120, 4): (2, 2),
1032
+ (104, 60, 4): (2, 2),
1033
+ (120, 68, 4): (2, 2),
1034
+ (160, 90, 4): (2, 2),
1035
+ # world_size = 8
1036
+ (60, 104, 8): (2, 4),
1037
+ (68, 120, 8): (2, 4),
1038
+ (90, 160, 8): (2, 4),
1039
+ (60, 60, 8): (2, 4),
1040
+ (72, 72, 8): (2, 4),
1041
+ (88, 88, 8): (2, 4),
1042
+ (120, 120, 8): (2, 4),
1043
+ (104, 60, 8): (4, 2),
1044
+ (120, 68, 8): (4, 2),
1045
+ (160, 90, 8): (4, 2),
1046
+ }
1047
+
1048
+ # init model
1049
+ self.model = (
1050
+ _video_vae(
1051
+ pretrained_path=vae_path,
1052
+ z_dim=z_dim,
1053
+ dtype=dtype,
1054
+ )
1055
+ .eval()
1056
+ .requires_grad_(False)
1057
+ .to(device)
1058
+ .to(dtype)
1059
+ )
1060
+
1061
+ def _calculate_2d_grid(self, latent_height, latent_width, world_size):
1062
+ if (latent_height, latent_width, world_size) in self.grid_table:
1063
+ best_h, best_w = self.grid_table[(latent_height, latent_width, world_size)]
1064
+ # logger.info(f"Vae using cached 2D grid: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
1065
+ return best_h, best_w
1066
+
1067
+ best_h, best_w = 1, world_size
1068
+ min_aspect_diff = float("inf")
1069
+
1070
+ for h in range(1, world_size + 1):
1071
+ if world_size % h == 0:
1072
+ w = world_size // h
1073
+ if latent_height % h == 0 and latent_width % w == 0:
1074
+ # Calculate how close this grid is to square
1075
+ aspect_diff = abs((latent_height / h) - (latent_width / w))
1076
+ if aspect_diff < min_aspect_diff:
1077
+ min_aspect_diff = aspect_diff
1078
+ best_h, best_w = h, w
1079
+ # logger.info(f"Vae using 2D grid & Update cache: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
1080
+ self.grid_table[(latent_height, latent_width, world_size)] = (best_h, best_w)
1081
+ return best_h, best_w
1082
+
1083
+ def current_device(self):
1084
+ return next(self.model.parameters()).device
1085
+
1086
+ def encode_dist(self, video, world_size, cur_rank, split_dim):
1087
+ spatial_ratio = 8
1088
+
1089
+ if split_dim == 3:
1090
+ total_latent_len = video.shape[3] // spatial_ratio
1091
+ elif split_dim == 4:
1092
+ total_latent_len = video.shape[4] // spatial_ratio
1093
+ else:
1094
+ raise ValueError(f"Unsupported split_dim: {split_dim}")
1095
+
1096
+ splited_chunk_len = total_latent_len // world_size
1097
+ padding_size = 1
1098
+
1099
+ video_chunk_len = splited_chunk_len * spatial_ratio
1100
+ video_padding_len = padding_size * spatial_ratio
1101
+
1102
+ if cur_rank == 0:
1103
+ if split_dim == 3:
1104
+ video_chunk = video[
1105
+ :, :, :, : video_chunk_len + 2 * video_padding_len, :
1106
+ ].contiguous()
1107
+ elif split_dim == 4:
1108
+ video_chunk = video[
1109
+ :, :, :, :, : video_chunk_len + 2 * video_padding_len
1110
+ ].contiguous()
1111
+ elif cur_rank == world_size - 1:
1112
+ if split_dim == 3:
1113
+ video_chunk = video[
1114
+ :, :, :, -(video_chunk_len + 2 * video_padding_len) :, :
1115
+ ].contiguous()
1116
+ elif split_dim == 4:
1117
+ video_chunk = video[
1118
+ :, :, :, :, -(video_chunk_len + 2 * video_padding_len) :
1119
+ ].contiguous()
1120
+ else:
1121
+ start_idx = cur_rank * video_chunk_len - video_padding_len
1122
+ end_idx = (cur_rank + 1) * video_chunk_len + video_padding_len
1123
+ if split_dim == 3:
1124
+ video_chunk = video[:, :, :, start_idx:end_idx, :].contiguous()
1125
+ elif split_dim == 4:
1126
+ video_chunk = video[:, :, :, :, start_idx:end_idx].contiguous()
1127
+
1128
+ if self.use_tiling:
1129
+ encoded_chunk = self.model.tiled_encode(video_chunk, self.scale)
1130
+ else:
1131
+ encoded_chunk = self.model.encode(video_chunk, self.scale)
1132
+
1133
+ if cur_rank == 0:
1134
+ if split_dim == 3:
1135
+ encoded_chunk = encoded_chunk[
1136
+ :, :, :, :splited_chunk_len, :
1137
+ ].contiguous()
1138
+ elif split_dim == 4:
1139
+ encoded_chunk = encoded_chunk[
1140
+ :, :, :, :, :splited_chunk_len
1141
+ ].contiguous()
1142
+ elif cur_rank == world_size - 1:
1143
+ if split_dim == 3:
1144
+ encoded_chunk = encoded_chunk[
1145
+ :, :, :, -splited_chunk_len:, :
1146
+ ].contiguous()
1147
+ elif split_dim == 4:
1148
+ encoded_chunk = encoded_chunk[
1149
+ :, :, :, :, -splited_chunk_len:
1150
+ ].contiguous()
1151
+ else:
1152
+ if split_dim == 3:
1153
+ encoded_chunk = encoded_chunk[
1154
+ :, :, :, padding_size:-padding_size, :
1155
+ ].contiguous()
1156
+ elif split_dim == 4:
1157
+ encoded_chunk = encoded_chunk[
1158
+ :, :, :, :, padding_size:-padding_size
1159
+ ].contiguous()
1160
+
1161
+ full_encoded = [torch.empty_like(encoded_chunk) for _ in range(world_size)]
1162
+ dist.all_gather(full_encoded, encoded_chunk)
1163
+
1164
+ torch.cuda.synchronize()
1165
+
1166
+ encoded = torch.cat(full_encoded, dim=split_dim)
1167
+
1168
+ return encoded.squeeze(0)
1169
+
1170
+ def encode_dist_2d(self, video, world_size_h, world_size_w, cur_rank_h, cur_rank_w):
1171
+ spatial_ratio = 8
1172
+
1173
+ # Calculate chunk sizes for both dimensions
1174
+ total_latent_h = video.shape[3] // spatial_ratio
1175
+ total_latent_w = video.shape[4] // spatial_ratio
1176
+
1177
+ chunk_h = total_latent_h // world_size_h
1178
+ chunk_w = total_latent_w // world_size_w
1179
+
1180
+ padding_size = 1
1181
+ video_chunk_h = chunk_h * spatial_ratio
1182
+ video_chunk_w = chunk_w * spatial_ratio
1183
+ video_padding_h = padding_size * spatial_ratio
1184
+ video_padding_w = padding_size * spatial_ratio
1185
+
1186
+ # Calculate H dimension slice
1187
+ if cur_rank_h == 0:
1188
+ h_start = 0
1189
+ h_end = video_chunk_h + 2 * video_padding_h
1190
+ elif cur_rank_h == world_size_h - 1:
1191
+ h_start = video.shape[3] - (video_chunk_h + 2 * video_padding_h)
1192
+ h_end = video.shape[3]
1193
+ else:
1194
+ h_start = cur_rank_h * video_chunk_h - video_padding_h
1195
+ h_end = (cur_rank_h + 1) * video_chunk_h + video_padding_h
1196
+
1197
+ # Calculate W dimension slice
1198
+ if cur_rank_w == 0:
1199
+ w_start = 0
1200
+ w_end = video_chunk_w + 2 * video_padding_w
1201
+ elif cur_rank_w == world_size_w - 1:
1202
+ w_start = video.shape[4] - (video_chunk_w + 2 * video_padding_w)
1203
+ w_end = video.shape[4]
1204
+ else:
1205
+ w_start = cur_rank_w * video_chunk_w - video_padding_w
1206
+ w_end = (cur_rank_w + 1) * video_chunk_w + video_padding_w
1207
+
1208
+ # Extract the video chunk for this process
1209
+ video_chunk = video[:, :, :, h_start:h_end, w_start:w_end].contiguous()
1210
+
1211
+ # Encode the chunk
1212
+ if self.use_tiling:
1213
+ encoded_chunk = self.model.tiled_encode(video_chunk, self.scale)
1214
+ else:
1215
+ encoded_chunk = self.model.encode(video_chunk, self.scale)
1216
+
1217
+ # Remove padding from encoded chunk
1218
+ if cur_rank_h == 0:
1219
+ encoded_h_start = 0
1220
+ encoded_h_end = chunk_h
1221
+ elif cur_rank_h == world_size_h - 1:
1222
+ encoded_h_start = encoded_chunk.shape[3] - chunk_h
1223
+ encoded_h_end = encoded_chunk.shape[3]
1224
+ else:
1225
+ encoded_h_start = padding_size
1226
+ encoded_h_end = encoded_chunk.shape[3] - padding_size
1227
+
1228
+ if cur_rank_w == 0:
1229
+ encoded_w_start = 0
1230
+ encoded_w_end = chunk_w
1231
+ elif cur_rank_w == world_size_w - 1:
1232
+ encoded_w_start = encoded_chunk.shape[4] - chunk_w
1233
+ encoded_w_end = encoded_chunk.shape[4]
1234
+ else:
1235
+ encoded_w_start = padding_size
1236
+ encoded_w_end = encoded_chunk.shape[4] - padding_size
1237
+
1238
+ encoded_chunk = encoded_chunk[
1239
+ :, :, :, encoded_h_start:encoded_h_end, encoded_w_start:encoded_w_end
1240
+ ].contiguous()
1241
+
1242
+ # Gather all chunks
1243
+ total_processes = world_size_h * world_size_w
1244
+ full_encoded = [torch.empty_like(encoded_chunk) for _ in range(total_processes)]
1245
+
1246
+ dist.all_gather(full_encoded, encoded_chunk)
1247
+
1248
+ torch.cuda.synchronize()
1249
+
1250
+ # Reconstruct the full encoded tensor
1251
+ encoded_rows = []
1252
+ for h_idx in range(world_size_h):
1253
+ encoded_cols = []
1254
+ for w_idx in range(world_size_w):
1255
+ process_idx = h_idx * world_size_w + w_idx
1256
+ encoded_cols.append(full_encoded[process_idx])
1257
+ encoded_rows.append(torch.cat(encoded_cols, dim=4))
1258
+
1259
+ encoded = torch.cat(encoded_rows, dim=3)
1260
+
1261
+ return encoded.squeeze(0)
1262
+
1263
+ def encode(self, video, world_size_h=None, world_size_w=None):
1264
+ """
1265
+ video: one video with shape [1, C, T, H, W].
1266
+ """
1267
+ if self.parallel:
1268
+ world_size = dist.get_world_size()
1269
+ cur_rank = dist.get_rank()
1270
+ height, width = video.shape[3], video.shape[4]
1271
+
1272
+ if self.use_2d_split:
1273
+ if world_size_h is None or world_size_w is None:
1274
+ world_size_h, world_size_w = self._calculate_2d_grid(
1275
+ height // 8, width // 8, world_size
1276
+ )
1277
+ cur_rank_h = cur_rank // world_size_w
1278
+ cur_rank_w = cur_rank % world_size_w
1279
+ out = self.encode_dist_2d(
1280
+ video, world_size_h, world_size_w, cur_rank_h, cur_rank_w
1281
+ )
1282
+ else:
1283
+ # Original 1D splitting logic
1284
+ if width % world_size == 0:
1285
+ out = self.encode_dist(video, world_size, cur_rank, split_dim=4)
1286
+ elif height % world_size == 0:
1287
+ out = self.encode_dist(video, world_size, cur_rank, split_dim=3)
1288
+ else:
1289
+ logger.info("Fall back to naive encode mode")
1290
+ if self.use_tiling:
1291
+ out = self.model.tiled_encode(video, self.scale).squeeze(0)
1292
+ else:
1293
+ out = self.model.encode(video, self.scale).squeeze(0)
1294
+ else:
1295
+ if self.use_tiling:
1296
+ out = self.model.tiled_encode(video, self.scale).squeeze(0)
1297
+ else:
1298
+ out = self.model.encode(video, self.scale).squeeze(0)
1299
+
1300
+ return out
1301
+
1302
+ def decode_dist(self, zs, world_size, cur_rank, split_dim):
1303
+ splited_total_len = zs.shape[split_dim]
1304
+ splited_chunk_len = splited_total_len // world_size
1305
+ padding_size = 1
1306
+
1307
+ if cur_rank == 0:
1308
+ if split_dim == 2:
1309
+ zs = zs[:, :, : splited_chunk_len + 2 * padding_size, :].contiguous()
1310
+ elif split_dim == 3:
1311
+ zs = zs[:, :, :, : splited_chunk_len + 2 * padding_size].contiguous()
1312
+ elif cur_rank == world_size - 1:
1313
+ if split_dim == 2:
1314
+ zs = zs[:, :, -(splited_chunk_len + 2 * padding_size) :, :].contiguous()
1315
+ elif split_dim == 3:
1316
+ zs = zs[:, :, :, -(splited_chunk_len + 2 * padding_size) :].contiguous()
1317
+ else:
1318
+ if split_dim == 2:
1319
+ zs = zs[
1320
+ :,
1321
+ :,
1322
+ cur_rank * splited_chunk_len - padding_size : (cur_rank + 1)
1323
+ * splited_chunk_len
1324
+ + padding_size,
1325
+ :,
1326
+ ].contiguous()
1327
+ elif split_dim == 3:
1328
+ zs = zs[
1329
+ :,
1330
+ :,
1331
+ :,
1332
+ cur_rank * splited_chunk_len - padding_size : (cur_rank + 1)
1333
+ * splited_chunk_len
1334
+ + padding_size,
1335
+ ].contiguous()
1336
+
1337
+ decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
1338
+ images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1)
1339
+
1340
+ if cur_rank == 0:
1341
+ if split_dim == 2:
1342
+ images = images[:, :, :, : splited_chunk_len * 8, :].contiguous()
1343
+ elif split_dim == 3:
1344
+ images = images[:, :, :, :, : splited_chunk_len * 8].contiguous()
1345
+ elif cur_rank == world_size - 1:
1346
+ if split_dim == 2:
1347
+ images = images[:, :, :, -splited_chunk_len * 8 :, :].contiguous()
1348
+ elif split_dim == 3:
1349
+ images = images[:, :, :, :, -splited_chunk_len * 8 :].contiguous()
1350
+ else:
1351
+ if split_dim == 2:
1352
+ images = images[
1353
+ :, :, :, 8 * padding_size : -8 * padding_size, :
1354
+ ].contiguous()
1355
+ elif split_dim == 3:
1356
+ images = images[
1357
+ :, :, :, :, 8 * padding_size : -8 * padding_size
1358
+ ].contiguous()
1359
+
1360
+ full_images = [torch.empty_like(images) for _ in range(world_size)]
1361
+ dist.all_gather(full_images, images)
1362
+
1363
+ torch.cuda.synchronize()
1364
+
1365
+ images = torch.cat(full_images, dim=split_dim + 1)
1366
+
1367
+ return images
1368
+
1369
+ def decode_dist_2d(self, zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w):
1370
+ total_h = zs.shape[2]
1371
+ total_w = zs.shape[3]
1372
+
1373
+ chunk_h = total_h // world_size_h
1374
+ chunk_w = total_w // world_size_w
1375
+
1376
+ padding_size = 2
1377
+
1378
+ # Calculate H dimension slice
1379
+ if cur_rank_h == 0:
1380
+ h_start = 0
1381
+ h_end = chunk_h + 2 * padding_size
1382
+ elif cur_rank_h == world_size_h - 1:
1383
+ h_start = total_h - (chunk_h + 2 * padding_size)
1384
+ h_end = total_h
1385
+ else:
1386
+ h_start = cur_rank_h * chunk_h - padding_size
1387
+ h_end = (cur_rank_h + 1) * chunk_h + padding_size
1388
+
1389
+ # Calculate W dimension slice
1390
+ if cur_rank_w == 0:
1391
+ w_start = 0
1392
+ w_end = chunk_w + 2 * padding_size
1393
+ elif cur_rank_w == world_size_w - 1:
1394
+ w_start = total_w - (chunk_w + 2 * padding_size)
1395
+ w_end = total_w
1396
+ else:
1397
+ w_start = cur_rank_w * chunk_w - padding_size
1398
+ w_end = (cur_rank_w + 1) * chunk_w + padding_size
1399
+
1400
+ # Extract the latent chunk for this process
1401
+ zs_chunk = zs[:, :, h_start:h_end, w_start:w_end].contiguous()
1402
+
1403
+ # Decode the chunk
1404
+ decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
1405
+ images_chunk = decode_func(zs_chunk.unsqueeze(0), self.scale).clamp_(-1, 1)
1406
+
1407
+ # Remove padding from decoded chunk
1408
+ spatial_ratio = 8
1409
+ if cur_rank_h == 0:
1410
+ decoded_h_start = 0
1411
+ decoded_h_end = chunk_h * spatial_ratio
1412
+ elif cur_rank_h == world_size_h - 1:
1413
+ decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio
1414
+ decoded_h_end = images_chunk.shape[3]
1415
+ else:
1416
+ decoded_h_start = padding_size * spatial_ratio
1417
+ decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio
1418
+
1419
+ if cur_rank_w == 0:
1420
+ decoded_w_start = 0
1421
+ decoded_w_end = chunk_w * spatial_ratio
1422
+ elif cur_rank_w == world_size_w - 1:
1423
+ decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio
1424
+ decoded_w_end = images_chunk.shape[4]
1425
+ else:
1426
+ decoded_w_start = padding_size * spatial_ratio
1427
+ decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio
1428
+
1429
+ images_chunk = images_chunk[
1430
+ :, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end
1431
+ ].contiguous()
1432
+
1433
+ # Gather all chunks
1434
+ total_processes = world_size_h * world_size_w
1435
+ full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)]
1436
+
1437
+ dist.all_gather(full_images, images_chunk)
1438
+
1439
+ torch.cuda.synchronize()
1440
+
1441
+ # Reconstruct the full image tensor
1442
+ image_rows = []
1443
+ for h_idx in range(world_size_h):
1444
+ image_cols = []
1445
+ for w_idx in range(world_size_w):
1446
+ process_idx = h_idx * world_size_w + w_idx
1447
+ image_cols.append(full_images[process_idx])
1448
+ image_rows.append(torch.cat(image_cols, dim=4))
1449
+
1450
+ images = torch.cat(image_rows, dim=3)
1451
+
1452
+ return images
1453
+
1454
+ def decode_dist_2d_stream(
1455
+ self, zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w
1456
+ ):
1457
+ total_h = zs.shape[2]
1458
+ total_w = zs.shape[3]
1459
+
1460
+ chunk_h = total_h // world_size_h
1461
+ chunk_w = total_w // world_size_w
1462
+
1463
+ padding_size = 2
1464
+
1465
+ # Calculate H dimension slice
1466
+ if cur_rank_h == 0:
1467
+ h_start = 0
1468
+ h_end = chunk_h + 2 * padding_size
1469
+ elif cur_rank_h == world_size_h - 1:
1470
+ h_start = total_h - (chunk_h + 2 * padding_size)
1471
+ h_end = total_h
1472
+ else:
1473
+ h_start = cur_rank_h * chunk_h - padding_size
1474
+ h_end = (cur_rank_h + 1) * chunk_h + padding_size
1475
+
1476
+ # Calculate W dimension slice
1477
+ if cur_rank_w == 0:
1478
+ w_start = 0
1479
+ w_end = chunk_w + 2 * padding_size
1480
+ elif cur_rank_w == world_size_w - 1:
1481
+ w_start = total_w - (chunk_w + 2 * padding_size)
1482
+ w_end = total_w
1483
+ else:
1484
+ w_start = cur_rank_w * chunk_w - padding_size
1485
+ w_end = (cur_rank_w + 1) * chunk_w + padding_size
1486
+
1487
+ # Extract the latent chunk for this process
1488
+ zs_chunk = zs[:, :, h_start:h_end, w_start:w_end].contiguous()
1489
+
1490
+ for image in self.model.decode_stream(zs_chunk.unsqueeze(0), self.scale):
1491
+ images_chunk = image.clamp_(-1, 1)
1492
+ # Remove padding from decoded chunk
1493
+ spatial_ratio = 8
1494
+ if cur_rank_h == 0:
1495
+ decoded_h_start = 0
1496
+ decoded_h_end = chunk_h * spatial_ratio
1497
+ elif cur_rank_h == world_size_h - 1:
1498
+ decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio
1499
+ decoded_h_end = images_chunk.shape[3]
1500
+ else:
1501
+ decoded_h_start = padding_size * spatial_ratio
1502
+ decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio
1503
+
1504
+ if cur_rank_w == 0:
1505
+ decoded_w_start = 0
1506
+ decoded_w_end = chunk_w * spatial_ratio
1507
+ elif cur_rank_w == world_size_w - 1:
1508
+ decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio
1509
+ decoded_w_end = images_chunk.shape[4]
1510
+ else:
1511
+ decoded_w_start = padding_size * spatial_ratio
1512
+ decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio
1513
+
1514
+ images_chunk = images_chunk[
1515
+ :, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end
1516
+ ].contiguous()
1517
+
1518
+ # Gather all chunks
1519
+ total_processes = world_size_h * world_size_w
1520
+ full_images = [
1521
+ torch.empty_like(images_chunk) for _ in range(total_processes)
1522
+ ]
1523
+
1524
+ dist.all_gather(full_images, images_chunk)
1525
+
1526
+ torch.cuda.synchronize()
1527
+
1528
+ # Reconstruct the full image tensor
1529
+ image_rows = []
1530
+ for h_idx in range(world_size_h):
1531
+ image_cols = []
1532
+ for w_idx in range(world_size_w):
1533
+ process_idx = h_idx * world_size_w + w_idx
1534
+ image_cols.append(full_images[process_idx])
1535
+ image_rows.append(torch.cat(image_cols, dim=4))
1536
+
1537
+ images = torch.cat(image_rows, dim=3)
1538
+
1539
+ yield images
1540
+
1541
+ def decode(self, zs):
1542
+ if self.parallel:
1543
+ world_size = dist.get_world_size()
1544
+ cur_rank = dist.get_rank()
1545
+ latent_height, latent_width = zs.shape[2], zs.shape[3]
1546
+
1547
+ if self.use_2d_split:
1548
+ world_size_h, world_size_w = self._calculate_2d_grid(
1549
+ latent_height, latent_width, world_size
1550
+ )
1551
+ cur_rank_h = cur_rank // world_size_w
1552
+ cur_rank_w = cur_rank % world_size_w
1553
+ images = self.decode_dist_2d(
1554
+ zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w
1555
+ )
1556
+ else:
1557
+ # Original 1D splitting logic
1558
+ if latent_width % world_size == 0:
1559
+ images = self.decode_dist(zs, world_size, cur_rank, split_dim=3)
1560
+ elif latent_height % world_size == 0:
1561
+ images = self.decode_dist(zs, world_size, cur_rank, split_dim=2)
1562
+ else:
1563
+ logger.info("Fall back to naive decode mode")
1564
+ images = self.model.decode(zs.unsqueeze(0), self.scale).clamp_(
1565
+ -1, 1
1566
+ )
1567
+ else:
1568
+ decode_func = (
1569
+ self.model.tiled_decode if self.use_tiling else self.model.decode
1570
+ )
1571
+ images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1)
1572
+
1573
+ return images
1574
+
1575
+ def decode_stream(self, zs):
1576
+ if self.parallel:
1577
+ world_size = dist.get_world_size()
1578
+ cur_rank = dist.get_rank()
1579
+ latent_height, latent_width = zs.shape[2], zs.shape[3]
1580
+
1581
+ world_size_h, world_size_w = self._calculate_2d_grid(
1582
+ latent_height, latent_width, world_size
1583
+ )
1584
+ cur_rank_h = cur_rank // world_size_w
1585
+ cur_rank_w = cur_rank % world_size_w
1586
+ for images in self.decode_dist_2d_stream(
1587
+ zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w
1588
+ ):
1589
+ yield images
1590
+ else:
1591
+ for image in self.model.decode_stream(zs.unsqueeze(0), self.scale):
1592
+ yield image.clamp_(-1, 1)
1593
+
1594
+ def encode_video(self, vid):
1595
+ return self.model.encode_video(vid)
1596
+
1597
+ def decode_video(self, vid_enc):
1598
+ return self.model.decode_video(vid_enc)
generate_video.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
2
+ import argparse
3
+ import os
4
+ import numpy as np
5
+ import time
6
+ import torch
7
+ import torch.distributed as dist
8
+ import subprocess
9
+ import imageio
10
+ import librosa
11
+ import numpy as np
12
+ from loguru import logger
13
+ from collections import deque
14
+ from datetime import datetime
15
+
16
+ from flash_head.inference import get_pipeline, get_base_data, get_infer_params, get_audio_embedding, run_pipeline
17
+
18
+ def _validate_args(args):
19
+ # Basic check
20
+ assert args.ckpt_dir is not None, "Please specify FlashHead model checkpoint directory."
21
+ assert args.wav2vec_dir is not None, "Please specify the wav2vec checkpoint directory."
22
+ assert args.model_type=="pro" or args.model_type=="lite", "Please specify the model name (pro, lite)."
23
+ assert args.cond_image_dir is not None or args.cond_image is not None, "Please specify the condition image or directory."
24
+ assert args.audio_path is not None, "Please specify the audio path."
25
+
26
+ args.base_seed = args.base_seed if args.base_seed >= 0 else 42
27
+
28
+ def _parse_args():
29
+ parser = argparse.ArgumentParser(
30
+ description="Generate video from one image using FlashHead"
31
+ )
32
+ parser.add_argument(
33
+ "--ckpt_dir",
34
+ type=str,
35
+ default=None,
36
+ help="The path to FlashHead model checkpoint directory.")
37
+ parser.add_argument(
38
+ "--wav2vec_dir",
39
+ type=str,
40
+ default=None,
41
+ help="The path to the wav2vec checkpoint directory.")
42
+ parser.add_argument(
43
+ "--model_type",
44
+ type=str,
45
+ default=None,
46
+ help="Choose from pro or lite.")
47
+ parser.add_argument(
48
+ "--save_file",
49
+ type=str,
50
+ default=None,
51
+ help="The file to save the generated video to.")
52
+ parser.add_argument(
53
+ "--base_seed",
54
+ type=int,
55
+ default=42,
56
+ help="The seed to use for generating the video.")
57
+ parser.add_argument(
58
+ "--cond_image",
59
+ type=str,
60
+ default=None,
61
+ help="[meta file] The condition image path to generate the video.")
62
+ parser.add_argument(
63
+ "--cond_image_dir",
64
+ type=str,
65
+ default=None,
66
+ help="[meta directory] The directory of condition images.")
67
+ parser.add_argument(
68
+ "--audio_path",
69
+ type=str,
70
+ default=None,
71
+ help="[meta file] The audio path to generate the video.")
72
+ parser.add_argument(
73
+ "--audio_encode_mode",
74
+ type=str,
75
+ default="stream",
76
+ choices=['stream', 'once'],
77
+ help="stream: encode audio chunk before every generation; once: encode audio together")
78
+ parser.add_argument(
79
+ "--use_face_crop",
80
+ type=bool,
81
+ default=False,
82
+ help="Enable face detection and crop for condition image")
83
+ args = parser.parse_args()
84
+ args = parser.parse_args()
85
+
86
+ _validate_args(args)
87
+
88
+ return args
89
+
90
+ def save_video(frames_list, video_path, audio_path, fps):
91
+ temp_video_path = video_path.replace('.mp4', '_tmp.mp4')
92
+ with imageio.get_writer(temp_video_path, format='mp4', mode='I',
93
+ fps=fps , codec='h264', ffmpeg_params=['-bf', '0']) as writer:
94
+ for frames in frames_list:
95
+ frames = frames.numpy().astype(np.uint8)
96
+ for i in range(frames.shape[0]):
97
+ frame = frames[i, :, :, :]
98
+ writer.append_data(frame)
99
+
100
+ # merge video and audio
101
+ cmd = ['ffmpeg', '-i', temp_video_path, '-i', audio_path, '-c:v', 'copy', '-c:a', 'aac', '-shortest', video_path, '-y']
102
+ subprocess.run(cmd)
103
+ os.remove(temp_video_path)
104
+
105
+
106
+ def generate(args):
107
+ world_size = int(os.environ.get("WORLD_SIZE", 1))
108
+ rank = int(os.environ.get("RANK", 0))
109
+
110
+ pipeline = get_pipeline(world_size=world_size, ckpt_dir=args.ckpt_dir, wav2vec_dir=args.wav2vec_dir, model_type=args.model_type)
111
+ get_base_data(pipeline, cond_image_path_or_dir=args.cond_image_dir if args.cond_image_dir is not None else args.cond_image, base_seed=args.base_seed, use_face_crop=args.use_face_crop)
112
+ infer_params = get_infer_params()
113
+
114
+ sample_rate = infer_params['sample_rate']
115
+ tgt_fps = infer_params['tgt_fps']
116
+ cached_audio_duration = infer_params['cached_audio_duration']
117
+ frame_num = infer_params['frame_num']
118
+ motion_frames_num = infer_params['motion_frames_num']
119
+ slice_len = frame_num - motion_frames_num
120
+
121
+ human_speech_array_all, _ = librosa.load(args.audio_path, sr=infer_params['sample_rate'], mono=True)
122
+ human_speech_array_slice_len = slice_len * sample_rate // tgt_fps
123
+ human_speech_array_frame_num = frame_num * sample_rate // tgt_fps
124
+
125
+ if rank == 0:
126
+ logger.info("Data preparation done. Start to generate video...")
127
+
128
+ generated_list = []
129
+ if args.audio_encode_mode == 'once':
130
+ # pad audio with silence to avoid truncating the last chunk
131
+ remainder = (len(human_speech_array_all) - human_speech_array_frame_num) % human_speech_array_slice_len
132
+ if remainder > 0:
133
+ pad_length = human_speech_array_slice_len - remainder
134
+ human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])
135
+
136
+ # encode audio together
137
+ audio_embedding_all = get_audio_embedding(pipeline, human_speech_array_all)
138
+
139
+ # split audio embedding into chunks
140
+ # for Pro model: 33, 28, 28, 28, ...; For Lite model: 33, 24, 24, 24, ...
141
+ audio_embedding_chunks_list = [audio_embedding_all[:, i * slice_len: i * slice_len + frame_num].contiguous() for i in range((audio_embedding_all.shape[1]-frame_num) // slice_len)]
142
+
143
+ for chunk_idx, audio_embedding_chunk in enumerate(audio_embedding_chunks_list):
144
+ torch.cuda.synchronize()
145
+ start_time = time.time()
146
+
147
+ # inference
148
+ video = run_pipeline(pipeline, audio_embedding_chunk)
149
+
150
+ if chunk_idx != 0:
151
+ video = video[motion_frames_num:]
152
+
153
+ torch.cuda.synchronize()
154
+ end_time = time.time()
155
+ if rank == 0:
156
+ logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.3f}s")
157
+
158
+ generated_list.append(video.cpu())
159
+
160
+ elif args.audio_encode_mode == 'stream':
161
+ cached_audio_length_sum = sample_rate * cached_audio_duration
162
+ audio_end_idx = cached_audio_duration * tgt_fps
163
+ audio_start_idx = audio_end_idx - frame_num
164
+
165
+ audio_dq = deque([0.0] * cached_audio_length_sum, maxlen=cached_audio_length_sum)
166
+
167
+ # pad audio with silence to avoid truncating the last chunk
168
+ remainder = len(human_speech_array_all) % human_speech_array_slice_len
169
+ if remainder > 0:
170
+ pad_length = human_speech_array_slice_len - remainder
171
+ human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])
172
+
173
+ # split audio embedding into chunks
174
+ # for Pro model: 28, 28, 28, 28, ...; For Lite model: 24, 24, 24, 24, ...
175
+ human_speech_array_slices = human_speech_array_all.reshape(-1, human_speech_array_slice_len)
176
+
177
+ for chunk_idx, human_speech_array in enumerate(human_speech_array_slices):
178
+ torch.cuda.synchronize()
179
+ start_time = time.time()
180
+
181
+ # streaming encode audio chunks
182
+ audio_dq.extend(human_speech_array.tolist())
183
+ audio_array = np.array(audio_dq)
184
+ audio_embedding = get_audio_embedding(pipeline, audio_array, audio_start_idx, audio_end_idx)
185
+
186
+ # inference
187
+ video = run_pipeline(pipeline, audio_embedding)
188
+ video = video[motion_frames_num:]
189
+
190
+ torch.cuda.synchronize()
191
+ end_time = time.time()
192
+ if rank == 0:
193
+ logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.3f}s")
194
+
195
+ generated_list.append(video.cpu())
196
+
197
+
198
+ if rank == 0:
199
+ if args.save_file is None:
200
+ output_dir = 'sample_results'
201
+ if not os.path.exists(output_dir):
202
+ os.makedirs(output_dir)
203
+ timestamp = datetime.now().strftime("%Y%m%d-%H:%M:%S-%f")[:-3]
204
+ filename = f"res_{timestamp}.mp4"
205
+ filepath = os.path.join(output_dir, filename)
206
+ args.save_file = filepath
207
+
208
+ save_video(generated_list, args.save_file, args.audio_path, fps=tgt_fps)
209
+ logger.info(f"Saving generated video to {args.save_file}")
210
+ logger.info("Finished.")
211
+
212
+ if world_size > 1:
213
+ dist.barrier()
214
+ dist.destroy_process_group()
215
+
216
+ if __name__ == "__main__":
217
+ args = _parse_args()
218
+ generate(args)
gradio_app_streaming.py ADDED
@@ -0,0 +1,339 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio 流式视频生成:视频生成&视频保存异步进行,确保实时性
3
+ """
4
+ import gradio as gr
5
+ import os
6
+ import torch
7
+ import numpy as np
8
+ import time
9
+ import wave
10
+ import imageio
11
+ import librosa
12
+ import subprocess
13
+ import queue
14
+ import threading
15
+ from datetime import datetime
16
+ from collections import deque
17
+ from loguru import logger
18
+
19
+ from flash_head.inference import (
20
+ get_pipeline,
21
+ get_base_data,
22
+ get_infer_params,
23
+ get_audio_embedding,
24
+ run_pipeline,
25
+ )
26
+
27
+ # gr.Video 的 streaming=True 要求视频片段大于1s,实际需要接近3s才能不卡顿。
28
+ # 为了适配,每 3 个 chunk 合并为一段视频
29
+ CHUNKS_PER_SEGMENT = 3
30
+
31
+ pipeline = None
32
+ loaded_ckpt_dir = None
33
+ loaded_wav2vec_dir = None
34
+ loaded_model_type = None
35
+
36
+
37
+ def _write_frames_to_mp4(frames_list, video_path, fps):
38
+ """将帧列表写入 MP4(仅视频轨)。"""
39
+ os.makedirs(os.path.dirname(video_path) or ".", exist_ok=True)
40
+ with imageio.get_writer(
41
+ video_path,
42
+ format="mp4",
43
+ mode="I",
44
+ fps=fps,
45
+ codec="h264",
46
+ ffmpeg_params=["-bf", "0"],
47
+ ) as writer:
48
+ for frames in frames_list:
49
+ frames_np = frames.numpy().astype(np.uint8)
50
+ for i in range(frames_np.shape[0]):
51
+ writer.append_data(frames_np[i, :, :, :])
52
+ return video_path
53
+
54
+
55
+ def save_video_with_audio(frames_list, video_path, audio_path, fps):
56
+ """写入完整视频并混入完整音频(-shortest 保证音画同步,yuv420p + faststart 保证浏览器可播)。"""
57
+ temp_path = video_path.replace(".mp4", "_temp.mp4")
58
+ _write_frames_to_mp4(frames_list, temp_path, fps)
59
+ try:
60
+ cmd = [
61
+ "ffmpeg", "-y",
62
+ "-i", temp_path,
63
+ "-i", audio_path,
64
+ "-c:v", "copy",
65
+ "-c:a", "aac",
66
+ # "-shortest",
67
+ video_path,
68
+ ]
69
+ subprocess.run(cmd, check=True, capture_output=True)
70
+ finally:
71
+ if os.path.exists(temp_path):
72
+ os.remove(temp_path)
73
+ return video_path
74
+
75
+ def _save_chunk_audio_to_wav(audio_array, wav_path, sample_rate=16000):
76
+ """将一段 float32 [-1,1] 的音频数组保存为 wav 文件。"""
77
+ os.makedirs(os.path.dirname(wav_path) or ".", exist_ok=True)
78
+ samples = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
79
+ with wave.open(wav_path, "wb") as wav_file:
80
+ wav_file.setnchannels(1)
81
+ wav_file.setsampwidth(2)
82
+ wav_file.setframerate(sample_rate)
83
+ wav_file.writeframes(samples.tobytes())
84
+ return wav_path
85
+
86
+ def run_inference_streaming(
87
+ ckpt_dir,
88
+ wav2vec_dir,
89
+ model_type,
90
+ cond_image,
91
+ audio_path,
92
+ seed,
93
+ use_face_crop,
94
+ progress=gr.Progress(),
95
+ ):
96
+ """
97
+ 流式推理:主程序监控 res_queue,有 frames 就保存并 yield;
98
+ 推理在独立线程中执行,按 chunk 顺序 infer,结果放入 res_queue。
99
+ """
100
+ global pipeline, loaded_ckpt_dir, loaded_wav2vec_dir, loaded_model_type
101
+
102
+ if (
103
+ pipeline is None
104
+ or loaded_ckpt_dir != ckpt_dir
105
+ or loaded_wav2vec_dir != wav2vec_dir
106
+ or loaded_model_type != model_type
107
+ ):
108
+ progress(0.2, desc="Loading Model...")
109
+ logger.info(f"Loading pipeline with ckpt_dir={ckpt_dir}, wav2vec_dir={wav2vec_dir}")
110
+ try:
111
+ pipeline = get_pipeline(
112
+ world_size=1,
113
+ ckpt_dir=ckpt_dir,
114
+ model_type=model_type,
115
+ wav2vec_dir=wav2vec_dir,
116
+ )
117
+ loaded_ckpt_dir = ckpt_dir
118
+ loaded_wav2vec_dir = wav2vec_dir
119
+ loaded_model_type = model_type
120
+ except Exception as e:
121
+ logger.error(f"Failed to load model: {e}")
122
+ raise gr.Error(f"Failed to load model: {e}")
123
+
124
+ progress(0.5, desc="Preparing Data...")
125
+ base_seed = int(seed) if seed >= 0 else 9999
126
+ try:
127
+ get_base_data(
128
+ pipeline,
129
+ cond_image_path_or_dir=cond_image,
130
+ base_seed=base_seed,
131
+ use_face_crop=use_face_crop,
132
+ )
133
+ except Exception as e:
134
+ logger.error(f"Error in get_base_data: {e}")
135
+ raise gr.Error(f"Error processing inputs: {e}")
136
+
137
+ infer_params = get_infer_params()
138
+ sample_rate = infer_params["sample_rate"]
139
+ tgt_fps = infer_params["tgt_fps"]
140
+ cached_audio_duration = infer_params["cached_audio_duration"]
141
+ frame_num = infer_params["frame_num"]
142
+ motion_frames_num = infer_params["motion_frames_num"]
143
+ slice_len = frame_num - motion_frames_num
144
+
145
+ try:
146
+ human_speech_array_all, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
147
+ except Exception as e:
148
+ raise gr.Error(f"Failed to load audio file: {e}")
149
+
150
+ human_speech_array_slice_len = slice_len * sample_rate // tgt_fps
151
+
152
+ stream_dir = os.path.join("gradio_results", "stream_preview")
153
+ os.makedirs(stream_dir, exist_ok=True)
154
+ timestamp = datetime.now().strftime("%Y%m%d-%H%M%S-%f")[:-3]
155
+ accumulated = []
156
+
157
+ # 默认使用 stream 模式:准备 chunk 切片
158
+ cached_audio_length_sum = sample_rate * cached_audio_duration
159
+ audio_end_idx = cached_audio_duration * tgt_fps
160
+ audio_start_idx = audio_end_idx - frame_num
161
+ remainder = len(human_speech_array_all) % human_speech_array_slice_len
162
+ if remainder > 0:
163
+ pad_length = human_speech_array_slice_len - remainder
164
+ human_speech_array_all = np.concatenate(
165
+ [human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)]
166
+ )
167
+ human_speech_array_slices = human_speech_array_all.reshape(-1, human_speech_array_slice_len)
168
+ total_chunks = len(human_speech_array_slices)
169
+ if total_chunks == 0:
170
+ raise gr.Error("Audio too short: no chunks to generate. Please use a longer audio.")
171
+
172
+ # Data prepare:按每 k 个 chunk 合并为一段 wav 保存(时间戳+segment_id 命名)
173
+ segment_audio_paths = {}
174
+ num_segments = (total_chunks + CHUNKS_PER_SEGMENT - 1) // CHUNKS_PER_SEGMENT
175
+ for segment_id in range(num_segments):
176
+ start = segment_id * CHUNKS_PER_SEGMENT
177
+ end = min(start + CHUNKS_PER_SEGMENT, total_chunks)
178
+ audio_concat = np.concatenate(
179
+ [human_speech_array_slices[i] for i in range(start, end)]
180
+ )
181
+ segment_audio_name = f"audio_{timestamp}_seg_{segment_id:04d}.wav"
182
+ segment_audio_path = os.path.join(stream_dir, segment_audio_name)
183
+ _save_chunk_audio_to_wav(
184
+ audio_concat,
185
+ segment_audio_path,
186
+ sample_rate=sample_rate,
187
+ )
188
+ segment_audio_paths[segment_id] = segment_audio_path
189
+ logger.info(
190
+ f"Pre-saved {num_segments} segment audios (every {CHUNKS_PER_SEGMENT} chunks) under {stream_dir}"
191
+ )
192
+
193
+ # 结果队列:推理线程放入 (chunk_idx, chunk_frames_np),主线程根据 chunk_id 取对应音频合并
194
+ res_queue = queue.Queue()
195
+
196
+ def inference_worker():
197
+ """单独线程:按 chunk 顺序执行 infer,每生成一帧就放入 res_queue,立即继续下一 chunk。"""
198
+ audio_dq = deque([0.0] * cached_audio_length_sum, maxlen=cached_audio_length_sum)
199
+ for chunk_idx, human_speech_array in enumerate(human_speech_array_slices):
200
+ audio_dq.extend(human_speech_array.tolist())
201
+ audio_array = np.array(audio_dq)
202
+ audio_embedding = get_audio_embedding(pipeline, audio_array, audio_start_idx, audio_end_idx)
203
+ torch.cuda.synchronize()
204
+ start_time = time.time()
205
+ video = run_pipeline(pipeline, audio_embedding)
206
+ video = video[motion_frames_num:]
207
+ torch.cuda.synchronize()
208
+ logger.info(f"Infer chunk-{chunk_idx} done, cost time: {time.time() - start_time:.2f}s")
209
+ chunk_frames_np = video.cpu().numpy()
210
+ res_queue.put((chunk_idx, chunk_frames_np))
211
+ res_queue.put(None) # 结束哨兵
212
+
213
+ worker_thread = threading.Thread(target=inference_worker)
214
+ worker_thread.start()
215
+ logger.info("Inference worker thread started. Main will consume res_queue and yield video paths.")
216
+
217
+ # 主程序:监控 res_queue,每凑满 k 个 chunk 合并为一段 mp4(含对应段音频)并 yield
218
+ frame_buffer = []
219
+ while True:
220
+ item = res_queue.get()
221
+ if item is None:
222
+ break
223
+ chunk_idx, chunk_frames_np = item
224
+ chunk_frames = torch.from_numpy(chunk_frames_np)
225
+ accumulated.append(chunk_frames)
226
+ frame_buffer.append(chunk_frames)
227
+ if len(frame_buffer) == CHUNKS_PER_SEGMENT:
228
+ segment_id = (chunk_idx + 1 - CHUNKS_PER_SEGMENT) // CHUNKS_PER_SEGMENT
229
+ segment_audio_path = segment_audio_paths[segment_id]
230
+ segment_path = os.path.join(
231
+ stream_dir, f"preview_{timestamp}_seg_{segment_id:04d}.mp4"
232
+ )
233
+ save_video_with_audio(
234
+ frame_buffer,
235
+ segment_path,
236
+ segment_audio_path,
237
+ fps=tgt_fps,
238
+ )
239
+ logger.info(
240
+ f"Saved segment-{segment_id} (chunks {segment_id * CHUNKS_PER_SEGMENT}-{chunk_idx}) and yielding to frontend."
241
+ )
242
+ yield os.path.abspath(segment_path)
243
+ frame_buffer = []
244
+
245
+ # 不足 k 的剩余 chunk 合并为最后一段
246
+ if frame_buffer:
247
+ segment_id = num_segments - 1
248
+ segment_audio_path = segment_audio_paths[segment_id]
249
+ segment_path = os.path.join(
250
+ stream_dir, f"preview_{timestamp}_seg_{segment_id:04d}.mp4"
251
+ )
252
+ save_video_with_audio(
253
+ frame_buffer,
254
+ segment_path,
255
+ segment_audio_path,
256
+ fps=tgt_fps,
257
+ )
258
+ logger.info(
259
+ f"Saved final segment-{segment_id} ({len(frame_buffer)} chunks) and yielding to frontend."
260
+ )
261
+ yield os.path.abspath(segment_path)
262
+
263
+ worker_thread.join()
264
+
265
+ if not accumulated:
266
+ raise gr.Error("No video frames generated. Please check inputs and try again.")
267
+
268
+ output_dir = "gradio_results"
269
+ os.makedirs(output_dir, exist_ok=True)
270
+ final_filename = f"res_{timestamp}.mp4"
271
+ final_path = os.path.join(output_dir, final_filename)
272
+ save_video_with_audio(accumulated, final_path, audio_path, fps=tgt_fps)
273
+ logger.info(f"Saved to {final_path}")
274
+
275
+
276
+ # ---------- Gradio UI ----------
277
+ with gr.Blocks(title="SoulX-FlashHead 流式视频生成", theme=gr.themes.Soft()) as app:
278
+ gr.Markdown("# ⚡ SoulX-FlashHead 流式视频生成")
279
+ gr.Markdown("上传图片与音频,边生成边播放,音画同步。当前仅支持单GPU。")
280
+
281
+ with gr.Row():
282
+ with gr.Column(scale=1):
283
+ with gr.Group():
284
+ gr.Markdown("### 🎬 生成输入")
285
+ with gr.Row():
286
+ cond_image_input = gr.Image(
287
+ label="Condition Image",
288
+ type="filepath",
289
+ value="examples/girl.png",
290
+ height=300,
291
+ )
292
+ audio_path_input = gr.Audio(
293
+ label="Audio Input",
294
+ type="filepath",
295
+ value="examples/podcast_sichuan_16k.wav",
296
+ )
297
+ generate_btn = gr.Button("🚀 流式生成视频", variant="primary", size="lg")
298
+ with gr.Accordion("⚙️ 高级设置", open=False):
299
+ ckpt_dir_input = gr.Textbox(
300
+ label="FlashHead Checkpoint Directory",
301
+ value="models/SoulX-FlashHead-1_3B",
302
+ )
303
+ wav2vec_dir_input = gr.Textbox(
304
+ label="Wav2Vec Directory",
305
+ value="models/wav2vec2-base-960h",
306
+ )
307
+ model_type_input = gr.Dropdown(
308
+ label="Model Type",
309
+ choices=["pro", "lite"],
310
+ value="lite",
311
+ )
312
+ use_face_crop_input = gr.Checkbox(label="Use Face Crop", value=False)
313
+ seed_input = gr.Number(label="Random Seed", value=9999, precision=0)
314
+ with gr.Column(scale=1):
315
+ gr.Markdown("### 📺 输出视频(流式更新)")
316
+ video_output = gr.Video(
317
+ label="Generated Video",
318
+ height=512,
319
+ format="mp4",
320
+ streaming=True,
321
+ autoplay=True,
322
+ )
323
+
324
+ generate_btn.click(
325
+ fn=run_inference_streaming,
326
+ inputs=[
327
+ ckpt_dir_input,
328
+ wav2vec_dir_input,
329
+ model_type_input,
330
+ cond_image_input,
331
+ audio_path_input,
332
+ seed_input,
333
+ use_face_crop_input,
334
+ ],
335
+ outputs=video_output,
336
+ )
337
+
338
+ if __name__ == "__main__":
339
+ app.launch()
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.7.1
2
+ opencv-python>=4.12.0.88
3
+ opencv-python-headless>=4.12.0.88
4
+ diffusers>=0.34.0
5
+ transformers==4.57.3
6
+ tokenizers>=0.20.3
7
+ accelerate>=1.8.1
8
+ tqdm
9
+ imageio
10
+ easydict
11
+ ftfy
12
+ imageio-ffmpeg
13
+ scikit-image
14
+ loguru
15
+ gradio==5.50.0
16
+ xfuser>=0.4.3
17
+ pyloudnorm
18
+ decord
19
+ xformers==0.0.31
20
+ librosa
21
+ mediapipe==0.10.9
22
+ flask
23
+ huggingface_hub