File size: 2,705 Bytes
97aa5af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
import torch.nn as nn
import torch.nn.functional as F
from .pointnet import PointNet
from .pooling import Pooling

class PointNetMask(nn.Module):
	def __init__(self, template_feature_size=1024, source_feature_size=1024, feature_model=PointNet()):
		super().__init__()
		self.feature_model = feature_model
		self.pooling = Pooling()

		input_size = template_feature_size + source_feature_size
		self.h3 = nn.Sequential(nn.Conv1d(input_size, 1024, 1), nn.ReLU(),
								nn.Conv1d(1024, 512, 1), nn.ReLU(),
								nn.Conv1d(512, 256, 1), nn.ReLU(),
								nn.Conv1d(256, 128, 1), nn.ReLU(),
								nn.Conv1d(128,   1, 1), nn.Sigmoid())

	def find_mask(self, x, t_out_h1):
		batch_size, _ , num_points = t_out_h1.size()
		x = x.unsqueeze(2)
		x = x.repeat(1,1,num_points)
		x = torch.cat([t_out_h1, x], dim=1)
		x = self.h3(x)
		return x.view(batch_size, -1)

	def forward(self, template, source):
		source_features = self.feature_model(source)				# [B x C x N]
		template_features = self.feature_model(template)			# [B x C x N]

		source_features = self.pooling(source_features)
		mask = self.find_mask(source_features, template_features)
		return mask


class MaskNet(nn.Module):
	def __init__(self, feature_model=PointNet(use_bn=True), is_training=True):
		super().__init__()
		self.maskNet = PointNetMask(feature_model=feature_model)
		self.is_training = is_training

	@staticmethod
	def index_points(points, idx):
		"""
		Input:
			points: input points data, [B, N, C]
			idx: sample index data, [B, S]
		Return:
			new_points:, indexed points data, [B, S, C]
		"""
		device = points.device
		B = points.shape[0]
		view_shape = list(idx.shape)
		view_shape[1:] = [1] * (len(view_shape) - 1)
		repeat_shape = list(idx.shape)
		repeat_shape[0] = 1
		batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
		new_points = points[batch_indices, idx, :]
		return new_points

	# This function is only useful for testing with a single pair of point clouds.
	@staticmethod
	def find_index(mask_val):
		mask_idx = torch.nonzero((mask_val[0]>0.5)*1.0)
		return mask_idx.view(1, -1)

	def forward(self, template, source, point_selection='threshold'):
		mask = self.maskNet(template, source)

		if point_selection == 'topk' or self.is_training:
			_, self.mask_idx = torch.topk(mask, source.shape[1], dim=1, sorted=False)
		elif point_selection == 'threshold':
			self.mask_idx = self.find_index(mask)

		template = self.index_points(template, self.mask_idx)
		return template, mask


if __name__ == '__main__':
	template, source = torch.rand(10,1024,3), torch.rand(10,1024,3)
	net = MaskNet()
	result = net(template, source)
	import ipdb; ipdb.set_trace()