JayLacoma commited on
Commit
12b5269
Β·
verified Β·
1 Parent(s): 6bcc8bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -104
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  from pathlib import Path
5
 
6
  # -----------------------------
7
- # Configuration
8
  # -----------------------------
9
 
10
  FREQ_BANDS = {
@@ -17,27 +17,18 @@ FREQ_BANDS = {
17
  'High_Gamma': (50, 100)
18
  }
19
 
 
20
  condition_labels = {
21
- 'bima_activity_on': 'Bima Activity (ON)',
22
- 'bima_activity_off': 'Bima Activity (OFF)',
23
- 'hands_move_on': 'Left Hand Moves (ON)',
24
- 'hands_move_off': 'Left Hand Moves (OFF)',
25
- 'rest_eyes_closed_on': 'Rest Eyes Closed (ON)',
26
- 'rest_eyes_closed_off': 'Rest Eyes Closed (OFF)',
27
- 'rest_eyes_open_on': 'Rest Eyes Open (ON)',
28
- 'rest_eyes_open_off': 'Rest Eyes Open (OFF)'
29
  }
30
 
31
- # 🎨 Separate colors for ON and OFF states
32
  condition_colors = {
33
- 'bima_activity_on': '#FF6B35', # Fiery Orange
34
- 'bima_activity_off': '#4A90E2', # Cool Blue
35
- 'hands_move_on': '#E53E3E', # Bold Red
36
- 'hands_move_off': '#38B2AC', # Teal
37
- 'rest_eyes_closed_on': '#805AD5', # Purple
38
- 'rest_eyes_closed_off': '#D69E2E', # Golden Yellow
39
- 'rest_eyes_open_on': '#38A169', # Green
40
- 'rest_eyes_open_off': '#D6336C', # Magenta
41
  }
42
 
43
  band_colors = {
@@ -49,7 +40,7 @@ band_colors = {
49
  METHOD_DISPLAY = {"multitaper": "Multitaper", "welch": "Welch"}
50
 
51
  # -----------------------------
52
- # PSD Analyzer Class
53
  # -----------------------------
54
 
55
  class PSDAnalyzer:
@@ -61,9 +52,9 @@ class PSDAnalyzer:
61
  loaded = np.load(data_file, allow_pickle=True)
62
  self.all_data = loaded['data'].tolist()
63
 
64
- # Extract metadata
65
  self.subjects_list = sorted({d['subject'] for d in self.all_data})
66
- self.conditions_list = sorted({f"{d['condition']}_{d['med_state']}" for d in self.all_data})
67
  self.region_names = self.all_data[0]['region_names']
68
 
69
  # Precompute traces
@@ -84,7 +75,7 @@ class PSDAnalyzer:
84
 
85
  for entry in self.all_data:
86
  subject = entry['subject']
87
- cond_key = f"{entry['condition']}_{entry['med_state']}"
88
  for result in entry['results']:
89
  region_name = result['region_name']
90
  for method in ['multitaper', 'welch']:
@@ -98,9 +89,8 @@ class PSDAnalyzer:
98
  traces[region_name][cond_key]['subjects'].append(subject)
99
  return traces
100
 
101
-
102
  def create_plot(self, region, method, selected_conditions, selected_subjects, log_scale, show_bands, align_by_delta):
103
- """Generate Plotly figure with optional Delta-band alignment."""
104
  fig = go.Figure()
105
  freq_range = (1, 100)
106
 
@@ -114,18 +104,21 @@ class PSDAnalyzer:
114
  plotted_conditions = []
115
 
116
  for condition in selected_conditions:
 
 
 
117
  data = self.traces[region][condition][method]
118
  if not data:
119
  continue
120
 
121
- # Filter data within global frequency range
122
  filtered_data = []
123
  for item in data:
124
- mask_full = (item['freqs'] >= freq_range[0]) & (item['freqs'] <= freq_range[1])
125
- if np.any(mask_full):
126
  filtered_data.append({
127
- 'freqs': item['freqs'][mask_full],
128
- 'psd': item['psd'][mask_full],
129
  'subject': item['subject']
130
  })
131
 
@@ -139,20 +132,16 @@ class PSDAnalyzer:
139
  freqs = filtered_data[0]['freqs']
140
  psds = np.array([item['psd'] for item in filtered_data])
141
 
142
- # Apply Delta band alignment?
143
  if align_by_delta:
144
  delta_mask = (freqs >= 1) & (freqs <= 4)
145
  if np.any(delta_mask):
146
  delta_mean = np.mean(psds[:, delta_mask], axis=1, keepdims=True)
147
- # Avoid division by zero
148
  delta_mean[delta_mean == 0] = 1e-9
149
  psds = psds / delta_mean
150
  else:
151
- # Fallback: normalize by global mean
152
  psds = psds / (psds.mean(axis=1, keepdims=True) + 1e-9)
153
 
154
  mean_psd = np.mean(psds, axis=0)
155
-
156
  label = f"{condition_labels.get(condition, condition)} (Mean)"
157
  if align_by_delta:
158
  label += " [Ξ”-norm]"
@@ -172,13 +161,12 @@ class PSDAnalyzer:
172
  for item in subject_data:
173
  freqs, psd = item['freqs'], item['psd']
174
 
175
- # Apply Delta band normalization per subject
176
  if align_by_delta:
177
  delta_mask = (freqs >= 1) & (freqs <= 4)
178
  if np.any(delta_mask):
179
  psd = psd / np.mean(psd[delta_mask])
180
  else:
181
- psd = psd / (np.mean(psd) + 1e-9) # Prevent division by zero
182
 
183
  label = f"{condition_labels.get(condition, condition)} - {item['subject']}"
184
  if align_by_delta:
@@ -201,65 +189,42 @@ class PSDAnalyzer:
201
  )
202
  return fig
203
 
204
- # ==============================
205
- # Frequency Band Visualization
206
- # ==============================
207
  if show_bands:
208
- # Add shaded band regions using paper coordinates (so they don't affect scaling)
209
  for band, (low, high) in FREQ_BANDS.items():
210
- # Only show band if it overlaps with visible freq range
211
  if high < freq_range[0] or low > freq_range[1]:
212
  continue
213
-
214
  band_low = max(low, freq_range[0])
215
  band_high = min(high, freq_range[1])
216
-
217
- # Add shaded rectangle in paper space (full height, behind data)
218
  fig.add_shape(
219
  type="rect",
220
- x0=band_low,
221
- x1=band_high,
222
- y0=0,
223
- y1=1,
224
- xref="x",
225
- yref="paper",
226
  fillcolor=band_colors[band],
227
  opacity=0.15,
228
  layer="below",
229
  line_width=0,
230
  )
231
-
232
- # Add label above the plot (outside the plotting area)
233
  center_x = (band_low + band_high) / 2
234
  fig.add_annotation(
235
- x=center_x,
236
- y=1.02, # Just above the top of the plot
237
  text=band,
238
  showarrow=False,
239
  font=dict(size=9, color="dimgray"),
240
- xanchor="center",
241
- yanchor="bottom",
242
- xref="x",
243
- yref="paper",
244
  opacity=0.85,
245
  )
246
 
247
- # Add vertical dotted lines at band boundaries (only once per edge)
248
  seen_edges = set()
249
  for low, high in FREQ_BANDS.values():
250
  for edge in [low, high]:
251
  if freq_range[0] < edge < freq_range[1] and edge not in seen_edges:
252
- fig.add_vline(
253
- x=edge,
254
- line=dict(color="lightgray", width=1, dash="dot"),
255
- opacity=0.5,
256
- layer="below"
257
- )
258
  seen_edges.add(edge)
259
 
260
- # ==============================
261
- # βœ… FIXED: Lock axis ranges to prevent distortion
262
- # ==============================
263
  yaxis_title = "Power"
264
  if align_by_delta:
265
  yaxis_title = "Power (norm. to Delta)"
@@ -281,15 +246,14 @@ class PSDAnalyzer:
281
  ),
282
  margin=dict(r=160, t=60, b=80, l=60),
283
  hovermode='x unified',
284
- # πŸ”’ CRITICAL: Lock axis ranges to prevent visual distortion
285
  xaxis=dict(
286
  range=[freq_range[0], freq_range[1]],
287
- fixedrange=True, # Prevents zoom/pan from UI, but more importantly β€” stops auto-resize
288
  showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
289
  showline=True, linewidth=1, linecolor='gray'
290
  ),
291
  yaxis=dict(
292
- fixedrange=True, # Stops Plotly from auto-resizing y-axis when shapes are added
293
  showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
294
  showline=True, linewidth=1, linecolor='gray'
295
  )
@@ -310,14 +274,14 @@ def create_app(analyzer: PSDAnalyzer):
310
  ) as demo:
