| |
| |
| |
| |
| |
|
|
| |
| |
|
|
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from typing import Optional, Tuple, Union |
|
|
|
|
| def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor: |
| """ |
| This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. |
| It is a wrapper of get_2d_sincos_pos_embed_from_grid. |
| Args: |
| - embed_dim: The embedding dimension. |
| - grid_size: The grid size. |
| Returns: |
| - pos_embed: The generated 2D positional embedding. |
| """ |
| if isinstance(grid_size, tuple): |
| grid_size_h, grid_size_w = grid_size |
| else: |
| grid_size_h = grid_size_w = grid_size |
| grid_h = torch.arange(grid_size_h, dtype=torch.float) |
| grid_w = torch.arange(grid_size_w, dtype=torch.float) |
| grid = torch.meshgrid(grid_w, grid_h, indexing="xy") |
| grid = torch.stack(grid, dim=0) |
| grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) |
| pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) |
| if return_grid: |
| return ( |
| pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2), |
| grid, |
| ) |
| return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) |
|
|
|
|
| def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor: |
| """ |
| This function generates a 2D positional embedding from a given grid using sine and cosine functions. |
| |
| Args: |
| - embed_dim: The embedding dimension. |
| - grid: The grid to generate the embedding from. |
| |
| Returns: |
| - emb: The generated 2D positional embedding. |
| """ |
| assert embed_dim % 2 == 0 |
|
|
| |
| emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) |
| emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) |
|
|
| emb = torch.cat([emb_h, emb_w], dim=2) |
| return emb |
|
|
|
|
| def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor: |
| """ |
| This function generates a 1D positional embedding from a given grid using sine and cosine functions. |
| |
| Args: |
| - embed_dim: The embedding dimension. |
| - pos: The position to generate the embedding from. |
| |
| Returns: |
| - emb: The generated 1D positional embedding. |
| """ |
| assert embed_dim % 2 == 0 |
| omega = torch.arange(embed_dim // 2, dtype=torch.double) |
| omega /= embed_dim / 2.0 |
| omega = 1.0 / 10000**omega |
|
|
| pos = pos.reshape(-1) |
| out = torch.einsum("m,d->md", pos, omega) |
|
|
| emb_sin = torch.sin(out) |
| emb_cos = torch.cos(out) |
|
|
| emb = torch.cat([emb_sin, emb_cos], dim=1) |
| return emb[None].float() |
|
|
|
|
| def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: |
| """ |
| This function generates a 2D positional embedding from given coordinates using sine and cosine functions. |
| |
| Args: |
| - xy: The coordinates to generate the embedding from. |
| - C: The size of the embedding. |
| - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. |
| |
| Returns: |
| - pe: The generated 2D positional embedding. |
| """ |
| B, N, D = xy.shape |
| assert D == 2 |
|
|
| x = xy[:, :, 0:1] |
| y = xy[:, :, 1:2] |
| div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2)) |
|
|
| pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) |
| pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) |
|
|
| pe_x[:, :, 0::2] = torch.sin(x * div_term) |
| pe_x[:, :, 1::2] = torch.cos(x * div_term) |
|
|
| pe_y[:, :, 0::2] = torch.sin(y * div_term) |
| pe_y[:, :, 1::2] = torch.cos(y * div_term) |
|
|
| pe = torch.cat([pe_x, pe_y], dim=2) |
| if cat_coords: |
| pe = torch.cat([xy, pe], dim=2) |
| return pe |
|
|
|
|
| def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): |
| r"""Sample a tensor using bilinear interpolation |
| |
| `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at |
| coordinates :attr:`coords` using bilinear interpolation. It is the same |
| as `torch.nn.functional.grid_sample()` but with a different coordinate |
| convention. |
| |
| The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where |
| :math:`B` is the batch size, :math:`C` is the number of channels, |
| :math:`H` is the height of the image, and :math:`W` is the width of the |
| image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is |
| interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. |
| |
| Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, |
| in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note |
| that in this case the order of the components is slightly different |
| from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. |
| |
| If `align_corners` is `True`, the coordinate :math:`x` is assumed to be |
| in the range :math:`[0,W-1]`, with 0 corresponding to the center of the |
| left-most image pixel :math:`W-1` to the center of the right-most |
| pixel. |
| |
| If `align_corners` is `False`, the coordinate :math:`x` is assumed to |
| be in the range :math:`[0,W]`, with 0 corresponding to the left edge of |
| the left-most pixel :math:`W` to the right edge of the right-most |
| pixel. |
| |
| Similar conventions apply to the :math:`y` for the range |
| :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range |
| :math:`[0,T-1]` and :math:`[0,T]`. |
| |
| Args: |
| input (Tensor): batch of input images. |
| coords (Tensor): batch of coordinates. |
| align_corners (bool, optional): Coordinate convention. Defaults to `True`. |
| padding_mode (str, optional): Padding mode. Defaults to `"border"`. |
| |
| Returns: |
| Tensor: sampled points. |
| """ |
| coords = coords.detach().clone() |
| |
| |
| coords = coords.to(input.device).to(input.dtype) |
| |
|
|
| sizes = input.shape[2:] |
|
|
| assert len(sizes) in [2, 3] |
|
|
| if len(sizes) == 3: |
| |
| coords = coords[..., [1, 2, 0]] |
|
|
| if align_corners: |
| scale = torch.tensor( |
| [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype |
| ) |
| else: |
| scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype) |
|
|
| coords.mul_(scale) |
| coords.sub_(1) |
|
|
| return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) |
|
|
|
|
| def sample_features4d(input, coords): |
| r"""Sample spatial features |
| |
| `sample_features4d(input, coords)` samples the spatial features |
| :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. |
| |
| The field is sampled at coordinates :attr:`coords` using bilinear |
| interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, |
| 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the |
| same convention as :func:`bilinear_sampler` with `align_corners=True`. |
| |
| The output tensor has one feature per point, and has shape :math:`(B, |
| R, C)`. |
| |
| Args: |
| input (Tensor): spatial features. |
| coords (Tensor): points. |
| |
| Returns: |
| Tensor: sampled features. |
| """ |
|
|
| B, _, _, _ = input.shape |
|
|
| |
| coords = coords.unsqueeze(2) |
|
|
| |
| feats = bilinear_sampler(input, coords) |
|
|
| return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) |
|
|