Spaces:
Running
on
Zero
Running
on
Zero
daidedou
commited on
Commit
·
c5cb8fa
1
Parent(s):
df60d6b
cpu option
Browse files- zero_shot.py +104 -77
zero_shot.py
CHANGED
|
@@ -27,8 +27,10 @@ from utils.pickle_stuff import safe_load_with_fallback
|
|
| 27 |
from utils.geometry import compute_operators, load_operators
|
| 28 |
from utils.surfaces import Surface
|
| 29 |
import sys
|
|
|
|
| 30 |
try:
|
| 31 |
import google.colab
|
|
|
|
| 32 |
print("Running Colab")
|
| 33 |
from tqdm import tqdm
|
| 34 |
except ImportError:
|
|
@@ -37,15 +39,17 @@ except ImportError:
|
|
| 37 |
|
| 38 |
|
| 39 |
def seed_everything(seed=42):
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
|
|
|
| 46 |
|
| 47 |
seed_everything()
|
| 48 |
|
|
|
|
| 49 |
class Tee:
|
| 50 |
def __init__(self, *outputs):
|
| 51 |
self.outputs = outputs
|
|
@@ -59,6 +63,7 @@ class Tee:
|
|
| 59 |
for output in self.outputs:
|
| 60 |
output.flush()
|
| 61 |
|
|
|
|
| 62 |
class DiffModel:
|
| 63 |
|
| 64 |
def __init__(self, cfg, device="cuda:0"):
|
|
@@ -81,7 +86,7 @@ class DiffModel:
|
|
| 81 |
network_pkl = os.path.join(netdir, chosen_pkl)
|
| 82 |
print(f'Loading network from "{network_pkl}"...')
|
| 83 |
self.net = safe_load_with_fallback(network_pkl)['ema'].to(device)
|
| 84 |
-
|
| 85 |
print('Done!')
|
| 86 |
loss_name = train_cfg['hyper_params']['loss_name']
|
| 87 |
self.loss_sde = None
|
|
@@ -115,7 +120,7 @@ class Matcher(object):
|
|
| 115 |
self.n_loop = self.cfg.opt.get("n_loop", 0)
|
| 116 |
self.fmap_cfg = self.cfg.deepfeat_conf.fmap
|
| 117 |
self.dataloaders = dict()
|
| 118 |
-
|
| 119 |
def _init(self):
|
| 120 |
cfg = self.cfg
|
| 121 |
self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).to(self.device)
|
|
@@ -124,7 +129,8 @@ class Matcher(object):
|
|
| 124 |
self.decoder = PrismDecoder(dim_in=515).to(self.device)
|
| 125 |
self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
|
| 126 |
self.soft_p2p = True
|
| 127 |
-
params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(
|
|
|
|
| 128 |
else:
|
| 129 |
params_to_opt = self.fmap_model.parameters()
|
| 130 |
self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
|
|
@@ -133,63 +139,73 @@ class Matcher(object):
|
|
| 133 |
|
| 134 |
def fmap(self, shape_dict, target_dict):
|
| 135 |
if self.fmap_cfg.get("use_diff", False):
|
| 136 |
-
C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
|
|
|
|
|
|
|
| 137 |
C12_pred, C12_obj, mask_12 = C12_pred
|
| 138 |
C21_pred, C21_obj, mask_21 = C21_pred
|
| 139 |
else:
|
| 140 |
-
C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
|
|
|
|
| 141 |
C12_obj, C21_obj = C12_pred, C21_pred
|
| 142 |
mask_12, mask_21 = None, None
|
| 143 |
return C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, mask_12, mask_21
|
| 144 |
-
|
| 145 |
|
| 146 |
def zo_shot(self, shape_dict, target_dict):
|
| 147 |
self._init()
|
| 148 |
evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
|
| 149 |
-
_, C12_mask_init, _, _, _, _, _
|
| 150 |
evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
|
| 151 |
new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_mask_init.squeeze(), self.cfg["zo_shot"])
|
| 152 |
indKNN_new, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
|
| 153 |
return new_FM, indKNN_new
|
| 154 |
|
| 155 |
-
|
| 156 |
def optimize(self, shape_dict, target_dict, target_normals):
|
| 157 |
self._init()
|
| 158 |
evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
|
| 159 |
-
C12_pred_init, _, _, _
|
| 160 |
evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
|
| 161 |
evecs_1trans = evecs1.t() @ torch.diag(shape_dict["mass"])
|
| 162 |
n_verts_target = target_dict["vertices"].shape[-2]
|
| 163 |
-
|
| 164 |
-
loss_save = {"cycle": [], "fmap": [], "mse": [], "prism": [], "bij": [], "ortho": [], "sds": [], "lap": [],
|
|
|
|
| 165 |
snk_rec = None
|
| 166 |
for i in tqdm(range(self.n_loop), "Optimizing matching " + shape_dict['name'] + " " + target_dict['name']):
|
| 167 |
-
C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict,
|
|
|
|
| 168 |
if self.cfg.opt.soft_p2p:
|
| 169 |
### A la SNK
|
| 170 |
## P2P 2 -> 1
|
| 171 |
-
soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_pred.squeeze(), evecs1[:, :self.n_fmap],
|
|
|
|
| 172 |
C12_new = evecs_trans2[:self.n_fmap, :] @ soft_p2p_21 @ evecs1[:, :self.n_fmap]
|
| 173 |
soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_new.squeeze(), evecs1[:, :self.n_fmap], prod=True)
|
| 174 |
|
| 175 |
-
## P2P 1 -> 2
|
| 176 |
-
soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_pred.squeeze(), evecs2[:, :self.n_fmap],
|
|
|
|
| 177 |
C21_new = evecs_trans1[:self.n_fmap, :] @ soft_p2p_12 @ evecs2[:, :self.n_fmap]
|
| 178 |
soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_new.squeeze(), evecs2[:, :self.n_fmap], prod=True)
|
| 179 |
|
| 180 |
-
l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(
|
|
|
|
| 181 |
else:
|
| 182 |
C12_new, C21_new = C12_pred, C21_pred
|
| 183 |
|
| 184 |
-
l_ortho = ((C12_new.squeeze() @ C12_new.squeeze().T - self.eye)**2).mean() + (
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
if self.snk:
|
| 192 |
-
# Latent vector
|
| 193 |
latents = self.encoder(shape_dict)
|
| 194 |
latents_duplicate = latents[None, :].repeat(n_verts_target, 1)
|
| 195 |
|
|
@@ -197,51 +213,61 @@ class Matcher(object):
|
|
| 197 |
feats_decode = torch.cat((target_dict["vertices"], latents_duplicate), dim=1)
|
| 198 |
snk_rec, prism, rots = self.decoder(target_dict, feats_decode)
|
| 199 |
l_prism = self.loss_prism(prism, rots, target_dict["vertices"], target_dict["faces"], target_normals)
|
| 200 |
-
l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec)**2).sum(dim=-1).mean()
|
| 201 |
-
l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(
|
|
|
|
| 202 |
l_sds, l_proper = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
|
| 203 |
if self.fmap_cfg.get("use_diff", False):
|
| 204 |
if self.fmap_cfg.diffusion.get("abs", False):
|
| 205 |
C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
|
| 206 |
else:
|
| 207 |
C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
|
| 208 |
-
grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1,
|
|
|
|
| 209 |
scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
|
| 210 |
with torch.no_grad():
|
| 211 |
denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
|
| 212 |
-
targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
|
| 213 |
-
|
| 214 |
-
l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap,
|
|
|
|
| 215 |
|
| 216 |
-
grad_21, _ = guidance_grad(C21_in, self.diffusion_model.net, grad_scale=1,
|
|
|
|
| 217 |
scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
|
| 218 |
-
#denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
| 219 |
with torch.no_grad():
|
| 220 |
-
denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
| 221 |
-
targets_21 = torch_zoomout(evecs2, evecs1, evecs_1trans, C21_obj.squeeze(),
|
| 222 |
-
|
|
|
|
|
|
|
| 223 |
l_proper = l_proper_12 + l_proper_21
|
| 224 |
|
| 225 |
-
l_sds = ((torch.abs(C12_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_12.squeeze()[
|
| 226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
loss = torch.as_tensor(0.).float().to(self.device)
|
| 228 |
if self.cfg.loss.get("ortho", 0) > 0:
|
| 229 |
-
loss += self.cfg.loss.get("ortho", 0) *
|
| 230 |
if self.cfg.loss.get("bij", 0) > 0:
|
| 231 |
-
loss += self.cfg.loss.get("bij", 0) *
|
| 232 |
if self.cfg.loss.get("lap", 0) > 0:
|
| 233 |
-
loss += self.cfg.loss.get("lap", 0) *
|
| 234 |
if self.cfg.loss.get("cycle", 0) > 0:
|
| 235 |
-
loss += self.cfg.loss.get("cycle", 0) *
|
| 236 |
if self.cfg.loss.get("mse_rec", 0) > 0:
|
| 237 |
-
loss += self.cfg.loss.get("mse_rec", 0) *
|
| 238 |
if self.cfg.loss.get("prism_rec", 0) > 0:
|
| 239 |
-
loss += self.cfg.loss.get("prism_rec", 0) *
|
| 240 |
if self.cfg.loss.get("sds", 0) > 0 and self.fmap_cfg.get("use_diff", False):
|
| 241 |
loss += self.cfg.loss.get("sds", 0) * l_sds
|
| 242 |
if self.cfg.loss.get("proper", 0) > 0 and self.fmap_cfg.get("use_diff", False):
|
| 243 |
loss += self.cfg.loss.get("proper", 0) * l_proper
|
| 244 |
-
|
| 245 |
loss.backward()
|
| 246 |
self.optim.step()
|
| 247 |
self.optim.zero_grad()
|
|
@@ -256,19 +282,17 @@ class Matcher(object):
|
|
| 256 |
indKNN_new_init, _ = extract_p2p_torch_fmap(C12_pred_init, evecs1, evecs2)
|
| 257 |
indKNN_new, _ = extract_p2p_torch_fmap(C12_new, evecs1, evecs2)
|
| 258 |
return C12_new, indKNN_new, indKNN_new_init, snk_rec, loss_save
|
| 259 |
-
|
| 260 |
-
|
| 261 |
|
| 262 |
def match(self, pair_batch, output_pair, geod_path, refine=True, eval=False):
|
| 263 |
-
shape_dict, _, target_dict, _, target_normals, mapinfo = pair_batch
|
| 264 |
shape_dict_device = convert_dict(shape_dict, self.device)
|
| 265 |
target_dict_device = convert_dict(target_dict, self.device)
|
| 266 |
print(shape_dict_device["vertices"].device)
|
| 267 |
os.makedirs(output_pair, exist_ok=True)
|
| 268 |
|
| 269 |
-
|
| 270 |
if self.cfg["optimize"]:
|
| 271 |
-
C12_new, p2p, p2p_init, snk_rec, loss_save = self.optimize(shape_dict_device, target_dict_device,
|
|
|
|
| 272 |
np.save(os.path.join(output_pair, "p2p_init.npy"), p2p_init)
|
| 273 |
np.save(os.path.join(output_pair, "losses.npy"), loss_save)
|
| 274 |
else:
|
|
@@ -277,12 +301,13 @@ class Matcher(object):
|
|
| 277 |
np.save(os.path.join(output_pair, "fmap.npy"), C12_new.detach().squeeze().cpu().numpy())
|
| 278 |
np.save(os.path.join(output_pair, "p2p.npy"), p2p)
|
| 279 |
if snk_rec is not None:
|
| 280 |
-
save_ply(os.path.join(output_pair, "rec.ply"), snk_rec.detach().squeeze().cpu().numpy(),
|
|
|
|
| 281 |
|
| 282 |
if refine:
|
| 283 |
evecs1, evecs2 = shape_dict_device["evecs"], target_dict_device["evecs"]
|
| 284 |
evecs_2trans = evecs2.t() @ torch.diag(target_dict_device["mass"])
|
| 285 |
-
new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze(), 128)
|
| 286 |
p2p_refined_zo, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
|
| 287 |
np.save(os.path.join(output_pair, "p2p_zo.npy"), p2p)
|
| 288 |
if eval:
|
|
@@ -290,19 +315,16 @@ class Matcher(object):
|
|
| 290 |
mat_loaded = scipy.io.loadmat(os.path.join(geod_path, file_i + ".mat"))
|
| 291 |
A_geod, sqrt_area = mat_loaded['geod_dist'], np.sqrt(mat_loaded['areas_f'].sum())
|
| 292 |
_, dist = accuracy(p2p[vts_2], vts_1, A_geod,
|
| 293 |
-
|
| 294 |
-
|
| 295 |
if refine:
|
| 296 |
_, dist_zo = accuracy(p2p_refined_zo[vts_2], vts_1, A_geod,
|
| 297 |
-
|
| 298 |
-
|
| 299 |
np.savetxt(os.path.join(output_pair, "dists.txt"), (dist.mean(), dist_zo.mean()))
|
| 300 |
return p2p, p2p_refined_zo, loss_save, dist.mean(), dist_zo.mean()
|
| 301 |
return p2p, loss_save, dist.mean()
|
| 302 |
return p2p, loss_save
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
|
| 307 |
def _dataset_epoch(self, dataset, name_dataset, save_dir, data_dir):
|
| 308 |
os.makedirs(save_dir, exist_ok=True)
|
|
@@ -323,27 +345,31 @@ class Matcher(object):
|
|
| 323 |
print("Pair: " + shape_dict['name'] + " " + target_dict['name'])
|
| 324 |
name_exp = os.path.join(save_dir, shape_dict['name'], target_dict['name'])
|
| 325 |
if self.cfg.get("refine", False):
|
| 326 |
-
_, _, _, dist, dist_zo = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset),
|
|
|
|
| 327 |
else:
|
| 328 |
-
_, _, dist = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True,
|
|
|
|
| 329 |
delta = datetime.now() - t1
|
| 330 |
fm_delta = str_delta(delta)
|
| 331 |
-
remains = ((delta/(id_pair+1))*num_pairs) - delta
|
| 332 |
fm_remains = str_delta(remains)
|
| 333 |
all_accs.append(dist)
|
| 334 |
accs_mean = np.mean(all_accs)
|
| 335 |
if self.cfg.get("refine", False):
|
| 336 |
all_accs_zo.append(dist_zo)
|
| 337 |
accs_zo = np.mean(all_accs_zo)
|
| 338 |
-
print(
|
|
|
|
| 339 |
else:
|
| 340 |
-
print(
|
|
|
|
| 341 |
id_pair += 1
|
| 342 |
if self.cfg.get("refine", False):
|
| 343 |
print(f"mean error : {np.mean(all_accs)}, mean error refined: {np.mean(all_accs_zo)}")
|
| 344 |
else:
|
| 345 |
print(f"mean error : {np.mean(all_accs)}")
|
| 346 |
-
sys.stdout = sys.__stdout__
|
| 347 |
|
| 348 |
def load_data(self, file, num_evecs=200, make_cache=False, factor=None):
|
| 349 |
name = os.path.basename(os.path.splitext(file)[0])
|
|
@@ -357,26 +383,25 @@ class Matcher(object):
|
|
| 357 |
data_dict = load_operators(cache_path)
|
| 358 |
data_dict['name'] = name
|
| 359 |
data_dict_torch = convert_dict(data_dict, self.device)
|
| 360 |
-
#batchify_dict(data_dict_torch)
|
| 361 |
return data_dict_torch, area_shape
|
| 362 |
|
| 363 |
def match_files(self, file_shape, file_target):
|
| 364 |
batch_shape, _ = self.load_data(file_shape)
|
| 365 |
-
batch_target, _ = self.load_data(file_target)
|
| 366 |
target_surf = Surface(filename=file_target)
|
| 367 |
-
target_normals = torch.from_numpy(
|
|
|
|
| 368 |
batch = batch_shape, None, batch_target, target_normals, None, None
|
| 369 |
output_folder = os.path.join(self.cfg.output, batch_shape["name"] + "_" + batch_shape["target"])
|
| 370 |
p2p, _ = self.match(batch, output_folder, None)
|
| 371 |
return batch_shape, batch_target, p2p
|
| 372 |
|
| 373 |
|
| 374 |
-
|
| 375 |
-
|
| 376 |
if __name__ == '__main__':
|
| 377 |
parser = argparse.ArgumentParser(description="Launch the SDS demo over datasets")
|
| 378 |
parser.add_argument('--dataset', type=str, default="SCAPE", help='name of the dataset')
|
| 379 |
-
parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')
|
| 380 |
parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
|
| 381 |
parser.add_argument('--output', type=str, default="results", help="where to store experience results")
|
| 382 |
args = parser.parse_args()
|
|
@@ -398,5 +423,7 @@ if __name__ == '__main__':
|
|
| 398 |
dset = pair_cls(corr_dir, 'test', dset_shape, rotate=cfg.get("rotate", False))
|
| 399 |
exp_time = time.strftime('%y-%m-%d_%H-%M-%S')
|
| 400 |
output_logs = os.path.join(args.output, name_data_geo, exp_time)
|
| 401 |
-
|
|
|
|
|
|
|
| 402 |
matcher._dataset_epoch(dset, name_data_geo, output_logs, args.datadir)
|
|
|
|
| 27 |
from utils.geometry import compute_operators, load_operators
|
| 28 |
from utils.surfaces import Surface
|
| 29 |
import sys
|
| 30 |
+
|
| 31 |
try:
|
| 32 |
import google.colab
|
| 33 |
+
|
| 34 |
print("Running Colab")
|
| 35 |
from tqdm import tqdm
|
| 36 |
except ImportError:
|
|
|
|
| 39 |
|
| 40 |
|
| 41 |
def seed_everything(seed=42):
|
| 42 |
+
random.seed(seed)
|
| 43 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
| 44 |
+
np.random.seed(seed)
|
| 45 |
+
torch.manual_seed(seed)
|
| 46 |
+
torch.backends.cudnn.deterministic = True
|
| 47 |
+
torch.backends.cudnn.benchmark = False
|
| 48 |
+
|
| 49 |
|
| 50 |
seed_everything()
|
| 51 |
|
| 52 |
+
|
| 53 |
class Tee:
|
| 54 |
def __init__(self, *outputs):
|
| 55 |
self.outputs = outputs
|
|
|
|
| 63 |
for output in self.outputs:
|
| 64 |
output.flush()
|
| 65 |
|
| 66 |
+
|
| 67 |
class DiffModel:
|
| 68 |
|
| 69 |
def __init__(self, cfg, device="cuda:0"):
|
|
|
|
| 86 |
network_pkl = os.path.join(netdir, chosen_pkl)
|
| 87 |
print(f'Loading network from "{network_pkl}"...')
|
| 88 |
self.net = safe_load_with_fallback(network_pkl)['ema'].to(device)
|
| 89 |
+
|
| 90 |
print('Done!')
|
| 91 |
loss_name = train_cfg['hyper_params']['loss_name']
|
| 92 |
self.loss_sde = None
|
|
|
|
| 120 |
self.n_loop = self.cfg.opt.get("n_loop", 0)
|
| 121 |
self.fmap_cfg = self.cfg.deepfeat_conf.fmap
|
| 122 |
self.dataloaders = dict()
|
| 123 |
+
|
| 124 |
def _init(self):
|
| 125 |
cfg = self.cfg
|
| 126 |
self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).to(self.device)
|
|
|
|
| 129 |
self.decoder = PrismDecoder(dim_in=515).to(self.device)
|
| 130 |
self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
|
| 131 |
self.soft_p2p = True
|
| 132 |
+
params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(
|
| 133 |
+
self.decoder.parameters())
|
| 134 |
else:
|
| 135 |
params_to_opt = self.fmap_model.parameters()
|
| 136 |
self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
|
|
|
|
| 139 |
|
| 140 |
def fmap(self, shape_dict, target_dict):
|
| 141 |
if self.fmap_cfg.get("use_diff", False):
|
| 142 |
+
C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
|
| 143 |
+
{"shape1": shape_dict, "shape2": target_dict}, diff_model=self.diffusion_model,
|
| 144 |
+
scale=self.fmap_cfg.diffusion.time)
|
| 145 |
C12_pred, C12_obj, mask_12 = C12_pred
|
| 146 |
C21_pred, C21_obj, mask_21 = C21_pred
|
| 147 |
else:
|
| 148 |
+
C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
|
| 149 |
+
{"shape1": shape_dict, "shape2": target_dict})
|
| 150 |
C12_obj, C21_obj = C12_pred, C21_pred
|
| 151 |
mask_12, mask_21 = None, None
|
| 152 |
return C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, mask_12, mask_21
|
|
|
|
| 153 |
|
| 154 |
def zo_shot(self, shape_dict, target_dict):
|
| 155 |
self._init()
|
| 156 |
evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
|
| 157 |
+
_, C12_mask_init, _, _, _, _, _, _, _, _ = self.fmap(shape_dict, target_dict)
|
| 158 |
evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
|
| 159 |
new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_mask_init.squeeze(), self.cfg["zo_shot"])
|
| 160 |
indKNN_new, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
|
| 161 |
return new_FM, indKNN_new
|
| 162 |
|
|
|
|
| 163 |
def optimize(self, shape_dict, target_dict, target_normals):
|
| 164 |
self._init()
|
| 165 |
evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
|
| 166 |
+
C12_pred_init, _, _, _, _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
|
| 167 |
evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
|
| 168 |
evecs_1trans = evecs1.t() @ torch.diag(shape_dict["mass"])
|
| 169 |
n_verts_target = target_dict["vertices"].shape[-2]
|
| 170 |
+
|
| 171 |
+
loss_save = {"cycle": [], "fmap": [], "mse": [], "prism": [], "bij": [], "ortho": [], "sds": [], "lap": [],
|
| 172 |
+
"proper": []}
|
| 173 |
snk_rec = None
|
| 174 |
for i in tqdm(range(self.n_loop), "Optimizing matching " + shape_dict['name'] + " " + target_dict['name']):
|
| 175 |
+
C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict,
|
| 176 |
+
target_dict)
|
| 177 |
if self.cfg.opt.soft_p2p:
|
| 178 |
### A la SNK
|
| 179 |
## P2P 2 -> 1
|
| 180 |
+
soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_pred.squeeze(), evecs1[:, :self.n_fmap],
|
| 181 |
+
prod=True)
|
| 182 |
C12_new = evecs_trans2[:self.n_fmap, :] @ soft_p2p_21 @ evecs1[:, :self.n_fmap]
|
| 183 |
soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_new.squeeze(), evecs1[:, :self.n_fmap], prod=True)
|
| 184 |
|
| 185 |
+
## P2P 1 -> 2
|
| 186 |
+
soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_pred.squeeze(), evecs2[:, :self.n_fmap],
|
| 187 |
+
prod=True)
|
| 188 |
C21_new = evecs_trans1[:self.n_fmap, :] @ soft_p2p_12 @ evecs2[:, :self.n_fmap]
|
| 189 |
soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_new.squeeze(), evecs2[:, :self.n_fmap], prod=True)
|
| 190 |
|
| 191 |
+
l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum(
|
| 192 |
+
dim=-1).mean()
|
| 193 |
else:
|
| 194 |
C12_new, C21_new = C12_pred, C21_pred
|
| 195 |
|
| 196 |
+
l_ortho = ((C12_new.squeeze() @ C12_new.squeeze().T - self.eye) ** 2).mean() + (
|
| 197 |
+
(C21_new.squeeze() @ C21_new.squeeze().T - self.eye) ** 2).mean()
|
| 198 |
+
l_bij = ((C12_new.squeeze() @ C21_new.squeeze() - self.eye) ** 2).mean() + (
|
| 199 |
+
(C21_new.squeeze() @ C12_new.squeeze() - self.eye) ** 2).mean()
|
| 200 |
+
l_lap = ((C12_new @ torch.diag(shape_dict["evals"][:self.n_fmap]) - torch.diag(
|
| 201 |
+
target_dict["evals"][:self.n_fmap]) @ C12_new) ** 2).mean()
|
| 202 |
+
l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(
|
| 203 |
+
shape_dict["evals"][:self.n_fmap]) @ C21_new) ** 2).mean()
|
| 204 |
+
|
| 205 |
+
l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(
|
| 206 |
+
self.device), torch.as_tensor(0.).float().to(self.device)
|
| 207 |
if self.snk:
|
| 208 |
+
# Latent vector
|
| 209 |
latents = self.encoder(shape_dict)
|
| 210 |
latents_duplicate = latents[None, :].repeat(n_verts_target, 1)
|
| 211 |
|
|
|
|
| 213 |
feats_decode = torch.cat((target_dict["vertices"], latents_duplicate), dim=1)
|
| 214 |
snk_rec, prism, rots = self.decoder(target_dict, feats_decode)
|
| 215 |
l_prism = self.loss_prism(prism, rots, target_dict["vertices"], target_dict["faces"], target_normals)
|
| 216 |
+
l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec) ** 2).sum(dim=-1).mean()
|
| 217 |
+
l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum(
|
| 218 |
+
dim=-1).mean()
|
| 219 |
l_sds, l_proper = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
|
| 220 |
if self.fmap_cfg.get("use_diff", False):
|
| 221 |
if self.fmap_cfg.diffusion.get("abs", False):
|
| 222 |
C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
|
| 223 |
else:
|
| 224 |
C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
|
| 225 |
+
grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1,
|
| 226 |
+
batch_size=self.fmap_cfg.diffusion.batch_sds,
|
| 227 |
scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
|
| 228 |
with torch.no_grad():
|
| 229 |
denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
|
| 230 |
+
targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
|
| 231 |
+
|
| 232 |
+
l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap,
|
| 233 |
+
:self.n_fmap]) ** 2).mean()
|
| 234 |
|
| 235 |
+
grad_21, _ = guidance_grad(C21_in, self.diffusion_model.net, grad_scale=1,
|
| 236 |
+
batch_size=self.fmap_cfg.diffusion.batch_sds,
|
| 237 |
scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
|
| 238 |
+
# denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
| 239 |
with torch.no_grad():
|
| 240 |
+
denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
| 241 |
+
targets_21 = torch_zoomout(evecs2, evecs1, evecs_1trans, C21_obj.squeeze(),
|
| 242 |
+
self.cfg.sds_conf.zoomout) # , step=10)
|
| 243 |
+
l_proper_21 = ((C21_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_21.squeeze()[:self.n_fmap,
|
| 244 |
+
:self.n_fmap]) ** 2).mean()
|
| 245 |
l_proper = l_proper_12 + l_proper_21
|
| 246 |
|
| 247 |
+
l_sds = ((torch.abs(C12_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_12.squeeze()[
|
| 248 |
+
:self.n_fmap,
|
| 249 |
+
:self.n_fmap]) ** 2).mean()
|
| 250 |
+
l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[
|
| 251 |
+
:self.n_fmap,
|
| 252 |
+
:self.n_fmap]) ** 2).mean()
|
| 253 |
loss = torch.as_tensor(0.).float().to(self.device)
|
| 254 |
if self.cfg.loss.get("ortho", 0) > 0:
|
| 255 |
+
loss += self.cfg.loss.get("ortho", 0) * l_ortho
|
| 256 |
if self.cfg.loss.get("bij", 0) > 0:
|
| 257 |
+
loss += self.cfg.loss.get("bij", 0) * l_bij
|
| 258 |
if self.cfg.loss.get("lap", 0) > 0:
|
| 259 |
+
loss += self.cfg.loss.get("lap", 0) * l_lap
|
| 260 |
if self.cfg.loss.get("cycle", 0) > 0:
|
| 261 |
+
loss += self.cfg.loss.get("cycle", 0) * l_cycle
|
| 262 |
if self.cfg.loss.get("mse_rec", 0) > 0:
|
| 263 |
+
loss += self.cfg.loss.get("mse_rec", 0) * l_mse
|
| 264 |
if self.cfg.loss.get("prism_rec", 0) > 0:
|
| 265 |
+
loss += self.cfg.loss.get("prism_rec", 0) * l_prism
|
| 266 |
if self.cfg.loss.get("sds", 0) > 0 and self.fmap_cfg.get("use_diff", False):
|
| 267 |
loss += self.cfg.loss.get("sds", 0) * l_sds
|
| 268 |
if self.cfg.loss.get("proper", 0) > 0 and self.fmap_cfg.get("use_diff", False):
|
| 269 |
loss += self.cfg.loss.get("proper", 0) * l_proper
|
| 270 |
+
|
| 271 |
loss.backward()
|
| 272 |
self.optim.step()
|
| 273 |
self.optim.zero_grad()
|
|
|
|
| 282 |
indKNN_new_init, _ = extract_p2p_torch_fmap(C12_pred_init, evecs1, evecs2)
|
| 283 |
indKNN_new, _ = extract_p2p_torch_fmap(C12_new, evecs1, evecs2)
|
| 284 |
return C12_new, indKNN_new, indKNN_new_init, snk_rec, loss_save
|
|
|
|
|
|
|
| 285 |
|
| 286 |
def match(self, pair_batch, output_pair, geod_path, refine=True, eval=False):
|
| 287 |
+
shape_dict, _, target_dict, _, target_normals, mapinfo = pair_batch
|
| 288 |
shape_dict_device = convert_dict(shape_dict, self.device)
|
| 289 |
target_dict_device = convert_dict(target_dict, self.device)
|
| 290 |
print(shape_dict_device["vertices"].device)
|
| 291 |
os.makedirs(output_pair, exist_ok=True)
|
| 292 |
|
|
|
|
| 293 |
if self.cfg["optimize"]:
|
| 294 |
+
C12_new, p2p, p2p_init, snk_rec, loss_save = self.optimize(shape_dict_device, target_dict_device,
|
| 295 |
+
target_normals.to(self.device))
|
| 296 |
np.save(os.path.join(output_pair, "p2p_init.npy"), p2p_init)
|
| 297 |
np.save(os.path.join(output_pair, "losses.npy"), loss_save)
|
| 298 |
else:
|
|
|
|
| 301 |
np.save(os.path.join(output_pair, "fmap.npy"), C12_new.detach().squeeze().cpu().numpy())
|
| 302 |
np.save(os.path.join(output_pair, "p2p.npy"), p2p)
|
| 303 |
if snk_rec is not None:
|
| 304 |
+
save_ply(os.path.join(output_pair, "rec.ply"), snk_rec.detach().squeeze().cpu().numpy(),
|
| 305 |
+
target_dict["faces"])
|
| 306 |
|
| 307 |
if refine:
|
| 308 |
evecs1, evecs2 = shape_dict_device["evecs"], target_dict_device["evecs"]
|
| 309 |
evecs_2trans = evecs2.t() @ torch.diag(target_dict_device["mass"])
|
| 310 |
+
new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze(), 128) # , step=10)
|
| 311 |
p2p_refined_zo, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
|
| 312 |
np.save(os.path.join(output_pair, "p2p_zo.npy"), p2p)
|
| 313 |
if eval:
|
|
|
|
| 315 |
mat_loaded = scipy.io.loadmat(os.path.join(geod_path, file_i + ".mat"))
|
| 316 |
A_geod, sqrt_area = mat_loaded['geod_dist'], np.sqrt(mat_loaded['areas_f'].sum())
|
| 317 |
_, dist = accuracy(p2p[vts_2], vts_1, A_geod,
|
| 318 |
+
sqrt_area=sqrt_area,
|
| 319 |
+
return_all=True)
|
| 320 |
if refine:
|
| 321 |
_, dist_zo = accuracy(p2p_refined_zo[vts_2], vts_1, A_geod,
|
| 322 |
+
sqrt_area=sqrt_area,
|
| 323 |
+
return_all=True)
|
| 324 |
np.savetxt(os.path.join(output_pair, "dists.txt"), (dist.mean(), dist_zo.mean()))
|
| 325 |
return p2p, p2p_refined_zo, loss_save, dist.mean(), dist_zo.mean()
|
| 326 |
return p2p, loss_save, dist.mean()
|
| 327 |
return p2p, loss_save
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
def _dataset_epoch(self, dataset, name_dataset, save_dir, data_dir):
|
| 330 |
os.makedirs(save_dir, exist_ok=True)
|
|
|
|
| 345 |
print("Pair: " + shape_dict['name'] + " " + target_dict['name'])
|
| 346 |
name_exp = os.path.join(save_dir, shape_dict['name'], target_dict['name'])
|
| 347 |
if self.cfg.get("refine", False):
|
| 348 |
+
_, _, _, dist, dist_zo = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset),
|
| 349 |
+
eval=True, refine=True)
|
| 350 |
else:
|
| 351 |
+
_, _, dist = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True,
|
| 352 |
+
refine=False)
|
| 353 |
delta = datetime.now() - t1
|
| 354 |
fm_delta = str_delta(delta)
|
| 355 |
+
remains = ((delta / (id_pair + 1)) * num_pairs) - delta
|
| 356 |
fm_remains = str_delta(remains)
|
| 357 |
all_accs.append(dist)
|
| 358 |
accs_mean = np.mean(all_accs)
|
| 359 |
if self.cfg.get("refine", False):
|
| 360 |
all_accs_zo.append(dist_zo)
|
| 361 |
accs_zo = np.mean(all_accs_zo)
|
| 362 |
+
print(
|
| 363 |
+
f"error: {dist}, zo: {dist_zo}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, mean zo: {accs_zo}, full time: {fm_delta}, remains: {fm_remains}")
|
| 364 |
else:
|
| 365 |
+
print(
|
| 366 |
+
f"error: {dist}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, full time: {fm_delta}, remains: {fm_remains}")
|
| 367 |
id_pair += 1
|
| 368 |
if self.cfg.get("refine", False):
|
| 369 |
print(f"mean error : {np.mean(all_accs)}, mean error refined: {np.mean(all_accs_zo)}")
|
| 370 |
else:
|
| 371 |
print(f"mean error : {np.mean(all_accs)}")
|
| 372 |
+
sys.stdout = sys.__stdout__
|
| 373 |
|
| 374 |
def load_data(self, file, num_evecs=200, make_cache=False, factor=None):
|
| 375 |
name = os.path.basename(os.path.splitext(file)[0])
|
|
|
|
| 383 |
data_dict = load_operators(cache_path)
|
| 384 |
data_dict['name'] = name
|
| 385 |
data_dict_torch = convert_dict(data_dict, self.device)
|
| 386 |
+
# batchify_dict(data_dict_torch)
|
| 387 |
return data_dict_torch, area_shape
|
| 388 |
|
| 389 |
def match_files(self, file_shape, file_target):
|
| 390 |
batch_shape, _ = self.load_data(file_shape)
|
| 391 |
+
batch_target, _ = self.load_data(file_target)
|
| 392 |
target_surf = Surface(filename=file_target)
|
| 393 |
+
target_normals = torch.from_numpy(
|
| 394 |
+
target_surf.surfel / np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().to(self.device)
|
| 395 |
batch = batch_shape, None, batch_target, target_normals, None, None
|
| 396 |
output_folder = os.path.join(self.cfg.output, batch_shape["name"] + "_" + batch_shape["target"])
|
| 397 |
p2p, _ = self.match(batch, output_folder, None)
|
| 398 |
return batch_shape, batch_target, p2p
|
| 399 |
|
| 400 |
|
|
|
|
|
|
|
| 401 |
if __name__ == '__main__':
|
| 402 |
parser = argparse.ArgumentParser(description="Launch the SDS demo over datasets")
|
| 403 |
parser.add_argument('--dataset', type=str, default="SCAPE", help='name of the dataset')
|
| 404 |
+
parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')
|
| 405 |
parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
|
| 406 |
parser.add_argument('--output', type=str, default="results", help="where to store experience results")
|
| 407 |
args = parser.parse_args()
|
|
|
|
| 423 |
dset = pair_cls(corr_dir, 'test', dset_shape, rotate=cfg.get("rotate", False))
|
| 424 |
exp_time = time.strftime('%y-%m-%d_%H-%M-%S')
|
| 425 |
output_logs = os.path.join(args.output, name_data_geo, exp_time)
|
| 426 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
| 427 |
+
print(f"Using device: {device}")
|
| 428 |
+
matcher = Matcher(cfg, device)
|
| 429 |
matcher._dataset_epoch(dset, name_data_geo, output_logs, args.datadir)
|