rossijakob's picture
Upload folder using huggingface_hub
18b9615 verified
import torch
import torch.nn as nn
import numpy as np
from models.transformer import TransformerBlock
from utils import LearnableSigmoid2d
from pesq import pesq
from joblib import Parallel, delayed
class SPConvTranspose2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, r=1):
super(SPConvTranspose2d, self).__init__()
self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
self.r = r
def forward(self, x):
x = self.pad1(x)
out = self.conv(x)
batch_size, nchannels, H, W = out.shape
out = out.view((batch_size, self.r, nchannels // self.r, H, W))
out = out.permute(0, 2, 3, 4, 1)
out = out.contiguous().view((batch_size, nchannels // self.r, H, -1))
return out
class DenseBlock(nn.Module):
def __init__(self, h, kernel_size=(2, 3), depth=4):
super(DenseBlock, self).__init__()
self.h = h
self.depth = depth
self.dense_block = nn.ModuleList([])
for i in range(depth):
dilation = 2 ** i
pad_length = dilation
dense_conv = nn.Sequential(
nn.ConstantPad2d((1, 1, pad_length, 0), value=0.),
nn.Conv2d(h.dense_channel*(i+1), h.dense_channel, kernel_size, dilation=(dilation, 1)),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel)
)
self.dense_block.append(dense_conv)
def forward(self, x):
skip = x
for i in range(self.depth):
x = self.dense_block[i](skip)
skip = torch.cat([x, skip], dim=1)
return x
class DenseEncoder(nn.Module):
def __init__(self, h, in_channel):
super(DenseEncoder, self).__init__()
self.h = h
self.dense_conv_1 = nn.Sequential(
nn.Conv2d(in_channel, h.dense_channel, (1, 1)),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel))
self.dense_block = DenseBlock(h, depth=4)
self.dense_conv_2 = nn.Sequential(
nn.Conv2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2), padding=(0, 1)),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel))
def forward(self, x):
x = self.dense_conv_1(x) # [b, 64, T, F]
x = self.dense_block(x) # [b, 64, T, F]
x = self.dense_conv_2(x) # [b, 64, T, F//2]
return x
class MaskDecoder(nn.Module):
def __init__(self, h, out_channel=1):
super(MaskDecoder, self).__init__()
self.dense_block = DenseBlock(h, depth=4)
self.mask_conv = nn.Sequential(
SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel),
nn.Conv2d(h.dense_channel, out_channel, (1, 2))
)
self.lsigmoid = LearnableSigmoid2d(h.n_fft//2+1, beta=h.beta)
def forward(self, x):
x = self.dense_block(x)
x = self.mask_conv(x)
x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T]
x = self.lsigmoid(x)
return x
class PhaseDecoder(nn.Module):
def __init__(self, h, out_channel=1):
super(PhaseDecoder, self).__init__()
self.dense_block = DenseBlock(h, depth=4)
self.phase_conv = nn.Sequential(
SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel)
)
self.phase_conv_r = nn.Conv2d(h.dense_channel, out_channel, (1, 2))
self.phase_conv_i = nn.Conv2d(h.dense_channel, out_channel, (1, 2))
def forward(self, x):
x = self.dense_block(x)
x = self.phase_conv(x)
x_r = self.phase_conv_r(x)
x_i = self.phase_conv_i(x)
x = torch.atan2(x_i, x_r)
x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T]
return x
class TSTransformerBlock(nn.Module):
def __init__(self, h):
super(TSTransformerBlock, self).__init__()
self.h = h
self.time_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4)
self.freq_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4)
def forward(self, x):
b, c, t, f = x.size()
x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
x = self.time_transformer(x) + x
x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
x = self.freq_transformer(x) + x
x = x.view(b, t, f, c).permute(0, 3, 1, 2)
return x
class MPNet(nn.Module):
def __init__(self, h, num_tsblocks=4):
super(MPNet, self).__init__()
self.h = h
self.num_tscblocks = num_tsblocks
self.dense_encoder = DenseEncoder(h, in_channel=2)
self.TSTransformer = nn.ModuleList([])
for i in range(num_tsblocks):
self.TSTransformer.append(TSTransformerBlock(h))
self.mask_decoder = MaskDecoder(h, out_channel=1)
self.phase_decoder = PhaseDecoder(h, out_channel=1)
def forward(self, noisy_amp, noisy_pha): # [B, F, T]
x = torch.stack((noisy_amp, noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F]
x = self.dense_encoder(x)
for i in range(self.num_tscblocks):
x = self.TSTransformer[i](x)
denoised_amp = noisy_amp * self.mask_decoder(x)
denoised_pha = self.phase_decoder(x)
denoised_com = torch.stack((denoised_amp*torch.cos(denoised_pha),
denoised_amp*torch.sin(denoised_pha)), dim=-1)
return denoised_amp, denoised_pha, denoised_com
def phase_losses(phase_r, phase_g):
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
return ip_loss, gd_loss, iaf_loss
def anti_wrapping_function(x):
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
def pesq_score(utts_r, utts_g, h):
pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
utts_r[i].squeeze().cpu().numpy(),
utts_g[i].squeeze().cpu().numpy(),
h.sampling_rate)
for i in range(len(utts_r)))
pesq_score = np.mean(pesq_score)
return pesq_score
def eval_pesq(clean_utt, esti_utt, sr):
try:
pesq_score = pesq(sr, clean_utt, esti_utt)
except:
pesq_score = -1
return pesq_score