File size: 1,200 Bytes
076275f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
from src.shape_utils import load_shape_with_lbo



class CSE(nn.Module):
    def __init__(self, class_name, num_basis=64, skip_first=True, dim=16, num_vert=None, barebones=False, device=torch.device('cuda'), rand_init=False):
        super(CSE, self).__init__()
        
        self.shape = load_shape_with_lbo(class_name, num_basis, skip_first)

        if barebones:
            return
        self.functional_basis=None
        
        # Create a parameter tensor for the D x Q matrix, initialized randomly
        if not rand_init:
            self.weight_matrix = nn.Parameter(torch.zeros(num_basis, dim, requires_grad=True))
        else:
            self.weight_matrix = nn.Parameter(torch.randn(num_basis, dim, requires_grad=True))

        self.to(device)
        self.nns = None
        self.num_vert = num_vert
    
    def forward(self):
        output = torch.matmul(self.functional_basis, self.weight_matrix)

        if self.num_vert is not None:
            output_tmp = torch.zeros((self.num_vert, output.shape[1])).to(output.device)
            output_tmp[:output.shape[0], :] = output
            output = output_tmp
        return output