Spaces:
Running
Running
Upload files
Browse files- .DS_Store +0 -0
- LICENSE +201 -0
- app.py +6 -0
- download_model_from_hf.py +20 -0
- flash_head/.DS_Store +0 -0
- flash_head/audio_analysis/torch_utils.py +20 -0
- flash_head/audio_analysis/wav2vec2.py +125 -0
- flash_head/configs/infer_params.yaml +10 -0
- flash_head/inference.py +77 -0
- flash_head/ltx_video/.DS_Store +0 -0
- flash_head/ltx_video/__init__.py +0 -0
- flash_head/ltx_video/ltx_vae.py +42 -0
- flash_head/ltx_video/models/__init__.py +0 -0
- flash_head/ltx_video/models/autoencoders/__init__.py +0 -0
- flash_head/ltx_video/models/autoencoders/causal_conv3d.py +63 -0
- flash_head/ltx_video/models/autoencoders/causal_video_autoencoder.py +1412 -0
- flash_head/ltx_video/models/autoencoders/conv_nd_factory.py +90 -0
- flash_head/ltx_video/models/autoencoders/dual_conv3d.py +217 -0
- flash_head/ltx_video/models/autoencoders/pixel_norm.py +12 -0
- flash_head/ltx_video/models/autoencoders/vae.py +380 -0
- flash_head/ltx_video/models/autoencoders/vae_encode.py +256 -0
- flash_head/ltx_video/models/autoencoders/video_autoencoder.py +1045 -0
- flash_head/ltx_video/models/transformers/__init__.py +0 -0
- flash_head/ltx_video/models/transformers/attention.py +1265 -0
- flash_head/ltx_video/models/transformers/embeddings.py +129 -0
- flash_head/ltx_video/models/transformers/symmetric_patchifier.py +84 -0
- flash_head/ltx_video/models/transformers/transformer3d.py +507 -0
- flash_head/ltx_video/utils/__init__.py +0 -0
- flash_head/ltx_video/utils/diffusers_config_mapping.py +174 -0
- flash_head/ltx_video/utils/prompt_enhance_utils.py +226 -0
- flash_head/ltx_video/utils/skip_layer_strategy.py +8 -0
- flash_head/ltx_video/utils/torch_utils.py +25 -0
- flash_head/src/.DS_Store +0 -0
- flash_head/src/distributed/usp_device.py +35 -0
- flash_head/src/modules/flash_head_model.py +548 -0
- flash_head/src/pipeline/flash_head_pipeline.py +316 -0
- flash_head/utils/cpu_face_handler.py +55 -0
- flash_head/utils/facecrop.py +110 -0
- flash_head/utils/utils.py +222 -0
- flash_head/wan/modules/__init__.py +5 -0
- flash_head/wan/modules/vae.py +1598 -0
- generate_video.py +218 -0
- gradio_app_streaming.py +339 -0
- requirements.txt +23 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
app.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from gradio_app_streaming import app
|
| 2 |
+
from download_model_from_hf import download_model
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
download_model("Soul-AILab/SoulX-FlashHead-1_3B", "models")
|
| 6 |
+
app.launch(share=True)
|
download_model_from_hf.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from huggingface_hub import snapshot_download
|
| 3 |
+
|
| 4 |
+
def download_model(model_name, save_dir="models"):
|
| 5 |
+
# 创建保存目录
|
| 6 |
+
save_path = Path(save_dir) / model_name.split("/")[-1]
|
| 7 |
+
|
| 8 |
+
if save_path.exists():
|
| 9 |
+
print(f"✅ 模型已存在: {save_path}")
|
| 10 |
+
return str(save_path)
|
| 11 |
+
|
| 12 |
+
save_path.mkdir(parents=True, exist_ok=True)
|
| 13 |
+
|
| 14 |
+
download_path = snapshot_download(
|
| 15 |
+
repo_id=model_name,
|
| 16 |
+
local_dir=save_path,
|
| 17 |
+
local_dir_use_symlinks=False
|
| 18 |
+
)
|
| 19 |
+
print(f"✅ 下载完成: {download_path}")
|
| 20 |
+
return download_path
|
flash_head/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
flash_head/audio_analysis/torch_utils.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def get_mask_from_lengths(lengths, max_len=None):
|
| 6 |
+
lengths = lengths.to(torch.long)
|
| 7 |
+
if max_len is None:
|
| 8 |
+
max_len = torch.max(lengths).item()
|
| 9 |
+
|
| 10 |
+
ids = torch.arange(0, max_len).unsqueeze(0).expand(lengths.shape[0], -1).to(lengths.device)
|
| 11 |
+
mask = ids < lengths.unsqueeze(1).expand(-1, max_len)
|
| 12 |
+
|
| 13 |
+
return mask
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def linear_interpolation(features, seq_len):
|
| 17 |
+
features = features.transpose(1, 2)
|
| 18 |
+
output_features = F.interpolate(features, size=seq_len, align_corners=True, mode='linear')
|
| 19 |
+
return output_features.transpose(1, 2)
|
| 20 |
+
|
flash_head/audio_analysis/wav2vec2.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import Wav2Vec2Config, Wav2Vec2Model
|
| 2 |
+
from transformers.modeling_outputs import BaseModelOutput
|
| 3 |
+
|
| 4 |
+
from .torch_utils import linear_interpolation
|
| 5 |
+
|
| 6 |
+
# the implementation of Wav2Vec2Model is borrowed from
|
| 7 |
+
# https://github.com/huggingface/transformers/blob/HEAD/src/transformers/models/wav2vec2/modeling_wav2vec2.py
|
| 8 |
+
# initialize our encoder with the pre-trained wav2vec 2.0 weights.
|
| 9 |
+
class Wav2Vec2Model(Wav2Vec2Model):
|
| 10 |
+
def __init__(self, config: Wav2Vec2Config):
|
| 11 |
+
super().__init__(config)
|
| 12 |
+
|
| 13 |
+
def forward(
|
| 14 |
+
self,
|
| 15 |
+
input_values,
|
| 16 |
+
seq_len,
|
| 17 |
+
attention_mask=None,
|
| 18 |
+
mask_time_indices=None,
|
| 19 |
+
output_attentions=None,
|
| 20 |
+
output_hidden_states=None,
|
| 21 |
+
return_dict=None,
|
| 22 |
+
):
|
| 23 |
+
self.config.output_attentions = False
|
| 24 |
+
|
| 25 |
+
output_hidden_states = (
|
| 26 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 27 |
+
)
|
| 28 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 29 |
+
|
| 30 |
+
extract_features = self.feature_extractor(input_values)
|
| 31 |
+
extract_features = extract_features.transpose(1, 2)
|
| 32 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
| 33 |
+
|
| 34 |
+
if attention_mask is not None:
|
| 35 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 36 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 37 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 41 |
+
hidden_states = self._mask_hidden_states(
|
| 42 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
encoder_outputs = self.encoder(
|
| 46 |
+
hidden_states,
|
| 47 |
+
attention_mask=attention_mask,
|
| 48 |
+
output_attentions=output_attentions,
|
| 49 |
+
output_hidden_states=output_hidden_states,
|
| 50 |
+
return_dict=return_dict,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
hidden_states = encoder_outputs[0]
|
| 54 |
+
|
| 55 |
+
if self.adapter is not None:
|
| 56 |
+
hidden_states = self.adapter(hidden_states)
|
| 57 |
+
|
| 58 |
+
if not return_dict:
|
| 59 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
| 60 |
+
return BaseModelOutput(
|
| 61 |
+
last_hidden_state=hidden_states,
|
| 62 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 63 |
+
attentions=encoder_outputs.attentions,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def feature_extract(
|
| 68 |
+
self,
|
| 69 |
+
input_values,
|
| 70 |
+
seq_len,
|
| 71 |
+
):
|
| 72 |
+
extract_features = self.feature_extractor(input_values)
|
| 73 |
+
extract_features = extract_features.transpose(1, 2)
|
| 74 |
+
extract_features = linear_interpolation(extract_features, seq_len=seq_len)
|
| 75 |
+
|
| 76 |
+
return extract_features
|
| 77 |
+
|
| 78 |
+
def encode(
|
| 79 |
+
self,
|
| 80 |
+
extract_features,
|
| 81 |
+
attention_mask=None,
|
| 82 |
+
mask_time_indices=None,
|
| 83 |
+
output_attentions=None,
|
| 84 |
+
output_hidden_states=None,
|
| 85 |
+
return_dict=None,
|
| 86 |
+
):
|
| 87 |
+
self.config.output_attentions = False
|
| 88 |
+
|
| 89 |
+
output_hidden_states = (
|
| 90 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 91 |
+
)
|
| 92 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 93 |
+
|
| 94 |
+
if attention_mask is not None:
|
| 95 |
+
# compute reduced attention_mask corresponding to feature vectors
|
| 96 |
+
attention_mask = self._get_feature_vector_attention_mask(
|
| 97 |
+
extract_features.shape[1], attention_mask, add_adapter=False
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
hidden_states, extract_features = self.feature_projection(extract_features)
|
| 102 |
+
hidden_states = self._mask_hidden_states(
|
| 103 |
+
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
encoder_outputs = self.encoder(
|
| 107 |
+
hidden_states,
|
| 108 |
+
attention_mask=attention_mask,
|
| 109 |
+
output_attentions=output_attentions,
|
| 110 |
+
output_hidden_states=output_hidden_states,
|
| 111 |
+
return_dict=return_dict,
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
hidden_states = encoder_outputs[0]
|
| 115 |
+
|
| 116 |
+
if self.adapter is not None:
|
| 117 |
+
hidden_states = self.adapter(hidden_states)
|
| 118 |
+
|
| 119 |
+
if not return_dict:
|
| 120 |
+
return (hidden_states, ) + encoder_outputs[1:]
|
| 121 |
+
return BaseModelOutput(
|
| 122 |
+
last_hidden_state=hidden_states,
|
| 123 |
+
hidden_states=encoder_outputs.hidden_states,
|
| 124 |
+
attentions=encoder_outputs.attentions,
|
| 125 |
+
)
|
flash_head/configs/infer_params.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
frame_num: 33
|
| 2 |
+
motion_frames_latent_num: 2
|
| 3 |
+
tgt_fps: 25
|
| 4 |
+
sample_rate: 16000
|
| 5 |
+
sample_shift: 5
|
| 6 |
+
color_correction_strength: 1.0
|
| 7 |
+
cached_audio_duration: 8
|
| 8 |
+
num_heads: 12
|
| 9 |
+
height: 512
|
| 10 |
+
width: 512
|
flash_head/inference.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import yaml
|
| 3 |
+
import torch
|
| 4 |
+
import copy
|
| 5 |
+
from loguru import logger
|
| 6 |
+
|
| 7 |
+
from flash_head.src.pipeline.flash_head_pipeline import FlashHeadPipeline
|
| 8 |
+
from flash_head.src.distributed.usp_device import get_device, get_parallel_degree
|
| 9 |
+
|
| 10 |
+
with open("flash_head/configs/infer_params.yaml", "r") as f:
|
| 11 |
+
infer_params = yaml.safe_load(f)
|
| 12 |
+
|
| 13 |
+
def get_pipeline(world_size, ckpt_dir, model_type, wav2vec_dir):
|
| 14 |
+
global infer_params
|
| 15 |
+
ulysses_degree, ring_degree = get_parallel_degree(world_size, infer_params['num_heads'])
|
| 16 |
+
device = get_device(ulysses_degree, ring_degree)
|
| 17 |
+
logger.info(f"ulysses_degree: {ulysses_degree}, ring_degree: {ring_degree}, device: {device}")
|
| 18 |
+
|
| 19 |
+
pipeline = FlashHeadPipeline(
|
| 20 |
+
checkpoint_dir=ckpt_dir,
|
| 21 |
+
model_type=model_type,
|
| 22 |
+
wav2vec_dir=wav2vec_dir,
|
| 23 |
+
device=device,
|
| 24 |
+
use_usp=(world_size > 1),
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# compute motion_frames_num
|
| 28 |
+
motion_frames_latent_num = infer_params['motion_frames_latent_num']
|
| 29 |
+
motion_frames_num = (motion_frames_latent_num - 1) * pipeline.config.vae_stride[0] + 1
|
| 30 |
+
infer_params['motion_frames_num'] = motion_frames_num
|
| 31 |
+
|
| 32 |
+
# TODO: move to args
|
| 33 |
+
if model_type == "pretrained":
|
| 34 |
+
infer_params['sample_steps'] = 20
|
| 35 |
+
else:
|
| 36 |
+
infer_params['sample_steps'] = 4
|
| 37 |
+
return pipeline
|
| 38 |
+
|
| 39 |
+
def get_base_data(pipeline, cond_image_path_or_dir, base_seed, use_face_crop):
|
| 40 |
+
pipeline.prepare_params(
|
| 41 |
+
cond_image_path_or_dir=cond_image_path_or_dir,
|
| 42 |
+
target_size=(infer_params['height'], infer_params['width']),
|
| 43 |
+
frame_num=infer_params['frame_num'],
|
| 44 |
+
motion_frames_num=infer_params['motion_frames_num'],
|
| 45 |
+
sampling_steps=infer_params['sample_steps'],
|
| 46 |
+
seed=base_seed,
|
| 47 |
+
shift=infer_params['sample_shift'],
|
| 48 |
+
color_correction_strength=infer_params['color_correction_strength'],
|
| 49 |
+
use_face_crop=use_face_crop,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def get_infer_params():
|
| 53 |
+
global infer_params
|
| 54 |
+
return copy.deepcopy(infer_params)
|
| 55 |
+
|
| 56 |
+
def get_audio_embedding(pipeline, audio_array, audio_start_idx=-1, audio_end_idx=-1):
|
| 57 |
+
# audio_array = loudness_norm(audio_array, infer_params['sample_rate'])
|
| 58 |
+
audio_embedding = pipeline.preprocess_audio(audio_array, sr=infer_params['sample_rate'], fps=infer_params['tgt_fps'])
|
| 59 |
+
|
| 60 |
+
if audio_start_idx == -1 or audio_end_idx == -1:
|
| 61 |
+
audio_start_idx = 0
|
| 62 |
+
audio_end_idx = audio_embedding.shape[0]
|
| 63 |
+
|
| 64 |
+
indices = (torch.arange(2 * 2 + 1) - 2) * 1
|
| 65 |
+
|
| 66 |
+
center_indices = torch.arange(audio_start_idx, audio_end_idx, 1).unsqueeze(1) + indices.unsqueeze(0)
|
| 67 |
+
center_indices = torch.clamp(center_indices, min=0, max=audio_end_idx-1)
|
| 68 |
+
|
| 69 |
+
audio_embedding = audio_embedding[center_indices][None,...].contiguous()
|
| 70 |
+
return audio_embedding
|
| 71 |
+
|
| 72 |
+
def run_pipeline(pipeline, audio_embedding):
|
| 73 |
+
audio_embedding = audio_embedding.to(pipeline.device)
|
| 74 |
+
sample = pipeline.generate(audio_embedding)
|
| 75 |
+
sample_frames = (((sample+1)/2).permute(1,2,3,0).clip(0,1) * 255).contiguous()
|
| 76 |
+
return sample_frames
|
| 77 |
+
|
flash_head/ltx_video/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
flash_head/ltx_video/__init__.py
ADDED
|
File without changes
|
flash_head/ltx_video/ltx_vae.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from flash_head.ltx_video.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LtxVAE:
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
pretrained_model_type_or_path,
|
| 9 |
+
dtype = torch.bfloat16,
|
| 10 |
+
device = "cuda",
|
| 11 |
+
):
|
| 12 |
+
self.model = CausalVideoAutoencoder.from_pretrained(pretrained_model_type_or_path)
|
| 13 |
+
self.model = self.model.eval().requires_grad_(False).to(device).to(dtype)
|
| 14 |
+
|
| 15 |
+
# torch.Size([1, 3, 33, 512, 512]) -> torch.Size([128, 5, 16, 16])
|
| 16 |
+
def encode(self, video):
|
| 17 |
+
latents = self.model.encode(video, return_dict=False)[0].sample()
|
| 18 |
+
out = self.normalize_latents(latents)
|
| 19 |
+
return out[0]
|
| 20 |
+
|
| 21 |
+
# torch.Size([128, 5, 16, 16]) -> torch.Size([1, 3, 33, 512, 512])
|
| 22 |
+
def decode(self, zs):
|
| 23 |
+
latents = zs.unsqueeze(0)
|
| 24 |
+
image = self.model.decode(
|
| 25 |
+
self.un_normalize_latents(latents),
|
| 26 |
+
return_dict=False,
|
| 27 |
+
target_shape=latents.shape,
|
| 28 |
+
)[0]
|
| 29 |
+
return image
|
| 30 |
+
|
| 31 |
+
def normalize_latents(self, latents):
|
| 32 |
+
return (
|
| 33 |
+
(latents - self.model.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
|
| 34 |
+
/ self.model.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def un_normalize_latents(self,latents):
|
| 39 |
+
return (
|
| 40 |
+
latents * self.model.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 41 |
+
+ self.model.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 42 |
+
)
|
flash_head/ltx_video/models/__init__.py
ADDED
|
File without changes
|
flash_head/ltx_video/models/autoencoders/__init__.py
ADDED
|
File without changes
|
flash_head/ltx_video/models/autoencoders/causal_conv3d.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class CausalConv3d(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
in_channels,
|
| 11 |
+
out_channels,
|
| 12 |
+
kernel_size: int = 3,
|
| 13 |
+
stride: Union[int, Tuple[int]] = 1,
|
| 14 |
+
dilation: int = 1,
|
| 15 |
+
groups: int = 1,
|
| 16 |
+
spatial_padding_mode: str = "zeros",
|
| 17 |
+
**kwargs,
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.in_channels = in_channels
|
| 22 |
+
self.out_channels = out_channels
|
| 23 |
+
|
| 24 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 25 |
+
self.time_kernel_size = kernel_size[0]
|
| 26 |
+
|
| 27 |
+
dilation = (dilation, 1, 1)
|
| 28 |
+
|
| 29 |
+
height_pad = kernel_size[1] // 2
|
| 30 |
+
width_pad = kernel_size[2] // 2
|
| 31 |
+
padding = (0, height_pad, width_pad)
|
| 32 |
+
|
| 33 |
+
self.conv = nn.Conv3d(
|
| 34 |
+
in_channels,
|
| 35 |
+
out_channels,
|
| 36 |
+
kernel_size,
|
| 37 |
+
stride=stride,
|
| 38 |
+
dilation=dilation,
|
| 39 |
+
padding=padding,
|
| 40 |
+
padding_mode=spatial_padding_mode,
|
| 41 |
+
groups=groups,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def forward(self, x, causal: bool = True):
|
| 45 |
+
if causal:
|
| 46 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
| 47 |
+
(1, 1, self.time_kernel_size - 1, 1, 1)
|
| 48 |
+
)
|
| 49 |
+
x = torch.concatenate((first_frame_pad, x), dim=2)
|
| 50 |
+
else:
|
| 51 |
+
first_frame_pad = x[:, :, :1, :, :].repeat(
|
| 52 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
| 53 |
+
)
|
| 54 |
+
last_frame_pad = x[:, :, -1:, :, :].repeat(
|
| 55 |
+
(1, 1, (self.time_kernel_size - 1) // 2, 1, 1)
|
| 56 |
+
)
|
| 57 |
+
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
|
| 58 |
+
x = self.conv(x)
|
| 59 |
+
return x
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def weight(self):
|
| 63 |
+
return self.conv.weight
|
flash_head/ltx_video/models/autoencoders/causal_video_autoencoder.py
ADDED
|
@@ -0,0 +1,1412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
from typing import Any, Mapping, Optional, Tuple, Union, List
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import numpy as np
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
from torch import nn
|
| 12 |
+
from diffusers.utils import logging
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings
|
| 15 |
+
from safetensors import safe_open
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
from flash_head.ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
| 19 |
+
from flash_head.ltx_video.models.autoencoders.pixel_norm import PixelNorm
|
| 20 |
+
from flash_head.ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
|
| 21 |
+
from flash_head.ltx_video.models.transformers.attention import Attention
|
| 22 |
+
from flash_head.ltx_video.utils.diffusers_config_mapping import (
|
| 23 |
+
diffusers_and_ours_config_mapping,
|
| 24 |
+
make_hashable_key,
|
| 25 |
+
VAE_KEYS_RENAME_DICT,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics."
|
| 29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class CausalVideoAutoencoder(AutoencoderKLWrapper):
|
| 33 |
+
@classmethod
|
| 34 |
+
def from_pretrained(
|
| 35 |
+
cls,
|
| 36 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 37 |
+
*args,
|
| 38 |
+
**kwargs,
|
| 39 |
+
):
|
| 40 |
+
pretrained_model_name_or_path = Path(pretrained_model_name_or_path)
|
| 41 |
+
if (
|
| 42 |
+
pretrained_model_name_or_path.is_dir()
|
| 43 |
+
and (pretrained_model_name_or_path / "autoencoder.pth").exists()
|
| 44 |
+
):
|
| 45 |
+
config_local_path = pretrained_model_name_or_path / "config.json"
|
| 46 |
+
config = cls.load_config(config_local_path, **kwargs)
|
| 47 |
+
|
| 48 |
+
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
|
| 49 |
+
state_dict = torch.load(model_local_path, map_location=torch.device("cpu"))
|
| 50 |
+
|
| 51 |
+
statistics_local_path = (
|
| 52 |
+
pretrained_model_name_or_path / "per_channel_statistics.json"
|
| 53 |
+
)
|
| 54 |
+
if statistics_local_path.exists():
|
| 55 |
+
with open(statistics_local_path, "r") as file:
|
| 56 |
+
data = json.load(file)
|
| 57 |
+
transposed_data = list(zip(*data["data"]))
|
| 58 |
+
data_dict = {
|
| 59 |
+
col: torch.tensor(vals)
|
| 60 |
+
for col, vals in zip(data["columns"], transposed_data)
|
| 61 |
+
}
|
| 62 |
+
std_of_means = data_dict["std-of-means"]
|
| 63 |
+
mean_of_means = data_dict.get(
|
| 64 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 65 |
+
)
|
| 66 |
+
state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = (
|
| 67 |
+
std_of_means
|
| 68 |
+
)
|
| 69 |
+
state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = (
|
| 70 |
+
mean_of_means
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
elif pretrained_model_name_or_path.is_dir():
|
| 74 |
+
config_path = pretrained_model_name_or_path / "config.json"
|
| 75 |
+
with open(config_path, "r") as f:
|
| 76 |
+
config = make_hashable_key(json.load(f))
|
| 77 |
+
|
| 78 |
+
assert config in diffusers_and_ours_config_mapping, (
|
| 79 |
+
"Provided diffusers checkpoint config for VAE is not suppported. "
|
| 80 |
+
"We only support diffusers configs found in Lightricks/LTX-Video."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
config = diffusers_and_ours_config_mapping[config]
|
| 84 |
+
|
| 85 |
+
state_dict_path = (
|
| 86 |
+
pretrained_model_name_or_path
|
| 87 |
+
/ "diffusion_pytorch_model.safetensors"
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
state_dict = {}
|
| 91 |
+
with safe_open(state_dict_path, framework="pt", device="cpu") as f:
|
| 92 |
+
for k in f.keys():
|
| 93 |
+
state_dict[k] = f.get_tensor(k)
|
| 94 |
+
for key in list(state_dict.keys()):
|
| 95 |
+
new_key = key
|
| 96 |
+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
|
| 97 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 98 |
+
|
| 99 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 100 |
+
|
| 101 |
+
elif pretrained_model_name_or_path.is_file() and str(
|
| 102 |
+
pretrained_model_name_or_path
|
| 103 |
+
).endswith(".safetensors"):
|
| 104 |
+
state_dict = {}
|
| 105 |
+
with safe_open(
|
| 106 |
+
pretrained_model_name_or_path, framework="pt", device="cpu"
|
| 107 |
+
) as f:
|
| 108 |
+
metadata = f.metadata()
|
| 109 |
+
for k in f.keys():
|
| 110 |
+
state_dict[k] = f.get_tensor(k)
|
| 111 |
+
configs = json.loads(metadata["config"])
|
| 112 |
+
config = configs["vae"]
|
| 113 |
+
|
| 114 |
+
video_vae = cls.from_config(config)
|
| 115 |
+
if "torch_dtype" in kwargs:
|
| 116 |
+
video_vae.to(kwargs["torch_dtype"])
|
| 117 |
+
video_vae.load_state_dict(state_dict)
|
| 118 |
+
return video_vae
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def from_config(config):
|
| 122 |
+
assert (
|
| 123 |
+
config["_class_name"] == "CausalVideoAutoencoder"
|
| 124 |
+
), "config must have _class_name=CausalVideoAutoencoder"
|
| 125 |
+
if isinstance(config["dims"], list):
|
| 126 |
+
config["dims"] = tuple(config["dims"])
|
| 127 |
+
|
| 128 |
+
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
| 129 |
+
|
| 130 |
+
double_z = config.get("double_z", True)
|
| 131 |
+
latent_log_var = config.get(
|
| 132 |
+
"latent_log_var", "per_channel" if double_z else "none"
|
| 133 |
+
)
|
| 134 |
+
use_quant_conv = config.get("use_quant_conv", True)
|
| 135 |
+
normalize_latent_channels = config.get("normalize_latent_channels", False)
|
| 136 |
+
|
| 137 |
+
if use_quant_conv and latent_log_var in ["uniform", "constant"]:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
f"latent_log_var={latent_log_var} requires use_quant_conv=False"
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
encoder = Encoder(
|
| 143 |
+
dims=config["dims"],
|
| 144 |
+
in_channels=config.get("in_channels", 3),
|
| 145 |
+
out_channels=config["latent_channels"],
|
| 146 |
+
blocks=config.get("encoder_blocks", config.get("blocks")),
|
| 147 |
+
patch_size=config.get("patch_size", 1),
|
| 148 |
+
latent_log_var=latent_log_var,
|
| 149 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 150 |
+
base_channels=config.get("encoder_base_channels", 128),
|
| 151 |
+
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
decoder = Decoder(
|
| 155 |
+
dims=config["dims"],
|
| 156 |
+
in_channels=config["latent_channels"],
|
| 157 |
+
out_channels=config.get("out_channels", 3),
|
| 158 |
+
blocks=config.get("decoder_blocks", config.get("blocks")),
|
| 159 |
+
patch_size=config.get("patch_size", 1),
|
| 160 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 161 |
+
causal=config.get("causal_decoder", False),
|
| 162 |
+
timestep_conditioning=config.get("timestep_conditioning", False),
|
| 163 |
+
base_channels=config.get("decoder_base_channels", 128),
|
| 164 |
+
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
dims = config["dims"]
|
| 168 |
+
return CausalVideoAutoencoder(
|
| 169 |
+
encoder=encoder,
|
| 170 |
+
decoder=decoder,
|
| 171 |
+
latent_channels=config["latent_channels"],
|
| 172 |
+
dims=dims,
|
| 173 |
+
use_quant_conv=use_quant_conv,
|
| 174 |
+
normalize_latent_channels=normalize_latent_channels,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
@property
|
| 178 |
+
def config(self):
|
| 179 |
+
return SimpleNamespace(
|
| 180 |
+
_class_name="CausalVideoAutoencoder",
|
| 181 |
+
dims=self.dims,
|
| 182 |
+
in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2,
|
| 183 |
+
out_channels=self.decoder.conv_out.out_channels
|
| 184 |
+
// self.decoder.patch_size**2,
|
| 185 |
+
latent_channels=self.decoder.conv_in.in_channels,
|
| 186 |
+
encoder_blocks=self.encoder.blocks_desc,
|
| 187 |
+
decoder_blocks=self.decoder.blocks_desc,
|
| 188 |
+
scaling_factor=1.0,
|
| 189 |
+
norm_layer=self.encoder.norm_layer,
|
| 190 |
+
patch_size=self.encoder.patch_size,
|
| 191 |
+
latent_log_var=self.encoder.latent_log_var,
|
| 192 |
+
use_quant_conv=self.use_quant_conv,
|
| 193 |
+
causal_decoder=self.decoder.causal,
|
| 194 |
+
timestep_conditioning=self.decoder.timestep_conditioning,
|
| 195 |
+
normalize_latent_channels=self.normalize_latent_channels,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
@property
|
| 199 |
+
def is_video_supported(self):
|
| 200 |
+
"""
|
| 201 |
+
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
|
| 202 |
+
"""
|
| 203 |
+
return self.dims != 2
|
| 204 |
+
|
| 205 |
+
@property
|
| 206 |
+
def spatial_downscale_factor(self):
|
| 207 |
+
return (
|
| 208 |
+
2
|
| 209 |
+
** len(
|
| 210 |
+
[
|
| 211 |
+
block
|
| 212 |
+
for block in self.encoder.blocks_desc
|
| 213 |
+
if block[0]
|
| 214 |
+
in [
|
| 215 |
+
"compress_space",
|
| 216 |
+
"compress_all",
|
| 217 |
+
"compress_all_res",
|
| 218 |
+
"compress_space_res",
|
| 219 |
+
]
|
| 220 |
+
]
|
| 221 |
+
)
|
| 222 |
+
* self.encoder.patch_size
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def temporal_downscale_factor(self):
|
| 227 |
+
return 2 ** len(
|
| 228 |
+
[
|
| 229 |
+
block
|
| 230 |
+
for block in self.encoder.blocks_desc
|
| 231 |
+
if block[0]
|
| 232 |
+
in [
|
| 233 |
+
"compress_time",
|
| 234 |
+
"compress_all",
|
| 235 |
+
"compress_all_res",
|
| 236 |
+
"compress_space_res",
|
| 237 |
+
]
|
| 238 |
+
]
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
def to_json_string(self) -> str:
|
| 242 |
+
import json
|
| 243 |
+
|
| 244 |
+
return json.dumps(self.config.__dict__)
|
| 245 |
+
|
| 246 |
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
| 247 |
+
if any([key.startswith("vae.") for key in state_dict.keys()]):
|
| 248 |
+
state_dict = {
|
| 249 |
+
key.replace("vae.", ""): value
|
| 250 |
+
for key, value in state_dict.items()
|
| 251 |
+
if key.startswith("vae.")
|
| 252 |
+
}
|
| 253 |
+
ckpt_state_dict = {
|
| 254 |
+
key: value
|
| 255 |
+
for key, value in state_dict.items()
|
| 256 |
+
if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
model_keys = set(name for name, _ in self.named_modules())
|
| 260 |
+
|
| 261 |
+
key_mapping = {
|
| 262 |
+
".resnets.": ".res_blocks.",
|
| 263 |
+
"downsamplers.0": "downsample",
|
| 264 |
+
"upsamplers.0": "upsample",
|
| 265 |
+
}
|
| 266 |
+
converted_state_dict = {}
|
| 267 |
+
for key, value in ckpt_state_dict.items():
|
| 268 |
+
for k, v in key_mapping.items():
|
| 269 |
+
key = key.replace(k, v)
|
| 270 |
+
|
| 271 |
+
key_prefix = ".".join(key.split(".")[:-1])
|
| 272 |
+
if "norm" in key and key_prefix not in model_keys:
|
| 273 |
+
logger.info(
|
| 274 |
+
f"Removing key {key} from state_dict as it is not present in the model"
|
| 275 |
+
)
|
| 276 |
+
continue
|
| 277 |
+
|
| 278 |
+
converted_state_dict[key] = value
|
| 279 |
+
|
| 280 |
+
super().load_state_dict(converted_state_dict, strict=strict)
|
| 281 |
+
|
| 282 |
+
data_dict = {
|
| 283 |
+
key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value
|
| 284 |
+
for key, value in state_dict.items()
|
| 285 |
+
if key.startswith(PER_CHANNEL_STATISTICS_PREFIX)
|
| 286 |
+
}
|
| 287 |
+
if len(data_dict) > 0:
|
| 288 |
+
self.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 289 |
+
self.register_buffer(
|
| 290 |
+
"mean_of_means",
|
| 291 |
+
data_dict.get(
|
| 292 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 293 |
+
),
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
def last_layer(self):
|
| 297 |
+
if hasattr(self.decoder, "conv_out"):
|
| 298 |
+
if isinstance(self.decoder.conv_out, nn.Sequential):
|
| 299 |
+
last_layer = self.decoder.conv_out[-1]
|
| 300 |
+
else:
|
| 301 |
+
last_layer = self.decoder.conv_out
|
| 302 |
+
else:
|
| 303 |
+
last_layer = self.decoder.layers[-1]
|
| 304 |
+
return last_layer
|
| 305 |
+
|
| 306 |
+
def set_use_tpu_flash_attention(self):
|
| 307 |
+
for block in self.decoder.up_blocks:
|
| 308 |
+
if isinstance(block, UNetMidBlock3D) and block.attention_blocks:
|
| 309 |
+
for attention_block in block.attention_blocks:
|
| 310 |
+
attention_block.set_use_tpu_flash_attention()
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class Encoder(nn.Module):
|
| 314 |
+
r"""
|
| 315 |
+
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
| 319 |
+
The number of dimensions to use in convolutions.
|
| 320 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 321 |
+
The number of input channels.
|
| 322 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 323 |
+
The number of output channels.
|
| 324 |
+
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
| 325 |
+
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
| 326 |
+
base_channels (`int`, *optional*, defaults to 128):
|
| 327 |
+
The number of output channels for the first convolutional layer.
|
| 328 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 329 |
+
The number of groups for normalization.
|
| 330 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 331 |
+
The patch size to use. Should be a power of 2.
|
| 332 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 333 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 334 |
+
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
| 335 |
+
The number of channels for the log variance. Can be either `per_channel`, `uniform`, `constant` or `none`.
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
dims: Union[int, Tuple[int, int]] = 3,
|
| 341 |
+
in_channels: int = 3,
|
| 342 |
+
out_channels: int = 3,
|
| 343 |
+
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
| 344 |
+
base_channels: int = 128,
|
| 345 |
+
norm_num_groups: int = 32,
|
| 346 |
+
patch_size: Union[int, Tuple[int]] = 1,
|
| 347 |
+
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
| 348 |
+
latent_log_var: str = "per_channel",
|
| 349 |
+
spatial_padding_mode: str = "zeros",
|
| 350 |
+
):
|
| 351 |
+
super().__init__()
|
| 352 |
+
self.patch_size = patch_size
|
| 353 |
+
self.norm_layer = norm_layer
|
| 354 |
+
self.latent_channels = out_channels
|
| 355 |
+
self.latent_log_var = latent_log_var
|
| 356 |
+
self.blocks_desc = blocks
|
| 357 |
+
|
| 358 |
+
in_channels = in_channels * patch_size**2
|
| 359 |
+
output_channel = base_channels
|
| 360 |
+
|
| 361 |
+
self.conv_in = make_conv_nd(
|
| 362 |
+
dims=dims,
|
| 363 |
+
in_channels=in_channels,
|
| 364 |
+
out_channels=output_channel,
|
| 365 |
+
kernel_size=3,
|
| 366 |
+
stride=1,
|
| 367 |
+
padding=1,
|
| 368 |
+
causal=True,
|
| 369 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
self.down_blocks = nn.ModuleList([])
|
| 373 |
+
|
| 374 |
+
for block_name, block_params in blocks:
|
| 375 |
+
input_channel = output_channel
|
| 376 |
+
if isinstance(block_params, int):
|
| 377 |
+
block_params = {"num_layers": block_params}
|
| 378 |
+
|
| 379 |
+
if block_name == "res_x":
|
| 380 |
+
block = UNetMidBlock3D(
|
| 381 |
+
dims=dims,
|
| 382 |
+
in_channels=input_channel,
|
| 383 |
+
num_layers=block_params["num_layers"],
|
| 384 |
+
resnet_eps=1e-6,
|
| 385 |
+
resnet_groups=norm_num_groups,
|
| 386 |
+
norm_layer=norm_layer,
|
| 387 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 388 |
+
)
|
| 389 |
+
elif block_name == "res_x_y":
|
| 390 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 391 |
+
block = ResnetBlock3D(
|
| 392 |
+
dims=dims,
|
| 393 |
+
in_channels=input_channel,
|
| 394 |
+
out_channels=output_channel,
|
| 395 |
+
eps=1e-6,
|
| 396 |
+
groups=norm_num_groups,
|
| 397 |
+
norm_layer=norm_layer,
|
| 398 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 399 |
+
)
|
| 400 |
+
elif block_name == "compress_time":
|
| 401 |
+
block = make_conv_nd(
|
| 402 |
+
dims=dims,
|
| 403 |
+
in_channels=input_channel,
|
| 404 |
+
out_channels=output_channel,
|
| 405 |
+
kernel_size=3,
|
| 406 |
+
stride=(2, 1, 1),
|
| 407 |
+
causal=True,
|
| 408 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 409 |
+
)
|
| 410 |
+
elif block_name == "compress_space":
|
| 411 |
+
block = make_conv_nd(
|
| 412 |
+
dims=dims,
|
| 413 |
+
in_channels=input_channel,
|
| 414 |
+
out_channels=output_channel,
|
| 415 |
+
kernel_size=3,
|
| 416 |
+
stride=(1, 2, 2),
|
| 417 |
+
causal=True,
|
| 418 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 419 |
+
)
|
| 420 |
+
elif block_name == "compress_all":
|
| 421 |
+
block = make_conv_nd(
|
| 422 |
+
dims=dims,
|
| 423 |
+
in_channels=input_channel,
|
| 424 |
+
out_channels=output_channel,
|
| 425 |
+
kernel_size=3,
|
| 426 |
+
stride=(2, 2, 2),
|
| 427 |
+
causal=True,
|
| 428 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 429 |
+
)
|
| 430 |
+
elif block_name == "compress_all_x_y":
|
| 431 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 432 |
+
block = make_conv_nd(
|
| 433 |
+
dims=dims,
|
| 434 |
+
in_channels=input_channel,
|
| 435 |
+
out_channels=output_channel,
|
| 436 |
+
kernel_size=3,
|
| 437 |
+
stride=(2, 2, 2),
|
| 438 |
+
causal=True,
|
| 439 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 440 |
+
)
|
| 441 |
+
elif block_name == "compress_all_res":
|
| 442 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 443 |
+
block = SpaceToDepthDownsample(
|
| 444 |
+
dims=dims,
|
| 445 |
+
in_channels=input_channel,
|
| 446 |
+
out_channels=output_channel,
|
| 447 |
+
stride=(2, 2, 2),
|
| 448 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 449 |
+
)
|
| 450 |
+
elif block_name == "compress_space_res":
|
| 451 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 452 |
+
block = SpaceToDepthDownsample(
|
| 453 |
+
dims=dims,
|
| 454 |
+
in_channels=input_channel,
|
| 455 |
+
out_channels=output_channel,
|
| 456 |
+
stride=(1, 2, 2),
|
| 457 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 458 |
+
)
|
| 459 |
+
elif block_name == "compress_time_res":
|
| 460 |
+
output_channel = block_params.get("multiplier", 2) * output_channel
|
| 461 |
+
block = SpaceToDepthDownsample(
|
| 462 |
+
dims=dims,
|
| 463 |
+
in_channels=input_channel,
|
| 464 |
+
out_channels=output_channel,
|
| 465 |
+
stride=(2, 1, 1),
|
| 466 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 467 |
+
)
|
| 468 |
+
else:
|
| 469 |
+
raise ValueError(f"unknown block: {block_name}")
|
| 470 |
+
|
| 471 |
+
self.down_blocks.append(block)
|
| 472 |
+
|
| 473 |
+
# out
|
| 474 |
+
if norm_layer == "group_norm":
|
| 475 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 476 |
+
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
| 477 |
+
)
|
| 478 |
+
elif norm_layer == "pixel_norm":
|
| 479 |
+
self.conv_norm_out = PixelNorm()
|
| 480 |
+
elif norm_layer == "layer_norm":
|
| 481 |
+
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
| 482 |
+
|
| 483 |
+
self.conv_act = nn.SiLU()
|
| 484 |
+
|
| 485 |
+
conv_out_channels = out_channels
|
| 486 |
+
if latent_log_var == "per_channel":
|
| 487 |
+
conv_out_channels *= 2
|
| 488 |
+
elif latent_log_var == "uniform":
|
| 489 |
+
conv_out_channels += 1
|
| 490 |
+
elif latent_log_var == "constant":
|
| 491 |
+
conv_out_channels += 1
|
| 492 |
+
elif latent_log_var != "none":
|
| 493 |
+
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 494 |
+
self.conv_out = make_conv_nd(
|
| 495 |
+
dims,
|
| 496 |
+
output_channel,
|
| 497 |
+
conv_out_channels,
|
| 498 |
+
3,
|
| 499 |
+
padding=1,
|
| 500 |
+
causal=True,
|
| 501 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 502 |
+
)
|
| 503 |
+
|
| 504 |
+
self.gradient_checkpointing = False
|
| 505 |
+
|
| 506 |
+
def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
| 507 |
+
r"""The forward method of the `Encoder` class."""
|
| 508 |
+
|
| 509 |
+
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
| 510 |
+
sample = self.conv_in(sample)
|
| 511 |
+
|
| 512 |
+
checkpoint_fn = (
|
| 513 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 514 |
+
if self.gradient_checkpointing and self.training
|
| 515 |
+
else lambda x: x
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
for down_block in self.down_blocks:
|
| 519 |
+
sample = checkpoint_fn(down_block)(sample)
|
| 520 |
+
|
| 521 |
+
sample = self.conv_norm_out(sample)
|
| 522 |
+
sample = self.conv_act(sample)
|
| 523 |
+
sample = self.conv_out(sample)
|
| 524 |
+
|
| 525 |
+
if self.latent_log_var == "uniform":
|
| 526 |
+
last_channel = sample[:, -1:, ...]
|
| 527 |
+
num_dims = sample.dim()
|
| 528 |
+
|
| 529 |
+
if num_dims == 4:
|
| 530 |
+
# For shape (B, C, H, W)
|
| 531 |
+
repeated_last_channel = last_channel.repeat(
|
| 532 |
+
1, sample.shape[1] - 2, 1, 1
|
| 533 |
+
)
|
| 534 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 535 |
+
elif num_dims == 5:
|
| 536 |
+
# For shape (B, C, F, H, W)
|
| 537 |
+
repeated_last_channel = last_channel.repeat(
|
| 538 |
+
1, sample.shape[1] - 2, 1, 1, 1
|
| 539 |
+
)
|
| 540 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 541 |
+
else:
|
| 542 |
+
raise ValueError(f"Invalid input shape: {sample.shape}")
|
| 543 |
+
elif self.latent_log_var == "constant":
|
| 544 |
+
sample = sample[:, :-1, ...]
|
| 545 |
+
approx_ln_0 = (
|
| 546 |
+
-30
|
| 547 |
+
) # this is the minimal clamp value in DiagonalGaussianDistribution objects
|
| 548 |
+
sample = torch.cat(
|
| 549 |
+
[sample, torch.ones_like(sample, device=sample.device) * approx_ln_0],
|
| 550 |
+
dim=1,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
return sample
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
class Decoder(nn.Module):
|
| 557 |
+
r"""
|
| 558 |
+
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
| 559 |
+
|
| 560 |
+
Args:
|
| 561 |
+
dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3):
|
| 562 |
+
The number of dimensions to use in convolutions.
|
| 563 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 564 |
+
The number of input channels.
|
| 565 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 566 |
+
The number of output channels.
|
| 567 |
+
blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`):
|
| 568 |
+
The blocks to use. Each block is a tuple of the block name and the number of layers.
|
| 569 |
+
base_channels (`int`, *optional*, defaults to 128):
|
| 570 |
+
The number of output channels for the first convolutional layer.
|
| 571 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 572 |
+
The number of groups for normalization.
|
| 573 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 574 |
+
The patch size to use. Should be a power of 2.
|
| 575 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 576 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 577 |
+
causal (`bool`, *optional*, defaults to `True`):
|
| 578 |
+
Whether to use causal convolutions or not.
|
| 579 |
+
"""
|
| 580 |
+
|
| 581 |
+
def __init__(
|
| 582 |
+
self,
|
| 583 |
+
dims,
|
| 584 |
+
in_channels: int = 3,
|
| 585 |
+
out_channels: int = 3,
|
| 586 |
+
blocks: List[Tuple[str, int | dict]] = [("res_x", 1)],
|
| 587 |
+
base_channels: int = 128,
|
| 588 |
+
layers_per_block: int = 2,
|
| 589 |
+
norm_num_groups: int = 32,
|
| 590 |
+
patch_size: int = 1,
|
| 591 |
+
norm_layer: str = "group_norm",
|
| 592 |
+
causal: bool = True,
|
| 593 |
+
timestep_conditioning: bool = False,
|
| 594 |
+
spatial_padding_mode: str = "zeros",
|
| 595 |
+
):
|
| 596 |
+
super().__init__()
|
| 597 |
+
self.patch_size = patch_size
|
| 598 |
+
self.layers_per_block = layers_per_block
|
| 599 |
+
out_channels = out_channels * patch_size**2
|
| 600 |
+
self.causal = causal
|
| 601 |
+
self.blocks_desc = blocks
|
| 602 |
+
|
| 603 |
+
# Compute output channel to be product of all channel-multiplier blocks
|
| 604 |
+
output_channel = base_channels
|
| 605 |
+
for block_name, block_params in list(reversed(blocks)):
|
| 606 |
+
block_params = block_params if isinstance(block_params, dict) else {}
|
| 607 |
+
if block_name == "res_x_y":
|
| 608 |
+
output_channel = output_channel * block_params.get("multiplier", 2)
|
| 609 |
+
if block_name == "compress_all":
|
| 610 |
+
output_channel = output_channel * block_params.get("multiplier", 1)
|
| 611 |
+
|
| 612 |
+
self.conv_in = make_conv_nd(
|
| 613 |
+
dims,
|
| 614 |
+
in_channels,
|
| 615 |
+
output_channel,
|
| 616 |
+
kernel_size=3,
|
| 617 |
+
stride=1,
|
| 618 |
+
padding=1,
|
| 619 |
+
causal=True,
|
| 620 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
self.up_blocks = nn.ModuleList([])
|
| 624 |
+
|
| 625 |
+
for block_name, block_params in list(reversed(blocks)):
|
| 626 |
+
input_channel = output_channel
|
| 627 |
+
if isinstance(block_params, int):
|
| 628 |
+
block_params = {"num_layers": block_params}
|
| 629 |
+
|
| 630 |
+
if block_name == "res_x":
|
| 631 |
+
block = UNetMidBlock3D(
|
| 632 |
+
dims=dims,
|
| 633 |
+
in_channels=input_channel,
|
| 634 |
+
num_layers=block_params["num_layers"],
|
| 635 |
+
resnet_eps=1e-6,
|
| 636 |
+
resnet_groups=norm_num_groups,
|
| 637 |
+
norm_layer=norm_layer,
|
| 638 |
+
inject_noise=block_params.get("inject_noise", False),
|
| 639 |
+
timestep_conditioning=timestep_conditioning,
|
| 640 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 641 |
+
)
|
| 642 |
+
elif block_name == "attn_res_x":
|
| 643 |
+
block = UNetMidBlock3D(
|
| 644 |
+
dims=dims,
|
| 645 |
+
in_channels=input_channel,
|
| 646 |
+
num_layers=block_params["num_layers"],
|
| 647 |
+
resnet_groups=norm_num_groups,
|
| 648 |
+
norm_layer=norm_layer,
|
| 649 |
+
inject_noise=block_params.get("inject_noise", False),
|
| 650 |
+
timestep_conditioning=timestep_conditioning,
|
| 651 |
+
attention_head_dim=block_params["attention_head_dim"],
|
| 652 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 653 |
+
)
|
| 654 |
+
elif block_name == "res_x_y":
|
| 655 |
+
output_channel = output_channel // block_params.get("multiplier", 2)
|
| 656 |
+
block = ResnetBlock3D(
|
| 657 |
+
dims=dims,
|
| 658 |
+
in_channels=input_channel,
|
| 659 |
+
out_channels=output_channel,
|
| 660 |
+
eps=1e-6,
|
| 661 |
+
groups=norm_num_groups,
|
| 662 |
+
norm_layer=norm_layer,
|
| 663 |
+
inject_noise=block_params.get("inject_noise", False),
|
| 664 |
+
timestep_conditioning=False,
|
| 665 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 666 |
+
)
|
| 667 |
+
elif block_name == "compress_time":
|
| 668 |
+
block = DepthToSpaceUpsample(
|
| 669 |
+
dims=dims,
|
| 670 |
+
in_channels=input_channel,
|
| 671 |
+
stride=(2, 1, 1),
|
| 672 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 673 |
+
)
|
| 674 |
+
elif block_name == "compress_space":
|
| 675 |
+
block = DepthToSpaceUpsample(
|
| 676 |
+
dims=dims,
|
| 677 |
+
in_channels=input_channel,
|
| 678 |
+
stride=(1, 2, 2),
|
| 679 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 680 |
+
)
|
| 681 |
+
elif block_name == "compress_all":
|
| 682 |
+
output_channel = output_channel // block_params.get("multiplier", 1)
|
| 683 |
+
block = DepthToSpaceUpsample(
|
| 684 |
+
dims=dims,
|
| 685 |
+
in_channels=input_channel,
|
| 686 |
+
stride=(2, 2, 2),
|
| 687 |
+
residual=block_params.get("residual", False),
|
| 688 |
+
out_channels_reduction_factor=block_params.get("multiplier", 1),
|
| 689 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 690 |
+
)
|
| 691 |
+
else:
|
| 692 |
+
raise ValueError(f"unknown layer: {block_name}")
|
| 693 |
+
|
| 694 |
+
self.up_blocks.append(block)
|
| 695 |
+
|
| 696 |
+
if norm_layer == "group_norm":
|
| 697 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 698 |
+
num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6
|
| 699 |
+
)
|
| 700 |
+
elif norm_layer == "pixel_norm":
|
| 701 |
+
self.conv_norm_out = PixelNorm()
|
| 702 |
+
elif norm_layer == "layer_norm":
|
| 703 |
+
self.conv_norm_out = LayerNorm(output_channel, eps=1e-6)
|
| 704 |
+
|
| 705 |
+
self.conv_act = nn.SiLU()
|
| 706 |
+
self.conv_out = make_conv_nd(
|
| 707 |
+
dims,
|
| 708 |
+
output_channel,
|
| 709 |
+
out_channels,
|
| 710 |
+
3,
|
| 711 |
+
padding=1,
|
| 712 |
+
causal=True,
|
| 713 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
self.gradient_checkpointing = False
|
| 717 |
+
|
| 718 |
+
self.timestep_conditioning = timestep_conditioning
|
| 719 |
+
|
| 720 |
+
if timestep_conditioning:
|
| 721 |
+
self.timestep_scale_multiplier = nn.Parameter(
|
| 722 |
+
torch.tensor(1000.0, dtype=torch.float32)
|
| 723 |
+
)
|
| 724 |
+
self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
| 725 |
+
output_channel * 2, 0
|
| 726 |
+
)
|
| 727 |
+
self.last_scale_shift_table = nn.Parameter(
|
| 728 |
+
torch.randn(2, output_channel) / output_channel**0.5
|
| 729 |
+
)
|
| 730 |
+
|
| 731 |
+
def forward(
|
| 732 |
+
self,
|
| 733 |
+
sample: torch.FloatTensor,
|
| 734 |
+
target_shape,
|
| 735 |
+
timestep: Optional[torch.Tensor] = None,
|
| 736 |
+
) -> torch.FloatTensor:
|
| 737 |
+
r"""The forward method of the `Decoder` class."""
|
| 738 |
+
assert target_shape is not None, "target_shape must be provided"
|
| 739 |
+
batch_size = sample.shape[0]
|
| 740 |
+
|
| 741 |
+
sample = self.conv_in(sample, causal=self.causal)
|
| 742 |
+
|
| 743 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
| 744 |
+
|
| 745 |
+
checkpoint_fn = (
|
| 746 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 747 |
+
if self.gradient_checkpointing and self.training
|
| 748 |
+
else lambda x: x
|
| 749 |
+
)
|
| 750 |
+
|
| 751 |
+
sample = sample.to(upscale_dtype)
|
| 752 |
+
|
| 753 |
+
if self.timestep_conditioning:
|
| 754 |
+
assert (
|
| 755 |
+
timestep is not None
|
| 756 |
+
), "should pass timestep with timestep_conditioning=True"
|
| 757 |
+
scaled_timestep = timestep * self.timestep_scale_multiplier
|
| 758 |
+
|
| 759 |
+
for up_block in self.up_blocks:
|
| 760 |
+
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
| 761 |
+
sample = checkpoint_fn(up_block)(
|
| 762 |
+
sample, causal=self.causal, timestep=scaled_timestep
|
| 763 |
+
)
|
| 764 |
+
else:
|
| 765 |
+
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
| 766 |
+
|
| 767 |
+
sample = self.conv_norm_out(sample)
|
| 768 |
+
|
| 769 |
+
if self.timestep_conditioning:
|
| 770 |
+
embedded_timestep = self.last_time_embedder(
|
| 771 |
+
timestep=scaled_timestep.flatten(),
|
| 772 |
+
resolution=None,
|
| 773 |
+
aspect_ratio=None,
|
| 774 |
+
batch_size=sample.shape[0],
|
| 775 |
+
hidden_dtype=sample.dtype,
|
| 776 |
+
)
|
| 777 |
+
embedded_timestep = embedded_timestep.view(
|
| 778 |
+
batch_size, embedded_timestep.shape[-1], 1, 1, 1
|
| 779 |
+
)
|
| 780 |
+
ada_values = self.last_scale_shift_table[
|
| 781 |
+
None, ..., None, None, None
|
| 782 |
+
] + embedded_timestep.reshape(
|
| 783 |
+
batch_size,
|
| 784 |
+
2,
|
| 785 |
+
-1,
|
| 786 |
+
embedded_timestep.shape[-3],
|
| 787 |
+
embedded_timestep.shape[-2],
|
| 788 |
+
embedded_timestep.shape[-1],
|
| 789 |
+
)
|
| 790 |
+
shift, scale = ada_values.unbind(dim=1)
|
| 791 |
+
sample = sample * (1 + scale) + shift
|
| 792 |
+
|
| 793 |
+
sample = self.conv_act(sample)
|
| 794 |
+
sample = self.conv_out(sample, causal=self.causal)
|
| 795 |
+
|
| 796 |
+
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
| 797 |
+
|
| 798 |
+
return sample
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
class UNetMidBlock3D(nn.Module):
|
| 802 |
+
"""
|
| 803 |
+
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
| 804 |
+
|
| 805 |
+
Args:
|
| 806 |
+
in_channels (`int`): The number of input channels.
|
| 807 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
| 808 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
| 809 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 810 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
| 811 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
| 812 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 813 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 814 |
+
inject_noise (`bool`, *optional*, defaults to `False`):
|
| 815 |
+
Whether to inject noise into the hidden states.
|
| 816 |
+
timestep_conditioning (`bool`, *optional*, defaults to `False`):
|
| 817 |
+
Whether to condition the hidden states on the timestep.
|
| 818 |
+
attention_head_dim (`int`, *optional*, defaults to -1):
|
| 819 |
+
The dimension of the attention head. If -1, no attention is used.
|
| 820 |
+
|
| 821 |
+
Returns:
|
| 822 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
| 823 |
+
in_channels, height, width)`.
|
| 824 |
+
|
| 825 |
+
"""
|
| 826 |
+
|
| 827 |
+
def __init__(
|
| 828 |
+
self,
|
| 829 |
+
dims: Union[int, Tuple[int, int]],
|
| 830 |
+
in_channels: int,
|
| 831 |
+
dropout: float = 0.0,
|
| 832 |
+
num_layers: int = 1,
|
| 833 |
+
resnet_eps: float = 1e-6,
|
| 834 |
+
resnet_groups: int = 32,
|
| 835 |
+
norm_layer: str = "group_norm",
|
| 836 |
+
inject_noise: bool = False,
|
| 837 |
+
timestep_conditioning: bool = False,
|
| 838 |
+
attention_head_dim: int = -1,
|
| 839 |
+
spatial_padding_mode: str = "zeros",
|
| 840 |
+
):
|
| 841 |
+
super().__init__()
|
| 842 |
+
resnet_groups = (
|
| 843 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 844 |
+
)
|
| 845 |
+
self.timestep_conditioning = timestep_conditioning
|
| 846 |
+
|
| 847 |
+
if timestep_conditioning:
|
| 848 |
+
self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(
|
| 849 |
+
in_channels * 4, 0
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
self.res_blocks = nn.ModuleList(
|
| 853 |
+
[
|
| 854 |
+
ResnetBlock3D(
|
| 855 |
+
dims=dims,
|
| 856 |
+
in_channels=in_channels,
|
| 857 |
+
out_channels=in_channels,
|
| 858 |
+
eps=resnet_eps,
|
| 859 |
+
groups=resnet_groups,
|
| 860 |
+
dropout=dropout,
|
| 861 |
+
norm_layer=norm_layer,
|
| 862 |
+
inject_noise=inject_noise,
|
| 863 |
+
timestep_conditioning=timestep_conditioning,
|
| 864 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 865 |
+
)
|
| 866 |
+
for _ in range(num_layers)
|
| 867 |
+
]
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
self.attention_blocks = None
|
| 871 |
+
|
| 872 |
+
if attention_head_dim > 0:
|
| 873 |
+
if attention_head_dim > in_channels:
|
| 874 |
+
raise ValueError(
|
| 875 |
+
"attention_head_dim must be less than or equal to in_channels"
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
self.attention_blocks = nn.ModuleList(
|
| 879 |
+
[
|
| 880 |
+
Attention(
|
| 881 |
+
query_dim=in_channels,
|
| 882 |
+
heads=in_channels // attention_head_dim,
|
| 883 |
+
dim_head=attention_head_dim,
|
| 884 |
+
bias=True,
|
| 885 |
+
out_bias=True,
|
| 886 |
+
qk_norm="rms_norm",
|
| 887 |
+
residual_connection=True,
|
| 888 |
+
)
|
| 889 |
+
for _ in range(num_layers)
|
| 890 |
+
]
|
| 891 |
+
)
|
| 892 |
+
|
| 893 |
+
def forward(
|
| 894 |
+
self,
|
| 895 |
+
hidden_states: torch.FloatTensor,
|
| 896 |
+
causal: bool = True,
|
| 897 |
+
timestep: Optional[torch.Tensor] = None,
|
| 898 |
+
) -> torch.FloatTensor:
|
| 899 |
+
timestep_embed = None
|
| 900 |
+
if self.timestep_conditioning:
|
| 901 |
+
assert (
|
| 902 |
+
timestep is not None
|
| 903 |
+
), "should pass timestep with timestep_conditioning=True"
|
| 904 |
+
batch_size = hidden_states.shape[0]
|
| 905 |
+
timestep_embed = self.time_embedder(
|
| 906 |
+
timestep=timestep.flatten(),
|
| 907 |
+
resolution=None,
|
| 908 |
+
aspect_ratio=None,
|
| 909 |
+
batch_size=batch_size,
|
| 910 |
+
hidden_dtype=hidden_states.dtype,
|
| 911 |
+
)
|
| 912 |
+
timestep_embed = timestep_embed.view(
|
| 913 |
+
batch_size, timestep_embed.shape[-1], 1, 1, 1
|
| 914 |
+
)
|
| 915 |
+
|
| 916 |
+
if self.attention_blocks:
|
| 917 |
+
for resnet, attention in zip(self.res_blocks, self.attention_blocks):
|
| 918 |
+
hidden_states = resnet(
|
| 919 |
+
hidden_states, causal=causal, timestep=timestep_embed
|
| 920 |
+
)
|
| 921 |
+
|
| 922 |
+
# Reshape the hidden states to be (batch_size, frames * height * width, channel)
|
| 923 |
+
batch_size, channel, frames, height, width = hidden_states.shape
|
| 924 |
+
hidden_states = hidden_states.view(
|
| 925 |
+
batch_size, channel, frames * height * width
|
| 926 |
+
).transpose(1, 2)
|
| 927 |
+
|
| 928 |
+
if attention.use_tpu_flash_attention:
|
| 929 |
+
# Pad the second dimension to be divisible by block_k_major (block in flash attention)
|
| 930 |
+
seq_len = hidden_states.shape[1]
|
| 931 |
+
block_k_major = 512
|
| 932 |
+
pad_len = (block_k_major - seq_len % block_k_major) % block_k_major
|
| 933 |
+
if pad_len > 0:
|
| 934 |
+
hidden_states = F.pad(
|
| 935 |
+
hidden_states, (0, 0, 0, pad_len), "constant", 0
|
| 936 |
+
)
|
| 937 |
+
|
| 938 |
+
# Create a mask with ones for the original sequence length and zeros for the padded indexes
|
| 939 |
+
mask = torch.ones(
|
| 940 |
+
(hidden_states.shape[0], seq_len),
|
| 941 |
+
device=hidden_states.device,
|
| 942 |
+
dtype=hidden_states.dtype,
|
| 943 |
+
)
|
| 944 |
+
if pad_len > 0:
|
| 945 |
+
mask = F.pad(mask, (0, pad_len), "constant", 0)
|
| 946 |
+
|
| 947 |
+
hidden_states = attention(
|
| 948 |
+
hidden_states,
|
| 949 |
+
attention_mask=(
|
| 950 |
+
None if not attention.use_tpu_flash_attention else mask
|
| 951 |
+
),
|
| 952 |
+
)
|
| 953 |
+
|
| 954 |
+
if attention.use_tpu_flash_attention:
|
| 955 |
+
# Remove the padding
|
| 956 |
+
if pad_len > 0:
|
| 957 |
+
hidden_states = hidden_states[:, :-pad_len, :]
|
| 958 |
+
|
| 959 |
+
# Reshape the hidden states back to (batch_size, channel, frames, height, width, channel)
|
| 960 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 961 |
+
batch_size, channel, frames, height, width
|
| 962 |
+
)
|
| 963 |
+
else:
|
| 964 |
+
for resnet in self.res_blocks:
|
| 965 |
+
hidden_states = resnet(
|
| 966 |
+
hidden_states, causal=causal, timestep=timestep_embed
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
return hidden_states
|
| 970 |
+
|
| 971 |
+
|
| 972 |
+
class SpaceToDepthDownsample(nn.Module):
|
| 973 |
+
def __init__(self, dims, in_channels, out_channels, stride, spatial_padding_mode):
|
| 974 |
+
super().__init__()
|
| 975 |
+
self.stride = stride
|
| 976 |
+
self.group_size = in_channels * np.prod(stride) // out_channels
|
| 977 |
+
self.conv = make_conv_nd(
|
| 978 |
+
dims=dims,
|
| 979 |
+
in_channels=in_channels,
|
| 980 |
+
out_channels=out_channels // np.prod(stride),
|
| 981 |
+
kernel_size=3,
|
| 982 |
+
stride=1,
|
| 983 |
+
causal=True,
|
| 984 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 985 |
+
)
|
| 986 |
+
|
| 987 |
+
def forward(self, x, causal: bool = True):
|
| 988 |
+
if self.stride[0] == 2:
|
| 989 |
+
x = torch.cat(
|
| 990 |
+
[x[:, :, :1, :, :], x], dim=2
|
| 991 |
+
) # duplicate first frames for padding
|
| 992 |
+
|
| 993 |
+
# skip connection
|
| 994 |
+
x_in = rearrange(
|
| 995 |
+
x,
|
| 996 |
+
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
| 997 |
+
p1=self.stride[0],
|
| 998 |
+
p2=self.stride[1],
|
| 999 |
+
p3=self.stride[2],
|
| 1000 |
+
)
|
| 1001 |
+
x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size)
|
| 1002 |
+
x_in = x_in.mean(dim=2)
|
| 1003 |
+
|
| 1004 |
+
# conv
|
| 1005 |
+
x = self.conv(x, causal=causal)
|
| 1006 |
+
x = rearrange(
|
| 1007 |
+
x,
|
| 1008 |
+
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
| 1009 |
+
p1=self.stride[0],
|
| 1010 |
+
p2=self.stride[1],
|
| 1011 |
+
p3=self.stride[2],
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
x = x + x_in
|
| 1015 |
+
|
| 1016 |
+
return x
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
class DepthToSpaceUpsample(nn.Module):
|
| 1020 |
+
def __init__(
|
| 1021 |
+
self,
|
| 1022 |
+
dims,
|
| 1023 |
+
in_channels,
|
| 1024 |
+
stride,
|
| 1025 |
+
residual=False,
|
| 1026 |
+
out_channels_reduction_factor=1,
|
| 1027 |
+
spatial_padding_mode="zeros",
|
| 1028 |
+
):
|
| 1029 |
+
super().__init__()
|
| 1030 |
+
self.stride = stride
|
| 1031 |
+
self.out_channels = (
|
| 1032 |
+
np.prod(stride) * in_channels // out_channels_reduction_factor
|
| 1033 |
+
)
|
| 1034 |
+
self.conv = make_conv_nd(
|
| 1035 |
+
dims=dims,
|
| 1036 |
+
in_channels=in_channels,
|
| 1037 |
+
out_channels=self.out_channels,
|
| 1038 |
+
kernel_size=3,
|
| 1039 |
+
stride=1,
|
| 1040 |
+
causal=True,
|
| 1041 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 1042 |
+
)
|
| 1043 |
+
self.residual = residual
|
| 1044 |
+
self.out_channels_reduction_factor = out_channels_reduction_factor
|
| 1045 |
+
|
| 1046 |
+
def forward(self, x, causal: bool = True):
|
| 1047 |
+
if self.residual:
|
| 1048 |
+
# Reshape and duplicate the input to match the output shape
|
| 1049 |
+
x_in = rearrange(
|
| 1050 |
+
x,
|
| 1051 |
+
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
| 1052 |
+
p1=self.stride[0],
|
| 1053 |
+
p2=self.stride[1],
|
| 1054 |
+
p3=self.stride[2],
|
| 1055 |
+
)
|
| 1056 |
+
num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor
|
| 1057 |
+
x_in = x_in.repeat(1, num_repeat, 1, 1, 1)
|
| 1058 |
+
if self.stride[0] == 2:
|
| 1059 |
+
x_in = x_in[:, :, 1:, :, :]
|
| 1060 |
+
x = self.conv(x, causal=causal)
|
| 1061 |
+
x = rearrange(
|
| 1062 |
+
x,
|
| 1063 |
+
"b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)",
|
| 1064 |
+
p1=self.stride[0],
|
| 1065 |
+
p2=self.stride[1],
|
| 1066 |
+
p3=self.stride[2],
|
| 1067 |
+
)
|
| 1068 |
+
if self.stride[0] == 2:
|
| 1069 |
+
x = x[:, :, 1:, :, :]
|
| 1070 |
+
if self.residual:
|
| 1071 |
+
x = x + x_in
|
| 1072 |
+
return x
|
| 1073 |
+
|
| 1074 |
+
|
| 1075 |
+
class LayerNorm(nn.Module):
|
| 1076 |
+
def __init__(self, dim, eps, elementwise_affine=True) -> None:
|
| 1077 |
+
super().__init__()
|
| 1078 |
+
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
| 1079 |
+
|
| 1080 |
+
def forward(self, x):
|
| 1081 |
+
x = rearrange(x, "b c d h w -> b d h w c")
|
| 1082 |
+
x = self.norm(x)
|
| 1083 |
+
x = rearrange(x, "b d h w c -> b c d h w")
|
| 1084 |
+
return x
|
| 1085 |
+
|
| 1086 |
+
|
| 1087 |
+
class ResnetBlock3D(nn.Module):
|
| 1088 |
+
r"""
|
| 1089 |
+
A Resnet block.
|
| 1090 |
+
|
| 1091 |
+
Parameters:
|
| 1092 |
+
in_channels (`int`): The number of channels in the input.
|
| 1093 |
+
out_channels (`int`, *optional*, default to be `None`):
|
| 1094 |
+
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
| 1095 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
| 1096 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
| 1097 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
| 1098 |
+
"""
|
| 1099 |
+
|
| 1100 |
+
def __init__(
|
| 1101 |
+
self,
|
| 1102 |
+
dims: Union[int, Tuple[int, int]],
|
| 1103 |
+
in_channels: int,
|
| 1104 |
+
out_channels: Optional[int] = None,
|
| 1105 |
+
dropout: float = 0.0,
|
| 1106 |
+
groups: int = 32,
|
| 1107 |
+
eps: float = 1e-6,
|
| 1108 |
+
norm_layer: str = "group_norm",
|
| 1109 |
+
inject_noise: bool = False,
|
| 1110 |
+
timestep_conditioning: bool = False,
|
| 1111 |
+
spatial_padding_mode: str = "zeros",
|
| 1112 |
+
):
|
| 1113 |
+
super().__init__()
|
| 1114 |
+
self.in_channels = in_channels
|
| 1115 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 1116 |
+
self.out_channels = out_channels
|
| 1117 |
+
self.inject_noise = inject_noise
|
| 1118 |
+
|
| 1119 |
+
if norm_layer == "group_norm":
|
| 1120 |
+
self.norm1 = nn.GroupNorm(
|
| 1121 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
| 1122 |
+
)
|
| 1123 |
+
elif norm_layer == "pixel_norm":
|
| 1124 |
+
self.norm1 = PixelNorm()
|
| 1125 |
+
elif norm_layer == "layer_norm":
|
| 1126 |
+
self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
| 1127 |
+
|
| 1128 |
+
self.non_linearity = nn.SiLU()
|
| 1129 |
+
|
| 1130 |
+
self.conv1 = make_conv_nd(
|
| 1131 |
+
dims,
|
| 1132 |
+
in_channels,
|
| 1133 |
+
out_channels,
|
| 1134 |
+
kernel_size=3,
|
| 1135 |
+
stride=1,
|
| 1136 |
+
padding=1,
|
| 1137 |
+
causal=True,
|
| 1138 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 1139 |
+
)
|
| 1140 |
+
|
| 1141 |
+
if inject_noise:
|
| 1142 |
+
self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
| 1143 |
+
|
| 1144 |
+
if norm_layer == "group_norm":
|
| 1145 |
+
self.norm2 = nn.GroupNorm(
|
| 1146 |
+
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
| 1147 |
+
)
|
| 1148 |
+
elif norm_layer == "pixel_norm":
|
| 1149 |
+
self.norm2 = PixelNorm()
|
| 1150 |
+
elif norm_layer == "layer_norm":
|
| 1151 |
+
self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True)
|
| 1152 |
+
|
| 1153 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 1154 |
+
|
| 1155 |
+
self.conv2 = make_conv_nd(
|
| 1156 |
+
dims,
|
| 1157 |
+
out_channels,
|
| 1158 |
+
out_channels,
|
| 1159 |
+
kernel_size=3,
|
| 1160 |
+
stride=1,
|
| 1161 |
+
padding=1,
|
| 1162 |
+
causal=True,
|
| 1163 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 1164 |
+
)
|
| 1165 |
+
|
| 1166 |
+
if inject_noise:
|
| 1167 |
+
self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1)))
|
| 1168 |
+
|
| 1169 |
+
self.conv_shortcut = (
|
| 1170 |
+
make_linear_nd(
|
| 1171 |
+
dims=dims, in_channels=in_channels, out_channels=out_channels
|
| 1172 |
+
)
|
| 1173 |
+
if in_channels != out_channels
|
| 1174 |
+
else nn.Identity()
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
+
self.norm3 = (
|
| 1178 |
+
LayerNorm(in_channels, eps=eps, elementwise_affine=True)
|
| 1179 |
+
if in_channels != out_channels
|
| 1180 |
+
else nn.Identity()
|
| 1181 |
+
)
|
| 1182 |
+
|
| 1183 |
+
self.timestep_conditioning = timestep_conditioning
|
| 1184 |
+
|
| 1185 |
+
if timestep_conditioning:
|
| 1186 |
+
self.scale_shift_table = nn.Parameter(
|
| 1187 |
+
torch.randn(4, in_channels) / in_channels**0.5
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
def _feed_spatial_noise(
|
| 1191 |
+
self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor
|
| 1192 |
+
) -> torch.FloatTensor:
|
| 1193 |
+
spatial_shape = hidden_states.shape[-2:]
|
| 1194 |
+
device = hidden_states.device
|
| 1195 |
+
dtype = hidden_states.dtype
|
| 1196 |
+
|
| 1197 |
+
# similar to the "explicit noise inputs" method in style-gan
|
| 1198 |
+
spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None]
|
| 1199 |
+
scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...]
|
| 1200 |
+
hidden_states = hidden_states + scaled_noise
|
| 1201 |
+
|
| 1202 |
+
return hidden_states
|
| 1203 |
+
|
| 1204 |
+
def forward(
|
| 1205 |
+
self,
|
| 1206 |
+
input_tensor: torch.FloatTensor,
|
| 1207 |
+
causal: bool = True,
|
| 1208 |
+
timestep: Optional[torch.Tensor] = None,
|
| 1209 |
+
) -> torch.FloatTensor:
|
| 1210 |
+
hidden_states = input_tensor
|
| 1211 |
+
batch_size = hidden_states.shape[0]
|
| 1212 |
+
|
| 1213 |
+
hidden_states = self.norm1(hidden_states)
|
| 1214 |
+
if self.timestep_conditioning:
|
| 1215 |
+
assert (
|
| 1216 |
+
timestep is not None
|
| 1217 |
+
), "should pass timestep with timestep_conditioning=True"
|
| 1218 |
+
ada_values = self.scale_shift_table[
|
| 1219 |
+
None, ..., None, None, None
|
| 1220 |
+
] + timestep.reshape(
|
| 1221 |
+
batch_size,
|
| 1222 |
+
4,
|
| 1223 |
+
-1,
|
| 1224 |
+
timestep.shape[-3],
|
| 1225 |
+
timestep.shape[-2],
|
| 1226 |
+
timestep.shape[-1],
|
| 1227 |
+
)
|
| 1228 |
+
shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1)
|
| 1229 |
+
|
| 1230 |
+
hidden_states = hidden_states * (1 + scale1) + shift1
|
| 1231 |
+
|
| 1232 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 1233 |
+
|
| 1234 |
+
hidden_states = self.conv1(hidden_states, causal=causal)
|
| 1235 |
+
|
| 1236 |
+
if self.inject_noise:
|
| 1237 |
+
hidden_states = self._feed_spatial_noise(
|
| 1238 |
+
hidden_states, self.per_channel_scale1
|
| 1239 |
+
)
|
| 1240 |
+
|
| 1241 |
+
hidden_states = self.norm2(hidden_states)
|
| 1242 |
+
|
| 1243 |
+
if self.timestep_conditioning:
|
| 1244 |
+
hidden_states = hidden_states * (1 + scale2) + shift2
|
| 1245 |
+
|
| 1246 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 1247 |
+
|
| 1248 |
+
hidden_states = self.dropout(hidden_states)
|
| 1249 |
+
|
| 1250 |
+
hidden_states = self.conv2(hidden_states, causal=causal)
|
| 1251 |
+
|
| 1252 |
+
if self.inject_noise:
|
| 1253 |
+
hidden_states = self._feed_spatial_noise(
|
| 1254 |
+
hidden_states, self.per_channel_scale2
|
| 1255 |
+
)
|
| 1256 |
+
|
| 1257 |
+
input_tensor = self.norm3(input_tensor)
|
| 1258 |
+
|
| 1259 |
+
batch_size = input_tensor.shape[0]
|
| 1260 |
+
|
| 1261 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 1262 |
+
|
| 1263 |
+
output_tensor = input_tensor + hidden_states
|
| 1264 |
+
|
| 1265 |
+
return output_tensor
|
| 1266 |
+
|
| 1267 |
+
|
| 1268 |
+
def patchify(x, patch_size_hw, patch_size_t=1):
|
| 1269 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 1270 |
+
return x
|
| 1271 |
+
if x.dim() == 4:
|
| 1272 |
+
x = rearrange(
|
| 1273 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
| 1274 |
+
)
|
| 1275 |
+
elif x.dim() == 5:
|
| 1276 |
+
x = rearrange(
|
| 1277 |
+
x,
|
| 1278 |
+
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
| 1279 |
+
p=patch_size_t,
|
| 1280 |
+
q=patch_size_hw,
|
| 1281 |
+
r=patch_size_hw,
|
| 1282 |
+
)
|
| 1283 |
+
else:
|
| 1284 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 1285 |
+
|
| 1286 |
+
return x
|
| 1287 |
+
|
| 1288 |
+
|
| 1289 |
+
def unpatchify(x, patch_size_hw, patch_size_t=1):
|
| 1290 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 1291 |
+
return x
|
| 1292 |
+
|
| 1293 |
+
if x.dim() == 4:
|
| 1294 |
+
x = rearrange(
|
| 1295 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
| 1296 |
+
)
|
| 1297 |
+
elif x.dim() == 5:
|
| 1298 |
+
x = rearrange(
|
| 1299 |
+
x,
|
| 1300 |
+
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
| 1301 |
+
p=patch_size_t,
|
| 1302 |
+
q=patch_size_hw,
|
| 1303 |
+
r=patch_size_hw,
|
| 1304 |
+
)
|
| 1305 |
+
|
| 1306 |
+
return x
|
| 1307 |
+
|
| 1308 |
+
|
| 1309 |
+
def create_video_autoencoder_demo_config(
|
| 1310 |
+
latent_channels: int = 64,
|
| 1311 |
+
):
|
| 1312 |
+
encoder_blocks = [
|
| 1313 |
+
("res_x", {"num_layers": 2}),
|
| 1314 |
+
("compress_space_res", {"multiplier": 2}),
|
| 1315 |
+
("res_x", {"num_layers": 2}),
|
| 1316 |
+
("compress_time_res", {"multiplier": 2}),
|
| 1317 |
+
("res_x", {"num_layers": 1}),
|
| 1318 |
+
("compress_all_res", {"multiplier": 2}),
|
| 1319 |
+
("res_x", {"num_layers": 1}),
|
| 1320 |
+
("compress_all_res", {"multiplier": 2}),
|
| 1321 |
+
("res_x", {"num_layers": 1}),
|
| 1322 |
+
]
|
| 1323 |
+
decoder_blocks = [
|
| 1324 |
+
("res_x", {"num_layers": 2, "inject_noise": False}),
|
| 1325 |
+
("compress_all", {"residual": True, "multiplier": 2}),
|
| 1326 |
+
("res_x", {"num_layers": 2, "inject_noise": False}),
|
| 1327 |
+
("compress_all", {"residual": True, "multiplier": 2}),
|
| 1328 |
+
("res_x", {"num_layers": 2, "inject_noise": False}),
|
| 1329 |
+
("compress_all", {"residual": True, "multiplier": 2}),
|
| 1330 |
+
("res_x", {"num_layers": 2, "inject_noise": False}),
|
| 1331 |
+
]
|
| 1332 |
+
return {
|
| 1333 |
+
"_class_name": "CausalVideoAutoencoder",
|
| 1334 |
+
"dims": 3,
|
| 1335 |
+
"encoder_blocks": encoder_blocks,
|
| 1336 |
+
"decoder_blocks": decoder_blocks,
|
| 1337 |
+
"latent_channels": latent_channels,
|
| 1338 |
+
"norm_layer": "pixel_norm",
|
| 1339 |
+
"patch_size": 4,
|
| 1340 |
+
"latent_log_var": "uniform",
|
| 1341 |
+
"use_quant_conv": False,
|
| 1342 |
+
"causal_decoder": False,
|
| 1343 |
+
"timestep_conditioning": True,
|
| 1344 |
+
"spatial_padding_mode": "replicate",
|
| 1345 |
+
}
|
| 1346 |
+
|
| 1347 |
+
|
| 1348 |
+
def test_vae_patchify_unpatchify():
|
| 1349 |
+
import torch
|
| 1350 |
+
|
| 1351 |
+
x = torch.randn(2, 3, 8, 64, 64)
|
| 1352 |
+
x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
|
| 1353 |
+
x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
|
| 1354 |
+
assert torch.allclose(x, x_unpatched)
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
def demo_video_autoencoder_forward_backward():
|
| 1358 |
+
# Configuration for the VideoAutoencoder
|
| 1359 |
+
config = create_video_autoencoder_demo_config()
|
| 1360 |
+
|
| 1361 |
+
# Instantiate the VideoAutoencoder with the specified configuration
|
| 1362 |
+
video_autoencoder = CausalVideoAutoencoder.from_config(config)
|
| 1363 |
+
|
| 1364 |
+
print(video_autoencoder)
|
| 1365 |
+
video_autoencoder.eval()
|
| 1366 |
+
# Print the total number of parameters in the video autoencoder
|
| 1367 |
+
total_params = sum(p.numel() for p in video_autoencoder.parameters())
|
| 1368 |
+
print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
|
| 1369 |
+
|
| 1370 |
+
# Create a mock input tensor simulating a batch of videos
|
| 1371 |
+
# Shape: (batch_size, channels, depth, height, width)
|
| 1372 |
+
# E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
|
| 1373 |
+
input_videos = torch.randn(2, 3, 17, 64, 64)
|
| 1374 |
+
|
| 1375 |
+
# Forward pass: encode and decode the input videos
|
| 1376 |
+
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
| 1377 |
+
print(f"input shape={input_videos.shape}")
|
| 1378 |
+
print(f"latent shape={latent.shape}")
|
| 1379 |
+
|
| 1380 |
+
timestep = torch.ones(input_videos.shape[0]) * 0.1
|
| 1381 |
+
reconstructed_videos = video_autoencoder.decode(
|
| 1382 |
+
latent, target_shape=input_videos.shape, timestep=timestep
|
| 1383 |
+
).sample
|
| 1384 |
+
|
| 1385 |
+
print(f"reconstructed shape={reconstructed_videos.shape}")
|
| 1386 |
+
|
| 1387 |
+
# Validate that single image gets treated the same way as first frame
|
| 1388 |
+
input_image = input_videos[:, :, :1, :, :]
|
| 1389 |
+
image_latent = video_autoencoder.encode(input_image).latent_dist.mode()
|
| 1390 |
+
_ = video_autoencoder.decode(
|
| 1391 |
+
image_latent, target_shape=image_latent.shape, timestep=timestep
|
| 1392 |
+
).sample
|
| 1393 |
+
|
| 1394 |
+
first_frame_latent = latent[:, :, :1, :, :]
|
| 1395 |
+
|
| 1396 |
+
assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
|
| 1397 |
+
# assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6)
|
| 1398 |
+
# assert torch.allclose(image_latent, first_frame_latent, atol=1e-6)
|
| 1399 |
+
# assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all()
|
| 1400 |
+
|
| 1401 |
+
# Calculate the loss (e.g., mean squared error)
|
| 1402 |
+
loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
|
| 1403 |
+
|
| 1404 |
+
# Perform backward pass
|
| 1405 |
+
loss.backward()
|
| 1406 |
+
|
| 1407 |
+
print(f"Demo completed with loss: {loss.item()}")
|
| 1408 |
+
|
| 1409 |
+
|
| 1410 |
+
# Ensure to call the demo function to execute the forward and backward pass
|
| 1411 |
+
if __name__ == "__main__":
|
| 1412 |
+
demo_video_autoencoder_forward_backward()
|
flash_head/ltx_video/models/autoencoders/conv_nd_factory.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from flash_head.ltx_video.models.autoencoders.dual_conv3d import DualConv3d
|
| 6 |
+
from flash_head.ltx_video.models.autoencoders.causal_conv3d import CausalConv3d
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def make_conv_nd(
|
| 10 |
+
dims: Union[int, Tuple[int, int]],
|
| 11 |
+
in_channels: int,
|
| 12 |
+
out_channels: int,
|
| 13 |
+
kernel_size: int,
|
| 14 |
+
stride=1,
|
| 15 |
+
padding=0,
|
| 16 |
+
dilation=1,
|
| 17 |
+
groups=1,
|
| 18 |
+
bias=True,
|
| 19 |
+
causal=False,
|
| 20 |
+
spatial_padding_mode="zeros",
|
| 21 |
+
temporal_padding_mode="zeros",
|
| 22 |
+
):
|
| 23 |
+
if not (spatial_padding_mode == temporal_padding_mode or causal):
|
| 24 |
+
raise NotImplementedError("spatial and temporal padding modes must be equal")
|
| 25 |
+
if dims == 2:
|
| 26 |
+
return torch.nn.Conv2d(
|
| 27 |
+
in_channels=in_channels,
|
| 28 |
+
out_channels=out_channels,
|
| 29 |
+
kernel_size=kernel_size,
|
| 30 |
+
stride=stride,
|
| 31 |
+
padding=padding,
|
| 32 |
+
dilation=dilation,
|
| 33 |
+
groups=groups,
|
| 34 |
+
bias=bias,
|
| 35 |
+
padding_mode=spatial_padding_mode,
|
| 36 |
+
)
|
| 37 |
+
elif dims == 3:
|
| 38 |
+
if causal:
|
| 39 |
+
return CausalConv3d(
|
| 40 |
+
in_channels=in_channels,
|
| 41 |
+
out_channels=out_channels,
|
| 42 |
+
kernel_size=kernel_size,
|
| 43 |
+
stride=stride,
|
| 44 |
+
padding=padding,
|
| 45 |
+
dilation=dilation,
|
| 46 |
+
groups=groups,
|
| 47 |
+
bias=bias,
|
| 48 |
+
spatial_padding_mode=spatial_padding_mode,
|
| 49 |
+
)
|
| 50 |
+
return torch.nn.Conv3d(
|
| 51 |
+
in_channels=in_channels,
|
| 52 |
+
out_channels=out_channels,
|
| 53 |
+
kernel_size=kernel_size,
|
| 54 |
+
stride=stride,
|
| 55 |
+
padding=padding,
|
| 56 |
+
dilation=dilation,
|
| 57 |
+
groups=groups,
|
| 58 |
+
bias=bias,
|
| 59 |
+
padding_mode=spatial_padding_mode,
|
| 60 |
+
)
|
| 61 |
+
elif dims == (2, 1):
|
| 62 |
+
return DualConv3d(
|
| 63 |
+
in_channels=in_channels,
|
| 64 |
+
out_channels=out_channels,
|
| 65 |
+
kernel_size=kernel_size,
|
| 66 |
+
stride=stride,
|
| 67 |
+
padding=padding,
|
| 68 |
+
bias=bias,
|
| 69 |
+
padding_mode=spatial_padding_mode,
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def make_linear_nd(
|
| 76 |
+
dims: int,
|
| 77 |
+
in_channels: int,
|
| 78 |
+
out_channels: int,
|
| 79 |
+
bias=True,
|
| 80 |
+
):
|
| 81 |
+
if dims == 2:
|
| 82 |
+
return torch.nn.Conv2d(
|
| 83 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
| 84 |
+
)
|
| 85 |
+
elif dims == 3 or dims == (2, 1):
|
| 86 |
+
return torch.nn.Conv3d(
|
| 87 |
+
in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
raise ValueError(f"unsupported dimensions: {dims}")
|
flash_head/ltx_video/models/autoencoders/dual_conv3d.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DualConv3d(nn.Module):
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
in_channels,
|
| 14 |
+
out_channels,
|
| 15 |
+
kernel_size,
|
| 16 |
+
stride: Union[int, Tuple[int, int, int]] = 1,
|
| 17 |
+
padding: Union[int, Tuple[int, int, int]] = 0,
|
| 18 |
+
dilation: Union[int, Tuple[int, int, int]] = 1,
|
| 19 |
+
groups=1,
|
| 20 |
+
bias=True,
|
| 21 |
+
padding_mode="zeros",
|
| 22 |
+
):
|
| 23 |
+
super(DualConv3d, self).__init__()
|
| 24 |
+
|
| 25 |
+
self.in_channels = in_channels
|
| 26 |
+
self.out_channels = out_channels
|
| 27 |
+
self.padding_mode = padding_mode
|
| 28 |
+
# Ensure kernel_size, stride, padding, and dilation are tuples of length 3
|
| 29 |
+
if isinstance(kernel_size, int):
|
| 30 |
+
kernel_size = (kernel_size, kernel_size, kernel_size)
|
| 31 |
+
if kernel_size == (1, 1, 1):
|
| 32 |
+
raise ValueError(
|
| 33 |
+
"kernel_size must be greater than 1. Use make_linear_nd instead."
|
| 34 |
+
)
|
| 35 |
+
if isinstance(stride, int):
|
| 36 |
+
stride = (stride, stride, stride)
|
| 37 |
+
if isinstance(padding, int):
|
| 38 |
+
padding = (padding, padding, padding)
|
| 39 |
+
if isinstance(dilation, int):
|
| 40 |
+
dilation = (dilation, dilation, dilation)
|
| 41 |
+
|
| 42 |
+
# Set parameters for convolutions
|
| 43 |
+
self.groups = groups
|
| 44 |
+
self.bias = bias
|
| 45 |
+
|
| 46 |
+
# Define the size of the channels after the first convolution
|
| 47 |
+
intermediate_channels = (
|
| 48 |
+
out_channels if in_channels < out_channels else in_channels
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Define parameters for the first convolution
|
| 52 |
+
self.weight1 = nn.Parameter(
|
| 53 |
+
torch.Tensor(
|
| 54 |
+
intermediate_channels,
|
| 55 |
+
in_channels // groups,
|
| 56 |
+
1,
|
| 57 |
+
kernel_size[1],
|
| 58 |
+
kernel_size[2],
|
| 59 |
+
)
|
| 60 |
+
)
|
| 61 |
+
self.stride1 = (1, stride[1], stride[2])
|
| 62 |
+
self.padding1 = (0, padding[1], padding[2])
|
| 63 |
+
self.dilation1 = (1, dilation[1], dilation[2])
|
| 64 |
+
if bias:
|
| 65 |
+
self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels))
|
| 66 |
+
else:
|
| 67 |
+
self.register_parameter("bias1", None)
|
| 68 |
+
|
| 69 |
+
# Define parameters for the second convolution
|
| 70 |
+
self.weight2 = nn.Parameter(
|
| 71 |
+
torch.Tensor(
|
| 72 |
+
out_channels, intermediate_channels // groups, kernel_size[0], 1, 1
|
| 73 |
+
)
|
| 74 |
+
)
|
| 75 |
+
self.stride2 = (stride[0], 1, 1)
|
| 76 |
+
self.padding2 = (padding[0], 0, 0)
|
| 77 |
+
self.dilation2 = (dilation[0], 1, 1)
|
| 78 |
+
if bias:
|
| 79 |
+
self.bias2 = nn.Parameter(torch.Tensor(out_channels))
|
| 80 |
+
else:
|
| 81 |
+
self.register_parameter("bias2", None)
|
| 82 |
+
|
| 83 |
+
# Initialize weights and biases
|
| 84 |
+
self.reset_parameters()
|
| 85 |
+
|
| 86 |
+
def reset_parameters(self):
|
| 87 |
+
nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
|
| 88 |
+
nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5))
|
| 89 |
+
if self.bias:
|
| 90 |
+
fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1)
|
| 91 |
+
bound1 = 1 / math.sqrt(fan_in1)
|
| 92 |
+
nn.init.uniform_(self.bias1, -bound1, bound1)
|
| 93 |
+
fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2)
|
| 94 |
+
bound2 = 1 / math.sqrt(fan_in2)
|
| 95 |
+
nn.init.uniform_(self.bias2, -bound2, bound2)
|
| 96 |
+
|
| 97 |
+
def forward(self, x, use_conv3d=False, skip_time_conv=False):
|
| 98 |
+
if use_conv3d:
|
| 99 |
+
return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv)
|
| 100 |
+
else:
|
| 101 |
+
return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv)
|
| 102 |
+
|
| 103 |
+
def forward_with_3d(self, x, skip_time_conv):
|
| 104 |
+
# First convolution
|
| 105 |
+
x = F.conv3d(
|
| 106 |
+
x,
|
| 107 |
+
self.weight1,
|
| 108 |
+
self.bias1,
|
| 109 |
+
self.stride1,
|
| 110 |
+
self.padding1,
|
| 111 |
+
self.dilation1,
|
| 112 |
+
self.groups,
|
| 113 |
+
padding_mode=self.padding_mode,
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if skip_time_conv:
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
# Second convolution
|
| 120 |
+
x = F.conv3d(
|
| 121 |
+
x,
|
| 122 |
+
self.weight2,
|
| 123 |
+
self.bias2,
|
| 124 |
+
self.stride2,
|
| 125 |
+
self.padding2,
|
| 126 |
+
self.dilation2,
|
| 127 |
+
self.groups,
|
| 128 |
+
padding_mode=self.padding_mode,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
return x
|
| 132 |
+
|
| 133 |
+
def forward_with_2d(self, x, skip_time_conv):
|
| 134 |
+
b, c, d, h, w = x.shape
|
| 135 |
+
|
| 136 |
+
# First 2D convolution
|
| 137 |
+
x = rearrange(x, "b c d h w -> (b d) c h w")
|
| 138 |
+
# Squeeze the depth dimension out of weight1 since it's 1
|
| 139 |
+
weight1 = self.weight1.squeeze(2)
|
| 140 |
+
# Select stride, padding, and dilation for the 2D convolution
|
| 141 |
+
stride1 = (self.stride1[1], self.stride1[2])
|
| 142 |
+
padding1 = (self.padding1[1], self.padding1[2])
|
| 143 |
+
dilation1 = (self.dilation1[1], self.dilation1[2])
|
| 144 |
+
x = F.conv2d(
|
| 145 |
+
x,
|
| 146 |
+
weight1,
|
| 147 |
+
self.bias1,
|
| 148 |
+
stride1,
|
| 149 |
+
padding1,
|
| 150 |
+
dilation1,
|
| 151 |
+
self.groups,
|
| 152 |
+
padding_mode=self.padding_mode,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
_, _, h, w = x.shape
|
| 156 |
+
|
| 157 |
+
if skip_time_conv:
|
| 158 |
+
x = rearrange(x, "(b d) c h w -> b c d h w", b=b)
|
| 159 |
+
return x
|
| 160 |
+
|
| 161 |
+
# Second convolution which is essentially treated as a 1D convolution across the 'd' dimension
|
| 162 |
+
x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b)
|
| 163 |
+
|
| 164 |
+
# Reshape weight2 to match the expected dimensions for conv1d
|
| 165 |
+
weight2 = self.weight2.squeeze(-1).squeeze(-1)
|
| 166 |
+
# Use only the relevant dimension for stride, padding, and dilation for the 1D convolution
|
| 167 |
+
stride2 = self.stride2[0]
|
| 168 |
+
padding2 = self.padding2[0]
|
| 169 |
+
dilation2 = self.dilation2[0]
|
| 170 |
+
x = F.conv1d(
|
| 171 |
+
x,
|
| 172 |
+
weight2,
|
| 173 |
+
self.bias2,
|
| 174 |
+
stride2,
|
| 175 |
+
padding2,
|
| 176 |
+
dilation2,
|
| 177 |
+
self.groups,
|
| 178 |
+
padding_mode=self.padding_mode,
|
| 179 |
+
)
|
| 180 |
+
x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w)
|
| 181 |
+
|
| 182 |
+
return x
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def weight(self):
|
| 186 |
+
return self.weight2
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def test_dual_conv3d_consistency():
|
| 190 |
+
# Initialize parameters
|
| 191 |
+
in_channels = 3
|
| 192 |
+
out_channels = 5
|
| 193 |
+
kernel_size = (3, 3, 3)
|
| 194 |
+
stride = (2, 2, 2)
|
| 195 |
+
padding = (1, 1, 1)
|
| 196 |
+
|
| 197 |
+
# Create an instance of the DualConv3d class
|
| 198 |
+
dual_conv3d = DualConv3d(
|
| 199 |
+
in_channels=in_channels,
|
| 200 |
+
out_channels=out_channels,
|
| 201 |
+
kernel_size=kernel_size,
|
| 202 |
+
stride=stride,
|
| 203 |
+
padding=padding,
|
| 204 |
+
bias=True,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Example input tensor
|
| 208 |
+
test_input = torch.randn(1, 3, 10, 10, 10)
|
| 209 |
+
|
| 210 |
+
# Perform forward passes with both 3D and 2D settings
|
| 211 |
+
output_conv3d = dual_conv3d(test_input, use_conv3d=True)
|
| 212 |
+
output_2d = dual_conv3d(test_input, use_conv3d=False)
|
| 213 |
+
|
| 214 |
+
# Assert that the outputs from both methods are sufficiently close
|
| 215 |
+
assert torch.allclose(
|
| 216 |
+
output_conv3d, output_2d, atol=1e-6
|
| 217 |
+
), "Outputs are not consistent between 3D and 2D convolutions."
|
flash_head/ltx_video/models/autoencoders/pixel_norm.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PixelNorm(nn.Module):
|
| 6 |
+
def __init__(self, dim=1, eps=1e-8):
|
| 7 |
+
super(PixelNorm, self).__init__()
|
| 8 |
+
self.dim = dim
|
| 9 |
+
self.eps = eps
|
| 10 |
+
|
| 11 |
+
def forward(self, x):
|
| 12 |
+
return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps)
|
flash_head/ltx_video/models/autoencoders/vae.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import inspect
|
| 5 |
+
import math
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from diffusers import ConfigMixin, ModelMixin
|
| 8 |
+
from diffusers.models.autoencoders.vae import (
|
| 9 |
+
DecoderOutput,
|
| 10 |
+
DiagonalGaussianDistribution,
|
| 11 |
+
)
|
| 12 |
+
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
| 13 |
+
from flash_head.ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class AutoencoderKLWrapper(ModelMixin, ConfigMixin):
|
| 17 |
+
"""Variational Autoencoder (VAE) model with KL loss.
|
| 18 |
+
|
| 19 |
+
VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling.
|
| 20 |
+
This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
encoder (`nn.Module`):
|
| 24 |
+
Encoder module.
|
| 25 |
+
decoder (`nn.Module`):
|
| 26 |
+
Decoder module.
|
| 27 |
+
latent_channels (`int`, *optional*, defaults to 4):
|
| 28 |
+
Number of latent channels.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
encoder: nn.Module,
|
| 34 |
+
decoder: nn.Module,
|
| 35 |
+
latent_channels: int = 4,
|
| 36 |
+
dims: int = 2,
|
| 37 |
+
sample_size=512,
|
| 38 |
+
use_quant_conv: bool = True,
|
| 39 |
+
normalize_latent_channels: bool = False,
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
|
| 43 |
+
# pass init params to Encoder
|
| 44 |
+
self.encoder = encoder
|
| 45 |
+
self.use_quant_conv = use_quant_conv
|
| 46 |
+
self.normalize_latent_channels = normalize_latent_channels
|
| 47 |
+
|
| 48 |
+
# pass init params to Decoder
|
| 49 |
+
quant_dims = 2 if dims == 2 else 3
|
| 50 |
+
self.decoder = decoder
|
| 51 |
+
if use_quant_conv:
|
| 52 |
+
self.quant_conv = make_conv_nd(
|
| 53 |
+
quant_dims, 2 * latent_channels, 2 * latent_channels, 1
|
| 54 |
+
)
|
| 55 |
+
self.post_quant_conv = make_conv_nd(
|
| 56 |
+
quant_dims, latent_channels, latent_channels, 1
|
| 57 |
+
)
|
| 58 |
+
else:
|
| 59 |
+
self.quant_conv = nn.Identity()
|
| 60 |
+
self.post_quant_conv = nn.Identity()
|
| 61 |
+
|
| 62 |
+
if normalize_latent_channels:
|
| 63 |
+
if dims == 2:
|
| 64 |
+
self.latent_norm_out = nn.BatchNorm2d(latent_channels, affine=False)
|
| 65 |
+
else:
|
| 66 |
+
self.latent_norm_out = nn.BatchNorm3d(latent_channels, affine=False)
|
| 67 |
+
else:
|
| 68 |
+
self.latent_norm_out = nn.Identity()
|
| 69 |
+
self.use_z_tiling = False
|
| 70 |
+
self.use_hw_tiling = False
|
| 71 |
+
self.dims = dims
|
| 72 |
+
self.z_sample_size = 1
|
| 73 |
+
|
| 74 |
+
self.decoder_params = inspect.signature(self.decoder.forward).parameters
|
| 75 |
+
|
| 76 |
+
# only relevant if vae tiling is enabled
|
| 77 |
+
self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25)
|
| 78 |
+
|
| 79 |
+
def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25):
|
| 80 |
+
self.tile_sample_min_size = sample_size
|
| 81 |
+
num_blocks = len(self.encoder.down_blocks)
|
| 82 |
+
self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1)))
|
| 83 |
+
self.tile_overlap_factor = overlap_factor
|
| 84 |
+
|
| 85 |
+
def enable_z_tiling(self, z_sample_size: int = 8):
|
| 86 |
+
r"""
|
| 87 |
+
Enable tiling during VAE decoding.
|
| 88 |
+
|
| 89 |
+
When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several
|
| 90 |
+
steps. This is useful to save some memory and allow larger batch sizes.
|
| 91 |
+
"""
|
| 92 |
+
self.use_z_tiling = z_sample_size > 1
|
| 93 |
+
self.z_sample_size = z_sample_size
|
| 94 |
+
assert (
|
| 95 |
+
z_sample_size % 8 == 0 or z_sample_size == 1
|
| 96 |
+
), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}."
|
| 97 |
+
|
| 98 |
+
def disable_z_tiling(self):
|
| 99 |
+
r"""
|
| 100 |
+
Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing
|
| 101 |
+
decoding in one step.
|
| 102 |
+
"""
|
| 103 |
+
self.use_z_tiling = False
|
| 104 |
+
|
| 105 |
+
def enable_hw_tiling(self):
|
| 106 |
+
r"""
|
| 107 |
+
Enable tiling during VAE decoding along the height and width dimension.
|
| 108 |
+
"""
|
| 109 |
+
self.use_hw_tiling = True
|
| 110 |
+
|
| 111 |
+
def disable_hw_tiling(self):
|
| 112 |
+
r"""
|
| 113 |
+
Disable tiling during VAE decoding along the height and width dimension.
|
| 114 |
+
"""
|
| 115 |
+
self.use_hw_tiling = False
|
| 116 |
+
|
| 117 |
+
def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True):
|
| 118 |
+
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
|
| 119 |
+
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
|
| 120 |
+
row_limit = self.tile_latent_min_size - blend_extent
|
| 121 |
+
|
| 122 |
+
# Split the image into 512x512 tiles and encode them separately.
|
| 123 |
+
rows = []
|
| 124 |
+
for i in range(0, x.shape[3], overlap_size):
|
| 125 |
+
row = []
|
| 126 |
+
for j in range(0, x.shape[4], overlap_size):
|
| 127 |
+
tile = x[
|
| 128 |
+
:,
|
| 129 |
+
:,
|
| 130 |
+
:,
|
| 131 |
+
i : i + self.tile_sample_min_size,
|
| 132 |
+
j : j + self.tile_sample_min_size,
|
| 133 |
+
]
|
| 134 |
+
tile = self.encoder(tile)
|
| 135 |
+
tile = self.quant_conv(tile)
|
| 136 |
+
row.append(tile)
|
| 137 |
+
rows.append(row)
|
| 138 |
+
result_rows = []
|
| 139 |
+
for i, row in enumerate(rows):
|
| 140 |
+
result_row = []
|
| 141 |
+
for j, tile in enumerate(row):
|
| 142 |
+
# blend the above tile and the left tile
|
| 143 |
+
# to the current tile and add the current tile to the result row
|
| 144 |
+
if i > 0:
|
| 145 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 146 |
+
if j > 0:
|
| 147 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 148 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 149 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 150 |
+
|
| 151 |
+
moments = torch.cat(result_rows, dim=3)
|
| 152 |
+
return moments
|
| 153 |
+
|
| 154 |
+
def blend_z(
|
| 155 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
| 156 |
+
) -> torch.Tensor:
|
| 157 |
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
| 158 |
+
for z in range(blend_extent):
|
| 159 |
+
b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * (
|
| 160 |
+
1 - z / blend_extent
|
| 161 |
+
) + b[:, :, z, :, :] * (z / blend_extent)
|
| 162 |
+
return b
|
| 163 |
+
|
| 164 |
+
def blend_v(
|
| 165 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
| 166 |
+
) -> torch.Tensor:
|
| 167 |
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
| 168 |
+
for y in range(blend_extent):
|
| 169 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
|
| 170 |
+
1 - y / blend_extent
|
| 171 |
+
) + b[:, :, :, y, :] * (y / blend_extent)
|
| 172 |
+
return b
|
| 173 |
+
|
| 174 |
+
def blend_h(
|
| 175 |
+
self, a: torch.Tensor, b: torch.Tensor, blend_extent: int
|
| 176 |
+
) -> torch.Tensor:
|
| 177 |
+
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
|
| 178 |
+
for x in range(blend_extent):
|
| 179 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
|
| 180 |
+
1 - x / blend_extent
|
| 181 |
+
) + b[:, :, :, :, x] * (x / blend_extent)
|
| 182 |
+
return b
|
| 183 |
+
|
| 184 |
+
def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape):
|
| 185 |
+
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
|
| 186 |
+
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
|
| 187 |
+
row_limit = self.tile_sample_min_size - blend_extent
|
| 188 |
+
tile_target_shape = (
|
| 189 |
+
*target_shape[:3],
|
| 190 |
+
self.tile_sample_min_size,
|
| 191 |
+
self.tile_sample_min_size,
|
| 192 |
+
)
|
| 193 |
+
# Split z into overlapping 64x64 tiles and decode them separately.
|
| 194 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 195 |
+
rows = []
|
| 196 |
+
for i in range(0, z.shape[3], overlap_size):
|
| 197 |
+
row = []
|
| 198 |
+
for j in range(0, z.shape[4], overlap_size):
|
| 199 |
+
tile = z[
|
| 200 |
+
:,
|
| 201 |
+
:,
|
| 202 |
+
:,
|
| 203 |
+
i : i + self.tile_latent_min_size,
|
| 204 |
+
j : j + self.tile_latent_min_size,
|
| 205 |
+
]
|
| 206 |
+
tile = self.post_quant_conv(tile)
|
| 207 |
+
decoded = self.decoder(tile, target_shape=tile_target_shape)
|
| 208 |
+
row.append(decoded)
|
| 209 |
+
rows.append(row)
|
| 210 |
+
result_rows = []
|
| 211 |
+
for i, row in enumerate(rows):
|
| 212 |
+
result_row = []
|
| 213 |
+
for j, tile in enumerate(row):
|
| 214 |
+
# blend the above tile and the left tile
|
| 215 |
+
# to the current tile and add the current tile to the result row
|
| 216 |
+
if i > 0:
|
| 217 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
|
| 218 |
+
if j > 0:
|
| 219 |
+
tile = self.blend_h(row[j - 1], tile, blend_extent)
|
| 220 |
+
result_row.append(tile[:, :, :, :row_limit, :row_limit])
|
| 221 |
+
result_rows.append(torch.cat(result_row, dim=4))
|
| 222 |
+
|
| 223 |
+
dec = torch.cat(result_rows, dim=3)
|
| 224 |
+
return dec
|
| 225 |
+
|
| 226 |
+
def encode(
|
| 227 |
+
self, z: torch.FloatTensor, return_dict: bool = True
|
| 228 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 229 |
+
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
| 230 |
+
num_splits = z.shape[2] // self.z_sample_size
|
| 231 |
+
sizes = [self.z_sample_size] * num_splits
|
| 232 |
+
sizes = (
|
| 233 |
+
sizes + [z.shape[2] - sum(sizes)]
|
| 234 |
+
if z.shape[2] - sum(sizes) > 0
|
| 235 |
+
else sizes
|
| 236 |
+
)
|
| 237 |
+
tiles = z.split(sizes, dim=2)
|
| 238 |
+
moments_tiles = [
|
| 239 |
+
(
|
| 240 |
+
self._hw_tiled_encode(z_tile, return_dict)
|
| 241 |
+
if self.use_hw_tiling
|
| 242 |
+
else self._encode(z_tile)
|
| 243 |
+
)
|
| 244 |
+
for z_tile in tiles
|
| 245 |
+
]
|
| 246 |
+
moments = torch.cat(moments_tiles, dim=2)
|
| 247 |
+
|
| 248 |
+
else:
|
| 249 |
+
moments = (
|
| 250 |
+
self._hw_tiled_encode(z, return_dict)
|
| 251 |
+
if self.use_hw_tiling
|
| 252 |
+
else self._encode(z)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
posterior = DiagonalGaussianDistribution(moments)
|
| 256 |
+
if not return_dict:
|
| 257 |
+
return (posterior,)
|
| 258 |
+
|
| 259 |
+
return AutoencoderKLOutput(latent_dist=posterior)
|
| 260 |
+
|
| 261 |
+
def _normalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
|
| 262 |
+
if isinstance(self.latent_norm_out, nn.BatchNorm3d):
|
| 263 |
+
_, c, _, _, _ = z.shape
|
| 264 |
+
z = torch.cat(
|
| 265 |
+
[
|
| 266 |
+
self.latent_norm_out(z[:, : c // 2, :, :, :]),
|
| 267 |
+
z[:, c // 2 :, :, :, :],
|
| 268 |
+
],
|
| 269 |
+
dim=1,
|
| 270 |
+
)
|
| 271 |
+
elif isinstance(self.latent_norm_out, nn.BatchNorm2d):
|
| 272 |
+
raise NotImplementedError("BatchNorm2d not supported")
|
| 273 |
+
return z
|
| 274 |
+
|
| 275 |
+
def _unnormalize_latent_channels(self, z: torch.FloatTensor) -> torch.FloatTensor:
|
| 276 |
+
if isinstance(self.latent_norm_out, nn.BatchNorm3d):
|
| 277 |
+
running_mean = self.latent_norm_out.running_mean.view(1, -1, 1, 1, 1)
|
| 278 |
+
running_var = self.latent_norm_out.running_var.view(1, -1, 1, 1, 1)
|
| 279 |
+
eps = self.latent_norm_out.eps
|
| 280 |
+
|
| 281 |
+
z = z * torch.sqrt(running_var + eps) + running_mean
|
| 282 |
+
elif isinstance(self.latent_norm_out, nn.BatchNorm3d):
|
| 283 |
+
raise NotImplementedError("BatchNorm2d not supported")
|
| 284 |
+
return z
|
| 285 |
+
|
| 286 |
+
def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput:
|
| 287 |
+
h = self.encoder(x)
|
| 288 |
+
moments = self.quant_conv(h)
|
| 289 |
+
moments = self._normalize_latent_channels(moments)
|
| 290 |
+
return moments
|
| 291 |
+
|
| 292 |
+
def _decode(
|
| 293 |
+
self,
|
| 294 |
+
z: torch.FloatTensor,
|
| 295 |
+
target_shape=None,
|
| 296 |
+
timestep: Optional[torch.Tensor] = None,
|
| 297 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 298 |
+
z = self._unnormalize_latent_channels(z)
|
| 299 |
+
z = self.post_quant_conv(z)
|
| 300 |
+
if "timestep" in self.decoder_params:
|
| 301 |
+
dec = self.decoder(z, target_shape=target_shape, timestep=timestep)
|
| 302 |
+
else:
|
| 303 |
+
dec = self.decoder(z, target_shape=target_shape)
|
| 304 |
+
return dec
|
| 305 |
+
|
| 306 |
+
def decode(
|
| 307 |
+
self,
|
| 308 |
+
z: torch.FloatTensor,
|
| 309 |
+
return_dict: bool = True,
|
| 310 |
+
target_shape=None,
|
| 311 |
+
timestep: Optional[torch.Tensor] = None,
|
| 312 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 313 |
+
assert target_shape is not None, "target_shape must be provided for decoding"
|
| 314 |
+
if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1:
|
| 315 |
+
reduction_factor = int(
|
| 316 |
+
self.encoder.patch_size_t
|
| 317 |
+
* 2
|
| 318 |
+
** (
|
| 319 |
+
len(self.encoder.down_blocks)
|
| 320 |
+
- 1
|
| 321 |
+
- math.sqrt(self.encoder.patch_size)
|
| 322 |
+
)
|
| 323 |
+
)
|
| 324 |
+
split_size = self.z_sample_size // reduction_factor
|
| 325 |
+
num_splits = z.shape[2] // split_size
|
| 326 |
+
|
| 327 |
+
# copy target shape, and divide frame dimension (=2) by the context size
|
| 328 |
+
target_shape_split = list(target_shape)
|
| 329 |
+
target_shape_split[2] = target_shape[2] // num_splits
|
| 330 |
+
|
| 331 |
+
decoded_tiles = [
|
| 332 |
+
(
|
| 333 |
+
self._hw_tiled_decode(z_tile, target_shape_split)
|
| 334 |
+
if self.use_hw_tiling
|
| 335 |
+
else self._decode(z_tile, target_shape=target_shape_split)
|
| 336 |
+
)
|
| 337 |
+
for z_tile in torch.tensor_split(z, num_splits, dim=2)
|
| 338 |
+
]
|
| 339 |
+
decoded = torch.cat(decoded_tiles, dim=2)
|
| 340 |
+
else:
|
| 341 |
+
decoded = (
|
| 342 |
+
self._hw_tiled_decode(z, target_shape)
|
| 343 |
+
if self.use_hw_tiling
|
| 344 |
+
else self._decode(z, target_shape=target_shape, timestep=timestep)
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
if not return_dict:
|
| 348 |
+
return (decoded,)
|
| 349 |
+
|
| 350 |
+
return DecoderOutput(sample=decoded)
|
| 351 |
+
|
| 352 |
+
def forward(
|
| 353 |
+
self,
|
| 354 |
+
sample: torch.FloatTensor,
|
| 355 |
+
sample_posterior: bool = False,
|
| 356 |
+
return_dict: bool = True,
|
| 357 |
+
generator: Optional[torch.Generator] = None,
|
| 358 |
+
) -> Union[DecoderOutput, torch.FloatTensor]:
|
| 359 |
+
r"""
|
| 360 |
+
Args:
|
| 361 |
+
sample (`torch.FloatTensor`): Input sample.
|
| 362 |
+
sample_posterior (`bool`, *optional*, defaults to `False`):
|
| 363 |
+
Whether to sample from the posterior.
|
| 364 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 365 |
+
Whether to return a [`DecoderOutput`] instead of a plain tuple.
|
| 366 |
+
generator (`torch.Generator`, *optional*):
|
| 367 |
+
Generator used to sample from the posterior.
|
| 368 |
+
"""
|
| 369 |
+
x = sample
|
| 370 |
+
posterior = self.encode(x).latent_dist
|
| 371 |
+
if sample_posterior:
|
| 372 |
+
z = posterior.sample(generator=generator)
|
| 373 |
+
else:
|
| 374 |
+
z = posterior.mode()
|
| 375 |
+
dec = self.decode(z, target_shape=sample.shape).sample
|
| 376 |
+
|
| 377 |
+
if not return_dict:
|
| 378 |
+
return (dec,)
|
| 379 |
+
|
| 380 |
+
return DecoderOutput(sample=dec)
|
flash_head/ltx_video/models/autoencoders/vae_encode.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers import AutoencoderKL
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from flash_head.ltx_video.models.autoencoders.causal_video_autoencoder import (
|
| 9 |
+
CausalVideoAutoencoder,
|
| 10 |
+
)
|
| 11 |
+
from flash_head.ltx_video.models.autoencoders.video_autoencoder import (
|
| 12 |
+
Downsample3D,
|
| 13 |
+
VideoAutoencoder,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import torch_xla.core.xla_model as xm
|
| 18 |
+
except ImportError:
|
| 19 |
+
xm = None
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def vae_encode(
|
| 23 |
+
media_items: Tensor,
|
| 24 |
+
vae: AutoencoderKL,
|
| 25 |
+
split_size: int = 1,
|
| 26 |
+
vae_per_channel_normalize=False,
|
| 27 |
+
) -> Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Encodes media items (images or videos) into latent representations using a specified VAE model.
|
| 30 |
+
The function supports processing batches of images or video frames and can handle the processing
|
| 31 |
+
in smaller sub-batches if needed.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
media_items (Tensor): A torch Tensor containing the media items to encode. The expected
|
| 35 |
+
shape is (batch_size, channels, height, width) for images or (batch_size, channels,
|
| 36 |
+
frames, height, width) for videos.
|
| 37 |
+
vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library,
|
| 38 |
+
pre-configured and loaded with the appropriate model weights.
|
| 39 |
+
split_size (int, optional): The number of sub-batches to split the input batch into for encoding.
|
| 40 |
+
If set to more than 1, the input media items are processed in smaller batches according to
|
| 41 |
+
this value. Defaults to 1, which processes all items in a single batch.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted
|
| 45 |
+
to match the input shape, scaled by the model's configuration.
|
| 46 |
+
|
| 47 |
+
Examples:
|
| 48 |
+
>>> import torch
|
| 49 |
+
>>> from diffusers import AutoencoderKL
|
| 50 |
+
>>> vae = AutoencoderKL.from_pretrained('your-model-name')
|
| 51 |
+
>>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames.
|
| 52 |
+
>>> latents = vae_encode(images, vae)
|
| 53 |
+
>>> print(latents.shape) # Output shape will depend on the model's latent configuration.
|
| 54 |
+
|
| 55 |
+
Note:
|
| 56 |
+
In case of a video, the function encodes the media item frame-by frame.
|
| 57 |
+
"""
|
| 58 |
+
is_video_shaped = media_items.dim() == 5
|
| 59 |
+
batch_size, channels = media_items.shape[0:2]
|
| 60 |
+
|
| 61 |
+
if channels != 3:
|
| 62 |
+
raise ValueError(f"Expects tensors with 3 channels, got {channels}.")
|
| 63 |
+
|
| 64 |
+
# if is_video_shaped and not isinstance(
|
| 65 |
+
# vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 66 |
+
# ): #这里经过fsdp包裹之后,无法判断了,因此后面的条件就不要了
|
| 67 |
+
if is_video_shaped and False: #False是为了兼容fsdp包裹之后的模型
|
| 68 |
+
media_items = rearrange(media_items, "b c n h w -> (b n) c h w")
|
| 69 |
+
if split_size > 1:
|
| 70 |
+
if len(media_items) % split_size != 0:
|
| 71 |
+
raise ValueError(
|
| 72 |
+
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
| 73 |
+
)
|
| 74 |
+
encode_bs = len(media_items) // split_size
|
| 75 |
+
# latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)]
|
| 76 |
+
latents = []
|
| 77 |
+
if media_items.device.type == "xla":
|
| 78 |
+
xm.mark_step()
|
| 79 |
+
for image_batch in media_items.split(encode_bs):
|
| 80 |
+
latents.append(vae.encode(image_batch).latent_dist.sample())
|
| 81 |
+
if media_items.device.type == "xla":
|
| 82 |
+
xm.mark_step()
|
| 83 |
+
latents = torch.cat(latents, dim=0)
|
| 84 |
+
else:
|
| 85 |
+
latents = vae.encode(media_items).latent_dist.sample()
|
| 86 |
+
|
| 87 |
+
latents = normalize_latents(latents, vae, vae_per_channel_normalize)
|
| 88 |
+
# if is_video_shaped and not isinstance(
|
| 89 |
+
# vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 90 |
+
# ):
|
| 91 |
+
if is_video_shaped and False: #False是为了兼容fsdp包裹之后的模型
|
| 92 |
+
latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size)
|
| 93 |
+
return latents
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def vae_decode(
|
| 97 |
+
latents: Tensor,
|
| 98 |
+
vae: AutoencoderKL,
|
| 99 |
+
is_video: bool = True,
|
| 100 |
+
split_size: int = 1,
|
| 101 |
+
vae_per_channel_normalize=False,
|
| 102 |
+
timestep=None,
|
| 103 |
+
) -> Tensor:
|
| 104 |
+
is_video_shaped = latents.dim() == 5
|
| 105 |
+
batch_size = latents.shape[0]
|
| 106 |
+
|
| 107 |
+
# if is_video_shaped and not isinstance(
|
| 108 |
+
# vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 109 |
+
# ):
|
| 110 |
+
if is_video_shaped and False: #False是为了兼容fsdp包裹之后的模型
|
| 111 |
+
latents = rearrange(latents, "b c n h w -> (b n) c h w")
|
| 112 |
+
if split_size > 1:
|
| 113 |
+
if len(latents) % split_size != 0:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
"Error: The batch size must be divisible by 'train.vae_bs_split"
|
| 116 |
+
)
|
| 117 |
+
encode_bs = len(latents) // split_size
|
| 118 |
+
image_batch = [
|
| 119 |
+
_run_decoder(
|
| 120 |
+
latent_batch, vae, is_video, vae_per_channel_normalize, timestep
|
| 121 |
+
)
|
| 122 |
+
for latent_batch in latents.split(encode_bs)
|
| 123 |
+
]
|
| 124 |
+
images = torch.cat(image_batch, dim=0)
|
| 125 |
+
else:
|
| 126 |
+
images = _run_decoder(
|
| 127 |
+
latents, vae, is_video, vae_per_channel_normalize, timestep
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# if is_video_shaped and not isinstance(
|
| 131 |
+
# vae, (VideoAutoencoder, CausalVideoAutoencoder)
|
| 132 |
+
# ):
|
| 133 |
+
if is_video_shaped and False: #False是为了兼容fsdp包裹之后的模型
|
| 134 |
+
images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size)
|
| 135 |
+
return images
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _run_decoder(
|
| 139 |
+
latents: Tensor,
|
| 140 |
+
vae: AutoencoderKL,
|
| 141 |
+
is_video: bool,
|
| 142 |
+
vae_per_channel_normalize=False,
|
| 143 |
+
timestep=None,
|
| 144 |
+
) -> Tensor:
|
| 145 |
+
# if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)):
|
| 146 |
+
if False: #True是为了兼容fsdp包裹之后的模型
|
| 147 |
+
*_, fl, hl, wl = latents.shape
|
| 148 |
+
temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae)
|
| 149 |
+
latents = latents.to(vae.dtype)
|
| 150 |
+
vae_decode_kwargs = {}
|
| 151 |
+
if timestep is not None:
|
| 152 |
+
vae_decode_kwargs["timestep"] = timestep
|
| 153 |
+
image = vae.decode(
|
| 154 |
+
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
| 155 |
+
return_dict=False,
|
| 156 |
+
target_shape=(
|
| 157 |
+
1,
|
| 158 |
+
3,
|
| 159 |
+
fl * temporal_scale if is_video else 1,
|
| 160 |
+
hl * spatial_scale,
|
| 161 |
+
wl * spatial_scale,
|
| 162 |
+
),
|
| 163 |
+
**vae_decode_kwargs,
|
| 164 |
+
)[0]
|
| 165 |
+
else:
|
| 166 |
+
image = vae.decode(
|
| 167 |
+
un_normalize_latents(latents, vae, vae_per_channel_normalize),
|
| 168 |
+
return_dict=False,
|
| 169 |
+
target_shape=latents.shape
|
| 170 |
+
)[0]
|
| 171 |
+
|
| 172 |
+
return image
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def get_vae_size_scale_factor(vae: AutoencoderKL) -> float:
|
| 176 |
+
# if isinstance(vae, CausalVideoAutoencoder):
|
| 177 |
+
if True: #True是为了兼容fsdp包裹之后的模型
|
| 178 |
+
spatial = vae.spatial_downscale_factor
|
| 179 |
+
temporal = vae.temporal_downscale_factor
|
| 180 |
+
else:
|
| 181 |
+
down_blocks = len(
|
| 182 |
+
[
|
| 183 |
+
block
|
| 184 |
+
for block in vae.encoder.down_blocks
|
| 185 |
+
if isinstance(block.downsample, Downsample3D)
|
| 186 |
+
]
|
| 187 |
+
)
|
| 188 |
+
spatial = vae.config.patch_size * 2**down_blocks
|
| 189 |
+
temporal = (
|
| 190 |
+
vae.config.patch_size_t * 2**down_blocks
|
| 191 |
+
if isinstance(vae, VideoAutoencoder)
|
| 192 |
+
else 1
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
return (temporal, spatial, spatial)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def latent_to_pixel_coords(
|
| 199 |
+
latent_coords: Tensor, vae: AutoencoderKL, causal_fix: bool = False
|
| 200 |
+
) -> Tensor:
|
| 201 |
+
"""
|
| 202 |
+
Converts latent coordinates to pixel coordinates by scaling them according to the VAE's
|
| 203 |
+
configuration.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
latent_coords (Tensor): A tensor of shape [batch_size, 3, num_latents]
|
| 207 |
+
containing the latent corner coordinates of each token.
|
| 208 |
+
vae (AutoencoderKL): The VAE model
|
| 209 |
+
causal_fix (bool): Whether to take into account the different temporal scale
|
| 210 |
+
of the first frame. Default = False for backwards compatibility.
|
| 211 |
+
Returns:
|
| 212 |
+
Tensor: A tensor of pixel coordinates corresponding to the input latent coordinates.
|
| 213 |
+
"""
|
| 214 |
+
|
| 215 |
+
scale_factors = get_vae_size_scale_factor(vae)
|
| 216 |
+
# causal_fix = isinstance(vae, CausalVideoAutoencoder) and causal_fix
|
| 217 |
+
causal_fix = True and causal_fix #True是为了兼容fsdp包裹之后的模型
|
| 218 |
+
pixel_coords = latent_to_pixel_coords_from_factors(
|
| 219 |
+
latent_coords, scale_factors, causal_fix
|
| 220 |
+
)
|
| 221 |
+
return pixel_coords
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def latent_to_pixel_coords_from_factors(
|
| 225 |
+
latent_coords: Tensor, scale_factors: Tuple, causal_fix: bool = False
|
| 226 |
+
) -> Tensor:
|
| 227 |
+
pixel_coords = (
|
| 228 |
+
latent_coords
|
| 229 |
+
* torch.tensor(scale_factors, device=latent_coords.device)[None, :, None]
|
| 230 |
+
)
|
| 231 |
+
if causal_fix:
|
| 232 |
+
# Fix temporal scale for first frame to 1 due to causality
|
| 233 |
+
pixel_coords[:, 0] = (pixel_coords[:, 0] + 1 - scale_factors[0]).clamp(min=0)
|
| 234 |
+
return pixel_coords
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def normalize_latents(
|
| 238 |
+
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
| 239 |
+
) -> Tensor:
|
| 240 |
+
return (
|
| 241 |
+
(latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1))
|
| 242 |
+
/ vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 243 |
+
if vae_per_channel_normalize
|
| 244 |
+
else latents * vae.config.scaling_factor
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def un_normalize_latents(
|
| 249 |
+
latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False
|
| 250 |
+
) -> Tensor:
|
| 251 |
+
return (
|
| 252 |
+
latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 253 |
+
+ vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)
|
| 254 |
+
if vae_per_channel_normalize
|
| 255 |
+
else latents / vae.config.scaling_factor
|
| 256 |
+
)
|
flash_head/ltx_video/models/autoencoders/video_autoencoder.py
ADDED
|
@@ -0,0 +1,1045 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from types import SimpleNamespace
|
| 5 |
+
from typing import Any, Mapping, Optional, Tuple, Union
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn import functional
|
| 11 |
+
|
| 12 |
+
from diffusers.utils import logging
|
| 13 |
+
|
| 14 |
+
from flash_head.ltx_video.utils.torch_utils import Identity
|
| 15 |
+
from flash_head.ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
|
| 16 |
+
from flash_head.ltx_video.models.autoencoders.pixel_norm import PixelNorm
|
| 17 |
+
from flash_head.ltx_video.models.autoencoders.vae import AutoencoderKLWrapper
|
| 18 |
+
|
| 19 |
+
logger = logging.get_logger(__name__)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class VideoAutoencoder(AutoencoderKLWrapper):
|
| 23 |
+
@classmethod
|
| 24 |
+
def from_pretrained(
|
| 25 |
+
cls,
|
| 26 |
+
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
|
| 27 |
+
*args,
|
| 28 |
+
**kwargs,
|
| 29 |
+
):
|
| 30 |
+
config_local_path = pretrained_model_name_or_path / "config.json"
|
| 31 |
+
config = cls.load_config(config_local_path, **kwargs)
|
| 32 |
+
video_vae = cls.from_config(config)
|
| 33 |
+
video_vae.to(kwargs["torch_dtype"])
|
| 34 |
+
|
| 35 |
+
model_local_path = pretrained_model_name_or_path / "autoencoder.pth"
|
| 36 |
+
ckpt_state_dict = torch.load(model_local_path)
|
| 37 |
+
video_vae.load_state_dict(ckpt_state_dict)
|
| 38 |
+
|
| 39 |
+
statistics_local_path = (
|
| 40 |
+
pretrained_model_name_or_path / "per_channel_statistics.json"
|
| 41 |
+
)
|
| 42 |
+
if statistics_local_path.exists():
|
| 43 |
+
with open(statistics_local_path, "r") as file:
|
| 44 |
+
data = json.load(file)
|
| 45 |
+
transposed_data = list(zip(*data["data"]))
|
| 46 |
+
data_dict = {
|
| 47 |
+
col: torch.tensor(vals)
|
| 48 |
+
for col, vals in zip(data["columns"], transposed_data)
|
| 49 |
+
}
|
| 50 |
+
video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
|
| 51 |
+
video_vae.register_buffer(
|
| 52 |
+
"mean_of_means",
|
| 53 |
+
data_dict.get(
|
| 54 |
+
"mean-of-means", torch.zeros_like(data_dict["std-of-means"])
|
| 55 |
+
),
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
return video_vae
|
| 59 |
+
|
| 60 |
+
@staticmethod
|
| 61 |
+
def from_config(config):
|
| 62 |
+
assert (
|
| 63 |
+
config["_class_name"] == "VideoAutoencoder"
|
| 64 |
+
), "config must have _class_name=VideoAutoencoder"
|
| 65 |
+
if isinstance(config["dims"], list):
|
| 66 |
+
config["dims"] = tuple(config["dims"])
|
| 67 |
+
|
| 68 |
+
assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)"
|
| 69 |
+
|
| 70 |
+
double_z = config.get("double_z", True)
|
| 71 |
+
latent_log_var = config.get(
|
| 72 |
+
"latent_log_var", "per_channel" if double_z else "none"
|
| 73 |
+
)
|
| 74 |
+
use_quant_conv = config.get("use_quant_conv", True)
|
| 75 |
+
|
| 76 |
+
if use_quant_conv and latent_log_var == "uniform":
|
| 77 |
+
raise ValueError("uniform latent_log_var requires use_quant_conv=False")
|
| 78 |
+
|
| 79 |
+
encoder = Encoder(
|
| 80 |
+
dims=config["dims"],
|
| 81 |
+
in_channels=config.get("in_channels", 3),
|
| 82 |
+
out_channels=config["latent_channels"],
|
| 83 |
+
block_out_channels=config["block_out_channels"],
|
| 84 |
+
patch_size=config.get("patch_size", 1),
|
| 85 |
+
latent_log_var=latent_log_var,
|
| 86 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 87 |
+
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
|
| 88 |
+
add_channel_padding=config.get("add_channel_padding", False),
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
decoder = Decoder(
|
| 92 |
+
dims=config["dims"],
|
| 93 |
+
in_channels=config["latent_channels"],
|
| 94 |
+
out_channels=config.get("out_channels", 3),
|
| 95 |
+
block_out_channels=config["block_out_channels"],
|
| 96 |
+
patch_size=config.get("patch_size", 1),
|
| 97 |
+
norm_layer=config.get("norm_layer", "group_norm"),
|
| 98 |
+
patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)),
|
| 99 |
+
add_channel_padding=config.get("add_channel_padding", False),
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
dims = config["dims"]
|
| 103 |
+
return VideoAutoencoder(
|
| 104 |
+
encoder=encoder,
|
| 105 |
+
decoder=decoder,
|
| 106 |
+
latent_channels=config["latent_channels"],
|
| 107 |
+
dims=dims,
|
| 108 |
+
use_quant_conv=use_quant_conv,
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
@property
|
| 112 |
+
def config(self):
|
| 113 |
+
return SimpleNamespace(
|
| 114 |
+
_class_name="VideoAutoencoder",
|
| 115 |
+
dims=self.dims,
|
| 116 |
+
in_channels=self.encoder.conv_in.in_channels
|
| 117 |
+
// (self.encoder.patch_size_t * self.encoder.patch_size**2),
|
| 118 |
+
out_channels=self.decoder.conv_out.out_channels
|
| 119 |
+
// (self.decoder.patch_size_t * self.decoder.patch_size**2),
|
| 120 |
+
latent_channels=self.decoder.conv_in.in_channels,
|
| 121 |
+
block_out_channels=[
|
| 122 |
+
self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels
|
| 123 |
+
for i in range(len(self.encoder.down_blocks))
|
| 124 |
+
],
|
| 125 |
+
scaling_factor=1.0,
|
| 126 |
+
norm_layer=self.encoder.norm_layer,
|
| 127 |
+
patch_size=self.encoder.patch_size,
|
| 128 |
+
latent_log_var=self.encoder.latent_log_var,
|
| 129 |
+
use_quant_conv=self.use_quant_conv,
|
| 130 |
+
patch_size_t=self.encoder.patch_size_t,
|
| 131 |
+
add_channel_padding=self.encoder.add_channel_padding,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def is_video_supported(self):
|
| 136 |
+
"""
|
| 137 |
+
Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images.
|
| 138 |
+
"""
|
| 139 |
+
return self.dims != 2
|
| 140 |
+
|
| 141 |
+
@property
|
| 142 |
+
def downscale_factor(self):
|
| 143 |
+
return self.encoder.downsample_factor
|
| 144 |
+
|
| 145 |
+
def to_json_string(self) -> str:
|
| 146 |
+
import json
|
| 147 |
+
|
| 148 |
+
return json.dumps(self.config.__dict__)
|
| 149 |
+
|
| 150 |
+
def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
|
| 151 |
+
model_keys = set(name for name, _ in self.named_parameters())
|
| 152 |
+
|
| 153 |
+
key_mapping = {
|
| 154 |
+
".resnets.": ".res_blocks.",
|
| 155 |
+
"downsamplers.0": "downsample",
|
| 156 |
+
"upsamplers.0": "upsample",
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
converted_state_dict = {}
|
| 160 |
+
for key, value in state_dict.items():
|
| 161 |
+
for k, v in key_mapping.items():
|
| 162 |
+
key = key.replace(k, v)
|
| 163 |
+
|
| 164 |
+
if "norm" in key and key not in model_keys:
|
| 165 |
+
logger.info(
|
| 166 |
+
f"Removing key {key} from state_dict as it is not present in the model"
|
| 167 |
+
)
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
converted_state_dict[key] = value
|
| 171 |
+
|
| 172 |
+
super().load_state_dict(converted_state_dict, strict=strict)
|
| 173 |
+
|
| 174 |
+
def last_layer(self):
|
| 175 |
+
if hasattr(self.decoder, "conv_out"):
|
| 176 |
+
if isinstance(self.decoder.conv_out, nn.Sequential):
|
| 177 |
+
last_layer = self.decoder.conv_out[-1]
|
| 178 |
+
else:
|
| 179 |
+
last_layer = self.decoder.conv_out
|
| 180 |
+
else:
|
| 181 |
+
last_layer = self.decoder.layers[-1]
|
| 182 |
+
return last_layer
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class Encoder(nn.Module):
|
| 186 |
+
r"""
|
| 187 |
+
The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation.
|
| 188 |
+
|
| 189 |
+
Args:
|
| 190 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 191 |
+
The number of input channels.
|
| 192 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 193 |
+
The number of output channels.
|
| 194 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 195 |
+
The number of output channels for each block.
|
| 196 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 197 |
+
The number of layers per block.
|
| 198 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 199 |
+
The number of groups for normalization.
|
| 200 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 201 |
+
The patch size to use. Should be a power of 2.
|
| 202 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 203 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 204 |
+
latent_log_var (`str`, *optional*, defaults to `per_channel`):
|
| 205 |
+
The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`.
|
| 206 |
+
"""
|
| 207 |
+
|
| 208 |
+
def __init__(
|
| 209 |
+
self,
|
| 210 |
+
dims: Union[int, Tuple[int, int]] = 3,
|
| 211 |
+
in_channels: int = 3,
|
| 212 |
+
out_channels: int = 3,
|
| 213 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 214 |
+
layers_per_block: int = 2,
|
| 215 |
+
norm_num_groups: int = 32,
|
| 216 |
+
patch_size: Union[int, Tuple[int]] = 1,
|
| 217 |
+
norm_layer: str = "group_norm", # group_norm, pixel_norm
|
| 218 |
+
latent_log_var: str = "per_channel",
|
| 219 |
+
patch_size_t: Optional[int] = None,
|
| 220 |
+
add_channel_padding: Optional[bool] = False,
|
| 221 |
+
):
|
| 222 |
+
super().__init__()
|
| 223 |
+
self.patch_size = patch_size
|
| 224 |
+
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
|
| 225 |
+
self.add_channel_padding = add_channel_padding
|
| 226 |
+
self.layers_per_block = layers_per_block
|
| 227 |
+
self.norm_layer = norm_layer
|
| 228 |
+
self.latent_channels = out_channels
|
| 229 |
+
self.latent_log_var = latent_log_var
|
| 230 |
+
if add_channel_padding:
|
| 231 |
+
in_channels = in_channels * self.patch_size**3
|
| 232 |
+
else:
|
| 233 |
+
in_channels = in_channels * self.patch_size_t * self.patch_size**2
|
| 234 |
+
self.in_channels = in_channels
|
| 235 |
+
output_channel = block_out_channels[0]
|
| 236 |
+
|
| 237 |
+
self.conv_in = make_conv_nd(
|
| 238 |
+
dims=dims,
|
| 239 |
+
in_channels=in_channels,
|
| 240 |
+
out_channels=output_channel,
|
| 241 |
+
kernel_size=3,
|
| 242 |
+
stride=1,
|
| 243 |
+
padding=1,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
self.down_blocks = nn.ModuleList([])
|
| 247 |
+
|
| 248 |
+
for i in range(len(block_out_channels)):
|
| 249 |
+
input_channel = output_channel
|
| 250 |
+
output_channel = block_out_channels[i]
|
| 251 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 252 |
+
|
| 253 |
+
down_block = DownEncoderBlock3D(
|
| 254 |
+
dims=dims,
|
| 255 |
+
in_channels=input_channel,
|
| 256 |
+
out_channels=output_channel,
|
| 257 |
+
num_layers=self.layers_per_block,
|
| 258 |
+
add_downsample=not is_final_block and 2**i >= patch_size,
|
| 259 |
+
resnet_eps=1e-6,
|
| 260 |
+
downsample_padding=0,
|
| 261 |
+
resnet_groups=norm_num_groups,
|
| 262 |
+
norm_layer=norm_layer,
|
| 263 |
+
)
|
| 264 |
+
self.down_blocks.append(down_block)
|
| 265 |
+
|
| 266 |
+
self.mid_block = UNetMidBlock3D(
|
| 267 |
+
dims=dims,
|
| 268 |
+
in_channels=block_out_channels[-1],
|
| 269 |
+
num_layers=self.layers_per_block,
|
| 270 |
+
resnet_eps=1e-6,
|
| 271 |
+
resnet_groups=norm_num_groups,
|
| 272 |
+
norm_layer=norm_layer,
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
# out
|
| 276 |
+
if norm_layer == "group_norm":
|
| 277 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 278 |
+
num_channels=block_out_channels[-1],
|
| 279 |
+
num_groups=norm_num_groups,
|
| 280 |
+
eps=1e-6,
|
| 281 |
+
)
|
| 282 |
+
elif norm_layer == "pixel_norm":
|
| 283 |
+
self.conv_norm_out = PixelNorm()
|
| 284 |
+
self.conv_act = nn.SiLU()
|
| 285 |
+
|
| 286 |
+
conv_out_channels = out_channels
|
| 287 |
+
if latent_log_var == "per_channel":
|
| 288 |
+
conv_out_channels *= 2
|
| 289 |
+
elif latent_log_var == "uniform":
|
| 290 |
+
conv_out_channels += 1
|
| 291 |
+
elif latent_log_var != "none":
|
| 292 |
+
raise ValueError(f"Invalid latent_log_var: {latent_log_var}")
|
| 293 |
+
self.conv_out = make_conv_nd(
|
| 294 |
+
dims, block_out_channels[-1], conv_out_channels, 3, padding=1
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
self.gradient_checkpointing = False
|
| 298 |
+
|
| 299 |
+
@property
|
| 300 |
+
def downscale_factor(self):
|
| 301 |
+
return (
|
| 302 |
+
2
|
| 303 |
+
** len(
|
| 304 |
+
[
|
| 305 |
+
block
|
| 306 |
+
for block in self.down_blocks
|
| 307 |
+
if isinstance(block.downsample, Downsample3D)
|
| 308 |
+
]
|
| 309 |
+
)
|
| 310 |
+
* self.patch_size
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
def forward(
|
| 314 |
+
self, sample: torch.FloatTensor, return_features=False
|
| 315 |
+
) -> torch.FloatTensor:
|
| 316 |
+
r"""The forward method of the `Encoder` class."""
|
| 317 |
+
|
| 318 |
+
downsample_in_time = sample.shape[2] != 1
|
| 319 |
+
|
| 320 |
+
# patchify
|
| 321 |
+
patch_size_t = self.patch_size_t if downsample_in_time else 1
|
| 322 |
+
sample = patchify(
|
| 323 |
+
sample,
|
| 324 |
+
patch_size_hw=self.patch_size,
|
| 325 |
+
patch_size_t=patch_size_t,
|
| 326 |
+
add_channel_padding=self.add_channel_padding,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
sample = self.conv_in(sample)
|
| 330 |
+
|
| 331 |
+
checkpoint_fn = (
|
| 332 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 333 |
+
if self.gradient_checkpointing and self.training
|
| 334 |
+
else lambda x: x
|
| 335 |
+
)
|
| 336 |
+
|
| 337 |
+
if return_features:
|
| 338 |
+
features = []
|
| 339 |
+
for down_block in self.down_blocks:
|
| 340 |
+
sample = checkpoint_fn(down_block)(
|
| 341 |
+
sample, downsample_in_time=downsample_in_time
|
| 342 |
+
)
|
| 343 |
+
if return_features:
|
| 344 |
+
features.append(sample)
|
| 345 |
+
|
| 346 |
+
sample = checkpoint_fn(self.mid_block)(sample)
|
| 347 |
+
|
| 348 |
+
# post-process
|
| 349 |
+
sample = self.conv_norm_out(sample)
|
| 350 |
+
sample = self.conv_act(sample)
|
| 351 |
+
sample = self.conv_out(sample)
|
| 352 |
+
|
| 353 |
+
if self.latent_log_var == "uniform":
|
| 354 |
+
last_channel = sample[:, -1:, ...]
|
| 355 |
+
num_dims = sample.dim()
|
| 356 |
+
|
| 357 |
+
if num_dims == 4:
|
| 358 |
+
# For shape (B, C, H, W)
|
| 359 |
+
repeated_last_channel = last_channel.repeat(
|
| 360 |
+
1, sample.shape[1] - 2, 1, 1
|
| 361 |
+
)
|
| 362 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 363 |
+
elif num_dims == 5:
|
| 364 |
+
# For shape (B, C, F, H, W)
|
| 365 |
+
repeated_last_channel = last_channel.repeat(
|
| 366 |
+
1, sample.shape[1] - 2, 1, 1, 1
|
| 367 |
+
)
|
| 368 |
+
sample = torch.cat([sample, repeated_last_channel], dim=1)
|
| 369 |
+
else:
|
| 370 |
+
raise ValueError(f"Invalid input shape: {sample.shape}")
|
| 371 |
+
|
| 372 |
+
if return_features:
|
| 373 |
+
features.append(sample[:, : self.latent_channels, ...])
|
| 374 |
+
return sample, features
|
| 375 |
+
return sample
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class Decoder(nn.Module):
|
| 379 |
+
r"""
|
| 380 |
+
The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample.
|
| 381 |
+
|
| 382 |
+
Args:
|
| 383 |
+
in_channels (`int`, *optional*, defaults to 3):
|
| 384 |
+
The number of input channels.
|
| 385 |
+
out_channels (`int`, *optional*, defaults to 3):
|
| 386 |
+
The number of output channels.
|
| 387 |
+
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
|
| 388 |
+
The number of output channels for each block.
|
| 389 |
+
layers_per_block (`int`, *optional*, defaults to 2):
|
| 390 |
+
The number of layers per block.
|
| 391 |
+
norm_num_groups (`int`, *optional*, defaults to 32):
|
| 392 |
+
The number of groups for normalization.
|
| 393 |
+
patch_size (`int`, *optional*, defaults to 1):
|
| 394 |
+
The patch size to use. Should be a power of 2.
|
| 395 |
+
norm_layer (`str`, *optional*, defaults to `group_norm`):
|
| 396 |
+
The normalization layer to use. Can be either `group_norm` or `pixel_norm`.
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
def __init__(
|
| 400 |
+
self,
|
| 401 |
+
dims,
|
| 402 |
+
in_channels: int = 3,
|
| 403 |
+
out_channels: int = 3,
|
| 404 |
+
block_out_channels: Tuple[int, ...] = (64,),
|
| 405 |
+
layers_per_block: int = 2,
|
| 406 |
+
norm_num_groups: int = 32,
|
| 407 |
+
patch_size: int = 1,
|
| 408 |
+
norm_layer: str = "group_norm",
|
| 409 |
+
patch_size_t: Optional[int] = None,
|
| 410 |
+
add_channel_padding: Optional[bool] = False,
|
| 411 |
+
):
|
| 412 |
+
super().__init__()
|
| 413 |
+
self.patch_size = patch_size
|
| 414 |
+
self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size
|
| 415 |
+
self.add_channel_padding = add_channel_padding
|
| 416 |
+
self.layers_per_block = layers_per_block
|
| 417 |
+
if add_channel_padding:
|
| 418 |
+
out_channels = out_channels * self.patch_size**3
|
| 419 |
+
else:
|
| 420 |
+
out_channels = out_channels * self.patch_size_t * self.patch_size**2
|
| 421 |
+
self.out_channels = out_channels
|
| 422 |
+
|
| 423 |
+
self.conv_in = make_conv_nd(
|
| 424 |
+
dims,
|
| 425 |
+
in_channels,
|
| 426 |
+
block_out_channels[-1],
|
| 427 |
+
kernel_size=3,
|
| 428 |
+
stride=1,
|
| 429 |
+
padding=1,
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
self.mid_block = None
|
| 433 |
+
self.up_blocks = nn.ModuleList([])
|
| 434 |
+
|
| 435 |
+
self.mid_block = UNetMidBlock3D(
|
| 436 |
+
dims=dims,
|
| 437 |
+
in_channels=block_out_channels[-1],
|
| 438 |
+
num_layers=self.layers_per_block,
|
| 439 |
+
resnet_eps=1e-6,
|
| 440 |
+
resnet_groups=norm_num_groups,
|
| 441 |
+
norm_layer=norm_layer,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 445 |
+
output_channel = reversed_block_out_channels[0]
|
| 446 |
+
for i in range(len(reversed_block_out_channels)):
|
| 447 |
+
prev_output_channel = output_channel
|
| 448 |
+
output_channel = reversed_block_out_channels[i]
|
| 449 |
+
|
| 450 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 451 |
+
|
| 452 |
+
up_block = UpDecoderBlock3D(
|
| 453 |
+
dims=dims,
|
| 454 |
+
num_layers=self.layers_per_block + 1,
|
| 455 |
+
in_channels=prev_output_channel,
|
| 456 |
+
out_channels=output_channel,
|
| 457 |
+
add_upsample=not is_final_block
|
| 458 |
+
and 2 ** (len(block_out_channels) - i - 1) > patch_size,
|
| 459 |
+
resnet_eps=1e-6,
|
| 460 |
+
resnet_groups=norm_num_groups,
|
| 461 |
+
norm_layer=norm_layer,
|
| 462 |
+
)
|
| 463 |
+
self.up_blocks.append(up_block)
|
| 464 |
+
|
| 465 |
+
if norm_layer == "group_norm":
|
| 466 |
+
self.conv_norm_out = nn.GroupNorm(
|
| 467 |
+
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6
|
| 468 |
+
)
|
| 469 |
+
elif norm_layer == "pixel_norm":
|
| 470 |
+
self.conv_norm_out = PixelNorm()
|
| 471 |
+
|
| 472 |
+
self.conv_act = nn.SiLU()
|
| 473 |
+
self.conv_out = make_conv_nd(
|
| 474 |
+
dims, block_out_channels[0], out_channels, 3, padding=1
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
self.gradient_checkpointing = False
|
| 478 |
+
|
| 479 |
+
def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
| 480 |
+
r"""The forward method of the `Decoder` class."""
|
| 481 |
+
assert target_shape is not None, "target_shape must be provided"
|
| 482 |
+
upsample_in_time = sample.shape[2] < target_shape[2]
|
| 483 |
+
|
| 484 |
+
sample = self.conv_in(sample)
|
| 485 |
+
|
| 486 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
| 487 |
+
|
| 488 |
+
checkpoint_fn = (
|
| 489 |
+
partial(torch.utils.checkpoint.checkpoint, use_reentrant=False)
|
| 490 |
+
if self.gradient_checkpointing and self.training
|
| 491 |
+
else lambda x: x
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
sample = checkpoint_fn(self.mid_block)(sample)
|
| 495 |
+
sample = sample.to(upscale_dtype)
|
| 496 |
+
|
| 497 |
+
for up_block in self.up_blocks:
|
| 498 |
+
sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time)
|
| 499 |
+
|
| 500 |
+
# post-process
|
| 501 |
+
sample = self.conv_norm_out(sample)
|
| 502 |
+
sample = self.conv_act(sample)
|
| 503 |
+
sample = self.conv_out(sample)
|
| 504 |
+
|
| 505 |
+
# un-patchify
|
| 506 |
+
patch_size_t = self.patch_size_t if upsample_in_time else 1
|
| 507 |
+
sample = unpatchify(
|
| 508 |
+
sample,
|
| 509 |
+
patch_size_hw=self.patch_size,
|
| 510 |
+
patch_size_t=patch_size_t,
|
| 511 |
+
add_channel_padding=self.add_channel_padding,
|
| 512 |
+
)
|
| 513 |
+
|
| 514 |
+
return sample
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
class DownEncoderBlock3D(nn.Module):
|
| 518 |
+
def __init__(
|
| 519 |
+
self,
|
| 520 |
+
dims: Union[int, Tuple[int, int]],
|
| 521 |
+
in_channels: int,
|
| 522 |
+
out_channels: int,
|
| 523 |
+
dropout: float = 0.0,
|
| 524 |
+
num_layers: int = 1,
|
| 525 |
+
resnet_eps: float = 1e-6,
|
| 526 |
+
resnet_groups: int = 32,
|
| 527 |
+
add_downsample: bool = True,
|
| 528 |
+
downsample_padding: int = 1,
|
| 529 |
+
norm_layer: str = "group_norm",
|
| 530 |
+
):
|
| 531 |
+
super().__init__()
|
| 532 |
+
res_blocks = []
|
| 533 |
+
|
| 534 |
+
for i in range(num_layers):
|
| 535 |
+
in_channels = in_channels if i == 0 else out_channels
|
| 536 |
+
res_blocks.append(
|
| 537 |
+
ResnetBlock3D(
|
| 538 |
+
dims=dims,
|
| 539 |
+
in_channels=in_channels,
|
| 540 |
+
out_channels=out_channels,
|
| 541 |
+
eps=resnet_eps,
|
| 542 |
+
groups=resnet_groups,
|
| 543 |
+
dropout=dropout,
|
| 544 |
+
norm_layer=norm_layer,
|
| 545 |
+
)
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
| 549 |
+
|
| 550 |
+
if add_downsample:
|
| 551 |
+
self.downsample = Downsample3D(
|
| 552 |
+
dims,
|
| 553 |
+
out_channels,
|
| 554 |
+
out_channels=out_channels,
|
| 555 |
+
padding=downsample_padding,
|
| 556 |
+
)
|
| 557 |
+
else:
|
| 558 |
+
self.downsample = Identity()
|
| 559 |
+
|
| 560 |
+
def forward(
|
| 561 |
+
self, hidden_states: torch.FloatTensor, downsample_in_time
|
| 562 |
+
) -> torch.FloatTensor:
|
| 563 |
+
for resnet in self.res_blocks:
|
| 564 |
+
hidden_states = resnet(hidden_states)
|
| 565 |
+
|
| 566 |
+
hidden_states = self.downsample(
|
| 567 |
+
hidden_states, downsample_in_time=downsample_in_time
|
| 568 |
+
)
|
| 569 |
+
|
| 570 |
+
return hidden_states
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
class UNetMidBlock3D(nn.Module):
|
| 574 |
+
"""
|
| 575 |
+
A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks.
|
| 576 |
+
|
| 577 |
+
Args:
|
| 578 |
+
in_channels (`int`): The number of input channels.
|
| 579 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
|
| 580 |
+
num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
|
| 581 |
+
resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
|
| 582 |
+
resnet_groups (`int`, *optional*, defaults to 32):
|
| 583 |
+
The number of groups to use in the group normalization layers of the resnet blocks.
|
| 584 |
+
|
| 585 |
+
Returns:
|
| 586 |
+
`torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
|
| 587 |
+
in_channels, height, width)`.
|
| 588 |
+
|
| 589 |
+
"""
|
| 590 |
+
|
| 591 |
+
def __init__(
|
| 592 |
+
self,
|
| 593 |
+
dims: Union[int, Tuple[int, int]],
|
| 594 |
+
in_channels: int,
|
| 595 |
+
dropout: float = 0.0,
|
| 596 |
+
num_layers: int = 1,
|
| 597 |
+
resnet_eps: float = 1e-6,
|
| 598 |
+
resnet_groups: int = 32,
|
| 599 |
+
norm_layer: str = "group_norm",
|
| 600 |
+
):
|
| 601 |
+
super().__init__()
|
| 602 |
+
resnet_groups = (
|
| 603 |
+
resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
self.res_blocks = nn.ModuleList(
|
| 607 |
+
[
|
| 608 |
+
ResnetBlock3D(
|
| 609 |
+
dims=dims,
|
| 610 |
+
in_channels=in_channels,
|
| 611 |
+
out_channels=in_channels,
|
| 612 |
+
eps=resnet_eps,
|
| 613 |
+
groups=resnet_groups,
|
| 614 |
+
dropout=dropout,
|
| 615 |
+
norm_layer=norm_layer,
|
| 616 |
+
)
|
| 617 |
+
for _ in range(num_layers)
|
| 618 |
+
]
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
|
| 622 |
+
for resnet in self.res_blocks:
|
| 623 |
+
hidden_states = resnet(hidden_states)
|
| 624 |
+
|
| 625 |
+
return hidden_states
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
class UpDecoderBlock3D(nn.Module):
|
| 629 |
+
def __init__(
|
| 630 |
+
self,
|
| 631 |
+
dims: Union[int, Tuple[int, int]],
|
| 632 |
+
in_channels: int,
|
| 633 |
+
out_channels: int,
|
| 634 |
+
resolution_idx: Optional[int] = None,
|
| 635 |
+
dropout: float = 0.0,
|
| 636 |
+
num_layers: int = 1,
|
| 637 |
+
resnet_eps: float = 1e-6,
|
| 638 |
+
resnet_groups: int = 32,
|
| 639 |
+
add_upsample: bool = True,
|
| 640 |
+
norm_layer: str = "group_norm",
|
| 641 |
+
):
|
| 642 |
+
super().__init__()
|
| 643 |
+
res_blocks = []
|
| 644 |
+
|
| 645 |
+
for i in range(num_layers):
|
| 646 |
+
input_channels = in_channels if i == 0 else out_channels
|
| 647 |
+
|
| 648 |
+
res_blocks.append(
|
| 649 |
+
ResnetBlock3D(
|
| 650 |
+
dims=dims,
|
| 651 |
+
in_channels=input_channels,
|
| 652 |
+
out_channels=out_channels,
|
| 653 |
+
eps=resnet_eps,
|
| 654 |
+
groups=resnet_groups,
|
| 655 |
+
dropout=dropout,
|
| 656 |
+
norm_layer=norm_layer,
|
| 657 |
+
)
|
| 658 |
+
)
|
| 659 |
+
|
| 660 |
+
self.res_blocks = nn.ModuleList(res_blocks)
|
| 661 |
+
|
| 662 |
+
if add_upsample:
|
| 663 |
+
self.upsample = Upsample3D(
|
| 664 |
+
dims=dims, channels=out_channels, out_channels=out_channels
|
| 665 |
+
)
|
| 666 |
+
else:
|
| 667 |
+
self.upsample = Identity()
|
| 668 |
+
|
| 669 |
+
self.resolution_idx = resolution_idx
|
| 670 |
+
|
| 671 |
+
def forward(
|
| 672 |
+
self, hidden_states: torch.FloatTensor, upsample_in_time=True
|
| 673 |
+
) -> torch.FloatTensor:
|
| 674 |
+
for resnet in self.res_blocks:
|
| 675 |
+
hidden_states = resnet(hidden_states)
|
| 676 |
+
|
| 677 |
+
hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time)
|
| 678 |
+
|
| 679 |
+
return hidden_states
|
| 680 |
+
|
| 681 |
+
|
| 682 |
+
class ResnetBlock3D(nn.Module):
|
| 683 |
+
r"""
|
| 684 |
+
A Resnet block.
|
| 685 |
+
|
| 686 |
+
Parameters:
|
| 687 |
+
in_channels (`int`): The number of channels in the input.
|
| 688 |
+
out_channels (`int`, *optional*, default to be `None`):
|
| 689 |
+
The number of output channels for the first conv layer. If None, same as `in_channels`.
|
| 690 |
+
dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
|
| 691 |
+
groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
|
| 692 |
+
eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
|
| 693 |
+
"""
|
| 694 |
+
|
| 695 |
+
def __init__(
|
| 696 |
+
self,
|
| 697 |
+
dims: Union[int, Tuple[int, int]],
|
| 698 |
+
in_channels: int,
|
| 699 |
+
out_channels: Optional[int] = None,
|
| 700 |
+
conv_shortcut: bool = False,
|
| 701 |
+
dropout: float = 0.0,
|
| 702 |
+
groups: int = 32,
|
| 703 |
+
eps: float = 1e-6,
|
| 704 |
+
norm_layer: str = "group_norm",
|
| 705 |
+
):
|
| 706 |
+
super().__init__()
|
| 707 |
+
self.in_channels = in_channels
|
| 708 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 709 |
+
self.out_channels = out_channels
|
| 710 |
+
self.use_conv_shortcut = conv_shortcut
|
| 711 |
+
|
| 712 |
+
if norm_layer == "group_norm":
|
| 713 |
+
self.norm1 = torch.nn.GroupNorm(
|
| 714 |
+
num_groups=groups, num_channels=in_channels, eps=eps, affine=True
|
| 715 |
+
)
|
| 716 |
+
elif norm_layer == "pixel_norm":
|
| 717 |
+
self.norm1 = PixelNorm()
|
| 718 |
+
|
| 719 |
+
self.non_linearity = nn.SiLU()
|
| 720 |
+
|
| 721 |
+
self.conv1 = make_conv_nd(
|
| 722 |
+
dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
if norm_layer == "group_norm":
|
| 726 |
+
self.norm2 = torch.nn.GroupNorm(
|
| 727 |
+
num_groups=groups, num_channels=out_channels, eps=eps, affine=True
|
| 728 |
+
)
|
| 729 |
+
elif norm_layer == "pixel_norm":
|
| 730 |
+
self.norm2 = PixelNorm()
|
| 731 |
+
|
| 732 |
+
self.dropout = torch.nn.Dropout(dropout)
|
| 733 |
+
|
| 734 |
+
self.conv2 = make_conv_nd(
|
| 735 |
+
dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
self.conv_shortcut = (
|
| 739 |
+
make_linear_nd(
|
| 740 |
+
dims=dims, in_channels=in_channels, out_channels=out_channels
|
| 741 |
+
)
|
| 742 |
+
if in_channels != out_channels
|
| 743 |
+
else nn.Identity()
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
def forward(
|
| 747 |
+
self,
|
| 748 |
+
input_tensor: torch.FloatTensor,
|
| 749 |
+
) -> torch.FloatTensor:
|
| 750 |
+
hidden_states = input_tensor
|
| 751 |
+
|
| 752 |
+
hidden_states = self.norm1(hidden_states)
|
| 753 |
+
|
| 754 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 755 |
+
|
| 756 |
+
hidden_states = self.conv1(hidden_states)
|
| 757 |
+
|
| 758 |
+
hidden_states = self.norm2(hidden_states)
|
| 759 |
+
|
| 760 |
+
hidden_states = self.non_linearity(hidden_states)
|
| 761 |
+
|
| 762 |
+
hidden_states = self.dropout(hidden_states)
|
| 763 |
+
|
| 764 |
+
hidden_states = self.conv2(hidden_states)
|
| 765 |
+
|
| 766 |
+
input_tensor = self.conv_shortcut(input_tensor)
|
| 767 |
+
|
| 768 |
+
output_tensor = input_tensor + hidden_states
|
| 769 |
+
|
| 770 |
+
return output_tensor
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
class Downsample3D(nn.Module):
|
| 774 |
+
def __init__(
|
| 775 |
+
self,
|
| 776 |
+
dims,
|
| 777 |
+
in_channels: int,
|
| 778 |
+
out_channels: int,
|
| 779 |
+
kernel_size: int = 3,
|
| 780 |
+
padding: int = 1,
|
| 781 |
+
):
|
| 782 |
+
super().__init__()
|
| 783 |
+
stride: int = 2
|
| 784 |
+
self.padding = padding
|
| 785 |
+
self.in_channels = in_channels
|
| 786 |
+
self.dims = dims
|
| 787 |
+
self.conv = make_conv_nd(
|
| 788 |
+
dims=dims,
|
| 789 |
+
in_channels=in_channels,
|
| 790 |
+
out_channels=out_channels,
|
| 791 |
+
kernel_size=kernel_size,
|
| 792 |
+
stride=stride,
|
| 793 |
+
padding=padding,
|
| 794 |
+
)
|
| 795 |
+
|
| 796 |
+
def forward(self, x, downsample_in_time=True):
|
| 797 |
+
conv = self.conv
|
| 798 |
+
if self.padding == 0:
|
| 799 |
+
if self.dims == 2:
|
| 800 |
+
padding = (0, 1, 0, 1)
|
| 801 |
+
else:
|
| 802 |
+
padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0)
|
| 803 |
+
|
| 804 |
+
x = functional.pad(x, padding, mode="constant", value=0)
|
| 805 |
+
|
| 806 |
+
if self.dims == (2, 1) and not downsample_in_time:
|
| 807 |
+
return conv(x, skip_time_conv=True)
|
| 808 |
+
|
| 809 |
+
return conv(x)
|
| 810 |
+
|
| 811 |
+
|
| 812 |
+
class Upsample3D(nn.Module):
|
| 813 |
+
"""
|
| 814 |
+
An upsampling layer for 3D tensors of shape (B, C, D, H, W).
|
| 815 |
+
|
| 816 |
+
:param channels: channels in the inputs and outputs.
|
| 817 |
+
"""
|
| 818 |
+
|
| 819 |
+
def __init__(self, dims, channels, out_channels=None):
|
| 820 |
+
super().__init__()
|
| 821 |
+
self.dims = dims
|
| 822 |
+
self.channels = channels
|
| 823 |
+
self.out_channels = out_channels or channels
|
| 824 |
+
self.conv = make_conv_nd(
|
| 825 |
+
dims, channels, out_channels, kernel_size=3, padding=1, bias=True
|
| 826 |
+
)
|
| 827 |
+
|
| 828 |
+
def forward(self, x, upsample_in_time):
|
| 829 |
+
if self.dims == 2:
|
| 830 |
+
x = functional.interpolate(
|
| 831 |
+
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
|
| 832 |
+
)
|
| 833 |
+
else:
|
| 834 |
+
time_scale_factor = 2 if upsample_in_time else 1
|
| 835 |
+
# print("before:", x.shape)
|
| 836 |
+
b, c, d, h, w = x.shape
|
| 837 |
+
x = rearrange(x, "b c d h w -> (b d) c h w")
|
| 838 |
+
# height and width interpolate
|
| 839 |
+
x = functional.interpolate(
|
| 840 |
+
x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest"
|
| 841 |
+
)
|
| 842 |
+
_, _, h, w = x.shape
|
| 843 |
+
|
| 844 |
+
if not upsample_in_time and self.dims == (2, 1):
|
| 845 |
+
x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w)
|
| 846 |
+
return self.conv(x, skip_time_conv=True)
|
| 847 |
+
|
| 848 |
+
# Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension
|
| 849 |
+
x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b)
|
| 850 |
+
|
| 851 |
+
# (b h w) c 1 d
|
| 852 |
+
new_d = x.shape[-1] * time_scale_factor
|
| 853 |
+
x = functional.interpolate(x, (1, new_d), mode="nearest")
|
| 854 |
+
# (b h w) c 1 new_d
|
| 855 |
+
x = rearrange(
|
| 856 |
+
x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d
|
| 857 |
+
)
|
| 858 |
+
# b c d h w
|
| 859 |
+
|
| 860 |
+
# x = functional.interpolate(
|
| 861 |
+
# x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
| 862 |
+
# )
|
| 863 |
+
# print("after:", x.shape)
|
| 864 |
+
|
| 865 |
+
return self.conv(x)
|
| 866 |
+
|
| 867 |
+
|
| 868 |
+
def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
| 869 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 870 |
+
return x
|
| 871 |
+
if x.dim() == 4:
|
| 872 |
+
x = rearrange(
|
| 873 |
+
x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw
|
| 874 |
+
)
|
| 875 |
+
elif x.dim() == 5:
|
| 876 |
+
x = rearrange(
|
| 877 |
+
x,
|
| 878 |
+
"b c (f p) (h q) (w r) -> b (c p r q) f h w",
|
| 879 |
+
p=patch_size_t,
|
| 880 |
+
q=patch_size_hw,
|
| 881 |
+
r=patch_size_hw,
|
| 882 |
+
)
|
| 883 |
+
else:
|
| 884 |
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
| 885 |
+
|
| 886 |
+
if (
|
| 887 |
+
(x.dim() == 5)
|
| 888 |
+
and (patch_size_hw > patch_size_t)
|
| 889 |
+
and (patch_size_t > 1 or add_channel_padding)
|
| 890 |
+
):
|
| 891 |
+
channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1]
|
| 892 |
+
padding_zeros = torch.zeros(
|
| 893 |
+
x.shape[0],
|
| 894 |
+
channels_to_pad,
|
| 895 |
+
x.shape[2],
|
| 896 |
+
x.shape[3],
|
| 897 |
+
x.shape[4],
|
| 898 |
+
device=x.device,
|
| 899 |
+
dtype=x.dtype,
|
| 900 |
+
)
|
| 901 |
+
x = torch.cat([padding_zeros, x], dim=1)
|
| 902 |
+
|
| 903 |
+
return x
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False):
|
| 907 |
+
if patch_size_hw == 1 and patch_size_t == 1:
|
| 908 |
+
return x
|
| 909 |
+
|
| 910 |
+
if (
|
| 911 |
+
(x.dim() == 5)
|
| 912 |
+
and (patch_size_hw > patch_size_t)
|
| 913 |
+
and (patch_size_t > 1 or add_channel_padding)
|
| 914 |
+
):
|
| 915 |
+
channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw))
|
| 916 |
+
x = x[:, :channels_to_keep, :, :, :]
|
| 917 |
+
|
| 918 |
+
if x.dim() == 4:
|
| 919 |
+
x = rearrange(
|
| 920 |
+
x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw
|
| 921 |
+
)
|
| 922 |
+
elif x.dim() == 5:
|
| 923 |
+
x = rearrange(
|
| 924 |
+
x,
|
| 925 |
+
"b (c p r q) f h w -> b c (f p) (h q) (w r)",
|
| 926 |
+
p=patch_size_t,
|
| 927 |
+
q=patch_size_hw,
|
| 928 |
+
r=patch_size_hw,
|
| 929 |
+
)
|
| 930 |
+
|
| 931 |
+
return x
|
| 932 |
+
|
| 933 |
+
|
| 934 |
+
def create_video_autoencoder_config(
|
| 935 |
+
latent_channels: int = 4,
|
| 936 |
+
):
|
| 937 |
+
config = {
|
| 938 |
+
"_class_name": "VideoAutoencoder",
|
| 939 |
+
"dims": (
|
| 940 |
+
2,
|
| 941 |
+
1,
|
| 942 |
+
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 943 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 944 |
+
"out_channels": 3, # Number of output color channels
|
| 945 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 946 |
+
"block_out_channels": [
|
| 947 |
+
128,
|
| 948 |
+
256,
|
| 949 |
+
512,
|
| 950 |
+
512,
|
| 951 |
+
], # Number of output channels of each encoder / decoder inner block
|
| 952 |
+
"patch_size": 1,
|
| 953 |
+
}
|
| 954 |
+
|
| 955 |
+
return config
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
def create_video_autoencoder_pathify4x4x4_config(
|
| 959 |
+
latent_channels: int = 4,
|
| 960 |
+
):
|
| 961 |
+
config = {
|
| 962 |
+
"_class_name": "VideoAutoencoder",
|
| 963 |
+
"dims": (
|
| 964 |
+
2,
|
| 965 |
+
1,
|
| 966 |
+
), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 967 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 968 |
+
"out_channels": 3, # Number of output color channels
|
| 969 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 970 |
+
"block_out_channels": [512]
|
| 971 |
+
* 4, # Number of output channels of each encoder / decoder inner block
|
| 972 |
+
"patch_size": 4,
|
| 973 |
+
"latent_log_var": "uniform",
|
| 974 |
+
}
|
| 975 |
+
|
| 976 |
+
return config
|
| 977 |
+
|
| 978 |
+
|
| 979 |
+
def create_video_autoencoder_pathify4x4_config(
|
| 980 |
+
latent_channels: int = 4,
|
| 981 |
+
):
|
| 982 |
+
config = {
|
| 983 |
+
"_class_name": "VideoAutoencoder",
|
| 984 |
+
"dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d
|
| 985 |
+
"in_channels": 3, # Number of input color channels (e.g., RGB)
|
| 986 |
+
"out_channels": 3, # Number of output color channels
|
| 987 |
+
"latent_channels": latent_channels, # Number of channels in the latent space representation
|
| 988 |
+
"block_out_channels": [512]
|
| 989 |
+
* 4, # Number of output channels of each encoder / decoder inner block
|
| 990 |
+
"patch_size": 4,
|
| 991 |
+
"norm_layer": "pixel_norm",
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
return config
|
| 995 |
+
|
| 996 |
+
|
| 997 |
+
def test_vae_patchify_unpatchify():
|
| 998 |
+
import torch
|
| 999 |
+
|
| 1000 |
+
x = torch.randn(2, 3, 8, 64, 64)
|
| 1001 |
+
x_patched = patchify(x, patch_size_hw=4, patch_size_t=4)
|
| 1002 |
+
x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4)
|
| 1003 |
+
assert torch.allclose(x, x_unpatched)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def demo_video_autoencoder_forward_backward():
|
| 1007 |
+
# Configuration for the VideoAutoencoder
|
| 1008 |
+
config = create_video_autoencoder_pathify4x4x4_config()
|
| 1009 |
+
|
| 1010 |
+
# Instantiate the VideoAutoencoder with the specified configuration
|
| 1011 |
+
video_autoencoder = VideoAutoencoder.from_config(config)
|
| 1012 |
+
|
| 1013 |
+
print(video_autoencoder)
|
| 1014 |
+
|
| 1015 |
+
# Print the total number of parameters in the video autoencoder
|
| 1016 |
+
total_params = sum(p.numel() for p in video_autoencoder.parameters())
|
| 1017 |
+
print(f"Total number of parameters in VideoAutoencoder: {total_params:,}")
|
| 1018 |
+
|
| 1019 |
+
# Create a mock input tensor simulating a batch of videos
|
| 1020 |
+
# Shape: (batch_size, channels, depth, height, width)
|
| 1021 |
+
# E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame
|
| 1022 |
+
input_videos = torch.randn(2, 3, 8, 64, 64)
|
| 1023 |
+
|
| 1024 |
+
# Forward pass: encode and decode the input videos
|
| 1025 |
+
latent = video_autoencoder.encode(input_videos).latent_dist.mode()
|
| 1026 |
+
print(f"input shape={input_videos.shape}")
|
| 1027 |
+
print(f"latent shape={latent.shape}")
|
| 1028 |
+
reconstructed_videos = video_autoencoder.decode(
|
| 1029 |
+
latent, target_shape=input_videos.shape
|
| 1030 |
+
).sample
|
| 1031 |
+
|
| 1032 |
+
print(f"reconstructed shape={reconstructed_videos.shape}")
|
| 1033 |
+
|
| 1034 |
+
# Calculate the loss (e.g., mean squared error)
|
| 1035 |
+
loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos)
|
| 1036 |
+
|
| 1037 |
+
# Perform backward pass
|
| 1038 |
+
loss.backward()
|
| 1039 |
+
|
| 1040 |
+
print(f"Demo completed with loss: {loss.item()}")
|
| 1041 |
+
|
| 1042 |
+
|
| 1043 |
+
# Ensure to call the demo function to execute the forward and backward pass
|
| 1044 |
+
if __name__ == "__main__":
|
| 1045 |
+
demo_video_autoencoder_forward_backward()
|
flash_head/ltx_video/models/transformers/__init__.py
ADDED
|
File without changes
|
flash_head/ltx_video/models/transformers/attention.py
ADDED
|
@@ -0,0 +1,1265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
from importlib import import_module
|
| 3 |
+
from typing import Any, Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from diffusers.models.activations import GEGLU, GELU, ApproximateGELU
|
| 8 |
+
from diffusers.models.attention import _chunked_feed_forward
|
| 9 |
+
from diffusers.models.attention_processor import (
|
| 10 |
+
LoRAAttnAddedKVProcessor,
|
| 11 |
+
LoRAAttnProcessor,
|
| 12 |
+
LoRAAttnProcessor2_0,
|
| 13 |
+
LoRAXFormersAttnProcessor,
|
| 14 |
+
SpatialNorm,
|
| 15 |
+
)
|
| 16 |
+
from diffusers.models.lora import LoRACompatibleLinear
|
| 17 |
+
from diffusers.models.normalization import RMSNorm
|
| 18 |
+
from diffusers.utils import deprecate, logging
|
| 19 |
+
from diffusers.utils.torch_utils import maybe_allow_in_graph
|
| 20 |
+
from einops import rearrange
|
| 21 |
+
from torch import nn
|
| 22 |
+
|
| 23 |
+
from flash_head.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
| 24 |
+
|
| 25 |
+
try:
|
| 26 |
+
from torch_xla.experimental.custom_kernel import flash_attention
|
| 27 |
+
except ImportError:
|
| 28 |
+
# workaround for automatic tests. Currently this function is manually patched
|
| 29 |
+
# to the torch_xla lib on setup of container
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
# code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
|
| 33 |
+
|
| 34 |
+
logger = logging.get_logger(__name__)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@maybe_allow_in_graph
|
| 38 |
+
class BasicTransformerBlock(nn.Module):
|
| 39 |
+
r"""
|
| 40 |
+
A basic Transformer block.
|
| 41 |
+
|
| 42 |
+
Parameters:
|
| 43 |
+
dim (`int`): The number of channels in the input and output.
|
| 44 |
+
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
| 45 |
+
attention_head_dim (`int`): The number of channels in each head.
|
| 46 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 47 |
+
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
| 48 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 49 |
+
num_embeds_ada_norm (:
|
| 50 |
+
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
| 51 |
+
attention_bias (:
|
| 52 |
+
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
| 53 |
+
only_cross_attention (`bool`, *optional*):
|
| 54 |
+
Whether to use only cross-attention layers. In this case two cross attention layers are used.
|
| 55 |
+
double_self_attention (`bool`, *optional*):
|
| 56 |
+
Whether to use two self-attention layers. In this case no cross attention layers are used.
|
| 57 |
+
upcast_attention (`bool`, *optional*):
|
| 58 |
+
Whether to upcast the attention computation to float32. This is useful for mixed precision training.
|
| 59 |
+
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
|
| 60 |
+
Whether to use learnable elementwise affine parameters for normalization.
|
| 61 |
+
qk_norm (`str`, *optional*, defaults to None):
|
| 62 |
+
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
| 63 |
+
adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`):
|
| 64 |
+
The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none".
|
| 65 |
+
standardization_norm (`str`, *optional*, defaults to `"layer_norm"`):
|
| 66 |
+
The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`.
|
| 67 |
+
final_dropout (`bool` *optional*, defaults to False):
|
| 68 |
+
Whether to apply a final dropout after the last feed-forward layer.
|
| 69 |
+
attention_type (`str`, *optional*, defaults to `"default"`):
|
| 70 |
+
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
|
| 71 |
+
positional_embeddings (`str`, *optional*, defaults to `None`):
|
| 72 |
+
The type of positional embeddings to apply to.
|
| 73 |
+
num_positional_embeddings (`int`, *optional*, defaults to `None`):
|
| 74 |
+
The maximum number of positional embeddings to apply.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
dim: int,
|
| 80 |
+
num_attention_heads: int,
|
| 81 |
+
attention_head_dim: int,
|
| 82 |
+
dropout=0.0,
|
| 83 |
+
cross_attention_dim: Optional[int] = None,
|
| 84 |
+
activation_fn: str = "geglu",
|
| 85 |
+
num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument
|
| 86 |
+
attention_bias: bool = False,
|
| 87 |
+
only_cross_attention: bool = False,
|
| 88 |
+
double_self_attention: bool = False,
|
| 89 |
+
upcast_attention: bool = False,
|
| 90 |
+
norm_elementwise_affine: bool = True,
|
| 91 |
+
adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none'
|
| 92 |
+
standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
|
| 93 |
+
norm_eps: float = 1e-5,
|
| 94 |
+
qk_norm: Optional[str] = None,
|
| 95 |
+
final_dropout: bool = False,
|
| 96 |
+
attention_type: str = "default", # pylint: disable=unused-argument
|
| 97 |
+
ff_inner_dim: Optional[int] = None,
|
| 98 |
+
ff_bias: bool = True,
|
| 99 |
+
attention_out_bias: bool = True,
|
| 100 |
+
use_tpu_flash_attention: bool = False,
|
| 101 |
+
use_rope: bool = False,
|
| 102 |
+
):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.only_cross_attention = only_cross_attention
|
| 105 |
+
self.use_tpu_flash_attention = use_tpu_flash_attention
|
| 106 |
+
self.adaptive_norm = adaptive_norm
|
| 107 |
+
|
| 108 |
+
assert standardization_norm in ["layer_norm", "rms_norm"]
|
| 109 |
+
assert adaptive_norm in ["single_scale_shift", "single_scale", "none"]
|
| 110 |
+
|
| 111 |
+
make_norm_layer = (
|
| 112 |
+
nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Define 3 blocks. Each block has its own normalization layer.
|
| 116 |
+
# 1. Self-Attn
|
| 117 |
+
self.norm1 = make_norm_layer(
|
| 118 |
+
dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.attn1 = Attention(
|
| 122 |
+
query_dim=dim,
|
| 123 |
+
heads=num_attention_heads,
|
| 124 |
+
dim_head=attention_head_dim,
|
| 125 |
+
dropout=dropout,
|
| 126 |
+
bias=attention_bias,
|
| 127 |
+
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
|
| 128 |
+
upcast_attention=upcast_attention,
|
| 129 |
+
out_bias=attention_out_bias,
|
| 130 |
+
use_tpu_flash_attention=use_tpu_flash_attention,
|
| 131 |
+
qk_norm=qk_norm,
|
| 132 |
+
use_rope=use_rope,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# 2. Cross-Attn
|
| 136 |
+
if cross_attention_dim is not None or double_self_attention:
|
| 137 |
+
self.attn2 = Attention(
|
| 138 |
+
query_dim=dim,
|
| 139 |
+
cross_attention_dim=(
|
| 140 |
+
cross_attention_dim if not double_self_attention else None
|
| 141 |
+
),
|
| 142 |
+
heads=num_attention_heads,
|
| 143 |
+
dim_head=attention_head_dim,
|
| 144 |
+
dropout=dropout,
|
| 145 |
+
bias=attention_bias,
|
| 146 |
+
upcast_attention=upcast_attention,
|
| 147 |
+
out_bias=attention_out_bias,
|
| 148 |
+
use_tpu_flash_attention=use_tpu_flash_attention,
|
| 149 |
+
qk_norm=qk_norm,
|
| 150 |
+
use_rope=use_rope,
|
| 151 |
+
) # is self-attn if encoder_hidden_states is none
|
| 152 |
+
|
| 153 |
+
if adaptive_norm == "none":
|
| 154 |
+
self.attn2_norm = make_norm_layer(
|
| 155 |
+
dim, norm_eps, norm_elementwise_affine
|
| 156 |
+
)
|
| 157 |
+
else:
|
| 158 |
+
self.attn2 = None
|
| 159 |
+
self.attn2_norm = None
|
| 160 |
+
|
| 161 |
+
self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine)
|
| 162 |
+
|
| 163 |
+
# 3. Feed-forward
|
| 164 |
+
self.ff = FeedForward(
|
| 165 |
+
dim,
|
| 166 |
+
dropout=dropout,
|
| 167 |
+
activation_fn=activation_fn,
|
| 168 |
+
final_dropout=final_dropout,
|
| 169 |
+
inner_dim=ff_inner_dim,
|
| 170 |
+
bias=ff_bias,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# 5. Scale-shift for PixArt-Alpha.
|
| 174 |
+
if adaptive_norm != "none":
|
| 175 |
+
num_ada_params = 4 if adaptive_norm == "single_scale" else 6
|
| 176 |
+
self.scale_shift_table = nn.Parameter(
|
| 177 |
+
torch.randn(num_ada_params, dim) / dim**0.5
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# let chunk size default to None
|
| 181 |
+
self._chunk_size = None
|
| 182 |
+
self._chunk_dim = 0
|
| 183 |
+
|
| 184 |
+
def set_use_tpu_flash_attention(self):
|
| 185 |
+
r"""
|
| 186 |
+
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
| 187 |
+
attention kernel.
|
| 188 |
+
"""
|
| 189 |
+
self.use_tpu_flash_attention = True
|
| 190 |
+
self.attn1.set_use_tpu_flash_attention()
|
| 191 |
+
self.attn2.set_use_tpu_flash_attention()
|
| 192 |
+
|
| 193 |
+
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
|
| 194 |
+
# Sets chunk feed-forward
|
| 195 |
+
self._chunk_size = chunk_size
|
| 196 |
+
self._chunk_dim = dim
|
| 197 |
+
|
| 198 |
+
def forward(
|
| 199 |
+
self,
|
| 200 |
+
hidden_states: torch.FloatTensor,
|
| 201 |
+
freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
|
| 202 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 203 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 204 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 205 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 206 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 207 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 208 |
+
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
| 209 |
+
skip_layer_mask: Optional[torch.Tensor] = None,
|
| 210 |
+
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
|
| 211 |
+
) -> torch.FloatTensor:
|
| 212 |
+
if cross_attention_kwargs is not None:
|
| 213 |
+
if cross_attention_kwargs.get("scale", None) is not None:
|
| 214 |
+
logger.warning(
|
| 215 |
+
"Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored."
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Notice that normalization is always applied before the real computation in the following blocks.
|
| 219 |
+
# 0. Self-Attention
|
| 220 |
+
batch_size = hidden_states.shape[0]
|
| 221 |
+
|
| 222 |
+
original_hidden_states = hidden_states
|
| 223 |
+
|
| 224 |
+
norm_hidden_states = self.norm1(hidden_states)
|
| 225 |
+
|
| 226 |
+
# Apply ada_norm_single
|
| 227 |
+
if self.adaptive_norm in ["single_scale_shift", "single_scale"]:
|
| 228 |
+
assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim]
|
| 229 |
+
num_ada_params = self.scale_shift_table.shape[0]
|
| 230 |
+
ada_values = self.scale_shift_table[None, None] + timestep.reshape(
|
| 231 |
+
batch_size, timestep.shape[1], num_ada_params, -1
|
| 232 |
+
)
|
| 233 |
+
if self.adaptive_norm == "single_scale_shift":
|
| 234 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
| 235 |
+
ada_values.unbind(dim=2)
|
| 236 |
+
)
|
| 237 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
| 238 |
+
else:
|
| 239 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
| 240 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_msa)
|
| 241 |
+
elif self.adaptive_norm == "none":
|
| 242 |
+
scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None
|
| 243 |
+
else:
|
| 244 |
+
raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
|
| 245 |
+
|
| 246 |
+
norm_hidden_states = norm_hidden_states.squeeze(
|
| 247 |
+
1
|
| 248 |
+
) # TODO: Check if this is needed
|
| 249 |
+
|
| 250 |
+
# 1. Prepare GLIGEN inputs
|
| 251 |
+
cross_attention_kwargs = (
|
| 252 |
+
cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
attn_output = self.attn1(
|
| 256 |
+
norm_hidden_states,
|
| 257 |
+
freqs_cis=freqs_cis,
|
| 258 |
+
encoder_hidden_states=(
|
| 259 |
+
encoder_hidden_states if self.only_cross_attention else None
|
| 260 |
+
),
|
| 261 |
+
attention_mask=attention_mask,
|
| 262 |
+
skip_layer_mask=skip_layer_mask,
|
| 263 |
+
skip_layer_strategy=skip_layer_strategy,
|
| 264 |
+
**cross_attention_kwargs,
|
| 265 |
+
)
|
| 266 |
+
if gate_msa is not None:
|
| 267 |
+
attn_output = gate_msa * attn_output
|
| 268 |
+
|
| 269 |
+
hidden_states = attn_output + hidden_states
|
| 270 |
+
if hidden_states.ndim == 4:
|
| 271 |
+
hidden_states = hidden_states.squeeze(1)
|
| 272 |
+
|
| 273 |
+
# 3. Cross-Attention
|
| 274 |
+
if self.attn2 is not None:
|
| 275 |
+
if self.adaptive_norm == "none":
|
| 276 |
+
attn_input = self.attn2_norm(hidden_states)
|
| 277 |
+
else:
|
| 278 |
+
attn_input = hidden_states
|
| 279 |
+
attn_output = self.attn2(
|
| 280 |
+
attn_input,
|
| 281 |
+
freqs_cis=freqs_cis,
|
| 282 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 283 |
+
attention_mask=encoder_attention_mask,
|
| 284 |
+
**cross_attention_kwargs,
|
| 285 |
+
)
|
| 286 |
+
hidden_states = attn_output + hidden_states
|
| 287 |
+
|
| 288 |
+
# 4. Feed-forward
|
| 289 |
+
norm_hidden_states = self.norm2(hidden_states)
|
| 290 |
+
if self.adaptive_norm == "single_scale_shift":
|
| 291 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
|
| 292 |
+
elif self.adaptive_norm == "single_scale":
|
| 293 |
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp)
|
| 294 |
+
elif self.adaptive_norm == "none":
|
| 295 |
+
pass
|
| 296 |
+
else:
|
| 297 |
+
raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}")
|
| 298 |
+
|
| 299 |
+
if self._chunk_size is not None:
|
| 300 |
+
# "feed_forward_chunk_size" can be used to save memory
|
| 301 |
+
ff_output = _chunked_feed_forward(
|
| 302 |
+
self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size
|
| 303 |
+
)
|
| 304 |
+
else:
|
| 305 |
+
ff_output = self.ff(norm_hidden_states)
|
| 306 |
+
if gate_mlp is not None:
|
| 307 |
+
ff_output = gate_mlp * ff_output
|
| 308 |
+
|
| 309 |
+
hidden_states = ff_output + hidden_states
|
| 310 |
+
if hidden_states.ndim == 4:
|
| 311 |
+
hidden_states = hidden_states.squeeze(1)
|
| 312 |
+
|
| 313 |
+
if (
|
| 314 |
+
skip_layer_mask is not None
|
| 315 |
+
and skip_layer_strategy == SkipLayerStrategy.TransformerBlock
|
| 316 |
+
):
|
| 317 |
+
skip_layer_mask = skip_layer_mask.view(-1, 1, 1)
|
| 318 |
+
hidden_states = hidden_states * skip_layer_mask + original_hidden_states * (
|
| 319 |
+
1.0 - skip_layer_mask
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
return hidden_states
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
@maybe_allow_in_graph
|
| 326 |
+
class Attention(nn.Module):
|
| 327 |
+
r"""
|
| 328 |
+
A cross attention layer.
|
| 329 |
+
|
| 330 |
+
Parameters:
|
| 331 |
+
query_dim (`int`):
|
| 332 |
+
The number of channels in the query.
|
| 333 |
+
cross_attention_dim (`int`, *optional*):
|
| 334 |
+
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
| 335 |
+
heads (`int`, *optional*, defaults to 8):
|
| 336 |
+
The number of heads to use for multi-head attention.
|
| 337 |
+
dim_head (`int`, *optional*, defaults to 64):
|
| 338 |
+
The number of channels in each head.
|
| 339 |
+
dropout (`float`, *optional*, defaults to 0.0):
|
| 340 |
+
The dropout probability to use.
|
| 341 |
+
bias (`bool`, *optional*, defaults to False):
|
| 342 |
+
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
| 343 |
+
upcast_attention (`bool`, *optional*, defaults to False):
|
| 344 |
+
Set to `True` to upcast the attention computation to `float32`.
|
| 345 |
+
upcast_softmax (`bool`, *optional*, defaults to False):
|
| 346 |
+
Set to `True` to upcast the softmax computation to `float32`.
|
| 347 |
+
cross_attention_norm (`str`, *optional*, defaults to `None`):
|
| 348 |
+
The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
|
| 349 |
+
cross_attention_norm_num_groups (`int`, *optional*, defaults to 32):
|
| 350 |
+
The number of groups to use for the group norm in the cross attention.
|
| 351 |
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
| 352 |
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
| 353 |
+
norm_num_groups (`int`, *optional*, defaults to `None`):
|
| 354 |
+
The number of groups to use for the group norm in the attention.
|
| 355 |
+
spatial_norm_dim (`int`, *optional*, defaults to `None`):
|
| 356 |
+
The number of channels to use for the spatial normalization.
|
| 357 |
+
out_bias (`bool`, *optional*, defaults to `True`):
|
| 358 |
+
Set to `True` to use a bias in the output linear layer.
|
| 359 |
+
scale_qk (`bool`, *optional*, defaults to `True`):
|
| 360 |
+
Set to `True` to scale the query and key by `1 / sqrt(dim_head)`.
|
| 361 |
+
qk_norm (`str`, *optional*, defaults to None):
|
| 362 |
+
Set to 'layer_norm' or `rms_norm` to perform query and key normalization.
|
| 363 |
+
only_cross_attention (`bool`, *optional*, defaults to `False`):
|
| 364 |
+
Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if
|
| 365 |
+
`added_kv_proj_dim` is not `None`.
|
| 366 |
+
eps (`float`, *optional*, defaults to 1e-5):
|
| 367 |
+
An additional value added to the denominator in group normalization that is used for numerical stability.
|
| 368 |
+
rescale_output_factor (`float`, *optional*, defaults to 1.0):
|
| 369 |
+
A factor to rescale the output by dividing it with this value.
|
| 370 |
+
residual_connection (`bool`, *optional*, defaults to `False`):
|
| 371 |
+
Set to `True` to add the residual connection to the output.
|
| 372 |
+
_from_deprecated_attn_block (`bool`, *optional*, defaults to `False`):
|
| 373 |
+
Set to `True` if the attention block is loaded from a deprecated state dict.
|
| 374 |
+
processor (`AttnProcessor`, *optional*, defaults to `None`):
|
| 375 |
+
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
|
| 376 |
+
`AttnProcessor` otherwise.
|
| 377 |
+
"""
|
| 378 |
+
|
| 379 |
+
def __init__(
|
| 380 |
+
self,
|
| 381 |
+
query_dim: int,
|
| 382 |
+
cross_attention_dim: Optional[int] = None,
|
| 383 |
+
heads: int = 8,
|
| 384 |
+
dim_head: int = 64,
|
| 385 |
+
dropout: float = 0.0,
|
| 386 |
+
bias: bool = False,
|
| 387 |
+
upcast_attention: bool = False,
|
| 388 |
+
upcast_softmax: bool = False,
|
| 389 |
+
cross_attention_norm: Optional[str] = None,
|
| 390 |
+
cross_attention_norm_num_groups: int = 32,
|
| 391 |
+
added_kv_proj_dim: Optional[int] = None,
|
| 392 |
+
norm_num_groups: Optional[int] = None,
|
| 393 |
+
spatial_norm_dim: Optional[int] = None,
|
| 394 |
+
out_bias: bool = True,
|
| 395 |
+
scale_qk: bool = True,
|
| 396 |
+
qk_norm: Optional[str] = None,
|
| 397 |
+
only_cross_attention: bool = False,
|
| 398 |
+
eps: float = 1e-5,
|
| 399 |
+
rescale_output_factor: float = 1.0,
|
| 400 |
+
residual_connection: bool = False,
|
| 401 |
+
_from_deprecated_attn_block: bool = False,
|
| 402 |
+
processor: Optional["AttnProcessor"] = None,
|
| 403 |
+
out_dim: int = None,
|
| 404 |
+
use_tpu_flash_attention: bool = False,
|
| 405 |
+
use_rope: bool = False,
|
| 406 |
+
):
|
| 407 |
+
super().__init__()
|
| 408 |
+
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
|
| 409 |
+
self.query_dim = query_dim
|
| 410 |
+
self.use_bias = bias
|
| 411 |
+
self.is_cross_attention = cross_attention_dim is not None
|
| 412 |
+
self.cross_attention_dim = (
|
| 413 |
+
cross_attention_dim if cross_attention_dim is not None else query_dim
|
| 414 |
+
)
|
| 415 |
+
self.upcast_attention = upcast_attention
|
| 416 |
+
self.upcast_softmax = upcast_softmax
|
| 417 |
+
self.rescale_output_factor = rescale_output_factor
|
| 418 |
+
self.residual_connection = residual_connection
|
| 419 |
+
self.dropout = dropout
|
| 420 |
+
self.fused_projections = False
|
| 421 |
+
self.out_dim = out_dim if out_dim is not None else query_dim
|
| 422 |
+
self.use_tpu_flash_attention = use_tpu_flash_attention
|
| 423 |
+
self.use_rope = use_rope
|
| 424 |
+
|
| 425 |
+
# we make use of this private variable to know whether this class is loaded
|
| 426 |
+
# with an deprecated state dict so that we can convert it on the fly
|
| 427 |
+
self._from_deprecated_attn_block = _from_deprecated_attn_block
|
| 428 |
+
|
| 429 |
+
self.scale_qk = scale_qk
|
| 430 |
+
self.scale = dim_head**-0.5 if self.scale_qk else 1.0
|
| 431 |
+
|
| 432 |
+
if qk_norm is None:
|
| 433 |
+
self.q_norm = nn.Identity()
|
| 434 |
+
self.k_norm = nn.Identity()
|
| 435 |
+
elif qk_norm == "rms_norm":
|
| 436 |
+
self.q_norm = RMSNorm(dim_head * heads, eps=1e-5)
|
| 437 |
+
self.k_norm = RMSNorm(dim_head * heads, eps=1e-5)
|
| 438 |
+
elif qk_norm == "layer_norm":
|
| 439 |
+
self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
|
| 440 |
+
self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5)
|
| 441 |
+
else:
|
| 442 |
+
raise ValueError(f"Unsupported qk_norm method: {qk_norm}")
|
| 443 |
+
|
| 444 |
+
self.heads = out_dim // dim_head if out_dim is not None else heads
|
| 445 |
+
# for slice_size > 0 the attention score computation
|
| 446 |
+
# is split across the batch axis to save memory
|
| 447 |
+
# You can set slice_size with `set_attention_slice`
|
| 448 |
+
self.sliceable_head_dim = heads
|
| 449 |
+
|
| 450 |
+
self.added_kv_proj_dim = added_kv_proj_dim
|
| 451 |
+
self.only_cross_attention = only_cross_attention
|
| 452 |
+
|
| 453 |
+
if self.added_kv_proj_dim is None and self.only_cross_attention:
|
| 454 |
+
raise ValueError(
|
| 455 |
+
"`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`."
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
if norm_num_groups is not None:
|
| 459 |
+
self.group_norm = nn.GroupNorm(
|
| 460 |
+
num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True
|
| 461 |
+
)
|
| 462 |
+
else:
|
| 463 |
+
self.group_norm = None
|
| 464 |
+
|
| 465 |
+
if spatial_norm_dim is not None:
|
| 466 |
+
self.spatial_norm = SpatialNorm(
|
| 467 |
+
f_channels=query_dim, zq_channels=spatial_norm_dim
|
| 468 |
+
)
|
| 469 |
+
else:
|
| 470 |
+
self.spatial_norm = None
|
| 471 |
+
|
| 472 |
+
if cross_attention_norm is None:
|
| 473 |
+
self.norm_cross = None
|
| 474 |
+
elif cross_attention_norm == "layer_norm":
|
| 475 |
+
self.norm_cross = nn.LayerNorm(self.cross_attention_dim)
|
| 476 |
+
elif cross_attention_norm == "group_norm":
|
| 477 |
+
if self.added_kv_proj_dim is not None:
|
| 478 |
+
# The given `encoder_hidden_states` are initially of shape
|
| 479 |
+
# (batch_size, seq_len, added_kv_proj_dim) before being projected
|
| 480 |
+
# to (batch_size, seq_len, cross_attention_dim). The norm is applied
|
| 481 |
+
# before the projection, so we need to use `added_kv_proj_dim` as
|
| 482 |
+
# the number of channels for the group norm.
|
| 483 |
+
norm_cross_num_channels = added_kv_proj_dim
|
| 484 |
+
else:
|
| 485 |
+
norm_cross_num_channels = self.cross_attention_dim
|
| 486 |
+
|
| 487 |
+
self.norm_cross = nn.GroupNorm(
|
| 488 |
+
num_channels=norm_cross_num_channels,
|
| 489 |
+
num_groups=cross_attention_norm_num_groups,
|
| 490 |
+
eps=1e-5,
|
| 491 |
+
affine=True,
|
| 492 |
+
)
|
| 493 |
+
else:
|
| 494 |
+
raise ValueError(
|
| 495 |
+
f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'"
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
linear_cls = nn.Linear
|
| 499 |
+
|
| 500 |
+
self.linear_cls = linear_cls
|
| 501 |
+
self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias)
|
| 502 |
+
|
| 503 |
+
if not self.only_cross_attention:
|
| 504 |
+
# only relevant for the `AddedKVProcessor` classes
|
| 505 |
+
self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
| 506 |
+
self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias)
|
| 507 |
+
else:
|
| 508 |
+
self.to_k = None
|
| 509 |
+
self.to_v = None
|
| 510 |
+
|
| 511 |
+
if self.added_kv_proj_dim is not None:
|
| 512 |
+
self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
| 513 |
+
self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim)
|
| 514 |
+
|
| 515 |
+
self.to_out = nn.ModuleList([])
|
| 516 |
+
self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias))
|
| 517 |
+
self.to_out.append(nn.Dropout(dropout))
|
| 518 |
+
|
| 519 |
+
# set attention processor
|
| 520 |
+
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
|
| 521 |
+
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
|
| 522 |
+
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
|
| 523 |
+
if processor is None:
|
| 524 |
+
processor = AttnProcessor2_0()
|
| 525 |
+
self.set_processor(processor)
|
| 526 |
+
|
| 527 |
+
def set_use_tpu_flash_attention(self):
|
| 528 |
+
r"""
|
| 529 |
+
Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel.
|
| 530 |
+
"""
|
| 531 |
+
self.use_tpu_flash_attention = True
|
| 532 |
+
|
| 533 |
+
def set_processor(self, processor: "AttnProcessor") -> None:
|
| 534 |
+
r"""
|
| 535 |
+
Set the attention processor to use.
|
| 536 |
+
|
| 537 |
+
Args:
|
| 538 |
+
processor (`AttnProcessor`):
|
| 539 |
+
The attention processor to use.
|
| 540 |
+
"""
|
| 541 |
+
# if current processor is in `self._modules` and if passed `processor` is not, we need to
|
| 542 |
+
# pop `processor` from `self._modules`
|
| 543 |
+
if (
|
| 544 |
+
hasattr(self, "processor")
|
| 545 |
+
and isinstance(self.processor, torch.nn.Module)
|
| 546 |
+
and not isinstance(processor, torch.nn.Module)
|
| 547 |
+
):
|
| 548 |
+
logger.info(
|
| 549 |
+
f"You are removing possibly trained weights of {self.processor} with {processor}"
|
| 550 |
+
)
|
| 551 |
+
self._modules.pop("processor")
|
| 552 |
+
|
| 553 |
+
self.processor = processor
|
| 554 |
+
|
| 555 |
+
def get_processor(
|
| 556 |
+
self, return_deprecated_lora: bool = False
|
| 557 |
+
) -> "AttentionProcessor": # noqa: F821
|
| 558 |
+
r"""
|
| 559 |
+
Get the attention processor in use.
|
| 560 |
+
|
| 561 |
+
Args:
|
| 562 |
+
return_deprecated_lora (`bool`, *optional*, defaults to `False`):
|
| 563 |
+
Set to `True` to return the deprecated LoRA attention processor.
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
"AttentionProcessor": The attention processor in use.
|
| 567 |
+
"""
|
| 568 |
+
if not return_deprecated_lora:
|
| 569 |
+
return self.processor
|
| 570 |
+
|
| 571 |
+
# TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
|
| 572 |
+
# serialization format for LoRA Attention Processors. It should be deleted once the integration
|
| 573 |
+
# with PEFT is completed.
|
| 574 |
+
is_lora_activated = {
|
| 575 |
+
name: module.lora_layer is not None
|
| 576 |
+
for name, module in self.named_modules()
|
| 577 |
+
if hasattr(module, "lora_layer")
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
# 1. if no layer has a LoRA activated we can return the processor as usual
|
| 581 |
+
if not any(is_lora_activated.values()):
|
| 582 |
+
return self.processor
|
| 583 |
+
|
| 584 |
+
# If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
|
| 585 |
+
is_lora_activated.pop("add_k_proj", None)
|
| 586 |
+
is_lora_activated.pop("add_v_proj", None)
|
| 587 |
+
# 2. else it is not posssible that only some layers have LoRA activated
|
| 588 |
+
if not all(is_lora_activated.values()):
|
| 589 |
+
raise ValueError(
|
| 590 |
+
f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
# 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
|
| 594 |
+
non_lora_processor_cls_name = self.processor.__class__.__name__
|
| 595 |
+
lora_processor_cls = getattr(
|
| 596 |
+
import_module(__name__), "LoRA" + non_lora_processor_cls_name
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
hidden_size = self.inner_dim
|
| 600 |
+
|
| 601 |
+
# now create a LoRA attention processor from the LoRA layers
|
| 602 |
+
if lora_processor_cls in [
|
| 603 |
+
LoRAAttnProcessor,
|
| 604 |
+
LoRAAttnProcessor2_0,
|
| 605 |
+
LoRAXFormersAttnProcessor,
|
| 606 |
+
]:
|
| 607 |
+
kwargs = {
|
| 608 |
+
"cross_attention_dim": self.cross_attention_dim,
|
| 609 |
+
"rank": self.to_q.lora_layer.rank,
|
| 610 |
+
"network_alpha": self.to_q.lora_layer.network_alpha,
|
| 611 |
+
"q_rank": self.to_q.lora_layer.rank,
|
| 612 |
+
"q_hidden_size": self.to_q.lora_layer.out_features,
|
| 613 |
+
"k_rank": self.to_k.lora_layer.rank,
|
| 614 |
+
"k_hidden_size": self.to_k.lora_layer.out_features,
|
| 615 |
+
"v_rank": self.to_v.lora_layer.rank,
|
| 616 |
+
"v_hidden_size": self.to_v.lora_layer.out_features,
|
| 617 |
+
"out_rank": self.to_out[0].lora_layer.rank,
|
| 618 |
+
"out_hidden_size": self.to_out[0].lora_layer.out_features,
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
if hasattr(self.processor, "attention_op"):
|
| 622 |
+
kwargs["attention_op"] = self.processor.attention_op
|
| 623 |
+
|
| 624 |
+
lora_processor = lora_processor_cls(hidden_size, **kwargs)
|
| 625 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
| 626 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
| 627 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
| 628 |
+
lora_processor.to_out_lora.load_state_dict(
|
| 629 |
+
self.to_out[0].lora_layer.state_dict()
|
| 630 |
+
)
|
| 631 |
+
elif lora_processor_cls == LoRAAttnAddedKVProcessor:
|
| 632 |
+
lora_processor = lora_processor_cls(
|
| 633 |
+
hidden_size,
|
| 634 |
+
cross_attention_dim=self.add_k_proj.weight.shape[0],
|
| 635 |
+
rank=self.to_q.lora_layer.rank,
|
| 636 |
+
network_alpha=self.to_q.lora_layer.network_alpha,
|
| 637 |
+
)
|
| 638 |
+
lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
|
| 639 |
+
lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
|
| 640 |
+
lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
|
| 641 |
+
lora_processor.to_out_lora.load_state_dict(
|
| 642 |
+
self.to_out[0].lora_layer.state_dict()
|
| 643 |
+
)
|
| 644 |
+
|
| 645 |
+
# only save if used
|
| 646 |
+
if self.add_k_proj.lora_layer is not None:
|
| 647 |
+
lora_processor.add_k_proj_lora.load_state_dict(
|
| 648 |
+
self.add_k_proj.lora_layer.state_dict()
|
| 649 |
+
)
|
| 650 |
+
lora_processor.add_v_proj_lora.load_state_dict(
|
| 651 |
+
self.add_v_proj.lora_layer.state_dict()
|
| 652 |
+
)
|
| 653 |
+
else:
|
| 654 |
+
lora_processor.add_k_proj_lora = None
|
| 655 |
+
lora_processor.add_v_proj_lora = None
|
| 656 |
+
else:
|
| 657 |
+
raise ValueError(f"{lora_processor_cls} does not exist.")
|
| 658 |
+
|
| 659 |
+
return lora_processor
|
| 660 |
+
|
| 661 |
+
def forward(
|
| 662 |
+
self,
|
| 663 |
+
hidden_states: torch.FloatTensor,
|
| 664 |
+
freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None,
|
| 665 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 666 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 667 |
+
skip_layer_mask: Optional[torch.Tensor] = None,
|
| 668 |
+
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
|
| 669 |
+
**cross_attention_kwargs,
|
| 670 |
+
) -> torch.Tensor:
|
| 671 |
+
r"""
|
| 672 |
+
The forward method of the `Attention` class.
|
| 673 |
+
|
| 674 |
+
Args:
|
| 675 |
+
hidden_states (`torch.Tensor`):
|
| 676 |
+
The hidden states of the query.
|
| 677 |
+
encoder_hidden_states (`torch.Tensor`, *optional*):
|
| 678 |
+
The hidden states of the encoder.
|
| 679 |
+
attention_mask (`torch.Tensor`, *optional*):
|
| 680 |
+
The attention mask to use. If `None`, no mask is applied.
|
| 681 |
+
skip_layer_mask (`torch.Tensor`, *optional*):
|
| 682 |
+
The skip layer mask to use. If `None`, no mask is applied.
|
| 683 |
+
skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`):
|
| 684 |
+
Controls which layers to skip for spatiotemporal guidance.
|
| 685 |
+
**cross_attention_kwargs:
|
| 686 |
+
Additional keyword arguments to pass along to the cross attention.
|
| 687 |
+
|
| 688 |
+
Returns:
|
| 689 |
+
`torch.Tensor`: The output of the attention layer.
|
| 690 |
+
"""
|
| 691 |
+
# The `Attention` class can call different attention processors / attention functions
|
| 692 |
+
# here we simply pass along all tensors to the selected processor class
|
| 693 |
+
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
|
| 694 |
+
|
| 695 |
+
attn_parameters = set(
|
| 696 |
+
inspect.signature(self.processor.__call__).parameters.keys()
|
| 697 |
+
)
|
| 698 |
+
unused_kwargs = [
|
| 699 |
+
k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters
|
| 700 |
+
]
|
| 701 |
+
if len(unused_kwargs) > 0:
|
| 702 |
+
logger.warning(
|
| 703 |
+
f"cross_attention_kwargs {unused_kwargs} are not expected by"
|
| 704 |
+
f" {self.processor.__class__.__name__} and will be ignored."
|
| 705 |
+
)
|
| 706 |
+
cross_attention_kwargs = {
|
| 707 |
+
k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters
|
| 708 |
+
}
|
| 709 |
+
|
| 710 |
+
return self.processor(
|
| 711 |
+
self,
|
| 712 |
+
hidden_states,
|
| 713 |
+
freqs_cis=freqs_cis,
|
| 714 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 715 |
+
attention_mask=attention_mask,
|
| 716 |
+
skip_layer_mask=skip_layer_mask,
|
| 717 |
+
skip_layer_strategy=skip_layer_strategy,
|
| 718 |
+
**cross_attention_kwargs,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor:
|
| 722 |
+
r"""
|
| 723 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads`
|
| 724 |
+
is the number of heads initialized while constructing the `Attention` class.
|
| 725 |
+
|
| 726 |
+
Args:
|
| 727 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
| 728 |
+
|
| 729 |
+
Returns:
|
| 730 |
+
`torch.Tensor`: The reshaped tensor.
|
| 731 |
+
"""
|
| 732 |
+
head_size = self.heads
|
| 733 |
+
batch_size, seq_len, dim = tensor.shape
|
| 734 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 735 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(
|
| 736 |
+
batch_size // head_size, seq_len, dim * head_size
|
| 737 |
+
)
|
| 738 |
+
return tensor
|
| 739 |
+
|
| 740 |
+
def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor:
|
| 741 |
+
r"""
|
| 742 |
+
Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is
|
| 743 |
+
the number of heads initialized while constructing the `Attention` class.
|
| 744 |
+
|
| 745 |
+
Args:
|
| 746 |
+
tensor (`torch.Tensor`): The tensor to reshape.
|
| 747 |
+
out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is
|
| 748 |
+
reshaped to `[batch_size * heads, seq_len, dim // heads]`.
|
| 749 |
+
|
| 750 |
+
Returns:
|
| 751 |
+
`torch.Tensor`: The reshaped tensor.
|
| 752 |
+
"""
|
| 753 |
+
|
| 754 |
+
head_size = self.heads
|
| 755 |
+
if tensor.ndim == 3:
|
| 756 |
+
batch_size, seq_len, dim = tensor.shape
|
| 757 |
+
extra_dim = 1
|
| 758 |
+
else:
|
| 759 |
+
batch_size, extra_dim, seq_len, dim = tensor.shape
|
| 760 |
+
tensor = tensor.reshape(
|
| 761 |
+
batch_size, seq_len * extra_dim, head_size, dim // head_size
|
| 762 |
+
)
|
| 763 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
| 764 |
+
|
| 765 |
+
if out_dim == 3:
|
| 766 |
+
tensor = tensor.reshape(
|
| 767 |
+
batch_size * head_size, seq_len * extra_dim, dim // head_size
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
return tensor
|
| 771 |
+
|
| 772 |
+
def get_attention_scores(
|
| 773 |
+
self,
|
| 774 |
+
query: torch.Tensor,
|
| 775 |
+
key: torch.Tensor,
|
| 776 |
+
attention_mask: torch.Tensor = None,
|
| 777 |
+
) -> torch.Tensor:
|
| 778 |
+
r"""
|
| 779 |
+
Compute the attention scores.
|
| 780 |
+
|
| 781 |
+
Args:
|
| 782 |
+
query (`torch.Tensor`): The query tensor.
|
| 783 |
+
key (`torch.Tensor`): The key tensor.
|
| 784 |
+
attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied.
|
| 785 |
+
|
| 786 |
+
Returns:
|
| 787 |
+
`torch.Tensor`: The attention probabilities/scores.
|
| 788 |
+
"""
|
| 789 |
+
dtype = query.dtype
|
| 790 |
+
if self.upcast_attention:
|
| 791 |
+
query = query.float()
|
| 792 |
+
key = key.float()
|
| 793 |
+
|
| 794 |
+
if attention_mask is None:
|
| 795 |
+
baddbmm_input = torch.empty(
|
| 796 |
+
query.shape[0],
|
| 797 |
+
query.shape[1],
|
| 798 |
+
key.shape[1],
|
| 799 |
+
dtype=query.dtype,
|
| 800 |
+
device=query.device,
|
| 801 |
+
)
|
| 802 |
+
beta = 0
|
| 803 |
+
else:
|
| 804 |
+
baddbmm_input = attention_mask
|
| 805 |
+
beta = 1
|
| 806 |
+
|
| 807 |
+
attention_scores = torch.baddbmm(
|
| 808 |
+
baddbmm_input,
|
| 809 |
+
query,
|
| 810 |
+
key.transpose(-1, -2),
|
| 811 |
+
beta=beta,
|
| 812 |
+
alpha=self.scale,
|
| 813 |
+
)
|
| 814 |
+
del baddbmm_input
|
| 815 |
+
|
| 816 |
+
if self.upcast_softmax:
|
| 817 |
+
attention_scores = attention_scores.float()
|
| 818 |
+
|
| 819 |
+
attention_probs = attention_scores.softmax(dim=-1)
|
| 820 |
+
del attention_scores
|
| 821 |
+
|
| 822 |
+
attention_probs = attention_probs.to(dtype)
|
| 823 |
+
|
| 824 |
+
return attention_probs
|
| 825 |
+
|
| 826 |
+
def prepare_attention_mask(
|
| 827 |
+
self,
|
| 828 |
+
attention_mask: torch.Tensor,
|
| 829 |
+
target_length: int,
|
| 830 |
+
batch_size: int,
|
| 831 |
+
out_dim: int = 3,
|
| 832 |
+
) -> torch.Tensor:
|
| 833 |
+
r"""
|
| 834 |
+
Prepare the attention mask for the attention computation.
|
| 835 |
+
|
| 836 |
+
Args:
|
| 837 |
+
attention_mask (`torch.Tensor`):
|
| 838 |
+
The attention mask to prepare.
|
| 839 |
+
target_length (`int`):
|
| 840 |
+
The target length of the attention mask. This is the length of the attention mask after padding.
|
| 841 |
+
batch_size (`int`):
|
| 842 |
+
The batch size, which is used to repeat the attention mask.
|
| 843 |
+
out_dim (`int`, *optional*, defaults to `3`):
|
| 844 |
+
The output dimension of the attention mask. Can be either `3` or `4`.
|
| 845 |
+
|
| 846 |
+
Returns:
|
| 847 |
+
`torch.Tensor`: The prepared attention mask.
|
| 848 |
+
"""
|
| 849 |
+
head_size = self.heads
|
| 850 |
+
if attention_mask is None:
|
| 851 |
+
return attention_mask
|
| 852 |
+
|
| 853 |
+
current_length: int = attention_mask.shape[-1]
|
| 854 |
+
if current_length != target_length:
|
| 855 |
+
if attention_mask.device.type == "mps":
|
| 856 |
+
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
|
| 857 |
+
# Instead, we can manually construct the padding tensor.
|
| 858 |
+
padding_shape = (
|
| 859 |
+
attention_mask.shape[0],
|
| 860 |
+
attention_mask.shape[1],
|
| 861 |
+
target_length,
|
| 862 |
+
)
|
| 863 |
+
padding = torch.zeros(
|
| 864 |
+
padding_shape,
|
| 865 |
+
dtype=attention_mask.dtype,
|
| 866 |
+
device=attention_mask.device,
|
| 867 |
+
)
|
| 868 |
+
attention_mask = torch.cat([attention_mask, padding], dim=2)
|
| 869 |
+
else:
|
| 870 |
+
# TODO: for pipelines such as stable-diffusion, padding cross-attn mask:
|
| 871 |
+
# we want to instead pad by (0, remaining_length), where remaining_length is:
|
| 872 |
+
# remaining_length: int = target_length - current_length
|
| 873 |
+
# TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding
|
| 874 |
+
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
| 875 |
+
|
| 876 |
+
if out_dim == 3:
|
| 877 |
+
if attention_mask.shape[0] < batch_size * head_size:
|
| 878 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
| 879 |
+
elif out_dim == 4:
|
| 880 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 881 |
+
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
|
| 882 |
+
|
| 883 |
+
return attention_mask
|
| 884 |
+
|
| 885 |
+
def norm_encoder_hidden_states(
|
| 886 |
+
self, encoder_hidden_states: torch.Tensor
|
| 887 |
+
) -> torch.Tensor:
|
| 888 |
+
r"""
|
| 889 |
+
Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the
|
| 890 |
+
`Attention` class.
|
| 891 |
+
|
| 892 |
+
Args:
|
| 893 |
+
encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder.
|
| 894 |
+
|
| 895 |
+
Returns:
|
| 896 |
+
`torch.Tensor`: The normalized encoder hidden states.
|
| 897 |
+
"""
|
| 898 |
+
assert (
|
| 899 |
+
self.norm_cross is not None
|
| 900 |
+
), "self.norm_cross must be defined to call self.norm_encoder_hidden_states"
|
| 901 |
+
|
| 902 |
+
if isinstance(self.norm_cross, nn.LayerNorm):
|
| 903 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 904 |
+
elif isinstance(self.norm_cross, nn.GroupNorm):
|
| 905 |
+
# Group norm norms along the channels dimension and expects
|
| 906 |
+
# input to be in the shape of (N, C, *). In this case, we want
|
| 907 |
+
# to norm along the hidden dimension, so we need to move
|
| 908 |
+
# (batch_size, sequence_length, hidden_size) ->
|
| 909 |
+
# (batch_size, hidden_size, sequence_length)
|
| 910 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 911 |
+
encoder_hidden_states = self.norm_cross(encoder_hidden_states)
|
| 912 |
+
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
| 913 |
+
else:
|
| 914 |
+
assert False
|
| 915 |
+
|
| 916 |
+
return encoder_hidden_states
|
| 917 |
+
|
| 918 |
+
@staticmethod
|
| 919 |
+
def apply_rotary_emb(
|
| 920 |
+
input_tensor: torch.Tensor,
|
| 921 |
+
freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
|
| 922 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 923 |
+
cos_freqs = freqs_cis[0]
|
| 924 |
+
sin_freqs = freqs_cis[1]
|
| 925 |
+
|
| 926 |
+
t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
|
| 927 |
+
t1, t2 = t_dup.unbind(dim=-1)
|
| 928 |
+
t_dup = torch.stack((-t2, t1), dim=-1)
|
| 929 |
+
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")
|
| 930 |
+
|
| 931 |
+
out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs
|
| 932 |
+
|
| 933 |
+
return out
|
| 934 |
+
|
| 935 |
+
|
| 936 |
+
class AttnProcessor2_0:
|
| 937 |
+
r"""
|
| 938 |
+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
| 939 |
+
"""
|
| 940 |
+
|
| 941 |
+
def __init__(self):
|
| 942 |
+
pass
|
| 943 |
+
|
| 944 |
+
def __call__(
|
| 945 |
+
self,
|
| 946 |
+
attn: Attention,
|
| 947 |
+
hidden_states: torch.FloatTensor,
|
| 948 |
+
freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor],
|
| 949 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 950 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 951 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 952 |
+
skip_layer_mask: Optional[torch.FloatTensor] = None,
|
| 953 |
+
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
|
| 954 |
+
*args,
|
| 955 |
+
**kwargs,
|
| 956 |
+
) -> torch.FloatTensor:
|
| 957 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
| 958 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
| 959 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
| 960 |
+
|
| 961 |
+
residual = hidden_states
|
| 962 |
+
if attn.spatial_norm is not None:
|
| 963 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 964 |
+
|
| 965 |
+
input_ndim = hidden_states.ndim
|
| 966 |
+
|
| 967 |
+
if input_ndim == 4:
|
| 968 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 969 |
+
hidden_states = hidden_states.view(
|
| 970 |
+
batch_size, channel, height * width
|
| 971 |
+
).transpose(1, 2)
|
| 972 |
+
|
| 973 |
+
batch_size, sequence_length, _ = (
|
| 974 |
+
hidden_states.shape
|
| 975 |
+
if encoder_hidden_states is None
|
| 976 |
+
else encoder_hidden_states.shape
|
| 977 |
+
)
|
| 978 |
+
|
| 979 |
+
if skip_layer_mask is not None:
|
| 980 |
+
skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1)
|
| 981 |
+
|
| 982 |
+
if (attention_mask is not None) and (not attn.use_tpu_flash_attention):
|
| 983 |
+
attention_mask = attn.prepare_attention_mask(
|
| 984 |
+
attention_mask, sequence_length, batch_size
|
| 985 |
+
)
|
| 986 |
+
# scaled_dot_product_attention expects attention_mask shape to be
|
| 987 |
+
# (batch, heads, source_length, target_length)
|
| 988 |
+
attention_mask = attention_mask.view(
|
| 989 |
+
batch_size, attn.heads, -1, attention_mask.shape[-1]
|
| 990 |
+
)
|
| 991 |
+
|
| 992 |
+
if attn.group_norm is not None:
|
| 993 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 994 |
+
1, 2
|
| 995 |
+
)
|
| 996 |
+
|
| 997 |
+
query = attn.to_q(hidden_states)
|
| 998 |
+
query = attn.q_norm(query)
|
| 999 |
+
|
| 1000 |
+
if encoder_hidden_states is not None:
|
| 1001 |
+
if attn.norm_cross:
|
| 1002 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 1003 |
+
encoder_hidden_states
|
| 1004 |
+
)
|
| 1005 |
+
key = attn.to_k(encoder_hidden_states)
|
| 1006 |
+
key = attn.k_norm(key)
|
| 1007 |
+
else: # if no context provided do self-attention
|
| 1008 |
+
encoder_hidden_states = hidden_states
|
| 1009 |
+
key = attn.to_k(hidden_states)
|
| 1010 |
+
key = attn.k_norm(key)
|
| 1011 |
+
if attn.use_rope:
|
| 1012 |
+
key = attn.apply_rotary_emb(key, freqs_cis)
|
| 1013 |
+
query = attn.apply_rotary_emb(query, freqs_cis)
|
| 1014 |
+
|
| 1015 |
+
value = attn.to_v(encoder_hidden_states)
|
| 1016 |
+
value_for_stg = value
|
| 1017 |
+
|
| 1018 |
+
inner_dim = key.shape[-1]
|
| 1019 |
+
head_dim = inner_dim // attn.heads
|
| 1020 |
+
|
| 1021 |
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 1022 |
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 1023 |
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 1024 |
+
|
| 1025 |
+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
| 1026 |
+
|
| 1027 |
+
if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention'
|
| 1028 |
+
q_segment_indexes = None
|
| 1029 |
+
if (
|
| 1030 |
+
attention_mask is not None
|
| 1031 |
+
): # if mask is required need to tune both segmenIds fields
|
| 1032 |
+
# attention_mask = torch.squeeze(attention_mask).to(torch.float32)
|
| 1033 |
+
attention_mask = attention_mask.to(torch.float32)
|
| 1034 |
+
q_segment_indexes = torch.ones(
|
| 1035 |
+
batch_size, query.shape[2], device=query.device, dtype=torch.float32
|
| 1036 |
+
)
|
| 1037 |
+
assert (
|
| 1038 |
+
attention_mask.shape[1] == key.shape[2]
|
| 1039 |
+
), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]"
|
| 1040 |
+
|
| 1041 |
+
assert (
|
| 1042 |
+
query.shape[2] % 128 == 0
|
| 1043 |
+
), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]"
|
| 1044 |
+
assert (
|
| 1045 |
+
key.shape[2] % 128 == 0
|
| 1046 |
+
), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]"
|
| 1047 |
+
|
| 1048 |
+
# run the TPU kernel implemented in jax with pallas
|
| 1049 |
+
hidden_states_a = flash_attention(
|
| 1050 |
+
q=query,
|
| 1051 |
+
k=key,
|
| 1052 |
+
v=value,
|
| 1053 |
+
q_segment_ids=q_segment_indexes,
|
| 1054 |
+
kv_segment_ids=attention_mask,
|
| 1055 |
+
sm_scale=attn.scale,
|
| 1056 |
+
)
|
| 1057 |
+
else:
|
| 1058 |
+
hidden_states_a = F.scaled_dot_product_attention(
|
| 1059 |
+
query,
|
| 1060 |
+
key,
|
| 1061 |
+
value,
|
| 1062 |
+
attn_mask=attention_mask,
|
| 1063 |
+
dropout_p=0.0,
|
| 1064 |
+
is_causal=False,
|
| 1065 |
+
)
|
| 1066 |
+
|
| 1067 |
+
hidden_states_a = hidden_states_a.transpose(1, 2).reshape(
|
| 1068 |
+
batch_size, -1, attn.heads * head_dim
|
| 1069 |
+
)
|
| 1070 |
+
hidden_states_a = hidden_states_a.to(query.dtype)
|
| 1071 |
+
|
| 1072 |
+
if (
|
| 1073 |
+
skip_layer_mask is not None
|
| 1074 |
+
and skip_layer_strategy == SkipLayerStrategy.AttentionSkip
|
| 1075 |
+
):
|
| 1076 |
+
hidden_states = hidden_states_a * skip_layer_mask + hidden_states * (
|
| 1077 |
+
1.0 - skip_layer_mask
|
| 1078 |
+
)
|
| 1079 |
+
elif (
|
| 1080 |
+
skip_layer_mask is not None
|
| 1081 |
+
and skip_layer_strategy == SkipLayerStrategy.AttentionValues
|
| 1082 |
+
):
|
| 1083 |
+
hidden_states = hidden_states_a * skip_layer_mask + value_for_stg * (
|
| 1084 |
+
1.0 - skip_layer_mask
|
| 1085 |
+
)
|
| 1086 |
+
else:
|
| 1087 |
+
hidden_states = hidden_states_a
|
| 1088 |
+
|
| 1089 |
+
# linear proj
|
| 1090 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 1091 |
+
# dropout
|
| 1092 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 1093 |
+
|
| 1094 |
+
if input_ndim == 4:
|
| 1095 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 1096 |
+
batch_size, channel, height, width
|
| 1097 |
+
)
|
| 1098 |
+
if (
|
| 1099 |
+
skip_layer_mask is not None
|
| 1100 |
+
and skip_layer_strategy == SkipLayerStrategy.Residual
|
| 1101 |
+
):
|
| 1102 |
+
skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1)
|
| 1103 |
+
|
| 1104 |
+
if attn.residual_connection:
|
| 1105 |
+
if (
|
| 1106 |
+
skip_layer_mask is not None
|
| 1107 |
+
and skip_layer_strategy == SkipLayerStrategy.Residual
|
| 1108 |
+
):
|
| 1109 |
+
hidden_states = hidden_states + residual * skip_layer_mask
|
| 1110 |
+
else:
|
| 1111 |
+
hidden_states = hidden_states + residual
|
| 1112 |
+
|
| 1113 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 1114 |
+
|
| 1115 |
+
return hidden_states
|
| 1116 |
+
|
| 1117 |
+
|
| 1118 |
+
class AttnProcessor:
|
| 1119 |
+
r"""
|
| 1120 |
+
Default processor for performing attention-related computations.
|
| 1121 |
+
"""
|
| 1122 |
+
|
| 1123 |
+
def __call__(
|
| 1124 |
+
self,
|
| 1125 |
+
attn: Attention,
|
| 1126 |
+
hidden_states: torch.FloatTensor,
|
| 1127 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 1128 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 1129 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 1130 |
+
*args,
|
| 1131 |
+
**kwargs,
|
| 1132 |
+
) -> torch.Tensor:
|
| 1133 |
+
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
| 1134 |
+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
| 1135 |
+
deprecate("scale", "1.0.0", deprecation_message)
|
| 1136 |
+
|
| 1137 |
+
residual = hidden_states
|
| 1138 |
+
|
| 1139 |
+
if attn.spatial_norm is not None:
|
| 1140 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 1141 |
+
|
| 1142 |
+
input_ndim = hidden_states.ndim
|
| 1143 |
+
|
| 1144 |
+
if input_ndim == 4:
|
| 1145 |
+
batch_size, channel, height, width = hidden_states.shape
|
| 1146 |
+
hidden_states = hidden_states.view(
|
| 1147 |
+
batch_size, channel, height * width
|
| 1148 |
+
).transpose(1, 2)
|
| 1149 |
+
|
| 1150 |
+
batch_size, sequence_length, _ = (
|
| 1151 |
+
hidden_states.shape
|
| 1152 |
+
if encoder_hidden_states is None
|
| 1153 |
+
else encoder_hidden_states.shape
|
| 1154 |
+
)
|
| 1155 |
+
attention_mask = attn.prepare_attention_mask(
|
| 1156 |
+
attention_mask, sequence_length, batch_size
|
| 1157 |
+
)
|
| 1158 |
+
|
| 1159 |
+
if attn.group_norm is not None:
|
| 1160 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(
|
| 1161 |
+
1, 2
|
| 1162 |
+
)
|
| 1163 |
+
|
| 1164 |
+
query = attn.to_q(hidden_states)
|
| 1165 |
+
|
| 1166 |
+
if encoder_hidden_states is None:
|
| 1167 |
+
encoder_hidden_states = hidden_states
|
| 1168 |
+
elif attn.norm_cross:
|
| 1169 |
+
encoder_hidden_states = attn.norm_encoder_hidden_states(
|
| 1170 |
+
encoder_hidden_states
|
| 1171 |
+
)
|
| 1172 |
+
|
| 1173 |
+
key = attn.to_k(encoder_hidden_states)
|
| 1174 |
+
value = attn.to_v(encoder_hidden_states)
|
| 1175 |
+
|
| 1176 |
+
query = attn.head_to_batch_dim(query)
|
| 1177 |
+
key = attn.head_to_batch_dim(key)
|
| 1178 |
+
value = attn.head_to_batch_dim(value)
|
| 1179 |
+
|
| 1180 |
+
query = attn.q_norm(query)
|
| 1181 |
+
key = attn.k_norm(key)
|
| 1182 |
+
|
| 1183 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
| 1184 |
+
hidden_states = torch.bmm(attention_probs, value)
|
| 1185 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
| 1186 |
+
|
| 1187 |
+
# linear proj
|
| 1188 |
+
hidden_states = attn.to_out[0](hidden_states)
|
| 1189 |
+
# dropout
|
| 1190 |
+
hidden_states = attn.to_out[1](hidden_states)
|
| 1191 |
+
|
| 1192 |
+
if input_ndim == 4:
|
| 1193 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(
|
| 1194 |
+
batch_size, channel, height, width
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
if attn.residual_connection:
|
| 1198 |
+
hidden_states = hidden_states + residual
|
| 1199 |
+
|
| 1200 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
| 1201 |
+
|
| 1202 |
+
return hidden_states
|
| 1203 |
+
|
| 1204 |
+
|
| 1205 |
+
class FeedForward(nn.Module):
|
| 1206 |
+
r"""
|
| 1207 |
+
A feed-forward layer.
|
| 1208 |
+
|
| 1209 |
+
Parameters:
|
| 1210 |
+
dim (`int`): The number of channels in the input.
|
| 1211 |
+
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
| 1212 |
+
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
| 1213 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 1214 |
+
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
| 1215 |
+
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
| 1216 |
+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
| 1217 |
+
"""
|
| 1218 |
+
|
| 1219 |
+
def __init__(
|
| 1220 |
+
self,
|
| 1221 |
+
dim: int,
|
| 1222 |
+
dim_out: Optional[int] = None,
|
| 1223 |
+
mult: int = 4,
|
| 1224 |
+
dropout: float = 0.0,
|
| 1225 |
+
activation_fn: str = "geglu",
|
| 1226 |
+
final_dropout: bool = False,
|
| 1227 |
+
inner_dim=None,
|
| 1228 |
+
bias: bool = True,
|
| 1229 |
+
):
|
| 1230 |
+
super().__init__()
|
| 1231 |
+
if inner_dim is None:
|
| 1232 |
+
inner_dim = int(dim * mult)
|
| 1233 |
+
dim_out = dim_out if dim_out is not None else dim
|
| 1234 |
+
linear_cls = nn.Linear
|
| 1235 |
+
|
| 1236 |
+
if activation_fn == "gelu":
|
| 1237 |
+
act_fn = GELU(dim, inner_dim, bias=bias)
|
| 1238 |
+
elif activation_fn == "gelu-approximate":
|
| 1239 |
+
act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias)
|
| 1240 |
+
elif activation_fn == "geglu":
|
| 1241 |
+
act_fn = GEGLU(dim, inner_dim, bias=bias)
|
| 1242 |
+
elif activation_fn == "geglu-approximate":
|
| 1243 |
+
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
| 1244 |
+
else:
|
| 1245 |
+
raise ValueError(f"Unsupported activation function: {activation_fn}")
|
| 1246 |
+
|
| 1247 |
+
self.net = nn.ModuleList([])
|
| 1248 |
+
# project in
|
| 1249 |
+
self.net.append(act_fn)
|
| 1250 |
+
# project dropout
|
| 1251 |
+
self.net.append(nn.Dropout(dropout))
|
| 1252 |
+
# project out
|
| 1253 |
+
self.net.append(linear_cls(inner_dim, dim_out, bias=bias))
|
| 1254 |
+
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
| 1255 |
+
if final_dropout:
|
| 1256 |
+
self.net.append(nn.Dropout(dropout))
|
| 1257 |
+
|
| 1258 |
+
def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
|
| 1259 |
+
compatible_cls = (GEGLU, LoRACompatibleLinear)
|
| 1260 |
+
for module in self.net:
|
| 1261 |
+
if isinstance(module, compatible_cls):
|
| 1262 |
+
hidden_states = module(hidden_states, scale)
|
| 1263 |
+
else:
|
| 1264 |
+
hidden_states = module(hidden_states)
|
| 1265 |
+
return hidden_states
|
flash_head/ltx_video/models/transformers/embeddings.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch import nn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def get_timestep_embedding(
|
| 11 |
+
timesteps: torch.Tensor,
|
| 12 |
+
embedding_dim: int,
|
| 13 |
+
flip_sin_to_cos: bool = False,
|
| 14 |
+
downscale_freq_shift: float = 1,
|
| 15 |
+
scale: float = 1,
|
| 16 |
+
max_period: int = 10000,
|
| 17 |
+
):
|
| 18 |
+
"""
|
| 19 |
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
| 20 |
+
|
| 21 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 22 |
+
These may be fractional.
|
| 23 |
+
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
|
| 24 |
+
embeddings. :return: an [N x dim] Tensor of positional embeddings.
|
| 25 |
+
"""
|
| 26 |
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
| 27 |
+
|
| 28 |
+
half_dim = embedding_dim // 2
|
| 29 |
+
exponent = -math.log(max_period) * torch.arange(
|
| 30 |
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
| 31 |
+
)
|
| 32 |
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
| 33 |
+
|
| 34 |
+
emb = torch.exp(exponent)
|
| 35 |
+
emb = timesteps[:, None].float() * emb[None, :]
|
| 36 |
+
|
| 37 |
+
# scale embeddings
|
| 38 |
+
emb = scale * emb
|
| 39 |
+
|
| 40 |
+
# concat sine and cosine embeddings
|
| 41 |
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
| 42 |
+
|
| 43 |
+
# flip sine and cosine embeddings
|
| 44 |
+
if flip_sin_to_cos:
|
| 45 |
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
| 46 |
+
|
| 47 |
+
# zero pad
|
| 48 |
+
if embedding_dim % 2 == 1:
|
| 49 |
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
| 50 |
+
return emb
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f):
|
| 54 |
+
"""
|
| 55 |
+
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
| 56 |
+
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
| 57 |
+
"""
|
| 58 |
+
grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w)
|
| 59 |
+
grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w)
|
| 60 |
+
grid = grid.reshape([3, 1, w, h, f])
|
| 61 |
+
pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid)
|
| 62 |
+
pos_embed = pos_embed.transpose(1, 0, 2, 3)
|
| 63 |
+
return rearrange(pos_embed, "h w f c -> (f h w) c")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_3d_sincos_pos_embed_from_grid(embed_dim, grid):
|
| 67 |
+
if embed_dim % 3 != 0:
|
| 68 |
+
raise ValueError("embed_dim must be divisible by 3")
|
| 69 |
+
|
| 70 |
+
# use half of dimensions to encode grid_h
|
| 71 |
+
emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3)
|
| 72 |
+
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3)
|
| 73 |
+
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3)
|
| 74 |
+
|
| 75 |
+
emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D)
|
| 76 |
+
return emb
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 80 |
+
"""
|
| 81 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
| 82 |
+
"""
|
| 83 |
+
if embed_dim % 2 != 0:
|
| 84 |
+
raise ValueError("embed_dim must be divisible by 2")
|
| 85 |
+
|
| 86 |
+
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
| 87 |
+
omega /= embed_dim / 2.0
|
| 88 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 89 |
+
|
| 90 |
+
pos_shape = pos.shape
|
| 91 |
+
|
| 92 |
+
pos = pos.reshape(-1)
|
| 93 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 94 |
+
out = out.reshape([*pos_shape, -1])[0]
|
| 95 |
+
|
| 96 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 97 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 98 |
+
|
| 99 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D)
|
| 100 |
+
return emb
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class SinusoidalPositionalEmbedding(nn.Module):
|
| 104 |
+
"""Apply positional information to a sequence of embeddings.
|
| 105 |
+
|
| 106 |
+
Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to
|
| 107 |
+
them
|
| 108 |
+
|
| 109 |
+
Args:
|
| 110 |
+
embed_dim: (int): Dimension of the positional embedding.
|
| 111 |
+
max_seq_length: Maximum sequence length to apply positional embeddings
|
| 112 |
+
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def __init__(self, embed_dim: int, max_seq_length: int = 32):
|
| 116 |
+
super().__init__()
|
| 117 |
+
position = torch.arange(max_seq_length).unsqueeze(1)
|
| 118 |
+
div_term = torch.exp(
|
| 119 |
+
torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)
|
| 120 |
+
)
|
| 121 |
+
pe = torch.zeros(1, max_seq_length, embed_dim)
|
| 122 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
| 123 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
| 124 |
+
self.register_buffer("pe", pe)
|
| 125 |
+
|
| 126 |
+
def forward(self, x):
|
| 127 |
+
_, seq_length, _ = x.shape
|
| 128 |
+
x = x + self.pe[:, :seq_length]
|
| 129 |
+
return x
|
flash_head/ltx_video/models/transformers/symmetric_patchifier.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import Tuple
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers.configuration_utils import ConfigMixin
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Patchifier(ConfigMixin, ABC):
|
| 11 |
+
def __init__(self, patch_size: int):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self._patch_size = (1, patch_size, patch_size)
|
| 14 |
+
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
|
| 17 |
+
raise NotImplementedError("Patchify method not implemented")
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def unpatchify(
|
| 21 |
+
self,
|
| 22 |
+
latents: Tensor,
|
| 23 |
+
output_height: int,
|
| 24 |
+
output_width: int,
|
| 25 |
+
out_channels: int,
|
| 26 |
+
) -> Tuple[Tensor, Tensor]:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@property
|
| 30 |
+
def patch_size(self):
|
| 31 |
+
return self._patch_size
|
| 32 |
+
|
| 33 |
+
def get_latent_coords(
|
| 34 |
+
self, latent_num_frames, latent_height, latent_width, batch_size, device
|
| 35 |
+
):
|
| 36 |
+
"""
|
| 37 |
+
Return a tensor of shape [batch_size, 3, num_patches] containing the
|
| 38 |
+
top-left corner latent coordinates of each latent patch.
|
| 39 |
+
The tensor is repeated for each batch element.
|
| 40 |
+
"""
|
| 41 |
+
latent_sample_coords = torch.meshgrid(
|
| 42 |
+
torch.arange(0, latent_num_frames, self._patch_size[0], device=device),
|
| 43 |
+
torch.arange(0, latent_height, self._patch_size[1], device=device),
|
| 44 |
+
torch.arange(0, latent_width, self._patch_size[2], device=device),
|
| 45 |
+
)
|
| 46 |
+
latent_sample_coords = torch.stack(latent_sample_coords, dim=0)
|
| 47 |
+
latent_coords = latent_sample_coords.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1)
|
| 48 |
+
latent_coords = rearrange(
|
| 49 |
+
latent_coords, "b c f h w -> b c (f h w)", b=batch_size
|
| 50 |
+
)
|
| 51 |
+
return latent_coords
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SymmetricPatchifier(Patchifier):
|
| 55 |
+
def patchify(self, latents: Tensor) -> Tuple[Tensor, Tensor]:
|
| 56 |
+
b, _, f, h, w = latents.shape
|
| 57 |
+
latent_coords = self.get_latent_coords(f, h, w, b, latents.device)
|
| 58 |
+
latents = rearrange(
|
| 59 |
+
latents,
|
| 60 |
+
"b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)",
|
| 61 |
+
p1=self._patch_size[0],
|
| 62 |
+
p2=self._patch_size[1],
|
| 63 |
+
p3=self._patch_size[2],
|
| 64 |
+
)
|
| 65 |
+
return latents, latent_coords
|
| 66 |
+
|
| 67 |
+
def unpatchify(
|
| 68 |
+
self,
|
| 69 |
+
latents: Tensor,
|
| 70 |
+
output_height: int,
|
| 71 |
+
output_width: int,
|
| 72 |
+
out_channels: int,
|
| 73 |
+
) -> Tuple[Tensor, Tensor]:
|
| 74 |
+
output_height = output_height // self._patch_size[1]
|
| 75 |
+
output_width = output_width // self._patch_size[2]
|
| 76 |
+
latents = rearrange(
|
| 77 |
+
latents,
|
| 78 |
+
"b (f h w) (c p q) -> b c f (h p) (w q)",
|
| 79 |
+
h=output_height,
|
| 80 |
+
w=output_width,
|
| 81 |
+
p=self._patch_size[1],
|
| 82 |
+
q=self._patch_size[2],
|
| 83 |
+
)
|
| 84 |
+
return latents
|
flash_head/ltx_video/models/transformers/transformer3d.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py
|
| 2 |
+
import math
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from typing import Any, Dict, List, Optional, Union
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import glob
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 12 |
+
from diffusers.models.embeddings import PixArtAlphaTextProjection
|
| 13 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 14 |
+
from diffusers.models.normalization import AdaLayerNormSingle
|
| 15 |
+
from diffusers.utils import BaseOutput, is_torch_version
|
| 16 |
+
from diffusers.utils import logging
|
| 17 |
+
from torch import nn
|
| 18 |
+
from safetensors import safe_open
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
from flash_head.ltx_video.models.transformers.attention import BasicTransformerBlock
|
| 22 |
+
from flash_head.ltx_video.utils.skip_layer_strategy import SkipLayerStrategy
|
| 23 |
+
|
| 24 |
+
from flash_head.ltx_video.utils.diffusers_config_mapping import (
|
| 25 |
+
diffusers_and_ours_config_mapping,
|
| 26 |
+
make_hashable_key,
|
| 27 |
+
TRANSFORMER_KEYS_RENAME_DICT,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
logger = logging.get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class Transformer3DModelOutput(BaseOutput):
|
| 36 |
+
"""
|
| 37 |
+
The output of [`Transformer2DModel`].
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
| 41 |
+
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
|
| 42 |
+
distributions for the unnoised latent pixels.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
sample: torch.FloatTensor
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Transformer3DModel(ModelMixin, ConfigMixin):
|
| 49 |
+
_supports_gradient_checkpointing = True
|
| 50 |
+
|
| 51 |
+
@register_to_config
|
| 52 |
+
def __init__(
|
| 53 |
+
self,
|
| 54 |
+
num_attention_heads: int = 16,
|
| 55 |
+
attention_head_dim: int = 88,
|
| 56 |
+
in_channels: Optional[int] = None,
|
| 57 |
+
out_channels: Optional[int] = None,
|
| 58 |
+
num_layers: int = 1,
|
| 59 |
+
dropout: float = 0.0,
|
| 60 |
+
norm_num_groups: int = 32,
|
| 61 |
+
cross_attention_dim: Optional[int] = None,
|
| 62 |
+
attention_bias: bool = False,
|
| 63 |
+
num_vector_embeds: Optional[int] = None,
|
| 64 |
+
activation_fn: str = "geglu",
|
| 65 |
+
num_embeds_ada_norm: Optional[int] = None,
|
| 66 |
+
use_linear_projection: bool = False,
|
| 67 |
+
only_cross_attention: bool = False,
|
| 68 |
+
double_self_attention: bool = False,
|
| 69 |
+
upcast_attention: bool = False,
|
| 70 |
+
adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale'
|
| 71 |
+
standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm'
|
| 72 |
+
norm_elementwise_affine: bool = True,
|
| 73 |
+
norm_eps: float = 1e-5,
|
| 74 |
+
attention_type: str = "default",
|
| 75 |
+
caption_channels: int = None,
|
| 76 |
+
use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention')
|
| 77 |
+
qk_norm: Optional[str] = None,
|
| 78 |
+
positional_embedding_type: str = "rope",
|
| 79 |
+
positional_embedding_theta: Optional[float] = None,
|
| 80 |
+
positional_embedding_max_pos: Optional[List[int]] = None,
|
| 81 |
+
timestep_scale_multiplier: Optional[float] = None,
|
| 82 |
+
causal_temporal_positioning: bool = False, # For backward compatibility, will be deprecated
|
| 83 |
+
):
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.use_tpu_flash_attention = (
|
| 86 |
+
use_tpu_flash_attention # FIXME: push config down to the attention modules
|
| 87 |
+
)
|
| 88 |
+
self.use_linear_projection = use_linear_projection
|
| 89 |
+
self.num_attention_heads = num_attention_heads
|
| 90 |
+
self.attention_head_dim = attention_head_dim
|
| 91 |
+
inner_dim = num_attention_heads * attention_head_dim
|
| 92 |
+
self.inner_dim = inner_dim
|
| 93 |
+
self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True)
|
| 94 |
+
self.positional_embedding_type = positional_embedding_type
|
| 95 |
+
self.positional_embedding_theta = positional_embedding_theta
|
| 96 |
+
self.positional_embedding_max_pos = positional_embedding_max_pos
|
| 97 |
+
self.use_rope = self.positional_embedding_type == "rope"
|
| 98 |
+
self.timestep_scale_multiplier = timestep_scale_multiplier
|
| 99 |
+
|
| 100 |
+
if self.positional_embedding_type == "absolute":
|
| 101 |
+
raise ValueError("Absolute positional embedding is no longer supported")
|
| 102 |
+
elif self.positional_embedding_type == "rope":
|
| 103 |
+
if positional_embedding_theta is None:
|
| 104 |
+
raise ValueError(
|
| 105 |
+
"If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined"
|
| 106 |
+
)
|
| 107 |
+
if positional_embedding_max_pos is None:
|
| 108 |
+
raise ValueError(
|
| 109 |
+
"If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# 3. Define transformers blocks
|
| 113 |
+
self.transformer_blocks = nn.ModuleList(
|
| 114 |
+
[
|
| 115 |
+
BasicTransformerBlock(
|
| 116 |
+
inner_dim,
|
| 117 |
+
num_attention_heads,
|
| 118 |
+
attention_head_dim,
|
| 119 |
+
dropout=dropout,
|
| 120 |
+
cross_attention_dim=cross_attention_dim,
|
| 121 |
+
activation_fn=activation_fn,
|
| 122 |
+
num_embeds_ada_norm=num_embeds_ada_norm,
|
| 123 |
+
attention_bias=attention_bias,
|
| 124 |
+
only_cross_attention=only_cross_attention,
|
| 125 |
+
double_self_attention=double_self_attention,
|
| 126 |
+
upcast_attention=upcast_attention,
|
| 127 |
+
adaptive_norm=adaptive_norm,
|
| 128 |
+
standardization_norm=standardization_norm,
|
| 129 |
+
norm_elementwise_affine=norm_elementwise_affine,
|
| 130 |
+
norm_eps=norm_eps,
|
| 131 |
+
attention_type=attention_type,
|
| 132 |
+
use_tpu_flash_attention=use_tpu_flash_attention,
|
| 133 |
+
qk_norm=qk_norm,
|
| 134 |
+
use_rope=self.use_rope,
|
| 135 |
+
)
|
| 136 |
+
for d in range(num_layers)
|
| 137 |
+
]
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
# 4. Define output layers
|
| 141 |
+
self.out_channels = in_channels if out_channels is None else out_channels
|
| 142 |
+
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
| 143 |
+
self.scale_shift_table = nn.Parameter(
|
| 144 |
+
torch.randn(2, inner_dim) / inner_dim**0.5
|
| 145 |
+
)
|
| 146 |
+
self.proj_out = nn.Linear(inner_dim, self.out_channels)
|
| 147 |
+
|
| 148 |
+
self.adaln_single = AdaLayerNormSingle(
|
| 149 |
+
inner_dim, use_additional_conditions=False
|
| 150 |
+
)
|
| 151 |
+
if adaptive_norm == "single_scale":
|
| 152 |
+
self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True)
|
| 153 |
+
|
| 154 |
+
self.caption_projection = None
|
| 155 |
+
if caption_channels is not None:
|
| 156 |
+
self.caption_projection = PixArtAlphaTextProjection(
|
| 157 |
+
in_features=caption_channels, hidden_size=inner_dim
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
self.gradient_checkpointing = False
|
| 161 |
+
|
| 162 |
+
def set_use_tpu_flash_attention(self):
|
| 163 |
+
r"""
|
| 164 |
+
Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU
|
| 165 |
+
attention kernel.
|
| 166 |
+
"""
|
| 167 |
+
logger.info("ENABLE TPU FLASH ATTENTION -> TRUE")
|
| 168 |
+
self.use_tpu_flash_attention = True
|
| 169 |
+
# push config down to the attention modules
|
| 170 |
+
for block in self.transformer_blocks:
|
| 171 |
+
block.set_use_tpu_flash_attention()
|
| 172 |
+
|
| 173 |
+
def create_skip_layer_mask(
|
| 174 |
+
self,
|
| 175 |
+
batch_size: int,
|
| 176 |
+
num_conds: int,
|
| 177 |
+
ptb_index: int,
|
| 178 |
+
skip_block_list: Optional[List[int]] = None,
|
| 179 |
+
):
|
| 180 |
+
if skip_block_list is None or len(skip_block_list) == 0:
|
| 181 |
+
return None
|
| 182 |
+
num_layers = len(self.transformer_blocks)
|
| 183 |
+
mask = torch.ones(
|
| 184 |
+
(num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype
|
| 185 |
+
)
|
| 186 |
+
for block_idx in skip_block_list:
|
| 187 |
+
mask[block_idx, ptb_index::num_conds] = 0
|
| 188 |
+
return mask
|
| 189 |
+
|
| 190 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 191 |
+
if hasattr(module, "gradient_checkpointing"):
|
| 192 |
+
module.gradient_checkpointing = value
|
| 193 |
+
|
| 194 |
+
def get_fractional_positions(self, indices_grid):
|
| 195 |
+
fractional_positions = torch.stack(
|
| 196 |
+
[
|
| 197 |
+
indices_grid[:, i] / self.positional_embedding_max_pos[i]
|
| 198 |
+
for i in range(3)
|
| 199 |
+
],
|
| 200 |
+
dim=-1,
|
| 201 |
+
)
|
| 202 |
+
return fractional_positions
|
| 203 |
+
|
| 204 |
+
def precompute_freqs_cis(self, indices_grid, spacing="exp"):
|
| 205 |
+
dtype = torch.float32 # We need full precision in the freqs_cis computation.
|
| 206 |
+
dim = self.inner_dim
|
| 207 |
+
theta = self.positional_embedding_theta
|
| 208 |
+
|
| 209 |
+
fractional_positions = self.get_fractional_positions(indices_grid)
|
| 210 |
+
|
| 211 |
+
start = 1
|
| 212 |
+
end = theta
|
| 213 |
+
device = fractional_positions.device
|
| 214 |
+
if spacing == "exp":
|
| 215 |
+
indices = theta ** (
|
| 216 |
+
torch.linspace(
|
| 217 |
+
math.log(start, theta),
|
| 218 |
+
math.log(end, theta),
|
| 219 |
+
dim // 6,
|
| 220 |
+
device=device,
|
| 221 |
+
dtype=dtype,
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
indices = indices.to(dtype=dtype)
|
| 225 |
+
elif spacing == "exp_2":
|
| 226 |
+
indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim)
|
| 227 |
+
indices = indices.to(dtype=dtype)
|
| 228 |
+
elif spacing == "linear":
|
| 229 |
+
indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype)
|
| 230 |
+
elif spacing == "sqrt":
|
| 231 |
+
indices = torch.linspace(
|
| 232 |
+
start**2, end**2, dim // 6, device=device, dtype=dtype
|
| 233 |
+
).sqrt()
|
| 234 |
+
|
| 235 |
+
indices = indices * math.pi / 2
|
| 236 |
+
|
| 237 |
+
if spacing == "exp_2":
|
| 238 |
+
freqs = (
|
| 239 |
+
(indices * fractional_positions.unsqueeze(-1))
|
| 240 |
+
.transpose(-1, -2)
|
| 241 |
+
.flatten(2)
|
| 242 |
+
)
|
| 243 |
+
else:
|
| 244 |
+
freqs = (
|
| 245 |
+
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
|
| 246 |
+
.transpose(-1, -2)
|
| 247 |
+
.flatten(2)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
|
| 251 |
+
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
|
| 252 |
+
if dim % 6 != 0:
|
| 253 |
+
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
|
| 254 |
+
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
|
| 255 |
+
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
|
| 256 |
+
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
|
| 257 |
+
return cos_freq.to(self.dtype), sin_freq.to(self.dtype)
|
| 258 |
+
|
| 259 |
+
def load_state_dict(
|
| 260 |
+
self,
|
| 261 |
+
state_dict: Dict,
|
| 262 |
+
*args,
|
| 263 |
+
**kwargs,
|
| 264 |
+
):
|
| 265 |
+
if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]):
|
| 266 |
+
state_dict = {
|
| 267 |
+
key.replace("model.diffusion_model.", ""): value
|
| 268 |
+
for key, value in state_dict.items()
|
| 269 |
+
if key.startswith("model.diffusion_model.")
|
| 270 |
+
}
|
| 271 |
+
super().load_state_dict(state_dict, **kwargs)
|
| 272 |
+
|
| 273 |
+
@classmethod
|
| 274 |
+
def from_pretrained(
|
| 275 |
+
cls,
|
| 276 |
+
pretrained_model_path: Optional[Union[str, os.PathLike]],
|
| 277 |
+
*args,
|
| 278 |
+
**kwargs,
|
| 279 |
+
):
|
| 280 |
+
pretrained_model_path = Path(pretrained_model_path)
|
| 281 |
+
if pretrained_model_path.is_dir():
|
| 282 |
+
config_path = pretrained_model_path / "transformer" / "config.json"
|
| 283 |
+
with open(config_path, "r") as f:
|
| 284 |
+
config = make_hashable_key(json.load(f))
|
| 285 |
+
|
| 286 |
+
assert config in diffusers_and_ours_config_mapping, (
|
| 287 |
+
"Provided diffusers checkpoint config for transformer is not suppported. "
|
| 288 |
+
"We only support diffusers configs found in Lightricks/LTX-Video."
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
config = diffusers_and_ours_config_mapping[config]
|
| 292 |
+
state_dict = {}
|
| 293 |
+
ckpt_paths = (
|
| 294 |
+
pretrained_model_path
|
| 295 |
+
/ "transformer"
|
| 296 |
+
/ "diffusion_pytorch_model*.safetensors"
|
| 297 |
+
)
|
| 298 |
+
dict_list = glob.glob(str(ckpt_paths))
|
| 299 |
+
for dict_path in dict_list:
|
| 300 |
+
part_dict = {}
|
| 301 |
+
with safe_open(dict_path, framework="pt", device="cpu") as f:
|
| 302 |
+
for k in f.keys():
|
| 303 |
+
part_dict[k] = f.get_tensor(k)
|
| 304 |
+
state_dict.update(part_dict)
|
| 305 |
+
|
| 306 |
+
for key in list(state_dict.keys()):
|
| 307 |
+
new_key = key
|
| 308 |
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
| 309 |
+
new_key = new_key.replace(replace_key, rename_key)
|
| 310 |
+
state_dict[new_key] = state_dict.pop(key)
|
| 311 |
+
|
| 312 |
+
with torch.device("meta"):
|
| 313 |
+
transformer = cls.from_config(config)
|
| 314 |
+
transformer.load_state_dict(state_dict, assign=True, strict=True)
|
| 315 |
+
elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith(
|
| 316 |
+
".safetensors"
|
| 317 |
+
):
|
| 318 |
+
comfy_single_file_state_dict = {}
|
| 319 |
+
with safe_open(pretrained_model_path, framework="pt", device="cpu") as f:
|
| 320 |
+
metadata = f.metadata()
|
| 321 |
+
for k in f.keys():
|
| 322 |
+
comfy_single_file_state_dict[k] = f.get_tensor(k)
|
| 323 |
+
configs = json.loads(metadata["config"])
|
| 324 |
+
transformer_config = configs["transformer"]
|
| 325 |
+
with torch.device("meta"):
|
| 326 |
+
transformer = Transformer3DModel.from_config(transformer_config)
|
| 327 |
+
transformer.load_state_dict(comfy_single_file_state_dict, assign=True)
|
| 328 |
+
return transformer
|
| 329 |
+
|
| 330 |
+
def forward(
|
| 331 |
+
self,
|
| 332 |
+
hidden_states: torch.Tensor,
|
| 333 |
+
indices_grid: torch.Tensor,
|
| 334 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 335 |
+
timestep: Optional[torch.LongTensor] = None,
|
| 336 |
+
class_labels: Optional[torch.LongTensor] = None,
|
| 337 |
+
cross_attention_kwargs: Dict[str, Any] = None,
|
| 338 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 339 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 340 |
+
skip_layer_mask: Optional[torch.Tensor] = None,
|
| 341 |
+
skip_layer_strategy: Optional[SkipLayerStrategy] = None,
|
| 342 |
+
return_dict: bool = True,
|
| 343 |
+
):
|
| 344 |
+
"""
|
| 345 |
+
The [`Transformer2DModel`] forward method.
|
| 346 |
+
|
| 347 |
+
Args:
|
| 348 |
+
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
|
| 349 |
+
Input `hidden_states`.
|
| 350 |
+
indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`):
|
| 351 |
+
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
|
| 352 |
+
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
| 353 |
+
self-attention.
|
| 354 |
+
timestep ( `torch.LongTensor`, *optional*):
|
| 355 |
+
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
|
| 356 |
+
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
| 357 |
+
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
|
| 358 |
+
`AdaLayerZeroNorm`.
|
| 359 |
+
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
|
| 360 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 361 |
+
`self.processor` in
|
| 362 |
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
| 363 |
+
attention_mask ( `torch.Tensor`, *optional*):
|
| 364 |
+
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
|
| 365 |
+
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
|
| 366 |
+
negative values to the attention scores corresponding to "discard" tokens.
|
| 367 |
+
encoder_attention_mask ( `torch.Tensor`, *optional*):
|
| 368 |
+
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
|
| 369 |
+
|
| 370 |
+
* Mask `(batch, sequence_length)` True = keep, False = discard.
|
| 371 |
+
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
|
| 372 |
+
|
| 373 |
+
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
|
| 374 |
+
above. This bias will be added to the cross-attention scores.
|
| 375 |
+
skip_layer_mask ( `torch.Tensor`, *optional*):
|
| 376 |
+
A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position
|
| 377 |
+
`layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index.
|
| 378 |
+
skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`):
|
| 379 |
+
Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance.
|
| 380 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 381 |
+
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
|
| 382 |
+
tuple.
|
| 383 |
+
|
| 384 |
+
Returns:
|
| 385 |
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
| 386 |
+
`tuple` where the first element is the sample tensor.
|
| 387 |
+
"""
|
| 388 |
+
# for tpu attention offload 2d token masks are used. No need to transform.
|
| 389 |
+
if not self.use_tpu_flash_attention:
|
| 390 |
+
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
|
| 391 |
+
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
|
| 392 |
+
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
|
| 393 |
+
# expects mask of shape:
|
| 394 |
+
# [batch, key_tokens]
|
| 395 |
+
# adds singleton query_tokens dimension:
|
| 396 |
+
# [batch, 1, key_tokens]
|
| 397 |
+
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
|
| 398 |
+
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
|
| 399 |
+
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
|
| 400 |
+
if attention_mask is not None and attention_mask.ndim == 2:
|
| 401 |
+
# assume that mask is expressed as:
|
| 402 |
+
# (1 = keep, 0 = discard)
|
| 403 |
+
# convert mask into a bias that can be added to attention scores:
|
| 404 |
+
# (keep = +0, discard = -10000.0)
|
| 405 |
+
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
| 406 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 407 |
+
|
| 408 |
+
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
| 409 |
+
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
| 410 |
+
encoder_attention_mask = (
|
| 411 |
+
1 - encoder_attention_mask.to(hidden_states.dtype)
|
| 412 |
+
) * -10000.0
|
| 413 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 414 |
+
|
| 415 |
+
# 1. Input
|
| 416 |
+
hidden_states = self.patchify_proj(hidden_states)
|
| 417 |
+
|
| 418 |
+
if self.timestep_scale_multiplier:
|
| 419 |
+
timestep = self.timestep_scale_multiplier * timestep
|
| 420 |
+
|
| 421 |
+
freqs_cis = self.precompute_freqs_cis(indices_grid)
|
| 422 |
+
|
| 423 |
+
batch_size = hidden_states.shape[0]
|
| 424 |
+
timestep, embedded_timestep = self.adaln_single(
|
| 425 |
+
timestep.flatten(),
|
| 426 |
+
{"resolution": None, "aspect_ratio": None},
|
| 427 |
+
batch_size=batch_size,
|
| 428 |
+
hidden_dtype=hidden_states.dtype,
|
| 429 |
+
)
|
| 430 |
+
# Second dimension is 1 or number of tokens (if timestep_per_token)
|
| 431 |
+
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
|
| 432 |
+
embedded_timestep = embedded_timestep.view(
|
| 433 |
+
batch_size, -1, embedded_timestep.shape[-1]
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
# 2. Blocks
|
| 437 |
+
if self.caption_projection is not None:
|
| 438 |
+
batch_size = hidden_states.shape[0]
|
| 439 |
+
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
| 440 |
+
encoder_hidden_states = encoder_hidden_states.view(
|
| 441 |
+
batch_size, -1, hidden_states.shape[-1]
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
for block_idx, block in enumerate(self.transformer_blocks):
|
| 445 |
+
if self.training and self.gradient_checkpointing:
|
| 446 |
+
|
| 447 |
+
def create_custom_forward(module, return_dict=None):
|
| 448 |
+
def custom_forward(*inputs):
|
| 449 |
+
if return_dict is not None:
|
| 450 |
+
return module(*inputs, return_dict=return_dict)
|
| 451 |
+
else:
|
| 452 |
+
return module(*inputs)
|
| 453 |
+
|
| 454 |
+
return custom_forward
|
| 455 |
+
|
| 456 |
+
ckpt_kwargs: Dict[str, Any] = (
|
| 457 |
+
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 458 |
+
)
|
| 459 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 460 |
+
create_custom_forward(block),
|
| 461 |
+
hidden_states,
|
| 462 |
+
freqs_cis,
|
| 463 |
+
attention_mask,
|
| 464 |
+
encoder_hidden_states,
|
| 465 |
+
encoder_attention_mask,
|
| 466 |
+
timestep,
|
| 467 |
+
cross_attention_kwargs,
|
| 468 |
+
class_labels,
|
| 469 |
+
(
|
| 470 |
+
skip_layer_mask[block_idx]
|
| 471 |
+
if skip_layer_mask is not None
|
| 472 |
+
else None
|
| 473 |
+
),
|
| 474 |
+
skip_layer_strategy,
|
| 475 |
+
**ckpt_kwargs,
|
| 476 |
+
)
|
| 477 |
+
else:
|
| 478 |
+
hidden_states = block(
|
| 479 |
+
hidden_states,
|
| 480 |
+
freqs_cis=freqs_cis,
|
| 481 |
+
attention_mask=attention_mask,
|
| 482 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 483 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 484 |
+
timestep=timestep,
|
| 485 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 486 |
+
class_labels=class_labels,
|
| 487 |
+
skip_layer_mask=(
|
| 488 |
+
skip_layer_mask[block_idx]
|
| 489 |
+
if skip_layer_mask is not None
|
| 490 |
+
else None
|
| 491 |
+
),
|
| 492 |
+
skip_layer_strategy=skip_layer_strategy,
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
# 3. Output
|
| 496 |
+
scale_shift_values = (
|
| 497 |
+
self.scale_shift_table[None, None] + embedded_timestep[:, :, None]
|
| 498 |
+
)
|
| 499 |
+
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
|
| 500 |
+
hidden_states = self.norm_out(hidden_states)
|
| 501 |
+
# Modulation
|
| 502 |
+
hidden_states = hidden_states * (1 + scale) + shift
|
| 503 |
+
hidden_states = self.proj_out(hidden_states)
|
| 504 |
+
if not return_dict:
|
| 505 |
+
return (hidden_states,)
|
| 506 |
+
|
| 507 |
+
return Transformer3DModelOutput(sample=hidden_states)
|
flash_head/ltx_video/utils/__init__.py
ADDED
|
File without changes
|
flash_head/ltx_video/utils/diffusers_config_mapping.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def make_hashable_key(dict_key):
|
| 2 |
+
def convert_value(value):
|
| 3 |
+
if isinstance(value, list):
|
| 4 |
+
return tuple(value)
|
| 5 |
+
elif isinstance(value, dict):
|
| 6 |
+
return tuple(sorted((k, convert_value(v)) for k, v in value.items()))
|
| 7 |
+
else:
|
| 8 |
+
return value
|
| 9 |
+
|
| 10 |
+
return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items()))
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
DIFFUSERS_SCHEDULER_CONFIG = {
|
| 14 |
+
"_class_name": "FlowMatchEulerDiscreteScheduler",
|
| 15 |
+
"_diffusers_version": "0.32.0.dev0",
|
| 16 |
+
"base_image_seq_len": 1024,
|
| 17 |
+
"base_shift": 0.95,
|
| 18 |
+
"invert_sigmas": False,
|
| 19 |
+
"max_image_seq_len": 4096,
|
| 20 |
+
"max_shift": 2.05,
|
| 21 |
+
"num_train_timesteps": 1000,
|
| 22 |
+
"shift": 1.0,
|
| 23 |
+
"shift_terminal": 0.1,
|
| 24 |
+
"use_beta_sigmas": False,
|
| 25 |
+
"use_dynamic_shifting": True,
|
| 26 |
+
"use_exponential_sigmas": False,
|
| 27 |
+
"use_karras_sigmas": False,
|
| 28 |
+
}
|
| 29 |
+
DIFFUSERS_TRANSFORMER_CONFIG = {
|
| 30 |
+
"_class_name": "LTXVideoTransformer3DModel",
|
| 31 |
+
"_diffusers_version": "0.32.0.dev0",
|
| 32 |
+
"activation_fn": "gelu-approximate",
|
| 33 |
+
"attention_bias": True,
|
| 34 |
+
"attention_head_dim": 64,
|
| 35 |
+
"attention_out_bias": True,
|
| 36 |
+
"caption_channels": 4096,
|
| 37 |
+
"cross_attention_dim": 2048,
|
| 38 |
+
"in_channels": 128,
|
| 39 |
+
"norm_elementwise_affine": False,
|
| 40 |
+
"norm_eps": 1e-06,
|
| 41 |
+
"num_attention_heads": 32,
|
| 42 |
+
"num_layers": 28,
|
| 43 |
+
"out_channels": 128,
|
| 44 |
+
"patch_size": 1,
|
| 45 |
+
"patch_size_t": 1,
|
| 46 |
+
"qk_norm": "rms_norm_across_heads",
|
| 47 |
+
}
|
| 48 |
+
DIFFUSERS_VAE_CONFIG = {
|
| 49 |
+
"_class_name": "AutoencoderKLLTXVideo",
|
| 50 |
+
"_diffusers_version": "0.32.0.dev0",
|
| 51 |
+
"block_out_channels": [128, 256, 512, 512],
|
| 52 |
+
"decoder_causal": False,
|
| 53 |
+
"encoder_causal": True,
|
| 54 |
+
"in_channels": 3,
|
| 55 |
+
"latent_channels": 128,
|
| 56 |
+
"layers_per_block": [4, 3, 3, 3, 4],
|
| 57 |
+
"out_channels": 3,
|
| 58 |
+
"patch_size": 4,
|
| 59 |
+
"patch_size_t": 1,
|
| 60 |
+
"resnet_norm_eps": 1e-06,
|
| 61 |
+
"scaling_factor": 1.0,
|
| 62 |
+
"spatio_temporal_scaling": [True, True, True, False],
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
OURS_SCHEDULER_CONFIG = {
|
| 66 |
+
"_class_name": "RectifiedFlowScheduler",
|
| 67 |
+
"_diffusers_version": "0.25.1",
|
| 68 |
+
"num_train_timesteps": 1000,
|
| 69 |
+
"shifting": "SD3",
|
| 70 |
+
"base_resolution": None,
|
| 71 |
+
"target_shift_terminal": 0.1,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
OURS_TRANSFORMER_CONFIG = {
|
| 75 |
+
"_class_name": "Transformer3DModel",
|
| 76 |
+
"_diffusers_version": "0.25.1",
|
| 77 |
+
"_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256",
|
| 78 |
+
"activation_fn": "gelu-approximate",
|
| 79 |
+
"attention_bias": True,
|
| 80 |
+
"attention_head_dim": 64,
|
| 81 |
+
"attention_type": "default",
|
| 82 |
+
"caption_channels": 4096,
|
| 83 |
+
"cross_attention_dim": 2048,
|
| 84 |
+
"double_self_attention": False,
|
| 85 |
+
"dropout": 0.0,
|
| 86 |
+
"in_channels": 128,
|
| 87 |
+
"norm_elementwise_affine": False,
|
| 88 |
+
"norm_eps": 1e-06,
|
| 89 |
+
"norm_num_groups": 32,
|
| 90 |
+
"num_attention_heads": 32,
|
| 91 |
+
"num_embeds_ada_norm": 1000,
|
| 92 |
+
"num_layers": 28,
|
| 93 |
+
"num_vector_embeds": None,
|
| 94 |
+
"only_cross_attention": False,
|
| 95 |
+
"out_channels": 128,
|
| 96 |
+
"project_to_2d_pos": True,
|
| 97 |
+
"upcast_attention": False,
|
| 98 |
+
"use_linear_projection": False,
|
| 99 |
+
"qk_norm": "rms_norm",
|
| 100 |
+
"standardization_norm": "rms_norm",
|
| 101 |
+
"positional_embedding_type": "rope",
|
| 102 |
+
"positional_embedding_theta": 10000.0,
|
| 103 |
+
"positional_embedding_max_pos": [20, 2048, 2048],
|
| 104 |
+
"timestep_scale_multiplier": 1000,
|
| 105 |
+
}
|
| 106 |
+
OURS_VAE_CONFIG = {
|
| 107 |
+
"_class_name": "CausalVideoAutoencoder",
|
| 108 |
+
"dims": 3,
|
| 109 |
+
"in_channels": 3,
|
| 110 |
+
"out_channels": 3,
|
| 111 |
+
"latent_channels": 128,
|
| 112 |
+
"blocks": [
|
| 113 |
+
["res_x", 4],
|
| 114 |
+
["compress_all", 1],
|
| 115 |
+
["res_x_y", 1],
|
| 116 |
+
["res_x", 3],
|
| 117 |
+
["compress_all", 1],
|
| 118 |
+
["res_x_y", 1],
|
| 119 |
+
["res_x", 3],
|
| 120 |
+
["compress_all", 1],
|
| 121 |
+
["res_x", 3],
|
| 122 |
+
["res_x", 4],
|
| 123 |
+
],
|
| 124 |
+
"scaling_factor": 1.0,
|
| 125 |
+
"norm_layer": "pixel_norm",
|
| 126 |
+
"patch_size": 4,
|
| 127 |
+
"latent_log_var": "uniform",
|
| 128 |
+
"use_quant_conv": False,
|
| 129 |
+
"causal_decoder": False,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
diffusers_and_ours_config_mapping = {
|
| 134 |
+
make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG,
|
| 135 |
+
make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG,
|
| 136 |
+
make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG,
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
TRANSFORMER_KEYS_RENAME_DICT = {
|
| 141 |
+
"proj_in": "patchify_proj",
|
| 142 |
+
"time_embed": "adaln_single",
|
| 143 |
+
"norm_q": "q_norm",
|
| 144 |
+
"norm_k": "k_norm",
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
VAE_KEYS_RENAME_DICT = {
|
| 149 |
+
"decoder.up_blocks.3.conv_in": "decoder.up_blocks.7",
|
| 150 |
+
"decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8",
|
| 151 |
+
"decoder.up_blocks.3": "decoder.up_blocks.9",
|
| 152 |
+
"decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5",
|
| 153 |
+
"decoder.up_blocks.2.conv_in": "decoder.up_blocks.4",
|
| 154 |
+
"decoder.up_blocks.2": "decoder.up_blocks.6",
|
| 155 |
+
"decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2",
|
| 156 |
+
"decoder.up_blocks.1": "decoder.up_blocks.3",
|
| 157 |
+
"decoder.up_blocks.0": "decoder.up_blocks.1",
|
| 158 |
+
"decoder.mid_block": "decoder.up_blocks.0",
|
| 159 |
+
"encoder.down_blocks.3": "encoder.down_blocks.8",
|
| 160 |
+
"encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7",
|
| 161 |
+
"encoder.down_blocks.2": "encoder.down_blocks.6",
|
| 162 |
+
"encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4",
|
| 163 |
+
"encoder.down_blocks.1.conv_out": "encoder.down_blocks.5",
|
| 164 |
+
"encoder.down_blocks.1": "encoder.down_blocks.3",
|
| 165 |
+
"encoder.down_blocks.0.conv_out": "encoder.down_blocks.2",
|
| 166 |
+
"encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1",
|
| 167 |
+
"encoder.down_blocks.0": "encoder.down_blocks.0",
|
| 168 |
+
"encoder.mid_block": "encoder.down_blocks.9",
|
| 169 |
+
"conv_shortcut.conv": "conv_shortcut",
|
| 170 |
+
"resnets": "res_blocks",
|
| 171 |
+
"norm3": "norm3.norm",
|
| 172 |
+
"latents_mean": "per_channel_statistics.mean-of-means",
|
| 173 |
+
"latents_std": "per_channel_statistics.std-of-means",
|
| 174 |
+
}
|
flash_head/ltx_video/utils/prompt_enhance_utils.py
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Union, List, Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
| 8 |
+
|
| 9 |
+
T2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
|
| 10 |
+
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
|
| 11 |
+
Start directly with the action, and keep descriptions literal and precise.
|
| 12 |
+
Think like a cinematographer describing a shot list.
|
| 13 |
+
Do not change the user input intent, just enhance it.
|
| 14 |
+
Keep within 150 words.
|
| 15 |
+
For best results, build your prompts using this structure:
|
| 16 |
+
Start with main action in a single sentence
|
| 17 |
+
Add specific details about movements and gestures
|
| 18 |
+
Describe character/object appearances precisely
|
| 19 |
+
Include background and environment details
|
| 20 |
+
Specify camera angles and movements
|
| 21 |
+
Describe lighting and colors
|
| 22 |
+
Note any changes or sudden events
|
| 23 |
+
Do not exceed the 150 word limit!
|
| 24 |
+
Output the enhanced prompt only.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
I2V_CINEMATIC_PROMPT = """You are an expert cinematic director with many award winning movies, When writing prompts based on the user input, focus on detailed, chronological descriptions of actions and scenes.
|
| 28 |
+
Include specific movements, appearances, camera angles, and environmental details - all in a single flowing paragraph.
|
| 29 |
+
Start directly with the action, and keep descriptions literal and precise.
|
| 30 |
+
Think like a cinematographer describing a shot list.
|
| 31 |
+
Keep within 150 words.
|
| 32 |
+
For best results, build your prompts using this structure:
|
| 33 |
+
Describe the image first and then add the user input. Image description should be in first priority! Align to the image caption if it contradicts the user text input.
|
| 34 |
+
Start with main action in a single sentence
|
| 35 |
+
Add specific details about movements and gestures
|
| 36 |
+
Describe character/object appearances precisely
|
| 37 |
+
Include background and environment details
|
| 38 |
+
Specify camera angles and movements
|
| 39 |
+
Describe lighting and colors
|
| 40 |
+
Note any changes or sudden events
|
| 41 |
+
Align to the image caption if it contradicts the user text input.
|
| 42 |
+
Do not exceed the 150 word limit!
|
| 43 |
+
Output the enhanced prompt only.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def tensor_to_pil(tensor):
|
| 48 |
+
# Ensure tensor is in range [-1, 1]
|
| 49 |
+
assert tensor.min() >= -1 and tensor.max() <= 1
|
| 50 |
+
|
| 51 |
+
# Convert from [-1, 1] to [0, 1]
|
| 52 |
+
tensor = (tensor + 1) / 2
|
| 53 |
+
|
| 54 |
+
# Rearrange from [C, H, W] to [H, W, C]
|
| 55 |
+
tensor = tensor.permute(1, 2, 0)
|
| 56 |
+
|
| 57 |
+
# Convert to numpy array and then to uint8 range [0, 255]
|
| 58 |
+
numpy_image = (tensor.cpu().numpy() * 255).astype("uint8")
|
| 59 |
+
|
| 60 |
+
# Convert to PIL Image
|
| 61 |
+
return Image.fromarray(numpy_image)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def generate_cinematic_prompt(
|
| 65 |
+
image_caption_model,
|
| 66 |
+
image_caption_processor,
|
| 67 |
+
prompt_enhancer_model,
|
| 68 |
+
prompt_enhancer_tokenizer,
|
| 69 |
+
prompt: Union[str, List[str]],
|
| 70 |
+
conditioning_items: Optional[List] = None,
|
| 71 |
+
max_new_tokens: int = 256,
|
| 72 |
+
) -> List[str]:
|
| 73 |
+
prompts = [prompt] if isinstance(prompt, str) else prompt
|
| 74 |
+
|
| 75 |
+
if conditioning_items is None:
|
| 76 |
+
prompts = _generate_t2v_prompt(
|
| 77 |
+
prompt_enhancer_model,
|
| 78 |
+
prompt_enhancer_tokenizer,
|
| 79 |
+
prompts,
|
| 80 |
+
max_new_tokens,
|
| 81 |
+
T2V_CINEMATIC_PROMPT,
|
| 82 |
+
)
|
| 83 |
+
else:
|
| 84 |
+
if len(conditioning_items) > 1 or conditioning_items[0].media_frame_number != 0:
|
| 85 |
+
logger.warning(
|
| 86 |
+
"prompt enhancement does only support unconditional or first frame of conditioning items, returning original prompts"
|
| 87 |
+
)
|
| 88 |
+
return prompts
|
| 89 |
+
|
| 90 |
+
first_frame_conditioning_item = conditioning_items[0]
|
| 91 |
+
first_frames = _get_first_frames_from_conditioning_item(
|
| 92 |
+
first_frame_conditioning_item
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
assert len(first_frames) == len(
|
| 96 |
+
prompts
|
| 97 |
+
), "Number of conditioning frames must match number of prompts"
|
| 98 |
+
|
| 99 |
+
prompts = _generate_i2v_prompt(
|
| 100 |
+
image_caption_model,
|
| 101 |
+
image_caption_processor,
|
| 102 |
+
prompt_enhancer_model,
|
| 103 |
+
prompt_enhancer_tokenizer,
|
| 104 |
+
prompts,
|
| 105 |
+
first_frames,
|
| 106 |
+
max_new_tokens,
|
| 107 |
+
I2V_CINEMATIC_PROMPT,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
return prompts
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _get_first_frames_from_conditioning_item(conditioning_item) -> List[Image.Image]:
|
| 114 |
+
frames_tensor = conditioning_item.media_item
|
| 115 |
+
return [
|
| 116 |
+
tensor_to_pil(frames_tensor[i, :, 0, :, :])
|
| 117 |
+
for i in range(frames_tensor.shape[0])
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _generate_t2v_prompt(
|
| 122 |
+
prompt_enhancer_model,
|
| 123 |
+
prompt_enhancer_tokenizer,
|
| 124 |
+
prompts: List[str],
|
| 125 |
+
max_new_tokens: int,
|
| 126 |
+
system_prompt: str,
|
| 127 |
+
) -> List[str]:
|
| 128 |
+
messages = [
|
| 129 |
+
[
|
| 130 |
+
{"role": "system", "content": system_prompt},
|
| 131 |
+
{"role": "user", "content": f"user_prompt: {p}"},
|
| 132 |
+
]
|
| 133 |
+
for p in prompts
|
| 134 |
+
]
|
| 135 |
+
|
| 136 |
+
texts = [
|
| 137 |
+
prompt_enhancer_tokenizer.apply_chat_template(
|
| 138 |
+
m, tokenize=False, add_generation_prompt=True
|
| 139 |
+
)
|
| 140 |
+
for m in messages
|
| 141 |
+
]
|
| 142 |
+
model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
|
| 143 |
+
prompt_enhancer_model.device
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return _generate_and_decode_prompts(
|
| 147 |
+
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _generate_i2v_prompt(
|
| 152 |
+
image_caption_model,
|
| 153 |
+
image_caption_processor,
|
| 154 |
+
prompt_enhancer_model,
|
| 155 |
+
prompt_enhancer_tokenizer,
|
| 156 |
+
prompts: List[str],
|
| 157 |
+
first_frames: List[Image.Image],
|
| 158 |
+
max_new_tokens: int,
|
| 159 |
+
system_prompt: str,
|
| 160 |
+
) -> List[str]:
|
| 161 |
+
image_captions = _generate_image_captions(
|
| 162 |
+
image_caption_model, image_caption_processor, first_frames
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
messages = [
|
| 166 |
+
[
|
| 167 |
+
{"role": "system", "content": system_prompt},
|
| 168 |
+
{"role": "user", "content": f"user_prompt: {p}\nimage_caption: {c}"},
|
| 169 |
+
]
|
| 170 |
+
for p, c in zip(prompts, image_captions)
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
texts = [
|
| 174 |
+
prompt_enhancer_tokenizer.apply_chat_template(
|
| 175 |
+
m, tokenize=False, add_generation_prompt=True
|
| 176 |
+
)
|
| 177 |
+
for m in messages
|
| 178 |
+
]
|
| 179 |
+
model_inputs = prompt_enhancer_tokenizer(texts, return_tensors="pt").to(
|
| 180 |
+
prompt_enhancer_model.device
|
| 181 |
+
)
|
| 182 |
+
|
| 183 |
+
return _generate_and_decode_prompts(
|
| 184 |
+
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _generate_image_captions(
|
| 189 |
+
image_caption_model,
|
| 190 |
+
image_caption_processor,
|
| 191 |
+
images: List[Image.Image],
|
| 192 |
+
system_prompt: str = "<DETAILED_CAPTION>",
|
| 193 |
+
) -> List[str]:
|
| 194 |
+
image_caption_prompts = [system_prompt] * len(images)
|
| 195 |
+
inputs = image_caption_processor(
|
| 196 |
+
image_caption_prompts, images, return_tensors="pt"
|
| 197 |
+
).to(image_caption_model.device)
|
| 198 |
+
|
| 199 |
+
with torch.inference_mode():
|
| 200 |
+
generated_ids = image_caption_model.generate(
|
| 201 |
+
input_ids=inputs["input_ids"],
|
| 202 |
+
pixel_values=inputs["pixel_values"],
|
| 203 |
+
max_new_tokens=1024,
|
| 204 |
+
do_sample=False,
|
| 205 |
+
num_beams=3,
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return image_caption_processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _generate_and_decode_prompts(
|
| 212 |
+
prompt_enhancer_model, prompt_enhancer_tokenizer, model_inputs, max_new_tokens: int
|
| 213 |
+
) -> List[str]:
|
| 214 |
+
with torch.inference_mode():
|
| 215 |
+
outputs = prompt_enhancer_model.generate(
|
| 216 |
+
**model_inputs, max_new_tokens=max_new_tokens
|
| 217 |
+
)
|
| 218 |
+
generated_ids = [
|
| 219 |
+
output_ids[len(input_ids) :]
|
| 220 |
+
for input_ids, output_ids in zip(model_inputs.input_ids, outputs)
|
| 221 |
+
]
|
| 222 |
+
decoded_prompts = prompt_enhancer_tokenizer.batch_decode(
|
| 223 |
+
generated_ids, skip_special_tokens=True
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
return decoded_prompts
|
flash_head/ltx_video/utils/skip_layer_strategy.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum, auto
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class SkipLayerStrategy(Enum):
|
| 5 |
+
AttentionSkip = auto()
|
| 6 |
+
AttentionValues = auto()
|
| 7 |
+
Residual = auto()
|
| 8 |
+
TransformerBlock = auto()
|
flash_head/ltx_video/utils/torch_utils.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
|
| 6 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 7 |
+
dims_to_append = target_dims - x.ndim
|
| 8 |
+
if dims_to_append < 0:
|
| 9 |
+
raise ValueError(
|
| 10 |
+
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
|
| 11 |
+
)
|
| 12 |
+
elif dims_to_append == 0:
|
| 13 |
+
return x
|
| 14 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Identity(nn.Module):
|
| 18 |
+
"""A placeholder identity operator that is argument-insensitive."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
|
| 21 |
+
super().__init__()
|
| 22 |
+
|
| 23 |
+
# pylint: disable=unused-argument
|
| 24 |
+
def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
| 25 |
+
return x
|
flash_head/src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
flash_head/src/distributed/usp_device.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from loguru import logger
|
| 3 |
+
import datetime
|
| 4 |
+
import torch
|
| 5 |
+
import torch.distributed as dist
|
| 6 |
+
|
| 7 |
+
def get_parallel_degree(world_size, num_heads):
|
| 8 |
+
# ulysses_degree is faster, and must be a divisor of num_heads
|
| 9 |
+
ulysses_degree = math.gcd(world_size, num_heads)
|
| 10 |
+
ring_degree = world_size // ulysses_degree
|
| 11 |
+
return ulysses_degree, ring_degree
|
| 12 |
+
|
| 13 |
+
def get_device(ulysses_degree, ring_degree):
|
| 14 |
+
if ulysses_degree > 1 or ring_degree > 1:
|
| 15 |
+
from xfuser.core.distributed import (
|
| 16 |
+
init_distributed_environment,
|
| 17 |
+
initialize_model_parallel,
|
| 18 |
+
get_world_group,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
dist.init_process_group("nccl", timeout=datetime.timedelta(hours=24*7))
|
| 22 |
+
init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
|
| 23 |
+
initialize_model_parallel(
|
| 24 |
+
sequence_parallel_degree=dist.get_world_size(),
|
| 25 |
+
ring_degree=ring_degree,
|
| 26 |
+
ulysses_degree=ulysses_degree
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
device = torch.device(f"cuda:{get_world_group().rank}")
|
| 30 |
+
torch.cuda.set_device(get_world_group().rank)
|
| 31 |
+
|
| 32 |
+
logger.info(f'rank={get_world_group().rank} device={str(device)}')
|
| 33 |
+
else:
|
| 34 |
+
device = "cuda"
|
| 35 |
+
return device
|
flash_head/src/modules/flash_head_model.py
ADDED
|
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
from typing import Tuple, Optional
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from diffusers import ModelMixin
|
| 8 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 9 |
+
import torch.cuda.amp as amp
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
from xfuser.core.distributed import (
|
| 12 |
+
get_sequence_parallel_rank,
|
| 13 |
+
get_sequence_parallel_world_size,
|
| 14 |
+
get_sp_group,
|
| 15 |
+
)
|
| 16 |
+
from xfuser.core.long_ctx_attention import xFuserLongContextAttention
|
| 17 |
+
try:
|
| 18 |
+
import flash_attn_interface
|
| 19 |
+
FLASH_ATTN_3_AVAILABLE = True
|
| 20 |
+
except ModuleNotFoundError:
|
| 21 |
+
FLASH_ATTN_3_AVAILABLE = False
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
import flash_attn
|
| 25 |
+
FLASH_ATTN_2_AVAILABLE = True
|
| 26 |
+
except ModuleNotFoundError:
|
| 27 |
+
FLASH_ATTN_2_AVAILABLE = False
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
from sageattention import sageattn
|
| 31 |
+
SAGE_ATTN_AVAILABLE = True
|
| 32 |
+
except ModuleNotFoundError:
|
| 33 |
+
SAGE_ATTN_AVAILABLE = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False):
|
| 37 |
+
if compatibility_mode:
|
| 38 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 39 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 40 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 41 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 42 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 43 |
+
elif SAGE_ATTN_AVAILABLE:
|
| 44 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 45 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 46 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 47 |
+
x = sageattn(q, k, v)
|
| 48 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 49 |
+
elif FLASH_ATTN_3_AVAILABLE:
|
| 50 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
| 51 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
| 52 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
| 53 |
+
x = flash_attn_interface.flash_attn_func(q, k, v)
|
| 54 |
+
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
| 55 |
+
elif FLASH_ATTN_2_AVAILABLE:
|
| 56 |
+
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
|
| 57 |
+
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
|
| 58 |
+
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
|
| 59 |
+
x = flash_attn.flash_attn_func(q, k, v)
|
| 60 |
+
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
|
| 61 |
+
else:
|
| 62 |
+
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
|
| 63 |
+
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
|
| 64 |
+
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
|
| 65 |
+
x = F.scaled_dot_product_attention(q, k, v)
|
| 66 |
+
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
def sinusoidal_embedding_1d(dim, position):
|
| 70 |
+
sinusoid = torch.outer(position.type(torch.float64), torch.pow(
|
| 71 |
+
10000, -torch.arange(dim//2, dtype=torch.float64, device=position.device).div(dim//2)))
|
| 72 |
+
x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
|
| 73 |
+
return x.to(position.dtype)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def precompute_freqs_cis_3d(dim: int, end: int = 1024, theta: float = 10000.0):
|
| 77 |
+
# 3d rope precompute
|
| 78 |
+
f_freqs_cis = precompute_freqs_cis(dim - 2 * (dim // 3), end, theta)
|
| 79 |
+
h_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
| 80 |
+
w_freqs_cis = precompute_freqs_cis(dim // 3, end, theta)
|
| 81 |
+
return torch.cat([f_freqs_cis, h_freqs_cis, w_freqs_cis], dim=1)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def precompute_freqs_cis(dim: int, end: int = 1024, theta: float = 10000.0):
|
| 85 |
+
# 1d rope precompute
|
| 86 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)
|
| 87 |
+
[: (dim // 2)].double() / dim))
|
| 88 |
+
freqs = torch.outer(torch.arange(end, device=freqs.device), freqs)
|
| 89 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
| 90 |
+
return freqs_cis
|
| 91 |
+
|
| 92 |
+
def pad_freqs(original_tensor, target_len):
|
| 93 |
+
seq_len, s1, s2 = original_tensor.shape
|
| 94 |
+
pad_size = target_len - seq_len
|
| 95 |
+
padding_tensor = torch.ones(
|
| 96 |
+
pad_size,
|
| 97 |
+
s1,
|
| 98 |
+
s2,
|
| 99 |
+
dtype=original_tensor.dtype,
|
| 100 |
+
device=original_tensor.device)
|
| 101 |
+
padded_tensor = torch.cat([original_tensor, padding_tensor], dim=0)
|
| 102 |
+
return padded_tensor
|
| 103 |
+
|
| 104 |
+
def rope_apply(x, freqs, grid_sizes, use_usp=False, sp_size=1, sp_rank=0):
|
| 105 |
+
"""
|
| 106 |
+
x: [B, L, N, C].
|
| 107 |
+
grid_sizes: [B, 3].
|
| 108 |
+
freqs: [M, C // 2].
|
| 109 |
+
"""
|
| 110 |
+
s, n, c = x.size(1), x.size(2), x.size(3) // 2
|
| 111 |
+
# split freqs
|
| 112 |
+
freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1) # [[N, head_dim/2], [N, head_dim/2], [N, head_dim/2]] # T H W 极坐标
|
| 113 |
+
|
| 114 |
+
# loop over samples
|
| 115 |
+
|
| 116 |
+
(f, h, w) = grid_sizes
|
| 117 |
+
seq_len = f * h * w
|
| 118 |
+
|
| 119 |
+
# precompute multipliers
|
| 120 |
+
x_i = torch.view_as_complex(x[0, :s].to(torch.float64).reshape(
|
| 121 |
+
s, n, -1, 2)) # [L, N, C/2] # 极坐标
|
| 122 |
+
freqs_i = torch.cat([
|
| 123 |
+
freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
|
| 124 |
+
freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
|
| 125 |
+
freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1)
|
| 126 |
+
],
|
| 127 |
+
dim=-1).reshape(seq_len, 1, -1) # seq_lens, 1, 3 * dim / 2 (T H W)
|
| 128 |
+
|
| 129 |
+
if use_usp:
|
| 130 |
+
# apply rotary embedding
|
| 131 |
+
freqs_i = pad_freqs(freqs_i, s * sp_size)
|
| 132 |
+
s_per_rank = s
|
| 133 |
+
freqs_i_rank = freqs_i[(sp_rank * s_per_rank):((sp_rank + 1) *
|
| 134 |
+
s_per_rank), :, :]
|
| 135 |
+
x_i = torch.view_as_real(x_i * freqs_i_rank).flatten(2)
|
| 136 |
+
x_i = torch.cat([x_i, x[0, s:]])
|
| 137 |
+
else:
|
| 138 |
+
x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
|
| 139 |
+
x_i = torch.cat([x_i, x[0, seq_len:]])
|
| 140 |
+
return x_i.unsqueeze(0).to(x.dtype)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
class RMSNorm(nn.Module):
|
| 144 |
+
def __init__(self, dim, eps=1e-5):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.eps = eps
|
| 147 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 148 |
+
|
| 149 |
+
def norm(self, x):
|
| 150 |
+
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
|
| 151 |
+
|
| 152 |
+
def forward(self, x):
|
| 153 |
+
dtype = x.dtype
|
| 154 |
+
return self.norm(x.float()).to(dtype) * self.weight
|
| 155 |
+
|
| 156 |
+
class SelfAttention(nn.Module):
|
| 157 |
+
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.dim = dim
|
| 160 |
+
self.num_heads = num_heads
|
| 161 |
+
self.head_dim = dim // num_heads
|
| 162 |
+
|
| 163 |
+
self.q = nn.Linear(dim, dim)
|
| 164 |
+
self.k = nn.Linear(dim, dim)
|
| 165 |
+
self.v = nn.Linear(dim, dim)
|
| 166 |
+
self.o = nn.Linear(dim, dim)
|
| 167 |
+
self.norm_q = RMSNorm(dim, eps=eps)
|
| 168 |
+
self.norm_k = RMSNorm(dim, eps=eps)
|
| 169 |
+
|
| 170 |
+
self.use_usp = dist.is_initialized()
|
| 171 |
+
self.sp_size = get_sequence_parallel_world_size() if self.use_usp else 1
|
| 172 |
+
self.sp_rank = get_sequence_parallel_rank() if self.use_usp else 0
|
| 173 |
+
|
| 174 |
+
def forward(self, x, freqs, grid_sizes):
|
| 175 |
+
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
|
| 176 |
+
q = self.norm_q(self.q(x)).view(b, s, n, d)
|
| 177 |
+
k = self.norm_k(self.k(x)).view(b, s, n, d)
|
| 178 |
+
v = self.v(x)
|
| 179 |
+
|
| 180 |
+
if self.use_usp:
|
| 181 |
+
from yunchang.kernels import AttnType
|
| 182 |
+
if SAGE_ATTN_AVAILABLE:
|
| 183 |
+
attn_type = AttnType.SAGE_AUTO
|
| 184 |
+
else:
|
| 185 |
+
attn_type = AttnType.FA
|
| 186 |
+
|
| 187 |
+
x = xFuserLongContextAttention(attn_type=attn_type)(
|
| 188 |
+
None,
|
| 189 |
+
query=rope_apply(q, freqs, grid_sizes, self.use_usp, self.sp_size, self.sp_rank),
|
| 190 |
+
key=rope_apply(k, freqs, grid_sizes, self.use_usp, self.sp_size, self.sp_rank),
|
| 191 |
+
value=v.view(b, s, n, d),
|
| 192 |
+
).flatten(2)
|
| 193 |
+
else:
|
| 194 |
+
x = flash_attention(
|
| 195 |
+
q=rope_apply(q, freqs, grid_sizes).flatten(2),
|
| 196 |
+
k=rope_apply(k, freqs, grid_sizes).flatten(2),
|
| 197 |
+
v=v,
|
| 198 |
+
num_heads=self.num_heads
|
| 199 |
+
)
|
| 200 |
+
return self.o(x)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class CrossAttention(nn.Module):
|
| 204 |
+
def __init__(self, dim: int, num_heads: int, eps: float = 1e-6, has_image_input: bool = False):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.dim = dim
|
| 207 |
+
self.num_heads = num_heads
|
| 208 |
+
self.head_dim = dim // num_heads
|
| 209 |
+
|
| 210 |
+
self.q = nn.Linear(dim, dim)
|
| 211 |
+
self.k = nn.Linear(dim, dim)
|
| 212 |
+
self.v = nn.Linear(dim, dim)
|
| 213 |
+
self.o = nn.Linear(dim, dim)
|
| 214 |
+
self.norm_q = RMSNorm(dim, eps=eps)
|
| 215 |
+
self.norm_k = RMSNorm(dim, eps=eps)
|
| 216 |
+
self.has_image_input = has_image_input
|
| 217 |
+
if has_image_input:
|
| 218 |
+
self.k_img = nn.Linear(dim, dim)
|
| 219 |
+
self.v_img = nn.Linear(dim, dim)
|
| 220 |
+
self.norm_k_img = RMSNorm(dim, eps=eps)
|
| 221 |
+
|
| 222 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
| 223 |
+
if self.has_image_input:
|
| 224 |
+
img = y[:, :257]
|
| 225 |
+
ctx = y[:, 257:]
|
| 226 |
+
else:
|
| 227 |
+
ctx = y
|
| 228 |
+
q = self.norm_q(self.q(x))
|
| 229 |
+
k = self.norm_k(self.k(ctx))
|
| 230 |
+
v = self.v(ctx)
|
| 231 |
+
x = flash_attention(q, k, v, num_heads=self.num_heads)
|
| 232 |
+
if self.has_image_input:
|
| 233 |
+
k_img = self.norm_k_img(self.k_img(img))
|
| 234 |
+
v_img = self.v_img(img)
|
| 235 |
+
y = flash_attention(q, k_img, v_img, num_heads=self.num_heads)
|
| 236 |
+
x = x + y
|
| 237 |
+
return self.o(x)
|
| 238 |
+
|
| 239 |
+
class DiTAudioBlock(nn.Module):
|
| 240 |
+
def __init__(self, has_image_input: bool, dim: int, num_heads: int, ffn_dim: int, eps: float = 1e-6, i=0, num_layers=0):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.dim = dim
|
| 243 |
+
self.num_heads = num_heads
|
| 244 |
+
self.ffn_dim = ffn_dim
|
| 245 |
+
self.i = i
|
| 246 |
+
self.num_layers = num_layers
|
| 247 |
+
|
| 248 |
+
self.self_attn = SelfAttention(dim, num_heads, eps)
|
| 249 |
+
self.cross_attn = CrossAttention(
|
| 250 |
+
dim, num_heads, eps, has_image_input=has_image_input)
|
| 251 |
+
self.norm1 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 252 |
+
self.norm2 = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 253 |
+
self.norm3 = nn.LayerNorm(dim, eps=eps)
|
| 254 |
+
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(
|
| 255 |
+
approximate='tanh'), nn.Linear(ffn_dim, dim))
|
| 256 |
+
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
| 257 |
+
|
| 258 |
+
self.use_usp = dist.is_initialized()
|
| 259 |
+
self.sp_size = get_sequence_parallel_world_size() if self.use_usp else 1
|
| 260 |
+
self.sp_rank = get_sequence_parallel_rank() if self.use_usp else 0
|
| 261 |
+
|
| 262 |
+
def forward(self, x, context, t_mod, freqs, grid_sizes):
|
| 263 |
+
e = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device) + t_mod).chunk(6, dim=1)
|
| 264 |
+
|
| 265 |
+
y = self.self_attn(
|
| 266 |
+
self.norm1(x) * (1 + e[1]) + e[0], freqs, grid_sizes)
|
| 267 |
+
|
| 268 |
+
x = x + y * e[2]
|
| 269 |
+
|
| 270 |
+
x_1 = rearrange(self.norm3(x), 'b (f l) c -> (b f) l c', f=context.shape[1])
|
| 271 |
+
context_1 = context.squeeze(0)
|
| 272 |
+
|
| 273 |
+
if self.use_usp:
|
| 274 |
+
context_1 = context_1.unsqueeze(1).repeat(1, self.sp_size, 1, 1).flatten(0,1)
|
| 275 |
+
context_1 = torch.chunk(context_1, self.sp_size, dim=0)[self.sp_rank]
|
| 276 |
+
|
| 277 |
+
x = x + self.cross_attn(x_1, context_1).flatten(0, 1).unsqueeze(0)
|
| 278 |
+
|
| 279 |
+
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
|
| 280 |
+
x = x + y * e[5]
|
| 281 |
+
|
| 282 |
+
return x
|
| 283 |
+
|
| 284 |
+
class MLP(torch.nn.Module):
|
| 285 |
+
def __init__(self, in_dim, out_dim):
|
| 286 |
+
super().__init__()
|
| 287 |
+
self.proj = torch.nn.Sequential(
|
| 288 |
+
nn.LayerNorm(in_dim),
|
| 289 |
+
nn.Linear(in_dim, in_dim),
|
| 290 |
+
nn.GELU(),
|
| 291 |
+
nn.Linear(in_dim, out_dim),
|
| 292 |
+
nn.LayerNorm(out_dim)
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
def forward(self, x):
|
| 296 |
+
return self.proj(x)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class Head(nn.Module):
|
| 300 |
+
def __init__(self, dim: int, out_dim: int, patch_size: Tuple[int, int, int], eps: float):
|
| 301 |
+
super().__init__()
|
| 302 |
+
self.dim = dim
|
| 303 |
+
self.patch_size = patch_size
|
| 304 |
+
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=False)
|
| 305 |
+
self.head = nn.Linear(dim, out_dim * math.prod(patch_size))
|
| 306 |
+
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
| 307 |
+
|
| 308 |
+
def forward(self, x, t_mod):
|
| 309 |
+
r"""
|
| 310 |
+
Args:
|
| 311 |
+
x(Tensor): Shape [B, L1, C]
|
| 312 |
+
t_mod(Tensor): Shape [B*21, C]
|
| 313 |
+
"""
|
| 314 |
+
B, L, D = x.shape
|
| 315 |
+
F = t_mod.shape[0] // B
|
| 316 |
+
shift, scale = (self.modulation.to(dtype=t_mod.dtype, device=t_mod.device).unsqueeze(1) + t_mod.unflatten(dim=0, sizes=(B, t_mod.shape[0]//B)).unsqueeze(2)).chunk(2, dim=2)
|
| 317 |
+
|
| 318 |
+
x = rearrange(x, 'b (f l) d -> b f l d', f=F)
|
| 319 |
+
x = (self.head(self.norm(x) * (1 + scale) + shift))
|
| 320 |
+
x = rearrange(x, 'b f l d -> b (f l) d')
|
| 321 |
+
return x
|
| 322 |
+
|
| 323 |
+
class WanModelAudioProject(ModelMixin, ConfigMixin):
|
| 324 |
+
_no_split_modules = ['DiTAudioBlock']
|
| 325 |
+
@register_to_config
|
| 326 |
+
def __init__(
|
| 327 |
+
self,
|
| 328 |
+
dim: int,
|
| 329 |
+
in_dim: int,
|
| 330 |
+
ffn_dim: int,
|
| 331 |
+
out_dim: int,
|
| 332 |
+
text_dim: int,
|
| 333 |
+
freq_dim: int,
|
| 334 |
+
eps: float,
|
| 335 |
+
vae_stride: Tuple[int, int, int],
|
| 336 |
+
patch_size: Tuple[int, int, int],
|
| 337 |
+
num_heads: int,
|
| 338 |
+
num_layers: int,
|
| 339 |
+
has_image_input: bool,
|
| 340 |
+
**kwargs,
|
| 341 |
+
):
|
| 342 |
+
super().__init__()
|
| 343 |
+
self.dim = dim
|
| 344 |
+
self.freq_dim = freq_dim
|
| 345 |
+
self.has_image_input = has_image_input
|
| 346 |
+
self.patch_size = patch_size
|
| 347 |
+
|
| 348 |
+
self.patch_embedding = nn.Conv3d(
|
| 349 |
+
in_dim, dim, kernel_size=patch_size, stride=patch_size)
|
| 350 |
+
self.text_embedding = nn.Sequential(
|
| 351 |
+
nn.Linear(text_dim, dim),
|
| 352 |
+
nn.GELU(approximate='tanh'),
|
| 353 |
+
nn.Linear(dim, dim)
|
| 354 |
+
)
|
| 355 |
+
self.time_embedding = nn.Sequential(
|
| 356 |
+
nn.Linear(freq_dim, dim),
|
| 357 |
+
nn.SiLU(),
|
| 358 |
+
nn.Linear(dim, dim)
|
| 359 |
+
)
|
| 360 |
+
self.time_projection = nn.Sequential(
|
| 361 |
+
nn.SiLU(), nn.Linear(dim, dim * 6))
|
| 362 |
+
self.blocks = nn.ModuleList([
|
| 363 |
+
DiTAudioBlock(has_image_input, dim, num_heads, ffn_dim, eps, i, num_layers)
|
| 364 |
+
for i in range(num_layers)
|
| 365 |
+
])
|
| 366 |
+
self.head = Head(dim, out_dim, patch_size, eps)
|
| 367 |
+
head_dim = dim // num_heads
|
| 368 |
+
self.freqs = precompute_freqs_cis_3d(head_dim)
|
| 369 |
+
|
| 370 |
+
self.audio_emb = MLP(768, dim)
|
| 371 |
+
|
| 372 |
+
if has_image_input:
|
| 373 |
+
self.img_emb = MLP(1280, dim)
|
| 374 |
+
|
| 375 |
+
# init audio adapter
|
| 376 |
+
audio_window = 5
|
| 377 |
+
vae_scale = vae_stride[0]
|
| 378 |
+
intermediate_dim = 512
|
| 379 |
+
output_dim = 1536
|
| 380 |
+
context_tokens = 32
|
| 381 |
+
norm_output_audio = True
|
| 382 |
+
self.audio_window = audio_window
|
| 383 |
+
self.vae_scale = vae_scale
|
| 384 |
+
self.audio_proj = AudioProjModel(
|
| 385 |
+
seq_len=audio_window,
|
| 386 |
+
seq_len_vf=audio_window+vae_scale-1,
|
| 387 |
+
intermediate_dim=intermediate_dim,
|
| 388 |
+
output_dim=output_dim,
|
| 389 |
+
context_tokens=context_tokens,
|
| 390 |
+
norm_output_audio=norm_output_audio,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
self.use_usp = dist.is_initialized()
|
| 394 |
+
self.sp_size = get_sequence_parallel_world_size() if self.use_usp else 1
|
| 395 |
+
self.sp_rank = get_sequence_parallel_rank() if self.use_usp else 0
|
| 396 |
+
|
| 397 |
+
def patchify(self, x: torch.Tensor):
|
| 398 |
+
x = self.patch_embedding(x)
|
| 399 |
+
grid_size = x.shape[2:]
|
| 400 |
+
x = rearrange(x, 'b c f h w -> b (f h w) c').contiguous()
|
| 401 |
+
return x, grid_size # x, grid_size: (f, h, w)
|
| 402 |
+
|
| 403 |
+
def unpatchify(self, x: torch.Tensor, grid_size: torch.Tensor):
|
| 404 |
+
return rearrange(
|
| 405 |
+
x, 'b (f h w) (x y z c) -> b c (f x) (h y) (w z)',
|
| 406 |
+
f=grid_size[0], h=grid_size[1], w=grid_size[2],
|
| 407 |
+
x=self.patch_size[0], y=self.patch_size[1], z=self.patch_size[2]
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
def forward(self,
|
| 411 |
+
x: torch.Tensor, #(1, 16, 9, 64, 64))
|
| 412 |
+
timestep: torch.Tensor, #(9,)
|
| 413 |
+
context: torch.Tensor, #(5, 33, 12, 768)
|
| 414 |
+
y: Optional[torch.Tensor] = None, #(1, 16, 9, 64, 64)
|
| 415 |
+
use_gradient_checkpointing: bool = False,
|
| 416 |
+
use_gradient_checkpointing_offload: bool = False,
|
| 417 |
+
**kwargs,
|
| 418 |
+
):
|
| 419 |
+
|
| 420 |
+
if self.freqs.device != x.device:
|
| 421 |
+
self.freqs = self.freqs.to(x.device)
|
| 422 |
+
|
| 423 |
+
x = torch.cat([x, y], dim=1) # (1, 32, 9, 64, 64)
|
| 424 |
+
x, grid_sizes = self.patchify(x)
|
| 425 |
+
t = self.time_embedding(
|
| 426 |
+
sinusoidal_embedding_1d(self.freq_dim, timestep.to(dtype=x.dtype)))
|
| 427 |
+
t_mod = self.time_projection(t).unflatten(1, (6, self.dim)) # (bsz, 6, 1536)
|
| 428 |
+
|
| 429 |
+
# ==================== 音频条件处理 ====================
|
| 430 |
+
# 输入: context (bsz, 81, 5, 12, 768)
|
| 431 |
+
# - 81 帧 = 1 (第一帧) + 80 (后续帧, 每4帧对应VAE压缩后的1帧)
|
| 432 |
+
# - 5 是音频窗口大小 (audio_window)
|
| 433 |
+
# - 12 是音频特征的 blocks
|
| 434 |
+
# - 768 是音频特征维度
|
| 435 |
+
|
| 436 |
+
audio_cond = context.to(device=x.device, dtype=x.dtype)
|
| 437 |
+
|
| 438 |
+
# 1. 第一帧:直接使用完整的5帧音频窗口
|
| 439 |
+
first_frame_audio = audio_cond[:, :1, ...] # (bsz, 1, 5, 12, 768)
|
| 440 |
+
|
| 441 |
+
# 2. 后续帧:需要根据帧位置选择不同的音频窗口
|
| 442 |
+
# 将 32 帧重排为 (8 个 VAE latent, 每个4帧)
|
| 443 |
+
latter_frames_audio = rearrange(
|
| 444 |
+
audio_cond[:, 1:, ...],
|
| 445 |
+
"b (n_latent n_frame) w s c -> b n_latent n_frame w s c",
|
| 446 |
+
n_frame=self.vae_scale # vae_scale=4
|
| 447 |
+
) # (bsz, 8, 4, 5, 12, 768)
|
| 448 |
+
|
| 449 |
+
mid_idx = self.audio_window // 2 # 窗口中心索引: 5//2=2
|
| 450 |
+
|
| 451 |
+
# 为每个 latent 的4帧选择合适的音频窗口:
|
| 452 |
+
# - 第1帧 (帧索引0): 无过去,取前3帧窗口 [:mid_idx+1] = [:3]
|
| 453 |
+
# - 中间帧 (帧索引1-2): 取中心1帧 [mid_idx:mid_idx+1] = [2:3]
|
| 454 |
+
# - 第4帧 (帧索引3): 无未来,取后3帧窗口 [mid_idx:] = [2:]
|
| 455 |
+
|
| 456 |
+
first_of_group = latter_frames_audio[:, :, :1, :mid_idx+1, ...] # (bsz, 8, 1, 3, 12, 768)
|
| 457 |
+
middle_of_group = latter_frames_audio[:, :, 1:-1, mid_idx:mid_idx+1, ...] # (bsz, 8, 2, 1, 12, 768)
|
| 458 |
+
last_of_group = latter_frames_audio[:, :, -1:, mid_idx:, ...] # (bsz, 8, 1, 3, 12, 768)
|
| 459 |
+
|
| 460 |
+
# 合并并展平窗口维度: (n_frame, window) -> (n_frame * window)
|
| 461 |
+
latter_frames_audio_processed = torch.cat([
|
| 462 |
+
rearrange(first_of_group, "b n_latent n_f w s c -> b n_latent (n_f w) s c"),
|
| 463 |
+
rearrange(middle_of_group, "b n_latent n_f w s c -> b n_latent (n_f w) s c"),
|
| 464 |
+
rearrange(last_of_group, "b n_latent n_f w s c -> b n_latent (n_f w) s c"),
|
| 465 |
+
], dim=2) # (bsz, 8, 1*3 + 2*1 + 1*3, 12, 768) = (bsz, 8, 8, 12, 768)
|
| 466 |
+
|
| 467 |
+
# 3. 通过 AudioProjModel 投影到 DiT 所需的特征空间
|
| 468 |
+
context = self.audio_proj(
|
| 469 |
+
first_frame_audio,
|
| 470 |
+
latter_frames_audio_processed
|
| 471 |
+
).to(x.dtype) # (bsz, 9, 32, 1536)
|
| 472 |
+
|
| 473 |
+
if self.use_usp:
|
| 474 |
+
x = torch.chunk(x, self.sp_size, dim=1)[self.sp_rank]
|
| 475 |
+
|
| 476 |
+
for block in self.blocks:
|
| 477 |
+
x = block(x, context, t_mod, self.freqs, grid_sizes)
|
| 478 |
+
x = self.head(x, t) # (bsz, 9*32*32, 64)
|
| 479 |
+
if self.use_usp:
|
| 480 |
+
x = get_sp_group().all_gather(x, dim=1)
|
| 481 |
+
x = self.unpatchify(x, grid_sizes) # (bsz, 16, 21, 64, 64)
|
| 482 |
+
return x
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class AudioProjModel(ModelMixin, ConfigMixin):
|
| 486 |
+
def __init__(
|
| 487 |
+
self,
|
| 488 |
+
seq_len=5,
|
| 489 |
+
seq_len_vf=12,
|
| 490 |
+
blocks=12,
|
| 491 |
+
channels=768,
|
| 492 |
+
intermediate_dim=512,
|
| 493 |
+
output_dim=768,
|
| 494 |
+
context_tokens=32,
|
| 495 |
+
norm_output_audio=False,
|
| 496 |
+
):
|
| 497 |
+
super().__init__()
|
| 498 |
+
|
| 499 |
+
self.seq_len = seq_len
|
| 500 |
+
self.blocks = blocks
|
| 501 |
+
self.channels = channels
|
| 502 |
+
self.input_dim = seq_len * blocks * channels
|
| 503 |
+
self.input_dim_vf = seq_len_vf * blocks * channels
|
| 504 |
+
self.intermediate_dim = intermediate_dim
|
| 505 |
+
self.context_tokens = context_tokens
|
| 506 |
+
self.output_dim = output_dim
|
| 507 |
+
|
| 508 |
+
# define multiple linear layers
|
| 509 |
+
self.proj1 = nn.Linear(self.input_dim, intermediate_dim)
|
| 510 |
+
self.proj1_vf = nn.Linear(self.input_dim_vf, intermediate_dim)
|
| 511 |
+
self.proj2 = nn.Linear(intermediate_dim, intermediate_dim)
|
| 512 |
+
self.proj3 = nn.Linear(intermediate_dim, context_tokens * output_dim)
|
| 513 |
+
self.norm = nn.LayerNorm(output_dim) if norm_output_audio else nn.Identity()
|
| 514 |
+
|
| 515 |
+
def forward(self, audio_embeds, audio_embeds_vf):
|
| 516 |
+
video_length = audio_embeds.shape[1] + audio_embeds_vf.shape[1]
|
| 517 |
+
B, _, _, S, C = audio_embeds.shape
|
| 518 |
+
|
| 519 |
+
# process audio of first frame
|
| 520 |
+
audio_embeds = rearrange(audio_embeds, "bz f w b c -> (bz f) w b c")
|
| 521 |
+
batch_size, window_size, blocks, channels = audio_embeds.shape
|
| 522 |
+
audio_embeds = audio_embeds.view(batch_size, window_size * blocks * channels)
|
| 523 |
+
|
| 524 |
+
# process audio of latter frame
|
| 525 |
+
audio_embeds_vf = rearrange(audio_embeds_vf, "bz f w b c -> (bz f) w b c")
|
| 526 |
+
batch_size_vf, window_size_vf, blocks_vf, channels_vf = audio_embeds_vf.shape
|
| 527 |
+
audio_embeds_vf = audio_embeds_vf.view(batch_size_vf, window_size_vf * blocks_vf * channels_vf)
|
| 528 |
+
|
| 529 |
+
# first projection
|
| 530 |
+
audio_embeds = torch.relu(self.proj1(audio_embeds))
|
| 531 |
+
audio_embeds_vf = torch.relu(self.proj1_vf(audio_embeds_vf))
|
| 532 |
+
audio_embeds = rearrange(audio_embeds, "(bz f) c -> bz f c", bz=B)
|
| 533 |
+
audio_embeds_vf = rearrange(audio_embeds_vf, "(bz f) c -> bz f c", bz=B)
|
| 534 |
+
audio_embeds_c = torch.concat([audio_embeds, audio_embeds_vf], dim=1)
|
| 535 |
+
batch_size_c, N_t, C_a = audio_embeds_c.shape
|
| 536 |
+
audio_embeds_c = audio_embeds_c.view(batch_size_c*N_t, C_a)
|
| 537 |
+
|
| 538 |
+
# second projection
|
| 539 |
+
audio_embeds_c = torch.relu(self.proj2(audio_embeds_c))
|
| 540 |
+
|
| 541 |
+
context_tokens = self.proj3(audio_embeds_c).reshape(batch_size_c*N_t, self.context_tokens, self.output_dim)
|
| 542 |
+
|
| 543 |
+
# normalization and reshape
|
| 544 |
+
with amp.autocast(dtype=torch.float32):
|
| 545 |
+
context_tokens = self.norm(context_tokens)
|
| 546 |
+
context_tokens = rearrange(context_tokens, "(bz f) m c -> bz f m c", f=video_length)
|
| 547 |
+
|
| 548 |
+
return context_tokens
|
flash_head/src/pipeline/flash_head_pipeline.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from loguru import logger
|
| 5 |
+
import time
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.distributed as dist
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
|
| 11 |
+
from transformers import Wav2Vec2FeatureExtractor
|
| 12 |
+
|
| 13 |
+
from flash_head.src.modules.flash_head_model import WanModelAudioProject
|
| 14 |
+
from flash_head.audio_analysis.wav2vec2 import Wav2Vec2Model
|
| 15 |
+
from flash_head.utils.utils import match_and_blend_colors_torch, resize_and_centercrop
|
| 16 |
+
from flash_head.utils.facecrop import process_image
|
| 17 |
+
|
| 18 |
+
# compile models to speedup inference
|
| 19 |
+
COMPILE_MODEL = False
|
| 20 |
+
COMPILE_VAE = False
|
| 21 |
+
# use parallel vae to speedup decode/encode, only support WanVAE
|
| 22 |
+
USE_PARALLEL_VAE = True
|
| 23 |
+
|
| 24 |
+
def get_cond_image_dict(cond_image_path_or_dir, use_face_crop):
|
| 25 |
+
def get_image(cond_image_path, use_face_crop):
|
| 26 |
+
if use_face_crop:
|
| 27 |
+
try:
|
| 28 |
+
image = process_image(cond_image_path)
|
| 29 |
+
return image
|
| 30 |
+
except Exception as e:
|
| 31 |
+
logger.error(f"Error processing {cond_image_path}: {e}")
|
| 32 |
+
return Image.open(cond_image_path).convert("RGB")
|
| 33 |
+
|
| 34 |
+
if os.path.isdir(cond_image_path_or_dir):
|
| 35 |
+
import glob
|
| 36 |
+
cond_image_list = glob.glob(os.path.join(cond_image_path_or_dir, "*.png"))
|
| 37 |
+
cond_image_list.sort()
|
| 38 |
+
cond_image_dict = {cond_image.split("/")[-1].split(".")[0]: get_image(cond_image, use_face_crop) for cond_image in cond_image_list}
|
| 39 |
+
else:
|
| 40 |
+
cond_image_dict = {cond_image_path_or_dir.split("/")[-1].split(".")[0]: get_image(cond_image_path_or_dir, use_face_crop)}
|
| 41 |
+
return cond_image_dict
|
| 42 |
+
|
| 43 |
+
def timestep_transform(
|
| 44 |
+
t,
|
| 45 |
+
shift=5.0,
|
| 46 |
+
num_timesteps=1000,
|
| 47 |
+
):
|
| 48 |
+
t = t / num_timesteps
|
| 49 |
+
# shift the timestep based on ratio
|
| 50 |
+
new_t = shift * t / (1 + (shift - 1) * t)
|
| 51 |
+
new_t = new_t * num_timesteps
|
| 52 |
+
return new_t
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class FlashHeadPipeline:
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
checkpoint_dir,
|
| 59 |
+
model_type,
|
| 60 |
+
wav2vec_dir,
|
| 61 |
+
device="cuda",
|
| 62 |
+
param_dtype=torch.bfloat16,
|
| 63 |
+
use_usp=False,
|
| 64 |
+
num_timesteps=1000,
|
| 65 |
+
use_timestep_transform=True,
|
| 66 |
+
):
|
| 67 |
+
r"""
|
| 68 |
+
Initializes the image-to-video generation model components.
|
| 69 |
+
Args:
|
| 70 |
+
checkpoint_dir (`str`):
|
| 71 |
+
Path to directory containing model checkpoints
|
| 72 |
+
wav2vec_dir (`str`):
|
| 73 |
+
Path to directory containing wav2vec checkpoints
|
| 74 |
+
use_usp (`bool`, *optional*, defaults to False):
|
| 75 |
+
Enable distribution strategy of USP.
|
| 76 |
+
"""
|
| 77 |
+
self.param_dtype = param_dtype
|
| 78 |
+
self.device = device
|
| 79 |
+
self.rank = dist.get_rank() if dist.is_initialized() else 0
|
| 80 |
+
self.use_usp = use_usp and dist.is_initialized()
|
| 81 |
+
self.model_type = model_type
|
| 82 |
+
self.use_ltx = model_type == "lite"
|
| 83 |
+
|
| 84 |
+
if self.use_ltx:
|
| 85 |
+
model_dir = os.path.join(checkpoint_dir, "Model_Lite")
|
| 86 |
+
vae_dir = os.path.join(checkpoint_dir, "VAE_LTX")
|
| 87 |
+
|
| 88 |
+
from flash_head.ltx_video.ltx_vae import LtxVAE
|
| 89 |
+
self.vae = LtxVAE(
|
| 90 |
+
pretrained_model_type_or_path=vae_dir,
|
| 91 |
+
dtype=self.param_dtype,
|
| 92 |
+
device=self.device,
|
| 93 |
+
)
|
| 94 |
+
else:
|
| 95 |
+
vae_path = os.path.join(checkpoint_dir, "VAE_Wan/Wan2.1_VAE.pth")
|
| 96 |
+
|
| 97 |
+
from flash_head.wan.modules import WanVAE
|
| 98 |
+
self.vae = WanVAE(
|
| 99 |
+
vae_path=vae_path,
|
| 100 |
+
dtype=self.param_dtype,
|
| 101 |
+
device=self.device,
|
| 102 |
+
parallel=(USE_PARALLEL_VAE and self.use_usp),
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
if self.model_type == "pretrained":
|
| 106 |
+
self.audio_guide_scale = 3.0
|
| 107 |
+
model_dir = os.path.join(checkpoint_dir, "teacher")
|
| 108 |
+
elif self.model_type == "pro":
|
| 109 |
+
model_dir = os.path.join(checkpoint_dir, "Model_Pro")
|
| 110 |
+
|
| 111 |
+
self.model = WanModelAudioProject.from_pretrained(model_dir)
|
| 112 |
+
self.model.eval().requires_grad_(False)
|
| 113 |
+
self.model.to(device=self.device, dtype=self.param_dtype)
|
| 114 |
+
|
| 115 |
+
self.config = self.model.config
|
| 116 |
+
|
| 117 |
+
if use_usp:
|
| 118 |
+
from xfuser.core.distributed import get_sequence_parallel_world_size
|
| 119 |
+
self.sp_size = get_sequence_parallel_world_size()
|
| 120 |
+
else:
|
| 121 |
+
self.sp_size = 1
|
| 122 |
+
|
| 123 |
+
if dist.is_initialized():
|
| 124 |
+
dist.barrier()
|
| 125 |
+
|
| 126 |
+
self.num_timesteps = num_timesteps
|
| 127 |
+
self.use_timestep_transform = use_timestep_transform
|
| 128 |
+
|
| 129 |
+
if COMPILE_MODEL:
|
| 130 |
+
self.model = torch.compile(self.model)
|
| 131 |
+
if COMPILE_VAE:
|
| 132 |
+
if self.use_ltx:
|
| 133 |
+
self.vae.model.encode = torch.compile(self.vae.model.encode)
|
| 134 |
+
self.vae.model.decode = torch.compile(self.vae.model.decode)
|
| 135 |
+
else:
|
| 136 |
+
self.vae.encode = torch.compile(self.vae.encode)
|
| 137 |
+
self.vae.decode = torch.compile(self.vae.decode)
|
| 138 |
+
|
| 139 |
+
self.audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec_dir, local_files_only=True).to(self.device)
|
| 140 |
+
self.audio_encoder.feature_extractor._freeze_parameters()
|
| 141 |
+
self.wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec_dir, local_files_only=True)
|
| 142 |
+
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def prepare_params(self,
|
| 145 |
+
cond_image_path_or_dir,
|
| 146 |
+
target_size,
|
| 147 |
+
frame_num,
|
| 148 |
+
motion_frames_num,
|
| 149 |
+
sampling_steps,
|
| 150 |
+
seed=None,
|
| 151 |
+
shift=5.0,
|
| 152 |
+
color_correction_strength=0.0,
|
| 153 |
+
use_face_crop=False,
|
| 154 |
+
):
|
| 155 |
+
self.cond_image_dict = get_cond_image_dict(cond_image_path_or_dir, use_face_crop)
|
| 156 |
+
|
| 157 |
+
self.frame_num = frame_num
|
| 158 |
+
self.motion_frames_num = motion_frames_num
|
| 159 |
+
self.color_correction_strength = color_correction_strength
|
| 160 |
+
|
| 161 |
+
self.target_h, self.target_w = target_size
|
| 162 |
+
self.lat_h, self.lat_w = self.target_h // self.config.vae_stride[1], self.target_w // self.config.vae_stride[2]
|
| 163 |
+
|
| 164 |
+
self.generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 165 |
+
|
| 166 |
+
# prepare timesteps
|
| 167 |
+
if sampling_steps == 2:
|
| 168 |
+
timesteps = [1000, 500]
|
| 169 |
+
elif sampling_steps == 4:
|
| 170 |
+
timesteps = [1000, 750, 500, 250]
|
| 171 |
+
else:
|
| 172 |
+
timesteps = list(np.linspace(self.num_timesteps, 1, sampling_steps, dtype=np.float32))
|
| 173 |
+
|
| 174 |
+
timesteps.append(0.)
|
| 175 |
+
timesteps = [torch.tensor([t], device=self.device) for t in timesteps]
|
| 176 |
+
if self.use_timestep_transform:
|
| 177 |
+
timesteps = [timestep_transform(t, shift=shift, num_timesteps=self.num_timesteps) for t in timesteps]
|
| 178 |
+
self.timesteps = timesteps
|
| 179 |
+
|
| 180 |
+
self.cond_image_tensor_dict = {}
|
| 181 |
+
self.ref_img_latent_dict = {}
|
| 182 |
+
for i, (person_name, cond_image_pil) in enumerate(self.cond_image_dict.items()):
|
| 183 |
+
cond_image_tensor = resize_and_centercrop(cond_image_pil, (self.target_h, self.target_w)).to(self.device, dtype=self.param_dtype) # 1 C 1 H W
|
| 184 |
+
cond_image_tensor = (cond_image_tensor / 255 - 0.5) * 2
|
| 185 |
+
|
| 186 |
+
self.cond_image_tensor_dict[person_name] = cond_image_tensor
|
| 187 |
+
|
| 188 |
+
video_frames = cond_image_tensor.repeat(1, 1, self.frame_num, 1, 1)
|
| 189 |
+
self.ref_img_latent_dict[person_name] = self.vae.encode(video_frames) # (16, 9, 64, 64) / (128, 5, 16, 16)
|
| 190 |
+
if i == 0:
|
| 191 |
+
self.reset_person_name(person_name)
|
| 192 |
+
|
| 193 |
+
return
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def reset_person_name(self, person_name=None):
|
| 197 |
+
if person_name is None or person_name not in self.cond_image_dict:
|
| 198 |
+
pass
|
| 199 |
+
else:
|
| 200 |
+
self.person_name = person_name
|
| 201 |
+
self.original_color_reference = self.cond_image_tensor_dict[self.person_name]
|
| 202 |
+
self.ref_img_latent = self.ref_img_latent_dict[self.person_name]
|
| 203 |
+
self.latent_motion_frames = self.ref_img_latent[:, :1].clone()
|
| 204 |
+
|
| 205 |
+
@torch.no_grad()
|
| 206 |
+
def preprocess_audio(self, speech_array, sr=16000, fps=25):
|
| 207 |
+
video_length = len(speech_array) * fps / sr
|
| 208 |
+
|
| 209 |
+
# wav2vec_feature_extractor
|
| 210 |
+
audio_feature = np.squeeze(
|
| 211 |
+
self.wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values
|
| 212 |
+
)
|
| 213 |
+
audio_feature = torch.from_numpy(audio_feature).float().to(device=self.device)
|
| 214 |
+
audio_feature = audio_feature.unsqueeze(0)
|
| 215 |
+
|
| 216 |
+
# audio encoder
|
| 217 |
+
with torch.no_grad():
|
| 218 |
+
embeddings = self.audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True)
|
| 219 |
+
|
| 220 |
+
if len(embeddings) == 0:
|
| 221 |
+
logger.error("Fail to extract audio embedding")
|
| 222 |
+
return None
|
| 223 |
+
|
| 224 |
+
audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0)
|
| 225 |
+
audio_emb = rearrange(audio_emb, "b s d -> s b d")
|
| 226 |
+
return audio_emb
|
| 227 |
+
|
| 228 |
+
@torch.no_grad()
|
| 229 |
+
def generate(self, audio_embedding):
|
| 230 |
+
# evaluation mode
|
| 231 |
+
with torch.no_grad():
|
| 232 |
+
|
| 233 |
+
# sample videos
|
| 234 |
+
noise = torch.randn(
|
| 235 |
+
self.config.out_dim,
|
| 236 |
+
(self.frame_num - 1) // self.config.vae_stride[0] + 1,
|
| 237 |
+
self.lat_h,
|
| 238 |
+
self.lat_w,
|
| 239 |
+
dtype=self.param_dtype,
|
| 240 |
+
device=self.device,
|
| 241 |
+
generator=self.generator)
|
| 242 |
+
|
| 243 |
+
for i in range(len(self.timesteps)-1):
|
| 244 |
+
torch.cuda.synchronize()
|
| 245 |
+
start_time = time.time()
|
| 246 |
+
|
| 247 |
+
noise[:, :self.latent_motion_frames.shape[1]] = self.latent_motion_frames
|
| 248 |
+
|
| 249 |
+
flow_pred = self.model(
|
| 250 |
+
x=noise.unsqueeze(0),
|
| 251 |
+
timestep=self.timesteps[i],
|
| 252 |
+
context=audio_embedding,
|
| 253 |
+
y=self.ref_img_latent.unsqueeze(0),
|
| 254 |
+
)[0]
|
| 255 |
+
|
| 256 |
+
if self.model_type == "pretrained":
|
| 257 |
+
flow_pred_drop_audio = self.model(
|
| 258 |
+
x=noise.unsqueeze(0),
|
| 259 |
+
timestep=self.timesteps[i],
|
| 260 |
+
context=torch.zeros_like(audio_embedding),
|
| 261 |
+
y=self.ref_img_latent.unsqueeze(0),
|
| 262 |
+
)[0]
|
| 263 |
+
flow_pred = flow_pred_drop_audio + self.audio_guide_scale * (flow_pred - flow_pred_drop_audio)
|
| 264 |
+
|
| 265 |
+
# update latent
|
| 266 |
+
dt = self.timesteps[i] - self.timesteps[i + 1]
|
| 267 |
+
dt = (dt / self.num_timesteps).to(self.param_dtype)
|
| 268 |
+
noise = noise - flow_pred * dt[:, None, None, None]
|
| 269 |
+
|
| 270 |
+
else:
|
| 271 |
+
# update latent
|
| 272 |
+
t_i = (self.timesteps[i][:, None, None, None] / self.num_timesteps).to(self.param_dtype)
|
| 273 |
+
t_i_1 = (self.timesteps[i+1][:, None, None, None] / self.num_timesteps).to(self.param_dtype)
|
| 274 |
+
x_0 = noise - flow_pred * t_i
|
| 275 |
+
|
| 276 |
+
noise = (1 - t_i_1) * x_0 + t_i_1 * torch.randn(x_0.size(), dtype=x_0.dtype, device=self.device, generator=self.generator)
|
| 277 |
+
|
| 278 |
+
torch.cuda.synchronize()
|
| 279 |
+
end_time = time.time()
|
| 280 |
+
if self.rank == 0:
|
| 281 |
+
print(f'[generate] model denoise per step: {end_time - start_time}s')
|
| 282 |
+
|
| 283 |
+
noise[:, :self.latent_motion_frames.shape[1]] = self.latent_motion_frames
|
| 284 |
+
|
| 285 |
+
torch.cuda.synchronize()
|
| 286 |
+
start_decode_time = time.time()
|
| 287 |
+
|
| 288 |
+
videos = self.vae.decode(noise)
|
| 289 |
+
|
| 290 |
+
torch.cuda.synchronize()
|
| 291 |
+
end_decode_time = time.time()
|
| 292 |
+
if self.rank == 0:
|
| 293 |
+
print(f'[generate] decode video frames: {end_decode_time - start_decode_time}s')
|
| 294 |
+
|
| 295 |
+
torch.cuda.synchronize()
|
| 296 |
+
start_color_correction_time = time.time()
|
| 297 |
+
if self.color_correction_strength > 0.0:
|
| 298 |
+
videos = match_and_blend_colors_torch(videos, self.original_color_reference, self.color_correction_strength)
|
| 299 |
+
|
| 300 |
+
cond_frame = videos[:, :, -self.motion_frames_num:].to(self.device)
|
| 301 |
+
torch.cuda.synchronize()
|
| 302 |
+
end_color_correction_time = time.time()
|
| 303 |
+
if self.rank == 0:
|
| 304 |
+
print(f'[generate] color correction: {end_color_correction_time - start_color_correction_time}s')
|
| 305 |
+
|
| 306 |
+
torch.cuda.synchronize()
|
| 307 |
+
start_encode_time = time.time()
|
| 308 |
+
self.latent_motion_frames = self.vae.encode(cond_frame)
|
| 309 |
+
torch.cuda.synchronize()
|
| 310 |
+
end_encode_time = time.time()
|
| 311 |
+
if self.rank == 0:
|
| 312 |
+
print(f'[generate] encode motion frames: {end_encode_time - start_encode_time}s')
|
| 313 |
+
|
| 314 |
+
gen_video_samples = videos #[:, :, self.motion_frames_num:]
|
| 315 |
+
|
| 316 |
+
return gen_video_samples[0].to(torch.float32)
|
flash_head/utils/cpu_face_handler.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import mediapipe as mp
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Tuple, List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class CPUFaceHandler:
|
| 7 |
+
"""Handler for CPU-based face detection using MediaPipe.
|
| 8 |
+
(2 ms/frame)
|
| 9 |
+
This handler provides a simple interface for face detection using MediaPipe's
|
| 10 |
+
face detection model. It's optimized for CPU usage and provides basic face
|
| 11 |
+
detection functionality.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
def __init__(self, model_selection: int = 1, min_detection_confidence: float = 0.0):
|
| 15 |
+
"""Initialize the face detection handler."""
|
| 16 |
+
self.detector = mp.solutions.face_detection.FaceDetection(
|
| 17 |
+
model_selection=model_selection,
|
| 18 |
+
min_detection_confidence=min_detection_confidence,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def detect(self, image: np.ndarray) -> Tuple[int, List[int]]:
|
| 22 |
+
"""Detect faces in the given image.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
image (np.ndarray): RGB image array.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
Tuple[int, List[int]]: A tuple containing:
|
| 29 |
+
- Number of faces detected (int)
|
| 30 |
+
- Bounding box coordinates [x1, y1, x2, y2] if exactly one face is detected,
|
| 31 |
+
empty list otherwise
|
| 32 |
+
"""
|
| 33 |
+
bboxs, scores = [], []
|
| 34 |
+
results = self.detector.process(image)
|
| 35 |
+
detection_result = results.detections
|
| 36 |
+
if detection_result is None:
|
| 37 |
+
return bboxs, scores
|
| 38 |
+
for detection in detection_result:
|
| 39 |
+
bboxC = detection.location_data.relative_bounding_box
|
| 40 |
+
x, y, w, h = bboxC.xmin, bboxC.ymin, bboxC.width, bboxC.height
|
| 41 |
+
x1, y1, x2, y2 = x, y, x + w, y + h
|
| 42 |
+
bboxs.append([x1, y1, x2, y2])
|
| 43 |
+
scores.append(detection.score[0])
|
| 44 |
+
return bboxs, scores
|
| 45 |
+
|
| 46 |
+
def __call__(self, image: np.ndarray) -> Tuple[int, List[int]]:
|
| 47 |
+
"""Make the handler callable.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
image (np.ndarray): RGB image array.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Tuple[int, List[int]]: Same as detect() method.
|
| 54 |
+
"""
|
| 55 |
+
return self.detect(image)
|
flash_head/utils/facecrop.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
人脸裁剪处理脚本
|
| 4 |
+
从单张图像中检测人脸,裁剪并调整大小到指定尺寸
|
| 5 |
+
"""
|
| 6 |
+
import os
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from flash_head.utils.cpu_face_handler import CPUFaceHandler
|
| 11 |
+
|
| 12 |
+
def get_scaled_bbox(
|
| 13 |
+
bbox, img_w, img_h, ratio: float = 1.0, face_image: Image.Image = None
|
| 14 |
+
):
|
| 15 |
+
"""
|
| 16 |
+
根据人脸边界框计算缩放后的裁剪区域
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
bbox: 人脸边界框 [x1, y1, x2, y2]
|
| 20 |
+
img_w: 图像宽度
|
| 21 |
+
img_h: 图像高度
|
| 22 |
+
ratio: 缩放比例,数值越大,人脸在画面中的比例越小(周围留白越多)
|
| 23 |
+
face_image: PIL Image 对象
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
裁剪后的人脸图像
|
| 27 |
+
"""
|
| 28 |
+
x1, y1, x2, y2 = bbox
|
| 29 |
+
|
| 30 |
+
# Calculate center point
|
| 31 |
+
center_x = (x1 + x2) / 2
|
| 32 |
+
center_y = (y1 + y2) / 2
|
| 33 |
+
|
| 34 |
+
# Calculate width and height
|
| 35 |
+
width = x2 - x1
|
| 36 |
+
|
| 37 |
+
# Scale width and height
|
| 38 |
+
new_width = width * ratio
|
| 39 |
+
new_height = new_width
|
| 40 |
+
|
| 41 |
+
# tile pix
|
| 42 |
+
dis_x_left = new_width * 0.5
|
| 43 |
+
dis_x_right = new_width - dis_x_left # 0.5new_width
|
| 44 |
+
dis_y_up = new_height * 0.55
|
| 45 |
+
dis_y_down = new_height - dis_y_up # 0.45new_height
|
| 46 |
+
|
| 47 |
+
# Calculate new coordinates
|
| 48 |
+
new_x1 = int(max(0, center_x - dis_x_left))
|
| 49 |
+
new_y1 = int(max(0, center_y - dis_y_up))
|
| 50 |
+
new_x2 = int(min(img_w, center_x + dis_x_right))
|
| 51 |
+
new_y2 = int(min(img_h, center_y + dis_y_down))
|
| 52 |
+
scaled_bbox = [new_x1, new_y1, new_x2, new_y2]
|
| 53 |
+
crop_face = face_image.crop(scaled_bbox)
|
| 54 |
+
return crop_face
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def process_image(
|
| 58 |
+
input_path,
|
| 59 |
+
face_ratio=2.0,
|
| 60 |
+
target_size=(512, 512),
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
处理单张图像,进行人脸检测和裁剪
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
input_path: 输入图像路径
|
| 67 |
+
face_ratio: 人脸缩放比例,建议范围:1.5-3.0,默认2.0
|
| 68 |
+
target_size: 输出图像尺寸,默认(512, 512)
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
imgae: 处理后的图像
|
| 72 |
+
"""
|
| 73 |
+
# 初始化人脸检测器
|
| 74 |
+
face_detector = CPUFaceHandler()
|
| 75 |
+
|
| 76 |
+
# 验证输入文件
|
| 77 |
+
if not os.path.isfile(input_path):
|
| 78 |
+
raise ValueError(f"File not found: {input_path}")
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
# 读取图像
|
| 82 |
+
image = Image.open(input_path)
|
| 83 |
+
image = image.convert("RGB")
|
| 84 |
+
image_rgb = np.array(image)
|
| 85 |
+
img_h, img_w = image_rgb.shape[:2]
|
| 86 |
+
|
| 87 |
+
# 检测人脸
|
| 88 |
+
boxes, scores = face_detector(image_rgb)
|
| 89 |
+
|
| 90 |
+
if len(boxes) == 0:
|
| 91 |
+
raise ValueError("No face detected")
|
| 92 |
+
|
| 93 |
+
# 转换边界框坐标(从相对坐标转为绝对坐标)
|
| 94 |
+
boxes_abs = [
|
| 95 |
+
boxes[0][0] * img_w,
|
| 96 |
+
boxes[0][1] * img_h,
|
| 97 |
+
boxes[0][2] * img_w,
|
| 98 |
+
boxes[0][3] * img_h
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
# 裁剪人脸
|
| 102 |
+
crop_face = get_scaled_bbox(boxes_abs, img_w, img_h, face_ratio, image)
|
| 103 |
+
|
| 104 |
+
# 调整大小
|
| 105 |
+
crop_face = crop_face.resize(target_size)
|
| 106 |
+
|
| 107 |
+
return crop_face
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
raise ValueError(f"Error processing {input_path}: {e}")
|
flash_head/utils/utils.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import numpy as np
|
| 4 |
+
import math
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import torchvision.transforms as transforms
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import pyloudnorm as pyln
|
| 9 |
+
|
| 10 |
+
def rgb_to_lab_torch(rgb: torch.Tensor) -> torch.Tensor:
|
| 11 |
+
"""
|
| 12 |
+
PyTorch GPU版本:RGB转Lab颜色空间(输入范围[0,1],张量形状任意,最后一维为通道数)
|
| 13 |
+
参考CIE 1931标准转换公式
|
| 14 |
+
"""
|
| 15 |
+
# 转换为线性RGB(sRGB伽马校正逆过程)
|
| 16 |
+
linear_rgb = torch.where(
|
| 17 |
+
rgb > 0.04045,
|
| 18 |
+
((rgb + 0.055) / 1.055) ** 2.4,
|
| 19 |
+
rgb / 12.92
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
# 线性RGB转XYZ(使用sRGB标准白点D65)
|
| 23 |
+
xyz_from_rgb = torch.tensor([
|
| 24 |
+
[0.4124564, 0.3575761, 0.1804375],
|
| 25 |
+
[0.2126729, 0.7151522, 0.0721750],
|
| 26 |
+
[0.0193339, 0.1191920, 0.9503041]
|
| 27 |
+
], dtype=rgb.dtype, device=rgb.device)
|
| 28 |
+
|
| 29 |
+
# 维度适配:确保输入为(B, ..., C),矩阵乘法后保持空间维度
|
| 30 |
+
shape = linear_rgb.shape
|
| 31 |
+
linear_rgb_flat = linear_rgb.reshape(-1, 3) # (N, 3),N=B*T*H*W
|
| 32 |
+
xyz_flat = linear_rgb_flat @ xyz_from_rgb.T # (N, 3)
|
| 33 |
+
xyz = xyz_flat.reshape(shape) # 恢复原形状
|
| 34 |
+
|
| 35 |
+
# XYZ转Lab(使用D65白点参数)
|
| 36 |
+
xyz_ref = torch.tensor([0.95047, 1.0, 1.08883], dtype=rgb.dtype, device=rgb.device)
|
| 37 |
+
xyz_normalized = xyz / xyz_ref[None, None, None, None, :] # 广播适配(B, C, T, H, W)
|
| 38 |
+
|
| 39 |
+
# 应用Lab转换公式
|
| 40 |
+
epsilon = 0.008856
|
| 41 |
+
kappa = 903.3
|
| 42 |
+
xyz_normalized = torch.clamp(xyz_normalized, 1e-8, 1.0) # 避免log(0)
|
| 43 |
+
|
| 44 |
+
f_xyz = torch.where(
|
| 45 |
+
xyz_normalized > epsilon,
|
| 46 |
+
xyz_normalized ** (1/3),
|
| 47 |
+
(kappa * xyz_normalized + 16) / 116
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
L = 116 * f_xyz[..., 1] - 16 # Y通道对应亮度
|
| 51 |
+
a = 500 * (f_xyz[..., 0] - f_xyz[..., 1]) # X-Y对应红绿
|
| 52 |
+
b = 200 * (f_xyz[..., 1] - f_xyz[..., 2]) # Y-Z对应蓝黄
|
| 53 |
+
|
| 54 |
+
lab = torch.stack([L, a, b], dim=-1) # 最后一维拼接为Lab通道
|
| 55 |
+
return lab
|
| 56 |
+
|
| 57 |
+
def lab_to_rgb_torch(lab: torch.Tensor) -> torch.Tensor:
|
| 58 |
+
"""
|
| 59 |
+
PyTorch GPU版本:Lab转RGB颜色空间(输出范围[0,1],张量形状任意,最后一维为通道数)
|
| 60 |
+
"""
|
| 61 |
+
# Lab分离通道
|
| 62 |
+
L = lab[..., 0]
|
| 63 |
+
a = lab[..., 1]
|
| 64 |
+
b = lab[..., 2]
|
| 65 |
+
|
| 66 |
+
# Lab转XYZ
|
| 67 |
+
f_y = (L + 16) / 116
|
| 68 |
+
f_x = (a / 500) + f_y
|
| 69 |
+
f_z = f_y - (b / 200)
|
| 70 |
+
|
| 71 |
+
epsilon = 0.008856
|
| 72 |
+
kappa = 903.3
|
| 73 |
+
|
| 74 |
+
x = torch.where(f_x ** 3 > epsilon, f_x ** 3, (116 * f_x - 16) / kappa)
|
| 75 |
+
y = torch.where(L > kappa * epsilon, ((L + 16) / 116) ** 3, L / kappa)
|
| 76 |
+
z = torch.where(f_z ** 3 > epsilon, f_z ** 3, (116 * f_z - 16) / kappa)
|
| 77 |
+
|
| 78 |
+
# 乘以D65白点参数
|
| 79 |
+
xyz_ref = torch.tensor([0.95047, 1.0, 1.08883], dtype=lab.dtype, device=lab.device)
|
| 80 |
+
xyz = torch.stack([x, y, z], dim=-1) * xyz_ref[None, None, None, None, :]
|
| 81 |
+
|
| 82 |
+
# XYZ转线性RGB
|
| 83 |
+
rgb_from_xyz = torch.tensor([
|
| 84 |
+
[3.2404542, -1.5371385, -0.4985314],
|
| 85 |
+
[-0.9692660, 1.8760108, 0.0415560],
|
| 86 |
+
[0.0556434, -0.2040259, 1.0572252]
|
| 87 |
+
], dtype=lab.dtype, device=lab.device)
|
| 88 |
+
|
| 89 |
+
# 维度适配:矩阵乘法
|
| 90 |
+
shape = xyz.shape
|
| 91 |
+
xyz_flat = xyz.reshape(-1, 3) # (N, 3)
|
| 92 |
+
linear_rgb_flat = xyz_flat @ rgb_from_xyz.T # (N, 3)
|
| 93 |
+
linear_rgb = linear_rgb_flat.reshape(shape) # 恢复原形状
|
| 94 |
+
|
| 95 |
+
# 线性RGB转sRGB(伽马校正)
|
| 96 |
+
rgb = torch.where(
|
| 97 |
+
linear_rgb > 0.0031308,
|
| 98 |
+
1.055 * (linear_rgb ** (1/2.4)) - 0.055,
|
| 99 |
+
12.92 * linear_rgb
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
# 确保输出在[0,1]范围内
|
| 103 |
+
rgb = torch.clamp(rgb, 0.0, 1.0)
|
| 104 |
+
return rgb
|
| 105 |
+
|
| 106 |
+
def match_and_blend_colors_torch(
|
| 107 |
+
source_chunk: torch.Tensor,
|
| 108 |
+
reference_image: torch.Tensor,
|
| 109 |
+
strength: float
|
| 110 |
+
) -> torch.Tensor:
|
| 111 |
+
"""
|
| 112 |
+
全GPU批量运算版本:将视频chunk的颜色匹配到参考图像并混合(支持B>1、T帧并行)
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
source_chunk (torch.Tensor): 视频chunk (B, C, T, H, W),范围[-1, 1]
|
| 116 |
+
reference_image (torch.Tensor): 参考图像 (B, C, 1, H, W),范围[-1, 1](B需与source_chunk一致)
|
| 117 |
+
strength (float): 颜色校正强度 (0.0-1.0),0.0无校正,1.0完全校正
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
torch.Tensor: 颜色校正后的视频chunk (B, C, T, H, W),范围[-1, 1]
|
| 121 |
+
"""
|
| 122 |
+
# 强度为0直接返回原图
|
| 123 |
+
if strength <= 0.0:
|
| 124 |
+
return source_chunk.clone()
|
| 125 |
+
|
| 126 |
+
# 验证强度范围
|
| 127 |
+
if not 0.0 <= strength <= 1.0:
|
| 128 |
+
raise ValueError(f"Strength必须在0.0-1.0之间,当前值:{strength}")
|
| 129 |
+
|
| 130 |
+
# 验证输入形状(确保B一致,参考图T=1)
|
| 131 |
+
B, C, T, H, W = source_chunk.shape
|
| 132 |
+
assert reference_image.shape == (B, C, 1, H, W), \
|
| 133 |
+
f"参考图像形状需为(B, C, 1, H, W),当前为{reference_image.shape}"
|
| 134 |
+
assert C == 3, f"仅支持3通道RGB图像,当前通道数:{C}"
|
| 135 |
+
|
| 136 |
+
# 保持设备和数据类型一致
|
| 137 |
+
device = source_chunk.device
|
| 138 |
+
dtype = source_chunk.dtype
|
| 139 |
+
reference_image = reference_image.to(device=device, dtype=dtype)
|
| 140 |
+
|
| 141 |
+
# 1. 从[-1,1]转换到[0,1](GPU上直接运算)
|
| 142 |
+
source_01 = (source_chunk + 1.0) / 2.0
|
| 143 |
+
ref_01 = (reference_image + 1.0) / 2.0
|
| 144 |
+
|
| 145 |
+
# 2. 调整维度顺序:(B, C, T, H, W) → (B, T, H, W, C)(适配颜色空间转换)
|
| 146 |
+
# 参考图:(B, C, 1, H, W) → (B, 1, H, W, C)
|
| 147 |
+
source_permuted = source_01.permute(0, 2, 3, 4, 1) # 通道移到最后一维
|
| 148 |
+
ref_permuted = ref_01.permute(0, 2, 3, 4, 1)
|
| 149 |
+
|
| 150 |
+
# 3. RGB转Lab(批量处理所有帧)
|
| 151 |
+
source_lab = rgb_to_lab_torch(source_permuted)
|
| 152 |
+
ref_lab = rgb_to_lab_torch(ref_permuted) # (B, 1, H, W, 3)
|
| 153 |
+
|
| 154 |
+
# 4. 批量颜色迁移:匹配L/a/b通道的均值和标准差(核心逻辑)
|
| 155 |
+
# 计算参考图各通道的均值和标准差(对H、W维度求统计,保持B维度)
|
| 156 |
+
ref_mean = ref_lab.mean(dim=[2, 3], keepdim=True) # (B, 1, 1, 1, 3)
|
| 157 |
+
ref_std = ref_lab.std(dim=[2, 3], keepdim=True, unbiased=False) # (B, 1, 1, 1, 3)
|
| 158 |
+
|
| 159 |
+
# 计算源视频各通道的均值和标准差(对H、W维度求统计,保持B、T维度)
|
| 160 |
+
source_mean = source_lab.mean(dim=[2, 3], keepdim=True) # (B, T, 1, 1, 3)
|
| 161 |
+
source_std = source_lab.std(dim=[2, 3], keepdim=True, unbiased=False) # (B, T, 1, 1, 3)
|
| 162 |
+
|
| 163 |
+
# 避免标准差为0的除法错误(用1.0替代0)
|
| 164 |
+
source_std_safe = torch.where(source_std < 1e-8, torch.ones_like(source_std), source_std)
|
| 165 |
+
|
| 166 |
+
# 颜色迁移公式:(源 - 源均值) * (参考标准差/源标准差) + 参考均值
|
| 167 |
+
corrected_lab = (source_lab - source_mean) * (ref_std / source_std_safe) + ref_mean
|
| 168 |
+
|
| 169 |
+
# 5. Lab转RGB(批量转换所有校正后的帧)
|
| 170 |
+
corrected_rgb_01 = lab_to_rgb_torch(corrected_lab)
|
| 171 |
+
|
| 172 |
+
# 6. 批量混合原始帧和校正帧(按强度加权)
|
| 173 |
+
blended_rgb_01 = (1 - strength) * source_permuted + strength * corrected_rgb_01
|
| 174 |
+
|
| 175 |
+
# 7. 还原维度顺序和数值范围:(B, T, H, W, C) → (B, C, T, H, W),范围[0,1]→[-1,1]
|
| 176 |
+
blended_rgb_01 = blended_rgb_01.permute(0, 4, 1, 2, 3) # 通道移回第二维
|
| 177 |
+
blended_rgb_minus1_1 = (blended_rgb_01 * 2.0) - 1.0
|
| 178 |
+
|
| 179 |
+
# 8. 确保输出格式正确(连续内存布局)
|
| 180 |
+
output = blended_rgb_minus1_1.contiguous().to(device=device, dtype=dtype)
|
| 181 |
+
|
| 182 |
+
return output
|
| 183 |
+
|
| 184 |
+
def resize_and_centercrop(cond_image, target_size):
|
| 185 |
+
"""
|
| 186 |
+
Resize image or tensor to the target size without padding.
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
# Get the original size
|
| 190 |
+
if isinstance(cond_image, torch.Tensor):
|
| 191 |
+
_, orig_h, orig_w = cond_image.shape
|
| 192 |
+
else:
|
| 193 |
+
orig_h, orig_w = cond_image.height, cond_image.width
|
| 194 |
+
|
| 195 |
+
target_h, target_w = target_size
|
| 196 |
+
|
| 197 |
+
# Calculate the scaling factor for resizing
|
| 198 |
+
scale_h = target_h / orig_h
|
| 199 |
+
scale_w = target_w / orig_w
|
| 200 |
+
|
| 201 |
+
# Compute the final size
|
| 202 |
+
scale = max(scale_h, scale_w)
|
| 203 |
+
final_h = math.ceil(scale * orig_h)
|
| 204 |
+
final_w = math.ceil(scale * orig_w)
|
| 205 |
+
|
| 206 |
+
# Resize
|
| 207 |
+
if isinstance(cond_image, torch.Tensor):
|
| 208 |
+
if len(cond_image.shape) == 3:
|
| 209 |
+
cond_image = cond_image[None]
|
| 210 |
+
resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous()
|
| 211 |
+
# crop
|
| 212 |
+
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
|
| 213 |
+
cropped_tensor = cropped_tensor.squeeze(0)
|
| 214 |
+
else:
|
| 215 |
+
resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR)
|
| 216 |
+
resized_image = np.array(resized_image)
|
| 217 |
+
# tensor and crop
|
| 218 |
+
resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous()
|
| 219 |
+
cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size)
|
| 220 |
+
cropped_tensor = cropped_tensor[:, :, None, :, :]
|
| 221 |
+
|
| 222 |
+
return cropped_tensor
|
flash_head/wan/modules/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .vae import WanVAE
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
'WanVAE',
|
| 5 |
+
]
|
flash_head/wan/modules/vae.py
ADDED
|
@@ -0,0 +1,1598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from loguru import logger
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"WanVAE",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
CACHE_T = 2
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class CausalConv3d(nn.Conv3d):
|
| 18 |
+
"""
|
| 19 |
+
Causal 3d convolusion.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, *args, **kwargs):
|
| 23 |
+
super().__init__(*args, **kwargs)
|
| 24 |
+
self._padding = (
|
| 25 |
+
self.padding[2],
|
| 26 |
+
self.padding[2],
|
| 27 |
+
self.padding[1],
|
| 28 |
+
self.padding[1],
|
| 29 |
+
2 * self.padding[0],
|
| 30 |
+
0,
|
| 31 |
+
)
|
| 32 |
+
self.padding = (0, 0, 0)
|
| 33 |
+
|
| 34 |
+
def forward(self, x, cache_x=None):
|
| 35 |
+
padding = list(self._padding)
|
| 36 |
+
if cache_x is not None and self._padding[4] > 0:
|
| 37 |
+
cache_x = cache_x.to(x.device)
|
| 38 |
+
x = torch.cat([cache_x, x], dim=2)
|
| 39 |
+
padding[4] -= cache_x.shape[2]
|
| 40 |
+
x = F.pad(x, padding)
|
| 41 |
+
|
| 42 |
+
return super().forward(x)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class RMS_norm(nn.Module):
|
| 46 |
+
def __init__(self, dim, channel_first=True, images=True, bias=False):
|
| 47 |
+
super().__init__()
|
| 48 |
+
broadcastable_dims = (1, 1, 1) if not images else (1, 1)
|
| 49 |
+
shape = (dim, *broadcastable_dims) if channel_first else (dim,)
|
| 50 |
+
|
| 51 |
+
self.channel_first = channel_first
|
| 52 |
+
self.scale = dim**0.5
|
| 53 |
+
self.gamma = nn.Parameter(torch.ones(shape))
|
| 54 |
+
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
return (
|
| 58 |
+
F.normalize(x, dim=(1 if self.channel_first else -1))
|
| 59 |
+
* self.scale
|
| 60 |
+
* self.gamma
|
| 61 |
+
+ self.bias
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Upsample(nn.Upsample):
|
| 66 |
+
def forward(self, x):
|
| 67 |
+
"""
|
| 68 |
+
Fix bfloat16 support for nearest neighbor interpolation.
|
| 69 |
+
"""
|
| 70 |
+
return super().forward(x)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class Resample(nn.Module):
|
| 74 |
+
def __init__(self, dim, mode):
|
| 75 |
+
assert mode in (
|
| 76 |
+
"none",
|
| 77 |
+
"upsample2d",
|
| 78 |
+
"upsample3d",
|
| 79 |
+
"downsample2d",
|
| 80 |
+
"downsample3d",
|
| 81 |
+
)
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.dim = dim
|
| 84 |
+
self.mode = mode
|
| 85 |
+
|
| 86 |
+
# layers
|
| 87 |
+
if mode == "upsample2d":
|
| 88 |
+
self.resample = nn.Sequential(
|
| 89 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 90 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1),
|
| 91 |
+
)
|
| 92 |
+
elif mode == "upsample3d":
|
| 93 |
+
self.resample = nn.Sequential(
|
| 94 |
+
Upsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
| 95 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1),
|
| 96 |
+
)
|
| 97 |
+
self.time_conv = CausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
| 98 |
+
|
| 99 |
+
elif mode == "downsample2d":
|
| 100 |
+
self.resample = nn.Sequential(
|
| 101 |
+
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
| 102 |
+
)
|
| 103 |
+
elif mode == "downsample3d":
|
| 104 |
+
self.resample = nn.Sequential(
|
| 105 |
+
nn.ZeroPad2d((0, 1, 0, 1)), nn.Conv2d(dim, dim, 3, stride=(2, 2))
|
| 106 |
+
)
|
| 107 |
+
self.time_conv = CausalConv3d(
|
| 108 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
else:
|
| 112 |
+
self.resample = nn.Identity()
|
| 113 |
+
|
| 114 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 115 |
+
b, c, t, h, w = x.size()
|
| 116 |
+
if self.mode == "upsample3d":
|
| 117 |
+
if feat_cache is not None:
|
| 118 |
+
idx = feat_idx[0]
|
| 119 |
+
if feat_cache[idx] is None:
|
| 120 |
+
feat_cache[idx] = "Rep"
|
| 121 |
+
feat_idx[0] += 1
|
| 122 |
+
else:
|
| 123 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 124 |
+
if (
|
| 125 |
+
cache_x.shape[2] < 2
|
| 126 |
+
and feat_cache[idx] is not None
|
| 127 |
+
and feat_cache[idx] != "Rep"
|
| 128 |
+
):
|
| 129 |
+
# cache last frame of last two chunk
|
| 130 |
+
cache_x = torch.cat(
|
| 131 |
+
[
|
| 132 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 133 |
+
.unsqueeze(2)
|
| 134 |
+
.to(cache_x.device),
|
| 135 |
+
cache_x,
|
| 136 |
+
],
|
| 137 |
+
dim=2,
|
| 138 |
+
)
|
| 139 |
+
if (
|
| 140 |
+
cache_x.shape[2] < 2
|
| 141 |
+
and feat_cache[idx] is not None
|
| 142 |
+
and feat_cache[idx] == "Rep"
|
| 143 |
+
):
|
| 144 |
+
cache_x = torch.cat(
|
| 145 |
+
[torch.zeros_like(cache_x).to(cache_x.device), cache_x],
|
| 146 |
+
dim=2,
|
| 147 |
+
)
|
| 148 |
+
if feat_cache[idx] == "Rep":
|
| 149 |
+
x = self.time_conv(x)
|
| 150 |
+
else:
|
| 151 |
+
x = self.time_conv(x, feat_cache[idx])
|
| 152 |
+
feat_cache[idx] = cache_x
|
| 153 |
+
feat_idx[0] += 1
|
| 154 |
+
|
| 155 |
+
x = x.reshape(b, 2, c, t, h, w)
|
| 156 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]), 3)
|
| 157 |
+
x = x.reshape(b, c, t * 2, h, w)
|
| 158 |
+
t = x.shape[2]
|
| 159 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 160 |
+
x = self.resample(x)
|
| 161 |
+
x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
|
| 162 |
+
|
| 163 |
+
if self.mode == "downsample3d":
|
| 164 |
+
if feat_cache is not None:
|
| 165 |
+
idx = feat_idx[0]
|
| 166 |
+
if feat_cache[idx] is None:
|
| 167 |
+
feat_cache[idx] = x.clone()
|
| 168 |
+
feat_idx[0] += 1
|
| 169 |
+
else:
|
| 170 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
| 171 |
+
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
| 172 |
+
# # cache last frame of last two chunk
|
| 173 |
+
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
| 174 |
+
|
| 175 |
+
x = self.time_conv(
|
| 176 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)
|
| 177 |
+
)
|
| 178 |
+
feat_cache[idx] = cache_x
|
| 179 |
+
feat_idx[0] += 1
|
| 180 |
+
return x
|
| 181 |
+
|
| 182 |
+
def init_weight(self, conv):
|
| 183 |
+
conv_weight = conv.weight
|
| 184 |
+
nn.init.zeros_(conv_weight)
|
| 185 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 186 |
+
one_matrix = torch.eye(c1, c2)
|
| 187 |
+
init_matrix = one_matrix
|
| 188 |
+
nn.init.zeros_(conv_weight)
|
| 189 |
+
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
| 190 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
| 191 |
+
conv.weight.data.copy_(conv_weight)
|
| 192 |
+
nn.init.zeros_(conv.bias.data)
|
| 193 |
+
|
| 194 |
+
def init_weight2(self, conv):
|
| 195 |
+
conv_weight = conv.weight.data
|
| 196 |
+
nn.init.zeros_(conv_weight)
|
| 197 |
+
c1, c2, t, h, w = conv_weight.size()
|
| 198 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
| 199 |
+
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
| 200 |
+
conv_weight[: c1 // 2, :, -1, 0, 0] = init_matrix
|
| 201 |
+
conv_weight[c1 // 2 :, :, -1, 0, 0] = init_matrix
|
| 202 |
+
conv.weight.data.copy_(conv_weight)
|
| 203 |
+
nn.init.zeros_(conv.bias.data)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class ResidualBlock(nn.Module):
|
| 207 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.in_dim = in_dim
|
| 210 |
+
self.out_dim = out_dim
|
| 211 |
+
|
| 212 |
+
# layers
|
| 213 |
+
self.residual = nn.Sequential(
|
| 214 |
+
RMS_norm(in_dim, images=False),
|
| 215 |
+
nn.SiLU(),
|
| 216 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
| 217 |
+
RMS_norm(out_dim, images=False),
|
| 218 |
+
nn.SiLU(),
|
| 219 |
+
nn.Dropout(dropout),
|
| 220 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1),
|
| 221 |
+
)
|
| 222 |
+
self.shortcut = (
|
| 223 |
+
CausalConv3d(in_dim, out_dim, 1) if in_dim != out_dim else nn.Identity()
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 227 |
+
h = self.shortcut(x)
|
| 228 |
+
for layer in self.residual:
|
| 229 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 230 |
+
idx = feat_idx[0]
|
| 231 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 232 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 233 |
+
# cache last frame of last two chunk
|
| 234 |
+
cache_x = torch.cat(
|
| 235 |
+
[
|
| 236 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 237 |
+
.unsqueeze(2)
|
| 238 |
+
.to(cache_x.device),
|
| 239 |
+
cache_x,
|
| 240 |
+
],
|
| 241 |
+
dim=2,
|
| 242 |
+
)
|
| 243 |
+
x = layer(x, feat_cache[idx])
|
| 244 |
+
feat_cache[idx] = cache_x
|
| 245 |
+
feat_idx[0] += 1
|
| 246 |
+
else:
|
| 247 |
+
x = layer(x)
|
| 248 |
+
return x + h
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
class AttentionBlock(nn.Module):
|
| 252 |
+
"""
|
| 253 |
+
Causal self-attention with a single head.
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, dim):
|
| 257 |
+
super().__init__()
|
| 258 |
+
self.dim = dim
|
| 259 |
+
|
| 260 |
+
# layers
|
| 261 |
+
self.norm = RMS_norm(dim)
|
| 262 |
+
self.to_qkv = nn.Conv2d(dim, dim * 3, 1)
|
| 263 |
+
self.proj = nn.Conv2d(dim, dim, 1)
|
| 264 |
+
|
| 265 |
+
# zero out the last layer params
|
| 266 |
+
nn.init.zeros_(self.proj.weight)
|
| 267 |
+
|
| 268 |
+
def forward(self, x):
|
| 269 |
+
identity = x
|
| 270 |
+
b, c, t, h, w = x.size()
|
| 271 |
+
x = rearrange(x, "b c t h w -> (b t) c h w")
|
| 272 |
+
x = self.norm(x)
|
| 273 |
+
# compute query, key, value
|
| 274 |
+
q, k, v = (
|
| 275 |
+
self.to_qkv(x)
|
| 276 |
+
.reshape(b * t, 1, c * 3, -1)
|
| 277 |
+
.permute(0, 1, 3, 2)
|
| 278 |
+
.contiguous()
|
| 279 |
+
.chunk(3, dim=-1)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# apply attention
|
| 283 |
+
x = F.scaled_dot_product_attention(
|
| 284 |
+
q,
|
| 285 |
+
k,
|
| 286 |
+
v,
|
| 287 |
+
)
|
| 288 |
+
x = x.squeeze(1).permute(0, 2, 1).reshape(b * t, c, h, w)
|
| 289 |
+
|
| 290 |
+
# output
|
| 291 |
+
x = self.proj(x)
|
| 292 |
+
x = rearrange(x, "(b t) c h w-> b c t h w", t=t)
|
| 293 |
+
return x + identity
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
class Encoder3d(nn.Module):
|
| 297 |
+
def __init__(
|
| 298 |
+
self,
|
| 299 |
+
dim=128,
|
| 300 |
+
z_dim=4,
|
| 301 |
+
dim_mult=[1, 2, 4, 4],
|
| 302 |
+
num_res_blocks=2,
|
| 303 |
+
attn_scales=[],
|
| 304 |
+
temperal_downsample=[True, True, False],
|
| 305 |
+
dropout=0.0,
|
| 306 |
+
):
|
| 307 |
+
super().__init__()
|
| 308 |
+
self.dim = dim
|
| 309 |
+
self.z_dim = z_dim
|
| 310 |
+
self.dim_mult = dim_mult
|
| 311 |
+
self.num_res_blocks = num_res_blocks
|
| 312 |
+
self.attn_scales = attn_scales
|
| 313 |
+
self.temperal_downsample = temperal_downsample
|
| 314 |
+
|
| 315 |
+
# dimensions
|
| 316 |
+
dims = [dim * u for u in [1] + dim_mult]
|
| 317 |
+
scale = 1.0
|
| 318 |
+
|
| 319 |
+
# init block
|
| 320 |
+
self.conv1 = CausalConv3d(3, dims[0], 3, padding=1)
|
| 321 |
+
|
| 322 |
+
# downsample blocks
|
| 323 |
+
downsamples = []
|
| 324 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 325 |
+
# residual (+attention) blocks
|
| 326 |
+
for _ in range(num_res_blocks):
|
| 327 |
+
downsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 328 |
+
if scale in attn_scales:
|
| 329 |
+
downsamples.append(AttentionBlock(out_dim))
|
| 330 |
+
in_dim = out_dim
|
| 331 |
+
|
| 332 |
+
# downsample block
|
| 333 |
+
if i != len(dim_mult) - 1:
|
| 334 |
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
| 335 |
+
downsamples.append(Resample(out_dim, mode=mode))
|
| 336 |
+
scale /= 2.0
|
| 337 |
+
self.downsamples = nn.Sequential(*downsamples)
|
| 338 |
+
|
| 339 |
+
# middle blocks
|
| 340 |
+
self.middle = nn.Sequential(
|
| 341 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 342 |
+
AttentionBlock(out_dim),
|
| 343 |
+
ResidualBlock(out_dim, out_dim, dropout),
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# output blocks
|
| 347 |
+
self.head = nn.Sequential(
|
| 348 |
+
RMS_norm(out_dim, images=False),
|
| 349 |
+
nn.SiLU(),
|
| 350 |
+
CausalConv3d(out_dim, z_dim, 3, padding=1),
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 354 |
+
if feat_cache is not None:
|
| 355 |
+
idx = feat_idx[0]
|
| 356 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 357 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 358 |
+
# cache last frame of last two chunk
|
| 359 |
+
cache_x = torch.cat(
|
| 360 |
+
[
|
| 361 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
| 362 |
+
cache_x,
|
| 363 |
+
],
|
| 364 |
+
dim=2,
|
| 365 |
+
)
|
| 366 |
+
x = self.conv1(x, feat_cache[idx])
|
| 367 |
+
feat_cache[idx] = cache_x
|
| 368 |
+
feat_idx[0] += 1
|
| 369 |
+
else:
|
| 370 |
+
x = self.conv1(x)
|
| 371 |
+
|
| 372 |
+
## downsamples
|
| 373 |
+
for layer in self.downsamples:
|
| 374 |
+
if feat_cache is not None:
|
| 375 |
+
x = layer(x, feat_cache, feat_idx)
|
| 376 |
+
else:
|
| 377 |
+
x = layer(x)
|
| 378 |
+
|
| 379 |
+
## middle
|
| 380 |
+
for layer in self.middle:
|
| 381 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 382 |
+
x = layer(x, feat_cache, feat_idx)
|
| 383 |
+
else:
|
| 384 |
+
x = layer(x)
|
| 385 |
+
|
| 386 |
+
## head
|
| 387 |
+
for layer in self.head:
|
| 388 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 389 |
+
idx = feat_idx[0]
|
| 390 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 391 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 392 |
+
# cache last frame of last two chunk
|
| 393 |
+
cache_x = torch.cat(
|
| 394 |
+
[
|
| 395 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 396 |
+
.unsqueeze(2)
|
| 397 |
+
.to(cache_x.device),
|
| 398 |
+
cache_x,
|
| 399 |
+
],
|
| 400 |
+
dim=2,
|
| 401 |
+
)
|
| 402 |
+
x = layer(x, feat_cache[idx])
|
| 403 |
+
feat_cache[idx] = cache_x
|
| 404 |
+
feat_idx[0] += 1
|
| 405 |
+
else:
|
| 406 |
+
x = layer(x)
|
| 407 |
+
return x
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
class Decoder3d(nn.Module):
|
| 411 |
+
def __init__(
|
| 412 |
+
self,
|
| 413 |
+
dim=128,
|
| 414 |
+
z_dim=4,
|
| 415 |
+
dim_mult=[1, 2, 4, 4],
|
| 416 |
+
num_res_blocks=2,
|
| 417 |
+
attn_scales=[],
|
| 418 |
+
temperal_upsample=[False, True, True],
|
| 419 |
+
dropout=0.0,
|
| 420 |
+
):
|
| 421 |
+
super().__init__()
|
| 422 |
+
self.dim = dim
|
| 423 |
+
self.z_dim = z_dim
|
| 424 |
+
self.dim_mult = dim_mult
|
| 425 |
+
self.num_res_blocks = num_res_blocks
|
| 426 |
+
self.attn_scales = attn_scales
|
| 427 |
+
self.temperal_upsample = temperal_upsample
|
| 428 |
+
|
| 429 |
+
# dimensions
|
| 430 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
| 431 |
+
|
| 432 |
+
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
| 433 |
+
|
| 434 |
+
# init block
|
| 435 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
| 436 |
+
|
| 437 |
+
# middle blocks
|
| 438 |
+
self.middle = nn.Sequential(
|
| 439 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 440 |
+
AttentionBlock(dims[0]),
|
| 441 |
+
ResidualBlock(dims[0], dims[0], dropout),
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
# upsample blocks
|
| 445 |
+
upsamples = []
|
| 446 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
| 447 |
+
# residual (+attention) blocks
|
| 448 |
+
if i == 1 or i == 2 or i == 3:
|
| 449 |
+
in_dim = in_dim // 2
|
| 450 |
+
for _ in range(num_res_blocks + 1):
|
| 451 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
| 452 |
+
if scale in attn_scales:
|
| 453 |
+
upsamples.append(AttentionBlock(out_dim))
|
| 454 |
+
in_dim = out_dim
|
| 455 |
+
|
| 456 |
+
# upsample block
|
| 457 |
+
if i != len(dim_mult) - 1:
|
| 458 |
+
mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
|
| 459 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
| 460 |
+
scale *= 2.0
|
| 461 |
+
self.upsamples = nn.Sequential(*upsamples)
|
| 462 |
+
|
| 463 |
+
# output blocks
|
| 464 |
+
self.head = nn.Sequential(
|
| 465 |
+
RMS_norm(out_dim, images=False),
|
| 466 |
+
nn.SiLU(),
|
| 467 |
+
CausalConv3d(out_dim, 3, 3, padding=1),
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
| 471 |
+
## conv1
|
| 472 |
+
if feat_cache is not None:
|
| 473 |
+
idx = feat_idx[0]
|
| 474 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 475 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 476 |
+
# cache last frame of last two chunk
|
| 477 |
+
cache_x = torch.cat(
|
| 478 |
+
[
|
| 479 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device),
|
| 480 |
+
cache_x,
|
| 481 |
+
],
|
| 482 |
+
dim=2,
|
| 483 |
+
)
|
| 484 |
+
x = self.conv1(x, feat_cache[idx])
|
| 485 |
+
feat_cache[idx] = cache_x
|
| 486 |
+
feat_idx[0] += 1
|
| 487 |
+
else:
|
| 488 |
+
x = self.conv1(x)
|
| 489 |
+
|
| 490 |
+
## middle
|
| 491 |
+
for layer in self.middle:
|
| 492 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
| 493 |
+
x = layer(x, feat_cache, feat_idx)
|
| 494 |
+
else:
|
| 495 |
+
x = layer(x)
|
| 496 |
+
|
| 497 |
+
## upsamples
|
| 498 |
+
for layer in self.upsamples:
|
| 499 |
+
if feat_cache is not None:
|
| 500 |
+
x = layer(x, feat_cache, feat_idx)
|
| 501 |
+
else:
|
| 502 |
+
x = layer(x)
|
| 503 |
+
|
| 504 |
+
## head
|
| 505 |
+
for layer in self.head:
|
| 506 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
| 507 |
+
idx = feat_idx[0]
|
| 508 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
| 509 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
| 510 |
+
# cache last frame of last two chunk
|
| 511 |
+
cache_x = torch.cat(
|
| 512 |
+
[
|
| 513 |
+
feat_cache[idx][:, :, -1, :, :]
|
| 514 |
+
.unsqueeze(2)
|
| 515 |
+
.to(cache_x.device),
|
| 516 |
+
cache_x,
|
| 517 |
+
],
|
| 518 |
+
dim=2,
|
| 519 |
+
)
|
| 520 |
+
x = layer(x, feat_cache[idx])
|
| 521 |
+
feat_cache[idx] = cache_x
|
| 522 |
+
feat_idx[0] += 1
|
| 523 |
+
else:
|
| 524 |
+
x = layer(x)
|
| 525 |
+
return x
|
| 526 |
+
|
| 527 |
+
|
| 528 |
+
def count_conv3d(model):
|
| 529 |
+
count = 0
|
| 530 |
+
for m in model.modules():
|
| 531 |
+
if isinstance(m, CausalConv3d):
|
| 532 |
+
count += 1
|
| 533 |
+
return count
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
class WanVAE_(nn.Module):
|
| 537 |
+
def __init__(
|
| 538 |
+
self,
|
| 539 |
+
dim=128,
|
| 540 |
+
z_dim=4,
|
| 541 |
+
dim_mult=[1, 2, 4, 4],
|
| 542 |
+
num_res_blocks=2,
|
| 543 |
+
attn_scales=[],
|
| 544 |
+
temperal_downsample=[True, True, False],
|
| 545 |
+
dropout=0.0,
|
| 546 |
+
):
|
| 547 |
+
super().__init__()
|
| 548 |
+
self.dim = dim
|
| 549 |
+
self.z_dim = z_dim
|
| 550 |
+
self.dim_mult = dim_mult
|
| 551 |
+
self.num_res_blocks = num_res_blocks
|
| 552 |
+
self.attn_scales = attn_scales
|
| 553 |
+
self.temperal_downsample = temperal_downsample
|
| 554 |
+
self.temperal_upsample = temperal_downsample[::-1]
|
| 555 |
+
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
| 556 |
+
|
| 557 |
+
# The minimal tile height and width for spatial tiling to be used
|
| 558 |
+
self.tile_sample_min_height = 256
|
| 559 |
+
self.tile_sample_min_width = 256
|
| 560 |
+
|
| 561 |
+
# The minimal distance between two spatial tiles
|
| 562 |
+
self.tile_sample_stride_height = 192
|
| 563 |
+
self.tile_sample_stride_width = 192
|
| 564 |
+
# modules
|
| 565 |
+
self.encoder = Encoder3d(
|
| 566 |
+
dim,
|
| 567 |
+
z_dim * 2,
|
| 568 |
+
dim_mult,
|
| 569 |
+
num_res_blocks,
|
| 570 |
+
attn_scales,
|
| 571 |
+
self.temperal_downsample,
|
| 572 |
+
dropout,
|
| 573 |
+
)
|
| 574 |
+
self.conv1 = CausalConv3d(z_dim * 2, z_dim * 2, 1)
|
| 575 |
+
self.conv2 = CausalConv3d(z_dim, z_dim, 1)
|
| 576 |
+
self.decoder = Decoder3d(
|
| 577 |
+
dim,
|
| 578 |
+
z_dim,
|
| 579 |
+
dim_mult,
|
| 580 |
+
num_res_blocks,
|
| 581 |
+
attn_scales,
|
| 582 |
+
self.temperal_upsample,
|
| 583 |
+
dropout,
|
| 584 |
+
)
|
| 585 |
+
|
| 586 |
+
def forward(self, x):
|
| 587 |
+
mu, log_var = self.encode(x)
|
| 588 |
+
z = self.reparameterize(mu, log_var)
|
| 589 |
+
x_recon = self.decode(z)
|
| 590 |
+
return x_recon, mu, log_var
|
| 591 |
+
|
| 592 |
+
def blend_v(self, a, b, blend_extent):
|
| 593 |
+
blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
|
| 594 |
+
for y in range(blend_extent):
|
| 595 |
+
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (
|
| 596 |
+
1 - y / blend_extent
|
| 597 |
+
) + b[:, :, :, y, :] * (y / blend_extent)
|
| 598 |
+
return b
|
| 599 |
+
|
| 600 |
+
def blend_h(self, a, b, blend_extent):
|
| 601 |
+
blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
|
| 602 |
+
for x in range(blend_extent):
|
| 603 |
+
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (
|
| 604 |
+
1 - x / blend_extent
|
| 605 |
+
) + b[:, :, :, :, x] * (x / blend_extent)
|
| 606 |
+
return b
|
| 607 |
+
|
| 608 |
+
def tiled_encode(self, x, scale):
|
| 609 |
+
_, _, num_frames, height, width = x.shape
|
| 610 |
+
latent_height = height // self.spatial_compression_ratio
|
| 611 |
+
latent_width = width // self.spatial_compression_ratio
|
| 612 |
+
|
| 613 |
+
tile_latent_min_height = (
|
| 614 |
+
self.tile_sample_min_height // self.spatial_compression_ratio
|
| 615 |
+
)
|
| 616 |
+
tile_latent_min_width = (
|
| 617 |
+
self.tile_sample_min_width // self.spatial_compression_ratio
|
| 618 |
+
)
|
| 619 |
+
tile_latent_stride_height = (
|
| 620 |
+
self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 621 |
+
)
|
| 622 |
+
tile_latent_stride_width = (
|
| 623 |
+
self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
blend_height = tile_latent_min_height - tile_latent_stride_height
|
| 627 |
+
blend_width = tile_latent_min_width - tile_latent_stride_width
|
| 628 |
+
|
| 629 |
+
# Split x into overlapping tiles and encode them separately.
|
| 630 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 631 |
+
rows = []
|
| 632 |
+
for i in range(0, height, self.tile_sample_stride_height):
|
| 633 |
+
row = []
|
| 634 |
+
for j in range(0, width, self.tile_sample_stride_width):
|
| 635 |
+
self.clear_cache()
|
| 636 |
+
time = []
|
| 637 |
+
frame_range = 1 + (num_frames - 1) // 4
|
| 638 |
+
for k in range(frame_range):
|
| 639 |
+
self._enc_conv_idx = [0]
|
| 640 |
+
if k == 0:
|
| 641 |
+
tile = x[
|
| 642 |
+
:,
|
| 643 |
+
:,
|
| 644 |
+
:1,
|
| 645 |
+
i : i + self.tile_sample_min_height,
|
| 646 |
+
j : j + self.tile_sample_min_width,
|
| 647 |
+
]
|
| 648 |
+
else:
|
| 649 |
+
tile = x[
|
| 650 |
+
:,
|
| 651 |
+
:,
|
| 652 |
+
1 + 4 * (k - 1) : 1 + 4 * k,
|
| 653 |
+
i : i + self.tile_sample_min_height,
|
| 654 |
+
j : j + self.tile_sample_min_width,
|
| 655 |
+
]
|
| 656 |
+
tile = self.encoder(
|
| 657 |
+
tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx
|
| 658 |
+
)
|
| 659 |
+
mu, log_var = self.conv1(tile).chunk(2, dim=1)
|
| 660 |
+
if isinstance(scale[0], torch.Tensor):
|
| 661 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[
|
| 662 |
+
1
|
| 663 |
+
].view(1, self.z_dim, 1, 1, 1)
|
| 664 |
+
else:
|
| 665 |
+
mu = (mu - scale[0]) * scale[1]
|
| 666 |
+
|
| 667 |
+
time.append(mu)
|
| 668 |
+
|
| 669 |
+
row.append(torch.cat(time, dim=2))
|
| 670 |
+
rows.append(row)
|
| 671 |
+
self.clear_cache()
|
| 672 |
+
|
| 673 |
+
result_rows = []
|
| 674 |
+
for i, row in enumerate(rows):
|
| 675 |
+
result_row = []
|
| 676 |
+
for j, tile in enumerate(row):
|
| 677 |
+
# blend the above tile and the left tile
|
| 678 |
+
# to the current tile and add the current tile to the result row
|
| 679 |
+
if i > 0:
|
| 680 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 681 |
+
if j > 0:
|
| 682 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 683 |
+
result_row.append(
|
| 684 |
+
tile[:, :, :, :tile_latent_stride_height, :tile_latent_stride_width]
|
| 685 |
+
)
|
| 686 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 687 |
+
|
| 688 |
+
enc = torch.cat(result_rows, dim=3)[:, :, :, :latent_height, :latent_width]
|
| 689 |
+
return enc
|
| 690 |
+
|
| 691 |
+
def tiled_decode(self, z, scale):
|
| 692 |
+
if isinstance(scale[0], torch.Tensor):
|
| 693 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 694 |
+
1, self.z_dim, 1, 1, 1
|
| 695 |
+
)
|
| 696 |
+
else:
|
| 697 |
+
z = z / scale[1] + scale[0]
|
| 698 |
+
|
| 699 |
+
_, _, num_frames, height, width = z.shape
|
| 700 |
+
sample_height = height * self.spatial_compression_ratio
|
| 701 |
+
sample_width = width * self.spatial_compression_ratio
|
| 702 |
+
|
| 703 |
+
tile_latent_min_height = (
|
| 704 |
+
self.tile_sample_min_height // self.spatial_compression_ratio
|
| 705 |
+
)
|
| 706 |
+
tile_latent_min_width = (
|
| 707 |
+
self.tile_sample_min_width // self.spatial_compression_ratio
|
| 708 |
+
)
|
| 709 |
+
tile_latent_stride_height = (
|
| 710 |
+
self.tile_sample_stride_height // self.spatial_compression_ratio
|
| 711 |
+
)
|
| 712 |
+
tile_latent_stride_width = (
|
| 713 |
+
self.tile_sample_stride_width // self.spatial_compression_ratio
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
| 717 |
+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
| 718 |
+
|
| 719 |
+
# Split z into overlapping tiles and decode them separately.
|
| 720 |
+
# The tiles have an overlap to avoid seams between tiles.
|
| 721 |
+
rows = []
|
| 722 |
+
for i in range(0, height, tile_latent_stride_height):
|
| 723 |
+
row = []
|
| 724 |
+
for j in range(0, width, tile_latent_stride_width):
|
| 725 |
+
self.clear_cache()
|
| 726 |
+
time = []
|
| 727 |
+
for k in range(num_frames):
|
| 728 |
+
self._conv_idx = [0]
|
| 729 |
+
tile = z[
|
| 730 |
+
:,
|
| 731 |
+
:,
|
| 732 |
+
k : k + 1,
|
| 733 |
+
i : i + tile_latent_min_height,
|
| 734 |
+
j : j + tile_latent_min_width,
|
| 735 |
+
]
|
| 736 |
+
tile = self.conv2(tile)
|
| 737 |
+
decoded = self.decoder(
|
| 738 |
+
tile, feat_cache=self._feat_map, feat_idx=self._conv_idx
|
| 739 |
+
)
|
| 740 |
+
time.append(decoded)
|
| 741 |
+
row.append(torch.cat(time, dim=2))
|
| 742 |
+
rows.append(row)
|
| 743 |
+
self.clear_cache()
|
| 744 |
+
|
| 745 |
+
result_rows = []
|
| 746 |
+
for i, row in enumerate(rows):
|
| 747 |
+
result_row = []
|
| 748 |
+
for j, tile in enumerate(row):
|
| 749 |
+
# blend the above tile and the left tile
|
| 750 |
+
# to the current tile and add the current tile to the result row
|
| 751 |
+
if i > 0:
|
| 752 |
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
| 753 |
+
if j > 0:
|
| 754 |
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
| 755 |
+
result_row.append(
|
| 756 |
+
tile[
|
| 757 |
+
:,
|
| 758 |
+
:,
|
| 759 |
+
:,
|
| 760 |
+
: self.tile_sample_stride_height,
|
| 761 |
+
: self.tile_sample_stride_width,
|
| 762 |
+
]
|
| 763 |
+
)
|
| 764 |
+
result_rows.append(torch.cat(result_row, dim=-1))
|
| 765 |
+
|
| 766 |
+
dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width]
|
| 767 |
+
|
| 768 |
+
return dec
|
| 769 |
+
|
| 770 |
+
def encode(self, x, scale, return_mu=False):
|
| 771 |
+
self.clear_cache()
|
| 772 |
+
## cache
|
| 773 |
+
t = x.shape[2]
|
| 774 |
+
iter_ = 1 + (t - 1) // 4
|
| 775 |
+
for i in range(iter_):
|
| 776 |
+
self._enc_conv_idx = [0]
|
| 777 |
+
if i == 0:
|
| 778 |
+
out = self.encoder(
|
| 779 |
+
x[:, :, :1, :, :],
|
| 780 |
+
feat_cache=self._enc_feat_map,
|
| 781 |
+
feat_idx=self._enc_conv_idx,
|
| 782 |
+
)
|
| 783 |
+
else:
|
| 784 |
+
out_ = self.encoder(
|
| 785 |
+
x[:, :, 1 + 4 * (i - 1) : 1 + 4 * i, :, :],
|
| 786 |
+
feat_cache=self._enc_feat_map,
|
| 787 |
+
feat_idx=self._enc_conv_idx,
|
| 788 |
+
)
|
| 789 |
+
out = torch.cat([out, out_], 2)
|
| 790 |
+
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
| 791 |
+
if isinstance(scale[0], torch.Tensor):
|
| 792 |
+
mu = (mu - scale[0].view(1, self.z_dim, 1, 1, 1)) * scale[1].view(
|
| 793 |
+
1, self.z_dim, 1, 1, 1
|
| 794 |
+
)
|
| 795 |
+
else:
|
| 796 |
+
mu = (mu - scale[0]) * scale[1]
|
| 797 |
+
|
| 798 |
+
self.clear_cache()
|
| 799 |
+
if return_mu:
|
| 800 |
+
return mu, log_var
|
| 801 |
+
else:
|
| 802 |
+
return mu
|
| 803 |
+
|
| 804 |
+
def decode(self, z, scale):
|
| 805 |
+
self.clear_cache()
|
| 806 |
+
|
| 807 |
+
# z: [b,c,t,h,w]
|
| 808 |
+
if isinstance(scale[0], torch.Tensor):
|
| 809 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 810 |
+
1, self.z_dim, 1, 1, 1
|
| 811 |
+
)
|
| 812 |
+
else:
|
| 813 |
+
z = z / scale[1] + scale[0]
|
| 814 |
+
iter_ = z.shape[2]
|
| 815 |
+
x = self.conv2(z)
|
| 816 |
+
for i in range(iter_):
|
| 817 |
+
self._conv_idx = [0]
|
| 818 |
+
if i == 0:
|
| 819 |
+
out = self.decoder(
|
| 820 |
+
x[:, :, i : i + 1, :, :],
|
| 821 |
+
feat_cache=self._feat_map,
|
| 822 |
+
feat_idx=self._conv_idx,
|
| 823 |
+
)
|
| 824 |
+
else:
|
| 825 |
+
out_ = self.decoder(
|
| 826 |
+
x[:, :, i : i + 1, :, :],
|
| 827 |
+
feat_cache=self._feat_map,
|
| 828 |
+
feat_idx=self._conv_idx,
|
| 829 |
+
)
|
| 830 |
+
out = torch.cat([out, out_], 2)
|
| 831 |
+
|
| 832 |
+
self.clear_cache()
|
| 833 |
+
return out
|
| 834 |
+
|
| 835 |
+
def decode_stream(self, z, scale):
|
| 836 |
+
self.clear_cache()
|
| 837 |
+
|
| 838 |
+
# z: [b,c,t,h,w]
|
| 839 |
+
if isinstance(scale[0], torch.Tensor):
|
| 840 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 841 |
+
1, self.z_dim, 1, 1, 1
|
| 842 |
+
)
|
| 843 |
+
else:
|
| 844 |
+
z = z / scale[1] + scale[0]
|
| 845 |
+
iter_ = z.shape[2]
|
| 846 |
+
x = self.conv2(z)
|
| 847 |
+
for i in range(iter_):
|
| 848 |
+
self._conv_idx = [0]
|
| 849 |
+
out = self.decoder(
|
| 850 |
+
x[:, :, i : i + 1, :, :],
|
| 851 |
+
feat_cache=self._feat_map,
|
| 852 |
+
feat_idx=self._conv_idx,
|
| 853 |
+
)
|
| 854 |
+
yield out
|
| 855 |
+
|
| 856 |
+
def cached_decode(self, z, scale):
|
| 857 |
+
# z: [b,c,t,h,w]
|
| 858 |
+
if isinstance(scale[0], torch.Tensor):
|
| 859 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
| 860 |
+
1, self.z_dim, 1, 1, 1
|
| 861 |
+
)
|
| 862 |
+
else:
|
| 863 |
+
z = z / scale[1] + scale[0]
|
| 864 |
+
iter_ = z.shape[2]
|
| 865 |
+
x = self.conv2(z)
|
| 866 |
+
for i in range(iter_):
|
| 867 |
+
self._conv_idx = [0]
|
| 868 |
+
if i == 0:
|
| 869 |
+
out = self.decoder(
|
| 870 |
+
x[:, :, i : i + 1, :, :],
|
| 871 |
+
feat_cache=self._feat_map,
|
| 872 |
+
feat_idx=self._conv_idx,
|
| 873 |
+
)
|
| 874 |
+
else:
|
| 875 |
+
out_ = self.decoder(
|
| 876 |
+
x[:, :, i : i + 1, :, :],
|
| 877 |
+
feat_cache=self._feat_map,
|
| 878 |
+
feat_idx=self._conv_idx,
|
| 879 |
+
)
|
| 880 |
+
out = torch.cat([out, out_], 2)
|
| 881 |
+
return out
|
| 882 |
+
|
| 883 |
+
def reparameterize(self, mu, log_var):
|
| 884 |
+
std = torch.exp(0.5 * log_var)
|
| 885 |
+
eps = torch.randn_like(std)
|
| 886 |
+
return eps * std + mu
|
| 887 |
+
|
| 888 |
+
def sample(self, imgs, deterministic=False, scale=[0, 1]):
|
| 889 |
+
mu, log_var = self.encode(imgs, scale, return_mu=True)
|
| 890 |
+
if deterministic:
|
| 891 |
+
return mu
|
| 892 |
+
std = torch.exp(0.5 * log_var.clamp(-30.0, 20.0))
|
| 893 |
+
return mu + std * torch.randn_like(std), mu, log_var
|
| 894 |
+
|
| 895 |
+
def clear_cache(self):
|
| 896 |
+
self._conv_num = count_conv3d(self.decoder)
|
| 897 |
+
self._conv_idx = [0]
|
| 898 |
+
self._feat_map = [None] * self._conv_num
|
| 899 |
+
# cache encode
|
| 900 |
+
self._enc_conv_num = count_conv3d(self.encoder)
|
| 901 |
+
self._enc_conv_idx = [0]
|
| 902 |
+
self._enc_feat_map = [None] * self._enc_conv_num
|
| 903 |
+
|
| 904 |
+
def encode_video(self, x, scale=[0, 1]):
|
| 905 |
+
assert x.ndim == 5 # NTCHW
|
| 906 |
+
assert x.shape[2] % 3 == 0
|
| 907 |
+
x = x.transpose(1, 2)
|
| 908 |
+
y = x.mul(2).sub_(1)
|
| 909 |
+
y, mu, log_var = self.sample(y, scale=scale)
|
| 910 |
+
return y.transpose(1, 2).to(x), mu, log_var
|
| 911 |
+
|
| 912 |
+
def decode_video(self, x, scale=[0, 1]):
|
| 913 |
+
assert x.ndim == 5 # NTCHW
|
| 914 |
+
assert x.shape[2] % self.z_dim == 0
|
| 915 |
+
x = x.transpose(1, 2)
|
| 916 |
+
# B, C, T, H, W
|
| 917 |
+
y = x
|
| 918 |
+
y = self.decode(y, scale).clamp_(-1, 1)
|
| 919 |
+
y = y.mul_(0.5).add_(0.5).clamp_(0, 1) # NCTHW
|
| 920 |
+
return y.transpose(1, 2).to(x)
|
| 921 |
+
|
| 922 |
+
|
| 923 |
+
def _video_vae(
|
| 924 |
+
pretrained_path=None,
|
| 925 |
+
z_dim=None,
|
| 926 |
+
device="cpu",
|
| 927 |
+
dtype=torch.float,
|
| 928 |
+
**kwargs,
|
| 929 |
+
):
|
| 930 |
+
"""
|
| 931 |
+
Autoencoder3d adapted from Stable Diffusion 1.x, 2.x and XL.
|
| 932 |
+
"""
|
| 933 |
+
# params
|
| 934 |
+
cfg = dict(
|
| 935 |
+
dim=96,
|
| 936 |
+
z_dim=z_dim,
|
| 937 |
+
dim_mult=[1, 2, 4, 4],
|
| 938 |
+
num_res_blocks=2,
|
| 939 |
+
attn_scales=[],
|
| 940 |
+
temperal_downsample=[False, True, True],
|
| 941 |
+
dropout=0.0,
|
| 942 |
+
)
|
| 943 |
+
cfg.update(**kwargs)
|
| 944 |
+
|
| 945 |
+
# init model
|
| 946 |
+
with torch.device("meta"):
|
| 947 |
+
model = WanVAE_(**cfg)
|
| 948 |
+
|
| 949 |
+
# load checkpoint
|
| 950 |
+
model.load_state_dict(torch.load(pretrained_path, map_location=device), assign=True)
|
| 951 |
+
|
| 952 |
+
return model
|
| 953 |
+
|
| 954 |
+
class WanVAE:
|
| 955 |
+
def __init__(
|
| 956 |
+
self,
|
| 957 |
+
z_dim=16,
|
| 958 |
+
vae_path="cache/vae_step_411000.pth",
|
| 959 |
+
dtype=torch.float,
|
| 960 |
+
device="cuda",
|
| 961 |
+
parallel=False,
|
| 962 |
+
use_tiling=False,
|
| 963 |
+
use_2d_split=True,
|
| 964 |
+
):
|
| 965 |
+
self.dtype = dtype
|
| 966 |
+
self.device = device
|
| 967 |
+
self.parallel = parallel
|
| 968 |
+
self.use_tiling = use_tiling
|
| 969 |
+
self.use_2d_split = use_2d_split
|
| 970 |
+
|
| 971 |
+
mean = [
|
| 972 |
+
-0.7571,
|
| 973 |
+
-0.7089,
|
| 974 |
+
-0.9113,
|
| 975 |
+
0.1075,
|
| 976 |
+
-0.1745,
|
| 977 |
+
0.9653,
|
| 978 |
+
-0.1517,
|
| 979 |
+
1.5508,
|
| 980 |
+
0.4134,
|
| 981 |
+
-0.0715,
|
| 982 |
+
0.5517,
|
| 983 |
+
-0.3632,
|
| 984 |
+
-0.1922,
|
| 985 |
+
-0.9497,
|
| 986 |
+
0.2503,
|
| 987 |
+
-0.2921,
|
| 988 |
+
]
|
| 989 |
+
std = [
|
| 990 |
+
2.8184,
|
| 991 |
+
1.4541,
|
| 992 |
+
2.3275,
|
| 993 |
+
2.6558,
|
| 994 |
+
1.2196,
|
| 995 |
+
1.7708,
|
| 996 |
+
2.6052,
|
| 997 |
+
2.0743,
|
| 998 |
+
3.2687,
|
| 999 |
+
2.1526,
|
| 1000 |
+
2.8652,
|
| 1001 |
+
1.5579,
|
| 1002 |
+
1.6382,
|
| 1003 |
+
1.1253,
|
| 1004 |
+
2.8251,
|
| 1005 |
+
1.9160,
|
| 1006 |
+
]
|
| 1007 |
+
self.mean = torch.tensor(mean, dtype=dtype, device=device)
|
| 1008 |
+
self.inv_std = 1.0 / torch.tensor(std, dtype=dtype, device=device)
|
| 1009 |
+
self.scale = [self.mean, self.inv_std]
|
| 1010 |
+
|
| 1011 |
+
# (height, width, world_size) -> (world_size_h, world_size_w)
|
| 1012 |
+
self.grid_table = {
|
| 1013 |
+
# world_size = 2
|
| 1014 |
+
(60, 104, 2): (1, 2),
|
| 1015 |
+
(68, 120, 2): (1, 2),
|
| 1016 |
+
(90, 160, 2): (1, 2),
|
| 1017 |
+
(60, 60, 2): (1, 2),
|
| 1018 |
+
(72, 72, 2): (1, 2),
|
| 1019 |
+
(88, 88, 2): (1, 2),
|
| 1020 |
+
(120, 120, 2): (1, 2),
|
| 1021 |
+
(104, 60, 2): (2, 1),
|
| 1022 |
+
(120, 68, 2): (2, 1),
|
| 1023 |
+
(160, 90, 2): (2, 1),
|
| 1024 |
+
# world_size = 4
|
| 1025 |
+
(60, 104, 4): (2, 2),
|
| 1026 |
+
(68, 120, 4): (2, 2),
|
| 1027 |
+
(90, 160, 4): (2, 2),
|
| 1028 |
+
(60, 60, 4): (2, 2),
|
| 1029 |
+
(72, 72, 4): (2, 2),
|
| 1030 |
+
(88, 88, 4): (2, 2),
|
| 1031 |
+
(120, 120, 4): (2, 2),
|
| 1032 |
+
(104, 60, 4): (2, 2),
|
| 1033 |
+
(120, 68, 4): (2, 2),
|
| 1034 |
+
(160, 90, 4): (2, 2),
|
| 1035 |
+
# world_size = 8
|
| 1036 |
+
(60, 104, 8): (2, 4),
|
| 1037 |
+
(68, 120, 8): (2, 4),
|
| 1038 |
+
(90, 160, 8): (2, 4),
|
| 1039 |
+
(60, 60, 8): (2, 4),
|
| 1040 |
+
(72, 72, 8): (2, 4),
|
| 1041 |
+
(88, 88, 8): (2, 4),
|
| 1042 |
+
(120, 120, 8): (2, 4),
|
| 1043 |
+
(104, 60, 8): (4, 2),
|
| 1044 |
+
(120, 68, 8): (4, 2),
|
| 1045 |
+
(160, 90, 8): (4, 2),
|
| 1046 |
+
}
|
| 1047 |
+
|
| 1048 |
+
# init model
|
| 1049 |
+
self.model = (
|
| 1050 |
+
_video_vae(
|
| 1051 |
+
pretrained_path=vae_path,
|
| 1052 |
+
z_dim=z_dim,
|
| 1053 |
+
dtype=dtype,
|
| 1054 |
+
)
|
| 1055 |
+
.eval()
|
| 1056 |
+
.requires_grad_(False)
|
| 1057 |
+
.to(device)
|
| 1058 |
+
.to(dtype)
|
| 1059 |
+
)
|
| 1060 |
+
|
| 1061 |
+
def _calculate_2d_grid(self, latent_height, latent_width, world_size):
|
| 1062 |
+
if (latent_height, latent_width, world_size) in self.grid_table:
|
| 1063 |
+
best_h, best_w = self.grid_table[(latent_height, latent_width, world_size)]
|
| 1064 |
+
# logger.info(f"Vae using cached 2D grid: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
|
| 1065 |
+
return best_h, best_w
|
| 1066 |
+
|
| 1067 |
+
best_h, best_w = 1, world_size
|
| 1068 |
+
min_aspect_diff = float("inf")
|
| 1069 |
+
|
| 1070 |
+
for h in range(1, world_size + 1):
|
| 1071 |
+
if world_size % h == 0:
|
| 1072 |
+
w = world_size // h
|
| 1073 |
+
if latent_height % h == 0 and latent_width % w == 0:
|
| 1074 |
+
# Calculate how close this grid is to square
|
| 1075 |
+
aspect_diff = abs((latent_height / h) - (latent_width / w))
|
| 1076 |
+
if aspect_diff < min_aspect_diff:
|
| 1077 |
+
min_aspect_diff = aspect_diff
|
| 1078 |
+
best_h, best_w = h, w
|
| 1079 |
+
# logger.info(f"Vae using 2D grid & Update cache: {best_h}x{best_w} grid for {latent_height}x{latent_width} latent")
|
| 1080 |
+
self.grid_table[(latent_height, latent_width, world_size)] = (best_h, best_w)
|
| 1081 |
+
return best_h, best_w
|
| 1082 |
+
|
| 1083 |
+
def current_device(self):
|
| 1084 |
+
return next(self.model.parameters()).device
|
| 1085 |
+
|
| 1086 |
+
def encode_dist(self, video, world_size, cur_rank, split_dim):
|
| 1087 |
+
spatial_ratio = 8
|
| 1088 |
+
|
| 1089 |
+
if split_dim == 3:
|
| 1090 |
+
total_latent_len = video.shape[3] // spatial_ratio
|
| 1091 |
+
elif split_dim == 4:
|
| 1092 |
+
total_latent_len = video.shape[4] // spatial_ratio
|
| 1093 |
+
else:
|
| 1094 |
+
raise ValueError(f"Unsupported split_dim: {split_dim}")
|
| 1095 |
+
|
| 1096 |
+
splited_chunk_len = total_latent_len // world_size
|
| 1097 |
+
padding_size = 1
|
| 1098 |
+
|
| 1099 |
+
video_chunk_len = splited_chunk_len * spatial_ratio
|
| 1100 |
+
video_padding_len = padding_size * spatial_ratio
|
| 1101 |
+
|
| 1102 |
+
if cur_rank == 0:
|
| 1103 |
+
if split_dim == 3:
|
| 1104 |
+
video_chunk = video[
|
| 1105 |
+
:, :, :, : video_chunk_len + 2 * video_padding_len, :
|
| 1106 |
+
].contiguous()
|
| 1107 |
+
elif split_dim == 4:
|
| 1108 |
+
video_chunk = video[
|
| 1109 |
+
:, :, :, :, : video_chunk_len + 2 * video_padding_len
|
| 1110 |
+
].contiguous()
|
| 1111 |
+
elif cur_rank == world_size - 1:
|
| 1112 |
+
if split_dim == 3:
|
| 1113 |
+
video_chunk = video[
|
| 1114 |
+
:, :, :, -(video_chunk_len + 2 * video_padding_len) :, :
|
| 1115 |
+
].contiguous()
|
| 1116 |
+
elif split_dim == 4:
|
| 1117 |
+
video_chunk = video[
|
| 1118 |
+
:, :, :, :, -(video_chunk_len + 2 * video_padding_len) :
|
| 1119 |
+
].contiguous()
|
| 1120 |
+
else:
|
| 1121 |
+
start_idx = cur_rank * video_chunk_len - video_padding_len
|
| 1122 |
+
end_idx = (cur_rank + 1) * video_chunk_len + video_padding_len
|
| 1123 |
+
if split_dim == 3:
|
| 1124 |
+
video_chunk = video[:, :, :, start_idx:end_idx, :].contiguous()
|
| 1125 |
+
elif split_dim == 4:
|
| 1126 |
+
video_chunk = video[:, :, :, :, start_idx:end_idx].contiguous()
|
| 1127 |
+
|
| 1128 |
+
if self.use_tiling:
|
| 1129 |
+
encoded_chunk = self.model.tiled_encode(video_chunk, self.scale)
|
| 1130 |
+
else:
|
| 1131 |
+
encoded_chunk = self.model.encode(video_chunk, self.scale)
|
| 1132 |
+
|
| 1133 |
+
if cur_rank == 0:
|
| 1134 |
+
if split_dim == 3:
|
| 1135 |
+
encoded_chunk = encoded_chunk[
|
| 1136 |
+
:, :, :, :splited_chunk_len, :
|
| 1137 |
+
].contiguous()
|
| 1138 |
+
elif split_dim == 4:
|
| 1139 |
+
encoded_chunk = encoded_chunk[
|
| 1140 |
+
:, :, :, :, :splited_chunk_len
|
| 1141 |
+
].contiguous()
|
| 1142 |
+
elif cur_rank == world_size - 1:
|
| 1143 |
+
if split_dim == 3:
|
| 1144 |
+
encoded_chunk = encoded_chunk[
|
| 1145 |
+
:, :, :, -splited_chunk_len:, :
|
| 1146 |
+
].contiguous()
|
| 1147 |
+
elif split_dim == 4:
|
| 1148 |
+
encoded_chunk = encoded_chunk[
|
| 1149 |
+
:, :, :, :, -splited_chunk_len:
|
| 1150 |
+
].contiguous()
|
| 1151 |
+
else:
|
| 1152 |
+
if split_dim == 3:
|
| 1153 |
+
encoded_chunk = encoded_chunk[
|
| 1154 |
+
:, :, :, padding_size:-padding_size, :
|
| 1155 |
+
].contiguous()
|
| 1156 |
+
elif split_dim == 4:
|
| 1157 |
+
encoded_chunk = encoded_chunk[
|
| 1158 |
+
:, :, :, :, padding_size:-padding_size
|
| 1159 |
+
].contiguous()
|
| 1160 |
+
|
| 1161 |
+
full_encoded = [torch.empty_like(encoded_chunk) for _ in range(world_size)]
|
| 1162 |
+
dist.all_gather(full_encoded, encoded_chunk)
|
| 1163 |
+
|
| 1164 |
+
torch.cuda.synchronize()
|
| 1165 |
+
|
| 1166 |
+
encoded = torch.cat(full_encoded, dim=split_dim)
|
| 1167 |
+
|
| 1168 |
+
return encoded.squeeze(0)
|
| 1169 |
+
|
| 1170 |
+
def encode_dist_2d(self, video, world_size_h, world_size_w, cur_rank_h, cur_rank_w):
|
| 1171 |
+
spatial_ratio = 8
|
| 1172 |
+
|
| 1173 |
+
# Calculate chunk sizes for both dimensions
|
| 1174 |
+
total_latent_h = video.shape[3] // spatial_ratio
|
| 1175 |
+
total_latent_w = video.shape[4] // spatial_ratio
|
| 1176 |
+
|
| 1177 |
+
chunk_h = total_latent_h // world_size_h
|
| 1178 |
+
chunk_w = total_latent_w // world_size_w
|
| 1179 |
+
|
| 1180 |
+
padding_size = 1
|
| 1181 |
+
video_chunk_h = chunk_h * spatial_ratio
|
| 1182 |
+
video_chunk_w = chunk_w * spatial_ratio
|
| 1183 |
+
video_padding_h = padding_size * spatial_ratio
|
| 1184 |
+
video_padding_w = padding_size * spatial_ratio
|
| 1185 |
+
|
| 1186 |
+
# Calculate H dimension slice
|
| 1187 |
+
if cur_rank_h == 0:
|
| 1188 |
+
h_start = 0
|
| 1189 |
+
h_end = video_chunk_h + 2 * video_padding_h
|
| 1190 |
+
elif cur_rank_h == world_size_h - 1:
|
| 1191 |
+
h_start = video.shape[3] - (video_chunk_h + 2 * video_padding_h)
|
| 1192 |
+
h_end = video.shape[3]
|
| 1193 |
+
else:
|
| 1194 |
+
h_start = cur_rank_h * video_chunk_h - video_padding_h
|
| 1195 |
+
h_end = (cur_rank_h + 1) * video_chunk_h + video_padding_h
|
| 1196 |
+
|
| 1197 |
+
# Calculate W dimension slice
|
| 1198 |
+
if cur_rank_w == 0:
|
| 1199 |
+
w_start = 0
|
| 1200 |
+
w_end = video_chunk_w + 2 * video_padding_w
|
| 1201 |
+
elif cur_rank_w == world_size_w - 1:
|
| 1202 |
+
w_start = video.shape[4] - (video_chunk_w + 2 * video_padding_w)
|
| 1203 |
+
w_end = video.shape[4]
|
| 1204 |
+
else:
|
| 1205 |
+
w_start = cur_rank_w * video_chunk_w - video_padding_w
|
| 1206 |
+
w_end = (cur_rank_w + 1) * video_chunk_w + video_padding_w
|
| 1207 |
+
|
| 1208 |
+
# Extract the video chunk for this process
|
| 1209 |
+
video_chunk = video[:, :, :, h_start:h_end, w_start:w_end].contiguous()
|
| 1210 |
+
|
| 1211 |
+
# Encode the chunk
|
| 1212 |
+
if self.use_tiling:
|
| 1213 |
+
encoded_chunk = self.model.tiled_encode(video_chunk, self.scale)
|
| 1214 |
+
else:
|
| 1215 |
+
encoded_chunk = self.model.encode(video_chunk, self.scale)
|
| 1216 |
+
|
| 1217 |
+
# Remove padding from encoded chunk
|
| 1218 |
+
if cur_rank_h == 0:
|
| 1219 |
+
encoded_h_start = 0
|
| 1220 |
+
encoded_h_end = chunk_h
|
| 1221 |
+
elif cur_rank_h == world_size_h - 1:
|
| 1222 |
+
encoded_h_start = encoded_chunk.shape[3] - chunk_h
|
| 1223 |
+
encoded_h_end = encoded_chunk.shape[3]
|
| 1224 |
+
else:
|
| 1225 |
+
encoded_h_start = padding_size
|
| 1226 |
+
encoded_h_end = encoded_chunk.shape[3] - padding_size
|
| 1227 |
+
|
| 1228 |
+
if cur_rank_w == 0:
|
| 1229 |
+
encoded_w_start = 0
|
| 1230 |
+
encoded_w_end = chunk_w
|
| 1231 |
+
elif cur_rank_w == world_size_w - 1:
|
| 1232 |
+
encoded_w_start = encoded_chunk.shape[4] - chunk_w
|
| 1233 |
+
encoded_w_end = encoded_chunk.shape[4]
|
| 1234 |
+
else:
|
| 1235 |
+
encoded_w_start = padding_size
|
| 1236 |
+
encoded_w_end = encoded_chunk.shape[4] - padding_size
|
| 1237 |
+
|
| 1238 |
+
encoded_chunk = encoded_chunk[
|
| 1239 |
+
:, :, :, encoded_h_start:encoded_h_end, encoded_w_start:encoded_w_end
|
| 1240 |
+
].contiguous()
|
| 1241 |
+
|
| 1242 |
+
# Gather all chunks
|
| 1243 |
+
total_processes = world_size_h * world_size_w
|
| 1244 |
+
full_encoded = [torch.empty_like(encoded_chunk) for _ in range(total_processes)]
|
| 1245 |
+
|
| 1246 |
+
dist.all_gather(full_encoded, encoded_chunk)
|
| 1247 |
+
|
| 1248 |
+
torch.cuda.synchronize()
|
| 1249 |
+
|
| 1250 |
+
# Reconstruct the full encoded tensor
|
| 1251 |
+
encoded_rows = []
|
| 1252 |
+
for h_idx in range(world_size_h):
|
| 1253 |
+
encoded_cols = []
|
| 1254 |
+
for w_idx in range(world_size_w):
|
| 1255 |
+
process_idx = h_idx * world_size_w + w_idx
|
| 1256 |
+
encoded_cols.append(full_encoded[process_idx])
|
| 1257 |
+
encoded_rows.append(torch.cat(encoded_cols, dim=4))
|
| 1258 |
+
|
| 1259 |
+
encoded = torch.cat(encoded_rows, dim=3)
|
| 1260 |
+
|
| 1261 |
+
return encoded.squeeze(0)
|
| 1262 |
+
|
| 1263 |
+
def encode(self, video, world_size_h=None, world_size_w=None):
|
| 1264 |
+
"""
|
| 1265 |
+
video: one video with shape [1, C, T, H, W].
|
| 1266 |
+
"""
|
| 1267 |
+
if self.parallel:
|
| 1268 |
+
world_size = dist.get_world_size()
|
| 1269 |
+
cur_rank = dist.get_rank()
|
| 1270 |
+
height, width = video.shape[3], video.shape[4]
|
| 1271 |
+
|
| 1272 |
+
if self.use_2d_split:
|
| 1273 |
+
if world_size_h is None or world_size_w is None:
|
| 1274 |
+
world_size_h, world_size_w = self._calculate_2d_grid(
|
| 1275 |
+
height // 8, width // 8, world_size
|
| 1276 |
+
)
|
| 1277 |
+
cur_rank_h = cur_rank // world_size_w
|
| 1278 |
+
cur_rank_w = cur_rank % world_size_w
|
| 1279 |
+
out = self.encode_dist_2d(
|
| 1280 |
+
video, world_size_h, world_size_w, cur_rank_h, cur_rank_w
|
| 1281 |
+
)
|
| 1282 |
+
else:
|
| 1283 |
+
# Original 1D splitting logic
|
| 1284 |
+
if width % world_size == 0:
|
| 1285 |
+
out = self.encode_dist(video, world_size, cur_rank, split_dim=4)
|
| 1286 |
+
elif height % world_size == 0:
|
| 1287 |
+
out = self.encode_dist(video, world_size, cur_rank, split_dim=3)
|
| 1288 |
+
else:
|
| 1289 |
+
logger.info("Fall back to naive encode mode")
|
| 1290 |
+
if self.use_tiling:
|
| 1291 |
+
out = self.model.tiled_encode(video, self.scale).squeeze(0)
|
| 1292 |
+
else:
|
| 1293 |
+
out = self.model.encode(video, self.scale).squeeze(0)
|
| 1294 |
+
else:
|
| 1295 |
+
if self.use_tiling:
|
| 1296 |
+
out = self.model.tiled_encode(video, self.scale).squeeze(0)
|
| 1297 |
+
else:
|
| 1298 |
+
out = self.model.encode(video, self.scale).squeeze(0)
|
| 1299 |
+
|
| 1300 |
+
return out
|
| 1301 |
+
|
| 1302 |
+
def decode_dist(self, zs, world_size, cur_rank, split_dim):
|
| 1303 |
+
splited_total_len = zs.shape[split_dim]
|
| 1304 |
+
splited_chunk_len = splited_total_len // world_size
|
| 1305 |
+
padding_size = 1
|
| 1306 |
+
|
| 1307 |
+
if cur_rank == 0:
|
| 1308 |
+
if split_dim == 2:
|
| 1309 |
+
zs = zs[:, :, : splited_chunk_len + 2 * padding_size, :].contiguous()
|
| 1310 |
+
elif split_dim == 3:
|
| 1311 |
+
zs = zs[:, :, :, : splited_chunk_len + 2 * padding_size].contiguous()
|
| 1312 |
+
elif cur_rank == world_size - 1:
|
| 1313 |
+
if split_dim == 2:
|
| 1314 |
+
zs = zs[:, :, -(splited_chunk_len + 2 * padding_size) :, :].contiguous()
|
| 1315 |
+
elif split_dim == 3:
|
| 1316 |
+
zs = zs[:, :, :, -(splited_chunk_len + 2 * padding_size) :].contiguous()
|
| 1317 |
+
else:
|
| 1318 |
+
if split_dim == 2:
|
| 1319 |
+
zs = zs[
|
| 1320 |
+
:,
|
| 1321 |
+
:,
|
| 1322 |
+
cur_rank * splited_chunk_len - padding_size : (cur_rank + 1)
|
| 1323 |
+
* splited_chunk_len
|
| 1324 |
+
+ padding_size,
|
| 1325 |
+
:,
|
| 1326 |
+
].contiguous()
|
| 1327 |
+
elif split_dim == 3:
|
| 1328 |
+
zs = zs[
|
| 1329 |
+
:,
|
| 1330 |
+
:,
|
| 1331 |
+
:,
|
| 1332 |
+
cur_rank * splited_chunk_len - padding_size : (cur_rank + 1)
|
| 1333 |
+
* splited_chunk_len
|
| 1334 |
+
+ padding_size,
|
| 1335 |
+
].contiguous()
|
| 1336 |
+
|
| 1337 |
+
decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
|
| 1338 |
+
images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1)
|
| 1339 |
+
|
| 1340 |
+
if cur_rank == 0:
|
| 1341 |
+
if split_dim == 2:
|
| 1342 |
+
images = images[:, :, :, : splited_chunk_len * 8, :].contiguous()
|
| 1343 |
+
elif split_dim == 3:
|
| 1344 |
+
images = images[:, :, :, :, : splited_chunk_len * 8].contiguous()
|
| 1345 |
+
elif cur_rank == world_size - 1:
|
| 1346 |
+
if split_dim == 2:
|
| 1347 |
+
images = images[:, :, :, -splited_chunk_len * 8 :, :].contiguous()
|
| 1348 |
+
elif split_dim == 3:
|
| 1349 |
+
images = images[:, :, :, :, -splited_chunk_len * 8 :].contiguous()
|
| 1350 |
+
else:
|
| 1351 |
+
if split_dim == 2:
|
| 1352 |
+
images = images[
|
| 1353 |
+
:, :, :, 8 * padding_size : -8 * padding_size, :
|
| 1354 |
+
].contiguous()
|
| 1355 |
+
elif split_dim == 3:
|
| 1356 |
+
images = images[
|
| 1357 |
+
:, :, :, :, 8 * padding_size : -8 * padding_size
|
| 1358 |
+
].contiguous()
|
| 1359 |
+
|
| 1360 |
+
full_images = [torch.empty_like(images) for _ in range(world_size)]
|
| 1361 |
+
dist.all_gather(full_images, images)
|
| 1362 |
+
|
| 1363 |
+
torch.cuda.synchronize()
|
| 1364 |
+
|
| 1365 |
+
images = torch.cat(full_images, dim=split_dim + 1)
|
| 1366 |
+
|
| 1367 |
+
return images
|
| 1368 |
+
|
| 1369 |
+
def decode_dist_2d(self, zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w):
|
| 1370 |
+
total_h = zs.shape[2]
|
| 1371 |
+
total_w = zs.shape[3]
|
| 1372 |
+
|
| 1373 |
+
chunk_h = total_h // world_size_h
|
| 1374 |
+
chunk_w = total_w // world_size_w
|
| 1375 |
+
|
| 1376 |
+
padding_size = 2
|
| 1377 |
+
|
| 1378 |
+
# Calculate H dimension slice
|
| 1379 |
+
if cur_rank_h == 0:
|
| 1380 |
+
h_start = 0
|
| 1381 |
+
h_end = chunk_h + 2 * padding_size
|
| 1382 |
+
elif cur_rank_h == world_size_h - 1:
|
| 1383 |
+
h_start = total_h - (chunk_h + 2 * padding_size)
|
| 1384 |
+
h_end = total_h
|
| 1385 |
+
else:
|
| 1386 |
+
h_start = cur_rank_h * chunk_h - padding_size
|
| 1387 |
+
h_end = (cur_rank_h + 1) * chunk_h + padding_size
|
| 1388 |
+
|
| 1389 |
+
# Calculate W dimension slice
|
| 1390 |
+
if cur_rank_w == 0:
|
| 1391 |
+
w_start = 0
|
| 1392 |
+
w_end = chunk_w + 2 * padding_size
|
| 1393 |
+
elif cur_rank_w == world_size_w - 1:
|
| 1394 |
+
w_start = total_w - (chunk_w + 2 * padding_size)
|
| 1395 |
+
w_end = total_w
|
| 1396 |
+
else:
|
| 1397 |
+
w_start = cur_rank_w * chunk_w - padding_size
|
| 1398 |
+
w_end = (cur_rank_w + 1) * chunk_w + padding_size
|
| 1399 |
+
|
| 1400 |
+
# Extract the latent chunk for this process
|
| 1401 |
+
zs_chunk = zs[:, :, h_start:h_end, w_start:w_end].contiguous()
|
| 1402 |
+
|
| 1403 |
+
# Decode the chunk
|
| 1404 |
+
decode_func = self.model.tiled_decode if self.use_tiling else self.model.decode
|
| 1405 |
+
images_chunk = decode_func(zs_chunk.unsqueeze(0), self.scale).clamp_(-1, 1)
|
| 1406 |
+
|
| 1407 |
+
# Remove padding from decoded chunk
|
| 1408 |
+
spatial_ratio = 8
|
| 1409 |
+
if cur_rank_h == 0:
|
| 1410 |
+
decoded_h_start = 0
|
| 1411 |
+
decoded_h_end = chunk_h * spatial_ratio
|
| 1412 |
+
elif cur_rank_h == world_size_h - 1:
|
| 1413 |
+
decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio
|
| 1414 |
+
decoded_h_end = images_chunk.shape[3]
|
| 1415 |
+
else:
|
| 1416 |
+
decoded_h_start = padding_size * spatial_ratio
|
| 1417 |
+
decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio
|
| 1418 |
+
|
| 1419 |
+
if cur_rank_w == 0:
|
| 1420 |
+
decoded_w_start = 0
|
| 1421 |
+
decoded_w_end = chunk_w * spatial_ratio
|
| 1422 |
+
elif cur_rank_w == world_size_w - 1:
|
| 1423 |
+
decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio
|
| 1424 |
+
decoded_w_end = images_chunk.shape[4]
|
| 1425 |
+
else:
|
| 1426 |
+
decoded_w_start = padding_size * spatial_ratio
|
| 1427 |
+
decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio
|
| 1428 |
+
|
| 1429 |
+
images_chunk = images_chunk[
|
| 1430 |
+
:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end
|
| 1431 |
+
].contiguous()
|
| 1432 |
+
|
| 1433 |
+
# Gather all chunks
|
| 1434 |
+
total_processes = world_size_h * world_size_w
|
| 1435 |
+
full_images = [torch.empty_like(images_chunk) for _ in range(total_processes)]
|
| 1436 |
+
|
| 1437 |
+
dist.all_gather(full_images, images_chunk)
|
| 1438 |
+
|
| 1439 |
+
torch.cuda.synchronize()
|
| 1440 |
+
|
| 1441 |
+
# Reconstruct the full image tensor
|
| 1442 |
+
image_rows = []
|
| 1443 |
+
for h_idx in range(world_size_h):
|
| 1444 |
+
image_cols = []
|
| 1445 |
+
for w_idx in range(world_size_w):
|
| 1446 |
+
process_idx = h_idx * world_size_w + w_idx
|
| 1447 |
+
image_cols.append(full_images[process_idx])
|
| 1448 |
+
image_rows.append(torch.cat(image_cols, dim=4))
|
| 1449 |
+
|
| 1450 |
+
images = torch.cat(image_rows, dim=3)
|
| 1451 |
+
|
| 1452 |
+
return images
|
| 1453 |
+
|
| 1454 |
+
def decode_dist_2d_stream(
|
| 1455 |
+
self, zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w
|
| 1456 |
+
):
|
| 1457 |
+
total_h = zs.shape[2]
|
| 1458 |
+
total_w = zs.shape[3]
|
| 1459 |
+
|
| 1460 |
+
chunk_h = total_h // world_size_h
|
| 1461 |
+
chunk_w = total_w // world_size_w
|
| 1462 |
+
|
| 1463 |
+
padding_size = 2
|
| 1464 |
+
|
| 1465 |
+
# Calculate H dimension slice
|
| 1466 |
+
if cur_rank_h == 0:
|
| 1467 |
+
h_start = 0
|
| 1468 |
+
h_end = chunk_h + 2 * padding_size
|
| 1469 |
+
elif cur_rank_h == world_size_h - 1:
|
| 1470 |
+
h_start = total_h - (chunk_h + 2 * padding_size)
|
| 1471 |
+
h_end = total_h
|
| 1472 |
+
else:
|
| 1473 |
+
h_start = cur_rank_h * chunk_h - padding_size
|
| 1474 |
+
h_end = (cur_rank_h + 1) * chunk_h + padding_size
|
| 1475 |
+
|
| 1476 |
+
# Calculate W dimension slice
|
| 1477 |
+
if cur_rank_w == 0:
|
| 1478 |
+
w_start = 0
|
| 1479 |
+
w_end = chunk_w + 2 * padding_size
|
| 1480 |
+
elif cur_rank_w == world_size_w - 1:
|
| 1481 |
+
w_start = total_w - (chunk_w + 2 * padding_size)
|
| 1482 |
+
w_end = total_w
|
| 1483 |
+
else:
|
| 1484 |
+
w_start = cur_rank_w * chunk_w - padding_size
|
| 1485 |
+
w_end = (cur_rank_w + 1) * chunk_w + padding_size
|
| 1486 |
+
|
| 1487 |
+
# Extract the latent chunk for this process
|
| 1488 |
+
zs_chunk = zs[:, :, h_start:h_end, w_start:w_end].contiguous()
|
| 1489 |
+
|
| 1490 |
+
for image in self.model.decode_stream(zs_chunk.unsqueeze(0), self.scale):
|
| 1491 |
+
images_chunk = image.clamp_(-1, 1)
|
| 1492 |
+
# Remove padding from decoded chunk
|
| 1493 |
+
spatial_ratio = 8
|
| 1494 |
+
if cur_rank_h == 0:
|
| 1495 |
+
decoded_h_start = 0
|
| 1496 |
+
decoded_h_end = chunk_h * spatial_ratio
|
| 1497 |
+
elif cur_rank_h == world_size_h - 1:
|
| 1498 |
+
decoded_h_start = images_chunk.shape[3] - chunk_h * spatial_ratio
|
| 1499 |
+
decoded_h_end = images_chunk.shape[3]
|
| 1500 |
+
else:
|
| 1501 |
+
decoded_h_start = padding_size * spatial_ratio
|
| 1502 |
+
decoded_h_end = images_chunk.shape[3] - padding_size * spatial_ratio
|
| 1503 |
+
|
| 1504 |
+
if cur_rank_w == 0:
|
| 1505 |
+
decoded_w_start = 0
|
| 1506 |
+
decoded_w_end = chunk_w * spatial_ratio
|
| 1507 |
+
elif cur_rank_w == world_size_w - 1:
|
| 1508 |
+
decoded_w_start = images_chunk.shape[4] - chunk_w * spatial_ratio
|
| 1509 |
+
decoded_w_end = images_chunk.shape[4]
|
| 1510 |
+
else:
|
| 1511 |
+
decoded_w_start = padding_size * spatial_ratio
|
| 1512 |
+
decoded_w_end = images_chunk.shape[4] - padding_size * spatial_ratio
|
| 1513 |
+
|
| 1514 |
+
images_chunk = images_chunk[
|
| 1515 |
+
:, :, :, decoded_h_start:decoded_h_end, decoded_w_start:decoded_w_end
|
| 1516 |
+
].contiguous()
|
| 1517 |
+
|
| 1518 |
+
# Gather all chunks
|
| 1519 |
+
total_processes = world_size_h * world_size_w
|
| 1520 |
+
full_images = [
|
| 1521 |
+
torch.empty_like(images_chunk) for _ in range(total_processes)
|
| 1522 |
+
]
|
| 1523 |
+
|
| 1524 |
+
dist.all_gather(full_images, images_chunk)
|
| 1525 |
+
|
| 1526 |
+
torch.cuda.synchronize()
|
| 1527 |
+
|
| 1528 |
+
# Reconstruct the full image tensor
|
| 1529 |
+
image_rows = []
|
| 1530 |
+
for h_idx in range(world_size_h):
|
| 1531 |
+
image_cols = []
|
| 1532 |
+
for w_idx in range(world_size_w):
|
| 1533 |
+
process_idx = h_idx * world_size_w + w_idx
|
| 1534 |
+
image_cols.append(full_images[process_idx])
|
| 1535 |
+
image_rows.append(torch.cat(image_cols, dim=4))
|
| 1536 |
+
|
| 1537 |
+
images = torch.cat(image_rows, dim=3)
|
| 1538 |
+
|
| 1539 |
+
yield images
|
| 1540 |
+
|
| 1541 |
+
def decode(self, zs):
|
| 1542 |
+
if self.parallel:
|
| 1543 |
+
world_size = dist.get_world_size()
|
| 1544 |
+
cur_rank = dist.get_rank()
|
| 1545 |
+
latent_height, latent_width = zs.shape[2], zs.shape[3]
|
| 1546 |
+
|
| 1547 |
+
if self.use_2d_split:
|
| 1548 |
+
world_size_h, world_size_w = self._calculate_2d_grid(
|
| 1549 |
+
latent_height, latent_width, world_size
|
| 1550 |
+
)
|
| 1551 |
+
cur_rank_h = cur_rank // world_size_w
|
| 1552 |
+
cur_rank_w = cur_rank % world_size_w
|
| 1553 |
+
images = self.decode_dist_2d(
|
| 1554 |
+
zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w
|
| 1555 |
+
)
|
| 1556 |
+
else:
|
| 1557 |
+
# Original 1D splitting logic
|
| 1558 |
+
if latent_width % world_size == 0:
|
| 1559 |
+
images = self.decode_dist(zs, world_size, cur_rank, split_dim=3)
|
| 1560 |
+
elif latent_height % world_size == 0:
|
| 1561 |
+
images = self.decode_dist(zs, world_size, cur_rank, split_dim=2)
|
| 1562 |
+
else:
|
| 1563 |
+
logger.info("Fall back to naive decode mode")
|
| 1564 |
+
images = self.model.decode(zs.unsqueeze(0), self.scale).clamp_(
|
| 1565 |
+
-1, 1
|
| 1566 |
+
)
|
| 1567 |
+
else:
|
| 1568 |
+
decode_func = (
|
| 1569 |
+
self.model.tiled_decode if self.use_tiling else self.model.decode
|
| 1570 |
+
)
|
| 1571 |
+
images = decode_func(zs.unsqueeze(0), self.scale).clamp_(-1, 1)
|
| 1572 |
+
|
| 1573 |
+
return images
|
| 1574 |
+
|
| 1575 |
+
def decode_stream(self, zs):
|
| 1576 |
+
if self.parallel:
|
| 1577 |
+
world_size = dist.get_world_size()
|
| 1578 |
+
cur_rank = dist.get_rank()
|
| 1579 |
+
latent_height, latent_width = zs.shape[2], zs.shape[3]
|
| 1580 |
+
|
| 1581 |
+
world_size_h, world_size_w = self._calculate_2d_grid(
|
| 1582 |
+
latent_height, latent_width, world_size
|
| 1583 |
+
)
|
| 1584 |
+
cur_rank_h = cur_rank // world_size_w
|
| 1585 |
+
cur_rank_w = cur_rank % world_size_w
|
| 1586 |
+
for images in self.decode_dist_2d_stream(
|
| 1587 |
+
zs, world_size_h, world_size_w, cur_rank_h, cur_rank_w
|
| 1588 |
+
):
|
| 1589 |
+
yield images
|
| 1590 |
+
else:
|
| 1591 |
+
for image in self.model.decode_stream(zs.unsqueeze(0), self.scale):
|
| 1592 |
+
yield image.clamp_(-1, 1)
|
| 1593 |
+
|
| 1594 |
+
def encode_video(self, vid):
|
| 1595 |
+
return self.model.encode_video(vid)
|
| 1596 |
+
|
| 1597 |
+
def decode_video(self, vid_enc):
|
| 1598 |
+
return self.model.decode_video(vid_enc)
|
generate_video.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import numpy as np
|
| 5 |
+
import time
|
| 6 |
+
import torch
|
| 7 |
+
import torch.distributed as dist
|
| 8 |
+
import subprocess
|
| 9 |
+
import imageio
|
| 10 |
+
import librosa
|
| 11 |
+
import numpy as np
|
| 12 |
+
from loguru import logger
|
| 13 |
+
from collections import deque
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
|
| 16 |
+
from flash_head.inference import get_pipeline, get_base_data, get_infer_params, get_audio_embedding, run_pipeline
|
| 17 |
+
|
| 18 |
+
def _validate_args(args):
|
| 19 |
+
# Basic check
|
| 20 |
+
assert args.ckpt_dir is not None, "Please specify FlashHead model checkpoint directory."
|
| 21 |
+
assert args.wav2vec_dir is not None, "Please specify the wav2vec checkpoint directory."
|
| 22 |
+
assert args.model_type=="pro" or args.model_type=="lite", "Please specify the model name (pro, lite)."
|
| 23 |
+
assert args.cond_image_dir is not None or args.cond_image is not None, "Please specify the condition image or directory."
|
| 24 |
+
assert args.audio_path is not None, "Please specify the audio path."
|
| 25 |
+
|
| 26 |
+
args.base_seed = args.base_seed if args.base_seed >= 0 else 42
|
| 27 |
+
|
| 28 |
+
def _parse_args():
|
| 29 |
+
parser = argparse.ArgumentParser(
|
| 30 |
+
description="Generate video from one image using FlashHead"
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--ckpt_dir",
|
| 34 |
+
type=str,
|
| 35 |
+
default=None,
|
| 36 |
+
help="The path to FlashHead model checkpoint directory.")
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
"--wav2vec_dir",
|
| 39 |
+
type=str,
|
| 40 |
+
default=None,
|
| 41 |
+
help="The path to the wav2vec checkpoint directory.")
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--model_type",
|
| 44 |
+
type=str,
|
| 45 |
+
default=None,
|
| 46 |
+
help="Choose from pro or lite.")
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--save_file",
|
| 49 |
+
type=str,
|
| 50 |
+
default=None,
|
| 51 |
+
help="The file to save the generated video to.")
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--base_seed",
|
| 54 |
+
type=int,
|
| 55 |
+
default=42,
|
| 56 |
+
help="The seed to use for generating the video.")
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--cond_image",
|
| 59 |
+
type=str,
|
| 60 |
+
default=None,
|
| 61 |
+
help="[meta file] The condition image path to generate the video.")
|
| 62 |
+
parser.add_argument(
|
| 63 |
+
"--cond_image_dir",
|
| 64 |
+
type=str,
|
| 65 |
+
default=None,
|
| 66 |
+
help="[meta directory] The directory of condition images.")
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--audio_path",
|
| 69 |
+
type=str,
|
| 70 |
+
default=None,
|
| 71 |
+
help="[meta file] The audio path to generate the video.")
|
| 72 |
+
parser.add_argument(
|
| 73 |
+
"--audio_encode_mode",
|
| 74 |
+
type=str,
|
| 75 |
+
default="stream",
|
| 76 |
+
choices=['stream', 'once'],
|
| 77 |
+
help="stream: encode audio chunk before every generation; once: encode audio together")
|
| 78 |
+
parser.add_argument(
|
| 79 |
+
"--use_face_crop",
|
| 80 |
+
type=bool,
|
| 81 |
+
default=False,
|
| 82 |
+
help="Enable face detection and crop for condition image")
|
| 83 |
+
args = parser.parse_args()
|
| 84 |
+
args = parser.parse_args()
|
| 85 |
+
|
| 86 |
+
_validate_args(args)
|
| 87 |
+
|
| 88 |
+
return args
|
| 89 |
+
|
| 90 |
+
def save_video(frames_list, video_path, audio_path, fps):
|
| 91 |
+
temp_video_path = video_path.replace('.mp4', '_tmp.mp4')
|
| 92 |
+
with imageio.get_writer(temp_video_path, format='mp4', mode='I',
|
| 93 |
+
fps=fps , codec='h264', ffmpeg_params=['-bf', '0']) as writer:
|
| 94 |
+
for frames in frames_list:
|
| 95 |
+
frames = frames.numpy().astype(np.uint8)
|
| 96 |
+
for i in range(frames.shape[0]):
|
| 97 |
+
frame = frames[i, :, :, :]
|
| 98 |
+
writer.append_data(frame)
|
| 99 |
+
|
| 100 |
+
# merge video and audio
|
| 101 |
+
cmd = ['ffmpeg', '-i', temp_video_path, '-i', audio_path, '-c:v', 'copy', '-c:a', 'aac', '-shortest', video_path, '-y']
|
| 102 |
+
subprocess.run(cmd)
|
| 103 |
+
os.remove(temp_video_path)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def generate(args):
|
| 107 |
+
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
| 108 |
+
rank = int(os.environ.get("RANK", 0))
|
| 109 |
+
|
| 110 |
+
pipeline = get_pipeline(world_size=world_size, ckpt_dir=args.ckpt_dir, wav2vec_dir=args.wav2vec_dir, model_type=args.model_type)
|
| 111 |
+
get_base_data(pipeline, cond_image_path_or_dir=args.cond_image_dir if args.cond_image_dir is not None else args.cond_image, base_seed=args.base_seed, use_face_crop=args.use_face_crop)
|
| 112 |
+
infer_params = get_infer_params()
|
| 113 |
+
|
| 114 |
+
sample_rate = infer_params['sample_rate']
|
| 115 |
+
tgt_fps = infer_params['tgt_fps']
|
| 116 |
+
cached_audio_duration = infer_params['cached_audio_duration']
|
| 117 |
+
frame_num = infer_params['frame_num']
|
| 118 |
+
motion_frames_num = infer_params['motion_frames_num']
|
| 119 |
+
slice_len = frame_num - motion_frames_num
|
| 120 |
+
|
| 121 |
+
human_speech_array_all, _ = librosa.load(args.audio_path, sr=infer_params['sample_rate'], mono=True)
|
| 122 |
+
human_speech_array_slice_len = slice_len * sample_rate // tgt_fps
|
| 123 |
+
human_speech_array_frame_num = frame_num * sample_rate // tgt_fps
|
| 124 |
+
|
| 125 |
+
if rank == 0:
|
| 126 |
+
logger.info("Data preparation done. Start to generate video...")
|
| 127 |
+
|
| 128 |
+
generated_list = []
|
| 129 |
+
if args.audio_encode_mode == 'once':
|
| 130 |
+
# pad audio with silence to avoid truncating the last chunk
|
| 131 |
+
remainder = (len(human_speech_array_all) - human_speech_array_frame_num) % human_speech_array_slice_len
|
| 132 |
+
if remainder > 0:
|
| 133 |
+
pad_length = human_speech_array_slice_len - remainder
|
| 134 |
+
human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])
|
| 135 |
+
|
| 136 |
+
# encode audio together
|
| 137 |
+
audio_embedding_all = get_audio_embedding(pipeline, human_speech_array_all)
|
| 138 |
+
|
| 139 |
+
# split audio embedding into chunks
|
| 140 |
+
# for Pro model: 33, 28, 28, 28, ...; For Lite model: 33, 24, 24, 24, ...
|
| 141 |
+
audio_embedding_chunks_list = [audio_embedding_all[:, i * slice_len: i * slice_len + frame_num].contiguous() for i in range((audio_embedding_all.shape[1]-frame_num) // slice_len)]
|
| 142 |
+
|
| 143 |
+
for chunk_idx, audio_embedding_chunk in enumerate(audio_embedding_chunks_list):
|
| 144 |
+
torch.cuda.synchronize()
|
| 145 |
+
start_time = time.time()
|
| 146 |
+
|
| 147 |
+
# inference
|
| 148 |
+
video = run_pipeline(pipeline, audio_embedding_chunk)
|
| 149 |
+
|
| 150 |
+
if chunk_idx != 0:
|
| 151 |
+
video = video[motion_frames_num:]
|
| 152 |
+
|
| 153 |
+
torch.cuda.synchronize()
|
| 154 |
+
end_time = time.time()
|
| 155 |
+
if rank == 0:
|
| 156 |
+
logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.3f}s")
|
| 157 |
+
|
| 158 |
+
generated_list.append(video.cpu())
|
| 159 |
+
|
| 160 |
+
elif args.audio_encode_mode == 'stream':
|
| 161 |
+
cached_audio_length_sum = sample_rate * cached_audio_duration
|
| 162 |
+
audio_end_idx = cached_audio_duration * tgt_fps
|
| 163 |
+
audio_start_idx = audio_end_idx - frame_num
|
| 164 |
+
|
| 165 |
+
audio_dq = deque([0.0] * cached_audio_length_sum, maxlen=cached_audio_length_sum)
|
| 166 |
+
|
| 167 |
+
# pad audio with silence to avoid truncating the last chunk
|
| 168 |
+
remainder = len(human_speech_array_all) % human_speech_array_slice_len
|
| 169 |
+
if remainder > 0:
|
| 170 |
+
pad_length = human_speech_array_slice_len - remainder
|
| 171 |
+
human_speech_array_all = np.concatenate([human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)])
|
| 172 |
+
|
| 173 |
+
# split audio embedding into chunks
|
| 174 |
+
# for Pro model: 28, 28, 28, 28, ...; For Lite model: 24, 24, 24, 24, ...
|
| 175 |
+
human_speech_array_slices = human_speech_array_all.reshape(-1, human_speech_array_slice_len)
|
| 176 |
+
|
| 177 |
+
for chunk_idx, human_speech_array in enumerate(human_speech_array_slices):
|
| 178 |
+
torch.cuda.synchronize()
|
| 179 |
+
start_time = time.time()
|
| 180 |
+
|
| 181 |
+
# streaming encode audio chunks
|
| 182 |
+
audio_dq.extend(human_speech_array.tolist())
|
| 183 |
+
audio_array = np.array(audio_dq)
|
| 184 |
+
audio_embedding = get_audio_embedding(pipeline, audio_array, audio_start_idx, audio_end_idx)
|
| 185 |
+
|
| 186 |
+
# inference
|
| 187 |
+
video = run_pipeline(pipeline, audio_embedding)
|
| 188 |
+
video = video[motion_frames_num:]
|
| 189 |
+
|
| 190 |
+
torch.cuda.synchronize()
|
| 191 |
+
end_time = time.time()
|
| 192 |
+
if rank == 0:
|
| 193 |
+
logger.info(f"Generate video chunk-{chunk_idx} done, cost time: {(end_time - start_time):.3f}s")
|
| 194 |
+
|
| 195 |
+
generated_list.append(video.cpu())
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if rank == 0:
|
| 199 |
+
if args.save_file is None:
|
| 200 |
+
output_dir = 'sample_results'
|
| 201 |
+
if not os.path.exists(output_dir):
|
| 202 |
+
os.makedirs(output_dir)
|
| 203 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H:%M:%S-%f")[:-3]
|
| 204 |
+
filename = f"res_{timestamp}.mp4"
|
| 205 |
+
filepath = os.path.join(output_dir, filename)
|
| 206 |
+
args.save_file = filepath
|
| 207 |
+
|
| 208 |
+
save_video(generated_list, args.save_file, args.audio_path, fps=tgt_fps)
|
| 209 |
+
logger.info(f"Saving generated video to {args.save_file}")
|
| 210 |
+
logger.info("Finished.")
|
| 211 |
+
|
| 212 |
+
if world_size > 1:
|
| 213 |
+
dist.barrier()
|
| 214 |
+
dist.destroy_process_group()
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
args = _parse_args()
|
| 218 |
+
generate(args)
|
gradio_app_streaming.py
ADDED
|
@@ -0,0 +1,339 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio 流式视频生成:视频生成&视频保存异步进行,确保实时性
|
| 3 |
+
"""
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import os
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import time
|
| 9 |
+
import wave
|
| 10 |
+
import imageio
|
| 11 |
+
import librosa
|
| 12 |
+
import subprocess
|
| 13 |
+
import queue
|
| 14 |
+
import threading
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
from collections import deque
|
| 17 |
+
from loguru import logger
|
| 18 |
+
|
| 19 |
+
from flash_head.inference import (
|
| 20 |
+
get_pipeline,
|
| 21 |
+
get_base_data,
|
| 22 |
+
get_infer_params,
|
| 23 |
+
get_audio_embedding,
|
| 24 |
+
run_pipeline,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# gr.Video 的 streaming=True 要求视频片段大于1s,实际需要接近3s才能不卡顿。
|
| 28 |
+
# 为了适配,每 3 个 chunk 合并为一段视频
|
| 29 |
+
CHUNKS_PER_SEGMENT = 3
|
| 30 |
+
|
| 31 |
+
pipeline = None
|
| 32 |
+
loaded_ckpt_dir = None
|
| 33 |
+
loaded_wav2vec_dir = None
|
| 34 |
+
loaded_model_type = None
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _write_frames_to_mp4(frames_list, video_path, fps):
|
| 38 |
+
"""将帧列表写入 MP4(仅视频轨)。"""
|
| 39 |
+
os.makedirs(os.path.dirname(video_path) or ".", exist_ok=True)
|
| 40 |
+
with imageio.get_writer(
|
| 41 |
+
video_path,
|
| 42 |
+
format="mp4",
|
| 43 |
+
mode="I",
|
| 44 |
+
fps=fps,
|
| 45 |
+
codec="h264",
|
| 46 |
+
ffmpeg_params=["-bf", "0"],
|
| 47 |
+
) as writer:
|
| 48 |
+
for frames in frames_list:
|
| 49 |
+
frames_np = frames.numpy().astype(np.uint8)
|
| 50 |
+
for i in range(frames_np.shape[0]):
|
| 51 |
+
writer.append_data(frames_np[i, :, :, :])
|
| 52 |
+
return video_path
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def save_video_with_audio(frames_list, video_path, audio_path, fps):
|
| 56 |
+
"""写入完整视频并混入完整音频(-shortest 保证音画同步,yuv420p + faststart 保证浏览器可播)。"""
|
| 57 |
+
temp_path = video_path.replace(".mp4", "_temp.mp4")
|
| 58 |
+
_write_frames_to_mp4(frames_list, temp_path, fps)
|
| 59 |
+
try:
|
| 60 |
+
cmd = [
|
| 61 |
+
"ffmpeg", "-y",
|
| 62 |
+
"-i", temp_path,
|
| 63 |
+
"-i", audio_path,
|
| 64 |
+
"-c:v", "copy",
|
| 65 |
+
"-c:a", "aac",
|
| 66 |
+
# "-shortest",
|
| 67 |
+
video_path,
|
| 68 |
+
]
|
| 69 |
+
subprocess.run(cmd, check=True, capture_output=True)
|
| 70 |
+
finally:
|
| 71 |
+
if os.path.exists(temp_path):
|
| 72 |
+
os.remove(temp_path)
|
| 73 |
+
return video_path
|
| 74 |
+
|
| 75 |
+
def _save_chunk_audio_to_wav(audio_array, wav_path, sample_rate=16000):
|
| 76 |
+
"""将一段 float32 [-1,1] 的音频数组保存为 wav 文件。"""
|
| 77 |
+
os.makedirs(os.path.dirname(wav_path) or ".", exist_ok=True)
|
| 78 |
+
samples = (np.clip(audio_array, -1.0, 1.0) * 32767).astype(np.int16)
|
| 79 |
+
with wave.open(wav_path, "wb") as wav_file:
|
| 80 |
+
wav_file.setnchannels(1)
|
| 81 |
+
wav_file.setsampwidth(2)
|
| 82 |
+
wav_file.setframerate(sample_rate)
|
| 83 |
+
wav_file.writeframes(samples.tobytes())
|
| 84 |
+
return wav_path
|
| 85 |
+
|
| 86 |
+
def run_inference_streaming(
|
| 87 |
+
ckpt_dir,
|
| 88 |
+
wav2vec_dir,
|
| 89 |
+
model_type,
|
| 90 |
+
cond_image,
|
| 91 |
+
audio_path,
|
| 92 |
+
seed,
|
| 93 |
+
use_face_crop,
|
| 94 |
+
progress=gr.Progress(),
|
| 95 |
+
):
|
| 96 |
+
"""
|
| 97 |
+
流式推理:主程序监控 res_queue,有 frames 就保存并 yield;
|
| 98 |
+
推理在独立线程中执行,按 chunk 顺序 infer,结果放入 res_queue。
|
| 99 |
+
"""
|
| 100 |
+
global pipeline, loaded_ckpt_dir, loaded_wav2vec_dir, loaded_model_type
|
| 101 |
+
|
| 102 |
+
if (
|
| 103 |
+
pipeline is None
|
| 104 |
+
or loaded_ckpt_dir != ckpt_dir
|
| 105 |
+
or loaded_wav2vec_dir != wav2vec_dir
|
| 106 |
+
or loaded_model_type != model_type
|
| 107 |
+
):
|
| 108 |
+
progress(0.2, desc="Loading Model...")
|
| 109 |
+
logger.info(f"Loading pipeline with ckpt_dir={ckpt_dir}, wav2vec_dir={wav2vec_dir}")
|
| 110 |
+
try:
|
| 111 |
+
pipeline = get_pipeline(
|
| 112 |
+
world_size=1,
|
| 113 |
+
ckpt_dir=ckpt_dir,
|
| 114 |
+
model_type=model_type,
|
| 115 |
+
wav2vec_dir=wav2vec_dir,
|
| 116 |
+
)
|
| 117 |
+
loaded_ckpt_dir = ckpt_dir
|
| 118 |
+
loaded_wav2vec_dir = wav2vec_dir
|
| 119 |
+
loaded_model_type = model_type
|
| 120 |
+
except Exception as e:
|
| 121 |
+
logger.error(f"Failed to load model: {e}")
|
| 122 |
+
raise gr.Error(f"Failed to load model: {e}")
|
| 123 |
+
|
| 124 |
+
progress(0.5, desc="Preparing Data...")
|
| 125 |
+
base_seed = int(seed) if seed >= 0 else 9999
|
| 126 |
+
try:
|
| 127 |
+
get_base_data(
|
| 128 |
+
pipeline,
|
| 129 |
+
cond_image_path_or_dir=cond_image,
|
| 130 |
+
base_seed=base_seed,
|
| 131 |
+
use_face_crop=use_face_crop,
|
| 132 |
+
)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
logger.error(f"Error in get_base_data: {e}")
|
| 135 |
+
raise gr.Error(f"Error processing inputs: {e}")
|
| 136 |
+
|
| 137 |
+
infer_params = get_infer_params()
|
| 138 |
+
sample_rate = infer_params["sample_rate"]
|
| 139 |
+
tgt_fps = infer_params["tgt_fps"]
|
| 140 |
+
cached_audio_duration = infer_params["cached_audio_duration"]
|
| 141 |
+
frame_num = infer_params["frame_num"]
|
| 142 |
+
motion_frames_num = infer_params["motion_frames_num"]
|
| 143 |
+
slice_len = frame_num - motion_frames_num
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
human_speech_array_all, _ = librosa.load(audio_path, sr=sample_rate, mono=True)
|
| 147 |
+
except Exception as e:
|
| 148 |
+
raise gr.Error(f"Failed to load audio file: {e}")
|
| 149 |
+
|
| 150 |
+
human_speech_array_slice_len = slice_len * sample_rate // tgt_fps
|
| 151 |
+
|
| 152 |
+
stream_dir = os.path.join("gradio_results", "stream_preview")
|
| 153 |
+
os.makedirs(stream_dir, exist_ok=True)
|
| 154 |
+
timestamp = datetime.now().strftime("%Y%m%d-%H%M%S-%f")[:-3]
|
| 155 |
+
accumulated = []
|
| 156 |
+
|
| 157 |
+
# 默认使用 stream 模式:准备 chunk 切片
|
| 158 |
+
cached_audio_length_sum = sample_rate * cached_audio_duration
|
| 159 |
+
audio_end_idx = cached_audio_duration * tgt_fps
|
| 160 |
+
audio_start_idx = audio_end_idx - frame_num
|
| 161 |
+
remainder = len(human_speech_array_all) % human_speech_array_slice_len
|
| 162 |
+
if remainder > 0:
|
| 163 |
+
pad_length = human_speech_array_slice_len - remainder
|
| 164 |
+
human_speech_array_all = np.concatenate(
|
| 165 |
+
[human_speech_array_all, np.zeros(pad_length, dtype=human_speech_array_all.dtype)]
|
| 166 |
+
)
|
| 167 |
+
human_speech_array_slices = human_speech_array_all.reshape(-1, human_speech_array_slice_len)
|
| 168 |
+
total_chunks = len(human_speech_array_slices)
|
| 169 |
+
if total_chunks == 0:
|
| 170 |
+
raise gr.Error("Audio too short: no chunks to generate. Please use a longer audio.")
|
| 171 |
+
|
| 172 |
+
# Data prepare:按每 k 个 chunk 合并为一段 wav 保存(时间戳+segment_id 命名)
|
| 173 |
+
segment_audio_paths = {}
|
| 174 |
+
num_segments = (total_chunks + CHUNKS_PER_SEGMENT - 1) // CHUNKS_PER_SEGMENT
|
| 175 |
+
for segment_id in range(num_segments):
|
| 176 |
+
start = segment_id * CHUNKS_PER_SEGMENT
|
| 177 |
+
end = min(start + CHUNKS_PER_SEGMENT, total_chunks)
|
| 178 |
+
audio_concat = np.concatenate(
|
| 179 |
+
[human_speech_array_slices[i] for i in range(start, end)]
|
| 180 |
+
)
|
| 181 |
+
segment_audio_name = f"audio_{timestamp}_seg_{segment_id:04d}.wav"
|
| 182 |
+
segment_audio_path = os.path.join(stream_dir, segment_audio_name)
|
| 183 |
+
_save_chunk_audio_to_wav(
|
| 184 |
+
audio_concat,
|
| 185 |
+
segment_audio_path,
|
| 186 |
+
sample_rate=sample_rate,
|
| 187 |
+
)
|
| 188 |
+
segment_audio_paths[segment_id] = segment_audio_path
|
| 189 |
+
logger.info(
|
| 190 |
+
f"Pre-saved {num_segments} segment audios (every {CHUNKS_PER_SEGMENT} chunks) under {stream_dir}"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# 结果队列:推理线程放入 (chunk_idx, chunk_frames_np),主线程根据 chunk_id 取对应音频合并
|
| 194 |
+
res_queue = queue.Queue()
|
| 195 |
+
|
| 196 |
+
def inference_worker():
|
| 197 |
+
"""单独线程:按 chunk 顺序执行 infer,每生成一帧就放入 res_queue,立即继续下一 chunk。"""
|
| 198 |
+
audio_dq = deque([0.0] * cached_audio_length_sum, maxlen=cached_audio_length_sum)
|
| 199 |
+
for chunk_idx, human_speech_array in enumerate(human_speech_array_slices):
|
| 200 |
+
audio_dq.extend(human_speech_array.tolist())
|
| 201 |
+
audio_array = np.array(audio_dq)
|
| 202 |
+
audio_embedding = get_audio_embedding(pipeline, audio_array, audio_start_idx, audio_end_idx)
|
| 203 |
+
torch.cuda.synchronize()
|
| 204 |
+
start_time = time.time()
|
| 205 |
+
video = run_pipeline(pipeline, audio_embedding)
|
| 206 |
+
video = video[motion_frames_num:]
|
| 207 |
+
torch.cuda.synchronize()
|
| 208 |
+
logger.info(f"Infer chunk-{chunk_idx} done, cost time: {time.time() - start_time:.2f}s")
|
| 209 |
+
chunk_frames_np = video.cpu().numpy()
|
| 210 |
+
res_queue.put((chunk_idx, chunk_frames_np))
|
| 211 |
+
res_queue.put(None) # 结束哨兵
|
| 212 |
+
|
| 213 |
+
worker_thread = threading.Thread(target=inference_worker)
|
| 214 |
+
worker_thread.start()
|
| 215 |
+
logger.info("Inference worker thread started. Main will consume res_queue and yield video paths.")
|
| 216 |
+
|
| 217 |
+
# 主程序:监控 res_queue,每凑满 k 个 chunk 合并为一段 mp4(含对应段音频)并 yield
|
| 218 |
+
frame_buffer = []
|
| 219 |
+
while True:
|
| 220 |
+
item = res_queue.get()
|
| 221 |
+
if item is None:
|
| 222 |
+
break
|
| 223 |
+
chunk_idx, chunk_frames_np = item
|
| 224 |
+
chunk_frames = torch.from_numpy(chunk_frames_np)
|
| 225 |
+
accumulated.append(chunk_frames)
|
| 226 |
+
frame_buffer.append(chunk_frames)
|
| 227 |
+
if len(frame_buffer) == CHUNKS_PER_SEGMENT:
|
| 228 |
+
segment_id = (chunk_idx + 1 - CHUNKS_PER_SEGMENT) // CHUNKS_PER_SEGMENT
|
| 229 |
+
segment_audio_path = segment_audio_paths[segment_id]
|
| 230 |
+
segment_path = os.path.join(
|
| 231 |
+
stream_dir, f"preview_{timestamp}_seg_{segment_id:04d}.mp4"
|
| 232 |
+
)
|
| 233 |
+
save_video_with_audio(
|
| 234 |
+
frame_buffer,
|
| 235 |
+
segment_path,
|
| 236 |
+
segment_audio_path,
|
| 237 |
+
fps=tgt_fps,
|
| 238 |
+
)
|
| 239 |
+
logger.info(
|
| 240 |
+
f"Saved segment-{segment_id} (chunks {segment_id * CHUNKS_PER_SEGMENT}-{chunk_idx}) and yielding to frontend."
|
| 241 |
+
)
|
| 242 |
+
yield os.path.abspath(segment_path)
|
| 243 |
+
frame_buffer = []
|
| 244 |
+
|
| 245 |
+
# 不足 k 的剩余 chunk 合并为最后一段
|
| 246 |
+
if frame_buffer:
|
| 247 |
+
segment_id = num_segments - 1
|
| 248 |
+
segment_audio_path = segment_audio_paths[segment_id]
|
| 249 |
+
segment_path = os.path.join(
|
| 250 |
+
stream_dir, f"preview_{timestamp}_seg_{segment_id:04d}.mp4"
|
| 251 |
+
)
|
| 252 |
+
save_video_with_audio(
|
| 253 |
+
frame_buffer,
|
| 254 |
+
segment_path,
|
| 255 |
+
segment_audio_path,
|
| 256 |
+
fps=tgt_fps,
|
| 257 |
+
)
|
| 258 |
+
logger.info(
|
| 259 |
+
f"Saved final segment-{segment_id} ({len(frame_buffer)} chunks) and yielding to frontend."
|
| 260 |
+
)
|
| 261 |
+
yield os.path.abspath(segment_path)
|
| 262 |
+
|
| 263 |
+
worker_thread.join()
|
| 264 |
+
|
| 265 |
+
if not accumulated:
|
| 266 |
+
raise gr.Error("No video frames generated. Please check inputs and try again.")
|
| 267 |
+
|
| 268 |
+
output_dir = "gradio_results"
|
| 269 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 270 |
+
final_filename = f"res_{timestamp}.mp4"
|
| 271 |
+
final_path = os.path.join(output_dir, final_filename)
|
| 272 |
+
save_video_with_audio(accumulated, final_path, audio_path, fps=tgt_fps)
|
| 273 |
+
logger.info(f"Saved to {final_path}")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# ---------- Gradio UI ----------
|
| 277 |
+
with gr.Blocks(title="SoulX-FlashHead 流式视频生成", theme=gr.themes.Soft()) as app:
|
| 278 |
+
gr.Markdown("# ⚡ SoulX-FlashHead 流式视频生成")
|
| 279 |
+
gr.Markdown("上传图片与音频,边生成边播放,音画同步。当前仅支持单GPU。")
|
| 280 |
+
|
| 281 |
+
with gr.Row():
|
| 282 |
+
with gr.Column(scale=1):
|
| 283 |
+
with gr.Group():
|
| 284 |
+
gr.Markdown("### 🎬 生成输入")
|
| 285 |
+
with gr.Row():
|
| 286 |
+
cond_image_input = gr.Image(
|
| 287 |
+
label="Condition Image",
|
| 288 |
+
type="filepath",
|
| 289 |
+
value="examples/girl.png",
|
| 290 |
+
height=300,
|
| 291 |
+
)
|
| 292 |
+
audio_path_input = gr.Audio(
|
| 293 |
+
label="Audio Input",
|
| 294 |
+
type="filepath",
|
| 295 |
+
value="examples/podcast_sichuan_16k.wav",
|
| 296 |
+
)
|
| 297 |
+
generate_btn = gr.Button("🚀 流式生成视频", variant="primary", size="lg")
|
| 298 |
+
with gr.Accordion("⚙️ 高级设置", open=False):
|
| 299 |
+
ckpt_dir_input = gr.Textbox(
|
| 300 |
+
label="FlashHead Checkpoint Directory",
|
| 301 |
+
value="models/SoulX-FlashHead-1_3B",
|
| 302 |
+
)
|
| 303 |
+
wav2vec_dir_input = gr.Textbox(
|
| 304 |
+
label="Wav2Vec Directory",
|
| 305 |
+
value="models/wav2vec2-base-960h",
|
| 306 |
+
)
|
| 307 |
+
model_type_input = gr.Dropdown(
|
| 308 |
+
label="Model Type",
|
| 309 |
+
choices=["pro", "lite"],
|
| 310 |
+
value="lite",
|
| 311 |
+
)
|
| 312 |
+
use_face_crop_input = gr.Checkbox(label="Use Face Crop", value=False)
|
| 313 |
+
seed_input = gr.Number(label="Random Seed", value=9999, precision=0)
|
| 314 |
+
with gr.Column(scale=1):
|
| 315 |
+
gr.Markdown("### 📺 输出视频(流式更新)")
|
| 316 |
+
video_output = gr.Video(
|
| 317 |
+
label="Generated Video",
|
| 318 |
+
height=512,
|
| 319 |
+
format="mp4",
|
| 320 |
+
streaming=True,
|
| 321 |
+
autoplay=True,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
generate_btn.click(
|
| 325 |
+
fn=run_inference_streaming,
|
| 326 |
+
inputs=[
|
| 327 |
+
ckpt_dir_input,
|
| 328 |
+
wav2vec_dir_input,
|
| 329 |
+
model_type_input,
|
| 330 |
+
cond_image_input,
|
| 331 |
+
audio_path_input,
|
| 332 |
+
seed_input,
|
| 333 |
+
use_face_crop_input,
|
| 334 |
+
],
|
| 335 |
+
outputs=video_output,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
if __name__ == "__main__":
|
| 339 |
+
app.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.7.1
|
| 2 |
+
opencv-python>=4.12.0.88
|
| 3 |
+
opencv-python-headless>=4.12.0.88
|
| 4 |
+
diffusers>=0.34.0
|
| 5 |
+
transformers==4.57.3
|
| 6 |
+
tokenizers>=0.20.3
|
| 7 |
+
accelerate>=1.8.1
|
| 8 |
+
tqdm
|
| 9 |
+
imageio
|
| 10 |
+
easydict
|
| 11 |
+
ftfy
|
| 12 |
+
imageio-ffmpeg
|
| 13 |
+
scikit-image
|
| 14 |
+
loguru
|
| 15 |
+
gradio==5.50.0
|
| 16 |
+
xfuser>=0.4.3
|
| 17 |
+
pyloudnorm
|
| 18 |
+
decord
|
| 19 |
+
xformers==0.0.31
|
| 20 |
+
librosa
|
| 21 |
+
mediapipe==0.10.9
|
| 22 |
+
flask
|
| 23 |
+
huggingface_hub
|