Spaces:
Sleeping
Sleeping
File size: 3,321 Bytes
e321b92 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import torch
import torch.nn as nn
import roma
from shape_models.layers import DiffusionNet
class PrismDecoder(torch.nn.Module):
def __init__(self, dim_in=1024, dim_out=512, n_width=256, n_block=4, pairwise_dot=True, dropout=False, dot_linear_complex=True, neig=128):
super().__init__()
self.diffusion_net = DiffusionNet(
C_in=dim_in,
C_out=dim_out,
C_width=n_width,
N_block=n_block,
dropout=dropout,
with_gradient_features=pairwise_dot,
with_gradient_rotations=dot_linear_complex,
)
self.mlp_refine = nn.Sequential(
nn.Linear(dim_out, dim_out),
nn.ReLU(),
nn.Linear(dim_out, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 12),
)
def forward(self, batch_dict, latent):
# original prism
try:
verts = batch_dict["vertices"]
except:
verts = batch_dict["verts"]
faces = batch_dict["faces"]
prism_base = verts[faces] # (n_faces, 3, 3)
bs = 1
# forward through diffusion net
features = self.diffusion_net(latent, batch_dict["mass"], batch_dict["L"], evals=batch_dict["evals"],
evecs=batch_dict["evecs"], gradX=batch_dict["gradX"], gradY=batch_dict["gradY"], faces=batch_dict["faces"]) # (bs, n_verts, dim)
# features per face
x_gather = features.unsqueeze(-1).expand(-1, -1, 3)
faces_gather = faces.unsqueeze(1).expand(-1, features.shape[-1], -1)
xf = torch.gather(x_gather, 0, faces_gather)
features = torch.mean(xf, dim=-1) # (bs, n_faces, dim)
# refine features with mlp
features = self.mlp_refine(features) # (bs, n_faces, 12)
# get the translation and rotation
rotations = features[:, :9].reshape(-1, 3, 3)
rotations = roma.special_procrustes(rotations) # (n_faces, 3, 3)
translations = features[:, 9:].reshape(-1, 3) # (n_faces, 3)
# transform the prism
transformed_prism = (prism_base @ rotations) + translations[:, None]
# prism to vertices
features = self.prism_to_vertices(transformed_prism, faces, verts)
out_features = features.reshape(bs, -1, 3)
transformed_prism = transformed_prism
rotations = rotations
return out_features, transformed_prism, rotations
def prism_to_vertices(self, prism, faces, verts):
# initialize the transformed features tensor
N = verts.shape[0]
d = prism.shape[-1]
device = prism.device
features = torch.zeros((N, d), device=device)
# scatter the features in K onto L using the indices in F
features.scatter_add_(0, faces[:, :, None].repeat(1, 1, d).reshape(-1, d), prism.reshape(-1, d))
# divide each row in the transformed features tensor by the number of faces that the corresponding vertex appears in
num_faces_per_vertex = torch.zeros(N, dtype=torch.float32, device=device)
num_faces_per_vertex.index_add_(0, faces.reshape(-1), torch.ones(faces.shape[0] * 3, device=device))
features /= num_faces_per_vertex.unsqueeze(1).clamp(min=1)
return features |