stim_trials / app.py
JayLacoma's picture
Update app.py
325bf92 verified
import os
import re
import mne
import numpy as np
import pandas as pd
import tempfile
import zipfile
from pathlib import Path
import gradio as gr
import plotly.graph_objects as go
mne.set_log_level('WARNING')
# Global state (okay for local use)
raw_data = None
events_df_global = None
events_wide_global = None
# =============================================================================
# UTILITIES
# =============================================================================
def discover_stim_channels(raw):
stim_candidates = []
for ch_name, ch_type in zip(raw.ch_names, raw.get_channel_types()):
if ch_type in ['stim', 'misc']:
stim_candidates.append(ch_name)
elif any(kw in ch_name.upper() for kw in ['STI', 'TRIG', 'TTL', 'STIM', 'TRIGGER']):
stim_candidates.append(ch_name)
elif re.match(r'^\d+[a-zA-Z]+$', ch_name):
stim_candidates.append(ch_name)
return sorted(set(stim_candidates))
def extract_events_from_channels(raw, stim_channels):
events = []
sfreq = raw.info['sfreq']
for ch in stim_channels:
if ch not in raw.ch_names:
continue
data = raw.get_data(picks=[ch])[0]
digital = (data > 0.5).astype(int)
diff = np.diff(digital, prepend=0)
onset_samples = np.where(diff == 1)[0]
onset_times = onset_samples / sfreq
for t, s in zip(onset_times, onset_samples):
events.append({'channel': ch, 'time_sec': t, 'sample': int(s)})
if not events:
return pd.DataFrame()
df = pd.DataFrame(events)
return df.sort_values('time_sec').reset_index(drop=True)
def events_df_to_wide_format(events_df):
if events_df.empty:
return pd.DataFrame()
cleaned_events = []
for ch in sorted(events_df['channel'].unique()):
ch_events = events_df[events_df['channel'] == ch].copy()
ch_events = ch_events.sort_values('time_sec').reset_index(drop=True)
times = ch_events['time_sec'].values
i = 0
while i < len(times):
burst_start = times[i]
cleaned_events.append({'channel': ch, 'time_sec': burst_start})
j = i
while j + 1 < len(times) and (times[j + 1] - times[j]) <= 0.9:
j += 1
i = j + 1
cleaned_df = pd.DataFrame(cleaned_events)
wide_data = {}
max_events = 0
for ch in sorted(cleaned_df['channel'].unique()):
times = cleaned_df[cleaned_df['channel'] == ch]['time_sec'].values
wide_data[ch] = times
max_events = max(max_events, len(times))
for ch in wide_data:
times = wide_data[ch]
if len(times) < max_events:
wide_data[ch] = np.pad(times, (0, max_events - len(times)), constant_values=np.nan)
df_wide = pd.DataFrame(wide_data)
df_wide.index = np.arange(1, len(df_wide) + 1)
df_wide.index.name = 'Event_Index'
return df_wide
def plot_events_timeline_plotly(events_df, raw_duration):
if events_df.empty:
return None
fig = go.Figure()
for i, ch in enumerate(sorted(events_df['channel'].unique())):
ch_events = events_df[events_df['channel'] == ch]
fig.add_trace(go.Scatter(
x=ch_events['time_sec'],
y=[i] * len(ch_events),
mode='markers',
name=f"{ch} (n={len(ch_events)})",
marker=dict(size=8, opacity=0.7)
))
fig.update_layout(
xaxis_title='Time (seconds)',
yaxis_title='Channel',
xaxis=dict(range=[0, raw_duration]),
height=500,
showlegend=True,
hovermode='closest'
)
return fig
def plot_raw_data_plotly(raw, stim_channels, duration=120.0):
stim_picks = [raw.ch_names.index(ch) for ch in stim_channels if ch in raw.ch_names]
all_picks = np.arange(len(raw.ch_names))
non_stim_picks = np.setdiff1d(all_picks, stim_picks)
if len(non_stim_picks) == 0:
return None
n_channels = min(45, len(non_stim_picks))
picks = non_stim_picks[:n_channels]
# πŸ”§ SAFETY FIX: avoid IndexError if raw.times is empty
recording_duration = raw.times[-1] if len(raw.times) > 0 else 0.0
end_time = min(duration, recording_duration)
end_sample = int(end_time * raw.info['sfreq'])
data, times = raw[picks, :end_sample]
fig = go.Figure()
for i, ch_idx in enumerate(picks):
ch_data = data[i]
ch_data = ch_data - np.mean(ch_data)
if np.std(ch_data) > 1e-15:
ch_data = ch_data / np.std(ch_data)
ch_data = ch_data + i * 3
fig.add_trace(go.Scatter(
x=times,
y=ch_data,
mode='lines',
name=raw.ch_names[ch_idx],
line=dict(width=0.8),
showlegend=False
))
fig.update_layout(
xaxis_title='Time (s)',
yaxis_title='Channels (normalized + offset)',
height=min(600, 100 + 120 * n_channels),
hovermode='x unified'
)
return fig
def validate_and_clip_crop_regions(crop_regions, max_time):
validated = []
warnings = []
for start, end in crop_regions:
if start > max_time:
warnings.append(f"⚠️ Skipping region ({start:.2f}, {end:.2f}) β€” starts after recording ends.")
continue
clipped_end = min(end, max_time)
if clipped_end <= start:
warnings.append(f"⚠️ Skipping region ({start:.2f}, {end:.2f}) β€” no valid duration.")
continue
if end > max_time:
warnings.append(f"⚠️ Region ({start:.2f}, {end:.2f}) clipped to ({start:.2f}, {clipped_end:.2f})")
validated.append((start, clipped_end))
return validated, warnings
# =============================================================================
# GRADIO FUNCTIONS
# =============================================================================
def inspect_eeg(file_obj):
global raw_data, events_df_global, events_wide_global
if file_obj is None:
return None, None, "Please upload an EEG file", None
temp_extract_dir = None
try:
file_path = file_obj.name
stem = Path(file_path).stem
suffixes = Path(file_path).suffixes
load_path = file_path
if len(suffixes) >= 2 and suffixes[-1] == '.zip':
temp_extract_dir = tempfile.mkdtemp()
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(temp_extract_dir)
mff_dir = None
for item in Path(temp_extract_dir).iterdir():
if item.is_dir() and item.suffix == '.mff':
mff_dir = item
break
if mff_dir is None:
return None, None, "ZIP must contain a .mff folder (e.g., 'data.mff/')", None
load_path = str(mff_dir)
raw = mne.io.read_raw_egi(load_path, preload=True, verbose=False)
elif Path(file_path).suffix.lower() == '.fif':
raw = mne.io.read_raw_fif(file_path, preload=True)
else:
return None, None, "Please upload a .fif file or a .zip containing a .mff folder", None
raw_data = raw
stim_channels = discover_stim_channels(raw)
events_df = extract_events_from_channels(raw, stim_channels)
events_df_global = events_df
events_wide = events_df_to_wide_format(events_df) if not events_df.empty else pd.DataFrame()
if not events_wide.empty:
events_wide = events_wide.reset_index()
events_wide_global = events_wide
n_stim = len(stim_channels)
n_data = len(raw.ch_names) - n_stim
summary_text = f"βœ… Loaded: {stem}\n"
summary_text += f"Duration: {raw.times[-1]:.2f}s | SFreq: {raw.info['sfreq']:.1f} Hz\n"
summary_text += f"Channels: {n_data} data + {n_stim} stim = {len(raw.ch_names)} total\n\n"
summary_text += f"πŸ” Stimulus channels found: {stim_channels}\n\n"
raw_plot = plot_raw_data_plotly(raw, stim_channels, duration=120.0)
timeline_plot = plot_events_timeline_plotly(events_df, raw.times[-1])
return timeline_plot, summary_text, events_wide, raw_plot
except Exception as e:
error_msg = f"Error loading file: {str(e)}"
return None, None, error_msg, None
finally:
pass
def crop_eeg(crop_regions_str):
global raw_data, events_wide_global
if raw_data is None:
return [], "Please inspect an EEG file first"
try:
crop_regions = []
for region in crop_regions_str.split(';'):
region = region.strip()
if not region:
continue
parts = [float(x.strip()) for x in region.split(',')]
if len(parts) == 2:
crop_regions.append((parts[0], parts[1]))
if not crop_regions:
return [], "No valid crop regions specified"
max_time = raw_data.times[-1]
crop_regions, warnings = validate_and_clip_crop_regions(crop_regions, max_time)
if not crop_regions:
return [], "No valid crop regions after validation"
if len(crop_regions) == 1:
start, end = crop_regions[0]
raw_crop = raw_data.copy().crop(tmin=start, tmax=end)
status = f"βœ‚οΈ Applied single crop: {start:.2f}s β†’ {end:.2f}s\n"
else:
segments = [raw_data.copy().crop(tmin=s, tmax=e) for s, e in crop_regions]
raw_crop = mne.concatenate_raws(segments)
status = f"βœ‚οΈ Concatenated {len(crop_regions)} segments\n"
status += f"Final cropped duration: {raw_crop.times[-1]:.2f}s\n"
if warnings:
status += "\nWarnings:\n" + "\n".join(warnings)
temp_dir = tempfile.mkdtemp()
output_files = []
eeg_output = os.path.join(temp_dir, "cropped_eeg.fif")
raw_crop.save(eeg_output, overwrite=True)
output_files.append(eeg_output)
if events_wide_global is not None and not events_wide_global.empty:
csv_output = os.path.join(temp_dir, "events_wide.csv")
events_wide_global.to_csv(csv_output)
output_files.append(csv_output)
return output_files, status
except Exception as e:
return [], f"Error during cropping: {str(e)}"
# =============================================================================
# GRADIO APP
# =============================================================================
with gr.Blocks(theme=gr.themes.Base(), title="EEG Stimulus Discovery & Interactive Cropping") as demo:
gr.Markdown("# EEG Stimulus Discovery & Interactive Cropping")
gr.Markdown("""
Upload EEG data to discover stimulus channels, visualize events, and crop clean segments.
- **`.fif`**: Upload directly
- **`.mff`**: ZIP the entire `.mff` folder first, then upload the `.zip` file
""")
with gr.Tabs():
with gr.Tab("πŸ” Inspect"):
gr.Markdown("### Step 1: Load and Inspect EEG Data")
with gr.Row():
with gr.Column(scale=1):
eeg_input = gr.File(
label="Upload .fif directly, or .zip containing a .mff folder",
file_types=[".fif", ".zip"]
)
inspect_btn = gr.Button("πŸ”Ž Inspect EEG", variant="primary", size="lg")
with gr.Column(scale=2):
summary = gr.Textbox(label="Summary", lines=5, interactive=False)
gr.Markdown("### Stimulus Timeline")
timeline_plot = gr.Plot(label="Stimulus Events")
gr.Markdown("### Event Times")
events_csv = gr.Dataframe(label="CSV", interactive=False)
gr.Markdown("### Raw Data Preview (first 120s)")
raw_plot = gr.Plot(label="Raw Data")
inspect_btn.click(
fn=inspect_eeg,
inputs=eeg_input,
outputs=[timeline_plot, summary, events_csv, raw_plot]
)
with gr.Tab("βœ‚οΈ Crop & Download"):
gr.Markdown("### Step 2: Define Crop Regions and Download")
gr.Markdown("""
**Instructions:**
- Use the event timeline and raw preview to choose clean segments
- Format: `start1,end1; start2,end2` (e.g., `25.0,29.0; 30.0,34.0`)
- Times in seconds
""")
crop_regions = gr.Textbox(
value="25.0, 29.0",
label="Crop Regions (seconds)",
lines=2
)
crop_btn = gr.Button("βœ‚οΈ Crop & Generate Downloads", variant="primary", size="lg")
crop_status = gr.Textbox(label="Crop Status", lines=5, interactive=False)
output_files = gr.Files(label="πŸ“₯ Download Cropped EEG + Events CSV")
crop_btn.click(
fn=crop_eeg,
inputs=crop_regions,
outputs=[output_files, crop_status]
)
if __name__ == "__main__":
demo.launch()