daidedou commited on
Commit
c5cb8fa
·
1 Parent(s): df60d6b

cpu option

Browse files
Files changed (1) hide show
  1. zero_shot.py +104 -77
zero_shot.py CHANGED
@@ -27,8 +27,10 @@ from utils.pickle_stuff import safe_load_with_fallback
27
  from utils.geometry import compute_operators, load_operators
28
  from utils.surfaces import Surface
29
  import sys
 
30
  try:
31
  import google.colab
 
32
  print("Running Colab")
33
  from tqdm import tqdm
34
  except ImportError:
@@ -37,15 +39,17 @@ except ImportError:
37
 
38
 
39
  def seed_everything(seed=42):
40
- random.seed(seed)
41
- os.environ['PYTHONHASHSEED'] = str(seed)
42
- np.random.seed(seed)
43
- torch.manual_seed(seed)
44
- torch.backends.cudnn.deterministic = True
45
- torch.backends.cudnn.benchmark = False
 
46
 
47
  seed_everything()
48
 
 
49
  class Tee:
50
  def __init__(self, *outputs):
51
  self.outputs = outputs
@@ -59,6 +63,7 @@ class Tee:
59
  for output in self.outputs:
60
  output.flush()
61
 
 
62
  class DiffModel:
63
 
64
  def __init__(self, cfg, device="cuda:0"):
@@ -81,7 +86,7 @@ class DiffModel:
81
  network_pkl = os.path.join(netdir, chosen_pkl)
82
  print(f'Loading network from "{network_pkl}"...')
83
  self.net = safe_load_with_fallback(network_pkl)['ema'].to(device)
84
-
85
  print('Done!')
86
  loss_name = train_cfg['hyper_params']['loss_name']
87
  self.loss_sde = None
@@ -115,7 +120,7 @@ class Matcher(object):
115
  self.n_loop = self.cfg.opt.get("n_loop", 0)
116
  self.fmap_cfg = self.cfg.deepfeat_conf.fmap
117
  self.dataloaders = dict()
118
-
119
  def _init(self):
120
  cfg = self.cfg
121
  self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).to(self.device)
@@ -124,7 +129,8 @@ class Matcher(object):
124
  self.decoder = PrismDecoder(dim_in=515).to(self.device)
125
  self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
126
  self.soft_p2p = True
127
- params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(self.decoder.parameters())
 
128
  else:
129
  params_to_opt = self.fmap_model.parameters()
130
  self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
@@ -133,63 +139,73 @@ class Matcher(object):
133
 
134
  def fmap(self, shape_dict, target_dict):
135
  if self.fmap_cfg.get("use_diff", False):
136
- C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model({"shape1": shape_dict, "shape2": target_dict}, diff_model=self.diffusion_model, scale=self.fmap_cfg.diffusion.time)
 
 
137
  C12_pred, C12_obj, mask_12 = C12_pred
138
  C21_pred, C21_obj, mask_21 = C21_pred
139
  else:
140
- C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model({"shape1": shape_dict, "shape2": target_dict})
 
141
  C12_obj, C21_obj = C12_pred, C21_pred
142
  mask_12, mask_21 = None, None
143
  return C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, mask_12, mask_21
144
-
145
 
146
  def zo_shot(self, shape_dict, target_dict):
147
  self._init()
148
  evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
149
- _, C12_mask_init, _, _, _, _, _ , _, _, _ = self.fmap(shape_dict, target_dict)
150
  evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
151
  new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_mask_init.squeeze(), self.cfg["zo_shot"])
152
  indKNN_new, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
153
  return new_FM, indKNN_new
154
 
155
-
156
  def optimize(self, shape_dict, target_dict, target_normals):
157
  self._init()
158
  evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
159
- C12_pred_init, _, _, _ , _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
160
  evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
161
  evecs_1trans = evecs1.t() @ torch.diag(shape_dict["mass"])
162
  n_verts_target = target_dict["vertices"].shape[-2]
163
-
164
- loss_save = {"cycle": [], "fmap": [], "mse": [], "prism": [], "bij": [], "ortho": [], "sds": [], "lap": [], "proper": []}
 
