daidedou commited on
Commit
20ce39e
·
1 Parent(s): 38ea30e

Remove pykeops for space ZeroGPU demo.

Browse files
Files changed (2) hide show
  1. app.py +6 -2
  2. zero_shot.py +16 -15
app.py CHANGED
@@ -30,6 +30,7 @@ from utils.meshplot import visu_pts
30
  from utils.torch_fmap import extract_p2p_torch_fmap, torch_zoomout
31
  import torch
32
  import argparse
 
33
  # -----------------------------
34
  # Utils
35
  # -----------------------------
@@ -173,8 +174,11 @@ def init_clicked(mesh1_path, mesh2_path,
173
  matcher._init()
174
  global datadicts
175
  datadicts = Datadicts(mesh1_path, mesh2_path)
176
-
177
- C12_pred_init, C21_pred_init, feat1, feat2, evecs_trans1, evecs_trans2 = matcher.fmap_model({"shape1": datadicts.shape_dict, "shape2": datadicts.target_dict}, diff_model=matcher.diffusion_model, scale=matcher.fmap_cfg.diffusion.time)
 
 
 
178
  C12_pred, C12_obj, mask_12 = C12_pred_init
179
  p2p_init, _ = extract_p2p_torch_fmap(C12_obj, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
180
  return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_init, tag="init")
 
30
  from utils.torch_fmap import extract_p2p_torch_fmap, torch_zoomout
31
  import torch
32
  import argparse
33
+ from utils.utils_func import convert_dict
34
  # -----------------------------
35
  # Utils
36
  # -----------------------------
 
174
  matcher._init()
175
  global datadicts
176
  datadicts = Datadicts(mesh1_path, mesh2_path)
177
+ shape_dict, target_dict = convert_dict(datadicts.shape_dict, 'cuda'), convert_dict(datadicts.target_dict, 'cuda')
178
+ fmap_model_cuda = matcher.fmap_model.cuda()
179
+ diff_model_cuda = matcher.diffusion_model
180
+ diff_model_cuda.net.cuda()
181
+ C12_pred_init, C21_pred_init, feat1, feat2, evecs_trans1, evecs_trans2 = fmap_model_cuda({"shape1": shape_dict, "shape2": target_dict}, diff_model=diff_model_cuda, scale=matcher.fmap_cfg.diffusion.time)
182
  C12_pred, C12_obj, mask_12 = C12_pred_init
183
  p2p_init, _ = extract_p2p_torch_fmap(C12_obj, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
184
  return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_init, tag="init")
zero_shot.py CHANGED
@@ -100,11 +100,9 @@ class Matcher(object):
100
 
101
  def __init__(self, cfg):
102
  self.cfg = cfg
103
- self.device = torch.device(f'cuda:{cfg["gpu"]}' if torch.cuda.is_available() else 'cpu')
104
- print(f"Using device: {self.device}")
105
  self.diffusion_model = None
106
  if self.cfg.get("sds", False):
107
- self.diffusion_model = DiffModel(cfg["sds_conf"], self.device)
108
  self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
109
  self.n_loop = 0
110
  if self.cfg.get("optimize", False):
@@ -124,10 +122,10 @@ class Matcher(object):
124
 
125
  def _init(self):
126
  cfg = self.cfg
127
- self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).to(self.device)
128
  if self.snk:
129
- self.encoder = Encoder().to(self.device)
130
- self.decoder = PrismDecoder(dim_in=515).to(self.device)
131
  self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
132
  self.soft_p2p = True
133
  params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(
@@ -135,13 +133,15 @@ class Matcher(object):
135
  else:
136
  params_to_opt = self.fmap_model.parameters()
137
  self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
138
- self.eye = torch.eye(self.n_fmap).float().to(self.device)
139
  self.eye.requires_grad = False
140
 
141
  def fmap(self, shape_dict, target_dict):
 
 
142
  if self.fmap_cfg.get("use_diff", False):
143
  C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
144
- {"shape1": shape_dict, "shape2": target_dict}, diff_model=self.diffusion_model,
145
  scale=self.fmap_cfg.diffusion.time)
146
  C12_pred, C12_obj, mask_12 = C12_pred
147
  C21_pred, C21_obj, mask_21 = C21_pred
@@ -163,6 +163,8 @@ class Matcher(object):
163
 
164
  def optimize(self, shape_dict, target_dict, target_normals):
165
  self._init()
 
 
166
  evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
167
  C12_pred_init, _, _, _, _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
168
  evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
@@ -203,8 +205,7 @@ class Matcher(object):
203
  l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(
204
  shape_dict["evals"][:self.n_fmap]) @ C21_new) ** 2).mean()
205
 
206
- l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(
207
- self.device), torch.as_tensor(0.).float().to(self.device)
208
  if self.snk:
209
  # Latent vector
210
  latents = self.encoder(shape_dict)
@@ -217,7 +218,7 @@ class Matcher(object):
217
  l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec) ** 2).sum(dim=-1).mean()
218
  l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum(
219
  dim=-1).mean()
