Spaces:
Sleeping
Sleeping
| import os | |
| import matplotlib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import scipy | |
| from mpl_toolkits.axes_grid1.inset_locator import inset_axes, mark_inset | |
| from scipy import signal | |
| from tqdm import tqdm | |
| from data_reader import Config | |
| matplotlib.use('agg') | |
| def plot_result(epoch, num, figure_dir, preds, X, Y, mode="valid"): | |
| config = Config() | |
| for i in range(min(num, len(X))): | |
| t, noisy_signal = scipy.signal.istft( | |
| X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
| ) | |
| t, ideal_denoised_signal = scipy.signal.istft( | |
| (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| t, denoised_signal = scipy.signal.istft( | |
| (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| plt.figure(i) | |
| fig_size = plt.gcf().get_size_inches() | |
| plt.gcf().set_size_inches(fig_size * [1.5, 1.5]) | |
| plt.subplot(4, 2, 1) | |
| plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2) | |
| plt.title("Noisy signal") | |
| plt.gca().set_xticklabels([]) | |
| plt.subplot(4, 2, 2) | |
| plt.plot(t, noisy_signal, label='Noisy signal', linewidth=0.1) | |
| signal_ylim = plt.gca().get_ylim() | |
| plt.gca().set_xticklabels([]) | |
| plt.legend(loc='lower left') | |
| plt.margins(x=0) | |
| plt.subplot(4, 2, 3) | |
| plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1) | |
| plt.gca().set_xticklabels([]) | |
| plt.title("Y") | |
| plt.subplot(4, 2, 4) | |
| plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1) | |
| plt.title("$\hat{Y}$") | |
| plt.gca().set_xticklabels([]) | |
| plt.subplot(4, 2, 5) | |
| plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2) | |
| plt.title("Ideal denoised signal") | |
| plt.gca().set_xticklabels([]) | |
| plt.subplot(4, 2, 6) | |
| plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2) | |
| plt.title("Denoised signal") | |
| plt.gca().set_xticklabels([]) | |
| plt.subplot(4, 2, 7) | |
| plt.plot(t, ideal_denoised_signal, label='Ideal denoised signal', linewidth=0.1) | |
| plt.ylim(signal_ylim) | |
| plt.xlabel("Time (s)") | |
| plt.legend(loc='lower left') | |
| plt.margins(x=0) | |
| plt.subplot(4, 2, 8) | |
| plt.plot(t, denoised_signal, label='Denoised signal', linewidth=0.1) | |
| plt.ylim(signal_ylim) | |
| plt.xlabel("Time (s)") | |
| plt.legend(loc='lower left') | |
| plt.margins(x=0) | |
| plt.tight_layout() | |
| plt.gcf().align_labels() | |
| plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, "epoch%03d_%03d.pdf" % (epoch, i)), bbox_inches='tight') | |
| plt.close(i) | |
| return 0 | |
| def plot_result_thread(i, epoch, preds, X, Y, figure_dir, mode="valid"): | |
| config = Config() | |
| t, noisy_signal = scipy.signal.istft( | |
| X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
| ) | |
| t, ideal_denoised_signal = scipy.signal.istft( | |
| (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| t, denoised_signal = scipy.signal.istft( | |
| (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| plt.figure(i) | |
| fig_size = plt.gcf().get_size_inches() | |
| plt.gcf().set_size_inches(fig_size * [1.5, 1.5]) | |
| plt.subplot(4, 2, 1) | |
| plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=2) | |
| plt.title("Noisy signal") | |
| plt.gca().set_xticklabels([]) | |
| plt.subplot(4, 2, 2) | |
| plt.plot(t, noisy_signal, 'k', label='Noisy signal', linewidth=0.5) | |
| signal_ylim = plt.gca().get_ylim() | |
| plt.gca().set_xticklabels([]) | |
| plt.legend(loc='lower left') | |
| plt.margins(x=0) | |
| plt.subplot(4, 2, 3) | |
| plt.pcolormesh(Y[i, :, :, 0], vmin=0, vmax=1) | |
| plt.gca().set_xticklabels([]) | |
| plt.title("Y") | |
| plt.subplot(4, 2, 4) | |
| plt.pcolormesh(preds[i, :, :, 0], vmin=0, vmax=1) | |
| plt.title("$\hat{Y}$") | |
| plt.gca().set_xticklabels([]) | |
| plt.subplot(4, 2, 5) | |
| plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * Y[i, :, :, 0], vmin=0, vmax=2) | |
| plt.title("Ideal denoised signal") | |
| plt.gca().set_xticklabels([]) | |
| plt.subplot(4, 2, 6) | |
| plt.pcolormesh(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=2) | |
| plt.title("Denoised signal") | |
| plt.gca().set_xticklabels([]) | |
| plt.subplot(4, 2, 7) | |
| plt.plot(t, ideal_denoised_signal, 'k', label='Ideal denoised signal', linewidth=0.5) | |
| plt.ylim(signal_ylim) | |
| plt.xlabel("Time (s)") | |
| plt.legend(loc='lower left') | |
| plt.margins(x=0) | |
| plt.subplot(4, 2, 8) | |
| plt.plot(t, denoised_signal, 'k', label='Denoised signal', linewidth=0.5) | |
| plt.ylim(signal_ylim) | |
| plt.xlabel("Time (s)") | |
| plt.legend(loc='lower left') | |
| plt.margins(x=0) | |
| plt.tight_layout() | |
| plt.gcf().align_labels() | |
| plt.savefig(os.path.join(figure_dir, "epoch{:03d}_{:03d}_{:}.png".format(epoch, i, mode)), bbox_inches='tight') | |
| plt.close(i) | |
| return 0 | |
| def postprocessing_test( | |
| i, preds, X, fname, figure_dir=None, result_dir=None, signal_FT=None, noise_FT=None, data_dir=None | |
| ): | |
| if (figure_dir is not None) or (result_dir is not None): | |
| config = Config() | |
| t1, noisy_signal = scipy.signal.istft( | |
| X[i, :, :, 0] + X[i, :, :, 1] * 1j, fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
| ) | |
| t1, denoised_signal = scipy.signal.istft( | |
| (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| t1, denoised_noise = scipy.signal.istft( | |
| (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * (1 - preds[i, :, :, 0]), | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| t1, signal = scipy.signal.istft( | |
| signal_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
| ) | |
| t1, noise = scipy.signal.istft( | |
| noise_FT[i, :, :], fs=config.fs, nperseg=config.nperseg, nfft=config.nfft, boundary='zeros' | |
| ) | |
| if result_dir is not None: | |
| try: | |
| np.savez( | |
| os.path.join(result_dir, fname[i].decode()), | |
| preds=preds[i], | |
| X=X[i], | |
| signal_FT=signal_FT[i], | |
| noise_FT=noise_FT[i], | |
| noisy_signal=noisy_signal, | |
| denoised_signal=denoised_signal, | |
| denoised_noise=denoised_noise, | |
| signal=signal, | |
| noise=noise, | |
| ) | |
| except FileNotFoundError: | |
| os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True) | |
| np.savez( | |
| os.path.join(result_dir, fname[i].decode()), | |
| preds=preds[i], | |
| X=X[i], | |
| signal_FT=signal_FT[i], | |
| noise_FT=noise_FT[i], | |
| noisy_signal=noisy_signal, | |
| denoised_signal=denoised_signal, | |
| denoised_noise=denoised_noise, | |
| signal=signal, | |
| noise=noise, | |
| ) | |
| if figure_dir is not None: | |
| t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2]) | |
| f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1]) | |
| raw_data = None | |
| if data_dir is not None: | |
| raw_data = np.load(os.path.join(data_dir, fname[i].decode().split('/')[-1])) | |
| itp = raw_data['itp'] | |
| its = raw_data['its'] | |
| ix1 = (750 - 50) / 100 | |
| ix2 = (750 + (its - itp) + 50) / 100 | |
| if ix2 - ix1 > 3: | |
| ix2 = ix1 + 3 | |
| box = dict(boxstyle='round', facecolor='white', alpha=1) | |
| text_loc = [0.05, 0.8] | |
| plt.figure(i) | |
| fig_size = plt.gcf().get_size_inches() | |
| plt.gcf().set_size_inches(fig_size * [1, 2]) | |
| plt.subplot(511) | |
| plt.pcolormesh(t_FT, f_FT, np.abs(signal_FT[i, :, :]), vmin=0, vmax=1) | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(i)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(512) | |
| plt.pcolormesh(t_FT, f_FT, np.abs(noise_FT[i, :, :]), vmin=0, vmax=1) | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(ii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(513) | |
| plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), vmin=0, vmax=1) | |
| plt.ylabel("Frequency (Hz)", fontsize='large') | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(iii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(514) | |
| plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], vmin=0, vmax=1) | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(iv)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(515) | |
| plt.pcolormesh(t_FT, f_FT, np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], vmin=0, vmax=1) | |
| plt.xlabel("Time (s)", fontsize='large') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(v)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| try: | |
| plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
| except FileNotFoundError: | |
| os.makedirs( | |
| os.path.dirname(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png')), exist_ok=True | |
| ) | |
| plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
| plt.close(i) | |
| text_loc = [0.05, 0.8] | |
| plt.figure(i) | |
| fig_size = plt.gcf().get_size_inches() | |
| plt.gcf().set_size_inches(fig_size * [1, 2]) | |
| ax3 = plt.subplot(513) | |
| plt.plot(t1, noisy_signal, 'k', linewidth=0.5, label='Noisy signal') | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| plt.ylim([-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))]) | |
| signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))] | |
| plt.ylim(signal_ylim) | |
| plt.ylabel("Amplitude", fontsize='large') | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(iii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| ax1 = plt.subplot(511) | |
| plt.plot(t1, signal, 'k', linewidth=0.5, label='Signal') | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| plt.ylim(signal_ylim) | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(i)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(512) | |
| plt.plot(t1, noise, 'k', linewidth=0.5, label='Noise') | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| plt.ylim([-np.max(np.abs(noise)), np.max(np.abs(noise))]) | |
| noise_ylim = [-np.max(np.abs(noise[100:-100])), np.max(np.abs(noise[100:-100]))] | |
| plt.ylim(noise_ylim) | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(ii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| ax4 = plt.subplot(514) | |
| plt.plot(t1, denoised_signal, 'k', linewidth=0.5, label='Recovered signal') | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| plt.ylim(signal_ylim) | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(iv)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(515) | |
| plt.plot(t1, denoised_noise, 'k', linewidth=0.5, label='Recovered noise') | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| plt.xlabel("Time (s)", fontsize='large') | |
| plt.ylim(noise_ylim) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(v)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| if data_dir is not None: | |
| axins = inset_axes( | |
| ax1, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax1.transAxes | |
| ) | |
| axins.plot(t1, signal, 'k', linewidth=0.5) | |
| x1, x2 = ix1, ix2 | |
| y1 = -np.max(np.abs(signal[(t1 > ix1) & (t1 < ix2)])) | |
| y2 = -y1 | |
| axins.set_xlim(x1, x2) | |
| axins.set_ylim(y1, y2) | |
| plt.xticks(visible=False) | |
| plt.yticks(visible=False) | |
| mark_inset(ax1, axins, loc1=1, loc2=3, fc="none", ec="0.5") | |
| axins = inset_axes( | |
| ax3, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.3), bbox_transform=ax3.transAxes | |
| ) | |
| axins.plot(t1, noisy_signal, 'k', linewidth=0.5) | |
| x1, x2 = ix1, ix2 | |
| axins.set_xlim(x1, x2) | |
| axins.set_ylim(y1, y2) | |
| plt.xticks(visible=False) | |
| plt.yticks(visible=False) | |
| mark_inset(ax3, axins, loc1=1, loc2=3, fc="none", ec="0.5") | |
| axins = inset_axes( | |
| ax4, width=2.0, height=1.0, loc='upper right', bbox_to_anchor=(1, 0.5), bbox_transform=ax4.transAxes | |
| ) | |
| axins.plot(t1, denoised_signal, 'k', linewidth=0.5) | |
| x1, x2 = ix1, ix2 | |
| axins.set_xlim(x1, x2) | |
| axins.set_ylim(y1, y2) | |
| plt.xticks(visible=False) | |
| plt.yticks(visible=False) | |
| mark_inset(ax4, axins, loc1=1, loc2=3, fc="none", ec="0.5") | |
| plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz') + '_wave.png'), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, fname[i].decode().rstrip('.npz')+'_wave.pdf'), bbox_inches='tight') | |
| plt.close(i) | |
| return | |
| def postprocessing_pred(i, preds, X, fname, figure_dir=None, result_dir=None): | |
| if (result_dir is not None) or (figure_dir is not None): | |
| config = Config() | |
| t1, noisy_signal = scipy.signal.istft( | |
| (X[i, :, :, 0] + X[i, :, :, 1] * 1j), | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| t1, denoised_signal = scipy.signal.istft( | |
| (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| t1, denoised_noise = scipy.signal.istft( | |
| (X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| if result_dir is not None: | |
| try: | |
| np.savez( | |
| os.path.join(result_dir, fname[i]), | |
| noisy_signal=noisy_signal, | |
| denoised_signal=denoised_signal, | |
| denoised_noise=denoised_noise, | |
| t=t1, | |
| ) | |
| except FileNotFoundError: | |
| os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i]))) | |
| np.savez( | |
| os.path.join(result_dir, fname[i]), | |
| noisy_signal=noisy_signal, | |
| denoised_signal=denoised_signal, | |
| denoised_noise=denoised_noise, | |
| t=t1, | |
| ) | |
| if figure_dir is not None: | |
| t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[2]) | |
| f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[1]) | |
| box = dict(boxstyle='round', facecolor='white', alpha=1) | |
| text_loc = [0.05, 0.77] | |
| plt.figure(i) | |
| fig_size = plt.gcf().get_size_inches() | |
| plt.gcf().set_size_inches(fig_size * [1, 1.2]) | |
| vmax = np.std(np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j)) * 1.8 | |
| plt.subplot(311) | |
| plt.pcolormesh( | |
| t_FT, | |
| f_FT, | |
| np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j), | |
| vmin=0, | |
| vmax=vmax, | |
| shading='auto', | |
| label='Noisy signal', | |
| ) | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(i)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(312) | |
| plt.pcolormesh( | |
| t_FT, | |
| f_FT, | |
| np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 0], | |
| vmin=0, | |
| vmax=vmax, | |
| shading='auto', | |
| label='Recovered signal', | |
| ) | |
| plt.gca().set_xticklabels([]) | |
| plt.ylabel("Frequency (Hz)", fontsize='large') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(ii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(313) | |
| plt.pcolormesh( | |
| t_FT, | |
| f_FT, | |
| np.abs(X[i, :, :, 0] + X[i, :, :, 1] * 1j) * preds[i, :, :, 1], | |
| vmin=0, | |
| vmax=vmax, | |
| shading='auto', | |
| label='Recovered noise', | |
| ) | |
| plt.xlabel("Time (s)", fontsize='large') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(iii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| try: | |
| plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
| except FileNotFoundError: | |
| os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png')), exist_ok=True) | |
| plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
| plt.close(i) | |
| plt.figure(i) | |
| fig_size = plt.gcf().get_size_inches() | |
| plt.gcf().set_size_inches(fig_size * [1, 1.2]) | |
| ax4 = plt.subplot(311) | |
| plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5) | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| signal_ylim = [-np.max(np.abs(noisy_signal[100:-100])), np.max(np.abs(noisy_signal[100:-100]))] | |
| plt.ylim(signal_ylim) | |
| plt.gca().set_xticklabels([]) | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(i)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| ax5 = plt.subplot(312) | |
| plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5) | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| plt.ylim(signal_ylim) | |
| plt.gca().set_xticklabels([]) | |
| plt.ylabel("Amplitude", fontsize='large') | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(ii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(313) | |
| plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5) | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| plt.ylim(signal_ylim) | |
| plt.xlabel("Time (s)", fontsize='large') | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(iii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz') + '_wave.png'), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight') | |
| plt.close(i) | |
| return | |
| def save_results(mask, X, fname, t0, save_signal=True, save_noise=True, result_dir="results"): | |
| config = Config() | |
| if save_signal: | |
| _, denoised_signal = scipy.signal.istft( | |
| (X[..., 0] + X[..., 1] * 1j) * mask[..., 0], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) # nbt, nch, nst, nt | |
| denoised_signal = np.transpose(denoised_signal, [0, 3, 2, 1]) # nbt, nt, nst, nch, | |
| if save_noise: | |
| _, denoised_noise = scipy.signal.istft( | |
| (X[..., 0] + X[..., 1] * 1j) * mask[..., 1], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| denoised_noise = np.transpose(denoised_noise, [0, 3, 2, 1]) | |
| if not os.path.exists(result_dir): | |
| os.makedirs(result_dir) | |
| for i in range(len(X)): | |
| np.savez( | |
| os.path.join(result_dir, fname[i]), | |
| data=denoised_signal[i] if save_signal else None, | |
| noise=denoised_noise[i] if save_noise else None, | |
| t0=t0[i], | |
| ) | |
| def plot_figures(mask, X, fname, figure_dir="figures"): | |
| config = Config() | |
| # plot the last channel | |
| mask = mask[-1, -1, ...] # nch, nst, nf, nt, 2 => nf, nt, 2 | |
| X = X[-1, -1, ...] | |
| t1, noisy_signal = scipy.signal.istft( | |
| (X[..., 0] + X[..., 1] * 1j), | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| t1, denoised_signal = scipy.signal.istft( | |
| (X[..., 0] + X[..., 1] * 1j) * mask[..., 0], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| t1, denoised_noise = scipy.signal.istft( | |
| (X[..., 0] + X[..., 1] * 1j) * mask[..., 1], | |
| fs=config.fs, | |
| nperseg=config.nperseg, | |
| nfft=config.nfft, | |
| boundary='zeros', | |
| ) | |
| if not os.path.exists(figure_dir): | |
| os.makedirs(figure_dir) | |
| t_FT = np.linspace(config.time_range[0], config.time_range[1], X.shape[1]) | |
| f_FT = np.linspace(config.freq_range[0], config.freq_range[1], X.shape[0]) | |
| box = dict(boxstyle='round', facecolor='white', alpha=1) | |
| text_loc = [0.05, 0.77] | |
| plt.figure() | |
| fig_size = plt.gcf().get_size_inches() | |
| plt.gcf().set_size_inches(fig_size * [1, 1.2]) | |
| vmax = np.std(np.abs(X[:, :, 0] + X[:, :, 1] * 1j)) * 1.8 | |
| plt.subplot(311) | |
| plt.pcolormesh( | |
| t_FT, | |
| f_FT, | |
| np.abs(X[:, :, 0] + X[:, :, 1] * 1j), | |
| vmin=0, | |
| vmax=vmax, | |
| shading='auto', | |
| label='Noisy signal', | |
| ) | |
| plt.gca().set_xticklabels([]) | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(i)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(312) | |
| plt.pcolormesh( | |
| t_FT, | |
| f_FT, | |
| np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 0], | |
| vmin=0, | |
| vmax=vmax, | |
| shading='auto', | |
| label='Recovered signal', | |
| ) | |
| plt.gca().set_xticklabels([]) | |
| plt.ylabel("Frequency (Hz)", fontsize='large') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(ii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(313) | |
| plt.pcolormesh( | |
| t_FT, | |
| f_FT, | |
| np.abs(X[:, :, 0] + X[:, :, 1] * 1j) * mask[:, :, 1], | |
| vmin=0, | |
| vmax=vmax, | |
| shading='auto', | |
| label='Recovered noise', | |
| ) | |
| plt.xlabel("Time (s)", fontsize='large') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(iii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| try: | |
| plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
| except FileNotFoundError: | |
| os.makedirs(os.path.dirname(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png')), exist_ok=True) | |
| plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_FT.png'), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, fname[i].split('/')[-1].rstrip('.npz')+'_FT.pdf'), bbox_inches='tight') | |
| plt.close() | |
| plt.figure() | |
| fig_size = plt.gcf().get_size_inches() | |
| plt.gcf().set_size_inches(fig_size * [1, 1.2]) | |
| ax4 = plt.subplot(311) | |
| plt.plot(t1, noisy_signal, 'k', label='Noisy signal', linewidth=0.5) | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| signal_ylim = [-np.max(np.abs(noisy_signal)), np.max(np.abs(noisy_signal))] | |
| if signal_ylim[0] != signal_ylim[1]: | |
| plt.ylim(signal_ylim) | |
| plt.gca().set_xticklabels([]) | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(i)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| ax5 = plt.subplot(312) | |
| plt.plot(t1, denoised_signal, 'k', label='Recovered signal', linewidth=0.5) | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| if signal_ylim[0] != signal_ylim[1]: | |
| plt.ylim(signal_ylim) | |
| plt.gca().set_xticklabels([]) | |
| plt.ylabel("Amplitude", fontsize='large') | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(ii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.subplot(313) | |
| plt.plot(t1, denoised_noise, 'k', label='Recovered noise', linewidth=0.5) | |
| plt.xlim([np.around(t1[0]), np.around(t1[-1])]) | |
| if signal_ylim[0] != signal_ylim[1]: | |
| plt.ylim(signal_ylim) | |
| plt.xlabel("Time (s)", fontsize='large') | |
| plt.legend(loc='lower left', fontsize='medium') | |
| plt.text( | |
| text_loc[0], | |
| text_loc[1], | |
| '(iii)', | |
| horizontalalignment='center', | |
| transform=plt.gca().transAxes, | |
| fontsize="medium", | |
| fontweight="bold", | |
| bbox=box, | |
| ) | |
| plt.savefig(os.path.join(figure_dir, fname.rstrip('.npz') + '_wave.png'), bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, fname[i].rstrip('.npz')+'_wave.pdf'), bbox_inches='tight') | |
| plt.close() | |
| return | |
| if __name__ == "__main__": | |
| pass | |