165
  snk_rec = None
166
  for i in tqdm(range(self.n_loop), "Optimizing matching " + shape_dict['name'] + " " + target_dict['name']):
167
- C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
 
168
  if self.cfg.opt.soft_p2p:
169
  ### A la SNK
170
  ## P2P 2 -> 1
171
- soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_pred.squeeze(), evecs1[:, :self.n_fmap], prod=True)
 
172
  C12_new = evecs_trans2[:self.n_fmap, :] @ soft_p2p_21 @ evecs1[:, :self.n_fmap]
173
  soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_new.squeeze(), evecs1[:, :self.n_fmap], prod=True)
174
 
175
- ## P2P 1 -> 2
176
- soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_pred.squeeze(), evecs2[:, :self.n_fmap], prod=True)
 
177
  C21_new = evecs_trans1[:self.n_fmap, :] @ soft_p2p_12 @ evecs2[:, :self.n_fmap]
178
  soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_new.squeeze(), evecs2[:, :self.n_fmap], prod=True)
179
 
180
- l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(dim=-1).mean()
 
181
  else:
182
  C12_new, C21_new = C12_pred, C21_pred
183
 
184
- l_ortho = ((C12_new.squeeze() @ C12_new.squeeze().T - self.eye)**2).mean() + ((C21_new.squeeze() @ C21_new.squeeze().T - self.eye)**2).mean()
185
- l_bij = ((C12_new.squeeze() @ C21_new.squeeze() - self.eye)**2).mean() + ((C21_new.squeeze() @ C12_new.squeeze() - self.eye)**2).mean()
186
- l_lap = ((C12_new @ torch.diag(shape_dict["evals"][:self.n_fmap]) - torch.diag(target_dict["evals"][:self.n_fmap]) @ C12_new)**2).mean()
187
- l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(shape_dict["evals"][:self.n_fmap]) @ C21_new)**2).mean()
188
-
189
-
190
- l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
 
 
 
 
191
  if self.snk:
192
- # Latent vector
193
  latents = self.encoder(shape_dict)
194
  latents_duplicate = latents[None, :].repeat(n_verts_target, 1)
195
 
@@ -197,51 +213,61 @@ class Matcher(object):
197
  feats_decode = torch.cat((target_dict["vertices"], latents_duplicate), dim=1)
198
  snk_rec, prism, rots = self.decoder(target_dict, feats_decode)
199
  l_prism = self.loss_prism(prism, rots, target_dict["vertices"], target_dict["faces"], target_normals)
200
- l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec)**2).sum(dim=-1).mean()
201
- l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"])**2).sum(dim=-1).mean()
 
202
  l_sds, l_proper = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
203
  if self.fmap_cfg.get("use_diff", False):
204
  if self.fmap_cfg.diffusion.get("abs", False):
205
  C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
206
  else:
207
  C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
208
- grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1, batch_size=self.fmap_cfg.diffusion.batch_sds,
 
209
  scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
210
  with torch.no_grad():
211
  denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
212
- targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
213
-
214
- l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
 
215
 
