File size: 12,041 Bytes
c70a8ec
 
4885bf4
 
c70a8ec
 
 
 
 
4885bf4
a3ca400
0928626
373ac12
 
 
4885bf4
c70a8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3ca400
c70a8ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3ca400
 
 
 
 
 
 
 
 
 
 
c70a8ec
 
095893d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6aca538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ffec487
 
 
 
 
 
 
0928626
ffec487
 
 
 
 
 
 
 
0928626
ffec487
 
 
0928626
ffec487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaea07f
 
 
a7252ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaea07f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0928626
aaea07f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
import os
import io
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import mne
from mne.channels import make_dig_montage
from matplotlib.animation import FuncAnimation
import requests
import streamlit as st
import tempfile
from scipy.signal import welch
os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib"
os.environ["STREAMLIT_HOME"] = "/tmp/streamlit"


# Set page config
st.set_page_config(page_title="Muse EEG Topomap Viewer", layout="centered")

def plot_eeg_topomap_muse_from_csv(csv_url, save_path_animation=None, save_directory=None,
                                   show_names=False, start_time=0.05, end_time=1, step_size=0.1):
    response = requests.get(csv_url)
    if response.status_code != 200:
        raise RuntimeError(f"Failed to download CSV from {csv_url}")
    data = pd.read_csv(io.StringIO(response.text))
    st.write(f"Loaded data with shape: {data.shape}")

    muse_channels = ['AF7', 'AF8', 'TP9', 'TP10']
    if not all(ch in data.columns for ch in muse_channels):
        raise ValueError(f"The dataset must contain the following channels: {muse_channels}")

    eeg_data = data[muse_channels].values.T

    muse_positions = {
        'AF7': [-0.05, 0.085, 0],
        'AF8': [0.05, 0.085, 0],
        'TP9': [-0.08, -0.04, 0],
        'TP10': [0.08, -0.04, 0]
    }
    for ch in muse_positions:
        muse_positions[ch][0] += np.random.normal(0, 0.0001)
        muse_positions[ch][1] += np.random.normal(0, 0.0001)

    montage = make_dig_montage(ch_pos=muse_positions, coord_frame='head')

    if 'TimeStamp' in data.columns:
        timestamps = data['TimeStamp']
    elif 'timestamps' in data.columns:
        timestamps = data['timestamps']
    else:
        raise ValueError("CSV must contain a 'TimeStamp' or 'timestamps' column")

    time_diffs = np.diff(timestamps)
    average_interval = np.mean(time_diffs)
    sfreq = 1 / average_interval

    info = mne.create_info(ch_names=muse_channels, sfreq=sfreq, ch_types='eeg')
    evoked = mne.EvokedArray(eeg_data, info)
    evoked.set_montage(montage)

    times = np.arange(start_time, end_time, step_size)
    if len(times) == 0:
        raise ValueError("Not enough time range selected for topomap animation.")

    if save_path_animation:
        if save_directory:
            os.makedirs(save_directory, exist_ok=True)
            save_path_animation = os.path.join(save_directory, os.path.basename(save_path_animation))

        fig, ax = plt.subplots()

        def update(frame):
            ax.clear()
            evoked.plot_topomap([times[frame]], ch_type='eeg', time_unit='s',
                                axes=ax, colorbar=False, show=False)

        anim = FuncAnimation(fig, update, frames=range(len(times)), interval=200)
        anim.save(save_path_animation, writer='pillow')
        plt.close(fig)

        return save_path_animation
    else:
        return None

# Streamlit UI
st.title("🧠 Muse EEG Topomap Animation from CSV")

default_url = "https://raw.githubusercontent.com/garenasd945/EEG_muse2/refs/heads/master/dataset/original_data/subjecta-relaxed-1.csv"
csv_url = st.text_input("CSV URL", value=default_url)
generate_button = st.button("Generate Topomap Animation")

if generate_button:
    try:
        with tempfile.TemporaryDirectory() as tmpdir:
            save_path_animation = os.path.join(tmpdir, "eeg_topomap_animation.gif")

            gif_path = plot_eeg_topomap_muse_from_csv(
                csv_url,
                save_path_animation=save_path_animation,
                save_directory=tmpdir,
                show_names=True
            )
            if gif_path:
                st.image(gif_path, caption="EEG Topomap Animation")
    except Exception as e:
        st.error(f"❌ Error: {str(e)}")


