| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .. utils import PointConvDensitySetAbstraction |
|
|
| class PointConvDensityClsSsg(torch.nn.Module): |
| def __init__(self, emb_dims=1024, input_shape="bnc", input_channel_dim=3, classifier=False, num_classes=40, pretrained=None): |
| super(PointConvDensityClsSsg, self).__init__() |
| if input_shape not in ["bnc", "bcn"]: |
| raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ") |
| self.input_shape = input_shape |
| self.emb_dims = emb_dims |
| self.classifier = classifier |
| self.input_channel_dim = input_channel_dim |
| self.create_structure() |
| if self.classifier: self.create_classifier(num_classes) |
|
|
| def create_structure(self): |
| |
| |
| |
| |
| |
| |
| |
| self.sa1 = PointConvDensitySetAbstraction(npoint=512, nsample=32, in_channel=self.input_channel_dim, |
| mlp=[64, 64, 128], bandwidth = 0.1, group_all=False) |
| self.sa2 = PointConvDensitySetAbstraction(npoint=128, nsample=64, in_channel=128 + 3, |
| mlp=[128, 128, 256], bandwidth = 0.2, group_all=False) |
| self.sa3 = PointConvDensitySetAbstraction(npoint=1, nsample=None, in_channel=256 + 3, |
| mlp=[256, 512, self.emb_dims], bandwidth = 0.4, group_all=True) |
|
|
| def create_classifier(self, num_classes): |
| |
| |
| |
| self.fc1 = nn.Linear(self.emb_dims, 512) |
| self.bn1 = nn.BatchNorm1d(512) |
| self.drop1 = nn.Dropout(0.7) |
| self.fc2 = nn.Linear(512, 256) |
| self.bn2 = nn.BatchNorm1d(256) |
| self.drop2 = nn.Dropout(0.7) |
| self.fc3 = nn.Linear(256, num_classes) |
|
|
| def forward(self, input_data): |
| if self.input_shape == "bnc": |
| input_data = input_data.permute(0, 2, 1) |
| batch_size = input_data.shape[0] |
|
|
| |
| l1_points, l1_features = self.sa1(input_data[:, :3, :], input_data[:, 3:, :]) |
| l2_points, l2_features = self.sa2(l1_points, l1_features) |
| l3_points, l3_features = self.sa3(l2_points, l2_features) |
| features = l3_features.view(batch_size, self.emb_dims) |
|
|
| if self.classifier: |
| |
| features = self.drop1(F.relu(self.bn1(self.fc1(features)))) |
| features = self.drop2(F.relu(self.bn2(self.fc2(features)))) |
| features = self.fc3(features) |
| output = F.log_softmax(features, -1) |
| else: |
| |
| output = features |
|
|
| return output |
|
|
| def create_pointconv(classifier=False, pretrained=None): |
| if classifier and pretrained is not None: |
| class Network(torch.nn.Module): |
| def __init__(self, emb_dims=1024, input_shape="bnc", input_channel_dim=3, classifier=False, num_classes=40, pretrained=None): |
| |
| |
| |
| |
| |
| |
| |
| super(PointConv, self).__init__() |
| self.pointconv = PointConvDensityClsSsg(emb_dims, input_shape, input_channel_dim, classifier, num_classes) |
| |
| if classifier and pretrained is not None: |
| self.use_pretrained(pretrained) |
|
|
| def use_pretrained(self, pretrained): |
| checkpoint = torch.load(pretrained, map_location='cpu') |
| self.pointconv.load_state_dict(checkpoint['model_state_dict']) |
|
|
| def forward(self, input_data): |
| return self.pointconv(input_data) |
| return Network |
| else: |
| class Network(PointConvDensityClsSsg): |
| def __init__(self, emb_dims=1024, input_shape="bnc", input_channel_dim=3, classifier=False, num_classes=40, pretrained=None): |
| super().__init__(emb_dims=emb_dims, input_shape=input_shape, input_channel_dim=input_channel_dim, classifier=classifier, num_classes=num_classes, pretrained=pretrained) |
| return Network |
|
|
|
|
| if __name__ == '__main__': |
| |
| x = torch.rand((2,1024,3)) |
|
|
| PointConv = create_pointconv(classifier=False, pretrained='checkpoint.pth') |
| pc = PointConv(input_channel_dim=3, classifier=False, pretrained='checkpoint.pth') |
| y = pc(x) |
| print("Network Architecture: ") |
| print(pc) |
| print("Input Shape of PointNet: ", x.shape, "\nOutput Shape of PointNet: ", y.shape) |