Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://arxiv.org/abs/1805.00579 | |
| https://github.com/haoxiangsnr/A-Convolutional-Recurrent-Neural-Network-for-Real-Time-Speech-Enhancement | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| class CausalConvBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.conv = nn.Conv2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=(3, 2), | |
| stride=(2, 1), | |
| padding=(0, 1) | |
| ) | |
| self.norm = nn.BatchNorm2d(num_features=out_channels) | |
| self.activation = nn.ELU() | |
| def forward(self, x): | |
| """ | |
| 2D Causal convolution. | |
| Args: | |
| x: [B, C, F, T] | |
| Returns: | |
| [B, C, F, T] | |
| """ | |
| x = self.conv(x) | |
| x = x[:, :, :, :-1] # chomp size | |
| x = self.norm(x) | |
| x = self.activation(x) | |
| return x | |
| class CausalTransConvBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)): | |
| super().__init__() | |
| self.conv = nn.ConvTranspose2d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=(3, 2), | |
| stride=(2, 1), | |
| output_padding=output_padding | |
| ) | |
| self.norm = nn.BatchNorm2d(num_features=out_channels) | |
| if is_last: | |
| self.activation = nn.ReLU() | |
| else: | |
| self.activation = nn.ELU() | |
| def forward(self, x): | |
| """ | |
| 2D Causal convolution. | |
| Args: | |
| x: [B, C, F, T] | |
| Returns: | |
| [B, C, F, T] | |
| """ | |
| x = self.conv(x) | |
| x = x[:, :, :, :-1] # chomp size | |
| x = self.norm(x) | |
| x = self.activation(x) | |
| return x | |
| class CRN(nn.Module): | |
| """ | |
| Input: [batch size, channels=1, T, n_fft] | |
| Output: [batch size, T, n_fft] | |
| """ | |
| def __init__(self): | |
| super(CRN, self).__init__() | |
| # Encoder | |
| self.conv_block_1 = CausalConvBlock(1, 16) | |
| self.conv_block_2 = CausalConvBlock(16, 32) | |
| self.conv_block_3 = CausalConvBlock(32, 64) | |
| self.conv_block_4 = CausalConvBlock(64, 128) | |
| self.conv_block_5 = CausalConvBlock(128, 256) | |
| # LSTM | |
| self.lstm_layer = nn.LSTM(input_size=1024, hidden_size=1024, num_layers=2, batch_first=True) | |
| self.tran_conv_block_1 = CausalTransConvBlock(256 + 256, 128) | |
| self.tran_conv_block_2 = CausalTransConvBlock(128 + 128, 64) | |
| self.tran_conv_block_3 = CausalTransConvBlock(64 + 64, 32) | |
| self.tran_conv_block_4 = CausalTransConvBlock(32 + 32, 16, output_padding=(1, 0)) | |
| self.tran_conv_block_5 = CausalTransConvBlock(16 + 16, 1, is_last=True) | |
| def forward(self, x): | |
| self.lstm_layer.flatten_parameters() | |
| e_1 = self.conv_block_1(x) | |
| e_2 = self.conv_block_2(e_1) | |
| e_3 = self.conv_block_3(e_2) | |
| e_4 = self.conv_block_4(e_3) | |
| e_5 = self.conv_block_5(e_4) # [2, 256, 4, 200] | |
| batch_size, n_channels, n_f_bins, n_frame_size = e_5.shape | |
| # [2, 256, 4, 200] = [2, 1024, 200] => [2, 200, 1024] | |
| lstm_in = e_5.reshape(batch_size, n_channels * n_f_bins, n_frame_size).permute(0, 2, 1) | |
| lstm_out, _ = self.lstm_layer(lstm_in) # [2, 200, 1024] | |
| lstm_out = lstm_out.permute(0, 2, 1).reshape(batch_size, n_channels, n_f_bins, n_frame_size) # [2, 256, 4, 200] | |
| d_1 = self.tran_conv_block_1(torch.cat((lstm_out, e_5), 1)) | |
| d_2 = self.tran_conv_block_2(torch.cat((d_1, e_4), 1)) | |
| d_3 = self.tran_conv_block_3(torch.cat((d_2, e_3), 1)) | |
| d_4 = self.tran_conv_block_4(torch.cat((d_3, e_2), 1)) | |
| d_5 = self.tran_conv_block_5(torch.cat((d_4, e_1), 1)) | |
| return d_5 | |
| def main(): | |
| layer = CRN() | |
| a = torch.rand(2, 1, 161, 200) | |
| print(layer(a).shape) | |
| return | |
| if __name__ == '__main__': | |
| main() | |