220
- l_sds, l_proper = torch.as_tensor(0.).float().to(self.device), torch.as_tensor(0.).float().to(self.device)
221
  if self.fmap_cfg.get("use_diff", False):
222
  if self.fmap_cfg.diffusion.get("abs", False):
223
  C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
@@ -225,7 +226,7 @@ class Matcher(object):
225
  C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
226
  grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1,
227
  batch_size=self.fmap_cfg.diffusion.batch_sds,
228
- scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
229
  with torch.no_grad():
230
  denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
231
  targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
@@ -233,9 +234,9 @@ class Matcher(object):
233
  l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap,
234
  :self.n_fmap]) ** 2).mean()
235
 
236
- grad_21, _ = guidance_grad(C21_in, self.diffusion_model.net, grad_scale=1,
237
  batch_size=self.fmap_cfg.diffusion.batch_sds,
238
- scale_noise=self.fmap_cfg.diffusion.time, device=self.device)
239
  # denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
240
  with torch.no_grad():
241
  denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
@@ -251,7 +252,7 @@ class Matcher(object):
251
  l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[
252
  :self.n_fmap,
253
  :self.n_fmap]) ** 2).mean()
254
- loss = torch.as_tensor(0.).float().to(self.device)
255
  if self.cfg.loss.get("ortho", 0) > 0:
256
  loss += self.cfg.loss.get("ortho", 0) * l_ortho
257
  if self.cfg.loss.get("bij", 0) > 0:
 
100
 
101
  def __init__(self, cfg):
102
  self.cfg = cfg
 
 
103
  self.diffusion_model = None
104
  if self.cfg.get("sds", False):
105
+ self.diffusion_model = DiffModel(cfg["sds_conf"], "cpu")
106
  self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
107
  self.n_loop = 0
108
  if self.cfg.get("optimize", False):
 
122
 
123
  def _init(self):
124
  cfg = self.cfg
125
+ self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).cuda()
126
  if self.snk:
127
+ self.encoder = Encoder().cuda()
128
+ self.decoder = PrismDecoder(dim_in=515).cuda()
129
  self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
130
  self.soft_p2p = True
131
  params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(
 
133
  else:
134
  params_to_opt = self.fmap_model.parameters()
135
  self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
136
+ self.eye = torch.eye(self.n_fmap).float().cuda()
137
  self.eye.requires_grad = False
138
 
139
  def fmap(self, shape_dict, target_dict):
140
+ diff_model_cuda = self.diffusion_model
141
+ diff_model_cuda.net.cuda()
142
  if self.fmap_cfg.get("use_diff", False):
143
  C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
144
+ {"shape1": shape_dict, "shape2": target_dict}, diff_model=diff_model_cuda,
145
  scale=self.fmap_cfg.diffusion.time)
146
  C12_pred, C12_obj, mask_12 = C12_pred
147
  C21_pred, C21_obj, mask_21 = C21_pred
 
163
 
164
  def optimize(self, shape_dict, target_dict, target_normals):
165
  self._init()
166
+ diff_model_cuda = self.diffusion_model
167
+ diff_model_cuda.net.cuda()
168
  evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
169
  C12_pred_init, _, _, _, _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
170
  evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
 
205
  l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(
206
  shape_dict["evals"][:self.n_fmap]) @ C21_new) ** 2).mean()
207
 
208
+ l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().cuda(), torch.as_tensor(0.).float().cuda(), torch.as_tensor(0.).float().cuda()
 
209
  if self.snk:
210
  # Latent vector
211
  latents = self.encoder(shape_dict)
 
218
  l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec) ** 2).sum(dim=-1).mean()
219
  l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum(
220
  dim=-1).mean()
221
+ l_sds, l_proper = torch.as_tensor(0.).float().cuda(), torch.as_tensor(0.).float().cuda()
222
  if self.fmap_cfg.get("use_diff", False):
223
  if self.fmap_cfg.diffusion.get("abs", False):
224
  C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
 
226
  C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
227
  grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1,
228
  batch_size=self.fmap_cfg.diffusion.batch_sds,
229
+ scale_noise=self.fmap_cfg.diffusion.time, device="cuda")
230
  with torch.no_grad():
231
  denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
232
  targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
 
234
  l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap,
235
  :self.n_fmap]) ** 2).mean()
236
 
237
+ grad_21, _ = guidance_grad(C21_in, diff_model_cuda.net, grad_scale=1,
238
  batch_size=self.fmap_cfg.diffusion.batch_sds,
239
+ scale_noise=self.fmap_cfg.diffusion.time, device="cuda")
240
  # denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
241
  with torch.no_grad():
242
  denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
 
252
  l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[
253
  :self.n_fmap,
254
  :self.n_fmap]) ** 2).mean()
255
+ loss = torch.as_tensor(0.).float().to("cuda")
256
  if self.cfg.loss.get("ortho", 0) > 0:
257
  loss += self.cfg.loss.get("ortho", 0) * l_ortho
258
  if self.cfg.loss.get("bij", 0) > 0: