Spaces:
Build error
Build error
| import auraloss | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| from deepafx_st.callbacks.plotting import plot_multi_spectrum | |
| from deepafx_st.metrics import ( | |
| LoudnessError, | |
| SpectralCentroidError, | |
| CrestFactorError, | |
| PESQ, | |
| MelSpectralDistance, | |
| ) | |
| class LogAudioCallback(pl.callbacks.Callback): | |
| def __init__(self, num_examples=4, peak_normalize=True, sample_rate=22050): | |
| super().__init__() | |
| self.num_examples = 4 | |
| self.peak_normalize = peak_normalize | |
| self.metrics = { | |
| "PESQ": PESQ(sample_rate), | |
| "MRSTFT": auraloss.freq.MultiResolutionSTFTLoss( | |
| fft_sizes=[32, 128, 512, 2048, 8192, 32768], | |
| hop_sizes=[16, 64, 256, 1024, 4096, 16384], | |
| win_lengths=[32, 128, 512, 2048, 8192, 32768], | |
| w_sc=0.0, | |
| w_phs=0.0, | |
| w_lin_mag=1.0, | |
| w_log_mag=1.0, | |
| ), | |
| "MSD": MelSpectralDistance(sample_rate), | |
| "SCE": SpectralCentroidError(sample_rate), | |
| "CFE": CrestFactorError(), | |
| "LUFS": LoudnessError(sample_rate), | |
| } | |
| self.outputs = [] | |
| def on_validation_batch_end( | |
| self, | |
| trainer, | |
| pl_module, | |
| outputs, | |
| batch, | |
| batch_idx, | |
| dataloader_idx, | |
| ): | |
| """Called when the validation batch ends.""" | |
| if outputs is not None: | |
| examples = np.min([self.num_examples, outputs["x"].shape[0]]) | |
| self.outputs.append(outputs) | |
| if batch_idx == 0: | |
| for n in range(examples): | |
| if batch_idx == 0: | |
| self.log_audio( | |
| outputs, | |
| n, | |
| pl_module.hparams.sample_rate, | |
| pl_module.hparams.val_length, | |
| trainer.global_step, | |
| trainer.logger, | |
| ) | |
| def on_validation_end(self, trainer, pl_module): | |
| metrics = { | |
| "PESQ": [], | |
| "MRSTFT": [], | |
| "MSD": [], | |
| "SCE": [], | |
| "CFE": [], | |
| "LUFS": [], | |
| } | |
| for output in self.outputs: | |
| for metric_name, metric in self.metrics.items(): | |
| try: | |
| val = metric(output["y_hat"], output["y"]) | |
| metrics[metric_name].append(val) | |
| except: | |
| pass | |
| # log final mean metrics | |
| for metric_name, metric in metrics.items(): | |
| val = np.mean(metric) | |
| trainer.logger.experiment.add_scalar( | |
| f"metrics/{metric_name}", val, trainer.global_step | |
| ) | |
| # clear outputs | |
| self.outputs = [] | |
| def compute_metrics(self, metrics_dict, outputs, batch_idx, global_step): | |
| # extract audio | |
| y = outputs["y"][batch_idx, ...].float() | |
| y_hat = outputs["y_hat"][batch_idx, ...].float() | |
| # compute all metrics | |
| for metric_name, metric in self.metrics.items(): | |
| try: | |
| val = metric(y_hat.view(1, 1, -1), y.view(1, 1, -1)) | |
| metrics_dict[metric_name].append(val) | |
| except: | |
| pass | |
| def log_audio(self, outputs, batch_idx, sample_rate, n_fft, global_step, logger): | |
| x = outputs["x"][batch_idx, ...].float() | |
| y = outputs["y"][batch_idx, ...].float() | |
| y_hat = outputs["y_hat"][batch_idx, ...].float() | |
| if self.peak_normalize: | |
| x /= x.abs().max() | |
| y /= y.abs().max() | |
| y_hat /= y_hat.abs().max() | |
| logger.experiment.add_audio( | |
| f"x/{batch_idx+1}", | |
| x[0:1, :], | |
| global_step, | |
| sample_rate=sample_rate, | |
| ) | |
| logger.experiment.add_audio( | |
| f"y/{batch_idx+1}", | |
| y[0:1, :], | |
| global_step, | |
| sample_rate=sample_rate, | |
| ) | |
| logger.experiment.add_audio( | |
| f"y_hat/{batch_idx+1}", | |
| y_hat[0:1, :], | |
| global_step, | |
| sample_rate=sample_rate, | |
| ) | |
| if "y_ref" in outputs: | |
| y_ref = outputs["y_ref"][batch_idx, ...].float() | |
| if self.peak_normalize: | |
| y_ref /= y_ref.abs().max() | |
| logger.experiment.add_audio( | |
| f"y_ref/{batch_idx+1}", | |
| y_ref[0:1, :], | |
| global_step, | |
| sample_rate=sample_rate, | |
| ) | |
| logger.experiment.add_image( | |
| f"spec/{batch_idx+1}", | |
| compare_spectra( | |
| y_hat[0:1, :], | |
| y[0:1, :], | |
| x[0:1, :], | |
| sample_rate=sample_rate, | |
| n_fft=n_fft, | |
| ), | |
| global_step, | |
| ) | |
| def compare_spectra( | |
| deepafx_y_hat, y, x, baseline_y_hat=None, sample_rate=44100, n_fft=16384 | |
| ): | |
| legend = ["Corrupted"] | |
| signals = [x] | |
| if baseline_y_hat is not None: | |
| legend.append("Baseline") | |
| signals.append(baseline_y_hat) | |
| legend.append("DeepAFx") | |
| signals.append(deepafx_y_hat) | |
| legend.append("Target") | |
| signals.append(y) | |
| image = plot_multi_spectrum( | |
| ys=signals, | |
| legend=legend, | |
| sample_rate=sample_rate, | |
| n_fft=n_fft, | |
| ) | |
| return image | |