Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from pathlib import Path | |
| from scipy.integrate import trapezoid | |
| import scipy.signal as signal | |
| # ============================================================================= | |
| # CONFIGURATION | |
| # ============================================================================= | |
| OUTPUT_IEEG = Path("consolidated_ieeg.npz") | |
| OUTPUT_LCMV = Path("consolidated_lcmv.npz") | |
| RUN_MAP = {"c": "eyes_closed", "o": "eyes_open", "l": "left_hand", "r": "right_hand"} | |
| # PSD Parameters | |
| SFREQ_DEFAULT = 500.0 | |
| PSD_WINDOW_SEC = 2.0 | |
| FMAX = 50 | |
| FREQ_BANDS = { | |
| 'Delta': (1, 4), 'Theta': (4, 8), 'Alpha': (8, 12), | |
| 'Low_Beta': (12, 20), 'High_Beta': (20, 30), | |
| 'Low_Gamma': (30, 50), 'High_Gamma': (50, 100), | |
| } | |
| # Patterns | |
| STN_PATTERNS = ["STN-L", "STN-R", "STN_L", "STN_R", "Left-STN", "Right-STN"] | |
| GPI_PATTERNS = ["GPi-L", "GPi-R", "GPi_L", "GPi_R", "pGP-lh", "pGP-rh", "L-GPi", "R-GPi", "GPI-L", "GPI-R"] | |
| M1_L_PATTERNS = ["ECOG-8-9-L", "ECOG-10-11-L", "M1-L", "Left-M1"] | |
| M1_R_PATTERNS = ["ECOG-8-9-R", "ECOG-10-11-R", "M1-R", "Right-M1"] | |
| ATLAS_LABELS = { | |
| "STN": "STN (DiFuMo-223)", | |
| "L_GPi": "L-GPi (GT pGP-lh)", | |
| "R_GPi": "R-GPi (GT pGP-rh)", | |
| } | |
| COLORS = { | |
| "IEEG": "#1f77b4", | |
| "LCMV": "#d62728", | |
| "STN": "#ff7f0e", | |
| "L_GPi": "#2ca02c", | |
| "R_GPi": "#9467bd", | |
| } | |
| # Global Data Handles | |
| ALL_IEEG_DATA = None | |
| ALL_LCMV_DATA = None | |
| # ============================================================================= | |
| # CORE LOGIC | |
| # ============================================================================= | |
| def compute_psd(time_series, sfreq=SFREQ_DEFAULT, fmax=FMAX): | |
| ts = np.real(time_series).astype(np.float64) | |
| window_size = int(PSD_WINDOW_SEC * sfreq) | |
| if len(ts) < window_size: | |
| window_size = max(int(len(ts)*0.8), 100) | |
| nyq = sfreq * 0.5 | |
| if nyq <= 0.5: nyq = 0.51 | |
| b, a = signal.butter(4, 0.5 / nyq, btype='high') | |
| filtered = signal.filtfilt(b, a, ts) | |
| freqs, psd = signal.welch(filtered, fs=sfreq, window='hann', nperseg=window_size, | |
| noverlap=window_size // 2, detrend='constant') | |
| mask = (freqs >= 1.0) & (freqs <= fmax) | |
| freqs, psd = freqs[mask], psd[mask] | |
| if len(freqs) == 0: | |
| return np.array([1, 10]), np.log10(np.array([1e-10, 1e-10]) + 1e-12) | |
| psd_log = np.log10(psd + 1e-12) | |
| return freqs.astype(np.float32), psd_log.astype(np.float32) | |
| def load_data(): | |
| global ALL_IEEG_DATA, ALL_LCMV_DATA | |
| if ALL_IEEG_DATA is None or ALL_LCMV_DATA is None: | |
| if not OUTPUT_IEEG.exists() or not OUTPUT_LCMV.exists(): | |
| raise FileNotFoundError("Consolidated files missing. Please run consolidation first.") | |
| ALL_IEEG_DATA = np.load(OUTPUT_IEEG, allow_pickle=True) | |
| ALL_LCMV_DATA = np.load(OUTPUT_LCMV, allow_pickle=True) | |
| def get_consolidated_ieeg(subj_id, run_code): | |
| global ALL_IEEG_DATA | |
| meta_key = f"meta_{subj_id}_{run_code}" | |
| if meta_key not in ALL_IEEG_DATA.files: | |
| return None, None | |
| meta = ALL_IEEG_DATA[meta_key].item() | |
| channels = {} | |
| prefix = f"{subj_id}_{run_code}_" | |
| for key in ALL_IEEG_DATA.files: | |
| if key.startswith(prefix) and key != meta_key: | |
| channels[key.replace(prefix, "")] = ALL_IEEG_DATA[key] | |
| return channels, meta | |
| def get_consolidated_lcmv(subj_id): | |
| global ALL_LCMV_DATA | |
| meta_key = f"meta_{subj_id}" | |
| if meta_key not in ALL_LCMV_DATA.files: | |
| return None, None | |
| meta = ALL_LCMV_DATA[meta_key].item() | |
| rois = {} | |
| prefix = f"{subj_id}_" | |
| for key in ALL_LCMV_DATA.files: | |
| if key.startswith(prefix) and key != meta_key: | |
| rois[key.replace(prefix, "")] = ALL_LCMV_DATA[key] | |
| return rois, meta | |
| def find_channel(channels_dict, patterns): | |
| if channels_dict is None: | |
| return None, None | |
| for pattern in patterns: | |
| if pattern in channels_dict: | |
| return pattern, channels_dict[pattern] | |
| for key in channels_dict.keys(): | |
| if pattern.lower() in key.lower(): | |
| return key, channels_dict[key] | |
| return None, None | |
| def create_interactive_plot(roi_name, ieeg_signal, ieeg_sfreq, ch_used, | |
| source_signal, source_sfreq, source_label, source_color, | |
| subject_id, run_id): | |
| freqs_ieeg, psd_ieeg = compute_psd(ieeg_signal, sfreq=ieeg_sfreq) | |
| freqs_src, psd_src = compute_psd(source_signal, sfreq=source_sfreq) | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter( | |
| x=freqs_ieeg, y=psd_ieeg, | |
| mode='lines', name=f'iEEG ({ch_used})', | |
| line=dict(color=COLORS["IEEG"], width=3), | |
| hovertemplate=f'<b>iEEG</b><br>Freq: %{{x:.2f}} Hz<br>PSD: %{{y:.2f}}<extra></extra>' | |
| )) | |
| fig.add_trace(go.Scatter( | |
| x=freqs_src, y=psd_src, | |
| mode='lines', name=source_label, | |
| line=dict(color=source_color, width=3, dash='dash'), | |
| hovertemplate=f'<b>{source_label}</b><br>Freq: %{{x:.2f}} Hz<br>PSD: %{{y:.2f}}<extra></extra>' | |
| )) | |
| shapes = [] | |
| n_bands = len(FREQ_BANDS) | |
| band_colors = [f"rgba(31, 119, 180, {0.1 + (i/n_bands)*0.2})" for i in range(n_bands)] | |
| for i, (band, (fmin, fmax)) in enumerate(FREQ_BANDS.items()): | |
| band_low = max(fmin, min(freqs_ieeg)) | |
| band_high = min(fmax, max(freqs_ieeg)) | |
| if band_low < band_high: | |
| shapes.append(dict( | |
| type="rect", xref="x", yref="paper", | |
| x0=band_low, x1=band_high, y0=0, y1=1, | |
| fillcolor=band_colors[i], opacity=0.3, layer="below", line_width=0 | |
| )) | |
| title_text = f"{subject_id} | Run: {run_id} | ROI: {roi_name}<br><sup>{source_label} vs iEEG</sup>" | |
| fig.update_layout( | |
| title=dict(text=title_text, font=dict(size=14, family="Arial")), | |
| xaxis_title="Frequency (Hz)", | |
| yaxis_title="PSD (logโโ)", | |
| xaxis=dict(range=[1, FMAX], type="linear"), | |
| yaxis_type="linear", | |
| hovermode="x unified", | |
| legend=dict(x=0, y=1, bgcolor="rgba(255,255,255,0.8)"), | |
| shapes=shapes, | |
| template="plotly_white", | |
| height=600, | |
| margin=dict(l=50, r=50, t=60, b=50) | |
| ) | |
| return fig | |
| def generate_all_plots(subj_id, run_code): | |
| """Generates all valid plots for a subject/run and returns a dictionary.""" | |
| try: | |
| load_data() | |
| except FileNotFoundError as e: | |
| return {}, str(e) | |
| cond = RUN_MAP.get(run_code, "unknown") | |
| ieeg_ch, ieeg_meta = get_consolidated_ieeg(subj_id, run_code) | |
| lcmv_rois, lcmv_meta = get_consolidated_lcmv(subj_id) | |
| plots_dict = {} | |
| logs = [f"Processing {subj_id} | Condition: {cond}"] | |
| if ieeg_ch is None or lcmv_rois is None: | |
| return plots_dict, f"No data found for {subj_id} (Run: {run_code})." | |
| ieeg_sfreq = ieeg_meta.get('sfreq', SFREQ_DEFAULT) | |
| lcmv_sfreq = lcmv_meta.get('sfreq', SFREQ_DEFAULT) | |
| # Detect Electrodes | |
| stn_l_ch, stn_l_sig = find_channel(ieeg_ch, STN_PATTERNS) | |
| stn_r_ch, stn_r_sig = find_channel(ieeg_ch, [p.replace("-L","-R").replace("_L","_R") for p in STN_PATTERNS]) | |
| gpi_l_ch, gpi_l_sig = find_channel(ieeg_ch, GPI_PATTERNS) | |
| gpi_r_ch, gpi_r_sig = None, None | |
| if gpi_l_ch: | |
| right_patterns = [gpi_l_ch.replace("L","R").replace("l","r").replace("lh","rh")] | |
| right_patterns.extend([p.replace("-L","-R").replace("_L","_R") for p in GPI_PATTERNS]) | |
| gpi_r_ch, gpi_r_sig = find_channel(ieeg_ch, right_patterns) | |
| m1_l_ch, m1_l_sig = find_channel(ieeg_ch, M1_L_PATTERNS) | |
| m1_r_ch, m1_r_sig = find_channel(ieeg_ch, M1_R_PATTERNS) | |
| def add_plot(name, sig, ch, roi_key, label, color): | |
| if sig is not None and ch is not None and roi_key in lcmv_rois: | |
| fig = create_interactive_plot(name, sig, ieeg_sfreq, ch, lcmv_rois[roi_key], lcmv_sfreq, label, color, subj_id, run_code) | |
| key = f"{name} vs {label}" | |
| plots_dict[key] = fig | |
| logs.append(f"โ Found: {key}") | |
| # M1 | |
| add_plot("L_M1", m1_l_sig, m1_l_ch, f"L_M1_{cond}", "LCMV MNI voxel", COLORS["LCMV"]) | |
| add_plot("R_M1", m1_r_sig, m1_r_ch, f"R_M1_{cond}", "LCMV MNI voxel", COLORS["LCMV"]) | |
| # STN | |
| if stn_l_sig is not None: | |
| add_plot("L_STN", stn_l_sig, stn_l_ch, f"L_STN_{cond}", "LCMV MNI voxel", COLORS["LCMV"]) | |
| if f"STN_{cond}" in lcmv_rois: | |
| add_plot("L_STN", stn_l_sig, stn_l_ch, f"STN_{cond}", ATLAS_LABELS["STN"], COLORS["STN"]) | |
| if stn_r_sig is not None: | |
| add_plot("R_STN", stn_r_sig, stn_r_ch, f"R_STN_{cond}", "LCMV MNI voxel", COLORS["LCMV"]) | |
| if f"STN_{cond}" in lcmv_rois: | |
| add_plot("R_STN", stn_r_sig, stn_r_ch, f"STN_{cond}", ATLAS_LABELS["STN"], COLORS["STN"]) | |
| # GPi (Fallback) | |
| if gpi_l_sig is not None and stn_l_sig is None: | |
| add_plot("L_GPi", gpi_l_sig, gpi_l_ch, f"L_GPi_{cond}", "LCMV MNI voxel (GPi)", COLORS["LCMV"]) | |
| if f"L_GPi_{cond}" in lcmv_rois: | |
| add_plot("L_GPi", gpi_l_sig, gpi_l_ch, f"L_GPi_{cond}", ATLAS_LABELS["L_GPi"], COLORS["L_GPi"]) | |
| if gpi_r_sig is not None and stn_r_sig is None: | |
| add_plot("R_GPi", gpi_r_sig, gpi_r_ch, f"R_GPi_{cond}", "LCMV MNI voxel (GPi)", COLORS["LCMV"]) | |
| if f"R_GPi_{cond}" in lcmv_rois: | |
| add_plot("R_GPi", gpi_r_sig, gpi_r_ch, f"R_GPi_{cond}", ATLAS_LABELS["R_GPi"], COLORS["R_GPi"]) | |
| if not plots_dict: | |
| logs.append("โ ๏ธ No matching electrode/ROI pairs found.") | |
| return plots_dict, "\n".join(logs) | |
| def get_available_subjects(): | |
| if not OUTPUT_LCMV.exists(): | |
| return [] | |
| data = np.load(OUTPUT_LCMV, allow_pickle=True) | |
| subjects = set() | |
| for key in data.files: | |
| if key.startswith("meta_"): | |
| subjects.add(key.replace("meta_", "")) | |
| return sorted(list(subjects)) | |
| # ============================================================================= | |
| # GRADIO INTERFACE | |
| # ============================================================================= | |
| # Note: 'theme' parameter removed from constructor for Gradio 5.0+ compatibility | |
| with gr.Blocks(title="Interactive iEEG-LCMV Viewer") as demo: | |
| gr.Markdown("# Interactive iEEG & LCMV Viewer") | |
| gr.Markdown("Select a subject and condition to generate available comparisons. Then choose specific plots to visualize.") | |
| # State to store generated plots for the current selection | |
| current_plots_state = gr.State({}) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### 1. Select Data") | |
| btn_refresh = gr.Button("๐ Refresh Subjects") | |
| subject_dropdown = gr.Dropdown(label="Subject", choices=[], interactive=True) | |
| run_dropdown = gr.Dropdown( | |
| label="Condition", | |
| choices=["c", "o", "l", "r"], | |
| value="c", | |
| info="c: Eyes Closed, o: Eyes Open, l: Left Hand, r: Right Hand" | |
| ) | |
| btn_generate = gr.Button("๐ Find Available Plots", variant="primary") | |
| gr.Markdown("### 2. Choose Visualization") | |
| plot_selector = gr.Dropdown(label="Select Plot to View", choices=[], interactive=True) | |
| gr.Markdown("### Log") | |
| val_log = gr.Textbox(label="Status", lines=6, interactive=False) | |
| with gr.Column(scale=3): | |
| gr.Markdown("### PSD Comparison") | |
| plot_display = gr.Plot(label="Interactive Plot", show_label=False) | |
| # Event Handlers | |
| def refresh_subjects(): | |
| subs = get_available_subjects() | |
| return gr.Dropdown(choices=subs, value=subs[0] if subs else None) | |
| def process_and_update_dropdown(subj, run): | |
| """Generates plots, updates state, log, dropdown options, and shows the first plot.""" | |
| if not subj: | |
| return {}, "Please select a subject.", gr.Dropdown(choices=[], value=None), None | |
| plots_dict, log_msg = generate_all_plots(subj, run) | |
| choices = list(plots_dict.keys()) | |
| if not choices: | |
| return plots_dict, log_msg, gr.Dropdown(choices=[], value=None), None | |
| initial_val = choices[0] | |
| initial_fig = plots_dict[initial_val] | |
| return plots_dict, log_msg, gr.Dropdown(choices=choices, value=initial_val), initial_fig | |
| def on_plot_selection(plots_dict, selected_key): | |
| """Updates only the plot when dropdown changes.""" | |
| if not plots_dict or not selected_key: | |
| return None | |
| return plots_dict.get(selected_key) | |
| # Wire up events | |
| btn_refresh.click(fn=refresh_subjects, inputs=[], outputs=[subject_dropdown]) | |
| demo.load(fn=refresh_subjects, inputs=[], outputs=[subject_dropdown]) | |
| # When Generate is clicked: Update State, Log, Dropdown, AND Plot | |
| btn_generate.click( | |
| fn=process_and_update_dropdown, | |
| inputs=[subject_dropdown, run_dropdown], | |
| outputs=[current_plots_state, val_log, plot_selector, plot_display] | |
| ) | |
| # When Dropdown changes: Update Plot Display only | |
| plot_selector.change( | |
| fn=on_plot_selection, | |
| inputs=[current_plots_state, plot_selector], | |
| outputs=[plot_display] | |
| ) | |
| if __name__ == "__main__": | |
| # Note: 'theme' parameter moved to launch() for Gradio 5.0+ | |
| demo.launch(theme=gr.themes.Soft()) |