Spaces:
Sleeping
Sleeping
| import os | |
| import io | |
| import numpy as np | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import mne | |
| from mne.channels import make_dig_montage | |
| from matplotlib.animation import FuncAnimation | |
| import requests | |
| import streamlit as st | |
| import tempfile | |
| from scipy.signal import welch | |
| os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" | |
| os.environ["STREAMLIT_HOME"] = "/tmp/streamlit" | |
| # Set page config | |
| st.set_page_config(page_title="Muse EEG Topomap Viewer", layout="centered") | |
| def plot_eeg_topomap_muse_from_csv(csv_url, save_path_animation=None, save_directory=None, | |
| show_names=False, start_time=0.05, end_time=1, step_size=0.1): | |
| response = requests.get(csv_url) | |
| if response.status_code != 200: | |
| raise RuntimeError(f"Failed to download CSV from {csv_url}") | |
| data = pd.read_csv(io.StringIO(response.text)) | |
| st.write(f"Loaded data with shape: {data.shape}") | |
| muse_channels = ['AF7', 'AF8', 'TP9', 'TP10'] | |
| if not all(ch in data.columns for ch in muse_channels): | |
| raise ValueError(f"The dataset must contain the following channels: {muse_channels}") | |
| eeg_data = data[muse_channels].values.T | |
| muse_positions = { | |
| 'AF7': [-0.05, 0.085, 0], | |
| 'AF8': [0.05, 0.085, 0], | |
| 'TP9': [-0.08, -0.04, 0], | |
| 'TP10': [0.08, -0.04, 0] | |
| } | |
| for ch in muse_positions: | |
| muse_positions[ch][0] += np.random.normal(0, 0.0001) | |
| muse_positions[ch][1] += np.random.normal(0, 0.0001) | |
| montage = make_dig_montage(ch_pos=muse_positions, coord_frame='head') | |
| if 'TimeStamp' in data.columns: | |
| timestamps = data['TimeStamp'] | |
| elif 'timestamps' in data.columns: | |
| timestamps = data['timestamps'] | |
| else: | |
| raise ValueError("CSV must contain a 'TimeStamp' or 'timestamps' column") | |
| time_diffs = np.diff(timestamps) | |
| average_interval = np.mean(time_diffs) | |
| sfreq = 1 / average_interval | |
| info = mne.create_info(ch_names=muse_channels, sfreq=sfreq, ch_types='eeg') | |
| evoked = mne.EvokedArray(eeg_data, info) | |
| evoked.set_montage(montage) | |
| times = np.arange(start_time, end_time, step_size) | |
| if len(times) == 0: | |
| raise ValueError("Not enough time range selected for topomap animation.") | |
| if save_path_animation: | |
| if save_directory: | |
| os.makedirs(save_directory, exist_ok=True) | |
| save_path_animation = os.path.join(save_directory, os.path.basename(save_path_animation)) | |
| fig, ax = plt.subplots() | |
| def update(frame): | |
| ax.clear() | |
| evoked.plot_topomap([times[frame]], ch_type='eeg', time_unit='s', | |
| axes=ax, colorbar=False, show=False) | |
| anim = FuncAnimation(fig, update, frames=range(len(times)), interval=200) | |
| anim.save(save_path_animation, writer='pillow') | |
| plt.close(fig) | |
| return save_path_animation | |
| else: | |
| return None | |
| # Streamlit UI | |
| st.title("π§ Muse EEG Topomap Animation from CSV") | |
| default_url = "https://raw.githubusercontent.com/garenasd945/EEG_muse2/refs/heads/master/dataset/original_data/subjecta-relaxed-1.csv" | |
| csv_url = st.text_input("CSV URL", value=default_url) | |
| generate_button = st.button("Generate Topomap Animation") | |
| if generate_button: | |
| try: | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| save_path_animation = os.path.join(tmpdir, "eeg_topomap_animation.gif") | |
| gif_path = plot_eeg_topomap_muse_from_csv( | |
| csv_url, | |
| save_path_animation=save_path_animation, | |
| save_directory=tmpdir, | |
| show_names=True | |
| ) | |
| if gif_path: | |
| st.image(gif_path, caption="EEG Topomap Animation") | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| plot_waveform_button = st.button("Plot TP9 Waveforms") | |
| if plot_waveform_button: | |
| try: | |
| # Download and read CSV | |
| response = requests.get(csv_url) | |
| response.raise_for_status() | |
| data = pd.read_csv(io.StringIO(response.text)) | |
| st.write(f"Loaded data with shape: {data.shape}") | |
| # Extract timestamps and TP9 | |
| timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp'] | |
| tp10_raw = data['TP9'] | |
| # Calculate sampling rate | |
| time_diffs = np.diff(timestamps) | |
| average_interval = np.mean(time_diffs) | |
| sampling_rate = 1 / average_interval | |
| # Create MNE Raw object | |
| info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg') | |
| raw = mne.io.RawArray(np.array([tp10_raw]), info) | |
| # Apply bandpass filter | |
| raw.filter(1, 50, fir_design='firwin') | |
| tp10_filtered_mne = raw.get_data()[0] | |
| # Plot raw waveform | |
| fig_raw, ax_raw = plt.subplots(figsize=(12, 4)) | |
| ax_raw.plot(timestamps, tp10_raw, label='TP9', alpha=0.6, color='blue') | |
| ax_raw.set_title('Raw Waveform of TP9') | |
| ax_raw.set_xlabel('Time (seconds)') | |
| ax_raw.set_ylabel('Amplitude') | |
| ax_raw.legend() | |
| ax_raw.grid(True) | |
| st.pyplot(fig_raw) | |
| # Plot filtered waveform | |
| fig_filtered, ax_filtered = plt.subplots(figsize=(12, 4)) | |
| ax_filtered.plot(timestamps, tp10_filtered_mne, label='TP9 (1-50Hz)', linewidth=1.5, color='green') | |
| ax_filtered.set_title('Filtered Waveform of TP9') | |
| ax_filtered.set_xlabel('Time (seconds)') | |
| ax_filtered.set_ylabel('Amplitude') | |
| ax_filtered.legend() | |
| ax_filtered.grid(True) | |
| st.pyplot(fig_filtered) | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| spectrogram_button = st.button("Plot TP9 Spectrograms") | |
| if spectrogram_button: | |
| try: | |
| # Download and read CSV | |
| response = requests.get(csv_url) | |
| response.raise_for_status() | |
| data = pd.read_csv(io.StringIO(response.text)) | |
| st.write(f"Loaded data with shape: {data.shape}") | |
| # Extract timestamps and TP9 | |
| timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp'] | |
| tp10_raw = data['TP9'] | |
| # Calculate sampling rate | |
| time_diffs = np.diff(timestamps) | |
| average_interval = np.mean(time_diffs) | |
| sampling_rate = 1 / average_interval | |
| # Create MNE Raw object | |
| info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg') | |
| raw = mne.io.RawArray(np.array([tp10_raw]), info) | |
| # Apply bandpass filter | |
| raw.filter(1, 50, fir_design='firwin') | |
| tp10_filtered_mne = raw.get_data()[0] | |
| from scipy.signal import spectrogram | |
| # Spectrogram before filtering | |
| frequencies_pre, times_pre, spectrogram_pre = spectrogram(tp10_raw, fs=sampling_rate, nperseg=256) | |
| fig_spec_pre, ax_pre = plt.subplots(figsize=(10, 6)) | |
| p1 = ax_pre.pcolormesh(times_pre, frequencies_pre, np.log10(spectrogram_pre + 1e-10), shading='gouraud') | |
| fig_spec_pre.colorbar(p1, ax=ax_pre, label='Log Power Spectral Density') | |
| ax_pre.set_title('Spectrogram of TP9 (Raw)') | |
| ax_pre.set_xlabel('Time (seconds)') | |
| ax_pre.set_ylabel('Frequency (Hz)') | |
| st.pyplot(fig_spec_pre) | |
| # Spectrogram after filtering | |
| frequencies_post, times_post, spectrogram_post = spectrogram(tp10_filtered_mne, fs=sampling_rate, nperseg=256) | |
| fig_spec_post, ax_post = plt.subplots(figsize=(10, 6)) | |
| p2 = ax_post.pcolormesh(times_post, frequencies_post, np.log10(spectrogram_post + 1e-10), shading='gouraud') | |
| fig_spec_post.colorbar(p2, ax=ax_post, label='Log Power Spectral Density') | |
| ax_post.set_title('Spectrogram of TP9 (Filtered 1β50Hz)') | |
| ax_post.set_xlabel('Time (seconds)') | |
| ax_post.set_ylabel('Frequency (Hz)') | |
| st.pyplot(fig_spec_post) | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| # Button to calculate and plot band powers | |
| if st.button("Show Band Power Chart (TP9)"): | |
| try: | |
| response = requests.get(csv_url) | |
| response.raise_for_status() | |
| data = pd.read_csv(io.StringIO(response.text)) | |
| st.write(f"Loaded data with shape: {data.shape}") | |
| # Extract timestamps and TP9 | |
| timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp'] | |
| tp10_raw = data['TP9'] | |
| # Calculate sampling rate | |
| time_diffs = np.diff(timestamps) | |
| average_interval = np.mean(time_diffs) | |
| sampling_rate = 1 / average_interval | |
| # Create MNE Raw object | |
| info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg') | |
| raw = mne.io.RawArray(np.array([tp10_raw]), info) | |
| # Apply bandpass filter | |
| raw.filter(1, 50, fir_design='firwin') | |
| tp10_filtered_mne = raw.get_data()[0] | |
| # Define bands | |
| bands = { | |
| "Delta (1β4 Hz)": (1, 4), | |
| "Theta (4β8 Hz)": (4, 8), | |
| "Alpha (8β12 Hz)": (8, 12), | |
| "Beta (12β30 Hz)": (12, 30), | |
| "Gamma (30β50 Hz)": (30, 50) | |
| } | |
| # Bandpower calculation | |
| from scipy.signal import welch | |
| def bandpower(data, sf, band, window_sec=None): | |
| band = np.asarray(band) | |
| low, high = band | |
| nperseg = int(window_sec * sf) if window_sec else None | |
| freqs, psd = welch(data, sf, nperseg=nperseg) | |
| idx_band = np.logical_and(freqs >= low, freqs <= high) | |
| return np.trapz(psd[idx_band], freqs[idx_band]) | |
| # Calculate power for each band | |
| powers = [bandpower(tp10_filtered_mne, sampling_rate, b) for b in bands.values()] | |
| # Plot | |
| fig, ax = plt.subplots() | |
| ax.bar(bands.keys(), powers, color='skyblue') | |
| ax.set_title("Band Power - TP9") | |
| ax.set_ylabel("Power") | |
| ax.tick_params(axis='x', rotation=45) | |
| st.pyplot(fig) | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |
| # Button to show EEG animation | |
| if st.button("Show Animated EEG Trace (TP9)"): | |
| try: | |
| response = requests.get(csv_url) | |
| response.raise_for_status() | |
| data = pd.read_csv(io.StringIO(response.text)) | |
| st.write(f"Loaded data with shape: {data.shape}") | |
| # Extract timestamps and TP9 | |
| timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp'] | |
| tp10_raw = data['TP9'] | |
| # Calculate sampling rate | |
| time_diffs = np.diff(timestamps) | |
| average_interval = np.mean(time_diffs) | |
| sampling_rate = 1 / average_interval | |
| # Create MNE Raw object | |
| info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg') | |
| raw = mne.io.RawArray(np.array([tp10_raw]), info) | |
| # Apply bandpass filter | |
| raw.filter(1, 50, fir_design='firwin') | |
| tp10_filtered_mne = raw.get_data()[0] | |
| # Setup figure and axes | |
| fig, ax = plt.subplots() | |
| line, = ax.plot([], [], lw=2) | |
| ax.set_xlim(0, 5) | |
| ax.set_ylim(np.min(tp10_filtered_mne), np.max(tp10_filtered_mne)) | |
| ax.set_title("Animated EEG Trace - TP9") | |
| # Init function | |
| def init(): | |
| line.set_data([], []) | |
| return line, | |
| # Animation function | |
| def animate(i): | |
| start = i * int(sampling_rate * 0.1) | |
| end = start + int(sampling_rate * 5) | |
| x = np.arange(start, end) / sampling_rate | |
| y = tp10_filtered_mne[start:end] | |
| line.set_data(x, y) | |
| return line, | |
| # Create animation | |
| ani = FuncAnimation(fig, animate, init_func=init, frames=100, interval=100, blit=True) | |
| # Save to temporary file as GIF | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.gif') as tmpfile: | |
| ani.save(tmpfile.name, writer='pillow', fps=10) | |
| st.image(tmpfile.name, caption="Animated EEG Trace", use_column_width=True) | |
| plt.close(fig) # Close figure to avoid duplicate rendering | |
| except Exception as e: | |
| st.error(f"β Error: {str(e)}") | |