YasiiKB's picture
initial commit
97aa5af verified
import torch
import torch.nn.functional as F
from .. utils import knn, get_graph_feature
class DGCNN(torch.nn.Module):
def __init__(self, emb_dims=1024, input_shape="bnc"):
super(DGCNN, self).__init__()
if input_shape not in ["bcn", "bnc"]:
raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
self.input_shape = input_shape
self.emb_dims = emb_dims
self.conv1 = torch.nn.Conv2d(6, 64, kernel_size=1, bias=False)
self.conv2 = torch.nn.Conv2d(64, 64, kernel_size=1, bias=False)
self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=1, bias=False)
self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=1, bias=False)
self.conv5 = torch.nn.Conv2d(512, emb_dims, kernel_size=1, bias=False)
self.bn1 = torch.nn.BatchNorm2d(64)
self.bn2 = torch.nn.BatchNorm2d(64)
self.bn3 = torch.nn.BatchNorm2d(128)
self.bn4 = torch.nn.BatchNorm2d(256)
self.bn5 = torch.nn.BatchNorm2d(emb_dims)
def forward(self, input_data):
if self.input_shape == "bnc":
input_data = input_data.permute(0, 2, 1)
if input_data.shape[1] != 3:
raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
batch_size, num_dims, num_points = input_data.size()
output = get_graph_feature(input_data)
output = F.relu(self.bn1(self.conv1(output)))
output1 = output.max(dim=-1, keepdim=True)[0]
output = F.relu(self.bn2(self.conv2(output)))
output2 = output.max(dim=-1, keepdim=True)[0]
output = F.relu(self.bn3(self.conv3(output)))
output3 = output.max(dim=-1, keepdim=True)[0]
output = F.relu(self.bn4(self.conv4(output)))
output4 = output.max(dim=-1, keepdim=True)[0]
output = torch.cat((output1, output2, output3, output4), dim=1)
output = F.relu(self.bn5(self.conv5(output))).view(batch_size, -1, num_points)
return output
if __name__ == '__main__':
# Test the code.
x = torch.rand((10,1024,3))
dgcnn = DGCNN()
y = dgcnn(x)
print("\nInput Shape of DGCNN: ", x.shape, "\nOutput Shape of DGCNN: ", y.shape)