File size: 1,989 Bytes
97aa5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | import torch
import torch.nn.functional as F
from .. utils import knn, get_graph_feature
class DGCNN(torch.nn.Module):
def __init__(self, emb_dims=1024, input_shape="bnc"):
super(DGCNN, self).__init__()
if input_shape not in ["bcn", "bnc"]:
raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
self.input_shape = input_shape
self.emb_dims = emb_dims
self.conv1 = torch.nn.Conv2d(6, 64, kernel_size=1, bias=False)
self.conv2 = torch.nn.Conv2d(64, 64, kernel_size=1, bias=False)
self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=1, bias=False)
self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=1, bias=False)
self.conv5 = torch.nn.Conv2d(512, emb_dims, kernel_size=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(64)
self.bn2 = torch.nn.BatchNorm2d(64)
self.bn3 = torch.nn.BatchNorm2d(128)
self.bn4 = torch.nn.BatchNorm2d(256)
self.bn5 = torch.nn.BatchNorm2d(emb_dims)
def forward(self, input_data):
if self.input_shape == "bnc":
input_data = input_data.permute(0, 2, 1)
if input_data.shape[1] != 3:
raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
batch_size, num_dims, num_points = input_data.size()
output = get_graph_feature(input_data)
output = F.relu(self.bn1(self.conv1(output)))
output1 = output.max(dim=-1, keepdim=True)[0]
output = F.relu(self.bn2(self.conv2(output)))
output2 = output.max(dim=-1, keepdim=True)[0]
output = F.relu(self.bn3(self.conv3(output)))
output3 = output.max(dim=-1, keepdim=True)[0]
output = F.relu(self.bn4(self.conv4(output)))
output4 = output.max(dim=-1, keepdim=True)[0]
output = torch.cat((output1, output2, output3, output4), dim=1)
output = F.relu(self.bn5(self.conv5(output))).view(batch_size, -1, num_points)
return output
if __name__ == '__main__':
# Test the code.
x = torch.rand((10,1024,3))
dgcnn = DGCNN()
y = dgcnn(x)
print("\nInput Shape of DGCNN: ", x.shape, "\nOutput Shape of DGCNN: ", y.shape) |