daidedou
first_try
458efe2
raw
history blame
864 Bytes
import torch
import torch.nn as nn
from .layers import DiffusionNet
class Encoder(nn.Module):
def __init__(self, with_grad=True, key_verts="vertices"):
super(Encoder, self).__init__()
self.diff_net = DiffusionNet(
C_in=3,
C_out=512,
C_width=128,
N_block=4,
dropout=True,
with_gradient_features=with_grad,
with_gradient_rotations=with_grad,
)
self.key_verts = key_verts
def forward(self, shape_dict):
feats = self.diff_net(shape_dict[self.key_verts], shape_dict["mass"], shape_dict["L"], evals=shape_dict["evals"],
evecs=shape_dict["evecs"], gradX=shape_dict["gradX"], gradY=shape_dict["gradY"], faces=shape_dict["faces"])
x_out = torch.max(feats, dim=0).values
return x_out