plot_waveform_button = st.button("Plot TP9 Waveforms")

if plot_waveform_button:
    try:
        # Download and read CSV
        response = requests.get(csv_url)
        response.raise_for_status()
        data = pd.read_csv(io.StringIO(response.text))
        st.write(f"Loaded data with shape: {data.shape}")

        # Extract timestamps and TP9
        timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp']
        tp10_raw = data['TP9']

        # Calculate sampling rate
        time_diffs = np.diff(timestamps)
        average_interval = np.mean(time_diffs)
        sampling_rate = 1 / average_interval

        # Create MNE Raw object
        info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg')
        raw = mne.io.RawArray(np.array([tp10_raw]), info)

        # Apply bandpass filter
        raw.filter(1, 50, fir_design='firwin')
        tp10_filtered_mne = raw.get_data()[0]

        # Plot raw waveform
        fig_raw, ax_raw = plt.subplots(figsize=(12, 4))
        ax_raw.plot(timestamps, tp10_raw, label='TP9', alpha=0.6, color='blue')
        ax_raw.set_title('Raw Waveform of TP9')
        ax_raw.set_xlabel('Time (seconds)')
        ax_raw.set_ylabel('Amplitude')
        ax_raw.legend()
        ax_raw.grid(True)
        st.pyplot(fig_raw)

        # Plot filtered waveform
        fig_filtered, ax_filtered = plt.subplots(figsize=(12, 4))
        ax_filtered.plot(timestamps, tp10_filtered_mne, label='TP9 (1-50Hz)', linewidth=1.5, color='green')
        ax_filtered.set_title('Filtered Waveform of TP9')
        ax_filtered.set_xlabel('Time (seconds)')
        ax_filtered.set_ylabel('Amplitude')
        ax_filtered.legend()
        ax_filtered.grid(True)
        st.pyplot(fig_filtered)

    except Exception as e:
        st.error(f"❌ Error: {str(e)}")

spectrogram_button = st.button("Plot TP9 Spectrograms")

if spectrogram_button:
    try:
        # Download and read CSV
        response = requests.get(csv_url)
        response.raise_for_status()
        data = pd.read_csv(io.StringIO(response.text))
        st.write(f"Loaded data with shape: {data.shape}")

        # Extract timestamps and TP9
        timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp']
        tp10_raw = data['TP9']

        # Calculate sampling rate
        time_diffs = np.diff(timestamps)
        average_interval = np.mean(time_diffs)
        sampling_rate = 1 / average_interval

        # Create MNE Raw object
        info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg')
        raw = mne.io.RawArray(np.array([tp10_raw]), info)

        # Apply bandpass filter
        raw.filter(1, 50, fir_design='firwin')
        tp10_filtered_mne = raw.get_data()[0]

        from scipy.signal import spectrogram

        # Spectrogram before filtering
        frequencies_pre, times_pre, spectrogram_pre = spectrogram(tp10_raw, fs=sampling_rate, nperseg=256)
        fig_spec_pre, ax_pre = plt.subplots(figsize=(10, 6))
        p1 = ax_pre.pcolormesh(times_pre, frequencies_pre, np.log10(spectrogram_pre + 1e-10), shading='gouraud')
        fig_spec_pre.colorbar(p1, ax=ax_pre, label='Log Power Spectral Density')
        ax_pre.set_title('Spectrogram of TP9 (Raw)')
        ax_pre.set_xlabel('Time (seconds)')
        ax_pre.set_ylabel('Frequency (Hz)')
        st.pyplot(fig_spec_pre)

        # Spectrogram after filtering
        frequencies_post, times_post, spectrogram_post = spectrogram(tp10_filtered_mne, fs=sampling_rate, nperseg=256)
        fig_spec_post, ax_post = plt.subplots(figsize=(10, 6))
        p2 = ax_post.pcolormesh(times_post, frequencies_post, np.log10(spectrogram_post + 1e-10), shading='gouraud')
        fig_spec_post.colorbar(p2, ax=ax_post, label='Log Power Spectral Density')
        ax_post.set_title('Spectrogram of TP9 (Filtered 1–50Hz)')
        ax_post.set_xlabel('Time (seconds)')
        ax_post.set_ylabel('Frequency (Hz)')
        st.pyplot(fig_spec_post)

    except Exception as e:
        st.error(f"❌ Error: {str(e)}")

