|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class Conv1dSubsampler(nn.Module): |
|
|
"""Convolutional subsampler: a stack of 1D convolution (along temporal |
|
|
dimension) followed by non-linear activation via gated linear units |
|
|
(https://arxiv.org/abs/1911.08460) |
|
|
|
|
|
Args: |
|
|
in_channels (int): the number of input channels |
|
|
mid_channels (int): the number of intermediate channels |
|
|
out_channels (int): the number of output channels |
|
|
kernel_sizes (List[int]): the kernel size for each convolutional layer |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
mid_channels: int, |
|
|
out_channels: int, |
|
|
kernel_sizes: List[int] = (3, 3), |
|
|
): |
|
|
super(Conv1dSubsampler, self).__init__() |
|
|
self.n_layers = len(kernel_sizes) |
|
|
self.conv_layers = nn.ModuleList( |
|
|
nn.Conv1d( |
|
|
in_channels if i == 0 else mid_channels // 2, |
|
|
mid_channels if i < self.n_layers - 1 else out_channels * 2, |
|
|
k, |
|
|
stride=2, |
|
|
padding=k // 2, |
|
|
) |
|
|
for i, k in enumerate(kernel_sizes) |
|
|
) |
|
|
|
|
|
def get_out_seq_lens_tensor(self, in_seq_lens_tensor): |
|
|
out = in_seq_lens_tensor.clone() |
|
|
for _ in range(self.n_layers): |
|
|
out = ((out.float() - 1) / 2 + 1).floor().long() |
|
|
return out |
|
|
|
|
|
def forward(self, src_tokens, src_lengths): |
|
|
bsz, in_seq_len, _ = src_tokens.size() |
|
|
x = src_tokens.transpose(1, 2).contiguous() |
|
|
for conv in self.conv_layers: |
|
|
x = conv(x) |
|
|
x = nn.functional.glu(x, dim=1) |
|
|
_, _, out_seq_len = x.size() |
|
|
x = x.transpose(1, 2).transpose(0, 1).contiguous() |
|
|
return x, self.get_out_seq_lens_tensor(src_lengths) |
|
|
|
|
|
|
|
|
def infer_conv_output_dim(in_channels, input_dim, out_channels): |
|
|
sample_seq_len = 200 |
|
|
sample_bsz = 10 |
|
|
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim) |
|
|
x = torch.nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=3 // 2)(x) |
|
|
x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x) |
|
|
x = x.transpose(1, 2) |
|
|
mb, seq = x.size()[:2] |
|
|
return x.contiguous().view(mb, seq, -1).size(-1) |
|
|
|
|
|
|
|
|
class Conv2dSubsampler(nn.Module): |
|
|
"""Convolutional subsampler: a stack of 2D convolution based on ESPnet implementation |
|
|
(https://github.com/espnet/espnet) |
|
|
|
|
|
Args: |
|
|
input_channels (int): the number of input channels |
|
|
input_feat_per_channel (int): encoder input dimension per input channel |
|
|
conv_out_channels (int): the number of output channels of conv layer |
|
|
encoder_embed_dim (int): encoder dimentions |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_channels: int, |
|
|
input_feat_per_channel: int, |
|
|
conv_out_channels: int, |
|
|
encoder_embed_dim: int, |
|
|
): |
|
|
super().__init__() |
|
|
assert input_channels == 1, input_channels |
|
|
self.conv = torch.nn.Sequential( |
|
|
torch.nn.Conv2d( |
|
|
input_channels, conv_out_channels, 3, stride=2, padding=3 // 2 |
|
|
), |
|
|
torch.nn.ReLU(), |
|
|
torch.nn.Conv2d( |
|
|
conv_out_channels, |
|
|
conv_out_channels, |
|
|
3, |
|
|
stride=2, |
|
|
padding=3 // 2, |
|
|
), |
|
|
torch.nn.ReLU(), |
|
|
) |
|
|
transformer_input_dim = infer_conv_output_dim( |
|
|
input_channels, input_feat_per_channel, conv_out_channels |
|
|
) |
|
|
self.out = torch.nn.Linear(transformer_input_dim, encoder_embed_dim) |
|
|
|
|
|
def forward(self, src_tokens, src_lengths): |
|
|
B, T_i, C = src_tokens.size() |
|
|
x = src_tokens.view(B, T_i, 1, C).transpose(1, 2).contiguous() |
|
|
x = self.conv(x) |
|
|
B, _, T_o, _ = x.size() |
|
|
x = x.transpose(1, 2).transpose(0, 1).contiguous().view(T_o, B, -1) |
|
|
x = self.out(x) |
|
|
|
|
|
subsampling_factor = int(T_i * 1.0 / T_o + 0.5) |
|
|
input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long() |
|
|
input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to( |
|
|
input_len_0.device |
|
|
) |
|
|
input_lengths = torch.min(input_len_0, input_len_1) |
|
|
return x, input_lengths |
|
|
|