File size: 8,868 Bytes
f4a0919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Coordinate transform utils."""

import torch
import MinkowskiEngine as Me
from typing import List

from ..reconstruction.frustum import \
    generate_frustum, compute_camera2frustum_transform


def transform_feat3d_coordinates(
    feat3d, intrinsic,
    image_size=(120, 160),
    depth_min=0.4, depth_max=6.0,
    voxel_size=0.03
):
    """
    Transform feat3d coordinates to match Uni3D coordinate system

    Args:
        feat3d: Me.SparseTensor from occupancy-aware lifting
        intrinsic: Camera intrinsic matrix (4x4)
        image_size: tuple of (height, width)
        depth_min, depth_max: depth range
        voxel_size: voxel size in meters
    Returns:
        Me.SparseTensor with transformed coordinates
    """
    device = feat3d.device
    coords = feat3d.C.clone()

    # step 1: Apply coordinate flip (as done in BackProjection line 33)
    coords[:, 1:3] = 256 - coords[:, 1:3]  # flip x, y coordinates
    batch_indices = coords[:, 0].unique()

    compute_once = True
    if intrinsic.dim() == 3:  # batched intrinsics
        # check if all intrinsics are identical
        if len(batch_indices) > 1:
            compute_once = torch.allclose(intrinsic[0:1].expand_as(intrinsic), intrinsic, atol=1e-6)
        intrinsic_ref = intrinsic[0] if compute_once else None
    else:
        intrinsic_ref = intrinsic

    if compute_once:
        intrinsic_batch = intrinsic_ref
        intrinsic_inverse = torch.inverse(intrinsic_batch)
        frustum = generate_frustum(image_size, intrinsic_inverse, depth_min, depth_max)
        camera2frustum, padding_offsets = compute_camera2frustum_transform(
            frustum.to(device), voxel_size,
            frustum_dimensions=torch.tensor([256, 256, 256], device=device)
        )
        # pre-move to device and pre-compute inverse
        camera2frustum = camera2frustum.to(device)
        padding_offsets = padding_offsets.to(device)
        camera2frustum_inv = torch.inverse(camera2frustum).float()
        ones_offset = torch.tensor([1., 1., 1.], device=device)

    transformed_coords_list = []

    for batch_idx in batch_indices:
        batch_mask = coords[:, 0] == batch_idx
        batch_coords = coords[batch_mask, 1:].float()  # convert to float once per batch

        # use pre-computed values or compute per-batch
        if not compute_once:
            intrinsic_batch = intrinsic[int(batch_idx)]
            intrinsic_inverse = torch.inverse(intrinsic_batch)
            frustum = generate_frustum(image_size, intrinsic_inverse, depth_min, depth_max)
            camera2frustum, padding_offsets = compute_camera2frustum_transform(
                frustum.to(device), voxel_size,
                frustum_dimensions=torch.tensor([256, 256, 256], device=device)
            )
            camera2frustum = camera2frustum.float().to(device)
            padding_offsets = padding_offsets.to(device)
            camera2frustum_inv = torch.inverse(camera2frustum).float()
            ones_offset = torch.tensor([1., 1., 1.], device=device)

        # convert voxel coordinates to world coordinates (reverse of BackProjection)
        batch_coords_adjusted = batch_coords - padding_offsets - ones_offset

        # convert to homogeneous coordinates
        homogenous_coords = torch.cat([
            batch_coords_adjusted,
            torch.ones(batch_coords_adjusted.shape[0], 1, device=device)
        ], dim=1)  # [N_batch, 4]

        # apply transformations: world space -> frustum space
        world_coords = torch.mm(camera2frustum_inv, homogenous_coords.t())
        final_coords_homog = torch.mm(camera2frustum.float(), world_coords.float())
        final_coords = final_coords_homog.t()[:, :3]

        # add padding offsets (as done in SparseProjection.projection())
        final_coords = final_coords + padding_offsets

        # add batch index back
        batch_column = torch.full(
            (final_coords.shape[0], 1),
            batch_idx,
            device=device,
            dtype=torch.float32
        )
        final_batch_coords = torch.cat([batch_column, final_coords], dim=1)
        transformed_coords_list.append(final_batch_coords)

    transformed_coords = torch.cat(transformed_coords_list, dim=0)

    transformed_feat3d = Me.SparseTensor(
        features=feat3d.F,
        coordinates=transformed_coords.int(),
        tensor_stride=feat3d.tensor_stride,
        quantization_mode=feat3d.quantization_mode
    )

    return transformed_feat3d


def fuse_sparse_tensors(tensor1: Me.SparseTensor, tensor2: Me.SparseTensor) -> Me.SparseTensor:
    """
    Efficiently fuse two sparse tensors
    Args:
        tensor1 (Me.SparseTensor): First sparse tensor
        tensor2 (Me.SparseTensor): Second sparse tensor

    Returns:
        Me.SparseTensor: Fused sparse tensor with concatenated features
    """
    device = tensor1.device
    dtype = tensor1.F.dtype

    # get coordinates and features
    coords1, feats1 = tensor1.C, tensor1.F
    coords2, feats2 = tensor2.C, tensor2.F

    feat_dim1, feat_dim2 = feats1.shape[1], feats2.shape[1]
    fused_feat_dim = feat_dim1 + feat_dim2

    # concatenate coordinates and create source tracking
    all_coords = torch.cat([coords1, coords2], dim=0)
    n_coords1 = coords1.shape[0]

    # convert each coordinate row to a view that can be uniqued
    coord_view = all_coords.view(all_coords.shape[0], -1)

    # use torch.unique with return_inverse to get mapping
    unique_coord_view, inverse_indices = torch.unique(coord_view, dim=0, return_inverse=True)
    unique_coords = unique_coord_view.view(-1, coords1.shape[1])
    n_unique = unique_coords.shape[0]

    # split inverse indices for each tensor
    inv_indices_1 = inverse_indices[:n_coords1]
    inv_indices_2 = inverse_indices[n_coords1:]

    # pre-allocate with zeros for automatic padding
    fused_features = torch.zeros(n_unique, fused_feat_dim, device=device, dtype=dtype)

    # tensor1 features go to positions [0:feat_dim1]
    fused_features[inv_indices_1, :feat_dim1] = feats1

    # tensor2 features go to positions [feat_dim1:feat_dim1+feat_dim2]
    fused_features[inv_indices_2, feat_dim1:] = feats2
    fused_tensor = Me.SparseTensor(
        features=fused_features,
        coordinates=unique_coords.int(),
        tensor_stride=tensor1.tensor_stride,
        quantization_mode=tensor1.quantization_mode
    )
    return fused_tensor


def generate_multiscale_feat3d(transformed_feat3d: Me.SparseTensor) -> List[Me.SparseTensor]:
    """
    Generate multi-scale sparse 3D features
    from transformed_feat3d to match sparse_multi_scale_features structure.
    Args:
        transformed_feat3d (Me.SparseTensor):
        Input sparse tensor from occupancy-aware lifting (256 grid)

    Returns:
        List[Me.SparseTensor]: Multi-scale sparse tensors
        at scales [1/2, 1/4, 1/8] corresponding to [128, 64, 32] grid sizes
    """
    device = transformed_feat3d.device

    # use consistent stride 2 for progressive downsampling
    # this ensures proper 1/2, 1/4, 1/8 scaling from original 256 grid
    pooling_op = Me.MinkowskiMaxPooling(
        kernel_size=3,
        stride=2,
        dimension=3
    ).to(device)

    multi_scale_feat3d = []
    current_tensor = transformed_feat3d
    target_strides = [2, 4, 8]  # Expected final strides for each scale

    # generate features at each scale by progressive pooling with stride 2
    for _, target_stride in enumerate(target_strides):
        # apply stride-2 pooling to get next scale
        pooled_tensor = pooling_op(current_tensor)

        # ensure the tensor stride matches expected value
        # the stride should be: 2^(i+1) relative to original
        if pooled_tensor.tensor_stride != target_stride:
            pooled_tensor = Me.SparseTensor(
                features=pooled_tensor.F,
                coordinates=pooled_tensor.C,
                tensor_stride=target_stride,
                quantization_mode=pooled_tensor.quantization_mode
            )

        multi_scale_feat3d.append(pooled_tensor)

        # use pooled tensor as input for next scale (progressive downsampling)
        # this gives us: 256 → 128 → 64 → 32 grid sizes
        current_tensor = pooled_tensor

    return multi_scale_feat3d