# Project EmbodiedGen # # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or # implied. See the License for the specific language governing # permissions and limitations under the License. from embodied_gen.utils.monkey_patches import monkey_patch_sam3d monkey_patch_sam3d() import os import sys from typing import Optional, Union import numpy as np from hydra.utils import instantiate # from modelscope import snapshot_download from huggingface_hub import snapshot_download from omegaconf import OmegaConf from PIL import Image current_file_path = os.path.abspath(__file__) current_dir = os.path.dirname(current_file_path) sys.path.append(os.path.join(current_dir, "../..")) from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import ( InferencePipelinePointMap, ) __all__ = ["Sam3dInference"] def load_image(path: str) -> np.ndarray: image = Image.open(path) image = np.array(image) image = image.astype(np.uint8) return image def load_mask(path: str) -> np.ndarray: mask = load_image(path) mask = mask > 0 if mask.ndim == 3: mask = mask[..., -1] return mask class Sam3dInference: def __init__( self, local_dir: str = "weights/sam-3d-objects", compile: bool = False ) -> None: if not os.path.exists(local_dir): # snapshot_download("facebook/sam-3d-objects", local_dir=local_dir) snapshot_download("jetjodh/sam-3d-objects", local_dir=local_dir) config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml") config = OmegaConf.load(config_file) config.rendering_engine = "nvdiffrast" config.compile_model = compile config.workspace_dir = os.path.dirname(config_file) # Generate 4 gs in each pixel. config["slat_decoder_gs_config_path"] = config.pop( "slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml" ) config["slat_decoder_gs_ckpt_path"] = config.pop( "slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt" ) self.pipeline: InferencePipelinePointMap = instantiate(config) def merge_mask_to_rgba( self, image: np.ndarray, mask: np.ndarray ) -> np.ndarray: mask = mask.astype(np.uint8) * 255 mask = mask[..., None] rgba_image = np.concatenate([image[..., :3], mask], axis=-1) return rgba_image def run( self, image: np.ndarray | Image.Image, mask: np.ndarray = None, seed: int = None, pointmap: np.ndarray = None, use_stage1_distillation: bool = False, use_stage2_distillation: bool = False, stage1_inference_steps: int = 25, stage2_inference_steps: int = 25, ) -> dict: if isinstance(image, Image.Image): image = np.array(image) if mask is not None: image = self.merge_mask_to_rgba(image, mask) return self.pipeline.run( image, None, seed, stage1_only=False, with_mesh_postprocess=False, with_texture_baking=False, with_layout_postprocess=False, use_vertex_color=True, use_stage1_distillation=use_stage1_distillation, use_stage2_distillation=use_stage2_distillation, stage1_inference_steps=stage1_inference_steps, stage2_inference_steps=stage2_inference_steps, pointmap=pointmap, ) if __name__ == "__main__": pipeline = Sam3dInference() # load image image = load_image( "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png" ) mask = load_mask( "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/13.png" ) import torch if torch.cuda.is_available(): torch.cuda.reset_peak_memory_stats() torch.cuda.empty_cache() from time import time start = time() output = pipeline.run(image, mask, seed=42) print(f"Running cost: {round(time()-start, 1)}") if torch.cuda.is_available(): max_memory = torch.cuda.max_memory_allocated() / (1024**3) print(f"(Max VRAM): {max_memory:.2f} GB") print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") output["gs"].save_ply(f"outputs/splat.ply") print("Your reconstruction has been saved to outputs/splat.ply")