Spaces:
Running
on
Zero
Running
on
Zero
EmbodiedGen-Image-to-3D
/
thirdparty
/sam3d
/sam3d_objects
/pipeline
/inference_pipeline_pointmap.py
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| from typing import Union, Optional | |
| from copy import deepcopy | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| import torchvision | |
| from loguru import logger | |
| from PIL import Image | |
| from pytorch3d.renderer import look_at_view_transform | |
| from pytorch3d.transforms import Transform3d | |
| from sam3d_objects.model.backbone.dit.embedder.pointmap import PointPatchEmbed | |
| from sam3d_objects.pipeline.inference_pipeline import InferencePipeline | |
| from sam3d_objects.data.dataset.tdfy.img_and_mask_transforms import ( | |
| get_mask, | |
| ) | |
| from sam3d_objects.data.dataset.tdfy.transforms_3d import ( | |
| DecomposedTransform, | |
| ) | |
| from sam3d_objects.pipeline.utils.pointmap import infer_intrinsics_from_pointmap | |
| from sam3d_objects.pipeline.inference_utils import o3d_plane_estimation, estimate_plane_area | |
| def camera_to_pytorch3d_camera(device="cpu") -> DecomposedTransform: | |
| """ | |
| R3 camera space --> PyTorch3D camera space | |
| Also needed for pointmaps | |
| """ | |
| r3_to_p3d_R, r3_to_p3d_T = look_at_view_transform( | |
| eye=np.array([[0, 0, -1]]), | |
| at=np.array([[0, 0, 0]]), | |
| up=np.array([[0, -1, 0]]), | |
| device=device, | |
| ) | |
| return DecomposedTransform( | |
| rotation=r3_to_p3d_R, | |
| translation=r3_to_p3d_T, | |
| scale=torch.tensor(1.0, dtype=r3_to_p3d_R.dtype, device=device), | |
| ) | |
| def recursive_fn_factory(fn): | |
| def recursive_fn(b): | |
| if isinstance(b, dict): | |
| return {k: recursive_fn(b[k]) for k in b} | |
| if isinstance(b, list): | |
| return [recursive_fn(t) for t in b] | |
| if isinstance(b, tuple): | |
| return tuple(recursive_fn(t) for t in b) | |
| if isinstance(b, torch.Tensor): | |
| return fn(b) | |
| # Yes, writing out an explicit white list of | |
| # trivial types is tedious, but so are bugs that | |
| # come from not applying fn, when expected to have | |
| # applied it. | |
| if b is None: | |
| return b | |
| trivial_types = [bool, int, float] | |
| for t in trivial_types: | |
| if isinstance(b, t): | |
| return b | |
| raise TypeError(f"Unexpected type {type(b)}") | |
| return recursive_fn | |
| recursive_contiguous = recursive_fn_factory(lambda x: x.contiguous()) | |
| recursive_clone = recursive_fn_factory(torch.clone) | |
| def compile_wrapper( | |
| fn, *, mode="max-autotune", fullgraph=True, dynamic=False, name=None | |
| ): | |
| compiled_fn = torch.compile(fn, mode=mode, fullgraph=fullgraph, dynamic=dynamic) | |
| def compiled_fn_wrapper(*args, **kwargs): | |
| with torch.autograd.profiler.record_function( | |
| f"compiled {fn}" if name is None else name | |
| ): | |
| cont_args = recursive_contiguous(args) | |
| cont_kwargs = recursive_contiguous(kwargs) | |
| result = compiled_fn(*cont_args, **cont_kwargs) | |
| cloned_result = recursive_clone(result) | |
| return cloned_result | |
| return compiled_fn_wrapper | |
| class InferencePipelinePointMap(InferencePipeline): | |
| def __init__( | |
| self, *args, depth_model, layout_post_optimization_method=None, clip_pointmap_beyond_scale=None, **kwargs | |
| ): | |
| self.depth_model = depth_model | |
| self.layout_post_optimization_method = layout_post_optimization_method | |
| self.clip_pointmap_beyond_scale = clip_pointmap_beyond_scale | |
| super().__init__(*args, **kwargs) | |
| def _compile(self): | |
| torch._dynamo.config.cache_size_limit = 64 | |
| torch._dynamo.config.accumulated_cache_size_limit = 2048 | |
| torch._dynamo.config.capture_scalar_outputs = True | |
| compile_mode = "max-autotune" | |
| for embedder, _ in self.condition_embedders[ | |
| "ss_condition_embedder" | |
| ].embedder_list: | |
| if isinstance(embedder, PointPatchEmbed): | |
| logger.info("Found PointPatchEmbed") | |
| embedder.inner_forward = compile_wrapper( | |
| embedder.inner_forward, | |
| mode=compile_mode, | |
| fullgraph=True, | |
| ) | |
| else: | |
| embedder.forward = compile_wrapper( | |
| embedder.forward, | |
| mode=compile_mode, | |
| fullgraph=True, | |
| ) | |
| self.models["ss_generator"].reverse_fn.inner_forward = compile_wrapper( | |
| self.models["ss_generator"].reverse_fn.inner_forward, | |
| mode=compile_mode, | |
| fullgraph=True, | |
| ) | |
| self.models["ss_decoder"].forward = compile_wrapper( | |
| self.models["ss_decoder"].forward, | |
| mode=compile_mode, | |
| fullgraph=True, | |
| ) | |
| self._warmup() | |
| def _warmup(self, num_warmup_iters=3): | |
| test_image = np.ones((512, 512, 4), dtype=np.uint8) * 255 | |
| test_image[:, :, :3] = np.random.randint(0, 255, (512, 512, 3), dtype=np.uint8) | |
| image = Image.fromarray(test_image) | |
| mask = None | |
| image = self.merge_image_and_mask(image, mask) | |
| with torch.inference_mode(False): | |
| with torch.no_grad(): | |
| for _ in tqdm(range(num_warmup_iters)): | |
| pointmap_dict = recursive_clone(self.compute_pointmap(image)) | |
| pointmap = pointmap_dict["pointmap"] | |
| ss_input_dict = self.preprocess_image( | |
| image, self.ss_preprocessor, pointmap=pointmap | |
| ) | |
| ss_return_dict = self.sample_sparse_structure( | |
| ss_input_dict, inference_steps=None | |
| ) | |
| _ = self.run_layout_model( | |
| ss_input_dict, | |
| ss_return_dict, | |
| inference_steps=None, | |
| ) | |
| def _preprocess_image_and_mask_pointmap( | |
| self, rgb_image, mask_image, pointmap, img_mask_pointmap_joint_transform | |
| ): | |
| for trans in img_mask_pointmap_joint_transform: | |
| rgb_image, mask_image, pointmap = trans( | |
| rgb_image, mask_image, pointmap=pointmap | |
| ) | |
| return rgb_image, mask_image, pointmap | |
| def preprocess_image( | |
| self, | |
| image: Union[Image.Image, np.ndarray], | |
| preprocessor, | |
| pointmap=None, | |
| ) -> torch.Tensor: | |
| # canonical type is numpy | |
| if not isinstance(image, np.ndarray): | |
| image = np.array(image) | |
| assert image.ndim == 3 # no batch dimension as of now | |
| assert image.shape[-1] == 4 # rgba format | |
| assert image.dtype == np.uint8 # [0,255] range | |
| rgba_image = torch.from_numpy(self.image_to_float(image)) | |
| rgba_image = rgba_image.permute(2, 0, 1).contiguous() | |
| rgb_image = rgba_image[:3] | |
| rgb_image_mask = get_mask(rgba_image, None, "ALPHA_CHANNEL") | |
| preprocessor_return_dict = preprocessor._process_image_mask_pointmap_mess( | |
| rgb_image, rgb_image_mask, pointmap | |
| ) | |
| # Put in a for loop? | |
| _item = preprocessor_return_dict | |
| item = { | |
| "mask": _item["mask"][None].to(self.device), | |
| "image": _item["image"][None].to(self.device), | |
| "rgb_image": _item["rgb_image"][None].to(self.device), | |
| "rgb_image_mask": _item["rgb_image_mask"][None].to(self.device), | |
| } | |
| if pointmap is not None and preprocessor.pointmap_transform != (None,): | |
| item["pointmap"] = _item["pointmap"][None].to(self.device) | |
| item["rgb_pointmap"] = _item["rgb_pointmap"][None].to(self.device) | |
| item["pointmap_scale"] = _item["pointmap_scale"][None].to(self.device) | |
| item["pointmap_shift"] = _item["pointmap_shift"][None].to(self.device) | |
| item["rgb_pointmap_scale"] = _item["rgb_pointmap_scale"][None].to(self.device) | |
| item["rgb_pointmap_shift"] = _item["rgb_pointmap_shift"][None].to(self.device) | |
| return item | |
| def _clip_pointmap(self, pointmap: torch.Tensor, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| if self.clip_pointmap_beyond_scale is None: | |
| return pointmap | |
| pointmap_size = (pointmap.shape[1], pointmap.shape[2]) | |
| if mask.dim() == 2: | |
| mask = mask.unsqueeze(0) | |
| mask_resized = torchvision.transforms.functional.resize( | |
| mask, pointmap_size, | |
| interpolation=torchvision.transforms.InterpolationMode.NEAREST | |
| ).squeeze(0) | |
| pointmap_flat = pointmap.reshape(3, -1) | |
| # Get valid points from the mask | |
| mask_bool = mask_resized.reshape(-1) > 0.5 | |
| mask_points = pointmap_flat[:, mask_bool] | |
| mask_distance = mask_points.nanmedian(dim=-1).values[-1] | |
| logger.info(f"mask_distance: {mask_distance}") | |
| pointmap_clipped_flat = torch.where( | |
| pointmap_flat[2, ...].abs() > self.clip_pointmap_beyond_scale * mask_distance, | |
| torch.full_like(pointmap_flat, float('nan')), | |
| pointmap_flat | |
| ) | |
| pointmap_clipped = pointmap_clipped_flat.reshape(pointmap.shape) | |
| return pointmap_clipped | |
| def compute_pointmap(self, image, pointmap=None): | |
| loaded_image = self.image_to_float(image) | |
| loaded_image = torch.from_numpy(loaded_image) | |
| loaded_mask = loaded_image[..., -1] | |
| loaded_image = loaded_image.permute(2, 0, 1).contiguous()[:3] | |
| if pointmap is None: | |
| with torch.no_grad(): | |
| with torch.autocast(device_type="cuda", dtype=self.dtype): | |
| output = self.depth_model(loaded_image) | |
| pointmaps = output["pointmaps"] | |
| camera_convention_transform = ( | |
| Transform3d() | |
| .rotate(camera_to_pytorch3d_camera(device=self.device).rotation) | |
| .to(self.device) | |
| ) | |
| points_tensor = camera_convention_transform.transform_points(pointmaps) | |
| intrinsics = output.get("intrinsics", None) | |
| else: | |
| output = {} | |
| points_tensor = pointmap.to(self.device) | |
| if loaded_image.shape != points_tensor.shape: | |
| # Interpolate points_tensor to match loaded_image size | |
| # loaded_image has shape [3, H, W], we need H and W | |
| points_tensor = torch.nn.functional.interpolate( | |
| points_tensor.permute(2, 0, 1).unsqueeze(0), | |
| size=(loaded_image.shape[1], loaded_image.shape[2]), | |
| mode="nearest", | |
| ).squeeze(0).permute(1, 2, 0) | |
| intrinsics = None | |
| points_tensor = points_tensor.permute(2, 0, 1) | |
| points_tensor = self._clip_pointmap(points_tensor, loaded_mask) | |
| # Prepare the point map tensor | |
| point_map_tensor = { | |
| "pointmap": points_tensor, | |
| "pts_color": loaded_image, | |
| } | |
| # If depth model doesn't provide intrinsics, infer them | |
| if intrinsics is None: | |
| intrinsics_result = infer_intrinsics_from_pointmap( | |
| points_tensor.permute(1, 2, 0), device=self.device | |
| ) | |
| point_map_tensor["intrinsics"] = intrinsics_result["intrinsics"] | |
| return point_map_tensor | |
| def run_post_optimization(self, mesh_glb, intrinsics, pose_dict, layout_input_dict): | |
| intrinsics = intrinsics.clone() | |
| fx, fy = intrinsics[0, 0], intrinsics[1, 1] | |
| re_focal = min(fx, fy) | |
| intrinsics[0, 0], intrinsics[1, 1] = re_focal, re_focal | |
| revised_quat, revised_t, revised_scale, final_iou, _, _ = ( | |
| self.layout_post_optimization_method( | |
| mesh_glb, | |
| pose_dict["rotation"], | |
| pose_dict["translation"], | |
| pose_dict["scale"], | |
| layout_input_dict["rgb_image_mask"][0, 0], | |
| layout_input_dict["rgb_pointmap"][0].permute(1, 2, 0), | |
| intrinsics, | |
| min_size=518, | |
| ) | |
| ) | |
| return { | |
| "rotation": revised_quat, | |
| "translation": revised_t, | |
| "scale": revised_scale, | |
| "iou": final_iou, | |
| } | |
| def run( | |
| self, | |
| image: Union[None, Image.Image, np.ndarray], | |
| mask: Union[None, Image.Image, np.ndarray] = None, | |
| seed: Optional[int] = None, | |
| stage1_only=False, | |
| with_mesh_postprocess=True, | |
| with_texture_baking=True, | |
| with_layout_postprocess=True, | |
| use_vertex_color=False, | |
| stage1_inference_steps=None, | |
| stage2_inference_steps=None, | |
| use_stage1_distillation=False, | |
| use_stage2_distillation=False, | |
| pointmap=None, | |
| decode_formats=None, | |
| estimate_plane=False, | |
| ) -> dict: | |
| image = self.merge_image_and_mask(image, mask) | |
| with self.device: | |
| pointmap_dict = self.compute_pointmap(image, pointmap) | |
| pointmap = pointmap_dict["pointmap"] | |
| pts = type(self)._down_sample_img(pointmap) | |
| pts_colors = type(self)._down_sample_img(pointmap_dict["pts_color"]) | |
| if estimate_plane: | |
| return self.estimate_plane(pointmap_dict, image) | |
| ss_input_dict = self.preprocess_image( | |
| image, self.ss_preprocessor, pointmap=pointmap | |
| ) | |
| slat_input_dict = self.preprocess_image(image, self.slat_preprocessor) | |
| if seed is not None: | |
| torch.manual_seed(seed) | |
| ss_return_dict = self.sample_sparse_structure( | |
| ss_input_dict, | |
| inference_steps=stage1_inference_steps, | |
| use_distillation=use_stage1_distillation, | |
| ) | |
| # We could probably use the decoder from the models themselves | |
| pointmap_scale = ss_input_dict.get("pointmap_scale", None) | |
| pointmap_shift = ss_input_dict.get("pointmap_shift", None) | |
| ss_return_dict.update( | |
| self.pose_decoder( | |
| ss_return_dict, | |
| scene_scale=pointmap_scale, | |
| scene_shift=pointmap_shift, | |
| ) | |
| ) | |
| logger.info(f"Rescaling scale by {ss_return_dict['downsample_factor']} after downsampling") | |
| ss_return_dict["scale"] = ss_return_dict["scale"] * ss_return_dict["downsample_factor"] | |
| if stage1_only: | |
| logger.info("Finished!") | |
| ss_return_dict["voxel"] = ss_return_dict["coords"][:, 1:] / 64 - 0.5 | |
| return { | |
| **ss_return_dict, | |
| "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3 | |
| "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)), # HxWx3 | |
| } | |
| # return ss_return_dict | |
| coords = ss_return_dict["coords"] | |
| slat = self.sample_slat( | |
| slat_input_dict, | |
| coords, | |
| inference_steps=stage2_inference_steps, | |
| use_distillation=use_stage2_distillation, | |
| ) | |
| outputs = self.decode_slat( | |
| slat, self.decode_formats if decode_formats is None else decode_formats | |
| ) | |
| outputs = self.postprocess_slat_output( | |
| outputs, with_mesh_postprocess, with_texture_baking, use_vertex_color | |
| ) | |
| glb = outputs.get("glb", None) | |
| try: | |
| if ( | |
| with_layout_postprocess | |
| and self.layout_post_optimization_method is not None | |
| ): | |
| assert glb is not None, "require mesh to run postprocessing" | |
| logger.info("Running layout post optimization method...") | |
| postprocessed_pose = self.run_post_optimization( | |
| deepcopy(glb), | |
| pointmap_dict["intrinsics"], | |
| ss_return_dict, | |
| ss_input_dict, | |
| ) | |
| ss_return_dict.update(postprocessed_pose) | |
| except Exception as e: | |
| logger.error( | |
| f"Error during layout post optimization: {e}", exc_info=True | |
| ) | |
| # glb.export("sample.glb") | |
| logger.info("Finished!") | |
| return { | |
| **ss_return_dict, | |
| **outputs, | |
| "pointmap": pts.cpu().permute((1, 2, 0)), # HxWx3 | |
| "pointmap_colors": pts_colors.cpu().permute((1, 2, 0)), # HxWx3 | |
| } | |
| def _down_sample_img(img_3chw: torch.Tensor): | |
| # img_3chw: (3, H, W) | |
| x = img_3chw.unsqueeze(0) | |
| if x.dtype == torch.uint8: | |
| x = x.float() / 255.0 | |
| max_side = max(x.shape[2], x.shape[3]) | |
| scale_factor = 1.0 | |
| # heuristics | |
| if max_side > 3800: | |
| scale_factor = 0.125 | |
| if max_side > 1900: | |
| scale_factor = 0.25 | |
| elif max_side > 1200: | |
| scale_factor = 0.5 | |
| x = torch.nn.functional.interpolate( | |
| x, | |
| scale_factor=(scale_factor, scale_factor), | |
| mode="bilinear", | |
| align_corners=False, | |
| antialias=True, | |
| ) # -> (1, 3, H/4, W/4) | |
| return x.squeeze(0) | |
| def estimate_plane(self, pointmap_dict, image, ground_area_threshold=0.25, min_points=100): | |
| assert image.shape[-1] == 4 # rgba format | |
| # Extract mask from alpha channel | |
| floor_mask = type(self)._down_sample_img(torch.from_numpy(image[..., -1]).float().unsqueeze(0))[0] > 0.5 | |
| pts = type(self)._down_sample_img(pointmap_dict["pointmap"]) | |
| # Get all points in 3D space (H, W, 3) | |
| pts_hwc = pts.cpu().permute((1, 2, 0)) | |
| valid_mask_points = floor_mask.cpu().numpy() | |
| # Extract points that fall within the mask | |
| if valid_mask_points.any(): | |
| # Get points within mask | |
| masked_points = pts_hwc[valid_mask_points] | |
| # Filter out invalid points (zero points from depth estimation failures) | |
| valid_points_mask = torch.norm(masked_points, dim=-1) > 1e-6 | |
| valid_points = masked_points[valid_points_mask] | |
| points = valid_points.numpy() | |
| else: | |
| points = np.array([]).reshape(0, 3) | |
| # Calculate area coverage and check num of points | |
| overlap_area = estimate_plane_area(floor_mask) | |
| has_enough_points = len(points) >= min_points | |
| logger.info(f"Plane estimation: {len(points)} points, {overlap_area:.3f} area coverage") | |
| if overlap_area > ground_area_threshold and has_enough_points: | |
| try: | |
| mesh = o3d_plane_estimation(points) | |
| logger.info("Successfully estimated plane mesh") | |
| except Exception as e: | |
| logger.error(f"Failed to estimate plane: {e}") | |
| mesh = None | |
| else: | |
| logger.info(f"Skipping plane estimation: area={overlap_area:.3f}, points={len(points)}") | |
| mesh = None | |
| return { | |
| "glb": mesh, | |
| "translation": torch.tensor([[0.0, 0.0, 0.0]]), | |
| "scale": torch.tensor([[1.0, 1.0, 1.0]]), | |
| "rotation": torch.tensor([[1.0, 0.0, 0.0, 0.0]]), | |
| } | |