va1bhavagrawa1 commited on
Commit
fa94fe3
·
1 Parent(s): 4da163f

working demo, organized files

Browse files
gradio_app/.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
gradio_app/Image0001.png ADDED

Git LFS Details

  • SHA256: 8a0a13b49f14c2972d352c629abc54e3616777ea82daae5ac3505ccef287bcc8
  • Pointer size: 131 Bytes
  • Size of remote file: 533 kB
gradio_app/app.py ADDED
@@ -0,0 +1,1504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import sys
4
+ import numpy as np
5
+ import tempfile
6
+ import shutil
7
+ import base64
8
+ import io
9
+ from PIL import Image
10
+ import gradio as gr
11
+ import time
12
+ import copy
13
+ import requests
14
+ import json
15
+ import pickle
16
+ from concurrent.futures import ThreadPoolExecutor, as_completed
17
+ from object_scales import scales
18
+ from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
19
+ import pickle
20
+ from datetime import datetime
21
+ from infer_backend import initialize_inference_engine, run_inference_from_gradio
22
+
23
+ COLORS = [
24
+ (1.0, 0.0, 0.0), # Red
25
+ (0.0, 0.8, 0.2), # Green
26
+ (0.0, 0.0, 1.0), # Blue
27
+ (1.0, 1.0, 0.0), # Yellow
28
+ (0.0, 1.0, 1.0), # Cyan
29
+ (1.0, 0.0, 1.0), # Magenta
30
+ (1.0, 0.6, 0.0), # Orange
31
+ (0.6, 0.0, 0.8), # Purple
32
+ (0.0, 0.4, 0.0), # Dark Green
33
+ (0.8, 0.8, 0.8), # Light Gray
34
+ (0.2, 0.2, 0.2) # Dark Gray
35
+ ]
36
+
37
+ CHECKPOINT_NAMES = [
38
+ "rgb__r1/epoch-0__checkpoint-25917",
39
+ "rgb__finetune_1024/epoch-0__checkpoint-3000",
40
+ "rgb__finetune_1024/epoch-1__checkpoint-4000",
41
+ "rgb__finetune_1024/epoch-1__checkpoint-5000",
42
+ "rgb__finetune_1024/epoch-1__checkpoint-6000",
43
+ "rgb__finetune_1024/epoch-1__checkpoint-7000",
44
+ "rgb__finetune_1024/epoch-1__checkpoint-7932",
45
+ ]
46
+
47
+ PRETRAINED_MODEL_NAME_OR_PATH = "black-forest-labs/FLUX.1-dev"
48
+
49
+ tokenizer = T5TokenizerFast.from_pretrained(
50
+ PRETRAINED_MODEL_NAME_OR_PATH,
51
+ subfolder="tokenizer_2",
52
+ revision=None,
53
+ )
54
+
55
+ placeholder_token_str = ["<placeholder>"]
56
+ num_added_tokens = tokenizer.add_tokens(placeholder_token_str)
57
+ assert num_added_tokens == 1
58
+
59
+ def generate_image_event(camera_elevation, camera_lens, surrounding_prompt, checkpoint_name,
60
+ height, width, seed, guidance_scale, num_steps):
61
+ """Generate final image with segmentation masks and run inference"""
62
+ # Update scene manager's inference params before generation
63
+ scene_manager.update_inference_params(height, width, seed, guidance_scale, num_steps, checkpoint_name)
64
+ if not scene_manager.objects:
65
+ return (
66
+ "⚠️ No objects to render",
67
+ gr.update(),
68
+ Image.new('RGB', (512, 512), color='white')
69
+ )
70
+
71
+ # Get subject descriptions
72
+ subject_descriptions = [obj['description'] for obj in scene_manager.objects]
73
+
74
+ print(f"Surrounding prompt: {surrounding_prompt}")
75
+ print(f"Subject descriptions: {subject_descriptions}")
76
+ print(f"Selected checkpoint: {checkpoint_name}")
77
+
78
+ placeholder_prompt = "a photo of PLACEHOLDER " + surrounding_prompt
79
+
80
+ # Create placeholder text
81
+ subject_embeds = []
82
+ for subject_idx, subject_desc in enumerate(subject_descriptions):
83
+ input_ids = tokenizer.encode(subject_desc, return_tensors="pt", max_length=77)[0]
84
+ subject_embed = {"input_ids_t5": input_ids.tolist()}
85
+ subject_embeds.append(subject_embed)
86
+
87
+ placeholder_text = ""
88
+ for subject in subject_descriptions[:-1]:
89
+ placeholder_text = placeholder_text + f"<placeholder> {subject} and "
90
+ for subject in subject_descriptions[-1:]:
91
+ placeholder_text = placeholder_text + f"<placeholder> {subject}"
92
+ placeholder_text = placeholder_text.strip()
93
+
94
+ placeholder_token_prompt = placeholder_prompt.replace("PLACEHOLDER", placeholder_text)
95
+
96
+ call_ids = get_call_ids_from_placeholder_prompt_flux(prompt=placeholder_token_prompt,
97
+ subjects=subject_descriptions,
98
+ subjects_embeds=subject_embeds,
99
+ debug=True
100
+ )
101
+ print(f"Generated call IDs: {call_ids}")
102
+
103
+ # Convert to server expected format
104
+ subjects_data, camera_data = scene_manager._convert_to_blender_format()
105
+
106
+ # Render final high-quality image using CYCLES (port 5002)
107
+ final_img = scene_manager.render_client._send_render_request(
108
+ scene_manager.render_client.final_server_url,
109
+ subjects_data,
110
+ camera_data
111
+ )
112
+
113
+ final_img.save("model_condition.jpg")
114
+
115
+ # Render segmentation masks
116
+ success, segmask_images, error_msg = scene_manager.render_client.render_segmasks(subjects_data, camera_data)
117
+
118
+ if not success:
119
+ return (
120
+ f"❌ Failed to render segmentation masks: {error_msg}",
121
+ gr.update(),
122
+ Image.new('RGB', (512, 512), color='white')
123
+ )
124
+
125
+ # Save all files to the correct location
126
+ root_save_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/gradio_files/"
127
+ os.system(f"rm -f {root_save_dir}/*")
128
+
129
+ # Save final render to root directory
130
+ final_render_path = osp.join(root_save_dir, "cv_render.jpg")
131
+ final_img.save(final_render_path)
132
+
133
+ # Move segmentation masks
134
+ for subject_idx in range(len(subject_descriptions)):
135
+ shutil.move(
136
+ f"{str(subject_idx).zfill(3)}_segmask_cv.png",
137
+ osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png")
138
+ )
139
+
140
+ # Create JSONL
141
+ jsonl = [{
142
+ "cv": final_render_path,
143
+ "target": final_render_path,
144
+ "cuboids_segmasks": [
145
+ osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png")
146
+ for subject_idx in range(len(subject_descriptions))
147
+ ],
148
+ "PLACEHOLDER_prompts": placeholder_prompt,
149
+ "subjects": subject_descriptions,
150
+ "call_ids": call_ids,
151
+ }]
152
+
153
+ jsonl_path = osp.join(root_save_dir, "cuboids.jsonl")
154
+ with open(jsonl_path, "w") as f:
155
+ json.dump(jsonl[0], f)
156
+
157
+ # Run inference using the pre-loaded model
158
+ print(f"\n{'='*60}")
159
+ print(f"RUNNING INFERENCE")
160
+ print(f"{'='*60}\n")
161
+
162
+ inference_success, generated_image, inference_msg = run_inference_from_gradio(
163
+ checkpoint_name=checkpoint_name,
164
+ height=height,
165
+ width=width,
166
+ seed=seed,
167
+ guidance_scale=guidance_scale,
168
+ num_inference_steps=num_steps,
169
+ jsonl_path=jsonl_path
170
+ )
171
+
172
+ if not inference_success:
173
+ return (
174
+ f"✅ Saved files but inference failed: {inference_msg}",
175
+ final_img,
176
+ Image.new('RGB', (512, 512), color='white')
177
+ )
178
+
179
+ status_msg = f"✅ Generated image using {checkpoint_name} with {len(segmask_images)} segmentation masks"
180
+
181
+ # Render final high-quality image using CYCLES (port 5002)
182
+ final_img = scene_manager.render_client._send_render_request(
183
+ scene_manager.render_client.paper_figure_server_url,
184
+ subjects_data,
185
+ camera_data
186
+ )
187
+
188
+ return (
189
+ status_msg,
190
+ final_img, # Display CV render in Camera View
191
+ generated_image # Display generated image in Generated Image section
192
+ )
193
+
194
+
195
+ def get_call_ids_from_placeholder_prompt_flux(prompt: str, subjects, subjects_embeds: list, debug: bool):
196
+ assert prompt.find("<placeholder>") != -1, "Prompt must contain <placeholder> to get call ids"
197
+
198
+ # the placeholder token ID for all the tokenizers
199
+ placeholder_token_three = tokenizer.encode("<placeholder>", return_tensors="pt")[0][:-1].item()
200
+ prompt_tokens_three = tokenizer.encode(prompt, return_tensors="pt")[0].tolist()
201
+
202
+ placeholder_token_locations_three = [i for i, w in enumerate(prompt_tokens_three) if w == placeholder_token_three]
203
+ prompt = prompt.replace("<placeholder> ", "")
204
+
205
+
206
+ call_ids = []
207
+ for subject_idx, (subject, subject_embed) in enumerate(zip(subjects, subjects_embeds)):
208
+ subject_prompt_ids_t5 = subject_embed["input_ids_t5"][:-1] # T5 has SOT token only
209
+ num_t5_tokens_subject = len(subject_prompt_ids_t5)
210
+
211
+ t5_call_ids_subject = [i + placeholder_token_locations_three[subject_idx] - 2 * subject_idx - 1 for i in range(num_t5_tokens_subject)]
212
+ call_ids.append(t5_call_ids_subject)
213
+
214
+ prompt_wo_placeholder = prompt.replace("<placeholder> ", "")
215
+ t5_call_strs = tokenizer.batch_decode(tokenizer.encode(prompt_wo_placeholder, return_tensors="pt")[0].tolist())
216
+ t5_call_strs = [t5_call_strs[i] for i in t5_call_ids_subject]
217
+ if debug:
218
+ print(f"{prompt = }, t5 CALL strs for {subject} = {t5_call_strs}")
219
+
220
+ return call_ids
221
+
222
+
223
+ def map_point_to_rgb(x, y):
224
+ """
225
+ Map (x, y) inside the frustum to an RGB color with continuity and variation.
226
+ """
227
+ # Frustum boundaries
228
+ X_MIN, X_MAX = -10.0, -1.0
229
+ Y_MIN_AT_XMIN, Y_MAX_AT_XMIN = -4.5, 4.5
230
+ Y_MIN_AT_XMAX, Y_MAX_AT_XMAX = -0.5, 0.5
231
+
232
+ # Normalize x to [0, 1]
233
+ x_norm = (x - X_MIN) / (X_MAX - X_MIN)
234
+ # x_norm = np.clip(x_norm, 0, 1)
235
+
236
+ # Compute current Y bounds at given x using linear interpolation
237
+ y_min = Y_MIN_AT_XMIN + x_norm * (Y_MIN_AT_XMAX - Y_MIN_AT_XMIN)
238
+ y_max = Y_MAX_AT_XMIN + x_norm * (Y_MAX_AT_XMAX - Y_MAX_AT_XMIN)
239
+
240
+ # Normalize y to [0, 1] within current bounds
241
+ if y_max != y_min:
242
+ y_norm = (y - y_min) / (y_max - y_min)
243
+ else:
244
+ y_norm = 0.5
245
+ y_norm = np.clip(y_norm, 0.0, 1.0)
246
+
247
+ # Color mapping: more variation along x
248
+ r = x_norm
249
+ g = y_norm
250
+ b = 1.0 - x_norm
251
+
252
+ return (r, g, b)
253
+
254
+ def rgb_to_hex(rgb_tuple):
255
+ """Convert RGB tuple (0-1 range) to hex color string."""
256
+ r, g, b = rgb_tuple
257
+ return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}"
258
+
259
+
260
+ class BlenderRenderClient:
261
+ def __init__(self, cv_server_url="http://127.0.0.1:5001", segmask_server_url="http://127.0.0.1:5003", final_server_url="http://127.0.0.1:5002", paper_figure_server_url="http://127.0.0.1:5004"):
262
+ """
263
+ Initialize the Blender render client.
264
+
265
+ Args:
266
+ cv_server_url (str): URL of the camera view render server
267
+ segmask_server_url (str): URL of the segmentation mask render server
268
+ """
269
+ self.cv_server_url = cv_server_url
270
+ self.segmask_server_url = segmask_server_url
271
+ self.final_server_url = final_server_url
272
+ self.paper_figure_server_url = paper_figure_server_url
273
+ self.timeout = 30 # 30 second timeout for renders
274
+
275
+ def render_segmasks(self, subjects_data: list, camera_data: dict) -> tuple:
276
+ """
277
+ Send a segmentation mask render request.
278
+ Returns (success: bool, segmask_images: list of PIL Images or None, error_message: str or None)
279
+ """
280
+ try:
281
+ request_data = {
282
+ "subjects_data": subjects_data,
283
+ "camera_data": camera_data,
284
+ "num_samples": 1
285
+ }
286
+
287
+ response = requests.post(
288
+ f"{self.segmask_server_url}/render_segmasks",
289
+ json=request_data,
290
+ timeout=self.timeout
291
+ )
292
+
293
+ if response.status_code == 200:
294
+ result = response.json()
295
+ if result["success"]:
296
+ # Decode all segmentation masks
297
+ segmask_images = []
298
+ for img_base64 in result["segmasks_base64"]:
299
+ img_data = base64.b64decode(img_base64)
300
+ img = Image.open(io.BytesIO(img_data))
301
+ segmask_images.append(img)
302
+
303
+ print(f"Successfully rendered {len(segmask_images)} segmentation masks")
304
+ return True, segmask_images, None
305
+ else:
306
+ error_msg = result.get('error_message', 'Unknown error')
307
+ print(f"Segmask render failed: {error_msg}")
308
+ return False, None, error_msg
309
+ else:
310
+ error_msg = f"HTTP error {response.status_code}: {response.text}"
311
+ print(error_msg)
312
+ return False, None, error_msg
313
+
314
+ except requests.exceptions.Timeout:
315
+ error_msg = "Segmask render request timed out"
316
+ print(error_msg)
317
+ return False, None, error_msg
318
+ except Exception as e:
319
+ error_msg = f"Segmask render request failed: {e}"
320
+ print(error_msg)
321
+ return False, None, error_msg
322
+
323
+
324
+ def _send_render_request(self, server_url: str, subjects_data: list, camera_data: dict) -> Image.Image:
325
+ """Send a render request to a server and return the image."""
326
+ try:
327
+ request_data = {
328
+ "subjects_data": subjects_data,
329
+ "camera_data": camera_data,
330
+ "num_samples": 1
331
+ }
332
+ print(f"passing {subjects_data = } to server at {server_url}")
333
+
334
+ response = requests.post(
335
+ f"{server_url}/render",
336
+ json=request_data,
337
+ timeout=self.timeout
338
+ )
339
+
340
+ if response.status_code == 200:
341
+ result = response.json()
342
+ if result["success"]:
343
+ # Decode base64 image
344
+ img_data = base64.b64decode(result["image_base64"])
345
+ img = Image.open(io.BytesIO(img_data))
346
+ return img
347
+ else:
348
+ print(f"Render failed: {result.get('error_message', 'Unknown error')}")
349
+ return self._create_error_image("red")
350
+ else:
351
+ print(f"HTTP error {response.status_code}: {response.text}")
352
+ return self._create_error_image("orange")
353
+
354
+ except requests.exceptions.Timeout:
355
+ print("Render request timed out")
356
+ return self._create_error_image("yellow")
357
+ except Exception as e:
358
+ print(f"Render request failed: {e}")
359
+ return self._create_error_image("red")
360
+
361
+ def _create_error_image(self, color: str) -> Image.Image:
362
+ """Create a colored error image."""
363
+ return Image.new('RGB', (512, 512), color=color)
364
+
365
+ # --- Scene Management Class ---
366
+ class SceneManager:
367
+ def __init__(self):
368
+ self.objects = []
369
+ self.camera_elevation = 30.0
370
+ self.camera_lens = 50.0
371
+ self.surrounding_prompt = ""
372
+ self.next_color_idx = 0
373
+ self.colors = [
374
+ (1.0, 0.0, 0.0), # red
375
+ (0.0, 0.0, 1.0), # blue
376
+ (0.0, 1.0, 0.0), # green
377
+ (0.5, 0.0, 0.5), # purple
378
+ (1.0, 0.5, 0.0), # orange
379
+ (1.0, 1.0, 0.0), # yellow
380
+ (0.0, 1.0, 1.0), # cyan
381
+ (1.0, 0.0, 1.0), # magenta
382
+ ]
383
+
384
+ # Add inference parameters with defaults
385
+ self.inference_params = {
386
+ 'height': 512,
387
+ 'width': 512,
388
+ 'seed': 42,
389
+ 'guidance_scale': 3.5,
390
+ 'num_inference_steps': 25,
391
+ 'checkpoint': CHECKPOINT_NAMES[0] if CHECKPOINT_NAMES else None
392
+ }
393
+
394
+ # Initialize BlenderRenderClient
395
+ self.render_client = BlenderRenderClient()
396
+
397
+ # Load asset dimensions
398
+ self.asset_dimensions = self._load_asset_dimensions()
399
+
400
+
401
+ def update_inference_params(self, height, width, seed, guidance_scale, num_steps, checkpoint):
402
+ """Update inference parameters"""
403
+ self.inference_params = {
404
+ 'height': height,
405
+ 'width': width,
406
+ 'seed': seed,
407
+ 'guidance_scale': guidance_scale,
408
+ 'num_inference_steps': num_steps,
409
+ 'checkpoint': checkpoint
410
+ }
411
+
412
+
413
+ def update_cuboid_description(self, obj_id, new_description):
414
+ """Update the description of a cuboid"""
415
+ if 0 <= obj_id < len(self.objects):
416
+ if new_description.strip(): # Check not empty
417
+ self.objects[obj_id]['description'] = new_description.strip()
418
+ return True
419
+ return False
420
+
421
+
422
+ def save_scene_to_pkl(self, filepath=None):
423
+ """Save current scene data to pkl file including inference parameters"""
424
+ if filepath is None:
425
+ # Auto-generate filename with timestamp
426
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
427
+ filepath = f"scene_{timestamp}.pkl"
428
+
429
+ # Convert to the expected format
430
+ subjects_data = []
431
+ for obj in self.objects:
432
+ subject_dict = {
433
+ 'name': obj['description'],
434
+ 'type': obj['type'], # Save the object type
435
+ 'dims': tuple(obj['size']), # (width, depth, height)
436
+ 'x': [obj['position'][0] - 6.0],
437
+ 'y': [obj['position'][1]],
438
+ 'z': [obj['position'][2]],
439
+ 'azimuth': [np.radians(obj['azimuth'])], # Convert to radians
440
+ 'bbox': [(0, 0, 0, 0)] # Placeholder, can be computed if needed
441
+ }
442
+ subjects_data.append(subject_dict)
443
+
444
+ camera_data = {
445
+ 'camera_elevation': np.radians(self.camera_elevation),
446
+ 'lens': self.camera_lens,
447
+ 'global_scale': 1.0 # Default value
448
+ }
449
+
450
+ scene_dict = {
451
+ 'subjects_data': subjects_data,
452
+ 'camera_data': camera_data,
453
+ 'surrounding_prompt': self.surrounding_prompt,
454
+ 'inference_params': self.inference_params.copy()
455
+ }
456
+
457
+ try:
458
+ with open(filepath, 'wb') as f:
459
+ pickle.dump(scene_dict, f)
460
+ return True, filepath, None
461
+ except Exception as e:
462
+ return False, None, str(e)
463
+
464
+
465
+ def load_scene_from_pkl(self, filepath):
466
+ """Load scene data from pkl file including inference parameters"""
467
+ try:
468
+ with open(filepath, 'rb') as f:
469
+ scene_dict = pickle.load(f)
470
+
471
+ # Clear existing objects
472
+ self.objects = []
473
+ self.next_color_idx = 0
474
+
475
+ # Load subjects
476
+ subjects_data = scene_dict.get('subjects_data', [])
477
+ for subject_dict in subjects_data:
478
+ name = subject_dict.get('name', 'Loaded Object')
479
+ asset_type = subject_dict.get('type', 'Custom') # Load the type
480
+ dims = subject_dict.get('dims', (1.0, 1.0, 1.0))
481
+ x = float(subject_dict.get('x', [0.0])[0]) + 6.0
482
+ y = float(subject_dict.get('y', [0.0])[0])
483
+ z = float(subject_dict.get('z', [0.0])[0])
484
+ azimuth_rad = float(subject_dict.get('azimuth', [0.0])[0])
485
+ azimuth_deg = np.degrees(azimuth_rad)
486
+
487
+ # Determine original_asset_size based on type
488
+ if asset_type == "Custom" or asset_type not in self.asset_dimensions:
489
+ original_asset_size = None
490
+ else:
491
+ # Look up the original asset dimensions
492
+ asset_dims = self.asset_dimensions[asset_type]
493
+ original_asset_size = [float(asset_dims[0]), float(asset_dims[1]), float(asset_dims[2])]
494
+
495
+ # Create object
496
+ obj_id = len(self.objects)
497
+ size_list = [float(d) for d in dims]
498
+ cuboid = {
499
+ 'id': obj_id,
500
+ 'description': name,
501
+ 'type': asset_type, # Use the loaded type
502
+ 'position': [x, y, z],
503
+ 'size': size_list,
504
+ 'original_asset_size': original_asset_size, # Restore from asset_dimensions
505
+ 'azimuth': float(azimuth_deg),
506
+ 'color': self._get_next_color()
507
+ }
508
+ self.objects.append(cuboid)
509
+
510
+ # Load camera settings
511
+ camera_data = scene_dict.get('camera_data', {})
512
+ camera_elev_rad = float(camera_data.get('camera_elevation', np.radians(30.0)))
513
+ self.camera_elevation = float(np.degrees(camera_elev_rad))
514
+ self.camera_lens = float(camera_data.get('lens', 50.0))
515
+
516
+ # Load surrounding prompt
517
+ self.surrounding_prompt = scene_dict.get('surrounding_prompt', '')
518
+
519
+ # Load inference parameters
520
+ loaded_inference_params = scene_dict.get('inference_params', {})
521
+
522
+ # Get checkpoint, fall back to first available if not found
523
+ saved_checkpoint = loaded_inference_params.get('checkpoint')
524
+ if saved_checkpoint and saved_checkpoint in CHECKPOINT_NAMES:
525
+ checkpoint = saved_checkpoint
526
+ else:
527
+ checkpoint = CHECKPOINT_NAMES[0] if CHECKPOINT_NAMES else None
528
+ if saved_checkpoint:
529
+ print(f"Warning: Saved checkpoint '{saved_checkpoint}' not found, using '{checkpoint}' instead")
530
+
531
+ self.inference_params = {
532
+ 'height': loaded_inference_params.get('height', 512),
533
+ 'width': loaded_inference_params.get('width', 512),
534
+ 'seed': loaded_inference_params.get('seed', 42),
535
+ 'guidance_scale': loaded_inference_params.get('guidance_scale', 3.5),
536
+ 'num_inference_steps': loaded_inference_params.get('num_inference_steps', 25),
537
+ 'checkpoint': checkpoint
538
+ }
539
+
540
+ return True, len(subjects_data), None
541
+ except FileNotFoundError:
542
+ return False, 0, f"File not found: {filepath}"
543
+ except Exception as e:
544
+ return False, 0, f"Error loading file: {str(e)}"
545
+
546
+
547
+ def _load_asset_dimensions(self):
548
+ """Load asset dimensions from pickle file"""
549
+ pkl_path = "asset_dimensions.pkl"
550
+ if os.path.exists(pkl_path):
551
+ try:
552
+ with open(pkl_path, 'rb') as f:
553
+ return pickle.load(f)
554
+ except Exception as e:
555
+ print(f"Warning: Could not load asset dimensions: {e}")
556
+ return {}
557
+ else:
558
+ print(f"Warning: asset_dimensions.pkl not found at {pkl_path}")
559
+ return {}
560
+
561
+ def get_asset_type_choices(self):
562
+ """Get list of asset types for dropdown"""
563
+ choices = ["Custom"]
564
+ if self.asset_dimensions:
565
+ choices.extend(sorted(self.asset_dimensions.keys()))
566
+ return choices
567
+
568
+ def _get_next_color(self):
569
+ color = self.colors[self.next_color_idx % len(self.colors)]
570
+ self.next_color_idx += 1
571
+ return color
572
+
573
+
574
+ def harmonize_scales(self):
575
+ """
576
+ Harmonize the scales of all non-Custom objects based on object scales.
577
+ Always scales from original asset dimensions, ignoring any manual edits.
578
+ Custom objects remain unchanged.
579
+ """
580
+ if not self.objects:
581
+ return "No objects to harmonize"
582
+
583
+ # Find objects that can be harmonized (non-Custom with valid scales and original_asset_size)
584
+ harmonizable_objects = []
585
+ for obj in self.objects:
586
+ if (obj['type'] != "Custom" and
587
+ obj['type'] in scales and
588
+ obj['original_asset_size'] is not None):
589
+ harmonizable_objects.append(obj)
590
+
591
+ if not harmonizable_objects:
592
+ return "No objects with defined scales to harmonize (all are Custom)"
593
+
594
+ # Find the largest scale among harmonizable objects
595
+ max_scale = max(scales[obj['type']] for obj in harmonizable_objects)
596
+
597
+ if max_scale == 0:
598
+ return "Invalid max scale (0)"
599
+
600
+ # Harmonize each object by scaling from ORIGINAL ASSET dimensions
601
+ for obj in harmonizable_objects:
602
+ obj_scale = scales[obj['type']]
603
+ scale_factor = obj_scale / max_scale
604
+
605
+ # Scale from ORIGINAL ASSET dimensions, not current dimensions
606
+ obj['size'][0] = obj['original_asset_size'][0] * scale_factor # width
607
+ obj['size'][1] = obj['original_asset_size'][1] * scale_factor # depth
608
+ obj['size'][2] = obj['original_asset_size'][2] * scale_factor # height
609
+
610
+ # Update z position to keep object on ground
611
+ obj['position'][2] = 0.0
612
+
613
+ return f"Harmonized {len(harmonizable_objects)} objects based on largest scale: {max_scale}"
614
+
615
+
616
+ def add_cuboid(self, description="New Cuboid", asset_type="Custom"):
617
+ """Add a cuboid with dimensions based on asset type"""
618
+ obj_id = len(self.objects)
619
+
620
+ # Determine dimensions based on asset type
621
+ if asset_type == "Custom" or asset_type not in self.asset_dimensions:
622
+ size = [1.0, 1.0, 1.0] # Default size
623
+ original_asset_size = None # Custom objects have no original asset size
624
+ else:
625
+ # Load dimensions from pkl file
626
+ dims = self.asset_dimensions[asset_type]
627
+ size = [float(dims[0]), float(dims[1]), float(dims[2])] # [width, depth, height]
628
+ original_asset_size = size.copy() # Store the original asset dimensions
629
+
630
+ cuboid = {
631
+ 'id': obj_id,
632
+ 'description': description,
633
+ 'type': asset_type, # Store the asset type
634
+ 'position': [0.0, 0.0, 0.0], # Place on ground (z = height/2)
635
+ 'size': size,
636
+ 'original_asset_size': original_asset_size, # Store original asset dimensions
637
+ 'azimuth': 0.0,
638
+ 'color': self._get_next_color()
639
+ }
640
+ self.objects.append(cuboid)
641
+ return obj_id
642
+
643
+
644
+ def update_cuboid(self, obj_id, x, y, z, azimuth, width, depth, height):
645
+ if 0 <= obj_id < len(self.objects):
646
+ obj = self.objects[obj_id]
647
+ obj['position'] = [x, y, z]
648
+ obj['size'] = [width, depth, height]
649
+ # Note: We do NOT update original_asset_size here - it stays unchanged
650
+ obj['azimuth'] = azimuth
651
+ return True
652
+ return False
653
+
654
+
655
+ def delete_cuboid(self, obj_id):
656
+ if 0 <= obj_id < len(self.objects):
657
+ del self.objects[obj_id]
658
+ # Update IDs for remaining objects
659
+ for i, obj in enumerate(self.objects):
660
+ obj['id'] = i
661
+ return True
662
+ return False
663
+
664
+ def set_camera_elevation(self, elevation_deg):
665
+ assert type(elevation_deg) == float or type(elevation_deg) == int, f"{type(elevation_deg) = }"
666
+ self.camera_elevation = np.clip(elevation_deg, 0.0, 90.0)
667
+ return f"Camera elevation set to {elevation_deg}°"
668
+
669
+ def set_camera_lens(self, lens_value):
670
+ self.camera_lens = np.clip(lens_value, 10.0, 200.0)
671
+ return f"Camera lens set to {lens_value}mm"
672
+
673
+ def set_surrounding_prompt(self, prompt): # Add this method
674
+ self.surrounding_prompt = prompt
675
+ return f"Surrounding prompt updated"
676
+
677
+ def _convert_to_blender_format(self):
678
+ """Convert internal objects format to server expected format"""
679
+ subjects_data = []
680
+
681
+ for obj in self.objects:
682
+ subject_data = {
683
+ 'subject_name': obj['description'],
684
+ 'x': float(obj['position'][0]),
685
+ 'y': float(obj['position'][1]),
686
+ 'z': float(obj['position'][2]),
687
+ 'azimuth': float(obj['azimuth']),
688
+ 'width': float(obj['size'][0]),
689
+ 'depth': float(obj['size'][1]),
690
+ 'height': float(obj['size'][2]),
691
+ 'base_color': obj['color']
692
+ }
693
+ subjects_data.append(subject_data)
694
+
695
+ camera_data = {
696
+ 'camera_elevation': float(np.radians(self.camera_elevation)),
697
+ 'lens': float(self.camera_lens),
698
+ 'global_scale': 1.0
699
+ }
700
+
701
+ return subjects_data, camera_data
702
+
703
+ def render_cv_view(self, subjects_data: list, camera_data: dict) -> Image.Image:
704
+ """Render only the CV view."""
705
+ if not subjects_data:
706
+ return Image.new('RGB', (512, 512), color='gray')
707
+
708
+ return self.render_client._send_render_request(self.render_client.cv_server_url, subjects_data, camera_data)
709
+
710
+
711
+ def render_scene(self, width=512, height=512):
712
+ """Render only CV view using the render client."""
713
+ print(f"calling render_scene")
714
+ if not self.objects:
715
+ # Return empty image if no objects
716
+ empty_cv = Image.new('RGB', (width, height), color='gray')
717
+ return empty_cv
718
+
719
+ # Convert to server expected format
720
+ subjects_data, camera_data = self._convert_to_blender_format()
721
+ print(f"passing {subjects_data = } to render_cv_view in SceneManager")
722
+
723
+ # Render CV view only
724
+ cv_img = self.render_cv_view(subjects_data, camera_data)
725
+
726
+ return cv_img
727
+
728
+ # --- Gradio Interface Logic ---
729
+ scene_manager = SceneManager()
730
+
731
+ def get_cuboid_list_html():
732
+ """Generate HTML for the cuboid list with position-based colors"""
733
+ if not scene_manager.objects:
734
+ return "<div style='text-align: center; padding: 20px; color: #888;'>No cuboids yet. Add one to get started!</div>"
735
+
736
+ html = "<div style='display: flex; flex-direction: column; gap: 8px;'>"
737
+ for obj_idx, obj in enumerate(scene_manager.objects):
738
+ # Get position-based color
739
+ # x, y = obj['position'][0], obj['position'][1]
740
+ # rgb_color = map_point_to_rgb(x, y)
741
+ rgb_color = COLORS[obj_idx % len(COLORS)]
742
+ hex_color = rgb_to_hex(rgb_color)
743
+
744
+ # Create a lighter version for gradient end
745
+ lighter_rgb = tuple(min(1.0, c + 0.2) for c in rgb_color)
746
+ lighter_hex = rgb_to_hex(lighter_rgb)
747
+
748
+ html += f"""
749
+ <div style='background: linear-gradient(135deg, {hex_color} 0%, {lighter_hex} 100%);
750
+ padding: 12px; border-radius: 8px; color: white; text-shadow: 1px 1px 2px rgba(0,0,0,0.5);'>
751
+ <div style='font-weight: bold; font-size: 14px;'>{obj['description']}</div>
752
+ <div style='font-size: 11px; opacity: 0.9; margin-top: 4px;'>
753
+ Pos: ({obj['position'][0]:.1f}, {obj['position'][1]:.1f}, {obj['position'][2]:.1f}) |
754
+ Size: {obj['size'][0]:.1f}×{obj['size'][1]:.1f}×{obj['size'][2]:.1f}
755
+ </div>
756
+ </div>
757
+ """
758
+ html += "</div>"
759
+ return html
760
+
761
+
762
+ def add_cuboid_event(description_input, asset_type, camera_elevation, camera_lens):
763
+ """Add a new cuboid"""
764
+ if not description_input.strip():
765
+ description_input = "New Cuboid"
766
+
767
+ new_id = scene_manager.add_cuboid(description_input, asset_type)
768
+ cv_img = scene_manager.render_scene()
769
+
770
+ # Create choices for radio buttons
771
+ choices = [f"{obj['description']}" for obj in scene_manager.objects]
772
+
773
+ # Get the new object data
774
+ new_obj = scene_manager.objects[new_id]
775
+
776
+ return (
777
+ gr.update(value=""), # Clear description input
778
+ gr.update(value="Custom"), # Reset type dropdown to Custom
779
+ cv_img,
780
+ get_cuboid_list_html(),
781
+ gr.update(choices=choices, value=new_obj['description']), # Radio with new selection
782
+ gr.update(visible=True), # Show editor
783
+ gr.update(value=new_obj['description']), # Set description in editor
784
+ gr.update(value=new_obj['position'][0]),
785
+ gr.update(value=new_obj['position'][1]),
786
+ gr.update(value=new_obj['position'][2]),
787
+ gr.update(value=new_obj['azimuth']),
788
+ gr.update(value=new_obj['size'][0]),
789
+ gr.update(value=new_obj['size'][1]),
790
+ gr.update(value=new_obj['size'][2]),
791
+ gr.update(value=1.0) # Reset scale to 1.0
792
+ )
793
+
794
+
795
+ def select_cuboid_event(selected_name):
796
+ """When a cuboid is selected from radio buttons"""
797
+ if not selected_name:
798
+ return [gr.update(visible=False)] + [gr.update() for _ in range(9)] # Changed from 8 to 9
799
+
800
+ # Find the cuboid by description
801
+ obj = None
802
+ for o in scene_manager.objects:
803
+ if o['description'] == selected_name:
804
+ obj = o
805
+ break
806
+
807
+ if obj is None:
808
+ return [gr.update(visible=False)] + [gr.update() for _ in range(9)]
809
+
810
+ return (
811
+ gr.update(visible=True), # Show editor
812
+ gr.update(value=obj['description']), # Set description
813
+ gr.update(value=obj['position'][0]),
814
+ gr.update(value=obj['position'][1]),
815
+ gr.update(value=obj['position'][2]),
816
+ gr.update(value=obj['azimuth']),
817
+ gr.update(value=obj['size'][0]),
818
+ gr.update(value=obj['size'][1]),
819
+ gr.update(value=obj['size'][2]),
820
+ gr.update(value=1.0) # Reset scale to 1.0
821
+ )
822
+
823
+
824
+ def delete_selected_cuboid(selected_name, camera_elevation, camera_lens):
825
+ """Delete the currently selected cuboid"""
826
+ if not selected_name:
827
+ return gr.update(), get_cuboid_list_html(), gr.update(), gr.update(visible=False)
828
+
829
+ # Find and delete the cuboid
830
+ obj_id = None
831
+ for i, obj in enumerate(scene_manager.objects):
832
+ if obj['description'] == selected_name:
833
+ obj_id = i
834
+ break
835
+
836
+ if obj_id is not None:
837
+ scene_manager.delete_cuboid(obj_id)
838
+
839
+ cv_img = scene_manager.render_scene()
840
+
841
+ # Update choices
842
+ choices = [f"{obj['description']}" for obj in scene_manager.objects]
843
+
844
+ return (
845
+ cv_img,
846
+ get_cuboid_list_html(),
847
+ gr.update(choices=choices, value=None),
848
+ gr.update(visible=False)
849
+ )
850
+
851
+
852
+ def update_cuboid_event(selected_name, camera_elevation, camera_lens, description, x, y, z, azimuth, width, depth, height, scale):
853
+ """Update the selected cuboid including description and scale"""
854
+ scene_manager.set_camera_elevation(camera_elevation)
855
+ scene_manager.set_camera_lens(camera_lens)
856
+
857
+ if selected_name:
858
+ # Find the cuboid by description
859
+ obj_id = None
860
+ for i, obj in enumerate(scene_manager.objects):
861
+ if obj['description'] == selected_name:
862
+ obj_id = i
863
+ break
864
+
865
+ if obj_id is not None:
866
+ # Update description first if changed
867
+ if description.strip() and description.strip() != selected_name:
868
+ scene_manager.update_cuboid_description(obj_id, description.strip())
869
+
870
+ # Apply scale to dimensions
871
+ scaled_width = width * scale
872
+ scaled_depth = depth * scale
873
+ scaled_height = height * scale
874
+
875
+ # Update other properties with scaled dimensions
876
+ scene_manager.update_cuboid(obj_id, x, y, z, azimuth, scaled_width, scaled_depth, scaled_height)
877
+
878
+ # Get updated object for return
879
+ updated_obj = scene_manager.objects[obj_id]
880
+ new_name = updated_obj['description']
881
+
882
+ cv_img = scene_manager.render_scene()
883
+
884
+ # Update choices with new descriptions
885
+ choices = [f"{obj['description']}" for obj in scene_manager.objects]
886
+
887
+ # Return updated HTML, image, radio choices, new selection, updated sliders, and reset scale
888
+ return (
889
+ get_cuboid_list_html(),
890
+ cv_img,
891
+ gr.update(choices=choices, value=new_name if obj_id is not None else None),
892
+ gr.update(value=scaled_width if obj_id is not None else width), # Update width slider
893
+ gr.update(value=scaled_depth if obj_id is not None else depth), # Update depth slider
894
+ gr.update(value=scaled_height if obj_id is not None else height), # Update height slider
895
+ gr.update(value=1.0) # Reset scale to 1.0
896
+ )
897
+
898
+
899
+ def camera_change_event(camera_elevation, camera_lens):
900
+ """Handle camera control changes"""
901
+ scene_manager.set_camera_elevation(camera_elevation)
902
+ scene_manager.set_camera_lens(camera_lens)
903
+ cv_img = scene_manager.render_scene()
904
+ return cv_img
905
+
906
+
907
+ def surrounding_prompt_change_event(prompt_text): # Add this function
908
+ """Handle surrounding prompt changes"""
909
+ scene_manager.set_surrounding_prompt(prompt_text)
910
+ return None # No visual update needed
911
+
912
+
913
+ def render_segmask_event(camera_elevation, camera_lens, surrounding_prompt):
914
+ """Render segmentation masks for all objects"""
915
+ if not scene_manager.objects:
916
+ return "⚠️ No objects to render", gr.update(visible=False), []
917
+
918
+ # Get subject descriptions
919
+ subject_descriptions = [obj['description'] for obj in scene_manager.objects]
920
+
921
+ # Now you have access to:
922
+ # - surrounding_prompt: the text from surrounding_prompt_input
923
+ # - subject_descriptions: list of all subject descriptions
924
+
925
+ print(f"Surrounding prompt: {surrounding_prompt}")
926
+ print(f"Subject descriptions: {subject_descriptions}")
927
+
928
+ placeholder_prompt = "a photo of PLACEHOLDER " + surrounding_prompt
929
+
930
+ # Create placeholder text
931
+ subject_embeds = []
932
+ for subject_idx, subject_desc in enumerate(subject_descriptions):
933
+ input_ids = tokenizer.encode(subject_desc, return_tensors="pt", max_length=77)[0]
934
+ subject_embed = {"input_ids_t5": input_ids.tolist()}
935
+ subject_embeds.append(subject_embed)
936
+
937
+ placeholder_text = ""
938
+ for subject in subject_descriptions[:-1]:
939
+ placeholder_text = placeholder_text + f"<placeholder> {subject} and "
940
+ for subject in subject_descriptions[-1:]:
941
+ placeholder_text = placeholder_text + f"<placeholder> {subject}"
942
+ placeholder_text = placeholder_text.strip()
943
+
944
+ placeholder_token_prompt = placeholder_prompt.replace("PLACEHOLDER", placeholder_text)
945
+
946
+ call_ids = get_call_ids_from_placeholder_prompt_flux(prompt=placeholder_token_prompt,
947
+ subjects=subject_descriptions,
948
+ subjects_embeds=subject_embeds,
949
+ debug=True
950
+ )
951
+ print(f"Generated call IDs: {call_ids}")
952
+
953
+
954
+ # Convert to server expected format
955
+ subjects_data, camera_data = scene_manager._convert_to_blender_format()
956
+
957
+ # You can add the prompt and descriptions to the request if needed
958
+ # For example, add to subjects_data or camera_data before sending
959
+
960
+ # Render segmentation masks
961
+ success, segmask_images, error_msg = scene_manager.render_client.render_segmasks(subjects_data, camera_data)
962
+
963
+ # copy all the data to the correct location
964
+ root_save_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/gradio_files/"
965
+ os.system("rm /archive/vaibhav.agrawal/a-bev-of-the-latents/gradio_files/*")
966
+ shutil.move("cv_render.jpg", osp.join(root_save_dir, "cv_render.jpg"))
967
+ for subject_idx in range(len(subject_descriptions)):
968
+ shutil.move(f"{str(subject_idx).zfill(3)}_segmask_cv.png", osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png"))
969
+
970
+ jsonl = [{
971
+ "cv": osp.join(root_save_dir, "cv_render.jpg"),
972
+ "target": osp.join(root_save_dir, "cv_render.jpg"),
973
+ "cuboids_segmasks": [osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png") for subject_idx in range(len(subject_descriptions))],
974
+ "PLACEHOLDER_prompts": placeholder_prompt,
975
+ "subjects": subject_descriptions,
976
+ "call_ids": call_ids,
977
+ }]
978
+
979
+ with open(osp.join(root_save_dir, "cuboids.jsonl"), "w") as f:
980
+ for item in jsonl:
981
+ f.write(json.dumps(item) + "\n")
982
+
983
+ if success:
984
+ return (
985
+ f"✅ Successfully rendered {len(segmask_images)} segmentation masks",
986
+ gr.update(visible=True),
987
+ segmask_images
988
+ )
989
+ else:
990
+ return (
991
+ f"❌ Failed to render segmentation masks: {error_msg}",
992
+ gr.update(visible=False),
993
+ []
994
+ )
995
+
996
+
997
+ def harmonize_event(selected_name, camera_elevation, camera_lens):
998
+ """Harmonize all object scales and update the scene"""
999
+ message = scene_manager.harmonize_scales()
1000
+ print(message)
1001
+
1002
+ cv_img = scene_manager.render_scene()
1003
+
1004
+ # If a cuboid is selected, update its sliders
1005
+ if selected_name:
1006
+ obj = None
1007
+ for o in scene_manager.objects:
1008
+ if o['description'] == selected_name:
1009
+ obj = o
1010
+ break
1011
+
1012
+ if obj is not None:
1013
+ return (
1014
+ cv_img,
1015
+ get_cuboid_list_html(),
1016
+ gr.update(value=obj['position'][0]),
1017
+ gr.update(value=obj['position'][1]),
1018
+ gr.update(value=obj['position'][2]),
1019
+ gr.update(value=obj['azimuth']),
1020
+ gr.update(value=obj['size'][0]),
1021
+ gr.update(value=obj['size'][1]),
1022
+ gr.update(value=obj['size'][2])
1023
+ )
1024
+
1025
+ # No object selected or object not found
1026
+ return (
1027
+ cv_img,
1028
+ get_cuboid_list_html(),
1029
+ gr.update(),
1030
+ gr.update(),
1031
+ gr.update(),
1032
+ gr.update(),
1033
+ gr.update(),
1034
+ gr.update(),
1035
+ gr.update()
1036
+ )
1037
+
1038
+
1039
+ def save_scene_event():
1040
+ """Save the current scene to a pkl file"""
1041
+ success, filepath, error = scene_manager.save_scene_to_pkl()
1042
+
1043
+ if success:
1044
+ return f"✅ Scene saved successfully to: {filepath}\n📋 Saved parameters: {scene_manager.inference_params}"
1045
+ else:
1046
+ return f"❌ Failed to save scene: {error}"
1047
+
1048
+
1049
+ def load_scene_event(filepath):
1050
+ """Load a scene from a pkl file and restore all parameters"""
1051
+ if not filepath.strip():
1052
+ return (
1053
+ "⚠️ Please enter a file path",
1054
+ gr.update(),
1055
+ gr.update(),
1056
+ gr.update(),
1057
+ gr.update(),
1058
+ gr.update(),
1059
+ gr.update(),
1060
+ gr.update(), # surrounding_prompt
1061
+ gr.update(), # checkpoint
1062
+ gr.update(), # height
1063
+ gr.update(), # width
1064
+ gr.update(), # seed
1065
+ gr.update(), # guidance
1066
+ gr.update() # steps
1067
+ )
1068
+
1069
+ success, num_objects, error = scene_manager.load_scene_from_pkl(filepath)
1070
+
1071
+ if success:
1072
+ # Re-render the scene
1073
+ cv_img = scene_manager.render_scene()
1074
+
1075
+ # Update UI components
1076
+ choices = [f"{obj['description']}" for obj in scene_manager.objects]
1077
+
1078
+ params_msg = f"✅ Scene loaded: {num_objects} objects\n📋 Restored parameters: {scene_manager.inference_params}"
1079
+
1080
+ return (
1081
+ params_msg,
1082
+ cv_img,
1083
+ get_cuboid_list_html(),
1084
+ gr.update(choices=choices, value=None),
1085
+ gr.update(visible=False),
1086
+ gr.update(value=scene_manager.camera_elevation),
1087
+ gr.update(value=scene_manager.camera_lens),
1088
+ gr.update(value=scene_manager.surrounding_prompt),
1089
+ gr.update(value=scene_manager.inference_params['checkpoint']),
1090
+ gr.update(value=scene_manager.inference_params['height']),
1091
+ gr.update(value=scene_manager.inference_params['width']),
1092
+ gr.update(value=scene_manager.inference_params['seed']),
1093
+ gr.update(value=scene_manager.inference_params['guidance_scale']),
1094
+ gr.update(value=scene_manager.inference_params['num_inference_steps'])
1095
+ )
1096
+ else:
1097
+ return (
1098
+ f"❌ {error}",
1099
+ gr.update(),
1100
+ gr.update(),
1101
+ gr.update(),
1102
+ gr.update(),
1103
+ gr.update(),
1104
+ gr.update(),
1105
+ gr.update(),
1106
+ gr.update(),
1107
+ gr.update(),
1108
+ gr.update(),
1109
+ gr.update(),
1110
+ gr.update(),
1111
+ gr.update()
1112
+ )
1113
+
1114
+
1115
+ # --- Gradio UI Layout ---
1116
+ with gr.Blocks(
1117
+ theme=gr.themes.Soft(
1118
+ primary_hue="green",
1119
+ secondary_hue="gray",
1120
+ neutral_hue="gray"
1121
+ ),
1122
+ css="""
1123
+ .gradio-container {
1124
+ background: linear-gradient(135deg, #0d1117 0%, #1a3d2e 50%, #000000 100%) !important;
1125
+ color: #ffffff !important;
1126
+ }
1127
+ .block {
1128
+ background: rgba(15, 36, 25, 0.8) !important;
1129
+ border: 1px solid #2d5a41 !important;
1130
+ border-radius: 8px !important;
1131
+ }
1132
+ .form {
1133
+ background: rgba(15, 36, 25, 0.6) !important;
1134
+ }
1135
+ h1, h2, h3, h4, h5, h6 {
1136
+ color: #ffffff !important;
1137
+ }
1138
+ .markdown {
1139
+ color: #e6e6e6 !important;
1140
+ }
1141
+ label {
1142
+ color: #cccccc !important;
1143
+ }
1144
+ .gr-button {
1145
+ background: linear-gradient(135deg, #2d5a41, #3d6a51) !important;
1146
+ border: 1px solid #4a7c59 !important;
1147
+ color: #ffffff !important;
1148
+ }
1149
+ .gr-button:hover {
1150
+ background: linear-gradient(135deg, #3d6a51, #4a7c59) !important;
1151
+ }
1152
+ .gr-input, .gr-textbox, .gr-dropdown {
1153
+ background: rgba(15, 36, 25, 0.8) !important;
1154
+ border: 1px solid #2d5a41 !important;
1155
+ color: #ffffff !important;
1156
+ }
1157
+ .gr-input:focus, .gr-textbox:focus {
1158
+ border-color: #4a7c59 !important;
1159
+ background: rgba(26, 61, 46, 0.8) !important;
1160
+ }
1161
+ .gr-slider input[type="range"] {
1162
+ background: #2d5a41 !important;
1163
+ }
1164
+ .gr-slider input[type="range"]::-webkit-slider-thumb {
1165
+ background: #4a7c59 !important;
1166
+ }
1167
+ .gr-radio label {
1168
+ color: #cccccc !important;
1169
+ }
1170
+ .gr-panel {
1171
+ background: rgba(15, 36, 25, 0.6) !important;
1172
+ border: 1px solid #2d5a41 !important;
1173
+ }
1174
+ """
1175
+ ) as demo:
1176
+ gr.Markdown("# [CVPR-2026] 3D Aware Occlusion Control in Text-to-Image Generation 🏞️🧱")
1177
+ # TOP ROW
1178
+ with gr.Row():
1179
+ # TOP LEFT - Edit Properties
1180
+ with gr.Column(scale=1):
1181
+ # Add description textbox at the top
1182
+ # with gr.Column(visible=False) as editor_section:
1183
+ # gr.Markdown("## ✏️ Edit Properties")
1184
+
1185
+ # delete_btn = gr.Button("❌ Delete Selected Cuboid", variant="stop", size="sm")
1186
+
1187
+ # with gr.Row():
1188
+ # edit_x = gr.Slider(-10, 10, value=0, step=0.1, label="X")
1189
+ # edit_y = gr.Slider(-10, 10, value=0, step=0.1, label="Y")
1190
+ # edit_z = gr.Slider(0, 10, value=1, step=0.1, label="Z")
1191
+
1192
+ # edit_azimuth = gr.Slider(-180, 180, value=0, step=1, label="Azimuth (°)")
1193
+
1194
+ # with gr.Row():
1195
+ # edit_width = gr.Slider(0.1, 5, value=1, step=0.1, label="Width")
1196
+ # edit_depth = gr.Slider(0.1, 5, value=1, step=0.1, label="Depth")
1197
+ # edit_height = gr.Slider(0.1, 5, value=1, step=0.1, label="Height")
1198
+ with gr.Column(visible=False) as editor_section:
1199
+ gr.Markdown("## ✏️ Edit Properties")
1200
+
1201
+ edit_description = gr.Textbox(
1202
+ label="Description",
1203
+ placeholder="Enter object description",
1204
+ info="Description cannot be empty"
1205
+ )
1206
+
1207
+ delete_btn = gr.Button("❌ Delete Selected Cuboid", variant="stop", size="sm")
1208
+
1209
+ with gr.Row():
1210
+ edit_x = gr.Slider(-10, 10, value=0, step=0.1, label="X")
1211
+ edit_y = gr.Slider(-10, 10, value=0, step=0.1, label="Y")
1212
+ edit_z = gr.Slider(0, 10, value=1, step=0.1, label="Z")
1213
+
1214
+ edit_azimuth = gr.Slider(-180, 180, value=0, step=1, label="Azimuth (°)")
1215
+
1216
+ with gr.Row():
1217
+ edit_width = gr.Slider(0.1, 5, value=1, step=0.1, label="Width")
1218
+ edit_depth = gr.Slider(0.1, 5, value=1, step=0.1, label="Depth")
1219
+ edit_height = gr.Slider(0.1, 5, value=1, step=0.1, label="Height")
1220
+
1221
+ # Add scale slider
1222
+ edit_scale = gr.Slider(
1223
+ 0.1, 3.0, value=1.0, step=0.1,
1224
+ label="Scale",
1225
+ info="Multiplier for all dimensions (resets to 1.0 after update)"
1226
+ )
1227
+
1228
+ # Add the Update Scene button
1229
+ update_scene_btn = gr.Button("🔄 Update Scene", variant="primary", size="sm")
1230
+
1231
+ # TOP MIDDLE - Camera View
1232
+ with gr.Column(scale=1):
1233
+ gr.Markdown("## 📷 Camera View")
1234
+ cv_image_output = gr.Image(label="Camera View", height=400)
1235
+
1236
+ # TOP RIGHT - Generated Image
1237
+ with gr.Column(scale=1):
1238
+ gr.Markdown("## 🎨 Generated Image")
1239
+ generated_image_output = gr.Image(label="Generated Image", height=400)
1240
+
1241
+ # BOTTOM ROW
1242
+ with gr.Row():
1243
+ # BOTTOM LEFT - Cuboid List and Selection
1244
+ with gr.Column(scale=1):
1245
+ gr.Markdown("## 📦 Cuboids")
1246
+ cuboid_list_html = gr.HTML(get_cuboid_list_html())
1247
+
1248
+ gr.Markdown("### Select Cuboid to Edit")
1249
+ cuboid_radio = gr.Radio(choices=[], label="", visible=True)
1250
+
1251
+ # BOTTOM RIGHT - Camera Controls and Add New Cuboid
1252
+ with gr.Column(scale=2):
1253
+ with gr.Row():
1254
+ with gr.Column():
1255
+ gr.Markdown("## Global Controls")
1256
+ camera_elevation_slider = gr.Slider(0, 90, value=30, label="Camera Elevation (degrees)")
1257
+ camera_lens_slider = gr.Slider(10, 200, value=50, label="Camera Lens (mm)")
1258
+
1259
+ # Add surrounding prompt textbox
1260
+ surrounding_prompt_input = gr.Textbox(
1261
+ placeholder="e.g., in a forest, in a city, on a beach",
1262
+ label="Surrounding Prompt",
1263
+ info="Describe the surrounding environment"
1264
+ )
1265
+
1266
+ gr.Markdown("## 🔧 Scene Tools")
1267
+ harmonize_btn = gr.Button("⚖️ Harmonize Scales", variant="secondary")
1268
+
1269
+ # Save/Load Section
1270
+ gr.Markdown("## 💾 Save/Load Scene")
1271
+ with gr.Row():
1272
+ save_scene_btn = gr.Button("💾 Save Scene", variant="secondary")
1273
+ load_scene_btn = gr.Button("📂 Load Scene", variant="secondary")
1274
+
1275
+ load_path_input = gr.Textbox(
1276
+ placeholder="/path/to/scene.pkl",
1277
+ label="Load Scene Path",
1278
+ info="Enter path to pkl file to load"
1279
+ )
1280
+ save_load_status = gr.Markdown("")
1281
+
1282
+ with gr.Column():
1283
+ gr.Markdown("## ➕ Add New Cuboid")
1284
+ add_cuboid_description_input = gr.Textbox(placeholder="Enter cuboid description", label="Description")
1285
+ asset_type_dropdown = gr.Dropdown(
1286
+ choices=scene_manager.get_asset_type_choices(),
1287
+ value="Custom",
1288
+ label="Type",
1289
+ info="Select asset type to load dimensions, or choose Custom"
1290
+ )
1291
+ add_cuboid_btn = gr.Button("Add Cuboid", variant="primary")
1292
+ generate_btn = gr.Button("🎨 Generate Image", variant="primary")
1293
+
1294
+ # Add checkpoint dropdown
1295
+ checkpoint_dropdown = gr.Dropdown(
1296
+ choices=CHECKPOINT_NAMES,
1297
+ value=CHECKPOINT_NAMES[0] if CHECKPOINT_NAMES else None,
1298
+ label="Checkpoint",
1299
+ info="Select model checkpoint for generation"
1300
+ )
1301
+
1302
+ # Inference Parameters
1303
+ gr.Markdown("### Inference Parameters")
1304
+
1305
+ with gr.Row():
1306
+ inference_height = gr.Slider(
1307
+ minimum=256, maximum=1024, value=512, step=64,
1308
+ label="Height"
1309
+ )
1310
+ inference_width = gr.Slider(
1311
+ minimum=256, maximum=1024, value=512, step=64,
1312
+ label="Width"
1313
+ )
1314
+
1315
+ inference_seed = gr.Number(
1316
+ value=42, label="Random Seed", precision=0
1317
+ )
1318
+
1319
+ inference_guidance = gr.Slider(
1320
+ minimum=1.0, maximum=10.0, value=3.5, step=0.5,
1321
+ label="Guidance Scale"
1322
+ )
1323
+
1324
+ inference_steps = gr.Slider(
1325
+ minimum=10, maximum=50, value=25, step=1,
1326
+ label="Inference Steps"
1327
+ )
1328
+
1329
+ # Event Handlers
1330
+ def add_cuboid_with_auto_update(description_input, asset_type, camera_elevation, camera_lens):
1331
+ """Add cuboid and auto-update scene"""
1332
+ result = add_cuboid_event(description_input, asset_type, camera_elevation, camera_lens)
1333
+ return result
1334
+
1335
+ # Update add_cuboid_btn.click event handler (around line 850):
1336
+ add_cuboid_btn.click(
1337
+ add_cuboid_with_auto_update,
1338
+ inputs=[add_cuboid_description_input, asset_type_dropdown, camera_elevation_slider, camera_lens_slider],
1339
+ outputs=[
1340
+ add_cuboid_description_input,
1341
+ asset_type_dropdown,
1342
+ cv_image_output,
1343
+ cuboid_list_html,
1344
+ cuboid_radio,
1345
+ editor_section,
1346
+ edit_description,
1347
+ edit_x, edit_y, edit_z,
1348
+ edit_azimuth,
1349
+ edit_width, edit_depth, edit_height,
1350
+ edit_scale # Add this
1351
+ ]
1352
+ )
1353
+
1354
+ # Update the cuboid_radio.change event handler (around line 860):
1355
+ cuboid_radio.change(
1356
+ select_cuboid_event,
1357
+ inputs=[cuboid_radio],
1358
+ outputs=[
1359
+ editor_section,
1360
+ edit_description,
1361
+ edit_x, edit_y, edit_z,
1362
+ edit_azimuth,
1363
+ edit_width, edit_depth, edit_height,
1364
+ edit_scale # Add this
1365
+ ]
1366
+ )
1367
+
1368
+ delete_btn.click(
1369
+ delete_selected_cuboid,
1370
+ inputs=[cuboid_radio, camera_elevation_slider, camera_lens_slider],
1371
+ outputs=[cv_image_output, cuboid_list_html, cuboid_radio, editor_section]
1372
+ )
1373
+
1374
+ # Save/Load handlers
1375
+ save_scene_btn.click(
1376
+ save_scene_event,
1377
+ inputs=[],
1378
+ outputs=[save_load_status]
1379
+ )
1380
+
1381
+ load_scene_btn.click(
1382
+ load_scene_event,
1383
+ inputs=[load_path_input],
1384
+ outputs=[
1385
+ save_load_status,
1386
+ cv_image_output,
1387
+ cuboid_list_html,
1388
+ cuboid_radio,
1389
+ editor_section,
1390
+ camera_elevation_slider,
1391
+ camera_lens_slider,
1392
+ surrounding_prompt_input,
1393
+ checkpoint_dropdown,
1394
+ inference_height,
1395
+ inference_width,
1396
+ inference_seed,
1397
+ inference_guidance,
1398
+ inference_steps
1399
+ ]
1400
+ )
1401
+
1402
+ # Auto-update scene when sliders change
1403
+ # for slider in [edit_x, edit_y, edit_z, edit_azimuth, edit_width, edit_depth, edit_height]:
1404
+ # slider.change(
1405
+ # update_cuboid_event,
1406
+ # inputs=[
1407
+ # cuboid_radio,
1408
+ # camera_elevation_slider,
1409
+ # camera_lens_slider,
1410
+ # edit_x, edit_y, edit_z,
1411
+ # edit_azimuth,
1412
+ # edit_width, edit_depth, edit_height
1413
+ # ],
1414
+ # outputs=[cuboid_list_html, cv_image_output]
1415
+ # )
1416
+ # Update the update_scene_btn.click event handler (around line 920):
1417
+ update_scene_btn.click(
1418
+ update_cuboid_event,
1419
+ inputs=[
1420
+ cuboid_radio,
1421
+ camera_elevation_slider,
1422
+ camera_lens_slider,
1423
+ edit_description,
1424
+ edit_x, edit_y, edit_z,
1425
+ edit_azimuth,
1426
+ edit_width, edit_depth, edit_height,
1427
+ edit_scale # Add this
1428
+ ],
1429
+ outputs=[
1430
+ cuboid_list_html,
1431
+ cv_image_output,
1432
+ cuboid_radio,
1433
+ edit_width, # Add this
1434
+ edit_depth, # Add this
1435
+ edit_height, # Add this
1436
+ edit_scale # Add this (to reset to 1.0)
1437
+ ]
1438
+ )
1439
+
1440
+
1441
+ # Update generate button click handler
1442
+ generate_btn.click(
1443
+ generate_image_event,
1444
+ inputs=[
1445
+ camera_elevation_slider,
1446
+ camera_lens_slider,
1447
+ surrounding_prompt_input,
1448
+ checkpoint_dropdown,
1449
+ inference_height,
1450
+ inference_width,
1451
+ inference_seed,
1452
+ inference_guidance,
1453
+ inference_steps
1454
+ ],
1455
+ outputs=[save_load_status, cv_image_output, generated_image_output]
1456
+ )
1457
+
1458
+
1459
+ harmonize_btn.click(
1460
+ harmonize_event,
1461
+ inputs=[cuboid_radio, camera_elevation_slider, camera_lens_slider],
1462
+ outputs=[
1463
+ cv_image_output,
1464
+ cuboid_list_html,
1465
+ edit_x, edit_y, edit_z,
1466
+ edit_azimuth,
1467
+ edit_width, edit_depth, edit_height
1468
+ ]
1469
+ )
1470
+
1471
+ # Camera controls
1472
+ for control in [camera_elevation_slider, camera_lens_slider]:
1473
+ control.change(
1474
+ camera_change_event,
1475
+ inputs=[camera_elevation_slider, camera_lens_slider],
1476
+ outputs=[cv_image_output]
1477
+ )
1478
+
1479
+ # Surrounding prompt control
1480
+ surrounding_prompt_input.change(
1481
+ surrounding_prompt_change_event,
1482
+ inputs=[surrounding_prompt_input],
1483
+ outputs=[]
1484
+ )
1485
+
1486
+
1487
+ # Initial render
1488
+ def initial_render():
1489
+ cv_img = scene_manager.render_scene()
1490
+ gen_img = Image.new('RGB', (512, 512), color='white')
1491
+ return cv_img, gen_img
1492
+
1493
+ demo.load(
1494
+ initial_render,
1495
+ outputs=[cv_image_output, generated_image_output]
1496
+ )
1497
+
1498
+
1499
+ if __name__ == "__main__":
1500
+ import os
1501
+ os.system("./launch_blender_backend.sh &")
1502
+ # Initialize inference engine (load model once at startup)
1503
+ initialize_inference_engine(base_model_path="black-forest-labs/FLUX.1-dev")
1504
+ demo.launch(share=True)
gradio_app/asset_dimensions.pkl ADDED
Binary file (1.75 kB). View file
 
gradio_app/blender_backend.py ADDED
@@ -0,0 +1,1521 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bpy
2
+ import bpy_extras
3
+ import numpy as np
4
+ import bmesh
5
+ import copy
6
+ import PIL
7
+ from PIL import Image
8
+ import matplotlib.pyplot as plt
9
+ import colorsys
10
+ import os
11
+ import os.path as osp
12
+ import shutil
13
+ import sys
14
+ import math
15
+ import mathutils
16
+ import random
17
+ import cv2
18
+ from object_scales import scales
19
+ import matplotlib.colors as mcolors
20
+ import torch
21
+
22
+ def map_point_to_rgb(x, y, z):
23
+ """
24
+ Map (x, y) inside the frustum to an RGB color with continuity and variation.
25
+ """
26
+ # Frustum boundaries
27
+ X_MIN, X_MAX = -12.0, -1.0
28
+ Y_MIN_AT_XMIN, Y_MAX_AT_XMIN = -4.5, 4.5
29
+ Y_MIN_AT_XMAX, Y_MAX_AT_XMAX = -0.5, 0.5
30
+ Z_MIN, Z_MAX = 0.0, 2.50
31
+ # Normalize x to [0, 1]
32
+ x_norm = (x - X_MIN) / (X_MAX - X_MIN)
33
+ x_norm = np.clip(x_norm, 0, 1)
34
+
35
+ # Compute current Y bounds at given x using linear interpolation
36
+ y_min = Y_MIN_AT_XMIN + x_norm * (Y_MIN_AT_XMAX - Y_MIN_AT_XMIN)
37
+ y_max = Y_MAX_AT_XMIN + x_norm * (Y_MAX_AT_XMAX - Y_MAX_AT_XMIN)
38
+
39
+ # Normalize y to [0, 1] within current bounds
40
+ if y_max != y_min:
41
+ y_norm = (y - y_min) / (y_max - y_min)
42
+ else:
43
+ y_norm = 0.5
44
+ y_norm = np.clip(y_norm, 0, 1)
45
+
46
+ z_norm = (z - Z_MIN) / (Z_MAX - Z_MIN)
47
+
48
+ # Color mapping: more variation along x
49
+ r = x_norm
50
+ # g = 0.5 * y_norm + 0.25 * x_norm
51
+ g = y_norm
52
+ b = z_norm
53
+
54
+ return (r, g, b)
55
+
56
+
57
+ def set_world_color(color=(0.1, 0.1, 0.1)):
58
+ """
59
+ Sets the world background color to match the grid floor.
60
+
61
+ Args:
62
+ color (tuple): RGB color values (0-1 range)
63
+ """
64
+ scene = bpy.context.scene
65
+
66
+ # Create a new world if it doesn't exist
67
+ if scene.world is None:
68
+ world = bpy.data.worlds.new(name="World")
69
+ scene.world = world
70
+ else:
71
+ world = scene.world
72
+
73
+ # Enable use of nodes for the world
74
+ world.use_nodes = True
75
+
76
+ # Get the node tree
77
+ nodes = world.node_tree.nodes
78
+ links = world.node_tree.links
79
+
80
+ # Find or create the Background node
81
+ background_node = None
82
+ for node in nodes:
83
+ if node.type == 'BACKGROUND':
84
+ background_node = node
85
+ break
86
+
87
+ if background_node is None:
88
+ # Clear existing nodes and create new ones
89
+ nodes.clear()
90
+ background_node = nodes.new(type='ShaderNodeBackground')
91
+ output_node = nodes.new(type='ShaderNodeOutputWorld')
92
+ links.new(background_node.outputs['Background'], output_node.inputs['Surface'])
93
+
94
+ # Set the background color
95
+ background_node.inputs['Color'].default_value = (*color, 1.0)
96
+ background_node.inputs['Strength'].default_value = 1.0
97
+
98
+
99
+ COLORS = [
100
+ (1.0, 0.0, 0.0), # Red
101
+ (0.0, 0.8, 0.2), # Green
102
+ (0.0, 0.0, 1.0), # Blue
103
+ (1.0, 1.0, 0.0), # Yellow
104
+ (0.0, 1.0, 1.0), # Cyan
105
+ (1.0, 0.0, 1.0), # Magenta
106
+ (1.0, 0.6, 0.0), # Orange
107
+ (0.6, 0.0, 0.8), # Purple
108
+ (0.0, 0.4, 0.0), # Dark Green
109
+ (0.8, 0.8, 0.8), # Light Gray
110
+ (0.2, 0.2, 0.2) # Dark Gray
111
+ ]
112
+
113
+ def do_z_pass(seg_masks: torch.Tensor, dist_values: torch.Tensor) -> torch.Tensor:
114
+ """
115
+ Performs a z-pass on segmentation masks based on distance values to the camera.
116
+ For each pixel, if multiple subjects' masks are active, only the one with the smallest distance (closest) remains active.
117
+
118
+ Args:
119
+ seg_masks (torch.Tensor): Binary segmentation masks of shape (n_subjects, h, w) with dtype uint8.
120
+ dist_values (torch.Tensor): Distance values for each subject of shape (n_subjects,).
121
+
122
+ Returns:
123
+ torch.Tensor: Processed segmentation masks after z-pass, same shape and dtype as seg_masks.
124
+ """
125
+ # Ensure tensors are on the same device
126
+ device = seg_masks.device
127
+
128
+ # Get dimensions
129
+ n_subjects, h, w = seg_masks.shape
130
+
131
+ # Reshape distance values for broadcasting across spatial dimensions
132
+ dist_values_expanded = dist_values.view(n_subjects, 1, 1)
133
+
134
+ # Create a tensor where active pixels have their distance, others have a high value (1e10)
135
+ masked_dist = torch.where(seg_masks.bool(), dist_values_expanded, torch.tensor(1e10, device=device))
136
+
137
+ # Find the subject index with the minimum distance for each pixel (shape (h, w))
138
+ closest_indices = torch.argmin(masked_dist, dim=0)
139
+
140
+ # Initialize output tensor with zeros
141
+ output = torch.zeros_like(seg_masks)
142
+
143
+ # Scatter 1s into the output tensor where the closest subject's indices are
144
+ # closest_indices.unsqueeze(0) adds a dummy dimension to match scatter's expected shape
145
+ output.scatter_(
146
+ dim=0,
147
+ index=closest_indices.unsqueeze(0),
148
+ src=torch.ones_like(closest_indices.unsqueeze(0), dtype=output.dtype)
149
+ )
150
+
151
+ # Zero out any positions where the original mask was inactive
152
+ output = output * seg_masks
153
+
154
+ return output
155
+
156
+
157
+ def get_image_to_world_matrix(camera_obj, render):
158
+ """
159
+ Calculates the matrix to transform a point from clip space to world space.
160
+
161
+ Args:
162
+ camera_obj (bpy.types.Object): The camera object.
163
+ render (bpy.types.RenderSettings): The scene's render settings.
164
+
165
+ Returns:
166
+ mathutils.Matrix: The 4x4 matrix for clip-to-world transformation.
167
+ """
168
+ # Get the camera's view matrix (world to camera)
169
+ view_matrix = camera_obj.matrix_world.inverted()
170
+
171
+ # Get the camera's projection matrix
172
+ # This matrix depends on the render resolution, so it's best to calculate it
173
+ # for the specific dimensions you're using.
174
+ projection_matrix = camera_obj.calc_matrix_camera(
175
+ bpy.context.evaluated_depsgraph_get(),
176
+ x=render.resolution_x,
177
+ y=render.resolution_y,
178
+ scale_x=render.pixel_aspect_x,
179
+ scale_y=render.pixel_aspect_y,
180
+ )
181
+
182
+ # Combine and invert to get the clip-to-world matrix
183
+ clip_to_world_matrix = (projection_matrix @ view_matrix).inverted()
184
+
185
+ return clip_to_world_matrix
186
+
187
+
188
+ def unproject_image_point(camera_obj, image_coord, depth):
189
+ """
190
+ Transforms a 2D image coordinate with a depth value into a 3D world coordinate.
191
+
192
+ Args:
193
+ camera_obj (bpy.types.Object): The camera used for rendering.
194
+ image_coord (tuple or list): The (x, y) pixel coordinate.
195
+ depth (float): The depth value at that coordinate (from the Z-pass).
196
+
197
+ Returns:
198
+ mathutils.Vector: The calculated 3D point in world space.
199
+ """
200
+ render = bpy.context.scene.render
201
+
202
+ # 1. Get the clip-to-world transformation matrix
203
+ clip_to_world_mat = get_image_to_world_matrix(camera_obj, render)
204
+
205
+ # 2. Convert image coordinates to Normalized Device Coordinates (NDC)
206
+ # (from [0, res] to [-1, 1])
207
+ ndc_x = (image_coord[0] / render.resolution_x) * 2 - 1
208
+ ndc_y = (image_coord[1] / render.resolution_y) * 2 - 1
209
+
210
+ # In Blender's Z-pass, the depth value is the distance from the camera's plane.
211
+ # We can use Blender's utility function to find the 3D vector for the pixel.
212
+ # This vector is in camera space and points from the camera towards the pixel.
213
+ view_vector = bpy_extras.view3d_utils.region_2d_to_vector_3d(
214
+ bpy.context.region,
215
+ bpy.context.space_data.region_3d,
216
+ image_coord
217
+ )
218
+
219
+ # 4. Project the view vector into world space and scale by depth
220
+ # The view_vector is normalized and in camera space.
221
+ # To get the point in world space, we transform the vector by the camera's
222
+ # world matrix (not the view matrix).
223
+ world_vector = camera_obj.matrix_world.to_3x3() @ view_vector
224
+
225
+ # The depth from the Z-pass is the distance along the camera's local Z-axis.
226
+ # To find the true distance along the ray, we must account for the angle.
227
+ # We can calculate the scaling factor 't' for our world_vector.
228
+ camera_forward = -camera_obj.matrix_world.col[2].xyz
229
+ t = depth / world_vector.dot(camera_forward)
230
+
231
+ # 5. Calculate the final world coordinate
232
+ # Start from the camera's location and move along the ray.
233
+ world_point = camera_obj.matrix_world.translation + (t * world_vector)
234
+
235
+ return world_point
236
+
237
+ # --- Example Usage ---
238
+ # This example assumes you have an active scene with a camera and have rendered an image.
239
+ # You would typically run this after rendering, where you can access the depth map.
240
+
241
+
242
+ def multiply_random_color(obj, random_color):
243
+ """
244
+ Multiplies the existing base color of an object's materials
245
+ with a random color.
246
+ """
247
+ for material_slot in obj.material_slots:
248
+ if material_slot.material:
249
+ material = material_slot.material
250
+ if material.use_nodes:
251
+ nodes = material.node_tree.nodes
252
+ links = material.node_tree.links
253
+
254
+ # Find the Principled BSDF node
255
+ principled_bsdf = nodes.get("Principled BSDF")
256
+ if not principled_bsdf:
257
+ continue
258
+
259
+ # Get the node connected to the Base Color input
260
+ base_color_input = principled_bsdf.inputs.get("Base Color")
261
+ if not base_color_input:
262
+ continue
263
+
264
+ # Create a MixRGB node and set it to multiply
265
+ mix_rgb_node = nodes.new(type='ShaderNodeMixRGB')
266
+ mix_rgb_node.blend_type = 'MULTIPLY'
267
+ mix_rgb_node.inputs['Fac'].default_value = 2.00
268
+ mix_rgb_node.location = (principled_bsdf.location.x - 200, principled_bsdf.location.y)
269
+
270
+ # Set the second color to a random color
271
+ mix_rgb_node.inputs['Color2'].default_value = random_color
272
+
273
+ # If a node is already connected to the Base Color,
274
+ # connect it to the first color input of the MixRGB node.
275
+ if base_color_input.is_linked:
276
+ original_link = base_color_input.links[0]
277
+ original_node = original_link.from_node
278
+ original_socket = original_link.from_socket
279
+ links.new(original_node.outputs[original_socket.name], mix_rgb_node.inputs['Color1'])
280
+ links.remove(original_link)
281
+ else:
282
+ # If no node is connected, use the original default color
283
+ original_color = base_color_input.default_value
284
+ mix_rgb_node.inputs['Color1'].default_value = original_color
285
+
286
+ # Connect the MixRGB node to the Principled BSDF's Base Color
287
+ links.new(mix_rgb_node.outputs['Color'], base_color_input)
288
+
289
+
290
+ OUTPUT_DIR = "four_subject_renders"
291
+ OBJECTS_DIR = "obja_2units_along_y/glbs"
292
+
293
+ NUM_AZIMUTH_BINS = 1
294
+ NUM_LIGHTS = 1
295
+
296
+ MAX_TRIES = 25
297
+
298
+ IMG_DIM = 1024
299
+
300
+ MASK_RES = 50
301
+
302
+ THRESHOLD_LOWER = 150
303
+ THRESHOLD_UPPER = 768
304
+
305
+ ROOT_OBJS_DIR = "/ssd_scratch/vaibhav.agrawal/a-bev-of-the-latents/glb_files/"
306
+
307
+ OBJ_SIDE_LENGTH = 2.0
308
+
309
+ def calculate_iou(box1, box2):
310
+ """
311
+ Calculate the Intersection over Union (IoU) of two bounding boxes.
312
+
313
+ Parameters:
314
+ box1, box2: Each box is defined by a tuple (x1, y1, x2, y2)
315
+ where (x1, y1) is the top-left corner and (x2, y2) is the bottom-right corner.
316
+
317
+ Returns:
318
+ float: IoU value
319
+ """
320
+ # Unpack coordinatesO
321
+ x1_min, y1_min, x1_max, y1_max = box1
322
+ x2_min, y2_min, x2_max, y2_max = box2
323
+
324
+ # Determine the coordinates of the intersection rectangle
325
+ inter_x_min = max(x1_min, x2_min)
326
+ inter_y_min = max(y1_min, y2_min)
327
+ inter_x_max = min(x1_max, x2_max)
328
+ inter_y_max = min(y1_max, y2_max)
329
+
330
+ # Compute the area of intersection rectangle
331
+ inter_width = max(0, inter_x_max - inter_x_min)
332
+ inter_height = max(0, inter_y_max - inter_y_min)
333
+ intersection_area = inter_width * inter_height
334
+
335
+ # Compute the area of both bounding boxes
336
+ box1_area = (x1_max - x1_min) * (y1_max - y1_min)
337
+ box2_area = (x2_max - x2_min) * (y2_max - y2_min)
338
+
339
+ # Compute the area of the union
340
+ union_area = box1_area + box2_area - intersection_area
341
+
342
+ # Compute IoU
343
+ iou = intersection_area / union_area if union_area > 0 else 0
344
+
345
+ return iou
346
+
347
+
348
+ def get_object_2d_bbox(empty_obj, scene):
349
+ """
350
+ Get the 2D bounding box coordinates of an object in the rendered image.
351
+
352
+ Args:
353
+ empty_obj (bpy.types.Object): The empty object containing the child mesh objects.
354
+ scene (bpy.types.Scene): The current scene.
355
+
356
+ Returns:
357
+ tuple: A tuple containing the 2D bounding box coordinates in pixel space
358
+ in the format (min_x, min_y, max_x, max_y).
359
+ """
360
+ # Get the render settings
361
+ render = scene.render
362
+ res_x = render.resolution_x
363
+ res_y = render.resolution_y
364
+
365
+ # Initialize the bounding box coordinates
366
+ min_x, min_y = float('inf'), float('inf')
367
+ max_x, max_y = float('-inf'), float('-inf')
368
+
369
+ depsgraph = bpy.context.evaluated_depsgraph_get()
370
+
371
+ # Iterate through the child mesh objects
372
+ for obj in empty_obj.children:
373
+ if obj.type == 'MESH':
374
+ # Get the bounding box coordinates in world space
375
+ bbox_corners = [obj.matrix_world @ mathutils.Vector(corner) for corner in obj.bound_box]
376
+
377
+ # Transform the bounding box corners to camera space
378
+ for corner in bbox_corners:
379
+ corner_2d = bpy_extras.object_utils.world_to_camera_view(scene, scene.camera, corner)
380
+
381
+ # Scale the coordinates to pixel space
382
+ x = corner_2d.x * res_x
383
+ y = (1 - corner_2d.y) * res_y # Flip Y since Blender renders from bottom to top
384
+
385
+ # Update the bounding box coordinates
386
+ min_x = min(min_x, x)
387
+ min_y = min(min_y, y)
388
+ max_x = max(max_x, x)
389
+ max_y = max(max_y, y)
390
+
391
+ # Return the 2D bounding box coordinates in pixel space
392
+ return (int(min_x), int(min_y), int(max_x), int(max_y))
393
+
394
+ def reset_cameras(scene) -> None:
395
+ """Resets the cameras in the scene to a single default camera."""
396
+ # Delete all existing cameras
397
+ bpy.ops.object.select_all(action="DESELECT")
398
+ bpy.ops.object.select_by_type(type="CAMERA")
399
+ bpy.ops.object.delete()
400
+
401
+ # Create a new camera with default properties
402
+ bpy.ops.object.camera_add()
403
+
404
+ # Get the camera by searching for it (it will be the only camera)
405
+ new_camera = None
406
+ for obj in scene.objects:
407
+ if obj.type == 'CAMERA':
408
+ new_camera = obj
409
+ break
410
+
411
+ new_camera.name = "Camera"
412
+
413
+ # Set the new camera as the active camera for the scene
414
+ scene.camera = new_camera
415
+
416
+
417
+ def add_plane():
418
+ print(f"in add_plane")
419
+
420
+ # Create mesh data
421
+ mesh = bpy.data.meshes.new("Plane")
422
+ backdrop = bpy.data.objects.new("Plane", mesh)
423
+ bpy.context.scene.collection.objects.link(backdrop)
424
+
425
+ # Create plane geometry using bmesh
426
+ bm = bmesh.new()
427
+ bmesh.ops.create_grid(bm, x_segments=1, y_segments=1, size=25.0) # size=25 gives 50x50 plane
428
+ bm.to_mesh(mesh)
429
+ bm.free()
430
+
431
+ # Add material
432
+ mat_backdrop = bpy.data.materials.new(name="WhiteMaterial")
433
+ mat_backdrop.diffuse_color = (0, 0, 0, 1) # Black
434
+ backdrop.data.materials.append(mat_backdrop)
435
+
436
+
437
+ def add_plane_cycles():
438
+ print(f"in add_plane")
439
+
440
+ # Create mesh data
441
+ mesh = bpy.data.meshes.new("Plane")
442
+ backdrop = bpy.data.objects.new("Plane", mesh)
443
+ bpy.context.scene.collection.objects.link(backdrop)
444
+
445
+ # Create plane geometry using bmesh
446
+ bm = bmesh.new()
447
+ bmesh.ops.create_grid(bm, x_segments=1, y_segments=1, size=25.0) # size=25 gives 50x50 plane
448
+ bm.to_mesh(mesh)
449
+ bm.free()
450
+
451
+ # Add material
452
+ mat_backdrop = bpy.data.materials.new(name="WhiteMaterial")
453
+ mat_backdrop.diffuse_color = (0.050, 0.050, 0.050, 1) # White
454
+ backdrop.data.materials.append(mat_backdrop)
455
+
456
+
457
+ def remove_all_planes():
458
+ # Deselect all objects first
459
+ bpy.ops.object.select_all(action='DESELECT')
460
+
461
+ # Select all plane objects in the scene
462
+ for obj in bpy.data.objects:
463
+ if obj.type == 'MESH' and obj.name.startswith('Plane'):
464
+ obj.select_set(True)
465
+
466
+ # Delete all selected planes
467
+ bpy.ops.object.delete()
468
+
469
+
470
+ def remove_all_lights():
471
+ """Remove all lights from the scene without using operators."""
472
+ lights_to_remove = [obj for obj in bpy.data.objects if obj.type == 'LIGHT']
473
+
474
+ for light in lights_to_remove:
475
+ bpy.data.objects.remove(light, do_unlink=True)
476
+
477
+ # Clean up orphaned light data blocks
478
+ for light_data in bpy.data.lights:
479
+ if light_data.users == 0:
480
+ bpy.data.lights.remove(light_data)
481
+
482
+
483
+ def set_lights_cv(radius, center, num_points, intensity):
484
+ print(f"in set_lights_cv")
485
+ radius = radius + 10.0
486
+ phi = np.random.uniform(-np.pi / 2, np.pi / 2, num_points) # azimuthal angle
487
+ cos_theta = np.random.uniform(0.50, 1.0, num_points) # cos of polar angle
488
+ theta = np.arccos(cos_theta) # polar angle
489
+ x = np.sin(theta) * np.cos(phi)
490
+ y = np.sin(theta) * np.sin(phi)
491
+ z = cos_theta # cos(theta) == z on unit sphere
492
+ # Scale to radius and shift to center
493
+ points = np.stack([x, y, z], axis=1) * radius + center
494
+ for point in points:
495
+ # Track objects before adding light
496
+ before_objs = set(bpy.data.objects)
497
+ bpy.ops.object.light_add(type='POINT', location=point)
498
+ after_objs = set(bpy.data.objects)
499
+
500
+ # Get the newly created light
501
+ diff_objs = after_objs - before_objs
502
+ light = list(diff_objs)[0]
503
+
504
+ light.data.energy = intensity
505
+ light.data.use_shadow = True
506
+ # light.data.shadow_soft_size = 1.0 # Adjust shadow softness if needed
507
+ return points
508
+
509
+
510
+ def adjust_color_brightness(rgb_color, factor):
511
+ """
512
+ Adjusts the brightness of an RGB color by a multiplicative factor.
513
+
514
+ Args:
515
+ rgb_color (tuple): The base color as an (R, G, B) or (R, G, B, A) tuple.
516
+ factor (float): The factor to multiply the brightness by.
517
+ > 1.0 makes it lighter, < 1.0 makes it darker.
518
+
519
+ Returns:
520
+ tuple: The new (R, G, B, A) color.
521
+ """
522
+ # Use only RGB for conversion, keep alpha separate
523
+ h, s, v = colorsys.rgb_to_hsv(rgb_color[0], rgb_color[1], rgb_color[2])
524
+
525
+ # Multiply the Value (brightness) by the factor, and clamp it between 0 and 1
526
+ v = max(0, min(1, v * factor))
527
+
528
+ new_rgb = colorsys.hsv_to_rgb(h, s, v)
529
+
530
+ # Return as an RGBA tuple, preserving original alpha if it exists
531
+ alpha = rgb_color[3] if len(rgb_color) == 4 else 1.0
532
+ return (new_rgb[0], new_rgb[1], new_rgb[2], alpha)
533
+
534
+
535
+ def get_primitive_object_translucent(base_color=(0.0, 1.0, 0.0), edge_color=None, face_opacity=0.025):
536
+ """
537
+ Spawns a cuboid primitive with individually colored faces and highlighted edges.
538
+
539
+ Args:
540
+ base_color (tuple): The base RGB color for the faces.
541
+ edge_color (tuple): The RGBA color for the edges (defaults to white).
542
+ face_opacity (float): The opacity of the cuboid faces (0.0 = invisible, 1.0 = opaque). Default is 0.2.
543
+ """
544
+ # --- Create the Cuboid and Parent ---
545
+ bpy.ops.object.empty_add(type="PLAIN_AXES")
546
+ # empty_object = bpy.context.object
547
+ empty_object = bpy.data.objects.new("Empty", None)
548
+ before_objs = set(bpy.data.objects)
549
+ bpy.ops.mesh.primitive_cube_add(size=0.5, location=(0, 0, 0))
550
+ after_objs = set(bpy.data.objects)
551
+ diff_objs = after_objs - before_objs
552
+
553
+ obj = None
554
+ for o in diff_objs:
555
+ obj = o
556
+ obj.parent = empty_object
557
+ world_matrix = obj.matrix_world
558
+ obj.matrix_world = world_matrix
559
+
560
+ # --- Create and Assign Materials for Each Face ---
561
+ if obj:
562
+ # left front right back bottom top
563
+ brightness_factors = [
564
+ 0.30, 0.30, 0.30, 0.30, 1.00, 0.30,
565
+ ]
566
+ colors = [adjust_color_brightness(base_color, factor) for factor in brightness_factors]
567
+
568
+ for i, color in enumerate(colors):
569
+ material = bpy.data.materials.new(name=f"FaceColor_{i}")
570
+ material.use_nodes = True
571
+ obj.data.materials.append(material)
572
+
573
+ nodes = material.node_tree.nodes
574
+ links = material.node_tree.links
575
+ nodes.clear()
576
+
577
+ # Create Principled BSDF instead of Emission for proper transparency
578
+ bsdf = nodes.new(type="ShaderNodeBsdfPrincipled")
579
+ bsdf.location = (0, 0)
580
+ bsdf.inputs['Base Color'].default_value = color
581
+ bsdf.inputs['Alpha'].default_value = face_opacity # Set face opacity
582
+ bsdf.inputs['Emission Color'].default_value = color[:3] + (1.0,) # Fixed: Use 'Emission Color' instead of 'Emission'
583
+ bsdf.inputs['Emission Strength'].default_value = 1.0 # Emission strength
584
+
585
+ material_output = nodes.new(type="ShaderNodeOutputMaterial")
586
+ material_output.location = (200, 0)
587
+ links.new(bsdf.outputs['BSDF'], material_output.inputs['Surface'])
588
+
589
+ # Enable transparency settings for the material
590
+ material.blend_method = 'BLEND'
591
+ material.show_transparent_back = False
592
+
593
+ if len(obj.data.polygons) == len(colors):
594
+ for i, poly in enumerate(obj.data.polygons):
595
+ poly.material_index = i
596
+ else:
597
+ print("Warning: The number of colors does not match the number of faces.")
598
+
599
+ # --- Add Wireframe Edges ---
600
+ # edge_material = bpy.data.materials.new(name="EdgeDelimiterMaterial")
601
+ # edge_material.use_nodes = True
602
+
603
+ # nodes = edge_material.node_tree.nodes
604
+ # links = edge_material.node_tree.links
605
+ # nodes.clear()
606
+
607
+ # if edge_color is None:
608
+ # edge_color = adjust_color_brightness(base_color, 0.10)
609
+
610
+ # edge_emission_node = nodes.new(type="ShaderNodeEmission")
611
+ # edge_emission_node.inputs['Color'].default_value = edge_color
612
+ # edge_output_node = nodes.new(type="ShaderNodeOutputMaterial")
613
+ # links.new(edge_emission_node.outputs['Emission'], edge_output_node.inputs['Surface'])
614
+
615
+ # obj.data.materials.append(edge_material)
616
+
617
+ # wire_mod = obj.modifiers.new(name="EdgeDelimiter", type='WIREFRAME')
618
+ # wire_mod.thickness = 0.01
619
+ # wire_mod.use_replace = False
620
+ # wire_mod.material_offset = len(obj.data.materials) - 1
621
+
622
+ # --- Bounding Box Calculation ---
623
+ bbox_corners = []
624
+ bpy.context.view_layer.update()
625
+ for child in empty_object.children:
626
+ for corner in child.bound_box:
627
+ world_corner = child.matrix_world @ mathutils.Vector(corner)
628
+ bbox_corners.append(world_corner)
629
+
630
+ if not bbox_corners:
631
+ return 0, empty_object
632
+
633
+ min_x = min(corner.x for corner in bbox_corners)
634
+ min_y = min(corner.y for corner in bbox_corners)
635
+ min_z = min(corner.z for corner in bbox_corners)
636
+
637
+ max_x = max(corner.x for corner in bbox_corners)
638
+ max_y = max(corner.y for corner in bbox_corners)
639
+ max_z = max(corner.z for corner in bbox_corners)
640
+
641
+ return max_z, empty_object
642
+
643
+
644
+ def get_primitive_object_translucent_rgb(base_color=(0.0, 1.0, 0.0), edge_color=None, face_opacity=0.025):
645
+ """
646
+ Spawns a cuboid primitive with individually colored faces and highlighted edges.
647
+
648
+ Args:
649
+ base_color (tuple): The base RGB color for the faces.
650
+ edge_color (tuple): The RGBA color for the edges (defaults to white).
651
+ face_opacity (float): The opacity of the cuboid faces (0.0 = invisible, 1.0 = opaque). Default is 0.2.
652
+ """
653
+ # --- Create the Cuboid and Parent ---
654
+ bpy.ops.object.empty_add(type="PLAIN_AXES")
655
+ # empty_object = bpy.context.object
656
+ empty_object = bpy.data.objects.new("Empty", None)
657
+ before_objs = set(bpy.data.objects)
658
+ bpy.ops.mesh.primitive_cube_add(size=0.5, location=(0, 0, 0))
659
+ after_objs = set(bpy.data.objects)
660
+ diff_objs = after_objs - before_objs
661
+
662
+ obj = None
663
+ for o in diff_objs:
664
+ obj = o
665
+ obj.parent = empty_object
666
+ world_matrix = obj.matrix_world
667
+ obj.matrix_world = world_matrix
668
+
669
+ # --- Create and Assign Materials for Each Face ---
670
+ if obj:
671
+ # left front right back bottom top
672
+ brightness_factors = [
673
+ 0.50, 0.50, 0.50, 0.50, 0.50, 0.50,
674
+ ]
675
+ red = (1.0, 0.0, 0.0, 1.0)
676
+ green = (0.0, 1.0, 0.0, 1.0)
677
+ blue = (0.0, 0.0, 1.0, 1.0)
678
+ colors = [adjust_color_brightness(green, factor) for factor in brightness_factors[:4]] + [adjust_color_brightness(blue, brightness_factors[4])] + [adjust_color_brightness(red, brightness_factors[5])]
679
+ colors = [colors[-2], colors[-1], colors[0], colors[1], colors[2], colors[3]]
680
+
681
+ for i, color in enumerate(colors):
682
+ material = bpy.data.materials.new(name=f"FaceColor_{i}")
683
+ material.use_nodes = True
684
+ obj.data.materials.append(material)
685
+
686
+ nodes = material.node_tree.nodes
687
+ links = material.node_tree.links
688
+ nodes.clear()
689
+
690
+ # Create Principled BSDF instead of Emission for proper transparency
691
+ bsdf = nodes.new(type="ShaderNodeBsdfPrincipled")
692
+ bsdf.location = (0, 0)
693
+ bsdf.inputs['Base Color'].default_value = color
694
+ bsdf.inputs['Alpha'].default_value = face_opacity # Set face opacity
695
+ bsdf.inputs['Emission Color'].default_value = color[:3] + (1.0,) # Fixed: Use 'Emission Color' instead of 'Emission'
696
+ bsdf.inputs['Emission Strength'].default_value = 1.0 # Emission strength
697
+
698
+ material_output = nodes.new(type="ShaderNodeOutputMaterial")
699
+ material_output.location = (200, 0)
700
+ links.new(bsdf.outputs['BSDF'], material_output.inputs['Surface'])
701
+
702
+ # Enable transparency settings for the material
703
+ material.blend_method = 'BLEND'
704
+ material.show_transparent_back = False
705
+
706
+ if len(obj.data.polygons) == len(colors):
707
+ for i, poly in enumerate(obj.data.polygons):
708
+ poly.material_index = i
709
+ else:
710
+ print("Warning: The number of colors does not match the number of faces.")
711
+
712
+ # --- Add Wireframe Edges ---
713
+ edge_material = bpy.data.materials.new(name="EdgeDelimiterMaterial")
714
+ edge_material.use_nodes = True
715
+
716
+ nodes = edge_material.node_tree.nodes
717
+ links = edge_material.node_tree.links
718
+ nodes.clear()
719
+
720
+ if edge_color is None:
721
+ edge_color = adjust_color_brightness(base_color, 0.10)
722
+
723
+ edge_emission_node = nodes.new(type="ShaderNodeEmission")
724
+ edge_emission_node.inputs['Color'].default_value = edge_color
725
+ edge_output_node = nodes.new(type="ShaderNodeOutputMaterial")
726
+ links.new(edge_emission_node.outputs['Emission'], edge_output_node.inputs['Surface'])
727
+
728
+ obj.data.materials.append(edge_material)
729
+
730
+ wire_mod = obj.modifiers.new(name="EdgeDelimiter", type='WIREFRAME')
731
+ wire_mod.thickness = 0.01
732
+ wire_mod.use_replace = False
733
+ wire_mod.material_offset = len(obj.data.materials) - 1
734
+
735
+ # --- Bounding Box Calculation ---
736
+ bbox_corners = []
737
+ bpy.context.view_layer.update()
738
+ for child in empty_object.children:
739
+ for corner in child.bound_box:
740
+ world_corner = child.matrix_world @ mathutils.Vector(corner)
741
+ bbox_corners.append(world_corner)
742
+
743
+ if not bbox_corners:
744
+ return 0, empty_object
745
+
746
+ min_x = min(corner.x for corner in bbox_corners)
747
+ min_y = min(corner.y for corner in bbox_corners)
748
+ min_z = min(corner.z for corner in bbox_corners)
749
+
750
+ max_x = max(corner.x for corner in bbox_corners)
751
+ max_y = max(corner.y for corner in bbox_corners)
752
+ max_z = max(corner.z for corner in bbox_corners)
753
+
754
+ return max_z, empty_object
755
+
756
+
757
+
758
+ def get_primitive_object(base_color=(0.0, 1.0, 0.0), edge_color=None):
759
+ """
760
+ Spawns a cuboid primitive with individually colored faces and highlighted edges.
761
+
762
+ Args:
763
+ base_color (tuple): The base RGB color for the faces.
764
+ edge_color (tuple): The RGBA color for the edges (defaults to white).
765
+ """
766
+ # --- Create the Empty Parent ---
767
+ empty_object = bpy.data.objects.new("Empty", None)
768
+ bpy.context.scene.collection.objects.link(empty_object)
769
+ empty_object.empty_display_type = 'PLAIN_AXES'
770
+
771
+ # --- Create the Cuboid using bmesh ---
772
+ mesh = bpy.data.meshes.new("Cube")
773
+ obj = bpy.data.objects.new("Cube", mesh)
774
+ bpy.context.scene.collection.objects.link(obj)
775
+
776
+ # Create cube geometry
777
+ bm = bmesh.new()
778
+ bmesh.ops.create_cube(bm, size=0.5)
779
+ bm.to_mesh(mesh)
780
+ bm.free()
781
+
782
+ # Set parent
783
+ obj.parent = empty_object
784
+ world_matrix = obj.matrix_world
785
+ obj.matrix_world = world_matrix
786
+
787
+ # --- Create and Assign Materials for Each Face ---
788
+ if obj:
789
+ # left front right back bottom top
790
+ brightness_factors = [
791
+ 0.35, 0.20, 0.65, 0.90, 0.50, 0.50
792
+ ]
793
+ colors = [adjust_color_brightness(base_color, factor) for factor in brightness_factors]
794
+
795
+ for i, color in enumerate(colors):
796
+ material = bpy.data.materials.new(name=f"FaceColor_{i}")
797
+ material.use_nodes = True
798
+ obj.data.materials.append(material)
799
+
800
+ nodes = material.node_tree.nodes
801
+ links = material.node_tree.links
802
+ nodes.clear()
803
+
804
+ emission_node = nodes.new(type="ShaderNodeEmission")
805
+ emission_node.inputs['Color'].default_value = color
806
+ material_output = nodes.new(type="ShaderNodeOutputMaterial")
807
+ links.new(emission_node.outputs['Emission'], material_output.inputs['Surface'])
808
+
809
+ material.blend_method = 'BLEND'
810
+ material.show_transparent_back = False
811
+
812
+ if len(obj.data.polygons) == len(colors):
813
+ for i, poly in enumerate(obj.data.polygons):
814
+ poly.material_index = i
815
+ else:
816
+ print("Warning: The number of colors does not match the number of faces.")
817
+
818
+ # --- MODIFICATION START: Add White Edges ---
819
+
820
+ # 1. Create a new material for the wireframe edges
821
+ edge_material = bpy.data.materials.new(name="EdgeDelimiterMaterial")
822
+ edge_material.use_nodes = True
823
+
824
+ # Set up the nodes for a simple white emission shader
825
+ nodes = edge_material.node_tree.nodes
826
+ links = edge_material.node_tree.links
827
+ nodes.clear()
828
+
829
+ if edge_color is None:
830
+ edge_color = adjust_color_brightness(base_color, 0.10)
831
+
832
+ edge_emission_node = nodes.new(type="ShaderNodeEmission")
833
+ edge_emission_node.inputs['Color'].default_value = edge_color
834
+ edge_output_node = nodes.new(type="ShaderNodeOutputMaterial")
835
+ links.new(edge_emission_node.outputs['Emission'], edge_output_node.inputs['Surface'])
836
+
837
+ # 2. Add the edge material to the object's material slots
838
+ obj.data.materials.append(edge_material)
839
+
840
+ # 3. Add and configure the Wireframe modifier
841
+ wire_mod = obj.modifiers.new(name="EdgeDelimiter", type='WIREFRAME')
842
+ wire_mod.thickness = 0.01 # The thickness of the edge lines
843
+ wire_mod.use_replace = False # Set to False to keep the original faces
844
+ # This offset tells the modifier to use the last material we added (the white one)
845
+ wire_mod.material_offset = len(obj.data.materials) - 1
846
+
847
+ # --- MODIFICATION END ---
848
+
849
+
850
+ # --- Bounding Box Calculation (remains the same) ---
851
+ bbox_corners = []
852
+ # Update the dependency graph to ensure modifiers are accounted for
853
+ bpy.context.view_layer.update()
854
+ for child in empty_object.children:
855
+ # Use child.bound_box which is in object's local space
856
+ for corner in child.bound_box:
857
+ # Convert corner to world space
858
+ world_corner = child.matrix_world @ mathutils.Vector(corner)
859
+ bbox_corners.append(world_corner)
860
+
861
+ if not bbox_corners:
862
+ return 0, empty_object # Return a default value if no corners found
863
+
864
+ min_x = min(corner.x for corner in bbox_corners)
865
+ min_y = min(corner.y for corner in bbox_corners)
866
+ min_z = min(corner.z for corner in bbox_corners)
867
+
868
+ max_x = max(corner.x for corner in bbox_corners)
869
+ max_y = max(corner.y for corner in bbox_corners)
870
+ max_z = max(corner.z for corner in bbox_corners)
871
+
872
+ return max_z, empty_object
873
+
874
+ class BlenderCuboidRenderer:
875
+ def __init__(self, render_engine):
876
+ """
877
+ Initialize the Blender cuboid renderer.
878
+
879
+ Args:
880
+ img_dim (int): Image dimensions (square)
881
+ render_engine (str): Blender render engine ('EEVEE' or 'CYCLES')
882
+ num_lights (int): Number of lights to add
883
+ max_tries (int): Maximum tries for placement
884
+ """
885
+ self.img_dim = 1024
886
+ self.render_engine = render_engine
887
+ self.blender_grid_dims = scales
888
+
889
+ self.radius = 6.0
890
+ self.center = -6.0
891
+
892
+ # Scene references
893
+ self.context = None
894
+ self.scene = None
895
+ self.camera = None
896
+ self.render = None
897
+
898
+ # Setup the scene
899
+ self.setup_scene()
900
+
901
+
902
+ def setup_scene(self):
903
+ """
904
+ Setup the basic Blender scene with camera, lighting, and render settings.
905
+
906
+ Args:
907
+ camera_data (dict): Camera configuration containing elevation, lens, global_scale, etc.
908
+ """
909
+ # Get all objects in the scene
910
+ objects_to_remove = []
911
+
912
+ for obj in bpy.data.objects:
913
+ # Remove default cube, plane, camera, and lights
914
+ if obj.type in {'MESH', 'LIGHT', 'CAMERA'}:
915
+ objects_to_remove.append(obj)
916
+
917
+ # Delete the objects
918
+ for obj in objects_to_remove:
919
+ bpy.data.objects.remove(obj, do_unlink=True)
920
+
921
+ # Also clear orphaned data
922
+ for mesh in bpy.data.meshes:
923
+ if mesh.users == 0:
924
+ bpy.data.meshes.remove(mesh)
925
+
926
+ for light in bpy.data.lights:
927
+ if light.users == 0:
928
+ bpy.data.lights.remove(light)
929
+
930
+ for camera in bpy.data.cameras:
931
+ if camera.users == 0:
932
+ bpy.data.cameras.remove(camera)
933
+
934
+ bpy.context.scene.world = None
935
+
936
+ # Initialize Blender scene
937
+ # bpy.ops.wm.read_factory_settings(use_empty=True)
938
+ self.context = bpy.context
939
+ self.scene = self.context.scene
940
+ if self.render_engine == "CYCLES":
941
+ self.scene.cycles.samples = 32
942
+ self.render = self.scene.render
943
+
944
+ # Set render engine and resolution
945
+ self.render.engine = self.render_engine
946
+ self.context.scene.render.resolution_x = self.img_dim
947
+ self.context.scene.render.resolution_y = self.img_dim
948
+ self.context.scene.render.resolution_percentage = 100
949
+
950
+ # Setup compositing nodes
951
+ self._setup_compositing()
952
+
953
+
954
+ def _setup_compositing(self):
955
+ """Setup Blender compositing nodes for depth and RGB output."""
956
+ self.context.scene.use_nodes = True
957
+ tree = self.context.scene.node_tree
958
+ links = tree.links
959
+
960
+ self.context.scene.render.use_compositing = True
961
+ self.context.view_layer.use_pass_z = True
962
+
963
+ # clear default nodes
964
+ for n in tree.nodes:
965
+ tree.nodes.remove(n)
966
+
967
+ # create input render layer node
968
+ rl = tree.nodes.new('CompositorNodeRLayers')
969
+
970
+ map_node = tree.nodes.new(type="CompositorNodeMapValue")
971
+ map_node.size = [0.05]
972
+ map_node.use_min = True
973
+ map_node.min = [0]
974
+ map_node.use_max = True
975
+ map_node.max = [65336]
976
+ links.new(rl.outputs[2], map_node.inputs[0])
977
+
978
+ invert = tree.nodes.new(type="CompositorNodeInvert")
979
+ links.new(map_node.outputs[0], invert.inputs[1])
980
+
981
+ # create output node
982
+ v = tree.nodes.new('CompositorNodeViewer')
983
+ v.use_alpha = True
984
+
985
+ # create a file output node and set the path
986
+ fileOutput = tree.nodes.new(type="CompositorNodeOutputFile")
987
+ fileOutput.base_path = "."
988
+ links.new(invert.outputs[0], fileOutput.inputs[0])
989
+
990
+ # Links
991
+ links.new(rl.outputs[0], v.inputs[0]) # link Image to Viewer Image RGB
992
+ links.new(rl.outputs['Depth'], v.inputs[1]) # link Render Z to Viewer Image Alpha
993
+
994
+ # Update scene to apply changes
995
+ self.context.view_layer.update()
996
+
997
+
998
+ def _setup_camera_cv(self, camera_data):
999
+ """Setup camera position and orientation."""
1000
+ reset_cameras(self.scene)
1001
+ self.camera = self.scene.objects["Camera"]
1002
+
1003
+ elevation = camera_data["camera_elevation"]
1004
+ tan_elevation = np.tan(elevation)
1005
+ cos_elevation = np.cos(elevation)
1006
+ sin_elevation = np.sin(elevation)
1007
+
1008
+ radius = self.radius
1009
+ center = self.center
1010
+
1011
+ self.camera.location = mathutils.Vector((radius * cos_elevation + center, 0, radius * sin_elevation))
1012
+ direction = mathutils.Vector((-1, 0, -tan_elevation))
1013
+ self.context.scene.camera = self.camera
1014
+ rot_quat = direction.to_track_quat("-Z", "Y")
1015
+ self.camera.rotation_euler = rot_quat.to_euler()
1016
+ self.camera.data.lens = camera_data["lens"]
1017
+
1018
+ def _create_cuboid_objects_translucent(self, subjects_data, opacity=0.025):
1019
+ """Create primitive cuboid objects for all subjects."""
1020
+ for subject_idx, subject_data in enumerate(subjects_data):
1021
+ # rgb_color = map_point_to_rgb(x, y)
1022
+ rgb_color = COLORS[subject_idx % len(COLORS)]
1023
+ _, prim_obj = get_primitive_object_translucent(base_color=rgb_color, face_opacity=opacity)
1024
+ prim_obj.location = np.array([100, 0, 0])
1025
+ subject_data["prim_obj"] = prim_obj
1026
+
1027
+ def _create_cuboid_objects_translucent_rgb(self, subjects_data, opacity=0.025):
1028
+ """Create primitive cuboid objects for all subjects."""
1029
+ for subject_idx, subject_data in enumerate(subjects_data):
1030
+ x = subject_data["x"][0]
1031
+ y = subject_data["y"][0]
1032
+ z = subject_data["z"][0]
1033
+ base_color = map_point_to_rgb(x, y, z)
1034
+ _, prim_obj = get_primitive_object_translucent_rgb(base_color=base_color, face_opacity=opacity)
1035
+ prim_obj.location = np.array([100, 0, 0])
1036
+ subject_data["prim_obj"] = prim_obj
1037
+
1038
+
1039
+ def _place_objects(self, subjects_data, camera_data):
1040
+ """Place objects in the scene according to their data."""
1041
+ global_scale = camera_data["global_scale"]
1042
+
1043
+ for subject_data in subjects_data:
1044
+ x = subject_data["x"][0]
1045
+ y = subject_data["y"][0]
1046
+ z = global_scale * subject_data["dims"][2] / 2.0 + subject_data["z"][0]
1047
+ subject_data["prim_obj"].location = np.array([x, y, z])
1048
+ subject_data["prim_obj"].scale = global_scale * np.array(subject_data["dims"]) * 2.0
1049
+ subject_data["prim_obj"].rotation_euler[2] = subject_data["azimuth"][0]
1050
+
1051
+ def render_cv(self, subjects_data, camera_data, num_samples=1, output_path="main.jpg"):
1052
+ """
1053
+ Main render method that takes subjects data and renders the scene.
1054
+
1055
+ Args:
1056
+ subjects_data (list): List of subject dictionaries containing position, dims, etc.
1057
+ camera_data (dict): Camera configuration
1058
+ num_samples (int): Number of samples to render (currently only supports 1)
1059
+ output_path (str): Path to save the rendered image
1060
+
1061
+ Returns:
1062
+ None
1063
+ """
1064
+ center = (-6.0, 0.0, 0.0)
1065
+ radius = 6.0
1066
+
1067
+ print(f"render_cv received {subjects_data = }")
1068
+
1069
+ # print(f"render_cv received {subjects_data = }")
1070
+ for subject_data in subjects_data:
1071
+ subject_data["azimuth"][0] = np.deg2rad(subject_data["azimuth"][0])
1072
+ subject_data["x"][0] = subject_data["x"][0] + center[0]
1073
+ subject_data["y"][0] = subject_data["y"][0] + center[1]
1074
+ subject_data["z"][0] = subject_data["z"][0] + center[2]
1075
+ # Setup camera
1076
+ self._setup_camera_cv(camera_data)
1077
+
1078
+ set_lights_cv(self.radius, np.array([self.center, 0, 0]), 20, intensity=7000.0)
1079
+
1080
+ # Add ground plane
1081
+ add_plane()
1082
+
1083
+ assert num_samples == 1, "for now, only implemented for a single sample"
1084
+ assert "global_scale" in camera_data.keys(), "global_scale must be set for EEVEE"
1085
+
1086
+ # Create primitive objects for subjects
1087
+ self._create_cuboid_objects_translucent(subjects_data, opacity=0.025)
1088
+ # self._create_cuboid_objects(subjects_data)
1089
+
1090
+ # Place objects in scene
1091
+ self._place_objects(subjects_data, camera_data)
1092
+
1093
+ # Perform rendering
1094
+ print(f"SUCCESS, rendering...")
1095
+ self.context.scene.render.filepath = output_path
1096
+ self.context.scene.render.image_settings.file_format = "JPEG"
1097
+ bpy.ops.render.render(write_still=True)
1098
+
1099
+ print(f"Rendered scene saved to: {output_path}")
1100
+
1101
+ self.cleanup()
1102
+
1103
+ def render_final_representation(self, subjects_data, camera_data, num_samples=1, output_path="main.jpg"):
1104
+ """
1105
+ Main render method that takes subjects data and renders the scene.
1106
+
1107
+ Args:
1108
+ subjects_data (list): List of subject dictionaries containing position, dims, etc.
1109
+ camera_data (dict): Camera configuration
1110
+ num_samples (int): Number of samples to render (currently only supports 1)
1111
+ output_path (str): Path to save the rendered image
1112
+
1113
+ Returns:
1114
+ None
1115
+ """
1116
+ assert self.render.engine == "CYCLES", "render_final_representation only works with CYCLES render engine"
1117
+ center = (-6.0, 0.0, 0.0)
1118
+ radius = 6.0
1119
+
1120
+ print(f"render_cv received {subjects_data = }")
1121
+
1122
+ # print(f"render_cv received {subjects_data = }")
1123
+ for subject_data in subjects_data:
1124
+ subject_data["azimuth"][0] = np.deg2rad(subject_data["azimuth"][0])
1125
+ subject_data["x"][0] = subject_data["x"][0] + center[0]
1126
+ subject_data["y"][0] = subject_data["y"][0] + center[1]
1127
+ subject_data["z"][0] = subject_data["z"][0] + center[2]
1128
+ # Setup camera
1129
+ self._setup_camera_cv(camera_data)
1130
+
1131
+ print(f"setting lights in cycles...")
1132
+ set_lights_cv(self.radius, np.array([self.center, 0, 0]), 5, intensity=700.0)
1133
+
1134
+ # Add ground plane
1135
+ print(f"adding plane in cycles...")
1136
+ add_plane_cycles()
1137
+
1138
+ assert num_samples == 1, "for now, only implemented for a single sample"
1139
+ assert "global_scale" in camera_data.keys(), "global_scale must be set for EEVEE"
1140
+
1141
+ # Create primitive objects for subjects
1142
+ self._create_cuboid_objects_translucent_rgb(subjects_data, opacity=0.025)
1143
+ # self._create_cuboid_objects(subjects_data)
1144
+
1145
+ # Place objects in scene
1146
+ self._place_objects(subjects_data, camera_data)
1147
+
1148
+ # Perform rendering
1149
+ print(f"SUCCESS, rendering...")
1150
+ self.context.scene.render.filepath = output_path
1151
+ self.context.scene.render.image_settings.file_format = "JPEG"
1152
+ bpy.ops.render.render(write_still=True)
1153
+
1154
+ print(f"Rendered scene saved to: {output_path}")
1155
+
1156
+ self.cleanup()
1157
+
1158
+
1159
+ def render_paper_figure(self, subjects_data, camera_data, num_samples=1, output_path="main.jpg"):
1160
+ """
1161
+ Main render method that takes subjects data and renders the scene.
1162
+
1163
+ Args:
1164
+ subjects_data (list): List of subject dictionaries containing position, dims, etc.
1165
+ camera_data (dict): Camera configuration
1166
+ num_samples (int): Number of samples to render (currently only supports 1)
1167
+ output_path (str): Path to save the rendered image
1168
+
1169
+ Returns:
1170
+ None
1171
+ """
1172
+ assert self.render.engine == "CYCLES", "render_final_representation only works with CYCLES render engine"
1173
+ center = (-6.0, 0.0, 0.0)
1174
+ radius = 6.0
1175
+
1176
+ print(f"render_cv received {subjects_data = }")
1177
+
1178
+ set_world_color((1.0, 1.0, 1.0)) # white background
1179
+
1180
+ # print(f"render_cv received {subjects_data = }")
1181
+ for subject_data in subjects_data:
1182
+ subject_data["azimuth"][0] = np.deg2rad(subject_data["azimuth"][0])
1183
+ subject_data["x"][0] = subject_data["x"][0] + center[0]
1184
+ subject_data["y"][0] = subject_data["y"][0] + center[1]
1185
+ subject_data["z"][0] = subject_data["z"][0] + center[2]
1186
+ # Setup camera
1187
+ self._setup_camera_cv(camera_data)
1188
+
1189
+ print(f"setting lights in cycles...")
1190
+ set_lights_cv(self.radius, np.array([self.center, 0, 0]), 5, intensity=7000.0)
1191
+
1192
+ # Add ground plane
1193
+ print(f"adding plane in cycles...")
1194
+
1195
+ assert num_samples == 1, "for now, only implemented for a single sample"
1196
+ assert "global_scale" in camera_data.keys(), "global_scale must be set for EEVEE"
1197
+
1198
+ # Create primitive objects for subjects
1199
+ self._create_cuboid_objects_translucent(subjects_data, opacity=0.35)
1200
+ # self._create_cuboid_objects(subjects_data)
1201
+
1202
+ # Place objects in scene
1203
+ self._place_objects(subjects_data, camera_data)
1204
+
1205
+ # Perform rendering
1206
+ print(f"SUCCESS, rendering...")
1207
+ self.context.scene.render.filepath = output_path
1208
+ self.context.scene.render.image_settings.file_format = "JPEG"
1209
+ bpy.ops.render.render(write_still=True)
1210
+
1211
+ print(f"Rendered scene saved to: {output_path}")
1212
+
1213
+ self.cleanup()
1214
+
1215
+
1216
+ def cleanup(self):
1217
+ """Clean up the scene for next render."""
1218
+ # Remove all lights
1219
+ remove_all_lights()
1220
+
1221
+ # Remove all other objects (meshes, empties, etc.)
1222
+ objects_to_remove = [obj for obj in bpy.data.objects]
1223
+
1224
+ for obj in objects_to_remove:
1225
+ bpy.data.objects.remove(obj, do_unlink=True)
1226
+
1227
+ # Clean up orphaned data blocks
1228
+ for mesh in bpy.data.meshes:
1229
+ if mesh.users == 0:
1230
+ bpy.data.meshes.remove(mesh)
1231
+
1232
+ for material in bpy.data.materials:
1233
+ if material.users == 0:
1234
+ bpy.data.materials.remove(material)
1235
+
1236
+ for light_data in bpy.data.lights:
1237
+ if light_data.users == 0:
1238
+ bpy.data.lights.remove(light_data)
1239
+
1240
+
1241
+ class BlenderSegmaskRenderer:
1242
+ def __init__(self):
1243
+ """
1244
+ Initialize the Blender cuboid renderer.
1245
+
1246
+ Args:
1247
+ img_dim (int): Image dimensions (square)
1248
+ render_engine (str): Blender render engine ('EEVEE' or 'CYCLES')
1249
+ num_lights (int): Number of lights to add
1250
+ max_tries (int): Maximum tries for placement
1251
+ """
1252
+ self.img_dim = 1024
1253
+ self.render_engine = "BLENDER_WORKBENCH"
1254
+ self.blender_grid_dims = scales
1255
+
1256
+ self.radius = 6.0
1257
+ self.center = -6.0
1258
+
1259
+ # Scene references
1260
+ self.context = None
1261
+ self.scene = None
1262
+ self.camera = None
1263
+ self.render = None
1264
+
1265
+ # Setup the scene
1266
+ self.setup_scene()
1267
+
1268
+
1269
+ def setup_scene(self):
1270
+ """
1271
+ Setup the basic Blender scene with camera, lighting, and render settings.
1272
+
1273
+ Args:
1274
+ camera_data (dict): Camera configuration containing elevation, lens, global_scale, etc.
1275
+ """
1276
+ # Get all objects in the scene
1277
+ objects_to_remove = []
1278
+
1279
+ for obj in bpy.data.objects:
1280
+ # Remove default cube, plane, camera, and lights
1281
+ if obj.type in {'MESH', 'LIGHT', 'CAMERA'}:
1282
+ objects_to_remove.append(obj)
1283
+
1284
+ # Delete the objects
1285
+ for obj in objects_to_remove:
1286
+ bpy.data.objects.remove(obj, do_unlink=True)
1287
+
1288
+ # Also clear orphaned data
1289
+ for mesh in bpy.data.meshes:
1290
+ if mesh.users == 0:
1291
+ bpy.data.meshes.remove(mesh)
1292
+
1293
+ for light in bpy.data.lights:
1294
+ if light.users == 0:
1295
+ bpy.data.lights.remove(light)
1296
+
1297
+ for camera in bpy.data.cameras:
1298
+ if camera.users == 0:
1299
+ bpy.data.cameras.remove(camera)
1300
+
1301
+ bpy.context.scene.world = None
1302
+
1303
+ # Initialize Blender scene
1304
+ # bpy.ops.wm.read_factory_settings(use_empty=True)
1305
+ self.context = bpy.context
1306
+ self.scene = self.context.scene
1307
+ self.render = self.scene.render
1308
+
1309
+ # Set render engine and resolution
1310
+ self.render.engine = self.render_engine
1311
+ self.context.scene.render.resolution_x = self.img_dim
1312
+ self.context.scene.render.resolution_y = self.img_dim
1313
+ self.context.scene.render.resolution_percentage = 100
1314
+
1315
+ # Setup compositing nodes
1316
+ self._setup_compositing()
1317
+
1318
+
1319
+ def _setup_compositing(self):
1320
+ """Setup Blender compositing nodes for depth and RGB output."""
1321
+ self.context.scene.use_nodes = True
1322
+ tree = self.context.scene.node_tree
1323
+ links = tree.links
1324
+
1325
+ self.context.scene.render.use_compositing = True
1326
+ self.context.view_layer.use_pass_z = True
1327
+
1328
+ # clear default nodes
1329
+ for n in tree.nodes:
1330
+ tree.nodes.remove(n)
1331
+
1332
+ # create input render layer node
1333
+ rl = tree.nodes.new('CompositorNodeRLayers')
1334
+
1335
+ map_node = tree.nodes.new(type="CompositorNodeMapValue")
1336
+ map_node.size = [0.05]
1337
+ map_node.use_min = True
1338
+ map_node.min = [0]
1339
+ map_node.use_max = True
1340
+ map_node.max = [65336]
1341
+ links.new(rl.outputs[2], map_node.inputs[0])
1342
+
1343
+ invert = tree.nodes.new(type="CompositorNodeInvert")
1344
+ links.new(map_node.outputs[0], invert.inputs[1])
1345
+
1346
+ # create output node
1347
+ v = tree.nodes.new('CompositorNodeViewer')
1348
+ v.use_alpha = True
1349
+
1350
+ # create a file output node and set the path
1351
+ fileOutput = tree.nodes.new(type="CompositorNodeOutputFile")
1352
+ fileOutput.base_path = "."
1353
+ links.new(invert.outputs[0], fileOutput.inputs[0])
1354
+
1355
+ # Links
1356
+ links.new(rl.outputs[0], v.inputs[0]) # link Image to Viewer Image RGB
1357
+ links.new(rl.outputs['Depth'], v.inputs[1]) # link Render Z to Viewer Image Alpha
1358
+
1359
+ # Update scene to apply changes
1360
+ self.context.view_layer.update()
1361
+
1362
+
1363
+ def _setup_camera_cv(self, camera_data):
1364
+ """Setup camera position and orientation."""
1365
+ reset_cameras(self.scene)
1366
+ self.camera = self.scene.objects["Camera"]
1367
+
1368
+ elevation = camera_data["camera_elevation"]
1369
+ tan_elevation = np.tan(elevation)
1370
+ cos_elevation = np.cos(elevation)
1371
+ sin_elevation = np.sin(elevation)
1372
+
1373
+ radius = self.radius
1374
+ center = self.center
1375
+
1376
+ self.camera.location = mathutils.Vector((radius * cos_elevation + center, 0, radius * sin_elevation))
1377
+ direction = mathutils.Vector((-1, 0, -tan_elevation))
1378
+ self.context.scene.camera = self.camera
1379
+ rot_quat = direction.to_track_quat("-Z", "Y")
1380
+ self.camera.rotation_euler = rot_quat.to_euler()
1381
+ self.camera.data.lens = camera_data["lens"]
1382
+
1383
+ def _create_cuboid_objects(self, subjects_data):
1384
+ """Create primitive cuboid objects for all subjects."""
1385
+ for subject_idx, subject_data in enumerate(subjects_data):
1386
+ x = subject_data["x"][0]
1387
+ y = subject_data["y"][0]
1388
+ z = subject_data["z"][0]
1389
+ rgb_color = map_point_to_rgb(x, y, z)
1390
+ _, prim_obj = get_primitive_object(rgb_color)
1391
+ prim_obj.location = np.array([100, 0, 0])
1392
+ subject_data["prim_obj"] = prim_obj
1393
+
1394
+ def _place_objects(self, subjects_data, camera_data):
1395
+ """Place objects in the scene according to their data."""
1396
+ global_scale = camera_data["global_scale"]
1397
+
1398
+ for subject_data in subjects_data:
1399
+ x = subject_data["x"][0]
1400
+ y = subject_data["y"][0]
1401
+ z = global_scale * subject_data["dims"][2] / 2.0 + subject_data["z"][0]
1402
+ subject_data["prim_obj"].location = np.array([x, y, z])
1403
+ subject_data["prim_obj"].scale = global_scale * np.array(subject_data["dims"]) * 2.0
1404
+ subject_data["prim_obj"].rotation_euler[2] = subject_data["azimuth"][0]
1405
+
1406
+ def render_cv(self, subjects_data, camera_data, num_samples=1):
1407
+ """
1408
+ Main render method that takes subjects data and renders the scene.
1409
+
1410
+ Args:
1411
+ subjects_data (list): List of subject dictionaries containing position, dims, etc.
1412
+ camera_data (dict): Camera configuration
1413
+ num_samples (int): Number of samples to render (currently only supports 1)
1414
+ output_path (str): Path to save the rendered image
1415
+
1416
+ Returns:
1417
+ None
1418
+ """
1419
+ # Setup camera
1420
+ center = (-6.0, 0.0, 0.0)
1421
+ radius = 6.0
1422
+
1423
+ for subject_data in subjects_data:
1424
+ subject_data["azimuth"][0] = np.deg2rad(subject_data["azimuth"][0])
1425
+ subject_data["x"][0] = subject_data["x"][0] + center[0]
1426
+ subject_data["y"][0] = subject_data["y"][0] + center[1]
1427
+ subject_data["z"][0] = subject_data["z"][0] + center[2]
1428
+
1429
+ print(f"in segmask render, {subjects_data = }")
1430
+
1431
+ self._setup_camera_cv(camera_data)
1432
+
1433
+ assert num_samples == 1, "for now, only implemented for a single sample"
1434
+ assert "global_scale" in camera_data.keys(), "global_scale must be set"
1435
+
1436
+ # Create primitive objects for subjects
1437
+ self._create_cuboid_objects(subjects_data)
1438
+
1439
+ def make_segmask(image):
1440
+ alpha = image[:, :, 3]
1441
+ _, mask = cv2.threshold(alpha, 0, 255, cv2.THRESH_BINARY)
1442
+ return mask
1443
+
1444
+
1445
+ for subject_idx, subject_data in enumerate(subjects_data):
1446
+ # Place objects in scene
1447
+ self._place_objects([subject_data], camera_data)
1448
+
1449
+ # Perform rendering
1450
+ print(f"SUCCESS, rendering...")
1451
+ self.context.scene.render.filepath = "tmp.png"
1452
+ self.context.scene.render.image_settings.file_format = "PNG"
1453
+ bpy.ops.render.render(write_still=True)
1454
+ img = cv2.imread("tmp.png", cv2.IMREAD_UNCHANGED)
1455
+ segmask = make_segmask(img)
1456
+ print(f"{segmask.shape = }")
1457
+ cv2.imwrite(f"{str(subject_idx).zfill(3)}_segmask_cv.png", segmask)
1458
+ print(f"saved {str(subject_idx).zfill(3)}_segmask_cv.png")
1459
+
1460
+ subject_data["prim_obj"].location = np.array([100, 0, 0]) # move out of view
1461
+
1462
+ self.cleanup()
1463
+
1464
+
1465
+ def cleanup(self):
1466
+ """Clean up the scene for next render."""
1467
+ # Remove all lights
1468
+ remove_all_lights()
1469
+
1470
+ # Remove all other objects (meshes, empties, etc.)
1471
+ objects_to_remove = [obj for obj in bpy.data.objects]
1472
+
1473
+ for obj in objects_to_remove:
1474
+ bpy.data.objects.remove(obj, do_unlink=True)
1475
+
1476
+ # Clean up orphaned data blocks
1477
+ for mesh in bpy.data.meshes:
1478
+ if mesh.users == 0:
1479
+ bpy.data.meshes.remove(mesh)
1480
+
1481
+ for material in bpy.data.materials:
1482
+ if material.users == 0:
1483
+ bpy.data.materials.remove(material)
1484
+
1485
+ for light_data in bpy.data.lights:
1486
+ if light_data.users == 0:
1487
+ bpy.data.lights.remove(light_data)
1488
+
1489
+
1490
+
1491
+ # Update the main execution
1492
+ if __name__ == '__main__':
1493
+ subjects_data = [
1494
+ {
1495
+ "name": "sedan",
1496
+ "x": [-5.0],
1497
+ "y": [0.0],
1498
+ "dims": [1.0, 2.0, 1.5],
1499
+ "azimuth": [0.0]
1500
+ },
1501
+ ]
1502
+ camera_data = {
1503
+ "camera_elevation": np.arctan(0.45),
1504
+ "lens": 70,
1505
+ "global_scale": 1.0
1506
+ }
1507
+
1508
+ # Create renderer instance
1509
+ renderer = BlenderCuboidRenderer(
1510
+ img_dim=1024,
1511
+ render_engine='EEVEE',
1512
+ num_lights=1,
1513
+ )
1514
+
1515
+ # Render the scene
1516
+ renderer.render(
1517
+ subjects_data=subjects_data,
1518
+ camera_data=camera_data,
1519
+ num_samples=1,
1520
+ output_path="main.jpg"
1521
+ )
gradio_app/blender_server.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
+ import shutil
5
+ import base64
6
+ import io
7
+ from PIL import Image
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+ from typing import List, Dict, Any
11
+ import uvicorn
12
+ import argparse
13
+
14
+ # Import BlenderCuboidRenderer
15
+ from blender_backend import BlenderCuboidRenderer
16
+
17
+ class RenderRequest(BaseModel):
18
+ subjects_data: List[Dict[str, Any]]
19
+ camera_data: Dict[str, Any]
20
+ num_samples: int = 1
21
+
22
+ class RenderResponse(BaseModel):
23
+ success: bool
24
+ image_base64: str = None
25
+ error_message: str = None
26
+
27
+ class BlenderRenderServer:
28
+ def __init__(self, render_mode: str):
29
+ """
30
+ Initialize the Blender render server.
31
+
32
+ Args:
33
+ render_mode (str): Either 'cv' for camera view or 'bev' for bird's eye view
34
+ """
35
+ self.render_mode = render_mode
36
+ if self.render_mode == "cv":
37
+ self.renderer = BlenderCuboidRenderer("BLENDER_EEVEE_NEXT")
38
+ elif self.render_mode == "final":
39
+ self.renderer = BlenderCuboidRenderer("CYCLES")
40
+ elif self.render_mode == "paper":
41
+ self.renderer = BlenderCuboidRenderer("CYCLES")
42
+
43
+ def process_render_request(self, request: RenderRequest) -> RenderResponse:
44
+ """Process a render request and return the result."""
45
+ # Create temporary directory for this render
46
+ output_path = os.path.join(f"{self.render_mode}_render.jpg")
47
+
48
+ # Convert subjects_data format if needed
49
+ converted_subjects_data = self._convert_subjects_data(request.subjects_data)
50
+
51
+ # Add required camera_data fields
52
+ camera_data = request.camera_data.copy()
53
+ camera_data["global_scale"] = camera_data.get("global_scale", 1.0)
54
+
55
+ # Perform the render based on mode
56
+ if self.render_mode == "cv":
57
+ self.renderer.render_cv(
58
+ subjects_data=converted_subjects_data,
59
+ camera_data=camera_data,
60
+ num_samples=request.num_samples,
61
+ output_path=output_path
62
+ )
63
+ elif self.render_mode == "final":
64
+ self.renderer.render_final_representation(
65
+ subjects_data=converted_subjects_data,
66
+ camera_data=camera_data,
67
+ num_samples=request.num_samples,
68
+ output_path=output_path
69
+ )
70
+ elif self.render_mode == "paper":
71
+ self.renderer.render_paper_figure(
72
+ subjects_data=converted_subjects_data,
73
+ camera_data=camera_data,
74
+ num_samples=request.num_samples,
75
+ output_path=output_path
76
+ )
77
+ else:
78
+ raise ValueError(f"Invalid render mode: {self.render_mode}")
79
+
80
+ # Read and encode the rendered image
81
+ if os.path.exists(output_path):
82
+ with open(output_path, "rb") as img_file:
83
+ img_data = img_file.read()
84
+ img_base64 = base64.b64encode(img_data).decode('utf-8')
85
+
86
+ return RenderResponse(success=True, image_base64=img_base64)
87
+ else:
88
+ return RenderResponse(
89
+ success=False,
90
+ error_message="Render output file not found"
91
+ )
92
+
93
+ def _convert_subjects_data(self, subjects_data: List[Dict]) -> List[Dict]:
94
+ """Convert subjects data to the format expected by BlenderCuboidRenderer."""
95
+ converted = []
96
+
97
+ for subject in subjects_data:
98
+ # Convert to the expected format with lists for x, y, azimuth
99
+ converted_subject = {
100
+ "name": subject.get("subject_name", "cuboid"),
101
+ "x": [subject["x"]],
102
+ "y": [subject["y"]],
103
+ "z": [subject["z"]],
104
+ "dims": [subject["width"], subject["depth"], subject["height"]],
105
+ "azimuth": [subject["azimuth"]]
106
+ }
107
+ converted.append(converted_subject)
108
+
109
+ return converted
110
+
111
+ # Create FastAPI app
112
+ app = FastAPI(title="Blender Render Server")
113
+
114
+ # Global server instance
115
+ server = None
116
+
117
+ @app.on_event("startup")
118
+ def startup_event():
119
+ global server
120
+ render_mode = os.environ.get("RENDER_MODE")
121
+ server = BlenderRenderServer(render_mode)
122
+ print(f"Blender Render Server started in {render_mode.upper()} mode")
123
+
124
+ @app.post("/render", response_model=RenderResponse)
125
+ def render_scene(request: RenderRequest):
126
+ """Render a scene and return the result."""
127
+ if server is None:
128
+ raise HTTPException(status_code=500, detail="Server not initialized")
129
+
130
+ return server.process_render_request(request)
131
+
132
+ @app.get("/health")
133
+ def health_check():
134
+ """Health check endpoint."""
135
+ return {"status": "healthy", "render_mode": server.render_mode if server else "unknown"}
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser(description="Blender Render Server")
139
+ parser.add_argument("--mode", choices=["cv", "final", "paper"], required=True,
140
+ help="Render mode: cv for camera view, bev for bird's eye view")
141
+ parser.add_argument("--port", type=int, default=5001, help="Port to run server on")
142
+ parser.add_argument("--host", default="127.0.0.1", help="Host to bind server to")
143
+
144
+ args = parser.parse_args()
145
+
146
+ # Set environment variable for the startup event
147
+ os.environ["RENDER_MODE"] = args.mode
148
+
149
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
gradio_app/blender_server_segmasks.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
+ import shutil
5
+ import base64
6
+ import io
7
+ from PIL import Image
8
+ from fastapi import FastAPI, HTTPException
9
+ from pydantic import BaseModel
10
+ from typing import List, Dict, Any
11
+ import uvicorn
12
+ import argparse
13
+
14
+ # Import BlenderSegmaskRenderer
15
+ from blender_backend import BlenderSegmaskRenderer
16
+
17
+ class SegmaskRenderRequest(BaseModel):
18
+ subjects_data: List[Dict[str, Any]]
19
+ camera_data: Dict[str, Any]
20
+ num_samples: int = 1
21
+
22
+ class SegmaskRenderResponse(BaseModel):
23
+ success: bool
24
+ segmasks_base64: List[str] = None
25
+ error_message: str = None
26
+
27
+ class BlenderSegmaskRenderServer:
28
+ def __init__(self):
29
+ """Initialize the Blender segmentation mask render server."""
30
+ self.renderer = BlenderSegmaskRenderer()
31
+
32
+ def process_render_request(self, request: SegmaskRenderRequest) -> SegmaskRenderResponse:
33
+ """Process a segmentation mask render request and return the result."""
34
+ try:
35
+ # Create temporary directory for this render
36
+ # Convert subjects_data format if needed
37
+ converted_subjects_data = self._convert_subjects_data(request.subjects_data)
38
+
39
+ # Add required camera_data fields
40
+ camera_data = request.camera_data.copy()
41
+ camera_data["global_scale"] = camera_data.get("global_scale", 1.0)
42
+
43
+ # Perform the render
44
+ self.renderer.render_cv(
45
+ subjects_data=converted_subjects_data,
46
+ camera_data=camera_data,
47
+ num_samples=request.num_samples
48
+ )
49
+
50
+ # Read and encode all segmentation masks in order
51
+ segmasks_base64 = []
52
+ num_subjects = len(converted_subjects_data)
53
+
54
+ for subject_idx in range(num_subjects):
55
+ segmask_path = os.path.join(f"{str(subject_idx).zfill(3)}_segmask_cv.png")
56
+
57
+ if os.path.exists(segmask_path):
58
+ with open(segmask_path, "rb") as img_file:
59
+ img_data = img_file.read()
60
+ img_base64 = base64.b64encode(img_data).decode('utf-8')
61
+ segmasks_base64.append(img_base64)
62
+ else:
63
+ # Return error if any segmask is missing
64
+ return SegmaskRenderResponse(
65
+ success=False,
66
+ error_message=f"Segmentation mask for subject {subject_idx} not found"
67
+ )
68
+
69
+
70
+ return SegmaskRenderResponse(
71
+ success=True,
72
+ segmasks_base64=segmasks_base64
73
+ )
74
+
75
+ except Exception as e:
76
+ # Change back to original directory on error
77
+
78
+ return SegmaskRenderResponse(
79
+ success=False,
80
+ error_message=f"Segmentation mask render failed: {str(e)}"
81
+ )
82
+
83
+ def _convert_subjects_data(self, subjects_data: List[Dict]) -> List[Dict]:
84
+ """Convert subjects data to the format expected by BlenderSegmaskRenderer."""
85
+ converted = []
86
+
87
+ for subject in subjects_data:
88
+ # Convert to the expected format with lists for x, y, azimuth
89
+ converted_subject = {
90
+ "name": subject.get("subject_name", "cuboid"),
91
+ "x": [subject["x"]],
92
+ "y": [subject["y"]],
93
+ "z": [subject["z"]],
94
+ "dims": [subject["width"], subject["depth"], subject["height"]],
95
+ "azimuth": [subject["azimuth"]]
96
+ }
97
+ converted.append(converted_subject)
98
+
99
+ return converted
100
+
101
+ # Create FastAPI app
102
+ app = FastAPI(title="Blender Segmentation Mask Render Server")
103
+
104
+ # Global server instance
105
+ server = None
106
+
107
+ @app.on_event("startup")
108
+ def startup_event():
109
+ global server
110
+ server = BlenderSegmaskRenderServer()
111
+ print("Blender Segmentation Mask Render Server started")
112
+
113
+ @app.post("/render_segmasks", response_model=SegmaskRenderResponse)
114
+ def render_segmasks(request: SegmaskRenderRequest):
115
+ """Render segmentation masks and return the results."""
116
+ if server is None:
117
+ raise HTTPException(status_code=500, detail="Server not initialized")
118
+
119
+ return server.process_render_request(request)
120
+
121
+ @app.get("/health")
122
+ def health_check():
123
+ """Health check endpoint."""
124
+ return {"status": "healthy", "type": "segmentation_mask_renderer"}
125
+
126
+ if __name__ == "__main__":
127
+ parser = argparse.ArgumentParser(description="Blender Segmentation Mask Render Server")
128
+ parser.add_argument("--port", type=int, default=5003, help="Port to run server on")
129
+ parser.add_argument("--host", default="127.0.0.1", help="Host to bind server to")
130
+
131
+ args = parser.parse_args()
132
+
133
+ uvicorn.run(app, host=args.host, port=args.port, log_level="info")
gradio_app/cv_render.jpg ADDED

Git LFS Details

  • SHA256: c652e9870c4dbd2f14bf286fb3a31abbbc63823dfa8b29db02025f666a1a25fa
  • Pointer size: 130 Bytes
  • Size of remote file: 29.1 kB
gradio_app/final_render.jpg ADDED

Git LFS Details

  • SHA256: 2c9f3c4c98d279c441c46d743c9bb7820bcb661c61564fca3fa28e2a27fc29f6
  • Pointer size: 130 Bytes
  • Size of remote file: 37.9 kB
gradio_app/infer_backend.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from PIL import Image
5
+ import numpy as np
6
+ from typing import Optional, List, Tuple
7
+ from transformers import CLIPTokenizer, T5TokenizerFast
8
+
9
+ import sys
10
+ sys.path.append("..")
11
+ from train.src.pipeline import FluxPipeline
12
+ from train.src.transformer_flux import FluxTransformer2DModel
13
+ from train.src.lora_helper import set_single_lora, set_multi_lora, unset_lora
14
+ from train.src.jsonl_datasets import make_train_dataset, collate_fn
15
+
16
+
17
+ class InferenceArgs:
18
+ """Arguments configuration for inference dataset loading"""
19
+ def __init__(self, jsonl_path: str, pretrained_model_name: str):
20
+ # Basic paths
21
+ self.current_train_data_dir = jsonl_path
22
+ self.inference_embeds_dir = "/archive/vaibhav.agrawal/a-bev-of-the-latents/inference_embeds_flux2"
23
+ self.pretrained_model_name_or_path = pretrained_model_name
24
+
25
+ # Column configurations
26
+ self.subject_column = None # Set to None since we're using spatial
27
+ self.spatial_column = "cv"
28
+ self.target_column = "target"
29
+ self.caption_column = "PLACEHOLDER_prompts"
30
+
31
+ # Size configurations
32
+ self.cond_size = 512
33
+ self.noise_size = 512
34
+
35
+ # Other required parameters
36
+ self.revision = None
37
+ self.variant = None
38
+ self.max_sequence_length = 512
39
+
40
+
41
+ class InferenceEngine:
42
+ """
43
+ Handles model loading and inference for the Gradio interface.
44
+ Pre-loads the base model and dynamically loads LoRA weights based on checkpoint selection.
45
+ """
46
+
47
+ def __init__(self, base_model_path: str = "black-forest-labs/FLUX.1-dev", device: str = "cuda"):
48
+ """
49
+ Initialize the inference engine with base model.
50
+
51
+ Args:
52
+ base_model_path: Path to the base FLUX model
53
+ device: Device to run inference on (default: "cuda")
54
+ """
55
+ self.device = device
56
+ self.base_model_path = base_model_path
57
+ self.current_lora_path = None
58
+
59
+ print(f"Loading base model from {base_model_path}...")
60
+
61
+ # Load pipeline and transformer
62
+ self.pipe = FluxPipeline.from_pretrained(
63
+ base_model_path,
64
+ torch_dtype=torch.bfloat16,
65
+ device=device
66
+ )
67
+
68
+ transformer = FluxTransformer2DModel.from_pretrained(
69
+ base_model_path,
70
+ subfolder="transformer",
71
+ torch_dtype=torch.bfloat16,
72
+ device=device
73
+ )
74
+
75
+ self.pipe.transformer = transformer
76
+ self.pipe.to(device)
77
+
78
+ # Load tokenizers (same as in train.py and infer.ipynb)
79
+ print("Loading tokenizers...")
80
+ self.tokenizer_one = CLIPTokenizer.from_pretrained(
81
+ base_model_path,
82
+ subfolder="tokenizer",
83
+ revision=None,
84
+ )
85
+ self.tokenizer_two = T5TokenizerFast.from_pretrained(
86
+ base_model_path,
87
+ subfolder="tokenizer_2",
88
+ revision=None,
89
+ )
90
+ self.tokenizers = [self.tokenizer_one, self.tokenizer_two]
91
+
92
+ print("Base model and tokenizers loaded successfully!")
93
+
94
+ def load_lora(self, checkpoint_name: str, lora_weights: List[float] = [1.0]):
95
+ """
96
+ Load LoRA weights for a specific checkpoint.
97
+
98
+ Args:
99
+ checkpoint_name: Name of the checkpoint (e.g., "checkpoint_1")
100
+ lora_weights: Weights for the LoRA adaptation
101
+ """
102
+ # Construct LoRA path
103
+ lora_path = f"/archive/vaibhav.agrawal/a-bev-of-the-latents/easycontrol_cuboids/{checkpoint_name}/lora.safetensors"
104
+
105
+ print(f"\n\nGOT THE FOLLOWING LORA PATH: {lora_path}\n\n")
106
+
107
+ # Check if path exists
108
+ if not os.path.exists(lora_path):
109
+ raise FileNotFoundError(f"LoRA checkpoint not found at: {lora_path}")
110
+
111
+ # Only reload if it's a different checkpoint
112
+ if self.current_lora_path != lora_path:
113
+ print(f"Loading LoRA weights from {lora_path}...")
114
+ set_single_lora(
115
+ self.pipe.transformer,
116
+ lora_path,
117
+ lora_weights=lora_weights,
118
+ cond_size=512
119
+ )
120
+ self.current_lora_path = lora_path
121
+ print(f"LoRA weights loaded successfully!")
122
+ else:
123
+ print(f"LoRA already loaded for {checkpoint_name}")
124
+
125
+ def clear_cache(self):
126
+ """Clear attention processor cache"""
127
+ for name, attn_processor in self.pipe.transformer.attn_processors.items():
128
+ if hasattr(attn_processor, 'bank_kv'):
129
+ attn_processor.bank_kv.clear()
130
+
131
+ def tensor_to_image_list(self, tensor):
132
+ """Convert normalized tensor to PIL Image list"""
133
+ if tensor is None:
134
+ return []
135
+
136
+ images = []
137
+ for img_tensor in tensor:
138
+ # Denormalize from [-1, 1] to [0, 1]
139
+ img = (img_tensor.cpu().permute(1, 2, 0) * 0.5 + 0.5).clamp(0, 1).numpy()
140
+ # Convert to [0, 255] uint8
141
+ img = (img * 255.0).astype(np.uint8)
142
+ images.append(Image.fromarray(img))
143
+
144
+ return images
145
+
146
+ def run_inference(
147
+ self,
148
+ jsonl_path: str,
149
+ checkpoint_name: str,
150
+ height: int = 512,
151
+ width: int = 512,
152
+ seed: int = 42,
153
+ guidance_scale: float = 3.5,
154
+ num_inference_steps: int = 25,
155
+ max_sequence_length: int = 512
156
+ ) -> Tuple[bool, Optional[Image.Image], str]:
157
+ """
158
+ Run inference using data from JSONL file.
159
+ Uses the same data loading pipeline as training (make_train_dataset).
160
+
161
+ Args:
162
+ jsonl_path: Path to the JSONL file containing inference data
163
+ checkpoint_name: Name of checkpoint to use
164
+ height: Output image height
165
+ width: Output image width
166
+ seed: Random seed for generation
167
+ guidance_scale: Guidance scale for diffusion
168
+ num_inference_steps: Number of denoising steps
169
+ max_sequence_length: Maximum sequence length for text encoding
170
+
171
+ Returns:
172
+ Tuple of (success: bool, image: PIL.Image or None, message: str)
173
+ """
174
+ try:
175
+ # Load LoRA for selected checkpoint
176
+ self.load_lora(checkpoint_name)
177
+
178
+ # Check if JSONL file exists
179
+ if not os.path.exists(jsonl_path):
180
+ return False, None, f"JSONL file not found at: {jsonl_path}"
181
+
182
+ # Create inference arguments
183
+ inference_args = InferenceArgs(
184
+ jsonl_path=jsonl_path,
185
+ pretrained_model_name=self.base_model_path
186
+ )
187
+
188
+ # Create dataset using the same pipeline as training
189
+ print("Creating inference dataset...")
190
+ inference_dataset = make_train_dataset(inference_args, self.tokenizers, accelerator=None, noise_size=512)
191
+
192
+ # Create dataloader with batch_size=1
193
+ inference_dataloader = torch.utils.data.DataLoader(
194
+ inference_dataset,
195
+ batch_size=1,
196
+ shuffle=False,
197
+ collate_fn=collate_fn,
198
+ num_workers=0,
199
+ )
200
+
201
+ # Get the first (and only) batch
202
+ batch = next(iter(inference_dataloader))
203
+
204
+ # Extract data from batch
205
+ caption = batch["prompts"][0] if isinstance(batch["prompts"], list) else batch["prompts"]
206
+ call_ids = batch["call_ids"]
207
+
208
+ print(f"\n{'='*60}")
209
+ print(f"Running inference with:")
210
+ print(f" Checkpoint: {checkpoint_name}")
211
+ print(f" Prompt: {caption}")
212
+ print(f" Call IDs: {call_ids}")
213
+ print(f" Height: {height}, Width: {width}")
214
+ print(f" Seed: {seed}, Steps: {num_inference_steps}")
215
+ print(f" Guidance Scale: {guidance_scale}")
216
+ print(f"{'='*60}\n")
217
+
218
+ # Convert spatial condition tensors to PIL Images
219
+ spatial_imgs = self.tensor_to_image_list(batch["cond_pixel_values"])
220
+
221
+ # Prepare cuboids segmentation masks
222
+ cuboids_segmasks = batch.get("cuboids_segmasks", None)
223
+
224
+ # Prepare joint attention kwargs
225
+ joint_attention_kwargs = {
226
+ "call_ids": call_ids,
227
+ "cuboids_segmasks": cuboids_segmasks,
228
+ }
229
+
230
+ print(f"Spatial images: {len(spatial_imgs)}")
231
+ print(f"{len(cuboids_segmasks) = }, {cuboids_segmasks[0].shape = }")
232
+ # print(f"Cuboids segmasks shape: {cuboids_segmasks.shape if cuboids_segmasks is not None else 'None'}")
233
+ cuboids_segmasks = torch.stack(cuboids_segmasks, dim=0) if cuboids_segmasks is not None else None
234
+
235
+ # Run inference
236
+ image = self.pipe(
237
+ prompt=caption,
238
+ height=int(height),
239
+ width=int(width),
240
+ guidance_scale=guidance_scale,
241
+ num_inference_steps=num_inference_steps,
242
+ max_sequence_length=max_sequence_length,
243
+ generator=torch.Generator("cpu").manual_seed(seed),
244
+ subject_images=[], # No subject images for spatial conditioning
245
+ spatial_images=spatial_imgs,
246
+ cond_size=512,
247
+ **joint_attention_kwargs
248
+ ).images[0]
249
+
250
+ # Clear cache
251
+ self.clear_cache()
252
+ torch.cuda.empty_cache()
253
+
254
+ success_msg = f"✅ Successfully generated image using {checkpoint_name}"
255
+ print(f"\n{success_msg}\n")
256
+
257
+ return True, image, success_msg
258
+
259
+ except Exception as e:
260
+ error_msg = f"❌ Inference failed: {str(e)}"
261
+ print(f"\n{error_msg}\n")
262
+ import traceback
263
+ traceback.print_exc()
264
+ return False, None, error_msg
265
+
266
+
267
+ # Global inference engine instance
268
+ _inference_engine: Optional[InferenceEngine] = None
269
+
270
+
271
+ def initialize_inference_engine(base_model_path: str = "black-forest-labs/FLUX.1-dev"):
272
+ """
273
+ Initialize the global inference engine.
274
+ Should be called once when the Gradio demo starts.
275
+ """
276
+ global _inference_engine
277
+
278
+ if _inference_engine is None:
279
+ print("\n" + "="*60)
280
+ print("INITIALIZING INFERENCE ENGINE")
281
+ print("="*60 + "\n")
282
+
283
+ _inference_engine = InferenceEngine(base_model_path=base_model_path)
284
+
285
+ print("\n" + "="*60)
286
+ print("INFERENCE ENGINE READY")
287
+ print("="*60 + "\n")
288
+
289
+ return _inference_engine
290
+
291
+
292
+ def get_inference_engine() -> InferenceEngine:
293
+ """
294
+ Get the global inference engine instance.
295
+ Raises an error if not initialized.
296
+ """
297
+ global _inference_engine
298
+
299
+ if _inference_engine is None:
300
+ raise RuntimeError(
301
+ "Inference engine not initialized. "
302
+ "Call initialize_inference_engine() first."
303
+ )
304
+
305
+ return _inference_engine
306
+
307
+
308
+ def run_inference_from_gradio(
309
+ checkpoint_name: str,
310
+ height: int = 512,
311
+ width: int = 512,
312
+ seed: int = 42,
313
+ guidance_scale: float = 3.5,
314
+ num_inference_steps: int = 25,
315
+ jsonl_path: str = "/archive/vaibhav.agrawal/a-bev-of-the-latents/gradio_files/cuboids.jsonl"
316
+ ) -> Tuple[bool, Optional[Image.Image], str]:
317
+ """
318
+ Wrapper function to run inference from Gradio interface.
319
+
320
+ Args:
321
+ checkpoint_name: Name of checkpoint to use (from dropdown)
322
+ height: Output image height
323
+ width: Output image width
324
+ seed: Random seed
325
+ guidance_scale: Guidance scale
326
+ num_inference_steps: Number of denoising steps
327
+ jsonl_path: Path to JSONL file with inference data
328
+
329
+ Returns:
330
+ Tuple of (success, generated_image, status_message)
331
+ """
332
+ engine = get_inference_engine()
333
+
334
+ return engine.run_inference(
335
+ jsonl_path=jsonl_path,
336
+ checkpoint_name=checkpoint_name,
337
+ height=height,
338
+ width=width,
339
+ seed=seed,
340
+ guidance_scale=guidance_scale,
341
+ num_inference_steps=num_inference_steps
342
+ )
gradio_app/launch_blender_backend.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ PORTS=(5001 5002 5003 5004)
4
+
5
+ for port in "${PORTS[@]}"; do
6
+ PID=$(lsof -t -i ":$port")
7
+ if [ -n "$PID" ]; then
8
+ echo "Killing process $PID running on port $port..."
9
+ kill -9 "$PID"
10
+ echo "Process $PID killed."
11
+ else
12
+ echo "No process found running on port $port."
13
+ fi
14
+ done
15
+
16
+ # Start CV render server
17
+ echo "Starting Camera View render server on port 5001..."
18
+ python blender_server.py --mode cv --port 5001 &
19
+ CV_PID=$!
20
+
21
+ echo "Starting Camera View render server on port 5002..."
22
+ python blender_server.py --mode final --port 5002 &
23
+ FINAL_PID=$!
24
+
25
+
26
+ # Start segmask render server
27
+ echo "Starting Segmentation Mask render server on port 5003..."
28
+ python3 blender_server_segmasks.py --port 5003 &
29
+ SEGMASK_PID=$!
30
+
31
+ echo "Starting Camera View render server on port 5004..."
32
+ python blender_server.py --mode paper --port 5004 &
33
+ PAPER_PID=$!
34
+
35
+ echo "Render servers started!"
36
+ echo "CV Server PID: $CV_PID (port 5001)"
37
+ echo "Final (Cycles) Render Server PID: $FINAL_PID (port 5002)"
38
+ echo "Segmentation Mask Server PID: $SEGMASK_PID (port 5003)"
39
+
40
+ # Function to cleanup on exit
41
+ cleanup() {
42
+ echo "Stopping render servers..."
43
+ kill $CV_PID $FINAL_PID $SEGMASK_PID 2>/dev/null
44
+ exit 0
45
+ }
46
+
47
+ # Set trap to cleanup on script exit
48
+ trap cleanup SIGINT SIGTERM
49
+
50
+ # Wait for both processes
51
+ wait $CV_PID $FINAL_PID $SEGMASK_PID
gradio_app/model_condition.jpg ADDED

Git LFS Details

  • SHA256: ba3e33f67bc6f81aa5756c00233fdeba584e3c4e63b9a235a2df8bd117a64ea4
  • Pointer size: 130 Bytes
  • Size of remote file: 27.3 kB
gradio_app/object_scales.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ scales = {
2
+ "bear": 0.53, # Unchanged
3
+ "bicycle": 0.4, # Unchanged
4
+ "bugatti": 1.0, # Unchanged
5
+ "bulldozer": 1.78, # Unchanged
6
+ "bus": 2.67, # Unchanged
7
+ "cat": 0.11, # Unchanged
8
+ "chair": 0.18, # Unchanged
9
+ "coupe": 1.0, # Unchanged
10
+ "cow": 0.56, # Unchanged
11
+ "crow": 0.09, # CHANGED: Reduced from 0.11
12
+ "deer": 0.44, # Unchanged
13
+ "dog": 0.22, # Unchanged
14
+ "elephant": 1.22, # Unchanged
15
+ "ferrari": 1.05, # CHANGED: Increased from 1.0
16
+ "flamingo": 0.10, # Unchanged
17
+ "fox": 0.22, # Unchanged
18
+ "giraffe": 0.90, # CHANGED: Reduced from 1.00
19
+ "goat": 0.33, # Unchanged
20
+ "helicopter": 2.26, # Unchanged
21
+ "hen": 0.09, # Unchanged
22
+ "horse": 0.53, # Unchanged
23
+ "jeep": 0.96, # Unchanged
24
+ "kangaroo": 0.38, # CHANGED: Increased from 0.33
25
+ "lamborghini": 1.0, # Unchanged
26
+ "lion": 0.56, # Unchanged
27
+ "mclaren": 1.0, # Unchanged
28
+ "motorbike": 0.44, # Unchanged
29
+ "office chair": 0.20,# Unchanged
30
+ "pickup truck": 1.22,# Unchanged
31
+ "pigeon": 0.067, # Unchanged
32
+ "pig": 0.33, # Unchanged
33
+ "rabbit": 0.11, # Unchanged
34
+ "scooter": 0.4, # Unchanged
35
+ "sedan": 1.0, # Unchanged (Reference)
36
+ "sheep": 0.29, # Unchanged
37
+ "shoe": 0.04, # Unchanged
38
+ "sparrow": 0.033, # Unchanged
39
+ "suv": 1.07, # Unchanged
40
+ "table": 0.4, # Unchanged
41
+ "teddy": 0.05, # CHANGED: Reduced from 0.11
42
+ "tiger": 0.67, # Unchanged
43
+ "tractor": 0.80, # Unchanged
44
+ "van": 1.11, # Unchanged
45
+ "vw beetle": 1.0, # Unchanged
46
+ "wolf": 0.33, # Unchanged
47
+ "man": 0.38, # Unchanged
48
+ "zebra": 0.56 # Unchanged
49
+ }
50
+
51
+ tiny_assets = [
52
+ "sparrow", # 0.033
53
+ "shoe", # 0.04
54
+ "teddy", # 0.05 (CHANGED)
55
+ "pigeon", # 0.067
56
+ "hen", # 0.09
57
+ "crow", # 0.09 (CHANGED)
58
+ "flamingo", # 0.10 (CHANGED - Moved from small)
59
+ "rabbit", # 0.11
60
+ "cat", # 0.11
61
+ ]
62
+
63
+
64
+ small_assets = [
65
+ "chair", # 0.18
66
+ "office chair", # 0.20
67
+ "dog", # 0.22
68
+ "fox", # 0.22
69
+ "sheep", # 0.29
70
+ "goat", # 0.33
71
+ "pig", # 0.33
72
+ "wolf", # 0.33
73
+ "man", # 0.38 (CHANGED - Added to group)
74
+ "kangaroo", # 0.38 (CHANGED)
75
+ ]
76
+
77
+
78
+ medium_assets = [
79
+ "table", # 0.4 (CHANGED - Moved from small)
80
+ "bicycle", # 0.4
81
+ "scooter", # 0.4
82
+ "deer", # 0.44
83
+ "motorbike", # 0.44
84
+ "bear", # 0.53
85
+ "horse", # 0.53
86
+ "cow", # 0.56
87
+ "lion", # 0.56
88
+ "zebra", # 0.56
89
+ "tiger", # 0.67
90
+ "tractor", # 0.80
91
+ "giraffe", # 0.90 (CHANGED)
92
+ "jeep", # 0.96
93
+ "bugatti", # 1.0
94
+ "coupe", # 1.0
95
+ "lamborghini", # 1.0
96
+ "mclaren", # 1.0
97
+ "sedan", # 1.0
98
+ "vw beetle", # 1.0
99
+ "ferrari", # 1.05 (CHANGED)
100
+ "suv", # 1.07
101
+ "van", # 1.11
102
+ "elephant", # 1.22
103
+ "pickup truck", # 1.22
104
+ "bulldozer", # 1.78
105
+ "helicopter", # 2.26
106
+ "bus", # 2.67
107
+ ]
108
+
109
+ tiny_prompts = [
110
+ "a photo of PLACEHOLDER in a cozy birdhouse nestled in a green tree",
111
+ "a photo of PLACEHOLDER on a sandy beach near the water's edge with small shells",
112
+ "a photo of PLACEHOLDER amongst colorful wildflowers in a sunny meadow",
113
+ "a photo of PLACEHOLDER on a moss-covered log in a quiet forest",
114
+ "a photo of PLACEHOLDER near a small pond with lily pads floating",
115
+ "a photo of PLACEHOLDER on a window sill overlooking a rainy city street",
116
+ "a photo of PLACEHOLDER in a child's bedroom surrounded by other toys",
117
+ "a photo of PLACEHOLDER on a park bench with fallen leaves around",
118
+ "a photo of PLACEHOLDER by a small stream with smooth pebbles",
119
+ "a photo of PLACEHOLDER in a field of tall grass swaying gently",
120
+ "a photo of PLACEHOLDER on a wooden fence post in the countryside",
121
+ "a photo of PLACEHOLDER amongst blossoming spring flowers in a garden",
122
+ "a photo of PLACEHOLDER on a stack of old books in a library",
123
+ "a photo of PLACEHOLDER near a bird feeder in a winter garden",
124
+ "a photo of PLACEHOLDER on a picnic blanket in a sunny park",
125
+ "a photo of PLACEHOLDER on a kitchen counter near ripe fruit",
126
+ "a photo of PLACEHOLDER amongst autumn leaves on a forest floor",
127
+ "a photo of PLACEHOLDER on a rocky outcrop with a distant view",
128
+ "a photo of PLACEHOLDER near a puddle reflecting the sky",
129
+ "a photo of PLACEHOLDER in a patch of soft green moss",
130
+ "a photo of PLACEHOLDER on a weathered stone wall",
131
+ "a photo of PLACEHOLDER near a patch of blooming daisies",
132
+ "a photo of PLACEHOLDER on a sandy path through a garden",
133
+ "a photo of PLACEHOLDER near a watering can in a greenhouse",
134
+ "a photo of PLACEHOLDER amongst fallen pine needles in a forest",
135
+ "a photo of PLACEHOLDER on a small bridge over a gentle stream",
136
+ "a photo of PLACEHOLDER near a patch of colorful mushrooms"
137
+ ]
138
+
139
+ small_prompts = [
140
+ "a photo of PLACEHOLDER in a sun-drenched greenhouse surrounded by various plants",
141
+ "a photo of PLACEHOLDER in a bustling city park with people walking by",
142
+ "a photo of PLACEHOLDER in a cozy library with tall bookshelves and soft lighting",
143
+ "a photo of PLACEHOLDER on a sandy dune near the ocean with gentle waves",
144
+ "a photo of PLACEHOLDER amongst tall reeds in a marshland area",
145
+ "a photo of PLACEHOLDER in a quiet forest clearing with sunlight filtering through trees",
146
+ "a photo of PLACEHOLDER on a grassy hill overlooking a small town",
147
+ "a photo of PLACEHOLDER near a flowing waterfall with mist in the air",
148
+ "a photo of PLACEHOLDER in a vibrant flower market with colorful blooms all around",
149
+ "a photo of PLACEHOLDER on a wooden dock extending into a still lake",
150
+ "a photo of PLACEHOLDER amongst rows of crops in a rural farmland",
151
+ "a photo of PLACEHOLDER in a historic town square with old buildings",
152
+ "a photo of PLACEHOLDER on a rocky beach with crashing waves in the distance",
153
+ "a photo of PLACEHOLDER amongst tall bamboo stalks in a serene grove",
154
+ "a photo of PLACEHOLDER in a snowy field with tracks visible in the snow",
155
+ "a photo of PLACEHOLDER on a paved walkway in a botanical garden",
156
+ "a photo of PLACEHOLDER near a campfire in a forest at night",
157
+ "a photo of PLACEHOLDER amongst colorful autumn foliage in a park",
158
+ "a photo of PLACEHOLDER on a stone path winding through a garden",
159
+ "a photo of PLACEHOLDER in a misty meadow with dew-covered grass",
160
+ "a photo of PLACEHOLDER on a wooden bridge crossing a small river",
161
+ "a photo of PLACEHOLDER amongst blooming lavender fields under a sunny sky",
162
+ "a photo of PLACEHOLDER in a quiet suburban backyard with green grass",
163
+ "a photo of PLACEHOLDER on a rocky hillside with sparse vegetation",
164
+ "a photo of PLACEHOLDER near a clear mountain stream with smooth stones",
165
+ "a photo of PLACEHOLDER amongst fallen leaves in a shaded woodland",
166
+ "a photo of PLACEHOLDER on a grassy bank beside a calm canal",
167
+ "a photo of PLACEHOLDER in a vineyard with rows of grapevines",
168
+ "a photo of PLACEHOLDER near a traditional wooden farmhouse"
169
+ ]
170
+
171
+ medium_prompts = [
172
+ "a photo of PLACEHOLDER in a vast open plain with a dramatic sunset on the horizon",
173
+ "a photo of PLACEHOLDER on a winding mountain road with scenic views of valleys",
174
+ "a photo of PLACEHOLDER in a bustling harbor with various boats and ships",
175
+ "a photo of PLACEHOLDER in a dense pine forest with tall trees reaching the sky",
176
+ "a photo of PLACEHOLDER on a sandy beach with palm trees swaying in the breeze",
177
+ "a photo of PLACEHOLDER amongst rolling hills in a green countryside landscape",
178
+ "a photo of PLACEHOLDER in a vibrant city square with historic architecture",
179
+ "a photo of PLACEHOLDER in a train yard with multiple railway tracks",
180
+ "a photo of PLACEHOLDER amongst tall redwood trees in an ancient forest",
181
+ "a photo of PLACEHOLDER in a sprawling parking lot outside a shopping mall",
182
+ "a photo of PLACEHOLDER on a coastal highway with ocean views and cliffs",
183
+ "a photo of PLACEHOLDER amongst golden wheat fields under a clear summer sky",
184
+ "a photo of PLACEHOLDER in a rocky canyon with sparse desert vegetation and blue sky above",
185
+ "a photo of PLACEHOLDER on a grassy plateau overlooking a vast landscape",
186
+ "a photo of PLACEHOLDER in a snowy mountain range with visible ski slopes",
187
+ "a photo of PLACEHOLDER on a paved highway stretching across an open landscape",
188
+ "a photo of PLACEHOLDER amongst lush vegetation in a tropical rainforest",
189
+ "a photo of PLACEHOLDER in a historic European city with ornate buildings",
190
+ "a photo of PLACEHOLDER in front of the Eiffel Tower at sunset",
191
+ "a photo of PLACEHOLDER amongst tall sunflowers in a field under a bright sun",
192
+ "a photo of PLACEHOLDER in a deep valley with steep forested sides",
193
+ "a photo of PLACEHOLDER on a rocky coastline with crashing waves and sea spray",
194
+ "a photo of PLACEHOLDER amongst vineyards on rolling hills under a sunny sky",
195
+ "a photo of PLACEHOLDER in a wide open desert with distant mesas and clear air",
196
+ "a photo of PLACEHOLDER amongst autumn-colored trees along a winding river",
197
+ "a photo of PLACEHOLDER in a bustling marketplace with various stalls and people",
198
+ "a photo of PLACEHOLDER on a racing circuit with banked turns and grandstands",
199
+ "a photo of PLACEHOLDER amongst tall grasses in a savanna landscape",
200
+ ]
201
+
202
+ groups = {
203
+ "tiny": tiny_assets,
204
+ "small": small_assets,
205
+ "medium": medium_assets,
206
+ }
207
+
208
+ groups_prompts = {
209
+ "tiny": tiny_prompts,
210
+ "small": small_prompts,
211
+ "medium": medium_prompts,
212
+ }
gradio_app/paper_render.jpg ADDED

Git LFS Details

  • SHA256: d7ea195b602da3fec33eb7affc91a1a37b5885905ed6fcc60cbdc53531628ebc
  • Pointer size: 130 Bytes
  • Size of remote file: 23.9 kB
gradio_app/set_tmp.sh ADDED
File without changes
gradio_app/tmp.png ADDED

Git LFS Details

  • SHA256: 24713b88fb53a35f320e40903c80c428de19bd741173b61fe0c1e1693a01a510
  • Pointer size: 131 Bytes
  • Size of remote file: 654 kB
gradio_app/visualize_server.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Flask web server to visualize inference results for 2-subject cases.
4
+ Port: 7023
5
+ """
6
+
7
+ import os
8
+ import json
9
+ from pathlib import Path
10
+ from flask import Flask, render_template, send_from_directory
11
+ import base64
12
+
13
+ app = Flask(__name__)
14
+
15
+ # Paths
16
+ DATASET_FILE = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval/cuboids_segmentation.jsonl"
17
+ DATASET_ROOT = "/archive/vaibhav.agrawal/a-bev-of-the-latents/datasetv7_superhard_eval"
18
+ RESULTS_DIR = "/archive/vaibhav.agrawal/a-bev-of-the-latents/VAL/results/omini_seg_baseline_r2_epoch-0_checkpoint-20000"
19
+
20
+ def load_2_subject_cases():
21
+ """Load all 2-subject cases from the dataset."""
22
+ cases = []
23
+ with open(DATASET_FILE, 'r') as f:
24
+ for idx, line in enumerate(f):
25
+ data = json.loads(line)
26
+ if len(data['subjects']) == 2:
27
+ cases.append({
28
+ 'dataset_index': idx,
29
+ 'subjects': data['subjects'],
30
+ 'prompt': data['prompt'],
31
+ 'target': data['target'],
32
+ 'cv': data['cv']
33
+ })
34
+ return cases
35
+
36
+ # Load cases on startup
37
+ TWO_SUBJECT_CASES = load_2_subject_cases()
38
+ print(f"Loaded {len(TWO_SUBJECT_CASES)} 2-subject cases")
39
+
40
+ def get_image_path(case, image_type):
41
+ """Get the path for a specific image type."""
42
+ if image_type == 'ground_truth':
43
+ return os.path.join(DATASET_ROOT, case['target'])
44
+ elif image_type == 'segmentation':
45
+ return os.path.join(DATASET_ROOT, case['cv'])
46
+ elif image_type == 'generated':
47
+ # Find the generated image in results
48
+ viz_dir = os.path.join(RESULTS_DIR, 'generated_images')
49
+ # Pattern: sample_{sample_idx:04d}_idx_{dataset_index}_seed_{seed}.jpg
50
+ # We need to find the file that matches the dataset_index
51
+ if os.path.exists(viz_dir):
52
+ for filename in os.listdir(viz_dir):
53
+ if f"_idx_{case['dataset_index']}_" in filename:
54
+ return os.path.join(viz_dir, filename)
55
+ return None
56
+
57
+ @app.route('/')
58
+ def index():
59
+ """Main page showing the first 2-subject case."""
60
+ return show_case(0)
61
+
62
+ @app.route('/case/<int:case_idx>')
63
+ def show_case(case_idx):
64
+ """Display a specific case."""
65
+ if case_idx < 0 or case_idx >= len(TWO_SUBJECT_CASES):
66
+ return "Case not found", 404
67
+
68
+ case = TWO_SUBJECT_CASES[case_idx]
69
+
70
+ # Get image paths
71
+ gt_path = get_image_path(case, 'ground_truth')
72
+ seg_path = get_image_path(case, 'segmentation')
73
+ gen_path = get_image_path(case, 'generated')
74
+
75
+ # Check if files exist
76
+ gt_exists = os.path.exists(gt_path) if gt_path else False
77
+ seg_exists = os.path.exists(seg_path) if seg_path else False
78
+ gen_exists = os.path.exists(gen_path) if gen_path else False
79
+
80
+ return render_template('viewer.html',
81
+ case_idx=case_idx,
82
+ total_cases=len(TWO_SUBJECT_CASES),
83
+ subjects=', '.join(case['subjects']),
84
+ prompt=case['prompt'].replace('PLACEHOLDER', ', '.join(case['subjects'])),
85
+ dataset_index=case['dataset_index'],
86
+ gt_exists=gt_exists,
87
+ seg_exists=seg_exists,
88
+ gen_exists=gen_exists,
89
+ prev_idx=case_idx - 1 if case_idx > 0 else None,
90
+ next_idx=case_idx + 1 if case_idx < len(TWO_SUBJECT_CASES) - 1 else None)
91
+
92
+ @app.route('/image/<int:case_idx>/<image_type>')
93
+ def serve_image(case_idx, image_type):
94
+ """Serve the requested image."""
95
+ if case_idx < 0 or case_idx >= len(TWO_SUBJECT_CASES):
96
+ return "Case not found", 404
97
+
98
+ case = TWO_SUBJECT_CASES[case_idx]
99
+ image_path = get_image_path(case, image_type)
100
+
101
+ if image_path and os.path.exists(image_path):
102
+ directory = os.path.dirname(image_path)
103
+ filename = os.path.basename(image_path)
104
+ return send_from_directory(directory, filename)
105
+ else:
106
+ return "Image not found", 404
107
+
108
+ if __name__ == '__main__':
109
+ # Create templates directory if it doesn't exist
110
+ os.makedirs('templates', exist_ok=True)
111
+
112
+ # Run server on all interfaces (0.0.0.0) for remote access
113
+ print(f"Starting server on port 7023...")
114
+ print(f"Access at: http://<your-host-ip>:7023")
115
+ app.run(host='0.0.0.0', port=7023, debug=True)
train/src/__pycache__/__init__.cpython-311.pyc CHANGED
Binary files a/train/src/__pycache__/__init__.cpython-311.pyc and b/train/src/__pycache__/__init__.cpython-311.pyc differ
 
train/src/__pycache__/jsonl_datasets.cpython-311.pyc CHANGED
Binary files a/train/src/__pycache__/jsonl_datasets.cpython-311.pyc and b/train/src/__pycache__/jsonl_datasets.cpython-311.pyc differ
 
train/src/__pycache__/layers.cpython-311.pyc CHANGED
Binary files a/train/src/__pycache__/layers.cpython-311.pyc and b/train/src/__pycache__/layers.cpython-311.pyc differ
 
train/src/__pycache__/lora_helper.cpython-311.pyc CHANGED
Binary files a/train/src/__pycache__/lora_helper.cpython-311.pyc and b/train/src/__pycache__/lora_helper.cpython-311.pyc differ
 
train/src/__pycache__/pipeline.cpython-311.pyc CHANGED
Binary files a/train/src/__pycache__/pipeline.cpython-311.pyc and b/train/src/__pycache__/pipeline.cpython-311.pyc differ
 
train/src/__pycache__/transformer_flux.cpython-311.pyc CHANGED
Binary files a/train/src/__pycache__/transformer_flux.cpython-311.pyc and b/train/src/__pycache__/transformer_flux.cpython-311.pyc differ
 
train/src/jsonl_datasets.py CHANGED
@@ -107,8 +107,7 @@ def make_train_dataset(args, tokenizer, accelerator, noise_size, only_realistic_
107
  prompt_file_name = "space_prompt.pth"
108
  else:
109
  prompt_file_name = "_".join(caption.split(" ")) + ".pth"
110
- if args.inference_embeds_dir is not None:
111
- assert osp.exists(osp.join(args.inference_embeds_dir, prompt_file_name)), f"Prompt embeddings for '{caption}' not found in {args.inference_embeds_dir}. Please precompute and save them."
112
  prompt_embeds = torch.load(osp.join(args.inference_embeds_dir, prompt_file_name), map_location="cpu")
113
  pooled_prompt_embeds = prompt_embeds["pooled_prompt_embeds"]
114
  prompt_embeds = prompt_embeds["prompt_embeds"]
 
107
  prompt_file_name = "space_prompt.pth"
108
  else:
109
  prompt_file_name = "_".join(caption.split(" ")) + ".pth"
110
+ if args.inference_embeds_dir is not None and osp.exists(osp.join(args.inference_embeds_dir, prompt_file_name)):
 
111
  prompt_embeds = torch.load(osp.join(args.inference_embeds_dir, prompt_file_name), map_location="cpu")
112
  pooled_prompt_embeds = prompt_embeds["pooled_prompt_embeds"]
113
  prompt_embeds = prompt_embeds["prompt_embeds"]
train/src/pipeline.py CHANGED
@@ -619,11 +619,7 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
619
  cond_size=512,
620
  call_ids=None,
621
  cuboids_segmasks=None,
622
- store_qk=None,
623
- store_qk_timesteps=None,
624
  ):
625
- assert not ((store_qk is None) ^ (store_qk_timesteps is None)), "Please provide both store_qk and store_qk_timesteps or neither of them."
626
-
627
  height = height or self.default_sample_size * self.vae_scale_factor
628
  width = width or self.default_sample_size * self.vae_scale_factor
629
  self.cond_size = cond_size
@@ -756,12 +752,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
756
  if self.interrupt:
757
  continue
758
 
759
- store_qk_ = copy.deepcopy(store_qk)
760
- if (store_qk_ is not None) and (i not in store_qk_timesteps):
761
- store_qk_ = None
762
- elif store_qk_ is not None:
763
- store_qk_ = osp.join(store_qk, f"step_{i}")
764
-
765
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
766
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
767
  noise_pred = self.transformer(
@@ -777,7 +767,6 @@ class FluxPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
777
  return_dict=False,
778
  call_ids=call_ids,
779
  cuboids_segmasks=cuboids_segmasks,
780
- store_qk=store_qk_,
781
  )[0]
782
 
783
  # compute the previous noisy sample x_t -> x_t-1
 
619
  cond_size=512,
620
  call_ids=None,
621
  cuboids_segmasks=None,
 
 
622
  ):
 
 
623
  height = height or self.default_sample_size * self.vae_scale_factor
624
  width = width or self.default_sample_size * self.vae_scale_factor
625
  self.cond_size = cond_size
 
752
  if self.interrupt:
753
  continue
754
 
 
 
 
 
 
 
755
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
756
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
757
  noise_pred = self.transformer(
 
767
  return_dict=False,
768
  call_ids=call_ids,
769
  cuboids_segmasks=cuboids_segmasks,
 
770
  )[0]
771
 
772
  # compute the previous noisy sample x_t -> x_t-1