import torch import torch.nn as nn import torch.nn.functional as F from .pooling import Pooling class Classifier(nn.Module): def __init__(self, feature_model, num_classes=40): super(Classifier, self).__init__() self.feature_model = feature_model self.num_classes = num_classes self.linear1 = torch.nn.Linear(self.feature_model.emb_dims, 512) self.bn1 = torch.nn.BatchNorm1d(512) self.dropout1 = torch.nn.Dropout(p=0.7) self.linear2 = torch.nn.Linear(512, 256) self.bn2 = torch.nn.BatchNorm1d(256) self.dropout2 = torch.nn.Dropout(p=0.7) self.linear3 = torch.nn.Linear(256, self.num_classes) self.pooling = Pooling('max') def forward(self, input_data): output = self.pooling(self.feature_model(input_data)) output = F.relu(self.bn1(self.linear1(output))) output = self.dropout1(output) output = F.relu(self.bn2(self.linear2(output))) output = self.dropout2(output) output = self.linear3(output) return output if __name__ == '__main__': from pointnet import PointNet x = torch.rand(10,1024,3) pn = PointNet() classifier = Classifier(pn) classes = classifier(x) print('Input Shape: {}\nClassification Output Shape: {}' .format(x.shape, classes.shape))