acmyu commited on
Commit
3366cca
·
verified ·
1 Parent(s): d298085

initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. LICENSE +201 -0
  3. README.md +12 -12
  4. app.py +48 -7
  5. caculate_metrics_256.py +27 -0
  6. caculate_metrics_512.py +27 -0
  7. evaluate.py +186 -0
  8. inception.py +138 -0
  9. main.py +1097 -0
  10. metrics.json +1538 -0
  11. metrics.py +522 -0
  12. pose-frames.py +16 -0
  13. pose.py +15 -0
  14. requirements.txt +8 -0
  15. run_stage1.sh +18 -0
  16. run_stage2.sh +18 -0
  17. run_stage3.sh +15 -0
  18. run_test_stage1.sh +9 -0
  19. run_test_stage2.sh +13 -0
  20. run_test_stage3.sh +12 -0
  21. sd.py +13 -0
  22. setup.txt +41 -0
  23. single_extract_pose.py +35 -0
  24. src/__init__.py +0 -0
  25. src/__pycache__/__init__.cpython-311.pyc +0 -0
  26. src/configs/dwpose-l_384x288.py +257 -0
  27. src/configs/stage1_config.py +181 -0
  28. src/configs/stage2_config.py +192 -0
  29. src/configs/stage3_config.py +217 -0
  30. src/configs/yolox_l_8xb8-300e_coco.py +245 -0
  31. src/controlnet_aux/__init__.py +18 -0
  32. src/controlnet_aux/__pycache__/__init__.cpython-311.pyc +0 -0
  33. src/controlnet_aux/__pycache__/util.cpython-311.pyc +0 -0
  34. src/controlnet_aux/canny/__init__.py +36 -0
  35. src/controlnet_aux/canny/__pycache__/__init__.cpython-311.pyc +0 -0
  36. src/controlnet_aux/dwpose/__init__.py +92 -0
  37. src/controlnet_aux/dwpose/__pycache__/__init__.cpython-311.pyc +0 -0
  38. src/controlnet_aux/dwpose/__pycache__/util.cpython-311.pyc +0 -0
  39. src/controlnet_aux/dwpose/__pycache__/wholebody.cpython-311.pyc +0 -0
  40. src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py +257 -0
  41. src/controlnet_aux/dwpose/util.py +303 -0
  42. src/controlnet_aux/dwpose/wholebody.py +121 -0
  43. src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py +245 -0
  44. src/controlnet_aux/hed/__init__.py +129 -0
  45. src/controlnet_aux/hed/__pycache__/__init__.cpython-311.pyc +0 -0
  46. src/controlnet_aux/leres/__init__.py +118 -0
  47. src/controlnet_aux/leres/__pycache__/__init__.cpython-311.pyc +0 -0
  48. src/controlnet_aux/leres/leres/LICENSE +23 -0
  49. src/controlnet_aux/leres/leres/Resnet.py +199 -0
  50. src/controlnet_aux/leres/leres/Resnext_torch.py +237 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ src/controlnet_aux/tests/test_image.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,12 @@
