import gradio as gr import plotly.graph_objects as go import numpy as np from pathlib import Path # ----------------------------- # Configuration # ----------------------------- FREQ_BANDS = { 'Delta': (1, 4), 'Theta': (4, 8), 'Alpha': (8, 12), 'Low_Beta': (12, 20), 'High_Beta': (20, 30), 'Low_Gamma': (30, 50), 'High_Gamma': (50, 100) } condition_labels = { 'bima_activity': 'Bima Task', 'rest_eyes_open': 'Rest (Eyes Open)', 'rest_eyes_closed': 'Rest (Eyes Closed)' } condition_colors = { 'bima_activity': 'red', 'rest_eyes_open': '#38A169', 'rest_eyes_closed': '#805AD5' } band_colors = { 'Delta': 'lightsalmon', 'Theta': 'wheat', 'Alpha': 'mediumpurple', 'Low_Beta': 'skyblue', 'High_Beta': 'lightcoral', 'Low_Gamma': 'lightgreen', 'High_Gamma': 'plum' } # ----------------------------- # Area-Based PSD Analyzer (Multi-Condition + Gamma Alignment) # ----------------------------- class AreaPSDAnalyzer: def __init__(self, data_file): data_file = Path(data_file) if not data_file.exists(): raise FileNotFoundError(f"Data file not found: {data_file}") loaded = np.load(data_file, allow_pickle=True) self.data = loaded['data'].item() self.areas = loaded['areas'].tolist() self.conditions = loaded['conditions'].tolist() self.freqs = loaded['freqs'] # Build subject list per area self.subjects_by_area = {} for area in self.areas: if area in self.data: first_cond = next(iter(self.data[area])) self.subjects_by_area[area] = self.data[area][first_cond]['subjects'] else: self.subjects_by_area[area] = [] print(f"✅ Loaded PSD data for areas: {self.areas}") def create_plot(self, area, selected_conditions, selected_subjects, log_scale, show_bands, align_by_gamma): if not selected_conditions: fig = go.Figure() fig.add_annotation(text="Select at least one condition", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper") return fig # Validate area and conditions if area not in self.data: fig = go.Figure() fig.add_annotation(text=f"No data for area: {area}", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper") return fig freqs = self.freqs fig = go.Figure() plotted = False for condition in selected_conditions: if condition not in self.data[area]: continue area_data = self.data[area][condition] subjects = area_data['subjects'] if align_by_gamma: individual_psd = area_data['individual_gamma_norm'] mean_psd = area_data['mean_gamma_norm'] else: individual_psd = area_data['individual_raw'] mean_psd = area_data['mean_raw'] color = condition_colors.get(condition, 'gray') if "All Subjects" in selected_subjects: label = f"{condition_labels.get(condition, condition)} (Mean, n={len(subjects)})" if align_by_gamma: label += " [Γ-norm]" fig.add_trace(go.Scatter( x=freqs, y=mean_psd, mode='lines', name=label, line=dict(color=color, width=3), opacity=0.9 )) else: available_subjects = [s for s in selected_subjects if s != "All Subjects"] # Find indices of selected subjects selected_indices = [i for i, s in enumerate(subjects) if s in available_subjects] if not selected_indices: continue for i in selected_indices: psd = individual_psd[i] label = f"{condition_labels.get(condition, condition)} - {subjects[i]}" if align_by_gamma: label += " [Γ-norm]" fig.add_trace(go.Scatter( x=freqs, y=psd, mode='lines', name=label, line=dict(color=color, width=2), opacity=0.7 )) plotted = True if not plotted: fig.add_annotation(text="No data for selected options", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper") return fig # Band shading if show_bands: for band, (low, high) in FREQ_BANDS.items(): if high < freqs[0] or low > freqs[-1]: continue band_low = max(low, freqs[0]) band_high = min(high, freqs[-1]) fig.add_shape( type="rect", x0=band_low, x1=band_high, y0=0, y1=1, xref="x", yref="paper", fillcolor=band_colors[band], opacity=0.15, layer="below", line_width=0 ) center_x = (band_low + band_high) / 2 fig.add_annotation( x=center_x, y=1.02, text=band, showarrow=False, font=dict(size=9, color="dimgray"), xanchor="center", yanchor="bottom", xref="x", yref="paper", opacity=0.85 ) # Layout y_title = "Power (Gamma-Normalized)" if align_by_gamma else "Power" if log_scale: y_title += " [log]" fig.update_layout( title=f"PSD — {area}", xaxis_title="Frequency (Hz)", yaxis_title=y_title, yaxis_type="log" if log_scale else "linear", template="plotly_white", height=650, legend=dict( y=0.99, yanchor="top", x=1.02, xanchor="left", bgcolor="rgba(255,255,255,0.8)", font_size=11 ), margin=dict(r=160, t=60, b=80, l=60), hovermode='x unified', xaxis=dict( range=[freqs[0], freqs[-1]], fixedrange=True, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)' ), yaxis=dict( fixedrange=True, showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)' ) ) return fig # ----------------------------- # Gradio Interface # ----------------------------- def create_app(analyzer: AreaPSDAnalyzer): with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 📊 PSD Comparison — Gamma Alignment for ERD") gr.Markdown(""" > **Compare conditions** (e.g., Bima vs. Rest) on the same plot > Toggle **'Align by Gamma'** to normalize all spectra to 30–100 Hz power > Essential for fair ERD/ERS interpretation """) with gr.Row(): area = gr.Dropdown( choices=analyzer.areas, value=analyzer.areas[0], label="Anatomical Area" ) subjects = gr.CheckboxGroup( choices=["All Subjects"] + analyzer.subjects_by_area[analyzer.areas[0]], value=["All Subjects"], label="Subjects" ) conditions = gr.CheckboxGroup( choices=analyzer.conditions, value=analyzer.conditions[:2], # Default: Bima + one rest label="Conditions to Compare" ) with gr.Row(): log_scale = gr.Checkbox(value=True, label="Log Scale (Y-axis)") show_bands = gr.Checkbox(value=True, label="Show Frequency Bands") align_by_gamma = gr.Checkbox(value=False, label="Align by Gamma (30–100 Hz)") plot_output = gr.Plot() # Update subject list when area changes def update_subjects(area): return gr.CheckboxGroup( choices=["All Subjects"] + analyzer.subjects_by_area.get(area, []), value=["All Subjects"] ) area.change(fn=update_subjects, inputs=area, outputs=subjects) inputs = [area, conditions, subjects, log_scale, show_bands, align_by_gamma] for comp in inputs: comp.change(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output) demo.load(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output) return demo # ----------------------------- # Launch App # ----------------------------- if __name__ == "__main__": DATA_PATH = "psd_by_area_gamma_ready.npz" analyzer = AreaPSDAnalyzer(DATA_PATH) app = create_app(analyzer) app.launch(share=True, show_error=True)