import gradio as gr import numpy as np import plotly.graph_objects as go from pathlib import Path from scipy.integrate import trapezoid import scipy.signal as signal # ============================================================================= # CONFIGURATION # ============================================================================= OUTPUT_IEEG = Path("consolidated_ieeg.npz") OUTPUT_LCMV = Path("consolidated_lcmv.npz") RUN_MAP = {"c": "eyes_closed", "o": "eyes_open", "l": "left_hand", "r": "right_hand"} # PSD Parameters SFREQ_DEFAULT = 500.0 PSD_WINDOW_SEC = 2.0 FMAX = 50 FREQ_BANDS = { 'Delta': (1, 4), 'Theta': (4, 8), 'Alpha': (8, 12), 'Low_Beta': (12, 20), 'High_Beta': (20, 30), 'Low_Gamma': (30, 50), 'High_Gamma': (50, 100), } # Patterns STN_PATTERNS = ["STN-L", "STN-R", "STN_L", "STN_R", "Left-STN", "Right-STN"] GPI_PATTERNS = ["GPi-L", "GPi-R", "GPi_L", "GPi_R", "pGP-lh", "pGP-rh", "L-GPi", "R-GPi", "GPI-L", "GPI-R"] M1_L_PATTERNS = ["ECOG-8-9-L", "ECOG-10-11-L", "M1-L", "Left-M1"] M1_R_PATTERNS = ["ECOG-8-9-R", "ECOG-10-11-R", "M1-R", "Right-M1"] ATLAS_LABELS = { "STN": "STN (DiFuMo-223)", "L_GPi": "L-GPi (GT pGP-lh)", "R_GPi": "R-GPi (GT pGP-rh)", } COLORS = { "IEEG": "#1f77b4", "LCMV": "#d62728", "STN": "#ff7f0e", "L_GPi": "#2ca02c", "R_GPi": "#9467bd", } # Global Data Handles ALL_IEEG_DATA = None ALL_LCMV_DATA = None # ============================================================================= # CORE LOGIC # ============================================================================= def compute_psd(time_series, sfreq=SFREQ_DEFAULT, fmax=FMAX): ts = np.real(time_series).astype(np.float64) window_size = int(PSD_WINDOW_SEC * sfreq) if len(ts) < window_size: window_size = max(int(len(ts)*0.8), 100) nyq = sfreq * 0.5 if nyq <= 0.5: nyq = 0.51 b, a = signal.butter(4, 0.5 / nyq, btype='high') filtered = signal.filtfilt(b, a, ts) freqs, psd = signal.welch(filtered, fs=sfreq, window='hann', nperseg=window_size, noverlap=window_size // 2, detrend='constant') mask = (freqs >= 1.0) & (freqs <= fmax) freqs, psd = freqs[mask], psd[mask] if len(freqs) == 0: return np.array([1, 10]), np.log10(np.array([1e-10, 1e-10]) + 1e-12) psd_log = np.log10(psd + 1e-12) return freqs.astype(np.float32), psd_log.astype(np.float32) def load_data(): global ALL_IEEG_DATA, ALL_LCMV_DATA if ALL_IEEG_DATA is None or ALL_LCMV_DATA is None: if not OUTPUT_IEEG.exists() or not OUTPUT_LCMV.exists(): raise FileNotFoundError("Consolidated files missing. Please run consolidation first.") ALL_IEEG_DATA = np.load(OUTPUT_IEEG, allow_pickle=True) ALL_LCMV_DATA = np.load(OUTPUT_LCMV, allow_pickle=True) def get_consolidated_ieeg(subj_id, run_code): global ALL_IEEG_DATA meta_key = f"meta_{subj_id}_{run_code}" if meta_key not in ALL_IEEG_DATA.files: return None, None meta = ALL_IEEG_DATA[meta_key].item() channels = {} prefix = f"{subj_id}_{run_code}_" for key in ALL_IEEG_DATA.files: if key.startswith(prefix) and key != meta_key: channels[key.replace(prefix, "")] = ALL_IEEG_DATA[key] return channels, meta def get_consolidated_lcmv(subj_id): global ALL_LCMV_DATA meta_key = f"meta_{subj_id}" if meta_key not in ALL_LCMV_DATA.files: return None, None meta = ALL_LCMV_DATA[meta_key].item() rois = {} prefix = f"{subj_id}_" for key in ALL_LCMV_DATA.files: if key.startswith(prefix) and key != meta_key: rois[key.replace(prefix, "")] = ALL_LCMV_DATA[key] return rois, meta def find_channel(channels_dict, patterns): if channels_dict is None: return None, None for pattern in patterns: if pattern in channels_dict: return pattern, channels_dict[pattern] for key in channels_dict.keys(): if pattern.lower() in key.lower(): return key, channels_dict[key] return None, None def create_interactive_plot(roi_name, ieeg_signal, ieeg_sfreq, ch_used, source_signal, source_sfreq, source_label, source_color, subject_id, run_id): freqs_ieeg, psd_ieeg = compute_psd(ieeg_signal, sfreq=ieeg_sfreq) freqs_src, psd_src = compute_psd(source_signal, sfreq=source_sfreq) fig = go.Figure() fig.add_trace(go.Scatter( x=freqs_ieeg, y=psd_ieeg, mode='lines', name=f'iEEG ({ch_used})', line=dict(color=COLORS["IEEG"], width=3), hovertemplate=f'iEEG
Freq: %{{x:.2f}} Hz
PSD: %{{y:.2f}}' )) fig.add_trace(go.Scatter( x=freqs_src, y=psd_src, mode='lines', name=source_label, line=dict(color=source_color, width=3, dash='dash'), hovertemplate=f'{source_label}
Freq: %{{x:.2f}} Hz
PSD: %{{y:.2f}}' )) shapes = [] n_bands = len(FREQ_BANDS) band_colors = [f"rgba(31, 119, 180, {0.1 + (i/n_bands)*0.2})" for i in range(n_bands)] for i, (band, (fmin, fmax)) in enumerate(FREQ_BANDS.items()): band_low = max(fmin, min(freqs_ieeg)) band_high = min(fmax, max(freqs_ieeg)) if band_low < band_high: shapes.append(dict( type="rect", xref="x", yref="paper", x0=band_low, x1=band_high, y0=0, y1=1, fillcolor=band_colors[i], opacity=0.3, layer="below", line_width=0 )) title_text = f"{subject_id} | Run: {run_id} | ROI: {roi_name}
{source_label} vs iEEG" fig.update_layout( title=dict(text=title_text, font=dict(size=14, family="Arial")), xaxis_title="Frequency (Hz)", yaxis_title="PSD (log₁₀)", xaxis=dict(range=[1, FMAX], type="linear"), yaxis_type="linear", hovermode="x unified", legend=dict(x=0, y=1, bgcolor="rgba(255,255,255,0.8)"), shapes=shapes, template="plotly_white", height=600, margin=dict(l=50, r=50, t=60, b=50) ) return fig def generate_all_plots(subj_id, run_code): """Generates all valid plots for a subject/run and returns a dictionary.""" try: load_data() except FileNotFoundError as e: return {}, str(e) cond = RUN_MAP.get(run_code, "unknown") ieeg_ch, ieeg_meta = get_consolidated_ieeg(subj_id, run_code) lcmv_rois, lcmv_meta = get_consolidated_lcmv(subj_id) plots_dict = {} logs = [f"Processing {subj_id} | Condition: {cond}"] if ieeg_ch is None or lcmv_rois is None: return plots_dict, f"No data found for {subj_id} (Run: {run_code})." ieeg_sfreq = ieeg_meta.get('sfreq', SFREQ_DEFAULT) lcmv_sfreq = lcmv_meta.get('sfreq', SFREQ_DEFAULT) # Detect Electrodes stn_l_ch, stn_l_sig = find_channel(ieeg_ch, STN_PATTERNS) stn_r_ch, stn_r_sig = find_channel(ieeg_ch, [p.replace("-L","-R").replace("_L","_R") for p in STN_PATTERNS]) gpi_l_ch, gpi_l_sig = find_channel(ieeg_ch, GPI_PATTERNS) gpi_r_ch, gpi_r_sig = None, None if gpi_l_ch: right_patterns = [gpi_l_ch.replace("L","R").replace("l","r").replace("lh","rh")] right_patterns.extend([p.replace("-L","-R").replace("_L","_R") for p in GPI_PATTERNS]) gpi_r_ch, gpi_r_sig = find_channel(ieeg_ch, right_patterns) m1_l_ch, m1_l_sig = find_channel(ieeg_ch, M1_L_PATTERNS) m1_r_ch, m1_r_sig = find_channel(ieeg_ch, M1_R_PATTERNS) def add_plot(name, sig, ch, roi_key, label, color): if sig is not None and ch is not None and roi_key in lcmv_rois: fig = create_interactive_plot(name, sig, ieeg_sfreq, ch, lcmv_rois[roi_key], lcmv_sfreq, label, color, subj_id, run_code) key = f"{name} vs {label}" plots_dict[key] = fig logs.append(f"✅ Found: {key}") # M1 add_plot("L_M1", m1_l_sig, m1_l_ch, f"L_M1_{cond}", "LCMV MNI voxel", COLORS["LCMV"]) add_plot("R_M1", m1_r_sig, m1_r_ch, f"R_M1_{cond}", "LCMV MNI voxel", COLORS["LCMV"]) # STN if stn_l_sig is not None: add_plot("L_STN", stn_l_sig, stn_l_ch, f"L_STN_{cond}", "LCMV MNI voxel", COLORS["LCMV"]) if f"STN_{cond}" in lcmv_rois: add_plot("L_STN", stn_l_sig, stn_l_ch, f"STN_{cond}", ATLAS_LABELS["STN"], COLORS["STN"]) if stn_r_sig is not None: add_plot("R_STN", stn_r_sig, stn_r_ch, f"R_STN_{cond}", "LCMV MNI voxel", COLORS["LCMV"]) if f"STN_{cond}" in lcmv_rois: add_plot("R_STN", stn_r_sig, stn_r_ch, f"STN_{cond}", ATLAS_LABELS["STN"], COLORS["STN"]) # GPi (Fallback) if gpi_l_sig is not None and stn_l_sig is None: add_plot("L_GPi", gpi_l_sig, gpi_l_ch, f"L_GPi_{cond}", "LCMV MNI voxel (GPi)", COLORS["LCMV"]) if f"L_GPi_{cond}" in lcmv_rois: add_plot("L_GPi", gpi_l_sig, gpi_l_ch, f"L_GPi_{cond}", ATLAS_LABELS["L_GPi"], COLORS["L_GPi"]) if gpi_r_sig is not None and stn_r_sig is None: add_plot("R_GPi", gpi_r_sig, gpi_r_ch, f"R_GPi_{cond}", "LCMV MNI voxel (GPi)", COLORS["LCMV"]) if f"R_GPi_{cond}" in lcmv_rois: add_plot("R_GPi", gpi_r_sig, gpi_r_ch, f"R_GPi_{cond}", ATLAS_LABELS["R_GPi"], COLORS["R_GPi"]) if not plots_dict: logs.append("⚠️ No matching electrode/ROI pairs found.") return plots_dict, "\n".join(logs) def get_available_subjects(): if not OUTPUT_LCMV.exists(): return [] data = np.load(OUTPUT_LCMV, allow_pickle=True) subjects = set() for key in data.files: if key.startswith("meta_"): subjects.add(key.replace("meta_", "")) return sorted(list(subjects)) # ============================================================================= # GRADIO INTERFACE # ============================================================================= # Note: 'theme' parameter removed from constructor for Gradio 5.0+ compatibility with gr.Blocks(title="Interactive iEEG-LCMV Viewer") as demo: gr.Markdown("# Interactive iEEG & LCMV Viewer") gr.Markdown("Select a subject and condition to generate available comparisons. Then choose specific plots to visualize.") # State to store generated plots for the current selection current_plots_state = gr.State({}) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### 1. Select Data") btn_refresh = gr.Button("🔄 Refresh Subjects") subject_dropdown = gr.Dropdown(label="Subject", choices=[], interactive=True) run_dropdown = gr.Dropdown( label="Condition", choices=["c", "o", "l", "r"], value="c", info="c: Eyes Closed, o: Eyes Open, l: Left Hand, r: Right Hand" ) btn_generate = gr.Button("🔍 Find Available Plots", variant="primary") gr.Markdown("### 2. Choose Visualization") plot_selector = gr.Dropdown(label="Select Plot to View", choices=[], interactive=True) gr.Markdown("### Log") val_log = gr.Textbox(label="Status", lines=6, interactive=False) with gr.Column(scale=3): gr.Markdown("### PSD Comparison") plot_display = gr.Plot(label="Interactive Plot", show_label=False) # Event Handlers def refresh_subjects(): subs = get_available_subjects() return gr.Dropdown(choices=subs, value=subs[0] if subs else None) def process_and_update_dropdown(subj, run): """Generates plots, updates state, log, dropdown options, and shows the first plot.""" if not subj: return {}, "Please select a subject.", gr.Dropdown(choices=[], value=None), None plots_dict, log_msg = generate_all_plots(subj, run) choices = list(plots_dict.keys()) if not choices: return plots_dict, log_msg, gr.Dropdown(choices=[], value=None), None initial_val = choices[0] initial_fig = plots_dict[initial_val] return plots_dict, log_msg, gr.Dropdown(choices=choices, value=initial_val), initial_fig def on_plot_selection(plots_dict, selected_key): """Updates only the plot when dropdown changes.""" if not plots_dict or not selected_key: return None return plots_dict.get(selected_key) # Wire up events btn_refresh.click(fn=refresh_subjects, inputs=[], outputs=[subject_dropdown]) demo.load(fn=refresh_subjects, inputs=[], outputs=[subject_dropdown]) # When Generate is clicked: Update State, Log, Dropdown, AND Plot btn_generate.click( fn=process_and_update_dropdown, inputs=[subject_dropdown, run_dropdown], outputs=[current_plots_state, val_log, plot_selector, plot_display] ) # When Dropdown changes: Update Plot Display only plot_selector.change( fn=on_plot_selection, inputs=[current_plots_state, plot_selector], outputs=[plot_display] ) if __name__ == "__main__": # Note: 'theme' parameter moved to launch() for Gradio 5.0+ demo.launch(theme=gr.themes.Soft())