File size: 4,035 Bytes
36c95ba |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
from typing import Optional
import torch
def create_meshgrid(
height: int,
width: int,
normalized_coordinates: bool = True,
device: Optional[torch.device] = torch.device('cpu'),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Generate a coordinate grid for an image.
When the flag ``normalized_coordinates`` is set to True, the grid is
normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
function :py:func:`torch.nn.functional.grid_sample`.
Args:
height: the image height (rows).
width: the image width (cols).
normalized_coordinates: whether to normalize
coordinates in the range :math:`[-1,1]` in order to be consistent with the
PyTorch function :py:func:`torch.nn.functional.grid_sample`.
device: the device on which the grid will be generated.
dtype: the data type of the generated grid.
Return:
grid tensor with shape :math:`(1, H, W, 2)`.
Example:
>>> create_meshgrid(2, 2)
tensor([[[[-1., -1.],
[ 1., -1.]],
<BLANKLINE>
[[-1., 1.],
[ 1., 1.]]]])
>>> create_meshgrid(2, 2, normalized_coordinates=False)
tensor([[[[0., 0.],
[1., 0.]],
<BLANKLINE>
[[0., 1.],
[1., 1.]]]])
"""
xs: torch.Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
ys: torch.Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
# Fix TracerWarning
# Note: normalize_pixel_coordinates still gots TracerWarning since new width and height
# tensors will be generated.
# Below is the code using normalize_pixel_coordinates:
# base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=2)
# if normalized_coordinates:
# base_grid = K.geometry.normalize_pixel_coordinates(base_grid, height, width)
# return torch.unsqueeze(base_grid.transpose(0, 1), dim=0)
if normalized_coordinates:
xs = (xs / (width - 1) - 0.5) * 2
ys = (ys / (height - 1) - 0.5) * 2
# generate grid by stacking coordinates
base_grid: torch.Tensor = torch.stack(torch.meshgrid([xs, ys]), dim=-1) # WxHx2
return base_grid.permute(1, 0, 2).unsqueeze(0) # 1xHxWx2
def create_meshgrid3d(
depth: int,
height: int,
width: int,
normalized_coordinates: bool = True,
device: Optional[torch.device] = torch.device('cpu'),
dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""Generate a coordinate grid for an image.
When the flag ``normalized_coordinates`` is set to True, the grid is
normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
function :py:func:`torch.nn.functional.grid_sample`.
Args:
depth: the image depth (channels).
height: the image height (rows).
width: the image width (cols).
normalized_coordinates: whether to normalize
coordinates in the range :math:`[-1,1]` in order to be consistent with the
PyTorch function :py:func:`torch.nn.functional.grid_sample`.
device: the device on which the grid will be generated.
dtype: the data type of the generated grid.
Return:
grid tensor with shape :math:`(1, D, H, W, 3)`.
"""
xs: torch.Tensor = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
ys: torch.Tensor = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
zs: torch.Tensor = torch.linspace(0, depth - 1, depth, device=device, dtype=dtype)
# Fix TracerWarning
if normalized_coordinates:
xs = (xs / (width - 1) - 0.5) * 2
ys = (ys / (height - 1) - 0.5) * 2
zs = (zs / (depth - 1) - 0.5) * 2
# generate grid by stacking coordinates
base_grid: torch.Tensor = torch.stack(torch.meshgrid([zs, xs, ys]), dim=-1) # DxWxHx3
return base_grid.permute(0, 2, 1, 3).unsqueeze(0) # 1xDxHxWx3
|