Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| sys.path.append(os.getcwd()) | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| # TODO: be aware of the actual netork structures | |
| def get_log(x): | |
| log = 0 | |
| while x > 1: | |
| if x % 2 == 0: | |
| x = x // 2 | |
| log += 1 | |
| else: | |
| raise ValueError('x is not a power of 2') | |
| return log | |
| class ConvNormRelu(nn.Module): | |
| ''' | |
| (B,C_in,H,W) -> (B, C_out, H, W) | |
| there exist some kernel size that makes the result is not H/s | |
| #TODO: there might some problems with residual | |
| ''' | |
| def __init__(self, | |
| in_channels, | |
| out_channels, | |
| type='1d', | |
| leaky=False, | |
| downsample=False, | |
| kernel_size=None, | |
| stride=None, | |
| padding=None, | |
| p=0, | |
| groups=1, | |
| residual=False, | |
| norm='bn'): | |
| ''' | |
| conv-bn-relu | |
| ''' | |
| super(ConvNormRelu, self).__init__() | |
| self.residual = residual | |
| self.norm_type = norm | |
| # kernel_size = k | |
| # stride = s | |
| if kernel_size is None and stride is None: | |
| if not downsample: | |
| kernel_size = 3 | |
| stride = 1 | |
| else: | |
| kernel_size = 4 | |
| stride = 2 | |
| if padding is None: | |
| if isinstance(kernel_size, int) and isinstance(stride, tuple): | |
| padding = tuple(int((kernel_size - st) / 2) for st in stride) | |
| elif isinstance(kernel_size, tuple) and isinstance(stride, int): | |
| padding = tuple(int((ks - stride) / 2) for ks in kernel_size) | |
| elif isinstance(kernel_size, tuple) and isinstance(stride, tuple): | |
| padding = tuple(int((ks - st) / 2) for ks, st in zip(kernel_size, stride)) | |
| else: | |
| padding = int((kernel_size - stride) / 2) | |
| if self.residual: | |
| if downsample: | |
| if type == '1d': | |
| self.residual_layer = nn.Sequential( | |
| nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding | |
| ) | |
| ) | |
| elif type == '2d': | |
| self.residual_layer = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding | |
| ) | |
| ) | |
| else: | |
| if in_channels == out_channels: | |
| self.residual_layer = nn.Identity() | |
| else: | |
| if type == '1d': | |
| self.residual_layer = nn.Sequential( | |
| nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding | |
| ) | |
| ) | |
| elif type == '2d': | |
| self.residual_layer = nn.Sequential( | |
| nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| padding=padding | |
| ) | |
| ) | |
| in_channels = in_channels * groups | |
| out_channels = out_channels * groups | |
| if type == '1d': | |
| self.conv = nn.Conv1d(in_channels=in_channels, out_channels=out_channels, | |
| kernel_size=kernel_size, stride=stride, padding=padding, | |
| groups=groups) | |
| self.norm = nn.BatchNorm1d(out_channels) | |
| self.dropout = nn.Dropout(p=p) | |
| elif type == '2d': | |
| self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, | |
| kernel_size=kernel_size, stride=stride, padding=padding, | |
| groups=groups) | |
| self.norm = nn.BatchNorm2d(out_channels) | |
| self.dropout = nn.Dropout2d(p=p) | |
| if norm == 'gn': | |
| self.norm = nn.GroupNorm(2, out_channels) | |
| elif norm == 'ln': | |
| self.norm = nn.LayerNorm(out_channels) | |
| if leaky: | |
| self.relu = nn.LeakyReLU(negative_slope=0.2) | |
| else: | |
| self.relu = nn.ReLU() | |
| def forward(self, x, **kwargs): | |
| if self.norm_type == 'ln': | |
| out = self.dropout(self.conv(x)) | |
| out = self.norm(out.transpose(1,2)).transpose(1,2) | |
| else: | |
| out = self.norm(self.dropout(self.conv(x))) | |
| if self.residual: | |
| residual = self.residual_layer(x) | |
| out += residual | |
| return self.relu(out) | |
| class UNet1D(nn.Module): | |
| def __init__(self, | |
| input_channels, | |
| output_channels, | |
| max_depth=5, | |
| kernel_size=None, | |
| stride=None, | |
| p=0, | |
| groups=1): | |
| super(UNet1D, self).__init__() | |
| self.pre_downsampling_conv = nn.ModuleList([]) | |
| self.conv1 = nn.ModuleList([]) | |
| self.conv2 = nn.ModuleList([]) | |
| self.upconv = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.max_depth = max_depth | |
| self.groups = groups | |
| self.pre_downsampling_conv.append(ConvNormRelu(input_channels, output_channels, | |
| type='1d', leaky=True, downsample=False, | |
| kernel_size=kernel_size, stride=stride, p=p, groups=groups)) | |
| self.pre_downsampling_conv.append(ConvNormRelu(output_channels, output_channels, | |
| type='1d', leaky=True, downsample=False, | |
| kernel_size=kernel_size, stride=stride, p=p, groups=groups)) | |
| for i in range(self.max_depth): | |
| self.conv1.append(ConvNormRelu(output_channels, output_channels, | |
| type='1d', leaky=True, downsample=True, | |
| kernel_size=kernel_size, stride=stride, p=p, groups=groups)) | |
| for i in range(self.max_depth): | |
| self.conv2.append(ConvNormRelu(output_channels, output_channels, | |
| type='1d', leaky=True, downsample=False, | |
| kernel_size=kernel_size, stride=stride, p=p, groups=groups)) | |
| def forward(self, x): | |
| input_size = x.shape[-1] | |
| assert get_log( | |
| input_size) >= self.max_depth, 'num_frames must be a power of 2 and its power must be greater than max_depth' | |
| x = nn.Sequential(*self.pre_downsampling_conv)(x) | |
| residuals = [] | |
| residuals.append(x) | |
| for i, conv1 in enumerate(self.conv1): | |
| x = conv1(x) | |
| if i < self.max_depth - 1: | |
| residuals.append(x) | |
| for i, conv2 in enumerate(self.conv2): | |
| x = self.upconv(x) + residuals[self.max_depth - i - 1] | |
| x = conv2(x) | |
| return x | |
| class UNet2D(nn.Module): | |
| def __init__(self): | |
| super(UNet2D, self).__init__() | |
| raise NotImplementedError('2D Unet is wierd') | |
| class AudioPoseEncoder1D(nn.Module): | |
| ''' | |
| (B, C, T) -> (B, C*2, T) -> ... -> (B, C_out, T) | |
| ''' | |
| def __init__(self, | |
| C_in, | |
| C_out, | |
| kernel_size=None, | |
| stride=None, | |
| min_layer_nums=None | |
| ): | |
| super(AudioPoseEncoder1D, self).__init__() | |
| self.C_in = C_in | |
| self.C_out = C_out | |
| conv_layers = nn.ModuleList([]) | |
| cur_C = C_in | |
| num_layers = 0 | |
| while cur_C < self.C_out: | |
| conv_layers.append(ConvNormRelu( | |
| in_channels=cur_C, | |
| out_channels=cur_C * 2, | |
| kernel_size=kernel_size, | |
| stride=stride | |
| )) | |
| cur_C *= 2 | |
| num_layers += 1 | |
| if (cur_C != C_out) or (min_layer_nums is not None and num_layers < min_layer_nums): | |
| while (cur_C != C_out) or num_layers < min_layer_nums: | |
| conv_layers.append(ConvNormRelu( | |
| in_channels=cur_C, | |
| out_channels=C_out, | |
| kernel_size=kernel_size, | |
| stride=stride | |
| )) | |
| num_layers += 1 | |
| cur_C = C_out | |
| self.conv_layers = nn.Sequential(*conv_layers) | |
| def forward(self, x): | |
| ''' | |
| x: (B, C, T) | |
| ''' | |
| x = self.conv_layers(x) | |
| return x | |
| class AudioPoseEncoder2D(nn.Module): | |
| ''' | |
| (B, C, T) -> (B, 1, C, T) -> ... -> (B, C_out, T) | |
| ''' | |
| def __init__(self): | |
| raise NotImplementedError | |
| class AudioPoseEncoderRNN(nn.Module): | |
| ''' | |
| (B, C, T)->(B, T, C)->(B, T, C_out)->(B, C_out, T) | |
| ''' | |
| def __init__(self, | |
| C_in, | |
| hidden_size, | |
| num_layers, | |
| rnn_cell='gru', | |
| bidirectional=False | |
| ): | |
| super(AudioPoseEncoderRNN, self).__init__() | |
| if rnn_cell == 'gru': | |
| self.cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, | |
| bidirectional=bidirectional) | |
| elif rnn_cell == 'lstm': | |
| self.cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, | |
| bidirectional=bidirectional) | |
| else: | |
| raise ValueError('invalid rnn cell:%s' % (rnn_cell)) | |
| def forward(self, x, state=None): | |
| x = x.permute(0, 2, 1) | |
| x, state = self.cell(x, state) | |
| x = x.permute(0, 2, 1) | |
| return x | |
| class AudioPoseEncoderGraph(nn.Module): | |
| ''' | |
| (B, C, T)->(B, 2, V, T)->(B, 2, T, V)->(B, D, T, V) | |
| ''' | |
| def __init__(self, | |
| layers_config, # 理应是(C_in, C_out, kernel_size)的list | |
| A, # adjacent matrix (num_parts, V, V) | |
| residual, | |
| local_bn=False, | |
| share_weights=False | |
| ) -> None: | |
| super().__init__() | |
| self.A = A | |
| self.num_joints = A.shape[1] | |
| self.num_parts = A.shape[0] | |
| self.C_in = layers_config[0][0] | |
| self.C_out = layers_config[-1][1] | |
| self.conv_layers = nn.ModuleList([ | |
| GraphConvNormRelu( | |
| C_in=c_in, | |
| C_out=c_out, | |
| A=self.A, | |
| residual=residual, | |
| local_bn=local_bn, | |
| kernel_size=k, | |
| share_weights=share_weights | |
| ) for (c_in, c_out, k) in layers_config | |
| ]) | |
| self.conv_layers = nn.Sequential(*self.conv_layers) | |
| def forward(self, x): | |
| ''' | |
| x: (B, C, T), C should be num_joints*D | |
| output: (B, D, T, V) | |
| ''' | |
| B, C, T = x.shape | |
| x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T),D:每个joint的特征维度,注意这里V在前面 | |
| x = x.permute(0, 2, 3, 1) # (B, D, T, V) | |
| assert x.shape[1] == self.C_in | |
| x_conved = self.conv_layers(x) | |
| # x_conved = x_conved.permute(0, 3, 1, 2).contiguous().view(B, self.C_out*self.num_joints, T)#(B, V*C_out, T) | |
| return x_conved | |
| class SeqEncoder2D(nn.Module): | |
| ''' | |
| seq_encoder, encoding a seq to a vector | |
| (B, C, T)->(B, 2, V, T)->(B, 2, T, V) -> (B, 32, )->...->(B, C_out) | |
| ''' | |
| def __init__(self, | |
| C_in, # should be 2 | |
| T_in, | |
| C_out, | |
| num_joints, | |
| min_layer_num=None, | |
| residual=False | |
| ): | |
| super(SeqEncoder2D, self).__init__() | |
| self.C_in = C_in | |
| self.C_out = C_out | |
| self.T_in = T_in | |
| self.num_joints = num_joints | |
| conv_layers = nn.ModuleList([]) | |
| conv_layers.append(ConvNormRelu( | |
| in_channels=C_in, | |
| out_channels=32, | |
| type='2d', | |
| residual=residual | |
| )) | |
| cur_C = 32 | |
| cur_H = T_in | |
| cur_W = num_joints | |
| num_layers = 1 | |
| while (cur_C < C_out) or (cur_H > 1) or (cur_W > 1): | |
| ks = [3, 3] | |
| st = [1, 1] | |
| if cur_H > 1: | |
| if cur_H > 4: | |
| ks[0] = 4 | |
| st[0] = 2 | |
| else: | |
| ks[0] = cur_H | |
| st[0] = cur_H | |
| if cur_W > 1: | |
| if cur_W > 4: | |
| ks[1] = 4 | |
| st[1] = 2 | |
| else: | |
| ks[1] = cur_W | |
| st[1] = cur_W | |
| conv_layers.append(ConvNormRelu( | |
| in_channels=cur_C, | |
| out_channels=min(C_out, cur_C * 2), | |
| type='2d', | |
| kernel_size=tuple(ks), | |
| stride=tuple(st), | |
| residual=residual | |
| )) | |
| cur_C = min(cur_C * 2, C_out) | |
| if cur_H > 1: | |
| if cur_H > 4: | |
| cur_H //= 2 | |
| else: | |
| cur_H = 1 | |
| if cur_W > 1: | |
| if cur_W > 4: | |
| cur_W //= 2 | |
| else: | |
| cur_W = 1 | |
| num_layers += 1 | |
| if min_layer_num is not None and (num_layers < min_layer_num): | |
| while num_layers < min_layer_num: | |
| conv_layers.append(ConvNormRelu( | |
| in_channels=C_out, | |
| out_channels=C_out, | |
| type='2d', | |
| kernel_size=1, | |
| stride=1, | |
| residual=residual | |
| )) | |
| num_layers += 1 | |
| self.conv_layers = nn.Sequential(*conv_layers) | |
| self.num_layers = num_layers | |
| def forward(self, x): | |
| B, C, T = x.shape | |
| x = x.view(B, self.num_joints, self.C_in, T) # (B, V, D, T) V in front | |
| x = x.permute(0, 2, 3, 1) # (B, D, T, V) | |
| assert x.shape[1] == self.C_in and x.shape[-1] == self.num_joints | |
| x = self.conv_layers(x) | |
| return x.squeeze() | |
| class SeqEncoder1D(nn.Module): | |
| ''' | |
| (B, C, T)->(B, D) | |
| ''' | |
| def __init__(self, | |
| C_in, | |
| C_out, | |
| T_in, | |
| min_layer_nums=None | |
| ): | |
| super(SeqEncoder1D, self).__init__() | |
| conv_layers = nn.ModuleList([]) | |
| cur_C = C_in | |
| cur_T = T_in | |
| self.num_layers = 0 | |
| while (cur_C < C_out) or (cur_T > 1): | |
| ks = 3 | |
| st = 1 | |
| if cur_T > 1: | |
| if cur_T > 4: | |
| ks = 4 | |
| st = 2 | |
| else: | |
| ks = cur_T | |
| st = cur_T | |
| conv_layers.append(ConvNormRelu( | |
| in_channels=cur_C, | |
| out_channels=min(C_out, cur_C * 2), | |
| type='1d', | |
| kernel_size=ks, | |
| stride=st | |
| )) | |
| cur_C = min(cur_C * 2, C_out) | |
| if cur_T > 1: | |
| if cur_T > 4: | |
| cur_T = cur_T // 2 | |
| else: | |
| cur_T = 1 | |
| self.num_layers += 1 | |
| if min_layer_nums is not None and (self.num_layers < min_layer_nums): | |
| while self.num_layers < min_layer_nums: | |
| conv_layers.append(ConvNormRelu( | |
| in_channels=C_out, | |
| out_channels=C_out, | |
| type='1d', | |
| kernel_size=1, | |
| stride=1 | |
| )) | |
| self.num_layers += 1 | |
| self.conv_layers = nn.Sequential(*conv_layers) | |
| def forward(self, x): | |
| x = self.conv_layers(x) | |
| return x.squeeze() | |
| class SeqEncoderRNN(nn.Module): | |
| ''' | |
| (B, C, T) -> (B, T, C) -> (B, D) | |
| LSTM/GRU-FC | |
| ''' | |
| def __init__(self, | |
| hidden_size, | |
| in_size, | |
| num_rnn_layers, | |
| rnn_cell='gru', | |
| bidirectional=False | |
| ): | |
| super(SeqEncoderRNN, self).__init__() | |
| self.hidden_size = hidden_size | |
| self.in_size = in_size | |
| self.num_rnn_layers = num_rnn_layers | |
| self.bidirectional = bidirectional | |
| if rnn_cell == 'gru': | |
| self.cell = nn.GRU(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers, | |
| batch_first=True, bidirectional=bidirectional) | |
| elif rnn_cell == 'lstm': | |
| self.cell = nn.LSTM(input_size=self.in_size, hidden_size=self.hidden_size, num_layers=self.num_rnn_layers, | |
| batch_first=True, bidirectional=bidirectional) | |
| def forward(self, x, state=None): | |
| x = x.permute(0, 2, 1) | |
| B, T, C = x.shape | |
| x, _ = self.cell(x, state) | |
| if self.bidirectional: | |
| out = torch.cat([x[:, -1, :self.hidden_size], x[:, 0, self.hidden_size:]], dim=-1) | |
| else: | |
| out = x[:, -1, :] | |
| assert out.shape[0] == B | |
| return out | |
| class SeqEncoderGraph(nn.Module): | |
| ''' | |
| ''' | |
| def __init__(self, | |
| embedding_size, | |
| layer_configs, | |
| residual, | |
| local_bn, | |
| A, | |
| T, | |
| share_weights=False | |
| ) -> None: | |
| super().__init__() | |
| self.C_in = layer_configs[0][0] | |
| self.C_out = embedding_size | |
| self.num_joints = A.shape[1] | |
| self.graph_encoder = AudioPoseEncoderGraph( | |
| layers_config=layer_configs, | |
| A=A, | |
| residual=residual, | |
| local_bn=local_bn, | |
| share_weights=share_weights | |
| ) | |
| cur_C = layer_configs[-1][1] | |
| self.spatial_pool = ConvNormRelu( | |
| in_channels=cur_C, | |
| out_channels=cur_C, | |
| type='2d', | |
| kernel_size=(1, self.num_joints), | |
| stride=(1, 1), | |
| padding=(0, 0) | |
| ) | |
| temporal_pool = nn.ModuleList([]) | |
| cur_H = T | |
| num_layers = 0 | |
| self.temporal_conv_info = [] | |
| while cur_C < self.C_out or cur_H > 1: | |
| self.temporal_conv_info.append(cur_C) | |
| ks = [3, 1] | |
| st = [1, 1] | |
| if cur_H > 1: | |
| if cur_H > 4: | |
| ks[0] = 4 | |
| st[0] = 2 | |
| else: | |
| ks[0] = cur_H | |
| st[0] = cur_H | |
| temporal_pool.append(ConvNormRelu( | |
| in_channels=cur_C, | |
| out_channels=min(self.C_out, cur_C * 2), | |
| type='2d', | |
| kernel_size=tuple(ks), | |
| stride=tuple(st) | |
| )) | |
| cur_C = min(cur_C * 2, self.C_out) | |
| if cur_H > 1: | |
| if cur_H > 4: | |
| cur_H //= 2 | |
| else: | |
| cur_H = 1 | |
| num_layers += 1 | |
| self.temporal_pool = nn.Sequential(*temporal_pool) | |
| print("graph seq encoder info: temporal pool:", self.temporal_conv_info) | |
| self.num_layers = num_layers | |
| # need fc? | |
| def forward(self, x): | |
| ''' | |
| x: (B, C, T) | |
| ''' | |
| B, C, T = x.shape | |
| x = self.graph_encoder(x) | |
| x = self.spatial_pool(x) | |
| x = self.temporal_pool(x) | |
| x = x.view(B, self.C_out) | |
| return x | |
| class SeqDecoder2D(nn.Module): | |
| ''' | |
| (B, D)->(B, D, 1, 1)->(B, C_out, C, T)->(B, C_out, T) | |
| ''' | |
| def __init__(self): | |
| super(SeqDecoder2D, self).__init__() | |
| raise NotImplementedError | |
| class SeqDecoder1D(nn.Module): | |
| ''' | |
| (B, D)->(B, D, 1)->...->(B, C_out, T) | |
| ''' | |
| def __init__(self, | |
| D_in, | |
| C_out, | |
| T_out, | |
| min_layer_num=None | |
| ): | |
| super(SeqDecoder1D, self).__init__() | |
| self.T_out = T_out | |
| self.min_layer_num = min_layer_num | |
| cur_t = 1 | |
| self.pre_conv = ConvNormRelu( | |
| in_channels=D_in, | |
| out_channels=C_out, | |
| type='1d' | |
| ) | |
| self.num_layers = 1 | |
| self.upconv = nn.Upsample(scale_factor=2, mode='nearest') | |
| self.conv_layers = nn.ModuleList([]) | |
| cur_t *= 2 | |
| while cur_t <= T_out: | |
| self.conv_layers.append(ConvNormRelu( | |
| in_channels=C_out, | |
| out_channels=C_out, | |
| type='1d' | |
| )) | |
| cur_t *= 2 | |
| self.num_layers += 1 | |
| post_conv = nn.ModuleList([ConvNormRelu( | |
| in_channels=C_out, | |
| out_channels=C_out, | |
| type='1d' | |
| )]) | |
| self.num_layers += 1 | |
| if min_layer_num is not None and self.num_layers < min_layer_num: | |
| while self.num_layers < min_layer_num: | |
| post_conv.append(ConvNormRelu( | |
| in_channels=C_out, | |
| out_channels=C_out, | |
| type='1d' | |
| )) | |
| self.num_layers += 1 | |
| self.post_conv = nn.Sequential(*post_conv) | |
| def forward(self, x): | |
| x = x.unsqueeze(-1) | |
| x = self.pre_conv(x) | |
| for conv in self.conv_layers: | |
| x = self.upconv(x) | |
| x = conv(x) | |
| x = torch.nn.functional.interpolate(x, size=self.T_out, mode='nearest') | |
| x = self.post_conv(x) | |
| return x | |
| class SeqDecoderRNN(nn.Module): | |
| ''' | |
| (B, D)->(B, C_out, T) | |
| ''' | |
| def __init__(self, | |
| hidden_size, | |
| C_out, | |
| T_out, | |
| num_layers, | |
| rnn_cell='gru' | |
| ): | |
| super(SeqDecoderRNN, self).__init__() | |
| self.num_steps = T_out | |
| if rnn_cell == 'gru': | |
| self.cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, | |
| bidirectional=False) | |
| elif rnn_cell == 'lstm': | |
| self.cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, | |
| bidirectional=False) | |
| else: | |
| raise ValueError('invalid rnn cell:%s' % (rnn_cell)) | |
| self.fc = nn.Linear(hidden_size, C_out) | |
| def forward(self, hidden, frame_0): | |
| frame_0 = frame_0.permute(0, 2, 1) | |
| dec_input = frame_0 | |
| outputs = [] | |
| for i in range(self.num_steps): | |
| frame_out, hidden = self.cell(dec_input, hidden) | |
| frame_out = self.fc(frame_out) | |
| dec_input = frame_out | |
| outputs.append(frame_out) | |
| output = torch.cat(outputs, dim=1) | |
| return output.permute(0, 2, 1) | |
| class SeqTranslator2D(nn.Module): | |
| ''' | |
| (B, C, T)->(B, 1, C, T)-> ... -> (B, 1, C_out, T_out) | |
| ''' | |
| def __init__(self, | |
| C_in=64, | |
| C_out=108, | |
| T_in=75, | |
| T_out=25, | |
| residual=True | |
| ): | |
| super(SeqTranslator2D, self).__init__() | |
| print("Warning: hard coded") | |
| self.C_in = C_in | |
| self.C_out = C_out | |
| self.T_in = T_in | |
| self.T_out = T_out | |
| self.residual = residual | |
| self.conv_layers = nn.Sequential( | |
| ConvNormRelu(1, 32, '2d', kernel_size=5, stride=1), | |
| ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual), | |
| ConvNormRelu(32, 32, '2d', kernel_size=5, stride=1, residual=self.residual), | |
| ConvNormRelu(32, 64, '2d', kernel_size=5, stride=(4, 3)), | |
| ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual), | |
| ConvNormRelu(64, 64, '2d', kernel_size=5, stride=1, residual=self.residual), | |
| ConvNormRelu(64, 128, '2d', kernel_size=5, stride=(4, 1)), | |
| ConvNormRelu(128, 108, '2d', kernel_size=3, stride=(4, 1)), | |
| ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual), | |
| ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1, residual=self.residual), | |
| ConvNormRelu(108, 108, '2d', kernel_size=(1, 3), stride=1), | |
| ) | |
| def forward(self, x): | |
| assert len(x.shape) == 3 and x.shape[1] == self.C_in and x.shape[2] == self.T_in | |
| x = x.view(x.shape[0], 1, x.shape[1], x.shape[2]) | |
| x = self.conv_layers(x) | |
| x = x.squeeze(2) | |
| return x | |
| class SeqTranslator1D(nn.Module): | |
| ''' | |
| (B, C, T)->(B, C_out, T) | |
| ''' | |
| def __init__(self, | |
| C_in, | |
| C_out, | |
| kernel_size=None, | |
| stride=None, | |
| min_layers_num=None, | |
| residual=True, | |
| norm='bn' | |
| ): | |
| super(SeqTranslator1D, self).__init__() | |
| conv_layers = nn.ModuleList([]) | |
| conv_layers.append(ConvNormRelu( | |
| in_channels=C_in, | |
| out_channels=C_out, | |
| type='1d', | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| residual=residual, | |
| norm=norm | |
| )) | |
| self.num_layers = 1 | |
| if min_layers_num is not None and self.num_layers < min_layers_num: | |
| while self.num_layers < min_layers_num: | |
| conv_layers.append(ConvNormRelu( | |
| in_channels=C_out, | |
| out_channels=C_out, | |
| type='1d', | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| residual=residual, | |
| norm=norm | |
| )) | |
| self.num_layers += 1 | |
| self.conv_layers = nn.Sequential(*conv_layers) | |
| def forward(self, x): | |
| return self.conv_layers(x) | |
| class SeqTranslatorRNN(nn.Module): | |
| ''' | |
| (B, C, T)->(B, C_out, T) | |
| LSTM-FC | |
| ''' | |
| def __init__(self, | |
| C_in, | |
| C_out, | |
| hidden_size, | |
| num_layers, | |
| rnn_cell='gru' | |
| ): | |
| super(SeqTranslatorRNN, self).__init__() | |
| if rnn_cell == 'gru': | |
| self.enc_cell = nn.GRU(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, | |
| bidirectional=False) | |
| self.dec_cell = nn.GRU(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, | |
| bidirectional=False) | |
| elif rnn_cell == 'lstm': | |
| self.enc_cell = nn.LSTM(input_size=C_in, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, | |
| bidirectional=False) | |
| self.dec_cell = nn.LSTM(input_size=C_out, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, | |
| bidirectional=False) | |
| else: | |
| raise ValueError('invalid rnn cell:%s' % (rnn_cell)) | |
| self.fc = nn.Linear(hidden_size, C_out) | |
| def forward(self, x, frame_0): | |
| num_steps = x.shape[-1] | |
| x = x.permute(0, 2, 1) | |
| frame_0 = frame_0.permute(0, 2, 1) | |
| _, hidden = self.enc_cell(x, None) | |
| outputs = [] | |
| for i in range(num_steps): | |
| inputs = frame_0 | |
| output_frame, hidden = self.dec_cell(inputs, hidden) | |
| output_frame = self.fc(output_frame) | |
| frame_0 = output_frame | |
| outputs.append(output_frame) | |
| outputs = torch.cat(outputs, dim=1) | |
| return outputs.permute(0, 2, 1) | |
| class ResBlock(nn.Module): | |
| def __init__(self, | |
| input_dim, | |
| fc_dim, | |
| afn, | |
| nfn | |
| ): | |
| ''' | |
| afn: activation fn | |
| nfn: normalization fn | |
| ''' | |
| super(ResBlock, self).__init__() | |
| self.input_dim = input_dim | |
| self.fc_dim = fc_dim | |
| self.afn = afn | |
| self.nfn = nfn | |
| if self.afn != 'relu': | |
| raise ValueError('Wrong') | |
| if self.nfn == 'layer_norm': | |
| raise ValueError('wrong') | |
| self.layers = nn.Sequential( | |
| nn.Linear(self.input_dim, self.fc_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(self.fc_dim // 2, self.fc_dim // 2), | |
| nn.ReLU(), | |
| nn.Linear(self.fc_dim // 2, self.fc_dim), | |
| nn.ReLU() | |
| ) | |
| self.shortcut_layer = nn.Sequential( | |
| nn.Linear(self.input_dim, self.fc_dim), | |
| nn.ReLU(), | |
| ) | |
| def forward(self, inputs): | |
| return self.layers(inputs) + self.shortcut_layer(inputs) | |
| class AudioEncoder(nn.Module): | |
| def __init__(self, channels, padding=3, kernel_size=8, conv_stride=2, conv_pool=None, augmentation=False): | |
| super(AudioEncoder, self).__init__() | |
| self.in_channels = channels[0] | |
| self.augmentation = augmentation | |
| model = [] | |
| acti = nn.LeakyReLU(0.2) | |
| nr_layer = len(channels) - 1 | |
| for i in range(nr_layer): | |
| if conv_pool is None: | |
| model.append(nn.ReflectionPad1d(padding)) | |
| model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride)) | |
| model.append(acti) | |
| else: | |
| model.append(nn.ReflectionPad1d(padding)) | |
| model.append(nn.Conv1d(channels[i], channels[i + 1], kernel_size=kernel_size, stride=conv_stride)) | |
| model.append(acti) | |
| model.append(conv_pool(kernel_size=2, stride=2)) | |
| if self.augmentation: | |
| model.append( | |
| nn.Conv1d(channels[-1], channels[-1], kernel_size=kernel_size, stride=conv_stride) | |
| ) | |
| model.append(acti) | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| x = x[:, :self.in_channels, :] | |
| x = self.model(x) | |
| return x | |
| class AudioDecoder(nn.Module): | |
| def __init__(self, channels, kernel_size=7, ups=25): | |
| super(AudioDecoder, self).__init__() | |
| model = [] | |
| pad = (kernel_size - 1) // 2 | |
| acti = nn.LeakyReLU(0.2) | |
| for i in range(len(channels) - 2): | |
| model.append(nn.Upsample(scale_factor=2, mode='nearest')) | |
| model.append(nn.ReflectionPad1d(pad)) | |
| model.append(nn.Conv1d(channels[i], channels[i + 1], | |
| kernel_size=kernel_size, stride=1)) | |
| if i == 0 or i == 1: | |
| model.append(nn.Dropout(p=0.2)) | |
| if not i == len(channels) - 2: | |
| model.append(acti) | |
| model.append(nn.Upsample(size=ups, mode='nearest')) | |
| model.append(nn.ReflectionPad1d(pad)) | |
| model.append(nn.Conv1d(channels[-2], channels[-1], | |
| kernel_size=kernel_size, stride=1)) | |
| self.model = nn.Sequential(*model) | |
| def forward(self, x): | |
| return self.model(x) | |
| class Audio2Pose(nn.Module): | |
| def __init__(self, pose_dim, embed_size, augmentation, ups=25): | |
| super(Audio2Pose, self).__init__() | |
| self.pose_dim = pose_dim | |
| self.embed_size = embed_size | |
| self.augmentation = augmentation | |
| self.aud_enc = AudioEncoder(channels=[13, 64, 128, 256], padding=2, kernel_size=7, conv_stride=1, | |
| conv_pool=nn.AvgPool1d, augmentation=self.augmentation) | |
| if self.augmentation: | |
| self.aud_dec = AudioDecoder(channels=[512, 256, 128, pose_dim]) | |
| else: | |
| self.aud_dec = AudioDecoder(channels=[256, 256, 128, pose_dim], ups=ups) | |
| if self.augmentation: | |
| self.pose_enc = nn.Sequential( | |
| nn.Linear(self.embed_size // 2, 256), | |
| nn.LayerNorm(256) | |
| ) | |
| def forward(self, audio_feat, dec_input=None): | |
| B = audio_feat.shape[0] | |
| aud_embed = self.aud_enc.forward(audio_feat) | |
| if self.augmentation: | |
| dec_input = dec_input.squeeze(0) | |
| dec_embed = self.pose_enc(dec_input) | |
| dec_embed = dec_embed.unsqueeze(2) | |
| dec_embed = dec_embed.expand(dec_embed.shape[0], dec_embed.shape[1], aud_embed.shape[-1]) | |
| aud_embed = torch.cat([aud_embed, dec_embed], dim=1) | |
| out = self.aud_dec.forward(aud_embed) | |
| return out | |
| if __name__ == '__main__': | |
| import numpy as np | |
| import os | |
| import sys | |
| test_model = SeqEncoder2D( | |
| C_in=2, | |
| T_in=25, | |
| C_out=512, | |
| num_joints=54, | |
| ) | |
| print(test_model.num_layers) | |
| input = torch.randn((64, 108, 25)) | |
| output = test_model(input) | |
| print(output.shape) |