Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import logging | |
| from pathlib import Path | |
| import shutil | |
| import tempfile | |
| import zipfile | |
| import librosa | |
| import numpy as np | |
| import torch | |
| import torchaudio | |
| from project_settings import project_path | |
| from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig | |
| from toolbox.torchaudio.models.mpnet.modeling_mpnet import MPNetPretrainedModel, MODEL_FILE | |
| from toolbox.torchaudio.models.mpnet.utils import mag_pha_stft, mag_pha_istft | |
| logger = logging.getLogger("toolbox") | |
| class InferenceMPNet(object): | |
| def __init__(self, pretrained_model_path_or_zip_file: str, device: str = "cpu"): | |
| self.pretrained_model_path_or_zip_file = pretrained_model_path_or_zip_file | |
| self.device = torch.device(device) | |
| logger.info(f"loading model; model_file: {self.pretrained_model_path_or_zip_file}") | |
| config, generator = self.load_models(self.pretrained_model_path_or_zip_file) | |
| logger.info(f"model loading completed; model_file: {self.pretrained_model_path_or_zip_file}") | |
| self.config = config | |
| self.generator = generator | |
| self.generator.to(device) | |
| self.generator.eval() | |
| def load_models(self, model_path: str): | |
| model_path = Path(model_path) | |
| if model_path.name.endswith(".zip"): | |
| with zipfile.ZipFile(model_path.as_posix(), "r") as f_zip: | |
| out_root = Path(tempfile.gettempdir()) / "nx_denoise" | |
| out_root.mkdir(parents=True, exist_ok=True) | |
| f_zip.extractall(path=out_root) | |
| model_path = out_root / model_path.stem | |
| config = MPNetConfig.from_pretrained( | |
| pretrained_model_name_or_path=model_path.as_posix(), | |
| ) | |
| generator = MPNetPretrainedModel.from_pretrained( | |
| pretrained_model_name_or_path=model_path.as_posix(), | |
| ) | |
| generator.to(self.device) | |
| generator.eval() | |
| shutil.rmtree(model_path) | |
| return config, generator | |
| def enhancement_by_ndarray(self, noisy_audio: np.ndarray) -> np.ndarray: | |
| noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) | |
| noisy_audio = noisy_audio.unsqueeze(dim=0) | |
| # noisy_audio shape: [batch_size, n_samples] | |
| enhanced_audio = self.enhancement_by_tensor(noisy_audio) | |
| # enhanced_audio shape: [channels, num_samples] | |
| enhanced_audio = enhanced_audio[0] | |
| # enhanced_audio shape: [num_samples] | |
| return enhanced_audio.cpu().numpy() | |
| def enhancement_by_tensor(self, noisy_audio: torch.Tensor) -> torch.Tensor: | |
| if torch.max(noisy_audio) > 1 or torch.min(noisy_audio) < -1: | |
| raise AssertionError(f"The value range of audio samples should be between -1 and 1.") | |
| noisy_audio = noisy_audio.to(self.device) | |
| with torch.no_grad(): | |
| noisy_mag, noisy_pha, noisy_com = mag_pha_stft( | |
| noisy_audio, | |
| self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor | |
| ) | |
| mag_g, pha_g, com_g = self.generator.forward(noisy_mag, noisy_pha) | |
| audio_g = mag_pha_istft( | |
| mag_g, pha_g, | |
| self.config.n_fft, self.config.hop_size, self.config.win_size, self.config.compress_factor | |
| ) | |
| enhanced_audio = audio_g.detach() | |
| # shape: [batch_size, num_samples] | |
| enhanced_audio = torch.unsqueeze(enhanced_audio, dim=1) | |
| # shape: [batch_size, 1, num_samples] | |
| enhanced_audio = enhanced_audio[0] | |
| # shape: [channels, num_samples] | |
| return enhanced_audio | |
| def main(): | |
| model_zip_file = project_path / "trained_models/mpnet-aishell-1-epoch.zip" | |
| infer_mpnet = InferenceMPNet(model_zip_file) | |
| sample_rate = 8000 | |
| noisy_audio_file = project_path / "data/examples/ai_agent/dfaaf264-b5e3-4ca2-b5cb-5b6d637d962d_section_1.wav" | |
| noisy_audio, _ = librosa.load( | |
| noisy_audio_file.as_posix(), | |
| sr=sample_rate, | |
| ) | |
| noisy_audio = noisy_audio[int(7*sample_rate):int(9*sample_rate)] | |
| noisy_audio = torch.tensor(noisy_audio, dtype=torch.float32) | |
| noisy_audio = noisy_audio.unsqueeze(dim=0) | |
| enhanced_audio = infer_mpnet.enhancement_by_tensor(noisy_audio) | |
| filename = "enhanced_audio.wav" | |
| torchaudio.save(filename, enhanced_audio.detach().cpu(), sample_rate) | |
| return | |
| if __name__ == '__main__': | |
| main() | |