| | from torch import nn as nn |
| | from torch.autograd import Function |
| |
|
| | from ..utils import ext_loader |
| |
|
| | ext_module = ext_loader.load_ext('_ext', ['roipoint_pool3d_forward']) |
| |
|
| |
|
| | class RoIPointPool3d(nn.Module): |
| | """Encode the geometry-specific features of each 3D proposal. |
| | |
| | Please refer to `Paper of PartA2 <https://arxiv.org/pdf/1907.03670.pdf>`_ |
| | for more details. |
| | |
| | Args: |
| | num_sampled_points (int, optional): Number of samples in each roi. |
| | Default: 512. |
| | """ |
| |
|
| | def __init__(self, num_sampled_points=512): |
| | super().__init__() |
| | self.num_sampled_points = num_sampled_points |
| |
|
| | def forward(self, points, point_features, boxes3d): |
| | """ |
| | Args: |
| | points (torch.Tensor): Input points whose shape is (B, N, C). |
| | point_features (torch.Tensor): Features of input points whose shape |
| | is (B, N, C). |
| | boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). |
| | |
| | Returns: |
| | pooled_features (torch.Tensor): The output pooled features whose |
| | shape is (B, M, 512, 3 + C). |
| | pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). |
| | """ |
| | return RoIPointPool3dFunction.apply(points, point_features, boxes3d, |
| | self.num_sampled_points) |
| |
|
| |
|
| | class RoIPointPool3dFunction(Function): |
| |
|
| | @staticmethod |
| | def forward(ctx, points, point_features, boxes3d, num_sampled_points=512): |
| | """ |
| | Args: |
| | points (torch.Tensor): Input points whose shape is (B, N, C). |
| | point_features (torch.Tensor): Features of input points whose shape |
| | is (B, N, C). |
| | boxes3d (B, M, 7), Input bounding boxes whose shape is (B, M, 7). |
| | num_sampled_points (int, optional): The num of sampled points. |
| | Default: 512. |
| | |
| | Returns: |
| | pooled_features (torch.Tensor): The output pooled features whose |
| | shape is (B, M, 512, 3 + C). |
| | pooled_empty_flag (torch.Tensor): Empty flag whose shape is (B, M). |
| | """ |
| | assert len(points.shape) == 3 and points.shape[2] == 3 |
| | batch_size, boxes_num, feature_len = points.shape[0], boxes3d.shape[ |
| | 1], point_features.shape[2] |
| | pooled_boxes3d = boxes3d.view(batch_size, -1, 7) |
| | pooled_features = point_features.new_zeros( |
| | (batch_size, boxes_num, num_sampled_points, 3 + feature_len)) |
| | pooled_empty_flag = point_features.new_zeros( |
| | (batch_size, boxes_num)).int() |
| |
|
| | ext_module.roipoint_pool3d_forward(points.contiguous(), |
| | pooled_boxes3d.contiguous(), |
| | point_features.contiguous(), |
| | pooled_features, pooled_empty_flag) |
| |
|
| | return pooled_features, pooled_empty_flag |
| |
|
| | @staticmethod |
| | def backward(ctx, grad_out): |
| | raise NotImplementedError |
| |
|