File size: 3,245 Bytes
27f9443
 
 
 
 
 
 
 
 
 
 
 
 
 
56603ec
27f9443
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56603ec
27f9443
 
 
 
 
 
 
 
 
 
56603ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78b4171
 
 
 
 
 
 
 
 
27f9443
 
56603ec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import numpy as np
from scipy.signal import butter, filtfilt
from tqdm import tqdm
import matplotlib.pyplot as plt
import os


# Bandpass filter function
def bandpass_filter(data, lowcut, highcut, fs, order=2):
    nyq = 0.5 * fs
    low, high = lowcut / nyq, highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

def plot_reconstructions(originals_list, reconstructions_list, fs, bands,
                         labels=["NeuroRVQ"], save_dir="./figures"):

    if not os.path.exists(save_dir):
        os.makedirs(save_dir)

    N, T = originals_list[0].shape
    time = np.linspace(0, T / fs, T)

    for i in tqdm(range(N), desc="Samples"):
        plt.figure(figsize=(10, 12))

        # Plot raw signals
        plt.subplot(6, 1, 1)
        orig = originals_list[0][i]
        recon = reconstructions_list[0][i]
        label = labels[0]

        plt.plot(time, orig, label=f"Original Signal", alpha=0.7)
        plt.plot(time, recon, linestyle='--', label=f"{label} Reconstruction", alpha=0.7)

        plt.title(f"Raw Signal")
        plt.legend()
        plt.ylabel("Amplitude")

        # Plot filtered bands
        for j, (band_name, (low, high)) in enumerate(bands.items()):
            plt.subplot(6, 1, j + 2)
            orig = originals_list[0][i]
            recon = reconstructions_list[0][i]
            label = labels[0]

            orig_band = bandpass_filter(orig, low, high, fs)
            recon_band = bandpass_filter(recon, low, high, fs)

            plt.plot(time, orig_band, label=f"{label} Original Signal", alpha=0.7)
            plt.plot(time, recon_band, linestyle='--', label=f"{label} Reconstruction", alpha=0.7)

            plt.title(f"{band_name} Band")
            plt.ylabel("Amplitude")

        plt.xlabel("Time (s)")
        plt.tight_layout()

        plt.savefig(f"{save_dir}/sample_{i}.png")
        plt.close()


def process_and_plot(originals, reconstructions, fs, mode):
    P, T = reconstructions[0].shape

    originals_np = [
        original.detach().cpu().numpy().reshape(P, T)
        for original in originals
    ]
    reconstructions_np = [
        reconstruction.detach().cpu().numpy().reshape(P, T)
        for reconstruction in reconstructions
    ]
    
    if mode=="EEG":
        # Define EEG bands
        bands = {
            "Delta (0.5–4 Hz)": (0.5, 4),
            "Theta (4–8 Hz)": (4, 8),
            "Alpha (8–13 Hz)": (8, 13),
            "Beta (13–30 Hz)": (13, 30),
            "Gamma (30–45 Hz)": (30, 45),
        }
    elif mode=="EMG":
        # Define EMG bands
        bands = {
            "Band 1 (20–60 Hz)": (20, 60),
            "Band 2 (60–125 Hz)": (60, 125),
            "Band 3 (125-200 Hz)": (125, 200),
            "Band 4 (200-250 Hz)": (200, 250),
            "Band 5 (250-400 Hz)": (250, 400),
        }
    elif mode=="ECG":
        # Define ECG bands
        bands = {
            "Band 1 (0.5–10 Hz)": (0.5, 10),
            "Band 2 (10–20 Hz)": (10, 20),
            "Band 3 (20-30 Hz)": (20, 30),
            "Band 4 (30-40 Hz)": (30, 40),
            "Band 5 (40-50 Hz)": (40, 50),
        }

    # Plot
    plot_reconstructions(originals_np, reconstructions_np, fs, bands)