216
- grad_21, _ = guidance_grad(C21_in, self.diffusion_model.net, grad_scale=1, batch_size=self.fmap_cfg.diffusion.batch_sds,
 
217
  scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
218
- #denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
219
  with torch.no_grad():
220
- denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
221
- targets_21 = torch_zoomout(evecs2, evecs1, evecs_1trans, C21_obj.squeeze(), self.cfg.sds_conf.zoomout)#, step=10)
222
- l_proper_21 = ((C21_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_21.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
 
 
223
  l_proper = l_proper_12 + l_proper_21
224
 
225
- l_sds = ((torch.abs(C12_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_12.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
226
- l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[:self.n_fmap, :self.n_fmap])**2).mean()
 
 
 
 
227
  loss = torch.as_tensor(0.).float().to(self.device)
228
  if self.cfg.loss.get("ortho", 0) > 0:
229
- loss += self.cfg.loss.get("ortho", 0) * l_ortho
230
  if self.cfg.loss.get("bij", 0) > 0:
231
- loss += self.cfg.loss.get("bij", 0) * l_bij
232
  if self.cfg.loss.get("lap", 0) > 0:
233
- loss += self.cfg.loss.get("lap", 0) * l_lap
234
  if self.cfg.loss.get("cycle", 0) > 0:
235
- loss += self.cfg.loss.get("cycle", 0) * l_cycle
236
  if self.cfg.loss.get("mse_rec", 0) > 0:
237
- loss += self.cfg.loss.get("mse_rec", 0) * l_mse
238
  if self.cfg.loss.get("prism_rec", 0) > 0:
239
- loss += self.cfg.loss.get("prism_rec", 0) * l_prism
240
  if self.cfg.loss.get("sds", 0) > 0 and self.fmap_cfg.get("use_diff", False):
241
  loss += self.cfg.loss.get("sds", 0) * l_sds
242
  if self.cfg.loss.get("proper", 0) > 0 and self.fmap_cfg.get("use_diff", False):
243
  loss += self.cfg.loss.get("proper", 0) * l_proper
244
-
245
  loss.backward()
246
  self.optim.step()
247
  self.optim.zero_grad()
@@ -256,19 +282,17 @@ class Matcher(object):
256
  indKNN_new_init, _ = extract_p2p_torch_fmap(C12_pred_init, evecs1, evecs2)
257
  indKNN_new, _ = extract_p2p_torch_fmap(C12_new, evecs1, evecs2)
258
  return C12_new, indKNN_new, indKNN_new_init, snk_rec, loss_save
259
-
260
-
261
 
262
  def match(self, pair_batch, output_pair, geod_path, refine=True, eval=False):
263
- shape_dict, _, target_dict, _, target_normals, mapinfo = pair_batch
264
  shape_dict_device = convert_dict(shape_dict, self.device)
265
  target_dict_device = convert_dict(target_dict, self.device)
266
  print(shape_dict_device["vertices"].device)
267
  os.makedirs(output_pair, exist_ok=True)
268
 
269
-
270
  if self.cfg["optimize"]:
271
- C12_new, p2p, p2p_init, snk_rec, loss_save = self.optimize(shape_dict_device, target_dict_device, target_normals.to(self.device))
 
272
  np.save(os.path.join(output_pair, "p2p_init.npy"), p2p_init)
273
  np.save(os.path.join(output_pair, "losses.npy"), loss_save)
274
  else:
@@ -277,12 +301,13 @@ class Matcher(object):
277
  np.save(os.path.join(output_pair, "fmap.npy"), C12_new.detach().squeeze().cpu().numpy())
278
  np.save(os.path.join(output_pair, "p2p.npy"), p2p)
279
  if snk_rec is not None:
280
- save_ply(os.path.join(output_pair, "rec.ply"), snk_rec.detach().squeeze().cpu().numpy(), target_dict["faces"])
 
281
 
282
  if refine:
283
  evecs1, evecs2 = shape_dict_device["evecs"], target_dict_device["evecs"]
284
  evecs_2trans = evecs2.t() @ torch.diag(target_dict_device["mass"])
285
- new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze(), 128)#, step=10)
286
  p2p_refined_zo, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
287
  np.save(os.path.join(output_pair, "p2p_zo.npy"), p2p)
288
  if eval:
@@ -290,19 +315,16 @@ class Matcher(object):
290
  mat_loaded = scipy.io.loadmat(os.path.join(geod_path, file_i + ".mat"))
291
  A_geod, sqrt_area = mat_loaded['geod_dist'], np.sqrt(mat_loaded['areas_f'].sum())
292
  _, dist = accuracy(p2p[vts_2], vts_1, A_geod,
293
- sqrt_area=sqrt_area,
294
- return_all=True)
295
  if refine:
296
  _, dist_zo = accuracy(p2p_refined_zo[vts_2], vts_1, A_geod,
297
- sqrt_area=sqrt_area,
298
- return_all=True)
299
  np.savetxt(os.path.join(output_pair, "dists.txt"), (dist.mean(), dist_zo.mean()))
300
  return p2p, p2p_refined_zo, loss_save, dist.mean(), dist_zo.mean()
301
  return p2p, loss_save, dist.mean()
302
  return p2p, loss_save
303
-
304
-
305
-
306
 
307
  def _dataset_epoch(self, dataset, name_dataset, save_dir, data_dir):
308
  os.makedirs(save_dir, exist_ok=True)
@@ -323,27 +345,31 @@ class Matcher(object):
323
  print("Pair: " + shape_dict['name'] + " " + target_dict['name'])
324
  name_exp = os.path.join(save_dir, shape_dict['name'], target_dict['name'])
325
  if self.cfg.get("refine", False):
326
- _, _, _, dist, dist_zo = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True, refine=True)
 
327
  else:
328
- _, _, dist = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True, refine=False)
 
329
  delta = datetime.now() - t1
330
  fm_delta = str_delta(delta)
331
- remains = ((delta/(id_pair+1))*num_pairs) - delta
332
  fm_remains = str_delta(remains)
333
  all_accs.append(dist)
334
  accs_mean = np.mean(all_accs)
335
  if self.cfg.get("refine", False):
336
  all_accs_zo.append(dist_zo)
337
  accs_zo = np.mean(all_accs_zo)
338
- print(f"error: {dist}, zo: {dist_zo}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, mean zo: {accs_zo}, full time: {fm_delta}, remains: {fm_remains}")
 
339
  else:
340
- print(f"error: {dist}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, full time: {fm_delta}, remains: {fm_remains}")
 
341
  id_pair += 1
342
  if self.cfg.get("refine", False):
343
  print(f"mean error : {np.mean(all_accs)}, mean error refined: {np.mean(all_accs_zo)}")
344
  else:
345
  print(f"mean error : {np.mean(all_accs)}")
346
- sys.stdout = sys.__stdout__
347
 
348
  def load_data(self, file, num_evecs=200, make_cache=False, factor=None):
349
  name = os.path.basename(os.path.splitext(file)[0])
@@ -357,26 +383,25 @@ class Matcher(object):
357
  data_dict = load_operators(cache_path)
358
  data_dict['name'] = name
359
  data_dict_torch = convert_dict(data_dict, self.device)
360
- #batchify_dict(data_dict_torch)
361
  return data_dict_torch, area_shape
362
 
363
  def match_files(self, file_shape, file_target):
364
  batch_shape, _ = self.load_data(file_shape)
365
- batch_target, _ = self.load_data(file_target)
366
  target_surf = Surface(filename=file_target)
367
- target_normals = torch.from_numpy(target_surf.surfel/np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().to(self.device)
 
368
  batch = batch_shape, None, batch_target, target_normals, None, None
369
  output_folder = os.path.join(self.cfg.output, batch_shape["name"] + "_" + batch_shape["target"])
370
  p2p, _ = self.match(batch, output_folder, None)
371
  return batch_shape, batch_target, p2p
372
 
373
 
374
-
375
-
376
  if __name__ == '__main__':
377
  parser = argparse.ArgumentParser(description="Launch the SDS demo over datasets")
378
  parser.add_argument('--dataset', type=str, default="SCAPE", help='name of the dataset')
379
- parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')
380
  parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
381
  parser.add_argument('--output', type=str, default="results", help="where to store experience results")
382
  args = parser.parse_args()
@@ -398,5 +423,7 @@ if __name__ == '__main__':
398
  dset = pair_cls(corr_dir, 'test', dset_shape, rotate=cfg.get("rotate", False))
399
  exp_time = time.strftime('%y-%m-%d_%H-%M-%S')
400
  output_logs = os.path.join(args.output, name_data_geo, exp_time)
401
- matcher = Matcher(cfg)
 
 
402
  matcher._dataset_epoch(dset, name_data_geo, output_logs, args.datadir)
 
27
  from utils.geometry import compute_operators, load_operators
28
  from utils.surfaces import Surface
29
  import sys
30
+
31
  try:
32
  import google.colab
33
+
34
  print("Running Colab")
35
  from tqdm import tqdm
36
  except ImportError:
 
39
 
40
 
41
  def seed_everything(seed=42):
42
+ random.seed(seed)
43
+ os.environ['PYTHONHASHSEED'] = str(seed)
44
+ np.random.seed(seed)
45
+ torch.manual_seed(seed)
46
+ torch.backends.cudnn.deterministic = True
47
+ torch.backends.cudnn.benchmark = False
48
+
49
 
50
  seed_everything()
51
 
52
+
53
  class Tee:
54
  def __init__(self, *outputs):
55
  self.outputs = outputs
 
63
  for output in self.outputs:
64
  output.flush()
65
 
66
+
67
  class DiffModel:
68
 
69
  def __init__(self, cfg, device="cuda:0"):
 
86
  network_pkl = os.path.join(netdir, chosen_pkl)
87
  print(f'Loading network from "{network_pkl}"...')
88
  self.net = safe_load_with_fallback(network_pkl)['ema'].to(device)
89
+
90
  print('Done!')
91
  loss_name = train_cfg['hyper_params']['loss_name']
92
  self.loss_sde = None
 
120
  self.n_loop = self.cfg.opt.get("n_loop", 0)
121
  self.fmap_cfg = self.cfg.deepfeat_conf.fmap
122
  self.dataloaders = dict()
123
+
124
  def _init(self):
125
  cfg = self.cfg
126
  self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).to(self.device)
 
129
  self.decoder = PrismDecoder(dim_in=515).to(self.device)
130
  self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
131
  self.soft_p2p = True
132
+ params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(
133
+ self.decoder.parameters())
134
  else:
135
  params_to_opt = self.fmap_model.parameters()
136
  self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
 
139
 
140
  def fmap(self, shape_dict, target_dict):
141
  if self.fmap_cfg.get("use_diff", False):
142
+ C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
143
+ {"shape1": shape_dict, "shape2": target_dict}, diff_model=self.diffusion_model,
144
+ scale=self.fmap_cfg.diffusion.time)
145
  C12_pred, C12_obj, mask_12 = C12_pred
