import os import re import mne import numpy as np import pandas as pd import tempfile import zipfile from pathlib import Path import gradio as gr import plotly.graph_objects as go mne.set_log_level('WARNING') # Global state (okay for local use) raw_data = None events_df_global = None events_wide_global = None # ============================================================================= # UTILITIES # ============================================================================= def discover_stim_channels(raw): stim_candidates = [] for ch_name, ch_type in zip(raw.ch_names, raw.get_channel_types()): if ch_type in ['stim', 'misc']: stim_candidates.append(ch_name) elif any(kw in ch_name.upper() for kw in ['STI', 'TRIG', 'TTL', 'STIM', 'TRIGGER']): stim_candidates.append(ch_name) elif re.match(r'^\d+[a-zA-Z]+$', ch_name): stim_candidates.append(ch_name) return sorted(set(stim_candidates)) def extract_events_from_channels(raw, stim_channels): events = [] sfreq = raw.info['sfreq'] for ch in stim_channels: if ch not in raw.ch_names: continue data = raw.get_data(picks=[ch])[0] digital = (data > 0.5).astype(int) diff = np.diff(digital, prepend=0) onset_samples = np.where(diff == 1)[0] onset_times = onset_samples / sfreq for t, s in zip(onset_times, onset_samples): events.append({'channel': ch, 'time_sec': t, 'sample': int(s)}) if not events: return pd.DataFrame() df = pd.DataFrame(events) return df.sort_values('time_sec').reset_index(drop=True) def events_df_to_wide_format(events_df): if events_df.empty: return pd.DataFrame() cleaned_events = [] for ch in sorted(events_df['channel'].unique()): ch_events = events_df[events_df['channel'] == ch].copy() ch_events = ch_events.sort_values('time_sec').reset_index(drop=True) times = ch_events['time_sec'].values i = 0 while i < len(times): burst_start = times[i] cleaned_events.append({'channel': ch, 'time_sec': burst_start}) j = i while j + 1 < len(times) and (times[j + 1] - times[j]) <= 0.9: j += 1 i = j + 1 cleaned_df = pd.DataFrame(cleaned_events) wide_data = {} max_events = 0 for ch in sorted(cleaned_df['channel'].unique()): times = cleaned_df[cleaned_df['channel'] == ch]['time_sec'].values wide_data[ch] = times max_events = max(max_events, len(times)) for ch in wide_data: times = wide_data[ch] if len(times) < max_events: wide_data[ch] = np.pad(times, (0, max_events - len(times)), constant_values=np.nan) df_wide = pd.DataFrame(wide_data) df_wide.index = np.arange(1, len(df_wide) + 1) df_wide.index.name = 'Event_Index' return df_wide def plot_events_timeline_plotly(events_df, raw_duration): if events_df.empty: return None fig = go.Figure() for i, ch in enumerate(sorted(events_df['channel'].unique())): ch_events = events_df[events_df['channel'] == ch] fig.add_trace(go.Scatter( x=ch_events['time_sec'], y=[i] * len(ch_events), mode='markers', name=f"{ch} (n={len(ch_events)})", marker=dict(size=8, opacity=0.7) )) fig.update_layout( xaxis_title='Time (seconds)', yaxis_title='Channel', xaxis=dict(range=[0, raw_duration]), height=500, showlegend=True, hovermode='closest' ) return fig def plot_raw_data_plotly(raw, stim_channels, duration=120.0): stim_picks = [raw.ch_names.index(ch) for ch in stim_channels if ch in raw.ch_names] all_picks = np.arange(len(raw.ch_names)) non_stim_picks = np.setdiff1d(all_picks, stim_picks) if len(non_stim_picks) == 0: return None n_channels = min(45, len(non_stim_picks)) picks = non_stim_picks[:n_channels] # 🔧 SAFETY FIX: avoid IndexError if raw.times is empty recording_duration = raw.times[-1] if len(raw.times) > 0 else 0.0 end_time = min(duration, recording_duration) end_sample = int(end_time * raw.info['sfreq']) data, times = raw[picks, :end_sample] fig = go.Figure() for i, ch_idx in enumerate(picks): ch_data = data[i] ch_data = ch_data - np.mean(ch_data) if np.std(ch_data) > 1e-15: ch_data = ch_data / np.std(ch_data) ch_data = ch_data + i * 3 fig.add_trace(go.Scatter( x=times, y=ch_data, mode='lines', name=raw.ch_names[ch_idx], line=dict(width=0.8), showlegend=False )) fig.update_layout( xaxis_title='Time (s)', yaxis_title='Channels (normalized + offset)', height=min(600, 100 + 120 * n_channels), hovermode='x unified' ) return fig def validate_and_clip_crop_regions(crop_regions, max_time): validated = [] warnings = [] for start, end in crop_regions: if start > max_time: warnings.append(f"⚠️ Skipping region ({start:.2f}, {end:.2f}) — starts after recording ends.") continue clipped_end = min(end, max_time) if clipped_end <= start: warnings.append(f"⚠️ Skipping region ({start:.2f}, {end:.2f}) — no valid duration.") continue if end > max_time: warnings.append(f"⚠️ Region ({start:.2f}, {end:.2f}) clipped to ({start:.2f}, {clipped_end:.2f})") validated.append((start, clipped_end)) return validated, warnings # ============================================================================= # GRADIO FUNCTIONS # ============================================================================= def inspect_eeg(file_obj): global raw_data, events_df_global, events_wide_global if file_obj is None: return None, None, "Please upload an EEG file", None temp_extract_dir = None try: file_path = file_obj.name stem = Path(file_path).stem suffixes = Path(file_path).suffixes load_path = file_path if len(suffixes) >= 2 and suffixes[-1] == '.zip': temp_extract_dir = tempfile.mkdtemp() with zipfile.ZipFile(file_path, 'r') as zip_ref: zip_ref.extractall(temp_extract_dir) mff_dir = None for item in Path(temp_extract_dir).iterdir(): if item.is_dir() and item.suffix == '.mff': mff_dir = item break if mff_dir is None: return None, None, "ZIP must contain a .mff folder (e.g., 'data.mff/')", None load_path = str(mff_dir) raw = mne.io.read_raw_egi(load_path, preload=True, verbose=False) elif Path(file_path).suffix.lower() == '.fif': raw = mne.io.read_raw_fif(file_path, preload=True) else: return None, None, "Please upload a .fif file or a .zip containing a .mff folder", None raw_data = raw stim_channels = discover_stim_channels(raw) events_df = extract_events_from_channels(raw, stim_channels) events_df_global = events_df events_wide = events_df_to_wide_format(events_df) if not events_df.empty else pd.DataFrame() if not events_wide.empty: events_wide = events_wide.reset_index() events_wide_global = events_wide n_stim = len(stim_channels) n_data = len(raw.ch_names) - n_stim summary_text = f"✅ Loaded: {stem}\n" summary_text += f"Duration: {raw.times[-1]:.2f}s | SFreq: {raw.info['sfreq']:.1f} Hz\n" summary_text += f"Channels: {n_data} data + {n_stim} stim = {len(raw.ch_names)} total\n\n" summary_text += f"🔍 Stimulus channels found: {stim_channels}\n\n" raw_plot = plot_raw_data_plotly(raw, stim_channels, duration=120.0) timeline_plot = plot_events_timeline_plotly(events_df, raw.times[-1]) return timeline_plot, summary_text, events_wide, raw_plot except Exception as e: error_msg = f"Error loading file: {str(e)}" return None, None, error_msg, None finally: pass def crop_eeg(crop_regions_str): global raw_data, events_wide_global if raw_data is None: return [], "Please inspect an EEG file first" try: crop_regions = [] for region in crop_regions_str.split(';'): region = region.strip() if not region: continue parts = [float(x.strip()) for x in region.split(',')] if len(parts) == 2: crop_regions.append((parts[0], parts[1])) if not crop_regions: return [], "No valid crop regions specified" max_time = raw_data.times[-1] crop_regions, warnings = validate_and_clip_crop_regions(crop_regions, max_time) if not crop_regions: return [], "No valid crop regions after validation" if len(crop_regions) == 1: start, end = crop_regions[0] raw_crop = raw_data.copy().crop(tmin=start, tmax=end) status = f"✂️ Applied single crop: {start:.2f}s → {end:.2f}s\n" else: segments = [raw_data.copy().crop(tmin=s, tmax=e) for s, e in crop_regions] raw_crop = mne.concatenate_raws(segments) status = f"✂️ Concatenated {len(crop_regions)} segments\n" status += f"Final cropped duration: {raw_crop.times[-1]:.2f}s\n" if warnings: status += "\nWarnings:\n" + "\n".join(warnings) temp_dir = tempfile.mkdtemp() output_files = [] eeg_output = os.path.join(temp_dir, "cropped_eeg.fif") raw_crop.save(eeg_output, overwrite=True) output_files.append(eeg_output) if events_wide_global is not None and not events_wide_global.empty: csv_output = os.path.join(temp_dir, "events_wide.csv") events_wide_global.to_csv(csv_output) output_files.append(csv_output) return output_files, status except Exception as e: return [], f"Error during cropping: {str(e)}" # ============================================================================= # GRADIO APP # ============================================================================= with gr.Blocks(theme=gr.themes.Base(), title="EEG Stimulus Discovery & Interactive Cropping") as demo: gr.Markdown("# EEG Stimulus Discovery & Interactive Cropping") gr.Markdown(""" Upload EEG data to discover stimulus channels, visualize events, and crop clean segments. - **`.fif`**: Upload directly - **`.mff`**: ZIP the entire `.mff` folder first, then upload the `.zip` file """) with gr.Tabs(): with gr.Tab("🔍 Inspect"): gr.Markdown("### Step 1: Load and Inspect EEG Data") with gr.Row(): with gr.Column(scale=1): eeg_input = gr.File( label="Upload .fif directly, or .zip containing a .mff folder", file_types=[".fif", ".zip"] ) inspect_btn = gr.Button("🔎 Inspect EEG", variant="primary", size="lg") with gr.Column(scale=2): summary = gr.Textbox(label="Summary", lines=5, interactive=False) gr.Markdown("### Stimulus Timeline") timeline_plot = gr.Plot(label="Stimulus Events") gr.Markdown("### Event Times") events_csv = gr.Dataframe(label="CSV", interactive=False) gr.Markdown("### Raw Data Preview (first 120s)") raw_plot = gr.Plot(label="Raw Data") inspect_btn.click( fn=inspect_eeg, inputs=eeg_input, outputs=[timeline_plot, summary, events_csv, raw_plot] ) with gr.Tab("✂️ Crop & Download"): gr.Markdown("### Step 2: Define Crop Regions and Download") gr.Markdown(""" **Instructions:** - Use the event timeline and raw preview to choose clean segments - Format: `start1,end1; start2,end2` (e.g., `25.0,29.0; 30.0,34.0`) - Times in seconds """) crop_regions = gr.Textbox( value="25.0, 29.0", label="Crop Regions (seconds)", lines=2 ) crop_btn = gr.Button("✂️ Crop & Generate Downloads", variant="primary", size="lg") crop_status = gr.Textbox(label="Crop Status", lines=5, interactive=False) output_files = gr.Files(label="📥 Download Cropped EEG + Events CSV") crop_btn.click( fn=crop_eeg, inputs=crop_regions, outputs=[output_files, crop_status] ) if __name__ == "__main__": demo.launch()