File size: 2,802 Bytes
0e267a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch.nn as nn
from models.causal_cnn import CausalEncoder, CausalDecoder


# Causal TAE:
class Causal_TAE(nn.Module):
    def __init__(self,
                 hidden_size=1024,
                 down_t=2,
                 stride_t=2,
                 width=1024,
                 depth=3,
                 dilation_growth_rate=3,
                 activation='relu',
                 norm=None,
                 latent_dim=16,
                 clip_range = []
                 ):
        
        super().__init__()

        self.decode_proj = nn.Linear(latent_dim, width)  

        self.encoder = CausalEncoder(272, hidden_size, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm, latent_dim=latent_dim, clip_range=clip_range)
        self.decoder = CausalDecoder(272, hidden_size, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
    


    def preprocess(self, x):
        x = x.permute(0,2,1).float()
        return x


    def postprocess(self, x):
        x = x.permute(0,2,1)
        return x


    def encode(self, x):
        x_in = self.preprocess(x)
        x_encoder, mu, logvar = self.encoder(x_in)
        x_encoder = self.postprocess(x_encoder)
        x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1])  
        
        return x_encoder, mu, logvar


    def forward(self, x):
        x_in = self.preprocess(x)       
        # Encode
        x_encoder, mu, logvar = self.encoder(x_in)  
        x_encoder = self.decode_proj(x_encoder) 
        # decoder
        x_decoder = self.decoder(x_encoder)
        x_out = self.postprocess(x_decoder)  
        return x_out, mu, logvar


    def forward_decoder(self, x):         
        # decoder
        x_width = self.decode_proj(x)           
        x_decoder = self.decoder(x_width)
        x_out = self.postprocess(x_decoder)
        return x_out


class Causal_HumanTAE(nn.Module):
    def __init__(self,
                 hidden_size=1024,
                 down_t=2,
                 stride_t=2,
                 depth=3,
                 dilation_growth_rate=3,
                 activation='relu',
                 norm=None,
                 latent_dim=16,
                 clip_range = []
                 ):
        
        super().__init__()
        self.tae = Causal_TAE(hidden_size, down_t, stride_t, hidden_size, depth, dilation_growth_rate, activation=activation, norm=norm, latent_dim=latent_dim, clip_range=clip_range)

    def encode(self, x):
        h, mu, logvar = self.tae.encode(x) 
        return h, mu, logvar

    def forward(self, x):
        x_out, mu, logvar = self.tae(x)
        return x_out, mu, logvar

    def forward_decoder(self, x):
        x_out = self.tae.forward_decoder(x)
        return x_out