| |
| |
| |
| |
| |
|
|
| |
|
|
|
|
| import copy |
| import inspect |
| import warnings |
| from typing import Any, List, Optional, Tuple, TypeVar, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
|
|
| from ..common.datatypes import Device, make_device |
|
|
|
|
| class TensorAccessor(nn.Module): |
| """ |
| A helper class to be used with the __getitem__ method. This can be used for |
| getting/setting the values for an attribute of a class at one particular |
| index. This is useful when the attributes of a class are batched tensors |
| and one element in the batch needs to be modified. |
| """ |
|
|
| def __init__(self, class_object, index: Union[int, slice]) -> None: |
| """ |
| Args: |
| class_object: this should be an instance of a class which has |
| attributes which are tensors representing a batch of |
| values. |
| index: int/slice, an index indicating the position in the batch. |
| In __setattr__ and __getattr__ only the value of class |
| attributes at this index will be accessed. |
| """ |
| self.__dict__["class_object"] = class_object |
| self.__dict__["index"] = index |
|
|
| def __setattr__(self, name: str, value: Any): |
| """ |
| Update the attribute given by `name` to the value given by `value` |
| at the index specified by `self.index`. |
| |
| Args: |
| name: str, name of the attribute. |
| value: value to set the attribute to. |
| """ |
| v = getattr(self.class_object, name) |
| if not torch.is_tensor(v): |
| msg = "Can only set values on attributes which are tensors; got %r" |
| raise AttributeError(msg % type(v)) |
|
|
| |
| if not torch.is_tensor(value): |
| value = torch.tensor( |
| value, device=v.device, dtype=v.dtype, requires_grad=v.requires_grad |
| ) |
|
|
| |
| if v.dim() > 1 and value.dim() > 1 and value.shape[1:] != v.shape[1:]: |
| msg = "Expected value to have shape %r; got %r" |
| raise ValueError(msg % (v.shape, value.shape)) |
| if ( |
| v.dim() == 0 |
| and isinstance(self.index, slice) |
| and len(value) != len(self.index) |
| ): |
| msg = "Expected value to have len %r; got %r" |
| raise ValueError(msg % (len(self.index), len(value))) |
| self.class_object.__dict__[name][self.index] = value |
|
|
| def __getattr__(self, name: str): |
| """ |
| Return the value of the attribute given by "name" on self.class_object |
| at the index specified in self.index. |
| |
| Args: |
| name: string of the attribute name |
| """ |
| if hasattr(self.class_object, name): |
| return self.class_object.__dict__[name][self.index] |
| else: |
| msg = "Attribute %s not found on %r" |
| return AttributeError(msg % (name, self.class_object.__name__)) |
|
|
|
|
| BROADCAST_TYPES = (float, int, list, tuple, torch.Tensor, np.ndarray) |
|
|
|
|
| class TensorProperties(nn.Module): |
| """ |
| A mix-in class for storing tensors as properties with helper methods. |
| """ |
|
|
| def __init__( |
| self, |
| dtype: torch.dtype = torch.float32, |
| device: Device = "cpu", |
| **kwargs, |
| ) -> None: |
| """ |
| Args: |
| dtype: data type to set for the inputs |
| device: Device (as str or torch.device) |
| kwargs: any number of keyword arguments. Any arguments which are |
| of type (float/int/list/tuple/tensor/array) are broadcasted and |
| other keyword arguments are set as attributes. |
| """ |
| super().__init__() |
| self.device = make_device(device) |
| self._N = 0 |
| if kwargs is not None: |
| |
| |
| args_to_broadcast = {} |
| for k, v in kwargs.items(): |
| if v is None or isinstance(v, (str, bool)): |
| setattr(self, k, v) |
| elif isinstance(v, BROADCAST_TYPES): |
| args_to_broadcast[k] = v |
| else: |
| msg = "Arg %s with type %r is not broadcastable" |
| warnings.warn(msg % (k, type(v))) |
|
|
| names = args_to_broadcast.keys() |
| |
| values = tuple(v for v in args_to_broadcast.values()) |
|
|
| if len(values) > 0: |
| broadcasted_values = convert_to_tensors_and_broadcast( |
| *values, device=device |
| ) |
|
|
| |
| for i, n in enumerate(names): |
| setattr(self, n, broadcasted_values[i]) |
| if self._N == 0: |
| self._N = broadcasted_values[i].shape[0] |
|
|
| def __len__(self) -> int: |
| return self._N |
|
|
| def isempty(self) -> bool: |
| return self._N == 0 |
|
|
| def __getitem__(self, index: Union[int, slice]) -> TensorAccessor: |
| """ |
| |
| Args: |
| index: an int or slice used to index all the fields. |
| |
| Returns: |
| if `index` is an index int/slice return a TensorAccessor class |
| with getattribute/setattribute methods which return/update the value |
| at the index in the original class. |
| """ |
| if isinstance(index, (int, slice)): |
| return TensorAccessor(class_object=self, index=index) |
|
|
| msg = "Expected index of type int or slice; got %r" |
| raise ValueError(msg % type(index)) |
|
|
| |
| def to(self, device: Device = "cpu") -> "TensorProperties": |
| """ |
| In place operation to move class properties which are tensors to a |
| specified device. If self has a property "device", update this as well. |
| """ |
| device_ = make_device(device) |
| for k in dir(self): |
| v = getattr(self, k) |
| if k == "device": |
| setattr(self, k, device_) |
| if torch.is_tensor(v) and v.device != device_: |
| setattr(self, k, v.to(device_)) |
| return self |
|
|
| def cpu(self) -> "TensorProperties": |
| return self.to("cpu") |
|
|
| |
| def cuda(self, device: Optional[int] = None) -> "TensorProperties": |
| return self.to(f"cuda:{device}" if device is not None else "cuda") |
|
|
| def clone(self, other) -> "TensorProperties": |
| """ |
| Update the tensor properties of other with the cloned properties of self. |
| """ |
| for k in dir(self): |
| v = getattr(self, k) |
| if inspect.ismethod(v) or k.startswith("__") or type(v) is TypeVar: |
| continue |
| if torch.is_tensor(v): |
| v_clone = v.clone() |
| else: |
| v_clone = copy.deepcopy(v) |
| setattr(other, k, v_clone) |
| return other |
|
|
| def gather_props(self, batch_idx) -> "TensorProperties": |
| """ |
| This is an in place operation to reformat all tensor class attributes |
| based on a set of given indices using torch.gather. This is useful when |
| attributes which are batched tensors e.g. shape (N, 3) need to be |
| multiplied with another tensor which has a different first dimension |
| e.g. packed vertices of shape (V, 3). |
| |
| Example |
| |
| .. code-block:: python |
| |
| self.specular_color = (N, 3) tensor of specular colors for each mesh |
| |
| A lighting calculation may use |
| |
| .. code-block:: python |
| |
| verts_packed = meshes.verts_packed() # (V, 3) |
| |
| To multiply these two tensors the batch dimension needs to be the same. |
| To achieve this we can do |
| |
| .. code-block:: python |
| |
| batch_idx = meshes.verts_packed_to_mesh_idx() # (V) |
| |
| This gives index of the mesh for each vertex in verts_packed. |
| |
| .. code-block:: python |
| |
| self.gather_props(batch_idx) |
| self.specular_color = (V, 3) tensor with the specular color for |
| each packed vertex. |
| |
| torch.gather requires the index tensor to have the same shape as the |
| input tensor so this method takes care of the reshaping of the index |
| tensor to use with class attributes with arbitrary dimensions. |
| |
| Args: |
| batch_idx: shape (B, ...) where `...` represents an arbitrary |
| number of dimensions |
| |
| Returns: |
| self with all properties reshaped. e.g. a property with shape (N, 3) |
| is transformed to shape (B, 3). |
| """ |
| |
| for k in dir(self): |
| v = getattr(self, k) |
| if torch.is_tensor(v): |
| if v.shape[0] > 1: |
| |
| |
| |
| |
| _batch_idx = batch_idx.clone() |
| idx_dims = _batch_idx.shape |
| tensor_dims = v.shape |
| if len(idx_dims) > len(tensor_dims): |
| msg = "batch_idx cannot have more dimensions than %s. " |
| msg += "got shape %r and %s has shape %r" |
| raise ValueError(msg % (k, idx_dims, k, tensor_dims)) |
| if idx_dims != tensor_dims: |
| |
| |
| new_dims = len(tensor_dims) - len(idx_dims) |
| new_shape = idx_dims + (1,) * new_dims |
| |
| |
| expand_dims = (-1,) + tensor_dims[1:] |
| _batch_idx = _batch_idx.view(*new_shape) |
| _batch_idx = _batch_idx.expand(*expand_dims) |
|
|
| v = v.gather(0, _batch_idx) |
| setattr(self, k, v) |
| return self |
|
|
|
|
| def format_tensor( |
| input, |
| dtype: torch.dtype = torch.float32, |
| device: Device = "cpu", |
| ) -> torch.Tensor: |
| """ |
| Helper function for converting a scalar value to a tensor. |
| |
| Args: |
| input: Python scalar, Python list/tuple, torch scalar, 1D torch tensor |
| dtype: data type for the input |
| device: Device (as str or torch.device) on which the tensor should be placed. |
| |
| Returns: |
| input_vec: torch tensor with optional added batch dimension. |
| """ |
| device_ = make_device(device) |
| if not torch.is_tensor(input): |
| input = torch.tensor(input, dtype=dtype, device=device_) |
|
|
| if input.dim() == 0: |
| input = input.view(1) |
|
|
| if input.device == device_: |
| return input |
|
|
| input = input.to(device=device) |
| return input |
|
|
|
|
| def convert_to_tensors_and_broadcast( |
| *args, |
| dtype: torch.dtype = torch.float32, |
| device: Device = "cpu", |
| ): |
| """ |
| Helper function to handle parsing an arbitrary number of inputs (*args) |
| which all need to have the same batch dimension. |
| The output is a list of tensors. |
| |
| Args: |
| *args: an arbitrary number of inputs |
| Each of the values in `args` can be one of the following |
| - Python scalar |
| - Torch scalar |
| - Torch tensor of shape (N, K_i) or (1, K_i) where K_i are |
| an arbitrary number of dimensions which can vary for each |
| value in args. In this case each input is broadcast to a |
| tensor of shape (N, K_i) |
| dtype: data type to use when creating new tensors. |
| device: torch device on which the tensors should be placed. |
| |
| Output: |
| args: A list of tensors of shape (N, K_i) |
| """ |
| |
| args_1d = [format_tensor(c, dtype, device) for c in args] |
|
|
| |
| sizes = [c.shape[0] for c in args_1d] |
| N = max(sizes) |
|
|
| args_Nd = [] |
| for c in args_1d: |
| if c.shape[0] != 1 and c.shape[0] != N: |
| msg = "Got non-broadcastable sizes %r" % sizes |
| raise ValueError(msg) |
|
|
| |
| expand_sizes = (N,) + (-1,) * len(c.shape[1:]) |
| args_Nd.append(c.expand(*expand_sizes)) |
|
|
| return args_Nd |
|
|
|
|
| def ndc_grid_sample( |
| input: torch.Tensor, |
| grid_ndc: torch.Tensor, |
| *, |
| align_corners: bool = False, |
| **grid_sample_kwargs, |
| ) -> torch.Tensor: |
| """ |
| Samples a tensor `input` of shape `(B, dim, H, W)` at 2D locations |
| specified by a tensor `grid_ndc` of shape `(B, ..., 2)` using |
| the `torch.nn.functional.grid_sample` function. |
| `grid_ndc` is specified in PyTorch3D NDC coordinate frame. |
| |
| Args: |
| input: The tensor of shape `(B, dim, H, W)` to be sampled. |
| grid_ndc: A tensor of shape `(B, ..., 2)` denoting the set of |
| 2D locations at which `input` is sampled. |
| See [1] for a detailed description of the NDC coordinates. |
| align_corners: Forwarded to the `torch.nn.functional.grid_sample` |
| call. See its docstring. |
| grid_sample_kwargs: Additional arguments forwarded to the |
| `torch.nn.functional.grid_sample` call. See the corresponding |
| docstring for a listing of the corresponding arguments. |
| |
| Returns: |
| sampled_input: A tensor of shape `(B, dim, ...)` containing the samples |
| of `input` at 2D locations `grid_ndc`. |
| |
| References: |
| [1] https://pytorch3d.org/docs/cameras |
| """ |
|
|
| batch, *spatial_size, pt_dim = grid_ndc.shape |
| if batch != input.shape[0]: |
| raise ValueError("'input' and 'grid_ndc' have to have the same batch size.") |
| if input.ndim != 4: |
| raise ValueError("'input' has to be a 4-dimensional Tensor.") |
| if pt_dim != 2: |
| raise ValueError("The last dimension of 'grid_ndc' has to be == 2.") |
|
|
| grid_ndc_flat = grid_ndc.reshape(batch, -1, 1, 2) |
|
|
| |
| grid_flat = ndc_to_grid_sample_coords(grid_ndc_flat, input.shape[2:]) |
|
|
| sampled_input_flat = torch.nn.functional.grid_sample( |
| input, grid_flat, align_corners=align_corners, **grid_sample_kwargs |
| ) |
|
|
| sampled_input = sampled_input_flat.reshape([batch, input.shape[1], *spatial_size]) |
|
|
| return sampled_input |
|
|
|
|
| def ndc_to_grid_sample_coords( |
| xy_ndc: torch.Tensor, |
| image_size_hw: Tuple[int, int], |
| ) -> torch.Tensor: |
| """ |
| Convert from the PyTorch3D's NDC coordinates to |
| `torch.nn.functional.grid_sampler`'s coordinates. |
| |
| Args: |
| xy_ndc: Tensor of shape `(..., 2)` containing 2D points in the |
| PyTorch3D's NDC coordinates. |
| image_size_hw: A tuple `(image_height, image_width)` denoting the |
| height and width of the image tensor to sample. |
| Returns: |
| xy_grid_sample: Tensor of shape `(..., 2)` containing 2D points in the |
| `torch.nn.functional.grid_sample` coordinates. |
| """ |
| if len(image_size_hw) != 2 or any(s <= 0 for s in image_size_hw): |
| raise ValueError("'image_size_hw' has to be a 2-tuple of positive integers") |
| aspect = min(image_size_hw) / max(image_size_hw) |
| xy_grid_sample = -xy_ndc |
| if image_size_hw[0] >= image_size_hw[1]: |
| xy_grid_sample[..., 1] *= aspect |
| else: |
| xy_grid_sample[..., 0] *= aspect |
| return xy_grid_sample |
|
|
|
|
| def parse_image_size( |
| image_size: Union[List[int], Tuple[int, int], int], |
| ) -> Tuple[int, int]: |
| """ |
| Args: |
| image_size: A single int (for square images) or a tuple/list of two ints. |
| |
| Returns: |
| A tuple of two ints. |
| |
| Throws: |
| ValueError if got more than two ints, any negative numbers or non-ints. |
| """ |
| if not isinstance(image_size, (tuple, list)): |
| return (image_size, image_size) |
| if len(image_size) != 2: |
| raise ValueError("Image size can only be a tuple/list of (H, W)") |
| if not all(i > 0 for i in image_size): |
| raise ValueError("Image sizes must be greater than 0; got %d, %d" % image_size) |
| if not all(isinstance(i, int) for i in image_size): |
| raise ValueError("Image sizes must be integers; got %f, %f" % image_size) |
| return tuple(image_size) |
|
|