Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| """ | |
| Utility functions for point map processing and intrinsics inference. | |
| Extracted from moge library for use in sam3d_objects pipeline. | |
| """ | |
| from typing import Optional, Tuple, Union | |
| import torch | |
| import utils3d | |
| # Import directly from moge for exact compatibility | |
| from moge.utils.geometry_torch import ( | |
| normalized_view_plane_uv, | |
| recover_focal_shift, | |
| ) | |
| from moge.utils.geometry_numpy import ( | |
| solve_optimal_focal_shift, | |
| solve_optimal_shift, | |
| ) | |
| def infer_intrinsics_from_pointmap( | |
| points: torch.Tensor, | |
| mask: Optional[torch.Tensor] = None, | |
| fov_x: Optional[Union[float, torch.Tensor]] = None, | |
| mask_threshold: float = 0.5, | |
| force_projection: bool = False, | |
| apply_mask: bool = False, | |
| device: Optional[torch.device] = None | |
| ) -> dict: | |
| """ | |
| Infer camera intrinsics from a point map. | |
| Exact implementation matching moge library's inference logic. | |
| Args: | |
| points: Point map tensor of shape (B, H, W, 3) or (H, W, 3) | |
| mask: Optional mask tensor of shape (B, H, W) or (H, W) | |
| fov_x: Optional horizontal field of view in degrees. If None, inferred from points | |
| mask_threshold: Threshold for binary mask creation | |
| force_projection: If True, recompute points using depth and intrinsics | |
| apply_mask: If True, apply mask to output points and depth | |
| device: Device for computation. If None, uses points.device | |
| Returns: | |
| Dictionary containing: | |
| - 'points': Camera-space points | |
| - 'intrinsics': Camera intrinsics matrix | |
| - 'depth': Depth map | |
| - 'mask': Binary mask | |
| """ | |
| if device is None: | |
| device = points.device | |
| # Handle batch dimension | |
| squeeze_batch = False | |
| if points.dim() == 3: | |
| points = points.unsqueeze(0) | |
| if mask is not None: | |
| mask = mask.unsqueeze(0) | |
| squeeze_batch = True | |
| height, width = points.shape[1:3] | |
| aspect_ratio = width / height | |
| # Always process the output in fp32 precision | |
| with torch.autocast(device_type=device.type, dtype=torch.float32): | |
| points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x]) | |
| mask_binary = mask > mask_threshold if mask is not None else torch.ones_like(points[..., 0], dtype=torch.bool) | |
| # Add finite check to handle NaN and inf values | |
| finite_mask = torch.isfinite(points).all(dim=-1) | |
| mask_binary = mask_binary & finite_mask | |
| # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal) | |
| if fov_x is None: | |
| # BUG: Recover focal shift numpy method has flipped outputs: https://github.com/microsoft/MoGe/issues/110 | |
| shift, focal = recover_focal_shift(points, mask_binary) | |
| else: | |
| focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) | |
| if focal.ndim == 0: | |
| focal = focal[None].expand(points.shape[0]) | |
| _, shift = recover_focal_shift(points, mask_binary, focal=focal) | |
| fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio | |
| fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 | |
| intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5) | |
| depth = points[..., 2] + shift[..., None, None] | |
| # If projection constraint is forced, recompute the point map using the actual depth map | |
| if force_projection: | |
| points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics) | |
| else: | |
| shift_stacked = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :] | |
| points = points + shift_stacked | |
| # Apply mask if needed | |
| if apply_mask: | |
| points = torch.where(mask_binary[..., None], points, torch.inf) | |
| depth = torch.where(mask_binary, depth, torch.inf) | |
| return_dict = { | |
| 'points': points.squeeze(0) if squeeze_batch else points, | |
| 'intrinsics': intrinsics.squeeze(0) if squeeze_batch else intrinsics, | |
| 'depth': depth.squeeze(0) if squeeze_batch else depth, | |
| 'mask': mask_binary.squeeze(0) if squeeze_batch else mask_binary, | |
| } | |
| return return_dict | |