| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from time import time |
| import numpy as np |
| from .. utils import ( |
| pc_normalize, |
| square_distance, |
| index_points, |
| farthest_point_sample, |
| knn_point, |
| query_ball_point |
| ) |
|
|
| try: |
| from .. utils import pointnet2_utils as pointutils |
| except: |
| print("Error in pointnet2_utils! Retry setup for pointnet2_utils.") |
|
|
| def timeit(tag, t): |
| print("{}: {}s".format(tag, time() - t)) |
| return time() |
|
|
| def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): |
| """ |
| Input: |
| npoint: |
| radius: |
| nsample: |
| xyz: input points position data, [B, N, C] |
| points: input points data, [B, N, D] |
| Return: |
| new_xyz: sampled points position data, [B, 1, C] |
| new_points: sampled points data, [B, 1, N, C+D] |
| """ |
| B, N, C = xyz.shape |
| S = npoint |
| fps_idx = farthest_point_sample(xyz, npoint) |
| new_xyz = index_points(xyz, fps_idx) |
| idx = query_ball_point(radius, nsample, xyz, new_xyz, get_cnt=False) |
| grouped_xyz = index_points(xyz, idx) |
| grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) |
| if points is not None: |
| grouped_points = index_points(points, idx) |
| new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) |
| else: |
| new_points = grouped_xyz_norm |
| if returnfps: |
| return new_xyz, new_points, grouped_xyz, fps_idx |
| else: |
| return new_xyz, new_points |
|
|
| def sample_and_group_all(xyz, points): |
| """ |
| Input: |
| xyz: input points position data, [B, N, C] |
| points: input points data, [B, N, D] |
| Return: |
| new_xyz: sampled points position data, [B, 1, C] |
| new_points: sampled points data, [B, 1, N, C+D] |
| """ |
| device = xyz.device |
| B, N, C = xyz.shape |
| new_xyz = torch.zeros(B, 1, C).to(device) |
| grouped_xyz = xyz.view(B, 1, N, C) |
| if points is not None: |
| new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) |
| else: |
| new_points = grouped_xyz |
| return new_xyz, new_points |
|
|
|
|
| class PointNetSetAbstraction(nn.Module): |
| def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): |
| super(PointNetSetAbstraction, self).__init__() |
| self.npoint = npoint |
| self.radius = radius |
| self.nsample = nsample |
| self.group_all = group_all |
| self.mlp_convs = nn.ModuleList() |
| self.mlp_bns = nn.ModuleList() |
| last_channel = in_channel+3 |
| for out_channel in mlp: |
| self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias = False)) |
| self.mlp_bns.append(nn.BatchNorm2d(out_channel)) |
| last_channel = out_channel |
| |
| if group_all: |
| self.queryandgroup = pointutils.GroupAll() |
| else: |
| self.queryandgroup = pointutils.QueryAndGroup(radius, nsample) |
|
|
| def forward(self, xyz, points): |
| """ |
| Input: |
| xyz: input points position data, [B, C, N] |
| points: input points data, [B, D, N] |
| Return: |
| new_xyz: sampled points position data, [B, S, C] |
| new_points_concat: sample points feature data, [B, S, D'] |
| """ |
| device = xyz.device |
| B, C, N = xyz.shape |
| xyz_t = xyz.permute(0, 2, 1).contiguous() |
| |
| |
|
|
| |
| if self.group_all == False: |
| fps_idx = pointutils.furthest_point_sample(xyz_t, self.npoint) |
| new_xyz = pointutils.gather_operation(xyz, fps_idx) |
| else: |
| new_xyz = xyz |
| new_points = self.queryandgroup(xyz_t, new_xyz.transpose(2, 1).contiguous(), points) |
| |
| |
| |
| for i, conv in enumerate(self.mlp_convs): |
| bn = self.mlp_bns[i] |
| new_points = F.relu(bn(conv(new_points))) |
|
|
| new_points = torch.max(new_points, -1)[0] |
| return new_xyz, new_points |
|
|
| class FlowEmbedding(nn.Module): |
| def __init__(self, radius, nsample, in_channel, mlp, pooling='max', corr_func='concat', knn = True): |
| super(FlowEmbedding, self).__init__() |
| self.radius = radius |
| self.nsample = nsample |
| self.knn = knn |
| self.pooling = pooling |
| self.corr_func = corr_func |
| self.mlp_convs = nn.ModuleList() |
| self.mlp_bns = nn.ModuleList() |
| if corr_func is 'concat': |
| last_channel = in_channel*2+3 |
| for out_channel in mlp: |
| self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1, bias=False)) |
| self.mlp_bns.append(nn.BatchNorm2d(out_channel)) |
| last_channel = out_channel |
|
|
| def forward(self, pos1, pos2, feature1, feature2): |
| """ |
| Input: |
| xyz1: (batch_size, 3, npoint) |
| xyz2: (batch_size, 3, npoint) |
| feat1: (batch_size, channel, npoint) |
| feat2: (batch_size, channel, npoint) |
| Output: |
| xyz1: (batch_size, 3, npoint) |
| feat1_new: (batch_size, mlp[-1], npoint) |
| """ |
| pos1_t = pos1.permute(0, 2, 1).contiguous() |
| pos2_t = pos2.permute(0, 2, 1).contiguous() |
| B, N, C = pos1_t.shape |
| if self.knn: |
| _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) |
| else: |
| |
| |
| idx, cnt = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t, get_cnt=True) |
| |
| _, idx_knn = pointutils.knn(self.nsample, pos1_t, pos2_t) |
| cnt = cnt.view(B, -1, 1).repeat(1, 1, self.nsample) |
| idx = idx_knn[cnt > (self.nsample-1)] |
| |
| pos2_grouped = pointutils.grouping_operation(pos2, idx) |
| pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) |
| |
| feat2_grouped = pointutils.grouping_operation(feature2, idx) |
| if self.corr_func=='concat': |
| feat_diff = torch.cat([feat2_grouped, feature1.view(B, -1, N, 1).repeat(1, 1, 1, self.nsample)], dim = 1) |
| |
| feat1_new = torch.cat([pos_diff, feat_diff], dim = 1) |
| for i, conv in enumerate(self.mlp_convs): |
| bn = self.mlp_bns[i] |
| feat1_new = F.relu(bn(conv(feat1_new))) |
|
|
| feat1_new = torch.max(feat1_new, -1)[0] |
| return pos1, feat1_new |
|
|
| class PointNetSetUpConv(nn.Module): |
| def __init__(self, nsample, radius, f1_channel, f2_channel, mlp, mlp2, knn = True): |
| super(PointNetSetUpConv, self).__init__() |
| self.nsample = nsample |
| self.radius = radius |
| self.knn = knn |
| self.mlp1_convs = nn.ModuleList() |
| self.mlp2_convs = nn.ModuleList() |
| last_channel = f2_channel+3 |
| for out_channel in mlp: |
| self.mlp1_convs.append(nn.Sequential(nn.Conv2d(last_channel, out_channel, 1, bias=False), |
| nn.BatchNorm2d(out_channel), |
| nn.ReLU(inplace=False))) |
| last_channel = out_channel |
| if len(mlp) is not 0: |
| last_channel = mlp[-1] + f1_channel |
| else: |
| last_channel = last_channel + f1_channel |
| for out_channel in mlp2: |
| self.mlp2_convs.append(nn.Sequential(nn.Conv1d(last_channel, out_channel, 1, bias=False), |
| nn.BatchNorm1d(out_channel), |
| nn.ReLU(inplace=False))) |
| last_channel = out_channel |
|
|
| def forward(self, pos1, pos2, feature1, feature2): |
| """ |
| Feature propagation from xyz2 (less points) to xyz1 (more points) |
| Inputs: |
| xyz1: (batch_size, 3, npoint1) |
| xyz2: (batch_size, 3, npoint2) |
| feat1: (batch_size, channel1, npoint1) features for xyz1 points (earlier layers, more points) |
| feat2: (batch_size, channel1, npoint2) features for xyz2 points |
| Output: |
| feat1_new: (batch_size, npoint2, mlp[-1] or mlp2[-1] or channel1+3) |
| TODO: Add support for skip links. Study how delta(XYZ) plays a role in feature updating. |
| """ |
| pos1_t = pos1.permute(0, 2, 1).contiguous() |
| pos2_t = pos2.permute(0, 2, 1).contiguous() |
| B,C,N = pos1.shape |
| if self.knn: |
| _, idx = pointutils.knn(self.nsample, pos1_t, pos2_t) |
| else: |
| idx = query_ball_point(self.radius, self.nsample, pos2_t, pos1_t, get_cnt=False) |
| |
| pos2_grouped = pointutils.grouping_operation(pos2, idx) |
| pos_diff = pos2_grouped - pos1.view(B, -1, N, 1) |
|
|
| feat2_grouped = pointutils.grouping_operation(feature2, idx) |
| feat_new = torch.cat([feat2_grouped, pos_diff], dim = 1) |
| for conv in self.mlp1_convs: |
| feat_new = conv(feat_new) |
| |
| feat_new = feat_new.max(-1)[0] |
| |
| if feature1 is not None: |
| feat_new = torch.cat([feat_new, feature1], dim=1) |
| |
| for conv in self.mlp2_convs: |
| feat_new = conv(feat_new) |
| |
| return feat_new |
|
|
| class PointNetFeaturePropogation(nn.Module): |
| def __init__(self, in_channel, mlp): |
| super(PointNetFeaturePropogation, self).__init__() |
| self.mlp_convs = nn.ModuleList() |
| self.mlp_bns = nn.ModuleList() |
| last_channel = in_channel |
| for out_channel in mlp: |
| self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) |
| self.mlp_bns.append(nn.BatchNorm1d(out_channel)) |
| last_channel = out_channel |
|
|
| def forward(self, pos1, pos2, feature1, feature2): |
| """ |
| Input: |
| xyz1: input points position data, [B, C, N] |
| xyz2: sampled input points position data, [B, C, S] |
| points1: input points data, [B, D, N] |
| points2: input points data, [B, D, S] |
| Return: |
| new_points: upsampled points data, [B, D', N] |
| """ |
| pos1_t = pos1.permute(0, 2, 1).contiguous() |
| pos2_t = pos2.permute(0, 2, 1).contiguous() |
| B, C, N = pos1.shape |
| |
| |
| |
| |
| dists,idx = pointutils.three_nn(pos1_t,pos2_t) |
| dists[dists < 1e-10] = 1e-10 |
| weight = 1.0 / dists |
| weight = weight / torch.sum(weight, -1,keepdim = True) |
| interpolated_feat = torch.sum(pointutils.grouping_operation(feature2, idx) * weight.view(B, 1, N, 3), dim = -1) |
|
|
| if feature1 is not None: |
| feat_new = torch.cat([interpolated_feat, feature1], 1) |
| else: |
| feat_new = interpolated_feat |
| |
| for i, conv in enumerate(self.mlp_convs): |
| bn = self.mlp_bns[i] |
| feat_new = F.relu(bn(conv(feat_new))) |
| return feat_new |
|
|
|
|
| class FlowNet3D(nn.Module): |
| def __init__(self): |
| super(FlowNet3D, self).__init__() |
|
|
| self.sa1 = PointNetSetAbstraction(npoint=1024, radius=0.5, nsample=16, in_channel=3, mlp=[32,32,64], group_all=False) |
| self.sa2 = PointNetSetAbstraction(npoint=256, radius=1.0, nsample=16, in_channel=64, mlp=[64, 64, 128], group_all=False) |
| self.sa3 = PointNetSetAbstraction(npoint=64, radius=2.0, nsample=8, in_channel=128, mlp=[128, 128, 256], group_all=False) |
| self.sa4 = PointNetSetAbstraction(npoint=16, radius=4.0, nsample=8, in_channel=256, mlp=[256, 256, 512], group_all=False) |
| |
| self.fe_layer = FlowEmbedding(radius=10.0, nsample=64, in_channel = 128, mlp=[128, 128, 128], pooling='max', corr_func='concat') |
| |
| self.su1 = PointNetSetUpConv(nsample=8, radius=2.4, f1_channel = 256, f2_channel = 512, mlp=[], mlp2=[256, 256]) |
| self.su2 = PointNetSetUpConv(nsample=8, radius=1.2, f1_channel = 128+128, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256]) |
| self.su3 = PointNetSetUpConv(nsample=8, radius=0.6, f1_channel = 64, f2_channel = 256, mlp=[128, 128, 256], mlp2=[256]) |
| self.fp = PointNetFeaturePropogation(in_channel = 256+3, mlp = [256, 256]) |
| |
| self.conv1 = nn.Conv1d(256, 128, kernel_size=1, bias=False) |
| self.bn1 = nn.BatchNorm1d(128) |
| self.conv2=nn.Conv1d(128, 3, kernel_size=1, bias=True) |
| |
| def forward(self, pc1, pc2, feature1, feature2): |
| l1_pc1, l1_feature1 = self.sa1(pc1, feature1) |
| l2_pc1, l2_feature1 = self.sa2(l1_pc1, l1_feature1) |
| |
| l1_pc2, l1_feature2 = self.sa1(pc2, feature2) |
| l2_pc2, l2_feature2 = self.sa2(l1_pc2, l1_feature2) |
| |
| _, l2_feature1_new = self.fe_layer(l2_pc1, l2_pc2, l2_feature1, l2_feature2) |
|
|
| l3_pc1, l3_feature1 = self.sa3(l2_pc1, l2_feature1_new) |
| l4_pc1, l4_feature1 = self.sa4(l3_pc1, l3_feature1) |
| |
| l3_fnew1 = self.su1(l3_pc1, l4_pc1, l3_feature1, l4_feature1) |
| l2_fnew1 = self.su2(l2_pc1, l3_pc1, torch.cat([l2_feature1, l2_feature1_new], dim=1), l3_fnew1) |
| l1_fnew1 = self.su3(l1_pc1, l2_pc1, l1_feature1, l2_fnew1) |
| l0_fnew1 = self.fp(pc1, l1_pc1, feature1, l1_fnew1) |
| |
| x = F.relu(self.bn1(self.conv1(l0_fnew1))) |
| sf = self.conv2(x) |
| return sf |
| |
| if __name__ == '__main__': |
| import os |
| import torch |
| os.environ["CUDA_VISIBLE_DEVICES"] = '0' |
| input = torch.randn((8,3,2048)) |
| label = torch.randn(8,16) |
| model = FlowNet3D() |
| output = model(input,input) |
| print(output.size()) |