Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,327 Bytes
458efe2 d408533 458efe2 d408533 458efe2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import os
import os.path as osp
import sys
import numpy as np
import scipy.linalg
from tqdm import tqdm
from scipy.spatial import cKDTree
ROOT_DIR = osp.join(osp.abspath(osp.dirname(__file__)), '../')
if ROOT_DIR not in sys.path:
sys.path.append(ROOT_DIR)
try:
import pynndescent
index = pynndescent.NNDescent(np.random.random((100, 3)), n_jobs=2)
del index
ANN = True
except ImportError:
ANN = False
class KNNSearch(object):
DTYPE = np.float32
NJOBS = 4
def __init__(self, data):
self.data = np.asarray(data, dtype=self.DTYPE)
self.kdtree = cKDTree(self.data)
def query(self, kpts, k, return_dists=False):
kpts = np.asarray(kpts, dtype=self.DTYPE)
nndists, nnindices = self.kdtree.query(kpts, k=k, workers=self.NJOBS)
if return_dists:
return nnindices, nndists
else:
return nnindices
def query_ball(self, kpt, radius):
kpt = np.asarray(kpt, dtype=self.DTYPE)
assert kpt.ndim == 1
nnindices = self.kdtree.query_ball_point(kpt, radius, n_jobs=self.NJOBS)
return nnindices
# https://github.com/RobinMagnet/pyFM
def FM_to_p2p(FM, eigvects1, eigvects2, use_ANN=False):
if use_ANN and not ANN:
raise ValueError('Please install pydescent to achieve Approximate Nearest Neighbor')
k2, k1 = FM.shape
assert k1 <= eigvects1.shape[1], \
f'At least {k1} should be provided, here only {eigvects1.shape[1]} are given'
assert k2 <= eigvects2.shape[1], \
f'At least {k2} should be provided, here only {eigvects2.shape[1]} are given'
if use_ANN:
index = pynndescent.NNDescent(eigvects1[:, :k1] @ FM.T, n_jobs=8)
matches, _ = index.query(eigvects2[:, :k2], k=1)
matches = matches.flatten()
else:
tree = KNNSearch(eigvects1[:, :k1] @ FM.T)
matches = tree.query(eigvects2[:, :k2], k=1).flatten()
return matches
def p2p_to_FM(p2p, eigvects1, eigvects2, A2=None):
if A2 is not None:
if A2.shape[0] != eigvects2.shape[0]:
raise ValueError("Can't compute pseudo inverse with subsampled eigenvectors")
if len(A2.shape) == 1:
return eigvects2.T @ (A2[:, None] * eigvects1[p2p, :])
return eigvects2.T @ A2 @ eigvects1[p2p, :]
return scipy.linalg.lstsq(eigvects2, eigvects1[p2p, :])[0]
def zoomout_iteration(eigvects1, eigvects2, FM, step=1, A2=None, use_ANN=False):
k2, k1 = FM.shape
try:
step1, step2 = step
except TypeError:
step1 = step
step2 = step
new_k1, new_k2 = k1 + step1, k2 + step2
p2p = FM_to_p2p(FM, eigvects1, eigvects2, use_ANN=use_ANN)
FM_zo = p2p_to_FM(p2p, eigvects1[:, :new_k1], eigvects2[:, :new_k2], A2=A2)
return FM_zo
def zoomout_refine(eigvects1,
eigvects2,
FM,
nit=10,
step=1,
A2=None,
subsample=None,
use_ANN=False,
return_p2p=False,
verbose=False):
k2_0, k1_0 = FM.shape
try:
step1, step2 = step
except TypeError:
step1 = step
step2 = step
assert k1_0 + nit*step1 <= eigvects1.shape[1], \
f"Not enough eigenvectors on source : \
{k1_0 + nit*step1} are needed when {eigvects1.shape[1]} are provided"
assert k2_0 + nit*step2 <= eigvects2.shape[1], \
f"Not enough eigenvectors on target : \
{k2_0 + nit*step2} are needed when {eigvects2.shape[1]} are provided"
use_subsample = False
if subsample is not None:
use_subsample = True
sub1, sub2 = subsample
FM_zo = FM.copy()
ANN_adventage = False
iterable = range(nit) if not verbose else tqdm(range(nit))
for it in iterable:
ANN_adventage = use_ANN and (FM_zo.shape[0] > 90) and (FM_zo.shape[1] > 90)
if use_subsample:
FM_zo = zoomout_iteration(eigvects1[sub1], eigvects2[sub2], FM_zo, A2=None, step=step, use_ANN=ANN_adventage)
else:
FM_zo = zoomout_iteration(eigvects1, eigvects2, FM_zo, A2=A2, step=step, use_ANN=ANN_adventage)
if return_p2p:
p2p_zo = FM_to_p2p(FM_zo, eigvects1, eigvects2, use_ANN=False)
return FM_zo, p2p_zo
return FM_zo
|