Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Callable, Sequence, Type, Union | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| ModuleFactory = Union[Type[nn.Module], Callable[[], nn.Module]] | |
| class FeedForwardModule(nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.net = None | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x) | |
| class Residual(nn.Module): | |
| def __init__(self, module: nn.Module) -> None: | |
| super().__init__() | |
| self.module = module | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.module(x) + x | |
| class DilatedConvolutionalUnit(FeedForwardModule): | |
| def __init__( | |
| self, | |
| hidden_dim: int, | |
| dilation: int, | |
| kernel_size: int, | |
| activation: ModuleFactory, | |
| normalization: Callable[[nn.Module], | |
| nn.Module] = lambda x: x) -> None: | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| activation(), | |
| normalization( | |
| nn.Conv1d( | |
| in_channels=hidden_dim, | |
| out_channels=hidden_dim, | |
| kernel_size=kernel_size, | |
| dilation=dilation, | |
| padding=((kernel_size - 1) * dilation) // 2, | |
| )), | |
| activation(), | |
| nn.Conv1d(in_channels=hidden_dim, | |
| out_channels=hidden_dim, | |
| kernel_size=1), | |
| ) | |
| class UpsamplingUnit(FeedForwardModule): | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int, | |
| stride: int, | |
| activation: ModuleFactory, | |
| normalization: Callable[[nn.Module], | |
| nn.Module] = lambda x: x) -> None: | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| activation(), | |
| normalization( | |
| nn.ConvTranspose1d( | |
| in_channels=input_dim, | |
| out_channels=output_dim, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| padding=stride // 2+ stride % 2, | |
| output_padding=1 if stride % 2 != 0 else 0 | |
| ))) | |
| class DownsamplingUnit(FeedForwardModule): | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| output_dim: int, | |
| stride: int, | |
| activation: ModuleFactory, | |
| normalization: Callable[[nn.Module], | |
| nn.Module] = lambda x: x) -> None: | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| activation(), | |
| normalization( | |
| nn.Conv1d( | |
| in_channels=input_dim, | |
| out_channels=output_dim, | |
| kernel_size=2 * stride, | |
| stride=stride, | |
| padding= stride // 2+ stride % 2, | |
| ))) | |
| class DilatedResidualEncoder(FeedForwardModule): | |
| def __init__( | |
| self, | |
| capacity: int, | |
| dilated_unit: Type[DilatedConvolutionalUnit], | |
| downsampling_unit: Type[DownsamplingUnit], | |
| ratios: Sequence[int], | |
| dilations: Union[Sequence[int], Sequence[Sequence[int]]], | |
| pre_network_conv: Type[nn.Conv1d], | |
| post_network_conv: Type[nn.Conv1d], | |
| normalization: Callable[[nn.Module], | |
| nn.Module] = lambda x: x) -> None: | |
| super().__init__() | |
| channels = capacity * 2**np.arange(len(ratios) + 1) | |
| dilations_list = self.normalize_dilations(dilations, ratios) | |
| net = [normalization(pre_network_conv(out_channels=channels[0]))] | |
| for ratio, dilations, input_dim, output_dim in zip( | |
| ratios, dilations_list, channels[:-1], channels[1:]): | |
| for dilation in dilations: | |
| net.append(Residual(dilated_unit(input_dim, dilation))) | |
| net.append(downsampling_unit(input_dim, output_dim, ratio)) | |
| net.append(post_network_conv(in_channels=output_dim)) | |
| self.net = nn.Sequential(*net) | |
| def normalize_dilations(dilations: Union[Sequence[int], | |
| Sequence[Sequence[int]]], | |
| ratios: Sequence[int]): | |
| if isinstance(dilations[0], int): | |
| dilations = [dilations for _ in ratios] | |
| return dilations | |
| class DilatedResidualDecoder(FeedForwardModule): | |
| def __init__( | |
| self, | |
| capacity: int, | |
| dilated_unit: Type[DilatedConvolutionalUnit], | |
| upsampling_unit: Type[UpsamplingUnit], | |
| ratios: Sequence[int], | |
| dilations: Union[Sequence[int], Sequence[Sequence[int]]], | |
| pre_network_conv: Type[nn.Conv1d], | |
| post_network_conv: Type[nn.Conv1d], | |
| normalization: Callable[[nn.Module], | |
| nn.Module] = lambda x: x) -> None: | |
| super().__init__() | |
| channels = capacity * 2**np.arange(len(ratios) + 1) | |
| channels = channels[::-1] | |
| dilations_list = self.normalize_dilations(dilations, ratios) | |
| dilations_list = dilations_list[::-1] | |
| net = [pre_network_conv(out_channels=channels[0])] | |
| for ratio, dilations, input_dim, output_dim in zip( | |
| ratios, dilations_list, channels[:-1], channels[1:]): | |
| net.append(upsampling_unit(input_dim, output_dim, ratio)) | |
| for dilation in dilations: | |
| net.append(Residual(dilated_unit(output_dim, dilation))) | |
| net.append(normalization(post_network_conv(in_channels=output_dim))) | |
| self.net = nn.Sequential(*net) | |
| def normalize_dilations(dilations: Union[Sequence[int], | |
| Sequence[Sequence[int]]], | |
| ratios: Sequence[int]): | |
| if isinstance(dilations[0], int): | |
| dilations = [dilations for _ in ratios] | |
| return dilations |