File size: 5,461 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
@Author: Tiange Xiang
@Contact: txia7609@uni.sydney.edu.au
@File: curvenet_cls.py
@Time: 2021/01/21 3:10 PM
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from .. utils import (
    index_points, 
    farthest_point_sample, 
    query_ball_point,
    LPFA,
    CIC
)

def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    new_xyz = index_points(xyz, farthest_point_sample(xyz, npoint))
    torch.cuda.empty_cache()

    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    torch.cuda.empty_cache()

    new_points = index_points(points, idx)
    torch.cuda.empty_cache()

    if returnfps:
        return new_xyz, new_points, idx
    else:
        return new_xyz, new_points

curve_config = {
        'default': [[100, 5], [100, 5], None, None],
        'long':  [[10, 30], None,  None,  None]
    }

class CurveNet(nn.Module):
    def __init__(self, num_classes=40, k=20, setting='default', input_shape="bnc", emb_dims=2048, classifier=True):
        super(CurveNet, self).__init__()

        if input_shape not in ["bcn", "bnc"]:
            raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
        
        self.input_shape = input_shape

        assert setting in curve_config

        additional_channel = 32
        self.classifier = classifier
        self.lpfa = LPFA(9, additional_channel, k=k, mlp_num=1, initial=True)

        # encoder
        self.cic11 = CIC(npoint=1024, radius=0.05, k=k, in_channels=additional_channel, output_channels=64, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][0])
        self.cic12 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=64, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][0])
        
        self.cic21 = CIC(npoint=1024, radius=0.05, k=k, in_channels=64, output_channels=128, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][1])
        self.cic22 = CIC(npoint=1024, radius=0.1, k=k, in_channels=128, output_channels=128, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][1])

        self.cic31 = CIC(npoint=256, radius=0.1, k=k, in_channels=128, output_channels=256, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][2])
        self.cic32 = CIC(npoint=256, radius=0.2, k=k, in_channels=256, output_channels=256, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][2])

        self.cic41 = CIC(npoint=64, radius=0.2, k=k, in_channels=256, output_channels=512, bottleneck_ratio=2, mlp_num=1, curve_config=curve_config[setting][3])
        self.cic42 = CIC(npoint=64, radius=0.4, k=k, in_channels=512, output_channels=512, bottleneck_ratio=4, mlp_num=1, curve_config=curve_config[setting][3])

        self.conv0 = nn.Sequential(
            nn.Conv1d(512, emb_dims//2, kernel_size=1, bias=False),
            nn.BatchNorm1d(emb_dims//2),
            nn.ReLU(inplace=True))
        
        if self.classifier:
            self.conv1 = nn.Linear(emb_dims, 512, bias=False)
            self.conv2 = nn.Linear(512, num_classes)
            self.bn1 = nn.BatchNorm1d(512)
            self.dp1 = nn.Dropout(p=0.5)

    def forward(self, xyz, get_flatten_curve_idxs=False):
        flatten_curve_idxs = {}
        if self.input_shape == 'bnc':
            xyz = xyz.permute(0, 2, 1)

        l0_points = self.lpfa(xyz, xyz)

        l1_xyz, l1_points, flatten_curve_idxs_11 = self.cic11(xyz, l0_points)
        flatten_curve_idxs['flatten_curve_idxs_11'] = flatten_curve_idxs_11
        l1_xyz, l1_points, flatten_curve_idxs_12 = self.cic12(l1_xyz, l1_points)
        flatten_curve_idxs['flatten_curve_idxs_12'] = flatten_curve_idxs_12

        l2_xyz, l2_points, flatten_curve_idxs_21 = self.cic21(l1_xyz, l1_points)
        flatten_curve_idxs['flatten_curve_idxs_21'] = flatten_curve_idxs_21
        l2_xyz, l2_points, flatten_curve_idxs_22 = self.cic22(l2_xyz, l2_points)
        flatten_curve_idxs['flatten_curve_idxs_22'] = flatten_curve_idxs_22

        l3_xyz, l3_points, flatten_curve_idxs_31 = self.cic31(l2_xyz, l2_points)
        flatten_curve_idxs['flatten_curve_idxs_31'] = flatten_curve_idxs_31
        l3_xyz, l3_points, flatten_curve_idxs_32 = self.cic32(l3_xyz, l3_points)
        flatten_curve_idxs['flatten_curve_idxs_32'] = flatten_curve_idxs_32
 
        l4_xyz, l4_points, flatten_curve_idxs_41 = self.cic41(l3_xyz, l3_points)
        flatten_curve_idxs['flatten_curve_idxs_41'] = flatten_curve_idxs_41
        l4_xyz, l4_points, flatten_curve_idxs_42 = self.cic42(l4_xyz, l4_points)
        flatten_curve_idxs['flatten_curve_idxs_42'] = flatten_curve_idxs_42

        x = self.conv0(l4_points)
        x_max = F.adaptive_max_pool1d(x, 1)
        x_avg = F.adaptive_avg_pool1d(x, 1)
        
        x = torch.cat((x_max, x_avg), dim=1).squeeze(-1)

        if self.classifier:
            x = F.relu(self.bn1(self.conv1(x).unsqueeze(-1)), inplace=True).squeeze(-1)
            x = self.dp1(x)
            x = self.conv2(x)
        
        if get_flatten_curve_idxs:
            return x, flatten_curve_idxs
        else:
            return x