va1bhavagrawa1 commited on
Commit
c714a7e
·
0 Parent(s):

first commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
2
+ *.jpg filter=lfs diff=lfs merge=lfs -text
3
+ *.webp filter=lfs diff=lfs merge=lfs -text
Image0001.png ADDED

Git LFS Details

  • SHA256: 823adf77d40fd5e79684be532cb5c3dd5428665c15aa9883d3eac4538bf3a787
  • Pointer size: 131 Bytes
  • Size of remote file: 521 kB
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
app.py ADDED
@@ -0,0 +1,1504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import sys
4
+ import numpy as np
5
+ import tempfile
6
+ import shutil
7
+ import base64
8
+ import io
9
+ from PIL import Image
10
+ import gradio as gr
11
+ import time
12
+ import copy
13
+ import requests
14
+ import json
15
+ import pickle
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ from object_scales import scales
18
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
19
+ import pickle
20
+ from datetime import datetime
21
+ from infer_backend import initialize_inference_engine, run_inference_from_gradio
22
+
23
+ COLORS = [
24
+ (1.0, 0.0, 0.0), # Red
25
+ (0.0, 0.8, 0.2), # Green
26
+ (0.0, 0.0, 1.0), # Blue
27
+ (1.0, 1.0, 0.0), # Yellow
28
+ (0.0, 1.0, 1.0), # Cyan
29
+ (1.0, 0.0, 1.0), # Magenta
30
+ (1.0, 0.6, 0.0), # Orange
31
+ (0.6, 0.0, 0.8), # Purple
32
+ (0.0, 0.4, 0.0), # Dark Green
33
+ (0.8, 0.8, 0.8), # Light Gray
34
+ (0.2, 0.2, 0.2) # Dark Gray
35
+ ]
36
+
37
+ CHECKPOINT_NAMES = [
38
+ "rgb__r1/epoch-0__checkpoint-25917",
39
+ "rgb__finetune_1024/epoch-0__checkpoint-3000",
40
+ "rgb__finetune_1024/epoch-1__checkpoint-4000",
41
+ "rgb__finetune_1024/epoch-1__checkpoint-5000",
42
+ "rgb__finetune_1024/epoch-1__checkpoint-6000",
43
+ "rgb__finetune_1024/epoch-1__checkpoint-7000",
44
+ "rgb__finetune_1024/epoch-1__checkpoint-7932",
45
+ ]
46
+
47
+ PRETRAINED_MODEL_NAME_OR_PATH = "black-forest-labs/FLUX.1-dev"
48
+
49
+ tokenizer = T5TokenizerFast.from_pretrained(
50
+ PRETRAINED_MODEL_NAME_OR_PATH,
51
+ subfolder="tokenizer_2",
52
+ revision=None,
53
+ )
54
+
55
+ placeholder_token_str = ["<placeholder>"]
56
+ num_added_tokens = tokenizer.add_tokens(placeholder_token_str)
57
+ assert num_added_tokens == 1
58
+
59
+ def generate_image_event(camera_elevation, camera_lens, surrounding_prompt, checkpoint_name,
60
+ height, width, seed, guidance_scale, num_steps):
61
+ """Generate final image with segmentation masks and run inference"""
62
+ # Update scene manager's inference params before generation
63
+ scene_manager.update_inference_params(height, width, seed, guidance_scale, num_steps, checkpoint_name)
64
+ if not scene_manager.objects:
65
+ return (
66
+ "⚠️ No objects to render",
67
+ gr.update(),
68
+ Image.new('RGB', (512, 512), color='white')
69
+ )
70
+
71
+ # Get subject descriptions
72
+ subject_descriptions = [obj['description'] for obj in scene_manager.objects]
73
+
74
+ print(f"Surrounding prompt: {surrounding_prompt}")
75
+ print(f"Subject descriptions: {subject_descriptions}")
76
+ print(f"Selected checkpoint: {checkpoint_name}")
77
+
78
+ placeholder_prompt = "a photo of PLACEHOLDER " + surrounding_prompt
79
+
80
+ # Create placeholder text
81
+ subject_embeds = []
82
+ for subject_idx, subject_desc in enumerate(subject_descriptions):
83
+ input_ids = tokenizer.encode(subject_desc, return_tensors="pt", max_length=77)[0]
84
+ subject_embed = {"input_ids_t5": input_ids.tolist()}
85
+ subject_embeds.append(subject_embed)
86
+
87
+ placeholder_text = ""
88
+ for subject in subject_descriptions[:-1]:
89
+ placeholder_text = placeholder_text + f"<placeholder> {subject} and "
90
+ for subject in subject_descriptions[-1:]:
91
+ placeholder_text = placeholder_text + f"<placeholder> {subject}"
92
+ placeholder_text = placeholder_text.strip()
93
+
94
+ placeholder_token_prompt = placeholder_prompt.replace("PLACEHOLDER", placeholder_text)
95
+
96
+ call_ids = get_call_ids_from_placeholder_prompt_flux(prompt=placeholder_token_prompt,
97
+ subjects=subject_descriptions,
98
+ subjects_embeds=subject_embeds,
99
+ debug=True
100
+ )
101
+ print(f"Generated call IDs: {call_ids}")
102
+
103
+ # Convert to server expected format
104
+ subjects_data, camera_data = scene_manager._convert_to_blender_format()
105
+
106
+ # Render final high-quality image using CYCLES (port 5002)
107
+ final_img = scene_manager.render_client._send_render_request(
108
+ scene_manager.render_client.final_server_url,
109
+ subjects_data,
110
+ camera_data
111
+ )
112
+
113
+ final_img.save("model_condition.jpg")
114
+
115
+ # Render segmentation masks
116
+ success, segmask_images, error_msg = scene_manager.render_client.render_segmasks(subjects_data, camera_data)
117
+
118
+ if not success:
119
+ return (
120
+ f"❌ Failed to render segmentation masks: {error_msg}",
121
+ gr.update(),
122
+ Image.new('RGB', (512, 512), color='white')
123
+ )
124
+
125
+ # Save all files to the correct location
126
+ root_save_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/gradio_files/"
127
+ os.system(f"rm -f {root_save_dir}/*")
128
+
129
+ # Save final render to root directory
130
+ final_render_path = osp.join(root_save_dir, "cv_render.jpg")
131
+ final_img.save(final_render_path)
132
+
133
+ # Move segmentation masks
134
+ for subject_idx in range(len(subject_descriptions)):
135
+ shutil.move(
136
+ f"{str(subject_idx).zfill(3)}_segmask_cv.png",
137
+ osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png")
138
+ )
139
+
140
+ # Create JSONL
141
+ jsonl = [{
142
+ "cv": final_render_path,
143
+ "target": final_render_path,
144
+ "cuboids_segmasks": [
145
+ osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png")
146
+ for subject_idx in range(len(subject_descriptions))
147
+ ],
148
+ "PLACEHOLDER_prompts": placeholder_prompt,
149
+ "subjects": subject_descriptions,
150
+ "call_ids": call_ids,
151
+ }]
152
+
153
+ jsonl_path = osp.join(root_save_dir, "cuboids.jsonl")
154
+ with open(jsonl_path, "w") as f:
155
+ json.dump(jsonl[0], f)
156
+
157
+ # Run inference using the pre-loaded model
158
+ print(f"\n{'='*60}")
159
+ print(f"RUNNING INFERENCE")
160
+ print(f"{'='*60}\n")
161
+
162
+ inference_success, generated_image, inference_msg = run_inference_from_gradio(
163
+ checkpoint_name=checkpoint_name,
164
+ height=height,
165
+ width=width,
166
+ seed=seed,
167
+ guidance_scale=guidance_scale,
168
+ num_inference_steps=num_steps,
169
+ jsonl_path=jsonl_path
170
+ )
171
+
172
+ if not inference_success:
173
+ return (
174
+ f"✅ Saved files but inference failed: {inference_msg}",
175
+ final_img,
176
+ Image.new('RGB', (512, 512), color='white')
177
+ )
178
+
179
+ status_msg = f"✅ Generated image using {checkpoint_name} with {len(segmask_images)} segmentation masks"
180
+
181
+ # Render final high-quality image using CYCLES (port 5002)
182
+ final_img = scene_manager.render_client._send_render_request(
183
+ scene_manager.render_client.paper_figure_server_url,
184
+ subjects_data,
185
+ camera_data
186
+ )
187
+
188
+ return (
189
+ status_msg,
190
+ final_img, # Display CV render in Camera View
191
+ generated_image # Display generated image in Generated Image section
192
+ )
193
+
194
+
195
+ def get_call_ids_from_placeholder_prompt_flux(prompt: str, subjects, subjects_embeds: list, debug: bool):
196
+ assert prompt.find("<placeholder>") != -1, "Prompt must contain <placeholder> to get call ids"
197
+
198
+ # the placeholder token ID for all the tokenizers
199
+ placeholder_token_three = tokenizer.encode("<placeholder>", return_tensors="pt")[0][:-1].item()
200
+ prompt_tokens_three = tokenizer.encode(prompt, return_tensors="pt")[0].tolist()
201
+
202
+ placeholder_token_locations_three = [i for i, w in enumerate(prompt_tokens_three) if w == placeholder_token_three]
203
+ prompt = prompt.replace("<placeholder> ", "")
204
+
205
+
206
+ call_ids = []
207
+ for subject_idx, (subject, subject_embed) in enumerate(zip(subjects, subjects_embeds)):
208
+ subject_prompt_ids_t5 = subject_embed["input_ids_t5"][:-1] # T5 has SOT token only
209
+ num_t5_tokens_subject = len(subject_prompt_ids_t5)
210
+
211
+ t5_call_ids_subject = [i + placeholder_token_locations_three[subject_idx] - 2 * subject_idx - 1 for i in range(num_t5_tokens_subject)]
212
+ call_ids.append(t5_call_ids_subject)
213
+
214
+ prompt_wo_placeholder = prompt.replace("<placeholder> ", "")
215
+ t5_call_strs = tokenizer.batch_decode(tokenizer.encode(prompt_wo_placeholder, return_tensors="pt")[0].tolist())
216
+ t5_call_strs = [t5_call_strs[i] for i in t5_call_ids_subject]
217
+ if debug:
218
+ print(f"{prompt = }, t5 CALL strs for {subject} = {t5_call_strs}")
219
+
220
+ return call_ids
221
+
222
+
223
+ def map_point_to_rgb(x, y):
224
+ """
225
+ Map (x, y) inside the frustum to an RGB color with continuity and variation.
226
+ """
227
+ # Frustum boundaries
228
+ X_MIN, X_MAX = -10.0, -1.0
229
+ Y_MIN_AT_XMIN, Y_MAX_AT_XMIN = -4.5, 4.5
230
+ Y_MIN_AT_XMAX, Y_MAX_AT_XMAX = -0.5, 0.5
231
+
232
+ # Normalize x to [0, 1]
233
+ x_norm = (x - X_MIN) / (X_MAX - X_MIN)
234
+ # x_norm = np.clip(x_norm, 0, 1)
235
+
236
+ # Compute current Y bounds at given x using linear interpolation
237
+ y_min = Y_MIN_AT_XMIN + x_norm * (Y_MIN_AT_XMAX - Y_MIN_AT_XMIN)
238
+ y_max = Y_MAX_AT_XMIN + x_norm * (Y_MAX_AT_XMAX - Y_MAX_AT_XMIN)
239
+
240
+ # Normalize y to [0, 1] within current bounds
241
+ if y_max != y_min:
242
+ y_norm = (y - y_min) / (y_max - y_min)
243
+ else:
244
+ y_norm = 0.5
245
+ y_norm = np.clip(y_norm, 0.0, 1.0)
246
+
247
+ # Color mapping: more variation along x
248
+ r = x_norm
249
+ g = y_norm
250
+ b = 1.0 - x_norm
251
+
252
+ return (r, g, b)
253
+
254
+ def rgb_to_hex(rgb_tuple):
255
+ """Convert RGB tuple (0-1 range) to hex color string."""
256
+ r, g, b = rgb_tuple
257
+ return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
258
+
259
+
260
+ class BlenderRenderClient:
261
+ def __init__(self, cv_server_url="http://127.0.0.1:5001", segmask_server_url="http://127.0.0.1:5003", final_server_url="http://127.0.0.1:5002", paper_figure_server_url="http://127.0.0.1:5004"):
262
+ """
263
+ Initialize the Blender render client.
264
+
265
+ Args:
266
+ cv_server_url (str): URL of the camera view render server
267
+ segmask_server_url (str): URL of the segmentation mask render server
268
+ """
269
+ self.cv_server_url = cv_server_url
270
+ self.segmask_server_url = segmask_server_url
271
+ self.final_server_url = final_server_url
272
+ self.paper_figure_server_url = paper_figure_server_url
273
+ self.timeout = 30 # 30 second timeout for renders
274
+
275
+ def render_segmasks(self, subjects_data: list, camera_data: dict) -> tuple:
276
+ """
277
+ Send a segmentation mask render request.
278
+ Returns (success: bool, segmask_images: list of PIL Images or None, error_message: str or None)
279
+ """
280
+ try:
281
+ request_data = {
282
+ "subjects_data": subjects_data,
283
+ "camera_data": camera_data,
284
+ "num_samples": 1
285
+ }
286
+
287
+ response = requests.post(
288
+ f"{self.segmask_server_url}/render_segmasks",
289
+ json=request_data,
290
+ timeout=self.timeout
291
+ )
292
+
293
+ if response.status_code == 200:
294
+ result = response.json()
295
+ if result["success"]:
296
+ # Decode all segmentation masks
297
+ segmask_images = []
298
+ for img_base64 in result["segmasks_base64"]:
299
+ img_data = base64.b64decode(img_base64)
300
+ img = Image.open(io.BytesIO(img_data))
301
+ segmask_images.append(img)
302
+
303
+ print(f"Successfully rendered {len(segmask_images)} segmentation masks")
304
+ return True, segmask_images, None
305
+ else:
306
+ error_msg = result.get('error_message', 'Unknown error')
307
+ print(f"Segmask render failed: {error_msg}")
308
+ return False, None, error_msg
309
+ else:
310
+ error_msg = f"HTTP error {response.status_code}: {response.text}"
311
+ print(error_msg)
312
+ return False, None, error_msg
313
+
314
+ except requests.exceptions.Timeout:
315
+ error_msg = "Segmask render request timed out"
316
+ print(error_msg)
317
+ return False, None, error_msg
318
+ except Exception as e:
319
+ error_msg = f"Segmask render request failed: {e}"
320
+ print(error_msg)
321
+ return False, None, error_msg
322
+
323
+
324
+ def _send_render_request(self, server_url: str, subjects_data: list, camera_data: dict) -> Image.Image:
325
+ """Send a render request to a server and return the image."""
326
+ try:
327
+ request_data = {
328
+ "subjects_data": subjects_data,
329
+ "camera_data": camera_data,
330
+ "num_samples": 1
331
+ }
332
+ print(f"passing {subjects_data = } to server at {server_url}")
333
+
334
+ response = requests.post(
335
+ f"{server_url}/render",
336
+ json=request_data,
337
+ timeout=self.timeout
338
+ )
339
+
340
+ if response.status_code == 200:
341
+ result = response.json()
342
+ if result["success"]:
343
+ # Decode base64 image
344
+ img_data = base64.b64decode(result["image_base64"])
345
+ img = Image.open(io.BytesIO(img_data))
346
+ return img
347
+ else:
348
+ print(f"Render failed: {result.get('error_message', 'Unknown error')}")
349
+ return self._create_error_image("red")
350
+ else:
351
+ print(f"HTTP error {response.status_code}: {response.text}")
352
+ return self._create_error_image("orange")
353
+
354
+ except requests.exceptions.Timeout:
355
+ print("Render request timed out")
356
+ return self._create_error_image("yellow")
357
+ except Exception as e:
358
+ print(f"Render request failed: {e}")
359
+ return self._create_error_image("red")
360
+
361
+ def _create_error_image(self, color: str) -> Image.Image:
362
+ """Create a colored error image."""
363
+ return Image.new('RGB', (512, 512), color=color)
364
+
365
+ # --- Scene Management Class ---
366
+ class SceneManager:
367
+ def __init__(self):
368
+ self.objects = []
369
+ self.camera_elevation = 30.0
370
+ self.camera_lens = 50.0
371
+ self.surrounding_prompt = ""
372
+ self.next_color_idx = 0
373
+ self.colors = [
374
+ (1.0, 0.0, 0.0), # red
375
+ (0.0, 0.0, 1.0), # blue
376
+ (0.0, 1.0, 0.0), # green
377
+ (0.5, 0.0, 0.5), # purple
378
+ (1.0, 0.5, 0.0), # orange
379
+ (1.0, 1.0, 0.0), # yellow
380
+ (0.0, 1.0, 1.0), # cyan
381
+ (1.0, 0.0, 1.0), # magenta
382
+ ]
383
+
384
+ # Add inference parameters with defaults
385
+ self.inference_params = {
386
+ 'height': 512,
387
+ 'width': 512,
388
+ 'seed': 42,
389
+ 'guidance_scale': 3.5,
390
+ 'num_inference_steps': 25,
391
+ 'checkpoint': CHECKPOINT_NAMES[0] if CHECKPOINT_NAMES else None
392
+ }
393
+
394
+ # Initialize BlenderRenderClient
395
+ self.render_client = BlenderRenderClient()
396
+
397
+ # Load asset dimensions
398
+ self.asset_dimensions = self._load_asset_dimensions()
399
+
400
+
401
+ def update_inference_params(self, height, width, seed, guidance_scale, num_steps, checkpoint):
402
+ """Update inference parameters"""
403
+ self.inference_params = {
404
+ 'height': height,
405
+ 'width': width,
406
+ 'seed': seed,
407
+ 'guidance_scale': guidance_scale,
408
+ 'num_inference_steps': num_steps,
409
+ 'checkpoint': checkpoint
410
+ }
411
+
412
+
413
+ def update_cuboid_description(self, obj_id, new_description):
414
+ """Update the description of a cuboid"""
415
+ if 0 <= obj_id < len(self.objects):
416
+ if new_description.strip(): # Check not empty
417
+ self.objects[obj_id]['description'] = new_description.strip()
418
+ return True
419
+ return False
420
+
421
+
422
+ def save_scene_to_pkl(self, filepath=None):
423
+ """Save current scene data to pkl file including inference parameters"""
424
+ if filepath is None:
425
+ # Auto-generate filename with timestamp
426
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
427
+ filepath = f"scene_{timestamp}.pkl"
428
+
429
+ # Convert to the expected format
430
+ subjects_data = []
431
+ for obj in self.objects:
432
+ subject_dict = {
433
+ 'name': obj['description'],
434
+ 'type': obj['type'], # Save the object type
435
+ 'dims': tuple(obj['size']), # (width, depth, height)
436
+ 'x': [obj['position'][0] - 6.0],
437
+ 'y': [obj['position'][1]],
438
+ 'z': [obj['position'][2]],
439
+ 'azimuth': [np.radians(obj['azimuth'])], # Convert to radians
440
+ 'bbox': [(0, 0, 0, 0)] # Placeholder, can be computed if needed
441
+ }
442
+ subjects_data.append(subject_dict)
443
+
444
+ camera_data = {
445
+ 'camera_elevation': np.radians(self.camera_elevation),
446
+ 'lens': self.camera_lens,
447
+ 'global_scale': 1.0 # Default value
448
+ }
449
+
450
+ scene_dict = {
451
+ 'subjects_data': subjects_data,
452
+ 'camera_data': camera_data,
453
+ 'surrounding_prompt': self.surrounding_prompt,
454
+ 'inference_params': self.inference_params.copy()
455
+ }
456
+
457
+ try:
458
+ with open(filepath, 'wb') as f:
459
+ pickle.dump(scene_dict, f)
460
+ return True, filepath, None
461
+ except Exception as e:
462
+ return False, None, str(e)
463
+
464
+
465
+ def load_scene_from_pkl(self, filepath):
466
+ """Load scene data from pkl file including inference parameters"""
467
+ try:
468
+ with open(filepath, 'rb') as f:
469
+ scene_dict = pickle.load(f)
470
+
471
+ # Clear existing objects
472
+ self.objects = []
473
+ self.next_color_idx = 0
474
+
475
+ # Load subjects
476
+ subjects_data = scene_dict.get('subjects_data', [])
477
+ for subject_dict in subjects_data:
478
+ name = subject_dict.get('name', 'Loaded Object')
479
+ asset_type = subject_dict.get('type', 'Custom') # Load the type
480
+ dims = subject_dict.get('dims', (1.0, 1.0, 1.0))
481
+ x = float(subject_dict.get('x', [0.0])[0]) + 6.0
482
+ y = float(subject_dict.get('y', [0.0])[0])
483
+ z = float(subject_dict.get('z', [0.0])[0])
484
+ azimuth_rad = float(subject_dict.get('azimuth', [0.0])[0])
485
+ azimuth_deg = np.degrees(azimuth_rad)
486
+
487
+ # Determine original_asset_size based on type
488
+ if asset_type == "Custom" or asset_type not in self.asset_dimensions:
489
+ original_asset_size = None
490
+ else:
491
+ # Look up the original asset dimensions
492
+ asset_dims = self.asset_dimensions[asset_type]
493
+ original_asset_size = [float(asset_dims[0]), float(asset_dims[1]), float(asset_dims[2])]
494
+
495
+ # Create object
496
+ obj_id = len(self.objects)
497
+ size_list = [float(d) for d in dims]
498
+ cuboid = {
499
+ 'id': obj_id,
500
+ 'description': name,
501
+ 'type': asset_type, # Use the loaded type
502
+ 'position': [x, y, z],
503
+ 'size': size_list,
504
+ 'original_asset_size': original_asset_size, # Restore from asset_dimensions
505
+ 'azimuth': float(azimuth_deg),
506
+ 'color': self._get_next_color()
507
+ }
508
+ self.objects.append(cuboid)
509
+
510
+ # Load camera settings
511
+ camera_data = scene_dict.get('camera_data', {})
512
+ camera_elev_rad = float(camera_data.get('camera_elevation', np.radians(30.0)))
513
+ self.camera_elevation = float(np.degrees(camera_elev_rad))
514
+ self.camera_lens = float(camera_data.get('lens', 50.0))
515
+
516
+ # Load surrounding prompt
517
+ self.surrounding_prompt = scene_dict.get('surrounding_prompt', '')
518
+
519
+ # Load inference parameters
520
+ loaded_inference_params = scene_dict.get('inference_params', {})
521
+
522
+ # Get checkpoint, fall back to first available if not found
523
+ saved_checkpoint = loaded_inference_params.get('checkpoint')
524
+ if saved_checkpoint and saved_checkpoint in CHECKPOINT_NAMES:
525
+ checkpoint = saved_checkpoint
526
+ else:
527
+ checkpoint = CHECKPOINT_NAMES[0] if CHECKPOINT_NAMES else None
528
+ if saved_checkpoint:
529
+ print(f"Warning: Saved checkpoint '{saved_checkpoint}' not found, using '{checkpoint}' instead")
530
+
531
+ self.inference_params = {
532
+ 'height': loaded_inference_params.get('height', 512),
533
+ 'width': loaded_inference_params.get('width', 512),
534
+ 'seed': loaded_inference_params.get('seed', 42),
535
+ 'guidance_scale': loaded_inference_params.get('guidance_scale', 3.5),
536
+ 'num_inference_steps': loaded_inference_params.get('num_inference_steps', 25),
537
+ 'checkpoint': checkpoint
538
+ }
539
+
540
+ return True, len(subjects_data), None
541
+ except FileNotFoundError:
542
+ return False, 0, f"File not found: {filepath}"
543
+ except Exception as e:
544
+ return False, 0, f"Error loading file: {str(e)}"
545
+
546
+
547
+ def _load_asset_dimensions(self):
548
+ """Load asset dimensions from pickle file"""
549
+ pkl_path = "asset_dimensions.pkl"
550
+ if os.path.exists(pkl_path):
551
+ try:
552
+ with open(pkl_path, 'rb') as f:
553
+ return pickle.load(f)
554
+ except Exception as e:
555
+ print(f"Warning: Could not load asset dimensions: {e}")
556
+ return {}
557
+ else:
558
+ print(f"Warning: asset_dimensions.pkl not found at {pkl_path}")
559
+ return {}
560
+
561
+ def get_asset_type_choices(self):
562
+ """Get list of asset types for dropdown"""
563
+ choices = ["Custom"]
564
+ if self.asset_dimensions:
565
+ choices.extend(sorted(self.asset_dimensions.keys()))
566
+ return choices
567
+
568
+ def _get_next_color(self):
569
+ color = self.colors[self.next_color_idx % len(self.colors)]
570
+ self.next_color_idx += 1
571
+ return color
572
+
573
+
574
+ def harmonize_scales(self):
575
+ """
576
+ Harmonize the scales of all non-Custom objects based on object scales.
577
+ Always scales from original asset dimensions, ignoring any manual edits.
578
+ Custom objects remain unchanged.
579
+ """
580
+ if not self.objects:
581
+ return "No objects to harmonize"
582
+
583
+ # Find objects that can be harmonized (non-Custom with valid scales and original_asset_size)
584
+ harmonizable_objects = []
585
+ for obj in self.objects:
586
+ if (obj['type'] != "Custom" and
587
+ obj['type'] in scales and
588
+ obj['original_asset_size'] is not None):
589
+ harmonizable_objects.append(obj)
590
+
591
+ if not harmonizable_objects:
592
+ return "No objects with defined scales to harmonize (all are Custom)"
593
+
594
+ # Find the largest scale among harmonizable objects
595
+ max_scale = max(scales[obj['type']] for obj in harmonizable_objects)
596
+
597
+ if max_scale == 0:
598
+ return "Invalid max scale (0)"
599
+
600
+ # Harmonize each object by scaling from ORIGINAL ASSET dimensions
601
+ for obj in harmonizable_objects:
602
+ obj_scale = scales[obj['type']]
603
+ scale_factor = obj_scale / max_scale
604
+
605
+ # Scale from ORIGINAL ASSET dimensions, not current dimensions
606
+ obj['size'][0] = obj['original_asset_size'][0] * scale_factor # width
607
+ obj['size'][1] = obj['original_asset_size'][1] * scale_factor # depth
608
+ obj['size'][2] = obj['original_asset_size'][2] * scale_factor # height
609
+
610
+ # Update z position to keep object on ground
611
+ obj['position'][2] = 0.0
612
+
613
+ return f"Harmonized {len(harmonizable_objects)} objects based on largest scale: {max_scale}"
614
+
615
+
616
+ def add_cuboid(self, description="New Cuboid", asset_type="Custom"):
617
+ """Add a cuboid with dimensions based on asset type"""
618
+ obj_id = len(self.objects)
619
+
620
+ # Determine dimensions based on asset type
621
+ if asset_type == "Custom" or asset_type not in self.asset_dimensions:
622
+ size = [1.0, 1.0, 1.0] # Default size
623
+ original_asset_size = None # Custom objects have no original asset size
624
+ else:
625
+ # Load dimensions from pkl file
626
+ dims = self.asset_dimensions[asset_type]
627
+ size = [float(dims[0]), float(dims[1]), float(dims[2])] # [width, depth, height]
628
+ original_asset_size = size.copy() # Store the original asset dimensions
629
+
630
+ cuboid = {
631
+ 'id': obj_id,
632
+ 'description': description,
633
+ 'type': asset_type, # Store the asset type
634
+ 'position': [0.0, 0.0, 0.0], # Place on ground (z = height/2)
635
+ 'size': size,
636
+ 'original_asset_size': original_asset_size, # Store original asset dimensions
637
+ 'azimuth': 0.0,
638
+ 'color': self._get_next_color()
639
+ }
640
+ self.objects.append(cuboid)
641
+ return obj_id
642
+
643
+
644
+ def update_cuboid(self, obj_id, x, y, z, azimuth, width, depth, height):
645
+ if 0 <= obj_id < len(self.objects):
646
+ obj = self.objects[obj_id]
647
+ obj['position'] = [x, y, z]
648
+ obj['size'] = [width, depth, height]
649
+ # Note: We do NOT update original_asset_size here - it stays unchanged
650
+ obj['azimuth'] = azimuth
651
+ return True
652
+ return False
653
+
654
+
655
+ def delete_cuboid(self, obj_id):
656
+ if 0 <= obj_id < len(self.objects):
657
+ del self.objects[obj_id]
658
+ # Update IDs for remaining objects
659
+ for i, obj in enumerate(self.objects):
660
+ obj['id'] = i
661
+ return True
662
+ return False
663
+
664
+ def set_camera_elevation(self, elevation_deg):
665
+ assert type(elevation_deg) == float or type(elevation_deg) == int, f"{type(elevation_deg) = }"
666
+ self.camera_elevation = np.clip(elevation_deg, 0.0, 90.0)
667
+ return f"Camera elevation set to {elevation_deg}°"
668
+
669
+ def set_camera_lens(self, lens_value):
670
+ self.camera_lens = np.clip(lens_value, 10.0, 200.0)
671
+ return f"Camera lens set to {lens_value}mm"
672
+
673
+ def set_surrounding_prompt(self, prompt): # Add this method
674
+ self.surrounding_prompt = prompt
675
+ return f"Surrounding prompt updated"
676
+
677
+ def _convert_to_blender_format(self):
678
+ """Convert internal objects format to server expected format"""
679
+ subjects_data = []
680
+
681
+ for obj in self.objects:
682
+ subject_data = {
683
+ 'subject_name': obj['description'],
684
+ 'x': float(obj['position'][0]),
685
+ 'y': float(obj['position'][1]),
686
+ 'z': float(obj['position'][2]),
687
+ 'azimuth': float(obj['azimuth']),
688
+ 'width': float(obj['size'][0]),
689
+ 'depth': float(obj['size'][1]),
690
+ 'height': float(obj['size'][2]),
691
+ 'base_color': obj['color']
692
+ }
693
+ subjects_data.append(subject_data)
694
+
695
+ camera_data = {
696
+ 'camera_elevation': float(np.radians(self.camera_elevation)),
697
+ 'lens': float(self.camera_lens),
698
+ 'global_scale': 1.0
699
+ }
700
+
701
+ return subjects_data, camera_data
702
+
703
+ def render_cv_view(self, subjects_data: list, camera_data: dict) -> Image.Image:
704
+ """Render only the CV view."""
705
+ if not subjects_data:
706
+ return Image.new('RGB', (512, 512), color='gray')
707
+
708
+ return self.render_client._send_render_request(self.render_client.cv_server_url, subjects_data, camera_data)
709
+
710
+
711
+ def render_scene(self, width=512, height=512):
712
+ """Render only CV view using the render client."""
713
+ print(f"calling render_scene")
714
+ if not self.objects:
715
+ # Return empty image if no objects
716
+ empty_cv = Image.new('RGB', (width, height), color='gray')
717
+ return empty_cv
718
+
719
+ # Convert to server expected format
720
+ subjects_data, camera_data = self._convert_to_blender_format()
721
+ print(f"passing {subjects_data = } to render_cv_view in SceneManager")
722
+
723
+ # Render CV view only
724
+ cv_img = self.render_cv_view(subjects_data, camera_data)
725
+
726
+ return cv_img
727
+
728
+ # --- Gradio Interface Logic ---
729
+ scene_manager = SceneManager()
730
+
731
+ def get_cuboid_list_html():
732
+ """Generate HTML for the cuboid list with position-based colors"""
733
+ if not scene_manager.objects:
734
+ return "<div style='text-align: center; padding: 20px; color: #888;'>No cuboids yet. Add one to get started!</div>"
735
+
736
+ html = "<div style='display: flex; flex-direction: column; gap: 8px;'>"
737
+ for obj_idx, obj in enumerate(scene_manager.objects):
738
+ # Get position-based color
739
+ # x, y = obj['position'][0], obj['position'][1]
740
+ # rgb_color = map_point_to_rgb(x, y)
741
+ rgb_color = COLORS[obj_idx % len(COLORS)]
742
+ hex_color = rgb_to_hex(rgb_color)
743
+
744
+ # Create a lighter version for gradient end
745
+ lighter_rgb = tuple(min(1.0, c + 0.2) for c in rgb_color)
746
+ lighter_hex = rgb_to_hex(lighter_rgb)
747
+
748
+ html += f"""
749
+ <div style='background: linear-gradient(135deg, {hex_color} 0%, {lighter_hex} 100%);
750
+ padding: 12px; border-radius: 8px; color: white; text-shadow: 1px 1px 2px rgba(0,0,0,0.5);'>
751
+ <div style='font-weight: bold; font-size: 14px;'>{obj['description']}</div>
752
+ <div style='font-size: 11px; opacity: 0.9; margin-top: 4px;'>
753
+ Pos: ({obj['position'][0]:.1f}, {obj['position'][1]:.1f}, {obj['position'][2]:.1f}) |
754
+ Size: {obj['size'][0]:.1f}×{obj['size'][1]:.1f}×{obj['size'][2]:.1f}
755
+ </div>
756
+ </div>
757
+ """
758
+ html += "</div>"
759
+ return html
760
+
761
+
762
+ def add_cuboid_event(description_input, asset_type, camera_elevation, camera_lens):
763
+ """Add a new cuboid"""
764
+ if not description_input.strip():
765
+ description_input = "New Cuboid"
766
+
767
+ new_id = scene_manager.add_cuboid(description_input, asset_type)
768
+ cv_img = scene_manager.render_scene()
769
+
770
+ # Create choices for radio buttons
771
+ choices = [f"{obj['description']}" for obj in scene_manager.objects]
772
+
773
+ # Get the new object data
774
+ new_obj = scene_manager.objects[new_id]
775
+
776
+ return (
777
+ gr.update(value=""), # Clear description input
778
+ gr.update(value="Custom"), # Reset type dropdown to Custom
779
+ cv_img,
780
+ get_cuboid_list_html(),
781
+ gr.update(choices=choices, value=new_obj['description']), # Radio with new selection
782
+ gr.update(visible=True), # Show editor
783
+ gr.update(value=new_obj['description']), # Set description in editor
784
+ gr.update(value=new_obj['position'][0]),
785
+ gr.update(value=new_obj['position'][1]),
786
+ gr.update(value=new_obj['position'][2]),
787
+ gr.update(value=new_obj['azimuth']),
788
+ gr.update(value=new_obj['size'][0]),
789
+ gr.update(value=new_obj['size'][1]),
790
+ gr.update(value=new_obj['size'][2]),
791
+ gr.update(value=1.0) # Reset scale to 1.0
792
+ )
793
+
794
+
795
+ def select_cuboid_event(selected_name):
796
+ """When a cuboid is selected from radio buttons"""
797
+ if not selected_name:
798
+ return [gr.update(visible=False)] + [gr.update() for _ in range(9)] # Changed from 8 to 9
799
+
800
+ # Find the cuboid by description
801
+ obj = None
802
+ for o in scene_manager.objects:
803
+ if o['description'] == selected_name:
804
+ obj = o
805
+ break
806
+
807
+ if obj is None:
808
+ return [gr.update(visible=False)] + [gr.update() for _ in range(9)]
809
+
810
+ return (
811
+ gr.update(visible=True), # Show editor
812
+ gr.update(value=obj['description']), # Set description
813
+ gr.update(value=obj['position'][0]),
814
+ gr.update(value=obj['position'][1]),
815
+ gr.update(value=obj['position'][2]),
816
+ gr.update(value=obj['azimuth']),
817
+ gr.update(value=obj['size'][0]),
818
+ gr.update(value=obj['size'][1]),
819
+ gr.update(value=obj['size'][2]),
820
+ gr.update(value=1.0) # Reset scale to 1.0
821
+ )
822
+
823
+
824
+ def delete_selected_cuboid(selected_name, camera_elevation, camera_lens):
825
+ """Delete the currently selected cuboid"""
826
+ if not selected_name:
827
+ return gr.update(), get_cuboid_list_html(), gr.update(), gr.update(visible=False)
828
+
829
+ # Find and delete the cuboid
830
+ obj_id = None
831
+ for i, obj in enumerate(scene_manager.objects):
832
+ if obj['description'] == selected_name:
833
+ obj_id = i
834
+ break
835
+
836
+ if obj_id is not None:
837
+ scene_manager.delete_cuboid(obj_id)
838
+
839
+ cv_img = scene_manager.render_scene()
840
+
841
+ # Update choices
842
+ choices = [f"{obj['description']}" for obj in scene_manager.objects]
843
+
844
+ return (
845
+ cv_img,
846
+ get_cuboid_list_html(),
847
+ gr.update(choices=choices, value=None),
848
+ gr.update(visible=False)
849
+ )
850
+
851
+
852
+ def update_cuboid_event(selected_name, camera_elevation, camera_lens, description, x, y, z, azimuth, width, depth, height, scale):
853
+ """Update the selected cuboid including description and scale"""
854
+ scene_manager.set_camera_elevation(camera_elevation)
855
+ scene_manager.set_camera_lens(camera_lens)
856
+
857
+ if selected_name:
858
+ # Find the cuboid by description
859
+ obj_id = None
860
+ for i, obj in enumerate(scene_manager.objects):
861
+ if obj['description'] == selected_name:
862
+ obj_id = i
863
+ break
864
+
865
+ if obj_id is not None:
866
+ # Update description first if changed
867
+ if description.strip() and description.strip() != selected_name:
868
+ scene_manager.update_cuboid_description(obj_id, description.strip())
869
+
870
+ # Apply scale to dimensions
871
+ scaled_width = width * scale
872
+ scaled_depth = depth * scale
873
+ scaled_height = height * scale
874
+
875
+ # Update other properties with scaled dimensions
876
+ scene_manager.update_cuboid(obj_id, x, y, z, azimuth, scaled_width, scaled_depth, scaled_height)
877
+
878
+ # Get updated object for return
879
+ updated_obj = scene_manager.objects[obj_id]
880
+ new_name = updated_obj['description']
881
+
882
+ cv_img = scene_manager.render_scene()
883
+
884
+ # Update choices with new descriptions
885
+ choices = [f"{obj['description']}" for obj in scene_manager.objects]
886
+
887
+ # Return updated HTML, image, radio choices, new selection, updated sliders, and reset scale
888
+ return (
889
+ get_cuboid_list_html(),
890
+ cv_img,
891
+ gr.update(choices=choices, value=new_name if obj_id is not None else None),
892
+ gr.update(value=scaled_width if obj_id is not None else width), # Update width slider
893
+ gr.update(value=scaled_depth if obj_id is not None else depth), # Update depth slider
894
+ gr.update(value=scaled_height if obj_id is not None else height), # Update height slider
895
+ gr.update(value=1.0) # Reset scale to 1.0
896
+ )
897
+
898
+
899
+ def camera_change_event(camera_elevation, camera_lens):
900
+ """Handle camera control changes"""
901
+ scene_manager.set_camera_elevation(camera_elevation)
902
+ scene_manager.set_camera_lens(camera_lens)
903
+ cv_img = scene_manager.render_scene()
904
+ return cv_img
905
+
906
+
907
+ def surrounding_prompt_change_event(prompt_text): # Add this function
908
+ """Handle surrounding prompt changes"""
909
+ scene_manager.set_surrounding_prompt(prompt_text)
910
+ return None # No visual update needed
911
+
912
+
913
+ def render_segmask_event(camera_elevation, camera_lens, surrounding_prompt):
914
+ """Render segmentation masks for all objects"""
915
+ if not scene_manager.objects:
916
+ return "⚠️ No objects to render", gr.update(visible=False), []
917
+
918
+ # Get subject descriptions
919
+ subject_descriptions = [obj['description'] for obj in scene_manager.objects]
920
+
921
+ # Now you have access to:
922
+ # - surrounding_prompt: the text from surrounding_prompt_input
923
+ # - subject_descriptions: list of all subject descriptions
924
+
925
+ print(f"Surrounding prompt: {surrounding_prompt}")
926
+ print(f"Subject descriptions: {subject_descriptions}")
927
+
928
+ placeholder_prompt = "a photo of PLACEHOLDER " + surrounding_prompt
929
+
930
+ # Create placeholder text
931
+ subject_embeds = []
932
+ for subject_idx, subject_desc in enumerate(subject_descriptions):
933
+ input_ids = tokenizer.encode(subject_desc, return_tensors="pt", max_length=77)[0]
934
+ subject_embed = {"input_ids_t5": input_ids.tolist()}
935
+ subject_embeds.append(subject_embed)
936
+
937
+ placeholder_text = ""
938
+ for subject in subject_descriptions[:-1]:
939
+ placeholder_text = placeholder_text + f"<placeholder> {subject} and "
940
+ for subject in subject_descriptions[-1:]:
941
+ placeholder_text = placeholder_text + f"<placeholder> {subject}"
942
+ placeholder_text = placeholder_text.strip()
943
+
944
+ placeholder_token_prompt = placeholder_prompt.replace("PLACEHOLDER", placeholder_text)
945
+
946
+ call_ids = get_call_ids_from_placeholder_prompt_flux(prompt=placeholder_token_prompt,
947
+ subjects=subject_descriptions,
948
+ subjects_embeds=subject_embeds,
949
+ debug=True
950
+ )
951
+ print(f"Generated call IDs: {call_ids}")
952
+
953
+
954
+ # Convert to server expected format
955
+ subjects_data, camera_data = scene_manager._convert_to_blender_format()
956
+
957
+ # You can add the prompt and descriptions to the request if needed
958
+ # For example, add to subjects_data or camera_data before sending
959
+
960
+ # Render segmentation masks
961
+ success, segmask_images, error_msg = scene_manager.render_client.render_segmasks(subjects_data, camera_data)
962
+
963
+ # copy all the data to the correct location
964
+ root_save_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/gradio_files/"
965
+ os.system("rm /archive/vaibhav.agrawal/a-bev-of-the-latents/gradio_files/*")
966
+ shutil.move("cv_render.jpg", osp.join(root_save_dir, "cv_render.jpg"))
967
+ for subject_idx in range(len(subject_descriptions)):
968
+ shutil.move(f"{str(subject_idx).zfill(3)}_segmask_cv.png", osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png"))
969
+
970
+ jsonl = [{
971
+ "cv": osp.join(root_save_dir, "cv_render.jpg"),
972
+ "target": osp.join(root_save_dir, "cv_render.jpg"),
973
+ "cuboids_segmasks": [osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png") for subject_idx in range(len(subject_descriptions))],
974
+ "PLACEHOLDER_prompts": placeholder_prompt,
975
+ "subjects": subject_descriptions,
976
+ "call_ids": call_ids,
977
+ }]
978
+
979
+ with open(osp.join(root_save_dir, "cuboids.jsonl"), "w") as f:
980
+ for item in jsonl:
981
+ f.write(json.dumps(item) + "\n")
982
+
983
+ if success:
984
+ return (
985
+ f"✅ Successfully rendered {len(segmask_images)} segmentation masks",
986
+ gr.update(visible=True),
987
+ segmask_images
988
+ )
989
+ else:
990
+ return (
991
+ f"❌ Failed to render segmentation masks: {error_msg}",
992
+ gr.update(visible=False),
993
+ []
994
+ )
995
+
996
+
997
+ def harmonize_event(selected_name, camera_elevation, camera_lens):
998
+ """Harmonize all object scales and update the scene"""
999
+ message = scene_manager.harmonize_scales()
1000
+ print(message)
1001
+
1002
+ cv_img = scene_manager.render_scene()
1003
+
1004
+ # If a cuboid is selected, update its sliders
1005
+ if selected_name:
1006
+ obj = None
1007
+ for o in scene_manager.objects:
1008
+ if o['description'] == selected_name:
1009
+ obj = o
1010
+ break
1011
+
1012
+ if obj is not None:
1013
+ return (
1014
+ cv_img,
1015
+ get_cuboid_list_html(),
1016
+ gr.update(value=obj['position'][0]),
1017
+ gr.update(value=obj['position'][1]),
1018
+ gr.update(value=obj['position'][2]),
1019
+ gr.update(value=obj['azimuth']),
1020
+ gr.update(value=obj['size'][0]),
1021
+ gr.update(value=obj['size'][1]),
1022
+ gr.update(value=obj['size'][2])
1023
+ )
1024
+
1025
+ # No object selected or object not found
1026
+ return (
1027
+ cv_img,
1028
+ get_cuboid_list_html(),
1029
+ gr.update(),
1030
+ gr.update(),
1031
+ gr.update(),
1032
+ gr.update(),
1033
+ gr.update(),
1034
+ gr.update(),
1035
+ gr.update()
1036
+ )
1037
+
1038
+
1039
+ def save_scene_event():
1040
+ """Save the current scene to a pkl file"""
1041
+ success, filepath, error = scene_manager.save_scene_to_pkl()
1042
+
1043
+ if success:
1044
+ return f"✅ Scene saved successfully to: {filepath}\n📋 Saved parameters: {scene_manager.inference_params}"
1045
+ else:
1046
+ return f"❌ Failed to save scene: {error}"
1047
+
1048
+
1049
+ def load_scene_event(filepath):
1050
+ """Load a scene from a pkl file and restore all parameters"""
1051
+ if not filepath.strip():
1052
+ return (
1053
+ "⚠️ Please enter a file path",
1054
+ gr.update(),
1055
+ gr.update(),
1056
+ gr.update(),
1057
+ gr.update(),
1058
+ gr.update(),
1059
+ gr.update(),
1060
+ gr.update(), # surrounding_prompt
1061
+ gr.update(), # checkpoint
1062
+ gr.update(), # height
1063
+ gr.update(), # width
1064
+ gr.update(), # seed
1065
+ gr.update(), # guidance
1066
+ gr.update() # steps
1067
+ )
1068
+
1069
+ success, num_objects, error = scene_manager.load_scene_from_pkl(filepath)
1070
+
1071
+ if success:
1072
+ # Re-render the scene
1073
+ cv_img = scene_manager.render_scene()
1074
+
1075
+ # Update UI components
1076
+ choices = [f"{obj['description']}" for obj in scene_manager.objects]
1077
+
1078
+ params_msg = f"✅ Scene loaded: {num_objects} objects\n📋 Restored parameters: {scene_manager.inference_params}"
1079
+
1080
+ return (
1081
+ params_msg,
1082
+ cv_img,
1083
+ get_cuboid_list_html(),
1084
+ gr.update(choices=choices, value=None),
1085
+ gr.update(visible=False),
1086
+ gr.update(value=scene_manager.camera_elevation),
1087
+ gr.update(value=scene_manager.camera_lens),
1088
+ gr.update(value=scene_manager.surrounding_prompt),
1089
+ gr.update(value=scene_manager.inference_params['checkpoint']),
1090
+ gr.update(value=scene_manager.inference_params['height']),
1091
+ gr.update(value=scene_manager.inference_params['width']),
1092
+ gr.update(value=scene_manager.inference_params['seed']),
1093
+ gr.update(value=scene_manager.inference_params['guidance_scale']),
1094
+ gr.update(value=scene_manager.inference_params['num_inference_steps'])
1095
+ )
1096
+ else:
1097
+ return (
1098
+ f"❌ {error}",
1099
+ gr.update(),
1100
+ gr.update(),
1101
+ gr.update(),
1102
+ gr.update(),
1103
+ gr.update(),
1104
+ gr.update(),
1105
+ gr.update(),
1106
+ gr.update(),
1107
+ gr.update(),
1108
+ gr.update(),
1109
+ gr.update(),
1110
+ gr.update(),
1111
+ gr.update()
1112
+ )
1113
+
1114
+
1115
+ # --- Gradio UI Layout ---
1116
+ with gr.Blocks(
1117
+ theme=gr.themes.Soft(
1118
+ primary_hue="green",
1119
+ secondary_hue="gray",
1120
+ neutral_hue="gray"
1121
+ ),
1122
+ css="""
1123
+ .gradio-container {
1124
+ background: linear-gradient(135deg, #0d1117 0%, #1a3d2e 50%, #000000 100%) !important;
1125
+ color: #ffffff !important;
1126
+ }
1127
+ .block {
1128
+ background: rgba(15, 36, 25, 0.8) !important;
1129
+ border: 1px solid #2d5a41 !important;
1130
+ border-radius: 8px !important;
1131
+ }
1132
+ .form {
1133
+ background: rgba(15, 36, 25, 0.6) !important;
1134
+ }
1135
+ h1, h2, h3, h4, h5, h6 {
1136
+ color: #ffffff !important;
1137
+ }
1138
+ .markdown {
1139
+ color: #e6e6e6 !important;
1140
+ }
1141
+ label {
1142
+ color: #cccccc !important;
1143
+ }
1144
+ .gr-button {
1145
+ background: linear-gradient(135deg, #2d5a41, #3d6a51) !important;
1146
+ border: 1px solid #4a7c59 !important;
1147
+ color: #ffffff !important;
1148
+ }
1149
+ .gr-button:hover {
1150
+ background: linear-gradient(135deg, #3d6a51, #4a7c59) !important;
1151
+ }
1152
+ .gr-input, .gr-textbox, .gr-dropdown {
1153
+ background: rgba(15, 36, 25, 0.8) !important;
1154
+ border: 1px solid #2d5a41 !important;
1155
+ color: #ffffff !important;
1156
+ }
1157
+ .gr-input:focus, .gr-textbox:focus {
1158
+ border-color: #4a7c59 !important;
1159
+ background: rgba(26, 61, 46, 0.8) !important;
1160
+ }
1161
+ .gr-slider input[type="range"] {
1162
+ background: #2d5a41 !important;
1163
+ }
1164
+ .gr-slider input[type="range"]::-webkit-slider-thumb {
1165
+ background: #4a7c59 !important;
1166
+ }
1167
+ .gr-radio label {
1168
+ color: #cccccc !important;
1169
+ }
1170
+ .gr-panel {
1171
+ background: rgba(15, 36, 25, 0.6) !important;
1172
+ border: 1px solid #2d5a41 !important;
1173
+ }
1174
+ """
1175
+ ) as demo:
1176
+ gr.Markdown("# [CVPR-2026] 3D Aware Occlusion Control in Text-to-Image Generation 🏞️🧱")
1177
+ # TOP ROW
1178
+ with gr.Row():
1179
+ # TOP LEFT - Edit Properties
1180
+ with gr.Column(scale=1):
1181
+ # Add description textbox at the top
1182
+ # with gr.Column(visible=False) as editor_section:
1183
+ # gr.Markdown("## ✏️ Edit Properties")
1184
+
1185
+ # delete_btn = gr.Button("❌ Delete Selected Cuboid", variant="stop", size="sm")
1186
+
1187
+ # with gr.Row():
1188
+ # edit_x = gr.Slider(-10, 10, value=0, step=0.1, label="X")
1189
+ # edit_y = gr.Slider(-10, 10, value=0, step=0.1, label="Y")
1190
+ # edit_z = gr.Slider(0, 10, value=1, step=0.1, label="Z")
1191
+
1192
+ # edit_azimuth = gr.Slider(-180, 180, value=0, step=1, label="Azimuth (°)")
1193
+
1194
+ # with gr.Row():
1195
+ # edit_width = gr.Slider(0.1, 5, value=1, step=0.1, label="Width")
1196
+ # edit_depth = gr.Slider(0.1, 5, value=1, step=0.1, label="Depth")
1197
+ # edit_height = gr.Slider(0.1, 5, value=1, step=0.1, label="Height")
1198
+ with gr.Column(visible=False) as editor_section:
1199
+ gr.Markdown("## ✏️ Edit Properties")
1200
+
1201
+ edit_description = gr.Textbox(
1202
+ label="Description",
1203
+ placeholder="Enter object description",
1204
+ info="Description cannot be empty"
1205
+ )
1206
+
1207
+ delete_btn = gr.Button("❌ Delete Selected Cuboid", variant="stop", size="sm")
1208
+
1209
+ with gr.Row():
1210
+ edit_x = gr.Slider(-10, 10, value=0, step=0.1, label="X")
1211
+ edit_y = gr.Slider(-10, 10, value=0, step=0.1, label="Y")
1212
+ edit_z = gr.Slider(0, 10, value=1, step=0.1, label="Z")
1213
+
1214
+ edit_azimuth = gr.Slider(-180, 180, value=0, step=1, label="Azimuth (°)")
1215
+
1216
+ with gr.Row():
1217
+ edit_width = gr.Slider(0.1, 5, value=1, step=0.1, label="Width")
1218
+ edit_depth = gr.Slider(0.1, 5, value=1, step=0.1, label="Depth")
1219
+ edit_height = gr.Slider(0.1, 5, value=1, step=0.1, label="Height")
1220
+
1221
+ # Add scale slider
1222
+ edit_scale = gr.Slider(
1223
+ 0.1, 3.0, value=1.0, step=0.1,
1224
+ label="Scale",
1225
+ info="Multiplier for all dimensions (resets to 1.0 after update)"
1226
+ )
1227
+
1228
+ # Add the Update Scene button
1229
+ update_scene_btn = gr.Button("🔄 Update Scene", variant="primary", size="sm")
1230
+
1231
+ # TOP MIDDLE - Camera View
1232
+ with gr.Column(scale=1):
1233
+ gr.Markdown("## 📷 Camera View")
1234
+ cv_image_output = gr.Image(label="Camera View", height=400)
1235
+
1236
+ # TOP RIGHT - Generated Image
1237
+ with gr.Column(scale=1):
1238
+ gr.Markdown("## 🎨 Generated Image")
1239
+ generated_image_output = gr.Image(label="Generated Image", height=400)
1240
+
1241
+ # BOTTOM ROW
1242
+ with gr.Row():
1243
+ # BOTTOM LEFT - Cuboid List and Selection
1244
+ with gr.Column(scale=1):
1245
+ gr.Markdown("## 📦 Cuboids")
1246
+ cuboid_list_html = gr.HTML(get_cuboid_list_html())
1247
+
1248
+ gr.Markdown("### Select Cuboid to Edit")
1249
+ cuboid_radio = gr.Radio(choices=[], label="", visible=True)
1250
+
1251
+ # BOTTOM RIGHT - Camera Controls and Add New Cuboid
1252
+ with gr.Column(scale=2):
1253
+ with gr.Row():
1254
+ with gr.Column():
1255
+ gr.Markdown("## Global Controls")
1256
+ camera_elevation_slider = gr.Slider(0, 90, value=30, label="Camera Elevation (degrees)")
1257
+ camera_lens_slider = gr.Slider(10, 200, value=50, label="Camera Lens (mm)")
1258
+
1259
+ # Add surrounding prompt textbox
1260
+ surrounding_prompt_input = gr.Textbox(
1261
+ placeholder="e.g., in a forest, in a city, on a beach",
1262
+ label="Surrounding Prompt",
1263
+ info="Describe the surrounding environment"
1264
+ )
1265
+
1266
+ gr.Markdown("## 🔧 Scene Tools")
1267
+ harmonize_btn = gr.Button("⚖️ Harmonize Scales", variant="secondary")
1268
+
1269
+ # Save/Load Section
1270
+ gr.Markdown("## 💾 Save/Load Scene")
1271
+ with gr.Row():
1272
+ save_scene_btn = gr.Button("💾 Save Scene", variant="secondary")
1273
+ load_scene_btn = gr.Button("📂 Load Scene", variant="secondary")
1274
+
1275
+ load_path_input = gr.Textbox(
1276
+ placeholder="/path/to/scene.pkl",
1277
+ label="Load Scene Path",
1278
+ info="Enter path to pkl file to load"
1279
+ )
1280
+ save_load_status = gr.Markdown("")
1281
+
1282
+ with gr.Column():
1283
+ gr.Markdown("## ➕ Add New Cuboid")
1284
+ add_cuboid_description_input = gr.Textbox(placeholder="Enter cuboid description", label="Description")
1285
+ asset_type_dropdown = gr.Dropdown(
1286
+ choices=scene_manager.get_asset_type_choices(),
1287
+ value="Custom",
1288
+ label="Type",
1289
+ info="Select asset type to load dimensions, or choose Custom"
1290
+ )
1291
+ add_cuboid_btn = gr.Button("Add Cuboid", variant="primary")
1292
+ generate_btn = gr.Button("🎨 Generate Image", variant="primary")
1293
+
1294
+ # Add checkpoint dropdown
1295
+ checkpoint_dropdown = gr.Dropdown(
1296
+ choices=CHECKPOINT_NAMES,
1297
+ value=CHECKPOINT_NAMES[0] if CHECKPOINT_NAMES else None,
1298
+ label="Checkpoint",
1299
+ info="Select model checkpoint for generation"
1300
+ )
1301
+
1302
+ # Inference Parameters
1303
+ gr.Markdown("### Inference Parameters")
1304
+
1305
+ with gr.Row():
1306
+ inference_height = gr.Slider(
1307
+ minimum=256, maximum=1024, value=512, step=64,
1308
+ label="Height"
1309
+ )
1310
+ inference_width = gr.Slider(
1311
+ minimum=256, maximum=1024, value=512, step=64,
1312
+ label="Width"
1313
+ )
1314
+
1315
+ inference_seed = gr.Number(
1316
+ value=42, label="Random Seed", precision=0
1317
+ )
1318
+
1319
+ inference_guidance = gr.Slider(
1320
+ minimum=1.0, maximum=10.0, value=3.5, step=0.5,
1321
+ label="Guidance Scale"
1322
+ )
1323
+
1324
+ inference_steps = gr.Slider(
1325
+ minimum=10, maximum=50, value=25, step=1,
1326
+ label="Inference Steps"
1327
+ )
1328
+
1329
+ # Event Handlers
1330
+ def add_cuboid_with_auto_update(description_input, asset_type, camera_elevation, camera_lens):
1331
+ """Add cuboid and auto-update scene"""
1332
+ result = add_cuboid_event(description_input, asset_type, camera_elevation, camera_lens)
1333
+ return result
1334
+
1335
+ # Update add_cuboid_btn.click event handler (around line 850):
1336
+ add_cuboid_btn.click(
1337
+ add_cuboid_with_auto_update,
1338
+ inputs=[add_cuboid_description_input, asset_type_dropdown, camera_elevation_slider, camera_lens_slider],
1339
+ outputs=[
1340
+ add_cuboid_description_input,
1341
+ asset_type_dropdown,
1342
+ cv_image_output,
1343
+ cuboid_list_html,
1344
+ cuboid_radio,
1345
+ editor_section,
1346
+ edit_description,
1347
+ edit_x, edit_y, edit_z,
1348
+ edit_azimuth,
1349
+ edit_width, edit_depth, edit_height,
1350
+ edit_scale # Add this
1351
+ ]
1352
+ )
1353
+
1354
+ # Update the cuboid_radio.change event handler (around line 860):
1355
+ cuboid_radio.change(
1356
+ select_cuboid_event,
1357
+ inputs=[cuboid_radio],
1358
+ outputs=[
1359
+ editor_section,
1360
+ edit_description,
1361
+ edit_x, edit_y, edit_z,
1362
+ edit_azimuth,
1363
+ edit_width, edit_depth, edit_height,
1364
+ edit_scale # Add this
1365
+ ]
1366
+ )
1367
+
1368
+ delete_btn.click(
1369
+ delete_selected_cuboid,
1370
+ inputs=[cuboid_radio, camera_elevation_slider, camera_lens_slider],
1371
+ outputs=[cv_image_output, cuboid_list_html, cuboid_radio, editor_section]
1372
+ )
1373
+
1374
+ # Save/Load handlers
1375
+ save_scene_btn.click(
1376
+ save_scene_event,
1377
+ inputs=[],
1378
+ outputs=[save_load_status]
1379
+ )
1380
+
1381
+ load_scene_btn.click(
1382
+ load_scene_event,
1383
+ inputs=[load_path_input],
1384
+ outputs=[
1385
+ save_load_status,
1386
+ cv_image_output,
1387
+ cuboid_list_html,
1388
+ cuboid_radio,
1389
+ editor_section,
1390
+ camera_elevation_slider,
1391
+ camera_lens_slider,
1392
+ surrounding_prompt_input,
1393
+ checkpoint_dropdown,
1394
+ inference_height,
1395
+ inference_width,
1396
+ inference_seed,
1397
+ inference_guidance,
1398
+ inference_steps
1399
+ ]
1400
+ )
1401
+
1402
+ # Auto-update scene when sliders change
1403
+ # for slider in [edit_x, edit_y, edit_z, edit_azimuth, edit_width, edit_depth, edit_height]:
1404
+ # slider.change(
1405
+ # update_cuboid_event,
1406
+ # inputs=[
1407
+ # cuboid_radio,
1408
+ # camera_elevation_slider,
1409
+ # camera_lens_slider,
1410
+ # edit_x, edit_y, edit_z,
1411
+ # edit_azimuth,
1412
+ # edit_width, edit_depth, edit_height
1413
+ # ],
1414
+ # outputs=[cuboid_list_html, cv_image_output]
1415
+ # )
1416
+ # Update the update_scene_btn.click event handler (around line 920):
1417
+ update_scene_btn.click(
1418
+ update_cuboid_event,
1419
+ inputs=[
1420
+ cuboid_radio,
1421
+ camera_elevation_slider,
1422
+ camera_lens_slider,
1423
+ edit_description,
1424
+ edit_x, edit_y, edit_z,
1425
+ edit_azimuth,
1426
+ edit_width, edit_depth, edit_height,
1427
+ edit_scale # Add this
1428
+ ],
1429
+ outputs=[
1430
+ cuboid_list_html,
1431
+ cv_image_output,
1432
+ cuboid_radio,
1433
+ edit_width, # Add this
1434
+ edit_depth, # Add this
1435
+ edit_height, # Add this
1436
+ edit_scale # Add this (to reset to 1.0)
1437
+ ]
1438
+ )
1439
+
1440
+
1441
+ # Update generate button click handler
1442
+ generate_btn.click(
1443
+ generate_image_event,
1444
+ inputs=[
1445
+ camera_elevation_slider,
1446
+ camera_lens_slider,
1447
+ surrounding_prompt_input,
1448
+ checkpoint_dropdown,
1449
+ inference_height,
1450
+ inference_width,
1451
+ inference_seed,
1452
+ inference_guidance,
1453
+ inference_steps
1454
+ ],
1455
+ outputs=[save_load_status, cv_image_output, generated_image_output]
1456
+ )
1457
+
1458
+
1459
+ harmonize_btn.click(
1460
+ harmonize_event,
1461
+ inputs=[cuboid_radio, camera_elevation_slider, camera_lens_slider],
1462
+ outputs=[
1463
+ cv_image_output,
1464
+ cuboid_list_html,
1465
+ edit_x, edit_y, edit_z,
1466
+ edit_azimuth,
1467
+ edit_width, edit_depth, edit_height
1468
+ ]
1469
+ )
1470
+
1471
+ # Camera controls
1472
+ for control in [camera_elevation_slider, camera_lens_slider]:
1473
+ control.change(
1474
+ camera_change_event,
1475
+ inputs=[camera_elevation_slider, camera_lens_slider],
1476
+ outputs=[cv_image_output]
1477
+ )
1478
+
1479
+ # Surrounding prompt control
1480
+ surrounding_prompt_input.change(
1481
+ surrounding_prompt_change_event,
1482
+ inputs=[surrounding_prompt_input],
1483
+ outputs=[]
1484
+ )
1485
+
1486
+
1487
+ # Initial render
1488
+ def initial_render():
1489
+ cv_img = scene_manager.render_scene()
1490
+ gen_img = Image.new('RGB', (512, 512), color='white')
1491
+ return cv_img, gen_img
1492
+
1493
+ demo.load(
1494
+ initial_render,
1495
+ outputs=[cv_image_output, generated_image_output]
1496
+ )
1497
+
1498
+
1499
+ if __name__ == "__main__":
1500
+ import os
1501
+ os.system("./launch_blender_backend.sh &")
1502
+ # Initialize inference engine (load model once at startup)
1503
+ initialize_inference_engine(base_model_path="black-forest-labs/FLUX.1-dev")
1504
+ demo.launch(share=True)
asset_dimensions.pkl ADDED
Binary file (1.75 kB). View file
 
blender_backend.py ADDED
@@ -0,0 +1,1521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bpy
2
+ import bpy_extras
3
+ import numpy as np
4
+ import bmesh
5
+ import copy
6
+ import PIL
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import colorsys
10
+ import os
11
+ import os.path as osp
12
+ import shutil
13
+ import sys
14
+ import math
15
+ import mathutils
16
+ import random
17
+ import cv2
18
+ from object_scales import scales
19
+ import matplotlib.colors as mcolors
20
+ import torch
21
+
22
+ def map_point_to_rgb(x, y, z):
23
+ """
24
+ Map (x, y) inside the frustum to an RGB color with continuity and variation.
25
+ """
26
+ # Frustum boundaries
27
+ X_MIN, X_MAX = -12.0, -1.0
28
+ Y_MIN_AT_XMIN, Y_MAX_AT_XMIN = -4.5, 4.5
29
+ Y_MIN_AT_XMAX, Y_MAX_AT_XMAX = -0.5, 0.5
30
+ Z_MIN, Z_MAX = 0.0, 2.50
31
+ # Normalize x to [0, 1]
32
+ x_norm = (x - X_MIN) / (X_MAX - X_MIN)
33
+ x_norm = np.clip(x_norm, 0, 1)
34
+
35
+ # Compute current Y bounds at given x using linear interpolation
36
+ y_min = Y_MIN_AT_XMIN + x_norm * (Y_MIN_AT_XMAX - Y_MIN_AT_XMIN)
37
+ y_max = Y_MAX_AT_XMIN + x_norm * (Y_MAX_AT_XMAX - Y_MAX_AT_XMIN)
38
+
39
+ # Normalize y to [0, 1] within current bounds
40
+ if y_max != y_min:
41
+ y_norm = (y - y_min) / (y_max - y_min)
42
+ else:
43
+ y_norm = 0.5
44
+ y_norm = np.clip(y_norm, 0, 1)
45
+
46
+ z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
47
+
48
+ # Color mapping: more variation along x
49
+ r = x_norm
50
+ # g = 0.5 * y_norm + 0.25 * x_norm
51
+ g = y_norm
52
+ b = z_norm
53
+
54
+ return (r, g, b)
55
+
56
+
57
+ def set_world_color(color=(0.1, 0.1, 0.1)):
58
+ """
59
+ Sets the world background color to match the grid floor.
60
+
61
+ Args:
62
+ color (tuple): RGB color values (0-1 range)
63
+ """
64
+ scene = bpy.context.scene
65
+
66
+ # Create a new world if it doesn't exist
67
+ if scene.world is None:
68
+ world = bpy.data.worlds.new(name="World")
69
+ scene.world = world
70
+ else:
71
+ world = scene.world
72
+
73
+ # Enable use of nodes for the world
74
+ world.use_nodes = True
75
+
76
+ # Get the node tree
77
+ nodes = world.node_tree.nodes
78
+ links = world.node_tree.links
79
+
80
+ # Find or create the Background node
81
+ background_node = None
82
+ for node in nodes:
83
+ if node.type == 'BACKGROUND':
84
+ background_node = node
85
+ break
86
+
87
+ if background_node is None:
88
+ # Clear existing nodes and create new ones
89
+ nodes.clear()
90
+ background_node = nodes.new(type='ShaderNodeBackground')
91
+ output_node = nodes.new(type='ShaderNodeOutputWorld')
92
+ links.new(background_node.outputs['Background'], output_node.inputs['Surface'])
93
+
94
+ # Set the background color
95
+ background_node.inputs['Color'].default_value = (*color, 1.0)
96
+ background_node.inputs['Strength'].default_value = 1.0
97
+
98
+
99
+ COLORS = [
100
+ (1.0, 0.0, 0.0), # Red
101
+ (0.0, 0.8, 0.2), # Green
102
+ (0.0, 0.0, 1.0), # Blue
103
+ (1.0, 1.0, 0.0), # Yellow
104
+ (0.0, 1.0, 1.0), # Cyan
105
+ (1.0, 0.0, 1.0), # Magenta
106
+ (1.0, 0.6, 0.0), # Orange
107
+ (0.6, 0.0, 0.8), # Purple
108
+ (0.0, 0.4, 0.0), # Dark Green
109
+ (0.8, 0.8, 0.8), # Light Gray
110
+ (0.2, 0.2, 0.2) # Dark Gray
111
+ ]
112
+
113
+ def do_z_pass(seg_masks: torch.Tensor, dist_values: torch.Tensor) -> torch.Tensor:
114
+ """
115
+ Performs a z-pass on segmentation masks based on distance values to the camera.
116
+ For each pixel, if multiple subjects' masks are active, only the one with the smallest distance (closest) remains active.
117
+
118
+ Args:
119
+ seg_masks (torch.Tensor): Binary segmentation masks of shape (n_subjects, h, w) with dtype uint8.
120
+ dist_values (torch.Tensor): Distance values for each subject of shape (n_subjects,).
121
+
122
+ Returns:
123
+ torch.Tensor: Processed segmentation masks after z-pass, same shape and dtype as seg_masks.
124
+ """
125
+ # Ensure tensors are on the same device
126
+ device = seg_masks.device
127
+
128
+ # Get dimensions
129
+ n_subjects, h, w = seg_masks.shape
130
+
131
+ # Reshape distance values for broadcasting across spatial dimensions
132
+ dist_values_expanded = dist_values.view(n_subjects, 1, 1)
133
+
134
+ # Create a tensor where active pixels have their distance, others have a high value (1e10)
135
+ masked_dist = torch.where(seg_masks.bool(), dist_values_expanded, torch.tensor(1e10, device=device))
136
+
137
+ # Find the subject index with the minimum distance for each pixel (shape (h, w))
138
+ closest_indices = torch.argmin(masked_dist, dim=0)
139
+
140
+ # Initialize output tensor with zeros
141
+ output = torch.zeros_like(seg_masks)
142
+
143
+ # Scatter 1s into the output tensor where the closest subject's indices are
144
+ # closest_indices.unsqueeze(0) adds a dummy dimension to match scatter's expected shape
145
+ output.scatter_(
146
+ dim=0,
147
+ index=closest_indices.unsqueeze(0),
148
+ src=torch.ones_like(closest_indices.unsqueeze(0), dtype=output.dtype)
149
+ )
150
+
151
+ # Zero out any positions where the original mask was inactive
152
+ output = output * seg_masks
153
+
154
+ return output
155
+
156
+
157
+ def get_image_to_world_matrix(camera_obj, render):
158
+ """
159
+ Calculates the matrix to transform a point from clip space to world space.
160
+
161
+ Args:
162
+ camera_obj (bpy.types.Object): The camera object.
163
+ render (bpy.types.RenderSettings): The scene's render settings.
164
+
165
+ Returns:
166
+ mathutils.Matrix: The 4x4 matrix for clip-to-world transformation.
167
+ """
168
+ # Get the camera's view matrix (world to camera)
169
+ view_matrix = camera_obj.matrix_world.inverted()
170
+
171
+ # Get the camera's projection matrix
172
+ # This matrix depends on the render resolution, so it's best to calculate it
173
+ # for the specific dimensions you're using.
174
+ projection_matrix = camera_obj.calc_matrix_camera(
175
+ bpy.context.evaluated_depsgraph_get(),
176
+ x=render.resolution_x,
177
+ y=render.resolution_y,
178
+ scale_x=render.pixel_aspect_x,
179
+ scale_y=render.pixel_aspect_y,
180
+ )
181
+
182
+ # Combine and invert to get the clip-to-world matrix
183
+ clip_to_world_matrix = (projection_matrix @ view_matrix).inverted()
184
+
185
+ return clip_to_world_matrix
186
+
187
+
188
+ def unproject_image_point(camera_obj, image_coord, depth):
189
+ """
190
+ Transforms a 2D image coordinate with a depth value into a 3D world coordinate.
191
+
192
+ Args:
193
+ camera_obj (bpy.types.Object): The camera used for rendering.
194
+ image_coord (tuple or list): The (x, y) pixel coordinate.
195
+ depth (float): The depth value at that coordinate (from the Z-pass).
196
+
197
+ Returns:
198
+ mathutils.Vector: The calculated 3D point in world space.
199
+ """
200
+ render = bpy.context.scene.render
201
+
202
+ # 1. Get the clip-to-world transformation matrix
203
+ clip_to_world_mat = get_image_to_world_matrix(camera_obj, render)
204
+
205
+ # 2. Convert image coordinates to Normalized Device Coordinates (NDC)
206
+ # (from [0, res] to [-1, 1])
207
+ ndc_x = (image_coord[0] / render.resolution_x) * 2 - 1
208
+ ndc_y = (image_coord[1] / render.resolution_y) * 2 - 1
209
+
210
+ # In Blender's Z-pass, the depth value is the distance from the camera's plane.
211
+ # We can use Blender's utility function to find the 3D vector for the pixel.
212
+ # This vector is in camera space and points from the camera towards the pixel.
213
+ view_vector = bpy_extras.view3d_utils.region_2d_to_vector_3d(
214
+ bpy.context.region,
215
+ bpy.context.space_data.region_3d,
216
+ image_coord
217
+ )
218
+
219
+ # 4. Project the view vector into world space and scale by depth
220
+ # The view_vector is normalized and in camera space.
221
+ # To get the point in world space, we transform the vector by the camera's
222
+ # world matrix (not the view matrix).
223
+ world_vector = camera_obj.matrix_world.to_3x3() @ view_vector
224
+
225
+ # The depth from the Z-pass is the distance along the camera's local Z-axis.
226
+ # To find the true distance along the ray, we must account for the angle.
227
+ # We can calculate the scaling factor 't' for our world_vector.
228
+ camera_forward = -camera_obj.matrix_world.col[2].xyz
229
+ t = depth / world_vector.dot(camera_forward)
230
+
231
+ # 5. Calculate the final world coordinate
232
+ # Start from the camera's location and move along the ray.
233
+ world_point = camera_obj.matrix_world.translation + (t * world_vector)
234
+
235
+ return world_point
236
+
237
+ # --- Example Usage ---
238
+ # This example assumes you have an active scene with a camera and have rendered an image.
239
+ # You would typically run this after rendering, where you can access the depth map.
240
+
241
+
242
+ def multiply_random_color(obj, random_color):
243
+ """
244
+ Multiplies the existing base color of an object's materials
245
+ with a random color.
246
+ """
247
+ for material_slot in obj.material_slots:
248
+ if material_slot.material:
249
+ material = material_slot.material
250
+ if material.use_nodes:
251
+ nodes = material.node_tree.nodes
252
+ links = material.node_tree.links
253
+
254
+ # Find the Principled BSDF node
255
+ principled_bsdf = nodes.get("Principled BSDF")
256
+ if not principled_bsdf:
257
+ continue
258
+
259
+ # Get the node connected to the Base Color input
260
+ base_color_input = principled_bsdf.inputs.get("Base Color")
261
+ if not base_color_input:
262
+ continue
263
+
264
+ # Create a MixRGB node and set it to multiply
265
+ mix_rgb_node = nodes.new(type='ShaderNodeMixRGB')
266
+ mix_rgb_node.blend_type = 'MULTIPLY'
267
+ mix_rgb_node.inputs['Fac'].default_value = 2.00
268
+ mix_rgb_node.location = (principled_bsdf.location.x - 200, principled_bsdf.location.y)
269
+
270
+ # Set the second color to a random color
271
+ mix_rgb_node.inputs['Color2'].default_value = random_color
272
+
273
+ # If a node is already connected to the Base Color,
274
+ # connect it to the first color input of the MixRGB node.
275
+ if base_color_input.is_linked:
276
+ original_link = base_color_input.links[0]
277
+ original_node = original_link.from_node
278
+ original_socket = original_link.from_socket
279
+ links.new(original_node.outputs[original_socket.name], mix_rgb_node.inputs['Color1'])
280
+ links.remove(original_link)
281
+ else:
282
+ # If no node is connected, use the original default color
283
+ original_color = base_color_input.default_value
284
+ mix_rgb_node.inputs['Color1'].default_value = original_color
285
+
286
+ # Connect the MixRGB node to the Principled BSDF's Base Color
287
+ links.new(mix_rgb_node.outputs['Color'], base_color_input)
288
+
289
+
290
+ OUTPUT_DIR = "four_subject_renders"
291
+ OBJECTS_DIR = "obja_2units_along_y/glbs"
292
+
293
+ NUM_AZIMUTH_BINS = 1
294
+ NUM_LIGHTS = 1
295
+
296
+ MAX_TRIES = 25
297
+
298
+ IMG_DIM = 1024
299
+
300
+ MASK_RES = 50
301
+
302
+ THRESHOLD_LOWER = 150
303
+ THRESHOLD_UPPER = 768
304
+
305
+ ROOT_OBJS_DIR = "/ssd_scratch/vaibhav.agrawal/a-bev-of-the-latents/glb_files/"
306
+
307
+ OBJ_SIDE_LENGTH = 2.0
308
+
309
+ def calculate_iou(box1, box2):
310
+ """
311
+ Calculate the Intersection over Union (IoU) of two bounding boxes.
312
+
313
+ Parameters:
314
+ box1, box2: Each box is defined by a tuple (x1, y1, x2, y2)
315
+ where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner.
316
+
317
+ Returns:
318
+ float: IoU value
319
+ """
320
+ # Unpack coordinatesO
321
+ x1_min, y1_min, x1_max, y1_max = box1
322
+ x2_min, y2_min, x2_max, y2_max = box2
323
+
324
+ # Determine the coordinates of the intersection rectangle
325
+ inter_x_min = max(x1_min, x2_min)
326
+ inter_y_min = max(y1_min, y2_min)
327
+ inter_x_max = min(x1_max, x2_max)
328
+ inter_y_max = min(y1_max, y2_max)
329
+
330
+ # Compute the area of intersection rectangle
331
+ inter_width = max(0, inter_x_max - inter_x_min)
332
+ inter_height = max(0, inter_y_max - inter_y_min)
333
+ intersection_area = inter_width * inter_height
334
+
335
+ # Compute the area of both bounding boxes
336
+ box1_area = (x1_max - x1_min) * (y1_max - y1_min)
337
+ box2_area = (x2_max - x2_min) * (y2_max - y2_min)
338
+
339
+ # Compute the area of the union
340
+ union_area = box1_area + box2_area - intersection_area
341
+
342
+ # Compute IoU
343
+ iou = intersection_area / union_area if union_area > 0 else 0
344
+
345
+ return iou
346
+
347
+
348
+ def get_object_2d_bbox(empty_obj, scene):
349
+ """
350
+ Get the 2D bounding box coordinates of an object in the rendered image.
351
+
352
+ Args:
353
+ empty_obj (bpy.types.Object): The empty object containing the child mesh objects.
354
+ scene (bpy.types.Scene): The current scene.
355
+
356
+ Returns:
357
+ tuple: A tuple containing the 2D bounding box coordinates in pixel space
358
+ in the format (min_x, min_y, max_x, max_y).
359
+ """
360
+ # Get the render settings
361
+ render = scene.render
362
+ res_x = render.resolution_x
363
+ res_y = render.resolution_y
364
+
365
+ # Initialize the bounding box coordinates
366
+ min_x, min_y = float('inf'), float('inf')
367
+ max_x, max_y = float('-inf'), float('-inf')
368
+
369
+ depsgraph = bpy.context.evaluated_depsgraph_get()
370
+
371
+ # Iterate through the child mesh objects
372
+ for obj in empty_obj.children:
373
+ if obj.type == 'MESH':
374
+ # Get the bounding box coordinates in world space
375
+ bbox_corners = [obj.matrix_world @ mathutils.Vector(corner) for corner in obj.bound_box]
376
+
377
+ # Transform the bounding box corners to camera space
378
+ for corner in bbox_corners:
379
+ corner_2d = bpy_extras.object_utils.world_to_camera_view(scene, scene.camera, corner)
380
+
381
+ # Scale the coordinates to pixel space
382
+ x = corner_2d.x * res_x
383
+ y = (1 - corner_2d.y) * res_y # Flip Y since Blender renders from bottom to top
384
+
385
+ # Update the bounding box coordinates
386
+ min_x = min(min_x, x)
387
+ min_y = min(min_y, y)
388
+ max_x = max(max_x, x)
389
+ max_y = max(max_y, y)
390
+
391
+ # Return the 2D bounding box coordinates in pixel space
392
+ return (int(min_x), int(min_y), int(max_x), int(max_y))
393
+
394
+ def reset_cameras(scene) -> None:
395
+ """Resets the cameras in the scene to a single default camera."""
396
+ # Delete all existing cameras
397
+ bpy.ops.object.select_all(action="DESELECT")
398
+ bpy.ops.object.select_by_type(type="CAMERA")
399
+ bpy.ops.object.delete()
400
+
401
+ # Create a new camera with default properties
402
+ bpy.ops.object.camera_add()
403
+
404
+ # Get the camera by searching for it (it will be the only camera)
405
+ new_camera = None
406
+ for obj in scene.objects:
407
+ if obj.type == 'CAMERA':
408
+ new_camera = obj
409
+ break
410
+
411
+ new_camera.name = "Camera"
412
+
413
+ # Set the new camera as the active camera for the scene
414
+ scene.camera = new_camera
415
+
416
+
417
+ def add_plane():
418
+ print(f"in add_plane")
419
+
420
+ # Create mesh data
421
+ mesh = bpy.data.meshes.new("Plane")
422
+ backdrop = bpy.data.objects.new("Plane", mesh)
423
+ bpy.context.scene.collection.objects.link(backdrop)
424
+
425
+ # Create plane geometry using bmesh
426
+ bm = bmesh.new()
427
+ bmesh.ops.create_grid(bm, x_segments=1, y_segments=1, size=25.0) # size=25 gives 50x50 plane
428
+ bm.to_mesh(mesh)
429
+ bm.free()
430
+
431
+ # Add material
432
+ mat_backdrop = bpy.data.materials.new(name="WhiteMaterial")
433
+ mat_backdrop.diffuse_color = (0, 0, 0, 1) # Black
434
+ backdrop.data.materials.append(mat_backdrop)
435
+
436
+
437
+ def add_plane_cycles():
438
+ print(f"in add_plane")
439
+
440
+ # Create mesh data
441
+ mesh = bpy.data.meshes.new("Plane")
442
+ backdrop = bpy.data.objects.new("Plane", mesh)
443
+ bpy.context.scene.collection.objects.link(backdrop)
444
+
445
+ # Create plane geometry using bmesh
446
+ bm = bmesh.new()
447
+ bmesh.ops.create_grid(bm, x_segments=1, y_segments=1, size=25.0) # size=25 gives 50x50 plane
448
+ bm.to_mesh(mesh)
449
+ bm.free()
450
+
451
+ # Add material
452
+ mat_backdrop = bpy.data.materials.new(name="WhiteMaterial")
453
+ mat_backdrop.diffuse_color = (0.050, 0.050, 0.050, 1) # White
454
+ backdrop.data.materials.append(mat_backdrop)
455
+
456
+
457
+ def remove_all_planes():
458
+ # Deselect all objects first
459
+ bpy.ops.object.select_all(action='DESELECT')
460
+
461
+ # Select all plane objects in the scene
462
+ for obj in bpy.data.objects:
463
+ if obj.type == 'MESH' and obj.name.startswith('Plane'):
464
+ obj.select_set(True)
465
+
466
+ # Delete all selected planes
467
+ bpy.ops.object.delete()
468
+
469
+
470
+ def remove_all_lights():
471
+ """Remove all lights from the scene without using operators."""
472
+ lights_to_remove = [obj for obj in bpy.data.objects if obj.type == 'LIGHT']
473
+
474
+ for light in lights_to_remove:
475
+ bpy.data.objects.remove(light, do_unlink=True)
476
+
477
+ # Clean up orphaned light data blocks
478
+ for light_data in bpy.data.lights:
479
+ if light_data.users == 0:
480
+ bpy.data.lights.remove(light_data)
481
+
482
+
483
+ def set_lights_cv(radius, center, num_points, intensity):
484
+ print(f"in set_lights_cv")
485
+ radius = radius + 10.0
486
+ phi = np.random.uniform(-np.pi / 2, np.pi / 2, num_points) # azimuthal angle
487
+ cos_theta = np.random.uniform(0.50, 1.0, num_points) # cos of polar angle
488
+ theta = np.arccos(cos_theta) # polar angle
489
+ x = np.sin(theta) * np.cos(phi)
490
+ y = np.sin(theta) * np.sin(phi)
491
+ z = cos_theta # cos(theta) == z on unit sphere
492
+ # Scale to radius and shift to center
493
+ points = np.stack([x, y, z], axis=1) * radius + center
494
+ for point in points:
495
+ # Track objects before adding light
496
+ before_objs = set(bpy.data.objects)
497
+ bpy.ops.object.light_add(type='POINT', location=point)
498
+ after_objs = set(bpy.data.objects)
499
+
500
+ # Get the newly created light
501
+ diff_objs = after_objs - before_objs
502
+ light = list(diff_objs)[0]
503
+
504
+ light.data.energy = intensity
505
+ light.data.use_shadow = True
506
+ # light.data.shadow_soft_size = 1.0 # Adjust shadow softness if needed
507
+ return points
508
+
509
+
510
+ def adjust_color_brightness(rgb_color, factor):
511
+ """
512
+ Adjusts the brightness of an RGB color by a multiplicative factor.
513
+
514
+ Args:
515
+ rgb_color (tuple): The base color as an (R, G, B) or (R, G, B, A) tuple.
516
+ factor (float): The factor to multiply the brightness by.
517
+ > 1.0 makes it lighter, < 1.0 makes it darker.
518
+
519
+ Returns:
520
+ tuple: The new (R, G, B, A) color.
521
+ """
522
+ # Use only RGB for conversion, keep alpha separate
523
+ h, s, v = colorsys.rgb_to_hsv(rgb_color[0], rgb_color[1], rgb_color[2])
524
+
525
+ # Multiply the Value (brightness) by the factor, and clamp it between 0 and 1
526
+ v = max(0, min(1, v * factor))
527
+
528
+ new_rgb = colorsys.hsv_to_rgb(h, s, v)
529
+
530
+ # Return as an RGBA tuple, preserving original alpha if it exists
531
+ alpha = rgb_color[3] if len(rgb_color) == 4 else 1.0
532
+ return (new_rgb[0], new_rgb[1], new_rgb[2], alpha)
533
+
534
+
535
+ def get_primitive_object_translucent(base_color=(0.0, 1.0, 0.0), edge_color=None, face_opacity=0.025):
536
+ """
537
+ Spawns a cuboid primitive with individually colored faces and highlighted edges.
538
+
539
+ Args:
540
+ base_color (tuple): The base RGB color for the faces.
541
+ edge_color (tuple): The RGBA color for the edges (defaults to white).
542
+ face_opacity (float): The opacity of the cuboid faces (0.0 = invisible, 1.0 = opaque). Default is 0.2.
543
+ """
544
+ # --- Create the Cuboid and Parent ---
545
+ bpy.ops.object.empty_add(type="PLAIN_AXES")
546
+ # empty_object = bpy.context.object
547
+ empty_object = bpy.data.objects.new("Empty", None)
548
+ before_objs = set(bpy.data.objects)
549
+ bpy.ops.mesh.primitive_cube_add(size=0.5, location=(0, 0, 0))
550
+ after_objs = set(bpy.data.objects)
551
+ diff_objs = after_objs - before_objs
552
+
553
+ obj = None
554
+ for o in diff_objs:
555
+ obj = o
556
+ obj.parent = empty_object
557
+ world_matrix = obj.matrix_world
558
+ obj.matrix_world = world_matrix
559
+
560
+ # --- Create and Assign Materials for Each Face ---
561
+ if obj:
562
+ # left front right back bottom top
563
+ brightness_factors = [
564
+ 0.30, 0.30, 0.30, 0.30, 1.00, 0.30,
565
+ ]
566
+ colors = [adjust_color_brightness(base_color, factor) for factor in brightness_factors]
567
+
568
+ for i, color in enumerate(colors):
569
+ material = bpy.data.materials.new(name=f"FaceColor_{i}")
570
+ material.use_nodes = True
571
+ obj.data.materials.append(material)
572
+
573
+ nodes = material.node_tree.nodes
574
+ links = material.node_tree.links
575
+ nodes.clear()
576
+
577
+ # Create Principled BSDF instead of Emission for proper transparency
578
+ bsdf = nodes.new(type="ShaderNodeBsdfPrincipled")
579
+ bsdf.location = (0, 0)
580
+ bsdf.inputs['Base Color'].default_value = color
581
+ bsdf.inputs['Alpha'].default_value = face_opacity # Set face opacity
582
+ bsdf.inputs['Emission Color'].default_value = color[:3] + (1.0,) # Fixed: Use 'Emission Color' instead of 'Emission'
583
+ bsdf.inputs['Emission Strength'].default_value = 1.0 # Emission strength
584
+
585
+ material_output = nodes.new(type="ShaderNodeOutputMaterial")
586
+ material_output.location = (200, 0)
587
+ links.new(bsdf.outputs['BSDF'], material_output.inputs['Surface'])
588
+
589
+ # Enable transparency settings for the material
590
+ material.blend_method = 'BLEND'
591
+ material.show_transparent_back = False
592
+
593
+ if len(obj.data.polygons) == len(colors):
594
+ for i, poly in enumerate(obj.data.polygons):
595
+ poly.material_index = i
596
+ else:
597
+ print("Warning: The number of colors does not match the number of faces.")
598
+
599
+ # --- Add Wireframe Edges ---
600
+ # edge_material = bpy.data.materials.new(name="EdgeDelimiterMaterial")
601
+ # edge_material.use_nodes = True
602
+
603
+ # nodes = edge_material.node_tree.nodes
604
+ # links = edge_material.node_tree.links
605
+ # nodes.clear()
606
+
607
+ # if edge_color is None:
608
+ # edge_color = adjust_color_brightness(base_color, 0.10)
609
+
610
+ # edge_emission_node = nodes.new(type="ShaderNodeEmission")
611
+ # edge_emission_node.inputs['Color'].default_value = edge_color
612
+ # edge_output_node = nodes.new(type="ShaderNodeOutputMaterial")
613
+ # links.new(edge_emission_node.outputs['Emission'], edge_output_node.inputs['Surface'])
614
+
615
+ # obj.data.materials.append(edge_material)
616
+
617
+ # wire_mod = obj.modifiers.new(name="EdgeDelimiter", type='WIREFRAME')
618
+ # wire_mod.thickness = 0.01
619
+ # wire_mod.use_replace = False
620
+ # wire_mod.material_offset = len(obj.data.materials) - 1
621
+
622
+ # --- Bounding Box Calculation ---
623
+ bbox_corners = []
624
+ bpy.context.view_layer.update()
625
+ for child in empty_object.children:
626
+ for corner in child.bound_box:
627
+ world_corner = child.matrix_world @ mathutils.Vector(corner)
628
+ bbox_corners.append(world_corner)
629
+
630
+ if not bbox_corners:
631
+ return 0, empty_object
632
+
633
+ min_x = min(corner.x for corner in bbox_corners)
634
+ min_y = min(corner.y for corner in bbox_corners)
635
+ min_z = min(corner.z for corner in bbox_corners)
636
+
637
+ max_x = max(corner.x for corner in bbox_corners)
638
+ max_y = max(corner.y for corner in bbox_corners)
639
+ max_z = max(corner.z for corner in bbox_corners)
640
+
641
+ return max_z, empty_object
642
+
643
+
644
+ def get_primitive_object_translucent_rgb(base_color=(0.0, 1.0, 0.0), edge_color=None, face_opacity=0.025):
645
+ """
646
+ Spawns a cuboid primitive with individually colored faces and highlighted edges.
647
+
648
+ Args:
649
+ base_color (tuple): The base RGB color for the faces.
650
+ edge_color (tuple): The RGBA color for the edges (defaults to white).
651
+ face_opacity (float): The opacity of the cuboid faces (0.0 = invisible, 1.0 = opaque). Default is 0.2.
652
+ """
653
+ # --- Create the Cuboid and Parent ---
654
+ bpy.ops.object.empty_add(type="PLAIN_AXES")
655
+ # empty_object = bpy.context.object
656
+ empty_object = bpy.data.objects.new("Empty", None)
657
+ before_objs = set(bpy.data.objects)
658
+ bpy.ops.mesh.primitive_cube_add(size=0.5, location=(0, 0, 0))
659
+ after_objs = set(bpy.data.objects)
660
+ diff_objs = after_objs - before_objs
661
+
662
+ obj = None
663
+ for o in diff_objs:
664
+ obj = o
665
+ obj.parent = empty_object
666
+ world_matrix = obj.matrix_world
667
+ obj.matrix_world = world_matrix
668
+
669
+ # --- Create and Assign Materials for Each Face ---
670
+ if obj:
671
+ # left front right back bottom top
672
+ brightness_factors = [
673
+ 0.50, 0.50, 0.50, 0.50, 0.50, 0.50,
674
+ ]
675
+ red = (1.0, 0.0, 0.0, 1.0)
676
+ green = (0.0, 1.0, 0.0, 1.0)
677
+ blue = (0.0, 0.0, 1.0, 1.0)
678
+ colors = [adjust_color_brightness(green, factor) for factor in brightness_factors[:4]] + [adjust_color_brightness(blue, brightness_factors[4])] + [adjust_color_brightness(red, brightness_factors[5])]
679
+ colors = [colors[-2], colors[-1], colors[0], colors[1], colors[2], colors[3]]
680
+
681
+ for i, color in enumerate(colors):
682
+ material = bpy.data.materials.new(name=f"FaceColor_{i}")
683
+ material.use_nodes = True
684
+ obj.data.materials.append(material)
685
+
686
+ nodes = material.node_tree.nodes
687
+ links = material.node_tree.links
688
+ nodes.clear()
689
+
690
+ # Create Principled BSDF instead of Emission for proper transparency
691
+ bsdf = nodes.new(type="ShaderNodeBsdfPrincipled")
692
+ bsdf.location = (0, 0)
693
+ bsdf.inputs['Base Color'].default_value = color
694
+ bsdf.inputs['Alpha'].default_value = face_opacity # Set face opacity
695
+ bsdf.inputs['Emission Color'].default_value = color[:3] + (1.0,) # Fixed: Use 'Emission Color' instead of 'Emission'
696
+ bsdf.inputs['Emission Strength'].default_value = 1.0 # Emission strength
697
+
698
+ material_output = nodes.new(type="ShaderNodeOutputMaterial")
699
+ material_output.location = (200, 0)
700
+ links.new(bsdf.outputs['BSDF'], material_output.inputs['Surface'])
701
+
702
+ # Enable transparency settings for the material
703
+ material.blend_method = 'BLEND'
704
+ material.show_transparent_back = False
705
+
706
+ if len(obj.data.polygons) == len(colors):
707
+ for i, poly in enumerate(obj.data.polygons):
708
+ poly.material_index = i
709
+ else:
710
+ print("Warning: The number of colors does not match the number of faces.")
711
+
712
+ # --- Add Wireframe Edges ---
713
+ edge_material = bpy.data.materials.new(name="EdgeDelimiterMaterial")
714
+ edge_material.use_nodes = True
715
+
716
+ nodes = edge_material.node_tree.nodes
717
+ links = edge_material.node_tree.links
718
+ nodes.clear()
719
+
720
+ if edge_color is None:
721
+ edge_color = adjust_color_brightness(base_color, 0.10)
722
+
723
+ edge_emission_node = nodes.new(type="ShaderNodeEmission")
724
+ edge_emission_node.inputs['Color'].default_value = edge_color
725
+ edge_output_node = nodes.new(type="ShaderNodeOutputMaterial")
726
+ links.new(edge_emission_node.outputs['Emission'], edge_output_node.inputs['Surface'])
727
+
728
+ obj.data.materials.append(edge_material)
729
+
730
+ wire_mod = obj.modifiers.new(name="EdgeDelimiter", type='WIREFRAME')
731
+ wire_mod.thickness = 0.01
732
+ wire_mod.use_replace = False
733
+ wire_mod.material_offset = len(obj.data.materials) - 1
734
+
735
+ # --- Bounding Box Calculation ---
736
+ bbox_corners = []
737
+ bpy.context.view_layer.update()
738
+ for child in empty_object.children:
739
+ for corner in child.bound_box:
740
+ world_corner = child.matrix_world @ mathutils.Vector(corner)
741
+ bbox_corners.append(world_corner)
742
+
743
+ if not bbox_corners:
744
+ return 0, empty_object
745
+
746
+ min_x = min(corner.x for corner in bbox_corners)
747
+ min_y = min(corner.y for corner in bbox_corners)
748
+ min_z = min(corner.z for corner in bbox_corners)
749
+
750
+ max_x = max(corner.x for corner in bbox_corners)
751
+ max_y = max(corner.y for corner in bbox_corners)
752
+ max_z = max(corner.z for corner in bbox_corners)
753
+
754
+ return max_z, empty_object
755
+
756
+
757
+
758
+ def get_primitive_object(base_color=(0.0, 1.0, 0.0), edge_color=None):
759
+ """
760
+ Spawns a cuboid primitive with individually colored faces and highlighted edges.
761
+
762
+ Args:
763
+ base_color (tuple): The base RGB color for the faces.
764
+ edge_color (tuple): The RGBA color for the edges (defaults to white).
765
+ """
766
+ # --- Create the Empty Parent ---
767
+ empty_object = bpy.data.objects.new("Empty", None)
768
+ bpy.context.scene.collection.objects.link(empty_object)
769
+ empty_object.empty_display_type = 'PLAIN_AXES'
770
+
771
+ # --- Create the Cuboid using bmesh ---
772
+ mesh = bpy.data.meshes.new("Cube")
773
+ obj = bpy.data.objects.new("Cube", mesh)
774
+ bpy.context.scene.collection.objects.link(obj)
775
+
776
+ # Create cube geometry
777
+ bm = bmesh.new()
778
+ bmesh.ops.create_cube(bm, size=0.5)
779
+ bm.to_mesh(mesh)
780
+ bm.free()
781
+
782
+ # Set parent
783
+ obj.parent = empty_object
784
+ world_matrix = obj.matrix_world
785
+ obj.matrix_world = world_matrix
786
+
787
+ # --- Create and Assign Materials for Each Face ---
788
+ if obj:
789
+ # left front right back bottom top
790
+ brightness_factors = [
791
+ 0.35, 0.20, 0.65, 0.90, 0.50, 0.50
792
+ ]
793
+ colors = [adjust_color_brightness(base_color, factor) for factor in brightness_factors]
794
+
795
+ for i, color in enumerate(colors):
796
+ material = bpy.data.materials.new(name=f"FaceColor_{i}")
797
+ material.use_nodes = True
798
+ obj.data.materials.append(material)
799
+
800
+ nodes = material.node_tree.nodes
801
+ links = material.node_tree.links
802
+ nodes.clear()
803
+
804
+ emission_node = nodes.new(type="ShaderNodeEmission")
805
+ emission_node.inputs['Color'].default_value = color
806
+ material_output = nodes.new(type="ShaderNodeOutputMaterial")
807
+ links.new(emission_node.outputs['Emission'], material_output.inputs['Surface'])
808
+
809
+ material.blend_method = 'BLEND'
810
+ material.show_transparent_back = False
811
+
812
+ if len(obj.data.polygons) == len(colors):
813
+ for i, poly in enumerate(obj.data.polygons):
814
+ poly.material_index = i
815
+ else:
816
+ print("Warning: The number of colors does not match the number of faces.")
817
+
818
+ # --- MODIFICATION START: Add White Edges ---
819
+
820
+ # 1. Create a new material for the wireframe edges
821
+ edge_material = bpy.data.materials.new(name="EdgeDelimiterMaterial")
822
+ edge_material.use_nodes = True
823
+
824
+ # Set up the nodes for a simple white emission shader
825
+ nodes = edge_material.node_tree.nodes
826
+ links = edge_material.node_tree.links
827
+ nodes.clear()
828
+
829
+ if edge_color is None:
830
+ edge_color = adjust_color_brightness(base_color, 0.10)
831
+
832
+ edge_emission_node = nodes.new(type="ShaderNodeEmission")
833
+ edge_emission_node.inputs['Color'].default_value = edge_color
834
+ edge_output_node = nodes.new(type="ShaderNodeOutputMaterial")
835
+ links.new(edge_emission_node.outputs['Emission'], edge_output_node.inputs['Surface'])
836
+
837
+ # 2. Add the edge material to the object's material slots
838
+ obj.data.materials.append(edge_material)
839
+
840
+ # 3. Add and configure the Wireframe modifier
841
+ wire_mod = obj.modifiers.new(name="EdgeDelimiter", type='WIREFRAME')
842
+ wire_mod.thickness = 0.01 # The thickness of the edge lines
843
+ wire_mod.use_replace = False # Set to False to keep the original faces
844
+ # This offset tells the modifier to use the last material we added (the white one)
845
+ wire_mod.material_offset = len(obj.data.materials) - 1
846
+
847
+ # --- MODIFICATION END ---
848
+
849
+
850
+ # --- Bounding Box Calculation (remains the same) ---
851
+ bbox_corners = []
852
+ # Update the dependency graph to ensure modifiers are accounted for
853
+ bpy.context.view_layer.update()
854
+ for child in empty_object.children:
855
+ # Use child.bound_box which is in object's local space
856
+ for corner in child.bound_box:
857
+ # Convert corner to world space
858
+ world_corner = child.matrix_world @ mathutils.Vector(corner)
859
+ bbox_corners.append(world_corner)
860
+
861
+ if not bbox_corners:
862
+ return 0, empty_object # Return a default value if no corners found
863
+
864
+ min_x = min(corner.x for corner in bbox_corners)
865
+ min_y = min(corner.y for corner in bbox_corners)
866
+ min_z = min(corner.z for corner in bbox_corners)
867
+
868
+ max_x = max(corner.x for corner in bbox_corners)
869
+ max_y = max(corner.y for corner in bbox_corners)
870
+ max_z = max(corner.z for corner in bbox_corners)
871
+
872
+ return max_z, empty_object
873
+
874
+ class BlenderCuboidRenderer:
875
+ def __init__(self, render_engine):
876
+ """
877
+ Initialize the Blender cuboid renderer.
878
+
879
+ Args:
880
+ img_dim (int): Image dimensions (square)
881
+ render_engine (str): Blender render engine ('EEVEE' or 'CYCLES')
882
+ num_lights (int): Number of lights to add
883
+ max_tries (int): Maximum tries for placement
884
+ """
885
+ self.img_dim = 1024
886
+ self.render_engine = render_engine
887
+ self.blender_grid_dims = scales
888
+
889
+ self.radius = 6.0
890
+ self.center = -6.0
891
+
892
+ # Scene references
893
+ self.context = None
894
+ self.scene = None
895
+ self.camera = None
896
+ self.render = None
897
+
898
+ # Setup the scene
899
+ self.setup_scene()
900
+
901
+
902
+ def setup_scene(self):
903
+ """
904
+ Setup the basic Blender scene with camera, lighting, and render settings.
905
+
906
+ Args:
907
+ camera_data (dict): Camera configuration containing elevation, lens, global_scale, etc.
908
+ """
909
+ # Get all objects in the scene
910
+ objects_to_remove = []
911
+
912
+ for obj in bpy.data.objects:
913
+ # Remove default cube, plane, camera, and lights
914
+ if obj.type in {'MESH', 'LIGHT', 'CAMERA'}:
915
+ objects_to_remove.append(obj)
916
+
917
+ # Delete the objects
918
+ for obj in objects_to_remove:
919
+ bpy.data.objects.remove(obj, do_unlink=True)
920
+
921
+ # Also clear orphaned data
922
+ for mesh in bpy.data.meshes:
923
+ if mesh.users == 0:
924
+ bpy.data.meshes.remove(mesh)
925
+
926
+ for light in bpy.data.lights:
927
+ if light.users == 0:
928
+ bpy.data.lights.remove(light)
929
+
930
+ for camera in bpy.data.cameras:
931
+ if camera.users == 0:
932
+ bpy.data.cameras.remove(camera)
933
+
934
+ bpy.context.scene.world = None
935
+
936
+ # Initialize Blender scene
937
+ # bpy.ops.wm.read_factory_settings(use_empty=True)
938
+ self.context = bpy.context
939
+ self.scene = self.context.scene
940
+ if self.render_engine == "CYCLES":
941
+ self.scene.cycles.samples = 32
942
+ self.render = self.scene.render
943
+
944
+ # Set render engine and resolution
945
+ self.render.engine = self.render_engine
946
+ self.context.scene.render.resolution_x = self.img_dim
947
+ self.context.scene.render.resolution_y = self.img_dim
948
+ self.context.scene.render.resolution_percentage = 100
949
+
950
+ # Setup compositing nodes
951
+ self._setup_compositing()
952
+
953
+
954
+ def _setup_compositing(self):
955
+ """Setup Blender compositing nodes for depth and RGB output."""
956
+ self.context.scene.use_nodes = True
957
+ tree = self.context.scene.node_tree
958
+ links = tree.links
959
+
960
+ self.context.scene.render.use_compositing = True
961
+ self.context.view_layer.use_pass_z = True
962
+
963
+ # clear default nodes
964
+ for n in tree.nodes:
965
+ tree.nodes.remove(n)
966
+
967
+ # create input render layer node
968
+ rl = tree.nodes.new('CompositorNodeRLayers')
969
+
970
+ map_node = tree.nodes.new(type="CompositorNodeMapValue")
971
+ map_node.size = [0.05]
972
+ map_node.use_min = True
973
+ map_node.min = [0]
974
+ map_node.use_max = True
975
+ map_node.max = [65336]
976
+ links.new(rl.outputs[2], map_node.inputs[0])
977
+
978
+ invert = tree.nodes.new(type="CompositorNodeInvert")
979
+ links.new(map_node.outputs[0], invert.inputs[1])
980
+
981
+ # create output node
982
+ v = tree.nodes.new('CompositorNodeViewer')
983
+ v.use_alpha = True
984
+
985
+ # create a file output node and set the path
986
+ fileOutput = tree.nodes.new(type="CompositorNodeOutputFile")
987
+ fileOutput.base_path = "."
988
+ links.new(invert.outputs[0], fileOutput.inputs[0])
989
+
990
+ # Links
991
+ links.new(rl.outputs[0], v.inputs[0]) # link Image to Viewer Image RGB
992
+ links.new(rl.outputs['Depth'], v.inputs[1]) # link Render Z to Viewer Image Alpha
993
+
994
+ # Update scene to apply changes
995
+ self.context.view_layer.update()
996
+
997
+
998
+ def _setup_camera_cv(self, camera_data):
999
+ """Setup camera position and orientation."""
1000
+ reset_cameras(self.scene)
1001
+ self.camera = self.scene.objects["Camera"]
1002
+
1003
+ elevation = camera_data["camera_elevation"]
1004
+ tan_elevation = np.tan(elevation)
1005
+ cos_elevation = np.cos(elevation)
1006
+ sin_elevation = np.sin(elevation)
1007
+
1008
+ radius = self.radius
1009
+ center = self.center
1010
+
1011
+ self.camera.location = mathutils.Vector((radius * cos_elevation + center, 0, radius * sin_elevation))
1012
+ direction = mathutils.Vector((-1, 0, -tan_elevation))
1013
+ self.context.scene.camera = self.camera
1014
+ rot_quat = direction.to_track_quat("-Z", "Y")
1015
+ self.camera.rotation_euler = rot_quat.to_euler()
1016
+ self.camera.data.lens = camera_data["lens"]
1017
+
1018
+ def _create_cuboid_objects_translucent(self, subjects_data, opacity=0.025):
1019
+ """Create primitive cuboid objects for all subjects."""
1020
+ for subject_idx, subject_data in enumerate(subjects_data):
1021
+ # rgb_color = map_point_to_rgb(x, y)
1022
+ rgb_color = COLORS[subject_idx % len(COLORS)]
1023
+ _, prim_obj = get_primitive_object_translucent(base_color=rgb_color, face_opacity=opacity)
1024
+ prim_obj.location = np.array([100, 0, 0])
1025
+ subject_data["prim_obj"] = prim_obj
1026
+
1027
+ def _create_cuboid_objects_translucent_rgb(self, subjects_data, opacity=0.025):
1028
+ """Create primitive cuboid objects for all subjects."""
1029
+ for subject_idx, subject_data in enumerate(subjects_data):
1030
+ x = subject_data["x"][0]
1031
+ y = subject_data["y"][0]
1032
+ z = subject_data["z"][0]
1033
+ base_color = map_point_to_rgb(x, y, z)
1034
+ _, prim_obj = get_primitive_object_translucent_rgb(base_color=base_color, face_opacity=opacity)
1035
+ prim_obj.location = np.array([100, 0, 0])
1036
+ subject_data["prim_obj"] = prim_obj
1037
+
1038
+
1039
+ def _place_objects(self, subjects_data, camera_data):
1040
+ """Place objects in the scene according to their data."""
1041
+ global_scale = camera_data["global_scale"]
1042
+
1043
+ for subject_data in subjects_data:
1044
+ x = subject_data["x"][0]
1045
+ y = subject_data["y"][0]
1046
+ z = global_scale * subject_data["dims"][2] / 2.0 + subject_data["z"][0]
1047
+ subject_data["prim_obj"].location = np.array([x, y, z])
1048
+ subject_data["prim_obj"].scale = global_scale * np.array(subject_data["dims"]) * 2.0
1049
+ subject_data["prim_obj"].rotation_euler[2] = subject_data["azimuth"][0]
1050
+
1051
+ def render_cv(self, subjects_data, camera_data, num_samples=1, output_path="main.jpg"):
1052
+ """
1053
+ Main render method that takes subjects data and renders the scene.
1054
+
1055
+ Args:
1056
+ subjects_data (list): List of subject dictionaries containing position, dims, etc.
1057
+ camera_data (dict): Camera configuration
1058
+ num_samples (int): Number of samples to render (currently only supports 1)
1059
+ output_path (str): Path to save the rendered image
1060
+
1061
+ Returns:
1062
+ None
1063
+ """
1064
+ center = (-6.0, 0.0, 0.0)
1065
+ radius = 6.0
1066
+
1067
+ print(f"render_cv received {subjects_data = }")
1068
+
1069
+ # print(f"render_cv received {subjects_data = }")
1070
+ for subject_data in subjects_data:
1071
+ subject_data["azimuth"][0] = np.deg2rad(subject_data["azimuth"][0])
1072
+ subject_data["x"][0] = subject_data["x"][0] + center[0]
1073
+ subject_data["y"][0] = subject_data["y"][0] + center[1]
1074
+ subject_data["z"][0] = subject_data["z"][0] + center[2]
1075
+ # Setup camera
1076
+ self._setup_camera_cv(camera_data)
1077
+
1078
+ set_lights_cv(self.radius, np.array([self.center, 0, 0]), 20, intensity=7000.0)
1079
+
1080
+ # Add ground plane
1081
+ add_plane()
1082
+
1083
+ assert num_samples == 1, "for now, only implemented for a single sample"
1084
+ assert "global_scale" in camera_data.keys(), "global_scale must be set for EEVEE"
1085
+
1086
+ # Create primitive objects for subjects
1087
+ self._create_cuboid_objects_translucent(subjects_data, opacity=0.025)
1088
+ # self._create_cuboid_objects(subjects_data)
1089
+
1090
+ # Place objects in scene
1091
+ self._place_objects(subjects_data, camera_data)
1092
+
1093
+ # Perform rendering
1094
+ print(f"SUCCESS, rendering...")
1095
+ self.context.scene.render.filepath = output_path
1096
+ self.context.scene.render.image_settings.file_format = "JPEG"
1097
+ bpy.ops.render.render(write_still=True)
1098
+
1099
+ print(f"Rendered scene saved to: {output_path}")
1100
+
1101
+ self.cleanup()
1102
+
1103
+ def render_final_representation(self, subjects_data, camera_data, num_samples=1, output_path="main.jpg"):
1104
+ """
1105
+ Main render method that takes subjects data and renders the scene.
1106
+
1107
+ Args:
1108
+ subjects_data (list): List of subject dictionaries containing position, dims, etc.
1109
+ camera_data (dict): Camera configuration
1110
+ num_samples (int): Number of samples to render (currently only supports 1)
1111
+ output_path (str): Path to save the rendered image
1112
+
1113
+ Returns:
1114
+ None
1115
+ """
1116
+ assert self.render.engine == "CYCLES", "render_final_representation only works with CYCLES render engine"
1117
+ center = (-6.0, 0.0, 0.0)
1118
+ radius = 6.0
1119
+
1120
+ print(f"render_cv received {subjects_data = }")
1121
+
1122
+ # print(f"render_cv received {subjects_data = }")
1123
+ for subject_data in subjects_data:
1124
+ subject_data["azimuth"][0] = np.deg2rad(subject_data["azimuth"][0])
1125
+ subject_data["x"][0] = subject_data["x"][0] + center[0]
1126
+ subject_data["y"][0] = subject_data["y"][0] + center[1]
1127
+ subject_data["z"][0] = subject_data["z"][0] + center[2]
1128
+ # Setup camera
1129
+ self._setup_camera_cv(camera_data)
1130
+
1131
+ print(f"setting lights in cycles...")
1132
+ set_lights_cv(self.radius, np.array([self.center, 0, 0]), 5, intensity=700.0)
1133
+
1134
+ # Add ground plane
1135
+ print(f"adding plane in cycles...")
1136
+ add_plane_cycles()
1137
+
1138
+ assert num_samples == 1, "for now, only implemented for a single sample"
1139
+ assert "global_scale" in camera_data.keys(), "global_scale must be set for EEVEE"
1140
+
1141
+ # Create primitive objects for subjects
1142
+ self._create_cuboid_objects_translucent_rgb(subjects_data, opacity=0.025)
1143
+ # self._create_cuboid_objects(subjects_data)
1144
+
1145
+ # Place objects in scene
1146
+ self._place_objects(subjects_data, camera_data)
1147
+
1148
+ # Perform rendering
1149
+ print(f"SUCCESS, rendering...")
1150
+ self.context.scene.render.filepath = output_path
1151
+ self.context.scene.render.image_settings.file_format = "JPEG"
1152
+ bpy.ops.render.render(write_still=True)
1153
+
1154
+ print(f"Rendered scene saved to: {output_path}")
1155
+
1156
+ self.cleanup()
1157
+
1158
+
1159
+ def render_paper_figure(self, subjects_data, camera_data, num_samples=1, output_path="main.jpg"):
1160
+ """
1161
+ Main render method that takes subjects data and renders the scene.
1162
+
1163
+ Args:
1164
+ subjects_data (list): List of subject dictionaries containing position, dims, etc.
1165
+ camera_data (dict): Camera configuration
1166
+ num_samples (int): Number of samples to render (currently only supports 1)
1167
+ output_path (str): Path to save the rendered image
1168
+
1169
+ Returns:
1170
+ None
1171
+ """
1172
+ assert self.render.engine == "CYCLES", "render_final_representation only works with CYCLES render engine"
1173
+ center = (-6.0, 0.0, 0.0)
1174
+ radius = 6.0
1175
+
1176
+ print(f"render_cv received {subjects_data = }")
1177
+
1178
+ set_world_color((1.0, 1.0, 1.0)) # white background
1179
+
1180
+ # print(f"render_cv received {subjects_data = }")
1181
+ for subject_data in subjects_data:
1182
+ subject_data["azimuth"][0] = np.deg2rad(subject_data["azimuth"][0])
1183
+ subject_data["x"][0] = subject_data["x"][0] + center[0]
1184
+ subject_data["y"][0] = subject_data["y"][0] + center[1]
1185
+ subject_data["z"][0] = subject_data["z"][0] + center[2]
1186
+ # Setup camera
1187
+ self._setup_camera_cv(camera_data)
1188
+
1189
+ print(f"setting lights in cycles...")
1190
+ set_lights_cv(self.radius, np.array([self.center, 0, 0]), 5, intensity=7000.0)
1191
+
1192
+ # Add ground plane
1193
+ print(f"adding plane in cycles...")
1194
+
1195
+ assert num_samples == 1, "for now, only implemented for a single sample"
1196
+ assert "global_scale" in camera_data.keys(), "global_scale must be set for EEVEE"
1197
+
1198
+ # Create primitive objects for subjects
1199
+ self._create_cuboid_objects_translucent(subjects_data, opacity=0.35)
1200
+ # self._create_cuboid_objects(subjects_data)
1201
+
1202
+ # Place objects in scene
1203
+ self._place_objects(subjects_data, camera_data)
1204
+
1205
+ # Perform rendering
1206
+ print(f"SUCCESS, rendering...")
1207
+ self.context.scene.render.filepath = output_path
1208
+ self.context.scene.render.image_settings.file_format = "JPEG"
1209
+ bpy.ops.render.render(write_still=True)
1210
+
1211
+ print(f"Rendered scene saved to: {output_path}")
1212
+
1213
+ self.cleanup()
1214
+
1215
+
1216
+ def cleanup(self):
1217
+ """Clean up the scene for next render."""
1218
+ # Remove all lights
1219
+ remove_all_lights()
1220
+
1221
+ # Remove all other objects (meshes, empties, etc.)
1222
+ objects_to_remove = [obj for obj in bpy.data.objects]
1223
+
1224
+ for obj in objects_to_remove:
1225
+ bpy.data.objects.remove(obj, do_unlink=True)
1226
+
1227
+ # Clean up orphaned data blocks
1228
+ for mesh in bpy.data.meshes:
1229
+ if mesh.users == 0:
1230
+ bpy.data.meshes.remove(mesh)
1231
+
1232
+ for material in bpy.data.materials:
1233
+ if material.users == 0:
1234
+ bpy.data.materials.remove(material)
1235
+
1236
+ for light_data in bpy.data.lights:
1237
+ if light_data.users == 0:
1238
+ bpy.data.lights.remove(light_data)
1239
+
1240
+
1241
+ class BlenderSegmaskRenderer:
1242
+ def __init__(self):
1243
+ """
1244
+ Initialize the Blender cuboid renderer.
1245
+
1246
+ Args:
1247
+ img_dim (int): Image dimensions (square)
1248
+ render_engine (str): Blender render engine ('EEVEE' or 'CYCLES')
1249
+ num_lights (int): Number of lights to add
1250
+ max_tries (int): Maximum tries for placement
1251
+ """
1252
+ self.img_dim = 1024
1253
+ self.render_engine = "BLENDER_WORKBENCH"
1254
+ self.blender_grid_dims = scales
1255
+
1256
+ self.radius = 6.0
1257
+ self.center = -6.0
1258
+
1259
+ # Scene references
1260
+ self.context = None
1261
+ self.scene = None
1262
+ self.camera = None
1263
+ self.render = None
1264
+
1265
+ # Setup the scene
1266
+ self.setup_scene()
1267
+
1268
+
1269
+ def setup_scene(self):
1270
+ """
1271
+ Setup the basic Blender scene with camera, lighting, and render settings.
1272
+
1273
+ Args:
1274
+ camera_data (dict): Camera configuration containing elevation, lens, global_scale, etc.
1275
+ """
1276
+ # Get all objects in the scene
1277
+ objects_to_remove = []
1278
+
1279
+ for obj in bpy.data.objects:
1280
+ # Remove default cube, plane, camera, and lights
1281
+ if obj.type in {'MESH', 'LIGHT', 'CAMERA'}:
1282
+ objects_to_remove.append(obj)
1283
+
1284
+ # Delete the objects
1285
+ for obj in objects_to_remove:
1286
+ bpy.data.objects.remove(obj, do_unlink=True)
1287
+
1288
+ # Also clear orphaned data
1289
+ for mesh in bpy.data.meshes:
1290
+ if mesh.users == 0:
1291
+ bpy.data.meshes.remove(mesh)
1292
+
1293
+ for light in bpy.data.lights:
1294
+ if light.users == 0:
1295
+ bpy.data.lights.remove(light)
1296
+
1297
+ for camera in bpy.data.cameras:
1298
+ if camera.users == 0:
1299
+ bpy.data.cameras.remove(camera)
1300
+
1301
+ bpy.context.scene.world = None
1302
+
1303
+ # Initialize Blender scene
1304
+ # bpy.ops.wm.read_factory_settings(use_empty=True)
1305
+ self.context = bpy.context
1306
+ self.scene = self.context.scene
1307
+ self.render = self.scene.render
1308
+
1309
+ # Set render engine and resolution
1310
+ self.render.engine = self.render_engine
1311
+ self.context.scene.render.resolution_x = self.img_dim
1312
+ self.context.scene.render.resolution_y = self.img_dim
1313
+ self.context.scene.render.resolution_percentage = 100
1314
+
1315
+ # Setup compositing nodes
1316
+ self._setup_compositing()
1317
+
1318
+
1319
+ def _setup_compositing(self):
1320
+ """Setup Blender compositing nodes for depth and RGB output."""
1321
+ self.context.scene.use_nodes = True
1322
+ tree = self.context.scene.node_tree
1323
+ links = tree.links
1324
+
1325
+ self.context.scene.render.use_compositing = True
1326
+ self.context.view_layer.use_pass_z = True
1327
+
1328
+ # clear default nodes
1329
+ for n in tree.nodes:
1330
+ tree.nodes.remove(n)
1331
+
1332
+ # create input render layer node
1333
+ rl = tree.nodes.new('CompositorNodeRLayers')
1334
+
1335
+ map_node = tree.nodes.new(type="CompositorNodeMapValue")
1336
+ map_node.size = [0.05]
1337
+ map_node.use_min = True
1338
+ map_node.min = [0]
1339
+ map_node.use_max = True
1340
+ map_node.max = [65336]
1341
+ links.new(rl.outputs[2], map_node.inputs[0])
1342
+
1343
+ invert = tree.nodes.new(type="CompositorNodeInvert")
1344
+ links.new(map_node.outputs[0], invert.inputs[1])
1345
+
1346
+ # create output node
1347
+ v = tree.nodes.new('CompositorNodeViewer')
1348
+ v.use_alpha = True
1349
+
1350
+ # create a file output node and set the path
1351
+ fileOutput = tree.nodes.new(type="CompositorNodeOutputFile")
1352
+ fileOutput.base_path = "."
1353
+ links.new(invert.outputs[0], fileOutput.inputs[0])
1354
+
1355
+ # Links
1356
+ links.new(rl.outputs[0], v.inputs[0]) # link Image to Viewer Image RGB
1357
+ links.new(rl.outputs['Depth'], v.inputs[1]) # link Render Z to Viewer Image Alpha
1358
+
1359
+ # Update scene to apply changes
1360
+ self.context.view_layer.update()
1361
+
1362
+
1363
+ def _setup_camera_cv(self, camera_data):
1364
+ """Setup camera position and orientation."""
1365
+ reset_cameras(self.scene)
1366
+ self.camera = self.scene.objects["Camera"]
1367
+
1368
+ elevation = camera_data["camera_elevation"]
1369
+ tan_elevation = np.tan(elevation)
1370
+ cos_elevation = np.cos(elevation)
1371
+ sin_elevation = np.sin(elevation)
1372
+
1373
+ radius = self.radius
1374
+ center = self.center
1375
+
1376
+ self.camera.location = mathutils.Vector((radius * cos_elevation + center, 0, radius * sin_elevation))
1377
+ direction = mathutils.Vector((-1, 0, -tan_elevation))
1378
+ self.context.scene.camera = self.camera
1379
+ rot_quat = direction.to_track_quat("-Z", "Y")
1380
+ self.camera.rotation_euler = rot_quat.to_euler()
1381
+ self.camera.data.lens = camera_data["lens"]
1382
+
1383
+ def _create_cuboid_objects(self, subjects_data):
1384
+ """Create primitive cuboid objects for all subjects."""
1385
+ for subject_idx, subject_data in enumerate(subjects_data):
1386
+ x = subject_data["x"][0]
1387
+ y = subject_data["y"][0]
1388
+ z = subject_data["z"][0]
1389
+ rgb_color = map_point_to_rgb(x, y, z)
1390
+ _, prim_obj = get_primitive_object(rgb_color)
1391
+ prim_obj.location = np.array([100, 0, 0])
1392
+ subject_data["prim_obj"] = prim_obj
1393
+
1394
+ def _place_objects(self, subjects_data, camera_data):
1395
+ """Place objects in the scene according to their data."""
1396
+ global_scale = camera_data["global_scale"]
1397
+
1398
+ for subject_data in subjects_data:
1399
+ x = subject_data["x"][0]
1400
+ y = subject_data["y"][0]
1401
+ z = global_scale * subject_data["dims"][2] / 2.0 + subject_data["z"][0]
1402
+ subject_data["prim_obj"].location = np.array([x, y, z])
1403
+ subject_data["prim_obj"].scale = global_scale * np.array(subject_data["dims"]) * 2.0
1404
+ subject_data["prim_obj"].rotation_euler[2] = subject_data["azimuth"][0]
1405
+
1406
+ def render_cv(self, subjects_data, camera_data, num_samples=1):
1407
+ """
1408
+ Main render method that takes subjects data and renders the scene.
1409
+
1410
+ Args:
1411
+ subjects_data (list): List of subject dictionaries containing position, dims, etc.
1412
+ camera_data (dict): Camera configuration
1413
+ num_samples (int): Number of samples to render (currently only supports 1)
1414
+ output_path (str): Path to save the rendered image
1415
+
1416
+ Returns:
1417
+ None
1418
+ """
1419
+ # Setup camera
1420
+ center = (-6.0, 0.0, 0.0)
1421
+ radius = 6.0
1422
+
1423
+ for subject_data in subjects_data:
1424
+ subject_data["azimuth"][0] = np.deg2rad(subject_data["azimuth"][0])
1425
+ subject_data["x"][0] = subject_data["x"][0] + center[0]
1426
+ subject_data["y"][0] = subject_data["y"][0] + center[1]
1427
+ subject_data["z"][0] = subject_data["z"][0] + center[2]
1428
+
1429
+ print(f"in segmask render, {subjects_data = }")
1430
+
1431
+ self._setup_camera_cv(camera_data)
1432
+
1433
+ assert num_samples == 1, "for now, only implemented for a single sample"
1434
+ assert "global_scale" in camera_data.keys(), "global_scale must be set"
1435
+
1436
+ # Create primitive objects for subjects
1437
+ self._create_cuboid_objects(subjects_data)
1438
+
1439
+ def make_segmask(image):
1440
+ alpha = image[:, :, 3]
1441
+ _, mask = cv2.threshold(alpha, 0, 255, cv2.THRESH_BINARY)
1442
+ return mask
1443
+
1444
+
1445
+ for subject_idx, subject_data in enumerate(subjects_data):
1446
+ # Place objects in scene
1447
+ self._place_objects([subject_data], camera_data)
1448
+
1449
+ # Perform rendering
1450
+ print(f"SUCCESS, rendering...")
1451
+ self.context.scene.render.filepath = "tmp.png"
1452
+ self.context.scene.render.image_settings.file_format = "PNG"
1453
+ bpy.ops.render.render(write_still=True)
1454
+ img = cv2.imread("tmp.png", cv2.IMREAD_UNCHANGED)
1455
+ segmask = make_segmask(img)
1456
+ print(f"{segmask.shape = }")
1457
+ cv2.imwrite(f"{str(subject_idx).zfill(3)}_segmask_cv.png", segmask)
1458
+ print(f"saved {str(subject_idx).zfill(3)}_segmask_cv.png")
1459
+
1460
+ subject_data["prim_obj"].location = np.array([100, 0, 0]) # move out of view
1461
+
1462
+ self.cleanup()
1463
+
1464
+
1465
+ def cleanup(self):
1466
+ """Clean up the scene for next render."""
1467
+ # Remove all lights
1468
+ remove_all_lights()
1469
+
1470
+ # Remove all other objects (meshes, empties, etc.)
1471
+ objects_to_remove = [obj for obj in bpy.data.objects]
1472
+
1473
+ for obj in objects_to_remove:
1474
+ bpy.data.objects.remove(obj, do_unlink=True)
1475
+
1476
+ # Clean up orphaned data blocks
1477
+ for mesh in bpy.data.meshes:
1478
+ if mesh.users == 0:
1479
+ bpy.data.meshes.remove(mesh)
1480
+
1481
+ for material in bpy.data.materials:
1482
+ if material.users == 0:
1483
+ bpy.data.materials.remove(material)
1484
+
1485
+ for light_data in bpy.data.lights:
1486
+ if light_data.users == 0:
1487
+ bpy.data.lights.remove(light_data)
1488
+
1489
+
1490
+
1491
+ # Update the main execution
1492
+ if __name__ == '__main__':
1493
+ subjects_data = [
1494
+ {
1495
+ "name": "sedan",
1496
+ "x": [-5.0],
1497
+ "y": [0.0],
1498
+ "dims": [1.0, 2.0, 1.5],
1499
+ "azimuth": [0.0]
1500
+ },
1501
+ ]
1502
+ camera_data = {
1503
+ "camera_elevation": np.arctan(0.45),
1504
+ "lens": 70,
1505
+ "global_scale": 1.0
1506
+ }
1507
+
1508
+ # Create renderer instance
1509
+ renderer = BlenderCuboidRenderer(
1510
+ img_dim=1024,
1511
+ render_engine='EEVEE',
1512
+ num_lights=1,
1513
+ )
1514
+
1515
+ # Render the scene
1516
+ renderer.render(
1517
+ subjects_data=subjects_data,
1518
+ camera_data=camera_data,
1519
+ num_samples=1,
1520
+ output_path="main.jpg"
1521
+ )
blender_server.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
+ import shutil
5
+ import base64
6
+ import io
7
+ from PIL import Image
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+ from typing import List, Dict, Any
11
+ import uvicorn
12
+ import argparse
13
+
14
+ # Import BlenderCuboidRenderer
15
+ from blender_backend import BlenderCuboidRenderer
16
+
17
+ class RenderRequest(BaseModel):
18
+ subjects_data: List[Dict[str, Any]]
19
+ camera_data: Dict[str, Any]
20
+ num_samples: int = 1
21
+
22
+ class RenderResponse(BaseModel):
23
+ success: bool
24
+ image_base64: str = None
25
+ error_message: str = None
26
+
27
+ class BlenderRenderServer:
28
+ def __init__(self, render_mode: str):
29
+ """
30
+ Initialize the Blender render server.
31
+
32
+ Args:
33
+ render_mode (str): Either 'cv' for camera view or 'bev' for bird's eye view
34
+ """
35
+ self.render_mode = render_mode
36
+ if self.render_mode == "cv":
37
+ self.renderer = BlenderCuboidRenderer("BLENDER_EEVEE_NEXT")
38
+ elif self.render_mode == "final":
39
+ self.renderer = BlenderCuboidRenderer("CYCLES")
40
+ elif self.render_mode == "paper":
41
+ self.renderer = BlenderCuboidRenderer("CYCLES")
42
+
43
+ def process_render_request(self, request: RenderRequest) -> RenderResponse:
44
+ """Process a render request and return the result."""
45
+ # Create temporary directory for this render
46
+ output_path = os.path.join(f"{self.render_mode}_render.jpg")
47
+
48
+ # Convert subjects_data format if needed
49
+ converted_subjects_data = self._convert_subjects_data(request.subjects_data)
50
+
51
+ # Add required camera_data fields
52
+ camera_data = request.camera_data.copy()
53
+ camera_data["global_scale"] = camera_data.get("global_scale", 1.0)
54
+
55
+ # Perform the render based on mode
56
+ if self.render_mode == "cv":
57
+ self.renderer.render_cv(
58
+ subjects_data=converted_subjects_data,
59
+ camera_data=camera_data,
60
+ num_samples=request.num_samples,
61
+ output_path=output_path
62
+ )
63
+ elif self.render_mode == "final":
64
+ self.renderer.render_final_representation(
65
+ subjects_data=converted_subjects_data,
66
+ camera_data=camera_data,
67
+ num_samples=request.num_samples,
68
+ output_path=output_path
69
+ )
70
+ elif self.render_mode == "paper":
71
+ self.renderer.render_paper_figure(
72
+ subjects_data=converted_subjects_data,
73
+ camera_data=camera_data,
74
+ num_samples=request.num_samples,
75
+ output_path=output_path
76
+ )
77
+ else:
78
+ raise ValueError(f"Invalid render mode: {self.render_mode}")
79
+
80
+ # Read and encode the rendered image
81
+ if os.path.exists(output_path):
82
+ with open(output_path, "rb") as img_file:
83
+ img_data = img_file.read()
84
+ img_base64 = base64.b64encode(img_data).decode('utf-8')
85
+
86
+ return RenderResponse(success=True, image_base64=img_base64)
87
+ else:
88
+ return RenderResponse(
89
+ success=False,
90
+ error_message="Render output file not found"
91
+ )
92
+
93
+ def _convert_subjects_data(self, subjects_data: List[Dict]) -> List[Dict]:
94
+ """Convert subjects data to the format expected by BlenderCuboidRenderer."""
95
+ converted = []
96
+
97
+ for subject in subjects_data:
98
+ # Convert to the expected format with lists for x, y, azimuth
99
+ converted_subject = {
100
+ "name": subject.get("subject_name", "cuboid"),
101
+ "x": [subject["x"]],
102
+ "y": [subject["y"]],
103
+ "z": [subject["z"]],
104
+ "dims": [subject["width"], subject["depth"], subject["height"]],
105
+ "azimuth": [subject["azimuth"]]
106
+ }
107
+ converted.append(converted_subject)
108
+
109
+ return converted
110
+
111
+ # Create FastAPI app
112
+ app = FastAPI(title="Blender Render Server")
113
+
114
+ # Global server instance
115
+ server = None
116
+
117
+ @app.on_event("startup")
118
+ def startup_event():
119
+ global server
120
+ render_mode = os.environ.get("RENDER_MODE")
121
+ server = BlenderRenderServer(render_mode)
122
+ print(f"Blender Render Server started in {render_mode.upper()} mode")
123
+
124
+ @app.post("/render", response_model=RenderResponse)
125
+ def render_scene(request: RenderRequest):
126
+ """Render a scene and return the result."""
127
+ if server is None:
128
+ raise HTTPException(status_code=500, detail="Server not initialized")
129
+
130
+ return server.process_render_request(request)
131
+
132
+ @app.get("/health")
133
+ def health_check():
134
+ """Health check endpoint."""
135
+ return {"status": "healthy", "render_mode": server.render_mode if server else "unknown"}
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser(description="Blender Render Server")
139
+ parser.add_argument("--mode", choices=["cv", "final", "paper"], required=True,
140
+ help="Render mode: cv for camera view, bev for bird's eye view")
141
+ parser.add_argument("--port", type=int, default=5001, help="Port to run server on")
142
+ parser.add_argument("--host", default="127.0.0.1", help="Host to bind server to")
143
+
144
+ args = parser.parse_args()
145
+
146
+ # Set environment variable for the startup event
147
+ os.environ["RENDER_MODE"] = args.mode
148
+
149
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
blender_server_segmasks.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
+ import shutil
5
+ import base64
6
+ import io
7
+ from PIL import Image
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+ from typing import List, Dict, Any
11
+ import uvicorn
12
+ import argparse
13
+
14
+ # Import BlenderSegmaskRenderer
15
+ from blender_backend import BlenderSegmaskRenderer
16
+
17
+ class SegmaskRenderRequest(BaseModel):
18
+ subjects_data: List[Dict[str, Any]]
19
+ camera_data: Dict[str, Any]
20
+ num_samples: int = 1
21
+
22
+ class SegmaskRenderResponse(BaseModel):
23
+ success: bool
24
+ segmasks_base64: List[str] = None
25
+ error_message: str = None
26
+
27
+ class BlenderSegmaskRenderServer:
28
+ def __init__(self):
29
+ """Initialize the Blender segmentation mask render server."""
30
+ self.renderer = BlenderSegmaskRenderer()
31
+
32
+ def process_render_request(self, request: SegmaskRenderRequest) -> SegmaskRenderResponse:
33
+ """Process a segmentation mask render request and return the result."""
34
+ try:
35
+ # Create temporary directory for this render
36
+ # Convert subjects_data format if needed
37
+ converted_subjects_data = self._convert_subjects_data(request.subjects_data)
38
+
39
+ # Add required camera_data fields
40
+ camera_data = request.camera_data.copy()
41
+ camera_data["global_scale"] = camera_data.get("global_scale", 1.0)
42
+
43
+ # Perform the render
44
+ self.renderer.render_cv(
45
+ subjects_data=converted_subjects_data,
46
+ camera_data=camera_data,
47
+ num_samples=request.num_samples
48
+ )
49
+
50
+ # Read and encode all segmentation masks in order
51
+ segmasks_base64 = []
52
+ num_subjects = len(converted_subjects_data)
53
+
54
+ for subject_idx in range(num_subjects):
55
+ segmask_path = os.path.join(f"{str(subject_idx).zfill(3)}_segmask_cv.png")
56
+
57
+ if os.path.exists(segmask_path):
58
+ with open(segmask_path, "rb") as img_file:
59
+ img_data = img_file.read()
60
+ img_base64 = base64.b64encode(img_data).decode('utf-8')
61
+ segmasks_base64.append(img_base64)
62
+ else:
63
+ # Return error if any segmask is missing
64
+ return SegmaskRenderResponse(
65
+ success=False,
66
+ error_message=f"Segmentation mask for subject {subject_idx} not found"
67
+ )
68
+
69
+
70
+ return SegmaskRenderResponse(
71
+ success=True,
72
+ segmasks_base64=segmasks_base64
73
+ )
74
+
75
+ except Exception as e:
76
+ # Change back to original directory on error
77
+
78
+ return SegmaskRenderResponse(
79
+ success=False,
80
+ error_message=f"Segmentation mask render failed: {str(e)}"
81
+ )
82
+
83
+ def _convert_subjects_data(self, subjects_data: List[Dict]) -> List[Dict]:
84
+ """Convert subjects data to the format expected by BlenderSegmaskRenderer."""
85
+ converted = []
86
+
87
+ for subject in subjects_data:
88
+ # Convert to the expected format with lists for x, y, azimuth
89
+ converted_subject = {
90
+ "name": subject.get("subject_name", "cuboid"),
91
+ "x": [subject["x"]],
92
+ "y": [subject["y"]],
93
+ "z": [subject["z"]],
94
+ "dims": [subject["width"], subject["depth"], subject["height"]],
95
+ "azimuth": [subject["azimuth"]]
96
+ }
97
+ converted.append(converted_subject)
98
+
99
+ return converted
100
+
101
+ # Create FastAPI app
102
+ app = FastAPI(title="Blender Segmentation Mask Render Server")
103
+
104
+ # Global server instance
105
+ server = None
106
+
107
+ @app.on_event("startup")
108
+ def startup_event():
109
+ global server
110
+ server = BlenderSegmaskRenderServer()
111
+ print("Blender Segmentation Mask Render Server started")
112
+
113
+ @app.post("/render_segmasks", response_model=SegmaskRenderResponse)
114
+ def render_segmasks(request: SegmaskRenderRequest):
115
+ """Render segmentation masks and return the results."""
116
+ if server is None:
117
+ raise HTTPException(status_code=500, detail="Server not initialized")
118
+
119
+ return server.process_render_request(request)
120
+
121
+ @app.get("/health")
122
+ def health_check():
123
+ """Health check endpoint."""
124
+ return {"status": "healthy", "type": "segmentation_mask_renderer"}
125
+
126
+ if __name__ == "__main__":
127
+ parser = argparse.ArgumentParser(description="Blender Segmentation Mask Render Server")
128
+ parser.add_argument("--port", type=int, default=5003, help="Port to run server on")
129
+ parser.add_argument("--host", default="127.0.0.1", help="Host to bind server to")
130
+
131
+ args = parser.parse_args()
132
+
133
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
cv_render.jpg ADDED

Git LFS Details

  • SHA256: d7c95530316a5c6534d67ccea89737248b59b11c34a398e3225ef789a1d7e7f0
  • Pointer size: 130 Bytes
  • Size of remote file: 27.2 kB
failed_images.txt ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
2
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
3
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
4
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
5
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
6
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
7
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
8
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
9
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
10
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
11
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
12
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
13
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
14
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
15
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
16
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
17
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
18
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
19
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
20
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
21
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
22
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
23
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
24
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
25
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
26
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
27
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
28
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
29
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
30
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
31
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
32
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
33
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
34
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
35
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
36
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
37
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
38
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
39
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
40
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
41
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
42
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
43
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
44
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
45
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
46
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
47
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
48
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
49
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
50
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
51
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
52
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
53
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
54
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
55
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
56
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
57
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
58
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
59
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
60
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
61
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
62
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
63
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
64
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
65
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
66
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
67
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
68
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
69
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
70
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
71
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
72
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
73
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
74
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
75
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
76
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
77
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
78
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
79
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
80
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
81
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
82
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
83
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
84
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
85
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
86
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
87
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
88
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
89
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
90
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/testing/
91
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/bicycle__van__scooter__sedan/001/segmap.png
92
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/dog__fox/000/segmap.png
93
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/elephant__ferrari__bear__sedan/009/segmap.png
94
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/goat__man__dog/006/segmap.png
95
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/dog__man/008/segmap.png
96
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/hen__crow/001/segmap.png
97
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/bicycle__giraffe/006/segmap.png
98
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/table__lion/008/segmap.png
99
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/deer__tiger/009/segmap.png
100
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/pigeon__teddy__hen/008/segmap.png
101
+ /archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/semantic_segmaps_dilated_eroded/bulldozer__mclaren__tiger/005/segmap.png
final_render.jpg ADDED

Git LFS Details

  • SHA256: a01724db823b183c9fffdd08eea7c503d51c20b5bbd34adee5dc4e2882517536
  • Pointer size: 130 Bytes
  • Size of remote file: 29.7 kB
infer_backend.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from typing import Optional, List, Tuple
7
+ from transformers import CLIPTokenizer, T5TokenizerFast
8
+
9
+ from train.src.pipeline import FluxPipeline
10
+ from train.src.transformer_flux import FluxTransformer2DModel
11
+ from train.src.lora_helper import set_single_lora, set_multi_lora, unset_lora
12
+ from train.src.jsonl_datasets import make_train_dataset, collate_fn
13
+
14
+
15
+ class InferenceArgs:
16
+ """Arguments configuration for inference dataset loading"""
17
+ def __init__(self, jsonl_path: str, pretrained_model_name: str):
18
+ # Basic paths
19
+ self.current_train_data_dir = jsonl_path
20
+ self.inference_embeds_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_flux2"
21
+ self.pretrained_model_name_or_path = pretrained_model_name
22
+
23
+ # Column configurations
24
+ self.subject_column = None # Set to None since we're using spatial
25
+ self.spatial_column = "cv"
26
+ self.target_column = "target"
27
+ self.caption_column = "PLACEHOLDER_prompts"
28
+
29
+ # Size configurations
30
+ self.cond_size = 512
31
+ self.noise_size = 512
32
+
33
+ # Other required parameters
34
+ self.revision = None
35
+ self.variant = None
36
+ self.max_sequence_length = 512
37
+
38
+
39
+ class InferenceEngine:
40
+ """
41
+ Handles model loading and inference for the Gradio interface.
42
+ Pre-loads the base model and dynamically loads LoRA weights based on checkpoint selection.
43
+ """
44
+
45
+ def __init__(self, base_model_path: str = "black-forest-labs/FLUX.1-dev", device: str = "cuda"):
46
+ """
47
+ Initialize the inference engine with base model.
48
+
49
+ Args:
50
+ base_model_path: Path to the base FLUX model
51
+ device: Device to run inference on (default: "cuda")
52
+ """
53
+ self.device = device
54
+ self.base_model_path = base_model_path
55
+ self.current_lora_path = None
56
+
57
+ print(f"Loading base model from {base_model_path}...")
58
+
59
+ # Load pipeline and transformer
60
+ self.pipe = FluxPipeline.from_pretrained(
61
+ base_model_path,
62
+ torch_dtype=torch.bfloat16,
63
+ device=device
64
+ )
65
+
66
+ transformer = FluxTransformer2DModel.from_pretrained(
67
+ base_model_path,
68
+ subfolder="transformer",
69
+ torch_dtype=torch.bfloat16,
70
+ device=device
71
+ )
72
+
73
+ self.pipe.transformer = transformer
74
+ self.pipe.to(device)
75
+
76
+ # Load tokenizers (same as in train.py and infer.ipynb)
77
+ print("Loading tokenizers...")
78
+ self.tokenizer_one = CLIPTokenizer.from_pretrained(
79
+ base_model_path,
80
+ subfolder="tokenizer",
81
+ revision=None,
82
+ )
83
+ self.tokenizer_two = T5TokenizerFast.from_pretrained(
84
+ base_model_path,
85
+ subfolder="tokenizer_2",
86
+ revision=None,
87
+ )
88
+ self.tokenizers = [self.tokenizer_one, self.tokenizer_two]
89
+
90
+ print("Base model and tokenizers loaded successfully!")
91
+
92
+ def load_lora(self, checkpoint_name: str, lora_weights: List[float] = [1.0]):
93
+ """
94
+ Load LoRA weights for a specific checkpoint.
95
+
96
+ Args:
97
+ checkpoint_name: Name of the checkpoint (e.g., "checkpoint_1")
98
+ lora_weights: Weights for the LoRA adaptation
99
+ """
100
+ # Construct LoRA path
101
+ lora_path = f"/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids/{checkpoint_name}/lora.safetensors"
102
+
103
+ print(f"\n\nGOT THE FOLLOWING LORA PATH: {lora_path}\n\n")
104
+
105
+ # Check if path exists
106
+ if not os.path.exists(lora_path):
107
+ raise FileNotFoundError(f"LoRA checkpoint not found at: {lora_path}")
108
+
109
+ # Only reload if it's a different checkpoint
110
+ if self.current_lora_path != lora_path:
111
+ print(f"Loading LoRA weights from {lora_path}...")
112
+ set_single_lora(
113
+ self.pipe.transformer,
114
+ lora_path,
115
+ lora_weights=lora_weights,
116
+ cond_size=512
117
+ )
118
+ self.current_lora_path = lora_path
119
+ print(f"LoRA weights loaded successfully!")
120
+ else:
121
+ print(f"LoRA already loaded for {checkpoint_name}")
122
+
123
+ def clear_cache(self):
124
+ """Clear attention processor cache"""
125
+ for name, attn_processor in self.pipe.transformer.attn_processors.items():
126
+ if hasattr(attn_processor, 'bank_kv'):
127
+ attn_processor.bank_kv.clear()
128
+
129
+ def tensor_to_image_list(self, tensor):
130
+ """Convert normalized tensor to PIL Image list"""
131
+ if tensor is None:
132
+ return []
133
+
134
+ images = []
135
+ for img_tensor in tensor:
136
+ # Denormalize from [-1, 1] to [0, 1]
137
+ img = (img_tensor.cpu().permute(1, 2, 0) * 0.5 + 0.5).clamp(0, 1).numpy()
138
+ # Convert to [0, 255] uint8
139
+ img = (img * 255.0).astype(np.uint8)
140
+ images.append(Image.fromarray(img))
141
+
142
+ return images
143
+
144
+ def run_inference(
145
+ self,
146
+ jsonl_path: str,
147
+ checkpoint_name: str,
148
+ height: int = 512,
149
+ width: int = 512,
150
+ seed: int = 42,
151
+ guidance_scale: float = 3.5,
152
+ num_inference_steps: int = 25,
153
+ max_sequence_length: int = 512
154
+ ) -> Tuple[bool, Optional[Image.Image], str]:
155
+ """
156
+ Run inference using data from JSONL file.
157
+ Uses the same data loading pipeline as training (make_train_dataset).
158
+
159
+ Args:
160
+ jsonl_path: Path to the JSONL file containing inference data
161
+ checkpoint_name: Name of checkpoint to use
162
+ height: Output image height
163
+ width: Output image width
164
+ seed: Random seed for generation
165
+ guidance_scale: Guidance scale for diffusion
166
+ num_inference_steps: Number of denoising steps
167
+ max_sequence_length: Maximum sequence length for text encoding
168
+
169
+ Returns:
170
+ Tuple of (success: bool, image: PIL.Image or None, message: str)
171
+ """
172
+ try:
173
+ # Load LoRA for selected checkpoint
174
+ self.load_lora(checkpoint_name)
175
+
176
+ # Check if JSONL file exists
177
+ if not os.path.exists(jsonl_path):
178
+ return False, None, f"JSONL file not found at: {jsonl_path}"
179
+
180
+ # Create inference arguments
181
+ inference_args = InferenceArgs(
182
+ jsonl_path=jsonl_path,
183
+ pretrained_model_name=self.base_model_path
184
+ )
185
+
186
+ # Create dataset using the same pipeline as training
187
+ print("Creating inference dataset...")
188
+ inference_dataset = make_train_dataset(inference_args, self.tokenizers, accelerator=None)
189
+
190
+ # Create dataloader with batch_size=1
191
+ inference_dataloader = torch.utils.data.DataLoader(
192
+ inference_dataset,
193
+ batch_size=1,
194
+ shuffle=False,
195
+ collate_fn=collate_fn,
196
+ num_workers=0,
197
+ )
198
+
199
+ # Get the first (and only) batch
200
+ batch = next(iter(inference_dataloader))
201
+
202
+ # Extract data from batch
203
+ caption = batch["prompts"][0] if isinstance(batch["prompts"], list) else batch["prompts"]
204
+ call_ids = batch["call_ids"]
205
+
206
+ print(f"\n{'='*60}")
207
+ print(f"Running inference with:")
208
+ print(f" Checkpoint: {checkpoint_name}")
209
+ print(f" Prompt: {caption}")
210
+ print(f" Call IDs: {call_ids}")
211
+ print(f" Height: {height}, Width: {width}")
212
+ print(f" Seed: {seed}, Steps: {num_inference_steps}")
213
+ print(f" Guidance Scale: {guidance_scale}")
214
+ print(f"{'='*60}\n")
215
+
216
+ # Convert spatial condition tensors to PIL Images
217
+ spatial_imgs = self.tensor_to_image_list(batch["cond_pixel_values"])
218
+
219
+ # Prepare cuboids segmentation masks
220
+ cuboids_segmasks = batch.get("cuboids_segmasks", None)
221
+
222
+ # Prepare joint attention kwargs
223
+ joint_attention_kwargs = {
224
+ "call_ids": call_ids,
225
+ "cuboids_segmasks": cuboids_segmasks,
226
+ }
227
+
228
+ print(f"Spatial images: {len(spatial_imgs)}")
229
+ print(f"{len(cuboids_segmasks) = }, {cuboids_segmasks[0].shape = }")
230
+ # print(f"Cuboids segmasks shape: {cuboids_segmasks.shape if cuboids_segmasks is not None else 'None'}")
231
+ cuboids_segmasks = torch.stack(cuboids_segmasks, dim=0) if cuboids_segmasks is not None else None
232
+
233
+ # Run inference
234
+ image = self.pipe(
235
+ prompt=caption,
236
+ height=int(height),
237
+ width=int(width),
238
+ guidance_scale=guidance_scale,
239
+ num_inference_steps=num_inference_steps,
240
+ max_sequence_length=max_sequence_length,
241
+ generator=torch.Generator("cpu").manual_seed(seed),
242
+ subject_images=[], # No subject images for spatial conditioning
243
+ spatial_images=spatial_imgs,
244
+ cond_size=512,
245
+ **joint_attention_kwargs
246
+ ).images[0]
247
+
248
+ # Clear cache
249
+ self.clear_cache()
250
+ torch.cuda.empty_cache()
251
+
252
+ success_msg = f"✅ Successfully generated image using {checkpoint_name}"
253
+ print(f"\n{success_msg}\n")
254
+
255
+ return True, image, success_msg
256
+
257
+ except Exception as e:
258
+ error_msg = f"❌ Inference failed: {str(e)}"
259
+ print(f"\n{error_msg}\n")
260
+ import traceback
261
+ traceback.print_exc()
262
+ return False, None, error_msg
263
+
264
+
265
+ # Global inference engine instance
266
+ _inference_engine: Optional[InferenceEngine] = None
267
+
268
+
269
+ def initialize_inference_engine(base_model_path: str = "black-forest-labs/FLUX.1-dev"):
270
+ """
271
+ Initialize the global inference engine.
272
+ Should be called once when the Gradio demo starts.
273
+ """
274
+ global _inference_engine
275
+
276
+ if _inference_engine is None:
277
+ print("\n" + "="*60)
278
+ print("INITIALIZING INFERENCE ENGINE")
279
+ print("="*60 + "\n")
280
+
281
+ _inference_engine = InferenceEngine(base_model_path=base_model_path)
282
+
283
+ print("\n" + "="*60)
284
+ print("INFERENCE ENGINE READY")
285
+ print("="*60 + "\n")
286
+
287
+ return _inference_engine
288
+
289
+
290
+ def get_inference_engine() -> InferenceEngine:
291
+ """
292
+ Get the global inference engine instance.
293
+ Raises an error if not initialized.
294
+ """
295
+ global _inference_engine
296
+
297
+ if _inference_engine is None:
298
+ raise RuntimeError(
299
+ "Inference engine not initialized. "
300
+ "Call initialize_inference_engine() first."
301
+ )
302
+
303
+ return _inference_engine
304
+
305
+
306
+ def run_inference_from_gradio(
307
+ checkpoint_name: str,
308
+ height: int = 512,
309
+ width: int = 512,
310
+ seed: int = 42,
311
+ guidance_scale: float = 3.5,
312
+ num_inference_steps: int = 25,
313
+ jsonl_path: str = "/archive/vaibhav.agrawal/a-bev-of-the-latents/gradio_files/cuboids.jsonl"
314
+ ) -> Tuple[bool, Optional[Image.Image], str]:
315
+ """
316
+ Wrapper function to run inference from Gradio interface.
317
+
318
+ Args:
319
+ checkpoint_name: Name of checkpoint to use (from dropdown)
320
+ height: Output image height
321
+ width: Output image width
322
+ seed: Random seed
323
+ guidance_scale: Guidance scale
324
+ num_inference_steps: Number of denoising steps
325
+ jsonl_path: Path to JSONL file with inference data
326
+
327
+ Returns:
328
+ Tuple of (success, generated_image, status_message)
329
+ """
330
+ engine = get_inference_engine()
331
+
332
+ return engine.run_inference(
333
+ jsonl_path=jsonl_path,
334
+ checkpoint_name=checkpoint_name,
335
+ height=height,
336
+ width=width,
337
+ seed=seed,
338
+ guidance_scale=guidance_scale,
339
+ num_inference_steps=num_inference_steps
340
+ )
launch_blender_backend.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ PORTS=(5001 5002 5003 5004)
4
+
5
+ for port in "${PORTS[@]}"; do
6
+ PID=$(lsof -t -i ":$port")
7
+ if [ -n "$PID" ]; then
8
+ echo "Killing process $PID running on port $port..."
9
+ kill -9 "$PID"
10
+ echo "Process $PID killed."
11
+ else
12
+ echo "No process found running on port $port."
13
+ fi
14
+ done
15
+
16
+ # Start CV render server
17
+ echo "Starting Camera View render server on port 5001..."
18
+ python blender_server.py --mode cv --port 5001 &
19
+ CV_PID=$!
20
+
21
+ echo "Starting Camera View render server on port 5002..."
22
+ python blender_server.py --mode final --port 5002 &
23
+ FINAL_PID=$!
24
+
25
+
26
+ # Start segmask render server
27
+ echo "Starting Segmentation Mask render server on port 5003..."
28
+ python3 blender_server_segmasks.py --port 5003 &
29
+ SEGMASK_PID=$!
30
+
31
+ echo "Starting Camera View render server on port 5004..."
32
+ python blender_server.py --mode paper --port 5004 &
33
+ PAPER_PID=$!
34
+
35
+ echo "Render servers started!"
36
+ echo "CV Server PID: $CV_PID (port 5001)"
37
+ echo "Final (Cycles) Render Server PID: $FINAL_PID (port 5002)"
38
+ echo "Segmentation Mask Server PID: $SEGMASK_PID (port 5003)"
39
+
40
+ # Function to cleanup on exit
41
+ cleanup() {
42
+ echo "Stopping render servers..."
43
+ kill $CV_PID $FINAL_PID $SEGMASK_PID 2>/dev/null
44
+ exit 0
45
+ }
46
+
47
+ # Set trap to cleanup on script exit
48
+ trap cleanup SIGINT SIGTERM
49
+
50
+ # Wait for both processes
51
+ wait $CV_PID $FINAL_PID $SEGMASK_PID
model_condition.jpg ADDED

Git LFS Details

  • SHA256: 3725ac8df7ae335e2eb4c7fdd5164dd727548650955da809f84680ddcd0e251c
  • Pointer size: 130 Bytes
  • Size of remote file: 24.1 kB
object_scales.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scales = {
2
+ "bear": 0.53, # Unchanged
3
+ "bicycle": 0.4, # Unchanged
4
+ "bugatti": 1.0, # Unchanged
5
+ "bulldozer": 1.78, # Unchanged
6
+ "bus": 2.67, # Unchanged
7
+ "cat": 0.11, # Unchanged
8
+ "chair": 0.18, # Unchanged
9
+ "coupe": 1.0, # Unchanged
10
+ "cow": 0.56, # Unchanged
11
+ "crow": 0.09, # CHANGED: Reduced from 0.11
12
+ "deer": 0.44, # Unchanged
13
+ "dog": 0.22, # Unchanged
14
+ "elephant": 1.22, # Unchanged
15
+ "ferrari": 1.05, # CHANGED: Increased from 1.0
16
+ "flamingo": 0.10, # Unchanged
17
+ "fox": 0.22, # Unchanged
18
+ "giraffe": 0.90, # CHANGED: Reduced from 1.00
19
+ "goat": 0.33, # Unchanged
20
+ "helicopter": 2.26, # Unchanged
21
+ "hen": 0.09, # Unchanged
22
+ "horse": 0.53, # Unchanged
23
+ "jeep": 0.96, # Unchanged
24
+ "kangaroo": 0.38, # CHANGED: Increased from 0.33
25
+ "lamborghini": 1.0, # Unchanged
26
+ "lion": 0.56, # Unchanged
27
+ "mclaren": 1.0, # Unchanged
28
+ "motorbike": 0.44, # Unchanged
29
+ "office chair": 0.20,# Unchanged
30
+ "pickup truck": 1.22,# Unchanged
31
+ "pigeon": 0.067, # Unchanged
32
+ "pig": 0.33, # Unchanged
33
+ "rabbit": 0.11, # Unchanged
34
+ "scooter": 0.4, # Unchanged
35
+ "sedan": 1.0, # Unchanged (Reference)
36
+ "sheep": 0.29, # Unchanged
37
+ "shoe": 0.04, # Unchanged
38
+ "sparrow": 0.033, # Unchanged
39
+ "suv": 1.07, # Unchanged
40
+ "table": 0.4, # Unchanged
41
+ "teddy": 0.05, # CHANGED: Reduced from 0.11
42
+ "tiger": 0.67, # Unchanged
43
+ "tractor": 0.80, # Unchanged
44
+ "van": 1.11, # Unchanged
45
+ "vw beetle": 1.0, # Unchanged
46
+ "wolf": 0.33, # Unchanged
47
+ "man": 0.38, # Unchanged
48
+ "zebra": 0.56 # Unchanged
49
+ }
50
+
51
+ tiny_assets = [
52
+ "sparrow", # 0.033
53
+ "shoe", # 0.04
54
+ "teddy", # 0.05 (CHANGED)
55
+ "pigeon", # 0.067
56
+ "hen", # 0.09
57
+ "crow", # 0.09 (CHANGED)
58
+ "flamingo", # 0.10 (CHANGED - Moved from small)
59
+ "rabbit", # 0.11
60
+ "cat", # 0.11
61
+ ]
62
+
63
+
64
+ small_assets = [
65
+ "chair", # 0.18
66
+ "office chair", # 0.20
67
+ "dog", # 0.22
68
+ "fox", # 0.22
69
+ "sheep", # 0.29
70
+ "goat", # 0.33
71
+ "pig", # 0.33
72
+ "wolf", # 0.33
73
+ "man", # 0.38 (CHANGED - Added to group)
74
+ "kangaroo", # 0.38 (CHANGED)
75
+ ]
76
+
77
+
78
+ medium_assets = [
79
+ "table", # 0.4 (CHANGED - Moved from small)
80
+ "bicycle", # 0.4
81
+ "scooter", # 0.4
82
+ "deer", # 0.44
83
+ "motorbike", # 0.44
84
+ "bear", # 0.53
85
+ "horse", # 0.53
86
+ "cow", # 0.56
87
+ "lion", # 0.56
88
+ "zebra", # 0.56
89
+ "tiger", # 0.67
90
+ "tractor", # 0.80
91
+ "giraffe", # 0.90 (CHANGED)
92
+ "jeep", # 0.96
93
+ "bugatti", # 1.0
94
+ "coupe", # 1.0
95
+ "lamborghini", # 1.0
96
+ "mclaren", # 1.0
97
+ "sedan", # 1.0
98
+ "vw beetle", # 1.0
99
+ "ferrari", # 1.05 (CHANGED)
100
+ "suv", # 1.07
101
+ "van", # 1.11
102
+ "elephant", # 1.22
103
+ "pickup truck", # 1.22
104
+ "bulldozer", # 1.78
105
+ "helicopter", # 2.26
106
+ "bus", # 2.67
107
+ ]
108
+
109
+ tiny_prompts = [
110
+ "a photo of PLACEHOLDER in a cozy birdhouse nestled in a green tree",
111
+ "a photo of PLACEHOLDER on a sandy beach near the water's edge with small shells",
112
+ "a photo of PLACEHOLDER amongst colorful wildflowers in a sunny meadow",
113
+ "a photo of PLACEHOLDER on a moss-covered log in a quiet forest",
114
+ "a photo of PLACEHOLDER near a small pond with lily pads floating",
115
+ "a photo of PLACEHOLDER on a window sill overlooking a rainy city street",
116
+ "a photo of PLACEHOLDER in a child's bedroom surrounded by other toys",
117
+ "a photo of PLACEHOLDER on a park bench with fallen leaves around",
118
+ "a photo of PLACEHOLDER by a small stream with smooth pebbles",
119
+ "a photo of PLACEHOLDER in a field of tall grass swaying gently",
120
+ "a photo of PLACEHOLDER on a wooden fence post in the countryside",
121
+ "a photo of PLACEHOLDER amongst blossoming spring flowers in a garden",
122
+ "a photo of PLACEHOLDER on a stack of old books in a library",
123
+ "a photo of PLACEHOLDER near a bird feeder in a winter garden",
124
+ "a photo of PLACEHOLDER on a picnic blanket in a sunny park",
125
+ "a photo of PLACEHOLDER on a kitchen counter near ripe fruit",
126
+ "a photo of PLACEHOLDER amongst autumn leaves on a forest floor",
127
+ "a photo of PLACEHOLDER on a rocky outcrop with a distant view",
128
+ "a photo of PLACEHOLDER near a puddle reflecting the sky",
129
+ "a photo of PLACEHOLDER in a patch of soft green moss",
130
+ "a photo of PLACEHOLDER on a weathered stone wall",
131
+ "a photo of PLACEHOLDER near a patch of blooming daisies",
132
+ "a photo of PLACEHOLDER on a sandy path through a garden",
133
+ "a photo of PLACEHOLDER near a watering can in a greenhouse",
134
+ "a photo of PLACEHOLDER amongst fallen pine needles in a forest",
135
+ "a photo of PLACEHOLDER on a small bridge over a gentle stream",
136
+ "a photo of PLACEHOLDER near a patch of colorful mushrooms"
137
+ ]
138
+
139
+ small_prompts = [
140
+ "a photo of PLACEHOLDER in a sun-drenched greenhouse surrounded by various plants",
141
+ "a photo of PLACEHOLDER in a bustling city park with people walking by",
142
+ "a photo of PLACEHOLDER in a cozy library with tall bookshelves and soft lighting",
143
+ "a photo of PLACEHOLDER on a sandy dune near the ocean with gentle waves",
144
+ "a photo of PLACEHOLDER amongst tall reeds in a marshland area",
145
+ "a photo of PLACEHOLDER in a quiet forest clearing with sunlight filtering through trees",
146
+ "a photo of PLACEHOLDER on a grassy hill overlooking a small town",
147
+ "a photo of PLACEHOLDER near a flowing waterfall with mist in the air",
148
+ "a photo of PLACEHOLDER in a vibrant flower market with colorful blooms all around",
149
+ "a photo of PLACEHOLDER on a wooden dock extending into a still lake",
150
+ "a photo of PLACEHOLDER amongst rows of crops in a rural farmland",
151
+ "a photo of PLACEHOLDER in a historic town square with old buildings",
152
+ "a photo of PLACEHOLDER on a rocky beach with crashing waves in the distance",
153
+ "a photo of PLACEHOLDER amongst tall bamboo stalks in a serene grove",
154
+ "a photo of PLACEHOLDER in a snowy field with tracks visible in the snow",
155
+ "a photo of PLACEHOLDER on a paved walkway in a botanical garden",
156
+ "a photo of PLACEHOLDER near a campfire in a forest at night",
157
+ "a photo of PLACEHOLDER amongst colorful autumn foliage in a park",
158
+ "a photo of PLACEHOLDER on a stone path winding through a garden",
159
+ "a photo of PLACEHOLDER in a misty meadow with dew-covered grass",
160
+ "a photo of PLACEHOLDER on a wooden bridge crossing a small river",
161
+ "a photo of PLACEHOLDER amongst blooming lavender fields under a sunny sky",
162
+ "a photo of PLACEHOLDER in a quiet suburban backyard with green grass",
163
+ "a photo of PLACEHOLDER on a rocky hillside with sparse vegetation",
164
+ "a photo of PLACEHOLDER near a clear mountain stream with smooth stones",
165
+ "a photo of PLACEHOLDER amongst fallen leaves in a shaded woodland",
166
+ "a photo of PLACEHOLDER on a grassy bank beside a calm canal",
167
+ "a photo of PLACEHOLDER in a vineyard with rows of grapevines",
168
+ "a photo of PLACEHOLDER near a traditional wooden farmhouse"
169
+ ]
170
+
171
+ medium_prompts = [
172
+ "a photo of PLACEHOLDER in a vast open plain with a dramatic sunset on the horizon",
173
+ "a photo of PLACEHOLDER on a winding mountain road with scenic views of valleys",
174
+ "a photo of PLACEHOLDER in a bustling harbor with various boats and ships",
175
+ "a photo of PLACEHOLDER in a dense pine forest with tall trees reaching the sky",
176
+ "a photo of PLACEHOLDER on a sandy beach with palm trees swaying in the breeze",
177
+ "a photo of PLACEHOLDER amongst rolling hills in a green countryside landscape",
178
+ "a photo of PLACEHOLDER in a vibrant city square with historic architecture",
179
+ "a photo of PLACEHOLDER in a train yard with multiple railway tracks",
180
+ "a photo of PLACEHOLDER amongst tall redwood trees in an ancient forest",
181
+ "a photo of PLACEHOLDER in a sprawling parking lot outside a shopping mall",
182
+ "a photo of PLACEHOLDER on a coastal highway with ocean views and cliffs",
183
+ "a photo of PLACEHOLDER amongst golden wheat fields under a clear summer sky",
184
+ "a photo of PLACEHOLDER in a rocky canyon with sparse desert vegetation and blue sky above",
185
+ "a photo of PLACEHOLDER on a grassy plateau overlooking a vast landscape",
186
+ "a photo of PLACEHOLDER in a snowy mountain range with visible ski slopes",
187
+ "a photo of PLACEHOLDER on a paved highway stretching across an open landscape",
188
+ "a photo of PLACEHOLDER amongst lush vegetation in a tropical rainforest",
189
+ "a photo of PLACEHOLDER in a historic European city with ornate buildings",
190
+ "a photo of PLACEHOLDER in front of the Eiffel Tower at sunset",
191
+ "a photo of PLACEHOLDER amongst tall sunflowers in a field under a bright sun",
192
+ "a photo of PLACEHOLDER in a deep valley with steep forested sides",
193
+ "a photo of PLACEHOLDER on a rocky coastline with crashing waves and sea spray",
194
+ "a photo of PLACEHOLDER amongst vineyards on rolling hills under a sunny sky",
195
+ "a photo of PLACEHOLDER in a wide open desert with distant mesas and clear air",
196
+ "a photo of PLACEHOLDER amongst autumn-colored trees along a winding river",
197
+ "a photo of PLACEHOLDER in a bustling marketplace with various stalls and people",
198
+ "a photo of PLACEHOLDER on a racing circuit with banked turns and grandstands",
199
+ "a photo of PLACEHOLDER amongst tall grasses in a savanna landscape",
200
+ ]
201
+
202
+ groups = {
203
+ "tiny": tiny_assets,
204
+ "small": small_assets,
205
+ "medium": medium_assets,
206
+ }
207
+
208
+ groups_prompts = {
209
+ "tiny": tiny_prompts,
210
+ "small": small_prompts,
211
+ "medium": medium_prompts,
212
+ }
paper_render.jpg ADDED

Git LFS Details

  • SHA256: 139d2ea5d57442ffb2f8eaa44e421c63fc71e008a7f4c001362712074ba0d749
  • Pointer size: 130 Bytes
  • Size of remote file: 21.1 kB
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu124
2
+ torch
3
+ torchvision
4
+ torchaudio
5
+ diffusers
6
+ easydict==1.13
7
+ einops==0.8.1
8
+ peft==0.17.0
9
+ pillow==11.0.0
10
+ protobuf==5.29.3
11
+ requests==2.32.3
12
+ safetensors==0.5.2
13
+ sentencepiece==0.2.0
14
+ spaces==0.34.1
15
+ transformers==4.49.0
16
+ datasets
17
+ wandb
18
+ matplotlib
19
+ opencv-python
20
+ wandb
set_tmp.sh ADDED
File without changes
tmp.png ADDED

Git LFS Details

  • SHA256: a01c67695150923ba5023492dba8e38e1afca3bddc036e49c0adc9b235cee76b
  • Pointer size: 131 Bytes
  • Size of remote file: 593 kB
train/default_config.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ main_process_port: 14121
5
+ downcast_bf16: 'no'
6
+ gpu_ids: 1,
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: fp16
10
+ num_machines: 1
11
+ num_processes: 1
12
+ same_network: true
13
+ tpu_env: []
14
+ tpu_use_cluster: false
15
+ tpu_use_sudo: false
16
+ use_cpu: false
train/group_subjects.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scales = {
2
+ "bear": 0.53, # Unchanged
3
+ "bicycle": 0.4, # Unchanged
4
+ "bugatti": 1.0, # Unchanged
5
+ "bulldozer": 1.78, # Unchanged
6
+ "bus": 2.67, # Unchanged
7
+ "cat": 0.11, # Unchanged
8
+ "chair": 0.18, # Unchanged
9
+ "coupe": 1.0, # Unchanged
10
+ "cow": 0.56, # Unchanged
11
+ "crow": 0.09, # CHANGED: Reduced from 0.11
12
+ "deer": 0.44, # Unchanged
13
+ "dog": 0.22, # Unchanged
14
+ "elephant": 1.22, # Unchanged
15
+ "ferrari": 1.05, # CHANGED: Increased from 1.0
16
+ "flamingo": 0.10, # Unchanged
17
+ "fox": 0.22, # Unchanged
18
+ "giraffe": 0.90, # CHANGED: Reduced from 1.00
19
+ "goat": 0.33, # Unchanged
20
+ "helicopter": 2.26, # Unchanged
21
+ "hen": 0.09, # Unchanged
22
+ "horse": 0.53, # Unchanged
23
+ "jeep": 0.96, # Unchanged
24
+ "kangaroo": 0.38, # CHANGED: Increased from 0.33
25
+ "lamborghini": 1.0, # Unchanged
26
+ "lion": 0.56, # Unchanged
27
+ "mclaren": 1.0, # Unchanged
28
+ "motorbike": 0.44, # Unchanged
29
+ "office chair": 0.20,# Unchanged
30
+ "pickup truck": 1.22,# Unchanged
31
+ "pigeon": 0.067, # Unchanged
32
+ "pig": 0.33, # Unchanged
33
+ "rabbit": 0.11, # Unchanged
34
+ "scooter": 0.4, # Unchanged
35
+ "sedan": 1.0, # Unchanged (Reference)
36
+ "sheep": 0.29, # Unchanged
37
+ "shoe": 0.04, # Unchanged
38
+ "sparrow": 0.033, # Unchanged
39
+ "suv": 1.07, # Unchanged
40
+ "table": 0.4, # Unchanged
41
+ "teddy": 0.05, # CHANGED: Reduced from 0.11
42
+ "tiger": 0.67, # Unchanged
43
+ "tractor": 0.80, # Unchanged
44
+ "van": 1.11, # Unchanged
45
+ "vw beetle": 1.0, # Unchanged
46
+ "wolf": 0.33, # Unchanged
47
+ "man": 0.38, # Unchanged
48
+ "zebra": 0.56 # Unchanged
49
+ }
50
+
51
+ tiny_assets = [
52
+ "sparrow", # 0.033
53
+ "shoe", # 0.04
54
+ "teddy", # 0.05 (CHANGED)
55
+ "pigeon", # 0.067
56
+ "hen", # 0.09
57
+ "crow", # 0.09 (CHANGED)
58
+ "flamingo", # 0.10 (CHANGED - Moved from small)
59
+ "rabbit", # 0.11
60
+ "cat", # 0.11
61
+ ]
62
+
63
+
64
+ small_assets = [
65
+ "chair", # 0.18
66
+ "office chair", # 0.20
67
+ "dog", # 0.22
68
+ "fox", # 0.22
69
+ "sheep", # 0.29
70
+ "goat", # 0.33
71
+ "pig", # 0.33
72
+ "wolf", # 0.33
73
+ "man", # 0.38 (CHANGED - Added to group)
74
+ "kangaroo", # 0.38 (CHANGED)
75
+ ]
76
+
77
+
78
+ medium_assets = [
79
+ "table", # 0.4 (CHANGED - Moved from small)
80
+ "bicycle", # 0.4
81
+ "scooter", # 0.4
82
+ "deer", # 0.44
83
+ "motorbike", # 0.44
84
+ "bear", # 0.53
85
+ "horse", # 0.53
86
+ "cow", # 0.56
87
+ "lion", # 0.56
88
+ "zebra", # 0.56
89
+ "tiger", # 0.67
90
+ "tractor", # 0.80
91
+ "giraffe", # 0.90 (CHANGED)
92
+ "jeep", # 0.96
93
+ "bugatti", # 1.0
94
+ "coupe", # 1.0
95
+ "lamborghini", # 1.0
96
+ "mclaren", # 1.0
97
+ "sedan", # 1.0
98
+ "vw beetle", # 1.0
99
+ "ferrari", # 1.05 (CHANGED)
100
+ "suv", # 1.07
101
+ "van", # 1.11
102
+ "elephant", # 1.22
103
+ "pickup truck", # 1.22
104
+ "bulldozer", # 1.78
105
+ "helicopter", # 2.26
106
+ "bus", # 2.67
107
+ ]
108
+
109
+ tiny_prompts = [
110
+ "a photo of PLACEHOLDER in a cozy birdhouse nestled in a green tree",
111
+ "a photo of PLACEHOLDER on a sandy beach near the water's edge with small shells",
112
+ "a photo of PLACEHOLDER amongst colorful wildflowers in a sunny meadow",
113
+ "a photo of PLACEHOLDER on a moss-covered log in a quiet forest",
114
+ "a photo of PLACEHOLDER near a small pond with lily pads floating",
115
+ "a photo of PLACEHOLDER on a window sill overlooking a rainy city street",
116
+ "a photo of PLACEHOLDER in a child's bedroom surrounded by other toys",
117
+ "a photo of PLACEHOLDER on a park bench with fallen leaves around",
118
+ "a photo of PLACEHOLDER by a small stream with smooth pebbles",
119
+ "a photo of PLACEHOLDER in a field of tall grass swaying gently",
120
+ "a photo of PLACEHOLDER on a wooden fence post in the countryside",
121
+ "a photo of PLACEHOLDER amongst blossoming spring flowers in a garden",
122
+ "a photo of PLACEHOLDER on a stack of old books in a library",
123
+ "a photo of PLACEHOLDER near a bird feeder in a winter garden",
124
+ "a photo of PLACEHOLDER on a picnic blanket in a sunny park",
125
+ "a photo of PLACEHOLDER on a kitchen counter near ripe fruit",
126
+ "a photo of PLACEHOLDER amongst autumn leaves on a forest floor",
127
+ "a photo of PLACEHOLDER on a rocky outcrop with a distant view",
128
+ "a photo of PLACEHOLDER near a puddle reflecting the sky",
129
+ "a photo of PLACEHOLDER in a patch of soft green moss",
130
+ "a photo of PLACEHOLDER on a weathered stone wall",
131
+ "a photo of PLACEHOLDER near a patch of blooming daisies",
132
+ "a photo of PLACEHOLDER on a sandy path through a garden",
133
+ "a photo of PLACEHOLDER near a watering can in a greenhouse",
134
+ "a photo of PLACEHOLDER amongst fallen pine needles in a forest",
135
+ "a photo of PLACEHOLDER on a small bridge over a gentle stream",
136
+ "a photo of PLACEHOLDER near a patch of colorful mushrooms"
137
+ ]
138
+
139
+ small_prompts = [
140
+ "a photo of PLACEHOLDER in a sun-drenched greenhouse surrounded by various plants",
141
+ "a photo of PLACEHOLDER in a bustling city park with people walking by",
142
+ "a photo of PLACEHOLDER in a cozy library with tall bookshelves and soft lighting",
143
+ "a photo of PLACEHOLDER on a sandy dune near the ocean with gentle waves",
144
+ "a photo of PLACEHOLDER amongst tall reeds in a marshland area",
145
+ "a photo of PLACEHOLDER in a quiet forest clearing with sunlight filtering through trees",
146
+ "a photo of PLACEHOLDER on a grassy hill overlooking a small town",
147
+ "a photo of PLACEHOLDER near a flowing waterfall with mist in the air",
148
+ "a photo of PLACEHOLDER in a vibrant flower market with colorful blooms all around",
149
+ "a photo of PLACEHOLDER on a wooden dock extending into a still lake",
150
+ "a photo of PLACEHOLDER amongst rows of crops in a rural farmland",
151
+ "a photo of PLACEHOLDER in a historic town square with old buildings",
152
+ "a photo of PLACEHOLDER on a rocky beach with crashing waves in the distance",
153
+ "a photo of PLACEHOLDER amongst tall bamboo stalks in a serene grove",
154
+ "a photo of PLACEHOLDER in a snowy field with tracks visible in the snow",
155
+ "a photo of PLACEHOLDER on a paved walkway in a botanical garden",
156
+ "a photo of PLACEHOLDER near a campfire in a forest at night",
157
+ "a photo of PLACEHOLDER amongst colorful autumn foliage in a park",
158
+ "a photo of PLACEHOLDER on a stone path winding through a garden",
159
+ "a photo of PLACEHOLDER in a misty meadow with dew-covered grass",
160
+ "a photo of PLACEHOLDER on a wooden bridge crossing a small river",
161
+ "a photo of PLACEHOLDER amongst blooming lavender fields under a sunny sky",
162
+ "a photo of PLACEHOLDER in a quiet suburban backyard with green grass",
163
+ "a photo of PLACEHOLDER on a rocky hillside with sparse vegetation",
164
+ "a photo of PLACEHOLDER near a clear mountain stream with smooth stones",
165
+ "a photo of PLACEHOLDER amongst fallen leaves in a shaded woodland",
166
+ "a photo of PLACEHOLDER on a grassy bank beside a calm canal",
167
+ "a photo of PLACEHOLDER in a vineyard with rows of grapevines",
168
+ "a photo of PLACEHOLDER near a traditional wooden farmhouse"
169
+ ]
170
+
171
+ medium_prompts = [
172
+ "a photo of PLACEHOLDER in a vast open plain with a dramatic sunset on the horizon",
173
+ "a photo of PLACEHOLDER on a winding mountain road with scenic views of valleys",
174
+ "a photo of PLACEHOLDER in a bustling harbor with various boats and ships",
175
+ "a photo of PLACEHOLDER in a dense pine forest with tall trees reaching the sky",
176
+ "a photo of PLACEHOLDER on a sandy beach with palm trees swaying in the breeze",
177
+ "a photo of PLACEHOLDER amongst rolling hills in a green countryside landscape",
178
+ "a photo of PLACEHOLDER in a vibrant city square with historic architecture",
179
+ "a photo of PLACEHOLDER in a train yard with multiple railway tracks",
180
+ "a photo of PLACEHOLDER amongst tall redwood trees in an ancient forest",
181
+ "a photo of PLACEHOLDER in a sprawling parking lot outside a shopping mall",
182
+ "a photo of PLACEHOLDER on a coastal highway with ocean views and cliffs",
183
+ "a photo of PLACEHOLDER amongst golden wheat fields under a clear summer sky",
184
+ "a photo of PLACEHOLDER in a rocky canyon with sparse desert vegetation and blue sky above",
185
+ "a photo of PLACEHOLDER on a grassy plateau overlooking a vast landscape",
186
+ "a photo of PLACEHOLDER in a snowy mountain range with visible ski slopes",
187
+ "a photo of PLACEHOLDER on a paved highway stretching across an open landscape",
188
+ "a photo of PLACEHOLDER amongst lush vegetation in a tropical rainforest",
189
+ "a photo of PLACEHOLDER in a historic European city with ornate buildings",
190
+ "a photo of PLACEHOLDER in front of the Eiffel Tower at sunset",
191
+ "a photo of PLACEHOLDER amongst tall sunflowers in a field under a bright sun",
192
+ "a photo of PLACEHOLDER in a deep valley with steep forested sides",
193
+ "a photo of PLACEHOLDER on a rocky coastline with crashing waves and sea spray",
194
+ "a photo of PLACEHOLDER amongst vineyards on rolling hills under a sunny sky",
195
+ "a photo of PLACEHOLDER in a wide open desert with distant mesas and clear air",
196
+ "a photo of PLACEHOLDER amongst autumn-colored trees along a winding river",
197
+ "a photo of PLACEHOLDER in a bustling marketplace with various stalls and people",
198
+ "a photo of PLACEHOLDER on a racing circuit with banked turns and grandstands",
199
+ "a photo of PLACEHOLDER amongst tall grasses in a savanna landscape",
200
+ ]
201
+
202
+ groups = {
203
+ "tiny": tiny_assets,
204
+ "small": small_assets,
205
+ "medium": medium_assets,
206
+ }
207
+
208
+ groups_prompts = {
209
+ "tiny": tiny_prompts,
210
+ "small": small_prompts,
211
+ "medium": medium_prompts,
212
+ }
train/make_jsonl2_clip.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import json
4
+ from tqdm import tqdm
5
+ import torch
6
+ import pickle
7
+ from group_subjects_tabletop import groups, groups_prompts
8
+ from train import *
9
+
10
+ PRETRAINED_MODEL_NAME_OR_PATH = "black-forest-labs/FLUX.1-dev"
11
+
12
+ def load_clip_evaluation_results(eval_dir, subjects_comb, img_idx, img_name):
13
+ """
14
+ Load the CLIP similarity results for a specific image.
15
+
16
+ Args:
17
+ eval_dir: Base directory containing CLIP evaluation results
18
+ subjects_comb: Subject combination directory name
19
+ img_idx: Image index directory name
20
+ img_name: Name of the image file
21
+
22
+ Returns:
23
+ Minimum CLIP similarity score, or None if file not found
24
+ """
25
+ # Construct the path to the pkl file
26
+ pkl_filename = img_name.replace(".jpg", ".pkl")
27
+ pkl_path = osp.join(eval_dir, subjects_comb, img_idx, pkl_filename)
28
+
29
+ if not osp.exists(pkl_path):
30
+ return None
31
+
32
+ try:
33
+ with open(pkl_path, 'rb') as f:
34
+ data = pickle.load(f)
35
+
36
+ similarities = data.get('similarities', [])
37
+
38
+ # Filter out any potential zero scores for non-existent subjects
39
+ valid_similarities = [s for s in similarities if s > 0.0]
40
+
41
+ if valid_similarities:
42
+ return min(valid_similarities)
43
+ else:
44
+ return 0.0
45
+
46
+ except Exception as e:
47
+ print(f"Warning: Could not process file {pkl_path}. Error: {e}")
48
+ return None
49
+
50
+
51
+ def get_call_ids_from_placeholder_prompt_flux(prompt: str, tokenizer_three, subjects, subjects_embeds: list, debug: bool):
52
+ assert prompt.find("<placeholder>") != -1, "Prompt must contain <placeholder> to get call ids"
53
+
54
+ # the placeholder token ID for all the tokenizers
55
+ placeholder_token_three = tokenizer_three.encode("<placeholder>", return_tensors="pt")[0][:-1].item()
56
+ prompt_tokens_three = tokenizer_three.encode(prompt, return_tensors="pt")[0].tolist()
57
+
58
+ placeholder_token_locations_three = [i for i, w in enumerate(prompt_tokens_three) if w == placeholder_token_three]
59
+ prompt = prompt.replace("<placeholder> ", "")
60
+
61
+
62
+ call_ids = []
63
+ for subject_idx, (subject, subject_embed) in enumerate(zip(subjects, subjects_embeds)):
64
+ subject_prompt_ids_t5 = subject_embed["input_ids_t5"][:-1] # T5 has SOT token only
65
+ num_t5_tokens_subject = len(subject_prompt_ids_t5)
66
+
67
+ t5_call_ids_subject = [i + placeholder_token_locations_three[subject_idx] - 2 * subject_idx - 1 for i in range(num_t5_tokens_subject)]
68
+ call_ids.append(t5_call_ids_subject)
69
+
70
+ prompt_wo_placeholder = prompt.replace("<placeholder> ", "")
71
+ t5_call_strs = tokenizer_three.batch_decode(tokenizer_three.encode(prompt_wo_placeholder, return_tensors="pt")[0].tolist())
72
+ t5_call_strs = [t5_call_strs[i] for i in t5_call_ids_subject]
73
+ if debug:
74
+ print(f"{prompt = }, t5 CALL strs for {subject} = {t5_call_strs}")
75
+
76
+ return call_ids
77
+
78
+
79
+ def generate_cuboids_jsonl(data_dir, output_path, subject_names_embeds_flux, tokenizer_one, tokenizer_two,
80
+ clip_eval_dir=None, min_clip_similarity=0.26):
81
+ """
82
+ Generate a JSONL file for cuboids dataset similar to pose.jsonl format.
83
+
84
+ Args:
85
+ data_dir: Path to the images directory (same as BlenderFLUXSyntheticDataset data_dir)
86
+ output_path: Path where the cuboids.jsonl file should be saved
87
+ clip_eval_dir: Directory containing CLIP evaluation results (optional)
88
+ min_clip_similarity: Minimum CLIP similarity threshold for depth_flux images (default: 0.26)
89
+ """
90
+
91
+ # Create inverse groups mapping
92
+ inverse_groups = {}
93
+ for category in groups:
94
+ for subject in groups[category]:
95
+ assert subject not in inverse_groups
96
+ inverse_groups[subject] = category
97
+
98
+ jsonl_entries = []
99
+ filtered_count = 0
100
+ total_depth_flux = 0
101
+
102
+ imgs_dir = osp.join(data_dir, "main_imgs")
103
+ cuboids_dir = osp.join(data_dir, "cuboids_monochrome")
104
+
105
+ # Iterate over the dataset structure (same as BlenderFLUXSyntheticDataset)
106
+ subjects_combs = os.listdir(imgs_dir)
107
+ import random
108
+ random.shuffle(subjects_combs)
109
+ for subjects_comb in tqdm(subjects_combs):
110
+ if len(subjects_comb.split("__")) > 4:
111
+ continue
112
+ if subjects_comb.startswith("_"):
113
+ continue
114
+ subjects_ = subjects_comb.split("__")
115
+ subjects = [" ".join(subject_.split("_")) for subject_ in subjects_]
116
+ if "bed" in subjects:
117
+ continue
118
+
119
+ subjects_groups = [inverse_groups[subject] for subject in subjects]
120
+ PROMPTS = groups_prompts[subjects_groups[-2]]
121
+
122
+ subjects_comb_dir = osp.join(imgs_dir, subjects_comb)
123
+
124
+ assert clip_eval_dir is not None
125
+
126
+ for img_idx in os.listdir(subjects_comb_dir):
127
+ if not osp.isdir(osp.join(subjects_comb_dir, img_idx)):
128
+ continue
129
+
130
+ img_idx_dir = osp.join(subjects_comb_dir, img_idx)
131
+
132
+ # Check if required files exist
133
+ main_img_path = osp.join(img_idx_dir, "main.jpg")
134
+ cuboids_path = osp.join(cuboids_dir, subjects_comb, img_idx, "cuboids.jpg")
135
+ pkl_path = osp.join(img_idx_dir, "main.pkl")
136
+
137
+ assert osp.exists(main_img_path), f"Main image path {main_img_path} does not exist"
138
+ assert osp.exists(cuboids_path), f"Cuboids path {cuboids_path} does not exist"
139
+ assert osp.exists(pkl_path), f"PKL path {pkl_path} does not exist"
140
+
141
+ # Get all image types (depth_flux and rendering)
142
+ img_names = os.listdir(img_idx_dir)
143
+
144
+ # Process depth_flux images (prompt*.jpg)
145
+ depth_flux_imgs = [img_name for img_name in img_names
146
+ if img_name.endswith(".jpg") and img_name.find("prompt") != -1 and img_name.find("DEBUG") == -1]
147
+
148
+ # Filter depth_flux images based on CLIP similarity if eval_dir is provided
149
+ if clip_eval_dir is not None:
150
+ filtered_depth_flux_imgs = []
151
+ for img_name in depth_flux_imgs:
152
+ total_depth_flux += 1
153
+ min_similarity = load_clip_evaluation_results(clip_eval_dir, subjects_comb, img_idx, img_name)
154
+
155
+ if min_similarity is not None and min_similarity >= min_clip_similarity:
156
+ filtered_depth_flux_imgs.append(img_name)
157
+ else:
158
+ filtered_count += 1
159
+
160
+ depth_flux_imgs = filtered_depth_flux_imgs
161
+
162
+ all_imgs = depth_flux_imgs + ["main.jpg"]
163
+ # all_imgs = ["main.jpg"]
164
+
165
+ for img_name in all_imgs:
166
+ img_path = osp.join(img_idx_dir, img_name)
167
+
168
+ if img_name != "main.jpg":
169
+ # Extract prompt index and get corresponding prompt
170
+ prompt_idx = int(img_name.replace("prompt", "").replace(".jpg", ""))
171
+ print(f"{prompt_idx = }, {subjects_groups[-1] = }, {subjects_comb = }")
172
+ prompt = PROMPTS[prompt_idx]
173
+ else:
174
+ prompt = "a photo of PLACEHOLDER"
175
+
176
+ # Create placeholder text
177
+ placeholder_text = ""
178
+ for subject in subjects[:-1]:
179
+ placeholder_text = placeholder_text + f"<placeholder> {subject} and "
180
+ for subject in subjects[-1:]:
181
+ placeholder_text = placeholder_text + f"<placeholder> {subject}"
182
+ placeholder_text = placeholder_text.strip()
183
+
184
+ subjects_embeds = []
185
+ cuboids_segmasks_paths = []
186
+ segmasks_dir = osp.join(data_dir, "cuboids_segmasks_cv", subjects_comb, img_idx)
187
+ assert osp.exists(segmasks_dir)
188
+ segmask_names = sorted(os.listdir(segmasks_dir))
189
+ for subject_idx, subject in enumerate(subjects):
190
+ subject_embed_path = osp.join(subject_names_embeds_flux, f"{subject.replace(' ', '_')}.pth")
191
+ assert osp.exists(subject_embed_path), f"Subject embed path {subject_embed_path} does not exist"
192
+ subject_embed_obj = torch.load(subject_embed_path)
193
+ subjects_embeds.append(subject_embed_obj)
194
+ cuboid_segmask_path = osp.join(data_dir, "cuboids_segmasks_cv", subjects_comb, img_idx, segmask_names[subject_idx])
195
+ cuboid_segmask_path = osp.relpath(cuboid_segmask_path, osp.dirname(output_path))
196
+ # assert osp.exists(cuboid_segmask_path), f"Cuboid segmask path {cuboid_segmask_path} does not exist"
197
+ cuboids_segmasks_paths.append(cuboid_segmask_path)
198
+ placeholder_prompt = prompt
199
+ prompt = prompt.replace("PLACEHOLDER", placeholder_text)
200
+ call_ids = get_call_ids_from_placeholder_prompt_flux(prompt, tokenizer_two, subjects, subjects_embeds, debug=True)
201
+ print(f"{call_ids = }")
202
+
203
+ # Create relative paths from the output jsonl location
204
+ rel_cuboids_path = osp.relpath(cuboids_path, osp.dirname(output_path))
205
+ rel_img_path = osp.relpath(img_path, osp.dirname(output_path))
206
+
207
+ # Create JSONL entry
208
+ entry = {
209
+ "cv": rel_cuboids_path,
210
+ "PLACEHOLDER_prompts": placeholder_prompt,
211
+ "target": rel_img_path,
212
+ "subjects": subjects,
213
+ "cuboids_segmasks": cuboids_segmasks_paths,
214
+ "call_ids": call_ids,
215
+ }
216
+ jsonl_entries.append(entry)
217
+
218
+ # Print filtering statistics
219
+ if clip_eval_dir is not None:
220
+ print(f"\n--- Filtering Statistics ---")
221
+ print(f"Total depth_flux images evaluated: {total_depth_flux}")
222
+ print(f"Images filtered out (min similarity < {min_clip_similarity}): {filtered_count}")
223
+ print(f"Images retained: {total_depth_flux - filtered_count}")
224
+ print(f"Retention rate: {((total_depth_flux - filtered_count) / total_depth_flux * 100):.2f}%")
225
+ print("---------------------------\n")
226
+
227
+ # Write JSONL file
228
+ os.makedirs(osp.dirname(output_path), exist_ok=True)
229
+ with open(output_path, 'w') as f:
230
+ for entry in jsonl_entries:
231
+ f.write(json.dumps(entry) + '\n')
232
+
233
+ print(f"Generated {len(jsonl_entries)} entries in {output_path}")
234
+
235
+ if __name__ == "__main__":
236
+ # Configuration
237
+ data_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv9" # Replace with actual imgs_dir path
238
+ output_path = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv9/cuboids_monochrome.jsonl"
239
+ subjects_embeds_path = "/archive/vaibhav.agrawal/a-bev-of-the-latents/subject_names_embeds_flux" # Path to subject embeddings JSON
240
+ clip_eval_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/clip_evaluation__datasetv9" # CLIP evaluation results directory
241
+ min_clip_similarity = 0.26 # Minimum CLIP similarity threshold
242
+ rendered_imgs_prompt = "An image of PLACEHOLDER" # Customize as needed
243
+
244
+ # You can also accept command line arguments
245
+ import sys
246
+ if len(sys.argv) >= 2:
247
+ imgs_dir = sys.argv[1]
248
+ if len(sys.argv) >= 3:
249
+ output_path = sys.argv[2]
250
+ if len(sys.argv) >= 4:
251
+ rendered_imgs_prompt = sys.argv[3]
252
+
253
+ tokenizer_one = CLIPTokenizer.from_pretrained(
254
+ PRETRAINED_MODEL_NAME_OR_PATH,
255
+ subfolder="tokenizer",
256
+ revision=None,
257
+ )
258
+ tokenizer_two = T5TokenizerFast.from_pretrained(
259
+ PRETRAINED_MODEL_NAME_OR_PATH,
260
+ subfolder="tokenizer_2",
261
+ revision=None,
262
+ )
263
+
264
+ placeholder_token_str = ["<placeholder>"]
265
+ num_added_tokens = tokenizer_one.add_tokens(placeholder_token_str)
266
+ assert num_added_tokens == 1
267
+ num_added_tokens = tokenizer_two.add_tokens(placeholder_token_str)
268
+ assert num_added_tokens == 1
269
+
270
+ generate_cuboids_jsonl(data_dir, output_path, subjects_embeds_path, tokenizer_two, tokenizer_two,
271
+ clip_eval_dir=clip_eval_dir, min_clip_similarity=min_clip_similarity)
train/merge_jsonls.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Set, Dict, Any
5
+
6
+
7
+ def extract_subjects_comb(path: str) -> str:
8
+ """
9
+ Extract subjects_comb from a path.
10
+ Format: cuboids_monochrome/subjects_comb/img_idx/cuboids.jpg
11
+ """
12
+ path_parts = path.split('/')
13
+
14
+ if 'cuboids_monochrome' in path_parts:
15
+ idx = path_parts.index('cuboids_monochrome')
16
+ if idx + 1 < len(path_parts):
17
+ return path_parts[idx + 1]
18
+
19
+ return None
20
+
21
+
22
+ def get_subjects_combs_from_jsonl(jsonl_path: str) -> Set[str]:
23
+ """Extract all unique subjects_comb from a JSONL file."""
24
+ subjects_combs = set()
25
+
26
+ with open(jsonl_path, 'r') as f:
27
+ for line_num, line in enumerate(f, 1):
28
+ try:
29
+ entry = json.loads(line.strip())
30
+ cv_path = entry.get('cv', '')
31
+ subjects_comb = extract_subjects_comb(cv_path)
32
+
33
+ if subjects_comb:
34
+ subjects_combs.add(subjects_comb)
35
+
36
+ except json.JSONDecodeError as e:
37
+ print(f"Warning: Could not parse line {line_num} in {jsonl_path}: {e}")
38
+ continue
39
+
40
+ return subjects_combs
41
+
42
+
43
+ def get_common_keys(jsonl_path1: str, jsonl_path2: str) -> Set[str]:
44
+ """Get keys that are common across ALL entries in both JSONL files."""
45
+ keys1 = None
46
+ keys2 = None
47
+
48
+ # Get keys from first file
49
+ with open(jsonl_path1, 'r') as f:
50
+ for line in f:
51
+ try:
52
+ entry = json.loads(line.strip())
53
+ current_keys = set(entry.keys())
54
+ if keys1 is None:
55
+ keys1 = current_keys
56
+ else:
57
+ keys1 = keys1.intersection(current_keys)
58
+ except json.JSONDecodeError:
59
+ continue
60
+
61
+ # Get keys from second file
62
+ with open(jsonl_path2, 'r') as f:
63
+ for line in f:
64
+ try:
65
+ entry = json.loads(line.strip())
66
+ current_keys = set(entry.keys())
67
+ if keys2 is None:
68
+ keys2 = current_keys
69
+ else:
70
+ keys2 = keys2.intersection(current_keys)
71
+ except json.JSONDecodeError:
72
+ continue
73
+
74
+ # Return intersection of keys from both files
75
+ if keys1 is None or keys2 is None:
76
+ return set()
77
+
78
+ common_keys = keys1.intersection(keys2)
79
+ return common_keys
80
+
81
+
82
+ def merge_jsonls(jsonl_path1: str, jsonl_path2: str, output_path: str):
83
+ """
84
+ Merge two JSONL files, ensuring:
85
+ 1. No overlapping subjects_comb (assertion)
86
+ 2. Only common keys are included in output
87
+ 3. All entries are concatenated
88
+ """
89
+ print(f"Checking for overlapping subjects_comb...")
90
+ combs1 = get_subjects_combs_from_jsonl(jsonl_path1)
91
+ combs2 = get_subjects_combs_from_jsonl(jsonl_path2)
92
+
93
+ overlap = combs1.intersection(combs2)
94
+
95
+ # Assert no overlap
96
+ assert len(overlap) == 0, (
97
+ f"ERROR: Found {len(overlap)} overlapping subjects_comb between files!\n"
98
+ f"Overlapping subjects_comb: {sorted(overlap)}"
99
+ )
100
+
101
+ print(f"✓ No overlapping subjects_comb found")
102
+ print(f" File 1: {len(combs1)} unique subjects_comb")
103
+ print(f" File 2: {len(combs2)} unique subjects_comb")
104
+
105
+ # Get common keys
106
+ print(f"\nFinding common keys...")
107
+ common_keys = get_common_keys(jsonl_path1, jsonl_path2)
108
+ print(f"{common_keys = }")
109
+
110
+ assert len(common_keys) > 0, "ERROR: No common keys found between the two JSONL files!"
111
+
112
+ print(f"✓ Found {len(common_keys)} common keys: {sorted(common_keys)}")
113
+
114
+ # Merge files
115
+ print(f"\nMerging files to {output_path}...")
116
+ total_entries = 0
117
+
118
+ with open(output_path, 'w') as out_f:
119
+ # Write entries from first file
120
+ with open(jsonl_path1, 'r') as f1:
121
+ for line_num, line in enumerate(f1, 1):
122
+ try:
123
+ entry = json.loads(line.strip())
124
+ # Keep only common keys
125
+ filtered_entry = {k: entry[k] for k in common_keys if k in entry}
126
+ out_f.write(json.dumps(filtered_entry) + '\n')
127
+ total_entries += 1
128
+ except json.JSONDecodeError as e:
129
+ print(f"Warning: Could not parse line {line_num} in {jsonl_path1}: {e}")
130
+ continue
131
+
132
+ # Write entries from second file
133
+ with open(jsonl_path2, 'r') as f2:
134
+ for line_num, line in enumerate(f2, 1):
135
+ try:
136
+ entry = json.loads(line.strip())
137
+ # Keep only common keys
138
+ filtered_entry = {k: entry[k] for k in common_keys if k in entry}
139
+ out_f.write(json.dumps(filtered_entry) + '\n')
140
+ total_entries += 1
141
+ except json.JSONDecodeError as e:
142
+ print(f"Warning: Could not parse line {line_num} in {jsonl_path2}: {e}")
143
+ continue
144
+
145
+ print(f"✓ Merged {total_entries} entries to {output_path}")
146
+ print(f"\nMerge complete!")
147
+
148
+
149
+ if __name__ == "__main__":
150
+ if len(sys.argv) != 4:
151
+ print("Usage: python merge_two_jsonls.py <jsonl_file1> <jsonl_file2> <output_jsonl>")
152
+ print("\nExample:")
153
+ print(" python merge_two_jsonls.py dataset1/cuboids.jsonl dataset2/cuboids.jsonl merged_cuboids.jsonl")
154
+ sys.exit(1)
155
+
156
+ jsonl_path1 = sys.argv[1]
157
+ jsonl_path2 = sys.argv[2]
158
+ output_path = sys.argv[3]
159
+
160
+ # Validate input files
161
+ if not Path(jsonl_path1).exists():
162
+ print(f"Error: File not found: {jsonl_path1}")
163
+ sys.exit(1)
164
+
165
+ if not Path(jsonl_path2).exists():
166
+ print(f"Error: File not found: {jsonl_path2}")
167
+ sys.exit(1)
168
+
169
+ # Check if output file already exists
170
+ if Path(output_path).exists():
171
+ response = input(f"Warning: {output_path} already exists. Overwrite? (y/n): ")
172
+ if response.lower() != 'y':
173
+ print("Aborted.")
174
+ sys.exit(0)
175
+
176
+ # Merge files
177
+ try:
178
+ merge_jsonls(jsonl_path1, jsonl_path2, output_path)
179
+ except AssertionError as e:
180
+ print(f"\n{e}")
181
+ sys.exit(1)
train/src/__init__.py ADDED
File without changes
train/src/jsonl_datasets.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from datasets import load_dataset
3
+ from torchvision import transforms
4
+ import random
5
+ import torch.nn.functional as F
6
+ import torch
7
+ import os
8
+ import os.path as osp
9
+ import cv2
10
+
11
+ def do_z_pass(seg_masks: torch.Tensor, dist_values: torch.Tensor) -> torch.Tensor:
12
+ """
13
+ Performs a z-pass on segmentation masks based on distance values to the camera.
14
+ For each pixel, if multiple subjects' masks are active, only the one with the smallest distance (closest) remains active.
15
+
16
+ Args:
17
+ seg_masks (torch.Tensor): Binary segmentation masks of shape (n_subjects, h, w) with dtype uint8.
18
+ dist_values (torch.Tensor): Distance values for each subject of shape (n_subjects,).
19
+
20
+ Returns:
21
+ torch.Tensor: Processed segmentation masks after z-pass, same shape and dtype as seg_masks.
22
+ """
23
+ # Ensure tensors are on the same device
24
+ device = seg_masks.device
25
+
26
+ # Get dimensions
27
+ n_subjects, h, w = seg_masks.shape
28
+
29
+ # Reshape distance values for broadcasting across spatial dimensions
30
+ dist_values_expanded = dist_values.view(n_subjects, 1, 1)
31
+
32
+ # Create a tensor where active pixels have their distance, others have a high value (1e10)
33
+ masked_dist = torch.where(seg_masks.bool(), dist_values_expanded, torch.tensor(1e10, device=device))
34
+
35
+ # Find the subject index with the minimum distance for each pixel (shape (h, w))
36
+ closest_indices = torch.argmin(masked_dist, dim=0)
37
+
38
+ # Initialize output tensor with zeros
39
+ output = torch.zeros_like(seg_masks)
40
+
41
+ # Scatter 1s into the output tensor where the closest subject's indices are
42
+ # closest_indices.unsqueeze(0) adds a dummy dimension to match scatter's expected shape
43
+ output.scatter_(
44
+ dim=0,
45
+ index=closest_indices.unsqueeze(0),
46
+ src=torch.ones_like(closest_indices.unsqueeze(0), dtype=output.dtype)
47
+ )
48
+
49
+ # Zero out any positions where the original mask was inactive
50
+ output = output * seg_masks
51
+
52
+ return output
53
+
54
+ Image.MAX_IMAGE_PIXELS = None
55
+
56
+ def multiple_16(num: float):
57
+ return int(round(num / 16) * 16)
58
+
59
+ def get_random_resolution(min_size=512, max_size=1280, multiple=16):
60
+ resolution = random.randint(min_size // multiple, max_size // multiple) * multiple
61
+ return resolution
62
+
63
+ def load_image_safely(image_path, size):
64
+ try:
65
+ image = Image.open(image_path).convert("RGB")
66
+ return image
67
+ except Exception as e:
68
+ print("file error: "+image_path)
69
+ with open("failed_images.txt", "a") as f:
70
+ f.write(f"{image_path}\n")
71
+ return Image.new("RGB", (size, size), (255, 255, 255))
72
+
73
+ def make_train_dataset(args, tokenizer, accelerator):
74
+ if args.current_train_data_dir is not None:
75
+ print("load_data")
76
+ dataset = load_dataset('json', data_files=args.current_train_data_dir)
77
+
78
+ # Add index column to the dataset
79
+ dataset = dataset.map(lambda examples, indices: {**examples, 'index': indices}, with_indices=True, batched=True)
80
+
81
+ column_names = dataset["train"].column_names
82
+
83
+ # 6. Get the column names for input/target.
84
+ target_column = args.target_column
85
+ if args.subject_column is not None:
86
+ subject_columns = args.subject_column.split(",")
87
+ if args.spatial_column is not None:
88
+ spatial_columns= args.spatial_column.split(",")
89
+
90
+ size = args.cond_size
91
+ # by default the noise size would be randomly sampled from (512, 1024)
92
+ # noise_size = get_random_resolution(max_size=args.noise_size) # maybe 768 or higher
93
+ noise_size = get_random_resolution(min_size=512, max_size=512) # maybe 768 or higher
94
+ # subject_cond_train_transforms = transforms.Compose(
95
+ # [
96
+ # transforms.Lambda(lambda img: img.resize((
97
+ # multiple_16(size * img.size[0] / max(img.size)),
98
+ # multiple_16(size * img.size[1] / max(img.size))
99
+ # ), resample=Image.BILINEAR)),
100
+ # transforms.RandomHorizontalFlip(p=0.7),
101
+ # transforms.RandomRotation(degrees=20),
102
+ # transforms.Lambda(lambda img: transforms.Pad(
103
+ # padding=(
104
+ # int((size - img.size[0]) / 2),
105
+ # int((size - img.size[1]) / 2),
106
+ # int((size - img.size[0]) / 2),
107
+ # int((size - img.size[1]) / 2)
108
+ # ),
109
+ # fill=0
110
+ # )(img)),
111
+ # transforms.ToTensor(),
112
+ # transforms.Normalize([0.5], [0.5]),
113
+ # ]
114
+ # )
115
+ cond_train_transforms = transforms.Compose(
116
+ [
117
+ transforms.Resize((size, size), interpolation=transforms.InterpolationMode.BILINEAR),
118
+ transforms.CenterCrop((size, size)),
119
+ transforms.ToTensor(),
120
+ transforms.Normalize([0.5], [0.5]),
121
+ ]
122
+ )
123
+ subject_cond_train_transforms = cond_train_transforms
124
+
125
+ def train_transforms(image, noise_size):
126
+ train_transforms_ = transforms.Compose(
127
+ [
128
+ transforms.Lambda(lambda img: img.resize((
129
+ multiple_16(noise_size * img.size[0] / max(img.size)),
130
+ multiple_16(noise_size * img.size[1] / max(img.size))
131
+ ), resample=Image.BILINEAR)),
132
+ transforms.ToTensor(),
133
+ transforms.Normalize([0.5], [0.5]),
134
+ ]
135
+ )
136
+ transformed_image = train_transforms_(image)
137
+ return transformed_image
138
+
139
+ def load_and_transform_cond_images(images):
140
+ transformed_images = [cond_train_transforms(image) for image in images]
141
+ concatenated_image = torch.cat(transformed_images, dim=1)
142
+ return concatenated_image
143
+
144
+ def load_and_transform_subject_images(images):
145
+ transformed_images = [subject_cond_train_transforms(image) for image in images]
146
+ concatenated_image = torch.cat(transformed_images, dim=1)
147
+ return concatenated_image
148
+
149
+ tokenizer_clip = tokenizer[0]
150
+ tokenizer_t5 = tokenizer[1]
151
+
152
+ def retrieve_prompt_embeds_from_disk(args, examples):
153
+ captions = []
154
+ for caption in examples["prompts"]:
155
+ if isinstance(caption, str):
156
+ if random.random() < 0.1:
157
+ captions.append(" ") # 将文本设为空
158
+ else:
159
+ captions.append(caption)
160
+ elif isinstance(caption, list):
161
+ raise NotImplementedError("list of captions not supported yet")
162
+ # take a random caption if there are multiple
163
+ if random.random() < 0.1:
164
+ captions.append(" ")
165
+ else:
166
+ captions.append(random.choice(caption))
167
+ else:
168
+ raise ValueError(
169
+ f"Caption column should contain either strings or lists of strings."
170
+ )
171
+
172
+ all_prompt_embeds = []
173
+ all_pooled_prompt_embeds = []
174
+ for caption in captions:
175
+ if caption == " ":
176
+ prompt_file_name = "space_prompt.pth"
177
+ else:
178
+ prompt_file_name = "_".join(caption.split(" ")) + ".pth"
179
+ if osp.exists(osp.join(args.inference_embeds_dir, prompt_file_name)):
180
+ prompt_embeds = torch.load(osp.join(args.inference_embeds_dir, prompt_file_name), map_location="cpu")
181
+ pooled_prompt_embeds = prompt_embeds["pooled_prompt_embeds"]
182
+ prompt_embeds = prompt_embeds["prompt_embeds"]
183
+ else:
184
+ # raise FileNotFoundError(f"Prompt embeddings for '{caption}' not found in {args.inference_embeds_dir}. Please precompute and save them.")
185
+ prompt_embeds = torch.zeros((1, 77, 768)) # Placeholder tensor
186
+ pooled_prompt_embeds = torch.zeros((1, 768)) # Placeholder tensor
187
+ all_prompt_embeds.append(prompt_embeds.squeeze(0))
188
+ all_pooled_prompt_embeds.append(pooled_prompt_embeds.squeeze(0))
189
+ return all_prompt_embeds, all_pooled_prompt_embeds
190
+
191
+
192
+ def tokenize_prompt_clip_t5(examples):
193
+ captions = []
194
+ for caption in examples["prompts"]:
195
+ if isinstance(caption, str):
196
+ if random.random() < 0.1:
197
+ captions.append(" ") # 将文本设为空
198
+ else:
199
+ captions.append(caption)
200
+ elif isinstance(caption, list):
201
+ # take a random caption if there are multiple
202
+ if random.random() < 0.1:
203
+ captions.append(" ")
204
+ else:
205
+ captions.append(random.choice(caption))
206
+ else:
207
+ raise ValueError(
208
+ f"Caption column should contain either strings or lists of strings."
209
+ )
210
+ text_inputs = tokenizer_clip(
211
+ captions,
212
+ padding="max_length",
213
+ max_length=77,
214
+ truncation=True,
215
+ return_length=False,
216
+ return_overflowing_tokens=False,
217
+ return_tensors="pt",
218
+ )
219
+ text_input_ids_1 = text_inputs.input_ids
220
+
221
+ text_inputs = tokenizer_t5(
222
+ captions,
223
+ padding="max_length",
224
+ max_length=512,
225
+ truncation=True,
226
+ return_length=False,
227
+ return_overflowing_tokens=False,
228
+ return_tensors="pt",
229
+ )
230
+ text_input_ids_2 = text_inputs.input_ids
231
+ return text_input_ids_1, text_input_ids_2
232
+
233
+ def preprocess_train(examples):
234
+ _examples = {}
235
+ train_data_dir = osp.dirname(args.current_train_data_dir)
236
+ if args.subject_column is not None:
237
+ subject_images = [[load_image_safely(osp.join(train_data_dir, examples[column][i]), args.cond_size) for column in subject_columns] for i in range(len(examples[target_column]))]
238
+ _examples["subject_pixel_values"] = [load_and_transform_subject_images(subject) for subject in subject_images]
239
+ if args.spatial_column is not None:
240
+ # this now has two conditions
241
+ spatial_images = [[load_image_safely(osp.join(train_data_dir, examples[column][i]), args.cond_size) for column in spatial_columns] for i in range(len(examples[target_column]))]
242
+ _examples["cond_pixel_values"] = [load_and_transform_cond_images(spatial) for spatial in spatial_images]
243
+ target_images = [load_image_safely(osp.join(train_data_dir, image_path), args.cond_size) for image_path in examples[target_column]]
244
+ _examples["pixel_values"] = [train_transforms(image, noise_size) for image in target_images]
245
+ _examples["PLACEHOLDER_prompts"] = examples["PLACEHOLDER_prompts"]
246
+ subjects = examples["subjects"]
247
+ _examples["subjects"] = subjects
248
+ subjects_ = ["_".join(subject) for subject in subjects] # get the subject names with "_" instead of space
249
+ _examples["prompts"] = []
250
+ # getting the prompts by replacing the PLACEHOLDER in the prompt with the actual subject names
251
+ for i in range(len(examples["subjects"])):
252
+ # replace the subjects string in the PLACEHOLDER
253
+ prompt = examples["PLACEHOLDER_prompts"][i]
254
+ placeholder_string = " and ".join(subjects[i])
255
+ prompt = prompt.replace("PLACEHOLDER", placeholder_string)
256
+ _examples["prompts"].append(prompt)
257
+ _examples["prompt_embeds"], _examples["pooled_prompt_embeds"] = retrieve_prompt_embeds_from_disk(args, _examples)
258
+ # gettin the z passed cuboids segmentation mask
259
+ _examples["cuboids_segmasks"] = []
260
+
261
+ def generous_resize_batch(masks, new_h, new_w):
262
+ """
263
+ masks: torch.Tensor of shape (B, H, W), values in {0,1}
264
+ new_h, new_w: desired output size
265
+ """
266
+ B, H, W = masks.shape
267
+ masks = masks.unsqueeze(1).float() # -> (B,1,H,W)
268
+
269
+ # Compute pooling kernel/stride
270
+ kh = H // new_h
271
+ kw = W // new_w
272
+ assert H % new_h == 0 and W % new_w == 0, \
273
+ "H and W must be divisible by new_h and new_w for exact block pooling"
274
+
275
+ out = F.max_pool2d(masks, kernel_size=(kh, kw), stride=(kh, kw))
276
+ return out.squeeze(1).byte() # -> (B,new_h,new_w)
277
+
278
+ for i in range(len(_examples["subjects"])):
279
+ segmasks_this_example = examples["cuboids_segmasks"][i]
280
+ # the name of the segmask is of the format "segmask_00<subject_idx>__<depth_value>.png"
281
+ depth_values_this_example = [osp.basename(segmasks_this_example[j]).split("__")[-1].split(".png")[0] for j in range(len(subjects[i]))]
282
+ depth_values_this_example = torch.as_tensor([float(depth) for depth in depth_values_this_example])
283
+ assert len(segmasks_this_example) == len(subjects[i]), f"Number of segmentation masks {len(segmasks_this_example)} does not match number of subjects {len(subjects[i])} for example {i}"
284
+ segmasks_this_example = [cv2.imread(osp.join(train_data_dir, segmasks_this_example[j]), cv2.IMREAD_UNCHANGED) for j in range(len(subjects[i]))]
285
+ # segmasks_this_example = [cv2.resize(segmask, (32, 32), interpolation=cv2.INTER_NEAREST) for segmask in segmasks_this_example]
286
+ segmasks_this_example = [torch.as_tensor(segmask, dtype=torch.uint8) for segmask in segmasks_this_example]
287
+ segmasks_this_example = torch.stack(segmasks_this_example, dim=0) # (n_subjects, h, w)
288
+ mask = segmasks_this_example > 128
289
+ segmasks_this_example[mask] = 1
290
+ segmasks_this_example[~mask] = 0
291
+ segmasks_this_example = generous_resize_batch(segmasks_this_example, 32, 32)
292
+ assert segmasks_this_example.shape == (len(subjects[i]), 32, 32), f"Segmentation masks shape {segmasks_this_example.shape} does not match expected shape {(len(subjects[i]), 32, 32)} for example {i}"
293
+ # z_passed_segmask = do_z_pass(segmasks_this_example, depth_values_this_example)
294
+ # print(f"{z_passed_segmask.shape = }, {segmasks_this_example.shape = }")
295
+ # _examples["cuboids_segmasks"].append(z_passed_segmask)
296
+ _examples["cuboids_segmasks"].append(segmasks_this_example)
297
+
298
+ _examples["token_ids_clip"], _examples["token_ids_t5"] = tokenize_prompt_clip_t5(_examples)
299
+ _examples["call_ids"] = examples["call_ids"]
300
+ _examples["index"] = examples["index"]
301
+
302
+ return _examples
303
+
304
+ if accelerator is not None:
305
+ with accelerator.main_process_first():
306
+ train_dataset = dataset["train"].with_transform(preprocess_train)
307
+ else:
308
+ train_dataset = dataset["train"].with_transform(preprocess_train)
309
+
310
+ return train_dataset
311
+
312
+
313
+ def collate_fn(examples):
314
+ if examples[0].get("cond_pixel_values") is not None:
315
+ cond_pixel_values = torch.stack([example["cond_pixel_values"] for example in examples])
316
+ cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
317
+ else:
318
+ cond_pixel_values = None
319
+ if examples[0].get("subject_pixel_values") is not None:
320
+ subject_pixel_values = torch.stack([example["subject_pixel_values"] for example in examples])
321
+ subject_pixel_values = subject_pixel_values.to(memory_format=torch.contiguous_format).float()
322
+ else:
323
+ subject_pixel_values = None
324
+
325
+ target_pixel_values = torch.stack([example["pixel_values"] for example in examples])
326
+ target_pixel_values = target_pixel_values.to(memory_format=torch.contiguous_format).float()
327
+ token_ids_clip = torch.stack([torch.tensor(example["token_ids_clip"]) for example in examples])
328
+ token_ids_t5 = torch.stack([torch.tensor(example["token_ids_t5"]) for example in examples])
329
+ prompt_embeds = torch.stack([example["prompt_embeds"] for example in examples], dim=0)
330
+ pooled_prompt_embeds = torch.stack([example["pooled_prompt_embeds"] for example in examples], dim=0)
331
+ prompts = [example["prompts"] for example in examples]
332
+ call_ids = [example["call_ids"] for example in examples]
333
+ cuboids_segmasks = [example["cuboids_segmasks"] for example in examples] if examples[0].get("cuboids_segmasks") is not None else None
334
+ indices = [example["index"] for example in examples] # Add this line
335
+
336
+ return {
337
+ "cond_pixel_values": cond_pixel_values,
338
+ "subject_pixel_values": subject_pixel_values,
339
+ "pixel_values": target_pixel_values,
340
+ "text_ids_1": token_ids_clip,
341
+ "text_ids_2": token_ids_t5,
342
+ "prompt_embeds": prompt_embeds,
343
+ "pooled_prompt_embeds": pooled_prompt_embeds,
344
+ "prompts": prompts,
345
+ "call_ids": call_ids,
346
+ "cuboids_segmasks": cuboids_segmasks,
347
+ "index": indices,
348
+ }
train/src/layers.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+ from einops import rearrange
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import nn
8
+ from torch import Tensor
9
+ from diffusers.models.attention_processor import Attention
10
+ import os
11
+ import os.path as osp
12
+ import numpy as np
13
+
14
+ class LoRALinearLayer(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ out_features: int,
19
+ rank: int = 4,
20
+ network_alpha: Optional[float] = None,
21
+ device: Optional[Union[torch.device, str]] = None,
22
+ dtype: Optional[torch.dtype] = None,
23
+ cond_width=512,
24
+ cond_height=512,
25
+ number=0,
26
+ n_loras=1
27
+ ):
28
+ super().__init__()
29
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
30
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
31
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
32
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
33
+ self.network_alpha = network_alpha
34
+ self.rank = rank
35
+ self.out_features = out_features
36
+ self.in_features = in_features
37
+
38
+ nn.init.normal_(self.down.weight, std=1 / rank)
39
+ nn.init.zeros_(self.up.weight)
40
+
41
+ self.cond_height = cond_height
42
+ self.cond_width = cond_width
43
+ self.number = number
44
+ self.n_loras = n_loras
45
+
46
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
47
+ orig_dtype = hidden_states.dtype
48
+ dtype = self.down.weight.dtype
49
+
50
+ #### img condition
51
+ batch_size = hidden_states.shape[0]
52
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
53
+ block_size = hidden_states.shape[1] - cond_size * self.n_loras
54
+ shape = (batch_size, hidden_states.shape[1], 3072)
55
+ mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
56
+ mask[:, :block_size+self.number*cond_size, :] = 0
57
+ mask[:, block_size+(self.number+1)*cond_size:, :] = 0
58
+ hidden_states = mask * hidden_states
59
+ ####
60
+
61
+ down_hidden_states = self.down(hidden_states.to(dtype))
62
+ up_hidden_states = self.up(down_hidden_states)
63
+
64
+ if self.network_alpha is not None:
65
+ up_hidden_states *= self.network_alpha / self.rank
66
+
67
+ return up_hidden_states.to(orig_dtype)
68
+
69
+
70
+ class MultiSingleStreamBlockLoraProcessor(nn.Module):
71
+ def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
72
+ super().__init__()
73
+ # Initialize a list to store the LoRA layers
74
+ self.n_loras = n_loras
75
+ self.cond_width = cond_width
76
+ self.cond_height = cond_height
77
+
78
+ self.q_loras = nn.ModuleList([
79
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
80
+ for i in range(n_loras)
81
+ ])
82
+ self.k_loras = nn.ModuleList([
83
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
84
+ for i in range(n_loras)
85
+ ])
86
+ self.v_loras = nn.ModuleList([
87
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
88
+ for i in range(n_loras)
89
+ ])
90
+ self.lora_weights = lora_weights
91
+
92
+
93
+ def __call__(self,
94
+ attn: Attention,
95
+ hidden_states: torch.FloatTensor,
96
+ encoder_hidden_states: torch.FloatTensor = None,
97
+ attention_mask: Optional[torch.FloatTensor] = None,
98
+ image_rotary_emb: Optional[torch.Tensor] = None,
99
+ use_cond = False,
100
+ call_ids = None,
101
+ cuboids_segmasks: torch.Tensor = None,
102
+ store_qk: Optional[str] = None,
103
+ ) -> torch.FloatTensor:
104
+
105
+ batch_size, seq_len, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
106
+ query = attn.to_q(hidden_states)
107
+ key = attn.to_k(hidden_states)
108
+ value = attn.to_v(hidden_states)
109
+
110
+ for i in range(self.n_loras):
111
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
112
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
113
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
114
+
115
+ inner_dim = key.shape[-1]
116
+ head_dim = inner_dim // attn.heads
117
+
118
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
119
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
120
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
121
+
122
+ if attn.norm_q is not None:
123
+ query = attn.norm_q(query)
124
+ if attn.norm_k is not None:
125
+ key = attn.norm_k(key)
126
+
127
+ if image_rotary_emb is not None:
128
+ from diffusers.models.embeddings import apply_rotary_emb
129
+ query = apply_rotary_emb(query, image_rotary_emb)
130
+ key = apply_rotary_emb(key, image_rotary_emb)
131
+
132
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
133
+ block_size = hidden_states.shape[1] - cond_size * self.n_loras
134
+ scaled_cond_size = cond_size
135
+ scaled_block_size = block_size
136
+ scaled_seq_len = query.shape[2]
137
+
138
+ num_cond_blocks = self.n_loras
139
+ mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
140
+ # zero for all the 'allowed' connections
141
+ mask[ :scaled_block_size, :] = 0 # First block_size row
142
+ for i in range(num_cond_blocks):
143
+ start = i * scaled_cond_size + scaled_block_size
144
+ end = (i + 1) * scaled_cond_size + scaled_block_size
145
+ mask[start:end, start:end] = 0 # Diagonal blocks
146
+
147
+ assert mask.shape[0] == scaled_block_size + num_cond_blocks*scaled_cond_size, f"{mask.shape = }, {scaled_block_size=}, {num_cond_blocks=}, {scaled_cond_size=}"
148
+
149
+ if call_ids is not None:
150
+ # repeat across batch size and heads
151
+ mask = mask.unsqueeze(0).unsqueeze(0).repeat(len(call_ids), 1, 1, 1) # (batch_size, num_heads, seq_len, seq_len)
152
+ num_img_tokens = scaled_block_size - 512
153
+ for batch_idx in range(len(call_ids)):
154
+ call_ids_this_example = call_ids[batch_idx]
155
+ for subject_idx, call_ids_this_subject in enumerate(call_ids_this_example):
156
+ # preparing the cuboid mask
157
+ cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
158
+ # assert cuboid_mask.shape == (int(math.sqrt(num_img_tokens)), int(math.sqrt(num_img_tokens))), f"{cuboid_mask.shape=}, {num_img_tokens=}"
159
+ cuboid_mask = cuboid_mask.to(torch.bool)
160
+
161
+ # assert scaled_block_size == scaled_cond_size + 512, f"{scaled_cond_size=}, {scaled_block_size=}"
162
+ for i in range(num_cond_blocks):
163
+ cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
164
+ cuboid_mask = cuboid_mask.to(torch.bool)
165
+ # masking out the condition tokens -> text token attention map
166
+ mask_subset = mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject]
167
+ # assert mask_subset.shape == (1, num_img_tokens, len(call_ids_this_subject)), f"{mask_subset.shape=}, {attn.heads=}, {num_img_tokens=}, {len(call_ids_this_subject)=}"
168
+ mask_subset[:, cuboid_mask.flatten()] = 0 # enable attention to cuboid regions
169
+
170
+ mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] = mask_subset
171
+
172
+
173
+ mask = mask * -1e20
174
+ mask = mask.to(query.dtype)
175
+
176
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
177
+
178
+ if store_qk:
179
+ attn_weights = query.detach().to(torch.float16) @ key.detach().to(torch.float16).transpose(-1, -2) # (batch_size, num_heads, query_len, key_len)
180
+ attn_weights = attn_weights + mask
181
+ attn_weights = torch.mean(torch.softmax(attn_weights, dim=-1), dim=1)
182
+ attn_weights = attn_weights.cpu()
183
+ os.makedirs(osp.dirname(store_qk), exist_ok=True)
184
+ torch.save(attn_weights, store_qk + ".pth")
185
+
186
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
187
+ hidden_states = hidden_states.to(query.dtype)
188
+
189
+ cond_hidden_states = hidden_states[:, block_size:,:]
190
+ hidden_states = hidden_states[:, : block_size,:]
191
+
192
+ return hidden_states if not use_cond else (hidden_states, cond_hidden_states)
193
+
194
+
195
+ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
196
+ def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
197
+ super().__init__()
198
+
199
+ # Initialize a list to store the LoRA layers
200
+ self.n_loras = n_loras
201
+ self.cond_width = cond_width
202
+ self.cond_height = cond_height
203
+ self.q_loras = nn.ModuleList([
204
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
205
+ for i in range(n_loras)
206
+ ])
207
+ self.k_loras = nn.ModuleList([
208
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
209
+ for i in range(n_loras)
210
+ ])
211
+ self.v_loras = nn.ModuleList([
212
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
213
+ for i in range(n_loras)
214
+ ])
215
+ self.proj_loras = nn.ModuleList([
216
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
217
+ for i in range(n_loras)
218
+ ])
219
+ self.lora_weights = lora_weights
220
+
221
+
222
+ def __call__(self,
223
+ attn: Attention,
224
+ hidden_states: torch.FloatTensor,
225
+ encoder_hidden_states: torch.FloatTensor = None,
226
+ attention_mask: Optional[torch.FloatTensor] = None,
227
+ image_rotary_emb: Optional[torch.Tensor] = None,
228
+ use_cond=False,
229
+ call_ids = None,
230
+ cuboids_segmasks: torch.Tensor = None,
231
+ store_qk: Optional[str] = None,
232
+ ) -> torch.FloatTensor:
233
+
234
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
235
+
236
+ # `context` projections.
237
+ inner_dim = 3072
238
+ head_dim = inner_dim // attn.heads
239
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
240
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
241
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
242
+
243
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
244
+ batch_size, -1, attn.heads, head_dim
245
+ ).transpose(1, 2)
246
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
247
+ batch_size, -1, attn.heads, head_dim
248
+ ).transpose(1, 2)
249
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
250
+ batch_size, -1, attn.heads, head_dim
251
+ ).transpose(1, 2)
252
+
253
+ if attn.norm_added_q is not None:
254
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
255
+ if attn.norm_added_k is not None:
256
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
257
+
258
+ query = attn.to_q(hidden_states)
259
+ key = attn.to_k(hidden_states)
260
+ value = attn.to_v(hidden_states)
261
+ for i in range(self.n_loras):
262
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
263
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
264
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
265
+
266
+ inner_dim = key.shape[-1]
267
+ head_dim = inner_dim // attn.heads
268
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
269
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
270
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
271
+
272
+ if attn.norm_q is not None:
273
+ query = attn.norm_q(query)
274
+ if attn.norm_k is not None:
275
+ key = attn.norm_k(key)
276
+
277
+ # attention
278
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
279
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
280
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
281
+
282
+ if image_rotary_emb is not None:
283
+ from diffusers.models.embeddings import apply_rotary_emb
284
+ query = apply_rotary_emb(query, image_rotary_emb)
285
+ key = apply_rotary_emb(key, image_rotary_emb)
286
+
287
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
288
+ block_size = hidden_states.shape[1] - cond_size * self.n_loras
289
+ scaled_cond_size = cond_size
290
+ scaled_seq_len = query.shape[2]
291
+ scaled_block_size = scaled_seq_len - cond_size * self.n_loras
292
+
293
+ num_cond_blocks = self.n_loras
294
+ mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
295
+ mask[ :scaled_block_size, :] = 0 # First block_size row
296
+ for i in range(num_cond_blocks):
297
+ start = i * scaled_cond_size + scaled_block_size
298
+ end = (i + 1) * scaled_cond_size + scaled_block_size
299
+ mask[start:end, start:end] = 0 # Diagonal blocks
300
+
301
+ assert mask.shape[0] == scaled_block_size + num_cond_blocks*scaled_cond_size, f"{mask.shape = }, {scaled_block_size=}, {num_cond_blocks=}, {scaled_cond_size=}"
302
+
303
+ if call_ids is not None:
304
+ # repeat across batch size and heads
305
+ mask = mask.unsqueeze(0).unsqueeze(0).repeat(len(call_ids), 1, 1, 1) # (batch_size, num_heads, seq_len, seq_len)
306
+ num_img_tokens = scaled_block_size - 512
307
+ for batch_idx in range(len(call_ids)):
308
+ call_ids_this_example = call_ids[batch_idx]
309
+ for subject_idx, call_ids_this_subject in enumerate(call_ids_this_example):
310
+ # preparing the cuboid mask
311
+ cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
312
+ # assert cuboid_mask.shape == (int(math.sqrt(num_img_tokens)), int(math.sqrt(num_img_tokens))), f"{cuboid_mask.shape=}, {num_img_tokens=}, {scaled_block_size=}"
313
+ cuboid_mask = cuboid_mask.to(torch.bool)
314
+
315
+ # assert scaled_block_size == scaled_cond_size + 512, f"{scaled_cond_size=}, {scaled_block_size=}"
316
+ for i in range(num_cond_blocks):
317
+ cuboid_mask = cuboids_segmasks[batch_idx][subject_idx] # (h, w)
318
+ cuboid_mask = cuboid_mask.to(torch.bool)
319
+ # masking out the condition tokens -> text token attention map
320
+ mask_subset = mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject]
321
+ # assert mask_subset.shape == (1, num_img_tokens, len(call_ids_this_subject)), f"{mask_subset.shape=}, {attn.heads=}, {num_img_tokens=}, {len(call_ids_this_subject)=}"
322
+ mask_subset[:, cuboid_mask.flatten()] = 0 # enable attention to cuboid regions
323
+
324
+ mask[batch_idx, :, scaled_block_size + i*scaled_cond_size : scaled_block_size + (i+1)*scaled_cond_size, call_ids_this_subject] = mask_subset
325
+
326
+
327
+ mask = mask * -1e20
328
+ mask = mask.to(query.dtype)
329
+
330
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
331
+
332
+ if store_qk:
333
+ attn_weights = query.detach().to(torch.float16) @ key.detach().to(torch.float16).transpose(-1, -2) # (batch_size, num_heads, query_len, key_len)
334
+ attn_weights = attn_weights + mask
335
+ attn_weights = torch.mean(torch.softmax(attn_weights, dim=-1), dim=1)
336
+ attn_weights = attn_weights.cpu()
337
+ os.makedirs(osp.dirname(store_qk), exist_ok=True)
338
+ torch.save(attn_weights, store_qk + ".pth")
339
+
340
+
341
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
342
+ hidden_states = hidden_states.to(query.dtype)
343
+
344
+ encoder_hidden_states, hidden_states = (
345
+ hidden_states[:, : encoder_hidden_states.shape[1]],
346
+ hidden_states[:, encoder_hidden_states.shape[1] :],
347
+ )
348
+
349
+ # Linear projection (with LoRA weight applied to each proj layer)
350
+ hidden_states = attn.to_out[0](hidden_states)
351
+ for i in range(self.n_loras):
352
+ hidden_states = hidden_states + self.lora_weights[i] * self.proj_loras[i](hidden_states)
353
+ # dropout
354
+ hidden_states = attn.to_out[1](hidden_states)
355
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
356
+
357
+ cond_hidden_states = hidden_states[:, block_size:,:]
358
+ hidden_states = hidden_states[:, :block_size,:]
359
+
360
+ return (hidden_states, encoder_hidden_states, cond_hidden_states) if use_cond else (encoder_hidden_states, hidden_states)
train/src/lora_helper.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
2
+ from safetensors import safe_open
3
+ import re
4
+ import torch
5
+ from .layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
6
+
7
+ device = "cuda"
8
+
9
+ def load_safetensors(path):
10
+ tensors = {}
11
+ with safe_open(path, framework="pt", device="cpu") as f:
12
+ for key in f.keys():
13
+ tensors[key] = f.get_tensor(key)
14
+ return tensors
15
+
16
+ def get_lora_rank(checkpoint):
17
+ for k in checkpoint.keys():
18
+ if k.endswith(".down.weight"):
19
+ return checkpoint[k].shape[0]
20
+
21
+ def load_checkpoint(local_path):
22
+ if local_path is not None:
23
+ if '.safetensors' in local_path:
24
+ print(f"Loading .safetensors checkpoint from {local_path}")
25
+ checkpoint = load_safetensors(local_path)
26
+ else:
27
+ print(f"Loading checkpoint from {local_path}")
28
+ checkpoint = torch.load(local_path, map_location='cpu')
29
+ return checkpoint
30
+
31
+ def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size):
32
+ number = len(lora_weights)
33
+ ranks = [get_lora_rank(checkpoint) for _ in range(number)]
34
+ lora_attn_procs = {}
35
+ double_blocks_idx = list(range(19))
36
+ single_blocks_idx = list(range(38))
37
+ for name, attn_processor in transformer.attn_processors.items():
38
+ match = re.search(r'\.(\d+)\.', name)
39
+ if match:
40
+ layer_index = int(match.group(1))
41
+
42
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
43
+
44
+ lora_state_dicts = {}
45
+ for key, value in checkpoint.items():
46
+ # Match based on the layer index in the key (assuming the key contains layer index)
47
+ if re.search(r'\.(\d+)\.', key):
48
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
49
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
50
+ lora_state_dicts[key] = value
51
+
52
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
53
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
54
+ )
55
+
56
+ # Load the weights from the checkpoint dictionary into the corresponding layers
57
+ for n in range(number):
58
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
59
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
60
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
61
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
62
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
63
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
64
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
65
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
66
+ lora_attn_procs[name].to(device)
67
+
68
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
69
+
70
+ lora_state_dicts = {}
71
+ for key, value in checkpoint.items():
72
+ # Match based on the layer index in the key (assuming the key contains layer index)
73
+ if re.search(r'\.(\d+)\.', key):
74
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
75
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
76
+ lora_state_dicts[key] = value
77
+
78
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
79
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
80
+ )
81
+ # Load the weights from the checkpoint dictionary into the corresponding layers
82
+ for n in range(number):
83
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
84
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
85
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
86
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
87
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
88
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
89
+ lora_attn_procs[name].to(device)
90
+ else:
91
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
92
+
93
+ transformer.set_attn_processor(lora_attn_procs)
94
+
95
+
96
+ def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
97
+ ck_number = len(checkpoints)
98
+ cond_lora_number = [len(ls) for ls in lora_weights]
99
+ cond_number = sum(cond_lora_number)
100
+ ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
101
+ multi_lora_weight = []
102
+ for ls in lora_weights:
103
+ for n in ls:
104
+ multi_lora_weight.append(n)
105
+
106
+ lora_attn_procs = {}
107
+ double_blocks_idx = list(range(19))
108
+ single_blocks_idx = list(range(38))
109
+ for name, attn_processor in transformer.attn_processors.items():
110
+ match = re.search(r'\.(\d+)\.', name)
111
+ if match:
112
+ layer_index = int(match.group(1))
113
+
114
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
115
+ lora_state_dicts = [{} for _ in range(ck_number)]
116
+ for idx, checkpoint in enumerate(checkpoints):
117
+ for key, value in checkpoint.items():
118
+ # Match based on the layer index in the key (assuming the key contains layer index)
119
+ if re.search(r'\.(\d+)\.', key):
120
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
121
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
122
+ lora_state_dicts[idx][key] = value
123
+
124
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
125
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
126
+ )
127
+
128
+ # Load the weights from the checkpoint dictionary into the corresponding layers
129
+ num = 0
130
+ for idx in range(ck_number):
131
+ for n in range(cond_lora_number[idx]):
132
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
133
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
134
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
135
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
136
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
137
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
138
+ lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
139
+ lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
140
+ lora_attn_procs[name].to(device)
141
+ num += 1
142
+
143
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
144
+
145
+ lora_state_dicts = [{} for _ in range(ck_number)]
146
+ for idx, checkpoint in enumerate(checkpoints):
147
+ for key, value in checkpoint.items():
148
+ # Match based on the layer index in the key (assuming the key contains layer index)
149
+ if re.search(r'\.(\d+)\.', key):
150
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
151
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
152
+ lora_state_dicts[idx][key] = value
153
+
154
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
155
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
156
+ )
157
+ # Load the weights from the checkpoint dictionary into the corresponding layers
158
+ num = 0
159
+ for idx in range(ck_number):
160
+ for n in range(cond_lora_number[idx]):
161
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
162
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
163
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
164
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
165
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
166
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
167
+ lora_attn_procs[name].to(device)
168
+ num += 1
169
+
170
+ else:
171
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
172
+
173
+ transformer.set_attn_processor(lora_attn_procs)
174
+
175
+
176
+ def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512):
177
+ checkpoint = load_checkpoint(local_path)
178
+ update_model_with_lora(checkpoint, lora_weights, transformer, cond_size)
179
+
180
+ def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
181
+ checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
182
+ update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
183
+
184
+ def unset_lora(transformer):
185
+ lora_attn_procs = {}
186
+ for name, attn_processor in transformer.attn_processors.items():
187
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
188
+ transformer.set_attn_processor(lora_attn_procs)
189
+
190
+
191
+ '''
192
+ unset_lora(pipe.transformer)
193
+ lora_path = "./lora.safetensors"
194
+ lora_weights = [1, 1]
195
+ set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
196
+ '''
train/src/pipeline.py ADDED
@@ -0,0 +1,824 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Union
3
+
4
+ import numpy as np
5
+ import torch
6
+ from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
7
+ import copy
8
+ import os
9
+ import os.path as osp
10
+
11
+ from diffusers.image_processor import (VaeImageProcessor)
12
+ from diffusers.loaders import FluxLoraLoaderMixin, FromSingleFileMixin
13
+ from diffusers.models.autoencoders import AutoencoderKL
14
+ from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
15
+ from diffusers.utils import (
16
+ USE_PEFT_BACKEND,
17
+ is_torch_xla_available,
18
+ logging,
19
+ scale_lora_layers,
20
+ unscale_lora_layers,
21
+ )
22
+ from diffusers.utils.torch_utils import randn_tensor
23
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
24
+ from diffusers.pipelines.flux.pipeline_output import FluxPipelineOutput
25
+ from torchvision.transforms.functional import pad
26
+ from .transformer_flux import FluxTransformer2DModel
27
+
28
+ if is_torch_xla_available():
29
+ import torch_xla.core.xla_model as xm
30
+
31
+ XLA_AVAILABLE = True
32
+ else:
33
+ XLA_AVAILABLE = False
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+ def calculate_shift(
38
+ image_seq_len,
39
+ base_seq_len: int = 256,
40
+ max_seq_len: int = 4096,
41
+ base_shift: float = 0.5,
42
+ max_shift: float = 1.16,
43
+ ):
44
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
45
+ b = base_shift - m * base_seq_len
46
+ mu = image_seq_len * m + b
47
+ return mu
48
+
49
+ def prepare_latent_image_ids_2(height, width, device, dtype):
50
+ latent_image_ids = torch.zeros(height//2, width//2, 3, device=device, dtype=dtype)
51
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height//2, device=device)[:, None] # y坐标
52
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width//2, device=device)[None, :] # x坐标
53
+ return latent_image_ids
54
+
55
+ def prepare_latent_subject_ids(height, width, device, dtype):
56
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3, device=device, dtype=dtype)
57
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2, device=device)[:, None]
58
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2, device=device)[None, :]
59
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
60
+ latent_image_ids = latent_image_ids.reshape(
61
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
62
+ )
63
+ return latent_image_ids.to(device=device, dtype=dtype)
64
+
65
+ def resize_position_encoding(batch_size, original_height, original_width, target_height, target_width, device, dtype):
66
+ latent_image_ids = prepare_latent_image_ids_2(original_height, original_width, device, dtype)
67
+ scale_h = original_height / target_height
68
+ scale_w = original_width / target_width
69
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
70
+ latent_image_ids = latent_image_ids.reshape(
71
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
72
+ )
73
+ #spatial进行PE插值
74
+ latent_image_ids_resized = torch.zeros(target_height//2, target_width//2, 3, device=device, dtype=dtype)
75
+ for i in range(target_height//2):
76
+ for j in range(target_width//2):
77
+ latent_image_ids_resized[i, j, 1] = i*scale_h
78
+ latent_image_ids_resized[i, j, 2] = j*scale_w
79
+ cond_latent_image_id_height, cond_latent_image_id_width, cond_latent_image_id_channels = latent_image_ids_resized.shape
80
+ cond_latent_image_ids = latent_image_ids_resized.reshape(
81
+ cond_latent_image_id_height * cond_latent_image_id_width, cond_latent_image_id_channels
82
+ )
83
+ # latent_image_ids_ = torch.concat([latent_image_ids, cond_latent_image_ids], dim=0)
84
+ return latent_image_ids, cond_latent_image_ids #, latent_image_ids_
85
+
86
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
87
+ def retrieve_latents(
88
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
89
+ ):
90
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
91
+ return encoder_output.latent_dist.sample(generator)
92
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
93
+ return encoder_output.latent_dist.mode()
94
+ elif hasattr(encoder_output, "latents"):
95
+ return encoder_output.latents
96
+ else:
97
+ raise AttributeError("Could not access latents of provided encoder_output")
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
101
+ def retrieve_timesteps(
102
+ scheduler,
103
+ num_inference_steps: Optional[int] = None,
104
+ device: Optional[Union[str, torch.device]] = None,
105
+ timesteps: Optional[List[int]] = None,
106
+ sigmas: Optional[List[float]] = None,
107
+ **kwargs,
108
+ ):
109
+ """
110
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
111
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
112
+
113
+ Args:
114
+ scheduler (`SchedulerMixin`):
115
+ The scheduler to get timesteps from.
116
+ num_inference_steps (`int`):
117
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
118
+ must be `None`.
119
+ device (`str` or `torch.device`, *optional*):
120
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
121
+ timesteps (`List[int]`, *optional*):
122
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
123
+ `num_inference_steps` and `sigmas` must be `None`.
124
+ sigmas (`List[float]`, *optional*):
125
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
126
+ `num_inference_steps` and `timesteps` must be `None`.
127
+
128
+ Returns:
129
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
130
+ second element is the number of inference steps.
131
+ """
132
+ if timesteps is not None and sigmas is not None:
133
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
134
+ if timesteps is not None:
135
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
136
+ if not accepts_timesteps:
137
+ raise ValueError(
138
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
139
+ f" timestep schedules. Please check whether you are using the correct scheduler."
140
+ )
141
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
142
+ timesteps = scheduler.timesteps
143
+ num_inference_steps = len(timesteps)
144
+ elif sigmas is not None:
145
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
146
+ if not accept_sigmas:
147
+ raise ValueError(
148
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
149
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
150
+ )
151
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
152
+ timesteps = scheduler.timesteps
153
+ num_inference_steps = len(timesteps)
154
+ else:
155
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
156
+ timesteps = scheduler.timesteps
157
+ return timesteps, num_inference_steps
158
+
159
+
160
+ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
161
+ r"""
162
+ The Flux pipeline for text-to-image generation.
163
+
164
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
165
+
166
+ Args:
167
+ transformer ([`FluxTransformer2DModel`]):
168
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
169
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
170
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
171
+ vae ([`AutoencoderKL`]):
172
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
173
+ text_encoder ([`CLIPTextModel`]):
174
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
175
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
176
+ text_encoder_2 ([`T5EncoderModel`]):
177
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
178
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
179
+ tokenizer (`CLIPTokenizer`):
180
+ Tokenizer of class
181
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
182
+ tokenizer_2 (`T5TokenizerFast`):
183
+ Second Tokenizer of class
184
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
185
+ """
186
+
187
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
188
+ _optional_components = []
189
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
190
+
191
+ def __init__(
192
+ self,
193
+ scheduler: FlowMatchEulerDiscreteScheduler,
194
+ vae: AutoencoderKL,
195
+ text_encoder: CLIPTextModel,
196
+ tokenizer: CLIPTokenizer,
197
+ text_encoder_2: T5EncoderModel,
198
+ tokenizer_2: T5TokenizerFast,
199
+ transformer: FluxTransformer2DModel,
200
+ ):
201
+ super().__init__()
202
+
203
+ self.register_modules(
204
+ vae=vae,
205
+ text_encoder=text_encoder,
206
+ text_encoder_2=text_encoder_2,
207
+ tokenizer=tokenizer,
208
+ tokenizer_2=tokenizer_2,
209
+ transformer=transformer,
210
+ scheduler=scheduler,
211
+ )
212
+ self.vae_scale_factor = (
213
+ 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
214
+ )
215
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
216
+ self.tokenizer_max_length = (
217
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
218
+ )
219
+ self.default_sample_size = 64
220
+
221
+ def _get_t5_prompt_embeds(
222
+ self,
223
+ prompt: Union[str, List[str]] = None,
224
+ num_images_per_prompt: int = 1,
225
+ max_sequence_length: int = 512,
226
+ device: Optional[torch.device] = None,
227
+ dtype: Optional[torch.dtype] = None,
228
+ ):
229
+ device = device or self._execution_device
230
+ dtype = dtype or self.text_encoder.dtype
231
+
232
+ prompt = [prompt] if isinstance(prompt, str) else prompt
233
+ batch_size = len(prompt)
234
+
235
+ text_inputs = self.tokenizer_2(
236
+ prompt,
237
+ padding="max_length",
238
+ max_length=max_sequence_length,
239
+ truncation=True,
240
+ return_length=False,
241
+ return_overflowing_tokens=False,
242
+ return_tensors="pt",
243
+ )
244
+ text_input_ids = text_inputs.input_ids
245
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
246
+
247
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
248
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
249
+ logger.warning(
250
+ "The following part of your input was truncated because `max_sequence_length` is set to "
251
+ f" {max_sequence_length} tokens: {removed_text}"
252
+ )
253
+
254
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
255
+
256
+ dtype = self.text_encoder_2.dtype
257
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
258
+
259
+ _, seq_len, _ = prompt_embeds.shape
260
+
261
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
262
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
263
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
264
+
265
+ return prompt_embeds
266
+
267
+ def _get_clip_prompt_embeds(
268
+ self,
269
+ prompt: Union[str, List[str]],
270
+ num_images_per_prompt: int = 1,
271
+ device: Optional[torch.device] = None,
272
+ ):
273
+ device = device or self._execution_device
274
+
275
+ prompt = [prompt] if isinstance(prompt, str) else prompt
276
+ batch_size = len(prompt)
277
+
278
+ text_inputs = self.tokenizer(
279
+ prompt,
280
+ padding="max_length",
281
+ max_length=self.tokenizer_max_length,
282
+ truncation=True,
283
+ return_overflowing_tokens=False,
284
+ return_length=False,
285
+ return_tensors="pt",
286
+ )
287
+
288
+ text_input_ids = text_inputs.input_ids
289
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
290
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
291
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1: -1])
292
+ logger.warning(
293
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
294
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
295
+ )
296
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
297
+
298
+ # Use pooled output of CLIPTextModel
299
+ prompt_embeds = prompt_embeds.pooler_output
300
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
301
+
302
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
303
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
304
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
305
+
306
+ return prompt_embeds
307
+
308
+ def encode_prompt(
309
+ self,
310
+ prompt: Union[str, List[str]],
311
+ prompt_2: Union[str, List[str]],
312
+ device: Optional[torch.device] = None,
313
+ num_images_per_prompt: int = 1,
314
+ prompt_embeds: Optional[torch.FloatTensor] = None,
315
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
316
+ max_sequence_length: int = 512,
317
+ lora_scale: Optional[float] = None,
318
+ ):
319
+ r"""
320
+
321
+ Args:
322
+ prompt (`str` or `List[str]`, *optional*):
323
+ prompt to be encoded
324
+ prompt_2 (`str` or `List[str]`, *optional*):
325
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
326
+ used in all text-encoders
327
+ device: (`torch.device`):
328
+ torch device
329
+ num_images_per_prompt (`int`):
330
+ number of images that should be generated per prompt
331
+ prompt_embeds (`torch.FloatTensor`, *optional*):
332
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
333
+ provided, text embeddings will be generated from `prompt` input argument.
334
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
335
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
336
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
337
+ lora_scale (`float`, *optional*):
338
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
339
+ """
340
+ device = device or self._execution_device
341
+
342
+ # set lora scale so that monkey patched LoRA
343
+ # function of text encoder can correctly access it
344
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
345
+ self._lora_scale = lora_scale
346
+
347
+ # dynamically adjust the LoRA scale
348
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
349
+ scale_lora_layers(self.text_encoder, lora_scale)
350
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
351
+ scale_lora_layers(self.text_encoder_2, lora_scale)
352
+
353
+ prompt = [prompt] if isinstance(prompt, str) else prompt
354
+
355
+ if prompt_embeds is None:
356
+ prompt_2 = prompt_2 or prompt
357
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
358
+
359
+ # We only use the pooled prompt output from the CLIPTextModel
360
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
361
+ prompt=prompt,
362
+ device=device,
363
+ num_images_per_prompt=num_images_per_prompt,
364
+ )
365
+ prompt_embeds = self._get_t5_prompt_embeds(
366
+ prompt=prompt_2,
367
+ num_images_per_prompt=num_images_per_prompt,
368
+ max_sequence_length=max_sequence_length,
369
+ device=device,
370
+ )
371
+
372
+ if self.text_encoder is not None:
373
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
374
+ # Retrieve the original scale by scaling back the LoRA layers
375
+ unscale_lora_layers(self.text_encoder, lora_scale)
376
+
377
+ if self.text_encoder_2 is not None:
378
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
379
+ # Retrieve the original scale by scaling back the LoRA layers
380
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
381
+
382
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
383
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
384
+
385
+ return prompt_embeds, pooled_prompt_embeds, text_ids
386
+
387
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
388
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
389
+ if isinstance(generator, list):
390
+ image_latents = [
391
+ retrieve_latents(self.vae.encode(image[i: i + 1]), generator=generator[i])
392
+ for i in range(image.shape[0])
393
+ ]
394
+ image_latents = torch.cat(image_latents, dim=0)
395
+ else:
396
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
397
+
398
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
399
+
400
+ return image_latents
401
+
402
+ def check_inputs(
403
+ self,
404
+ prompt,
405
+ prompt_2,
406
+ height,
407
+ width,
408
+ prompt_embeds=None,
409
+ pooled_prompt_embeds=None,
410
+ callback_on_step_end_tensor_inputs=None,
411
+ max_sequence_length=None,
412
+ ):
413
+ if height % 8 != 0 or width % 8 != 0:
414
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
415
+
416
+ if callback_on_step_end_tensor_inputs is not None and not all(
417
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
418
+ ):
419
+ raise ValueError(
420
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
421
+ )
422
+
423
+ if prompt is not None and prompt_embeds is not None:
424
+ raise ValueError(
425
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
426
+ " only forward one of the two."
427
+ )
428
+ elif prompt_2 is not None and prompt_embeds is not None:
429
+ raise ValueError(
430
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
431
+ " only forward one of the two."
432
+ )
433
+ elif prompt is None and prompt_embeds is None:
434
+ raise ValueError(
435
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
436
+ )
437
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
438
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
439
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
440
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
441
+
442
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
443
+ raise ValueError(
444
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
445
+ )
446
+
447
+ if max_sequence_length is not None and max_sequence_length > 512:
448
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
449
+
450
+ @staticmethod
451
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
452
+ latent_image_ids = torch.zeros(height // 2, width // 2, 3)
453
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
454
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
455
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
456
+ latent_image_ids = latent_image_ids.reshape(
457
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
458
+ )
459
+ return latent_image_ids.to(device=device, dtype=dtype)
460
+
461
+ @staticmethod
462
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
463
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
464
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
465
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
466
+ return latents
467
+
468
+ @staticmethod
469
+ def _unpack_latents(latents, height, width, vae_scale_factor):
470
+ batch_size, num_patches, channels = latents.shape
471
+
472
+ height = height // vae_scale_factor
473
+ width = width // vae_scale_factor
474
+
475
+ latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
476
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
477
+
478
+ latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
479
+
480
+ return latents
481
+
482
+ def enable_vae_slicing(self):
483
+ r"""
484
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
485
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
486
+ """
487
+ self.vae.enable_slicing()
488
+
489
+ def disable_vae_slicing(self):
490
+ r"""
491
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
492
+ computing decoding in one step.
493
+ """
494
+ self.vae.disable_slicing()
495
+
496
+ def enable_vae_tiling(self):
497
+ r"""
498
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
499
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
500
+ processing larger images.
501
+ """
502
+ self.vae.enable_tiling()
503
+
504
+ def disable_vae_tiling(self):
505
+ r"""
506
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
507
+ computing decoding in one step.
508
+ """
509
+ self.vae.disable_tiling()
510
+
511
+ def prepare_latents(
512
+ self,
513
+ batch_size,
514
+ num_channels_latents,
515
+ height,
516
+ width,
517
+ dtype,
518
+ device,
519
+ generator,
520
+ subject_image,
521
+ condition_image,
522
+ latents=None,
523
+ cond_number=1,
524
+ sub_number=1
525
+ ):
526
+ height_cond = 2 * (self.cond_size // self.vae_scale_factor)
527
+ width_cond = 2 * (self.cond_size // self.vae_scale_factor)
528
+ height = 2 * (int(height) // self.vae_scale_factor)
529
+ width = 2 * (int(width) // self.vae_scale_factor)
530
+
531
+ shape = (batch_size, num_channels_latents, height, width) # 1 16 106 80
532
+ noise_latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
533
+ noise_latents = self._pack_latents(noise_latents, batch_size, num_channels_latents, height, width)
534
+ noise_latent_image_ids, cond_latent_image_ids = resize_position_encoding(
535
+ batch_size,
536
+ height,
537
+ width,
538
+ height_cond,
539
+ width_cond,
540
+ device,
541
+ dtype,
542
+ )
543
+
544
+ latents_to_concat = [] # 不包含 latents
545
+ latents_ids_to_concat = [noise_latent_image_ids]
546
+
547
+ # subject
548
+ if subject_image is not None:
549
+ shape_subject = (batch_size, num_channels_latents, height_cond*sub_number, width_cond)
550
+ subject_image = subject_image.to(device=device, dtype=dtype)
551
+ subject_image_latents = self._encode_vae_image(image=subject_image, generator=generator)
552
+ subject_latents = self._pack_latents(subject_image_latents, batch_size, num_channels_latents, height_cond*sub_number, width_cond)
553
+ mask2 = torch.zeros(shape_subject, device=device, dtype=dtype)
554
+ mask2 = self._pack_latents(mask2, batch_size, num_channels_latents, height_cond*sub_number, width_cond)
555
+ latent_subject_ids = prepare_latent_subject_ids(height_cond, width_cond, device, dtype)
556
+ latent_subject_ids[:, 1] += 64 # fixed offset
557
+ subject_latent_image_ids = torch.concat([latent_subject_ids for _ in range(sub_number)], dim=-2)
558
+ latents_to_concat.append(subject_latents)
559
+ latents_ids_to_concat.append(subject_latent_image_ids)
560
+
561
+ # spatial
562
+ if condition_image is not None:
563
+ shape_cond = (batch_size, num_channels_latents, height_cond*cond_number, width_cond)
564
+ condition_image = condition_image.to(device=device, dtype=dtype)
565
+ self.vae = self.vae.to(dtype)
566
+ image_latents = self._encode_vae_image(image=condition_image, generator=generator)
567
+ cond_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height_cond*cond_number, width_cond)
568
+ mask3 = torch.zeros(shape_cond, device=device, dtype=dtype)
569
+ mask3 = self._pack_latents(mask3, batch_size, num_channels_latents, height_cond*cond_number, width_cond)
570
+ cond_latent_image_ids = cond_latent_image_ids
571
+ cond_latent_image_ids = torch.concat([cond_latent_image_ids for _ in range(cond_number)], dim=-2)
572
+ latents_ids_to_concat.append(cond_latent_image_ids)
573
+ latents_to_concat.append(cond_latents)
574
+
575
+ cond_latents = torch.concat(latents_to_concat, dim=-2)
576
+ latent_image_ids = torch.concat(latents_ids_to_concat, dim=-2)
577
+ return cond_latents, latent_image_ids, noise_latents
578
+
579
+ @property
580
+ def guidance_scale(self):
581
+ return self._guidance_scale
582
+
583
+ @property
584
+ def joint_attention_kwargs(self):
585
+ return self._joint_attention_kwargs
586
+
587
+ @property
588
+ def num_timesteps(self):
589
+ return self._num_timesteps
590
+
591
+ @property
592
+ def interrupt(self):
593
+ return self._interrupt
594
+
595
+ @torch.no_grad()
596
+ def __call__(
597
+ self,
598
+ args: Any = None,
599
+ prompt: Union[str, List[str]] = None,
600
+ prompt_2: Optional[Union[str, List[str]]] = None,
601
+ height: Optional[int] = None,
602
+ width: Optional[int] = None,
603
+ num_inference_steps: int = 28,
604
+ timesteps: List[int] = None,
605
+ guidance_scale: float = 3.5,
606
+ num_images_per_prompt: Optional[int] = 1,
607
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
608
+ latents: Optional[torch.FloatTensor] = None,
609
+ prompt_embeds: Optional[torch.FloatTensor] = None,
610
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
611
+ output_type: Optional[str] = "pil",
612
+ return_dict: bool = True,
613
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
614
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
615
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
616
+ max_sequence_length: int = 512,
617
+ spatial_images=None,
618
+ subject_images=None,
619
+ cond_size=512,
620
+ call_ids=None,
621
+ cuboids_segmasks=None,
622
+ store_qk=None,
623
+ store_qk_timesteps=None,
624
+ ):
625
+ assert not ((store_qk is None) ^ (store_qk_timesteps is None)), "Please provide both store_qk and store_qk_timesteps or neither of them."
626
+
627
+ height = height or self.default_sample_size * self.vae_scale_factor
628
+ width = width or self.default_sample_size * self.vae_scale_factor
629
+ self.cond_size = cond_size
630
+
631
+ # 1. Check inputs. Raise error if not correct
632
+ self.check_inputs(
633
+ prompt,
634
+ prompt_2,
635
+ height,
636
+ width,
637
+ prompt_embeds=prompt_embeds,
638
+ pooled_prompt_embeds=pooled_prompt_embeds,
639
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
640
+ max_sequence_length=max_sequence_length,
641
+ )
642
+
643
+ self._guidance_scale = guidance_scale
644
+ self._joint_attention_kwargs = joint_attention_kwargs
645
+ self._interrupt = False
646
+
647
+ cond_number = len(spatial_images)
648
+ sub_number = len(subject_images)
649
+
650
+ if sub_number > 0:
651
+ subject_image_ls = []
652
+ for subject_image in subject_images:
653
+ w, h = subject_image.size[:2]
654
+ scale = self.cond_size / max(h, w)
655
+ new_h, new_w = int(h * scale), int(w * scale)
656
+ subject_image = self.image_processor.preprocess(subject_image, height=new_h, width=new_w)
657
+ subject_image = subject_image.to(dtype=torch.float32)
658
+ pad_h = cond_size - subject_image.shape[-2]
659
+ pad_w = cond_size - subject_image.shape[-1]
660
+ subject_image = pad(
661
+ subject_image,
662
+ padding=(int(pad_w / 2), int(pad_h / 2), int(pad_w / 2), int(pad_h / 2)),
663
+ fill=0
664
+ )
665
+ subject_image_ls.append(subject_image)
666
+ subject_image = torch.concat(subject_image_ls, dim=-2)
667
+ else:
668
+ subject_image = None
669
+
670
+ if cond_number > 0:
671
+ condition_image_ls = []
672
+ for img in spatial_images:
673
+ condition_image = self.image_processor.preprocess(img, height=self.cond_size, width=self.cond_size)
674
+ condition_image = condition_image.to(dtype=torch.float32)
675
+ condition_image_ls.append(condition_image)
676
+ condition_image = torch.concat(condition_image_ls, dim=-2)
677
+ else:
678
+ condition_image = None
679
+
680
+ # 2. Define call parameters
681
+ if prompt is not None and isinstance(prompt, str):
682
+ batch_size = 1
683
+ elif prompt is not None and isinstance(prompt, list):
684
+ batch_size = len(prompt)
685
+ else:
686
+ batch_size = prompt_embeds.shape[0]
687
+
688
+ device = self._execution_device
689
+
690
+ lora_scale = (
691
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
692
+ )
693
+ (
694
+ prompt_embeds,
695
+ pooled_prompt_embeds,
696
+ text_ids,
697
+ ) = self.encode_prompt(
698
+ prompt=prompt,
699
+ prompt_2=prompt_2,
700
+ prompt_embeds=prompt_embeds,
701
+ pooled_prompt_embeds=pooled_prompt_embeds,
702
+ device=device,
703
+ num_images_per_prompt=num_images_per_prompt,
704
+ max_sequence_length=max_sequence_length,
705
+ lora_scale=lora_scale,
706
+ )
707
+
708
+ # 4. Prepare latent variables
709
+ num_channels_latents = self.transformer.config.in_channels // 4 # 16
710
+ cond_latents, latent_image_ids, noise_latents = self.prepare_latents(
711
+ batch_size * num_images_per_prompt,
712
+ num_channels_latents,
713
+ height,
714
+ width,
715
+ prompt_embeds.dtype,
716
+ device,
717
+ generator,
718
+ subject_image,
719
+ condition_image,
720
+ latents,
721
+ cond_number,
722
+ sub_number
723
+ )
724
+ latents = noise_latents
725
+ # 5. Prepare timesteps
726
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
727
+ image_seq_len = latents.shape[1]
728
+ mu = calculate_shift(
729
+ image_seq_len,
730
+ self.scheduler.config.base_image_seq_len,
731
+ self.scheduler.config.max_image_seq_len,
732
+ self.scheduler.config.base_shift,
733
+ self.scheduler.config.max_shift,
734
+ )
735
+ timesteps, num_inference_steps = retrieve_timesteps(
736
+ self.scheduler,
737
+ num_inference_steps,
738
+ device,
739
+ timesteps,
740
+ sigmas,
741
+ mu=mu,
742
+ )
743
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
744
+ self._num_timesteps = len(timesteps)
745
+
746
+ # handle guidance
747
+ if self.transformer.config.guidance_embeds:
748
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
749
+ guidance = guidance.expand(latents.shape[0])
750
+ else:
751
+ guidance = None
752
+
753
+ # 6. Denoising loop
754
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
755
+ for i, t in enumerate(timesteps):
756
+ if self.interrupt:
757
+ continue
758
+
759
+ store_qk_ = copy.deepcopy(store_qk)
760
+ if (store_qk_ is not None) and (i not in store_qk_timesteps):
761
+ store_qk_ = None
762
+ elif store_qk_ is not None:
763
+ store_qk_ = osp.join(store_qk, f"step_{i}")
764
+
765
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
766
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
767
+ noise_pred = self.transformer(
768
+ hidden_states=latents, # 1 4096 64
769
+ cond_hidden_states=cond_latents,
770
+ timestep=timestep / 1000,
771
+ guidance=guidance,
772
+ pooled_projections=pooled_prompt_embeds,
773
+ encoder_hidden_states=prompt_embeds,
774
+ txt_ids=text_ids,
775
+ img_ids=latent_image_ids,
776
+ joint_attention_kwargs=self.joint_attention_kwargs,
777
+ return_dict=False,
778
+ call_ids=call_ids,
779
+ cuboids_segmasks=cuboids_segmasks,
780
+ store_qk=store_qk_,
781
+ )[0]
782
+
783
+ # compute the previous noisy sample x_t -> x_t-1
784
+ latents_dtype = latents.dtype
785
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
786
+ latents = latents
787
+
788
+ if latents.dtype != latents_dtype:
789
+ if torch.backends.mps.is_available():
790
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
791
+ latents = latents.to(latents_dtype)
792
+
793
+ if callback_on_step_end is not None:
794
+ callback_kwargs = {}
795
+ for k in callback_on_step_end_tensor_inputs:
796
+ callback_kwargs[k] = locals()[k]
797
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
798
+
799
+ latents = callback_outputs.pop("latents", latents)
800
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
801
+
802
+ # call the callback, if provided
803
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
804
+ progress_bar.update()
805
+
806
+ if XLA_AVAILABLE:
807
+ xm.mark_step()
808
+
809
+ if output_type == "latent":
810
+ image = latents
811
+
812
+ else:
813
+ latents = self._unpack_latents(latents, height, width, self.vae_scale_factor)
814
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
815
+ image = self.vae.decode(latents.to(dtype=self.vae.dtype), return_dict=False)[0]
816
+ image = self.image_processor.postprocess(image, output_type=output_type)
817
+
818
+ # Offload all models
819
+ self.maybe_free_model_hooks()
820
+
821
+ if not return_dict:
822
+ return (image,)
823
+
824
+ return FluxPipelineOutput(images=image)
train/src/prompt_helper.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import os.path as osp
4
+
5
+
6
+ def load_text_encoders(args, class_one, class_two):
7
+ text_encoder_one = class_one.from_pretrained(
8
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
9
+ )
10
+ text_encoder_two = class_two.from_pretrained(
11
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
12
+ )
13
+ return text_encoder_one, text_encoder_two
14
+
15
+
16
+ def tokenize_prompt(tokenizer, prompt, max_sequence_length):
17
+ text_inputs = tokenizer(
18
+ prompt,
19
+ padding="max_length",
20
+ max_length=max_sequence_length,
21
+ truncation=True,
22
+ return_length=False,
23
+ return_overflowing_tokens=False,
24
+ return_tensors="pt",
25
+ )
26
+ text_input_ids = text_inputs.input_ids
27
+ return text_input_ids
28
+
29
+
30
+ def tokenize_prompt_clip(tokenizer, prompt):
31
+ text_inputs = tokenizer(
32
+ prompt,
33
+ padding="max_length",
34
+ max_length=77,
35
+ truncation=True,
36
+ return_length=False,
37
+ return_overflowing_tokens=False,
38
+ return_tensors="pt",
39
+ )
40
+ text_input_ids = text_inputs.input_ids
41
+ return text_input_ids
42
+
43
+
44
+ def tokenize_prompt_t5(tokenizer, prompt):
45
+ text_inputs = tokenizer(
46
+ prompt,
47
+ padding="max_length",
48
+ max_length=512,
49
+ truncation=True,
50
+ return_length=False,
51
+ return_overflowing_tokens=False,
52
+ return_tensors="pt",
53
+ )
54
+ text_input_ids = text_inputs.input_ids
55
+ return text_input_ids
56
+
57
+
58
+ def _encode_prompt_with_t5(
59
+ text_encoder,
60
+ tokenizer,
61
+ max_sequence_length=512,
62
+ prompt=None,
63
+ num_images_per_prompt=1,
64
+ device=None,
65
+ text_input_ids=None,
66
+ ):
67
+ prompt = [prompt] if isinstance(prompt, str) else prompt
68
+ batch_size = len(prompt)
69
+
70
+ if tokenizer is not None:
71
+ text_inputs = tokenizer(
72
+ prompt,
73
+ padding="max_length",
74
+ max_length=max_sequence_length,
75
+ truncation=True,
76
+ return_length=False,
77
+ return_overflowing_tokens=False,
78
+ return_tensors="pt",
79
+ )
80
+ text_input_ids = text_inputs.input_ids
81
+ else:
82
+ if text_input_ids is None:
83
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
84
+
85
+ prompt_embeds = text_encoder(text_input_ids.to(device))[0]
86
+
87
+ dtype = text_encoder.dtype
88
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
89
+
90
+ _, seq_len, _ = prompt_embeds.shape
91
+
92
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
93
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
94
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
95
+
96
+ return prompt_embeds
97
+
98
+
99
+ def _encode_prompt_with_clip(
100
+ text_encoder,
101
+ tokenizer,
102
+ prompt: str,
103
+ device=None,
104
+ text_input_ids=None,
105
+ num_images_per_prompt: int = 1,
106
+ ):
107
+ prompt = [prompt] if isinstance(prompt, str) else prompt
108
+ batch_size = len(prompt)
109
+
110
+ if tokenizer is not None:
111
+ text_inputs = tokenizer(
112
+ prompt,
113
+ padding="max_length",
114
+ max_length=77,
115
+ truncation=True,
116
+ return_overflowing_tokens=False,
117
+ return_length=False,
118
+ return_tensors="pt",
119
+ )
120
+
121
+ text_input_ids = text_inputs.input_ids
122
+ else:
123
+ if text_input_ids is None:
124
+ raise ValueError("text_input_ids must be provided when the tokenizer is not specified")
125
+
126
+ prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=False)
127
+
128
+ # Use pooled output of CLIPTextModel
129
+ prompt_embeds = prompt_embeds.pooler_output
130
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device)
131
+
132
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
133
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
134
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
135
+
136
+ return prompt_embeds
137
+
138
+
139
+ def encode_prompt(
140
+ args,
141
+ text_encoders,
142
+ tokenizers,
143
+ prompt: str,
144
+ max_sequence_length,
145
+ device=None,
146
+ num_images_per_prompt: int = 1,
147
+ text_input_ids_list=None,
148
+ ):
149
+ prompt = [prompt] if isinstance(prompt, str) else prompt
150
+ dtype = text_encoders[0].dtype
151
+
152
+ _prompt_ = "_".join(prompt)
153
+ if osp.exists(osp.join(args.inference_embeds_dir, f"{_prompt_}.pth")):
154
+ prompt_embeds = torch.load(osp.join(args.inference_embeds_dir, f"{_prompt_}.pth"))
155
+ pooled_prompt_embeds = prompt_embeds["pooled_prompt_embeds"]
156
+ prompt_embeds = prompt_embeds["prompt_embeds"]
157
+
158
+ else:
159
+ pooled_prompt_embeds = _encode_prompt_with_clip(
160
+ text_encoder=text_encoders[0],
161
+ tokenizer=tokenizers[0],
162
+ prompt=prompt,
163
+ device=device if device is not None else text_encoders[0].device,
164
+ num_images_per_prompt=num_images_per_prompt,
165
+ text_input_ids=text_input_ids_list[0] if text_input_ids_list else None,
166
+ )
167
+
168
+ prompt_embeds = _encode_prompt_with_t5(
169
+ text_encoder=text_encoders[1],
170
+ tokenizer=tokenizers[1],
171
+ max_sequence_length=max_sequence_length,
172
+ prompt=prompt,
173
+ num_images_per_prompt=num_images_per_prompt,
174
+ device=device if device is not None else text_encoders[1].device,
175
+ text_input_ids=text_input_ids_list[1] if text_input_ids_list else None,
176
+ )
177
+
178
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
179
+
180
+ return prompt_embeds, pooled_prompt_embeds, text_ids
181
+
182
+
183
+ def encode_token_ids(text_encoders, tokens, accelerator, num_images_per_prompt=1, device=None):
184
+ text_encoder_clip = text_encoders[0]
185
+ text_encoder_t5 = text_encoders[1]
186
+ tokens_clip, tokens_t5 = tokens[0], tokens[1]
187
+ batch_size = tokens_clip.shape[0]
188
+
189
+ if device == "cpu":
190
+ device = "cpu"
191
+ else:
192
+ device = accelerator.device
193
+
194
+ # clip
195
+ prompt_embeds = text_encoder_clip(tokens_clip.to(device), output_hidden_states=False)
196
+ # Use pooled output of CLIPTextModel
197
+ prompt_embeds = prompt_embeds.pooler_output
198
+ prompt_embeds = prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
199
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
200
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
201
+ pooled_prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
202
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=text_encoder_clip.dtype, device=accelerator.device)
203
+
204
+ # t5
205
+ prompt_embeds = text_encoder_t5(tokens_t5.to(device))[0]
206
+ dtype = text_encoder_t5.dtype
207
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=accelerator.device)
208
+ _, seq_len, _ = prompt_embeds.shape
209
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
210
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
211
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
212
+
213
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=accelerator.device, dtype=dtype)
214
+
215
+ return prompt_embeds, pooled_prompt_embeds, text_ids
train/src/transformer_flux.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ import os
9
+ import os.path as osp
10
+
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
13
+ from diffusers.models.attention import FeedForward
14
+ from diffusers.models.attention_processor import (
15
+ Attention,
16
+ AttentionProcessor,
17
+ FluxAttnProcessor2_0,
18
+ FluxAttnProcessor2_0_NPU,
19
+ FusedFluxAttnProcessor2_0,
20
+ )
21
+ from diffusers.models.modeling_utils import ModelMixin
22
+ from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
23
+ from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
24
+ from diffusers.utils.import_utils import is_torch_npu_available
25
+ from diffusers.utils.torch_utils import maybe_allow_in_graph
26
+ from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
27
+ from diffusers.models.modeling_outputs import Transformer2DModelOutput
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+ @maybe_allow_in_graph
32
+ class FluxSingleTransformerBlock(nn.Module):
33
+
34
+ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
35
+ super().__init__()
36
+ self.mlp_hidden_dim = int(dim * mlp_ratio)
37
+
38
+ self.norm = AdaLayerNormZeroSingle(dim)
39
+ self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
40
+ self.act_mlp = nn.GELU(approximate="tanh")
41
+ self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
42
+
43
+ if is_torch_npu_available():
44
+ processor = FluxAttnProcessor2_0_NPU()
45
+ else:
46
+ processor = FluxAttnProcessor2_0()
47
+ self.attn = Attention(
48
+ query_dim=dim,
49
+ cross_attention_dim=None,
50
+ dim_head=attention_head_dim,
51
+ heads=num_attention_heads,
52
+ out_dim=dim,
53
+ bias=True,
54
+ processor=processor,
55
+ qk_norm="rms_norm",
56
+ eps=1e-6,
57
+ pre_only=True,
58
+ )
59
+
60
+ def forward(
61
+ self,
62
+ hidden_states: torch.Tensor,
63
+ cond_hidden_states: torch.Tensor,
64
+ temb: torch.Tensor,
65
+ cond_temb: torch.Tensor,
66
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
67
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
68
+ ) -> torch.Tensor:
69
+ use_cond = cond_hidden_states is not None
70
+
71
+ residual = hidden_states
72
+ norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
73
+ mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
74
+
75
+ if use_cond:
76
+ residual_cond = cond_hidden_states
77
+ norm_cond_hidden_states, cond_gate = self.norm(cond_hidden_states, emb=cond_temb)
78
+ mlp_cond_hidden_states = self.act_mlp(self.proj_mlp(norm_cond_hidden_states))
79
+
80
+ norm_hidden_states_concat = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
81
+
82
+ joint_attention_kwargs = joint_attention_kwargs or {}
83
+ attn_output = self.attn(
84
+ hidden_states=norm_hidden_states_concat,
85
+ image_rotary_emb=image_rotary_emb,
86
+ use_cond=use_cond,
87
+ **joint_attention_kwargs,
88
+ )
89
+ if use_cond:
90
+ attn_output, cond_attn_output = attn_output
91
+
92
+ hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
93
+ gate = gate.unsqueeze(1)
94
+ hidden_states = gate * self.proj_out(hidden_states)
95
+ hidden_states = residual + hidden_states
96
+
97
+ if use_cond:
98
+ condition_latents = torch.cat([cond_attn_output, mlp_cond_hidden_states], dim=2)
99
+ cond_gate = cond_gate.unsqueeze(1)
100
+ condition_latents = cond_gate * self.proj_out(condition_latents)
101
+ condition_latents = residual_cond + condition_latents
102
+
103
+ if hidden_states.dtype == torch.float16:
104
+ hidden_states = hidden_states.clip(-65504, 65504)
105
+
106
+ return hidden_states, condition_latents if use_cond else None
107
+
108
+
109
+ @maybe_allow_in_graph
110
+ class FluxTransformerBlock(nn.Module):
111
+ def __init__(
112
+ self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
113
+ ):
114
+ super().__init__()
115
+
116
+ self.norm1 = AdaLayerNormZero(dim)
117
+
118
+ self.norm1_context = AdaLayerNormZero(dim)
119
+
120
+ if hasattr(F, "scaled_dot_product_attention"):
121
+ processor = FluxAttnProcessor2_0()
122
+ else:
123
+ raise ValueError(
124
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
125
+ )
126
+ self.attn = Attention(
127
+ query_dim=dim,
128
+ cross_attention_dim=None,
129
+ added_kv_proj_dim=dim,
130
+ dim_head=attention_head_dim,
131
+ heads=num_attention_heads,
132
+ out_dim=dim,
133
+ context_pre_only=False,
134
+ bias=True,
135
+ processor=processor,
136
+ qk_norm=qk_norm,
137
+ eps=eps,
138
+ )
139
+
140
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
141
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
142
+
143
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
144
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
145
+
146
+ # let chunk size default to None
147
+ self._chunk_size = None
148
+ self._chunk_dim = 0
149
+
150
+ def forward(
151
+ self,
152
+ hidden_states: torch.Tensor,
153
+ cond_hidden_states: torch.Tensor,
154
+ encoder_hidden_states: torch.Tensor,
155
+ temb: torch.Tensor,
156
+ cond_temb: torch.Tensor,
157
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
158
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
159
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
160
+ use_cond = cond_hidden_states is not None
161
+
162
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
163
+ if use_cond:
164
+ (
165
+ norm_cond_hidden_states,
166
+ cond_gate_msa,
167
+ cond_shift_mlp,
168
+ cond_scale_mlp,
169
+ cond_gate_mlp,
170
+ ) = self.norm1(cond_hidden_states, emb=cond_temb)
171
+
172
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
173
+ encoder_hidden_states, emb=temb
174
+ )
175
+
176
+ norm_hidden_states = torch.concat([norm_hidden_states, norm_cond_hidden_states], dim=-2)
177
+
178
+ joint_attention_kwargs = joint_attention_kwargs or {}
179
+ # Attention.
180
+ attention_outputs = self.attn(
181
+ hidden_states=norm_hidden_states,
182
+ encoder_hidden_states=norm_encoder_hidden_states,
183
+ image_rotary_emb=image_rotary_emb,
184
+ use_cond=use_cond,
185
+ **joint_attention_kwargs,
186
+ )
187
+
188
+ attn_output, context_attn_output = attention_outputs[:2]
189
+ cond_attn_output = attention_outputs[2] if use_cond else None
190
+
191
+ # Process attention outputs for the `hidden_states`.
192
+ attn_output = gate_msa.unsqueeze(1) * attn_output
193
+ hidden_states = hidden_states + attn_output
194
+
195
+ if use_cond:
196
+ cond_attn_output = cond_gate_msa.unsqueeze(1) * cond_attn_output
197
+ cond_hidden_states = cond_hidden_states + cond_attn_output
198
+
199
+ norm_hidden_states = self.norm2(hidden_states)
200
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
201
+
202
+ if use_cond:
203
+ norm_cond_hidden_states = self.norm2(cond_hidden_states)
204
+ norm_cond_hidden_states = (
205
+ norm_cond_hidden_states * (1 + cond_scale_mlp[:, None])
206
+ + cond_shift_mlp[:, None]
207
+ )
208
+
209
+ ff_output = self.ff(norm_hidden_states)
210
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
211
+ hidden_states = hidden_states + ff_output
212
+
213
+ if use_cond:
214
+ cond_ff_output = self.ff(norm_cond_hidden_states)
215
+ cond_ff_output = cond_gate_mlp.unsqueeze(1) * cond_ff_output
216
+ cond_hidden_states = cond_hidden_states + cond_ff_output
217
+
218
+ # Process attention outputs for the `encoder_hidden_states`.
219
+
220
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
221
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
222
+
223
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
224
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
225
+
226
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
227
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
228
+ if encoder_hidden_states.dtype == torch.float16:
229
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
230
+
231
+ return encoder_hidden_states, hidden_states, cond_hidden_states if use_cond else None
232
+
233
+
234
+ class FluxTransformer2DModel(
235
+ ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin
236
+ ):
237
+ _supports_gradient_checkpointing = True
238
+ _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
239
+
240
+ @register_to_config
241
+ def __init__(
242
+ self,
243
+ patch_size: int = 1,
244
+ in_channels: int = 64,
245
+ out_channels: Optional[int] = None,
246
+ num_layers: int = 19,
247
+ num_single_layers: int = 38,
248
+ attention_head_dim: int = 128,
249
+ num_attention_heads: int = 24,
250
+ joint_attention_dim: int = 4096,
251
+ pooled_projection_dim: int = 768,
252
+ guidance_embeds: bool = False,
253
+ axes_dims_rope: Tuple[int] = (16, 56, 56),
254
+ ):
255
+ super().__init__()
256
+ self.out_channels = out_channels or in_channels
257
+ self.inner_dim = num_attention_heads * attention_head_dim
258
+
259
+ self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
260
+
261
+ text_time_guidance_cls = (
262
+ CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
263
+ )
264
+ self.time_text_embed = text_time_guidance_cls(
265
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
266
+ )
267
+
268
+ self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
269
+ self.x_embedder = nn.Linear(in_channels, self.inner_dim)
270
+
271
+ self.transformer_blocks = nn.ModuleList(
272
+ [
273
+ FluxTransformerBlock(
274
+ dim=self.inner_dim,
275
+ num_attention_heads=num_attention_heads,
276
+ attention_head_dim=attention_head_dim,
277
+ )
278
+ for _ in range(num_layers)
279
+ ]
280
+ )
281
+
282
+ self.single_transformer_blocks = nn.ModuleList(
283
+ [
284
+ FluxSingleTransformerBlock(
285
+ dim=self.inner_dim,
286
+ num_attention_heads=num_attention_heads,
287
+ attention_head_dim=attention_head_dim,
288
+ )
289
+ for _ in range(num_single_layers)
290
+ ]
291
+ )
292
+
293
+ self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
294
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
295
+
296
+ self.gradient_checkpointing = False
297
+
298
+ @property
299
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
300
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
301
+ r"""
302
+ Returns:
303
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
304
+ indexed by its weight name.
305
+ """
306
+ # set recursively
307
+ processors = {}
308
+
309
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
310
+ if hasattr(module, "get_processor"):
311
+ processors[f"{name}.processor"] = module.get_processor()
312
+
313
+ for sub_name, child in module.named_children():
314
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
315
+
316
+ return processors
317
+
318
+ for name, module in self.named_children():
319
+ fn_recursive_add_processors(name, module, processors)
320
+
321
+ return processors
322
+
323
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
324
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
325
+ r"""
326
+ Sets the attention processor to use to compute attention.
327
+
328
+ Parameters:
329
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
330
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
331
+ for **all** `Attention` layers.
332
+
333
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
334
+ processor. This is strongly recommended when setting trainable attention processors.
335
+
336
+ """
337
+ count = len(self.attn_processors.keys())
338
+
339
+ if isinstance(processor, dict) and len(processor) != count:
340
+ raise ValueError(
341
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
342
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
343
+ )
344
+
345
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
346
+ if hasattr(module, "set_processor"):
347
+ if not isinstance(processor, dict):
348
+ module.set_processor(processor)
349
+ else:
350
+ module.set_processor(processor.pop(f"{name}.processor"))
351
+
352
+ for sub_name, child in module.named_children():
353
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
354
+
355
+ for name, module in self.named_children():
356
+ fn_recursive_attn_processor(name, module, processor)
357
+
358
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
359
+ def fuse_qkv_projections(self):
360
+ """
361
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
362
+ are fused. For cross-attention modules, key and value projection matrices are fused.
363
+
364
+ <Tip warning={true}>
365
+
366
+ This API is 🧪 experimental.
367
+
368
+ </Tip>
369
+ """
370
+ self.original_attn_processors = None
371
+
372
+ for _, attn_processor in self.attn_processors.items():
373
+ if "Added" in str(attn_processor.__class__.__name__):
374
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
375
+
376
+ self.original_attn_processors = self.attn_processors
377
+
378
+ for module in self.modules():
379
+ if isinstance(module, Attention):
380
+ module.fuse_projections(fuse=True)
381
+
382
+ self.set_attn_processor(FusedFluxAttnProcessor2_0())
383
+
384
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
385
+ def unfuse_qkv_projections(self):
386
+ """Disables the fused QKV projection if enabled.
387
+
388
+ <Tip warning={true}>
389
+
390
+ This API is 🧪 experimental.
391
+
392
+ </Tip>
393
+
394
+ """
395
+ if self.original_attn_processors is not None:
396
+ self.set_attn_processor(self.original_attn_processors)
397
+
398
+ def _set_gradient_checkpointing(self, module, value=False):
399
+ if hasattr(module, "gradient_checkpointing"):
400
+ module.gradient_checkpointing = value
401
+
402
+ def forward(
403
+ self,
404
+ hidden_states: torch.Tensor,
405
+ cond_hidden_states: torch.Tensor = None,
406
+ encoder_hidden_states: torch.Tensor = None,
407
+ pooled_projections: torch.Tensor = None,
408
+ timestep: torch.LongTensor = None,
409
+ img_ids: torch.Tensor = None,
410
+ txt_ids: torch.Tensor = None,
411
+ guidance: torch.Tensor = None,
412
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
413
+ controlnet_block_samples=None,
414
+ controlnet_single_block_samples=None,
415
+ return_dict: bool = True,
416
+ controlnet_blocks_repeat: bool = False,
417
+ call_ids: list = None,
418
+ cuboids_segmasks: torch.Tensor = None,
419
+ store_qk: bool = False,
420
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
421
+ if cond_hidden_states is not None:
422
+ use_condition = True
423
+ else:
424
+ use_condition = False
425
+
426
+ if joint_attention_kwargs is not None:
427
+ joint_attention_kwargs = joint_attention_kwargs.copy()
428
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
429
+ else:
430
+ lora_scale = 1.0
431
+ joint_attention_kwargs = {}
432
+ joint_attention_kwargs["call_ids"] = call_ids
433
+ joint_attention_kwargs["cuboids_segmasks"] = cuboids_segmasks
434
+
435
+ if USE_PEFT_BACKEND:
436
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
437
+ scale_lora_layers(self, lora_scale)
438
+ else:
439
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
440
+ logger.warning(
441
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
442
+ )
443
+
444
+ hidden_states = self.x_embedder(hidden_states)
445
+ cond_hidden_states = self.x_embedder(cond_hidden_states)
446
+
447
+ timestep = timestep.to(hidden_states.dtype) * 1000
448
+ if guidance is not None:
449
+ guidance = guidance.to(hidden_states.dtype) * 1000
450
+ else:
451
+ guidance = None
452
+
453
+ temb = (
454
+ self.time_text_embed(timestep, pooled_projections)
455
+ if guidance is None
456
+ else self.time_text_embed(timestep, guidance, pooled_projections)
457
+ )
458
+
459
+ cond_temb = (
460
+ self.time_text_embed(torch.ones_like(timestep) * 0, pooled_projections)
461
+ if guidance is None
462
+ else self.time_text_embed(
463
+ torch.ones_like(timestep) * 0, guidance, pooled_projections
464
+ )
465
+ )
466
+
467
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
468
+
469
+ if txt_ids.ndim == 3:
470
+ logger.warning(
471
+ "Passing `txt_ids` 3d torch.Tensor is deprecated."
472
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
473
+ )
474
+ txt_ids = txt_ids[0]
475
+ if img_ids.ndim == 3:
476
+ logger.warning(
477
+ "Passing `img_ids` 3d torch.Tensor is deprecated."
478
+ "Please remove the batch dimension and pass it as a 2d torch Tensor"
479
+ )
480
+ img_ids = img_ids[0]
481
+
482
+ ids = torch.cat((txt_ids, img_ids), dim=0)
483
+ image_rotary_emb = self.pos_embed(ids)
484
+
485
+ if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
486
+ ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
487
+ ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
488
+ joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
489
+
490
+ for index_block, block in enumerate(self.transformer_blocks):
491
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
492
+
493
+ def create_custom_forward(module, return_dict=None):
494
+ def custom_forward(*inputs):
495
+ if return_dict is not None:
496
+ return module(*inputs, return_dict=return_dict)
497
+ else:
498
+ return module(*inputs)
499
+
500
+ return custom_forward
501
+
502
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
503
+ encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
504
+ create_custom_forward(block),
505
+ hidden_states,
506
+ encoder_hidden_states,
507
+ temb,
508
+ image_rotary_emb,
509
+ cond_temb=cond_temb if use_condition else None,
510
+ cond_hidden_states=cond_hidden_states if use_condition else None,
511
+ **ckpt_kwargs,
512
+ )
513
+
514
+ else:
515
+ if store_qk:
516
+ overall_block_idx = index_block
517
+ joint_attention_kwargs["store_qk"] = osp.join(store_qk, f"{str(overall_block_idx).zfill(3)}")
518
+
519
+ encoder_hidden_states, hidden_states, cond_hidden_states = block(
520
+ hidden_states=hidden_states,
521
+ encoder_hidden_states=encoder_hidden_states,
522
+ cond_hidden_states=cond_hidden_states if use_condition else None,
523
+ temb=temb,
524
+ cond_temb=cond_temb if use_condition else None,
525
+ image_rotary_emb=image_rotary_emb,
526
+ joint_attention_kwargs=joint_attention_kwargs,
527
+ )
528
+
529
+ # controlnet residual
530
+ if controlnet_block_samples is not None:
531
+ interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
532
+ interval_control = int(np.ceil(interval_control))
533
+ # For Xlabs ControlNet.
534
+ if controlnet_blocks_repeat:
535
+ hidden_states = (
536
+ hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
537
+ )
538
+ else:
539
+ hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
540
+
541
+ # note that the encoder_hidden_states are concatenated in FRONT of the hidden states, not BEHIND
542
+ # this would change the attention mask calculation.
543
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
544
+
545
+ for index_block, block in enumerate(self.single_transformer_blocks):
546
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
547
+
548
+ def create_custom_forward(module, return_dict=None):
549
+ def custom_forward(*inputs):
550
+ if return_dict is not None:
551
+ return module(*inputs, return_dict=return_dict)
552
+ else:
553
+ return module(*inputs)
554
+
555
+ return custom_forward
556
+
557
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
558
+ hidden_states, cond_hidden_states = torch.utils.checkpoint.checkpoint(
559
+ create_custom_forward(block),
560
+ hidden_states,
561
+ temb,
562
+ image_rotary_emb,
563
+ cond_temb=cond_temb if use_condition else None,
564
+ cond_hidden_states=cond_hidden_states if use_condition else None,
565
+ **ckpt_kwargs,
566
+ )
567
+
568
+ else:
569
+ if store_qk:
570
+ overall_block_idx = index_block + len(self.transformer_blocks)
571
+ joint_attention_kwargs["store_qk"] = osp.join(store_qk, f"{str(overall_block_idx).zfill(3)}")
572
+
573
+ hidden_states, cond_hidden_states = block(
574
+ hidden_states=hidden_states,
575
+ cond_hidden_states=cond_hidden_states if use_condition else None,
576
+ temb=temb,
577
+ cond_temb=cond_temb if use_condition else None,
578
+ image_rotary_emb=image_rotary_emb,
579
+ joint_attention_kwargs=joint_attention_kwargs,
580
+ )
581
+
582
+ # controlnet residual
583
+ if controlnet_single_block_samples is not None:
584
+ interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
585
+ interval_control = int(np.ceil(interval_control))
586
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
587
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
588
+ + controlnet_single_block_samples[index_block // interval_control]
589
+ )
590
+
591
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
592
+
593
+ hidden_states = self.norm_out(hidden_states, temb)
594
+ output = self.proj_out(hidden_states)
595
+
596
+ if USE_PEFT_BACKEND:
597
+ # remove `lora_scale` from each PEFT layer
598
+ unscale_lora_layers(self, lora_scale)
599
+
600
+ if not return_dict:
601
+ return (output,)
602
+
603
+ return Transformer2DModelOutput(sample=output)
train/train.py ADDED
@@ -0,0 +1,1463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import logging
4
+ import random
5
+ import math
6
+ import os
7
+ import shutil
8
+ import gc
9
+ from contextlib import nullcontext
10
+ from pathlib import Path
11
+ import re
12
+ from safetensors.torch import save_file
13
+
14
+ from PIL import Image
15
+ import numpy as np
16
+ import torch.utils.checkpoint
17
+ import transformers
18
+ from accelerate import Accelerator
19
+ from accelerate.logging import get_logger
20
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
21
+
22
+ from tqdm.auto import tqdm
23
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
24
+
25
+ import diffusers
26
+
27
+ from diffusers import (
28
+ AutoencoderKL,
29
+ FlowMatchEulerDiscreteScheduler
30
+ )
31
+ from diffusers.optimization import get_scheduler
32
+ from diffusers.training_utils import (
33
+ cast_training_params,
34
+ compute_density_for_timestep_sampling,
35
+ compute_loss_weighting_for_sd3,
36
+ )
37
+ import os.path as osp
38
+ from diffusers.utils.torch_utils import is_compiled_module
39
+ from diffusers.utils import (
40
+ check_min_version,
41
+ is_wandb_available,
42
+ convert_unet_state_dict_to_peft
43
+ )
44
+
45
+ from src.prompt_helper import *
46
+ from src.lora_helper import *
47
+ from src.pipeline import FluxPipeline, resize_position_encoding, prepare_latent_subject_ids
48
+ from src.layers import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
49
+ from src.transformer_flux import FluxTransformer2DModel
50
+ from src.jsonl_datasets import make_train_dataset, collate_fn
51
+
52
+ if is_wandb_available():
53
+ import wandb
54
+
55
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
56
+ check_min_version("0.31.0.dev0")
57
+
58
+ logger = get_logger(__name__)
59
+
60
+ import matplotlib.pyplot as plt
61
+ import torch
62
+
63
+ def create_validation_figure(output_image, spatial_image, subject_image, prompt, validation_idx, global_step):
64
+ """
65
+ Create a 2x2 matplotlib figure showing validation results.
66
+
67
+ Args:
68
+ output_image: Generated output image (PIL Image)
69
+ spatial_image: Spatial condition image (PIL Image or None)
70
+ subject_image: Subject condition image (PIL Image or None)
71
+ prompt: Text prompt string
72
+ validation_idx: Index of validation prompt
73
+ global_step: Current training step
74
+
75
+ Returns:
76
+ matplotlib figure
77
+ """
78
+ fig, axes = plt.subplots(2, 2, figsize=(12, 20))
79
+ fig.suptitle(f'Validation Results - Step {global_step} - Prompt {validation_idx}', fontsize=14)
80
+
81
+ # Output image (top-left)
82
+ axes[0, 0].imshow(np.array(output_image))
83
+ axes[0, 0].set_title('Generated Output')
84
+ axes[0, 0].axis('off')
85
+
86
+ # Spatial condition (top-right)
87
+ if spatial_image is not None:
88
+ axes[0, 1].imshow(np.array(spatial_image))
89
+ axes[0, 1].set_title('Spatial Condition')
90
+ else:
91
+ axes[0, 1].text(0.5, 0.5, 'NOT AVAILABLE',
92
+ horizontalalignment='center', verticalalignment='center',
93
+ transform=axes[0, 1].transAxes, fontsize=14, fontweight='bold')
94
+ axes[0, 1].set_title('Spatial Condition')
95
+ axes[0, 1].axis('off')
96
+
97
+ # Subject condition (bottom-left)
98
+ if subject_image is not None:
99
+ axes[1, 0].imshow(np.array(subject_image))
100
+ axes[1, 0].set_title('Subject Condition')
101
+ else:
102
+ axes[1, 0].text(0.5, 0.5, 'NOT AVAILABLE',
103
+ horizontalalignment='center', verticalalignment='center',
104
+ transform=axes[1, 0].transAxes, fontsize=14, fontweight='bold')
105
+ axes[1, 0].set_title('Subject Condition')
106
+ axes[1, 0].axis('off')
107
+
108
+ # Prompt and info (bottom-right)
109
+ info_text = f'Prompt:\n"{prompt}"\n\nStep: {global_step}\nValidation Index: {validation_idx}'
110
+ axes[1, 1].text(0.5, 0.5, info_text,
111
+ horizontalalignment='center', verticalalignment='center',
112
+ transform=axes[1, 1].transAxes, fontsize=10, wrap=True)
113
+ axes[1, 1].set_title('Prompt & Info')
114
+ axes[1, 1].axis('off')
115
+
116
+ plt.tight_layout()
117
+ return fig
118
+
119
+
120
+ def visualize_training_data(batch, vae, model_input, noisy_model_input, cond_input, subject_input, args, global_step, accelerator):
121
+ """
122
+ Visualize training data including all entities from the batch.
123
+
124
+ Args:
125
+ batch: Training batch containing data
126
+ vae: VAE model for decoding latents
127
+ model_input: Clean latents before adding noise
128
+ noisy_model_input: Noisy latents passed to transformer
129
+ cond_input: Spatial condition latents (may be None)
130
+ subject_input: Subject condition latents (may be None)
131
+ args: Training arguments
132
+ global_step: Current training step
133
+ accelerator: Accelerator instance
134
+ """
135
+
136
+ # Check availability of conditions
137
+ has_spatial_condition = batch["cond_pixel_values"] is not None
138
+ has_subject_condition = batch["subject_pixel_values"] is not None
139
+ has_cuboids_segmasks = "cuboids_segmasks" in batch and batch["cuboids_segmasks"] is not None
140
+ has_cuboids_segmasks_bev = "cuboids_segmasks_bev" in batch and batch["cuboids_segmasks_bev"] is not None
141
+
142
+ # Initialize variables
143
+ spatial_img = None
144
+ subject_img = None
145
+
146
+ with torch.no_grad():
147
+ # Get VAE config for proper decoding
148
+ vae_config_shift_factor = vae.config.shift_factor
149
+ vae_config_scaling_factor = vae.config.scaling_factor
150
+ vae_dtype = vae.dtype
151
+ vae = vae.to(torch.float32)
152
+
153
+ # Decode spatial condition if available
154
+ if has_spatial_condition:
155
+ cond_for_decode = (cond_input / vae_config_scaling_factor) + vae_config_shift_factor
156
+ spatial_decoded = vae.decode(cond_for_decode.float()).sample
157
+ spatial_decoded = (spatial_decoded / 2 + 0.5).clamp(0, 1) # Normalize to [0,1]
158
+ spatial_img = spatial_decoded[0].float().cpu().permute(1, 2, 0).numpy()
159
+
160
+ # Decode subject condition if available
161
+ if has_subject_condition:
162
+ subject_for_decode = (subject_input / vae_config_scaling_factor) + vae_config_shift_factor
163
+ subject_decoded = vae.decode(subject_for_decode.float()).sample
164
+ subject_decoded = (subject_decoded / 2 + 0.5).clamp(0, 1) # Normalize to [0,1]
165
+ subject_img = subject_decoded[0].float().cpu().permute(1, 2, 0).numpy()
166
+
167
+ # Decode clean model input
168
+ clean_for_decode = (model_input / vae_config_scaling_factor) + vae_config_shift_factor
169
+ clean_decoded = vae.decode(clean_for_decode.float()).sample
170
+ clean_decoded = (clean_decoded / 2 + 0.5).clamp(0, 1)
171
+
172
+ # Decode noisy model input
173
+ noisy_for_decode = (noisy_model_input / vae_config_scaling_factor) + vae_config_shift_factor
174
+ noisy_decoded = vae.decode(noisy_for_decode.float()).sample
175
+ noisy_decoded = (noisy_decoded / 2 + 0.5).clamp(0, 1)
176
+
177
+ # Convert to CPU and numpy for visualization (take first batch item)
178
+ clean_img = clean_decoded[0].float().cpu().permute(1, 2, 0).numpy()
179
+ noisy_img = noisy_decoded[0].float().cpu().permute(1, 2, 0).numpy()
180
+
181
+ # Get text prompt and other info
182
+ text_prompt = batch["prompts"][0] if isinstance(batch["prompts"], list) else batch["prompts"]
183
+ call_id = batch["call_ids"][0] if batch["call_ids"] is not None else "N/A"
184
+
185
+ # Create figure with more subplots to accommodate all entities including BEV
186
+ fig, axes = plt.subplots(4, 3, figsize=(18, 24))
187
+ # fig.suptitle(f'Training Data Visualization - Step {global_step}', fontsize=16)
188
+
189
+ # Spatial condition (0,0)
190
+ if has_spatial_condition and spatial_img is not None:
191
+ axes[0, 0].imshow(spatial_img)
192
+ axes[0, 0].set_title('Spatial Condition')
193
+ else:
194
+ axes[0, 0].text(0.5, 0.5, 'NOT AVAILABLE',
195
+ horizontalalignment='center', verticalalignment='center',
196
+ transform=axes[0, 0].transAxes, fontsize=14, fontweight='bold')
197
+ axes[0, 0].set_title('Spatial Condition')
198
+ axes[0, 0].axis('off')
199
+
200
+ # Subject condition (0,1)
201
+ if has_subject_condition and subject_img is not None:
202
+ axes[0, 1].imshow(subject_img)
203
+ axes[0, 1].set_title('Subject Condition')
204
+ else:
205
+ axes[0, 1].text(0.5, 0.5, 'NOT AVAILABLE',
206
+ horizontalalignment='center', verticalalignment='center',
207
+ transform=axes[0, 1].transAxes, fontsize=14, fontweight='bold')
208
+ axes[0, 1].set_title('Subject Condition')
209
+ axes[0, 1].axis('off')
210
+
211
+ # Clean model input (0,2)
212
+ axes[0, 2].imshow(clean_img)
213
+ axes[0, 2].set_title('Clean Model Input')
214
+ axes[0, 2].axis('off')
215
+
216
+ # Noisy model input (1,0)
217
+ axes[1, 0].imshow(noisy_img)
218
+ axes[1, 0].set_title('Noisy Model Input')
219
+ axes[1, 0].axis('off')
220
+
221
+ # Cuboids segmentation masks with legend (1,1 and 1,2)
222
+ if has_cuboids_segmasks:
223
+ segmask = batch["cuboids_segmasks"][0].float().cpu().numpy() # Shape: (n_subjects, h, w)
224
+ n_subjects, h, w = segmask.shape
225
+
226
+ # Only use first 4 subjects for visualization
227
+ n_subjects_to_show = min(4, n_subjects)
228
+
229
+ # Create colored segmentation visualization
230
+ np.random.seed(42) # For consistent colors
231
+ colors = np.random.rand(n_subjects_to_show + 1, 3) # +1 for background
232
+ colors[0] = [0, 0, 0] # Background is black
233
+
234
+ # Create 2x2 grid of individual subject masks
235
+ grid_h, grid_w = 2, 2
236
+ combined_mask = np.zeros((h * grid_h, w * grid_w, 3))
237
+
238
+ for idx in range(n_subjects_to_show):
239
+ row = idx // grid_w
240
+ col = idx % grid_w
241
+
242
+ # Create binary mask for this subject
243
+ subject_mask = np.zeros((h, w, 3))
244
+ mask = segmask[idx] > 0.5 # Binary threshold
245
+ subject_mask[mask] = colors[idx + 1]
246
+
247
+ # Place in grid
248
+ combined_mask[row*h:(row+1)*h, col*w:(col+1)*w] = subject_mask
249
+
250
+ axes[1, 1].imshow(combined_mask)
251
+ axes[1, 1].set_title('Cuboids Segmentation (2x2 Grid)')
252
+ axes[1, 1].axis('off')
253
+
254
+ # Create legend in the next subplot (1,2) - only for first 4 subjects
255
+ axes[1, 2].set_xlim(0, 1)
256
+ axes[1, 2].set_ylim(0, 1)
257
+
258
+ # Add legend entries
259
+ legend_y_positions = np.linspace(0.9, 0.1, n_subjects_to_show + 1)
260
+ axes[1, 2].text(0.1, legend_y_positions[0], f"Background",
261
+ color=colors[0], fontsize=12, fontweight='bold')
262
+
263
+ for subject_idx in range(n_subjects_to_show):
264
+ axes[1, 2].text(0.1, legend_y_positions[subject_idx + 1],
265
+ f"Subject {subject_idx}",
266
+ color=colors[subject_idx + 1], fontsize=12, fontweight='bold')
267
+
268
+ axes[1, 2].set_title('Segmentation Legend (First 4)')
269
+ axes[1, 2].axis('off')
270
+ else:
271
+ axes[1, 1].text(0.5, 0.5, 'NOT AVAILABLE',
272
+ horizontalalignment='center', verticalalignment='center',
273
+ transform=axes[1, 1].transAxes, fontsize=14, fontweight='bold')
274
+ axes[1, 1].set_title('Cuboids Segmentation')
275
+ axes[1, 1].axis('off')
276
+
277
+ axes[1, 2].text(0.5, 0.5, 'NOT AVAILABLE',
278
+ horizontalalignment='center', verticalalignment='center',
279
+ transform=axes[1, 2].transAxes, fontsize=14, fontweight='bold')
280
+ axes[1, 2].set_title('Segmentation Legend')
281
+ axes[1, 2].axis('off')
282
+
283
+ # BEV Cuboids segmentation masks with legend (2,0 and 2,1)
284
+ if has_cuboids_segmasks_bev:
285
+ segmask_bev = batch["cuboids_segmasks_bev"][0].float().cpu().numpy() # Shape: (n_subjects, h, w)
286
+ n_subjects_bev, h_bev, w_bev = segmask_bev.shape
287
+
288
+ # Create colored segmentation visualization for BEV (use different seed for different colors)
289
+ np.random.seed(123) # Different seed for BEV colors
290
+ colors_bev = np.random.rand(n_subjects_bev + 1, 3) # +1 for background
291
+ colors_bev[0] = [0, 0, 0] # Background is black
292
+
293
+ # Create RGB image from BEV segmentation
294
+ colored_segmask_bev = np.zeros((h_bev, w_bev, 3))
295
+ for subject_idx in range(n_subjects_bev):
296
+ mask_bev = segmask_bev[subject_idx] > 0.5 # Binary threshold
297
+ colored_segmask_bev[mask_bev] = colors_bev[subject_idx + 1]
298
+
299
+ axes[2, 0].imshow(colored_segmask_bev)
300
+ axes[2, 0].set_title('BEV Cuboids Segmentation')
301
+ axes[2, 0].axis('off')
302
+
303
+ # Create BEV legend in the next subplot (2,1)
304
+ axes[2, 1].set_xlim(0, 1)
305
+ axes[2, 1].set_ylim(0, 1)
306
+
307
+ # Add BEV legend entries
308
+ legend_y_positions_bev = np.linspace(0.9, 0.1, n_subjects_bev + 1)
309
+ axes[2, 1].text(0.1, legend_y_positions_bev[0], f"Background",
310
+ color=colors_bev[0], fontsize=12, fontweight='bold')
311
+
312
+ for subject_idx in range(n_subjects_bev):
313
+ axes[2, 1].text(0.1, legend_y_positions_bev[subject_idx + 1],
314
+ f"Subject {subject_idx}",
315
+ color=colors_bev[subject_idx + 1], fontsize=12, fontweight='bold')
316
+
317
+ axes[2, 1].set_title('BEV Segmentation Legend')
318
+ axes[2, 1].axis('off')
319
+ else:
320
+ axes[2, 0].text(0.5, 0.5, 'NOT AVAILABLE',
321
+ horizontalalignment='center', verticalalignment='center',
322
+ transform=axes[2, 0].transAxes, fontsize=14, fontweight='bold')
323
+ axes[2, 0].set_title('BEV Cuboids Segmentation')
324
+ axes[2, 0].axis('off')
325
+
326
+ axes[2, 1].text(0.5, 0.5, 'NOT AVAILABLE',
327
+ horizontalalignment='center', verticalalignment='center',
328
+ transform=axes[2, 1].transAxes, fontsize=14, fontweight='bold')
329
+ axes[2, 1].set_title('BEV Segmentation Legend')
330
+ axes[2, 1].axis('off')
331
+
332
+ # Text prompt and call ID (2,2)
333
+ axes[2, 2].text(0.5, 0.5, f'Text Prompt:\n\n"{text_prompt}"\n\nCall ID: {call_id}',
334
+ horizontalalignment='center', verticalalignment='center',
335
+ transform=axes[2, 2].transAxes, fontsize=12, wrap=True)
336
+ axes[2, 2].set_title('Text Prompt & Call ID')
337
+ axes[2, 2].axis('off')
338
+
339
+ # Pixel values info (3,0)
340
+ pixel_info = f'Pixel Values Shape: {batch["pixel_values"].shape}\n'
341
+ if has_spatial_condition:
342
+ pixel_info += f'Spatial Shape: {batch["cond_pixel_values"].shape}\n'
343
+ if has_subject_condition:
344
+ pixel_info += f'Subject Shape: {batch["subject_pixel_values"].shape}\n'
345
+ if has_cuboids_segmasks:
346
+ pixel_info += f'Cuboids Segmasks: {len(batch["cuboids_segmasks"])}\n'
347
+ if has_cuboids_segmasks_bev:
348
+ pixel_info += f'BEV Segmasks: {len(batch["cuboids_segmasks_bev"])}'
349
+
350
+ axes[3, 0].text(0.5, 0.5, pixel_info,
351
+ horizontalalignment='center', verticalalignment='center',
352
+ transform=axes[3, 0].transAxes, fontsize=10, fontfamily='monospace')
353
+ axes[3, 0].set_title('Tensor Shapes')
354
+ axes[3, 0].axis('off')
355
+
356
+ # Training info (3,1)
357
+ training_info = f'Global Step: {global_step}\nConditions:\nSpatial: {"✓" if has_spatial_condition else "✗"}\nSubject: {"✓" if has_subject_condition else "✗"}\nSegmasks: {"✓" if has_cuboids_segmasks else "✗"}\nBEV Segmasks: {"✓" if has_cuboids_segmasks_bev else "✗"}'
358
+ axes[3, 1].text(0.5, 0.5, training_info,
359
+ horizontalalignment='center', verticalalignment='center',
360
+ transform=axes[3, 1].transAxes, fontsize=12, fontfamily='monospace')
361
+ axes[3, 1].set_title('Training Info')
362
+ axes[3, 1].axis('off')
363
+
364
+ # Additional info (3,2) - can be used for any extra debugging info
365
+ axes[3, 2].text(0.5, 0.5, 'Additional Info\n(Reserved)',
366
+ horizontalalignment='center', verticalalignment='center',
367
+ transform=axes[3, 2].transAxes, fontsize=12, fontfamily='monospace')
368
+ axes[3, 2].set_title('Reserved')
369
+ axes[3, 2].axis('off')
370
+
371
+ plt.tight_layout()
372
+
373
+ # Save the visualization
374
+ save_dir = os.path.join(args.output_dir, "visualizations")
375
+ os.makedirs(save_dir, exist_ok=True)
376
+ save_path = os.path.join(save_dir, f"training_vis_step_{global_step}.png")
377
+ plt.savefig(save_path, dpi=150, bbox_inches='tight')
378
+ plt.close()
379
+
380
+ logger.info(f"Training visualization saved to {save_path}")
381
+
382
+ vae = vae.to(vae_dtype)
383
+
384
+ def log_validation(
385
+ pipeline,
386
+ args,
387
+ accelerator,
388
+ pipeline_args,
389
+ step,
390
+ torch_dtype,
391
+ is_final_validation=False,
392
+ ):
393
+ logger.info(
394
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
395
+ f" {pipeline_args['prompt']}."
396
+ )
397
+ pipeline = pipeline.to(accelerator.device)
398
+ pipeline.set_progress_bar_config(disable=True)
399
+ # run inference
400
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
401
+ # autocast_ctx = torch.autocast(accelerator.device.type) if not is_final_validation else nullcontext()
402
+ autocast_ctx = nullcontext()
403
+
404
+ with autocast_ctx:
405
+ images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
406
+
407
+ # for tracker in accelerator.trackers:
408
+ # phase_name = "test" if is_final_validation else "validation"
409
+ # if tracker.name == "tensorboard":
410
+ # np_images = np.stack([np.asarray(img) for img in images])
411
+ # tracker.writer.add_images(phase_name, np_images, step, dataformats="NHWC")
412
+ # if tracker.name == "wandb":
413
+ # tracker.log(
414
+ # {
415
+ # phase_name: [
416
+ # wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
417
+ # ]
418
+ # },
419
+ # )
420
+
421
+ return images
422
+
423
+
424
+ def import_model_class_from_model_name_or_path(
425
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
426
+ ):
427
+ text_encoder_config = PretrainedConfig.from_pretrained(
428
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
429
+ )
430
+ model_class = text_encoder_config.architectures[0]
431
+ if model_class == "CLIPTextModel":
432
+ from transformers import CLIPTextModel
433
+
434
+ return CLIPTextModel
435
+ elif model_class == "T5EncoderModel":
436
+ from transformers import T5EncoderModel
437
+
438
+ return T5EncoderModel
439
+ else:
440
+ raise ValueError(f"{model_class} is not supported.")
441
+
442
+
443
+ def parse_args(input_args=None):
444
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
445
+ parser.add_argument("--lora_num", type=int, default=2, help="number of the lora.")
446
+ parser.add_argument("--cond_size", type=int, default=512, help="size of the condition data.")
447
+ parser.add_argument("--test_h", type=int, default=1024, help="max side of the training data.")
448
+ parser.add_argument("--debug", type=int, default=0, help="whether to enter debug mode -- visualizations, gradient checks, etc.")
449
+ parser.add_argument("--test_w", type=int, default=1024, help="max side of the training data.")
450
+ parser.add_argument("--mode",type=str,default=None,help="The mode of the controller. Choose between ['depth', 'pose', 'canny'].")
451
+ parser.add_argument("--run_name",type=str,required=True,help="the name of the wandb run")
452
+ parser.add_argument(
453
+ "--train_data_dir",
454
+ type=str,
455
+ default="",
456
+ help=(
457
+ "A folder containing the training data. Folder contents must follow the structure described in"
458
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
459
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
460
+ ),
461
+ )
462
+ parser.add_argument(
463
+ "--inference_embeds_dir",
464
+ type=str,
465
+ default="",
466
+ help=(
467
+ "the captions for images"
468
+ ),
469
+ )
470
+ parser.add_argument(
471
+ "--pretrained_model_name_or_path",
472
+ type=str,
473
+ default="",
474
+ required=False,
475
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
476
+ )
477
+ parser.add_argument(
478
+ "--pretrained_lora_path",
479
+ type=str,
480
+ default=None,
481
+ required=False,
482
+ help="Path to pretrained model",
483
+ )
484
+ parser.add_argument(
485
+ "--revision",
486
+ type=str,
487
+ default=None,
488
+ required=False,
489
+ help="Revision of pretrained model identifier from huggingface.co/models.",
490
+ )
491
+ parser.add_argument(
492
+ "--variant",
493
+ type=str,
494
+ default=None,
495
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
496
+ )
497
+ parser.add_argument(
498
+ "--spatial_column",
499
+ type=str,
500
+ default="None",
501
+ help="The column of the dataset containing the canny image. By "
502
+ "default, the standard Image Dataset maps out 'file_name' "
503
+ "to 'image'.",
504
+ )
505
+ parser.add_argument(
506
+ "--subject_column",
507
+ type=str,
508
+ default="image",
509
+ help="The column of the dataset containing the subject image. By "
510
+ "default, the standard Image Dataset maps out 'file_name' "
511
+ "to 'image'.",
512
+ )
513
+ parser.add_argument(
514
+ "--target_column",
515
+ type=str,
516
+ default="image",
517
+ help="The column of the dataset containing the target image. By "
518
+ "default, the standard Image Dataset maps out 'file_name' "
519
+ "to 'image'.",
520
+ )
521
+ parser.add_argument(
522
+ "--caption_column",
523
+ type=str,
524
+ default="caption_left,caption_right",
525
+ help="The column of the dataset containing the instance prompt for each image",
526
+ )
527
+ parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
528
+ parser.add_argument(
529
+ "--max_sequence_length",
530
+ type=int,
531
+ default=512,
532
+ help="Maximum sequence length to use with with the T5 text encoder",
533
+ )
534
+ parser.add_argument(
535
+ "--validation_prompt",
536
+ type=str,
537
+ nargs="+",
538
+ default="A woodenpot floating in a pool.",
539
+ help="A prompt that is used during validation to verify that the model is learning.",
540
+ )
541
+ parser.add_argument(
542
+ "--subject_test_images",
543
+ type=str,
544
+ nargs="+",
545
+ default=["/tiamat-NAS/zhangyuxuan/datasets/benchmark_dataset/decoritems_woodenpot/0.png"],
546
+ help="A list of subject test image paths.",
547
+ )
548
+ parser.add_argument(
549
+ "--spatial_test_images",
550
+ type=str,
551
+ nargs="+",
552
+ default=[],
553
+ help="A list of spatial test image paths.",
554
+ )
555
+ parser.add_argument(
556
+ "--num_validation_images",
557
+ type=int,
558
+ default=4,
559
+ help="Number of images that should be generated during validation with `validation_prompt`.",
560
+ )
561
+ parser.add_argument(
562
+ "--validation_steps",
563
+ type=int,
564
+ default=20,
565
+ help=(
566
+ "Run validation every X epochs. validation consists of running the prompt"
567
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
568
+ ),
569
+ )
570
+ parser.add_argument(
571
+ "--ranks",
572
+ type=int,
573
+ nargs="+",
574
+ default=[128],
575
+ help=("The dimension of the LoRA update matrices."),
576
+ )
577
+ parser.add_argument(
578
+ "--network_alphas",
579
+ type=int,
580
+ nargs="+",
581
+ default=[128],
582
+ help=("The dimension of the LoRA update matrices."),
583
+ )
584
+ parser.add_argument(
585
+ "--output_dir",
586
+ type=str,
587
+ required=True,
588
+ help="The output directory where the model predictions and checkpoints will be written.",
589
+ )
590
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
591
+ parser.add_argument(
592
+ "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
593
+ )
594
+ parser.add_argument("--num_train_epochs", type=int, default=50)
595
+ parser.add_argument(
596
+ "--max_train_steps",
597
+ type=int,
598
+ default=None,
599
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
600
+ )
601
+ parser.add_argument(
602
+ "--checkpointing_steps",
603
+ type=int,
604
+ default=1000,
605
+ help=(
606
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
607
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
608
+ " training using `--resume_from_checkpoint`."
609
+ ),
610
+ )
611
+ parser.add_argument(
612
+ "--checkpoints_total_limit",
613
+ type=int,
614
+ default=None,
615
+ help=("Max number of checkpoints to store."),
616
+ )
617
+ parser.add_argument(
618
+ "--resume_from_checkpoint",
619
+ type=str,
620
+ default=None,
621
+ help=(
622
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
623
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
624
+ ),
625
+ )
626
+ parser.add_argument(
627
+ "--gradient_accumulation_steps",
628
+ type=int,
629
+ default=1,
630
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
631
+ )
632
+ parser.add_argument(
633
+ "--gradient_checkpointing",
634
+ action="store_true",
635
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
636
+ )
637
+ parser.add_argument(
638
+ "--learning_rate",
639
+ type=float,
640
+ default=1e-4,
641
+ help="Initial learning rate (after the potential warmup period) to use.",
642
+ )
643
+
644
+ parser.add_argument(
645
+ "--guidance_scale",
646
+ type=float,
647
+ default=1,
648
+ help="the FLUX.1 dev variant is a guidance distilled model",
649
+ )
650
+ parser.add_argument(
651
+ "--scale_lr",
652
+ action="store_true",
653
+ default=False,
654
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
655
+ )
656
+ parser.add_argument(
657
+ "--lr_scheduler",
658
+ type=str,
659
+ default="constant",
660
+ help=(
661
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
662
+ ' "constant", "constant_with_warmup"]'
663
+ ),
664
+ )
665
+ parser.add_argument(
666
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
667
+ )
668
+ parser.add_argument(
669
+ "--lr_num_cycles",
670
+ type=int,
671
+ default=1,
672
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
673
+ )
674
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
675
+ parser.add_argument(
676
+ "--dataloader_num_workers",
677
+ type=int,
678
+ default=2,
679
+ help=(
680
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
681
+ ),
682
+ )
683
+ parser.add_argument(
684
+ "--weighting_scheme",
685
+ type=str,
686
+ default="none",
687
+ choices=["sigma_sqrt", "logit_normal", "mode", "cosmap", "none"],
688
+ help=('We default to the "none" weighting scheme for uniform sampling and uniform loss'),
689
+ )
690
+ parser.add_argument(
691
+ "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme."
692
+ )
693
+ parser.add_argument(
694
+ "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme."
695
+ )
696
+ parser.add_argument(
697
+ "--mode_scale",
698
+ type=float,
699
+ default=1.29,
700
+ help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
701
+ )
702
+ parser.add_argument(
703
+ "--optimizer",
704
+ type=str,
705
+ default="AdamW",
706
+ help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
707
+ )
708
+
709
+ parser.add_argument(
710
+ "--use_8bit_adam",
711
+ action="store_true",
712
+ help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
713
+ )
714
+
715
+ parser.add_argument(
716
+ "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
717
+ )
718
+ parser.add_argument(
719
+ "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
720
+ )
721
+ parser.add_argument(
722
+ "--prodigy_beta3",
723
+ type=float,
724
+ default=None,
725
+ help="coefficients for computing the Prodigy stepsize using running averages. If set to None, "
726
+ "uses the value of square root of beta2. Ignored if optimizer is adamW",
727
+ )
728
+ parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
729
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
730
+ parser.add_argument(
731
+ "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
732
+ )
733
+
734
+ parser.add_argument(
735
+ "--adam_epsilon",
736
+ type=float,
737
+ default=1e-08,
738
+ help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
739
+ )
740
+
741
+ parser.add_argument(
742
+ "--prodigy_use_bias_correction",
743
+ type=bool,
744
+ default=True,
745
+ help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
746
+ )
747
+ parser.add_argument(
748
+ "--prodigy_safeguard_warmup",
749
+ type=bool,
750
+ default=True,
751
+ help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
752
+ "Ignored if optimizer is adamW",
753
+ )
754
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
755
+ parser.add_argument(
756
+ "--logging_dir",
757
+ type=str,
758
+ default="logs",
759
+ help=(
760
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
761
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
762
+ ),
763
+ )
764
+ parser.add_argument(
765
+ "--cache_latents",
766
+ action="store_true",
767
+ default=False,
768
+ help="Cache the VAE latents",
769
+ )
770
+ parser.add_argument(
771
+ "--report_to",
772
+ type=str,
773
+ default="wandb",
774
+ help=(
775
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
776
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
777
+ ),
778
+ )
779
+ parser.add_argument(
780
+ "--mixed_precision",
781
+ type=str,
782
+ default="bf16",
783
+ choices=["no", "fp16", "bf16"],
784
+ help=(
785
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
786
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
787
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
788
+ ),
789
+ )
790
+ parser.add_argument(
791
+ "--upcast_before_saving",
792
+ action="store_true",
793
+ default=False,
794
+ help=(
795
+ "Whether to upcast the trained transformer layers to float32 before saving (at the end of training). "
796
+ "Defaults to precision dtype used for training to save memory"
797
+ ),
798
+ )
799
+
800
+ if input_args is not None:
801
+ args = parser.parse_args(input_args)
802
+ else:
803
+ args = parser.parse_args()
804
+ return args
805
+
806
+
807
+ def main(args):
808
+ if torch.backends.mps.is_available() and args.mixed_precision == "bf16":
809
+ # due to pytorch#99272, MPS does not yet support bfloat16.
810
+ raise ValueError(
811
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
812
+ )
813
+
814
+ if args.pretrained_lora_path is not None:
815
+ assert osp.exists(args.pretrained_lora_path), f"Make sure that the `pretrained_lora_path` {args.pretrained_lora_path} exists."
816
+ args.resume_from_checkpoint = osp.dirname(args.pretrained_lora_path)
817
+
818
+ args.output_dir = osp.join(args.output_dir, args.run_name)
819
+ args.logging_dir = osp.join(args.output_dir, args.logging_dir)
820
+ os.makedirs(args.output_dir, exist_ok=True)
821
+ os.makedirs(args.logging_dir, exist_ok=True)
822
+ logging_dir = Path(args.output_dir, args.logging_dir)
823
+
824
+ if args.subject_column == "None":
825
+ args.subject_column = None
826
+ if args.spatial_column == "None":
827
+ args.spatial_column = None
828
+
829
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
830
+ # kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
831
+ accelerator = Accelerator(
832
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
833
+ mixed_precision=args.mixed_precision,
834
+ log_with=args.report_to,
835
+ project_config=accelerator_project_config,
836
+ # kwargs_handlers=[kwargs],
837
+ )
838
+
839
+ def save_model_hook(models, weights, output_dir):
840
+ pass
841
+
842
+ def load_model_hook(models, input_dir):
843
+ pass
844
+
845
+ # Disable AMP for MPS.
846
+ if torch.backends.mps.is_available():
847
+ accelerator.native_amp = False
848
+
849
+ if args.report_to == "wandb":
850
+ if not is_wandb_available():
851
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
852
+
853
+ # Make one log on every process with the configuration for debugging.
854
+ logging.basicConfig(
855
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
856
+ datefmt="%m/%d/%Y %H:%M:%S",
857
+ level=logging.INFO,
858
+ )
859
+ logger.info(accelerator.state, main_process_only=False)
860
+ if accelerator.is_local_main_process:
861
+ transformers.utils.logging.set_verbosity_warning()
862
+ diffusers.utils.logging.set_verbosity_info()
863
+ else:
864
+ transformers.utils.logging.set_verbosity_error()
865
+ diffusers.utils.logging.set_verbosity_error()
866
+
867
+ # If passed along, set the training seed now.
868
+ if args.seed is not None:
869
+ set_seed(args.seed)
870
+
871
+ # Handle the repository creation
872
+ if accelerator.is_main_process:
873
+ if args.output_dir is not None:
874
+ os.makedirs(args.output_dir, exist_ok=True)
875
+
876
+ # Load the tokenizers
877
+ tokenizer_one = CLIPTokenizer.from_pretrained(
878
+ args.pretrained_model_name_or_path,
879
+ subfolder="tokenizer",
880
+ revision=args.revision,
881
+ )
882
+ tokenizer_two = T5TokenizerFast.from_pretrained(
883
+ args.pretrained_model_name_or_path,
884
+ subfolder="tokenizer_2",
885
+ revision=args.revision,
886
+ )
887
+
888
+ # Load scheduler and models
889
+ noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
890
+ args.pretrained_model_name_or_path, subfolder="scheduler"
891
+ )
892
+ noise_scheduler_copy = copy.deepcopy(noise_scheduler)
893
+ gc.collect()
894
+ torch.cuda.empty_cache()
895
+ vae = AutoencoderKL.from_pretrained(
896
+ args.pretrained_model_name_or_path,
897
+ subfolder="vae",
898
+ revision=args.revision,
899
+ variant=args.variant,
900
+ )
901
+ transformer = FluxTransformer2DModel.from_pretrained(
902
+ args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant
903
+ )
904
+
905
+ # We only train the additional adapter LoRA layers
906
+ transformer.requires_grad_(True)
907
+ vae.requires_grad_(False)
908
+
909
+ # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
910
+ # as these weights are only used for inference, keeping weights in full precision is not required.
911
+ weight_dtype = torch.float32
912
+ if accelerator.mixed_precision == "fp16":
913
+ weight_dtype = torch.float16
914
+ elif accelerator.mixed_precision == "bf16":
915
+ weight_dtype = torch.bfloat16
916
+
917
+ if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16:
918
+ # due to pytorch#99272, MPS does not yet support bfloat16.
919
+ raise ValueError(
920
+ "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
921
+ )
922
+
923
+ vae.to(accelerator.device, dtype=weight_dtype)
924
+ transformer.to(accelerator.device, dtype=weight_dtype)
925
+
926
+ if args.gradient_checkpointing:
927
+ transformer.enable_gradient_checkpointing()
928
+
929
+ #### lora_layers ####
930
+ if args.pretrained_lora_path is not None:
931
+ lora_path = args.pretrained_lora_path
932
+ checkpoint = load_checkpoint(lora_path)
933
+ lora_attn_procs = {}
934
+ double_blocks_idx = list(range(19))
935
+ single_blocks_idx = list(range(38))
936
+ number = 1
937
+ for name, attn_processor in transformer.attn_processors.items():
938
+ match = re.search(r'\.(\d+)\.', name)
939
+ if match:
940
+ layer_index = int(match.group(1))
941
+
942
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
943
+ lora_state_dicts = {}
944
+ for key, value in checkpoint.items():
945
+ # Match based on the layer index in the key (assuming the key contains layer index)
946
+ if re.search(r'\.(\d+)\.', key):
947
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
948
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
949
+ lora_state_dicts[key] = value
950
+
951
+ print("setting LoRA Processor for", name)
952
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
953
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
954
+ )
955
+
956
+ # Load the weights from the checkpoint dictionary into the corresponding layers
957
+ for n in range(number):
958
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
959
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
960
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
961
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
962
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
963
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
964
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
965
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
966
+
967
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
968
+
969
+ lora_state_dicts = {}
970
+ for key, value in checkpoint.items():
971
+ # Match based on the layer index in the key (assuming the key contains layer index)
972
+ if re.search(r'\.(\d+)\.', key):
973
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
974
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
975
+ lora_state_dicts[key] = value
976
+
977
+ print("setting LoRA Processor for", name)
978
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
979
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
980
+ )
981
+
982
+ # Load the weights from the checkpoint dictionary into the corresponding layers
983
+ for n in range(number):
984
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
985
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
986
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
987
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
988
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
989
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
990
+ else:
991
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
992
+ else:
993
+ lora_attn_procs = {}
994
+ double_blocks_idx = list(range(19))
995
+ single_blocks_idx = list(range(38))
996
+ for name, attn_processor in transformer.attn_processors.items():
997
+ match = re.search(r'\.(\d+)\.', name)
998
+ if match:
999
+ layer_index = int(match.group(1))
1000
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
1001
+ lora_state_dicts = {}
1002
+ print("setting LoRA Processor for", name)
1003
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
1004
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
1005
+ )
1006
+
1007
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
1008
+ print("setting LoRA Processor for", name)
1009
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
1010
+ dim=3072, ranks=args.ranks, network_alphas=args.network_alphas, lora_weights=[1 for _ in range(args.lora_num)], device=accelerator.device, dtype=weight_dtype, cond_width=args.cond_size, cond_height=args.cond_size, n_loras=args.lora_num
1011
+ )
1012
+
1013
+ else:
1014
+ lora_attn_procs[name] = attn_processor
1015
+ ######################
1016
+ transformer.set_attn_processor(lora_attn_procs)
1017
+ transformer.train()
1018
+ for n, param in transformer.named_parameters():
1019
+ if '_lora' not in n:
1020
+ param.requires_grad = False
1021
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'M parameters')
1022
+
1023
+ def unwrap_model(model):
1024
+ model = accelerator.unwrap_model(model)
1025
+ model = model._orig_mod if is_compiled_module(model) else model
1026
+ return model
1027
+
1028
+ # Potentially load in the weights and states from a previous save
1029
+ if args.resume_from_checkpoint:
1030
+ foldername = osp.basename(args.resume_from_checkpoint)
1031
+ first_epoch = epoch = int(foldername.split("-")[1].split("__")[0])
1032
+ initial_global_step = global_step = int(foldername.split("-")[-1])
1033
+ else:
1034
+ initial_global_step = 0
1035
+ global_step = 0
1036
+ first_epoch = 0
1037
+
1038
+ if args.scale_lr:
1039
+ args.learning_rate = (
1040
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1041
+ )
1042
+
1043
+ # Make sure the trainable params are in float32.
1044
+ if args.mixed_precision == "fp16":
1045
+ models = [transformer]
1046
+ # only upcast trainable parameters (LoRA) into fp32
1047
+ cast_training_params(models, dtype=torch.float32)
1048
+
1049
+ # Optimization parameters
1050
+ params_to_optimize = [p for p in transformer.parameters() if p.requires_grad]
1051
+ transformer_parameters_with_lr = {"params": params_to_optimize, "lr": args.learning_rate}
1052
+ print(sum([p.numel() for p in transformer.parameters() if p.requires_grad]) / 1000000, 'parameters')
1053
+
1054
+ optimizer_class = torch.optim.AdamW
1055
+ optimizer = optimizer_class(
1056
+ [transformer_parameters_with_lr],
1057
+ betas=(args.adam_beta1, args.adam_beta2),
1058
+ weight_decay=args.adam_weight_decay,
1059
+ eps=args.adam_epsilon,
1060
+ )
1061
+
1062
+ tokenizers = [tokenizer_one, tokenizer_two]
1063
+
1064
+ # # Dataset and DataLoaders creation:
1065
+ # train_dataset = make_train_dataset(args, tokenizers, accelerator)
1066
+ # train_dataloader = torch.utils.data.DataLoader(
1067
+ # train_dataset,
1068
+ # batch_size=args.train_batch_size,
1069
+ # shuffle=True,
1070
+ # collate_fn=collate_fn,
1071
+ # num_workers=args.dataloader_num_workers,
1072
+ # )
1073
+
1074
+ # now, we will define a dataset for each epoch to make it easier to save the state
1075
+ shuffled_jsonls = os.listdir(osp.dirname(args.train_data_dir))
1076
+ base_jsonl_name = osp.basename(args.train_data_dir).replace(".jsonl", "")
1077
+ shuffled_jsonls = sorted([_ for _ in shuffled_jsonls if _.endswith('.jsonl') and "shuffled" in _ and base_jsonl_name in _])
1078
+ shuffled_jsonls = [osp.join(osp.dirname(args.train_data_dir), _) for _ in shuffled_jsonls]
1079
+ print(f"{shuffled_jsonls = }")
1080
+ # exit(0)
1081
+ assert len(shuffled_jsonls) > 0, f"Make sure that there are shuffled jsonl files in {osp.dirname(args.train_data_dir)}"
1082
+ train_dataloaders = []
1083
+ for epoch in range(args.num_train_epochs): # prepare dataloader for each epoch, irrespective of the resume state
1084
+ shuffled_idx = epoch % len(shuffled_jsonls)
1085
+ train_data_file = shuffled_jsonls[shuffled_idx]
1086
+ assert osp.exists(train_data_file), f"Make sure that the train data jsonl file {train_data_file} exists."
1087
+ args.current_train_data_dir = train_data_file
1088
+ train_dataset = make_train_dataset(args, tokenizers, accelerator)
1089
+ train_dataloader = torch.utils.data.DataLoader(
1090
+ train_dataset,
1091
+ batch_size=args.train_batch_size,
1092
+ shuffle=False, # yayy!! reproducible experiments!
1093
+ collate_fn=collate_fn,
1094
+ num_workers=args.dataloader_num_workers,
1095
+ )
1096
+ train_dataloaders.append(train_dataloader)
1097
+
1098
+ vae_config_shift_factor = vae.config.shift_factor
1099
+ vae_config_scaling_factor = vae.config.scaling_factor
1100
+
1101
+ # Scheduler and math around the number of training steps.
1102
+ overrode_max_train_steps = False
1103
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1104
+ if args.max_train_steps is None:
1105
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1106
+ overrode_max_train_steps = True
1107
+
1108
+ lr_scheduler = get_scheduler(
1109
+ args.lr_scheduler,
1110
+ optimizer=optimizer,
1111
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
1112
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
1113
+ num_cycles=args.lr_num_cycles,
1114
+ power=args.lr_power,
1115
+ )
1116
+
1117
+
1118
+ accelerator.register_save_state_pre_hook(save_model_hook)
1119
+ accelerator.register_load_state_pre_hook(load_model_hook)
1120
+ optimizer, lr_scheduler = accelerator.prepare(
1121
+ optimizer, lr_scheduler
1122
+ )
1123
+
1124
+ print(f"before preparation, {len(train_dataloaders[0]) = }")
1125
+
1126
+ prepared_train_dataloaders = []
1127
+ for train_dataloader in train_dataloaders:
1128
+ prepared_train_dataloaders.append(accelerator.prepare(train_dataloader))
1129
+ train_dataloaders = prepared_train_dataloaders
1130
+
1131
+ print(f"after preparation, {len(train_dataloaders[0]) = }")
1132
+
1133
+ if args.pretrained_lora_path is not None:
1134
+ accelerator.load_state(osp.dirname(args.pretrained_lora_path))
1135
+
1136
+ # Explicitly move optimizer states to accelerator.device
1137
+ for state in optimizer.state.values():
1138
+ for k, v in state.items():
1139
+ if isinstance(v, torch.Tensor):
1140
+ state[k] = v.to(accelerator.device)
1141
+
1142
+ transformer = accelerator.prepare(transformer)
1143
+
1144
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1145
+ num_update_steps_per_epoch = math.ceil(len(train_dataloaders[0]) / args.gradient_accumulation_steps)
1146
+ if overrode_max_train_steps:
1147
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1148
+ # Afterwards we recalculate our number of training epochs
1149
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1150
+
1151
+ # We need to initialize the trackers we use, and also store our configuration.
1152
+ # The trackers initializes automatically on the main process.
1153
+ # if accelerator.is_main_process:
1154
+ # tracker_name = "Easy_Control"
1155
+ # accelerator.init_trackers(tracker_name, config=vars(args))
1156
+
1157
+ if accelerator.is_main_process:
1158
+ tracker_config = vars(copy.deepcopy(args))
1159
+ # tracker_config.pop("validation_images")
1160
+ wandb_args = {
1161
+ "wandb": {
1162
+ "entity": "generative_parts",
1163
+ "name": args.run_name,
1164
+ }
1165
+ }
1166
+ accelerator.init_trackers("seethrough3d", config=tracker_config, init_kwargs=wandb_args)
1167
+
1168
+
1169
+ # Train!
1170
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1171
+
1172
+ logger.info("***** Running training *****")
1173
+ logger.info(f" Num examples = {len(train_dataset)}")
1174
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
1175
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1176
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1177
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1178
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1179
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1180
+
1181
+ progress_bar = tqdm(
1182
+ range(0, args.max_train_steps),
1183
+ initial=initial_global_step,
1184
+ desc="Steps",
1185
+ # Only show the progress bar once on each machine.
1186
+ disable=not accelerator.is_local_main_process,
1187
+ )
1188
+
1189
+ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
1190
+ sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype)
1191
+ schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device)
1192
+ timesteps = timesteps.to(accelerator.device)
1193
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
1194
+
1195
+ sigma = sigmas[step_indices].flatten()
1196
+ while len(sigma.shape) < n_dim:
1197
+ sigma = sigma.unsqueeze(-1)
1198
+ return sigma
1199
+
1200
+ # some fixed parameters
1201
+ vae_scale_factor = 16
1202
+ height_cond = 2 * (args.cond_size // vae_scale_factor)
1203
+ width_cond = 2 * (args.cond_size // vae_scale_factor)
1204
+ offset = 64
1205
+
1206
+ num_training_visualizations = 10
1207
+
1208
+ skip_steps = initial_global_step - first_epoch * num_update_steps_per_epoch
1209
+ print(f"{skip_steps = }")
1210
+ for epoch in range(first_epoch, args.num_train_epochs):
1211
+ transformer.train()
1212
+ train_dataloader = train_dataloaders[epoch] # use a new dataloader for each epoch
1213
+ if epoch == first_epoch and skip_steps > 0:
1214
+ logger.info(f"Skipping {skip_steps} batches in epoch {epoch} due to resuming from checkpoint")
1215
+ # dataloader_iterator = skip_first_batches_manual(train_dataloader, skip_steps)
1216
+ dataloader_iterator = accelerator.skip_first_batches(train_dataloader, skip_steps)
1217
+ # Convert back to enumerate format
1218
+ enumerated_dataloader = enumerate(dataloader_iterator, start=skip_steps)
1219
+ else:
1220
+ enumerated_dataloader = enumerate(train_dataloader)
1221
+ for step, batch in enumerated_dataloader:
1222
+ progress_bar.set_description(f"epoch {epoch}, dataset_ids: {batch['index']}")
1223
+ torch.cuda.empty_cache()
1224
+ models_to_accumulate = [transformer]
1225
+ with accelerator.accumulate(models_to_accumulate):
1226
+
1227
+ # tokens = [batch["text_ids_1"], batch["text_ids_2"]]
1228
+ # prompt_embeds, pooled_prompt_embeds, text_ids = encode_token_ids(text_encoders, tokens, accelerator)
1229
+ prompt_embeds = batch["prompt_embeds"]
1230
+ pooled_prompt_embeds = batch["pooled_prompt_embeds"]
1231
+ text_ids = torch.zeros((batch["prompt_embeds"].shape[1], 3))
1232
+ prompt_embeds = prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1233
+ pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=vae.dtype, device=accelerator.device)
1234
+ text_ids = text_ids.to(dtype=vae.dtype, device=accelerator.device)
1235
+
1236
+ pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
1237
+ height_ = 2 * (int(pixel_values.shape[-2]) // vae_scale_factor)
1238
+ width_ = 2 * (int(pixel_values.shape[-1]) // vae_scale_factor)
1239
+
1240
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1241
+ model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
1242
+ model_input = model_input.to(dtype=weight_dtype)
1243
+
1244
+ latent_image_ids, cond_latent_image_ids = resize_position_encoding(
1245
+ model_input.shape[0],
1246
+ height_,
1247
+ width_,
1248
+ height_cond,
1249
+ width_cond,
1250
+ accelerator.device,
1251
+ weight_dtype,
1252
+ )
1253
+
1254
+ # Sample noise that we'll add to the latents
1255
+ noise = torch.randn_like(model_input)
1256
+ bsz = model_input.shape[0]
1257
+
1258
+ # Sample a random timestep for each image
1259
+ # for weighting schemes where we sample timesteps non-uniformly
1260
+ u = compute_density_for_timestep_sampling(
1261
+ weighting_scheme=args.weighting_scheme,
1262
+ batch_size=bsz,
1263
+ logit_mean=args.logit_mean,
1264
+ logit_std=args.logit_std,
1265
+ mode_scale=args.mode_scale,
1266
+ )
1267
+ indices = (u * noise_scheduler_copy.config.num_train_timesteps).long()
1268
+ timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
1269
+
1270
+ # Add noise according to flow matching.
1271
+ # zt = (1 - texp) * x + texp * z1
1272
+ sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
1273
+ noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
1274
+
1275
+ packed_noisy_model_input = FluxPipeline._pack_latents(
1276
+ noisy_model_input,
1277
+ batch_size=model_input.shape[0],
1278
+ num_channels_latents=model_input.shape[1],
1279
+ height=model_input.shape[2],
1280
+ width=model_input.shape[3],
1281
+ )
1282
+
1283
+ latent_image_ids_to_concat = [latent_image_ids]
1284
+ packed_cond_model_input_to_concat = []
1285
+
1286
+ if args.subject_column is not None:
1287
+ # in case the condition is not spatial
1288
+ subject_pixel_values = batch["subject_pixel_values"].to(dtype=vae.dtype)
1289
+ subject_input = vae.encode(subject_pixel_values).latent_dist.sample()
1290
+ subject_input = (subject_input - vae_config_shift_factor) * vae_config_scaling_factor
1291
+ subject_input = subject_input.to(dtype=weight_dtype)
1292
+ # the number of subjects in the concatenated subject image
1293
+ sub_number = subject_pixel_values.shape[-2] // args.cond_size
1294
+ latent_subject_ids = prepare_latent_subject_ids(height_cond, width_cond, accelerator.device, weight_dtype)
1295
+ latent_subject_ids[:, 1] += offset
1296
+ sub_latent_image_ids = torch.concat([latent_subject_ids for _ in range(sub_number)], dim=-2)
1297
+ latent_image_ids_to_concat.append(sub_latent_image_ids)
1298
+
1299
+ packed_subject_model_input = FluxPipeline._pack_latents(
1300
+ subject_input,
1301
+ batch_size=subject_input.shape[0],
1302
+ num_channels_latents=subject_input.shape[1],
1303
+ height=subject_input.shape[2],
1304
+ width=subject_input.shape[3],
1305
+ )
1306
+ packed_cond_model_input_to_concat.append(packed_subject_model_input)
1307
+ else:
1308
+ subject_input = None
1309
+
1310
+ if args.spatial_column is not None:
1311
+ # in case the condition is spatial
1312
+ cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
1313
+ cond_input = vae.encode(cond_pixel_values).latent_dist.sample()
1314
+ cond_input = (cond_input - vae_config_shift_factor) * vae_config_scaling_factor
1315
+ cond_input = cond_input.to(dtype=weight_dtype)
1316
+ # number of conditions in the concatenated condition image
1317
+ cond_number = cond_pixel_values.shape[-2] // args.cond_size
1318
+ cond_latent_image_ids = torch.concat([cond_latent_image_ids for _ in range(cond_number)], dim=-2)
1319
+ latent_image_ids_to_concat.append(cond_latent_image_ids)
1320
+
1321
+ packed_cond_model_input = FluxPipeline._pack_latents(
1322
+ cond_input,
1323
+ batch_size=cond_input.shape[0],
1324
+ num_channels_latents=cond_input.shape[1],
1325
+ height=cond_input.shape[2],
1326
+ width=cond_input.shape[3],
1327
+ )
1328
+ packed_cond_model_input_to_concat.append(packed_cond_model_input)
1329
+ else:
1330
+ cond_input = None
1331
+
1332
+ latent_image_ids = torch.concat(latent_image_ids_to_concat, dim=-2)
1333
+ cond_packed_noisy_model_input = torch.concat(packed_cond_model_input_to_concat, dim=-2)
1334
+
1335
+ # handle guidance
1336
+ if accelerator.unwrap_model(transformer).config.guidance_embeds:
1337
+ guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
1338
+ guidance = guidance.expand(model_input.shape[0])
1339
+ else:
1340
+ guidance = None
1341
+
1342
+ # Visualize training data before transformer forward pass
1343
+ if accelerator.is_main_process and args.debug and num_training_visualizations > 0 and global_step % 5 == 0:
1344
+ visualize_training_data(
1345
+ batch=batch,
1346
+ vae=vae,
1347
+ model_input=model_input,
1348
+ noisy_model_input=noisy_model_input,
1349
+ cond_input=cond_input,
1350
+ subject_input=subject_input,
1351
+ args=args,
1352
+ global_step=global_step,
1353
+ accelerator=accelerator
1354
+ )
1355
+ num_training_visualizations -= 1
1356
+
1357
+ # Predict the noise residual
1358
+ model_pred = transformer(
1359
+ hidden_states=packed_noisy_model_input,
1360
+ # YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
1361
+ cond_hidden_states=cond_packed_noisy_model_input,
1362
+ timestep=timesteps / 1000,
1363
+ guidance=guidance,
1364
+ pooled_projections=pooled_prompt_embeds,
1365
+ encoder_hidden_states=prompt_embeds,
1366
+ txt_ids=text_ids,
1367
+ img_ids=latent_image_ids,
1368
+ return_dict=False,
1369
+ call_ids=batch["call_ids"],
1370
+ cuboids_segmasks=batch["cuboids_segmasks"],
1371
+ )[0]
1372
+
1373
+ model_pred = FluxPipeline._unpack_latents(
1374
+ model_pred,
1375
+ height=int(pixel_values.shape[-2]),
1376
+ width=int(pixel_values.shape[-1]),
1377
+ vae_scale_factor=vae_scale_factor,
1378
+ )
1379
+
1380
+ # these weighting schemes use a uniform timestep sampling
1381
+ # and instead post-weight the loss
1382
+ weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
1383
+
1384
+ # flow matching loss
1385
+ target = noise - model_input
1386
+
1387
+ # Compute regular loss.
1388
+ loss = torch.mean(
1389
+ (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
1390
+ 1,
1391
+ )
1392
+
1393
+ loss = loss.mean()
1394
+ accelerator.backward(loss)
1395
+ if accelerator.sync_gradients:
1396
+ params_to_clip = (transformer.parameters())
1397
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1398
+
1399
+ optimizer.step()
1400
+ lr_scheduler.step()
1401
+ optimizer.zero_grad()
1402
+
1403
+ # Checks if the accelerator has performed an optimization step behind the scenes
1404
+ if accelerator.sync_gradients:
1405
+ progress_bar.update(1)
1406
+ global_step += 1
1407
+
1408
+ if accelerator.is_main_process:
1409
+ if global_step % args.checkpointing_steps == 0:
1410
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1411
+ if args.checkpoints_total_limit is not None:
1412
+ checkpoints = os.listdir(args.output_dir)
1413
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1414
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1415
+
1416
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1417
+ if len(checkpoints) >= args.checkpoints_total_limit:
1418
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1419
+ removing_checkpoints = checkpoints[0:num_to_remove]
1420
+
1421
+ logger.info(
1422
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1423
+ )
1424
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1425
+
1426
+ for removing_checkpoint in removing_checkpoints:
1427
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1428
+ shutil.rmtree(removing_checkpoint)
1429
+
1430
+ save_path = os.path.join(args.output_dir, f"epoch-{epoch}__checkpoint-{global_step}")
1431
+ os.makedirs(save_path, exist_ok=True)
1432
+ unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
1433
+ lora_state_dict = {k:unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
1434
+ save_file(
1435
+ lora_state_dict,
1436
+ os.path.join(save_path, "lora.safetensors")
1437
+ )
1438
+ accelerator.save_state(save_path)
1439
+ os.remove(osp.join(save_path, "model.safetensors"))
1440
+ logger.info(f"Saved state to {save_path}")
1441
+
1442
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
1443
+ progress_bar.set_postfix(**logs)
1444
+ accelerator.log(logs, step=global_step)
1445
+
1446
+ save_path = os.path.join(args.output_dir, f"epoch-{epoch}__checkpoint-{global_step}")
1447
+ os.makedirs(save_path, exist_ok=True)
1448
+ unwrapped_model_state = accelerator.unwrap_model(transformer).state_dict()
1449
+ lora_state_dict = {k:unwrapped_model_state[k] for k in unwrapped_model_state.keys() if '_lora' in k}
1450
+ save_file(
1451
+ lora_state_dict,
1452
+ os.path.join(save_path, "lora.safetensors")
1453
+ )
1454
+ accelerator.save_state(save_path)
1455
+ os.remove(osp.join(save_path, "model.safetensors"))
1456
+ logger.info(f"Saved state to {save_path}")
1457
+ accelerator.wait_for_everyone()
1458
+ accelerator.end_training()
1459
+
1460
+
1461
+ if __name__ == "__main__":
1462
+ args = parse_args()
1463
+ main(args)
train/train.sh ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=vaibhav
3
+ #SBATCH --output=%j.out
4
+ #SBATCH --ntasks=1
5
+ #SBATCH --cpus-per-task=4
6
+ #SBATCH --mem=150G
7
+ #SBATCH --gres=gpu:4
8
+ #SBATCH --partition=ada
9
+
10
+ # chetna
11
+ # export MODEL_DIR="black-forest-labs/FLUX.1-Kontext-dev" # your flux path
12
+ export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
13
+ export OUTPUT_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids" # your save path
14
+ export CONFIG="./default_config.yaml"
15
+ export TRAIN_DATA="/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard/cuboids__upto_4subjects.jsonl" # your data jsonl file
16
+ export LOG_PATH="$OUTPUT_DIR/log"
17
+ export INFERENCE_EMBEDS_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_datasetv7_superhard"
18
+
19
+ export WANDB_API_KEY=f27c837d8d7d0c8d79f3eb1de21fa78233c03be6
20
+
21
+ # kotak
22
+ # export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
23
+ # export OUTPUT_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids" # your save path
24
+ # export CONFIG="./default_config.yaml"
25
+ # export TRAIN_DATA="/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv6/cuboids.jsonl" # your data jsonl file
26
+ # export LOG_PATH="$OUTPUT_DIR/log"
27
+ # export INFERENCE_EMBEDS_DIR="/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_flux2"
28
+
29
+ # kotak
30
+ # export MODEL_DIR="black-forest-labs/FLUX.1-dev" # your flux path
31
+ # export OUTPUT_DIR="./easycontrol_cuboids" # your save path
32
+ # export CONFIG="./default_config.yaml"
33
+ # export TRAIN_DATA="/home/venky/vaibhav.agrawal/a-bev-of-the-latents/datasets/actual_data/datasetv6/cuboids.jsonl" # your data jsonl file
34
+ # export LOG_PATH="$OUTPUT_DIR/log"
35
+ # export INFERENCE_EMBEDS_DIR="/home/venky/vaibhav.agrawal/a-bev-of-the-latents/caching/inference_embeds_flux2"
36
+
37
+ # i love this.
38
+ accelerate launch --config_file $CONFIG train.py \
39
+ --pretrained_model_name_or_path $MODEL_DIR \
40
+ --cond_size=512 \
41
+ --subject_column="None" \
42
+ --spatial_column="cv" \
43
+ --target_column="target" \
44
+ --caption_column="caption" \
45
+ --ranks 128 \
46
+ --network_alphas 128 \
47
+ --lora_num 1 \
48
+ --output_dir=$OUTPUT_DIR \
49
+ --logging_dir=$LOG_PATH \
50
+ --run_name="rgb__r1" \
51
+ --debug=1 \
52
+ --mixed_precision="bf16" \
53
+ --train_data_dir=$TRAIN_DATA \
54
+ --learning_rate=1e-4 \
55
+ --train_batch_size=1 \
56
+ --inference_embeds_dir $INFERENCE_EMBEDS_DIR \
57
+ --validation_prompt "a photo of sedan and pickup truck and suv amongst autumn-colored trees along a winding river" "a photo of cow and suv on a sandy beach with palm trees swaying in the breeze" "a photo of table and horse and suv in a dense pine forest with tall trees reaching the sky" \
58
+ --num_train_epochs=1 \
59
+ --validation_steps=5000000000000 \
60
+ --checkpointing_steps=2500 \
61
+ --spatial_test_images "cuboids/sedan__pickup_truck__suv/005/cuboids.png" "cuboids/cow__suv/008/cuboids.png" "cuboids/table__horse__suv/007/cuboids.png" \
62
+ --subject_test_images None \
63
+ --test_h 512 \
64
+ --test_w 512 \
65
+ --num_validation_images=1
66
+
67
+ # --run_name="semantic_info_from_cuboid_cond" \
68
+ # --run_name="datasetv8__0.8_0.1_0.1" \
69
+ # --pretrained_lora_path="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids/wireframe/epoch-0__checkpoint-5000/lora.safetensors" \
70
+ # --pretrained_lora_path="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids/rgb/epoch-0__checkpoint-7500/lora.safetensors" \
71
+ # --pretrained_lora_path="/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids/datasetv9__wireframe_best_case/epoch-0__checkpoint-3888/lora.safetensors" \
visualize_server.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flask web server to visualize inference results for 2-subject cases.
4
+ Port: 7023
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from flask import Flask, render_template, send_from_directory
11
+ import base64
12
+
13
+ app = Flask(__name__)
14
+
15
+ # Paths
16
+ DATASET_FILE = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/cuboids_segmentation.jsonl"
17
+ DATASET_ROOT = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval"
18
+ RESULTS_DIR = "/archive/vaibhav.agrawal/a-bev-of-the-latents/VAL/results/omini_seg_baseline_r2_epoch-0_checkpoint-20000"
19
+
20
+ def load_2_subject_cases():
21
+ """Load all 2-subject cases from the dataset."""
22
+ cases = []
23
+ with open(DATASET_FILE, 'r') as f:
24
+ for idx, line in enumerate(f):
25
+ data = json.loads(line)
26
+ if len(data['subjects']) == 2:
27
+ cases.append({
28
+ 'dataset_index': idx,
29
+ 'subjects': data['subjects'],
30
+ 'prompt': data['prompt'],
31
+ 'target': data['target'],
32
+ 'cv': data['cv']
33
+ })
34
+ return cases
35
+
36
+ # Load cases on startup
37
+ TWO_SUBJECT_CASES = load_2_subject_cases()
38
+ print(f"Loaded {len(TWO_SUBJECT_CASES)} 2-subject cases")
39
+
40
+ def get_image_path(case, image_type):
41
+ """Get the path for a specific image type."""
42
+ if image_type == 'ground_truth':
43
+ return os.path.join(DATASET_ROOT, case['target'])
44
+ elif image_type == 'segmentation':
45
+ return os.path.join(DATASET_ROOT, case['cv'])
46
+ elif image_type == 'generated':
47
+ # Find the generated image in results
48
+ viz_dir = os.path.join(RESULTS_DIR, 'generated_images')
49
+ # Pattern: sample_{sample_idx:04d}_idx_{dataset_index}_seed_{seed}.jpg
50
+ # We need to find the file that matches the dataset_index
51
+ if os.path.exists(viz_dir):
52
+ for filename in os.listdir(viz_dir):
53
+ if f"_idx_{case['dataset_index']}_" in filename:
54
+ return os.path.join(viz_dir, filename)
55
+ return None
56
+
57
+ @app.route('/')
58
+ def index():
59
+ """Main page showing the first 2-subject case."""
60
+ return show_case(0)
61
+
62
+ @app.route('/case/<int:case_idx>')
63
+ def show_case(case_idx):
64
+ """Display a specific case."""
65
+ if case_idx < 0 or case_idx >= len(TWO_SUBJECT_CASES):
66
+ return "Case not found", 404
67
+
68
+ case = TWO_SUBJECT_CASES[case_idx]
69
+
70
+ # Get image paths
71
+ gt_path = get_image_path(case, 'ground_truth')
72
+ seg_path = get_image_path(case, 'segmentation')
73
+ gen_path = get_image_path(case, 'generated')
74
+
75
+ # Check if files exist
76
+ gt_exists = os.path.exists(gt_path) if gt_path else False
77
+ seg_exists = os.path.exists(seg_path) if seg_path else False
78
+ gen_exists = os.path.exists(gen_path) if gen_path else False
79
+
80
+ return render_template('viewer.html',
81
+ case_idx=case_idx,
82
+ total_cases=len(TWO_SUBJECT_CASES),
83
+ subjects=', '.join(case['subjects']),
84
+ prompt=case['prompt'].replace('PLACEHOLDER', ', '.join(case['subjects'])),
85
+ dataset_index=case['dataset_index'],
86
+ gt_exists=gt_exists,
87
+ seg_exists=seg_exists,
88
+ gen_exists=gen_exists,
89
+ prev_idx=case_idx - 1 if case_idx > 0 else None,
90
+ next_idx=case_idx + 1 if case_idx < len(TWO_SUBJECT_CASES) - 1 else None)
91
+
92
+ @app.route('/image/<int:case_idx>/<image_type>')
93
+ def serve_image(case_idx, image_type):
94
+ """Serve the requested image."""
95
+ if case_idx < 0 or case_idx >= len(TWO_SUBJECT_CASES):
96
+ return "Case not found", 404
97
+
98
+ case = TWO_SUBJECT_CASES[case_idx]
99
+ image_path = get_image_path(case, image_type)
100
+
101
+ if image_path and os.path.exists(image_path):
102
+ directory = os.path.dirname(image_path)
103
+ filename = os.path.basename(image_path)
104
+ return send_from_directory(directory, filename)
105
+ else:
106
+ return "Image not found", 404
107
+
108
+ if __name__ == '__main__':
109
+ # Create templates directory if it doesn't exist
110
+ os.makedirs('templates', exist_ok=True)
111
+
112
+ # Run server on all interfaces (0.0.0.0) for remote access
113
+ print(f"Starting server on port 7023...")
114
+ print(f"Access at: http://<your-host-ip>:7023")
115
+ app.run(host='0.0.0.0', port=7023, debug=True)