# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import List import torch import torch.nn as nn class Conv1dSubsampler(nn.Module): """Convolutional subsampler: a stack of 1D convolution (along temporal dimension) followed by non-linear activation via gated linear units (https://arxiv.org/abs/1911.08460) Args: in_channels (int): the number of input channels mid_channels (int): the number of intermediate channels out_channels (int): the number of output channels kernel_sizes (List[int]): the kernel size for each convolutional layer """ def __init__( self, in_channels: int, mid_channels: int, out_channels: int, kernel_sizes: List[int] = (3, 3), ): super(Conv1dSubsampler, self).__init__() self.n_layers = len(kernel_sizes) self.conv_layers = nn.ModuleList( nn.Conv1d( in_channels if i == 0 else mid_channels // 2, mid_channels if i < self.n_layers - 1 else out_channels * 2, k, stride=2, padding=k // 2, ) for i, k in enumerate(kernel_sizes) ) def get_out_seq_lens_tensor(self, in_seq_lens_tensor): out = in_seq_lens_tensor.clone() for _ in range(self.n_layers): out = ((out.float() - 1) / 2 + 1).floor().long() return out def forward(self, src_tokens, src_lengths): bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D) x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T for conv in self.conv_layers: x = conv(x) x = nn.functional.glu(x, dim=1) _, _, out_seq_len = x.size() x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D) return x, self.get_out_seq_lens_tensor(src_lengths) def infer_conv_output_dim(in_channels, input_dim, out_channels): sample_seq_len = 200 sample_bsz = 10 x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) x = torch.nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=3 // 2)(x) x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x) x = x.transpose(1, 2) mb, seq = x.size()[:2] return x.contiguous().view(mb, seq, -1).size(-1) class Conv2dSubsampler(nn.Module): """Convolutional subsampler: a stack of 2D convolution based on ESPnet implementation (https://github.com/espnet/espnet) Args: input_channels (int): the number of input channels input_feat_per_channel (int): encoder input dimension per input channel conv_out_channels (int): the number of output channels of conv layer encoder_embed_dim (int): encoder dimentions """ def __init__( self, input_channels: int, input_feat_per_channel: int, conv_out_channels: int, encoder_embed_dim: int, ): super().__init__() assert input_channels == 1, input_channels self.conv = torch.nn.Sequential( torch.nn.Conv2d( input_channels, conv_out_channels, 3, stride=2, padding=3 // 2 ), torch.nn.ReLU(), torch.nn.Conv2d( conv_out_channels, conv_out_channels, 3, stride=2, padding=3 // 2, ), torch.nn.ReLU(), ) transformer_input_dim = infer_conv_output_dim( input_channels, input_feat_per_channel, conv_out_channels ) self.out = torch.nn.Linear(transformer_input_dim, encoder_embed_dim) def forward(self, src_tokens, src_lengths): B, T_i, C = src_tokens.size() x = src_tokens.view(B, T_i, 1, C).transpose(1, 2).contiguous() x = self.conv(x) B, _, T_o, _ = x.size() x = x.transpose(1, 2).transpose(0, 1).contiguous().view(T_o, B, -1) x = self.out(x) subsampling_factor = int(T_i * 1.0 / T_o + 0.5) input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long() input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to( input_len_0.device ) input_lengths = torch.min(input_len_0, input_len_1) return x, input_lengths