Spaces:
Configuration error
Configuration error
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| def ppts_to_pts(ppts, bw, A): | |
| """transform points from the pose space to the zero space""" | |
| sh = ppts.shape | |
| bw = bw.permute(0, 2, 1) | |
| A = torch.bmm(bw, A.view(sh[0], 24, -1)) | |
| A = A.view(sh[0], -1, 4, 4) | |
| pts = ppts - A[..., :3, 3] | |
| R_inv = torch.inverse(A[..., :3, :3]) | |
| pts = torch.sum(R_inv * pts[:, :, None], dim=3) | |
| return pts | |
| def grid_sample_blend_weights(grid_coords, bw): | |
| # the blend weight is indexed by xyz | |
| grid_coords = grid_coords[:, None, None] | |
| bw = F.grid_sample(bw, | |
| grid_coords, | |
| padding_mode='border', | |
| align_corners=True) | |
| bw = bw[:, :, 0, 0] | |
| return bw | |
| def bounds_grid_sample_blend_weights(pts, bw, bounds): | |
| """grid sample blend weights""" | |
| pts = pts.clone() | |
| # interpolate blend weights | |
| min_xyz = bounds[:, 0] | |
| max_xyz = bounds[:, 1] | |
| bounds = max_xyz[:, None] - min_xyz[:, None] | |
| grid_coords = (pts - min_xyz[:, None]) / bounds | |
| grid_coords = grid_coords * 2 - 1 | |
| # convert xyz to zyx, since the blend weight is indexed by xyz | |
| grid_coords = grid_coords[..., [2, 1, 0]] | |
| # the blend weight is indexed by xyz | |
| bw = bw.permute(0, 4, 1, 2, 3) | |
| grid_coords = grid_coords[:, None, None] | |
| bw = F.grid_sample(bw, | |
| grid_coords, | |
| padding_mode='border', | |
| align_corners=True) | |
| bw = bw[:, :, 0, 0] | |
| return bw | |
| def grid_sample_A_blend_weights(nf_grid_coords, bw): | |
| """ | |
| nf_grid_coords: batch_size x N_samples x 24 x 3 | |
| bw: batch_size x 24 x 64 x 64 x 64 | |
| """ | |
| bws = [] | |
| for i in range(24): | |
| nf_grid_coords_ = nf_grid_coords[:, :, i] | |
| nf_grid_coords_ = nf_grid_coords_[:, None, None] | |
| bw_ = F.grid_sample(bw[:, i:i + 1], | |
| nf_grid_coords_, | |
| padding_mode='border', | |
| align_corners=True) | |
| bw_ = bw_[:, :, 0, 0] | |
| bws.append(bw_) | |
| bw = torch.cat(bws, dim=1) | |
| return bw | |
| def ppts_to_pts(pts, bw, A): | |
| """transform points from the pose space to the t pose""" | |
| sh = pts.shape | |
| bw = bw.permute(0, 2, 1) | |
| A = torch.bmm(bw, A.view(sh[0], 24, -1)) | |
| A = A.view(sh[0], -1, 4, 4) | |
| pts = pts - A[..., :3, 3] | |
| R_inv = torch.inverse(A[..., :3, :3]) | |
| pts = torch.sum(R_inv * pts[:, :, None], dim=3) | |
| return pts | |