| |
|
|
| from typing import Tuple, Union |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch import Tensor |
| from torch.nn.modules.utils import _pair |
|
|
|
|
| def bilinear_grid_sample(im: Tensor, |
| grid: Tensor, |
| align_corners: bool = False) -> Tensor: |
| """Given an input and a flow-field grid, computes the output using input |
| values and pixel locations from grid. Supported only bilinear interpolation |
| method to sample the input pixels. |
| |
| Args: |
| im (torch.Tensor): Input feature map, shape (N, C, H, W) |
| grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2) |
| align_corners (bool): If set to True, the extrema (-1 and 1) are |
| considered as referring to the center points of the input’s |
| corner pixels. If set to False, they are instead considered as |
| referring to the corner points of the input’s corner pixels, |
| making the sampling more resolution agnostic. |
| |
| Returns: |
| torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg) |
| """ |
| n, c, h, w = im.shape |
| gn, gh, gw, _ = grid.shape |
| assert n == gn |
|
|
| x = grid[:, :, :, 0] |
| y = grid[:, :, :, 1] |
|
|
| if align_corners: |
| x = ((x + 1) / 2) * (w - 1) |
| y = ((y + 1) / 2) * (h - 1) |
| else: |
| x = ((x + 1) * w - 1) / 2 |
| y = ((y + 1) * h - 1) / 2 |
|
|
| x = x.view(n, -1) |
| y = y.view(n, -1) |
|
|
| x0 = torch.floor(x).long() |
| y0 = torch.floor(y).long() |
| x1 = x0 + 1 |
| y1 = y0 + 1 |
|
|
| wa = ((x1 - x) * (y1 - y)).unsqueeze(1) |
| wb = ((x1 - x) * (y - y0)).unsqueeze(1) |
| wc = ((x - x0) * (y1 - y)).unsqueeze(1) |
| wd = ((x - x0) * (y - y0)).unsqueeze(1) |
|
|
| |
| im_padded = F.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0) |
| padded_h = h + 2 |
| padded_w = w + 2 |
| |
| x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1 |
|
|
| |
| x0 = torch.where(x0 < 0, torch.tensor(0), x0) |
| x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1), x0) |
| x1 = torch.where(x1 < 0, torch.tensor(0), x1) |
| x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1), x1) |
| y0 = torch.where(y0 < 0, torch.tensor(0), y0) |
| y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1), y0) |
| y1 = torch.where(y1 < 0, torch.tensor(0), y1) |
| y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1), y1) |
|
|
| im_padded = im_padded.view(n, c, -1) |
|
|
| x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) |
| x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) |
| x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1) |
| x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1) |
|
|
| Ia = torch.gather(im_padded, 2, x0_y0) |
| Ib = torch.gather(im_padded, 2, x0_y1) |
| Ic = torch.gather(im_padded, 2, x1_y0) |
| Id = torch.gather(im_padded, 2, x1_y1) |
|
|
| return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw) |
|
|
|
|
| def normalize(grid: Tensor) -> Tensor: |
| """Normalize input grid from [-1, 1] to [0, 1] |
| |
| Args: |
| grid (torch.Tensor): The grid to be normalize, range [-1, 1]. |
| |
| Returns: |
| torch.Tensor: Normalized grid, range [0, 1]. |
| """ |
|
|
| return (grid + 1.0) / 2.0 |
|
|
|
|
| def denormalize(grid: Tensor) -> Tensor: |
| """Denormalize input grid from range [0, 1] to [-1, 1] |
| |
| Args: |
| grid (torch.Tensor): The grid to be denormalize, range [0, 1]. |
| |
| Returns: |
| torch.Tensor: Denormalized grid, range [-1, 1]. |
| """ |
|
|
| return grid * 2.0 - 1.0 |
|
|
|
|
| def generate_grid(num_grid: int, size: Tuple[int, int], |
| device: torch.device) -> Tensor: |
| """Generate regular square grid of points in [0, 1] x [0, 1] coordinate |
| space. |
| |
| Args: |
| num_grid (int): The number of grids to sample, one for each region. |
| size (tuple[int, int]): The side size of the regular grid. |
| device (torch.device): Desired device of returned tensor. |
| |
| Returns: |
| torch.Tensor: A tensor of shape (num_grid, size[0]*size[1], 2) that |
| contains coordinates for the regular grids. |
| """ |
|
|
| affine_trans = torch.tensor([[[1., 0., 0.], [0., 1., 0.]]], device=device) |
| grid = F.affine_grid( |
| affine_trans, torch.Size((1, 1, *size)), align_corners=False) |
| grid = normalize(grid) |
| return grid.view(1, -1, 2).expand(num_grid, -1, -1) |
|
|
|
|
| def rel_roi_point_to_abs_img_point(rois: Tensor, |
| rel_roi_points: Tensor) -> Tensor: |
| """Convert roi based relative point coordinates to image based absolute |
| point coordinates. |
| |
| Args: |
| rois (torch.Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5) |
| rel_roi_points (torch.Tensor): Point coordinates inside RoI, relative |
| to RoI, location, range (0, 1), shape (N, P, 2) |
| Returns: |
| torch.Tensor: Image based absolute point coordinates, shape (N, P, 2) |
| """ |
|
|
| with torch.no_grad(): |
| assert rel_roi_points.size(0) == rois.size(0) |
| assert rois.dim() == 2 |
| assert rel_roi_points.dim() == 3 |
| assert rel_roi_points.size(2) == 2 |
| |
| if rois.size(1) == 5: |
| rois = rois[:, 1:] |
| abs_img_points = rel_roi_points.clone() |
| |
| |
| xs = abs_img_points[:, :, 0] * (rois[:, None, 2] - rois[:, None, 0]) |
| ys = abs_img_points[:, :, 1] * (rois[:, None, 3] - rois[:, None, 1]) |
| xs += rois[:, None, 0] |
| ys += rois[:, None, 1] |
| abs_img_points = torch.stack([xs, ys], dim=2) |
| return abs_img_points |
|
|
|
|
| def get_shape_from_feature_map(x: Tensor) -> Tensor: |
| """Get spatial resolution of input feature map considering exporting to |
| onnx mode. |
| |
| Args: |
| x (torch.Tensor): Input tensor, shape (N, C, H, W) |
| |
| Returns: |
| torch.Tensor: Spatial resolution (width, height), shape (1, 1, 2) |
| """ |
| img_shape = torch.tensor(x.shape[2:]).flip(0).view(1, 1, |
| 2).to(x.device).float() |
| return img_shape |
|
|
|
|
| def abs_img_point_to_rel_img_point(abs_img_points: Tensor, |
| img: Union[tuple, Tensor], |
| spatial_scale: float = 1.) -> Tensor: |
| """Convert image based absolute point coordinates to image based relative |
| coordinates for sampling. |
| |
| Args: |
| abs_img_points (torch.Tensor): Image based absolute point coordinates, |
| shape (N, P, 2) |
| img (tuple or torch.Tensor): (height, width) of image or feature map. |
| spatial_scale (float, optional): Scale points by this factor. |
| Default: 1. |
| |
| Returns: |
| Tensor: Image based relative point coordinates for sampling, shape |
| (N, P, 2). |
| """ |
|
|
| assert (isinstance(img, tuple) and len(img) == 2) or \ |
| (isinstance(img, torch.Tensor) and len(img.shape) == 4) |
|
|
| if isinstance(img, tuple): |
| h, w = img |
| scale = torch.tensor([w, h], |
| dtype=torch.float, |
| device=abs_img_points.device) |
| scale = scale.view(1, 1, 2) |
| else: |
| scale = get_shape_from_feature_map(img) |
|
|
| return abs_img_points / scale * spatial_scale |
|
|
|
|
| def rel_roi_point_to_rel_img_point(rois: Tensor, |
| rel_roi_points: Tensor, |
| img: Union[tuple, Tensor], |
| spatial_scale: float = 1.) -> Tensor: |
| """Convert roi based relative point coordinates to image based absolute |
| point coordinates. |
| |
| Args: |
| rois (torch.Tensor): RoIs or BBoxes, shape (N, 4) or (N, 5) |
| rel_roi_points (torch.Tensor): Point coordinates inside RoI, relative |
| to RoI, location, range (0, 1), shape (N, P, 2) |
| img (tuple or torch.Tensor): (height, width) of image or feature map. |
| spatial_scale (float, optional): Scale points by this factor. |
| Default: 1. |
| |
| Returns: |
| torch.Tensor: Image based relative point coordinates for sampling, |
| shape (N, P, 2). |
| """ |
|
|
| abs_img_point = rel_roi_point_to_abs_img_point(rois, rel_roi_points) |
| rel_img_point = abs_img_point_to_rel_img_point(abs_img_point, img, |
| spatial_scale) |
|
|
| return rel_img_point |
|
|
|
|
| def point_sample(input: Tensor, |
| points: Tensor, |
| align_corners: bool = False, |
| **kwargs) -> Tensor: |
| """A wrapper around :func:`grid_sample` to support 3D point_coords tensors |
| Unlike :func:`torch.nn.functional.grid_sample` it assumes point_coords to |
| lie inside ``[0, 1] x [0, 1]`` square. |
| |
| Args: |
| input (torch.Tensor): Feature map, shape (N, C, H, W). |
| points (torch.Tensor): Image based absolute point coordinates |
| (normalized), range [0, 1] x [0, 1], shape (N, P, 2) or |
| (N, Hgrid, Wgrid, 2). |
| align_corners (bool, optional): Whether align_corners. |
| Default: False |
| |
| Returns: |
| torch.Tensor: Features of `point` on `input`, shape (N, C, P) or |
| (N, C, Hgrid, Wgrid). |
| """ |
|
|
| add_dim = False |
| if points.dim() == 3: |
| add_dim = True |
| points = points.unsqueeze(2) |
| output = F.grid_sample( |
| input, denormalize(points), align_corners=align_corners, **kwargs) |
| if add_dim: |
| output = output.squeeze(3) |
| return output |
|
|
|
|
| class SimpleRoIAlign(nn.Module): |
|
|
| def __init__(self, |
| output_size: Tuple[int], |
| spatial_scale: float, |
| aligned: bool = True) -> None: |
| """Simple RoI align in PointRend, faster than standard RoIAlign. |
| |
| Args: |
| output_size (tuple[int]): h, w |
| spatial_scale (float): scale the input boxes by this number |
| aligned (bool): if False, use the legacy implementation in |
| MMDetection, align_corners=True will be used in F.grid_sample. |
| If True, align the results more perfectly. |
| """ |
|
|
| super().__init__() |
| self.output_size = _pair(output_size) |
| self.spatial_scale = float(spatial_scale) |
| |
| self.use_torchvision = False |
| self.aligned = aligned |
|
|
| def forward(self, features: Tensor, rois: Tensor) -> Tensor: |
| num_imgs = features.size(0) |
| num_rois = rois.size(0) |
| rel_roi_points = generate_grid( |
| num_rois, self.output_size, device=rois.device) |
|
|
| point_feats = [] |
| for batch_ind in range(num_imgs): |
| |
| feat = features[batch_ind].unsqueeze(0) |
| inds = (rois[:, 0].long() == batch_ind) |
| if inds.any(): |
| rel_img_points = rel_roi_point_to_rel_img_point( |
| rois[inds], rel_roi_points[inds], feat, |
| self.spatial_scale).unsqueeze(0) |
| point_feat = point_sample( |
| feat, rel_img_points, align_corners=not self.aligned) |
| point_feat = point_feat.squeeze(0).transpose(0, 1) |
| point_feats.append(point_feat) |
|
|
| point_feats_t = torch.cat(point_feats, dim=0) |
|
|
| channels = features.size(1) |
| roi_feats = point_feats_t.reshape(num_rois, channels, |
| *self.output_size) |
|
|
| return roi_feats |
|
|
| def __repr__(self) -> str: |
| format_str = self.__class__.__name__ |
| format_str += '(output_size={}, spatial_scale={}'.format( |
| self.output_size, self.spatial_scale) |
| return format_str |
|
|