AndranikSargsyan commited on
Commit
a8a9bce
·
0 Parent(s):

Add FlowDIS inference and demo

Browse files
.gitattributes ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ *.jpg filter=lfs diff=lfs merge=lfs -text
2
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
3
+ *.png filter=lfs diff=lfs merge=lfs -text
4
+ *.webp filter=lfs diff=lfs merge=lfs -text
5
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gradio temporary files
2
+ gradio_temp/
3
+
4
+ # Python
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+ *.so
9
+ .Python
10
+ build/
11
+ develop-eggs/
12
+ dist/
13
+ downloads/
14
+ eggs/
15
+ .eggs/
16
+ lib/
17
+ lib64/
18
+ parts/
19
+ sdist/
20
+ var/
21
+ wheels/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # Virtual Environment
27
+ venv/
28
+ env/
29
+ ENV/
30
+
31
+ # IDE
32
+ .vscode/
33
+ .idea/
34
+ *.swp
35
+ *.swo
36
+ *~
37
+
38
+ # OS
39
+ .DS_Store
40
+ Thumbs.db
41
+
42
+ outputs/
43
+ .gradio/
APACHE-2.0-LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LICENSE ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PicsArt Inc. FlowDIS Model License v1.0
2
+ Non-Commercial Use License
3
+
4
+ PicsArt Inc. ("PicsArt," "we," "our," or "Company") makes the weights,
5
+ parameters, and inference code for FlowDIS (as defined below)
6
+ available for your non-commercial and non-production use under the
7
+ terms of this License.
8
+
9
+ FlowDIS is a derivative of FLUX.1 [schnell] by Black Forest Labs,
10
+ Inc., which is licensed under the Apache License, Version 2.0. The
11
+ original FLUX.1 [schnell] model and its associated copyright, patent,
12
+ trademark, and attribution notices are included with this
13
+ distribution. A copy of the Apache License, Version 2.0 is provided
14
+ in the accompanying APACHE-2.0-LICENSE file. This model contains
15
+ modifications made by PicsArt Inc. to the original FLUX.1 [schnell]
16
+ model. By downloading, accessing, using, Distributing, or creating a
17
+ Derivative of FlowDIS, you agree to the terms of this License. If you
18
+ do not agree, you have no rights to access, use, Distribute, or
19
+ create a Derivative of FlowDIS and must immediately cease using it.
20
+ If you accept this License on behalf of your employer or another
21
+ entity, you represent and warrant that you have full legal authority
22
+ to bind that employer or entity.
23
+
24
+ 1. Definitions
25
+
26
+ (a) "Derivative" means any (i) modified version of FlowDIS (including
27
+ any fine-tuned or distilled version), (ii) work based on FlowDIS,
28
+ or (iii) any other derivative work thereof. For clarity, Outputs
29
+ are not Derivatives.
30
+
31
+ (b) "Distribution" or "Distribute" means providing or making
32
+ available, by any means, a copy of FlowDIS and/or Derivatives.
33
+
34
+ (c) "Non-Commercial Purpose" means any of the following uses, but
35
+ only so far as you do not receive any direct or indirect payment
36
+ arising from the use of FlowDIS or Derivatives:
37
+
38
+ (i) personal use for research, experimentation, and testing for the
39
+ benefit of public knowledge, personal study, private
40
+ entertainment, hobby projects, or otherwise not directly or
41
+ indirectly connected to any commercial activities, business
42
+ operations, or employment responsibilities;
43
+
44
+ (ii) use by commercial or for-profit entities for testing,
45
+ evaluation, or non-commercial research and development in a
46
+ non-production environment; and
47
+
48
+ (iii) use by any charitable organization for charitable purposes,
49
+ or for testing or evaluation. For clarity, use (a) for
50
+ revenue-generating activity, (b) in direct interactions with
51
+ or that has impact on end users, or (c) to train, fine-tune,
52
+ or distill other models for commercial use, in each case, is
53
+ not a Non-Commercial Purpose.
54
+
55
+ (d) "Outputs" means any content generated by the operation of FlowDIS
56
+ or Derivatives from an input or prompt provided by users. Outputs
57
+ do not include any components of FlowDIS such as fine-tuned
58
+ versions, weights, or parameters.
59
+
60
+ (e) "you" or "your" means the individual or entity entering into this
61
+ License with Company.
62
+
63
+ 2. License Grant
64
+
65
+ (a) License. Subject to your compliance with this License, Company
66
+ grants you a non-exclusive, worldwide, non-transferable,
67
+ non-sublicensable, revocable, royalty-free, and limited license
68
+ to access, use, create Derivatives of, and Distribute FlowDIS and
69
+ Derivatives solely for Non-Commercial Purposes. This license is
70
+ personal to you, and you may not assign or sublicense this
71
+ License or any rights or obligations under it without Company's
72
+ prior written consent; any such assignment or sublicense will be
73
+ void and will automatically and immediately terminate this
74
+ License. Any restrictions set forth herein regarding FlowDIS also
75
+ apply to any Derivative you create or that is created on your
76
+ behalf.
77
+
78
+ (b) Non-Commercial Use Only. You may only access, use, Distribute, or
79
+ create Derivatives of FlowDIS or Derivatives for Non-Commercial
80
+ Purposes. If you wish to use FlowDIS or a Derivative for any
81
+ purpose not expressly authorized under this License, you must
82
+ request a license from Company, which Company may grant in its
83
+ sole discretion and which may be subject to a fee, royalty, or
84
+ other revenue share.
85
+
86
+ (c) Reserved Rights. The grant of rights expressly set forth in this
87
+ License constitutes the complete grant of rights to use FlowDIS,
88
+ and no other licenses are granted, whether by waiver, estoppel,
89
+ implication, equity, or otherwise. Company and its licensors
90
+ reserve all rights not expressly granted by this License.
91
+
92
+ (d) Outputs. Company claims no ownership rights in Outputs. You are
93
+ solely responsible for the Outputs you generate and their
94
+ subsequent uses in accordance with this License. You may use
95
+ Outputs for any purpose (including commercial purposes), except
96
+ as expressly prohibited herein.
97
+
98
+ 3. Distribution
99
+
100
+ Subject to this License, you may Distribute copies of FlowDIS and/or
101
+ Derivatives made by you under the following conditions:
102
+ (a) You must make available a copy of this License to third-party
103
+ recipients of FlowDIS and/or Derivatives you Distribute, and
104
+ specify that any rights to use FlowDIS and/or Derivatives shall
105
+ be granted directly by Company to said third-party recipients
106
+ pursuant to this License.
107
+
108
+ (b) You must prominently display the following notice alongside the
109
+ Distribution (such as via a "NOTICE" text file distributed as
110
+ part of FlowDIS or the Derivative) (the "Attribution Notice"):
111
+ This model is licensed by PicsArt Inc. under the PicsArt Inc.
112
+ FlowDIS Model License v1.0. Copyright 2026 PicsArt Inc. This
113
+ model is a derivative of FLUX.1 [schnell] by Black Forest Labs,
114
+ Inc., licensed under the Apache License, Version 2.0. IN NO EVENT
115
+ SHALL PICSART INC. BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
116
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
117
+ ARISING FROM, OUT OF OR IN CONNECTION WITH USE OF THIS MODEL.
118
+
119
+ (c) In the case of Distribution of Derivatives made by you: (i) you
120
+ must include in the Attribution Notice a statement that you have
121
+ modified FlowDIS; (ii) any terms and conditions you impose on
122
+ third-party recipients relating to your Derivatives shall neither
123
+ limit such recipients' use of FlowDIS or any Derivatives made by
124
+ Company in accordance with this License, nor conflict with any of
125
+ its terms and conditions, and must include disclaimer of
126
+ warranties and limitation of liability provisions at least as
127
+ protective of Company as those set forth herein; and (iii) you
128
+ must not misrepresent or imply that Derivatives made by or for
129
+ you are an official product of PicsArt Inc. or have been
130
+ endorsed, approved, or validated by PicsArt Inc., unless
131
+ authorized by Company in writing.
132
+
133
+ (d) Apache 2.0 Compliance. All Distributions must include: (i) a copy
134
+ of the Apache License, Version 2.0; (ii) all copyright, patent,
135
+ trademark, and attribution notices from the original FLUX.1
136
+ [schnell] model; and (iii) prominent notices stating that the
137
+ files have been modified from the original FLUX.1 [schnell]
138
+ model.
139
+
140
+ 4. Restrictions
141
+
142
+ You will not, and will not permit, assist, or cause any third party
143
+ to:
144
+ (a) use, modify, copy, reproduce, create Derivatives of, or
145
+ Distribute FlowDIS (or any Derivative or data produced by
146
+ FlowDIS), in whole or in part, for:
147
+
148
+ (i) any commercial or production purpose;
149
+
150
+ (ii) any military purpose, including research, development,
151
+ design, manufacture, production, or use of weapons, weapons
152
+ systems, munitions, or any military or defense applications;
153
+
154
+ (iii) purposes of surveillance, including any research or
155
+ development relating to surveillance;
156
+
157
+ (iv) biometric processing;
158
+
159
+ (v) any manner that infringes, misappropriates, or otherwise violates
160
+ any third party's legal rights, including rights of publicity or
161
+ digital replica rights;
162
+
163
+ (vi) any unlawful, fraudulent, defamatory, or abusive activity;
164
+
165
+ (vii) generating unlawful content, including child sexual abuse
166
+ material or non-consensual intimate images; or
167
+
168
+ (viii) any manner that violates any applicable law, privacy or
169
+ security laws, rules, regulations, directives, or
170
+ governmental requirements (including the GDPR, the California
171
+ Consumer Privacy Act, laws governing the processing of
172
+ biometric information, and the EU AI Act, as well as all
173
+ amendments and successor laws to any of the foregoing);
174
+
175
+ (b) alter or remove copyright and other proprietary notices which
176
+ appear on or in any portion of FlowDIS;
177
+
178
+ (c) utilize any equipment, device, software, or other means to
179
+ circumvent or remove any security or protection used by Company
180
+ in connection with FlowDIS, or to circumvent or remove any usage
181
+ restrictions, or to enable functionality disabled by Company;
182
+
183
+ (d) offer or impose any terms on FlowDIS that alter, restrict, or are
184
+ inconsistent with the terms of this License;
185
+
186
+ (e) violate any applicable U.S. and non-U.S. export control and trade
187
+ sanctions laws ("Export Laws") in connection with your use or
188
+ Distribution of FlowDIS; or
189
+
190
+ (f) directly or indirectly Distribute, export, or otherwise transfer
191
+ FlowDIS (i) to any individual, entity, or country prohibited by
192
+ Export Laws; (ii) to anyone on U.S. or non-U.S. government
193
+ restricted parties lists; (iii) for any purpose prohibited by
194
+ Export Laws, including nuclear, chemical or biological weapons,
195
+ or missile technology applications; (iv) if you or they are
196
+ located in a comprehensively sanctioned jurisdiction, currently
197
+ listed on any U.S. or non-U.S. restricted parties list, or for
198
+ any purpose prohibited by Export Laws; or (v) while disguising
199
+ your location through IP proxying or other methods.
200
+
201
+ 5. Disclaimers
202
+
203
+ THE MODEL IS PROVIDED "AS IS" AND "WITH ALL FAULTS" WITH NO WARRANTY
204
+ OF ANY KIND, EXPRESS OR IMPLIED. COMPANY EXPRESSLY DISCLAIMS ALL
205
+ REPRESENTATIONS AND WARRANTIES, EXPRESS OR IMPLIED, WHETHER BY
206
+ STATUTE, CUSTOM, USAGE OR OTHERWISE, INCLUDING BUT NOT LIMITED TO THE
207
+ IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
208
+ PURPOSE, TITLE, SATISFACTORY QUALITY, OR NON-INFRINGEMENT. COMPANY
209
+ MAKES NO WARRANTIES OR REPRESENTATIONS THAT THE MODEL WILL BE ERROR
210
+ FREE OR FREE OF VIRUSES OR OTHER HARMFUL COMPONENTS, OR PRODUCE ANY
211
+ PARTICULAR RESULTS.
212
+
213
+ 6. Limitation of Liability
214
+
215
+ TO THE FULLEST EXTENT PERMITTED BY LAW, IN NO EVENT WILL COMPANY BE
216
+ LIABLE TO YOU OR YOUR EMPLOYEES, AFFILIATES, USERS, OFFICERS, OR
217
+ DIRECTORS (A) UNDER ANY THEORY OF LIABILITY, WHETHER BASED IN
218
+ CONTRACT, TORT, NEGLIGENCE, STRICT LIABILITY, WARRANTY, OR OTHERWISE
219
+ UNDER THIS LICENSE, OR (B) FOR ANY INDIRECT, CONSEQUENTIAL,
220
+ EXEMPLARY, INCIDENTAL, PUNITIVE, OR SPECIAL DAMAGES OR LOST PROFITS,
221
+ EVEN IF COMPANY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
222
+ THE MODEL, ITS CONSTITUENT COMPONENTS, AND ANY OUTPUT (COLLECTIVELY,
223
+ "MODEL MATERIALS") ARE NOT DESIGNED OR INTENDED FOR USE IN ANY
224
+ APPLICATION OR SITUATION WHERE FAILURE OR FAULT COULD REASONABLY BE
225
+ ANTICIPATED TO LEAD TO SERIOUS INJURY OF ANY PERSON, INCLUDING
226
+ POTENTIAL DISCRIMINATION OR VIOLATION OF AN INDIVIDUAL'S PRIVACY
227
+ RIGHTS, OR TO SEVERE PHYSICAL, PROPERTY, OR ENVIRONMENTAL DAMAGE
228
+ (EACH, A "HIGH-RISK USE"). IF YOU ELECT TO USE ANY MODEL MATERIALS
229
+ FOR A HIGH-RISK USE, YOU DO SO AT YOUR OWN RISK. YOU AGREE TO DESIGN
230
+ AND IMPLEMENT APPROPRIATE DECISION-MAKING AND RISK-MITIGATION
231
+ PROCEDURES AND POLICIES IN CONNECTION WITH A HIGH-RISK USE.
232
+
233
+ 7. Indemnification
234
+
235
+ You will indemnify, defend, and hold harmless Company and its
236
+ subsidiaries and affiliates, and each of their respective
237
+ shareholders, directors, officers, employees, agents, successors, and
238
+ assigns (collectively, the "Company Parties") from and against any
239
+ losses, liabilities, damages, fines, penalties, and expenses
240
+ (including reasonable attorneys' fees) incurred by any Company Party
241
+ in connection with any claim, demand, allegation, lawsuit,
242
+ proceeding, or investigation ("Claims") arising out of or related to:
243
+ (a) your access to or use of FlowDIS (including any Output or data
244
+ generated from such use), including any High-Risk Use; (b) your
245
+ violation of this License; or (c) your violation, misappropriation,
246
+ or infringement of any rights of another (including intellectual
247
+ property or other proprietary rights and privacy rights). You will
248
+ promptly notify the Company Parties of any such Claims and cooperate
249
+ with Company Parties in defending such Claims. You will also grant
250
+ the Company Parties sole control of the defense or settlement, at
251
+ Company's sole option, of any Claims. This indemnity is in addition
252
+ to, and not in lieu of, any other indemnities or remedies set forth
253
+ in a written agreement between you and Company or the other Company
254
+ Parties.
255
+
256
+ 8. Termination; Survival
257
+
258
+ (a) This License will automatically terminate upon any breach by you
259
+ of the terms of this License.
260
+
261
+ (b) Company may terminate this License, in whole or in part, at any
262
+ time upon notice (including electronic) to you.
263
+
264
+ (c) If you initiate any legal action or proceedings against Company
265
+ or any other entity (including a cross-claim or counterclaim),
266
+ alleging that FlowDIS, any Derivative, or any part thereof,
267
+ infringes upon intellectual property or other rights owned or
268
+ licensable by you, then any licenses granted to you under this
269
+ License will immediately terminate as of the date such legal
270
+ action or claim is filed.
271
+
272
+ (d) Upon termination, you must cease all use, access, or Distribution
273
+ of FlowDIS and any Derivatives. Sections 2(c), 2(d), 4 through 11
274
+ survive termination.
275
+
276
+ 9. Third-Party Materials
277
+
278
+ FlowDIS is derived from FLUX.1 [schnell] by Black Forest Labs, Inc.,
279
+ and may contain additional third-party software or components
280
+ (including free and open-source software) ("Third-Party Materials"),
281
+ which are subject to the license terms of the respective third-party
282
+ licensors. Your dealings or correspondence with third parties and
283
+ your use of or interaction with any Third-Party Materials are solely
284
+ between you and the third party. Company does not control or endorse,
285
+ and makes no representations or warranties regarding, any Third-Party
286
+ Materials, and your access to and use of such Third-Party Materials
287
+ are at your own risk.
288
+
289
+ 10. Trademarks
290
+
291
+ No trademark license is granted as part of this License. You may not
292
+ use any name, logo, or trademark associated with PicsArt Inc. without
293
+ Company's prior written permission, except to the extent necessary to
294
+ make the reference required in the Attribution Notice or as is
295
+ reasonably necessary in describing FlowDIS and its creators.
296
+
297
+ 11. General
298
+
299
+ This License will be governed and construed under the laws of the
300
+ State of Delaware without regard to conflicts of law provisions. If
301
+ any provision or part of a provision of this License is unlawful,
302
+ void, or unenforceable, that provision or part is deemed severed from
303
+ this License and will not affect the validity and enforceability of
304
+ any remaining provisions. The failure of Company to exercise or
305
+ enforce any right or provision of this License will not operate as a
306
+ waiver of such right or provision. This License does not confer any
307
+ third-party beneficiary rights upon any other person or entity. This
308
+ License, together with any accompanying documentation, contains the
309
+ entire understanding between you and Company regarding its subject
310
+ matter and supersedes all other written or oral agreements and
311
+ understandings between you and Company regarding such subject matter.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FlowDIS
3
+ emoji: 🌀
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 6.3.0
8
+ python_version: 3.12
9
+ app_file: app.py
10
+ pinned: true
11
+ thumbnail: assets/preview.png
12
+ ---
13
+ Paper: https://arxiv.org/abs/2605.05077
app.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ import logging
4
+ import uuid
5
+ import shutil
6
+ from copy import deepcopy
7
+ from pathlib import Path
8
+
9
+ # Set up logging
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format="%(asctime)s | %(levelname)s | %(message)s",
13
+ datefmt="%Y-%m-%d %H:%M:%S",
14
+ )
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Set Gradio temp directory BEFORE importing gradio to avoid permission issues
18
+ TEMP_DIR = Path(__file__).parent / "gradio_temp"
19
+ if TEMP_DIR.exists():
20
+ shutil.rmtree(str(TEMP_DIR))
21
+ TEMP_DIR.mkdir(exist_ok=True)
22
+ os.environ["GRADIO_TEMP_DIR"] = str(TEMP_DIR)
23
+ os.environ["TMPDIR"] = str(TEMP_DIR)
24
+
25
+ import gradio as gr
26
+ import numpy as np
27
+ import torch
28
+ from PIL import Image
29
+
30
+ IS_HF_SPACE = os.environ.get("SPACE_ID") is not None
31
+
32
+ try:
33
+ import spaces
34
+ zero_gpu = spaces.GPU
35
+ except ImportError:
36
+ zero_gpu = lambda f: f
37
+
38
+ from flowdis.sampling import flowdis_predict
39
+ from flowdis.util import load_models
40
+ from qwen import expand_prompt
41
+
42
+
43
+ models = None
44
+ device = "cuda"
45
+ if torch.cuda.is_available():
46
+ models = load_models(device=device)
47
+ else:
48
+ print("No GPU available, the demo will not be able to run.")
49
+
50
+
51
+ def disable_download_btn():
52
+ return gr.update(interactive=False)
53
+
54
+
55
+ @zero_gpu
56
+ def process_image(image, prompt, expand_prompt_enabled, resolution, num_inference_steps):
57
+ """
58
+ Process the input image and prompt.
59
+ This is a placeholder function - replace with your actual processing logic.
60
+
61
+ Args:
62
+ image: PIL Image or numpy array
63
+ prompt: str, the text input from the user
64
+ expand_prompt_enabled: bool, whether to expand the prompt via the model
65
+ resolution: int, the inference resolution
66
+ num_inference_steps: int, the number of inference steps
67
+
68
+ Returns:
69
+ Processed image
70
+ """
71
+ if image is None:
72
+ return None, None
73
+
74
+ if isinstance(image, np.ndarray):
75
+ image = Image.fromarray(image)
76
+
77
+ logger.info(f"Original prompt: {prompt}")
78
+ if prompt != "" and expand_prompt_enabled:
79
+ prompt = expand_prompt(image, prompt)
80
+ logger.info(f"Expanded prompt: {prompt}")
81
+
82
+ num_inference_steps = int(num_inference_steps)
83
+
84
+ pred_mask = flowdis_predict(
85
+ image=image,
86
+ prompt=prompt,
87
+ models=models,
88
+ resolution=resolution,
89
+ num_inference_steps=num_inference_steps,
90
+ device=device,
91
+ )
92
+ blacked_image = Image.fromarray(np.array(image) * (np.array(pred_mask)[:, :, np.newaxis] > 0).astype(np.uint8))
93
+ transparent_png = Image.fromarray(np.dstack([blacked_image, np.array(pred_mask)]))
94
+ uid = uuid.uuid4().hex
95
+ png_path = TEMP_DIR / f"{uid}.png"
96
+ transparent_png.save(png_path)
97
+ return (
98
+ gr.update(value=[image, transparent_png], key=uid),
99
+ gr.update(value=str(png_path), interactive=True)
100
+ )
101
+
102
+
103
+ # Load examples from assets/examples/examples.csv: image_name, prompt, resolution, num_steps
104
+ _example_dir = Path(__file__).parent / "assets" / "examples"
105
+ _examples_csv = _example_dir / "examples.csv"
106
+ examples = []
107
+ if _examples_csv.exists():
108
+ with open(_examples_csv, newline="", encoding="utf-8") as f:
109
+ for row in csv.DictReader(f):
110
+ image_path = str(_example_dir / row["image_name"].strip())
111
+ examples.append([
112
+ image_path,
113
+ row["prompt"].strip(),
114
+ True, # expand prompt (default for examples)
115
+ int(row["resolution"].strip()),
116
+ int(row["num_steps"].strip()),
117
+ ])
118
+
119
+
120
+ _head_js = """
121
+ <style>
122
+ #expand-prompt.is-disabled { pointer-events: none !important; }
123
+ #expand-prompt.is-disabled label,
124
+ #expand-prompt.is-disabled input,
125
+ #expand-prompt.is-disabled .info { opacity: 0.4 !important; }
126
+ /* Hide the "Expand prompt" column (3rd) in the examples table */
127
+ #examples-table table th:nth-child(3),
128
+ #examples-table table td:nth-child(3) { display: none !important; }
129
+ </style>
130
+ <script>
131
+ (function() {
132
+ function findEls() {
133
+ return {
134
+ ta: document.querySelector('#text-prompt textarea, #text-prompt input'),
135
+ cb: document.querySelector('#expand-prompt'),
136
+ };
137
+ }
138
+ function syncFromText() {
139
+ var els = findEls();
140
+ if (!els.ta || !els.cb) return;
141
+ var empty = !els.ta.value.trim();
142
+ els.cb.classList.toggle('is-disabled', empty);
143
+ var input = els.cb.querySelector('input[type=checkbox]');
144
+ if (input) input.disabled = empty;
145
+ }
146
+ function init() {
147
+ var els = findEls();
148
+ if (!els.ta || !els.cb) { setTimeout(init, 200); return; }
149
+ els.ta.addEventListener('input', syncFromText);
150
+ els.ta.addEventListener('change', syncFromText);
151
+ // Catch programmatic value changes (e.g. example selection)
152
+ var lastVal = els.ta.value;
153
+ setInterval(function() {
154
+ if (els.ta.value !== lastVal) { lastVal = els.ta.value; syncFromText(); }
155
+ }, 250);
156
+ syncFromText();
157
+ }
158
+ if (document.readyState === 'loading')
159
+ document.addEventListener('DOMContentLoaded', init);
160
+ else
161
+ init();
162
+ })();
163
+ </script>
164
+ <script>
165
+ (function() {
166
+ function findEls() {
167
+ return {
168
+ ta: document.querySelector('#text-prompt textarea, #text-prompt input'),
169
+ cb: document.querySelector('#expand-prompt'),
170
+ };
171
+ }
172
+ function syncFromText() {
173
+ var els = findEls();
174
+ if (!els.ta || !els.cb) return;
175
+ var empty = !els.ta.value.trim();
176
+ els.cb.classList.toggle('is-disabled', empty);
177
+ var input = els.cb.querySelector('input[type=checkbox]');
178
+ if (input) input.disabled = empty;
179
+ }
180
+ function init() {
181
+ var els = findEls();
182
+ if (!els.ta || !els.cb) { setTimeout(init, 200); return; }
183
+ els.ta.addEventListener('input', syncFromText);
184
+ els.ta.addEventListener('change', syncFromText);
185
+ // Catch programmatic value changes (e.g. example selection)
186
+ var lastVal = els.ta.value;
187
+ setInterval(function() {
188
+ if (els.ta.value !== lastVal) { lastVal = els.ta.value; syncFromText(); }
189
+ }, 250);
190
+ syncFromText();
191
+ }
192
+ if (document.readyState === 'loading')
193
+ document.addEventListener('DOMContentLoaded', init);
194
+ else
195
+ init();
196
+ })();
197
+ </script>
198
+ """
199
+ with gr.Blocks(
200
+ title="FlowDIS – Precise Background Removal",
201
+ head=_head_js,
202
+ theme=gr.themes.Default(
203
+ font=gr.themes.GoogleFont("Inter"),
204
+ ).set(
205
+ button_primary_background_fill="#C209C1",
206
+ button_primary_background_fill_dark="#C209C1",
207
+ button_primary_background_fill_hover="#d63bd5",
208
+ button_primary_background_fill_hover_dark="#d63bd5",
209
+ button_primary_text_color="#ffffff",
210
+ button_primary_text_color_dark="#ffffff",
211
+ ),
212
+ delete_cache=(1800, 1800)
213
+ ) as demo:
214
+ gr.HTML(
215
+ """
216
+ <style>
217
+ /* Theme-adaptive tokens */
218
+ :root {
219
+ --flow-text: #0f172a; /* slate-900 */
220
+ --flow-muted: #475569; /* slate-600 */
221
+ --flow-link: #2563eb; /* blue-600 */
222
+ --flow-link-hover: #1d4ed8; /* blue-700 */
223
+ --flow-title: #C209C1; /* Picsart pink */
224
+ }
225
+
226
+ @media (prefers-color-scheme: dark) {
227
+ :root {
228
+ --flow-text: #f1f5f9; /* slate-100 */
229
+ --flow-muted: #94a3b8; /* slate-400 */
230
+ --flow-link: #60a5fa; /* blue-400 */
231
+ --flow-link-hover: #93c5fd; /* blue-300 */
232
+ --flow-title: #e45fe3; /* Picsart pink (lighter for dark mode) */
233
+ }
234
+ }
235
+
236
+ .flow-header {
237
+ text-align: center;
238
+ max-width: 900px;
239
+ margin: 18px auto 12px auto;
240
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
241
+ }
242
+
243
+ .flow-title {
244
+ font-size: 1.9rem;
245
+ font-weight: 750;
246
+ letter-spacing: -0.3px;
247
+ margin-bottom: 4px;
248
+ color: var(--flow-title); /* title accent (needle stays as-is) */
249
+ }
250
+
251
+ .flow-links {
252
+ margin-bottom: 8px;
253
+ }
254
+
255
+ .flow-links a {
256
+ color: var(--flow-link); /* cool blue links */
257
+ font-weight: 600;
258
+ text-decoration: none;
259
+ margin: 0 0px;
260
+ font-size: 0.95rem;
261
+ transition: color 0.2s ease, text-shadow 0.2s ease;
262
+ }
263
+
264
+ .flow-links a:hover {
265
+ color: var(--flow-link-hover);
266
+ text-shadow: 0 0 10px rgba(37, 99, 235, 0.25);
267
+ }
268
+
269
+ @media (prefers-color-scheme: dark) {
270
+ .flow-links a:hover {
271
+ text-shadow: 0 0 12px rgba(147, 197, 253, 0.35);
272
+ }
273
+ }
274
+
275
+ .flow-desc {
276
+ font-size: 0.95rem;
277
+ color: var(--flow-muted);
278
+ max-width: 650px;
279
+ margin: 0 auto;
280
+ line-height: 1.5;
281
+ }
282
+
283
+ .bg-btn-row { display: flex; gap: 6px; overflow-x: auto; scrollbar-width: thin; }
284
+ .bg-btn {
285
+ width: 42px !important; height: 42px !important;
286
+ border: 2.5px solid #aaa !important; border-radius: 8px !important;
287
+ cursor: pointer !important; flex-shrink: 0 !important;
288
+ padding: 0 !important; outline: none !important;
289
+ transition: transform 0.15s ease, box-shadow 0.15s ease,
290
+ border-color 0.15s ease, filter 0.15s ease;
291
+ }
292
+ .bg-btn:hover {
293
+ transform: scale(1.15);
294
+ border-color: #333 !important;
295
+ box-shadow: 0 3px 10px rgba(0,0,0,0.4);
296
+ filter: brightness(1.15);
297
+ }
298
+ .bg-btn:active {
299
+ transform: scale(0.95);
300
+ }
301
+
302
+
303
+ @media (max-width: 1024px) {
304
+ #main-row {
305
+ flex-direction: column !important;
306
+ flex-wrap: wrap !important;
307
+ }
308
+ #main-row > * {
309
+ width: 100% !important;
310
+ flex: 1 1 100% !important;
311
+ min-width: 0 !important;
312
+ }
313
+ }
314
+
315
+ @media (max-width: 500px) {
316
+ #input-image { height: 400px !important; }
317
+ }
318
+ @media (max-width: 400px) {
319
+ #input-image { height: 300px !important; }
320
+ }
321
+ .prose :is(label span, .info) { font-weight: 400 !important; }
322
+ </style>
323
+
324
+ <div class="flow-header">
325
+ <div class="flow-title"><span style="color:#C209C1">✦</span> FlowDIS Demo</div>
326
+
327
+ <div class="flow-links">
328
+ <span>📄</span><a href="https://arxiv.org/" target="_blank" rel="noopener noreferrer">arXiv</a>
329
+ <span>💻</span><a href="https://github.com/Picsart-AI-Research/FlowDIS" target="_blank" rel="noopener noreferrer">Code</a>
330
+ </div>
331
+
332
+ <div class="flow-desc">
333
+ FlowDIS performs precise foreground segmentation, optionally guided by a text prompt to only preserve the specified objects.
334
+ </div>
335
+ </div>
336
+ """
337
+ )
338
+
339
+ with gr.Row(elem_id="main-row"):
340
+ # Left column: Input image, text field, and submit button
341
+ with gr.Column(scale=1):
342
+ input_image = gr.Image(
343
+ label="Input Image",
344
+ type="pil",
345
+ height=500,
346
+ elem_id="input-image",
347
+ )
348
+ text_input = gr.Textbox(
349
+ label="Text Prompt (Optional)",
350
+ placeholder="Enter what you want to retain...",
351
+ lines=1,
352
+ elem_id="text-prompt",
353
+ )
354
+ expand_prompt_check = gr.Checkbox(
355
+ label="Expand prompt",
356
+ value=True,
357
+ elem_id="expand-prompt",
358
+ info="Use Qwen3-VL-4B-Instruct model to expand the prompt for better text-guided segmentation.",
359
+ )
360
+
361
+ # Sliders for resolution and steps
362
+ with gr.Row():
363
+ with gr.Column(scale=1, min_width=300):
364
+ resolution_slider = gr.Slider(
365
+ minimum=1024,
366
+ maximum=2048,
367
+ value=1536,
368
+ step=64,
369
+ label="Inference Resolution",
370
+ info="Higher resolution preserves more details.",
371
+ )
372
+
373
+ with gr.Column(scale=1, min_width=300):
374
+ steps_slider = gr.Slider(
375
+ minimum=1,
376
+ maximum=12,
377
+ value=4,
378
+ step=1,
379
+ label="Number of Steps",
380
+ info="More steps generate sharper results.",
381
+ )
382
+
383
+ submit_btn = gr.Button("🚀 Remove Background", variant="primary")
384
+
385
+ # Right column: Output image
386
+ with gr.Column(scale=1):
387
+ output_image = gr.ImageSlider(
388
+ label="FlowDIS prediction",
389
+ type="pil",
390
+ format="webp",
391
+ height=500,
392
+ slider_position=10,
393
+ elem_id="output-slider",
394
+ )
395
+
396
+ _checker = "repeating-conic-gradient(#ccc 0% 25%,#fff 0% 50%) 50%/12px 12px"
397
+ _bg_buttons = [
398
+ (_checker, _checker),
399
+ ("#ffffff", "#ffffff"),
400
+ ("#000000", "#000000"),
401
+ ("#00ff00", "#00ff00"),
402
+ ("#0000ff", "#0000ff"),
403
+ ("#ff0000", "#ff0000"),
404
+ ("#ffff00", "#ffff00"),
405
+ ("#ff00ff", "#ff00ff"),
406
+ ("#00ffff", "#00ffff"),
407
+ ]
408
+ _onclick = (
409
+ "var s=document.getElementById('slider-bg-style');"
410
+ "if(!s){s=document.createElement('style');"
411
+ "s.id='slider-bg-style';document.head.appendChild(s);}"
412
+ "s.textContent='#output-slider img,#output-slider canvas"
413
+ "{background:'+this.dataset.bg+' !important}';"
414
+ )
415
+ gr.HTML(
416
+ value='<div class="bg-btn-row">'
417
+ + "".join(
418
+ f'<button class="bg-btn" style="background:{style}"'
419
+ f' data-bg="{bg}" onclick="{_onclick}"></button>'
420
+ for style, bg in _bg_buttons
421
+ )
422
+ + "</div>"
423
+ )
424
+
425
+ download_btn = gr.DownloadButton(
426
+ label="📥 Download PNG",
427
+ variant="primary",
428
+ interactive=False
429
+ )
430
+
431
+ # Connect the submit button to the processing function
432
+ submit_btn.click(
433
+ disable_download_btn,
434
+ outputs=download_btn
435
+ ).then(
436
+ fn=process_image,
437
+ inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider],
438
+ outputs=[output_image, download_btn]
439
+ )
440
+
441
+ # Optional: Also trigger on text input enter key
442
+ text_input.submit(
443
+ disable_download_btn,
444
+ outputs=download_btn
445
+ ).then(
446
+ fn=process_image,
447
+ inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider],
448
+ outputs=[output_image, download_btn],
449
+ )
450
+
451
+ examples_component = gr.Examples(
452
+ examples=examples,
453
+ inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider],
454
+ label="Examples",
455
+ elem_id="examples-table",
456
+ )
457
+
458
+ examples_component.dataset.click(
459
+ disable_download_btn,
460
+ outputs=download_btn
461
+ ).then(
462
+ process_image,
463
+ inputs=[input_image, text_input, expand_prompt_check, resolution_slider, steps_slider],
464
+ outputs=[output_image, download_btn],
465
+ )
466
+
467
+
468
+ # Launch the app
469
+ if __name__ == "__main__":
470
+ demo.queue(max_size=20)
471
+ if IS_HF_SPACE:
472
+ demo.launch(allowed_paths=[str(TEMP_DIR), "assets"])
473
+ else:
474
+ demo.launch(
475
+ server_name="0.0.0.0",
476
+ server_port=7860,
477
+ share=True,
478
+ allowed_paths=[str(TEMP_DIR), "assets"],
479
+ )
480
+
assets/examples/0.jpg ADDED

