Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import plotly.graph_objects as go | |
| import numpy as np | |
| from pathlib import Path | |
| # ----------------------------- | |
| # Configuration | |
| # ----------------------------- | |
| FREQ_BANDS = { | |
| 'Delta': (1, 4), | |
| 'Theta': (4, 8), | |
| 'Alpha': (8, 12), | |
| 'Low_Beta': (12, 20), | |
| 'High_Beta': (20, 30), | |
| 'Low_Gamma': (30, 50), | |
| 'High_Gamma': (50, 100) | |
| } | |
| condition_labels = { | |
| 'bima_activity': 'Bima Task', | |
| 'rest_eyes_open': 'Rest (Eyes Open)', | |
| 'rest_eyes_closed': 'Rest (Eyes Closed)' | |
| } | |
| condition_colors = { | |
| 'bima_activity': 'red', | |
| 'rest_eyes_open': '#38A169', | |
| 'rest_eyes_closed': '#805AD5' | |
| } | |
| band_colors = { | |
| 'Delta': 'lightsalmon', 'Theta': 'wheat', 'Alpha': 'mediumpurple', | |
| 'Low_Beta': 'skyblue', 'High_Beta': 'lightcoral', | |
| 'Low_Gamma': 'lightgreen', 'High_Gamma': 'plum' | |
| } | |
| # ----------------------------- | |
| # Area-Based PSD Analyzer (Multi-Condition + Gamma Alignment) | |
| # ----------------------------- | |
| class AreaPSDAnalyzer: | |
| def __init__(self, data_file): | |
| data_file = Path(data_file) | |
| if not data_file.exists(): | |
| raise FileNotFoundError(f"Data file not found: {data_file}") | |
| loaded = np.load(data_file, allow_pickle=True) | |
| self.data = loaded['data'].item() | |
| self.areas = loaded['areas'].tolist() | |
| self.conditions = loaded['conditions'].tolist() | |
| self.freqs = loaded['freqs'] | |
| # Build subject list per area | |
| self.subjects_by_area = {} | |
| for area in self.areas: | |
| if area in self.data: | |
| first_cond = next(iter(self.data[area])) | |
| self.subjects_by_area[area] = self.data[area][first_cond]['subjects'] | |
| else: | |
| self.subjects_by_area[area] = [] | |
| print(f"β Loaded PSD data for areas: {self.areas}") | |
| def create_plot(self, area, selected_conditions, selected_subjects, log_scale, show_bands, align_by_gamma): | |
| if not selected_conditions: | |
| fig = go.Figure() | |
| fig.add_annotation(text="Select at least one condition", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper") | |
| return fig | |
| # Validate area and conditions | |
| if area not in self.data: | |
| fig = go.Figure() | |
| fig.add_annotation(text=f"No data for area: {area}", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper") | |
| return fig | |
| freqs = self.freqs | |
| fig = go.Figure() | |
| plotted = False | |
| for condition in selected_conditions: | |
| if condition not in self.data[area]: | |
| continue | |
| area_data = self.data[area][condition] | |
| subjects = area_data['subjects'] | |
| if align_by_gamma: | |
| individual_psd = area_data['individual_gamma_norm'] | |
| mean_psd = area_data['mean_gamma_norm'] | |
| else: | |
| individual_psd = area_data['individual_raw'] | |
| mean_psd = area_data['mean_raw'] | |
| color = condition_colors.get(condition, 'gray') | |
| if "All Subjects" in selected_subjects: | |
| label = f"{condition_labels.get(condition, condition)} (Mean, n={len(subjects)})" | |
| if align_by_gamma: | |
| label += " [Ξ-norm]" | |
| fig.add_trace(go.Scatter( | |
| x=freqs, y=mean_psd, | |
| mode='lines', | |
| name=label, | |
| line=dict(color=color, width=3), | |
| opacity=0.9 | |
| )) | |
| else: | |
| available_subjects = [s for s in selected_subjects if s != "All Subjects"] | |
| # Find indices of selected subjects | |
| selected_indices = [i for i, s in enumerate(subjects) if s in available_subjects] | |
| if not selected_indices: | |
| continue | |
| for i in selected_indices: | |
| psd = individual_psd[i] | |
| label = f"{condition_labels.get(condition, condition)} - {subjects[i]}" | |
| if align_by_gamma: | |
| label += " [Ξ-norm]" | |
| fig.add_trace(go.Scatter( | |
| x=freqs, y=psd, | |
| mode='lines', | |
| name=label, | |
| line=dict(color=color, width=2), | |
| opacity=0.7 | |
| )) | |
| plotted = True | |
| if not plotted: | |
| fig.add_annotation(text="No data for selected options", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper") | |
| return fig | |
| # Band shading | |
| if show_bands: | |
| for band, (low, high) in FREQ_BANDS.items(): | |
| if high < freqs[0] or low > freqs[-1]: | |
| continue | |
| band_low = max(low, freqs[0]) | |
| band_high = min(high, freqs[-1]) | |
| fig.add_shape( | |
| type="rect", x0=band_low, x1=band_high, y0=0, y1=1, | |
| xref="x", yref="paper", fillcolor=band_colors[band], | |
| opacity=0.15, layer="below", line_width=0 | |
| ) | |
| center_x = (band_low + band_high) / 2 | |
| fig.add_annotation( | |
| x=center_x, y=1.02, text=band, showarrow=False, | |
| font=dict(size=9, color="dimgray"), xanchor="center", | |
| yanchor="bottom", xref="x", yref="paper", opacity=0.85 | |
| ) | |
| # Layout | |
| y_title = "Power (Gamma-Normalized)" if align_by_gamma else "Power" | |
| if log_scale: | |
| y_title += " [log]" | |
| fig.update_layout( | |
| title=f"PSD β {area}", | |
| xaxis_title="Frequency (Hz)", | |
| yaxis_title=y_title, | |
| yaxis_type="log" if log_scale else "linear", | |
| template="plotly_white", | |
| height=650, | |
| legend=dict( | |
| y=0.99, yanchor="top", x=1.02, xanchor="left", | |
| bgcolor="rgba(255,255,255,0.8)", font_size=11 | |
| ), | |
| margin=dict(r=160, t=60, b=80, l=60), | |
| hovermode='x unified', | |
| xaxis=dict( | |
| range=[freqs[0], freqs[-1]], | |
| fixedrange=True, | |
| showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)' | |
| ), | |
| yaxis=dict( | |
| fixedrange=True, | |
| showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)' | |
| ) | |
| ) | |
| return fig | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| def create_app(analyzer: AreaPSDAnalyzer): | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π PSD Comparison β Gamma Alignment for ERD") | |
| gr.Markdown(""" | |
| > **Compare conditions** (e.g., Bima vs. Rest) on the same plot | |
| > Toggle **'Align by Gamma'** to normalize all spectra to 30β100 Hz power | |
| > Essential for fair ERD/ERS interpretation | |
| """) | |
| with gr.Row(): | |
| area = gr.Dropdown( | |
| choices=analyzer.areas, | |
| value=analyzer.areas[0], | |
| label="Anatomical Area" | |
| ) | |
| subjects = gr.CheckboxGroup( | |
| choices=["All Subjects"] + analyzer.subjects_by_area[analyzer.areas[0]], | |
| value=["All Subjects"], | |
| label="Subjects" | |
| ) | |
| conditions = gr.CheckboxGroup( | |
| choices=analyzer.conditions, | |
| value=analyzer.conditions[:2], # Default: Bima + one rest | |
| label="Conditions to Compare" | |
| ) | |
| with gr.Row(): | |
| log_scale = gr.Checkbox(value=True, label="Log Scale (Y-axis)") | |
| show_bands = gr.Checkbox(value=True, label="Show Frequency Bands") | |
| align_by_gamma = gr.Checkbox(value=False, label="Align by Gamma (30β100 Hz)") | |
| plot_output = gr.Plot() | |
| # Update subject list when area changes | |
| def update_subjects(area): | |
| return gr.CheckboxGroup( | |
| choices=["All Subjects"] + analyzer.subjects_by_area.get(area, []), | |
| value=["All Subjects"] | |
| ) | |
| area.change(fn=update_subjects, inputs=area, outputs=subjects) | |
| inputs = [area, conditions, subjects, log_scale, show_bands, align_by_gamma] | |
| for comp in inputs: | |
| comp.change(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output) | |
| demo.load(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output) | |
| return demo | |
| # ----------------------------- | |
| # Launch App | |
| # ----------------------------- | |
| if __name__ == "__main__": | |
| DATA_PATH = "psd_by_area_gamma_ready.npz" | |
| analyzer = AreaPSDAnalyzer(DATA_PATH) | |
| app = create_app(analyzer) | |
| app.launch(share=True, show_error=True) |