| | |
| | |
| | |
| | |
| |
|
| | import math |
| |
|
| | import torch.nn as nn |
| |
|
| | from modules.general.utils import Conv1d, zero_module |
| | from .residual_block import ResidualBlock |
| |
|
| |
|
| | class BiDilConv(nn.Module): |
| | r"""Dilated CNN architecture with residual connections, default diffusion decoder. |
| | |
| | Args: |
| | input_channel: The number of input channels. |
| | base_channel: The number of base channels. |
| | n_res_block: The number of residual blocks. |
| | conv_kernel_size: The kernel size of convolutional layers. |
| | dilation_cycle_length: The cycle length of dilation. |
| | conditioner_size: The size of conditioner. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_channel, |
| | base_channel, |
| | n_res_block, |
| | conv_kernel_size, |
| | dilation_cycle_length, |
| | conditioner_size, |
| | output_channel: int = -1, |
| | ): |
| | super().__init__() |
| |
|
| | self.input_channel = input_channel |
| | self.base_channel = base_channel |
| | self.n_res_block = n_res_block |
| | self.conv_kernel_size = conv_kernel_size |
| | self.dilation_cycle_length = dilation_cycle_length |
| | self.conditioner_size = conditioner_size |
| | self.output_channel = output_channel if output_channel > 0 else input_channel |
| |
|
| | self.input = nn.Sequential( |
| | Conv1d( |
| | input_channel, |
| | base_channel, |
| | 1, |
| | ), |
| | nn.ReLU(), |
| | ) |
| |
|
| | self.residual_blocks = nn.ModuleList( |
| | [ |
| | ResidualBlock( |
| | channels=base_channel, |
| | kernel_size=conv_kernel_size, |
| | dilation=2 ** (i % dilation_cycle_length), |
| | d_context=conditioner_size, |
| | ) |
| | for i in range(n_res_block) |
| | ] |
| | ) |
| |
|
| | self.out_proj = nn.Sequential( |
| | Conv1d( |
| | base_channel, |
| | base_channel, |
| | 1, |
| | ), |
| | nn.ReLU(), |
| | zero_module( |
| | Conv1d( |
| | base_channel, |
| | self.output_channel, |
| | 1, |
| | ), |
| | ), |
| | ) |
| |
|
| | def forward(self, x, y, context=None): |
| | """ |
| | Args: |
| | x: Noisy mel-spectrogram [B x ``n_mel`` x L] |
| | y: FILM embeddings with the shape of (B, ``base_channel``) |
| | context: Context with the shape of [B x ``d_context`` x L], default to None. |
| | """ |
| |
|
| | h = self.input(x) |
| |
|
| | skip = None |
| | for i in range(self.n_res_block): |
| | h, skip_connection = self.residual_blocks[i](h, y, context) |
| | skip = skip_connection if skip is None else skip_connection + skip |
| |
|
| | out = skip / math.sqrt(self.n_res_block) |
| |
|
| | out = self.out_proj(out) |
| |
|
| | return out |
| |
|