Git LFS Details

  • SHA256: 3c3404d60a4de67e2764479266bf6b7130ba1169d3fe0f37acebab5a62f8b16c
  • Pointer size: 131 Bytes
  • Size of remote file: 359 kB
assets/examples/1.jpg ADDED

Git LFS Details

  • SHA256: 51809ee1a7e27bbbbde5a0325723895fd2592e70dda327e2a06795bf7d98c0cf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
assets/examples/2.png ADDED

Git LFS Details

  • SHA256: 4d17b0f354a2c72baedc68137057e384ca68812b05d7070f4cfd4bb0ea8f49f5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.59 MB
assets/examples/3.jpg ADDED

Git LFS Details

  • SHA256: 7bc8ac73466b7af3eb5f0ea9b7bb99b7cbcd541532641d7d56e5106ed3282cd3
  • Pointer size: 132 Bytes
  • Size of remote file: 2.41 MB
assets/examples/4.jpg ADDED

Git LFS Details

  • SHA256: da7bab1e177dbd5dd3aaf1cb8d5e7040d84b8ff3f2e844860735510e63eb8b66
  • Pointer size: 131 Bytes
  • Size of remote file: 890 kB
assets/examples/5.jpg ADDED

Git LFS Details

  • SHA256: 5c35568c50cc28582379689dc79382b61dcb7bd0352f5345248c50f641c48926
  • Pointer size: 131 Bytes
  • Size of remote file: 998 kB
