Spaces:
Runtime error
Runtime error
| import logging | |
| import os | |
| import time | |
| from typing import Dict, List, NoReturn | |
| import librosa | |
| import musdb | |
| import museval | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch.nn as nn | |
| from pytorch_lightning.utilities import rank_zero_only | |
| from bytesep.callbacks.base_callbacks import SaveCheckpointsCallback | |
| from bytesep.dataset_creation.pack_audios_to_hdf5s.musdb18 import preprocess_audio | |
| from bytesep.inference import Separator | |
| from bytesep.utils import StatisticsContainer, read_yaml | |
| def get_musdb18_callbacks( | |
| config_yaml: str, | |
| workspace: str, | |
| checkpoints_dir: str, | |
| statistics_path: str, | |
| logger: pl.loggers.TensorBoardLogger, | |
| model: nn.Module, | |
| evaluate_device: str, | |
| ) -> List[pl.Callback]: | |
| r"""Get MUSDB18 callbacks of a config yaml. | |
| Args: | |
| config_yaml: str | |
| workspace: str | |
| checkpoints_dir: str, directory to save checkpoints | |
| statistics_dir: str, directory to save statistics | |
| logger: pl.loggers.TensorBoardLogger | |
| model: nn.Module | |
| evaluate_device: str | |
| Return: | |
| callbacks: List[pl.Callback] | |
| """ | |
| configs = read_yaml(config_yaml) | |
| task_name = configs['task_name'] | |
| evaluation_callback = configs['train']['evaluation_callback'] | |
| target_source_types = configs['train']['target_source_types'] | |
| input_channels = configs['train']['channels'] | |
| evaluation_audios_dir = os.path.join(workspace, "evaluation_audios", task_name) | |
| test_segment_seconds = configs['evaluate']['segment_seconds'] | |
| sample_rate = configs['train']['sample_rate'] | |
| test_segment_samples = int(test_segment_seconds * sample_rate) | |
| test_batch_size = configs['evaluate']['batch_size'] | |
| evaluate_step_frequency = configs['train']['evaluate_step_frequency'] | |
| save_step_frequency = configs['train']['save_step_frequency'] | |
| # save checkpoint callback | |
| save_checkpoints_callback = SaveCheckpointsCallback( | |
| model=model, | |
| checkpoints_dir=checkpoints_dir, | |
| save_step_frequency=save_step_frequency, | |
| ) | |
| # evaluation callback | |
| EvaluationCallback = _get_evaluation_callback_class(evaluation_callback) | |
| # statistics container | |
| statistics_container = StatisticsContainer(statistics_path) | |
| # evaluation callback | |
| evaluate_train_callback = EvaluationCallback( | |
| dataset_dir=evaluation_audios_dir, | |
| model=model, | |
| target_source_types=target_source_types, | |
| input_channels=input_channels, | |
| sample_rate=sample_rate, | |
| split='train', | |
| segment_samples=test_segment_samples, | |
| batch_size=test_batch_size, | |
| device=evaluate_device, | |
| evaluate_step_frequency=evaluate_step_frequency, | |
| logger=logger, | |
| statistics_container=statistics_container, | |
| ) | |
| evaluate_test_callback = EvaluationCallback( | |
| dataset_dir=evaluation_audios_dir, | |
| model=model, | |
| target_source_types=target_source_types, | |
| input_channels=input_channels, | |
| sample_rate=sample_rate, | |
| split='test', | |
| segment_samples=test_segment_samples, | |
| batch_size=test_batch_size, | |
| device=evaluate_device, | |
| evaluate_step_frequency=evaluate_step_frequency, | |
| logger=logger, | |
| statistics_container=statistics_container, | |
| ) | |
| # callbacks = [save_checkpoints_callback, evaluate_train_callback, evaluate_test_callback] | |
| callbacks = [save_checkpoints_callback, evaluate_test_callback] | |
| return callbacks | |
| def _get_evaluation_callback_class(evaluation_callback) -> pl.Callback: | |
| r"""Get evaluation callback class.""" | |
| if evaluation_callback == "Musdb18EvaluationCallback": | |
| return Musdb18EvaluationCallback | |
| if evaluation_callback == 'Musdb18ConditionalEvaluationCallback': | |
| return Musdb18ConditionalEvaluationCallback | |
| else: | |
| raise NotImplementedError | |
| class Musdb18EvaluationCallback(pl.Callback): | |
| def __init__( | |
| self, | |
| dataset_dir: str, | |
| model: nn.Module, | |
| target_source_types: str, | |
| input_channels: int, | |
| split: str, | |
| sample_rate: int, | |
| segment_samples: int, | |
| batch_size: int, | |
| device: str, | |
| evaluate_step_frequency: int, | |
| logger: pl.loggers.TensorBoardLogger, | |
| statistics_container: StatisticsContainer, | |
| ): | |
| r"""Callback to evaluate every #save_step_frequency steps. | |
| Args: | |
| dataset_dir: str | |
| model: nn.Module | |
| target_source_types: List[str], e.g., ['vocals', 'bass', ...] | |
| input_channels: int | |
| split: 'train' | 'test' | |
| sample_rate: int | |
| segment_samples: int, length of segments to be input to a model, e.g., 44100*30 | |
| batch_size, int, e.g., 12 | |
| device: str, e.g., 'cuda' | |
| evaluate_step_frequency: int, evaluate every #save_step_frequency steps | |
| logger: object | |
| statistics_container: StatisticsContainer | |
| """ | |
| self.model = model | |
| self.target_source_types = target_source_types | |
| self.input_channels = input_channels | |
| self.sample_rate = sample_rate | |
| self.split = split | |
| self.segment_samples = segment_samples | |
| self.evaluate_step_frequency = evaluate_step_frequency | |
| self.logger = logger | |
| self.statistics_container = statistics_container | |
| self.mono = input_channels == 1 | |
| self.resample_type = "kaiser_fast" | |
| self.mus = musdb.DB(root=dataset_dir, subsets=[split]) | |
| error_msg = "The directory {} is empty!".format(dataset_dir) | |
| assert len(self.mus) > 0, error_msg | |
| # separator | |
| self.separator = Separator(model, self.segment_samples, batch_size, device) | |
| def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: | |
| r"""Evaluate separation SDRs of audio recordings.""" | |
| global_step = trainer.global_step | |
| if global_step % self.evaluate_step_frequency == 0: | |
| sdr_dict = {} | |
| logging.info("--- Step {} ---".format(global_step)) | |
| logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks))) | |
| eval_time = time.time() | |
| for track in self.mus.tracks: | |
| audio_name = track.name | |
| # Get waveform of mixture. | |
| mixture = track.audio.T | |
| # (channels_num, audio_samples) | |
| mixture = preprocess_audio( | |
| audio=mixture, | |
| mono=self.mono, | |
| origin_sr=track.rate, | |
| sr=self.sample_rate, | |
| resample_type=self.resample_type, | |
| ) | |
| # (channels_num, audio_samples) | |
| target_dict = {} | |
| sdr_dict[audio_name] = {} | |
| # Get waveform of all target source types. | |
| for j, source_type in enumerate(self.target_source_types): | |
| # E.g., ['vocals', 'bass', ...] | |
| audio = track.targets[source_type].audio.T | |
| audio = preprocess_audio( | |
| audio=audio, | |
| mono=self.mono, | |
| origin_sr=track.rate, | |
| sr=self.sample_rate, | |
| resample_type=self.resample_type, | |
| ) | |
| # (channels_num, audio_samples) | |
| target_dict[source_type] = audio | |
| # (channels_num, audio_samples) | |
| # Separate. | |
| input_dict = {'waveform': mixture} | |
| sep_wavs = self.separator.separate(input_dict) | |
| # sep_wavs: (target_sources_num * channels_num, audio_samples) | |
| # Post process separation results. | |
| sep_wavs = preprocess_audio( | |
| audio=sep_wavs, | |
| mono=self.mono, | |
| origin_sr=self.sample_rate, | |
| sr=track.rate, | |
| resample_type=self.resample_type, | |
| ) | |
| # sep_wavs: (target_sources_num * channels_num, audio_samples) | |
| sep_wavs = librosa.util.fix_length( | |
| sep_wavs, size=mixture.shape[1], axis=1 | |
| ) | |
| # sep_wavs: (target_sources_num * channels_num, audio_samples) | |
| sep_wav_dict = get_separated_wavs_from_simo_output( | |
| sep_wavs, self.input_channels, self.target_source_types | |
| ) | |
| # output_dict: dict, e.g., { | |
| # 'vocals': (channels_num, audio_samples), | |
| # 'bass': (channels_num, audio_samples), | |
| # ..., | |
| # } | |
| # Evaluate for all target source types. | |
| for source_type in self.target_source_types: | |
| # E.g., ['vocals', 'bass', ...] | |
| # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan). | |
| (sdrs, _, _, _) = museval.evaluate( | |
| [target_dict[source_type].T], [sep_wav_dict[source_type].T] | |
| ) | |
| sdr = np.nanmedian(sdrs) | |
| sdr_dict[audio_name][source_type] = sdr | |
| logging.info( | |
| "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr) | |
| ) | |
| logging.info("-----------------------------") | |
| median_sdr_dict = {} | |
| # Calculate median SDRs of all songs. | |
| for source_type in self.target_source_types: | |
| # E.g., ['vocals', 'bass', ...] | |
| median_sdr = np.median( | |
| [ | |
| sdr_dict[audio_name][source_type] | |
| for audio_name in sdr_dict.keys() | |
| ] | |
| ) | |
| median_sdr_dict[source_type] = median_sdr | |
| logging.info( | |
| "Step: {}, {}, Median SDR: {:.3f}".format( | |
| global_step, source_type, median_sdr | |
| ) | |
| ) | |
| logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) | |
| statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict} | |
| self.statistics_container.append(global_step, statistics, self.split) | |
| self.statistics_container.dump() | |
| def get_separated_wavs_from_simo_output(x, input_channels, target_source_types) -> Dict: | |
| r"""Get separated waveforms of target sources from a single input multiple | |
| output (SIMO) system. | |
| Args: | |
| x: (target_sources_num * channels_num, audio_samples) | |
| input_channels: int | |
| target_source_types: List[str], e.g., ['vocals', 'bass', ...] | |
| Returns: | |
| output_dict: dict, e.g., { | |
| 'vocals': (channels_num, audio_samples), | |
| 'bass': (channels_num, audio_samples), | |
| ..., | |
| } | |
| """ | |
| output_dict = {} | |
| for j, source_type in enumerate(target_source_types): | |
| output_dict[source_type] = x[j * input_channels : (j + 1) * input_channels] | |
| return output_dict | |
| class Musdb18ConditionalEvaluationCallback(pl.Callback): | |
| def __init__( | |
| self, | |
| dataset_dir: str, | |
| model: nn.Module, | |
| target_source_types: str, | |
| input_channels: int, | |
| split: str, | |
| sample_rate: int, | |
| segment_samples: int, | |
| batch_size: int, | |
| device: str, | |
| evaluate_step_frequency: int, | |
| logger: pl.loggers.TensorBoardLogger, | |
| statistics_container: StatisticsContainer, | |
| ): | |
| r"""Callback to evaluate every #save_step_frequency steps. | |
| Args: | |
| dataset_dir: str | |
| model: nn.Module | |
| target_source_types: List[str], e.g., ['vocals', 'bass', ...] | |
| input_channels: int | |
| split: 'train' | 'test' | |
| sample_rate: int | |
| segment_samples: int, length of segments to be input to a model, e.g., 44100*30 | |
| batch_size, int, e.g., 12 | |
| device: str, e.g., 'cuda' | |
| evaluate_step_frequency: int, evaluate every #save_step_frequency steps | |
| logger: object | |
| statistics_container: StatisticsContainer | |
| """ | |
| self.model = model | |
| self.target_source_types = target_source_types | |
| self.input_channels = input_channels | |
| self.sample_rate = sample_rate | |
| self.split = split | |
| self.segment_samples = segment_samples | |
| self.evaluate_step_frequency = evaluate_step_frequency | |
| self.logger = logger | |
| self.statistics_container = statistics_container | |
| self.mono = input_channels == 1 | |
| self.resample_type = "kaiser_fast" | |
| self.mus = musdb.DB(root=dataset_dir, subsets=[split]) | |
| error_msg = "The directory {} is empty!".format(dataset_dir) | |
| assert len(self.mus) > 0, error_msg | |
| # separator | |
| self.separator = Separator(model, self.segment_samples, batch_size, device) | |
| def on_batch_end(self, trainer: pl.Trainer, _) -> NoReturn: | |
| r"""Evaluate separation SDRs of audio recordings.""" | |
| global_step = trainer.global_step | |
| if global_step % self.evaluate_step_frequency == 0: | |
| sdr_dict = {} | |
| logging.info("--- Step {} ---".format(global_step)) | |
| logging.info("Total {} pieces for evaluation:".format(len(self.mus.tracks))) | |
| eval_time = time.time() | |
| for track in self.mus.tracks: | |
| audio_name = track.name | |
| # Get waveform of mixture. | |
| mixture = track.audio.T | |
| # (channels_num, audio_samples) | |
| mixture = preprocess_audio( | |
| audio=mixture, | |
| mono=self.mono, | |
| origin_sr=track.rate, | |
| sr=self.sample_rate, | |
| resample_type=self.resample_type, | |
| ) | |
| # (channels_num, audio_samples) | |
| target_dict = {} | |
| sdr_dict[audio_name] = {} | |
| # Get waveform of all target source types. | |
| for j, source_type in enumerate(self.target_source_types): | |
| # E.g., ['vocals', 'bass', ...] | |
| audio = track.targets[source_type].audio.T | |
| audio = preprocess_audio( | |
| audio=audio, | |
| mono=self.mono, | |
| origin_sr=track.rate, | |
| sr=self.sample_rate, | |
| resample_type=self.resample_type, | |
| ) | |
| # (channels_num, audio_samples) | |
| target_dict[source_type] = audio | |
| # (channels_num, audio_samples) | |
| condition = np.zeros(len(self.target_source_types)) | |
| condition[j] = 1 | |
| input_dict = {'waveform': mixture, 'condition': condition} | |
| sep_wav = self.separator.separate(input_dict) | |
| # sep_wav: (channels_num, audio_samples) | |
| sep_wav = preprocess_audio( | |
| audio=sep_wav, | |
| mono=self.mono, | |
| origin_sr=self.sample_rate, | |
| sr=track.rate, | |
| resample_type=self.resample_type, | |
| ) | |
| # sep_wav: (channels_num, audio_samples) | |
| sep_wav = librosa.util.fix_length( | |
| sep_wav, size=mixture.shape[1], axis=1 | |
| ) | |
| # sep_wav: (target_sources_num * channels_num, audio_samples) | |
| # Calculate SDR using museval, input shape should be: (nsrc, nsampl, nchan) | |
| (sdrs, _, _, _) = museval.evaluate( | |
| [target_dict[source_type].T], [sep_wav.T] | |
| ) | |
| sdr = np.nanmedian(sdrs) | |
| sdr_dict[audio_name][source_type] = sdr | |
| logging.info( | |
| "{}, {}, sdr: {:.3f}".format(audio_name, source_type, sdr) | |
| ) | |
| logging.info("-----------------------------") | |
| median_sdr_dict = {} | |
| # Calculate median SDRs of all songs. | |
| for source_type in self.target_source_types: | |
| median_sdr = np.median( | |
| [ | |
| sdr_dict[audio_name][source_type] | |
| for audio_name in sdr_dict.keys() | |
| ] | |
| ) | |
| median_sdr_dict[source_type] = median_sdr | |
| logging.info( | |
| "Step: {}, {}, Median SDR: {:.3f}".format( | |
| global_step, source_type, median_sdr | |
| ) | |
| ) | |
| logging.info("Evlauation time: {:.3f}".format(time.time() - eval_time)) | |
| statistics = {"sdr_dict": sdr_dict, "median_sdr_dict": median_sdr_dict} | |
| self.statistics_container.append(global_step, statistics, self.split) | |
| self.statistics_container.dump() | |