311
  gr.Markdown("# πŸ“Š Power Spectral Density Analysis Dashboard")
312
  gr.Markdown("""
313
- > **Color Guide**:
314
- > - **Bold/Dark colors** = ON state
315
- > - **Lighter colors** = OFF state
 
316
  >
317
  > πŸ’‘ Tip: Use _'Align by Delta'_ to compare spectral shapes across conditions.
318
  """)
319
 
320
- # TOP ROW: Data Selection
321
  with gr.Row():
322
  with gr.Column(scale=1):
323
  subjects = gr.CheckboxGroup(
@@ -334,40 +298,26 @@ def create_app(analyzer: PSDAnalyzer):
334
  interactive=True
335
  )
336
 
337
- # CENTER: BIG PSD PLOT
338
  with gr.Row():
339
- plot_output = gr.Plot(
340
- label="PSD Plot",
341
- elem_classes="big-plot"
342
- )
343
 
344
- # BOTTOM ROW: Display Options
345
  with gr.Row():
346
- with gr.Column():
347
- log_scale = gr.Checkbox(value=True, label="Use Log Scale (Y-axis)")
348
- with gr.Column():
349
- show_bands = gr.Checkbox(value=True, label="Show Frequency Bands")
350
- with gr.Column():
351
- align_by_delta = gr.Checkbox(value=False, label="Align by Delta (1–4 Hz)")
352
-
353
- # BOTTOM ROW: Method and Region
354
  with gr.Row():
