#!/usr/bin/env python # -*- coding: utf-8 -*- import os import sys import glob import h5py import copy import math import json import numpy as np from tqdm import tqdm import torch import torch.nn as nn import torch.nn.functional as F from .. ops import transform_functions as transform from .. utils import Transformer, Identity, knn, get_graph_feature from sklearn.metrics import r2_score device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') def pairwise_distance(src, tgt): inner = -2 * torch.matmul(src.transpose(2, 1).contiguous(), tgt) xx = torch.sum(src**2, dim=1, keepdim=True) yy = torch.sum(tgt**2, dim=1, keepdim=True) distances = xx.transpose(2, 1).contiguous() + inner + yy return torch.sqrt(distances) def cycle_consistency(rotation_ab, translation_ab, rotation_ba, translation_ba): batch_size = rotation_ab.size(0) identity = torch.eye(3, device=rotation_ab.device).unsqueeze(0).repeat(batch_size, 1, 1) return F.mse_loss(torch.matmul(rotation_ab, rotation_ba), identity) + F.mse_loss(translation_ab, -translation_ba) class PointNet(nn.Module): def __init__(self, emb_dims=512): super(PointNet, self).__init__() self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False) self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) self.conv3 = nn.Conv1d(64, 64, kernel_size=1, bias=False) self.conv4 = nn.Conv1d(64, 128, kernel_size=1, bias=False) self.conv5 = nn.Conv1d(128, emb_dims, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm1d(64) self.bn2 = nn.BatchNorm1d(64) self.bn3 = nn.BatchNorm1d(64) self.bn4 = nn.BatchNorm1d(128) self.bn5 = nn.BatchNorm1d(emb_dims) def forward(self, x): x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) x = F.relu(self.bn4(self.conv4(x))) x = F.relu(self.bn5(self.conv5(x))) return x class DGCNN(nn.Module): def __init__(self, emb_dims=512): super(DGCNN, self).__init__() self.conv1 = nn.Conv2d(6, 64, kernel_size=1, bias=False) self.conv2 = nn.Conv2d(64*2, 64, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(64*2, 128, kernel_size=1, bias=False) self.conv4 = nn.Conv2d(128*2, 256, kernel_size=1, bias=False) self.conv5 = nn.Conv2d(512, emb_dims, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(64) self.bn2 = nn.BatchNorm2d(64) self.bn3 = nn.BatchNorm2d(128) self.bn4 = nn.BatchNorm2d(256) self.bn5 = nn.BatchNorm2d(emb_dims) def forward(self, x): batch_size, num_dims, num_points = x.size() x = get_graph_feature(x, device=device) x = F.leaky_relu(self.bn1(self.conv1(x)), negative_slope=0.2) x1 = x.max(dim=-1, keepdim=True)[0] x = get_graph_feature(x1, device=device) x = F.leaky_relu(self.bn2(self.conv2(x)), negative_slope=0.2) x2 = x.max(dim=-1, keepdim=True)[0] x = get_graph_feature(x2, device=device) x = F.leaky_relu(self.bn3(self.conv3(x)), negative_slope=0.2) x3 = x.max(dim=-1, keepdim=True)[0] x = get_graph_feature(x3, device=device) x = F.leaky_relu(self.bn4(self.conv4(x)), negative_slope=0.2) x4 = x.max(dim=-1, keepdim=True)[0] x = torch.cat((x1, x2, x3, x4), dim=1) x = F.leaky_relu(self.bn5(self.conv5(x)), negative_slope=0.2).view(batch_size, -1, num_points) return x class MLPHead(nn.Module): def __init__(self, emb_dims): super(MLPHead, self).__init__() n_emb_dims = emb_dims self.n_emb_dims = n_emb_dims self.nn = nn.Sequential(nn.Linear(n_emb_dims*2, n_emb_dims//2), nn.BatchNorm1d(n_emb_dims//2), nn.ReLU(), nn.Linear(n_emb_dims//2, n_emb_dims//4), nn.BatchNorm1d(n_emb_dims//4), nn.ReLU(), nn.Linear(n_emb_dims//4, n_emb_dims//8), nn.BatchNorm1d(n_emb_dims//8), nn.ReLU()) self.proj_rot = nn.Linear(n_emb_dims//8, 4) self.proj_trans = nn.Linear(n_emb_dims//8, 3) def forward(self, *input): src_embedding = input[0] tgt_embedding = input[1] embedding = torch.cat((src_embedding, tgt_embedding), dim=1) embedding = self.nn(embedding.max(dim=-1)[0]) rotation = self.proj_rot(embedding) rotation = rotation / torch.norm(rotation, p=2, dim=1, keepdim=True) translation = self.proj_trans(embedding) return quat2mat(rotation), translation class TemperatureNet(nn.Module): def __init__(self, emb_dims, temp_factor): super(TemperatureNet, self).__init__() self.n_emb_dims = emb_dims self.temp_factor = temp_factor self.nn = nn.Sequential(nn.Linear(self.n_emb_dims, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 128), nn.BatchNorm1d(128), nn.ReLU(), nn.Linear(128, 1), nn.ReLU()) self.feature_disparity = None def forward(self, *input): src_embedding = input[0] tgt_embedding = input[1] src_embedding = src_embedding.mean(dim=2) tgt_embedding = tgt_embedding.mean(dim=2) residual = torch.abs(src_embedding-tgt_embedding) self.feature_disparity = residual return torch.clamp(self.nn(residual), 1.0/self.temp_factor, 1.0*self.temp_factor), residual class SVDHead(nn.Module): def __init__(self, emb_dims, cat_sampler): super(SVDHead, self).__init__() self.n_emb_dims = emb_dims self.cat_sampler = cat_sampler self.reflect = nn.Parameter(torch.eye(3), requires_grad=False) self.reflect[2, 2] = -1 self.temperature = nn.Parameter(torch.ones(1)*0.5, requires_grad=True) self.my_iter = torch.ones(1) def forward(self, *input): src_embedding = input[0] tgt_embedding = input[1] src = input[2] tgt = input[3] batch_size, num_dims, num_points = src.size() temperature = input[4].view(batch_size, 1, 1) if self.cat_sampler == 'softmax': d_k = src_embedding.size(1) scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k) scores = torch.softmax(temperature*scores, dim=2) elif self.cat_sampler == 'gumbel_softmax': d_k = src_embedding.size(1) scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k) scores = scores.view(batch_size*num_points, num_points) temperature = temperature.repeat(1, num_points, 1).view(-1, 1) scores = F.gumbel_softmax(scores, tau=temperature, hard=True) scores = scores.view(batch_size, num_points, num_points) else: raise Exception('not implemented') src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous()) src_centered = src - src.mean(dim=2, keepdim=True) src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True) H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous()).cpu() R = [] for i in range(src.size(0)): u, s, v = torch.svd(H[i]) r = torch.matmul(v, u.transpose(1, 0)).contiguous() r_det = torch.det(r).item() diag = torch.from_numpy(np.array([[1.0, 0, 0], [0, 1.0, 0], [0, 0, r_det]]).astype('float32')).to(v.device) r = torch.matmul(torch.matmul(v, diag), u.transpose(1, 0)).contiguous() R.append(r) R = torch.stack(R, dim=0).to(device) t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True) if self.training: self.my_iter += 1 return R, t.view(batch_size, 3) class KeyPointNet(nn.Module): def __init__(self, num_keypoints): super(KeyPointNet, self).__init__() self.num_keypoints = num_keypoints def forward(self, *input): src = input[0] tgt = input[1] src_embedding = input[2] tgt_embedding = input[3] batch_size, num_dims, num_points = src_embedding.size() src_norm = torch.norm(src_embedding, dim=1, keepdim=True) tgt_norm = torch.norm(tgt_embedding, dim=1, keepdim=True) src_topk_idx = torch.topk(src_norm, k=self.num_keypoints, dim=2, sorted=False)[1] tgt_topk_idx = torch.topk(tgt_norm, k=self.num_keypoints, dim=2, sorted=False)[1] src_keypoints_idx = src_topk_idx.repeat(1, 3, 1) tgt_keypoints_idx = tgt_topk_idx.repeat(1, 3, 1) src_embedding_idx = src_topk_idx.repeat(1, num_dims, 1) tgt_embedding_idx = tgt_topk_idx.repeat(1, num_dims, 1) src_keypoints = torch.gather(src, dim=2, index=src_keypoints_idx) tgt_keypoints = torch.gather(tgt, dim=2, index=tgt_keypoints_idx) src_embedding = torch.gather(src_embedding, dim=2, index=src_embedding_idx) tgt_embedding = torch.gather(tgt_embedding, dim=2, index=tgt_embedding_idx) return src_keypoints, tgt_keypoints, src_embedding, tgt_embedding class PRNet(nn.Module): def __init__(self, emb_nn='dgcnn', attention='transformer', head='svd', emb_dims=512, num_keypoints=512, num_subsampled_points=768, num_iters=3, cycle_consistency_loss=0.1, feature_alignment_loss=0.1, discount_factor = 0.9, input_shape='bnc'): super(PRNet, self).__init__() self.emb_dims = emb_dims self.num_keypoints = num_keypoints self.num_subsampled_points = num_subsampled_points self.num_iters = num_iters self.discount_factor = discount_factor self.feature_alignment_loss = feature_alignment_loss self.cycle_consistency_loss = cycle_consistency_loss self.input_shape = input_shape if emb_nn == 'pointnet': self.emb_nn = PointNet(emb_dims=self.emb_dims) elif emb_nn == 'dgcnn': self.emb_nn = DGCNN(emb_dims=self.emb_dims) else: raise Exception('Not implemented') if attention == 'identity': self.attention = Identity() elif attention == 'transformer': self.attention = Transformer(emb_dims=self.emb_dims, n_blocks=1, dropout=0.0, ff_dims=1024, n_heads=4) else: raise Exception("Not implemented") self.temp_net = TemperatureNet(emb_dims=self.emb_dims, temp_factor=100) if head == 'mlp': self.head = MLPHead(emb_dims=self.emb_dims) elif head == 'svd': self.head = SVDHead(emb_dims=self.emb_dims, cat_sampler='softmax') else: raise Exception('Not implemented') if self.num_keypoints != self.num_subsampled_points: self.keypointnet = KeyPointNet(num_keypoints=self.num_keypoints) else: self.keypointnet = Identity() def predict_embedding(self, *input): src = input[0] tgt = input[1] src_embedding = self.emb_nn(src) tgt_embedding = self.emb_nn(tgt) src_embedding_p, tgt_embedding_p = self.attention(src_embedding, tgt_embedding) src_embedding = src_embedding + src_embedding_p tgt_embedding = tgt_embedding + tgt_embedding_p src, tgt, src_embedding, tgt_embedding = self.keypointnet(src, tgt, src_embedding, tgt_embedding) temperature, feature_disparity = self.temp_net(src_embedding, tgt_embedding) return src, tgt, src_embedding, tgt_embedding, temperature, feature_disparity # Single Pass Alignment Module for PRNet def spam(self, *input): src, tgt, src_embedding, tgt_embedding, temperature, feature_disparity = self.predict_embedding(*input) rotation_ab, translation_ab = self.head(src_embedding, tgt_embedding, src, tgt, temperature) rotation_ba, translation_ba = self.head(tgt_embedding, src_embedding, tgt, src, temperature) return rotation_ab, translation_ab, rotation_ba, translation_ba, feature_disparity def predict_keypoint_correspondence(self, *input): src, tgt, src_embedding, tgt_embedding, temperature, _ = self.predict_embedding(*input) batch_size, num_dims, num_points = src.size() d_k = src_embedding.size(1) scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k) scores = scores.view(batch_size*num_points, num_points) temperature = temperature.repeat(1, num_points, 1).view(-1, 1) scores = F.gumbel_softmax(scores, tau=temperature, hard=True) scores = scores.view(batch_size, num_points, num_points) return src, tgt, scores def forward(self, *input): calculate_loss = False if len(input) == 2: src, tgt = input[0], input[1] elif len(input) == 3: src, tgt, rotation_ab, translation_ab = input[0], input[1], input[2][:, :3, :3], input[2][:, :3, 3].view(-1, 3) calculate_loss = True elif len(input) == 4: src, tgt, rotation_ab, translation_ab = input[0], input[1], input[2], input[3] calculate_loss = True if self.input_shape == 'bnc': src, tgt = src.permute(0, 2, 1), tgt.permute(0, 2, 1) batch_size = src.size(0) identity = torch.eye(3, device=src.device).unsqueeze(0).repeat(batch_size, 1, 1) rotation_ab_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1) translation_ab_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1) rotation_ba_pred = torch.eye(3, device=src.device, dtype=torch.float32).view(1, 3, 3).repeat(batch_size, 1, 1) translation_ba_pred = torch.zeros(3, device=src.device, dtype=torch.float32).view(1, 3).repeat(batch_size, 1) total_loss = 0 total_feature_alignment_loss = 0 total_cycle_consistency_loss = 0 total_scale_consensus_loss = 0 for i in range(self.num_iters): rotation_ab_pred_i, translation_ab_pred_i, rotation_ba_pred_i, translation_ba_pred_i, feature_disparity = self.spam(src, tgt) rotation_ab_pred = torch.matmul(rotation_ab_pred_i, rotation_ab_pred) translation_ab_pred = torch.matmul(rotation_ab_pred_i, translation_ab_pred.unsqueeze(2)).squeeze(2) + translation_ab_pred_i rotation_ba_pred = torch.matmul(rotation_ba_pred_i, rotation_ba_pred) translation_ba_pred = torch.matmul(rotation_ba_pred_i, translation_ba_pred.unsqueeze(2)).squeeze(2) + translation_ba_pred_i if calculate_loss: loss = (F.mse_loss(torch.matmul(rotation_ab_pred.transpose(2, 1), rotation_ab), identity) \ + F.mse_loss(translation_ab_pred, translation_ab)) * self.discount_factor**i feature_alignment_loss = feature_disparity.mean() * self.feature_alignment_loss * self.discount_factor**i cycle_consistency_loss = cycle_consistency(rotation_ab_pred_i, translation_ab_pred_i, rotation_ba_pred_i, translation_ba_pred_i) \ * self.cycle_consistency_loss * self.discount_factor**i scale_consensus_loss = 0 total_feature_alignment_loss += feature_alignment_loss total_cycle_consistency_loss += cycle_consistency_loss total_loss = total_loss + loss + feature_alignment_loss + cycle_consistency_loss + scale_consensus_loss if self.input_shape == 'bnc': src = transform.transform_point_cloud(src.permute(0, 2, 1), rotation_ab_pred_i, translation_ab_pred_i).permute(0, 2, 1) else: src = transform.transform_point_cloud(src, rotation_ab_pred_i, translation_ab_pred_i) if self.input_shape == 'bnc': src, tgt = src.permute(0, 2, 1), tgt.permute(0, 2, 1) result = {'est_R': rotation_ab_pred, 'est_t': translation_ab_pred, 'est_T': transform.convert2transformation(rotation_ab_pred, translation_ab_pred), 'transformed_source': src} if calculate_loss: result['loss'] = total_loss return result if __name__ == '__main__': model = PRNet() src = torch.tensor(10, 1024, 3) tgt = torch.tensor(10, 768, 3) rotation_ab, translation_ab = torch.tensor(10, 3, 3), torch.tensor(10, 3) src, tgt = src.to(device), tgt.to(device) rotation_ab, translation_ab = rotation_ab.to(device), translation_ab.to(device) rotation_ab_pred, translation_ab_pred, loss = model(src, tgt, rotation_ab, translation_ab)