xinjie.wang
update
6bc32b6
# Copyright (c) Meta Platforms, Inc. and affiliates.
"""
Utility functions for point map processing and intrinsics inference.
Extracted from moge library for use in sam3d_objects pipeline.
"""
from typing import Optional, Tuple, Union
import torch
import utils3d
# Import directly from moge for exact compatibility
from moge.utils.geometry_torch import (
normalized_view_plane_uv,
recover_focal_shift,
)
from moge.utils.geometry_numpy import (
solve_optimal_focal_shift,
solve_optimal_shift,
)
def infer_intrinsics_from_pointmap(
points: torch.Tensor,
mask: Optional[torch.Tensor] = None,
fov_x: Optional[Union[float, torch.Tensor]] = None,
mask_threshold: float = 0.5,
force_projection: bool = False,
apply_mask: bool = False,
device: Optional[torch.device] = None
) -> dict:
"""
Infer camera intrinsics from a point map.
Exact implementation matching moge library's inference logic.
Args:
points: Point map tensor of shape (B, H, W, 3) or (H, W, 3)
mask: Optional mask tensor of shape (B, H, W) or (H, W)
fov_x: Optional horizontal field of view in degrees. If None, inferred from points
mask_threshold: Threshold for binary mask creation
force_projection: If True, recompute points using depth and intrinsics
apply_mask: If True, apply mask to output points and depth
device: Device for computation. If None, uses points.device
Returns:
Dictionary containing:
- 'points': Camera-space points
- 'intrinsics': Camera intrinsics matrix
- 'depth': Depth map
- 'mask': Binary mask
"""
if device is None:
device = points.device
# Handle batch dimension
squeeze_batch = False
if points.dim() == 3:
points = points.unsqueeze(0)
if mask is not None:
mask = mask.unsqueeze(0)
squeeze_batch = True
height, width = points.shape[1:3]
aspect_ratio = width / height
# Always process the output in fp32 precision
with torch.autocast(device_type=device.type, dtype=torch.float32):
points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x])
mask_binary = mask > mask_threshold if mask is not None else torch.ones_like(points[..., 0], dtype=torch.bool)
# Add finite check to handle NaN and inf values
finite_mask = torch.isfinite(points).all(dim=-1)
mask_binary = mask_binary & finite_mask
# Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
if fov_x is None:
# BUG: Recover focal shift numpy method has flipped outputs: https://github.com/microsoft/MoGe/issues/110
shift, focal = recover_focal_shift(points, mask_binary)
else:
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
if focal.ndim == 0:
focal = focal[None].expand(points.shape[0])
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
depth = points[..., 2] + shift[..., None, None]
# If projection constraint is forced, recompute the point map using the actual depth map
if force_projection:
points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
else:
shift_stacked = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]
points = points + shift_stacked
# Apply mask if needed
if apply_mask:
points = torch.where(mask_binary[..., None], points, torch.inf)
depth = torch.where(mask_binary, depth, torch.inf)
return_dict = {
'points': points.squeeze(0) if squeeze_batch else points,
'intrinsics': intrinsics.squeeze(0) if squeeze_batch else intrinsics,
'depth': depth.squeeze(0) if squeeze_batch else depth,
'mask': mask_binary.squeeze(0) if squeeze_batch else mask_binary,
}
return return_dict