355
- with gr.Column(min_width=200):
356
- method = gr.Radio(
357
- choices=["multitaper", "welch"],
358
- value="multitaper",
359
- label="PSD Method",
360
- interactive=True
361
- )
362
- with gr.Column(min_width=200):
363
- region = gr.Dropdown(
364
- choices=analyzer.region_names,
365
- value=analyzer.region_names[0],
366
- label="Brain Region",
367
- interactive=True
368
- )
369
 
370
- # Update plot whenever any input changes
371
  inputs = [region, method, conditions, subjects, log_scale, show_bands, align_by_delta]
372
  for comp in inputs:
373
  comp.change(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output)
@@ -381,7 +331,7 @@ def create_app(analyzer: PSDAnalyzer):
381
  # -----------------------------
382
 
383
  if __name__ == "__main__":
384
- DATA_PATH = "psd_voxel_all_data.npz"
385
  analyzer = PSDAnalyzer(DATA_PATH)
386
  app = create_app(analyzer)
387
  app.launch(share=True, show_error=True)
 
4
  from pathlib import Path
5
 
6
  # -----------------------------
7
+ # Configuration (Updated for YOUR data)
8
  # -----------------------------
9
 
10
  FREQ_BANDS = {
 
17
  'High_Gamma': (50, 100)
18
  }
19
 
20
+ # βœ… Only 3 real conditions β€” no ON/OFF
21
  condition_labels = {
22
+ 'bima_activity': 'Bima Task',
23
+ 'rest_eyes_open': 'Rest (Eyes Open)',
24
+ 'rest_eyes_closed': 'Rest (Eyes Closed)'
 
 
 
 
 
25
  }
26
 
