File size: 10,722 Bytes
f4a0919
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
# Copyright (c) 2024, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper utils."""

import json
import copy
import itertools
import numpy as np
from functools import wraps
from contextlib import contextmanager
from typing import Tuple, Union, Optional
from fvcore.transforms.transform import Transform

import os
import torch
import torch.nn as nn


def adjust_intrinsic(
    intrinsic: Union[np.array, torch.Tensor],
    intrinsic_image_dim: Tuple,
    image_dim: Tuple
) -> Union[np.array, torch.Tensor]:
    """
    Adjust intrinsic camera parameters for image dimension changes.

    Args:
        intrinsic: Camera intrinsic matrix (numpy array or torch tensor)
        intrinsic_image_dim: Original image dimensions (width, height)
        image_dim: Target image dimensions (width, height)

    Returns:
        Adjusted intrinsic matrix (same type as input)
    """
    if intrinsic_image_dim == image_dim:
        return intrinsic

    # Calculate scaling factors
    height_after = image_dim[1]
    height_before = intrinsic_image_dim[1]
    width_after = image_dim[0]
    width_before = intrinsic_image_dim[0]

    width_scale = float(width_after) / float(width_before)
    height_scale = float(height_after) / float(height_before)
    width_offset_scale = float(width_after - 1) / float(width_before - 1)
    height_offset_scale = float(height_after - 1) / float(height_before - 1)

    # handle numpy array case
    if isinstance(intrinsic, np.ndarray):
        intrinsic_return = np.copy(intrinsic)

        intrinsic_return[0, 0] *= width_scale
        intrinsic_return[1, 1] *= height_scale
        # account for cropping/padding here
        intrinsic_return[0, 2] *= width_offset_scale
        intrinsic_return[1, 2] *= height_offset_scale

        return intrinsic_return

    # handle torch tensor case
    elif isinstance(intrinsic, torch.Tensor):
        intrinsic_return = intrinsic.clone()

        intrinsic_return[:, 0, 0] *= width_scale
        intrinsic_return[:, 1, 1] *= height_scale

        intrinsic_return[:, 0, 2] *= width_offset_scale
        intrinsic_return[:, 1, 2] *= height_offset_scale

        return intrinsic_return

    else:
        raise TypeError(f"Unsupported input type: {type(intrinsic)}.")


class ModelInputResize(Transform):
    """Resize and pad the model input."""

    def __init__(self, size_divisibility: int = 0, pad_value: float = 0):
        """Initialize model input resize transform."""
        super().__init__()
        self.size_divisibility = size_divisibility
        self.pad_value = pad_value

    def apply_coords(self, coords):
        """ Apply transforms to the coordinates. """
        return coords

    def apply_image(self, array: torch.Tensor) -> torch.Tensor:
        """ Apply transforms to the image. """
        assert len(array) > 0
        device = array.device
        image_size = [array.shape[-2], array.shape[-1]]

        max_size = torch.tensor(image_size, device=device)
        if self.size_divisibility > 1:
            stride = self.size_divisibility
            max_size = (max_size + (stride - 1)).div(stride, rounding_mode="floor") * stride

        u0 = max_size[-1] - image_size[1]
        u1 = max_size[-2] - image_size[0]
        padding_size = [0, u0, 0, u1]

        array = F.pad(array, padding_size, value=self.pad_value)
        return array

    def apply_segmentation(self, array: torch.Tensor) -> torch.Tensor:
        """ Apply transforms to the segmentation. """
        return array


@contextmanager
def _ignore_torch_cuda_oom():
    """
    A context which ignores CUDA OOM exception from pytorch.
    """
    try:
        yield
    except RuntimeError as e:
        # NOTE: the string may change?
        if "CUDA out of memory. " in str(e):
            pass
        else:
            raise

