JayLacoma commited on
Commit
f8a22ea
·
verified ·
1 Parent(s): ddfed6f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +345 -0
app.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import plotly.graph_objects as go
4
+ from pathlib import Path
5
+ from scipy.integrate import trapezoid
6
+ import scipy.signal as signal
7
+
8
+ # =============================================================================
9
+ # CONFIGURATION
10
+ # =============================================================================
11
+
12
+ OUTPUT_IEEG = Path("consolidated_ieeg.npz")
13
+ OUTPUT_LCMV = Path("consolidated_lcmv.npz")
14
+
15
+ RUN_MAP = {"c": "eyes_closed", "o": "eyes_open", "l": "left_hand", "r": "right_hand"}
16
+
17
+ # PSD Parameters
18
+ SFREQ_DEFAULT = 500.0
19
+ PSD_WINDOW_SEC = 2.0
20
+ FMAX = 50
21
+
22
+ FREQ_BANDS = {
23
+ 'Delta': (1, 4), 'Theta': (4, 8), 'Alpha': (8, 12),
24
+ 'Low_Beta': (12, 20), 'High_Beta': (20, 30),
25
+ 'Low_Gamma': (30, 50), 'High_Gamma': (50, 100),
26
+ }
27
+
28
+ # Patterns
29
+ STN_PATTERNS = ["STN-L", "STN-R", "STN_L", "STN_R", "Left-STN", "Right-STN"]
30
+ GPI_PATTERNS = ["GPi-L", "GPi-R", "GPi_L", "GPi_R", "pGP-lh", "pGP-rh", "L-GPi", "R-GPi", "GPI-L", "GPI-R"]
31
+ M1_L_PATTERNS = ["ECOG-8-9-L", "ECOG-10-11-L", "M1-L", "Left-M1"]
32
+ M1_R_PATTERNS = ["ECOG-8-9-R", "ECOG-10-11-R", "M1-R", "Right-M1"]
33
+
34
+ ATLAS_LABELS = {
35
+ "STN": "STN (DiFuMo-223)",
36
+ "L_GPi": "L-GPi (GT pGP-lh)",
37
+ "R_GPi": "R-GPi (GT pGP-rh)",
38
+ }
39
+
40
+ COLORS = {
41
+ "IEEG": "#1f77b4",
42
+ "LCMV": "#d62728",
43
+ "STN": "#ff7f0e",
44
+ "L_GPi": "#2ca02c",
45
+ "R_GPi": "#9467bd",
46
+ }
47
+
48
+ # Global Data Handles
49
+ ALL_IEEG_DATA = None
50
+ ALL_LCMV_DATA = None
51
+
52
+ # =============================================================================
53
+ # CORE LOGIC
54
+ # =============================================================================
55
+
56
+ def compute_psd(time_series, sfreq=SFREQ_DEFAULT, fmax=FMAX):
57
+ ts = np.real(time_series).astype(np.float64)
58
+ window_size = int(PSD_WINDOW_SEC * sfreq)
59
+ if len(ts) < window_size:
60
+ window_size = max(int(len(ts)*0.8), 100)
61
+
62
+ nyq = sfreq * 0.5
63
+ if nyq <= 0.5: nyq = 0.51
64
+
65
+ b, a = signal.butter(4, 0.5 / nyq, btype='high')
66
+ filtered = signal.filtfilt(b, a, ts)
67
+
68
+ freqs, psd = signal.welch(filtered, fs=sfreq, window='hann', nperseg=window_size,
69
+ noverlap=window_size // 2, detrend='constant')
70
+ mask = (freqs >= 1.0) & (freqs <= fmax)
71
+ freqs, psd = freqs[mask], psd[mask]
72
+
73
+ if len(freqs) == 0:
74
+ return np.array([1, 10]), np.log10(np.array([1e-10, 1e-10]) + 1e-12)
75
+
76
+ psd_log = np.log10(psd + 1e-12)
77
+ return freqs.astype(np.float32), psd_log.astype(np.float32)
78
+
79
+ def load_data():
80
+ global ALL_IEEG_DATA, ALL_LCMV_DATA
81
+ if ALL_IEEG_DATA is None or ALL_LCMV_DATA is None:
82
+ if not OUTPUT_IEEG.exists() or not OUTPUT_LCMV.exists():
83
+ raise FileNotFoundError("Consolidated files missing. Please run consolidation first.")
84
+
85
+ ALL_IEEG_DATA = np.load(OUTPUT_IEEG, allow_pickle=True)
86
+ ALL_LCMV_DATA = np.load(OUTPUT_LCMV, allow_pickle=True)
87
+
88
+ def get_consolidated_ieeg(subj_id, run_code):
89
+ global ALL_IEEG_DATA
90
+ meta_key = f"meta_{subj_id}_{run_code}"
91
+ if meta_key not in ALL_IEEG_DATA.files:
92
+ return None, None
93
+ meta = ALL_IEEG_DATA[meta_key].item()
94
+ channels = {}
95
+ prefix = f"{subj_id}_{run_code}_"
96
+ for key in ALL_IEEG_DATA.files:
97
+ if key.startswith(prefix) and key != meta_key:
98
+ channels[key.replace(prefix, "")] = ALL_IEEG_DATA[key]
99
+ return channels, meta
100
+
101
+ def get_consolidated_lcmv(subj_id):
102
+ global ALL_LCMV_DATA
103
+ meta_key = f"meta_{subj_id}"
104
+ if meta_key not in ALL_LCMV_DATA.files:
105
+ return None, None
106
+ meta = ALL_LCMV_DATA[meta_key].item()
107
+ rois = {}
108
+ prefix = f"{subj_id}_"
109
+ for key in ALL_LCMV_DATA.files:
110
+ if key.startswith(prefix) and key != meta_key:
111
+ rois[key.replace(prefix, "")] = ALL_LCMV_DATA[key]
112
+ return rois, meta
113
+
114
+ def find_channel(channels_dict, patterns):
115
+ if channels_dict is None:
116
+ return None, None
117
+ for pattern in patterns:
118
+ if pattern in channels_dict:
119
+ return pattern, channels_dict[pattern]
120
+ for key in channels_dict.keys():
121
+ if pattern.lower() in key.lower():
122
+ return key, channels_dict[key]
123
+ return None, None
124
+
125
+ def create_interactive_plot(roi_name, ieeg_signal, ieeg_sfreq, ch_used,
126
+ source_signal, source_sfreq, source_label, source_color,
127
+ subject_id, run_id):
128
+
129
+ freqs_ieeg, psd_ieeg = compute_psd(ieeg_signal, sfreq=ieeg_sfreq)
130
+ freqs_src, psd_src = compute_psd(source_signal, sfreq=source_sfreq)
131
+
132
+ fig = go.Figure()
133
+
134
+ fig.add_trace(go.Scatter(
135
+ x=freqs_ieeg, y=psd_ieeg,
136
+ mode='lines', name=f'iEEG ({ch_used})',
137
+ line=dict(color=COLORS["IEEG"], width=3),
138
+ hovertemplate=f'<b>iEEG</b><br>Freq: %{{x:.2f}} Hz<br>PSD: %{{y:.2f}}<extra></extra>'
139
+ ))
140
+
141
+ fig.add_trace(go.Scatter(
142
+ x=freqs_src, y=psd_src,
143
+ mode='lines', name=source_label,
144
+ line=dict(color=source_color, width=3, dash='dash'),
145
+ hovertemplate=f'<b>{source_label}</b><br>Freq: %{{x:.2f}} Hz<br>PSD: %{{y:.2f}}<extra></extra>'
146
+ ))
147
+
148
+ shapes = []
149
+ n_bands = len(FREQ_BANDS)
150
+ band_colors = [f"rgba(31, 119, 180, {0.1 + (i/n_bands)*0.2})" for i in range(n_bands)]
151
+
152
+ for i, (band, (fmin, fmax)) in enumerate(FREQ_BANDS.items()):
153
+ band_low = max(fmin, min(freqs_ieeg))
154
+ band_high = min(fmax, max(freqs_ieeg))
155
+ if band_low < band_high:
156
+ shapes.append(dict(
157
+ type="rect", xref="x", yref="paper",
158
+ x0=band_low, x1=band_high, y0=0, y1=1,
159
+ fillcolor=band_colors[i], opacity=0.3, layer="below", line_width=0
160
+ ))
161
+
162
+ title_text = f"{subject_id} | Run: {run_id} | ROI: {roi_name}<br><sup>{source_label} vs iEEG</sup>"
163
+
164
+ fig.update_layout(
165
+ title=dict(text=title_text, font=dict(size=14, family="Arial")),
166
+ xaxis_title="Frequency (Hz)",
167
+ yaxis_title="PSD (log₁₀)",
168
+ xaxis=dict(range=[1, FMAX], type="linear"),
169
+ yaxis_type="linear",
170
+ hovermode="x unified",
171
+ legend=dict(x=0, y=1, bgcolor="rgba(255,255,255,0.8)"),
172
+ shapes=shapes,
173
+ template="plotly_white",
174
+ height=600,
175
+ margin=dict(l=50, r=50, t=60, b=50)
176
+ )
177
+
178
+ return fig
179
+
180
+ def generate_all_plots(subj_id, run_code):
181
+ """Generates all valid plots for a subject/run and returns a dictionary."""
182
+ try:
183
+ load_data()
184
+ except FileNotFoundError as e:
185
+ return {}, str(e)
186
+
187
+ cond = RUN_MAP.get(run_code, "unknown")
188
+ ieeg_ch, ieeg_meta = get_consolidated_ieeg(subj_id, run_code)
189
+ lcmv_rois, lcmv_meta = get_consolidated_lcmv(subj_id)
190
+
191
+ plots_dict = {}
192
+ logs = [f"Processing {subj_id} | Condition: {cond}"]
193
+
194
+ if ieeg_ch is None or lcmv_rois is None:
195
+ return plots_dict, f"No data found for {subj_id} (Run: {run_code})."
196
+
197
+ ieeg_sfreq = ieeg_meta.get('sfreq', SFREQ_DEFAULT)
198
+ lcmv_sfreq = lcmv_meta.get('sfreq', SFREQ_DEFAULT)
199
+
200
+ # Detect Electrodes
201
+ stn_l_ch, stn_l_sig = find_channel(ieeg_ch, STN_PATTERNS)
202
+ stn_r_ch, stn_r_sig = find_channel(ieeg_ch, [p.replace("-L","-R").replace("_L","_R") for p in STN_PATTERNS])
203
+ gpi_l_ch, gpi_l_sig = find_channel(ieeg_ch, GPI_PATTERNS)
204
+
205
+ gpi_r_ch, gpi_r_sig = None, None
206
+ if gpi_l_ch:
207
+ right_patterns = [gpi_l_ch.replace("L","R").replace("l","r").replace("lh","rh")]
208
+ right_patterns.extend([p.replace("-L","-R").replace("_L","_R") for p in GPI_PATTERNS])
209
+ gpi_r_ch, gpi_r_sig = find_channel(ieeg_ch, right_patterns)
210
+
211
+ m1_l_ch, m1_l_sig = find_channel(ieeg_ch, M1_L_PATTERNS)
212
+ m1_r_ch, m1_r_sig = find_channel(ieeg_ch, M1_R_PATTERNS)
213
+
214
+ def add_plot(name, sig, ch, roi_key, label, color):
215
+ if sig is not None and ch is not None and roi_key in lcmv_rois:
216
+ fig = create_interactive_plot(name, sig, ieeg_sfreq, ch, lcmv_rois[roi_key], lcmv_sfreq, label, color, subj_id, run_code)
217
+ key = f"{name} vs {label}"
218
+ plots_dict[key] = fig
219
+ logs.append(f"✅ Found: {key}")
220
+
221
+ # M1
222
+ add_plot("L_M1", m1_l_sig, m1_l_ch, f"L_M1_{cond}", "LCMV MNI voxel", COLORS["LCMV"])
223
+ add_plot("R_M1", m1_r_sig, m1_r_ch, f"R_M1_{cond}", "LCMV MNI voxel", COLORS["LCMV"])
224
+
225
+ # STN
226
+ if stn_l_sig is not None:
227
+ add_plot("L_STN", stn_l_sig, stn_l_ch, f"L_STN_{cond}", "LCMV MNI voxel", COLORS["LCMV"])
228
+ if f"STN_{cond}" in lcmv_rois:
229
+ add_plot("L_STN", stn_l_sig, stn_l_ch, f"STN_{cond}", ATLAS_LABELS["STN"], COLORS["STN"])
230
+
231
+ if stn_r_sig is not None:
232
+ add_plot("R_STN", stn_r_sig, stn_r_ch, f"R_STN_{cond}", "LCMV MNI voxel", COLORS["LCMV"])
233
+ if f"STN_{cond}" in lcmv_rois:
234
+ add_plot("R_STN", stn_r_sig, stn_r_ch, f"STN_{cond}", ATLAS_LABELS["STN"], COLORS["STN"])
235
+
236
+ # GPi (Fallback)
237
+ if gpi_l_sig is not None and stn_l_sig is None:
238
+ add_plot("L_GPi", gpi_l_sig, gpi_l_ch, f"L_GPi_{cond}", "LCMV MNI voxel (GPi)", COLORS["LCMV"])
239
+ if f"L_GPi_{cond}" in lcmv_rois:
240
+ add_plot("L_GPi", gpi_l_sig, gpi_l_ch, f"L_GPi_{cond}", ATLAS_LABELS["L_GPi"], COLORS["L_GPi"])
241
+
242
+ if gpi_r_sig is not None and stn_r_sig is None:
243
+ add_plot("R_GPi", gpi_r_sig, gpi_r_ch, f"R_GPi_{cond}", "LCMV MNI voxel (GPi)", COLORS["LCMV"])
244
+ if f"R_GPi_{cond}" in lcmv_rois:
245
+ add_plot("R_GPi", gpi_r_sig, gpi_r_ch, f"R_GPi_{cond}", ATLAS_LABELS["R_GPi"], COLORS["R_GPi"])
246
+
247
+ if not plots_dict:
248
+ logs.append("⚠️ No matching electrode/ROI pairs found.")
249
+
250
+ return plots_dict, "\n".join(logs)
251
+
252
+ def get_available_subjects():
253
+ if not OUTPUT_LCMV.exists():
254
+ return []
255
+ data = np.load(OUTPUT_LCMV, allow_pickle=True)
256
+ subjects = set()
257
+ for key in data.files:
258
+ if key.startswith("meta_"):
259
+ subjects.add(key.replace("meta_", ""))
260
+ return sorted(list(subjects))
261
+
262
+ # =============================================================================
263
+ # GRADIO INTERFACE
264
+ # =============================================================================
265
+
266
+ # Note: 'theme' parameter removed from constructor for Gradio 5.0+ compatibility
267
+ with gr.Blocks(title="Interactive iEEG-LCMV Viewer") as demo:
268
+ gr.Markdown("# Interactive iEEG & LCMV Viewer")
269
+ gr.Markdown("Select a subject and condition to generate available comparisons. Then choose specific plots to visualize.")
270
+
271
+ # State to store generated plots for the current selection
272
+ current_plots_state = gr.State({})
273
+
274
+ with gr.Row():
275
+ with gr.Column(scale=1):
276
+ gr.Markdown("### 1. Select Data")
277
+ btn_refresh = gr.Button("🔄 Refresh Subjects")
278
+ subject_dropdown = gr.Dropdown(label="Subject", choices=[], interactive=True)
279
+ run_dropdown = gr.Dropdown(
280
+ label="Condition",
281
+ choices=["c", "o", "l", "r"],
282
+ value="c",
283
+ info="c: Eyes Closed, o: Eyes Open, l: Left Hand, r: Right Hand"
284
+ )
285
+ btn_generate = gr.Button("🔍 Find Available Plots", variant="primary")
286
+
287
+ gr.Markdown("### 2. Choose Visualization")
288
+ plot_selector = gr.Dropdown(label="Select Plot to View", choices=[], interactive=True)
289
+
290
+ gr.Markdown("### Log")
291
+ val_log = gr.Textbox(label="Status", lines=6, interactive=False)
292
+
293
+ with gr.Column(scale=3):
294
+ gr.Markdown("### PSD Comparison")
295
+ plot_display = gr.Plot(label="Interactive Plot", show_label=False)
296
+
297
+ # Event Handlers
298
+
299
+ def refresh_subjects():
300
+ subs = get_available_subjects()
301
+ return gr.Dropdown(choices=subs, value=subs[0] if subs else None)
302
+
303
+ def process_and_update_dropdown(subj, run):
304
+ """Generates plots, updates state, log, dropdown options, and shows the first plot."""
305
+ if not subj:
306
+ return {}, "Please select a subject.", gr.Dropdown(choices=[], value=None), None
307
+
308
+ plots_dict, log_msg = generate_all_plots(subj, run)
309
+ choices = list(plots_dict.keys())
310
+
311
+ if not choices:
312
+ return plots_dict, log_msg, gr.Dropdown(choices=[], value=None), None
313
+
314
+ initial_val = choices[0]
315
+ initial_fig = plots_dict[initial_val]
316
+
317
+ return plots_dict, log_msg, gr.Dropdown(choices=choices, value=initial_val), initial_fig
318
+
319
+ def on_plot_selection(plots_dict, selected_key):
320
+ """Updates only the plot when dropdown changes."""
321
+ if not plots_dict or not selected_key:
322
+ return None
323
+ return plots_dict.get(selected_key)
324
+
325
+ # Wire up events
326
+ btn_refresh.click(fn=refresh_subjects, inputs=[], outputs=[subject_dropdown])
327
+ demo.load(fn=refresh_subjects, inputs=[], outputs=[subject_dropdown])
328
+
329
+ # When Generate is clicked: Update State, Log, Dropdown, AND Plot
330
+ btn_generate.click(
331
+ fn=process_and_update_dropdown,
332
+ inputs=[subject_dropdown, run_dropdown],
333
+ outputs=[current_plots_state, val_log, plot_selector, plot_display]
334
+ )
335
+
336
+ # When Dropdown changes: Update Plot Display only
337
+ plot_selector.change(
338
+ fn=on_plot_selection,
339
+ inputs=[current_plots_state, plot_selector],
340
+ outputs=[plot_display]
341
+ )
342
+
343
+ if __name__ == "__main__":
344
+ # Note: 'theme' parameter moved to launch() for Gradio 5.0+
345
+ demo.launch(theme=gr.themes.Soft())