1
- ---
2
- title: KeyframesAI2
3
- emoji: 📊
4
- colorFrom: green
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.41.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: KeyframesAI
3
+ emoji: 📈
4
+ colorFrom: gray
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 5.22.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,7 +1,48 @@
1
- import gradio as gr
2
-
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
-
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from main import run_app, run_train, run_inference
2
+
3
+ import spaces
4
+ from PIL import Image
5
+ import cv2
6
+ import os
7
+ import gradio as gr
8
+
9
+ with gr.Blocks() as demo:
10
+ with gr.Row():
11
+ with gr.Column():
12
+ char_imgs = gr.Gallery(type="pil", label="Images of the Character")
13
+ mocap = gr.Video(label="Motion-Capture Video")
14
+ tr_steps = gr.Number(label="Training steps", value=10)
15
+ inf_steps = gr.Number(label="Inference steps", value=10)
16
+ fps = gr.Number(label="Output frame rate", value=12)
17
+ modelId = gr.Text(label="Model Id", value="fine_tuned_pcdms")
18
+ remove_bg = gr.Checkbox(label="Remove background", value=False)
19
+ resize_inputs = gr.Checkbox(label="Resize images to match video", value=True)
20
+ train_btn = gr.Button(value="Train")
21
+ inference_btn = gr.Button(value="Inference")
22
+ submit_btn = gr.Button(value="Generate")
23
+ with gr.Column():
24
+ animation = gr.Video(label="Result")
25
+ frames = gr.Gallery(type="pil", label="Frames")
26
+
27
+ submit_btn.click(
28
+ run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
29
+ )
30
+
31
+ train_btn.click(
32
+ run_train, inputs=[char_imgs, tr_steps, remove_bg, resize_inputs, modelId], outputs=[]
33
+ )
34
+
35
+ inference_btn.click(
36
+ run_inference, inputs=[char_imgs, mocap, inf_steps, fps, remove_bg, resize_inputs, modelId], outputs=[animation, frames]
37
+ )
38
+
39
+
40
+
41
+
42
+ demo.launch(share=True)
43
+
44
+
45
+
46
+
47
+
48
+
caculate_metrics_256.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from metrics import FID, LPIPS, Reconstruction_Metrics, preprocess_path_for_deform_task
2
+ import torch
3
+
4
+
5
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
6
+ fid = FID()
7
+ lpips_obj = LPIPS()
8
+ rec = Reconstruction_Metrics()
9
+
10
+ real_path = './datasets/deepfashing/train_lst_256_png'
11
+ gt_path = '/datasets/deepfashing/test_lst_256_png'
12
+
13
+
14
+ distorated_path = './PCDMs_Results/stage3_256_results'
15
+ results_save_path = distorated_path + '_results.txt' # save path
16
+
17
+
18
+ gt_list, distorated_list = preprocess_path_for_deform_task(gt_path, distorated_path)
19
+ print(len(gt_list), len(distorated_list))
20
+
21
+ FID = fid.calculate_from_disk(distorated_path, real_path, img_size=(176,256))
22
+ LPIPS = lpips_obj.calculate_from_disk(distorated_list, gt_list, img_size=(176,256), sort=False)
23
+ REC = rec.calculate_from_disk(distorated_list, gt_list, distorated_path, img_size=(176,256), sort=False, debug=False)
24
+
25
+ print ("FID: "+str(FID)+"\nLPIPS: "+str(LPIPS)+"\nSSIM: "+str(REC))
26
+ with open(results_save_path, 'a') as ff:
27
+ ff.write("\nFID: "+str(FID)+"\nLPIPS: "+str(LPIPS)+"\nSSIM: "+str(REC))
caculate_metrics_512.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from metrics import FID, LPIPS, Reconstruction_Metrics, preprocess_path_for_deform_task
2
+ import torch
3
+
4
+
5
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
6
+ fid = FID()
7
+ lpips_obj = LPIPS()
8
+ rec = Reconstruction_Metrics()
9
+
10
+ real_path = './datasets/deepfashing/train_lst_512_png'
11
+ gt_path = '/datasets/deepfashing/test_lst_512_png'
12
+
13
+
14
+ distorated_path = './PCDMs_Results/stage3_512_results'
15
+ results_save_path = distorated_path + '_results.txt' # save path
16
+
17
+
18
+ gt_list, distorated_list = preprocess_path_for_deform_task(gt_path, distorated_path)
19
+ print(len(gt_list), len(distorated_list))
20
+
21
+ FID = fid.calculate_from_disk(distorated_path, real_path, img_size=(352,512))
22
+ LPIPS = lpips_obj.calculate_from_disk(distorated_list, gt_list, img_size=(352,512), sort=False)
23
+ REC = rec.calculate_from_disk(distorated_list, gt_list, distorated_path, img_size=(352,512), sort=False, debug=False)
24
+
25
+ print ("FID: "+str(FID)+"\nLPIPS: "+str(LPIPS)+"\nSSIM: "+str(REC))
26
+ with open(results_save_path, 'a') as ff:
27
+ ff.write("\nFID: "+str(FID)+"\nLPIPS: "+str(LPIPS)+"\nSSIM: "+str(REC))
evaluate.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from main import extract_frames, run
2
+
3
+ from PIL import Image
4
+ import numpy as np
5
+ from skimage.metrics import structural_similarity as ssim
6
+ from skimage.metrics import peak_signal_noise_ratio as psnr
7
+ import torch
8
+ import torchvision.transforms as transforms
9
+ import lpips
10
+ from pytorch_fid.fid_score import calculate_fid_given_paths
11
+ import os
12
+ import json
13
+
14
+ # Convert PIL to numpy
15
+ def pil_to_np(img):
16
+ return np.array(img).astype(np.float32) / 255.0
17
+
18
+ # SSIM
19
+ def compute_ssim(img1, img2):
20
+ img1_np = pil_to_np(img1)
21
+ img2_np = pil_to_np(img2)
22
+
23
+ h, w = img1_np.shape[:2]
24
+ min_dim = min(h, w)
25
+ win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1) # ensure odd
26
+
27
+ return ssim(img1_np, img2_np, win_size=win_size, channel_axis=-1, data_range=1.0)
28
+
29
+ # PSNR
30
+ def compute_psnr(img1, img2):
31
+ img1_np = pil_to_np(img1)
32
+ img2_np = pil_to_np(img2)
33
+ return psnr(img1_np, img2_np, data_range=1.0)
34
+
35
+ # LPIPS
36
+ lpips_model = lpips.LPIPS(net='alex')
37
+ lpips_transform = transforms.Compose([
38
+ transforms.Resize((256, 256)),
39
+ transforms.ToTensor(),
40
+ transforms.Normalize([0.5]*3, [0.5]*3)
41
+ ])
42
+ def compute_lpips(img1, img2):
43
+
44
+ img1_tensor = lpips_transform(img1).unsqueeze(0)
45
+ img2_tensor = lpips_transform(img2).unsqueeze(0)
46
+ return lpips_model(img1_tensor, img2_tensor).item()
47
+
48
+ # FID: Save images to temp folders for FID calculation
49
+ def compute_fid(img1, img2):
50
+ os.makedirs('temp/img1', exist_ok=True)
51
+ os.makedirs('temp/img2', exist_ok=True)
52
+ img1.save('temp/img1/0.png')
53
+ img2.save('temp/img2/0.png')
54
+ fid = calculate_fid_given_paths(['temp/img1', 'temp/img2'], batch_size=1, device='cpu', dims=2048)
55
+ return fid
56
+
57
+
58
+ with open('metrics.json', 'r') as file:
59
+ metrics = json.load(file)
60
+
61
+ def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
62
+ print(item)
63
+
64
+ images = []
65
+ for path in image_paths:
66
+ img = Image.open(path)
67
+ images.append(img)
68
+
69
+ gt_frames = extract_frames(video_path, fps)
70
+
71
+ os.makedirs('out/'+item, exist_ok=True)
72
+
73
+
74
+ for i, frame in enumerate(gt_frames):
75
+ frame.save("out/"+item+"/frame_"+str(i)+".png")
76
+
77
+ results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
78
+
79
+ for i, result in enumerate(results):
80
+ result.save("out/"+item+"/result_"+str(i)+".png")
81
+
82
+ results_base = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=False)
83
+
84
+ for i, result in enumerate(results_base):
85
+ result.save("out/"+item+"/base_"+str(i)+".png")
86
+
87
+ """
88
+ img1=gt_frames[0]
89
+ img2=Image.open("out/base_0.png")
90
+ print("SSIM:", compute_ssim(img1, img2))
91
+ print("PSNR:", compute_psnr(img1, img2))
92
+ print("LPIPS:", compute_lpips(img1, img2))
93
+ print("FID:", compute_fid(img1, img2))
94
+ """
95
+
96
+ ssim = []
97
+ psnr = []
98
+ lpips = []
99
+ fid = []
100
+ ssim2 = []
101
+ psnr2 = []
102
+ lpips2 = []
103
+ fid2 = []
104
+ for gt, result, base in zip(gt_frames, results, results_base):
105
+ ssim.append(float(compute_ssim(gt, result)))
106
+ psnr.append(float(compute_psnr(gt, result)))
107
+ lpips.append(float(compute_lpips(gt, result)))
108
+ fid.append(float(compute_fid(gt, result)))
109
+
110
+ ssim2.append(float(compute_ssim(gt, base)))
111
+ psnr2.append(float(compute_psnr(gt, base)))
112
+ lpips2.append(float(compute_lpips(gt, base)))
113
+ fid2.append(float(compute_fid(gt, base)))
114
+
115
+
116
+ print("SSIM:", sum(ssim)/len(ssim))
117
+ print("PSNR:", sum(psnr)/len(psnr))
118
+ print("LPIPS:", sum(lpips)/len(lpips))
119
+ print("FID:", sum(fid)/len(fid))
120
+ print('baseline:')
121
+ print("SSIM:", sum(ssim2)/len(ssim2))
122
+ print("PSNR:", sum(psnr2)/len(psnr2))
123
+ print("LPIPS:", sum(lpips2)/len(lpips2))
124
+ print("FID:", sum(fid2)/len(fid2))
125
+
126
+ metrics[item] = {'ft': {}, 'base': {}}
127
+ metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
128
+ metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
129
+ metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips}
130
+ metrics[item]['ft']['fid'] = {'avg': sum(fid)/len(fid), 'vals': fid}
131
+ metrics[item]['base']['ssim'] = {'avg': sum(ssim2)/len(ssim2), 'vals': ssim2}
132
+ metrics[item]['base']['psnr'] = {'avg': sum(psnr2)/len(psnr2), 'vals': psnr2}
133
+ metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2}
134
+ metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2}
135
+
136
+ with open('metrics.json', "w", encoding="utf-8") as json_file:
137
+ json.dump(metrics, json_file, ensure_ascii=False, indent=4)
138
+
139
+
140
+
141
+
142
+ items = ['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
143
+ for item in items:
144
+ if item in metrics:
145
+ continue
146
+ get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
147
+
148
+
149
+
150
+ ssim = []
151
+ psnr = []
152
+ lpips = []
153
+ fid = []
154
+ ssim2 = []
155
+ psnr2 = []
156
+ lpips2 = []
157
+ fid2 = []
158
+ for item in metrics.keys():
159
+ ssim.append(metrics[item]['ft']['ssim']['avg'])
160
+ psnr.append(metrics[item]['ft']['psnr']['avg'])
161
+ lpips.append(metrics[item]['ft']['lpips']['avg'])
162
+ fid.append(metrics[item]['ft']['fid']['avg'])
163
+
164
+ ssim2.append(metrics[item]['base']['ssim']['avg'])
165
+ psnr2.append(metrics[item]['base']['psnr']['avg'])
166
+ lpips2.append(metrics[item]['base']['lpips']['avg'])
167
+ fid2.append(metrics[item]['base']['fid']['avg'])
168
+
169
+ print(item)
170
+ print("SSIM:", metrics[item]['ft']['ssim']['avg'], metrics[item]['base']['ssim']['avg'])
171
+ print("PSNR:", metrics[item]['ft']['psnr']['avg'], metrics[item]['base']['psnr']['avg'])
172
+ print("LPIPS:", metrics[item]['ft']['lpips']['avg'], metrics[item]['base']['lpips']['avg'])
173
+ print("FID:", metrics[item]['ft']['fid']['avg'], metrics[item]['base']['fid']['avg'])
174
+
175
+ print('Results:')
176
+ print("SSIM:", sum(ssim)/len(ssim))
177
+ print("PSNR:", sum(psnr)/len(psnr))
178
+ print("LPIPS:", sum(lpips)/len(lpips))
179
+ print("FID:", sum(fid)/len(fid))
180
+ print('baseline:')
181
+ print("SSIM:", sum(ssim2)/len(ssim2))
182
+ print("PSNR:", sum(psnr2)/len(psnr2))
183
+ print("LPIPS:", sum(lpips2)/len(lpips2))
184
+ print("FID:", sum(fid2)/len(fid2))
185
+
186
+
inception.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torchvision import models
5
+
6
+
7
+ class InceptionV3(nn.Module):
8
+ """Pretrained InceptionV3 network returning feature maps"""
9
+
10
+ # Index of default block of inception to return,
11
+ # corresponds to output of final average pooling
12
+ DEFAULT_BLOCK_INDEX = 3
13
+
14
+ # Maps feature dimensionality to their output blocks indices
15
+ BLOCK_INDEX_BY_DIM = {
16
+ 64: 0, # First max pooling features
17
+ 192: 1, # Second max pooling featurs
18
+ 768: 2, # Pre-aux classifier features
19
+ 2048: 3 # Final average pooling features
20
+ }
21
+
22
+ def __init__(self,
23
+ output_blocks=[DEFAULT_BLOCK_INDEX],
24
+ resize_input=True,
25
+ normalize_input=True,
26
+ requires_grad=False):
27
+ """Build pretrained InceptionV3
28
+ Parameters
29
+ ----------
30
+ output_blocks : list of int
31
+ Indices of blocks to return features of. Possible values are:
32
+ - 0: corresponds to output of first max pooling
33
+ - 1: corresponds to output of second max pooling
34
+ - 2: corresponds to output which is fed to aux classifier
35
+ - 3: corresponds to output of final average pooling
36
+ resize_input : bool
37
+ If true, bilinearly resizes input to width and height 299 before
38
+ feeding input to model. As the network without fully connected
39
+ layers is fully convolutional, it should be able to handle inputs
40
+ of arbitrary size, so resizing might not be strictly needed
41
+ normalize_input : bool
42
+ If true, normalizes the input to the statistics the pretrained
43
+ Inception network expects
44
+ requires_grad : bool
45
+ If true, parameters of the model require gradient. Possibly useful
46
+ for finetuning the network
47
+ """
48
+ super(InceptionV3, self).__init__()
49
+
50
+ self.resize_input = resize_input
51
+ self.normalize_input = normalize_input
52
+ self.output_blocks = sorted(output_blocks)
53
+ self.last_needed_block = max(output_blocks)
54
+
55
+ assert self.last_needed_block <= 3, \
56
+ 'Last possible output block index is 3'
57
+
58
+ self.blocks = nn.ModuleList()
59
+
60
+ inception = models.inception_v3(pretrained=True)
61
+ # Block 0: input to maxpool1
62
+ block0 = [
63
+ inception.Conv2d_1a_3x3,
64
+ inception.Conv2d_2a_3x3,
65
+ inception.Conv2d_2b_3x3,
66
+ nn.MaxPool2d(kernel_size=3, stride=2)
67
+ ]
68
+ self.blocks.append(nn.Sequential(*block0))
69
+
70
+ # Block 1: maxpool1 to maxpool2
71
+ if self.last_needed_block >= 1:
72
+ block1 = [
73
+ inception.Conv2d_3b_1x1,
74
+ inception.Conv2d_4a_3x3,
75
+ nn.MaxPool2d(kernel_size=3, stride=2)
76
+ ]
77
+ self.blocks.append(nn.Sequential(*block1))
78
+
79
+ # Block 2: maxpool2 to aux classifier
80
+ if self.last_needed_block >= 2:
81
+ block2 = [
82
+ inception.Mixed_5b,
83
+ inception.Mixed_5c,
84
+ inception.Mixed_5d,
85
+ inception.Mixed_6a,
86
+ inception.Mixed_6b,
87
+ inception.Mixed_6c,
88
+ inception.Mixed_6d,
89
+ inception.Mixed_6e,
90
+ ]
91
+ self.blocks.append(nn.Sequential(*block2))
92
+
93
+ # Block 3: aux classifier to final avgpool
94
+ if self.last_needed_block >= 3:
95
+ block3 = [
96
+ inception.Mixed_7a,
97
+ inception.Mixed_7b,
98
+ inception.Mixed_7c,
99
+ nn.AdaptiveAvgPool2d(output_size=(1, 1))
100
+ ]
101
+ self.blocks.append(nn.Sequential(*block3))
102
+
103
+ for param in self.parameters():
104
+ param.requires_grad = requires_grad
105
+
106
+ def forward(self, inp):
107
+ """Get Inception feature maps
108
+ Parameters
109
+ ----------
110
+ inp : torch.autograd.Variable
111
+ Input tensor of shape Bx3xHxW. Values are expected to be in
112
+ range (0, 1)
113
+ Returns
114
+ -------
115
+ List of torch.autograd.Variable, corresponding to the selected output
116
+ block, sorted ascending by index
117
+ """
118
+ outp = []
119
+ x = inp
120
+
121
+ if self.resize_input:
122
+ x = F.upsample(x, size=(299, 299), mode='bilinear')
123
+
124
+ if self.normalize_input:
125
+ x = x.clone()
126
+ x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
127
+ x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
128
+ x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
129
+
130
+ for idx, block in enumerate(self.blocks):
131
+ x = block(x)
132
+ if idx in self.output_blocks:
133
+ outp.append(x)
134
+
135
+ if idx == self.last_needed_block:
136
+ break
137
+
138
+ return outp
main.py ADDED
@@ -0,0 +1,1097 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import math
3
+ import os
4
+ from typing import Any, Dict, List, Optional, Tuple, Union
5
+ from diffusers.models.controlnet import ControlNetConditioningEmbedding
6
+ import torch
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint
10
+ import transformers
11
+ from accelerate import Accelerator
12
+ from accelerate.logging import get_logger
13
+ from accelerate.utils import ProjectConfiguration, set_seed
14
+
15
+ from tqdm.auto import tqdm
16
+ from src.configs.stage2_config import args
17
+
18
+ import diffusers
19
+ from diffusers import (
20
+ AutoencoderKL,
21
+ DDPMScheduler,
22
+ )
23
+ from diffusers.optimization import get_scheduler
24
+ from diffusers.utils import check_min_version, is_wandb_available
25
+ from src.dataset.stage2_dataset import InpaintDataset, InpaintCollate_fn
26
+ from transformers import CLIPVisionModelWithProjection
27
+ from transformers import Dinov2Model
28
+ from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
29
+
30
+
31
+
32
+ import glob
33
+ import os
34
+ import torch
35
+ from torch import nn
36
+ from PIL import Image, ImageOps
37
+ import numpy as np
38
+ from diffusers import UniPCMultistepScheduler
39
+ from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
40
+
41
+ from torchvision import transforms
42
+ from diffusers.models.controlnet import ControlNetConditioningEmbedding
43
+ from transformers import CLIPImageProcessor
44
+ from transformers import Dinov2Model
45
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel,ControlNetModel,DDIMScheduler
46
+ from src.pipelines.PCDMs_pipeline import PCDMsPipeline
47
+ #from single_extract_pose import inference_pose
48
+
49
+
50
+ import spaces
51
+ from easy_dwpose import DWposeDetector
52
+ from PIL import Image
53
+ import cv2
54
+ import os
55
+ import gradio as gr
56
+ import rembg
57
+ import uuid
58
+ import gc
59
+ from numba import cuda
60
+
61
+ from huggingface_hub import hf_hub_download
62
+
63
+
64
+ # Inputs ===================================================================================================
65
+
66
+ input_img = "sm.png"
67
+ train_imgs = ["target.png"]
68
+ in_vid = "walk.mp4"
69
+ out_vid = 'out.mp4'
70
+
71
+ """
72
+ train_steps = 100
73
+ inference_steps = 10
74
+ fps = 12
75
+ """
76
+
77
+ debug = False
78
+ save_model = True
79
+ max_batch_size = 8
80
+
81
+ # Pose detection ==============================================================================================
82
+
83
+ def load_models():
84
+ dwpose = DWposeDetector(device="cpu")
85
+ rembg_session = rembg.new_session("u2netp")
86
+
87
+ pcdms_model = hf_hub_download(repo_id="acmyu/PCDMs", filename="pcdms_ckpt.pt")
88
+
89
+ # Load scheduler
90
+ noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
91
+
92
+ # Load model
93
+ image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant')
94
+ image_encoder_g = CLIPVisionModelWithProjection.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')#("openai/clip-vit-base-patch32")
95
+
96
+ vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae")
97
+ unet = Stage2_InapintUNet2DConditionModel.from_pretrained(
98
+ "stabilityai/stable-diffusion-2-1-base",
99
+ torch_dtype=torch.float16,
100
+ subfolder="unet",
101
+ in_channels=9,
102
+ low_cpu_mem_usage=False,
103
+ ignore_mismatched_sizes=True)
104
+
105
+
106
+ return dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet
107
+
108
+
109
+ #load_models()
110
+
111
+
112
+ def resize_and_pad(img, target_img):
113
+ tw, th = target_img.size
114
+ w, h = img.size
115
+
116
+ if tw/th > w/h:
117
+ tw = int(th * w/h)
118
+ elif tw/th < w/h:
119
+ th = int(tw * h/w)
120
+
121
+ img = img.resize((tw, th), Image.BICUBIC)
122
+
123
+ tw, th = target_img.size
124
+ new_img = Image.new("RGB", (tw, th), (0, 0, 0))
125
+ left = (tw - img.width) // 2
126
+ top = (th - img.height) // 2
127
+ new_img.paste(img, (left, top))
128
+
129
+ return new_img
130
+
131
+
132
+ def remove_zero_pad(image):
133
+ image = np.array(image)
134
+ dummy = np.argwhere(image != 0) # assume blackground is zero
135
+ max_y = dummy[:, 0].max()
136
+ min_y = dummy[:, 0].min()
137
+ min_x = dummy[:, 1].min()
138
+ max_x = dummy[:, 1].max()
139
+ crop_image = image[min_y:max_y, min_x:max_x]
140
+
141
+ return Image.fromarray(crop_image)
142
+
143
+
144
+ def get_pose(img, dwpose, outfile, crop=False):
145
+ #pil_image = Image.open("imgs/"+img).convert("RGB")
146
+ #skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False)
147
+
148
+ #img.thumbnail((512,512))
149
+ out_img = dwpose(img, include_hands=True, include_face=False)
150
+
151
+ #print(pose['bodies'])
152
+
153
+ if crop:
154
+ bbox = out_img.getbbox()
155
+ out_img = out_img.crop(bbox)
156
+ out_img = ImageOps.expand(out_img, border=int(out_img.width*0.2), fill=(0,0,0))
157
+
158
+ return out_img
159
+
160
+
161
+ def extract_frames(video_path, fps):
162
+ video_capture = cv2.VideoCapture(video_path)
163
+ frame_count = 0
164
+ frames = []
165
+
166
+ fps_in = video_capture.get(cv2.CAP_PROP_FPS)
167
+ fps_out = fps
168
+
169
+ index_in = -1
170
+ index_out = -1
171
+
172
+ while True:
173
+ success = video_capture.grab()
174
+ if not success: break
175
+ index_in += 1
176
+
177
+ out_due = int(index_in / fps_in * fps_out)
178
+ if out_due > index_out:
179
+ success, frame = video_capture.retrieve()
180
+ if not success:
181
+ break
182
+ index_out += 1
183
+
184
+ frame_count += 1
185
+ frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
186
+
187
+ video_capture.release()
188
+ print(f"Extracted {frame_count} frames")
189
+ return frames
190
+
191
+
192
+ def removebg(img, rembg_session):
193
+ result = Image.new("RGB", img.size, "#ffffff")
194
+ out = rembg.remove(img, session=rembg_session)
195
+ result.paste(out, mask=out)
196
+ return result
197
+
198
+
199
+ def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
200
+ if bg_remove:
201
+ images = [removebg(img, rembg_session) for img in images]
202
+
203
+ in_img = images[0]
204
+ in_pose = get_pose(in_img, dwpose, "in_pose.png")
205
+ train_poses = []
206
+ train_imgs = [resize_and_pad(img, in_img) for img in images[1:]]
207
+
208
+ for i, img in enumerate(train_imgs):
209
+ train_poses.append(get_pose(img, dwpose, "tr_pose"+str(i)+".png"))
210
+
211
+ return in_img, in_pose, train_imgs, train_poses
212
+
213
+
214
+ def prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize='target', is_app=False):
215
+ progress=gr.Progress(track_tqdm=True)
216
+
217
+ print("prepare_inputs_inference")
218
+
219
+ in_pose = get_pose(in_img, dwpose, "in_pose.png")
220
+
221
+ frames = extract_frames(in_vid, fps)
222
+ #frames = [removebg(img, rembg_session) for img in frames]
223
+ if debug:
224
+ for i, frame in enumerate(frames):
225
+ frame.save("out/frame_"+str(i)+".png")
226
+
227
+ print("vid: ", in_vid, fps)
228
+
229
+ progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames")
230
+ target_poses = []
231
+ max_left = max_top = 999999
232
+ max_right = max_bottom = 0
233
+ it = frames
234
+ if is_app:
235
+ it = progress.tqdm(frames, desc="Pose Detection")
236
+ for f in it:
237
+ tpose = get_pose(f, dwpose, "tar_pose"+str(len(target_poses))+".png")
238
+ target_poses.append(tpose)
239
+ progress_bar.update(1)
240
+
241
+ bbox = tpose.getbbox()
242
+ left, top, right, bottom = bbox
243
+ max_left = min(max_left, left)
244
+ max_top = min(max_top, top)
245
+ max_right = max(max_right, right)
246
+ max_bottom = max(max_bottom, bottom)
247
+
248
+ target_poses_cropped = []
249
+ for tpose in target_poses:
250
+ if resize=='target':
251
+ tpose = tpose.crop((max_left, max_top, max_right, max_bottom))
252
+ tpose = ImageOps.expand(tpose, border=int(tpose.width*0.2), fill=(0,0,0))
253
+
254
+ tpose = resize_and_pad(tpose, in_img)
255
+
256
+
257
+ if debug:
258
+ tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png")
259
+ target_poses_cropped.append(tpose)
260
+
261
+ return target_poses_cropped, in_pose
262
+
263
+
264
+ def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize='target', is_app=False):
265
+
266
+ in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
267
+
268
+ target_poses_cropped, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize, is_app)
269
+
270
+
271
+ return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
272
+
273
+
274
+ # Training ===================================================================================================
275
+
276
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
277
+ check_min_version("0.18.0.dev0")
278
+
279
+ logger = get_logger(__name__)
280
+
281
+
282
+ class ImageProjModel_p(torch.nn.Module):
283
+ """SD model with image prompt"""
284
+
285
+ def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
286
+ super().__init__()
287
+
288
+ self.net = nn.Sequential(
289
+ nn.Linear(in_dim, hidden_dim),
290
+ nn.GELU(),
291
+ nn.Dropout(dropout),
292
+ nn.LayerNorm(hidden_dim),
293
+ nn.Linear(hidden_dim, out_dim),
294
+ nn.Dropout(dropout)
295
+ )
296
+
297
+ def forward(self, x):
298
+ return self.net(x)
299
+
300
+ class ImageProjModel_g(torch.nn.Module):
301
+ """SD model with image prompt"""
302
+
303
+ def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
304
+ super().__init__()
305
+
306
+ self.net = nn.Sequential(
307
+ nn.Linear(in_dim, hidden_dim),
308
+ nn.GELU(),
309
+ nn.Dropout(dropout),
310
+ nn.LayerNorm(hidden_dim),
311
+ nn.Linear(hidden_dim, out_dim),
312
+ nn.Dropout(dropout)
313
+ )
314
+
315
+ def forward(self, x): # b, 257,1280
316
+ return self.net(x)
317
+
318
+
319
+ class SDModel(torch.nn.Module):
320
+ """SD model with image prompt"""
321
+ def __init__(self, unet) -> None:
322
+ super().__init__()
323
+ self.image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024)
324
+
325
+ self.unet = unet
326
+ self.pose_proj = ControlNetConditioningEmbedding(
327
+ conditioning_embedding_channels=320,
328
+ block_out_channels=(16, 32, 96, 256),
329
+ conditioning_channels=3)
330
+
331
+
332
+ def forward(self, noisy_latents, timesteps, simg_f_p, timg_f_g, pose_f):
333
+
334
+ extra_image_embeddings_p = self.image_proj_model_p(simg_f_p)
335
+ extra_image_embeddings_g = timg_f_g
336
+
337
+ print(extra_image_embeddings_p.size())
338
+ print(extra_image_embeddings_g.size())
339
+
340
+ encoder_image_hidden_states = torch.cat([extra_image_embeddings_p ,extra_image_embeddings_g], dim=1)
341
+ pose_cond = self.pose_proj(pose_f)
342
+
343
+ pred_noise = self.unet(noisy_latents, timesteps, class_labels=timg_f_g, encoder_hidden_states=encoder_image_hidden_states,my_pose_cond=pose_cond).sample
344
+ return pred_noise
345
+
346
+ def load_training_checkpoint(model, pcdms_model, tag=None, **kwargs):
347
+ #model_sd = torch.load(load_dir, map_location="cpu")["module"]
348
+ model_sd = torch.load(
349
+ pcdms_model,
350
+ map_location="cpu"
351
+ )["module"]
352
+
353
+
354
+ image_proj_model_dict = {}
355
+ pose_proj_dict = {}
356
+ unet_dict = {}
357
+ for k in model_sd.keys():
358
+ if k.startswith("pose_proj"):
359
+ pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
360
+
361
+ elif k.startswith("image_proj_model_p"):
362
+ image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
363
+
364
+ elif k.startswith("image_proj_model."):
365
+ image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
366
+
367
+
368
+ elif k.startswith("unet"):
369
+ unet_dict[k.replace("unet.", "")] = model_sd[k]
370
+ else:
371
+ print(k)
372
+
373
+ model.pose_proj.load_state_dict(pose_proj_dict)
374
+ model.image_proj_model_p.load_state_dict(image_proj_model_dict)
375
+ model.unet.load_state_dict(unet_dict)
376
+
377
+ return model, 0, 0
378
+
379
+
380
+ def checkpoint_model(checkpoint_folder, ckpt_id, model, epoch, last_global_step, **kwargs):
381
+ """Utility function for checkpointing model + optimizer dictionaries
382
+ The main purpose for this is to be able to resume training from that instant again
383
+ """
384
+ checkpoint_state_dict = {
385
+ "epoch": epoch,
386
+ "last_global_step": last_global_step,
387
+ }
388
+ # Add extra kwargs too
389
+ checkpoint_state_dict.update(kwargs)
390
+
391
+ success = model.save_checkpoint(checkpoint_folder, ckpt_id, checkpoint_state_dict)
392
+ status_msg = f"checkpointing: checkpoint_folder={checkpoint_folder}, ckpt_id={ckpt_id}"
393
+ if success:
394
+ logging.info(f"Success {status_msg}")
395
+ else:
396
+ logging.warning(f"Failure {status_msg}")
397
+ return
398
+
399
+
400
+ @spaces.GPU(duration=600)
401
+ def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune=True, is_app=False):
402
+ logging_dir = 'outputs/logging'
403
+ print('start train')
404
+
405
+
406
+ progress=gr.Progress(track_tqdm=True)
407
+
408
+ accelerator = Accelerator(
409
+ log_with=args.report_to,
410
+ project_dir=logging_dir,
411
+ mixed_precision=args.mixed_precision,
412
+ gradient_accumulation_steps=args.gradient_accumulation_steps
413
+ )
414
+
415
+ # Make one log on every process with the configuration for debugging.
416
+ #logging.basicConfig(
417
+ # format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
418
+ # datefmt="%m/%d/%Y %H:%M:%S",
419
+ # level=logging.INFO, )
420
+
421
+ print(accelerator.state)
422
+ if accelerator.is_local_main_process:
423
+ transformers.utils.logging.set_verbosity_warning()
424
+ diffusers.utils.logging.set_verbosity_info()
425
+ else:
426
+ transformers.utils.logging.set_verbosity_error()
427
+ diffusers.utils.logging.set_verbosity_error()
428
+
429
+ # If passed along, set the training seed now.
430
+ set_seed(42)
431
+
432
+ # Handle the repository creation
433
+ if accelerator.is_main_process:
434
+ os.makedirs('outputs', exist_ok=True)
435
+
436
+
437
+ """
438
+ unet = Stage2_InapintUNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet",
439
+ in_channels=9, class_embed_type="projection" ,projection_class_embeddings_input_dim=1024,
440
+ low_cpu_mem_usage=False, ignore_mismatched_sizes=True)
441
+ """
442
+ image_encoder_p.requires_grad_(False)
443
+ image_encoder_g.requires_grad_(False)
444
+ vae.requires_grad_(False)
445
+
446
+ sd_model = SDModel(unet=unet)
447
+ sd_model.train()
448
+
449
+
450
+ if args.gradient_checkpointing:
451
+ sd_model.enable_gradient_checkpointing()
452
+
453
+
454
+ # Enable TF32 for faster training on Ampere GPUs,
455
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
456
+ if args.allow_tf32:
457
+ torch.backends.cuda.matmul.allow_tf32 = True
458
+
459
+ learning_rate = 1e-4
460
+ train_batch_size = min(len(train_images), max_batch_size) #len(train_images) % 16
461
+
462
+
463
+ # Optimizer creation
464
+ params_to_optimize = sd_model.parameters()
465
+ optimizer = torch.optim.AdamW(
466
+ params_to_optimize,
467
+ lr=learning_rate,
468
+ betas=(args.adam_beta1, args.adam_beta2),
469
+ weight_decay=args.adam_weight_decay,
470
+ eps=args.adam_epsilon,
471
+ )
472
+
473
+ inputs = [{
474
+ "source_image": in_image,
475
+ "source_pose": in_pose,
476
+ "target_image": timg,
477
+ "target_pose": tpose,
478
+ } for timg, tpose in zip(train_images, train_poses)]
479
+
480
+ """
481
+ inputs = {[
482
+ "source_image": Image.open('imgs/sm.png'),
483
+ "source_pose": Image.open('imgs/sm_pose.jpg'),
484
+ "target_image": Image.open('imgs/target.png'),
485
+ "target_pose": Image.open('imgs/target_pose.jpg'),
486
+ ]}
487
+ """
488
+
489
+ #print(inputs)
490
+
491
+ dataset = InpaintDataset(
492
+ inputs,
493
+ 'imgs/',
494
+ size=(args.img_width, args.img_height), # w h
495
+ imgp_drop_rate=0.1,
496
+ imgg_drop_rate=0.1,
497
+ )
498
+
499
+ """
500
+ dataset = InpaintDataset(
501
+ args.json_path,
502
+ args.image_root_path,
503
+ size=(args.img_width, args.img_height), # w h
504
+ imgp_drop_rate=0.1,
505
+ imgg_drop_rate=0.1,
506
+ )
507
+ """
508
+
509
+ train_sampler = torch.utils.data.distributed.DistributedSampler(
510
+ dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True)
511
+
512
+ train_dataloader = torch.utils.data.DataLoader(
513
+ dataset,
514
+ sampler=train_sampler,
515
+ collate_fn=InpaintCollate_fn,
516
+ batch_size=train_batch_size,
517
+ num_workers=0,)
518
+
519
+
520
+ # Scheduler and math around the number of training steps.
521
+ overrode_max_train_steps = False
522
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
523
+ if args.max_train_steps is None:
524
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
525
+ overrode_max_train_steps = True
526
+ args.max_train_steps = train_steps
527
+
528
+ lr_scheduler = get_scheduler(
529
+ args.lr_scheduler,
530
+ optimizer=optimizer,
531
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
532
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
533
+ num_cycles=args.lr_num_cycles,
534
+ power=args.lr_power,
535
+ )
536
+
537
+ # Prepare everything with our `accelerator`.
538
+ sd_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(sd_model, optimizer, train_dataloader, lr_scheduler)
539
+
540
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
541
+ # as these models are only used for inference, keeping weights in full precision is not required.
542
+ weight_dtype = torch.float32
543
+ """
544
+ if accelerator.mixed_precision == "fp16":
545
+ weight_dtype = torch.float16
546
+ elif accelerator.mixed_precision == "bf16":
547
+ weight_dtype = torch.bfloat16
548
+ """
549
+
550
+ # Move vae, unet and text_encoder to device and cast to weight_dtype
551
+ vae.to(accelerator.device, dtype=weight_dtype)
552
+ sd_model.unet.to(accelerator.device, dtype=weight_dtype)
553
+ image_encoder_p.to(accelerator.device, dtype=weight_dtype)
554
+ image_encoder_g.to(accelerator.device, dtype=weight_dtype)
555
+
556
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
557
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
558
+ if overrode_max_train_steps:
559
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
560
+ # Afterwards we recalculate our number of training epochs
561
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
562
+
563
+
564
+ args.num_train_epochs = train_steps
565
+
566
+
567
+ # Train!
568
+ total_batch_size = (
569
+ train_batch_size
570
+ * accelerator.num_processes
571
+ * args.gradient_accumulation_steps
572
+ )
573
+
574
+ print("***** Running training *****")
575
+ print(f" Num batches each epoch = {len(train_dataloader)}")
576
+ print(f" Num Epochs = {args.num_train_epochs}")
577
+ print(f" Instantaneous batch size per device = {train_batch_size}")
578
+ print(
579
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
580
+ )
581
+ print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
582
+ print(f" Total optimization steps = {args.max_train_steps}")
583
+
584
+
585
+ if args.resume_from_checkpoint:
586
+ # New Code #
587
+ # Loads the DeepSpeed checkpoint from the specified path
588
+ prior_model, last_epoch, last_global_step = load_training_checkpoint(
589
+ sd_model,
590
+ pcdms_model,
591
+ **{"load_optimizer_states": True, "load_lr_scheduler_states": True},
592
+ )
593
+ print(f"Resumed from checkpoint: {args.resume_from_checkpoint}, global step: {last_global_step}")
594
+ starting_epoch = last_epoch
595
+ global_steps = last_global_step
596
+ sd_model = sd_model
597
+ else:
598
+ global_steps = 0
599
+ starting_epoch = 0
600
+ sd_model = sd_model
601
+
602
+ progress_bar = tqdm(range(global_steps, args.max_train_steps), initial=global_steps, desc="Steps",
603
+ # Only show the progress bar once on each machine.
604
+ disable=not accelerator.is_local_main_process, )
605
+
606
+ bsz = train_batch_size
607
+
608
+ if not finetune or train_steps == 0:
609
+ accelerator.wait_for_everyone()
610
+ accelerator.end_training()
611
+ return {k: v.cpu() for k, v in sd_model.state_dict().items()}
612
+
613
+
614
+ it = range(starting_epoch, args.num_train_epochs)
615
+ if is_app:
616
+ it = progress.tqdm(it, desc="Fine-tuning")
617
+ for epoch in it:
618
+ for step, batch in enumerate(train_dataloader):
619
+ with accelerator.accumulate(sd_model):
620
+ with torch.no_grad():
621
+ # Convert images to latent space
622
+ latents = vae.encode(batch["source_target_image"].to(dtype=weight_dtype)).latent_dist.sample()
623
+ latents = latents * vae.config.scaling_factor
624
+
625
+ # Get the masked image latents
626
+ masked_latents = vae.encode(batch["vae_source_mask_image"].to(dtype=weight_dtype)).latent_dist.sample()
627
+ masked_latents = masked_latents * vae.config.scaling_factor
628
+
629
+ bsz = batch["target_image"].size(dim=0)
630
+
631
+ # mask
632
+ mask1 = torch.ones((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
633
+ mask0 = torch.zeros((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
634
+ mask = torch.cat([mask1, mask0], dim=3)
635
+ # Get the image embedding for conditioning
636
+ cond_image_feature_p = image_encoder_p(batch["source_image"].to(accelerator.device, dtype=weight_dtype))
637
+ cond_image_feature_p = (cond_image_feature_p.last_hidden_state)
638
+
639
+
640
+ cond_image_feature_g = image_encoder_g(batch["target_image"].to(accelerator.device, dtype=weight_dtype), ).image_embeds
641
+ cond_image_feature_g =cond_image_feature_g.unsqueeze(1)
642
+
643
+ # Sample noise that we'll add to the latents
644
+ noise = torch.randn_like(latents)
645
+ if args.noise_offset:
646
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
647
+ noise += args.noise_offset * torch.randn(
648
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
649
+ )
650
+
651
+ # Sample a random timestep for each image
652
+ #timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (train_batch_size,),device=latents.device, )
653
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,),device=latents.device, )
654
+ timesteps = timesteps.long()
655
+
656
+
657
+
658
+ # Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion process)
659
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
660
+
661
+ #print(noisy_latents.size(), mask.size(), masked_latents.size())
662
+
663
+ noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1)
664
+ # Get the text embedding for conditioning
665
+
666
+
667
+ cond_pose = batch["source_target_pose"].to(dtype=weight_dtype)
668
+
669
+ #print(noisy_latents.size())
670
+ #print(cond_image_feature_p.size())
671
+ #print(cond_image_feature_g.size())
672
+ #print(cond_pose.size())
673
+
674
+ # Predict the noise residual
675
+ model_pred = sd_model(noisy_latents, timesteps, cond_image_feature_p,cond_image_feature_g, cond_pose, )
676
+
677
+ # Get the target for loss depending on the prediction type
678
+ if noise_scheduler.config.prediction_type == "epsilon":
679
+ target = noise
680
+ elif noise_scheduler.config.prediction_type == "v_prediction":
681
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
682
+ else:
683
+ raise ValueError(
684
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
685
+ )
686
+
687
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
688
+
689
+ accelerator.backward(loss)
690
+ if accelerator.sync_gradients:
691
+ params_to_clip = sd_model.parameters()
692
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
693
+ optimizer.step()
694
+ lr_scheduler.step()
695
+ optimizer.zero_grad(set_to_none=args.set_grads_to_none)
696
+
697
+ # Checks if the accelerator has performed an optimization step behind the scenes
698
+ if accelerator.sync_gradients:
699
+ global_steps += 1
700
+
701
+ if global_steps >= args.max_train_steps:
702
+ break
703
+
704
+
705
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
706
+ print(logs)
707
+ progress_bar.set_postfix(**logs)
708
+
709
+ progress_bar.update(1)
710
+
711
+ # Create the pipeline using the trained modules and save it.
712
+ accelerator.wait_for_everyone()
713
+ accelerator.end_training()
714
+
715
+
716
+
717
+ if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps:
718
+ print('saving', modelId)
719
+
720
+ checkpoint_state_dict = {
721
+ "epoch": 0,
722
+ "module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(),
723
+ }
724
+ print(list(sd_model.state_dict().keys())[:20])
725
+ torch.save(checkpoint_state_dict, modelId+".pt")
726
+
727
+ gc.collect()
728
+ torch.cuda.empty_cache()
729
+ #device = cuda.get_current_device()
730
+ #device.reset()
731
+ print('done train')
732
+ return
733
+
734
+ gc.collect()
735
+ torch.cuda.empty_cache()
736
+ return {k: v.cpu() for k, v in sd_model.state_dict().items()}
737
+
738
+
739
+
740
+
741
+ # Pose-transfer ===================================================================================================
742
+
743
+
744
+ device = "cuda"
745
+
746
+ class ImageProjModel(torch.nn.Module):
747
+ """SD model with image prompt"""
748
+ def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
749
+ super().__init__()
750
+
751
+ self.net = nn.Sequential(
752
+ nn.Linear(in_dim, hidden_dim),
753
+ nn.GELU(),
754
+ nn.Dropout(dropout),
755
+ nn.LayerNorm(hidden_dim),
756
+ nn.Linear(hidden_dim, out_dim),
757
+ nn.Dropout(dropout)
758
+ )
759
+
760
+ def forward(self, x):
761
+ return self.net(x)
762
+
763
+ def image_grid(imgs, rows, cols):
764
+ assert len(imgs) == rows * cols
765
+ w, h = imgs[0].size
766
+ print(w, h)
767
+ grid = Image.new("RGB", size=(cols * w, rows * h))
768
+ grid_w, grid_h = grid.size
769
+
770
+ for i, img in enumerate(imgs):
771
+ grid.paste(img, box=(i % cols * w, i // cols * h))
772
+ return grid
773
+
774
+ def load_mydict(modelId, finetuned_model):
775
+ if save_model:
776
+ model_ckpt_path = modelId+'.pt'
777
+ model_sd = torch.load(model_ckpt_path, map_location="cpu")["module"]
778
+ else:
779
+ model_sd = finetuned_model #torch.load(model_ckpt_path, map_location="cpu")["module"]
780
+
781
+ image_proj_model_dict = {}
782
+ pose_proj_dict = {}
783
+ unet_dict = {}
784
+ for k in model_sd.keys():
785
+ if k.startswith("pose_proj"):
786
+ pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
787
+
788
+ elif k.startswith("image_proj_model_p"):
789
+ image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
790
+ elif k.startswith("image_proj_model"):
791
+ image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
792
+
793
+
794
+ elif k.startswith("unet"):
795
+ unet_dict[k.replace("unet.", "")] = model_sd[k]
796
+ else:
797
+ print(k)
798
+ return image_proj_model_dict, pose_proj_dict, unet_dict
799
+
800
+
801
+
802
+ @spaces.GPU(duration=600)
803
+ def inference(modelId, in_image, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder, is_app=False):
804
+ print('start inference')
805
+ progress=gr.Progress(track_tqdm=True)
806
+
807
+ if not save_model:
808
+ finetuned_model = {k: v.cuda() for k, v in finetuned_model.items()}
809
+
810
+ device = "cuda"
811
+ pretrained_model_name_or_path ="stabilityai/stable-diffusion-2-1-base"
812
+ image_encoder_path = "facebook/dinov2-giant"
813
+ #model_ckpt_path = "./pcdms_ckpt.pt" # ckpt path
814
+ model_ckpt_path = modelId+'.pt'
815
+
816
+
817
+ clip_image_processor = CLIPImageProcessor()
818
+ img_transform = transforms.Compose([
819
+ transforms.ToTensor(),
820
+ transforms.Normalize([0.5], [0.5]),
821
+ ])
822
+
823
+ generator = torch.Generator(device=device).manual_seed(42)
824
+
825
+ """
826
+ unet = Stage2_InapintUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16,subfolder="unet",in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device)
827
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path,subfolder="vae").to(device, dtype=torch.float16)
828
+ image_encoder = Dinov2Model.from_pretrained(image_encoder_path).to(device, dtype=torch.float16)
829
+ """
830
+ noise_scheduler = DDIMScheduler(
831
+ num_train_timesteps=1000,
832
+ beta_start=0.00085,
833
+ beta_end=0.012,
834
+ beta_schedule="scaled_linear",
835
+ clip_sample=False,
836
+ set_alpha_to_one=False,
837
+ steps_offset=1,
838
+ )
839
+
840
+ unet = unet.to(device, dtype=torch.float16)
841
+ vae = vae.to(device, dtype=torch.float16)
842
+ image_encoder = image_encoder.to(device, dtype=torch.float16)
843
+
844
+
845
+ image_proj_model = ImageProjModel(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).to(dtype=torch.float16)
846
+ pose_proj_model = ControlNetConditioningEmbedding(
847
+ conditioning_embedding_channels=320,
848
+ block_out_channels=(16, 32, 96, 256),
849
+ conditioning_channels=3).to(device).to(dtype=torch.float16)
850
+
851
+
852
+ # load weight
853
+ print('loading', modelId)
854
+ image_proj_model_dict, pose_proj_dict, unet_dict = load_mydict(modelId, finetuned_model)
855
+ print('loaded', modelId)
856
+ image_proj_model.load_state_dict(image_proj_model_dict)
857
+ pose_proj_model.load_state_dict(pose_proj_dict)
858
+ unet.load_state_dict(unet_dict)
859
+
860
+
861
+ pipe = PCDMsPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", unet=unet, torch_dtype=torch.float16, scheduler=noise_scheduler,feature_extractor=None,safety_checker=None).to(device)
862
+
863
+ print('====================== model load finish ===================')
864
+
865
+ results = []
866
+ progress_bar = tqdm(range(len(target_poses)), initial=0, desc="Frames")
867
+
868
+
869
+ it = target_poses
870
+ if is_app:
871
+ it = progress.tqdm(it, desc="Pose Transfer")
872
+ for pose in it:
873
+
874
+ num_samples = 1
875
+ image_size = (512, 512)
876
+ s_img_path = 'imgs/'+input_img # input image 1
877
+ #target_pose_img = 'imgs/pose_'+str(n)+'.png' # input image 2
878
+
879
+ #t_pose = inference_pose(target_pose_img, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
880
+ #t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
881
+ t_pose = pose.convert("RGB").resize((image_size), Image.BICUBIC)
882
+ #t_pose = resize_and_pad(pose.convert("RGB"))
883
+
884
+
885
+ #s_img = Image.open(s_img_path)
886
+ width_orig, height_orig = in_image.size
887
+ s_img = in_image.convert("RGB").resize(image_size, Image.BICUBIC)
888
+ #s_img = resize_and_pad(in_image.convert("RGB"))
889
+ black_image = Image.new("RGB", s_img.size, (0, 0, 0)).resize(image_size, Image.BICUBIC)
890
+
891
+ s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height))
892
+ s_img_t_mask.paste(s_img, (0, 0))
893
+ s_img_t_mask.paste(black_image, (s_img.width, 0))
894
+
895
+ #s_pose = inference_pose(s_img_path, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
896
+ #s_pose = Image.open('imgs/sm_pose.jpg').convert("RGB").resize(image_size, Image.BICUBIC)
897
+ s_pose = in_pose.convert("RGB").resize(image_size, Image.BICUBIC)
898
+ #s_pose = resize_and_pad(in_pose.convert("RGB"))
899
+ print('source image width: {}, height: {}'.format(s_pose.width, s_pose.height))
900
+ #t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
901
+
902
+ st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height))
903
+ st_pose.paste(s_pose, (0, 0))
904
+ st_pose.paste(t_pose, (s_pose.width, 0))
905
+
906
+
907
+ clip_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values
908
+ vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0)
909
+ cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0)
910
+
911
+ mask1 = torch.ones((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
912
+ mask0 = torch.zeros((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
913
+ mask = torch.cat([mask1, mask0], dim=3)
914
+
915
+
916
+ with torch.inference_mode():
917
+ cond_pose = pose_proj_model(cond_st_pose.to(dtype=torch.float16, device=device))
918
+ simg_mask_latents = pipe.vae.encode(vae_image.to(device, dtype=torch.float16)).latent_dist.sample()
919
+ simg_mask_latents = simg_mask_latents * 0.18215
920
+
921
+ images_embeds = image_encoder(clip_s_img.to(device, dtype=torch.float16)).last_hidden_state
922
+ image_prompt_embeds = image_proj_model(images_embeds)
923
+ uncond_image_prompt_embeds = image_proj_model(torch.zeros_like(images_embeds))
924
+
925
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
926
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
927
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
928
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
929
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
930
+
931
+ output, _ = pipe(
932
+ simg_mask_latents= simg_mask_latents,
933
+ mask = mask,
934
+ cond_pose = cond_pose,
935
+ prompt_embeds=image_prompt_embeds,
936
+ negative_prompt_embeds=uncond_image_prompt_embeds,
937
+ height=image_size[1],
938
+ width=image_size[0]*2,
939
+ num_images_per_prompt=num_samples,
940
+ guidance_scale=2.0,
941
+ generator=generator,
942
+ num_inference_steps=inference_steps,
943
+ )
944
+
945
+ output = output.images[-1]
946
+
947
+ result = output.crop((image_size[0], 0, image_size[0] * 2, image_size[1]))
948
+ result = result.resize((width_orig, height_orig), Image.BICUBIC)
949
+ #result = remove_zero_pad(result)
950
+
951
+ if debug:
952
+ result.save('out/'+str(len(results))+'.png')
953
+ results.append(result)
954
+ progress_bar.update(1)
955
+
956
+ gc.collect()
957
+ torch.cuda.empty_cache()
958
+
959
+ return results
960
+
961
+
962
+ def gen_vid(frames, video_name, fps, codec):
963
+ progress=gr.Progress(track_tqdm=True)
964
+
965
+ frame = cv2.cvtColor(np.array(frames[0]), cv2.COLOR_RGB2BGR)
966
+ height, width, layers = frame.shape
967
+
968
+ #video = cv2.VideoWriter(video_name, 0, 1, (width,height))
969
+ if codec == 'mp4':
970
+ video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
971
+ else:
972
+ video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'VP90'), fps, (width, height))
973
+
974
+ for r in progress.tqdm(frames, desc="Creating video"):
975
+ image = cv2.cvtColor(np.array(r), cv2.COLOR_RGB2BGR)
976
+ video.write(image)
977
+
978
+ #cv2.destroyAllWindows()
979
+ #video.release()
980
+
981
+
982
+
983
+ def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True, finetune=True, is_app=False):
984
+ print("==== Load Models ====")
985
+ dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
986
+
987
+ print("==== Pose Detection ====")
988
+ if resize_inputs:
989
+ resize = 'target'
990
+ else:
991
+ resize = 'none'
992
+ in_img, in_pose, train_imgs, train_poses, target_poses = prepare_inputs(images, video_path, fps, bg_remove, dwpose, rembg_session, resize=resize, is_app=is_app)
993
+
994
+ if save_model:
995
+ train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
996
+ print('next')
997
+ results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
998
+
999
+ else:
1000
+ print("==== Finetuning ====")
1001
+ finetuned_model = train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1002
+
1003
+ print("==== Pose Transfer ====")
1004
+ results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder_p, is_app)
1005
+
1006
+ return results
1007
+
1008
+
1009
+ def run_train(images, train_steps=100, bg_remove=False, resize_inputs=True, modelId="fine_tuned_pcdms"):
1010
+ finetune=True
1011
+ is_app=True
1012
+ images = [img[0] for img in images]
1013
+
1014
+ dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1015
+
1016
+ if resize_inputs:
1017
+ resize = 'target'
1018
+ else:
1019
+ resize = 'none'
1020
+
1021
+ in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
1022
+
1023
+ train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
1024
+
1025
+
1026
+ def run_inference(images, video_path, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True, modelId="fine_tuned_pcdms"):
1027
+ is_app=True
1028
+ images = [img[0] for img in images]
1029
+ in_img = images[0]
1030
+
1031
+ dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
1032
+
1033
+ target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, 'target', is_app)
1034
+
1035
+ results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
1036
+
1037
+ if debug:
1038
+ gen_vid(results, out_vid+'.mp4', fps, 'mp4')
1039
+ else:
1040
+ gen_vid(results, out_vid+'.webm', fps, 'webm')
1041
+
1042
+ print("Done!")
1043
+
1044
+ return out_vid+'.webm', results
1045
+
1046
+
1047
+ def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True):
1048
+
1049
+ images = [img[0] for img in images]
1050
+
1051
+ results = run(images, video_path, train_steps, inference_steps, fps, bg_remove, resize_inputs, finetune=True, is_app=True)
1052
+
1053
+
1054
+ print("==== Video generation ====")
1055
+ out_vid = f"out_{uuid.uuid4()}"
1056
+
1057
+ if debug:
1058
+ gen_vid(results, out_vid+'.mp4', fps, 'mp4')
1059
+ else:
1060
+ gen_vid(results, out_vid+'.webm', fps, 'webm')
1061
+
1062
+
1063
+
1064
+ print("Done!")
1065
+
1066
+ return out_vid+'.webm', results
1067
+
1068
+
1069
+
1070
+ """
1071
+ train_steps = 100
1072
+ inference_steps = 10
1073
+ fps = 12
1074
+ """
1075
+
1076
+ """
1077
+ iface = gr.Interface(
1078
+ fn=run,
1079
+ inputs=[
1080
+ gr.Gallery(type="pil", label="Images of the Character"),
1081
+ gr.Video(label="Motion-Capture Video"),
1082
+ gr.Number(label="Training steps", value=100),
1083
+ gr.Number(label="Inference steps", value=10),
1084
+ gr.Number(label="Output frame rate", value=12),
1085
+ gr.Checkbox(label="Remove background", value=False),
1086
+ ],
1087
+ outputs=[gr.Video(label="Result"), gr.Gallery(type="pil", label="Frames")],
1088
+ title="Keyframes AI",
1089
+ description="Upload images of your character and a motion-capture video to generate an animation of the character.",
1090
+ )
1091
+ """
1092
+
1093
+
1094
+
1095
+
1096
+
1097
+
metrics.json ADDED
@@ -0,0 +1,1538 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "sidewalk": {
3
+ "ft": {
4
+ "ssim": {
5
+ "avg": 0.9425153732299805,
6
+ "vals": [
7
+ 0.9539688229560852,
8
+ 0.9410380721092224,
9
+ 0.9339408278465271,
10
+ 0.9381983876228333,
11
+ 0.937977135181427,
12
+ 0.937738835811615,
13
+ 0.9322819709777832,
14
+ 0.9447341561317444,
15
+ 0.9701671004295349,
16
+ 0.9403354525566101,
17
+ 0.9407327771186829,
18
+ 0.9384147524833679,
19
+ 0.9429612159729004,
20
+ 0.9541351199150085,
21
+ 0.9593266844749451,
22
+ 0.9423799514770508,
23
+ 0.9342405200004578,
24
+ 0.9316573143005371,
25
+ 0.9376768469810486,
26
+ 0.9355674386024475,
27
+ 0.9092764258384705,
28
+ 0.9486138820648193,
29
+ 0.9740718007087708,
30
+ 0.9381063580513,
31
+ 0.9418080449104309,
32
+ 0.9435502886772156,
33
+ 0.9383361339569092,
34
+ 0.9491550326347351,
35
+ 0.9579421877861023,
36
+ 0.9433062672615051,
37
+ 0.9372091889381409,
38
+ 0.9316847920417786,
39
+ 0.9353565573692322,
40
+ 0.9314845204353333,
41
+ 0.9296603202819824,
42
+ 0.947252094745636,
43
+ 0.9745006561279297,
44
+ 0.9371557235717773,
45
+ 0.9355664849281311,
46
+ 0.9372367858886719,
47
+ 0.9420561790466309,
48
+ 0.9548425674438477
49
+ ]
50
+ },
51
+ "psnr": {
52
+ "avg": 29.313762982030134,
53
+ "vals": [
54
+ 31.79728361792486,
55
+ 29.12789926111764,
56
+ 28.150918151496114,
57
+ 28.197009952290536,
58
+ 28.30431945800427,
59
+ 28.3766962913435,
60
+ 27.95230686775166,
61
+ 29.60480299071412,
62
+ 33.870620988359505,
63
+ 28.2723202748136,
64
+ 28.65414745402314,
65
+ 29.44914322199964,
66
+ 28.918554998476544,
67
+ 30.829394826527686,
68
+ 32.453438778794194,
69
+ 29.07800808482171,
70
+ 28.126848239355724,
71
+ 28.219538928254238,
72
+ 28.15327266691132,
73
+ 28.37061664435462,
74
+ 22.572103732095137,
75
+ 30.218097049078665,
76
+ 35.291027583657176,
77
+ 28.264550325776455,
78
+ 28.548636901711063,
79
+ 30.077494275071125,
80
+ 28.765906730485675,
81
+ 30.473295981007006,
82
+ 32.31943212041054,
83
+ 29.03166397751077,
84
+ 28.20738376155079,
85
+ 27.52313906207686,
86
+ 28.068443800302507,
87
+ 28.02290963487483,
88
+ 27.137965527670815,
89
+ 29.891409757083256,
90
+ 35.52381896568835,
91
+ 28.314135705730163,
92
+ 28.17076848548687,
93
+ 29.376791831995504,
94
+ 28.654625695022126,
95
+ 30.817302643645203
96
+ ]
97
+ },
98
+ "lpips": {
99
+ "avg": 0.0587034212159259,
100
+ "vals": [
101
+ 0.03719184175133705,
102
+ 0.05626270920038223,
103
+ 0.06762276589870453,
104
+ 0.07581378519535065,
105
+ 0.06386660039424896,
106
+ 0.06519967317581177,
107
+ 0.06447210162878036,
108
+ 0.044685374945402145,
109
+ 0.02818106673657894,
110
+ 0.06283751875162125,
111
+ 0.05770150199532509,
112
+ 0.05407204478979111,
113
+ 0.06568614393472672,
114
+ 0.041273392736911774,
115
+ 0.03479328751564026,
116
+ 0.05250183865427971,
117
+ 0.06724748760461807,
118
+ 0.07919331640005112,
119
+ 0.06452416628599167,
120
+ 0.06678760051727295,
121
+ 0.15427803993225098,
122
+ 0.04608333483338356,
123
+ 0.026074513792991638,
124
+ 0.06458891183137894,
125
+ 0.06116756051778793,
126
+ 0.050854653120040894,
127
+ 0.06727714836597443,
128
+ 0.04326387867331505,
129
+ 0.037447478622198105,
130
+ 0.0521712489426136,
131
+ 0.0674269050359726,
132
+ 0.07980889827013016,
133
+ 0.06633622944355011,
134
+ 0.06805833429098129,
135
+ 0.07028443366289139,
136
+ 0.043847665190696716,
137
+ 0.025734560564160347,
138
+ 0.06363274157047272,
139
+ 0.06513398140668869,
140
+ 0.0585525706410408,
141
+ 0.06366591155529022,
142
+ 0.039940472692251205
143
+ ]
144
+ },
145
+ "fid": {
146
+ "avg": 42.43536594462592,
147
+ "vals": [
148
+ 20.22960865639975,
149
+ 47.958607601608676,
150
+ 57.62921688800561,
151
+ 95.16926849364827,
152
+ 52.86706191106997,
153
+ 40.88534322140695,
154
+ 66.72949840511895,
155
+ 24.21628873128267,
156
+ 18.986568909208156,
157
+ 33.69661265603083,
158
+ 24.557634338057728,
159
+ 54.261117389557846,
160
+ 51.156090096940055,
161
+ 22.082607032341244,
162
+ 24.085151411940856,
163
+ 45.16219303566392,
164
+ 51.10103371191223,
165
+ 94.4284018519586,
166
+ 61.20478944897824,
167
+ 53.09587418802096,
168
+ 45.77148143461985,
169
+ 25.706806643970687,
170
+ 17.209723468709765,
171
+ 38.78578933921177,
172
+ 25.072541814813434,
173
+ 41.178136179334956,
174
+ 37.296019926311956,
175
+ 27.822427143452785,
176
+ 22.649814443169078,
177
+ 37.95954774350004,
178
+ 47.89610979304835,
179
+ 90.26735601619404,
180
+ 50.99027170447254,
181
+ 47.766544543497034,
182
+ 74.45788273316838,
183
+ 22.916864843654203,
184
+ 17.551236020192366,
185
+ 29.612576631133486,
186
+ 32.3629873396745,
187
+ 46.47368512482174,
188
+ 39.25242590132321,
189
+ 23.782172906863416
190
+ ]
191
+ }
192
+ },
193
+ "base": {
194
+ "ssim": {
195
+ "avg": 0.8774991716657367,
196
+ "vals": [
197
+ 0.8641023635864258,
198
+ 0.8767493367195129,
199
+ 0.8762156367301941,
200
+ 0.90242999792099,
201
+ 0.9017764925956726,
202
+ 0.8788102269172668,
203
+ 0.8805813789367676,
204
+ 0.8612765669822693,
205
+ 0.8842142224311829,
206
+ 0.8914286494255066,
207
+ 0.8772357106208801,
208
+ 0.8998511433601379,
209
+ 0.8634107112884521,
210
+ 0.8928578495979309,
211
+ 0.8836771845817566,
212
+ 0.8996903300285339,
213
+ 0.866424560546875,
214
+ 0.8816824555397034,
215
+ 0.8785967230796814,
216
+ 0.8867073655128479,
217
+ 0.8279438614845276,
218
+ 0.8717477321624756,
219
+ 0.8780853748321533,
220
+ 0.8800415992736816,
221
+ 0.8771164417266846,
222
+ 0.8793924450874329,
223
+ 0.8809206485748291,
224
+ 0.8608385920524597,
225
+ 0.893017590045929,
226
+ 0.8689785003662109,
227
+ 0.8729929327964783,
228
+ 0.8567183017730713,
229
+ 0.8588916659355164,
230
+ 0.8677377104759216,
231
+ 0.8715691566467285,
232
+ 0.8649471402168274,
233
+ 0.884833037853241,
234
+ 0.913370668888092,
235
+ 0.8970480561256409,
236
+ 0.8238387107849121,
237
+ 0.8899227976799011,
238
+ 0.8872933387756348
239
+ ]
240
+ },
241
+ "psnr": {
242
+ "avg": 21.225323061553667,
243
+ "vals": [
244
+ 18.731802028716977,
245
+ 22.517723678922334,
246
+ 21.373543493859014,
247
+ 22.758004432454015,
248
+ 22.955530937204767,
249
+ 20.145150147854874,
250
+ 21.634592878743387,
251
+ 20.379519268519918,
252
+ 23.696208467856458,
253
+ 22.98936924589105,
254
+ 21.756553483824376,
255
+ 23.69843738135742,
256
+ 17.770960047525655,
257
+ 22.7351530856508,
258
+ 20.95728828090057,
259
+ 23.440357115988917,
260
+ 20.458007935114516,
261
+ 19.903659717164693,
262
+ 19.636953200826923,
263
+ 21.34791655058212,
264
+ 19.213297117979387,
265
+ 21.474593210996353,
266
+ 23.02245609312815,
267
+ 21.57037201859069,
268
+ 20.49675090194842,
269
+ 21.149239825321867,
270
+ 21.768542181018454,
271
+ 20.75462471848509,
272
+ 23.340499225598812,
273
+ 20.184658098890935,
274
+ 19.84577062228628,
275
+ 18.651530199516753,
276
+ 17.713097904145002,
277
+ 19.501488720897637,
278
+ 20.525847469970195,
279
+ 20.044060320208107,
280
+ 23.33142797550424,
281
+ 25.84051181341309,
282
+ 24.517940621682737,
283
+ 16.233878462859337,
284
+ 21.53839254074825,
285
+ 21.857857163105283
286
+ ]
287
+ },
288
+ "lpips": {
289
+ "avg": 0.17535314992779777,
290
+ "vals": [
291
+ 0.1866557002067566,
292
+ 0.19489219784736633,
293
+ 0.1676192283630371,
294
+ 0.13815467059612274,
295
+ 0.1378740817308426,
296
+ 0.19519048929214478,
297
+ 0.1394621878862381,
298
+ 0.18285232782363892,
299
+ 0.1300656795501709,
300
+ 0.13945598900318146,
301
+ 0.1777152568101883,
302
+ 0.1422179788351059,
303
+ 0.2465370148420334,
304
+ 0.173164963722229,
305
+ 0.16046200692653656,
306
+ 0.1408388763666153,
307
+ 0.1748061180114746,
308
+ 0.15652400255203247,
309
+ 0.18918247520923615,
310
+ 0.13944856822490692,
311
+ 0.3080754280090332,
312
+ 0.15395303070545197,
313
+ 0.13146977126598358,
314
+ 0.21760974824428558,
315
+ 0.16546086966991425,
316
+ 0.18631187081336975,
317
+ 0.17548373341560364,
318
+ 0.2248706817626953,
319
+ 0.14110884070396423,
320
+ 0.1748245805501938,
321
+ 0.17234086990356445,
322
+ 0.1958460807800293,
323
+ 0.22889724373817444,
324
+ 0.1727704405784607,
325
+ 0.15012094378471375,
326
+ 0.20727047324180603,
327
+ 0.1201525330543518,
328
+ 0.09184442460536957,
329
+ 0.11003588140010834,
330
+ 0.34251728653907776,
331
+ 0.18208707869052887,
332
+ 0.19866067171096802
333
+ ]
334
+ },
335
+ "fid": {
336
+ "avg": 104.964235596045,
337
+ "vals": [
338
+ 111.40174829746752,
339
+ 148.92090091433516,
340
+ 128.51386183067834,
341
+ 61.321546832998955,
342
+ 72.54608643303393,
343
+ 119.84550665904197,
344
+ 104.43999062002221,
345
+ 60.195703096095556,
346
+ 64.8998572256939,
347
+ 49.645917433260855,
348
+ 114.28769797616002,
349
+ 66.89825805245846,
350
+ 155.83362493523907,
351
+ 88.44848612474361,
352
+ 104.61621566440316,
353
+ 163.43248950817738,
354
+ 172.53727330437798,
355
+ 101.98603573984371,
356
+ 125.94605415324145,
357
+ 102.60611610457259,
358
+ 104.21668109264168,
359
+ 102.24131609759299,
360
+ 70.50377311449385,
361
+ 161.15530398717667,
362
+ 96.17823491839759,
363
+ 105.78022442714524,
364
+ 120.89941533798705,
365
+ 94.75931214260221,
366
+ 73.59833927600653,
367
+ 112.51851327401378,
368
+ 148.31978447992762,
369
+ 113.95507129041124,
370
+ 82.26377459422716,
371
+ 154.98015860895285,
372
+ 96.10781829434029,
373
+ 100.0667395138281,
374
+ 83.65808540884001,
375
+ 50.67035884888851,
376
+ 64.41624770555333,
377
+ 123.2744913103858,
378
+ 106.94257164137686,
379
+ 123.66830876325491
380
+ ]
381
+ }
382
+ }
383
+ },
384
+ "aaa": {
385
+ "ft": {
386
+ "ssim": {
387
+ "avg": 0.8912452217694875,
388
+ "vals": [
389
+ 0.9061050415039062,
390
+ 0.9104815125465393,
391
+ 0.9111503958702087,
392
+ 0.8999118208885193,
393
+ 0.9075160026550293,
394
+ 0.9037303924560547,
395
+ 0.9100475311279297,
396
+ 0.9105749130249023,
397
+ 0.9070033431053162,
398
+ 0.928657054901123,
399
+ 0.9301331043243408,
400
+ 0.9153156280517578,
401
+ 0.9060927033424377,
402
+ 0.9026443362236023,
403
+ 0.9006186127662659,
404
+ 0.8967500329017639,
405
+ 0.8858065605163574,
406
+ 0.8841550946235657,
407
+ 0.8957152366638184,
408
+ 0.9159244894981384,
409
+ 0.9212849140167236,
410
+ 0.8954355716705322,
411
+ 0.882415771484375,
412
+ 0.8875539898872375,
413
+ 0.8762226700782776,
414
+ 0.8741069436073303,
415
+ 0.8663904070854187,
416
+ 0.8608574867248535,
417
+ 0.8592395782470703,
418
+ 0.8617714047431946,
419
+ 0.8670153021812439,
420
+ 0.8589889407157898,
421
+ 0.8667375445365906,
422
+ 0.8646472096443176,
423
+ 0.8644879460334778,
424
+ 0.8704419732093811,
425
+ 0.8701417446136475
426
+ ]
427
+ },
428
+ "psnr": {
429
+ "avg": 23.218790631433333,
430
+ "vals": [
431
+ 22.60000443641932,
432
+ 23.213327330683896,
433
+ 23.966372308676554,
434
+ 22.61628135760781,
435
+ 23.334516149012956,
436
+ 22.76956956037783,
437
+ 23.522473461596828,
438
+ 23.818158050349368,
439
+ 24.35100933616721,
440
+ 27.585580066737258,
441
+ 27.485016628132442,
442
+ 25.653444844066296,
443
+ 24.116901996142047,
444
+ 23.141365324276826,
445
+ 23.439784068380483,
446
+ 23.33054311947808,
447
+ 22.82591865046196,
448
+ 22.71185043604199,
449
+ 24.538035778803117,
450
+ 27.826709590087102,
451
+ 28.474405666001857,
452
+ 24.769433736339693,
453
+ 23.449171903609013,
454
+ 23.46383536533758,
455
+ 21.324737956273246,
456
+ 21.96812016491501,
457
+ 20.7294582590173,
458
+ 20.641701301001603,
459
+ 20.774969220796272,
460
+ 21.108372009572886,
461
+ 22.00821781586527,
462
+ 20.4745834349127,
463
+ 21.075791929560445,
464
+ 21.104358437386903,
465
+ 21.55546473865332,
466
+ 22.000344888516832,
467
+ 21.32542404177411
468
+ ]
469
+ },
470
+ "lpips": {
471
+ "avg": 0.10130888052486084,
472
+ "vals": [
473
+ 0.14061428606510162,
474
+ 0.14724630117416382,
475
+ 0.1335088610649109,
476
+ 0.13316522538661957,
477
+ 0.12164061516523361,
478
+ 0.11964336037635803,
479
+ 0.10177505016326904,
480
+ 0.10693053156137466,
481
+ 0.0958232507109642,
482
+ 0.05439213290810585,
483
+ 0.05292022228240967,
484
+ 0.07251676917076111,
485
+ 0.09491521865129471,
486
+ 0.11164701730012894,
487
+ 0.1037021279335022,
488
+ 0.09932567179203033,
489
+ 0.0995492935180664,
490
+ 0.0946572870016098,
491
+ 0.07657209038734436,
492
+ 0.04164290428161621,
493
+ 0.04009518772363663,
494
+ 0.06518009305000305,
495
+ 0.07143466174602509,
496
+ 0.07886430621147156,
497
+ 0.09445148706436157,
498
+ 0.11044608056545258,
499
+ 0.1215718686580658,
500
+ 0.12195122241973877,
501
+ 0.12177268415689468,
502
+ 0.13400736451148987,
503
+ 0.1025841161608696,
504
+ 0.11698935180902481,
505
+ 0.11627939343452454,
506
+ 0.11498457193374634,
507
+ 0.11028993129730225,
508
+ 0.10681547224521637,
509
+ 0.11852256953716278
510
+ ]
511
+ },
512
+ "fid": {
513
+ "avg": 128.47027428857402,
514
+ "vals": [
515
+ 210.6965716493914,
516
+ 136.02845249812458,
517
+ 142.9994319683501,
518
+ 102.24152561128506,
519
+ 106.32239193197034,
520
+ 133.25417613786345,
521
+ 148.41552185542787,
522
+ 156.6543860781757,
523
+ 169.9280658509937,
524
+ 96.20484628876847,
525
+ 124.39223402724446,
526
+ 157.57056317546318,
527
+ 182.60839379211808,
528
+ 143.061621048883,
529
+ 129.08572262883925,
530
+ 146.6837514000539,
531
+ 121.28023529371208,
532
+ 133.09083552748842,
533
+ 133.1701816777087,
534
+ 106.30667335984795,
535
+ 92.99936563096986,
536
+ 106.42616281960741,
537
+ 91.02325698154696,
538
+ 82.03384095285311,
539
+ 100.23715859488628,
540
+ 103.07187891625573,
541
+ 126.73029941764449,
542
+ 120.07365192426221,
543
+ 120.96345164031321,
544
+ 116.58719477188482,
545
+ 167.96926051103426,
546
+ 180.28950392287587,
547
+ 105.1495396073212,
548
+ 140.74437360480772,
549
+ 99.20108044973863,
550
+ 113.96675633000953,
551
+ 105.93779079951634
552
+ ]
553
+ }
554
+ },
555
+ "base": {
556
+ "ssim": {
557
+ "avg": 0.8579733484500164,
558
+ "vals": [
559
+ 0.8666160702705383,
560
+ 0.7987958788871765,
561
+ 0.8860695362091064,
562
+ 0.8763647079467773,
563
+ 0.8763935565948486,
564
+ 0.8563421368598938,
565
+ 0.8911734223365784,
566
+ 0.8716912269592285,
567
+ 0.8579948544502258,
568
+ 0.8772208094596863,
569
+ 0.8896298408508301,
570
+ 0.8552396893501282,
571
+ 0.8748952746391296,
572
+ 0.8613572120666504,
573
+ 0.8676340579986572,
574
+ 0.8660659193992615,
575
+ 0.8730340600013733,
576
+ 0.8534577488899231,
577
+ 0.8619491457939148,
578
+ 0.8790878653526306,
579
+ 0.869230329990387,
580
+ 0.8246889114379883,
581
+ 0.8481850624084473,
582
+ 0.8987539410591125,
583
+ 0.8677256107330322,
584
+ 0.8619834780693054,
585
+ 0.8389760851860046,
586
+ 0.8448922634124756,
587
+ 0.8406286239624023,
588
+ 0.805531919002533,
589
+ 0.8549509048461914,
590
+ 0.8503775596618652,
591
+ 0.8362762928009033,
592
+ 0.8465806841850281,
593
+ 0.8307406902313232,
594
+ 0.850950300693512,
595
+ 0.8335282206535339
596
+ ]
597
+ },
598
+ "psnr": {
599
+ "avg": 19.081316861697736,
600
+ "vals": [
601
+ 16.88895564145245,
602
+ 15.797760205800826,
603
+ 18.41262183621939,
604
+ 18.562114490640262,
605
+ 17.917161492127512,
606
+ 18.252547952717784,
607
+ 19.64865146927212,
608
+ 17.6865025569328,
609
+ 17.99347739896675,
610
+ 18.854024365954317,
611
+ 19.934364132621422,
612
+ 18.675988138292176,
613
+ 19.247068668941317,
614
+ 19.570660435858237,
615
+ 17.830968802374684,
616
+ 19.397491216401374,
617
+ 21.611632805892775,
618
+ 20.071126840114104,
619
+ 19.438827872416823,
620
+ 22.09851766389441,
621
+ 20.985060014772596,
622
+ 17.811881997736606,
623
+ 20.828426349668014,
624
+ 21.464988324765418,
625
+ 19.570091780949138,
626
+ 20.7044778911251,
627
+ 18.34111665771247,
628
+ 19.15076106907652,
629
+ 20.11914583622098,
630
+ 16.338984721289947,
631
+ 21.540692338174694,
632
+ 19.138283664226506,
633
+ 17.11142702905118,
634
+ 18.437481626987633,
635
+ 18.706711870048608,
636
+ 20.13241290437168,
637
+ 17.73631581974743
638
+ ]
639
+ },
640
+ "lpips": {
641
+ "avg": 0.2180521393547187,
642
+ "vals": [
643
+ 0.24907447397708893,
644
+ 0.5150208473205566,
645
+ 0.21998971700668335,
646
+ 0.30011799931526184,
647
+ 0.24314486980438232,
648
+ 0.2725614607334137,
649
+ 0.2167254090309143,
650
+ 0.22530333697795868,
651
+ 0.2836877703666687,
652
+ 0.2629338204860687,
653
+ 0.16042795777320862,
654
+ 0.2589859664440155,
655
+ 0.18903297185897827,
656
+ 0.19533997774124146,
657
+ 0.27064773440361023,
658
+ 0.1686403751373291,
659
+ 0.12280213087797165,
660
+ 0.25335609912872314,
661
+ 0.29388710856437683,
662
+ 0.09789006412029266,
663
+ 0.1359485536813736,
664
+ 0.37839996814727783,
665
+ 0.16171419620513916,
666
+ 0.09748921543359756,
667
+ 0.11395677924156189,
668
+ 0.13522587716579437,
669
+ 0.20320472121238708,
670
+ 0.20818984508514404,
671
+ 0.16585132479667664,
672
+ 0.2417808175086975,
673
+ 0.11952673643827438,
674
+ 0.15657898783683777,
675
+ 0.21135613322257996,
676
+ 0.20219209790229797,
677
+ 0.2686466574668884,
678
+ 0.16088804602622986,
679
+ 0.3074091076850891
680
+ ]
681
+ },
682
+ "fid": {
683
+ "avg": 168.1660536989984,
684
+ "vals": [
685
+ 247.82225753332435,
686
+ 261.94302538502393,
687
+ 148.08945162416452,
688
+ 156.38665720624414,
689
+ 164.2083172443518,
690
+ 152.55262134635595,
691
+ 167.74632833375753,
692
+ 186.55618400333205,
693
+ 237.13432800336085,
694
+ 295.59781724649554,
695
+ 183.50749905003778,
696
+ 233.3946841940326,
697
+ 127.19021998668687,
698
+ 237.64574583795806,
699
+ 212.50245640676792,
700
+ 138.55530819236137,
701
+ 214.2925259680326,
702
+ 206.29937532009325,
703
+ 166.21050755154022,
704
+ 78.77898145774333,
705
+ 102.08709635256376,
706
+ 205.22517444539076,
707
+ 103.51534026875984,
708
+ 79.41715249937273,
709
+ 119.52601444212452,
710
+ 103.10200144999655,
711
+ 177.805092182125,
712
+ 132.67670025242347,
713
+ 131.8997749255857,
714
+ 137.54224584649327,
715
+ 132.2298444263562,
716
+ 220.02890730124264,
717
+ 148.88606776876063,
718
+ 136.47310567443643,
719
+ 165.10752223661132,
720
+ 125.63803892245707,
721
+ 184.56961597657588
722
+ ]
723
+ }
724
+ }
725
+ },
726
+ "azri": {
727
+ "ft": {
728
+ "ssim": {
729
+ "avg": 0.9058952973439143,
730
+ "vals": [
731
+ 0.8917362689971924,
732
+ 0.9481084942817688,
733
+ 0.8861625790596008,
734
+ 0.9341784119606018,
735
+ 0.9068629145622253,
736
+ 0.9030451774597168,
737
+ 0.9057510495185852,
738
+ 0.8903522491455078,
739
+ 0.8899074196815491,
740
+ 0.8972444534301758,
741
+ 0.9205942153930664,
742
+ 0.9142631888389587,
743
+ 0.8862788081169128,
744
+ 0.9196522235870361,
745
+ 0.8918246626853943,
746
+ 0.9560407996177673,
747
+ 0.901726484298706,
748
+ 0.8848044276237488,
749
+ 0.8926780819892883,
750
+ 0.8907856941223145,
751
+ 0.9449396729469299,
752
+ 0.8873879313468933,
753
+ 0.9299740791320801,
754
+ 0.904637336730957,
755
+ 0.9079319834709167,
756
+ 0.9113084673881531,
757
+ 0.8961232304573059,
758
+ 0.881411075592041,
759
+ 0.891535222530365,
760
+ 0.9156699776649475,
761
+ 0.9131000638008118,
762
+ 0.8913257122039795,
763
+ 0.9246047139167786,
764
+ 0.8882644772529602,
765
+ 0.9553866982460022,
766
+ 0.8979193568229675,
767
+ 0.884848415851593,
768
+ 0.8858329653739929,
769
+ 0.8967573046684265,
770
+ 0.9442141652107239,
771
+ 0.8944272398948669,
772
+ 0.9422566294670105,
773
+ 0.9031774401664734,
774
+ 0.9035826325416565,
775
+ 0.9115505814552307,
776
+ 0.8966580033302307,
777
+ 0.8833451867103577,
778
+ 0.8935332894325256,
779
+ 0.9135630130767822,
780
+ 0.9111640453338623,
781
+ 0.8814460635185242,
782
+ 0.9066808819770813
783
+ ]
784
+ },
785
+ "psnr": {
786
+ "avg": 22.26210885067453,
787
+ "vals": [
788
+ 21.127093280116206,
789
+ 24.839156863255635,
790
+ 20.600073186747085,
791
+ 25.025180077941343,
792
+ 22.081731981690943,
793
+ 22.42101370484669,
794
+ 22.81864654433271,
795
+ 21.18838097936547,
796
+ 22.058042539733158,
797
+ 21.10834360782046,
798
+ 23.075382471633933,
799
+ 22.544052476196704,
800
+ 19.00360259824341,
801
+ 21.92184820041113,
802
+ 21.430364755014008,
803
+ 28.074184819479914,
804
+ 22.007287073710827,
805
+ 20.596292100864723,
806
+ 22.0403525229489,
807
+ 20.915785732740424,
808
+ 24.24433505144897,
809
+ 20.759471581753495,
810
+ 23.580874095639594,
811
+ 21.765209438099404,
812
+ 22.735731723488712,
813
+ 23.165893102346963,
814
+ 21.193429487248192,
815
+ 20.883114381461837,
816
+ 20.509329225340185,
817
+ 22.74356627523407,
818
+ 23.076752632172514,
819
+ 20.80786955805867,
820
+ 22.356163191215202,
821
+ 21.454672093200195,
822
+ 28.091895600685465,
823
+ 21.89353215478105,
824
+ 20.445727082361902,
825
+ 21.73245024548928,
826
+ 21.415697749014967,
827
+ 25.255915842679492,
828
+ 21.344568596586,
829
+ 25.736038367482795,
830
+ 21.80143662440821,
831
+ 22.556710854506832,
832
+ 23.011730483476125,
833
+ 21.360129615824693,
834
+ 21.183864623422796,
835
+ 20.786549475403824,
836
+ 23.05387236898404,
837
+ 22.61546996399324,
838
+ 18.695781286483605,
839
+ 22.495061945689415
840
+ ]
841
+ },
842
+ "lpips": {
843
+ "avg": 0.0681791385420813,
844
+ "vals": [
845
+ 0.07602706551551819,
846
+ 0.04024939239025116,
847
+ 0.08144165575504303,
848
+ 0.04438445717096329,
849
+ 0.07352830469608307,
850
+ 0.05720607936382294,
851
+ 0.06115525960922241,
852
+ 0.08031078428030014,
853
+ 0.07755088806152344,
854
+ 0.08515866100788116,
855
+ 0.058882661163806915,
856
+ 0.07514132559299469,
857
+ 0.08883378654718399,
858
+ 0.08077041804790497,
859
+ 0.0614573135972023,
860
+ 0.017901955172419548,
861
+ 0.06920649111270905,
862
+ 0.07639990746974945,
863
+ 0.07837355881929398,
864
+ 0.0728088989853859,
865
+ 0.04596266895532608,
866
+ 0.09040576219558716,
867
+ 0.057428933680057526,
868
+ 0.07047676295042038,
869
+ 0.054519202560186386,
870
+ 0.0576803982257843,
871
+ 0.0698535293340683,
872
+ 0.08105408400297165,
873
+ 0.09364423155784607,
874
+ 0.06507238000631332,
875
+ 0.07589171826839447,
876
+ 0.0750972181558609,
877
+ 0.074916310608387,
878
+ 0.062267474830150604,
879
+ 0.021616671234369278,
880
+ 0.06903669983148575,
881
+ 0.07748014479875565,
882
+ 0.08583179116249084,
883
+ 0.06618104875087738,
884
+ 0.05046549439430237,
885
+ 0.07481929659843445,
886
+ 0.042737700045108795,
887
+ 0.07031229138374329,
888
+ 0.05742187052965164,
889
+ 0.058822184801101685,
890
+ 0.08137130737304688,
891
+ 0.08340458571910858,
892
+ 0.08779360353946686,
893
+ 0.05354199558496475,
894
+ 0.0719795972108841,
895
+ 0.09594655781984329,
896
+ 0.06549282371997833
897
+ ]
898
+ },
899
+ "fid": {
900
+ "avg": 87.69454384414368,
901
+ "vals": [
902
+ 77.53360222199552,
903
+ 81.69009799979364,
904
+ 126.35462309879928,
905
+ 128.097754928267,
906
+ 117.15623067707392,
907
+ 98.01858464323817,
908
+ 77.96826502055629,
909
+ 78.41206612900142,
910
+ 110.78376599817574,
911
+ 63.587186575144706,
912
+ 54.89619147757511,
913
+ 84.46815633637291,
914
+ 76.09334109513291,
915
+ 89.06784327761028,
916
+ 112.09989177281398,
917
+ 77.76499620466669,
918
+ 127.0296162823459,
919
+ 56.2819223080697,
920
+ 69.43078527369693,
921
+ 64.65111942898734,
922
+ 62.87374700137458,
923
+ 128.69000277570802,
924
+ 85.75299650509533,
925
+ 101.57266095839137,
926
+ 69.76067120877498,
927
+ 86.34886735845666,
928
+ 68.48340578638442,
929
+ 98.73242401936356,
930
+ 63.17021113035836,
931
+ 65.11579193591135,
932
+ 77.95177629642525,
933
+ 89.78335003024443,
934
+ 86.28969096009206,
935
+ 119.6802369335825,
936
+ 89.66745457666659,
937
+ 132.85212105669805,
938
+ 114.10763350784305,
939
+ 103.6134398480294,
940
+ 59.686968923134835,
941
+ 82.54559469802132,
942
+ 111.1869442643463,
943
+ 85.38777030577579,
944
+ 106.32082608988044,
945
+ 106.12593735864233,
946
+ 86.00820771852335,
947
+ 64.28650309776683,
948
+ 102.70937080791995,
949
+ 47.320553137933764,
950
+ 53.731227883240436,
951
+ 71.80043885301173,
952
+ 96.20161778732583,
953
+ 70.97179633123264
954
+ ]
955
+ }
956
+ },
957
+ "base": {
958
+ "ssim": {
959
+ "avg": 0.765552927668278,
960
+ "vals": [
961
+ 0.7876270413398743,
962
+ 0.6932912468910217,
963
+ 0.7860949039459229,
964
+ 0.8119332194328308,
965
+ 0.7923873066902161,
966
+ 0.7616052031517029,
967
+ 0.7635032534599304,
968
+ 0.752696692943573,
969
+ 0.7297654151916504,
970
+ 0.7741971015930176,
971
+ 0.7830759882926941,
972
+ 0.787074089050293,
973
+ 0.7227017283439636,
974
+ 0.7842667698860168,
975
+ 0.7618840336799622,
976
+ 0.7612975239753723,
977
+ 0.7881305813789368,
978
+ 0.7242482304573059,
979
+ 0.7725765109062195,
980
+ 0.7895841598510742,
981
+ 0.7584860920906067,
982
+ 0.7806427478790283,
983
+ 0.7808322906494141,
984
+ 0.8068225979804993,
985
+ 0.7732831835746765,
986
+ 0.7021427154541016,
987
+ 0.7602499127388,
988
+ 0.7775993943214417,
989
+ 0.7879728674888611,
990
+ 0.7852631211280823,
991
+ 0.7593006491661072,
992
+ 0.7491123080253601,
993
+ 0.8211724162101746,
994
+ 0.791597306728363,
995
+ 0.8033311367034912,
996
+ 0.751656711101532,
997
+ 0.7145156860351562,
998
+ 0.7085480690002441,
999
+ 0.7710719704627991,
1000
+ 0.7261748313903809,
1001
+ 0.7700297236442566,
1002
+ 0.721587598323822,
1003
+ 0.769636869430542,
1004
+ 0.7924219965934753,
1005
+ 0.7783514857292175,
1006
+ 0.7467233538627625,
1007
+ 0.7454271912574768,
1008
+ 0.7511434555053711,
1009
+ 0.7623993754386902,
1010
+ 0.7588648796081543,
1011
+ 0.7819810509681702,
1012
+ 0.792468249797821
1013
+ ]
1014
+ },
1015
+ "psnr": {
1016
+ "avg": 12.04275296711651,
1017
+ "vals": [
1018
+ 11.831093266628356,
1019
+ 12.528152127082233,
1020
+ 11.504680724545494,
1021
+ 12.06579748941828,
1022
+ 12.193623545830743,
1023
+ 12.967612777094432,
1024
+ 11.22349001603837,
1025
+ 12.676957730222881,
1026
+ 10.463052935595755,
1027
+ 12.068058666401921,
1028
+ 11.815425022083929,
1029
+ 11.474556127083932,
1030
+ 11.65343712380346,
1031
+ 12.970802840776267,
1032
+ 10.921165557898965,
1033
+ 11.487784206793332,
1034
+ 11.862155399851087,
1035
+ 12.500533116810459,
1036
+ 13.26478322723994,
1037
+ 11.454244322299186,
1038
+ 12.846474859150605,
1039
+ 11.178912278082944,
1040
+ 11.668886917161911,
1041
+ 13.535748073878615,
1042
+ 11.74158297715729,
1043
+ 11.307205931581416,
1044
+ 11.122995221603784,
1045
+ 12.315128915653881,
1046
+ 11.927970316701074,
1047
+ 11.471458996386659,
1048
+ 10.8378095973721,
1049
+ 12.40354222062087,
1050
+ 13.883183470169415,
1051
+ 12.746526277133636,
1052
+ 13.423830377359124,
1053
+ 12.183072443407536,
1054
+ 11.745039572219284,
1055
+ 11.288605142920026,
1056
+ 10.811118546539486,
1057
+ 11.492946654674247,
1058
+ 12.552522641514392,
1059
+ 12.042564039263278,
1060
+ 11.3939431319692,
1061
+ 12.898514335036811,
1062
+ 11.114529050567972,
1063
+ 11.339446334579078,
1064
+ 12.24930476842462,
1065
+ 13.364881567497516,
1066
+ 13.445838873213505,
1067
+ 12.214397326986626,
1068
+ 12.84755151678836,
1069
+ 11.900215690944371
1070
+ ]
1071
+ },
1072
+ "lpips": {
1073
+ "avg": 0.3510053243774634,
1074
+ "vals": [
1075
+ 0.3419290781021118,
1076
+ 0.4748748242855072,
1077
+ 0.3430050015449524,
1078
+ 0.3214326500892639,
1079
+ 0.28315189480781555,
1080
+ 0.38122087717056274,
1081
+ 0.40785902738571167,
1082
+ 0.283894807100296,
1083
+ 0.28781914710998535,
1084
+ 0.31599509716033936,
1085
+ 0.31139951944351196,
1086
+ 0.40369728207588196,
1087
+ 0.4894903302192688,
1088
+ 0.3339408040046692,
1089
+ 0.3329698443412781,
1090
+ 0.2978940010070801,
1091
+ 0.39562928676605225,
1092
+ 0.3887198865413666,
1093
+ 0.3585241734981537,
1094
+ 0.31963711977005005,
1095
+ 0.31140702962875366,
1096
+ 0.3825278878211975,
1097
+ 0.3666013181209564,
1098
+ 0.24337142705917358,
1099
+ 0.36876243352890015,
1100
+ 0.36285972595214844,
1101
+ 0.33685219287872314,
1102
+ 0.3693651854991913,
1103
+ 0.2829691767692566,
1104
+ 0.30595219135284424,
1105
+ 0.3462897539138794,
1106
+ 0.49221092462539673,
1107
+ 0.35862377285957336,
1108
+ 0.2963302731513977,
1109
+ 0.2860722541809082,
1110
+ 0.33742064237594604,
1111
+ 0.46653813123703003,
1112
+ 0.4688953757286072,
1113
+ 0.3590819835662842,
1114
+ 0.3915751874446869,
1115
+ 0.43592336773872375,
1116
+ 0.4350709915161133,
1117
+ 0.3321887254714966,
1118
+ 0.2887817919254303,
1119
+ 0.3315331041812897,
1120
+ 0.34705182909965515,
1121
+ 0.29848411679267883,
1122
+ 0.3338838815689087,
1123
+ 0.3551400601863861,
1124
+ 0.2849394679069519,
1125
+ 0.276046484708786,
1126
+ 0.32644152641296387
1127
+ ]
1128
+ },
1129
+ "fid": {
1130
+ "avg": 290.9468764478715,
1131
+ "vals": [
1132
+ 317.3580192368432,
1133
+ 285.3201578307224,
1134
+ 227.35443388770344,
1135
+ 284.40224085168455,
1136
+ 250.42273416927597,
1137
+ 332.1637800582306,
1138
+ 421.4826897186527,
1139
+ 175.65665020370423,
1140
+ 462.80212160877016,
1141
+ 313.11887539989885,
1142
+ 216.874426261107,
1143
+ 318.69325303966303,
1144
+ 352.27748822517265,
1145
+ 337.56780348968255,
1146
+ 333.41852463427074,
1147
+ 291.18619781007527,
1148
+ 231.22305793958952,
1149
+ 182.59524666699008,
1150
+ 174.66560996071973,
1151
+ 287.25586265039675,
1152
+ 245.32536993670467,
1153
+ 340.7025359683132,
1154
+ 430.81577040883786,
1155
+ 151.6384776188366,
1156
+ 179.9751120645675,
1157
+ 262.16017361838544,
1158
+ 252.1575376021156,
1159
+ 383.6182322995117,
1160
+ 293.8831833204783,
1161
+ 238.37454832172995,
1162
+ 253.2247312385866,
1163
+ 300.3330815210274,
1164
+ 263.0408874830202,
1165
+ 308.8415900149356,
1166
+ 208.06900694354897,
1167
+ 312.3743463947788,
1168
+ 281.0093276692239,
1169
+ 401.06852280176753,
1170
+ 412.57353143782905,
1171
+ 245.6395460326994,
1172
+ 330.86620573911364,
1173
+ 396.22633300703814,
1174
+ 253.3246381630602,
1175
+ 243.78755013247724,
1176
+ 343.90170319814894,
1177
+ 411.48625023760695,
1178
+ 426.23161153775044,
1179
+ 167.20827449948533,
1180
+ 253.45101894407023,
1181
+ 249.6648801639487,
1182
+ 233.26966302525298,
1183
+ 259.15476030131083
1184
+ ]
1185
+ }
1186
+ }
1187
+ },
1188
+ "dead": {
1189
+ "ft": {
1190
+ "ssim": {
1191
+ "avg": 0.8463698829475202,
1192
+ "vals": [
1193
+ 0.819096565246582,
1194
+ 0.8230092525482178,
1195
+ 0.8352133631706238,
1196
+ 0.9176575541496277,
1197
+ 0.8429977297782898,
1198
+ 0.8367952704429626,
1199
+ 0.8534295558929443,
1200
+ 0.838138997554779,
1201
+ 0.8528039455413818,
1202
+ 0.8470211029052734,
1203
+ 0.8170969486236572,
1204
+ 0.8159563541412354,
1205
+ 0.8493235111236572,
1206
+ 0.8536948561668396,
1207
+ 0.8294236063957214,
1208
+ 0.90614253282547,
1209
+ 0.8337429165840149,
1210
+ 0.8743107914924622,
1211
+ 0.8367175459861755,
1212
+ 0.8243162631988525,
1213
+ 0.8247227668762207,
1214
+ 0.8266720771789551,
1215
+ 0.9225577712059021,
1216
+ 0.8386837840080261,
1217
+ 0.8380022644996643,
1218
+ 0.8386256098747253,
1219
+ 0.8460710048675537,
1220
+ 0.8459582924842834,
1221
+ 0.835181713104248,
1222
+ 0.8186705708503723,
1223
+ 0.8184948563575745,
1224
+ 0.8417790532112122,
1225
+ 0.8489491939544678,
1226
+ 0.8365125060081482,
1227
+ 0.9095039963722229,
1228
+ 0.8317481875419617,
1229
+ 0.8886545300483704,
1230
+ 0.8443787097930908
1231
+ ]
1232
+ },
1233
+ "psnr": {
1234
+ "avg": 23.795184903621124,
1235
+ "vals": [
1236
+ 22.123992625667,
1237
+ 22.383136736110977,
1238
+ 22.815137157510414,
1239
+ 29.058328544219666,
1240
+ 23.458824312069826,
1241
+ 23.35115778300144,
1242
+ 23.899720737917086,
1243
+ 23.251382000597822,
1244
+ 23.8556550518376,
1245
+ 24.496927997944468,
1246
+ 22.229396380994544,
1247
+ 21.835705931258186,
1248
+ 24.044979173806603,
1249
+ 23.87095747368083,
1250
+ 22.755647052287486,
1251
+ 27.255022608267765,
1252
+ 23.384111730889643,
1253
+ 24.47027594982054,
1254
+ 22.969660848990017,
1255
+ 22.118529030027556,
1256
+ 22.65373935684274,
1257
+ 22.26046784537836,
1258
+ 29.458407087230075,
1259
+ 23.06010458055496,
1260
+ 23.57190723261845,
1261
+ 23.327430637055485,
1262
+ 23.589256344441864,
1263
+ 23.48439936889143,
1264
+ 23.067392439373478,
1265
+ 22.096704419878947,
1266
+ 21.619810610831216,
1267
+ 23.50251506589982,
1268
+ 23.94463595644848,
1269
+ 23.01993677097123,
1270
+ 28.729724692646393,
1271
+ 23.230957096762136,
1272
+ 26.24107924644081,
1273
+ 23.730008458437332
1274
+ ]
1275
+ },
1276
+ "lpips": {
1277
+ "avg": 0.07999030775145481,
1278
+ "vals": [
1279
+ 0.09683731198310852,
1280
+ 0.10403041541576385,
1281
+ 0.10699359327554703,
1282
+ 0.02980189025402069,
1283
+ 0.07957077771425247,
1284
+ 0.08596174418926239,
1285
+ 0.06019214540719986,
1286
+ 0.07882507890462875,
1287
+ 0.07737825810909271,
1288
+ 0.08655022084712982,
1289
+ 0.09879130125045776,
1290
+ 0.12470565736293793,
1291
+ 0.055472321808338165,
1292
+ 0.06304865330457687,
1293
+ 0.10779929161071777,
1294
+ 0.030872609466314316,
1295
+ 0.07079628109931946,
1296
+ 0.06584230065345764,
1297
+ 0.10193256288766861,
1298
+ 0.09854554384946823,
1299
+ 0.10243230313062668,
1300
+ 0.10806915909051895,
1301
+ 0.025890624150633812,
1302
+ 0.08075970411300659,
1303
+ 0.077939972281456,
1304
+ 0.06577856838703156,
1305
+ 0.07062944024801254,
1306
+ 0.08819691836833954,
1307
+ 0.10097920894622803,
1308
+ 0.105897456407547,
1309
+ 0.11288004368543625,
1310
+ 0.06032264977693558,
1311
+ 0.063026562333107,
1312
+ 0.10948407649993896,
1313
+ 0.030226124450564384,
1314
+ 0.06960238516330719,
1315
+ 0.053624995052814484,
1316
+ 0.0899435430765152
1317
+ ]
1318
+ },
1319
+ "fid": {
1320
+ "avg": 80.0617442123876,
1321
+ "vals": [
1322
+ 77.72583958495193,
1323
+ 108.01430563351622,
1324
+ 83.80114371601998,
1325
+ 43.79561757437186,
1326
+ 64.21238376481085,
1327
+ 107.90633495184883,
1328
+ 94.90226901876251,
1329
+ 81.66996436668452,
1330
+ 56.44677259421701,
1331
+ 133.90454895719694,
1332
+ 94.53365874681484,
1333
+ 103.25655943118333,
1334
+ 56.07493043457935,
1335
+ 76.20330391887549,
1336
+ 86.46052802881235,
1337
+ 35.14997667649733,
1338
+ 62.05464557602466,
1339
+ 82.75703146138048,
1340
+ 78.68766046633303,
1341
+ 95.9801870274251,
1342
+ 69.69440837156752,
1343
+ 133.97665503608607,
1344
+ 23.992440096968796,
1345
+ 73.95956883955756,
1346
+ 101.69397444074366,
1347
+ 96.09931824946253,
1348
+ 44.15662245914518,
1349
+ 83.0950768685819,
1350
+ 113.04304804723927,
1351
+ 123.68218057444115,
1352
+ 80.64456666041356,
1353
+ 51.48177417838611,
1354
+ 61.27753437751468,
1355
+ 89.59810342754149,
1356
+ 64.8439294361588,
1357
+ 63.560962531495846,
1358
+ 68.41856099127594,
1359
+ 75.58989355384159
1360
+ ]
1361
+ }
1362
+ },
1363
+ "base": {
1364
+ "ssim": {
1365
+ "avg": 0.7146695460143843,
1366
+ "vals": [
1367
+ 0.6920250058174133,
1368
+ 0.6781535148620605,
1369
+ 0.7473878860473633,
1370
+ 0.7988237738609314,
1371
+ 0.7046812176704407,
1372
+ 0.6728485226631165,
1373
+ 0.6648303866386414,
1374
+ 0.6507443785667419,
1375
+ 0.6572958827018738,
1376
+ 0.7580508589744568,
1377
+ 0.7515758872032166,
1378
+ 0.759639322757721,
1379
+ 0.7379322052001953,
1380
+ 0.7911661267280579,
1381
+ 0.7242826819419861,
1382
+ 0.6797776222229004,
1383
+ 0.7349328994750977,
1384
+ 0.7590611577033997,
1385
+ 0.6552085876464844,
1386
+ 0.745119035243988,
1387
+ 0.7352902293205261,
1388
+ 0.7022021412849426,
1389
+ 0.7449350357055664,
1390
+ 0.7506275773048401,
1391
+ 0.7204473614692688,
1392
+ 0.666022002696991,
1393
+ 0.7632265090942383,
1394
+ 0.7753145098686218,
1395
+ 0.6686270833015442,
1396
+ 0.6356704831123352,
1397
+ 0.6632815003395081,
1398
+ 0.7420862317085266,
1399
+ 0.6526774764060974,
1400
+ 0.7651051878929138,
1401
+ 0.7491373419761658,
1402
+ 0.6927105784416199,
1403
+ 0.7062166333198547,
1404
+ 0.6603279113769531
1405
+ ]
1406
+ },
1407
+ "psnr": {
1408
+ "avg": 17.075681593330103,
1409
+ "vals": [
1410
+ 16.02221993829837,
1411
+ 16.391894818273492,
1412
+ 19.158898065865564,
1413
+ 21.168902461771523,
1414
+ 15.849282724811843,
1415
+ 13.438762906336217,
1416
+ 14.041096876172848,
1417
+ 14.614652873958555,
1418
+ 14.836499194441464,
1419
+ 19.42617293541887,
1420
+ 19.6028175608849,
1421
+ 19.50721422138582,
1422
+ 17.430967194872665,
1423
+ 20.64046847536366,
1424
+ 17.37638733796819,
1425
+ 15.173079056741063,
1426
+ 18.277006035337127,
1427
+ 18.80925161485643,
1428
+ 14.095576974205786,
1429
+ 19.420502974160595,
1430
+ 19.454135471046165,
1431
+ 17.222296068007893,
1432
+ 18.225982648513018,
1433
+ 19.51789549326609,
1434
+ 16.618582564745015,
1435
+ 15.364057378648432,
1436
+ 18.71637998732089,
1437
+ 19.61078528607858,
1438
+ 14.215511993158811,
1439
+ 13.88753261057972,
1440
+ 10.856355225225222,
1441
+ 18.09338753925673,
1442
+ 14.179305881135015,
1443
+ 20.07084940675992,
1444
+ 19.650702290598378,
1445
+ 15.665586196724535,
1446
+ 16.945129336951908,
1447
+ 15.299770927402514
1448
+ ]
1449
+ },
1450
+ "lpips": {
1451
+ "avg": 0.2726066539946355,
1452
+ "vals": [
1453
+ 0.2813563346862793,
1454
+ 0.3176368176937103,
1455
+ 0.239101380109787,
1456
+ 0.1326800435781479,
1457
+ 0.31650233268737793,
1458
+ 0.3085799217224121,
1459
+ 0.32426244020462036,
1460
+ 0.32414373755455017,
1461
+ 0.3856937885284424,
1462
+ 0.21604351699352264,
1463
+ 0.2195548415184021,
1464
+ 0.18589559197425842,
1465
+ 0.21749232709407806,
1466
+ 0.1575351059436798,
1467
+ 0.2156955450773239,
1468
+ 0.29644495248794556,
1469
+ 0.22918514907360077,
1470
+ 0.21251118183135986,
1471
+ 0.37587040662765503,
1472
+ 0.193213552236557,
1473
+ 0.2757120132446289,
1474
+ 0.31559503078460693,
1475
+ 0.2390986531972885,
1476
+ 0.2627783417701721,
1477
+ 0.22325637936592102,
1478
+ 0.35516589879989624,
1479
+ 0.2559050917625427,
1480
+ 0.2075577825307846,
1481
+ 0.3593791425228119,
1482
+ 0.39297720789909363,
1483
+ 0.3490094244480133,
1484
+ 0.22949230670928955,
1485
+ 0.3761798143386841,
1486
+ 0.20215649902820587,
1487
+ 0.21151788532733917,
1488
+ 0.33757245540618896,
1489
+ 0.27517855167388916,
1490
+ 0.3411214053630829
1491
+ ]
1492
+ },
1493
+ "fid": {
1494
+ "avg": 171.8519597862991,
1495
+ "vals": [
1496
+ 144.77019624546853,
1497
+ 154.75288495007914,
1498
+ 160.95209947656429,
1499
+ 117.83869765679046,
1500
+ 151.64615714342628,
1501
+ 157.47977926729806,
1502
+ 162.33033867519435,
1503
+ 265.0183384600493,
1504
+ 281.5918475384978,
1505
+ 152.44583427606105,
1506
+ 133.1601149299724,
1507
+ 194.62961438661011,
1508
+ 244.89585280015393,
1509
+ 131.48030011431425,
1510
+ 167.5752450544133,
1511
+ 201.44646187475317,
1512
+ 107.98531116987058,
1513
+ 170.08420689868527,
1514
+ 236.9186629037794,
1515
+ 141.81304882152676,
1516
+ 138.10872705204866,
1517
+ 174.07596100327163,
1518
+ 133.47905631898664,
1519
+ 200.28084705986407,
1520
+ 100.24977206742142,
1521
+ 146.11378626732036,
1522
+ 180.40072384521355,
1523
+ 134.37882246824165,
1524
+ 266.2742357839778,
1525
+ 204.80154766025774,
1526
+ 283.97677018987224,
1527
+ 144.08380223026322,
1528
+ 176.91537688202212,
1529
+ 127.66897079587724,
1530
+ 130.9521420141355,
1531
+ 166.36337398681758,
1532
+ 214.39949796639,
1533
+ 129.03606564387414
1534
+ ]
1535
+ }
1536
+ }
1537
+ }
1538
+ }
metrics.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import torch
4
+ import numpy as np
5
+ import skimage
6
+ from imageio import imread
7
+ from scipy import linalg
8
+ from torch.nn.functional import adaptive_avg_pool2d
9
+ from skimage.metrics import structural_similarity as compare_ssim
10
+ from skimage.metrics import peak_signal_noise_ratio as compare_psnr
11
+ import glob
12
+ import argparse
13
+ import matplotlib.pyplot as plt
14
+ from inception import InceptionV3
15
+ #from scripts.PerceptualSimilarity.models import dist_model as dm
16
+ import lpips
17
+ import pandas as pd
18
+ import json
19
+ import imageio
20
+ import cv2
21
+ print(skimage.__version__)
22
+
23
+ class FID():
24
+ """docstring for FID
25
+ Calculates the Frechet Inception Distance (FID) to evalulate GANs
26
+ The FID metric calculates the distance between two distributions of images.
27
+ Typically, we have summary statistics (mean & covariance matrix) of one
28
+ of these distributions, while the 2nd distribution is given by a GAN.
29
+ When run as a stand-alone program, it compares the distribution of
30
+ images that are stored as PNG/JPEG at a specified location with a
31
+ distribution given by summary statistics (in pickle format).
32
+ The FID is calculated by assuming that X_1 and X_2 are the activations of
33
+ the pool_3 layer of the inception net for generated samples and real world
34
+ samples respectivly.
35
+ See --help to see further details.
36
+ Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
37
+ of Tensorflow
38
+ Copyright 2018 Institute of Bioinformatics, JKU Linz
39
+ Licensed under the Apache License, Version 2.0 (the "License");
40
+ you may not use this file except in compliance with the License.
41
+ You may obtain a copy of the License at
42
+ http://www.apache.org/licenses/LICENSE-2.0
43
+ Unless required by applicable law or agreed to in writing, software
44
+ distributed under the License is distributed on an "AS IS" BASIS,
45
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
46
+ See the License for the specific language governing permissions and
47
+ limitations under the License.
48
+ """
49
+ def __init__(self):
50
+ self.dims = 2048
51
+ self.batch_size = 128
52
+ self.cuda = True
53
+ self.verbose=False
54
+
55
+ block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
56
+ self.model = InceptionV3([block_idx])
57
+ if self.cuda:
58
+ # TODO: put model into specific GPU
59
+ self.model.cuda()
60
+
61
+ def __call__(self, images, gt_path):
62
+ """ images: list of the generated image. The values must lie between 0 and 1.
63
+ gt_path: the path of the ground truth images. The values must lie between 0 and 1.
64
+ """
65
+ if not os.path.exists(gt_path):
66
+ raise RuntimeError('Invalid path: %s' % gt_path)
67
+
68
+
69
+ print('calculate gt_path statistics...')
70
+ m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
71
+ print('calculate generated_images statistics...')
72
+ m2, s2 = self.calculate_activation_statistics(images, self.verbose)
73
+ fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
74
+ return fid_value
75
+
76
+
77
+ def calculate_from_disk(self, generated_path, gt_path, img_size):
78
+ """
79
+ """
80
+ if not os.path.exists(gt_path):
81
+ raise RuntimeError('Invalid path: %s' % gt_path)
82
+ if not os.path.exists(generated_path):
83
+ raise RuntimeError('Invalid path: %s' % generated_path)
84
+
85
+ print ('exp-path - '+generated_path)
86
+
87
+ print('calculate gt_path statistics...')
88
+ m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose, img_size)
89
+ print('calculate generated_path statistics...')
90
+ m2, s2 = self.compute_statistics_of_path(generated_path, self.verbose, img_size)
91
+ print('calculate frechet distance...')
92
+ fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
93
+ print('fid_distance %f' % (fid_value))
94
+ return fid_value
95
+
96
+
97
+ def compute_statistics_of_path(self, path , verbose, img_size):
98
+
99
+ size_flag = '{}_{}'.format(img_size[0], img_size[1])
100
+ npz_file = os.path.join(path, size_flag + '_statistics.npz')
101
+ if os.path.exists(npz_file):
102
+ f = np.load(npz_file)
103
+ m, s = f['mu'][:], f['sigma'][:]
104
+ f.close()
105
+
106
+ else:
107
+
108
+ path = pathlib.Path(path)
109
+ files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
110
+
111
+ imgs = (np.array([(cv2.resize(imread(str(fn)).astype(np.float32),img_size,interpolation=cv2.INTER_CUBIC)) for fn in files]))/255.0
112
+ # Bring images to shape (B, 3, H, W)
113
+ imgs = imgs.transpose((0, 3, 1, 2))
114
+
115
+ # Rescale images to be between 0 and 1
116
+
117
+
118
+ m, s = self.calculate_activation_statistics(imgs, verbose)
119
+ np.savez(npz_file, mu=m, sigma=s)
120
+
121
+ return m, s
122
+
123
+ def calculate_activation_statistics(self, images, verbose):
124
+ """Calculation of the statistics used by the FID.
125
+ Params:
126
+ -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
127
+ must lie between 0 and 1.
128
+ -- model : Instance of inception model
129
+ -- batch_size : The images numpy array is split into batches with
130
+ batch size batch_size. A reasonable batch size
131
+ depends on the hardware.
132
+ -- dims : Dimensionality of features returned by Inception
133
+ -- cuda : If set to True, use GPU
134
+ -- verbose : If set to True and parameter out_step is given, the
135
+ number of calculated batches is reported.
136
+ Returns:
137
+ -- mu : The mean over samples of the activations of the pool_3 layer of
138
+ the inception model.
139
+ -- sigma : The covariance matrix of the activations of the pool_3 layer of
140
+ the inception model.
141
+ """
142
+ act = self.get_activations(images, verbose)
143
+ mu = np.mean(act, axis=0)
144
+ sigma = np.cov(act, rowvar=False)
145
+ return mu, sigma
146
+
147
+
148
+
149
+ def get_activations(self, images, verbose=False):
150
+ """Calculates the activations of the pool_3 layer for all images.
151
+ Params:
152
+ -- images : Numpy array of dimension (n_images, 3, hi, wi). The values
153
+ must lie between 0 and 1.
154
+ -- model : Instance of inception model
155
+ -- batch_size : the images numpy array is split into batches with
156
+ batch size batch_size. A reasonable batch size depends
157
+ on the hardware.
158
+ -- dims : Dimensionality of features returned by Inception
159
+ -- cuda : If set to True, use GPU
160
+ -- verbose : If set to True and parameter out_step is given, the number
161
+ of calculated batches is reported.
162
+ Returns:
163
+ -- A numpy array of dimension (num images, dims) that contains the
164
+ activations of the given tensor when feeding inception with the
165
+ query tensor.
166
+ """
167
+ self.model.eval()
168
+
169
+ d0 = images.shape[0]
170
+ if self.batch_size > d0:
171
+ print(('Warning: batch size is bigger than the data size. '
172
+ 'Setting batch size to data size'))
173
+ self.batch_size = d0
174
+
175
+ n_batches = d0 // self.batch_size
176
+ n_used_imgs = n_batches * self.batch_size
177
+
178
+ pred_arr = np.empty((n_used_imgs, self.dims))
179
+ for i in range(n_batches):
180
+ if verbose:
181
+ print('\rPropagating batch %d/%d' % (i + 1, n_batches))
182
+ # end='', flush=True)
183
+ start = i * self.batch_size
184
+ end = start + self.batch_size
185
+
186
+ batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
187
+ # batch = Variable(batch, volatile=True)
188
+ if self.cuda:
189
+ batch = batch.cuda()
190
+
191
+ pred = self.model(batch)[0]
192
+
193
+ # If model output is not scalar, apply global spatial average pooling.
194
+ # This happens if you choose a dimensionality not equal 2048.
195
+ if pred.shape[2] != 1 or pred.shape[3] != 1:
196
+ pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
197
+
198
+ pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1)
199
+
200
+ if verbose:
201
+ print(' done')
202
+
203
+ return pred_arr
204
+
205
+
206
+ def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
207
+ """Numpy implementation of the Frechet Distance.
208
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
209
+ and X_2 ~ N(mu_2, C_2) is
210
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
211
+ Stable version by Dougal J. Sutherland.
212
+ Params:
213
+ -- mu1 : Numpy array containing the activations of a layer of the
214
+ inception net (like returned by the function 'get_predictions')
215
+ for generated samples.
216
+ -- mu2 : The sample mean over activations, precalculated on an
217
+ representive data set.
218
+ -- sigma1: The covariance matrix over activations for generated samples.
219
+ -- sigma2: The covariance matrix over activations, precalculated on an
220
+ representive data set.
221
+ Returns:
222
+ -- : The Frechet Distance.
223
+ """
224
+
225
+ mu1 = np.atleast_1d(mu1)
226
+ mu2 = np.atleast_1d(mu2)
227
+
228
+ sigma1 = np.atleast_2d(sigma1)
229
+ sigma2 = np.atleast_2d(sigma2)
230
+
231
+ assert mu1.shape == mu2.shape, \
232
+ 'Training and test mean vectors have different lengths'
233
+ assert sigma1.shape == sigma2.shape, \
234
+ 'Training and test covariances have different dimensions'
235
+
236
+ diff = mu1 - mu2
237
+
238
+ # Product might be almost singular
239
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
240
+ if not np.isfinite(covmean).all():
241
+ msg = ('fid calculation produces singular product; '
242
+ 'adding %s to diagonal of cov estimates') % eps
243
+ print(msg)
244
+ offset = np.eye(sigma1.shape[0]) * eps
245
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
246
+
247
+ # Numerical error might give slight imaginary component
248
+ if np.iscomplexobj(covmean):
249
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
250
+ m = np.max(np.abs(covmean.imag))
251
+ raise ValueError('Imaginary component {}'.format(m))
252
+ covmean = covmean.real
253
+
254
+ tr_covmean = np.trace(covmean)
255
+
256
+ return (diff.dot(diff) + np.trace(sigma1) +
257
+ np.trace(sigma2) - 2 * tr_covmean)
258
+
259
+
260
+ class Reconstruction_Metrics():
261
+ def __init__(self, metric_list=['ssim', 'psnr', 'l1', 'mae'], data_range=1, win_size=51, multichannel=True):
262
+ self.data_range = data_range
263
+ self.win_size = win_size
264
+ self.multichannel = multichannel
265
+ for metric in metric_list:
266
+ if metric in ['ssim', 'psnr', 'l1', 'mae']:
267
+ setattr(self, metric, True)
268
+ else:
269
+ print('unsupport reconstruction metric: %s'%metric)
270
+
271
+
272
+ def __call__(self, inputs, gts):
273
+ """
274
+ inputs: the generated image, size (b,c,w,h), data range(0, data_range)
275
+ gts: the ground-truth image, size (b,c,w,h), data range(0, data_range)
276
+ """
277
+ result = dict()
278
+ [b,n,w,h] = inputs.size()
279
+ inputs = inputs.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
280
+ gts = gts.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
281
+
282
+ if hasattr(self, 'ssim'):
283
+ ssim_value = compare_ssim(inputs, gts, data_range=self.data_range,
284
+ win_size=self.win_size, multichannel=self.multichannel)
285
+ result['ssim'] = ssim_value
286
+
287
+
288
+ if hasattr(self, 'psnr'):
289
+ psnr_value = compare_psnr(inputs, gts, self.data_range)
290
+ result['psnr'] = psnr_value
291
+
292
+ if hasattr(self, 'l1'):
293
+ l1_value = compare_l1(inputs, gts)
294
+ result['l1'] = l1_value
295
+
296
+ if hasattr(self, 'mae'):
297
+ mae_value = compare_mae(inputs, gts)
298
+ result['mae'] = mae_value
299
+ return result
300
+
301
+
302
+ def calculate_from_disk(self, inputs, gts, save_path=None, img_size=(176,256), sort=True, debug=0):
303
+ """
304
+ inputs: .txt files, floders, image files (string), image files (list)
305
+ gts: .txt files, floders, image files (string), image files (list)
306
+ """
307
+ if sort:
308
+ input_image_list = sorted(get_image_list(inputs))
309
+ gt_image_list = sorted(get_image_list(gts))
310
+ else:
311
+ input_image_list = get_image_list(inputs)
312
+ gt_image_list = get_image_list(gts)
313
+
314
+ size_flag = '{}_{}'.format(img_size[0], img_size[1])
315
+ npz_file = os.path.join(save_path, size_flag + '_metrics.npz')
316
+ if os.path.exists(npz_file):
317
+ f = np.load(npz_file)
318
+ psnr,ssim,ssim_256,mae,l1=f['psnr'],f['ssim'],f['ssim_256'],f['mae'],f['l1']
319
+ else:
320
+ psnr = []
321
+ ssim = []
322
+ ssim_256 = []
323
+ mae = []
324
+ l1 = []
325
+ names = []
326
+
327
+ for index in range(len(input_image_list)):
328
+ name = os.path.basename(input_image_list[index])
329
+ names.append(name)
330
+
331
+
332
+ img_gt = (cv2.resize(imread(str(gt_image_list[index])).astype(np.float32), img_size,interpolation=cv2.INTER_CUBIC)) /255.0
333
+ img_pred = (cv2.resize(imread(str(input_image_list[index])).astype(np.float32), img_size,interpolation=cv2.INTER_CUBIC)) / 255.0
334
+
335
+
336
+ if debug != 0:
337
+ plt.subplot('121')
338
+ plt.imshow(img_gt)
339
+ plt.title('Groud truth')
340
+ plt.subplot('122')
341
+ plt.imshow(img_pred)
342
+ plt.title('Output')
343
+ plt.show()
344
+
345
+ psnr.append(compare_psnr(img_gt, img_pred, data_range=self.data_range))
346
+ ssim.append(compare_ssim(img_gt, img_pred, data_range=self.data_range,
347
+ win_size=self.win_size,multichannel=self.multichannel, channel_axis=2))
348
+ mae.append(compare_mae(img_gt, img_pred))
349
+ l1.append(compare_l1(img_gt, img_pred))
350
+
351
+ img_gt_256 = img_gt*255.0
352
+ img_pred_256 = img_pred*255.0
353
+ ssim_256.append(compare_ssim(img_gt_256, img_pred_256, gaussian_weights=True, sigma=1.2,
354
+ use_sample_covariance=False, multichannel=True, channel_axis=2,
355
+ data_range=img_pred_256.max() - img_pred_256.min()))
356
+
357
+ if np.mod(index, 200) == 0:
358
+ print(
359
+ str(index) + ' images processed',
360
+ "PSNR: %.4f" % round(np.mean(psnr), 4),
361
+ "SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
362
+ "MAE: %.4f" % round(np.mean(mae), 4),
363
+ "l1: %.4f" % round(np.mean(l1), 4),
364
+ )
365
+
366
+ if save_path:
367
+ np.savez(save_path + '/' + size_flag + '_metrics.npz', psnr=psnr, ssim=ssim, ssim_256=ssim_256, mae=mae, l1=l1, names=names)
368
+
369
+ print(
370
+ "PSNR: %.4f" % round(np.mean(psnr), 4),
371
+ "PSNR Variance: %.4f" % round(np.var(psnr), 4),
372
+ "SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
373
+ "SSIM_256 Variance: %.4f" % round(np.var(ssim_256), 4),
374
+ "MAE: %.4f" % round(np.mean(mae), 4),
375
+ "MAE Variance: %.4f" % round(np.var(mae), 4),
376
+ "l1: %.4f" % round(np.mean(l1), 4),
377
+ "l1 Variance: %.4f" % round(np.var(l1), 4)
378
+ )
379
+
380
+ dic = {"psnr":[round(np.mean(psnr), 6)],
381
+ "psnr_variance": [round(np.var(psnr), 6)],
382
+ "ssim_256": [round(np.mean(ssim_256), 6)],
383
+ "ssim_256_variance": [round(np.var(ssim_256), 6)],
384
+ "mae": [round(np.mean(mae), 6)],
385
+ "mae_variance": [round(np.var(mae), 6)],
386
+ "l1": [round(np.mean(l1), 6)],
387
+ "l1_variance": [round(np.var(l1), 6)] }
388
+
389
+ return dic
390
+
391
+
392
+ def get_image_list(flist):
393
+ if isinstance(flist, list):
394
+ return flist
395
+
396
+ # flist: image file path, image directory path, text file flist path
397
+ if isinstance(flist, str):
398
+ if os.path.isdir(flist):
399
+ flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
400
+ flist.sort()
401
+ return flist
402
+
403
+ if os.path.isfile(flist):
404
+ try:
405
+ return np.genfromtxt(flist, dtype=np.str)
406
+ except:
407
+ return [flist]
408
+ print('can not read files from %s return empty list'%flist)
409
+ return []
410
+
411
+ def compare_l1(img_true, img_test):
412
+ img_true = img_true.astype(np.float32)
413
+ img_test = img_test.astype(np.float32)
414
+ return np.mean(np.abs(img_true - img_test))
415
+
416
+ def compare_mae(img_true, img_test):
417
+ img_true = img_true.astype(np.float32)
418
+ img_test = img_test.astype(np.float32)
419
+ return np.sum(np.abs(img_true - img_test)) / np.sum(img_true + img_test)
420
+
421
+ def preprocess_path_for_deform_task(gt_path, distorted_path):
422
+ distorted_image_list = sorted(get_image_list(distorted_path))
423
+ gt_list=[]
424
+ distorated_list=[]
425
+
426
+ for distorted_image in distorted_image_list:
427
+ image = os.path.basename(distorted_image)[1:]
428
+ image = image.split('_to_')[-1]
429
+ gt_image = gt_path + '/' + image.replace('jpg', 'png')
430
+ if not os.path.isfile(gt_image):
431
+ print(distorted_image, gt_image)
432
+ print('=====')
433
+ continue
434
+ gt_list.append(gt_image)
435
+ distorated_list.append(distorted_image)
436
+
437
+ return gt_list, distorated_list
438
+
439
+
440
+
441
+ class LPIPS():
442
+ def __init__(self, use_gpu=True):
443
+
444
+ self.model = lpips.LPIPS(net='alex').eval().cuda()
445
+ self.use_gpu=use_gpu
446
+
447
+ def __call__(self, image_1, image_2):
448
+ """
449
+ image_1: images with size (n, 3, w, h) with value [-1, 1]
450
+ image_2: images with size (n, 3, w, h) with value [-1, 1]
451
+ """
452
+ result = self.model.forward(image_1, image_2)
453
+ return result
454
+
455
+ def calculate_from_disk(self, path_1, path_2,img_size, batch_size=64, verbose=False, sort=True):
456
+
457
+ if sort:
458
+ files_1 = sorted(get_image_list(path_1))
459
+ files_2 = sorted(get_image_list(path_2))
460
+ else:
461
+ files_1 = get_image_list(path_1)
462
+ files_2 = get_image_list(path_2)
463
+
464
+
465
+ results=[]
466
+
467
+
468
+ d0 = len(files_1)
469
+ if batch_size > d0:
470
+ print(('Warning: batch size is bigger than the data size. '
471
+ 'Setting batch size to data size'))
472
+ batch_size = d0
473
+
474
+ n_batches = d0 // batch_size
475
+
476
+
477
+ for i in range(n_batches):
478
+ if verbose:
479
+ print('\rPropagating batch %d/%d' % (i + 1, n_batches))
480
+ # end='', flush=True)
481
+ start = i * batch_size
482
+ end = start + batch_size
483
+
484
+ imgs_1 = np.array([cv2.resize(imread(str(fn)).astype(np.float32),img_size,interpolation=cv2.INTER_CUBIC)/255.0 for fn in files_1[start:end]])
485
+ imgs_2 = np.array([cv2.resize(imread(str(fn)).astype(np.float32),img_size,interpolation=cv2.INTER_CUBIC)/255.0 for fn in files_2[start:end]])
486
+
487
+ imgs_1 = imgs_1.transpose((0, 3, 1, 2))
488
+ imgs_2 = imgs_2.transpose((0, 3, 1, 2))
489
+
490
+ img_1_batch = torch.from_numpy(imgs_1).type(torch.FloatTensor)
491
+ img_2_batch = torch.from_numpy(imgs_2).type(torch.FloatTensor)
492
+
493
+ if self.use_gpu:
494
+ img_1_batch = img_1_batch.cuda()
495
+ img_2_batch = img_2_batch.cuda()
496
+
497
+ with torch.no_grad():
498
+ result = self.model.forward(img_1_batch, img_2_batch)
499
+
500
+ results.append(result)
501
+
502
+
503
+ distance = torch.cat(results,0)[:,0,0,0].mean()
504
+
505
+ print('lpips: %.3f'%distance)
506
+ return distance
507
+
508
+
509
+
510
+
511
+
512
+
513
+
514
+
515
+
516
+
517
+
518
+
519
+
520
+
521
+
522
+
pose-frames.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from annotator.dwpose import DWposeDetector
2
+ from easy_dwpose import DWposeDetector
3
+ from PIL import Image
4
+
5
+
6
+ device = "cpu"
7
+ dwpose = DWposeDetector(device=device)
8
+
9
+ for n in range(1, 46):
10
+ pil_image = Image.open("videos/dance2/frame ("+str(n)+").png").convert("RGB")
11
+ #skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False)
12
+
13
+ out_img, pose = dwpose(pil_image, include_hands=True, include_face=True)
14
+
15
+ print(pose['bodies'])
16
+ out_img.save('videos/dance'+str(n)+'.png')
pose.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from annotator.dwpose import DWposeDetector
2
+ from easy_dwpose import DWposeDetector
3
+ from PIL import Image
4
+
5
+
6
+ device = "cpu"
7
+ dwpose = DWposeDetector(device=device)
8
+
9
+ pil_image = Image.open("imgs/baggy.png").convert("RGB")
10
+ #skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False)
11
+
12
+ out_img, _ = dwpose(pil_image, include_hands=True, include_face=False)
13
+
14
+ #print(pose['bodies'])
15
+ out_img.save("pose.png")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ easy-dwpose
2
+ diffusers
3
+ controlnet-aux
4
+ transformers
5
+ accelerate
6
+ gradio
7
+ rembg[cpu]
8
+ spaces
run_stage1.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate launch --gpu_ids 0,1,2,3,4,5,6,7 --use_deepspeed --num_processes 8 \
2
+ stage1_train_prior_model.py \
3
+ --pretrained_model_name_or_path="kandinsky-community/kandinsky-2-2-prior" \
4
+ --image_encoder_path='{image_encoder_path}' \
5
+ --img_path='{image_path}' \
6
+ --json_path='{data.json}' \
7
+ --output_dir="{output_dir}" \
8
+ --img_height=512 \
9
+ --img_width=512 \
10
+ --train_batch_size=128 \
11
+ --gradient_accumulation_steps=1 \
12
+ --max_train_steps=100000 \
13
+ --noise_offset=0.1 \
14
+ --learning_rate=1e-05 \
15
+ --weight_decay=0.01 \
16
+ --lr_scheduler="constant" --num_warmup_steps=2000 \
17
+ --checkpointing_steps=5000 \
18
+ --seed 42
run_stage2.sh ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ accelerate launch --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 --use_deepspeed --mixed_precision="fp16" stage2_train_inpaint_model.py \
3
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
4
+ --image_encoder_p_path='facebook/dinov2-giant' \
5
+ --image_encoder_g_path='{image_encoder_path}' \
6
+ --json_path='{data.json}' \
7
+ --image_root_path="{image_path}" \
8
+ --output_dir="{output_dir}" \
9
+ --img_height=512 \
10
+ --img_width=512 \
11
+ --learning_rate=1e-4 \
12
+ --train_batch_size=8 \
13
+ --max_train_steps=1000000 \
14
+ --mixed_precision="fp16" \
15
+ --checkpointing_steps=5000 \
16
+ --noise_offset=0.1 \
17
+ --lr_warmup_steps 5000 \
18
+ --seed 42
run_stage3.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate launch --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 --use_deepspeed --mixed_precision="fp16" stage3_train_refined_model.py \
2
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
3
+ --image_encoder_path='facebook/dinov2-giant' \
4
+ --img_path='{image_path}' \
5
+ --json_path='{data.json}' \
6
+ --gen_t_img_path='{stage2_generate}' \
7
+ --output_dir="{output_dir}" \
8
+ --learning_rate=1e-5 \
9
+ --train_batch_size=16 \
10
+ --max_train_steps=1000000 \
11
+ --mixed_precision="fp16" \
12
+ --checkpointing_steps=5000 \
13
+ --noise_offset=0.1 \
14
+ --report_to=tensorboard \
15
+ --lr_warmup_steps 5000 \
run_test_stage1.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ python3 stage1_batchtest_prior_model.py \
2
+ --pretrained_model_name_or_path="kandinsky-community/kandinsky-2-2-prior" \
3
+ --image_encoder_path="{image_encoder_path}" \
4
+ --img_path='{image_path}' \
5
+ --json_path='{data.json}' \
6
+ --pose_path="{normalized_pose_txt}" \
7
+ --save_path="./logs/view_stage1/512_512" \
8
+ --weights_name="{save_ckpt}"\
9
+
run_test_stage2.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python3 stage2_batchtest_inpaint_model.py \
2
+ --img_weigh 512 \
3
+ --img_height 512 \
4
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
5
+ --image_encoder_g_path='{image_encoder_path}' \
6
+ --image_encoder_p_path='facebook/dinov2-giant' \
7
+ --img_path='{image_path}' \
8
+ --json_path='{data.json}' \
9
+ --pose_path="{pose_path}" \
10
+ --target_embed_path="./logs/view_stage1/512_512/" \
11
+ --save_path="./logs/view_stage2/512_512" \
12
+ --weights_name="{save_ckpt}" \
13
+ --calculate_metrics
run_test_stage3.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python3 stage3_batchtest_refined_model.py \
2
+ --img_weigh 512 \
3
+ --img_height 512 \
4
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
5
+ --image_encoder_p_path='facebook/dinov2-giant' \
6
+ --img_path='{image_path}' \
7
+ --json_path='{data.json}' \
8
+ --pose_path="{pose_path}" \
9
+ --gen_t_img_path="./logs/view_stage2/512_512/" \
10
+ --save_path="./logs/view_stage3/512_512" \
11
+ --weights_name"{save_ckpt}" \
12
+ --calculate_metrics
sd.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
2
+ import torch
3
+
4
+ model_id = "stabilityai/stable-diffusion-2-1-base"
5
+
6
+ scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
7
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
8
+ pipe = pipe.to("cpu")
9
+
10
+ prompt = "a photo of an astronaut riding a horse on mars"
11
+ image = pipe(prompt).images[0]
12
+
13
+ image.save("astronaut_rides_horse.png")
setup.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pip install -U xformers --index-url https://download.pytorch.org/whl/cu118
2
+ pip install easy-dwpose
3
+ pip install diffusers==0.24.0
4
+ pip install controlnet-aux==0.0.7
5
+ pip install transformers==4.32.1
6
+ pip install accelerate==0.24.1
7
+ pip install huggingface-hub==0.25.2
8
+
9
+ pip install gdown
10
+
11
+ pip install -U openmim
12
+ mim install mmengine
13
+ mim install "mmcv==2.1.0"
14
+ mim install "mmdet==3.3.0"
15
+ mim install "mmpose==1.3.2"
16
+
17
+ apt-get install libgl1
18
+
19
+
20
+ #s1
21
+ gdown https://drive.google.com/uc?id=11a5-a5C8NWA4m6i1g099fQQrdohz5gQ1
22
+
23
+ #s2
24
+ gdown https://drive.google.com/uc?id=1JhWeScr9bQtoQmB503VDyomaDmxBHail
25
+
26
+ #s3
27
+ gdown https://drive.google.com/uc?id=11JZXfYVlgLFqmE8jCbLWjQWO7LrwYq-I
28
+
29
+ #demo
30
+ gdown https://drive.google.com/uc?id=1JFFy_FBxOFuGFBcB6xMIVwcQb8bfnpO9
31
+
32
+
33
+
34
+ pip install numpy==1.26.4
35
+ pip install -U xformers --index-url https://download.pytorch.org/whl/cu118
36
+
37
+ for pytorch 2.4.0?
38
+ pip install xformers==0.0.28.dev895 or pip install xformers==0.0.28.dev893
39
+
40
+ 48gb ram for finetuning
41
+
single_extract_pose.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from src.controlnet_aux import DWposeDetector
2
+ from PIL import Image
3
+ import torchvision.transforms as transforms
4
+ import torch
5
+
6
+ def init_dwpose_detector(device):
7
+ # specify configs, ckpts and device, or it will be downloaded automatically and use cpu by default
8
+ det_config = './src/configs/yolox_l_8xb8-300e_coco.py'
9
+ det_ckpt = './ckpts/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'
10
+ pose_config = './src/configs/dwpose-l_384x288.py'
11
+ pose_ckpt = './ckpts/dw-ll_ucoco_384.pth'
12
+
13
+ dwpose_model = DWposeDetector(
14
+ det_config=det_config,
15
+ det_ckpt=det_ckpt,
16
+ pose_config=pose_config,
17
+ pose_ckpt=pose_ckpt,
18
+ device=device
19
+ )
20
+ return dwpose_model.to(device)
21
+
22
+
23
+ def inference_pose(img_path, image_size=(1024, 1024)):
24
+ device = torch.device(f"cuda:{0}")
25
+ model = init_dwpose_detector(device=device)
26
+ pil_image = Image.open(img_path).convert("RGB").resize(image_size, Image.BICUBIC)
27
+ dwpose_image = model(pil_image, output_type='np', image_resolution=image_size[1])
28
+ save_dwpose_image = Image.fromarray(dwpose_image)
29
+ return save_dwpose_image
30
+
31
+
32
+
33
+ inference_pose('imgs/test.png').save("pose.png")
34
+
35
+
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (178 Bytes). View file
 
