Spaces:
Paused
Paused
| import numpy as np | |
| from matplotlib import pyplot as plt | |
| from scipy.fft import fft | |
| from scipy.signal import savgol_filter | |
| from tools import rms_normalize | |
| colors = [ | |
| # (0, 0, 0), # Black | |
| # (86, 180, 233), # Sky blue | |
| # (240, 228, 66), # Yellow | |
| # (204, 121, 167), # Reddish purple | |
| (213, 94, 0), # Vermilion | |
| (0, 114, 178), # Blue | |
| (230, 159, 0), # Orange | |
| (0, 158, 115), # Bluish green | |
| ] | |
| def plot_psd_multiple_signals(signals_list, labels_list, sample_rate=16000, window_size=500, | |
| figsize=(10, 6), save_path=None, normalize=False): | |
| """ | |
| 在同一张图上绘制多组音频信号的功率谱密度比较图,使用对数刻度的响度轴(以2为底),并应用平滑处理。 | |
| 参数: | |
| signals_list: 包含多组音频信号的列表,每组信号形状为 [sample_number, sample_length] 的numpy array | |
| labels_list: 每组音频信号对应的标签字符串列表 | |
| sample_rate: 音频的采样率 | |
| """ | |
| # 确保传入的signals_list和labels_list长度相同 | |
| assert len(signals_list) == len(labels_list), "每组信号必须有一个对应的标签。" | |
| signals_list = [np.array([rms_normalize(signal) for signal in signals]) for signals in signals_list] | |
| # 绘图准备 | |
| plt.figure(figsize=figsize) | |
| # 遍历所有的音频信号 | |
| i = 0 | |
| for signal, label in zip(signals_list, labels_list): | |
| # 计算FFT | |
| fft_signal = fft(signal, axis=1) | |
| # 计算平均功率谱密度 | |
| psd_signal = np.mean(np.abs(fft_signal)**2, axis=0) | |
| # 计算频率轴 | |
| freqs = np.fft.fftfreq(signal.shape[1], 1/sample_rate) | |
| # 应用Savitzky-Golay滤波器进行平滑 | |
| psd_smoothed = savgol_filter(np.log2(psd_signal[:signal.shape[1] // 2] + 1), window_size, 3) # 窗口大小51, 多项式阶数3 | |
| # Normalize each curve if normalize is True | |
| if normalize: | |
| psd_smoothed /= np.mean(psd_smoothed) | |
| # 绘制每组信号的功率谱密度 | |
| plt.plot(freqs[:signal.shape[1] // 2], psd_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1) | |
| i += 1 | |
| # 设置图表元素 | |
| plt.xlabel('Frequency (Hz)') | |
| plt.ylabel('Mean Log-Amplitude') | |
| plt.legend() | |
| # 根据save_path参数决定保存图像还是直接显示 | |
| if save_path: | |
| plt.savefig(save_path) | |
| else: | |
| plt.show() | |
| def plot_amplitude_over_time(signals_list, labels_list, sample_rate=16000, window_size=500, | |
| figsize=(10, 6), save_path=None, normalize=False, start_time=0): | |
| """ | |
| Plot the loudness of multiple sets of audio signals over time on the same graph, | |
| using a logarithmic scale for the loudness axis (base 2), with smoothing applied. | |
| Parameters: | |
| signals_list: List of sets of audio signals, each set is a numpy array with shape [sample_number, sample_length] | |
| labels_list: List of labels corresponding to each set of audio signals | |
| sample_rate: Sampling rate of the audio | |
| window_size: Window size for the Savitzky-Golay filter | |
| figsize: Figure size | |
| save_path: Path to save the figure, if None, the figure will be displayed | |
| normalize: Whether to normalize each curve so that the sum of each curve is the same | |
| start_time: Time (in seconds) to start plotting, only data after this time will be retained | |
| """ | |
| assert len(signals_list) == len(labels_list), f"len(signals_list) != len(labels_list) for " \ | |
| f"len(signals_list) = {len(signals_list)} and len(labels_list) = {len(labels_list)}" | |
| # Compute starting sample index | |
| start_sample = int(start_time * sample_rate) | |
| # Normalize signals and truncate data | |
| signals_list = [np.array([rms_normalize(signal)[start_sample:] for signal in signals]) for signals in signals_list] | |
| time_axis = np.arange(start_sample, start_sample + signals_list[0].shape[1]) / sample_rate | |
| plt.figure(figsize=figsize) | |
| i = 0 | |
| for signal, label in zip(signals_list, labels_list): | |
| amplitude_mean = np.mean(np.abs(signal), axis=0) | |
| amplitude_smoothed = savgol_filter(np.log2(amplitude_mean + 1), window_size, 3) | |
| # Normalize each curve if normalize is True | |
| if normalize: | |
| amplitude_smoothed /= np.mean(amplitude_smoothed) | |
| plt.plot(time_axis, amplitude_smoothed, label=label, color=[x/255.0 for x in colors[i % len(colors)]], linewidth=1) | |
| i += 1 | |
| plt.xlabel('Time (seconds)') | |
| plt.ylabel('Mean Log-Amplitude') | |
| plt.legend() | |
| # Save or show the figure based on save_path parameter | |
| if save_path: | |
| plt.savefig(save_path) | |
| else: | |
| plt.show() | |