Spaces:
Build error
Build error
| import io | |
| import torch | |
| import PIL.Image | |
| import numpy as np | |
| import scipy.signal | |
| import librosa.display | |
| import matplotlib.pyplot as plt | |
| from torch.functional import Tensor | |
| from torchvision.transforms import ToTensor | |
| def compute_comparison_spectrogram( | |
| x: np.ndarray, | |
| y: np.ndarray, | |
| sample_rate: float = 44100, | |
| n_fft: int = 2048, | |
| hop_length: int = 1024, | |
| ) -> Tensor: | |
| X = librosa.stft(x, n_fft=n_fft, hop_length=hop_length) | |
| X_db = librosa.amplitude_to_db(np.abs(X), ref=np.max) | |
| Y = librosa.stft(y, n_fft=n_fft, hop_length=hop_length) | |
| Y_db = librosa.amplitude_to_db(np.abs(Y), ref=np.max) | |
| fig, axs = plt.subplots(figsize=(9, 6), nrows=2) | |
| img = librosa.display.specshow( | |
| X_db, | |
| ax=axs[0], | |
| hop_length=hop_length, | |
| x_axis="time", | |
| y_axis="log", | |
| sr=sample_rate, | |
| ) | |
| # fig.colorbar(img, ax=axs[0]) | |
| img = librosa.display.specshow( | |
| Y_db, | |
| ax=axs[1], | |
| hop_length=hop_length, | |
| x_axis="time", | |
| y_axis="log", | |
| sr=sample_rate, | |
| ) | |
| # fig.colorbar(img, ax=axs[1]) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="jpeg") | |
| buf.seek(0) | |
| image = PIL.Image.open(buf) | |
| image = ToTensor()(image) | |
| plt.close("all") | |
| return image | |
| def plot_multi_spectrum( | |
| ys=None, | |
| Hs=None, | |
| legend=[], | |
| title="Spectrum", | |
| filename=None, | |
| sample_rate=44100, | |
| n_fft=1024, | |
| zero_mean=False, | |
| ): | |
| if Hs is None: | |
| Hs = [] | |
| for y in ys: | |
| X = get_average_spectrum(y, n_fft) | |
| X_sm = smooth_spectrum(X) | |
| Hs.append(X_sm) | |
| bin_width = (sample_rate / 2) / (n_fft // 2) | |
| freqs = np.arange(0, (sample_rate / 2) + bin_width, step=bin_width) | |
| fig, ax1 = plt.subplots() | |
| for idx, H in enumerate(Hs): | |
| H = np.nan_to_num(H) | |
| H = np.clip(H, 0, np.max(H)) | |
| H_dB = 20 * np.log10(H + 1e-8) | |
| if zero_mean: | |
| H_dB -= np.mean(H_dB) | |
| if "Target" in legend[idx]: | |
| ax1.plot(freqs, H_dB, linestyle="--", color="k") | |
| else: | |
| ax1.plot(freqs, H_dB) | |
| plt.legend(legend) | |
| ax1.set_xscale("log") | |
| ax1.set_ylim([-80, 0]) | |
| ax1.set_xlim([100, 11000]) | |
| plt.title(title) | |
| plt.ylabel("Magnitude (dB)") | |
| plt.xlabel("Frequency (Hz)") | |
| plt.grid(c="lightgray", which="both") | |
| if filename is not None: | |
| plt.savefig(f"{filename}.png", dpi=300) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="jpeg") | |
| buf.seek(0) | |
| image = PIL.Image.open(buf) | |
| image = ToTensor()(image) | |
| plt.close("all") | |
| return image | |
| def smooth_spectrum(H): | |
| # apply Savgol filter for smoothed target curve | |
| return scipy.signal.savgol_filter(H, 1025, 2) | |
| def get_average_spectrum(x, n_fft): | |
| X = torch.stft(x, n_fft, return_complex=True, normalized=True) | |
| X = X.abs() # convert to magnitude | |
| X = X.mean(dim=-1).view(-1) # average across frames | |
| return X | |