xinjie.wang
update
b0e8715
# Project EmbodiedGen
#
# Copyright (c) 2025 Horizon Robotics. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.
import os
import sys
import zipfile
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from omegaconf import OmegaConf
from PIL import Image
from torchvision import transforms
__all__ = [
"monkey_patch_pano2room",
"monkey_patch_maniskill",
"monkey_patch_sam3d",
]
def monkey_patch_pano2room():
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sys.path.append(os.path.join(current_dir, "../.."))
sys.path.append(os.path.join(current_dir, "../../thirdparty/pano2room"))
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_normal_predictor import (
OmnidataNormalPredictor,
)
from thirdparty.pano2room.modules.geo_predictors.omnidata.omnidata_predictor import (
OmnidataPredictor,
)
def patched_omni_depth_init(self):
self.img_size = 384
self.model = torch.hub.load(
'alexsax/omnidata_models', 'depth_dpt_hybrid_384'
)
self.model.eval()
self.trans_totensor = transforms.Compose(
[
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
transforms.CenterCrop(self.img_size),
transforms.Normalize(mean=0.5, std=0.5),
]
)
OmnidataPredictor.__init__ = patched_omni_depth_init
def patched_omni_normal_init(self):
self.img_size = 384
self.model = torch.hub.load(
'alexsax/omnidata_models', 'surface_normal_dpt_hybrid_384'
)
self.model.eval()
self.trans_totensor = transforms.Compose(
[
transforms.Resize(self.img_size, interpolation=Image.BILINEAR),
transforms.CenterCrop(self.img_size),
transforms.Normalize(mean=0.5, std=0.5),
]
)
OmnidataNormalPredictor.__init__ = patched_omni_normal_init
def patched_panojoint_init(self, save_path=None):
self.depth_predictor = OmnidataPredictor()
self.normal_predictor = OmnidataNormalPredictor()
self.save_path = save_path
from modules.geo_predictors import PanoJointPredictor
PanoJointPredictor.__init__ = patched_panojoint_init
# NOTE: We use gsplat instead.
# import depth_diff_gaussian_rasterization_min as ddgr
# from dataclasses import dataclass
# @dataclass
# class PatchedGaussianRasterizationSettings:
# image_height: int
# image_width: int
# tanfovx: float
# tanfovy: float
# bg: torch.Tensor
# scale_modifier: float
# viewmatrix: torch.Tensor
# projmatrix: torch.Tensor
# sh_degree: int
# campos: torch.Tensor
# prefiltered: bool
# debug: bool = False
# ddgr.GaussianRasterizationSettings = PatchedGaussianRasterizationSettings
# disable get_has_ddp_rank print in `BaseInpaintingTrainingModule`
os.environ["NODE_RANK"] = "0"
from thirdparty.pano2room.modules.inpainters.lama.saicinpainting.training.trainers import (
load_checkpoint,
)
from thirdparty.pano2room.modules.inpainters.lama_inpainter import (
LamaInpainter,
)
def patched_lama_inpaint_init(self):
zip_path = hf_hub_download(
repo_id="smartywu/big-lama",
filename="big-lama.zip",
repo_type="model",
)
extract_dir = os.path.splitext(zip_path)[0]
if not os.path.exists(extract_dir):
os.makedirs(extract_dir, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(extract_dir)
config_path = os.path.join(extract_dir, 'big-lama', 'config.yaml')
checkpoint_path = os.path.join(
extract_dir, 'big-lama/models/best.ckpt'
)
train_config = OmegaConf.load(config_path)
train_config.training_model.predict_only = True
train_config.visualizer.kind = 'noop'
self.model = load_checkpoint(
train_config, checkpoint_path, strict=False, map_location='cpu'
)
self.model.freeze()
LamaInpainter.__init__ = patched_lama_inpaint_init
from diffusers import StableDiffusionInpaintPipeline
from thirdparty.pano2room.modules.inpainters.SDFT_inpainter import (
SDFTInpainter,
)
def patched_sd_inpaint_init(self, subset_name=None):
super(SDFTInpainter, self).__init__()
pipe = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16,
).to("cuda")
pipe.enable_model_cpu_offload()
self.inpaint_pipe = pipe
SDFTInpainter.__init__ = patched_sd_inpaint_init
def monkey_patch_maniskill():
from mani_skill.envs.scene import ManiSkillScene
def get_sensor_images(
self, obs: dict[str, any]
) -> dict[str, dict[str, torch.Tensor]]:
sensor_data = dict()
for name, sensor in self.sensors.items():
sensor_data[name] = sensor.get_images(obs[name])
return sensor_data
def get_human_render_camera_images(
self, camera_name: str = None, return_alpha: bool = False
) -> dict[str, torch.Tensor]:
def get_rgba_tensor(camera, return_alpha):
color = camera.get_obs(
rgb=True, depth=False, segmentation=False, position=False
)["rgb"]
if return_alpha:
seg_labels = camera.get_obs(
rgb=False, depth=False, segmentation=True, position=False
)["segmentation"]
masks = np.where((seg_labels.cpu() > 1), 255, 0).astype(
np.uint8
)
masks = torch.tensor(masks).to(color.device)
color = torch.concat([color, masks], dim=-1)
return color
image_data = dict()
if self.gpu_sim_enabled:
if self.parallel_in_single_scene:
for name, camera in self.human_render_cameras.items():
camera.camera._render_cameras[0].take_picture()
rgba = get_rgba_tensor(camera, return_alpha)
image_data[name] = rgba
else:
for name, camera in self.human_render_cameras.items():
if camera_name is not None and name != camera_name:
continue
assert camera.config.shader_config.shader_pack not in [
"rt",
"rt-fast",
"rt-med",
], "ray tracing shaders do not work with parallel rendering"
camera.capture()
rgba = get_rgba_tensor(camera, return_alpha)
image_data[name] = rgba
else:
for name, camera in self.human_render_cameras.items():
if camera_name is not None and name != camera_name:
continue
camera.capture()
rgba = get_rgba_tensor(camera, return_alpha)
image_data[name] = rgba
return image_data
ManiSkillScene.get_sensor_images = get_sensor_images
ManiSkillScene.get_human_render_camera_images = (
get_human_render_camera_images
)
def monkey_patch_sam3d():
from typing import Optional, Union
from embodied_gen.data.utils import model_device_ctx
from embodied_gen.utils.log import logger
os.environ["LIDRA_SKIP_INIT"] = "true"
current_file_path = os.path.abspath(__file__)
current_dir = os.path.dirname(current_file_path)
sam3d_root = os.path.abspath(
os.path.join(current_dir, "../../thirdparty/sam3d")
)
if sam3d_root not in sys.path:
sys.path.insert(0, sam3d_root)
print(f"[MonkeyPatch] Added to sys.path: {sam3d_root}")
def patch_pointmap_infer_pipeline():
from copy import deepcopy
try:
from sam3d_objects.pipeline.inference_pipeline_pointmap import (
InferencePipelinePointMap,
)
except ImportError:
logger.error(
"[MonkeyPatch]: Could not import sam3d_objects directly. Check paths."
)
return
def patch_run(
self,
image: Union[None, Image.Image, np.ndarray],
mask: Union[None, Image.Image, np.ndarray] = None,
seed: Optional[int] = None,
stage1_only=False,
with_mesh_postprocess=True,
with_texture_baking=True,
with_layout_postprocess=True,
use_vertex_color=False,
stage1_inference_steps=None,
stage2_inference_steps=None,
use_stage1_distillation=False,
use_stage2_distillation=False,
pointmap=None,
decode_formats=None,
estimate_plane=False,
) -> dict:
image = self.merge_image_and_mask(image, mask)
with self.device:
pointmap_dict = self.compute_pointmap(image, pointmap)
pointmap = pointmap_dict["pointmap"]
pts = type(self)._down_sample_img(pointmap)
pts_colors = type(self)._down_sample_img(
pointmap_dict["pts_color"]
)
if estimate_plane:
return self.estimate_plane(pointmap_dict, image)
ss_input_dict = self.preprocess_image(
image, self.ss_preprocessor, pointmap=pointmap
)
slat_input_dict = self.preprocess_image(
image, self.slat_preprocessor
)
if seed is not None:
torch.manual_seed(seed)
with model_device_ctx(
self.models["ss_generator"],
self.models["ss_decoder"],
self.condition_embedders["ss_condition_embedder"],
):
ss_return_dict = self.sample_sparse_structure(
ss_input_dict,
inference_steps=stage1_inference_steps,
use_distillation=use_stage1_distillation,
)
# We could probably use the decoder from the models themselves
pointmap_scale = ss_input_dict.get("pointmap_scale", None)
pointmap_shift = ss_input_dict.get("pointmap_shift", None)
ss_return_dict.update(
self.pose_decoder(
ss_return_dict,
scene_scale=pointmap_scale,
scene_shift=pointmap_shift,
)
)
logger.info(
f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling"
)
ss_return_dict["scale"] = (
ss_return_dict["scale"]
* ss_return_dict["downsample_factor"]
)
if stage1_only:
logger.info("Finished!")
ss_return_dict["voxel"] = (
ss_return_dict["coords"][:, 1:] / 64 - 0.5
)
return {
**ss_return_dict,
"pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3
"pointmap_colors": pts_colors.cpu().permute(
(1, 2, 0)
), # HxWx3
}
# return ss_return_dict
coords = ss_return_dict["coords"]
with model_device_ctx(
self.models["slat_generator"],
self.condition_embedders["slat_condition_embedder"],
):
slat = self.sample_slat(
slat_input_dict,
coords,
inference_steps=stage2_inference_steps,
use_distillation=use_stage2_distillation,
)
with model_device_ctx(
self.models["slat_decoder_mesh"],
self.models["slat_decoder_gs"],
self.models["slat_decoder_gs_4"],
):
outputs = self.decode_slat(
slat,
(
self.decode_formats
if decode_formats is None
else decode_formats
),
)
outputs = self.postprocess_slat_output(
outputs,
with_mesh_postprocess,
with_texture_baking,
use_vertex_color,
)
glb = outputs.get("glb", None)
try:
if (
with_layout_postprocess
and self.layout_post_optimization_method is not None
):
assert (
glb is not None
), "require mesh to run postprocessing"
logger.info(
"Running layout post optimization method..."
)
postprocessed_pose = self.run_post_optimization(
deepcopy(glb),
pointmap_dict["intrinsics"],
ss_return_dict,
ss_input_dict,
)
ss_return_dict.update(postprocessed_pose)
except Exception as e:
logger.error(
f"Error during layout post optimization: {e}",
exc_info=True,
)
result = {
**ss_return_dict,
**outputs,
"pointmap": pts.cpu().permute((1, 2, 0)),
"pointmap_colors": pts_colors.cpu().permute((1, 2, 0)),
}
return result
InferencePipelinePointMap.run = patch_run
def patch_infer_init():
import torch
try:
from sam3d_objects.pipeline import preprocess_utils
from sam3d_objects.pipeline.inference_pipeline_pointmap import (
InferencePipeline,
)
from sam3d_objects.pipeline.inference_utils import (
SLAT_MEAN,
SLAT_STD,
)
except ImportError:
print(
"[MonkeyPatch] Error: Could not import sam3d_objects directly for infer pipeline."
)
return
def patch_init(
self,
ss_generator_config_path,
ss_generator_ckpt_path,
slat_generator_config_path,
slat_generator_ckpt_path,
ss_decoder_config_path,
ss_decoder_ckpt_path,
slat_decoder_gs_config_path,
slat_decoder_gs_ckpt_path,
slat_decoder_mesh_config_path,
slat_decoder_mesh_ckpt_path,
slat_decoder_gs_4_config_path=None,
slat_decoder_gs_4_ckpt_path=None,
ss_encoder_config_path=None,
ss_encoder_ckpt_path=None,
decode_formats=["gaussian", "mesh"],
dtype="bfloat16",
pad_size=1.0,
version="v0",
device="cuda",
ss_preprocessor=preprocess_utils.get_default_preprocessor(),
slat_preprocessor=preprocess_utils.get_default_preprocessor(),
ss_condition_input_mapping=["image"],
slat_condition_input_mapping=["image"],
pose_decoder_name="default",
workspace_dir="",
downsample_ss_dist=0, # the distance we use to downsample
ss_inference_steps=25,
ss_rescale_t=3,
ss_cfg_strength=7,
ss_cfg_interval=[0, 500],
ss_cfg_strength_pm=0.0,
slat_inference_steps=25,
slat_rescale_t=3,
slat_cfg_strength=5,
slat_cfg_interval=[0, 500],
rendering_engine: str = "nvdiffrast", # nvdiffrast OR pytorch3d,
shape_model_dtype=None,
compile_model=False,
slat_mean=SLAT_MEAN,
slat_std=SLAT_STD,
):
self.rendering_engine = rendering_engine
self.device = torch.device(device)
self.compile_model = compile_model
logger.info(f"self.device: {self.device}")
logger.info(
f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', None)}"
)
logger.info(f"Actually using GPU: {torch.cuda.current_device()}")
with self.device:
self.decode_formats = decode_formats
self.pad_size = pad_size
self.version = version
self.ss_condition_input_mapping = ss_condition_input_mapping
self.slat_condition_input_mapping = (
slat_condition_input_mapping
)
self.workspace_dir = workspace_dir
self.downsample_ss_dist = downsample_ss_dist
self.ss_inference_steps = ss_inference_steps
self.ss_rescale_t = ss_rescale_t
self.ss_cfg_strength = ss_cfg_strength
self.ss_cfg_interval = ss_cfg_interval
self.ss_cfg_strength_pm = ss_cfg_strength_pm
self.slat_inference_steps = slat_inference_steps
self.slat_rescale_t = slat_rescale_t
self.slat_cfg_strength = slat_cfg_strength
self.slat_cfg_interval = slat_cfg_interval
self.dtype = self._get_dtype(dtype)
if shape_model_dtype is None:
self.shape_model_dtype = self.dtype
else:
self.shape_model_dtype = self._get_dtype(shape_model_dtype)
# Setup preprocessors
self.pose_decoder = self.init_pose_decoder(
ss_generator_config_path, pose_decoder_name
)
self.ss_preprocessor = self.init_ss_preprocessor(
ss_preprocessor, ss_generator_config_path
)
self.slat_preprocessor = slat_preprocessor
logger.info("Loading model weights...")
raw_device = self.device
self.device = torch.device("cpu")
ss_generator = self.init_ss_generator(
ss_generator_config_path, ss_generator_ckpt_path
)
slat_generator = self.init_slat_generator(
slat_generator_config_path, slat_generator_ckpt_path
)
ss_decoder = self.init_ss_decoder(
ss_decoder_config_path, ss_decoder_ckpt_path
)
ss_encoder = self.init_ss_encoder(
ss_encoder_config_path, ss_encoder_ckpt_path
)
slat_decoder_gs = self.init_slat_decoder_gs(
slat_decoder_gs_config_path, slat_decoder_gs_ckpt_path
)
slat_decoder_gs_4 = self.init_slat_decoder_gs(
slat_decoder_gs_4_config_path, slat_decoder_gs_4_ckpt_path
)
slat_decoder_mesh = self.init_slat_decoder_mesh(
slat_decoder_mesh_config_path, slat_decoder_mesh_ckpt_path
)
# Load conditioner embedder so that we only load it once
ss_condition_embedder = self.init_ss_condition_embedder(
ss_generator_config_path, ss_generator_ckpt_path
)
slat_condition_embedder = self.init_slat_condition_embedder(
slat_generator_config_path, slat_generator_ckpt_path
)
self.device = raw_device
self.condition_embedders = {
"ss_condition_embedder": ss_condition_embedder,
"slat_condition_embedder": slat_condition_embedder,
}
# override generator and condition embedder setting
self.override_ss_generator_cfg_config(
ss_generator,
cfg_strength=ss_cfg_strength,
inference_steps=ss_inference_steps,
rescale_t=ss_rescale_t,
cfg_interval=ss_cfg_interval,
cfg_strength_pm=ss_cfg_strength_pm,
)
self.override_slat_generator_cfg_config(
slat_generator,
cfg_strength=slat_cfg_strength,
inference_steps=slat_inference_steps,
rescale_t=slat_rescale_t,
cfg_interval=slat_cfg_interval,
)
self.models = torch.nn.ModuleDict(
{
"ss_generator": ss_generator,
"slat_generator": slat_generator,
"ss_encoder": ss_encoder,
"ss_decoder": ss_decoder,
"slat_decoder_gs": slat_decoder_gs,
"slat_decoder_gs_4": slat_decoder_gs_4,
"slat_decoder_mesh": slat_decoder_mesh,
}
)
logger.info("Loading model weights completed!")
if self.compile_model:
logger.info("Compiling model...")
self._compile()
logger.info("Model compilation completed!")
self.slat_mean = torch.tensor(slat_mean)
self.slat_std = torch.tensor(slat_std)
InferencePipeline.__init__ = patch_init
patch_pointmap_infer_pipeline()
patch_infer_init()
return