File size: 1,989 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn.functional as F
from .. utils import knn, get_graph_feature


class DGCNN(torch.nn.Module):
	def __init__(self, emb_dims=1024, input_shape="bnc"):
		super(DGCNN, self).__init__()
		if input_shape not in ["bcn", "bnc"]:
			raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ")
		self.input_shape = input_shape
		self.emb_dims = emb_dims

		self.conv1 = torch.nn.Conv2d(6, 64, kernel_size=1, bias=False)
		self.conv2 = torch.nn.Conv2d(64, 64, kernel_size=1, bias=False)
		self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=1, bias=False)
		self.conv4 = torch.nn.Conv2d(128, 256, kernel_size=1, bias=False)
		self.conv5 = torch.nn.Conv2d(512, emb_dims, kernel_size=1, bias=False)
		self.bn1 = torch.nn.BatchNorm2d(64)
		self.bn2 = torch.nn.BatchNorm2d(64)
		self.bn3 = torch.nn.BatchNorm2d(128)
		self.bn4 = torch.nn.BatchNorm2d(256)
		self.bn5 = torch.nn.BatchNorm2d(emb_dims)

	def forward(self, input_data):
		if self.input_shape == "bnc":
			input_data = input_data.permute(0, 2, 1)
		if input_data.shape[1] != 3:
			raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")

		batch_size, num_dims, num_points = input_data.size()
		output = get_graph_feature(input_data)

		output = F.relu(self.bn1(self.conv1(output)))
		output1 = output.max(dim=-1, keepdim=True)[0]

		output = F.relu(self.bn2(self.conv2(output)))
		output2 = output.max(dim=-1, keepdim=True)[0]

		output = F.relu(self.bn3(self.conv3(output)))
		output3 = output.max(dim=-1, keepdim=True)[0]

		output = F.relu(self.bn4(self.conv4(output)))
		output4 = output.max(dim=-1, keepdim=True)[0]

		output = torch.cat((output1, output2, output3, output4), dim=1)

		output = F.relu(self.bn5(self.conv5(output))).view(batch_size, -1, num_points)
		return output


if __name__ == '__main__':
	# Test the code.
	x = torch.rand((10,1024,3))

	dgcnn = DGCNN()
	y = dgcnn(x)
	print("\nInput Shape of DGCNN: ", x.shape, "\nOutput Shape of DGCNN: ", y.shape)