File size: 3,684 Bytes
458efe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ccb54b
458efe2
3ccb54b
458efe2
3ccb54b
458efe2
 
3ccb54b
458efe2
 
 
 
 
 
 
3ccb54b
 
458efe2
3ccb54b
458efe2
 
3ccb54b
458efe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import torch
import torch.nn.functional as F


def euclidean_dist(x, y):
    """
    Args:
      x: pytorch Variable, with shape [m, d]
      y: pytorch Variable, with shape [n, d]
    Returns:
      dist: pytorch Variable, with shape [m, n]
    """
    #bs, m, n = x.size(0), x.size(1), y.size(1)
    xx = torch.pow(x.squeeze(), 2).sum(1, keepdim=True)
    yy = torch.pow(y.squeeze(), 2).sum(1, keepdim=True).t()
    dist = xx + yy - 2 * torch.inner(x.squeeze(), y.squeeze())
    dist = dist.clamp(min=1e-12).sqrt() 
    return dist

def knnsearch(x, y, alpha=1./0.07, prod=False):
    if prod:
        prods = torch.inner(x.squeeze(), y.squeeze())#/( torch.norm(x.squeeze(), dim=-1)[:, None]*torch.norm(y.squeeze(), dim=-1)[None, :])
        output = F.softmax(alpha*prods, dim=1)
    else:
        distance = euclidean_dist(x, y[None,:])
        output = F.softmax(-alpha*distance, dim=1)
    return output.squeeze()

def extract_p2p_torch(reps_shape, reps_template):
    n_ev = reps_shape.shape[-1]
    with torch.no_grad():
        # print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
        # print(evecs1_dzo.shape)
        reps_shape_torch = torch.from_numpy(reps_shape).float().cuda()
        G_i = (reps_shape_torch[:, None, :].contiguous())  # (M**2, 1, 2)
        reps_template_torch = torch.from_numpy(reps_template).float().cuda()
        X_j = (reps_template_torch[None, :, :n_ev].contiguous())  # (1, N, 2)
        D_ij = ((G_i - X_j) ** 2).sum(-1)  # (M**2, N) symbolic matrix of squared distances
        indKNN = torch.argmin(D_ij, dim=0).squeeze()  # Grid <-> Samples, (M**2, K) integer tensor
        # pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
        # print(indKNN[:10], pmap10_ref[:10])
        indKNN_2 = torch.argmin(D_ij, dim=1).squeeze()
    return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()

def extract_p2p_torch_fmap(fmap_shape_template, evecs_shape, evecs_template):
    n_ev = fmap_shape_template.shape[-1]
    with torch.no_grad():
        # print((evecs0_dzo @ fmap01_final.squeeze().T).shape)
        # print(evecs1_dzo.shape)
        G_i = ((evecs_shape[:, :n_ev] @ fmap_shape_template.squeeze().T)[:, None, :].contiguous())  # (M**2, 1, 2)
        X_j = (evecs_template[None, :, :n_ev].contiguous())  # (1, N, 2)
        D_ij = ((G_i - X_j) ** 2).sum(-1)  # (M**2, N) symbolic matrix of squared distances
        indKNN = torch.argmin(D_ij, dim=0).squeeze()  # Grid <-> Samples, (M**2, K) integer tensor
        # pmap10_ref = FM_to_p2p(fmap01_final.detach().squeeze().cpu().numpy(), s_dict['evecs'], template_dict['evecs'])
        # print(indKNN[:10], pmap10_ref[:10])
        indKNN_2 = torch.argmin(D_ij, dim=1).squeeze()
    return indKNN.detach().cpu().numpy(), indKNN_2.detach().cpu().numpy()

def wlstsq(A, B, w):
    if w is None:
        return torch.linalg.lstsq(A, B).solution
    else:
        assert w.dim() + 1 == A.dim() and w.shape[-1] == A.shape[-2]
        W = torch.diag_embed(w)
        return torch.linalg.lstsq(W @ A, W @ B).solution

def torch_zoomout(evecs0, evecs1, evecs_1_trans, fmap01, target_size, step=1):
    assert fmap01.shape[-2] == fmap01.shape[-1], f"square fmap needed, got {fmap01.shape[-2]} and {fmap01.shape[-1]}"
    fs = fmap01.shape[0]
    for i in range(fs, target_size+1, step):
        indKNN, _ = extract_p2p_torch_fmap(fmap01, evecs0, evecs1)
        #fmap01 = wlstsq(evecs1[..., :i], evecs0[indKNN, :i], None)
        fmap01 = evecs_1_trans[:i, :] @ evecs0[indKNN, :i]
    if fmap01.shape[0] < target_size:
        fmap01 = evecs_1_trans[:target_size, :] @ evecs0[indKNN, :target_size]
    return fmap01