import os import io import numpy as np import pandas as pd import matplotlib.pyplot as plt import mne from mne.channels import make_dig_montage from matplotlib.animation import FuncAnimation import requests import streamlit as st import tempfile from scipy.signal import welch os.environ["MPLCONFIGDIR"] = "/tmp/matplotlib" os.environ["STREAMLIT_HOME"] = "/tmp/streamlit" # Set page config st.set_page_config(page_title="Muse EEG Topomap Viewer", layout="centered") def plot_eeg_topomap_muse_from_csv(csv_url, save_path_animation=None, save_directory=None, show_names=False, start_time=0.05, end_time=1, step_size=0.1): response = requests.get(csv_url) if response.status_code != 200: raise RuntimeError(f"Failed to download CSV from {csv_url}") data = pd.read_csv(io.StringIO(response.text)) st.write(f"Loaded data with shape: {data.shape}") muse_channels = ['AF7', 'AF8', 'TP9', 'TP10'] if not all(ch in data.columns for ch in muse_channels): raise ValueError(f"The dataset must contain the following channels: {muse_channels}") eeg_data = data[muse_channels].values.T muse_positions = { 'AF7': [-0.05, 0.085, 0], 'AF8': [0.05, 0.085, 0], 'TP9': [-0.08, -0.04, 0], 'TP10': [0.08, -0.04, 0] } for ch in muse_positions: muse_positions[ch][0] += np.random.normal(0, 0.0001) muse_positions[ch][1] += np.random.normal(0, 0.0001) montage = make_dig_montage(ch_pos=muse_positions, coord_frame='head') if 'TimeStamp' in data.columns: timestamps = data['TimeStamp'] elif 'timestamps' in data.columns: timestamps = data['timestamps'] else: raise ValueError("CSV must contain a 'TimeStamp' or 'timestamps' column") time_diffs = np.diff(timestamps) average_interval = np.mean(time_diffs) sfreq = 1 / average_interval info = mne.create_info(ch_names=muse_channels, sfreq=sfreq, ch_types='eeg') evoked = mne.EvokedArray(eeg_data, info) evoked.set_montage(montage) times = np.arange(start_time, end_time, step_size) if len(times) == 0: raise ValueError("Not enough time range selected for topomap animation.") if save_path_animation: if save_directory: os.makedirs(save_directory, exist_ok=True) save_path_animation = os.path.join(save_directory, os.path.basename(save_path_animation)) fig, ax = plt.subplots() def update(frame): ax.clear() evoked.plot_topomap([times[frame]], ch_type='eeg', time_unit='s', axes=ax, colorbar=False, show=False) anim = FuncAnimation(fig, update, frames=range(len(times)), interval=200) anim.save(save_path_animation, writer='pillow') plt.close(fig) return save_path_animation else: return None # Streamlit UI st.title("🧠 Muse EEG Topomap Animation from CSV") default_url = "https://raw.githubusercontent.com/garenasd945/EEG_muse2/refs/heads/master/dataset/original_data/subjecta-relaxed-1.csv" csv_url = st.text_input("CSV URL", value=default_url) generate_button = st.button("Generate Topomap Animation") if generate_button: try: with tempfile.TemporaryDirectory() as tmpdir: save_path_animation = os.path.join(tmpdir, "eeg_topomap_animation.gif") gif_path = plot_eeg_topomap_muse_from_csv( csv_url, save_path_animation=save_path_animation, save_directory=tmpdir, show_names=True ) if gif_path: st.image(gif_path, caption="EEG Topomap Animation") except Exception as e: st.error(f"❌ Error: {str(e)}") plot_waveform_button = st.button("Plot TP9 Waveforms") if plot_waveform_button: try: # Download and read CSV response = requests.get(csv_url) response.raise_for_status() data = pd.read_csv(io.StringIO(response.text)) st.write(f"Loaded data with shape: {data.shape}") # Extract timestamps and TP9 timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp'] tp10_raw = data['TP9'] # Calculate sampling rate time_diffs = np.diff(timestamps) average_interval = np.mean(time_diffs) sampling_rate = 1 / average_interval # Create MNE Raw object info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg') raw = mne.io.RawArray(np.array([tp10_raw]), info) # Apply bandpass filter raw.filter(1, 50, fir_design='firwin') tp10_filtered_mne = raw.get_data()[0] # Plot raw waveform fig_raw, ax_raw = plt.subplots(figsize=(12, 4)) ax_raw.plot(timestamps, tp10_raw, label='TP9', alpha=0.6, color='blue') ax_raw.set_title('Raw Waveform of TP9') ax_raw.set_xlabel('Time (seconds)') ax_raw.set_ylabel('Amplitude') ax_raw.legend() ax_raw.grid(True) st.pyplot(fig_raw) # Plot filtered waveform fig_filtered, ax_filtered = plt.subplots(figsize=(12, 4)) ax_filtered.plot(timestamps, tp10_filtered_mne, label='TP9 (1-50Hz)', linewidth=1.5, color='green') ax_filtered.set_title('Filtered Waveform of TP9') ax_filtered.set_xlabel('Time (seconds)') ax_filtered.set_ylabel('Amplitude') ax_filtered.legend() ax_filtered.grid(True) st.pyplot(fig_filtered) except Exception as e: st.error(f"❌ Error: {str(e)}") spectrogram_button = st.button("Plot TP9 Spectrograms") if spectrogram_button: try: # Download and read CSV response = requests.get(csv_url) response.raise_for_status() data = pd.read_csv(io.StringIO(response.text)) st.write(f"Loaded data with shape: {data.shape}") # Extract timestamps and TP9 timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp'] tp10_raw = data['TP9'] # Calculate sampling rate time_diffs = np.diff(timestamps) average_interval = np.mean(time_diffs) sampling_rate = 1 / average_interval # Create MNE Raw object info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg') raw = mne.io.RawArray(np.array([tp10_raw]), info) # Apply bandpass filter raw.filter(1, 50, fir_design='firwin') tp10_filtered_mne = raw.get_data()[0] from scipy.signal import spectrogram # Spectrogram before filtering frequencies_pre, times_pre, spectrogram_pre = spectrogram(tp10_raw, fs=sampling_rate, nperseg=256) fig_spec_pre, ax_pre = plt.subplots(figsize=(10, 6)) p1 = ax_pre.pcolormesh(times_pre, frequencies_pre, np.log10(spectrogram_pre + 1e-10), shading='gouraud') fig_spec_pre.colorbar(p1, ax=ax_pre, label='Log Power Spectral Density') ax_pre.set_title('Spectrogram of TP9 (Raw)') ax_pre.set_xlabel('Time (seconds)') ax_pre.set_ylabel('Frequency (Hz)') st.pyplot(fig_spec_pre) # Spectrogram after filtering frequencies_post, times_post, spectrogram_post = spectrogram(tp10_filtered_mne, fs=sampling_rate, nperseg=256) fig_spec_post, ax_post = plt.subplots(figsize=(10, 6)) p2 = ax_post.pcolormesh(times_post, frequencies_post, np.log10(spectrogram_post + 1e-10), shading='gouraud') fig_spec_post.colorbar(p2, ax=ax_post, label='Log Power Spectral Density') ax_post.set_title('Spectrogram of TP9 (Filtered 1–50Hz)') ax_post.set_xlabel('Time (seconds)') ax_post.set_ylabel('Frequency (Hz)') st.pyplot(fig_spec_post) except Exception as e: st.error(f"❌ Error: {str(e)}") # Button to calculate and plot band powers if st.button("Show Band Power Chart (TP9)"): try: response = requests.get(csv_url) response.raise_for_status() data = pd.read_csv(io.StringIO(response.text)) st.write(f"Loaded data with shape: {data.shape}") # Extract timestamps and TP9 timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp'] tp10_raw = data['TP9'] # Calculate sampling rate time_diffs = np.diff(timestamps) average_interval = np.mean(time_diffs) sampling_rate = 1 / average_interval # Create MNE Raw object info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg') raw = mne.io.RawArray(np.array([tp10_raw]), info) # Apply bandpass filter raw.filter(1, 50, fir_design='firwin') tp10_filtered_mne = raw.get_data()[0] # Define bands bands = { "Delta (1–4 Hz)": (1, 4), "Theta (4–8 Hz)": (4, 8), "Alpha (8–12 Hz)": (8, 12), "Beta (12–30 Hz)": (12, 30), "Gamma (30–50 Hz)": (30, 50) } # Bandpower calculation from scipy.signal import welch def bandpower(data, sf, band, window_sec=None): band = np.asarray(band) low, high = band nperseg = int(window_sec * sf) if window_sec else None freqs, psd = welch(data, sf, nperseg=nperseg) idx_band = np.logical_and(freqs >= low, freqs <= high) return np.trapz(psd[idx_band], freqs[idx_band]) # Calculate power for each band powers = [bandpower(tp10_filtered_mne, sampling_rate, b) for b in bands.values()] # Plot fig, ax = plt.subplots() ax.bar(bands.keys(), powers, color='skyblue') ax.set_title("Band Power - TP9") ax.set_ylabel("Power") ax.tick_params(axis='x', rotation=45) st.pyplot(fig) except Exception as e: st.error(f"❌ Error: {str(e)}") # Button to show EEG animation if st.button("Show Animated EEG Trace (TP9)"): try: response = requests.get(csv_url) response.raise_for_status() data = pd.read_csv(io.StringIO(response.text)) st.write(f"Loaded data with shape: {data.shape}") # Extract timestamps and TP9 timestamps = data['timestamps'] if 'timestamps' in data.columns else data['TimeStamp'] tp10_raw = data['TP9'] # Calculate sampling rate time_diffs = np.diff(timestamps) average_interval = np.mean(time_diffs) sampling_rate = 1 / average_interval # Create MNE Raw object info = mne.create_info(ch_names=['TP9'], sfreq=sampling_rate, ch_types='eeg') raw = mne.io.RawArray(np.array([tp10_raw]), info) # Apply bandpass filter raw.filter(1, 50, fir_design='firwin') tp10_filtered_mne = raw.get_data()[0] # Setup figure and axes fig, ax = plt.subplots() line, = ax.plot([], [], lw=2) ax.set_xlim(0, 5) ax.set_ylim(np.min(tp10_filtered_mne), np.max(tp10_filtered_mne)) ax.set_title("Animated EEG Trace - TP9") # Init function def init(): line.set_data([], []) return line, # Animation function def animate(i): start = i * int(sampling_rate * 0.1) end = start + int(sampling_rate * 5) x = np.arange(start, end) / sampling_rate y = tp10_filtered_mne[start:end] line.set_data(x, y) return line, # Create animation ani = FuncAnimation(fig, animate, init_func=init, frames=100, interval=100, blit=True) # Save to temporary file as GIF with tempfile.NamedTemporaryFile(delete=False, suffix='.gif') as tmpfile: ani.save(tmpfile.name, writer='pillow', fps=10) st.image(tmpfile.name, caption="Animated EEG Trace", use_column_width=True) plt.close(fig) # Close figure to avoid duplicate rendering except Exception as e: st.error(f"❌ Error: {str(e)}")