146
  C21_pred, C21_obj, mask_21 = C21_pred
147
  else:
148
+ C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
149
+ {"shape1": shape_dict, "shape2": target_dict})
150
  C12_obj, C21_obj = C12_pred, C21_pred
151
  mask_12, mask_21 = None, None
152
  return C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, mask_12, mask_21
 
153
 
154
  def zo_shot(self, shape_dict, target_dict):
155
  self._init()
156
  evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
157
+ _, C12_mask_init, _, _, _, _, _, _, _, _ = self.fmap(shape_dict, target_dict)
158
  evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
159
  new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_mask_init.squeeze(), self.cfg["zo_shot"])
160
  indKNN_new, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
161
  return new_FM, indKNN_new
162
 
 
163
  def optimize(self, shape_dict, target_dict, target_normals):
164
  self._init()
165
  evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
166
+ C12_pred_init, _, _, _, _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
167
  evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
168
  evecs_1trans = evecs1.t() @ torch.diag(shape_dict["mass"])
169
  n_verts_target = target_dict["vertices"].shape[-2]
170
+
171
+ loss_save = {"cycle": [], "fmap": [], "mse": [], "prism": [], "bij": [], "ortho": [], "sds": [], "lap": [],
172
+ "proper": []}
173
  snk_rec = None
