| import torch |
| import torch.nn.functional as F |
| from .. utils import knn, get_graph_feature |
|
|
|
|
| class DGCNN(torch.nn.Module): |
| def __init__(self, emb_dims=1024, input_shape="bnc"): |
| super(DGCNN, self).__init__() |
| if input_shape not in ["bcn", "bnc"]: |
| raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ") |
| self.input_shape = input_shape |
| self.emb_dims = emb_dims |
|
|
| self.conv1 = torch.nn.Conv2d(6, 64, kernel_size=1, bias=False) |
| self.conv2 = torch.nn.Conv2d(64, 64, kernel_size=1, bias=False) |
| self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=1, bias=False) |
| self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=1, bias=False) |
| self.conv5 = torch.nn.Conv2d(512, emb_dims, kernel_size=1, bias=False) |
| self.bn1 = torch.nn.BatchNorm2d(64) |
| self.bn2 = torch.nn.BatchNorm2d(64) |
| self.bn3 = torch.nn.BatchNorm2d(128) |
| self.bn4 = torch.nn.BatchNorm2d(256) |
| self.bn5 = torch.nn.BatchNorm2d(emb_dims) |
|
|
| def forward(self, input_data): |
| if self.input_shape == "bnc": |
| input_data = input_data.permute(0, 2, 1) |
| if input_data.shape[1] != 3: |
| raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]") |
|
|
| batch_size, num_dims, num_points = input_data.size() |
| output = get_graph_feature(input_data) |
|
|
| output = F.relu(self.bn1(self.conv1(output))) |
| output1 = output.max(dim=-1, keepdim=True)[0] |
|
|
| output = F.relu(self.bn2(self.conv2(output))) |
| output2 = output.max(dim=-1, keepdim=True)[0] |
|
|
| output = F.relu(self.bn3(self.conv3(output))) |
| output3 = output.max(dim=-1, keepdim=True)[0] |
|
|
| output = F.relu(self.bn4(self.conv4(output))) |
| output4 = output.max(dim=-1, keepdim=True)[0] |
|
|
| output = torch.cat((output1, output2, output3, output4), dim=1) |
|
|
| output = F.relu(self.bn5(self.conv5(output))).view(batch_size, -1, num_points) |
| return output |
|
|
|
|
| if __name__ == '__main__': |
| |
| x = torch.rand((10,1024,3)) |
|
|
| dgcnn = DGCNN() |
| y = dgcnn(x) |
| print("\nInput Shape of DGCNN: ", x.shape, "\nOutput Shape of DGCNN: ", y.shape) |