Spaces:
Running on Zero
Running on Zero
AndranikSargsyan commited on
Commit ·
a8a9bce
0
Parent(s):
Add FlowDIS inference and demo
Browse files- .gitattributes +5 -0
- .gitignore +43 -0
- APACHE-2.0-LICENSE +201 -0
- LICENSE +311 -0
- README.md +13 -0
- app.py +480 -0
- assets/examples/0.jpg +3 -0
- assets/examples/1.jpg +3 -0
- assets/examples/2.png +3 -0
- assets/examples/3.jpg +3 -0
- assets/examples/4.jpg +3 -0
- assets/examples/5.jpg +3 -0
- assets/examples/6.jpg +3 -0
- assets/examples/examples.csv +8 -0
- assets/examples/prompts.json +9 -0
- assets/preview.png +3 -0
- flowdis/__init__.py +16 -0
- flowdis/autoencoder.py +318 -0
- flowdis/conditioner.py +44 -0
- flowdis/configs.py +32 -0
- flowdis/layers.py +263 -0
- flowdis/loaders.py +75 -0
- flowdis/math.py +30 -0
- flowdis/model.py +118 -0
- flowdis/sampling.py +136 -0
- flowdis/util.py +116 -0
- pyproject.toml +50 -0
- qwen.py +73 -0
- requirements.txt +13 -0
.gitattributes
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.webp filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gradio temporary files
|
| 2 |
+
gradio_temp/
|
| 3 |
+
|
| 4 |
+
# Python
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
*.so
|
| 9 |
+
.Python
|
| 10 |
+
build/
|
| 11 |
+
develop-eggs/
|
| 12 |
+
dist/
|
| 13 |
+
downloads/
|
| 14 |
+
eggs/
|
| 15 |
+
.eggs/
|
| 16 |
+
lib/
|
| 17 |
+
lib64/
|
| 18 |
+
parts/
|
| 19 |
+
sdist/
|
| 20 |
+
var/
|
| 21 |
+
wheels/
|
| 22 |
+
*.egg-info/
|
| 23 |
+
.installed.cfg
|
| 24 |
+
*.egg
|
| 25 |
+
|
| 26 |
+
# Virtual Environment
|
| 27 |
+
venv/
|
| 28 |
+
env/
|
| 29 |
+
ENV/
|
| 30 |
+
|
| 31 |
+
# IDE
|
| 32 |
+
.vscode/
|
| 33 |
+
.idea/
|
| 34 |
+
*.swp
|
| 35 |
+
*.swo
|
| 36 |
+
*~
|
| 37 |
+
|
| 38 |
+
# OS
|
| 39 |
+
.DS_Store
|
| 40 |
+
Thumbs.db
|
| 41 |
+
|
| 42 |
+
outputs/
|
| 43 |
+
.gradio/
|
APACHE-2.0-LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
LICENSE
ADDED
|
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PicsArt Inc. FlowDIS Model License v1.0
|
| 2 |
+
Non-Commercial Use License
|
| 3 |
+
|
| 4 |
+
PicsArt Inc. ("PicsArt," "we," "our," or "Company") makes the weights,
|
| 5 |
+
parameters, and inference code for FlowDIS (as defined below)
|
| 6 |
+
available for your non-commercial and non-production use under the
|
| 7 |
+
terms of this License.
|
| 8 |
+
|
| 9 |
+
FlowDIS is a derivative of FLUX.1 [schnell] by Black Forest Labs,
|
| 10 |
+
Inc., which is licensed under the Apache License, Version 2.0. The
|
| 11 |
+
original FLUX.1 [schnell] model and its associated copyright, patent,
|
| 12 |
+
trademark, and attribution notices are included with this
|
| 13 |
+
distribution. A copy of the Apache License, Version 2.0 is provided
|
| 14 |
+
in the accompanying APACHE-2.0-LICENSE file. This model contains
|
| 15 |
+
modifications made by PicsArt Inc. to the original FLUX.1 [schnell]
|
| 16 |
+
model. By downloading, accessing, using, Distributing, or creating a
|
| 17 |
+
Derivative of FlowDIS, you agree to the terms of this License. If you
|
| 18 |
+
do not agree, you have no rights to access, use, Distribute, or
|
| 19 |
+
create a Derivative of FlowDIS and must immediately cease using it.
|
| 20 |
+
If you accept this License on behalf of your employer or another
|
| 21 |
+
entity, you represent and warrant that you have full legal authority
|
| 22 |
+
to bind that employer or entity.
|
| 23 |
+
|
| 24 |
+
1. Definitions
|
| 25 |
+
|
| 26 |
+
(a) "Derivative" means any (i) modified version of FlowDIS (including
|
| 27 |
+
any fine-tuned or distilled version), (ii) work based on FlowDIS,
|
| 28 |
+
or (iii) any other derivative work thereof. For clarity, Outputs
|
| 29 |
+
are not Derivatives.
|
| 30 |
+
|
| 31 |
+
(b) "Distribution" or "Distribute" means providing or making
|
| 32 |
+
available, by any means, a copy of FlowDIS and/or Derivatives.
|
| 33 |
+
|
| 34 |
+
(c) "Non-Commercial Purpose" means any of the following uses, but
|
| 35 |
+
only so far as you do not receive any direct or indirect payment
|
| 36 |
+
arising from the use of FlowDIS or Derivatives:
|
| 37 |
+
|
| 38 |
+
(i) personal use for research, experimentation, and testing for the
|
| 39 |
+
benefit of public knowledge, personal study, private
|
| 40 |
+
entertainment, hobby projects, or otherwise not directly or
|
| 41 |
+
indirectly connected to any commercial activities, business
|
| 42 |
+
operations, or employment responsibilities;
|
| 43 |
+
|
| 44 |
+
(ii) use by commercial or for-profit entities for testing,
|
| 45 |
+
evaluation, or non-commercial research and development in a
|
| 46 |
+
non-production environment; and
|
| 47 |
+
|
| 48 |
+
(iii) use by any charitable organization for charitable purposes,
|
| 49 |
+
or for testing or evaluation. For clarity, use (a) for
|
| 50 |
+
revenue-generating activity, (b) in direct interactions with
|
| 51 |
+
or that has impact on end users, or (c) to train, fine-tune,
|
| 52 |
+
or distill other models for commercial use, in each case, is
|
| 53 |
+
not a Non-Commercial Purpose.
|
| 54 |
+
|
| 55 |
+
(d) "Outputs" means any content generated by the operation of FlowDIS
|
| 56 |
+
or Derivatives from an input or prompt provided by users. Outputs
|
| 57 |
+
do not include any components of FlowDIS such as fine-tuned
|
| 58 |
+
versions, weights, or parameters.
|
| 59 |
+
|
| 60 |
+
(e) "you" or "your" means the individual or entity entering into this
|
| 61 |
+
License with Company.
|
| 62 |
+
|
| 63 |
+
2. License Grant
|
| 64 |
+
|
| 65 |
+
(a) License. Subject to your compliance with this License, Company
|
| 66 |
+
grants you a non-exclusive, worldwide, non-transferable,
|
| 67 |
+
non-sublicensable, revocable, royalty-free, and limited license
|
| 68 |
+
to access, use, create Derivatives of, and Distribute FlowDIS and
|
| 69 |
+
Derivatives solely for Non-Commercial Purposes. This license is
|
| 70 |
+
personal to you, and you may not assign or sublicense this
|
| 71 |
+
License or any rights or obligations under it without Company's
|
| 72 |
+
prior written consent; any such assignment or sublicense will be
|
| 73 |
+
void and will automatically and immediately terminate this
|
| 74 |
+
License. Any restrictions set forth herein regarding FlowDIS also
|
| 75 |
+
apply to any Derivative you create or that is created on your
|
| 76 |
+
behalf.
|
| 77 |
+
|
| 78 |
+
(b) Non-Commercial Use Only. You may only access, use, Distribute, or
|
| 79 |
+
create Derivatives of FlowDIS or Derivatives for Non-Commercial
|
| 80 |
+
Purposes. If you wish to use FlowDIS or a Derivative for any
|
| 81 |
+
purpose not expressly authorized under this License, you must
|
| 82 |
+
request a license from Company, which Company may grant in its
|
| 83 |
+
sole discretion and which may be subject to a fee, royalty, or
|
| 84 |
+
other revenue share.
|
| 85 |
+
|
| 86 |
+
(c) Reserved Rights. The grant of rights expressly set forth in this
|
| 87 |
+
License constitutes the complete grant of rights to use FlowDIS,
|
| 88 |
+
and no other licenses are granted, whether by waiver, estoppel,
|
| 89 |
+
implication, equity, or otherwise. Company and its licensors
|
| 90 |
+
reserve all rights not expressly granted by this License.
|
| 91 |
+
|
| 92 |
+
(d) Outputs. Company claims no ownership rights in Outputs. You are
|
| 93 |
+
solely responsible for the Outputs you generate and their
|
| 94 |
+
subsequent uses in accordance with this License. You may use
|
| 95 |
+
Outputs for any purpose (including commercial purposes), except
|
| 96 |
+
as expressly prohibited herein.
|
| 97 |
+
|
| 98 |
+
3. Distribution
|
| 99 |
+
|
| 100 |
+
Subject to this License, you may Distribute copies of FlowDIS and/or
|
| 101 |
+
Derivatives made by you under the following conditions:
|
| 102 |
+
(a) You must make available a copy of this License to third-party
|
| 103 |
+
recipients of FlowDIS and/or Derivatives you Distribute, and
|
| 104 |
+
specify that any rights to use FlowDIS and/or Derivatives shall
|
| 105 |
+
be granted directly by Company to said third-party recipients
|
| 106 |
+
pursuant to this License.
|
| 107 |
+
|
| 108 |
+
(b) You must prominently display the following notice alongside the
|
| 109 |
+
Distribution (such as via a "NOTICE" text file distributed as
|
| 110 |
+
part of FlowDIS or the Derivative) (the "Attribution Notice"):
|
| 111 |
+
This model is licensed by PicsArt Inc. under the PicsArt Inc.
|
| 112 |
+
FlowDIS Model License v1.0. Copyright 2026 PicsArt Inc. This
|
| 113 |
+
model is a derivative of FLUX.1 [schnell] by Black Forest Labs,
|
| 114 |
+
Inc., licensed under the Apache License, Version 2.0. IN NO EVENT
|
| 115 |
+
SHALL PICSART INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 116 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
|
| 117 |
+
ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.
|
| 118 |
+
|
| 119 |
+
(c) In the case of Distribution of Derivatives made by you: (i) you
|
| 120 |
+
must include in the Attribution Notice a statement that you have
|
| 121 |
+
modified FlowDIS; (ii) any terms and conditions you impose on
|
| 122 |
+
third-party recipients relating to your Derivatives shall neither
|
| 123 |
+
limit such recipients' use of FlowDIS or any Derivatives made by
|
| 124 |
+
Company in accordance with this License, nor conflict with any of
|
| 125 |
+
its terms and conditions, and must include disclaimer of
|
| 126 |
+
warranties and limitation of liability provisions at least as
|
| 127 |
+
protective of Company as those set forth herein; and (iii) you
|
| 128 |
+
must not misrepresent or imply that Derivatives made by or for
|
| 129 |
+
you are an official product of PicsArt Inc. or have been
|
| 130 |
+
endorsed, approved, or validated by PicsArt Inc., unless
|
| 131 |
+
authorized by Company in writing.
|
| 132 |
+
|
| 133 |
+
(d) Apache 2.0 Compliance. All Distributions must include: (i) a copy
|
| 134 |
+
of the Apache License, Version 2.0; (ii) all copyright, patent,
|
| 135 |
+
trademark, and attribution notices from the original FLUX.1
|
| 136 |
+
[schnell] model; and (iii) prominent notices stating that the
|
| 137 |
+
files have been modified from the original FLUX.1 [schnell]
|
| 138 |
+
model.
|
| 139 |
+
|
| 140 |
+
4. Restrictions
|
| 141 |
+
|
| 142 |
+
You will not, and will not permit, assist, or cause any third party
|
| 143 |
+
to:
|
| 144 |
+
(a) use, modify, copy, reproduce, create Derivatives of, or
|
| 145 |
+
Distribute FlowDIS (or any Derivative or data produced by
|
| 146 |
+
FlowDIS), in whole or in part, for:
|
| 147 |
+
|
| 148 |
+
(i) any commercial or production purpose;
|
| 149 |
+
|
| 150 |
+
(ii) any military purpose, including research, development,
|
| 151 |
+
design, manufacture, production, or use of weapons, weapons
|
| 152 |
+
systems, munitions, or any military or defense applications;
|
| 153 |
+
|
| 154 |
+
(iii) purposes of surveillance, including any research or
|
| 155 |
+
development relating to surveillance;
|
| 156 |
+
|
| 157 |
+
(iv) biometric processing;
|
| 158 |
+
|
| 159 |
+
(v) any manner that infringes, misappropriates, or otherwise violates
|
| 160 |
+
any third party's legal rights, including rights of publicity or
|
| 161 |
+
digital replica rights;
|
| 162 |
+
|
| 163 |
+
(vi) any unlawful, fraudulent, defamatory, or abusive activity;
|
| 164 |
+
|
| 165 |
+
(vii) generating unlawful content, including child sexual abuse
|
| 166 |
+
material or non-consensual intimate images; or
|
| 167 |
+
|
| 168 |
+
(viii) any manner that violates any applicable law, privacy or
|
| 169 |
+
security laws, rules, regulations, directives, or
|
| 170 |
+
governmental requirements (including the GDPR, the California
|
| 171 |
+
Consumer Privacy Act, laws governing the processing of
|
| 172 |
+
biometric information, and the EU AI Act, as well as all
|
| 173 |
+
amendments and successor laws to any of the foregoing);
|
| 174 |
+
|
| 175 |
+
(b) alter or remove copyright and other proprietary notices which
|
| 176 |
+
appear on or in any portion of FlowDIS;
|
| 177 |
+
|
| 178 |
+
(c) utilize any equipment, device, software, or other means to
|
| 179 |
+
circumvent or remove any security or protection used by Company
|
| 180 |
+
in connection with FlowDIS, or to circumvent or remove any usage
|
| 181 |
+
restrictions, or to enable functionality disabled by Company;
|
| 182 |
+
|
| 183 |
+
(d) offer or impose any terms on FlowDIS that alter, restrict, or are
|
| 184 |
+
inconsistent with the terms of this License;
|
| 185 |
+
|
| 186 |
+
(e) violate any applicable U.S. and non-U.S. export control and trade
|
| 187 |
+
sanctions laws ("Export Laws") in connection with your use or
|
| 188 |
+
Distribution of FlowDIS; or
|
| 189 |
+
|
| 190 |
+
(f) directly or indirectly Distribute, export, or otherwise transfer
|
| 191 |
+
FlowDIS (i) to any individual, entity, or country prohibited by
|
| 192 |
+
Export Laws; (ii) to anyone on U.S. or non-U.S. government
|
| 193 |
+
restricted parties lists; (iii) for any purpose prohibited by
|
| 194 |
+
Export Laws, including nuclear, chemical or biological weapons,
|
| 195 |
+
or missile technology applications; (iv) if you or they are
|
| 196 |
+
located in a comprehensively sanctioned jurisdiction, currently
|
| 197 |
+
listed on any U.S. or non-U.S. restricted parties list, or for
|
| 198 |
+
any purpose prohibited by Export Laws; or (v) while disguising
|
| 199 |
+
your location through IP proxying or other methods.
|
| 200 |
+
|
| 201 |
+
5. Disclaimers
|
| 202 |
+
|
| 203 |
+
THE MODEL IS PROVIDED "AS IS" AND "WITH ALL FAULTS" WITH NO WARRANTY
|
| 204 |
+
OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL
|
| 205 |
+
REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY
|
| 206 |
+
STATUTE, CUSTOM, USAGE OR OTHERWISE, INCLUDING BUT NOT LIMITED TO THE
|
| 207 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
| 208 |
+
PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY
|
| 209 |
+
MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE MODEL WILL BE ERROR
|
| 210 |
+
FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY
|
| 211 |
+
PARTICULAR RESULTS.
|
| 212 |
+
|
| 213 |
+
6. Limitation of Liability
|
| 214 |
+
|
| 215 |
+
TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE
|
| 216 |
+
LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS, OR
|
| 217 |
+
DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN
|
| 218 |
+
CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE
|
| 219 |
+
UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL,
|
| 220 |
+
EXEMPLARY, INCIDENTAL, PUNITIVE, OR SPECIAL DAMAGES OR LOST PROFITS,
|
| 221 |
+
EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
| 222 |
+
THE MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY,
|
| 223 |
+
"MODEL MATERIALS") ARE NOT DESIGNED OR INTENDED FOR USE IN ANY
|
| 224 |
+
APPLICATION OR SITUATION WHERE FAILURE OR FAULT COULD REASONABLY BE
|
| 225 |
+
ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING
|
| 226 |
+
POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL'S PRIVACY
|
| 227 |
+
RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE
|
| 228 |
+
(EACH, A "HIGH-RISK USE"). IF YOU ELECT TO USE ANY MODEL MATERIALS
|
| 229 |
+
FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN
|
| 230 |
+
AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION
|
| 231 |
+
PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE.
|
| 232 |
+
|
| 233 |
+
7. Indemnification
|
| 234 |
+
|
| 235 |
+
You will indemnify, defend, and hold harmless Company and its
|
| 236 |
+
subsidiaries and affiliates, and each of their respective
|
| 237 |
+
shareholders, directors, officers, employees, agents, successors, and
|
| 238 |
+
assigns (collectively, the "Company Parties") from and against any
|
| 239 |
+
losses, liabilities, damages, fines, penalties, and expenses
|
| 240 |
+
(including reasonable attorneys' fees) incurred by any Company Party
|
| 241 |
+
in connection with any claim, demand, allegation, lawsuit,
|
| 242 |
+
proceeding, or investigation ("Claims") arising out of or related to:
|
| 243 |
+
(a) your access to or use of FlowDIS (including any Output or data
|
| 244 |
+
generated from such use), including any High-Risk Use; (b) your
|
| 245 |
+
violation of this License; or (c) your violation, misappropriation,
|
| 246 |
+
or infringement of any rights of another (including intellectual
|
| 247 |
+
property or other proprietary rights and privacy rights). You will
|
| 248 |
+
promptly notify the Company Parties of any such Claims and cooperate
|
| 249 |
+
with Company Parties in defending such Claims. You will also grant
|
| 250 |
+
the Company Parties sole control of the defense or settlement, at
|
| 251 |
+
Company's sole option, of any Claims. This indemnity is in addition
|
| 252 |
+
to, and not in lieu of, any other indemnities or remedies set forth
|
| 253 |
+
in a written agreement between you and Company or the other Company
|
| 254 |
+
Parties.
|
| 255 |
+
|
| 256 |
+
8. Termination; Survival
|
| 257 |
+
|
| 258 |
+
(a) This License will automatically terminate upon any breach by you
|
| 259 |
+
of the terms of this License.
|
| 260 |
+
|
| 261 |
+
(b) Company may terminate this License, in whole or in part, at any
|
| 262 |
+
time upon notice (including electronic) to you.
|
| 263 |
+
|
| 264 |
+
(c) If you initiate any legal action or proceedings against Company
|
| 265 |
+
or any other entity (including a cross-claim or counterclaim),
|
| 266 |
+
alleging that FlowDIS, any Derivative, or any part thereof,
|
| 267 |
+
infringes upon intellectual property or other rights owned or
|
| 268 |
+
licensable by you, then any licenses granted to you under this
|
| 269 |
+
License will immediately terminate as of the date such legal
|
| 270 |
+
action or claim is filed.
|
| 271 |
+
|
| 272 |
+
(d) Upon termination, you must cease all use, access, or Distribution
|
| 273 |
+
of FlowDIS and any Derivatives. Sections 2(c), 2(d), 4 through 11
|
| 274 |
+
survive termination.
|
| 275 |
+
|
| 276 |
+
9. Third-Party Materials
|
| 277 |
+
|
| 278 |
+
FlowDIS is derived from FLUX.1 [schnell] by Black Forest Labs, Inc.,
|
| 279 |
+
and may contain additional third-party software or components
|
| 280 |
+
(including free and open-source software) ("Third-Party Materials"),
|
| 281 |
+
which are subject to the license terms of the respective third-party
|
| 282 |
+
licensors. Your dealings or correspondence with third parties and
|
| 283 |
+
your use of or interaction with any Third-Party Materials are solely
|
| 284 |
+
between you and the third party. Company does not control or endorse,
|
| 285 |
+
and makes no representations or warranties regarding, any Third-Party
|
| 286 |
+
Materials, and your access to and use of such Third-Party Materials
|
| 287 |
+
are at your own risk.
|
| 288 |
+
|
| 289 |
+
10. Trademarks
|
| 290 |
+
|
| 291 |
+
No trademark license is granted as part of this License. You may not
|
| 292 |
+
use any name, logo, or trademark associated with PicsArt Inc. without
|
| 293 |
+
Company's prior written permission, except to the extent necessary to
|
| 294 |
+
make the reference required in the Attribution Notice or as is
|
| 295 |
+
reasonably necessary in describing FlowDIS and its creators.
|
| 296 |
+
|
| 297 |
+
11. General
|
| 298 |
+
|
| 299 |
+
This License will be governed and construed under the laws of the
|
| 300 |
+
State of Delaware without regard to conflicts of law provisions. If
|
| 301 |
+
any provision or part of a provision of this License is unlawful,
|
| 302 |
+
void, or unenforceable, that provision or part is deemed severed from
|
| 303 |
+
this License and will not affect the validity and enforceability of
|
| 304 |
+
any remaining provisions. The failure of Company to exercise or
|
| 305 |
+
enforce any right or provision of this License will not operate as a
|
| 306 |
+
waiver of such right or provision. This License does not confer any
|
| 307 |
+
third-party beneficiary rights upon any other person or entity. This
|
| 308 |
+
License, together with any accompanying documentation, contains the
|
| 309 |
+
entire understanding between you and Company regarding its subject
|
| 310 |
+
matter and supersedes all other written or oral agreements and
|
| 311 |
+
understandings between you and Company regarding such subject matter.
|
README.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: FlowDIS
|
| 3 |
+
emoji: 🌀
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 6.3.0
|
| 8 |
+
python_version: 3.12
|
| 9 |
+
app_file: app.py
|
| 10 |
+
pinned: true
|
| 11 |
+
thumbnail: assets/preview.png
|
| 12 |
+
---
|
| 13 |
+
Paper: https://arxiv.org/abs/2605.05077
|
app.py
ADDED
|
@@ -0,0 +1,480 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
import uuid
|
| 5 |
+
import shutil
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
# Set up logging
|
| 10 |
+
logging.basicConfig(
|
| 11 |
+
level=logging.INFO,
|
| 12 |
+
format="%(asctime)s | %(levelname)s | %(message)s",
|
| 13 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 14 |
+
)
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Set Gradio temp directory BEFORE importing gradio to avoid permission issues
|
| 18 |
+
TEMP_DIR = Path(__file__).parent / "gradio_temp"
|
| 19 |
+
if TEMP_DIR.exists():
|
| 20 |
+
shutil.rmtree(str(TEMP_DIR))
|
| 21 |
+
TEMP_DIR.mkdir(exist_ok=True)
|
| 22 |
+
os.environ["GRADIO_TEMP_DIR"] = str(TEMP_DIR)
|
| 23 |
+
os.environ["TMPDIR"] = str(TEMP_DIR)
|
| 24 |
+
|
| 25 |
+
import gradio as gr
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
from PIL import Image
|
| 29 |
+
|
| 30 |
+
IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
|
| 31 |
+
|
| 32 |
+
try:
|
| 33 |
+
import spaces
|
| 34 |
+
zero_gpu = spaces.GPU
|
| 35 |
+
except ImportError:
|
| 36 |
+
zero_gpu = lambda f: f
|
| 37 |
+
|
| 38 |
+
from flowdis.sampling import flowdis_predict
|
| 39 |
+
from flowdis.util import load_models
|
| 40 |
+
from qwen import expand_prompt
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
models = None
|
| 44 |
+
device = "cuda"
|
| 45 |
+
if torch.cuda.is_available():
|
| 46 |
+
models = load_models(device=device)
|
| 47 |
+
else:
|
| 48 |
+
print("No GPU available, the demo will not be able to run.")
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def disable_download_btn():
|
| 52 |
+
return gr.update(interactive=False)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@zero_gpu
|
| 56 |
+
def process_image(image, prompt, expand_prompt_enabled, resolution, num_inference_steps):
|
| 57 |
+
"""
|
| 58 |
+
Process the input image and prompt.
|
| 59 |
+
This is a placeholder function - replace with your actual processing logic.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
image: PIL Image or numpy array
|
| 63 |
+
prompt: str, the text input from the user
|
| 64 |
+
expand_prompt_enabled: bool, whether to expand the prompt via the model
|
| 65 |
+
resolution: int, the inference resolution
|
| 66 |
+
num_inference_steps: int, the number of inference steps
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
Processed image
|
| 70 |
+
"""
|
| 71 |
+
if image is None:
|
| 72 |
+
return None, None
|
| 73 |
+
|
| 74 |
+
if isinstance(image, np.ndarray):
|
| 75 |
+
image = Image.fromarray(image)
|
| 76 |
+
|
| 77 |
+
logger.info(f"Original prompt: {prompt}")
|
| 78 |
+
if prompt != "" and expand_prompt_enabled:
|
| 79 |
+
prompt = expand_prompt(image, prompt)
|
| 80 |
+
logger.info(f"Expanded prompt: {prompt}")
|
| 81 |
+
|
| 82 |
+
num_inference_steps = int(num_inference_steps)
|
| 83 |
+
|
| 84 |
+
pred_mask = flowdis_predict(
|
| 85 |
+
image=image,
|
| 86 |
+
prompt=prompt,
|
| 87 |
+
models=models,
|
| 88 |
+
resolution=resolution,
|
| 89 |
+
num_inference_steps=num_inference_steps,
|
| 90 |
+
device=device,
|
| 91 |
+
)
|
| 92 |
+
blacked_image = Image.fromarray(np.array(image) * (np.array(pred_mask)[:, :, np.newaxis] > 0).astype(np.uint8))
|
| 93 |
+
transparent_png = Image.fromarray(np.dstack([blacked_image, np.array(pred_mask)]))
|
| 94 |
+
uid = uuid.uuid4().hex
|
| 95 |
+
png_path = TEMP_DIR / f"{uid}.png"
|
| 96 |
+
transparent_png.save(png_path)
|
| 97 |
+
return (
|
| 98 |
+
gr.update(value=[image, transparent_png], key=uid),
|
| 99 |
+
gr.update(value=str(png_path), interactive=True)
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# Load examples from assets/examples/examples.csv: image_name, prompt, resolution, num_steps
|
| 104 |
+
_example_dir = Path(__file__).parent / "assets" / "examples"
|
| 105 |
+
_examples_csv = _example_dir / "examples.csv"
|
| 106 |
+
examples = []
|
| 107 |
+
if _examples_csv.exists():
|
| 108 |
+
with open(_examples_csv, newline="", encoding="utf-8") as f:
|
| 109 |
+
for row in csv.DictReader(f):
|
| 110 |
+
image_path = str(_example_dir / row["image_name"].strip())
|
| 111 |
+
examples.append([
|
| 112 |
+
image_path,
|
| 113 |
+
row["prompt"].strip(),
|
| 114 |
+
True, # expand prompt (default for examples)
|
| 115 |
+
int(row["resolution"].strip()),
|
| 116 |
+
int(row["num_steps"].strip()),
|
| 117 |
+
])
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
_head_js = """
|
| 121 |
+
<style>
|
| 122 |
+
#expand-prompt.is-disabled { pointer-events: none !important; }
|
| 123 |
+
#expand-prompt.is-disabled label,
|
| 124 |
+
#expand-prompt.is-disabled input,
|
| 125 |
+
#expand-prompt.is-disabled .info { opacity: 0.4 !important; }
|
| 126 |
+
/* Hide the "Expand prompt" column (3rd) in the examples table */
|
| 127 |
+
#examples-table table th:nth-child(3),
|
| 128 |
+
#examples-table table td:nth-child(3) { display: none !important; }
|
| 129 |
+
</style>
|
| 130 |
+
<script>
|
| 131 |
+
(function() {
|
| 132 |
+
function findEls() {
|
| 133 |
+
return {
|
| 134 |
+
ta: document.querySelector('#text-prompt textarea, #text-prompt input'),
|
| 135 |
+
cb: document.querySelector('#expand-prompt'),
|
| 136 |
+
};
|
| 137 |
+
}
|
| 138 |
+
function syncFromText() {
|
| 139 |
+
var els = findEls();
|
| 140 |
+
if (!els.ta || !els.cb) return;
|
| 141 |
+
var empty = !els.ta.value.trim();
|
| 142 |
+
els.cb.classList.toggle('is-disabled', empty);
|
| 143 |
+
var input = els.cb.querySelector('input[type=checkbox]');
|
| 144 |
+
if (input) input.disabled = empty;
|
| 145 |
+
}
|
| 146 |
+
function init() {
|
| 147 |
+
var els = findEls();
|
| 148 |
+
if (!els.ta || !els.cb) { setTimeout(init, 200); return; }
|
| 149 |
+
els.ta.addEventListener('input', syncFromText);
|
| 150 |
+
els.ta.addEventListener('change', syncFromText);
|
| 151 |
+
// Catch programmatic value changes (e.g. example selection)
|
| 152 |
+
var lastVal = els.ta.value;
|
| 153 |
+
setInterval(function() {
|
| 154 |
+
if (els.ta.value !== lastVal) { lastVal = els.ta.value; syncFromText(); }
|
| 155 |
+
}, 250);
|
| 156 |
+
syncFromText();
|
| 157 |
+
}
|
| 158 |
+
if (document.readyState === 'loading')
|
| 159 |
+
document.addEventListener('DOMContentLoaded', init);
|
| 160 |
+
else
|
| 161 |
+
init();
|
| 162 |
+
})();
|
| 163 |
+
</script>
|
| 164 |
+
<script>
|
| 165 |
+
(function() {
|
| 166 |
+
function findEls() {
|
| 167 |
+
return {
|
| 168 |
+
ta: document.querySelector('#text-prompt textarea, #text-prompt input'),
|
| 169 |
+
cb: document.querySelector('#expand-prompt'),
|
| 170 |
+
};
|
| 171 |
+
}
|
| 172 |
+
function syncFromText() {
|
| 173 |
+
var els = findEls();
|
| 174 |
+
if (!els.ta || !els.cb) return;
|
| 175 |
+
var empty = !els.ta.value.trim();
|
| 176 |
+
els.cb.classList.toggle('is-disabled', empty);
|
| 177 |
+
var input = els.cb.querySelector('input[type=checkbox]');
|
| 178 |
+
if (input) input.disabled = empty;
|
| 179 |
+
}
|
| 180 |
+
function init() {
|
| 181 |
+
var els = findEls();
|
| 182 |
+
if (!els.ta || !els.cb) { setTimeout(init, 200); return; }
|
| 183 |
+
els.ta.addEventListener('input', syncFromText);
|
| 184 |
+
els.ta.addEventListener('change', syncFromText);
|
| 185 |
+
// Catch programmatic value changes (e.g. example selection)
|
| 186 |
+
var lastVal = els.ta.value;
|
| 187 |
+
setInterval(function() {
|
| 188 |
+
if (els.ta.value !== lastVal) { lastVal = els.ta.value; syncFromText(); }
|
| 189 |
+
}, 250);
|
| 190 |
+
syncFromText();
|
| 191 |
+
}
|
| 192 |
+
if (document.readyState === 'loading')
|
| 193 |
+
document.addEventListener('DOMContentLoaded', init);
|
| 194 |
+
else
|
| 195 |
+
init();
|
| 196 |
+
})();
|
| 197 |
+
</script>
|
| 198 |
+
"""
|
| 199 |
+
with gr.Blocks(
|
| 200 |
+
title="FlowDIS – Precise Background Removal",
|
| 201 |
+
head=_head_js,
|
| 202 |
+
theme=gr.themes.Default(
|
| 203 |
+
font=gr.themes.GoogleFont("Inter"),
|
| 204 |
+
).set(
|
| 205 |
+
button_primary_background_fill="#C209C1",
|
| 206 |
+
button_primary_background_fill_dark="#C209C1",
|
| 207 |
+
button_primary_background_fill_hover="#d63bd5",
|
| 208 |
+
button_primary_background_fill_hover_dark="#d63bd5",
|
| 209 |
+
button_primary_text_color="#ffffff",
|
| 210 |
+
button_primary_text_color_dark="#ffffff",
|
| 211 |
+
),
|
| 212 |
+
delete_cache=(1800, 1800)
|
| 213 |
+
) as demo:
|
| 214 |
+
gr.HTML(
|
| 215 |
+
"""
|
| 216 |
+
<style>
|
| 217 |
+
/* Theme-adaptive tokens */
|
| 218 |
+
:root {
|
| 219 |
+
--flow-text: #0f172a; /* slate-900 */
|
| 220 |
+
--flow-muted: #475569; /* slate-600 */
|
| 221 |
+
--flow-link: #2563eb; /* blue-600 */
|
| 222 |
+
--flow-link-hover: #1d4ed8; /* blue-700 */
|
| 223 |
+
--flow-title: #C209C1; /* Picsart pink */
|
| 224 |
+
}
|
| 225 |
+
|
| 226 |
+
@media (prefers-color-scheme: dark) {
|
| 227 |
+
:root {
|
| 228 |
+
--flow-text: #f1f5f9; /* slate-100 */
|
| 229 |
+
--flow-muted: #94a3b8; /* slate-400 */
|
| 230 |
+
--flow-link: #60a5fa; /* blue-400 */
|
| 231 |
+
--flow-link-hover: #93c5fd; /* blue-300 */
|
| 232 |
+
--flow-title: #e45fe3; /* Picsart pink (lighter for dark mode) */
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
.flow-header {
|
| 237 |
+
text-align: center;
|
| 238 |
+
max-width: 900px;
|
| 239 |
+
margin: 18px auto 12px auto;
|
| 240 |
+
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
.flow-title {
|
| 244 |
+
font-size: 1.9rem;
|
| 245 |
+
font-weight: 750;
|
| 246 |
+
letter-spacing: -0.3px;
|
| 247 |
+
margin-bottom: 4px;
|
| 248 |
+
color: var(--flow-title); /* title accent (needle stays as-is) */
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
.flow-links {
|
| 252 |
+
margin-bottom: 8px;
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
.flow-links a {
|
| 256 |
+
color: var(--flow-link); /* cool blue links */
|
| 257 |
+
font-weight: 600;
|
| 258 |
+
text-decoration: none;
|
| 259 |
+
margin: 0 0px;
|
| 260 |
+
font-size: 0.95rem;
|
| 261 |
+
transition: color 0.2s ease, text-shadow 0.2s ease;
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
.flow-links a:hover {
|
| 265 |
+
color: var(--flow-link-hover);
|
| 266 |
+
text-shadow: 0 0 10px rgba(37, 99, 235, 0.25);
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
@media (prefers-color-scheme: dark) {
|
| 270 |
+
.flow-links a:hover {
|
| 271 |
+
text-shadow: 0 0 12px rgba(147, 197, 253, 0.35);
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
.flow-desc {
|
| 276 |
+
font-size: 0.95rem;
|
| 277 |
+
color: var(--flow-muted);
|
| 278 |
+
max-width: 650px;
|
| 279 |
+
margin: 0 auto;
|
| 280 |
+
line-height: 1.5;
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
.bg-btn-row { display: flex; gap: 6px; overflow-x: auto; scrollbar-width: thin; }
|
| 284 |
+
.bg-btn {
|
| 285 |
+
width: 42px !important; height: 42px !important;
|
| 286 |
+
border: 2.5px solid #aaa !important; border-radius: 8px !important;
|
| 287 |
+
cursor: pointer !important; flex-shrink: 0 !important;
|
| 288 |
+
padding: 0 !important; outline: none !important;
|
| 289 |
+
transition: transform 0.15s ease, box-shadow 0.15s ease,
|
| 290 |
+
border-color 0.15s ease, filter 0.15s ease;
|
| 291 |
+
}
|
| 292 |
+
.bg-btn:hover {
|
| 293 |
+
transform: scale(1.15);
|
| 294 |
+
border-color: #333 !important;
|
| 295 |
+
box-shadow: 0 3px 10px rgba(0,0,0,0.4);
|
| 296 |
+
filter: brightness(1.15);
|
| 297 |
+
}
|
| 298 |
+
.bg-btn:active {
|
| 299 |
+
transform: scale(0.95);
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
@media (max-width: 1024px) {
|
| 304 |
+
#main-row {
|
| 305 |
+
flex-direction: column !important;
|
| 306 |
+
flex-wrap: wrap !important;
|
| 307 |
+
}
|
| 308 |
+
#main-row > * {
|
| 309 |
+
width: 100% !important;
|
| 310 |
+
flex: 1 1 100% !important;
|
| 311 |
+
min-width: 0 !important;
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
@media (max-width: 500px) {
|
| 316 |
+
#input-image { height: 400px !important; }
|
| 317 |
+
}
|
| 318 |
+
@media (max-width: 400px) {
|
| 319 |
+
#input-image { height: 300px !important; }
|
| 320 |
+
}
|
| 321 |
+
.prose :is(label span, .info) { font-weight: 400 !important; }
|
| 322 |
+
</style>
|
| 323 |
+
|
| 324 |
+
<div class="flow-header">
|
| 325 |
+
<div class="flow-title"><span style="color:#C209C1">✦</span> FlowDIS Demo</div>
|
| 326 |
+
|
| 327 |
+
<div class="flow-links">
|
| 328 |
+
<span>📄</span><a href="https://arxiv.org/" target="_blank" rel="noopener noreferrer">arXiv</a>
|
| 329 |
+
<span>💻</span><a href="https://github.com/Picsart-AI-Research/FlowDIS" target="_blank" rel="noopener noreferrer">Code</a>
|
| 330 |
+
</div>
|
| 331 |
+
|
| 332 |
+
<div class="flow-desc">
|
| 333 |
+
FlowDIS performs precise foreground segmentation, optionally guided by a text prompt to only preserve the specified objects.
|
| 334 |
+
</div>
|
| 335 |
+
</div>
|
| 336 |
+
"""
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
with gr.Row(elem_id="main-row"):
|
| 340 |
+
# Left column: Input image, text field, and submit button
|
| 341 |
+
with gr.Column(scale=1):
|
| 342 |
+
input_image = gr.Image(
|
| 343 |
+
label="Input Image",
|
| 344 |
+
type="pil",
|
| 345 |
+
height=500,
|
| 346 |
+
elem_id="input-image",
|
| 347 |
+
)
|
| 348 |
+
text_input = gr.Textbox(
|
| 349 |
+
label="Text Prompt (Optional)",
|
| 350 |
+
placeholder="Enter what you want to retain...",
|
| 351 |
+
lines=1,
|
| 352 |
+
elem_id="text-prompt",
|
| 353 |
+
)
|
| 354 |
+
expand_prompt_check = gr.Checkbox(
|
| 355 |
+
label="Expand prompt",
|
| 356 |
+
value=True,
|
| 357 |
+
elem_id="expand-prompt",
|
| 358 |
+
info="Use Qwen3-VL-4B-Instruct model to expand the prompt for better text-guided segmentation.",
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Sliders for resolution and steps
|
| 362 |
+
with gr.Row():
|
| 363 |
+
with gr.Column(scale=1, min_width=300):
|
| 364 |
+
resolution_slider = gr.Slider(
|
| 365 |
+
minimum=1024,
|
| 366 |
+
maximum=2048,
|
| 367 |
+
value=1536,
|
| 368 |
+
step=64,
|
| 369 |
+
label="Inference Resolution",
|
| 370 |
+
info="Higher resolution preserves more details.",
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
with gr.Column(scale=1, min_width=300):
|
| 374 |
+
steps_slider = gr.Slider(
|
| 375 |
+
minimum=1,
|
| 376 |
+
maximum=12,
|
| 377 |
+
value=4,
|
| 378 |
+
step=1,
|
| 379 |
+
label="Number of Steps",
|
| 380 |
+
info="More steps generate sharper results.",
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
submit_btn = gr.Button("🚀 Remove Background", variant="primary")
|
| 384 |
+
|
| 385 |
+
# Right column: Output image
|
| 386 |
+
with gr.Column(scale=1):
|
| 387 |
+
output_image = gr.ImageSlider(
|
| 388 |
+
label="FlowDIS prediction",
|
| 389 |
+
type="pil",
|
| 390 |
+
format="webp",
|
| 391 |
+
height=500,
|
| 392 |
+
slider_position=10,
|
| 393 |
+
elem_id="output-slider",
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
_checker = "repeating-conic-gradient(#ccc 0% 25%,#fff 0% 50%) 50%/12px 12px"
|
| 397 |
+
_bg_buttons = [
|
| 398 |
+
(_checker, _checker),
|
| 399 |
+
("#ffffff", "#ffffff"),
|
| 400 |
+
("#000000", "#000000"),
|
| 401 |
+
("#00ff00", "#00ff00"),
|
| 402 |
+
("#0000ff", "#0000ff"),
|
| 403 |
+
("#ff0000", "#ff0000"),
|
| 404 |
+
("#ffff00", "#ffff00"),
|
| 405 |
+
("#ff00ff", "#ff00ff"),
|
| 406 |
+
("#00ffff", "#00ffff"),
|
| 407 |
+
]
|
| 408 |
+
_onclick = (
|
| 409 |
+
"var s=document.getElementById('slider-bg-style');"
|
| 410 |
+
"if(!s){s=document.createElement('style');"
|
| 411 |
+
"s.id='slider-bg-style';document.head.appendChild(s);}"
|
| 412 |
+
"s.textContent='#output-slider img,#output-slider canvas"
|
| 413 |
+
"{background:'+this.dataset.bg+' !important}';"
|
| 414 |
+
)
|
| 415 |
+
gr.HTML(
|
| 416 |
+
value='<div class="bg-btn-row">'
|
| 417 |
+
+ "".join(
|
| 418 |
+
f'<button class="bg-btn" style="background:{style}"'
|
| 419 |
+
f' data-bg="{bg}" onclick="{_onclick}"></button>'
|
| 420 |
+
for style, bg in _bg_buttons
|
| 421 |
+
)
|
| 422 |
+
+ "</div>"
|
| 423 |
+
)
|
| 424 |
+
|
| 425 |
+
download_btn = gr.DownloadButton(
|
| 426 |
+
label="📥 Download PNG",
|
| 427 |
+
variant="primary",
|
| 428 |
+
interactive=False
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# Connect the submit button to the processing function
|
| 432 |
+
submit_btn.click(
|
| 433 |
+
disable_download_btn,
|
| 434 |
+
outputs=download_btn
|
| 435 |
+
).then(
|
| 436 |
+
fn=process_image,
|
| 437 |
+
inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider],
|
| 438 |
+
outputs=[output_image, download_btn]
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
# Optional: Also trigger on text input enter key
|
| 442 |
+
text_input.submit(
|
| 443 |
+
disable_download_btn,
|
| 444 |
+
outputs=download_btn
|
| 445 |
+
).then(
|
| 446 |
+
fn=process_image,
|
| 447 |
+
inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider],
|
| 448 |
+
outputs=[output_image, download_btn],
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
examples_component = gr.Examples(
|
| 452 |
+
examples=examples,
|
| 453 |
+
inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider],
|
| 454 |
+
label="Examples",
|
| 455 |
+
elem_id="examples-table",
|
| 456 |
+
)
|
| 457 |
+
|
| 458 |
+
examples_component.dataset.click(
|
| 459 |
+
disable_download_btn,
|
| 460 |
+
outputs=download_btn
|
| 461 |
+
).then(
|
| 462 |
+
process_image,
|
| 463 |
+
inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider],
|
| 464 |
+
outputs=[output_image, download_btn],
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# Launch the app
|
| 469 |
+
if __name__ == "__main__":
|
| 470 |
+
demo.queue(max_size=20)
|
| 471 |
+
if IS_HF_SPACE:
|
| 472 |
+
demo.launch(allowed_paths=[str(TEMP_DIR), "assets"])
|
| 473 |
+
else:
|
| 474 |
+
demo.launch(
|
| 475 |
+
server_name="0.0.0.0",
|
| 476 |
+
server_port=7860,
|
| 477 |
+
share=True,
|
| 478 |
+
allowed_paths=[str(TEMP_DIR), "assets"],
|
| 479 |
+
)
|
| 480 |
+
|
assets/examples/0.jpg
ADDED
|
Git LFS Details
|
assets/examples/1.jpg
ADDED
|
Git LFS Details
|
assets/examples/2.png
ADDED
|
Git LFS Details
|
assets/examples/3.jpg
ADDED
|
Git LFS Details
|
assets/examples/4.jpg
ADDED
|
Git LFS Details
|
assets/examples/5.jpg
ADDED
|
Git LFS Details
|
assets/examples/6.jpg
ADDED
|
Git LFS Details
|
assets/examples/examples.csv
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_name,prompt,resolution,num_steps
|
| 2 |
+
0.jpg,,2048,8
|
| 3 |
+
1.jpg,,2048,8
|
| 4 |
+
2.png,,1536,4
|
| 5 |
+
3.jpg,,1536,2
|
| 6 |
+
4.jpg,measuring tape,2048,8
|
| 7 |
+
5.jpg,white shoes,1280,2
|
| 8 |
+
6.jpg,bicycle,2048,8
|
assets/examples/prompts.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"0.jpg": "",
|
| 3 |
+
"1.jpg": "",
|
| 4 |
+
"2.png": "",
|
| 5 |
+
"3.jpg": "",
|
| 6 |
+
"4.jpg": "measuring tape",
|
| 7 |
+
"5.jpg": "white shoes",
|
| 8 |
+
"6.jpg": "bicycle"
|
| 9 |
+
}
|
assets/preview.png
ADDED
|
Git LFS Details
|
flowdis/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FlowDIS: Language-Guided Dichotomous Image Segmentation with Flow Matching"""
|
| 2 |
+
|
| 3 |
+
from flowdis.configs import configs
|
| 4 |
+
from flowdis.loaders import load_autoencoder, load_clip, load_t5, load_transformer
|
| 5 |
+
from flowdis.sampling import flowdis_predict
|
| 6 |
+
from flowdis.util import load_models
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"configs",
|
| 10 |
+
"load_autoencoder",
|
| 11 |
+
"load_clip",
|
| 12 |
+
"load_t5",
|
| 13 |
+
"load_transformer",
|
| 14 |
+
"flowdis_predict",
|
| 15 |
+
"load_models",
|
| 16 |
+
]
|
flowdis/autoencoder.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
from torch import Tensor, nn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@dataclass
|
| 9 |
+
class AutoEncoderParams:
|
| 10 |
+
resolution: int
|
| 11 |
+
in_channels: int
|
| 12 |
+
ch: int
|
| 13 |
+
out_ch: int
|
| 14 |
+
ch_mult: list[int]
|
| 15 |
+
num_res_blocks: int
|
| 16 |
+
z_channels: int
|
| 17 |
+
scale_factor: float
|
| 18 |
+
shift_factor: float
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def swish(x: Tensor) -> Tensor:
|
| 22 |
+
return x * torch.sigmoid(x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AttnBlock(nn.Module):
|
| 26 |
+
def __init__(self, in_channels: int):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.in_channels = in_channels
|
| 29 |
+
|
| 30 |
+
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 31 |
+
|
| 32 |
+
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 33 |
+
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 34 |
+
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 35 |
+
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
|
| 36 |
+
|
| 37 |
+
def attention(self, h_: Tensor) -> Tensor:
|
| 38 |
+
h_ = self.norm(h_)
|
| 39 |
+
q = self.q(h_)
|
| 40 |
+
k = self.k(h_)
|
| 41 |
+
v = self.v(h_)
|
| 42 |
+
|
| 43 |
+
b, c, h, w = q.shape
|
| 44 |
+
q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
|
| 45 |
+
k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
|
| 46 |
+
v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
|
| 47 |
+
h_ = nn.functional.scaled_dot_product_attention(q, k, v)
|
| 48 |
+
|
| 49 |
+
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
| 50 |
+
|
| 51 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 52 |
+
return x + self.proj_out(self.attention(x))
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class ResnetBlock(nn.Module):
|
| 56 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.in_channels = in_channels
|
| 59 |
+
out_channels = in_channels if out_channels is None else out_channels
|
| 60 |
+
self.out_channels = out_channels
|
| 61 |
+
|
| 62 |
+
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
| 63 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 64 |
+
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
| 65 |
+
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
| 66 |
+
if self.in_channels != self.out_channels:
|
| 67 |
+
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
h = x
|
| 71 |
+
h = self.norm1(h)
|
| 72 |
+
h = swish(h)
|
| 73 |
+
h = self.conv1(h)
|
| 74 |
+
|
| 75 |
+
h = self.norm2(h)
|
| 76 |
+
h = swish(h)
|
| 77 |
+
h = self.conv2(h)
|
| 78 |
+
|
| 79 |
+
if self.in_channels != self.out_channels:
|
| 80 |
+
x = self.nin_shortcut(x)
|
| 81 |
+
|
| 82 |
+
return x + h
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Downsample(nn.Module):
|
| 86 |
+
def __init__(self, in_channels: int):
|
| 87 |
+
super().__init__()
|
| 88 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
| 89 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
| 90 |
+
|
| 91 |
+
def forward(self, x: Tensor):
|
| 92 |
+
pad = (0, 1, 0, 1)
|
| 93 |
+
x = nn.functional.pad(x, pad, mode="constant", value=0)
|
| 94 |
+
x = self.conv(x)
|
| 95 |
+
return x
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class Upsample(nn.Module):
|
| 99 |
+
def __init__(self, in_channels: int):
|
| 100 |
+
super().__init__()
|
| 101 |
+
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
| 102 |
+
|
| 103 |
+
def forward(self, x: Tensor):
|
| 104 |
+
x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
| 105 |
+
x = self.conv(x)
|
| 106 |
+
return x
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Encoder(nn.Module):
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
resolution: int,
|
| 113 |
+
in_channels: int,
|
| 114 |
+
ch: int,
|
| 115 |
+
ch_mult: list[int],
|
| 116 |
+
num_res_blocks: int,
|
| 117 |
+
z_channels: int,
|
| 118 |
+
):
|
| 119 |
+
super().__init__()
|
| 120 |
+
self.ch = ch
|
| 121 |
+
self.num_resolutions = len(ch_mult)
|
| 122 |
+
self.num_res_blocks = num_res_blocks
|
| 123 |
+
self.resolution = resolution
|
| 124 |
+
self.in_channels = in_channels
|
| 125 |
+
# downsampling
|
| 126 |
+
self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
| 127 |
+
|
| 128 |
+
curr_res = resolution
|
| 129 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
| 130 |
+
self.in_ch_mult = in_ch_mult
|
| 131 |
+
self.down = nn.ModuleList()
|
| 132 |
+
block_in = self.ch
|
| 133 |
+
for i_level in range(self.num_resolutions):
|
| 134 |
+
block = nn.ModuleList()
|
| 135 |
+
attn = nn.ModuleList()
|
| 136 |
+
block_in = ch * in_ch_mult[i_level]
|
| 137 |
+
block_out = ch * ch_mult[i_level]
|
| 138 |
+
for _ in range(self.num_res_blocks):
|
| 139 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 140 |
+
block_in = block_out
|
| 141 |
+
down = nn.Module()
|
| 142 |
+
down.block = block
|
| 143 |
+
down.attn = attn
|
| 144 |
+
if i_level != self.num_resolutions - 1:
|
| 145 |
+
down.downsample = Downsample(block_in)
|
| 146 |
+
curr_res = curr_res // 2
|
| 147 |
+
self.down.append(down)
|
| 148 |
+
|
| 149 |
+
# middle
|
| 150 |
+
self.mid = nn.Module()
|
| 151 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 152 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 153 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 154 |
+
|
| 155 |
+
# end
|
| 156 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 157 |
+
self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
|
| 158 |
+
|
| 159 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 160 |
+
# downsampling
|
| 161 |
+
hs = [self.conv_in(x)]
|
| 162 |
+
for i_level in range(self.num_resolutions):
|
| 163 |
+
for i_block in range(self.num_res_blocks):
|
| 164 |
+
h = self.down[i_level].block[i_block](hs[-1])
|
| 165 |
+
if len(self.down[i_level].attn) > 0:
|
| 166 |
+
h = self.down[i_level].attn[i_block](h)
|
| 167 |
+
hs.append(h)
|
| 168 |
+
if i_level != self.num_resolutions - 1:
|
| 169 |
+
hs.append(self.down[i_level].downsample(hs[-1]))
|
| 170 |
+
|
| 171 |
+
# middle
|
| 172 |
+
h = hs[-1]
|
| 173 |
+
h = self.mid.block_1(h)
|
| 174 |
+
h = self.mid.attn_1(h)
|
| 175 |
+
h = self.mid.block_2(h)
|
| 176 |
+
# end
|
| 177 |
+
h = self.norm_out(h)
|
| 178 |
+
h = swish(h)
|
| 179 |
+
h = self.conv_out(h)
|
| 180 |
+
return h
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class Decoder(nn.Module):
|
| 184 |
+
def __init__(
|
| 185 |
+
self,
|
| 186 |
+
ch: int,
|
| 187 |
+
out_ch: int,
|
| 188 |
+
ch_mult: list[int],
|
| 189 |
+
num_res_blocks: int,
|
| 190 |
+
in_channels: int,
|
| 191 |
+
resolution: int,
|
| 192 |
+
z_channels: int,
|
| 193 |
+
):
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.ch = ch
|
| 196 |
+
self.num_resolutions = len(ch_mult)
|
| 197 |
+
self.num_res_blocks = num_res_blocks
|
| 198 |
+
self.resolution = resolution
|
| 199 |
+
self.in_channels = in_channels
|
| 200 |
+
self.ffactor = 2 ** (self.num_resolutions - 1)
|
| 201 |
+
|
| 202 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
| 203 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
| 204 |
+
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
| 205 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
| 206 |
+
|
| 207 |
+
# z to block_in
|
| 208 |
+
self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
| 209 |
+
|
| 210 |
+
# middle
|
| 211 |
+
self.mid = nn.Module()
|
| 212 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 213 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
| 214 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
|
| 215 |
+
|
| 216 |
+
# upsampling
|
| 217 |
+
self.up = nn.ModuleList()
|
| 218 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 219 |
+
block = nn.ModuleList()
|
| 220 |
+
attn = nn.ModuleList()
|
| 221 |
+
block_out = ch * ch_mult[i_level]
|
| 222 |
+
for _ in range(self.num_res_blocks + 1):
|
| 223 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
|
| 224 |
+
block_in = block_out
|
| 225 |
+
up = nn.Module()
|
| 226 |
+
up.block = block
|
| 227 |
+
up.attn = attn
|
| 228 |
+
if i_level != 0:
|
| 229 |
+
up.upsample = Upsample(block_in)
|
| 230 |
+
curr_res = curr_res * 2
|
| 231 |
+
self.up.insert(0, up) # prepend to get consistent order
|
| 232 |
+
|
| 233 |
+
# end
|
| 234 |
+
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
|
| 235 |
+
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
| 236 |
+
|
| 237 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 238 |
+
# get dtype for proper tracing
|
| 239 |
+
upscale_dtype = next(self.up.parameters()).dtype
|
| 240 |
+
|
| 241 |
+
# z to block_in
|
| 242 |
+
h = self.conv_in(z)
|
| 243 |
+
|
| 244 |
+
# middle
|
| 245 |
+
h = self.mid.block_1(h)
|
| 246 |
+
h = self.mid.attn_1(h)
|
| 247 |
+
h = self.mid.block_2(h)
|
| 248 |
+
|
| 249 |
+
# cast to proper dtype
|
| 250 |
+
h = h.to(upscale_dtype)
|
| 251 |
+
# upsampling
|
| 252 |
+
for i_level in reversed(range(self.num_resolutions)):
|
| 253 |
+
for i_block in range(self.num_res_blocks + 1):
|
| 254 |
+
h = self.up[i_level].block[i_block](h)
|
| 255 |
+
if len(self.up[i_level].attn) > 0:
|
| 256 |
+
h = self.up[i_level].attn[i_block](h)
|
| 257 |
+
if i_level != 0:
|
| 258 |
+
h = self.up[i_level].upsample(h)
|
| 259 |
+
|
| 260 |
+
# end
|
| 261 |
+
h = self.norm_out(h)
|
| 262 |
+
h = swish(h)
|
| 263 |
+
h = self.conv_out(h)
|
| 264 |
+
return h
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class DiagonalGaussian(nn.Module):
|
| 268 |
+
def __init__(self, sample: bool = True, chunk_dim: int = 1):
|
| 269 |
+
super().__init__()
|
| 270 |
+
self.sample = sample
|
| 271 |
+
self.chunk_dim = chunk_dim
|
| 272 |
+
|
| 273 |
+
def forward(self, z: Tensor) -> Tensor:
|
| 274 |
+
mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
|
| 275 |
+
if self.sample:
|
| 276 |
+
std = torch.exp(0.5 * logvar)
|
| 277 |
+
return mean + std * torch.randn_like(mean)
|
| 278 |
+
else:
|
| 279 |
+
return mean
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class AutoEncoder(nn.Module):
|
| 283 |
+
def __init__(self, params: AutoEncoderParams):
|
| 284 |
+
super().__init__()
|
| 285 |
+
self.params = params
|
| 286 |
+
self.encoder = Encoder(
|
| 287 |
+
resolution=params.resolution,
|
| 288 |
+
in_channels=params.in_channels,
|
| 289 |
+
ch=params.ch,
|
| 290 |
+
ch_mult=params.ch_mult,
|
| 291 |
+
num_res_blocks=params.num_res_blocks,
|
| 292 |
+
z_channels=params.z_channels,
|
| 293 |
+
)
|
| 294 |
+
self.decoder = Decoder(
|
| 295 |
+
resolution=params.resolution,
|
| 296 |
+
in_channels=params.in_channels,
|
| 297 |
+
ch=params.ch,
|
| 298 |
+
out_ch=params.out_ch,
|
| 299 |
+
ch_mult=params.ch_mult,
|
| 300 |
+
num_res_blocks=params.num_res_blocks,
|
| 301 |
+
z_channels=params.z_channels,
|
| 302 |
+
)
|
| 303 |
+
self.reg = DiagonalGaussian()
|
| 304 |
+
|
| 305 |
+
self.scale_factor = params.scale_factor
|
| 306 |
+
self.shift_factor = params.shift_factor
|
| 307 |
+
|
| 308 |
+
def encode(self, x: Tensor) -> Tensor:
|
| 309 |
+
z = self.reg(self.encoder(x))
|
| 310 |
+
z = self.scale_factor * (z - self.shift_factor)
|
| 311 |
+
return z
|
| 312 |
+
|
| 313 |
+
def decode(self, z: Tensor) -> Tensor:
|
| 314 |
+
z = z / self.scale_factor + self.shift_factor
|
| 315 |
+
return self.decoder(z)
|
| 316 |
+
|
| 317 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 318 |
+
return self.decode(self.encode(x))
|
flowdis/conditioner.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import Tensor, nn
|
| 2 |
+
from transformers import (
|
| 3 |
+
CLIPTextConfig,
|
| 4 |
+
CLIPTextModel,
|
| 5 |
+
CLIPTokenizer,
|
| 6 |
+
T5Config,
|
| 7 |
+
T5EncoderModel,
|
| 8 |
+
T5Tokenizer,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HFEmbedder(nn.Module):
|
| 13 |
+
def __init__(self, version: str, max_length: int, is_clip: bool, **hf_kwargs):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.is_clip = is_clip
|
| 16 |
+
self.max_length = max_length
|
| 17 |
+
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
|
| 18 |
+
|
| 19 |
+
if self.is_clip:
|
| 20 |
+
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
|
| 21 |
+
config = CLIPTextConfig.from_pretrained(version, **hf_kwargs)
|
| 22 |
+
self.hf_module: CLIPTextModel = CLIPTextModel._from_config(config)
|
| 23 |
+
else:
|
| 24 |
+
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, legacy=True)
|
| 25 |
+
config = T5Config.from_pretrained(version, **hf_kwargs)
|
| 26 |
+
self.hf_module: T5EncoderModel = T5EncoderModel._from_config(config)
|
| 27 |
+
|
| 28 |
+
def forward(self, text: list[str]) -> Tensor:
|
| 29 |
+
batch_encoding = self.tokenizer(
|
| 30 |
+
text,
|
| 31 |
+
truncation=True,
|
| 32 |
+
max_length=self.max_length,
|
| 33 |
+
return_length=False,
|
| 34 |
+
return_overflowing_tokens=False,
|
| 35 |
+
padding="max_length",
|
| 36 |
+
return_tensors="pt",
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
outputs = self.hf_module(
|
| 40 |
+
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
| 41 |
+
attention_mask=None,
|
| 42 |
+
output_hidden_states=False,
|
| 43 |
+
)
|
| 44 |
+
return outputs[self.output_key]
|
flowdis/configs.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flowdis.autoencoder import AutoEncoderParams
|
| 2 |
+
from flowdis.model import FluxParams
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
configs = {
|
| 6 |
+
"autoencoder": AutoEncoderParams(
|
| 7 |
+
resolution=256,
|
| 8 |
+
in_channels=3,
|
| 9 |
+
ch=128,
|
| 10 |
+
out_ch=3,
|
| 11 |
+
ch_mult=[1, 2, 4, 4],
|
| 12 |
+
num_res_blocks=2,
|
| 13 |
+
z_channels=16,
|
| 14 |
+
scale_factor=0.3611,
|
| 15 |
+
shift_factor=0.1159,
|
| 16 |
+
),
|
| 17 |
+
"flowdis": FluxParams(
|
| 18 |
+
in_channels=128,
|
| 19 |
+
out_channels=64,
|
| 20 |
+
vec_in_dim=768,
|
| 21 |
+
context_in_dim=4096,
|
| 22 |
+
hidden_size=3072,
|
| 23 |
+
mlp_ratio=4.0,
|
| 24 |
+
num_heads=24,
|
| 25 |
+
depth=19,
|
| 26 |
+
depth_single_blocks=38,
|
| 27 |
+
axes_dim=[16, 56, 56],
|
| 28 |
+
theta=10_000,
|
| 29 |
+
qkv_bias=True,
|
| 30 |
+
guidance_embed=False,
|
| 31 |
+
),
|
| 32 |
+
}
|
flowdis/layers.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from einops import rearrange
|
| 6 |
+
from torch import Tensor, nn
|
| 7 |
+
|
| 8 |
+
from flowdis.math import attention, rope
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class EmbedND(nn.Module):
|
| 12 |
+
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.dim = dim
|
| 15 |
+
self.theta = theta
|
| 16 |
+
self.axes_dim = axes_dim
|
| 17 |
+
|
| 18 |
+
def forward(self, ids: Tensor) -> Tensor:
|
| 19 |
+
n_axes = ids.shape[-1]
|
| 20 |
+
emb = torch.cat(
|
| 21 |
+
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
| 22 |
+
dim=-3,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
return emb.unsqueeze(1)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
| 29 |
+
"""
|
| 30 |
+
Create sinusoidal timestep embeddings.
|
| 31 |
+
:param t: a 1-D Tensor of N indices, one per batch element.
|
| 32 |
+
These may be fractional.
|
| 33 |
+
:param dim: the dimension of the output.
|
| 34 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
| 35 |
+
:return: an (N, D) Tensor of positional embeddings.
|
| 36 |
+
"""
|
| 37 |
+
t = time_factor * t
|
| 38 |
+
half = dim // 2
|
| 39 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
| 40 |
+
|
| 41 |
+
args = t[:, None].float() * freqs[None]
|
| 42 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
| 43 |
+
if dim % 2:
|
| 44 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
| 45 |
+
if torch.is_floating_point(t):
|
| 46 |
+
embedding = embedding.to(t)
|
| 47 |
+
return embedding
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class MLPEmbedder(nn.Module):
|
| 51 |
+
def __init__(self, in_dim: int, hidden_dim: int):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
| 54 |
+
self.silu = nn.SiLU()
|
| 55 |
+
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
| 56 |
+
|
| 57 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 58 |
+
return self.out_layer(self.silu(self.in_layer(x)))
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class RMSNorm(torch.nn.Module):
|
| 62 |
+
def __init__(self, dim: int):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.scale = nn.Parameter(torch.ones(dim))
|
| 65 |
+
self.dim = dim
|
| 66 |
+
|
| 67 |
+
def forward(self, x: Tensor):
|
| 68 |
+
x_dtype = x.dtype
|
| 69 |
+
x = x.float()
|
| 70 |
+
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
| 71 |
+
return (x * rrms).to(dtype=x_dtype) * self.scale
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class QKNorm(torch.nn.Module):
|
| 75 |
+
def __init__(self, dim: int):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.query_norm = RMSNorm(dim)
|
| 78 |
+
self.key_norm = RMSNorm(dim)
|
| 79 |
+
|
| 80 |
+
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
| 81 |
+
q = self.query_norm(q)
|
| 82 |
+
k = self.key_norm(k)
|
| 83 |
+
return q.to(v), k.to(v)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class SelfAttention(nn.Module):
|
| 87 |
+
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
| 88 |
+
super().__init__()
|
| 89 |
+
self.num_heads = num_heads
|
| 90 |
+
head_dim = dim // num_heads
|
| 91 |
+
|
| 92 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 93 |
+
self.norm = QKNorm(head_dim)
|
| 94 |
+
self.proj = nn.Linear(dim, dim)
|
| 95 |
+
|
| 96 |
+
def forward(self, x: Tensor, pe: Tensor) -> Tensor:
|
| 97 |
+
qkv = self.qkv(x)
|
| 98 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 99 |
+
q, k = self.norm(q, k, v)
|
| 100 |
+
x = attention(q, k, v, pe=pe)
|
| 101 |
+
x = self.proj(x)
|
| 102 |
+
return x
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@dataclass
|
| 106 |
+
class ModulationOut:
|
| 107 |
+
shift: Tensor
|
| 108 |
+
scale: Tensor
|
| 109 |
+
gate: Tensor
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Modulation(nn.Module):
|
| 113 |
+
def __init__(self, dim: int, double: bool):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.is_double = double
|
| 116 |
+
self.multiplier = 6 if double else 3
|
| 117 |
+
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
| 118 |
+
|
| 119 |
+
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
| 120 |
+
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
| 121 |
+
|
| 122 |
+
return (
|
| 123 |
+
ModulationOut(*out[:3]),
|
| 124 |
+
ModulationOut(*out[3:]) if self.is_double else None,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class DoubleStreamBlock(nn.Module):
|
| 129 |
+
def __init__(
|
| 130 |
+
self,
|
| 131 |
+
hidden_size: int,
|
| 132 |
+
num_heads: int,
|
| 133 |
+
mlp_ratio: float,
|
| 134 |
+
qkv_bias: bool = False,
|
| 135 |
+
):
|
| 136 |
+
super().__init__()
|
| 137 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 138 |
+
self.num_heads = num_heads
|
| 139 |
+
self.hidden_size = hidden_size
|
| 140 |
+
self.img_mod = Modulation(hidden_size, double=True)
|
| 141 |
+
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 142 |
+
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
| 143 |
+
|
| 144 |
+
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 145 |
+
self.img_mlp = nn.Sequential(
|
| 146 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 147 |
+
nn.GELU(approximate="tanh"),
|
| 148 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
self.txt_mod = Modulation(hidden_size, double=True)
|
| 152 |
+
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 153 |
+
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
| 154 |
+
|
| 155 |
+
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 156 |
+
self.txt_mlp = nn.Sequential(
|
| 157 |
+
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
| 158 |
+
nn.GELU(approximate="tanh"),
|
| 159 |
+
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
| 163 |
+
return self._forward(img, txt, vec, pe)
|
| 164 |
+
|
| 165 |
+
def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
|
| 166 |
+
img_mod1, img_mod2 = self.img_mod(vec)
|
| 167 |
+
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
| 168 |
+
|
| 169 |
+
# prepare image for attention
|
| 170 |
+
img_modulated = self.img_norm1(img)
|
| 171 |
+
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
| 172 |
+
img_qkv = self.img_attn.qkv(img_modulated)
|
| 173 |
+
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 174 |
+
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
| 175 |
+
|
| 176 |
+
# prepare txt for attention
|
| 177 |
+
txt_modulated = self.txt_norm1(txt)
|
| 178 |
+
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
| 179 |
+
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
| 180 |
+
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 181 |
+
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
| 182 |
+
|
| 183 |
+
# run actual attention
|
| 184 |
+
q = torch.cat((txt_q, img_q), dim=2)
|
| 185 |
+
k = torch.cat((txt_k, img_k), dim=2)
|
| 186 |
+
v = torch.cat((txt_v, img_v), dim=2)
|
| 187 |
+
|
| 188 |
+
attn = attention(q, k, v, pe=pe)
|
| 189 |
+
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
| 190 |
+
|
| 191 |
+
# calculate the img bloks
|
| 192 |
+
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
| 193 |
+
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
| 194 |
+
|
| 195 |
+
# calculate the txt bloks
|
| 196 |
+
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
| 197 |
+
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
| 198 |
+
return img, txt
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class SingleStreamBlock(nn.Module):
|
| 202 |
+
"""
|
| 203 |
+
A DiT block with parallel linear layers as described in
|
| 204 |
+
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
hidden_size: int,
|
| 210 |
+
num_heads: int,
|
| 211 |
+
mlp_ratio: float = 4.0,
|
| 212 |
+
qk_scale: float | None = None,
|
| 213 |
+
):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.hidden_dim = hidden_size
|
| 216 |
+
self.num_heads = num_heads
|
| 217 |
+
head_dim = hidden_size // num_heads
|
| 218 |
+
self.scale = qk_scale or head_dim**-0.5
|
| 219 |
+
|
| 220 |
+
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 221 |
+
# qkv and mlp_in
|
| 222 |
+
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
| 223 |
+
# proj and mlp_out
|
| 224 |
+
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
| 225 |
+
|
| 226 |
+
self.norm = QKNorm(head_dim)
|
| 227 |
+
|
| 228 |
+
self.hidden_size = hidden_size
|
| 229 |
+
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 230 |
+
|
| 231 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
| 232 |
+
self.modulation = Modulation(hidden_size, double=False)
|
| 233 |
+
|
| 234 |
+
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
| 235 |
+
return self._forward(x, vec, pe)
|
| 236 |
+
|
| 237 |
+
def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
| 238 |
+
mod, _ = self.modulation(vec)
|
| 239 |
+
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
| 240 |
+
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
| 241 |
+
|
| 242 |
+
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
| 243 |
+
q, k = self.norm(q, k, v)
|
| 244 |
+
|
| 245 |
+
# compute attention
|
| 246 |
+
attn = attention(q, k, v, pe=pe)
|
| 247 |
+
# compute activation in mlp stream, cat again and run second linear layer
|
| 248 |
+
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
| 249 |
+
return x + mod.gate * output
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class LastLayer(nn.Module):
|
| 253 |
+
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 256 |
+
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
| 257 |
+
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
| 258 |
+
|
| 259 |
+
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
| 260 |
+
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
| 261 |
+
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
| 262 |
+
x = self.linear(x)
|
| 263 |
+
return x
|
flowdis/loaders.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
from safetensors.torch import load_file
|
| 4 |
+
|
| 5 |
+
from flowdis.autoencoder import AutoEncoder
|
| 6 |
+
from flowdis.conditioner import HFEmbedder
|
| 7 |
+
from flowdis.configs import configs
|
| 8 |
+
from flowdis.model import Flux, FluxParams
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_transformer(
|
| 12 |
+
model_name: str,
|
| 13 |
+
model_path: str,
|
| 14 |
+
device: str | torch.device = "cuda",
|
| 15 |
+
config: FluxParams = None,
|
| 16 |
+
state_dict: dict = None,
|
| 17 |
+
) -> Flux:
|
| 18 |
+
with torch.device("meta"):
|
| 19 |
+
model = Flux(config if config else configs[model_name]).to(dtype=torch.bfloat16)
|
| 20 |
+
model.to_empty(device="cpu")
|
| 21 |
+
if state_dict is None:
|
| 22 |
+
if str(model_path).endswith(".safetensors"):
|
| 23 |
+
state_dict = load_file(model_path, device="cpu")
|
| 24 |
+
else:
|
| 25 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 26 |
+
model.load_state_dict(state_dict, assign=True, strict=False)
|
| 27 |
+
model = model.to(device=device, dtype=torch.bfloat16)
|
| 28 |
+
return model.eval()
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_autoencoder(
|
| 32 |
+
model_path: str,
|
| 33 |
+
device: str | torch.device = "cuda"
|
| 34 |
+
) -> AutoEncoder:
|
| 35 |
+
with torch.device("meta"):
|
| 36 |
+
ae = AutoEncoder(configs["autoencoder"])
|
| 37 |
+
ae.to_empty(device="cpu")
|
| 38 |
+
state_dict = load_file(model_path, device="cpu")
|
| 39 |
+
ae.load_state_dict(state_dict, assign=True, strict=False)
|
| 40 |
+
ae = ae.to(device=device, dtype=torch.bfloat16)
|
| 41 |
+
return ae.eval()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_t5(
|
| 45 |
+
model_path: str,
|
| 46 |
+
max_length: int = 512,
|
| 47 |
+
device: str | torch.device = "cuda"
|
| 48 |
+
) -> HFEmbedder:
|
| 49 |
+
with torch.device("meta"):
|
| 50 |
+
t5 = HFEmbedder(
|
| 51 |
+
model_path.parent,
|
| 52 |
+
max_length=max_length,
|
| 53 |
+
is_clip=False,
|
| 54 |
+
dtype=torch.bfloat16
|
| 55 |
+
)
|
| 56 |
+
t5.to_empty(device="cpu")
|
| 57 |
+
state_dict = load_file(model_path, device="cpu")
|
| 58 |
+
t5.load_state_dict(state_dict, assign=True, strict=False)
|
| 59 |
+
return t5.to(device=device, dtype=torch.bfloat16)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def load_clip(
|
| 63 |
+
model_path: str,
|
| 64 |
+
device: str | torch.device = "cuda"
|
| 65 |
+
) -> HFEmbedder:
|
| 66 |
+
clip = HFEmbedder(
|
| 67 |
+
model_path.parent,
|
| 68 |
+
max_length=77,
|
| 69 |
+
is_clip=True,
|
| 70 |
+
dtype=torch.bfloat16
|
| 71 |
+
)
|
| 72 |
+
state_dict = load_file(model_path, device="cpu")
|
| 73 |
+
clip.load_state_dict(state_dict, assign=True, strict=False)
|
| 74 |
+
return clip.to(device=device, dtype=torch.bfloat16)
|
| 75 |
+
|
flowdis/math.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
from torch import Tensor
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 7 |
+
q, k = apply_rope(q, k, pe)
|
| 8 |
+
|
| 9 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 10 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 11 |
+
|
| 12 |
+
return x
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
| 16 |
+
assert dim % 2 == 0
|
| 17 |
+
scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
|
| 18 |
+
omega = 1.0 / (theta**scale)
|
| 19 |
+
out = torch.einsum("...n,d->...nd", pos, omega)
|
| 20 |
+
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
| 21 |
+
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
| 22 |
+
return out.float()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
| 26 |
+
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
| 27 |
+
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
| 28 |
+
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
| 29 |
+
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
| 30 |
+
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
flowdis/model.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torch import Tensor, nn
|
| 5 |
+
|
| 6 |
+
from flowdis.layers import (
|
| 7 |
+
DoubleStreamBlock,
|
| 8 |
+
EmbedND,
|
| 9 |
+
LastLayer,
|
| 10 |
+
MLPEmbedder,
|
| 11 |
+
SingleStreamBlock,
|
| 12 |
+
timestep_embedding,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class FluxParams:
|
| 18 |
+
in_channels: int
|
| 19 |
+
out_channels: int
|
| 20 |
+
vec_in_dim: int
|
| 21 |
+
context_in_dim: int
|
| 22 |
+
hidden_size: int
|
| 23 |
+
mlp_ratio: float
|
| 24 |
+
num_heads: int
|
| 25 |
+
depth: int
|
| 26 |
+
depth_single_blocks: int
|
| 27 |
+
axes_dim: list[int]
|
| 28 |
+
theta: int
|
| 29 |
+
qkv_bias: bool
|
| 30 |
+
guidance_embed: bool
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Flux(nn.Module):
|
| 34 |
+
"""
|
| 35 |
+
Transformer model for flow matching on sequences.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, params: FluxParams):
|
| 39 |
+
super().__init__()
|
| 40 |
+
|
| 41 |
+
self.params = params
|
| 42 |
+
self.in_channels = params.in_channels
|
| 43 |
+
self.out_channels = params.out_channels
|
| 44 |
+
if params.hidden_size % params.num_heads != 0:
|
| 45 |
+
raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
|
| 46 |
+
pe_dim = params.hidden_size // params.num_heads
|
| 47 |
+
if sum(params.axes_dim) != pe_dim:
|
| 48 |
+
raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
|
| 49 |
+
self.hidden_size = params.hidden_size
|
| 50 |
+
self.num_heads = params.num_heads
|
| 51 |
+
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
| 52 |
+
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
| 53 |
+
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
| 54 |
+
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
|
| 55 |
+
self.guidance_in = (
|
| 56 |
+
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
|
| 57 |
+
)
|
| 58 |
+
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
|
| 59 |
+
|
| 60 |
+
self.double_blocks = nn.ModuleList(
|
| 61 |
+
[
|
| 62 |
+
DoubleStreamBlock(
|
| 63 |
+
self.hidden_size,
|
| 64 |
+
self.num_heads,
|
| 65 |
+
mlp_ratio=params.mlp_ratio,
|
| 66 |
+
qkv_bias=params.qkv_bias,
|
| 67 |
+
) for _ in range(params.depth)
|
| 68 |
+
]
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
self.single_blocks = nn.ModuleList(
|
| 72 |
+
[
|
| 73 |
+
SingleStreamBlock(
|
| 74 |
+
self.hidden_size,
|
| 75 |
+
self.num_heads,
|
| 76 |
+
mlp_ratio=params.mlp_ratio,
|
| 77 |
+
) for _ in range(params.depth_single_blocks)
|
| 78 |
+
]
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
| 82 |
+
|
| 83 |
+
def forward(
|
| 84 |
+
self,
|
| 85 |
+
img: Tensor,
|
| 86 |
+
img_ids: Tensor,
|
| 87 |
+
txt: Tensor,
|
| 88 |
+
txt_ids: Tensor,
|
| 89 |
+
timesteps: Tensor,
|
| 90 |
+
y: Tensor,
|
| 91 |
+
guidance: Tensor | None = None,
|
| 92 |
+
) -> Tensor:
|
| 93 |
+
if img.ndim != 3 or txt.ndim != 3:
|
| 94 |
+
raise ValueError("Input img and txt tensors must have 3 dimensions.")
|
| 95 |
+
|
| 96 |
+
# running on sequences img
|
| 97 |
+
img = self.img_in(img)
|
| 98 |
+
vec = self.time_in(timestep_embedding(timesteps, 256))
|
| 99 |
+
if self.params.guidance_embed:
|
| 100 |
+
if guidance is None:
|
| 101 |
+
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
| 102 |
+
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
| 103 |
+
vec = vec + self.vector_in(y)
|
| 104 |
+
txt = self.txt_in(txt)
|
| 105 |
+
|
| 106 |
+
ids = torch.cat((txt_ids, img_ids), dim=1)
|
| 107 |
+
pe = self.pe_embedder(ids)
|
| 108 |
+
|
| 109 |
+
for block in self.double_blocks:
|
| 110 |
+
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
|
| 111 |
+
|
| 112 |
+
img = torch.cat((txt, img), 1)
|
| 113 |
+
for block in self.single_blocks:
|
| 114 |
+
img = block(img, vec=vec, pe=pe)
|
| 115 |
+
img = img[:, txt.shape[1] :, ...]
|
| 116 |
+
|
| 117 |
+
img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
|
| 118 |
+
return img
|
flowdis/sampling.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchvision.transforms.functional as tvF
|
| 5 |
+
from einops import rearrange, repeat
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from scipy import stats
|
| 8 |
+
from torch import Tensor
|
| 9 |
+
|
| 10 |
+
from flowdis.model import Flux
|
| 11 |
+
from flowdis.util import Models
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def unpack(x: Tensor, height: int, width: int) -> Tensor:
|
| 15 |
+
return rearrange(
|
| 16 |
+
x,
|
| 17 |
+
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
| 18 |
+
h=math.ceil(height / 16),
|
| 19 |
+
w=math.ceil(width / 16),
|
| 20 |
+
ph=2,
|
| 21 |
+
pw=2,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def beta_scheduler(num_timesteps: int, alpha: float = 2.5, beta: float = 1.0) -> list[float]:
|
| 26 |
+
q = torch.linspace(1, 0, num_timesteps+1)
|
| 27 |
+
steps = stats.beta.ppf(q, alpha, beta).tolist()
|
| 28 |
+
if steps[-1] > 0.0:
|
| 29 |
+
steps.append(0.0)
|
| 30 |
+
return steps
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def prepare(
|
| 34 |
+
img: Tensor,
|
| 35 |
+
prompt: str | list[str],
|
| 36 |
+
models: Models,
|
| 37 |
+
device: str = "cuda"
|
| 38 |
+
) -> dict[str, Tensor]:
|
| 39 |
+
# load and encode the conditioning image and the mask
|
| 40 |
+
bs, _, _, _ = img.shape
|
| 41 |
+
if bs == 1 and not isinstance(prompt, str):
|
| 42 |
+
bs = len(prompt)
|
| 43 |
+
if isinstance(prompt, str):
|
| 44 |
+
prompt = [prompt]
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
img = models.ae.encode(img.to(device=device, dtype=torch.bfloat16))
|
| 48 |
+
h, w = img.shape[2], img.shape[3]
|
| 49 |
+
|
| 50 |
+
img_ids = torch.zeros(h // 2, w // 2, 3)
|
| 51 |
+
img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
|
| 52 |
+
img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
|
| 53 |
+
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
|
| 54 |
+
|
| 55 |
+
img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
| 56 |
+
if img.shape[0] == 1 and bs > 1:
|
| 57 |
+
img = repeat(img, "1 ... -> bs ...", bs=bs)
|
| 58 |
+
|
| 59 |
+
txt = models.t5(prompt)
|
| 60 |
+
if txt.shape[0] == 1 and bs > 1:
|
| 61 |
+
txt = repeat(txt, "1 ... -> bs ...", bs=bs)
|
| 62 |
+
txt_ids = torch.zeros(bs, txt.shape[1], 3)
|
| 63 |
+
|
| 64 |
+
vec = models.clip(prompt)
|
| 65 |
+
if vec.shape[0] == 1 and bs > 1:
|
| 66 |
+
vec = repeat(vec, "1 ... -> bs ...", bs=bs)
|
| 67 |
+
|
| 68 |
+
return_dict = {
|
| 69 |
+
"img": img,
|
| 70 |
+
"img_ids": img_ids.to(img.device),
|
| 71 |
+
"txt": txt.to(img.device),
|
| 72 |
+
"txt_ids": txt_ids.to(img.device),
|
| 73 |
+
"vec": vec.to(img.device),
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
return return_dict
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def solve_flowdis_ode(
|
| 80 |
+
model: Flux,
|
| 81 |
+
img: Tensor,
|
| 82 |
+
img_ids: Tensor,
|
| 83 |
+
txt: Tensor,
|
| 84 |
+
txt_ids: Tensor,
|
| 85 |
+
vec: Tensor,
|
| 86 |
+
num_inference_steps: int,
|
| 87 |
+
):
|
| 88 |
+
zt = img
|
| 89 |
+
timesteps = beta_scheduler(num_inference_steps)
|
| 90 |
+
for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
|
| 91 |
+
t_vec = torch.full((zt.shape[0],), t_curr, dtype=zt.dtype, device=zt.device)
|
| 92 |
+
pred = model(
|
| 93 |
+
img=torch.cat((zt, img), dim=-1),
|
| 94 |
+
img_ids=img_ids,
|
| 95 |
+
txt=txt,
|
| 96 |
+
txt_ids=txt_ids,
|
| 97 |
+
y=vec,
|
| 98 |
+
timesteps=t_vec,
|
| 99 |
+
)
|
| 100 |
+
zt = zt + (t_prev - t_curr) * pred
|
| 101 |
+
return zt
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@torch.no_grad()
|
| 105 |
+
def flowdis_predict(
|
| 106 |
+
image: Tensor,
|
| 107 |
+
prompt: str | list[str],
|
| 108 |
+
models: Models,
|
| 109 |
+
resolution: int = 1024,
|
| 110 |
+
num_inference_steps: int = 2,
|
| 111 |
+
device: str = "cuda",
|
| 112 |
+
):
|
| 113 |
+
image_orig = image.convert("RGB")
|
| 114 |
+
image = image.resize((resolution, resolution))
|
| 115 |
+
|
| 116 |
+
image_t = tvF.to_tensor(image).unsqueeze(0).to(device=device)
|
| 117 |
+
image_t = (image_t - 0.5) / 0.5
|
| 118 |
+
|
| 119 |
+
inp = prepare(image_t, prompt, models, device)
|
| 120 |
+
|
| 121 |
+
pred_mask_latent_t = solve_flowdis_ode(
|
| 122 |
+
models.transformer,
|
| 123 |
+
**inp,
|
| 124 |
+
num_inference_steps=num_inference_steps,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
pred_mask_latent_t = unpack(pred_mask_latent_t.float(), resolution, resolution)
|
| 128 |
+
with torch.autocast(device_type=device, dtype=torch.bfloat16):
|
| 129 |
+
pred_mask_t = models.ae.decode(pred_mask_latent_t).clamp(-1, 1)
|
| 130 |
+
|
| 131 |
+
pred_mask_t = rearrange(pred_mask_t[0], "c h w -> h w c")
|
| 132 |
+
pred_mask_np = (127.5 * (pred_mask_t + 1.0)).mean(dim=-1).cpu().byte().numpy()
|
| 133 |
+
pred_mask = Image.fromarray(pred_mask_np).convert("L")
|
| 134 |
+
pred_mask = pred_mask.resize(image_orig.size)
|
| 135 |
+
|
| 136 |
+
return pred_mask
|
flowdis/util.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from copy import deepcopy
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
from huggingface_hub import snapshot_download
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
from flowdis.autoencoder import AutoEncoder
|
| 11 |
+
from flowdis.conditioner import HFEmbedder
|
| 12 |
+
from flowdis.configs import configs
|
| 13 |
+
from flowdis.loaders import load_autoencoder, load_clip, load_t5, load_transformer
|
| 14 |
+
from flowdis.model import Flux
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class Models:
|
| 22 |
+
clip: HFEmbedder
|
| 23 |
+
t5: HFEmbedder
|
| 24 |
+
ae: AutoEncoder
|
| 25 |
+
transformer: Flux
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def load_models(
|
| 29 |
+
root_model_dir: Path = None,
|
| 30 |
+
device: str | torch.device = "cuda"
|
| 31 |
+
) -> Models:
|
| 32 |
+
"""
|
| 33 |
+
Load the models for the FlowDIS pipeline.
|
| 34 |
+
|
| 35 |
+
Args:
|
| 36 |
+
root_model_dir: The root model directory.
|
| 37 |
+
If None, the models are downloaded from the Hugging Face Hub.
|
| 38 |
+
device: The device to load the models on.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
Models: The loaded models.
|
| 42 |
+
"""
|
| 43 |
+
if root_model_dir is None:
|
| 44 |
+
root_model_dir = download_from_hf_hub("PAIR/FlowDIS")
|
| 45 |
+
|
| 46 |
+
logger.info("Loading T5.")
|
| 47 |
+
t5 = load_t5(
|
| 48 |
+
model_path=root_model_dir / "t5-v1_1-xxl" / "model.safetensors",
|
| 49 |
+
device=device,
|
| 50 |
+
max_length=512
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
logger.info("Loading CLIP.")
|
| 54 |
+
clip = load_clip(
|
| 55 |
+
model_path=root_model_dir / "clip-vit-large-patch14" / "model.safetensors",
|
| 56 |
+
device=device
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
logger.info("Loading AE.")
|
| 60 |
+
ae = load_autoencoder(
|
| 61 |
+
model_path=root_model_dir / "ae.safetensors",
|
| 62 |
+
device=device
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
logger.info("Loading Transformer.")
|
| 66 |
+
model = load_transformer(
|
| 67 |
+
model_name="flowdis",
|
| 68 |
+
model_path=root_model_dir / "flowdis-transformer.safetensors",
|
| 69 |
+
device=device,
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
logger.info("All models loaded.")
|
| 73 |
+
|
| 74 |
+
return Models(
|
| 75 |
+
clip=clip,
|
| 76 |
+
t5=t5,
|
| 77 |
+
ae=ae,
|
| 78 |
+
transformer=model,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def download_from_hf_hub(
|
| 83 |
+
repo_id: str,
|
| 84 |
+
cache_dir: str | Path | None = None,
|
| 85 |
+
revision: str | None = None,
|
| 86 |
+
) -> Path:
|
| 87 |
+
"""
|
| 88 |
+
Download a FlowDIS model repository from the Hugging Face Hub.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
repo_id: The Hugging Face Hub repo id (e.g. "PAIR/FlowDIS").
|
| 92 |
+
cache_dir: Optional cache directory. Defaults to the huggingface_hub
|
| 93 |
+
default (typically ~/.cache/huggingface/hub).
|
| 94 |
+
revision: Optional git revision (branch, tag, or commit SHA).
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
Path to the local directory containing the downloaded snapshot. The
|
| 98 |
+
directory layout matches the repo layout on the Hub, so it can be
|
| 99 |
+
passed directly to `load_models` as `root_model_dir`.
|
| 100 |
+
"""
|
| 101 |
+
logger.info(f"Downloading {repo_id} from Hugging Face Hub.")
|
| 102 |
+
local_dir = snapshot_download(
|
| 103 |
+
repo_id=repo_id,
|
| 104 |
+
cache_dir=cache_dir,
|
| 105 |
+
revision=revision,
|
| 106 |
+
)
|
| 107 |
+
logger.info(f"Snapshot available at {local_dir}.")
|
| 108 |
+
return Path(local_dir)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def green_screen(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
| 112 |
+
img_np = np.array(img)
|
| 113 |
+
mask = (np.array(mask) / 255)[:, :, np.newaxis].repeat(3, axis=2)
|
| 114 |
+
combined = img_np * mask + (1-mask) * np.array([0, 255, 0], dtype=np.uint8)
|
| 115 |
+
combined = combined.astype(np.uint8)
|
| 116 |
+
return combined
|
pyproject.toml
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "flowdis"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "FlowDIS: Language-Guided Dichotomous Image Segmentation with Flow Matching"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
license = { text = "MIT" }
|
| 12 |
+
authors = [
|
| 13 |
+
{ name = "Andranik Sargsyan" },
|
| 14 |
+
{ name = "Shant Navasardyan" },
|
| 15 |
+
]
|
| 16 |
+
keywords = ["segmentation", "flow-matching", "background removal", "deep-learning"]
|
| 17 |
+
classifiers = [
|
| 18 |
+
"Development Status :: 3 - Alpha",
|
| 19 |
+
"Intended Audience :: Science/Research",
|
| 20 |
+
"License :: OSI Approved :: MIT License",
|
| 21 |
+
"Programming Language :: Python :: 3",
|
| 22 |
+
"Programming Language :: Python :: 3.10",
|
| 23 |
+
"Programming Language :: Python :: 3.11",
|
| 24 |
+
"Programming Language :: Python :: 3.12",
|
| 25 |
+
"Topic :: Scientific/Engineering :: Image Recognition",
|
| 26 |
+
]
|
| 27 |
+
dependencies = [
|
| 28 |
+
"accelerate>=1.12.0,<2.0",
|
| 29 |
+
"einops>=0.8.2,<1.0",
|
| 30 |
+
"gradio==6.3.0",
|
| 31 |
+
"numpy>=1.24.0,<2.0",
|
| 32 |
+
"opencv-python>=4.11.0,<5.0",
|
| 33 |
+
"Pillow>=10.0.0,<11.0",
|
| 34 |
+
"safetensors>=0.7.0,<1.0",
|
| 35 |
+
"scipy>=1.17.1,<2.0",
|
| 36 |
+
"sentencepiece>=0.2.1,<1.0",
|
| 37 |
+
"tiktoken>=0.12.0,<1.0",
|
| 38 |
+
"torch>=2.8.0,<=2.10",
|
| 39 |
+
"torchvision>=0.25.0",
|
| 40 |
+
"transformers>=4.39.0,<5.0",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
[project.optional-dependencies]
|
| 44 |
+
dev = [
|
| 45 |
+
"pytest>=7.0",
|
| 46 |
+
"ruff>=0.1.0",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
[tool.setuptools]
|
| 50 |
+
packages = ["flowdis"]
|
qwen.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
logger = logging.getLogger(__name__)
|
| 8 |
+
|
| 9 |
+
# Load model if GPU is available
|
| 10 |
+
model = None
|
| 11 |
+
processor = None
|
| 12 |
+
if torch.cuda.is_available():
|
| 13 |
+
logger.info("Loading Qwen3VL model.")
|
| 14 |
+
model = Qwen3VLForConditionalGeneration.from_pretrained(
|
| 15 |
+
"Qwen/Qwen3-VL-4B-Instruct",
|
| 16 |
+
dtype=torch.bfloat16,
|
| 17 |
+
device_map="auto"
|
| 18 |
+
)
|
| 19 |
+
processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-4B-Instruct")
|
| 20 |
+
logger.info("Qwen3VL model loaded.")
|
| 21 |
+
else:
|
| 22 |
+
logger.info("Qwen3VL was not loaded because no GPU is available.")
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def expand_prompt(image: Image.Image, user_prompt: str) -> str:
|
| 26 |
+
"""
|
| 27 |
+
Expand the user prompt using the Qwen3VL model.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
image: The image to use for the prompt expansion.
|
| 31 |
+
user_prompt: The user prompt to expand.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
The expanded prompt.
|
| 35 |
+
"""
|
| 36 |
+
messages = [
|
| 37 |
+
{
|
| 38 |
+
"role": "user",
|
| 39 |
+
"content": [
|
| 40 |
+
{"type": "image"},
|
| 41 |
+
{"type": "text", "text": f"Describe the {user_prompt} in this image with a short prompt. Don't use surrounding objects in the description. Also don't describe the background, like what it is sitting on or what it is on top of, etc..."}
|
| 42 |
+
]
|
| 43 |
+
}
|
| 44 |
+
]
|
| 45 |
+
|
| 46 |
+
text = processor.apply_chat_template(
|
| 47 |
+
messages,
|
| 48 |
+
tokenize=False,
|
| 49 |
+
add_generation_prompt=True
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
inputs = processor(
|
| 53 |
+
text=[text],
|
| 54 |
+
images=[image],
|
| 55 |
+
padding=True,
|
| 56 |
+
return_tensors="pt"
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
inputs = inputs.to(model.device)
|
| 60 |
+
|
| 61 |
+
with torch.no_grad():
|
| 62 |
+
generated_ids = model.generate(
|
| 63 |
+
**inputs,
|
| 64 |
+
max_new_tokens=512
|
| 65 |
+
)
|
| 66 |
+
generated_ids_trimmed = generated_ids[:, inputs["input_ids"].shape[1]:]
|
| 67 |
+
|
| 68 |
+
output_text = processor.batch_decode(
|
| 69 |
+
generated_ids_trimmed,
|
| 70 |
+
skip_special_tokens=True
|
| 71 |
+
)[0]
|
| 72 |
+
|
| 73 |
+
return output_text
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate>=1.12.0,<2.0
|
| 2 |
+
einops>=0.8.2,<1.0
|
| 3 |
+
gradio==6.3.0
|
| 4 |
+
numpy>=1.24.0,<2.0
|
| 5 |
+
opencv-python>=4.11.0,<5.0
|
| 6 |
+
Pillow>=10.0.0,<11.0
|
| 7 |
+
safetensors>=0.7.0,<1.0
|
| 8 |
+
scipy>=1.17.1,<2.0
|
| 9 |
+
sentencepiece>=0.2.1,<1.0
|
| 10 |
+
tiktoken>=0.12.0,<1.0
|
| 11 |
+
torch>=2.8.0,<=2.10
|
| 12 |
+
torchvision>=0.25.0
|
| 13 |
+
transformers>=4.39.0,<5.0
|