| import torch |
| from torch.nn import functional as F |
|
|
|
|
| class Stretch2d(torch.nn.Module): |
| def __init__(self, x_scale, y_scale, mode="nearest"): |
| super().__init__() |
| self.x_scale = x_scale |
| self.y_scale = y_scale |
| self.mode = mode |
|
|
| def forward(self, x): |
| """ |
| x (Tensor): Input tensor (B, C, F, T). |
| Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), |
| """ |
| return F.interpolate(x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) |
|
|
|
|
| class UpsampleNetwork(torch.nn.Module): |
| |
| def __init__( |
| self, |
| upsample_factors, |
| nonlinear_activation=None, |
| nonlinear_activation_params={}, |
| interpolate_mode="nearest", |
| freq_axis_kernel_size=1, |
| use_causal_conv=False, |
| ): |
| super().__init__() |
| self.use_causal_conv = use_causal_conv |
| self.up_layers = torch.nn.ModuleList() |
| for scale in upsample_factors: |
| |
| stretch = Stretch2d(scale, 1, interpolate_mode) |
| self.up_layers += [stretch] |
|
|
| |
| assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." |
| freq_axis_padding = (freq_axis_kernel_size - 1) // 2 |
| kernel_size = (freq_axis_kernel_size, scale * 2 + 1) |
| if use_causal_conv: |
| padding = (freq_axis_padding, scale * 2) |
| else: |
| padding = (freq_axis_padding, scale) |
| conv = torch.nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) |
| self.up_layers += [conv] |
|
|
| |
| if nonlinear_activation is not None: |
| nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) |
| self.up_layers += [nonlinear] |
|
|
| def forward(self, c): |
| """ |
| c : (B, C, T_in). |
| Tensor: (B, C, T_upsample) |
| """ |
| c = c.unsqueeze(1) |
| for f in self.up_layers: |
| c = f(c) |
| return c.squeeze(1) |
|
|
|
|
| class ConvUpsample(torch.nn.Module): |
| |
| def __init__( |
| self, |
| upsample_factors, |
| nonlinear_activation=None, |
| nonlinear_activation_params={}, |
| interpolate_mode="nearest", |
| freq_axis_kernel_size=1, |
| aux_channels=80, |
| aux_context_window=0, |
| use_causal_conv=False, |
| ): |
| super().__init__() |
| self.aux_context_window = aux_context_window |
| self.use_causal_conv = use_causal_conv and aux_context_window > 0 |
| |
| kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 |
| |
| self.conv_in = torch.nn.Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) |
| self.upsample = UpsampleNetwork( |
| upsample_factors=upsample_factors, |
| nonlinear_activation=nonlinear_activation, |
| nonlinear_activation_params=nonlinear_activation_params, |
| interpolate_mode=interpolate_mode, |
| freq_axis_kernel_size=freq_axis_kernel_size, |
| use_causal_conv=use_causal_conv, |
| ) |
|
|
| def forward(self, c): |
| """ |
| c : (B, C, T_in). |
| Tensor: (B, C, T_upsampled), |
| """ |
| c_ = self.conv_in(c) |
| c = c_[:, :, : -self.aux_context_window] if self.use_causal_conv else c_ |
| return self.upsample(c) |
|
|