Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import mne | |
| import numpy as np | |
| import pandas as pd | |
| import tempfile | |
| import zipfile | |
| from pathlib import Path | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| mne.set_log_level('WARNING') | |
| # Global state (okay for local use) | |
| raw_data = None | |
| events_df_global = None | |
| events_wide_global = None | |
| # ============================================================================= | |
| # UTILITIES | |
| # ============================================================================= | |
| def discover_stim_channels(raw): | |
| stim_candidates = [] | |
| for ch_name, ch_type in zip(raw.ch_names, raw.get_channel_types()): | |
| if ch_type in ['stim', 'misc']: | |
| stim_candidates.append(ch_name) | |
| elif any(kw in ch_name.upper() for kw in ['STI', 'TRIG', 'TTL', 'STIM', 'TRIGGER']): | |
| stim_candidates.append(ch_name) | |
| elif re.match(r'^\d+[a-zA-Z]+$', ch_name): | |
| stim_candidates.append(ch_name) | |
| return sorted(set(stim_candidates)) | |
| def extract_events_from_channels(raw, stim_channels): | |
| events = [] | |
| sfreq = raw.info['sfreq'] | |
| for ch in stim_channels: | |
| if ch not in raw.ch_names: | |
| continue | |
| data = raw.get_data(picks=[ch])[0] | |
| digital = (data > 0.5).astype(int) | |
| diff = np.diff(digital, prepend=0) | |
| onset_samples = np.where(diff == 1)[0] | |
| onset_times = onset_samples / sfreq | |
| for t, s in zip(onset_times, onset_samples): | |
| events.append({'channel': ch, 'time_sec': t, 'sample': int(s)}) | |
| if not events: | |
| return pd.DataFrame() | |
| df = pd.DataFrame(events) | |
| return df.sort_values('time_sec').reset_index(drop=True) | |
| def events_df_to_wide_format(events_df): | |
| if events_df.empty: | |
| return pd.DataFrame() | |
| cleaned_events = [] | |
| for ch in sorted(events_df['channel'].unique()): | |
| ch_events = events_df[events_df['channel'] == ch].copy() | |
| ch_events = ch_events.sort_values('time_sec').reset_index(drop=True) | |
| times = ch_events['time_sec'].values | |
| i = 0 | |
| while i < len(times): | |
| burst_start = times[i] | |
| cleaned_events.append({'channel': ch, 'time_sec': burst_start}) | |
| j = i | |
| while j + 1 < len(times) and (times[j + 1] - times[j]) <= 0.9: | |
| j += 1 | |
| i = j + 1 | |
| cleaned_df = pd.DataFrame(cleaned_events) | |
| wide_data = {} | |
| max_events = 0 | |
| for ch in sorted(cleaned_df['channel'].unique()): | |
| times = cleaned_df[cleaned_df['channel'] == ch]['time_sec'].values | |
| wide_data[ch] = times | |
| max_events = max(max_events, len(times)) | |
| for ch in wide_data: | |
| times = wide_data[ch] | |
| if len(times) < max_events: | |
| wide_data[ch] = np.pad(times, (0, max_events - len(times)), constant_values=np.nan) | |
| df_wide = pd.DataFrame(wide_data) | |
| df_wide.index = np.arange(1, len(df_wide) + 1) | |
| df_wide.index.name = 'Event_Index' | |
| return df_wide | |
| def plot_events_timeline_plotly(events_df, raw_duration): | |
| if events_df.empty: | |
| return None | |
| fig = go.Figure() | |
| for i, ch in enumerate(sorted(events_df['channel'].unique())): | |
| ch_events = events_df[events_df['channel'] == ch] | |
| fig.add_trace(go.Scatter( | |
| x=ch_events['time_sec'], | |
| y=[i] * len(ch_events), | |
| mode='markers', | |
| name=f"{ch} (n={len(ch_events)})", | |
| marker=dict(size=8, opacity=0.7) | |
| )) | |
| fig.update_layout( | |
| xaxis_title='Time (seconds)', | |
| yaxis_title='Channel', | |
| xaxis=dict(range=[0, raw_duration]), | |
| height=500, | |
| showlegend=True, | |
| hovermode='closest' | |
| ) | |
| return fig | |
| def plot_raw_data_plotly(raw, stim_channels, duration=120.0): | |
| stim_picks = [raw.ch_names.index(ch) for ch in stim_channels if ch in raw.ch_names] | |
| all_picks = np.arange(len(raw.ch_names)) | |
| non_stim_picks = np.setdiff1d(all_picks, stim_picks) | |
| if len(non_stim_picks) == 0: | |
| return None | |
| n_channels = min(45, len(non_stim_picks)) | |
| picks = non_stim_picks[:n_channels] | |
| # π§ SAFETY FIX: avoid IndexError if raw.times is empty | |
| recording_duration = raw.times[-1] if len(raw.times) > 0 else 0.0 | |
| end_time = min(duration, recording_duration) | |
| end_sample = int(end_time * raw.info['sfreq']) | |
| data, times = raw[picks, :end_sample] | |
| fig = go.Figure() | |
| for i, ch_idx in enumerate(picks): | |
| ch_data = data[i] | |
| ch_data = ch_data - np.mean(ch_data) | |
| if np.std(ch_data) > 1e-15: | |
| ch_data = ch_data / np.std(ch_data) | |
| ch_data = ch_data + i * 3 | |
| fig.add_trace(go.Scatter( | |
| x=times, | |
| y=ch_data, | |
| mode='lines', | |
| name=raw.ch_names[ch_idx], | |
| line=dict(width=0.8), | |
| showlegend=False | |
| )) | |
| fig.update_layout( | |
| xaxis_title='Time (s)', | |
| yaxis_title='Channels (normalized + offset)', | |
| height=min(600, 100 + 120 * n_channels), | |
| hovermode='x unified' | |
| ) | |
| return fig | |
| def validate_and_clip_crop_regions(crop_regions, max_time): | |
| validated = [] | |
| warnings = [] | |
| for start, end in crop_regions: | |
| if start > max_time: | |
| warnings.append(f"β οΈ Skipping region ({start:.2f}, {end:.2f}) β starts after recording ends.") | |
| continue | |
| clipped_end = min(end, max_time) | |
| if clipped_end <= start: | |
| warnings.append(f"β οΈ Skipping region ({start:.2f}, {end:.2f}) β no valid duration.") | |
| continue | |
| if end > max_time: | |
| warnings.append(f"β οΈ Region ({start:.2f}, {end:.2f}) clipped to ({start:.2f}, {clipped_end:.2f})") | |
| validated.append((start, clipped_end)) | |
| return validated, warnings | |
| # ============================================================================= | |
| # GRADIO FUNCTIONS | |
| # ============================================================================= | |
| def inspect_eeg(file_obj): | |
| global raw_data, events_df_global, events_wide_global | |
| if file_obj is None: | |
| return None, None, "Please upload an EEG file", None | |
| temp_extract_dir = None | |
| try: | |
| file_path = file_obj.name | |
| stem = Path(file_path).stem | |
| suffixes = Path(file_path).suffixes | |
| load_path = file_path | |
| if len(suffixes) >= 2 and suffixes[-1] == '.zip': | |
| temp_extract_dir = tempfile.mkdtemp() | |
| with zipfile.ZipFile(file_path, 'r') as zip_ref: | |
| zip_ref.extractall(temp_extract_dir) | |
| mff_dir = None | |
| for item in Path(temp_extract_dir).iterdir(): | |
| if item.is_dir() and item.suffix == '.mff': | |
| mff_dir = item | |
| break | |
| if mff_dir is None: | |
| return None, None, "ZIP must contain a .mff folder (e.g., 'data.mff/')", None | |
| load_path = str(mff_dir) | |
| raw = mne.io.read_raw_egi(load_path, preload=True, verbose=False) | |
| elif Path(file_path).suffix.lower() == '.fif': | |
| raw = mne.io.read_raw_fif(file_path, preload=True) | |
| else: | |
| return None, None, "Please upload a .fif file or a .zip containing a .mff folder", None | |
| raw_data = raw | |
| stim_channels = discover_stim_channels(raw) | |
| events_df = extract_events_from_channels(raw, stim_channels) | |
| events_df_global = events_df | |
| events_wide = events_df_to_wide_format(events_df) if not events_df.empty else pd.DataFrame() | |
| if not events_wide.empty: | |
| events_wide = events_wide.reset_index() | |
| events_wide_global = events_wide | |
| n_stim = len(stim_channels) | |
| n_data = len(raw.ch_names) - n_stim | |
| summary_text = f"β Loaded: {stem}\n" | |
| summary_text += f"Duration: {raw.times[-1]:.2f}s | SFreq: {raw.info['sfreq']:.1f} Hz\n" | |
| summary_text += f"Channels: {n_data} data + {n_stim} stim = {len(raw.ch_names)} total\n\n" | |
| summary_text += f"π Stimulus channels found: {stim_channels}\n\n" | |
| raw_plot = plot_raw_data_plotly(raw, stim_channels, duration=120.0) | |
| timeline_plot = plot_events_timeline_plotly(events_df, raw.times[-1]) | |
| return timeline_plot, summary_text, events_wide, raw_plot | |
| except Exception as e: | |
| error_msg = f"Error loading file: {str(e)}" | |
| return None, None, error_msg, None | |
| finally: | |
| pass | |
| def crop_eeg(crop_regions_str): | |
| global raw_data, events_wide_global | |
| if raw_data is None: | |
| return [], "Please inspect an EEG file first" | |
| try: | |
| crop_regions = [] | |
| for region in crop_regions_str.split(';'): | |
| region = region.strip() | |
| if not region: | |
| continue | |
| parts = [float(x.strip()) for x in region.split(',')] | |
| if len(parts) == 2: | |
| crop_regions.append((parts[0], parts[1])) | |
| if not crop_regions: | |
| return [], "No valid crop regions specified" | |
| max_time = raw_data.times[-1] | |
| crop_regions, warnings = validate_and_clip_crop_regions(crop_regions, max_time) | |
| if not crop_regions: | |
| return [], "No valid crop regions after validation" | |
| if len(crop_regions) == 1: | |
| start, end = crop_regions[0] | |
| raw_crop = raw_data.copy().crop(tmin=start, tmax=end) | |
| status = f"βοΈ Applied single crop: {start:.2f}s β {end:.2f}s\n" | |
| else: | |
| segments = [raw_data.copy().crop(tmin=s, tmax=e) for s, e in crop_regions] | |
| raw_crop = mne.concatenate_raws(segments) | |
| status = f"βοΈ Concatenated {len(crop_regions)} segments\n" | |
| status += f"Final cropped duration: {raw_crop.times[-1]:.2f}s\n" | |
| if warnings: | |
| status += "\nWarnings:\n" + "\n".join(warnings) | |
| temp_dir = tempfile.mkdtemp() | |
| output_files = [] | |
| eeg_output = os.path.join(temp_dir, "cropped_eeg.fif") | |
| raw_crop.save(eeg_output, overwrite=True) | |
| output_files.append(eeg_output) | |
| if events_wide_global is not None and not events_wide_global.empty: | |
| csv_output = os.path.join(temp_dir, "events_wide.csv") | |
| events_wide_global.to_csv(csv_output) | |
| output_files.append(csv_output) | |
| return output_files, status | |
| except Exception as e: | |
| return [], f"Error during cropping: {str(e)}" | |
| # ============================================================================= | |
| # GRADIO APP | |
| # ============================================================================= | |
| with gr.Blocks(theme=gr.themes.Base(), title="EEG Stimulus Discovery & Interactive Cropping") as demo: | |
| gr.Markdown("# EEG Stimulus Discovery & Interactive Cropping") | |
| gr.Markdown(""" | |
| Upload EEG data to discover stimulus channels, visualize events, and crop clean segments. | |
| - **`.fif`**: Upload directly | |
| - **`.mff`**: ZIP the entire `.mff` folder first, then upload the `.zip` file | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("π Inspect"): | |
| gr.Markdown("### Step 1: Load and Inspect EEG Data") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| eeg_input = gr.File( | |
| label="Upload .fif directly, or .zip containing a .mff folder", | |
| file_types=[".fif", ".zip"] | |
| ) | |
| inspect_btn = gr.Button("π Inspect EEG", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| summary = gr.Textbox(label="Summary", lines=5, interactive=False) | |
| gr.Markdown("### Stimulus Timeline") | |
| timeline_plot = gr.Plot(label="Stimulus Events") | |
| gr.Markdown("### Event Times") | |
| events_csv = gr.Dataframe(label="CSV", interactive=False) | |
| gr.Markdown("### Raw Data Preview (first 120s)") | |
| raw_plot = gr.Plot(label="Raw Data") | |
| inspect_btn.click( | |
| fn=inspect_eeg, | |
| inputs=eeg_input, | |
| outputs=[timeline_plot, summary, events_csv, raw_plot] | |
| ) | |
| with gr.Tab("βοΈ Crop & Download"): | |
| gr.Markdown("### Step 2: Define Crop Regions and Download") | |
| gr.Markdown(""" | |
| **Instructions:** | |
| - Use the event timeline and raw preview to choose clean segments | |
| - Format: `start1,end1; start2,end2` (e.g., `25.0,29.0; 30.0,34.0`) | |
| - Times in seconds | |
| """) | |
| crop_regions = gr.Textbox( | |
| value="25.0, 29.0", | |
| label="Crop Regions (seconds)", | |
| lines=2 | |
| ) | |
| crop_btn = gr.Button("βοΈ Crop & Generate Downloads", variant="primary", size="lg") | |
| crop_status = gr.Textbox(label="Crop Status", lines=5, interactive=False) | |
| output_files = gr.Files(label="π₯ Download Cropped EEG + Events CSV") | |
| crop_btn.click( | |
| fn=crop_eeg, | |
| inputs=crop_regions, | |
| outputs=[output_files, crop_status] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |