JayLacoma commited on
Commit
db07b77
·
verified ·
1 Parent(s): 19743e3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +319 -0
app.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import plotly.graph_objects as go
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ # -----------------------------
7
+ # Configuration
8
+ # -----------------------------
9
+
10
+ FREQ_BANDS = {
11
+ 'Delta': (1, 4),
12
+ 'Theta': (4, 8),
13
+ 'Alpha': (8, 12),
14
+ 'Low_Beta': (12, 20),
15
+ 'High_Beta': (20, 30),
16
+ 'Low_Gamma': (30, 50),
17
+ 'High_Gamma': (50, 100)
18
+ }
19
+
20
+ condition_labels = {
21
+ 'bima_activity_on': 'Bima Activity (ON)',
22
+ 'bima_activity_off': 'Bima Activity (OFF)',
23
+ 'hands_move_on': 'Hands Move (ON)',
24
+ 'hands_move_off': 'Hands Move (OFF)',
25
+ 'rest_eyes_closed_on': 'Rest Eyes Closed (ON)',
26
+ 'rest_eyes_closed_off': 'Rest Eyes Closed (OFF)',
27
+ 'rest_eyes_open_on': 'Rest Eyes Open (ON)',
28
+ 'rest_eyes_open_off': 'Rest Eyes Open (OFF)'
29
+ }
30
+
31
+ condition_colors = {
32
+ 'bima_activity': 'orange',
33
+ 'hands_move': 'red',
34
+ 'rest_eyes_closed': 'blue',
35
+ 'rest_eyes_open': 'green'
36
+ }
37
+
38
+ band_colors = {
39
+ 'Delta': 'lightsalmon', 'Theta': 'wheat', 'Alpha': 'mediumpurple',
40
+ 'Low_Beta': 'skyblue', 'High_Beta': 'lightcoral',
41
+ 'Low_Gamma': 'lightgreen', 'High_Gamma': 'plum'
42
+ }
43
+
44
+ METHOD_DISPLAY = {"multitaper": "Multitaper", "welch": "Welch"}
45
+
46
+
47
+ # -----------------------------
48
+ # PSD Analyzer Class
49
+ # -----------------------------
50
+
51
+ class PSDAnalyzer:
52
+ def __init__(self, data_file):
53
+ data_file = Path(data_file)
54
+ if not data_file.exists():
55
+ raise FileNotFoundError(f"Data file not found: {data_file}")
56
+
57
+ loaded = np.load(data_file, allow_pickle=True)
58
+ self.all_data = loaded['data'].tolist()
59
+
60
+ # Extract metadata
61
+ self.subjects_list = sorted({d['subject'] for d in self.all_data})
62
+ self.conditions_list = sorted({f"{d['condition']}_{d['med_state']}" for d in self.all_data})
63
+ self.region_names = self.all_data[0]['region_names']
64
+
65
+ # Precompute traces
66
+ self.traces = self._precompute_traces()
67
+
68
+ print(f"✅ Loaded {len(self.all_data)} subject-condition combinations")
69
+ print(f"Subjects: {len(self.subjects_list)} | Conditions: {len(self.conditions_list)} | Regions: {len(self.region_names)}")
70
+
71
+ def _precompute_traces(self):
72
+ """Precompute traces grouped by region → condition → method."""
73
+ traces = {
74
+ region: {
75
+ cond: {'multitaper': [], 'welch': [], 'subjects': []}
76
+ for cond in self.conditions_list
77
+ }
78
+ for region in self.region_names
79
+ }
80
+
81
+ for entry in self.all_data:
82
+ subject = entry['subject']
83
+ cond_key = f"{entry['condition']}_{entry['med_state']}"
84
+ for result in entry['results']:
85
+ region_name = result['region_name']
86
+ for method in ['multitaper', 'welch']:
87
+ psd_data = result['psd_dual'].get(method)
88
+ if psd_data is not None:
89
+ freqs, psd = psd_data['freqs'], psd_data['psd']
90
+ traces[region_name][cond_key][method].append({'freqs': freqs, 'psd': psd, 'subject': subject})
91
+ if subject not in traces[region_name][cond_key]['subjects']:
92
+ traces[region_name][cond_key]['subjects'].append(subject)
93
+ return traces
94
+
95
+ def create_plot(self, region, method, selected_conditions, selected_subjects, log_scale, show_bands):
96
+ """Generate Plotly figure based on selections."""
97
+ fig = go.Figure()
98
+ freq_range = (1, 100)
99
+
100
+ # Handle no conditions selected
101
+ if not selected_conditions:
102
+ fig.add_annotation(text="Please select at least one condition", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper")
103
+ return fig
104
+
105
+ plotted_conditions = []
106
+
107
+ for condition in selected_conditions:
108
+ data = self.traces[region][condition][method]
109
+ if not data:
110
+ continue
111
+
112
+ # Filter frequency range
113
+ filtered_data = [
114
+ {k: v for k, v in item.items()}
115
+ for item in data
116
+ if np.any((item['freqs'] >= freq_range[0]) & (item['freqs'] <= freq_range[1]))
117
+ ]
118
+ if not filtered_data:
119
+ continue
120
+
121
+ # Determine subjects to plot
122
+ if "All Subjects" in selected_subjects:
123
+ freqs = filtered_data[0]['freqs']
124
+ mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1])
125
+ freqs = freqs[mask]
126
+ psds = np.array([item['psd'][mask] for item in filtered_data])
127
+ mean_psd = np.mean(psds, axis=0)
128
+
129
+ base_cond = condition.rsplit('_', 1)[0]
130
+ color = condition_colors.get(base_cond, 'gray')
131
+ linestyle = 'solid' if 'on' in condition else 'dash'
132
+
133
+ fig.add_trace(go.Scatter(
134
+ x=freqs, y=mean_psd,
135
+ mode='lines',
136
+ name=f"{condition_labels.get(condition, condition)} (Mean, n={len(psds)})",
137
+ line=dict(color=color, width=3, dash=linestyle)
138
+ ))
139
+ else:
140
+ available_subjects = [s for s in selected_subjects if s != "All Subjects"]
141
+ subject_data = [item for item in filtered_data if item['subject'] in available_subjects]
142
+ if not subject_data:
143
+ continue
144
+
145
+ base_cond = condition.rsplit('_', 1)[0]
146
+ color = condition_colors.get(base_cond, 'gray')
147
+ linestyle = 'solid' if 'on' in condition else 'dash'
148
+
149
+ for item in subject_data:
150
+ freqs, psd = item['freqs'], item['psd']
151
+ mask = (freqs >= freq_range[0]) & (freqs <= freq_range[1])
152
+ if not np.any(mask):
153
+ continue
154
+ fig.add_trace(go.Scatter(
155
+ x=freqs[mask], y=psd[mask],
156
+ mode='lines',
157
+ name=f"{condition_labels.get(condition, condition)} - {item['subject']}",
158
+ line=dict(color=color, width=2, dash=linestyle),
159
+ opacity=0.8
160
+ ))
161
+ plotted_conditions.append(condition)
162
+
163
+ # No valid data
164
+ if not plotted_conditions:
165
+ fig.add_annotation(
166
+ text="No data for selected conditions/subjects",
167
+ x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper"
168
+ )
169
+ return fig
170
+
171
+ # Add frequency band visualization - Clean & non-overlapping
172
+ if show_bands:
173
+ # 1. Shaded background for each band
174
+ for band, (low, high) in FREQ_BANDS.items():
175
+ if low < freq_range[1] and high > freq_range[0]:
176
+ fig.add_vrect(
177
+ x0=low, x1=high,
178
+ fillcolor=band_colors[band], opacity=0.15,
179
+ layer="below", line_width=0
180
+ )
181
+
182
+ # 2. Dotted lines at band edges
183
+ all_edges = sorted(set([low for low, _ in FREQ_BANDS.values()] +
184
+ [high for _, high in FREQ_BANDS.values()]))
185
+ for edge in all_edges:
186
+ if freq_range[0] < edge < freq_range[1]:
187
+ fig.add_vline(x=edge, line=dict(color="gray", width=1, dash="dot"), opacity=0.5)
188
+
189
+ # 3. Estimate Y position for band labels
190
+ try:
191
+ max_power = max(trace.y.max() for trace in fig.data if hasattr(trace, 'y'))
192
+ y_pos = max_power * 1.15 if not log_scale else max_power * 3 # higher in log space
193
+ except Exception:
194
+ y_pos = 1.1
195
+
196
+ # 4. Place band names at center, above the plot
197
+ for band, (low, high) in FREQ_BANDS.items():
198
+ center = (low + high) / 2
199
+ if freq_range[0] <= center <= freq_range[1]:
200
+ fig.add_annotation(
201
+ x=center,
202
+ y=y_pos,
203
+ text=band,
204
+ showarrow=False,
205
+ font=dict(size=10, color="dimgray"),
206
+ xanchor="center",
207
+ yanchor="bottom",
208
+ opacity=0.9,
209
+ xref="x",
210
+ yref="y"
211
+ )
212
+
213
+ # Final layout
214
+ yaxis_title = "Power (log)" if log_scale else "Power"
215
+ fig.update_layout(
216
+ title=f"PSD - {region} | {METHOD_DISPLAY[method]}",
217
+ xaxis_title="Frequency (Hz)",
218
+ yaxis_title=yaxis_title,
219
+ yaxis_type="log" if log_scale else "linear",
220
+ template="plotly_white",
221
+ height=650,
222
+ legend=dict(
223
+ y=0.99, yanchor="top", x=1.02, xanchor="left",
224
+ bgcolor="rgba(255,255,255,0.8)", font_size=11
225
+ ),
226
+ margin=dict(r=160, t=60, b=80),
227
+ hovermode='x unified'
228
+ )
229
+ fig.update_xaxes(
230
+ showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
231
+ showline=True, linewidth=1, linecolor='gray'
232
+ )
233
+ fig.update_yaxes(
234
+ showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
235
+ showline=True, linewidth=1, linecolor='gray'
236
+ )
237
+
238
+ return fig
239
+
240
+
241
+ # -----------------------------
242
+ # Gradio Interface
243
+ # -----------------------------
244
+
245
+ def create_app(analyzer: PSDAnalyzer):
246
+ with gr.Blocks(
247
+ theme=gr.themes.Soft(),
248
+ title="PSD Analysis Dashboard",
249
+ css=".big-plot { height: 700px; }"
250
+ ) as demo:
251
+ gr.Markdown("# 📊 Power Spectral Density Analysis Dashboard")
252
+
253
+ # TOP ROW: Data Selection
254
+ with gr.Row():
255
+ with gr.Column(scale=1):
256
+ subjects = gr.CheckboxGroup(
257
+ choices=["All Subjects"] + analyzer.subjects_list,
258
+ value=["All Subjects"],
259
+ label="Subjects",
260
+ interactive=True
261
+ )
262
+ with gr.Column(scale=1):
263
+ conditions = gr.CheckboxGroup(
264
+ choices=analyzer.conditions_list,
265
+ value=analyzer.conditions_list[:2],
266
+ label="Conditions",
267
+ interactive=True
268
+ )
269
+
270
+ # CENTER: BIG PSD PLOT
271
+ with gr.Row():
272
+ with gr.Column():
273
+ plot_output = gr.Plot(
274
+ label="PSD Plot",
275
+ elem_classes="big-plot"
276
+ )
277
+
278
+ # BOTTOM ROW: Display Options
279
+ with gr.Row():
280
+ with gr.Column():
281
+ log_scale = gr.Checkbox(value=True, label="Use Log Scale (Y-axis)")
282
+ with gr.Column():
283
+ show_bands = gr.Checkbox(value=True, label="Show Frequency Bands")
284
+
285
+ # BOTTOM ROW: Method (Left) and Region (Right)
286
+ with gr.Row():
287
+ with gr.Column(min_width=200):
288
+ method = gr.Radio(
289
+ choices=["multitaper", "welch"],
290
+ value="multitaper",
291
+ label="PSD Method",
292
+ interactive=True
293
+ )
294
+ with gr.Column(min_width=200):
295
+ region = gr.Dropdown(
296
+ choices=analyzer.region_names,
297
+ value=analyzer.region_names[0],
298
+ label="Brain Region",
299
+ interactive=True
300
+ )
301
+
302
+ # Update plot on any input change
303
+ inputs = [region, method, conditions, subjects, log_scale, show_bands]
304
+ for comp in inputs:
305
+ comp.change(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output)
306
+ demo.load(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output)
307
+
308
+ return demo
309
+
310
+
311
+ # -----------------------------
312
+ # Launch App
313
+ # -----------------------------
314
+
315
+ if __name__ == "__main__":
316
+ DATA_PATH = "psd_voxel_all_data.npz"
317
+ analyzer = PSDAnalyzer(DATA_PATH)
318
+ app = create_app(analyzer)
319
+ app.launch(share=True, show_error=True)