|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Tuple |
|
|
import warnings |
|
|
|
|
|
import torch |
|
|
from torch import Tensor, nn |
|
|
from scaling import ( |
|
|
Balancer, |
|
|
BiasNorm, |
|
|
Dropout3, |
|
|
FloatLike, |
|
|
Optional, |
|
|
ScaledConv2d, |
|
|
ScaleGrad, |
|
|
ScheduledFloat, |
|
|
SwooshL, |
|
|
SwooshR, |
|
|
Whiten, |
|
|
) |
|
|
|
|
|
|
|
|
class ConvNeXt(nn.Module): |
|
|
""" |
|
|
Our interpretation of the ConvNeXt module as used in https://arxiv.org/pdf/2206.14747.pdf |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
channels: int, |
|
|
hidden_ratio: int = 3, |
|
|
kernel_size: Tuple[int, int] = (7, 7), |
|
|
layerdrop_rate: FloatLike = None, |
|
|
): |
|
|
super().__init__() |
|
|
self.padding = ((kernel_size[0] - 1) // 2, (kernel_size[1] - 1) // 2) |
|
|
hidden_channels = channels * hidden_ratio |
|
|
if layerdrop_rate is None: |
|
|
layerdrop_rate = ScheduledFloat((0.0, 0.2), (20000.0, 0.015)) |
|
|
self.layerdrop_rate = layerdrop_rate |
|
|
|
|
|
self.depthwise_conv = nn.Conv2d( |
|
|
in_channels=channels, |
|
|
out_channels=channels, |
|
|
groups=channels, |
|
|
kernel_size=kernel_size, |
|
|
padding=self.padding, |
|
|
) |
|
|
|
|
|
self.pointwise_conv1 = nn.Conv2d( |
|
|
in_channels=channels, out_channels=hidden_channels, kernel_size=1 |
|
|
) |
|
|
|
|
|
self.hidden_balancer = Balancer( |
|
|
hidden_channels, |
|
|
channel_dim=1, |
|
|
min_positive=0.3, |
|
|
max_positive=1.0, |
|
|
min_abs=0.75, |
|
|
max_abs=5.0, |
|
|
) |
|
|
|
|
|
self.activation = SwooshL() |
|
|
self.pointwise_conv2 = ScaledConv2d( |
|
|
in_channels=hidden_channels, |
|
|
out_channels=channels, |
|
|
kernel_size=1, |
|
|
initial_scale=0.01, |
|
|
) |
|
|
|
|
|
self.out_balancer = Balancer( |
|
|
channels, |
|
|
channel_dim=1, |
|
|
min_positive=0.4, |
|
|
max_positive=0.6, |
|
|
min_abs=1.0, |
|
|
max_abs=6.0, |
|
|
) |
|
|
self.out_whiten = Whiten( |
|
|
num_groups=1, |
|
|
whitening_limit=5.0, |
|
|
prob=(0.025, 0.25), |
|
|
grad_scale=0.01, |
|
|
) |
|
|
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
|
if torch.jit.is_scripting() or torch.jit.is_tracing() or not self.training: |
|
|
return self.forward_internal(x) |
|
|
layerdrop_rate = float(self.layerdrop_rate) |
|
|
|
|
|
if layerdrop_rate != 0.0: |
|
|
batch_size = x.shape[0] |
|
|
mask = ( |
|
|
torch.rand((batch_size, 1, 1, 1), dtype=x.dtype, device=x.device) |
|
|
> layerdrop_rate |
|
|
) |
|
|
else: |
|
|
mask = None |
|
|
|
|
|
|
|
|
return self.forward_internal(x, mask) |
|
|
|
|
|
def forward_internal( |
|
|
self, x: Tensor, layer_skip_mask: Optional[Tensor] = None |
|
|
) -> Tensor: |
|
|
""" |
|
|
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) |
|
|
|
|
|
The returned value has the same shape as x. |
|
|
""" |
|
|
bypass = x |
|
|
x = self.depthwise_conv(x) |
|
|
x = self.pointwise_conv1(x) |
|
|
x = self.hidden_balancer(x) |
|
|
x = self.activation(x) |
|
|
x = self.pointwise_conv2(x) |
|
|
|
|
|
if layer_skip_mask is not None: |
|
|
x = x * layer_skip_mask |
|
|
|
|
|
x = bypass + x |
|
|
x = self.out_balancer(x) |
|
|
|
|
|
if x.requires_grad: |
|
|
x = x.transpose(1, 3) |
|
|
x = self.out_whiten(x) |
|
|
x = x.transpose(1, 3) |
|
|
|
|
|
return x |
|
|
|
|
|
def streaming_forward( |
|
|
self, |
|
|
x: Tensor, |
|
|
cached_left_pad: Tensor, |
|
|
) -> Tuple[Tensor, Tensor]: |
|
|
""" |
|
|
Args: |
|
|
x layout: (N, C, H, W), i.e. (batch_size, num_channels, num_frames, num_freqs) |
|
|
cached_left_pad: (batch_size, num_channels, left_pad, num_freqs) |
|
|
|
|
|
Returns: |
|
|
- The returned value has the same shape as x. |
|
|
- Updated cached_left_pad. |
|
|
""" |
|
|
padding = self.padding |
|
|
|
|
|
|
|
|
T = x.size(2) - padding[0] |
|
|
|
|
|
bypass = x[:, :, :T, :] |
|
|
|
|
|
|
|
|
assert cached_left_pad.size(2) == padding[0], ( |
|
|
cached_left_pad.size(2), |
|
|
padding[0], |
|
|
) |
|
|
x = torch.cat([cached_left_pad, x], dim=2) |
|
|
|
|
|
cached_left_pad = x[:, :, T : padding[0] + T, :] |
|
|
|
|
|
|
|
|
x = torch.nn.functional.conv2d( |
|
|
x, |
|
|
weight=self.depthwise_conv.weight, |
|
|
bias=self.depthwise_conv.bias, |
|
|
padding=(0, padding[1]), |
|
|
groups=self.depthwise_conv.groups, |
|
|
) |
|
|
x = self.pointwise_conv1(x) |
|
|
x = self.hidden_balancer(x) |
|
|
x = self.activation(x) |
|
|
x = self.pointwise_conv2(x) |
|
|
|
|
|
x = bypass + x |
|
|
return x, cached_left_pad |
|
|
|
|
|
|
|
|
class Conv2dSubsampling(nn.Module): |
|
|
"""Convolutional 2D subsampling (to 1/2 length). |
|
|
|
|
|
Convert an input of shape (N, T, idim) to an output |
|
|
with shape (N, T', odim), where |
|
|
T' = (T-3)//2 - 2 == (T-7)//2 |
|
|
|
|
|
It is based on |
|
|
https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
out_channels: int, |
|
|
layer1_channels: int = 8, |
|
|
layer2_channels: int = 32, |
|
|
layer3_channels: int = 128, |
|
|
dropout: FloatLike = 0.1, |
|
|
) -> None: |
|
|
""" |
|
|
Args: |
|
|
in_channels: |
|
|
Number of channels in. The input shape is (N, T, in_channels). |
|
|
Caution: It requires: T >=7, in_channels >=7 |
|
|
out_channels |
|
|
Output dim. The output shape is (N, (T-3)//2, out_channels) |
|
|
layer1_channels: |
|
|
Number of channels in layer1 |
|
|
layer1_channels: |
|
|
Number of channels in layer2 |
|
|
bottleneck: |
|
|
bottleneck dimension for 1d squeeze-excite |
|
|
""" |
|
|
assert in_channels >= 7 |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.conv = nn.Sequential( |
|
|
nn.Conv2d( |
|
|
in_channels=1, |
|
|
out_channels=layer1_channels, |
|
|
kernel_size=3, |
|
|
padding=(0, 1), |
|
|
), |
|
|
ScaleGrad(0.2), |
|
|
Balancer(layer1_channels, channel_dim=1, max_abs=1.0), |
|
|
SwooshR(), |
|
|
nn.Conv2d( |
|
|
in_channels=layer1_channels, |
|
|
out_channels=layer2_channels, |
|
|
kernel_size=3, |
|
|
stride=2, |
|
|
padding=0, |
|
|
), |
|
|
Balancer(layer2_channels, channel_dim=1, max_abs=4.0), |
|
|
SwooshR(), |
|
|
nn.Conv2d( |
|
|
in_channels=layer2_channels, |
|
|
out_channels=layer3_channels, |
|
|
kernel_size=3, |
|
|
stride=(1, 2), |
|
|
), |
|
|
Balancer(layer3_channels, channel_dim=1, max_abs=4.0), |
|
|
SwooshR(), |
|
|
) |
|
|
|
|
|
|
|
|
self.convnext = ConvNeXt(layer3_channels, kernel_size=(7, 7)) |
|
|
|
|
|
|
|
|
self.out_width = (((in_channels - 1) // 2) - 1) // 2 |
|
|
self.layer3_channels = layer3_channels |
|
|
|
|
|
self.out = nn.Linear(self.out_width * layer3_channels, out_channels) |
|
|
|
|
|
|
|
|
|
|
|
self.out_whiten = Whiten( |
|
|
num_groups=1, |
|
|
whitening_limit=ScheduledFloat((0.0, 4.0), (20000.0, 8.0), default=4.0), |
|
|
prob=(0.025, 0.25), |
|
|
grad_scale=0.02, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.out_norm = BiasNorm(out_channels) |
|
|
self.dropout = Dropout3(dropout, shared_dim=1) |
|
|
|
|
|
def forward( |
|
|
self, x: torch.Tensor, x_lens: torch.Tensor |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Subsample x. |
|
|
|
|
|
Args: |
|
|
x: |
|
|
Its shape is (N, T, idim). |
|
|
x_lens: |
|
|
A tensor of shape (batch_size,) containing the number of frames in |
|
|
|
|
|
Returns: |
|
|
- a tensor of shape (N, (T-7)//2, odim) |
|
|
- output lengths, of shape (batch_size,) |
|
|
""" |
|
|
|
|
|
x = x.unsqueeze(1) |
|
|
|
|
|
|
|
|
|
|
|
x = self.conv(x) |
|
|
x = self.convnext(x) |
|
|
|
|
|
|
|
|
b, c, t, f = x.size() |
|
|
|
|
|
x = x.transpose(1, 2).reshape(b, t, c * f) |
|
|
|
|
|
|
|
|
x = self.out(x) |
|
|
|
|
|
x = self.out_whiten(x) |
|
|
x = self.out_norm(x) |
|
|
x = self.dropout(x) |
|
|
|
|
|
if torch.jit.is_scripting() or torch.jit.is_tracing(): |
|
|
x_lens = (x_lens - 7) // 2 |
|
|
else: |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
x_lens = (x_lens - 7) // 2 |
|
|
assert x.size(1) == x_lens.max().item(), (x.size(1), x_lens.max()) |
|
|
|
|
|
return x, x_lens |
|
|
|
|
|
def streaming_forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
x_lens: torch.Tensor, |
|
|
cached_left_pad: Tensor, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""Subsample x. |
|
|
|
|
|
Args: |
|
|
x: |
|
|
Its shape is (N, T, idim). |
|
|
x_lens: |
|
|
A tensor of shape (batch_size,) containing the number of frames in |
|
|
|
|
|
Returns: |
|
|
- a tensor of shape (N, (T-7)//2, odim) |
|
|
- output lengths, of shape (batch_size,) |
|
|
- updated cache |
|
|
""" |
|
|
|
|
|
x = x.unsqueeze(1) |
|
|
|
|
|
|
|
|
x = self.conv(x) |
|
|
|
|
|
|
|
|
x, cached_left_pad = self.convnext.streaming_forward( |
|
|
x, cached_left_pad=cached_left_pad |
|
|
) |
|
|
|
|
|
|
|
|
b, c, t, f = x.size() |
|
|
|
|
|
x = x.transpose(1, 2).reshape(b, t, c * f) |
|
|
|
|
|
|
|
|
x = self.out(x) |
|
|
|
|
|
x = self.out_norm(x) |
|
|
|
|
|
if torch.jit.is_scripting() or torch.jit.is_tracing(): |
|
|
assert self.convnext.padding[0] == 3 |
|
|
|
|
|
x_lens = (x_lens - 7) // 2 - 3 |
|
|
else: |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore") |
|
|
|
|
|
assert self.convnext.padding[0] == 3 |
|
|
x_lens = (x_lens - 7) // 2 - 3 |
|
|
|
|
|
assert x.size(1) == x_lens.max().item(), (x.shape, x_lens.max()) |
|
|
|
|
|
return x, x_lens, cached_left_pad |
|
|
|
|
|
@torch.jit.export |
|
|
def get_init_states( |
|
|
self, |
|
|
batch_size: int = 1, |
|
|
device: torch.device = torch.device("cpu"), |
|
|
) -> Tensor: |
|
|
"""Get initial states for Conv2dSubsampling module. |
|
|
It is the cached left padding for ConvNeXt module, |
|
|
of shape (batch_size, num_channels, left_pad, num_freqs) |
|
|
""" |
|
|
left_pad = self.convnext.padding[0] |
|
|
freq = self.out_width |
|
|
channels = self.layer3_channels |
|
|
cached_embed_left_pad = torch.zeros(batch_size, channels, left_pad, freq).to( |
|
|
device |
|
|
) |
|
|
|
|
|
return cached_embed_left_pad |