# Button to calculate and plot band powers
if st.button("Show Band Power Chart (TP9)"):
    try:
        response = requests.get(csv_url)
        response.raise_for_status()
        data = pd.read_csv(io.StringIO(response.text))
        st.write(f"Loaded data with shape: {data.shape}")

        # Extract timestamps and TP9
        timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp']
        tp10_raw = data['TP9']

        # Calculate sampling rate
        time_diffs = np.diff(timestamps)
        average_interval = np.mean(time_diffs)
        sampling_rate = 1 / average_interval

        # Create MNE Raw object
        info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg')
        raw = mne.io.RawArray(np.array([tp10_raw]), info)

        # Apply bandpass filter
        raw.filter(1, 50, fir_design='firwin')
        tp10_filtered_mne = raw.get_data()[0]

        # Define bands
        bands = {
            "Delta (1–4 Hz)": (1, 4),
            "Theta (4–8 Hz)": (4, 8),
            "Alpha (8–12 Hz)": (8, 12),
            "Beta (12–30 Hz)": (12, 30),
            "Gamma (30–50 Hz)": (30, 50)
        }

        # Bandpower calculation
        from scipy.signal import welch

        def bandpower(data, sf, band, window_sec=None):
            band = np.asarray(band)
            low, high = band
            nperseg = int(window_sec * sf) if window_sec else None
            freqs, psd = welch(data, sf, nperseg=nperseg)
            idx_band = np.logical_and(freqs >= low, freqs <= high)
            return np.trapz(psd[idx_band], freqs[idx_band])

        # Calculate power for each band
        powers = [bandpower(tp10_filtered_mne, sampling_rate, b) for b in bands.values()]

        # Plot
        fig, ax = plt.subplots()
        ax.bar(bands.keys(), powers, color='skyblue')
        ax.set_title("Band Power - TP9")
        ax.set_ylabel("Power")
        ax.tick_params(axis='x', rotation=45)
        st.pyplot(fig)

    except Exception as e:
        st.error(f"❌ Error: {str(e)}")
# Button to show EEG animation
if st.button("Show Animated EEG Trace (TP9)"):
    try:
        response = requests.get(csv_url)
        response.raise_for_status()
        data = pd.read_csv(io.StringIO(response.text))
        st.write(f"Loaded data with shape: {data.shape}")

        # Extract timestamps and TP9
        timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp']
        tp10_raw = data['TP9']

        # Calculate sampling rate
        time_diffs = np.diff(timestamps)
        average_interval = np.mean(time_diffs)
        sampling_rate = 1 / average_interval

        # Create MNE Raw object
        info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg')
        raw = mne.io.RawArray(np.array([tp10_raw]), info)

        # Apply bandpass filter
        raw.filter(1, 50, fir_design='firwin')
        tp10_filtered_mne = raw.get_data()[0]
        # Setup figure and axes
        fig, ax = plt.subplots()
        line, = ax.plot([], [], lw=2)
        ax.set_xlim(0, 5)
        ax.set_ylim(np.min(tp10_filtered_mne), np.max(tp10_filtered_mne))
        ax.set_title("Animated EEG Trace - TP9")

        # Init function
        def init():
            line.set_data([], [])
            return line,

        # Animation function
        def animate(i):
            start = i * int(sampling_rate * 0.1)
            end = start + int(sampling_rate * 5)
            x = np.arange(start, end) / sampling_rate
            y = tp10_filtered_mne[start:end]
            line.set_data(x, y)
            return line,

        # Create animation
        ani = FuncAnimation(fig, animate, init_func=init, frames=100, interval=100, blit=True)

        # Save to temporary file as GIF
        with tempfile.NamedTemporaryFile(delete=False, suffix='.gif') as tmpfile:
            ani.save(tmpfile.name, writer='pillow', fps=10)
            st.image(tmpfile.name, caption="Animated EEG Trace", use_column_width=True)

        plt.close(fig)  # Close figure to avoid duplicate rendering

    except Exception as e:
        st.error(f"❌ Error: {str(e)}")