174
  for i in tqdm(range(self.n_loop), "Optimizing matching " + shape_dict['name'] + " " + target_dict['name']):
175
+ C12_pred, C12_obj, C21_pred, C21_obj, feat1, feat2, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict,
176
+ target_dict)
177
  if self.cfg.opt.soft_p2p:
178
  ### A la SNK
179
  ## P2P 2 -> 1
180
+ soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_pred.squeeze(), evecs1[:, :self.n_fmap],
181
+ prod=True)
182
  C12_new = evecs_trans2[:self.n_fmap, :] @ soft_p2p_21 @ evecs1[:, :self.n_fmap]
183
  soft_p2p_21 = knnsearch(evecs2[:, :self.n_fmap] @ C12_new.squeeze(), evecs1[:, :self.n_fmap], prod=True)
184
 
185
+ ## P2P 1 -> 2
186
+ soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_pred.squeeze(), evecs2[:, :self.n_fmap],
187
+ prod=True)
188
  C21_new = evecs_trans1[:self.n_fmap, :] @ soft_p2p_12 @ evecs2[:, :self.n_fmap]
189
  soft_p2p_12 = knnsearch(evecs1[:, :self.n_fmap] @ C21_new.squeeze(), evecs2[:, :self.n_fmap], prod=True)
