File size: 7,516 Bytes
36c95ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from typing import cast, List, Optional, Tuple, Union

import torch


def _common_param_check(batch_size: int, same_on_batch: Optional[bool] = None):
    """Valid batch_size and same_on_batch params."""
    if not (type(batch_size) is int and batch_size >= 0):
        raise AssertionError(f"`batch_size` shall be a positive integer. Got {batch_size}.")
    if same_on_batch is not None and type(same_on_batch) is not bool:
        raise AssertionError(f"`same_on_batch` shall be boolean. Got {same_on_batch}.")


def _range_bound(
    factor: Union[torch.Tensor, float, Tuple[float, float], List[float]],
    name: str,
    center: float = 0.0,
    bounds: Tuple[float, float] = (0, float('inf')),
    check: Optional[str] = 'joint',
    device: torch.device = torch.device('cpu'),
    dtype: torch.dtype = torch.get_default_dtype(),
) -> torch.Tensor:
    r"""Check inputs and compute the corresponding factor bounds"""
    if not isinstance(factor, (torch.Tensor)):
        factor = torch.tensor(factor, device=device, dtype=dtype)
    factor_bound: torch.Tensor

    if factor.dim() == 0:
        if factor < 0:
            raise ValueError(f"If {name} is a single number number, it must be non negative. Got {factor}")
        # Should be something other than clamp
        # Currently, single value factor will not out of scope as long as the user provided it.
        # Note: I personally think throw an error will be better than a coarse clamp.
        factor_bound = factor.repeat(2) * torch.tensor([-1.0, 1.0], device=factor.device, dtype=factor.dtype) + center
        factor_bound = factor_bound.clamp(bounds[0], bounds[1])
    else:
        factor_bound = torch.as_tensor(factor, device=device, dtype=dtype)

    if check is not None:
        if check == 'joint':
            _joint_range_check(factor_bound, name, bounds)
        elif check == 'singular':
            _singular_range_check(factor_bound, name, bounds)
        else:
            raise NotImplementedError(f"methods '{check}' not implemented.")

    return factor_bound


def _joint_range_check(ranged_factor: torch.Tensor, name: str, bounds: Optional[Tuple[float, float]] = None) -> None:
    """Check if bounds[0] <= ranged_factor[0] <= ranged_factor[1] <= bounds[1]"""
    if bounds is None:
        bounds = (float('-inf'), float('inf'))
    if ranged_factor.dim() == 1 and len(ranged_factor) == 2:
        if not bounds[0] <= ranged_factor[0] or not bounds[1] >= ranged_factor[1]:
            raise ValueError(f"{name} out of bounds. Expected inside {bounds}, got {ranged_factor}.")

        if not bounds[0] <= ranged_factor[0] <= ranged_factor[1] <= bounds[1]:
            raise ValueError(f"{name}[0] should be smaller than {name}[1] got {ranged_factor}")
    else:
        raise TypeError(
            f"{name} should be a tensor with length 2 whose values between {bounds}. " f"Got {ranged_factor}."
        )


def _singular_range_check(
    ranged_factor: torch.Tensor,
    name: str,
    bounds: Optional[Tuple[float, float]] = None,
    skip_none: bool = False,
    mode: str = '2d',
) -> None:
    """Check if bounds[0] <= ranged_factor[0] <= bounds[1] and bounds[0] <= ranged_factor[1] <= bounds[1]"""
    if mode == '2d':
        dim_size = 2
    elif mode == '3d':
        dim_size = 3
    else:
        raise ValueError(f"'mode' shall be either 2d or 3d. Got {mode}")

    if skip_none and ranged_factor is None:
        return
    if bounds is None:
        bounds = (float('-inf'), float('inf'))
    if ranged_factor.dim() == 1 and len(ranged_factor) == dim_size:
        for f in ranged_factor:
            if not bounds[0] <= f <= bounds[1]:
                raise ValueError(f"{name} out of bounds. Expected inside {bounds}, got {ranged_factor}.")
    else:
        raise TypeError(
            f"{name} should be a float number or a tuple with length {dim_size} whose values between {bounds}."
            f"Got {ranged_factor}"
        )


def _tuple_range_reader(
    input_range: Union[torch.Tensor, float, tuple],
    target_size: int,
    device: Optional[torch.device] = None,
    dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
    """Given target_size, it will generate the corresponding (target_size, 2) range tensor for element-wise params.

    Example:
    >>> degree = torch.tensor([0.2, 0.3])
    >>> _tuple_range_reader(degree, 3)  # read degree for yaw, pitch and roll.
    tensor([[0.2000, 0.3000],
            [0.2000, 0.3000],
            [0.2000, 0.3000]])
    """
    target_shape = torch.Size([target_size, 2])
    if not torch.is_tensor(input_range):
        if isinstance(input_range, (float, int)):
            if input_range < 0:
                raise ValueError(f"If input_range is only one number it must be a positive number. Got{input_range}")
            input_range_tmp = torch.tensor([-input_range, input_range], device=device, dtype=dtype).repeat(
                target_shape[0], 1
            )

        elif (
            isinstance(input_range, (tuple, list))
            and len(input_range) == 2
            and isinstance(input_range[0], (float, int))
            and isinstance(input_range[1], (float, int))
        ):
            input_range_tmp = torch.tensor(input_range, device=device, dtype=dtype).repeat(target_shape[0], 1)

        elif (
            isinstance(input_range, (tuple, list))
            and len(input_range) == target_shape[0]
            and all(isinstance(x, (float, int)) for x in input_range)
        ):
            input_range_tmp = torch.tensor([(-s, s) for s in input_range], device=device, dtype=dtype)

        elif (
            isinstance(input_range, (tuple, list))
            and len(input_range) == target_shape[0]
            and all(isinstance(x, (tuple, list)) for x in input_range)
        ):
            input_range_tmp = torch.tensor(input_range, device=device, dtype=dtype)

        else:
            raise TypeError(
                "If not pass a tensor, it must be float, (float, float) for isotropic operation or a tuple of "
                f"{target_size} floats or {target_size} (float, float) for independent operation. Got {input_range}."
            )

    else:
        # https://mypy.readthedocs.io/en/latest/casts.html cast to please mypy gods
        input_range = cast(torch.Tensor, input_range)
        if (len(input_range.shape) == 0) or (len(input_range.shape) == 1 and len(input_range) == 1):
            if input_range < 0:
                raise ValueError(f"If input_range is only one number it must be a positive number. Got{input_range}")
            input_range_tmp = input_range.repeat(2) * torch.tensor(
                [-1.0, 1.0], device=input_range.device, dtype=input_range.dtype
            )
            input_range_tmp = input_range_tmp.repeat(target_shape[0], 1)

        elif len(input_range.shape) == 1 and len(input_range) == 2:
            input_range_tmp = input_range.repeat(target_shape[0], 1)

        elif len(input_range.shape) == 1 and len(input_range) == target_shape[0]:
            input_range_tmp = input_range.unsqueeze(1).repeat(1, 2) * torch.tensor(
                [-1, 1], device=input_range.device, dtype=input_range.dtype
            )

        elif input_range.shape == target_shape:
            input_range_tmp = input_range

        else:
            raise ValueError(
                f"Degrees must be a {list(target_shape)} tensor for the degree range for independent operation."
                f"Got {input_range}"
            )

    return input_range_tmp