zirobtc's picture
Initial upload of MotionStreamer code, excluding large extracted data and output folders.
0e267a7 verified
import torch.nn as nn
from models.causal_cnn import CausalEncoder, CausalDecoder
# Causal TAE:
class Causal_TAE(nn.Module):
def __init__(self,
hidden_size=1024,
down_t=2,
stride_t=2,
width=1024,
depth=3,
dilation_growth_rate=3,
activation='relu',
norm=None,
latent_dim=16,
clip_range = []
):
super().__init__()
self.decode_proj = nn.Linear(latent_dim, width)
self.encoder = CausalEncoder(272, hidden_size, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm, latent_dim=latent_dim, clip_range=clip_range)
self.decoder = CausalDecoder(272, hidden_size, down_t, stride_t, width, depth, dilation_growth_rate, activation=activation, norm=norm)
def preprocess(self, x):
x = x.permute(0,2,1).float()
return x
def postprocess(self, x):
x = x.permute(0,2,1)
return x
def encode(self, x):
x_in = self.preprocess(x)
x_encoder, mu, logvar = self.encoder(x_in)
x_encoder = self.postprocess(x_encoder)
x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1])
return x_encoder, mu, logvar
def forward(self, x):
x_in = self.preprocess(x)
# Encode
x_encoder, mu, logvar = self.encoder(x_in)
x_encoder = self.decode_proj(x_encoder)
# decoder
x_decoder = self.decoder(x_encoder)
x_out = self.postprocess(x_decoder)
return x_out, mu, logvar
def forward_decoder(self, x):
# decoder
x_width = self.decode_proj(x)
x_decoder = self.decoder(x_width)
x_out = self.postprocess(x_decoder)
return x_out
class Causal_HumanTAE(nn.Module):
def __init__(self,
hidden_size=1024,
down_t=2,
stride_t=2,
depth=3,
dilation_growth_rate=3,
activation='relu',
norm=None,
latent_dim=16,
clip_range = []
):
super().__init__()
self.tae = Causal_TAE(hidden_size, down_t, stride_t, hidden_size, depth, dilation_growth_rate, activation=activation, norm=norm, latent_dim=latent_dim, clip_range=clip_range)
def encode(self, x):
h, mu, logvar = self.tae.encode(x)
return h, mu, logvar
def forward(self, x):
x_out, mu, logvar = self.tae(x)
return x_out, mu, logvar
def forward_decoder(self, x):
x_out = self.tae.forward_decoder(x)
return x_out