190
 
191
+ l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum(
192
+ dim=-1).mean()
193
  else:
194
  C12_new, C21_new = C12_pred, C21_pred
195
 
196
+ l_ortho = ((C12_new.squeeze() @ C12_new.squeeze().T - self.eye) ** 2).mean() + (
197
+ (C21_new.squeeze() @ C21_new.squeeze().T - self.eye) ** 2).mean()
198
+ l_bij = ((C12_new.squeeze() @ C21_new.squeeze() - self.eye) ** 2).mean() + (
199
+ (C21_new.squeeze() @ C12_new.squeeze() - self.eye) ** 2).mean()
200
+ l_lap = ((C12_new @ torch.diag(shape_dict["evals"][:self.n_fmap]) - torch.diag(
201
+ target_dict["evals"][:self.n_fmap]) @ C12_new) ** 2).mean()
202
+ l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(
203
+ shape_dict["evals"][:self.n_fmap]) @ C21_new) ** 2).mean()
204
+
205
+ l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(
206
+ self.device), torch.as_tensor(0.).float().to(self.device)
207
  if self.snk:
208
+ # Latent vector
209
  latents = self.encoder(shape_dict)
210
  latents_duplicate = latents[None, :].repeat(n_verts_target, 1)
211
 
 
213
  feats_decode = torch.cat((target_dict["vertices"], latents_duplicate), dim=1)
214
  snk_rec, prism, rots = self.decoder(target_dict, feats_decode)
215
  l_prism = self.loss_prism(prism, rots, target_dict["vertices"], target_dict["faces"], target_normals)
216
+ l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec) ** 2).sum(dim=-1).mean()
217
+ l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum(
218
+ dim=-1).mean()
219
  l_sds, l_proper = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
220
  if self.fmap_cfg.get("use_diff", False):
221
  if self.fmap_cfg.diffusion.get("abs", False):
222
  C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
223
  else:
224
  C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
225
+ grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1,
226
+ batch_size=self.fmap_cfg.diffusion.batch_sds,
227
  scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
228
  with torch.no_grad():
229
  denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
230
+ targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
231
+
232
+ l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap,
233
+ :self.n_fmap]) ** 2).mean()
234
 
235
+ grad_21, _ = guidance_grad(C21_in, self.diffusion_model.net, grad_scale=1,
236
+ batch_size=self.fmap_cfg.diffusion.batch_sds,
237
  scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
238
+ # denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
239
  with torch.no_grad():
240
+ denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
241
+ targets_21 = torch_zoomout(evecs2, evecs1, evecs_1trans, C21_obj.squeeze(),
242
+ self.cfg.sds_conf.zoomout) # , step=10)
243
+ l_proper_21 = ((C21_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_21.squeeze()[:self.n_fmap,
244
+ :self.n_fmap]) ** 2).mean()
245
  l_proper = l_proper_12 + l_proper_21
246
 
247
+ l_sds = ((torch.abs(C12_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_12.squeeze()[
248
+ :self.n_fmap,
249
+ :self.n_fmap]) ** 2).mean()
250
+ l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[
251
+ :self.n_fmap,
252
+ :self.n_fmap]) ** 2).mean()
253
  loss = torch.as_tensor(0.).float().to(self.device)
254
  if self.cfg.loss.get("ortho", 0) > 0:
255
+ loss += self.cfg.loss.get("ortho", 0) * l_ortho
256
  if self.cfg.loss.get("bij", 0) > 0:
