import streamlit as st import pandas as pd import numpy as np import plotly.graph_objects as go from plotly.subplots import make_subplots import mne from pathlib import Path import zipfile import os st.set_page_config( page_title="EEG Mental Arithmetic Explorer", page_icon="🧠", layout="wide", initial_sidebar_state="expanded" ) st.markdown(""" """, unsafe_allow_html=True) st.markdown('

EEG Mental Arithmetic Explorer

', unsafe_allow_html=True) st.markdown('

Cognitive Workload Assessment through Brain Activity Analysis

', unsafe_allow_html=True) # Data paths - Root level structure ZIP_FILE_PATH = "edf_files.zip" EDF_EXTRACT_PATH = "edf_extracted" # Uncompress EDF files if needed @st.cache_resource def extract_edf_files(): """Extract EDF files from ZIP if not already extracted""" if not os.path.exists(EDF_EXTRACT_PATH): if os.path.exists(ZIP_FILE_PATH): with st.spinner("Extracting EDF files... This may take a moment."): os.makedirs(EDF_EXTRACT_PATH, exist_ok=True) with zipfile.ZipFile(ZIP_FILE_PATH, 'r') as zip_ref: file_list = zip_ref.namelist() for file in file_list: if file.endswith('.edf') and not file.startswith('__MACOSX'): # Extract to root of EDF_EXTRACT_PATH, removing any subdirectories filename = os.path.basename(file) target_path = os.path.join(EDF_EXTRACT_PATH, filename) if not os.path.exists(target_path): with zip_ref.open(file) as source, open(target_path, 'wb') as target: target.write(source.read()) return True else: return False return True extraction_success = extract_edf_files() if not extraction_success: st.error(f"Could not find {ZIP_FILE_PATH}") st.info(""" Expected structure: ``` space/ ├── app.py ├── requirements.txt ├── README.md └── edf_files.zip ``` """) st.stop() def get_available_subjects(): """Get list of available subjects from EDF files""" edf_files = list_available_files() subjects = set() for f in edf_files: # Extract subject ID from filename (e.g., Subject01_1.edf -> Subject01) name = f.stem if '_' in name: subject_id = name.split('_')[0] subjects.add(subject_id) return sorted(list(subjects)) def list_available_files(): """List available EDF files in extracted directory""" if not os.path.exists(EDF_EXTRACT_PATH): return [] # Get only .edf files directly in the extract path (no subdirectories) edf_files = [f for f in Path(EDF_EXTRACT_PATH).glob("*.edf")] return edf_files @st.cache_resource def load_edf_data(subject_id, suffix): """Load EDF EEG data from extracted files""" # Direct path in extracted directory file_path = f"{EDF_EXTRACT_PATH}/{subject_id}{suffix}.edf" if not os.path.exists(file_path): # List available files for debugging available_files = list(Path(EDF_EXTRACT_PATH).glob("*.edf")) available_names = sorted([f.name for f in available_files]) raise FileNotFoundError( f"Could not find: {subject_id}{suffix}.edf\n" f"Available files ({len(available_names)}): {available_names[:10]}" ) try: # Load EDF with verbose to see any warnings raw = mne.io.read_raw_edf(file_path, preload=True, verbose=True) # Get data in Volts (MNE returns data in Volts by default) data = raw.get_data() # Shape: (n_channels, n_samples) # Convert to microvolts data_uv = data * 1e6 channels = raw.ch_names sfreq = raw.info['sfreq'] n_samples = data.shape[1] time = np.arange(n_samples) / sfreq # Create DataFrame with microvolts df = pd.DataFrame(data_uv.T, columns=channels) df.insert(0, 'time', time) return df, sfreq, channels, file_path except Exception as e: raise Exception(f"Error loading EDF file {file_path}: {e}") def list_available_files(): """List available EDF files in extracted directory""" if not os.path.exists(EDF_EXTRACT_PATH): return [] # Get only .edf files directly in the extract path (no subdirectories) edf_files = [f for f in Path(EDF_EXTRACT_PATH).glob("*.edf")] return edf_files st.sidebar.header("Dataset Controls") # Check available files edf_files = list_available_files() if not edf_files: st.error("No EDF files found after extraction!") st.info(f"Checked directory: {EDF_EXTRACT_PATH}") st.stop() unique_files = len(edf_files) st.sidebar.success(f"Found {unique_files} EDF files") subject_ids = get_available_subjects() if not subject_ids: st.error("No subject files found!") st.stop() selected_subject = st.sidebar.selectbox( "Select Subject", subject_ids, index=0 ) recording_type = st.sidebar.radio( "Recording Type", ["Resting State (Baseline)", "Mental Arithmetic Task"], index=0 ) suffix = "_1" if recording_type == "Resting State (Baseline)" else "_2" st.sidebar.markdown("---") st.sidebar.markdown("") # Espacio adicional st.sidebar.markdown("### Subject Information") st.sidebar.markdown(f"**ID:** {selected_subject}") st.sidebar.markdown(f"**Recording:** {recording_type}") st.sidebar.markdown("") # Espacio adicional st.sidebar.markdown("---") st.sidebar.markdown("### Data Source") st.sidebar.info("Data loaded from EDF files") # Main content tab1, tab2, tab3, tab4 = st.tabs(["Signal Viewer", "Spectral Analysis", "Statistics", "About Dataset"]) # Load data try: with st.spinner(f"Loading {selected_subject}{suffix}..."): df, sfreq, channels, file_path = load_edf_data(selected_subject, suffix) data_loaded = True st.sidebar.success(f"Loaded: {Path(file_path).name}") except Exception as e: st.error(f"Error loading data: {e}") st.info(f"Attempting to load: {selected_subject}{suffix}") data_loaded = False if data_loaded: # TAB 1: Signal Viewer with tab1: st.markdown("### EEG Signal Visualization") col1, col2, col3 = st.columns([2, 2, 1]) with col1: time_range = st.slider( "Time Window (seconds)", min_value=0.0, max_value=float(df['time'].max()), value=(0.0, min(10.0, float(df['time'].max()))), step=0.5 ) with col2: selected_channels = st.multiselect( "Select Channels", channels, default=channels[:6] if len(channels) >= 6 else channels ) with col3: plot_style = st.selectbox( "Plot Style", ["Stacked", "Overlay"] ) if selected_channels: # Filter data by time range mask = (df['time'] >= time_range[0]) & (df['time'] <= time_range[1]) df_plot = df[mask] if plot_style == "Stacked": # Create stacked subplots fig = make_subplots( rows=len(selected_channels), cols=1, shared_xaxes=True, vertical_spacing=0.02, subplot_titles=selected_channels ) for idx, channel in enumerate(selected_channels, 1): fig.add_trace( go.Scatter( x=df_plot['time'], y=df_plot[channel], mode='lines', name=channel, line=dict(width=1), showlegend=False ), row=idx, col=1 ) fig.update_layout( height=150 * len(selected_channels), showlegend=False, hovermode='x unified' ) fig.update_xaxes(title_text="Time (s)", row=len(selected_channels), col=1) else: # Overlay fig = go.Figure() for channel in selected_channels: fig.add_trace( go.Scatter( x=df_plot['time'], y=df_plot[channel], mode='lines', name=channel, line=dict(width=1) ) ) fig.update_layout( height=600, xaxis_title="Time (s)", yaxis_title="Amplitude (μV)", hovermode='x unified', legend=dict( orientation="v", yanchor="top", y=1, xanchor="left", x=1.01 ) ) st.plotly_chart(fig, use_container_width=True) # Signal metrics st.markdown("### Signal Metrics") metric_cols = st.columns(4) with metric_cols[0]: st.metric("Channels", len(selected_channels)) with metric_cols[1]: st.metric("Sampling Rate", f"{sfreq:.0f} Hz") with metric_cols[2]: st.metric("Duration", f"{df['time'].max():.2f} s") with metric_cols[3]: st.metric("Samples", len(df_plot)) else: st.warning("Please select at least one channel to display") # TAB 2: Spectral Analysis with tab2: st.markdown("### Power Spectral Density Analysis") col1, col2 = st.columns([3, 1]) with col2: channel_for_psd = st.selectbox( "Select Channel for PSD", channels, index=0 ) freq_bands = st.checkbox("Show Frequency Bands", value=True) # Compute PSD from scipy import signal channel_data = df[channel_for_psd].values frequencies, psd = signal.welch(channel_data, fs=sfreq, nperseg=min(256, len(channel_data))) # Plot PSD fig = go.Figure() fig.add_trace(go.Scatter( x=frequencies, y=10 * np.log10(psd), mode='lines', name='PSD', line=dict(color='steelblue', width=2) )) # Add frequency bands if selected if freq_bands: bands = { 'Delta': (0.5, 4, 'rgba(255, 0, 0, 0.1)'), 'Theta': (4, 8, 'rgba(255, 165, 0, 0.1)'), 'Alpha': (8, 13, 'rgba(255, 255, 0, 0.1)'), 'Beta': (13, 30, 'rgba(0, 255, 0, 0.1)'), 'Gamma': (30, 50, 'rgba(0, 0, 255, 0.1)') } # Add colored bands for band_name, (low, high, color) in bands.items(): fig.add_vrect( x0=low, x1=high, fillcolor=color, layer="below", line_width=0 ) # Add annotations at the top of the plot y_max = 10 * np.log10(psd).max() annotations = [] for band_name, (low, high, color) in bands.items(): mid_freq = (low + high) / 2 annotations.append( dict( x=mid_freq, y=y_max, text=band_name, showarrow=False, font=dict(size=10, color='black'), bgcolor='rgba(255, 255, 255, 0.8)', borderpad=4 ) ) fig.update_layout(annotations=annotations) fig.update_layout( height=500, xaxis_title="Frequency (Hz)", yaxis_title="Power Spectral Density (dB/Hz)", hovermode='x' ) fig.update_xaxes(range=[0, 100]) st.plotly_chart(fig, use_container_width=True) # Band power analysis st.markdown("### Band Power Analysis") bands_power = { 'Delta': (0.5, 4), 'Theta': (4, 8), 'Alpha': (8, 13), 'Beta': (13, 30), 'Gamma': (30, 50) } band_powers = {} for band_name, (low, high) in bands_power.items(): mask = (frequencies >= low) & (frequencies <= high) # Use trapezoid instead of trapz (numpy 2.0+) band_powers[band_name] = np.trapezoid(psd[mask], frequencies[mask]) # Plot band powers fig_bands = go.Figure(data=[ go.Bar( x=list(band_powers.keys()), y=list(band_powers.values()), marker_color=['#ff6b6b', '#ffa500', '#ffff00', '#90ee90', '#6495ed'] ) ]) fig_bands.update_layout( height=400, xaxis_title="Frequency Band", yaxis_title="Absolute Power", showlegend=False ) st.plotly_chart(fig_bands, use_container_width=True) # TAB 3: Statistics with tab3: st.markdown("### Statistical Analysis") # Channel statistics table stats_data = [] for channel in channels: channel_series = df[channel] mean_val = float(channel_series.mean()) std_val = float(channel_series.std()) min_val = float(channel_series.min()) max_val = float(channel_series.max()) stats_data.append({ 'Channel': channel, 'Mean (μV)': mean_val, 'Std (μV)': std_val, 'Min (μV)': min_val, 'Max (μV)': max_val, 'Range (μV)': max_val - min_val }) stats_df = pd.DataFrame(stats_data) # Format numeric columns to 2 decimals numeric_cols = ['Mean (μV)', 'Std (μV)', 'Min (μV)', 'Max (μV)', 'Range (μV)'] for col in numeric_cols: stats_df[col] = stats_df[col].apply(lambda x: f"{x:.2f}") st.dataframe(stats_df, height=400) # Correlation heatmap st.markdown("### Channel Correlation Matrix") corr_matrix = df[channels].corr() fig_corr = go.Figure(data=go.Heatmap( z=corr_matrix.values, x=channels, y=channels, colorscale='RdBu', zmid=0, text=corr_matrix.values, texttemplate='%{text:.2f}', textfont={"size": 8}, colorbar=dict(title="Correlation") )) fig_corr.update_layout( height=750, title="Channel Correlation Matrix" ) st.plotly_chart(fig_corr, use_container_width=True) # TAB 4: About with tab4: st.markdown(""" ### About This Dataset This dataset contains EEG recordings from 36 healthy participants during resting state and mental arithmetic task performance. #### Key Features - **Participants**: 36 healthy subjects - **Recordings**: Paired (resting state + task) - **Channels**: 23 EEG channels (International 10/20 system) - **Duration**: 60 seconds per recording - **Sampling Rate**: Approximately 500 Hz - **Task**: Serial subtraction (4-digit minus 2-digit numbers) #### Subject Groups - **Good Performers** (24 subjects): Mean 21 operations in 4 minutes - **Poor Performers** (12 subjects): Mean 7 operations in 4 minutes #### Preprocessing - High-pass filter at 30 Hz - Notch filter at 50 Hz - ICA artifact removal (eyes, muscles, cardiac) #### Citation ``` Zyma I, Tukaev S, Seleznov I, Kiyono K, Popov A, Chernykh M, Shpenkov O. Electroencephalograms during Mental Arithmetic Task Performance. Data. 2019; 4(1):14. https://doi.org/10.3390/data4010014 ``` #### Resources - [PhysioNet Dataset](https://physionet.org/content/eegmat/1.0.0/) - [Original Paper](https://doi.org/10.3390/data4010014) - [Hugging Face Dataset](https://huggingface.co/datasets/BrainSpectralAnalytics/eeg-mental-arithmetic) #### Contact Ivan Seleznov: ivan.seleznov1@gmail.com """) else: st.warning("Unable to load data. Please check the selected subject and recording type.") # Footer st.markdown("---") st.markdown( '

Built with Streamlit | EEG Mental Arithmetic Dataset Explorer

', unsafe_allow_html=True )