Spaces:
Runtime error
Runtime error
| # Copyright 2021 by Haozhe Wu, Tsinghua University, Department of Computer Science and Technology. | |
| # All rights reserved. | |
| # This file is part of the pytorch-nicp, | |
| # and is released under the "MIT License Agreement". Please see the LICENSE | |
| # file that should have been included as part of this package. | |
| import torch | |
| import torch.nn as nn | |
| import trimesh | |
| from pytorch3d.loss import chamfer_distance | |
| from pytorch3d.structures import Meshes | |
| from tqdm import tqdm | |
| from lib.common.train_util import init_loss | |
| from lib.dataset.mesh_util import update_mesh_shape_prior_losses | |
| # reference: https://github.com/wuhaozhe/pytorch-nicp | |
| class LocalAffine(nn.Module): | |
| def __init__(self, num_points, batch_size=1, edges=None): | |
| ''' | |
| specify the number of points, the number of points should be constant across the batch | |
| and the edges torch.Longtensor() with shape N * 2 | |
| the local affine operator supports batch operation | |
| batch size must be constant | |
| add additional pooling on top of w matrix | |
| ''' | |
| super(LocalAffine, self).__init__() | |
| self.A = nn.Parameter( | |
| torch.eye(3).unsqueeze(0).unsqueeze(0).repeat(batch_size, num_points, 1, 1) | |
| ) | |
| self.b = nn.Parameter( | |
| torch.zeros(3).unsqueeze(0).unsqueeze(0).unsqueeze(3).repeat( | |
| batch_size, num_points, 1, 1 | |
| ) | |
| ) | |
| self.edges = edges | |
| self.num_points = num_points | |
| def stiffness(self): | |
| ''' | |
| calculate the stiffness of local affine transformation | |
| f norm get infinity gradient when w is zero matrix, | |
| ''' | |
| if self.edges is None: | |
| raise Exception("edges cannot be none when calculate stiff") | |
| affine_weight = torch.cat((self.A, self.b), dim=3) | |
| w1 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 0]) | |
| w2 = torch.index_select(affine_weight, dim=1, index=self.edges[:, 1]) | |
| w_diff = (w1 - w2)**2 | |
| w_rigid = (torch.linalg.det(self.A) - 1.0)**2 | |
| return w_diff, w_rigid | |
| def forward(self, x): | |
| ''' | |
| x should have shape of B * N * 3 * 1 | |
| ''' | |
| x = x.unsqueeze(3) | |
| out_x = torch.matmul(self.A, x) | |
| out_x = out_x + self.b | |
| out_x.squeeze_(3) | |
| stiffness, rigid = self.stiffness() | |
| return out_x, stiffness, rigid | |
| def trimesh2meshes(mesh): | |
| ''' | |
| convert trimesh mesh to pytorch3d mesh | |
| ''' | |
| verts = torch.from_numpy(mesh.vertices).float() | |
| faces = torch.from_numpy(mesh.faces).long() | |
| mesh = Meshes(verts.unsqueeze(0), faces.unsqueeze(0)) | |
| return mesh | |
| def register(target_mesh, src_mesh, device, verbose=True): | |
| # define local_affine deform verts | |
| tgt_mesh = trimesh2meshes(target_mesh).to(device) | |
| src_verts = src_mesh.verts_padded().clone() | |
| local_affine_model = LocalAffine( | |
| src_mesh.verts_padded().shape[1], | |
| src_mesh.verts_padded().shape[0], src_mesh.edges_packed() | |
| ).to(device) | |
| optimizer_cloth = torch.optim.Adam([{'params': local_affine_model.parameters()}], | |
| lr=1e-2, | |
| amsgrad=True) | |
| scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer_cloth, | |
| mode="min", | |
| factor=0.1, | |
| verbose=0, | |
| min_lr=1e-5, | |
| patience=5, | |
| ) | |
| losses = init_loss() | |
| if verbose: | |
| loop_cloth = tqdm(range(100)) | |
| else: | |
| loop_cloth = range(100) | |
| for i in loop_cloth: | |
| optimizer_cloth.zero_grad() | |
| deformed_verts, stiffness, rigid = local_affine_model(x=src_verts) | |
| src_mesh = src_mesh.update_padded(deformed_verts) | |
| # losses for laplacian, edge, normal consistency | |
| update_mesh_shape_prior_losses(src_mesh, losses) | |
| losses["cloth"]["value"] = chamfer_distance( | |
| x=src_mesh.verts_padded(), y=tgt_mesh.verts_padded() | |
| )[0] | |
| losses["stiff"]["value"] = torch.mean(stiffness) | |
| losses["rigid"]["value"] = torch.mean(rigid) | |
| # Weighted sum of the losses | |
| cloth_loss = torch.tensor(0.0, requires_grad=True).to(device) | |
| pbar_desc = "Register SMPL-X -> d-BiNI -- " | |
| for k in losses.keys(): | |
| if losses[k]["weight"] > 0.0 and losses[k]["value"] != 0.0: | |
| cloth_loss = cloth_loss + \ | |
| losses[k]["value"] * losses[k]["weight"] | |
| pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.3f} | " | |
| if verbose: | |
| pbar_desc += f"TOTAL: {cloth_loss:.3f}" | |
| loop_cloth.set_description(pbar_desc) | |
| # update params | |
| cloth_loss.backward(retain_graph=True) | |
| optimizer_cloth.step() | |
| scheduler_cloth.step(cloth_loss) | |
| print(pbar_desc) | |
| final = trimesh.Trimesh( | |
| src_mesh.verts_packed().detach().squeeze(0).cpu(), | |
| src_mesh.faces_packed().detach().squeeze(0).cpu(), | |
| process=False, | |
| maintains_order=True | |
| ) | |
| return final | |