| |
|
|
| from typing import Callable |
|
|
| import torch |
| from torch.nn import functional as F |
|
|
|
|
| |
| def point_sample(input, point_coords, **kwargs): |
| """ |
| A wrapper around :function:`torch.nn.functional.grid_sample` to support 3D point_coords tensors. |
| Unlike :function:`torch.nn.functional.grid_sample` it assumes `point_coords` to lie inside |
| [0, 1] x [0, 1] square. |
| |
| Args: |
| input (Tensor): A tensor of shape (N, C, H, W) that contains features map on a H x W grid. |
| point_coords (Tensor): A tensor of shape (N, P, 2) or (N, Hgrid, Wgrid, 2) that contains |
| [0, 1] x [0, 1] normalized point coordinates. |
| |
| Returns: |
| output (Tensor): A tensor of shape (N, C, P) or (N, C, Hgrid, Wgrid) that contains |
| features for points in `point_coords`. The features are obtained via bilinear |
| interplation from `input` the same way as :function:`torch.nn.functional.grid_sample`. |
| """ |
| add_dim = False |
| if point_coords.dim() == 3: |
| add_dim = True |
| point_coords = point_coords.unsqueeze(2) |
| normalized_point_coords = 2.0 * point_coords - 1.0 |
| output = F.grid_sample(input, normalized_point_coords, **kwargs) |
| if add_dim: |
| output = output.squeeze(3) |
| return output |
|
|
|
|
| |
| def get_uncertain_point_coords_with_randomness( |
| logits: torch.Tensor, |
| uncertainty_func: Callable, |
| num_points: int, |
| oversample_ratio: int, |
| importance_sample_ratio: float, |
| ) -> torch.Tensor: |
| """ |
| Sample points in [0, 1] x [0, 1] coordinate space based on their uncertainty. The unceratinties |
| are calculated for each point using 'uncertainty_func' function that takes point's logit |
| prediction as input. |
| See PointRend paper for details. |
| |
| Args: |
| logits (Tensor): A tensor of shape (N, C, Hmask, Wmask) or (N, 1, Hmask, Wmask) for |
| class-specific or class-agnostic prediction. |
| uncertainty_func: A function that takes a Tensor of shape (N, C, P) or (N, 1, P) that |
| contains logit predictions for P points and returns their uncertainties as a Tensor of |
| shape (N, 1, P). |
| num_points (int): The number of points P to sample. |
| oversample_ratio (int): Oversampling parameter. |
| importance_sample_ratio (float): Ratio of points that are sampled via importnace sampling. |
| |
| Returns: |
| point_coords (Tensor): A tensor of shape (N, P, 2) that contains the coordinates of P |
| sampled points. |
| """ |
| assert oversample_ratio >= 1 |
| assert importance_sample_ratio <= 1 and importance_sample_ratio >= 0 |
| num_boxes = logits.shape[0] |
| num_sampled = int(num_points * oversample_ratio) |
| point_coords = torch.rand(num_boxes, num_sampled, 2, device=logits.device) |
| point_logits = point_sample(logits, point_coords, align_corners=False) |
| |
| |
| |
| |
| |
| |
| |
| point_uncertainties = uncertainty_func(point_logits) |
| num_uncertain_points = int(importance_sample_ratio * num_points) |
| num_random_points = num_points - num_uncertain_points |
| idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] |
| |
| shift = num_sampled * torch.arange( |
| num_boxes, dtype=torch.long, device=logits.device |
| ) |
| idx += shift[:, None] |
| point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view( |
| num_boxes, num_uncertain_points, 2 |
| ) |
| if num_random_points > 0: |
| point_coords = torch.cat( |
| [ |
| point_coords, |
| torch.rand(num_boxes, num_random_points, 2, device=logits.device), |
| ], |
| dim=1, |
| ) |
| return point_coords |
|
|
|
|
| |
| def calculate_uncertainty(logits: torch.Tensor) -> torch.Tensor: |
| """ |
| Estimates uncerainty as L1 distance between 0.0 and the logit prediction. |
| Args: |
| logits (Tensor): A tensor of shape (R, 1, ...) for class-agnostic |
| predicted masks |
| Returns: |
| scores (Tensor): A tensor of shape (R, 1, ...) that contains uncertainty scores with |
| the most uncertain locations having the highest uncertainty score. |
| """ |
| assert logits.shape[1] == 1 |
| return -(torch.abs(logits)) |
|
|