Implement Moebius Gradio Space
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +18 -0
- .gitignore +5 -0
- LICENSE +199 -0
- README.md +195 -8
- app.py +229 -0
- assets/logo_dynamic_woWaterMark.gif +3 -0
- assets/pipeline.png +3 -0
- assets/qualitative_comparison.png +3 -0
- assets/sup_showcase_celebahq_ffhq.png +3 -0
- assets/sup_showcase_places_v2.png +3 -0
- assets/tab1.png +3 -0
- assets/tab1_woTitle.png +3 -0
- assets/tab2.png +3 -0
- assets/tab3.png +3 -0
- assets/tab4.png +3 -0
- config/data_demo.yaml +9 -0
- config/model_cfg/moebius.yaml +47 -0
- config/model_cfg/pixelhacker.yaml +24 -0
- config/rand_mask_cfg/random_medium_256.yaml +33 -0
- config/rand_mask_cfg/random_medium_512.yaml +32 -0
- config/rand_mask_cfg/random_thick_256.yaml +33 -0
- config/rand_mask_cfg/random_thick_512.yaml +33 -0
- config/rand_mask_cfg/random_thin_256.yaml +25 -0
- config/rand_mask_cfg/random_thin_512.yaml +25 -0
- config/train_demo.sh +45 -0
- data/images/0.png +3 -0
- data/images/1.png +3 -0
- data/images/10.png +3 -0
- data/images/100.png +3 -0
- data/images/10000.png +3 -0
- data/images/10001.png +3 -0
- data/images/10002.png +3 -0
- data/images/10003.png +3 -0
- data/masks/000000.png +0 -0
- data/masks/000001.png +0 -0
- data/masks/000002.png +0 -0
- data/masks/000003.png +0 -0
- data/masks/000004.png +0 -0
- data/masks/000005.png +0 -0
- data/masks/000006.png +0 -0
- data/masks/000007.png +0 -0
- data/train_data.jsonl +8 -0
- infer/__init__.py +0 -0
- infer/infer_moebius.py +45 -0
- infer/utils.py +123 -0
- infer/utils_dataset.py +211 -0
- library/__init__.py +0 -0
- library/chinese_sdxl_train_util.py +350 -0
- library/custom_train_functions.py +515 -0
- library/train_util.py +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
assets/logo_dynamic_woWaterMark.gif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
assets/pipeline.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
assets/qualitative_comparison.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
assets/sup_showcase_celebahq_ffhq.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
assets/sup_showcase_places_v2.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
assets/tab1.png filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
assets/tab1_woTitle.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
assets/tab2.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
assets/tab3.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
assets/tab4.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
data/images/0.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
data/images/1.png filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
data/images/10.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
data/images/100.png filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
data/images/10000.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
data/images/10001.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
data/images/10002.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
data/images/10003.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
.venv/
|
| 4 |
+
outputs/
|
| 5 |
+
weight/
|
LICENSE
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to the Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by the Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding any notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. Please also get an
|
| 185 |
+
"Alarm or alarm" page at http://www.apache.org/
|
| 186 |
+
|
| 187 |
+
Copyright 2024 Moebius Authors
|
| 188 |
+
|
| 189 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 190 |
+
you may not use this file except in compliance with the License.
|
| 191 |
+
You may obtain a copy of the License at
|
| 192 |
+
|
| 193 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 194 |
+
|
| 195 |
+
Unless required by applicable law or agreed to in writing, software
|
| 196 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 197 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 198 |
+
See the License for the specific language governing permissions and
|
| 199 |
+
limitations under the License.
|
README.md
CHANGED
|
@@ -1,13 +1,200 @@
|
|
| 1 |
---
|
| 2 |
-
title: Moebius
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 6.
|
| 8 |
-
python_version: '3.13'
|
| 9 |
app_file: app.py
|
| 10 |
-
|
|
|
|
|
|
|
| 11 |
---
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Moebius Inpainting
|
| 3 |
+
emoji: 🖌️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 6.10.0
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
+
short_description: Lightweight Moebius image inpainting
|
| 10 |
+
startup_duration_timeout: 1h
|
| 11 |
+
python_version: 3.11
|
| 12 |
---
|
| 13 |
|
| 14 |
+
<div align="center">
|
| 15 |
+
<img src="./assets/logo_dynamic_woWaterMark.gif" width="100%"></img>
|
| 16 |
+
</div>
|
| 17 |
+
|
| 18 |
+
<div align="center">
|
| 19 |
+
<h2>Moebius: 0.2B Lightweight Image Inpainting Framework with 10B-Level Performance</h2>
|
| 20 |
+
|
| 21 |
+
***On-par-with/surpass 10B-level industrial SOTA generalist (FLUX.1-Fill-Dev) on 6 benchmarks across natural and portrait scenes & Only 2% (0.2B) parameters, and inference 15× faster***
|
| 22 |
+
|
| 23 |
+
[Kangsheng Duan](https://github.com/AnduinD)<sup>1,</sup>\*, [Ziyang Xu](https://ziyangxu.top)<sup>1,</sup>\*<sup>,†</sup>, [Wenyu Liu](http://eic.hust.edu.cn/professor/liuwenyu)<sup>1</sup>, Xiaohu Ruan<sup>2</sup>, [Xiaoxin Chen](https://scholar.google.com/citations?hl=zh-CN&user=SI_oBwsAAAAJ)<sup>2</sup>, [Xinggang Wang](https://xwcv.github.io)<sup>1, :email:</sup>
|
| 24 |
+
|
| 25 |
+
(*) Equal Contribution, (<sup>†</sup>) Project Leader, (<sup>:email:</sup>) Corresponding Author.
|
| 26 |
+
|
| 27 |
+
<sup>1</sup> Huazhong University of Science and Technology. <sup>2</sup> VIVO AI Lab.
|
| 28 |
+
|
| 29 |
+
[](https://arxiv.org/abs/2606.19195) [](LICENSE) [](https://hustvl.github.io/Moebius) [](https://huggingface.co/papers/date/2026-06-19)
|
| 30 |
+
|
| 31 |
+
<br>
|
| 32 |
+
|
| 33 |
+
<img src="./assets/pipeline.png" style="margin-bottom: 10px;"></img>
|
| 34 |
+
|
| 35 |
+
<img src="./assets/tab1_woTitle.png"></img>
|
| 36 |
+
|
| 37 |
+
<img src="./assets/qualitative_comparison.png"></img>
|
| 38 |
+
|
| 39 |
+
</div>
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
## 🐱🏍 Insight & Small Talk
|
| 43 |
+
|
| 44 |
+
> ***Moebius*** *is our latest AI Image Inpainting endeavor, serving as a direct continuation of our previous work, **[PixelHacker](https://github.com/hustvl/PixelHacker)**. Named after the concepts of "infinity" and "master painter," Moebius embodies our vision: maintaining exceptional generation quality under highly constrained computational resources while pushing the efficiency of image inpainting to its limits as much as possible.*
|
| 45 |
+
>
|
| 46 |
+
> *Under the iron grip of the Scaling Law, AI research has long devolved into a grueling arms race of burning capital, compute, and data. Consequently, the academic community finds it increasingly difficult to keep pace with the ever-expanding model scales driven by the tech industry.*
|
| 47 |
+
>
|
| 48 |
+
> <p align="center"><b>"<ins>But is this brute-force scaling truly the only path forward?</ins>"</b></p>
|
| 49 |
+
>
|
| 50 |
+
> *Using general-purpose image inpainting as our strategic entry point, we challenge the "scale-at-all-costs" path dependency dictated by the Scaling Law narrative. Through the synergistic optimization of architectural design and knowledge distillation, Moebius achieves a remarkably compact footprint of just **0.22B parameters**. It liberates high-quality image inpainting from the heavy-compute narrative of 10B+ foundation models:*
|
| 51 |
+
> *Across six comprehensive benchmarks spanning both natural and portrait scenes, Moebius performs **on par with**, and in certain scenarios **surpasses**, the inpainting quality of 10B+ industrial state-of-the-art (SOTA) generalist models like *FLUX.1-Fill-Dev*, while delivering a massive **>15× inference acceleration**.*
|
| 52 |
+
>
|
| 53 |
+
> 💡 **The core insight of Moebius can be summarized in a single equation:**
|
| 54 |
+
>
|
| 55 |
+
> $$\begin{aligned}
|
| 56 |
+
> \text{Synergy} \times (\text{Architecture} + \text{Distillation}) = & \text{Shattering the "Impossible Triangle" of} \\
|
| 57 |
+
> & \text{Low Parameters, Fast Inference, and High Quality}
|
| 58 |
+
> \end{aligned}$$
|
| 59 |
+
>
|
| 60 |
+
> --- *written on June 16, 2026* ---
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
## 🌟 Highlights
|
| 65 |
+
|
| 66 |
+
- **📉 Extreme Parametric Efficiency (< 2%)**: Moebius operates with a mere **0.22B (226M) parameters**, which represents **less than 2%** of the size of the colossal industrial giant *FLUX.1-Fill-Dev (11.9B)*. It shatters the heavy-compute narrative, making high-quality inpainting accessible on consumer-grade and edge devices.
|
| 67 |
+
- **⚡ 15× Inference Speedup (26ms/step)**: Achieves a blistering inference latency of only **26.01 ms per step** on a single GPU. Combined with optimized sampling steps, Moebius delivers an overall **>15× total runtime acceleration** compared to 10B-level models.
|
| 68 |
+
- **🏆 10B-Level Inpainting Quality (on-par-with/surpass FLUX.1-Fill-Dev across 6 benchmarks)**: Size contraction does not mean representation degradation. Through the synergistic optimization of architecture and distillation, Moebius performs on par with, and in certain scenarios (such as complex textures and facial plausibility), surpasses 10B-level state-of-the-art (SOTA) generalist models (*FLUX.1-Fill-Dev, SD3.5 Large-Inpainting*) across 6 comprehensive benchmarks spanning **both natural** scenes (*Places2*) and **portrait** scenes (*CelebA-HQ*, *FFHQ*).
|
| 69 |
+
- **💡 Synergistic Core Innovations**:
|
| 70 |
+
- **Architecture Design (LλMI Block)**: Reformulates both self- and cross-attention by condensing spatial context and global semantic priors into fixed-size linear matrices, bypassing quadratic computational overhead.
|
| 71 |
+
- **Adaptive Multi-Granularity Distillation Strategy**: Transfers the representational capacity from our *[PixelHacker](https://github.com/hustvl/PixelHacker)* (teacher) strictly within the latent space (avoiding expensive pixel-space decoding). It bridges the giant capacity gap by aligning multi-granularity supervision—ranging from microscopic intermediate features to macroscopic diffusion trajectories—while dynamically balancing training via a gradient norm adaptive loss weighting mechanism.
|
| 72 |
+
- **Optimal Synergistic Balancing**: Systematically explores the mutual constraint and upper bound between compact structure and distillation. By mapping this architecture-distillation synergy frontier, we ensure our 0.22B *Moebius* (student) absorbs the maximum semantic reasoning of *[PixelHacker](https://github.com/hustvl/PixelHacker)* (teacher) without triggering representation saturation.
|
| 73 |
+
|
| 74 |
+
<div align="center">
|
| 75 |
+
<img src="./assets/tab2.png" width="70%" style="margin-bottom: 10px;"></img>
|
| 76 |
+
</div>
|
| 77 |
+
|
| 78 |
+
- **🚀 Task-Specific Specialist over Bloated Generalists**: Rather than blindly scaling up, Moebius answers a fundamental question: *<ins>Can a model be smarter, lighter, and faster when the task is explicitly defined?</ins>* It serves as a highly optimized specialist that liberates real-world image inpainting and AI object removal from parameter bloat.
|
| 79 |
+
|
| 80 |
+
## 🔥 News
|
| 81 |
+
* **`June 19, 2026`:** 🎉 Moebius has achieved the [No. 1 daily ranking](https://huggingface.co/papers/date/2026-06-19) on Hugging Face!
|
| 82 |
+
|
| 83 |
+
* **`June 18, 2026`:** 🔥🔥 We have released the training and inference code, and open-sourced the [model weights](https://huggingface.co/hustvl/Moebius) on Hugging Face.
|
| 84 |
+
|
| 85 |
+
* **`June 18, 2026`:** 🎉 Moebius is accepted by ECCV'26! We have released the preprint on arXiv, check it [here](https://arxiv.org/abs/2606.19195) ~ 🍻
|
| 86 |
+
|
| 87 |
+
* **`June 16, 2026`:** 🔥 We have submitted the GitHub repo for the first time, and there will be more updates soon. Stay tuned! 🤗
|
| 88 |
+
|
| 89 |
+
## 🏕️ Performance on Natural Scene
|
| 90 |
+
|
| 91 |
+
<div align="center">
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
<img src="./assets/tab3.png"></img>
|
| 95 |
+
|
| 96 |
+
<img src="./assets/sup_showcase_places_v2.png"></img>
|
| 97 |
+
|
| 98 |
+
</div>
|
| 99 |
+
|
| 100 |
+
## 🤗 Performance on Portrait Scene
|
| 101 |
+
<div align="center">
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
<img src="./assets/tab4.png" width="50%"></img>
|
| 105 |
+
|
| 106 |
+
<img src="./assets/sup_showcase_celebahq_ffhq.png"></img>
|
| 107 |
+
|
| 108 |
+
</div>
|
| 109 |
+
|
| 110 |
+
## ⚖️ Evaluation Resources
|
| 111 |
+
|
| 112 |
+
The masks of the evaluation set are shared in [Google Drive](https://drive.google.com/drive/folders/13J91fdQt2RnHp4j-VzdtSrHRHPA1OxJ5?usp=sharing), and the corresponding images can be downloaded from the following open source platforms:
|
| 113 |
+
* Places2: [Places2](http://places2.csail.mit.edu/download-private.html)
|
| 114 |
+
* CelebA-HQ: [CelebA-HQ](https://openxlab.org.cn/datasets/OpenDataLab/CelebA-HQ)
|
| 115 |
+
* FFHQ: [FFHQ](https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_link)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
## 📦 Environment Setups
|
| 120 |
+
|
| 121 |
+
* torch=2.7.1
|
| 122 |
+
* diffusers=0.38.0
|
| 123 |
+
* transformers=4.56.2
|
| 124 |
+
* flash-linear-attention=0.3.2
|
| 125 |
+
* See 'requirements.txt' for detailed Python libraries required
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
conda create -n moebius python=3.14.4
|
| 129 |
+
conda activate moebius
|
| 130 |
+
# cd /xx/xx/Moebius
|
| 131 |
+
pip install -r requirements.txt
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
## 🗃️ Model Checkpoints
|
| 135 |
+
* Download the checkpoint of [VAE](https://huggingface.co/hustvl/PixelHacker/tree/main/vae) and put it into ./weight/vae.
|
| 136 |
+
|
| 137 |
+
* Download the checkpoints of [pretrained version](https://huggingface.co/hustvl/Moebius/tree/main/pretrained), [fine-tuned version (places2)](https://huggingface.co/hustvl/Moebius/tree/main/ft_places2), [fine-tuned version (celeba-hq)](https://huggingface.co/hustvl/Moebius/tree/main/ft_celebahq), [fine-tuned version (ffhq)](https://huggingface.co/hustvl/Moebius/tree/main/ft_ffhq), and put them into ./weight/Moebius.
|
| 138 |
+
|
| 139 |
+
* Finally, the detailed organizational form is as follows:
|
| 140 |
+
```bash
|
| 141 |
+
├── weight
|
| 142 |
+
| ├── Moebius
|
| 143 |
+
| ├── pretrained
|
| 144 |
+
| ├── diffusion_pytorch_model.bin
|
| 145 |
+
| ├── ft_places2
|
| 146 |
+
| ├── diffusion_pytorch_model.bin
|
| 147 |
+
| ├── ft_celebahq
|
| 148 |
+
| ├── diffusion_pytorch_model.bin
|
| 149 |
+
| ├── ft_ffhq
|
| 150 |
+
| ├── diffusion_pytorch_model.bin
|
| 151 |
+
| ├── vae
|
| 152 |
+
| ├── config.json
|
| 153 |
+
| ├── diffusion_pytorch_model.bin
|
| 154 |
+
├── ...
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
<!-- * teacher model and vae: [hustvl/PixelHacker](https://huggingface.co/hustvl/PixelHacker)
|
| 158 |
+
* student model: [hustvl/Moebius](https://huggingface.co/hustvl/Moebius) -->
|
| 159 |
+
|
| 160 |
+
## 🚂 Training
|
| 161 |
+
You can run the following code to start training. The training script supports distributed training, and you can configure the GPU count via environment variables.
|
| 162 |
+
```bash
|
| 163 |
+
# For single GPU training:
|
| 164 |
+
PY_TRAINER=train_distillation.py bash run/run_ddp_1node.sh config/train_demo.sh
|
| 165 |
+
|
| 166 |
+
# For multi GPU training:
|
| 167 |
+
NUM_GPUS_PER_MACHINE=4 bash run/run_ddp_1node.sh config/train_demo.sh
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
## 🔮 Inference
|
| 171 |
+
You can run the following code directly to get the inpainting result of the example image-mask pair, and the result will be generated in ./outputs. If you want to infer on custom data, just place the image and mask with the same name in ./dataset.local/imgs and ./dataset.local/masks, respectively, then run the following code as well.
|
| 172 |
+
```bash
|
| 173 |
+
python -m infer.infer_moebius \
|
| 174 |
+
--model-config config/model_cfg/moebius.yaml \
|
| 175 |
+
--model-weight weight/Moebius/ft_celebahq/diffusion_pytorch_model.bin \
|
| 176 |
+
--real-dir data/images \
|
| 177 |
+
--mask-dir data/masks \
|
| 178 |
+
--save-dir ./outputs \
|
| 179 |
+
--cfg 2.0 \
|
| 180 |
+
--batch-size 8 \
|
| 181 |
+
--num-workers 8
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
## 🎓 Citation
|
| 186 |
+
|
| 187 |
+
```shell
|
| 188 |
+
@misc{DuanAndXu2026Moebius,
|
| 189 |
+
title={Moebius: 0.2B Lightweight Image Inpainting Framework with 10B-Level Performance},
|
| 190 |
+
author={Kangsheng Duan and Ziyang Xu and Wenyu Liu and Xiaohu Ruan and Xiaoxin Chen and Xinggang Wang},
|
| 191 |
+
year={2026},
|
| 192 |
+
eprint={2606.19195},
|
| 193 |
+
archivePrefix={arXiv},
|
| 194 |
+
primaryClass={cs.CV},
|
| 195 |
+
url={https://arxiv.org/abs/2606.19195},
|
| 196 |
+
}
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
## 🧑🤝🧑 Acknowledgement
|
| 200 |
+
We sincerely thank the authors of the following open-source repositories for their contributions to the community, which have greatly facilitated our research and development of Moebius: [Sana](https://github.com/NVlabs/Sana), [flash-linear-attention](https://github.com/fla-org/flash-linear-attention), [lambda-networks](https://github.com/lucidrains/lambda-networks), [timm](https://github.com/huggingface/pytorch-image-models), [Muon](https://github.com/KellerJordan/Muon), [diffusers](https://github.com/huggingface/diffusers).
|
app.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
_hf_cache = "/data/.cache/huggingface" if os.path.isdir("/data") and os.access("/data", os.W_OK) else "/tmp/hf_home"
|
| 4 |
+
os.environ.setdefault("HF_HOME", _hf_cache)
|
| 5 |
+
os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules")
|
| 6 |
+
os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib")
|
| 7 |
+
os.environ.setdefault("GRADIO_SSR_MODE", "false")
|
| 8 |
+
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Dict, Tuple
|
| 12 |
+
|
| 13 |
+
import spaces
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import torch
|
| 16 |
+
from diffusers import DDIMScheduler
|
| 17 |
+
from diffusers.models import AutoencoderKL
|
| 18 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
from removal.v1_2 import build_removal_model, load_cfg, load_removal_model
|
| 22 |
+
from removal.v1_2.pipeline import RemovalSDXLPipeline_BatchMode
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
ROOT = Path(__file__).resolve().parent
|
| 26 |
+
CONFIG_PATH = ROOT / "config" / "model_cfg" / "moebius.yaml"
|
| 27 |
+
MOEBIUS_REPO = "hustvl/Moebius"
|
| 28 |
+
PIXELHACKER_REPO = "hustvl/PixelHacker"
|
| 29 |
+
DEFAULT_MODEL_KEY = "ft_places2"
|
| 30 |
+
|
| 31 |
+
MODEL_CHOICES = {
|
| 32 |
+
"General scenes (Places2)": "ft_places2",
|
| 33 |
+
"Portraits (CelebA-HQ)": "ft_celebahq",
|
| 34 |
+
"Faces (FFHQ)": "ft_ffhq",
|
| 35 |
+
"Pretrained": "pretrained",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
_PIPELINE_CACHE: Dict[str, RemovalSDXLPipeline_BatchMode] = {}
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _download_vae_dir() -> str:
|
| 42 |
+
repo_dir = snapshot_download(
|
| 43 |
+
repo_id=PIXELHACKER_REPO,
|
| 44 |
+
allow_patterns=["vae/*"],
|
| 45 |
+
)
|
| 46 |
+
return str(Path(repo_dir) / "vae")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _download_model_weight(model_key: str) -> str:
|
| 50 |
+
return hf_hub_download(
|
| 51 |
+
repo_id=MOEBIUS_REPO,
|
| 52 |
+
filename=f"{model_key}/diffusion_pytorch_model.bin",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _build_cpu_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode:
|
| 57 |
+
model_cfg = load_cfg(str(CONFIG_PATH))
|
| 58 |
+
model_cfg["vae"]["model_dir"] = _download_vae_dir()
|
| 59 |
+
|
| 60 |
+
removal_model = build_removal_model(model_cfg, 20)
|
| 61 |
+
weight_path = _download_model_weight(model_key)
|
| 62 |
+
print(load_removal_model(removal_model, weight_path, device="cpu"))
|
| 63 |
+
|
| 64 |
+
vae = AutoencoderKL.from_pretrained(model_cfg["vae"]["model_dir"])
|
| 65 |
+
scheduler = DDIMScheduler(
|
| 66 |
+
beta_start=0.00085,
|
| 67 |
+
beta_end=0.012,
|
| 68 |
+
beta_schedule="scaled_linear",
|
| 69 |
+
num_train_timesteps=1000,
|
| 70 |
+
clip_sample=False,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return RemovalSDXLPipeline_BatchMode(
|
| 74 |
+
removal_model=removal_model,
|
| 75 |
+
vae=vae,
|
| 76 |
+
scheduler=scheduler,
|
| 77 |
+
device="cpu",
|
| 78 |
+
dtype=torch.float32,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _get_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode:
|
| 83 |
+
if model_key not in _PIPELINE_CACHE:
|
| 84 |
+
_PIPELINE_CACHE[model_key] = _build_cpu_pipeline(model_key)
|
| 85 |
+
return _PIPELINE_CACHE[model_key]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _set_pipeline_device(pipe: RemovalSDXLPipeline_BatchMode, device: str) -> None:
|
| 89 |
+
pipe.device = device
|
| 90 |
+
pipe.vae.to(device=device, dtype=pipe.dtype).eval()
|
| 91 |
+
pipe.removal_model.to(device=device, dtype=pipe.dtype).eval()
|
| 92 |
+
|
| 93 |
+
half_id_num = pipe.removal_model.num_embeddings // 2
|
| 94 |
+
id_num = pipe.removal_model.num_embeddings
|
| 95 |
+
input_ids = torch.tensor([list(range(half_id_num))], dtype=torch.int64, device=device, requires_grad=False)
|
| 96 |
+
un_input_ids = torch.tensor([list(range(half_id_num, id_num))], dtype=torch.int64, device=device, requires_grad=False)
|
| 97 |
+
pipe.input_ids = torch.cat([un_input_ids, input_ids]).to(device=device)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _normalize_inputs(image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
|
| 101 |
+
if image is None:
|
| 102 |
+
raise gr.Error("Upload an image.")
|
| 103 |
+
if mask is None:
|
| 104 |
+
raise gr.Error("Upload a mask.")
|
| 105 |
+
|
| 106 |
+
image = image.convert("RGB")
|
| 107 |
+
mask = mask.convert("L").resize(image.size, Image.Resampling.NEAREST)
|
| 108 |
+
|
| 109 |
+
mask_min, mask_max = mask.getextrema()
|
| 110 |
+
if mask_max < 8:
|
| 111 |
+
raise gr.Error("The mask is empty. Use white pixels for the area to inpaint.")
|
| 112 |
+
if mask_min > 247:
|
| 113 |
+
raise gr.Error("The mask covers the whole image. Leave black pixels outside the edit area.")
|
| 114 |
+
|
| 115 |
+
return image, mask
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _model_key(label: str) -> str:
|
| 119 |
+
return MODEL_CHOICES.get(label, DEFAULT_MODEL_KEY)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _estimate_duration(image, mask, model_name, steps, guidance_scale, paste, compensate, seed, *args, **kwargs):
|
| 123 |
+
del image, mask, model_name, guidance_scale, paste, compensate, seed, args, kwargs
|
| 124 |
+
return min(240, 90 + int(steps) * 5)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
_get_pipeline(DEFAULT_MODEL_KEY)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@spaces.GPU(duration=1)
|
| 131 |
+
def _zerogpu_probe():
|
| 132 |
+
return "ready"
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@spaces.GPU(duration=_estimate_duration)
|
| 136 |
+
def run_inpaint(image, mask, model_name, steps, guidance_scale, paste, compensate, seed):
|
| 137 |
+
image, mask = _normalize_inputs(image, mask)
|
| 138 |
+
model_key = _model_key(model_name)
|
| 139 |
+
seed_value = 0 if seed is None else int(seed)
|
| 140 |
+
pipe = _get_pipeline(model_key)
|
| 141 |
+
|
| 142 |
+
started = time.perf_counter()
|
| 143 |
+
try:
|
| 144 |
+
_set_pipeline_device(pipe, "cuda")
|
| 145 |
+
with torch.inference_mode():
|
| 146 |
+
outputs = pipe(
|
| 147 |
+
[image],
|
| 148 |
+
[mask],
|
| 149 |
+
image_size=512,
|
| 150 |
+
num_steps=int(steps),
|
| 151 |
+
guidance_scale=float(guidance_scale),
|
| 152 |
+
paste=bool(paste),
|
| 153 |
+
compensate=bool(compensate),
|
| 154 |
+
noise_offset=0.0357,
|
| 155 |
+
retry=seed_value,
|
| 156 |
+
mute=True,
|
| 157 |
+
)
|
| 158 |
+
elapsed = time.perf_counter() - started
|
| 159 |
+
return outputs[0], f"Completed in {elapsed:.1f}s"
|
| 160 |
+
finally:
|
| 161 |
+
_set_pipeline_device(pipe, "cpu")
|
| 162 |
+
if torch.cuda.is_available():
|
| 163 |
+
torch.cuda.empty_cache()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
with gr.Blocks(title="Moebius Inpainting", fill_width=True) as demo:
|
| 167 |
+
gr.Markdown("# Moebius Inpainting")
|
| 168 |
+
with gr.Row():
|
| 169 |
+
with gr.Column(scale=1, min_width=320):
|
| 170 |
+
input_image = gr.Image(
|
| 171 |
+
label="Image",
|
| 172 |
+
type="pil",
|
| 173 |
+
image_mode="RGB",
|
| 174 |
+
sources=["upload", "clipboard"],
|
| 175 |
+
height=360,
|
| 176 |
+
)
|
| 177 |
+
input_mask = gr.Image(
|
| 178 |
+
label="Mask",
|
| 179 |
+
type="pil",
|
| 180 |
+
image_mode="L",
|
| 181 |
+
sources=["upload", "clipboard"],
|
| 182 |
+
height=360,
|
| 183 |
+
)
|
| 184 |
+
with gr.Column(scale=1, min_width=320):
|
| 185 |
+
output_image = gr.Image(label="Result", type="pil", height=520)
|
| 186 |
+
status = gr.Markdown()
|
| 187 |
+
|
| 188 |
+
with gr.Row():
|
| 189 |
+
model_name = gr.Dropdown(
|
| 190 |
+
label="Checkpoint",
|
| 191 |
+
choices=list(MODEL_CHOICES.keys()),
|
| 192 |
+
value="General scenes (Places2)",
|
| 193 |
+
min_width=240,
|
| 194 |
+
)
|
| 195 |
+
steps = gr.Slider(4, 30, value=20, step=1, label="Steps", min_width=180)
|
| 196 |
+
guidance_scale = gr.Slider(1.0, 6.0, value=2.0, step=0.1, label="CFG", min_width=180)
|
| 197 |
+
seed = gr.Number(value=0, precision=0, label="Seed", min_width=140)
|
| 198 |
+
|
| 199 |
+
with gr.Row():
|
| 200 |
+
paste = gr.Checkbox(value=True, label="Paste")
|
| 201 |
+
compensate = gr.Checkbox(value=False, label="Compensate")
|
| 202 |
+
run_button = gr.Button("Inpaint", variant="primary")
|
| 203 |
+
|
| 204 |
+
run_button.click(
|
| 205 |
+
fn=run_inpaint,
|
| 206 |
+
inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed],
|
| 207 |
+
outputs=[output_image, status],
|
| 208 |
+
api_name="inpaint",
|
| 209 |
+
concurrency_limit=1,
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
gr.Examples(
|
| 213 |
+
examples=[
|
| 214 |
+
["data/images/0.png", "data/masks/000000.png", "General scenes (Places2)", 20, 2.0, True, False, 0],
|
| 215 |
+
["data/images/1.png", "data/masks/000001.png", "General scenes (Places2)", 20, 2.0, True, False, 1],
|
| 216 |
+
],
|
| 217 |
+
inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed],
|
| 218 |
+
outputs=[output_image, status],
|
| 219 |
+
fn=run_inpaint,
|
| 220 |
+
cache_examples=True,
|
| 221 |
+
cache_mode="lazy",
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
demo.queue(max_size=8, default_concurrency_limit=1)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
if __name__ == "__main__":
|
| 229 |
+
demo.launch()
|
assets/logo_dynamic_woWaterMark.gif
ADDED
|
Git LFS Details
|
assets/pipeline.png
ADDED
|
Git LFS Details
|
assets/qualitative_comparison.png
ADDED
|
Git LFS Details
|
assets/sup_showcase_celebahq_ffhq.png
ADDED
|
Git LFS Details
|
assets/sup_showcase_places_v2.png
ADDED
|
Git LFS Details
|
assets/tab1.png
ADDED
|
Git LFS Details
|
assets/tab1_woTitle.png
ADDED
|
Git LFS Details
|
assets/tab2.png
ADDED
|
Git LFS Details
|
assets/tab3.png
ADDED
|
Git LFS Details
|
assets/tab4.png
ADDED
|
Git LFS Details
|
config/data_demo.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
path: data/train_data.jsonl
|
| 3 |
+
|
| 4 |
+
use_rand_mask: True
|
| 5 |
+
rand_mask_config: config/rand_mask_cfg/random_medium_512.yaml
|
| 6 |
+
|
| 7 |
+
use_extra_fg_mask: False
|
| 8 |
+
|
| 9 |
+
extra_ann_files_4_PureBackTrain_2_RandMask: null
|
config/model_cfg/moebius.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
image_size: 512
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
vae:
|
| 6 |
+
model_name: 'sdvae_f8d4'
|
| 7 |
+
model_dir: ./weight/vae
|
| 8 |
+
downsample_ratio: 8
|
| 9 |
+
embed_dim: 4
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
model:
|
| 13 |
+
model_type: UNet2DLambdaDWConvMixFFNConditionModel_prune_down_mid_up_block_8x8
|
| 14 |
+
|
| 15 |
+
in_channels: 9
|
| 16 |
+
out_channels: 4
|
| 17 |
+
|
| 18 |
+
attention_head_dim: 8
|
| 19 |
+
|
| 20 |
+
conv_in_kernel: 3
|
| 21 |
+
conv_out_kernel: 3
|
| 22 |
+
cross_attention_dim: 768
|
| 23 |
+
|
| 24 |
+
encoder_hid_dim: 3072
|
| 25 |
+
encoder_hid_dim_type: 'text_proj'
|
| 26 |
+
|
| 27 |
+
projection_class_embeddings_input_dim: 2560
|
| 28 |
+
|
| 29 |
+
use_lambda_cross_attn: True
|
| 30 |
+
use_local_self_attn: True
|
| 31 |
+
|
| 32 |
+
down_block_types:
|
| 33 |
+
- DWMixTFDownBlock2D
|
| 34 |
+
- DWMixTFDownBlock2D
|
| 35 |
+
- DWMixTFDownBlock2D
|
| 36 |
+
mid_block_type: null
|
| 37 |
+
up_block_types:
|
| 38 |
+
- DWMixTFUpBlock2D
|
| 39 |
+
- DWMixTFUpBlock2D
|
| 40 |
+
- DWMixTFUpBlock2D
|
| 41 |
+
|
| 42 |
+
block_out_channels:
|
| 43 |
+
- 320
|
| 44 |
+
- 640
|
| 45 |
+
- 1280
|
| 46 |
+
|
| 47 |
+
mix_mlp_ratio: 2.5
|
config/model_cfg/pixelhacker.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
image_size: 512
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
vae:
|
| 6 |
+
model_name: 'sdvae_f8d4'
|
| 7 |
+
model_dir: ./weight/vae
|
| 8 |
+
downsample_ratio: 8
|
| 9 |
+
embed_dim: 4
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
model:
|
| 13 |
+
model_type: UNet2DGLAConditionModel
|
| 14 |
+
|
| 15 |
+
in_channels: 9
|
| 16 |
+
out_channels: 4
|
| 17 |
+
|
| 18 |
+
attention_head_dim: 8
|
| 19 |
+
cross_attention_dim: 768
|
| 20 |
+
|
| 21 |
+
encoder_hid_dim: 3072
|
| 22 |
+
encoder_hid_dim_type: 'text_proj'
|
| 23 |
+
|
| 24 |
+
projection_class_embeddings_input_dim: 2560
|
config/rand_mask_cfg/random_medium_256.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generator_kind: random
|
| 2 |
+
|
| 3 |
+
mask_generator_kwargs:
|
| 4 |
+
irregular_proba: 1
|
| 5 |
+
irregular_kwargs:
|
| 6 |
+
min_times: 4
|
| 7 |
+
max_times: 5
|
| 8 |
+
max_width: 50
|
| 9 |
+
max_angle: 4
|
| 10 |
+
max_len: 100
|
| 11 |
+
|
| 12 |
+
box_proba: 0.3
|
| 13 |
+
box_kwargs:
|
| 14 |
+
margin: 0
|
| 15 |
+
bbox_min_size: 10
|
| 16 |
+
bbox_max_size: 50
|
| 17 |
+
max_times: 5
|
| 18 |
+
min_times: 1
|
| 19 |
+
|
| 20 |
+
segm_proba: 0
|
| 21 |
+
squares_proba: 0
|
| 22 |
+
|
| 23 |
+
# variants_n: 5
|
| 24 |
+
|
| 25 |
+
max_masks_per_image: 1
|
| 26 |
+
|
| 27 |
+
cropping:
|
| 28 |
+
out_min_size: 256
|
| 29 |
+
handle_small_mode: upscale
|
| 30 |
+
out_square_crop: True
|
| 31 |
+
crop_min_overlap: 1
|
| 32 |
+
|
| 33 |
+
max_tamper_area: 0.5
|
config/rand_mask_cfg/random_medium_512.yaml
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generator_kind: random
|
| 2 |
+
|
| 3 |
+
mask_generator_kwargs:
|
| 4 |
+
irregular_proba: 1
|
| 5 |
+
irregular_kwargs:
|
| 6 |
+
min_times: 4
|
| 7 |
+
max_times: 10
|
| 8 |
+
max_width: 100
|
| 9 |
+
max_angle: 4
|
| 10 |
+
max_len: 200
|
| 11 |
+
|
| 12 |
+
box_proba: 0.3
|
| 13 |
+
box_kwargs:
|
| 14 |
+
margin: 0
|
| 15 |
+
bbox_min_size: 30
|
| 16 |
+
bbox_max_size: 150
|
| 17 |
+
max_times: 5
|
| 18 |
+
min_times: 1
|
| 19 |
+
|
| 20 |
+
segm_proba: 0
|
| 21 |
+
squares_proba: 0
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
max_masks_per_image: 1
|
| 25 |
+
|
| 26 |
+
cropping:
|
| 27 |
+
out_min_size: 512
|
| 28 |
+
handle_small_mode: upscale
|
| 29 |
+
out_square_crop: True
|
| 30 |
+
crop_min_overlap: 1
|
| 31 |
+
|
| 32 |
+
max_tamper_area: 0.5
|
config/rand_mask_cfg/random_thick_256.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generator_kind: random
|
| 2 |
+
|
| 3 |
+
mask_generator_kwargs:
|
| 4 |
+
irregular_proba: 1
|
| 5 |
+
irregular_kwargs:
|
| 6 |
+
min_times: 1
|
| 7 |
+
max_times: 5
|
| 8 |
+
max_width: 100
|
| 9 |
+
max_angle: 4
|
| 10 |
+
max_len: 200
|
| 11 |
+
|
| 12 |
+
box_proba: 0.3
|
| 13 |
+
box_kwargs:
|
| 14 |
+
margin: 10
|
| 15 |
+
bbox_min_size: 30
|
| 16 |
+
bbox_max_size: 150
|
| 17 |
+
max_times: 3
|
| 18 |
+
min_times: 1
|
| 19 |
+
|
| 20 |
+
segm_proba: 0
|
| 21 |
+
squares_proba: 0
|
| 22 |
+
|
| 23 |
+
# variants_n: 5
|
| 24 |
+
|
| 25 |
+
max_masks_per_image: 1
|
| 26 |
+
|
| 27 |
+
cropping:
|
| 28 |
+
out_min_size: 256
|
| 29 |
+
handle_small_mode: upscale
|
| 30 |
+
out_square_crop: True
|
| 31 |
+
crop_min_overlap: 1
|
| 32 |
+
|
| 33 |
+
max_tamper_area: 0.5
|
config/rand_mask_cfg/random_thick_512.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generator_kind: random
|
| 2 |
+
|
| 3 |
+
mask_generator_kwargs:
|
| 4 |
+
irregular_proba: 1
|
| 5 |
+
irregular_kwargs:
|
| 6 |
+
min_times: 1
|
| 7 |
+
max_times: 5
|
| 8 |
+
max_width: 250
|
| 9 |
+
max_angle: 4
|
| 10 |
+
max_len: 450
|
| 11 |
+
|
| 12 |
+
box_proba: 0.3
|
| 13 |
+
box_kwargs:
|
| 14 |
+
margin: 10
|
| 15 |
+
bbox_min_size: 30
|
| 16 |
+
bbox_max_size: 300
|
| 17 |
+
max_times: 4
|
| 18 |
+
min_times: 1
|
| 19 |
+
|
| 20 |
+
segm_proba: 0
|
| 21 |
+
squares_proba: 0
|
| 22 |
+
|
| 23 |
+
# variants_n: 5
|
| 24 |
+
|
| 25 |
+
max_masks_per_image: 1
|
| 26 |
+
|
| 27 |
+
cropping:
|
| 28 |
+
out_min_size: 512
|
| 29 |
+
handle_small_mode: upscale
|
| 30 |
+
out_square_crop: True
|
| 31 |
+
crop_min_overlap: 1
|
| 32 |
+
|
| 33 |
+
max_tamper_area: 0.5
|
config/rand_mask_cfg/random_thin_256.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generator_kind: random
|
| 2 |
+
|
| 3 |
+
mask_generator_kwargs:
|
| 4 |
+
irregular_proba: 1
|
| 5 |
+
irregular_kwargs:
|
| 6 |
+
min_times: 4
|
| 7 |
+
max_times: 50
|
| 8 |
+
max_width: 10
|
| 9 |
+
max_angle: 4
|
| 10 |
+
max_len: 40
|
| 11 |
+
box_proba: 0
|
| 12 |
+
segm_proba: 0
|
| 13 |
+
squares_proba: 0
|
| 14 |
+
|
| 15 |
+
variants_n: 5
|
| 16 |
+
|
| 17 |
+
max_masks_per_image: 1
|
| 18 |
+
|
| 19 |
+
cropping:
|
| 20 |
+
out_min_size: 256
|
| 21 |
+
handle_small_mode: upscale
|
| 22 |
+
out_square_crop: True
|
| 23 |
+
crop_min_overlap: 1
|
| 24 |
+
|
| 25 |
+
max_tamper_area: 0.5
|
config/rand_mask_cfg/random_thin_512.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
generator_kind: random
|
| 2 |
+
|
| 3 |
+
mask_generator_kwargs:
|
| 4 |
+
irregular_proba: 1
|
| 5 |
+
irregular_kwargs:
|
| 6 |
+
min_times: 4
|
| 7 |
+
max_times: 70
|
| 8 |
+
max_width: 20
|
| 9 |
+
max_angle: 4
|
| 10 |
+
max_len: 100
|
| 11 |
+
box_proba: 0
|
| 12 |
+
segm_proba: 0
|
| 13 |
+
squares_proba: 0
|
| 14 |
+
|
| 15 |
+
variants_n: 5
|
| 16 |
+
|
| 17 |
+
max_masks_per_image: 1
|
| 18 |
+
|
| 19 |
+
cropping:
|
| 20 |
+
out_min_size: 512
|
| 21 |
+
handle_small_mode: upscale
|
| 22 |
+
out_square_crop: True
|
| 23 |
+
crop_min_overlap: 1
|
| 24 |
+
|
| 25 |
+
max_tamper_area: 0.5
|
config/train_demo.sh
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Set WORK_DIR to your project root before running
|
| 2 |
+
THIS_SH_PATH=$CONFIG_FILE
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
OUTPUT_DIR='exp_outputs'
|
| 6 |
+
OUTPUT_DIR_EXP_NAME="${OUTPUT_DIR}/${EXP_NAME}"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
export OUTPUT_DIR=$OUTPUT_DIR
|
| 10 |
+
export OUTPUT_DIR_EXP_NAME=$OUTPUT_DIR_EXP_NAME
|
| 11 |
+
export HF_HOME=$HF_HOME
|
| 12 |
+
|
| 13 |
+
export TRAIN_ARGS=" --data_type RemovalDataset_v1_2 \
|
| 14 |
+
--lognorm_t \
|
| 15 |
+
--elatentlpips_loss --elatentlpips_loss_weight 0.5 \
|
| 16 |
+
--task_loss --task_loss_weight 0.5 \
|
| 17 |
+
--KD_loss_weight 0.01 \
|
| 18 |
+
--mse_feat_loss --feat_loss_weight 1.0 --feat_index_T 5 --feat_index_S 2 \
|
| 19 |
+
--model_config_path=config/model_cfg/moebius.yaml \
|
| 20 |
+
--teacher_weight_path=../../hf_models/hustvl/PixelHacker/pretrained/diffusion_pytorch_model.bin \
|
| 21 |
+
--teacher_config_path=config/model_cfg/pixelhacker.yaml \
|
| 22 |
+
--data_config=config/data_demo.yaml \
|
| 23 |
+
--num_embeddings 20 \
|
| 24 |
+
--image_size 512 \
|
| 25 |
+
--batch_size 2 \
|
| 26 |
+
--num_workers 4 \
|
| 27 |
+
--output_dir=${OUTPUT_DIR_EXP_NAME} \
|
| 28 |
+
--output_name=exp \
|
| 29 |
+
--seed=42 \
|
| 30 |
+
--learning_rate=1e-4 \
|
| 31 |
+
--global_step=0 \
|
| 32 |
+
--max_train_steps=200000 \
|
| 33 |
+
--save_every_n_steps=3000 \
|
| 34 |
+
--logging_dir=${OUTPUT_DIR_EXP_NAME}/log \
|
| 35 |
+
--gradient_accumulation_steps=1 \
|
| 36 |
+
--optimizer_type=Muon \
|
| 37 |
+
--lr_scheduler=constant_with_warmup \
|
| 38 |
+
--lr_warmup_steps=0 \
|
| 39 |
+
--save_precision=bf16 \
|
| 40 |
+
--mixed_precision=bf16 \
|
| 41 |
+
--noise_offset=0.0357 \
|
| 42 |
+
--gradient_checkpointing \
|
| 43 |
+
--xformers \
|
| 44 |
+
--log_with=tensorboard \
|
| 45 |
+
--script_args=$THIS_SH_PATH "
|
data/images/0.png
ADDED
|
Git LFS Details
|
data/images/1.png
ADDED
|
Git LFS Details
|
data/images/10.png
ADDED
|
Git LFS Details
|
data/images/100.png
ADDED
|
Git LFS Details
|
data/images/10000.png
ADDED
|
Git LFS Details
|
data/images/10001.png
ADDED
|
Git LFS Details
|
data/images/10002.png
ADDED
|
Git LFS Details
|
data/images/10003.png
ADDED
|
Git LFS Details
|
data/masks/000000.png
ADDED
|
data/masks/000001.png
ADDED
|
data/masks/000002.png
ADDED
|
data/masks/000003.png
ADDED
|
data/masks/000004.png
ADDED
|
data/masks/000005.png
ADDED
|
data/masks/000006.png
ADDED
|
data/masks/000007.png
ADDED
|
data/train_data.jsonl
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"image": "data/images/0.png", "prompt": "background"}
|
| 2 |
+
{"image": "data/images/1.png", "prompt": "background"}
|
| 3 |
+
{"image": "data/images/10.png", "prompt": "background"}
|
| 4 |
+
{"image": "data/images/100.png", "prompt": "background"}
|
| 5 |
+
{"image": "data/images/10000.png", "prompt": "background"}
|
| 6 |
+
{"image": "data/images/10001.png", "prompt": "background"}
|
| 7 |
+
{"image": "data/images/10002.png", "prompt": "background"}
|
| 8 |
+
{"image": "data/images/10003.png", "prompt": "background"}
|
infer/__init__.py
ADDED
|
File without changes
|
infer/infer_moebius.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import os
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from .utils import get_batch_infer_args, build_pipeline, SAVER
|
| 14 |
+
from .utils_dataset import SimpleInferDataset, build_dataloader
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
args = get_batch_infer_args()
|
| 20 |
+
|
| 21 |
+
dataloader = build_dataloader(args, SimpleInferDataset)
|
| 22 |
+
|
| 23 |
+
pipe = build_pipeline(args)
|
| 24 |
+
pipe = partial(pipe,
|
| 25 |
+
guidance_scale=args.cfg,
|
| 26 |
+
paste=args.pst,
|
| 27 |
+
compensate=args.cps,
|
| 28 |
+
num_steps=args.num_step,
|
| 29 |
+
noise_offset=args.noise_offset
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
save_root = Path(args.save_dir)
|
| 33 |
+
save_root.mkdir(parents=True, exist_ok=True)
|
| 34 |
+
|
| 35 |
+
pbar_loader = tqdm(enumerate(dataloader),
|
| 36 |
+
total=dataloader.dataset.__len__()//args.batch_size+1)
|
| 37 |
+
|
| 38 |
+
for idx, (images, masks, inames) in pbar_loader:
|
| 39 |
+
image_inpaint_list = pipe(images, masks)
|
| 40 |
+
names = [iname+'.png' for iname in inames]
|
| 41 |
+
SAVER.save_images_mp(image_inpaint_list, names, save_root)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == '__main__':
|
| 45 |
+
main()
|
infer/utils.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import os
|
| 3 |
+
from typing import List
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import math
|
| 6 |
+
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
from PIL import Image
|
| 11 |
+
|
| 12 |
+
def get_batch_infer_args(parser=None):
|
| 13 |
+
|
| 14 |
+
if parser is None:
|
| 15 |
+
import argparse
|
| 16 |
+
parser = argparse.ArgumentParser()
|
| 17 |
+
|
| 18 |
+
def str2bool(v):
|
| 19 |
+
if isinstance(v, bool):
|
| 20 |
+
return v
|
| 21 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
| 22 |
+
return True
|
| 23 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
| 24 |
+
return False
|
| 25 |
+
else:
|
| 26 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# model argument
|
| 30 |
+
parser.add_argument("--model-config", type=str, required=False, default=None)
|
| 31 |
+
parser.add_argument("--model-weight", type=str, required=False, default=None)
|
| 32 |
+
|
| 33 |
+
# sampling argument
|
| 34 |
+
parser.add_argument("--num-step", type=int, required=False, default=20)
|
| 35 |
+
parser.add_argument("--cfg", type=float, required=False, default=2.5)
|
| 36 |
+
parser.add_argument("--pst", type=str2bool, required=False, default=True)
|
| 37 |
+
parser.add_argument("--cps", type=str2bool, required=False, default=False)
|
| 38 |
+
parser.add_argument("--noise-offset", type=float, required=False, default=0.0357)
|
| 39 |
+
parser.add_argument("--seed", type=int, default=0, required=False)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# data argument
|
| 43 |
+
parser.add_argument("--real-dir", type=Path, required=True)
|
| 44 |
+
parser.add_argument("--mask-dir", type=Path, required=False)
|
| 45 |
+
parser.add_argument("--resolution", type=int, default=512, required=False)
|
| 46 |
+
|
| 47 |
+
# runtime argument
|
| 48 |
+
parser.add_argument("--device", type=str, required=False, default="cuda")
|
| 49 |
+
parser.add_argument("--batch-size", type=int, required=False, default=32)
|
| 50 |
+
parser.add_argument("--num-workers", type=int, required=False, default=64)
|
| 51 |
+
|
| 52 |
+
# save argument
|
| 53 |
+
parser.add_argument("--save-dir", type=str, required=True)
|
| 54 |
+
parser.add_argument("--visualize-latent", action="store_true", default=False)
|
| 55 |
+
|
| 56 |
+
return parser.parse_args()
|
| 57 |
+
|
| 58 |
+
def build_pipeline(args):
|
| 59 |
+
from diffusers import DDIMScheduler
|
| 60 |
+
from removal.v1_2.pipeline import RemovalSDXLPipeline_BatchMode as Removal_Pipeline
|
| 61 |
+
from removal.v1_2 import build_removal_model, load_cfg, load_removal_model
|
| 62 |
+
from utils_train import build_vae
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
model_cfg = load_cfg(args.model_config)
|
| 66 |
+
|
| 67 |
+
removal_model = build_removal_model(model_cfg, 20).to(args.device)
|
| 68 |
+
print(load_removal_model(removal_model, args.model_weight,args.device))
|
| 69 |
+
|
| 70 |
+
vae = build_vae(model_cfg).to(args.device)
|
| 71 |
+
scheduler = DDIMScheduler(
|
| 72 |
+
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
|
| 73 |
+
num_train_timesteps=1000, clip_sample=False)
|
| 74 |
+
|
| 75 |
+
pipe = Removal_Pipeline(
|
| 76 |
+
removal_model=removal_model,
|
| 77 |
+
vae=vae,
|
| 78 |
+
scheduler=scheduler,
|
| 79 |
+
device=args.device,
|
| 80 |
+
dtype=torch.float)
|
| 81 |
+
|
| 82 |
+
return pipe
|
| 83 |
+
|
| 84 |
+
class SAVER:
|
| 85 |
+
@staticmethod
|
| 86 |
+
def save_image(img, name, path):
|
| 87 |
+
img.save(path / name)
|
| 88 |
+
return name
|
| 89 |
+
|
| 90 |
+
@staticmethod
|
| 91 |
+
def save_images(images:List[Image.Image], names:List[str], save_root:str):
|
| 92 |
+
assert len(images) == len(names), \
|
| 93 |
+
f"images and names are not equal: {len(images)}!={len(names)}"
|
| 94 |
+
|
| 95 |
+
pbar_save = tqdm(zip(images, names), total=len(names))
|
| 96 |
+
|
| 97 |
+
cache_names = os.listdir(save_root)
|
| 98 |
+
for image, name in pbar_save:
|
| 99 |
+
if name not in cache_names:
|
| 100 |
+
SAVER.save_image(image, name, save_root)
|
| 101 |
+
|
| 102 |
+
@staticmethod
|
| 103 |
+
def save_images_mt(images:List[Image.Image], names:List[str], save_root:str, num_workers=8):
|
| 104 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 105 |
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
| 106 |
+
futures = [
|
| 107 |
+
executor.submit(SAVER.save_image, image, name, save_root) for image, name in zip(images, names)]
|
| 108 |
+
|
| 109 |
+
for future in tqdm(futures):
|
| 110 |
+
future.result()
|
| 111 |
+
|
| 112 |
+
@staticmethod
|
| 113 |
+
def save_images_mp(images:List[Image.Image], names:List[str], save_root:str, num_workers=8):
|
| 114 |
+
from concurrent.futures import ProcessPoolExecutor
|
| 115 |
+
with ProcessPoolExecutor(max_workers=num_workers) as executor:
|
| 116 |
+
futures = [
|
| 117 |
+
executor.submit(SAVER.save_image, image, name, save_root) for image, name in zip(images, names)]
|
| 118 |
+
|
| 119 |
+
for future in tqdm(futures):
|
| 120 |
+
future.result()
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
|
infer/utils_dataset.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------- Dataset Utils -----------------------
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Tuple, Optional
|
| 6 |
+
import math
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from PIL import Image, ImageDraw
|
| 14 |
+
from torch.utils.data import Dataset, DataLoader
|
| 15 |
+
|
| 16 |
+
warnings.filterwarnings("ignore")
|
| 17 |
+
|
| 18 |
+
def RandomBrush(
|
| 19 |
+
max_tries,
|
| 20 |
+
s,
|
| 21 |
+
min_num_vertex=4,
|
| 22 |
+
max_num_vertex=18,
|
| 23 |
+
mean_angle=2*math.pi / 5,
|
| 24 |
+
angle_range=2*math.pi / 15,
|
| 25 |
+
min_width=12,
|
| 26 |
+
max_width=48
|
| 27 |
+
):
|
| 28 |
+
H, W = s, s
|
| 29 |
+
average_radius = math.sqrt(H*H+W*W) / 8
|
| 30 |
+
mask = Image.new('L', (W, H), 0)
|
| 31 |
+
for _ in range(np.random.randint(max_tries)):
|
| 32 |
+
num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
|
| 33 |
+
angle_min = mean_angle - np.random.uniform(0, angle_range)
|
| 34 |
+
angle_max = mean_angle + np.random.uniform(0, angle_range)
|
| 35 |
+
angles = []
|
| 36 |
+
vertex = []
|
| 37 |
+
for i in range(num_vertex):
|
| 38 |
+
if i % 2 == 0:
|
| 39 |
+
angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
|
| 40 |
+
else:
|
| 41 |
+
angles.append(np.random.uniform(angle_min, angle_max))
|
| 42 |
+
|
| 43 |
+
h, w = mask.size
|
| 44 |
+
vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
|
| 45 |
+
for i in range(num_vertex):
|
| 46 |
+
r = np.clip(
|
| 47 |
+
np.random.normal(loc=average_radius, scale=average_radius//2),
|
| 48 |
+
0, 2*average_radius)
|
| 49 |
+
new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
|
| 50 |
+
new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
|
| 51 |
+
vertex.append((int(new_x), int(new_y)))
|
| 52 |
+
|
| 53 |
+
draw = ImageDraw.Draw(mask)
|
| 54 |
+
width = int(np.random.uniform(min_width, max_width))
|
| 55 |
+
draw.line(vertex, fill=1, width=width)
|
| 56 |
+
for v in vertex:
|
| 57 |
+
draw.ellipse((v[0] - width//2,
|
| 58 |
+
v[1] - width//2,
|
| 59 |
+
v[0] + width//2,
|
| 60 |
+
v[1] + width//2),
|
| 61 |
+
fill=1)
|
| 62 |
+
if np.random.random() > 0.5:
|
| 63 |
+
mask.transpose(Image.FLIP_LEFT_RIGHT)
|
| 64 |
+
if np.random.random() > 0.5:
|
| 65 |
+
mask.transpose(Image.FLIP_TOP_BOTTOM)
|
| 66 |
+
mask = np.asarray(mask, np.uint8)
|
| 67 |
+
if np.random.random() > 0.5:
|
| 68 |
+
mask = np.flip(mask, 0)
|
| 69 |
+
if np.random.random() > 0.5:
|
| 70 |
+
mask = np.flip(mask, 1)
|
| 71 |
+
return mask
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def RandomMask(s, hole_range=[0,1]):
|
| 75 |
+
coef = min(hole_range[0] + hole_range[1], 1.0)
|
| 76 |
+
while True:
|
| 77 |
+
mask = np.ones((s, s), np.uint8)
|
| 78 |
+
|
| 79 |
+
def Fill(max_size):
|
| 80 |
+
w, h = np.random.randint(max_size), np.random.randint(max_size)
|
| 81 |
+
ww, hh = w // 2, h // 2
|
| 82 |
+
x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
|
| 83 |
+
mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
|
| 84 |
+
|
| 85 |
+
def MultiFill(max_tries, max_size):
|
| 86 |
+
for _ in range(np.random.randint(max_tries)):
|
| 87 |
+
Fill(max_size)
|
| 88 |
+
|
| 89 |
+
MultiFill(int(10 * coef), s // 2)
|
| 90 |
+
MultiFill(int(5 * coef), s)
|
| 91 |
+
mask = np.logical_and(mask, 1 - RandomBrush(int(20 * coef), s))
|
| 92 |
+
hole_ratio = 1 - np.mean(mask)
|
| 93 |
+
if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
|
| 94 |
+
continue
|
| 95 |
+
return (mask * 255).astype(np.uint8)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class InferDataset(Dataset): # ABC
|
| 99 |
+
img_ext = {".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"}
|
| 100 |
+
def __init__(
|
| 101 |
+
self,
|
| 102 |
+
real_dir: Path,
|
| 103 |
+
mask_dir: Optional[Path] = None,
|
| 104 |
+
resolution: int = None
|
| 105 |
+
):
|
| 106 |
+
super(InferDataset, self).__init__()
|
| 107 |
+
|
| 108 |
+
self.img_paths = sorted([i for i in Path(real_dir).iterdir() if i.suffix in self.img_ext])
|
| 109 |
+
self.mask_dir = mask_dir
|
| 110 |
+
self.resolution = resolution
|
| 111 |
+
|
| 112 |
+
def __len__(self):
|
| 113 |
+
return len(self.img_paths)
|
| 114 |
+
|
| 115 |
+
def __getitem__(self, index) -> Tuple[torch.Tensor, np.array, np.array, str]:
|
| 116 |
+
img_path = Path(self.img_paths[index])
|
| 117 |
+
img_name = img_path.stem
|
| 118 |
+
img = Image.open(img_path).convert("RGB")
|
| 119 |
+
if img.size[0] != self.resolution or img.size[1] != self.resolution:
|
| 120 |
+
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
| 121 |
+
assert img.size[0] == self.resolution
|
| 122 |
+
|
| 123 |
+
if self.mask_dir is not None:
|
| 124 |
+
# mask_path = self.mask_dir / f"{img_name}.png"
|
| 125 |
+
mask_path = self.mask_dir / f"img000{img_name}.png"
|
| 126 |
+
mask = Image.open(mask_path).convert("L")
|
| 127 |
+
mask = mask.resize((self.resolution, self.resolution), Image.NEAREST)
|
| 128 |
+
assert mask.size[0] == self.resolution
|
| 129 |
+
else:
|
| 130 |
+
mask = RandomMask(img.size[0])
|
| 131 |
+
mask = Image.fromarray(mask).convert("L")
|
| 132 |
+
|
| 133 |
+
img = np.array(img)
|
| 134 |
+
mask = np.array(mask)[:, :, np.newaxis] // 255
|
| 135 |
+
img = torch.Tensor(img).float() * 2 / 255 - 1
|
| 136 |
+
mask = torch.Tensor(mask).float()
|
| 137 |
+
img = img.permute(2, 0, 1)
|
| 138 |
+
mask = mask.permute(2, 0, 1)
|
| 139 |
+
x = torch.cat([mask - 0.5, img * mask], dim=0)
|
| 140 |
+
return x, np.array(img), mask, img_name
|
| 141 |
+
|
| 142 |
+
class SimpleInferDataset(torch.utils.data.Dataset):
|
| 143 |
+
def __init__(
|
| 144 |
+
self,
|
| 145 |
+
real_dir: Path,
|
| 146 |
+
mask_dir: Path = None,
|
| 147 |
+
resolution: int = 512
|
| 148 |
+
):
|
| 149 |
+
super(SimpleInferDataset, self).__init__()
|
| 150 |
+
|
| 151 |
+
img_extensions = {".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"}
|
| 152 |
+
self.img_paths = sorted([i for i in Path(real_dir).iterdir() if i.suffix in img_extensions])
|
| 153 |
+
self.img_dir = real_dir
|
| 154 |
+
|
| 155 |
+
if mask_dir:
|
| 156 |
+
self.mask_paths = sorted([i for i in Path(mask_dir).iterdir() if i.suffix in img_extensions])
|
| 157 |
+
self.mask_dir = mask_dir
|
| 158 |
+
|
| 159 |
+
self.resolution = resolution
|
| 160 |
+
|
| 161 |
+
def __getitem__(self, index):
|
| 162 |
+
img_path = Path(self.img_paths[index])
|
| 163 |
+
img_name = os.path.basename(img_path)
|
| 164 |
+
|
| 165 |
+
img = Image.open(img_path).convert("RGB")
|
| 166 |
+
|
| 167 |
+
if self.mask_dir:
|
| 168 |
+
mask_path = Path(self.mask_paths[index])
|
| 169 |
+
mask = Image.open(mask_path).convert("L")
|
| 170 |
+
else:
|
| 171 |
+
mask = RandomMask(img.size[0])
|
| 172 |
+
mask = Image.fromarray(mask).convert("L")
|
| 173 |
+
|
| 174 |
+
mask = mask.resize((self.resolution, self.resolution), Image.NEAREST)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if img.size[0] != self.resolution or img.size[1] != self.resolution:
|
| 178 |
+
img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
|
| 179 |
+
|
| 180 |
+
return img, mask, img_name
|
| 181 |
+
|
| 182 |
+
def __len__(self):
|
| 183 |
+
return len(self.img_paths)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def collate_fn(inputs):
|
| 188 |
+
image_list = [i[0] for i in inputs]
|
| 189 |
+
mask_list = [i[1] for i in inputs]
|
| 190 |
+
iname_list = [i[2] for i in inputs]
|
| 191 |
+
return image_list, mask_list, iname_list
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def build_dataloader(args, dataset_class=InferDataset):
|
| 195 |
+
dataset = dataset_class(
|
| 196 |
+
real_dir=args.real_dir,
|
| 197 |
+
mask_dir=args.mask_dir,
|
| 198 |
+
resolution=args.resolution)
|
| 199 |
+
|
| 200 |
+
dataloader = DataLoader(
|
| 201 |
+
dataset,
|
| 202 |
+
shuffle=False,
|
| 203 |
+
batch_size=args.batch_size,
|
| 204 |
+
num_workers=args.num_workers,
|
| 205 |
+
drop_last=False,
|
| 206 |
+
collate_fn = collate_fn,
|
| 207 |
+
pin_memory=True,
|
| 208 |
+
# persistent_workers=True
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
return dataloader
|
library/__init__.py
ADDED
|
File without changes
|
library/chinese_sdxl_train_util.py
ADDED
|
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import gc
|
| 4 |
+
import re
|
| 5 |
+
import json
|
| 6 |
+
import math
|
| 7 |
+
import time
|
| 8 |
+
import toml
|
| 9 |
+
import shutil
|
| 10 |
+
import argparse
|
| 11 |
+
from typing import Optional
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from PIL import Image
|
| 14 |
+
import importlib
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 18 |
+
|
| 19 |
+
from library import train_util
|
| 20 |
+
from transformers import BertTokenizer, BertTokenizerFast, ChineseCLIPTextModel, PreTrainedTokenizerFast, T5Tokenizer, T5ForConditionalGeneration
|
| 21 |
+
|
| 22 |
+
from diffusers import (
|
| 23 |
+
DDPMScheduler,
|
| 24 |
+
EulerAncestralDiscreteScheduler,
|
| 25 |
+
DPMSolverMultistepScheduler,
|
| 26 |
+
DDIMScheduler,
|
| 27 |
+
EulerDiscreteScheduler,
|
| 28 |
+
KDPM2DiscreteScheduler,
|
| 29 |
+
AutoencoderKL,
|
| 30 |
+
UNet2DConditionModel,
|
| 31 |
+
)
|
| 32 |
+
from diffusers.models import UNet2DConditionModel, Transformer2DModel
|
| 33 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 34 |
+
from tqdm import tqdm
|
| 35 |
+
from transformers import BertTokenizerFast, ChineseCLIPTextModel
|
| 36 |
+
from library import train_util
|
| 37 |
+
|
| 38 |
+
# from mmmp_text import DebertaV2Model
|
| 39 |
+
# from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
|
| 40 |
+
# from diffusers_patch.models.vivo_llm2vec import LLM2VecWithoutPool
|
| 41 |
+
# from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
|
| 42 |
+
|
| 43 |
+
DEFAULT_NOISE_OFFSET = 0.0357
|
| 44 |
+
|
| 45 |
+
def load_target_model(args, accelerator, pipe_class, weight_dtype):
|
| 46 |
+
# load models for each process
|
| 47 |
+
for pi in range(accelerator.state.num_processes):
|
| 48 |
+
if pi == accelerator.state.local_process_index:
|
| 49 |
+
print(f"loading model for process {accelerator.process_index}/{accelerator.state.num_processes}")
|
| 50 |
+
|
| 51 |
+
(
|
| 52 |
+
text_encoder1,
|
| 53 |
+
text_encoder2,
|
| 54 |
+
vae,
|
| 55 |
+
unet,
|
| 56 |
+
) = _load_target_model(
|
| 57 |
+
args,
|
| 58 |
+
args.pretrained_model_name_or_path,
|
| 59 |
+
args.vae,
|
| 60 |
+
pipe_class,
|
| 61 |
+
weight_dtype,
|
| 62 |
+
accelerator.device if args.lowram else "cpu",
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
gc.collect()
|
| 66 |
+
torch.cuda.empty_cache()
|
| 67 |
+
accelerator.wait_for_everyone()
|
| 68 |
+
|
| 69 |
+
return text_encoder1, text_encoder2, vae, unet
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _load_target_model(args, name_or_path: str, vae_path: Optional[str], pipe_class, weight_dtype, device="cpu"):
|
| 73 |
+
name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
|
| 74 |
+
|
| 75 |
+
model_index_path = os.path.join(name_or_path, 'model_index.json')
|
| 76 |
+
model_index = read_json(model_index_path)
|
| 77 |
+
TextEncoderLib1 = model_index['text_encoder'][0]
|
| 78 |
+
TextEncoderLib2 = model_index['text_encoder_2'][0]
|
| 79 |
+
TextEncoderClass1 = model_index['text_encoder'][-1]
|
| 80 |
+
TextEncoderClass2 = model_index['text_encoder_2'][-1]
|
| 81 |
+
|
| 82 |
+
library1 = importlib.import_module(TextEncoderLib1)
|
| 83 |
+
library2 = importlib.import_module(TextEncoderLib2)
|
| 84 |
+
TextEncoderClass1 = getattr(library1, TextEncoderClass1)
|
| 85 |
+
TextEncoderClass2 = getattr(library2, TextEncoderClass2)
|
| 86 |
+
|
| 87 |
+
if 'unet' in model_index:
|
| 88 |
+
UNetClass = eval(model_index['unet'][-1])
|
| 89 |
+
unet_dir = 'unet'
|
| 90 |
+
elif 'transformer' in model_index:
|
| 91 |
+
UNetClass = eval(model_index['transformer'][-1])
|
| 92 |
+
unet_dir = 'transformer'
|
| 93 |
+
|
| 94 |
+
print(f"TextEncoderClass1:{TextEncoderClass1}")
|
| 95 |
+
print(f"TextEncoderClass2:{TextEncoderClass2}")
|
| 96 |
+
print(f"UNetClass:{UNetClass}")
|
| 97 |
+
|
| 98 |
+
vae = AutoencoderKL.from_pretrained(os.path.join(name_or_path, 'vae'), torch_dtype=weight_dtype, low_cpu_mem_usage=False, device_map=None)
|
| 99 |
+
unet = UNetClass.from_pretrained(os.path.join(name_or_path, unet_dir), torch_dtype=weight_dtype, low_cpu_mem_usage=False, device_map=None, ignore_mismatched_sizes=True)
|
| 100 |
+
|
| 101 |
+
text_encoder1 = TextEncoderClass1.from_pretrained(os.path.join(name_or_path, 'text_encoder'), torch_dtype=weight_dtype)
|
| 102 |
+
text_encoder2 = TextEncoderClass2.from_pretrained(os.path.join(name_or_path, 'text_encoder_2'), torch_dtype=weight_dtype)
|
| 103 |
+
|
| 104 |
+
vae_version = vae.config.version if 'version' in vae.config else ''
|
| 105 |
+
if vae_version == 'vivo':
|
| 106 |
+
vae.quant_conv = torch.nn.Identity()
|
| 107 |
+
vae.post_quant_conv = torch.nn.Identity()
|
| 108 |
+
|
| 109 |
+
return text_encoder1, text_encoder2, vae, unet
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def load_tokenizers(args: argparse.Namespace):
|
| 113 |
+
print("prepare tokenizers")
|
| 114 |
+
model_index_path = os.path.join(args.pretrained_model_name_or_path, 'model_index.json')
|
| 115 |
+
model_index = read_json(model_index_path)
|
| 116 |
+
ToeknierLib1 = model_index['tokenizer'][0]
|
| 117 |
+
ToeknierLib2 = model_index['tokenizer_2'][0]
|
| 118 |
+
TokenierClass1 = model_index['tokenizer'][-1]
|
| 119 |
+
TokenierClass2 = "BertTokenizer" # ToDo: model_index['tokenizer_2'][-1]
|
| 120 |
+
|
| 121 |
+
library1 = importlib.import_module(ToeknierLib1)
|
| 122 |
+
library2 = importlib.import_module(ToeknierLib2)
|
| 123 |
+
TokenierClass1 = getattr(library1, TokenierClass1)
|
| 124 |
+
TokenierClass2 = getattr(library2, TokenierClass2)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
tokenizer_1 = TokenierClass1.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
|
| 128 |
+
tokenizer_2 = TokenierClass2.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer_2')
|
| 129 |
+
tokeniers = [tokenizer_1, tokenizer_2]
|
| 130 |
+
|
| 131 |
+
if hasattr(args, "max_token_length") and args.max_token_length is not None:
|
| 132 |
+
print(f"update token length: {args.max_token_length}")
|
| 133 |
+
|
| 134 |
+
return tokeniers
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def get_hidden_states_sdxl(
|
| 138 |
+
input_ids1: torch.Tensor,
|
| 139 |
+
input_ids2: torch.Tensor,
|
| 140 |
+
tokenizer1: BertTokenizerFast,
|
| 141 |
+
tokenizer2: BertTokenizerFast,
|
| 142 |
+
text_encoder1: ChineseCLIPTextModel,
|
| 143 |
+
text_encoder2: ChineseCLIPTextModel,
|
| 144 |
+
weight_dtype: Optional[str] = None,
|
| 145 |
+
attention_mask1: torch.Tensor = None,
|
| 146 |
+
attention_mask2: torch.Tensor = None,
|
| 147 |
+
):
|
| 148 |
+
# input_ids: b,n,77 -> b*n, 77
|
| 149 |
+
b_size = input_ids1.size()[0]
|
| 150 |
+
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
|
| 151 |
+
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
|
| 152 |
+
if attention_mask1 is not None:
|
| 153 |
+
attention_mask1 = attention_mask1.reshape((-1, tokenizer1.model_max_length))
|
| 154 |
+
attention_mask2 = attention_mask2.reshape((-1, tokenizer2.model_max_length))
|
| 155 |
+
|
| 156 |
+
hidden_states1, _ = encode_token(input_ids1, attention_mask1, text_encoder1)
|
| 157 |
+
hidden_states2, pool2 = encode_token(input_ids2, attention_mask2, text_encoder2)
|
| 158 |
+
|
| 159 |
+
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
|
| 160 |
+
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
|
| 161 |
+
if weight_dtype is not None:
|
| 162 |
+
# this is required for additional network training
|
| 163 |
+
hidden_states1 = hidden_states1.to(weight_dtype)
|
| 164 |
+
hidden_states2 = hidden_states2.to(weight_dtype)
|
| 165 |
+
|
| 166 |
+
return hidden_states1, hidden_states2, pool2
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def encode_token(input_ids, attention_mask, text_encoder):
|
| 170 |
+
|
| 171 |
+
# T5
|
| 172 |
+
if isinstance(text_encoder, T5ForConditionalGeneration):
|
| 173 |
+
prompt_embeds = text_encoder.encoder(
|
| 174 |
+
input_ids,
|
| 175 |
+
attention_mask=attention_mask,
|
| 176 |
+
output_hidden_states=True,
|
| 177 |
+
)
|
| 178 |
+
pooled_prompt_embeds = None
|
| 179 |
+
prompt_embeds = prompt_embeds.hidden_states[-1]
|
| 180 |
+
|
| 181 |
+
# clip Bert
|
| 182 |
+
elif isinstance(text_encoder, ChineseCLIPTextModel):
|
| 183 |
+
prompt_embeds = text_encoder(
|
| 184 |
+
input_ids,
|
| 185 |
+
attention_mask=attention_mask,
|
| 186 |
+
output_hidden_states=True,
|
| 187 |
+
)
|
| 188 |
+
# We are only ALWAYS interested in the pooled output of the final text encoder
|
| 189 |
+
pooled_prompt_embeds = prompt_embeds['pooler_output']
|
| 190 |
+
prompt_embeds = prompt_embeds.hidden_states[-2]
|
| 191 |
+
|
| 192 |
+
# 3mp_Bert\Qwen2Model\LLM2VecWithoutPool\GLMModel
|
| 193 |
+
else:
|
| 194 |
+
prompt_embeds = text_encoder(
|
| 195 |
+
input_ids,
|
| 196 |
+
attention_mask=attention_mask,
|
| 197 |
+
output_hidden_states=True,
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
if 'last_hidden_states' in prompt_embeds:
|
| 201 |
+
prompt_embeds = prompt_embeds.last_hidden_states
|
| 202 |
+
else:
|
| 203 |
+
prompt_embeds = prompt_embeds.last_hidden_state
|
| 204 |
+
|
| 205 |
+
pooled_prompt_embeds = prompt_embeds.mean(dim=1)
|
| 206 |
+
|
| 207 |
+
return prompt_embeds, pooled_prompt_embeds
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def prepare_logging(args: argparse.Namespace, is_main_process):
|
| 213 |
+
if args.logging_dir is None:
|
| 214 |
+
logging_dir = None
|
| 215 |
+
else:
|
| 216 |
+
log_prefix = "" if args.log_prefix is None else args.log_prefix
|
| 217 |
+
logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
|
| 218 |
+
|
| 219 |
+
log_with = args.log_with
|
| 220 |
+
if log_with in ["tensorboard", "all"]:
|
| 221 |
+
if logging_dir is None:
|
| 222 |
+
raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
|
| 223 |
+
tensorboard_dir = os.path.join(logging_dir, 'tensorboard')
|
| 224 |
+
|
| 225 |
+
writer = None
|
| 226 |
+
if is_main_process:
|
| 227 |
+
os.makedirs(logging_dir, exist_ok=True)
|
| 228 |
+
os.makedirs(tensorboard_dir, exist_ok=True)
|
| 229 |
+
|
| 230 |
+
if args.script_args:
|
| 231 |
+
sh_basename = os.path.basename(args.script_args)
|
| 232 |
+
sh_dst_path = os.path.join(logging_dir, sh_basename)
|
| 233 |
+
data_basename = os.path.basename(args.dataset_config)
|
| 234 |
+
data_dst_path = os.path.join(logging_dir, data_basename)
|
| 235 |
+
shutil.copyfile(args.script_args, sh_dst_path)
|
| 236 |
+
shutil.copyfile(args.dataset_config, data_dst_path)
|
| 237 |
+
|
| 238 |
+
writer = SummaryWriter(tensorboard_dir)
|
| 239 |
+
|
| 240 |
+
return writer
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
|
| 244 |
+
|
| 245 |
+
if args.clip_skip is not None:
|
| 246 |
+
print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
|
| 247 |
+
|
| 248 |
+
if args.multires_noise_iterations:
|
| 249 |
+
print(
|
| 250 |
+
f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations"
|
| 251 |
+
)
|
| 252 |
+
else:
|
| 253 |
+
if args.noise_offset is None:
|
| 254 |
+
args.noise_offset = DEFAULT_NOISE_OFFSET
|
| 255 |
+
elif args.noise_offset != DEFAULT_NOISE_OFFSET:
|
| 256 |
+
print(
|
| 257 |
+
f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
|
| 258 |
+
)
|
| 259 |
+
print(f"noise_offset is set to {args.noise_offset}")
|
| 260 |
+
|
| 261 |
+
assert (
|
| 262 |
+
not hasattr(args, "weighted_captions") or not args.weighted_captions
|
| 263 |
+
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
|
| 264 |
+
|
| 265 |
+
if supportTextEncoderCaching:
|
| 266 |
+
if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
|
| 267 |
+
args.cache_text_encoder_outputs = True
|
| 268 |
+
print(
|
| 269 |
+
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
|
| 270 |
+
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
| 275 |
+
"""
|
| 276 |
+
Create sinusoidal timestep embeddings.
|
| 277 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
| 278 |
+
These may be fractional.
|
| 279 |
+
:param dim: the dimension of the output.
|
| 280 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 281 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
| 282 |
+
"""
|
| 283 |
+
half = dim // 2
|
| 284 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
| 285 |
+
device=timesteps.device
|
| 286 |
+
)
|
| 287 |
+
args = timesteps[:, None].float() * freqs[None]
|
| 288 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 289 |
+
if dim % 2:
|
| 290 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 291 |
+
return embedding
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def get_timestep_embedding(x, outdim):
|
| 295 |
+
assert len(x.shape) == 2
|
| 296 |
+
b, dims = x.shape[0], x.shape[1]
|
| 297 |
+
x = torch.flatten(x)
|
| 298 |
+
emb = timestep_embedding(x, outdim)
|
| 299 |
+
emb = torch.reshape(emb, (b, dims * outdim))
|
| 300 |
+
return emb
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
def get_size_embeddings(orig_size, crop_size, target_size, device):
|
| 304 |
+
emb1 = get_timestep_embedding(orig_size, 256)
|
| 305 |
+
emb2 = get_timestep_embedding(crop_size, 256)
|
| 306 |
+
emb3 = get_timestep_embedding(target_size, 256)
|
| 307 |
+
vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
|
| 308 |
+
return vector
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
|
| 314 |
+
)
|
| 315 |
+
parser.add_argument(
|
| 316 |
+
"--cache_text_encoder_outputs_to_disk",
|
| 317 |
+
action="store_true",
|
| 318 |
+
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def set_unet_eff_attn(unet, mem_eff_attn, xformers, sdpa):
|
| 323 |
+
if mem_eff_attn:
|
| 324 |
+
print("Enable memory efficient attention for U-Net")
|
| 325 |
+
unet.set_use_memory_efficient_attention_xformers(False, True)
|
| 326 |
+
elif xformers:
|
| 327 |
+
print("Enable xformers for U-Net")
|
| 328 |
+
try:
|
| 329 |
+
import xformers.ops
|
| 330 |
+
except ImportError:
|
| 331 |
+
raise ImportError("No xformers / xformersがインストールされていないようです")
|
| 332 |
+
|
| 333 |
+
unet.set_use_memory_efficient_attention_xformers(True, False)
|
| 334 |
+
elif sdpa:
|
| 335 |
+
print("Enable SDPA for U-Net")
|
| 336 |
+
unet.set_use_sdpa(True)
|
| 337 |
+
|
| 338 |
+
def set_diffusers_xformers_flag(model, valid):
|
| 339 |
+
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
| 340 |
+
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
| 341 |
+
module.set_use_memory_efficient_attention_xformers(valid)
|
| 342 |
+
|
| 343 |
+
for child in module.children():
|
| 344 |
+
fn_recursive_set_mem_eff(child)
|
| 345 |
+
|
| 346 |
+
fn_recursive_set_mem_eff(model)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
def read_json(json_path):
|
| 350 |
+
return json.load(open(json_path))
|
library/custom_train_functions.py
ADDED
|
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import argparse
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
from typing import List, Optional, Union
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def prepare_scheduler_for_custom_training(noise_scheduler, device):
|
| 9 |
+
if hasattr(noise_scheduler, "all_snr"):
|
| 10 |
+
return
|
| 11 |
+
|
| 12 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
| 13 |
+
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
|
| 14 |
+
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
|
| 15 |
+
alpha = sqrt_alphas_cumprod
|
| 16 |
+
sigma = sqrt_one_minus_alphas_cumprod
|
| 17 |
+
all_snr = (alpha / sigma) ** 2
|
| 18 |
+
|
| 19 |
+
noise_scheduler.all_snr = all_snr.to(device)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
|
| 23 |
+
# fix beta: zero terminal SNR
|
| 24 |
+
print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
|
| 25 |
+
|
| 26 |
+
def enforce_zero_terminal_snr(betas):
|
| 27 |
+
# Convert betas to alphas_bar_sqrt
|
| 28 |
+
alphas = 1 - betas
|
| 29 |
+
alphas_bar = alphas.cumprod(0)
|
| 30 |
+
alphas_bar_sqrt = alphas_bar.sqrt()
|
| 31 |
+
|
| 32 |
+
# Store old values.
|
| 33 |
+
alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
|
| 34 |
+
alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
|
| 35 |
+
# Shift so last timestep is zero.
|
| 36 |
+
alphas_bar_sqrt -= alphas_bar_sqrt_T
|
| 37 |
+
# Scale so first timestep is back to old value.
|
| 38 |
+
alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
|
| 39 |
+
|
| 40 |
+
# Convert alphas_bar_sqrt to betas
|
| 41 |
+
alphas_bar = alphas_bar_sqrt**2
|
| 42 |
+
alphas = alphas_bar[1:] / alphas_bar[:-1]
|
| 43 |
+
alphas = torch.cat([alphas_bar[0:1], alphas])
|
| 44 |
+
betas = 1 - alphas
|
| 45 |
+
return betas
|
| 46 |
+
|
| 47 |
+
betas = noise_scheduler.betas
|
| 48 |
+
betas = enforce_zero_terminal_snr(betas)
|
| 49 |
+
alphas = 1.0 - betas
|
| 50 |
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
| 51 |
+
|
| 52 |
+
# print("original:", noise_scheduler.betas)
|
| 53 |
+
# print("fixed:", betas)
|
| 54 |
+
|
| 55 |
+
noise_scheduler.betas = betas
|
| 56 |
+
noise_scheduler.alphas = alphas
|
| 57 |
+
noise_scheduler.alphas_cumprod = alphas_cumprod
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
|
| 61 |
+
snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
|
| 62 |
+
gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
|
| 63 |
+
snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
|
| 64 |
+
loss = loss * snr_weight
|
| 65 |
+
return loss
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
|
| 69 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
| 70 |
+
loss = loss * scale
|
| 71 |
+
return loss
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_snr_scale(timesteps, noise_scheduler):
|
| 75 |
+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
|
| 76 |
+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
|
| 77 |
+
scale = snr_t / (snr_t + 1)
|
| 78 |
+
# # show debug info
|
| 79 |
+
# print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
|
| 80 |
+
return scale
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
|
| 84 |
+
scale = get_snr_scale(timesteps, noise_scheduler)
|
| 85 |
+
# print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
|
| 86 |
+
loss = loss + loss / scale * v_pred_like_loss
|
| 87 |
+
return loss
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# TODO train_utilと分散しているのでどちらかに寄せる
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
|
| 94 |
+
parser.add_argument(
|
| 95 |
+
"--min_snr_gamma",
|
| 96 |
+
type=float,
|
| 97 |
+
default=None,
|
| 98 |
+
help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
|
| 99 |
+
)
|
| 100 |
+
parser.add_argument(
|
| 101 |
+
"--scale_v_pred_loss_like_noise_pred",
|
| 102 |
+
action="store_true",
|
| 103 |
+
help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
|
| 104 |
+
)
|
| 105 |
+
parser.add_argument(
|
| 106 |
+
"--v_pred_like_loss",
|
| 107 |
+
type=float,
|
| 108 |
+
default=None,
|
| 109 |
+
help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
|
| 110 |
+
)
|
| 111 |
+
if support_weighted_captions:
|
| 112 |
+
parser.add_argument(
|
| 113 |
+
"--weighted_captions",
|
| 114 |
+
action="store_true",
|
| 115 |
+
default=False,
|
| 116 |
+
help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
re_attention = re.compile(
|
| 121 |
+
r"""
|
| 122 |
+
\\\(|
|
| 123 |
+
\\\)|
|
| 124 |
+
\\\[|
|
| 125 |
+
\\]|
|
| 126 |
+
\\\\|
|
| 127 |
+
\\|
|
| 128 |
+
\(|
|
| 129 |
+
\[|
|
| 130 |
+
:([+-]?[.\d]+)\)|
|
| 131 |
+
\)|
|
| 132 |
+
]|
|
| 133 |
+
[^\\()\[\]:]+|
|
| 134 |
+
:
|
| 135 |
+
""",
|
| 136 |
+
re.X,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def parse_prompt_attention(text):
|
| 141 |
+
"""
|
| 142 |
+
Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
|
| 143 |
+
Accepted tokens are:
|
| 144 |
+
(abc) - increases attention to abc by a multiplier of 1.1
|
| 145 |
+
(abc:3.12) - increases attention to abc by a multiplier of 3.12
|
| 146 |
+
[abc] - decreases attention to abc by a multiplier of 1.1
|
| 147 |
+
\( - literal character '('
|
| 148 |
+
\[ - literal character '['
|
| 149 |
+
\) - literal character ')'
|
| 150 |
+
\] - literal character ']'
|
| 151 |
+
\\ - literal character '\'
|
| 152 |
+
anything else - just text
|
| 153 |
+
>>> parse_prompt_attention('normal text')
|
| 154 |
+
[['normal text', 1.0]]
|
| 155 |
+
>>> parse_prompt_attention('an (important) word')
|
| 156 |
+
[['an ', 1.0], ['important', 1.1], [' word', 1.0]]
|
| 157 |
+
>>> parse_prompt_attention('(unbalanced')
|
| 158 |
+
[['unbalanced', 1.1]]
|
| 159 |
+
>>> parse_prompt_attention('\(literal\]')
|
| 160 |
+
[['(literal]', 1.0]]
|
| 161 |
+
>>> parse_prompt_attention('(unnecessary)(parens)')
|
| 162 |
+
[['unnecessaryparens', 1.1]]
|
| 163 |
+
>>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
|
| 164 |
+
[['a ', 1.0],
|
| 165 |
+
['house', 1.5730000000000004],
|
| 166 |
+
[' ', 1.1],
|
| 167 |
+
['on', 1.0],
|
| 168 |
+
[' a ', 1.1],
|
| 169 |
+
['hill', 0.55],
|
| 170 |
+
[', sun, ', 1.1],
|
| 171 |
+
['sky', 1.4641000000000006],
|
| 172 |
+
['.', 1.1]]
|
| 173 |
+
"""
|
| 174 |
+
|
| 175 |
+
res = []
|
| 176 |
+
round_brackets = []
|
| 177 |
+
square_brackets = []
|
| 178 |
+
|
| 179 |
+
round_bracket_multiplier = 1.1
|
| 180 |
+
square_bracket_multiplier = 1 / 1.1
|
| 181 |
+
|
| 182 |
+
def multiply_range(start_position, multiplier):
|
| 183 |
+
for p in range(start_position, len(res)):
|
| 184 |
+
res[p][1] *= multiplier
|
| 185 |
+
|
| 186 |
+
for m in re_attention.finditer(text):
|
| 187 |
+
text = m.group(0)
|
| 188 |
+
weight = m.group(1)
|
| 189 |
+
|
| 190 |
+
if text.startswith("\\"):
|
| 191 |
+
res.append([text[1:], 1.0])
|
| 192 |
+
elif text == "(":
|
| 193 |
+
round_brackets.append(len(res))
|
| 194 |
+
elif text == "[":
|
| 195 |
+
square_brackets.append(len(res))
|
| 196 |
+
elif weight is not None and len(round_brackets) > 0:
|
| 197 |
+
multiply_range(round_brackets.pop(), float(weight))
|
| 198 |
+
elif text == ")" and len(round_brackets) > 0:
|
| 199 |
+
multiply_range(round_brackets.pop(), round_bracket_multiplier)
|
| 200 |
+
elif text == "]" and len(square_brackets) > 0:
|
| 201 |
+
multiply_range(square_brackets.pop(), square_bracket_multiplier)
|
| 202 |
+
else:
|
| 203 |
+
res.append([text, 1.0])
|
| 204 |
+
|
| 205 |
+
for pos in round_brackets:
|
| 206 |
+
multiply_range(pos, round_bracket_multiplier)
|
| 207 |
+
|
| 208 |
+
for pos in square_brackets:
|
| 209 |
+
multiply_range(pos, square_bracket_multiplier)
|
| 210 |
+
|
| 211 |
+
if len(res) == 0:
|
| 212 |
+
res = [["", 1.0]]
|
| 213 |
+
|
| 214 |
+
# merge runs of identical weights
|
| 215 |
+
i = 0
|
| 216 |
+
while i + 1 < len(res):
|
| 217 |
+
if res[i][1] == res[i + 1][1]:
|
| 218 |
+
res[i][0] += res[i + 1][0]
|
| 219 |
+
res.pop(i + 1)
|
| 220 |
+
else:
|
| 221 |
+
i += 1
|
| 222 |
+
|
| 223 |
+
return res
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
|
| 227 |
+
r"""
|
| 228 |
+
Tokenize a list of prompts and return its tokens with weights of each token.
|
| 229 |
+
|
| 230 |
+
No padding, starting or ending token is included.
|
| 231 |
+
"""
|
| 232 |
+
tokens = []
|
| 233 |
+
weights = []
|
| 234 |
+
truncated = False
|
| 235 |
+
for text in prompt:
|
| 236 |
+
texts_and_weights = parse_prompt_attention(text)
|
| 237 |
+
text_token = []
|
| 238 |
+
text_weight = []
|
| 239 |
+
for word, weight in texts_and_weights:
|
| 240 |
+
# tokenize and discard the starting and the ending token
|
| 241 |
+
token = tokenizer(word).input_ids[1:-1]
|
| 242 |
+
text_token += token
|
| 243 |
+
# copy the weight by length of token
|
| 244 |
+
text_weight += [weight] * len(token)
|
| 245 |
+
# stop if the text is too long (longer than truncation limit)
|
| 246 |
+
if len(text_token) > max_length:
|
| 247 |
+
truncated = True
|
| 248 |
+
break
|
| 249 |
+
# truncate
|
| 250 |
+
if len(text_token) > max_length:
|
| 251 |
+
truncated = True
|
| 252 |
+
text_token = text_token[:max_length]
|
| 253 |
+
text_weight = text_weight[:max_length]
|
| 254 |
+
tokens.append(text_token)
|
| 255 |
+
weights.append(text_weight)
|
| 256 |
+
if truncated:
|
| 257 |
+
print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
|
| 258 |
+
return tokens, weights
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
|
| 262 |
+
r"""
|
| 263 |
+
Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
|
| 264 |
+
"""
|
| 265 |
+
max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
|
| 266 |
+
weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
|
| 267 |
+
for i in range(len(tokens)):
|
| 268 |
+
tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
|
| 269 |
+
if no_boseos_middle:
|
| 270 |
+
weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
|
| 271 |
+
else:
|
| 272 |
+
w = []
|
| 273 |
+
if len(weights[i]) == 0:
|
| 274 |
+
w = [1.0] * weights_length
|
| 275 |
+
else:
|
| 276 |
+
for j in range(max_embeddings_multiples):
|
| 277 |
+
w.append(1.0) # weight for starting token in this chunk
|
| 278 |
+
w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
|
| 279 |
+
w.append(1.0) # weight for ending token in this chunk
|
| 280 |
+
w += [1.0] * (weights_length - len(w))
|
| 281 |
+
weights[i] = w[:]
|
| 282 |
+
|
| 283 |
+
return tokens, weights
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def get_unweighted_text_embeddings(
|
| 287 |
+
tokenizer,
|
| 288 |
+
text_encoder,
|
| 289 |
+
text_input: torch.Tensor,
|
| 290 |
+
chunk_length: int,
|
| 291 |
+
clip_skip: int,
|
| 292 |
+
eos: int,
|
| 293 |
+
pad: int,
|
| 294 |
+
no_boseos_middle: Optional[bool] = True,
|
| 295 |
+
):
|
| 296 |
+
"""
|
| 297 |
+
When the length of tokens is a multiple of the capacity of the text encoder,
|
| 298 |
+
it should be split into chunks and sent to the text encoder individually.
|
| 299 |
+
"""
|
| 300 |
+
max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
|
| 301 |
+
if max_embeddings_multiples > 1:
|
| 302 |
+
text_embeddings = []
|
| 303 |
+
for i in range(max_embeddings_multiples):
|
| 304 |
+
# extract the i-th chunk
|
| 305 |
+
text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
|
| 306 |
+
|
| 307 |
+
# cover the head and the tail by the starting and the ending tokens
|
| 308 |
+
text_input_chunk[:, 0] = text_input[0, 0]
|
| 309 |
+
if pad == eos: # v1
|
| 310 |
+
text_input_chunk[:, -1] = text_input[0, -1]
|
| 311 |
+
else: # v2
|
| 312 |
+
for j in range(len(text_input_chunk)):
|
| 313 |
+
if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
|
| 314 |
+
text_input_chunk[j, -1] = eos
|
| 315 |
+
if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
|
| 316 |
+
text_input_chunk[j, 1] = eos
|
| 317 |
+
|
| 318 |
+
if clip_skip is None or clip_skip == 1:
|
| 319 |
+
text_embedding = text_encoder(text_input_chunk)[0]
|
| 320 |
+
else:
|
| 321 |
+
enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
|
| 322 |
+
text_embedding = enc_out["hidden_states"][-clip_skip]
|
| 323 |
+
text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
|
| 324 |
+
|
| 325 |
+
if no_boseos_middle:
|
| 326 |
+
if i == 0:
|
| 327 |
+
# discard the ending token
|
| 328 |
+
text_embedding = text_embedding[:, :-1]
|
| 329 |
+
elif i == max_embeddings_multiples - 1:
|
| 330 |
+
# discard the starting token
|
| 331 |
+
text_embedding = text_embedding[:, 1:]
|
| 332 |
+
else:
|
| 333 |
+
# discard both starting and ending tokens
|
| 334 |
+
text_embedding = text_embedding[:, 1:-1]
|
| 335 |
+
|
| 336 |
+
text_embeddings.append(text_embedding)
|
| 337 |
+
text_embeddings = torch.concat(text_embeddings, axis=1)
|
| 338 |
+
else:
|
| 339 |
+
if clip_skip is None or clip_skip == 1:
|
| 340 |
+
text_embeddings = text_encoder(text_input)[0]
|
| 341 |
+
else:
|
| 342 |
+
enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
|
| 343 |
+
text_embeddings = enc_out["hidden_states"][-clip_skip]
|
| 344 |
+
text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
|
| 345 |
+
return text_embeddings
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
def get_weighted_text_embeddings(
|
| 349 |
+
tokenizer,
|
| 350 |
+
text_encoder,
|
| 351 |
+
prompt: Union[str, List[str]],
|
| 352 |
+
device,
|
| 353 |
+
max_embeddings_multiples: Optional[int] = 3,
|
| 354 |
+
no_boseos_middle: Optional[bool] = False,
|
| 355 |
+
clip_skip=None,
|
| 356 |
+
):
|
| 357 |
+
r"""
|
| 358 |
+
Prompts can be assigned with local weights using brackets. For example,
|
| 359 |
+
prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
|
| 360 |
+
and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
|
| 361 |
+
|
| 362 |
+
Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
|
| 363 |
+
|
| 364 |
+
Args:
|
| 365 |
+
prompt (`str` or `List[str]`):
|
| 366 |
+
The prompt or prompts to guide the image generation.
|
| 367 |
+
max_embeddings_multiples (`int`, *optional*, defaults to `3`):
|
| 368 |
+
The max multiple length of prompt embeddings compared to the max output length of text encoder.
|
| 369 |
+
no_boseos_middle (`bool`, *optional*, defaults to `False`):
|
| 370 |
+
If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
|
| 371 |
+
ending token in each of the chunk in the middle.
|
| 372 |
+
skip_parsing (`bool`, *optional*, defaults to `False`):
|
| 373 |
+
Skip the parsing of brackets.
|
| 374 |
+
skip_weighting (`bool`, *optional*, defaults to `False`):
|
| 375 |
+
Skip the weighting. When the parsing is skipped, it is forced True.
|
| 376 |
+
"""
|
| 377 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 378 |
+
if isinstance(prompt, str):
|
| 379 |
+
prompt = [prompt]
|
| 380 |
+
|
| 381 |
+
prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
|
| 382 |
+
|
| 383 |
+
# round up the longest length of tokens to a multiple of (model_max_length - 2)
|
| 384 |
+
max_length = max([len(token) for token in prompt_tokens])
|
| 385 |
+
|
| 386 |
+
max_embeddings_multiples = min(
|
| 387 |
+
max_embeddings_multiples,
|
| 388 |
+
(max_length - 1) // (tokenizer.model_max_length - 2) + 1,
|
| 389 |
+
)
|
| 390 |
+
max_embeddings_multiples = max(1, max_embeddings_multiples)
|
| 391 |
+
max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
|
| 392 |
+
|
| 393 |
+
# pad the length of tokens and weights
|
| 394 |
+
bos = tokenizer.bos_token_id
|
| 395 |
+
eos = tokenizer.eos_token_id
|
| 396 |
+
pad = tokenizer.pad_token_id
|
| 397 |
+
prompt_tokens, prompt_weights = pad_tokens_and_weights(
|
| 398 |
+
prompt_tokens,
|
| 399 |
+
prompt_weights,
|
| 400 |
+
max_length,
|
| 401 |
+
bos,
|
| 402 |
+
eos,
|
| 403 |
+
no_boseos_middle=no_boseos_middle,
|
| 404 |
+
chunk_length=tokenizer.model_max_length,
|
| 405 |
+
)
|
| 406 |
+
prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
|
| 407 |
+
|
| 408 |
+
# get the embeddings
|
| 409 |
+
text_embeddings = get_unweighted_text_embeddings(
|
| 410 |
+
tokenizer,
|
| 411 |
+
text_encoder,
|
| 412 |
+
prompt_tokens,
|
| 413 |
+
tokenizer.model_max_length,
|
| 414 |
+
clip_skip,
|
| 415 |
+
eos,
|
| 416 |
+
pad,
|
| 417 |
+
no_boseos_middle=no_boseos_middle,
|
| 418 |
+
)
|
| 419 |
+
prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
|
| 420 |
+
|
| 421 |
+
# assign weights to the prompts and normalize in the sense of mean
|
| 422 |
+
previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 423 |
+
text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
|
| 424 |
+
current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
|
| 425 |
+
text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
|
| 426 |
+
|
| 427 |
+
return text_embeddings
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
# https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
|
| 431 |
+
def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
|
| 432 |
+
b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
|
| 433 |
+
u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
|
| 434 |
+
for i in range(iterations):
|
| 435 |
+
r = random.random() * 2 + 2 # Rather than always going 2x,
|
| 436 |
+
wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
|
| 437 |
+
noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
|
| 438 |
+
if wn == 1 or hn == 1:
|
| 439 |
+
break # Lowest resolution is 1x1
|
| 440 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
| 444 |
+
def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
|
| 445 |
+
if noise_offset is None:
|
| 446 |
+
return noise
|
| 447 |
+
if adaptive_noise_scale is not None:
|
| 448 |
+
# latent shape: (batch_size, channels, height, width)
|
| 449 |
+
# abs mean value for each channel
|
| 450 |
+
latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
|
| 451 |
+
|
| 452 |
+
# multiply adaptive noise scale to the mean value and add it to the noise offset
|
| 453 |
+
noise_offset = noise_offset + adaptive_noise_scale * latent_mean
|
| 454 |
+
noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
|
| 455 |
+
|
| 456 |
+
noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
| 457 |
+
return noise
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
"""
|
| 461 |
+
##########################################
|
| 462 |
+
# Perlin Noise
|
| 463 |
+
def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
|
| 464 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
| 465 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
| 466 |
+
|
| 467 |
+
grid = (
|
| 468 |
+
torch.stack(
|
| 469 |
+
torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
|
| 470 |
+
dim=-1,
|
| 471 |
+
)
|
| 472 |
+
% 1
|
| 473 |
+
)
|
| 474 |
+
angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
|
| 475 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
| 476 |
+
|
| 477 |
+
tile_grads = (
|
| 478 |
+
lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
|
| 479 |
+
.repeat_interleave(d[0], 0)
|
| 480 |
+
.repeat_interleave(d[1], 1)
|
| 481 |
+
)
|
| 482 |
+
dot = lambda grad, shift: (
|
| 483 |
+
torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
|
| 484 |
+
* grad[: shape[0], : shape[1]]
|
| 485 |
+
).sum(dim=-1)
|
| 486 |
+
|
| 487 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
| 488 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
| 489 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
| 490 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
| 491 |
+
t = fade(grid[: shape[0], : shape[1]])
|
| 492 |
+
return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
|
| 496 |
+
noise = torch.zeros(shape, device=device)
|
| 497 |
+
frequency = 1
|
| 498 |
+
amplitude = 1
|
| 499 |
+
for _ in range(octaves):
|
| 500 |
+
noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
|
| 501 |
+
frequency *= 2
|
| 502 |
+
amplitude *= persistence
|
| 503 |
+
return noise
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
def perlin_noise(noise, device, octaves):
|
| 507 |
+
_, c, w, h = noise.shape
|
| 508 |
+
perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
|
| 509 |
+
noise_perlin = []
|
| 510 |
+
for _ in range(c):
|
| 511 |
+
noise_perlin.append(perlin())
|
| 512 |
+
noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
|
| 513 |
+
noise += noise_perlin # broadcast for each batch
|
| 514 |
+
return noise / noise.std() # Scaled back to roughly unit variance
|
| 515 |
+
"""
|
library/train_util.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|