# SPDX-FileCopyrightText: 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics # SPDX-License-Identifier: Apache-2.0 # See the LICENSE file in the project root for full license information. from gradio_image_prompter import ImagePrompter import gradio as gr import spaces import os import uuid from typing import Any, List, Optional, Union import cv2 import torch import numpy as np from PIL import Image import trimesh import random import imageio from einops import repeat from huggingface_hub import snapshot_download from moge.model.v2 import MoGeModel from transformers import AutoModelForMaskGeneration, AutoProcessor from scripts.grounding_sam import plot_segmentation, segment import copy import shutil import time from concurrent.futures import ThreadPoolExecutor MARKDOWN = """ ## Image to 3D Scene with [3D-Fixer](https://zx-yin.github.io/3dfixer/) 1. Upload an image, and draw bounding boxes for each instance by holding and dragging the mouse. Then click "Run Segmentation" to generate the segmentation result. 2. If you find the generated 3D scene satisfactory, download it by clicking the "Download scene GLB" button, and you can also download each islolated 3D instance. 3. In this implementation, we generate each instances one by one, and update the scene results at the "Generated GLB" area, besides, we display isolated instances below. 4. It may take some time to download the ckpts, and compile the gsplat. Thank you for your patience to wait. We recommend to deploy the demo locally. """ MAX_SEED = np.iinfo(np.int32).max TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") EXAMPLE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets/example_data") DTYPE = torch.float16 DEVICE = "cpu" VALID_RATIO_THRESHOLD = 0.005 CROP_SIZE = 518 work_space = None generated_object_map = {} ############## 3D-Fixer model model_dir = 'HorizonRobotics/3D-Fixer' local_dir = "./checkpoints/3D-Fixer" os.makedirs(local_dir, exist_ok=True) snapshot_download(repo_id=model_dir, local_dir=local_dir) ############## 3D-Fixer model save_projected_colored_pcd = lambda pts, pts_color, fpath: trimesh.PointCloud(pts.reshape(-1, 3), pts_color.reshape(-1, 3)).export(fpath) EXAMPLES = [ [ { "image": "assets/example_data/scene1/rgb.png", }, "assets/example_data/scene1/seg.png", 1024, False, 25, 5.5, 0.8, 1.0, 5.0 # num_inference_steps, guidance_scale, cfg_interval_start, cfg_interval_end, t_rescale ], [ { "image": "assets/example_data/scene2/rgb.png", }, "assets/example_data/scene2/seg.png", 1, False, 25, 5.0, 0.8, 1.0, 5.0 ], [ { "image": "assets/example_data/scene3/rgb.png", }, "assets/example_data/scene3/seg.png", 1, False, 25, 5.0, 0.8, 1.0, 5.0 ], [ { "image": "assets/example_data/scene4/rgb.png", }, "assets/example_data/scene4/seg.png", 42, False, 25, 5.0, 0.8, 1.0, 5.0 ], [ { "image": "assets/example_data/scene5/rgb.png", }, "assets/example_data/scene5/seg.png", 1, False, 25, 5.0, 0.8, 1.0, 5.0 ], [ { "image": "assets/example_data/scene6/rgb.png", }, "assets/example_data/scene6/seg.png", 1, False, 25, 5.0, 0.8, 1.0, 5.0 ] ] def cleanup_tmp(tmp_root: str = "./tmp", expire_seconds: int = 3600) -> None: """ 删除 tmp_root 下超过 expire_seconds 未更新的旧子目录。 Args: tmp_root: 临时目录根路径。 expire_seconds: 过期时间,默认 3600 秒(1 小时)。 """ tmp_root = os.path.abspath(tmp_root) if not os.path.isdir(tmp_root): return now = time.time() for name in os.listdir(tmp_root): path = os.path.join(tmp_root, name) # 只清理子目录,不动散落文件 if not os.path.isdir(path): continue try: mtime = os.path.getmtime(path) age = now - mtime if age > expire_seconds: shutil.rmtree(path, ignore_errors=False) print(f"[cleanup_tmp] removed old directory: {path}") except Exception as e: print(f"[cleanup_tmp] failed to remove {path}: {e}") # run seg on CPU def run_segmentation( image_prompts: Any, polygon_refinement: bool = True, ) -> Image.Image: rgb_image = image_prompts["image"].convert("RGB") global sam_segmentator device = "cpu" sam_segmentator.to(device=device, dtype=DTYPE if device == 'cuda' else torch.float32) # pre-process the layers and get the xyxy boxes of each layer if len(image_prompts["points"]) == 0: raise gr.Error("No points provided for segmentation. Please add points to the image.") boxes = [ [ [int(box[0]), int(box[1]), int(box[3]), int(box[4])] for box in image_prompts["points"] ] ] with torch.no_grad(): detections = segment( sam_processor, sam_segmentator, rgb_image, boxes=[boxes], polygon_refinement=polygon_refinement, ) seg_map_pil = plot_segmentation(rgb_image, detections) cleanup_tmp(TMP_DIR, expire_seconds=3600) work_space = { "dir": os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}"), } os.makedirs(work_space["dir"], exist_ok=True) seg_map_pil.save(os.path.join(work_space["dir"], "mask.png")) return seg_map_pil, work_space @spaces.GPU def run_depth_estimation( image_prompts: Any, seg_image: Union[str, Image.Image], work_space: dict, ) -> Image.Image: rgb_image = image_prompts["image"].convert("RGB") from threeDFixer.datasets.utils import ( normalize_vertices, project2ply ) rgb_image = rgb_image.resize((1024, 1024), Image.Resampling.LANCZOS) global moge_v2_dpt_model device = 'cuda' if torch.cuda.is_available() else 'cpu' dtype = torch.float16 if device == 'cuda' else torch.float32 moge_v2_dpt_model = moge_v2_dpt_model.to(device=device, dtype=dtype) if work_space is None: work_space = { "dir": os.path.join(TMP_DIR, f"work_space_{uuid.uuid4()}"), } os.makedirs(work_space["dir"], exist_ok=True) origin_W, origin_H = rgb_image.size if max(origin_H, origin_W) > 1024: factor = max(origin_H, origin_W) / 1024 H = int(origin_H // factor) W = int(origin_W // factor) rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS) W, H = rgb_image.size input_image = np.array(rgb_image).astype(np.float32) input_image = torch.tensor(input_image / 255, dtype=torch.float32, device=device).permute(2, 0, 1) with torch.no_grad(): output = moge_v2_dpt_model.infer(input_image) depth = output['depth'] intrinsics = output['intrinsics'] invalid_mask = torch.logical_or(torch.isnan(depth), torch.isinf(depth)) depth_mask = ~invalid_mask depth = torch.where(invalid_mask, 0.0, depth) K = torch.from_numpy( np.array([ [intrinsics[0, 0].item() * W, 0, 0.5*W], [0, intrinsics[1, 1].item() * H, 0.5*H], [0, 0, 1] ]) ).to(dtype=torch.float32, device=device) work_space.update({ "c2w": c2w, "K": K, "depth_mask": depth_mask, "depth": depth, }) instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0) seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS) seg_image = np.array(seg_image) mask_pack = [] for instance_label in instance_labels: if (instance_label == np.array([0, 0, 0])).all(): continue else: instance_mask = (seg_image.reshape(-1, 3) == instance_label).all(axis=-1).reshape(H, W) mask_pack.append(instance_mask) fg_mask = torch.from_numpy(np.stack(mask_pack).any(axis=0)) scene_est_depth_pts, scene_est_depth_pts_colors = \ project2ply(depth_mask.to(device), depth.to(device), input_image.to(device), K.to(device), c2w.to(device)) save_ply_path = os.path.join(work_space["dir"], "scene_pcd.glb") fg_depth_pts, _ = \ project2ply(fg_mask.to(device), depth.to(device), input_image.to(device), K.to(device), c2w.to(device)) _, trans, scale = normalize_vertices(fg_depth_pts) if trans.shape[0] == 1: trans = trans[0] work_space.update( { "trans": trans, "scale": scale, } ) for k, v in work_space.items(): if isinstance(v, torch.Tensor): work_space[k] = v.to('cpu') trimesh.PointCloud(scene_est_depth_pts.reshape(-1, 3), scene_est_depth_pts_colors.reshape(-1, 3)).\ apply_translation(-trans).apply_scale(1. / (scale + 1e-6)).\ apply_transform(rot).export(save_ply_path) return save_ply_path, work_space def save_image(img, save_path): img = (img.permute(1, 2, 0).detach().cpu().numpy() * 255.).astype(np.uint8) imageio.v3.imwrite(save_path, img) def set_random_seed(seed): np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) def export_scene_glb(trimeshes, work_space, scene_name): scene_path = os.path.abspath(os.path.join(work_space, scene_name)) trimesh.Scene(trimeshes).export(scene_path) return scene_path def get_duration(rgb_image, seg_image, seed, randomize_seed, num_inference_steps, guidance_scale, cfg_interval_start, cfg_interval_end, t_rescale, work_space): instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0) step_duration = 15.0 return instance_labels.shape[0] * step_duration + 240 @spaces.GPU(duration=get_duration) def run_generation( rgb_image: Any, seg_image: Union[str, Image.Image], seed: int, randomize_seed: bool = False, num_inference_steps: int = 50, guidance_scale: float = 5.0, cfg_interval_start: float = 0.5, cfg_interval_end: float = 1.0, t_rescale: float = 3.0, work_space: dict = None, ): first_render = True if work_space is None: raise gr.Error("Please run step 1 and step 2 first.") required_keys = ["dir", "depth_mask", "depth", "K", "c2w", "trans", "scale"] missing = [k for k in required_keys if k not in work_space] if missing: raise gr.Error(f"Missing workspace fields: {missing}. Please run depth estimation (step 2) first.") from threeDFixer.pipelines import ThreeDFixerPipeline from threeDFixer.datasets.utils import ( edge_mask_morph_gradient, process_scene_image, process_instance_image, ) from threeDFixer.utils import render_utils def export_single_glb_from_outputs( outputs, fine_scale, fine_trans, coarse_scale, coarse_trans, trans, scale, rot, work_space, instance_name, run_id ): from threeDFixer.datasets.utils import ( transform_vertices, ) from threeDFixer.utils import postprocessing_utils with torch.enable_grad(): glb = postprocessing_utils.to_glb( outputs["gaussian"][0], outputs["mesh"][0], simplify=0.95, texture_size=1024, transform_fn=lambda x: transform_vertices( x, ops=["scale", "translation", "scale", "translation"], params=[fine_scale, fine_trans[None], coarse_scale, coarse_trans[None]], ), verbose=False ) instance_glb_path = os.path.abspath( os.path.join(work_space, f"{run_id}_{instance_name}.glb") ) glb.apply_translation(-trans) \ .apply_scale(1.0 / (scale + 1e-6)) \ .apply_transform(rot) \ .export(instance_glb_path) return instance_glb_path, glb generated_object_map = {} run_id = str(uuid.uuid4()) DEVICE = 'cuda' gr.Info('Loading ckpts') down_t = time.time() pipeline = ThreeDFixerPipeline.from_pretrained( local_dir, compile=False ) pipeline.to(device=DEVICE) gr.Info(f'Loading ckpts duration: {time.time()-down_t:.2}s') if not isinstance(rgb_image, Image.Image) and "image" in rgb_image: rgb_image = rgb_image["image"] instance_labels = np.unique(np.array(seg_image).reshape(-1, 3), axis=0) if randomize_seed: seed = random.randint(0, MAX_SEED) set_random_seed(seed) H, W = work_space['depth_mask'].shape rgb_image = rgb_image.resize((W, H), Image.Resampling.LANCZOS) seg_image = seg_image.resize((W, H), Image.Resampling.LANCZOS) depth_mask = work_space['depth_mask'].detach().cpu().numpy() > 0 seg_image = np.array(seg_image) mask_pack = [] for instance_label in instance_labels: if (instance_label == np.array([0, 0, 0])).all(): continue instance_mask = (seg_image.reshape(-1, 3) == instance_label).all(axis=-1).reshape(H, W) mask_pack.append(instance_mask) erode_kernel_size = 7 kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_kernel_size, erode_kernel_size)) results = [] trimeshes = [] trans = work_space['trans'] scale = work_space['scale'] current_scene_path = None pending_exports = [] def build_stream_html(status_text: str): cards_html = "".join([ f"""