Spaces:
Runtime error
Runtime error
| # This code is based on the following repository written by Christian J. Steinmetz | |
| # https://github.com/csteinmetz1/micro-tcn | |
| from typing import Callable | |
| import torch | |
| import torch.nn as nn | |
| from torch import Tensor | |
| from remfx.utils import causal_crop, center_crop | |
| class TCNBlock(nn.Module): | |
| def __init__( | |
| self, | |
| in_ch: int, | |
| out_ch: int, | |
| kernel_size: int = 3, | |
| dilation: int = 1, | |
| stride: int = 1, | |
| crop_fn: Callable = causal_crop, | |
| ) -> None: | |
| super().__init__() | |
| self.in_ch = in_ch | |
| self.out_ch = out_ch | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.crop_fn = crop_fn | |
| self.conv1 = nn.Conv1d( | |
| in_ch, | |
| out_ch, | |
| kernel_size, | |
| stride=stride, | |
| padding=0, | |
| dilation=dilation, | |
| bias=True, | |
| ) | |
| # residual connection | |
| self.res = nn.Conv1d( | |
| in_ch, | |
| out_ch, | |
| kernel_size=1, | |
| groups=1, | |
| stride=stride, | |
| bias=False, | |
| ) | |
| self.relu = nn.PReLU(out_ch) | |
| def forward(self, x: Tensor) -> Tensor: | |
| x_in = x | |
| x = self.conv1(x) | |
| x = self.relu(x) | |
| # residual | |
| x_res = self.res(x_in) | |
| # causal crop | |
| x = x + self.crop_fn(x_res, x.shape[-1]) | |
| return x | |
| class TCN(nn.Module): | |
| def __init__( | |
| self, | |
| ninputs: int = 1, | |
| noutputs: int = 1, | |
| nblocks: int = 4, | |
| channel_growth: int = 0, | |
| channel_width: int = 32, | |
| kernel_size: int = 13, | |
| stack_size: int = 10, | |
| dilation_growth: int = 10, | |
| condition: bool = False, | |
| latent_dim: int = 2, | |
| norm_type: str = "identity", | |
| causal: bool = False, | |
| estimate_loudness: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.ninputs = ninputs | |
| self.noutputs = noutputs | |
| self.nblocks = nblocks | |
| self.channel_growth = channel_growth | |
| self.channel_width = channel_width | |
| self.kernel_size = kernel_size | |
| self.stack_size = stack_size | |
| self.dilation_growth = dilation_growth | |
| self.condition = condition | |
| self.latent_dim = latent_dim | |
| self.norm_type = norm_type | |
| self.causal = causal | |
| self.estimate_loudness = estimate_loudness | |
| if self.causal: | |
| self.crop_fn = causal_crop | |
| else: | |
| self.crop_fn = center_crop | |
| if estimate_loudness: | |
| self.loudness = torch.nn.Linear(latent_dim, 1) | |
| # audio model | |
| self.process_blocks = torch.nn.ModuleList() | |
| out_ch = -1 | |
| for n in range(nblocks): | |
| in_ch = out_ch if n > 0 else ninputs | |
| out_ch = in_ch * channel_growth if channel_growth > 1 else channel_width | |
| dilation = dilation_growth ** (n % stack_size) | |
| self.process_blocks.append( | |
| TCNBlock( | |
| in_ch, | |
| out_ch, | |
| kernel_size, | |
| dilation, | |
| stride=1, | |
| crop_fn=self.crop_fn, | |
| ) | |
| ) | |
| self.output = nn.Conv1d(out_ch, noutputs, kernel_size=1) | |
| # model configuration | |
| self.receptive_field = self.compute_receptive_field() | |
| self.block_size = 2048 | |
| self.buffer = torch.zeros(2, self.receptive_field + self.block_size - 1) | |
| def forward(self, x: Tensor) -> Tensor: | |
| for _, block in enumerate(self.process_blocks): | |
| x = block(x) | |
| y_hat = torch.tanh(self.output(x)) | |
| return y_hat | |
| def compute_receptive_field(self): | |
| """Compute the receptive field in samples.""" | |
| rf = self.kernel_size | |
| for n in range(1, self.nblocks): | |
| dilation = self.dilation_growth ** (n % self.stack_size) | |
| rf = rf + ((self.kernel_size - 1) * dilation) | |
| return rf | |