assets/examples/6.jpg ADDED

Git LFS Details

  • SHA256: 4a27f6ccf014f0c2e073eda7bb148cf7c70eb40ff4e7027f69d3e34d63854bca
  • Pointer size: 132 Bytes
  • Size of remote file: 1.4 MB
assets/examples/examples.csv ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ image_name,prompt,resolution,num_steps
2
+ 0.jpg,,2048,8
3
+ 1.jpg,,2048,8
4
+ 2.png,,1536,4
5
+ 3.jpg,,1536,2
6
+ 4.jpg,measuring tape,2048,8
7
+ 5.jpg,white shoes,1280,2
8
+ 6.jpg,bicycle,2048,8
assets/examples/prompts.json ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "0.jpg": "",
3
+ "1.jpg": "",
4
+ "2.png": "",
5
+ "3.jpg": "",
6
+ "4.jpg": "measuring tape",
7
+ "5.jpg": "white shoes",
8
+ "6.jpg": "bicycle"
9
+ }
assets/preview.png ADDED

Git LFS Details

  • SHA256: 6d640682e2c07159c5d43c530948dfcf87a036ae20b21f1f2a25395db45daf74
  • Pointer size: 131 Bytes
  • Size of remote file: 293 kB
flowdis/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FlowDIS: Language-Guided Dichotomous Image Segmentation with Flow Matching"""
2
+
3
+ from flowdis.configs import configs
4
+ from flowdis.loaders import load_autoencoder, load_clip, load_t5, load_transformer
5
+ from flowdis.sampling import flowdis_predict
6
+ from flowdis.util import load_models
7
+
8
+ __all__ = [
9
+ "configs",
10
+ "load_autoencoder",
11
+ "load_clip",
12
+ "load_t5",
13
+ "load_transformer",
14
+ "flowdis_predict",
15
+ "load_models",
16
+ ]
flowdis/autoencoder.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from einops import rearrange
5
+ from torch import Tensor, nn
6
+
7
+
8
+ @dataclass
9
+ class AutoEncoderParams:
10
+ resolution: int
11
+ in_channels: int
12
+ ch: int
13
+ out_ch: int
14
+ ch_mult: list[int]
15
+ num_res_blocks: int
16
+ z_channels: int
17
+ scale_factor: float
18
+ shift_factor: float
19
+
20
+
21
+ def swish(x: Tensor) -> Tensor:
22
+ return x * torch.sigmoid(x)
23
+
24
+
25
+ class AttnBlock(nn.Module):
26
+ def __init__(self, in_channels: int):
27
+ super().__init__()
28
+ self.in_channels = in_channels
29
+
30
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
31
+
32
+ self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1)
33
+ self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1)
34
+ self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1)
35
+ self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1)
36
+
37
+ def attention(self, h_: Tensor) -> Tensor:
38
+ h_ = self.norm(h_)
39
+ q = self.q(h_)
40
+ k = self.k(h_)
41
+ v = self.v(h_)
42
+
43
+ b, c, h, w = q.shape
44
+ q = rearrange(q, "b c h w -> b 1 (h w) c").contiguous()
45
+ k = rearrange(k, "b c h w -> b 1 (h w) c").contiguous()
46
+ v = rearrange(v, "b c h w -> b 1 (h w) c").contiguous()
47
+ h_ = nn.functional.scaled_dot_product_attention(q, k, v)
48
+
49
+ return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
50
+
51
+ def forward(self, x: Tensor) -> Tensor:
52
+ return x + self.proj_out(self.attention(x))
53
+
54
+
55
+ class ResnetBlock(nn.Module):
56
+ def __init__(self, in_channels: int, out_channels: int):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+ out_channels = in_channels if out_channels is None else out_channels
60
+ self.out_channels = out_channels
61
+
62
+ self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
63
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
64
+ self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
65
+ self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
66
+ if self.in_channels != self.out_channels:
67
+ self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
68
+
69
+ def forward(self, x):
70
+ h = x
71
+ h = self.norm1(h)
72
+ h = swish(h)
73
+ h = self.conv1(h)
74
+
75
+ h = self.norm2(h)
76
+ h = swish(h)
77
+ h = self.conv2(h)
78
+
79
+ if self.in_channels != self.out_channels:
80
+ x = self.nin_shortcut(x)
81
+
82
+ return x + h
83
+
84
+
85
+ class Downsample(nn.Module):
86
+ def __init__(self, in_channels: int):
87
+ super().__init__()
88
+ # no asymmetric padding in torch conv, must do it ourselves
89
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
90
+
91
+ def forward(self, x: Tensor):
92
+ pad = (0, 1, 0, 1)
93
+ x = nn.functional.pad(x, pad, mode="constant", value=0)
94
+ x = self.conv(x)
95
+ return x
96
+
97
+
98
+ class Upsample(nn.Module):
99
+ def __init__(self, in_channels: int):
100
+ super().__init__()
101
+ self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
102
+
103
+ def forward(self, x: Tensor):
104
+ x = nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
105
+ x = self.conv(x)
106
+ return x
107
+
108
+
109
+ class Encoder(nn.Module):
110
+ def __init__(
111
+ self,
112
+ resolution: int,
113
+ in_channels: int,
114
+ ch: int,
115
+ ch_mult: list[int],
116
+ num_res_blocks: int,
117
+ z_channels: int,
118
+ ):
119
+ super().__init__()
120
+ self.ch = ch
121
+ self.num_resolutions = len(ch_mult)
122
+ self.num_res_blocks = num_res_blocks
123
+ self.resolution = resolution
124
+ self.in_channels = in_channels
125
+ # downsampling
126
+ self.conv_in = nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
127
+
128
+ curr_res = resolution
129
+ in_ch_mult = (1,) + tuple(ch_mult)
130
+ self.in_ch_mult = in_ch_mult
131
+ self.down = nn.ModuleList()
132
+ block_in = self.ch
133
+ for i_level in range(self.num_resolutions):
134
+ block = nn.ModuleList()
135
+ attn = nn.ModuleList()
136
+ block_in = ch * in_ch_mult[i_level]
137
+ block_out = ch * ch_mult[i_level]
138
+ for _ in range(self.num_res_blocks):
139
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
140
+ block_in = block_out
141
+ down = nn.Module()
142
+ down.block = block
143
+ down.attn = attn
144
+ if i_level != self.num_resolutions - 1:
145
+ down.downsample = Downsample(block_in)
146
+ curr_res = curr_res // 2
147
+ self.down.append(down)
148
+
149
+ # middle
150
+ self.mid = nn.Module()
151
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
152
+ self.mid.attn_1 = AttnBlock(block_in)
153
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
154
+
155
+ # end
156
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
157
+ self.conv_out = nn.Conv2d(block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1)
158
+
159
+ def forward(self, x: Tensor) -> Tensor:
160
+ # downsampling
161
+ hs = [self.conv_in(x)]
162
+ for i_level in range(self.num_resolutions):
163
+ for i_block in range(self.num_res_blocks):
164
+ h = self.down[i_level].block[i_block](hs[-1])
165
+ if len(self.down[i_level].attn) > 0:
166
+ h = self.down[i_level].attn[i_block](h)
167
+ hs.append(h)
168
+ if i_level != self.num_resolutions - 1:
169
+ hs.append(self.down[i_level].downsample(hs[-1]))
170
+
171
+ # middle
172
+ h = hs[-1]
173
+ h = self.mid.block_1(h)
174
+ h = self.mid.attn_1(h)
175
+ h = self.mid.block_2(h)
176
+ # end
177
+ h = self.norm_out(h)
178
+ h = swish(h)
179
+ h = self.conv_out(h)
180
+ return h
181
+
182
+
183
+ class Decoder(nn.Module):
184
+ def __init__(
185
+ self,
186
+ ch: int,
187
+ out_ch: int,
188
+ ch_mult: list[int],
189
+ num_res_blocks: int,
190
+ in_channels: int,
191
+ resolution: int,
192
+ z_channels: int,
193
+ ):
194
+ super().__init__()
195
+ self.ch = ch
196
+ self.num_resolutions = len(ch_mult)
197
+ self.num_res_blocks = num_res_blocks
198
+ self.resolution = resolution
199
+ self.in_channels = in_channels
200
+ self.ffactor = 2 ** (self.num_resolutions - 1)
201
+
202
+ # compute in_ch_mult, block_in and curr_res at lowest res
203
+ block_in = ch * ch_mult[self.num_resolutions - 1]
204
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
205
+ self.z_shape = (1, z_channels, curr_res, curr_res)
206
+
207
+ # z to block_in
208
+ self.conv_in = nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
209
+
210
+ # middle
211
+ self.mid = nn.Module()
212
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in)
213
+ self.mid.attn_1 = AttnBlock(block_in)
214
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in)
215
+
216
+ # upsampling
217
+ self.up = nn.ModuleList()
218
+ for i_level in reversed(range(self.num_resolutions)):
219
+ block = nn.ModuleList()
220
+ attn = nn.ModuleList()
221
+ block_out = ch * ch_mult[i_level]
222
+ for _ in range(self.num_res_blocks + 1):
223
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
224
+ block_in = block_out
225
+ up = nn.Module()
226
+ up.block = block
227
+ up.attn = attn
228
+ if i_level != 0:
229
+ up.upsample = Upsample(block_in)
230
+ curr_res = curr_res * 2
231
+ self.up.insert(0, up) # prepend to get consistent order
232
+
233
+ # end
234
+ self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_in, eps=1e-6, affine=True)
235
+ self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
236
+
237
+ def forward(self, z: Tensor) -> Tensor:
238
+ # get dtype for proper tracing
239
+ upscale_dtype = next(self.up.parameters()).dtype
240
+
241
+ # z to block_in
242
+ h = self.conv_in(z)
243
+
244
+ # middle
245
+ h = self.mid.block_1(h)
246
+ h = self.mid.attn_1(h)
247
+ h = self.mid.block_2(h)
248
+
249
+ # cast to proper dtype
250
+ h = h.to(upscale_dtype)
251
+ # upsampling
252
+ for i_level in reversed(range(self.num_resolutions)):
253
+ for i_block in range(self.num_res_blocks + 1):
254
+ h = self.up[i_level].block[i_block](h)
255
+ if len(self.up[i_level].attn) > 0:
256
+ h = self.up[i_level].attn[i_block](h)
257
+ if i_level != 0:
258
+ h = self.up[i_level].upsample(h)
259
+
260
+ # end
261
+ h = self.norm_out(h)
262
+ h = swish(h)
263
+ h = self.conv_out(h)
264
+ return h
265
+
266
+
267
+ class DiagonalGaussian(nn.Module):
268
+ def __init__(self, sample: bool = True, chunk_dim: int = 1):
269
+ super().__init__()
270
+ self.sample = sample
271
+ self.chunk_dim = chunk_dim
272
+
273
+ def forward(self, z: Tensor) -> Tensor:
274
+ mean, logvar = torch.chunk(z, 2, dim=self.chunk_dim)
275
+ if self.sample:
276
+ std = torch.exp(0.5 * logvar)
277
+ return mean + std * torch.randn_like(mean)
278
+ else:
279
+ return mean
280
+
281
+
282
+ class AutoEncoder(nn.Module):
283
+ def __init__(self, params: AutoEncoderParams):
284
+ super().__init__()
285
+ self.params = params
286
+ self.encoder = Encoder(
287
+ resolution=params.resolution,
288
+ in_channels=params.in_channels,
289
+ ch=params.ch,
290
+ ch_mult=params.ch_mult,
291
+ num_res_blocks=params.num_res_blocks,
292
+ z_channels=params.z_channels,
293
+ )
294
+ self.decoder = Decoder(
295
+ resolution=params.resolution,
296
+ in_channels=params.in_channels,
297
+ ch=params.ch,
298
+ out_ch=params.out_ch,
299
+ ch_mult=params.ch_mult,
300
+ num_res_blocks=params.num_res_blocks,
301
+ z_channels=params.z_channels,
302
+ )
303
+ self.reg = DiagonalGaussian()
304
+
305
+ self.scale_factor = params.scale_factor
306
+ self.shift_factor = params.shift_factor
307
+
308
+ def encode(self, x: Tensor) -> Tensor:
309
+ z = self.reg(self.encoder(x))
310
+ z = self.scale_factor * (z - self.shift_factor)
311
+ return z
312
+
313
+ def decode(self, z: Tensor) -> Tensor:
314
+ z = z / self.scale_factor + self.shift_factor
315
+ return self.decoder(z)
316
+
317
+ def forward(self, x: Tensor) -> Tensor:
318
+ return self.decode(self.encode(x))
flowdis/conditioner.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import Tensor, nn
2
+ from transformers import (
3
+ CLIPTextConfig,
4
+ CLIPTextModel,
5
+ CLIPTokenizer,
6
+ T5Config,
7
+ T5EncoderModel,
8
+ T5Tokenizer,
9
+ )
10
+
11
+
12
+ class HFEmbedder(nn.Module):
13
+ def __init__(self, version: str, max_length: int, is_clip: bool, **hf_kwargs):
14
+ super().__init__()
15
+ self.is_clip = is_clip
16
+ self.max_length = max_length
17
+ self.output_key = "pooler_output" if self.is_clip else "last_hidden_state"
18
+
19
+ if self.is_clip:
20
+ self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(version, max_length=max_length)
21
+ config = CLIPTextConfig.from_pretrained(version, **hf_kwargs)
22
+ self.hf_module: CLIPTextModel = CLIPTextModel._from_config(config)
23
+ else:
24
+ self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained(version, max_length=max_length, legacy=True)
25
+ config = T5Config.from_pretrained(version, **hf_kwargs)
26
+ self.hf_module: T5EncoderModel = T5EncoderModel._from_config(config)
27
+
28
+ def forward(self, text: list[str]) -> Tensor:
29
+ batch_encoding = self.tokenizer(
30
+ text,
31
+ truncation=True,
32
+ max_length=self.max_length,
33
+ return_length=False,
34
+ return_overflowing_tokens=False,
35
+ padding="max_length",
36
+ return_tensors="pt",
37
+ )
38
+
39
+ outputs = self.hf_module(
40
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
41
+ attention_mask=None,
42
+ output_hidden_states=False,
43
+ )
44
+ return outputs[self.output_key]
flowdis/configs.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flowdis.autoencoder import AutoEncoderParams
2
+ from flowdis.model import FluxParams
3
+
4
+
5
+ configs = {
6
+ "autoencoder": AutoEncoderParams(
7
+ resolution=256,
8
+ in_channels=3,
9
+ ch=128,
10
+ out_ch=3,
11
+ ch_mult=[1, 2, 4, 4],
12
+ num_res_blocks=2,
13
+ z_channels=16,
14
+ scale_factor=0.3611,
15
+ shift_factor=0.1159,
16
+ ),
17
+ "flowdis": FluxParams(
18
+ in_channels=128,
19
+ out_channels=64,
20
+ vec_in_dim=768,
21
+ context_in_dim=4096,
22
+ hidden_size=3072,
23
+ mlp_ratio=4.0,
24
+ num_heads=24,
25
+ depth=19,
26
+ depth_single_blocks=38,
27
+ axes_dim=[16, 56, 56],
28
+ theta=10_000,
29
+ qkv_bias=True,
30
+ guidance_embed=False,
31
+ ),
32
+ }
flowdis/layers.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from dataclasses import dataclass
3
+
4
+ import torch
5
+ from einops import rearrange
6
+ from torch import Tensor, nn
7
+
8
+ from flowdis.math import attention, rope
9
+
10
+
11
+ class EmbedND(nn.Module):
12
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
13
+ super().__init__()
14
+ self.dim = dim
15
+ self.theta = theta
16
+ self.axes_dim = axes_dim
17
+
18
+ def forward(self, ids: Tensor) -> Tensor:
19
+ n_axes = ids.shape[-1]
20
+ emb = torch.cat(
21
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
22
+ dim=-3,
23
+ )
24
+
25
+ return emb.unsqueeze(1)
26
+
27
+
28
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
29
+ """
30
+ Create sinusoidal timestep embeddings.
31
+ :param t: a 1-D Tensor of N indices, one per batch element.
32
+ These may be fractional.
33
+ :param dim: the dimension of the output.
34
+ :param max_period: controls the minimum frequency of the embeddings.
35
+ :return: an (N, D) Tensor of positional embeddings.
36
+ """
37
+ t = time_factor * t
38
+ half = dim // 2
39
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
40
+
41
+ args = t[:, None].float() * freqs[None]
42
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
43
+ if dim % 2:
44
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
45
+ if torch.is_floating_point(t):
46
+ embedding = embedding.to(t)
47
+ return embedding
48
+
49
+
50
+ class MLPEmbedder(nn.Module):
51
+ def __init__(self, in_dim: int, hidden_dim: int):
52
+ super().__init__()
53
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
54
+ self.silu = nn.SiLU()
55
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
56
+
57
+ def forward(self, x: Tensor) -> Tensor:
58
+ return self.out_layer(self.silu(self.in_layer(x)))
59
+
60
+
61
+ class RMSNorm(torch.nn.Module):
62
+ def __init__(self, dim: int):
63
+ super().__init__()
64
+ self.scale = nn.Parameter(torch.ones(dim))
65
+ self.dim = dim
66
+
67
+ def forward(self, x: Tensor):
68
+ x_dtype = x.dtype
69
+ x = x.float()
70
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
71
+ return (x * rrms).to(dtype=x_dtype) * self.scale
72
+
73
+
74
+ class QKNorm(torch.nn.Module):
75
+ def __init__(self, dim: int):
76
+ super().__init__()
77
+ self.query_norm = RMSNorm(dim)
78
+ self.key_norm = RMSNorm(dim)
79
+
80
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
81
+ q = self.query_norm(q)
82
+ k = self.key_norm(k)
83
+ return q.to(v), k.to(v)
84
+
85
+
86
+ class SelfAttention(nn.Module):
87
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
88
+ super().__init__()
89
+ self.num_heads = num_heads
90
+ head_dim = dim // num_heads
91
+
92
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
93
+ self.norm = QKNorm(head_dim)
94
+ self.proj = nn.Linear(dim, dim)
95
+
96
+ def forward(self, x: Tensor, pe: Tensor) -> Tensor:
97
+ qkv = self.qkv(x)
98
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
99
+ q, k = self.norm(q, k, v)
100
+ x = attention(q, k, v, pe=pe)
101
+ x = self.proj(x)
102
+ return x
103
+
104
+
105
+ @dataclass
106
+ class ModulationOut:
107
+ shift: Tensor
108
+ scale: Tensor
109
+ gate: Tensor
110
+
111
+
112
+ class Modulation(nn.Module):
113
+ def __init__(self, dim: int, double: bool):
114
+ super().__init__()
115
+ self.is_double = double
116
+ self.multiplier = 6 if double else 3
117
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
118
+
119
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
120
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
121
+
122
+ return (
123
+ ModulationOut(*out[:3]),
124
+ ModulationOut(*out[3:]) if self.is_double else None,
125
+ )
126
+
127
+
128
+ class DoubleStreamBlock(nn.Module):
129
+ def __init__(
130
+ self,
131
+ hidden_size: int,
132
+ num_heads: int,
133
+ mlp_ratio: float,
134
+ qkv_bias: bool = False,
135
+ ):
136
+ super().__init__()
137
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
138
+ self.num_heads = num_heads
139
+ self.hidden_size = hidden_size
140
+ self.img_mod = Modulation(hidden_size, double=True)
141
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
142
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
143
+
144
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
145
+ self.img_mlp = nn.Sequential(
146
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
147
+ nn.GELU(approximate="tanh"),
148
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
149
+ )
150
+
151
+ self.txt_mod = Modulation(hidden_size, double=True)
152
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
153
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
154
+
155
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
156
+ self.txt_mlp = nn.Sequential(
157
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
158
+ nn.GELU(approximate="tanh"),
159
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
160
+ )
161
+
162
+ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
163
+ return self._forward(img, txt, vec, pe)
164
+
165
+ def _forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor) -> tuple[Tensor, Tensor]:
166
+ img_mod1, img_mod2 = self.img_mod(vec)
167
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
168
+
169
+ # prepare image for attention
170
+ img_modulated = self.img_norm1(img)
171
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
172
+ img_qkv = self.img_attn.qkv(img_modulated)
173
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
175
+
176
+ # prepare txt for attention
177
+ txt_modulated = self.txt_norm1(txt)
178
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
179
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
180
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
181
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
182
+
183
+ # run actual attention
184
+ q = torch.cat((txt_q, img_q), dim=2)
185
+ k = torch.cat((txt_k, img_k), dim=2)
186
+ v = torch.cat((txt_v, img_v), dim=2)
187
+
188
+ attn = attention(q, k, v, pe=pe)
189
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
190
+
191
+ # calculate the img bloks
192
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
193
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
194
+
195
+ # calculate the txt bloks
196
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
197
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
198
+ return img, txt
199
+
200
+
201
+ class SingleStreamBlock(nn.Module):
202
+ """
203
+ A DiT block with parallel linear layers as described in
204
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
205
+ """
206
+
207
+ def __init__(
208
+ self,
209
+ hidden_size: int,
210
+ num_heads: int,
211
+ mlp_ratio: float = 4.0,
212
+ qk_scale: float | None = None,
213
+ ):
214
+ super().__init__()
215
+ self.hidden_dim = hidden_size
216
+ self.num_heads = num_heads
217
+ head_dim = hidden_size // num_heads
218
+ self.scale = qk_scale or head_dim**-0.5
219
+
220
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
221
+ # qkv and mlp_in
222
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
223
+ # proj and mlp_out
224
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
225
+
226
+ self.norm = QKNorm(head_dim)
227
+
228
+ self.hidden_size = hidden_size
229
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
230
+
231
+ self.mlp_act = nn.GELU(approximate="tanh")
232
+ self.modulation = Modulation(hidden_size, double=False)
233
+
234
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
235
+ return self._forward(x, vec, pe)
236
+
237
+ def _forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
238
+ mod, _ = self.modulation(vec)
239
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
240
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
241
+
242
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
243
+ q, k = self.norm(q, k, v)
244
+
245
+ # compute attention
246
+ attn = attention(q, k, v, pe=pe)
247
+ # compute activation in mlp stream, cat again and run second linear layer
248
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
249
+ return x + mod.gate * output
250
+
251
+
252
+ class LastLayer(nn.Module):
253
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
254
+ super().__init__()
255
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
256
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
257
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
258
+
259
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
260
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
261
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
262
+ x = self.linear(x)
263
+ return x
flowdis/loaders.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from safetensors.torch import load_file
4
+
5
+ from flowdis.autoencoder import AutoEncoder
6
+ from flowdis.conditioner import HFEmbedder
7
+ from flowdis.configs import configs
8
+ from flowdis.model import Flux, FluxParams
9
+
10
+
11
+ def load_transformer(
12
+ model_name: str,
13
+ model_path: str,
14
+ device: str | torch.device = "cuda",
15
+ config: FluxParams = None,
16
+ state_dict: dict = None,
17
+ ) -> Flux:
18
+ with torch.device("meta"):
19
+ model = Flux(config if config else configs[model_name]).to(dtype=torch.bfloat16)
20
+ model.to_empty(device="cpu")
21
+ if state_dict is None:
22
+ if str(model_path).endswith(".safetensors"):
23
+ state_dict = load_file(model_path, device="cpu")
24
+ else:
25
+ state_dict = torch.load(model_path, map_location="cpu")
26
+ model.load_state_dict(state_dict, assign=True, strict=False)
27
+ model = model.to(device=device, dtype=torch.bfloat16)
28
+ return model.eval()
29
+
30
+
31
+ def load_autoencoder(
32
+ model_path: str,
33
+ device: str | torch.device = "cuda"
34
+ ) -> AutoEncoder:
35
+ with torch.device("meta"):
36
+ ae = AutoEncoder(configs["autoencoder"])
37
+ ae.to_empty(device="cpu")
38
+ state_dict = load_file(model_path, device="cpu")
39
+ ae.load_state_dict(state_dict, assign=True, strict=False)
40
+ ae = ae.to(device=device, dtype=torch.bfloat16)
41
+ return ae.eval()
42
+
43
+
44
+ def load_t5(
45
+ model_path: str,
46
+ max_length: int = 512,
47
+ device: str | torch.device = "cuda"
48
+ ) -> HFEmbedder:
49
+ with torch.device("meta"):
50
+ t5 = HFEmbedder(
51
+ model_path.parent,
52
+ max_length=max_length,
53
+ is_clip=False,
54
+ dtype=torch.bfloat16
55
+ )
56
+ t5.to_empty(device="cpu")
57
+ state_dict = load_file(model_path, device="cpu")
58
+ t5.load_state_dict(state_dict, assign=True, strict=False)
59
+ return t5.to(device=device, dtype=torch.bfloat16)
60
+
61
+
62
+ def load_clip(
63
+ model_path: str,
64
+ device: str | torch.device = "cuda"
65
+ ) -> HFEmbedder:
66
+ clip = HFEmbedder(
67
+ model_path.parent,
68
+ max_length=77,
69
+ is_clip=True,
70
+ dtype=torch.bfloat16
71
+ )
72
+ state_dict = load_file(model_path, device="cpu")
73
+ clip.load_state_dict(state_dict, assign=True, strict=False)
74
+ return clip.to(device=device, dtype=torch.bfloat16)
75
+
flowdis/math.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import Tensor
4
+
5
+
6
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
+ q, k = apply_rope(q, k, pe)
8
+
9
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
+ x = rearrange(x, "B H L D -> B L (H D)")
11
+
12
+ return x
13
+
14
+
15
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
16
+ assert dim % 2 == 0
17
+ scale = torch.arange(0, dim, 2, dtype=pos.dtype, device=pos.device) / dim
18
+ omega = 1.0 / (theta**scale)
19
+ out = torch.einsum("...n,d->...nd", pos, omega)
20
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
21
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
22
+ return out.float()
23
+
24
+
25
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
26
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
27
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
28
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
29
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
30
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
flowdis/model.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ from torch import Tensor, nn
5
+
6
+ from flowdis.layers import (
7
+ DoubleStreamBlock,
8
+ EmbedND,
9
+ LastLayer,
10
+ MLPEmbedder,
11
+ SingleStreamBlock,
12
+ timestep_embedding,
13
+ )
14
+
15
+
16
+ @dataclass
17
+ class FluxParams:
18
+ in_channels: int
19
+ out_channels: int
20
+ vec_in_dim: int
21
+ context_in_dim: int
22
+ hidden_size: int
23
+ mlp_ratio: float
24
+ num_heads: int
25
+ depth: int
26
+ depth_single_blocks: int
27
+ axes_dim: list[int]
28
+ theta: int
29
+ qkv_bias: bool
30
+ guidance_embed: bool
31
+
32
+
33
+ class Flux(nn.Module):
34
+ """
35
+ Transformer model for flow matching on sequences.
36
+ """
37
+
38
+ def __init__(self, params: FluxParams):
39
+ super().__init__()
40
+
41
+ self.params = params
42
+ self.in_channels = params.in_channels
43
+ self.out_channels = params.out_channels
44
+ if params.hidden_size % params.num_heads != 0:
45
+ raise ValueError(f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}")
46
+ pe_dim = params.hidden_size // params.num_heads
47
+ if sum(params.axes_dim) != pe_dim:
48
+ raise ValueError(f"Got {params.axes_dim} but expected positional dim {pe_dim}")
49
+ self.hidden_size = params.hidden_size
50
+ self.num_heads = params.num_heads
51
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
52
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
53
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
54
+ self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
55
+ self.guidance_in = (
56
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if params.guidance_embed else nn.Identity()
57
+ )
58
+ self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
59
+
60
+ self.double_blocks = nn.ModuleList(
61
+ [
62
+ DoubleStreamBlock(
63
+ self.hidden_size,
64
+ self.num_heads,
65
+ mlp_ratio=params.mlp_ratio,
66
+ qkv_bias=params.qkv_bias,
67
+ ) for _ in range(params.depth)
68
+ ]
69
+ )
70
+
71
+ self.single_blocks = nn.ModuleList(
72
+ [
73
+ SingleStreamBlock(
74
+ self.hidden_size,
75
+ self.num_heads,
76
+ mlp_ratio=params.mlp_ratio,
77
+ ) for _ in range(params.depth_single_blocks)
78
+ ]
79
+ )
80
+
81
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
82
+
83
+ def forward(
84
+ self,
85
+ img: Tensor,
86
+ img_ids: Tensor,
87
+ txt: Tensor,
88
+ txt_ids: Tensor,
89
+ timesteps: Tensor,
90
+ y: Tensor,
91
+ guidance: Tensor | None = None,
92
+ ) -> Tensor:
93
+ if img.ndim != 3 or txt.ndim != 3:
94
+ raise ValueError("Input img and txt tensors must have 3 dimensions.")
95
+
96
+ # running on sequences img
97
+ img = self.img_in(img)
98
+ vec = self.time_in(timestep_embedding(timesteps, 256))
99
+ if self.params.guidance_embed:
100
+ if guidance is None:
101
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
102
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
103
+ vec = vec + self.vector_in(y)
104
+ txt = self.txt_in(txt)
105
+
106
+ ids = torch.cat((txt_ids, img_ids), dim=1)
107
+ pe = self.pe_embedder(ids)
108
+
109
+ for block in self.double_blocks:
110
+ img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
111
+
112
+ img = torch.cat((txt, img), 1)
113
+ for block in self.single_blocks:
114
+ img = block(img, vec=vec, pe=pe)
115
+ img = img[:, txt.shape[1] :, ...]
116
+
117
+ img = self.final_layer(img, vec) # (N, T, patch_size ** 2 * out_channels)
118
+ return img
flowdis/sampling.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torchvision.transforms.functional as tvF
5
+ from einops import rearrange, repeat
6
+ from PIL import Image
7
+ from scipy import stats
8
+ from torch import Tensor
9
+
10
+ from flowdis.model import Flux
11
+ from flowdis.util import Models
12
+
13
+
14
+ def unpack(x: Tensor, height: int, width: int) -> Tensor:
15
+ return rearrange(
16
+ x,
17
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
18
+ h=math.ceil(height / 16),
19
+ w=math.ceil(width / 16),
20
+ ph=2,
21
+ pw=2,
22
+ )
23
+
24
+
25
+ def beta_scheduler(num_timesteps: int, alpha: float = 2.5, beta: float = 1.0) -> list[float]:
26
+ q = torch.linspace(1, 0, num_timesteps+1)
27
+ steps = stats.beta.ppf(q, alpha, beta).tolist()
28
+ if steps[-1] > 0.0:
29
+ steps.append(0.0)
30
+ return steps
31
+
32
+
33
+ def prepare(
34
+ img: Tensor,
35
+ prompt: str | list[str],
36
+ models: Models,
37
+ device: str = "cuda"
38
+ ) -> dict[str, Tensor]:
39
+ # load and encode the conditioning image and the mask
40
+ bs, _, _, _ = img.shape
41
+ if bs == 1 and not isinstance(prompt, str):
42
+ bs = len(prompt)
43
+ if isinstance(prompt, str):
44
+ prompt = [prompt]
45
+
46
+ with torch.no_grad():
47
+ img = models.ae.encode(img.to(device=device, dtype=torch.bfloat16))
48
+ h, w = img.shape[2], img.shape[3]
49
+
50
+ img_ids = torch.zeros(h // 2, w // 2, 3)
51
+ img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
52
+ img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
53
+ img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
54
+
55
+ img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
56
+ if img.shape[0] == 1 and bs > 1:
57
+ img = repeat(img, "1 ... -> bs ...", bs=bs)
58
+
59
+ txt = models.t5(prompt)
60
+ if txt.shape[0] == 1 and bs > 1:
61
+ txt = repeat(txt, "1 ... -> bs ...", bs=bs)
62
+ txt_ids = torch.zeros(bs, txt.shape[1], 3)
63
+
64
+ vec = models.clip(prompt)
65
+ if vec.shape[0] == 1 and bs > 1:
66
+ vec = repeat(vec, "1 ... -> bs ...", bs=bs)
67
+
68
+ return_dict = {
69
+ "img": img,
70
+ "img_ids": img_ids.to(img.device),
71
+ "txt": txt.to(img.device),
72
+ "txt_ids": txt_ids.to(img.device),
73
+ "vec": vec.to(img.device),
74
+ }
75
+
76
+ return return_dict
77
+
78
+
79
+ def solve_flowdis_ode(
80
+ model: Flux,
81
+ img: Tensor,
82
+ img_ids: Tensor,
83
+ txt: Tensor,
84
+ txt_ids: Tensor,
85
+ vec: Tensor,
86
+ num_inference_steps: int,
87
+ ):
88
+ zt = img
89
+ timesteps = beta_scheduler(num_inference_steps)
90
+ for t_curr, t_prev in zip(timesteps[:-1], timesteps[1:]):
91
+ t_vec = torch.full((zt.shape[0],), t_curr, dtype=zt.dtype, device=zt.device)
92
+ pred = model(
93
+ img=torch.cat((zt, img), dim=-1),
94
+ img_ids=img_ids,
95
+ txt=txt,
96
+ txt_ids=txt_ids,
97
+ y=vec,
98
+ timesteps=t_vec,
99
+ )
100
+ zt = zt + (t_prev - t_curr) * pred
101
+ return zt
102
+
103
+
104
+ @torch.no_grad()
105
+ def flowdis_predict(
106
+ image: Tensor,
107
+ prompt: str | list[str],
108
+ models: Models,
109
+ resolution: int = 1024,
110
+ num_inference_steps: int = 2,
111
+ device: str = "cuda",
112
+ ):
113
+ image_orig = image.convert("RGB")
114
+ image = image.resize((resolution, resolution))
115
+
116
+ image_t = tvF.to_tensor(image).unsqueeze(0).to(device=device)
117
+ image_t = (image_t - 0.5) / 0.5
118
+
119
+ inp = prepare(image_t, prompt, models, device)
120
+
121
+ pred_mask_latent_t = solve_flowdis_ode(
122
+ models.transformer,
123
+ **inp,
124
+ num_inference_steps=num_inference_steps,
125
+ )
126
+
127
+ pred_mask_latent_t = unpack(pred_mask_latent_t.float(), resolution, resolution)
128
+ with torch.autocast(device_type=device, dtype=torch.bfloat16):
129
+ pred_mask_t = models.ae.decode(pred_mask_latent_t).clamp(-1, 1)
130
+
131
+ pred_mask_t = rearrange(pred_mask_t[0], "c h w -> h w c")
132
+ pred_mask_np = (127.5 * (pred_mask_t + 1.0)).mean(dim=-1).cpu().byte().numpy()
133
+ pred_mask = Image.fromarray(pred_mask_np).convert("L")
134
+ pred_mask = pred_mask.resize(image_orig.size)
135
+
136
+ return pred_mask
flowdis/util.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from copy import deepcopy
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ import numpy as np
8
+ from huggingface_hub import snapshot_download
9
+ from safetensors.torch import load_file
10
+ from flowdis.autoencoder import AutoEncoder
11
+ from flowdis.conditioner import HFEmbedder
12
+ from flowdis.configs import configs
13
+ from flowdis.loaders import load_autoencoder, load_clip, load_t5, load_transformer
14
+ from flowdis.model import Flux
15
+
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @dataclass
21
+ class Models:
22
+ clip: HFEmbedder
23
+ t5: HFEmbedder
24
+ ae: AutoEncoder
25
+ transformer: Flux
26
+
27
+
28
+ def load_models(
29
+ root_model_dir: Path = None,
30
+ device: str | torch.device = "cuda"
31
+ ) -> Models:
32
+ """
33
+ Load the models for the FlowDIS pipeline.
34
+
35
+ Args:
36
+ root_model_dir: The root model directory.
37
+ If None, the models are downloaded from the Hugging Face Hub.
38
+ device: The device to load the models on.
39
+
40
+ Returns:
41
+ Models: The loaded models.
42
+ """
43
+ if root_model_dir is None:
44
+ root_model_dir = download_from_hf_hub("PAIR/FlowDIS")
45
+
46
+ logger.info("Loading T5.")
47
+ t5 = load_t5(
48
+ model_path=root_model_dir / "t5-v1_1-xxl" / "model.safetensors",
49
+ device=device,
50
+ max_length=512
51
+ )
52
+
53
+ logger.info("Loading CLIP.")
54
+ clip = load_clip(
55
+ model_path=root_model_dir / "clip-vit-large-patch14" / "model.safetensors",
56
+ device=device
57
+ )
58
+
59
+ logger.info("Loading AE.")
60
+ ae = load_autoencoder(
61
+ model_path=root_model_dir / "ae.safetensors",
62
+ device=device
63
+ )
64
+
65
+ logger.info("Loading Transformer.")
66
+ model = load_transformer(
67
+ model_name="flowdis",
68
+ model_path=root_model_dir / "flowdis-transformer.safetensors",
69
+ device=device,
70
+ )
71
+
72
+ logger.info("All models loaded.")
73
+
74
+ return Models(
75
+ clip=clip,
76
+ t5=t5,
77
+ ae=ae,
78
+ transformer=model,
79
+ )
80
+
81
+
82
+ def download_from_hf_hub(
83
+ repo_id: str,
84
+ cache_dir: str | Path | None = None,
85
+ revision: str | None = None,
86
+ ) -> Path:
87
+ """
88
+ Download a FlowDIS model repository from the Hugging Face Hub.
89
+
90
+ Args:
91
+ repo_id: The Hugging Face Hub repo id (e.g. "PAIR/FlowDIS").
92
+ cache_dir: Optional cache directory. Defaults to the huggingface_hub
93
+ default (typically ~/.cache/huggingface/hub).
94
+ revision: Optional git revision (branch, tag, or commit SHA).
95
+
96
+ Returns:
97
+ Path to the local directory containing the downloaded snapshot. The
98
+ directory layout matches the repo layout on the Hub, so it can be
99
+ passed directly to `load_models` as `root_model_dir`.
100
+ """
101
+ logger.info(f"Downloading {repo_id} from Hugging Face Hub.")
102
+ local_dir = snapshot_download(
103
+ repo_id=repo_id,
104
+ cache_dir=cache_dir,
105
+ revision=revision,
106
+ )
107
+ logger.info(f"Snapshot available at {local_dir}.")
108
+ return Path(local_dir)
109
+
110
+
111
+ def green_screen(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
112
+ img_np = np.array(img)
113
+ mask = (np.array(mask) / 255)[:, :, np.newaxis].repeat(3, axis=2)
114
+ combined = img_np * mask + (1-mask) * np.array([0, 255, 0], dtype=np.uint8)
115
+ combined = combined.astype(np.uint8)
116
+ return combined
pyproject.toml ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "flowdis"
7
+ version = "0.1.0"
8
+ description = "FlowDIS: Language-Guided Dichotomous Image Segmentation with Flow Matching"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "MIT" }
12
+ authors = [
13
+ { name = "Andranik Sargsyan" },
14
+ { name = "Shant Navasardyan" },
15
+ ]
16
+ keywords = ["segmentation", "flow-matching", "background removal", "deep-learning"]
17
+ classifiers = [
18
+ "Development Status :: 3 - Alpha",
19
+ "Intended Audience :: Science/Research",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Programming Language :: Python :: 3",
22
+ "Programming Language :: Python :: 3.10",
23
+ "Programming Language :: Python :: 3.11",
24
+ "Programming Language :: Python :: 3.12",
25
+ "Topic :: Scientific/Engineering :: Image Recognition",
26
+ ]
27
+ dependencies = [
28
+ "accelerate>=1.12.0,<2.0",
29
+ "einops>=0.8.2,<1.0",
30
+ "gradio==6.3.0",
31
+ "numpy>=1.24.0,<2.0",
32
+ "opencv-python>=4.11.0,<5.0",
33
+ "Pillow>=10.0.0,<11.0",
34
+ "safetensors>=0.7.0,<1.0",
35
+ "scipy>=1.17.1,<2.0",
36
+ "sentencepiece>=0.2.1,<1.0",
37
+ "tiktoken>=0.12.0,<1.0",
38
+ "torch>=2.8.0,<=2.10",
39
+ "torchvision>=0.25.0",
40
+ "transformers>=4.39.0,<5.0",
41
+ ]
42
+
43
+ [project.optional-dependencies]
44
+ dev = [
45
+ "pytest>=7.0",
46
+ "ruff>=0.1.0",
47
+ ]
48
+
49
+ [tool.setuptools]
50
+ packages = ["flowdis"]
qwen.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import torch
4
+ from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
5
+ from PIL import Image
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+ # Load model if GPU is available
10
+ model = None
11
+ processor = None
12
+ if torch.cuda.is_available():
13
+ logger.info("Loading Qwen3VL model.")
14
+ model = Qwen3VLForConditionalGeneration.from_pretrained(
15
+ "Qwen/Qwen3-VL-4B-Instruct",
16
+ dtype=torch.bfloat16,
17
+ device_map="auto"
18
+ )
19
+ processor = AutoProcessor.from_pretrained("Qwen/Qwen3-VL-4B-Instruct")
20
+ logger.info("Qwen3VL model loaded.")
21
+ else:
22
+ logger.info("Qwen3VL was not loaded because no GPU is available.")
23
+
24
+
25
+ def expand_prompt(image: Image.Image, user_prompt: str) -> str:
26
+ """
27
+ Expand the user prompt using the Qwen3VL model.
28
+
29
+ Args:
30
+ image: The image to use for the prompt expansion.
31
+ user_prompt: The user prompt to expand.
32
+
33
+ Returns:
34
+ The expanded prompt.
35
+ """
36
+ messages = [
37
+ {
38
+ "role": "user",
39
+ "content": [
40
+ {"type": "image"},
41
+ {"type": "text", "text": f"Describe the {user_prompt} in this image with a short prompt. Don't use surrounding objects in the description. Also don't describe the background, like what it is sitting on or what it is on top of, etc..."}
42
+ ]
43
+ }
44
+ ]
45
+
46
+ text = processor.apply_chat_template(
47
+ messages,
48
+ tokenize=False,
49
+ add_generation_prompt=True
50
+ )
51
+
52
+ inputs = processor(
53
+ text=[text],
54
+ images=[image],
55
+ padding=True,
56
+ return_tensors="pt"
57
+ )
58
+
59
+ inputs = inputs.to(model.device)
60
+
61
+ with torch.no_grad():
62
+ generated_ids = model.generate(
63
+ **inputs,
64
+ max_new_tokens=512
65
+ )
66
+ generated_ids_trimmed = generated_ids[:, inputs["input_ids"].shape[1]:]
67
+
68
+ output_text = processor.batch_decode(
69
+ generated_ids_trimmed,
70
+ skip_special_tokens=True
71
+ )[0]
72
+
73
+ return output_text
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate>=1.12.0,<2.0
2
+ einops>=0.8.2,<1.0
3
+ gradio==6.3.0
4
+ numpy>=1.24.0,<2.0
5
+ opencv-python>=4.11.0,<5.0
6
+ Pillow>=10.0.0,<11.0
7
+ safetensors>=0.7.0,<1.0
8
+ scipy>=1.17.1,<2.0
9
+ sentencepiece>=0.2.1,<1.0
10
+ tiktoken>=0.12.0,<1.0
11
+ torch>=2.8.0,<=2.10
12
+ torchvision>=0.25.0
13
+ transformers>=4.39.0,<5.0