Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| from embodied_gen.utils.monkey_patches import monkey_patch_sam3d | |
| monkey_patch_sam3d() | |
| import os | |
| import sys | |
| from typing import Optional, Union | |
| import numpy as np | |
| from hydra.utils import instantiate | |
| # from modelscope import snapshot_download | |
| from huggingface_hub import snapshot_download | |
| from omegaconf import OmegaConf | |
| from PIL import Image | |
| current_file_path = os.path.abspath(__file__) | |
| current_dir = os.path.dirname(current_file_path) | |
| sys.path.append(os.path.join(current_dir, "../..")) | |
| from thirdparty.sam3d.sam3d_objects.pipeline.inference_pipeline_pointmap import ( | |
| InferencePipelinePointMap, | |
| ) | |
| __all__ = ["Sam3dInference"] | |
| def load_image(path: str) -> np.ndarray: | |
| image = Image.open(path) | |
| image = np.array(image) | |
| image = image.astype(np.uint8) | |
| return image | |
| def load_mask(path: str) -> np.ndarray: | |
| mask = load_image(path) | |
| mask = mask > 0 | |
| if mask.ndim == 3: | |
| mask = mask[..., -1] | |
| return mask | |
| class Sam3dInference: | |
| def __init__( | |
| self, local_dir: str = "weights/sam-3d-objects", compile: bool = False | |
| ) -> None: | |
| if not os.path.exists(local_dir): | |
| # snapshot_download("facebook/sam-3d-objects", local_dir=local_dir) | |
| snapshot_download("jetjodh/sam-3d-objects", local_dir=local_dir) | |
| config_file = os.path.join(local_dir, "checkpoints/pipeline.yaml") | |
| config = OmegaConf.load(config_file) | |
| config.rendering_engine = "nvdiffrast" | |
| config.compile_model = compile | |
| config.workspace_dir = os.path.dirname(config_file) | |
| # Generate 4 gs in each pixel. | |
| config["slat_decoder_gs_config_path"] = config.pop( | |
| "slat_decoder_gs_4_config_path", "slat_decoder_gs_4.yaml" | |
| ) | |
| config["slat_decoder_gs_ckpt_path"] = config.pop( | |
| "slat_decoder_gs_4_ckpt_path", "slat_decoder_gs_4.ckpt" | |
| ) | |
| self.pipeline: InferencePipelinePointMap = instantiate(config) | |
| def merge_mask_to_rgba( | |
| self, image: np.ndarray, mask: np.ndarray | |
| ) -> np.ndarray: | |
| mask = mask.astype(np.uint8) * 255 | |
| mask = mask[..., None] | |
| rgba_image = np.concatenate([image[..., :3], mask], axis=-1) | |
| return rgba_image | |
| def run( | |
| self, | |
| image: np.ndarray | Image.Image, | |
| mask: np.ndarray = None, | |
| seed: int = None, | |
| pointmap: np.ndarray = None, | |
| use_stage1_distillation: bool = False, | |
| use_stage2_distillation: bool = False, | |
| stage1_inference_steps: int = 25, | |
| stage2_inference_steps: int = 25, | |
| ) -> dict: | |
| if isinstance(image, Image.Image): | |
| image = np.array(image) | |
| if mask is not None: | |
| image = self.merge_mask_to_rgba(image, mask) | |
| return self.pipeline.run( | |
| image, | |
| None, | |
| seed, | |
| stage1_only=False, | |
| with_mesh_postprocess=False, | |
| with_texture_baking=False, | |
| with_layout_postprocess=False, | |
| use_vertex_color=True, | |
| use_stage1_distillation=use_stage1_distillation, | |
| use_stage2_distillation=use_stage2_distillation, | |
| stage1_inference_steps=stage1_inference_steps, | |
| stage2_inference_steps=stage2_inference_steps, | |
| pointmap=pointmap, | |
| ) | |
| if __name__ == "__main__": | |
| pipeline = Sam3dInference() | |
| # load image | |
| image = load_image( | |
| "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/image.png" | |
| ) | |
| mask = load_mask( | |
| "/home/users/xinjie.wang/xinjie/sam-3d-objects/notebook/images/shutterstock_stylish_kidsroom_1640806567/13.png" | |
| ) | |
| import torch | |
| if torch.cuda.is_available(): | |
| torch.cuda.reset_peak_memory_stats() | |
| torch.cuda.empty_cache() | |
| from time import time | |
| start = time() | |
| output = pipeline.run(image, mask, seed=42) | |
| print(f"Running cost: {round(time()-start, 1)}") | |
| if torch.cuda.is_available(): | |
| max_memory = torch.cuda.max_memory_allocated() / (1024**3) | |
| print(f"(Max VRAM): {max_memory:.2f} GB") | |
| print(f"End: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") | |
| output["gs"].save_ply(f"outputs/splat.ply") | |
| print("Your reconstruction has been saved to outputs/splat.ply") | |