Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| import argparse | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import sys | |
| import uuid | |
| pwd = os.path.abspath(os.path.dirname(__file__)) | |
| sys.path.append(os.path.join(pwd, "../../")) | |
| import librosa | |
| import numpy as np | |
| import pandas as pd | |
| from scipy.io import wavfile | |
| import torch | |
| import torch.nn as nn | |
| import torchaudio | |
| from tqdm import tqdm | |
| from toolbox.torchaudio.models.lstm.modeling_lstm import LstmPretrainedModel | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) | |
| parser.add_argument("--model_dir", default="serialization_dir/best", type=str) | |
| parser.add_argument("--evaluation_audio_dir", default="evaluation_audio_dir", type=str) | |
| parser.add_argument("--limit", default=10, type=int) | |
| args = parser.parse_args() | |
| return args | |
| def logging_config(): | |
| fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s" | |
| logging.basicConfig(format=fmt, | |
| datefmt="%m/%d/%Y %H:%M:%S", | |
| level=logging.INFO) | |
| stream_handler = logging.StreamHandler() | |
| stream_handler.setLevel(logging.INFO) | |
| stream_handler.setFormatter(logging.Formatter(fmt)) | |
| logger = logging.getLogger(__name__) | |
| return logger | |
| def mix_speech_and_noise(speech: np.ndarray, noise: np.ndarray, snr_db: float): | |
| l1 = len(speech) | |
| l2 = len(noise) | |
| l = min(l1, l2) | |
| speech = speech[:l] | |
| noise = noise[:l] | |
| # np.float32, value between (-1, 1). | |
| speech_power = np.mean(np.square(speech)) | |
| noise_power = speech_power / (10 ** (snr_db / 10)) | |
| noise_adjusted = np.sqrt(noise_power) * noise / np.sqrt(np.mean(noise ** 2)) | |
| noisy_signal = speech + noise_adjusted | |
| return noisy_signal | |
| stft_power = torchaudio.transforms.Spectrogram( | |
| n_fft=512, | |
| win_length=200, | |
| hop_length=80, | |
| power=2.0, | |
| window_fn=torch.hamming_window, | |
| ) | |
| stft_complex = torchaudio.transforms.Spectrogram( | |
| n_fft=512, | |
| win_length=200, | |
| hop_length=80, | |
| power=None, | |
| window_fn=torch.hamming_window, | |
| ) | |
| istft = torchaudio.transforms.InverseSpectrogram( | |
| n_fft=512, | |
| win_length=200, | |
| hop_length=80, | |
| window_fn=torch.hamming_window, | |
| ) | |
| def enhance(mix_spec_complex: torch.Tensor, speech_irm_prediction: torch.Tensor): | |
| mix_spec_complex = mix_spec_complex.detach().cpu() | |
| speech_irm_prediction = speech_irm_prediction.detach().cpu() | |
| mask_speech = speech_irm_prediction | |
| mask_noise = 1.0 - speech_irm_prediction | |
| speech_spec = mix_spec_complex * mask_speech | |
| noise_spec = mix_spec_complex * mask_noise | |
| speech_wave = istft.forward(speech_spec) | |
| noise_wave = istft.forward(noise_spec) | |
| return speech_wave, noise_wave | |
| def save_audios(noise_wave: torch.Tensor, | |
| speech_wave: torch.Tensor, | |
| mix_wave: torch.Tensor, | |
| speech_wave_enhanced: torch.Tensor, | |
| noise_wave_enhanced: torch.Tensor, | |
| output_dir: str, | |
| sample_rate: int = 8000, | |
| ): | |
| basename = uuid.uuid4().__str__() | |
| output_dir = Path(output_dir) / basename | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| filename = output_dir / "noise_wave.wav" | |
| torchaudio.save(filename, noise_wave, sample_rate) | |
| filename = output_dir / "speech_wave.wav" | |
| torchaudio.save(filename, speech_wave, sample_rate) | |
| filename = output_dir / "mix_wave.wav" | |
| torchaudio.save(filename, mix_wave, sample_rate) | |
| filename = output_dir / "speech_wave_enhanced.wav" | |
| torchaudio.save(filename, speech_wave_enhanced, sample_rate) | |
| filename = output_dir / "noise_wave_enhanced.wav" | |
| torchaudio.save(filename, noise_wave_enhanced, sample_rate) | |
| return output_dir.as_posix() | |
| def main(): | |
| args = get_args() | |
| logger = logging_config() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| n_gpu = torch.cuda.device_count() | |
| logger.info("GPU available count: {}; device: {}".format(n_gpu, device)) | |
| logger.info("prepare model") | |
| model = LstmPretrainedModel.from_pretrained( | |
| pretrained_model_name_or_path=args.model_dir, | |
| ) | |
| model.to(device) | |
| model.eval() | |
| # optimizer | |
| logger.info("prepare loss_fn") | |
| mse_loss = nn.MSELoss( | |
| reduction="mean", | |
| ) | |
| logger.info("read excel") | |
| df = pd.read_excel(args.valid_dataset) | |
| total_loss = 0. | |
| total_examples = 0. | |
| progress_bar = tqdm(total=len(df), desc="Evaluation") | |
| for idx, row in df.iterrows(): | |
| noise_filename = row["noise_filename"] | |
| noise_offset = row["noise_offset"] | |
| noise_duration = row["noise_duration"] | |
| speech_filename = row["speech_filename"] | |
| speech_offset = row["speech_offset"] | |
| speech_duration = row["speech_duration"] | |
| snr_db = row["snr_db"] | |
| noise_wave, _ = librosa.load( | |
| noise_filename, | |
| sr=8000, | |
| offset=noise_offset, | |
| duration=noise_duration, | |
| ) | |
| speech_wave, _ = librosa.load( | |
| speech_filename, | |
| sr=8000, | |
| offset=speech_offset, | |
| duration=speech_duration, | |
| ) | |
| mix_wave: np.ndarray = mix_speech_and_noise( | |
| speech=speech_wave, | |
| noise=noise_wave, | |
| snr_db=snr_db, | |
| ) | |
| noise_wave = torch.tensor(noise_wave, dtype=torch.float32) | |
| speech_wave = torch.tensor(speech_wave, dtype=torch.float32) | |
| mix_wave: torch.Tensor = torch.tensor(mix_wave, dtype=torch.float32) | |
| noise_wave = noise_wave.unsqueeze(dim=0) | |
| speech_wave = speech_wave.unsqueeze(dim=0) | |
| mix_wave = mix_wave.unsqueeze(dim=0) | |
| noise_spec: torch.Tensor = stft_power.forward(noise_wave) | |
| speech_spec: torch.Tensor = stft_power.forward(speech_wave) | |
| mix_spec: torch.Tensor = stft_power.forward(mix_wave) | |
| mix_spec_complex: torch.Tensor = stft_complex.forward(mix_wave) | |
| speech_irm = speech_spec / (noise_spec + speech_spec) | |
| speech_irm = torch.pow(speech_irm, 1.0) | |
| mix_spec = mix_spec.to(device) | |
| speech_irm_target = speech_irm.to(device) | |
| with torch.no_grad(): | |
| speech_irm_prediction = model.forward(mix_spec) | |
| loss = mse_loss.forward(speech_irm_prediction, speech_irm_target) | |
| speech_wave_enhanced, noise_wave_enhanced = enhance(mix_spec_complex, speech_irm_prediction) | |
| save_audios(noise_wave, speech_wave, mix_wave, speech_wave_enhanced, noise_wave_enhanced, args.evaluation_audio_dir) | |
| total_loss += loss.item() | |
| total_examples += mix_spec.size(0) | |
| evaluation_loss = total_loss / total_examples | |
| evaluation_loss = round(evaluation_loss, 4) | |
| progress_bar.update(1) | |
| progress_bar.set_postfix({ | |
| "evaluation_loss": evaluation_loss, | |
| }) | |
| if idx > args.limit: | |
| break | |
| return | |
| if __name__ == '__main__': | |
| main() | |