Spaces:
Paused
Paused
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- LICENSE +201 -0
- README.md +12 -12
- app.py +48 -7
- caculate_metrics_256.py +27 -0
- caculate_metrics_512.py +27 -0
- evaluate.py +186 -0
- inception.py +138 -0
- main.py +1097 -0
- metrics.json +1538 -0
- metrics.py +522 -0
- pose-frames.py +16 -0
- pose.py +15 -0
- requirements.txt +8 -0
- run_stage1.sh +18 -0
- run_stage2.sh +18 -0
- run_stage3.sh +15 -0
- run_test_stage1.sh +9 -0
- run_test_stage2.sh +13 -0
- run_test_stage3.sh +12 -0
- sd.py +13 -0
- setup.txt +41 -0
- single_extract_pose.py +35 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-311.pyc +0 -0
- src/configs/dwpose-l_384x288.py +257 -0
- src/configs/stage1_config.py +181 -0
- src/configs/stage2_config.py +192 -0
- src/configs/stage3_config.py +217 -0
- src/configs/yolox_l_8xb8-300e_coco.py +245 -0
- src/controlnet_aux/__init__.py +18 -0
- src/controlnet_aux/__pycache__/__init__.cpython-311.pyc +0 -0
- src/controlnet_aux/__pycache__/util.cpython-311.pyc +0 -0
- src/controlnet_aux/canny/__init__.py +36 -0
- src/controlnet_aux/canny/__pycache__/__init__.cpython-311.pyc +0 -0
- src/controlnet_aux/dwpose/__init__.py +92 -0
- src/controlnet_aux/dwpose/__pycache__/__init__.cpython-311.pyc +0 -0
- src/controlnet_aux/dwpose/__pycache__/util.cpython-311.pyc +0 -0
- src/controlnet_aux/dwpose/__pycache__/wholebody.cpython-311.pyc +0 -0
- src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py +257 -0
- src/controlnet_aux/dwpose/util.py +303 -0
- src/controlnet_aux/dwpose/wholebody.py +121 -0
- src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py +245 -0
- src/controlnet_aux/hed/__init__.py +129 -0
- src/controlnet_aux/hed/__pycache__/__init__.cpython-311.pyc +0 -0
- src/controlnet_aux/leres/__init__.py +118 -0
- src/controlnet_aux/leres/__pycache__/__init__.cpython-311.pyc +0 -0
- src/controlnet_aux/leres/leres/LICENSE +23 -0
- src/controlnet_aux/leres/leres/Resnet.py +199 -0
- src/controlnet_aux/leres/leres/Resnext_torch.py +237 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
src/controlnet_aux/tests/test_image.png filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,12 +1,12 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: KeyframesAI
|
| 3 |
+
emoji: 📈
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: gray
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.22.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
CHANGED
|
@@ -1,7 +1,48 @@
|
|
| 1 |
-
import
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from main import run_app, run_train, run_inference
|
| 2 |
+
|
| 3 |
+
import spaces
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import cv2
|
| 6 |
+
import os
|
| 7 |
+
import gradio as gr
|
| 8 |
+
|
| 9 |
+
with gr.Blocks() as demo:
|
| 10 |
+
with gr.Row():
|
| 11 |
+
with gr.Column():
|
| 12 |
+
char_imgs = gr.Gallery(type="pil", label="Images of the Character")
|
| 13 |
+
mocap = gr.Video(label="Motion-Capture Video")
|
| 14 |
+
tr_steps = gr.Number(label="Training steps", value=10)
|
| 15 |
+
inf_steps = gr.Number(label="Inference steps", value=10)
|
| 16 |
+
fps = gr.Number(label="Output frame rate", value=12)
|
| 17 |
+
modelId = gr.Text(label="Model Id", value="fine_tuned_pcdms")
|
| 18 |
+
remove_bg = gr.Checkbox(label="Remove background", value=False)
|
| 19 |
+
resize_inputs = gr.Checkbox(label="Resize images to match video", value=True)
|
| 20 |
+
train_btn = gr.Button(value="Train")
|
| 21 |
+
inference_btn = gr.Button(value="Inference")
|
| 22 |
+
submit_btn = gr.Button(value="Generate")
|
| 23 |
+
with gr.Column():
|
| 24 |
+
animation = gr.Video(label="Result")
|
| 25 |
+
frames = gr.Gallery(type="pil", label="Frames")
|
| 26 |
+
|
| 27 |
+
submit_btn.click(
|
| 28 |
+
run_app, inputs=[char_imgs, mocap, tr_steps, inf_steps, fps, remove_bg, resize_inputs], outputs=[animation, frames]
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
train_btn.click(
|
| 32 |
+
run_train, inputs=[char_imgs, tr_steps, remove_bg, resize_inputs, modelId], outputs=[]
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
inference_btn.click(
|
| 36 |
+
run_inference, inputs=[char_imgs, mocap, inf_steps, fps, remove_bg, resize_inputs, modelId], outputs=[animation, frames]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
demo.launch(share=True)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
caculate_metrics_256.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from metrics import FID, LPIPS, Reconstruction_Metrics, preprocess_path_for_deform_task
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 6 |
+
fid = FID()
|
| 7 |
+
lpips_obj = LPIPS()
|
| 8 |
+
rec = Reconstruction_Metrics()
|
| 9 |
+
|
| 10 |
+
real_path = './datasets/deepfashing/train_lst_256_png'
|
| 11 |
+
gt_path = '/datasets/deepfashing/test_lst_256_png'
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
distorated_path = './PCDMs_Results/stage3_256_results'
|
| 15 |
+
results_save_path = distorated_path + '_results.txt' # save path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
gt_list, distorated_list = preprocess_path_for_deform_task(gt_path, distorated_path)
|
| 19 |
+
print(len(gt_list), len(distorated_list))
|
| 20 |
+
|
| 21 |
+
FID = fid.calculate_from_disk(distorated_path, real_path, img_size=(176,256))
|
| 22 |
+
LPIPS = lpips_obj.calculate_from_disk(distorated_list, gt_list, img_size=(176,256), sort=False)
|
| 23 |
+
REC = rec.calculate_from_disk(distorated_list, gt_list, distorated_path, img_size=(176,256), sort=False, debug=False)
|
| 24 |
+
|
| 25 |
+
print ("FID: "+str(FID)+"\nLPIPS: "+str(LPIPS)+"\nSSIM: "+str(REC))
|
| 26 |
+
with open(results_save_path, 'a') as ff:
|
| 27 |
+
ff.write("\nFID: "+str(FID)+"\nLPIPS: "+str(LPIPS)+"\nSSIM: "+str(REC))
|
caculate_metrics_512.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from metrics import FID, LPIPS, Reconstruction_Metrics, preprocess_path_for_deform_task
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 6 |
+
fid = FID()
|
| 7 |
+
lpips_obj = LPIPS()
|
| 8 |
+
rec = Reconstruction_Metrics()
|
| 9 |
+
|
| 10 |
+
real_path = './datasets/deepfashing/train_lst_512_png'
|
| 11 |
+
gt_path = '/datasets/deepfashing/test_lst_512_png'
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
distorated_path = './PCDMs_Results/stage3_512_results'
|
| 15 |
+
results_save_path = distorated_path + '_results.txt' # save path
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
gt_list, distorated_list = preprocess_path_for_deform_task(gt_path, distorated_path)
|
| 19 |
+
print(len(gt_list), len(distorated_list))
|
| 20 |
+
|
| 21 |
+
FID = fid.calculate_from_disk(distorated_path, real_path, img_size=(352,512))
|
| 22 |
+
LPIPS = lpips_obj.calculate_from_disk(distorated_list, gt_list, img_size=(352,512), sort=False)
|
| 23 |
+
REC = rec.calculate_from_disk(distorated_list, gt_list, distorated_path, img_size=(352,512), sort=False, debug=False)
|
| 24 |
+
|
| 25 |
+
print ("FID: "+str(FID)+"\nLPIPS: "+str(LPIPS)+"\nSSIM: "+str(REC))
|
| 26 |
+
with open(results_save_path, 'a') as ff:
|
| 27 |
+
ff.write("\nFID: "+str(FID)+"\nLPIPS: "+str(LPIPS)+"\nSSIM: "+str(REC))
|
evaluate.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from main import extract_frames, run
|
| 2 |
+
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
from skimage.metrics import structural_similarity as ssim
|
| 6 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision.transforms as transforms
|
| 9 |
+
import lpips
|
| 10 |
+
from pytorch_fid.fid_score import calculate_fid_given_paths
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
|
| 14 |
+
# Convert PIL to numpy
|
| 15 |
+
def pil_to_np(img):
|
| 16 |
+
return np.array(img).astype(np.float32) / 255.0
|
| 17 |
+
|
| 18 |
+
# SSIM
|
| 19 |
+
def compute_ssim(img1, img2):
|
| 20 |
+
img1_np = pil_to_np(img1)
|
| 21 |
+
img2_np = pil_to_np(img2)
|
| 22 |
+
|
| 23 |
+
h, w = img1_np.shape[:2]
|
| 24 |
+
min_dim = min(h, w)
|
| 25 |
+
win_size = min(7, min_dim if min_dim % 2 == 1 else min_dim - 1) # ensure odd
|
| 26 |
+
|
| 27 |
+
return ssim(img1_np, img2_np, win_size=win_size, channel_axis=-1, data_range=1.0)
|
| 28 |
+
|
| 29 |
+
# PSNR
|
| 30 |
+
def compute_psnr(img1, img2):
|
| 31 |
+
img1_np = pil_to_np(img1)
|
| 32 |
+
img2_np = pil_to_np(img2)
|
| 33 |
+
return psnr(img1_np, img2_np, data_range=1.0)
|
| 34 |
+
|
| 35 |
+
# LPIPS
|
| 36 |
+
lpips_model = lpips.LPIPS(net='alex')
|
| 37 |
+
lpips_transform = transforms.Compose([
|
| 38 |
+
transforms.Resize((256, 256)),
|
| 39 |
+
transforms.ToTensor(),
|
| 40 |
+
transforms.Normalize([0.5]*3, [0.5]*3)
|
| 41 |
+
])
|
| 42 |
+
def compute_lpips(img1, img2):
|
| 43 |
+
|
| 44 |
+
img1_tensor = lpips_transform(img1).unsqueeze(0)
|
| 45 |
+
img2_tensor = lpips_transform(img2).unsqueeze(0)
|
| 46 |
+
return lpips_model(img1_tensor, img2_tensor).item()
|
| 47 |
+
|
| 48 |
+
# FID: Save images to temp folders for FID calculation
|
| 49 |
+
def compute_fid(img1, img2):
|
| 50 |
+
os.makedirs('temp/img1', exist_ok=True)
|
| 51 |
+
os.makedirs('temp/img2', exist_ok=True)
|
| 52 |
+
img1.save('temp/img1/0.png')
|
| 53 |
+
img2.save('temp/img2/0.png')
|
| 54 |
+
fid = calculate_fid_given_paths(['temp/img1', 'temp/img2'], batch_size=1, device='cpu', dims=2048)
|
| 55 |
+
return fid
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
with open('metrics.json', 'r') as file:
|
| 59 |
+
metrics = json.load(file)
|
| 60 |
+
|
| 61 |
+
def get_score(item, image_paths, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False):
|
| 62 |
+
print(item)
|
| 63 |
+
|
| 64 |
+
images = []
|
| 65 |
+
for path in image_paths:
|
| 66 |
+
img = Image.open(path)
|
| 67 |
+
images.append(img)
|
| 68 |
+
|
| 69 |
+
gt_frames = extract_frames(video_path, fps)
|
| 70 |
+
|
| 71 |
+
os.makedirs('out/'+item, exist_ok=True)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
for i, frame in enumerate(gt_frames):
|
| 75 |
+
frame.save("out/"+item+"/frame_"+str(i)+".png")
|
| 76 |
+
|
| 77 |
+
results = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=True)
|
| 78 |
+
|
| 79 |
+
for i, result in enumerate(results):
|
| 80 |
+
result.save("out/"+item+"/result_"+str(i)+".png")
|
| 81 |
+
|
| 82 |
+
results_base = run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, finetune=False)
|
| 83 |
+
|
| 84 |
+
for i, result in enumerate(results_base):
|
| 85 |
+
result.save("out/"+item+"/base_"+str(i)+".png")
|
| 86 |
+
|
| 87 |
+
"""
|
| 88 |
+
img1=gt_frames[0]
|
| 89 |
+
img2=Image.open("out/base_0.png")
|
| 90 |
+
print("SSIM:", compute_ssim(img1, img2))
|
| 91 |
+
print("PSNR:", compute_psnr(img1, img2))
|
| 92 |
+
print("LPIPS:", compute_lpips(img1, img2))
|
| 93 |
+
print("FID:", compute_fid(img1, img2))
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
ssim = []
|
| 97 |
+
psnr = []
|
| 98 |
+
lpips = []
|
| 99 |
+
fid = []
|
| 100 |
+
ssim2 = []
|
| 101 |
+
psnr2 = []
|
| 102 |
+
lpips2 = []
|
| 103 |
+
fid2 = []
|
| 104 |
+
for gt, result, base in zip(gt_frames, results, results_base):
|
| 105 |
+
ssim.append(float(compute_ssim(gt, result)))
|
| 106 |
+
psnr.append(float(compute_psnr(gt, result)))
|
| 107 |
+
lpips.append(float(compute_lpips(gt, result)))
|
| 108 |
+
fid.append(float(compute_fid(gt, result)))
|
| 109 |
+
|
| 110 |
+
ssim2.append(float(compute_ssim(gt, base)))
|
| 111 |
+
psnr2.append(float(compute_psnr(gt, base)))
|
| 112 |
+
lpips2.append(float(compute_lpips(gt, base)))
|
| 113 |
+
fid2.append(float(compute_fid(gt, base)))
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
print("SSIM:", sum(ssim)/len(ssim))
|
| 117 |
+
print("PSNR:", sum(psnr)/len(psnr))
|
| 118 |
+
print("LPIPS:", sum(lpips)/len(lpips))
|
| 119 |
+
print("FID:", sum(fid)/len(fid))
|
| 120 |
+
print('baseline:')
|
| 121 |
+
print("SSIM:", sum(ssim2)/len(ssim2))
|
| 122 |
+
print("PSNR:", sum(psnr2)/len(psnr2))
|
| 123 |
+
print("LPIPS:", sum(lpips2)/len(lpips2))
|
| 124 |
+
print("FID:", sum(fid2)/len(fid2))
|
| 125 |
+
|
| 126 |
+
metrics[item] = {'ft': {}, 'base': {}}
|
| 127 |
+
metrics[item]['ft']['ssim'] = {'avg': sum(ssim)/len(ssim), 'vals': ssim}
|
| 128 |
+
metrics[item]['ft']['psnr'] = {'avg': sum(psnr)/len(psnr), 'vals': psnr}
|
| 129 |
+
metrics[item]['ft']['lpips'] = {'avg': sum(lpips)/len(lpips), 'vals': lpips}
|
| 130 |
+
metrics[item]['ft']['fid'] = {'avg': sum(fid)/len(fid), 'vals': fid}
|
| 131 |
+
metrics[item]['base']['ssim'] = {'avg': sum(ssim2)/len(ssim2), 'vals': ssim2}
|
| 132 |
+
metrics[item]['base']['psnr'] = {'avg': sum(psnr2)/len(psnr2), 'vals': psnr2}
|
| 133 |
+
metrics[item]['base']['lpips'] = {'avg': sum(lpips2)/len(lpips2), 'vals': lpips2}
|
| 134 |
+
metrics[item]['base']['fid'] = {'avg': sum(fid2)/len(fid2), 'vals': fid2}
|
| 135 |
+
|
| 136 |
+
with open('metrics.json', "w", encoding="utf-8") as json_file:
|
| 137 |
+
json.dump(metrics, json_file, ensure_ascii=False, indent=4)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
items = ['sidewalk', 'aaa', 'azri', 'dead', 'frankgirl', 'kobold', 'ramona', 'renee', 'walk', 'woody']
|
| 143 |
+
for item in items:
|
| 144 |
+
if item in metrics:
|
| 145 |
+
continue
|
| 146 |
+
get_score(item, ['test/'+item+'/1.jpg', 'test/'+item+'/2.jpg', 'test/'+item+'/3.jpg'], 'test/'+item+'/v.mp4')
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
ssim = []
|
| 151 |
+
psnr = []
|
| 152 |
+
lpips = []
|
| 153 |
+
fid = []
|
| 154 |
+
ssim2 = []
|
| 155 |
+
psnr2 = []
|
| 156 |
+
lpips2 = []
|
| 157 |
+
fid2 = []
|
| 158 |
+
for item in metrics.keys():
|
| 159 |
+
ssim.append(metrics[item]['ft']['ssim']['avg'])
|
| 160 |
+
psnr.append(metrics[item]['ft']['psnr']['avg'])
|
| 161 |
+
lpips.append(metrics[item]['ft']['lpips']['avg'])
|
| 162 |
+
fid.append(metrics[item]['ft']['fid']['avg'])
|
| 163 |
+
|
| 164 |
+
ssim2.append(metrics[item]['base']['ssim']['avg'])
|
| 165 |
+
psnr2.append(metrics[item]['base']['psnr']['avg'])
|
| 166 |
+
lpips2.append(metrics[item]['base']['lpips']['avg'])
|
| 167 |
+
fid2.append(metrics[item]['base']['fid']['avg'])
|
| 168 |
+
|
| 169 |
+
print(item)
|
| 170 |
+
print("SSIM:", metrics[item]['ft']['ssim']['avg'], metrics[item]['base']['ssim']['avg'])
|
| 171 |
+
print("PSNR:", metrics[item]['ft']['psnr']['avg'], metrics[item]['base']['psnr']['avg'])
|
| 172 |
+
print("LPIPS:", metrics[item]['ft']['lpips']['avg'], metrics[item]['base']['lpips']['avg'])
|
| 173 |
+
print("FID:", metrics[item]['ft']['fid']['avg'], metrics[item]['base']['fid']['avg'])
|
| 174 |
+
|
| 175 |
+
print('Results:')
|
| 176 |
+
print("SSIM:", sum(ssim)/len(ssim))
|
| 177 |
+
print("PSNR:", sum(psnr)/len(psnr))
|
| 178 |
+
print("LPIPS:", sum(lpips)/len(lpips))
|
| 179 |
+
print("FID:", sum(fid)/len(fid))
|
| 180 |
+
print('baseline:')
|
| 181 |
+
print("SSIM:", sum(ssim2)/len(ssim2))
|
| 182 |
+
print("PSNR:", sum(psnr2)/len(psnr2))
|
| 183 |
+
print("LPIPS:", sum(lpips2)/len(lpips2))
|
| 184 |
+
print("FID:", sum(fid2)/len(fid2))
|
| 185 |
+
|
| 186 |
+
|
inception.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision import models
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class InceptionV3(nn.Module):
|
| 8 |
+
"""Pretrained InceptionV3 network returning feature maps"""
|
| 9 |
+
|
| 10 |
+
# Index of default block of inception to return,
|
| 11 |
+
# corresponds to output of final average pooling
|
| 12 |
+
DEFAULT_BLOCK_INDEX = 3
|
| 13 |
+
|
| 14 |
+
# Maps feature dimensionality to their output blocks indices
|
| 15 |
+
BLOCK_INDEX_BY_DIM = {
|
| 16 |
+
64: 0, # First max pooling features
|
| 17 |
+
192: 1, # Second max pooling featurs
|
| 18 |
+
768: 2, # Pre-aux classifier features
|
| 19 |
+
2048: 3 # Final average pooling features
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
def __init__(self,
|
| 23 |
+
output_blocks=[DEFAULT_BLOCK_INDEX],
|
| 24 |
+
resize_input=True,
|
| 25 |
+
normalize_input=True,
|
| 26 |
+
requires_grad=False):
|
| 27 |
+
"""Build pretrained InceptionV3
|
| 28 |
+
Parameters
|
| 29 |
+
----------
|
| 30 |
+
output_blocks : list of int
|
| 31 |
+
Indices of blocks to return features of. Possible values are:
|
| 32 |
+
- 0: corresponds to output of first max pooling
|
| 33 |
+
- 1: corresponds to output of second max pooling
|
| 34 |
+
- 2: corresponds to output which is fed to aux classifier
|
| 35 |
+
- 3: corresponds to output of final average pooling
|
| 36 |
+
resize_input : bool
|
| 37 |
+
If true, bilinearly resizes input to width and height 299 before
|
| 38 |
+
feeding input to model. As the network without fully connected
|
| 39 |
+
layers is fully convolutional, it should be able to handle inputs
|
| 40 |
+
of arbitrary size, so resizing might not be strictly needed
|
| 41 |
+
normalize_input : bool
|
| 42 |
+
If true, normalizes the input to the statistics the pretrained
|
| 43 |
+
Inception network expects
|
| 44 |
+
requires_grad : bool
|
| 45 |
+
If true, parameters of the model require gradient. Possibly useful
|
| 46 |
+
for finetuning the network
|
| 47 |
+
"""
|
| 48 |
+
super(InceptionV3, self).__init__()
|
| 49 |
+
|
| 50 |
+
self.resize_input = resize_input
|
| 51 |
+
self.normalize_input = normalize_input
|
| 52 |
+
self.output_blocks = sorted(output_blocks)
|
| 53 |
+
self.last_needed_block = max(output_blocks)
|
| 54 |
+
|
| 55 |
+
assert self.last_needed_block <= 3, \
|
| 56 |
+
'Last possible output block index is 3'
|
| 57 |
+
|
| 58 |
+
self.blocks = nn.ModuleList()
|
| 59 |
+
|
| 60 |
+
inception = models.inception_v3(pretrained=True)
|
| 61 |
+
# Block 0: input to maxpool1
|
| 62 |
+
block0 = [
|
| 63 |
+
inception.Conv2d_1a_3x3,
|
| 64 |
+
inception.Conv2d_2a_3x3,
|
| 65 |
+
inception.Conv2d_2b_3x3,
|
| 66 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
| 67 |
+
]
|
| 68 |
+
self.blocks.append(nn.Sequential(*block0))
|
| 69 |
+
|
| 70 |
+
# Block 1: maxpool1 to maxpool2
|
| 71 |
+
if self.last_needed_block >= 1:
|
| 72 |
+
block1 = [
|
| 73 |
+
inception.Conv2d_3b_1x1,
|
| 74 |
+
inception.Conv2d_4a_3x3,
|
| 75 |
+
nn.MaxPool2d(kernel_size=3, stride=2)
|
| 76 |
+
]
|
| 77 |
+
self.blocks.append(nn.Sequential(*block1))
|
| 78 |
+
|
| 79 |
+
# Block 2: maxpool2 to aux classifier
|
| 80 |
+
if self.last_needed_block >= 2:
|
| 81 |
+
block2 = [
|
| 82 |
+
inception.Mixed_5b,
|
| 83 |
+
inception.Mixed_5c,
|
| 84 |
+
inception.Mixed_5d,
|
| 85 |
+
inception.Mixed_6a,
|
| 86 |
+
inception.Mixed_6b,
|
| 87 |
+
inception.Mixed_6c,
|
| 88 |
+
inception.Mixed_6d,
|
| 89 |
+
inception.Mixed_6e,
|
| 90 |
+
]
|
| 91 |
+
self.blocks.append(nn.Sequential(*block2))
|
| 92 |
+
|
| 93 |
+
# Block 3: aux classifier to final avgpool
|
| 94 |
+
if self.last_needed_block >= 3:
|
| 95 |
+
block3 = [
|
| 96 |
+
inception.Mixed_7a,
|
| 97 |
+
inception.Mixed_7b,
|
| 98 |
+
inception.Mixed_7c,
|
| 99 |
+
nn.AdaptiveAvgPool2d(output_size=(1, 1))
|
| 100 |
+
]
|
| 101 |
+
self.blocks.append(nn.Sequential(*block3))
|
| 102 |
+
|
| 103 |
+
for param in self.parameters():
|
| 104 |
+
param.requires_grad = requires_grad
|
| 105 |
+
|
| 106 |
+
def forward(self, inp):
|
| 107 |
+
"""Get Inception feature maps
|
| 108 |
+
Parameters
|
| 109 |
+
----------
|
| 110 |
+
inp : torch.autograd.Variable
|
| 111 |
+
Input tensor of shape Bx3xHxW. Values are expected to be in
|
| 112 |
+
range (0, 1)
|
| 113 |
+
Returns
|
| 114 |
+
-------
|
| 115 |
+
List of torch.autograd.Variable, corresponding to the selected output
|
| 116 |
+
block, sorted ascending by index
|
| 117 |
+
"""
|
| 118 |
+
outp = []
|
| 119 |
+
x = inp
|
| 120 |
+
|
| 121 |
+
if self.resize_input:
|
| 122 |
+
x = F.upsample(x, size=(299, 299), mode='bilinear')
|
| 123 |
+
|
| 124 |
+
if self.normalize_input:
|
| 125 |
+
x = x.clone()
|
| 126 |
+
x[:, 0] = x[:, 0] * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
| 127 |
+
x[:, 1] = x[:, 1] * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
| 128 |
+
x[:, 2] = x[:, 2] * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
| 129 |
+
|
| 130 |
+
for idx, block in enumerate(self.blocks):
|
| 131 |
+
x = block(x)
|
| 132 |
+
if idx in self.output_blocks:
|
| 133 |
+
outp.append(x)
|
| 134 |
+
|
| 135 |
+
if idx == self.last_needed_block:
|
| 136 |
+
break
|
| 137 |
+
|
| 138 |
+
return outp
|
main.py
ADDED
|
@@ -0,0 +1,1097 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
from diffusers.models.controlnet import ControlNetConditioningEmbedding
|
| 6 |
+
import torch
|
| 7 |
+
from torch import nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import torch.utils.checkpoint
|
| 10 |
+
import transformers
|
| 11 |
+
from accelerate import Accelerator
|
| 12 |
+
from accelerate.logging import get_logger
|
| 13 |
+
from accelerate.utils import ProjectConfiguration, set_seed
|
| 14 |
+
|
| 15 |
+
from tqdm.auto import tqdm
|
| 16 |
+
from src.configs.stage2_config import args
|
| 17 |
+
|
| 18 |
+
import diffusers
|
| 19 |
+
from diffusers import (
|
| 20 |
+
AutoencoderKL,
|
| 21 |
+
DDPMScheduler,
|
| 22 |
+
)
|
| 23 |
+
from diffusers.optimization import get_scheduler
|
| 24 |
+
from diffusers.utils import check_min_version, is_wandb_available
|
| 25 |
+
from src.dataset.stage2_dataset import InpaintDataset, InpaintCollate_fn
|
| 26 |
+
from transformers import CLIPVisionModelWithProjection
|
| 27 |
+
from transformers import Dinov2Model
|
| 28 |
+
from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
import glob
|
| 33 |
+
import os
|
| 34 |
+
import torch
|
| 35 |
+
from torch import nn
|
| 36 |
+
from PIL import Image, ImageOps
|
| 37 |
+
import numpy as np
|
| 38 |
+
from diffusers import UniPCMultistepScheduler
|
| 39 |
+
from src.models.stage2_inpaint_unet_2d_condition import Stage2_InapintUNet2DConditionModel
|
| 40 |
+
|
| 41 |
+
from torchvision import transforms
|
| 42 |
+
from diffusers.models.controlnet import ControlNetConditioningEmbedding
|
| 43 |
+
from transformers import CLIPImageProcessor
|
| 44 |
+
from transformers import Dinov2Model
|
| 45 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel,ControlNetModel,DDIMScheduler
|
| 46 |
+
from src.pipelines.PCDMs_pipeline import PCDMsPipeline
|
| 47 |
+
#from single_extract_pose import inference_pose
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
import spaces
|
| 51 |
+
from easy_dwpose import DWposeDetector
|
| 52 |
+
from PIL import Image
|
| 53 |
+
import cv2
|
| 54 |
+
import os
|
| 55 |
+
import gradio as gr
|
| 56 |
+
import rembg
|
| 57 |
+
import uuid
|
| 58 |
+
import gc
|
| 59 |
+
from numba import cuda
|
| 60 |
+
|
| 61 |
+
from huggingface_hub import hf_hub_download
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Inputs ===================================================================================================
|
| 65 |
+
|
| 66 |
+
input_img = "sm.png"
|
| 67 |
+
train_imgs = ["target.png"]
|
| 68 |
+
in_vid = "walk.mp4"
|
| 69 |
+
out_vid = 'out.mp4'
|
| 70 |
+
|
| 71 |
+
"""
|
| 72 |
+
train_steps = 100
|
| 73 |
+
inference_steps = 10
|
| 74 |
+
fps = 12
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
debug = False
|
| 78 |
+
save_model = True
|
| 79 |
+
max_batch_size = 8
|
| 80 |
+
|
| 81 |
+
# Pose detection ==============================================================================================
|
| 82 |
+
|
| 83 |
+
def load_models():
|
| 84 |
+
dwpose = DWposeDetector(device="cpu")
|
| 85 |
+
rembg_session = rembg.new_session("u2netp")
|
| 86 |
+
|
| 87 |
+
pcdms_model = hf_hub_download(repo_id="acmyu/PCDMs", filename="pcdms_ckpt.pt")
|
| 88 |
+
|
| 89 |
+
# Load scheduler
|
| 90 |
+
noise_scheduler = DDPMScheduler.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="scheduler")
|
| 91 |
+
|
| 92 |
+
# Load model
|
| 93 |
+
image_encoder_p = Dinov2Model.from_pretrained('facebook/dinov2-giant')
|
| 94 |
+
image_encoder_g = CLIPVisionModelWithProjection.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')#("openai/clip-vit-base-patch32")
|
| 95 |
+
|
| 96 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="vae")
|
| 97 |
+
unet = Stage2_InapintUNet2DConditionModel.from_pretrained(
|
| 98 |
+
"stabilityai/stable-diffusion-2-1-base",
|
| 99 |
+
torch_dtype=torch.float16,
|
| 100 |
+
subfolder="unet",
|
| 101 |
+
in_channels=9,
|
| 102 |
+
low_cpu_mem_usage=False,
|
| 103 |
+
ignore_mismatched_sizes=True)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
return dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
#load_models()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def resize_and_pad(img, target_img):
|
| 113 |
+
tw, th = target_img.size
|
| 114 |
+
w, h = img.size
|
| 115 |
+
|
| 116 |
+
if tw/th > w/h:
|
| 117 |
+
tw = int(th * w/h)
|
| 118 |
+
elif tw/th < w/h:
|
| 119 |
+
th = int(tw * h/w)
|
| 120 |
+
|
| 121 |
+
img = img.resize((tw, th), Image.BICUBIC)
|
| 122 |
+
|
| 123 |
+
tw, th = target_img.size
|
| 124 |
+
new_img = Image.new("RGB", (tw, th), (0, 0, 0))
|
| 125 |
+
left = (tw - img.width) // 2
|
| 126 |
+
top = (th - img.height) // 2
|
| 127 |
+
new_img.paste(img, (left, top))
|
| 128 |
+
|
| 129 |
+
return new_img
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def remove_zero_pad(image):
|
| 133 |
+
image = np.array(image)
|
| 134 |
+
dummy = np.argwhere(image != 0) # assume blackground is zero
|
| 135 |
+
max_y = dummy[:, 0].max()
|
| 136 |
+
min_y = dummy[:, 0].min()
|
| 137 |
+
min_x = dummy[:, 1].min()
|
| 138 |
+
max_x = dummy[:, 1].max()
|
| 139 |
+
crop_image = image[min_y:max_y, min_x:max_x]
|
| 140 |
+
|
| 141 |
+
return Image.fromarray(crop_image)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_pose(img, dwpose, outfile, crop=False):
|
| 145 |
+
#pil_image = Image.open("imgs/"+img).convert("RGB")
|
| 146 |
+
#skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False)
|
| 147 |
+
|
| 148 |
+
#img.thumbnail((512,512))
|
| 149 |
+
out_img = dwpose(img, include_hands=True, include_face=False)
|
| 150 |
+
|
| 151 |
+
#print(pose['bodies'])
|
| 152 |
+
|
| 153 |
+
if crop:
|
| 154 |
+
bbox = out_img.getbbox()
|
| 155 |
+
out_img = out_img.crop(bbox)
|
| 156 |
+
out_img = ImageOps.expand(out_img, border=int(out_img.width*0.2), fill=(0,0,0))
|
| 157 |
+
|
| 158 |
+
return out_img
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def extract_frames(video_path, fps):
|
| 162 |
+
video_capture = cv2.VideoCapture(video_path)
|
| 163 |
+
frame_count = 0
|
| 164 |
+
frames = []
|
| 165 |
+
|
| 166 |
+
fps_in = video_capture.get(cv2.CAP_PROP_FPS)
|
| 167 |
+
fps_out = fps
|
| 168 |
+
|
| 169 |
+
index_in = -1
|
| 170 |
+
index_out = -1
|
| 171 |
+
|
| 172 |
+
while True:
|
| 173 |
+
success = video_capture.grab()
|
| 174 |
+
if not success: break
|
| 175 |
+
index_in += 1
|
| 176 |
+
|
| 177 |
+
out_due = int(index_in / fps_in * fps_out)
|
| 178 |
+
if out_due > index_out:
|
| 179 |
+
success, frame = video_capture.retrieve()
|
| 180 |
+
if not success:
|
| 181 |
+
break
|
| 182 |
+
index_out += 1
|
| 183 |
+
|
| 184 |
+
frame_count += 1
|
| 185 |
+
frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
|
| 186 |
+
|
| 187 |
+
video_capture.release()
|
| 188 |
+
print(f"Extracted {frame_count} frames")
|
| 189 |
+
return frames
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def removebg(img, rembg_session):
|
| 193 |
+
result = Image.new("RGB", img.size, "#ffffff")
|
| 194 |
+
out = rembg.remove(img, session=rembg_session)
|
| 195 |
+
result.paste(out, mask=out)
|
| 196 |
+
return result
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def prepare_inputs_train(images, bg_remove, dwpose, rembg_session):
|
| 200 |
+
if bg_remove:
|
| 201 |
+
images = [removebg(img, rembg_session) for img in images]
|
| 202 |
+
|
| 203 |
+
in_img = images[0]
|
| 204 |
+
in_pose = get_pose(in_img, dwpose, "in_pose.png")
|
| 205 |
+
train_poses = []
|
| 206 |
+
train_imgs = [resize_and_pad(img, in_img) for img in images[1:]]
|
| 207 |
+
|
| 208 |
+
for i, img in enumerate(train_imgs):
|
| 209 |
+
train_poses.append(get_pose(img, dwpose, "tr_pose"+str(i)+".png"))
|
| 210 |
+
|
| 211 |
+
return in_img, in_pose, train_imgs, train_poses
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize='target', is_app=False):
|
| 215 |
+
progress=gr.Progress(track_tqdm=True)
|
| 216 |
+
|
| 217 |
+
print("prepare_inputs_inference")
|
| 218 |
+
|
| 219 |
+
in_pose = get_pose(in_img, dwpose, "in_pose.png")
|
| 220 |
+
|
| 221 |
+
frames = extract_frames(in_vid, fps)
|
| 222 |
+
#frames = [removebg(img, rembg_session) for img in frames]
|
| 223 |
+
if debug:
|
| 224 |
+
for i, frame in enumerate(frames):
|
| 225 |
+
frame.save("out/frame_"+str(i)+".png")
|
| 226 |
+
|
| 227 |
+
print("vid: ", in_vid, fps)
|
| 228 |
+
|
| 229 |
+
progress_bar = tqdm(range(len(frames)), initial=0, desc="Frames")
|
| 230 |
+
target_poses = []
|
| 231 |
+
max_left = max_top = 999999
|
| 232 |
+
max_right = max_bottom = 0
|
| 233 |
+
it = frames
|
| 234 |
+
if is_app:
|
| 235 |
+
it = progress.tqdm(frames, desc="Pose Detection")
|
| 236 |
+
for f in it:
|
| 237 |
+
tpose = get_pose(f, dwpose, "tar_pose"+str(len(target_poses))+".png")
|
| 238 |
+
target_poses.append(tpose)
|
| 239 |
+
progress_bar.update(1)
|
| 240 |
+
|
| 241 |
+
bbox = tpose.getbbox()
|
| 242 |
+
left, top, right, bottom = bbox
|
| 243 |
+
max_left = min(max_left, left)
|
| 244 |
+
max_top = min(max_top, top)
|
| 245 |
+
max_right = max(max_right, right)
|
| 246 |
+
max_bottom = max(max_bottom, bottom)
|
| 247 |
+
|
| 248 |
+
target_poses_cropped = []
|
| 249 |
+
for tpose in target_poses:
|
| 250 |
+
if resize=='target':
|
| 251 |
+
tpose = tpose.crop((max_left, max_top, max_right, max_bottom))
|
| 252 |
+
tpose = ImageOps.expand(tpose, border=int(tpose.width*0.2), fill=(0,0,0))
|
| 253 |
+
|
| 254 |
+
tpose = resize_and_pad(tpose, in_img)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if debug:
|
| 258 |
+
tpose.save("out/"+"tar_pose"+str(len(target_poses_cropped))+".png")
|
| 259 |
+
target_poses_cropped.append(tpose)
|
| 260 |
+
|
| 261 |
+
return target_poses_cropped, in_pose
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def prepare_inputs(images, in_vid, fps, bg_remove, dwpose, rembg_session, resize='target', is_app=False):
|
| 265 |
+
|
| 266 |
+
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
|
| 267 |
+
|
| 268 |
+
target_poses_cropped, _ = prepare_inputs_inference(in_img, in_vid, fps, dwpose, resize, is_app)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
return in_img, in_pose, train_imgs, train_poses, target_poses_cropped
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# Training ===================================================================================================
|
| 275 |
+
|
| 276 |
+
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
| 277 |
+
check_min_version("0.18.0.dev0")
|
| 278 |
+
|
| 279 |
+
logger = get_logger(__name__)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class ImageProjModel_p(torch.nn.Module):
|
| 283 |
+
"""SD model with image prompt"""
|
| 284 |
+
|
| 285 |
+
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
|
| 286 |
+
super().__init__()
|
| 287 |
+
|
| 288 |
+
self.net = nn.Sequential(
|
| 289 |
+
nn.Linear(in_dim, hidden_dim),
|
| 290 |
+
nn.GELU(),
|
| 291 |
+
nn.Dropout(dropout),
|
| 292 |
+
nn.LayerNorm(hidden_dim),
|
| 293 |
+
nn.Linear(hidden_dim, out_dim),
|
| 294 |
+
nn.Dropout(dropout)
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def forward(self, x):
|
| 298 |
+
return self.net(x)
|
| 299 |
+
|
| 300 |
+
class ImageProjModel_g(torch.nn.Module):
|
| 301 |
+
"""SD model with image prompt"""
|
| 302 |
+
|
| 303 |
+
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
|
| 304 |
+
super().__init__()
|
| 305 |
+
|
| 306 |
+
self.net = nn.Sequential(
|
| 307 |
+
nn.Linear(in_dim, hidden_dim),
|
| 308 |
+
nn.GELU(),
|
| 309 |
+
nn.Dropout(dropout),
|
| 310 |
+
nn.LayerNorm(hidden_dim),
|
| 311 |
+
nn.Linear(hidden_dim, out_dim),
|
| 312 |
+
nn.Dropout(dropout)
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
def forward(self, x): # b, 257,1280
|
| 316 |
+
return self.net(x)
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class SDModel(torch.nn.Module):
|
| 320 |
+
"""SD model with image prompt"""
|
| 321 |
+
def __init__(self, unet) -> None:
|
| 322 |
+
super().__init__()
|
| 323 |
+
self.image_proj_model_p = ImageProjModel_p(in_dim=1536, hidden_dim=768, out_dim=1024)
|
| 324 |
+
|
| 325 |
+
self.unet = unet
|
| 326 |
+
self.pose_proj = ControlNetConditioningEmbedding(
|
| 327 |
+
conditioning_embedding_channels=320,
|
| 328 |
+
block_out_channels=(16, 32, 96, 256),
|
| 329 |
+
conditioning_channels=3)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
def forward(self, noisy_latents, timesteps, simg_f_p, timg_f_g, pose_f):
|
| 333 |
+
|
| 334 |
+
extra_image_embeddings_p = self.image_proj_model_p(simg_f_p)
|
| 335 |
+
extra_image_embeddings_g = timg_f_g
|
| 336 |
+
|
| 337 |
+
print(extra_image_embeddings_p.size())
|
| 338 |
+
print(extra_image_embeddings_g.size())
|
| 339 |
+
|
| 340 |
+
encoder_image_hidden_states = torch.cat([extra_image_embeddings_p ,extra_image_embeddings_g], dim=1)
|
| 341 |
+
pose_cond = self.pose_proj(pose_f)
|
| 342 |
+
|
| 343 |
+
pred_noise = self.unet(noisy_latents, timesteps, class_labels=timg_f_g, encoder_hidden_states=encoder_image_hidden_states,my_pose_cond=pose_cond).sample
|
| 344 |
+
return pred_noise
|
| 345 |
+
|
| 346 |
+
def load_training_checkpoint(model, pcdms_model, tag=None, **kwargs):
|
| 347 |
+
#model_sd = torch.load(load_dir, map_location="cpu")["module"]
|
| 348 |
+
model_sd = torch.load(
|
| 349 |
+
pcdms_model,
|
| 350 |
+
map_location="cpu"
|
| 351 |
+
)["module"]
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
image_proj_model_dict = {}
|
| 355 |
+
pose_proj_dict = {}
|
| 356 |
+
unet_dict = {}
|
| 357 |
+
for k in model_sd.keys():
|
| 358 |
+
if k.startswith("pose_proj"):
|
| 359 |
+
pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
|
| 360 |
+
|
| 361 |
+
elif k.startswith("image_proj_model_p"):
|
| 362 |
+
image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
|
| 363 |
+
|
| 364 |
+
elif k.startswith("image_proj_model."):
|
| 365 |
+
image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
elif k.startswith("unet"):
|
| 369 |
+
unet_dict[k.replace("unet.", "")] = model_sd[k]
|
| 370 |
+
else:
|
| 371 |
+
print(k)
|
| 372 |
+
|
| 373 |
+
model.pose_proj.load_state_dict(pose_proj_dict)
|
| 374 |
+
model.image_proj_model_p.load_state_dict(image_proj_model_dict)
|
| 375 |
+
model.unet.load_state_dict(unet_dict)
|
| 376 |
+
|
| 377 |
+
return model, 0, 0
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def checkpoint_model(checkpoint_folder, ckpt_id, model, epoch, last_global_step, **kwargs):
|
| 381 |
+
"""Utility function for checkpointing model + optimizer dictionaries
|
| 382 |
+
The main purpose for this is to be able to resume training from that instant again
|
| 383 |
+
"""
|
| 384 |
+
checkpoint_state_dict = {
|
| 385 |
+
"epoch": epoch,
|
| 386 |
+
"last_global_step": last_global_step,
|
| 387 |
+
}
|
| 388 |
+
# Add extra kwargs too
|
| 389 |
+
checkpoint_state_dict.update(kwargs)
|
| 390 |
+
|
| 391 |
+
success = model.save_checkpoint(checkpoint_folder, ckpt_id, checkpoint_state_dict)
|
| 392 |
+
status_msg = f"checkpointing: checkpoint_folder={checkpoint_folder}, ckpt_id={ckpt_id}"
|
| 393 |
+
if success:
|
| 394 |
+
logging.info(f"Success {status_msg}")
|
| 395 |
+
else:
|
| 396 |
+
logging.warning(f"Failure {status_msg}")
|
| 397 |
+
return
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
@spaces.GPU(duration=600)
|
| 401 |
+
def train(modelId, in_image, in_pose, train_images, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune=True, is_app=False):
|
| 402 |
+
logging_dir = 'outputs/logging'
|
| 403 |
+
print('start train')
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
progress=gr.Progress(track_tqdm=True)
|
| 407 |
+
|
| 408 |
+
accelerator = Accelerator(
|
| 409 |
+
log_with=args.report_to,
|
| 410 |
+
project_dir=logging_dir,
|
| 411 |
+
mixed_precision=args.mixed_precision,
|
| 412 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# Make one log on every process with the configuration for debugging.
|
| 416 |
+
#logging.basicConfig(
|
| 417 |
+
# format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 418 |
+
# datefmt="%m/%d/%Y %H:%M:%S",
|
| 419 |
+
# level=logging.INFO, )
|
| 420 |
+
|
| 421 |
+
print(accelerator.state)
|
| 422 |
+
if accelerator.is_local_main_process:
|
| 423 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 424 |
+
diffusers.utils.logging.set_verbosity_info()
|
| 425 |
+
else:
|
| 426 |
+
transformers.utils.logging.set_verbosity_error()
|
| 427 |
+
diffusers.utils.logging.set_verbosity_error()
|
| 428 |
+
|
| 429 |
+
# If passed along, set the training seed now.
|
| 430 |
+
set_seed(42)
|
| 431 |
+
|
| 432 |
+
# Handle the repository creation
|
| 433 |
+
if accelerator.is_main_process:
|
| 434 |
+
os.makedirs('outputs', exist_ok=True)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
"""
|
| 438 |
+
unet = Stage2_InapintUNet2DConditionModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="unet",
|
| 439 |
+
in_channels=9, class_embed_type="projection" ,projection_class_embeddings_input_dim=1024,
|
| 440 |
+
low_cpu_mem_usage=False, ignore_mismatched_sizes=True)
|
| 441 |
+
"""
|
| 442 |
+
image_encoder_p.requires_grad_(False)
|
| 443 |
+
image_encoder_g.requires_grad_(False)
|
| 444 |
+
vae.requires_grad_(False)
|
| 445 |
+
|
| 446 |
+
sd_model = SDModel(unet=unet)
|
| 447 |
+
sd_model.train()
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
if args.gradient_checkpointing:
|
| 451 |
+
sd_model.enable_gradient_checkpointing()
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# Enable TF32 for faster training on Ampere GPUs,
|
| 455 |
+
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
| 456 |
+
if args.allow_tf32:
|
| 457 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 458 |
+
|
| 459 |
+
learning_rate = 1e-4
|
| 460 |
+
train_batch_size = min(len(train_images), max_batch_size) #len(train_images) % 16
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
# Optimizer creation
|
| 464 |
+
params_to_optimize = sd_model.parameters()
|
| 465 |
+
optimizer = torch.optim.AdamW(
|
| 466 |
+
params_to_optimize,
|
| 467 |
+
lr=learning_rate,
|
| 468 |
+
betas=(args.adam_beta1, args.adam_beta2),
|
| 469 |
+
weight_decay=args.adam_weight_decay,
|
| 470 |
+
eps=args.adam_epsilon,
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
inputs = [{
|
| 474 |
+
"source_image": in_image,
|
| 475 |
+
"source_pose": in_pose,
|
| 476 |
+
"target_image": timg,
|
| 477 |
+
"target_pose": tpose,
|
| 478 |
+
} for timg, tpose in zip(train_images, train_poses)]
|
| 479 |
+
|
| 480 |
+
"""
|
| 481 |
+
inputs = {[
|
| 482 |
+
"source_image": Image.open('imgs/sm.png'),
|
| 483 |
+
"source_pose": Image.open('imgs/sm_pose.jpg'),
|
| 484 |
+
"target_image": Image.open('imgs/target.png'),
|
| 485 |
+
"target_pose": Image.open('imgs/target_pose.jpg'),
|
| 486 |
+
]}
|
| 487 |
+
"""
|
| 488 |
+
|
| 489 |
+
#print(inputs)
|
| 490 |
+
|
| 491 |
+
dataset = InpaintDataset(
|
| 492 |
+
inputs,
|
| 493 |
+
'imgs/',
|
| 494 |
+
size=(args.img_width, args.img_height), # w h
|
| 495 |
+
imgp_drop_rate=0.1,
|
| 496 |
+
imgg_drop_rate=0.1,
|
| 497 |
+
)
|
| 498 |
+
|
| 499 |
+
"""
|
| 500 |
+
dataset = InpaintDataset(
|
| 501 |
+
args.json_path,
|
| 502 |
+
args.image_root_path,
|
| 503 |
+
size=(args.img_width, args.img_height), # w h
|
| 504 |
+
imgp_drop_rate=0.1,
|
| 505 |
+
imgg_drop_rate=0.1,
|
| 506 |
+
)
|
| 507 |
+
"""
|
| 508 |
+
|
| 509 |
+
train_sampler = torch.utils.data.distributed.DistributedSampler(
|
| 510 |
+
dataset, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle=True)
|
| 511 |
+
|
| 512 |
+
train_dataloader = torch.utils.data.DataLoader(
|
| 513 |
+
dataset,
|
| 514 |
+
sampler=train_sampler,
|
| 515 |
+
collate_fn=InpaintCollate_fn,
|
| 516 |
+
batch_size=train_batch_size,
|
| 517 |
+
num_workers=0,)
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
# Scheduler and math around the number of training steps.
|
| 521 |
+
overrode_max_train_steps = False
|
| 522 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 523 |
+
if args.max_train_steps is None:
|
| 524 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 525 |
+
overrode_max_train_steps = True
|
| 526 |
+
args.max_train_steps = train_steps
|
| 527 |
+
|
| 528 |
+
lr_scheduler = get_scheduler(
|
| 529 |
+
args.lr_scheduler,
|
| 530 |
+
optimizer=optimizer,
|
| 531 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
| 532 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
| 533 |
+
num_cycles=args.lr_num_cycles,
|
| 534 |
+
power=args.lr_power,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# Prepare everything with our `accelerator`.
|
| 538 |
+
sd_model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(sd_model, optimizer, train_dataloader, lr_scheduler)
|
| 539 |
+
|
| 540 |
+
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
| 541 |
+
# as these models are only used for inference, keeping weights in full precision is not required.
|
| 542 |
+
weight_dtype = torch.float32
|
| 543 |
+
"""
|
| 544 |
+
if accelerator.mixed_precision == "fp16":
|
| 545 |
+
weight_dtype = torch.float16
|
| 546 |
+
elif accelerator.mixed_precision == "bf16":
|
| 547 |
+
weight_dtype = torch.bfloat16
|
| 548 |
+
"""
|
| 549 |
+
|
| 550 |
+
# Move vae, unet and text_encoder to device and cast to weight_dtype
|
| 551 |
+
vae.to(accelerator.device, dtype=weight_dtype)
|
| 552 |
+
sd_model.unet.to(accelerator.device, dtype=weight_dtype)
|
| 553 |
+
image_encoder_p.to(accelerator.device, dtype=weight_dtype)
|
| 554 |
+
image_encoder_g.to(accelerator.device, dtype=weight_dtype)
|
| 555 |
+
|
| 556 |
+
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
| 557 |
+
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
| 558 |
+
if overrode_max_train_steps:
|
| 559 |
+
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
| 560 |
+
# Afterwards we recalculate our number of training epochs
|
| 561 |
+
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
args.num_train_epochs = train_steps
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
# Train!
|
| 568 |
+
total_batch_size = (
|
| 569 |
+
train_batch_size
|
| 570 |
+
* accelerator.num_processes
|
| 571 |
+
* args.gradient_accumulation_steps
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
print("***** Running training *****")
|
| 575 |
+
print(f" Num batches each epoch = {len(train_dataloader)}")
|
| 576 |
+
print(f" Num Epochs = {args.num_train_epochs}")
|
| 577 |
+
print(f" Instantaneous batch size per device = {train_batch_size}")
|
| 578 |
+
print(
|
| 579 |
+
f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
|
| 580 |
+
)
|
| 581 |
+
print(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
| 582 |
+
print(f" Total optimization steps = {args.max_train_steps}")
|
| 583 |
+
|
| 584 |
+
|
| 585 |
+
if args.resume_from_checkpoint:
|
| 586 |
+
# New Code #
|
| 587 |
+
# Loads the DeepSpeed checkpoint from the specified path
|
| 588 |
+
prior_model, last_epoch, last_global_step = load_training_checkpoint(
|
| 589 |
+
sd_model,
|
| 590 |
+
pcdms_model,
|
| 591 |
+
**{"load_optimizer_states": True, "load_lr_scheduler_states": True},
|
| 592 |
+
)
|
| 593 |
+
print(f"Resumed from checkpoint: {args.resume_from_checkpoint}, global step: {last_global_step}")
|
| 594 |
+
starting_epoch = last_epoch
|
| 595 |
+
global_steps = last_global_step
|
| 596 |
+
sd_model = sd_model
|
| 597 |
+
else:
|
| 598 |
+
global_steps = 0
|
| 599 |
+
starting_epoch = 0
|
| 600 |
+
sd_model = sd_model
|
| 601 |
+
|
| 602 |
+
progress_bar = tqdm(range(global_steps, args.max_train_steps), initial=global_steps, desc="Steps",
|
| 603 |
+
# Only show the progress bar once on each machine.
|
| 604 |
+
disable=not accelerator.is_local_main_process, )
|
| 605 |
+
|
| 606 |
+
bsz = train_batch_size
|
| 607 |
+
|
| 608 |
+
if not finetune or train_steps == 0:
|
| 609 |
+
accelerator.wait_for_everyone()
|
| 610 |
+
accelerator.end_training()
|
| 611 |
+
return {k: v.cpu() for k, v in sd_model.state_dict().items()}
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
it = range(starting_epoch, args.num_train_epochs)
|
| 615 |
+
if is_app:
|
| 616 |
+
it = progress.tqdm(it, desc="Fine-tuning")
|
| 617 |
+
for epoch in it:
|
| 618 |
+
for step, batch in enumerate(train_dataloader):
|
| 619 |
+
with accelerator.accumulate(sd_model):
|
| 620 |
+
with torch.no_grad():
|
| 621 |
+
# Convert images to latent space
|
| 622 |
+
latents = vae.encode(batch["source_target_image"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 623 |
+
latents = latents * vae.config.scaling_factor
|
| 624 |
+
|
| 625 |
+
# Get the masked image latents
|
| 626 |
+
masked_latents = vae.encode(batch["vae_source_mask_image"].to(dtype=weight_dtype)).latent_dist.sample()
|
| 627 |
+
masked_latents = masked_latents * vae.config.scaling_factor
|
| 628 |
+
|
| 629 |
+
bsz = batch["target_image"].size(dim=0)
|
| 630 |
+
|
| 631 |
+
# mask
|
| 632 |
+
mask1 = torch.ones((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
|
| 633 |
+
mask0 = torch.zeros((bsz, 1, int(args.img_height / 8), int(args.img_width / 8))).to(accelerator.device, dtype=weight_dtype)
|
| 634 |
+
mask = torch.cat([mask1, mask0], dim=3)
|
| 635 |
+
# Get the image embedding for conditioning
|
| 636 |
+
cond_image_feature_p = image_encoder_p(batch["source_image"].to(accelerator.device, dtype=weight_dtype))
|
| 637 |
+
cond_image_feature_p = (cond_image_feature_p.last_hidden_state)
|
| 638 |
+
|
| 639 |
+
|
| 640 |
+
cond_image_feature_g = image_encoder_g(batch["target_image"].to(accelerator.device, dtype=weight_dtype), ).image_embeds
|
| 641 |
+
cond_image_feature_g =cond_image_feature_g.unsqueeze(1)
|
| 642 |
+
|
| 643 |
+
# Sample noise that we'll add to the latents
|
| 644 |
+
noise = torch.randn_like(latents)
|
| 645 |
+
if args.noise_offset:
|
| 646 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
| 647 |
+
noise += args.noise_offset * torch.randn(
|
| 648 |
+
(latents.shape[0], latents.shape[1], 1, 1), device=latents.device
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# Sample a random timestep for each image
|
| 652 |
+
#timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (train_batch_size,),device=latents.device, )
|
| 653 |
+
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,),device=latents.device, )
|
| 654 |
+
timesteps = timesteps.long()
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
# Add noise to the latents according to the noise magnitude at each timestep (this is the forward diffusion process)
|
| 659 |
+
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
| 660 |
+
|
| 661 |
+
#print(noisy_latents.size(), mask.size(), masked_latents.size())
|
| 662 |
+
|
| 663 |
+
noisy_latents = torch.cat([noisy_latents, mask, masked_latents], dim=1)
|
| 664 |
+
# Get the text embedding for conditioning
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
cond_pose = batch["source_target_pose"].to(dtype=weight_dtype)
|
| 668 |
+
|
| 669 |
+
#print(noisy_latents.size())
|
| 670 |
+
#print(cond_image_feature_p.size())
|
| 671 |
+
#print(cond_image_feature_g.size())
|
| 672 |
+
#print(cond_pose.size())
|
| 673 |
+
|
| 674 |
+
# Predict the noise residual
|
| 675 |
+
model_pred = sd_model(noisy_latents, timesteps, cond_image_feature_p,cond_image_feature_g, cond_pose, )
|
| 676 |
+
|
| 677 |
+
# Get the target for loss depending on the prediction type
|
| 678 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 679 |
+
target = noise
|
| 680 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 681 |
+
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
| 682 |
+
else:
|
| 683 |
+
raise ValueError(
|
| 684 |
+
f"Unknown prediction type {noise_scheduler.config.prediction_type}"
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
| 688 |
+
|
| 689 |
+
accelerator.backward(loss)
|
| 690 |
+
if accelerator.sync_gradients:
|
| 691 |
+
params_to_clip = sd_model.parameters()
|
| 692 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
| 693 |
+
optimizer.step()
|
| 694 |
+
lr_scheduler.step()
|
| 695 |
+
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
| 696 |
+
|
| 697 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
| 698 |
+
if accelerator.sync_gradients:
|
| 699 |
+
global_steps += 1
|
| 700 |
+
|
| 701 |
+
if global_steps >= args.max_train_steps:
|
| 702 |
+
break
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
| 706 |
+
print(logs)
|
| 707 |
+
progress_bar.set_postfix(**logs)
|
| 708 |
+
|
| 709 |
+
progress_bar.update(1)
|
| 710 |
+
|
| 711 |
+
# Create the pipeline using the trained modules and save it.
|
| 712 |
+
accelerator.wait_for_everyone()
|
| 713 |
+
accelerator.end_training()
|
| 714 |
+
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
if save_model: #if global_steps % args.checkpointing_steps == 0 or global_steps == args.max_train_steps:
|
| 718 |
+
print('saving', modelId)
|
| 719 |
+
|
| 720 |
+
checkpoint_state_dict = {
|
| 721 |
+
"epoch": 0,
|
| 722 |
+
"module": {k: v.cpu() for k, v in sd_model.state_dict().items()}, #sd_model.state_dict(),
|
| 723 |
+
}
|
| 724 |
+
print(list(sd_model.state_dict().keys())[:20])
|
| 725 |
+
torch.save(checkpoint_state_dict, modelId+".pt")
|
| 726 |
+
|
| 727 |
+
gc.collect()
|
| 728 |
+
torch.cuda.empty_cache()
|
| 729 |
+
#device = cuda.get_current_device()
|
| 730 |
+
#device.reset()
|
| 731 |
+
print('done train')
|
| 732 |
+
return
|
| 733 |
+
|
| 734 |
+
gc.collect()
|
| 735 |
+
torch.cuda.empty_cache()
|
| 736 |
+
return {k: v.cpu() for k, v in sd_model.state_dict().items()}
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
# Pose-transfer ===================================================================================================
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
device = "cuda"
|
| 745 |
+
|
| 746 |
+
class ImageProjModel(torch.nn.Module):
|
| 747 |
+
"""SD model with image prompt"""
|
| 748 |
+
def __init__(self, in_dim, hidden_dim, out_dim, dropout = 0.):
|
| 749 |
+
super().__init__()
|
| 750 |
+
|
| 751 |
+
self.net = nn.Sequential(
|
| 752 |
+
nn.Linear(in_dim, hidden_dim),
|
| 753 |
+
nn.GELU(),
|
| 754 |
+
nn.Dropout(dropout),
|
| 755 |
+
nn.LayerNorm(hidden_dim),
|
| 756 |
+
nn.Linear(hidden_dim, out_dim),
|
| 757 |
+
nn.Dropout(dropout)
|
| 758 |
+
)
|
| 759 |
+
|
| 760 |
+
def forward(self, x):
|
| 761 |
+
return self.net(x)
|
| 762 |
+
|
| 763 |
+
def image_grid(imgs, rows, cols):
|
| 764 |
+
assert len(imgs) == rows * cols
|
| 765 |
+
w, h = imgs[0].size
|
| 766 |
+
print(w, h)
|
| 767 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 768 |
+
grid_w, grid_h = grid.size
|
| 769 |
+
|
| 770 |
+
for i, img in enumerate(imgs):
|
| 771 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
| 772 |
+
return grid
|
| 773 |
+
|
| 774 |
+
def load_mydict(modelId, finetuned_model):
|
| 775 |
+
if save_model:
|
| 776 |
+
model_ckpt_path = modelId+'.pt'
|
| 777 |
+
model_sd = torch.load(model_ckpt_path, map_location="cpu")["module"]
|
| 778 |
+
else:
|
| 779 |
+
model_sd = finetuned_model #torch.load(model_ckpt_path, map_location="cpu")["module"]
|
| 780 |
+
|
| 781 |
+
image_proj_model_dict = {}
|
| 782 |
+
pose_proj_dict = {}
|
| 783 |
+
unet_dict = {}
|
| 784 |
+
for k in model_sd.keys():
|
| 785 |
+
if k.startswith("pose_proj"):
|
| 786 |
+
pose_proj_dict[k.replace("pose_proj.", "")] = model_sd[k]
|
| 787 |
+
|
| 788 |
+
elif k.startswith("image_proj_model_p"):
|
| 789 |
+
image_proj_model_dict[k.replace("image_proj_model_p.", "")] = model_sd[k]
|
| 790 |
+
elif k.startswith("image_proj_model"):
|
| 791 |
+
image_proj_model_dict[k.replace("image_proj_model.", "")] = model_sd[k]
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
elif k.startswith("unet"):
|
| 795 |
+
unet_dict[k.replace("unet.", "")] = model_sd[k]
|
| 796 |
+
else:
|
| 797 |
+
print(k)
|
| 798 |
+
return image_proj_model_dict, pose_proj_dict, unet_dict
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
@spaces.GPU(duration=600)
|
| 803 |
+
def inference(modelId, in_image, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder, is_app=False):
|
| 804 |
+
print('start inference')
|
| 805 |
+
progress=gr.Progress(track_tqdm=True)
|
| 806 |
+
|
| 807 |
+
if not save_model:
|
| 808 |
+
finetuned_model = {k: v.cuda() for k, v in finetuned_model.items()}
|
| 809 |
+
|
| 810 |
+
device = "cuda"
|
| 811 |
+
pretrained_model_name_or_path ="stabilityai/stable-diffusion-2-1-base"
|
| 812 |
+
image_encoder_path = "facebook/dinov2-giant"
|
| 813 |
+
#model_ckpt_path = "./pcdms_ckpt.pt" # ckpt path
|
| 814 |
+
model_ckpt_path = modelId+'.pt'
|
| 815 |
+
|
| 816 |
+
|
| 817 |
+
clip_image_processor = CLIPImageProcessor()
|
| 818 |
+
img_transform = transforms.Compose([
|
| 819 |
+
transforms.ToTensor(),
|
| 820 |
+
transforms.Normalize([0.5], [0.5]),
|
| 821 |
+
])
|
| 822 |
+
|
| 823 |
+
generator = torch.Generator(device=device).manual_seed(42)
|
| 824 |
+
|
| 825 |
+
"""
|
| 826 |
+
unet = Stage2_InapintUNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, torch_dtype=torch.float16,subfolder="unet",in_channels=9, low_cpu_mem_usage=False, ignore_mismatched_sizes=True).to(device)
|
| 827 |
+
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path,subfolder="vae").to(device, dtype=torch.float16)
|
| 828 |
+
image_encoder = Dinov2Model.from_pretrained(image_encoder_path).to(device, dtype=torch.float16)
|
| 829 |
+
"""
|
| 830 |
+
noise_scheduler = DDIMScheduler(
|
| 831 |
+
num_train_timesteps=1000,
|
| 832 |
+
beta_start=0.00085,
|
| 833 |
+
beta_end=0.012,
|
| 834 |
+
beta_schedule="scaled_linear",
|
| 835 |
+
clip_sample=False,
|
| 836 |
+
set_alpha_to_one=False,
|
| 837 |
+
steps_offset=1,
|
| 838 |
+
)
|
| 839 |
+
|
| 840 |
+
unet = unet.to(device, dtype=torch.float16)
|
| 841 |
+
vae = vae.to(device, dtype=torch.float16)
|
| 842 |
+
image_encoder = image_encoder.to(device, dtype=torch.float16)
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
image_proj_model = ImageProjModel(in_dim=1536, hidden_dim=768, out_dim=1024).to(device).to(dtype=torch.float16)
|
| 846 |
+
pose_proj_model = ControlNetConditioningEmbedding(
|
| 847 |
+
conditioning_embedding_channels=320,
|
| 848 |
+
block_out_channels=(16, 32, 96, 256),
|
| 849 |
+
conditioning_channels=3).to(device).to(dtype=torch.float16)
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
# load weight
|
| 853 |
+
print('loading', modelId)
|
| 854 |
+
image_proj_model_dict, pose_proj_dict, unet_dict = load_mydict(modelId, finetuned_model)
|
| 855 |
+
print('loaded', modelId)
|
| 856 |
+
image_proj_model.load_state_dict(image_proj_model_dict)
|
| 857 |
+
pose_proj_model.load_state_dict(pose_proj_dict)
|
| 858 |
+
unet.load_state_dict(unet_dict)
|
| 859 |
+
|
| 860 |
+
|
| 861 |
+
pipe = PCDMsPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base", unet=unet, torch_dtype=torch.float16, scheduler=noise_scheduler,feature_extractor=None,safety_checker=None).to(device)
|
| 862 |
+
|
| 863 |
+
print('====================== model load finish ===================')
|
| 864 |
+
|
| 865 |
+
results = []
|
| 866 |
+
progress_bar = tqdm(range(len(target_poses)), initial=0, desc="Frames")
|
| 867 |
+
|
| 868 |
+
|
| 869 |
+
it = target_poses
|
| 870 |
+
if is_app:
|
| 871 |
+
it = progress.tqdm(it, desc="Pose Transfer")
|
| 872 |
+
for pose in it:
|
| 873 |
+
|
| 874 |
+
num_samples = 1
|
| 875 |
+
image_size = (512, 512)
|
| 876 |
+
s_img_path = 'imgs/'+input_img # input image 1
|
| 877 |
+
#target_pose_img = 'imgs/pose_'+str(n)+'.png' # input image 2
|
| 878 |
+
|
| 879 |
+
#t_pose = inference_pose(target_pose_img, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
|
| 880 |
+
#t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
|
| 881 |
+
t_pose = pose.convert("RGB").resize((image_size), Image.BICUBIC)
|
| 882 |
+
#t_pose = resize_and_pad(pose.convert("RGB"))
|
| 883 |
+
|
| 884 |
+
|
| 885 |
+
#s_img = Image.open(s_img_path)
|
| 886 |
+
width_orig, height_orig = in_image.size
|
| 887 |
+
s_img = in_image.convert("RGB").resize(image_size, Image.BICUBIC)
|
| 888 |
+
#s_img = resize_and_pad(in_image.convert("RGB"))
|
| 889 |
+
black_image = Image.new("RGB", s_img.size, (0, 0, 0)).resize(image_size, Image.BICUBIC)
|
| 890 |
+
|
| 891 |
+
s_img_t_mask = Image.new("RGB", (s_img.width * 2, s_img.height))
|
| 892 |
+
s_img_t_mask.paste(s_img, (0, 0))
|
| 893 |
+
s_img_t_mask.paste(black_image, (s_img.width, 0))
|
| 894 |
+
|
| 895 |
+
#s_pose = inference_pose(s_img_path, image_size=(image_size[1], image_size[0])).resize(image_size, Image.BICUBIC)
|
| 896 |
+
#s_pose = Image.open('imgs/sm_pose.jpg').convert("RGB").resize(image_size, Image.BICUBIC)
|
| 897 |
+
s_pose = in_pose.convert("RGB").resize(image_size, Image.BICUBIC)
|
| 898 |
+
#s_pose = resize_and_pad(in_pose.convert("RGB"))
|
| 899 |
+
print('source image width: {}, height: {}'.format(s_pose.width, s_pose.height))
|
| 900 |
+
#t_pose = Image.open(target_pose_img).convert("RGB").resize((image_size), Image.BICUBIC)
|
| 901 |
+
|
| 902 |
+
st_pose = Image.new("RGB", (s_pose.width * 2, s_pose.height))
|
| 903 |
+
st_pose.paste(s_pose, (0, 0))
|
| 904 |
+
st_pose.paste(t_pose, (s_pose.width, 0))
|
| 905 |
+
|
| 906 |
+
|
| 907 |
+
clip_s_img = clip_image_processor(images=s_img, return_tensors="pt").pixel_values
|
| 908 |
+
vae_image = torch.unsqueeze(img_transform(s_img_t_mask), 0)
|
| 909 |
+
cond_st_pose = torch.unsqueeze(img_transform(st_pose), 0)
|
| 910 |
+
|
| 911 |
+
mask1 = torch.ones((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
|
| 912 |
+
mask0 = torch.zeros((1, 1, int(image_size[0] / 8), int(image_size[1] / 8))).to(device, dtype=torch.float16)
|
| 913 |
+
mask = torch.cat([mask1, mask0], dim=3)
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
with torch.inference_mode():
|
| 917 |
+
cond_pose = pose_proj_model(cond_st_pose.to(dtype=torch.float16, device=device))
|
| 918 |
+
simg_mask_latents = pipe.vae.encode(vae_image.to(device, dtype=torch.float16)).latent_dist.sample()
|
| 919 |
+
simg_mask_latents = simg_mask_latents * 0.18215
|
| 920 |
+
|
| 921 |
+
images_embeds = image_encoder(clip_s_img.to(device, dtype=torch.float16)).last_hidden_state
|
| 922 |
+
image_prompt_embeds = image_proj_model(images_embeds)
|
| 923 |
+
uncond_image_prompt_embeds = image_proj_model(torch.zeros_like(images_embeds))
|
| 924 |
+
|
| 925 |
+
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
| 926 |
+
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
| 927 |
+
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 928 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
| 929 |
+
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
| 930 |
+
|
| 931 |
+
output, _ = pipe(
|
| 932 |
+
simg_mask_latents= simg_mask_latents,
|
| 933 |
+
mask = mask,
|
| 934 |
+
cond_pose = cond_pose,
|
| 935 |
+
prompt_embeds=image_prompt_embeds,
|
| 936 |
+
negative_prompt_embeds=uncond_image_prompt_embeds,
|
| 937 |
+
height=image_size[1],
|
| 938 |
+
width=image_size[0]*2,
|
| 939 |
+
num_images_per_prompt=num_samples,
|
| 940 |
+
guidance_scale=2.0,
|
| 941 |
+
generator=generator,
|
| 942 |
+
num_inference_steps=inference_steps,
|
| 943 |
+
)
|
| 944 |
+
|
| 945 |
+
output = output.images[-1]
|
| 946 |
+
|
| 947 |
+
result = output.crop((image_size[0], 0, image_size[0] * 2, image_size[1]))
|
| 948 |
+
result = result.resize((width_orig, height_orig), Image.BICUBIC)
|
| 949 |
+
#result = remove_zero_pad(result)
|
| 950 |
+
|
| 951 |
+
if debug:
|
| 952 |
+
result.save('out/'+str(len(results))+'.png')
|
| 953 |
+
results.append(result)
|
| 954 |
+
progress_bar.update(1)
|
| 955 |
+
|
| 956 |
+
gc.collect()
|
| 957 |
+
torch.cuda.empty_cache()
|
| 958 |
+
|
| 959 |
+
return results
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
def gen_vid(frames, video_name, fps, codec):
|
| 963 |
+
progress=gr.Progress(track_tqdm=True)
|
| 964 |
+
|
| 965 |
+
frame = cv2.cvtColor(np.array(frames[0]), cv2.COLOR_RGB2BGR)
|
| 966 |
+
height, width, layers = frame.shape
|
| 967 |
+
|
| 968 |
+
#video = cv2.VideoWriter(video_name, 0, 1, (width,height))
|
| 969 |
+
if codec == 'mp4':
|
| 970 |
+
video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
|
| 971 |
+
else:
|
| 972 |
+
video = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'VP90'), fps, (width, height))
|
| 973 |
+
|
| 974 |
+
for r in progress.tqdm(frames, desc="Creating video"):
|
| 975 |
+
image = cv2.cvtColor(np.array(r), cv2.COLOR_RGB2BGR)
|
| 976 |
+
video.write(image)
|
| 977 |
+
|
| 978 |
+
#cv2.destroyAllWindows()
|
| 979 |
+
#video.release()
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
def run(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True, finetune=True, is_app=False):
|
| 984 |
+
print("==== Load Models ====")
|
| 985 |
+
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 986 |
+
|
| 987 |
+
print("==== Pose Detection ====")
|
| 988 |
+
if resize_inputs:
|
| 989 |
+
resize = 'target'
|
| 990 |
+
else:
|
| 991 |
+
resize = 'none'
|
| 992 |
+
in_img, in_pose, train_imgs, train_poses, target_poses = prepare_inputs(images, video_path, fps, bg_remove, dwpose, rembg_session, resize=resize, is_app=is_app)
|
| 993 |
+
|
| 994 |
+
if save_model:
|
| 995 |
+
train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 996 |
+
print('next')
|
| 997 |
+
results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 998 |
+
|
| 999 |
+
else:
|
| 1000 |
+
print("==== Finetuning ====")
|
| 1001 |
+
finetuned_model = train("fine_tuned_pcdms", in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 1002 |
+
|
| 1003 |
+
print("==== Pose Transfer ====")
|
| 1004 |
+
results = inference("fine_tuned_pcdms", in_img, in_pose, target_poses, inference_steps, finetuned_model, vae, unet, image_encoder_p, is_app)
|
| 1005 |
+
|
| 1006 |
+
return results
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
def run_train(images, train_steps=100, bg_remove=False, resize_inputs=True, modelId="fine_tuned_pcdms"):
|
| 1010 |
+
finetune=True
|
| 1011 |
+
is_app=True
|
| 1012 |
+
images = [img[0] for img in images]
|
| 1013 |
+
|
| 1014 |
+
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 1015 |
+
|
| 1016 |
+
if resize_inputs:
|
| 1017 |
+
resize = 'target'
|
| 1018 |
+
else:
|
| 1019 |
+
resize = 'none'
|
| 1020 |
+
|
| 1021 |
+
in_img, in_pose, train_imgs, train_poses = prepare_inputs_train(images, bg_remove, dwpose, rembg_session)
|
| 1022 |
+
|
| 1023 |
+
train(modelId, in_img, in_pose, train_imgs, train_poses, train_steps, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet, finetune, is_app)
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
def run_inference(images, video_path, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True, modelId="fine_tuned_pcdms"):
|
| 1027 |
+
is_app=True
|
| 1028 |
+
images = [img[0] for img in images]
|
| 1029 |
+
in_img = images[0]
|
| 1030 |
+
|
| 1031 |
+
dwpose, rembg_session, pcdms_model, noise_scheduler, image_encoder_p, image_encoder_g, vae, unet = load_models()
|
| 1032 |
+
|
| 1033 |
+
target_poses, in_pose = prepare_inputs_inference(in_img, video_path, fps, dwpose, 'target', is_app)
|
| 1034 |
+
|
| 1035 |
+
results = inference(modelId, in_img, in_pose, target_poses, inference_steps, None, vae, unet, image_encoder_p, is_app)
|
| 1036 |
+
|
| 1037 |
+
if debug:
|
| 1038 |
+
gen_vid(results, out_vid+'.mp4', fps, 'mp4')
|
| 1039 |
+
else:
|
| 1040 |
+
gen_vid(results, out_vid+'.webm', fps, 'webm')
|
| 1041 |
+
|
| 1042 |
+
print("Done!")
|
| 1043 |
+
|
| 1044 |
+
return out_vid+'.webm', results
|
| 1045 |
+
|
| 1046 |
+
|
| 1047 |
+
def run_app(images, video_path, train_steps=100, inference_steps=10, fps=12, bg_remove=False, resize_inputs=True):
|
| 1048 |
+
|
| 1049 |
+
images = [img[0] for img in images]
|
| 1050 |
+
|
| 1051 |
+
results = run(images, video_path, train_steps, inference_steps, fps, bg_remove, resize_inputs, finetune=True, is_app=True)
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
print("==== Video generation ====")
|
| 1055 |
+
out_vid = f"out_{uuid.uuid4()}"
|
| 1056 |
+
|
| 1057 |
+
if debug:
|
| 1058 |
+
gen_vid(results, out_vid+'.mp4', fps, 'mp4')
|
| 1059 |
+
else:
|
| 1060 |
+
gen_vid(results, out_vid+'.webm', fps, 'webm')
|
| 1061 |
+
|
| 1062 |
+
|
| 1063 |
+
|
| 1064 |
+
print("Done!")
|
| 1065 |
+
|
| 1066 |
+
return out_vid+'.webm', results
|
| 1067 |
+
|
| 1068 |
+
|
| 1069 |
+
|
| 1070 |
+
"""
|
| 1071 |
+
train_steps = 100
|
| 1072 |
+
inference_steps = 10
|
| 1073 |
+
fps = 12
|
| 1074 |
+
"""
|
| 1075 |
+
|
| 1076 |
+
"""
|
| 1077 |
+
iface = gr.Interface(
|
| 1078 |
+
fn=run,
|
| 1079 |
+
inputs=[
|
| 1080 |
+
gr.Gallery(type="pil", label="Images of the Character"),
|
| 1081 |
+
gr.Video(label="Motion-Capture Video"),
|
| 1082 |
+
gr.Number(label="Training steps", value=100),
|
| 1083 |
+
gr.Number(label="Inference steps", value=10),
|
| 1084 |
+
gr.Number(label="Output frame rate", value=12),
|
| 1085 |
+
gr.Checkbox(label="Remove background", value=False),
|
| 1086 |
+
],
|
| 1087 |
+
outputs=[gr.Video(label="Result"), gr.Gallery(type="pil", label="Frames")],
|
| 1088 |
+
title="Keyframes AI",
|
| 1089 |
+
description="Upload images of your character and a motion-capture video to generate an animation of the character.",
|
| 1090 |
+
)
|
| 1091 |
+
"""
|
| 1092 |
+
|
| 1093 |
+
|
| 1094 |
+
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
|
metrics.json
ADDED
|
@@ -0,0 +1,1538 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"sidewalk": {
|
| 3 |
+
"ft": {
|
| 4 |
+
"ssim": {
|
| 5 |
+
"avg": 0.9425153732299805,
|
| 6 |
+
"vals": [
|
| 7 |
+
0.9539688229560852,
|
| 8 |
+
0.9410380721092224,
|
| 9 |
+
0.9339408278465271,
|
| 10 |
+
0.9381983876228333,
|
| 11 |
+
0.937977135181427,
|
| 12 |
+
0.937738835811615,
|
| 13 |
+
0.9322819709777832,
|
| 14 |
+
0.9447341561317444,
|
| 15 |
+
0.9701671004295349,
|
| 16 |
+
0.9403354525566101,
|
| 17 |
+
0.9407327771186829,
|
| 18 |
+
0.9384147524833679,
|
| 19 |
+
0.9429612159729004,
|
| 20 |
+
0.9541351199150085,
|
| 21 |
+
0.9593266844749451,
|
| 22 |
+
0.9423799514770508,
|
| 23 |
+
0.9342405200004578,
|
| 24 |
+
0.9316573143005371,
|
| 25 |
+
0.9376768469810486,
|
| 26 |
+
0.9355674386024475,
|
| 27 |
+
0.9092764258384705,
|
| 28 |
+
0.9486138820648193,
|
| 29 |
+
0.9740718007087708,
|
| 30 |
+
0.9381063580513,
|
| 31 |
+
0.9418080449104309,
|
| 32 |
+
0.9435502886772156,
|
| 33 |
+
0.9383361339569092,
|
| 34 |
+
0.9491550326347351,
|
| 35 |
+
0.9579421877861023,
|
| 36 |
+
0.9433062672615051,
|
| 37 |
+
0.9372091889381409,
|
| 38 |
+
0.9316847920417786,
|
| 39 |
+
0.9353565573692322,
|
| 40 |
+
0.9314845204353333,
|
| 41 |
+
0.9296603202819824,
|
| 42 |
+
0.947252094745636,
|
| 43 |
+
0.9745006561279297,
|
| 44 |
+
0.9371557235717773,
|
| 45 |
+
0.9355664849281311,
|
| 46 |
+
0.9372367858886719,
|
| 47 |
+
0.9420561790466309,
|
| 48 |
+
0.9548425674438477
|
| 49 |
+
]
|
| 50 |
+
},
|
| 51 |
+
"psnr": {
|
| 52 |
+
"avg": 29.313762982030134,
|
| 53 |
+
"vals": [
|
| 54 |
+
31.79728361792486,
|
| 55 |
+
29.12789926111764,
|
| 56 |
+
28.150918151496114,
|
| 57 |
+
28.197009952290536,
|
| 58 |
+
28.30431945800427,
|
| 59 |
+
28.3766962913435,
|
| 60 |
+
27.95230686775166,
|
| 61 |
+
29.60480299071412,
|
| 62 |
+
33.870620988359505,
|
| 63 |
+
28.2723202748136,
|
| 64 |
+
28.65414745402314,
|
| 65 |
+
29.44914322199964,
|
| 66 |
+
28.918554998476544,
|
| 67 |
+
30.829394826527686,
|
| 68 |
+
32.453438778794194,
|
| 69 |
+
29.07800808482171,
|
| 70 |
+
28.126848239355724,
|
| 71 |
+
28.219538928254238,
|
| 72 |
+
28.15327266691132,
|
| 73 |
+
28.37061664435462,
|
| 74 |
+
22.572103732095137,
|
| 75 |
+
30.218097049078665,
|
| 76 |
+
35.291027583657176,
|
| 77 |
+
28.264550325776455,
|
| 78 |
+
28.548636901711063,
|
| 79 |
+
30.077494275071125,
|
| 80 |
+
28.765906730485675,
|
| 81 |
+
30.473295981007006,
|
| 82 |
+
32.31943212041054,
|
| 83 |
+
29.03166397751077,
|
| 84 |
+
28.20738376155079,
|
| 85 |
+
27.52313906207686,
|
| 86 |
+
28.068443800302507,
|
| 87 |
+
28.02290963487483,
|
| 88 |
+
27.137965527670815,
|
| 89 |
+
29.891409757083256,
|
| 90 |
+
35.52381896568835,
|
| 91 |
+
28.314135705730163,
|
| 92 |
+
28.17076848548687,
|
| 93 |
+
29.376791831995504,
|
| 94 |
+
28.654625695022126,
|
| 95 |
+
30.817302643645203
|
| 96 |
+
]
|
| 97 |
+
},
|
| 98 |
+
"lpips": {
|
| 99 |
+
"avg": 0.0587034212159259,
|
| 100 |
+
"vals": [
|
| 101 |
+
0.03719184175133705,
|
| 102 |
+
0.05626270920038223,
|
| 103 |
+
0.06762276589870453,
|
| 104 |
+
0.07581378519535065,
|
| 105 |
+
0.06386660039424896,
|
| 106 |
+
0.06519967317581177,
|
| 107 |
+
0.06447210162878036,
|
| 108 |
+
0.044685374945402145,
|
| 109 |
+
0.02818106673657894,
|
| 110 |
+
0.06283751875162125,
|
| 111 |
+
0.05770150199532509,
|
| 112 |
+
0.05407204478979111,
|
| 113 |
+
0.06568614393472672,
|
| 114 |
+
0.041273392736911774,
|
| 115 |
+
0.03479328751564026,
|
| 116 |
+
0.05250183865427971,
|
| 117 |
+
0.06724748760461807,
|
| 118 |
+
0.07919331640005112,
|
| 119 |
+
0.06452416628599167,
|
| 120 |
+
0.06678760051727295,
|
| 121 |
+
0.15427803993225098,
|
| 122 |
+
0.04608333483338356,
|
| 123 |
+
0.026074513792991638,
|
| 124 |
+
0.06458891183137894,
|
| 125 |
+
0.06116756051778793,
|
| 126 |
+
0.050854653120040894,
|
| 127 |
+
0.06727714836597443,
|
| 128 |
+
0.04326387867331505,
|
| 129 |
+
0.037447478622198105,
|
| 130 |
+
0.0521712489426136,
|
| 131 |
+
0.0674269050359726,
|
| 132 |
+
0.07980889827013016,
|
| 133 |
+
0.06633622944355011,
|
| 134 |
+
0.06805833429098129,
|
| 135 |
+
0.07028443366289139,
|
| 136 |
+
0.043847665190696716,
|
| 137 |
+
0.025734560564160347,
|
| 138 |
+
0.06363274157047272,
|
| 139 |
+
0.06513398140668869,
|
| 140 |
+
0.0585525706410408,
|
| 141 |
+
0.06366591155529022,
|
| 142 |
+
0.039940472692251205
|
| 143 |
+
]
|
| 144 |
+
},
|
| 145 |
+
"fid": {
|
| 146 |
+
"avg": 42.43536594462592,
|
| 147 |
+
"vals": [
|
| 148 |
+
20.22960865639975,
|
| 149 |
+
47.958607601608676,
|
| 150 |
+
57.62921688800561,
|
| 151 |
+
95.16926849364827,
|
| 152 |
+
52.86706191106997,
|
| 153 |
+
40.88534322140695,
|
| 154 |
+
66.72949840511895,
|
| 155 |
+
24.21628873128267,
|
| 156 |
+
18.986568909208156,
|
| 157 |
+
33.69661265603083,
|
| 158 |
+
24.557634338057728,
|
| 159 |
+
54.261117389557846,
|
| 160 |
+
51.156090096940055,
|
| 161 |
+
22.082607032341244,
|
| 162 |
+
24.085151411940856,
|
| 163 |
+
45.16219303566392,
|
| 164 |
+
51.10103371191223,
|
| 165 |
+
94.4284018519586,
|
| 166 |
+
61.20478944897824,
|
| 167 |
+
53.09587418802096,
|
| 168 |
+
45.77148143461985,
|
| 169 |
+
25.706806643970687,
|
| 170 |
+
17.209723468709765,
|
| 171 |
+
38.78578933921177,
|
| 172 |
+
25.072541814813434,
|
| 173 |
+
41.178136179334956,
|
| 174 |
+
37.296019926311956,
|
| 175 |
+
27.822427143452785,
|
| 176 |
+
22.649814443169078,
|
| 177 |
+
37.95954774350004,
|
| 178 |
+
47.89610979304835,
|
| 179 |
+
90.26735601619404,
|
| 180 |
+
50.99027170447254,
|
| 181 |
+
47.766544543497034,
|
| 182 |
+
74.45788273316838,
|
| 183 |
+
22.916864843654203,
|
| 184 |
+
17.551236020192366,
|
| 185 |
+
29.612576631133486,
|
| 186 |
+
32.3629873396745,
|
| 187 |
+
46.47368512482174,
|
| 188 |
+
39.25242590132321,
|
| 189 |
+
23.782172906863416
|
| 190 |
+
]
|
| 191 |
+
}
|
| 192 |
+
},
|
| 193 |
+
"base": {
|
| 194 |
+
"ssim": {
|
| 195 |
+
"avg": 0.8774991716657367,
|
| 196 |
+
"vals": [
|
| 197 |
+
0.8641023635864258,
|
| 198 |
+
0.8767493367195129,
|
| 199 |
+
0.8762156367301941,
|
| 200 |
+
0.90242999792099,
|
| 201 |
+
0.9017764925956726,
|
| 202 |
+
0.8788102269172668,
|
| 203 |
+
0.8805813789367676,
|
| 204 |
+
0.8612765669822693,
|
| 205 |
+
0.8842142224311829,
|
| 206 |
+
0.8914286494255066,
|
| 207 |
+
0.8772357106208801,
|
| 208 |
+
0.8998511433601379,
|
| 209 |
+
0.8634107112884521,
|
| 210 |
+
0.8928578495979309,
|
| 211 |
+
0.8836771845817566,
|
| 212 |
+
0.8996903300285339,
|
| 213 |
+
0.866424560546875,
|
| 214 |
+
0.8816824555397034,
|
| 215 |
+
0.8785967230796814,
|
| 216 |
+
0.8867073655128479,
|
| 217 |
+
0.8279438614845276,
|
| 218 |
+
0.8717477321624756,
|
| 219 |
+
0.8780853748321533,
|
| 220 |
+
0.8800415992736816,
|
| 221 |
+
0.8771164417266846,
|
| 222 |
+
0.8793924450874329,
|
| 223 |
+
0.8809206485748291,
|
| 224 |
+
0.8608385920524597,
|
| 225 |
+
0.893017590045929,
|
| 226 |
+
0.8689785003662109,
|
| 227 |
+
0.8729929327964783,
|
| 228 |
+
0.8567183017730713,
|
| 229 |
+
0.8588916659355164,
|
| 230 |
+
0.8677377104759216,
|
| 231 |
+
0.8715691566467285,
|
| 232 |
+
0.8649471402168274,
|
| 233 |
+
0.884833037853241,
|
| 234 |
+
0.913370668888092,
|
| 235 |
+
0.8970480561256409,
|
| 236 |
+
0.8238387107849121,
|
| 237 |
+
0.8899227976799011,
|
| 238 |
+
0.8872933387756348
|
| 239 |
+
]
|
| 240 |
+
},
|
| 241 |
+
"psnr": {
|
| 242 |
+
"avg": 21.225323061553667,
|
| 243 |
+
"vals": [
|
| 244 |
+
18.731802028716977,
|
| 245 |
+
22.517723678922334,
|
| 246 |
+
21.373543493859014,
|
| 247 |
+
22.758004432454015,
|
| 248 |
+
22.955530937204767,
|
| 249 |
+
20.145150147854874,
|
| 250 |
+
21.634592878743387,
|
| 251 |
+
20.379519268519918,
|
| 252 |
+
23.696208467856458,
|
| 253 |
+
22.98936924589105,
|
| 254 |
+
21.756553483824376,
|
| 255 |
+
23.69843738135742,
|
| 256 |
+
17.770960047525655,
|
| 257 |
+
22.7351530856508,
|
| 258 |
+
20.95728828090057,
|
| 259 |
+
23.440357115988917,
|
| 260 |
+
20.458007935114516,
|
| 261 |
+
19.903659717164693,
|
| 262 |
+
19.636953200826923,
|
| 263 |
+
21.34791655058212,
|
| 264 |
+
19.213297117979387,
|
| 265 |
+
21.474593210996353,
|
| 266 |
+
23.02245609312815,
|
| 267 |
+
21.57037201859069,
|
| 268 |
+
20.49675090194842,
|
| 269 |
+
21.149239825321867,
|
| 270 |
+
21.768542181018454,
|
| 271 |
+
20.75462471848509,
|
| 272 |
+
23.340499225598812,
|
| 273 |
+
20.184658098890935,
|
| 274 |
+
19.84577062228628,
|
| 275 |
+
18.651530199516753,
|
| 276 |
+
17.713097904145002,
|
| 277 |
+
19.501488720897637,
|
| 278 |
+
20.525847469970195,
|
| 279 |
+
20.044060320208107,
|
| 280 |
+
23.33142797550424,
|
| 281 |
+
25.84051181341309,
|
| 282 |
+
24.517940621682737,
|
| 283 |
+
16.233878462859337,
|
| 284 |
+
21.53839254074825,
|
| 285 |
+
21.857857163105283
|
| 286 |
+
]
|
| 287 |
+
},
|
| 288 |
+
"lpips": {
|
| 289 |
+
"avg": 0.17535314992779777,
|
| 290 |
+
"vals": [
|
| 291 |
+
0.1866557002067566,
|
| 292 |
+
0.19489219784736633,
|
| 293 |
+
0.1676192283630371,
|
| 294 |
+
0.13815467059612274,
|
| 295 |
+
0.1378740817308426,
|
| 296 |
+
0.19519048929214478,
|
| 297 |
+
0.1394621878862381,
|
| 298 |
+
0.18285232782363892,
|
| 299 |
+
0.1300656795501709,
|
| 300 |
+
0.13945598900318146,
|
| 301 |
+
0.1777152568101883,
|
| 302 |
+
0.1422179788351059,
|
| 303 |
+
0.2465370148420334,
|
| 304 |
+
0.173164963722229,
|
| 305 |
+
0.16046200692653656,
|
| 306 |
+
0.1408388763666153,
|
| 307 |
+
0.1748061180114746,
|
| 308 |
+
0.15652400255203247,
|
| 309 |
+
0.18918247520923615,
|
| 310 |
+
0.13944856822490692,
|
| 311 |
+
0.3080754280090332,
|
| 312 |
+
0.15395303070545197,
|
| 313 |
+
0.13146977126598358,
|
| 314 |
+
0.21760974824428558,
|
| 315 |
+
0.16546086966991425,
|
| 316 |
+
0.18631187081336975,
|
| 317 |
+
0.17548373341560364,
|
| 318 |
+
0.2248706817626953,
|
| 319 |
+
0.14110884070396423,
|
| 320 |
+
0.1748245805501938,
|
| 321 |
+
0.17234086990356445,
|
| 322 |
+
0.1958460807800293,
|
| 323 |
+
0.22889724373817444,
|
| 324 |
+
0.1727704405784607,
|
| 325 |
+
0.15012094378471375,
|
| 326 |
+
0.20727047324180603,
|
| 327 |
+
0.1201525330543518,
|
| 328 |
+
0.09184442460536957,
|
| 329 |
+
0.11003588140010834,
|
| 330 |
+
0.34251728653907776,
|
| 331 |
+
0.18208707869052887,
|
| 332 |
+
0.19866067171096802
|
| 333 |
+
]
|
| 334 |
+
},
|
| 335 |
+
"fid": {
|
| 336 |
+
"avg": 104.964235596045,
|
| 337 |
+
"vals": [
|
| 338 |
+
111.40174829746752,
|
| 339 |
+
148.92090091433516,
|
| 340 |
+
128.51386183067834,
|
| 341 |
+
61.321546832998955,
|
| 342 |
+
72.54608643303393,
|
| 343 |
+
119.84550665904197,
|
| 344 |
+
104.43999062002221,
|
| 345 |
+
60.195703096095556,
|
| 346 |
+
64.8998572256939,
|
| 347 |
+
49.645917433260855,
|
| 348 |
+
114.28769797616002,
|
| 349 |
+
66.89825805245846,
|
| 350 |
+
155.83362493523907,
|
| 351 |
+
88.44848612474361,
|
| 352 |
+
104.61621566440316,
|
| 353 |
+
163.43248950817738,
|
| 354 |
+
172.53727330437798,
|
| 355 |
+
101.98603573984371,
|
| 356 |
+
125.94605415324145,
|
| 357 |
+
102.60611610457259,
|
| 358 |
+
104.21668109264168,
|
| 359 |
+
102.24131609759299,
|
| 360 |
+
70.50377311449385,
|
| 361 |
+
161.15530398717667,
|
| 362 |
+
96.17823491839759,
|
| 363 |
+
105.78022442714524,
|
| 364 |
+
120.89941533798705,
|
| 365 |
+
94.75931214260221,
|
| 366 |
+
73.59833927600653,
|
| 367 |
+
112.51851327401378,
|
| 368 |
+
148.31978447992762,
|
| 369 |
+
113.95507129041124,
|
| 370 |
+
82.26377459422716,
|
| 371 |
+
154.98015860895285,
|
| 372 |
+
96.10781829434029,
|
| 373 |
+
100.0667395138281,
|
| 374 |
+
83.65808540884001,
|
| 375 |
+
50.67035884888851,
|
| 376 |
+
64.41624770555333,
|
| 377 |
+
123.2744913103858,
|
| 378 |
+
106.94257164137686,
|
| 379 |
+
123.66830876325491
|
| 380 |
+
]
|
| 381 |
+
}
|
| 382 |
+
}
|
| 383 |
+
},
|
| 384 |
+
"aaa": {
|
| 385 |
+
"ft": {
|
| 386 |
+
"ssim": {
|
| 387 |
+
"avg": 0.8912452217694875,
|
| 388 |
+
"vals": [
|
| 389 |
+
0.9061050415039062,
|
| 390 |
+
0.9104815125465393,
|
| 391 |
+
0.9111503958702087,
|
| 392 |
+
0.8999118208885193,
|
| 393 |
+
0.9075160026550293,
|
| 394 |
+
0.9037303924560547,
|
| 395 |
+
0.9100475311279297,
|
| 396 |
+
0.9105749130249023,
|
| 397 |
+
0.9070033431053162,
|
| 398 |
+
0.928657054901123,
|
| 399 |
+
0.9301331043243408,
|
| 400 |
+
0.9153156280517578,
|
| 401 |
+
0.9060927033424377,
|
| 402 |
+
0.9026443362236023,
|
| 403 |
+
0.9006186127662659,
|
| 404 |
+
0.8967500329017639,
|
| 405 |
+
0.8858065605163574,
|
| 406 |
+
0.8841550946235657,
|
| 407 |
+
0.8957152366638184,
|
| 408 |
+
0.9159244894981384,
|
| 409 |
+
0.9212849140167236,
|
| 410 |
+
0.8954355716705322,
|
| 411 |
+
0.882415771484375,
|
| 412 |
+
0.8875539898872375,
|
| 413 |
+
0.8762226700782776,
|
| 414 |
+
0.8741069436073303,
|
| 415 |
+
0.8663904070854187,
|
| 416 |
+
0.8608574867248535,
|
| 417 |
+
0.8592395782470703,
|
| 418 |
+
0.8617714047431946,
|
| 419 |
+
0.8670153021812439,
|
| 420 |
+
0.8589889407157898,
|
| 421 |
+
0.8667375445365906,
|
| 422 |
+
0.8646472096443176,
|
| 423 |
+
0.8644879460334778,
|
| 424 |
+
0.8704419732093811,
|
| 425 |
+
0.8701417446136475
|
| 426 |
+
]
|
| 427 |
+
},
|
| 428 |
+
"psnr": {
|
| 429 |
+
"avg": 23.218790631433333,
|
| 430 |
+
"vals": [
|
| 431 |
+
22.60000443641932,
|
| 432 |
+
23.213327330683896,
|
| 433 |
+
23.966372308676554,
|
| 434 |
+
22.61628135760781,
|
| 435 |
+
23.334516149012956,
|
| 436 |
+
22.76956956037783,
|
| 437 |
+
23.522473461596828,
|
| 438 |
+
23.818158050349368,
|
| 439 |
+
24.35100933616721,
|
| 440 |
+
27.585580066737258,
|
| 441 |
+
27.485016628132442,
|
| 442 |
+
25.653444844066296,
|
| 443 |
+
24.116901996142047,
|
| 444 |
+
23.141365324276826,
|
| 445 |
+
23.439784068380483,
|
| 446 |
+
23.33054311947808,
|
| 447 |
+
22.82591865046196,
|
| 448 |
+
22.71185043604199,
|
| 449 |
+
24.538035778803117,
|
| 450 |
+
27.826709590087102,
|
| 451 |
+
28.474405666001857,
|
| 452 |
+
24.769433736339693,
|
| 453 |
+
23.449171903609013,
|
| 454 |
+
23.46383536533758,
|
| 455 |
+
21.324737956273246,
|
| 456 |
+
21.96812016491501,
|
| 457 |
+
20.7294582590173,
|
| 458 |
+
20.641701301001603,
|
| 459 |
+
20.774969220796272,
|
| 460 |
+
21.108372009572886,
|
| 461 |
+
22.00821781586527,
|
| 462 |
+
20.4745834349127,
|
| 463 |
+
21.075791929560445,
|
| 464 |
+
21.104358437386903,
|
| 465 |
+
21.55546473865332,
|
| 466 |
+
22.000344888516832,
|
| 467 |
+
21.32542404177411
|
| 468 |
+
]
|
| 469 |
+
},
|
| 470 |
+
"lpips": {
|
| 471 |
+
"avg": 0.10130888052486084,
|
| 472 |
+
"vals": [
|
| 473 |
+
0.14061428606510162,
|
| 474 |
+
0.14724630117416382,
|
| 475 |
+
0.1335088610649109,
|
| 476 |
+
0.13316522538661957,
|
| 477 |
+
0.12164061516523361,
|
| 478 |
+
0.11964336037635803,
|
| 479 |
+
0.10177505016326904,
|
| 480 |
+
0.10693053156137466,
|
| 481 |
+
0.0958232507109642,
|
| 482 |
+
0.05439213290810585,
|
| 483 |
+
0.05292022228240967,
|
| 484 |
+
0.07251676917076111,
|
| 485 |
+
0.09491521865129471,
|
| 486 |
+
0.11164701730012894,
|
| 487 |
+
0.1037021279335022,
|
| 488 |
+
0.09932567179203033,
|
| 489 |
+
0.0995492935180664,
|
| 490 |
+
0.0946572870016098,
|
| 491 |
+
0.07657209038734436,
|
| 492 |
+
0.04164290428161621,
|
| 493 |
+
0.04009518772363663,
|
| 494 |
+
0.06518009305000305,
|
| 495 |
+
0.07143466174602509,
|
| 496 |
+
0.07886430621147156,
|
| 497 |
+
0.09445148706436157,
|
| 498 |
+
0.11044608056545258,
|
| 499 |
+
0.1215718686580658,
|
| 500 |
+
0.12195122241973877,
|
| 501 |
+
0.12177268415689468,
|
| 502 |
+
0.13400736451148987,
|
| 503 |
+
0.1025841161608696,
|
| 504 |
+
0.11698935180902481,
|
| 505 |
+
0.11627939343452454,
|
| 506 |
+
0.11498457193374634,
|
| 507 |
+
0.11028993129730225,
|
| 508 |
+
0.10681547224521637,
|
| 509 |
+
0.11852256953716278
|
| 510 |
+
]
|
| 511 |
+
},
|
| 512 |
+
"fid": {
|
| 513 |
+
"avg": 128.47027428857402,
|
| 514 |
+
"vals": [
|
| 515 |
+
210.6965716493914,
|
| 516 |
+
136.02845249812458,
|
| 517 |
+
142.9994319683501,
|
| 518 |
+
102.24152561128506,
|
| 519 |
+
106.32239193197034,
|
| 520 |
+
133.25417613786345,
|
| 521 |
+
148.41552185542787,
|
| 522 |
+
156.6543860781757,
|
| 523 |
+
169.9280658509937,
|
| 524 |
+
96.20484628876847,
|
| 525 |
+
124.39223402724446,
|
| 526 |
+
157.57056317546318,
|
| 527 |
+
182.60839379211808,
|
| 528 |
+
143.061621048883,
|
| 529 |
+
129.08572262883925,
|
| 530 |
+
146.6837514000539,
|
| 531 |
+
121.28023529371208,
|
| 532 |
+
133.09083552748842,
|
| 533 |
+
133.1701816777087,
|
| 534 |
+
106.30667335984795,
|
| 535 |
+
92.99936563096986,
|
| 536 |
+
106.42616281960741,
|
| 537 |
+
91.02325698154696,
|
| 538 |
+
82.03384095285311,
|
| 539 |
+
100.23715859488628,
|
| 540 |
+
103.07187891625573,
|
| 541 |
+
126.73029941764449,
|
| 542 |
+
120.07365192426221,
|
| 543 |
+
120.96345164031321,
|
| 544 |
+
116.58719477188482,
|
| 545 |
+
167.96926051103426,
|
| 546 |
+
180.28950392287587,
|
| 547 |
+
105.1495396073212,
|
| 548 |
+
140.74437360480772,
|
| 549 |
+
99.20108044973863,
|
| 550 |
+
113.96675633000953,
|
| 551 |
+
105.93779079951634
|
| 552 |
+
]
|
| 553 |
+
}
|
| 554 |
+
},
|
| 555 |
+
"base": {
|
| 556 |
+
"ssim": {
|
| 557 |
+
"avg": 0.8579733484500164,
|
| 558 |
+
"vals": [
|
| 559 |
+
0.8666160702705383,
|
| 560 |
+
0.7987958788871765,
|
| 561 |
+
0.8860695362091064,
|
| 562 |
+
0.8763647079467773,
|
| 563 |
+
0.8763935565948486,
|
| 564 |
+
0.8563421368598938,
|
| 565 |
+
0.8911734223365784,
|
| 566 |
+
0.8716912269592285,
|
| 567 |
+
0.8579948544502258,
|
| 568 |
+
0.8772208094596863,
|
| 569 |
+
0.8896298408508301,
|
| 570 |
+
0.8552396893501282,
|
| 571 |
+
0.8748952746391296,
|
| 572 |
+
0.8613572120666504,
|
| 573 |
+
0.8676340579986572,
|
| 574 |
+
0.8660659193992615,
|
| 575 |
+
0.8730340600013733,
|
| 576 |
+
0.8534577488899231,
|
| 577 |
+
0.8619491457939148,
|
| 578 |
+
0.8790878653526306,
|
| 579 |
+
0.869230329990387,
|
| 580 |
+
0.8246889114379883,
|
| 581 |
+
0.8481850624084473,
|
| 582 |
+
0.8987539410591125,
|
| 583 |
+
0.8677256107330322,
|
| 584 |
+
0.8619834780693054,
|
| 585 |
+
0.8389760851860046,
|
| 586 |
+
0.8448922634124756,
|
| 587 |
+
0.8406286239624023,
|
| 588 |
+
0.805531919002533,
|
| 589 |
+
0.8549509048461914,
|
| 590 |
+
0.8503775596618652,
|
| 591 |
+
0.8362762928009033,
|
| 592 |
+
0.8465806841850281,
|
| 593 |
+
0.8307406902313232,
|
| 594 |
+
0.850950300693512,
|
| 595 |
+
0.8335282206535339
|
| 596 |
+
]
|
| 597 |
+
},
|
| 598 |
+
"psnr": {
|
| 599 |
+
"avg": 19.081316861697736,
|
| 600 |
+
"vals": [
|
| 601 |
+
16.88895564145245,
|
| 602 |
+
15.797760205800826,
|
| 603 |
+
18.41262183621939,
|
| 604 |
+
18.562114490640262,
|
| 605 |
+
17.917161492127512,
|
| 606 |
+
18.252547952717784,
|
| 607 |
+
19.64865146927212,
|
| 608 |
+
17.6865025569328,
|
| 609 |
+
17.99347739896675,
|
| 610 |
+
18.854024365954317,
|
| 611 |
+
19.934364132621422,
|
| 612 |
+
18.675988138292176,
|
| 613 |
+
19.247068668941317,
|
| 614 |
+
19.570660435858237,
|
| 615 |
+
17.830968802374684,
|
| 616 |
+
19.397491216401374,
|
| 617 |
+
21.611632805892775,
|
| 618 |
+
20.071126840114104,
|
| 619 |
+
19.438827872416823,
|
| 620 |
+
22.09851766389441,
|
| 621 |
+
20.985060014772596,
|
| 622 |
+
17.811881997736606,
|
| 623 |
+
20.828426349668014,
|
| 624 |
+
21.464988324765418,
|
| 625 |
+
19.570091780949138,
|
| 626 |
+
20.7044778911251,
|
| 627 |
+
18.34111665771247,
|
| 628 |
+
19.15076106907652,
|
| 629 |
+
20.11914583622098,
|
| 630 |
+
16.338984721289947,
|
| 631 |
+
21.540692338174694,
|
| 632 |
+
19.138283664226506,
|
| 633 |
+
17.11142702905118,
|
| 634 |
+
18.437481626987633,
|
| 635 |
+
18.706711870048608,
|
| 636 |
+
20.13241290437168,
|
| 637 |
+
17.73631581974743
|
| 638 |
+
]
|
| 639 |
+
},
|
| 640 |
+
"lpips": {
|
| 641 |
+
"avg": 0.2180521393547187,
|
| 642 |
+
"vals": [
|
| 643 |
+
0.24907447397708893,
|
| 644 |
+
0.5150208473205566,
|
| 645 |
+
0.21998971700668335,
|
| 646 |
+
0.30011799931526184,
|
| 647 |
+
0.24314486980438232,
|
| 648 |
+
0.2725614607334137,
|
| 649 |
+
0.2167254090309143,
|
| 650 |
+
0.22530333697795868,
|
| 651 |
+
0.2836877703666687,
|
| 652 |
+
0.2629338204860687,
|
| 653 |
+
0.16042795777320862,
|
| 654 |
+
0.2589859664440155,
|
| 655 |
+
0.18903297185897827,
|
| 656 |
+
0.19533997774124146,
|
| 657 |
+
0.27064773440361023,
|
| 658 |
+
0.1686403751373291,
|
| 659 |
+
0.12280213087797165,
|
| 660 |
+
0.25335609912872314,
|
| 661 |
+
0.29388710856437683,
|
| 662 |
+
0.09789006412029266,
|
| 663 |
+
0.1359485536813736,
|
| 664 |
+
0.37839996814727783,
|
| 665 |
+
0.16171419620513916,
|
| 666 |
+
0.09748921543359756,
|
| 667 |
+
0.11395677924156189,
|
| 668 |
+
0.13522587716579437,
|
| 669 |
+
0.20320472121238708,
|
| 670 |
+
0.20818984508514404,
|
| 671 |
+
0.16585132479667664,
|
| 672 |
+
0.2417808175086975,
|
| 673 |
+
0.11952673643827438,
|
| 674 |
+
0.15657898783683777,
|
| 675 |
+
0.21135613322257996,
|
| 676 |
+
0.20219209790229797,
|
| 677 |
+
0.2686466574668884,
|
| 678 |
+
0.16088804602622986,
|
| 679 |
+
0.3074091076850891
|
| 680 |
+
]
|
| 681 |
+
},
|
| 682 |
+
"fid": {
|
| 683 |
+
"avg": 168.1660536989984,
|
| 684 |
+
"vals": [
|
| 685 |
+
247.82225753332435,
|
| 686 |
+
261.94302538502393,
|
| 687 |
+
148.08945162416452,
|
| 688 |
+
156.38665720624414,
|
| 689 |
+
164.2083172443518,
|
| 690 |
+
152.55262134635595,
|
| 691 |
+
167.74632833375753,
|
| 692 |
+
186.55618400333205,
|
| 693 |
+
237.13432800336085,
|
| 694 |
+
295.59781724649554,
|
| 695 |
+
183.50749905003778,
|
| 696 |
+
233.3946841940326,
|
| 697 |
+
127.19021998668687,
|
| 698 |
+
237.64574583795806,
|
| 699 |
+
212.50245640676792,
|
| 700 |
+
138.55530819236137,
|
| 701 |
+
214.2925259680326,
|
| 702 |
+
206.29937532009325,
|
| 703 |
+
166.21050755154022,
|
| 704 |
+
78.77898145774333,
|
| 705 |
+
102.08709635256376,
|
| 706 |
+
205.22517444539076,
|
| 707 |
+
103.51534026875984,
|
| 708 |
+
79.41715249937273,
|
| 709 |
+
119.52601444212452,
|
| 710 |
+
103.10200144999655,
|
| 711 |
+
177.805092182125,
|
| 712 |
+
132.67670025242347,
|
| 713 |
+
131.8997749255857,
|
| 714 |
+
137.54224584649327,
|
| 715 |
+
132.2298444263562,
|
| 716 |
+
220.02890730124264,
|
| 717 |
+
148.88606776876063,
|
| 718 |
+
136.47310567443643,
|
| 719 |
+
165.10752223661132,
|
| 720 |
+
125.63803892245707,
|
| 721 |
+
184.56961597657588
|
| 722 |
+
]
|
| 723 |
+
}
|
| 724 |
+
}
|
| 725 |
+
},
|
| 726 |
+
"azri": {
|
| 727 |
+
"ft": {
|
| 728 |
+
"ssim": {
|
| 729 |
+
"avg": 0.9058952973439143,
|
| 730 |
+
"vals": [
|
| 731 |
+
0.8917362689971924,
|
| 732 |
+
0.9481084942817688,
|
| 733 |
+
0.8861625790596008,
|
| 734 |
+
0.9341784119606018,
|
| 735 |
+
0.9068629145622253,
|
| 736 |
+
0.9030451774597168,
|
| 737 |
+
0.9057510495185852,
|
| 738 |
+
0.8903522491455078,
|
| 739 |
+
0.8899074196815491,
|
| 740 |
+
0.8972444534301758,
|
| 741 |
+
0.9205942153930664,
|
| 742 |
+
0.9142631888389587,
|
| 743 |
+
0.8862788081169128,
|
| 744 |
+
0.9196522235870361,
|
| 745 |
+
0.8918246626853943,
|
| 746 |
+
0.9560407996177673,
|
| 747 |
+
0.901726484298706,
|
| 748 |
+
0.8848044276237488,
|
| 749 |
+
0.8926780819892883,
|
| 750 |
+
0.8907856941223145,
|
| 751 |
+
0.9449396729469299,
|
| 752 |
+
0.8873879313468933,
|
| 753 |
+
0.9299740791320801,
|
| 754 |
+
0.904637336730957,
|
| 755 |
+
0.9079319834709167,
|
| 756 |
+
0.9113084673881531,
|
| 757 |
+
0.8961232304573059,
|
| 758 |
+
0.881411075592041,
|
| 759 |
+
0.891535222530365,
|
| 760 |
+
0.9156699776649475,
|
| 761 |
+
0.9131000638008118,
|
| 762 |
+
0.8913257122039795,
|
| 763 |
+
0.9246047139167786,
|
| 764 |
+
0.8882644772529602,
|
| 765 |
+
0.9553866982460022,
|
| 766 |
+
0.8979193568229675,
|
| 767 |
+
0.884848415851593,
|
| 768 |
+
0.8858329653739929,
|
| 769 |
+
0.8967573046684265,
|
| 770 |
+
0.9442141652107239,
|
| 771 |
+
0.8944272398948669,
|
| 772 |
+
0.9422566294670105,
|
| 773 |
+
0.9031774401664734,
|
| 774 |
+
0.9035826325416565,
|
| 775 |
+
0.9115505814552307,
|
| 776 |
+
0.8966580033302307,
|
| 777 |
+
0.8833451867103577,
|
| 778 |
+
0.8935332894325256,
|
| 779 |
+
0.9135630130767822,
|
| 780 |
+
0.9111640453338623,
|
| 781 |
+
0.8814460635185242,
|
| 782 |
+
0.9066808819770813
|
| 783 |
+
]
|
| 784 |
+
},
|
| 785 |
+
"psnr": {
|
| 786 |
+
"avg": 22.26210885067453,
|
| 787 |
+
"vals": [
|
| 788 |
+
21.127093280116206,
|
| 789 |
+
24.839156863255635,
|
| 790 |
+
20.600073186747085,
|
| 791 |
+
25.025180077941343,
|
| 792 |
+
22.081731981690943,
|
| 793 |
+
22.42101370484669,
|
| 794 |
+
22.81864654433271,
|
| 795 |
+
21.18838097936547,
|
| 796 |
+
22.058042539733158,
|
| 797 |
+
21.10834360782046,
|
| 798 |
+
23.075382471633933,
|
| 799 |
+
22.544052476196704,
|
| 800 |
+
19.00360259824341,
|
| 801 |
+
21.92184820041113,
|
| 802 |
+
21.430364755014008,
|
| 803 |
+
28.074184819479914,
|
| 804 |
+
22.007287073710827,
|
| 805 |
+
20.596292100864723,
|
| 806 |
+
22.0403525229489,
|
| 807 |
+
20.915785732740424,
|
| 808 |
+
24.24433505144897,
|
| 809 |
+
20.759471581753495,
|
| 810 |
+
23.580874095639594,
|
| 811 |
+
21.765209438099404,
|
| 812 |
+
22.735731723488712,
|
| 813 |
+
23.165893102346963,
|
| 814 |
+
21.193429487248192,
|
| 815 |
+
20.883114381461837,
|
| 816 |
+
20.509329225340185,
|
| 817 |
+
22.74356627523407,
|
| 818 |
+
23.076752632172514,
|
| 819 |
+
20.80786955805867,
|
| 820 |
+
22.356163191215202,
|
| 821 |
+
21.454672093200195,
|
| 822 |
+
28.091895600685465,
|
| 823 |
+
21.89353215478105,
|
| 824 |
+
20.445727082361902,
|
| 825 |
+
21.73245024548928,
|
| 826 |
+
21.415697749014967,
|
| 827 |
+
25.255915842679492,
|
| 828 |
+
21.344568596586,
|
| 829 |
+
25.736038367482795,
|
| 830 |
+
21.80143662440821,
|
| 831 |
+
22.556710854506832,
|
| 832 |
+
23.011730483476125,
|
| 833 |
+
21.360129615824693,
|
| 834 |
+
21.183864623422796,
|
| 835 |
+
20.786549475403824,
|
| 836 |
+
23.05387236898404,
|
| 837 |
+
22.61546996399324,
|
| 838 |
+
18.695781286483605,
|
| 839 |
+
22.495061945689415
|
| 840 |
+
]
|
| 841 |
+
},
|
| 842 |
+
"lpips": {
|
| 843 |
+
"avg": 0.0681791385420813,
|
| 844 |
+
"vals": [
|
| 845 |
+
0.07602706551551819,
|
| 846 |
+
0.04024939239025116,
|
| 847 |
+
0.08144165575504303,
|
| 848 |
+
0.04438445717096329,
|
| 849 |
+
0.07352830469608307,
|
| 850 |
+
0.05720607936382294,
|
| 851 |
+
0.06115525960922241,
|
| 852 |
+
0.08031078428030014,
|
| 853 |
+
0.07755088806152344,
|
| 854 |
+
0.08515866100788116,
|
| 855 |
+
0.058882661163806915,
|
| 856 |
+
0.07514132559299469,
|
| 857 |
+
0.08883378654718399,
|
| 858 |
+
0.08077041804790497,
|
| 859 |
+
0.0614573135972023,
|
| 860 |
+
0.017901955172419548,
|
| 861 |
+
0.06920649111270905,
|
| 862 |
+
0.07639990746974945,
|
| 863 |
+
0.07837355881929398,
|
| 864 |
+
0.0728088989853859,
|
| 865 |
+
0.04596266895532608,
|
| 866 |
+
0.09040576219558716,
|
| 867 |
+
0.057428933680057526,
|
| 868 |
+
0.07047676295042038,
|
| 869 |
+
0.054519202560186386,
|
| 870 |
+
0.0576803982257843,
|
| 871 |
+
0.0698535293340683,
|
| 872 |
+
0.08105408400297165,
|
| 873 |
+
0.09364423155784607,
|
| 874 |
+
0.06507238000631332,
|
| 875 |
+
0.07589171826839447,
|
| 876 |
+
0.0750972181558609,
|
| 877 |
+
0.074916310608387,
|
| 878 |
+
0.062267474830150604,
|
| 879 |
+
0.021616671234369278,
|
| 880 |
+
0.06903669983148575,
|
| 881 |
+
0.07748014479875565,
|
| 882 |
+
0.08583179116249084,
|
| 883 |
+
0.06618104875087738,
|
| 884 |
+
0.05046549439430237,
|
| 885 |
+
0.07481929659843445,
|
| 886 |
+
0.042737700045108795,
|
| 887 |
+
0.07031229138374329,
|
| 888 |
+
0.05742187052965164,
|
| 889 |
+
0.058822184801101685,
|
| 890 |
+
0.08137130737304688,
|
| 891 |
+
0.08340458571910858,
|
| 892 |
+
0.08779360353946686,
|
| 893 |
+
0.05354199558496475,
|
| 894 |
+
0.0719795972108841,
|
| 895 |
+
0.09594655781984329,
|
| 896 |
+
0.06549282371997833
|
| 897 |
+
]
|
| 898 |
+
},
|
| 899 |
+
"fid": {
|
| 900 |
+
"avg": 87.69454384414368,
|
| 901 |
+
"vals": [
|
| 902 |
+
77.53360222199552,
|
| 903 |
+
81.69009799979364,
|
| 904 |
+
126.35462309879928,
|
| 905 |
+
128.097754928267,
|
| 906 |
+
117.15623067707392,
|
| 907 |
+
98.01858464323817,
|
| 908 |
+
77.96826502055629,
|
| 909 |
+
78.41206612900142,
|
| 910 |
+
110.78376599817574,
|
| 911 |
+
63.587186575144706,
|
| 912 |
+
54.89619147757511,
|
| 913 |
+
84.46815633637291,
|
| 914 |
+
76.09334109513291,
|
| 915 |
+
89.06784327761028,
|
| 916 |
+
112.09989177281398,
|
| 917 |
+
77.76499620466669,
|
| 918 |
+
127.0296162823459,
|
| 919 |
+
56.2819223080697,
|
| 920 |
+
69.43078527369693,
|
| 921 |
+
64.65111942898734,
|
| 922 |
+
62.87374700137458,
|
| 923 |
+
128.69000277570802,
|
| 924 |
+
85.75299650509533,
|
| 925 |
+
101.57266095839137,
|
| 926 |
+
69.76067120877498,
|
| 927 |
+
86.34886735845666,
|
| 928 |
+
68.48340578638442,
|
| 929 |
+
98.73242401936356,
|
| 930 |
+
63.17021113035836,
|
| 931 |
+
65.11579193591135,
|
| 932 |
+
77.95177629642525,
|
| 933 |
+
89.78335003024443,
|
| 934 |
+
86.28969096009206,
|
| 935 |
+
119.6802369335825,
|
| 936 |
+
89.66745457666659,
|
| 937 |
+
132.85212105669805,
|
| 938 |
+
114.10763350784305,
|
| 939 |
+
103.6134398480294,
|
| 940 |
+
59.686968923134835,
|
| 941 |
+
82.54559469802132,
|
| 942 |
+
111.1869442643463,
|
| 943 |
+
85.38777030577579,
|
| 944 |
+
106.32082608988044,
|
| 945 |
+
106.12593735864233,
|
| 946 |
+
86.00820771852335,
|
| 947 |
+
64.28650309776683,
|
| 948 |
+
102.70937080791995,
|
| 949 |
+
47.320553137933764,
|
| 950 |
+
53.731227883240436,
|
| 951 |
+
71.80043885301173,
|
| 952 |
+
96.20161778732583,
|
| 953 |
+
70.97179633123264
|
| 954 |
+
]
|
| 955 |
+
}
|
| 956 |
+
},
|
| 957 |
+
"base": {
|
| 958 |
+
"ssim": {
|
| 959 |
+
"avg": 0.765552927668278,
|
| 960 |
+
"vals": [
|
| 961 |
+
0.7876270413398743,
|
| 962 |
+
0.6932912468910217,
|
| 963 |
+
0.7860949039459229,
|
| 964 |
+
0.8119332194328308,
|
| 965 |
+
0.7923873066902161,
|
| 966 |
+
0.7616052031517029,
|
| 967 |
+
0.7635032534599304,
|
| 968 |
+
0.752696692943573,
|
| 969 |
+
0.7297654151916504,
|
| 970 |
+
0.7741971015930176,
|
| 971 |
+
0.7830759882926941,
|
| 972 |
+
0.787074089050293,
|
| 973 |
+
0.7227017283439636,
|
| 974 |
+
0.7842667698860168,
|
| 975 |
+
0.7618840336799622,
|
| 976 |
+
0.7612975239753723,
|
| 977 |
+
0.7881305813789368,
|
| 978 |
+
0.7242482304573059,
|
| 979 |
+
0.7725765109062195,
|
| 980 |
+
0.7895841598510742,
|
| 981 |
+
0.7584860920906067,
|
| 982 |
+
0.7806427478790283,
|
| 983 |
+
0.7808322906494141,
|
| 984 |
+
0.8068225979804993,
|
| 985 |
+
0.7732831835746765,
|
| 986 |
+
0.7021427154541016,
|
| 987 |
+
0.7602499127388,
|
| 988 |
+
0.7775993943214417,
|
| 989 |
+
0.7879728674888611,
|
| 990 |
+
0.7852631211280823,
|
| 991 |
+
0.7593006491661072,
|
| 992 |
+
0.7491123080253601,
|
| 993 |
+
0.8211724162101746,
|
| 994 |
+
0.791597306728363,
|
| 995 |
+
0.8033311367034912,
|
| 996 |
+
0.751656711101532,
|
| 997 |
+
0.7145156860351562,
|
| 998 |
+
0.7085480690002441,
|
| 999 |
+
0.7710719704627991,
|
| 1000 |
+
0.7261748313903809,
|
| 1001 |
+
0.7700297236442566,
|
| 1002 |
+
0.721587598323822,
|
| 1003 |
+
0.769636869430542,
|
| 1004 |
+
0.7924219965934753,
|
| 1005 |
+
0.7783514857292175,
|
| 1006 |
+
0.7467233538627625,
|
| 1007 |
+
0.7454271912574768,
|
| 1008 |
+
0.7511434555053711,
|
| 1009 |
+
0.7623993754386902,
|
| 1010 |
+
0.7588648796081543,
|
| 1011 |
+
0.7819810509681702,
|
| 1012 |
+
0.792468249797821
|
| 1013 |
+
]
|
| 1014 |
+
},
|
| 1015 |
+
"psnr": {
|
| 1016 |
+
"avg": 12.04275296711651,
|
| 1017 |
+
"vals": [
|
| 1018 |
+
11.831093266628356,
|
| 1019 |
+
12.528152127082233,
|
| 1020 |
+
11.504680724545494,
|
| 1021 |
+
12.06579748941828,
|
| 1022 |
+
12.193623545830743,
|
| 1023 |
+
12.967612777094432,
|
| 1024 |
+
11.22349001603837,
|
| 1025 |
+
12.676957730222881,
|
| 1026 |
+
10.463052935595755,
|
| 1027 |
+
12.068058666401921,
|
| 1028 |
+
11.815425022083929,
|
| 1029 |
+
11.474556127083932,
|
| 1030 |
+
11.65343712380346,
|
| 1031 |
+
12.970802840776267,
|
| 1032 |
+
10.921165557898965,
|
| 1033 |
+
11.487784206793332,
|
| 1034 |
+
11.862155399851087,
|
| 1035 |
+
12.500533116810459,
|
| 1036 |
+
13.26478322723994,
|
| 1037 |
+
11.454244322299186,
|
| 1038 |
+
12.846474859150605,
|
| 1039 |
+
11.178912278082944,
|
| 1040 |
+
11.668886917161911,
|
| 1041 |
+
13.535748073878615,
|
| 1042 |
+
11.74158297715729,
|
| 1043 |
+
11.307205931581416,
|
| 1044 |
+
11.122995221603784,
|
| 1045 |
+
12.315128915653881,
|
| 1046 |
+
11.927970316701074,
|
| 1047 |
+
11.471458996386659,
|
| 1048 |
+
10.8378095973721,
|
| 1049 |
+
12.40354222062087,
|
| 1050 |
+
13.883183470169415,
|
| 1051 |
+
12.746526277133636,
|
| 1052 |
+
13.423830377359124,
|
| 1053 |
+
12.183072443407536,
|
| 1054 |
+
11.745039572219284,
|
| 1055 |
+
11.288605142920026,
|
| 1056 |
+
10.811118546539486,
|
| 1057 |
+
11.492946654674247,
|
| 1058 |
+
12.552522641514392,
|
| 1059 |
+
12.042564039263278,
|
| 1060 |
+
11.3939431319692,
|
| 1061 |
+
12.898514335036811,
|
| 1062 |
+
11.114529050567972,
|
| 1063 |
+
11.339446334579078,
|
| 1064 |
+
12.24930476842462,
|
| 1065 |
+
13.364881567497516,
|
| 1066 |
+
13.445838873213505,
|
| 1067 |
+
12.214397326986626,
|
| 1068 |
+
12.84755151678836,
|
| 1069 |
+
11.900215690944371
|
| 1070 |
+
]
|
| 1071 |
+
},
|
| 1072 |
+
"lpips": {
|
| 1073 |
+
"avg": 0.3510053243774634,
|
| 1074 |
+
"vals": [
|
| 1075 |
+
0.3419290781021118,
|
| 1076 |
+
0.4748748242855072,
|
| 1077 |
+
0.3430050015449524,
|
| 1078 |
+
0.3214326500892639,
|
| 1079 |
+
0.28315189480781555,
|
| 1080 |
+
0.38122087717056274,
|
| 1081 |
+
0.40785902738571167,
|
| 1082 |
+
0.283894807100296,
|
| 1083 |
+
0.28781914710998535,
|
| 1084 |
+
0.31599509716033936,
|
| 1085 |
+
0.31139951944351196,
|
| 1086 |
+
0.40369728207588196,
|
| 1087 |
+
0.4894903302192688,
|
| 1088 |
+
0.3339408040046692,
|
| 1089 |
+
0.3329698443412781,
|
| 1090 |
+
0.2978940010070801,
|
| 1091 |
+
0.39562928676605225,
|
| 1092 |
+
0.3887198865413666,
|
| 1093 |
+
0.3585241734981537,
|
| 1094 |
+
0.31963711977005005,
|
| 1095 |
+
0.31140702962875366,
|
| 1096 |
+
0.3825278878211975,
|
| 1097 |
+
0.3666013181209564,
|
| 1098 |
+
0.24337142705917358,
|
| 1099 |
+
0.36876243352890015,
|
| 1100 |
+
0.36285972595214844,
|
| 1101 |
+
0.33685219287872314,
|
| 1102 |
+
0.3693651854991913,
|
| 1103 |
+
0.2829691767692566,
|
| 1104 |
+
0.30595219135284424,
|
| 1105 |
+
0.3462897539138794,
|
| 1106 |
+
0.49221092462539673,
|
| 1107 |
+
0.35862377285957336,
|
| 1108 |
+
0.2963302731513977,
|
| 1109 |
+
0.2860722541809082,
|
| 1110 |
+
0.33742064237594604,
|
| 1111 |
+
0.46653813123703003,
|
| 1112 |
+
0.4688953757286072,
|
| 1113 |
+
0.3590819835662842,
|
| 1114 |
+
0.3915751874446869,
|
| 1115 |
+
0.43592336773872375,
|
| 1116 |
+
0.4350709915161133,
|
| 1117 |
+
0.3321887254714966,
|
| 1118 |
+
0.2887817919254303,
|
| 1119 |
+
0.3315331041812897,
|
| 1120 |
+
0.34705182909965515,
|
| 1121 |
+
0.29848411679267883,
|
| 1122 |
+
0.3338838815689087,
|
| 1123 |
+
0.3551400601863861,
|
| 1124 |
+
0.2849394679069519,
|
| 1125 |
+
0.276046484708786,
|
| 1126 |
+
0.32644152641296387
|
| 1127 |
+
]
|
| 1128 |
+
},
|
| 1129 |
+
"fid": {
|
| 1130 |
+
"avg": 290.9468764478715,
|
| 1131 |
+
"vals": [
|
| 1132 |
+
317.3580192368432,
|
| 1133 |
+
285.3201578307224,
|
| 1134 |
+
227.35443388770344,
|
| 1135 |
+
284.40224085168455,
|
| 1136 |
+
250.42273416927597,
|
| 1137 |
+
332.1637800582306,
|
| 1138 |
+
421.4826897186527,
|
| 1139 |
+
175.65665020370423,
|
| 1140 |
+
462.80212160877016,
|
| 1141 |
+
313.11887539989885,
|
| 1142 |
+
216.874426261107,
|
| 1143 |
+
318.69325303966303,
|
| 1144 |
+
352.27748822517265,
|
| 1145 |
+
337.56780348968255,
|
| 1146 |
+
333.41852463427074,
|
| 1147 |
+
291.18619781007527,
|
| 1148 |
+
231.22305793958952,
|
| 1149 |
+
182.59524666699008,
|
| 1150 |
+
174.66560996071973,
|
| 1151 |
+
287.25586265039675,
|
| 1152 |
+
245.32536993670467,
|
| 1153 |
+
340.7025359683132,
|
| 1154 |
+
430.81577040883786,
|
| 1155 |
+
151.6384776188366,
|
| 1156 |
+
179.9751120645675,
|
| 1157 |
+
262.16017361838544,
|
| 1158 |
+
252.1575376021156,
|
| 1159 |
+
383.6182322995117,
|
| 1160 |
+
293.8831833204783,
|
| 1161 |
+
238.37454832172995,
|
| 1162 |
+
253.2247312385866,
|
| 1163 |
+
300.3330815210274,
|
| 1164 |
+
263.0408874830202,
|
| 1165 |
+
308.8415900149356,
|
| 1166 |
+
208.06900694354897,
|
| 1167 |
+
312.3743463947788,
|
| 1168 |
+
281.0093276692239,
|
| 1169 |
+
401.06852280176753,
|
| 1170 |
+
412.57353143782905,
|
| 1171 |
+
245.6395460326994,
|
| 1172 |
+
330.86620573911364,
|
| 1173 |
+
396.22633300703814,
|
| 1174 |
+
253.3246381630602,
|
| 1175 |
+
243.78755013247724,
|
| 1176 |
+
343.90170319814894,
|
| 1177 |
+
411.48625023760695,
|
| 1178 |
+
426.23161153775044,
|
| 1179 |
+
167.20827449948533,
|
| 1180 |
+
253.45101894407023,
|
| 1181 |
+
249.6648801639487,
|
| 1182 |
+
233.26966302525298,
|
| 1183 |
+
259.15476030131083
|
| 1184 |
+
]
|
| 1185 |
+
}
|
| 1186 |
+
}
|
| 1187 |
+
},
|
| 1188 |
+
"dead": {
|
| 1189 |
+
"ft": {
|
| 1190 |
+
"ssim": {
|
| 1191 |
+
"avg": 0.8463698829475202,
|
| 1192 |
+
"vals": [
|
| 1193 |
+
0.819096565246582,
|
| 1194 |
+
0.8230092525482178,
|
| 1195 |
+
0.8352133631706238,
|
| 1196 |
+
0.9176575541496277,
|
| 1197 |
+
0.8429977297782898,
|
| 1198 |
+
0.8367952704429626,
|
| 1199 |
+
0.8534295558929443,
|
| 1200 |
+
0.838138997554779,
|
| 1201 |
+
0.8528039455413818,
|
| 1202 |
+
0.8470211029052734,
|
| 1203 |
+
0.8170969486236572,
|
| 1204 |
+
0.8159563541412354,
|
| 1205 |
+
0.8493235111236572,
|
| 1206 |
+
0.8536948561668396,
|
| 1207 |
+
0.8294236063957214,
|
| 1208 |
+
0.90614253282547,
|
| 1209 |
+
0.8337429165840149,
|
| 1210 |
+
0.8743107914924622,
|
| 1211 |
+
0.8367175459861755,
|
| 1212 |
+
0.8243162631988525,
|
| 1213 |
+
0.8247227668762207,
|
| 1214 |
+
0.8266720771789551,
|
| 1215 |
+
0.9225577712059021,
|
| 1216 |
+
0.8386837840080261,
|
| 1217 |
+
0.8380022644996643,
|
| 1218 |
+
0.8386256098747253,
|
| 1219 |
+
0.8460710048675537,
|
| 1220 |
+
0.8459582924842834,
|
| 1221 |
+
0.835181713104248,
|
| 1222 |
+
0.8186705708503723,
|
| 1223 |
+
0.8184948563575745,
|
| 1224 |
+
0.8417790532112122,
|
| 1225 |
+
0.8489491939544678,
|
| 1226 |
+
0.8365125060081482,
|
| 1227 |
+
0.9095039963722229,
|
| 1228 |
+
0.8317481875419617,
|
| 1229 |
+
0.8886545300483704,
|
| 1230 |
+
0.8443787097930908
|
| 1231 |
+
]
|
| 1232 |
+
},
|
| 1233 |
+
"psnr": {
|
| 1234 |
+
"avg": 23.795184903621124,
|
| 1235 |
+
"vals": [
|
| 1236 |
+
22.123992625667,
|
| 1237 |
+
22.383136736110977,
|
| 1238 |
+
22.815137157510414,
|
| 1239 |
+
29.058328544219666,
|
| 1240 |
+
23.458824312069826,
|
| 1241 |
+
23.35115778300144,
|
| 1242 |
+
23.899720737917086,
|
| 1243 |
+
23.251382000597822,
|
| 1244 |
+
23.8556550518376,
|
| 1245 |
+
24.496927997944468,
|
| 1246 |
+
22.229396380994544,
|
| 1247 |
+
21.835705931258186,
|
| 1248 |
+
24.044979173806603,
|
| 1249 |
+
23.87095747368083,
|
| 1250 |
+
22.755647052287486,
|
| 1251 |
+
27.255022608267765,
|
| 1252 |
+
23.384111730889643,
|
| 1253 |
+
24.47027594982054,
|
| 1254 |
+
22.969660848990017,
|
| 1255 |
+
22.118529030027556,
|
| 1256 |
+
22.65373935684274,
|
| 1257 |
+
22.26046784537836,
|
| 1258 |
+
29.458407087230075,
|
| 1259 |
+
23.06010458055496,
|
| 1260 |
+
23.57190723261845,
|
| 1261 |
+
23.327430637055485,
|
| 1262 |
+
23.589256344441864,
|
| 1263 |
+
23.48439936889143,
|
| 1264 |
+
23.067392439373478,
|
| 1265 |
+
22.096704419878947,
|
| 1266 |
+
21.619810610831216,
|
| 1267 |
+
23.50251506589982,
|
| 1268 |
+
23.94463595644848,
|
| 1269 |
+
23.01993677097123,
|
| 1270 |
+
28.729724692646393,
|
| 1271 |
+
23.230957096762136,
|
| 1272 |
+
26.24107924644081,
|
| 1273 |
+
23.730008458437332
|
| 1274 |
+
]
|
| 1275 |
+
},
|
| 1276 |
+
"lpips": {
|
| 1277 |
+
"avg": 0.07999030775145481,
|
| 1278 |
+
"vals": [
|
| 1279 |
+
0.09683731198310852,
|
| 1280 |
+
0.10403041541576385,
|
| 1281 |
+
0.10699359327554703,
|
| 1282 |
+
0.02980189025402069,
|
| 1283 |
+
0.07957077771425247,
|
| 1284 |
+
0.08596174418926239,
|
| 1285 |
+
0.06019214540719986,
|
| 1286 |
+
0.07882507890462875,
|
| 1287 |
+
0.07737825810909271,
|
| 1288 |
+
0.08655022084712982,
|
| 1289 |
+
0.09879130125045776,
|
| 1290 |
+
0.12470565736293793,
|
| 1291 |
+
0.055472321808338165,
|
| 1292 |
+
0.06304865330457687,
|
| 1293 |
+
0.10779929161071777,
|
| 1294 |
+
0.030872609466314316,
|
| 1295 |
+
0.07079628109931946,
|
| 1296 |
+
0.06584230065345764,
|
| 1297 |
+
0.10193256288766861,
|
| 1298 |
+
0.09854554384946823,
|
| 1299 |
+
0.10243230313062668,
|
| 1300 |
+
0.10806915909051895,
|
| 1301 |
+
0.025890624150633812,
|
| 1302 |
+
0.08075970411300659,
|
| 1303 |
+
0.077939972281456,
|
| 1304 |
+
0.06577856838703156,
|
| 1305 |
+
0.07062944024801254,
|
| 1306 |
+
0.08819691836833954,
|
| 1307 |
+
0.10097920894622803,
|
| 1308 |
+
0.105897456407547,
|
| 1309 |
+
0.11288004368543625,
|
| 1310 |
+
0.06032264977693558,
|
| 1311 |
+
0.063026562333107,
|
| 1312 |
+
0.10948407649993896,
|
| 1313 |
+
0.030226124450564384,
|
| 1314 |
+
0.06960238516330719,
|
| 1315 |
+
0.053624995052814484,
|
| 1316 |
+
0.0899435430765152
|
| 1317 |
+
]
|
| 1318 |
+
},
|
| 1319 |
+
"fid": {
|
| 1320 |
+
"avg": 80.0617442123876,
|
| 1321 |
+
"vals": [
|
| 1322 |
+
77.72583958495193,
|
| 1323 |
+
108.01430563351622,
|
| 1324 |
+
83.80114371601998,
|
| 1325 |
+
43.79561757437186,
|
| 1326 |
+
64.21238376481085,
|
| 1327 |
+
107.90633495184883,
|
| 1328 |
+
94.90226901876251,
|
| 1329 |
+
81.66996436668452,
|
| 1330 |
+
56.44677259421701,
|
| 1331 |
+
133.90454895719694,
|
| 1332 |
+
94.53365874681484,
|
| 1333 |
+
103.25655943118333,
|
| 1334 |
+
56.07493043457935,
|
| 1335 |
+
76.20330391887549,
|
| 1336 |
+
86.46052802881235,
|
| 1337 |
+
35.14997667649733,
|
| 1338 |
+
62.05464557602466,
|
| 1339 |
+
82.75703146138048,
|
| 1340 |
+
78.68766046633303,
|
| 1341 |
+
95.9801870274251,
|
| 1342 |
+
69.69440837156752,
|
| 1343 |
+
133.97665503608607,
|
| 1344 |
+
23.992440096968796,
|
| 1345 |
+
73.95956883955756,
|
| 1346 |
+
101.69397444074366,
|
| 1347 |
+
96.09931824946253,
|
| 1348 |
+
44.15662245914518,
|
| 1349 |
+
83.0950768685819,
|
| 1350 |
+
113.04304804723927,
|
| 1351 |
+
123.68218057444115,
|
| 1352 |
+
80.64456666041356,
|
| 1353 |
+
51.48177417838611,
|
| 1354 |
+
61.27753437751468,
|
| 1355 |
+
89.59810342754149,
|
| 1356 |
+
64.8439294361588,
|
| 1357 |
+
63.560962531495846,
|
| 1358 |
+
68.41856099127594,
|
| 1359 |
+
75.58989355384159
|
| 1360 |
+
]
|
| 1361 |
+
}
|
| 1362 |
+
},
|
| 1363 |
+
"base": {
|
| 1364 |
+
"ssim": {
|
| 1365 |
+
"avg": 0.7146695460143843,
|
| 1366 |
+
"vals": [
|
| 1367 |
+
0.6920250058174133,
|
| 1368 |
+
0.6781535148620605,
|
| 1369 |
+
0.7473878860473633,
|
| 1370 |
+
0.7988237738609314,
|
| 1371 |
+
0.7046812176704407,
|
| 1372 |
+
0.6728485226631165,
|
| 1373 |
+
0.6648303866386414,
|
| 1374 |
+
0.6507443785667419,
|
| 1375 |
+
0.6572958827018738,
|
| 1376 |
+
0.7580508589744568,
|
| 1377 |
+
0.7515758872032166,
|
| 1378 |
+
0.759639322757721,
|
| 1379 |
+
0.7379322052001953,
|
| 1380 |
+
0.7911661267280579,
|
| 1381 |
+
0.7242826819419861,
|
| 1382 |
+
0.6797776222229004,
|
| 1383 |
+
0.7349328994750977,
|
| 1384 |
+
0.7590611577033997,
|
| 1385 |
+
0.6552085876464844,
|
| 1386 |
+
0.745119035243988,
|
| 1387 |
+
0.7352902293205261,
|
| 1388 |
+
0.7022021412849426,
|
| 1389 |
+
0.7449350357055664,
|
| 1390 |
+
0.7506275773048401,
|
| 1391 |
+
0.7204473614692688,
|
| 1392 |
+
0.666022002696991,
|
| 1393 |
+
0.7632265090942383,
|
| 1394 |
+
0.7753145098686218,
|
| 1395 |
+
0.6686270833015442,
|
| 1396 |
+
0.6356704831123352,
|
| 1397 |
+
0.6632815003395081,
|
| 1398 |
+
0.7420862317085266,
|
| 1399 |
+
0.6526774764060974,
|
| 1400 |
+
0.7651051878929138,
|
| 1401 |
+
0.7491373419761658,
|
| 1402 |
+
0.6927105784416199,
|
| 1403 |
+
0.7062166333198547,
|
| 1404 |
+
0.6603279113769531
|
| 1405 |
+
]
|
| 1406 |
+
},
|
| 1407 |
+
"psnr": {
|
| 1408 |
+
"avg": 17.075681593330103,
|
| 1409 |
+
"vals": [
|
| 1410 |
+
16.02221993829837,
|
| 1411 |
+
16.391894818273492,
|
| 1412 |
+
19.158898065865564,
|
| 1413 |
+
21.168902461771523,
|
| 1414 |
+
15.849282724811843,
|
| 1415 |
+
13.438762906336217,
|
| 1416 |
+
14.041096876172848,
|
| 1417 |
+
14.614652873958555,
|
| 1418 |
+
14.836499194441464,
|
| 1419 |
+
19.42617293541887,
|
| 1420 |
+
19.6028175608849,
|
| 1421 |
+
19.50721422138582,
|
| 1422 |
+
17.430967194872665,
|
| 1423 |
+
20.64046847536366,
|
| 1424 |
+
17.37638733796819,
|
| 1425 |
+
15.173079056741063,
|
| 1426 |
+
18.277006035337127,
|
| 1427 |
+
18.80925161485643,
|
| 1428 |
+
14.095576974205786,
|
| 1429 |
+
19.420502974160595,
|
| 1430 |
+
19.454135471046165,
|
| 1431 |
+
17.222296068007893,
|
| 1432 |
+
18.225982648513018,
|
| 1433 |
+
19.51789549326609,
|
| 1434 |
+
16.618582564745015,
|
| 1435 |
+
15.364057378648432,
|
| 1436 |
+
18.71637998732089,
|
| 1437 |
+
19.61078528607858,
|
| 1438 |
+
14.215511993158811,
|
| 1439 |
+
13.88753261057972,
|
| 1440 |
+
10.856355225225222,
|
| 1441 |
+
18.09338753925673,
|
| 1442 |
+
14.179305881135015,
|
| 1443 |
+
20.07084940675992,
|
| 1444 |
+
19.650702290598378,
|
| 1445 |
+
15.665586196724535,
|
| 1446 |
+
16.945129336951908,
|
| 1447 |
+
15.299770927402514
|
| 1448 |
+
]
|
| 1449 |
+
},
|
| 1450 |
+
"lpips": {
|
| 1451 |
+
"avg": 0.2726066539946355,
|
| 1452 |
+
"vals": [
|
| 1453 |
+
0.2813563346862793,
|
| 1454 |
+
0.3176368176937103,
|
| 1455 |
+
0.239101380109787,
|
| 1456 |
+
0.1326800435781479,
|
| 1457 |
+
0.31650233268737793,
|
| 1458 |
+
0.3085799217224121,
|
| 1459 |
+
0.32426244020462036,
|
| 1460 |
+
0.32414373755455017,
|
| 1461 |
+
0.3856937885284424,
|
| 1462 |
+
0.21604351699352264,
|
| 1463 |
+
0.2195548415184021,
|
| 1464 |
+
0.18589559197425842,
|
| 1465 |
+
0.21749232709407806,
|
| 1466 |
+
0.1575351059436798,
|
| 1467 |
+
0.2156955450773239,
|
| 1468 |
+
0.29644495248794556,
|
| 1469 |
+
0.22918514907360077,
|
| 1470 |
+
0.21251118183135986,
|
| 1471 |
+
0.37587040662765503,
|
| 1472 |
+
0.193213552236557,
|
| 1473 |
+
0.2757120132446289,
|
| 1474 |
+
0.31559503078460693,
|
| 1475 |
+
0.2390986531972885,
|
| 1476 |
+
0.2627783417701721,
|
| 1477 |
+
0.22325637936592102,
|
| 1478 |
+
0.35516589879989624,
|
| 1479 |
+
0.2559050917625427,
|
| 1480 |
+
0.2075577825307846,
|
| 1481 |
+
0.3593791425228119,
|
| 1482 |
+
0.39297720789909363,
|
| 1483 |
+
0.3490094244480133,
|
| 1484 |
+
0.22949230670928955,
|
| 1485 |
+
0.3761798143386841,
|
| 1486 |
+
0.20215649902820587,
|
| 1487 |
+
0.21151788532733917,
|
| 1488 |
+
0.33757245540618896,
|
| 1489 |
+
0.27517855167388916,
|
| 1490 |
+
0.3411214053630829
|
| 1491 |
+
]
|
| 1492 |
+
},
|
| 1493 |
+
"fid": {
|
| 1494 |
+
"avg": 171.8519597862991,
|
| 1495 |
+
"vals": [
|
| 1496 |
+
144.77019624546853,
|
| 1497 |
+
154.75288495007914,
|
| 1498 |
+
160.95209947656429,
|
| 1499 |
+
117.83869765679046,
|
| 1500 |
+
151.64615714342628,
|
| 1501 |
+
157.47977926729806,
|
| 1502 |
+
162.33033867519435,
|
| 1503 |
+
265.0183384600493,
|
| 1504 |
+
281.5918475384978,
|
| 1505 |
+
152.44583427606105,
|
| 1506 |
+
133.1601149299724,
|
| 1507 |
+
194.62961438661011,
|
| 1508 |
+
244.89585280015393,
|
| 1509 |
+
131.48030011431425,
|
| 1510 |
+
167.5752450544133,
|
| 1511 |
+
201.44646187475317,
|
| 1512 |
+
107.98531116987058,
|
| 1513 |
+
170.08420689868527,
|
| 1514 |
+
236.9186629037794,
|
| 1515 |
+
141.81304882152676,
|
| 1516 |
+
138.10872705204866,
|
| 1517 |
+
174.07596100327163,
|
| 1518 |
+
133.47905631898664,
|
| 1519 |
+
200.28084705986407,
|
| 1520 |
+
100.24977206742142,
|
| 1521 |
+
146.11378626732036,
|
| 1522 |
+
180.40072384521355,
|
| 1523 |
+
134.37882246824165,
|
| 1524 |
+
266.2742357839778,
|
| 1525 |
+
204.80154766025774,
|
| 1526 |
+
283.97677018987224,
|
| 1527 |
+
144.08380223026322,
|
| 1528 |
+
176.91537688202212,
|
| 1529 |
+
127.66897079587724,
|
| 1530 |
+
130.9521420141355,
|
| 1531 |
+
166.36337398681758,
|
| 1532 |
+
214.39949796639,
|
| 1533 |
+
129.03606564387414
|
| 1534 |
+
]
|
| 1535 |
+
}
|
| 1536 |
+
}
|
| 1537 |
+
}
|
| 1538 |
+
}
|
metrics.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pathlib
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import skimage
|
| 6 |
+
from imageio import imread
|
| 7 |
+
from scipy import linalg
|
| 8 |
+
from torch.nn.functional import adaptive_avg_pool2d
|
| 9 |
+
from skimage.metrics import structural_similarity as compare_ssim
|
| 10 |
+
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
|
| 11 |
+
import glob
|
| 12 |
+
import argparse
|
| 13 |
+
import matplotlib.pyplot as plt
|
| 14 |
+
from inception import InceptionV3
|
| 15 |
+
#from scripts.PerceptualSimilarity.models import dist_model as dm
|
| 16 |
+
import lpips
|
| 17 |
+
import pandas as pd
|
| 18 |
+
import json
|
| 19 |
+
import imageio
|
| 20 |
+
import cv2
|
| 21 |
+
print(skimage.__version__)
|
| 22 |
+
|
| 23 |
+
class FID():
|
| 24 |
+
"""docstring for FID
|
| 25 |
+
Calculates the Frechet Inception Distance (FID) to evalulate GANs
|
| 26 |
+
The FID metric calculates the distance between two distributions of images.
|
| 27 |
+
Typically, we have summary statistics (mean & covariance matrix) of one
|
| 28 |
+
of these distributions, while the 2nd distribution is given by a GAN.
|
| 29 |
+
When run as a stand-alone program, it compares the distribution of
|
| 30 |
+
images that are stored as PNG/JPEG at a specified location with a
|
| 31 |
+
distribution given by summary statistics (in pickle format).
|
| 32 |
+
The FID is calculated by assuming that X_1 and X_2 are the activations of
|
| 33 |
+
the pool_3 layer of the inception net for generated samples and real world
|
| 34 |
+
samples respectivly.
|
| 35 |
+
See --help to see further details.
|
| 36 |
+
Code apapted from https://github.com/bioinf-jku/TTUR to use PyTorch instead
|
| 37 |
+
of Tensorflow
|
| 38 |
+
Copyright 2018 Institute of Bioinformatics, JKU Linz
|
| 39 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 40 |
+
you may not use this file except in compliance with the License.
|
| 41 |
+
You may obtain a copy of the License at
|
| 42 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 43 |
+
Unless required by applicable law or agreed to in writing, software
|
| 44 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 45 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 46 |
+
See the License for the specific language governing permissions and
|
| 47 |
+
limitations under the License.
|
| 48 |
+
"""
|
| 49 |
+
def __init__(self):
|
| 50 |
+
self.dims = 2048
|
| 51 |
+
self.batch_size = 128
|
| 52 |
+
self.cuda = True
|
| 53 |
+
self.verbose=False
|
| 54 |
+
|
| 55 |
+
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims]
|
| 56 |
+
self.model = InceptionV3([block_idx])
|
| 57 |
+
if self.cuda:
|
| 58 |
+
# TODO: put model into specific GPU
|
| 59 |
+
self.model.cuda()
|
| 60 |
+
|
| 61 |
+
def __call__(self, images, gt_path):
|
| 62 |
+
""" images: list of the generated image. The values must lie between 0 and 1.
|
| 63 |
+
gt_path: the path of the ground truth images. The values must lie between 0 and 1.
|
| 64 |
+
"""
|
| 65 |
+
if not os.path.exists(gt_path):
|
| 66 |
+
raise RuntimeError('Invalid path: %s' % gt_path)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
print('calculate gt_path statistics...')
|
| 70 |
+
m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose)
|
| 71 |
+
print('calculate generated_images statistics...')
|
| 72 |
+
m2, s2 = self.calculate_activation_statistics(images, self.verbose)
|
| 73 |
+
fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
|
| 74 |
+
return fid_value
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def calculate_from_disk(self, generated_path, gt_path, img_size):
|
| 78 |
+
"""
|
| 79 |
+
"""
|
| 80 |
+
if not os.path.exists(gt_path):
|
| 81 |
+
raise RuntimeError('Invalid path: %s' % gt_path)
|
| 82 |
+
if not os.path.exists(generated_path):
|
| 83 |
+
raise RuntimeError('Invalid path: %s' % generated_path)
|
| 84 |
+
|
| 85 |
+
print ('exp-path - '+generated_path)
|
| 86 |
+
|
| 87 |
+
print('calculate gt_path statistics...')
|
| 88 |
+
m1, s1 = self.compute_statistics_of_path(gt_path, self.verbose, img_size)
|
| 89 |
+
print('calculate generated_path statistics...')
|
| 90 |
+
m2, s2 = self.compute_statistics_of_path(generated_path, self.verbose, img_size)
|
| 91 |
+
print('calculate frechet distance...')
|
| 92 |
+
fid_value = self.calculate_frechet_distance(m1, s1, m2, s2)
|
| 93 |
+
print('fid_distance %f' % (fid_value))
|
| 94 |
+
return fid_value
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def compute_statistics_of_path(self, path , verbose, img_size):
|
| 98 |
+
|
| 99 |
+
size_flag = '{}_{}'.format(img_size[0], img_size[1])
|
| 100 |
+
npz_file = os.path.join(path, size_flag + '_statistics.npz')
|
| 101 |
+
if os.path.exists(npz_file):
|
| 102 |
+
f = np.load(npz_file)
|
| 103 |
+
m, s = f['mu'][:], f['sigma'][:]
|
| 104 |
+
f.close()
|
| 105 |
+
|
| 106 |
+
else:
|
| 107 |
+
|
| 108 |
+
path = pathlib.Path(path)
|
| 109 |
+
files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
|
| 110 |
+
|
| 111 |
+
imgs = (np.array([(cv2.resize(imread(str(fn)).astype(np.float32),img_size,interpolation=cv2.INTER_CUBIC)) for fn in files]))/255.0
|
| 112 |
+
# Bring images to shape (B, 3, H, W)
|
| 113 |
+
imgs = imgs.transpose((0, 3, 1, 2))
|
| 114 |
+
|
| 115 |
+
# Rescale images to be between 0 and 1
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
m, s = self.calculate_activation_statistics(imgs, verbose)
|
| 119 |
+
np.savez(npz_file, mu=m, sigma=s)
|
| 120 |
+
|
| 121 |
+
return m, s
|
| 122 |
+
|
| 123 |
+
def calculate_activation_statistics(self, images, verbose):
|
| 124 |
+
"""Calculation of the statistics used by the FID.
|
| 125 |
+
Params:
|
| 126 |
+
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
|
| 127 |
+
must lie between 0 and 1.
|
| 128 |
+
-- model : Instance of inception model
|
| 129 |
+
-- batch_size : The images numpy array is split into batches with
|
| 130 |
+
batch size batch_size. A reasonable batch size
|
| 131 |
+
depends on the hardware.
|
| 132 |
+
-- dims : Dimensionality of features returned by Inception
|
| 133 |
+
-- cuda : If set to True, use GPU
|
| 134 |
+
-- verbose : If set to True and parameter out_step is given, the
|
| 135 |
+
number of calculated batches is reported.
|
| 136 |
+
Returns:
|
| 137 |
+
-- mu : The mean over samples of the activations of the pool_3 layer of
|
| 138 |
+
the inception model.
|
| 139 |
+
-- sigma : The covariance matrix of the activations of the pool_3 layer of
|
| 140 |
+
the inception model.
|
| 141 |
+
"""
|
| 142 |
+
act = self.get_activations(images, verbose)
|
| 143 |
+
mu = np.mean(act, axis=0)
|
| 144 |
+
sigma = np.cov(act, rowvar=False)
|
| 145 |
+
return mu, sigma
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_activations(self, images, verbose=False):
|
| 150 |
+
"""Calculates the activations of the pool_3 layer for all images.
|
| 151 |
+
Params:
|
| 152 |
+
-- images : Numpy array of dimension (n_images, 3, hi, wi). The values
|
| 153 |
+
must lie between 0 and 1.
|
| 154 |
+
-- model : Instance of inception model
|
| 155 |
+
-- batch_size : the images numpy array is split into batches with
|
| 156 |
+
batch size batch_size. A reasonable batch size depends
|
| 157 |
+
on the hardware.
|
| 158 |
+
-- dims : Dimensionality of features returned by Inception
|
| 159 |
+
-- cuda : If set to True, use GPU
|
| 160 |
+
-- verbose : If set to True and parameter out_step is given, the number
|
| 161 |
+
of calculated batches is reported.
|
| 162 |
+
Returns:
|
| 163 |
+
-- A numpy array of dimension (num images, dims) that contains the
|
| 164 |
+
activations of the given tensor when feeding inception with the
|
| 165 |
+
query tensor.
|
| 166 |
+
"""
|
| 167 |
+
self.model.eval()
|
| 168 |
+
|
| 169 |
+
d0 = images.shape[0]
|
| 170 |
+
if self.batch_size > d0:
|
| 171 |
+
print(('Warning: batch size is bigger than the data size. '
|
| 172 |
+
'Setting batch size to data size'))
|
| 173 |
+
self.batch_size = d0
|
| 174 |
+
|
| 175 |
+
n_batches = d0 // self.batch_size
|
| 176 |
+
n_used_imgs = n_batches * self.batch_size
|
| 177 |
+
|
| 178 |
+
pred_arr = np.empty((n_used_imgs, self.dims))
|
| 179 |
+
for i in range(n_batches):
|
| 180 |
+
if verbose:
|
| 181 |
+
print('\rPropagating batch %d/%d' % (i + 1, n_batches))
|
| 182 |
+
# end='', flush=True)
|
| 183 |
+
start = i * self.batch_size
|
| 184 |
+
end = start + self.batch_size
|
| 185 |
+
|
| 186 |
+
batch = torch.from_numpy(images[start:end]).type(torch.FloatTensor)
|
| 187 |
+
# batch = Variable(batch, volatile=True)
|
| 188 |
+
if self.cuda:
|
| 189 |
+
batch = batch.cuda()
|
| 190 |
+
|
| 191 |
+
pred = self.model(batch)[0]
|
| 192 |
+
|
| 193 |
+
# If model output is not scalar, apply global spatial average pooling.
|
| 194 |
+
# This happens if you choose a dimensionality not equal 2048.
|
| 195 |
+
if pred.shape[2] != 1 or pred.shape[3] != 1:
|
| 196 |
+
pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
|
| 197 |
+
|
| 198 |
+
pred_arr[start:end] = pred.cpu().data.numpy().reshape(self.batch_size, -1)
|
| 199 |
+
|
| 200 |
+
if verbose:
|
| 201 |
+
print(' done')
|
| 202 |
+
|
| 203 |
+
return pred_arr
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def calculate_frechet_distance(self, mu1, sigma1, mu2, sigma2, eps=1e-6):
|
| 207 |
+
"""Numpy implementation of the Frechet Distance.
|
| 208 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
| 209 |
+
and X_2 ~ N(mu_2, C_2) is
|
| 210 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
| 211 |
+
Stable version by Dougal J. Sutherland.
|
| 212 |
+
Params:
|
| 213 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
| 214 |
+
inception net (like returned by the function 'get_predictions')
|
| 215 |
+
for generated samples.
|
| 216 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
| 217 |
+
representive data set.
|
| 218 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
| 219 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
| 220 |
+
representive data set.
|
| 221 |
+
Returns:
|
| 222 |
+
-- : The Frechet Distance.
|
| 223 |
+
"""
|
| 224 |
+
|
| 225 |
+
mu1 = np.atleast_1d(mu1)
|
| 226 |
+
mu2 = np.atleast_1d(mu2)
|
| 227 |
+
|
| 228 |
+
sigma1 = np.atleast_2d(sigma1)
|
| 229 |
+
sigma2 = np.atleast_2d(sigma2)
|
| 230 |
+
|
| 231 |
+
assert mu1.shape == mu2.shape, \
|
| 232 |
+
'Training and test mean vectors have different lengths'
|
| 233 |
+
assert sigma1.shape == sigma2.shape, \
|
| 234 |
+
'Training and test covariances have different dimensions'
|
| 235 |
+
|
| 236 |
+
diff = mu1 - mu2
|
| 237 |
+
|
| 238 |
+
# Product might be almost singular
|
| 239 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
| 240 |
+
if not np.isfinite(covmean).all():
|
| 241 |
+
msg = ('fid calculation produces singular product; '
|
| 242 |
+
'adding %s to diagonal of cov estimates') % eps
|
| 243 |
+
print(msg)
|
| 244 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
| 245 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
| 246 |
+
|
| 247 |
+
# Numerical error might give slight imaginary component
|
| 248 |
+
if np.iscomplexobj(covmean):
|
| 249 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
| 250 |
+
m = np.max(np.abs(covmean.imag))
|
| 251 |
+
raise ValueError('Imaginary component {}'.format(m))
|
| 252 |
+
covmean = covmean.real
|
| 253 |
+
|
| 254 |
+
tr_covmean = np.trace(covmean)
|
| 255 |
+
|
| 256 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
| 257 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class Reconstruction_Metrics():
|
| 261 |
+
def __init__(self, metric_list=['ssim', 'psnr', 'l1', 'mae'], data_range=1, win_size=51, multichannel=True):
|
| 262 |
+
self.data_range = data_range
|
| 263 |
+
self.win_size = win_size
|
| 264 |
+
self.multichannel = multichannel
|
| 265 |
+
for metric in metric_list:
|
| 266 |
+
if metric in ['ssim', 'psnr', 'l1', 'mae']:
|
| 267 |
+
setattr(self, metric, True)
|
| 268 |
+
else:
|
| 269 |
+
print('unsupport reconstruction metric: %s'%metric)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def __call__(self, inputs, gts):
|
| 273 |
+
"""
|
| 274 |
+
inputs: the generated image, size (b,c,w,h), data range(0, data_range)
|
| 275 |
+
gts: the ground-truth image, size (b,c,w,h), data range(0, data_range)
|
| 276 |
+
"""
|
| 277 |
+
result = dict()
|
| 278 |
+
[b,n,w,h] = inputs.size()
|
| 279 |
+
inputs = inputs.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
|
| 280 |
+
gts = gts.view(b*n, w, h).detach().cpu().numpy().astype(np.float32).transpose(1,2,0)
|
| 281 |
+
|
| 282 |
+
if hasattr(self, 'ssim'):
|
| 283 |
+
ssim_value = compare_ssim(inputs, gts, data_range=self.data_range,
|
| 284 |
+
win_size=self.win_size, multichannel=self.multichannel)
|
| 285 |
+
result['ssim'] = ssim_value
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
if hasattr(self, 'psnr'):
|
| 289 |
+
psnr_value = compare_psnr(inputs, gts, self.data_range)
|
| 290 |
+
result['psnr'] = psnr_value
|
| 291 |
+
|
| 292 |
+
if hasattr(self, 'l1'):
|
| 293 |
+
l1_value = compare_l1(inputs, gts)
|
| 294 |
+
result['l1'] = l1_value
|
| 295 |
+
|
| 296 |
+
if hasattr(self, 'mae'):
|
| 297 |
+
mae_value = compare_mae(inputs, gts)
|
| 298 |
+
result['mae'] = mae_value
|
| 299 |
+
return result
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def calculate_from_disk(self, inputs, gts, save_path=None, img_size=(176,256), sort=True, debug=0):
|
| 303 |
+
"""
|
| 304 |
+
inputs: .txt files, floders, image files (string), image files (list)
|
| 305 |
+
gts: .txt files, floders, image files (string), image files (list)
|
| 306 |
+
"""
|
| 307 |
+
if sort:
|
| 308 |
+
input_image_list = sorted(get_image_list(inputs))
|
| 309 |
+
gt_image_list = sorted(get_image_list(gts))
|
| 310 |
+
else:
|
| 311 |
+
input_image_list = get_image_list(inputs)
|
| 312 |
+
gt_image_list = get_image_list(gts)
|
| 313 |
+
|
| 314 |
+
size_flag = '{}_{}'.format(img_size[0], img_size[1])
|
| 315 |
+
npz_file = os.path.join(save_path, size_flag + '_metrics.npz')
|
| 316 |
+
if os.path.exists(npz_file):
|
| 317 |
+
f = np.load(npz_file)
|
| 318 |
+
psnr,ssim,ssim_256,mae,l1=f['psnr'],f['ssim'],f['ssim_256'],f['mae'],f['l1']
|
| 319 |
+
else:
|
| 320 |
+
psnr = []
|
| 321 |
+
ssim = []
|
| 322 |
+
ssim_256 = []
|
| 323 |
+
mae = []
|
| 324 |
+
l1 = []
|
| 325 |
+
names = []
|
| 326 |
+
|
| 327 |
+
for index in range(len(input_image_list)):
|
| 328 |
+
name = os.path.basename(input_image_list[index])
|
| 329 |
+
names.append(name)
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
img_gt = (cv2.resize(imread(str(gt_image_list[index])).astype(np.float32), img_size,interpolation=cv2.INTER_CUBIC)) /255.0
|
| 333 |
+
img_pred = (cv2.resize(imread(str(input_image_list[index])).astype(np.float32), img_size,interpolation=cv2.INTER_CUBIC)) / 255.0
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
if debug != 0:
|
| 337 |
+
plt.subplot('121')
|
| 338 |
+
plt.imshow(img_gt)
|
| 339 |
+
plt.title('Groud truth')
|
| 340 |
+
plt.subplot('122')
|
| 341 |
+
plt.imshow(img_pred)
|
| 342 |
+
plt.title('Output')
|
| 343 |
+
plt.show()
|
| 344 |
+
|
| 345 |
+
psnr.append(compare_psnr(img_gt, img_pred, data_range=self.data_range))
|
| 346 |
+
ssim.append(compare_ssim(img_gt, img_pred, data_range=self.data_range,
|
| 347 |
+
win_size=self.win_size,multichannel=self.multichannel, channel_axis=2))
|
| 348 |
+
mae.append(compare_mae(img_gt, img_pred))
|
| 349 |
+
l1.append(compare_l1(img_gt, img_pred))
|
| 350 |
+
|
| 351 |
+
img_gt_256 = img_gt*255.0
|
| 352 |
+
img_pred_256 = img_pred*255.0
|
| 353 |
+
ssim_256.append(compare_ssim(img_gt_256, img_pred_256, gaussian_weights=True, sigma=1.2,
|
| 354 |
+
use_sample_covariance=False, multichannel=True, channel_axis=2,
|
| 355 |
+
data_range=img_pred_256.max() - img_pred_256.min()))
|
| 356 |
+
|
| 357 |
+
if np.mod(index, 200) == 0:
|
| 358 |
+
print(
|
| 359 |
+
str(index) + ' images processed',
|
| 360 |
+
"PSNR: %.4f" % round(np.mean(psnr), 4),
|
| 361 |
+
"SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
|
| 362 |
+
"MAE: %.4f" % round(np.mean(mae), 4),
|
| 363 |
+
"l1: %.4f" % round(np.mean(l1), 4),
|
| 364 |
+
)
|
| 365 |
+
|
| 366 |
+
if save_path:
|
| 367 |
+
np.savez(save_path + '/' + size_flag + '_metrics.npz', psnr=psnr, ssim=ssim, ssim_256=ssim_256, mae=mae, l1=l1, names=names)
|
| 368 |
+
|
| 369 |
+
print(
|
| 370 |
+
"PSNR: %.4f" % round(np.mean(psnr), 4),
|
| 371 |
+
"PSNR Variance: %.4f" % round(np.var(psnr), 4),
|
| 372 |
+
"SSIM_256: %.4f" % round(np.mean(ssim_256), 4),
|
| 373 |
+
"SSIM_256 Variance: %.4f" % round(np.var(ssim_256), 4),
|
| 374 |
+
"MAE: %.4f" % round(np.mean(mae), 4),
|
| 375 |
+
"MAE Variance: %.4f" % round(np.var(mae), 4),
|
| 376 |
+
"l1: %.4f" % round(np.mean(l1), 4),
|
| 377 |
+
"l1 Variance: %.4f" % round(np.var(l1), 4)
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
dic = {"psnr":[round(np.mean(psnr), 6)],
|
| 381 |
+
"psnr_variance": [round(np.var(psnr), 6)],
|
| 382 |
+
"ssim_256": [round(np.mean(ssim_256), 6)],
|
| 383 |
+
"ssim_256_variance": [round(np.var(ssim_256), 6)],
|
| 384 |
+
"mae": [round(np.mean(mae), 6)],
|
| 385 |
+
"mae_variance": [round(np.var(mae), 6)],
|
| 386 |
+
"l1": [round(np.mean(l1), 6)],
|
| 387 |
+
"l1_variance": [round(np.var(l1), 6)] }
|
| 388 |
+
|
| 389 |
+
return dic
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def get_image_list(flist):
|
| 393 |
+
if isinstance(flist, list):
|
| 394 |
+
return flist
|
| 395 |
+
|
| 396 |
+
# flist: image file path, image directory path, text file flist path
|
| 397 |
+
if isinstance(flist, str):
|
| 398 |
+
if os.path.isdir(flist):
|
| 399 |
+
flist = list(glob.glob(flist + '/*.jpg')) + list(glob.glob(flist + '/*.png'))
|
| 400 |
+
flist.sort()
|
| 401 |
+
return flist
|
| 402 |
+
|
| 403 |
+
if os.path.isfile(flist):
|
| 404 |
+
try:
|
| 405 |
+
return np.genfromtxt(flist, dtype=np.str)
|
| 406 |
+
except:
|
| 407 |
+
return [flist]
|
| 408 |
+
print('can not read files from %s return empty list'%flist)
|
| 409 |
+
return []
|
| 410 |
+
|
| 411 |
+
def compare_l1(img_true, img_test):
|
| 412 |
+
img_true = img_true.astype(np.float32)
|
| 413 |
+
img_test = img_test.astype(np.float32)
|
| 414 |
+
return np.mean(np.abs(img_true - img_test))
|
| 415 |
+
|
| 416 |
+
def compare_mae(img_true, img_test):
|
| 417 |
+
img_true = img_true.astype(np.float32)
|
| 418 |
+
img_test = img_test.astype(np.float32)
|
| 419 |
+
return np.sum(np.abs(img_true - img_test)) / np.sum(img_true + img_test)
|
| 420 |
+
|
| 421 |
+
def preprocess_path_for_deform_task(gt_path, distorted_path):
|
| 422 |
+
distorted_image_list = sorted(get_image_list(distorted_path))
|
| 423 |
+
gt_list=[]
|
| 424 |
+
distorated_list=[]
|
| 425 |
+
|
| 426 |
+
for distorted_image in distorted_image_list:
|
| 427 |
+
image = os.path.basename(distorted_image)[1:]
|
| 428 |
+
image = image.split('_to_')[-1]
|
| 429 |
+
gt_image = gt_path + '/' + image.replace('jpg', 'png')
|
| 430 |
+
if not os.path.isfile(gt_image):
|
| 431 |
+
print(distorted_image, gt_image)
|
| 432 |
+
print('=====')
|
| 433 |
+
continue
|
| 434 |
+
gt_list.append(gt_image)
|
| 435 |
+
distorated_list.append(distorted_image)
|
| 436 |
+
|
| 437 |
+
return gt_list, distorated_list
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
|
| 441 |
+
class LPIPS():
|
| 442 |
+
def __init__(self, use_gpu=True):
|
| 443 |
+
|
| 444 |
+
self.model = lpips.LPIPS(net='alex').eval().cuda()
|
| 445 |
+
self.use_gpu=use_gpu
|
| 446 |
+
|
| 447 |
+
def __call__(self, image_1, image_2):
|
| 448 |
+
"""
|
| 449 |
+
image_1: images with size (n, 3, w, h) with value [-1, 1]
|
| 450 |
+
image_2: images with size (n, 3, w, h) with value [-1, 1]
|
| 451 |
+
"""
|
| 452 |
+
result = self.model.forward(image_1, image_2)
|
| 453 |
+
return result
|
| 454 |
+
|
| 455 |
+
def calculate_from_disk(self, path_1, path_2,img_size, batch_size=64, verbose=False, sort=True):
|
| 456 |
+
|
| 457 |
+
if sort:
|
| 458 |
+
files_1 = sorted(get_image_list(path_1))
|
| 459 |
+
files_2 = sorted(get_image_list(path_2))
|
| 460 |
+
else:
|
| 461 |
+
files_1 = get_image_list(path_1)
|
| 462 |
+
files_2 = get_image_list(path_2)
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
results=[]
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
d0 = len(files_1)
|
| 469 |
+
if batch_size > d0:
|
| 470 |
+
print(('Warning: batch size is bigger than the data size. '
|
| 471 |
+
'Setting batch size to data size'))
|
| 472 |
+
batch_size = d0
|
| 473 |
+
|
| 474 |
+
n_batches = d0 // batch_size
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
for i in range(n_batches):
|
| 478 |
+
if verbose:
|
| 479 |
+
print('\rPropagating batch %d/%d' % (i + 1, n_batches))
|
| 480 |
+
# end='', flush=True)
|
| 481 |
+
start = i * batch_size
|
| 482 |
+
end = start + batch_size
|
| 483 |
+
|
| 484 |
+
imgs_1 = np.array([cv2.resize(imread(str(fn)).astype(np.float32),img_size,interpolation=cv2.INTER_CUBIC)/255.0 for fn in files_1[start:end]])
|
| 485 |
+
imgs_2 = np.array([cv2.resize(imread(str(fn)).astype(np.float32),img_size,interpolation=cv2.INTER_CUBIC)/255.0 for fn in files_2[start:end]])
|
| 486 |
+
|
| 487 |
+
imgs_1 = imgs_1.transpose((0, 3, 1, 2))
|
| 488 |
+
imgs_2 = imgs_2.transpose((0, 3, 1, 2))
|
| 489 |
+
|
| 490 |
+
img_1_batch = torch.from_numpy(imgs_1).type(torch.FloatTensor)
|
| 491 |
+
img_2_batch = torch.from_numpy(imgs_2).type(torch.FloatTensor)
|
| 492 |
+
|
| 493 |
+
if self.use_gpu:
|
| 494 |
+
img_1_batch = img_1_batch.cuda()
|
| 495 |
+
img_2_batch = img_2_batch.cuda()
|
| 496 |
+
|
| 497 |
+
with torch.no_grad():
|
| 498 |
+
result = self.model.forward(img_1_batch, img_2_batch)
|
| 499 |
+
|
| 500 |
+
results.append(result)
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
distance = torch.cat(results,0)[:,0,0,0].mean()
|
| 504 |
+
|
| 505 |
+
print('lpips: %.3f'%distance)
|
| 506 |
+
return distance
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
|
pose-frames.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#from annotator.dwpose import DWposeDetector
|
| 2 |
+
from easy_dwpose import DWposeDetector
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
device = "cpu"
|
| 7 |
+
dwpose = DWposeDetector(device=device)
|
| 8 |
+
|
| 9 |
+
for n in range(1, 46):
|
| 10 |
+
pil_image = Image.open("videos/dance2/frame ("+str(n)+").png").convert("RGB")
|
| 11 |
+
#skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False)
|
| 12 |
+
|
| 13 |
+
out_img, pose = dwpose(pil_image, include_hands=True, include_face=True)
|
| 14 |
+
|
| 15 |
+
print(pose['bodies'])
|
| 16 |
+
out_img.save('videos/dance'+str(n)+'.png')
|
pose.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#from annotator.dwpose import DWposeDetector
|
| 2 |
+
from easy_dwpose import DWposeDetector
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
device = "cpu"
|
| 7 |
+
dwpose = DWposeDetector(device=device)
|
| 8 |
+
|
| 9 |
+
pil_image = Image.open("imgs/baggy.png").convert("RGB")
|
| 10 |
+
#skeleton = dwpose(pil_image, output_type="np", include_hands=True, include_face=False)
|
| 11 |
+
|
| 12 |
+
out_img, _ = dwpose(pil_image, include_hands=True, include_face=False)
|
| 13 |
+
|
| 14 |
+
#print(pose['bodies'])
|
| 15 |
+
out_img.save("pose.png")
|
requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
easy-dwpose
|
| 2 |
+
diffusers
|
| 3 |
+
controlnet-aux
|
| 4 |
+
transformers
|
| 5 |
+
accelerate
|
| 6 |
+
gradio
|
| 7 |
+
rembg[cpu]
|
| 8 |
+
spaces
|
run_stage1.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate launch --gpu_ids 0,1,2,3,4,5,6,7 --use_deepspeed --num_processes 8 \
|
| 2 |
+
stage1_train_prior_model.py \
|
| 3 |
+
--pretrained_model_name_or_path="kandinsky-community/kandinsky-2-2-prior" \
|
| 4 |
+
--image_encoder_path='{image_encoder_path}' \
|
| 5 |
+
--img_path='{image_path}' \
|
| 6 |
+
--json_path='{data.json}' \
|
| 7 |
+
--output_dir="{output_dir}" \
|
| 8 |
+
--img_height=512 \
|
| 9 |
+
--img_width=512 \
|
| 10 |
+
--train_batch_size=128 \
|
| 11 |
+
--gradient_accumulation_steps=1 \
|
| 12 |
+
--max_train_steps=100000 \
|
| 13 |
+
--noise_offset=0.1 \
|
| 14 |
+
--learning_rate=1e-05 \
|
| 15 |
+
--weight_decay=0.01 \
|
| 16 |
+
--lr_scheduler="constant" --num_warmup_steps=2000 \
|
| 17 |
+
--checkpointing_steps=5000 \
|
| 18 |
+
--seed 42
|
run_stage2.sh
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
accelerate launch --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 --use_deepspeed --mixed_precision="fp16" stage2_train_inpaint_model.py \
|
| 3 |
+
--pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
|
| 4 |
+
--image_encoder_p_path='facebook/dinov2-giant' \
|
| 5 |
+
--image_encoder_g_path='{image_encoder_path}' \
|
| 6 |
+
--json_path='{data.json}' \
|
| 7 |
+
--image_root_path="{image_path}" \
|
| 8 |
+
--output_dir="{output_dir}" \
|
| 9 |
+
--img_height=512 \
|
| 10 |
+
--img_width=512 \
|
| 11 |
+
--learning_rate=1e-4 \
|
| 12 |
+
--train_batch_size=8 \
|
| 13 |
+
--max_train_steps=1000000 \
|
| 14 |
+
--mixed_precision="fp16" \
|
| 15 |
+
--checkpointing_steps=5000 \
|
| 16 |
+
--noise_offset=0.1 \
|
| 17 |
+
--lr_warmup_steps 5000 \
|
| 18 |
+
--seed 42
|
run_stage3.sh
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate launch --gpu_ids 0,1,2,3,4,5,6,7 --num_processes 8 --use_deepspeed --mixed_precision="fp16" stage3_train_refined_model.py \
|
| 2 |
+
--pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
|
| 3 |
+
--image_encoder_path='facebook/dinov2-giant' \
|
| 4 |
+
--img_path='{image_path}' \
|
| 5 |
+
--json_path='{data.json}' \
|
| 6 |
+
--gen_t_img_path='{stage2_generate}' \
|
| 7 |
+
--output_dir="{output_dir}" \
|
| 8 |
+
--learning_rate=1e-5 \
|
| 9 |
+
--train_batch_size=16 \
|
| 10 |
+
--max_train_steps=1000000 \
|
| 11 |
+
--mixed_precision="fp16" \
|
| 12 |
+
--checkpointing_steps=5000 \
|
| 13 |
+
--noise_offset=0.1 \
|
| 14 |
+
--report_to=tensorboard \
|
| 15 |
+
--lr_warmup_steps 5000 \
|
run_test_stage1.sh
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python3 stage1_batchtest_prior_model.py \
|
| 2 |
+
--pretrained_model_name_or_path="kandinsky-community/kandinsky-2-2-prior" \
|
| 3 |
+
--image_encoder_path="{image_encoder_path}" \
|
| 4 |
+
--img_path='{image_path}' \
|
| 5 |
+
--json_path='{data.json}' \
|
| 6 |
+
--pose_path="{normalized_pose_txt}" \
|
| 7 |
+
--save_path="./logs/view_stage1/512_512" \
|
| 8 |
+
--weights_name="{save_ckpt}"\
|
| 9 |
+
|
run_test_stage2.sh
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python3 stage2_batchtest_inpaint_model.py \
|
| 2 |
+
--img_weigh 512 \
|
| 3 |
+
--img_height 512 \
|
| 4 |
+
--pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
|
| 5 |
+
--image_encoder_g_path='{image_encoder_path}' \
|
| 6 |
+
--image_encoder_p_path='facebook/dinov2-giant' \
|
| 7 |
+
--img_path='{image_path}' \
|
| 8 |
+
--json_path='{data.json}' \
|
| 9 |
+
--pose_path="{pose_path}" \
|
| 10 |
+
--target_embed_path="./logs/view_stage1/512_512/" \
|
| 11 |
+
--save_path="./logs/view_stage2/512_512" \
|
| 12 |
+
--weights_name="{save_ckpt}" \
|
| 13 |
+
--calculate_metrics
|
run_test_stage3.sh
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python3 stage3_batchtest_refined_model.py \
|
| 2 |
+
--img_weigh 512 \
|
| 3 |
+
--img_height 512 \
|
| 4 |
+
--pretrained_model_name_or_path="stabilityai/stable-diffusion-2-1-base" \
|
| 5 |
+
--image_encoder_p_path='facebook/dinov2-giant' \
|
| 6 |
+
--img_path='{image_path}' \
|
| 7 |
+
--json_path='{data.json}' \
|
| 8 |
+
--pose_path="{pose_path}" \
|
| 9 |
+
--gen_t_img_path="./logs/view_stage2/512_512/" \
|
| 10 |
+
--save_path="./logs/view_stage3/512_512" \
|
| 11 |
+
--weights_name"{save_ckpt}" \
|
| 12 |
+
--calculate_metrics
|
sd.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
model_id = "stabilityai/stable-diffusion-2-1-base"
|
| 5 |
+
|
| 6 |
+
scheduler = EulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
|
| 7 |
+
pipe = StableDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, torch_dtype=torch.float16)
|
| 8 |
+
pipe = pipe.to("cpu")
|
| 9 |
+
|
| 10 |
+
prompt = "a photo of an astronaut riding a horse on mars"
|
| 11 |
+
image = pipe(prompt).images[0]
|
| 12 |
+
|
| 13 |
+
image.save("astronaut_rides_horse.png")
|
setup.txt
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pip install -U xformers --index-url https://download.pytorch.org/whl/cu118
|
| 2 |
+
pip install easy-dwpose
|
| 3 |
+
pip install diffusers==0.24.0
|
| 4 |
+
pip install controlnet-aux==0.0.7
|
| 5 |
+
pip install transformers==4.32.1
|
| 6 |
+
pip install accelerate==0.24.1
|
| 7 |
+
pip install huggingface-hub==0.25.2
|
| 8 |
+
|
| 9 |
+
pip install gdown
|
| 10 |
+
|
| 11 |
+
pip install -U openmim
|
| 12 |
+
mim install mmengine
|
| 13 |
+
mim install "mmcv==2.1.0"
|
| 14 |
+
mim install "mmdet==3.3.0"
|
| 15 |
+
mim install "mmpose==1.3.2"
|
| 16 |
+
|
| 17 |
+
apt-get install libgl1
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
#s1
|
| 21 |
+
gdown https://drive.google.com/uc?id=11a5-a5C8NWA4m6i1g099fQQrdohz5gQ1
|
| 22 |
+
|
| 23 |
+
#s2
|
| 24 |
+
gdown https://drive.google.com/uc?id=1JhWeScr9bQtoQmB503VDyomaDmxBHail
|
| 25 |
+
|
| 26 |
+
#s3
|
| 27 |
+
gdown https://drive.google.com/uc?id=11JZXfYVlgLFqmE8jCbLWjQWO7LrwYq-I
|
| 28 |
+
|
| 29 |
+
#demo
|
| 30 |
+
gdown https://drive.google.com/uc?id=1JFFy_FBxOFuGFBcB6xMIVwcQb8bfnpO9
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
pip install numpy==1.26.4
|
| 35 |
+
pip install -U xformers --index-url https://download.pytorch.org/whl/cu118
|
| 36 |
+
|
| 37 |
+
for pytorch 2.4.0?
|
| 38 |
+
pip install xformers==0.0.28.dev895 or pip install xformers==0.0.28.dev893
|
| 39 |
+
|
| 40 |
+
48gb ram for finetuning
|
| 41 |
+
|
single_extract_pose.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.controlnet_aux import DWposeDetector
|
| 2 |
+
from PIL import Image
|
| 3 |
+
import torchvision.transforms as transforms
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
def init_dwpose_detector(device):
|
| 7 |
+
# specify configs, ckpts and device, or it will be downloaded automatically and use cpu by default
|
| 8 |
+
det_config = './src/configs/yolox_l_8xb8-300e_coco.py'
|
| 9 |
+
det_ckpt = './ckpts/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'
|
| 10 |
+
pose_config = './src/configs/dwpose-l_384x288.py'
|
| 11 |
+
pose_ckpt = './ckpts/dw-ll_ucoco_384.pth'
|
| 12 |
+
|
| 13 |
+
dwpose_model = DWposeDetector(
|
| 14 |
+
det_config=det_config,
|
| 15 |
+
det_ckpt=det_ckpt,
|
| 16 |
+
pose_config=pose_config,
|
| 17 |
+
pose_ckpt=pose_ckpt,
|
| 18 |
+
device=device
|
| 19 |
+
)
|
| 20 |
+
return dwpose_model.to(device)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def inference_pose(img_path, image_size=(1024, 1024)):
|
| 24 |
+
device = torch.device(f"cuda:{0}")
|
| 25 |
+
model = init_dwpose_detector(device=device)
|
| 26 |
+
pil_image = Image.open(img_path).convert("RGB").resize(image_size, Image.BICUBIC)
|
| 27 |
+
dwpose_image = model(pil_image, output_type='np', image_resolution=image_size[1])
|
| 28 |
+
save_dwpose_image = Image.fromarray(dwpose_image)
|
| 29 |
+
return save_dwpose_image
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
inference_pose('imgs/test.png').save("pose.png")
|
| 34 |
+
|
| 35 |
+
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (178 Bytes). View file
|
|
|
src/configs/dwpose-l_384x288.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# runtime
|
| 2 |
+
max_epochs = 270
|
| 3 |
+
stage2_num_epochs = 30
|
| 4 |
+
base_lr = 4e-3
|
| 5 |
+
|
| 6 |
+
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
| 7 |
+
randomness = dict(seed=21)
|
| 8 |
+
|
| 9 |
+
# optimizer
|
| 10 |
+
optim_wrapper = dict(
|
| 11 |
+
type='OptimWrapper',
|
| 12 |
+
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
|
| 13 |
+
paramwise_cfg=dict(
|
| 14 |
+
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
| 15 |
+
|
| 16 |
+
# learning rate
|
| 17 |
+
param_scheduler = [
|
| 18 |
+
dict(
|
| 19 |
+
type='LinearLR',
|
| 20 |
+
start_factor=1.0e-5,
|
| 21 |
+
by_epoch=False,
|
| 22 |
+
begin=0,
|
| 23 |
+
end=1000),
|
| 24 |
+
dict(
|
| 25 |
+
# use cosine lr from 150 to 300 epoch
|
| 26 |
+
type='CosineAnnealingLR',
|
| 27 |
+
eta_min=base_lr * 0.05,
|
| 28 |
+
begin=max_epochs // 2,
|
| 29 |
+
end=max_epochs,
|
| 30 |
+
T_max=max_epochs // 2,
|
| 31 |
+
by_epoch=True,
|
| 32 |
+
convert_to_iter_based=True),
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
# automatically scaling LR based on the actual training batch size
|
| 36 |
+
auto_scale_lr = dict(base_batch_size=512)
|
| 37 |
+
|
| 38 |
+
# codec settings
|
| 39 |
+
codec = dict(
|
| 40 |
+
type='SimCCLabel',
|
| 41 |
+
input_size=(288, 384),
|
| 42 |
+
sigma=(6., 6.93),
|
| 43 |
+
simcc_split_ratio=2.0,
|
| 44 |
+
normalize=False,
|
| 45 |
+
use_dark=False)
|
| 46 |
+
|
| 47 |
+
# model settings
|
| 48 |
+
model = dict(
|
| 49 |
+
type='TopdownPoseEstimator',
|
| 50 |
+
data_preprocessor=dict(
|
| 51 |
+
type='PoseDataPreprocessor',
|
| 52 |
+
mean=[123.675, 116.28, 103.53],
|
| 53 |
+
std=[58.395, 57.12, 57.375],
|
| 54 |
+
bgr_to_rgb=True),
|
| 55 |
+
backbone=dict(
|
| 56 |
+
_scope_='mmdet',
|
| 57 |
+
type='CSPNeXt',
|
| 58 |
+
arch='P5',
|
| 59 |
+
expand_ratio=0.5,
|
| 60 |
+
deepen_factor=1.,
|
| 61 |
+
widen_factor=1.,
|
| 62 |
+
out_indices=(4, ),
|
| 63 |
+
channel_attention=True,
|
| 64 |
+
norm_cfg=dict(type='SyncBN'),
|
| 65 |
+
act_cfg=dict(type='SiLU'),
|
| 66 |
+
init_cfg=dict(
|
| 67 |
+
type='Pretrained',
|
| 68 |
+
prefix='backbone.',
|
| 69 |
+
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
|
| 70 |
+
'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa
|
| 71 |
+
)),
|
| 72 |
+
head=dict(
|
| 73 |
+
type='RTMCCHead',
|
| 74 |
+
in_channels=1024,
|
| 75 |
+
out_channels=133,
|
| 76 |
+
input_size=codec['input_size'],
|
| 77 |
+
in_featuremap_size=(9, 12),
|
| 78 |
+
simcc_split_ratio=codec['simcc_split_ratio'],
|
| 79 |
+
final_layer_kernel_size=7,
|
| 80 |
+
gau_cfg=dict(
|
| 81 |
+
hidden_dims=256,
|
| 82 |
+
s=128,
|
| 83 |
+
expansion_factor=2,
|
| 84 |
+
dropout_rate=0.,
|
| 85 |
+
drop_path=0.,
|
| 86 |
+
act_fn='SiLU',
|
| 87 |
+
use_rel_bias=False,
|
| 88 |
+
pos_enc=False),
|
| 89 |
+
loss=dict(
|
| 90 |
+
type='KLDiscretLoss',
|
| 91 |
+
use_target_weight=True,
|
| 92 |
+
beta=10.,
|
| 93 |
+
label_softmax=True),
|
| 94 |
+
decoder=codec),
|
| 95 |
+
test_cfg=dict(flip_test=True, ))
|
| 96 |
+
|
| 97 |
+
# base dataset settings
|
| 98 |
+
dataset_type = 'CocoWholeBodyDataset'
|
| 99 |
+
data_mode = 'topdown'
|
| 100 |
+
data_root = '/data/'
|
| 101 |
+
|
| 102 |
+
backend_args = dict(backend='local')
|
| 103 |
+
# backend_args = dict(
|
| 104 |
+
# backend='petrel',
|
| 105 |
+
# path_mapping=dict({
|
| 106 |
+
# f'{data_root}': 's3://openmmlab/datasets/detection/coco/',
|
| 107 |
+
# f'{data_root}': 's3://openmmlab/datasets/detection/coco/'
|
| 108 |
+
# }))
|
| 109 |
+
|
| 110 |
+
# pipelines
|
| 111 |
+
train_pipeline = [
|
| 112 |
+
dict(type='LoadImage', backend_args=backend_args),
|
| 113 |
+
dict(type='GetBBoxCenterScale'),
|
| 114 |
+
dict(type='RandomFlip', direction='horizontal'),
|
| 115 |
+
dict(type='RandomHalfBody'),
|
| 116 |
+
dict(
|
| 117 |
+
type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
|
| 118 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
| 119 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
| 120 |
+
dict(
|
| 121 |
+
type='Albumentation',
|
| 122 |
+
transforms=[
|
| 123 |
+
dict(type='Blur', p=0.1),
|
| 124 |
+
dict(type='MedianBlur', p=0.1),
|
| 125 |
+
dict(
|
| 126 |
+
type='CoarseDropout',
|
| 127 |
+
max_holes=1,
|
| 128 |
+
max_height=0.4,
|
| 129 |
+
max_width=0.4,
|
| 130 |
+
min_holes=1,
|
| 131 |
+
min_height=0.2,
|
| 132 |
+
min_width=0.2,
|
| 133 |
+
p=1.0),
|
| 134 |
+
]),
|
| 135 |
+
dict(type='GenerateTarget', encoder=codec),
|
| 136 |
+
dict(type='PackPoseInputs')
|
| 137 |
+
]
|
| 138 |
+
val_pipeline = [
|
| 139 |
+
dict(type='LoadImage', backend_args=backend_args),
|
| 140 |
+
dict(type='GetBBoxCenterScale'),
|
| 141 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
| 142 |
+
dict(type='PackPoseInputs')
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
train_pipeline_stage2 = [
|
| 146 |
+
dict(type='LoadImage', backend_args=backend_args),
|
| 147 |
+
dict(type='GetBBoxCenterScale'),
|
| 148 |
+
dict(type='RandomFlip', direction='horizontal'),
|
| 149 |
+
dict(type='RandomHalfBody'),
|
| 150 |
+
dict(
|
| 151 |
+
type='RandomBBoxTransform',
|
| 152 |
+
shift_factor=0.,
|
| 153 |
+
scale_factor=[0.75, 1.25],
|
| 154 |
+
rotate_factor=60),
|
| 155 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
| 156 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
| 157 |
+
dict(
|
| 158 |
+
type='Albumentation',
|
| 159 |
+
transforms=[
|
| 160 |
+
dict(type='Blur', p=0.1),
|
| 161 |
+
dict(type='MedianBlur', p=0.1),
|
| 162 |
+
dict(
|
| 163 |
+
type='CoarseDropout',
|
| 164 |
+
max_holes=1,
|
| 165 |
+
max_height=0.4,
|
| 166 |
+
max_width=0.4,
|
| 167 |
+
min_holes=1,
|
| 168 |
+
min_height=0.2,
|
| 169 |
+
min_width=0.2,
|
| 170 |
+
p=0.5),
|
| 171 |
+
]),
|
| 172 |
+
dict(type='GenerateTarget', encoder=codec),
|
| 173 |
+
dict(type='PackPoseInputs')
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
datasets = []
|
| 177 |
+
dataset_coco=dict(
|
| 178 |
+
type=dataset_type,
|
| 179 |
+
data_root=data_root,
|
| 180 |
+
data_mode=data_mode,
|
| 181 |
+
ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
|
| 182 |
+
data_prefix=dict(img='coco/train2017/'),
|
| 183 |
+
pipeline=[],
|
| 184 |
+
)
|
| 185 |
+
datasets.append(dataset_coco)
|
| 186 |
+
|
| 187 |
+
scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class',
|
| 188 |
+
'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow',
|
| 189 |
+
'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference']
|
| 190 |
+
|
| 191 |
+
for i in range(len(scene)):
|
| 192 |
+
datasets.append(
|
| 193 |
+
dict(
|
| 194 |
+
type=dataset_type,
|
| 195 |
+
data_root=data_root,
|
| 196 |
+
data_mode=data_mode,
|
| 197 |
+
ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json',
|
| 198 |
+
data_prefix=dict(img='UBody/images/'+scene[i]+'/'),
|
| 199 |
+
pipeline=[],
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# data loaders
|
| 204 |
+
train_dataloader = dict(
|
| 205 |
+
batch_size=32,
|
| 206 |
+
num_workers=10,
|
| 207 |
+
persistent_workers=True,
|
| 208 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
| 209 |
+
dataset=dict(
|
| 210 |
+
type='CombinedDataset',
|
| 211 |
+
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
|
| 212 |
+
datasets=datasets,
|
| 213 |
+
pipeline=train_pipeline,
|
| 214 |
+
test_mode=False,
|
| 215 |
+
))
|
| 216 |
+
val_dataloader = dict(
|
| 217 |
+
batch_size=32,
|
| 218 |
+
num_workers=10,
|
| 219 |
+
persistent_workers=True,
|
| 220 |
+
drop_last=False,
|
| 221 |
+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
| 222 |
+
dataset=dict(
|
| 223 |
+
type=dataset_type,
|
| 224 |
+
data_root=data_root,
|
| 225 |
+
data_mode=data_mode,
|
| 226 |
+
ann_file='coco/annotations/coco_wholebody_val_v1.0.json',
|
| 227 |
+
bbox_file=f'{data_root}coco/person_detection_results/'
|
| 228 |
+
'COCO_val2017_detections_AP_H_56_person.json',
|
| 229 |
+
data_prefix=dict(img='coco/val2017/'),
|
| 230 |
+
test_mode=True,
|
| 231 |
+
pipeline=val_pipeline,
|
| 232 |
+
))
|
| 233 |
+
test_dataloader = val_dataloader
|
| 234 |
+
|
| 235 |
+
# hooks
|
| 236 |
+
default_hooks = dict(
|
| 237 |
+
checkpoint=dict(
|
| 238 |
+
save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
|
| 239 |
+
|
| 240 |
+
custom_hooks = [
|
| 241 |
+
dict(
|
| 242 |
+
type='EMAHook',
|
| 243 |
+
ema_type='ExpMomentumEMA',
|
| 244 |
+
momentum=0.0002,
|
| 245 |
+
update_buffers=True,
|
| 246 |
+
priority=49),
|
| 247 |
+
dict(
|
| 248 |
+
type='mmdet.PipelineSwitchHook',
|
| 249 |
+
switch_epoch=max_epochs - stage2_num_epochs,
|
| 250 |
+
switch_pipeline=train_pipeline_stage2)
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
# evaluators
|
| 254 |
+
val_evaluator = dict(
|
| 255 |
+
type='CocoWholeBodyMetric',
|
| 256 |
+
ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json')
|
| 257 |
+
test_evaluator = val_evaluator
|
src/configs/stage1_config.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"--pretrained_model_name_or_path",
|
| 12 |
+
type=str,
|
| 13 |
+
default=None,
|
| 14 |
+
required=True,
|
| 15 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 16 |
+
)
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--pretrained_image_model_path",
|
| 19 |
+
type=str,
|
| 20 |
+
default=None,
|
| 21 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 22 |
+
)
|
| 23 |
+
parser.add_argument(
|
| 24 |
+
"--pretrained_pose_model_path",
|
| 25 |
+
type=str,
|
| 26 |
+
default=None,
|
| 27 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
parser.add_argument(
|
| 31 |
+
"--unet_config_file",
|
| 32 |
+
type=str,
|
| 33 |
+
default=None,
|
| 34 |
+
help="Config file of UNet model",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument("--json_path", type=str, default="./datasets/deepfashing/train_data.json", help="json path", )
|
| 37 |
+
parser.add_argument("--img_path", type=str, default="./datasets/deepfashing/all_data_png/", help="image path", )
|
| 38 |
+
parser.add_argument("--image_encoder_path", type=str, default="./OpenCLIP-ViT-H-14",
|
| 39 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.", )
|
| 40 |
+
parser.add_argument("--img_width", type=int, default=512, help="width", )
|
| 41 |
+
parser.add_argument("--img_height", type=int, default=512, help="height", )
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--output_dir",
|
| 44 |
+
type=str,
|
| 45 |
+
default="sd-model-finetuned",
|
| 46 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
| 49 |
+
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
"--center_crop",
|
| 52 |
+
action="store_true",
|
| 53 |
+
help="Whether to center crop images before resizing to resolution (if not set, random crop will be used)",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--random_flip",
|
| 57 |
+
action="store_true",
|
| 58 |
+
help="whether to randomly flip images horizontally",
|
| 59 |
+
)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
'--clip_penultimate',
|
| 62 |
+
type=bool,
|
| 63 |
+
default=False,
|
| 64 |
+
help='Use penultimate CLIP layer for text embedding'
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
| 68 |
+
)
|
| 69 |
+
parser.add_argument("--num_train_epochs", type=int, default=100000000)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--max_train_steps",
|
| 72 |
+
type=int,
|
| 73 |
+
default=None,
|
| 74 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--gradient_accumulation_steps",
|
| 78 |
+
type=int,
|
| 79 |
+
default=1,
|
| 80 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 81 |
+
)
|
| 82 |
+
parser.add_argument(
|
| 83 |
+
"--gradient_checkpointing",
|
| 84 |
+
action="store_true",
|
| 85 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 86 |
+
)
|
| 87 |
+
parser.add_argument(
|
| 88 |
+
"--learning_rate",
|
| 89 |
+
type=float,
|
| 90 |
+
default=1e-4,
|
| 91 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 92 |
+
)
|
| 93 |
+
parser.add_argument(
|
| 94 |
+
"--weight_decay",
|
| 95 |
+
type=float,
|
| 96 |
+
default=0.01,
|
| 97 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 98 |
+
)
|
| 99 |
+
parser.add_argument(
|
| 100 |
+
"--lr_scheduler",
|
| 101 |
+
type=str,
|
| 102 |
+
default="linear",
|
| 103 |
+
help="The scheduler type to use.",
|
| 104 |
+
choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"],
|
| 105 |
+
)
|
| 106 |
+
parser.add_argument(
|
| 107 |
+
"--num_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--logging_dir",
|
| 112 |
+
type=str,
|
| 113 |
+
default="logs",
|
| 114 |
+
help=(
|
| 115 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 116 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 117 |
+
),
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--print_freq",
|
| 121 |
+
type=int,
|
| 122 |
+
default=1,
|
| 123 |
+
help=(
|
| 124 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 125 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 126 |
+
),
|
| 127 |
+
)
|
| 128 |
+
parser.add_argument(
|
| 129 |
+
"--report_to",
|
| 130 |
+
type=str,
|
| 131 |
+
default="tensorboard",
|
| 132 |
+
help=(
|
| 133 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
| 134 |
+
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
| 135 |
+
"Only applicable when `--with_tracking` is passed."
|
| 136 |
+
),
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--checkpointing_steps",
|
| 140 |
+
type=int,
|
| 141 |
+
default=500,
|
| 142 |
+
help="Whether the various states should be saved at the end of every n steps, or 'epoch' for each epoch.",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--resume_from_checkpoint",
|
| 146 |
+
type=str,
|
| 147 |
+
default=None,
|
| 148 |
+
help="If the training should continue from a checkpoint folder.",
|
| 149 |
+
)
|
| 150 |
+
parser.add_argument(
|
| 151 |
+
"--unet_init_ckpt",
|
| 152 |
+
type=str,
|
| 153 |
+
default=None,
|
| 154 |
+
help="If the training should continue from a checkpoint folder.",
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--mixed_precision",
|
| 159 |
+
type=str,
|
| 160 |
+
default="fp16",
|
| 161 |
+
choices=["no", "fp16", "bf16"],
|
| 162 |
+
help=(
|
| 163 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 164 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 165 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 166 |
+
),
|
| 167 |
+
)
|
| 168 |
+
parser.add_argument(
|
| 169 |
+
"--enable_xformers_memory_efficient_attention",
|
| 170 |
+
action="store_true",
|
| 171 |
+
help="Whether or not to use xformers.",
|
| 172 |
+
)
|
| 173 |
+
parser.add_argument(
|
| 174 |
+
"--max_grad_norm", default=10.0, type=float, help="Max gradient norm."
|
| 175 |
+
)
|
| 176 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
| 177 |
+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
| 178 |
+
|
| 179 |
+
args = parser.parse_args()
|
| 180 |
+
print(args)
|
| 181 |
+
|
src/configs/stage2_config.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
|
| 3 |
+
parser.add_argument(
|
| 4 |
+
"--pretrained_model_name_or_path",
|
| 5 |
+
type=str,
|
| 6 |
+
default="stabilityai/stable-diffusion-2-1-base",
|
| 7 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",)
|
| 8 |
+
|
| 9 |
+
parser.add_argument(
|
| 10 |
+
"--seed", type=int, default=42, help="A seed for reproducible training."
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
parser.add_argument(
|
| 14 |
+
"--train_batch_size",
|
| 15 |
+
type=int,
|
| 16 |
+
default=8,
|
| 17 |
+
help="Batch size (per device) for the training dataloader.",
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument("--num_train_epochs", type=int, default=10000)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--max_train_steps",
|
| 22 |
+
type=int,
|
| 23 |
+
default=100,
|
| 24 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--checkpointing_steps",
|
| 28 |
+
type=int,
|
| 29 |
+
default=1000,
|
| 30 |
+
help=(
|
| 31 |
+
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
| 32 |
+
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
| 33 |
+
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
| 34 |
+
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
| 35 |
+
"instructions."
|
| 36 |
+
),
|
| 37 |
+
)
|
| 38 |
+
parser.add_argument("--json_path", type=str, default="./datasets/deepfashing/train_data.json", help="json path", )
|
| 39 |
+
parser.add_argument("--image_root_path", type=str, default="./datasets/deepfashing/all_data_png/", help="image path", )
|
| 40 |
+
parser.add_argument("--image_encoder_g_path", type=str, default="./OpenCLIP-ViT-H-14",
|
| 41 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.", )
|
| 42 |
+
parser.add_argument("--image_encoder_p_path", type=str, default="./dinov2-giant",
|
| 43 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.", )
|
| 44 |
+
parser.add_argument("--output_dir",type=str,default="out/",help="The output directory where the model predictions and checkpoints will be written.",)
|
| 45 |
+
parser.add_argument("--img_width", type=int, default=512, help="device number", )
|
| 46 |
+
parser.add_argument("--img_height", type=int, default=512, help="device number", )
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--resume_from_checkpoint",
|
| 49 |
+
type=str,
|
| 50 |
+
default="pcdms_ckpt.pt",
|
| 51 |
+
help=(
|
| 52 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 53 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 54 |
+
),
|
| 55 |
+
)
|
| 56 |
+
parser.add_argument(
|
| 57 |
+
"--set_grads_to_none",
|
| 58 |
+
action="store_true",
|
| 59 |
+
help=(
|
| 60 |
+
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
| 61 |
+
" behaviors, so disable this argument if it causes any problems. More info:"
|
| 62 |
+
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
| 63 |
+
),
|
| 64 |
+
)
|
| 65 |
+
parser.add_argument(
|
| 66 |
+
"--gradient_accumulation_steps",
|
| 67 |
+
type=int,
|
| 68 |
+
default=1,
|
| 69 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--gradient_checkpointing",
|
| 73 |
+
action="store_true",
|
| 74 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--learning_rate",
|
| 78 |
+
type=float,
|
| 79 |
+
default=1e-4, #5e-6,
|
| 80 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--scale_lr",
|
| 85 |
+
action="store_true",
|
| 86 |
+
default=False,
|
| 87 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--lr_scheduler",
|
| 91 |
+
type=str,
|
| 92 |
+
default="constant",
|
| 93 |
+
help=(
|
| 94 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 95 |
+
' "constant", "constant_with_warmup"]'
|
| 96 |
+
),
|
| 97 |
+
)
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--lr_warmup_steps",
|
| 100 |
+
type=int,
|
| 101 |
+
default=5000,
|
| 102 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
| 103 |
+
)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--lr_num_cycles",
|
| 106 |
+
type=int,
|
| 107 |
+
default=1,
|
| 108 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument(
|
| 111 |
+
"--lr_power",
|
| 112 |
+
type=float,
|
| 113 |
+
default=1.0,
|
| 114 |
+
help="Power factor of the polynomial scheduler.",
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--adam_beta1",
|
| 120 |
+
type=float,
|
| 121 |
+
default=0.9,
|
| 122 |
+
help="The beta1 parameter for the Adam optimizer.",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--adam_beta2",
|
| 126 |
+
type=float,
|
| 127 |
+
default=0.999,
|
| 128 |
+
help="The beta2 parameter for the Adam optimizer.",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--adam_epsilon",
|
| 135 |
+
type=float,
|
| 136 |
+
default=1e-08,
|
| 137 |
+
help="Epsilon value for the Adam optimizer",
|
| 138 |
+
)
|
| 139 |
+
parser.add_argument(
|
| 140 |
+
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--logging_dir",
|
| 144 |
+
type=str,
|
| 145 |
+
default="logs",
|
| 146 |
+
help=(
|
| 147 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 148 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 149 |
+
),
|
| 150 |
+
)
|
| 151 |
+
parser.add_argument(
|
| 152 |
+
"--report_to",
|
| 153 |
+
type=str,
|
| 154 |
+
default="tensorboard",
|
| 155 |
+
help=(
|
| 156 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
|
| 157 |
+
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
|
| 158 |
+
"Only applicable when `--with_tracking` is passed."
|
| 159 |
+
),
|
| 160 |
+
)
|
| 161 |
+
parser.add_argument(
|
| 162 |
+
"--allow_tf32",
|
| 163 |
+
action="store_true",
|
| 164 |
+
help=(
|
| 165 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 166 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 167 |
+
),
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--mixed_precision",
|
| 171 |
+
type=str,
|
| 172 |
+
default="fp16",
|
| 173 |
+
choices=["no", "fp16", "bf16"],
|
| 174 |
+
help=(
|
| 175 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 176 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 177 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 178 |
+
),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
parser.add_argument("--noise_offset", type=float, default=0.1, help="The scale of noise offset.")
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
args = parser.parse_args()
|
| 189 |
+
print(args)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
|
src/configs/stage3_config.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
|
| 3 |
+
parser.add_argument(
|
| 4 |
+
"--pretrained_model_name_or_path",
|
| 5 |
+
type=str,
|
| 6 |
+
default=None,
|
| 7 |
+
required=True,
|
| 8 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.",)
|
| 9 |
+
|
| 10 |
+
parser.add_argument("--revision",type=str,default=None,required=False,help=(
|
| 11 |
+
"Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
|
| 12 |
+
" float32 precision."),)
|
| 13 |
+
parser.add_argument("--json_path", type=str, default="./datasets/deepfashing/test_data.json", help="json path", )
|
| 14 |
+
parser.add_argument("--img_path", type=str, default="./datasets/deepfashing/train_all_png/", help="image path", )
|
| 15 |
+
parser.add_argument("--gen_t_img_path", type=str,default="./save_data/stage2/guidancescale2_seed42_numsteps20/",help="gen target image path", )
|
| 16 |
+
parser.add_argument("--image_encoder_path", type=str, default="./dinov2-giant",
|
| 17 |
+
help="Path to pretrained model or model identifier from huggingface.co/models.", )
|
| 18 |
+
parser.add_argument("--output_dir",type=str,default="controlnet-model",help="The output directory where the model predictions and checkpoints will be written.",)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--seed", type=int, default=None, help="A seed for reproducible training."
|
| 23 |
+
)
|
| 24 |
+
parser.add_argument(
|
| 25 |
+
"--resolution",
|
| 26 |
+
type=int,
|
| 27 |
+
default=512,
|
| 28 |
+
help=(
|
| 29 |
+
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
| 30 |
+
" resolution"
|
| 31 |
+
),
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--train_batch_size",
|
| 35 |
+
type=int,
|
| 36 |
+
default=4,
|
| 37 |
+
help="Batch size (per device) for the training dataloader.",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument("--num_train_epochs", type=int, default=1)
|
| 40 |
+
parser.add_argument("--noise_level", type=int, default=250)
|
| 41 |
+
|
| 42 |
+
parser.add_argument(
|
| 43 |
+
"--max_train_steps",
|
| 44 |
+
type=int,
|
| 45 |
+
default=None,
|
| 46 |
+
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
| 47 |
+
)
|
| 48 |
+
parser.add_argument(
|
| 49 |
+
"--checkpointing_steps",
|
| 50 |
+
type=int,
|
| 51 |
+
default=500,
|
| 52 |
+
help=(
|
| 53 |
+
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
|
| 54 |
+
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
|
| 55 |
+
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
|
| 56 |
+
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
|
| 57 |
+
"instructions."
|
| 58 |
+
),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--resume_from_checkpoint",
|
| 63 |
+
type=str,
|
| 64 |
+
default=None,
|
| 65 |
+
help=(
|
| 66 |
+
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
| 67 |
+
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
| 68 |
+
),
|
| 69 |
+
)
|
| 70 |
+
parser.add_argument(
|
| 71 |
+
"--gradient_accumulation_steps",
|
| 72 |
+
type=int,
|
| 73 |
+
default=1,
|
| 74 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
| 75 |
+
)
|
| 76 |
+
parser.add_argument(
|
| 77 |
+
"--gradient_checkpointing",
|
| 78 |
+
action="store_true",
|
| 79 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
| 80 |
+
)
|
| 81 |
+
parser.add_argument(
|
| 82 |
+
"--learning_rate",
|
| 83 |
+
type=float,
|
| 84 |
+
default=5e-6,
|
| 85 |
+
help="Initial learning rate (after the potential warmup period) to use.",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
parser.add_argument(
|
| 89 |
+
"--scale_lr",
|
| 90 |
+
action="store_true",
|
| 91 |
+
default=False,
|
| 92 |
+
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
| 93 |
+
)
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--lr_scheduler",
|
| 96 |
+
type=str,
|
| 97 |
+
default="constant",
|
| 98 |
+
help=(
|
| 99 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
| 100 |
+
' "constant", "constant_with_warmup"]'
|
| 101 |
+
),
|
| 102 |
+
)
|
| 103 |
+
parser.add_argument(
|
| 104 |
+
"--lr_warmup_steps",
|
| 105 |
+
type=int,
|
| 106 |
+
default=500,
|
| 107 |
+
help="Number of steps for the warmup in the lr scheduler.",
|
| 108 |
+
)
|
| 109 |
+
parser.add_argument(
|
| 110 |
+
"--lr_num_cycles",
|
| 111 |
+
type=int,
|
| 112 |
+
default=1,
|
| 113 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
| 114 |
+
)
|
| 115 |
+
parser.add_argument(
|
| 116 |
+
"--lr_power",
|
| 117 |
+
type=float,
|
| 118 |
+
default=1.0,
|
| 119 |
+
help="Power factor of the polynomial scheduler.",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
parser.add_argument(
|
| 124 |
+
"--adam_beta1",
|
| 125 |
+
type=float,
|
| 126 |
+
default=0.9,
|
| 127 |
+
help="The beta1 parameter for the Adam optimizer.",
|
| 128 |
+
)
|
| 129 |
+
parser.add_argument(
|
| 130 |
+
"--adam_beta2",
|
| 131 |
+
type=float,
|
| 132 |
+
default=0.999,
|
| 133 |
+
help="The beta2 parameter for the Adam optimizer.",
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
|
| 137 |
+
)
|
| 138 |
+
parser.add_argument(
|
| 139 |
+
"--adam_epsilon",
|
| 140 |
+
type=float,
|
| 141 |
+
default=1e-08,
|
| 142 |
+
help="Epsilon value for the Adam optimizer",
|
| 143 |
+
)
|
| 144 |
+
parser.add_argument(
|
| 145 |
+
"--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
parser.add_argument(
|
| 150 |
+
"--logging_dir",
|
| 151 |
+
type=str,
|
| 152 |
+
default="logs",
|
| 153 |
+
help=(
|
| 154 |
+
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
| 155 |
+
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
| 156 |
+
),
|
| 157 |
+
)
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--allow_tf32",
|
| 160 |
+
action="store_true",
|
| 161 |
+
help=(
|
| 162 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
| 163 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
| 164 |
+
),
|
| 165 |
+
)
|
| 166 |
+
parser.add_argument(
|
| 167 |
+
"--report_to",
|
| 168 |
+
type=str,
|
| 169 |
+
default="tensorboard",
|
| 170 |
+
help=(
|
| 171 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
| 172 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
| 173 |
+
),
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--mixed_precision",
|
| 177 |
+
type=str,
|
| 178 |
+
default=None,
|
| 179 |
+
choices=["no", "fp16", "bf16"],
|
| 180 |
+
help=(
|
| 181 |
+
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
| 182 |
+
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
| 183 |
+
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
| 184 |
+
),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--set_grads_to_none",
|
| 189 |
+
action="store_true",
|
| 190 |
+
help=(
|
| 191 |
+
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
|
| 192 |
+
" behaviors, so disable this argument if it causes any problems. More info:"
|
| 193 |
+
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
|
| 194 |
+
),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
| 198 |
+
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--tracker_project_name",
|
| 201 |
+
type=str,
|
| 202 |
+
default="train_baseline",
|
| 203 |
+
help=(
|
| 204 |
+
"The `project_name` argument passed to Accelerator.init_trackers for"
|
| 205 |
+
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
| 206 |
+
),
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
args = parser.parse_args()
|
| 211 |
+
print(args)
|
| 212 |
+
if args.resolution % 8 != 0:
|
| 213 |
+
raise ValueError(
|
| 214 |
+
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
|
src/configs/yolox_l_8xb8-300e_coco.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
img_scale = (640, 640) # width, height
|
| 2 |
+
|
| 3 |
+
# model settings
|
| 4 |
+
model = dict(
|
| 5 |
+
type='YOLOX',
|
| 6 |
+
data_preprocessor=dict(
|
| 7 |
+
type='DetDataPreprocessor',
|
| 8 |
+
pad_size_divisor=32,
|
| 9 |
+
batch_augments=[
|
| 10 |
+
dict(
|
| 11 |
+
type='BatchSyncRandomResize',
|
| 12 |
+
random_size_range=(480, 800),
|
| 13 |
+
size_divisor=32,
|
| 14 |
+
interval=10)
|
| 15 |
+
]),
|
| 16 |
+
backbone=dict(
|
| 17 |
+
type='CSPDarknet',
|
| 18 |
+
deepen_factor=1.0,
|
| 19 |
+
widen_factor=1.0,
|
| 20 |
+
out_indices=(2, 3, 4),
|
| 21 |
+
use_depthwise=False,
|
| 22 |
+
spp_kernal_sizes=(5, 9, 13),
|
| 23 |
+
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
| 24 |
+
act_cfg=dict(type='Swish'),
|
| 25 |
+
),
|
| 26 |
+
neck=dict(
|
| 27 |
+
type='YOLOXPAFPN',
|
| 28 |
+
in_channels=[256, 512, 1024],
|
| 29 |
+
out_channels=256,
|
| 30 |
+
num_csp_blocks=3,
|
| 31 |
+
use_depthwise=False,
|
| 32 |
+
upsample_cfg=dict(scale_factor=2, mode='nearest'),
|
| 33 |
+
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
| 34 |
+
act_cfg=dict(type='Swish')),
|
| 35 |
+
bbox_head=dict(
|
| 36 |
+
type='YOLOXHead',
|
| 37 |
+
num_classes=80,
|
| 38 |
+
in_channels=256,
|
| 39 |
+
feat_channels=256,
|
| 40 |
+
stacked_convs=2,
|
| 41 |
+
strides=(8, 16, 32),
|
| 42 |
+
use_depthwise=False,
|
| 43 |
+
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
| 44 |
+
act_cfg=dict(type='Swish'),
|
| 45 |
+
loss_cls=dict(
|
| 46 |
+
type='CrossEntropyLoss',
|
| 47 |
+
use_sigmoid=True,
|
| 48 |
+
reduction='sum',
|
| 49 |
+
loss_weight=1.0),
|
| 50 |
+
loss_bbox=dict(
|
| 51 |
+
type='IoULoss',
|
| 52 |
+
mode='square',
|
| 53 |
+
eps=1e-16,
|
| 54 |
+
reduction='sum',
|
| 55 |
+
loss_weight=5.0),
|
| 56 |
+
loss_obj=dict(
|
| 57 |
+
type='CrossEntropyLoss',
|
| 58 |
+
use_sigmoid=True,
|
| 59 |
+
reduction='sum',
|
| 60 |
+
loss_weight=1.0),
|
| 61 |
+
loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
|
| 62 |
+
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
|
| 63 |
+
# In order to align the source code, the threshold of the val phase is
|
| 64 |
+
# 0.01, and the threshold of the test phase is 0.001.
|
| 65 |
+
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
|
| 66 |
+
|
| 67 |
+
# dataset settings
|
| 68 |
+
data_root = 'data/coco/'
|
| 69 |
+
dataset_type = 'CocoDataset'
|
| 70 |
+
|
| 71 |
+
# Example to use different file client
|
| 72 |
+
# Method 1: simply set the data root and let the file I/O module
|
| 73 |
+
# automatically infer from prefix (not support LMDB and Memcache yet)
|
| 74 |
+
|
| 75 |
+
# data_root = 's3://openmmlab/datasets/detection/coco/'
|
| 76 |
+
|
| 77 |
+
# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
|
| 78 |
+
# backend_args = dict(
|
| 79 |
+
# backend='petrel',
|
| 80 |
+
# path_mapping=dict({
|
| 81 |
+
# './data/': 's3://openmmlab/datasets/detection/',
|
| 82 |
+
# 'data/': 's3://openmmlab/datasets/detection/'
|
| 83 |
+
# }))
|
| 84 |
+
backend_args = None
|
| 85 |
+
|
| 86 |
+
train_pipeline = [
|
| 87 |
+
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
|
| 88 |
+
dict(
|
| 89 |
+
type='RandomAffine',
|
| 90 |
+
scaling_ratio_range=(0.1, 2),
|
| 91 |
+
# img_scale is (width, height)
|
| 92 |
+
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
|
| 93 |
+
dict(
|
| 94 |
+
type='MixUp',
|
| 95 |
+
img_scale=img_scale,
|
| 96 |
+
ratio_range=(0.8, 1.6),
|
| 97 |
+
pad_val=114.0),
|
| 98 |
+
dict(type='YOLOXHSVRandomAug'),
|
| 99 |
+
dict(type='RandomFlip', prob=0.5),
|
| 100 |
+
# According to the official implementation, multi-scale
|
| 101 |
+
# training is not considered here but in the
|
| 102 |
+
# 'mmdet/models/detectors/yolox.py'.
|
| 103 |
+
# Resize and Pad are for the last 15 epochs when Mosaic,
|
| 104 |
+
# RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
|
| 105 |
+
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
| 106 |
+
dict(
|
| 107 |
+
type='Pad',
|
| 108 |
+
pad_to_square=True,
|
| 109 |
+
# If the image is three-channel, the pad value needs
|
| 110 |
+
# to be set separately for each channel.
|
| 111 |
+
pad_val=dict(img=(114.0, 114.0, 114.0))),
|
| 112 |
+
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
|
| 113 |
+
dict(type='PackDetInputs')
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
train_dataset = dict(
|
| 117 |
+
# use MultiImageMixDataset wrapper to support mosaic and mixup
|
| 118 |
+
type='MultiImageMixDataset',
|
| 119 |
+
dataset=dict(
|
| 120 |
+
type=dataset_type,
|
| 121 |
+
data_root=data_root,
|
| 122 |
+
ann_file='annotations/instances_train2017.json',
|
| 123 |
+
data_prefix=dict(img='train2017/'),
|
| 124 |
+
pipeline=[
|
| 125 |
+
dict(type='LoadImageFromFile', backend_args=backend_args),
|
| 126 |
+
dict(type='LoadAnnotations', with_bbox=True)
|
| 127 |
+
],
|
| 128 |
+
filter_cfg=dict(filter_empty_gt=False, min_size=32),
|
| 129 |
+
backend_args=backend_args),
|
| 130 |
+
pipeline=train_pipeline)
|
| 131 |
+
|
| 132 |
+
test_pipeline = [
|
| 133 |
+
dict(type='LoadImageFromFile', backend_args=backend_args),
|
| 134 |
+
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
| 135 |
+
dict(
|
| 136 |
+
type='Pad',
|
| 137 |
+
pad_to_square=True,
|
| 138 |
+
pad_val=dict(img=(114.0, 114.0, 114.0))),
|
| 139 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
| 140 |
+
dict(
|
| 141 |
+
type='PackDetInputs',
|
| 142 |
+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
| 143 |
+
'scale_factor'))
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
train_dataloader = dict(
|
| 147 |
+
batch_size=8,
|
| 148 |
+
num_workers=4,
|
| 149 |
+
persistent_workers=True,
|
| 150 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
| 151 |
+
dataset=train_dataset)
|
| 152 |
+
val_dataloader = dict(
|
| 153 |
+
batch_size=8,
|
| 154 |
+
num_workers=4,
|
| 155 |
+
persistent_workers=True,
|
| 156 |
+
drop_last=False,
|
| 157 |
+
sampler=dict(type='DefaultSampler', shuffle=False),
|
| 158 |
+
dataset=dict(
|
| 159 |
+
type=dataset_type,
|
| 160 |
+
data_root=data_root,
|
| 161 |
+
ann_file='annotations/instances_val2017.json',
|
| 162 |
+
data_prefix=dict(img='val2017/'),
|
| 163 |
+
test_mode=True,
|
| 164 |
+
pipeline=test_pipeline,
|
| 165 |
+
backend_args=backend_args))
|
| 166 |
+
test_dataloader = val_dataloader
|
| 167 |
+
|
| 168 |
+
val_evaluator = dict(
|
| 169 |
+
type='CocoMetric',
|
| 170 |
+
ann_file=data_root + 'annotations/instances_val2017.json',
|
| 171 |
+
metric='bbox',
|
| 172 |
+
backend_args=backend_args)
|
| 173 |
+
test_evaluator = val_evaluator
|
| 174 |
+
|
| 175 |
+
# training settings
|
| 176 |
+
max_epochs = 300
|
| 177 |
+
num_last_epochs = 15
|
| 178 |
+
interval = 10
|
| 179 |
+
|
| 180 |
+
train_cfg = dict(max_epochs=max_epochs, val_interval=interval)
|
| 181 |
+
|
| 182 |
+
# optimizer
|
| 183 |
+
# default 8 gpu
|
| 184 |
+
base_lr = 0.01
|
| 185 |
+
optim_wrapper = dict(
|
| 186 |
+
type='OptimWrapper',
|
| 187 |
+
optimizer=dict(
|
| 188 |
+
type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4,
|
| 189 |
+
nesterov=True),
|
| 190 |
+
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
|
| 191 |
+
|
| 192 |
+
# learning rate
|
| 193 |
+
param_scheduler = [
|
| 194 |
+
dict(
|
| 195 |
+
# use quadratic formula to warm up 5 epochs
|
| 196 |
+
# and lr is updated by iteration
|
| 197 |
+
# TODO: fix default scope in get function
|
| 198 |
+
type='mmdet.QuadraticWarmupLR',
|
| 199 |
+
by_epoch=True,
|
| 200 |
+
begin=0,
|
| 201 |
+
end=5,
|
| 202 |
+
convert_to_iter_based=True),
|
| 203 |
+
dict(
|
| 204 |
+
# use cosine lr from 5 to 285 epoch
|
| 205 |
+
type='CosineAnnealingLR',
|
| 206 |
+
eta_min=base_lr * 0.05,
|
| 207 |
+
begin=5,
|
| 208 |
+
T_max=max_epochs - num_last_epochs,
|
| 209 |
+
end=max_epochs - num_last_epochs,
|
| 210 |
+
by_epoch=True,
|
| 211 |
+
convert_to_iter_based=True),
|
| 212 |
+
dict(
|
| 213 |
+
# use fixed lr during last 15 epochs
|
| 214 |
+
type='ConstantLR',
|
| 215 |
+
by_epoch=True,
|
| 216 |
+
factor=1,
|
| 217 |
+
begin=max_epochs - num_last_epochs,
|
| 218 |
+
end=max_epochs,
|
| 219 |
+
)
|
| 220 |
+
]
|
| 221 |
+
|
| 222 |
+
default_hooks = dict(
|
| 223 |
+
checkpoint=dict(
|
| 224 |
+
interval=interval,
|
| 225 |
+
max_keep_ckpts=3 # only keep latest 3 checkpoints
|
| 226 |
+
))
|
| 227 |
+
|
| 228 |
+
custom_hooks = [
|
| 229 |
+
dict(
|
| 230 |
+
type='YOLOXModeSwitchHook',
|
| 231 |
+
num_last_epochs=num_last_epochs,
|
| 232 |
+
priority=48),
|
| 233 |
+
dict(type='SyncNormHook', priority=48),
|
| 234 |
+
dict(
|
| 235 |
+
type='EMAHook',
|
| 236 |
+
ema_type='ExpMomentumEMA',
|
| 237 |
+
momentum=0.0001,
|
| 238 |
+
update_buffers=True,
|
| 239 |
+
priority=49)
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
|
| 243 |
+
# USER SHOULD NOT CHANGE ITS VALUES.
|
| 244 |
+
# base_batch_size = (8 GPUs) x (8 samples per GPU)
|
| 245 |
+
auto_scale_lr = dict(base_batch_size=64)
|
src/controlnet_aux/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__version__ = "0.0.6"
|
| 2 |
+
|
| 3 |
+
from .hed import HEDdetector
|
| 4 |
+
from .leres import LeresDetector
|
| 5 |
+
from .lineart import LineartDetector
|
| 6 |
+
from .lineart_anime import LineartAnimeDetector
|
| 7 |
+
from .midas import MidasDetector
|
| 8 |
+
from .mlsd import MLSDdetector
|
| 9 |
+
from .normalbae import NormalBaeDetector
|
| 10 |
+
from .open_pose import OpenposeDetector
|
| 11 |
+
from .pidi import PidiNetDetector
|
| 12 |
+
from .zoe import ZoeDetector
|
| 13 |
+
|
| 14 |
+
from .canny import CannyDetector
|
| 15 |
+
from .mediapipe_face import MediapipeFaceDetector
|
| 16 |
+
from .segment_anything import SamDetector
|
| 17 |
+
from .shuffle import ContentShuffleDetector
|
| 18 |
+
from .dwpose import DWposeDetector
|
src/controlnet_aux/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.17 kB). View file
|
|
|
src/controlnet_aux/__pycache__/util.cpython-311.pyc
ADDED
|
Binary file (13 kB). View file
|
|
|
src/controlnet_aux/canny/__init__.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from ..util import HWC3, resize_image
|
| 6 |
+
|
| 7 |
+
class CannyDetector:
|
| 8 |
+
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, image_resolution=512, output_type=None, **kwargs):
|
| 9 |
+
if "img" in kwargs:
|
| 10 |
+
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
|
| 11 |
+
input_image = kwargs.pop("img")
|
| 12 |
+
|
| 13 |
+
if input_image is None:
|
| 14 |
+
raise ValueError("input_image must be defined.")
|
| 15 |
+
|
| 16 |
+
if not isinstance(input_image, np.ndarray):
|
| 17 |
+
input_image = np.array(input_image, dtype=np.uint8)
|
| 18 |
+
output_type = output_type or "pil"
|
| 19 |
+
else:
|
| 20 |
+
output_type = output_type or "np"
|
| 21 |
+
|
| 22 |
+
input_image = HWC3(input_image)
|
| 23 |
+
input_image = resize_image(input_image, detect_resolution)
|
| 24 |
+
|
| 25 |
+
detected_map = cv2.Canny(input_image, low_threshold, high_threshold)
|
| 26 |
+
detected_map = HWC3(detected_map)
|
| 27 |
+
|
| 28 |
+
img = resize_image(input_image, image_resolution)
|
| 29 |
+
H, W, C = img.shape
|
| 30 |
+
|
| 31 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 32 |
+
|
| 33 |
+
if output_type == "pil":
|
| 34 |
+
detected_map = Image.fromarray(detected_map)
|
| 35 |
+
|
| 36 |
+
return detected_map
|
src/controlnet_aux/canny/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (2.11 kB). View file
|
|
|
src/controlnet_aux/dwpose/__init__.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Openpose
|
| 2 |
+
# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
|
| 3 |
+
# 2nd Edited by https://github.com/Hzzone/pytorch-openpose
|
| 4 |
+
# 3rd Edited by ControlNet
|
| 5 |
+
# 4th Edited by ControlNet (added face and correct hands)
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import torch
|
| 12 |
+
import numpy as np
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from ..util import HWC3, resize_image
|
| 16 |
+
from . import util
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def draw_pose(pose, H, W):
|
| 20 |
+
bodies = pose['bodies']
|
| 21 |
+
faces = pose['faces']
|
| 22 |
+
hands = pose['hands']
|
| 23 |
+
candidate = bodies['candidate']
|
| 24 |
+
subset = bodies['subset']
|
| 25 |
+
|
| 26 |
+
canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
|
| 27 |
+
canvas = util.draw_bodypose(canvas, candidate, subset)
|
| 28 |
+
canvas = util.draw_handpose(canvas, hands)
|
| 29 |
+
# canvas = util.draw_facepose(canvas, faces)
|
| 30 |
+
|
| 31 |
+
return canvas
|
| 32 |
+
|
| 33 |
+
class DWposeDetector:
|
| 34 |
+
def __init__(self, det_config=None, det_ckpt=None, pose_config=None, pose_ckpt=None, device="cpu"):
|
| 35 |
+
from .wholebody import Wholebody
|
| 36 |
+
|
| 37 |
+
self.pose_estimation = Wholebody(det_config, det_ckpt, pose_config, pose_ckpt, device)
|
| 38 |
+
|
| 39 |
+
def to(self, device):
|
| 40 |
+
self.pose_estimation.to(device)
|
| 41 |
+
return self
|
| 42 |
+
|
| 43 |
+
def __call__(self, input_image, detect_resolution=512, image_resolution=512, output_type="pil", **kwargs):
|
| 44 |
+
|
| 45 |
+
input_image = cv2.cvtColor(np.array(input_image, dtype=np.uint8), cv2.COLOR_RGB2BGR)
|
| 46 |
+
|
| 47 |
+
input_image = HWC3(input_image)
|
| 48 |
+
input_image = resize_image(input_image, detect_resolution)
|
| 49 |
+
|
| 50 |
+
H, W, C = input_image.shape
|
| 51 |
+
|
| 52 |
+
with torch.no_grad():
|
| 53 |
+
candidate, subset = self.pose_estimation(input_image)
|
| 54 |
+
nums, keys, locs = candidate.shape
|
| 55 |
+
candidate[..., 0] /= float(W)
|
| 56 |
+
candidate[..., 1] /= float(H)
|
| 57 |
+
body = candidate[:,:18].copy()
|
| 58 |
+
body = body.reshape(nums*18, locs)
|
| 59 |
+
score = subset[:,:18]
|
| 60 |
+
|
| 61 |
+
for i in range(len(score)):
|
| 62 |
+
for j in range(len(score[i])):
|
| 63 |
+
if score[i][j] > 0.3:
|
| 64 |
+
score[i][j] = int(18*i+j)
|
| 65 |
+
else:
|
| 66 |
+
score[i][j] = -1
|
| 67 |
+
|
| 68 |
+
un_visible = subset<0.3
|
| 69 |
+
candidate[un_visible] = -1
|
| 70 |
+
|
| 71 |
+
foot = candidate[:,18:24]
|
| 72 |
+
|
| 73 |
+
faces = candidate[:,24:92]
|
| 74 |
+
|
| 75 |
+
hands = candidate[:,92:113]
|
| 76 |
+
hands = np.vstack([hands, candidate[:,113:]])
|
| 77 |
+
|
| 78 |
+
bodies = dict(candidate=body, subset=score)
|
| 79 |
+
pose = dict(bodies=bodies, hands=hands, faces=faces)
|
| 80 |
+
|
| 81 |
+
detected_map = draw_pose(pose, H, W)
|
| 82 |
+
detected_map = HWC3(detected_map)
|
| 83 |
+
|
| 84 |
+
img = resize_image(input_image, image_resolution)
|
| 85 |
+
H, W, C = img.shape
|
| 86 |
+
|
| 87 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 88 |
+
|
| 89 |
+
if output_type == "pil":
|
| 90 |
+
detected_map = Image.fromarray(detected_map)
|
| 91 |
+
|
| 92 |
+
return detected_map
|
src/controlnet_aux/dwpose/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (4.84 kB). View file
|
|
|
src/controlnet_aux/dwpose/__pycache__/util.cpython-311.pyc
ADDED
|
Binary file (16.5 kB). View file
|
|
|
src/controlnet_aux/dwpose/__pycache__/wholebody.cpython-311.pyc
ADDED
|
Binary file (6.2 kB). View file
|
|
|
src/controlnet_aux/dwpose/dwpose_config/dwpose-l_384x288.py
ADDED
|
@@ -0,0 +1,257 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# runtime
|
| 2 |
+
max_epochs = 270
|
| 3 |
+
stage2_num_epochs = 30
|
| 4 |
+
base_lr = 4e-3
|
| 5 |
+
|
| 6 |
+
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
| 7 |
+
randomness = dict(seed=21)
|
| 8 |
+
|
| 9 |
+
# optimizer
|
| 10 |
+
optim_wrapper = dict(
|
| 11 |
+
type='OptimWrapper',
|
| 12 |
+
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
|
| 13 |
+
paramwise_cfg=dict(
|
| 14 |
+
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
| 15 |
+
|
| 16 |
+
# learning rate
|
| 17 |
+
param_scheduler = [
|
| 18 |
+
dict(
|
| 19 |
+
type='LinearLR',
|
| 20 |
+
start_factor=1.0e-5,
|
| 21 |
+
by_epoch=False,
|
| 22 |
+
begin=0,
|
| 23 |
+
end=1000),
|
| 24 |
+
dict(
|
| 25 |
+
# use cosine lr from 150 to 300 epoch
|
| 26 |
+
type='CosineAnnealingLR',
|
| 27 |
+
eta_min=base_lr * 0.05,
|
| 28 |
+
begin=max_epochs // 2,
|
| 29 |
+
end=max_epochs,
|
| 30 |
+
T_max=max_epochs // 2,
|
| 31 |
+
by_epoch=True,
|
| 32 |
+
convert_to_iter_based=True),
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
# automatically scaling LR based on the actual training batch size
|
| 36 |
+
auto_scale_lr = dict(base_batch_size=512)
|
| 37 |
+
|
| 38 |
+
# codec settings
|
| 39 |
+
codec = dict(
|
| 40 |
+
type='SimCCLabel',
|
| 41 |
+
input_size=(288, 384),
|
| 42 |
+
sigma=(6., 6.93),
|
| 43 |
+
simcc_split_ratio=2.0,
|
| 44 |
+
normalize=False,
|
| 45 |
+
use_dark=False)
|
| 46 |
+
|
| 47 |
+
# model settings
|
| 48 |
+
model = dict(
|
| 49 |
+
type='TopdownPoseEstimator',
|
| 50 |
+
data_preprocessor=dict(
|
| 51 |
+
type='PoseDataPreprocessor',
|
| 52 |
+
mean=[123.675, 116.28, 103.53],
|
| 53 |
+
std=[58.395, 57.12, 57.375],
|
| 54 |
+
bgr_to_rgb=True),
|
| 55 |
+
backbone=dict(
|
| 56 |
+
_scope_='mmdet',
|
| 57 |
+
type='CSPNeXt',
|
| 58 |
+
arch='P5',
|
| 59 |
+
expand_ratio=0.5,
|
| 60 |
+
deepen_factor=1.,
|
| 61 |
+
widen_factor=1.,
|
| 62 |
+
out_indices=(4, ),
|
| 63 |
+
channel_attention=True,
|
| 64 |
+
norm_cfg=dict(type='SyncBN'),
|
| 65 |
+
act_cfg=dict(type='SiLU'),
|
| 66 |
+
init_cfg=dict(
|
| 67 |
+
type='Pretrained',
|
| 68 |
+
prefix='backbone.',
|
| 69 |
+
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
|
| 70 |
+
'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa
|
| 71 |
+
)),
|
| 72 |
+
head=dict(
|
| 73 |
+
type='RTMCCHead',
|
| 74 |
+
in_channels=1024,
|
| 75 |
+
out_channels=133,
|
| 76 |
+
input_size=codec['input_size'],
|
| 77 |
+
in_featuremap_size=(9, 12),
|
| 78 |
+
simcc_split_ratio=codec['simcc_split_ratio'],
|
| 79 |
+
final_layer_kernel_size=7,
|
| 80 |
+
gau_cfg=dict(
|
| 81 |
+
hidden_dims=256,
|
| 82 |
+
s=128,
|
| 83 |
+
expansion_factor=2,
|
| 84 |
+
dropout_rate=0.,
|
| 85 |
+
drop_path=0.,
|
| 86 |
+
act_fn='SiLU',
|
| 87 |
+
use_rel_bias=False,
|
| 88 |
+
pos_enc=False),
|
| 89 |
+
loss=dict(
|
| 90 |
+
type='KLDiscretLoss',
|
| 91 |
+
use_target_weight=True,
|
| 92 |
+
beta=10.,
|
| 93 |
+
label_softmax=True),
|
| 94 |
+
decoder=codec),
|
| 95 |
+
test_cfg=dict(flip_test=True, ))
|
| 96 |
+
|
| 97 |
+
# base dataset settings
|
| 98 |
+
dataset_type = 'CocoWholeBodyDataset'
|
| 99 |
+
data_mode = 'topdown'
|
| 100 |
+
data_root = '/data/'
|
| 101 |
+
|
| 102 |
+
backend_args = dict(backend='local')
|
| 103 |
+
# backend_args = dict(
|
| 104 |
+
# backend='petrel',
|
| 105 |
+
# path_mapping=dict({
|
| 106 |
+
# f'{data_root}': 's3://openmmlab/datasets/detection/coco/',
|
| 107 |
+
# f'{data_root}': 's3://openmmlab/datasets/detection/coco/'
|
| 108 |
+
# }))
|
| 109 |
+
|
| 110 |
+
# pipelines
|
| 111 |
+
train_pipeline = [
|
| 112 |
+
dict(type='LoadImage', backend_args=backend_args),
|
| 113 |
+
dict(type='GetBBoxCenterScale'),
|
| 114 |
+
dict(type='RandomFlip', direction='horizontal'),
|
| 115 |
+
dict(type='RandomHalfBody'),
|
| 116 |
+
dict(
|
| 117 |
+
type='RandomBBoxTransform', scale_factor=[0.6, 1.4], rotate_factor=80),
|
| 118 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
| 119 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
| 120 |
+
dict(
|
| 121 |
+
type='Albumentation',
|
| 122 |
+
transforms=[
|
| 123 |
+
dict(type='Blur', p=0.1),
|
| 124 |
+
dict(type='MedianBlur', p=0.1),
|
| 125 |
+
dict(
|
| 126 |
+
type='CoarseDropout',
|
| 127 |
+
max_holes=1,
|
| 128 |
+
max_height=0.4,
|
| 129 |
+
max_width=0.4,
|
| 130 |
+
min_holes=1,
|
| 131 |
+
min_height=0.2,
|
| 132 |
+
min_width=0.2,
|
| 133 |
+
p=1.0),
|
| 134 |
+
]),
|
| 135 |
+
dict(type='GenerateTarget', encoder=codec),
|
| 136 |
+
dict(type='PackPoseInputs')
|
| 137 |
+
]
|
| 138 |
+
val_pipeline = [
|
| 139 |
+
dict(type='LoadImage', backend_args=backend_args),
|
| 140 |
+
dict(type='GetBBoxCenterScale'),
|
| 141 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
| 142 |
+
dict(type='PackPoseInputs')
|
| 143 |
+
]
|
| 144 |
+
|
| 145 |
+
train_pipeline_stage2 = [
|
| 146 |
+
dict(type='LoadImage', backend_args=backend_args),
|
| 147 |
+
dict(type='GetBBoxCenterScale'),
|
| 148 |
+
dict(type='RandomFlip', direction='horizontal'),
|
| 149 |
+
dict(type='RandomHalfBody'),
|
| 150 |
+
dict(
|
| 151 |
+
type='RandomBBoxTransform',
|
| 152 |
+
shift_factor=0.,
|
| 153 |
+
scale_factor=[0.75, 1.25],
|
| 154 |
+
rotate_factor=60),
|
| 155 |
+
dict(type='TopdownAffine', input_size=codec['input_size']),
|
| 156 |
+
dict(type='mmdet.YOLOXHSVRandomAug'),
|
| 157 |
+
dict(
|
| 158 |
+
type='Albumentation',
|
| 159 |
+
transforms=[
|
| 160 |
+
dict(type='Blur', p=0.1),
|
| 161 |
+
dict(type='MedianBlur', p=0.1),
|
| 162 |
+
dict(
|
| 163 |
+
type='CoarseDropout',
|
| 164 |
+
max_holes=1,
|
| 165 |
+
max_height=0.4,
|
| 166 |
+
max_width=0.4,
|
| 167 |
+
min_holes=1,
|
| 168 |
+
min_height=0.2,
|
| 169 |
+
min_width=0.2,
|
| 170 |
+
p=0.5),
|
| 171 |
+
]),
|
| 172 |
+
dict(type='GenerateTarget', encoder=codec),
|
| 173 |
+
dict(type='PackPoseInputs')
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
datasets = []
|
| 177 |
+
dataset_coco=dict(
|
| 178 |
+
type=dataset_type,
|
| 179 |
+
data_root=data_root,
|
| 180 |
+
data_mode=data_mode,
|
| 181 |
+
ann_file='coco/annotations/coco_wholebody_train_v1.0.json',
|
| 182 |
+
data_prefix=dict(img='coco/train2017/'),
|
| 183 |
+
pipeline=[],
|
| 184 |
+
)
|
| 185 |
+
datasets.append(dataset_coco)
|
| 186 |
+
|
| 187 |
+
scene = ['Magic_show', 'Entertainment', 'ConductMusic', 'Online_class',
|
| 188 |
+
'TalkShow', 'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow',
|
| 189 |
+
'Singing', 'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference']
|
| 190 |
+
|
| 191 |
+
for i in range(len(scene)):
|
| 192 |
+
datasets.append(
|
| 193 |
+
dict(
|
| 194 |
+
type=dataset_type,
|
| 195 |
+
data_root=data_root,
|
| 196 |
+
data_mode=data_mode,
|
| 197 |
+
ann_file='UBody/annotations/'+scene[i]+'/keypoint_annotation.json',
|
| 198 |
+
data_prefix=dict(img='UBody/images/'+scene[i]+'/'),
|
| 199 |
+
pipeline=[],
|
| 200 |
+
)
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# data loaders
|
| 204 |
+
train_dataloader = dict(
|
| 205 |
+
batch_size=32,
|
| 206 |
+
num_workers=10,
|
| 207 |
+
persistent_workers=True,
|
| 208 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
| 209 |
+
dataset=dict(
|
| 210 |
+
type='CombinedDataset',
|
| 211 |
+
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
|
| 212 |
+
datasets=datasets,
|
| 213 |
+
pipeline=train_pipeline,
|
| 214 |
+
test_mode=False,
|
| 215 |
+
))
|
| 216 |
+
val_dataloader = dict(
|
| 217 |
+
batch_size=32,
|
| 218 |
+
num_workers=10,
|
| 219 |
+
persistent_workers=True,
|
| 220 |
+
drop_last=False,
|
| 221 |
+
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
| 222 |
+
dataset=dict(
|
| 223 |
+
type=dataset_type,
|
| 224 |
+
data_root=data_root,
|
| 225 |
+
data_mode=data_mode,
|
| 226 |
+
ann_file='coco/annotations/coco_wholebody_val_v1.0.json',
|
| 227 |
+
bbox_file=f'{data_root}coco/person_detection_results/'
|
| 228 |
+
'COCO_val2017_detections_AP_H_56_person.json',
|
| 229 |
+
data_prefix=dict(img='coco/val2017/'),
|
| 230 |
+
test_mode=True,
|
| 231 |
+
pipeline=val_pipeline,
|
| 232 |
+
))
|
| 233 |
+
test_dataloader = val_dataloader
|
| 234 |
+
|
| 235 |
+
# hooks
|
| 236 |
+
default_hooks = dict(
|
| 237 |
+
checkpoint=dict(
|
| 238 |
+
save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
|
| 239 |
+
|
| 240 |
+
custom_hooks = [
|
| 241 |
+
dict(
|
| 242 |
+
type='EMAHook',
|
| 243 |
+
ema_type='ExpMomentumEMA',
|
| 244 |
+
momentum=0.0002,
|
| 245 |
+
update_buffers=True,
|
| 246 |
+
priority=49),
|
| 247 |
+
dict(
|
| 248 |
+
type='mmdet.PipelineSwitchHook',
|
| 249 |
+
switch_epoch=max_epochs - stage2_num_epochs,
|
| 250 |
+
switch_pipeline=train_pipeline_stage2)
|
| 251 |
+
]
|
| 252 |
+
|
| 253 |
+
# evaluators
|
| 254 |
+
val_evaluator = dict(
|
| 255 |
+
type='CocoWholeBodyMetric',
|
| 256 |
+
ann_file=data_root + 'coco/annotations/coco_wholebody_val_v1.0.json')
|
| 257 |
+
test_evaluator = val_evaluator
|
src/controlnet_aux/dwpose/util.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import cv2
|
| 4 |
+
|
| 5 |
+
eps = 0.01
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def smart_resize(x, s):
|
| 9 |
+
Ht, Wt = s
|
| 10 |
+
if x.ndim == 2:
|
| 11 |
+
Ho, Wo = x.shape
|
| 12 |
+
Co = 1
|
| 13 |
+
else:
|
| 14 |
+
Ho, Wo, Co = x.shape
|
| 15 |
+
if Co == 3 or Co == 1:
|
| 16 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
| 17 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
| 18 |
+
else:
|
| 19 |
+
return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def smart_resize_k(x, fx, fy):
|
| 23 |
+
if x.ndim == 2:
|
| 24 |
+
Ho, Wo = x.shape
|
| 25 |
+
Co = 1
|
| 26 |
+
else:
|
| 27 |
+
Ho, Wo, Co = x.shape
|
| 28 |
+
Ht, Wt = Ho * fy, Wo * fx
|
| 29 |
+
if Co == 3 or Co == 1:
|
| 30 |
+
k = float(Ht + Wt) / float(Ho + Wo)
|
| 31 |
+
return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
|
| 32 |
+
else:
|
| 33 |
+
return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def padRightDownCorner(img, stride, padValue):
|
| 37 |
+
h = img.shape[0]
|
| 38 |
+
w = img.shape[1]
|
| 39 |
+
|
| 40 |
+
pad = 4 * [None]
|
| 41 |
+
pad[0] = 0 # up
|
| 42 |
+
pad[1] = 0 # left
|
| 43 |
+
pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
|
| 44 |
+
pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
|
| 45 |
+
|
| 46 |
+
img_padded = img
|
| 47 |
+
pad_up = np.tile(img_padded[0:1, :, :] * 0 + padValue, (pad[0], 1, 1))
|
| 48 |
+
img_padded = np.concatenate((pad_up, img_padded), axis=0)
|
| 49 |
+
pad_left = np.tile(img_padded[:, 0:1, :] * 0 + padValue, (1, pad[1], 1))
|
| 50 |
+
img_padded = np.concatenate((pad_left, img_padded), axis=1)
|
| 51 |
+
pad_down = np.tile(img_padded[-2:-1, :, :] * 0 + padValue, (pad[2], 1, 1))
|
| 52 |
+
img_padded = np.concatenate((img_padded, pad_down), axis=0)
|
| 53 |
+
pad_right = np.tile(img_padded[:, -2:-1, :] * 0 + padValue, (1, pad[3], 1))
|
| 54 |
+
img_padded = np.concatenate((img_padded, pad_right), axis=1)
|
| 55 |
+
|
| 56 |
+
return img_padded, pad
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def transfer(model, model_weights):
|
| 60 |
+
transfered_model_weights = {}
|
| 61 |
+
for weights_name in model.state_dict().keys():
|
| 62 |
+
transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
|
| 63 |
+
return transfered_model_weights
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def draw_bodypose(canvas, candidate, subset):
|
| 67 |
+
H, W, C = canvas.shape
|
| 68 |
+
candidate = np.array(candidate)
|
| 69 |
+
subset = np.array(subset)
|
| 70 |
+
|
| 71 |
+
stickwidth = 4
|
| 72 |
+
|
| 73 |
+
limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
|
| 74 |
+
[10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
|
| 75 |
+
[1, 16], [16, 18], [3, 17], [6, 18]]
|
| 76 |
+
|
| 77 |
+
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
| 78 |
+
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
| 79 |
+
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
| 80 |
+
|
| 81 |
+
for i in range(17):
|
| 82 |
+
for n in range(len(subset)):
|
| 83 |
+
index = subset[n][np.array(limbSeq[i]) - 1]
|
| 84 |
+
if -1 in index:
|
| 85 |
+
continue
|
| 86 |
+
Y = candidate[index.astype(int), 0] * float(W)
|
| 87 |
+
X = candidate[index.astype(int), 1] * float(H)
|
| 88 |
+
mX = np.mean(X)
|
| 89 |
+
mY = np.mean(Y)
|
| 90 |
+
length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
|
| 91 |
+
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
| 92 |
+
polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
|
| 93 |
+
cv2.fillConvexPoly(canvas, polygon, colors[i])
|
| 94 |
+
|
| 95 |
+
canvas = (canvas * 0.6).astype(np.uint8)
|
| 96 |
+
|
| 97 |
+
for i in range(18):
|
| 98 |
+
for n in range(len(subset)):
|
| 99 |
+
index = int(subset[n][i])
|
| 100 |
+
if index == -1:
|
| 101 |
+
continue
|
| 102 |
+
x, y = candidate[index][0:2]
|
| 103 |
+
x = int(x * W)
|
| 104 |
+
y = int(y * H)
|
| 105 |
+
cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
|
| 106 |
+
|
| 107 |
+
return canvas
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def draw_handpose(canvas, all_hand_peaks):
|
| 111 |
+
import matplotlib
|
| 112 |
+
|
| 113 |
+
H, W, C = canvas.shape
|
| 114 |
+
|
| 115 |
+
edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
|
| 116 |
+
[10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
|
| 117 |
+
|
| 118 |
+
# (person_number*2, 21, 2)
|
| 119 |
+
for i in range(len(all_hand_peaks)):
|
| 120 |
+
peaks = all_hand_peaks[i]
|
| 121 |
+
peaks = np.array(peaks)
|
| 122 |
+
|
| 123 |
+
for ie, e in enumerate(edges):
|
| 124 |
+
|
| 125 |
+
x1, y1 = peaks[e[0]]
|
| 126 |
+
x2, y2 = peaks[e[1]]
|
| 127 |
+
|
| 128 |
+
x1 = int(x1 * W)
|
| 129 |
+
y1 = int(y1 * H)
|
| 130 |
+
x2 = int(x2 * W)
|
| 131 |
+
y2 = int(y2 * H)
|
| 132 |
+
if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
|
| 133 |
+
cv2.line(canvas, (x1, y1), (x2, y2),
|
| 134 |
+
matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=1)
|
| 135 |
+
|
| 136 |
+
for _, keyponit in enumerate(peaks):
|
| 137 |
+
x, y = keyponit
|
| 138 |
+
|
| 139 |
+
x = int(x * W)
|
| 140 |
+
y = int(y * H)
|
| 141 |
+
if x > eps and y > eps:
|
| 142 |
+
cv2.circle(canvas, (x, y), 1, (0, 0, 255), thickness=-1)
|
| 143 |
+
return canvas
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def draw_facepose(canvas, all_lmks):
|
| 147 |
+
H, W, C = canvas.shape
|
| 148 |
+
for lmks in all_lmks:
|
| 149 |
+
lmks = np.array(lmks)
|
| 150 |
+
for lmk in lmks:
|
| 151 |
+
x, y = lmk
|
| 152 |
+
x = int(x * W)
|
| 153 |
+
y = int(y * H)
|
| 154 |
+
if x > eps and y > eps:
|
| 155 |
+
cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
|
| 156 |
+
return canvas
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# detect hand according to body pose keypoints
|
| 160 |
+
# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
|
| 161 |
+
def handDetect(candidate, subset, oriImg):
|
| 162 |
+
# right hand: wrist 4, elbow 3, shoulder 2
|
| 163 |
+
# left hand: wrist 7, elbow 6, shoulder 5
|
| 164 |
+
ratioWristElbow = 0.33
|
| 165 |
+
detect_result = []
|
| 166 |
+
image_height, image_width = oriImg.shape[0:2]
|
| 167 |
+
for person in subset.astype(int):
|
| 168 |
+
# if any of three not detected
|
| 169 |
+
has_left = np.sum(person[[5, 6, 7]] == -1) == 0
|
| 170 |
+
has_right = np.sum(person[[2, 3, 4]] == -1) == 0
|
| 171 |
+
if not (has_left or has_right):
|
| 172 |
+
continue
|
| 173 |
+
hands = []
|
| 174 |
+
# left hand
|
| 175 |
+
if has_left:
|
| 176 |
+
left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
|
| 177 |
+
x1, y1 = candidate[left_shoulder_index][:2]
|
| 178 |
+
x2, y2 = candidate[left_elbow_index][:2]
|
| 179 |
+
x3, y3 = candidate[left_wrist_index][:2]
|
| 180 |
+
hands.append([x1, y1, x2, y2, x3, y3, True])
|
| 181 |
+
# right hand
|
| 182 |
+
if has_right:
|
| 183 |
+
right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
|
| 184 |
+
x1, y1 = candidate[right_shoulder_index][:2]
|
| 185 |
+
x2, y2 = candidate[right_elbow_index][:2]
|
| 186 |
+
x3, y3 = candidate[right_wrist_index][:2]
|
| 187 |
+
hands.append([x1, y1, x2, y2, x3, y3, False])
|
| 188 |
+
|
| 189 |
+
for x1, y1, x2, y2, x3, y3, is_left in hands:
|
| 190 |
+
# pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
|
| 191 |
+
# handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
|
| 192 |
+
# handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
|
| 193 |
+
# const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
|
| 194 |
+
# const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
|
| 195 |
+
# handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
|
| 196 |
+
x = x3 + ratioWristElbow * (x3 - x2)
|
| 197 |
+
y = y3 + ratioWristElbow * (y3 - y2)
|
| 198 |
+
distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
|
| 199 |
+
distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
|
| 200 |
+
width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
|
| 201 |
+
# x-y refers to the center --> offset to topLeft point
|
| 202 |
+
# handRectangle.x -= handRectangle.width / 2.f;
|
| 203 |
+
# handRectangle.y -= handRectangle.height / 2.f;
|
| 204 |
+
x -= width / 2
|
| 205 |
+
y -= width / 2 # width = height
|
| 206 |
+
# overflow the image
|
| 207 |
+
if x < 0: x = 0
|
| 208 |
+
if y < 0: y = 0
|
| 209 |
+
width1 = width
|
| 210 |
+
width2 = width
|
| 211 |
+
if x + width > image_width: width1 = image_width - x
|
| 212 |
+
if y + width > image_height: width2 = image_height - y
|
| 213 |
+
width = min(width1, width2)
|
| 214 |
+
# the max hand box value is 20 pixels
|
| 215 |
+
if width >= 20:
|
| 216 |
+
detect_result.append([int(x), int(y), int(width), is_left])
|
| 217 |
+
|
| 218 |
+
'''
|
| 219 |
+
return value: [[x, y, w, True if left hand else False]].
|
| 220 |
+
width=height since the network require squared input.
|
| 221 |
+
x, y is the coordinate of top left
|
| 222 |
+
'''
|
| 223 |
+
return detect_result
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
# Written by Lvmin
|
| 227 |
+
def faceDetect(candidate, subset, oriImg):
|
| 228 |
+
# left right eye ear 14 15 16 17
|
| 229 |
+
detect_result = []
|
| 230 |
+
image_height, image_width = oriImg.shape[0:2]
|
| 231 |
+
for person in subset.astype(int):
|
| 232 |
+
has_head = person[0] > -1
|
| 233 |
+
if not has_head:
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
has_left_eye = person[14] > -1
|
| 237 |
+
has_right_eye = person[15] > -1
|
| 238 |
+
has_left_ear = person[16] > -1
|
| 239 |
+
has_right_ear = person[17] > -1
|
| 240 |
+
|
| 241 |
+
if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
|
| 242 |
+
continue
|
| 243 |
+
|
| 244 |
+
head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
|
| 245 |
+
|
| 246 |
+
width = 0.0
|
| 247 |
+
x0, y0 = candidate[head][:2]
|
| 248 |
+
|
| 249 |
+
if has_left_eye:
|
| 250 |
+
x1, y1 = candidate[left_eye][:2]
|
| 251 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
| 252 |
+
width = max(width, d * 3.0)
|
| 253 |
+
|
| 254 |
+
if has_right_eye:
|
| 255 |
+
x1, y1 = candidate[right_eye][:2]
|
| 256 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
| 257 |
+
width = max(width, d * 3.0)
|
| 258 |
+
|
| 259 |
+
if has_left_ear:
|
| 260 |
+
x1, y1 = candidate[left_ear][:2]
|
| 261 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
| 262 |
+
width = max(width, d * 1.5)
|
| 263 |
+
|
| 264 |
+
if has_right_ear:
|
| 265 |
+
x1, y1 = candidate[right_ear][:2]
|
| 266 |
+
d = max(abs(x0 - x1), abs(y0 - y1))
|
| 267 |
+
width = max(width, d * 1.5)
|
| 268 |
+
|
| 269 |
+
x, y = x0, y0
|
| 270 |
+
|
| 271 |
+
x -= width
|
| 272 |
+
y -= width
|
| 273 |
+
|
| 274 |
+
if x < 0:
|
| 275 |
+
x = 0
|
| 276 |
+
|
| 277 |
+
if y < 0:
|
| 278 |
+
y = 0
|
| 279 |
+
|
| 280 |
+
width1 = width * 2
|
| 281 |
+
width2 = width * 2
|
| 282 |
+
|
| 283 |
+
if x + width > image_width:
|
| 284 |
+
width1 = image_width - x
|
| 285 |
+
|
| 286 |
+
if y + width > image_height:
|
| 287 |
+
width2 = image_height - y
|
| 288 |
+
|
| 289 |
+
width = min(width1, width2)
|
| 290 |
+
|
| 291 |
+
if width >= 20:
|
| 292 |
+
detect_result.append([int(x), int(y), int(width)])
|
| 293 |
+
|
| 294 |
+
return detect_result
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
# get max index of 2d array
|
| 298 |
+
def npmax(array):
|
| 299 |
+
arrayindex = array.argmax(1)
|
| 300 |
+
arrayvalue = array.max(1)
|
| 301 |
+
i = arrayvalue.argmax()
|
| 302 |
+
j = arrayindex[i]
|
| 303 |
+
return i, j
|
src/controlnet_aux/dwpose/wholebody.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import os
|
| 3 |
+
import numpy as np
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
import mmcv
|
| 8 |
+
except ImportError:
|
| 9 |
+
warnings.warn(
|
| 10 |
+
"The module 'mmcv' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmcv>=2.0.1'"
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from mmpose.apis import inference_topdown
|
| 15 |
+
from mmpose.apis import init_model as init_pose_estimator
|
| 16 |
+
from mmpose.evaluation.functional import nms
|
| 17 |
+
from mmpose.utils import adapt_mmdet_pipeline
|
| 18 |
+
from mmpose.structures import merge_data_samples
|
| 19 |
+
except ImportError:
|
| 20 |
+
warnings.warn(
|
| 21 |
+
"The module 'mmpose' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmpose>=1.1.0'"
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
from mmdet.apis import inference_detector, init_detector
|
| 26 |
+
except ImportError:
|
| 27 |
+
warnings.warn(
|
| 28 |
+
"The module 'mmdet' is not installed. The package will have limited functionality. Please install it using the command: mim install 'mmdet>=3.1.0'"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class Wholebody:
|
| 33 |
+
def __init__(self,
|
| 34 |
+
det_config=None, det_ckpt=None,
|
| 35 |
+
pose_config=None, pose_ckpt=None,
|
| 36 |
+
device="cpu"):
|
| 37 |
+
|
| 38 |
+
if det_config is None:
|
| 39 |
+
det_config = os.path.join(os.path.dirname(__file__), "yolox_config/yolox_l_8xb8-300e_coco.py")
|
| 40 |
+
|
| 41 |
+
if pose_config is None:
|
| 42 |
+
pose_config = os.path.join(os.path.dirname(__file__), "dwpose_config/dwpose-l_384x288.py")
|
| 43 |
+
|
| 44 |
+
if det_ckpt is None:
|
| 45 |
+
det_ckpt = 'https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_l_8x8_300e_coco/yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth'
|
| 46 |
+
|
| 47 |
+
if pose_ckpt is None:
|
| 48 |
+
pose_ckpt = "https://huggingface.co/wanghaofan/dw-ll_ucoco_384/resolve/main/dw-ll_ucoco_384.pth"
|
| 49 |
+
|
| 50 |
+
# build detector
|
| 51 |
+
self.detector = init_detector(det_config, det_ckpt, device=device)
|
| 52 |
+
self.detector.cfg = adapt_mmdet_pipeline(self.detector.cfg)
|
| 53 |
+
|
| 54 |
+
# build pose estimator
|
| 55 |
+
self.pose_estimator = init_pose_estimator(
|
| 56 |
+
pose_config,
|
| 57 |
+
pose_ckpt,
|
| 58 |
+
device=device)
|
| 59 |
+
|
| 60 |
+
def to(self, device):
|
| 61 |
+
self.detector.to(device)
|
| 62 |
+
self.pose_estimator.to(device)
|
| 63 |
+
return self
|
| 64 |
+
|
| 65 |
+
def __call__(self, oriImg):
|
| 66 |
+
# predict bbox
|
| 67 |
+
det_result = inference_detector(self.detector, oriImg)
|
| 68 |
+
pred_instance = det_result.pred_instances.cpu().numpy()
|
| 69 |
+
bboxes = np.concatenate(
|
| 70 |
+
(pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
|
| 71 |
+
bboxes = bboxes[np.logical_and(pred_instance.labels == 0,
|
| 72 |
+
pred_instance.scores > 0.5)]
|
| 73 |
+
|
| 74 |
+
# set NMS threshold
|
| 75 |
+
bboxes = bboxes[nms(bboxes, 0.7), :4]
|
| 76 |
+
|
| 77 |
+
# predict keypoints
|
| 78 |
+
if len(bboxes) == 0:
|
| 79 |
+
pose_results = inference_topdown(self.pose_estimator, oriImg)
|
| 80 |
+
else:
|
| 81 |
+
pose_results = inference_topdown(self.pose_estimator, oriImg, bboxes)
|
| 82 |
+
preds = merge_data_samples(pose_results)
|
| 83 |
+
preds = preds.pred_instances
|
| 84 |
+
|
| 85 |
+
# preds = pose_results[0].pred_instances
|
| 86 |
+
keypoints = preds.get('transformed_keypoints',
|
| 87 |
+
preds.keypoints)
|
| 88 |
+
if 'keypoint_scores' in preds:
|
| 89 |
+
scores = preds.keypoint_scores
|
| 90 |
+
else:
|
| 91 |
+
scores = np.ones(keypoints.shape[:-1])
|
| 92 |
+
|
| 93 |
+
if 'keypoints_visible' in preds:
|
| 94 |
+
visible = preds.keypoints_visible
|
| 95 |
+
else:
|
| 96 |
+
visible = np.ones(keypoints.shape[:-1])
|
| 97 |
+
keypoints_info = np.concatenate(
|
| 98 |
+
(keypoints, scores[..., None], visible[..., None]),
|
| 99 |
+
axis=-1)
|
| 100 |
+
# compute neck joint
|
| 101 |
+
neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
|
| 102 |
+
# neck score when visualizing pred
|
| 103 |
+
neck[:, 2:4] = np.logical_and(
|
| 104 |
+
keypoints_info[:, 5, 2:4] > 0.3,
|
| 105 |
+
keypoints_info[:, 6, 2:4] > 0.3).astype(int)
|
| 106 |
+
new_keypoints_info = np.insert(
|
| 107 |
+
keypoints_info, 17, neck, axis=1)
|
| 108 |
+
mmpose_idx = [
|
| 109 |
+
17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
|
| 110 |
+
]
|
| 111 |
+
openpose_idx = [
|
| 112 |
+
1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
|
| 113 |
+
]
|
| 114 |
+
new_keypoints_info[:, openpose_idx] = \
|
| 115 |
+
new_keypoints_info[:, mmpose_idx]
|
| 116 |
+
keypoints_info = new_keypoints_info
|
| 117 |
+
|
| 118 |
+
keypoints, scores, visible = keypoints_info[
|
| 119 |
+
..., :2], keypoints_info[..., 2], keypoints_info[..., 3]
|
| 120 |
+
|
| 121 |
+
return keypoints, scores
|
src/controlnet_aux/dwpose/yolox_config/yolox_l_8xb8-300e_coco.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
img_scale = (640, 640) # width, height
|
| 2 |
+
|
| 3 |
+
# model settings
|
| 4 |
+
model = dict(
|
| 5 |
+
type='YOLOX',
|
| 6 |
+
data_preprocessor=dict(
|
| 7 |
+
type='DetDataPreprocessor',
|
| 8 |
+
pad_size_divisor=32,
|
| 9 |
+
batch_augments=[
|
| 10 |
+
dict(
|
| 11 |
+
type='BatchSyncRandomResize',
|
| 12 |
+
random_size_range=(480, 800),
|
| 13 |
+
size_divisor=32,
|
| 14 |
+
interval=10)
|
| 15 |
+
]),
|
| 16 |
+
backbone=dict(
|
| 17 |
+
type='CSPDarknet',
|
| 18 |
+
deepen_factor=1.0,
|
| 19 |
+
widen_factor=1.0,
|
| 20 |
+
out_indices=(2, 3, 4),
|
| 21 |
+
use_depthwise=False,
|
| 22 |
+
spp_kernal_sizes=(5, 9, 13),
|
| 23 |
+
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
| 24 |
+
act_cfg=dict(type='Swish'),
|
| 25 |
+
),
|
| 26 |
+
neck=dict(
|
| 27 |
+
type='YOLOXPAFPN',
|
| 28 |
+
in_channels=[256, 512, 1024],
|
| 29 |
+
out_channels=256,
|
| 30 |
+
num_csp_blocks=3,
|
| 31 |
+
use_depthwise=False,
|
| 32 |
+
upsample_cfg=dict(scale_factor=2, mode='nearest'),
|
| 33 |
+
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
| 34 |
+
act_cfg=dict(type='Swish')),
|
| 35 |
+
bbox_head=dict(
|
| 36 |
+
type='YOLOXHead',
|
| 37 |
+
num_classes=80,
|
| 38 |
+
in_channels=256,
|
| 39 |
+
feat_channels=256,
|
| 40 |
+
stacked_convs=2,
|
| 41 |
+
strides=(8, 16, 32),
|
| 42 |
+
use_depthwise=False,
|
| 43 |
+
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
| 44 |
+
act_cfg=dict(type='Swish'),
|
| 45 |
+
loss_cls=dict(
|
| 46 |
+
type='CrossEntropyLoss',
|
| 47 |
+
use_sigmoid=True,
|
| 48 |
+
reduction='sum',
|
| 49 |
+
loss_weight=1.0),
|
| 50 |
+
loss_bbox=dict(
|
| 51 |
+
type='IoULoss',
|
| 52 |
+
mode='square',
|
| 53 |
+
eps=1e-16,
|
| 54 |
+
reduction='sum',
|
| 55 |
+
loss_weight=5.0),
|
| 56 |
+
loss_obj=dict(
|
| 57 |
+
type='CrossEntropyLoss',
|
| 58 |
+
use_sigmoid=True,
|
| 59 |
+
reduction='sum',
|
| 60 |
+
loss_weight=1.0),
|
| 61 |
+
loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
|
| 62 |
+
train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
|
| 63 |
+
# In order to align the source code, the threshold of the val phase is
|
| 64 |
+
# 0.01, and the threshold of the test phase is 0.001.
|
| 65 |
+
test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
|
| 66 |
+
|
| 67 |
+
# dataset settings
|
| 68 |
+
data_root = 'data/coco/'
|
| 69 |
+
dataset_type = 'CocoDataset'
|
| 70 |
+
|
| 71 |
+
# Example to use different file client
|
| 72 |
+
# Method 1: simply set the data root and let the file I/O module
|
| 73 |
+
# automatically infer from prefix (not support LMDB and Memcache yet)
|
| 74 |
+
|
| 75 |
+
# data_root = 's3://openmmlab/datasets/detection/coco/'
|
| 76 |
+
|
| 77 |
+
# Method 2: Use `backend_args`, `file_client_args` in versions before 3.0.0rc6
|
| 78 |
+
# backend_args = dict(
|
| 79 |
+
# backend='petrel',
|
| 80 |
+
# path_mapping=dict({
|
| 81 |
+
# './data/': 's3://openmmlab/datasets/detection/',
|
| 82 |
+
# 'data/': 's3://openmmlab/datasets/detection/'
|
| 83 |
+
# }))
|
| 84 |
+
backend_args = None
|
| 85 |
+
|
| 86 |
+
train_pipeline = [
|
| 87 |
+
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
|
| 88 |
+
dict(
|
| 89 |
+
type='RandomAffine',
|
| 90 |
+
scaling_ratio_range=(0.1, 2),
|
| 91 |
+
# img_scale is (width, height)
|
| 92 |
+
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
|
| 93 |
+
dict(
|
| 94 |
+
type='MixUp',
|
| 95 |
+
img_scale=img_scale,
|
| 96 |
+
ratio_range=(0.8, 1.6),
|
| 97 |
+
pad_val=114.0),
|
| 98 |
+
dict(type='YOLOXHSVRandomAug'),
|
| 99 |
+
dict(type='RandomFlip', prob=0.5),
|
| 100 |
+
# According to the official implementation, multi-scale
|
| 101 |
+
# training is not considered here but in the
|
| 102 |
+
# 'mmdet/models/detectors/yolox.py'.
|
| 103 |
+
# Resize and Pad are for the last 15 epochs when Mosaic,
|
| 104 |
+
# RandomAffine, and MixUp are closed by YOLOXModeSwitchHook.
|
| 105 |
+
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
| 106 |
+
dict(
|
| 107 |
+
type='Pad',
|
| 108 |
+
pad_to_square=True,
|
| 109 |
+
# If the image is three-channel, the pad value needs
|
| 110 |
+
# to be set separately for each channel.
|
| 111 |
+
pad_val=dict(img=(114.0, 114.0, 114.0))),
|
| 112 |
+
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
|
| 113 |
+
dict(type='PackDetInputs')
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
train_dataset = dict(
|
| 117 |
+
# use MultiImageMixDataset wrapper to support mosaic and mixup
|
| 118 |
+
type='MultiImageMixDataset',
|
| 119 |
+
dataset=dict(
|
| 120 |
+
type=dataset_type,
|
| 121 |
+
data_root=data_root,
|
| 122 |
+
ann_file='annotations/instances_train2017.json',
|
| 123 |
+
data_prefix=dict(img='train2017/'),
|
| 124 |
+
pipeline=[
|
| 125 |
+
dict(type='LoadImageFromFile', backend_args=backend_args),
|
| 126 |
+
dict(type='LoadAnnotations', with_bbox=True)
|
| 127 |
+
],
|
| 128 |
+
filter_cfg=dict(filter_empty_gt=False, min_size=32),
|
| 129 |
+
backend_args=backend_args),
|
| 130 |
+
pipeline=train_pipeline)
|
| 131 |
+
|
| 132 |
+
test_pipeline = [
|
| 133 |
+
dict(type='LoadImageFromFile', backend_args=backend_args),
|
| 134 |
+
dict(type='Resize', scale=img_scale, keep_ratio=True),
|
| 135 |
+
dict(
|
| 136 |
+
type='Pad',
|
| 137 |
+
pad_to_square=True,
|
| 138 |
+
pad_val=dict(img=(114.0, 114.0, 114.0))),
|
| 139 |
+
dict(type='LoadAnnotations', with_bbox=True),
|
| 140 |
+
dict(
|
| 141 |
+
type='PackDetInputs',
|
| 142 |
+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
| 143 |
+
'scale_factor'))
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
train_dataloader = dict(
|
| 147 |
+
batch_size=8,
|
| 148 |
+
num_workers=4,
|
| 149 |
+
persistent_workers=True,
|
| 150 |
+
sampler=dict(type='DefaultSampler', shuffle=True),
|
| 151 |
+
dataset=train_dataset)
|
| 152 |
+
val_dataloader = dict(
|
| 153 |
+
batch_size=8,
|
| 154 |
+
num_workers=4,
|
| 155 |
+
persistent_workers=True,
|
| 156 |
+
drop_last=False,
|
| 157 |
+
sampler=dict(type='DefaultSampler', shuffle=False),
|
| 158 |
+
dataset=dict(
|
| 159 |
+
type=dataset_type,
|
| 160 |
+
data_root=data_root,
|
| 161 |
+
ann_file='annotations/instances_val2017.json',
|
| 162 |
+
data_prefix=dict(img='val2017/'),
|
| 163 |
+
test_mode=True,
|
| 164 |
+
pipeline=test_pipeline,
|
| 165 |
+
backend_args=backend_args))
|
| 166 |
+
test_dataloader = val_dataloader
|
| 167 |
+
|
| 168 |
+
val_evaluator = dict(
|
| 169 |
+
type='CocoMetric',
|
| 170 |
+
ann_file=data_root + 'annotations/instances_val2017.json',
|
| 171 |
+
metric='bbox',
|
| 172 |
+
backend_args=backend_args)
|
| 173 |
+
test_evaluator = val_evaluator
|
| 174 |
+
|
| 175 |
+
# training settings
|
| 176 |
+
max_epochs = 300
|
| 177 |
+
num_last_epochs = 15
|
| 178 |
+
interval = 10
|
| 179 |
+
|
| 180 |
+
train_cfg = dict(max_epochs=max_epochs, val_interval=interval)
|
| 181 |
+
|
| 182 |
+
# optimizer
|
| 183 |
+
# default 8 gpu
|
| 184 |
+
base_lr = 0.01
|
| 185 |
+
optim_wrapper = dict(
|
| 186 |
+
type='OptimWrapper',
|
| 187 |
+
optimizer=dict(
|
| 188 |
+
type='SGD', lr=base_lr, momentum=0.9, weight_decay=5e-4,
|
| 189 |
+
nesterov=True),
|
| 190 |
+
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
|
| 191 |
+
|
| 192 |
+
# learning rate
|
| 193 |
+
param_scheduler = [
|
| 194 |
+
dict(
|
| 195 |
+
# use quadratic formula to warm up 5 epochs
|
| 196 |
+
# and lr is updated by iteration
|
| 197 |
+
# TODO: fix default scope in get function
|
| 198 |
+
type='mmdet.QuadraticWarmupLR',
|
| 199 |
+
by_epoch=True,
|
| 200 |
+
begin=0,
|
| 201 |
+
end=5,
|
| 202 |
+
convert_to_iter_based=True),
|
| 203 |
+
dict(
|
| 204 |
+
# use cosine lr from 5 to 285 epoch
|
| 205 |
+
type='CosineAnnealingLR',
|
| 206 |
+
eta_min=base_lr * 0.05,
|
| 207 |
+
begin=5,
|
| 208 |
+
T_max=max_epochs - num_last_epochs,
|
| 209 |
+
end=max_epochs - num_last_epochs,
|
| 210 |
+
by_epoch=True,
|
| 211 |
+
convert_to_iter_based=True),
|
| 212 |
+
dict(
|
| 213 |
+
# use fixed lr during last 15 epochs
|
| 214 |
+
type='ConstantLR',
|
| 215 |
+
by_epoch=True,
|
| 216 |
+
factor=1,
|
| 217 |
+
begin=max_epochs - num_last_epochs,
|
| 218 |
+
end=max_epochs,
|
| 219 |
+
)
|
| 220 |
+
]
|
| 221 |
+
|
| 222 |
+
default_hooks = dict(
|
| 223 |
+
checkpoint=dict(
|
| 224 |
+
interval=interval,
|
| 225 |
+
max_keep_ckpts=3 # only keep latest 3 checkpoints
|
| 226 |
+
))
|
| 227 |
+
|
| 228 |
+
custom_hooks = [
|
| 229 |
+
dict(
|
| 230 |
+
type='YOLOXModeSwitchHook',
|
| 231 |
+
num_last_epochs=num_last_epochs,
|
| 232 |
+
priority=48),
|
| 233 |
+
dict(type='SyncNormHook', priority=48),
|
| 234 |
+
dict(
|
| 235 |
+
type='EMAHook',
|
| 236 |
+
ema_type='ExpMomentumEMA',
|
| 237 |
+
momentum=0.0001,
|
| 238 |
+
update_buffers=True,
|
| 239 |
+
priority=49)
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
|
| 243 |
+
# USER SHOULD NOT CHANGE ITS VALUES.
|
| 244 |
+
# base_batch_size = (8 GPUs) x (8 samples per GPU)
|
| 245 |
+
auto_scale_lr = dict(base_batch_size=64)
|
src/controlnet_aux/hed/__init__.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This is an improved version and model of HED edge detection with Apache License, Version 2.0.
|
| 2 |
+
# Please use this implementation in your products
|
| 3 |
+
# This implementation may produce slightly different results from Saining Xie's official implementations,
|
| 4 |
+
# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations.
|
| 5 |
+
# Different from official models and other implementations, this is an RGB-input model (rather than BGR)
|
| 6 |
+
# and in this way it works better for gradio's RGB protocol
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
from huggingface_hub import hf_hub_download
|
| 16 |
+
from PIL import Image
|
| 17 |
+
|
| 18 |
+
from ..util import HWC3, nms, resize_image, safe_step
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DoubleConvBlock(torch.nn.Module):
|
| 22 |
+
def __init__(self, input_channel, output_channel, layer_number):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.convs = torch.nn.Sequential()
|
| 25 |
+
self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
| 26 |
+
for i in range(1, layer_number):
|
| 27 |
+
self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1))
|
| 28 |
+
self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0)
|
| 29 |
+
|
| 30 |
+
def __call__(self, x, down_sampling=False):
|
| 31 |
+
h = x
|
| 32 |
+
if down_sampling:
|
| 33 |
+
h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2))
|
| 34 |
+
for conv in self.convs:
|
| 35 |
+
h = conv(h)
|
| 36 |
+
h = torch.nn.functional.relu(h)
|
| 37 |
+
return h, self.projection(h)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ControlNetHED_Apache2(torch.nn.Module):
|
| 41 |
+
def __init__(self):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1)))
|
| 44 |
+
self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2)
|
| 45 |
+
self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2)
|
| 46 |
+
self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3)
|
| 47 |
+
self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3)
|
| 48 |
+
self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3)
|
| 49 |
+
|
| 50 |
+
def __call__(self, x):
|
| 51 |
+
h = x - self.norm
|
| 52 |
+
h, projection1 = self.block1(h)
|
| 53 |
+
h, projection2 = self.block2(h, down_sampling=True)
|
| 54 |
+
h, projection3 = self.block3(h, down_sampling=True)
|
| 55 |
+
h, projection4 = self.block4(h, down_sampling=True)
|
| 56 |
+
h, projection5 = self.block5(h, down_sampling=True)
|
| 57 |
+
return projection1, projection2, projection3, projection4, projection5
|
| 58 |
+
|
| 59 |
+
class HEDdetector:
|
| 60 |
+
def __init__(self, netNetwork):
|
| 61 |
+
self.netNetwork = netNetwork
|
| 62 |
+
|
| 63 |
+
@classmethod
|
| 64 |
+
def from_pretrained(cls, pretrained_model_or_path, filename=None, cache_dir=None):
|
| 65 |
+
filename = filename or "ControlNetHED.pth"
|
| 66 |
+
|
| 67 |
+
if os.path.isdir(pretrained_model_or_path):
|
| 68 |
+
model_path = os.path.join(pretrained_model_or_path, filename)
|
| 69 |
+
else:
|
| 70 |
+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
|
| 71 |
+
|
| 72 |
+
netNetwork = ControlNetHED_Apache2()
|
| 73 |
+
netNetwork.load_state_dict(torch.load(model_path, map_location='cpu'))
|
| 74 |
+
netNetwork.float().eval()
|
| 75 |
+
|
| 76 |
+
return cls(netNetwork)
|
| 77 |
+
|
| 78 |
+
def to(self, device):
|
| 79 |
+
self.netNetwork.to(device)
|
| 80 |
+
return self
|
| 81 |
+
|
| 82 |
+
def __call__(self, input_image, detect_resolution=512, image_resolution=512, safe=False, output_type="pil", scribble=False, **kwargs):
|
| 83 |
+
if "return_pil" in kwargs:
|
| 84 |
+
warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
|
| 85 |
+
output_type = "pil" if kwargs["return_pil"] else "np"
|
| 86 |
+
if type(output_type) is bool:
|
| 87 |
+
warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
|
| 88 |
+
if output_type:
|
| 89 |
+
output_type = "pil"
|
| 90 |
+
|
| 91 |
+
device = next(iter(self.netNetwork.parameters())).device
|
| 92 |
+
if not isinstance(input_image, np.ndarray):
|
| 93 |
+
input_image = np.array(input_image, dtype=np.uint8)
|
| 94 |
+
|
| 95 |
+
input_image = HWC3(input_image)
|
| 96 |
+
input_image = resize_image(input_image, detect_resolution)
|
| 97 |
+
|
| 98 |
+
assert input_image.ndim == 3
|
| 99 |
+
H, W, C = input_image.shape
|
| 100 |
+
with torch.no_grad():
|
| 101 |
+
image_hed = torch.from_numpy(input_image.copy()).float().to(device)
|
| 102 |
+
image_hed = rearrange(image_hed, 'h w c -> 1 c h w')
|
| 103 |
+
edges = self.netNetwork(image_hed)
|
| 104 |
+
edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges]
|
| 105 |
+
edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges]
|
| 106 |
+
edges = np.stack(edges, axis=2)
|
| 107 |
+
edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64)))
|
| 108 |
+
if safe:
|
| 109 |
+
edge = safe_step(edge)
|
| 110 |
+
edge = (edge * 255.0).clip(0, 255).astype(np.uint8)
|
| 111 |
+
|
| 112 |
+
detected_map = edge
|
| 113 |
+
detected_map = HWC3(detected_map)
|
| 114 |
+
|
| 115 |
+
img = resize_image(input_image, image_resolution)
|
| 116 |
+
H, W, C = img.shape
|
| 117 |
+
|
| 118 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 119 |
+
|
| 120 |
+
if scribble:
|
| 121 |
+
detected_map = nms(detected_map, 127, 3.0)
|
| 122 |
+
detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0)
|
| 123 |
+
detected_map[detected_map > 4] = 255
|
| 124 |
+
detected_map[detected_map < 255] = 0
|
| 125 |
+
|
| 126 |
+
if output_type == "pil":
|
| 127 |
+
detected_map = Image.fromarray(detected_map)
|
| 128 |
+
|
| 129 |
+
return detected_map
|
src/controlnet_aux/hed/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (10.1 kB). View file
|
|
|
src/controlnet_aux/leres/__init__.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
from ..util import HWC3, resize_image
|
| 10 |
+
from .leres.depthmap import estimateboost, estimateleres
|
| 11 |
+
from .leres.multi_depth_model_woauxi import RelDepthModel
|
| 12 |
+
from .leres.net_tools import strip_prefix_if_present
|
| 13 |
+
from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel
|
| 14 |
+
from .pix2pix.options.test_options import TestOptions
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class LeresDetector:
|
| 18 |
+
def __init__(self, model, pix2pixmodel):
|
| 19 |
+
self.model = model
|
| 20 |
+
self.pix2pixmodel = pix2pixmodel
|
| 21 |
+
|
| 22 |
+
@classmethod
|
| 23 |
+
def from_pretrained(cls, pretrained_model_or_path, filename=None, pix2pix_filename=None, cache_dir=None):
|
| 24 |
+
filename = filename or "res101.pth"
|
| 25 |
+
pix2pix_filename = pix2pix_filename or "latest_net_G.pth"
|
| 26 |
+
|
| 27 |
+
if os.path.isdir(pretrained_model_or_path):
|
| 28 |
+
model_path = os.path.join(pretrained_model_or_path, filename)
|
| 29 |
+
else:
|
| 30 |
+
model_path = hf_hub_download(pretrained_model_or_path, filename, cache_dir=cache_dir)
|
| 31 |
+
|
| 32 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
| 33 |
+
|
| 34 |
+
model = RelDepthModel(backbone='resnext101')
|
| 35 |
+
model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True)
|
| 36 |
+
del checkpoint
|
| 37 |
+
|
| 38 |
+
if os.path.isdir(pretrained_model_or_path):
|
| 39 |
+
model_path = os.path.join(pretrained_model_or_path, pix2pix_filename)
|
| 40 |
+
else:
|
| 41 |
+
model_path = hf_hub_download(pretrained_model_or_path, pix2pix_filename, cache_dir=cache_dir)
|
| 42 |
+
|
| 43 |
+
opt = TestOptions().parse()
|
| 44 |
+
if not torch.cuda.is_available():
|
| 45 |
+
opt.gpu_ids = [] # cpu mode
|
| 46 |
+
pix2pixmodel = Pix2Pix4DepthModel(opt)
|
| 47 |
+
pix2pixmodel.save_dir = os.path.dirname(model_path)
|
| 48 |
+
pix2pixmodel.load_networks('latest')
|
| 49 |
+
pix2pixmodel.eval()
|
| 50 |
+
|
| 51 |
+
return cls(model, pix2pixmodel)
|
| 52 |
+
|
| 53 |
+
def to(self, device):
|
| 54 |
+
self.model.to(device)
|
| 55 |
+
# TODO - refactor pix2pix implementation to support device migration
|
| 56 |
+
# self.pix2pixmodel.to(device)
|
| 57 |
+
return self
|
| 58 |
+
|
| 59 |
+
def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution=512, image_resolution=512, output_type="pil"):
|
| 60 |
+
device = next(iter(self.model.parameters())).device
|
| 61 |
+
if not isinstance(input_image, np.ndarray):
|
| 62 |
+
input_image = np.array(input_image, dtype=np.uint8)
|
| 63 |
+
|
| 64 |
+
input_image = HWC3(input_image)
|
| 65 |
+
input_image = resize_image(input_image, detect_resolution)
|
| 66 |
+
|
| 67 |
+
assert input_image.ndim == 3
|
| 68 |
+
height, width, dim = input_image.shape
|
| 69 |
+
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
|
| 72 |
+
if boost:
|
| 73 |
+
depth = estimateboost(input_image, self.model, 0, self.pix2pixmodel, max(width, height))
|
| 74 |
+
else:
|
| 75 |
+
depth = estimateleres(input_image, self.model, width, height)
|
| 76 |
+
|
| 77 |
+
numbytes=2
|
| 78 |
+
depth_min = depth.min()
|
| 79 |
+
depth_max = depth.max()
|
| 80 |
+
max_val = (2**(8*numbytes))-1
|
| 81 |
+
|
| 82 |
+
# check output before normalizing and mapping to 16 bit
|
| 83 |
+
if depth_max - depth_min > np.finfo("float").eps:
|
| 84 |
+
out = max_val * (depth - depth_min) / (depth_max - depth_min)
|
| 85 |
+
else:
|
| 86 |
+
out = np.zeros(depth.shape)
|
| 87 |
+
|
| 88 |
+
# single channel, 16 bit image
|
| 89 |
+
depth_image = out.astype("uint16")
|
| 90 |
+
|
| 91 |
+
# convert to uint8
|
| 92 |
+
depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0))
|
| 93 |
+
|
| 94 |
+
# remove near
|
| 95 |
+
if thr_a != 0:
|
| 96 |
+
thr_a = ((thr_a/100)*255)
|
| 97 |
+
depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1]
|
| 98 |
+
|
| 99 |
+
# invert image
|
| 100 |
+
depth_image = cv2.bitwise_not(depth_image)
|
| 101 |
+
|
| 102 |
+
# remove bg
|
| 103 |
+
if thr_b != 0:
|
| 104 |
+
thr_b = ((thr_b/100)*255)
|
| 105 |
+
depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1]
|
| 106 |
+
|
| 107 |
+
detected_map = depth_image
|
| 108 |
+
detected_map = HWC3(detected_map)
|
| 109 |
+
|
| 110 |
+
img = resize_image(input_image, image_resolution)
|
| 111 |
+
H, W, C = img.shape
|
| 112 |
+
|
| 113 |
+
detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR)
|
| 114 |
+
|
| 115 |
+
if output_type == "pil":
|
| 116 |
+
detected_map = Image.fromarray(detected_map)
|
| 117 |
+
|
| 118 |
+
return detected_map
|
src/controlnet_aux/leres/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (6.37 kB). View file
|
|
|
src/controlnet_aux/leres/leres/LICENSE
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
https://github.com/thygate/stable-diffusion-webui-depthmap-script
|
| 2 |
+
|
| 3 |
+
MIT License
|
| 4 |
+
|
| 5 |
+
Copyright (c) 2023 Bob Thiry
|
| 6 |
+
|
| 7 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 8 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 9 |
+
in the Software without restriction, including without limitation the rights
|
| 10 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 11 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 12 |
+
furnished to do so, subject to the following conditions:
|
| 13 |
+
|
| 14 |
+
The above copyright notice and this permission notice shall be included in all
|
| 15 |
+
copies or substantial portions of the Software.
|
| 16 |
+
|
| 17 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 18 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 19 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 20 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 21 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 22 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 23 |
+
SOFTWARE.
|
src/controlnet_aux/leres/leres/Resnet.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn as NN
|
| 3 |
+
|
| 4 |
+
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
|
| 5 |
+
'resnet152']
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
model_urls = {
|
| 9 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
| 10 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
| 11 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
| 12 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
| 13 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 18 |
+
"""3x3 convolution with padding"""
|
| 19 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 20 |
+
padding=1, bias=False)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BasicBlock(nn.Module):
|
| 24 |
+
expansion = 1
|
| 25 |
+
|
| 26 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 27 |
+
super(BasicBlock, self).__init__()
|
| 28 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 29 |
+
self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
| 30 |
+
self.relu = nn.ReLU(inplace=True)
|
| 31 |
+
self.conv2 = conv3x3(planes, planes)
|
| 32 |
+
self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
| 33 |
+
self.downsample = downsample
|
| 34 |
+
self.stride = stride
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
residual = x
|
| 38 |
+
|
| 39 |
+
out = self.conv1(x)
|
| 40 |
+
out = self.bn1(out)
|
| 41 |
+
out = self.relu(out)
|
| 42 |
+
|
| 43 |
+
out = self.conv2(out)
|
| 44 |
+
out = self.bn2(out)
|
| 45 |
+
|
| 46 |
+
if self.downsample is not None:
|
| 47 |
+
residual = self.downsample(x)
|
| 48 |
+
|
| 49 |
+
out += residual
|
| 50 |
+
out = self.relu(out)
|
| 51 |
+
|
| 52 |
+
return out
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Bottleneck(nn.Module):
|
| 56 |
+
expansion = 4
|
| 57 |
+
|
| 58 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 59 |
+
super(Bottleneck, self).__init__()
|
| 60 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 61 |
+
self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
| 62 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 63 |
+
padding=1, bias=False)
|
| 64 |
+
self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d
|
| 65 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
|
| 66 |
+
self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d
|
| 67 |
+
self.relu = nn.ReLU(inplace=True)
|
| 68 |
+
self.downsample = downsample
|
| 69 |
+
self.stride = stride
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
residual = x
|
| 73 |
+
|
| 74 |
+
out = self.conv1(x)
|
| 75 |
+
out = self.bn1(out)
|
| 76 |
+
out = self.relu(out)
|
| 77 |
+
|
| 78 |
+
out = self.conv2(out)
|
| 79 |
+
out = self.bn2(out)
|
| 80 |
+
out = self.relu(out)
|
| 81 |
+
|
| 82 |
+
out = self.conv3(out)
|
| 83 |
+
out = self.bn3(out)
|
| 84 |
+
|
| 85 |
+
if self.downsample is not None:
|
| 86 |
+
residual = self.downsample(x)
|
| 87 |
+
|
| 88 |
+
out += residual
|
| 89 |
+
out = self.relu(out)
|
| 90 |
+
|
| 91 |
+
return out
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ResNet(nn.Module):
|
| 95 |
+
|
| 96 |
+
def __init__(self, block, layers, num_classes=1000):
|
| 97 |
+
self.inplanes = 64
|
| 98 |
+
super(ResNet, self).__init__()
|
| 99 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
| 100 |
+
bias=False)
|
| 101 |
+
self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d
|
| 102 |
+
self.relu = nn.ReLU(inplace=True)
|
| 103 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 104 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 105 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 106 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 107 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 108 |
+
#self.avgpool = nn.AvgPool2d(7, stride=1)
|
| 109 |
+
#self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 110 |
+
|
| 111 |
+
for m in self.modules():
|
| 112 |
+
if isinstance(m, nn.Conv2d):
|
| 113 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 114 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 115 |
+
nn.init.constant_(m.weight, 1)
|
| 116 |
+
nn.init.constant_(m.bias, 0)
|
| 117 |
+
|
| 118 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 119 |
+
downsample = None
|
| 120 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 121 |
+
downsample = nn.Sequential(
|
| 122 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 123 |
+
kernel_size=1, stride=stride, bias=False),
|
| 124 |
+
NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
layers = []
|
| 128 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 129 |
+
self.inplanes = planes * block.expansion
|
| 130 |
+
for i in range(1, blocks):
|
| 131 |
+
layers.append(block(self.inplanes, planes))
|
| 132 |
+
|
| 133 |
+
return nn.Sequential(*layers)
|
| 134 |
+
|
| 135 |
+
def forward(self, x):
|
| 136 |
+
features = []
|
| 137 |
+
|
| 138 |
+
x = self.conv1(x)
|
| 139 |
+
x = self.bn1(x)
|
| 140 |
+
x = self.relu(x)
|
| 141 |
+
x = self.maxpool(x)
|
| 142 |
+
|
| 143 |
+
x = self.layer1(x)
|
| 144 |
+
features.append(x)
|
| 145 |
+
x = self.layer2(x)
|
| 146 |
+
features.append(x)
|
| 147 |
+
x = self.layer3(x)
|
| 148 |
+
features.append(x)
|
| 149 |
+
x = self.layer4(x)
|
| 150 |
+
features.append(x)
|
| 151 |
+
|
| 152 |
+
return features
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def resnet18(pretrained=True, **kwargs):
|
| 156 |
+
"""Constructs a ResNet-18 model.
|
| 157 |
+
Args:
|
| 158 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 159 |
+
"""
|
| 160 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
| 161 |
+
return model
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def resnet34(pretrained=True, **kwargs):
|
| 165 |
+
"""Constructs a ResNet-34 model.
|
| 166 |
+
Args:
|
| 167 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 168 |
+
"""
|
| 169 |
+
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
|
| 170 |
+
return model
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def resnet50(pretrained=True, **kwargs):
|
| 174 |
+
"""Constructs a ResNet-50 model.
|
| 175 |
+
Args:
|
| 176 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 177 |
+
"""
|
| 178 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 179 |
+
|
| 180 |
+
return model
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def resnet101(pretrained=True, **kwargs):
|
| 184 |
+
"""Constructs a ResNet-101 model.
|
| 185 |
+
Args:
|
| 186 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 187 |
+
"""
|
| 188 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
| 189 |
+
|
| 190 |
+
return model
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def resnet152(pretrained=True, **kwargs):
|
| 194 |
+
"""Constructs a ResNet-152 model.
|
| 195 |
+
Args:
|
| 196 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 197 |
+
"""
|
| 198 |
+
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
|
| 199 |
+
return model
|
src/controlnet_aux/leres/leres/Resnext_torch.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
try:
|
| 6 |
+
from urllib import urlretrieve
|
| 7 |
+
except ImportError:
|
| 8 |
+
from urllib.request import urlretrieve
|
| 9 |
+
|
| 10 |
+
__all__ = ['resnext101_32x8d']
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
model_urls = {
|
| 14 |
+
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
|
| 15 |
+
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 20 |
+
"""3x3 convolution with padding"""
|
| 21 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 22 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 26 |
+
"""1x1 convolution"""
|
| 27 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class BasicBlock(nn.Module):
|
| 31 |
+
expansion = 1
|
| 32 |
+
|
| 33 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 34 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 35 |
+
super(BasicBlock, self).__init__()
|
| 36 |
+
if norm_layer is None:
|
| 37 |
+
norm_layer = nn.BatchNorm2d
|
| 38 |
+
if groups != 1 or base_width != 64:
|
| 39 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 40 |
+
if dilation > 1:
|
| 41 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 42 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 43 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 44 |
+
self.bn1 = norm_layer(planes)
|
| 45 |
+
self.relu = nn.ReLU(inplace=True)
|
| 46 |
+
self.conv2 = conv3x3(planes, planes)
|
| 47 |
+
self.bn2 = norm_layer(planes)
|
| 48 |
+
self.downsample = downsample
|
| 49 |
+
self.stride = stride
|
| 50 |
+
|
| 51 |
+
def forward(self, x):
|
| 52 |
+
identity = x
|
| 53 |
+
|
| 54 |
+
out = self.conv1(x)
|
| 55 |
+
out = self.bn1(out)
|
| 56 |
+
out = self.relu(out)
|
| 57 |
+
|
| 58 |
+
out = self.conv2(out)
|
| 59 |
+
out = self.bn2(out)
|
| 60 |
+
|
| 61 |
+
if self.downsample is not None:
|
| 62 |
+
identity = self.downsample(x)
|
| 63 |
+
|
| 64 |
+
out += identity
|
| 65 |
+
out = self.relu(out)
|
| 66 |
+
|
| 67 |
+
return out
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Bottleneck(nn.Module):
|
| 71 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 72 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 73 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 74 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 75 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 76 |
+
|
| 77 |
+
expansion = 4
|
| 78 |
+
|
| 79 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
|
| 80 |
+
base_width=64, dilation=1, norm_layer=None):
|
| 81 |
+
super(Bottleneck, self).__init__()
|
| 82 |
+
if norm_layer is None:
|
| 83 |
+
norm_layer = nn.BatchNorm2d
|
| 84 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 85 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 86 |
+
self.conv1 = conv1x1(inplanes, width)
|
| 87 |
+
self.bn1 = norm_layer(width)
|
| 88 |
+
self.conv2 = conv3x3(width, width, stride, groups, dilation)
|
| 89 |
+
self.bn2 = norm_layer(width)
|
| 90 |
+
self.conv3 = conv1x1(width, planes * self.expansion)
|
| 91 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 92 |
+
self.relu = nn.ReLU(inplace=True)
|
| 93 |
+
self.downsample = downsample
|
| 94 |
+
self.stride = stride
|
| 95 |
+
|
| 96 |
+
def forward(self, x):
|
| 97 |
+
identity = x
|
| 98 |
+
|
| 99 |
+
out = self.conv1(x)
|
| 100 |
+
out = self.bn1(out)
|
| 101 |
+
out = self.relu(out)
|
| 102 |
+
|
| 103 |
+
out = self.conv2(out)
|
| 104 |
+
out = self.bn2(out)
|
| 105 |
+
out = self.relu(out)
|
| 106 |
+
|
| 107 |
+
out = self.conv3(out)
|
| 108 |
+
out = self.bn3(out)
|
| 109 |
+
|
| 110 |
+
if self.downsample is not None:
|
| 111 |
+
identity = self.downsample(x)
|
| 112 |
+
|
| 113 |
+
out += identity
|
| 114 |
+
out = self.relu(out)
|
| 115 |
+
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ResNet(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
|
| 122 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None,
|
| 123 |
+
norm_layer=None):
|
| 124 |
+
super(ResNet, self).__init__()
|
| 125 |
+
if norm_layer is None:
|
| 126 |
+
norm_layer = nn.BatchNorm2d
|
| 127 |
+
self._norm_layer = norm_layer
|
| 128 |
+
|
| 129 |
+
self.inplanes = 64
|
| 130 |
+
self.dilation = 1
|
| 131 |
+
if replace_stride_with_dilation is None:
|
| 132 |
+
# each element in the tuple indicates if we should replace
|
| 133 |
+
# the 2x2 stride with a dilated convolution instead
|
| 134 |
+
replace_stride_with_dilation = [False, False, False]
|
| 135 |
+
if len(replace_stride_with_dilation) != 3:
|
| 136 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 137 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 138 |
+
self.groups = groups
|
| 139 |
+
self.base_width = width_per_group
|
| 140 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
| 141 |
+
bias=False)
|
| 142 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 143 |
+
self.relu = nn.ReLU(inplace=True)
|
| 144 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 145 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 146 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
|
| 147 |
+
dilate=replace_stride_with_dilation[0])
|
| 148 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
|
| 149 |
+
dilate=replace_stride_with_dilation[1])
|
| 150 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
|
| 151 |
+
dilate=replace_stride_with_dilation[2])
|
| 152 |
+
#self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 153 |
+
#self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 154 |
+
|
| 155 |
+
for m in self.modules():
|
| 156 |
+
if isinstance(m, nn.Conv2d):
|
| 157 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 158 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 159 |
+
nn.init.constant_(m.weight, 1)
|
| 160 |
+
nn.init.constant_(m.bias, 0)
|
| 161 |
+
|
| 162 |
+
# Zero-initialize the last BN in each residual branch,
|
| 163 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 164 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 165 |
+
if zero_init_residual:
|
| 166 |
+
for m in self.modules():
|
| 167 |
+
if isinstance(m, Bottleneck):
|
| 168 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 169 |
+
elif isinstance(m, BasicBlock):
|
| 170 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 171 |
+
|
| 172 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 173 |
+
norm_layer = self._norm_layer
|
| 174 |
+
downsample = None
|
| 175 |
+
previous_dilation = self.dilation
|
| 176 |
+
if dilate:
|
| 177 |
+
self.dilation *= stride
|
| 178 |
+
stride = 1
|
| 179 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 180 |
+
downsample = nn.Sequential(
|
| 181 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 182 |
+
norm_layer(planes * block.expansion),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
layers = []
|
| 186 |
+
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
|
| 187 |
+
self.base_width, previous_dilation, norm_layer))
|
| 188 |
+
self.inplanes = planes * block.expansion
|
| 189 |
+
for _ in range(1, blocks):
|
| 190 |
+
layers.append(block(self.inplanes, planes, groups=self.groups,
|
| 191 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 192 |
+
norm_layer=norm_layer))
|
| 193 |
+
|
| 194 |
+
return nn.Sequential(*layers)
|
| 195 |
+
|
| 196 |
+
def _forward_impl(self, x):
|
| 197 |
+
# See note [TorchScript super()]
|
| 198 |
+
features = []
|
| 199 |
+
x = self.conv1(x)
|
| 200 |
+
x = self.bn1(x)
|
| 201 |
+
x = self.relu(x)
|
| 202 |
+
x = self.maxpool(x)
|
| 203 |
+
|
| 204 |
+
x = self.layer1(x)
|
| 205 |
+
features.append(x)
|
| 206 |
+
|
| 207 |
+
x = self.layer2(x)
|
| 208 |
+
features.append(x)
|
| 209 |
+
|
| 210 |
+
x = self.layer3(x)
|
| 211 |
+
features.append(x)
|
| 212 |
+
|
| 213 |
+
x = self.layer4(x)
|
| 214 |
+
features.append(x)
|
| 215 |
+
|
| 216 |
+
#x = self.avgpool(x)
|
| 217 |
+
#x = torch.flatten(x, 1)
|
| 218 |
+
#x = self.fc(x)
|
| 219 |
+
|
| 220 |
+
return features
|
| 221 |
+
|
| 222 |
+
def forward(self, x):
|
| 223 |
+
return self._forward_impl(x)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def resnext101_32x8d(pretrained=True, **kwargs):
|
| 228 |
+
"""Constructs a ResNet-152 model.
|
| 229 |
+
Args:
|
| 230 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 231 |
+
"""
|
| 232 |
+
kwargs['groups'] = 32
|
| 233 |
+
kwargs['width_per_group'] = 8
|
| 234 |
+
|
| 235 |
+
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
|
| 236 |
+
return model
|
| 237 |
+
|