Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import numpy as np | |
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class Quaternion: | |
| """Torch Tensor based Quaternion class""" | |
| def identity(dtype=th.double): | |
| """ | |
| Create identity quaternion | |
| """ | |
| return th.tensor([0.0, 0.0, 0.0, 1.0], dtype=dtype) | |
| def mul(q, r): | |
| """ | |
| mul two quaternions, expects those to be double tesnors of length 4 | |
| """ | |
| return th.stack( | |
| [ | |
| (q * th.tensor([1.0, 1.0, -1.0, 1.0], dtype=q.dtype)).dot(r[[3, 2, 1, 0]]), | |
| (q * th.tensor([-1.0, 1.0, 1.0, 1.0], dtype=q.dtype)).dot(r[[2, 3, 0, 1]]), | |
| (q * th.tensor([1.0, -1.0, 1.0, 1.0], dtype=q.dtype)).dot(r[[1, 0, 3, 2]]), | |
| (q * th.tensor([-1.0, -1.0, -1.0, 1.0], dtype=q.dtype)).dot(r[[0, 1, 2, 3]]), | |
| ] | |
| ) | |
| def rot(q, v): | |
| """ | |
| Rotate 3d-vector v given with quaternion q | |
| """ | |
| axis = q[:3] | |
| av = th.cross(axis, v) | |
| aav = th.cross(axis, av) | |
| return v + 2 * (av * q[3] + aav) | |
| def invert(q): | |
| """ | |
| Get the inverse of quaternion q | |
| """ | |
| return q * th.tensor([-1.0, -1.0, -1.0, 1.0], dtype=q.dtype) * (1.0 / q.dot(q)) | |
| def fromAxisAngle(axis, angle): | |
| """ | |
| Generate a quaternion representing a rotation around axis by angle | |
| """ | |
| s = th.sin(angle * 0.5) | |
| c = th.cos(angle * 0.5).view([1]) | |
| return th.cat((axis * s, c), 0) | |
| def fromXYZ(angles): | |
| """ | |
| Generate a quaternion representing a rotation defined by a XYZ-Euler | |
| rotation. | |
| This is faster than creating three separate quaternions and muling | |
| them. | |
| """ | |
| rc = th.cos( | |
| angles * th.tensor([-0.5, 0.5, 0.5], dtype=angles.dtype, device=angles.device) | |
| ) | |
| rs = th.sin( | |
| angles * th.tensor([-0.5, 0.5, 0.5], dtype=angles.dtype, device=angles.device) | |
| ) | |
| return th.stack( | |
| [ | |
| -rs[0] * rc[1] * rc[2] - rc[0] * rs[1] * rs[2], | |
| rc[0] * rs[1] * rc[2] - rs[0] * rc[1] * rs[2], | |
| rc[0] * rc[1] * rs[2] + rs[0] * rs[1] * rc[2], | |
| rc[0] * rc[1] * rc[2] - rs[0] * rs[1] * rs[2], | |
| ] | |
| ) | |
| def toMatrix(q): | |
| """ | |
| Convert quaternion q to 3x3 rotation matrix | |
| """ | |
| result = th.empty([3, 3], dtype=q.dtype) | |
| tx = q[0] * 2.0 | |
| ty = q[1] * 2.0 | |
| tz = q[2] * 2.0 | |
| twx = tx * q[3] | |
| twy = ty * q[3] | |
| twz = tz * q[3] | |
| txx = tx * q[0] | |
| txy = ty * q[0] | |
| txz = tz * q[0] | |
| tyy = ty * q[1] | |
| tyz = tz * q[1] | |
| tzz = tz * q[2] | |
| result[0, 0] = 1.0 - (tyy + tzz) | |
| result[0, 1] = txy - twz | |
| result[0, 2] = txz + twy | |
| result[1, 0] = txy + twz | |
| result[1, 1] = 1.0 - (txx + tzz) | |
| result[1, 2] = tyz - twx | |
| result[2, 0] = txz - twy | |
| result[2, 1] = tyz + twx | |
| result[2, 2] = 1.0 - (txx + tyy) | |
| return result | |
| def toMatrixBatch(q): | |
| tx = q[..., 0] * 2.0 | |
| ty = q[..., 1] * 2.0 | |
| tz = q[..., 2] * 2.0 | |
| twx = tx * q[..., 3] | |
| twy = ty * q[..., 3] | |
| twz = tz * q[..., 3] | |
| txx = tx * q[..., 0] | |
| txy = ty * q[..., 0] | |
| txz = tz * q[..., 0] | |
| tyy = ty * q[..., 1] | |
| tyz = tz * q[..., 1] | |
| tzz = tz * q[..., 2] | |
| return th.stack( | |
| ( | |
| th.stack((1.0 - (tyy + tzz), txy + twz, txz - twy), dim=2), | |
| th.stack((txy - twz, 1.0 - (txx + tzz), tyz + twx), dim=2), | |
| th.stack((txz + twy, tyz - twx, 1.0 - (txx + tyy)), dim=2), | |
| ), | |
| dim=3, | |
| ) | |
| def toMatrixBatchDim1(q): | |
| tx = q[..., 0] * 2.0 | |
| ty = q[..., 1] * 2.0 | |
| tz = q[..., 2] * 2.0 | |
| twx = tx * q[..., 3] | |
| twy = ty * q[..., 3] | |
| twz = tz * q[..., 3] | |
| txx = tx * q[..., 0] | |
| txy = ty * q[..., 0] | |
| txz = tz * q[..., 0] | |
| tyy = ty * q[..., 1] | |
| tyz = tz * q[..., 1] | |
| tzz = tz * q[..., 2] | |
| return th.stack( | |
| ( | |
| th.stack((1.0 - (tyy + tzz), txy + twz, txz - twy), dim=1), | |
| th.stack((txy - twz, 1.0 - (txx + tzz), tyz + twx), dim=1), | |
| th.stack((txz + twy, tyz - twx, 1.0 - (txx + tyy)), dim=1), | |
| ), | |
| dim=2, | |
| ) | |
| def batchMul(q, r): | |
| """ | |
| mul two quaternions, expects those to be double tesnors of length 4 | |
| Args: | |
| q: N x K x 4 quaternions | |
| r: N x K x 4 quaternions | |
| Returns: | |
| N x K x 4 multiplied quaternions | |
| """ | |
| return th.stack( | |
| [ | |
| th.sum( | |
| th.mul( | |
| th.mul( | |
| q, | |
| th.tensor( | |
| [[[1.0, 1.0, -1.0, 1.0]]], | |
| dtype=q.dtype, | |
| device=q.device, | |
| ), | |
| ), | |
| r[:, :, (3, 2, 1, 0)], | |
| ), | |
| dim=-1, | |
| ), | |
| th.sum( | |
| th.mul( | |
| th.mul( | |
| q, | |
| th.tensor( | |
| [[[-1.0, 1.0, 1.0, 1.0]]], | |
| dtype=q.dtype, | |
| device=q.device, | |
| ), | |
| ), | |
| r[:, :, (2, 3, 0, 1)], | |
| ), | |
| dim=-1, | |
| ), | |
| th.sum( | |
| th.mul( | |
| th.mul( | |
| q, | |
| th.tensor( | |
| [[[1.0, -1.0, 1.0, 1.0]]], | |
| dtype=q.dtype, | |
| device=q.device, | |
| ), | |
| ), | |
| r[:, :, (1, 0, 3, 2)], | |
| ), | |
| dim=-1, | |
| ), | |
| th.sum( | |
| th.mul( | |
| th.mul( | |
| q, | |
| th.tensor( | |
| [[[-1.0, -1.0, -1.0, 1.0]]], | |
| dtype=q.dtype, | |
| device=q.device, | |
| ), | |
| ), | |
| r[:, :, (0, 1, 2, 3)], | |
| ), | |
| dim=-1, | |
| ), | |
| ], | |
| dim=2, | |
| ) | |
| def batchRot(q, v): | |
| """ | |
| Rotate 3d-vector v given with quaternion q | |
| Args: | |
| q: N x K x 4 quaternions | |
| v: N x K x 3 vectors | |
| Returns: | |
| N x K x 3 rotated vectors | |
| """ | |
| av = th.cross(q[:, :, :3], v, dim=2) | |
| aav = th.cross(q[:, :, :3], av, dim=2) | |
| return th.add(v, 2 * th.add(th.mul(av, q[:, :, 3].unsqueeze(2)), aav)) | |
| def batchInvert(q): | |
| """ | |
| Get the inverse of quaternion q | |
| Args: | |
| q: N x K x 4 quaternions | |
| Returns: | |
| N x K x 4 inverted quaternions | |
| """ | |
| return ( | |
| q | |
| * th.tensor([-1.0, -1.0, -1.0, 1.0], dtype=q.dtype, device=q.device) | |
| * (th.reciprocal(th.sum(q * q, dim=2).unsqueeze(2))) | |
| ) | |
| def batchFromXYZ(r): | |
| """ | |
| Generate a quaternion representing a rotation defined by a XYZ-Euler | |
| rotation. | |
| Args: | |
| r: N x K x 3 rotation vectors | |
| Returns: | |
| N x K x 4 quaternions | |
| """ | |
| rm = r * th.tensor([[[-0.5, 0.5, 0.5]]], dtype=r.dtype, device=r.device) | |
| rc = th.cos(rm) | |
| rs = th.sin(rm) | |
| return th.stack( | |
| [ | |
| th.sub( | |
| th.mul(th.neg(rs[:, :, 0]), th.mul(rc[:, :, 1], rc[:, :, 2])), | |
| th.mul(rc[:, :, 0], th.mul(rs[:, :, 1], rs[:, :, 2])), | |
| ), | |
| th.sub( | |
| th.mul(rc[:, :, 0], th.mul(rs[:, :, 1], rc[:, :, 2])), | |
| th.mul(rs[:, :, 0], th.mul(rc[:, :, 1], rs[:, :, 2])), | |
| ), | |
| th.add( | |
| th.mul(rc[:, :, 0], th.mul(rc[:, :, 1], rs[:, :, 2])), | |
| th.mul(rs[:, :, 0], th.mul(rs[:, :, 1], rc[:, :, 2])), | |
| ), | |
| th.sub( | |
| th.mul(rc[:, :, 0], th.mul(rc[:, :, 1], rc[:, :, 2])), | |
| th.mul(rs[:, :, 0], th.mul(rs[:, :, 1], rs[:, :, 2])), | |
| ), | |
| ], | |
| dim=2, | |
| ) | |
| def batchMatrixFromXYZ(r): | |
| """ | |
| Generate a matrix representing a rotation defined by a XYZ-Euler | |
| rotation. | |
| Args: | |
| r: N x 3 rotation vectors | |
| Returns: | |
| N x 3 x 3 rotation matrices | |
| """ | |
| rc = th.cos(r) | |
| rs = th.sin(r) | |
| cx = rc[:, 0] | |
| cy = rc[:, 1] | |
| cz = rc[:, 2] | |
| sx = rs[:, 0] | |
| sy = rs[:, 1] | |
| sz = rs[:, 2] | |
| result = th.stack( | |
| ( | |
| cy * cz, | |
| -cx * sz + sx * sy * cz, | |
| sx * sz + cx * sy * cz, | |
| cy * sz, | |
| cx * cz + sx * sy * sz, | |
| -sx * cz + cx * sy * sz, | |
| -sy, | |
| sx * cy, | |
| cx * cy, | |
| ), | |
| dim=1, | |
| ).view(-1, 3, 3) | |
| return result | |
| def batchQuatFromMatrix(m): | |
| """ | |
| :param m: B*3*3 | |
| :return: B*4, order xyzw | |
| """ | |
| assert len(m.shape) == 3 | |
| b, j, k = m.shape | |
| assert j == 3 | |
| assert k == 3 | |
| result = th.zeros((b, 4), dtype=th.float32).to(m.device) | |
| S = th.zeros((b,), dtype=th.float32).to(m.device) | |
| m00 = m[:, 0, 0] | |
| m01 = m[:, 0, 1] | |
| m02 = m[:, 0, 2] | |
| m10 = m[:, 1, 0] | |
| m11 = m[:, 1, 1] | |
| m12 = m[:, 1, 2] | |
| m20 = m[:, 2, 0] | |
| m21 = m[:, 2, 1] | |
| m22 = m[:, 2, 2] | |
| tr = m00 + m11 + m22 | |
| flag = tr > 0 | |
| S[flag] = 2 * th.sqrt(1 + tr[flag]) | |
| result[flag, 0] = (m21 - m12)[flag] / S[flag] | |
| result[flag, 1] = (m02 - m20)[flag] / S[flag] | |
| result[flag, 2] = (m10 - m01)[flag] / S[flag] | |
| result[flag, 3] = 0.25 * S[flag] | |
| flag = ~flag & (m00 > m11) & (m00 > m22) | |
| S[flag] = 2 * th.sqrt(1.0 + m00[flag] - m11[flag] - m22[flag]) | |
| result[flag, 0] = 0.25 * S[flag] | |
| result[flag, 1] = (m01 + m10)[flag] / S[flag] | |
| result[flag, 2] = (m02 + m20)[flag] / S[flag] | |
| result[flag, 3] = (m21 - m12)[flag] / S[flag] | |
| flag = ~flag & (m11 > m22) | |
| S[flag] = 2 * th.sqrt(1.0 + m11[flag] - m00[flag] - m22[flag]) | |
| result[flag, 0] = (m01 + m10)[flag] / S[flag] | |
| result[flag, 1] = 0.25 * S[flag] | |
| result[flag, 2] = (m12 + m21)[flag] / S[flag] | |
| result[flag, 3] = (m02 - m20)[flag] / S[flag] | |
| flag = ~flag | |
| S[flag] = 2 * th.sqrt(1.0 + m22[flag] - m00[flag] - m11[flag]) | |
| result[flag, 0] = (m02 + m20)[flag] / S[flag] | |
| result[flag, 1] = (m12 + m21)[flag] / S[flag] | |
| result[flag, 2] = 0.25 * S[flag] | |
| result[flag, 3] = (m10 - m01)[flag] / S[flag] | |
| return result | |
| class RodriguesVecBatch(nn.Module): | |
| def __init__(self): | |
| super(RodriguesVecBatch, self).__init__() | |
| self.register_buffer("eye", (th.eye(3))) | |
| self.register_buffer( | |
| "zero", | |
| ( | |
| th.zeros( | |
| 1, | |
| ) | |
| ), | |
| ) | |
| # mat = th.zeros((nbat,3,3),dtype=th.float32,device=r.device,requires_grad=True) | |
| def forward( | |
| self, v0, v1 | |
| ): # assuming v0 and v1 are already normalized, compute matrix aligning v0 to v1 | |
| nbat = v0.size(0) | |
| cosn = (v0 * v1).sum(dim=1, keepdim=True).unsqueeze(2) | |
| # r = v0.cross(v1,dim=1) | |
| r = v1.cross(v0, dim=1) | |
| sinn = r.pow(2).sum(1, keepdim=True).sqrt().unsqueeze(2) | |
| rn = r.unsqueeze(2) / (sinn + 1e-10) | |
| R = cosn * self.eye.unsqueeze(0).expand(nbat, 3, 3) | |
| R = R + (1.0 - cosn) * rn.bmm(rn.permute(0, 2, 1)) | |
| R[:, 0, 1] = R[:, 0, 1] + rn[:, 2, 0] * sinn[:, 0, 0] | |
| R[:, 1, 0] = R[:, 0, 1] - rn[:, 2, 0] * sinn[:, 0, 0] | |
| R[:, 0, 2] = R[:, 0, 2] - rn[:, 1, 0] * sinn[:, 0, 0] | |
| R[:, 2, 0] = R[:, 2, 0] + rn[:, 1, 0] * sinn[:, 0, 0] | |
| R[:, 1, 2] = R[:, 1, 2] + rn[:, 0, 0] * sinn[:, 0, 0] | |
| R[:, 2, 1] = R[:, 2, 1] - rn[:, 0, 0] * sinn[:, 0, 0] | |
| return R | |
| class RodriguesBatch(nn.Module): | |
| def __init__(self): | |
| super(RodriguesBatch, self).__init__() | |
| self.register_buffer("eye", (th.eye(3))) | |
| self.register_buffer( | |
| "zero", | |
| ( | |
| th.zeros( | |
| 1, | |
| ) | |
| ), | |
| ) | |
| def forward(self, r): | |
| # pdb.set_trace() | |
| nbat = r.size(0) | |
| n = ((r * r).sum(dim=1, keepdim=True) + 1e-10).sqrt() | |
| rn = th.div(r, n).unsqueeze(2) | |
| cosn = th.cos(n).unsqueeze(2) | |
| sinn = th.sin(n).unsqueeze(2) | |
| R = cosn * self.eye.unsqueeze(0).expand(nbat, 3, 3) | |
| R = R + (1.0 - cosn) * rn.bmm(rn.permute(0, 2, 1)) | |
| R[:, 0, 1] = R[:, 0, 1] + rn[:, 2, 0] * sinn[:, 0, 0] | |
| R[:, 1, 0] = R[:, 0, 1] - rn[:, 2, 0] * sinn[:, 0, 0] | |
| R[:, 0, 2] = R[:, 0, 2] - rn[:, 1, 0] * sinn[:, 0, 0] | |
| R[:, 2, 0] = R[:, 2, 0] + rn[:, 1, 0] * sinn[:, 0, 0] | |
| R[:, 1, 2] = R[:, 1, 2] + rn[:, 0, 0] * sinn[:, 0, 0] | |
| R[:, 2, 1] = R[:, 2, 1] - rn[:, 0, 0] * sinn[:, 0, 0] | |
| return R | |
| class NormalComputer(nn.Module): | |
| def __init__(self, height, width, maskin=None): | |
| super(NormalComputer, self).__init__() | |
| # self.register_buffer('eye', (th.eye(3))) | |
| # self.register_buffer('zero', (th.zeros(1,))) | |
| patchttnum = 5 # neighbor + self | |
| patchmatch_uvpos = np.zeros((height, width, patchttnum, 2), dtype=np.int32) | |
| vec_standuv = ( | |
| np.indices((height, width)) | |
| .swapaxes(0, 2) | |
| .swapaxes(0, 1) | |
| .astype(np.int32) | |
| .reshape(height, width, 1, 2) | |
| ) | |
| patchmatch_uvpos = patchmatch_uvpos + vec_standuv | |
| localpatchcoord = np.zeros((patchttnum, 2), dtype=np.int32) | |
| localpatchcoord = np.array([[-1, 0], [0, 1], [1, 0], [0, -1], [0, 0]]).astype(np.int32) | |
| patchmatch_uvpos = patchmatch_uvpos + localpatchcoord.reshape(1, 1, patchttnum, 2) | |
| patchmatch_uvpos[..., 0] = np.clip(patchmatch_uvpos[..., 0], 0, height - 1) | |
| patchmatch_uvpos[..., 1] = np.clip(patchmatch_uvpos[..., 1], 0, width - 1) | |
| # geoemtry mask , apply simiilar to texture mask | |
| # mesh_mask_int = mesh_mask.reshape(height,width).astype(np.int32) | |
| if maskin is None: | |
| maskin = np.ones((height, width), dtype=np.int32) | |
| mesh_mask_int = maskin.reshape(height, width).astype( | |
| np.int32 | |
| ) # using all pixel valid mask; can use a tailored mask | |
| patchmatch_mask = mesh_mask_int[patchmatch_uvpos[..., 0], patchmatch_uvpos[..., 1]].reshape( | |
| height, width, patchttnum, 1 | |
| ) | |
| patch_indicemap = patchmatch_uvpos * patchmatch_mask + (1 - patchmatch_mask) * vec_standuv | |
| tensor_patch_geoindicemap = th.from_numpy(patch_indicemap).long() | |
| tensor_patch_geoindicemap1d = ( | |
| tensor_patch_geoindicemap[..., 0] * width + tensor_patch_geoindicemap[..., 1] | |
| ) | |
| self.register_buffer("tensor_patch_geoindicemap1d", tensor_patch_geoindicemap1d) | |
| # tensor_patchmatch_uvpos = th.from_numpy(patchmatch_uvpos).long() | |
| # tensor_vec_standuv = th.from_numpy(vec_standuv).long() | |
| def forward(self, t_georecon): # in: N 3 H W | |
| # pdb.set_trace() | |
| # Intergration switch it to index_select | |
| # geometry_in = index_selection_nd( | |
| # t_georecon.view(t_georecon.size(0), t_georecon.size(1), -1), | |
| # self.tensor_patch_geoindicemap1d, | |
| # 2, | |
| # ).permute(0, 2, 3, 4, 1) | |
| geometry_in = th.index_select( | |
| t_georecon.view(t_georecon.size(0), t_georecon.size(1), -1), | |
| self.tensor_patch_geoindicemap1d, | |
| 2, | |
| ).permute(0, 2, 3, 4, 1) | |
| normal = (geometry_in[..., 0, :] - geometry_in[..., 4, :]).cross( | |
| geometry_in[..., 1, :] - geometry_in[..., 4, :], dim=3 | |
| ) | |
| normal = normal + (geometry_in[..., 1, :] - geometry_in[..., 4, :]).cross( | |
| geometry_in[..., 2, :] - geometry_in[..., 4, :], dim=3 | |
| ) | |
| normal = normal + (geometry_in[..., 2, :] - geometry_in[..., 4, :]).cross( | |
| geometry_in[..., 3, :] - geometry_in[..., 4, :], dim=3 | |
| ) | |
| normal = normal + (geometry_in[..., 3, :] - geometry_in[..., 4, :]).cross( | |
| geometry_in[..., 0, :] - geometry_in[..., 4, :], dim=3 | |
| ) | |
| normal = normal / th.clamp(normal.pow(2).sum(3, keepdim=True).sqrt(), min=1e-6) | |
| return normal.permute(0, 3, 1, 2) | |
| def pointcloud_rigid_registration(src_pointcloud, dst_pointcloud, reduce_loss: bool = True): | |
| """ | |
| Calculate RT and residual L2 loss for two pointclouds | |
| :param src_pointcloud: x (b, v, 3) | |
| :param dst_pointcloud: y (b, v, 3) | |
| :return: loss, R, t s.t. ||Rx+t-y||_2^2 minimal. | |
| """ | |
| if len(src_pointcloud.shape) == 2: | |
| src_pointcloud = src_pointcloud.unsqueeze(0) | |
| if len(dst_pointcloud.shape) == 2: | |
| dst_pointcloud = dst_pointcloud.unsqueeze(0) | |
| bn = src_pointcloud.shape[0] | |
| assert src_pointcloud.shape == dst_pointcloud.shape | |
| assert src_pointcloud.shape[2] == 3 | |
| X = src_pointcloud - src_pointcloud.mean(dim=1, keepdim=True) | |
| Y = dst_pointcloud - dst_pointcloud.mean(dim=1, keepdim=True) | |
| XYT = th.einsum("nji,njk->nik", X, Y) | |
| muX = src_pointcloud.mean(dim=1) | |
| muY = dst_pointcloud.mean(dim=1) | |
| R = th.zeros((bn, 3, 3), dtype=src_pointcloud.dtype).to(src_pointcloud.device) | |
| t = th.zeros((bn, 1, 3), dtype=src_pointcloud.dtype).to(src_pointcloud.device) | |
| loss = th.zeros((bn,), dtype=src_pointcloud.dtype).to(src_pointcloud.device) | |
| for i in range(bn): | |
| u_, s_, v_ = th.svd(XYT[i, :, :]) | |
| detvut = th.det(v_.mm(u_.t())) | |
| diag_m = th.ones_like(s_) | |
| diag_m[-1] = detvut | |
| r_ = v_.mm(th.diag(diag_m)).mm(u_.t()) | |
| t_ = muY[i, :] - r_.mm(muX[i, :, None])[:, 0] | |
| R[i, :, :] = r_ | |
| t[i, 0, :] = t_ | |
| loss[i] = (th.einsum("ij,nj->ni", r_, X[i]) - Y[i]).pow(2).sum(1).mean(0) | |
| loss = loss.mean(0) if reduce_loss else loss | |
| return loss, R, t | |
| def pointcloud_rigid_registration_balanced(src_pointcloud, dst_pointcloud, weight): | |
| """ | |
| Calculate RT and residual L2 loss for two pointclouds | |
| :param src_pointcloud: x (b, v, 3) | |
| :param dst_pointcloud: y (b, v, 3) | |
| :param weight: (v, ), duplication of vertices | |
| :return: loss, R, t s.t. ||w(Rx+t-y)||_2^2 minimal. | |
| """ | |
| if len(src_pointcloud.shape) == 2: | |
| src_pointcloud = src_pointcloud.unsqueeze(0) | |
| if len(dst_pointcloud.shape) == 2: | |
| dst_pointcloud = dst_pointcloud.unsqueeze(0) | |
| bn = src_pointcloud.shape[0] | |
| assert src_pointcloud.shape == dst_pointcloud.shape | |
| assert src_pointcloud.shape[2] == 3 | |
| assert src_pointcloud.shape[1] == weight.shape[0] | |
| assert len(weight.shape) == 1 | |
| w = weight[None, :, None] | |
| def s1(a): | |
| return a.sum(dim=1, keepdim=True) | |
| w2 = w.pow(2) | |
| sw2 = s1(w2) | |
| X = src_pointcloud | |
| Y = dst_pointcloud | |
| wXYT = th.einsum("nji,njk->nik", w2 * (sw2 - w2) * X, Y) | |
| U, s, V = batch_svd(wXYT) | |
| UT = U.permute(0, 2, 1).contiguous() | |
| det = batch_det(V.bmm(UT)) | |
| diag = th.ones_like(s).to(s.device) | |
| diag[:, -1] = det | |
| R = V.bmm(batch_diag(diag)).bmm(UT) | |
| RX = th.einsum("bij,bnj->bni", R, X) | |
| t = th.sum(w * (Y - RX), dim=1, keepdim=True) / sw2 | |
| loss = w * (RX + t - Y) | |
| loss = F.mse_loss(loss, th.zeros_like(loss)) * 3 | |
| return loss, R, t | |
| def batch_dot(x, y): | |
| assert x.shape == y.shape | |
| assert len(x.shape) == 2 | |
| return th.einsum("ni,ni->n", x, y) | |
| def batch_svd(x): | |
| assert len(x.shape) == 3 | |
| bn, m, n = x.shape | |
| U = th.zeros((bn, m, m), dtype=th.float32).to(x.device) | |
| s = th.zeros((bn, min(n, m)), dtype=th.float32).to(x.device) | |
| V = th.zeros((bn, n, n), dtype=th.float32).to(x.device) | |
| for i in range(bn): | |
| u_, s_, v_ = th.svd(x[i, :, :]) | |
| U[i] = u_ | |
| s[i] = s_ | |
| V[i] = v_ | |
| return U, s, V | |
| def batch_diag(x): | |
| if len(x.shape) == 2: | |
| bn, n = x.shape | |
| res = th.zeros((bn, n, n), dtype=th.float32).to(x.device) | |
| res[:, range(n), range(n)] = x | |
| return res | |
| elif len(x.shape) == 3: | |
| assert x.shape[1] == x.shape[2] | |
| n = x.shape[1] | |
| return x[:, range(n), range(n)] | |
| else: | |
| raise ValueError("dim of batch_diag should be 2 or 3") | |
| def batch_det(x): | |
| assert len(x.shape) == 3 | |
| assert x.shape[1] == x.shape[2] | |
| bn, _, _ = x.shape | |
| res = th.zeros((bn,), dtype=th.float32).to(x.device) | |
| for i in range(bn): | |
| res[i] = th.det(x[i]) | |
| return res | |