| """ |
| @Author: Tiange Xiang |
| @Contact: txia7609@uni.sydney.edu.au |
| @File: curvenet_cls.py |
| @Time: 2021/01/21 3:10 PM |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .. utils import ( |
| index_points, |
| farthest_point_sample, |
| query_ball_point, |
| LPFA, |
| CIC |
| ) |
|
|
| def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): |
| """ |
| Input: |
| npoint: |
| radius: |
| nsample: |
| xyz: input points position data, [B, N, 3] |
| points: input points data, [B, N, D] |
| Return: |
| new_xyz: sampled points position data, [B, npoint, nsample, 3] |
| new_points: sampled points data, [B, npoint, nsample, 3+D] |
| """ |
| new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint)) |
| torch.cuda.empty_cache() |
|
|
| idx = query_ball_point(radius, nsample, xyz, new_xyz) |
| torch.cuda.empty_cache() |
|
|
| new_points = index_points(points, idx) |
| torch.cuda.empty_cache() |
|
|
| if returnfps: |
| return new_xyz, new_points, idx |
| else: |
| return new_xyz, new_points |
|
|
| curve_config = { |
| 'default': [[100, 5], [100, 5], None, None], |
| 'long': [[10, 30], None, None, None] |
| } |
|
|
| class CurveNet(nn.Module): |
| def __init__(self, num_classes=40, k=20, setting='default', input_shape="bnc", emb_dims=2048, classifier=True): |
| super(CurveNet, self).__init__() |
|
|
| if input_shape not in ["bcn", "bnc"]: |
| raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ") |
| |
| self.input_shape = input_shape |
|
|
| assert setting in curve_config |
|
|
| additional_channel = 32 |
| self.classifier = classifier |
| self.lpfa = LPFA(9, additional_channel, k=k, mlp_num=1, initial=True) |
|
|
| |
| self.cic11 = CIC(npoint=1024, radius=0.05, k=k, in_channels=additional_channel, output_channels=64, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][0]) |
| self.cic12 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=64, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][0]) |
| |
| self.cic21 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=128, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][1]) |
| self.cic22 = CIC(npoint=1024, radius=0.1, k=k, in_channels=128, output_channels=128, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][1]) |
|
|
| self.cic31 = CIC(npoint=256, radius=0.1, k=k, in_channels=128, output_channels=256, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][2]) |
| self.cic32 = CIC(npoint=256, radius=0.2, k=k, in_channels=256, output_channels=256, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][2]) |
|
|
| self.cic41 = CIC(npoint=64, radius=0.2, k=k, in_channels=256, output_channels=512, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][3]) |
| self.cic42 = CIC(npoint=64, radius=0.4, k=k, in_channels=512, output_channels=512, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][3]) |
|
|
| self.conv0 = nn.Sequential( |
| nn.Conv1d(512, emb_dims//2, kernel_size=1, bias=False), |
| nn.BatchNorm1d(emb_dims//2), |
| nn.ReLU(inplace=True)) |
| |
| if self.classifier: |
| self.conv1 = nn.Linear(emb_dims, 512, bias=False) |
| self.conv2 = nn.Linear(512, num_classes) |
| self.bn1 = nn.BatchNorm1d(512) |
| self.dp1 = nn.Dropout(p=0.5) |
|
|
| def forward(self, xyz, get_flatten_curve_idxs=False): |
| flatten_curve_idxs = {} |
| if self.input_shape == 'bnc': |
| xyz = xyz.permute(0, 2, 1) |
|
|
| l0_points = self.lpfa(xyz, xyz) |
|
|
| l1_xyz, l1_points, flatten_curve_idxs_11 = self.cic11(xyz, l0_points) |
| flatten_curve_idxs['flatten_curve_idxs_11'] = flatten_curve_idxs_11 |
| l1_xyz, l1_points, flatten_curve_idxs_12 = self.cic12(l1_xyz, l1_points) |
| flatten_curve_idxs['flatten_curve_idxs_12'] = flatten_curve_idxs_12 |
|
|
| l2_xyz, l2_points, flatten_curve_idxs_21 = self.cic21(l1_xyz, l1_points) |
| flatten_curve_idxs['flatten_curve_idxs_21'] = flatten_curve_idxs_21 |
| l2_xyz, l2_points, flatten_curve_idxs_22 = self.cic22(l2_xyz, l2_points) |
| flatten_curve_idxs['flatten_curve_idxs_22'] = flatten_curve_idxs_22 |
|
|
| l3_xyz, l3_points, flatten_curve_idxs_31 = self.cic31(l2_xyz, l2_points) |
| flatten_curve_idxs['flatten_curve_idxs_31'] = flatten_curve_idxs_31 |
| l3_xyz, l3_points, flatten_curve_idxs_32 = self.cic32(l3_xyz, l3_points) |
| flatten_curve_idxs['flatten_curve_idxs_32'] = flatten_curve_idxs_32 |
| |
| l4_xyz, l4_points, flatten_curve_idxs_41 = self.cic41(l3_xyz, l3_points) |
| flatten_curve_idxs['flatten_curve_idxs_41'] = flatten_curve_idxs_41 |
| l4_xyz, l4_points, flatten_curve_idxs_42 = self.cic42(l4_xyz, l4_points) |
| flatten_curve_idxs['flatten_curve_idxs_42'] = flatten_curve_idxs_42 |
|
|
| x = self.conv0(l4_points) |
| x_max = F.adaptive_max_pool1d(x, 1) |
| x_avg = F.adaptive_avg_pool1d(x, 1) |
| |
| x = torch.cat((x_max, x_avg), dim=1).squeeze(-1) |
|
|
| if self.classifier: |
| x = F.relu(self.bn1(self.conv1(x).unsqueeze(-1)), inplace=True).squeeze(-1) |
| x = self.dp1(x) |
| x = self.conv2(x) |
| |
| if get_flatten_curve_idxs: |
| return x, flatten_curve_idxs |
| else: |
| return x |
|
|