257
+ loss += self.cfg.loss.get("bij", 0) * l_bij
258
  if self.cfg.loss.get("lap", 0) > 0:
259
+ loss += self.cfg.loss.get("lap", 0) * l_lap
260
  if self.cfg.loss.get("cycle", 0) > 0:
261
+ loss += self.cfg.loss.get("cycle", 0) * l_cycle
262
  if self.cfg.loss.get("mse_rec", 0) > 0:
263
+ loss += self.cfg.loss.get("mse_rec", 0) * l_mse
264
  if self.cfg.loss.get("prism_rec", 0) > 0:
265
+ loss += self.cfg.loss.get("prism_rec", 0) * l_prism
266
  if self.cfg.loss.get("sds", 0) > 0 and self.fmap_cfg.get("use_diff", False):
267
  loss += self.cfg.loss.get("sds", 0) * l_sds
268
  if self.cfg.loss.get("proper", 0) > 0 and self.fmap_cfg.get("use_diff", False):
269
  loss += self.cfg.loss.get("proper", 0) * l_proper
270
+
271
  loss.backward()
272
  self.optim.step()
273
  self.optim.zero_grad()
 
282
  indKNN_new_init, _ = extract_p2p_torch_fmap(C12_pred_init, evecs1, evecs2)
283
  indKNN_new, _ = extract_p2p_torch_fmap(C12_new, evecs1, evecs2)
284
  return C12_new, indKNN_new, indKNN_new_init, snk_rec, loss_save
 
 
285
 
286
  def match(self, pair_batch, output_pair, geod_path, refine=True, eval=False):
287
+ shape_dict, _, target_dict, _, target_normals, mapinfo = pair_batch
288
  shape_dict_device = convert_dict(shape_dict, self.device)
289
  target_dict_device = convert_dict(target_dict, self.device)
290
  print(shape_dict_device["vertices"].device)
291
  os.makedirs(output_pair, exist_ok=True)
292
 
 
293
  if self.cfg["optimize"]:
294
+ C12_new, p2p, p2p_init, snk_rec, loss_save = self.optimize(shape_dict_device, target_dict_device,
295
+ target_normals.to(self.device))
296
  np.save(os.path.join(output_pair, "p2p_init.npy"), p2p_init)
297
  np.save(os.path.join(output_pair, "losses.npy"), loss_save)
298
  else:
 
301
  np.save(os.path.join(output_pair, "fmap.npy"), C12_new.detach().squeeze().cpu().numpy())
302
  np.save(os.path.join(output_pair, "p2p.npy"), p2p)
303
  if snk_rec is not None:
304
+ save_ply(os.path.join(output_pair, "rec.ply"), snk_rec.detach().squeeze().cpu().numpy(),
305
+ target_dict["faces"])
306
 
307
  if refine:
308
  evecs1, evecs2 = shape_dict_device["evecs"], target_dict_device["evecs"]
309
  evecs_2trans = evecs2.t() @ torch.diag(target_dict_device["mass"])
310
+ new_FM = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_new.squeeze(), 128) # , step=10)
311
  p2p_refined_zo, _ = extract_p2p_torch_fmap(new_FM, evecs1, evecs2)
312
  np.save(os.path.join(output_pair, "p2p_zo.npy"), p2p)
313
  if eval:
 
315
  mat_loaded = scipy.io.loadmat(os.path.join(geod_path, file_i + ".mat"))
316
  A_geod, sqrt_area = mat_loaded['geod_dist'], np.sqrt(mat_loaded['areas_f'].sum())
317
  _, dist = accuracy(p2p[vts_2], vts_1, A_geod,
318
+ sqrt_area=sqrt_area,
319
+ return_all=True)
320
  if refine:
321
  _, dist_zo = accuracy(p2p_refined_zo[vts_2], vts_1, A_geod,
322
+ sqrt_area=sqrt_area,
323
+ return_all=True)
324
  np.savetxt(os.path.join(output_pair, "dists.txt"), (dist.mean(), dist_zo.mean()))
325
  return p2p, p2p_refined_zo, loss_save, dist.mean(), dist_zo.mean()
