| | import torch |
| |
|
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import torch.nn.functional as F |
| | import torch.nn.init as init |
| |
|
| |
|
| | class MultiLatentEncoder(nn.Module): |
| | def __init__(self, opt): |
| | super(MultiLatentEncoder, self).__init__() |
| |
|
| | self.neuron_input = Siren( |
| | dim_in = 7, |
| | dim_out = opt.pos_encode_dim |
| | ) |
| |
|
| | def forward(self, pos, direct, imp): |
| | input_encoded = torch.concat((pos, direct, imp), -1) |
| | output = self.neuron_input(input_encoded) |
| | return output |
| |
|
| | def predict(self, pos, direct, imp): |
| | input_encoded = torch.concat((pos, direct, imp), -1) |
| | output = self.neuron_input(input_encoded) |
| | return output |
| |
|
| | class AutoDecoder(nn.Module): |
| | def __init__(self, opt): |
| | super(AutoDecoder, self).__init__() |
| |
|
| | self.ndf = opt.ndf |
| | self.data_shape = opt.data_shape |
| |
|
| | |
| | def block(in_feat, out_feat, normalize=True): |
| | layers = [nn.ConvTranspose3d(in_feat, out_feat, 4, 2, 1)] |
| | if normalize: |
| | layers.append(nn.BatchNorm3d(out_feat, 0.8)) |
| | layers.append(nn.LeakyReLU(0.2, inplace=True)) |
| | return layers |
| |
|
| | self.fc = nn.Sequential( |
| | nn.Linear(opt.pos_encode_dim + opt.z_latent_dim, int((self.ndf*8)*int(self.data_shape/16)*int(self.data_shape/16)*int(self.data_shape/16))), |
| | nn.LeakyReLU(0.2, inplace=True), |
| | ) |
| | self.decoder = nn.Sequential( |
| | *block(self.ndf*8, self.ndf*4), |
| | *block(self.ndf*4, self.ndf*2), |
| | *block(self.ndf*2, self.ndf) |
| | ) |
| |
|
| | self.toVoxelMd = nn.Sequential( |
| | nn.ConvTranspose3d(self.ndf , 1, 4, 2, 1, bias=False), |
| | nn.Tanh(), |
| | ) |
| |
|
| | self.toVoxelBig = nn.Sequential( |
| | *block(self.ndf, int(self.ndf/2)), |
| | nn.ConvTranspose3d(int(self.ndf/2), 1, 4, 2, 1, bias=False), |
| | nn.Tanh(), |
| | ) |
| |
|
| | self.latent_vectors = nn.Parameter(torch.FloatTensor(opt.train_dataset_size, opt.z_latent_dim)) |
| | self.cookbook = nn.Parameter(torch.FloatTensor(opt.train_dataset_size, opt.pos_encode_dim + opt.z_latent_dim)) |
| |
|
| | init.xavier_normal_(self.latent_vectors) |
| |
|
| | def Cook(self, x, y): |
| | input_x = self.embedding(x,y) |
| | distances = ( |
| | (input_x ** 2).sum(1, keepdim=True) |
| | - 2 * input_x @ self.cookbook.transpose(0, 1) |
| | + (self.cookbook.transpose(0, 1) ** 2).sum(0, keepdim=True) |
| | ) |
| | encoding_indices = distances.argmin(1) |
| | output = F.embedding(encoding_indices.view(input_x.shape[0],*input_x.shape[2:]), self.cookbook) |
| | distance = ((input_x - output.detach()) ** 2).mean() |
| |
|
| | |
| |
|
| | return output, encoding_indices, distance |
| |
|
| | def embedding(self, x, y): |
| | input_x = torch.concat((x, y), -1) |
| | return input_x |
| |
|
| | def forward(self, x, y, t = "Middle"): |
| | input_x = self.embedding(x, y) |
| | if t == "Middle": |
| | return self.forwardMiddle(input_x) |
| | else: |
| | return self.forwardBig(input_x) |
| |
|
| | def forwardMiddle(self, input_x): |
| | feature = self.fc(input_x).reshape(1, self.ndf*8, int(self.data_shape/16), int(self.data_shape/16), int(self.data_shape/16)) |
| | output = self.decoder(feature) |
| | output = self.toVoxelMd(output) |
| | output = output.view(1,1,self.data_shape,self.data_shape,self.data_shape) |
| |
|
| | return output |
| | def forwardBig(self, input_x): |
| | feature = self.fc(input_x).reshape(1, self.ndf*8, int(self.data_shape/16), int(self.data_shape/16), int(self.data_shape/16)) |
| | output = self.decoder(feature) |
| | output = self.toVoxelBig(output) |
| | output = output.view(1,1,self.data_shape*2,self.data_shape*2,self.data_shape*2) |
| |
|
| | return output |
| |
|
| | def codes(self): |
| | return self.latent_vectors |
| |
|