def retry_if_cuda_oom(func):
    """
    Makes a function retry itself after encountering
    pytorch's CUDA OOM error.
    It will first retry after calling `torch.cuda.empty_cache()`.

    If that still fails, it will then retry by trying to convert inputs to CPUs.
    In this case, it expects the function to dispatch to CPU implementation.
    The return values may become CPU tensors as well and it's user's
    responsibility to convert it back to CUDA tensor if needed.

    Args:
        func: a stateless callable that takes tensor-like objects as arguments

    Returns:
        a callable which retries `func` if OOM is encountered.

    Examples:
    ::
        output = retry_if_cuda_oom(some_torch_function)(input1, input2)
        # output may be on CPU even if inputs are on GPU

    Note:
        1. When converting inputs to CPU, it will only look at each argument and check
           if it has `.device` and `.to` for conversion. Nested structures of tensors
           are not supported.

        2. Since the function might be called more than once, it has to be
           stateless.
    """

    def maybe_to_cpu(x):
        """Convert to CPU."""
        try:
            like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to")
        except AttributeError:
            like_gpu_tensor = False
        if like_gpu_tensor:
            return x.to(device="cpu")
        return x

    @wraps(func)
    def wrapped(*args, **kwargs):
        """Wrapped function."""
        with _ignore_torch_cuda_oom():
            return func(*args, **kwargs)

        # Clear cache and retry
        torch.cuda.empty_cache()
        with _ignore_torch_cuda_oom():
            return func(*args, **kwargs)

        # Try on CPU. This slows down the code significantly, therefore print a notice.
        logging.info(f"Attempting to copy inputs of {str(func)} to CPU due to CUDA OOM")
        new_args = (maybe_to_cpu(x) for x in args)
        new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()}
        return func(*new_args, **new_kwargs)

    return wrapped


def prepare_kept_mapping(model, cfg, dataset, frustum_mask=None, intrinsic=None):
    """
    Prepare kept and mapping tensors using back projection.

    Args:
        model: The model instance with back_projection method
        cfg: Configuration object
        dataset: Dataset name ('front3d' or others)
        frustum_mask: Optional frustum mask tensor
        intrinsic: Intrinsic matrix tensor

    Returns:
        tuple: (kept, mapping) tensors from back projection
    """
    if dataset != "front3d":
        intrinsic = adjust_intrinsic(
            intrinsic,
            tuple(cfg.dataset.target_size),
            tuple(cfg.dataset.reduced_target_size)
        )
    kept, mapping = model.back_projection(
        tuple(cfg.dataset.reduced_target_size[::-1]) + (256,),
        intrinsic,
        frustum_mask
    )
    return kept, mapping


def get_kept_mapping(model, cfg, batch, device):
    """
    Get kept and mapping for a batch of data (used for non-front3d datasets).

    Args:
        model: The model instance with back_projection method
        cfg: Configuration object
        batch: Batch data containing frustum_mask and intrinsic
        device: Device to place tensors on

    Returns:
        tuple: (kept, mapping) tensors
    """
    frustum_mask = batch["frustum_mask"].to(device)
    intrinsic = batch["intrinsic"].float().to(device)
    dataset = cfg.dataset.name

    kept, mapping = prepare_kept_mapping(
        model,
        cfg,
        dataset,
        frustum_mask=frustum_mask,
        intrinsic=intrinsic
    )

    return kept, mapping


def get_norm(norm, out_channels):
    """
    Args:
        norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
            or a callable that takes a channel number and returns
            the normalization layer as a nn.Module.

    Returns:
        nn.Module or None: the normalization layer
    """
    if norm is None:
        return None
    if isinstance(norm, str):
        if len(norm) == 0:
            return None
        norm = {
            "SyncBN": nn.SyncBatchNorm,
            "GN": lambda channels: nn.GroupNorm(32, channels),
            "LN": lambda channels: LayerNorm(channels),
        }[norm]
    return norm(out_channels)


class Conv2d(nn.Conv2d):
    """
    A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features.
    """

    def __init__(self, *args, **kwargs):
        """
        Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`:

        Args:
            norm (nn.Module, optional): a normalization layer
            activation (callable(Tensor) -> Tensor): a callable activation function

        It assumes that norm layer is used before activation.
        """
        norm = kwargs.pop("norm", None)
        activation = kwargs.pop("activation", None)
        super().__init__(*args, **kwargs)

        self.norm = norm
        self.activation = activation

    def forward(self, x):
        """Forward pass."""
        # torchscript does not support SyncBatchNorm yet
        # https://github.com/pytorch/pytorch/issues/40507
        # and we skip these codes in torchscript since:
        # 1. currently we only support torchscript in evaluation mode
        # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or
        # later version, `Conv2d` in these PyTorch versions has already supported empty inputs.
        if not torch.jit.is_scripting():
            # Dynamo doesn't support context managers yet
            is_dynamo_compiling = is_compiling()
            if not is_dynamo_compiling:
                with warnings.catch_warnings(record=True):
                    if x.numel() == 0 and self.training:
                        # https://github.com/pytorch/pytorch/issues/12013
                        assert not isinstance(
                            self.norm, torch.nn.SyncBatchNorm
                        ), "SyncBatchNorm does not support empty inputs!"

        x = F.conv2d(
            x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups
        )
        if self.norm is not None:
            x = self.norm(x)
        if self.activation is not None:
            x = self.activation(x)
        return x