YasiiKB's picture
initial commit
97aa5af verified
"""
@Author: Tiange Xiang
@Contact: txia7609@uni.sydney.edu.au
@File: curvenet_cls.py
@Time: 2021/01/21 3:10 PM
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .. utils import (
index_points,
farthest_point_sample,
query_ball_point,
LPFA,
CIC
)
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
"""
Input:
npoint:
radius:
nsample:
xyz: input points position data, [B, N, 3]
points: input points data, [B, N, D]
Return:
new_xyz: sampled points position data, [B, npoint, nsample, 3]
new_points: sampled points data, [B, npoint, nsample, 3+D]
"""
new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint))
torch.cuda.empty_cache()
idx = query_ball_point(radius, nsample, xyz, new_xyz)
torch.cuda.empty_cache()
new_points = index_points(points, idx)
torch.cuda.empty_cache()
if returnfps:
return new_xyz, new_points, idx
else:
return new_xyz, new_points
curve_config = {
'default': [[100, 5], [100, 5], None, None],
'long': [[10, 30], None, None, None]
}
class CurveNet(nn.Module):
def __init__(self, num_classes=40, k=20, setting='default', input_shape="bnc", emb_dims=2048, classifier=True):
super(CurveNet, self).__init__()
if input_shape not in ["bcn", "bnc"]:
raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
self.input_shape = input_shape
assert setting in curve_config
additional_channel = 32
self.classifier = classifier
self.lpfa = LPFA(9, additional_channel, k=k, mlp_num=1, initial=True)
# encoder
self.cic11 = CIC(npoint=1024, radius=0.05, k=k, in_channels=additional_channel, output_channels=64, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][0])
self.cic12 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=64, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][0])
self.cic21 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=128, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][1])
self.cic22 = CIC(npoint=1024, radius=0.1, k=k, in_channels=128, output_channels=128, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][1])
self.cic31 = CIC(npoint=256, radius=0.1, k=k, in_channels=128, output_channels=256, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][2])
self.cic32 = CIC(npoint=256, radius=0.2, k=k, in_channels=256, output_channels=256, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][2])
self.cic41 = CIC(npoint=64, radius=0.2, k=k, in_channels=256, output_channels=512, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][3])
self.cic42 = CIC(npoint=64, radius=0.4, k=k, in_channels=512, output_channels=512, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][3])
self.conv0 = nn.Sequential(
nn.Conv1d(512, emb_dims//2, kernel_size=1, bias=False),
nn.BatchNorm1d(emb_dims//2),
nn.ReLU(inplace=True))
if self.classifier:
self.conv1 = nn.Linear(emb_dims, 512, bias=False)
self.conv2 = nn.Linear(512, num_classes)
self.bn1 = nn.BatchNorm1d(512)
self.dp1 = nn.Dropout(p=0.5)
def forward(self, xyz, get_flatten_curve_idxs=False):
flatten_curve_idxs = {}
if self.input_shape == 'bnc':
xyz = xyz.permute(0, 2, 1)
l0_points = self.lpfa(xyz, xyz)
l1_xyz, l1_points, flatten_curve_idxs_11 = self.cic11(xyz, l0_points)
flatten_curve_idxs['flatten_curve_idxs_11'] = flatten_curve_idxs_11
l1_xyz, l1_points, flatten_curve_idxs_12 = self.cic12(l1_xyz, l1_points)
flatten_curve_idxs['flatten_curve_idxs_12'] = flatten_curve_idxs_12
l2_xyz, l2_points, flatten_curve_idxs_21 = self.cic21(l1_xyz, l1_points)
flatten_curve_idxs['flatten_curve_idxs_21'] = flatten_curve_idxs_21
l2_xyz, l2_points, flatten_curve_idxs_22 = self.cic22(l2_xyz, l2_points)
flatten_curve_idxs['flatten_curve_idxs_22'] = flatten_curve_idxs_22
l3_xyz, l3_points, flatten_curve_idxs_31 = self.cic31(l2_xyz, l2_points)
flatten_curve_idxs['flatten_curve_idxs_31'] = flatten_curve_idxs_31
l3_xyz, l3_points, flatten_curve_idxs_32 = self.cic32(l3_xyz, l3_points)
flatten_curve_idxs['flatten_curve_idxs_32'] = flatten_curve_idxs_32
l4_xyz, l4_points, flatten_curve_idxs_41 = self.cic41(l3_xyz, l3_points)
flatten_curve_idxs['flatten_curve_idxs_41'] = flatten_curve_idxs_41
l4_xyz, l4_points, flatten_curve_idxs_42 = self.cic42(l4_xyz, l4_points)
flatten_curve_idxs['flatten_curve_idxs_42'] = flatten_curve_idxs_42
x = self.conv0(l4_points)
x_max = F.adaptive_max_pool1d(x, 1)
x_avg = F.adaptive_avg_pool1d(x, 1)
x = torch.cat((x_max, x_avg), dim=1).squeeze(-1)
if self.classifier:
x = F.relu(self.bn1(self.conv1(x).unsqueeze(-1)), inplace=True).squeeze(-1)
x = self.dp1(x)
x = self.conv2(x)
if get_flatten_curve_idxs:
return x, flatten_curve_idxs
else:
return x