Spaces:
Runtime error
Runtime error
| import torch.nn as nn | |
| import torch | |
| import logging | |
| from modules.FastDiff.module.modules import DiffusionDBlock, TimeAware_LVCBlock | |
| from modules.FastDiff.module.util import calc_diffusion_step_embedding | |
| def swish(x): | |
| return x * torch.sigmoid(x) | |
| class FastDiff(nn.Module): | |
| """FastDiff module.""" | |
| def __init__(self, | |
| audio_channels=1, | |
| inner_channels=32, | |
| cond_channels=80, | |
| upsample_ratios=[8, 8, 4], | |
| lvc_layers_each_block=4, | |
| lvc_kernel_size=3, | |
| kpnet_hidden_channels=64, | |
| kpnet_conv_size=3, | |
| dropout=0.0, | |
| diffusion_step_embed_dim_in=128, | |
| diffusion_step_embed_dim_mid=512, | |
| diffusion_step_embed_dim_out=512, | |
| use_weight_norm=True): | |
| super().__init__() | |
| self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in | |
| self.audio_channels = audio_channels | |
| self.cond_channels = cond_channels | |
| self.lvc_block_nums = len(upsample_ratios) | |
| self.first_audio_conv = nn.Conv1d(1, inner_channels, | |
| kernel_size=7, padding=(7 - 1) // 2, | |
| dilation=1, bias=True) | |
| # define residual blocks | |
| self.lvc_blocks = nn.ModuleList() | |
| self.downsample = nn.ModuleList() | |
| # the layer-specific fc for noise scale embedding | |
| self.fc_t = nn.ModuleList() | |
| self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) | |
| self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) | |
| cond_hop_length = 1 | |
| for n in range(self.lvc_block_nums): | |
| cond_hop_length = cond_hop_length * upsample_ratios[n] | |
| lvcb = TimeAware_LVCBlock( | |
| in_channels=inner_channels, | |
| cond_channels=cond_channels, | |
| upsample_ratio=upsample_ratios[n], | |
| conv_layers=lvc_layers_each_block, | |
| conv_kernel_size=lvc_kernel_size, | |
| cond_hop_length=cond_hop_length, | |
| kpnet_hidden_channels=kpnet_hidden_channels, | |
| kpnet_conv_size=kpnet_conv_size, | |
| kpnet_dropout=dropout, | |
| noise_scale_embed_dim_out=diffusion_step_embed_dim_out | |
| ) | |
| self.lvc_blocks += [lvcb] | |
| self.downsample.append(DiffusionDBlock(inner_channels, inner_channels, upsample_ratios[self.lvc_block_nums-n-1])) | |
| # define output layers | |
| self.final_conv = nn.Sequential(nn.Conv1d(inner_channels, audio_channels, kernel_size=7, padding=(7 - 1) // 2, | |
| dilation=1, bias=True)) | |
| # apply weight norm | |
| if use_weight_norm: | |
| self.apply_weight_norm() | |
| def forward(self, data): | |
| """Calculate forward propagation. | |
| Args: | |
| x (Tensor): Input noise signal (B, 1, T). | |
| c (Tensor): Local conditioning auxiliary features (B, C ,T'). | |
| Returns: | |
| Tensor: Output tensor (B, out_channels, T) | |
| """ | |
| audio, c, diffusion_steps = data | |
| # embed diffusion step t | |
| diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in) | |
| diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) | |
| diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) | |
| audio = self.first_audio_conv(audio) | |
| downsample = [] | |
| for down_layer in self.downsample: | |
| downsample.append(audio) | |
| audio = down_layer(audio) | |
| x = audio | |
| for n, audio_down in enumerate(reversed(downsample)): | |
| x = self.lvc_blocks[n]((x, audio_down, c, diffusion_step_embed)) | |
| # apply final layers | |
| x = self.final_conv(x) | |
| return x | |
| def remove_weight_norm(self): | |
| """Remove weight normalization module from all of the layers.""" | |
| def _remove_weight_norm(m): | |
| try: | |
| logging.debug(f"Weight norm is removed from {m}.") | |
| torch.nn.utils.remove_weight_norm(m) | |
| except ValueError: # this module didn't have weight norm | |
| return | |
| self.apply(_remove_weight_norm) | |
| def apply_weight_norm(self): | |
| """Apply weight normalization module from all of the layers.""" | |
| def _apply_weight_norm(m): | |
| if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): | |
| torch.nn.utils.weight_norm(m) | |
| logging.debug(f"Weight norm is applied to {m}.") | |
| self.apply(_apply_weight_norm) | |