| import numpy as np
|
| from typing import Dict, List, NoReturn, Tuple
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torchlibrosa.stft import STFT, ISTFT, magphase
|
| from models.base import Base, init_layer, init_bn, act
|
|
|
|
|
| class FiLM(nn.Module):
|
| def __init__(self, film_meta, condition_size):
|
| super(FiLM, self).__init__()
|
|
|
| self.condition_size = condition_size
|
|
|
| self.modules, _ = self.create_film_modules(
|
| film_meta=film_meta,
|
| ancestor_names=[],
|
| )
|
|
|
| def create_film_modules(self, film_meta, ancestor_names):
|
|
|
| modules = {}
|
|
|
|
|
| for module_name, value in film_meta.items():
|
|
|
| if isinstance(value, int):
|
|
|
| ancestor_names.append(module_name)
|
| unique_module_name = '->'.join(ancestor_names)
|
|
|
| modules[module_name] = self.add_film_layer_to_module(
|
| num_features=value,
|
| unique_module_name=unique_module_name,
|
| )
|
|
|
| elif isinstance(value, dict):
|
|
|
| ancestor_names.append(module_name)
|
|
|
| modules[module_name], _ = self.create_film_modules(
|
| film_meta=value,
|
| ancestor_names=ancestor_names,
|
| )
|
|
|
| ancestor_names.pop()
|
|
|
| return modules, ancestor_names
|
|
|
| def add_film_layer_to_module(self, num_features, unique_module_name):
|
|
|
| layer = nn.Linear(self.condition_size, num_features)
|
| init_layer(layer)
|
| self.add_module(name=unique_module_name, module=layer)
|
|
|
| return layer
|
|
|
| def forward(self, conditions):
|
|
|
| film_dict = self.calculate_film_data(
|
| conditions=conditions,
|
| modules=self.modules,
|
| )
|
|
|
| return film_dict
|
|
|
| def calculate_film_data(self, conditions, modules):
|
|
|
| film_data = {}
|
|
|
|
|
| for module_name, module in modules.items():
|
|
|
| if isinstance(module, nn.Module):
|
| film_data[module_name] = module(conditions)[:, :, None, None]
|
|
|
| elif isinstance(module, dict):
|
| film_data[module_name] = self.calculate_film_data(conditions, module)
|
|
|
| return film_data
|
|
|
|
|
| class ConvBlockRes(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| out_channels: int,
|
| kernel_size: Tuple,
|
| momentum: float,
|
| has_film,
|
| ):
|
| r"""Residual block."""
|
| super(ConvBlockRes, self).__init__()
|
|
|
| padding = [kernel_size[0] // 2, kernel_size[1] // 2]
|
|
|
| self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
| self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
|
|
| self.conv1 = nn.Conv2d(
|
| in_channels=in_channels,
|
| out_channels=out_channels,
|
| kernel_size=kernel_size,
|
| stride=(1, 1),
|
| dilation=(1, 1),
|
| padding=padding,
|
| bias=False,
|
| )
|
|
|
| self.conv2 = nn.Conv2d(
|
| in_channels=out_channels,
|
| out_channels=out_channels,
|
| kernel_size=kernel_size,
|
| stride=(1, 1),
|
| dilation=(1, 1),
|
| padding=padding,
|
| bias=False,
|
| )
|
|
|
| if in_channels != out_channels:
|
| self.shortcut = nn.Conv2d(
|
| in_channels=in_channels,
|
| out_channels=out_channels,
|
| kernel_size=(1, 1),
|
| stride=(1, 1),
|
| padding=(0, 0),
|
| )
|
| self.is_shortcut = True
|
| else:
|
| self.is_shortcut = False
|
|
|
| self.has_film = has_film
|
|
|
| self.init_weights()
|
|
|
| def init_weights(self) -> NoReturn:
|
| r"""Initialize weights."""
|
| init_bn(self.bn1)
|
| init_bn(self.bn2)
|
| init_layer(self.conv1)
|
| init_layer(self.conv2)
|
|
|
| if self.is_shortcut:
|
| init_layer(self.shortcut)
|
|
|
| def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor:
|
| r"""Forward data into the module.
|
|
|
| Args:
|
| input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins)
|
|
|
| Returns:
|
| output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins)
|
| """
|
| b1 = film_dict['beta1']
|
| b2 = film_dict['beta2']
|
|
|
| x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1, negative_slope=0.01))
|
| x = self.conv2(F.leaky_relu_(self.bn2(x) + b2, negative_slope=0.01))
|
|
|
| if self.is_shortcut:
|
| return self.shortcut(input_tensor) + x
|
| else:
|
| return input_tensor + x
|
|
|
|
|
| class EncoderBlockRes1B(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| out_channels: int,
|
| kernel_size: Tuple,
|
| downsample: Tuple,
|
| momentum: float,
|
| has_film,
|
| ):
|
| r"""Encoder block, contains 8 convolutional layers."""
|
| super(EncoderBlockRes1B, self).__init__()
|
|
|
| self.conv_block1 = ConvBlockRes(
|
| in_channels, out_channels, kernel_size, momentum, has_film,
|
| )
|
| self.downsample = downsample
|
|
|
| def forward(self, input_tensor: torch.Tensor, film_dict: Dict) -> torch.Tensor:
|
| r"""Forward data into the module.
|
|
|
| Args:
|
| input_tensor: (batch_size, input_feature_maps, time_steps, freq_bins)
|
|
|
| Returns:
|
| encoder_pool: (batch_size, output_feature_maps, downsampled_time_steps, downsampled_freq_bins)
|
| encoder: (batch_size, output_feature_maps, time_steps, freq_bins)
|
| """
|
| encoder = self.conv_block1(input_tensor, film_dict['conv_block1'])
|
| encoder_pool = F.avg_pool2d(encoder, kernel_size=self.downsample)
|
| return encoder_pool, encoder
|
|
|
|
|
| class DecoderBlockRes1B(nn.Module):
|
| def __init__(
|
| self,
|
| in_channels: int,
|
| out_channels: int,
|
| kernel_size: Tuple,
|
| upsample: Tuple,
|
| momentum: float,
|
| has_film,
|
| ):
|
| r"""Decoder block, contains 1 transposed convolutional and 8 convolutional layers."""
|
| super(DecoderBlockRes1B, self).__init__()
|
| self.kernel_size = kernel_size
|
| self.stride = upsample
|
|
|
| self.conv1 = torch.nn.ConvTranspose2d(
|
| in_channels=in_channels,
|
| out_channels=out_channels,
|
| kernel_size=self.stride,
|
| stride=self.stride,
|
| padding=(0, 0),
|
| bias=False,
|
| dilation=(1, 1),
|
| )
|
|
|
| self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
| self.conv_block2 = ConvBlockRes(
|
| out_channels * 2, out_channels, kernel_size, momentum, has_film,
|
| )
|
| self.bn2 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
| self.has_film = has_film
|
|
|
| self.init_weights()
|
|
|
| def init_weights(self):
|
| r"""Initialize weights."""
|
| init_bn(self.bn1)
|
| init_layer(self.conv1)
|
|
|
| def forward(
|
| self, input_tensor: torch.Tensor, concat_tensor: torch.Tensor, film_dict: Dict,
|
| ) -> torch.Tensor:
|
| r"""Forward data into the module.
|
|
|
| Args:
|
| input_tensor: (batch_size, input_feature_maps, downsampled_time_steps, downsampled_freq_bins)
|
| concat_tensor: (batch_size, input_feature_maps, time_steps, freq_bins)
|
|
|
| Returns:
|
| output_tensor: (batch_size, output_feature_maps, time_steps, freq_bins)
|
| """
|
|
|
|
|
| b1 = film_dict['beta1']
|
| x = self.conv1(F.leaky_relu_(self.bn1(input_tensor) + b1))
|
|
|
|
|
| x = torch.cat((x, concat_tensor), dim=1)
|
|
|
|
|
| x = self.conv_block2(x, film_dict['conv_block2'])
|
|
|
|
|
| return x
|
|
|
|
|
| class ResUNet30_Base(nn.Module, Base):
|
| def __init__(self, input_channels, output_channels):
|
| super(ResUNet30_Base, self).__init__()
|
|
|
| window_size = 2048
|
| hop_size = 320
|
| center = True
|
| pad_mode = "reflect"
|
| window = "hann"
|
| momentum = 0.01
|
|
|
| self.output_channels = output_channels
|
| self.target_sources_num = 1
|
| self.K = 3
|
|
|
| self.time_downsample_ratio = 2 ** 5
|
|
|
| self.stft = STFT(
|
| n_fft=window_size,
|
| hop_length=hop_size,
|
| win_length=window_size,
|
| window=window,
|
| center=center,
|
| pad_mode=pad_mode,
|
| freeze_parameters=True,
|
| )
|
|
|
| self.istft = ISTFT(
|
| n_fft=window_size,
|
| hop_length=hop_size,
|
| win_length=window_size,
|
| window=window,
|
| center=center,
|
| pad_mode=pad_mode,
|
| freeze_parameters=True,
|
| )
|
|
|
| self.bn0 = nn.BatchNorm2d(window_size // 2 + 1, momentum=momentum)
|
|
|
| self.pre_conv = nn.Conv2d(
|
| in_channels=input_channels,
|
| out_channels=32,
|
| kernel_size=(1, 1),
|
| stride=(1, 1),
|
| padding=(0, 0),
|
| bias=True,
|
| )
|
|
|
| self.encoder_block1 = EncoderBlockRes1B(
|
| in_channels=32,
|
| out_channels=32,
|
| kernel_size=(3, 3),
|
| downsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.encoder_block2 = EncoderBlockRes1B(
|
| in_channels=32,
|
| out_channels=64,
|
| kernel_size=(3, 3),
|
| downsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.encoder_block3 = EncoderBlockRes1B(
|
| in_channels=64,
|
| out_channels=128,
|
| kernel_size=(3, 3),
|
| downsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.encoder_block4 = EncoderBlockRes1B(
|
| in_channels=128,
|
| out_channels=256,
|
| kernel_size=(3, 3),
|
| downsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.encoder_block5 = EncoderBlockRes1B(
|
| in_channels=256,
|
| out_channels=384,
|
| kernel_size=(3, 3),
|
| downsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.encoder_block6 = EncoderBlockRes1B(
|
| in_channels=384,
|
| out_channels=384,
|
| kernel_size=(3, 3),
|
| downsample=(1, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.conv_block7a = EncoderBlockRes1B(
|
| in_channels=384,
|
| out_channels=384,
|
| kernel_size=(3, 3),
|
| downsample=(1, 1),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.decoder_block1 = DecoderBlockRes1B(
|
| in_channels=384,
|
| out_channels=384,
|
| kernel_size=(3, 3),
|
| upsample=(1, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.decoder_block2 = DecoderBlockRes1B(
|
| in_channels=384,
|
| out_channels=384,
|
| kernel_size=(3, 3),
|
| upsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.decoder_block3 = DecoderBlockRes1B(
|
| in_channels=384,
|
| out_channels=256,
|
| kernel_size=(3, 3),
|
| upsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.decoder_block4 = DecoderBlockRes1B(
|
| in_channels=256,
|
| out_channels=128,
|
| kernel_size=(3, 3),
|
| upsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.decoder_block5 = DecoderBlockRes1B(
|
| in_channels=128,
|
| out_channels=64,
|
| kernel_size=(3, 3),
|
| upsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
| self.decoder_block6 = DecoderBlockRes1B(
|
| in_channels=64,
|
| out_channels=32,
|
| kernel_size=(3, 3),
|
| upsample=(2, 2),
|
| momentum=momentum,
|
| has_film=True,
|
| )
|
|
|
| self.after_conv = nn.Conv2d(
|
| in_channels=32,
|
| out_channels=output_channels * self.K,
|
| kernel_size=(1, 1),
|
| stride=(1, 1),
|
| padding=(0, 0),
|
| bias=True,
|
| )
|
|
|
| self.init_weights()
|
|
|
| def init_weights(self):
|
| init_bn(self.bn0)
|
| init_layer(self.pre_conv)
|
| init_layer(self.after_conv)
|
|
|
| def feature_maps_to_wav(
|
| self,
|
| input_tensor: torch.Tensor,
|
| sp: torch.Tensor,
|
| sin_in: torch.Tensor,
|
| cos_in: torch.Tensor,
|
| audio_length: int,
|
| ) -> torch.Tensor:
|
| r"""Convert feature maps to waveform.
|
|
|
| Args:
|
| input_tensor: (batch_size, target_sources_num * output_channels * self.K, time_steps, freq_bins)
|
| sp: (batch_size, input_channels, time_steps, freq_bins)
|
| sin_in: (batch_size, input_channels, time_steps, freq_bins)
|
| cos_in: (batch_size, input_channels, time_steps, freq_bins)
|
|
|
| (There is input_channels == output_channels for the source separation task.)
|
|
|
| Outputs:
|
| waveform: (batch_size, target_sources_num * output_channels, segment_samples)
|
| """
|
| batch_size, _, time_steps, freq_bins = input_tensor.shape
|
|
|
| x = input_tensor.reshape(
|
| batch_size,
|
| self.target_sources_num,
|
| self.output_channels,
|
| self.K,
|
| time_steps,
|
| freq_bins,
|
| )
|
|
|
|
|
| mask_mag = torch.sigmoid(x[:, :, :, 0, :, :])
|
| _mask_real = torch.tanh(x[:, :, :, 1, :, :])
|
| _mask_imag = torch.tanh(x[:, :, :, 2, :, :])
|
|
|
| _, mask_cos, mask_sin = magphase(_mask_real, _mask_imag)
|
|
|
|
|
|
|
|
|
|
|
| out_cos = (
|
| cos_in[:, None, :, :, :] * mask_cos - sin_in[:, None, :, :, :] * mask_sin
|
| )
|
| out_sin = (
|
| sin_in[:, None, :, :, :] * mask_cos + cos_in[:, None, :, :, :] * mask_sin
|
| )
|
|
|
|
|
|
|
|
|
| out_mag = F.relu_(sp[:, None, :, :, :] * mask_mag)
|
|
|
|
|
|
|
|
|
| out_real = out_mag * out_cos
|
| out_imag = out_mag * out_sin
|
|
|
|
|
|
|
|
|
| shape = (
|
| batch_size * self.target_sources_num * self.output_channels,
|
| 1,
|
| time_steps,
|
| freq_bins,
|
| )
|
| out_real = out_real.reshape(shape)
|
| out_imag = out_imag.reshape(shape)
|
|
|
|
|
| x = self.istft(out_real, out_imag, audio_length)
|
|
|
|
|
|
|
| waveform = x.reshape(
|
| batch_size, self.target_sources_num * self.output_channels, audio_length
|
| )
|
|
|
|
|
| return waveform
|
|
|
|
|
| def forward(self, mixtures, film_dict):
|
| """
|
| Args:
|
| input: (batch_size, segment_samples, channels_num)
|
|
|
| Outputs:
|
| output_dict: {
|
| 'wav': (batch_size, segment_samples, channels_num),
|
| 'sp': (batch_size, channels_num, time_steps, freq_bins)}
|
| """
|
|
|
| mag, cos_in, sin_in = self.wav_to_spectrogram_phase(mixtures)
|
| x = mag
|
|
|
|
|
| x = x.transpose(1, 3)
|
| x = self.bn0(x)
|
| x = x.transpose(1, 3)
|
| """(batch_size, chanenls, time_steps, freq_bins)"""
|
|
|
|
|
| origin_len = x.shape[2]
|
| pad_len = (
|
| int(np.ceil(x.shape[2] / self.time_downsample_ratio)) * self.time_downsample_ratio
|
| - origin_len
|
| )
|
| x = F.pad(x, pad=(0, 0, 0, pad_len))
|
| """(batch_size, channels, padded_time_steps, freq_bins)"""
|
|
|
|
|
| x = x[..., 0 : x.shape[-1] - 1]
|
|
|
|
|
| x = self.pre_conv(x)
|
| x1_pool, x1 = self.encoder_block1(x, film_dict['encoder_block1'])
|
| x2_pool, x2 = self.encoder_block2(x1_pool, film_dict['encoder_block2'])
|
| x3_pool, x3 = self.encoder_block3(x2_pool, film_dict['encoder_block3'])
|
| x4_pool, x4 = self.encoder_block4(x3_pool, film_dict['encoder_block4'])
|
| x5_pool, x5 = self.encoder_block5(x4_pool, film_dict['encoder_block5'])
|
| x6_pool, x6 = self.encoder_block6(x5_pool, film_dict['encoder_block6'])
|
| x_center, _ = self.conv_block7a(x6_pool, film_dict['conv_block7a'])
|
| x7 = self.decoder_block1(x_center, x6, film_dict['decoder_block1'])
|
| x8 = self.decoder_block2(x7, x5, film_dict['decoder_block2'])
|
| x9 = self.decoder_block3(x8, x4, film_dict['decoder_block3'])
|
| x10 = self.decoder_block4(x9, x3, film_dict['decoder_block4'])
|
| x11 = self.decoder_block5(x10, x2, film_dict['decoder_block5'])
|
| x12 = self.decoder_block6(x11, x1, film_dict['decoder_block6'])
|
|
|
| x = self.after_conv(x12)
|
|
|
|
|
| x = F.pad(x, pad=(0, 1))
|
| x = x[:, :, 0:origin_len, :]
|
|
|
| audio_length = mixtures.shape[2]
|
|
|
|
|
|
|
| separated_audio = self.feature_maps_to_wav(
|
| input_tensor=x,
|
|
|
| sp=mag,
|
|
|
| sin_in=sin_in,
|
|
|
| cos_in=cos_in,
|
|
|
| audio_length=audio_length,
|
| )
|
|
|
|
|
| output_dict = {'waveform': separated_audio}
|
|
|
| return output_dict
|
|
|
|
|
| def get_film_meta(module):
|
|
|
| film_meta = {}
|
|
|
| if hasattr(module, 'has_film'):\
|
|
|
| if module.has_film:
|
| film_meta['beta1'] = module.bn1.num_features
|
| film_meta['beta2'] = module.bn2.num_features
|
| else:
|
| film_meta['beta1'] = 0
|
| film_meta['beta2'] = 0
|
|
|
| for child_name, child_module in module.named_children():
|
|
|
| child_meta = get_film_meta(child_module)
|
|
|
| if len(child_meta) > 0:
|
| film_meta[child_name] = child_meta
|
|
|
| return film_meta
|
|
|
|
|
| class ResUNet30(nn.Module):
|
| def __init__(self, input_channels, output_channels, condition_size):
|
| super(ResUNet30, self).__init__()
|
|
|
| self.base = ResUNet30_Base(
|
| input_channels=input_channels,
|
| output_channels=output_channels,
|
| )
|
|
|
| self.film_meta = get_film_meta(
|
| module=self.base,
|
| )
|
|
|
| self.film = FiLM(
|
| film_meta=self.film_meta,
|
| condition_size=condition_size
|
| )
|
|
|
|
|
| def forward(self, input_dict):
|
| mixtures = input_dict['mixture']
|
| conditions = input_dict['condition']
|
|
|
| film_dict = self.film(
|
| conditions=conditions,
|
| )
|
|
|
| output_dict = self.base(
|
| mixtures=mixtures,
|
| film_dict=film_dict,
|
| )
|
|
|
| return output_dict
|
|
|
|
|
| @torch.no_grad()
|
| def chunk_inference(self, input_dict):
|
| chunk_config = {
|
| 'NL': 1.0,
|
| 'NC': 3.0,
|
| 'NR': 1.0,
|
| 'RATE': self.sampling_rate
|
| }
|
|
|
| mixtures = input_dict['mixture']
|
| conditions = input_dict['condition']
|
|
|
| film_dict = self.film(
|
| conditions=conditions,
|
| )
|
|
|
| NL = int(chunk_config['NL'] * chunk_config['RATE'])
|
| NC = int(chunk_config['NC'] * chunk_config['RATE'])
|
| NR = int(chunk_config['NR'] * chunk_config['RATE'])
|
|
|
| L = mixtures.shape[2]
|
|
|
| out_np = np.zeros([1, L])
|
|
|
| WINDOW = NL + NC + NR
|
| current_idx = 0
|
|
|
| while current_idx + WINDOW < L:
|
| chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]
|
|
|
| chunk_out = self.base(
|
| mixtures=chunk_in,
|
| film_dict=film_dict,
|
| )['waveform']
|
|
|
| chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()
|
|
|
| if current_idx == 0:
|
| out_np[:, current_idx:current_idx+WINDOW-NR] = \
|
| chunk_out_np[:, :-NR] if NR != 0 else chunk_out_np
|
| else:
|
| out_np[:, current_idx+NL:current_idx+WINDOW-NR] = \
|
| chunk_out_np[:, NL:-NR] if NR != 0 else chunk_out_np[:, NL:]
|
|
|
| current_idx += NC
|
|
|
| if current_idx < L:
|
| chunk_in = mixtures[:, :, current_idx:current_idx + WINDOW]
|
| chunk_out = self.base(
|
| mixtures=chunk_in,
|
| film_dict=film_dict,
|
| )['waveform']
|
|
|
| chunk_out_np = chunk_out.squeeze(0).cpu().data.numpy()
|
|
|
| seg_len = chunk_out_np.shape[1]
|
| out_np[:, current_idx + NL:current_idx + seg_len] = \
|
| chunk_out_np[:, NL:]
|
|
|
| return out_np |