svjack commited on
Commit
b2331d4
·
verified ·
1 Parent(s): 061fd50

Upload folder using huggingface_hub

Browse files
Files changed (47) hide show
  1. .gitattributes +2 -0
  2. .gitignore +8 -0
  3. .ipynb_checkpoints/Untitled-checkpoint.ipynb +6 -0
  4. LICENSE +107 -0
  5. README.md +235 -0
  6. Untitled.ipynb +0 -0
  7. assets/images/ref.png +0 -0
  8. assets/poses/align/img_ref_video_dance.mp4 +0 -0
  9. assets/poses/align_demo/img_ref_video_dance.mp4 +3 -0
  10. assets/videos/dance.mp4 +3 -0
  11. configs/.ipynb_checkpoints/test_stage_2-checkpoint.yaml +21 -0
  12. configs/inference_v2.yaml +35 -0
  13. configs/test_stage_1.yaml +26 -0
  14. configs/test_stage_2.yaml +21 -0
  15. downloading_weights.py +38 -0
  16. draw_dwpose.py +112 -0
  17. musepose/__init__.py +0 -0
  18. musepose/dataset/dance_image.py +130 -0
  19. musepose/dataset/dance_video.py +150 -0
  20. musepose/models/attention.py +443 -0
  21. musepose/models/motion_module.py +388 -0
  22. musepose/models/mutual_self_attention.py +363 -0
  23. musepose/models/pose_guider.py +57 -0
  24. musepose/models/resnet.py +252 -0
  25. musepose/models/transformer_2d.py +395 -0
  26. musepose/models/transformer_3d.py +169 -0
  27. musepose/models/unet_2d_blocks.py +1074 -0
  28. musepose/models/unet_2d_condition.py +1307 -0
  29. musepose/models/unet_3d.py +675 -0
  30. musepose/models/unet_3d_blocks.py +871 -0
  31. musepose/pipelines/__init__.py +0 -0
  32. musepose/pipelines/context.py +76 -0
  33. musepose/pipelines/pipeline_pose2img.py +360 -0
  34. musepose/pipelines/pipeline_pose2vid.py +458 -0
  35. musepose/pipelines/pipeline_pose2vid_long.py +571 -0
  36. musepose/pipelines/utils.py +29 -0
  37. musepose/utils/util.py +133 -0
  38. pose/config/dwpose-l_384x288.py +257 -0
  39. pose/config/yolox_l_8xb8-300e_coco.py +245 -0
  40. pose/script/dwpose.py +143 -0
  41. pose/script/tool.py +130 -0
  42. pose/script/util.py +153 -0
  43. pose/script/wholebody.py +121 -0
  44. pose_align.py +556 -0
  45. requirements.txt +25 -0
  46. test_stage_1.py +192 -0
  47. test_stage_2.py +237 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/poses/align_demo/img_ref_video_dance.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ assets/videos/dance.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ .DS_Store
7
+ pretrained_weights
8
+ output
.ipynb_checkpoints/Untitled-checkpoint.ipynb ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [],
3
+ "metadata": {},
4
+ "nbformat": 4,
5
+ "nbformat_minor": 5
6
+ }
LICENSE ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ MIT License
3
+
4
+ Copyright (c) 2024 Tencent Music Entertainment Group
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
23
+
24
+
25
+ Other dependencies and licenses:
26
+
27
+
28
+ Open Source Software Licensed under the MIT License:
29
+ --------------------------------------------------------------------
30
+ 1. sd-vae-ft-mse
31
+ Files:https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main
32
+ License:MIT license
33
+ For details:https://choosealicense.com/licenses/mit/
34
+
35
+
36
+
37
+
38
+ Open Source Software Licensed under the Apache License Version 2.0:
39
+ --------------------------------------------------------------------
40
+ 1. DWpose
41
+ Files:https://huggingface.co/yzd-v/DWPose/tree/main
42
+ License:Apache-2.0
43
+ For details:https://choosealicense.com/licenses/apache-2.0/
44
+
45
+ 2. Moore-AnimateAnyone
46
+ Files:https://github.com/MooreThreads/Moore-AnimateAnyone
47
+ License:Apache-2.0
48
+ For details:https://github.com/MooreThreads/Moore-AnimateAnyone/blob/master/LICENSE
49
+
50
+ Terms of the Apache License Version 2.0:
51
+ --------------------------------------------------------------------
52
+ Apache License
53
+
54
+ Version 2.0, January 2004
55
+
56
+ http://www.apache.org/licenses/
57
+
58
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
59
+ 1. Definitions.
60
+
61
+ "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
62
+
63
+ "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
64
+
65
+ "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
66
+
67
+ "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
68
+
69
+ "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
70
+
71
+ "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
72
+
73
+ "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
74
+
75
+ "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
76
+
77
+ "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
78
+
79
+ "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
80
+
81
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
82
+
83
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
84
+
85
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
86
+
87
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
88
+
89
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
90
+
91
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
92
+
93
+ If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
94
+
95
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
96
+
97
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
98
+
99
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
100
+
101
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
102
+
103
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
104
+
105
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
106
+
107
+ END OF TERMS AND CONDITIONS
README.md ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MusePose
2
+
3
+ MusePose: a Pose-Driven Image-to-Video Framework for Virtual Human Generation.
4
+
5
+ Zhengyan Tong,
6
+ Chao Li,
7
+ Zhaokang Chen,
8
+ Bin Wu<sup>†</sup>,
9
+ Wenjiang Zhou
10
+ (<sup>†</sup>Corresponding Author, benbinwu@tencent.com)
11
+
12
+ Lyra Lab, Tencent Music Entertainment
13
+
14
+
15
+ **[github](https://github.com/TMElyralab/MusePose)** **[huggingface](https://huggingface.co/TMElyralab/MusePose)** **space (comming soon)** **Project (comming soon)** **Technical report (comming soon)**
16
+
17
+ [MusePose](https://github.com/TMElyralab/MusePose) is an image-to-video generation framework for virtual human under control signal such as pose. The current released model was an implementation of [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) by optimizing [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone).
18
+
19
+ `MusePose` is the last building block of **the Muse opensource serie**. Together with [MuseV](https://github.com/TMElyralab/MuseV) and [MuseTalk](https://github.com/TMElyralab/MuseTalk), we hope the community can join us and march towards the vision where a virtual human can be generated end2end with native ability of full body movement and interaction. Please stay tuned for our next milestone!
20
+
21
+ We really appreciate [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) for their academic paper and [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone) for their code base, which have significantly expedited the development of the AIGC community and [MusePose](https://github.com/TMElyralab/MusePose).
22
+
23
+ Update:
24
+ 1. We support [Comfyui-MusePose](https://github.com/TMElyralab/Comfyui-MusePose) now!
25
+
26
+ ## Recruitment
27
+ Join Lyra Lab, Tencent Music Entertainment!
28
+
29
+ We are currently seeking AIGC researchers including Internships, New Grads, and Senior (实习、校招、社招).
30
+
31
+ Please find details in the following two links or contact zkangchen@tencent.com
32
+
33
+ - AI Researcher (https://join.tencentmusic.com/social/post-details/?id=13488, https://join.tencentmusic.com/social/post-details/?id=13502)
34
+
35
+ ## Overview
36
+ [MusePose](https://github.com/TMElyralab/MusePose) is a diffusion-based and pose-guided virtual human video generation framework.
37
+ Our main contributions could be summarized as follows:
38
+ 1. The released model can generate dance videos of the human character in a reference image under the given pose sequence. The result quality exceeds almost all current open source models within the same topic.
39
+ 2. We release the `pose align` algorithm so that users could align arbitrary dance videos to arbitrary reference images, which **SIGNIFICANTLY** improved inference performance and enhanced model usability.
40
+ 3. We have fixed several important bugs and made some improvement based on the code of [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone).
41
+
42
+ ## Demos
43
+ <table class="center">
44
+
45
+ <tr>
46
+ <td width=50% style="border: none">
47
+ <video controls autoplay loop src="https://github.com/TMElyralab/MusePose/assets/47803475/bb52ca3e-8a5c-405a-8575-7ab42abca248" muted="false"></video>
48
+ </td>
49
+ <td width=50% style="border: none">
50
+ <video controls autoplay loop src="https://github.com/TMElyralab/MusePose/assets/47803475/6667c9ae-8417-49a1-bbbb-fe1695404c23" muted="false"></video>
51
+ </td>
52
+ </tr>
53
+
54
+ <tr>
55
+ <td width=50% style="border: none">
56
+ <video controls autoplay loop src="https://github.com/TMElyralab/MusePose/assets/47803475/7f7a3aaf-2720-4b50-8bca-3257acce4733" muted="false"></video>
57
+ </td>
58
+ <td width=50% style="border: none">
59
+ <video controls autoplay loop src="https://github.com/TMElyralab/MusePose/assets/47803475/c56f7e9c-d94d-494e-88e6-62a4a3c1e016" muted="false"></video>
60
+ </td>
61
+ </tr>
62
+
63
+
64
+ <tr>
65
+ <td width=50% style="border: none">
66
+ <video controls autoplay loop src="https://github.com/TMElyralab/MusePose/assets/47803475/00a9faec-2453-4834-ad1f-44eb0ec8247d" muted="false"></video>
67
+ </td>
68
+ <td width=50% style="border: none">
69
+ <video controls autoplay loop src="https://github.com/TMElyralab/MusePose/assets/47803475/41ad26b3-d477-4975-bf29-73a3c9ed0380" muted="false"></video>
70
+ </td>
71
+ </tr>
72
+
73
+ <tr>
74
+ <td width=50% style="border: none">
75
+ <video controls autoplay loop src="https://github.com/TMElyralab/MusePose/assets/47803475/2bbebf98-6805-4f1b-b769-537f69cc0e4b" muted="false"></video>
76
+ </td>
77
+ <td width=50% style="border: none">
78
+ <video controls autoplay loop src="https://github.com/TMElyralab/MusePose/assets/47803475/1b2b97d0-0ae9-49a6-83ba-b3024ae64f08" muted="false"></video>
79
+ </td>
80
+ </tr>
81
+
82
+ </table>
83
+
84
+
85
+ ## News
86
+ - [05/27/2024] Release `MusePose` and pretrained models.
87
+ - [05/31/2024] Support [Comfyui-MusePose](https://github.com/TMElyralab/Comfyui-MusePose)
88
+ - [06/14/2024] Bug Fixed in `inference_v2.yaml`.
89
+
90
+
91
+ ## Todo:
92
+ - [x] release our trained models and inference codes of MusePose.
93
+ - [x] release pose align algorithm.
94
+ - [x] Comfyui-MusePose
95
+ - [ ] training guidelines.
96
+ - [ ] Huggingface Gradio demo.
97
+ - [ ] a improved architecture and model (may take longer).
98
+
99
+
100
+ # Getting Started
101
+ We provide a detailed tutorial about the installation and the basic usage of MusePose for new users:
102
+
103
+ ## Installation
104
+ To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
105
+
106
+ ### Build environment
107
+
108
+ We recommend a python version >=3.10 and cuda version =11.7. Then build environment as follows:
109
+
110
+ ```shell
111
+ pip install -r requirements.txt
112
+ ```
113
+
114
+ ### mmlab packages
115
+ ```bash
116
+ pip install --no-cache-dir -U openmim
117
+ mim install mmengine
118
+ mim install "mmcv>=2.0.1"
119
+ mim install "mmdet>=3.1.0"
120
+ mim install "mmpose>=1.1.0"
121
+ ```
122
+
123
+
124
+ ### Download weights
125
+ You can download weights manually as follows:
126
+
127
+ 1. Download our trained [weights](https://huggingface.co/TMElyralab/MusePose).
128
+
129
+ 2. Download the weights of other components:
130
+ - [sd-image-variations-diffusers](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/unet)
131
+ - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
132
+ - [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
133
+ - [yolox](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth) - Make sure to rename to `yolox_l_8x8_300e_coco.pth`
134
+ - [image_encoder](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/tree/main/image_encoder)
135
+
136
+ Finally, these weights should be organized in `pretrained_weights` as follows:
137
+ ```
138
+ ./pretrained_weights/
139
+ |-- MusePose
140
+ | |-- denoising_unet.pth
141
+ | |-- motion_module.pth
142
+ | |-- pose_guider.pth
143
+ | └── reference_unet.pth
144
+ |-- dwpose
145
+ | |-- dw-ll_ucoco_384.pth
146
+ | └── yolox_l_8x8_300e_coco.pth
147
+ |-- sd-image-variations-diffusers
148
+ | └── unet
149
+ | |-- config.json
150
+ | └── diffusion_pytorch_model.bin
151
+ |-- image_encoder
152
+ | |-- config.json
153
+ | └── pytorch_model.bin
154
+ └── sd-vae-ft-mse
155
+ |-- config.json
156
+ └── diffusion_pytorch_model.bin
157
+
158
+ ```
159
+ ## Quickstart
160
+ ### Inference
161
+ #### Preparation
162
+ Prepare your referemce images and dance videos in the folder ```./assets``` and organnized as the example:
163
+ ```
164
+ ./assets/
165
+ |-- images
166
+ | └── ref.png
167
+ └── videos
168
+ └── dance.mp4
169
+ ```
170
+
171
+ #### Pose Alignment
172
+ Get the aligned dwpose of the reference image:
173
+ ```
174
+ python pose_align.py --imgfn_refer ./assets/images/ref.png --vidfn ./assets/videos/dance.mp4
175
+ ```
176
+ After this, you can see the pose align results in ```./assets/poses```, where ```./assets/poses/align/img_ref_video_dance.mp4``` is the aligned dwpose and the ```./assets/poses/align_demo/img_ref_video_dance.mp4``` is for debug.
177
+
178
+ #### Inferring MusePose
179
+ Add the path of the reference image and the aligned dwpose to the test config file ```./configs/test_stage_2.yaml``` as the example:
180
+ ```
181
+ test_cases:
182
+ "./assets/images/ref.png":
183
+ - "./assets/poses/align/img_ref_video_dance.mp4"
184
+ ```
185
+
186
+ Then, simply run
187
+ ```
188
+ python test_stage_2.py --config ./configs/test_stage_2.yaml
189
+ ```
190
+ ```./configs/test_stage_2.yaml``` is the path to the inference configuration file.
191
+
192
+ Finally, you can see the output results in ```./output/```
193
+
194
+ ##### Reducing VRAM cost
195
+ If you want to reduce the VRAM cost, you could set the width and height for inference. For example,
196
+ ```
197
+ python test_stage_2.py --config ./configs/test_stage_2.yaml -W 512 -H 512
198
+ ```
199
+ It will generate the video at 512 x 512 first, and then resize it back to the original size of the pose video.
200
+
201
+ Currently, it takes 16GB VRAM to run on 512 x 512 x 48 and takes 28GB VRAM to run on 768 x 768 x 48. However, it should be noticed that the inference resolution would affect the final results (especially face region).
202
+
203
+ #### Face Enhancement
204
+
205
+ If you want to enhance the face region to have a better consistency of the face, you could use [FaceFusion](https://github.com/facefusion/facefusion). You could use the `face-swap` function to swap the face in the reference image to the generated video.
206
+
207
+ ### Training
208
+
209
+
210
+
211
+ # Acknowledgement
212
+ 1. We thank [AnimateAnyone](https://github.com/HumanAIGC/AnimateAnyone) for their technical report, and have refer much to [Moore-AnimateAnyone](https://github.com/MooreThreads/Moore-AnimateAnyone) and [diffusers](https://github.com/huggingface/diffusers).
213
+ 1. We thank open-source components like [AnimateDiff](https://animatediff.github.io/), [dwpose](https://github.com/IDEA-Research/DWPose), [Stable Diffusion](https://github.com/CompVis/stable-diffusion), etc..
214
+
215
+ Thanks for open-sourcing!
216
+
217
+ # Limitations
218
+ - Detail consitency: some details of the original character are not well preserved (e.g. face region and complex clothing).
219
+ - Noise and flickering: we observe noise and flicking in complex background.
220
+
221
+ # Citation
222
+ ```bib
223
+ @article{musepose,
224
+ title={MusePose: a Pose-Driven Image-to-Video Framework for Virtual Human Generation},
225
+ author={Tong, Zhengyan and Li, Chao and Chen, Zhaokang and Wu, Bin and Zhou, Wenjiang},
226
+ journal={arxiv},
227
+ year={2024}
228
+ }
229
+ ```
230
+ # Disclaimer/License
231
+ 1. `code`: The code of MusePose is released under the MIT License. There is no limitation for both academic and commercial usage.
232
+ 1. `model`: The trained model are available for non-commercial research purposes only.
233
+ 1. `other opensource model`: Other open-source models used must comply with their license, such as `ft-mse-vae`, `dwpose`, etc..
234
+ 1. The testdata are collected from internet, which are available for non-commercial research purposes only.
235
+ 1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.
Untitled.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
assets/images/ref.png ADDED
assets/poses/align/img_ref_video_dance.mp4 ADDED
Binary file (458 kB). View file
 
assets/poses/align_demo/img_ref_video_dance.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d82dbea870955be98d731d8361894b98901e081d0f1a44913f545fc518d4342c
3
+ size 1599273
assets/videos/dance.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ef159b9e8e3768d91903267c00a5c9a01f7d8e5a0575010aeb7bc242c33f84d
3
+ size 2692364
configs/.ipynb_checkpoints/test_stage_2-checkpoint.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_base_model_path: './pretrained_weights/sd-image-variations-diffusers'
2
+ pretrained_vae_path: './pretrained_weights/sd-vae-ft-mse'
3
+ image_encoder_path: './pretrained_weights/image_encoder'
4
+
5
+
6
+
7
+ denoising_unet_path: "./pretrained_weights/MusePose/denoising_unet.pth"
8
+ reference_unet_path: "./pretrained_weights/MusePose/reference_unet.pth"
9
+ pose_guider_path: "./pretrained_weights/MusePose/pose_guider.pth"
10
+ motion_module_path: "./pretrained_weights/MusePose/motion_module.pth"
11
+
12
+
13
+
14
+ inference_config: "./configs/inference_v2.yaml"
15
+ weight_dtype: 'fp16'
16
+
17
+
18
+
19
+ test_cases:
20
+ "./assets/images/ref.png":
21
+ - "./assets/poses/align/img_ref_video_dance.mp4"
configs/inference_v2.yaml ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unet_additional_kwargs:
2
+ use_inflated_groupnorm: true
3
+ unet_use_cross_frame_attention: false
4
+ unet_use_temporal_attention: false
5
+ use_motion_module: true
6
+ motion_module_resolutions:
7
+ - 1
8
+ - 2
9
+ - 4
10
+ - 8
11
+ motion_module_mid_block: true
12
+ motion_module_decoder_only: false
13
+ motion_module_type: Vanilla
14
+ motion_module_kwargs:
15
+ num_attention_heads: 8
16
+ num_transformer_block: 1
17
+ attention_block_types:
18
+ - Temporal_Self
19
+ - Temporal_Self
20
+ temporal_position_encoding: true
21
+ temporal_position_encoding_max_len: 128
22
+ temporal_attention_dim_div: 1
23
+
24
+ noise_scheduler_kwargs:
25
+ beta_start: 0.00085
26
+ beta_end: 0.012
27
+ beta_schedule: "scaled_linear"
28
+ clip_sample: false
29
+ steps_offset: 1
30
+ ### Zero-SNR params
31
+ prediction_type: "v_prediction"
32
+ rescale_betas_zero_snr: True
33
+ timestep_spacing: "trailing"
34
+
35
+ sampler: DDIM
configs/test_stage_1.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_base_model_path: './pretrained_weights/sd-image-variations-diffusers'
2
+ pretrained_vae_path: './pretrained_weights/sd-vae-ft-mse'
3
+ image_encoder_path: './pretrained_weights/image_encoder'
4
+
5
+
6
+
7
+ denoising_unet_path: "./pretrained_weights/MusePose/denoising_unet.pth"
8
+ reference_unet_path: "./pretrained_weights/MusePose/reference_unet.pth"
9
+ pose_guider_path: "./pretrained_weights/MusePose/pose_guider.pth"
10
+
11
+
12
+
13
+
14
+ inference_config: "./configs/inference_v2.yaml"
15
+ weight_dtype: 'fp16'
16
+
17
+
18
+
19
+ test_cases:
20
+ "./assets/images/ref.png":
21
+ - "./assets/poses/align/img_ref_video_dance.mp4"
22
+
23
+
24
+
25
+
26
+
configs/test_stage_2.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_base_model_path: './pretrained_weights/sd-image-variations-diffusers'
2
+ pretrained_vae_path: './pretrained_weights/sd-vae-ft-mse'
3
+ image_encoder_path: './pretrained_weights/image_encoder'
4
+
5
+
6
+
7
+ denoising_unet_path: "./pretrained_weights/MusePose/denoising_unet.pth"
8
+ reference_unet_path: "./pretrained_weights/MusePose/reference_unet.pth"
9
+ pose_guider_path: "./pretrained_weights/MusePose/pose_guider.pth"
10
+ motion_module_path: "./pretrained_weights/MusePose/motion_module.pth"
11
+
12
+
13
+
14
+ inference_config: "./configs/inference_v2.yaml"
15
+ weight_dtype: 'fp16'
16
+
17
+
18
+
19
+ test_cases:
20
+ "./assets/images/ref.png":
21
+ - "./assets/poses/align/img_ref_video_dance.mp4"
downloading_weights.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import wget
3
+ from tqdm import tqdm
4
+
5
+ os.makedirs('pretrained_weights', exist_ok=True)
6
+
7
+ urls = ['https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
8
+ 'https://huggingface.co/yzd-v/DWPose/resolve/main/dw-ll_ucoco_384.pth',
9
+ 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/denoising_unet.pth',
10
+ 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/motion_module.pth',
11
+ 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/pose_guider.pth',
12
+ 'https://huggingface.co/TMElyralab/MusePose/resolve/main/MusePose/reference_unet.pth',
13
+ 'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/diffusion_pytorch_model.bin',
14
+ 'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/pytorch_model.bin',
15
+ 'https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/diffusion_pytorch_model.bin'
16
+ ]
17
+
18
+ paths = ['dwpose', 'dwpose', 'MusePose', 'MusePose', 'MusePose', 'MusePose', 'sd-image-variations-diffusers/unet', 'image_encoder', 'sd-vae-ft-mse']
19
+
20
+ for path in paths:
21
+ os.makedirs(f'pretrained_weights/{path}', exist_ok=True)
22
+
23
+ # saving weights
24
+ for url, path in tqdm(zip(urls, paths)):
25
+ filename = wget.download(url, f'pretrained_weights/{path}')
26
+
27
+ config_urls = ['https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/unet/config.json',
28
+ 'https://huggingface.co/lambdalabs/sd-image-variations-diffusers/resolve/main/image_encoder/config.json',
29
+ 'https://huggingface.co/stabilityai/sd-vae-ft-mse/resolve/main/config.json']
30
+
31
+ config_paths = ['sd-image-variations-diffusers/unet', 'image_encoder', 'sd-vae-ft-mse']
32
+
33
+ # saving config files
34
+ for url, path in tqdm(zip(config_urls, config_paths)):
35
+ filename = wget.download(url, f'pretrained_weights/{path}')
36
+
37
+ # renaming model name as given in readme
38
+ os.rename('pretrained_weights/dwpose/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth', 'pretrained_weights/dwpose/yolox_l_8x8_300e_coco.pth')
draw_dwpose.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import argparse
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+
8
+ from pose.script.tool import save_videos_from_pil
9
+ from pose.script.dwpose import draw_pose
10
+
11
+
12
+
13
+ def draw_dwpose(video_path, pose_path, out_path, draw_face):
14
+
15
+ # capture video info
16
+ cap = cv2.VideoCapture(video_path)
17
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
18
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
19
+ fps = cap.get(cv2.CAP_PROP_FPS)
20
+ fps = int(np.around(fps))
21
+ # fps = get_fps(video_path)
22
+ cap.release()
23
+
24
+ # render resolution, short edge = 1024
25
+ k = float(1024) / min(width, height)
26
+ h_render = int(k*height//2 * 2)
27
+ w_render = int(k*width//2 * 2)
28
+
29
+ # save resolution, short edge = 768
30
+ k = float(768) / min(width, height)
31
+ h_save = int(k*height//2 * 2)
32
+ w_save = int(k*width//2 * 2)
33
+
34
+ poses = np.load(pose_path, allow_pickle=True)
35
+ poses = poses.tolist()
36
+
37
+ frames = []
38
+ for pose in tqdm(poses):
39
+ detected_map = draw_pose(pose, h_render, w_render, draw_face)
40
+ detected_map = cv2.resize(detected_map, (w_save, h_save), interpolation=cv2.INTER_AREA)
41
+ # cv2.imshow('', detected_map)
42
+ # cv2.waitKey(0)
43
+ detected_map = cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB)
44
+ detected_map = Image.fromarray(detected_map)
45
+ frames.append(detected_map)
46
+
47
+ save_videos_from_pil(frames, out_path, fps)
48
+
49
+
50
+
51
+ if __name__ == "__main__":
52
+
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument("--video_dir", type=str, default="./UBC_fashion/test", help='dance video dir')
55
+ parser.add_argument("--pose_dir", type=str, default=None, help='auto makedir')
56
+ parser.add_argument("--save_dir", type=str, default=None, help='auto makedir')
57
+ parser.add_argument("--draw_face", type=bool, default=False, help='whether draw face or not')
58
+ args = parser.parse_args()
59
+
60
+
61
+ # video dir
62
+ video_dir = args.video_dir
63
+
64
+ # pose dir
65
+ if args.pose_dir is None:
66
+ pose_dir = args.video_dir + "_dwpose_keypoints"
67
+ else:
68
+ pose_dir = args.pose_dir
69
+
70
+ # save dir
71
+ if args.save_dir is None:
72
+ if args.draw_face == True:
73
+ save_dir = args.video_dir + "_dwpose"
74
+ else:
75
+ save_dir = args.video_dir + "_dwpose_without_face"
76
+ else:
77
+ save_dir = args.save_dir
78
+ if not os.path.exists(save_dir):
79
+ os.makedirs(save_dir)
80
+
81
+
82
+ # collect all video_folder paths
83
+ video_mp4_paths = set()
84
+ for root, dirs, files in os.walk(args.video_dir):
85
+ for name in files:
86
+ if name.endswith(".mp4"):
87
+ video_mp4_paths.add(os.path.join(root, name))
88
+ video_mp4_paths = list(video_mp4_paths)
89
+ # random.shuffle(video_mp4_paths)
90
+ video_mp4_paths.sort()
91
+ print("Num of videos:", len(video_mp4_paths))
92
+
93
+
94
+ # draw dwpose
95
+ for i in range(len(video_mp4_paths)):
96
+ video_path = video_mp4_paths[i]
97
+ video_name = os.path.relpath(video_path, video_dir)
98
+ base_name = os.path.splitext(video_name)[0]
99
+
100
+ pose_path = os.path.join(pose_dir, base_name + '.npy')
101
+ if not os.path.exists(pose_path):
102
+ print('no keypoint file:', pose_path)
103
+
104
+ out_path = os.path.join(save_dir, base_name + '.mp4')
105
+ if os.path.exists(out_path):
106
+ print('already have rendered pose:', out_path)
107
+ continue
108
+
109
+ draw_dwpose(video_path, pose_path, out_path, args.draw_face)
110
+ print(f"Process {i+1}/{len(video_mp4_paths)} video")
111
+
112
+ print('all done!')
musepose/__init__.py ADDED
File without changes
musepose/dataset/dance_image.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from decord import VideoReader
7
+ from PIL import Image
8
+ from torch.utils.data import Dataset
9
+ from transformers import CLIPImageProcessor
10
+
11
+
12
+ class HumanDanceDataset(Dataset):
13
+ def __init__(
14
+ self,
15
+ img_size,
16
+ img_scale=(1.0, 1.0),
17
+ img_ratio=(0.9, 1.0),
18
+ drop_ratio=0.1,
19
+ data_meta_paths=["./data/fahsion_meta.json"],
20
+ sample_margin=30,
21
+ ):
22
+ super().__init__()
23
+
24
+ self.img_size = img_size
25
+ self.img_scale = img_scale
26
+ self.img_ratio = img_ratio
27
+ self.sample_margin = sample_margin
28
+
29
+ # -----
30
+ # vid_meta format:
31
+ # [{'video_path': , 'kps_path': , 'other':},
32
+ # {'video_path': , 'kps_path': , 'other':}]
33
+ # -----
34
+ vid_meta = []
35
+ for data_meta_path in data_meta_paths:
36
+ vid_meta.extend(json.load(open(data_meta_path, "r")))
37
+ self.vid_meta = vid_meta
38
+
39
+ self.clip_image_processor = CLIPImageProcessor()
40
+
41
+ self.transform = transforms.Compose(
42
+ [
43
+ # transforms.RandomResizedCrop(
44
+ # self.img_size,
45
+ # scale=self.img_scale,
46
+ # ratio=self.img_ratio,
47
+ # interpolation=transforms.InterpolationMode.BILINEAR,
48
+ # ),
49
+ transforms.Resize(
50
+ self.img_size,
51
+ ),
52
+ transforms.ToTensor(),
53
+ transforms.Normalize([0.5], [0.5]),
54
+ ]
55
+ )
56
+
57
+ self.cond_transform = transforms.Compose(
58
+ [
59
+ # transforms.RandomResizedCrop(
60
+ # self.img_size,
61
+ # scale=self.img_scale,
62
+ # ratio=self.img_ratio,
63
+ # interpolation=transforms.InterpolationMode.BILINEAR,
64
+ # ),
65
+ transforms.Resize(
66
+ self.img_size,
67
+ ),
68
+ transforms.ToTensor(),
69
+ ]
70
+ )
71
+
72
+ self.drop_ratio = drop_ratio
73
+
74
+ def augmentation(self, image, transform, state=None):
75
+ if state is not None:
76
+ torch.set_rng_state(state)
77
+ return transform(image)
78
+
79
+ def __getitem__(self, index):
80
+ video_meta = self.vid_meta[index]
81
+ video_path = video_meta["video_path"]
82
+ kps_path = video_meta["kps_path"]
83
+
84
+ video_reader = VideoReader(video_path)
85
+ kps_reader = VideoReader(kps_path)
86
+
87
+ assert len(video_reader) == len(
88
+ kps_reader
89
+ ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
90
+
91
+ video_length = len(video_reader)
92
+
93
+ margin = min(self.sample_margin, video_length)
94
+
95
+ ref_img_idx = random.randint(0, video_length - 1)
96
+ if ref_img_idx + margin < video_length:
97
+ tgt_img_idx = random.randint(ref_img_idx + margin, video_length - 1)
98
+ elif ref_img_idx - margin > 0:
99
+ tgt_img_idx = random.randint(0, ref_img_idx - margin)
100
+ else:
101
+ tgt_img_idx = random.randint(0, video_length - 1)
102
+
103
+ ref_img = video_reader[ref_img_idx]
104
+ ref_img_pil = Image.fromarray(ref_img.asnumpy())
105
+ tgt_img = video_reader[tgt_img_idx]
106
+ tgt_img_pil = Image.fromarray(tgt_img.asnumpy())
107
+
108
+ tgt_pose = kps_reader[tgt_img_idx]
109
+ tgt_pose_pil = Image.fromarray(tgt_pose.asnumpy())
110
+
111
+ state = torch.get_rng_state()
112
+ tgt_img = self.augmentation(tgt_img_pil, self.transform, state)
113
+ tgt_pose_img = self.augmentation(tgt_pose_pil, self.cond_transform, state)
114
+ ref_img_vae = self.augmentation(ref_img_pil, self.transform, state)
115
+ clip_image = self.clip_image_processor(
116
+ images=ref_img_pil, return_tensors="pt"
117
+ ).pixel_values[0]
118
+
119
+ sample = dict(
120
+ video_dir=video_path,
121
+ img=tgt_img,
122
+ tgt_pose=tgt_pose_img,
123
+ ref_img=ref_img_vae,
124
+ clip_images=clip_image,
125
+ )
126
+
127
+ return sample
128
+
129
+ def __len__(self):
130
+ return len(self.vid_meta)
musepose/dataset/dance_video.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ from decord import VideoReader
10
+ from PIL import Image
11
+ from torch.utils.data import Dataset
12
+ from transformers import CLIPImageProcessor
13
+
14
+
15
+ class HumanDanceVideoDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ sample_rate,
19
+ n_sample_frames,
20
+ width,
21
+ height,
22
+ img_scale=(1.0, 1.0),
23
+ img_ratio=(0.9, 1.0),
24
+ drop_ratio=0.1,
25
+ data_meta_paths=["./data/fashion_meta.json"],
26
+ ):
27
+ super().__init__()
28
+ self.sample_rate = sample_rate
29
+ self.n_sample_frames = n_sample_frames
30
+ self.width = width
31
+ self.height = height
32
+ self.img_scale = img_scale
33
+ self.img_ratio = img_ratio
34
+
35
+ vid_meta = []
36
+ for data_meta_path in data_meta_paths:
37
+ vid_meta.extend(json.load(open(data_meta_path, "r")))
38
+ self.vid_meta = vid_meta
39
+
40
+ self.clip_image_processor = CLIPImageProcessor()
41
+
42
+ self.pixel_transform = transforms.Compose(
43
+ [
44
+ # transforms.RandomResizedCrop(
45
+ # (height, width),
46
+ # scale=self.img_scale,
47
+ # ratio=self.img_ratio,
48
+ # interpolation=transforms.InterpolationMode.BILINEAR,
49
+ # ),
50
+ transforms.Resize(
51
+ (height, width),
52
+ ),
53
+ transforms.ToTensor(),
54
+ transforms.Normalize([0.5], [0.5]),
55
+ ]
56
+ )
57
+
58
+ self.cond_transform = transforms.Compose(
59
+ [
60
+ # transforms.RandomResizedCrop(
61
+ # (height, width),
62
+ # scale=self.img_scale,
63
+ # ratio=self.img_ratio,
64
+ # interpolation=transforms.InterpolationMode.BILINEAR,
65
+ # ),
66
+ transforms.Resize(
67
+ (height, width),
68
+ ),
69
+ transforms.ToTensor(),
70
+ ]
71
+ )
72
+
73
+ self.drop_ratio = drop_ratio
74
+
75
+ def augmentation(self, images, transform, state=None):
76
+ if state is not None:
77
+ torch.set_rng_state(state)
78
+ if isinstance(images, List):
79
+ transformed_images = [transform(img) for img in images]
80
+ ret_tensor = torch.stack(transformed_images, dim=0) # (f, c, h, w)
81
+ else:
82
+ ret_tensor = transform(images) # (c, h, w)
83
+ return ret_tensor
84
+
85
+ def __getitem__(self, index):
86
+ video_meta = self.vid_meta[index]
87
+ video_path = video_meta["video_path"]
88
+ kps_path = video_meta["kps_path"]
89
+
90
+ video_reader = VideoReader(video_path)
91
+ kps_reader = VideoReader(kps_path)
92
+
93
+ assert len(video_reader) == len(
94
+ kps_reader
95
+ ), f"{len(video_reader) = } != {len(kps_reader) = } in {video_path}"
96
+
97
+ video_length = len(video_reader)
98
+ video_fps = video_reader.get_avg_fps()
99
+ # print("fps", video_fps)
100
+ if video_fps > 30: # 30-60
101
+ sample_rate = self.sample_rate*2
102
+ else:
103
+ sample_rate = self.sample_rate
104
+
105
+
106
+ clip_length = min(
107
+ video_length, (self.n_sample_frames - 1) * sample_rate + 1
108
+ )
109
+ start_idx = random.randint(0, video_length - clip_length)
110
+ batch_index = np.linspace(
111
+ start_idx, start_idx + clip_length - 1, self.n_sample_frames, dtype=int
112
+ ).tolist()
113
+
114
+ # read frames and kps
115
+ vid_pil_image_list = []
116
+ pose_pil_image_list = []
117
+ for index in batch_index:
118
+ img = video_reader[index]
119
+ vid_pil_image_list.append(Image.fromarray(img.asnumpy()))
120
+ img = kps_reader[index]
121
+ pose_pil_image_list.append(Image.fromarray(img.asnumpy()))
122
+
123
+ ref_img_idx = random.randint(0, video_length - 1)
124
+ ref_img = Image.fromarray(video_reader[ref_img_idx].asnumpy())
125
+
126
+ # transform
127
+ state = torch.get_rng_state()
128
+ pixel_values_vid = self.augmentation(
129
+ vid_pil_image_list, self.pixel_transform, state
130
+ )
131
+ pixel_values_pose = self.augmentation(
132
+ pose_pil_image_list, self.cond_transform, state
133
+ )
134
+ pixel_values_ref_img = self.augmentation(ref_img, self.pixel_transform, state)
135
+ clip_ref_img = self.clip_image_processor(
136
+ images=ref_img, return_tensors="pt"
137
+ ).pixel_values[0]
138
+
139
+ sample = dict(
140
+ video_dir=video_path,
141
+ pixel_values_vid=pixel_values_vid,
142
+ pixel_values_pose=pixel_values_pose,
143
+ pixel_values_ref_img=pixel_values_ref_img,
144
+ clip_ref_img=clip_ref_img,
145
+ )
146
+
147
+ return sample
148
+
149
+ def __len__(self):
150
+ return len(self.vid_meta)
musepose/models/attention.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.models.attention import AdaLayerNorm, Attention, FeedForward
7
+ from diffusers.models.embeddings import SinusoidalPositionalEmbedding
8
+ from einops import rearrange
9
+ from torch import nn
10
+
11
+
12
+ class BasicTransformerBlock(nn.Module):
13
+ r"""
14
+ A basic Transformer block.
15
+
16
+ Parameters:
17
+ dim (`int`): The number of channels in the input and output.
18
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
19
+ attention_head_dim (`int`): The number of channels in each head.
20
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
21
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
22
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
23
+ num_embeds_ada_norm (:
24
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
25
+ attention_bias (:
26
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
27
+ only_cross_attention (`bool`, *optional*):
28
+ Whether to use only cross-attention layers. In this case two cross attention layers are used.
29
+ double_self_attention (`bool`, *optional*):
30
+ Whether to use two self-attention layers. In this case no cross attention layers are used.
31
+ upcast_attention (`bool`, *optional*):
32
+ Whether to upcast the attention computation to float32. This is useful for mixed precision training.
33
+ norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
34
+ Whether to use learnable elementwise affine parameters for normalization.
35
+ norm_type (`str`, *optional*, defaults to `"layer_norm"`):
36
+ The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`.
37
+ final_dropout (`bool` *optional*, defaults to False):
38
+ Whether to apply a final dropout after the last feed-forward layer.
39
+ attention_type (`str`, *optional*, defaults to `"default"`):
40
+ The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`.
41
+ positional_embeddings (`str`, *optional*, defaults to `None`):
42
+ The type of positional embeddings to apply to.
43
+ num_positional_embeddings (`int`, *optional*, defaults to `None`):
44
+ The maximum number of positional embeddings to apply.
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ dim: int,
50
+ num_attention_heads: int,
51
+ attention_head_dim: int,
52
+ dropout=0.0,
53
+ cross_attention_dim: Optional[int] = None,
54
+ activation_fn: str = "geglu",
55
+ num_embeds_ada_norm: Optional[int] = None,
56
+ attention_bias: bool = False,
57
+ only_cross_attention: bool = False,
58
+ double_self_attention: bool = False,
59
+ upcast_attention: bool = False,
60
+ norm_elementwise_affine: bool = True,
61
+ norm_type: str = "layer_norm", # 'layer_norm', 'ada_norm', 'ada_norm_zero', 'ada_norm_single'
62
+ norm_eps: float = 1e-5,
63
+ final_dropout: bool = False,
64
+ attention_type: str = "default",
65
+ positional_embeddings: Optional[str] = None,
66
+ num_positional_embeddings: Optional[int] = None,
67
+ ):
68
+ super().__init__()
69
+ self.only_cross_attention = only_cross_attention
70
+
71
+ self.use_ada_layer_norm_zero = (
72
+ num_embeds_ada_norm is not None
73
+ ) and norm_type == "ada_norm_zero"
74
+ self.use_ada_layer_norm = (
75
+ num_embeds_ada_norm is not None
76
+ ) and norm_type == "ada_norm"
77
+ self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
78
+ self.use_layer_norm = norm_type == "layer_norm"
79
+
80
+ if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
81
+ raise ValueError(
82
+ f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
83
+ f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
84
+ )
85
+
86
+ if positional_embeddings and (num_positional_embeddings is None):
87
+ raise ValueError(
88
+ "If `positional_embedding` type is defined, `num_positition_embeddings` must also be defined."
89
+ )
90
+
91
+ if positional_embeddings == "sinusoidal":
92
+ self.pos_embed = SinusoidalPositionalEmbedding(
93
+ dim, max_seq_length=num_positional_embeddings
94
+ )
95
+ else:
96
+ self.pos_embed = None
97
+
98
+ # Define 3 blocks. Each block has its own normalization layer.
99
+ # 1. Self-Attn
100
+ if self.use_ada_layer_norm:
101
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
102
+ elif self.use_ada_layer_norm_zero:
103
+ self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
104
+ else:
105
+ self.norm1 = nn.LayerNorm(
106
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
107
+ )
108
+
109
+ self.attn1 = Attention(
110
+ query_dim=dim,
111
+ heads=num_attention_heads,
112
+ dim_head=attention_head_dim,
113
+ dropout=dropout,
114
+ bias=attention_bias,
115
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
116
+ upcast_attention=upcast_attention,
117
+ )
118
+
119
+ # 2. Cross-Attn
120
+ if cross_attention_dim is not None or double_self_attention:
121
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
122
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
123
+ # the second cross attention block.
124
+ self.norm2 = (
125
+ AdaLayerNorm(dim, num_embeds_ada_norm)
126
+ if self.use_ada_layer_norm
127
+ else nn.LayerNorm(
128
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
129
+ )
130
+ )
131
+ self.attn2 = Attention(
132
+ query_dim=dim,
133
+ cross_attention_dim=cross_attention_dim
134
+ if not double_self_attention
135
+ else None,
136
+ heads=num_attention_heads,
137
+ dim_head=attention_head_dim,
138
+ dropout=dropout,
139
+ bias=attention_bias,
140
+ upcast_attention=upcast_attention,
141
+ ) # is self-attn if encoder_hidden_states is none
142
+ else:
143
+ self.norm2 = None
144
+ self.attn2 = None
145
+
146
+ # 3. Feed-forward
147
+ if not self.use_ada_layer_norm_single:
148
+ self.norm3 = nn.LayerNorm(
149
+ dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps
150
+ )
151
+
152
+ self.ff = FeedForward(
153
+ dim,
154
+ dropout=dropout,
155
+ activation_fn=activation_fn,
156
+ final_dropout=final_dropout,
157
+ )
158
+
159
+ # 4. Fuser
160
+ if attention_type == "gated" or attention_type == "gated-text-image":
161
+ self.fuser = GatedSelfAttentionDense(
162
+ dim, cross_attention_dim, num_attention_heads, attention_head_dim
163
+ )
164
+
165
+ # 5. Scale-shift for PixArt-Alpha.
166
+ if self.use_ada_layer_norm_single:
167
+ self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
168
+
169
+ # let chunk size default to None
170
+ self._chunk_size = None
171
+ self._chunk_dim = 0
172
+
173
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
174
+ # Sets chunk feed-forward
175
+ self._chunk_size = chunk_size
176
+ self._chunk_dim = dim
177
+
178
+ def forward(
179
+ self,
180
+ hidden_states: torch.FloatTensor,
181
+ attention_mask: Optional[torch.FloatTensor] = None,
182
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
183
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
184
+ timestep: Optional[torch.LongTensor] = None,
185
+ cross_attention_kwargs: Dict[str, Any] = None,
186
+ class_labels: Optional[torch.LongTensor] = None,
187
+ ) -> torch.FloatTensor:
188
+ # Notice that normalization is always applied before the real computation in the following blocks.
189
+ # 0. Self-Attention
190
+ batch_size = hidden_states.shape[0]
191
+
192
+ if self.use_ada_layer_norm:
193
+ norm_hidden_states = self.norm1(hidden_states, timestep)
194
+ elif self.use_ada_layer_norm_zero:
195
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
196
+ hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
197
+ )
198
+ elif self.use_layer_norm:
199
+ norm_hidden_states = self.norm1(hidden_states)
200
+ elif self.use_ada_layer_norm_single:
201
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
202
+ self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
203
+ ).chunk(6, dim=1)
204
+ norm_hidden_states = self.norm1(hidden_states)
205
+ norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
206
+ norm_hidden_states = norm_hidden_states.squeeze(1)
207
+ else:
208
+ raise ValueError("Incorrect norm used")
209
+
210
+ if self.pos_embed is not None:
211
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
212
+
213
+ # 1. Retrieve lora scale.
214
+ lora_scale = (
215
+ cross_attention_kwargs.get("scale", 1.0)
216
+ if cross_attention_kwargs is not None
217
+ else 1.0
218
+ )
219
+
220
+ # 2. Prepare GLIGEN inputs
221
+ cross_attention_kwargs = (
222
+ cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
223
+ )
224
+ gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
225
+
226
+ attn_output = self.attn1(
227
+ norm_hidden_states,
228
+ encoder_hidden_states=encoder_hidden_states
229
+ if self.only_cross_attention
230
+ else None,
231
+ attention_mask=attention_mask,
232
+ **cross_attention_kwargs,
233
+ )
234
+ if self.use_ada_layer_norm_zero:
235
+ attn_output = gate_msa.unsqueeze(1) * attn_output
236
+ elif self.use_ada_layer_norm_single:
237
+ attn_output = gate_msa * attn_output
238
+
239
+ hidden_states = attn_output + hidden_states
240
+ if hidden_states.ndim == 4:
241
+ hidden_states = hidden_states.squeeze(1)
242
+
243
+ # 2.5 GLIGEN Control
244
+ if gligen_kwargs is not None:
245
+ hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
246
+
247
+ # 3. Cross-Attention
248
+ if self.attn2 is not None:
249
+ if self.use_ada_layer_norm:
250
+ norm_hidden_states = self.norm2(hidden_states, timestep)
251
+ elif self.use_ada_layer_norm_zero or self.use_layer_norm:
252
+ norm_hidden_states = self.norm2(hidden_states)
253
+ elif self.use_ada_layer_norm_single:
254
+ # For PixArt norm2 isn't applied here:
255
+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
256
+ norm_hidden_states = hidden_states
257
+ else:
258
+ raise ValueError("Incorrect norm")
259
+
260
+ if self.pos_embed is not None and self.use_ada_layer_norm_single is False:
261
+ norm_hidden_states = self.pos_embed(norm_hidden_states)
262
+
263
+ attn_output = self.attn2(
264
+ norm_hidden_states,
265
+ encoder_hidden_states=encoder_hidden_states,
266
+ attention_mask=encoder_attention_mask,
267
+ **cross_attention_kwargs,
268
+ )
269
+ hidden_states = attn_output + hidden_states
270
+
271
+ # 4. Feed-forward
272
+ if not self.use_ada_layer_norm_single:
273
+ norm_hidden_states = self.norm3(hidden_states)
274
+
275
+ if self.use_ada_layer_norm_zero:
276
+ norm_hidden_states = (
277
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
278
+ )
279
+
280
+ if self.use_ada_layer_norm_single:
281
+ norm_hidden_states = self.norm2(hidden_states)
282
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
283
+
284
+ ff_output = self.ff(norm_hidden_states, scale=lora_scale)
285
+
286
+ if self.use_ada_layer_norm_zero:
287
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
288
+ elif self.use_ada_layer_norm_single:
289
+ ff_output = gate_mlp * ff_output
290
+
291
+ hidden_states = ff_output + hidden_states
292
+ if hidden_states.ndim == 4:
293
+ hidden_states = hidden_states.squeeze(1)
294
+
295
+ return hidden_states
296
+
297
+
298
+ class TemporalBasicTransformerBlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim: int,
302
+ num_attention_heads: int,
303
+ attention_head_dim: int,
304
+ dropout=0.0,
305
+ cross_attention_dim: Optional[int] = None,
306
+ activation_fn: str = "geglu",
307
+ num_embeds_ada_norm: Optional[int] = None,
308
+ attention_bias: bool = False,
309
+ only_cross_attention: bool = False,
310
+ upcast_attention: bool = False,
311
+ unet_use_cross_frame_attention=None,
312
+ unet_use_temporal_attention=None,
313
+ ):
314
+ super().__init__()
315
+ self.only_cross_attention = only_cross_attention
316
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
317
+ self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
318
+ self.unet_use_temporal_attention = unet_use_temporal_attention
319
+
320
+ # SC-Attn
321
+ self.attn1 = Attention(
322
+ query_dim=dim,
323
+ heads=num_attention_heads,
324
+ dim_head=attention_head_dim,
325
+ dropout=dropout,
326
+ bias=attention_bias,
327
+ upcast_attention=upcast_attention,
328
+ )
329
+ self.norm1 = (
330
+ AdaLayerNorm(dim, num_embeds_ada_norm)
331
+ if self.use_ada_layer_norm
332
+ else nn.LayerNorm(dim)
333
+ )
334
+
335
+ # Cross-Attn
336
+ if cross_attention_dim is not None:
337
+ self.attn2 = Attention(
338
+ query_dim=dim,
339
+ cross_attention_dim=cross_attention_dim,
340
+ heads=num_attention_heads,
341
+ dim_head=attention_head_dim,
342
+ dropout=dropout,
343
+ bias=attention_bias,
344
+ upcast_attention=upcast_attention,
345
+ )
346
+ else:
347
+ self.attn2 = None
348
+
349
+ if cross_attention_dim is not None:
350
+ self.norm2 = (
351
+ AdaLayerNorm(dim, num_embeds_ada_norm)
352
+ if self.use_ada_layer_norm
353
+ else nn.LayerNorm(dim)
354
+ )
355
+ else:
356
+ self.norm2 = None
357
+
358
+ # Feed-forward
359
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
360
+ self.norm3 = nn.LayerNorm(dim)
361
+ self.use_ada_layer_norm_zero = False
362
+
363
+ # Temp-Attn
364
+ assert unet_use_temporal_attention is not None
365
+ if unet_use_temporal_attention:
366
+ self.attn_temp = Attention(
367
+ query_dim=dim,
368
+ heads=num_attention_heads,
369
+ dim_head=attention_head_dim,
370
+ dropout=dropout,
371
+ bias=attention_bias,
372
+ upcast_attention=upcast_attention,
373
+ )
374
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
375
+ self.norm_temp = (
376
+ AdaLayerNorm(dim, num_embeds_ada_norm)
377
+ if self.use_ada_layer_norm
378
+ else nn.LayerNorm(dim)
379
+ )
380
+
381
+ def forward(
382
+ self,
383
+ hidden_states,
384
+ encoder_hidden_states=None,
385
+ timestep=None,
386
+ attention_mask=None,
387
+ video_length=None,
388
+ ):
389
+ norm_hidden_states = (
390
+ self.norm1(hidden_states, timestep)
391
+ if self.use_ada_layer_norm
392
+ else self.norm1(hidden_states)
393
+ )
394
+
395
+ if self.unet_use_cross_frame_attention:
396
+ hidden_states = (
397
+ self.attn1(
398
+ norm_hidden_states,
399
+ attention_mask=attention_mask,
400
+ video_length=video_length,
401
+ )
402
+ + hidden_states
403
+ )
404
+ else:
405
+ hidden_states = (
406
+ self.attn1(norm_hidden_states, attention_mask=attention_mask)
407
+ + hidden_states
408
+ )
409
+
410
+ if self.attn2 is not None:
411
+ # Cross-Attention
412
+ norm_hidden_states = (
413
+ self.norm2(hidden_states, timestep)
414
+ if self.use_ada_layer_norm
415
+ else self.norm2(hidden_states)
416
+ )
417
+ hidden_states = (
418
+ self.attn2(
419
+ norm_hidden_states,
420
+ encoder_hidden_states=encoder_hidden_states,
421
+ attention_mask=attention_mask,
422
+ )
423
+ + hidden_states
424
+ )
425
+
426
+ # Feed-forward
427
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
428
+
429
+ # Temporal-Attention
430
+ if self.unet_use_temporal_attention:
431
+ d = hidden_states.shape[1]
432
+ hidden_states = rearrange(
433
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
434
+ )
435
+ norm_hidden_states = (
436
+ self.norm_temp(hidden_states, timestep)
437
+ if self.use_ada_layer_norm
438
+ else self.norm_temp(hidden_states)
439
+ )
440
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
441
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
442
+
443
+ return hidden_states
musepose/models/motion_module.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapt from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
2
+ import math
3
+ from dataclasses import dataclass
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from diffusers.models.attention import FeedForward
8
+ from diffusers.models.attention_processor import Attention, AttnProcessor
9
+ from diffusers.utils import BaseOutput
10
+ from diffusers.utils.import_utils import is_xformers_available
11
+ from einops import rearrange, repeat
12
+ from torch import nn
13
+
14
+
15
+ def zero_module(module):
16
+ # Zero out the parameters of a module and return it.
17
+ for p in module.parameters():
18
+ p.detach().zero_()
19
+ return module
20
+
21
+
22
+ @dataclass
23
+ class TemporalTransformer3DModelOutput(BaseOutput):
24
+ sample: torch.FloatTensor
25
+
26
+
27
+ if is_xformers_available():
28
+ import xformers
29
+ import xformers.ops
30
+ else:
31
+ xformers = None
32
+
33
+
34
+ def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
35
+ if motion_module_type == "Vanilla":
36
+ return VanillaTemporalModule(
37
+ in_channels=in_channels,
38
+ **motion_module_kwargs,
39
+ )
40
+ else:
41
+ raise ValueError
42
+
43
+
44
+ class VanillaTemporalModule(nn.Module):
45
+ def __init__(
46
+ self,
47
+ in_channels,
48
+ num_attention_heads=8,
49
+ num_transformer_block=2,
50
+ attention_block_types=("Temporal_Self", "Temporal_Self"),
51
+ cross_frame_attention_mode=None,
52
+ temporal_position_encoding=False,
53
+ temporal_position_encoding_max_len=24,
54
+ temporal_attention_dim_div=1,
55
+ zero_initialize=True,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.temporal_transformer = TemporalTransformer3DModel(
60
+ in_channels=in_channels,
61
+ num_attention_heads=num_attention_heads,
62
+ attention_head_dim=in_channels
63
+ // num_attention_heads
64
+ // temporal_attention_dim_div,
65
+ num_layers=num_transformer_block,
66
+ attention_block_types=attention_block_types,
67
+ cross_frame_attention_mode=cross_frame_attention_mode,
68
+ temporal_position_encoding=temporal_position_encoding,
69
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
70
+ )
71
+
72
+ if zero_initialize:
73
+ self.temporal_transformer.proj_out = zero_module(
74
+ self.temporal_transformer.proj_out
75
+ )
76
+
77
+ def forward(
78
+ self,
79
+ input_tensor,
80
+ temb,
81
+ encoder_hidden_states,
82
+ attention_mask=None,
83
+ anchor_frame_idx=None,
84
+ ):
85
+ hidden_states = input_tensor
86
+ hidden_states = self.temporal_transformer(
87
+ hidden_states, encoder_hidden_states, attention_mask
88
+ )
89
+
90
+ output = hidden_states
91
+ return output
92
+
93
+
94
+ class TemporalTransformer3DModel(nn.Module):
95
+ def __init__(
96
+ self,
97
+ in_channels,
98
+ num_attention_heads,
99
+ attention_head_dim,
100
+ num_layers,
101
+ attention_block_types=(
102
+ "Temporal_Self",
103
+ "Temporal_Self",
104
+ ),
105
+ dropout=0.0,
106
+ norm_num_groups=32,
107
+ cross_attention_dim=768,
108
+ activation_fn="geglu",
109
+ attention_bias=False,
110
+ upcast_attention=False,
111
+ cross_frame_attention_mode=None,
112
+ temporal_position_encoding=False,
113
+ temporal_position_encoding_max_len=24,
114
+ ):
115
+ super().__init__()
116
+
117
+ inner_dim = num_attention_heads * attention_head_dim
118
+
119
+ self.norm = torch.nn.GroupNorm(
120
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
121
+ )
122
+ self.proj_in = nn.Linear(in_channels, inner_dim)
123
+
124
+ self.transformer_blocks = nn.ModuleList(
125
+ [
126
+ TemporalTransformerBlock(
127
+ dim=inner_dim,
128
+ num_attention_heads=num_attention_heads,
129
+ attention_head_dim=attention_head_dim,
130
+ attention_block_types=attention_block_types,
131
+ dropout=dropout,
132
+ norm_num_groups=norm_num_groups,
133
+ cross_attention_dim=cross_attention_dim,
134
+ activation_fn=activation_fn,
135
+ attention_bias=attention_bias,
136
+ upcast_attention=upcast_attention,
137
+ cross_frame_attention_mode=cross_frame_attention_mode,
138
+ temporal_position_encoding=temporal_position_encoding,
139
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
140
+ )
141
+ for d in range(num_layers)
142
+ ]
143
+ )
144
+ self.proj_out = nn.Linear(inner_dim, in_channels)
145
+
146
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
147
+ assert (
148
+ hidden_states.dim() == 5
149
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
150
+ video_length = hidden_states.shape[2]
151
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
152
+
153
+ batch, channel, height, weight = hidden_states.shape
154
+ residual = hidden_states
155
+
156
+ hidden_states = self.norm(hidden_states)
157
+ inner_dim = hidden_states.shape[1]
158
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
159
+ batch, height * weight, inner_dim
160
+ )
161
+ hidden_states = self.proj_in(hidden_states)
162
+
163
+ # Transformer Blocks
164
+ for block in self.transformer_blocks:
165
+ hidden_states = block(
166
+ hidden_states,
167
+ encoder_hidden_states=encoder_hidden_states,
168
+ video_length=video_length,
169
+ )
170
+
171
+ # output
172
+ hidden_states = self.proj_out(hidden_states)
173
+ hidden_states = (
174
+ hidden_states.reshape(batch, height, weight, inner_dim)
175
+ .permute(0, 3, 1, 2)
176
+ .contiguous()
177
+ )
178
+
179
+ output = hidden_states + residual
180
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
181
+
182
+ return output
183
+
184
+
185
+ class TemporalTransformerBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ dim,
189
+ num_attention_heads,
190
+ attention_head_dim,
191
+ attention_block_types=(
192
+ "Temporal_Self",
193
+ "Temporal_Self",
194
+ ),
195
+ dropout=0.0,
196
+ norm_num_groups=32,
197
+ cross_attention_dim=768,
198
+ activation_fn="geglu",
199
+ attention_bias=False,
200
+ upcast_attention=False,
201
+ cross_frame_attention_mode=None,
202
+ temporal_position_encoding=False,
203
+ temporal_position_encoding_max_len=24,
204
+ ):
205
+ super().__init__()
206
+
207
+ attention_blocks = []
208
+ norms = []
209
+
210
+ for block_name in attention_block_types:
211
+ attention_blocks.append(
212
+ VersatileAttention(
213
+ attention_mode=block_name.split("_")[0],
214
+ cross_attention_dim=cross_attention_dim
215
+ if block_name.endswith("_Cross")
216
+ else None,
217
+ query_dim=dim,
218
+ heads=num_attention_heads,
219
+ dim_head=attention_head_dim,
220
+ dropout=dropout,
221
+ bias=attention_bias,
222
+ upcast_attention=upcast_attention,
223
+ cross_frame_attention_mode=cross_frame_attention_mode,
224
+ temporal_position_encoding=temporal_position_encoding,
225
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
226
+ )
227
+ )
228
+ norms.append(nn.LayerNorm(dim))
229
+
230
+ self.attention_blocks = nn.ModuleList(attention_blocks)
231
+ self.norms = nn.ModuleList(norms)
232
+
233
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
234
+ self.ff_norm = nn.LayerNorm(dim)
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states,
239
+ encoder_hidden_states=None,
240
+ attention_mask=None,
241
+ video_length=None,
242
+ ):
243
+ for attention_block, norm in zip(self.attention_blocks, self.norms):
244
+ norm_hidden_states = norm(hidden_states)
245
+ hidden_states = (
246
+ attention_block(
247
+ norm_hidden_states,
248
+ encoder_hidden_states=encoder_hidden_states
249
+ if attention_block.is_cross_attention
250
+ else None,
251
+ video_length=video_length,
252
+ )
253
+ + hidden_states
254
+ )
255
+
256
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
257
+
258
+ output = hidden_states
259
+ return output
260
+
261
+
262
+ class PositionalEncoding(nn.Module):
263
+ def __init__(self, d_model, dropout=0.0, max_len=24):
264
+ super().__init__()
265
+ self.dropout = nn.Dropout(p=dropout)
266
+ position = torch.arange(max_len).unsqueeze(1)
267
+ div_term = torch.exp(
268
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
269
+ )
270
+ pe = torch.zeros(1, max_len, d_model)
271
+ pe[0, :, 0::2] = torch.sin(position * div_term)
272
+ pe[0, :, 1::2] = torch.cos(position * div_term)
273
+ self.register_buffer("pe", pe)
274
+
275
+ def forward(self, x):
276
+ x = x + self.pe[:, : x.size(1)]
277
+ return self.dropout(x)
278
+
279
+
280
+ class VersatileAttention(Attention):
281
+ def __init__(
282
+ self,
283
+ attention_mode=None,
284
+ cross_frame_attention_mode=None,
285
+ temporal_position_encoding=False,
286
+ temporal_position_encoding_max_len=24,
287
+ *args,
288
+ **kwargs,
289
+ ):
290
+ super().__init__(*args, **kwargs)
291
+ assert attention_mode == "Temporal"
292
+
293
+ self.attention_mode = attention_mode
294
+ self.is_cross_attention = kwargs["cross_attention_dim"] is not None
295
+
296
+ self.pos_encoder = (
297
+ PositionalEncoding(
298
+ kwargs["query_dim"],
299
+ dropout=0.0,
300
+ max_len=temporal_position_encoding_max_len,
301
+ )
302
+ if (temporal_position_encoding and attention_mode == "Temporal")
303
+ else None
304
+ )
305
+
306
+ def extra_repr(self):
307
+ return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
308
+
309
+ def set_use_memory_efficient_attention_xformers(
310
+ self,
311
+ use_memory_efficient_attention_xformers: bool,
312
+ attention_op: Optional[Callable] = None,
313
+ ):
314
+ if use_memory_efficient_attention_xformers:
315
+ if not is_xformers_available():
316
+ raise ModuleNotFoundError(
317
+ (
318
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
319
+ " xformers"
320
+ ),
321
+ name="xformers",
322
+ )
323
+ elif not torch.cuda.is_available():
324
+ raise ValueError(
325
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
326
+ " only available for GPU "
327
+ )
328
+ else:
329
+ try:
330
+ # Make sure we can run the memory efficient attention
331
+ _ = xformers.ops.memory_efficient_attention(
332
+ torch.randn((1, 2, 40), device="cuda"),
333
+ torch.randn((1, 2, 40), device="cuda"),
334
+ torch.randn((1, 2, 40), device="cuda"),
335
+ )
336
+ except Exception as e:
337
+ raise e
338
+
339
+ # XFormersAttnProcessor corrupts video generation and work with Pytorch 1.13.
340
+ # Pytorch 2.0.1 AttnProcessor works the same as XFormersAttnProcessor in Pytorch 1.13.
341
+ # You don't need XFormersAttnProcessor here.
342
+ # processor = XFormersAttnProcessor(
343
+ # attention_op=attention_op,
344
+ # )
345
+ processor = AttnProcessor()
346
+ else:
347
+ processor = AttnProcessor()
348
+
349
+ self.set_processor(processor)
350
+
351
+ def forward(
352
+ self,
353
+ hidden_states,
354
+ encoder_hidden_states=None,
355
+ attention_mask=None,
356
+ video_length=None,
357
+ **cross_attention_kwargs,
358
+ ):
359
+ if self.attention_mode == "Temporal":
360
+ d = hidden_states.shape[1] # d means HxW
361
+ hidden_states = rearrange(
362
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
363
+ )
364
+
365
+ if self.pos_encoder is not None:
366
+ hidden_states = self.pos_encoder(hidden_states)
367
+
368
+ encoder_hidden_states = (
369
+ repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d)
370
+ if encoder_hidden_states is not None
371
+ else encoder_hidden_states
372
+ )
373
+
374
+ else:
375
+ raise NotImplementedError
376
+
377
+ hidden_states = self.processor(
378
+ self,
379
+ hidden_states,
380
+ encoder_hidden_states=encoder_hidden_states,
381
+ attention_mask=attention_mask,
382
+ **cross_attention_kwargs,
383
+ )
384
+
385
+ if self.attention_mode == "Temporal":
386
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
387
+
388
+ return hidden_states
musepose/models/mutual_self_attention.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/models/mutual_self_attention.py
2
+ from typing import Any, Dict, Optional
3
+
4
+ import torch
5
+ from einops import rearrange
6
+
7
+ from musepose.models.attention import TemporalBasicTransformerBlock
8
+
9
+ from .attention import BasicTransformerBlock
10
+
11
+
12
+ def torch_dfs(model: torch.nn.Module):
13
+ result = [model]
14
+ for child in model.children():
15
+ result += torch_dfs(child)
16
+ return result
17
+
18
+
19
+ class ReferenceAttentionControl:
20
+ def __init__(
21
+ self,
22
+ unet,
23
+ mode="write",
24
+ do_classifier_free_guidance=False,
25
+ attention_auto_machine_weight=float("inf"),
26
+ gn_auto_machine_weight=1.0,
27
+ style_fidelity=1.0,
28
+ reference_attn=True,
29
+ reference_adain=False,
30
+ fusion_blocks="midup",
31
+ batch_size=1,
32
+ ) -> None:
33
+ # 10. Modify self attention and group norm
34
+ self.unet = unet
35
+ assert mode in ["read", "write"]
36
+ assert fusion_blocks in ["midup", "full"]
37
+ self.reference_attn = reference_attn
38
+ self.reference_adain = reference_adain
39
+ self.fusion_blocks = fusion_blocks
40
+ self.register_reference_hooks(
41
+ mode,
42
+ do_classifier_free_guidance,
43
+ attention_auto_machine_weight,
44
+ gn_auto_machine_weight,
45
+ style_fidelity,
46
+ reference_attn,
47
+ reference_adain,
48
+ fusion_blocks,
49
+ batch_size=batch_size,
50
+ )
51
+
52
+ def register_reference_hooks(
53
+ self,
54
+ mode,
55
+ do_classifier_free_guidance,
56
+ attention_auto_machine_weight,
57
+ gn_auto_machine_weight,
58
+ style_fidelity,
59
+ reference_attn,
60
+ reference_adain,
61
+ dtype=torch.float16,
62
+ batch_size=1,
63
+ num_images_per_prompt=1,
64
+ device=torch.device("cpu"),
65
+ fusion_blocks="midup",
66
+ ):
67
+ MODE = mode
68
+ do_classifier_free_guidance = do_classifier_free_guidance
69
+ attention_auto_machine_weight = attention_auto_machine_weight
70
+ gn_auto_machine_weight = gn_auto_machine_weight
71
+ style_fidelity = style_fidelity
72
+ reference_attn = reference_attn
73
+ reference_adain = reference_adain
74
+ fusion_blocks = fusion_blocks
75
+ num_images_per_prompt = num_images_per_prompt
76
+ dtype = dtype
77
+ if do_classifier_free_guidance:
78
+ uc_mask = (
79
+ torch.Tensor(
80
+ [1] * batch_size * num_images_per_prompt * 16
81
+ + [0] * batch_size * num_images_per_prompt * 16
82
+ )
83
+ .to(device)
84
+ .bool()
85
+ )
86
+ else:
87
+ uc_mask = (
88
+ torch.Tensor([0] * batch_size * num_images_per_prompt * 2)
89
+ .to(device)
90
+ .bool()
91
+ )
92
+
93
+ def hacked_basic_transformer_inner_forward(
94
+ self,
95
+ hidden_states: torch.FloatTensor,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
98
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
99
+ timestep: Optional[torch.LongTensor] = None,
100
+ cross_attention_kwargs: Dict[str, Any] = None,
101
+ class_labels: Optional[torch.LongTensor] = None,
102
+ video_length=None,
103
+ ):
104
+ if self.use_ada_layer_norm: # False
105
+ norm_hidden_states = self.norm1(hidden_states, timestep)
106
+ elif self.use_ada_layer_norm_zero:
107
+ (
108
+ norm_hidden_states,
109
+ gate_msa,
110
+ shift_mlp,
111
+ scale_mlp,
112
+ gate_mlp,
113
+ ) = self.norm1(
114
+ hidden_states,
115
+ timestep,
116
+ class_labels,
117
+ hidden_dtype=hidden_states.dtype,
118
+ )
119
+ else:
120
+ norm_hidden_states = self.norm1(hidden_states)
121
+
122
+ # 1. Self-Attention
123
+ # self.only_cross_attention = False
124
+ cross_attention_kwargs = (
125
+ cross_attention_kwargs if cross_attention_kwargs is not None else {}
126
+ )
127
+ if self.only_cross_attention:
128
+ attn_output = self.attn1(
129
+ norm_hidden_states,
130
+ encoder_hidden_states=encoder_hidden_states
131
+ if self.only_cross_attention
132
+ else None,
133
+ attention_mask=attention_mask,
134
+ **cross_attention_kwargs,
135
+ )
136
+ else:
137
+ if MODE == "write":
138
+ self.bank.append(norm_hidden_states.clone())
139
+ attn_output = self.attn1(
140
+ norm_hidden_states,
141
+ encoder_hidden_states=encoder_hidden_states
142
+ if self.only_cross_attention
143
+ else None,
144
+ attention_mask=attention_mask,
145
+ **cross_attention_kwargs,
146
+ )
147
+ if MODE == "read":
148
+ bank_fea = [
149
+ rearrange(
150
+ d.unsqueeze(1).repeat(1, video_length, 1, 1),
151
+ "b t l c -> (b t) l c",
152
+ )
153
+ for d in self.bank
154
+ ]
155
+ modify_norm_hidden_states = torch.cat(
156
+ [norm_hidden_states] + bank_fea, dim=1
157
+ )
158
+ hidden_states_uc = (
159
+ self.attn1(
160
+ norm_hidden_states,
161
+ encoder_hidden_states=modify_norm_hidden_states,
162
+ attention_mask=attention_mask,
163
+ )
164
+ + hidden_states
165
+ )
166
+ if do_classifier_free_guidance:
167
+ hidden_states_c = hidden_states_uc.clone()
168
+ _uc_mask = uc_mask.clone()
169
+ if hidden_states.shape[0] != _uc_mask.shape[0]:
170
+ _uc_mask = (
171
+ torch.Tensor(
172
+ [1] * (hidden_states.shape[0] // 2)
173
+ + [0] * (hidden_states.shape[0] // 2)
174
+ )
175
+ .to(device)
176
+ .bool()
177
+ )
178
+ hidden_states_c[_uc_mask] = (
179
+ self.attn1(
180
+ norm_hidden_states[_uc_mask],
181
+ encoder_hidden_states=norm_hidden_states[_uc_mask],
182
+ attention_mask=attention_mask,
183
+ )
184
+ + hidden_states[_uc_mask]
185
+ )
186
+ hidden_states = hidden_states_c.clone()
187
+ else:
188
+ hidden_states = hidden_states_uc
189
+
190
+ # self.bank.clear()
191
+ if self.attn2 is not None:
192
+ # Cross-Attention
193
+ norm_hidden_states = (
194
+ self.norm2(hidden_states, timestep)
195
+ if self.use_ada_layer_norm
196
+ else self.norm2(hidden_states)
197
+ )
198
+ hidden_states = (
199
+ self.attn2(
200
+ norm_hidden_states,
201
+ encoder_hidden_states=encoder_hidden_states,
202
+ attention_mask=attention_mask,
203
+ )
204
+ + hidden_states
205
+ )
206
+
207
+ # Feed-forward
208
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
209
+
210
+ # Temporal-Attention
211
+ if self.unet_use_temporal_attention:
212
+ d = hidden_states.shape[1]
213
+ hidden_states = rearrange(
214
+ hidden_states, "(b f) d c -> (b d) f c", f=video_length
215
+ )
216
+ norm_hidden_states = (
217
+ self.norm_temp(hidden_states, timestep)
218
+ if self.use_ada_layer_norm
219
+ else self.norm_temp(hidden_states)
220
+ )
221
+ hidden_states = (
222
+ self.attn_temp(norm_hidden_states) + hidden_states
223
+ )
224
+ hidden_states = rearrange(
225
+ hidden_states, "(b d) f c -> (b f) d c", d=d
226
+ )
227
+
228
+ return hidden_states
229
+
230
+ if self.use_ada_layer_norm_zero:
231
+ attn_output = gate_msa.unsqueeze(1) * attn_output
232
+ hidden_states = attn_output + hidden_states
233
+
234
+ if self.attn2 is not None:
235
+ norm_hidden_states = (
236
+ self.norm2(hidden_states, timestep)
237
+ if self.use_ada_layer_norm
238
+ else self.norm2(hidden_states)
239
+ )
240
+
241
+ # 2. Cross-Attention
242
+ attn_output = self.attn2(
243
+ norm_hidden_states,
244
+ encoder_hidden_states=encoder_hidden_states,
245
+ attention_mask=encoder_attention_mask,
246
+ **cross_attention_kwargs,
247
+ )
248
+ hidden_states = attn_output + hidden_states
249
+
250
+ # 3. Feed-forward
251
+ norm_hidden_states = self.norm3(hidden_states)
252
+
253
+ if self.use_ada_layer_norm_zero:
254
+ norm_hidden_states = (
255
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
256
+ )
257
+
258
+ ff_output = self.ff(norm_hidden_states)
259
+
260
+ if self.use_ada_layer_norm_zero:
261
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
262
+
263
+ hidden_states = ff_output + hidden_states
264
+
265
+ return hidden_states
266
+
267
+ if self.reference_attn:
268
+ if self.fusion_blocks == "midup":
269
+ attn_modules = [
270
+ module
271
+ for module in (
272
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
273
+ )
274
+ if isinstance(module, BasicTransformerBlock)
275
+ or isinstance(module, TemporalBasicTransformerBlock)
276
+ ]
277
+ elif self.fusion_blocks == "full":
278
+ attn_modules = [
279
+ module
280
+ for module in torch_dfs(self.unet)
281
+ if isinstance(module, BasicTransformerBlock)
282
+ or isinstance(module, TemporalBasicTransformerBlock)
283
+ ]
284
+ attn_modules = sorted(
285
+ attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
286
+ )
287
+
288
+ for i, module in enumerate(attn_modules):
289
+ module._original_inner_forward = module.forward
290
+ if isinstance(module, BasicTransformerBlock):
291
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
292
+ module, BasicTransformerBlock
293
+ )
294
+ if isinstance(module, TemporalBasicTransformerBlock):
295
+ module.forward = hacked_basic_transformer_inner_forward.__get__(
296
+ module, TemporalBasicTransformerBlock
297
+ )
298
+
299
+ module.bank = []
300
+ module.attn_weight = float(i) / float(len(attn_modules))
301
+
302
+ def update(self, writer, dtype=torch.float16):
303
+ if self.reference_attn:
304
+ if self.fusion_blocks == "midup":
305
+ reader_attn_modules = [
306
+ module
307
+ for module in (
308
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
309
+ )
310
+ if isinstance(module, TemporalBasicTransformerBlock)
311
+ ]
312
+ writer_attn_modules = [
313
+ module
314
+ for module in (
315
+ torch_dfs(writer.unet.mid_block)
316
+ + torch_dfs(writer.unet.up_blocks)
317
+ )
318
+ if isinstance(module, BasicTransformerBlock)
319
+ ]
320
+ elif self.fusion_blocks == "full":
321
+ reader_attn_modules = [
322
+ module
323
+ for module in torch_dfs(self.unet)
324
+ if isinstance(module, TemporalBasicTransformerBlock)
325
+ ]
326
+ writer_attn_modules = [
327
+ module
328
+ for module in torch_dfs(writer.unet)
329
+ if isinstance(module, BasicTransformerBlock)
330
+ ]
331
+ reader_attn_modules = sorted(
332
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
333
+ )
334
+ writer_attn_modules = sorted(
335
+ writer_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
336
+ )
337
+ for r, w in zip(reader_attn_modules, writer_attn_modules):
338
+ r.bank = [v.clone().to(dtype) for v in w.bank]
339
+ # w.bank.clear()
340
+
341
+ def clear(self):
342
+ if self.reference_attn:
343
+ if self.fusion_blocks == "midup":
344
+ reader_attn_modules = [
345
+ module
346
+ for module in (
347
+ torch_dfs(self.unet.mid_block) + torch_dfs(self.unet.up_blocks)
348
+ )
349
+ if isinstance(module, BasicTransformerBlock)
350
+ or isinstance(module, TemporalBasicTransformerBlock)
351
+ ]
352
+ elif self.fusion_blocks == "full":
353
+ reader_attn_modules = [
354
+ module
355
+ for module in torch_dfs(self.unet)
356
+ if isinstance(module, BasicTransformerBlock)
357
+ or isinstance(module, TemporalBasicTransformerBlock)
358
+ ]
359
+ reader_attn_modules = sorted(
360
+ reader_attn_modules, key=lambda x: -x.norm1.normalized_shape[0]
361
+ )
362
+ for r in reader_attn_modules:
363
+ r.bank.clear()
musepose/models/pose_guider.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+ from diffusers.models.modeling_utils import ModelMixin
7
+
8
+ from musepose.models.motion_module import zero_module
9
+ from musepose.models.resnet import InflatedConv3d
10
+
11
+
12
+ class PoseGuider(ModelMixin):
13
+ def __init__(
14
+ self,
15
+ conditioning_embedding_channels: int,
16
+ conditioning_channels: int = 3,
17
+ block_out_channels: Tuple[int] = (16, 32, 64, 128),
18
+ ):
19
+ super().__init__()
20
+ self.conv_in = InflatedConv3d(
21
+ conditioning_channels, block_out_channels[0], kernel_size=3, padding=1
22
+ )
23
+
24
+ self.blocks = nn.ModuleList([])
25
+
26
+ for i in range(len(block_out_channels) - 1):
27
+ channel_in = block_out_channels[i]
28
+ channel_out = block_out_channels[i + 1]
29
+ self.blocks.append(
30
+ InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
31
+ )
32
+ self.blocks.append(
33
+ InflatedConv3d(
34
+ channel_in, channel_out, kernel_size=3, padding=1, stride=2
35
+ )
36
+ )
37
+
38
+ self.conv_out = zero_module(
39
+ InflatedConv3d(
40
+ block_out_channels[-1],
41
+ conditioning_embedding_channels,
42
+ kernel_size=3,
43
+ padding=1,
44
+ )
45
+ )
46
+
47
+ def forward(self, conditioning):
48
+ embedding = self.conv_in(conditioning)
49
+ embedding = F.silu(embedding)
50
+
51
+ for block in self.blocks:
52
+ embedding = block(embedding)
53
+ embedding = F.silu(embedding)
54
+
55
+ embedding = self.conv_out(embedding)
56
+
57
+ return embedding
musepose/models/resnet.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+
8
+
9
+ class InflatedConv3d(nn.Conv2d):
10
+ def forward(self, x):
11
+ video_length = x.shape[2]
12
+
13
+ x = rearrange(x, "b c f h w -> (b f) c h w")
14
+ x = super().forward(x)
15
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
16
+
17
+ return x
18
+
19
+
20
+ class InflatedGroupNorm(nn.GroupNorm):
21
+ def forward(self, x):
22
+ video_length = x.shape[2]
23
+
24
+ x = rearrange(x, "b c f h w -> (b f) c h w")
25
+ x = super().forward(x)
26
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
27
+
28
+ return x
29
+
30
+
31
+ class Upsample3D(nn.Module):
32
+ def __init__(
33
+ self,
34
+ channels,
35
+ use_conv=False,
36
+ use_conv_transpose=False,
37
+ out_channels=None,
38
+ name="conv",
39
+ ):
40
+ super().__init__()
41
+ self.channels = channels
42
+ self.out_channels = out_channels or channels
43
+ self.use_conv = use_conv
44
+ self.use_conv_transpose = use_conv_transpose
45
+ self.name = name
46
+
47
+ conv = None
48
+ if use_conv_transpose:
49
+ raise NotImplementedError
50
+ elif use_conv:
51
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
52
+
53
+ def forward(self, hidden_states, output_size=None):
54
+ assert hidden_states.shape[1] == self.channels
55
+
56
+ if self.use_conv_transpose:
57
+ raise NotImplementedError
58
+
59
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
60
+ dtype = hidden_states.dtype
61
+ if dtype == torch.bfloat16:
62
+ hidden_states = hidden_states.to(torch.float32)
63
+
64
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
65
+ if hidden_states.shape[0] >= 64:
66
+ hidden_states = hidden_states.contiguous()
67
+
68
+ # if `output_size` is passed we force the interpolation output
69
+ # size and do not make use of `scale_factor=2`
70
+ if output_size is None:
71
+ hidden_states = F.interpolate(
72
+ hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest"
73
+ )
74
+ else:
75
+ hidden_states = F.interpolate(
76
+ hidden_states, size=output_size, mode="nearest"
77
+ )
78
+
79
+ # If the input is bfloat16, we cast back to bfloat16
80
+ if dtype == torch.bfloat16:
81
+ hidden_states = hidden_states.to(dtype)
82
+
83
+ # if self.use_conv:
84
+ # if self.name == "conv":
85
+ # hidden_states = self.conv(hidden_states)
86
+ # else:
87
+ # hidden_states = self.Conv2d_0(hidden_states)
88
+ hidden_states = self.conv(hidden_states)
89
+
90
+ return hidden_states
91
+
92
+
93
+ class Downsample3D(nn.Module):
94
+ def __init__(
95
+ self, channels, use_conv=False, out_channels=None, padding=1, name="conv"
96
+ ):
97
+ super().__init__()
98
+ self.channels = channels
99
+ self.out_channels = out_channels or channels
100
+ self.use_conv = use_conv
101
+ self.padding = padding
102
+ stride = 2
103
+ self.name = name
104
+
105
+ if use_conv:
106
+ self.conv = InflatedConv3d(
107
+ self.channels, self.out_channels, 3, stride=stride, padding=padding
108
+ )
109
+ else:
110
+ raise NotImplementedError
111
+
112
+ def forward(self, hidden_states):
113
+ assert hidden_states.shape[1] == self.channels
114
+ if self.use_conv and self.padding == 0:
115
+ raise NotImplementedError
116
+
117
+ assert hidden_states.shape[1] == self.channels
118
+ hidden_states = self.conv(hidden_states)
119
+
120
+ return hidden_states
121
+
122
+
123
+ class ResnetBlock3D(nn.Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ in_channels,
128
+ out_channels=None,
129
+ conv_shortcut=False,
130
+ dropout=0.0,
131
+ temb_channels=512,
132
+ groups=32,
133
+ groups_out=None,
134
+ pre_norm=True,
135
+ eps=1e-6,
136
+ non_linearity="swish",
137
+ time_embedding_norm="default",
138
+ output_scale_factor=1.0,
139
+ use_in_shortcut=None,
140
+ use_inflated_groupnorm=None,
141
+ ):
142
+ super().__init__()
143
+ self.pre_norm = pre_norm
144
+ self.pre_norm = True
145
+ self.in_channels = in_channels
146
+ out_channels = in_channels if out_channels is None else out_channels
147
+ self.out_channels = out_channels
148
+ self.use_conv_shortcut = conv_shortcut
149
+ self.time_embedding_norm = time_embedding_norm
150
+ self.output_scale_factor = output_scale_factor
151
+
152
+ if groups_out is None:
153
+ groups_out = groups
154
+
155
+ assert use_inflated_groupnorm != None
156
+ if use_inflated_groupnorm:
157
+ self.norm1 = InflatedGroupNorm(
158
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
159
+ )
160
+ else:
161
+ self.norm1 = torch.nn.GroupNorm(
162
+ num_groups=groups, num_channels=in_channels, eps=eps, affine=True
163
+ )
164
+
165
+ self.conv1 = InflatedConv3d(
166
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
167
+ )
168
+
169
+ if temb_channels is not None:
170
+ if self.time_embedding_norm == "default":
171
+ time_emb_proj_out_channels = out_channels
172
+ elif self.time_embedding_norm == "scale_shift":
173
+ time_emb_proj_out_channels = out_channels * 2
174
+ else:
175
+ raise ValueError(
176
+ f"unknown time_embedding_norm : {self.time_embedding_norm} "
177
+ )
178
+
179
+ self.time_emb_proj = torch.nn.Linear(
180
+ temb_channels, time_emb_proj_out_channels
181
+ )
182
+ else:
183
+ self.time_emb_proj = None
184
+
185
+ if use_inflated_groupnorm:
186
+ self.norm2 = InflatedGroupNorm(
187
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
188
+ )
189
+ else:
190
+ self.norm2 = torch.nn.GroupNorm(
191
+ num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True
192
+ )
193
+ self.dropout = torch.nn.Dropout(dropout)
194
+ self.conv2 = InflatedConv3d(
195
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
196
+ )
197
+
198
+ if non_linearity == "swish":
199
+ self.nonlinearity = lambda x: F.silu(x)
200
+ elif non_linearity == "mish":
201
+ self.nonlinearity = Mish()
202
+ elif non_linearity == "silu":
203
+ self.nonlinearity = nn.SiLU()
204
+
205
+ self.use_in_shortcut = (
206
+ self.in_channels != self.out_channels
207
+ if use_in_shortcut is None
208
+ else use_in_shortcut
209
+ )
210
+
211
+ self.conv_shortcut = None
212
+ if self.use_in_shortcut:
213
+ self.conv_shortcut = InflatedConv3d(
214
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
215
+ )
216
+
217
+ def forward(self, input_tensor, temb):
218
+ hidden_states = input_tensor
219
+
220
+ hidden_states = self.norm1(hidden_states)
221
+ hidden_states = self.nonlinearity(hidden_states)
222
+
223
+ hidden_states = self.conv1(hidden_states)
224
+
225
+ if temb is not None:
226
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
227
+
228
+ if temb is not None and self.time_embedding_norm == "default":
229
+ hidden_states = hidden_states + temb
230
+
231
+ hidden_states = self.norm2(hidden_states)
232
+
233
+ if temb is not None and self.time_embedding_norm == "scale_shift":
234
+ scale, shift = torch.chunk(temb, 2, dim=1)
235
+ hidden_states = hidden_states * (1 + scale) + shift
236
+
237
+ hidden_states = self.nonlinearity(hidden_states)
238
+
239
+ hidden_states = self.dropout(hidden_states)
240
+ hidden_states = self.conv2(hidden_states)
241
+
242
+ if self.conv_shortcut is not None:
243
+ input_tensor = self.conv_shortcut(input_tensor)
244
+
245
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
246
+
247
+ return output_tensor
248
+
249
+
250
+ class Mish(torch.nn.Module):
251
+ def forward(self, hidden_states):
252
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
musepose/models/transformer_2d.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformer_2d.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, Optional
4
+
5
+ import torch
6
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
7
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
8
+ from diffusers.models.modeling_utils import ModelMixin
9
+ from diffusers.models.normalization import AdaLayerNormSingle
10
+ from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
11
+ from torch import nn
12
+
13
+ from .attention import BasicTransformerBlock
14
+
15
+
16
+ @dataclass
17
+ class Transformer2DModelOutput(BaseOutput):
18
+ """
19
+ The output of [`Transformer2DModel`].
20
+
21
+ Args:
22
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
23
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
24
+ distributions for the unnoised latent pixels.
25
+ """
26
+
27
+ sample: torch.FloatTensor
28
+ ref_feature: torch.FloatTensor
29
+
30
+
31
+ class Transformer2DModel(ModelMixin, ConfigMixin):
32
+ """
33
+ A 2D Transformer model for image-like data.
34
+
35
+ Parameters:
36
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
37
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
38
+ in_channels (`int`, *optional*):
39
+ The number of channels in the input and output (specify if the input is **continuous**).
40
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
41
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
42
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
43
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
44
+ This is fixed during training since it is used to learn a number of position embeddings.
45
+ num_vector_embeds (`int`, *optional*):
46
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
47
+ Includes the class for the masked latent pixel.
48
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
49
+ num_embeds_ada_norm ( `int`, *optional*):
50
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
51
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
52
+ added to the hidden states.
53
+
54
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
55
+ attention_bias (`bool`, *optional*):
56
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
57
+ """
58
+
59
+ _supports_gradient_checkpointing = True
60
+
61
+ @register_to_config
62
+ def __init__(
63
+ self,
64
+ num_attention_heads: int = 16,
65
+ attention_head_dim: int = 88,
66
+ in_channels: Optional[int] = None,
67
+ out_channels: Optional[int] = None,
68
+ num_layers: int = 1,
69
+ dropout: float = 0.0,
70
+ norm_num_groups: int = 32,
71
+ cross_attention_dim: Optional[int] = None,
72
+ attention_bias: bool = False,
73
+ sample_size: Optional[int] = None,
74
+ num_vector_embeds: Optional[int] = None,
75
+ patch_size: Optional[int] = None,
76
+ activation_fn: str = "geglu",
77
+ num_embeds_ada_norm: Optional[int] = None,
78
+ use_linear_projection: bool = False,
79
+ only_cross_attention: bool = False,
80
+ double_self_attention: bool = False,
81
+ upcast_attention: bool = False,
82
+ norm_type: str = "layer_norm",
83
+ norm_elementwise_affine: bool = True,
84
+ norm_eps: float = 1e-5,
85
+ attention_type: str = "default",
86
+ caption_channels: int = None,
87
+ ):
88
+ super().__init__()
89
+ self.use_linear_projection = use_linear_projection
90
+ self.num_attention_heads = num_attention_heads
91
+ self.attention_head_dim = attention_head_dim
92
+ inner_dim = num_attention_heads * attention_head_dim
93
+
94
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
95
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
96
+
97
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
98
+ # Define whether input is continuous or discrete depending on configuration
99
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
100
+ self.is_input_vectorized = num_vector_embeds is not None
101
+ self.is_input_patches = in_channels is not None and patch_size is not None
102
+
103
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
104
+ deprecation_message = (
105
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
106
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
107
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
108
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
109
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
110
+ )
111
+ deprecate(
112
+ "norm_type!=num_embeds_ada_norm",
113
+ "1.0.0",
114
+ deprecation_message,
115
+ standard_warn=False,
116
+ )
117
+ norm_type = "ada_norm"
118
+
119
+ if self.is_input_continuous and self.is_input_vectorized:
120
+ raise ValueError(
121
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
122
+ " sure that either `in_channels` or `num_vector_embeds` is None."
123
+ )
124
+ elif self.is_input_vectorized and self.is_input_patches:
125
+ raise ValueError(
126
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
127
+ " sure that either `num_vector_embeds` or `num_patches` is None."
128
+ )
129
+ elif (
130
+ not self.is_input_continuous
131
+ and not self.is_input_vectorized
132
+ and not self.is_input_patches
133
+ ):
134
+ raise ValueError(
135
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
136
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
137
+ )
138
+
139
+ # 2. Define input layers
140
+ self.in_channels = in_channels
141
+
142
+ self.norm = torch.nn.GroupNorm(
143
+ num_groups=norm_num_groups,
144
+ num_channels=in_channels,
145
+ eps=1e-6,
146
+ affine=True,
147
+ )
148
+ if use_linear_projection:
149
+ self.proj_in = linear_cls(in_channels, inner_dim)
150
+ else:
151
+ self.proj_in = conv_cls(
152
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
153
+ )
154
+
155
+ # 3. Define transformers blocks
156
+ self.transformer_blocks = nn.ModuleList(
157
+ [
158
+ BasicTransformerBlock(
159
+ inner_dim,
160
+ num_attention_heads,
161
+ attention_head_dim,
162
+ dropout=dropout,
163
+ cross_attention_dim=cross_attention_dim,
164
+ activation_fn=activation_fn,
165
+ num_embeds_ada_norm=num_embeds_ada_norm,
166
+ attention_bias=attention_bias,
167
+ only_cross_attention=only_cross_attention,
168
+ double_self_attention=double_self_attention,
169
+ upcast_attention=upcast_attention,
170
+ norm_type=norm_type,
171
+ norm_elementwise_affine=norm_elementwise_affine,
172
+ norm_eps=norm_eps,
173
+ attention_type=attention_type,
174
+ )
175
+ for d in range(num_layers)
176
+ ]
177
+ )
178
+
179
+ # 4. Define output layers
180
+ self.out_channels = in_channels if out_channels is None else out_channels
181
+ # TODO: should use out_channels for continuous projections
182
+ if use_linear_projection:
183
+ self.proj_out = linear_cls(inner_dim, in_channels)
184
+ else:
185
+ self.proj_out = conv_cls(
186
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
187
+ )
188
+
189
+ # 5. PixArt-Alpha blocks.
190
+ self.adaln_single = None
191
+ self.use_additional_conditions = False
192
+ if norm_type == "ada_norm_single":
193
+ self.use_additional_conditions = self.config.sample_size == 128
194
+ # TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
195
+ # additional conditions until we find better name
196
+ self.adaln_single = AdaLayerNormSingle(
197
+ inner_dim, use_additional_conditions=self.use_additional_conditions
198
+ )
199
+
200
+ self.caption_projection = None
201
+ if caption_channels is not None:
202
+ self.caption_projection = CaptionProjection(
203
+ in_features=caption_channels, hidden_size=inner_dim
204
+ )
205
+
206
+ self.gradient_checkpointing = False
207
+
208
+ def _set_gradient_checkpointing(self, module, value=False):
209
+ if hasattr(module, "gradient_checkpointing"):
210
+ module.gradient_checkpointing = value
211
+
212
+ def forward(
213
+ self,
214
+ hidden_states: torch.Tensor,
215
+ encoder_hidden_states: Optional[torch.Tensor] = None,
216
+ timestep: Optional[torch.LongTensor] = None,
217
+ added_cond_kwargs: Dict[str, torch.Tensor] = None,
218
+ class_labels: Optional[torch.LongTensor] = None,
219
+ cross_attention_kwargs: Dict[str, Any] = None,
220
+ attention_mask: Optional[torch.Tensor] = None,
221
+ encoder_attention_mask: Optional[torch.Tensor] = None,
222
+ return_dict: bool = True,
223
+ ):
224
+ """
225
+ The [`Transformer2DModel`] forward method.
226
+
227
+ Args:
228
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
229
+ Input `hidden_states`.
230
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
231
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
232
+ self-attention.
233
+ timestep ( `torch.LongTensor`, *optional*):
234
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
235
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
236
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
237
+ `AdaLayerZeroNorm`.
238
+ cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
239
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
240
+ `self.processor` in
241
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
242
+ attention_mask ( `torch.Tensor`, *optional*):
243
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
244
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
245
+ negative values to the attention scores corresponding to "discard" tokens.
246
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
247
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
248
+
249
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
250
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
251
+
252
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
253
+ above. This bias will be added to the cross-attention scores.
254
+ return_dict (`bool`, *optional*, defaults to `True`):
255
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
256
+ tuple.
257
+
258
+ Returns:
259
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
260
+ `tuple` where the first element is the sample tensor.
261
+ """
262
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
263
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
264
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
265
+ # expects mask of shape:
266
+ # [batch, key_tokens]
267
+ # adds singleton query_tokens dimension:
268
+ # [batch, 1, key_tokens]
269
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
270
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
271
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
272
+ if attention_mask is not None and attention_mask.ndim == 2:
273
+ # assume that mask is expressed as:
274
+ # (1 = keep, 0 = discard)
275
+ # convert mask into a bias that can be added to attention scores:
276
+ # (keep = +0, discard = -10000.0)
277
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
278
+ attention_mask = attention_mask.unsqueeze(1)
279
+
280
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
281
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
282
+ encoder_attention_mask = (
283
+ 1 - encoder_attention_mask.to(hidden_states.dtype)
284
+ ) * -10000.0
285
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
286
+
287
+ # Retrieve lora scale.
288
+ lora_scale = (
289
+ cross_attention_kwargs.get("scale", 1.0)
290
+ if cross_attention_kwargs is not None
291
+ else 1.0
292
+ )
293
+
294
+ # 1. Input
295
+ batch, _, height, width = hidden_states.shape
296
+ residual = hidden_states
297
+
298
+ hidden_states = self.norm(hidden_states)
299
+ if not self.use_linear_projection:
300
+ hidden_states = (
301
+ self.proj_in(hidden_states, scale=lora_scale)
302
+ if not USE_PEFT_BACKEND
303
+ else self.proj_in(hidden_states)
304
+ )
305
+ inner_dim = hidden_states.shape[1]
306
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
307
+ batch, height * width, inner_dim
308
+ )
309
+ else:
310
+ inner_dim = hidden_states.shape[1]
311
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
312
+ batch, height * width, inner_dim
313
+ )
314
+ hidden_states = (
315
+ self.proj_in(hidden_states, scale=lora_scale)
316
+ if not USE_PEFT_BACKEND
317
+ else self.proj_in(hidden_states)
318
+ )
319
+
320
+ # 2. Blocks
321
+ if self.caption_projection is not None:
322
+ batch_size = hidden_states.shape[0]
323
+ encoder_hidden_states = self.caption_projection(encoder_hidden_states)
324
+ encoder_hidden_states = encoder_hidden_states.view(
325
+ batch_size, -1, hidden_states.shape[-1]
326
+ )
327
+
328
+ ref_feature = hidden_states.reshape(batch, height, width, inner_dim)
329
+ for block in self.transformer_blocks:
330
+ if self.training and self.gradient_checkpointing:
331
+
332
+ def create_custom_forward(module, return_dict=None):
333
+ def custom_forward(*inputs):
334
+ if return_dict is not None:
335
+ return module(*inputs, return_dict=return_dict)
336
+ else:
337
+ return module(*inputs)
338
+
339
+ return custom_forward
340
+
341
+ ckpt_kwargs: Dict[str, Any] = (
342
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
343
+ )
344
+ hidden_states = torch.utils.checkpoint.checkpoint(
345
+ create_custom_forward(block),
346
+ hidden_states,
347
+ attention_mask,
348
+ encoder_hidden_states,
349
+ encoder_attention_mask,
350
+ timestep,
351
+ cross_attention_kwargs,
352
+ class_labels,
353
+ **ckpt_kwargs,
354
+ )
355
+ else:
356
+ hidden_states = block(
357
+ hidden_states,
358
+ attention_mask=attention_mask,
359
+ encoder_hidden_states=encoder_hidden_states,
360
+ encoder_attention_mask=encoder_attention_mask,
361
+ timestep=timestep,
362
+ cross_attention_kwargs=cross_attention_kwargs,
363
+ class_labels=class_labels,
364
+ )
365
+
366
+ # 3. Output
367
+ if self.is_input_continuous:
368
+ if not self.use_linear_projection:
369
+ hidden_states = (
370
+ hidden_states.reshape(batch, height, width, inner_dim)
371
+ .permute(0, 3, 1, 2)
372
+ .contiguous()
373
+ )
374
+ hidden_states = (
375
+ self.proj_out(hidden_states, scale=lora_scale)
376
+ if not USE_PEFT_BACKEND
377
+ else self.proj_out(hidden_states)
378
+ )
379
+ else:
380
+ hidden_states = (
381
+ self.proj_out(hidden_states, scale=lora_scale)
382
+ if not USE_PEFT_BACKEND
383
+ else self.proj_out(hidden_states)
384
+ )
385
+ hidden_states = (
386
+ hidden_states.reshape(batch, height, width, inner_dim)
387
+ .permute(0, 3, 1, 2)
388
+ .contiguous()
389
+ )
390
+
391
+ output = hidden_states + residual
392
+ if not return_dict:
393
+ return (output, ref_feature)
394
+
395
+ return Transformer2DModelOutput(sample=output, ref_feature=ref_feature)
musepose/models/transformer_3d.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
6
+ from diffusers.models import ModelMixin
7
+ from diffusers.utils import BaseOutput
8
+ from diffusers.utils.import_utils import is_xformers_available
9
+ from einops import rearrange, repeat
10
+ from torch import nn
11
+
12
+ from .attention import TemporalBasicTransformerBlock
13
+
14
+
15
+ @dataclass
16
+ class Transformer3DModelOutput(BaseOutput):
17
+ sample: torch.FloatTensor
18
+
19
+
20
+ if is_xformers_available():
21
+ import xformers
22
+ import xformers.ops
23
+ else:
24
+ xformers = None
25
+
26
+
27
+ class Transformer3DModel(ModelMixin, ConfigMixin):
28
+ _supports_gradient_checkpointing = True
29
+
30
+ @register_to_config
31
+ def __init__(
32
+ self,
33
+ num_attention_heads: int = 16,
34
+ attention_head_dim: int = 88,
35
+ in_channels: Optional[int] = None,
36
+ num_layers: int = 1,
37
+ dropout: float = 0.0,
38
+ norm_num_groups: int = 32,
39
+ cross_attention_dim: Optional[int] = None,
40
+ attention_bias: bool = False,
41
+ activation_fn: str = "geglu",
42
+ num_embeds_ada_norm: Optional[int] = None,
43
+ use_linear_projection: bool = False,
44
+ only_cross_attention: bool = False,
45
+ upcast_attention: bool = False,
46
+ unet_use_cross_frame_attention=None,
47
+ unet_use_temporal_attention=None,
48
+ ):
49
+ super().__init__()
50
+ self.use_linear_projection = use_linear_projection
51
+ self.num_attention_heads = num_attention_heads
52
+ self.attention_head_dim = attention_head_dim
53
+ inner_dim = num_attention_heads * attention_head_dim
54
+
55
+ # Define input layers
56
+ self.in_channels = in_channels
57
+
58
+ self.norm = torch.nn.GroupNorm(
59
+ num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True
60
+ )
61
+ if use_linear_projection:
62
+ self.proj_in = nn.Linear(in_channels, inner_dim)
63
+ else:
64
+ self.proj_in = nn.Conv2d(
65
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
66
+ )
67
+
68
+ # Define transformers blocks
69
+ self.transformer_blocks = nn.ModuleList(
70
+ [
71
+ TemporalBasicTransformerBlock(
72
+ inner_dim,
73
+ num_attention_heads,
74
+ attention_head_dim,
75
+ dropout=dropout,
76
+ cross_attention_dim=cross_attention_dim,
77
+ activation_fn=activation_fn,
78
+ num_embeds_ada_norm=num_embeds_ada_norm,
79
+ attention_bias=attention_bias,
80
+ only_cross_attention=only_cross_attention,
81
+ upcast_attention=upcast_attention,
82
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
83
+ unet_use_temporal_attention=unet_use_temporal_attention,
84
+ )
85
+ for d in range(num_layers)
86
+ ]
87
+ )
88
+
89
+ # 4. Define output layers
90
+ if use_linear_projection:
91
+ self.proj_out = nn.Linear(in_channels, inner_dim)
92
+ else:
93
+ self.proj_out = nn.Conv2d(
94
+ inner_dim, in_channels, kernel_size=1, stride=1, padding=0
95
+ )
96
+
97
+ self.gradient_checkpointing = False
98
+
99
+ def _set_gradient_checkpointing(self, module, value=False):
100
+ if hasattr(module, "gradient_checkpointing"):
101
+ module.gradient_checkpointing = value
102
+
103
+ def forward(
104
+ self,
105
+ hidden_states,
106
+ encoder_hidden_states=None,
107
+ timestep=None,
108
+ return_dict: bool = True,
109
+ ):
110
+ # Input
111
+ assert (
112
+ hidden_states.dim() == 5
113
+ ), f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
114
+ video_length = hidden_states.shape[2]
115
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
116
+ if encoder_hidden_states.shape[0] != hidden_states.shape[0]:
117
+ encoder_hidden_states = repeat(
118
+ encoder_hidden_states, "b n c -> (b f) n c", f=video_length
119
+ )
120
+
121
+ batch, channel, height, weight = hidden_states.shape
122
+ residual = hidden_states
123
+
124
+ hidden_states = self.norm(hidden_states)
125
+ if not self.use_linear_projection:
126
+ hidden_states = self.proj_in(hidden_states)
127
+ inner_dim = hidden_states.shape[1]
128
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
129
+ batch, height * weight, inner_dim
130
+ )
131
+ else:
132
+ inner_dim = hidden_states.shape[1]
133
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
134
+ batch, height * weight, inner_dim
135
+ )
136
+ hidden_states = self.proj_in(hidden_states)
137
+
138
+ # Blocks
139
+ for i, block in enumerate(self.transformer_blocks):
140
+ hidden_states = block(
141
+ hidden_states,
142
+ encoder_hidden_states=encoder_hidden_states,
143
+ timestep=timestep,
144
+ video_length=video_length,
145
+ )
146
+
147
+ # Output
148
+ if not self.use_linear_projection:
149
+ hidden_states = (
150
+ hidden_states.reshape(batch, height, weight, inner_dim)
151
+ .permute(0, 3, 1, 2)
152
+ .contiguous()
153
+ )
154
+ hidden_states = self.proj_out(hidden_states)
155
+ else:
156
+ hidden_states = self.proj_out(hidden_states)
157
+ hidden_states = (
158
+ hidden_states.reshape(batch, height, weight, inner_dim)
159
+ .permute(0, 3, 1, 2)
160
+ .contiguous()
161
+ )
162
+
163
+ output = hidden_states + residual
164
+
165
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
166
+ if not return_dict:
167
+ return (output,)
168
+
169
+ return Transformer3DModelOutput(sample=output)
musepose/models/unet_2d_blocks.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from diffusers.models.activations import get_activation
8
+ from diffusers.models.attention_processor import Attention
9
+ from diffusers.models.dual_transformer_2d import DualTransformer2DModel
10
+ from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
11
+ from diffusers.utils import is_torch_version, logging
12
+ from diffusers.utils.torch_utils import apply_freeu
13
+ from torch import nn
14
+
15
+ from .transformer_2d import Transformer2DModel
16
+
17
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
18
+
19
+
20
+ def get_down_block(
21
+ down_block_type: str,
22
+ num_layers: int,
23
+ in_channels: int,
24
+ out_channels: int,
25
+ temb_channels: int,
26
+ add_downsample: bool,
27
+ resnet_eps: float,
28
+ resnet_act_fn: str,
29
+ transformer_layers_per_block: int = 1,
30
+ num_attention_heads: Optional[int] = None,
31
+ resnet_groups: Optional[int] = None,
32
+ cross_attention_dim: Optional[int] = None,
33
+ downsample_padding: Optional[int] = None,
34
+ dual_cross_attention: bool = False,
35
+ use_linear_projection: bool = False,
36
+ only_cross_attention: bool = False,
37
+ upcast_attention: bool = False,
38
+ resnet_time_scale_shift: str = "default",
39
+ attention_type: str = "default",
40
+ resnet_skip_time_act: bool = False,
41
+ resnet_out_scale_factor: float = 1.0,
42
+ cross_attention_norm: Optional[str] = None,
43
+ attention_head_dim: Optional[int] = None,
44
+ downsample_type: Optional[str] = None,
45
+ dropout: float = 0.0,
46
+ ):
47
+ # If attn head dim is not defined, we default it to the number of heads
48
+ if attention_head_dim is None:
49
+ logger.warn(
50
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
51
+ )
52
+ attention_head_dim = num_attention_heads
53
+
54
+ down_block_type = (
55
+ down_block_type[7:]
56
+ if down_block_type.startswith("UNetRes")
57
+ else down_block_type
58
+ )
59
+ if down_block_type == "DownBlock2D":
60
+ return DownBlock2D(
61
+ num_layers=num_layers,
62
+ in_channels=in_channels,
63
+ out_channels=out_channels,
64
+ temb_channels=temb_channels,
65
+ dropout=dropout,
66
+ add_downsample=add_downsample,
67
+ resnet_eps=resnet_eps,
68
+ resnet_act_fn=resnet_act_fn,
69
+ resnet_groups=resnet_groups,
70
+ downsample_padding=downsample_padding,
71
+ resnet_time_scale_shift=resnet_time_scale_shift,
72
+ )
73
+ elif down_block_type == "CrossAttnDownBlock2D":
74
+ if cross_attention_dim is None:
75
+ raise ValueError(
76
+ "cross_attention_dim must be specified for CrossAttnDownBlock2D"
77
+ )
78
+ return CrossAttnDownBlock2D(
79
+ num_layers=num_layers,
80
+ transformer_layers_per_block=transformer_layers_per_block,
81
+ in_channels=in_channels,
82
+ out_channels=out_channels,
83
+ temb_channels=temb_channels,
84
+ dropout=dropout,
85
+ add_downsample=add_downsample,
86
+ resnet_eps=resnet_eps,
87
+ resnet_act_fn=resnet_act_fn,
88
+ resnet_groups=resnet_groups,
89
+ downsample_padding=downsample_padding,
90
+ cross_attention_dim=cross_attention_dim,
91
+ num_attention_heads=num_attention_heads,
92
+ dual_cross_attention=dual_cross_attention,
93
+ use_linear_projection=use_linear_projection,
94
+ only_cross_attention=only_cross_attention,
95
+ upcast_attention=upcast_attention,
96
+ resnet_time_scale_shift=resnet_time_scale_shift,
97
+ attention_type=attention_type,
98
+ )
99
+ raise ValueError(f"{down_block_type} does not exist.")
100
+
101
+
102
+ def get_up_block(
103
+ up_block_type: str,
104
+ num_layers: int,
105
+ in_channels: int,
106
+ out_channels: int,
107
+ prev_output_channel: int,
108
+ temb_channels: int,
109
+ add_upsample: bool,
110
+ resnet_eps: float,
111
+ resnet_act_fn: str,
112
+ resolution_idx: Optional[int] = None,
113
+ transformer_layers_per_block: int = 1,
114
+ num_attention_heads: Optional[int] = None,
115
+ resnet_groups: Optional[int] = None,
116
+ cross_attention_dim: Optional[int] = None,
117
+ dual_cross_attention: bool = False,
118
+ use_linear_projection: bool = False,
119
+ only_cross_attention: bool = False,
120
+ upcast_attention: bool = False,
121
+ resnet_time_scale_shift: str = "default",
122
+ attention_type: str = "default",
123
+ resnet_skip_time_act: bool = False,
124
+ resnet_out_scale_factor: float = 1.0,
125
+ cross_attention_norm: Optional[str] = None,
126
+ attention_head_dim: Optional[int] = None,
127
+ upsample_type: Optional[str] = None,
128
+ dropout: float = 0.0,
129
+ ) -> nn.Module:
130
+ # If attn head dim is not defined, we default it to the number of heads
131
+ if attention_head_dim is None:
132
+ logger.warn(
133
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
134
+ )
135
+ attention_head_dim = num_attention_heads
136
+
137
+ up_block_type = (
138
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
139
+ )
140
+ if up_block_type == "UpBlock2D":
141
+ return UpBlock2D(
142
+ num_layers=num_layers,
143
+ in_channels=in_channels,
144
+ out_channels=out_channels,
145
+ prev_output_channel=prev_output_channel,
146
+ temb_channels=temb_channels,
147
+ resolution_idx=resolution_idx,
148
+ dropout=dropout,
149
+ add_upsample=add_upsample,
150
+ resnet_eps=resnet_eps,
151
+ resnet_act_fn=resnet_act_fn,
152
+ resnet_groups=resnet_groups,
153
+ resnet_time_scale_shift=resnet_time_scale_shift,
154
+ )
155
+ elif up_block_type == "CrossAttnUpBlock2D":
156
+ if cross_attention_dim is None:
157
+ raise ValueError(
158
+ "cross_attention_dim must be specified for CrossAttnUpBlock2D"
159
+ )
160
+ return CrossAttnUpBlock2D(
161
+ num_layers=num_layers,
162
+ transformer_layers_per_block=transformer_layers_per_block,
163
+ in_channels=in_channels,
164
+ out_channels=out_channels,
165
+ prev_output_channel=prev_output_channel,
166
+ temb_channels=temb_channels,
167
+ resolution_idx=resolution_idx,
168
+ dropout=dropout,
169
+ add_upsample=add_upsample,
170
+ resnet_eps=resnet_eps,
171
+ resnet_act_fn=resnet_act_fn,
172
+ resnet_groups=resnet_groups,
173
+ cross_attention_dim=cross_attention_dim,
174
+ num_attention_heads=num_attention_heads,
175
+ dual_cross_attention=dual_cross_attention,
176
+ use_linear_projection=use_linear_projection,
177
+ only_cross_attention=only_cross_attention,
178
+ upcast_attention=upcast_attention,
179
+ resnet_time_scale_shift=resnet_time_scale_shift,
180
+ attention_type=attention_type,
181
+ )
182
+
183
+ raise ValueError(f"{up_block_type} does not exist.")
184
+
185
+
186
+ class AutoencoderTinyBlock(nn.Module):
187
+ """
188
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
189
+ blocks.
190
+
191
+ Args:
192
+ in_channels (`int`): The number of input channels.
193
+ out_channels (`int`): The number of output channels.
194
+ act_fn (`str`):
195
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
196
+
197
+ Returns:
198
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
199
+ `out_channels`.
200
+ """
201
+
202
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
203
+ super().__init__()
204
+ act_fn = get_activation(act_fn)
205
+ self.conv = nn.Sequential(
206
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
207
+ act_fn,
208
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
209
+ act_fn,
210
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
211
+ )
212
+ self.skip = (
213
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
214
+ if in_channels != out_channels
215
+ else nn.Identity()
216
+ )
217
+ self.fuse = nn.ReLU()
218
+
219
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
220
+ return self.fuse(self.conv(x) + self.skip(x))
221
+
222
+
223
+ class UNetMidBlock2D(nn.Module):
224
+ """
225
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
226
+
227
+ Args:
228
+ in_channels (`int`): The number of input channels.
229
+ temb_channels (`int`): The number of temporal embedding channels.
230
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
231
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
232
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
233
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
234
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
235
+ model on tasks with long-range temporal dependencies.
236
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
237
+ resnet_groups (`int`, *optional*, defaults to 32):
238
+ The number of groups to use in the group normalization layers of the resnet blocks.
239
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
240
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
241
+ Whether to use pre-normalization for the resnet blocks.
242
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
243
+ attention_head_dim (`int`, *optional*, defaults to 1):
244
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
245
+ the number of input channels.
246
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
247
+
248
+ Returns:
249
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
250
+ in_channels, height, width)`.
251
+
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ in_channels: int,
257
+ temb_channels: int,
258
+ dropout: float = 0.0,
259
+ num_layers: int = 1,
260
+ resnet_eps: float = 1e-6,
261
+ resnet_time_scale_shift: str = "default", # default, spatial
262
+ resnet_act_fn: str = "swish",
263
+ resnet_groups: int = 32,
264
+ attn_groups: Optional[int] = None,
265
+ resnet_pre_norm: bool = True,
266
+ add_attention: bool = True,
267
+ attention_head_dim: int = 1,
268
+ output_scale_factor: float = 1.0,
269
+ ):
270
+ super().__init__()
271
+ resnet_groups = (
272
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
273
+ )
274
+ self.add_attention = add_attention
275
+
276
+ if attn_groups is None:
277
+ attn_groups = (
278
+ resnet_groups if resnet_time_scale_shift == "default" else None
279
+ )
280
+
281
+ # there is always at least one resnet
282
+ resnets = [
283
+ ResnetBlock2D(
284
+ in_channels=in_channels,
285
+ out_channels=in_channels,
286
+ temb_channels=temb_channels,
287
+ eps=resnet_eps,
288
+ groups=resnet_groups,
289
+ dropout=dropout,
290
+ time_embedding_norm=resnet_time_scale_shift,
291
+ non_linearity=resnet_act_fn,
292
+ output_scale_factor=output_scale_factor,
293
+ pre_norm=resnet_pre_norm,
294
+ )
295
+ ]
296
+ attentions = []
297
+
298
+ if attention_head_dim is None:
299
+ logger.warn(
300
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
301
+ )
302
+ attention_head_dim = in_channels
303
+
304
+ for _ in range(num_layers):
305
+ if self.add_attention:
306
+ attentions.append(
307
+ Attention(
308
+ in_channels,
309
+ heads=in_channels // attention_head_dim,
310
+ dim_head=attention_head_dim,
311
+ rescale_output_factor=output_scale_factor,
312
+ eps=resnet_eps,
313
+ norm_num_groups=attn_groups,
314
+ spatial_norm_dim=temb_channels
315
+ if resnet_time_scale_shift == "spatial"
316
+ else None,
317
+ residual_connection=True,
318
+ bias=True,
319
+ upcast_softmax=True,
320
+ _from_deprecated_attn_block=True,
321
+ )
322
+ )
323
+ else:
324
+ attentions.append(None)
325
+
326
+ resnets.append(
327
+ ResnetBlock2D(
328
+ in_channels=in_channels,
329
+ out_channels=in_channels,
330
+ temb_channels=temb_channels,
331
+ eps=resnet_eps,
332
+ groups=resnet_groups,
333
+ dropout=dropout,
334
+ time_embedding_norm=resnet_time_scale_shift,
335
+ non_linearity=resnet_act_fn,
336
+ output_scale_factor=output_scale_factor,
337
+ pre_norm=resnet_pre_norm,
338
+ )
339
+ )
340
+
341
+ self.attentions = nn.ModuleList(attentions)
342
+ self.resnets = nn.ModuleList(resnets)
343
+
344
+ def forward(
345
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
346
+ ) -> torch.FloatTensor:
347
+ hidden_states = self.resnets[0](hidden_states, temb)
348
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
349
+ if attn is not None:
350
+ hidden_states = attn(hidden_states, temb=temb)
351
+ hidden_states = resnet(hidden_states, temb)
352
+
353
+ return hidden_states
354
+
355
+
356
+ class UNetMidBlock2DCrossAttn(nn.Module):
357
+ def __init__(
358
+ self,
359
+ in_channels: int,
360
+ temb_channels: int,
361
+ dropout: float = 0.0,
362
+ num_layers: int = 1,
363
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
364
+ resnet_eps: float = 1e-6,
365
+ resnet_time_scale_shift: str = "default",
366
+ resnet_act_fn: str = "swish",
367
+ resnet_groups: int = 32,
368
+ resnet_pre_norm: bool = True,
369
+ num_attention_heads: int = 1,
370
+ output_scale_factor: float = 1.0,
371
+ cross_attention_dim: int = 1280,
372
+ dual_cross_attention: bool = False,
373
+ use_linear_projection: bool = False,
374
+ upcast_attention: bool = False,
375
+ attention_type: str = "default",
376
+ ):
377
+ super().__init__()
378
+
379
+ self.has_cross_attention = True
380
+ self.num_attention_heads = num_attention_heads
381
+ resnet_groups = (
382
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
383
+ )
384
+
385
+ # support for variable transformer layers per block
386
+ if isinstance(transformer_layers_per_block, int):
387
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
388
+
389
+ # there is always at least one resnet
390
+ resnets = [
391
+ ResnetBlock2D(
392
+ in_channels=in_channels,
393
+ out_channels=in_channels,
394
+ temb_channels=temb_channels,
395
+ eps=resnet_eps,
396
+ groups=resnet_groups,
397
+ dropout=dropout,
398
+ time_embedding_norm=resnet_time_scale_shift,
399
+ non_linearity=resnet_act_fn,
400
+ output_scale_factor=output_scale_factor,
401
+ pre_norm=resnet_pre_norm,
402
+ )
403
+ ]
404
+ attentions = []
405
+
406
+ for i in range(num_layers):
407
+ if not dual_cross_attention:
408
+ attentions.append(
409
+ Transformer2DModel(
410
+ num_attention_heads,
411
+ in_channels // num_attention_heads,
412
+ in_channels=in_channels,
413
+ num_layers=transformer_layers_per_block[i],
414
+ cross_attention_dim=cross_attention_dim,
415
+ norm_num_groups=resnet_groups,
416
+ use_linear_projection=use_linear_projection,
417
+ upcast_attention=upcast_attention,
418
+ attention_type=attention_type,
419
+ )
420
+ )
421
+ else:
422
+ attentions.append(
423
+ DualTransformer2DModel(
424
+ num_attention_heads,
425
+ in_channels // num_attention_heads,
426
+ in_channels=in_channels,
427
+ num_layers=1,
428
+ cross_attention_dim=cross_attention_dim,
429
+ norm_num_groups=resnet_groups,
430
+ )
431
+ )
432
+ resnets.append(
433
+ ResnetBlock2D(
434
+ in_channels=in_channels,
435
+ out_channels=in_channels,
436
+ temb_channels=temb_channels,
437
+ eps=resnet_eps,
438
+ groups=resnet_groups,
439
+ dropout=dropout,
440
+ time_embedding_norm=resnet_time_scale_shift,
441
+ non_linearity=resnet_act_fn,
442
+ output_scale_factor=output_scale_factor,
443
+ pre_norm=resnet_pre_norm,
444
+ )
445
+ )
446
+
447
+ self.attentions = nn.ModuleList(attentions)
448
+ self.resnets = nn.ModuleList(resnets)
449
+
450
+ self.gradient_checkpointing = False
451
+
452
+ def forward(
453
+ self,
454
+ hidden_states: torch.FloatTensor,
455
+ temb: Optional[torch.FloatTensor] = None,
456
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
457
+ attention_mask: Optional[torch.FloatTensor] = None,
458
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
459
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
460
+ ) -> torch.FloatTensor:
461
+ lora_scale = (
462
+ cross_attention_kwargs.get("scale", 1.0)
463
+ if cross_attention_kwargs is not None
464
+ else 1.0
465
+ )
466
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
467
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
468
+ if self.training and self.gradient_checkpointing:
469
+
470
+ def create_custom_forward(module, return_dict=None):
471
+ def custom_forward(*inputs):
472
+ if return_dict is not None:
473
+ return module(*inputs, return_dict=return_dict)
474
+ else:
475
+ return module(*inputs)
476
+
477
+ return custom_forward
478
+
479
+ ckpt_kwargs: Dict[str, Any] = (
480
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
481
+ )
482
+ hidden_states, ref_feature = attn(
483
+ hidden_states,
484
+ encoder_hidden_states=encoder_hidden_states,
485
+ cross_attention_kwargs=cross_attention_kwargs,
486
+ attention_mask=attention_mask,
487
+ encoder_attention_mask=encoder_attention_mask,
488
+ return_dict=False,
489
+ )
490
+ hidden_states = torch.utils.checkpoint.checkpoint(
491
+ create_custom_forward(resnet),
492
+ hidden_states,
493
+ temb,
494
+ **ckpt_kwargs,
495
+ )
496
+ else:
497
+ hidden_states, ref_feature = attn(
498
+ hidden_states,
499
+ encoder_hidden_states=encoder_hidden_states,
500
+ cross_attention_kwargs=cross_attention_kwargs,
501
+ attention_mask=attention_mask,
502
+ encoder_attention_mask=encoder_attention_mask,
503
+ return_dict=False,
504
+ )
505
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
506
+
507
+ return hidden_states
508
+
509
+
510
+ class CrossAttnDownBlock2D(nn.Module):
511
+ def __init__(
512
+ self,
513
+ in_channels: int,
514
+ out_channels: int,
515
+ temb_channels: int,
516
+ dropout: float = 0.0,
517
+ num_layers: int = 1,
518
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
519
+ resnet_eps: float = 1e-6,
520
+ resnet_time_scale_shift: str = "default",
521
+ resnet_act_fn: str = "swish",
522
+ resnet_groups: int = 32,
523
+ resnet_pre_norm: bool = True,
524
+ num_attention_heads: int = 1,
525
+ cross_attention_dim: int = 1280,
526
+ output_scale_factor: float = 1.0,
527
+ downsample_padding: int = 1,
528
+ add_downsample: bool = True,
529
+ dual_cross_attention: bool = False,
530
+ use_linear_projection: bool = False,
531
+ only_cross_attention: bool = False,
532
+ upcast_attention: bool = False,
533
+ attention_type: str = "default",
534
+ ):
535
+ super().__init__()
536
+ resnets = []
537
+ attentions = []
538
+
539
+ self.has_cross_attention = True
540
+ self.num_attention_heads = num_attention_heads
541
+ if isinstance(transformer_layers_per_block, int):
542
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
543
+
544
+ for i in range(num_layers):
545
+ in_channels = in_channels if i == 0 else out_channels
546
+ resnets.append(
547
+ ResnetBlock2D(
548
+ in_channels=in_channels,
549
+ out_channels=out_channels,
550
+ temb_channels=temb_channels,
551
+ eps=resnet_eps,
552
+ groups=resnet_groups,
553
+ dropout=dropout,
554
+ time_embedding_norm=resnet_time_scale_shift,
555
+ non_linearity=resnet_act_fn,
556
+ output_scale_factor=output_scale_factor,
557
+ pre_norm=resnet_pre_norm,
558
+ )
559
+ )
560
+ if not dual_cross_attention:
561
+ attentions.append(
562
+ Transformer2DModel(
563
+ num_attention_heads,
564
+ out_channels // num_attention_heads,
565
+ in_channels=out_channels,
566
+ num_layers=transformer_layers_per_block[i],
567
+ cross_attention_dim=cross_attention_dim,
568
+ norm_num_groups=resnet_groups,
569
+ use_linear_projection=use_linear_projection,
570
+ only_cross_attention=only_cross_attention,
571
+ upcast_attention=upcast_attention,
572
+ attention_type=attention_type,
573
+ )
574
+ )
575
+ else:
576
+ attentions.append(
577
+ DualTransformer2DModel(
578
+ num_attention_heads,
579
+ out_channels // num_attention_heads,
580
+ in_channels=out_channels,
581
+ num_layers=1,
582
+ cross_attention_dim=cross_attention_dim,
583
+ norm_num_groups=resnet_groups,
584
+ )
585
+ )
586
+ self.attentions = nn.ModuleList(attentions)
587
+ self.resnets = nn.ModuleList(resnets)
588
+
589
+ if add_downsample:
590
+ self.downsamplers = nn.ModuleList(
591
+ [
592
+ Downsample2D(
593
+ out_channels,
594
+ use_conv=True,
595
+ out_channels=out_channels,
596
+ padding=downsample_padding,
597
+ name="op",
598
+ )
599
+ ]
600
+ )
601
+ else:
602
+ self.downsamplers = None
603
+
604
+ self.gradient_checkpointing = False
605
+
606
+ def forward(
607
+ self,
608
+ hidden_states: torch.FloatTensor,
609
+ temb: Optional[torch.FloatTensor] = None,
610
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
611
+ attention_mask: Optional[torch.FloatTensor] = None,
612
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
613
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
614
+ additional_residuals: Optional[torch.FloatTensor] = None,
615
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
616
+ output_states = ()
617
+
618
+ lora_scale = (
619
+ cross_attention_kwargs.get("scale", 1.0)
620
+ if cross_attention_kwargs is not None
621
+ else 1.0
622
+ )
623
+
624
+ blocks = list(zip(self.resnets, self.attentions))
625
+
626
+ for i, (resnet, attn) in enumerate(blocks):
627
+ if self.training and self.gradient_checkpointing:
628
+
629
+ def create_custom_forward(module, return_dict=None):
630
+ def custom_forward(*inputs):
631
+ if return_dict is not None:
632
+ return module(*inputs, return_dict=return_dict)
633
+ else:
634
+ return module(*inputs)
635
+
636
+ return custom_forward
637
+
638
+ ckpt_kwargs: Dict[str, Any] = (
639
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
640
+ )
641
+ hidden_states = torch.utils.checkpoint.checkpoint(
642
+ create_custom_forward(resnet),
643
+ hidden_states,
644
+ temb,
645
+ **ckpt_kwargs,
646
+ )
647
+ hidden_states, ref_feature = attn(
648
+ hidden_states,
649
+ encoder_hidden_states=encoder_hidden_states,
650
+ cross_attention_kwargs=cross_attention_kwargs,
651
+ attention_mask=attention_mask,
652
+ encoder_attention_mask=encoder_attention_mask,
653
+ return_dict=False,
654
+ )
655
+ else:
656
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
657
+ hidden_states, ref_feature = attn(
658
+ hidden_states,
659
+ encoder_hidden_states=encoder_hidden_states,
660
+ cross_attention_kwargs=cross_attention_kwargs,
661
+ attention_mask=attention_mask,
662
+ encoder_attention_mask=encoder_attention_mask,
663
+ return_dict=False,
664
+ )
665
+
666
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
667
+ if i == len(blocks) - 1 and additional_residuals is not None:
668
+ hidden_states = hidden_states + additional_residuals
669
+
670
+ output_states = output_states + (hidden_states,)
671
+
672
+ if self.downsamplers is not None:
673
+ for downsampler in self.downsamplers:
674
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
675
+
676
+ output_states = output_states + (hidden_states,)
677
+
678
+ return hidden_states, output_states
679
+
680
+
681
+ class DownBlock2D(nn.Module):
682
+ def __init__(
683
+ self,
684
+ in_channels: int,
685
+ out_channels: int,
686
+ temb_channels: int,
687
+ dropout: float = 0.0,
688
+ num_layers: int = 1,
689
+ resnet_eps: float = 1e-6,
690
+ resnet_time_scale_shift: str = "default",
691
+ resnet_act_fn: str = "swish",
692
+ resnet_groups: int = 32,
693
+ resnet_pre_norm: bool = True,
694
+ output_scale_factor: float = 1.0,
695
+ add_downsample: bool = True,
696
+ downsample_padding: int = 1,
697
+ ):
698
+ super().__init__()
699
+ resnets = []
700
+
701
+ for i in range(num_layers):
702
+ in_channels = in_channels if i == 0 else out_channels
703
+ resnets.append(
704
+ ResnetBlock2D(
705
+ in_channels=in_channels,
706
+ out_channels=out_channels,
707
+ temb_channels=temb_channels,
708
+ eps=resnet_eps,
709
+ groups=resnet_groups,
710
+ dropout=dropout,
711
+ time_embedding_norm=resnet_time_scale_shift,
712
+ non_linearity=resnet_act_fn,
713
+ output_scale_factor=output_scale_factor,
714
+ pre_norm=resnet_pre_norm,
715
+ )
716
+ )
717
+
718
+ self.resnets = nn.ModuleList(resnets)
719
+
720
+ if add_downsample:
721
+ self.downsamplers = nn.ModuleList(
722
+ [
723
+ Downsample2D(
724
+ out_channels,
725
+ use_conv=True,
726
+ out_channels=out_channels,
727
+ padding=downsample_padding,
728
+ name="op",
729
+ )
730
+ ]
731
+ )
732
+ else:
733
+ self.downsamplers = None
734
+
735
+ self.gradient_checkpointing = False
736
+
737
+ def forward(
738
+ self,
739
+ hidden_states: torch.FloatTensor,
740
+ temb: Optional[torch.FloatTensor] = None,
741
+ scale: float = 1.0,
742
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
743
+ output_states = ()
744
+
745
+ for resnet in self.resnets:
746
+ if self.training and self.gradient_checkpointing:
747
+
748
+ def create_custom_forward(module):
749
+ def custom_forward(*inputs):
750
+ return module(*inputs)
751
+
752
+ return custom_forward
753
+
754
+ if is_torch_version(">=", "1.11.0"):
755
+ hidden_states = torch.utils.checkpoint.checkpoint(
756
+ create_custom_forward(resnet),
757
+ hidden_states,
758
+ temb,
759
+ use_reentrant=False,
760
+ )
761
+ else:
762
+ hidden_states = torch.utils.checkpoint.checkpoint(
763
+ create_custom_forward(resnet), hidden_states, temb
764
+ )
765
+ else:
766
+ hidden_states = resnet(hidden_states, temb, scale=scale)
767
+
768
+ output_states = output_states + (hidden_states,)
769
+
770
+ if self.downsamplers is not None:
771
+ for downsampler in self.downsamplers:
772
+ hidden_states = downsampler(hidden_states, scale=scale)
773
+
774
+ output_states = output_states + (hidden_states,)
775
+
776
+ return hidden_states, output_states
777
+
778
+
779
+ class CrossAttnUpBlock2D(nn.Module):
780
+ def __init__(
781
+ self,
782
+ in_channels: int,
783
+ out_channels: int,
784
+ prev_output_channel: int,
785
+ temb_channels: int,
786
+ resolution_idx: Optional[int] = None,
787
+ dropout: float = 0.0,
788
+ num_layers: int = 1,
789
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
790
+ resnet_eps: float = 1e-6,
791
+ resnet_time_scale_shift: str = "default",
792
+ resnet_act_fn: str = "swish",
793
+ resnet_groups: int = 32,
794
+ resnet_pre_norm: bool = True,
795
+ num_attention_heads: int = 1,
796
+ cross_attention_dim: int = 1280,
797
+ output_scale_factor: float = 1.0,
798
+ add_upsample: bool = True,
799
+ dual_cross_attention: bool = False,
800
+ use_linear_projection: bool = False,
801
+ only_cross_attention: bool = False,
802
+ upcast_attention: bool = False,
803
+ attention_type: str = "default",
804
+ ):
805
+ super().__init__()
806
+ resnets = []
807
+ attentions = []
808
+
809
+ self.has_cross_attention = True
810
+ self.num_attention_heads = num_attention_heads
811
+
812
+ if isinstance(transformer_layers_per_block, int):
813
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
814
+
815
+ for i in range(num_layers):
816
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
817
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
818
+
819
+ resnets.append(
820
+ ResnetBlock2D(
821
+ in_channels=resnet_in_channels + res_skip_channels,
822
+ out_channels=out_channels,
823
+ temb_channels=temb_channels,
824
+ eps=resnet_eps,
825
+ groups=resnet_groups,
826
+ dropout=dropout,
827
+ time_embedding_norm=resnet_time_scale_shift,
828
+ non_linearity=resnet_act_fn,
829
+ output_scale_factor=output_scale_factor,
830
+ pre_norm=resnet_pre_norm,
831
+ )
832
+ )
833
+ if not dual_cross_attention:
834
+ attentions.append(
835
+ Transformer2DModel(
836
+ num_attention_heads,
837
+ out_channels // num_attention_heads,
838
+ in_channels=out_channels,
839
+ num_layers=transformer_layers_per_block[i],
840
+ cross_attention_dim=cross_attention_dim,
841
+ norm_num_groups=resnet_groups,
842
+ use_linear_projection=use_linear_projection,
843
+ only_cross_attention=only_cross_attention,
844
+ upcast_attention=upcast_attention,
845
+ attention_type=attention_type,
846
+ )
847
+ )
848
+ else:
849
+ attentions.append(
850
+ DualTransformer2DModel(
851
+ num_attention_heads,
852
+ out_channels // num_attention_heads,
853
+ in_channels=out_channels,
854
+ num_layers=1,
855
+ cross_attention_dim=cross_attention_dim,
856
+ norm_num_groups=resnet_groups,
857
+ )
858
+ )
859
+ self.attentions = nn.ModuleList(attentions)
860
+ self.resnets = nn.ModuleList(resnets)
861
+
862
+ if add_upsample:
863
+ self.upsamplers = nn.ModuleList(
864
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
865
+ )
866
+ else:
867
+ self.upsamplers = None
868
+
869
+ self.gradient_checkpointing = False
870
+ self.resolution_idx = resolution_idx
871
+
872
+ def forward(
873
+ self,
874
+ hidden_states: torch.FloatTensor,
875
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
876
+ temb: Optional[torch.FloatTensor] = None,
877
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
878
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
879
+ upsample_size: Optional[int] = None,
880
+ attention_mask: Optional[torch.FloatTensor] = None,
881
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
882
+ ) -> torch.FloatTensor:
883
+ lora_scale = (
884
+ cross_attention_kwargs.get("scale", 1.0)
885
+ if cross_attention_kwargs is not None
886
+ else 1.0
887
+ )
888
+ is_freeu_enabled = (
889
+ getattr(self, "s1", None)
890
+ and getattr(self, "s2", None)
891
+ and getattr(self, "b1", None)
892
+ and getattr(self, "b2", None)
893
+ )
894
+
895
+ for resnet, attn in zip(self.resnets, self.attentions):
896
+ # pop res hidden states
897
+ res_hidden_states = res_hidden_states_tuple[-1]
898
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
899
+
900
+ # FreeU: Only operate on the first two stages
901
+ if is_freeu_enabled:
902
+ hidden_states, res_hidden_states = apply_freeu(
903
+ self.resolution_idx,
904
+ hidden_states,
905
+ res_hidden_states,
906
+ s1=self.s1,
907
+ s2=self.s2,
908
+ b1=self.b1,
909
+ b2=self.b2,
910
+ )
911
+
912
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
913
+
914
+ if self.training and self.gradient_checkpointing:
915
+
916
+ def create_custom_forward(module, return_dict=None):
917
+ def custom_forward(*inputs):
918
+ if return_dict is not None:
919
+ return module(*inputs, return_dict=return_dict)
920
+ else:
921
+ return module(*inputs)
922
+
923
+ return custom_forward
924
+
925
+ ckpt_kwargs: Dict[str, Any] = (
926
+ {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
927
+ )
928
+ hidden_states = torch.utils.checkpoint.checkpoint(
929
+ create_custom_forward(resnet),
930
+ hidden_states,
931
+ temb,
932
+ **ckpt_kwargs,
933
+ )
934
+ hidden_states, ref_feature = attn(
935
+ hidden_states,
936
+ encoder_hidden_states=encoder_hidden_states,
937
+ cross_attention_kwargs=cross_attention_kwargs,
938
+ attention_mask=attention_mask,
939
+ encoder_attention_mask=encoder_attention_mask,
940
+ return_dict=False,
941
+ )
942
+ else:
943
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
944
+ hidden_states, ref_feature = attn(
945
+ hidden_states,
946
+ encoder_hidden_states=encoder_hidden_states,
947
+ cross_attention_kwargs=cross_attention_kwargs,
948
+ attention_mask=attention_mask,
949
+ encoder_attention_mask=encoder_attention_mask,
950
+ return_dict=False,
951
+ )
952
+
953
+ if self.upsamplers is not None:
954
+ for upsampler in self.upsamplers:
955
+ hidden_states = upsampler(
956
+ hidden_states, upsample_size, scale=lora_scale
957
+ )
958
+
959
+ return hidden_states
960
+
961
+
962
+ class UpBlock2D(nn.Module):
963
+ def __init__(
964
+ self,
965
+ in_channels: int,
966
+ prev_output_channel: int,
967
+ out_channels: int,
968
+ temb_channels: int,
969
+ resolution_idx: Optional[int] = None,
970
+ dropout: float = 0.0,
971
+ num_layers: int = 1,
972
+ resnet_eps: float = 1e-6,
973
+ resnet_time_scale_shift: str = "default",
974
+ resnet_act_fn: str = "swish",
975
+ resnet_groups: int = 32,
976
+ resnet_pre_norm: bool = True,
977
+ output_scale_factor: float = 1.0,
978
+ add_upsample: bool = True,
979
+ ):
980
+ super().__init__()
981
+ resnets = []
982
+
983
+ for i in range(num_layers):
984
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
985
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
986
+
987
+ resnets.append(
988
+ ResnetBlock2D(
989
+ in_channels=resnet_in_channels + res_skip_channels,
990
+ out_channels=out_channels,
991
+ temb_channels=temb_channels,
992
+ eps=resnet_eps,
993
+ groups=resnet_groups,
994
+ dropout=dropout,
995
+ time_embedding_norm=resnet_time_scale_shift,
996
+ non_linearity=resnet_act_fn,
997
+ output_scale_factor=output_scale_factor,
998
+ pre_norm=resnet_pre_norm,
999
+ )
1000
+ )
1001
+
1002
+ self.resnets = nn.ModuleList(resnets)
1003
+
1004
+ if add_upsample:
1005
+ self.upsamplers = nn.ModuleList(
1006
+ [Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]
1007
+ )
1008
+ else:
1009
+ self.upsamplers = None
1010
+
1011
+ self.gradient_checkpointing = False
1012
+ self.resolution_idx = resolution_idx
1013
+
1014
+ def forward(
1015
+ self,
1016
+ hidden_states: torch.FloatTensor,
1017
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1018
+ temb: Optional[torch.FloatTensor] = None,
1019
+ upsample_size: Optional[int] = None,
1020
+ scale: float = 1.0,
1021
+ ) -> torch.FloatTensor:
1022
+ is_freeu_enabled = (
1023
+ getattr(self, "s1", None)
1024
+ and getattr(self, "s2", None)
1025
+ and getattr(self, "b1", None)
1026
+ and getattr(self, "b2", None)
1027
+ )
1028
+
1029
+ for resnet in self.resnets:
1030
+ # pop res hidden states
1031
+ res_hidden_states = res_hidden_states_tuple[-1]
1032
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
1033
+
1034
+ # FreeU: Only operate on the first two stages
1035
+ if is_freeu_enabled:
1036
+ hidden_states, res_hidden_states = apply_freeu(
1037
+ self.resolution_idx,
1038
+ hidden_states,
1039
+ res_hidden_states,
1040
+ s1=self.s1,
1041
+ s2=self.s2,
1042
+ b1=self.b1,
1043
+ b2=self.b2,
1044
+ )
1045
+
1046
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1047
+
1048
+ if self.training and self.gradient_checkpointing:
1049
+
1050
+ def create_custom_forward(module):
1051
+ def custom_forward(*inputs):
1052
+ return module(*inputs)
1053
+
1054
+ return custom_forward
1055
+
1056
+ if is_torch_version(">=", "1.11.0"):
1057
+ hidden_states = torch.utils.checkpoint.checkpoint(
1058
+ create_custom_forward(resnet),
1059
+ hidden_states,
1060
+ temb,
1061
+ use_reentrant=False,
1062
+ )
1063
+ else:
1064
+ hidden_states = torch.utils.checkpoint.checkpoint(
1065
+ create_custom_forward(resnet), hidden_states, temb
1066
+ )
1067
+ else:
1068
+ hidden_states = resnet(hidden_states, temb, scale=scale)
1069
+
1070
+ if self.upsamplers is not None:
1071
+ for upsampler in self.upsamplers:
1072
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1073
+
1074
+ return hidden_states
musepose/models/unet_2d_condition.py ADDED
@@ -0,0 +1,1307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ from dataclasses import dataclass
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.utils.checkpoint
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.loaders import UNet2DConditionLoadersMixin
10
+ from diffusers.models.activations import get_activation
11
+ from diffusers.models.attention_processor import (
12
+ ADDED_KV_ATTENTION_PROCESSORS,
13
+ CROSS_ATTENTION_PROCESSORS,
14
+ AttentionProcessor,
15
+ AttnAddedKVProcessor,
16
+ AttnProcessor,
17
+ )
18
+ from diffusers.models.embeddings import (
19
+ GaussianFourierProjection,
20
+ ImageHintTimeEmbedding,
21
+ ImageProjection,
22
+ ImageTimeEmbedding,
23
+ TextImageProjection,
24
+ TextImageTimeEmbedding,
25
+ TextTimeEmbedding,
26
+ TimestepEmbedding,
27
+ Timesteps,
28
+ )
29
+ from diffusers.models.modeling_utils import ModelMixin
30
+ from diffusers.utils import (
31
+ USE_PEFT_BACKEND,
32
+ BaseOutput,
33
+ deprecate,
34
+ logging,
35
+ scale_lora_layers,
36
+ unscale_lora_layers,
37
+ )
38
+
39
+ from .unet_2d_blocks import (
40
+ UNetMidBlock2D,
41
+ UNetMidBlock2DCrossAttn,
42
+ get_down_block,
43
+ get_up_block,
44
+ )
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+
49
+ @dataclass
50
+ class UNet2DConditionOutput(BaseOutput):
51
+ """
52
+ The output of [`UNet2DConditionModel`].
53
+
54
+ Args:
55
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
56
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
57
+ """
58
+
59
+ sample: torch.FloatTensor = None
60
+ ref_features: Tuple[torch.FloatTensor] = None
61
+
62
+
63
+ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
64
+ r"""
65
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
66
+ shaped output.
67
+
68
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
69
+ for all models (such as downloading or saving).
70
+
71
+ Parameters:
72
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
73
+ Height and width of input/output sample.
74
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
75
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
76
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
77
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
78
+ Whether to flip the sin to cos in the time embedding.
79
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
80
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
81
+ The tuple of downsample blocks to use.
82
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
83
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
84
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
85
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
86
+ The tuple of upsample blocks to use.
87
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
88
+ Whether to include self-attention in the basic transformer blocks, see
89
+ [`~models.attention.BasicTransformerBlock`].
90
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
91
+ The tuple of output channels for each block.
92
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
93
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
94
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
95
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
96
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
97
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
98
+ If `None`, normalization and activation layers is skipped in post-processing.
99
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
100
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
101
+ The dimension of the cross attention features.
102
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
103
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
104
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
105
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
106
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
107
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
108
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
109
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
110
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
111
+ encoder_hid_dim (`int`, *optional*, defaults to None):
112
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
113
+ dimension to `cross_attention_dim`.
114
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
115
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
116
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
117
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
118
+ num_attention_heads (`int`, *optional*):
119
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
120
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
121
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
122
+ class_embed_type (`str`, *optional*, defaults to `None`):
123
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
124
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
125
+ addition_embed_type (`str`, *optional*, defaults to `None`):
126
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
127
+ "text". "text" will use the `TextTimeEmbedding` layer.
128
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
129
+ Dimension for the timestep embeddings.
130
+ num_class_embeds (`int`, *optional*, defaults to `None`):
131
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
132
+ class conditioning with `class_embed_type` equal to `None`.
133
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
134
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
135
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
136
+ An optional override for the dimension of the projected time embedding.
137
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
138
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
139
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
140
+ timestep_post_act (`str`, *optional*, defaults to `None`):
141
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
142
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
143
+ The dimension of `cond_proj` layer in the timestep embedding.
144
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
145
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
146
+ *optional*): The dimension of the `class_labels` input when
147
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
148
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
149
+ embeddings with the class embeddings.
150
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
151
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
152
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
153
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
154
+ otherwise.
155
+ """
156
+
157
+ _supports_gradient_checkpointing = True
158
+
159
+ @register_to_config
160
+ def __init__(
161
+ self,
162
+ sample_size: Optional[int] = None,
163
+ in_channels: int = 4,
164
+ out_channels: int = 4,
165
+ center_input_sample: bool = False,
166
+ flip_sin_to_cos: bool = True,
167
+ freq_shift: int = 0,
168
+ down_block_types: Tuple[str] = (
169
+ "CrossAttnDownBlock2D",
170
+ "CrossAttnDownBlock2D",
171
+ "CrossAttnDownBlock2D",
172
+ "DownBlock2D",
173
+ ),
174
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
175
+ up_block_types: Tuple[str] = (
176
+ "UpBlock2D",
177
+ "CrossAttnUpBlock2D",
178
+ "CrossAttnUpBlock2D",
179
+ "CrossAttnUpBlock2D",
180
+ ),
181
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
182
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
183
+ layers_per_block: Union[int, Tuple[int]] = 2,
184
+ downsample_padding: int = 1,
185
+ mid_block_scale_factor: float = 1,
186
+ dropout: float = 0.0,
187
+ act_fn: str = "silu",
188
+ norm_num_groups: Optional[int] = 32,
189
+ norm_eps: float = 1e-5,
190
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
191
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
192
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
193
+ encoder_hid_dim: Optional[int] = None,
194
+ encoder_hid_dim_type: Optional[str] = None,
195
+ attention_head_dim: Union[int, Tuple[int]] = 8,
196
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
197
+ dual_cross_attention: bool = False,
198
+ use_linear_projection: bool = False,
199
+ class_embed_type: Optional[str] = None,
200
+ addition_embed_type: Optional[str] = None,
201
+ addition_time_embed_dim: Optional[int] = None,
202
+ num_class_embeds: Optional[int] = None,
203
+ upcast_attention: bool = False,
204
+ resnet_time_scale_shift: str = "default",
205
+ resnet_skip_time_act: bool = False,
206
+ resnet_out_scale_factor: int = 1.0,
207
+ time_embedding_type: str = "positional",
208
+ time_embedding_dim: Optional[int] = None,
209
+ time_embedding_act_fn: Optional[str] = None,
210
+ timestep_post_act: Optional[str] = None,
211
+ time_cond_proj_dim: Optional[int] = None,
212
+ conv_in_kernel: int = 3,
213
+ conv_out_kernel: int = 3,
214
+ projection_class_embeddings_input_dim: Optional[int] = None,
215
+ attention_type: str = "default",
216
+ class_embeddings_concat: bool = False,
217
+ mid_block_only_cross_attention: Optional[bool] = None,
218
+ cross_attention_norm: Optional[str] = None,
219
+ addition_embed_type_num_heads=64,
220
+ ):
221
+ super().__init__()
222
+
223
+ self.sample_size = sample_size
224
+
225
+ if num_attention_heads is not None:
226
+ raise ValueError(
227
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
228
+ )
229
+
230
+ # If `num_attention_heads` is not defined (which is the case for most models)
231
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
232
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
233
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
234
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
235
+ # which is why we correct for the naming here.
236
+ num_attention_heads = num_attention_heads or attention_head_dim
237
+
238
+ # Check inputs
239
+ if len(down_block_types) != len(up_block_types):
240
+ raise ValueError(
241
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
242
+ )
243
+
244
+ if len(block_out_channels) != len(down_block_types):
245
+ raise ValueError(
246
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
247
+ )
248
+
249
+ if not isinstance(only_cross_attention, bool) and len(
250
+ only_cross_attention
251
+ ) != len(down_block_types):
252
+ raise ValueError(
253
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
254
+ )
255
+
256
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(
257
+ down_block_types
258
+ ):
259
+ raise ValueError(
260
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
261
+ )
262
+
263
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(
264
+ down_block_types
265
+ ):
266
+ raise ValueError(
267
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
268
+ )
269
+
270
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(
271
+ down_block_types
272
+ ):
273
+ raise ValueError(
274
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
275
+ )
276
+
277
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(
278
+ down_block_types
279
+ ):
280
+ raise ValueError(
281
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
282
+ )
283
+ if (
284
+ isinstance(transformer_layers_per_block, list)
285
+ and reverse_transformer_layers_per_block is None
286
+ ):
287
+ for layer_number_per_block in transformer_layers_per_block:
288
+ if isinstance(layer_number_per_block, list):
289
+ raise ValueError(
290
+ "Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet."
291
+ )
292
+
293
+ # input
294
+ conv_in_padding = (conv_in_kernel - 1) // 2
295
+ self.conv_in = nn.Conv2d(
296
+ in_channels,
297
+ block_out_channels[0],
298
+ kernel_size=conv_in_kernel,
299
+ padding=conv_in_padding,
300
+ )
301
+
302
+ # time
303
+ if time_embedding_type == "fourier":
304
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
305
+ if time_embed_dim % 2 != 0:
306
+ raise ValueError(
307
+ f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
308
+ )
309
+ self.time_proj = GaussianFourierProjection(
310
+ time_embed_dim // 2,
311
+ set_W_to_weight=False,
312
+ log=False,
313
+ flip_sin_to_cos=flip_sin_to_cos,
314
+ )
315
+ timestep_input_dim = time_embed_dim
316
+ elif time_embedding_type == "positional":
317
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
318
+
319
+ self.time_proj = Timesteps(
320
+ block_out_channels[0], flip_sin_to_cos, freq_shift
321
+ )
322
+ timestep_input_dim = block_out_channels[0]
323
+ else:
324
+ raise ValueError(
325
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
326
+ )
327
+
328
+ self.time_embedding = TimestepEmbedding(
329
+ timestep_input_dim,
330
+ time_embed_dim,
331
+ act_fn=act_fn,
332
+ post_act_fn=timestep_post_act,
333
+ cond_proj_dim=time_cond_proj_dim,
334
+ )
335
+
336
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
337
+ encoder_hid_dim_type = "text_proj"
338
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
339
+ logger.info(
340
+ "encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined."
341
+ )
342
+
343
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
344
+ raise ValueError(
345
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
346
+ )
347
+
348
+ if encoder_hid_dim_type == "text_proj":
349
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
350
+ elif encoder_hid_dim_type == "text_image_proj":
351
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
352
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
353
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
354
+ self.encoder_hid_proj = TextImageProjection(
355
+ text_embed_dim=encoder_hid_dim,
356
+ image_embed_dim=cross_attention_dim,
357
+ cross_attention_dim=cross_attention_dim,
358
+ )
359
+ elif encoder_hid_dim_type == "image_proj":
360
+ # Kandinsky 2.2
361
+ self.encoder_hid_proj = ImageProjection(
362
+ image_embed_dim=encoder_hid_dim,
363
+ cross_attention_dim=cross_attention_dim,
364
+ )
365
+ elif encoder_hid_dim_type is not None:
366
+ raise ValueError(
367
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
368
+ )
369
+ else:
370
+ self.encoder_hid_proj = None
371
+
372
+ # class embedding
373
+ if class_embed_type is None and num_class_embeds is not None:
374
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
375
+ elif class_embed_type == "timestep":
376
+ self.class_embedding = TimestepEmbedding(
377
+ timestep_input_dim, time_embed_dim, act_fn=act_fn
378
+ )
379
+ elif class_embed_type == "identity":
380
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
381
+ elif class_embed_type == "projection":
382
+ if projection_class_embeddings_input_dim is None:
383
+ raise ValueError(
384
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
385
+ )
386
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
387
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
388
+ # 2. it projects from an arbitrary input dimension.
389
+ #
390
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
391
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
392
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
393
+ self.class_embedding = TimestepEmbedding(
394
+ projection_class_embeddings_input_dim, time_embed_dim
395
+ )
396
+ elif class_embed_type == "simple_projection":
397
+ if projection_class_embeddings_input_dim is None:
398
+ raise ValueError(
399
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
400
+ )
401
+ self.class_embedding = nn.Linear(
402
+ projection_class_embeddings_input_dim, time_embed_dim
403
+ )
404
+ else:
405
+ self.class_embedding = None
406
+
407
+ if addition_embed_type == "text":
408
+ if encoder_hid_dim is not None:
409
+ text_time_embedding_from_dim = encoder_hid_dim
410
+ else:
411
+ text_time_embedding_from_dim = cross_attention_dim
412
+
413
+ self.add_embedding = TextTimeEmbedding(
414
+ text_time_embedding_from_dim,
415
+ time_embed_dim,
416
+ num_heads=addition_embed_type_num_heads,
417
+ )
418
+ elif addition_embed_type == "text_image":
419
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
420
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
421
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
422
+ self.add_embedding = TextImageTimeEmbedding(
423
+ text_embed_dim=cross_attention_dim,
424
+ image_embed_dim=cross_attention_dim,
425
+ time_embed_dim=time_embed_dim,
426
+ )
427
+ elif addition_embed_type == "text_time":
428
+ self.add_time_proj = Timesteps(
429
+ addition_time_embed_dim, flip_sin_to_cos, freq_shift
430
+ )
431
+ self.add_embedding = TimestepEmbedding(
432
+ projection_class_embeddings_input_dim, time_embed_dim
433
+ )
434
+ elif addition_embed_type == "image":
435
+ # Kandinsky 2.2
436
+ self.add_embedding = ImageTimeEmbedding(
437
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
438
+ )
439
+ elif addition_embed_type == "image_hint":
440
+ # Kandinsky 2.2 ControlNet
441
+ self.add_embedding = ImageHintTimeEmbedding(
442
+ image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim
443
+ )
444
+ elif addition_embed_type is not None:
445
+ raise ValueError(
446
+ f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'."
447
+ )
448
+
449
+ if time_embedding_act_fn is None:
450
+ self.time_embed_act = None
451
+ else:
452
+ self.time_embed_act = get_activation(time_embedding_act_fn)
453
+
454
+ self.down_blocks = nn.ModuleList([])
455
+ self.up_blocks = nn.ModuleList([])
456
+
457
+ if isinstance(only_cross_attention, bool):
458
+ if mid_block_only_cross_attention is None:
459
+ mid_block_only_cross_attention = only_cross_attention
460
+
461
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
462
+
463
+ if mid_block_only_cross_attention is None:
464
+ mid_block_only_cross_attention = False
465
+
466
+ if isinstance(num_attention_heads, int):
467
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
468
+
469
+ if isinstance(attention_head_dim, int):
470
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
471
+
472
+ if isinstance(cross_attention_dim, int):
473
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
474
+
475
+ if isinstance(layers_per_block, int):
476
+ layers_per_block = [layers_per_block] * len(down_block_types)
477
+
478
+ if isinstance(transformer_layers_per_block, int):
479
+ transformer_layers_per_block = [transformer_layers_per_block] * len(
480
+ down_block_types
481
+ )
482
+
483
+ if class_embeddings_concat:
484
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
485
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
486
+ # regular time embeddings
487
+ blocks_time_embed_dim = time_embed_dim * 2
488
+ else:
489
+ blocks_time_embed_dim = time_embed_dim
490
+
491
+ # down
492
+ output_channel = block_out_channels[0]
493
+ for i, down_block_type in enumerate(down_block_types):
494
+ input_channel = output_channel
495
+ output_channel = block_out_channels[i]
496
+ is_final_block = i == len(block_out_channels) - 1
497
+
498
+ down_block = get_down_block(
499
+ down_block_type,
500
+ num_layers=layers_per_block[i],
501
+ transformer_layers_per_block=transformer_layers_per_block[i],
502
+ in_channels=input_channel,
503
+ out_channels=output_channel,
504
+ temb_channels=blocks_time_embed_dim,
505
+ add_downsample=not is_final_block,
506
+ resnet_eps=norm_eps,
507
+ resnet_act_fn=act_fn,
508
+ resnet_groups=norm_num_groups,
509
+ cross_attention_dim=cross_attention_dim[i],
510
+ num_attention_heads=num_attention_heads[i],
511
+ downsample_padding=downsample_padding,
512
+ dual_cross_attention=dual_cross_attention,
513
+ use_linear_projection=use_linear_projection,
514
+ only_cross_attention=only_cross_attention[i],
515
+ upcast_attention=upcast_attention,
516
+ resnet_time_scale_shift=resnet_time_scale_shift,
517
+ attention_type=attention_type,
518
+ resnet_skip_time_act=resnet_skip_time_act,
519
+ resnet_out_scale_factor=resnet_out_scale_factor,
520
+ cross_attention_norm=cross_attention_norm,
521
+ attention_head_dim=attention_head_dim[i]
522
+ if attention_head_dim[i] is not None
523
+ else output_channel,
524
+ dropout=dropout,
525
+ )
526
+ self.down_blocks.append(down_block)
527
+
528
+ # mid
529
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
530
+ self.mid_block = UNetMidBlock2DCrossAttn(
531
+ transformer_layers_per_block=transformer_layers_per_block[-1],
532
+ in_channels=block_out_channels[-1],
533
+ temb_channels=blocks_time_embed_dim,
534
+ dropout=dropout,
535
+ resnet_eps=norm_eps,
536
+ resnet_act_fn=act_fn,
537
+ output_scale_factor=mid_block_scale_factor,
538
+ resnet_time_scale_shift=resnet_time_scale_shift,
539
+ cross_attention_dim=cross_attention_dim[-1],
540
+ num_attention_heads=num_attention_heads[-1],
541
+ resnet_groups=norm_num_groups,
542
+ dual_cross_attention=dual_cross_attention,
543
+ use_linear_projection=use_linear_projection,
544
+ upcast_attention=upcast_attention,
545
+ attention_type=attention_type,
546
+ )
547
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
548
+ raise NotImplementedError(f"Unsupport mid_block_type: {mid_block_type}")
549
+ elif mid_block_type == "UNetMidBlock2D":
550
+ self.mid_block = UNetMidBlock2D(
551
+ in_channels=block_out_channels[-1],
552
+ temb_channels=blocks_time_embed_dim,
553
+ dropout=dropout,
554
+ num_layers=0,
555
+ resnet_eps=norm_eps,
556
+ resnet_act_fn=act_fn,
557
+ output_scale_factor=mid_block_scale_factor,
558
+ resnet_groups=norm_num_groups,
559
+ resnet_time_scale_shift=resnet_time_scale_shift,
560
+ add_attention=False,
561
+ )
562
+ elif mid_block_type is None:
563
+ self.mid_block = None
564
+ else:
565
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
566
+
567
+ # count how many layers upsample the images
568
+ self.num_upsamplers = 0
569
+
570
+ # up
571
+ reversed_block_out_channels = list(reversed(block_out_channels))
572
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
573
+ reversed_layers_per_block = list(reversed(layers_per_block))
574
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
575
+ reversed_transformer_layers_per_block = (
576
+ list(reversed(transformer_layers_per_block))
577
+ if reverse_transformer_layers_per_block is None
578
+ else reverse_transformer_layers_per_block
579
+ )
580
+ only_cross_attention = list(reversed(only_cross_attention))
581
+
582
+ output_channel = reversed_block_out_channels[0]
583
+ for i, up_block_type in enumerate(up_block_types):
584
+ is_final_block = i == len(block_out_channels) - 1
585
+
586
+ prev_output_channel = output_channel
587
+ output_channel = reversed_block_out_channels[i]
588
+ input_channel = reversed_block_out_channels[
589
+ min(i + 1, len(block_out_channels) - 1)
590
+ ]
591
+
592
+ # add upsample block for all BUT final layer
593
+ if not is_final_block:
594
+ add_upsample = True
595
+ self.num_upsamplers += 1
596
+ else:
597
+ add_upsample = False
598
+
599
+ up_block = get_up_block(
600
+ up_block_type,
601
+ num_layers=reversed_layers_per_block[i] + 1,
602
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
603
+ in_channels=input_channel,
604
+ out_channels=output_channel,
605
+ prev_output_channel=prev_output_channel,
606
+ temb_channels=blocks_time_embed_dim,
607
+ add_upsample=add_upsample,
608
+ resnet_eps=norm_eps,
609
+ resnet_act_fn=act_fn,
610
+ resolution_idx=i,
611
+ resnet_groups=norm_num_groups,
612
+ cross_attention_dim=reversed_cross_attention_dim[i],
613
+ num_attention_heads=reversed_num_attention_heads[i],
614
+ dual_cross_attention=dual_cross_attention,
615
+ use_linear_projection=use_linear_projection,
616
+ only_cross_attention=only_cross_attention[i],
617
+ upcast_attention=upcast_attention,
618
+ resnet_time_scale_shift=resnet_time_scale_shift,
619
+ attention_type=attention_type,
620
+ resnet_skip_time_act=resnet_skip_time_act,
621
+ resnet_out_scale_factor=resnet_out_scale_factor,
622
+ cross_attention_norm=cross_attention_norm,
623
+ attention_head_dim=attention_head_dim[i]
624
+ if attention_head_dim[i] is not None
625
+ else output_channel,
626
+ dropout=dropout,
627
+ )
628
+ self.up_blocks.append(up_block)
629
+ prev_output_channel = output_channel
630
+
631
+ # out
632
+ if norm_num_groups is not None:
633
+ self.conv_norm_out = nn.GroupNorm(
634
+ num_channels=block_out_channels[0],
635
+ num_groups=norm_num_groups,
636
+ eps=norm_eps,
637
+ )
638
+
639
+ self.conv_act = get_activation(act_fn)
640
+
641
+ else:
642
+ self.conv_norm_out = None
643
+ self.conv_act = None
644
+ self.conv_norm_out = None
645
+
646
+ conv_out_padding = (conv_out_kernel - 1) // 2
647
+ # self.conv_out = nn.Conv2d(
648
+ # block_out_channels[0],
649
+ # out_channels,
650
+ # kernel_size=conv_out_kernel,
651
+ # padding=conv_out_padding,
652
+ # )
653
+
654
+ if attention_type in ["gated", "gated-text-image"]:
655
+ positive_len = 768
656
+ if isinstance(cross_attention_dim, int):
657
+ positive_len = cross_attention_dim
658
+ elif isinstance(cross_attention_dim, tuple) or isinstance(
659
+ cross_attention_dim, list
660
+ ):
661
+ positive_len = cross_attention_dim[0]
662
+
663
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
664
+ self.position_net = PositionNet(
665
+ positive_len=positive_len,
666
+ out_dim=cross_attention_dim,
667
+ feature_type=feature_type,
668
+ )
669
+
670
+ @property
671
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
672
+ r"""
673
+ Returns:
674
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
675
+ indexed by its weight name.
676
+ """
677
+ # set recursively
678
+ processors = {}
679
+
680
+ def fn_recursive_add_processors(
681
+ name: str,
682
+ module: torch.nn.Module,
683
+ processors: Dict[str, AttentionProcessor],
684
+ ):
685
+ if hasattr(module, "get_processor"):
686
+ processors[f"{name}.processor"] = module.get_processor(
687
+ return_deprecated_lora=True
688
+ )
689
+
690
+ for sub_name, child in module.named_children():
691
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
692
+
693
+ return processors
694
+
695
+ for name, module in self.named_children():
696
+ fn_recursive_add_processors(name, module, processors)
697
+
698
+ return processors
699
+
700
+ def set_attn_processor(
701
+ self,
702
+ processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]],
703
+ _remove_lora=False,
704
+ ):
705
+ r"""
706
+ Sets the attention processor to use to compute attention.
707
+
708
+ Parameters:
709
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
710
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
711
+ for **all** `Attention` layers.
712
+
713
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
714
+ processor. This is strongly recommended when setting trainable attention processors.
715
+
716
+ """
717
+ count = len(self.attn_processors.keys())
718
+
719
+ if isinstance(processor, dict) and len(processor) != count:
720
+ raise ValueError(
721
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
722
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
723
+ )
724
+
725
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
726
+ if hasattr(module, "set_processor"):
727
+ if not isinstance(processor, dict):
728
+ module.set_processor(processor, _remove_lora=_remove_lora)
729
+ else:
730
+ module.set_processor(
731
+ processor.pop(f"{name}.processor"), _remove_lora=_remove_lora
732
+ )
733
+
734
+ for sub_name, child in module.named_children():
735
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
736
+
737
+ for name, module in self.named_children():
738
+ fn_recursive_attn_processor(name, module, processor)
739
+
740
+ def set_default_attn_processor(self):
741
+ """
742
+ Disables custom attention processors and sets the default attention implementation.
743
+ """
744
+ if all(
745
+ proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS
746
+ for proc in self.attn_processors.values()
747
+ ):
748
+ processor = AttnAddedKVProcessor()
749
+ elif all(
750
+ proc.__class__ in CROSS_ATTENTION_PROCESSORS
751
+ for proc in self.attn_processors.values()
752
+ ):
753
+ processor = AttnProcessor()
754
+ else:
755
+ raise ValueError(
756
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
757
+ )
758
+
759
+ self.set_attn_processor(processor, _remove_lora=True)
760
+
761
+ def set_attention_slice(self, slice_size):
762
+ r"""
763
+ Enable sliced attention computation.
764
+
765
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
766
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
767
+
768
+ Args:
769
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
770
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
771
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
772
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
773
+ must be a multiple of `slice_size`.
774
+ """
775
+ sliceable_head_dims = []
776
+
777
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
778
+ if hasattr(module, "set_attention_slice"):
779
+ sliceable_head_dims.append(module.sliceable_head_dim)
780
+
781
+ for child in module.children():
782
+ fn_recursive_retrieve_sliceable_dims(child)
783
+
784
+ # retrieve number of attention layers
785
+ for module in self.children():
786
+ fn_recursive_retrieve_sliceable_dims(module)
787
+
788
+ num_sliceable_layers = len(sliceable_head_dims)
789
+
790
+ if slice_size == "auto":
791
+ # half the attention head size is usually a good trade-off between
792
+ # speed and memory
793
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
794
+ elif slice_size == "max":
795
+ # make smallest slice possible
796
+ slice_size = num_sliceable_layers * [1]
797
+
798
+ slice_size = (
799
+ num_sliceable_layers * [slice_size]
800
+ if not isinstance(slice_size, list)
801
+ else slice_size
802
+ )
803
+
804
+ if len(slice_size) != len(sliceable_head_dims):
805
+ raise ValueError(
806
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
807
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
808
+ )
809
+
810
+ for i in range(len(slice_size)):
811
+ size = slice_size[i]
812
+ dim = sliceable_head_dims[i]
813
+ if size is not None and size > dim:
814
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
815
+
816
+ # Recursively walk through all the children.
817
+ # Any children which exposes the set_attention_slice method
818
+ # gets the message
819
+ def fn_recursive_set_attention_slice(
820
+ module: torch.nn.Module, slice_size: List[int]
821
+ ):
822
+ if hasattr(module, "set_attention_slice"):
823
+ module.set_attention_slice(slice_size.pop())
824
+
825
+ for child in module.children():
826
+ fn_recursive_set_attention_slice(child, slice_size)
827
+
828
+ reversed_slice_size = list(reversed(slice_size))
829
+ for module in self.children():
830
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
831
+
832
+ def _set_gradient_checkpointing(self, module, value=False):
833
+ if hasattr(module, "gradient_checkpointing"):
834
+ module.gradient_checkpointing = value
835
+
836
+ def enable_freeu(self, s1, s2, b1, b2):
837
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
838
+
839
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
840
+
841
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
842
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
843
+
844
+ Args:
845
+ s1 (`float`):
846
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
847
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
848
+ s2 (`float`):
849
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
850
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
851
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
852
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
853
+ """
854
+ for i, upsample_block in enumerate(self.up_blocks):
855
+ setattr(upsample_block, "s1", s1)
856
+ setattr(upsample_block, "s2", s2)
857
+ setattr(upsample_block, "b1", b1)
858
+ setattr(upsample_block, "b2", b2)
859
+
860
+ def disable_freeu(self):
861
+ """Disables the FreeU mechanism."""
862
+ freeu_keys = {"s1", "s2", "b1", "b2"}
863
+ for i, upsample_block in enumerate(self.up_blocks):
864
+ for k in freeu_keys:
865
+ if (
866
+ hasattr(upsample_block, k)
867
+ or getattr(upsample_block, k, None) is not None
868
+ ):
869
+ setattr(upsample_block, k, None)
870
+
871
+ def forward(
872
+ self,
873
+ sample: torch.FloatTensor,
874
+ timestep: Union[torch.Tensor, float, int],
875
+ encoder_hidden_states: torch.Tensor,
876
+ class_labels: Optional[torch.Tensor] = None,
877
+ timestep_cond: Optional[torch.Tensor] = None,
878
+ attention_mask: Optional[torch.Tensor] = None,
879
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
880
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
881
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
882
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
883
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
884
+ encoder_attention_mask: Optional[torch.Tensor] = None,
885
+ return_dict: bool = True,
886
+ ) -> Union[UNet2DConditionOutput, Tuple]:
887
+ r"""
888
+ The [`UNet2DConditionModel`] forward method.
889
+
890
+ Args:
891
+ sample (`torch.FloatTensor`):
892
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
893
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
894
+ encoder_hidden_states (`torch.FloatTensor`):
895
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
896
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
897
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
898
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
899
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
900
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
901
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
902
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
903
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
904
+ negative values to the attention scores corresponding to "discard" tokens.
905
+ cross_attention_kwargs (`dict`, *optional*):
906
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
907
+ `self.processor` in
908
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
909
+ added_cond_kwargs: (`dict`, *optional*):
910
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
911
+ are passed along to the UNet blocks.
912
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
913
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
914
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
915
+ A tensor that if specified is added to the residual of the middle unet block.
916
+ encoder_attention_mask (`torch.Tensor`):
917
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
918
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
919
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
920
+ return_dict (`bool`, *optional*, defaults to `True`):
921
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
922
+ tuple.
923
+ cross_attention_kwargs (`dict`, *optional*):
924
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
925
+ added_cond_kwargs: (`dict`, *optional*):
926
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
927
+ are passed along to the UNet blocks.
928
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
929
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
930
+ example from ControlNet side model(s)
931
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
932
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
933
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
934
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
935
+
936
+ Returns:
937
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
938
+ If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
939
+ a `tuple` is returned where the first element is the sample tensor.
940
+ """
941
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
942
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
943
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
944
+ # on the fly if necessary.
945
+ default_overall_up_factor = 2**self.num_upsamplers
946
+
947
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
948
+ forward_upsample_size = False
949
+ upsample_size = None
950
+
951
+ for dim in sample.shape[-2:]:
952
+ if dim % default_overall_up_factor != 0:
953
+ # Forward upsample size to force interpolation output size.
954
+ forward_upsample_size = True
955
+ break
956
+
957
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
958
+ # expects mask of shape:
959
+ # [batch, key_tokens]
960
+ # adds singleton query_tokens dimension:
961
+ # [batch, 1, key_tokens]
962
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
963
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
964
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
965
+ if attention_mask is not None:
966
+ # assume that mask is expressed as:
967
+ # (1 = keep, 0 = discard)
968
+ # convert mask into a bias that can be added to attention scores:
969
+ # (keep = +0, discard = -10000.0)
970
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
971
+ attention_mask = attention_mask.unsqueeze(1)
972
+
973
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
974
+ if encoder_attention_mask is not None:
975
+ encoder_attention_mask = (
976
+ 1 - encoder_attention_mask.to(sample.dtype)
977
+ ) * -10000.0
978
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
979
+
980
+ # 0. center input if necessary
981
+ if self.config.center_input_sample:
982
+ sample = 2 * sample - 1.0
983
+
984
+ # 1. time
985
+ timesteps = timestep
986
+ if not torch.is_tensor(timesteps):
987
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
988
+ # This would be a good case for the `match` statement (Python 3.10+)
989
+ is_mps = sample.device.type == "mps"
990
+ if isinstance(timestep, float):
991
+ dtype = torch.float32 if is_mps else torch.float64
992
+ else:
993
+ dtype = torch.int32 if is_mps else torch.int64
994
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
995
+ elif len(timesteps.shape) == 0:
996
+ timesteps = timesteps[None].to(sample.device)
997
+
998
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
999
+ timesteps = timesteps.expand(sample.shape[0])
1000
+
1001
+ t_emb = self.time_proj(timesteps)
1002
+
1003
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1004
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1005
+ # there might be better ways to encapsulate this.
1006
+ t_emb = t_emb.to(dtype=sample.dtype)
1007
+
1008
+ emb = self.time_embedding(t_emb, timestep_cond)
1009
+ aug_emb = None
1010
+
1011
+ if self.class_embedding is not None:
1012
+ if class_labels is None:
1013
+ raise ValueError(
1014
+ "class_labels should be provided when num_class_embeds > 0"
1015
+ )
1016
+
1017
+ if self.config.class_embed_type == "timestep":
1018
+ class_labels = self.time_proj(class_labels)
1019
+
1020
+ # `Timesteps` does not contain any weights and will always return f32 tensors
1021
+ # there might be better ways to encapsulate this.
1022
+ class_labels = class_labels.to(dtype=sample.dtype)
1023
+
1024
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
1025
+
1026
+ if self.config.class_embeddings_concat:
1027
+ emb = torch.cat([emb, class_emb], dim=-1)
1028
+ else:
1029
+ emb = emb + class_emb
1030
+
1031
+ if self.config.addition_embed_type == "text":
1032
+ aug_emb = self.add_embedding(encoder_hidden_states)
1033
+ elif self.config.addition_embed_type == "text_image":
1034
+ # Kandinsky 2.1 - style
1035
+ if "image_embeds" not in added_cond_kwargs:
1036
+ raise ValueError(
1037
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1038
+ )
1039
+
1040
+ image_embs = added_cond_kwargs.get("image_embeds")
1041
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
1042
+ aug_emb = self.add_embedding(text_embs, image_embs)
1043
+ elif self.config.addition_embed_type == "text_time":
1044
+ # SDXL - style
1045
+ if "text_embeds" not in added_cond_kwargs:
1046
+ raise ValueError(
1047
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
1048
+ )
1049
+ text_embeds = added_cond_kwargs.get("text_embeds")
1050
+ if "time_ids" not in added_cond_kwargs:
1051
+ raise ValueError(
1052
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
1053
+ )
1054
+ time_ids = added_cond_kwargs.get("time_ids")
1055
+ time_embeds = self.add_time_proj(time_ids.flatten())
1056
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
1057
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
1058
+ add_embeds = add_embeds.to(emb.dtype)
1059
+ aug_emb = self.add_embedding(add_embeds)
1060
+ elif self.config.addition_embed_type == "image":
1061
+ # Kandinsky 2.2 - style
1062
+ if "image_embeds" not in added_cond_kwargs:
1063
+ raise ValueError(
1064
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
1065
+ )
1066
+ image_embs = added_cond_kwargs.get("image_embeds")
1067
+ aug_emb = self.add_embedding(image_embs)
1068
+ elif self.config.addition_embed_type == "image_hint":
1069
+ # Kandinsky 2.2 - style
1070
+ if (
1071
+ "image_embeds" not in added_cond_kwargs
1072
+ or "hint" not in added_cond_kwargs
1073
+ ):
1074
+ raise ValueError(
1075
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
1076
+ )
1077
+ image_embs = added_cond_kwargs.get("image_embeds")
1078
+ hint = added_cond_kwargs.get("hint")
1079
+ aug_emb, hint = self.add_embedding(image_embs, hint)
1080
+ sample = torch.cat([sample, hint], dim=1)
1081
+
1082
+ emb = emb + aug_emb if aug_emb is not None else emb
1083
+
1084
+ if self.time_embed_act is not None:
1085
+ emb = self.time_embed_act(emb)
1086
+
1087
+ if (
1088
+ self.encoder_hid_proj is not None
1089
+ and self.config.encoder_hid_dim_type == "text_proj"
1090
+ ):
1091
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1092
+ elif (
1093
+ self.encoder_hid_proj is not None
1094
+ and self.config.encoder_hid_dim_type == "text_image_proj"
1095
+ ):
1096
+ # Kadinsky 2.1 - style
1097
+ if "image_embeds" not in added_cond_kwargs:
1098
+ raise ValueError(
1099
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1100
+ )
1101
+
1102
+ image_embeds = added_cond_kwargs.get("image_embeds")
1103
+ encoder_hidden_states = self.encoder_hid_proj(
1104
+ encoder_hidden_states, image_embeds
1105
+ )
1106
+ elif (
1107
+ self.encoder_hid_proj is not None
1108
+ and self.config.encoder_hid_dim_type == "image_proj"
1109
+ ):
1110
+ # Kandinsky 2.2 - style
1111
+ if "image_embeds" not in added_cond_kwargs:
1112
+ raise ValueError(
1113
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1114
+ )
1115
+ image_embeds = added_cond_kwargs.get("image_embeds")
1116
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
1117
+ elif (
1118
+ self.encoder_hid_proj is not None
1119
+ and self.config.encoder_hid_dim_type == "ip_image_proj"
1120
+ ):
1121
+ if "image_embeds" not in added_cond_kwargs:
1122
+ raise ValueError(
1123
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
1124
+ )
1125
+ image_embeds = added_cond_kwargs.get("image_embeds")
1126
+ image_embeds = self.encoder_hid_proj(image_embeds).to(
1127
+ encoder_hidden_states.dtype
1128
+ )
1129
+ encoder_hidden_states = torch.cat(
1130
+ [encoder_hidden_states, image_embeds], dim=1
1131
+ )
1132
+
1133
+ # 2. pre-process
1134
+ sample = self.conv_in(sample)
1135
+
1136
+ # 2.5 GLIGEN position net
1137
+ if (
1138
+ cross_attention_kwargs is not None
1139
+ and cross_attention_kwargs.get("gligen", None) is not None
1140
+ ):
1141
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1142
+ gligen_args = cross_attention_kwargs.pop("gligen")
1143
+ cross_attention_kwargs["gligen"] = {
1144
+ "objs": self.position_net(**gligen_args)
1145
+ }
1146
+
1147
+ # 3. down
1148
+ lora_scale = (
1149
+ cross_attention_kwargs.get("scale", 1.0)
1150
+ if cross_attention_kwargs is not None
1151
+ else 1.0
1152
+ )
1153
+ if USE_PEFT_BACKEND:
1154
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
1155
+ scale_lora_layers(self, lora_scale)
1156
+
1157
+ is_controlnet = (
1158
+ mid_block_additional_residual is not None
1159
+ and down_block_additional_residuals is not None
1160
+ )
1161
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
1162
+ is_adapter = down_intrablock_additional_residuals is not None
1163
+ # maintain backward compatibility for legacy usage, where
1164
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
1165
+ # but can only use one or the other
1166
+ if (
1167
+ not is_adapter
1168
+ and mid_block_additional_residual is None
1169
+ and down_block_additional_residuals is not None
1170
+ ):
1171
+ deprecate(
1172
+ "T2I should not use down_block_additional_residuals",
1173
+ "1.3.0",
1174
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
1175
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
1176
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
1177
+ standard_warn=False,
1178
+ )
1179
+ down_intrablock_additional_residuals = down_block_additional_residuals
1180
+ is_adapter = True
1181
+
1182
+ down_block_res_samples = (sample,)
1183
+ tot_referece_features = ()
1184
+ for downsample_block in self.down_blocks:
1185
+ if (
1186
+ hasattr(downsample_block, "has_cross_attention")
1187
+ and downsample_block.has_cross_attention
1188
+ ):
1189
+ # For t2i-adapter CrossAttnDownBlock2D
1190
+ additional_residuals = {}
1191
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1192
+ additional_residuals[
1193
+ "additional_residuals"
1194
+ ] = down_intrablock_additional_residuals.pop(0)
1195
+
1196
+ sample, res_samples = downsample_block(
1197
+ hidden_states=sample,
1198
+ temb=emb,
1199
+ encoder_hidden_states=encoder_hidden_states,
1200
+ attention_mask=attention_mask,
1201
+ cross_attention_kwargs=cross_attention_kwargs,
1202
+ encoder_attention_mask=encoder_attention_mask,
1203
+ **additional_residuals,
1204
+ )
1205
+ else:
1206
+ sample, res_samples = downsample_block(
1207
+ hidden_states=sample, temb=emb, scale=lora_scale
1208
+ )
1209
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
1210
+ sample += down_intrablock_additional_residuals.pop(0)
1211
+
1212
+ down_block_res_samples += res_samples
1213
+
1214
+ if is_controlnet:
1215
+ new_down_block_res_samples = ()
1216
+
1217
+ for down_block_res_sample, down_block_additional_residual in zip(
1218
+ down_block_res_samples, down_block_additional_residuals
1219
+ ):
1220
+ down_block_res_sample = (
1221
+ down_block_res_sample + down_block_additional_residual
1222
+ )
1223
+ new_down_block_res_samples = new_down_block_res_samples + (
1224
+ down_block_res_sample,
1225
+ )
1226
+
1227
+ down_block_res_samples = new_down_block_res_samples
1228
+
1229
+ # 4. mid
1230
+ if self.mid_block is not None:
1231
+ if (
1232
+ hasattr(self.mid_block, "has_cross_attention")
1233
+ and self.mid_block.has_cross_attention
1234
+ ):
1235
+ sample = self.mid_block(
1236
+ sample,
1237
+ emb,
1238
+ encoder_hidden_states=encoder_hidden_states,
1239
+ attention_mask=attention_mask,
1240
+ cross_attention_kwargs=cross_attention_kwargs,
1241
+ encoder_attention_mask=encoder_attention_mask,
1242
+ )
1243
+ else:
1244
+ sample = self.mid_block(sample, emb)
1245
+
1246
+ # To support T2I-Adapter-XL
1247
+ if (
1248
+ is_adapter
1249
+ and len(down_intrablock_additional_residuals) > 0
1250
+ and sample.shape == down_intrablock_additional_residuals[0].shape
1251
+ ):
1252
+ sample += down_intrablock_additional_residuals.pop(0)
1253
+
1254
+ if is_controlnet:
1255
+ sample = sample + mid_block_additional_residual
1256
+
1257
+ # 5. up
1258
+ for i, upsample_block in enumerate(self.up_blocks):
1259
+ is_final_block = i == len(self.up_blocks) - 1
1260
+
1261
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
1262
+ down_block_res_samples = down_block_res_samples[
1263
+ : -len(upsample_block.resnets)
1264
+ ]
1265
+
1266
+ # if we have not reached the final block and need to forward the
1267
+ # upsample size, we do it here
1268
+ if not is_final_block and forward_upsample_size:
1269
+ upsample_size = down_block_res_samples[-1].shape[2:]
1270
+
1271
+ if (
1272
+ hasattr(upsample_block, "has_cross_attention")
1273
+ and upsample_block.has_cross_attention
1274
+ ):
1275
+ sample = upsample_block(
1276
+ hidden_states=sample,
1277
+ temb=emb,
1278
+ res_hidden_states_tuple=res_samples,
1279
+ encoder_hidden_states=encoder_hidden_states,
1280
+ cross_attention_kwargs=cross_attention_kwargs,
1281
+ upsample_size=upsample_size,
1282
+ attention_mask=attention_mask,
1283
+ encoder_attention_mask=encoder_attention_mask,
1284
+ )
1285
+ else:
1286
+ sample = upsample_block(
1287
+ hidden_states=sample,
1288
+ temb=emb,
1289
+ res_hidden_states_tuple=res_samples,
1290
+ upsample_size=upsample_size,
1291
+ scale=lora_scale,
1292
+ )
1293
+
1294
+ # 6. post-process
1295
+ # if self.conv_norm_out:
1296
+ # sample = self.conv_norm_out(sample)
1297
+ # sample = self.conv_act(sample)
1298
+ # sample = self.conv_out(sample)
1299
+
1300
+ if USE_PEFT_BACKEND:
1301
+ # remove `lora_scale` from each PEFT layer
1302
+ unscale_lora_layers(self, lora_scale)
1303
+
1304
+ if not return_dict:
1305
+ return (sample,)
1306
+
1307
+ return UNet2DConditionOutput(sample=sample)
musepose/models/unet_3d.py ADDED
@@ -0,0 +1,675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
2
+
3
+ from collections import OrderedDict
4
+ from dataclasses import dataclass
5
+ from os import PathLike
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.checkpoint
12
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
13
+ from diffusers.models.attention_processor import AttentionProcessor
14
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME, WEIGHTS_NAME, BaseOutput, logging
17
+ from safetensors.torch import load_file
18
+
19
+ from .resnet import InflatedConv3d, InflatedGroupNorm
20
+ from .unet_3d_blocks import UNetMidBlock3DCrossAttn, get_down_block, get_up_block
21
+
22
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
23
+
24
+
25
+ @dataclass
26
+ class UNet3DConditionOutput(BaseOutput):
27
+ sample: torch.FloatTensor
28
+
29
+
30
+ class UNet3DConditionModel(ModelMixin, ConfigMixin):
31
+ _supports_gradient_checkpointing = True
32
+
33
+ @register_to_config
34
+ def __init__(
35
+ self,
36
+ sample_size: Optional[int] = None,
37
+ in_channels: int = 4,
38
+ out_channels: int = 4,
39
+ center_input_sample: bool = False,
40
+ flip_sin_to_cos: bool = True,
41
+ freq_shift: int = 0,
42
+ down_block_types: Tuple[str] = (
43
+ "CrossAttnDownBlock3D",
44
+ "CrossAttnDownBlock3D",
45
+ "CrossAttnDownBlock3D",
46
+ "DownBlock3D",
47
+ ),
48
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
49
+ up_block_types: Tuple[str] = (
50
+ "UpBlock3D",
51
+ "CrossAttnUpBlock3D",
52
+ "CrossAttnUpBlock3D",
53
+ "CrossAttnUpBlock3D",
54
+ ),
55
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
56
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
57
+ layers_per_block: int = 2,
58
+ downsample_padding: int = 1,
59
+ mid_block_scale_factor: float = 1,
60
+ act_fn: str = "silu",
61
+ norm_num_groups: int = 32,
62
+ norm_eps: float = 1e-5,
63
+ cross_attention_dim: int = 1280,
64
+ attention_head_dim: Union[int, Tuple[int]] = 8,
65
+ dual_cross_attention: bool = False,
66
+ use_linear_projection: bool = False,
67
+ class_embed_type: Optional[str] = None,
68
+ num_class_embeds: Optional[int] = None,
69
+ upcast_attention: bool = False,
70
+ resnet_time_scale_shift: str = "default",
71
+ use_inflated_groupnorm=False,
72
+ # Additional
73
+ use_motion_module=False,
74
+ motion_module_resolutions=(1, 2, 4, 8),
75
+ motion_module_mid_block=False,
76
+ motion_module_decoder_only=False,
77
+ motion_module_type=None,
78
+ motion_module_kwargs={},
79
+ unet_use_cross_frame_attention=None,
80
+ unet_use_temporal_attention=None,
81
+ ):
82
+ super().__init__()
83
+
84
+ self.sample_size = sample_size
85
+ time_embed_dim = block_out_channels[0] * 4
86
+
87
+ # input
88
+ self.conv_in = InflatedConv3d(
89
+ in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)
90
+ )
91
+
92
+ # time
93
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
94
+ timestep_input_dim = block_out_channels[0]
95
+
96
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
97
+
98
+ # class embedding
99
+ if class_embed_type is None and num_class_embeds is not None:
100
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
101
+ elif class_embed_type == "timestep":
102
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
103
+ elif class_embed_type == "identity":
104
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
105
+ else:
106
+ self.class_embedding = None
107
+
108
+ self.down_blocks = nn.ModuleList([])
109
+ self.mid_block = None
110
+ self.up_blocks = nn.ModuleList([])
111
+
112
+ if isinstance(only_cross_attention, bool):
113
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
114
+
115
+ if isinstance(attention_head_dim, int):
116
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
117
+
118
+ # down
119
+ output_channel = block_out_channels[0]
120
+ for i, down_block_type in enumerate(down_block_types):
121
+ res = 2**i
122
+ input_channel = output_channel
123
+ output_channel = block_out_channels[i]
124
+ is_final_block = i == len(block_out_channels) - 1
125
+
126
+ down_block = get_down_block(
127
+ down_block_type,
128
+ num_layers=layers_per_block,
129
+ in_channels=input_channel,
130
+ out_channels=output_channel,
131
+ temb_channels=time_embed_dim,
132
+ add_downsample=not is_final_block,
133
+ resnet_eps=norm_eps,
134
+ resnet_act_fn=act_fn,
135
+ resnet_groups=norm_num_groups,
136
+ cross_attention_dim=cross_attention_dim,
137
+ attn_num_head_channels=attention_head_dim[i],
138
+ downsample_padding=downsample_padding,
139
+ dual_cross_attention=dual_cross_attention,
140
+ use_linear_projection=use_linear_projection,
141
+ only_cross_attention=only_cross_attention[i],
142
+ upcast_attention=upcast_attention,
143
+ resnet_time_scale_shift=resnet_time_scale_shift,
144
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
145
+ unet_use_temporal_attention=unet_use_temporal_attention,
146
+ use_inflated_groupnorm=use_inflated_groupnorm,
147
+ use_motion_module=use_motion_module
148
+ and (res in motion_module_resolutions)
149
+ and (not motion_module_decoder_only),
150
+ motion_module_type=motion_module_type,
151
+ motion_module_kwargs=motion_module_kwargs,
152
+ )
153
+ self.down_blocks.append(down_block)
154
+
155
+ # mid
156
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
157
+ self.mid_block = UNetMidBlock3DCrossAttn(
158
+ in_channels=block_out_channels[-1],
159
+ temb_channels=time_embed_dim,
160
+ resnet_eps=norm_eps,
161
+ resnet_act_fn=act_fn,
162
+ output_scale_factor=mid_block_scale_factor,
163
+ resnet_time_scale_shift=resnet_time_scale_shift,
164
+ cross_attention_dim=cross_attention_dim,
165
+ attn_num_head_channels=attention_head_dim[-1],
166
+ resnet_groups=norm_num_groups,
167
+ dual_cross_attention=dual_cross_attention,
168
+ use_linear_projection=use_linear_projection,
169
+ upcast_attention=upcast_attention,
170
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
171
+ unet_use_temporal_attention=unet_use_temporal_attention,
172
+ use_inflated_groupnorm=use_inflated_groupnorm,
173
+ use_motion_module=use_motion_module and motion_module_mid_block,
174
+ motion_module_type=motion_module_type,
175
+ motion_module_kwargs=motion_module_kwargs,
176
+ )
177
+ else:
178
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
179
+
180
+ # count how many layers upsample the videos
181
+ self.num_upsamplers = 0
182
+
183
+ # up
184
+ reversed_block_out_channels = list(reversed(block_out_channels))
185
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
186
+ only_cross_attention = list(reversed(only_cross_attention))
187
+ output_channel = reversed_block_out_channels[0]
188
+ for i, up_block_type in enumerate(up_block_types):
189
+ res = 2 ** (3 - i)
190
+ is_final_block = i == len(block_out_channels) - 1
191
+
192
+ prev_output_channel = output_channel
193
+ output_channel = reversed_block_out_channels[i]
194
+ input_channel = reversed_block_out_channels[
195
+ min(i + 1, len(block_out_channels) - 1)
196
+ ]
197
+
198
+ # add upsample block for all BUT final layer
199
+ if not is_final_block:
200
+ add_upsample = True
201
+ self.num_upsamplers += 1
202
+ else:
203
+ add_upsample = False
204
+
205
+ up_block = get_up_block(
206
+ up_block_type,
207
+ num_layers=layers_per_block + 1,
208
+ in_channels=input_channel,
209
+ out_channels=output_channel,
210
+ prev_output_channel=prev_output_channel,
211
+ temb_channels=time_embed_dim,
212
+ add_upsample=add_upsample,
213
+ resnet_eps=norm_eps,
214
+ resnet_act_fn=act_fn,
215
+ resnet_groups=norm_num_groups,
216
+ cross_attention_dim=cross_attention_dim,
217
+ attn_num_head_channels=reversed_attention_head_dim[i],
218
+ dual_cross_attention=dual_cross_attention,
219
+ use_linear_projection=use_linear_projection,
220
+ only_cross_attention=only_cross_attention[i],
221
+ upcast_attention=upcast_attention,
222
+ resnet_time_scale_shift=resnet_time_scale_shift,
223
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
224
+ unet_use_temporal_attention=unet_use_temporal_attention,
225
+ use_inflated_groupnorm=use_inflated_groupnorm,
226
+ use_motion_module=use_motion_module
227
+ and (res in motion_module_resolutions),
228
+ motion_module_type=motion_module_type,
229
+ motion_module_kwargs=motion_module_kwargs,
230
+ )
231
+ self.up_blocks.append(up_block)
232
+ prev_output_channel = output_channel
233
+
234
+ # out
235
+ if use_inflated_groupnorm:
236
+ self.conv_norm_out = InflatedGroupNorm(
237
+ num_channels=block_out_channels[0],
238
+ num_groups=norm_num_groups,
239
+ eps=norm_eps,
240
+ )
241
+ else:
242
+ self.conv_norm_out = nn.GroupNorm(
243
+ num_channels=block_out_channels[0],
244
+ num_groups=norm_num_groups,
245
+ eps=norm_eps,
246
+ )
247
+ self.conv_act = nn.SiLU()
248
+ self.conv_out = InflatedConv3d(
249
+ block_out_channels[0], out_channels, kernel_size=3, padding=1
250
+ )
251
+
252
+ @property
253
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
254
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
255
+ r"""
256
+ Returns:
257
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
258
+ indexed by its weight name.
259
+ """
260
+ # set recursively
261
+ processors = {}
262
+
263
+ def fn_recursive_add_processors(
264
+ name: str,
265
+ module: torch.nn.Module,
266
+ processors: Dict[str, AttentionProcessor],
267
+ ):
268
+ if hasattr(module, "set_processor"):
269
+ processors[f"{name}.processor"] = module.processor
270
+
271
+ for sub_name, child in module.named_children():
272
+ if "temporal_transformer" not in sub_name:
273
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
274
+
275
+ return processors
276
+
277
+ for name, module in self.named_children():
278
+ if "temporal_transformer" not in name:
279
+ fn_recursive_add_processors(name, module, processors)
280
+
281
+ return processors
282
+
283
+ def set_attention_slice(self, slice_size):
284
+ r"""
285
+ Enable sliced attention computation.
286
+
287
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
288
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
289
+
290
+ Args:
291
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
292
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
293
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
294
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
295
+ must be a multiple of `slice_size`.
296
+ """
297
+ sliceable_head_dims = []
298
+
299
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
300
+ if hasattr(module, "set_attention_slice"):
301
+ sliceable_head_dims.append(module.sliceable_head_dim)
302
+
303
+ for child in module.children():
304
+ fn_recursive_retrieve_slicable_dims(child)
305
+
306
+ # retrieve number of attention layers
307
+ for module in self.children():
308
+ fn_recursive_retrieve_slicable_dims(module)
309
+
310
+ num_slicable_layers = len(sliceable_head_dims)
311
+
312
+ if slice_size == "auto":
313
+ # half the attention head size is usually a good trade-off between
314
+ # speed and memory
315
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
316
+ elif slice_size == "max":
317
+ # make smallest slice possible
318
+ slice_size = num_slicable_layers * [1]
319
+
320
+ slice_size = (
321
+ num_slicable_layers * [slice_size]
322
+ if not isinstance(slice_size, list)
323
+ else slice_size
324
+ )
325
+
326
+ if len(slice_size) != len(sliceable_head_dims):
327
+ raise ValueError(
328
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
329
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
330
+ )
331
+
332
+ for i in range(len(slice_size)):
333
+ size = slice_size[i]
334
+ dim = sliceable_head_dims[i]
335
+ if size is not None and size > dim:
336
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
337
+
338
+ # Recursively walk through all the children.
339
+ # Any children which exposes the set_attention_slice method
340
+ # gets the message
341
+ def fn_recursive_set_attention_slice(
342
+ module: torch.nn.Module, slice_size: List[int]
343
+ ):
344
+ if hasattr(module, "set_attention_slice"):
345
+ module.set_attention_slice(slice_size.pop())
346
+
347
+ for child in module.children():
348
+ fn_recursive_set_attention_slice(child, slice_size)
349
+
350
+ reversed_slice_size = list(reversed(slice_size))
351
+ for module in self.children():
352
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
353
+
354
+ def _set_gradient_checkpointing(self, module, value=False):
355
+ if hasattr(module, "gradient_checkpointing"):
356
+ module.gradient_checkpointing = value
357
+
358
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
359
+ def set_attn_processor(
360
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
361
+ ):
362
+ r"""
363
+ Sets the attention processor to use to compute attention.
364
+
365
+ Parameters:
366
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
367
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
368
+ for **all** `Attention` layers.
369
+
370
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
371
+ processor. This is strongly recommended when setting trainable attention processors.
372
+
373
+ """
374
+ count = len(self.attn_processors.keys())
375
+
376
+ if isinstance(processor, dict) and len(processor) != count:
377
+ raise ValueError(
378
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
379
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
380
+ )
381
+
382
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
383
+ if hasattr(module, "set_processor"):
384
+ if not isinstance(processor, dict):
385
+ module.set_processor(processor)
386
+ else:
387
+ module.set_processor(processor.pop(f"{name}.processor"))
388
+
389
+ for sub_name, child in module.named_children():
390
+ if "temporal_transformer" not in sub_name:
391
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
392
+
393
+ for name, module in self.named_children():
394
+ if "temporal_transformer" not in name:
395
+ fn_recursive_attn_processor(name, module, processor)
396
+
397
+ def forward(
398
+ self,
399
+ sample: torch.FloatTensor,
400
+ timestep: Union[torch.Tensor, float, int],
401
+ encoder_hidden_states: torch.Tensor,
402
+ class_labels: Optional[torch.Tensor] = None,
403
+ pose_cond_fea: Optional[torch.Tensor] = None,
404
+ attention_mask: Optional[torch.Tensor] = None,
405
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
406
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
407
+ return_dict: bool = True,
408
+ ) -> Union[UNet3DConditionOutput, Tuple]:
409
+ r"""
410
+ Args:
411
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
412
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
413
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
414
+ return_dict (`bool`, *optional*, defaults to `True`):
415
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
416
+
417
+ Returns:
418
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
419
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
420
+ returning a tuple, the first element is the sample tensor.
421
+ """
422
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
423
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
424
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
425
+ # on the fly if necessary.
426
+ default_overall_up_factor = 2**self.num_upsamplers
427
+
428
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
429
+ forward_upsample_size = False
430
+ upsample_size = None
431
+
432
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
433
+ logger.info("Forward upsample size to force interpolation output size.")
434
+ forward_upsample_size = True
435
+
436
+ # prepare attention_mask
437
+ if attention_mask is not None:
438
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
439
+ attention_mask = attention_mask.unsqueeze(1)
440
+
441
+ # center input if necessary
442
+ if self.config.center_input_sample:
443
+ sample = 2 * sample - 1.0
444
+
445
+ # time
446
+ timesteps = timestep
447
+ if not torch.is_tensor(timesteps):
448
+ # This would be a good case for the `match` statement (Python 3.10+)
449
+ is_mps = sample.device.type == "mps"
450
+ if isinstance(timestep, float):
451
+ dtype = torch.float32 if is_mps else torch.float64
452
+ else:
453
+ dtype = torch.int32 if is_mps else torch.int64
454
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
455
+ elif len(timesteps.shape) == 0:
456
+ timesteps = timesteps[None].to(sample.device)
457
+
458
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
459
+ timesteps = timesteps.expand(sample.shape[0])
460
+
461
+ t_emb = self.time_proj(timesteps)
462
+
463
+ # timesteps does not contain any weights and will always return f32 tensors
464
+ # but time_embedding might actually be running in fp16. so we need to cast here.
465
+ # there might be better ways to encapsulate this.
466
+ t_emb = t_emb.to(dtype=self.dtype)
467
+ emb = self.time_embedding(t_emb)
468
+
469
+ if self.class_embedding is not None:
470
+ if class_labels is None:
471
+ raise ValueError(
472
+ "class_labels should be provided when num_class_embeds > 0"
473
+ )
474
+
475
+ if self.config.class_embed_type == "timestep":
476
+ class_labels = self.time_proj(class_labels)
477
+
478
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
479
+ emb = emb + class_emb
480
+
481
+ # pre-process
482
+ sample = self.conv_in(sample)
483
+ if pose_cond_fea is not None:
484
+ sample = sample + pose_cond_fea
485
+
486
+ # down
487
+ down_block_res_samples = (sample,)
488
+ for downsample_block in self.down_blocks:
489
+ if (
490
+ hasattr(downsample_block, "has_cross_attention")
491
+ and downsample_block.has_cross_attention
492
+ ):
493
+ sample, res_samples = downsample_block(
494
+ hidden_states=sample,
495
+ temb=emb,
496
+ encoder_hidden_states=encoder_hidden_states,
497
+ attention_mask=attention_mask,
498
+ )
499
+ else:
500
+ sample, res_samples = downsample_block(
501
+ hidden_states=sample,
502
+ temb=emb,
503
+ encoder_hidden_states=encoder_hidden_states,
504
+ )
505
+
506
+ down_block_res_samples += res_samples
507
+
508
+ if down_block_additional_residuals is not None:
509
+ new_down_block_res_samples = ()
510
+
511
+ for down_block_res_sample, down_block_additional_residual in zip(
512
+ down_block_res_samples, down_block_additional_residuals
513
+ ):
514
+ down_block_res_sample = (
515
+ down_block_res_sample + down_block_additional_residual
516
+ )
517
+ new_down_block_res_samples += (down_block_res_sample,)
518
+
519
+ down_block_res_samples = new_down_block_res_samples
520
+
521
+ # mid
522
+ sample = self.mid_block(
523
+ sample,
524
+ emb,
525
+ encoder_hidden_states=encoder_hidden_states,
526
+ attention_mask=attention_mask,
527
+ )
528
+
529
+ if mid_block_additional_residual is not None:
530
+ sample = sample + mid_block_additional_residual
531
+
532
+ # up
533
+ for i, upsample_block in enumerate(self.up_blocks):
534
+ is_final_block = i == len(self.up_blocks) - 1
535
+
536
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
537
+ down_block_res_samples = down_block_res_samples[
538
+ : -len(upsample_block.resnets)
539
+ ]
540
+
541
+ # if we have not reached the final block and need to forward the
542
+ # upsample size, we do it here
543
+ if not is_final_block and forward_upsample_size:
544
+ upsample_size = down_block_res_samples[-1].shape[2:]
545
+
546
+ if (
547
+ hasattr(upsample_block, "has_cross_attention")
548
+ and upsample_block.has_cross_attention
549
+ ):
550
+ sample = upsample_block(
551
+ hidden_states=sample,
552
+ temb=emb,
553
+ res_hidden_states_tuple=res_samples,
554
+ encoder_hidden_states=encoder_hidden_states,
555
+ upsample_size=upsample_size,
556
+ attention_mask=attention_mask,
557
+ )
558
+ else:
559
+ sample = upsample_block(
560
+ hidden_states=sample,
561
+ temb=emb,
562
+ res_hidden_states_tuple=res_samples,
563
+ upsample_size=upsample_size,
564
+ encoder_hidden_states=encoder_hidden_states,
565
+ )
566
+
567
+ # post-process
568
+ sample = self.conv_norm_out(sample)
569
+ sample = self.conv_act(sample)
570
+ sample = self.conv_out(sample)
571
+
572
+ if not return_dict:
573
+ return (sample,)
574
+
575
+ return UNet3DConditionOutput(sample=sample)
576
+
577
+ @classmethod
578
+ def from_pretrained_2d(
579
+ cls,
580
+ pretrained_model_path: PathLike,
581
+ motion_module_path: PathLike,
582
+ subfolder=None,
583
+ unet_additional_kwargs=None,
584
+ mm_zero_proj_out=False,
585
+ ):
586
+ pretrained_model_path = Path(pretrained_model_path)
587
+ motion_module_path = Path(motion_module_path)
588
+ if subfolder is not None:
589
+ pretrained_model_path = pretrained_model_path.joinpath(subfolder)
590
+ logger.info(
591
+ f"loaded temporal unet's pretrained weights from {pretrained_model_path} ..."
592
+ )
593
+
594
+ config_file = pretrained_model_path / "config.json"
595
+ if not (config_file.exists() and config_file.is_file()):
596
+ raise RuntimeError(f"{config_file} does not exist or is not a file")
597
+
598
+ unet_config = cls.load_config(config_file)
599
+ unet_config["_class_name"] = cls.__name__
600
+ unet_config["down_block_types"] = [
601
+ "CrossAttnDownBlock3D",
602
+ "CrossAttnDownBlock3D",
603
+ "CrossAttnDownBlock3D",
604
+ "DownBlock3D",
605
+ ]
606
+ unet_config["up_block_types"] = [
607
+ "UpBlock3D",
608
+ "CrossAttnUpBlock3D",
609
+ "CrossAttnUpBlock3D",
610
+ "CrossAttnUpBlock3D",
611
+ ]
612
+ unet_config["mid_block_type"] = "UNetMidBlock3DCrossAttn"
613
+
614
+ model = cls.from_config(unet_config, **unet_additional_kwargs)
615
+ # load the vanilla weights
616
+ if pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME).exists():
617
+ logger.debug(
618
+ f"loading safeTensors weights from {pretrained_model_path} ..."
619
+ )
620
+ state_dict = load_file(
621
+ pretrained_model_path.joinpath(SAFETENSORS_WEIGHTS_NAME), device="cpu"
622
+ )
623
+
624
+ elif pretrained_model_path.joinpath(WEIGHTS_NAME).exists():
625
+ logger.debug(f"loading weights from {pretrained_model_path} ...")
626
+ state_dict = torch.load(
627
+ pretrained_model_path.joinpath(WEIGHTS_NAME),
628
+ map_location="cpu",
629
+ weights_only=True,
630
+ )
631
+ else:
632
+ raise FileNotFoundError(f"no weights file found in {pretrained_model_path}")
633
+
634
+ # load the motion module weights
635
+ if motion_module_path.exists() and motion_module_path.is_file():
636
+ if motion_module_path.suffix.lower() in [".pth", ".pt", ".ckpt"]:
637
+ logger.info(f"Load motion module params from {motion_module_path}")
638
+ motion_state_dict = torch.load(
639
+ motion_module_path, map_location="cpu", weights_only=True
640
+ )
641
+ elif motion_module_path.suffix.lower() == ".safetensors":
642
+ motion_state_dict = load_file(motion_module_path, device="cpu")
643
+ else:
644
+ raise RuntimeError(
645
+ f"unknown file format for motion module weights: {motion_module_path.suffix}"
646
+ )
647
+ if mm_zero_proj_out:
648
+ logger.info(f"Zero initialize proj_out layers in motion module...")
649
+ new_motion_state_dict = OrderedDict()
650
+ for k in motion_state_dict:
651
+ if "proj_out" in k:
652
+ continue
653
+ new_motion_state_dict[k] = motion_state_dict[k]
654
+ motion_state_dict = new_motion_state_dict
655
+
656
+
657
+
658
+ for weight_name in list(motion_state_dict.keys()):
659
+ if weight_name[-2:]== 'pe':
660
+ del motion_state_dict[weight_name]
661
+ # print(weight_name)
662
+
663
+ # merge the state dicts
664
+ state_dict.update(motion_state_dict)
665
+
666
+ # load the weights into the model
667
+ m, u = model.load_state_dict(state_dict, strict=False)
668
+ logger.debug(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};")
669
+
670
+ params = [
671
+ p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()
672
+ ]
673
+ logger.info(f"Loaded {sum(params) / 1e6}M-parameter motion module")
674
+
675
+ return model
musepose/models/unet_3d_blocks.py ADDED
@@ -0,0 +1,871 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import pdb
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from .motion_module import get_motion_module
9
+
10
+ # from .motion_module import get_motion_module
11
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
12
+ from .transformer_3d import Transformer3DModel
13
+
14
+
15
+ def get_down_block(
16
+ down_block_type,
17
+ num_layers,
18
+ in_channels,
19
+ out_channels,
20
+ temb_channels,
21
+ add_downsample,
22
+ resnet_eps,
23
+ resnet_act_fn,
24
+ attn_num_head_channels,
25
+ resnet_groups=None,
26
+ cross_attention_dim=None,
27
+ downsample_padding=None,
28
+ dual_cross_attention=False,
29
+ use_linear_projection=False,
30
+ only_cross_attention=False,
31
+ upcast_attention=False,
32
+ resnet_time_scale_shift="default",
33
+ unet_use_cross_frame_attention=None,
34
+ unet_use_temporal_attention=None,
35
+ use_inflated_groupnorm=None,
36
+ use_motion_module=None,
37
+ motion_module_type=None,
38
+ motion_module_kwargs=None,
39
+ ):
40
+ down_block_type = (
41
+ down_block_type[7:]
42
+ if down_block_type.startswith("UNetRes")
43
+ else down_block_type
44
+ )
45
+ if down_block_type == "DownBlock3D":
46
+ return DownBlock3D(
47
+ num_layers=num_layers,
48
+ in_channels=in_channels,
49
+ out_channels=out_channels,
50
+ temb_channels=temb_channels,
51
+ add_downsample=add_downsample,
52
+ resnet_eps=resnet_eps,
53
+ resnet_act_fn=resnet_act_fn,
54
+ resnet_groups=resnet_groups,
55
+ downsample_padding=downsample_padding,
56
+ resnet_time_scale_shift=resnet_time_scale_shift,
57
+ use_inflated_groupnorm=use_inflated_groupnorm,
58
+ use_motion_module=use_motion_module,
59
+ motion_module_type=motion_module_type,
60
+ motion_module_kwargs=motion_module_kwargs,
61
+ )
62
+ elif down_block_type == "CrossAttnDownBlock3D":
63
+ if cross_attention_dim is None:
64
+ raise ValueError(
65
+ "cross_attention_dim must be specified for CrossAttnDownBlock3D"
66
+ )
67
+ return CrossAttnDownBlock3D(
68
+ num_layers=num_layers,
69
+ in_channels=in_channels,
70
+ out_channels=out_channels,
71
+ temb_channels=temb_channels,
72
+ add_downsample=add_downsample,
73
+ resnet_eps=resnet_eps,
74
+ resnet_act_fn=resnet_act_fn,
75
+ resnet_groups=resnet_groups,
76
+ downsample_padding=downsample_padding,
77
+ cross_attention_dim=cross_attention_dim,
78
+ attn_num_head_channels=attn_num_head_channels,
79
+ dual_cross_attention=dual_cross_attention,
80
+ use_linear_projection=use_linear_projection,
81
+ only_cross_attention=only_cross_attention,
82
+ upcast_attention=upcast_attention,
83
+ resnet_time_scale_shift=resnet_time_scale_shift,
84
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
85
+ unet_use_temporal_attention=unet_use_temporal_attention,
86
+ use_inflated_groupnorm=use_inflated_groupnorm,
87
+ use_motion_module=use_motion_module,
88
+ motion_module_type=motion_module_type,
89
+ motion_module_kwargs=motion_module_kwargs,
90
+ )
91
+ raise ValueError(f"{down_block_type} does not exist.")
92
+
93
+
94
+ def get_up_block(
95
+ up_block_type,
96
+ num_layers,
97
+ in_channels,
98
+ out_channels,
99
+ prev_output_channel,
100
+ temb_channels,
101
+ add_upsample,
102
+ resnet_eps,
103
+ resnet_act_fn,
104
+ attn_num_head_channels,
105
+ resnet_groups=None,
106
+ cross_attention_dim=None,
107
+ dual_cross_attention=False,
108
+ use_linear_projection=False,
109
+ only_cross_attention=False,
110
+ upcast_attention=False,
111
+ resnet_time_scale_shift="default",
112
+ unet_use_cross_frame_attention=None,
113
+ unet_use_temporal_attention=None,
114
+ use_inflated_groupnorm=None,
115
+ use_motion_module=None,
116
+ motion_module_type=None,
117
+ motion_module_kwargs=None,
118
+ ):
119
+ up_block_type = (
120
+ up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
121
+ )
122
+ if up_block_type == "UpBlock3D":
123
+ return UpBlock3D(
124
+ num_layers=num_layers,
125
+ in_channels=in_channels,
126
+ out_channels=out_channels,
127
+ prev_output_channel=prev_output_channel,
128
+ temb_channels=temb_channels,
129
+ add_upsample=add_upsample,
130
+ resnet_eps=resnet_eps,
131
+ resnet_act_fn=resnet_act_fn,
132
+ resnet_groups=resnet_groups,
133
+ resnet_time_scale_shift=resnet_time_scale_shift,
134
+ use_inflated_groupnorm=use_inflated_groupnorm,
135
+ use_motion_module=use_motion_module,
136
+ motion_module_type=motion_module_type,
137
+ motion_module_kwargs=motion_module_kwargs,
138
+ )
139
+ elif up_block_type == "CrossAttnUpBlock3D":
140
+ if cross_attention_dim is None:
141
+ raise ValueError(
142
+ "cross_attention_dim must be specified for CrossAttnUpBlock3D"
143
+ )
144
+ return CrossAttnUpBlock3D(
145
+ num_layers=num_layers,
146
+ in_channels=in_channels,
147
+ out_channels=out_channels,
148
+ prev_output_channel=prev_output_channel,
149
+ temb_channels=temb_channels,
150
+ add_upsample=add_upsample,
151
+ resnet_eps=resnet_eps,
152
+ resnet_act_fn=resnet_act_fn,
153
+ resnet_groups=resnet_groups,
154
+ cross_attention_dim=cross_attention_dim,
155
+ attn_num_head_channels=attn_num_head_channels,
156
+ dual_cross_attention=dual_cross_attention,
157
+ use_linear_projection=use_linear_projection,
158
+ only_cross_attention=only_cross_attention,
159
+ upcast_attention=upcast_attention,
160
+ resnet_time_scale_shift=resnet_time_scale_shift,
161
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
162
+ unet_use_temporal_attention=unet_use_temporal_attention,
163
+ use_inflated_groupnorm=use_inflated_groupnorm,
164
+ use_motion_module=use_motion_module,
165
+ motion_module_type=motion_module_type,
166
+ motion_module_kwargs=motion_module_kwargs,
167
+ )
168
+ raise ValueError(f"{up_block_type} does not exist.")
169
+
170
+
171
+ class UNetMidBlock3DCrossAttn(nn.Module):
172
+ def __init__(
173
+ self,
174
+ in_channels: int,
175
+ temb_channels: int,
176
+ dropout: float = 0.0,
177
+ num_layers: int = 1,
178
+ resnet_eps: float = 1e-6,
179
+ resnet_time_scale_shift: str = "default",
180
+ resnet_act_fn: str = "swish",
181
+ resnet_groups: int = 32,
182
+ resnet_pre_norm: bool = True,
183
+ attn_num_head_channels=1,
184
+ output_scale_factor=1.0,
185
+ cross_attention_dim=1280,
186
+ dual_cross_attention=False,
187
+ use_linear_projection=False,
188
+ upcast_attention=False,
189
+ unet_use_cross_frame_attention=None,
190
+ unet_use_temporal_attention=None,
191
+ use_inflated_groupnorm=None,
192
+ use_motion_module=None,
193
+ motion_module_type=None,
194
+ motion_module_kwargs=None,
195
+ ):
196
+ super().__init__()
197
+
198
+ self.has_cross_attention = True
199
+ self.attn_num_head_channels = attn_num_head_channels
200
+ resnet_groups = (
201
+ resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
202
+ )
203
+
204
+ # there is always at least one resnet
205
+ resnets = [
206
+ ResnetBlock3D(
207
+ in_channels=in_channels,
208
+ out_channels=in_channels,
209
+ temb_channels=temb_channels,
210
+ eps=resnet_eps,
211
+ groups=resnet_groups,
212
+ dropout=dropout,
213
+ time_embedding_norm=resnet_time_scale_shift,
214
+ non_linearity=resnet_act_fn,
215
+ output_scale_factor=output_scale_factor,
216
+ pre_norm=resnet_pre_norm,
217
+ use_inflated_groupnorm=use_inflated_groupnorm,
218
+ )
219
+ ]
220
+ attentions = []
221
+ motion_modules = []
222
+
223
+ for _ in range(num_layers):
224
+ if dual_cross_attention:
225
+ raise NotImplementedError
226
+ attentions.append(
227
+ Transformer3DModel(
228
+ attn_num_head_channels,
229
+ in_channels // attn_num_head_channels,
230
+ in_channels=in_channels,
231
+ num_layers=1,
232
+ cross_attention_dim=cross_attention_dim,
233
+ norm_num_groups=resnet_groups,
234
+ use_linear_projection=use_linear_projection,
235
+ upcast_attention=upcast_attention,
236
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
237
+ unet_use_temporal_attention=unet_use_temporal_attention,
238
+ )
239
+ )
240
+ motion_modules.append(
241
+ get_motion_module(
242
+ in_channels=in_channels,
243
+ motion_module_type=motion_module_type,
244
+ motion_module_kwargs=motion_module_kwargs,
245
+ )
246
+ if use_motion_module
247
+ else None
248
+ )
249
+ resnets.append(
250
+ ResnetBlock3D(
251
+ in_channels=in_channels,
252
+ out_channels=in_channels,
253
+ temb_channels=temb_channels,
254
+ eps=resnet_eps,
255
+ groups=resnet_groups,
256
+ dropout=dropout,
257
+ time_embedding_norm=resnet_time_scale_shift,
258
+ non_linearity=resnet_act_fn,
259
+ output_scale_factor=output_scale_factor,
260
+ pre_norm=resnet_pre_norm,
261
+ use_inflated_groupnorm=use_inflated_groupnorm,
262
+ )
263
+ )
264
+
265
+ self.attentions = nn.ModuleList(attentions)
266
+ self.resnets = nn.ModuleList(resnets)
267
+ self.motion_modules = nn.ModuleList(motion_modules)
268
+
269
+ def forward(
270
+ self,
271
+ hidden_states,
272
+ temb=None,
273
+ encoder_hidden_states=None,
274
+ attention_mask=None,
275
+ ):
276
+ hidden_states = self.resnets[0](hidden_states, temb)
277
+ for attn, resnet, motion_module in zip(
278
+ self.attentions, self.resnets[1:], self.motion_modules
279
+ ):
280
+ hidden_states = attn(
281
+ hidden_states,
282
+ encoder_hidden_states=encoder_hidden_states,
283
+ ).sample
284
+ hidden_states = (
285
+ motion_module(
286
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
287
+ )
288
+ if motion_module is not None
289
+ else hidden_states
290
+ )
291
+ hidden_states = resnet(hidden_states, temb)
292
+
293
+ return hidden_states
294
+
295
+
296
+ class CrossAttnDownBlock3D(nn.Module):
297
+ def __init__(
298
+ self,
299
+ in_channels: int,
300
+ out_channels: int,
301
+ temb_channels: int,
302
+ dropout: float = 0.0,
303
+ num_layers: int = 1,
304
+ resnet_eps: float = 1e-6,
305
+ resnet_time_scale_shift: str = "default",
306
+ resnet_act_fn: str = "swish",
307
+ resnet_groups: int = 32,
308
+ resnet_pre_norm: bool = True,
309
+ attn_num_head_channels=1,
310
+ cross_attention_dim=1280,
311
+ output_scale_factor=1.0,
312
+ downsample_padding=1,
313
+ add_downsample=True,
314
+ dual_cross_attention=False,
315
+ use_linear_projection=False,
316
+ only_cross_attention=False,
317
+ upcast_attention=False,
318
+ unet_use_cross_frame_attention=None,
319
+ unet_use_temporal_attention=None,
320
+ use_inflated_groupnorm=None,
321
+ use_motion_module=None,
322
+ motion_module_type=None,
323
+ motion_module_kwargs=None,
324
+ ):
325
+ super().__init__()
326
+ resnets = []
327
+ attentions = []
328
+ motion_modules = []
329
+
330
+ self.has_cross_attention = True
331
+ self.attn_num_head_channels = attn_num_head_channels
332
+
333
+ for i in range(num_layers):
334
+ in_channels = in_channels if i == 0 else out_channels
335
+ resnets.append(
336
+ ResnetBlock3D(
337
+ in_channels=in_channels,
338
+ out_channels=out_channels,
339
+ temb_channels=temb_channels,
340
+ eps=resnet_eps,
341
+ groups=resnet_groups,
342
+ dropout=dropout,
343
+ time_embedding_norm=resnet_time_scale_shift,
344
+ non_linearity=resnet_act_fn,
345
+ output_scale_factor=output_scale_factor,
346
+ pre_norm=resnet_pre_norm,
347
+ use_inflated_groupnorm=use_inflated_groupnorm,
348
+ )
349
+ )
350
+ if dual_cross_attention:
351
+ raise NotImplementedError
352
+ attentions.append(
353
+ Transformer3DModel(
354
+ attn_num_head_channels,
355
+ out_channels // attn_num_head_channels,
356
+ in_channels=out_channels,
357
+ num_layers=1,
358
+ cross_attention_dim=cross_attention_dim,
359
+ norm_num_groups=resnet_groups,
360
+ use_linear_projection=use_linear_projection,
361
+ only_cross_attention=only_cross_attention,
362
+ upcast_attention=upcast_attention,
363
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
364
+ unet_use_temporal_attention=unet_use_temporal_attention,
365
+ )
366
+ )
367
+ motion_modules.append(
368
+ get_motion_module(
369
+ in_channels=out_channels,
370
+ motion_module_type=motion_module_type,
371
+ motion_module_kwargs=motion_module_kwargs,
372
+ )
373
+ if use_motion_module
374
+ else None
375
+ )
376
+
377
+ self.attentions = nn.ModuleList(attentions)
378
+ self.resnets = nn.ModuleList(resnets)
379
+ self.motion_modules = nn.ModuleList(motion_modules)
380
+
381
+ if add_downsample:
382
+ self.downsamplers = nn.ModuleList(
383
+ [
384
+ Downsample3D(
385
+ out_channels,
386
+ use_conv=True,
387
+ out_channels=out_channels,
388
+ padding=downsample_padding,
389
+ name="op",
390
+ )
391
+ ]
392
+ )
393
+ else:
394
+ self.downsamplers = None
395
+
396
+ self.gradient_checkpointing = False
397
+
398
+ def forward(
399
+ self,
400
+ hidden_states,
401
+ temb=None,
402
+ encoder_hidden_states=None,
403
+ attention_mask=None,
404
+ ):
405
+ output_states = ()
406
+
407
+ for i, (resnet, attn, motion_module) in enumerate(
408
+ zip(self.resnets, self.attentions, self.motion_modules)
409
+ ):
410
+ # self.gradient_checkpointing = False
411
+ if self.training and self.gradient_checkpointing:
412
+
413
+ def create_custom_forward(module, return_dict=None):
414
+ def custom_forward(*inputs):
415
+ if return_dict is not None:
416
+ return module(*inputs, return_dict=return_dict)
417
+ else:
418
+ return module(*inputs)
419
+
420
+ return custom_forward
421
+
422
+ hidden_states = torch.utils.checkpoint.checkpoint(
423
+ create_custom_forward(resnet), hidden_states, temb
424
+ )
425
+ hidden_states = torch.utils.checkpoint.checkpoint(
426
+ create_custom_forward(attn, return_dict=False),
427
+ hidden_states,
428
+ encoder_hidden_states,
429
+ )[0]
430
+
431
+ # add motion module
432
+ if motion_module is not None:
433
+ hidden_states = torch.utils.checkpoint.checkpoint(
434
+ create_custom_forward(motion_module),
435
+ hidden_states.requires_grad_(),
436
+ temb,
437
+ encoder_hidden_states,
438
+ )
439
+
440
+ # # add motion module
441
+ # hidden_states = (
442
+ # motion_module(
443
+ # hidden_states, temb, encoder_hidden_states=encoder_hidden_states
444
+ # )
445
+ # if motion_module is not None
446
+ # else hidden_states
447
+ # )
448
+
449
+ else:
450
+ hidden_states = resnet(hidden_states, temb)
451
+ hidden_states = attn(
452
+ hidden_states,
453
+ encoder_hidden_states=encoder_hidden_states,
454
+ ).sample
455
+
456
+ # add motion module
457
+ hidden_states = (
458
+ motion_module(
459
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
460
+ )
461
+ if motion_module is not None
462
+ else hidden_states
463
+ )
464
+
465
+ output_states += (hidden_states,)
466
+
467
+ if self.downsamplers is not None:
468
+ for downsampler in self.downsamplers:
469
+ hidden_states = downsampler(hidden_states)
470
+
471
+ output_states += (hidden_states,)
472
+
473
+ return hidden_states, output_states
474
+
475
+
476
+ class DownBlock3D(nn.Module):
477
+ def __init__(
478
+ self,
479
+ in_channels: int,
480
+ out_channels: int,
481
+ temb_channels: int,
482
+ dropout: float = 0.0,
483
+ num_layers: int = 1,
484
+ resnet_eps: float = 1e-6,
485
+ resnet_time_scale_shift: str = "default",
486
+ resnet_act_fn: str = "swish",
487
+ resnet_groups: int = 32,
488
+ resnet_pre_norm: bool = True,
489
+ output_scale_factor=1.0,
490
+ add_downsample=True,
491
+ downsample_padding=1,
492
+ use_inflated_groupnorm=None,
493
+ use_motion_module=None,
494
+ motion_module_type=None,
495
+ motion_module_kwargs=None,
496
+ ):
497
+ super().__init__()
498
+ resnets = []
499
+ motion_modules = []
500
+
501
+ # use_motion_module = False
502
+ for i in range(num_layers):
503
+ in_channels = in_channels if i == 0 else out_channels
504
+ resnets.append(
505
+ ResnetBlock3D(
506
+ in_channels=in_channels,
507
+ out_channels=out_channels,
508
+ temb_channels=temb_channels,
509
+ eps=resnet_eps,
510
+ groups=resnet_groups,
511
+ dropout=dropout,
512
+ time_embedding_norm=resnet_time_scale_shift,
513
+ non_linearity=resnet_act_fn,
514
+ output_scale_factor=output_scale_factor,
515
+ pre_norm=resnet_pre_norm,
516
+ use_inflated_groupnorm=use_inflated_groupnorm,
517
+ )
518
+ )
519
+ motion_modules.append(
520
+ get_motion_module(
521
+ in_channels=out_channels,
522
+ motion_module_type=motion_module_type,
523
+ motion_module_kwargs=motion_module_kwargs,
524
+ )
525
+ if use_motion_module
526
+ else None
527
+ )
528
+
529
+ self.resnets = nn.ModuleList(resnets)
530
+ self.motion_modules = nn.ModuleList(motion_modules)
531
+
532
+ if add_downsample:
533
+ self.downsamplers = nn.ModuleList(
534
+ [
535
+ Downsample3D(
536
+ out_channels,
537
+ use_conv=True,
538
+ out_channels=out_channels,
539
+ padding=downsample_padding,
540
+ name="op",
541
+ )
542
+ ]
543
+ )
544
+ else:
545
+ self.downsamplers = None
546
+
547
+ self.gradient_checkpointing = False
548
+
549
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
550
+ output_states = ()
551
+
552
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
553
+ # print(f"DownBlock3D {self.gradient_checkpointing = }")
554
+ if self.training and self.gradient_checkpointing:
555
+
556
+ def create_custom_forward(module):
557
+ def custom_forward(*inputs):
558
+ return module(*inputs)
559
+
560
+ return custom_forward
561
+
562
+ hidden_states = torch.utils.checkpoint.checkpoint(
563
+ create_custom_forward(resnet), hidden_states, temb
564
+ )
565
+ if motion_module is not None:
566
+ hidden_states = torch.utils.checkpoint.checkpoint(
567
+ create_custom_forward(motion_module),
568
+ hidden_states.requires_grad_(),
569
+ temb,
570
+ encoder_hidden_states,
571
+ )
572
+ else:
573
+ hidden_states = resnet(hidden_states, temb)
574
+
575
+ # add motion module
576
+ hidden_states = (
577
+ motion_module(
578
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
579
+ )
580
+ if motion_module is not None
581
+ else hidden_states
582
+ )
583
+
584
+ output_states += (hidden_states,)
585
+
586
+ if self.downsamplers is not None:
587
+ for downsampler in self.downsamplers:
588
+ hidden_states = downsampler(hidden_states)
589
+
590
+ output_states += (hidden_states,)
591
+
592
+ return hidden_states, output_states
593
+
594
+
595
+ class CrossAttnUpBlock3D(nn.Module):
596
+ def __init__(
597
+ self,
598
+ in_channels: int,
599
+ out_channels: int,
600
+ prev_output_channel: int,
601
+ temb_channels: int,
602
+ dropout: float = 0.0,
603
+ num_layers: int = 1,
604
+ resnet_eps: float = 1e-6,
605
+ resnet_time_scale_shift: str = "default",
606
+ resnet_act_fn: str = "swish",
607
+ resnet_groups: int = 32,
608
+ resnet_pre_norm: bool = True,
609
+ attn_num_head_channels=1,
610
+ cross_attention_dim=1280,
611
+ output_scale_factor=1.0,
612
+ add_upsample=True,
613
+ dual_cross_attention=False,
614
+ use_linear_projection=False,
615
+ only_cross_attention=False,
616
+ upcast_attention=False,
617
+ unet_use_cross_frame_attention=None,
618
+ unet_use_temporal_attention=None,
619
+ use_motion_module=None,
620
+ use_inflated_groupnorm=None,
621
+ motion_module_type=None,
622
+ motion_module_kwargs=None,
623
+ ):
624
+ super().__init__()
625
+ resnets = []
626
+ attentions = []
627
+ motion_modules = []
628
+
629
+ self.has_cross_attention = True
630
+ self.attn_num_head_channels = attn_num_head_channels
631
+
632
+ for i in range(num_layers):
633
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
634
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
635
+
636
+ resnets.append(
637
+ ResnetBlock3D(
638
+ in_channels=resnet_in_channels + res_skip_channels,
639
+ out_channels=out_channels,
640
+ temb_channels=temb_channels,
641
+ eps=resnet_eps,
642
+ groups=resnet_groups,
643
+ dropout=dropout,
644
+ time_embedding_norm=resnet_time_scale_shift,
645
+ non_linearity=resnet_act_fn,
646
+ output_scale_factor=output_scale_factor,
647
+ pre_norm=resnet_pre_norm,
648
+ use_inflated_groupnorm=use_inflated_groupnorm,
649
+ )
650
+ )
651
+ if dual_cross_attention:
652
+ raise NotImplementedError
653
+ attentions.append(
654
+ Transformer3DModel(
655
+ attn_num_head_channels,
656
+ out_channels // attn_num_head_channels,
657
+ in_channels=out_channels,
658
+ num_layers=1,
659
+ cross_attention_dim=cross_attention_dim,
660
+ norm_num_groups=resnet_groups,
661
+ use_linear_projection=use_linear_projection,
662
+ only_cross_attention=only_cross_attention,
663
+ upcast_attention=upcast_attention,
664
+ unet_use_cross_frame_attention=unet_use_cross_frame_attention,
665
+ unet_use_temporal_attention=unet_use_temporal_attention,
666
+ )
667
+ )
668
+ motion_modules.append(
669
+ get_motion_module(
670
+ in_channels=out_channels,
671
+ motion_module_type=motion_module_type,
672
+ motion_module_kwargs=motion_module_kwargs,
673
+ )
674
+ if use_motion_module
675
+ else None
676
+ )
677
+
678
+ self.attentions = nn.ModuleList(attentions)
679
+ self.resnets = nn.ModuleList(resnets)
680
+ self.motion_modules = nn.ModuleList(motion_modules)
681
+
682
+ if add_upsample:
683
+ self.upsamplers = nn.ModuleList(
684
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
685
+ )
686
+ else:
687
+ self.upsamplers = None
688
+
689
+ self.gradient_checkpointing = False
690
+
691
+ def forward(
692
+ self,
693
+ hidden_states,
694
+ res_hidden_states_tuple,
695
+ temb=None,
696
+ encoder_hidden_states=None,
697
+ upsample_size=None,
698
+ attention_mask=None,
699
+ ):
700
+ for i, (resnet, attn, motion_module) in enumerate(
701
+ zip(self.resnets, self.attentions, self.motion_modules)
702
+ ):
703
+ # pop res hidden states
704
+ res_hidden_states = res_hidden_states_tuple[-1]
705
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
706
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
707
+
708
+ if self.training and self.gradient_checkpointing:
709
+
710
+ def create_custom_forward(module, return_dict=None):
711
+ def custom_forward(*inputs):
712
+ if return_dict is not None:
713
+ return module(*inputs, return_dict=return_dict)
714
+ else:
715
+ return module(*inputs)
716
+
717
+ return custom_forward
718
+
719
+ hidden_states = torch.utils.checkpoint.checkpoint(
720
+ create_custom_forward(resnet), hidden_states, temb
721
+ )
722
+ hidden_states = attn(
723
+ hidden_states,
724
+ encoder_hidden_states=encoder_hidden_states,
725
+ ).sample
726
+ if motion_module is not None:
727
+ hidden_states = torch.utils.checkpoint.checkpoint(
728
+ create_custom_forward(motion_module),
729
+ hidden_states.requires_grad_(),
730
+ temb,
731
+ encoder_hidden_states,
732
+ )
733
+
734
+ else:
735
+ hidden_states = resnet(hidden_states, temb)
736
+ hidden_states = attn(
737
+ hidden_states,
738
+ encoder_hidden_states=encoder_hidden_states,
739
+ ).sample
740
+
741
+ # add motion module
742
+ hidden_states = (
743
+ motion_module(
744
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
745
+ )
746
+ if motion_module is not None
747
+ else hidden_states
748
+ )
749
+
750
+ if self.upsamplers is not None:
751
+ for upsampler in self.upsamplers:
752
+ hidden_states = upsampler(hidden_states, upsample_size)
753
+
754
+ return hidden_states
755
+
756
+
757
+ class UpBlock3D(nn.Module):
758
+ def __init__(
759
+ self,
760
+ in_channels: int,
761
+ prev_output_channel: int,
762
+ out_channels: int,
763
+ temb_channels: int,
764
+ dropout: float = 0.0,
765
+ num_layers: int = 1,
766
+ resnet_eps: float = 1e-6,
767
+ resnet_time_scale_shift: str = "default",
768
+ resnet_act_fn: str = "swish",
769
+ resnet_groups: int = 32,
770
+ resnet_pre_norm: bool = True,
771
+ output_scale_factor=1.0,
772
+ add_upsample=True,
773
+ use_inflated_groupnorm=None,
774
+ use_motion_module=None,
775
+ motion_module_type=None,
776
+ motion_module_kwargs=None,
777
+ ):
778
+ super().__init__()
779
+ resnets = []
780
+ motion_modules = []
781
+
782
+ # use_motion_module = False
783
+ for i in range(num_layers):
784
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
785
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
786
+
787
+ resnets.append(
788
+ ResnetBlock3D(
789
+ in_channels=resnet_in_channels + res_skip_channels,
790
+ out_channels=out_channels,
791
+ temb_channels=temb_channels,
792
+ eps=resnet_eps,
793
+ groups=resnet_groups,
794
+ dropout=dropout,
795
+ time_embedding_norm=resnet_time_scale_shift,
796
+ non_linearity=resnet_act_fn,
797
+ output_scale_factor=output_scale_factor,
798
+ pre_norm=resnet_pre_norm,
799
+ use_inflated_groupnorm=use_inflated_groupnorm,
800
+ )
801
+ )
802
+ motion_modules.append(
803
+ get_motion_module(
804
+ in_channels=out_channels,
805
+ motion_module_type=motion_module_type,
806
+ motion_module_kwargs=motion_module_kwargs,
807
+ )
808
+ if use_motion_module
809
+ else None
810
+ )
811
+
812
+ self.resnets = nn.ModuleList(resnets)
813
+ self.motion_modules = nn.ModuleList(motion_modules)
814
+
815
+ if add_upsample:
816
+ self.upsamplers = nn.ModuleList(
817
+ [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]
818
+ )
819
+ else:
820
+ self.upsamplers = None
821
+
822
+ self.gradient_checkpointing = False
823
+
824
+ def forward(
825
+ self,
826
+ hidden_states,
827
+ res_hidden_states_tuple,
828
+ temb=None,
829
+ upsample_size=None,
830
+ encoder_hidden_states=None,
831
+ ):
832
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
833
+ # pop res hidden states
834
+ res_hidden_states = res_hidden_states_tuple[-1]
835
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
836
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
837
+
838
+ # print(f"UpBlock3D {self.gradient_checkpointing = }")
839
+ if self.training and self.gradient_checkpointing:
840
+
841
+ def create_custom_forward(module):
842
+ def custom_forward(*inputs):
843
+ return module(*inputs)
844
+
845
+ return custom_forward
846
+
847
+ hidden_states = torch.utils.checkpoint.checkpoint(
848
+ create_custom_forward(resnet), hidden_states, temb
849
+ )
850
+ if motion_module is not None:
851
+ hidden_states = torch.utils.checkpoint.checkpoint(
852
+ create_custom_forward(motion_module),
853
+ hidden_states.requires_grad_(),
854
+ temb,
855
+ encoder_hidden_states,
856
+ )
857
+ else:
858
+ hidden_states = resnet(hidden_states, temb)
859
+ hidden_states = (
860
+ motion_module(
861
+ hidden_states, temb, encoder_hidden_states=encoder_hidden_states
862
+ )
863
+ if motion_module is not None
864
+ else hidden_states
865
+ )
866
+
867
+ if self.upsamplers is not None:
868
+ for upsampler in self.upsamplers:
869
+ hidden_states = upsampler(hidden_states, upsample_size)
870
+
871
+ return hidden_states
musepose/pipelines/__init__.py ADDED
File without changes
musepose/pipelines/context.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO: Adapted from cli
2
+ from typing import Callable, List, Optional
3
+
4
+ import numpy as np
5
+
6
+
7
+ def ordered_halving(val):
8
+ bin_str = f"{val:064b}"
9
+ bin_flip = bin_str[::-1]
10
+ as_int = int(bin_flip, 2)
11
+
12
+ return as_int / (1 << 64)
13
+
14
+
15
+ def uniform(
16
+ step: int = ...,
17
+ num_steps: Optional[int] = None,
18
+ num_frames: int = ...,
19
+ context_size: Optional[int] = None,
20
+ context_stride: int = 3,
21
+ context_overlap: int = 4,
22
+ closed_loop: bool = False,
23
+ ):
24
+ if num_frames <= context_size:
25
+ yield list(range(num_frames))
26
+ return
27
+
28
+ context_stride = min(
29
+ context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1
30
+ )
31
+
32
+ for context_step in 1 << np.arange(context_stride):
33
+ pad = int(round(num_frames * ordered_halving(step)))
34
+ for j in range(
35
+ int(ordered_halving(step) * context_step) + pad,
36
+ num_frames + pad + (0 if closed_loop else -context_overlap),
37
+ (context_size * context_step - context_overlap),
38
+ ):
39
+ yield [
40
+ e % num_frames
41
+ for e in range(j, j + context_size * context_step, context_step)
42
+ ]
43
+
44
+
45
+ def get_context_scheduler(name: str) -> Callable:
46
+ if name == "uniform":
47
+ return uniform
48
+ else:
49
+ raise ValueError(f"Unknown context_overlap policy {name}")
50
+
51
+
52
+ def get_total_steps(
53
+ scheduler,
54
+ timesteps: List[int],
55
+ num_steps: Optional[int] = None,
56
+ num_frames: int = ...,
57
+ context_size: Optional[int] = None,
58
+ context_stride: int = 3,
59
+ context_overlap: int = 4,
60
+ closed_loop: bool = True,
61
+ ):
62
+ return sum(
63
+ len(
64
+ list(
65
+ scheduler(
66
+ i,
67
+ num_steps,
68
+ num_frames,
69
+ context_size,
70
+ context_stride,
71
+ context_overlap,
72
+ )
73
+ )
74
+ )
75
+ for i in range(len(timesteps))
76
+ )
musepose/pipelines/pipeline_pose2img.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (
10
+ DDIMScheduler,
11
+ DPMSolverMultistepScheduler,
12
+ EulerAncestralDiscreteScheduler,
13
+ EulerDiscreteScheduler,
14
+ LMSDiscreteScheduler,
15
+ PNDMScheduler,
16
+ )
17
+ from diffusers.utils import BaseOutput, is_accelerate_available
18
+ from diffusers.utils.torch_utils import randn_tensor
19
+ from einops import rearrange
20
+ from tqdm import tqdm
21
+ from transformers import CLIPImageProcessor
22
+
23
+ from musepose.models.mutual_self_attention import ReferenceAttentionControl
24
+
25
+
26
+ @dataclass
27
+ class Pose2ImagePipelineOutput(BaseOutput):
28
+ images: Union[torch.Tensor, np.ndarray]
29
+
30
+
31
+ class Pose2ImagePipeline(DiffusionPipeline):
32
+ _optional_components = []
33
+
34
+ def __init__(
35
+ self,
36
+ vae,
37
+ image_encoder,
38
+ reference_unet,
39
+ denoising_unet,
40
+ pose_guider,
41
+ scheduler: Union[
42
+ DDIMScheduler,
43
+ PNDMScheduler,
44
+ LMSDiscreteScheduler,
45
+ EulerDiscreteScheduler,
46
+ EulerAncestralDiscreteScheduler,
47
+ DPMSolverMultistepScheduler,
48
+ ],
49
+ ):
50
+ super().__init__()
51
+
52
+ self.register_modules(
53
+ vae=vae,
54
+ image_encoder=image_encoder,
55
+ reference_unet=reference_unet,
56
+ denoising_unet=denoising_unet,
57
+ pose_guider=pose_guider,
58
+ scheduler=scheduler,
59
+ )
60
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
61
+ self.clip_image_processor = CLIPImageProcessor()
62
+ self.ref_image_processor = VaeImageProcessor(
63
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
64
+ )
65
+ self.cond_image_processor = VaeImageProcessor(
66
+ vae_scale_factor=self.vae_scale_factor,
67
+ do_convert_rgb=True,
68
+ do_normalize=False,
69
+ )
70
+
71
+ def enable_vae_slicing(self):
72
+ self.vae.enable_slicing()
73
+
74
+ def disable_vae_slicing(self):
75
+ self.vae.disable_slicing()
76
+
77
+ def enable_sequential_cpu_offload(self, gpu_id=0):
78
+ if is_accelerate_available():
79
+ from accelerate import cpu_offload
80
+ else:
81
+ raise ImportError("Please install accelerate via `pip install accelerate`")
82
+
83
+ device = torch.device(f"cuda:{gpu_id}")
84
+
85
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
86
+ if cpu_offloaded_model is not None:
87
+ cpu_offload(cpu_offloaded_model, device)
88
+
89
+ @property
90
+ def _execution_device(self):
91
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
92
+ return self.device
93
+ for module in self.unet.modules():
94
+ if (
95
+ hasattr(module, "_hf_hook")
96
+ and hasattr(module._hf_hook, "execution_device")
97
+ and module._hf_hook.execution_device is not None
98
+ ):
99
+ return torch.device(module._hf_hook.execution_device)
100
+ return self.device
101
+
102
+ def decode_latents(self, latents):
103
+ video_length = latents.shape[2]
104
+ latents = 1 / 0.18215 * latents
105
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
106
+ # video = self.vae.decode(latents).sample
107
+ video = []
108
+ for frame_idx in tqdm(range(latents.shape[0])):
109
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
110
+ video = torch.cat(video)
111
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
112
+ video = (video / 2 + 0.5).clamp(0, 1)
113
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
114
+ video = video.cpu().float().numpy()
115
+ return video
116
+
117
+ def prepare_extra_step_kwargs(self, generator, eta):
118
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
119
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
120
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
121
+ # and should be between [0, 1]
122
+
123
+ accepts_eta = "eta" in set(
124
+ inspect.signature(self.scheduler.step).parameters.keys()
125
+ )
126
+ extra_step_kwargs = {}
127
+ if accepts_eta:
128
+ extra_step_kwargs["eta"] = eta
129
+
130
+ # check if the scheduler accepts generator
131
+ accepts_generator = "generator" in set(
132
+ inspect.signature(self.scheduler.step).parameters.keys()
133
+ )
134
+ if accepts_generator:
135
+ extra_step_kwargs["generator"] = generator
136
+ return extra_step_kwargs
137
+
138
+ def prepare_latents(
139
+ self,
140
+ batch_size,
141
+ num_channels_latents,
142
+ width,
143
+ height,
144
+ dtype,
145
+ device,
146
+ generator,
147
+ latents=None,
148
+ ):
149
+ shape = (
150
+ batch_size,
151
+ num_channels_latents,
152
+ height // self.vae_scale_factor,
153
+ width // self.vae_scale_factor,
154
+ )
155
+ if isinstance(generator, list) and len(generator) != batch_size:
156
+ raise ValueError(
157
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
158
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
159
+ )
160
+
161
+ if latents is None:
162
+ latents = randn_tensor(
163
+ shape, generator=generator, device=device, dtype=dtype
164
+ )
165
+ else:
166
+ latents = latents.to(device)
167
+
168
+ # scale the initial noise by the standard deviation required by the scheduler
169
+ latents = latents * self.scheduler.init_noise_sigma
170
+ return latents
171
+
172
+ def prepare_condition(
173
+ self,
174
+ cond_image,
175
+ width,
176
+ height,
177
+ device,
178
+ dtype,
179
+ do_classififer_free_guidance=False,
180
+ ):
181
+ image = self.cond_image_processor.preprocess(
182
+ cond_image, height=height, width=width
183
+ ).to(dtype=torch.float32)
184
+
185
+ image = image.to(device=device, dtype=dtype)
186
+
187
+ if do_classififer_free_guidance:
188
+ image = torch.cat([image] * 2)
189
+
190
+ return image
191
+
192
+ @torch.no_grad()
193
+ def __call__(
194
+ self,
195
+ ref_image,
196
+ pose_image,
197
+ width,
198
+ height,
199
+ num_inference_steps,
200
+ guidance_scale,
201
+ num_images_per_prompt=1,
202
+ eta: float = 0.0,
203
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
204
+ output_type: Optional[str] = "tensor",
205
+ return_dict: bool = True,
206
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
207
+ callback_steps: Optional[int] = 1,
208
+ **kwargs,
209
+ ):
210
+ # Default height and width to unet
211
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
212
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
213
+
214
+ device = self._execution_device
215
+
216
+ do_classifier_free_guidance = guidance_scale > 1.0
217
+
218
+ # Prepare timesteps
219
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
220
+ timesteps = self.scheduler.timesteps
221
+
222
+ batch_size = 1
223
+
224
+ # Prepare clip image embeds
225
+ clip_image = self.clip_image_processor.preprocess(
226
+ ref_image.resize((224, 224)), return_tensors="pt"
227
+ ).pixel_values
228
+ clip_image_embeds = self.image_encoder(
229
+ clip_image.to(device, dtype=self.image_encoder.dtype)
230
+ ).image_embeds
231
+ image_prompt_embeds = clip_image_embeds.unsqueeze(1)
232
+ uncond_image_prompt_embeds = torch.zeros_like(image_prompt_embeds)
233
+
234
+ if do_classifier_free_guidance:
235
+ image_prompt_embeds = torch.cat(
236
+ [uncond_image_prompt_embeds, image_prompt_embeds], dim=0
237
+ )
238
+
239
+ reference_control_writer = ReferenceAttentionControl(
240
+ self.reference_unet,
241
+ do_classifier_free_guidance=do_classifier_free_guidance,
242
+ mode="write",
243
+ batch_size=batch_size,
244
+ fusion_blocks="full",
245
+ )
246
+ reference_control_reader = ReferenceAttentionControl(
247
+ self.denoising_unet,
248
+ do_classifier_free_guidance=do_classifier_free_guidance,
249
+ mode="read",
250
+ batch_size=batch_size,
251
+ fusion_blocks="full",
252
+ )
253
+
254
+ num_channels_latents = self.denoising_unet.in_channels
255
+ latents = self.prepare_latents(
256
+ batch_size * num_images_per_prompt,
257
+ num_channels_latents,
258
+ width,
259
+ height,
260
+ clip_image_embeds.dtype,
261
+ device,
262
+ generator,
263
+ )
264
+ latents = latents.unsqueeze(2) # (bs, c, 1, h', w')
265
+ latents_dtype = latents.dtype
266
+
267
+ # Prepare extra step kwargs.
268
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
269
+
270
+ # Prepare ref image latents
271
+ ref_image_tensor = self.ref_image_processor.preprocess(
272
+ ref_image, height=height, width=width
273
+ ) # (bs, c, width, height)
274
+ ref_image_tensor = ref_image_tensor.to(
275
+ dtype=self.vae.dtype, device=self.vae.device
276
+ )
277
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
278
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
279
+
280
+ # Prepare pose condition image
281
+ pose_cond_tensor = self.cond_image_processor.preprocess(
282
+ pose_image, height=height, width=width
283
+ )
284
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
285
+ pose_cond_tensor = pose_cond_tensor.to(
286
+ device=device, dtype=self.pose_guider.dtype
287
+ )
288
+ pose_fea = self.pose_guider(pose_cond_tensor)
289
+ pose_fea = (
290
+ torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
291
+ )
292
+
293
+ # denoising loop
294
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
295
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
296
+ for i, t in enumerate(timesteps):
297
+ # 1. Forward reference image
298
+ if i == 0:
299
+ self.reference_unet(
300
+ ref_image_latents.repeat(
301
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
302
+ ),
303
+ torch.zeros_like(t),
304
+ encoder_hidden_states=image_prompt_embeds,
305
+ return_dict=False,
306
+ )
307
+
308
+ # 2. Update reference unet feature into denosing net
309
+ reference_control_reader.update(reference_control_writer)
310
+
311
+ # 3.1 expand the latents if we are doing classifier free guidance
312
+ latent_model_input = (
313
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
314
+ )
315
+ latent_model_input = self.scheduler.scale_model_input(
316
+ latent_model_input, t
317
+ )
318
+
319
+ noise_pred = self.denoising_unet(
320
+ latent_model_input,
321
+ t,
322
+ encoder_hidden_states=image_prompt_embeds,
323
+ pose_cond_fea=pose_fea,
324
+ return_dict=False,
325
+ )[0]
326
+
327
+ # perform guidance
328
+ if do_classifier_free_guidance:
329
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
330
+ noise_pred = noise_pred_uncond + guidance_scale * (
331
+ noise_pred_text - noise_pred_uncond
332
+ )
333
+
334
+ # compute the previous noisy sample x_t -> x_t-1
335
+ latents = self.scheduler.step(
336
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
337
+ )[0]
338
+
339
+ # call the callback, if provided
340
+ if i == len(timesteps) - 1 or (
341
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
342
+ ):
343
+ progress_bar.update()
344
+ if callback is not None and i % callback_steps == 0:
345
+ step_idx = i // getattr(self.scheduler, "order", 1)
346
+ callback(step_idx, t, latents)
347
+ reference_control_reader.clear()
348
+ reference_control_writer.clear()
349
+
350
+ # Post-processing
351
+ image = self.decode_latents(latents) # (b, c, 1, h, w)
352
+
353
+ # Convert to tensor
354
+ if output_type == "tensor":
355
+ image = torch.from_numpy(image)
356
+
357
+ if not return_dict:
358
+ return image
359
+
360
+ return Pose2ImagePipelineOutput(images=image)
musepose/pipelines/pipeline_pose2vid.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from dataclasses import dataclass
3
+ from typing import Callable, List, Optional, Union
4
+
5
+ import numpy as np
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+ from diffusers.image_processor import VaeImageProcessor
9
+ from diffusers.schedulers import (DDIMScheduler, DPMSolverMultistepScheduler,
10
+ EulerAncestralDiscreteScheduler,
11
+ EulerDiscreteScheduler, LMSDiscreteScheduler,
12
+ PNDMScheduler)
13
+ from diffusers.utils import BaseOutput, is_accelerate_available
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+ from einops import rearrange
16
+ from tqdm import tqdm
17
+ from transformers import CLIPImageProcessor
18
+
19
+ from musepose.models.mutual_self_attention import ReferenceAttentionControl
20
+
21
+
22
+ @dataclass
23
+ class Pose2VideoPipelineOutput(BaseOutput):
24
+ videos: Union[torch.Tensor, np.ndarray]
25
+
26
+
27
+ class Pose2VideoPipeline(DiffusionPipeline):
28
+ _optional_components = []
29
+
30
+ def __init__(
31
+ self,
32
+ vae,
33
+ image_encoder,
34
+ reference_unet,
35
+ denoising_unet,
36
+ pose_guider,
37
+ scheduler: Union[
38
+ DDIMScheduler,
39
+ PNDMScheduler,
40
+ LMSDiscreteScheduler,
41
+ EulerDiscreteScheduler,
42
+ EulerAncestralDiscreteScheduler,
43
+ DPMSolverMultistepScheduler,
44
+ ],
45
+ image_proj_model=None,
46
+ tokenizer=None,
47
+ text_encoder=None,
48
+ ):
49
+ super().__init__()
50
+
51
+ self.register_modules(
52
+ vae=vae,
53
+ image_encoder=image_encoder,
54
+ reference_unet=reference_unet,
55
+ denoising_unet=denoising_unet,
56
+ pose_guider=pose_guider,
57
+ scheduler=scheduler,
58
+ image_proj_model=image_proj_model,
59
+ tokenizer=tokenizer,
60
+ text_encoder=text_encoder,
61
+ )
62
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
63
+ self.clip_image_processor = CLIPImageProcessor()
64
+ self.ref_image_processor = VaeImageProcessor(
65
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
66
+ )
67
+ self.cond_image_processor = VaeImageProcessor(
68
+ vae_scale_factor=self.vae_scale_factor,
69
+ do_convert_rgb=True,
70
+ do_normalize=False,
71
+ )
72
+
73
+ def enable_vae_slicing(self):
74
+ self.vae.enable_slicing()
75
+
76
+ def disable_vae_slicing(self):
77
+ self.vae.disable_slicing()
78
+
79
+ def enable_sequential_cpu_offload(self, gpu_id=0):
80
+ if is_accelerate_available():
81
+ from accelerate import cpu_offload
82
+ else:
83
+ raise ImportError("Please install accelerate via `pip install accelerate`")
84
+
85
+ device = torch.device(f"cuda:{gpu_id}")
86
+
87
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
88
+ if cpu_offloaded_model is not None:
89
+ cpu_offload(cpu_offloaded_model, device)
90
+
91
+ @property
92
+ def _execution_device(self):
93
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
94
+ return self.device
95
+ for module in self.unet.modules():
96
+ if (
97
+ hasattr(module, "_hf_hook")
98
+ and hasattr(module._hf_hook, "execution_device")
99
+ and module._hf_hook.execution_device is not None
100
+ ):
101
+ return torch.device(module._hf_hook.execution_device)
102
+ return self.device
103
+
104
+ def decode_latents(self, latents):
105
+ video_length = latents.shape[2]
106
+ latents = 1 / 0.18215 * latents
107
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
108
+ # video = self.vae.decode(latents).sample
109
+ video = []
110
+ for frame_idx in tqdm(range(latents.shape[0])):
111
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
112
+ video = torch.cat(video)
113
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
114
+ video = (video / 2 + 0.5).clamp(0, 1)
115
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
116
+ video = video.cpu().float().numpy()
117
+ return video
118
+
119
+ def prepare_extra_step_kwargs(self, generator, eta):
120
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
121
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
122
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
123
+ # and should be between [0, 1]
124
+
125
+ accepts_eta = "eta" in set(
126
+ inspect.signature(self.scheduler.step).parameters.keys()
127
+ )
128
+ extra_step_kwargs = {}
129
+ if accepts_eta:
130
+ extra_step_kwargs["eta"] = eta
131
+
132
+ # check if the scheduler accepts generator
133
+ accepts_generator = "generator" in set(
134
+ inspect.signature(self.scheduler.step).parameters.keys()
135
+ )
136
+ if accepts_generator:
137
+ extra_step_kwargs["generator"] = generator
138
+ return extra_step_kwargs
139
+
140
+ def prepare_latents(
141
+ self,
142
+ batch_size,
143
+ num_channels_latents,
144
+ width,
145
+ height,
146
+ video_length,
147
+ dtype,
148
+ device,
149
+ generator,
150
+ latents=None,
151
+ ):
152
+ shape = (
153
+ batch_size,
154
+ num_channels_latents,
155
+ video_length,
156
+ height // self.vae_scale_factor,
157
+ width // self.vae_scale_factor,
158
+ )
159
+ if isinstance(generator, list) and len(generator) != batch_size:
160
+ raise ValueError(
161
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
162
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
163
+ )
164
+
165
+ if latents is None:
166
+ latents = randn_tensor(
167
+ shape, generator=generator, device=device, dtype=dtype
168
+ )
169
+ else:
170
+ latents = latents.to(device)
171
+
172
+ # scale the initial noise by the standard deviation required by the scheduler
173
+ latents = latents * self.scheduler.init_noise_sigma
174
+ return latents
175
+
176
+ def _encode_prompt(
177
+ self,
178
+ prompt,
179
+ device,
180
+ num_videos_per_prompt,
181
+ do_classifier_free_guidance,
182
+ negative_prompt,
183
+ ):
184
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
185
+
186
+ text_inputs = self.tokenizer(
187
+ prompt,
188
+ padding="max_length",
189
+ max_length=self.tokenizer.model_max_length,
190
+ truncation=True,
191
+ return_tensors="pt",
192
+ )
193
+ text_input_ids = text_inputs.input_ids
194
+ untruncated_ids = self.tokenizer(
195
+ prompt, padding="longest", return_tensors="pt"
196
+ ).input_ids
197
+
198
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
199
+ text_input_ids, untruncated_ids
200
+ ):
201
+ removed_text = self.tokenizer.batch_decode(
202
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
203
+ )
204
+
205
+ if (
206
+ hasattr(self.text_encoder.config, "use_attention_mask")
207
+ and self.text_encoder.config.use_attention_mask
208
+ ):
209
+ attention_mask = text_inputs.attention_mask.to(device)
210
+ else:
211
+ attention_mask = None
212
+
213
+ text_embeddings = self.text_encoder(
214
+ text_input_ids.to(device),
215
+ attention_mask=attention_mask,
216
+ )
217
+ text_embeddings = text_embeddings[0]
218
+
219
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
220
+ bs_embed, seq_len, _ = text_embeddings.shape
221
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
222
+ text_embeddings = text_embeddings.view(
223
+ bs_embed * num_videos_per_prompt, seq_len, -1
224
+ )
225
+
226
+ # get unconditional embeddings for classifier free guidance
227
+ if do_classifier_free_guidance:
228
+ uncond_tokens: List[str]
229
+ if negative_prompt is None:
230
+ uncond_tokens = [""] * batch_size
231
+ elif type(prompt) is not type(negative_prompt):
232
+ raise TypeError(
233
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
234
+ f" {type(prompt)}."
235
+ )
236
+ elif isinstance(negative_prompt, str):
237
+ uncond_tokens = [negative_prompt]
238
+ elif batch_size != len(negative_prompt):
239
+ raise ValueError(
240
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
241
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
242
+ " the batch size of `prompt`."
243
+ )
244
+ else:
245
+ uncond_tokens = negative_prompt
246
+
247
+ max_length = text_input_ids.shape[-1]
248
+ uncond_input = self.tokenizer(
249
+ uncond_tokens,
250
+ padding="max_length",
251
+ max_length=max_length,
252
+ truncation=True,
253
+ return_tensors="pt",
254
+ )
255
+
256
+ if (
257
+ hasattr(self.text_encoder.config, "use_attention_mask")
258
+ and self.text_encoder.config.use_attention_mask
259
+ ):
260
+ attention_mask = uncond_input.attention_mask.to(device)
261
+ else:
262
+ attention_mask = None
263
+
264
+ uncond_embeddings = self.text_encoder(
265
+ uncond_input.input_ids.to(device),
266
+ attention_mask=attention_mask,
267
+ )
268
+ uncond_embeddings = uncond_embeddings[0]
269
+
270
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
271
+ seq_len = uncond_embeddings.shape[1]
272
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
273
+ uncond_embeddings = uncond_embeddings.view(
274
+ batch_size * num_videos_per_prompt, seq_len, -1
275
+ )
276
+
277
+ # For classifier free guidance, we need to do two forward passes.
278
+ # Here we concatenate the unconditional and text embeddings into a single batch
279
+ # to avoid doing two forward passes
280
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
281
+
282
+ return text_embeddings
283
+
284
+ @torch.no_grad()
285
+ def __call__(
286
+ self,
287
+ ref_image,
288
+ pose_images,
289
+ width,
290
+ height,
291
+ video_length,
292
+ num_inference_steps,
293
+ guidance_scale,
294
+ num_images_per_prompt=1,
295
+ eta: float = 0.0,
296
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
297
+ output_type: Optional[str] = "tensor",
298
+ return_dict: bool = True,
299
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
300
+ callback_steps: Optional[int] = 1,
301
+ **kwargs,
302
+ ):
303
+ # Default height and width to unet
304
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
305
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
306
+
307
+ device = self._execution_device
308
+
309
+ do_classifier_free_guidance = guidance_scale > 1.0
310
+
311
+ # Prepare timesteps
312
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
313
+ timesteps = self.scheduler.timesteps
314
+
315
+ batch_size = 1
316
+
317
+ # Prepare clip image embeds
318
+ clip_image = self.clip_image_processor.preprocess(
319
+ ref_image, return_tensors="pt"
320
+ ).pixel_values
321
+ clip_image_embeds = self.image_encoder(
322
+ clip_image.to(device, dtype=self.image_encoder.dtype)
323
+ ).image_embeds
324
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
325
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
326
+
327
+ if do_classifier_free_guidance:
328
+ encoder_hidden_states = torch.cat(
329
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
330
+ )
331
+ reference_control_writer = ReferenceAttentionControl(
332
+ self.reference_unet,
333
+ do_classifier_free_guidance=do_classifier_free_guidance,
334
+ mode="write",
335
+ batch_size=batch_size,
336
+ fusion_blocks="full",
337
+ )
338
+ reference_control_reader = ReferenceAttentionControl(
339
+ self.denoising_unet,
340
+ do_classifier_free_guidance=do_classifier_free_guidance,
341
+ mode="read",
342
+ batch_size=batch_size,
343
+ fusion_blocks="full",
344
+ )
345
+
346
+ num_channels_latents = self.denoising_unet.in_channels
347
+ latents = self.prepare_latents(
348
+ batch_size * num_images_per_prompt,
349
+ num_channels_latents,
350
+ width,
351
+ height,
352
+ video_length,
353
+ clip_image_embeds.dtype,
354
+ device,
355
+ generator,
356
+ )
357
+
358
+ # Prepare extra step kwargs.
359
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
360
+
361
+ # Prepare ref image latents
362
+ ref_image_tensor = self.ref_image_processor.preprocess(
363
+ ref_image, height=height, width=width
364
+ ) # (bs, c, width, height)
365
+ ref_image_tensor = ref_image_tensor.to(
366
+ dtype=self.vae.dtype, device=self.vae.device
367
+ )
368
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
369
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
370
+
371
+ # Prepare a list of pose condition images
372
+ pose_cond_tensor_list = []
373
+ for pose_image in pose_images:
374
+ pose_cond_tensor = (
375
+ torch.from_numpy(np.array(pose_image.resize((width, height)))) / 255.0
376
+ )
377
+ pose_cond_tensor = pose_cond_tensor.permute(2, 0, 1).unsqueeze(
378
+ 1
379
+ ) # (c, 1, h, w)
380
+ pose_cond_tensor_list.append(pose_cond_tensor)
381
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=1) # (c, t, h, w)
382
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(0)
383
+ pose_cond_tensor = pose_cond_tensor.to(
384
+ device=device, dtype=self.pose_guider.dtype
385
+ )
386
+ pose_fea = self.pose_guider(pose_cond_tensor)
387
+ pose_fea = (
388
+ torch.cat([pose_fea] * 2) if do_classifier_free_guidance else pose_fea
389
+ )
390
+
391
+ # denoising loop
392
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
393
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
394
+ for i, t in enumerate(timesteps):
395
+ # 1. Forward reference image
396
+ if i == 0:
397
+ self.reference_unet(
398
+ ref_image_latents.repeat(
399
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
400
+ ),
401
+ torch.zeros_like(t),
402
+ # t,
403
+ encoder_hidden_states=encoder_hidden_states,
404
+ return_dict=False,
405
+ )
406
+ reference_control_reader.update(reference_control_writer)
407
+
408
+ # 3.1 expand the latents if we are doing classifier free guidance
409
+ latent_model_input = (
410
+ torch.cat([latents] * 2) if do_classifier_free_guidance else latents
411
+ )
412
+ latent_model_input = self.scheduler.scale_model_input(
413
+ latent_model_input, t
414
+ )
415
+
416
+ noise_pred = self.denoising_unet(
417
+ latent_model_input,
418
+ t,
419
+ encoder_hidden_states=encoder_hidden_states,
420
+ pose_cond_fea=pose_fea,
421
+ return_dict=False,
422
+ )[0]
423
+
424
+ # perform guidance
425
+ if do_classifier_free_guidance:
426
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
427
+ noise_pred = noise_pred_uncond + guidance_scale * (
428
+ noise_pred_text - noise_pred_uncond
429
+ )
430
+
431
+ # compute the previous noisy sample x_t -> x_t-1
432
+ latents = self.scheduler.step(
433
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
434
+ )[0]
435
+
436
+ # call the callback, if provided
437
+ if i == len(timesteps) - 1 or (
438
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
439
+ ):
440
+ progress_bar.update()
441
+ if callback is not None and i % callback_steps == 0:
442
+ step_idx = i // getattr(self.scheduler, "order", 1)
443
+ callback(step_idx, t, latents)
444
+
445
+ reference_control_reader.clear()
446
+ reference_control_writer.clear()
447
+
448
+ # Post-processing
449
+ images = self.decode_latents(latents) # (b, c, f, h, w)
450
+
451
+ # Convert to tensor
452
+ if output_type == "tensor":
453
+ images = torch.from_numpy(images)
454
+
455
+ if not return_dict:
456
+ return images
457
+
458
+ return Pose2VideoPipelineOutput(videos=images)
musepose/pipelines/pipeline_pose2vid_long.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/magic-research/magic-animate/blob/main/magicanimate/pipelines/pipeline_animation.py
2
+ import inspect
3
+ import math
4
+ from dataclasses import dataclass
5
+ from typing import Callable, List, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers import DiffusionPipeline
10
+ from diffusers.image_processor import VaeImageProcessor
11
+ from diffusers.schedulers import (
12
+ DDIMScheduler,
13
+ DPMSolverMultistepScheduler,
14
+ EulerAncestralDiscreteScheduler,
15
+ EulerDiscreteScheduler,
16
+ LMSDiscreteScheduler,
17
+ PNDMScheduler,
18
+ )
19
+ from diffusers.utils import BaseOutput, deprecate, is_accelerate_available, logging
20
+ from diffusers.utils.torch_utils import randn_tensor
21
+ from einops import rearrange
22
+ from tqdm import tqdm
23
+ from transformers import CLIPImageProcessor
24
+
25
+ from musepose.models.mutual_self_attention import ReferenceAttentionControl
26
+ from musepose.pipelines.context import get_context_scheduler
27
+ from musepose.pipelines.utils import get_tensor_interpolation_method
28
+
29
+
30
+ @dataclass
31
+ class Pose2VideoPipelineOutput(BaseOutput):
32
+ videos: Union[torch.Tensor, np.ndarray]
33
+
34
+
35
+ class Pose2VideoPipeline(DiffusionPipeline):
36
+ _optional_components = []
37
+
38
+ def __init__(
39
+ self,
40
+ vae,
41
+ image_encoder,
42
+ reference_unet,
43
+ denoising_unet,
44
+ pose_guider,
45
+ scheduler: Union[
46
+ DDIMScheduler,
47
+ PNDMScheduler,
48
+ LMSDiscreteScheduler,
49
+ EulerDiscreteScheduler,
50
+ EulerAncestralDiscreteScheduler,
51
+ DPMSolverMultistepScheduler,
52
+ ],
53
+ image_proj_model=None,
54
+ tokenizer=None,
55
+ text_encoder=None,
56
+ ):
57
+ super().__init__()
58
+
59
+ self.register_modules(
60
+ vae=vae,
61
+ image_encoder=image_encoder,
62
+ reference_unet=reference_unet,
63
+ denoising_unet=denoising_unet,
64
+ pose_guider=pose_guider,
65
+ scheduler=scheduler,
66
+ image_proj_model=image_proj_model,
67
+ tokenizer=tokenizer,
68
+ text_encoder=text_encoder,
69
+ )
70
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
71
+ self.clip_image_processor = CLIPImageProcessor()
72
+ self.ref_image_processor = VaeImageProcessor(
73
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True
74
+ )
75
+ self.cond_image_processor = VaeImageProcessor(
76
+ vae_scale_factor=self.vae_scale_factor,
77
+ do_convert_rgb=True,
78
+ do_normalize=False,
79
+ )
80
+
81
+ def enable_vae_slicing(self):
82
+ self.vae.enable_slicing()
83
+
84
+ def disable_vae_slicing(self):
85
+ self.vae.disable_slicing()
86
+
87
+ def enable_sequential_cpu_offload(self, gpu_id=0):
88
+ if is_accelerate_available():
89
+ from accelerate import cpu_offload
90
+ else:
91
+ raise ImportError("Please install accelerate via `pip install accelerate`")
92
+
93
+ device = torch.device(f"cuda:{gpu_id}")
94
+
95
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
96
+ if cpu_offloaded_model is not None:
97
+ cpu_offload(cpu_offloaded_model, device)
98
+
99
+ @property
100
+ def _execution_device(self):
101
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
102
+ return self.device
103
+ for module in self.unet.modules():
104
+ if (
105
+ hasattr(module, "_hf_hook")
106
+ and hasattr(module._hf_hook, "execution_device")
107
+ and module._hf_hook.execution_device is not None
108
+ ):
109
+ return torch.device(module._hf_hook.execution_device)
110
+ return self.device
111
+
112
+ def decode_latents(self, latents):
113
+ video_length = latents.shape[2]
114
+ latents = 1 / 0.18215 * latents
115
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
116
+ # video = self.vae.decode(latents).sample
117
+ video = []
118
+ for frame_idx in tqdm(range(latents.shape[0])):
119
+ video.append(self.vae.decode(latents[frame_idx : frame_idx + 1]).sample)
120
+ video = torch.cat(video)
121
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
122
+ video = (video / 2 + 0.5).clamp(0, 1)
123
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
124
+ video = video.cpu().float().numpy()
125
+ return video
126
+
127
+ def prepare_extra_step_kwargs(self, generator, eta):
128
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
129
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
130
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
131
+ # and should be between [0, 1]
132
+
133
+ accepts_eta = "eta" in set(
134
+ inspect.signature(self.scheduler.step).parameters.keys()
135
+ )
136
+ extra_step_kwargs = {}
137
+ if accepts_eta:
138
+ extra_step_kwargs["eta"] = eta
139
+
140
+ # check if the scheduler accepts generator
141
+ accepts_generator = "generator" in set(
142
+ inspect.signature(self.scheduler.step).parameters.keys()
143
+ )
144
+ if accepts_generator:
145
+ extra_step_kwargs["generator"] = generator
146
+ return extra_step_kwargs
147
+
148
+ def prepare_latents(
149
+ self,
150
+ batch_size,
151
+ num_channels_latents,
152
+ width,
153
+ height,
154
+ video_length,
155
+ dtype,
156
+ device,
157
+ generator,
158
+ latents=None,
159
+ ):
160
+ shape = (
161
+ batch_size,
162
+ num_channels_latents,
163
+ video_length,
164
+ height // self.vae_scale_factor,
165
+ width // self.vae_scale_factor,
166
+ )
167
+ if isinstance(generator, list) and len(generator) != batch_size:
168
+ raise ValueError(
169
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
170
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
171
+ )
172
+
173
+ if latents is None:
174
+ latents = randn_tensor(
175
+ shape, generator=generator, device=device, dtype=dtype
176
+ )
177
+ else:
178
+ latents = latents.to(device)
179
+
180
+ # scale the initial noise by the standard deviation required by the scheduler
181
+ latents = latents * self.scheduler.init_noise_sigma
182
+ return latents
183
+
184
+ def _encode_prompt(
185
+ self,
186
+ prompt,
187
+ device,
188
+ num_videos_per_prompt,
189
+ do_classifier_free_guidance,
190
+ negative_prompt,
191
+ ):
192
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
193
+
194
+ text_inputs = self.tokenizer(
195
+ prompt,
196
+ padding="max_length",
197
+ max_length=self.tokenizer.model_max_length,
198
+ truncation=True,
199
+ return_tensors="pt",
200
+ )
201
+ text_input_ids = text_inputs.input_ids
202
+ untruncated_ids = self.tokenizer(
203
+ prompt, padding="longest", return_tensors="pt"
204
+ ).input_ids
205
+
206
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
207
+ text_input_ids, untruncated_ids
208
+ ):
209
+ removed_text = self.tokenizer.batch_decode(
210
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
211
+ )
212
+
213
+ if (
214
+ hasattr(self.text_encoder.config, "use_attention_mask")
215
+ and self.text_encoder.config.use_attention_mask
216
+ ):
217
+ attention_mask = text_inputs.attention_mask.to(device)
218
+ else:
219
+ attention_mask = None
220
+
221
+ text_embeddings = self.text_encoder(
222
+ text_input_ids.to(device),
223
+ attention_mask=attention_mask,
224
+ )
225
+ text_embeddings = text_embeddings[0]
226
+
227
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
228
+ bs_embed, seq_len, _ = text_embeddings.shape
229
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
230
+ text_embeddings = text_embeddings.view(
231
+ bs_embed * num_videos_per_prompt, seq_len, -1
232
+ )
233
+
234
+ # get unconditional embeddings for classifier free guidance
235
+ if do_classifier_free_guidance:
236
+ uncond_tokens: List[str]
237
+ if negative_prompt is None:
238
+ uncond_tokens = [""] * batch_size
239
+ elif type(prompt) is not type(negative_prompt):
240
+ raise TypeError(
241
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
242
+ f" {type(prompt)}."
243
+ )
244
+ elif isinstance(negative_prompt, str):
245
+ uncond_tokens = [negative_prompt]
246
+ elif batch_size != len(negative_prompt):
247
+ raise ValueError(
248
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
249
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
250
+ " the batch size of `prompt`."
251
+ )
252
+ else:
253
+ uncond_tokens = negative_prompt
254
+
255
+ max_length = text_input_ids.shape[-1]
256
+ uncond_input = self.tokenizer(
257
+ uncond_tokens,
258
+ padding="max_length",
259
+ max_length=max_length,
260
+ truncation=True,
261
+ return_tensors="pt",
262
+ )
263
+
264
+ if (
265
+ hasattr(self.text_encoder.config, "use_attention_mask")
266
+ and self.text_encoder.config.use_attention_mask
267
+ ):
268
+ attention_mask = uncond_input.attention_mask.to(device)
269
+ else:
270
+ attention_mask = None
271
+
272
+ uncond_embeddings = self.text_encoder(
273
+ uncond_input.input_ids.to(device),
274
+ attention_mask=attention_mask,
275
+ )
276
+ uncond_embeddings = uncond_embeddings[0]
277
+
278
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
279
+ seq_len = uncond_embeddings.shape[1]
280
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
281
+ uncond_embeddings = uncond_embeddings.view(
282
+ batch_size * num_videos_per_prompt, seq_len, -1
283
+ )
284
+
285
+ # For classifier free guidance, we need to do two forward passes.
286
+ # Here we concatenate the unconditional and text embeddings into a single batch
287
+ # to avoid doing two forward passes
288
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
289
+
290
+ return text_embeddings
291
+
292
+ def interpolate_latents(
293
+ self, latents: torch.Tensor, interpolation_factor: int, device
294
+ ):
295
+ if interpolation_factor < 2:
296
+ return latents
297
+
298
+ new_latents = torch.zeros(
299
+ (
300
+ latents.shape[0],
301
+ latents.shape[1],
302
+ ((latents.shape[2] - 1) * interpolation_factor) + 1,
303
+ latents.shape[3],
304
+ latents.shape[4],
305
+ ),
306
+ device=latents.device,
307
+ dtype=latents.dtype,
308
+ )
309
+
310
+ org_video_length = latents.shape[2]
311
+ rate = [i / interpolation_factor for i in range(interpolation_factor)][1:]
312
+
313
+ new_index = 0
314
+
315
+ v0 = None
316
+ v1 = None
317
+
318
+ for i0, i1 in zip(range(org_video_length), range(org_video_length)[1:]):
319
+ v0 = latents[:, :, i0, :, :]
320
+ v1 = latents[:, :, i1, :, :]
321
+
322
+ new_latents[:, :, new_index, :, :] = v0
323
+ new_index += 1
324
+
325
+ for f in rate:
326
+ v = get_tensor_interpolation_method()(
327
+ v0.to(device=device), v1.to(device=device), f
328
+ )
329
+ new_latents[:, :, new_index, :, :] = v.to(latents.device)
330
+ new_index += 1
331
+
332
+ new_latents[:, :, new_index, :, :] = v1
333
+ new_index += 1
334
+
335
+ return new_latents
336
+
337
+ @torch.no_grad()
338
+ def __call__(
339
+ self,
340
+ ref_image,
341
+ pose_images,
342
+ width,
343
+ height,
344
+ video_length,
345
+ num_inference_steps,
346
+ guidance_scale,
347
+ num_images_per_prompt=1,
348
+ eta: float = 0.0,
349
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
350
+ output_type: Optional[str] = "tensor",
351
+ return_dict: bool = True,
352
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
353
+ callback_steps: Optional[int] = 1,
354
+ context_schedule="uniform",
355
+ context_frames=24,
356
+ context_stride=1,
357
+ context_overlap=4,
358
+ context_batch_size=1,
359
+ interpolation_factor=1,
360
+ **kwargs,
361
+ ):
362
+ # Default height and width to unet
363
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
364
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
365
+
366
+ device = self._execution_device
367
+
368
+ do_classifier_free_guidance = guidance_scale > 1.0
369
+
370
+ # Prepare timesteps
371
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
372
+ timesteps = self.scheduler.timesteps
373
+
374
+ batch_size = 1
375
+
376
+ # Prepare clip image embeds
377
+ clip_image = self.clip_image_processor.preprocess(
378
+ ref_image.resize((224, 224)), return_tensors="pt"
379
+ ).pixel_values
380
+ clip_image_embeds = self.image_encoder(
381
+ clip_image.to(device, dtype=self.image_encoder.dtype)
382
+ ).image_embeds
383
+ encoder_hidden_states = clip_image_embeds.unsqueeze(1)
384
+ uncond_encoder_hidden_states = torch.zeros_like(encoder_hidden_states)
385
+
386
+ if do_classifier_free_guidance:
387
+ encoder_hidden_states = torch.cat(
388
+ [uncond_encoder_hidden_states, encoder_hidden_states], dim=0
389
+ )
390
+
391
+ reference_control_writer = ReferenceAttentionControl(
392
+ self.reference_unet,
393
+ do_classifier_free_guidance=do_classifier_free_guidance,
394
+ mode="write",
395
+ batch_size=batch_size,
396
+ fusion_blocks="full",
397
+ )
398
+ reference_control_reader = ReferenceAttentionControl(
399
+ self.denoising_unet,
400
+ do_classifier_free_guidance=do_classifier_free_guidance,
401
+ mode="read",
402
+ batch_size=batch_size,
403
+ fusion_blocks="full",
404
+ )
405
+
406
+ num_channels_latents = self.denoising_unet.in_channels
407
+ latents = self.prepare_latents(
408
+ batch_size * num_images_per_prompt,
409
+ num_channels_latents,
410
+ width,
411
+ height,
412
+ video_length,
413
+ clip_image_embeds.dtype,
414
+ device,
415
+ generator,
416
+ )
417
+
418
+ # Prepare extra step kwargs.
419
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
420
+
421
+ # Prepare ref image latents
422
+ ref_image_tensor = self.ref_image_processor.preprocess(
423
+ ref_image, height=height, width=width
424
+ ) # (bs, c, width, height)
425
+ ref_image_tensor = ref_image_tensor.to(
426
+ dtype=self.vae.dtype, device=self.vae.device
427
+ )
428
+ ref_image_latents = self.vae.encode(ref_image_tensor).latent_dist.mean
429
+ ref_image_latents = ref_image_latents * 0.18215 # (b, 4, h, w)
430
+
431
+ # Prepare a list of pose condition images
432
+ pose_cond_tensor_list = []
433
+ for pose_image in pose_images:
434
+ pose_cond_tensor = self.cond_image_processor.preprocess(
435
+ pose_image, height=height, width=width
436
+ )
437
+ pose_cond_tensor = pose_cond_tensor.unsqueeze(2) # (bs, c, 1, h, w)
438
+ pose_cond_tensor_list.append(pose_cond_tensor)
439
+ pose_cond_tensor = torch.cat(pose_cond_tensor_list, dim=2) # (bs, c, t, h, w)
440
+ pose_cond_tensor = pose_cond_tensor.to(
441
+ device=device, dtype=self.pose_guider.dtype
442
+ )
443
+ pose_fea = self.pose_guider(pose_cond_tensor)
444
+
445
+ context_scheduler = get_context_scheduler(context_schedule)
446
+
447
+ # denoising loop
448
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
449
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
450
+ for i, t in enumerate(timesteps):
451
+ noise_pred = torch.zeros(
452
+ (
453
+ latents.shape[0] * (2 if do_classifier_free_guidance else 1),
454
+ *latents.shape[1:],
455
+ ),
456
+ device=latents.device,
457
+ dtype=latents.dtype,
458
+ )
459
+ counter = torch.zeros(
460
+ (1, 1, latents.shape[2], 1, 1),
461
+ device=latents.device,
462
+ dtype=latents.dtype,
463
+ )
464
+
465
+ # 1. Forward reference image
466
+ if i == 0:
467
+ self.reference_unet(
468
+ ref_image_latents.repeat(
469
+ (2 if do_classifier_free_guidance else 1), 1, 1, 1
470
+ ),
471
+ torch.zeros_like(t),
472
+ # t,
473
+ encoder_hidden_states=encoder_hidden_states,
474
+ return_dict=False,
475
+ )
476
+ reference_control_reader.update(reference_control_writer)
477
+
478
+ context_queue = list(
479
+ context_scheduler(
480
+ 0,
481
+ num_inference_steps,
482
+ latents.shape[2],
483
+ context_frames,
484
+ context_stride,
485
+ 0,
486
+ )
487
+ )
488
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
489
+
490
+ context_queue = list(
491
+ context_scheduler(
492
+ 0,
493
+ num_inference_steps,
494
+ latents.shape[2],
495
+ context_frames,
496
+ context_stride,
497
+ context_overlap,
498
+ )
499
+ )
500
+
501
+ num_context_batches = math.ceil(len(context_queue) / context_batch_size)
502
+ global_context = []
503
+ for i in range(num_context_batches):
504
+ global_context.append(
505
+ context_queue[
506
+ i * context_batch_size : (i + 1) * context_batch_size
507
+ ]
508
+ )
509
+
510
+ for context in global_context:
511
+ # 3.1 expand the latents if we are doing classifier free guidance
512
+ latent_model_input = (
513
+ torch.cat([latents[:, :, c] for c in context])
514
+ .to(device)
515
+ .repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
516
+ )
517
+ latent_model_input = self.scheduler.scale_model_input(
518
+ latent_model_input, t
519
+ )
520
+ b, c, f, h, w = latent_model_input.shape
521
+ latent_pose_input = torch.cat(
522
+ [pose_fea[:, :, c] for c in context]
523
+ ).repeat(2 if do_classifier_free_guidance else 1, 1, 1, 1, 1)
524
+
525
+ pred = self.denoising_unet(
526
+ latent_model_input,
527
+ t,
528
+ encoder_hidden_states=encoder_hidden_states[:b],
529
+ pose_cond_fea=latent_pose_input,
530
+ return_dict=False,
531
+ )[0]
532
+
533
+ for j, c in enumerate(context):
534
+ noise_pred[:, :, c] = noise_pred[:, :, c] + pred
535
+ counter[:, :, c] = counter[:, :, c] + 1
536
+
537
+ # perform guidance
538
+ if do_classifier_free_guidance:
539
+ noise_pred_uncond, noise_pred_text = (noise_pred / counter).chunk(2)
540
+ noise_pred = noise_pred_uncond + guidance_scale * (
541
+ noise_pred_text - noise_pred_uncond
542
+ )
543
+
544
+ latents = self.scheduler.step(
545
+ noise_pred, t, latents, **extra_step_kwargs
546
+ ).prev_sample
547
+
548
+ if i == len(timesteps) - 1 or (
549
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
550
+ ):
551
+ progress_bar.update()
552
+ if callback is not None and i % callback_steps == 0:
553
+ step_idx = i // getattr(self.scheduler, "order", 1)
554
+ callback(step_idx, t, latents)
555
+
556
+ reference_control_reader.clear()
557
+ reference_control_writer.clear()
558
+
559
+ if interpolation_factor > 0:
560
+ latents = self.interpolate_latents(latents, interpolation_factor, device)
561
+ # Post-processing
562
+ images = self.decode_latents(latents) # (b, c, f, h, w)
563
+
564
+ # Convert to tensor
565
+ if output_type == "tensor":
566
+ images = torch.from_numpy(images)
567
+
568
+ if not return_dict:
569
+ return images
570
+
571
+ return Pose2VideoPipelineOutput(videos=images)
musepose/pipelines/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ tensor_interpolation = None
4
+
5
+
6
+ def get_tensor_interpolation_method():
7
+ return tensor_interpolation
8
+
9
+
10
+ def set_tensor_interpolation_method(is_slerp):
11
+ global tensor_interpolation
12
+ tensor_interpolation = slerp if is_slerp else linear
13
+
14
+
15
+ def linear(v1, v2, t):
16
+ return (1.0 - t) * v1 + t * v2
17
+
18
+
19
+ def slerp(
20
+ v0: torch.Tensor, v1: torch.Tensor, t: float, DOT_THRESHOLD: float = 0.9995
21
+ ) -> torch.Tensor:
22
+ u0 = v0 / v0.norm()
23
+ u1 = v1 / v1.norm()
24
+ dot = (u0 * u1).sum()
25
+ if dot.abs() > DOT_THRESHOLD:
26
+ # logger.info(f'warning: v0 and v1 close to parallel, using linear interpolation instead.')
27
+ return (1.0 - t) * v0 + t * v1
28
+ omega = dot.acos()
29
+ return (((1.0 - t) * omega).sin() * v0 + (t * omega).sin() * v1) / omega.sin()
musepose/utils/util.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import os.path as osp
4
+ import shutil
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import av
9
+ import numpy as np
10
+ import torch
11
+ import torchvision
12
+ from einops import rearrange
13
+ from PIL import Image
14
+
15
+
16
+ def seed_everything(seed):
17
+ import random
18
+
19
+ import numpy as np
20
+
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ np.random.seed(seed % (2**32))
24
+ random.seed(seed)
25
+
26
+
27
+ def import_filename(filename):
28
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
29
+ module = importlib.util.module_from_spec(spec)
30
+ sys.modules[spec.name] = module
31
+ spec.loader.exec_module(module)
32
+ return module
33
+
34
+
35
+ def delete_additional_ckpt(base_path, num_keep):
36
+ dirs = []
37
+ for d in os.listdir(base_path):
38
+ if d.startswith("checkpoint-"):
39
+ dirs.append(d)
40
+ num_tot = len(dirs)
41
+ if num_tot <= num_keep:
42
+ return
43
+ # ensure ckpt is sorted and delete the ealier!
44
+ del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45
+ for d in del_dirs:
46
+ path_to_dir = osp.join(base_path, d)
47
+ if osp.exists(path_to_dir):
48
+ shutil.rmtree(path_to_dir)
49
+
50
+
51
+ def save_videos_from_pil(pil_images, path, fps=8):
52
+ import av
53
+
54
+ save_fmt = Path(path).suffix
55
+ os.makedirs(os.path.dirname(path), exist_ok=True)
56
+ width, height = pil_images[0].size
57
+
58
+ if save_fmt == ".mp4":
59
+ codec = "libx264"
60
+ container = av.open(path, "w")
61
+ stream = container.add_stream(codec, rate=fps)
62
+
63
+ stream.width = width
64
+ stream.height = height
65
+ stream.pix_fmt = 'yuv420p'
66
+ stream.bit_rate = 10000000
67
+ stream.options["crf"] = "18"
68
+
69
+
70
+
71
+ for pil_image in pil_images:
72
+ # pil_image = Image.fromarray(image_arr).convert("RGB")
73
+ av_frame = av.VideoFrame.from_image(pil_image)
74
+ container.mux(stream.encode(av_frame))
75
+ container.mux(stream.encode())
76
+ container.close()
77
+
78
+ elif save_fmt == ".gif":
79
+ pil_images[0].save(
80
+ fp=path,
81
+ format="GIF",
82
+ append_images=pil_images[1:],
83
+ save_all=True,
84
+ duration=(1 / fps * 1000),
85
+ loop=0,
86
+ )
87
+ else:
88
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
89
+
90
+
91
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
92
+ videos = rearrange(videos, "b c t h w -> t b c h w")
93
+ height, width = videos.shape[-2:]
94
+ outputs = []
95
+
96
+ for x in videos:
97
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
98
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
99
+ if rescale:
100
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
101
+ x = (x * 255).numpy().astype(np.uint8)
102
+ x = Image.fromarray(x)
103
+
104
+ outputs.append(x)
105
+
106
+ os.makedirs(os.path.dirname(path), exist_ok=True)
107
+
108
+ save_videos_from_pil(outputs, path, fps)
109
+
110
+
111
+ def read_frames(video_path):
112
+ container = av.open(video_path)
113
+
114
+ video_stream = next(s for s in container.streams if s.type == "video")
115
+ frames = []
116
+ for packet in container.demux(video_stream):
117
+ for frame in packet.decode():
118
+ image = Image.frombytes(
119
+ "RGB",
120
+ (frame.width, frame.height),
121
+ frame.to_rgb().to_ndarray(),
122
+ )
123
+ frames.append(image)
124
+
125
+ return frames
126
+
127
+
128
+ def get_fps(video_path):
129
+ container = av.open(video_path)
130
+ video_stream = next(s for s in container.streams if s.type == "video")
131
+ fps = video_stream.average_rate
132
+ container.close()
133
+ return fps
pose/config/dwpose-l_384x288.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # runtime
2
+ max_epochs = 270
3
+ stage2_num_epochs = 30
4
+ base_lr = 4e-3
5
+
6
+ train_cfg = dict(max_epochs=max_epochs, val_interval=10)
7
+ randomness = dict(seed=21)
8
+
9
+ # optimizer
10
+ optim_wrapper = dict(
11
+ type='OptimWrapper',
12
+ optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
13
+ paramwise_cfg=dict(
14
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
15
+
16
+ # learning rate
17
+ param_scheduler = [
18
+ dict(
19
+ type='LinearLR',
20
+ start_factor=1.0e-5,
21
+ by_epoch=False,
22
+ begin=0,
23
+ end=1000),
24
+ dict(
25
+ # use cosine lr from 150 to 300 epoch
26
+ type='CosineAnnealingLR',
27
+ eta_min=base_lr * 0.05,
28
+ begin=max_epochs // 2,
29
+ end=max_epochs,
30
+ T_max=max_epochs // 2,
31
+ by_epoch=True,
32
+ convert_to_iter_based=True),
33
+ ]
34
+
35
+ # automatically scaling LR based on the actual training batch size
36
+ auto_scale_lr = dict(base_batch_size=512)
37
+
38
+ # codec settings
39
+ codec = dict(
40
+ type='SimCCLabel',
41
+ input_size=(288, 384),
42
+ sigma=(6., 6.93),
43
+ simcc_split_ratio=2.0,
44
+ normalize=False,
45
+ use_dark=False)
46
+
47
+ # model settings
48
+ model = dict(
49
+ type='TopdownPoseEstimator',
50
+ data_preprocessor=dict(
51
+ type='PoseDataPreprocessor',
52
+ mean=[123.675, 116.28, 103.53],
53
+ std=[58.395, 57.12, 57.375],
54
+ bgr_to_rgb=True),
55
+ backbone=dict(
56
+ _scope_='mmdet',
57
+ type='CSPNeXt',
58
+ arch='P5',
59
+ expand_ratio=0.5,
60
+ deepen_factor=1.,
61
+ widen_factor=1.,
62
+ out_indices=(4, ),
63
+ channel_attention=True,
64
+ norm_cfg=dict(type='SyncBN'),
65
+ act_cfg=dict(type='SiLU'),
66
+ init_cfg=dict(
67
+ type='Pretrained',
68
+ prefix='backbone.',
69
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
70
+ 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa
71
+ )),
72
+ head=dict(
73
+ type='RTMCCHead',
74
+ in_channels=1024,
75
+ out_channels=133,
76
+ input_size=codec['input_size'],
77
+ in_featuremap_size=(9, 12),
78
+ simcc_split_ratio=codec['simcc_split_ratio'],
79
+ final_layer_kernel_size=7,
80
+ gau_cfg=dict(
81
+ hidden_dims=256,
82
+ s=128,
83
+ expansion_factor=2,
84
+ dropout_rate=0.,
85
+ drop_path=0.,
86
+ act_fn='SiLU',
87
+ use_rel_bias=False,
88
+ pos_enc=False),
89
+ loss=dict(
90
+ type='KLDiscretLoss',
91
+ use_target_weight=True,
92
+ beta=10.,
93
+ label_softmax=True),
94
+ decoder=codec),
95
+ test_cfg=dict(flip_test=True, ))
96
+
97
+ # base dataset settings
98
+ dataset_type = 'CocoWholeBodyDataset'
99
+ data_mode = 'topdown'
100
+ data_root = '/data/'
101
+
102
+ backend_args = dict(backend='local')
103
+ # backend_args = dict(
104
+ # backend='petrel',
105
+ # path_mapping=dict({
106
+ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/',
107
+ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/'
108
+ # }))
109
+
110
+ # pipelines
111
+ train_pipeline = [
112
+ dict(type='LoadImage', backend_args=backend_args),
113
+ dict(type='GetBBoxCenterScale'),
114
+ dict(type='RandomFlip', direction='horizontal'),
115
+ dict(type='RandomHalfBody'),
116
+ dict(
117
+ type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
118
+ dict(type='TopdownAffine', input_size=codec['input_size']),
119
+ dict(type='mmdet.YOLOXHSVRandomAug'),
120
+ dict(
121
+ type='Albumentation',
122
+ transforms=[
123
+ dict(type='Blur', p=0.1),
124
+ dict(type='MedianBlur', p=0.1),
125
+ dict(
126
+ type='CoarseDropout',
127
+ max_holes=1,
128
+ max_height=0.4,
129
+ max_width=0.4,
130
+ min_holes=1,
131
+ min_height=0.2,
132
+ min_width=0.2,
133
+ p=1.0),
134
+ ]),
135
+ dict(type='GenerateTarget', encoder=codec),
136
+ dict(type='PackPoseInputs')
137
+ ]
138
+ val_pipeline = [
139
+ dict(type='LoadImage', backend_args=backend_args),
140
+ dict(type='GetBBoxCenterScale'),
141
+ dict(type='TopdownAffine', input_size=codec['input_size']),
142
+ dict(type='PackPoseInputs')
143
+ ]
144
+
145
+ train_pipeline_stage2 = [
146
+ dict(type='LoadImage', backend_args=backend_args),
147
+ dict(type='GetBBoxCenterScale'),
148
+ dict(type='RandomFlip', direction='horizontal'),
149
+ dict(type='RandomHalfBody'),
150
+ dict(
151
+ type='RandomBBoxTransform',
152
+ shift_factor=0.,
153
+ scale_factor=[0.75, 1.25],
154
+ rotate_factor=60),
155
+ dict(type='TopdownAffine', input_size=codec['input_size']),
156
+ dict(type='mmdet.YOLOXHSVRandomAug'),
157
+ dict(
158
+ type='Albumentation',
159
+ transforms=[
160
+ dict(type='Blur', p=0.1),
161
+ dict(type='MedianBlur', p=0.1),
162
+ dict(
163
+ type='CoarseDropout',
164
+ max_holes=1,
165
+ max_height=0.4,
166
+ max_width=0.4,
167
+ min_holes=1,
168
+ min_height=0.2,
169
+ min_width=0.2,
170
+ p=0.5),
171
+ ]),
172
+ dict(type='GenerateTarget', encoder=codec),
173
+ dict(type='PackPoseInputs')
174
+ ]
175
+
176
+ datasets = []
177
+ dataset_coco=dict(
178
+ type=dataset_type,
179
+ data_root=data_root,
180
+ data_mode=data_mode,
181
+ ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
182
+ data_prefix=dict(img='coco/train2017/'),
183
+ pipeline=[],
184
+ )
185
+ datasets.append(dataset_coco)
186
+
187
+ scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class',
188
+ 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow',
189
+ 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference']
190
+
191
+ for i in range(len(scene)):
192
+ datasets.append(
193
+ dict(
194
+ type=dataset_type,
195
+ data_root=data_root,
196
+ data_mode=data_mode,
197
+ ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json',
198
+ data_prefix=dict(img='UBody/images/'+scene[i]+'/'),
199
+ pipeline=[],
200
+ )
201
+ )
202
+
203
+ # data loaders
204
+ train_dataloader = dict(
205
+ batch_size=32,
206
+ num_workers=10,
207
+ persistent_workers=True,
208
+ sampler=dict(type='DefaultSampler', shuffle=True),
209
+ dataset=dict(
210
+ type='CombinedDataset',
211
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
212
+ datasets=datasets,
213
+ pipeline=train_pipeline,
214
+ test_mode=False,
215
+ ))
216
+ val_dataloader = dict(
217
+ batch_size=32,
218
+ num_workers=10,
219
+ persistent_workers=True,
220
+ drop_last=False,
221
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
222
+ dataset=dict(
223
+ type=dataset_type,
224
+ data_root=data_root,
225
+ data_mode=data_mode,
226
+ ann_file='coco/annotations/coco_wholebody_val_v1.0.json',
227
+ bbox_file=f'{data_root}coco/person_detection_results/'
228
+ 'COCO_val2017_detections_AP_H_56_person.json',
229
+ data_prefix=dict(img='coco/val2017/'),
230
+ test_mode=True,
231
+ pipeline=val_pipeline,
232
+ ))
233
+ test_dataloader = val_dataloader
234
+
235
+ # hooks
236
+ default_hooks = dict(
237
+ checkpoint=dict(
238
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
239
+
240
+ custom_hooks = [
241
+ dict(
242
+ type='EMAHook',
243
+ ema_type='ExpMomentumEMA',
244
+ momentum=0.0002,
245
+ update_buffers=True,
246
+ priority=49),
247
+ dict(
248
+ type='mmdet.PipelineSwitchHook',
249
+ switch_epoch=max_epochs - stage2_num_epochs,
250
+ switch_pipeline=train_pipeline_stage2)
251
+ ]
252
+
253
+ # evaluators
254
+ val_evaluator = dict(
255
+ type='CocoWholeBodyMetric',
256
+ ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json')
257
+ test_evaluator = val_evaluator
pose/config/yolox_l_8xb8-300e_coco.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ img_scale = (640, 640) # width, height
2
+
3
+ # model settings
4
+ model = dict(
5
+ type='YOLOX',
6
+ data_preprocessor=dict(
7
+ type='DetDataPreprocessor',
8
+ pad_size_divisor=32,
9
+ batch_augments=[
10
+ dict(
11
+ type='BatchSyncRandomResize',
12
+ random_size_range=(480, 800),
13
+ size_divisor=32,
14
+ interval=10)
15
+ ]),
16
+ backbone=dict(
17
+ type='CSPDarknet',
18
+ deepen_factor=1.0,
19
+ widen_factor=1.0,
20
+ out_indices=(2, 3, 4),
21
+ use_depthwise=False,
22
+ spp_kernal_sizes=(5, 9, 13),
23
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
24
+ act_cfg=dict(type='Swish'),
25
+ ),
26
+ neck=dict(
27
+ type='YOLOXPAFPN',
28
+ in_channels=[256, 512, 1024],
29
+ out_channels=256,
30
+ num_csp_blocks=3,
31
+ use_depthwise=False,
32
+ upsample_cfg=dict(scale_factor=2, mode='nearest'),
33
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
34
+ act_cfg=dict(type='Swish')),
35
+ bbox_head=dict(
36
+ type='YOLOXHead',
37
+ num_classes=80,
38
+ in_channels=256,
39
+ feat_channels=256,
40
+ stacked_convs=2,
41
+ strides=(8, 16, 32),
42
+ use_depthwise=False,
43
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
44
+ act_cfg=dict(type='Swish'),
45
+ loss_cls=dict(
46
+ type='CrossEntropyLoss',
47
+ use_sigmoid=True,
48
+ reduction='sum',
49
+ loss_weight=1.0),
50
+ loss_bbox=dict(
51
+ type='IoULoss',
52
+ mode='square',
53
+ eps=1e-16,
54
+ reduction='sum',
55
+ loss_weight=5.0),
56
+ loss_obj=dict(
57
+ type='CrossEntropyLoss',
58
+ use_sigmoid=True,
59
+ reduction='sum',
60
+ loss_weight=1.0),
61
+ loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
62
+ train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
63
+ # In order to align the source code, the threshold of the val phase is
64
+ # 0.01, and the threshold of the test phase is 0.001.
65
+ test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
66
+
67
+ # dataset settings
68
+ data_root = 'data/coco/'
69
+ dataset_type = 'CocoDataset'
70
+
71
+ # Example to use different file client
72
+ # Method 1: simply set the data root and let the file I/O module
73
+ # automatically infer from prefix (not support LMDB and Memcache yet)
74
+
75
+ # data_root = 's3://openmmlab/datasets/detection/coco/'
76
+
77
+ # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
78
+ # backend_args = dict(
79
+ # backend='petrel',
80
+ # path_mapping=dict({
81
+ # './data/': 's3://openmmlab/datasets/detection/',
82
+ # 'data/': 's3://openmmlab/datasets/detection/'
83
+ # }))
84
+ backend_args = None
85
+
86
+ train_pipeline = [
87
+ dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
88
+ dict(
89
+ type='RandomAffine',
90
+ scaling_ratio_range=(0.1, 2),
91
+ # img_scale is (width, height)
92
+ border=(-img_scale[0] // 2, -img_scale[1] // 2)),
93
+ dict(
94
+ type='MixUp',
95
+ img_scale=img_scale,
96
+ ratio_range=(0.8, 1.6),
97
+ pad_val=114.0),
98
+ dict(type='YOLOXHSVRandomAug'),
99
+ dict(type='RandomFlip', prob=0.5),
100
+ # According to the official implementation, multi-scale
101
+ # training is not considered here but in the
102
+ # 'mmdet/models/detectors/yolox.py'.
103
+ # Resize and Pad are for the last 15 epochs when Mosaic,
104
+ # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
105
+ dict(type='Resize', scale=img_scale, keep_ratio=True),
106
+ dict(
107
+ type='Pad',
108
+ pad_to_square=True,
109
+ # If the image is three-channel, the pad value needs
110
+ # to be set separately for each channel.
111
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
112
+ dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
113
+ dict(type='PackDetInputs')
114
+ ]
115
+
116
+ train_dataset = dict(
117
+ # use MultiImageMixDataset wrapper to support mosaic and mixup
118
+ type='MultiImageMixDataset',
119
+ dataset=dict(
120
+ type=dataset_type,
121
+ data_root=data_root,
122
+ ann_file='annotations/instances_train2017.json',
123
+ data_prefix=dict(img='train2017/'),
124
+ pipeline=[
125
+ dict(type='LoadImageFromFile', backend_args=backend_args),
126
+ dict(type='LoadAnnotations', with_bbox=True)
127
+ ],
128
+ filter_cfg=dict(filter_empty_gt=False, min_size=32),
129
+ backend_args=backend_args),
130
+ pipeline=train_pipeline)
131
+
132
+ test_pipeline = [
133
+ dict(type='LoadImageFromFile', backend_args=backend_args),
134
+ dict(type='Resize', scale=img_scale, keep_ratio=True),
135
+ dict(
136
+ type='Pad',
137
+ pad_to_square=True,
138
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
139
+ dict(type='LoadAnnotations', with_bbox=True),
140
+ dict(
141
+ type='PackDetInputs',
142
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
143
+ 'scale_factor'))
144
+ ]
145
+
146
+ train_dataloader = dict(
147
+ batch_size=8,
148
+ num_workers=4,
149
+ persistent_workers=True,
150
+ sampler=dict(type='DefaultSampler', shuffle=True),
151
+ dataset=train_dataset)
152
+ val_dataloader = dict(
153
+ batch_size=8,
154
+ num_workers=4,
155
+ persistent_workers=True,
156
+ drop_last=False,
157
+ sampler=dict(type='DefaultSampler', shuffle=False),
158
+ dataset=dict(
159
+ type=dataset_type,
160
+ data_root=data_root,
161
+ ann_file='annotations/instances_val2017.json',
162
+ data_prefix=dict(img='val2017/'),
163
+ test_mode=True,
164
+ pipeline=test_pipeline,
165
+ backend_args=backend_args))
166
+ test_dataloader = val_dataloader
167
+
168
+ val_evaluator = dict(
169
+ type='CocoMetric',
170
+ ann_file=data_root + 'annotations/instances_val2017.json',
171
+ metric='bbox',
172
+ backend_args=backend_args)
173
+ test_evaluator = val_evaluator
174
+
175
+ # training settings
176
+ max_epochs = 300
177
+ num_last_epochs = 15
178
+ interval = 10
179
+
180
+ train_cfg = dict(max_epochs=max_epochs, val_interval=interval)
181
+
182
+ # optimizer
183
+ # default 8 gpu
184
+ base_lr = 0.01
185
+ optim_wrapper = dict(
186
+ type='OptimWrapper',
187
+ optimizer=dict(
188
+ type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4,
189
+ nesterov=True),
190
+ paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
191
+
192
+ # learning rate
193
+ param_scheduler = [
194
+ dict(
195
+ # use quadratic formula to warm up 5 epochs
196
+ # and lr is updated by iteration
197
+ # TODO: fix default scope in get function
198
+ type='mmdet.QuadraticWarmupLR',
199
+ by_epoch=True,
200
+ begin=0,
201
+ end=5,
202
+ convert_to_iter_based=True),
203
+ dict(
204
+ # use cosine lr from 5 to 285 epoch
205
+ type='CosineAnnealingLR',
206
+ eta_min=base_lr * 0.05,
207
+ begin=5,
208
+ T_max=max_epochs - num_last_epochs,
209
+ end=max_epochs - num_last_epochs,
210
+ by_epoch=True,
211
+ convert_to_iter_based=True),
212
+ dict(
213
+ # use fixed lr during last 15 epochs
214
+ type='ConstantLR',
215
+ by_epoch=True,
216
+ factor=1,
217
+ begin=max_epochs - num_last_epochs,
218
+ end=max_epochs,
219
+ )
220
+ ]
221
+
222
+ default_hooks = dict(
223
+ checkpoint=dict(
224
+ interval=interval,
225
+ max_keep_ckpts=3 # only keep latest 3 checkpoints
226
+ ))
227
+
228
+ custom_hooks = [
229
+ dict(
230
+ type='YOLOXModeSwitchHook',
231
+ num_last_epochs=num_last_epochs,
232
+ priority=48),
233
+ dict(type='SyncNormHook', priority=48),
234
+ dict(
235
+ type='EMAHook',
236
+ ema_type='ExpMomentumEMA',
237
+ momentum=0.0001,
238
+ update_buffers=True,
239
+ priority=49)
240
+ ]
241
+
242
+ # NOTE: `auto_scale_lr` is for automatically scaling LR,
243
+ # USER SHOULD NOT CHANGE ITS VALUES.
244
+ # base_batch_size = (8 GPUs) x (8 samples per GPU)
245
+ auto_scale_lr = dict(base_batch_size=64)
pose/script/dwpose.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Openpose
2
+ # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3
+ # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4
+ # 3rd Edited by ControlNet
5
+ # 4th Edited by ControlNet (added face and correct hands)
6
+
7
+ import os
8
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
9
+
10
+ import cv2
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+
16
+ import pose.script.util as util
17
+
18
+ def resize_image(input_image, resolution):
19
+ H, W, C = input_image.shape
20
+ H = float(H)
21
+ W = float(W)
22
+ k = float(resolution) / min(H, W)
23
+ H *= k
24
+ W *= k
25
+ H = int(np.round(H / 64.0)) * 64
26
+ W = int(np.round(W / 64.0)) * 64
27
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
28
+ return img
29
+
30
+ def HWC3(x):
31
+ assert x.dtype == np.uint8
32
+ if x.ndim == 2:
33
+ x = x[:, :, None]
34
+ assert x.ndim == 3
35
+ H, W, C = x.shape
36
+ assert C == 1 or C == 3 or C == 4
37
+ if C == 3:
38
+ return x
39
+ if C == 1:
40
+ return np.concatenate([x, x, x], axis=2)
41
+ if C == 4:
42
+ color = x[:, :, 0:3].astype(np.float32)
43
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
44
+ y = color * alpha + 255.0 * (1.0 - alpha)
45
+ y = y.clip(0, 255).astype(np.uint8)
46
+ return y
47
+
48
+ def draw_pose(pose, H, W, draw_face):
49
+ bodies = pose['bodies']
50
+ faces = pose['faces']
51
+ hands = pose['hands']
52
+ candidate = bodies['candidate']
53
+ subset = bodies['subset']
54
+
55
+ # only the most significant person
56
+ faces = pose['faces'][:1]
57
+ hands = pose['hands'][:2]
58
+ candidate = bodies['candidate'][:18]
59
+ subset = bodies['subset'][:1]
60
+
61
+ # draw
62
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
63
+ canvas = util.draw_bodypose(canvas, candidate, subset)
64
+ canvas = util.draw_handpose(canvas, hands)
65
+ if draw_face == True:
66
+ canvas = util.draw_facepose(canvas, faces)
67
+
68
+ return canvas
69
+
70
+ class DWposeDetector:
71
+ def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu", keypoints_only=False):
72
+ from pose.script.wholebody import Wholebody
73
+
74
+ self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device)
75
+ self.keypoints_only = keypoints_only
76
+ def to(self, device):
77
+ self.pose_estimation.to(device)
78
+ return self
79
+ '''
80
+ detect_resolution: 短边resize到多少 这是 draw pose 时的原始渲染分辨率。建议1024
81
+ image_resolution: 短边resize到多少 这是 save pose 时的文件分辨率。建议768
82
+
83
+ 实际检测分辨率:
84
+ yolox: (640, 640)
85
+ dwpose:(288, 384)
86
+ '''
87
+
88
+ def __call__(self, input_image, detect_resolution=1024, image_resolution=768, output_type="pil", **kwargs):
89
+
90
+ input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
91
+ # cv2.imshow('', input_image)
92
+ # cv2.waitKey(0)
93
+
94
+ input_image = HWC3(input_image)
95
+ input_image = resize_image(input_image, detect_resolution)
96
+ H, W, C = input_image.shape
97
+
98
+ with torch.no_grad():
99
+ candidate, subset = self.pose_estimation(input_image)
100
+ nums, keys, locs = candidate.shape
101
+ candidate[..., 0] /= float(W)
102
+ candidate[..., 1] /= float(H)
103
+ body = candidate[:,:18].copy()
104
+ body = body.reshape(nums*18, locs)
105
+ score = subset[:,:18]
106
+
107
+ for i in range(len(score)):
108
+ for j in range(len(score[i])):
109
+ if score[i][j] > 0.3:
110
+ score[i][j] = int(18*i+j)
111
+ else:
112
+ score[i][j] = -1
113
+
114
+ un_visible = subset<0.3
115
+ candidate[un_visible] = -1
116
+
117
+ foot = candidate[:,18:24]
118
+
119
+ faces = candidate[:,24:92]
120
+
121
+ hands = candidate[:,92:113]
122
+ hands = np.vstack([hands, candidate[:,113:]])
123
+
124
+ bodies = dict(candidate=body, subset=score)
125
+ pose = dict(bodies=bodies, hands=hands, faces=faces)
126
+
127
+ if self.keypoints_only==True:
128
+ return pose
129
+ else:
130
+ detected_map = draw_pose(pose, H, W, draw_face=False)
131
+ detected_map = HWC3(detected_map)
132
+ img = resize_image(input_image, image_resolution)
133
+ H, W, C = img.shape
134
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
135
+ # cv2.imshow('detected_map',detected_map)
136
+ # cv2.waitKey(0)
137
+
138
+ if output_type == "pil":
139
+ detected_map = cv2.cvtColor(detected_map, cv2.COLOR_BGR2RGB)
140
+ detected_map = Image.fromarray(detected_map)
141
+
142
+ return detected_map, pose
143
+
pose/script/tool.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import os
3
+ import os.path as osp
4
+ import shutil
5
+ import sys
6
+ from pathlib import Path
7
+
8
+ import av
9
+ import numpy as np
10
+ import torch
11
+ import torchvision
12
+ from einops import rearrange
13
+ from PIL import Image
14
+
15
+
16
+ def seed_everything(seed):
17
+ import random
18
+
19
+ import numpy as np
20
+
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed_all(seed)
23
+ np.random.seed(seed % (2**32))
24
+ random.seed(seed)
25
+
26
+
27
+ def import_filename(filename):
28
+ spec = importlib.util.spec_from_file_location("mymodule", filename)
29
+ module = importlib.util.module_from_spec(spec)
30
+ sys.modules[spec.name] = module
31
+ spec.loader.exec_module(module)
32
+ return module
33
+
34
+
35
+ def delete_additional_ckpt(base_path, num_keep):
36
+ dirs = []
37
+ for d in os.listdir(base_path):
38
+ if d.startswith("checkpoint-"):
39
+ dirs.append(d)
40
+ num_tot = len(dirs)
41
+ if num_tot <= num_keep:
42
+ return
43
+ # ensure ckpt is sorted and delete the ealier!
44
+ del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
45
+ for d in del_dirs:
46
+ path_to_dir = osp.join(base_path, d)
47
+ if osp.exists(path_to_dir):
48
+ shutil.rmtree(path_to_dir)
49
+
50
+
51
+ def save_videos_from_pil(pil_images, path, fps):
52
+
53
+ save_fmt = Path(path).suffix
54
+ os.makedirs(os.path.dirname(path), exist_ok=True)
55
+ width, height = pil_images[0].size
56
+
57
+ if save_fmt == ".mp4":
58
+ codec = "libx264"
59
+ container = av.open(path, "w")
60
+ stream = container.add_stream(codec, rate=fps)
61
+
62
+ stream.width = width
63
+ stream.height = height
64
+ stream.pix_fmt = 'yuv420p'
65
+ stream.bit_rate = 10000000
66
+ stream.options["crf"] = "18"
67
+
68
+ for pil_image in pil_images:
69
+ # pil_image = Image.fromarray(image_arr).convert("RGB")
70
+ av_frame = av.VideoFrame.from_image(pil_image)
71
+ container.mux(stream.encode(av_frame))
72
+ container.mux(stream.encode())
73
+ container.close()
74
+
75
+ elif save_fmt == ".gif":
76
+ pil_images[0].save(
77
+ fp=path,
78
+ format="GIF",
79
+ append_images=pil_images[1:],
80
+ save_all=True,
81
+ duration=(1 / fps * 1000),
82
+ loop=0,
83
+ )
84
+ else:
85
+ raise ValueError("Unsupported file type. Use .mp4 or .gif.")
86
+
87
+
88
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
89
+ videos = rearrange(videos, "b c t h w -> t b c h w")
90
+ height, width = videos.shape[-2:]
91
+ outputs = []
92
+
93
+ for x in videos:
94
+ x = torchvision.utils.make_grid(x, nrow=n_rows) # (c h w)
95
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1) # (h w c)
96
+ if rescale:
97
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
98
+ x = (x * 255).numpy().astype(np.uint8)
99
+ x = Image.fromarray(x)
100
+
101
+ outputs.append(x)
102
+
103
+ os.makedirs(os.path.dirname(path), exist_ok=True)
104
+
105
+ save_videos_from_pil(outputs, path, fps)
106
+
107
+
108
+ def read_frames(video_path):
109
+ container = av.open(video_path)
110
+
111
+ video_stream = next(s for s in container.streams if s.type == "video")
112
+ frames = []
113
+ for packet in container.demux(video_stream):
114
+ for frame in packet.decode():
115
+ image = Image.frombytes(
116
+ "RGB",
117
+ (frame.width, frame.height),
118
+ frame.to_rgb().to_ndarray(),
119
+ )
120
+ frames.append(image)
121
+
122
+ return frames
123
+
124
+
125
+ def get_fps(video_path):
126
+ container = av.open(video_path)
127
+ video_stream = next(s for s in container.streams if s.type == "video")
128
+ fps = video_stream.average_rate
129
+ container.close()
130
+ return fps
pose/script/util.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ eps = 0.01
7
+
8
+ def smart_width(d):
9
+ if d<5:
10
+ return 1
11
+ elif d<10:
12
+ return 2
13
+ elif d<20:
14
+ return 3
15
+ elif d<40:
16
+ return 4
17
+ elif d<80:
18
+ return 5
19
+ elif d<160:
20
+ return 6
21
+ elif d<320:
22
+ return 7
23
+ else:
24
+ return 8
25
+
26
+
27
+
28
+ def draw_bodypose(canvas, candidate, subset):
29
+ H, W, C = canvas.shape
30
+ candidate = np.array(candidate)
31
+ subset = np.array(subset)
32
+
33
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
34
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
35
+ [1, 16], [16, 18], [3, 17], [6, 18]]
36
+
37
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
38
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
39
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
40
+
41
+ for i in range(17):
42
+ for n in range(len(subset)):
43
+ index = subset[n][np.array(limbSeq[i]) - 1]
44
+ if -1 in index:
45
+ continue
46
+ Y = candidate[index.astype(int), 0] * float(W)
47
+ X = candidate[index.astype(int), 1] * float(H)
48
+ mX = np.mean(X)
49
+ mY = np.mean(Y)
50
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
51
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
52
+
53
+ width = smart_width(length)
54
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), width), int(angle), 0, 360, 1)
55
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
56
+
57
+ canvas = (canvas * 0.6).astype(np.uint8)
58
+
59
+ for i in range(18):
60
+ for n in range(len(subset)):
61
+ index = int(subset[n][i])
62
+ if index == -1:
63
+ continue
64
+ x, y = candidate[index][0:2]
65
+ x = int(x * W)
66
+ y = int(y * H)
67
+ radius = 4
68
+ cv2.circle(canvas, (int(x), int(y)), radius, colors[i], thickness=-1)
69
+
70
+ return canvas
71
+
72
+
73
+ def draw_handpose(canvas, all_hand_peaks):
74
+ import matplotlib
75
+
76
+ H, W, C = canvas.shape
77
+
78
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
79
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
80
+
81
+ # (person_number*2, 21, 2)
82
+ for i in range(len(all_hand_peaks)):
83
+ peaks = all_hand_peaks[i]
84
+ peaks = np.array(peaks)
85
+
86
+ for ie, e in enumerate(edges):
87
+
88
+ x1, y1 = peaks[e[0]]
89
+ x2, y2 = peaks[e[1]]
90
+
91
+ x1 = int(x1 * W)
92
+ y1 = int(y1 * H)
93
+ x2 = int(x2 * W)
94
+ y2 = int(y2 * H)
95
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
96
+ length = ((x1 - x2) ** 2 + (y1 - y2) ** 2) ** 0.5
97
+ width = smart_width(length)
98
+ cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=width)
99
+
100
+ for _, keyponit in enumerate(peaks):
101
+ x, y = keyponit
102
+
103
+ x = int(x * W)
104
+ y = int(y * H)
105
+ if x > eps and y > eps:
106
+ radius = 3
107
+ cv2.circle(canvas, (x, y), radius, (0, 0, 255), thickness=-1)
108
+ return canvas
109
+
110
+
111
+ def draw_facepose(canvas, all_lmks):
112
+ H, W, C = canvas.shape
113
+ for lmks in all_lmks:
114
+ lmks = np.array(lmks)
115
+ for lmk in lmks:
116
+ x, y = lmk
117
+ x = int(x * W)
118
+ y = int(y * H)
119
+ if x > eps and y > eps:
120
+ radius = 3
121
+ cv2.circle(canvas, (x, y), radius, (255, 255, 255), thickness=-1)
122
+ return canvas
123
+
124
+
125
+
126
+
127
+ # Calculate the resolution
128
+ def size_calculate(h, w, resolution):
129
+
130
+ H = float(h)
131
+ W = float(w)
132
+
133
+ # resize the short edge to the resolution
134
+ k = float(resolution) / min(H, W) # short edge
135
+ H *= k
136
+ W *= k
137
+
138
+ # resize to the nearest integer multiple of 64
139
+ H = int(np.round(H / 64.0)) * 64
140
+ W = int(np.round(W / 64.0)) * 64
141
+ return H, W
142
+
143
+
144
+
145
+ def warpAffine_kps(kps, M):
146
+ a = M[:,:2]
147
+ t = M[:,2]
148
+ kps = np.dot(kps, a.T) + t
149
+ return kps
150
+
151
+
152
+
153
+
pose/script/wholebody.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import numpy as np
4
+ import warnings
5
+
6
+ try:
7
+ import mmcv
8
+ except ImportError:
9
+ warnings.warn(
10
+ "The module 'mmcv' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmcv>=2.0.1'"
11
+ )
12
+
13
+ try:
14
+ from mmpose.apis import inference_topdown
15
+ from mmpose.apis import init_model as init_pose_estimator
16
+ from mmpose.evaluation.functional import nms
17
+ from mmpose.utils import adapt_mmdet_pipeline
18
+ from mmpose.structures import merge_data_samples
19
+ except ImportError:
20
+ warnings.warn(
21
+ "The module 'mmpose' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmpose>=1.1.0'"
22
+ )
23
+
24
+ try:
25
+ from mmdet.apis import inference_detector, init_detector
26
+ except ImportError:
27
+ warnings.warn(
28
+ "The module 'mmdet' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmdet>=3.1.0'"
29
+ )
30
+
31
+
32
+ class Wholebody:
33
+ def __init__(self,
34
+ det_config=None, det_ckpt=None,
35
+ pose_config=None, pose_ckpt=None,
36
+ device="cpu"):
37
+
38
+ if det_config is None:
39
+ det_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py")
40
+
41
+ if pose_config is None:
42
+ pose_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py")
43
+
44
+ if det_ckpt is None:
45
+ det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'
46
+
47
+ if pose_ckpt is None:
48
+ pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth"
49
+
50
+ # build detector
51
+ self.detector = init_detector(det_config, det_ckpt, device=device)
52
+ self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
53
+
54
+ # build pose estimator
55
+ self.pose_estimator = init_pose_estimator(
56
+ pose_config,
57
+ pose_ckpt,
58
+ device=device)
59
+
60
+ def to(self, device):
61
+ self.detector.to(device)
62
+ self.pose_estimator.to(device)
63
+ return self
64
+
65
+ def __call__(self, oriImg):
66
+ # predict bbox
67
+ det_result = inference_detector(self.detector, oriImg)
68
+ pred_instance = det_result.pred_instances.cpu().numpy()
69
+ bboxes = np.concatenate(
70
+ (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
71
+ bboxes = bboxes[np.logical_and(pred_instance.labels == 0,
72
+ pred_instance.scores > 0.5)]
73
+
74
+ # set NMS threshold
75
+ bboxes = bboxes[nms(bboxes, 0.7), :4]
76
+
77
+ # predict keypoints
78
+ if len(bboxes) == 0:
79
+ pose_results = inference_topdown(self.pose_estimator, oriImg)
80
+ else:
81
+ pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes)
82
+ preds = merge_data_samples(pose_results)
83
+ preds = preds.pred_instances
84
+
85
+ # preds = pose_results[0].pred_instances
86
+ keypoints = preds.get('transformed_keypoints',
87
+ preds.keypoints)
88
+ if 'keypoint_scores' in preds:
89
+ scores = preds.keypoint_scores
90
+ else:
91
+ scores = np.ones(keypoints.shape[:-1])
92
+
93
+ if 'keypoints_visible' in preds:
94
+ visible = preds.keypoints_visible
95
+ else:
96
+ visible = np.ones(keypoints.shape[:-1])
97
+ keypoints_info = np.concatenate(
98
+ (keypoints, scores[..., None], visible[..., None]),
99
+ axis=-1)
100
+ # compute neck joint
101
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
102
+ # neck score when visualizing pred
103
+ neck[:, 2:4] = np.logical_and(
104
+ keypoints_info[:, 5, 2:4] > 0.3,
105
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
106
+ new_keypoints_info = np.insert(
107
+ keypoints_info, 17, neck, axis=1)
108
+ mmpose_idx = [
109
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
110
+ ]
111
+ openpose_idx = [
112
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
113
+ ]
114
+ new_keypoints_info[:, openpose_idx] = \
115
+ new_keypoints_info[:, mmpose_idx]
116
+ keypoints_info = new_keypoints_info
117
+
118
+ keypoints, scores, visible = keypoints_info[
119
+ ..., :2], keypoints_info[..., 2], keypoints_info[..., 3]
120
+
121
+ return keypoints, scores
pose_align.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import argparse
3
+ import torch
4
+ import copy
5
+ import cv2
6
+ import os
7
+ import moviepy.video.io.ImageSequenceClip
8
+
9
+ from pose.script.dwpose import DWposeDetector, draw_pose
10
+ from pose.script.util import size_calculate, warpAffine_kps
11
+
12
+
13
+
14
+ '''
15
+ Detect dwpose from img, then align it by scale parameters
16
+ img: frame from the pose video
17
+ detector: DWpose
18
+ scales: scale parameters
19
+ '''
20
+ def align_img(img, pose_ori, scales, detect_resolution, image_resolution):
21
+
22
+ body_pose = copy.deepcopy(pose_ori['bodies']['candidate'])
23
+ hands = copy.deepcopy(pose_ori['hands'])
24
+ faces = copy.deepcopy(pose_ori['faces'])
25
+
26
+ '''
27
+ 计算逻辑:
28
+ 0. 该函数内进行绝对变换,始终保持人体中心点 body_pose[1] 不变
29
+ 1. 先把 ref 和 pose 的高 resize 到一样,且都保持原来的长宽比。
30
+ 2. 用点在图中的实际坐标来计算。
31
+ 3. 实际计算中,把h的坐标归一化到 [0, 1], w为[0, W/H]
32
+ 4. 由于 dwpose 的输出本来就是归一化的坐标,所以h不需要变,w要乘W/H
33
+ 注意:dwpose 输出是 (w, h)
34
+ '''
35
+
36
+ # h不变,w缩放到原比例
37
+ H_in, W_in, C_in = img.shape
38
+ video_ratio = W_in / H_in
39
+ body_pose[:, 0] = body_pose[:, 0] * video_ratio
40
+ hands[:, :, 0] = hands[:, :, 0] * video_ratio
41
+ faces[:, :, 0] = faces[:, :, 0] * video_ratio
42
+
43
+ # scales of 10 body parts
44
+ scale_neck = scales["scale_neck"]
45
+ scale_face = scales["scale_face"]
46
+ scale_shoulder = scales["scale_shoulder"]
47
+ scale_arm_upper = scales["scale_arm_upper"]
48
+ scale_arm_lower = scales["scale_arm_lower"]
49
+ scale_hand = scales["scale_hand"]
50
+ scale_body_len = scales["scale_body_len"]
51
+ scale_leg_upper = scales["scale_leg_upper"]
52
+ scale_leg_lower = scales["scale_leg_lower"]
53
+
54
+ scale_sum = 0
55
+ count = 0
56
+ scale_list = [scale_neck, scale_face, scale_shoulder, scale_arm_upper, scale_arm_lower, scale_hand, scale_body_len, scale_leg_upper, scale_leg_lower]
57
+ for i in range(len(scale_list)):
58
+ if not np.isinf(scale_list[i]):
59
+ scale_sum = scale_sum + scale_list[i]
60
+ count = count + 1
61
+ for i in range(len(scale_list)):
62
+ if np.isinf(scale_list[i]):
63
+ scale_list[i] = scale_sum/count
64
+
65
+
66
+
67
+ # offsets of each part
68
+ offset = dict()
69
+ offset["14_15_16_17_to_0"] = body_pose[[14,15,16,17], :] - body_pose[[0], :]
70
+ offset["3_to_2"] = body_pose[[3], :] - body_pose[[2], :]
71
+ offset["4_to_3"] = body_pose[[4], :] - body_pose[[3], :]
72
+ offset["6_to_5"] = body_pose[[6], :] - body_pose[[5], :]
73
+ offset["7_to_6"] = body_pose[[7], :] - body_pose[[6], :]
74
+ offset["9_to_8"] = body_pose[[9], :] - body_pose[[8], :]
75
+ offset["10_to_9"] = body_pose[[10], :] - body_pose[[9], :]
76
+ offset["12_to_11"] = body_pose[[12], :] - body_pose[[11], :]
77
+ offset["13_to_12"] = body_pose[[13], :] - body_pose[[12], :]
78
+ offset["hand_left_to_4"] = hands[1, :, :] - body_pose[[4], :]
79
+ offset["hand_right_to_7"] = hands[0, :, :] - body_pose[[7], :]
80
+
81
+ # neck
82
+ c_ = body_pose[1]
83
+ cx = c_[0]
84
+ cy = c_[1]
85
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_neck)
86
+
87
+ neck = body_pose[[0], :]
88
+ neck = warpAffine_kps(neck, M)
89
+ body_pose[[0], :] = neck
90
+
91
+ # body_pose_up_shoulder
92
+ c_ = body_pose[0]
93
+ cx = c_[0]
94
+ cy = c_[1]
95
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_face)
96
+
97
+ body_pose_up_shoulder = offset["14_15_16_17_to_0"] + body_pose[[0], :]
98
+ body_pose_up_shoulder = warpAffine_kps(body_pose_up_shoulder, M)
99
+ body_pose[[14,15,16,17], :] = body_pose_up_shoulder
100
+
101
+ # shoulder
102
+ c_ = body_pose[1]
103
+ cx = c_[0]
104
+ cy = c_[1]
105
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_shoulder)
106
+
107
+ body_pose_shoulder = body_pose[[2,5], :]
108
+ body_pose_shoulder = warpAffine_kps(body_pose_shoulder, M)
109
+ body_pose[[2,5], :] = body_pose_shoulder
110
+
111
+ # arm upper left
112
+ c_ = body_pose[2]
113
+ cx = c_[0]
114
+ cy = c_[1]
115
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_arm_upper)
116
+
117
+ elbow = offset["3_to_2"] + body_pose[[2], :]
118
+ elbow = warpAffine_kps(elbow, M)
119
+ body_pose[[3], :] = elbow
120
+
121
+ # arm lower left
122
+ c_ = body_pose[3]
123
+ cx = c_[0]
124
+ cy = c_[1]
125
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_arm_lower)
126
+
127
+ wrist = offset["4_to_3"] + body_pose[[3], :]
128
+ wrist = warpAffine_kps(wrist, M)
129
+ body_pose[[4], :] = wrist
130
+
131
+ # hand left
132
+ c_ = body_pose[4]
133
+ cx = c_[0]
134
+ cy = c_[1]
135
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_hand)
136
+
137
+ hand = offset["hand_left_to_4"] + body_pose[[4], :]
138
+ hand = warpAffine_kps(hand, M)
139
+ hands[1, :, :] = hand
140
+
141
+ # arm upper right
142
+ c_ = body_pose[5]
143
+ cx = c_[0]
144
+ cy = c_[1]
145
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_arm_upper)
146
+
147
+ elbow = offset["6_to_5"] + body_pose[[5], :]
148
+ elbow = warpAffine_kps(elbow, M)
149
+ body_pose[[6], :] = elbow
150
+
151
+ # arm lower right
152
+ c_ = body_pose[6]
153
+ cx = c_[0]
154
+ cy = c_[1]
155
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_arm_lower)
156
+
157
+ wrist = offset["7_to_6"] + body_pose[[6], :]
158
+ wrist = warpAffine_kps(wrist, M)
159
+ body_pose[[7], :] = wrist
160
+
161
+ # hand right
162
+ c_ = body_pose[7]
163
+ cx = c_[0]
164
+ cy = c_[1]
165
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_hand)
166
+
167
+ hand = offset["hand_right_to_7"] + body_pose[[7], :]
168
+ hand = warpAffine_kps(hand, M)
169
+ hands[0, :, :] = hand
170
+
171
+ # body len
172
+ c_ = body_pose[1]
173
+ cx = c_[0]
174
+ cy = c_[1]
175
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_body_len)
176
+
177
+ body_len = body_pose[[8,11], :]
178
+ body_len = warpAffine_kps(body_len, M)
179
+ body_pose[[8,11], :] = body_len
180
+
181
+ # leg upper left
182
+ c_ = body_pose[8]
183
+ cx = c_[0]
184
+ cy = c_[1]
185
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_leg_upper)
186
+
187
+ knee = offset["9_to_8"] + body_pose[[8], :]
188
+ knee = warpAffine_kps(knee, M)
189
+ body_pose[[9], :] = knee
190
+
191
+ # leg lower left
192
+ c_ = body_pose[9]
193
+ cx = c_[0]
194
+ cy = c_[1]
195
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_leg_lower)
196
+
197
+ ankle = offset["10_to_9"] + body_pose[[9], :]
198
+ ankle = warpAffine_kps(ankle, M)
199
+ body_pose[[10], :] = ankle
200
+
201
+ # leg upper right
202
+ c_ = body_pose[11]
203
+ cx = c_[0]
204
+ cy = c_[1]
205
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_leg_upper)
206
+
207
+ knee = offset["12_to_11"] + body_pose[[11], :]
208
+ knee = warpAffine_kps(knee, M)
209
+ body_pose[[12], :] = knee
210
+
211
+ # leg lower right
212
+ c_ = body_pose[12]
213
+ cx = c_[0]
214
+ cy = c_[1]
215
+ M = cv2.getRotationMatrix2D((cx,cy), 0, scale_leg_lower)
216
+
217
+ ankle = offset["13_to_12"] + body_pose[[12], :]
218
+ ankle = warpAffine_kps(ankle, M)
219
+ body_pose[[13], :] = ankle
220
+
221
+ # none part
222
+ body_pose_none = pose_ori['bodies']['candidate'] == -1.
223
+ hands_none = pose_ori['hands'] == -1.
224
+ faces_none = pose_ori['faces'] == -1.
225
+
226
+ body_pose[body_pose_none] = -1.
227
+ hands[hands_none] = -1.
228
+ nan = float('nan')
229
+ if len(hands[np.isnan(hands)]) > 0:
230
+ print('nan')
231
+ faces[faces_none] = -1.
232
+
233
+ # last check nan -> -1.
234
+ body_pose = np.nan_to_num(body_pose, nan=-1.)
235
+ hands = np.nan_to_num(hands, nan=-1.)
236
+ faces = np.nan_to_num(faces, nan=-1.)
237
+
238
+ # return
239
+ pose_align = copy.deepcopy(pose_ori)
240
+ pose_align['bodies']['candidate'] = body_pose
241
+ pose_align['hands'] = hands
242
+ pose_align['faces'] = faces
243
+
244
+ return pose_align
245
+
246
+
247
+
248
+ def run_align_video_with_filterPose_translate_smooth(args):
249
+
250
+ vidfn=args.vidfn
251
+ imgfn_refer=args.imgfn_refer
252
+ outfn=args.outfn
253
+
254
+ video = cv2.VideoCapture(vidfn)
255
+ width= video.get(cv2.CAP_PROP_FRAME_WIDTH)
256
+ height= video.get(cv2.CAP_PROP_FRAME_HEIGHT)
257
+
258
+ total_frame= video.get(cv2.CAP_PROP_FRAME_COUNT)
259
+ fps= video.get(cv2.CAP_PROP_FPS)
260
+
261
+ print("height:", height)
262
+ print("width:", width)
263
+ print("fps:", fps)
264
+
265
+ H_in, W_in = height, width
266
+ H_out, W_out = size_calculate(H_in,W_in,args.detect_resolution)
267
+ H_out, W_out = size_calculate(H_out,W_out,args.image_resolution)
268
+
269
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
270
+ detector = DWposeDetector(
271
+ det_config = args.yolox_config,
272
+ det_ckpt = args.yolox_ckpt,
273
+ pose_config = args.dwpose_config,
274
+ pose_ckpt = args.dwpose_ckpt,
275
+ keypoints_only=False
276
+ )
277
+ detector = detector.to(device)
278
+
279
+ refer_img = cv2.imread(imgfn_refer)
280
+ output_refer, pose_refer = detector(refer_img,detect_resolution=args.detect_resolution, image_resolution=args.image_resolution, output_type='cv2',return_pose_dict=True)
281
+ body_ref_img = pose_refer['bodies']['candidate']
282
+ hands_ref_img = pose_refer['hands']
283
+ faces_ref_img = pose_refer['faces']
284
+ output_refer = cv2.cvtColor(output_refer, cv2.COLOR_RGB2BGR)
285
+
286
+
287
+ skip_frames = args.align_frame
288
+ max_frame = args.max_frame
289
+ pose_list, video_frame_buffer, video_pose_buffer = [], [], []
290
+
291
+
292
+ cap = cv2.VideoCapture('2.mp4') # 读取视频
293
+ while cap.isOpened(): # 当视频被打开时:
294
+ ret, frame = cap.read() # 读取视频,读取到的某一帧存储到frame,若是读取成功,ret为True,反之为False
295
+ if ret: # 若是读取成功
296
+ cv2.imshow('frame', frame) # 显示读取到的这一帧画面
297
+ key = cv2.waitKey(25) # 等待一段时间,并且检测键盘输入
298
+ if key == ord('q'): # 若是键盘输入'q',则退出,释放视频
299
+ cap.release() # 释放视频
300
+ break
301
+ else:
302
+ cap.release()
303
+ cv2.destroyAllWindows() # 关闭所有窗口
304
+
305
+
306
+ for i in range(max_frame):
307
+ ret, img = video.read()
308
+ if img is None:
309
+ break
310
+ else:
311
+ if i < skip_frames:
312
+ continue
313
+ video_frame_buffer.append(img)
314
+
315
+
316
+
317
+ # estimate scale parameters by the 1st frame in the video
318
+ if i==skip_frames:
319
+ output_1st_img, pose_1st_img = detector(img, args.detect_resolution, args.image_resolution, output_type='cv2', return_pose_dict=True)
320
+ body_1st_img = pose_1st_img['bodies']['candidate']
321
+ hands_1st_img = pose_1st_img['hands']
322
+ faces_1st_img = pose_1st_img['faces']
323
+
324
+ '''
325
+ 计算逻辑:
326
+ 1. 先把 ref 和 pose 的高 resize 到一样,且都保持原来的长宽比。
327
+ 2. 用点在图中的实际坐标来计算。
328
+ 3. 实际计算中,把h的坐标归一化到 [0, 1], w为[0, W/H]
329
+ 4. 由于 dwpose 的输出本来就是归一化的坐标,所以h不需要变,w要乘W/H
330
+ 注意:dwpose 输出是 (w, h)
331
+ '''
332
+
333
+ # h不变,w缩放到原比例
334
+ ref_H, ref_W = refer_img.shape[0], refer_img.shape[1]
335
+ ref_ratio = ref_W / ref_H
336
+ body_ref_img[:, 0] = body_ref_img[:, 0] * ref_ratio
337
+ hands_ref_img[:, :, 0] = hands_ref_img[:, :, 0] * ref_ratio
338
+ faces_ref_img[:, :, 0] = faces_ref_img[:, :, 0] * ref_ratio
339
+
340
+ video_ratio = width / height
341
+ body_1st_img[:, 0] = body_1st_img[:, 0] * video_ratio
342
+ hands_1st_img[:, :, 0] = hands_1st_img[:, :, 0] * video_ratio
343
+ faces_1st_img[:, :, 0] = faces_1st_img[:, :, 0] * video_ratio
344
+
345
+ # scale
346
+ align_args = dict()
347
+
348
+ dist_1st_img = np.linalg.norm(body_1st_img[0]-body_1st_img[1]) # 0.078
349
+ dist_ref_img = np.linalg.norm(body_ref_img[0]-body_ref_img[1]) # 0.106
350
+ align_args["scale_neck"] = dist_ref_img / dist_1st_img # align / pose = ref / 1st
351
+
352
+ dist_1st_img = np.linalg.norm(body_1st_img[16]-body_1st_img[17])
353
+ dist_ref_img = np.linalg.norm(body_ref_img[16]-body_ref_img[17])
354
+ align_args["scale_face"] = dist_ref_img / dist_1st_img
355
+
356
+ dist_1st_img = np.linalg.norm(body_1st_img[2]-body_1st_img[5]) # 0.112
357
+ dist_ref_img = np.linalg.norm(body_ref_img[2]-body_ref_img[5]) # 0.174
358
+ align_args["scale_shoulder"] = dist_ref_img / dist_1st_img
359
+
360
+ dist_1st_img = np.linalg.norm(body_1st_img[2]-body_1st_img[3]) # 0.895
361
+ dist_ref_img = np.linalg.norm(body_ref_img[2]-body_ref_img[3]) # 0.134
362
+ s1 = dist_ref_img / dist_1st_img
363
+ dist_1st_img = np.linalg.norm(body_1st_img[5]-body_1st_img[6])
364
+ dist_ref_img = np.linalg.norm(body_ref_img[5]-body_ref_img[6])
365
+ s2 = dist_ref_img / dist_1st_img
366
+ align_args["scale_arm_upper"] = (s1+s2)/2 # 1.548
367
+
368
+ dist_1st_img = np.linalg.norm(body_1st_img[3]-body_1st_img[4])
369
+ dist_ref_img = np.linalg.norm(body_ref_img[3]-body_ref_img[4])
370
+ s1 = dist_ref_img / dist_1st_img
371
+ dist_1st_img = np.linalg.norm(body_1st_img[6]-body_1st_img[7])
372
+ dist_ref_img = np.linalg.norm(body_ref_img[6]-body_ref_img[7])
373
+ s2 = dist_ref_img / dist_1st_img
374
+ align_args["scale_arm_lower"] = (s1+s2)/2
375
+
376
+ # hand
377
+ dist_1st_img = np.zeros(10)
378
+ dist_ref_img = np.zeros(10)
379
+
380
+ dist_1st_img[0] = np.linalg.norm(hands_1st_img[0,0]-hands_1st_img[0,1])
381
+ dist_1st_img[1] = np.linalg.norm(hands_1st_img[0,0]-hands_1st_img[0,5])
382
+ dist_1st_img[2] = np.linalg.norm(hands_1st_img[0,0]-hands_1st_img[0,9])
383
+ dist_1st_img[3] = np.linalg.norm(hands_1st_img[0,0]-hands_1st_img[0,13])
384
+ dist_1st_img[4] = np.linalg.norm(hands_1st_img[0,0]-hands_1st_img[0,17])
385
+ dist_1st_img[5] = np.linalg.norm(hands_1st_img[1,0]-hands_1st_img[1,1])
386
+ dist_1st_img[6] = np.linalg.norm(hands_1st_img[1,0]-hands_1st_img[1,5])
387
+ dist_1st_img[7] = np.linalg.norm(hands_1st_img[1,0]-hands_1st_img[1,9])
388
+ dist_1st_img[8] = np.linalg.norm(hands_1st_img[1,0]-hands_1st_img[1,13])
389
+ dist_1st_img[9] = np.linalg.norm(hands_1st_img[1,0]-hands_1st_img[1,17])
390
+
391
+ dist_ref_img[0] = np.linalg.norm(hands_ref_img[0,0]-hands_ref_img[0,1])
392
+ dist_ref_img[1] = np.linalg.norm(hands_ref_img[0,0]-hands_ref_img[0,5])
393
+ dist_ref_img[2] = np.linalg.norm(hands_ref_img[0,0]-hands_ref_img[0,9])
394
+ dist_ref_img[3] = np.linalg.norm(hands_ref_img[0,0]-hands_ref_img[0,13])
395
+ dist_ref_img[4] = np.linalg.norm(hands_ref_img[0,0]-hands_ref_img[0,17])
396
+ dist_ref_img[5] = np.linalg.norm(hands_ref_img[1,0]-hands_ref_img[1,1])
397
+ dist_ref_img[6] = np.linalg.norm(hands_ref_img[1,0]-hands_ref_img[1,5])
398
+ dist_ref_img[7] = np.linalg.norm(hands_ref_img[1,0]-hands_ref_img[1,9])
399
+ dist_ref_img[8] = np.linalg.norm(hands_ref_img[1,0]-hands_ref_img[1,13])
400
+ dist_ref_img[9] = np.linalg.norm(hands_ref_img[1,0]-hands_ref_img[1,17])
401
+
402
+ ratio = 0
403
+ count = 0
404
+ for i in range (10):
405
+ if dist_1st_img[i] != 0:
406
+ ratio = ratio + dist_ref_img[i]/dist_1st_img[i]
407
+ count = count + 1
408
+ if count!=0:
409
+ align_args["scale_hand"] = (ratio/count+align_args["scale_arm_upper"]+align_args["scale_arm_lower"])/3
410
+ else:
411
+ align_args["scale_hand"] = (align_args["scale_arm_upper"]+align_args["scale_arm_lower"])/2
412
+
413
+ # body
414
+ dist_1st_img = np.linalg.norm(body_1st_img[1] - (body_1st_img[8] + body_1st_img[11])/2 )
415
+ dist_ref_img = np.linalg.norm(body_ref_img[1] - (body_ref_img[8] + body_ref_img[11])/2 )
416
+ align_args["scale_body_len"]=dist_ref_img / dist_1st_img
417
+
418
+ dist_1st_img = np.linalg.norm(body_1st_img[8]-body_1st_img[9])
419
+ dist_ref_img = np.linalg.norm(body_ref_img[8]-body_ref_img[9])
420
+ s1 = dist_ref_img / dist_1st_img
421
+ dist_1st_img = np.linalg.norm(body_1st_img[11]-body_1st_img[12])
422
+ dist_ref_img = np.linalg.norm(body_ref_img[11]-body_ref_img[12])
423
+ s2 = dist_ref_img / dist_1st_img
424
+ align_args["scale_leg_upper"] = (s1+s2)/2
425
+
426
+ dist_1st_img = np.linalg.norm(body_1st_img[9]-body_1st_img[10])
427
+ dist_ref_img = np.linalg.norm(body_ref_img[9]-body_ref_img[10])
428
+ s1 = dist_ref_img / dist_1st_img
429
+ dist_1st_img = np.linalg.norm(body_1st_img[12]-body_1st_img[13])
430
+ dist_ref_img = np.linalg.norm(body_ref_img[12]-body_ref_img[13])
431
+ s2 = dist_ref_img / dist_1st_img
432
+ align_args["scale_leg_lower"] = (s1+s2)/2
433
+
434
+ ####################
435
+ ####################
436
+ # need adjust nan
437
+ for k,v in align_args.items():
438
+ if np.isnan(v):
439
+ align_args[k]=1
440
+
441
+ # centre offset (the offset of key point 1)
442
+ offset = body_ref_img[1] - body_1st_img[1]
443
+
444
+
445
+ # pose align
446
+ pose_img, pose_ori = detector(img, args.detect_resolution, args.image_resolution, output_type='cv2', return_pose_dict=True)
447
+ video_pose_buffer.append(pose_img)
448
+ pose_align = align_img(img, pose_ori, align_args, args.detect_resolution, args.image_resolution)
449
+
450
+
451
+ # add centre offset
452
+ pose = pose_align
453
+ pose['bodies']['candidate'] = pose['bodies']['candidate'] + offset
454
+ pose['hands'] = pose['hands'] + offset
455
+ pose['faces'] = pose['faces'] + offset
456
+
457
+
458
+ # h不变,w从绝对坐标缩放回0-1 注意这里要回到ref的坐标系
459
+ pose['bodies']['candidate'][:, 0] = pose['bodies']['candidate'][:, 0] / ref_ratio
460
+ pose['hands'][:, :, 0] = pose['hands'][:, :, 0] / ref_ratio
461
+ pose['faces'][:, :, 0] = pose['faces'][:, :, 0] / ref_ratio
462
+ pose_list.append(pose)
463
+
464
+ # stack
465
+ body_list = [pose['bodies']['candidate'][:18] for pose in pose_list]
466
+ body_list_subset = [pose['bodies']['subset'][:1] for pose in pose_list]
467
+ hands_list = [pose['hands'][:2] for pose in pose_list]
468
+ faces_list = [pose['faces'][:1] for pose in pose_list]
469
+
470
+ body_seq = np.stack(body_list , axis=0)
471
+ body_seq_subset = np.stack(body_list_subset, axis=0)
472
+ hands_seq = np.stack(hands_list , axis=0)
473
+ faces_seq = np.stack(faces_list , axis=0)
474
+
475
+
476
+ # concatenate and paint results
477
+ H = 768 # paint height
478
+ W1 = int((H/ref_H * ref_W)//2 *2)
479
+ W2 = int((H/height * width)//2 *2)
480
+ result_demo = [] # = Writer(args, None, H, 3*W1+2*W2, outfn, fps)
481
+ result_pose_only = [] # Writer(args, None, H, W1, args.outfn_align_pose_video, fps)
482
+ for i in range(len(body_seq)):
483
+ pose_t={}
484
+ pose_t["bodies"]={}
485
+ pose_t["bodies"]["candidate"]=body_seq[i]
486
+ pose_t["bodies"]["subset"]=body_seq_subset[i]
487
+ pose_t["hands"]=hands_seq[i]
488
+ pose_t["faces"]=faces_seq[i]
489
+
490
+ ref_img = cv2.cvtColor(refer_img, cv2.COLOR_RGB2BGR)
491
+ ref_img = cv2.resize(ref_img, (W1, H))
492
+ ref_pose= cv2.resize(output_refer, (W1, H))
493
+
494
+ output_transformed = draw_pose(
495
+ pose_t,
496
+ int(H_in*1024/W_in),
497
+ 1024,
498
+ draw_face=False,
499
+ )
500
+ output_transformed = cv2.cvtColor(output_transformed, cv2.COLOR_BGR2RGB)
501
+ output_transformed = cv2.resize(output_transformed, (W1, H))
502
+
503
+ video_frame = cv2.resize(video_frame_buffer[i], (W2, H))
504
+ video_pose = cv2.resize(video_pose_buffer[i], (W2, H))
505
+
506
+ res = np.concatenate([ref_img, ref_pose, output_transformed, video_frame, video_pose], axis=1)
507
+ result_demo.append(res)
508
+ result_pose_only.append(output_transformed)
509
+
510
+ print(f"pose_list len: {len(pose_list)}")
511
+ clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(result_demo, fps=fps)
512
+ clip.write_videofile(outfn, fps=fps)
513
+ clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(result_pose_only, fps=fps)
514
+ clip.write_videofile(args.outfn_align_pose_video, fps=fps)
515
+ print('pose align done')
516
+
517
+
518
+
519
+ def main():
520
+ parser = argparse.ArgumentParser()
521
+
522
+ parser.add_argument('--detect_resolution', type=int, default=512, help='detect_resolution')
523
+ parser.add_argument('--image_resolution', type=int, default=720, help='image_resolution')
524
+
525
+ parser.add_argument("--yolox_config", type=str, default="./pose/config/yolox_l_8xb8-300e_coco.py")
526
+ parser.add_argument("--dwpose_config", type=str, default="./pose/config/dwpose-l_384x288.py")
527
+ parser.add_argument("--yolox_ckpt", type=str, default="./pretrained_weights/dwpose/yolox_l_8x8_300e_coco.pth")
528
+ parser.add_argument("--dwpose_ckpt", type=str, default="./pretrained_weights/dwpose/dw-ll_ucoco_384.pth")
529
+
530
+
531
+ parser.add_argument('--align_frame', type=int, default=0, help='the frame index of the video to align')
532
+ parser.add_argument('--max_frame', type=int, default=300, help='maximum frame number of the video to align')
533
+ parser.add_argument('--imgfn_refer', type=str, default="./assets/images/0.jpg", help='refer image path')
534
+ parser.add_argument('--vidfn', type=str, default="./assets/videos/0.mp4", help='Input video path')
535
+ parser.add_argument('--outfn_align_pose_video', type=str, default=None, help='output path of the aligned video of the refer img')
536
+ parser.add_argument('--outfn', type=str, default=None, help='Output path of the alignment visualization')
537
+ args = parser.parse_args()
538
+
539
+ if not os.path.exists("./assets/poses/align"):
540
+ # os.makedirs("./assets/poses/")
541
+ os.makedirs("./assets/poses/align")
542
+ os.makedirs("./assets/poses/align_demo")
543
+
544
+ img_name = os.path.basename(args.imgfn_refer).split('.')[0]
545
+ video_name = os.path.basename(args.vidfn).split('.')[0]
546
+ if args.outfn_align_pose_video is None:
547
+ args.outfn_align_pose_video = "./assets/poses/align/img_{}_video_{}.mp4".format(img_name, video_name)
548
+ if args.outfn is None:
549
+ args.outfn = "./assets/poses/align_demo/img_{}_video_{}.mp4".format(img_name, video_name)
550
+
551
+ run_align_video_with_filterPose_translate_smooth(args)
552
+
553
+
554
+
555
+ if __name__ == '__main__':
556
+ main()
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchdiffeq==0.2.3
3
+ torchmetrics==1.2.1
4
+ torchsde==0.2.5
5
+ torchvision==0.15.2
6
+ accelerate==0.29.3
7
+ av==11.0.0
8
+ clip @ https://github.com/openai/CLIP/archive/d50d76daa670286dd6cacf3bcd80b5e4823fc8e1.zip#sha256=b5842c25da441d6c581b53a5c60e0c2127ebafe0f746f8e15561a006c6c3be6a
9
+ decord==0.6.0
10
+ diffusers>=0.24.0,<=0.27.2
11
+ einops==0.4.1
12
+ imageio==2.33.0
13
+ imageio-ffmpeg==0.4.9
14
+ ffmpeg-python==0.2.0
15
+ omegaconf==2.2.3
16
+ open-clip-torch==2.20.0
17
+ opencv-contrib-python==4.8.1.78
18
+ opencv-python==4.8.1.78
19
+ scikit-image==0.21.0
20
+ scikit-learn==1.3.2
21
+ transformers==4.33.1
22
+ xformers==0.0.22
23
+ moviepy==1.0.3
24
+ wget==3.2
25
+ huggingface_hub==0.24.7
test_stage_1.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+ import argparse
3
+ import os
4
+ import sys
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import List
8
+ import glob
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchvision
13
+ from diffusers import AutoencoderKL, DDIMScheduler
14
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
15
+ from einops import repeat
16
+ from omegaconf import OmegaConf
17
+ from PIL import Image
18
+ from torchvision import transforms
19
+ from transformers import CLIPVisionModelWithProjection
20
+
21
+
22
+ from musepose.models.pose_guider import PoseGuider
23
+ from musepose.models.unet_2d_condition import UNet2DConditionModel
24
+ from musepose.models.unet_3d import UNet3DConditionModel
25
+ from musepose.pipelines.pipeline_pose2img import Pose2ImagePipeline
26
+ from musepose.utils.util import get_fps, read_frames, save_videos_grid
27
+
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--config",default="./configs/test_stage_1.yaml")
32
+ parser.add_argument("-W", type=int, default=768)
33
+ parser.add_argument("-H", type=int, default=768)
34
+ parser.add_argument("--seed", type=int, default=42)
35
+ parser.add_argument("--cnt", type=int, default=1)
36
+ parser.add_argument("--cfg", type=float, default=7)
37
+ parser.add_argument("--steps", type=int, default=20)
38
+ parser.add_argument("--fps", type=int)
39
+ args = parser.parse_args()
40
+
41
+ return args
42
+
43
+
44
+
45
+ def main():
46
+ args = parse_args()
47
+
48
+ config = OmegaConf.load(args.config)
49
+
50
+ if config.weight_dtype == "fp16":
51
+ weight_dtype = torch.float16
52
+ else:
53
+ weight_dtype = torch.float32
54
+
55
+ vae = AutoencoderKL.from_pretrained(
56
+ config.pretrained_vae_path,
57
+ ).to("cuda", dtype=weight_dtype)
58
+
59
+ reference_unet = UNet2DConditionModel.from_pretrained(
60
+ config.pretrained_base_model_path,
61
+ subfolder="unet",
62
+ ).to(dtype=weight_dtype, device="cuda")
63
+
64
+ inference_config_path = config.inference_config
65
+ infer_config = OmegaConf.load(inference_config_path)
66
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
67
+ config.pretrained_base_model_path,
68
+ # config.motion_module_path,
69
+ "",
70
+ subfolder="unet",
71
+ unet_additional_kwargs={
72
+ "use_motion_module": False,
73
+ "unet_use_temporal_attention": False,
74
+ },
75
+ ).to(dtype=weight_dtype, device="cuda")
76
+
77
+ pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
78
+ dtype=weight_dtype, device="cuda"
79
+ )
80
+
81
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
82
+ config.image_encoder_path
83
+ ).to(dtype=weight_dtype, device="cuda")
84
+
85
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
86
+ scheduler = DDIMScheduler(**sched_kwargs)
87
+
88
+
89
+ width, height = args.W, args.H
90
+
91
+ # load pretrained weights
92
+ denoising_unet.load_state_dict(
93
+ torch.load(config.denoising_unet_path, map_location="cpu"),
94
+ strict=False,
95
+ )
96
+ reference_unet.load_state_dict(
97
+ torch.load(config.reference_unet_path, map_location="cpu"),
98
+ )
99
+ pose_guider.load_state_dict(
100
+ torch.load(config.pose_guider_path, map_location="cpu"),
101
+ )
102
+
103
+ pipe = Pose2ImagePipeline(
104
+ vae=vae,
105
+ image_encoder=image_enc,
106
+ reference_unet=reference_unet,
107
+ denoising_unet=denoising_unet,
108
+ pose_guider=pose_guider,
109
+ scheduler=scheduler,
110
+ )
111
+
112
+ pipe = pipe.to("cuda", dtype=weight_dtype)
113
+
114
+ date_str = datetime.now().strftime("%Y%m%d")
115
+ time_str = datetime.now().strftime("%H%M")
116
+
117
+ m1 = config.pose_guider_path.split('.')[0].split('/')[-1]
118
+ save_dir_name = f"{time_str}-{m1}"
119
+
120
+ save_dir = Path(f"./output/image-{date_str}/{save_dir_name}")
121
+ save_dir.mkdir(exist_ok=True, parents=True)
122
+
123
+ def handle_single(ref_image_path, pose_path,seed):
124
+ generator = torch.manual_seed(seed)
125
+ ref_name = Path(ref_image_path).stem
126
+ # pose_name = Path(pose_image_path).stem.replace("_kps", "")
127
+ pose_name = Path(pose_path).stem
128
+
129
+ ref_image_pil = Image.open(ref_image_path).convert("RGB")
130
+ pose_image = Image.open(pose_path).convert("RGB")
131
+
132
+ original_width, original_height = pose_image.size
133
+
134
+ pose_transform = transforms.Compose(
135
+ [transforms.Resize((height, width)), transforms.ToTensor()]
136
+ )
137
+
138
+ pose_image_tensor = pose_transform(pose_image)
139
+ pose_image_tensor = pose_image_tensor.unsqueeze(0) # (1, c, h, w)
140
+
141
+ ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
142
+ ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
143
+
144
+ image = pipe(
145
+ ref_image_pil,
146
+ pose_image,
147
+ width,
148
+ height,
149
+ args.steps,
150
+ args.cfg,
151
+ generator=generator,
152
+ ).images
153
+
154
+ image = image.squeeze(2).squeeze(0) # (c, h, w)
155
+ image = image.transpose(0, 1).transpose(1, 2) # (h w c)
156
+ #image = (image + 1.0) / 2.0 # -1,1 -> 0,1
157
+
158
+ image = (image * 255).numpy().astype(np.uint8)
159
+ image = Image.fromarray(image, 'RGB')
160
+ # image.save(os.path.join(save_dir, f"{ref_name}_{pose_name}_{args.H}x{args.W}_{int(args.cfg)}_{time_str}.png"))
161
+
162
+ image_grid = Image.new('RGB',(original_width*3,original_height))
163
+ imgs = [ref_image_pil,pose_image,image]
164
+ x_offset = 0
165
+ for img in imgs:
166
+ img = img.resize((original_width*2, original_height*2))
167
+ img.save(os.path.join(save_dir, f"res_{ref_name}_{pose_name}_{args.cfg}_{seed}.jpg"))
168
+ img = img.resize((original_width,original_height))
169
+ image_grid.paste(img, (x_offset,0))
170
+ x_offset += img.size[0]
171
+ image_grid.save(os.path.join(save_dir, f"grid_{ref_name}_{pose_name}_{args.cfg}_{seed}.jpg"))
172
+
173
+
174
+ for ref_image_path_dir in config["test_cases"].keys():
175
+ if os.path.isdir(ref_image_path_dir):
176
+ ref_image_paths = glob.glob(os.path.join(ref_image_path_dir, '*.jpg'))
177
+ else:
178
+ ref_image_paths = [ref_image_path_dir]
179
+ for ref_image_path in ref_image_paths:
180
+ for pose_image_path_dir in config["test_cases"][ref_image_path_dir]:
181
+ if os.path.isdir(pose_image_path_dir):
182
+ pose_image_paths = glob.glob(os.path.join(pose_image_path_dir, '*.jpg'))
183
+ else:
184
+ pose_image_paths = [pose_image_path_dir]
185
+ for pose_image_path in pose_image_paths:
186
+ for i in range(args.cnt):
187
+ handle_single(ref_image_path, pose_image_path, args.seed + i)
188
+
189
+
190
+ if __name__ == "__main__":
191
+ main()
192
+
test_stage_2.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys
2
+ import argparse
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from typing import List
6
+
7
+ import av
8
+ import numpy as np
9
+ import torch
10
+ import torchvision
11
+ from diffusers import AutoencoderKL, DDIMScheduler
12
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
13
+ from einops import repeat
14
+ from omegaconf import OmegaConf
15
+ from PIL import Image
16
+ from torchvision import transforms
17
+ from transformers import CLIPVisionModelWithProjection
18
+ import glob
19
+ import torch.nn.functional as F
20
+
21
+ from musepose.models.pose_guider import PoseGuider
22
+ from musepose.models.unet_2d_condition import UNet2DConditionModel
23
+ from musepose.models.unet_3d import UNet3DConditionModel
24
+ from musepose.pipelines.pipeline_pose2vid_long import Pose2VideoPipeline
25
+ from musepose.utils.util import get_fps, read_frames, save_videos_grid
26
+
27
+
28
+
29
+ def parse_args():
30
+ parser = argparse.ArgumentParser()
31
+ parser.add_argument("--config", type=str, default="./configs/test_stage_2.yaml")
32
+ parser.add_argument("-W", type=int, default=768, help="Width")
33
+ parser.add_argument("-H", type=int, default=768, help="Height")
34
+ parser.add_argument("-L", type=int, default=300, help="video frame length")
35
+ parser.add_argument("-S", type=int, default=48, help="video slice frame number")
36
+ parser.add_argument("-O", type=int, default=4, help="video slice overlap frame number")
37
+
38
+ parser.add_argument("--cfg", type=float, default=3.5, help="Classifier free guidance")
39
+ parser.add_argument("--seed", type=int, default=99)
40
+ parser.add_argument("--steps", type=int, default=20, help="DDIM sampling steps")
41
+ parser.add_argument("--fps", type=int)
42
+
43
+ parser.add_argument("--skip", type=int, default=1, help="frame sample rate = (skip+1)")
44
+ args = parser.parse_args()
45
+
46
+ print('Width:', args.W)
47
+ print('Height:', args.H)
48
+ print('Length:', args.L)
49
+ print('Slice:', args.S)
50
+ print('Overlap:', args.O)
51
+ print('Classifier free guidance:', args.cfg)
52
+ print('DDIM sampling steps :', args.steps)
53
+ print("skip", args.skip)
54
+
55
+ return args
56
+
57
+
58
+ def scale_video(video,width,height):
59
+ video_reshaped = video.view(-1, *video.shape[2:]) # [batch*frames, channels, height, width]
60
+ scaled_video = F.interpolate(video_reshaped, size=(height, width), mode='bilinear', align_corners=False)
61
+ scaled_video = scaled_video.view(*video.shape[:2], scaled_video.shape[1], height, width) # [batch, frames, channels, height, width]
62
+
63
+ return scaled_video
64
+
65
+
66
+ def main():
67
+ args = parse_args()
68
+
69
+ config = OmegaConf.load(args.config)
70
+
71
+ if config.weight_dtype == "fp16":
72
+ weight_dtype = torch.float16
73
+ else:
74
+ weight_dtype = torch.float32
75
+
76
+ vae = AutoencoderKL.from_pretrained(
77
+ config.pretrained_vae_path,
78
+ ).to("cuda", dtype=weight_dtype)
79
+
80
+ reference_unet = UNet2DConditionModel.from_pretrained(
81
+ config.pretrained_base_model_path,
82
+ subfolder="unet",
83
+ ).to(dtype=weight_dtype, device="cuda")
84
+
85
+ inference_config_path = config.inference_config
86
+ infer_config = OmegaConf.load(inference_config_path)
87
+ denoising_unet = UNet3DConditionModel.from_pretrained_2d(
88
+ config.pretrained_base_model_path,
89
+ config.motion_module_path,
90
+ subfolder="unet",
91
+ unet_additional_kwargs=infer_config.unet_additional_kwargs,
92
+ ).to(dtype=weight_dtype, device="cuda")
93
+
94
+ pose_guider = PoseGuider(320, block_out_channels=(16, 32, 96, 256)).to(
95
+ dtype=weight_dtype, device="cuda"
96
+ )
97
+
98
+ image_enc = CLIPVisionModelWithProjection.from_pretrained(
99
+ config.image_encoder_path
100
+ ).to(dtype=weight_dtype, device="cuda")
101
+
102
+ sched_kwargs = OmegaConf.to_container(infer_config.noise_scheduler_kwargs)
103
+ scheduler = DDIMScheduler(**sched_kwargs)
104
+
105
+ generator = torch.manual_seed(args.seed)
106
+
107
+ width, height = args.W, args.H
108
+
109
+ # load pretrained weights
110
+ denoising_unet.load_state_dict(
111
+ torch.load(config.denoising_unet_path, map_location="cpu"),
112
+ strict=False,
113
+ )
114
+ reference_unet.load_state_dict(
115
+ torch.load(config.reference_unet_path, map_location="cpu"),
116
+ )
117
+ pose_guider.load_state_dict(
118
+ torch.load(config.pose_guider_path, map_location="cpu"),
119
+ )
120
+
121
+ pipe = Pose2VideoPipeline(
122
+ vae=vae,
123
+ image_encoder=image_enc,
124
+ reference_unet=reference_unet,
125
+ denoising_unet=denoising_unet,
126
+ pose_guider=pose_guider,
127
+ scheduler=scheduler,
128
+ )
129
+ pipe = pipe.to("cuda", dtype=weight_dtype)
130
+
131
+ date_str = datetime.now().strftime("%Y%m%d")
132
+ time_str = datetime.now().strftime("%H%M")
133
+
134
+ def handle_single(ref_image_path,pose_video_path):
135
+ print ('handle===',ref_image_path, pose_video_path)
136
+ ref_name = Path(ref_image_path).stem
137
+ pose_name = Path(pose_video_path).stem.replace("_kps", "")
138
+
139
+ ref_image_pil = Image.open(ref_image_path).convert("RGB")
140
+
141
+ pose_list = []
142
+ pose_tensor_list = []
143
+ pose_images = read_frames(pose_video_path)
144
+ src_fps = get_fps(pose_video_path)
145
+ print(f"pose video has {len(pose_images)} frames, with {src_fps} fps")
146
+ L = min(args.L, len(pose_images))
147
+ pose_transform = transforms.Compose(
148
+ [transforms.Resize((height, width)), transforms.ToTensor()]
149
+ )
150
+ original_width,original_height = 0,0
151
+
152
+ pose_images = pose_images[::args.skip+1]
153
+ print("processing length:", len(pose_images))
154
+ src_fps = src_fps // (args.skip + 1)
155
+ print("fps", src_fps)
156
+ L = L // ((args.skip + 1))
157
+
158
+ for pose_image_pil in pose_images[: L]:
159
+ pose_tensor_list.append(pose_transform(pose_image_pil))
160
+ pose_list.append(pose_image_pil)
161
+ original_width, original_height = pose_image_pil.size
162
+ pose_image_pil = pose_image_pil.resize((width,height))
163
+
164
+ # repeart the last segment
165
+ last_segment_frame_num = (L - args.S) % (args.S - args.O)
166
+ repeart_frame_num = (args.S - args.O - last_segment_frame_num) % (args.S - args.O)
167
+ for i in range(repeart_frame_num):
168
+ pose_list.append(pose_list[-1])
169
+ pose_tensor_list.append(pose_tensor_list[-1])
170
+
171
+
172
+ ref_image_tensor = pose_transform(ref_image_pil) # (c, h, w)
173
+ ref_image_tensor = ref_image_tensor.unsqueeze(1).unsqueeze(0) # (1, c, 1, h, w)
174
+ ref_image_tensor = repeat(ref_image_tensor, "b c f h w -> b c (repeat f) h w", repeat=L)
175
+
176
+ pose_tensor = torch.stack(pose_tensor_list, dim=0) # (f, c, h, w)
177
+ pose_tensor = pose_tensor.transpose(0, 1)
178
+ pose_tensor = pose_tensor.unsqueeze(0)
179
+
180
+ video = pipe(
181
+ ref_image_pil,
182
+ pose_list,
183
+ width,
184
+ height,
185
+ len(pose_list),
186
+ args.steps,
187
+ args.cfg,
188
+ generator=generator,
189
+ context_frames=args.S,
190
+ context_stride=1,
191
+ context_overlap=args.O,
192
+ ).videos
193
+
194
+
195
+ m1 = config.pose_guider_path.split('.')[0].split('/')[-1]
196
+ m2 = config.motion_module_path.split('.')[0].split('/')[-1]
197
+
198
+ save_dir_name = f"{time_str}-{args.cfg}-{m1}-{m2}"
199
+ save_dir = Path(f"./output/video-{date_str}/{save_dir_name}")
200
+ save_dir.mkdir(exist_ok=True, parents=True)
201
+
202
+ result = scale_video(video[:,:,:L], original_width, original_height)
203
+ save_videos_grid(
204
+ result,
205
+ f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}.mp4",
206
+ n_rows=1,
207
+ fps=src_fps if args.fps is None else args.fps,
208
+ )
209
+
210
+ video = torch.cat([ref_image_tensor, pose_tensor[:,:,:L], video[:,:,:L]], dim=0)
211
+ video = scale_video(video, original_width, original_height)
212
+ save_videos_grid(
213
+ video,
214
+ f"{save_dir}/{ref_name}_{pose_name}_{args.cfg}_{args.steps}_{args.skip}_{m1}_{m2}.mp4",
215
+ n_rows=3,
216
+ fps=src_fps if args.fps is None else args.fps,
217
+ )
218
+
219
+ for ref_image_path_dir in config["test_cases"].keys():
220
+ if os.path.isdir(ref_image_path_dir):
221
+ ref_image_paths = glob.glob(os.path.join(ref_image_path_dir, '*.jpg'))
222
+ else:
223
+ ref_image_paths = [ref_image_path_dir]
224
+ for ref_image_path in ref_image_paths:
225
+ for pose_video_path_dir in config["test_cases"][ref_image_path_dir]:
226
+ if os.path.isdir(pose_video_path_dir):
227
+ pose_video_paths = glob.glob(os.path.join(pose_video_path_dir, '*.mp4'))
228
+ else:
229
+ pose_video_paths = [pose_video_path_dir]
230
+ for pose_video_path in pose_video_paths:
231
+ handle_single(ref_image_path, pose_video_path)
232
+
233
+
234
+
235
+
236
+ if __name__ == "__main__":
237
+ main()