src/configs/dwpose-l_384x288.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # runtime
2
+ max_epochs = 270
3
+ stage2_num_epochs = 30
4
+ base_lr = 4e-3
5
+
6
+ train_cfg = dict(max_epochs=max_epochs, val_interval=10)
7
+ randomness = dict(seed=21)
8
+
9
+ # optimizer
10
+ optim_wrapper = dict(
11
+ type='OptimWrapper',
12
+ optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
13
+ paramwise_cfg=dict(
14
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
15
+
16
+ # learning rate
17
+ param_scheduler = [
18
+ dict(
19
+ type='LinearLR',
20
+ start_factor=1.0e-5,
21
+ by_epoch=False,
22
+ begin=0,
23
+ end=1000),
24
+ dict(
25
+ # use cosine lr from 150 to 300 epoch
26
+ type='CosineAnnealingLR',
27
+ eta_min=base_lr * 0.05,
28
+ begin=max_epochs // 2,
29
+ end=max_epochs,
30
+ T_max=max_epochs // 2,
31
+ by_epoch=True,
32
+ convert_to_iter_based=True),
33
+ ]
34
+
35
+ # automatically scaling LR based on the actual training batch size
36
+ auto_scale_lr = dict(base_batch_size=512)
37
+
38
+ # codec settings
39
+ codec = dict(
40
+ type='SimCCLabel',
41
+ input_size=(288, 384),
42
+ sigma=(6., 6.93),
43
+ simcc_split_ratio=2.0,
44
+ normalize=False,
45
+ use_dark=False)
46
+
47
+ # model settings
48
+ model = dict(
49
+ type='TopdownPoseEstimator',
50
+ data_preprocessor=dict(
51
+ type='PoseDataPreprocessor',
52
+ mean=[123.675, 116.28, 103.53],
53
+ std=[58.395, 57.12, 57.375],
54
+ bgr_to_rgb=True),
55
+ backbone=dict(
56
+ _scope_='mmdet',
57
+ type='CSPNeXt',
58
+ arch='P5',
59
+ expand_ratio=0.5,
60
+ deepen_factor=1.,
61
+ widen_factor=1.,
62
+ out_indices=(4, ),
63
+ channel_attention=True,
64
+ norm_cfg=dict(type='SyncBN'),
65
+ act_cfg=dict(type='SiLU'),
66
+ init_cfg=dict(
67
+ type='Pretrained',
68
+ prefix='backbone.',
69
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
70
+ 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa
71
+ )),
72
+ head=dict(
73
+ type='RTMCCHead',
74
+ in_channels=1024,
75
+ out_channels=133,
76
+ input_size=codec['input_size'],
77
+ in_featuremap_size=(9, 12),
78
+ simcc_split_ratio=codec['simcc_split_ratio'],
79
+ final_layer_kernel_size=7,
80
+ gau_cfg=dict(
81
+ hidden_dims=256,
82
+ s=128,
83
+ expansion_factor=2,
84
+ dropout_rate=0.,
85
+ drop_path=0.,
86
+ act_fn='SiLU',
87
+ use_rel_bias=False,
88
+ pos_enc=False),
89
+ loss=dict(
90
+ type='KLDiscretLoss',
91
+ use_target_weight=True,
92
+ beta=10.,
93
+ label_softmax=True),
94
+ decoder=codec),
95
+ test_cfg=dict(flip_test=True, ))
96
+
97
+ # base dataset settings
98
+ dataset_type = 'CocoWholeBodyDataset'
99
+ data_mode = 'topdown'
100
+ data_root = '/data/'
101
+
102
+ backend_args = dict(backend='local')
103
+ # backend_args = dict(
104
+ # backend='petrel',
105
+ # path_mapping=dict({
106
+ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/',
107
+ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/'
108
+ # }))
109
+
110
+ # pipelines
111
+ train_pipeline = [
112
+ dict(type='LoadImage', backend_args=backend_args),
113
+ dict(type='GetBBoxCenterScale'),
114
+ dict(type='RandomFlip', direction='horizontal'),
115
+ dict(type='RandomHalfBody'),
116
+ dict(
117
+ type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
118
+ dict(type='TopdownAffine', input_size=codec['input_size']),
119
+ dict(type='mmdet.YOLOXHSVRandomAug'),
120
+ dict(
121
+ type='Albumentation',
122
+ transforms=[
123
+ dict(type='Blur', p=0.1),
124
+ dict(type='MedianBlur', p=0.1),
125
+ dict(
126
+ type='CoarseDropout',
127
+ max_holes=1,
128
+ max_height=0.4,
129
+ max_width=0.4,
130
+ min_holes=1,
131
+ min_height=0.2,
132
+ min_width=0.2,
133
+ p=1.0),
134
+ ]),
135
+ dict(type='GenerateTarget', encoder=codec),
136
+ dict(type='PackPoseInputs')
137
+ ]
138
+ val_pipeline = [
139
+ dict(type='LoadImage', backend_args=backend_args),
140
+ dict(type='GetBBoxCenterScale'),
141
+ dict(type='TopdownAffine', input_size=codec['input_size']),
142
+ dict(type='PackPoseInputs')
143
+ ]
144
+
145
+ train_pipeline_stage2 = [
146
+ dict(type='LoadImage', backend_args=backend_args),
147
+ dict(type='GetBBoxCenterScale'),
148
+ dict(type='RandomFlip', direction='horizontal'),
149
+ dict(type='RandomHalfBody'),
150
+ dict(
151
+ type='RandomBBoxTransform',
152
+ shift_factor=0.,
153
+ scale_factor=[0.75, 1.25],
154
+ rotate_factor=60),
155
+ dict(type='TopdownAffine', input_size=codec['input_size']),
156
+ dict(type='mmdet.YOLOXHSVRandomAug'),
157
+ dict(
158
+ type='Albumentation',
159
+ transforms=[
160
+ dict(type='Blur', p=0.1),
161
+ dict(type='MedianBlur', p=0.1),
162
+ dict(
163
+ type='CoarseDropout',
164
+ max_holes=1,
165
+ max_height=0.4,
166
+ max_width=0.4,
167
+ min_holes=1,
168
+ min_height=0.2,
169
+ min_width=0.2,
170
+ p=0.5),
171
+ ]),
172
+ dict(type='GenerateTarget', encoder=codec),
173
+ dict(type='PackPoseInputs')
174
+ ]
175
+
176
+ datasets = []
177
+ dataset_coco=dict(
178
+ type=dataset_type,
179
+ data_root=data_root,
180
+ data_mode=data_mode,
181
+ ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
182
+ data_prefix=dict(img='coco/train2017/'),
183
+ pipeline=[],
184
+ )
185
+ datasets.append(dataset_coco)
186
+
187
+ scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class',
188
+ 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow',
189
+ 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference']
190
+
191
+ for i in range(len(scene)):
192
+ datasets.append(
193
+ dict(
194
+ type=dataset_type,
195
+ data_root=data_root,
196
+ data_mode=data_mode,
197
+ ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json',
198
+ data_prefix=dict(img='UBody/images/'+scene[i]+'/'),
199
+ pipeline=[],
200
+ )
201
+ )
202
+
203
+ # data loaders
204
+ train_dataloader = dict(
205
+ batch_size=32,
206
+ num_workers=10,
207
+ persistent_workers=True,
208
+ sampler=dict(type='DefaultSampler', shuffle=True),
209
+ dataset=dict(
210
+ type='CombinedDataset',
211
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
212
+ datasets=datasets,
213
+ pipeline=train_pipeline,
214
+ test_mode=False,
215
+ ))
216
+ val_dataloader = dict(
217
+ batch_size=32,
218
+ num_workers=10,
219
+ persistent_workers=True,
220
+ drop_last=False,
221
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
222
+ dataset=dict(
223
+ type=dataset_type,
224
+ data_root=data_root,
225
+ data_mode=data_mode,
226
+ ann_file='coco/annotations/coco_wholebody_val_v1.0.json',
227
+ bbox_file=f'{data_root}coco/person_detection_results/'
228
+ 'COCO_val2017_detections_AP_H_56_person.json',
229
+ data_prefix=dict(img='coco/val2017/'),
230
+ test_mode=True,
231
+ pipeline=val_pipeline,
232
+ ))
233
+ test_dataloader = val_dataloader
234
+
235
+ # hooks
236
+ default_hooks = dict(
237
+ checkpoint=dict(
238
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
239
+
240
+ custom_hooks = [
241
+ dict(
242
+ type='EMAHook',
243
+ ema_type='ExpMomentumEMA',
244
+ momentum=0.0002,
245
+ update_buffers=True,
246
+ priority=49),
247
+ dict(
248
+ type='mmdet.PipelineSwitchHook',
249
+ switch_epoch=max_epochs - stage2_num_epochs,
250
+ switch_pipeline=train_pipeline_stage2)
251
+ ]
252
+
253
+ # evaluators
254
+ val_evaluator = dict(
255
+ type='CocoWholeBodyMetric',
256
+ ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json')
257
+ test_evaluator = val_evaluator
src/configs/stage1_config.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import os
4
+ import argparse
5
+
6
+
7
+
8
+
9
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
10
+ parser.add_argument(
11
+ "--pretrained_model_name_or_path",
12
+ type=str,
13
+ default=None,
14
+ required=True,
15
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
16
+ )
17
+ parser.add_argument(
18
+ "--pretrained_image_model_path",
19
+ type=str,
20
+ default=None,
21
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
22
+ )
23
+ parser.add_argument(
24
+ "--pretrained_pose_model_path",
25
+ type=str,
26
+ default=None,
27
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
28
+ )
29
+
30
+ parser.add_argument(
31
+ "--unet_config_file",
32
+ type=str,
33
+ default=None,
34
+ help="Config file of UNet model",
35
+ )
36
+ parser.add_argument("--json_path", type=str, default="./datasets/deepfashing/train_data.json", help="json path", )
37
+ parser.add_argument("--img_path", type=str, default="./datasets/deepfashing/all_data_png/", help="image path", )
38
+ parser.add_argument("--image_encoder_path", type=str, default="./OpenCLIP-ViT-H-14",
39
+ help="Path to pretrained model or model identifier from huggingface.co/models.", )
40
+ parser.add_argument("--img_width", type=int, default=512, help="width", )
41
+ parser.add_argument("--img_height", type=int, default=512, help="height", )
42
+ parser.add_argument(
43
+ "--output_dir",
44
+ type=str,
45
+ default="sd-model-finetuned",
46
+ help="The output directory where the model predictions and checkpoints will be written.",
47
+ )
48
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
49
+
50
+ parser.add_argument(
51
+ "--center_crop",
52
+ action="store_true",
53
+ help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)",
54
+ )
55
+ parser.add_argument(
56
+ "--random_flip",
57
+ action="store_true",
58
+ help="whether to randomly flip images horizontally",
59
+ )
60
+ parser.add_argument(
61
+ '--clip_penultimate',
62
+ type=bool,
63
+ default=False,
64
+ help='Use penultimate CLIP layer for text embedding'
65
+ )
66
+ parser.add_argument(
67
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
68
+ )
69
+ parser.add_argument("--num_train_epochs", type=int, default=100000000)
70
+ parser.add_argument(
71
+ "--max_train_steps",
72
+ type=int,
73
+ default=None,
74
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
75
+ )
76
+ parser.add_argument(
77
+ "--gradient_accumulation_steps",
78
+ type=int,
79
+ default=1,
80
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
81
+ )
82
+ parser.add_argument(
83
+ "--gradient_checkpointing",
84
+ action="store_true",
85
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
86
+ )
87
+ parser.add_argument(
88
+ "--learning_rate",
89
+ type=float,
90
+ default=1e-4,
91
+ help="Initial learning rate (after the potential warmup period) to use.",
92
+ )
93
+ parser.add_argument(
94
+ "--weight_decay",
95
+ type=float,
96
+ default=0.01,
97
+ help="Initial learning rate (after the potential warmup period) to use.",
98
+ )
99
+ parser.add_argument(
100
+ "--lr_scheduler",
101
+ type=str,
102
+ default="linear",
103
+ help="The scheduler type to use.",
104
+ choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
105
+ )
106
+ parser.add_argument(
107
+ "--num_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
108
+ )
109
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
110
+ parser.add_argument(
111
+ "--logging_dir",
112
+ type=str,
113
+ default="logs",
114
+ help=(
115
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
116
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
117
+ ),
118
+ )
119
+ parser.add_argument(
120
+ "--print_freq",
121
+ type=int,
122
+ default=1,
123
+ help=(
124
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
125
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
126
+ ),
127
+ )
128
+ parser.add_argument(
129
+ "--report_to",
130
+ type=str,
131
+ default="tensorboard",
132
+ help=(
133
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
134
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
135
+ "Only applicable when `--with_tracking` is passed."
136
+ ),
137
+ )
138
+ parser.add_argument(
139
+ "--checkpointing_steps",
140
+ type=int,
141
+ default=500,
142
+ help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
143
+ )
144
+ parser.add_argument(
145
+ "--resume_from_checkpoint",
146
+ type=str,
147
+ default=None,
148
+ help="If the training should continue from a checkpoint folder.",
149
+ )
150
+ parser.add_argument(
151
+ "--unet_init_ckpt",
152
+ type=str,
153
+ default=None,
154
+ help="If the training should continue from a checkpoint folder.",
155
+ )
156
+
157
+ parser.add_argument(
158
+ "--mixed_precision",
159
+ type=str,
160
+ default="fp16",
161
+ choices=["no", "fp16", "bf16"],
162
+ help=(
163
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
164
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
165
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
166
+ ),
167
+ )
168
+ parser.add_argument(
169
+ "--enable_xformers_memory_efficient_attention",
170
+ action="store_true",
171
+ help="Whether or not to use xformers.",
172
+ )
173
+ parser.add_argument(
174
+ "--max_grad_norm", default=10.0, type=float, help="Max gradient norm."
175
+ )
176
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
177
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
178
+
179
+ args = parser.parse_args()
180
+ print(args)
181
+
src/configs/stage2_config.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
3
+ parser.add_argument(
4
+ "--pretrained_model_name_or_path",
5
+ type=str,
6
+ default="stabilityai/stable-diffusion-2-1-base",
7
+ help="Path to pretrained model or model identifier from huggingface.co/models.",)
8
+
9
+ parser.add_argument(
10
+ "--seed", type=int, default=42, help="A seed for reproducible training."
11
+ )
12
+
13
+ parser.add_argument(
14
+ "--train_batch_size",
15
+ type=int,
16
+ default=8,
17
+ help="Batch size (per device) for the training dataloader.",
18
+ )
19
+ parser.add_argument("--num_train_epochs", type=int, default=10000)
20
+ parser.add_argument(
21
+ "--max_train_steps",
22
+ type=int,
23
+ default=100,
24
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
25
+ )
26
+ parser.add_argument(
27
+ "--checkpointing_steps",
28
+ type=int,
29
+ default=1000,
30
+ help=(
31
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
32
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
33
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
34
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
35
+ "instructions."
36
+ ),
37
+ )
38
+ parser.add_argument("--json_path", type=str, default="./datasets/deepfashing/train_data.json", help="json path", )
39
+ parser.add_argument("--image_root_path", type=str, default="./datasets/deepfashing/all_data_png/", help="image path", )
40
+ parser.add_argument("--image_encoder_g_path", type=str, default="./OpenCLIP-ViT-H-14",
41
+ help="Path to pretrained model or model identifier from huggingface.co/models.", )
42
+ parser.add_argument("--image_encoder_p_path", type=str, default="./dinov2-giant",
43
+ help="Path to pretrained model or model identifier from huggingface.co/models.", )
44
+ parser.add_argument("--output_dir",type=str,default="out/",help="The output directory where the model predictions and checkpoints will be written.",)
45
+ parser.add_argument("--img_width", type=int, default=512, help="device number", )
46
+ parser.add_argument("--img_height", type=int, default=512, help="device number", )
47
+ parser.add_argument(
48
+ "--resume_from_checkpoint",
49
+ type=str,
50
+ default="pcdms_ckpt.pt",
51
+ help=(
52
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
53
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
54
+ ),
55
+ )
56
+ parser.add_argument(
57
+ "--set_grads_to_none",
58
+ action="store_true",
59
+ help=(
60
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
61
+ " behaviors, so disable this argument if it causes any problems. More info:"
62
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
63
+ ),
64
+ )
65
+ parser.add_argument(
66
+ "--gradient_accumulation_steps",
67
+ type=int,
68
+ default=1,
69
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
70
+ )
71
+ parser.add_argument(
72
+ "--gradient_checkpointing",
73
+ action="store_true",
74
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
75
+ )
76
+ parser.add_argument(
77
+ "--learning_rate",
78
+ type=float,
79
+ default=1e-4, #5e-6,
80
+ help="Initial learning rate (after the potential warmup period) to use.",
81
+ )
82
+
83
+ parser.add_argument(
84
+ "--scale_lr",
85
+ action="store_true",
86
+ default=False,
87
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
88
+ )
89
+ parser.add_argument(
90
+ "--lr_scheduler",
91
+ type=str,
92
+ default="constant",
93
+ help=(
94
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
95
+ ' "constant", "constant_with_warmup"]'
96
+ ),
97
+ )
98
+ parser.add_argument(
99
+ "--lr_warmup_steps",
100
+ type=int,
101
+ default=5000,
102
+ help="Number of steps for the warmup in the lr scheduler.",
103
+ )
104
+ parser.add_argument(
105
+ "--lr_num_cycles",
106
+ type=int,
107
+ default=1,
108
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
109
+ )
110
+ parser.add_argument(
111
+ "--lr_power",
112
+ type=float,
113
+ default=1.0,
114
+ help="Power factor of the polynomial scheduler.",
115
+ )
116
+
117
+
118
+ parser.add_argument(
119
+ "--adam_beta1",
120
+ type=float,
121
+ default=0.9,
122
+ help="The beta1 parameter for the Adam optimizer.",
123
+ )
124
+ parser.add_argument(
125
+ "--adam_beta2",
126
+ type=float,
127
+ default=0.999,
128
+ help="The beta2 parameter for the Adam optimizer.",
129
+ )
130
+ parser.add_argument(
131
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
132
+ )
133
+ parser.add_argument(
134
+ "--adam_epsilon",
135
+ type=float,
136
+ default=1e-08,
137
+ help="Epsilon value for the Adam optimizer",
138
+ )
139
+ parser.add_argument(
140
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
141
+ )
142
+ parser.add_argument(
143
+ "--logging_dir",
144
+ type=str,
145
+ default="logs",
146
+ help=(
147
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
148
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
149
+ ),
150
+ )
151
+ parser.add_argument(
152
+ "--report_to",
153
+ type=str,
154
+ default="tensorboard",
155
+ help=(
156
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
157
+ ' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
158
+ "Only applicable when `--with_tracking` is passed."
159
+ ),
160
+ )
161
+ parser.add_argument(
162
+ "--allow_tf32",
163
+ action="store_true",
164
+ help=(
165
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
166
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
167
+ ),
168
+ )
169
+ parser.add_argument(
170
+ "--mixed_precision",
171
+ type=str,
172
+ default="fp16",
173
+ choices=["no", "fp16", "bf16"],
174
+ help=(
175
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
176
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
177
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
178
+ ),
179
+ )
180
+
181
+
182
+
183
+ parser.add_argument("--noise_offset", type=float, default=0.1, help="The scale of noise offset.")
184
+
185
+
186
+
187
+
188
+ args = parser.parse_args()
189
+ print(args)
190
+
191
+
192
+
src/configs/stage3_config.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
3
+ parser.add_argument(
4
+ "--pretrained_model_name_or_path",
5
+ type=str,
6
+ default=None,
7
+ required=True,
8
+ help="Path to pretrained model or model identifier from huggingface.co/models.",)
9
+
10
+ parser.add_argument("--revision",type=str,default=None,required=False,help=(
11
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
12
+ " float32 precision."),)
13
+ parser.add_argument("--json_path", type=str, default="./datasets/deepfashing/test_data.json", help="json path", )
14
+ parser.add_argument("--img_path", type=str, default="./datasets/deepfashing/train_all_png/", help="image path", )
15
+ parser.add_argument("--gen_t_img_path", type=str,default="./save_data/stage2/guidancescale2_seed42_numsteps20/",help="gen target image path", )
16
+ parser.add_argument("--image_encoder_path", type=str, default="./dinov2-giant",
17
+ help="Path to pretrained model or model identifier from huggingface.co/models.", )
18
+ parser.add_argument("--output_dir",type=str,default="controlnet-model",help="The output directory where the model predictions and checkpoints will be written.",)
19
+
20
+
21
+ parser.add_argument(
22
+ "--seed", type=int, default=None, help="A seed for reproducible training."
23
+ )
24
+ parser.add_argument(
25
+ "--resolution",
26
+ type=int,
27
+ default=512,
28
+ help=(
29
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
30
+ " resolution"
31
+ ),
32
+ )
33
+ parser.add_argument(
34
+ "--train_batch_size",
35
+ type=int,
36
+ default=4,
37
+ help="Batch size (per device) for the training dataloader.",
38
+ )
39
+ parser.add_argument("--num_train_epochs", type=int, default=1)
40
+ parser.add_argument("--noise_level", type=int, default=250)
41
+
42
+ parser.add_argument(
43
+ "--max_train_steps",
44
+ type=int,
45
+ default=None,
46
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
47
+ )
48
+ parser.add_argument(
49
+ "--checkpointing_steps",
50
+ type=int,
51
+ default=500,
52
+ help=(
53
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
54
+ "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
55
+ "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
56
+ "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
57
+ "instructions."
58
+ ),
59
+ )
60
+
61
+ parser.add_argument(
62
+ "--resume_from_checkpoint",
63
+ type=str,
64
+ default=None,
65
+ help=(
66
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
67
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
68
+ ),
69
+ )
70
+ parser.add_argument(
71
+ "--gradient_accumulation_steps",
72
+ type=int,
73
+ default=1,
74
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
75
+ )
76
+ parser.add_argument(
77
+ "--gradient_checkpointing",
78
+ action="store_true",
79
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
80
+ )
81
+ parser.add_argument(
82
+ "--learning_rate",
83
+ type=float,
84
+ default=5e-6,
85
+ help="Initial learning rate (after the potential warmup period) to use.",
86
+ )
87
+
88
+ parser.add_argument(
89
+ "--scale_lr",
90
+ action="store_true",
91
+ default=False,
92
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
93
+ )
94
+ parser.add_argument(
95
+ "--lr_scheduler",
96
+ type=str,
97
+ default="constant",
98
+ help=(
99
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
100
+ ' "constant", "constant_with_warmup"]'
101
+ ),
102
+ )
103
+ parser.add_argument(
104
+ "--lr_warmup_steps",
105
+ type=int,
106
+ default=500,
107
+ help="Number of steps for the warmup in the lr scheduler.",
108
+ )
109
+ parser.add_argument(
110
+ "--lr_num_cycles",
111
+ type=int,
112
+ default=1,
113
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
114
+ )
115
+ parser.add_argument(
116
+ "--lr_power",
117
+ type=float,
118
+ default=1.0,
119
+ help="Power factor of the polynomial scheduler.",
120
+ )
121
+
122
+
123
+ parser.add_argument(
124
+ "--adam_beta1",
125
+ type=float,
126
+ default=0.9,
127
+ help="The beta1 parameter for the Adam optimizer.",
128
+ )
129
+ parser.add_argument(
130
+ "--adam_beta2",
131
+ type=float,
132
+ default=0.999,
133
+ help="The beta2 parameter for the Adam optimizer.",
134
+ )
135
+ parser.add_argument(
136
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
137
+ )
138
+ parser.add_argument(
139
+ "--adam_epsilon",
140
+ type=float,
141
+ default=1e-08,
142
+ help="Epsilon value for the Adam optimizer",
143
+ )
144
+ parser.add_argument(
145
+ "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
146
+ )
147
+
148
+
149
+ parser.add_argument(
150
+ "--logging_dir",
151
+ type=str,
152
+ default="logs",
153
+ help=(
154
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
155
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
156
+ ),
157
+ )
158
+ parser.add_argument(
159
+ "--allow_tf32",
160
+ action="store_true",
161
+ help=(
162
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
163
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
164
+ ),
165
+ )
166
+ parser.add_argument(
167
+ "--report_to",
168
+ type=str,
169
+ default="tensorboard",
170
+ help=(
171
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
172
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
173
+ ),
174
+ )
175
+ parser.add_argument(
176
+ "--mixed_precision",
177
+ type=str,
178
+ default=None,
179
+ choices=["no", "fp16", "bf16"],
180
+ help=(
181
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
182
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
183
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
184
+ ),
185
+ )
186
+
187
+ parser.add_argument(
188
+ "--set_grads_to_none",
189
+ action="store_true",
190
+ help=(
191
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
192
+ " behaviors, so disable this argument if it causes any problems. More info:"
193
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
194
+ ),
195
+ )
196
+
197
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
198
+
199
+ parser.add_argument(
200
+ "--tracker_project_name",
201
+ type=str,
202
+ default="train_baseline",
203
+ help=(
204
+ "The `project_name` argument passed to Accelerator.init_trackers for"
205
+ " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
206
+ ),
207
+ )
208
+
209
+
210
+ args = parser.parse_args()
211
+ print(args)
212
+ if args.resolution % 8 != 0:
213
+ raise ValueError(
214
+ "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
215
+ )
216
+
217
+
src/configs/yolox_l_8xb8-300e_coco.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ img_scale = (640, 640) # width, height
2
+
3
+ # model settings
4
+ model = dict(
5
+ type='YOLOX',
6
+ data_preprocessor=dict(
7
+ type='DetDataPreprocessor',
8
+ pad_size_divisor=32,
9
+ batch_augments=[
10
+ dict(
11
+ type='BatchSyncRandomResize',
12
+ random_size_range=(480, 800),
13
+ size_divisor=32,
14
+ interval=10)
15
+ ]),
16
+ backbone=dict(
17
+ type='CSPDarknet',
18
+ deepen_factor=1.0,
19
+ widen_factor=1.0,
20
+ out_indices=(2, 3, 4),
21
+ use_depthwise=False,
22
+ spp_kernal_sizes=(5, 9, 13),
23
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
24
+ act_cfg=dict(type='Swish'),
25
+ ),
26
+ neck=dict(
27
+ type='YOLOXPAFPN',
28
+ in_channels=[256, 512, 1024],
29
+ out_channels=256,
30
+ num_csp_blocks=3,
31
+ use_depthwise=False,
32
+ upsample_cfg=dict(scale_factor=2, mode='nearest'),
33
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
34
+ act_cfg=dict(type='Swish')),
35
+ bbox_head=dict(
36
+ type='YOLOXHead',
37
+ num_classes=80,
38
+ in_channels=256,
39
+ feat_channels=256,
40
+ stacked_convs=2,
41
+ strides=(8, 16, 32),
42
+ use_depthwise=False,
43
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
44
+ act_cfg=dict(type='Swish'),
45
+ loss_cls=dict(
46
+ type='CrossEntropyLoss',
47
+ use_sigmoid=True,
48
+ reduction='sum',
49
+ loss_weight=1.0),
50
+ loss_bbox=dict(
51
+ type='IoULoss',
52
+ mode='square',
53
+ eps=1e-16,
54
+ reduction='sum',
55
+ loss_weight=5.0),
56
+ loss_obj=dict(
57
+ type='CrossEntropyLoss',
58
+ use_sigmoid=True,
59
+ reduction='sum',
60
+ loss_weight=1.0),
61
+ loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
62
+ train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
63
+ # In order to align the source code, the threshold of the val phase is
64
+ # 0.01, and the threshold of the test phase is 0.001.
65
+ test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
66
+
67
+ # dataset settings
68
+ data_root = 'data/coco/'
69
+ dataset_type = 'CocoDataset'
70
+
71
+ # Example to use different file client
72
+ # Method 1: simply set the data root and let the file I/O module
73
+ # automatically infer from prefix (not support LMDB and Memcache yet)
74
+
75
+ # data_root = 's3://openmmlab/datasets/detection/coco/'
76
+
77
+ # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
78
+ # backend_args = dict(
79
+ # backend='petrel',
80
+ # path_mapping=dict({
81
+ # './data/': 's3://openmmlab/datasets/detection/',
82
+ # 'data/': 's3://openmmlab/datasets/detection/'
83
+ # }))
84
+ backend_args = None
85
+
86
+ train_pipeline = [
87
+ dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
88
+ dict(
89
+ type='RandomAffine',
90
+ scaling_ratio_range=(0.1, 2),
91
+ # img_scale is (width, height)
92
+ border=(-img_scale[0] // 2, -img_scale[1] // 2)),
93
+ dict(
94
+ type='MixUp',
95
+ img_scale=img_scale,
96
+ ratio_range=(0.8, 1.6),
97
+ pad_val=114.0),
98
+ dict(type='YOLOXHSVRandomAug'),
99
+ dict(type='RandomFlip', prob=0.5),
100
+ # According to the official implementation, multi-scale
101
+ # training is not considered here but in the
102
+ # 'mmdet/models/detectors/yolox.py'.
103
+ # Resize and Pad are for the last 15 epochs when Mosaic,
104
+ # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
105
+ dict(type='Resize', scale=img_scale, keep_ratio=True),
106
+ dict(
107
+ type='Pad',
108
+ pad_to_square=True,
109
+ # If the image is three-channel, the pad value needs
110
+ # to be set separately for each channel.
111
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
112
+ dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
113
+ dict(type='PackDetInputs')
114
+ ]
115
+
116
+ train_dataset = dict(
117
+ # use MultiImageMixDataset wrapper to support mosaic and mixup
118
+ type='MultiImageMixDataset',
119
+ dataset=dict(
120
+ type=dataset_type,
121
+ data_root=data_root,
122
+ ann_file='annotations/instances_train2017.json',
123
+ data_prefix=dict(img='train2017/'),
124
+ pipeline=[
125
+ dict(type='LoadImageFromFile', backend_args=backend_args),
126
+ dict(type='LoadAnnotations', with_bbox=True)
127
+ ],
128
+ filter_cfg=dict(filter_empty_gt=False, min_size=32),
129
+ backend_args=backend_args),
130
+ pipeline=train_pipeline)
131
+
132
+ test_pipeline = [
133
+ dict(type='LoadImageFromFile', backend_args=backend_args),
134
+ dict(type='Resize', scale=img_scale, keep_ratio=True),
135
+ dict(
136
+ type='Pad',
137
+ pad_to_square=True,
138
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
139
+ dict(type='LoadAnnotations', with_bbox=True),
140
+ dict(
141
+ type='PackDetInputs',
142
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
143
+ 'scale_factor'))
144
+ ]
145
+
146
+ train_dataloader = dict(
147
+ batch_size=8,
148
+ num_workers=4,
149
+ persistent_workers=True,
150
+ sampler=dict(type='DefaultSampler', shuffle=True),
151
+ dataset=train_dataset)
152
+ val_dataloader = dict(
153
+ batch_size=8,
154
+ num_workers=4,
155
+ persistent_workers=True,
156
+ drop_last=False,
157
+ sampler=dict(type='DefaultSampler', shuffle=False),
158
+ dataset=dict(
159
+ type=dataset_type,
160
+ data_root=data_root,
161
+ ann_file='annotations/instances_val2017.json',
162
+ data_prefix=dict(img='val2017/'),
163
+ test_mode=True,
164
+ pipeline=test_pipeline,
165
+ backend_args=backend_args))
166
+ test_dataloader = val_dataloader
167
+
168
+ val_evaluator = dict(
169
+ type='CocoMetric',
170
+ ann_file=data_root + 'annotations/instances_val2017.json',
171
+ metric='bbox',
172
+ backend_args=backend_args)
173
+ test_evaluator = val_evaluator
174
+
175
+ # training settings
176
+ max_epochs = 300
177
+ num_last_epochs = 15
178
+ interval = 10
179
+
180
+ train_cfg = dict(max_epochs=max_epochs, val_interval=interval)
181
+
182
+ # optimizer
183
+ # default 8 gpu
184
+ base_lr = 0.01
185
+ optim_wrapper = dict(
186
+ type='OptimWrapper',
187
+ optimizer=dict(
188
+ type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4,
189
+ nesterov=True),
190
+ paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
191
+
192
+ # learning rate
193
+ param_scheduler = [
194
+ dict(
195
+ # use quadratic formula to warm up 5 epochs
196
+ # and lr is updated by iteration
197
+ # TODO: fix default scope in get function
198
+ type='mmdet.QuadraticWarmupLR',
199
+ by_epoch=True,
200
+ begin=0,
201
+ end=5,
202
+ convert_to_iter_based=True),
203
+ dict(
204
+ # use cosine lr from 5 to 285 epoch
205
+ type='CosineAnnealingLR',
206
+ eta_min=base_lr * 0.05,
207
+ begin=5,
208
+ T_max=max_epochs - num_last_epochs,
209
+ end=max_epochs - num_last_epochs,
210
+ by_epoch=True,
211
+ convert_to_iter_based=True),
212
+ dict(
213
+ # use fixed lr during last 15 epochs
214
+ type='ConstantLR',
215
+ by_epoch=True,
216
+ factor=1,
217
+ begin=max_epochs - num_last_epochs,
218
+ end=max_epochs,
219
+ )
220
+ ]
221
+
222
+ default_hooks = dict(
223
+ checkpoint=dict(
224
+ interval=interval,
225
+ max_keep_ckpts=3 # only keep latest 3 checkpoints
226
+ ))
227
+
228
+ custom_hooks = [
229
+ dict(
230
+ type='YOLOXModeSwitchHook',
231
+ num_last_epochs=num_last_epochs,
232
+ priority=48),
233
+ dict(type='SyncNormHook', priority=48),
234
+ dict(
235
+ type='EMAHook',
236
+ ema_type='ExpMomentumEMA',
237
+ momentum=0.0001,
238
+ update_buffers=True,
239
+ priority=49)
240
+ ]
241
+
242
+ # NOTE: `auto_scale_lr` is for automatically scaling LR,
243
+ # USER SHOULD NOT CHANGE ITS VALUES.
244
+ # base_batch_size = (8 GPUs) x (8 samples per GPU)
245
+ auto_scale_lr = dict(base_batch_size=64)
src/controlnet_aux/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.0.6"
2
+
3
+ from .hed import HEDdetector
4
+ from .leres import LeresDetector
5
+ from .lineart import LineartDetector
6
+ from .lineart_anime import LineartAnimeDetector
7
+ from .midas import MidasDetector
8
+ from .mlsd import MLSDdetector
9
+ from .normalbae import NormalBaeDetector
10
+ from .open_pose import OpenposeDetector
11
+ from .pidi import PidiNetDetector
12
+ from .zoe import ZoeDetector
13
+
14
+ from .canny import CannyDetector
15
+ from .mediapipe_face import MediapipeFaceDetector
16
+ from .segment_anything import SamDetector
17
+ from .shuffle import ContentShuffleDetector
18
+ from .dwpose import DWposeDetector
src/controlnet_aux/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.17 kB). View file
 
src/controlnet_aux/__pycache__/util.cpython-311.pyc ADDED
Binary file (13 kB). View file
 
src/controlnet_aux/canny/__init__.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from ..util import HWC3, resize_image
6
+
7
+ class CannyDetector:
8
+ def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs):
9
+ if "img" in kwargs:
10
+ warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
11
+ input_image = kwargs.pop("img")
12
+
13
+ if input_image is None:
14
+ raise ValueError("input_image must be defined.")
15
+
16
+ if not isinstance(input_image, np.ndarray):
17
+ input_image = np.array(input_image, dtype=np.uint8)
18
+ output_type = output_type or "pil"
19
+ else:
20
+ output_type = output_type or "np"
21
+
22
+ input_image = HWC3(input_image)
23
+ input_image = resize_image(input_image, detect_resolution)
24
+
25
+ detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
26
+ detected_map = HWC3(detected_map)
27
+
28
+ img = resize_image(input_image, image_resolution)
29
+ H, W, C = img.shape
30
+
31
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
32
+
33
+ if output_type == "pil":
34
+ detected_map = Image.fromarray(detected_map)
35
+
36
+ return detected_map
src/controlnet_aux/canny/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (2.11 kB). View file
 
src/controlnet_aux/dwpose/__init__.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Openpose
2
+ # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3
+ # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4
+ # 3rd Edited by ControlNet
5
+ # 4th Edited by ControlNet (added face and correct hands)
6
+
7
+ import os
8
+ os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
9
+
10
+ import cv2
11
+ import torch
12
+ import numpy as np
13
+ from PIL import Image
14
+
15
+ from ..util import HWC3, resize_image
16
+ from . import util
17
+
18
+
19
+ def draw_pose(pose, H, W):
20
+ bodies = pose['bodies']
21
+ faces = pose['faces']
22
+ hands = pose['hands']
23
+ candidate = bodies['candidate']
24
+ subset = bodies['subset']
25
+
26
+ canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
27
+ canvas = util.draw_bodypose(canvas, candidate, subset)
28
+ canvas = util.draw_handpose(canvas, hands)
29
+ # canvas = util.draw_facepose(canvas, faces)
30
+
31
+ return canvas
32
+
33
+ class DWposeDetector:
34
+ def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"):
35
+ from .wholebody import Wholebody
36
+
37
+ self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device)
38
+
39
+ def to(self, device):
40
+ self.pose_estimation.to(device)
41
+ return self
42
+
43
+ def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
44
+
45
+ input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
46
+
47
+ input_image = HWC3(input_image)
48
+ input_image = resize_image(input_image, detect_resolution)
49
+
50
+ H, W, C = input_image.shape
51
+
52
+ with torch.no_grad():
53
+ candidate, subset = self.pose_estimation(input_image)
54
+ nums, keys, locs = candidate.shape
55
+ candidate[..., 0] /= float(W)
56
+ candidate[..., 1] /= float(H)
57
+ body = candidate[:,:18].copy()
58
+ body = body.reshape(nums*18, locs)
59
+ score = subset[:,:18]
60
+
61
+ for i in range(len(score)):
62
+ for j in range(len(score[i])):
63
+ if score[i][j] > 0.3:
64
+ score[i][j] = int(18*i+j)
65
+ else:
66
+ score[i][j] = -1
67
+
68
+ un_visible = subset<0.3
69
+ candidate[un_visible] = -1
70
+
71
+ foot = candidate[:,18:24]
72
+
73
+ faces = candidate[:,24:92]
74
+
75
+ hands = candidate[:,92:113]
76
+ hands = np.vstack([hands, candidate[:,113:]])
77
+
78
+ bodies = dict(candidate=body, subset=score)
79
+ pose = dict(bodies=bodies, hands=hands, faces=faces)
80
+
81
+ detected_map = draw_pose(pose, H, W)
82
+ detected_map = HWC3(detected_map)
83
+
84
+ img = resize_image(input_image, image_resolution)
85
+ H, W, C = img.shape
86
+
87
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
88
+
89
+ if output_type == "pil":
90
+ detected_map = Image.fromarray(detected_map)
91
+
92
+ return detected_map
src/controlnet_aux/dwpose/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (4.84 kB). View file
 
src/controlnet_aux/dwpose/__pycache__/util.cpython-311.pyc ADDED
Binary file (16.5 kB). View file
 
src/controlnet_aux/dwpose/__pycache__/wholebody.cpython-311.pyc ADDED
Binary file (6.2 kB). View file
 
src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # runtime
2
+ max_epochs = 270
3
+ stage2_num_epochs = 30
4
+ base_lr = 4e-3
5
+
6
+ train_cfg = dict(max_epochs=max_epochs, val_interval=10)
7
+ randomness = dict(seed=21)
8
+
9
+ # optimizer
10
+ optim_wrapper = dict(
11
+ type='OptimWrapper',
12
+ optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
13
+ paramwise_cfg=dict(
14
+ norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
15
+
16
+ # learning rate
17
+ param_scheduler = [
18
+ dict(
19
+ type='LinearLR',
20
+ start_factor=1.0e-5,
21
+ by_epoch=False,
22
+ begin=0,
23
+ end=1000),
24
+ dict(
25
+ # use cosine lr from 150 to 300 epoch
26
+ type='CosineAnnealingLR',
27
+ eta_min=base_lr * 0.05,
28
+ begin=max_epochs // 2,
29
+ end=max_epochs,
30
+ T_max=max_epochs // 2,
31
+ by_epoch=True,
32
+ convert_to_iter_based=True),
33
+ ]
34
+
35
+ # automatically scaling LR based on the actual training batch size
36
+ auto_scale_lr = dict(base_batch_size=512)
37
+
38
+ # codec settings
39
+ codec = dict(
40
+ type='SimCCLabel',
41
+ input_size=(288, 384),
42
+ sigma=(6., 6.93),
43
+ simcc_split_ratio=2.0,
44
+ normalize=False,
45
+ use_dark=False)
46
+
47
+ # model settings
48
+ model = dict(
49
+ type='TopdownPoseEstimator',
50
+ data_preprocessor=dict(
51
+ type='PoseDataPreprocessor',
52
+ mean=[123.675, 116.28, 103.53],
53
+ std=[58.395, 57.12, 57.375],
54
+ bgr_to_rgb=True),
55
+ backbone=dict(
56
+ _scope_='mmdet',
57
+ type='CSPNeXt',
58
+ arch='P5',
59
+ expand_ratio=0.5,
60
+ deepen_factor=1.,
61
+ widen_factor=1.,
62
+ out_indices=(4, ),
63
+ channel_attention=True,
64
+ norm_cfg=dict(type='SyncBN'),
65
+ act_cfg=dict(type='SiLU'),
66
+ init_cfg=dict(
67
+ type='Pretrained',
68
+ prefix='backbone.',
69
+ checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
70
+ 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa
71
+ )),
72
+ head=dict(
73
+ type='RTMCCHead',
74
+ in_channels=1024,
75
+ out_channels=133,
76
+ input_size=codec['input_size'],
77
+ in_featuremap_size=(9, 12),
78
+ simcc_split_ratio=codec['simcc_split_ratio'],
79
+ final_layer_kernel_size=7,
80
+ gau_cfg=dict(
81
+ hidden_dims=256,
82
+ s=128,
83
+ expansion_factor=2,
84
+ dropout_rate=0.,
85
+ drop_path=0.,
86
+ act_fn='SiLU',
87
+ use_rel_bias=False,
88
+ pos_enc=False),
89
+ loss=dict(
90
+ type='KLDiscretLoss',
91
+ use_target_weight=True,
92
+ beta=10.,
93
+ label_softmax=True),
94
+ decoder=codec),
95
+ test_cfg=dict(flip_test=True, ))
96
+
97
+ # base dataset settings
98
+ dataset_type = 'CocoWholeBodyDataset'
99
+ data_mode = 'topdown'
100
+ data_root = '/data/'
101
+
102
+ backend_args = dict(backend='local')
103
+ # backend_args = dict(
104
+ # backend='petrel',
105
+ # path_mapping=dict({
106
+ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/',
107
+ # f'{data_root}': 's3://openmmlab/datasets/detection/coco/'
108
+ # }))
109
+
110
+ # pipelines
111
+ train_pipeline = [
112
+ dict(type='LoadImage', backend_args=backend_args),
113
+ dict(type='GetBBoxCenterScale'),
114
+ dict(type='RandomFlip', direction='horizontal'),
115
+ dict(type='RandomHalfBody'),
116
+ dict(
117
+ type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
118
+ dict(type='TopdownAffine', input_size=codec['input_size']),
119
+ dict(type='mmdet.YOLOXHSVRandomAug'),
120
+ dict(
121
+ type='Albumentation',
122
+ transforms=[
123
+ dict(type='Blur', p=0.1),
124
+ dict(type='MedianBlur', p=0.1),
125
+ dict(
126
+ type='CoarseDropout',
127
+ max_holes=1,
128
+ max_height=0.4,
129
+ max_width=0.4,
130
+ min_holes=1,
131
+ min_height=0.2,
132
+ min_width=0.2,
133
+ p=1.0),
134
+ ]),
135
+ dict(type='GenerateTarget', encoder=codec),
136
+ dict(type='PackPoseInputs')
137
+ ]
138
+ val_pipeline = [
139
+ dict(type='LoadImage', backend_args=backend_args),
140
+ dict(type='GetBBoxCenterScale'),
141
+ dict(type='TopdownAffine', input_size=codec['input_size']),
142
+ dict(type='PackPoseInputs')
143
+ ]
144
+
145
+ train_pipeline_stage2 = [
146
+ dict(type='LoadImage', backend_args=backend_args),
147
+ dict(type='GetBBoxCenterScale'),
148
+ dict(type='RandomFlip', direction='horizontal'),
149
+ dict(type='RandomHalfBody'),
150
+ dict(
151
+ type='RandomBBoxTransform',
152
+ shift_factor=0.,
153
+ scale_factor=[0.75, 1.25],
154
+ rotate_factor=60),
155
+ dict(type='TopdownAffine', input_size=codec['input_size']),
156
+ dict(type='mmdet.YOLOXHSVRandomAug'),
157
+ dict(
158
+ type='Albumentation',
159
+ transforms=[
160
+ dict(type='Blur', p=0.1),
161
+ dict(type='MedianBlur', p=0.1),
162
+ dict(
163
+ type='CoarseDropout',
164
+ max_holes=1,
165
+ max_height=0.4,
166
+ max_width=0.4,
167
+ min_holes=1,
168
+ min_height=0.2,
169
+ min_width=0.2,
170
+ p=0.5),
171
+ ]),
172
+ dict(type='GenerateTarget', encoder=codec),
173
+ dict(type='PackPoseInputs')
174
+ ]
175
+
176
+ datasets = []
177
+ dataset_coco=dict(
178
+ type=dataset_type,
179
+ data_root=data_root,
180
+ data_mode=data_mode,
181
+ ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
182
+ data_prefix=dict(img='coco/train2017/'),
183
+ pipeline=[],
184
+ )
185
+ datasets.append(dataset_coco)
186
+
187
+ scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class',
188
+ 'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow',
189
+ 'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference']
190
+
191
+ for i in range(len(scene)):
192
+ datasets.append(
193
+ dict(
194
+ type=dataset_type,
195
+ data_root=data_root,
196
+ data_mode=data_mode,
197
+ ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json',
198
+ data_prefix=dict(img='UBody/images/'+scene[i]+'/'),
199
+ pipeline=[],
200
+ )
201
+ )
202
+
203
+ # data loaders
204
+ train_dataloader = dict(
205
+ batch_size=32,
206
+ num_workers=10,
207
+ persistent_workers=True,
208
+ sampler=dict(type='DefaultSampler', shuffle=True),
209
+ dataset=dict(
210
+ type='CombinedDataset',
211
+ metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
212
+ datasets=datasets,
213
+ pipeline=train_pipeline,
214
+ test_mode=False,
215
+ ))
216
+ val_dataloader = dict(
217
+ batch_size=32,
218
+ num_workers=10,
219
+ persistent_workers=True,
220
+ drop_last=False,
221
+ sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
222
+ dataset=dict(
223
+ type=dataset_type,
224
+ data_root=data_root,
225
+ data_mode=data_mode,
226
+ ann_file='coco/annotations/coco_wholebody_val_v1.0.json',
227
+ bbox_file=f'{data_root}coco/person_detection_results/'
228
+ 'COCO_val2017_detections_AP_H_56_person.json',
229
+ data_prefix=dict(img='coco/val2017/'),
230
+ test_mode=True,
231
+ pipeline=val_pipeline,
232
+ ))
233
+ test_dataloader = val_dataloader
234
+
235
+ # hooks
236
+ default_hooks = dict(
237
+ checkpoint=dict(
238
+ save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
239
+
240
+ custom_hooks = [
241
+ dict(
242
+ type='EMAHook',
243
+ ema_type='ExpMomentumEMA',
244
+ momentum=0.0002,
245
+ update_buffers=True,
246
+ priority=49),
247
+ dict(
248
+ type='mmdet.PipelineSwitchHook',
249
+ switch_epoch=max_epochs - stage2_num_epochs,
250
+ switch_pipeline=train_pipeline_stage2)
251
+ ]
252
+
253
+ # evaluators
254
+ val_evaluator = dict(
255
+ type='CocoWholeBodyMetric',
256
+ ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json')
257
+ test_evaluator = val_evaluator
src/controlnet_aux/dwpose/util.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import cv2
4
+
5
+ eps = 0.01
6
+
7
+
8
+ def smart_resize(x, s):
9
+ Ht, Wt = s
10
+ if x.ndim == 2:
11
+ Ho, Wo = x.shape
12
+ Co = 1
13
+ else:
14
+ Ho, Wo, Co = x.shape
15
+ if Co == 3 or Co == 1:
16
+ k = float(Ht + Wt) / float(Ho + Wo)
17
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
18
+ else:
19
+ return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
20
+
21
+
22
+ def smart_resize_k(x, fx, fy):
23
+ if x.ndim == 2:
24
+ Ho, Wo = x.shape
25
+ Co = 1
26
+ else:
27
+ Ho, Wo, Co = x.shape
28
+ Ht, Wt = Ho * fy, Wo * fx
29
+ if Co == 3 or Co == 1:
30
+ k = float(Ht + Wt) / float(Ho + Wo)
31
+ return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
32
+ else:
33
+ return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
34
+
35
+
36
+ def padRightDownCorner(img, stride, padValue):
37
+ h = img.shape[0]
38
+ w = img.shape[1]
39
+
40
+ pad = 4 * [None]
41
+ pad[0] = 0 # up
42
+ pad[1] = 0 # left
43
+ pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
44
+ pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
45
+
46
+ img_padded = img
47
+ pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
48
+ img_padded = np.concatenate((pad_up, img_padded), axis=0)
49
+ pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
50
+ img_padded = np.concatenate((pad_left, img_padded), axis=1)
51
+ pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
52
+ img_padded = np.concatenate((img_padded, pad_down), axis=0)
53
+ pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
54
+ img_padded = np.concatenate((img_padded, pad_right), axis=1)
55
+
56
+ return img_padded, pad
57
+
58
+
59
+ def transfer(model, model_weights):
60
+ transfered_model_weights = {}
61
+ for weights_name in model.state_dict().keys():
62
+ transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
63
+ return transfered_model_weights
64
+
65
+
66
+ def draw_bodypose(canvas, candidate, subset):
67
+ H, W, C = canvas.shape
68
+ candidate = np.array(candidate)
69
+ subset = np.array(subset)
70
+
71
+ stickwidth = 4
72
+
73
+ limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
74
+ [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
75
+ [1, 16], [16, 18], [3, 17], [6, 18]]
76
+
77
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
78
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
79
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
80
+
81
+ for i in range(17):
82
+ for n in range(len(subset)):
83
+ index = subset[n][np.array(limbSeq[i]) - 1]
84
+ if -1 in index:
85
+ continue
86
+ Y = candidate[index.astype(int), 0] * float(W)
87
+ X = candidate[index.astype(int), 1] * float(H)
88
+ mX = np.mean(X)
89
+ mY = np.mean(Y)
90
+ length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
91
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
92
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
93
+ cv2.fillConvexPoly(canvas, polygon, colors[i])
94
+
95
+ canvas = (canvas * 0.6).astype(np.uint8)
96
+
97
+ for i in range(18):
98
+ for n in range(len(subset)):
99
+ index = int(subset[n][i])
100
+ if index == -1:
101
+ continue
102
+ x, y = candidate[index][0:2]
103
+ x = int(x * W)
104
+ y = int(y * H)
105
+ cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
106
+
107
+ return canvas
108
+
109
+
110
+ def draw_handpose(canvas, all_hand_peaks):
111
+ import matplotlib
112
+
113
+ H, W, C = canvas.shape
114
+
115
+ edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
116
+ [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
117
+
118
+ # (person_number*2, 21, 2)
119
+ for i in range(len(all_hand_peaks)):
120
+ peaks = all_hand_peaks[i]
121
+ peaks = np.array(peaks)
122
+
123
+ for ie, e in enumerate(edges):
124
+
125
+ x1, y1 = peaks[e[0]]
126
+ x2, y2 = peaks[e[1]]
127
+
128
+ x1 = int(x1 * W)
129
+ y1 = int(y1 * H)
130
+ x2 = int(x2 * W)
131
+ y2 = int(y2 * H)
132
+ if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
133
+ cv2.line(canvas, (x1, y1), (x2, y2),
134
+ matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=1)
135
+
136
+ for _, keyponit in enumerate(peaks):
137
+ x, y = keyponit
138
+
139
+ x = int(x * W)
140
+ y = int(y * H)
141
+ if x > eps and y > eps:
142
+ cv2.circle(canvas, (x, y), 1, (0, 0, 255), thickness=-1)
143
+ return canvas
144
+
145
+
146
+ def draw_facepose(canvas, all_lmks):
147
+ H, W, C = canvas.shape
148
+ for lmks in all_lmks:
149
+ lmks = np.array(lmks)
150
+ for lmk in lmks:
151
+ x, y = lmk
152
+ x = int(x * W)
153
+ y = int(y * H)
154
+ if x > eps and y > eps:
155
+ cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
156
+ return canvas
157
+
158
+
159
+ # detect hand according to body pose keypoints
160
+ # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
161
+ def handDetect(candidate, subset, oriImg):
162
+ # right hand: wrist 4, elbow 3, shoulder 2
163
+ # left hand: wrist 7, elbow 6, shoulder 5
164
+ ratioWristElbow = 0.33
165
+ detect_result = []
166
+ image_height, image_width = oriImg.shape[0:2]
167
+ for person in subset.astype(int):
168
+ # if any of three not detected
169
+ has_left = np.sum(person[[5, 6, 7]] == -1) == 0
170
+ has_right = np.sum(person[[2, 3, 4]] == -1) == 0
171
+ if not (has_left or has_right):
172
+ continue
173
+ hands = []
174
+ # left hand
175
+ if has_left:
176
+ left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
177
+ x1, y1 = candidate[left_shoulder_index][:2]
178
+ x2, y2 = candidate[left_elbow_index][:2]
179
+ x3, y3 = candidate[left_wrist_index][:2]
180
+ hands.append([x1, y1, x2, y2, x3, y3, True])
181
+ # right hand
182
+ if has_right:
183
+ right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
184
+ x1, y1 = candidate[right_shoulder_index][:2]
185
+ x2, y2 = candidate[right_elbow_index][:2]
186
+ x3, y3 = candidate[right_wrist_index][:2]
187
+ hands.append([x1, y1, x2, y2, x3, y3, False])
188
+
189
+ for x1, y1, x2, y2, x3, y3, is_left in hands:
190
+ # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
191
+ # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
192
+ # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
193
+ # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
194
+ # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
195
+ # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
196
+ x = x3 + ratioWristElbow * (x3 - x2)
197
+ y = y3 + ratioWristElbow * (y3 - y2)
198
+ distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
199
+ distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
200
+ width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
201
+ # x-y refers to the center --> offset to topLeft point
202
+ # handRectangle.x -= handRectangle.width / 2.f;
203
+ # handRectangle.y -= handRectangle.height / 2.f;
204
+ x -= width / 2
205
+ y -= width / 2 # width = height
206
+ # overflow the image
207
+ if x < 0: x = 0
208
+ if y < 0: y = 0
209
+ width1 = width
210
+ width2 = width
211
+ if x + width > image_width: width1 = image_width - x
212
+ if y + width > image_height: width2 = image_height - y
213
+ width = min(width1, width2)
214
+ # the max hand box value is 20 pixels
215
+ if width >= 20:
216
+ detect_result.append([int(x), int(y), int(width), is_left])
217
+
218
+ '''
219
+ return value: [[x, y, w, True if left hand else False]].
220
+ width=height since the network require squared input.
221
+ x, y is the coordinate of top left
222
+ '''
223
+ return detect_result
224
+
225
+
226
+ # Written by Lvmin
227
+ def faceDetect(candidate, subset, oriImg):
228
+ # left right eye ear 14 15 16 17
229
+ detect_result = []
230
+ image_height, image_width = oriImg.shape[0:2]
231
+ for person in subset.astype(int):
232
+ has_head = person[0] > -1
233
+ if not has_head:
234
+ continue
235
+
236
+ has_left_eye = person[14] > -1
237
+ has_right_eye = person[15] > -1
238
+ has_left_ear = person[16] > -1
239
+ has_right_ear = person[17] > -1
240
+
241
+ if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
242
+ continue
243
+
244
+ head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
245
+
246
+ width = 0.0
247
+ x0, y0 = candidate[head][:2]
248
+
249
+ if has_left_eye:
250
+ x1, y1 = candidate[left_eye][:2]
251
+ d = max(abs(x0 - x1), abs(y0 - y1))
252
+ width = max(width, d * 3.0)
253
+
254
+ if has_right_eye:
255
+ x1, y1 = candidate[right_eye][:2]
256
+ d = max(abs(x0 - x1), abs(y0 - y1))
257
+ width = max(width, d * 3.0)
258
+
259
+ if has_left_ear:
260
+ x1, y1 = candidate[left_ear][:2]
261
+ d = max(abs(x0 - x1), abs(y0 - y1))
262
+ width = max(width, d * 1.5)
263
+
264
+ if has_right_ear:
265
+ x1, y1 = candidate[right_ear][:2]
266
+ d = max(abs(x0 - x1), abs(y0 - y1))
267
+ width = max(width, d * 1.5)
268
+
269
+ x, y = x0, y0
270
+
271
+ x -= width
272
+ y -= width
273
+
274
+ if x < 0:
275
+ x = 0
276
+
277
+ if y < 0:
278
+ y = 0
279
+
280
+ width1 = width * 2
281
+ width2 = width * 2
282
+
283
+ if x + width > image_width:
284
+ width1 = image_width - x
285
+
286
+ if y + width > image_height:
287
+ width2 = image_height - y
288
+
289
+ width = min(width1, width2)
290
+
291
+ if width >= 20:
292
+ detect_result.append([int(x), int(y), int(width)])
293
+
294
+ return detect_result
295
+
296
+
297
+ # get max index of 2d array
298
+ def npmax(array):
299
+ arrayindex = array.argmax(1)
300
+ arrayvalue = array.max(1)
301
+ i = arrayvalue.argmax()
302
+ j = arrayindex[i]
303
+ return i, j
src/controlnet_aux/dwpose/wholebody.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import os
3
+ import numpy as np
4
+ import warnings
5
+
6
+ try:
7
+ import mmcv
8
+ except ImportError:
9
+ warnings.warn(
10
+ "The module 'mmcv' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmcv>=2.0.1'"
11
+ )
12
+
13
+ try:
14
+ from mmpose.apis import inference_topdown
15
+ from mmpose.apis import init_model as init_pose_estimator
16
+ from mmpose.evaluation.functional import nms
17
+ from mmpose.utils import adapt_mmdet_pipeline
18
+ from mmpose.structures import merge_data_samples
19
+ except ImportError:
20
+ warnings.warn(
21
+ "The module 'mmpose' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmpose>=1.1.0'"
22
+ )
23
+
24
+ try:
25
+ from mmdet.apis import inference_detector, init_detector
26
+ except ImportError:
27
+ warnings.warn(
28
+ "The module 'mmdet' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmdet>=3.1.0'"
29
+ )
30
+
31
+
32
+ class Wholebody:
33
+ def __init__(self,
34
+ det_config=None, det_ckpt=None,
35
+ pose_config=None, pose_ckpt=None,
36
+ device="cpu"):
37
+
38
+ if det_config is None:
39
+ det_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py")
40
+
41
+ if pose_config is None:
42
+ pose_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py")
43
+
44
+ if det_ckpt is None:
45
+ det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'
46
+
47
+ if pose_ckpt is None:
48
+ pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth"
49
+
50
+ # build detector
51
+ self.detector = init_detector(det_config, det_ckpt, device=device)
52
+ self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
53
+
54
+ # build pose estimator
55
+ self.pose_estimator = init_pose_estimator(
56
+ pose_config,
57
+ pose_ckpt,
58
+ device=device)
59
+
60
+ def to(self, device):
61
+ self.detector.to(device)
62
+ self.pose_estimator.to(device)
63
+ return self
64
+
65
+ def __call__(self, oriImg):
66
+ # predict bbox
67
+ det_result = inference_detector(self.detector, oriImg)
68
+ pred_instance = det_result.pred_instances.cpu().numpy()
69
+ bboxes = np.concatenate(
70
+ (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
71
+ bboxes = bboxes[np.logical_and(pred_instance.labels == 0,
72
+ pred_instance.scores > 0.5)]
73
+
74
+ # set NMS threshold
75
+ bboxes = bboxes[nms(bboxes, 0.7), :4]
76
+
77
+ # predict keypoints
78
+ if len(bboxes) == 0:
79
+ pose_results = inference_topdown(self.pose_estimator, oriImg)
80
+ else:
81
+ pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes)
82
+ preds = merge_data_samples(pose_results)
83
+ preds = preds.pred_instances
84
+
85
+ # preds = pose_results[0].pred_instances
86
+ keypoints = preds.get('transformed_keypoints',
87
+ preds.keypoints)
88
+ if 'keypoint_scores' in preds:
89
+ scores = preds.keypoint_scores
90
+ else:
91
+ scores = np.ones(keypoints.shape[:-1])
92
+
93
+ if 'keypoints_visible' in preds:
94
+ visible = preds.keypoints_visible
95
+ else:
96
+ visible = np.ones(keypoints.shape[:-1])
97
+ keypoints_info = np.concatenate(
98
+ (keypoints, scores[..., None], visible[..., None]),
99
+ axis=-1)
100
+ # compute neck joint
101
+ neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
102
+ # neck score when visualizing pred
103
+ neck[:, 2:4] = np.logical_and(
104
+ keypoints_info[:, 5, 2:4] > 0.3,
105
+ keypoints_info[:, 6, 2:4] > 0.3).astype(int)
106
+ new_keypoints_info = np.insert(
107
+ keypoints_info, 17, neck, axis=1)
108
+ mmpose_idx = [
109
+ 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
110
+ ]
111
+ openpose_idx = [
112
+ 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
113
+ ]
114
+ new_keypoints_info[:, openpose_idx] = \
115
+ new_keypoints_info[:, mmpose_idx]
116
+ keypoints_info = new_keypoints_info
117
+
118
+ keypoints, scores, visible = keypoints_info[
119
+ ..., :2], keypoints_info[..., 2], keypoints_info[..., 3]
120
+
121
+ return keypoints, scores
src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ img_scale = (640, 640) # width, height
2
+
3
+ # model settings
4
+ model = dict(
5
+ type='YOLOX',
6
+ data_preprocessor=dict(
7
+ type='DetDataPreprocessor',
8
+ pad_size_divisor=32,
9
+ batch_augments=[
10
+ dict(
11
+ type='BatchSyncRandomResize',
12
+ random_size_range=(480, 800),
13
+ size_divisor=32,
14
+ interval=10)
15
+ ]),
16
+ backbone=dict(
17
+ type='CSPDarknet',
18
+ deepen_factor=1.0,
19
+ widen_factor=1.0,
20
+ out_indices=(2, 3, 4),
21
+ use_depthwise=False,
22
+ spp_kernal_sizes=(5, 9, 13),
23
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
24
+ act_cfg=dict(type='Swish'),
25
+ ),
26
+ neck=dict(
27
+ type='YOLOXPAFPN',
28
+ in_channels=[256, 512, 1024],
29
+ out_channels=256,
30
+ num_csp_blocks=3,
31
+ use_depthwise=False,
32
+ upsample_cfg=dict(scale_factor=2, mode='nearest'),
33
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
34
+ act_cfg=dict(type='Swish')),
35
+ bbox_head=dict(
36
+ type='YOLOXHead',
37
+ num_classes=80,
38
+ in_channels=256,
39
+ feat_channels=256,
40
+ stacked_convs=2,
41
+ strides=(8, 16, 32),
42
+ use_depthwise=False,
43
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
44
+ act_cfg=dict(type='Swish'),
45
+ loss_cls=dict(
46
+ type='CrossEntropyLoss',
47
+ use_sigmoid=True,
48
+ reduction='sum',
49
+ loss_weight=1.0),
50
+ loss_bbox=dict(
51
+ type='IoULoss',
52
+ mode='square',
53
+ eps=1e-16,
54
+ reduction='sum',
55
+ loss_weight=5.0),
56
+ loss_obj=dict(
57
+ type='CrossEntropyLoss',
58
+ use_sigmoid=True,
59
+ reduction='sum',
60
+ loss_weight=1.0),
61
+ loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
62
+ train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
63
+ # In order to align the source code, the threshold of the val phase is
64
+ # 0.01, and the threshold of the test phase is 0.001.
65
+ test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
66
+
67
+ # dataset settings
68
+ data_root = 'data/coco/'
69
+ dataset_type = 'CocoDataset'
70
+
71
+ # Example to use different file client
72
+ # Method 1: simply set the data root and let the file I/O module
73
+ # automatically infer from prefix (not support LMDB and Memcache yet)
74
+
75
+ # data_root = 's3://openmmlab/datasets/detection/coco/'
76
+
77
+ # Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
78
+ # backend_args = dict(
79
+ # backend='petrel',
80
+ # path_mapping=dict({
81
+ # './data/': 's3://openmmlab/datasets/detection/',
82
+ # 'data/': 's3://openmmlab/datasets/detection/'
83
+ # }))
84
+ backend_args = None
85
+
86
+ train_pipeline = [
87
+ dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
88
+ dict(
89
+ type='RandomAffine',
90
+ scaling_ratio_range=(0.1, 2),
91
+ # img_scale is (width, height)
92
+ border=(-img_scale[0] // 2, -img_scale[1] // 2)),
93
+ dict(
94
+ type='MixUp',
95
+ img_scale=img_scale,
96
+ ratio_range=(0.8, 1.6),
97
+ pad_val=114.0),
98
+ dict(type='YOLOXHSVRandomAug'),
99
+ dict(type='RandomFlip', prob=0.5),
100
+ # According to the official implementation, multi-scale
101
+ # training is not considered here but in the
102
+ # 'mmdet/models/detectors/yolox.py'.
103
+ # Resize and Pad are for the last 15 epochs when Mosaic,
104
+ # RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
105
+ dict(type='Resize', scale=img_scale, keep_ratio=True),
106
+ dict(
107
+ type='Pad',
108
+ pad_to_square=True,
109
+ # If the image is three-channel, the pad value needs
110
+ # to be set separately for each channel.
111
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
112
+ dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
113
+ dict(type='PackDetInputs')
114
+ ]
115
+
116
+ train_dataset = dict(
117
+ # use MultiImageMixDataset wrapper to support mosaic and mixup
118
+ type='MultiImageMixDataset',
119
+ dataset=dict(
120
+ type=dataset_type,
121
+ data_root=data_root,
122
+ ann_file='annotations/instances_train2017.json',
123
+ data_prefix=dict(img='train2017/'),
124
+ pipeline=[
125
+ dict(type='LoadImageFromFile', backend_args=backend_args),
126
+ dict(type='LoadAnnotations', with_bbox=True)
127
+ ],
128
+ filter_cfg=dict(filter_empty_gt=False, min_size=32),
129
+ backend_args=backend_args),
130
+ pipeline=train_pipeline)
131
+
132
+ test_pipeline = [
133
+ dict(type='LoadImageFromFile', backend_args=backend_args),
134
+ dict(type='Resize', scale=img_scale, keep_ratio=True),
135
+ dict(
136
+ type='Pad',
137
+ pad_to_square=True,
138
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
139
+ dict(type='LoadAnnotations', with_bbox=True),
140
+ dict(
141
+ type='PackDetInputs',
142
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
143
+ 'scale_factor'))
144
+ ]
145
+
146
+ train_dataloader = dict(
147
+ batch_size=8,
148
+ num_workers=4,
149
+ persistent_workers=True,
150
+ sampler=dict(type='DefaultSampler', shuffle=True),
151
+ dataset=train_dataset)
152
+ val_dataloader = dict(
153
+ batch_size=8,
154
+ num_workers=4,
155
+ persistent_workers=True,
156
+ drop_last=False,
157
+ sampler=dict(type='DefaultSampler', shuffle=False),
158
+ dataset=dict(
159
+ type=dataset_type,
160
+ data_root=data_root,
161
+ ann_file='annotations/instances_val2017.json',
162
+ data_prefix=dict(img='val2017/'),
163
+ test_mode=True,
164
+ pipeline=test_pipeline,
165
+ backend_args=backend_args))
166
+ test_dataloader = val_dataloader
167
+
168
+ val_evaluator = dict(
169
+ type='CocoMetric',
170
+ ann_file=data_root + 'annotations/instances_val2017.json',
171
+ metric='bbox',
172
+ backend_args=backend_args)
173
+ test_evaluator = val_evaluator
174
+
175
+ # training settings
176
+ max_epochs = 300
177
+ num_last_epochs = 15
178
+ interval = 10
179
+
180
+ train_cfg = dict(max_epochs=max_epochs, val_interval=interval)
181
+
182
+ # optimizer
183
+ # default 8 gpu
184
+ base_lr = 0.01
185
+ optim_wrapper = dict(
186
+ type='OptimWrapper',
187
+ optimizer=dict(
188
+ type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4,
189
+ nesterov=True),
190
+ paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
191
+
192
+ # learning rate
193
+ param_scheduler = [
194
+ dict(
195
+ # use quadratic formula to warm up 5 epochs
196
+ # and lr is updated by iteration
197
+ # TODO: fix default scope in get function
198
+ type='mmdet.QuadraticWarmupLR',
199
+ by_epoch=True,
200
+ begin=0,
201
+ end=5,
202
+ convert_to_iter_based=True),
203
+ dict(
204
+ # use cosine lr from 5 to 285 epoch
205
+ type='CosineAnnealingLR',
206
+ eta_min=base_lr * 0.05,
207
+ begin=5,
208
+ T_max=max_epochs - num_last_epochs,
209
+ end=max_epochs - num_last_epochs,
210
+ by_epoch=True,
211
+ convert_to_iter_based=True),
212
+ dict(
213
+ # use fixed lr during last 15 epochs
214
+ type='ConstantLR',
215
+ by_epoch=True,
216
+ factor=1,
217
+ begin=max_epochs - num_last_epochs,
218
+ end=max_epochs,
219
+ )
220
+ ]
221
+
222
+ default_hooks = dict(
223
+ checkpoint=dict(
224
+ interval=interval,
225
+ max_keep_ckpts=3 # only keep latest 3 checkpoints
226
+ ))
227
+
228
+ custom_hooks = [
229
+ dict(
230
+ type='YOLOXModeSwitchHook',
231
+ num_last_epochs=num_last_epochs,
232
+ priority=48),
233
+ dict(type='SyncNormHook', priority=48),
234
+ dict(
235
+ type='EMAHook',
236
+ ema_type='ExpMomentumEMA',
237
+ momentum=0.0001,
238
+ update_buffers=True,
239
+ priority=49)
240
+ ]
241
+
242
+ # NOTE: `auto_scale_lr` is for automatically scaling LR,
243
+ # USER SHOULD NOT CHANGE ITS VALUES.
244
+ # base_batch_size = (8 GPUs) x (8 samples per GPU)
245
+ auto_scale_lr = dict(base_batch_size=64)
src/controlnet_aux/hed/__init__.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an improved version and model of HED edge detection with Apache License, Version 2.0.
2
+ # Please use this implementation in your products
3
+ # This implementation may produce slightly different results from Saining Xie's official implementations,
4
+ # but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
5
+ # Different from official models and other implementations, this is an RGB-input model (rather than BGR)
6
+ # and in this way it works better for gradio's RGB protocol
7
+
8
+ import os
9
+ import warnings
10
+
11
+ import cv2
12
+ import numpy as np
13
+ import torch
14
+ from einops import rearrange
15
+ from huggingface_hub import hf_hub_download
16
+ from PIL import Image
17
+
18
+ from ..util import HWC3, nms, resize_image, safe_step
19
+
20
+
21
+ class DoubleConvBlock(torch.nn.Module):
22
+ def __init__(self, input_channel, output_channel, layer_number):
23
+ super().__init__()
24
+ self.convs = torch.nn.Sequential()
25
+ self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
26
+ for i in range(1, layer_number):
27
+ self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
28
+ self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
29
+
30
+ def __call__(self, x, down_sampling=False):
31
+ h = x
32
+ if down_sampling:
33
+ h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
34
+ for conv in self.convs:
35
+ h = conv(h)
36
+ h = torch.nn.functional.relu(h)
37
+ return h, self.projection(h)
38
+
39
+
40
+ class ControlNetHED_Apache2(torch.nn.Module):
41
+ def __init__(self):
42
+ super().__init__()
43
+ self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
44
+ self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
45
+ self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
46
+ self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
47
+ self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
48
+ self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
49
+
50
+ def __call__(self, x):
51
+ h = x - self.norm
52
+ h, projection1 = self.block1(h)
53
+ h, projection2 = self.block2(h, down_sampling=True)
54
+ h, projection3 = self.block3(h, down_sampling=True)
55
+ h, projection4 = self.block4(h, down_sampling=True)
56
+ h, projection5 = self.block5(h, down_sampling=True)
57
+ return projection1, projection2, projection3, projection4, projection5
58
+
59
+ class HEDdetector:
60
+ def __init__(self, netNetwork):
61
+ self.netNetwork = netNetwork
62
+
63
+ @classmethod
64
+ def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None):
65
+ filename = filename or "ControlNetHED.pth"
66
+
67
+ if os.path.isdir(pretrained_model_or_path):
68
+ model_path = os.path.join(pretrained_model_or_path, filename)
69
+ else:
70
+ model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
71
+
72
+ netNetwork = ControlNetHED_Apache2()
73
+ netNetwork.load_state_dict(torch.load(model_path, map_location='cpu'))
74
+ netNetwork.float().eval()
75
+
76
+ return cls(netNetwork)
77
+
78
+ def to(self, device):
79
+ self.netNetwork.to(device)
80
+ return self
81
+
82
+ def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, **kwargs):
83
+ if "return_pil" in kwargs:
84
+ warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
85
+ output_type = "pil" if kwargs["return_pil"] else "np"
86
+ if type(output_type) is bool:
87
+ warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
88
+ if output_type:
89
+ output_type = "pil"
90
+
91
+ device = next(iter(self.netNetwork.parameters())).device
92
+ if not isinstance(input_image, np.ndarray):
93
+ input_image = np.array(input_image, dtype=np.uint8)
94
+
95
+ input_image = HWC3(input_image)
96
+ input_image = resize_image(input_image, detect_resolution)
97
+
98
+ assert input_image.ndim == 3
99
+ H, W, C = input_image.shape
100
+ with torch.no_grad():
101
+ image_hed = torch.from_numpy(input_image.copy()).float().to(device)
102
+ image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
103
+ edges = self.netNetwork(image_hed)
104
+ edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
105
+ edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
106
+ edges = np.stack(edges, axis=2)
107
+ edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
108
+ if safe:
109
+ edge = safe_step(edge)
110
+ edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
111
+
112
+ detected_map = edge
113
+ detected_map = HWC3(detected_map)
114
+
115
+ img = resize_image(input_image, image_resolution)
116
+ H, W, C = img.shape
117
+
118
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
119
+
120
+ if scribble:
121
+ detected_map = nms(detected_map, 127, 3.0)
122
+ detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
123
+ detected_map[detected_map > 4] = 255
124
+ detected_map[detected_map < 255] = 0
125
+
126
+ if output_type == "pil":
127
+ detected_map = Image.fromarray(detected_map)
128
+
129
+ return detected_map
src/controlnet_aux/hed/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (10.1 kB). View file
 
src/controlnet_aux/leres/__init__.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from huggingface_hub import hf_hub_download
7
+ from PIL import Image
8
+
9
+ from ..util import HWC3, resize_image
10
+ from .leres.depthmap import estimateboost, estimateleres
11
+ from .leres.multi_depth_model_woauxi import RelDepthModel
12
+ from .leres.net_tools import strip_prefix_if_present
13
+ from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
14
+ from .pix2pix.options.test_options import TestOptions
15
+
16
+
17
+ class LeresDetector:
18
+ def __init__(self, model, pix2pixmodel):
19
+ self.model = model
20
+ self.pix2pixmodel = pix2pixmodel
21
+
22
+ @classmethod
23
+ def from_pretrained(cls, pretrained_model_or_path, filename=None, pix2pix_filename=None, cache_dir=None):
24
+ filename = filename or "res101.pth"
25
+ pix2pix_filename = pix2pix_filename or "latest_net_G.pth"
26
+
27
+ if os.path.isdir(pretrained_model_or_path):
28
+ model_path = os.path.join(pretrained_model_or_path, filename)
29
+ else:
30
+ model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
31
+
32
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
33
+
34
+ model = RelDepthModel(backbone='resnext101')
35
+ model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
36
+ del checkpoint
37
+
38
+ if os.path.isdir(pretrained_model_or_path):
39
+ model_path = os.path.join(pretrained_model_or_path, pix2pix_filename)
40
+ else:
41
+ model_path = hf_hub_download(pretrained_model_or_path, pix2pix_filename, cache_dir=cache_dir)
42
+
43
+ opt = TestOptions().parse()
44
+ if not torch.cuda.is_available():
45
+ opt.gpu_ids = [] # cpu mode
46
+ pix2pixmodel = Pix2Pix4DepthModel(opt)
47
+ pix2pixmodel.save_dir = os.path.dirname(model_path)
48
+ pix2pixmodel.load_networks('latest')
49
+ pix2pixmodel.eval()
50
+
51
+ return cls(model, pix2pixmodel)
52
+
53
+ def to(self, device):
54
+ self.model.to(device)
55
+ # TODO - refactor pix2pix implementation to support device migration
56
+ # self.pix2pixmodel.to(device)
57
+ return self
58
+
59
+ def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution=512, image_resolution=512, output_type="pil"):
60
+ device = next(iter(self.model.parameters())).device
61
+ if not isinstance(input_image, np.ndarray):
62
+ input_image = np.array(input_image, dtype=np.uint8)
63
+
64
+ input_image = HWC3(input_image)
65
+ input_image = resize_image(input_image, detect_resolution)
66
+
67
+ assert input_image.ndim == 3
68
+ height, width, dim = input_image.shape
69
+
70
+ with torch.no_grad():
71
+
72
+ if boost:
73
+ depth = estimateboost(input_image, self.model, 0, self.pix2pixmodel, max(width, height))
74
+ else:
75
+ depth = estimateleres(input_image, self.model, width, height)
76
+
77
+ numbytes=2
78
+ depth_min = depth.min()
79
+ depth_max = depth.max()
80
+ max_val = (2**(8*numbytes))-1
81
+
82
+ # check output before normalizing and mapping to 16 bit
83
+ if depth_max - depth_min > np.finfo("float").eps:
84
+ out = max_val * (depth - depth_min) / (depth_max - depth_min)
85
+ else:
86
+ out = np.zeros(depth.shape)
87
+
88
+ # single channel, 16 bit image
89
+ depth_image = out.astype("uint16")
90
+
91
+ # convert to uint8
92
+ depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0))
93
+
94
+ # remove near
95
+ if thr_a != 0:
96
+ thr_a = ((thr_a/100)*255)
97
+ depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
98
+
99
+ # invert image
100
+ depth_image = cv2.bitwise_not(depth_image)
101
+
102
+ # remove bg
103
+ if thr_b != 0:
104
+ thr_b = ((thr_b/100)*255)
105
+ depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
106
+
107
+ detected_map = depth_image
108
+ detected_map = HWC3(detected_map)
109
+
110
+ img = resize_image(input_image, image_resolution)
111
+ H, W, C = img.shape
112
+
113
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
114
+
115
+ if output_type == "pil":
116
+ detected_map = Image.fromarray(detected_map)
117
+
118
+ return detected_map
src/controlnet_aux/leres/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (6.37 kB). View file
 
src/controlnet_aux/leres/leres/LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ https://github.com/thygate/stable-diffusion-webui-depthmap-script
2
+
3
+ MIT License
4
+
5
+ Copyright (c) 2023 Bob Thiry
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
src/controlnet_aux/leres/leres/Resnet.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn as NN
3
+
4
+ __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
5
+ 'resnet152']
6
+
7
+
8
+ model_urls = {
9
+ 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
10
+ 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
11
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
12
+ 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
13
+ 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
14
+ }
15
+
16
+
17
+ def conv3x3(in_planes, out_planes, stride=1):
18
+ """3x3 convolution with padding"""
19
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
20
+ padding=1, bias=False)
21
+
22
+
23
+ class BasicBlock(nn.Module):
24
+ expansion = 1
25
+
26
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
27
+ super(BasicBlock, self).__init__()
28
+ self.conv1 = conv3x3(inplanes, planes, stride)
29
+ self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
30
+ self.relu = nn.ReLU(inplace=True)
31
+ self.conv2 = conv3x3(planes, planes)
32
+ self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
33
+ self.downsample = downsample
34
+ self.stride = stride
35
+
36
+ def forward(self, x):
37
+ residual = x
38
+
39
+ out = self.conv1(x)
40
+ out = self.bn1(out)
41
+ out = self.relu(out)
42
+
43
+ out = self.conv2(out)
44
+ out = self.bn2(out)
45
+
46
+ if self.downsample is not None:
47
+ residual = self.downsample(x)
48
+
49
+ out += residual
50
+ out = self.relu(out)
51
+
52
+ return out
53
+
54
+
55
+ class Bottleneck(nn.Module):
56
+ expansion = 4
57
+
58
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
59
+ super(Bottleneck, self).__init__()
60
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
61
+ self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
62
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
63
+ padding=1, bias=False)
64
+ self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
65
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
66
+ self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d
67
+ self.relu = nn.ReLU(inplace=True)
68
+ self.downsample = downsample
69
+ self.stride = stride
70
+
71
+ def forward(self, x):
72
+ residual = x
73
+
74
+ out = self.conv1(x)
75
+ out = self.bn1(out)
76
+ out = self.relu(out)
77
+
78
+ out = self.conv2(out)
79
+ out = self.bn2(out)
80
+ out = self.relu(out)
81
+
82
+ out = self.conv3(out)
83
+ out = self.bn3(out)
84
+
85
+ if self.downsample is not None:
86
+ residual = self.downsample(x)
87
+
88
+ out += residual
89
+ out = self.relu(out)
90
+
91
+ return out
92
+
93
+
94
+ class ResNet(nn.Module):
95
+
96
+ def __init__(self, block, layers, num_classes=1000):
97
+ self.inplanes = 64
98
+ super(ResNet, self).__init__()
99
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
100
+ bias=False)
101
+ self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d
102
+ self.relu = nn.ReLU(inplace=True)
103
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
104
+ self.layer1 = self._make_layer(block, 64, layers[0])
105
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
106
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
107
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
108
+ #self.avgpool = nn.AvgPool2d(7, stride=1)
109
+ #self.fc = nn.Linear(512 * block.expansion, num_classes)
110
+
111
+ for m in self.modules():
112
+ if isinstance(m, nn.Conv2d):
113
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
114
+ elif isinstance(m, nn.BatchNorm2d):
115
+ nn.init.constant_(m.weight, 1)
116
+ nn.init.constant_(m.bias, 0)
117
+
118
+ def _make_layer(self, block, planes, blocks, stride=1):
119
+ downsample = None
120
+ if stride != 1 or self.inplanes != planes * block.expansion:
121
+ downsample = nn.Sequential(
122
+ nn.Conv2d(self.inplanes, planes * block.expansion,
123
+ kernel_size=1, stride=stride, bias=False),
124
+ NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d
125
+ )
126
+
127
+ layers = []
128
+ layers.append(block(self.inplanes, planes, stride, downsample))
129
+ self.inplanes = planes * block.expansion
130
+ for i in range(1, blocks):
131
+ layers.append(block(self.inplanes, planes))
132
+
133
+ return nn.Sequential(*layers)
134
+
135
+ def forward(self, x):
136
+ features = []
137
+
138
+ x = self.conv1(x)
139
+ x = self.bn1(x)
140
+ x = self.relu(x)
141
+ x = self.maxpool(x)
142
+
143
+ x = self.layer1(x)
144
+ features.append(x)
145
+ x = self.layer2(x)
146
+ features.append(x)
147
+ x = self.layer3(x)
148
+ features.append(x)
149
+ x = self.layer4(x)
150
+ features.append(x)
151
+
152
+ return features
153
+
154
+
155
+ def resnet18(pretrained=True, **kwargs):
156
+ """Constructs a ResNet-18 model.
157
+ Args:
158
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
159
+ """
160
+ model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
161
+ return model
162
+
163
+
164
+ def resnet34(pretrained=True, **kwargs):
165
+ """Constructs a ResNet-34 model.
166
+ Args:
167
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
168
+ """
169
+ model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
170
+ return model
171
+
172
+
173
+ def resnet50(pretrained=True, **kwargs):
174
+ """Constructs a ResNet-50 model.
175
+ Args:
176
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
177
+ """
178
+ model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
179
+
180
+ return model
181
+
182
+
183
+ def resnet101(pretrained=True, **kwargs):
184
+ """Constructs a ResNet-101 model.
185
+ Args:
186
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
187
+ """
188
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
189
+
190
+ return model
191
+
192
+
193
+ def resnet152(pretrained=True, **kwargs):
194
+ """Constructs a ResNet-152 model.
195
+ Args:
196
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
197
+ """
198
+ model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
199
+ return model
src/controlnet_aux/leres/leres/Resnext_torch.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import torch.nn as nn
4
+
5
+ try:
6
+ from urllib import urlretrieve
7
+ except ImportError:
8
+ from urllib.request import urlretrieve
9
+
10
+ __all__ = ['resnext101_32x8d']
11
+
12
+
13
+ model_urls = {
14
+ 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
15
+ 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
16
+ }
17
+
18
+
19
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
20
+ """3x3 convolution with padding"""
21
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
22
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
23
+
24
+
25
+ def conv1x1(in_planes, out_planes, stride=1):
26
+ """1x1 convolution"""
27
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
28
+
29
+
30
+ class BasicBlock(nn.Module):
31
+ expansion = 1
32
+
33
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
34
+ base_width=64, dilation=1, norm_layer=None):
35
+ super(BasicBlock, self).__init__()
36
+ if norm_layer is None:
37
+ norm_layer = nn.BatchNorm2d
38
+ if groups != 1 or base_width != 64:
39
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
40
+ if dilation > 1:
41
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
42
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
43
+ self.conv1 = conv3x3(inplanes, planes, stride)
44
+ self.bn1 = norm_layer(planes)
45
+ self.relu = nn.ReLU(inplace=True)
46
+ self.conv2 = conv3x3(planes, planes)
47
+ self.bn2 = norm_layer(planes)
48
+ self.downsample = downsample
49
+ self.stride = stride
50
+
51
+ def forward(self, x):
52
+ identity = x
53
+
54
+ out = self.conv1(x)
55
+ out = self.bn1(out)
56
+ out = self.relu(out)
57
+
58
+ out = self.conv2(out)
59
+ out = self.bn2(out)
60
+
61
+ if self.downsample is not None:
62
+ identity = self.downsample(x)
63
+
64
+ out += identity
65
+ out = self.relu(out)
66
+
67
+ return out
68
+
69
+
70
+ class Bottleneck(nn.Module):
71
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
72
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
73
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
74
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
75
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
76
+
77
+ expansion = 4
78
+
79
+ def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
80
+ base_width=64, dilation=1, norm_layer=None):
81
+ super(Bottleneck, self).__init__()
82
+ if norm_layer is None:
83
+ norm_layer = nn.BatchNorm2d
84
+ width = int(planes * (base_width / 64.)) * groups
85
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
86
+ self.conv1 = conv1x1(inplanes, width)
87
+ self.bn1 = norm_layer(width)
88
+ self.conv2 = conv3x3(width, width, stride, groups, dilation)
89
+ self.bn2 = norm_layer(width)
90
+ self.conv3 = conv1x1(width, planes * self.expansion)
91
+ self.bn3 = norm_layer(planes * self.expansion)
92
+ self.relu = nn.ReLU(inplace=True)
93
+ self.downsample = downsample
94
+ self.stride = stride
95
+
96
+ def forward(self, x):
97
+ identity = x
98
+
99
+ out = self.conv1(x)
100
+ out = self.bn1(out)
101
+ out = self.relu(out)
102
+
103
+ out = self.conv2(out)
104
+ out = self.bn2(out)
105
+ out = self.relu(out)
106
+
107
+ out = self.conv3(out)
108
+ out = self.bn3(out)
109
+
110
+ if self.downsample is not None:
111
+ identity = self.downsample(x)
112
+
113
+ out += identity
114
+ out = self.relu(out)
115
+
116
+ return out
117
+
118
+
119
+ class ResNet(nn.Module):
120
+
121
+ def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
122
+ groups=1, width_per_group=64, replace_stride_with_dilation=None,
123
+ norm_layer=None):
124
+ super(ResNet, self).__init__()
125
+ if norm_layer is None:
126
+ norm_layer = nn.BatchNorm2d
127
+ self._norm_layer = norm_layer
128
+
129
+ self.inplanes = 64
130
+ self.dilation = 1
131
+ if replace_stride_with_dilation is None:
132
+ # each element in the tuple indicates if we should replace
133
+ # the 2x2 stride with a dilated convolution instead
134
+ replace_stride_with_dilation = [False, False, False]
135
+ if len(replace_stride_with_dilation) != 3:
136
+ raise ValueError("replace_stride_with_dilation should be None "
137
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
138
+ self.groups = groups
139
+ self.base_width = width_per_group
140
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
141
+ bias=False)
142
+ self.bn1 = norm_layer(self.inplanes)
143
+ self.relu = nn.ReLU(inplace=True)
144
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
145
+ self.layer1 = self._make_layer(block, 64, layers[0])
146
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
147
+ dilate=replace_stride_with_dilation[0])
148
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
149
+ dilate=replace_stride_with_dilation[1])
150
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
151
+ dilate=replace_stride_with_dilation[2])
152
+ #self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
153
+ #self.fc = nn.Linear(512 * block.expansion, num_classes)
154
+
155
+ for m in self.modules():
156
+ if isinstance(m, nn.Conv2d):
157
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
158
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
159
+ nn.init.constant_(m.weight, 1)
160
+ nn.init.constant_(m.bias, 0)
161
+
162
+ # Zero-initialize the last BN in each residual branch,
163
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
164
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
165
+ if zero_init_residual:
166
+ for m in self.modules():
167
+ if isinstance(m, Bottleneck):
168
+ nn.init.constant_(m.bn3.weight, 0)
169
+ elif isinstance(m, BasicBlock):
170
+ nn.init.constant_(m.bn2.weight, 0)
171
+
172
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
173
+ norm_layer = self._norm_layer
174
+ downsample = None
175
+ previous_dilation = self.dilation
176
+ if dilate:
177
+ self.dilation *= stride
178
+ stride = 1
179
+ if stride != 1 or self.inplanes != planes * block.expansion:
180
+ downsample = nn.Sequential(
181
+ conv1x1(self.inplanes, planes * block.expansion, stride),
182
+ norm_layer(planes * block.expansion),
183
+ )
184
+
185
+ layers = []
186
+ layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
187
+ self.base_width, previous_dilation, norm_layer))
188
+ self.inplanes = planes * block.expansion
189
+ for _ in range(1, blocks):
190
+ layers.append(block(self.inplanes, planes, groups=self.groups,
191
+ base_width=self.base_width, dilation=self.dilation,
192
+ norm_layer=norm_layer))
193
+
194
+ return nn.Sequential(*layers)
195
+
196
+ def _forward_impl(self, x):
197
+ # See note [TorchScript super()]
198
+ features = []
199
+ x = self.conv1(x)
200
+ x = self.bn1(x)
201
+ x = self.relu(x)
202
+ x = self.maxpool(x)
203
+
204
+ x = self.layer1(x)
205
+ features.append(x)
206
+
207
+ x = self.layer2(x)
208
+ features.append(x)
209
+
210
+ x = self.layer3(x)
211
+ features.append(x)
212
+
213
+ x = self.layer4(x)
214
+ features.append(x)
215
+
216
+ #x = self.avgpool(x)
217
+ #x = torch.flatten(x, 1)
218
+ #x = self.fc(x)
219
+
220
+ return features
221
+
222
+ def forward(self, x):
223
+ return self._forward_impl(x)
224
+
225
+
226
+
227
+ def resnext101_32x8d(pretrained=True, **kwargs):
228
+ """Constructs a ResNet-152 model.
229
+ Args:
230
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
231
+ """
232
+ kwargs['groups'] = 32
233
+ kwargs['width_per_group'] = 8
234
+
235
+ model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
236
+ return model
237
+