Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://huggingface.co/spaces/LeeSangHoon/HierSpeech_TTS/blob/main/denoiser/generator.py | |
| https://huggingface.co/spaces/JacobLinCool/MP-SENet | |
| https://arxiv.org/abs/2305.13686 | |
| https://github.com/yxlu-0102/MP-SENet | |
| 应该是不支持流式改造的。 | |
| """ | |
| import os | |
| from typing import Optional, Union | |
| from pesq import pesq | |
| from joblib import Parallel, delayed | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from toolbox.torchaudio.configuration_utils import CONFIG_FILE | |
| from toolbox.torchaudio.models.mpnet.conformer import ConformerBlock | |
| from toolbox.torchaudio.models.mpnet.transformers import TransformerBlock | |
| from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig | |
| from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid2d | |
| class SPConvTranspose2d(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, r=1): | |
| super(SPConvTranspose2d, self).__init__() | |
| self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.) | |
| self.out_channels = out_channels | |
| self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1)) | |
| self.r = r | |
| def forward(self, x): | |
| x = self.pad1(x) | |
| out = self.conv(x) | |
| batch_size, nchannels, H, W = out.shape | |
| out = out.view((batch_size, self.r, nchannels // self.r, H, W)) | |
| out = out.permute(0, 2, 3, 4, 1) | |
| out = out.contiguous().view((batch_size, nchannels // self.r, H, -1)) | |
| return out | |
| class DenseBlock(nn.Module): | |
| def __init__(self, h, kernel_size=(2, 3), depth=4): | |
| super(DenseBlock, self).__init__() | |
| self.h = h | |
| self.depth = depth | |
| self.dense_block = nn.ModuleList([]) | |
| for i in range(depth): | |
| dilation = 2 ** i | |
| pad_length = dilation | |
| dense_conv = nn.Sequential( | |
| nn.ConstantPad2d((1, 1, pad_length, 0), value=0.), | |
| nn.Conv2d(h.dense_channel*(i+1), h.dense_channel, kernel_size, dilation=(dilation, 1)), | |
| nn.InstanceNorm2d(h.dense_channel, affine=True), | |
| nn.PReLU(h.dense_channel) | |
| ) | |
| self.dense_block.append(dense_conv) | |
| def forward(self, x): | |
| skip = x | |
| for i in range(self.depth): | |
| x = self.dense_block[i](skip) | |
| skip = torch.cat([x, skip], dim=1) | |
| return x | |
| class DenseEncoder(nn.Module): | |
| def __init__(self, h, in_channel): | |
| super(DenseEncoder, self).__init__() | |
| self.h = h | |
| self.dense_conv_1 = nn.Sequential( | |
| nn.Conv2d(in_channel, h.dense_channel, (1, 1)), | |
| nn.InstanceNorm2d(h.dense_channel, affine=True), | |
| nn.PReLU(h.dense_channel)) | |
| self.dense_block = DenseBlock(h, depth=4) | |
| self.dense_conv_2 = nn.Sequential( | |
| nn.Conv2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2), padding=(0, 1)), | |
| nn.InstanceNorm2d(h.dense_channel, affine=True), | |
| nn.PReLU(h.dense_channel)) | |
| def forward(self, x): | |
| x = self.dense_conv_1(x) # [b, 64, T, F] | |
| x = self.dense_block(x) # [b, 64, T, F] | |
| x = self.dense_conv_2(x) # [b, 64, T, F//2] | |
| return x | |
| class MaskDecoder(nn.Module): | |
| def __init__(self, h, out_channel=1): | |
| super(MaskDecoder, self).__init__() | |
| self.dense_block = DenseBlock(h, depth=4) | |
| self.mask_conv = nn.Sequential( | |
| SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2), | |
| nn.InstanceNorm2d(h.dense_channel, affine=True), | |
| nn.PReLU(h.dense_channel), | |
| nn.Conv2d(h.dense_channel, out_channel, (1, 2)) | |
| ) | |
| self.lsigmoid = LearnableSigmoid2d(h.n_fft//2+1, beta=h.beta) | |
| def forward(self, x): | |
| x = self.dense_block(x) | |
| x = self.mask_conv(x) | |
| x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T] | |
| x = self.lsigmoid(x) | |
| return x | |
| class PhaseDecoder(nn.Module): | |
| def __init__(self, h, out_channel=1): | |
| super(PhaseDecoder, self).__init__() | |
| self.dense_block = DenseBlock(h, depth=4) | |
| self.phase_conv = nn.Sequential( | |
| SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2), | |
| nn.InstanceNorm2d(h.dense_channel, affine=True), | |
| nn.PReLU(h.dense_channel) | |
| ) | |
| self.phase_conv_r = nn.Conv2d(h.dense_channel, out_channel, (1, 2)) | |
| self.phase_conv_i = nn.Conv2d(h.dense_channel, out_channel, (1, 2)) | |
| def forward(self, x): | |
| x = self.dense_block(x) | |
| x = self.phase_conv(x) | |
| x_r = self.phase_conv_r(x) | |
| x_i = self.phase_conv_i(x) | |
| x = torch.atan2(x_i, x_r) | |
| x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T] | |
| return x | |
| class TSTransformerBlock(nn.Module): | |
| def __init__(self, h): | |
| super(TSTransformerBlock, self).__init__() | |
| self.h = h | |
| self.time_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4) | |
| self.freq_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4) | |
| def forward(self, x): | |
| b, c, t, f = x.size() | |
| x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c) | |
| x = self.time_transformer(x) + x | |
| x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c) | |
| x = self.freq_transformer(x) + x | |
| x = x.view(b, t, f, c).permute(0, 3, 1, 2) | |
| return x | |
| class MPNet(nn.Module): | |
| def __init__(self, config: MPNetConfig, num_tsblocks=4): | |
| super(MPNet, self).__init__() | |
| self.num_tscblocks = num_tsblocks | |
| self.dense_encoder = DenseEncoder(config, in_channel=2) | |
| self.TSTransformer = nn.ModuleList([]) | |
| for i in range(num_tsblocks): | |
| self.TSTransformer.append(TSTransformerBlock(config)) | |
| self.mask_decoder = MaskDecoder(config, out_channel=1) | |
| self.phase_decoder = PhaseDecoder(config, out_channel=1) | |
| def forward(self, noisy_amp, noisy_pha): # [B, F, T] | |
| x = torch.stack((noisy_amp, noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F] | |
| x = self.dense_encoder(x) | |
| for i in range(self.num_tscblocks): | |
| x = self.TSTransformer[i](x) | |
| denoised_amp = noisy_amp * self.mask_decoder(x) | |
| denoised_pha = self.phase_decoder(x) | |
| denoised_com = torch.stack( | |
| tensors=( | |
| denoised_amp * torch.cos(denoised_pha), | |
| denoised_amp * torch.sin(denoised_pha) | |
| ), | |
| dim=-1 | |
| ) | |
| return denoised_amp, denoised_pha, denoised_com | |
| MODEL_FILE = "generator.pt" | |
| class MPNetPretrainedModel(MPNet): | |
| def __init__(self, | |
| config: MPNetConfig, | |
| ): | |
| super(MPNetPretrainedModel, self).__init__( | |
| config=config, | |
| ) | |
| self.config = config | |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
| config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
| model = cls(config) | |
| if os.path.isdir(pretrained_model_name_or_path): | |
| ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) | |
| else: | |
| ckpt_file = pretrained_model_name_or_path | |
| with open(ckpt_file, "rb") as f: | |
| state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
| model.load_state_dict(state_dict, strict=True) | |
| return model | |
| def save_pretrained(self, | |
| save_directory: Union[str, os.PathLike], | |
| state_dict: Optional[dict] = None, | |
| ): | |
| model = self | |
| if state_dict is None: | |
| state_dict = model.state_dict() | |
| os.makedirs(save_directory, exist_ok=True) | |
| # save state dict | |
| model_file = os.path.join(save_directory, MODEL_FILE) | |
| torch.save(state_dict, model_file) | |
| # save config | |
| config_file = os.path.join(save_directory, CONFIG_FILE) | |
| self.config.to_yaml_file(config_file) | |
| return save_directory | |
| def phase_losses(phase_r, phase_g): | |
| ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) | |
| gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) | |
| iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) | |
| return ip_loss, gd_loss, iaf_loss | |
| def anti_wrapping_function(x): | |
| return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) | |
| # def pesq_score(utts_r, utts_g, h): | |
| # | |
| # pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)( | |
| # utts_r[i].squeeze().cpu().numpy(), | |
| # utts_g[i].squeeze().cpu().numpy(), | |
| # h.sample_rate, ) | |
| # for i in range(len(utts_r))) | |
| # pesq_score = np.mean(pesq_score) | |
| # | |
| # return pesq_score | |
| # | |
| # | |
| # def eval_pesq(clean_utt, esti_utt, sr): | |
| # try: | |
| # mode = "nb" if sr == 8000 else "wb" | |
| # pesq_score = pesq(sr, clean_utt, esti_utt, mode=mode) | |
| # except: | |
| # pesq_score = -1 | |
| # | |
| # return pesq_score | |
| def main(): | |
| import torchaudio | |
| config = MPNetConfig() | |
| model = MPNet(config=config) | |
| transformer = torchaudio.transforms.Spectrogram( | |
| n_fft=config.n_fft, | |
| win_length=config.win_size, | |
| hop_length=config.hop_size, | |
| window_fn=torch.hamming_window, | |
| ) | |
| inputs = torch.randn(size=(1, 32000), dtype=torch.float32) | |
| spec = transformer.forward(inputs) | |
| print(spec.shape) | |
| denoised_amp, denoised_pha, denoised_com = model.forward(spec, spec) | |
| print(denoised_amp.shape) | |
| print(denoised_pha.shape) | |
| print(denoised_com.shape) | |
| return | |
| if __name__ == '__main__': | |
| main() | |