File size: 1,200 Bytes
076275f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import torch
import torch.nn as nn
from src.shape_utils import load_shape_with_lbo
class CSE(nn.Module):
def __init__(self, class_name, num_basis=64, skip_first=True, dim=16, num_vert=None, barebones=False, device=torch.device('cuda'), rand_init=False):
super(CSE, self).__init__()
self.shape = load_shape_with_lbo(class_name, num_basis, skip_first)
if barebones:
return
self.functional_basis=None
# Create a parameter tensor for the D x Q matrix, initialized randomly
if not rand_init:
self.weight_matrix = nn.Parameter(torch.zeros(num_basis, dim, requires_grad=True))
else:
self.weight_matrix = nn.Parameter(torch.randn(num_basis, dim, requires_grad=True))
self.to(device)
self.nns = None
self.num_vert = num_vert
def forward(self):
output = torch.matmul(self.functional_basis, self.weight_matrix)
if self.num_vert is not None:
output_tmp = torch.zeros((self.num_vert, output.shape[1])).to(output.device)
output_tmp[:output.shape[0], :] = output
output = output_tmp
return output
|