| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from .pooling import Pooling |
|
|
| class PCN(torch.nn.Module): |
| def __init__(self, emb_dims=1024, input_shape="bnc", num_coarse=1024, grid_size=4, detailed_output=False): |
| |
| |
| super(PCN, self).__init__() |
| if input_shape not in ["bcn", "bnc"]: |
| raise ValueError("Allowed shapes are 'bcn' (batch * channels * num_in_points), 'bnc' ") |
| self.input_shape = input_shape |
| self.emb_dims = emb_dims |
| self.num_coarse = num_coarse |
| self.detailed_output = detailed_output |
| self.grid_size = grid_size |
| self.num_fine = self.grid_size ** 2 * self.num_coarse |
| self.pooling = Pooling('max') |
|
|
| self.encoder() |
| self.decoder_layers = self.decoder() |
| if detailed_output: self.folding_layers = self.folding() |
|
|
| def encoder_1(self): |
| self.conv1 = torch.nn.Conv1d(3, 128, 1) |
| self.conv2 = torch.nn.Conv1d(128, 256, 1) |
| self.relu = torch.nn.ReLU() |
|
|
| |
| |
|
|
| layers = [self.conv1, self.relu, |
| self.conv2] |
| return layers |
|
|
| def encoder_2(self): |
| self.conv3 = torch.nn.Conv1d(2*256, 512, 1) |
| self.conv4 = torch.nn.Conv1d(512, self.emb_dims, 1) |
|
|
| |
| |
| self.relu = torch.nn.ReLU() |
|
|
| layers = [self.conv3, self.relu, |
| self.conv4] |
| return layers |
|
|
| def encoder(self): |
| self.encoder_layers1 = self.encoder_1() |
| self.encoder_layers2 = self.encoder_2() |
|
|
| def decoder(self): |
| self.linear1 = torch.nn.Linear(self.emb_dims, 1024) |
| self.linear2 = torch.nn.Linear(1024, 1024) |
| self.linear3 = torch.nn.Linear(1024, self.num_coarse*3) |
|
|
| |
| |
| |
| self.relu = torch.nn.ReLU() |
|
|
| layers = [self.linear1, self.relu, |
| self.linear2, self.relu, |
| self.linear3] |
| return layers |
|
|
| def folding(self): |
| self.conv5 = torch.nn.Conv1d(1029, 512, 1) |
| self.conv6 = torch.nn.Conv1d(512, 512, 1) |
| self.conv7 = torch.nn.Conv1d(512, 3, 1) |
|
|
| |
| |
| self.relu = torch.nn.ReLU() |
|
|
| layers = [self.conv5, self.relu, |
| self.conv6, self.relu, |
| self.conv7] |
| return layers |
|
|
| def fine_decoder(self): |
| |
| linspace = torch.linspace(-0.05, 0.05, steps=self.grid_size) |
| grid = torch.meshgrid(linspace, linspace) |
| grid = torch.reshape(torch.stack(grid, dim=2), (-1,2)) |
| grid = torch.unsqueeze(grid, dim=0) |
| grid_feature = grid.repeat([self.coarse_output.shape[0], self.num_coarse, 1]) |
|
|
| point_feature = torch.unsqueeze(self.coarse_output, dim=2) |
| point_feature = point_feature.repeat([1, 1, self.grid_size ** 2, 1]) |
| point_feature = torch.reshape(point_feature, (-1, self.num_fine, 3)) |
|
|
| global_feature = torch.unsqueeze(self.global_feature_v, dim=1) |
| global_feature = global_feature.repeat([1, self.num_fine, 1]) |
|
|
| feature = torch.cat([grid_feature, point_feature, global_feature], dim=2) |
|
|
| center = torch.unsqueeze(self.coarse_output, dim=2) |
| center = center.repeat([1, 1, self.grid_size ** 2, 1]) |
| center = torch.reshape(center, [-1, self.num_fine, 3]) |
|
|
| output = feature.permute(0, 2, 1) |
| for idx, layer in enumerate(self.folding_layers): |
| output = layer(output) |
| fine_output = output.permute(0, 2, 1) + center |
| return fine_output |
|
|
| def encode(self, input_data): |
| output = input_data |
| for idx, layer in enumerate(self.encoder_layers1): |
| output = layer(output) |
|
|
| global_feature_g = self.pooling(output) |
|
|
| global_feature_g = global_feature_g.unsqueeze(2) |
| global_feature_g = global_feature_g.repeat(1,1,self.num_points) |
| output = torch.cat([output, global_feature_g], dim=1) |
|
|
| for idx, layer in enumerate(self.encoder_layers2): |
| output = layer(output) |
|
|
| self.global_feature_v = self.pooling(output) |
|
|
| def decode(self): |
| output = self.global_feature_v |
| for idx, layer in enumerate(self.decoder_layers): |
| output = layer(output) |
| self.coarse_output = output.view(self.global_feature_v.shape[0], self.num_coarse, 3) |
|
|
| def forward(self, input_data): |
| |
| |
| if self.input_shape == "bnc": |
| self.num_points = input_data.shape[1] |
| input_data = input_data.permute(0, 2, 1) |
| else: |
| self.num_points = input_data.shape[2] |
| if input_data.shape[1] != 3: |
| raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]") |
|
|
| self.encode(input_data) |
| self.decode() |
|
|
| result = {'coarse_output': self.coarse_output} |
|
|
| if self.detailed_output: |
| fine_output = self.fine_decoder() |
| result['fine_output'] = fine_output |
|
|
| return result |
|
|
| |
| if __name__ == '__main__': |
| |
| x = torch.rand((10,1024,3)) |
|
|
| pcn = PCN() |
| y = pcn(x) |
| print("Network Architecture: ") |
| print(pn) |
| print("Input Shape of PCN: ", x.shape, "\nOutput Shape of PCN: ", y['coarse_output'].shape) |