File size: 3,245 Bytes
27f9443 56603ec 27f9443 56603ec 27f9443 56603ec 78b4171 27f9443 56603ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | import numpy as np
from scipy.signal import butter, filtfilt
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
# Bandpass filter function
def bandpass_filter(data, lowcut, highcut, fs, order=2):
nyq = 0.5 * fs
low, high = lowcut / nyq, highcut / nyq
b, a = butter(order, [low, high], btype='band')
return filtfilt(b, a, data)
def plot_reconstructions(originals_list, reconstructions_list, fs, bands,
labels=["NeuroRVQ"], save_dir="./figures"):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
N, T = originals_list[0].shape
time = np.linspace(0, T / fs, T)
for i in tqdm(range(N), desc="Samples"):
plt.figure(figsize=(10, 12))
# Plot raw signals
plt.subplot(6, 1, 1)
orig = originals_list[0][i]
recon = reconstructions_list[0][i]
label = labels[0]
plt.plot(time, orig, label=f"Original Signal", alpha=0.7)
plt.plot(time, recon, linestyle='--', label=f"{label} Reconstruction", alpha=0.7)
plt.title(f"Raw Signal")
plt.legend()
plt.ylabel("Amplitude")
# Plot filtered bands
for j, (band_name, (low, high)) in enumerate(bands.items()):
plt.subplot(6, 1, j + 2)
orig = originals_list[0][i]
recon = reconstructions_list[0][i]
label = labels[0]
orig_band = bandpass_filter(orig, low, high, fs)
recon_band = bandpass_filter(recon, low, high, fs)
plt.plot(time, orig_band, label=f"{label} Original Signal", alpha=0.7)
plt.plot(time, recon_band, linestyle='--', label=f"{label} Reconstruction", alpha=0.7)
plt.title(f"{band_name} Band")
plt.ylabel("Amplitude")
plt.xlabel("Time (s)")
plt.tight_layout()
plt.savefig(f"{save_dir}/sample_{i}.png")
plt.close()
def process_and_plot(originals, reconstructions, fs, mode):
P, T = reconstructions[0].shape
originals_np = [
original.detach().cpu().numpy().reshape(P, T)
for original in originals
]
reconstructions_np = [
reconstruction.detach().cpu().numpy().reshape(P, T)
for reconstruction in reconstructions
]
if mode=="EEG":
# Define EEG bands
bands = {
"Delta (0.5β4 Hz)": (0.5, 4),
"Theta (4β8 Hz)": (4, 8),
"Alpha (8β13 Hz)": (8, 13),
"Beta (13β30 Hz)": (13, 30),
"Gamma (30β45 Hz)": (30, 45),
}
elif mode=="EMG":
# Define EMG bands
bands = {
"Band 1 (20β60 Hz)": (20, 60),
"Band 2 (60β125 Hz)": (60, 125),
"Band 3 (125-200 Hz)": (125, 200),
"Band 4 (200-250 Hz)": (200, 250),
"Band 5 (250-400 Hz)": (250, 400),
}
elif mode=="ECG":
# Define ECG bands
bands = {
"Band 1 (0.5β10 Hz)": (0.5, 10),
"Band 2 (10β20 Hz)": (10, 20),
"Band 3 (20-30 Hz)": (20, 30),
"Band 4 (30-40 Hz)": (30, 40),
"Band 5 (40-50 Hz)": (40, 50),
}
# Plot
plot_reconstructions(originals_np, reconstructions_np, fs, bands)
|