Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| import os | |
| import sys | |
| import zipfile | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| from torchvision import transforms | |
| __all__ = [ | |
| "monkey_patch_pano2room", | |
| "monkey_patch_maniskill", | |
| "monkey_patch_sam3d", | |
| ] | |
| def monkey_patch_pano2room(): | |
| current_file_path = os.path.abspath(__file__) | |
| current_dir = os.path.dirname(current_file_path) | |
| sys.path.append(os.path.join(current_dir, "../..")) | |
| sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room")) | |
| from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import ( | |
| OmnidataNormalPredictor, | |
| ) | |
| from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import ( | |
| OmnidataPredictor, | |
| ) | |
| def patched_omni_depth_init(self): | |
| self.img_size = 384 | |
| self.model = torch.hub.load( | |
| 'alexsax/omnidata_models', 'depth_dpt_hybrid_384' | |
| ) | |
| self.model.eval() | |
| self.trans_totensor = transforms.Compose( | |
| [ | |
| transforms.Resize(self.img_size, interpolation=Image.BILINEAR), | |
| transforms.CenterCrop(self.img_size), | |
| transforms.Normalize(mean=0.5, std=0.5), | |
| ] | |
| ) | |
| OmnidataPredictor.__init__ = patched_omni_depth_init | |
| def patched_omni_normal_init(self): | |
| self.img_size = 384 | |
| self.model = torch.hub.load( | |
| 'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384' | |
| ) | |
| self.model.eval() | |
| self.trans_totensor = transforms.Compose( | |
| [ | |
| transforms.Resize(self.img_size, interpolation=Image.BILINEAR), | |
| transforms.CenterCrop(self.img_size), | |
| transforms.Normalize(mean=0.5, std=0.5), | |
| ] | |
| ) | |
| OmnidataNormalPredictor.__init__ = patched_omni_normal_init | |
| def patched_panojoint_init(self, save_path=None): | |
| self.depth_predictor = OmnidataPredictor() | |
| self.normal_predictor = OmnidataNormalPredictor() | |
| self.save_path = save_path | |
| from modules.geo_predictors import PanoJointPredictor | |
| PanoJointPredictor.__init__ = patched_panojoint_init | |
| # NOTE: We use gsplat instead. | |
| # import depth_diff_gaussian_rasterization_min as ddgr | |
| # from dataclasses import dataclass | |
| # @dataclass | |
| # class PatchedGaussianRasterizationSettings: | |
| # image_height: int | |
| # image_width: int | |
| # tanfovx: float | |
| # tanfovy: float | |
| # bg: torch.Tensor | |
| # scale_modifier: float | |
| # viewmatrix: torch.Tensor | |
| # projmatrix: torch.Tensor | |
| # sh_degree: int | |
| # campos: torch.Tensor | |
| # prefiltered: bool | |
| # debug: bool = False | |
| # ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings | |
| # disable get_has_ddp_rank print in `BaseInpaintingTrainingModule` | |
| os.environ["NODE_RANK"] = "0" | |
| from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import ( | |
| load_checkpoint, | |
| ) | |
| from thirdparty.pano2room.modules.inpainters.lama_inpainter import ( | |
| LamaInpainter, | |
| ) | |
| def patched_lama_inpaint_init(self): | |
| zip_path = hf_hub_download( | |
| repo_id="smartywu/big-lama", | |
| filename="big-lama.zip", | |
| repo_type="model", | |
| ) | |
| extract_dir = os.path.splitext(zip_path)[0] | |
| if not os.path.exists(extract_dir): | |
| os.makedirs(extract_dir, exist_ok=True) | |
| with zipfile.ZipFile(zip_path, "r") as zip_ref: | |
| zip_ref.extractall(extract_dir) | |
| config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml') | |
| checkpoint_path = os.path.join( | |
| extract_dir, 'big-lama/models/best.ckpt' | |
| ) | |
| train_config = OmegaConf.load(config_path) | |
| train_config.training_model.predict_only = True | |
| train_config.visualizer.kind = 'noop' | |
| self.model = load_checkpoint( | |
| train_config, checkpoint_path, strict=False, map_location='cpu' | |
| ) | |
| self.model.freeze() | |
| LamaInpainter.__init__ = patched_lama_inpaint_init | |
| from diffusers import StableDiffusionInpaintPipeline | |
| from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import ( | |
| SDFTInpainter, | |
| ) | |
| def patched_sd_inpaint_init(self, subset_name=None): | |
| super(SDFTInpainter, self).__init__() | |
| pipe = StableDiffusionInpaintPipeline.from_pretrained( | |
| "stabilityai/stable-diffusion-2-inpainting", | |
| torch_dtype=torch.float16, | |
| ).to("cuda") | |
| pipe.enable_model_cpu_offload() | |
| self.inpaint_pipe = pipe | |
| SDFTInpainter.__init__ = patched_sd_inpaint_init | |
| def monkey_patch_maniskill(): | |
| from mani_skill.envs.scene import ManiSkillScene | |
| def get_sensor_images( | |
| self, obs: dict[str, any] | |
| ) -> dict[str, dict[str, torch.Tensor]]: | |
| sensor_data = dict() | |
| for name, sensor in self.sensors.items(): | |
| sensor_data[name] = sensor.get_images(obs[name]) | |
| return sensor_data | |
| def get_human_render_camera_images( | |
| self, camera_name: str = None, return_alpha: bool = False | |
| ) -> dict[str, torch.Tensor]: | |
| def get_rgba_tensor(camera, return_alpha): | |
| color = camera.get_obs( | |
| rgb=True, depth=False, segmentation=False, position=False | |
| )["rgb"] | |
| if return_alpha: | |
| seg_labels = camera.get_obs( | |
| rgb=False, depth=False, segmentation=True, position=False | |
| )["segmentation"] | |
| masks = np.where((seg_labels.cpu() > 1), 255, 0).astype( | |
| np.uint8 | |
| ) | |
| masks = torch.tensor(masks).to(color.device) | |
| color = torch.concat([color, masks], dim=-1) | |
| return color | |
| image_data = dict() | |
| if self.gpu_sim_enabled: | |
| if self.parallel_in_single_scene: | |
| for name, camera in self.human_render_cameras.items(): | |
| camera.camera._render_cameras[0].take_picture() | |
| rgba = get_rgba_tensor(camera, return_alpha) | |
| image_data[name] = rgba | |
| else: | |
| for name, camera in self.human_render_cameras.items(): | |
| if camera_name is not None and name != camera_name: | |
| continue | |
| assert camera.config.shader_config.shader_pack not in [ | |
| "rt", | |
| "rt-fast", | |
| "rt-med", | |
| ], "ray tracing shaders do not work with parallel rendering" | |
| camera.capture() | |
| rgba = get_rgba_tensor(camera, return_alpha) | |
| image_data[name] = rgba | |
| else: | |
| for name, camera in self.human_render_cameras.items(): | |
| if camera_name is not None and name != camera_name: | |
| continue | |
| camera.capture() | |
| rgba = get_rgba_tensor(camera, return_alpha) | |
| image_data[name] = rgba | |
| return image_data | |
| ManiSkillScene.get_sensor_images = get_sensor_images | |
| ManiSkillScene.get_human_render_camera_images = ( | |
| get_human_render_camera_images | |
| ) | |
| def monkey_patch_sam3d(): | |
| from typing import Optional, Union | |
| from embodied_gen.data.utils import model_device_ctx | |
| from embodied_gen.utils.log import logger | |
| os.environ["LIDRA_SKIP_INIT"] = "true" | |
| current_file_path = os.path.abspath(__file__) | |
| current_dir = os.path.dirname(current_file_path) | |
| sam3d_root = os.path.abspath( | |
| os.path.join(current_dir, "../../thirdparty/sam3d") | |
| ) | |
| if sam3d_root not in sys.path: | |
| sys.path.insert(0, sam3d_root) | |
| print(f"[MonkeyPatch] Added to sys.path: {sam3d_root}") | |
| def patch_pointmap_infer_pipeline(): | |
| from copy import deepcopy | |
| try: | |
| from sam3d_objects.pipeline.inference_pipeline_pointmap import ( | |
| InferencePipelinePointMap, | |
| ) | |
| except ImportError: | |
| logger.error( | |
| "[MonkeyPatch]: Could not import sam3d_objects directly. Check paths." | |
| ) | |
| return | |
| def patch_run( | |
| self, | |
| image: Union[None, Image.Image, np.ndarray], | |
| mask: Union[None, Image.Image, np.ndarray] = None, | |
| seed: Optional[int] = None, | |
| stage1_only=False, | |
| with_mesh_postprocess=True, | |
| with_texture_baking=True, | |
| with_layout_postprocess=True, | |
| use_vertex_color=False, | |
| stage1_inference_steps=None, | |
| stage2_inference_steps=None, | |
| use_stage1_distillation=False, | |
| use_stage2_distillation=False, | |
| pointmap=None, | |
| decode_formats=None, | |
| estimate_plane=False, | |
| ) -> dict: | |
| image = self.merge_image_and_mask(image, mask) | |
| with self.device: | |
| pointmap_dict = self.compute_pointmap(image, pointmap) | |
| pointmap = pointmap_dict["pointmap"] | |
| pts = type(self)._down_sample_img(pointmap) | |
| pts_colors = type(self)._down_sample_img( | |
| pointmap_dict["pts_color"] | |
| ) | |
| if estimate_plane: | |
| return self.estimate_plane(pointmap_dict, image) | |
| ss_input_dict = self.preprocess_image( | |
| image, self.ss_preprocessor, pointmap=pointmap | |
| ) | |
| slat_input_dict = self.preprocess_image( | |
| image, self.slat_preprocessor | |
| ) | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| with model_device_ctx( | |
| self.models["ss_generator"], | |
| self.models["ss_decoder"], | |
| self.condition_embedders["ss_condition_embedder"], | |
| ): | |
| ss_return_dict = self.sample_sparse_structure( | |
| ss_input_dict, | |
| inference_steps=stage1_inference_steps, | |
| use_distillation=use_stage1_distillation, | |
| ) | |
| # We could probably use the decoder from the models themselves | |
| pointmap_scale = ss_input_dict.get("pointmap_scale", None) | |
| pointmap_shift = ss_input_dict.get("pointmap_shift", None) | |
| ss_return_dict.update( | |
| self.pose_decoder( | |
| ss_return_dict, | |
| scene_scale=pointmap_scale, | |
| scene_shift=pointmap_shift, | |
| ) | |
| ) | |
| logger.info( | |
| f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling" | |
| ) | |
| ss_return_dict["scale"] = ( | |
| ss_return_dict["scale"] | |
| * ss_return_dict["downsample_factor"] | |
| ) | |
| if stage1_only: | |
| logger.info("Finished!") | |
| ss_return_dict["voxel"] = ( | |
| ss_return_dict["coords"][:, 1:] / 64 - 0.5 | |
| ) | |
| return { | |
| **ss_return_dict, | |
| "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3 | |
| "pointmap_colors": pts_colors.cpu().permute( | |
| (1, 2, 0) | |
| ), # HxWx3 | |
| } | |
| # return ss_return_dict | |
| coords = ss_return_dict["coords"] | |
| with model_device_ctx( | |
| self.models["slat_generator"], | |
| self.condition_embedders["slat_condition_embedder"], | |
| ): | |
| slat = self.sample_slat( | |
| slat_input_dict, | |
| coords, | |
| inference_steps=stage2_inference_steps, | |
| use_distillation=use_stage2_distillation, | |
| ) | |
| with model_device_ctx( | |
| self.models["slat_decoder_mesh"], | |
| self.models["slat_decoder_gs"], | |
| self.models["slat_decoder_gs_4"], | |
| ): | |
| outputs = self.decode_slat( | |
| slat, | |
| ( | |
| self.decode_formats | |
| if decode_formats is None | |
| else decode_formats | |
| ), | |
| ) | |
| outputs = self.postprocess_slat_output( | |
| outputs, | |
| with_mesh_postprocess, | |
| with_texture_baking, | |
| use_vertex_color, | |
| ) | |
| glb = outputs.get("glb", None) | |
| try: | |
| if ( | |
| with_layout_postprocess | |
| and self.layout_post_optimization_method is not None | |
| ): | |
| assert ( | |
| glb is not None | |
| ), "require mesh to run postprocessing" | |
| logger.info( | |
| "Running layout post optimization method..." | |
| ) | |
| postprocessed_pose = self.run_post_optimization( | |
| deepcopy(glb), | |
| pointmap_dict["intrinsics"], | |
| ss_return_dict, | |
| ss_input_dict, | |
| ) | |
| ss_return_dict.update(postprocessed_pose) | |
| except Exception as e: | |
| logger.error( | |
| f"Error during layout post optimization: {e}", | |
| exc_info=True, | |
| ) | |
| result = { | |
| **ss_return_dict, | |
| **outputs, | |
| "pointmap": pts.cpu().permute((1, 2, 0)), | |
| "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)), | |
| } | |
| return result | |
| InferencePipelinePointMap.run = patch_run | |
| def patch_infer_init(): | |
| import torch | |
| try: | |
| from sam3d_objects.pipeline import preprocess_utils | |
| from sam3d_objects.pipeline.inference_pipeline_pointmap import ( | |
| InferencePipeline, | |
| ) | |
| from sam3d_objects.pipeline.inference_utils import ( | |
| SLAT_MEAN, | |
| SLAT_STD, | |
| ) | |
| except ImportError: | |
| print( | |
| "[MonkeyPatch] Error: Could not import sam3d_objects directly for infer pipeline." | |
| ) | |
| return | |
| def patch_init( | |
| self, | |
| ss_generator_config_path, | |
| ss_generator_ckpt_path, | |
| slat_generator_config_path, | |
| slat_generator_ckpt_path, | |
| ss_decoder_config_path, | |
| ss_decoder_ckpt_path, | |
| slat_decoder_gs_config_path, | |
| slat_decoder_gs_ckpt_path, | |
| slat_decoder_mesh_config_path, | |
| slat_decoder_mesh_ckpt_path, | |
| slat_decoder_gs_4_config_path=None, | |
| slat_decoder_gs_4_ckpt_path=None, | |
| ss_encoder_config_path=None, | |
| ss_encoder_ckpt_path=None, | |
| decode_formats=["gaussian", "mesh"], | |
| dtype="bfloat16", | |
| pad_size=1.0, | |
| version="v0", | |
| device="cuda", | |
| ss_preprocessor=preprocess_utils.get_default_preprocessor(), | |
| slat_preprocessor=preprocess_utils.get_default_preprocessor(), | |
| ss_condition_input_mapping=["image"], | |
| slat_condition_input_mapping=["image"], | |
| pose_decoder_name="default", | |
| workspace_dir="", | |
| downsample_ss_dist=0, # the distance we use to downsample | |
| ss_inference_steps=25, | |
| ss_rescale_t=3, | |
| ss_cfg_strength=7, | |
| ss_cfg_interval=[0, 500], | |
| ss_cfg_strength_pm=0.0, | |
| slat_inference_steps=25, | |
| slat_rescale_t=3, | |
| slat_cfg_strength=5, | |
| slat_cfg_interval=[0, 500], | |
| rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d, | |
| shape_model_dtype=None, | |
| compile_model=False, | |
| slat_mean=SLAT_MEAN, | |
| slat_std=SLAT_STD, | |
| ): | |
| self.rendering_engine = rendering_engine | |
| self.device = torch.device(device) | |
| self.compile_model = compile_model | |
| logger.info(f"self.device: {self.device}") | |
| logger.info( | |
| f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}" | |
| ) | |
| logger.info(f"Actually using GPU: {torch.cuda.current_device()}") | |
| with self.device: | |
| self.decode_formats = decode_formats | |
| self.pad_size = pad_size | |
| self.version = version | |
| self.ss_condition_input_mapping = ss_condition_input_mapping | |
| self.slat_condition_input_mapping = ( | |
| slat_condition_input_mapping | |
| ) | |
| self.workspace_dir = workspace_dir | |
| self.downsample_ss_dist = downsample_ss_dist | |
| self.ss_inference_steps = ss_inference_steps | |
| self.ss_rescale_t = ss_rescale_t | |
| self.ss_cfg_strength = ss_cfg_strength | |
| self.ss_cfg_interval = ss_cfg_interval | |
| self.ss_cfg_strength_pm = ss_cfg_strength_pm | |
| self.slat_inference_steps = slat_inference_steps | |
| self.slat_rescale_t = slat_rescale_t | |
| self.slat_cfg_strength = slat_cfg_strength | |
| self.slat_cfg_interval = slat_cfg_interval | |
| self.dtype = self._get_dtype(dtype) | |
| if shape_model_dtype is None: | |
| self.shape_model_dtype = self.dtype | |
| else: | |
| self.shape_model_dtype = self._get_dtype(shape_model_dtype) | |
| # Setup preprocessors | |
| self.pose_decoder = self.init_pose_decoder( | |
| ss_generator_config_path, pose_decoder_name | |
| ) | |
| self.ss_preprocessor = self.init_ss_preprocessor( | |
| ss_preprocessor, ss_generator_config_path | |
| ) | |
| self.slat_preprocessor = slat_preprocessor | |
| logger.info("Loading model weights...") | |
| raw_device = self.device | |
| self.device = torch.device("cpu") | |
| ss_generator = self.init_ss_generator( | |
| ss_generator_config_path, ss_generator_ckpt_path | |
| ) | |
| slat_generator = self.init_slat_generator( | |
| slat_generator_config_path, slat_generator_ckpt_path | |
| ) | |
| ss_decoder = self.init_ss_decoder( | |
| ss_decoder_config_path, ss_decoder_ckpt_path | |
| ) | |
| ss_encoder = self.init_ss_encoder( | |
| ss_encoder_config_path, ss_encoder_ckpt_path | |
| ) | |
| slat_decoder_gs = self.init_slat_decoder_gs( | |
| slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path | |
| ) | |
| slat_decoder_gs_4 = self.init_slat_decoder_gs( | |
| slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path | |
| ) | |
| slat_decoder_mesh = self.init_slat_decoder_mesh( | |
| slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path | |
| ) | |
| # Load conditioner embedder so that we only load it once | |
| ss_condition_embedder = self.init_ss_condition_embedder( | |
| ss_generator_config_path, ss_generator_ckpt_path | |
| ) | |
| slat_condition_embedder = self.init_slat_condition_embedder( | |
| slat_generator_config_path, slat_generator_ckpt_path | |
| ) | |
| self.device = raw_device | |
| self.condition_embedders = { | |
| "ss_condition_embedder": ss_condition_embedder, | |
| "slat_condition_embedder": slat_condition_embedder, | |
| } | |
| # override generator and condition embedder setting | |
| self.override_ss_generator_cfg_config( | |
| ss_generator, | |
| cfg_strength=ss_cfg_strength, | |
| inference_steps=ss_inference_steps, | |
| rescale_t=ss_rescale_t, | |
| cfg_interval=ss_cfg_interval, | |
| cfg_strength_pm=ss_cfg_strength_pm, | |
| ) | |
| self.override_slat_generator_cfg_config( | |
| slat_generator, | |
| cfg_strength=slat_cfg_strength, | |
| inference_steps=slat_inference_steps, | |
| rescale_t=slat_rescale_t, | |
| cfg_interval=slat_cfg_interval, | |
| ) | |
| self.models = torch.nn.ModuleDict( | |
| { | |
| "ss_generator": ss_generator, | |
| "slat_generator": slat_generator, | |
| "ss_encoder": ss_encoder, | |
| "ss_decoder": ss_decoder, | |
| "slat_decoder_gs": slat_decoder_gs, | |
| "slat_decoder_gs_4": slat_decoder_gs_4, | |
| "slat_decoder_mesh": slat_decoder_mesh, | |
| } | |
| ) | |
| logger.info("Loading model weights completed!") | |
| if self.compile_model: | |
| logger.info("Compiling model...") | |
| self._compile() | |
| logger.info("Model compilation completed!") | |
| self.slat_mean = torch.tensor(slat_mean) | |
| self.slat_std = torch.tensor(slat_std) | |
| InferencePipeline.__init__ = patch_init | |
| patch_pointmap_infer_pipeline() | |
| patch_infer_init() | |
| return | |