Spaces:
Running
on
Zero
Running
on
Zero
daidedou
commited on
Commit
·
20ce39e
1
Parent(s):
38ea30e
Remove pykeops for space ZeroGPU demo.
Browse files- app.py +6 -2
- zero_shot.py +16 -15
app.py
CHANGED
|
@@ -30,6 +30,7 @@ from utils.meshplot import visu_pts
|
|
| 30 |
from utils.torch_fmap import extract_p2p_torch_fmap, torch_zoomout
|
| 31 |
import torch
|
| 32 |
import argparse
|
|
|
|
| 33 |
# -----------------------------
|
| 34 |
# Utils
|
| 35 |
# -----------------------------
|
|
@@ -173,8 +174,11 @@ def init_clicked(mesh1_path, mesh2_path,
|
|
| 173 |
matcher._init()
|
| 174 |
global datadicts
|
| 175 |
datadicts = Datadicts(mesh1_path, mesh2_path)
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
| 178 |
C12_pred, C12_obj, mask_12 = C12_pred_init
|
| 179 |
p2p_init, _ = extract_p2p_torch_fmap(C12_obj, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
|
| 180 |
return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_init, tag="init")
|
|
|
|
| 30 |
from utils.torch_fmap import extract_p2p_torch_fmap, torch_zoomout
|
| 31 |
import torch
|
| 32 |
import argparse
|
| 33 |
+
from utils.utils_func import convert_dict
|
| 34 |
# -----------------------------
|
| 35 |
# Utils
|
| 36 |
# -----------------------------
|
|
|
|
| 174 |
matcher._init()
|
| 175 |
global datadicts
|
| 176 |
datadicts = Datadicts(mesh1_path, mesh2_path)
|
| 177 |
+
shape_dict, target_dict = convert_dict(datadicts.shape_dict, 'cuda'), convert_dict(datadicts.target_dict, 'cuda')
|
| 178 |
+
fmap_model_cuda = matcher.fmap_model.cuda()
|
| 179 |
+
diff_model_cuda = matcher.diffusion_model
|
| 180 |
+
diff_model_cuda.net.cuda()
|
| 181 |
+
C12_pred_init, C21_pred_init, feat1, feat2, evecs_trans1, evecs_trans2 = fmap_model_cuda({"shape1": shape_dict, "shape2": target_dict}, diff_model=diff_model_cuda, scale=matcher.fmap_cfg.diffusion.time)
|
| 182 |
C12_pred, C12_obj, mask_12 = C12_pred_init
|
| 183 |
p2p_init, _ = extract_p2p_torch_fmap(C12_obj, datadicts.shape_dict["evecs"], datadicts.target_dict["evecs"])
|
| 184 |
return build_outputs(datadicts.shape_surf, datadicts.target_surf, datadicts.cmap1, p2p_init, tag="init")
|
zero_shot.py
CHANGED
|
@@ -100,11 +100,9 @@ class Matcher(object):
|
|
| 100 |
|
| 101 |
def __init__(self, cfg):
|
| 102 |
self.cfg = cfg
|
| 103 |
-
self.device = torch.device(f'cuda:{cfg["gpu"]}' if torch.cuda.is_available() else 'cpu')
|
| 104 |
-
print(f"Using device: {self.device}")
|
| 105 |
self.diffusion_model = None
|
| 106 |
if self.cfg.get("sds", False):
|
| 107 |
-
self.diffusion_model = DiffModel(cfg["sds_conf"],
|
| 108 |
self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
|
| 109 |
self.n_loop = 0
|
| 110 |
if self.cfg.get("optimize", False):
|
|
@@ -124,10 +122,10 @@ class Matcher(object):
|
|
| 124 |
|
| 125 |
def _init(self):
|
| 126 |
cfg = self.cfg
|
| 127 |
-
self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).
|
| 128 |
if self.snk:
|
| 129 |
-
self.encoder = Encoder().
|
| 130 |
-
self.decoder = PrismDecoder(dim_in=515).
|
| 131 |
self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
|
| 132 |
self.soft_p2p = True
|
| 133 |
params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(
|
|
@@ -135,13 +133,15 @@ class Matcher(object):
|
|
| 135 |
else:
|
| 136 |
params_to_opt = self.fmap_model.parameters()
|
| 137 |
self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
|
| 138 |
-
self.eye = torch.eye(self.n_fmap).float().
|
| 139 |
self.eye.requires_grad = False
|
| 140 |
|
| 141 |
def fmap(self, shape_dict, target_dict):
|
|
|
|
|
|
|
| 142 |
if self.fmap_cfg.get("use_diff", False):
|
| 143 |
C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
|
| 144 |
-
{"shape1": shape_dict, "shape2": target_dict}, diff_model=
|
| 145 |
scale=self.fmap_cfg.diffusion.time)
|
| 146 |
C12_pred, C12_obj, mask_12 = C12_pred
|
| 147 |
C21_pred, C21_obj, mask_21 = C21_pred
|
|
@@ -163,6 +163,8 @@ class Matcher(object):
|
|
| 163 |
|
| 164 |
def optimize(self, shape_dict, target_dict, target_normals):
|
| 165 |
self._init()
|
|
|
|
|
|
|
| 166 |
evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
|
| 167 |
C12_pred_init, _, _, _, _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
|
| 168 |
evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
|
|
@@ -203,8 +205,7 @@ class Matcher(object):
|
|
| 203 |
l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(
|
| 204 |
shape_dict["evals"][:self.n_fmap]) @ C21_new) ** 2).mean()
|
| 205 |
|
| 206 |
-
l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().
|
| 207 |
-
self.device), torch.as_tensor(0.).float().to(self.device)
|
| 208 |
if self.snk:
|
| 209 |
# Latent vector
|
| 210 |
latents = self.encoder(shape_dict)
|
|
@@ -217,7 +218,7 @@ class Matcher(object):
|
|
| 217 |
l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec) ** 2).sum(dim=-1).mean()
|
| 218 |
l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum(
|
| 219 |
dim=-1).mean()
|
| 220 |
-
l_sds, l_proper = torch.as_tensor(0.).float().
|
| 221 |
if self.fmap_cfg.get("use_diff", False):
|
| 222 |
if self.fmap_cfg.diffusion.get("abs", False):
|
| 223 |
C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
|
|
@@ -225,7 +226,7 @@ class Matcher(object):
|
|
| 225 |
C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
|
| 226 |
grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1,
|
| 227 |
batch_size=self.fmap_cfg.diffusion.batch_sds,
|
| 228 |
-
scale_noise=self.fmap_cfg.diffusion.time, device=
|
| 229 |
with torch.no_grad():
|
| 230 |
denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
|
| 231 |
targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
|
|
@@ -233,9 +234,9 @@ class Matcher(object):
|
|
| 233 |
l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap,
|
| 234 |
:self.n_fmap]) ** 2).mean()
|
| 235 |
|
| 236 |
-
grad_21, _ = guidance_grad(C21_in,
|
| 237 |
batch_size=self.fmap_cfg.diffusion.batch_sds,
|
| 238 |
-
scale_noise=self.fmap_cfg.diffusion.time, device=
|
| 239 |
# denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
| 240 |
with torch.no_grad():
|
| 241 |
denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
|
@@ -251,7 +252,7 @@ class Matcher(object):
|
|
| 251 |
l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[
|
| 252 |
:self.n_fmap,
|
| 253 |
:self.n_fmap]) ** 2).mean()
|
| 254 |
-
loss = torch.as_tensor(0.).float().to(
|
| 255 |
if self.cfg.loss.get("ortho", 0) > 0:
|
| 256 |
loss += self.cfg.loss.get("ortho", 0) * l_ortho
|
| 257 |
if self.cfg.loss.get("bij", 0) > 0:
|
|
|
|
| 100 |
|
| 101 |
def __init__(self, cfg):
|
| 102 |
self.cfg = cfg
|
|
|
|
|
|
|
| 103 |
self.diffusion_model = None
|
| 104 |
if self.cfg.get("sds", False):
|
| 105 |
+
self.diffusion_model = DiffModel(cfg["sds_conf"], "cpu")
|
| 106 |
self.n_fmap = self.cfg["deepfeat_conf"]["fmap"]["n_fmap"]
|
| 107 |
self.n_loop = 0
|
| 108 |
if self.cfg.get("optimize", False):
|
|
|
|
| 122 |
|
| 123 |
def _init(self):
|
| 124 |
cfg = self.cfg
|
| 125 |
+
self.fmap_model = DFMNet(self.cfg["deepfeat_conf"]["fmap"]).cuda()
|
| 126 |
if self.snk:
|
| 127 |
+
self.encoder = Encoder().cuda()
|
| 128 |
+
self.decoder = PrismDecoder(dim_in=515).cuda()
|
| 129 |
self.loss_prism = PrismRegularizationLoss(primo_h=0.02)
|
| 130 |
self.soft_p2p = True
|
| 131 |
params_to_opt = list(self.fmap_model.parameters()) + list(self.encoder.parameters()) + list(
|
|
|
|
| 133 |
else:
|
| 134 |
params_to_opt = self.fmap_model.parameters()
|
| 135 |
self.optim = torch.optim.Adam(params_to_opt, lr=0.001, betas=(0.9, 0.99))
|
| 136 |
+
self.eye = torch.eye(self.n_fmap).float().cuda()
|
| 137 |
self.eye.requires_grad = False
|
| 138 |
|
| 139 |
def fmap(self, shape_dict, target_dict):
|
| 140 |
+
diff_model_cuda = self.diffusion_model
|
| 141 |
+
diff_model_cuda.net.cuda()
|
| 142 |
if self.fmap_cfg.get("use_diff", False):
|
| 143 |
C12_pred, C21_pred, feat1, feat2, evecs_trans1, evecs_trans2 = self.fmap_model(
|
| 144 |
+
{"shape1": shape_dict, "shape2": target_dict}, diff_model=diff_model_cuda,
|
| 145 |
scale=self.fmap_cfg.diffusion.time)
|
| 146 |
C12_pred, C12_obj, mask_12 = C12_pred
|
| 147 |
C21_pred, C21_obj, mask_21 = C21_pred
|
|
|
|
| 163 |
|
| 164 |
def optimize(self, shape_dict, target_dict, target_normals):
|
| 165 |
self._init()
|
| 166 |
+
diff_model_cuda = self.diffusion_model
|
| 167 |
+
diff_model_cuda.net.cuda()
|
| 168 |
evecs1, evecs2 = shape_dict["evecs"], target_dict["evecs"]
|
| 169 |
C12_pred_init, _, _, _, _, _, evecs_trans1, evecs_trans2, _, _ = self.fmap(shape_dict, target_dict)
|
| 170 |
evecs_2trans = evecs2.t() @ torch.diag(target_dict["mass"])
|
|
|
|
| 205 |
l_lap += ((C21_new @ torch.diag(target_dict["evals"][:self.n_fmap]) - torch.diag(
|
| 206 |
shape_dict["evals"][:self.n_fmap]) @ C21_new) ** 2).mean()
|
| 207 |
|
| 208 |
+
l_cycle, l_prism, l_mse = torch.as_tensor(0.).float().cuda(), torch.as_tensor(0.).float().cuda(), torch.as_tensor(0.).float().cuda()
|
|
|
|
| 209 |
if self.snk:
|
| 210 |
# Latent vector
|
| 211 |
latents = self.encoder(shape_dict)
|
|
|
|
| 218 |
l_mse = ((soft_p2p_21 @ shape_dict["vertices"] - snk_rec) ** 2).sum(dim=-1).mean()
|
| 219 |
l_cycle = ((soft_p2p_12 @ (soft_p2p_21 @ shape_dict["vertices"]) - shape_dict["vertices"]) ** 2).sum(
|
| 220 |
dim=-1).mean()
|
| 221 |
+
l_sds, l_proper = torch.as_tensor(0.).float().cuda(), torch.as_tensor(0.).float().cuda()
|
| 222 |
if self.fmap_cfg.get("use_diff", False):
|
| 223 |
if self.fmap_cfg.diffusion.get("abs", False):
|
| 224 |
C12_in, C21_in = torch.abs(C12_pred).squeeze(), torch.abs(C21_pred).squeeze()
|
|
|
|
| 226 |
C12_in, C21_in = C12_pred.squeeze(), C21_pred.squeeze()
|
| 227 |
grad_12, _ = guidance_grad(C12_in, self.diffusion_model.net, grad_scale=1,
|
| 228 |
batch_size=self.fmap_cfg.diffusion.batch_sds,
|
| 229 |
+
scale_noise=self.fmap_cfg.diffusion.time, device="cuda")
|
| 230 |
with torch.no_grad():
|
| 231 |
denoised_12 = C12_pred - self.optim.param_groups[0]['lr'] * grad_12
|
| 232 |
targets_12 = torch_zoomout(evecs1, evecs2, evecs_2trans, C12_obj.squeeze(), self.cfg.sds_conf.zoomout)
|
|
|
|
| 234 |
l_proper_12 = ((C12_pred.squeeze()[:self.n_fmap, :self.n_fmap] - targets_12.squeeze()[:self.n_fmap,
|
| 235 |
:self.n_fmap]) ** 2).mean()
|
| 236 |
|
| 237 |
+
grad_21, _ = guidance_grad(C21_in, diff_model_cuda.net, grad_scale=1,
|
| 238 |
batch_size=self.fmap_cfg.diffusion.batch_sds,
|
| 239 |
+
scale_noise=self.fmap_cfg.diffusion.time, device="cuda")
|
| 240 |
# denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
| 241 |
with torch.no_grad():
|
| 242 |
denoised_21 = C21_pred - self.optim.param_groups[0]['lr'] * grad_21
|
|
|
|
| 252 |
l_sds += ((torch.abs(C21_pred).squeeze()[:self.n_fmap, :self.n_fmap] - denoised_21.squeeze()[
|
| 253 |
:self.n_fmap,
|
| 254 |
:self.n_fmap]) ** 2).mean()
|
| 255 |
+
loss = torch.as_tensor(0.).float().to("cuda")
|
| 256 |
if self.cfg.loss.get("ortho", 0) > 0:
|
| 257 |
loss += self.cfg.loss.get("ortho", 0) * l_ortho
|
| 258 |
if self.cfg.loss.get("bij", 0) > 0:
|