File size: 5,651 Bytes
901e06a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | # Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List
import torch
import torch.nn as nn
from chunk_unity.modules.chunk_causal_conv1d import ChunkCausalConv1d
class Conv1dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 1D convolution (along temporal
dimension) followed by non-linear activation via gated linear units
(https://arxiv.org/abs/1911.08460)
Args:
in_channels (int): the number of input channels
mid_channels (int): the number of intermediate channels
out_channels (int): the number of output channels
kernel_sizes (List[int]): the kernel size for each convolutional layer
"""
def __init__(
self,
in_channels: int,
mid_channels: int,
out_channels: int,
kernel_sizes: List[int] = (3, 3),
chunk_size=None,
):
super(Conv1dSubsampler, self).__init__()
self.n_layers = len(kernel_sizes)
if chunk_size is None:
self.conv_layers = nn.ModuleList(
nn.Conv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
padding=k // 2,
)
for i, k in enumerate(kernel_sizes)
)
else:
self.conv_layers = nn.ModuleList(
ChunkCausalConv1d(
in_channels if i == 0 else mid_channels // 2,
mid_channels if i < self.n_layers - 1 else out_channels * 2,
k,
stride=2,
# padding=k // 2,
chunk_size=chunk_size,
)
for i, k in enumerate(kernel_sizes)
)
def bulid_causal_conv(self, conv):
fixed = torch.zeros_like(conv.weight.data) + 1e-2
fixed.requires_grad_(False)
upd = conv.weight
conv_mask = torch.ones_like(conv.weight.data)
conv_mask.requires_grad_(False)
conv_mask[:, :, conv_mask.size(-1) // 2 + 1 :] = 0
# combine them using fixed "mask":
new_conv_weight = (1 - conv_mask) * fixed + conv_mask * upd
# print(new_conv_weight)
conv.weight = torch.nn.Parameter(new_conv_weight)
return
def get_out_seq_lens_tensor(self, in_seq_lens_tensor):
out = in_seq_lens_tensor.clone()
for _ in range(self.n_layers):
out = ((out.float() - 1) / 2 + 1).floor().long()
return out
def forward(self, src_tokens, src_lengths):
bsz, in_seq_len, _ = src_tokens.size() # B x T x (C x D)
x = src_tokens.transpose(1, 2).contiguous() # -> B x (C x D) x T
for conv in self.conv_layers:
x = conv(x)
x = nn.functional.glu(x, dim=1)
_, _, out_seq_len = x.size()
x = x.transpose(1, 2).transpose(0, 1).contiguous() # -> T x B x (C x D)
return x, self.get_out_seq_lens_tensor(src_lengths)
def infer_conv_output_dim(in_channels, input_dim, out_channels):
sample_seq_len = 200
sample_bsz = 10
x = torch.randn(sample_bsz, in_channels, sample_seq_len, input_dim)
x = torch.nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=3 // 2)(x)
x = torch.nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=3 // 2)(x)
x = x.transpose(1, 2)
mb, seq = x.size()[:2]
return x.contiguous().view(mb, seq, -1).size(-1)
class Conv2dSubsampler(nn.Module):
"""Convolutional subsampler: a stack of 2D convolution based on ESPnet implementation
(https://github.com/espnet/espnet)
Args:
input_channels (int): the number of input channels
input_feat_per_channel (int): encoder input dimension per input channel
conv_out_channels (int): the number of output channels of conv layer
encoder_embed_dim (int): encoder dimentions
"""
def __init__(
self,
input_channels: int,
input_feat_per_channel: int,
conv_out_channels: int,
encoder_embed_dim: int,
):
super().__init__()
assert input_channels == 1, input_channels
self.conv = torch.nn.Sequential(
torch.nn.Conv2d(
input_channels, conv_out_channels, 3, stride=2, padding=3 // 2
),
torch.nn.ReLU(),
torch.nn.Conv2d(
conv_out_channels,
conv_out_channels,
3,
stride=2,
padding=3 // 2,
),
torch.nn.ReLU(),
)
transformer_input_dim = infer_conv_output_dim(
input_channels, input_feat_per_channel, conv_out_channels
)
self.out = torch.nn.Linear(transformer_input_dim, encoder_embed_dim)
def forward(self, src_tokens, src_lengths):
B, T_i, C = src_tokens.size()
x = src_tokens.view(B, T_i, 1, C).transpose(1, 2).contiguous()
x = self.conv(x)
B, _, T_o, _ = x.size()
x = x.transpose(1, 2).transpose(0, 1).contiguous().view(T_o, B, -1)
x = self.out(x)
subsampling_factor = int(T_i * 1.0 / T_o + 0.5)
input_len_0 = (src_lengths.float() / subsampling_factor).ceil().long()
input_len_1 = x.size(0) * torch.ones([src_lengths.size(0)]).long().to(
input_len_0.device
)
input_lengths = torch.min(input_len_0, input_len_1)
return x, input_lengths
|