| from os import path |
|
|
| import librosa as rosa |
| import matplotlib |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from pytorch_lightning.loggers import TensorBoardLogger |
| from pytorch_lightning.utilities import rank_zero_only |
|
|
| from utils.stft import STFTMag |
|
|
| matplotlib.use('Agg') |
|
|
|
|
| class TensorBoardLoggerExpanded(TensorBoardLogger): |
| def __init__(self, sr=16000): |
| super().__init__(save_dir='lightning_logs', default_hp_metric=False, name='') |
| self.sr = sr |
| self.stftmag = STFTMag() |
|
|
| def fig2np(self, fig): |
| data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='') |
| data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) |
| return data |
|
|
| def plot_spectrogram_to_numpy(self, y, y_low, y_recon, step): |
| name_list = ['y', 'y_low', 'y_recon'] |
| fig = plt.figure(figsize=(9, 15)) |
| fig.suptitle(f'Epoch_{step}') |
| for i, yy in enumerate([y, y_low, y_recon]): |
| if yy.dim() == 1: |
| yy = self.stftmag(yy) |
| ax = plt.subplot(3, 1, i + 1) |
| ax.set_title(name_list[i]) |
| plt.imshow(rosa.amplitude_to_db(yy.numpy(), |
| ref=np.max, top_db=80.), |
| |
| vmax=0., |
| aspect='auto', |
| origin='lower', |
| interpolation='none') |
| plt.colorbar() |
| plt.xlabel('Frames') |
| plt.ylabel('Channels') |
| plt.tight_layout() |
|
|
| fig.canvas.draw() |
| data = self.fig2np(fig) |
|
|
| plt.close() |
| return data |
|
|
| @rank_zero_only |
| def log_spectrogram(self, y, y_low, y_recon, epoch): |
| y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu() |
| spec_img = self.plot_spectrogram_to_numpy(y, y_low, y_recon, epoch) |
| self.experiment.add_image(path.join(self.save_dir, 'result'), |
| spec_img, |
| epoch, |
| dataformats='HWC') |
| self.experiment.flush() |
| return |
|
|
| @rank_zero_only |
| def log_audio(self, y, y_low, y_recon, epoch): |
| y, y_low, y_recon = y.detach().cpu(), y_low.detach().cpu(), y_recon.detach().cpu(), |
| name_list = ['y', 'y_low', 'y_recon'] |
| for n, yy in zip(name_list, [y, y_low, y_recon]): |
| self.experiment.add_audio(n, yy, epoch, self.sr) |
| self.experiment.flush() |
| return |
|
|