File size: 1,936 Bytes
97aa5af | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | import torch
import torch.nn as nn
import math
class SVDHead(nn.Module):
def __init__(self, emb_dims, input_shape="bnc"):
super(SVDHead, self).__init__()
self.emb_dims = emb_dims
self.reflect = nn.Parameter(torch.eye(3), requires_grad=False)
self.reflect[2, 2] = -1
self.input_shape = input_shape
def forward(self, *input):
src_embedding = input[0]
tgt_embedding = input[1]
src = input[2]
tgt = input[3]
batch_size = src.size(0)
if self.input_shape == "bnc":
src = src.permute(0, 2, 1)
tgt = tgt.permute(0, 2, 1)
d_k = src_embedding.size(1)
scores = torch.matmul(src_embedding.transpose(2, 1).contiguous(), tgt_embedding) / math.sqrt(d_k)
scores = torch.softmax(scores, dim=2)
src_corr = torch.matmul(tgt, scores.transpose(2, 1).contiguous())
src_centered = src - src.mean(dim=2, keepdim=True)
src_corr_centered = src_corr - src_corr.mean(dim=2, keepdim=True)
H = torch.matmul(src_centered, src_corr_centered.transpose(2, 1).contiguous())
U, S, V = [], [], []
R = []
for i in range(src.size(0)):
u, s, v = torch.svd(H[i])
r = torch.matmul(v, u.transpose(1, 0).contiguous())
r_det = torch.det(r)
if r_det < 0:
u, s, v = torch.svd(H[i])
v = torch.matmul(v, self.reflect)
r = torch.matmul(v, u.transpose(1, 0).contiguous())
# r = r * self.reflect
R.append(r)
U.append(u)
S.append(s)
V.append(v)
U = torch.stack(U, dim=0)
V = torch.stack(V, dim=0)
S = torch.stack(S, dim=0)
R = torch.stack(R, dim=0)
t = torch.matmul(-R, src.mean(dim=2, keepdim=True)) + src_corr.mean(dim=2, keepdim=True)
return R, t.view(batch_size, 3) |