Mike0021 commited on
Commit
166ab04
·
verified ·
1 Parent(s): 7874d4a

Implement Moebius Gradio Space

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +18 -0
  2. .gitignore +5 -0
  3. LICENSE +199 -0
  4. README.md +195 -8
  5. app.py +229 -0
  6. assets/logo_dynamic_woWaterMark.gif +3 -0
  7. assets/pipeline.png +3 -0
  8. assets/qualitative_comparison.png +3 -0
  9. assets/sup_showcase_celebahq_ffhq.png +3 -0
  10. assets/sup_showcase_places_v2.png +3 -0
  11. assets/tab1.png +3 -0
  12. assets/tab1_woTitle.png +3 -0
  13. assets/tab2.png +3 -0
  14. assets/tab3.png +3 -0
  15. assets/tab4.png +3 -0
  16. config/data_demo.yaml +9 -0
  17. config/model_cfg/moebius.yaml +47 -0
  18. config/model_cfg/pixelhacker.yaml +24 -0
  19. config/rand_mask_cfg/random_medium_256.yaml +33 -0
  20. config/rand_mask_cfg/random_medium_512.yaml +32 -0
  21. config/rand_mask_cfg/random_thick_256.yaml +33 -0
  22. config/rand_mask_cfg/random_thick_512.yaml +33 -0
  23. config/rand_mask_cfg/random_thin_256.yaml +25 -0
  24. config/rand_mask_cfg/random_thin_512.yaml +25 -0
  25. config/train_demo.sh +45 -0
  26. data/images/0.png +3 -0
  27. data/images/1.png +3 -0
  28. data/images/10.png +3 -0
  29. data/images/100.png +3 -0
  30. data/images/10000.png +3 -0
  31. data/images/10001.png +3 -0
  32. data/images/10002.png +3 -0
  33. data/images/10003.png +3 -0
  34. data/masks/000000.png +0 -0
  35. data/masks/000001.png +0 -0
  36. data/masks/000002.png +0 -0
  37. data/masks/000003.png +0 -0
  38. data/masks/000004.png +0 -0
  39. data/masks/000005.png +0 -0
  40. data/masks/000006.png +0 -0
  41. data/masks/000007.png +0 -0
  42. data/train_data.jsonl +8 -0
  43. infer/__init__.py +0 -0
  44. infer/infer_moebius.py +45 -0
  45. infer/utils.py +123 -0
  46. infer/utils_dataset.py +211 -0
  47. library/__init__.py +0 -0
  48. library/chinese_sdxl_train_util.py +350 -0
  49. library/custom_train_functions.py +515 -0
  50. library/train_util.py +0 -0
.gitattributes CHANGED
@@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/logo_dynamic_woWaterMark.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/pipeline.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/qualitative_comparison.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/sup_showcase_celebahq_ffhq.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/sup_showcase_places_v2.png filter=lfs diff=lfs merge=lfs -text
41
+ assets/tab1.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/tab1_woTitle.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/tab2.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/tab3.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/tab4.png filter=lfs diff=lfs merge=lfs -text
46
+ data/images/0.png filter=lfs diff=lfs merge=lfs -text
47
+ data/images/1.png filter=lfs diff=lfs merge=lfs -text
48
+ data/images/10.png filter=lfs diff=lfs merge=lfs -text
49
+ data/images/100.png filter=lfs diff=lfs merge=lfs -text
50
+ data/images/10000.png filter=lfs diff=lfs merge=lfs -text
51
+ data/images/10001.png filter=lfs diff=lfs merge=lfs -text
52
+ data/images/10002.png filter=lfs diff=lfs merge=lfs -text
53
+ data/images/10003.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .venv/
4
+ outputs/
5
+ weight/
LICENSE ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to the Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by the Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding any notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. Please also get an
185
+ "Alarm or alarm" page at http://www.apache.org/
186
+
187
+ Copyright 2024 Moebius Authors
188
+
189
+ Licensed under the Apache License, Version 2.0 (the "License");
190
+ you may not use this file except in compliance with the License.
191
+ You may obtain a copy of the License at
192
+
193
+ http://www.apache.org/licenses/LICENSE-2.0
194
+
195
+ Unless required by applicable law or agreed to in writing, software
196
+ distributed under the License is distributed on an "AS IS" BASIS,
197
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
198
+ See the License for the specific language governing permissions and
199
+ limitations under the License.
README.md CHANGED
@@ -1,13 +1,200 @@
1
  ---
2
- title: Moebius
3
- emoji: 📊
4
- colorFrom: green
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 6.19.0
8
- python_version: '3.13'
9
  app_file: app.py
10
- pinned: false
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Moebius Inpainting
3
+ emoji: 🖌️
4
+ colorFrom: blue
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 6.10.0
 
8
  app_file: app.py
9
+ short_description: Lightweight Moebius image inpainting
10
+ startup_duration_timeout: 1h
11
+ python_version: 3.11
12
  ---
13
 