27
+ # 🎨 One distinct color per condition
28
  condition_colors = {
29
+ 'bima_activity': '#FF6B35', # Fiery Orange
30
+ 'rest_eyes_open': '#38A169', # Green
31
+ 'rest_eyes_closed': '#805AD5', # Purple
 
 
 
 
 
32
  }
33
 
34
  band_colors = {
 
40
  METHOD_DISPLAY = {"multitaper": "Multitaper", "welch": "Welch"}
41
 
42
  # -----------------------------
43
+ # PSD Analyzer Class (Fixed for your data)
44
  # -----------------------------
45
 
46
  class PSDAnalyzer:
 
52
  loaded = np.load(data_file, allow_pickle=True)
53
  self.all_data = loaded['data'].tolist()
54
 
55
+ # Extract clean metadata β€” ignore med_state (always 'unknown')
56
  self.subjects_list = sorted({d['subject'] for d in self.all_data})
57
+ self.conditions_list = sorted({d['condition'] for d in self.all_data}) # βœ… no _unknown
58
  self.region_names = self.all_data[0]['region_names']
59
 
60
  # Precompute traces
 
75
 
76
  for entry in self.all_data:
77
  subject = entry['subject']
78
+ cond_key = entry['condition'] # βœ… clean key, no med_state
79
  for result in entry['results']:
80
  region_name = result['region_name']
81
  for method in ['multitaper', 'welch']:
 
89
  traces[region_name][cond_key]['subjects'].append(subject)
90
  return traces
91
 
 
92
  def create_plot(self, region, method, selected_conditions, selected_subjects, log_scale, show_bands, align_by_delta):
93
+ """Generate Plotly figure."""
94
  fig = go.Figure()
95
  freq_range = (1, 100)
96
 
 
104
  plotted_conditions = []
105
 
106
  for condition in selected_conditions:
107
+ if condition not in self.traces.get(region, {}):
108
+ continue
109
+
110
  data = self.traces[region][condition][method]
111
  if not data:
112
  continue
113
 
114
+ # Filter to global freq range
115
  filtered_data = []
116
  for item in data:
117
+ mask = (item['freqs'] >= freq_range[0]) & (item['freqs'] <= freq_range[1])
118
+ if np.any(mask):
119
  filtered_data.append({
120
+ 'freqs': item['freqs'][mask],
121
+ 'psd': item['psd'][mask],
122
  'subject': item['subject']
123
  })
124
 
 
132
  freqs = filtered_data[0]['freqs']
133
  psds = np.array([item['psd'] for item in filtered_data])
134
 
 
135
  if align_by_delta:
136
  delta_mask = (freqs >= 1) & (freqs <= 4)
137
  if np.any(delta_mask):
138
  delta_mean = np.mean(psds[:, delta_mask], axis=1, keepdims=True)
 
139
  delta_mean[delta_mean == 0] = 1e-9
140
  psds = psds / delta_mean
141
  else:
 
142
  psds = psds / (psds.mean(axis=1, keepdims=True) + 1e-9)
143
 
144
  mean_psd = np.mean(psds, axis=0)
 
145
  label = f"{condition_labels.get(condition, condition)} (Mean)"
146
  if align_by_delta:
147
  label += " [Ξ”-norm]"
 
161
  for item in subject_data:
162
  freqs, psd = item['freqs'], item['psd']
163
 
 
164
  if align_by_delta:
165
  delta_mask = (freqs >= 1) & (freqs <= 4)
166
  if np.any(delta_mask):
167
  psd = psd / np.mean(psd[delta_mask])
168
  else:
169
+ psd = psd / (np.mean(psd) + 1e-9)
170
 
171
  label = f"{condition_labels.get(condition, condition)} - {item['subject']}"
172
  if align_by_delta:
 
189
  )
190
  return fig
191
 
192
+ # Frequency band shading
 
 
193
  if show_bands:
 
194
  for band, (low, high) in FREQ_BANDS.items():
 
195
  if high < freq_range[0] or low > freq_range[1]:
196
  continue
 
