Spaces:
Sleeping
Sleeping
| import io | |
| import torch | |
| import torch.nn.functional as F | |
| import math | |
| from datetime import datetime | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import IPython.display as ipd | |
| from IPython.core.display import display | |
| from PIL import Image | |
| from torchvision.transforms import PILToTensor, ToTensor | |
| from typing import Union | |
| #matplotlib.use('Agg') # switch backend to run on server | |
| ################################################################################ | |
| # Plotting utilities for logging and figures | |
| ################################################################################ | |
| def tensor_to_np(x: torch.Tensor): | |
| return x.clone().detach().cpu().numpy() | |
| def play_audio(x: torch.Tensor, sample_rate: int = 16000): | |
| display(ipd.Audio(tensor_to_np(x).flatten(), rate=sample_rate)) | |
| def plot_filter(amplitudes: torch.Tensor): | |
| """ | |
| Given a single set of time-varying filter controls, return plot as image | |
| """ | |
| amplitudes = amplitudes.clone().detach() | |
| if amplitudes.ndim == 2: | |
| magnitudes = amplitudes.cpu().numpy().T | |
| elif amplitudes.ndim == 3: | |
| magnitudes = amplitudes[0].cpu().numpy().T | |
| else: | |
| raise ValueError("Can only plot single filter response") | |
| # plot filter controls over time as heatmap | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| im = ax.imshow(magnitudes, aspect='auto') | |
| fig.colorbar(im, ax=ax) | |
| ax.invert_yaxis() | |
| ax.set_title('filter amplitudes') | |
| ax.set_xlabel('frames') | |
| ax.set_ylabel('frequency bin') | |
| plt.tight_layout() | |
| # save plot to buffer | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| # return plot as image | |
| return ToTensor()(np.array(img)) | |
| def plot_waveform(x: torch.Tensor, scale: Union[int, float] = 1.0): | |
| """ | |
| Given single audio waveform, return plot as image | |
| """ | |
| try: | |
| assert len(x.shape) == 1 or x.shape[0] == 1 | |
| except AssertionError: | |
| raise ValueError('Audio input must be single waveform') | |
| # waveform plot | |
| fig, ax = plt.subplots(figsize=(8,8)) | |
| fig.subplots_adjust(bottom=0.2) | |
| plt.xticks( | |
| #rotation=90 | |
| ) | |
| ax.plot(tensor_to_np(x).flatten(), color='k') | |
| ax.set_xlabel("Sample Index") | |
| ax.set_ylabel("Waveform Amplitude") | |
| plt.axis((None, None, -scale, scale)) # set y-axis range | |
| # save plot to buffer | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| # return plot as image | |
| return ToTensor()(np.array(img)) | |
| def plot_filter_codebook(x: torch.Tensor, use: torch.Tensor = None): | |
| """ | |
| Plot a codebook of learned frequency-domain filter controls. | |
| """ | |
| # scale use rates to [0, 1] for background coloring but not text display | |
| if use is not None: | |
| use = use.clone().detach() | |
| use_normalized = use.clone() | |
| use_normalized -= use_normalized.min(0, keepdim=True)[0] | |
| use_normalized /= use_normalized.max(0, keepdim=True)[0] | |
| n_filters, n_bands = x.shape[0], x.shape[-1] | |
| # create a square grid layout, which may be partially filled | |
| grid_size = math.ceil(math.sqrt(n_filters)) | |
| fig, axs = plt.subplots(ncols=grid_size, nrows=grid_size, figsize=(8, 8)) | |
| for i in range(n_filters): | |
| axis = axs[i//grid_size, i % grid_size] | |
| # color filter plot according to use rate of filter | |
| if use is not None: | |
| assert len(use) == n_filters # one usage rate per filter | |
| axis.set_facecolor((1.0, 0.47, 0.42, use_normalized[i].item())) | |
| x_text = n_bands // 2 | |
| y_text = x[i].max().item() / 2 | |
| axis.text(x_text, y_text, f"{use[i].item() :0.3f}", ha="center", va="center", zorder=10) | |
| axis.plot(np.zeros(n_bands), 'k', alpha=0.5) # plot "neutral" line | |
| axis.plot(tensor_to_np(x[i]).flatten()) | |
| axis.set_xlabel("Frequency") | |
| axis.set_ylabel("Amplitude") | |
| plt.tight_layout() | |
| # save plot to buffer | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| # return plot as image | |
| return ToTensor()(np.array(img)) | |
| def plot_spectrogram(x: torch.Tensor): | |
| """ | |
| Given single audio waveform, return spectrogram plot as image | |
| """ | |
| try: | |
| assert len(x.shape) == 1 or x.shape[0] == 1 | |
| except AssertionError: | |
| raise ValueError('Audio input must be single waveform') | |
| x = x.clone().detach() | |
| # spectrogram plot | |
| spec = torch.stft(x.reshape(1, -1), | |
| n_fft=512, | |
| win_length=512, | |
| hop_length=256, | |
| window=torch.hann_window( | |
| window_length=512 | |
| ).to(x.device), | |
| return_complex=True, | |
| center=False | |
| ) | |
| spec = torch.squeeze( | |
| torch.abs(spec) / (torch.max(torch.abs(spec))) | |
| ) # normalize spectrogram by maximum absolute value | |
| # save plot to buffer | |
| fig, ax = plt.subplots(figsize=(8, 8)) | |
| ax.pcolormesh(tensor_to_np(torch.log(spec + 1)), vmin=0, vmax=.31) | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| # return plot image as tensor | |
| return ToTensor()(np.array(img)) | |
| def plot_logits(class_scores: torch.Tensor, target: int = None): | |
| """ | |
| Given a vector of class scores, and optionally a target index, create a | |
| simple bar plot of the scores and return as an image | |
| """ | |
| # require single vector of class scores | |
| try: | |
| assert class_scores.ndim <= 1 or class_scores.shape[0] == 1 | |
| except AssertionError: | |
| raise ValueError('Must provide single vector of class scores') | |
| # convert to NumPy | |
| scores = tensor_to_np(class_scores).flatten() | |
| labels = np.arange(scores.shape[-1]) | |
| # bar plot | |
| fig = plt.figure(figsize=(8, 8)) | |
| bars = plt.bar(labels, scores, color='k') | |
| # if target label index is given, highlight corresponding bar | |
| if target is not None: | |
| try: | |
| assert 0 <= target < len(scores) | |
| except AssertionError: | |
| raise ValueError("Target must be valid index") | |
| bars[target].set_color('r') | |
| # save plot to buffer | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png") | |
| plt.close(fig) | |
| buf.seek(0) | |
| img = Image.open(buf) | |
| # return plot image as tensor | |
| return ToTensor()(np.array(img)) | |
| def get_duration(st: datetime, ed: datetime): | |
| """Return duration as string""" | |
| total_seconds = int((ed - st).seconds) | |
| hours = total_seconds // 3600 | |
| if hours: | |
| minutes = total_seconds % (3600 * hours) // 60 | |
| else: | |
| minutes = total_seconds // 60 | |
| seconds = total_seconds | |
| if minutes: | |
| seconds = seconds % (60 * minutes) | |
| if hours: | |
| seconds = seconds % (3600 * hours) | |
| duration = "" | |
| if hours > 0: | |
| duration += f"{hours}h {minutes}m " | |
| elif minutes > 0: | |
| duration += f"{minutes}m " | |
| duration += f"{seconds}s" | |
| return duration | |