File size: 4,409 Bytes
7734c01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) Meta Platforms, Inc. and affiliates.
"""
Utility functions for point map processing and intrinsics inference.
Extracted from moge library for use in sam3d_objects pipeline.
"""

from typing import Optional, Tuple, Union
import torch
import utils3d
# Import directly from moge for exact compatibility
from moge.utils.geometry_torch import (
    normalized_view_plane_uv,
    recover_focal_shift,
)
from moge.utils.geometry_numpy import (
    solve_optimal_focal_shift,
    solve_optimal_shift,
)


def infer_intrinsics_from_pointmap(
    points: torch.Tensor,
    mask: Optional[torch.Tensor] = None,
    fov_x: Optional[Union[float, torch.Tensor]] = None,
    mask_threshold: float = 0.5,
    force_projection: bool = False,
    apply_mask: bool = False,
    device: Optional[torch.device] = None
) -> dict:
    """
    Infer camera intrinsics from a point map.
    
    Exact implementation matching moge library's inference logic.
    
    Args:
        points: Point map tensor of shape (B, H, W, 3) or (H, W, 3)
        mask: Optional mask tensor of shape (B, H, W) or (H, W)
        fov_x: Optional horizontal field of view in degrees. If None, inferred from points
        mask_threshold: Threshold for binary mask creation
        force_projection: If True, recompute points using depth and intrinsics
        apply_mask: If True, apply mask to output points and depth
        device: Device for computation. If None, uses points.device
    
    Returns:
        Dictionary containing:
        - 'points': Camera-space points
        - 'intrinsics': Camera intrinsics matrix
        - 'depth': Depth map
        - 'mask': Binary mask
    """
    if device is None:
        device = points.device
    
    # Handle batch dimension
    squeeze_batch = False
    if points.dim() == 3:
        points = points.unsqueeze(0)
        if mask is not None:
            mask = mask.unsqueeze(0)
        squeeze_batch = True
    
    height, width = points.shape[1:3]
    aspect_ratio = width / height
    
    # Always process the output in fp32 precision
    with torch.autocast(device_type=device.type, dtype=torch.float32):
        points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x])

        mask_binary = mask > mask_threshold if mask is not None else torch.ones_like(points[..., 0], dtype=torch.bool)
        
        # Add finite check to handle NaN and inf values
        finite_mask = torch.isfinite(points).all(dim=-1)
        mask_binary = mask_binary & finite_mask

        # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
        if fov_x is None:
            # BUG: Recover focal shift numpy method has flipped outputs: https://github.com/microsoft/MoGe/issues/110
            shift, focal = recover_focal_shift(points, mask_binary)
        else:
            focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
            if focal.ndim == 0:
                focal = focal[None].expand(points.shape[0])
            _, shift = recover_focal_shift(points, mask_binary, focal=focal)
        fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
        fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 
        intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
        depth = points[..., 2] + shift[..., None, None]
        
        # If projection constraint is forced, recompute the point map using the actual depth map
        if force_projection:
            points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
        else:
            shift_stacked = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]
            points = points + shift_stacked

        # Apply mask if needed
        if apply_mask:
            points = torch.where(mask_binary[..., None], points, torch.inf)
            depth = torch.where(mask_binary, depth, torch.inf)

    return_dict = {
        'points': points.squeeze(0) if squeeze_batch else points,
        'intrinsics': intrinsics.squeeze(0) if squeeze_batch else intrinsics,
        'depth': depth.squeeze(0) if squeeze_batch else depth,
        'mask': mask_binary.squeeze(0) if squeeze_batch else mask_binary,
    }
    
    return return_dict