import os os.environ["GRADIO_TEMP_DIR"]="./gradio_tmp" os.environ["GRADIO_TMP_DIR"]="./gradio_tmp" os.environ["TEMPDIR"]="./gradio_tmp" os.environ["TMP_DIR"]="./gradio_tmp" os.environ["TEMP_DIR"]="./gradio_tmp" os.environ["TMPDIR"]="./gradio_tmp" import os.path as osp import sys import numpy as np import tempfile import shutil import base64 import io from PIL import Image import gradio as gr import time import copy import requests import json import pickle from concurrent.futures import ThreadPoolExecutor, as_completed from inference.object_scales import scales from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast import pickle from datetime import datetime from inference.infer_backend import initialize_inference_engine, run_inference_from_gradio import inference.config as config COLORS = [ (1.0, 0.0, 0.0), # Red (0.0, 0.8, 0.2), # Green (0.0, 0.0, 1.0), # Blue (1.0, 1.0, 0.0), # Yellow (0.0, 1.0, 1.0), # Cyan (1.0, 0.0, 1.0), # Magenta (1.0, 0.6, 0.0), # Orange (0.6, 0.0, 0.8), # Purple (0.0, 0.4, 0.0), # Dark Green (0.8, 0.8, 0.8), # Light Gray (0.2, 0.2, 0.2) # Dark Gray ] tokenizer = T5TokenizerFast.from_pretrained( config.PRETRAINED_MODEL_NAME_OR_PATH, subfolder="tokenizer_2", revision=None, ) placeholder_token_str = [""] num_added_tokens = tokenizer.add_tokens(placeholder_token_str) assert num_added_tokens == 1 def generate_image_event(camera_elevation, camera_lens, surrounding_prompt, checkpoint_name, image_size, seed, guidance_scale, num_steps): """Generate final image with segmentation masks and run inference""" # Update scene manager's inference params before generation scene_manager.update_inference_params(image_size, image_size, seed, guidance_scale, num_steps, checkpoint_name) if not scene_manager.objects: return ( "⚠️ No objects to render", gr.update(), Image.new('RGB', (512, 512), color='white') ) # Get subject descriptions subject_descriptions = [obj['description'] for obj in scene_manager.objects] print(f"Surrounding prompt: {surrounding_prompt}") print(f"Subject descriptions: {subject_descriptions}") print(f"Selected checkpoint: {checkpoint_name}") placeholder_prompt = "a photo of PLACEHOLDER " + surrounding_prompt # Create placeholder text subject_embeds = [] for subject_idx, subject_desc in enumerate(subject_descriptions): input_ids = tokenizer.encode(subject_desc, return_tensors="pt", max_length=77)[0] subject_embed = {"input_ids_t5": input_ids.tolist()} subject_embeds.append(subject_embed) placeholder_text = "" for subject in subject_descriptions[:-1]: placeholder_text = placeholder_text + f" {subject} and " for subject in subject_descriptions[-1:]: placeholder_text = placeholder_text + f" {subject}" placeholder_text = placeholder_text.strip() placeholder_token_prompt = placeholder_prompt.replace("PLACEHOLDER", placeholder_text) call_ids = get_call_ids_from_placeholder_prompt_flux(prompt=placeholder_token_prompt, subjects=subject_descriptions, subjects_embeds=subject_embeds, debug=True ) print(f"Generated call IDs: {call_ids}") # Convert to server expected format subjects_data, camera_data = scene_manager._convert_to_blender_format() # Render final high-quality image using CYCLES (port 5002) final_img = scene_manager.render_client._send_render_request( scene_manager.render_client.final_server_url, subjects_data, camera_data ) final_img.save("model_condition.jpg") # Render segmentation masks success, segmask_images, error_msg = scene_manager.render_client.render_segmasks(subjects_data, camera_data) if not success: return ( f"❌ Failed to render segmentation masks: {error_msg}", gr.update(), Image.new('RGB', (512, 512), color='white') ) # Save all files to the correct location root_save_dir = config.GRADIO_FILES_DIR os.makedirs(root_save_dir, exist_ok=True) os.system(f"rm -f {root_save_dir}/*") # Save final render to root directory final_render_path = osp.join(root_save_dir, "cv_render.jpg") final_img.save(final_render_path) # Move segmentation masks for subject_idx in range(len(subject_descriptions)): shutil.move( f"{str(subject_idx).zfill(3)}_segmask_cv.png", osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png") ) # Create JSONL jsonl = [{ "cv": final_render_path, "target": final_render_path, "cuboids_segmasks": [ osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png") for subject_idx in range(len(subject_descriptions)) ], "PLACEHOLDER_prompts": placeholder_prompt, "subjects": subject_descriptions, "call_ids": call_ids, }] jsonl_path = osp.join(root_save_dir, "cuboids.jsonl") with open(jsonl_path, "w") as f: json.dump(jsonl[0], f) # Run inference using the pre-loaded model print(f"\n{'='*60}") print(f"RUNNING INFERENCE") print(f"{'='*60}\n") inference_success, generated_image, inference_msg = run_inference_from_gradio( checkpoint_name=checkpoint_name, height=image_size, width=image_size, seed=seed, guidance_scale=guidance_scale, num_inference_steps=num_steps, jsonl_path=jsonl_path ) if not inference_success: return ( f"✅ Saved files but inference failed: {inference_msg}", final_img, Image.new('RGB', (512, 512), color='white') ) status_msg = f"✅ Generated image using {checkpoint_name} with {len(segmask_images)} segmentation masks" # Render final image for camera view final_img = scene_manager.render_client._send_render_request( scene_manager.render_client.cv_server_url, subjects_data, camera_data ) return ( status_msg, final_img, # Display CV render in Camera View generated_image # Display generated image in Generated Image section ) def get_call_ids_from_placeholder_prompt_flux(prompt: str, subjects, subjects_embeds: list, debug: bool): assert prompt.find("") != -1, "Prompt must contain to get call ids" # the placeholder token ID for all the tokenizers placeholder_token_three = tokenizer.encode("", return_tensors="pt")[0][:-1].item() prompt_tokens_three = tokenizer.encode(prompt, return_tensors="pt")[0].tolist() placeholder_token_locations_three = [i for i, w in enumerate(prompt_tokens_three) if w == placeholder_token_three] prompt = prompt.replace(" ", "") call_ids = [] for subject_idx, (subject, subject_embed) in enumerate(zip(subjects, subjects_embeds)): subject_prompt_ids_t5 = subject_embed["input_ids_t5"][:-1] # T5 has SOT token only num_t5_tokens_subject = len(subject_prompt_ids_t5) t5_call_ids_subject = [i + placeholder_token_locations_three[subject_idx] - 2 * subject_idx - 1 for i in range(num_t5_tokens_subject)] call_ids.append(t5_call_ids_subject) prompt_wo_placeholder = prompt.replace(" ", "") t5_call_strs = tokenizer.batch_decode(tokenizer.encode(prompt_wo_placeholder, return_tensors="pt")[0].tolist()) t5_call_strs = [t5_call_strs[i] for i in t5_call_ids_subject] if debug: print(f"{prompt = }, t5 CALL strs for {subject} = {t5_call_strs}") return call_ids def map_point_to_rgb(x, y): """ Map (x, y) inside the frustum to an RGB color with continuity and variation. """ # Frustum boundaries X_MIN, X_MAX = -10.0, -1.0 Y_MIN_AT_XMIN, Y_MAX_AT_XMIN = -4.5, 4.5 Y_MIN_AT_XMAX, Y_MAX_AT_XMAX = -0.5, 0.5 # Normalize x to [0, 1] x_norm = (x - X_MIN) / (X_MAX - X_MIN) # x_norm = np.clip(x_norm, 0, 1) # Compute current Y bounds at given x using linear interpolation y_min = Y_MIN_AT_XMIN + x_norm * (Y_MIN_AT_XMAX - Y_MIN_AT_XMIN) y_max = Y_MAX_AT_XMIN + x_norm * (Y_MAX_AT_XMAX - Y_MAX_AT_XMIN) # Normalize y to [0, 1] within current bounds if y_max != y_min: y_norm = (y - y_min) / (y_max - y_min) else: y_norm = 0.5 y_norm = np.clip(y_norm, 0.0, 1.0) # Color mapping: more variation along x r = x_norm g = y_norm b = 1.0 - x_norm return (r, g, b) def rgb_to_hex(rgb_tuple): """Convert RGB tuple (0-1 range) to hex color string.""" r, g, b = rgb_tuple return f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}" class BlenderRenderClient: def __init__(self, cv_server_url=config.BLENDER_CV_SERVER_URL, segmask_server_url=config.BLENDER_SEGMASK_SERVER_URL, final_server_url=config.BLENDER_FINAL_SERVER_URL): """ Initialize the Blender render client. Args: cv_server_url (str): URL of the camera view render server segmask_server_url (str): URL of the segmentation mask render server final_server_url (str): URL of the final view render server """ self.cv_server_url = cv_server_url self.segmask_server_url = segmask_server_url self.final_server_url = final_server_url self.timeout = 30 # 30 second timeout for renders def render_segmasks(self, subjects_data: list, camera_data: dict) -> tuple: """ Send a segmentation mask render request. Returns (success: bool, segmask_images: list of PIL Images or None, error_message: str or None) """ try: request_data = { "subjects_data": subjects_data, "camera_data": camera_data, "num_samples": 1 } response = requests.post( f"{self.segmask_server_url}/render_segmasks", json=request_data, timeout=self.timeout ) if response.status_code == 200: result = response.json() if result["success"]: # Decode all segmentation masks segmask_images = [] for img_base64 in result["segmasks_base64"]: img_data = base64.b64decode(img_base64) img = Image.open(io.BytesIO(img_data)) segmask_images.append(img) print(f"Successfully rendered {len(segmask_images)} segmentation masks") return True, segmask_images, None else: error_msg = result.get('error_message', 'Unknown error') print(f"Segmask render failed: {error_msg}") return False, None, error_msg else: error_msg = f"HTTP error {response.status_code}: {response.text}" print(error_msg) return False, None, error_msg except requests.exceptions.Timeout: error_msg = "Segmask render request timed out" print(error_msg) return False, None, error_msg except Exception as e: error_msg = f"Segmask render request failed: {e}" print(error_msg) return False, None, error_msg def _send_render_request(self, server_url: str, subjects_data: list, camera_data: dict) -> Image.Image: """Send a render request to a server and return the image.""" try: request_data = { "subjects_data": subjects_data, "camera_data": camera_data, "num_samples": 1 } print(f"passing {subjects_data = } to server at {server_url}") response = requests.post( f"{server_url}/render", json=request_data, timeout=self.timeout ) if response.status_code == 200: result = response.json() if result["success"]: # Decode base64 image img_data = base64.b64decode(result["image_base64"]) img = Image.open(io.BytesIO(img_data)) return img else: print(f"Render failed: {result.get('error_message', 'Unknown error')}") return self._create_error_image("red") else: print(f"HTTP error {response.status_code}: {response.text}") return self._create_error_image("orange") except requests.exceptions.Timeout: print("Render request timed out") return self._create_error_image("yellow") except Exception as e: print(f"Render request failed: {e}") return self._create_error_image("red") def _create_error_image(self, color: str) -> Image.Image: """Create a colored error image.""" return Image.new('RGB', (512, 512), color=color) # --- Scene Management Class --- class SceneManager: def __init__(self): self.objects = [] self.camera_elevation = 30.0 self.camera_lens = 50.0 self.surrounding_prompt = "" self.next_color_idx = 0 self.colors = [ (1.0, 0.0, 0.0), # red (0.0, 0.0, 1.0), # blue (0.0, 1.0, 0.0), # green (0.5, 0.0, 0.5), # purple (1.0, 0.5, 0.0), # orange (1.0, 1.0, 0.0), # yellow (0.0, 1.0, 1.0), # cyan (1.0, 0.0, 1.0), # magenta ] # Add inference parameters with defaults self.inference_params = { 'height': 512, 'width': 512, 'seed': 42, 'guidance_scale': 3.5, 'num_inference_steps': 25, 'checkpoint': config.CHECKPOINT_NAMES[0] if config.CHECKPOINT_NAMES else None } # Initialize BlenderRenderClient self.render_client = BlenderRenderClient() # Load asset dimensions self.asset_dimensions = self._load_asset_dimensions() def update_inference_params(self, height, width, seed, guidance_scale, num_steps, checkpoint): """Update inference parameters""" self.inference_params = { 'height': height, 'width': width, 'seed': seed, 'guidance_scale': guidance_scale, 'num_inference_steps': num_steps, 'checkpoint': checkpoint } def update_cuboid_description(self, obj_id, new_description): """Update the description of a cuboid""" if 0 <= obj_id < len(self.objects): if new_description.strip(): # Check not empty self.objects[obj_id]['description'] = new_description.strip() return True return False def save_scene_to_pkl(self, filepath=None): """Save current scene data to pkl file including inference parameters""" if filepath is None: # Auto-generate filename with timestamp, saved into configuration saved scenes dir save_dir = config.SAVED_SCENES_DIR os.makedirs(save_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filepath = os.path.join(save_dir, f"scene_{timestamp}.pkl") # Convert to the expected format subjects_data = [] for obj in self.objects: subject_dict = { 'name': obj['description'], 'type': obj['type'], # Save the object type 'dims': tuple(obj['size']), # (width, depth, height) 'x': [obj['position'][0] - 6.0], 'y': [obj['position'][1]], 'z': [obj['position'][2]], 'azimuth': [np.radians(obj['azimuth'])], # Convert to radians 'bbox': [(0, 0, 0, 0)] # Placeholder, can be computed if needed } subjects_data.append(subject_dict) camera_data = { 'camera_elevation': np.radians(self.camera_elevation), 'lens': self.camera_lens, 'global_scale': 1.0 # Default value } scene_dict = { 'subjects_data': subjects_data, 'camera_data': camera_data, 'surrounding_prompt': self.surrounding_prompt, 'inference_params': self.inference_params.copy() } try: with open(filepath, 'wb') as f: pickle.dump(scene_dict, f) return True, filepath, None except Exception as e: return False, None, str(e) def load_scene_from_pkl(self, filepath): """Load scene data from pkl file including inference parameters""" # Automatically look in saved_scenes if a basic filename was provided if os.path.sep not in filepath: filepath = os.path.join(config.SAVED_SCENES_DIR, filepath) try: with open(filepath, 'rb') as f: scene_dict = pickle.load(f) # Clear existing objects self.objects = [] self.next_color_idx = 0 # Load subjects subjects_data = scene_dict.get('subjects_data', []) for subject_dict in subjects_data: name = subject_dict.get('name', 'Loaded Object') asset_type = subject_dict.get('type', 'Custom') # Load the type dims = subject_dict.get('dims', (1.0, 1.0, 1.0)) x = float(subject_dict.get('x', [0.0])[0]) + 6.0 y = float(subject_dict.get('y', [0.0])[0]) z = float(subject_dict.get('z', [0.0])[0]) azimuth_rad = float(subject_dict.get('azimuth', [0.0])[0]) azimuth_deg = np.degrees(azimuth_rad) # Determine original_asset_size based on type if asset_type == "Custom" or asset_type not in self.asset_dimensions: original_asset_size = None else: # Look up the original asset dimensions asset_dims = self.asset_dimensions[asset_type] original_asset_size = [float(asset_dims[0]), float(asset_dims[1]), float(asset_dims[2])] # Create object obj_id = len(self.objects) size_list = [float(d) for d in dims] cuboid = { 'id': obj_id, 'description': name, 'type': asset_type, # Use the loaded type 'position': [x, y, z], 'size': size_list, 'original_asset_size': original_asset_size, # Restore from asset_dimensions 'azimuth': float(azimuth_deg), 'color': self._get_next_color() } self.objects.append(cuboid) # Load camera settings camera_data = scene_dict.get('camera_data', {}) camera_elev_rad = float(camera_data.get('camera_elevation', np.radians(30.0))) self.camera_elevation = float(np.degrees(camera_elev_rad)) self.camera_lens = float(camera_data.get('lens', 50.0)) # Load surrounding prompt self.surrounding_prompt = scene_dict.get('surrounding_prompt', '') # Load inference parameters loaded_inference_params = scene_dict.get('inference_params', {}) # Get checkpoint, fall back to first available if not found saved_checkpoint = loaded_inference_params.get('checkpoint') if saved_checkpoint and saved_checkpoint in config.CHECKPOINT_NAMES: checkpoint = saved_checkpoint else: checkpoint = config.CHECKPOINT_NAMES[0] if config.CHECKPOINT_NAMES else None if saved_checkpoint: print(f"Warning: Saved checkpoint '{saved_checkpoint}' not found, using '{checkpoint}' instead") self.inference_params = { 'height': loaded_inference_params.get('height', 512), 'width': loaded_inference_params.get('width', 512), 'seed': loaded_inference_params.get('seed', 42), 'guidance_scale': loaded_inference_params.get('guidance_scale', 3.5), 'num_inference_steps': loaded_inference_params.get('num_inference_steps', 25), 'checkpoint': checkpoint } return True, len(subjects_data), None except FileNotFoundError: return False, 0, f"File not found: {filepath}" except Exception as e: return False, 0, f"Error loading file: {str(e)}" def _load_asset_dimensions(self): """Load asset dimensions from pickle file""" pkl_path = "asset_dimensions.pkl" if os.path.exists(pkl_path): try: with open(pkl_path, 'rb') as f: return pickle.load(f) except Exception as e: print(f"Warning: Could not load asset dimensions: {e}") return {} else: print(f"Warning: asset_dimensions.pkl not found at {pkl_path}") return {} def get_asset_type_choices(self): """Get list of asset types for dropdown""" choices = ["Custom"] if self.asset_dimensions: choices.extend(sorted(self.asset_dimensions.keys())) return choices def _get_next_color(self): color = self.colors[self.next_color_idx % len(self.colors)] self.next_color_idx += 1 return color def harmonize_scales(self): """ Harmonize the scales of all non-Custom objects based on object scales. Always scales from original asset dimensions, ignoring any manual edits. Custom objects remain unchanged. """ if not self.objects: return "No objects to harmonize" # Find objects that can be harmonized (non-Custom with valid scales and original_asset_size) harmonizable_objects = [] for obj in self.objects: if (obj['type'] != "Custom" and obj['type'] in scales and obj['original_asset_size'] is not None): harmonizable_objects.append(obj) if not harmonizable_objects: return "No objects with defined scales to harmonize (all are Custom)" # Find the largest scale among harmonizable objects max_scale = max(scales[obj['type']] for obj in harmonizable_objects) if max_scale == 0: return "Invalid max scale (0)" # Harmonize each object by scaling from ORIGINAL ASSET dimensions for obj in harmonizable_objects: obj_scale = scales[obj['type']] scale_factor = obj_scale / max_scale # Scale from ORIGINAL ASSET dimensions, not current dimensions obj['size'][0] = obj['original_asset_size'][0] * scale_factor # width obj['size'][1] = obj['original_asset_size'][1] * scale_factor # depth obj['size'][2] = obj['original_asset_size'][2] * scale_factor # height # Update z position to keep object on ground obj['position'][2] = 0.0 return f"Harmonized {len(harmonizable_objects)} objects based on largest scale: {max_scale}" def add_cuboid(self, description="New Cuboid", asset_type="Custom"): """Add a cuboid with dimensions based on asset type""" obj_id = len(self.objects) # Determine dimensions based on asset type if asset_type == "Custom" or asset_type not in self.asset_dimensions: size = [1.0, 1.0, 1.0] # Default size original_asset_size = None # Custom objects have no original asset size else: # Load dimensions from pkl file dims = self.asset_dimensions[asset_type] size = [float(dims[0]), float(dims[1]), float(dims[2])] # [width, depth, height] original_asset_size = size.copy() # Store the original asset dimensions cuboid = { 'id': obj_id, 'description': description, 'type': asset_type, # Store the asset type 'position': [0.0, 0.0, 0.0], # Place on ground (z = height/2) 'size': size, 'original_asset_size': original_asset_size, # Store original asset dimensions 'azimuth': 0.0, 'color': self._get_next_color() } self.objects.append(cuboid) return obj_id def update_cuboid(self, obj_id, x, y, z, azimuth, width, depth, height): if 0 <= obj_id < len(self.objects): obj = self.objects[obj_id] obj['position'] = [x, y, z] obj['size'] = [width, depth, height] # Note: We do NOT update original_asset_size here - it stays unchanged obj['azimuth'] = azimuth return True return False def delete_cuboid(self, obj_id): if 0 <= obj_id < len(self.objects): del self.objects[obj_id] # Update IDs for remaining objects for i, obj in enumerate(self.objects): obj['id'] = i return True return False def set_camera_elevation(self, elevation_deg): assert type(elevation_deg) == float or type(elevation_deg) == int, f"{type(elevation_deg) = }" self.camera_elevation = np.clip(elevation_deg, 0.0, 90.0) return f"Camera elevation set to {elevation_deg}°" def set_camera_lens(self, lens_value): self.camera_lens = np.clip(lens_value, 10.0, 200.0) return f"Camera lens set to {lens_value}mm" def set_surrounding_prompt(self, prompt): # Add this method self.surrounding_prompt = prompt return f"Surrounding prompt updated" def _convert_to_blender_format(self): """Convert internal objects format to server expected format""" subjects_data = [] for obj in self.objects: subject_data = { 'subject_name': obj['description'], 'x': float(obj['position'][0]), 'y': float(obj['position'][1]), 'z': float(obj['position'][2]), 'azimuth': float(obj['azimuth']), 'width': float(obj['size'][0]), 'depth': float(obj['size'][1]), 'height': float(obj['size'][2]), 'base_color': obj['color'] } subjects_data.append(subject_data) camera_data = { 'camera_elevation': float(np.radians(self.camera_elevation)), 'lens': float(self.camera_lens), 'global_scale': 1.0 } return subjects_data, camera_data def render_cv_view(self, subjects_data: list, camera_data: dict) -> Image.Image: """Render only the CV view.""" if not subjects_data: return Image.new('RGB', (512, 512), color='gray') return self.render_client._send_render_request(self.render_client.cv_server_url, subjects_data, camera_data) def render_scene(self, width=512, height=512): """Render only CV view using the render client.""" print(f"calling render_scene") if not self.objects: # Return empty image if no objects empty_cv = Image.new('RGB', (width, height), color='gray') return empty_cv # Convert to server expected format subjects_data, camera_data = self._convert_to_blender_format() print(f"passing {subjects_data = } to render_cv_view in SceneManager") # Render CV view only cv_img = self.render_cv_view(subjects_data, camera_data) return cv_img # --- Gradio Interface Logic --- scene_manager = SceneManager() def get_cuboid_list_html(): """Generate HTML for the cuboid list with position-based colors""" if not scene_manager.objects: return "
No cuboids yet. Add one to get started!
" html = "
" for obj_idx, obj in enumerate(scene_manager.objects): # Get position-based color # x, y = obj['position'][0], obj['position'][1] # rgb_color = map_point_to_rgb(x, y) rgb_color = COLORS[obj_idx % len(COLORS)] hex_color = rgb_to_hex(rgb_color) # Create a lighter version for gradient end lighter_rgb = tuple(min(1.0, c + 0.2) for c in rgb_color) lighter_hex = rgb_to_hex(lighter_rgb) html += f"""
{obj['description']}
Pos: ({obj['position'][0]:.1f}, {obj['position'][1]:.1f}, {obj['position'][2]:.1f}) | Size: {obj['size'][0]:.1f}×{obj['size'][1]:.1f}×{obj['size'][2]:.1f}
""" html += "
" return html def make_radio_choices(): """Generate unique radio button labels for all cuboids.""" return [f"[{i}] {obj['description']}" for i, obj in enumerate(scene_manager.objects)] def find_obj_by_radio(selected_name): """Extract cuboid index from radio label. Returns (obj_id, obj) or (None, None).""" if not selected_name or not selected_name.startswith("["): return None, None try: idx = int(selected_name[1:selected_name.index("]")]) if 0 <= idx < len(scene_manager.objects): return idx, scene_manager.objects[idx] except (ValueError, IndexError): pass return None, None def make_radio_value(obj_id): """Generate radio label for a specific cuboid by index.""" if 0 <= obj_id < len(scene_manager.objects): return f"[{obj_id}] {scene_manager.objects[obj_id]['description']}" return None def add_cuboid_event(description_input, asset_type, camera_elevation, camera_lens): """Add a new cuboid""" if not description_input.strip(): description_input = "New Cuboid" new_id = scene_manager.add_cuboid(description_input, asset_type) cv_img = scene_manager.render_scene() # Create choices for radio buttons choices = make_radio_choices() # Get the new object data new_obj = scene_manager.objects[new_id] return ( gr.update(value=""), # Clear description input gr.update(value="Custom"), # Reset type dropdown to Custom cv_img, get_cuboid_list_html(), gr.update(choices=choices, value=make_radio_value(new_id)), # Radio with new selection gr.update(visible=True), # Show editor gr.update(value=new_obj['description']), # Set description in editor gr.update(value=round(new_obj['position'][0], 2)), gr.update(value=round(new_obj['position'][1], 2)), gr.update(value=round(new_obj['position'][2], 2)), gr.update(value=new_obj['azimuth']), gr.update(value=round(new_obj['size'][0], 2)), gr.update(value=round(new_obj['size'][1], 2)), gr.update(value=round(new_obj['size'][2], 2)), gr.update(value=1.0) # Reset scale to 1.0 ) def select_cuboid_event(selected_name): """When a cuboid is selected from radio buttons""" if not selected_name: return [gr.update(visible=False)] + [gr.update() for _ in range(9)] # Changed from 8 to 9 # Find the cuboid by radio label _, obj = find_obj_by_radio(selected_name) if obj is None: return [gr.update(visible=False)] + [gr.update() for _ in range(9)] return ( gr.update(visible=True), # Show editor gr.update(value=obj['description']), # Set description gr.update(value=round(obj['position'][0], 2)), gr.update(value=round(obj['position'][1], 2)), gr.update(value=round(obj['position'][2], 2)), gr.update(value=obj['azimuth']), gr.update(value=round(obj['size'][0], 2)), gr.update(value=round(obj['size'][1], 2)), gr.update(value=round(obj['size'][2], 2)), gr.update(value=1.0) # Reset scale to 1.0 ) def delete_selected_cuboid(selected_name, camera_elevation, camera_lens): """Delete the currently selected cuboid""" if not selected_name: return gr.update(), get_cuboid_list_html(), gr.update(), gr.update(visible=False) # Find and delete the cuboid obj_id, _ = find_obj_by_radio(selected_name) if obj_id is not None: scene_manager.delete_cuboid(obj_id) cv_img = scene_manager.render_scene() # Update choices choices = make_radio_choices() return ( cv_img, get_cuboid_list_html(), gr.update(choices=choices, value=None), gr.update(visible=False) ) def update_cuboid_event(selected_name, camera_elevation, camera_lens, description, x, y, z, azimuth, width, depth, height, scale): """Update the selected cuboid including description and scale""" scene_manager.set_camera_elevation(camera_elevation) scene_manager.set_camera_lens(camera_lens) if selected_name: # Find the cuboid by radio label obj_id, _ = find_obj_by_radio(selected_name) if obj_id is not None: # Update description first if changed if description.strip() and description.strip() != selected_name: scene_manager.update_cuboid_description(obj_id, description.strip()) # Apply scale to dimensions scaled_width = width * scale scaled_depth = depth * scale scaled_height = height * scale # Update other properties with scaled dimensions scene_manager.update_cuboid(obj_id, x, y, z, azimuth, scaled_width, scaled_depth, scaled_height) # Get updated object for return updated_obj = scene_manager.objects[obj_id] new_name = updated_obj['description'] cv_img = scene_manager.render_scene() # Update choices with new descriptions choices = make_radio_choices() # Return updated HTML, image, radio choices, new selection, updated sliders, and reset scale return ( get_cuboid_list_html(), cv_img, gr.update(choices=choices, value=make_radio_value(obj_id) if obj_id is not None else None), gr.update(value=round(scaled_width, 2) if obj_id is not None else round(width, 2)), # Update width slider gr.update(value=round(scaled_depth, 2) if obj_id is not None else round(depth, 2)), # Update depth slider gr.update(value=round(scaled_height, 2) if obj_id is not None else round(height, 2)), # Update height slider gr.update(value=1.0) # Reset scale to 1.0 ) def camera_change_event(camera_elevation, camera_lens): """Handle camera control changes""" scene_manager.set_camera_elevation(camera_elevation) scene_manager.set_camera_lens(camera_lens) cv_img = scene_manager.render_scene() return cv_img def surrounding_prompt_change_event(prompt_text): # Add this function """Handle surrounding prompt changes""" scene_manager.set_surrounding_prompt(prompt_text) return None # No visual update needed def render_segmask_event(camera_elevation, camera_lens, surrounding_prompt): """Render segmentation masks for all objects""" if not scene_manager.objects: return "⚠️ No objects to render", gr.update(visible=False), [] # Get subject descriptions subject_descriptions = [obj['description'] for obj in scene_manager.objects] # Now you have access to: # - surrounding_prompt: the text from surrounding_prompt_input # - subject_descriptions: list of all subject descriptions print(f"Surrounding prompt: {surrounding_prompt}") print(f"Subject descriptions: {subject_descriptions}") placeholder_prompt = "a photo of PLACEHOLDER " + surrounding_prompt # Create placeholder text subject_embeds = [] for subject_idx, subject_desc in enumerate(subject_descriptions): input_ids = tokenizer.encode(subject_desc, return_tensors="pt", max_length=77)[0] subject_embed = {"input_ids_t5": input_ids.tolist()} subject_embeds.append(subject_embed) placeholder_text = "" for subject in subject_descriptions[:-1]: placeholder_text = placeholder_text + f" {subject} and " for subject in subject_descriptions[-1:]: placeholder_text = placeholder_text + f" {subject}" placeholder_text = placeholder_text.strip() placeholder_token_prompt = placeholder_prompt.replace("PLACEHOLDER", placeholder_text) call_ids = get_call_ids_from_placeholder_prompt_flux(prompt=placeholder_token_prompt, subjects=subject_descriptions, subjects_embeds=subject_embeds, debug=True ) print(f"Generated call IDs: {call_ids}") # Convert to server expected format subjects_data, camera_data = scene_manager._convert_to_blender_format() # You can add the prompt and descriptions to the request if needed # For example, add to subjects_data or camera_data before sending # Render segmentation masks success, segmask_images, error_msg = scene_manager.render_client.render_segmasks(subjects_data, camera_data) # copy all the data to the correct location root_save_dir = config.GRADIO_FILES_DIR os.makedirs(root_save_dir, exist_ok=True) os.system(f"rm {root_save_dir}/*") shutil.move("cv_render.jpg", osp.join(root_save_dir, "cv_render.jpg")) for subject_idx in range(len(subject_descriptions)): shutil.move(f"{str(subject_idx).zfill(3)}_segmask_cv.png", osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png")) jsonl = [{ "cv": osp.join(root_save_dir, "cv_render.jpg"), "target": osp.join(root_save_dir, "cv_render.jpg"), "cuboids_segmasks": [osp.join(root_save_dir, f"main__segmask_{str(subject_idx).zfill(3)}__{1.00}.png") for subject_idx in range(len(subject_descriptions))], "PLACEHOLDER_prompts": placeholder_prompt, "subjects": subject_descriptions, "call_ids": call_ids, }] with open(osp.join(root_save_dir, "cuboids.jsonl"), "w") as f: for item in jsonl: f.write(json.dumps(item) + "\n") if success: return ( f"✅ Successfully rendered {len(segmask_images)} segmentation masks", gr.update(visible=True), segmask_images ) else: return ( f"❌ Failed to render segmentation masks: {error_msg}", gr.update(visible=False), [] ) def harmonize_event(selected_name, camera_elevation, camera_lens): """Harmonize all object scales and update the scene""" message = scene_manager.harmonize_scales() print(message) cv_img = scene_manager.render_scene() # If a cuboid is selected, update its sliders if selected_name: _, obj = find_obj_by_radio(selected_name) if obj is not None: return ( cv_img, get_cuboid_list_html(), gr.update(value=round(obj['position'][0], 2)), gr.update(value=round(obj['position'][1], 2)), gr.update(value=round(obj['position'][2], 2)), gr.update(value=obj['azimuth']), gr.update(value=round(obj['size'][0], 2)), gr.update(value=round(obj['size'][1], 2)), gr.update(value=round(obj['size'][2], 2)) ) # No object selected or object not found return ( cv_img, get_cuboid_list_html(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() ) def save_scene_event(): """Save the current scene to a pkl file""" success, filepath, error = scene_manager.save_scene_to_pkl() if success: return f"✅ Scene saved successfully to: {filepath}\n📋 Saved parameters: {scene_manager.inference_params}" else: return f"❌ Failed to save scene: {error}" def load_scene_event(filepath): """Load a scene from a pkl file and restore all parameters""" if not filepath.strip(): return ( "⚠️ Please enter a file path", gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), # surrounding_prompt gr.update(), # checkpoint gr.update(), # image_size gr.update(), # seed gr.update(), # guidance gr.update() # steps ) success, num_objects, error = scene_manager.load_scene_from_pkl(filepath) if success: # Re-render the scene cv_img = scene_manager.render_scene() # Update UI components choices = make_radio_choices() params_msg = f"✅ Scene loaded: {num_objects} objects\n📋 Restored parameters: {scene_manager.inference_params}" return ( params_msg, cv_img, get_cuboid_list_html(), gr.update(choices=choices, value=None), gr.update(visible=False), gr.update(value=scene_manager.camera_elevation), gr.update(value=scene_manager.camera_lens), gr.update(value=scene_manager.surrounding_prompt), gr.update(value=scene_manager.inference_params['checkpoint']), gr.update(value=scene_manager.inference_params['height']), gr.update(value=scene_manager.inference_params['seed']), gr.update(value=scene_manager.inference_params['guidance_scale']), gr.update(value=scene_manager.inference_params['num_inference_steps']) ) else: return ( f"❌ {error}", gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update() ) # --- Gradio UI Layout --- with gr.Blocks( theme=gr.themes.Soft( primary_hue="slate", secondary_hue="slate", neutral_hue="slate" ), css=""" /* ── Global ── */ .gradio-container { background: linear-gradient(160deg, #0f0f1a 0%, #161625 40%, #1a1a2e 100%) !important; color: #e0e0e8 !important; font-family: 'Arial', sans-serif !important; } * { font-family: 'Arial', sans-serif !important; } /* ── Cards / blocks ── */ .block { background: rgba(22, 22, 38, 0.65) !important; border: none !important; border-radius: 10px !important; box-shadow: 0 1px 4px rgba(0,0,0,0.30) !important; } .form { background: transparent !important; border: none !important; } /* ── Typography ── */ h1, h2, h3, h4, h5, h6 { color: #f0f0f5 !important; } .markdown { color: #d0d0da !important; } label { color: #b0b0be !important; } /* ── Buttons ── */ .gr-button { background: linear-gradient(135deg, #3a3a5c, #4a4a6e) !important; border: none !important; color: #ffffff !important; border-radius: 6px !important; transition: background 0.2s ease !important; } .gr-button:hover { background: linear-gradient(135deg, #4a4a6e, #5a5a80) !important; } /* ── Inputs ── */ .gr-input, .gr-textbox, .gr-dropdown, textarea, input[type="text"], input[type="number"] { background: rgba(30, 30, 50, 0.7) !important; border: 1px solid rgba(255,255,255,0.08) !important; color: #e0e0e8 !important; border-radius: 6px !important; } .gr-input:focus, .gr-textbox:focus, textarea:focus, input[type="text"]:focus, input[type="number"]:focus { border-color: rgba(130, 130, 200, 0.5) !important; background: rgba(35, 35, 58, 0.8) !important; } /* ── Slider track ── */ input[type="range"] { -webkit-appearance: none !important; appearance: none !important; height: 3px !important; width: calc(100% - 16px) !important; margin-left: 8px !important; margin-right: 8px !important; background: rgba(255,255,255,0.12) !important; border-radius: 2px !important; outline: none !important; overflow: visible !important; } input[type="range"]::-webkit-slider-thumb { -webkit-appearance: none !important; appearance: none !important; width: 16px !important; height: 16px !important; border-radius: 50% !important; background: #8888bb !important; cursor: pointer !important; border: 2px solid #c8c8e0 !important; box-sizing: border-box !important; } input[type="range"]::-moz-range-thumb { width: 16px !important; height: 16px !important; border-radius: 50% !important; background: #8888bb !important; cursor: pointer !important; border: 2px solid #c8c8e0 !important; box-sizing: border-box !important; } input[type="range"]::-moz-range-track { height: 3px !important; background: rgba(255,255,255,0.12) !important; border-radius: 2px !important; } /* ── Gradio custom slider thumbs ── */ .range_slider .thumb, .range_slider .thumb::after, .range_slider [role="slider"], .noUi-handle, .slider .thumb, .gradio-slider .thumb { width: 16px !important; height: 16px !important; border-radius: 50% !important; background: #8888bb !important; border: 2px solid #c8c8e0 !important; box-sizing: border-box !important; cursor: pointer !important; } /* ── Hide ugly scrollbars on slider containers ── */ .gradio-slider, .gradio-slider > div, .gradio-slider > div > div, .wrap, .range_slider, div[class*="slider"], div[class*="Slider"], input[type="range"], .gradio-slider *, .axis-box-red *, .axis-box-green *, .axis-box-white * { overflow: visible !important; overflow-x: hidden !important; scrollbar-width: none !important; -ms-overflow-style: none !important; } .gradio-slider::-webkit-scrollbar, .gradio-slider > div::-webkit-scrollbar, .gradio-slider > div > div::-webkit-scrollbar, .gradio-slider *::-webkit-scrollbar, .wrap::-webkit-scrollbar, .range_slider::-webkit-scrollbar, div[class*="slider"]::-webkit-scrollbar, div[class*="Slider"]::-webkit-scrollbar, .axis-box-red *::-webkit-scrollbar, .axis-box-green *::-webkit-scrollbar, .axis-box-white *::-webkit-scrollbar { display: none !important; width: 0 !important; height: 0 !important; } /* Number input boxes next to sliders */ .gradio-slider input[type="number"] { overflow: hidden !important; } /* ── Axis slider boxes ── */ .axis-box-red { background: rgba(40, 20, 20, 0.5) !important; border: none !important; border-left: 3px solid #e05555 !important; border-radius: 8px !important; padding: 10px 14px !important; overflow: visible !important; } .axis-box-red input[type="range"]::-webkit-slider-thumb { background: #e05555 !important; border-color: #ffaaaa !important; border-radius: 50% !important; } .axis-box-red input[type="range"]::-moz-range-thumb { background: #e05555 !important; border-color: #ffaaaa !important; border-radius: 50% !important; } .axis-box-red .range_slider .thumb, .axis-box-red .range_slider [role="slider"] { background: #e05555 !important; border-color: #ffaaaa !important; border-radius: 50% !important; } .axis-box-green { background: rgba(20, 40, 20, 0.5) !important; border: none !important; border-left: 3px solid #55cc55 !important; border-radius: 8px !important; padding: 10px 14px !important; overflow: visible !important; } .axis-box-green input[type="range"]::-webkit-slider-thumb { background: #55cc55 !important; border-color: #aaffaa !important; border-radius: 50% !important; } .axis-box-green input[type="range"]::-moz-range-thumb { background: #55cc55 !important; border-color: #aaffaa !important; border-radius: 50% !important; } .axis-box-green .range_slider .thumb, .axis-box-green .range_slider [role="slider"] { background: #55cc55 !important; border-color: #aaffaa !important; border-radius: 50% !important; } .axis-box-white { background: rgba(40, 40, 40, 0.5) !important; border: none !important; border-left: 3px solid #cccccc !important; border-radius: 8px !important; padding: 10px 14px !important; overflow: visible !important; } .axis-box-white input[type="range"]::-webkit-slider-thumb { background: #cccccc !important; border-color: #ffffff !important; border-radius: 50% !important; } .axis-box-white input[type="range"]::-moz-range-thumb { background: #cccccc !important; border-color: #ffffff !important; border-radius: 50% !important; } .axis-box-white .range_slider .thumb, .axis-box-white .range_slider [role="slider"] { background: #cccccc !important; border-color: #ffffff !important; border-radius: 50% !important; } .gr-radio label { color: #b0b0be !important; } .gr-panel { background: rgba(22, 22, 38, 0.5) !important; border: none !important; } """ ) as demo: gr.Markdown("# [CVPR-2026] Occlusion Aware 3D Control in Text-to-Image Generation") # TOP ROW with gr.Row(): # TOP LEFT - Edit Properties with gr.Column(scale=1): # Add description textbox at the top # with gr.Column(visible=False) as editor_section: # gr.Markdown("## ✏️ Edit Properties") # delete_btn = gr.Button("❌ Delete Selected Cuboid", variant="stop", size="sm") # with gr.Row(): # edit_x = gr.Slider(-10, 10, value=0, step=0.1, label="X") # edit_y = gr.Slider(-10, 10, value=0, step=0.1, label="Y") # edit_z = gr.Slider(0, 10, value=1, step=0.1, label="Z") # edit_azimuth = gr.Slider(-180, 180, value=0, step=1, label="Azimuth (°)") # with gr.Row(): # edit_width = gr.Slider(0.1, 5, value=1, step=0.1, label="Width") # edit_depth = gr.Slider(0.1, 5, value=1, step=0.1, label="Depth") # edit_height = gr.Slider(0.1, 5, value=1, step=0.1, label="Height") with gr.Column(visible=False) as editor_section: gr.Markdown("## ✏️ Edit Properties") edit_description = gr.Textbox( label="Description", placeholder="Enter object description", info="Description cannot be empty" ) delete_btn = gr.Button("❌ Delete Selected Cuboid", variant="stop", size="sm") with gr.Row(): with gr.Column(elem_classes=["axis-box-red"], min_width=120): gr.Markdown("**Away / Towards Camera**") edit_x = gr.Slider(-10, 10, value=0, step=0.01, label="", show_label=False) with gr.Column(elem_classes=["axis-box-green"], min_width=120): gr.Markdown("**Left / Right**") edit_y = gr.Slider(-10, 10, value=0, step=0.01, label="", show_label=False) with gr.Column(elem_classes=["axis-box-white"], min_width=120): gr.Markdown("**Up / Down**") edit_z = gr.Slider(0, 10, value=1, step=0.01, label="", show_label=False) edit_azimuth = gr.Slider(-180, 180, value=0, step=1, label="Azimuth (°)") with gr.Row(): edit_width = gr.Slider(0.0, 5, value=1, step=0.01, label="Width") edit_depth = gr.Slider(0.0, 5, value=1, step=0.01, label="Depth") edit_height = gr.Slider(0.0, 5, value=1, step=0.01, label="Height") # Add scale slider edit_scale = gr.Slider( 0.0, 3.0, value=1.0, step=0.01, label="Scale", info="Multiplier for all dimensions (resets to 1.0 after update)" ) # Add the Update Scene button update_scene_btn = gr.Button("🔄 Update Scene", variant="primary", size="sm") # TOP MIDDLE - Camera View with gr.Column(scale=1): gr.Markdown("## 🧊 Layout Visualization") cv_image_output = gr.Image(label="Camera View", height=400) # TOP RIGHT - Generated Image with gr.Column(scale=1): gr.Markdown("## 🎨 Generated Image") generated_image_output = gr.Image(label="Generated Image", height=400) # BOTTOM ROW with gr.Row(): # BOTTOM LEFT - Cuboid List and Selection with gr.Column(scale=1): gr.Markdown("## 📦 Scene Objects") cuboid_list_html = gr.HTML(get_cuboid_list_html()) gr.Markdown("### Select Object to Edit") cuboid_radio = gr.Radio(choices=[], label="", visible=True) # BOTTOM RIGHT - Camera Controls and Add New Cuboid with gr.Column(scale=2): with gr.Row(): with gr.Column(): gr.Markdown("## Global Scene Controls") camera_elevation_slider = gr.Slider(0, 90, value=20, label="Camera Elevation (degrees)") camera_lens_slider = gr.Slider(10, 200, value=50, label="Camera Lens (mm)") # Add surrounding prompt textbox surrounding_prompt_input = gr.Textbox( placeholder="e.g., in a forest, in a city, on a beach", label="Surrounding Prompt", info="Describe the surrounding environment" ) gr.Markdown("## 🔧 Scene Tools") harmonize_btn = gr.Button("⚖️ Adjust Object Scales", variant="secondary") # Save/Load Section gr.Markdown("## 💾 Save/Load Scene") with gr.Row(): save_scene_btn = gr.Button("💾 Save Scene", variant="secondary") load_scene_btn = gr.Button("📂 Load Scene", variant="secondary") load_path_input = gr.Textbox( placeholder="/path/to/scene.pkl", label="Load Scene Path", info="Enter path to pkl file to load" ) save_load_status = gr.Markdown("") example_files = [] for i in range(5): webp_path = os.path.join(config.SAVED_SCENES_DIR, f"example{i}.webp") pkl_name = f"example{i}.pkl" if os.path.exists(webp_path): example_files.append((webp_path, pkl_name)) if example_files: gr.Markdown("## 🖼️ Examples") example_gallery = gr.Gallery( value=[img for img, pkl in example_files], label="Click an example to load its scene", show_label=True, columns=max(len(example_files), 1), rows=1, allow_preview=False, object_fit="contain", height=120 ) example_gallery_state = gr.State([pkl for img, pkl in example_files]) with gr.Column(): gr.Markdown("## ➕ Add New Object") add_cuboid_description_input = gr.Textbox(placeholder="Enter cuboid description", label="Description") asset_type_dropdown = gr.Dropdown( choices=scene_manager.get_asset_type_choices(), value="Custom", label="Type", info="Select asset type to load dimensions, or choose Custom" ) add_cuboid_btn = gr.Button("Add Object", variant="primary") generate_btn = gr.Button("🎨 Generate Image", variant="primary") # Add checkpoint dropdown checkpoint_dropdown = gr.Dropdown( choices=config.CHECKPOINT_NAMES, value=config.CHECKPOINT_NAMES[0] if config.CHECKPOINT_NAMES else None, label="Checkpoint", info="Select model checkpoint for generation" ) # Inference Parameters gr.Markdown("### Inference Parameters") inference_image_size = gr.Slider( minimum=256, maximum=1024, value=1024, step=64, label="Image Size" ) inference_seed = gr.Number( value=42, label="Random Seed", precision=0 ) inference_guidance = gr.Slider( minimum=1.0, maximum=10.0, value=3.5, step=0.5, label="Guidance Scale" ) inference_steps = gr.Slider( minimum=10, maximum=50, value=25, step=1, label="Inference Steps" ) # Event Handlers def add_cuboid_with_auto_update(description_input, asset_type, camera_elevation, camera_lens): """Add cuboid and auto-update scene""" result = add_cuboid_event(description_input, asset_type, camera_elevation, camera_lens) return result # Update add_cuboid_btn.click event handler (around line 850): add_cuboid_btn.click( add_cuboid_with_auto_update, inputs=[add_cuboid_description_input, asset_type_dropdown, camera_elevation_slider, camera_lens_slider], outputs=[ add_cuboid_description_input, asset_type_dropdown, cv_image_output, cuboid_list_html, cuboid_radio, editor_section, edit_description, edit_x, edit_y, edit_z, edit_azimuth, edit_width, edit_depth, edit_height, edit_scale # Add this ] ) # Update the cuboid_radio.change event handler (around line 860): cuboid_radio.change( select_cuboid_event, inputs=[cuboid_radio], outputs=[ editor_section, edit_description, edit_x, edit_y, edit_z, edit_azimuth, edit_width, edit_depth, edit_height, edit_scale # Add this ] ) delete_btn.click( delete_selected_cuboid, inputs=[cuboid_radio, camera_elevation_slider, camera_lens_slider], outputs=[cv_image_output, cuboid_list_html, cuboid_radio, editor_section] ) # Save/Load handlers save_scene_btn.click( save_scene_event, inputs=[], outputs=[save_load_status] ) load_scene_btn.click( load_scene_event, inputs=[load_path_input], outputs=[ save_load_status, cv_image_output, cuboid_list_html, cuboid_radio, editor_section, camera_elevation_slider, camera_lens_slider, surrounding_prompt_input, checkpoint_dropdown, inference_image_size, inference_seed, inference_guidance, inference_steps ] ) def load_from_gallery(state, evt: gr.SelectData): pkl_name = state[evt.index] return load_scene_event(pkl_name) if 'example_gallery' in locals(): example_gallery.select( load_from_gallery, inputs=[example_gallery_state], outputs=[ save_load_status, cv_image_output, cuboid_list_html, cuboid_radio, editor_section, camera_elevation_slider, camera_lens_slider, surrounding_prompt_input, checkpoint_dropdown, inference_image_size, inference_seed, inference_guidance, inference_steps ] ) # Auto-update scene when sliders change # for slider in [edit_x, edit_y, edit_z, edit_azimuth, edit_width, edit_depth, edit_height]: # slider.change( # update_cuboid_event, # inputs=[ # cuboid_radio, # camera_elevation_slider, # camera_lens_slider, # edit_x, edit_y, edit_z, # edit_azimuth, # edit_width, edit_depth, edit_height # ], # outputs=[cuboid_list_html, cv_image_output] # ) # Update the update_scene_btn.click event handler (around line 920): update_scene_btn.click( update_cuboid_event, inputs=[ cuboid_radio, camera_elevation_slider, camera_lens_slider, edit_description, edit_x, edit_y, edit_z, edit_azimuth, edit_width, edit_depth, edit_height, edit_scale # Add this ], outputs=[ cuboid_list_html, cv_image_output, cuboid_radio, edit_width, # Add this edit_depth, # Add this edit_height, # Add this edit_scale # Add this (to reset to 1.0) ] ) # Update generate button click handler generate_btn.click( generate_image_event, inputs=[ camera_elevation_slider, camera_lens_slider, surrounding_prompt_input, checkpoint_dropdown, inference_image_size, inference_seed, inference_guidance, inference_steps ], outputs=[save_load_status, cv_image_output, generated_image_output] ) harmonize_btn.click( harmonize_event, inputs=[cuboid_radio, camera_elevation_slider, camera_lens_slider], outputs=[ cv_image_output, cuboid_list_html, edit_x, edit_y, edit_z, edit_azimuth, edit_width, edit_depth, edit_height ] ) # Camera controls for control in [camera_elevation_slider, camera_lens_slider]: control.change( camera_change_event, inputs=[camera_elevation_slider, camera_lens_slider], outputs=[cv_image_output] ) # Surrounding prompt control surrounding_prompt_input.change( surrounding_prompt_change_event, inputs=[surrounding_prompt_input], outputs=[] ) # Initial render def initial_render(): cv_img = scene_manager.render_scene() gen_img = Image.new('RGB', (512, 512), color='white') return cv_img, gen_img demo.load( initial_render, outputs=[cv_image_output, generated_image_output] ) if __name__ == "__main__": import os from urllib.parse import urlparse import config cv_port = urlparse(config.BLENDER_CV_SERVER_URL).port final_port = urlparse(config.BLENDER_FINAL_SERVER_URL).port seg_port = urlparse(config.BLENDER_SEGMASK_SERVER_URL).port os.system(f"./launch_blender_backend.sh {cv_port} {final_port} {seg_port} 5004 &") # Initialize inference engine (load model once at startup) initialize_inference_engine(base_model_path=config.PRETRAINED_MODEL_NAME_OR_PATH) demo.launch(share=True)