197
  band_low = max(low, freq_range[0])
198
  band_high = min(high, freq_range[1])
 
 
199
  fig.add_shape(
200
  type="rect",
201
+ x0=band_low, x1=band_high,
202
+ y0=0, y1=1,
203
+ xref="x", yref="paper",
 
 
 
204
  fillcolor=band_colors[band],
205
  opacity=0.15,
206
  layer="below",
207
  line_width=0,
208
  )
 
 
209
  center_x = (band_low + band_high) / 2
210
  fig.add_annotation(
211
+ x=center_x, y=1.02,
 
212
  text=band,
213
  showarrow=False,
214
  font=dict(size=9, color="dimgray"),
215
+ xanchor="center", yanchor="bottom",
216
+ xref="x", yref="paper",
 
 
217
  opacity=0.85,
218
  )
219
 
 
220
  seen_edges = set()
221
  for low, high in FREQ_BANDS.values():
222
  for edge in [low, high]:
223
  if freq_range[0] < edge < freq_range[1] and edge not in seen_edges:
224
+ fig.add_vline(x=edge, line=dict(color="lightgray", dash="dot"), opacity=0.5)
 
 
 
 
 
225
  seen_edges.add(edge)
226
 
227
+ # Final layout
 
 
228
  yaxis_title = "Power"
229
  if align_by_delta:
230
  yaxis_title = "Power (norm. to Delta)"
 
246
  ),
247
  margin=dict(r=160, t=60, b=80, l=60),
248
  hovermode='x unified',
 
249
  xaxis=dict(
250
  range=[freq_range[0], freq_range[1]],
251
+ fixedrange=True,
252
  showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
253
  showline=True, linewidth=1, linecolor='gray'
254
  ),
255
  yaxis=dict(
256
+ fixedrange=True,
257
  showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
258
  showline=True, linewidth=1, linecolor='gray'
259
  )
 
274
  ) as demo:
275
  gr.Markdown("# πŸ“Š Power Spectral Density Analysis Dashboard")
276
  gr.Markdown("""
277
+ > **Conditions**:
278
+ > - πŸ”₯ **Bima Task**
279
+ > - 🟒 **Rest (Eyes Open)**
280
+ > - 🟣 **Rest (Eyes Closed)**
281
  >
282
  > πŸ’‘ Tip: Use _'Align by Delta'_ to compare spectral shapes across conditions.
283
  """)
284
 
 
285
  with gr.Row():
286
  with gr.Column(scale=1):
287
  subjects = gr.CheckboxGroup(
 
298
  interactive=True
299
  )
300
 
 
301
  with gr.Row():
302
+ plot_output = gr.Plot(label="PSD Plot", elem_classes="big-plot")
 
 
 
303
 
 
304
  with gr.Row():
305
+ log_scale = gr.Checkbox(value=True, label="Use Log Scale (Y-axis)")
306
+ show_bands = gr.Checkbox(value=True, label="Show Frequency Bands")
307
+ align_by_delta = gr.Checkbox(value=False, label="Align by Delta (1–4 Hz)")
308
+
 
 
 
 
309
  with gr.Row():
310
+ method = gr.Radio(
311
+ choices=["multitaper", "welch"],
312
+ value="multitaper",
313
+ label="PSD Method"
314
+ )
315
+ region = gr.Dropdown(
316
+ choices=analyzer.region_names,
317
+ value=analyzer.region_names[0],
318
+ label="Brain Region"
319
+ )
 
 
 
 
320
 
 
321
  inputs = [region, method, conditions, subjects, log_scale, show_bands, align_by_delta]
322
  for comp in inputs:
323
  comp.change(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output)
 
331
  # -----------------------------
332
 
333
  if __name__ == "__main__":
334
+ DATA_PATH = "/home/jaizor/jaizor/xtra/derivatives/psd_voxel_cache/psd_voxel_all_data.npz"
335
  analyzer = PSDAnalyzer(DATA_PATH)
336
  app = create_app(analyzer)
337
  app.launch(share=True, show_error=True)