14
+ <div align="center">
15
+ <img src="./assets/logo_dynamic_woWaterMark.gif" width="100%"></img>
16
+ </div>
17
+
18
+ <div align="center">
19
+ <h2>Moebius: 0.2B Lightweight Image Inpainting Framework with 10B-Level Performance</h2>
20
+
21
+ ***On-par-with/surpass 10B-level industrial SOTA generalist (FLUX.1-Fill-Dev) on 6 benchmarks across natural and portrait scenes & Only 2% (0.2B) parameters, and inference 15× faster***
22
+
23
+ [Kangsheng Duan](https://github.com/AnduinD)<sup>1,</sup>\*, [Ziyang Xu](https://ziyangxu.top)<sup>1,</sup>\*<sup>,&dagger;</sup>, [Wenyu Liu](http://eic.hust.edu.cn/professor/liuwenyu)<sup>1</sup>, Xiaohu Ruan<sup>2</sup>, [Xiaoxin Chen](https://scholar.google.com/citations?hl=zh-CN&user=SI_oBwsAAAAJ)<sup>2</sup>, [Xinggang Wang](https://xwcv.github.io)<sup>1, :email:</sup>
24
+
25
+ (*) Equal Contribution, (<sup>&dagger;</sup>) Project Leader, (<sup>:email:</sup>) Corresponding Author.
26
+
27
+ <sup>1</sup> Huazhong University of Science and Technology. <sup>2</sup> VIVO AI Lab.
28
+
29
+ [![arxiv](https://img.shields.io/badge/ECCV'26-paper-orange)](https://arxiv.org/abs/2606.19195) [![license](https://img.shields.io/badge/License-Apache_2.0-blue)](LICENSE) [![Project](https://img.shields.io/badge/Project_Page-https://hustvl.github.io/Moebius-purple)](https://hustvl.github.io/Moebius) [![HF Daily Rank](https://img.shields.io/badge/Hugging%20Face-No.%201%20Daily%20Ranking-ffbd00)](https://huggingface.co/papers/date/2026-06-19)
30
+
31
+ <br>
32
+
33
+ <img src="./assets/pipeline.png" style="margin-bottom: 10px;"></img>
34
+
35
+ <img src="./assets/tab1_woTitle.png"></img>
36
+
37
+ <img src="./assets/qualitative_comparison.png"></img>
38
+
39
+ </div>
40
+
41
+
42
+ ## 🐱‍🏍 Insight & Small Talk
43
+
44
+ > ***Moebius*** *is our latest AI Image Inpainting endeavor, serving as a direct continuation of our previous work, **[PixelHacker](https://github.com/hustvl/PixelHacker)**. Named after the concepts of "infinity" and "master painter," Moebius embodies our vision: maintaining exceptional generation quality under highly constrained computational resources while pushing the efficiency of image inpainting to its limits as much as possible.*
45
+ >
46
+ > *Under the iron grip of the Scaling Law, AI research has long devolved into a grueling arms race of burning capital, compute, and data. Consequently, the academic community finds it increasingly difficult to keep pace with the ever-expanding model scales driven by the tech industry.*
47
+ >
48
+ > <p align="center"><b>"<ins>But is this brute-force scaling truly the only path forward?</ins>"</b></p>
49
+ >
50
+ > *Using general-purpose image inpainting as our strategic entry point, we challenge the "scale-at-all-costs" path dependency dictated by the Scaling Law narrative. Through the synergistic optimization of architectural design and knowledge distillation, Moebius achieves a remarkably compact footprint of just **0.22B parameters**. It liberates high-quality image inpainting from the heavy-compute narrative of 10B+ foundation models:*
51
+ > *Across six comprehensive benchmarks spanning both natural and portrait scenes, Moebius performs **on par with**, and in certain scenarios **surpasses**, the inpainting quality of 10B+ industrial state-of-the-art (SOTA) generalist models like *FLUX.1-Fill-Dev*, while delivering a massive **>15× inference acceleration**.*
52
+ >
53
+ > 💡 **The core insight of Moebius can be summarized in a single equation:**
54
+ >
55
+ > $$\begin{aligned}
56
+ > \text{Synergy} \times (\text{Architecture} + \text{Distillation}) = & \text{Shattering the "Impossible Triangle" of} \\
57
+ > & \text{Low Parameters, Fast Inference, and High Quality}
58
+ > \end{aligned}$$
59
+ >
60
+ > --- *written on June 16, 2026* ---
61
+
62
+
63
+
64
+ ## 🌟 Highlights
65
+
66
+ - **📉 Extreme Parametric Efficiency (< 2%)**: Moebius operates with a mere **0.22B (226M) parameters**, which represents **less than 2%** of the size of the colossal industrial giant *FLUX.1-Fill-Dev (11.9B)*. It shatters the heavy-compute narrative, making high-quality inpainting accessible on consumer-grade and edge devices.
67
+ - **⚡ 15× Inference Speedup (26ms/step)**: Achieves a blistering inference latency of only **26.01 ms per step** on a single GPU. Combined with optimized sampling steps, Moebius delivers an overall **>15× total runtime acceleration** compared to 10B-level models.
68
+ - **🏆 10B-Level Inpainting Quality (on-par-with/surpass FLUX.1-Fill-Dev across 6 benchmarks)**: Size contraction does not mean representation degradation. Through the synergistic optimization of architecture and distillation, Moebius performs on par with, and in certain scenarios (such as complex textures and facial plausibility), surpasses 10B-level state-of-the-art (SOTA) generalist models (*FLUX.1-Fill-Dev, SD3.5 Large-Inpainting*) across 6 comprehensive benchmarks spanning **both natural** scenes (*Places2*) and **portrait** scenes (*CelebA-HQ*, *FFHQ*).
69
+ - **💡 Synergistic Core Innovations**:
70
+ - **Architecture Design (LλMI Block)**: Reformulates both self- and cross-attention by condensing spatial context and global semantic priors into fixed-size linear matrices, bypassing quadratic computational overhead.
71
+ - **Adaptive Multi-Granularity Distillation Strategy**: Transfers the representational capacity from our *[PixelHacker](https://github.com/hustvl/PixelHacker)* (teacher) strictly within the latent space (avoiding expensive pixel-space decoding). It bridges the giant capacity gap by aligning multi-granularity supervision—ranging from microscopic intermediate features to macroscopic diffusion trajectories—while dynamically balancing training via a gradient norm adaptive loss weighting mechanism.
72
+ - **Optimal Synergistic Balancing**: Systematically explores the mutual constraint and upper bound between compact structure and distillation. By mapping this architecture-distillation synergy frontier, we ensure our 0.22B *Moebius* (student) absorbs the maximum semantic reasoning of *[PixelHacker](https://github.com/hustvl/PixelHacker)* (teacher) without triggering representation saturation.
73
+
74
+ <div align="center">
75
+ <img src="./assets/tab2.png" width="70%" style="margin-bottom: 10px;"></img>
76
+ </div>
77
+
78
+ - **🚀 Task-Specific Specialist over Bloated Generalists**: Rather than blindly scaling up, Moebius answers a fundamental question: *<ins>Can a model be smarter, lighter, and faster when the task is explicitly defined?</ins>* It serves as a highly optimized specialist that liberates real-world image inpainting and AI object removal from parameter bloat.
79
+
80
+ ## 🔥 News
81
+ * **`June 19, 2026`:** 🎉 Moebius has achieved the [No. 1 daily ranking](https://huggingface.co/papers/date/2026-06-19) on Hugging Face!
82
+
83
+ * **`June 18, 2026`:** 🔥🔥 We have released the training and inference code, and open-sourced the [model weights](https://huggingface.co/hustvl/Moebius) on Hugging Face.
84
+
85
+ * **`June 18, 2026`:** 🎉 Moebius is accepted by ECCV'26! We have released the preprint on arXiv, check it [here](https://arxiv.org/abs/2606.19195) ~ 🍻
86
+
87
+ * **`June 16, 2026`:** 🔥 We have submitted the GitHub repo for the first time, and there will be more updates soon. Stay tuned! 🤗
88
+
89
+ ## 🏕️ Performance on Natural Scene
90
+
91
+ <div align="center">
92
+
93
+
94
+ <img src="./assets/tab3.png"></img>
95
+
96
+ <img src="./assets/sup_showcase_places_v2.png"></img>
97
+
98
+ </div>
99
+
100
+ ## 🤗 Performance on Portrait Scene
101
+ <div align="center">
102
+
103
+
104
+ <img src="./assets/tab4.png" width="50%"></img>
105
+
106
+ <img src="./assets/sup_showcase_celebahq_ffhq.png"></img>
107
+
108
+ </div>
109
+
110
+ ## ⚖️ Evaluation Resources
111
+
112
+ The masks of the evaluation set are shared in [Google Drive](https://drive.google.com/drive/folders/13J91fdQt2RnHp4j-VzdtSrHRHPA1OxJ5?usp=sharing), and the corresponding images can be downloaded from the following open source platforms:
113
+ * Places2: [Places2](http://places2.csail.mit.edu/download-private.html)
114
+ * CelebA-HQ: [CelebA-HQ](https://openxlab.org.cn/datasets/OpenDataLab/CelebA-HQ)
115
+ * FFHQ: [FFHQ](https://drive.google.com/drive/folders/1tZUcXDBeOibC6jcMCtgRRz67pzrAHeHL?usp=drive_link)
116
+
117
+
118
+
119
+ ## 📦 Environment Setups
120
+
121
+ * torch=2.7.1
122
+ * diffusers=0.38.0
123
+ * transformers=4.56.2
124
+ * flash-linear-attention=0.3.2
125
+ * See 'requirements.txt' for detailed Python libraries required
126
+
127
+ ```bash
128
+ conda create -n moebius python=3.14.4
129
+ conda activate moebius
130
+ # cd /xx/xx/Moebius
131
+ pip install -r requirements.txt
132
+ ```
133
+
134
+ ## 🗃️ Model Checkpoints
135
+ * Download the checkpoint of [VAE](https://huggingface.co/hustvl/PixelHacker/tree/main/vae) and put it into ./weight/vae.
136
+
137
+ * Download the checkpoints of [pretrained version](https://huggingface.co/hustvl/Moebius/tree/main/pretrained), [fine-tuned version (places2)](https://huggingface.co/hustvl/Moebius/tree/main/ft_places2), [fine-tuned version (celeba-hq)](https://huggingface.co/hustvl/Moebius/tree/main/ft_celebahq), [fine-tuned version (ffhq)](https://huggingface.co/hustvl/Moebius/tree/main/ft_ffhq), and put them into ./weight/Moebius.
138
+
139
+ * Finally, the detailed organizational form is as follows:
140
+ ```bash
141
+ ├── weight
142
+ | ├── Moebius
143
+ | ├── pretrained
144
+ | ├── diffusion_pytorch_model.bin
145
+ | ├── ft_places2
146
+ | ├── diffusion_pytorch_model.bin
147
+ | ├── ft_celebahq
148
+ | ├── diffusion_pytorch_model.bin
149
+ | ├── ft_ffhq
150
+ | ├── diffusion_pytorch_model.bin
151
+ | ├── vae
152
+ | ├── config.json
153
+ | ├── diffusion_pytorch_model.bin
154
+ ├── ...
155
+ ```
156
+
157
+ <!-- * teacher model and vae: [hustvl/PixelHacker](https://huggingface.co/hustvl/PixelHacker)
158
+ * student model: [hustvl/Moebius](https://huggingface.co/hustvl/Moebius) -->
159
+
160
+ ## 🚂 Training
161
+ You can run the following code to start training. The training script supports distributed training, and you can configure the GPU count via environment variables.
162
+ ```bash
163
+ # For single GPU training:
164
+ PY_TRAINER=train_distillation.py bash run/run_ddp_1node.sh config/train_demo.sh
165
+
166
+ # For multi GPU training:
167
+ NUM_GPUS_PER_MACHINE=4 bash run/run_ddp_1node.sh config/train_demo.sh
168
+ ```
169
+
170
+ ## 🔮 Inference
171
+ You can run the following code directly to get the inpainting result of the example image-mask pair, and the result will be generated in ./outputs. If you want to infer on custom data, just place the image and mask with the same name in ./dataset.local/imgs and ./dataset.local/masks, respectively, then run the following code as well.
172
+ ```bash
173
+ python -m infer.infer_moebius \
174
+ --model-config config/model_cfg/moebius.yaml \
175
+ --model-weight weight/Moebius/ft_celebahq/diffusion_pytorch_model.bin \
176
+ --real-dir data/images \
177
+ --mask-dir data/masks \
178
+ --save-dir ./outputs \
179
+ --cfg 2.0 \
180
+ --batch-size 8 \
181
+ --num-workers 8
182
+ ```
183
+
184
+
185
+ ## 🎓 Citation
186
+
187
+ ```shell
188
+ @misc{DuanAndXu2026Moebius,
189
+ title={Moebius: 0.2B Lightweight Image Inpainting Framework with 10B-Level Performance},
190
+ author={Kangsheng Duan and Ziyang Xu and Wenyu Liu and Xiaohu Ruan and Xiaoxin Chen and Xinggang Wang},
191
+ year={2026},
192
+ eprint={2606.19195},
193
+ archivePrefix={arXiv},
194
+ primaryClass={cs.CV},
195
+ url={https://arxiv.org/abs/2606.19195},
196
+ }
197
+ ```
198
+
199
+ ## 🧑‍🤝‍🧑 Acknowledgement
200
+ We sincerely thank the authors of the following open-source repositories for their contributions to the community, which have greatly facilitated our research and development of Moebius: [Sana](https://github.com/NVlabs/Sana), [flash-linear-attention](https://github.com/fla-org/flash-linear-attention), [lambda-networks](https://github.com/lucidrains/lambda-networks), [timm](https://github.com/huggingface/pytorch-image-models), [Muon](https://github.com/KellerJordan/Muon), [diffusers](https://github.com/huggingface/diffusers).
app.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ _hf_cache = "/data/.cache/huggingface" if os.path.isdir("/data") and os.access("/data", os.W_OK) else "/tmp/hf_home"
4
+ os.environ.setdefault("HF_HOME", _hf_cache)
5
+ os.environ.setdefault("HF_MODULES_CACHE", "/tmp/hf_modules")
6
+ os.environ.setdefault("MPLCONFIGDIR", "/tmp/matplotlib")
7
+ os.environ.setdefault("GRADIO_SSR_MODE", "false")
8
+
9
+ import time
10
+ from pathlib import Path
11
+ from typing import Dict, Tuple
12
+
13
+ import spaces
14
+ import gradio as gr
15
+ import torch
16
+ from diffusers import DDIMScheduler
17
+ from diffusers.models import AutoencoderKL
18
+ from huggingface_hub import hf_hub_download, snapshot_download
19
+ from PIL import Image
20
+
21
+ from removal.v1_2 import build_removal_model, load_cfg, load_removal_model
22
+ from removal.v1_2.pipeline import RemovalSDXLPipeline_BatchMode
23
+
24
+
25
+ ROOT = Path(__file__).resolve().parent
26
+ CONFIG_PATH = ROOT / "config" / "model_cfg" / "moebius.yaml"
27
+ MOEBIUS_REPO = "hustvl/Moebius"
28
+ PIXELHACKER_REPO = "hustvl/PixelHacker"
29
+ DEFAULT_MODEL_KEY = "ft_places2"
30
+
31
+ MODEL_CHOICES = {
32
+ "General scenes (Places2)": "ft_places2",
33
+ "Portraits (CelebA-HQ)": "ft_celebahq",
34
+ "Faces (FFHQ)": "ft_ffhq",
35
+ "Pretrained": "pretrained",
36
+ }
37
+
38
+ _PIPELINE_CACHE: Dict[str, RemovalSDXLPipeline_BatchMode] = {}
39
+
40
+
41
+ def _download_vae_dir() -> str:
42
+ repo_dir = snapshot_download(
43
+ repo_id=PIXELHACKER_REPO,
44
+ allow_patterns=["vae/*"],
45
+ )
46
+ return str(Path(repo_dir) / "vae")
47
+
48
+
49
+ def _download_model_weight(model_key: str) -> str:
50
+ return hf_hub_download(
51
+ repo_id=MOEBIUS_REPO,
52
+ filename=f"{model_key}/diffusion_pytorch_model.bin",
53
+ )
54
+
55
+
56
+ def _build_cpu_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode:
57
+ model_cfg = load_cfg(str(CONFIG_PATH))
58
+ model_cfg["vae"]["model_dir"] = _download_vae_dir()
59
+
60
+ removal_model = build_removal_model(model_cfg, 20)
61
+ weight_path = _download_model_weight(model_key)
62
+ print(load_removal_model(removal_model, weight_path, device="cpu"))
63
+
64
+ vae = AutoencoderKL.from_pretrained(model_cfg["vae"]["model_dir"])
65
+ scheduler = DDIMScheduler(
66
+ beta_start=0.00085,
67
+ beta_end=0.012,
68
+ beta_schedule="scaled_linear",
69
+ num_train_timesteps=1000,
70
+ clip_sample=False,
71
+ )
72
+
73
+ return RemovalSDXLPipeline_BatchMode(
74
+ removal_model=removal_model,
75
+ vae=vae,
76
+ scheduler=scheduler,
77
+ device="cpu",
78
+ dtype=torch.float32,
79
+ )
80
+
81
+
82
+ def _get_pipeline(model_key: str) -> RemovalSDXLPipeline_BatchMode:
83
+ if model_key not in _PIPELINE_CACHE:
84
+ _PIPELINE_CACHE[model_key] = _build_cpu_pipeline(model_key)
85
+ return _PIPELINE_CACHE[model_key]
86
+
87
+
88
+ def _set_pipeline_device(pipe: RemovalSDXLPipeline_BatchMode, device: str) -> None:
89
+ pipe.device = device
90
+ pipe.vae.to(device=device, dtype=pipe.dtype).eval()
91
+ pipe.removal_model.to(device=device, dtype=pipe.dtype).eval()
92
+
93
+ half_id_num = pipe.removal_model.num_embeddings // 2
94
+ id_num = pipe.removal_model.num_embeddings
95
+ input_ids = torch.tensor([list(range(half_id_num))], dtype=torch.int64, device=device, requires_grad=False)
96
+ un_input_ids = torch.tensor([list(range(half_id_num, id_num))], dtype=torch.int64, device=device, requires_grad=False)
97
+ pipe.input_ids = torch.cat([un_input_ids, input_ids]).to(device=device)
98
+
99
+
100
+ def _normalize_inputs(image: Image.Image, mask: Image.Image) -> Tuple[Image.Image, Image.Image]:
101
+ if image is None:
102
+ raise gr.Error("Upload an image.")
103
+ if mask is None:
104
+ raise gr.Error("Upload a mask.")
105
+
106
+ image = image.convert("RGB")
107
+ mask = mask.convert("L").resize(image.size, Image.Resampling.NEAREST)
108
+
109
+ mask_min, mask_max = mask.getextrema()
110
+ if mask_max < 8:
111
+ raise gr.Error("The mask is empty. Use white pixels for the area to inpaint.")
112
+ if mask_min > 247:
113
+ raise gr.Error("The mask covers the whole image. Leave black pixels outside the edit area.")
114
+
115
+ return image, mask
116
+
117
+
118
+ def _model_key(label: str) -> str:
119
+ return MODEL_CHOICES.get(label, DEFAULT_MODEL_KEY)
120
+
121
+
122
+ def _estimate_duration(image, mask, model_name, steps, guidance_scale, paste, compensate, seed, *args, **kwargs):
123
+ del image, mask, model_name, guidance_scale, paste, compensate, seed, args, kwargs
124
+ return min(240, 90 + int(steps) * 5)
125
+
126
+
127
+ _get_pipeline(DEFAULT_MODEL_KEY)
128
+
129
+
130
+ @spaces.GPU(duration=1)
131
+ def _zerogpu_probe():
132
+ return "ready"
133
+
134
+
135
+ @spaces.GPU(duration=_estimate_duration)
136
+ def run_inpaint(image, mask, model_name, steps, guidance_scale, paste, compensate, seed):
137
+ image, mask = _normalize_inputs(image, mask)
138
+ model_key = _model_key(model_name)
139
+ seed_value = 0 if seed is None else int(seed)
140
+ pipe = _get_pipeline(model_key)
141
+
142
+ started = time.perf_counter()
143
+ try:
144
+ _set_pipeline_device(pipe, "cuda")
145
+ with torch.inference_mode():
146
+ outputs = pipe(
147
+ [image],
148
+ [mask],
149
+ image_size=512,
150
+ num_steps=int(steps),
151
+ guidance_scale=float(guidance_scale),
152
+ paste=bool(paste),
153
+ compensate=bool(compensate),
154
+ noise_offset=0.0357,
155
+ retry=seed_value,
156
+ mute=True,
157
+ )
158
+ elapsed = time.perf_counter() - started
159
+ return outputs[0], f"Completed in {elapsed:.1f}s"
160
+ finally:
161
+ _set_pipeline_device(pipe, "cpu")
162
+ if torch.cuda.is_available():
163
+ torch.cuda.empty_cache()
164
+
165
+
166
+ with gr.Blocks(title="Moebius Inpainting", fill_width=True) as demo:
167
+ gr.Markdown("# Moebius Inpainting")
168
+ with gr.Row():
169
+ with gr.Column(scale=1, min_width=320):
170
+ input_image = gr.Image(
171
+ label="Image",
172
+ type="pil",
173
+ image_mode="RGB",
174
+ sources=["upload", "clipboard"],
175
+ height=360,
176
+ )
177
+ input_mask = gr.Image(
178
+ label="Mask",
179
+ type="pil",
180
+ image_mode="L",
181
+ sources=["upload", "clipboard"],
182
+ height=360,
183
+ )
184
+ with gr.Column(scale=1, min_width=320):
185
+ output_image = gr.Image(label="Result", type="pil", height=520)
186
+ status = gr.Markdown()
187
+
188
+ with gr.Row():
189
+ model_name = gr.Dropdown(
190
+ label="Checkpoint",
191
+ choices=list(MODEL_CHOICES.keys()),
192
+ value="General scenes (Places2)",
193
+ min_width=240,
194
+ )
195
+ steps = gr.Slider(4, 30, value=20, step=1, label="Steps", min_width=180)
196
+ guidance_scale = gr.Slider(1.0, 6.0, value=2.0, step=0.1, label="CFG", min_width=180)
197
+ seed = gr.Number(value=0, precision=0, label="Seed", min_width=140)
198
+
199
+ with gr.Row():
200
+ paste = gr.Checkbox(value=True, label="Paste")
201
+ compensate = gr.Checkbox(value=False, label="Compensate")
202
+ run_button = gr.Button("Inpaint", variant="primary")
203
+
204
+ run_button.click(
205
+ fn=run_inpaint,
206
+ inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed],
207
+ outputs=[output_image, status],
208
+ api_name="inpaint",
209
+ concurrency_limit=1,
210
+ )
211
+
212
+ gr.Examples(
213
+ examples=[
214
+ ["data/images/0.png", "data/masks/000000.png", "General scenes (Places2)", 20, 2.0, True, False, 0],
215
+ ["data/images/1.png", "data/masks/000001.png", "General scenes (Places2)", 20, 2.0, True, False, 1],
216
+ ],
217
+ inputs=[input_image, input_mask, model_name, steps, guidance_scale, paste, compensate, seed],
218
+ outputs=[output_image, status],
219
+ fn=run_inpaint,
220
+ cache_examples=True,
221
+ cache_mode="lazy",
222
+ )
223
+
224
+
225
+ demo.queue(max_size=8, default_concurrency_limit=1)
226
+
227
+
228
+ if __name__ == "__main__":
229
+ demo.launch()
assets/logo_dynamic_woWaterMark.gif ADDED

Git LFS Details

  • SHA256: 50c044d0b741ab057dd3debc74aca6bcf68752e5dc9f6c238e28ff44cd828dd0
  • Pointer size: 133 Bytes
  • Size of remote file: 22.4 MB
assets/pipeline.png ADDED

Git LFS Details

  • SHA256: 062dadb63eac08f76e02c53d5f2fda6bed9281e405b890f8fc25890510af0e23
  • Pointer size: 131 Bytes
  • Size of remote file: 624 kB
assets/qualitative_comparison.png ADDED

Git LFS Details

  • SHA256: c89293e6c515a7625e0c90e0f2587c94acd88336fbc8cf2017bbe414e425b93d
  • Pointer size: 132 Bytes
  • Size of remote file: 5.07 MB
assets/sup_showcase_celebahq_ffhq.png ADDED

Git LFS Details

  • SHA256: 0c4c7a8775cedafd26308fb9c6d38606f63dc822e89d00aa14fc45bdf8b707c3
  • Pointer size: 132 Bytes
  • Size of remote file: 9.75 MB
assets/sup_showcase_places_v2.png ADDED

Git LFS Details

  • SHA256: 1672c61e662e87eb8cc37bd6a858959d2a544c3cc74c9cd088f390c88010f2d8
  • Pointer size: 133 Bytes
  • Size of remote file: 11.4 MB
assets/tab1.png ADDED

Git LFS Details

  • SHA256: 10965aef428ca1f1b942503a09a9c4b1959464c3ccde5dfad7a99d0b5170b8b5
  • Pointer size: 131 Bytes
  • Size of remote file: 276 kB
assets/tab1_woTitle.png ADDED

Git LFS Details

  • SHA256: 537927361e74307d0f19516757a856935e819544e22bc2b7fb0458e363c375e2
  • Pointer size: 131 Bytes
  • Size of remote file: 188 kB
assets/tab2.png ADDED

Git LFS Details

  • SHA256: 8f17dfbc00cccd7422546273df24004d33ee11c9e45674bea37770162118f409
  • Pointer size: 131 Bytes
  • Size of remote file: 805 kB
assets/tab3.png ADDED

Git LFS Details

  • SHA256: 62eac94e1636d9a97b6c7f816267548a9725b3304bd01a69a0c9ba163161f5d5
  • Pointer size: 131 Bytes
  • Size of remote file: 793 kB
assets/tab4.png ADDED

Git LFS Details

  • SHA256: 2421f7d32c75242a3bf24d0fa1c471fa8f51ecff63afaa2e247fcbe50fb604e8
  • Pointer size: 131 Bytes
  • Size of remote file: 511 kB
config/data_demo.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ path: data/train_data.jsonl
3
+
4
+ use_rand_mask: True
5
+ rand_mask_config: config/rand_mask_cfg/random_medium_512.yaml
6
+
7
+ use_extra_fg_mask: False
8
+
9
+ extra_ann_files_4_PureBackTrain_2_RandMask: null
config/model_cfg/moebius.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ image_size: 512
3
+
4
+
5
+ vae:
6
+ model_name: 'sdvae_f8d4'
7
+ model_dir: ./weight/vae
8
+ downsample_ratio: 8
9
+ embed_dim: 4
10
+
11
+
12
+ model:
13
+ model_type: UNet2DLambdaDWConvMixFFNConditionModel_prune_down_mid_up_block_8x8
14
+
15
+ in_channels: 9
16
+ out_channels: 4
17
+
18
+ attention_head_dim: 8
19
+
20
+ conv_in_kernel: 3
21
+ conv_out_kernel: 3
22
+ cross_attention_dim: 768
23
+
24
+ encoder_hid_dim: 3072
25
+ encoder_hid_dim_type: 'text_proj'
26
+
27
+ projection_class_embeddings_input_dim: 2560
28
+
29
+ use_lambda_cross_attn: True
30
+ use_local_self_attn: True
31
+
32
+ down_block_types:
33
+ - DWMixTFDownBlock2D
34
+ - DWMixTFDownBlock2D
35
+ - DWMixTFDownBlock2D
36
+ mid_block_type: null
37
+ up_block_types:
38
+ - DWMixTFUpBlock2D
39
+ - DWMixTFUpBlock2D
40
+ - DWMixTFUpBlock2D
41
+
42
+ block_out_channels:
43
+ - 320
44
+ - 640
45
+ - 1280
46
+
47
+ mix_mlp_ratio: 2.5
config/model_cfg/pixelhacker.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data:
2
+ image_size: 512
3
+
4
+
5
+ vae:
6
+ model_name: 'sdvae_f8d4'
7
+ model_dir: ./weight/vae
8
+ downsample_ratio: 8
9
+ embed_dim: 4
10
+
11
+
12
+ model:
13
+ model_type: UNet2DGLAConditionModel
14
+
15
+ in_channels: 9
16
+ out_channels: 4
17
+
18
+ attention_head_dim: 8
19
+ cross_attention_dim: 768
20
+
21
+ encoder_hid_dim: 3072
22
+ encoder_hid_dim_type: 'text_proj'
23
+
24
+ projection_class_embeddings_input_dim: 2560
config/rand_mask_cfg/random_medium_256.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 5
8
+ max_width: 50
9
+ max_angle: 4
10
+ max_len: 100
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 0
15
+ bbox_min_size: 10
16
+ bbox_max_size: 50
17
+ max_times: 5
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ # variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 256
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
config/rand_mask_cfg/random_medium_512.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 10
8
+ max_width: 100
9
+ max_angle: 4
10
+ max_len: 200
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 0
15
+ bbox_min_size: 30
16
+ bbox_max_size: 150
17
+ max_times: 5
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+
24
+ max_masks_per_image: 1
25
+
26
+ cropping:
27
+ out_min_size: 512
28
+ handle_small_mode: upscale
29
+ out_square_crop: True
30
+ crop_min_overlap: 1
31
+
32
+ max_tamper_area: 0.5
config/rand_mask_cfg/random_thick_256.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 1
7
+ max_times: 5
8
+ max_width: 100
9
+ max_angle: 4
10
+ max_len: 200
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 10
15
+ bbox_min_size: 30
16
+ bbox_max_size: 150
17
+ max_times: 3
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ # variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 256
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
config/rand_mask_cfg/random_thick_512.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 1
7
+ max_times: 5
8
+ max_width: 250
9
+ max_angle: 4
10
+ max_len: 450
11
+
12
+ box_proba: 0.3
13
+ box_kwargs:
14
+ margin: 10
15
+ bbox_min_size: 30
16
+ bbox_max_size: 300
17
+ max_times: 4
18
+ min_times: 1
19
+
20
+ segm_proba: 0
21
+ squares_proba: 0
22
+
23
+ # variants_n: 5
24
+
25
+ max_masks_per_image: 1
26
+
27
+ cropping:
28
+ out_min_size: 512
29
+ handle_small_mode: upscale
30
+ out_square_crop: True
31
+ crop_min_overlap: 1
32
+
33
+ max_tamper_area: 0.5
config/rand_mask_cfg/random_thin_256.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 50
8
+ max_width: 10
9
+ max_angle: 4
10
+ max_len: 40
11
+ box_proba: 0
12
+ segm_proba: 0
13
+ squares_proba: 0
14
+
15
+ variants_n: 5
16
+
17
+ max_masks_per_image: 1
18
+
19
+ cropping:
20
+ out_min_size: 256
21
+ handle_small_mode: upscale
22
+ out_square_crop: True
23
+ crop_min_overlap: 1
24
+
25
+ max_tamper_area: 0.5
config/rand_mask_cfg/random_thin_512.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_kind: random
2
+
3
+ mask_generator_kwargs:
4
+ irregular_proba: 1
5
+ irregular_kwargs:
6
+ min_times: 4
7
+ max_times: 70
8
+ max_width: 20
9
+ max_angle: 4
10
+ max_len: 100
11
+ box_proba: 0
12
+ segm_proba: 0
13
+ squares_proba: 0
14
+
15
+ variants_n: 5
16
+
17
+ max_masks_per_image: 1
18
+
19
+ cropping:
20
+ out_min_size: 512
21
+ handle_small_mode: upscale
22
+ out_square_crop: True
23
+ crop_min_overlap: 1
24
+
25
+ max_tamper_area: 0.5
config/train_demo.sh ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Set WORK_DIR to your project root before running
2
+ THIS_SH_PATH=$CONFIG_FILE
3
+
4
+
5
+ OUTPUT_DIR='exp_outputs'
6
+ OUTPUT_DIR_EXP_NAME="${OUTPUT_DIR}/${EXP_NAME}"
7
+
8
+
9
+ export OUTPUT_DIR=$OUTPUT_DIR
10
+ export OUTPUT_DIR_EXP_NAME=$OUTPUT_DIR_EXP_NAME
11
+ export HF_HOME=$HF_HOME
12
+
13
+ export TRAIN_ARGS=" --data_type RemovalDataset_v1_2 \
14
+ --lognorm_t \
15
+ --elatentlpips_loss --elatentlpips_loss_weight 0.5 \
16
+ --task_loss --task_loss_weight 0.5 \
17
+ --KD_loss_weight 0.01 \
18
+ --mse_feat_loss --feat_loss_weight 1.0 --feat_index_T 5 --feat_index_S 2 \
19
+ --model_config_path=config/model_cfg/moebius.yaml \
20
+ --teacher_weight_path=../../hf_models/hustvl/PixelHacker/pretrained/diffusion_pytorch_model.bin \
21
+ --teacher_config_path=config/model_cfg/pixelhacker.yaml \
22
+ --data_config=config/data_demo.yaml \
23
+ --num_embeddings 20 \
24
+ --image_size 512 \
25
+ --batch_size 2 \
26
+ --num_workers 4 \
27
+ --output_dir=${OUTPUT_DIR_EXP_NAME} \
28
+ --output_name=exp \
29
+ --seed=42 \
30
+ --learning_rate=1e-4 \
31
+ --global_step=0 \
32
+ --max_train_steps=200000 \
33
+ --save_every_n_steps=3000 \
34
+ --logging_dir=${OUTPUT_DIR_EXP_NAME}/log \
35
+ --gradient_accumulation_steps=1 \
36
+ --optimizer_type=Muon \
37
+ --lr_scheduler=constant_with_warmup \
38
+ --lr_warmup_steps=0 \
39
+ --save_precision=bf16 \
40
+ --mixed_precision=bf16 \
41
+ --noise_offset=0.0357 \
42
+ --gradient_checkpointing \
43
+ --xformers \
44
+ --log_with=tensorboard \
45
+ --script_args=$THIS_SH_PATH "
data/images/0.png ADDED

Git LFS Details

  • SHA256: a6a27b1be3be48d8d89882dbd66927f5eba5be4a49b471a16853b663dde7a3b4
  • Pointer size: 131 Bytes
  • Size of remote file: 400 kB
data/images/1.png ADDED

Git LFS Details

  • SHA256: 4750545d31413533e45a29736235654ac9b2c0f1dc1956080406861e37dc74e8
  • Pointer size: 131 Bytes
  • Size of remote file: 417 kB
data/images/10.png ADDED

Git LFS Details

  • SHA256: bbc5088521255df0f23a55ff2e6941cf9e172f4bfd518701c2b2594e81680e20
  • Pointer size: 131 Bytes
  • Size of remote file: 317 kB
data/images/100.png ADDED

Git LFS Details

  • SHA256: c112d470d70193d2ab1ea0e8746db8d892b8576568dcb257c91e793e836c323f
  • Pointer size: 131 Bytes
  • Size of remote file: 377 kB
data/images/10000.png ADDED

Git LFS Details

  • SHA256: 120dcf025efb50a6d220ddba07c435a5d12abc3ce1f00b4b4cc868249d1781c0
  • Pointer size: 131 Bytes
  • Size of remote file: 362 kB
data/images/10001.png ADDED

Git LFS Details

  • SHA256: 89749ea1a5442b738dba155f55d51a225edd6e2a405032554cff1034a81e542e
  • Pointer size: 131 Bytes
  • Size of remote file: 318 kB
data/images/10002.png ADDED

Git LFS Details

  • SHA256: a4afea8d4e239d74186308ac0dc1994e8c607caae5c17e63f63a62676401a5f9
  • Pointer size: 131 Bytes
  • Size of remote file: 278 kB
data/images/10003.png ADDED

Git LFS Details

  • SHA256: a133da270d2019d2b6bdf3741293768f0b1cc845395e6a81c31d0b0107539aad
  • Pointer size: 131 Bytes
  • Size of remote file: 225 kB
data/masks/000000.png ADDED
data/masks/000001.png ADDED
data/masks/000002.png ADDED
data/masks/000003.png ADDED
data/masks/000004.png ADDED
data/masks/000005.png ADDED
data/masks/000006.png ADDED
data/masks/000007.png ADDED
data/train_data.jsonl ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {"image": "data/images/0.png", "prompt": "background"}
2
+ {"image": "data/images/1.png", "prompt": "background"}
3
+ {"image": "data/images/10.png", "prompt": "background"}
4
+ {"image": "data/images/100.png", "prompt": "background"}
5
+ {"image": "data/images/10000.png", "prompt": "background"}
6
+ {"image": "data/images/10001.png", "prompt": "background"}
7
+ {"image": "data/images/10002.png", "prompt": "background"}
8
+ {"image": "data/images/10003.png", "prompt": "background"}
infer/__init__.py ADDED
File without changes
infer/infer_moebius.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import os
3
+ from typing import List
4
+
5
+
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import torch
9
+
10
+ from PIL import Image
11
+ from pathlib import Path
12
+
13
+ from .utils import get_batch_infer_args, build_pipeline, SAVER
14
+ from .utils_dataset import SimpleInferDataset, build_dataloader
15
+
16
+
17
+
18
+ def main():
19
+ args = get_batch_infer_args()
20
+
21
+ dataloader = build_dataloader(args, SimpleInferDataset)
22
+
23
+ pipe = build_pipeline(args)
24
+ pipe = partial(pipe,
25
+ guidance_scale=args.cfg,
26
+ paste=args.pst,
27
+ compensate=args.cps,
28
+ num_steps=args.num_step,
29
+ noise_offset=args.noise_offset
30
+ )
31
+
32
+ save_root = Path(args.save_dir)
33
+ save_root.mkdir(parents=True, exist_ok=True)
34
+
35
+ pbar_loader = tqdm(enumerate(dataloader),
36
+ total=dataloader.dataset.__len__()//args.batch_size+1)
37
+
38
+ for idx, (images, masks, inames) in pbar_loader:
39
+ image_inpaint_list = pipe(images, masks)
40
+ names = [iname+'.png' for iname in inames]
41
+ SAVER.save_images_mp(image_inpaint_list, names, save_root)
42
+
43
+
44
+ if __name__ == '__main__':
45
+ main()
infer/utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ import os
3
+ from typing import List
4
+ from pathlib import Path
5
+ import math
6
+
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+ import torch
10
+ from PIL import Image
11
+
12
+ def get_batch_infer_args(parser=None):
13
+
14
+ if parser is None:
15
+ import argparse
16
+ parser = argparse.ArgumentParser()
17
+
18
+ def str2bool(v):
19
+ if isinstance(v, bool):
20
+ return v
21
+ if v.lower() in ('yes', 'true', 't', 'y', '1'):
22
+ return True
23
+ elif v.lower() in ('no', 'false', 'f', 'n', '0'):
24
+ return False
25
+ else:
26
+ raise argparse.ArgumentTypeError('Boolean value expected.')
27
+
28
+
29
+ # model argument
30
+ parser.add_argument("--model-config", type=str, required=False, default=None)
31
+ parser.add_argument("--model-weight", type=str, required=False, default=None)
32
+
33
+ # sampling argument
34
+ parser.add_argument("--num-step", type=int, required=False, default=20)
35
+ parser.add_argument("--cfg", type=float, required=False, default=2.5)
36
+ parser.add_argument("--pst", type=str2bool, required=False, default=True)
37
+ parser.add_argument("--cps", type=str2bool, required=False, default=False)
38
+ parser.add_argument("--noise-offset", type=float, required=False, default=0.0357)
39
+ parser.add_argument("--seed", type=int, default=0, required=False)
40
+
41
+
42
+ # data argument
43
+ parser.add_argument("--real-dir", type=Path, required=True)
44
+ parser.add_argument("--mask-dir", type=Path, required=False)
45
+ parser.add_argument("--resolution", type=int, default=512, required=False)
46
+
47
+ # runtime argument
48
+ parser.add_argument("--device", type=str, required=False, default="cuda")
49
+ parser.add_argument("--batch-size", type=int, required=False, default=32)
50
+ parser.add_argument("--num-workers", type=int, required=False, default=64)
51
+
52
+ # save argument
53
+ parser.add_argument("--save-dir", type=str, required=True)
54
+ parser.add_argument("--visualize-latent", action="store_true", default=False)
55
+
56
+ return parser.parse_args()
57
+
58
+ def build_pipeline(args):
59
+ from diffusers import DDIMScheduler
60
+ from removal.v1_2.pipeline import RemovalSDXLPipeline_BatchMode as Removal_Pipeline
61
+ from removal.v1_2 import build_removal_model, load_cfg, load_removal_model
62
+ from utils_train import build_vae
63
+
64
+
65
+ model_cfg = load_cfg(args.model_config)
66
+
67
+ removal_model = build_removal_model(model_cfg, 20).to(args.device)
68
+ print(load_removal_model(removal_model, args.model_weight,args.device))
69
+
70
+ vae = build_vae(model_cfg).to(args.device)
71
+ scheduler = DDIMScheduler(
72
+ beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
73
+ num_train_timesteps=1000, clip_sample=False)
74
+
75
+ pipe = Removal_Pipeline(
76
+ removal_model=removal_model,
77
+ vae=vae,
78
+ scheduler=scheduler,
79
+ device=args.device,
80
+ dtype=torch.float)
81
+
82
+ return pipe
83
+
84
+ class SAVER:
85
+ @staticmethod
86
+ def save_image(img, name, path):
87
+ img.save(path / name)
88
+ return name
89
+
90
+ @staticmethod
91
+ def save_images(images:List[Image.Image], names:List[str], save_root:str):
92
+ assert len(images) == len(names), \
93
+ f"images and names are not equal: {len(images)}!={len(names)}"
94
+
95
+ pbar_save = tqdm(zip(images, names), total=len(names))
96
+
97
+ cache_names = os.listdir(save_root)
98
+ for image, name in pbar_save:
99
+ if name not in cache_names:
100
+ SAVER.save_image(image, name, save_root)
101
+
102
+ @staticmethod
103
+ def save_images_mt(images:List[Image.Image], names:List[str], save_root:str, num_workers=8):
104
+ from concurrent.futures import ThreadPoolExecutor
105
+ with ThreadPoolExecutor(max_workers=num_workers) as executor:
106
+ futures = [
107
+ executor.submit(SAVER.save_image, image, name, save_root) for image, name in zip(images, names)]
108
+
109
+ for future in tqdm(futures):
110
+ future.result()
111
+
112
+ @staticmethod
113
+ def save_images_mp(images:List[Image.Image], names:List[str], save_root:str, num_workers=8):
114
+ from concurrent.futures import ProcessPoolExecutor
115
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
116
+ futures = [
117
+ executor.submit(SAVER.save_image, image, name, save_root) for image, name in zip(images, names)]
118
+
119
+ for future in tqdm(futures):
120
+ future.result()
121
+
122
+
123
+
infer/utils_dataset.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------- Dataset Utils -----------------------
2
+
3
+ import warnings
4
+ from pathlib import Path
5
+ from typing import Tuple, Optional
6
+ import math
7
+ import os
8
+
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from PIL import Image, ImageDraw
14
+ from torch.utils.data import Dataset, DataLoader
15
+
16
+ warnings.filterwarnings("ignore")
17
+
18
+ def RandomBrush(
19
+ max_tries,
20
+ s,
21
+ min_num_vertex=4,
22
+ max_num_vertex=18,
23
+ mean_angle=2*math.pi / 5,
24
+ angle_range=2*math.pi / 15,
25
+ min_width=12,
26
+ max_width=48
27
+ ):
28
+ H, W = s, s
29
+ average_radius = math.sqrt(H*H+W*W) / 8
30
+ mask = Image.new('L', (W, H), 0)
31
+ for _ in range(np.random.randint(max_tries)):
32
+ num_vertex = np.random.randint(min_num_vertex, max_num_vertex)
33
+ angle_min = mean_angle - np.random.uniform(0, angle_range)
34
+ angle_max = mean_angle + np.random.uniform(0, angle_range)
35
+ angles = []
36
+ vertex = []
37
+ for i in range(num_vertex):
38
+ if i % 2 == 0:
39
+ angles.append(2*math.pi - np.random.uniform(angle_min, angle_max))
40
+ else:
41
+ angles.append(np.random.uniform(angle_min, angle_max))
42
+
43
+ h, w = mask.size
44
+ vertex.append((int(np.random.randint(0, w)), int(np.random.randint(0, h))))
45
+ for i in range(num_vertex):
46
+ r = np.clip(
47
+ np.random.normal(loc=average_radius, scale=average_radius//2),
48
+ 0, 2*average_radius)
49
+ new_x = np.clip(vertex[-1][0] + r * math.cos(angles[i]), 0, w)
50
+ new_y = np.clip(vertex[-1][1] + r * math.sin(angles[i]), 0, h)
51
+ vertex.append((int(new_x), int(new_y)))
52
+
53
+ draw = ImageDraw.Draw(mask)
54
+ width = int(np.random.uniform(min_width, max_width))
55
+ draw.line(vertex, fill=1, width=width)
56
+ for v in vertex:
57
+ draw.ellipse((v[0] - width//2,
58
+ v[1] - width//2,
59
+ v[0] + width//2,
60
+ v[1] + width//2),
61
+ fill=1)
62
+ if np.random.random() > 0.5:
63
+ mask.transpose(Image.FLIP_LEFT_RIGHT)
64
+ if np.random.random() > 0.5:
65
+ mask.transpose(Image.FLIP_TOP_BOTTOM)
66
+ mask = np.asarray(mask, np.uint8)
67
+ if np.random.random() > 0.5:
68
+ mask = np.flip(mask, 0)
69
+ if np.random.random() > 0.5:
70
+ mask = np.flip(mask, 1)
71
+ return mask
72
+
73
+
74
+ def RandomMask(s, hole_range=[0,1]):
75
+ coef = min(hole_range[0] + hole_range[1], 1.0)
76
+ while True:
77
+ mask = np.ones((s, s), np.uint8)
78
+
79
+ def Fill(max_size):
80
+ w, h = np.random.randint(max_size), np.random.randint(max_size)
81
+ ww, hh = w // 2, h // 2
82
+ x, y = np.random.randint(-ww, s - w + ww), np.random.randint(-hh, s - h + hh)
83
+ mask[max(y, 0): min(y + h, s), max(x, 0): min(x + w, s)] = 0
84
+
85
+ def MultiFill(max_tries, max_size):
86
+ for _ in range(np.random.randint(max_tries)):
87
+ Fill(max_size)
88
+
89
+ MultiFill(int(10 * coef), s // 2)
90
+ MultiFill(int(5 * coef), s)
91
+ mask = np.logical_and(mask, 1 - RandomBrush(int(20 * coef), s))
92
+ hole_ratio = 1 - np.mean(mask)
93
+ if hole_range is not None and (hole_ratio <= hole_range[0] or hole_ratio >= hole_range[1]):
94
+ continue
95
+ return (mask * 255).astype(np.uint8)
96
+
97
+
98
+ class InferDataset(Dataset): # ABC
99
+ img_ext = {".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"}
100
+ def __init__(
101
+ self,
102
+ real_dir: Path,
103
+ mask_dir: Optional[Path] = None,
104
+ resolution: int = None
105
+ ):
106
+ super(InferDataset, self).__init__()
107
+
108
+ self.img_paths = sorted([i for i in Path(real_dir).iterdir() if i.suffix in self.img_ext])
109
+ self.mask_dir = mask_dir
110
+ self.resolution = resolution
111
+
112
+ def __len__(self):
113
+ return len(self.img_paths)
114
+
115
+ def __getitem__(self, index) -> Tuple[torch.Tensor, np.array, np.array, str]:
116
+ img_path = Path(self.img_paths[index])
117
+ img_name = img_path.stem
118
+ img = Image.open(img_path).convert("RGB")
119
+ if img.size[0] != self.resolution or img.size[1] != self.resolution:
120
+ img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
121
+ assert img.size[0] == self.resolution
122
+
123
+ if self.mask_dir is not None:
124
+ # mask_path = self.mask_dir / f"{img_name}.png"
125
+ mask_path = self.mask_dir / f"img000{img_name}.png"
126
+ mask = Image.open(mask_path).convert("L")
127
+ mask = mask.resize((self.resolution, self.resolution), Image.NEAREST)
128
+ assert mask.size[0] == self.resolution
129
+ else:
130
+ mask = RandomMask(img.size[0])
131
+ mask = Image.fromarray(mask).convert("L")
132
+
133
+ img = np.array(img)
134
+ mask = np.array(mask)[:, :, np.newaxis] // 255
135
+ img = torch.Tensor(img).float() * 2 / 255 - 1
136
+ mask = torch.Tensor(mask).float()
137
+ img = img.permute(2, 0, 1)
138
+ mask = mask.permute(2, 0, 1)
139
+ x = torch.cat([mask - 0.5, img * mask], dim=0)
140
+ return x, np.array(img), mask, img_name
141
+
142
+ class SimpleInferDataset(torch.utils.data.Dataset):
143
+ def __init__(
144
+ self,
145
+ real_dir: Path,
146
+ mask_dir: Path = None,
147
+ resolution: int = 512
148
+ ):
149
+ super(SimpleInferDataset, self).__init__()
150
+
151
+ img_extensions = {".jpg", ".jpeg", ".JPG", ".JPEG", ".png", ".PNG"}
152
+ self.img_paths = sorted([i for i in Path(real_dir).iterdir() if i.suffix in img_extensions])
153
+ self.img_dir = real_dir
154
+
155
+ if mask_dir:
156
+ self.mask_paths = sorted([i for i in Path(mask_dir).iterdir() if i.suffix in img_extensions])
157
+ self.mask_dir = mask_dir
158
+
159
+ self.resolution = resolution
160
+
161
+ def __getitem__(self, index):
162
+ img_path = Path(self.img_paths[index])
163
+ img_name = os.path.basename(img_path)
164
+
165
+ img = Image.open(img_path).convert("RGB")
166
+
167
+ if self.mask_dir:
168
+ mask_path = Path(self.mask_paths[index])
169
+ mask = Image.open(mask_path).convert("L")
170
+ else:
171
+ mask = RandomMask(img.size[0])
172
+ mask = Image.fromarray(mask).convert("L")
173
+
174
+ mask = mask.resize((self.resolution, self.resolution), Image.NEAREST)
175
+
176
+
177
+ if img.size[0] != self.resolution or img.size[1] != self.resolution:
178
+ img = img.resize((self.resolution, self.resolution), Image.BICUBIC)
179
+
180
+ return img, mask, img_name
181
+
182
+ def __len__(self):
183
+ return len(self.img_paths)
184
+
185
+
186
+
187
+ def collate_fn(inputs):
188
+ image_list = [i[0] for i in inputs]
189
+ mask_list = [i[1] for i in inputs]
190
+ iname_list = [i[2] for i in inputs]
191
+ return image_list, mask_list, iname_list
192
+
193
+
194
+ def build_dataloader(args, dataset_class=InferDataset):
195
+ dataset = dataset_class(
196
+ real_dir=args.real_dir,
197
+ mask_dir=args.mask_dir,
198
+ resolution=args.resolution)
199
+
200
+ dataloader = DataLoader(
201
+ dataset,
202
+ shuffle=False,
203
+ batch_size=args.batch_size,
204
+ num_workers=args.num_workers,
205
+ drop_last=False,
206
+ collate_fn = collate_fn,
207
+ pin_memory=True,
208
+ # persistent_workers=True
209
+ )
210
+
211
+ return dataloader
library/__init__.py ADDED
File without changes
library/chinese_sdxl_train_util.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gc
4
+ import re
5
+ import json
6
+ import math
7
+ import time
8
+ import toml
9
+ import shutil
10
+ import argparse
11
+ from typing import Optional
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ import importlib
15
+
16
+ import torch
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ from library import train_util
20
+ from transformers import BertTokenizer, BertTokenizerFast, ChineseCLIPTextModel, PreTrainedTokenizerFast, T5Tokenizer, T5ForConditionalGeneration
21
+
22
+ from diffusers import (
23
+ DDPMScheduler,
24
+ EulerAncestralDiscreteScheduler,
25
+ DPMSolverMultistepScheduler,
26
+ DDIMScheduler,
27
+ EulerDiscreteScheduler,
28
+ KDPM2DiscreteScheduler,
29
+ AutoencoderKL,
30
+ UNet2DConditionModel,
31
+ )
32
+ from diffusers.models import UNet2DConditionModel, Transformer2DModel
33
+ from torch.utils.tensorboard import SummaryWriter
34
+ from tqdm import tqdm
35
+ from transformers import BertTokenizerFast, ChineseCLIPTextModel
36
+ from library import train_util
37
+
38
+ # from mmmp_text import DebertaV2Model
39
+ # from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
40
+ # from diffusers_patch.models.vivo_llm2vec import LLM2VecWithoutPool
41
+ # from transformers.models.qwen2.tokenization_qwen2_fast import Qwen2TokenizerFast
42
+
43
+ DEFAULT_NOISE_OFFSET = 0.0357
44
+
45
+ def load_target_model(args, accelerator, pipe_class, weight_dtype):
46
+ # load models for each process
47
+ for pi in range(accelerator.state.num_processes):
48
+ if pi == accelerator.state.local_process_index:
49
+ print(f"loading model for process {accelerator.process_index}/{accelerator.state.num_processes}")
50
+
51
+ (
52
+ text_encoder1,
53
+ text_encoder2,
54
+ vae,
55
+ unet,
56
+ ) = _load_target_model(
57
+ args,
58
+ args.pretrained_model_name_or_path,
59
+ args.vae,
60
+ pipe_class,
61
+ weight_dtype,
62
+ accelerator.device if args.lowram else "cpu",
63
+ )
64
+
65
+ gc.collect()
66
+ torch.cuda.empty_cache()
67
+ accelerator.wait_for_everyone()
68
+
69
+ return text_encoder1, text_encoder2, vae, unet
70
+
71
+
72
+ def _load_target_model(args, name_or_path: str, vae_path: Optional[str], pipe_class, weight_dtype, device="cpu"):
73
+ name_or_path = os.readlink(name_or_path) if os.path.islink(name_or_path) else name_or_path
74
+
75
+ model_index_path = os.path.join(name_or_path, 'model_index.json')
76
+ model_index = read_json(model_index_path)
77
+ TextEncoderLib1 = model_index['text_encoder'][0]
78
+ TextEncoderLib2 = model_index['text_encoder_2'][0]
79
+ TextEncoderClass1 = model_index['text_encoder'][-1]
80
+ TextEncoderClass2 = model_index['text_encoder_2'][-1]
81
+
82
+ library1 = importlib.import_module(TextEncoderLib1)
83
+ library2 = importlib.import_module(TextEncoderLib2)
84
+ TextEncoderClass1 = getattr(library1, TextEncoderClass1)
85
+ TextEncoderClass2 = getattr(library2, TextEncoderClass2)
86
+
87
+ if 'unet' in model_index:
88
+ UNetClass = eval(model_index['unet'][-1])
89
+ unet_dir = 'unet'
90
+ elif 'transformer' in model_index:
91
+ UNetClass = eval(model_index['transformer'][-1])
92
+ unet_dir = 'transformer'
93
+
94
+ print(f"TextEncoderClass1:{TextEncoderClass1}")
95
+ print(f"TextEncoderClass2:{TextEncoderClass2}")
96
+ print(f"UNetClass:{UNetClass}")
97
+
98
+ vae = AutoencoderKL.from_pretrained(os.path.join(name_or_path, 'vae'), torch_dtype=weight_dtype, low_cpu_mem_usage=False, device_map=None)
99
+ unet = UNetClass.from_pretrained(os.path.join(name_or_path, unet_dir), torch_dtype=weight_dtype, low_cpu_mem_usage=False, device_map=None, ignore_mismatched_sizes=True)
100
+
101
+ text_encoder1 = TextEncoderClass1.from_pretrained(os.path.join(name_or_path, 'text_encoder'), torch_dtype=weight_dtype)
102
+ text_encoder2 = TextEncoderClass2.from_pretrained(os.path.join(name_or_path, 'text_encoder_2'), torch_dtype=weight_dtype)
103
+
104
+ vae_version = vae.config.version if 'version' in vae.config else ''
105
+ if vae_version == 'vivo':
106
+ vae.quant_conv = torch.nn.Identity()
107
+ vae.post_quant_conv = torch.nn.Identity()
108
+
109
+ return text_encoder1, text_encoder2, vae, unet
110
+
111
+
112
+ def load_tokenizers(args: argparse.Namespace):
113
+ print("prepare tokenizers")
114
+ model_index_path = os.path.join(args.pretrained_model_name_or_path, 'model_index.json')
115
+ model_index = read_json(model_index_path)
116
+ ToeknierLib1 = model_index['tokenizer'][0]
117
+ ToeknierLib2 = model_index['tokenizer_2'][0]
118
+ TokenierClass1 = model_index['tokenizer'][-1]
119
+ TokenierClass2 = "BertTokenizer" # ToDo: model_index['tokenizer_2'][-1]
120
+
121
+ library1 = importlib.import_module(ToeknierLib1)
122
+ library2 = importlib.import_module(ToeknierLib2)
123
+ TokenierClass1 = getattr(library1, TokenierClass1)
124
+ TokenierClass2 = getattr(library2, TokenierClass2)
125
+
126
+
127
+ tokenizer_1 = TokenierClass1.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
128
+ tokenizer_2 = TokenierClass2.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer_2')
129
+ tokeniers = [tokenizer_1, tokenizer_2]
130
+
131
+ if hasattr(args, "max_token_length") and args.max_token_length is not None:
132
+ print(f"update token length: {args.max_token_length}")
133
+
134
+ return tokeniers
135
+
136
+
137
+ def get_hidden_states_sdxl(
138
+ input_ids1: torch.Tensor,
139
+ input_ids2: torch.Tensor,
140
+ tokenizer1: BertTokenizerFast,
141
+ tokenizer2: BertTokenizerFast,
142
+ text_encoder1: ChineseCLIPTextModel,
143
+ text_encoder2: ChineseCLIPTextModel,
144
+ weight_dtype: Optional[str] = None,
145
+ attention_mask1: torch.Tensor = None,
146
+ attention_mask2: torch.Tensor = None,
147
+ ):
148
+ # input_ids: b,n,77 -> b*n, 77
149
+ b_size = input_ids1.size()[0]
150
+ input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
151
+ input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77
152
+ if attention_mask1 is not None:
153
+ attention_mask1 = attention_mask1.reshape((-1, tokenizer1.model_max_length))
154
+ attention_mask2 = attention_mask2.reshape((-1, tokenizer2.model_max_length))
155
+
156
+ hidden_states1, _ = encode_token(input_ids1, attention_mask1, text_encoder1)
157
+ hidden_states2, pool2 = encode_token(input_ids2, attention_mask2, text_encoder2)
158
+
159
+ hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
160
+ hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))
161
+ if weight_dtype is not None:
162
+ # this is required for additional network training
163
+ hidden_states1 = hidden_states1.to(weight_dtype)
164
+ hidden_states2 = hidden_states2.to(weight_dtype)
165
+
166
+ return hidden_states1, hidden_states2, pool2
167
+
168
+
169
+ def encode_token(input_ids, attention_mask, text_encoder):
170
+
171
+ # T5
172
+ if isinstance(text_encoder, T5ForConditionalGeneration):
173
+ prompt_embeds = text_encoder.encoder(
174
+ input_ids,
175
+ attention_mask=attention_mask,
176
+ output_hidden_states=True,
177
+ )
178
+ pooled_prompt_embeds = None
179
+ prompt_embeds = prompt_embeds.hidden_states[-1]
180
+
181
+ # clip Bert
182
+ elif isinstance(text_encoder, ChineseCLIPTextModel):
183
+ prompt_embeds = text_encoder(
184
+ input_ids,
185
+ attention_mask=attention_mask,
186
+ output_hidden_states=True,
187
+ )
188
+ # We are only ALWAYS interested in the pooled output of the final text encoder
189
+ pooled_prompt_embeds = prompt_embeds['pooler_output']
190
+ prompt_embeds = prompt_embeds.hidden_states[-2]
191
+
192
+ # 3mp_Bert\Qwen2Model\LLM2VecWithoutPool\GLMModel
193
+ else:
194
+ prompt_embeds = text_encoder(
195
+ input_ids,
196
+ attention_mask=attention_mask,
197
+ output_hidden_states=True,
198
+ )
199
+
200
+ if 'last_hidden_states' in prompt_embeds:
201
+ prompt_embeds = prompt_embeds.last_hidden_states
202
+ else:
203
+ prompt_embeds = prompt_embeds.last_hidden_state
204
+
205
+ pooled_prompt_embeds = prompt_embeds.mean(dim=1)
206
+
207
+ return prompt_embeds, pooled_prompt_embeds
208
+
209
+
210
+
211
+
212
+ def prepare_logging(args: argparse.Namespace, is_main_process):
213
+ if args.logging_dir is None:
214
+ logging_dir = None
215
+ else:
216
+ log_prefix = "" if args.log_prefix is None else args.log_prefix
217
+ logging_dir = args.logging_dir + "/" + log_prefix + time.strftime("%Y%m%d%H%M%S", time.localtime())
218
+
219
+ log_with = args.log_with
220
+ if log_with in ["tensorboard", "all"]:
221
+ if logging_dir is None:
222
+ raise ValueError("logging_dir is required when log_with is tensorboard / Tensorboardを使う場合、logging_dirを指定してください")
223
+ tensorboard_dir = os.path.join(logging_dir, 'tensorboard')
224
+
225
+ writer = None
226
+ if is_main_process:
227
+ os.makedirs(logging_dir, exist_ok=True)
228
+ os.makedirs(tensorboard_dir, exist_ok=True)
229
+
230
+ if args.script_args:
231
+ sh_basename = os.path.basename(args.script_args)
232
+ sh_dst_path = os.path.join(logging_dir, sh_basename)
233
+ data_basename = os.path.basename(args.dataset_config)
234
+ data_dst_path = os.path.join(logging_dir, data_basename)
235
+ shutil.copyfile(args.script_args, sh_dst_path)
236
+ shutil.copyfile(args.dataset_config, data_dst_path)
237
+
238
+ writer = SummaryWriter(tensorboard_dir)
239
+
240
+ return writer
241
+
242
+
243
+ def verify_sdxl_training_args(args: argparse.Namespace, supportTextEncoderCaching: bool = True):
244
+
245
+ if args.clip_skip is not None:
246
+ print("clip_skip will be unexpected / SDXL学習ではclip_skipは動作しません")
247
+
248
+ if args.multires_noise_iterations:
249
+ print(
250
+ f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET}, but noise_offset is disabled due to multires_noise_iterations"
251
+ )
252
+ else:
253
+ if args.noise_offset is None:
254
+ args.noise_offset = DEFAULT_NOISE_OFFSET
255
+ elif args.noise_offset != DEFAULT_NOISE_OFFSET:
256
+ print(
257
+ f"Warning: SDXL has been trained with noise_offset={DEFAULT_NOISE_OFFSET} / SDXLはnoise_offset={DEFAULT_NOISE_OFFSET}で学習されています"
258
+ )
259
+ print(f"noise_offset is set to {args.noise_offset}")
260
+
261
+ assert (
262
+ not hasattr(args, "weighted_captions") or not args.weighted_captions
263
+ ), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"
264
+
265
+ if supportTextEncoderCaching:
266
+ if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
267
+ args.cache_text_encoder_outputs = True
268
+ print(
269
+ "cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
270
+ + "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
271
+ )
272
+
273
+
274
+ def timestep_embedding(timesteps, dim, max_period=10000):
275
+ """
276
+ Create sinusoidal timestep embeddings.
277
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
278
+ These may be fractional.
279
+ :param dim: the dimension of the output.
280
+ :param max_period: controls the minimum frequency of the embeddings.
281
+ :return: an [N x dim] Tensor of positional embeddings.
282
+ """
283
+ half = dim // 2
284
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
285
+ device=timesteps.device
286
+ )
287
+ args = timesteps[:, None].float() * freqs[None]
288
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
289
+ if dim % 2:
290
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
291
+ return embedding
292
+
293
+
294
+ def get_timestep_embedding(x, outdim):
295
+ assert len(x.shape) == 2
296
+ b, dims = x.shape[0], x.shape[1]
297
+ x = torch.flatten(x)
298
+ emb = timestep_embedding(x, outdim)
299
+ emb = torch.reshape(emb, (b, dims * outdim))
300
+ return emb
301
+
302
+
303
+ def get_size_embeddings(orig_size, crop_size, target_size, device):
304
+ emb1 = get_timestep_embedding(orig_size, 256)
305
+ emb2 = get_timestep_embedding(crop_size, 256)
306
+ emb3 = get_timestep_embedding(target_size, 256)
307
+ vector = torch.cat([emb1, emb2, emb3], dim=1).to(device)
308
+ return vector
309
+
310
+
311
+ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
312
+ parser.add_argument(
313
+ "--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
314
+ )
315
+ parser.add_argument(
316
+ "--cache_text_encoder_outputs_to_disk",
317
+ action="store_true",
318
+ help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
319
+ )
320
+
321
+
322
+ def set_unet_eff_attn(unet, mem_eff_attn, xformers, sdpa):
323
+ if mem_eff_attn:
324
+ print("Enable memory efficient attention for U-Net")
325
+ unet.set_use_memory_efficient_attention_xformers(False, True)
326
+ elif xformers:
327
+ print("Enable xformers for U-Net")
328
+ try:
329
+ import xformers.ops
330
+ except ImportError:
331
+ raise ImportError("No xformers / xformersがインストールされていないようです")
332
+
333
+ unet.set_use_memory_efficient_attention_xformers(True, False)
334
+ elif sdpa:
335
+ print("Enable SDPA for U-Net")
336
+ unet.set_use_sdpa(True)
337
+
338
+ def set_diffusers_xformers_flag(model, valid):
339
+ def fn_recursive_set_mem_eff(module: torch.nn.Module):
340
+ if hasattr(module, "set_use_memory_efficient_attention_xformers"):
341
+ module.set_use_memory_efficient_attention_xformers(valid)
342
+
343
+ for child in module.children():
344
+ fn_recursive_set_mem_eff(child)
345
+
346
+ fn_recursive_set_mem_eff(model)
347
+
348
+
349
+ def read_json(json_path):
350
+ return json.load(open(json_path))
library/custom_train_functions.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import argparse
3
+ import random
4
+ import re
5
+ from typing import List, Optional, Union
6
+
7
+
8
+ def prepare_scheduler_for_custom_training(noise_scheduler, device):
9
+ if hasattr(noise_scheduler, "all_snr"):
10
+ return
11
+
12
+ alphas_cumprod = noise_scheduler.alphas_cumprod
13
+ sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
14
+ sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
15
+ alpha = sqrt_alphas_cumprod
16
+ sigma = sqrt_one_minus_alphas_cumprod
17
+ all_snr = (alpha / sigma) ** 2
18
+
19
+ noise_scheduler.all_snr = all_snr.to(device)
20
+
21
+
22
+ def fix_noise_scheduler_betas_for_zero_terminal_snr(noise_scheduler):
23
+ # fix beta: zero terminal SNR
24
+ print(f"fix noise scheduler betas: https://arxiv.org/abs/2305.08891")
25
+
26
+ def enforce_zero_terminal_snr(betas):
27
+ # Convert betas to alphas_bar_sqrt
28
+ alphas = 1 - betas
29
+ alphas_bar = alphas.cumprod(0)
30
+ alphas_bar_sqrt = alphas_bar.sqrt()
31
+
32
+ # Store old values.
33
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
34
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
35
+ # Shift so last timestep is zero.
36
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
37
+ # Scale so first timestep is back to old value.
38
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
39
+
40
+ # Convert alphas_bar_sqrt to betas
41
+ alphas_bar = alphas_bar_sqrt**2
42
+ alphas = alphas_bar[1:] / alphas_bar[:-1]
43
+ alphas = torch.cat([alphas_bar[0:1], alphas])
44
+ betas = 1 - alphas
45
+ return betas
46
+
47
+ betas = noise_scheduler.betas
48
+ betas = enforce_zero_terminal_snr(betas)
49
+ alphas = 1.0 - betas
50
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
51
+
52
+ # print("original:", noise_scheduler.betas)
53
+ # print("fixed:", betas)
54
+
55
+ noise_scheduler.betas = betas
56
+ noise_scheduler.alphas = alphas
57
+ noise_scheduler.alphas_cumprod = alphas_cumprod
58
+
59
+
60
+ def apply_snr_weight(loss, timesteps, noise_scheduler, gamma):
61
+ snr = torch.stack([noise_scheduler.all_snr[t] for t in timesteps])
62
+ gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr)
63
+ snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float().to(loss.device) # from paper
64
+ loss = loss * snr_weight
65
+ return loss
66
+
67
+
68
+ def scale_v_prediction_loss_like_noise_prediction(loss, timesteps, noise_scheduler):
69
+ scale = get_snr_scale(timesteps, noise_scheduler)
70
+ loss = loss * scale
71
+ return loss
72
+
73
+
74
+ def get_snr_scale(timesteps, noise_scheduler):
75
+ snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
76
+ snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
77
+ scale = snr_t / (snr_t + 1)
78
+ # # show debug info
79
+ # print(f"timesteps: {timesteps}, snr_t: {snr_t}, scale: {scale}")
80
+ return scale
81
+
82
+
83
+ def add_v_prediction_like_loss(loss, timesteps, noise_scheduler, v_pred_like_loss):
84
+ scale = get_snr_scale(timesteps, noise_scheduler)
85
+ # print(f"add v-prediction like loss: {v_pred_like_loss}, scale: {scale}, loss: {loss}, time: {timesteps}")
86
+ loss = loss + loss / scale * v_pred_like_loss
87
+ return loss
88
+
89
+
90
+ # TODO train_utilと分散しているのでどちらかに寄せる
91
+
92
+
93
+ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted_captions: bool = True):
94
+ parser.add_argument(
95
+ "--min_snr_gamma",
96
+ type=float,
97
+ default=None,
98
+ help="gamma for reducing the weight of high loss timesteps. Lower numbers have stronger effect. 5 is recommended by paper. / 低いタイムステップでの高いlossに対して重みを減らすためのgamma値、低いほど効果が強く、論文では5が推奨",
99
+ )
100
+ parser.add_argument(
101
+ "--scale_v_pred_loss_like_noise_pred",
102
+ action="store_true",
103
+ help="scale v-prediction loss like noise prediction loss / v-prediction lossをnoise prediction lossと同じようにスケーリングする",
104
+ )
105
+ parser.add_argument(
106
+ "--v_pred_like_loss",
107
+ type=float,
108
+ default=None,
109
+ help="add v-prediction like loss multiplied by this value / v-prediction lossをこの値をかけたものをlossに加算する",
110
+ )
111
+ if support_weighted_captions:
112
+ parser.add_argument(
113
+ "--weighted_captions",
114
+ action="store_true",
115
+ default=False,
116
+ help="Enable weighted captions in the standard style (token:1.3). No commas inside parens, or shuffle/dropout may break the decoder. / 「[token]」、「(token)」「(token:1.3)」のような重み付きキャプションを有効にする。カンマを括弧内に入れるとシャッフルやdropoutで重みづけがおかしくなるので注意",
117
+ )
118
+
119
+
120
+ re_attention = re.compile(
121
+ r"""
122
+ \\\(|
123
+ \\\)|
124
+ \\\[|
125
+ \\]|
126
+ \\\\|
127
+ \\|
128
+ \(|
129
+ \[|
130
+ :([+-]?[.\d]+)\)|
131
+ \)|
132
+ ]|
133
+ [^\\()\[\]:]+|
134
+ :
135
+ """,
136
+ re.X,
137
+ )
138
+
139
+
140
+ def parse_prompt_attention(text):
141
+ """
142
+ Parses a string with attention tokens and returns a list of pairs: text and its associated weight.
143
+ Accepted tokens are:
144
+ (abc) - increases attention to abc by a multiplier of 1.1
145
+ (abc:3.12) - increases attention to abc by a multiplier of 3.12
146
+ [abc] - decreases attention to abc by a multiplier of 1.1
147
+ \( - literal character '('
148
+ \[ - literal character '['
149
+ \) - literal character ')'
150
+ \] - literal character ']'
151
+ \\ - literal character '\'
152
+ anything else - just text
153
+ >>> parse_prompt_attention('normal text')
154
+ [['normal text', 1.0]]
155
+ >>> parse_prompt_attention('an (important) word')
156
+ [['an ', 1.0], ['important', 1.1], [' word', 1.0]]
157
+ >>> parse_prompt_attention('(unbalanced')
158
+ [['unbalanced', 1.1]]
159
+ >>> parse_prompt_attention('\(literal\]')
160
+ [['(literal]', 1.0]]
161
+ >>> parse_prompt_attention('(unnecessary)(parens)')
162
+ [['unnecessaryparens', 1.1]]
163
+ >>> parse_prompt_attention('a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).')
164
+ [['a ', 1.0],
165
+ ['house', 1.5730000000000004],
166
+ [' ', 1.1],
167
+ ['on', 1.0],
168
+ [' a ', 1.1],
169
+ ['hill', 0.55],
170
+ [', sun, ', 1.1],
171
+ ['sky', 1.4641000000000006],
172
+ ['.', 1.1]]
173
+ """
174
+
175
+ res = []
176
+ round_brackets = []
177
+ square_brackets = []
178
+
179
+ round_bracket_multiplier = 1.1
180
+ square_bracket_multiplier = 1 / 1.1
181
+
182
+ def multiply_range(start_position, multiplier):
183
+ for p in range(start_position, len(res)):
184
+ res[p][1] *= multiplier
185
+
186
+ for m in re_attention.finditer(text):
187
+ text = m.group(0)
188
+ weight = m.group(1)
189
+
190
+ if text.startswith("\\"):
191
+ res.append([text[1:], 1.0])
192
+ elif text == "(":
193
+ round_brackets.append(len(res))
194
+ elif text == "[":
195
+ square_brackets.append(len(res))
196
+ elif weight is not None and len(round_brackets) > 0:
197
+ multiply_range(round_brackets.pop(), float(weight))
198
+ elif text == ")" and len(round_brackets) > 0:
199
+ multiply_range(round_brackets.pop(), round_bracket_multiplier)
200
+ elif text == "]" and len(square_brackets) > 0:
201
+ multiply_range(square_brackets.pop(), square_bracket_multiplier)
202
+ else:
203
+ res.append([text, 1.0])
204
+
205
+ for pos in round_brackets:
206
+ multiply_range(pos, round_bracket_multiplier)
207
+
208
+ for pos in square_brackets:
209
+ multiply_range(pos, square_bracket_multiplier)
210
+
211
+ if len(res) == 0:
212
+ res = [["", 1.0]]
213
+
214
+ # merge runs of identical weights
215
+ i = 0
216
+ while i + 1 < len(res):
217
+ if res[i][1] == res[i + 1][1]:
218
+ res[i][0] += res[i + 1][0]
219
+ res.pop(i + 1)
220
+ else:
221
+ i += 1
222
+
223
+ return res
224
+
225
+
226
+ def get_prompts_with_weights(tokenizer, prompt: List[str], max_length: int):
227
+ r"""
228
+ Tokenize a list of prompts and return its tokens with weights of each token.
229
+
230
+ No padding, starting or ending token is included.
231
+ """
232
+ tokens = []
233
+ weights = []
234
+ truncated = False
235
+ for text in prompt:
236
+ texts_and_weights = parse_prompt_attention(text)
237
+ text_token = []
238
+ text_weight = []
239
+ for word, weight in texts_and_weights:
240
+ # tokenize and discard the starting and the ending token
241
+ token = tokenizer(word).input_ids[1:-1]
242
+ text_token += token
243
+ # copy the weight by length of token
244
+ text_weight += [weight] * len(token)
245
+ # stop if the text is too long (longer than truncation limit)
246
+ if len(text_token) > max_length:
247
+ truncated = True
248
+ break
249
+ # truncate
250
+ if len(text_token) > max_length:
251
+ truncated = True
252
+ text_token = text_token[:max_length]
253
+ text_weight = text_weight[:max_length]
254
+ tokens.append(text_token)
255
+ weights.append(text_weight)
256
+ if truncated:
257
+ print("Prompt was truncated. Try to shorten the prompt or increase max_embeddings_multiples")
258
+ return tokens, weights
259
+
260
+
261
+ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_middle=True, chunk_length=77):
262
+ r"""
263
+ Pad the tokens (with starting and ending tokens) and weights (with 1.0) to max_length.
264
+ """
265
+ max_embeddings_multiples = (max_length - 2) // (chunk_length - 2)
266
+ weights_length = max_length if no_boseos_middle else max_embeddings_multiples * chunk_length
267
+ for i in range(len(tokens)):
268
+ tokens[i] = [bos] + tokens[i] + [eos] * (max_length - 1 - len(tokens[i]))
269
+ if no_boseos_middle:
270
+ weights[i] = [1.0] + weights[i] + [1.0] * (max_length - 1 - len(weights[i]))
271
+ else:
272
+ w = []
273
+ if len(weights[i]) == 0:
274
+ w = [1.0] * weights_length
275
+ else:
276
+ for j in range(max_embeddings_multiples):
277
+ w.append(1.0) # weight for starting token in this chunk
278
+ w += weights[i][j * (chunk_length - 2) : min(len(weights[i]), (j + 1) * (chunk_length - 2))]
279
+ w.append(1.0) # weight for ending token in this chunk
280
+ w += [1.0] * (weights_length - len(w))
281
+ weights[i] = w[:]
282
+
283
+ return tokens, weights
284
+
285
+
286
+ def get_unweighted_text_embeddings(
287
+ tokenizer,
288
+ text_encoder,
289
+ text_input: torch.Tensor,
290
+ chunk_length: int,
291
+ clip_skip: int,
292
+ eos: int,
293
+ pad: int,
294
+ no_boseos_middle: Optional[bool] = True,
295
+ ):
296
+ """
297
+ When the length of tokens is a multiple of the capacity of the text encoder,
298
+ it should be split into chunks and sent to the text encoder individually.
299
+ """
300
+ max_embeddings_multiples = (text_input.shape[1] - 2) // (chunk_length - 2)
301
+ if max_embeddings_multiples > 1:
302
+ text_embeddings = []
303
+ for i in range(max_embeddings_multiples):
304
+ # extract the i-th chunk
305
+ text_input_chunk = text_input[:, i * (chunk_length - 2) : (i + 1) * (chunk_length - 2) + 2].clone()
306
+
307
+ # cover the head and the tail by the starting and the ending tokens
308
+ text_input_chunk[:, 0] = text_input[0, 0]
309
+ if pad == eos: # v1
310
+ text_input_chunk[:, -1] = text_input[0, -1]
311
+ else: # v2
312
+ for j in range(len(text_input_chunk)):
313
+ if text_input_chunk[j, -1] != eos and text_input_chunk[j, -1] != pad: # 最後に普通の文字がある
314
+ text_input_chunk[j, -1] = eos
315
+ if text_input_chunk[j, 1] == pad: # BOSだけであとはPAD
316
+ text_input_chunk[j, 1] = eos
317
+
318
+ if clip_skip is None or clip_skip == 1:
319
+ text_embedding = text_encoder(text_input_chunk)[0]
320
+ else:
321
+ enc_out = text_encoder(text_input_chunk, output_hidden_states=True, return_dict=True)
322
+ text_embedding = enc_out["hidden_states"][-clip_skip]
323
+ text_embedding = text_encoder.text_model.final_layer_norm(text_embedding)
324
+
325
+ if no_boseos_middle:
326
+ if i == 0:
327
+ # discard the ending token
328
+ text_embedding = text_embedding[:, :-1]
329
+ elif i == max_embeddings_multiples - 1:
330
+ # discard the starting token
331
+ text_embedding = text_embedding[:, 1:]
332
+ else:
333
+ # discard both starting and ending tokens
334
+ text_embedding = text_embedding[:, 1:-1]
335
+
336
+ text_embeddings.append(text_embedding)
337
+ text_embeddings = torch.concat(text_embeddings, axis=1)
338
+ else:
339
+ if clip_skip is None or clip_skip == 1:
340
+ text_embeddings = text_encoder(text_input)[0]
341
+ else:
342
+ enc_out = text_encoder(text_input, output_hidden_states=True, return_dict=True)
343
+ text_embeddings = enc_out["hidden_states"][-clip_skip]
344
+ text_embeddings = text_encoder.text_model.final_layer_norm(text_embeddings)
345
+ return text_embeddings
346
+
347
+
348
+ def get_weighted_text_embeddings(
349
+ tokenizer,
350
+ text_encoder,
351
+ prompt: Union[str, List[str]],
352
+ device,
353
+ max_embeddings_multiples: Optional[int] = 3,
354
+ no_boseos_middle: Optional[bool] = False,
355
+ clip_skip=None,
356
+ ):
357
+ r"""
358
+ Prompts can be assigned with local weights using brackets. For example,
359
+ prompt 'A (very beautiful) masterpiece' highlights the words 'very beautiful',
360
+ and the embedding tokens corresponding to the words get multiplied by a constant, 1.1.
361
+
362
+ Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
363
+
364
+ Args:
365
+ prompt (`str` or `List[str]`):
366
+ The prompt or prompts to guide the image generation.
367
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
368
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
369
+ no_boseos_middle (`bool`, *optional*, defaults to `False`):
370
+ If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
371
+ ending token in each of the chunk in the middle.
372
+ skip_parsing (`bool`, *optional*, defaults to `False`):
373
+ Skip the parsing of brackets.
374
+ skip_weighting (`bool`, *optional*, defaults to `False`):
375
+ Skip the weighting. When the parsing is skipped, it is forced True.
376
+ """
377
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
378
+ if isinstance(prompt, str):
379
+ prompt = [prompt]
380
+
381
+ prompt_tokens, prompt_weights = get_prompts_with_weights(tokenizer, prompt, max_length - 2)
382
+
383
+ # round up the longest length of tokens to a multiple of (model_max_length - 2)
384
+ max_length = max([len(token) for token in prompt_tokens])
385
+
386
+ max_embeddings_multiples = min(
387
+ max_embeddings_multiples,
388
+ (max_length - 1) // (tokenizer.model_max_length - 2) + 1,
389
+ )
390
+ max_embeddings_multiples = max(1, max_embeddings_multiples)
391
+ max_length = (tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
392
+
393
+ # pad the length of tokens and weights
394
+ bos = tokenizer.bos_token_id
395
+ eos = tokenizer.eos_token_id
396
+ pad = tokenizer.pad_token_id
397
+ prompt_tokens, prompt_weights = pad_tokens_and_weights(
398
+ prompt_tokens,
399
+ prompt_weights,
400
+ max_length,
401
+ bos,
402
+ eos,
403
+ no_boseos_middle=no_boseos_middle,
404
+ chunk_length=tokenizer.model_max_length,
405
+ )
406
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=device)
407
+
408
+ # get the embeddings
409
+ text_embeddings = get_unweighted_text_embeddings(
410
+ tokenizer,
411
+ text_encoder,
412
+ prompt_tokens,
413
+ tokenizer.model_max_length,
414
+ clip_skip,
415
+ eos,
416
+ pad,
417
+ no_boseos_middle=no_boseos_middle,
418
+ )
419
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=device)
420
+
421
+ # assign weights to the prompts and normalize in the sense of mean
422
+ previous_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
423
+ text_embeddings = text_embeddings * prompt_weights.unsqueeze(-1)
424
+ current_mean = text_embeddings.float().mean(axis=[-2, -1]).to(text_embeddings.dtype)
425
+ text_embeddings = text_embeddings * (previous_mean / current_mean).unsqueeze(-1).unsqueeze(-1)
426
+
427
+ return text_embeddings
428
+
429
+
430
+ # https://wandb.ai/johnowhitaker/multires_noise/reports/Multi-Resolution-Noise-for-Diffusion-Model-Training--VmlldzozNjYyOTU2
431
+ def pyramid_noise_like(noise, device, iterations=6, discount=0.4):
432
+ b, c, w, h = noise.shape # EDIT: w and h get over-written, rename for a different variant!
433
+ u = torch.nn.Upsample(size=(w, h), mode="bilinear").to(device)
434
+ for i in range(iterations):
435
+ r = random.random() * 2 + 2 # Rather than always going 2x,
436
+ wn, hn = max(1, int(w / (r**i))), max(1, int(h / (r**i)))
437
+ noise += u(torch.randn(b, c, wn, hn).to(device)) * discount**i
438
+ if wn == 1 or hn == 1:
439
+ break # Lowest resolution is 1x1
440
+ return noise / noise.std() # Scaled back to roughly unit variance
441
+
442
+
443
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
444
+ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
445
+ if noise_offset is None:
446
+ return noise
447
+ if adaptive_noise_scale is not None:
448
+ # latent shape: (batch_size, channels, height, width)
449
+ # abs mean value for each channel
450
+ latent_mean = torch.abs(latents.mean(dim=(2, 3), keepdim=True))
451
+
452
+ # multiply adaptive noise scale to the mean value and add it to the noise offset
453
+ noise_offset = noise_offset + adaptive_noise_scale * latent_mean
454
+ noise_offset = torch.clamp(noise_offset, 0.0, None) # in case of adaptive noise scale is negative
455
+
456
+ noise = noise + noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
457
+ return noise
458
+
459
+
460
+ """
461
+ ##########################################
462
+ # Perlin Noise
463
+ def rand_perlin_2d(device, shape, res, fade=lambda t: 6 * t**5 - 15 * t**4 + 10 * t**3):
464
+ delta = (res[0] / shape[0], res[1] / shape[1])
465
+ d = (shape[0] // res[0], shape[1] // res[1])
466
+
467
+ grid = (
468
+ torch.stack(
469
+ torch.meshgrid(torch.arange(0, res[0], delta[0], device=device), torch.arange(0, res[1], delta[1], device=device)),
470
+ dim=-1,
471
+ )
472
+ % 1
473
+ )
474
+ angles = 2 * torch.pi * torch.rand(res[0] + 1, res[1] + 1, device=device)
475
+ gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
476
+
477
+ tile_grads = (
478
+ lambda slice1, slice2: gradients[slice1[0] : slice1[1], slice2[0] : slice2[1]]
479
+ .repeat_interleave(d[0], 0)
480
+ .repeat_interleave(d[1], 1)
481
+ )
482
+ dot = lambda grad, shift: (
483
+ torch.stack((grid[: shape[0], : shape[1], 0] + shift[0], grid[: shape[0], : shape[1], 1] + shift[1]), dim=-1)
484
+ * grad[: shape[0], : shape[1]]
485
+ ).sum(dim=-1)
486
+
487
+ n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
488
+ n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
489
+ n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
490
+ n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
491
+ t = fade(grid[: shape[0], : shape[1]])
492
+ return 1.414 * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
493
+
494
+
495
+ def rand_perlin_2d_octaves(device, shape, res, octaves=1, persistence=0.5):
496
+ noise = torch.zeros(shape, device=device)
497
+ frequency = 1
498
+ amplitude = 1
499
+ for _ in range(octaves):
500
+ noise += amplitude * rand_perlin_2d(device, shape, (frequency * res[0], frequency * res[1]))
501
+ frequency *= 2
502
+ amplitude *= persistence
503
+ return noise
504
+
505
+
506
+ def perlin_noise(noise, device, octaves):
507
+ _, c, w, h = noise.shape
508
+ perlin = lambda: rand_perlin_2d_octaves(device, (w, h), (4, 4), octaves)
509
+ noise_perlin = []
510
+ for _ in range(c):
511
+ noise_perlin.append(perlin())
512
+ noise_perlin = torch.stack(noise_perlin).unsqueeze(0) # (1, c, w, h)
513
+ noise += noise_perlin # broadcast for each batch
514
+ return noise / noise.std() # Scaled back to roughly unit variance
515
+ """
library/train_util.py ADDED
The diff for this file is too large to render. See raw diff