File size: 3,921 Bytes
b152a33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
import torch.nn.init as init


class MultiLatentEncoder(nn.Module):
    def __init__(self, opt):
        super(MultiLatentEncoder, self).__init__()

        self.neuron_input = Siren(
            dim_in = 7,
            dim_out = opt.pos_encode_dim
        )

    def forward(self, pos, direct, imp):
        input_encoded = torch.concat((pos, direct, imp), -1)
        output = self.neuron_input(input_encoded)
        return output

    def predict(self, pos, direct, imp):
        input_encoded = torch.concat((pos, direct, imp), -1)
        output = self.neuron_input(input_encoded)
        return output

class AutoDecoder(nn.Module):
    def __init__(self, opt):
        super(AutoDecoder, self).__init__()

        self.ndf = opt.ndf
        self.data_shape = opt.data_shape

        # With FC Layer
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.ConvTranspose3d(in_feat, out_feat, 4, 2, 1)]
            if normalize:
                layers.append(nn.BatchNorm3d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.fc = nn.Sequential(
            nn.Linear(opt.pos_encode_dim + opt.z_latent_dim, int((self.ndf*8)*int(self.data_shape/16)*int(self.data_shape/16)*int(self.data_shape/16))),#6*6
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.decoder = nn.Sequential(
            *block(self.ndf*8, self.ndf*4),
            *block(self.ndf*4, self.ndf*2),
            *block(self.ndf*2, self.ndf)
        )

        self.toVoxelMd = nn.Sequential(
            nn.ConvTranspose3d(self.ndf , 1, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

        self.toVoxelBig = nn.Sequential(
            *block(self.ndf, int(self.ndf/2)),
            nn.ConvTranspose3d(int(self.ndf/2), 1, 4, 2, 1, bias=False),
            nn.Tanh(),
        )

        self.latent_vectors = nn.Parameter(torch.FloatTensor(opt.train_dataset_size, opt.z_latent_dim))
        self.cookbook = nn.Parameter(torch.FloatTensor(opt.train_dataset_size, opt.pos_encode_dim + opt.z_latent_dim))

        init.xavier_normal_(self.latent_vectors)

    def Cook(self, x, y):
        input_x = self.embedding(x,y)
        distances = (
            (input_x ** 2).sum(1, keepdim=True)
            - 2 * input_x @ self.cookbook.transpose(0, 1)
            + (self.cookbook.transpose(0, 1) ** 2).sum(0, keepdim=True)
        )
        encoding_indices = distances.argmin(1)
        output = F.embedding(encoding_indices.view(input_x.shape[0],*input_x.shape[2:]), self.cookbook)
        distance = ((input_x - output.detach()) ** 2).mean()

        # quantized_x = input_x + (output - input_x).detach()

        return output, encoding_indices, distance

    def embedding(self, x, y):
        input_x = torch.concat((x, y), -1)
        return input_x

    def forward(self, x, y, t = "Middle"):
        input_x = self.embedding(x, y)
        if t == "Middle":
            return self.forwardMiddle(input_x)
        else:
            return self.forwardBig(input_x)

    def forwardMiddle(self, input_x):
        feature = self.fc(input_x).reshape(1, self.ndf*8, int(self.data_shape/16), int(self.data_shape/16), int(self.data_shape/16))
        output = self.decoder(feature)
        output = self.toVoxelMd(output)
        output = output.view(1,1,self.data_shape,self.data_shape,self.data_shape)

        return output
    def forwardBig(self, input_x):
        feature = self.fc(input_x).reshape(1, self.ndf*8, int(self.data_shape/16), int(self.data_shape/16), int(self.data_shape/16))
        output = self.decoder(feature)
        output = self.toVoxelBig(output)
        output = output.view(1,1,self.data_shape*2,self.data_shape*2,self.data_shape*2)

        return output

    def codes(self):
        return self.latent_vectors