326
  return p2p, loss_save, dist.mean()
327
  return p2p, loss_save
 
 
 
328
 
329
  def _dataset_epoch(self, dataset, name_dataset, save_dir, data_dir):
330
  os.makedirs(save_dir, exist_ok=True)
 
345
  print("Pair: " + shape_dict['name'] + " " + target_dict['name'])
346
  name_exp = os.path.join(save_dir, shape_dict['name'], target_dict['name'])
347
  if self.cfg.get("refine", False):
348
+ _, _, _, dist, dist_zo = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset),
349
+ eval=True, refine=True)
350
  else:
351
+ _, _, dist = self.match(batch, name_exp, os.path.join(data_dir, "geomats", name_dataset), eval=True,
352
+ refine=False)
353
  delta = datetime.now() - t1
354
  fm_delta = str_delta(delta)
355
+ remains = ((delta / (id_pair + 1)) * num_pairs) - delta
356
  fm_remains = str_delta(remains)
357
  all_accs.append(dist)
358
  accs_mean = np.mean(all_accs)
359
  if self.cfg.get("refine", False):
360
  all_accs_zo.append(dist_zo)
361
  accs_zo = np.mean(all_accs_zo)
362
+ print(
363
+ f"error: {dist}, zo: {dist_zo}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, mean zo: {accs_zo}, full time: {fm_delta}, remains: {fm_remains}")
364
  else:
365
+ print(
366
+ f"error: {dist}, element {id_pair}/{num_pairs}, mean accuracy: {accs_mean}, full time: {fm_delta}, remains: {fm_remains}")
367
  id_pair += 1
368
  if self.cfg.get("refine", False):
369
  print(f"mean error : {np.mean(all_accs)}, mean error refined: {np.mean(all_accs_zo)}")
370
  else:
371
  print(f"mean error : {np.mean(all_accs)}")
372
+ sys.stdout = sys.__stdout__
373
 
374
  def load_data(self, file, num_evecs=200, make_cache=False, factor=None):
375
  name = os.path.basename(os.path.splitext(file)[0])
 
383
  data_dict = load_operators(cache_path)
384
  data_dict['name'] = name
385
  data_dict_torch = convert_dict(data_dict, self.device)
386
+ # batchify_dict(data_dict_torch)
387
  return data_dict_torch, area_shape
388
 
389
  def match_files(self, file_shape, file_target):
390
  batch_shape, _ = self.load_data(file_shape)
391
+ batch_target, _ = self.load_data(file_target)
392
  target_surf = Surface(filename=file_target)
393
+ target_normals = torch.from_numpy(
394
+ target_surf.surfel / np.linalg.norm(target_surf.surfel, axis=-1, keepdims=True)).float().to(self.device)
395
  batch = batch_shape, None, batch_target, target_normals, None, None
396
  output_folder = os.path.join(self.cfg.output, batch_shape["name"] + "_" + batch_shape["target"])
397
  p2p, _ = self.match(batch, output_folder, None)
398
  return batch_shape, batch_target, p2p
399
 
400
 
 
 
401
  if __name__ == '__main__':
402
  parser = argparse.ArgumentParser(description="Launch the SDS demo over datasets")
403
  parser.add_argument('--dataset', type=str, default="SCAPE", help='name of the dataset')
404
+ parser.add_argument('--config', type=str, default="config/matching/sds.yaml", help='Config file location')
405
  parser.add_argument('--datadir', type=str, default="data", help='path where datasets are store')
406
  parser.add_argument('--output', type=str, default="results", help="where to store experience results")
407
  args = parser.parse_args()
 
423
  dset = pair_cls(corr_dir, 'test', dset_shape, rotate=cfg.get("rotate", False))
424
  exp_time = time.strftime('%y-%m-%d_%H-%M-%S')
425
  output_logs = os.path.join(args.output, name_data_geo, exp_time)
426
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
427
+ print(f"Using device: {device}")
428
+ matcher = Matcher(cfg, device)
429
  matcher._dataset_epoch(dset, name_data_geo, output_logs, args.datadir)