| import numpy as np |
| from scipy.signal import butter, filtfilt |
| from tqdm import tqdm |
| import matplotlib.pyplot as plt |
| import os |
|
|
|
|
| |
| def bandpass_filter(data, lowcut, highcut, fs, order=2): |
| nyq = 0.5 * fs |
| low, high = lowcut / nyq, highcut / nyq |
| b, a = butter(order, [low, high], btype='band') |
| return filtfilt(b, a, data) |
|
|
| def plot_reconstructions(originals_list, reconstructions_list, fs, bands, |
| labels=["NeuroRVQ"], save_dir="./figures"): |
|
|
| if not os.path.exists(save_dir): |
| os.makedirs(save_dir) |
|
|
| N, T = originals_list[0].shape |
| time = np.linspace(0, T / fs, T) |
|
|
| for i in tqdm(range(N), desc="Samples"): |
| plt.figure(figsize=(10, 12)) |
|
|
| |
| plt.subplot(6, 1, 1) |
| orig = originals_list[0][i] |
| recon = reconstructions_list[0][i] |
| label = labels[0] |
|
|
| plt.plot(time, orig, label=f"Original Signal", alpha=0.7) |
| plt.plot(time, recon, linestyle='--', label=f"{label} Reconstruction", alpha=0.7) |
|
|
| plt.title(f"Raw Signal") |
| plt.legend() |
| plt.ylabel("Amplitude") |
|
|
| |
| for j, (band_name, (low, high)) in enumerate(bands.items()): |
| plt.subplot(6, 1, j + 2) |
| orig = originals_list[0][i] |
| recon = reconstructions_list[0][i] |
| label = labels[0] |
|
|
| orig_band = bandpass_filter(orig, low, high, fs) |
| recon_band = bandpass_filter(recon, low, high, fs) |
|
|
| plt.plot(time, orig_band, label=f"{label} Original Signal", alpha=0.7) |
| plt.plot(time, recon_band, linestyle='--', label=f"{label} Reconstruction", alpha=0.7) |
|
|
| plt.title(f"{band_name} Band") |
| plt.ylabel("Amplitude") |
|
|
| plt.xlabel("Time (s)") |
| plt.tight_layout() |
|
|
| plt.savefig(f"{save_dir}/sample_{i}.png") |
| plt.close() |
|
|
|
|
| def process_and_plot(originals, reconstructions, fs, mode): |
| P, T = reconstructions[0].shape |
|
|
| originals_np = [ |
| original.detach().cpu().numpy().reshape(P, T) |
| for original in originals |
| ] |
| reconstructions_np = [ |
| reconstruction.detach().cpu().numpy().reshape(P, T) |
| for reconstruction in reconstructions |
| ] |
| |
| if mode=="EEG": |
| |
| bands = { |
| "Delta (0.5β4 Hz)": (0.5, 4), |
| "Theta (4β8 Hz)": (4, 8), |
| "Alpha (8β13 Hz)": (8, 13), |
| "Beta (13β30 Hz)": (13, 30), |
| "Gamma (30β45 Hz)": (30, 45), |
| } |
| elif mode=="EMG": |
| |
| bands = { |
| "Band 1 (20β60 Hz)": (20, 60), |
| "Band 2 (60β125 Hz)": (60, 125), |
| "Band 3 (125-200 Hz)": (125, 200), |
| "Band 4 (200-250 Hz)": (200, 250), |
| "Band 5 (250-400 Hz)": (250, 400), |
| } |
| elif mode=="ECG": |
| |
| bands = { |
| "Band 1 (0.5β10 Hz)": (0.5, 10), |
| "Band 2 (10β20 Hz)": (10, 20), |
| "Band 3 (20-30 Hz)": (20, 30), |
| "Band 4 (30-40 Hz)": (30, 40), |
| "Band 5 (40-50 Hz)": (40, 50), |
| } |
|
|
| |
| plot_reconstructions(originals_np, reconstructions_np, fs, bands) |
| |
|
|