JayLacoma commited on
Commit
77dbe6a
Β·
verified Β·
1 Parent(s): 9a4f177

Delete app2.py

Browse files
Files changed (1) hide show
  1. app2.py +0 -337
app2.py DELETED
@@ -1,337 +0,0 @@
1
- import plotly.graph_objects as go
2
- import numpy as np
3
- import gradio as gr
4
- from pathlib import Path
5
-
6
- # -----------------------------
7
- # Configuration (Updated for YOUR data)
8
- # -----------------------------
9
-
10
- FREQ_BANDS = {
11
- 'Delta': (1, 4),
12
- 'Theta': (4, 8),
13
- 'Alpha': (8, 12),
14
- 'Low_Beta': (12, 20),
15
- 'High_Beta': (20, 30),
16
- 'Low_Gamma': (30, 50),
17
- 'High_Gamma': (50, 100)
18
- }
19
-
20
- # βœ… Only 3 real conditions β€” no ON/OFF
21
- condition_labels = {
22
- 'bima_activity': 'Bima Task',
23
- 'rest_eyes_open': 'Rest (Eyes Open)',
24
- 'rest_eyes_closed': 'Rest (Eyes Closed)'
25
- }
26
-
27
- # 🎨 One distinct color per condition
28
- condition_colors = {
29
- 'bima_activity': 'red', # Fiery
30
- 'rest_eyes_open': '#38A169', # Green
31
- 'rest_eyes_closed': '#805AD5', # Purple
32
- }
33
-
34
- band_colors = {
35
- 'Delta': 'lightsalmon', 'Theta': 'wheat', 'Alpha': 'mediumpurple',
36
- 'Low_Beta': 'skyblue', 'High_Beta': 'lightcoral',
37
- 'Low_Gamma': 'lightgreen', 'High_Gamma': 'plum'
38
- }
39
-
40
- METHOD_DISPLAY = {"multitaper": "Multitaper", "welch": "Welch"}
41
-
42
- # -----------------------------
43
- # PSD Analyzer Class (Fixed for your data)
44
- # -----------------------------
45
-
46
- class PSDAnalyzer:
47
- def __init__(self, data_file):
48
- data_file = Path(data_file)
49
- if not data_file.exists():
50
- raise FileNotFoundError(f"Data file not found: {data_file}")
51
-
52
- loaded = np.load(data_file, allow_pickle=True)
53
- self.all_data = loaded['data'].tolist()
54
-
55
- # Extract clean metadata β€” ignore med_state (always 'unknown')
56
- self.subjects_list = sorted({d['subject'] for d in self.all_data})
57
- self.conditions_list = sorted({d['condition'] for d in self.all_data}) # βœ… no _unknown
58
- self.region_names = self.all_data[0]['region_names']
59
-
60
- # Precompute traces
61
- self.traces = self._precompute_traces()
62
-
63
- print(f"βœ… Loaded {len(self.all_data)} subject-condition combinations")
64
- print(f"Subjects: {len(self.subjects_list)} | Conditions: {len(self.conditions_list)} | Regions: {len(self.region_names)}")
65
-
66
- def _precompute_traces(self):
67
- """Precompute traces grouped by region β†’ condition β†’ method."""
68
- traces = {
69
- region: {
70
- cond: {'multitaper': [], 'welch': [], 'subjects': []}
71
- for cond in self.conditions_list
72
- }
73
- for region in self.region_names
74
- }
75
-
76
- for entry in self.all_data:
77
- subject = entry['subject']
78
- cond_key = entry['condition'] # βœ… clean key, no med_state
79
- for result in entry['results']:
80
- region_name = result['region_name']
81
- for method in ['multitaper', 'welch']:
82
- psd_data = result['psd_dual'].get(method)
83
- if psd_data is not None:
84
- freqs, psd = psd_data['freqs'], psd_data['psd']
85
- traces[region_name][cond_key][method].append({
86
- 'freqs': freqs, 'psd': psd, 'subject': subject
87
- })
88
- if subject not in traces[region_name][cond_key]['subjects']:
89
- traces[region_name][cond_key]['subjects'].append(subject)
90
- return traces
91
-
92
- def create_plot(self, region, method, selected_conditions, selected_subjects, log_scale, show_bands, align_by_gamma):
93
- """Generate Plotly figure."""
94
- fig = go.Figure()
95
- freq_range = (1, 100)
96
-
97
- if not selected_conditions:
98
- fig.add_annotation(
99
- text="Please select at least one condition",
100
- x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper"
101
- )
102
- return fig
103
-
104
- plotted_conditions = []
105
-
106
- for condition in selected_conditions:
107
- if condition not in self.traces.get(region, {}):
108
- continue
109
-
110
- data = self.traces[region][condition][method]
111
- if not data:
112
- continue
113
-
114
- # Filter to global freq range
115
- filtered_data = []
116
- for item in data:
117
- mask = (item['freqs'] >= freq_range[0]) & (item['freqs'] <= freq_range[1])
118
- if np.any(mask):
119
- filtered_data.append({
120
- 'freqs': item['freqs'][mask],
121
- 'psd': item['psd'][mask],
122
- 'subject': item['subject']
123
- })
124
-
125
- if not filtered_data:
126
- continue
127
-
128
- color = condition_colors.get(condition, 'gray')
129
-
130
- if "All Subjects" in selected_subjects:
131
- # Mean across subjects
132
- freqs = filtered_data[0]['freqs']
133
- psds = np.array([item['psd'] for item in filtered_data])
134
-
135
- if align_by_gamma:
136
- delta_mask = (freqs >= 30) & (freqs <= 100)
137
- if np.any(delta_mask):
138
- delta_mean = np.mean(psds[:, delta_mask], axis=1, keepdims=True)
139
- delta_mean[delta_mean == 0] = 1e-9
140
- psds = psds / delta_mean
141
- else:
142
- psds = psds / (psds.mean(axis=1, keepdims=True) + 1e-9)
143
-
144
- mean_psd = np.mean(psds, axis=0)
145
- label = f"{condition_labels.get(condition, condition)} (Mean)"
146
- if align_by_gamma:
147
- label += " [Ξ”-norm]"
148
- fig.add_trace(go.Scatter(
149
- x=freqs, y=mean_psd,
150
- mode='lines',
151
- name=label,
152
- line=dict(color=color, width=3),
153
- opacity=0.9
154
- ))
155
- else:
156
- available_subjects = [s for s in selected_subjects if s != "All Subjects"]
157
- subject_data = [item for item in filtered_data if item['subject'] in available_subjects]
158
- if not subject_data:
159
- continue
160
-
161
- for item in subject_data:
162
- freqs, psd = item['freqs'], item['psd']
163
-
164
- if align_by_gamma:
165
- delta_mask = (freqs >= 1) & (freqs <= 4)
166
- if np.any(delta_mask):
167
- psd = psd / np.mean(psd[delta_mask])
168
- else:
169
- psd = psd / (np.mean(psd) + 1e-9)
170
-
171
- label = f"{condition_labels.get(condition, condition)} - {item['subject']}"
172
- if align_by_gamma:
173
- label += " [Ξ”-norm]"
174
-
175
- fig.add_trace(go.Scatter(
176
- x=freqs, y=psd,
177
- mode='lines',
178
- name=label,
179
- line=dict(color=color, width=2),
180
- opacity=0.7
181
- ))
182
-
183
- plotted_conditions.append(condition)
184
-
185
- if not plotted_conditions:
186
- fig.add_annotation(
187
- text="No data for selected conditions/subjects",
188
- x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper"
189
- )
190
- return fig
191
-
192
- # Frequency band shading
193
- if show_bands:
194
- for band, (low, high) in FREQ_BANDS.items():
195
- if high < freq_range[0] or low > freq_range[1]:
196
- continue
197
- band_low = max(low, freq_range[0])
198
- band_high = min(high, freq_range[1])
199
- fig.add_shape(
200
- type="rect",
201
- x0=band_low, x1=band_high,
202
- y0=0, y1=1,
203
- xref="x", yref="paper",
204
- fillcolor=band_colors[band],
205
- opacity=0.15,
206
- layer="below",
207
- line_width=0,
208
- )
209
- center_x = (band_low + band_high) / 2
210
- fig.add_annotation(
211
- x=center_x, y=1.02,
212
- text=band,
213
- showarrow=False,
214
- font=dict(size=9, color="dimgray"),
215
- xanchor="center", yanchor="bottom",
216
- xref="x", yref="paper",
217
- opacity=0.85,
218
- )
219
-
220
- seen_edges = set()
221
- for low, high in FREQ_BANDS.values():
222
- for edge in [low, high]:
223
- if freq_range[0] < edge < freq_range[1] and edge not in seen_edges:
224
- fig.add_vline(x=edge, line=dict(color="lightgray", dash="dot"), opacity=0.5)
225
- seen_edges.add(edge)
226
-
227
- # Final layout
228
- yaxis_title = "Power"
229
- if align_by_gamma:
230
- yaxis_title = "Power (norm. to gamma)"
231
- if log_scale:
232
- yaxis_title += " [log]"
233
- elif log_scale:
234
- yaxis_title += " (log)"
235
-
236
- fig.update_layout(
237
- title=f"PSD - {region} | {METHOD_DISPLAY[method]}",
238
- xaxis_title="Frequency (Hz)",
239
- yaxis_title=yaxis_title,
240
- yaxis_type="log" if log_scale else "linear",
241
- template="plotly_white",
242
- height=650,
243
- legend=dict(
244
- y=0.99, yanchor="top", x=1.02, xanchor="left",
245
- bgcolor="rgba(255,255,255,0.8)", font_size=11
246
- ),
247
- margin=dict(r=160, t=60, b=80, l=60),
248
- hovermode='x unified',
249
- xaxis=dict(
250
- range=[freq_range[0], freq_range[1]],
251
- fixedrange=True,
252
- showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
253
- showline=True, linewidth=1, linecolor='gray'
254
- ),
255
- yaxis=dict(
256
- fixedrange=True,
257
- showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
258
- showline=True, linewidth=1, linecolor='gray'
259
- )
260
- )
261
-
262
- return fig
263
-
264
-
265
- # -----------------------------
266
- # Gradio Interface
267
- # -----------------------------
268
-
269
- def create_app(analyzer: PSDAnalyzer):
270
- with gr.Blocks(
271
- theme=gr.themes.Soft(),
272
- title="PSD Analysis Dashboard",
273
- css=".big-plot { height: 700px; }"
274
- ) as demo:
275
- gr.Markdown("# πŸ“Š Power Spectral Density Analysis Dashboard")
276
- gr.Markdown("""
277
- > **Conditions**:
278
- > - πŸ”₯ **Bima Task**
279
- > - 🟒 **Rest (Eyes Open)**
280
- > - 🟣 **Rest (Eyes Closed)**
281
- >
282
- > πŸ’‘ Tip: Use _'Align by Delta'_ to compare spectral shapes across conditions.
283
- """)
284
-
285
- with gr.Row():
286
- with gr.Column(scale=1):
287
- subjects = gr.CheckboxGroup(
288
- choices=["All Subjects"] + analyzer.subjects_list,
289
- value=["All Subjects"],
290
- label="Subjects",
291
- interactive=True
292
- )
293
- with gr.Column(scale=1):
294
- conditions = gr.CheckboxGroup(
295
- choices=analyzer.conditions_list,
296
- value=analyzer.conditions_list[:2],
297
- label="Conditions",
298
- interactive=True
299
- )
300
-
301
- with gr.Row():
302
- plot_output = gr.Plot(label="PSD Plot", elem_classes="big-plot")
303
-
304
- with gr.Row():
305
- log_scale = gr.Checkbox(value=True, label="Use Log Scale (Y-axis)")
306
- show_bands = gr.Checkbox(value=True, label="Show Frequency Bands")
307
- align_by_gamma = gr.Checkbox(value=False, label="Align by gamma")
308
-
309
- with gr.Row():
310
- method = gr.Radio(
311
- choices=["multitaper", "welch"],
312
- value="multitaper",
313
- label="PSD Method"
314
- )
315
- region = gr.Dropdown(
316
- choices=analyzer.region_names,
317
- value=analyzer.region_names[0],
318
- label="Brain Region"
319
- )
320
-
321
- inputs = [region, method, conditions, subjects, log_scale, show_bands, align_by_gamma]
322
- for comp in inputs:
323
- comp.change(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output)
324
- demo.load(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output)
325
-
326
- return demo
327
-
328
-
329
- # -----------------------------
330
- # Launch App
331
- # -----------------------------
332
-
333
- if __name__ == "__main__":
334
- DATA_PATH = "psd_voxel_all_data.npz"
335
- analyzer = PSDAnalyzer(DATA_PATH)
336
- app = create_app(analyzer)
337
- app.launch(share=True, show_error=True)