JayLacoma's picture
Create app.py
9a4f177 verified
import gradio as gr
import plotly.graph_objects as go
import numpy as np
from pathlib import Path
# -----------------------------
# Configuration
# -----------------------------
FREQ_BANDS = {
'Delta': (1, 4),
'Theta': (4, 8),
'Alpha': (8, 12),
'Low_Beta': (12, 20),
'High_Beta': (20, 30),
'Low_Gamma': (30, 50),
'High_Gamma': (50, 100)
}
condition_labels = {
'bima_activity': 'Bima Task',
'rest_eyes_open': 'Rest (Eyes Open)',
'rest_eyes_closed': 'Rest (Eyes Closed)'
}
condition_colors = {
'bima_activity': 'red',
'rest_eyes_open': '#38A169',
'rest_eyes_closed': '#805AD5'
}
band_colors = {
'Delta': 'lightsalmon', 'Theta': 'wheat', 'Alpha': 'mediumpurple',
'Low_Beta': 'skyblue', 'High_Beta': 'lightcoral',
'Low_Gamma': 'lightgreen', 'High_Gamma': 'plum'
}
# -----------------------------
# Area-Based PSD Analyzer (Multi-Condition + Gamma Alignment)
# -----------------------------
class AreaPSDAnalyzer:
def __init__(self, data_file):
data_file = Path(data_file)
if not data_file.exists():
raise FileNotFoundError(f"Data file not found: {data_file}")
loaded = np.load(data_file, allow_pickle=True)
self.data = loaded['data'].item()
self.areas = loaded['areas'].tolist()
self.conditions = loaded['conditions'].tolist()
self.freqs = loaded['freqs']
# Build subject list per area
self.subjects_by_area = {}
for area in self.areas:
if area in self.data:
first_cond = next(iter(self.data[area]))
self.subjects_by_area[area] = self.data[area][first_cond]['subjects']
else:
self.subjects_by_area[area] = []
print(f"βœ… Loaded PSD data for areas: {self.areas}")
def create_plot(self, area, selected_conditions, selected_subjects, log_scale, show_bands, align_by_gamma):
if not selected_conditions:
fig = go.Figure()
fig.add_annotation(text="Select at least one condition", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper")
return fig
# Validate area and conditions
if area not in self.data:
fig = go.Figure()
fig.add_annotation(text=f"No data for area: {area}", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper")
return fig
freqs = self.freqs
fig = go.Figure()
plotted = False
for condition in selected_conditions:
if condition not in self.data[area]:
continue
area_data = self.data[area][condition]
subjects = area_data['subjects']
if align_by_gamma:
individual_psd = area_data['individual_gamma_norm']
mean_psd = area_data['mean_gamma_norm']
else:
individual_psd = area_data['individual_raw']
mean_psd = area_data['mean_raw']
color = condition_colors.get(condition, 'gray')
if "All Subjects" in selected_subjects:
label = f"{condition_labels.get(condition, condition)} (Mean, n={len(subjects)})"
if align_by_gamma:
label += " [Ξ“-norm]"
fig.add_trace(go.Scatter(
x=freqs, y=mean_psd,
mode='lines',
name=label,
line=dict(color=color, width=3),
opacity=0.9
))
else:
available_subjects = [s for s in selected_subjects if s != "All Subjects"]
# Find indices of selected subjects
selected_indices = [i for i, s in enumerate(subjects) if s in available_subjects]
if not selected_indices:
continue
for i in selected_indices:
psd = individual_psd[i]
label = f"{condition_labels.get(condition, condition)} - {subjects[i]}"
if align_by_gamma:
label += " [Ξ“-norm]"
fig.add_trace(go.Scatter(
x=freqs, y=psd,
mode='lines',
name=label,
line=dict(color=color, width=2),
opacity=0.7
))
plotted = True
if not plotted:
fig.add_annotation(text="No data for selected options", x=0.5, y=0.5, showarrow=False, xref="paper", yref="paper")
return fig
# Band shading
if show_bands:
for band, (low, high) in FREQ_BANDS.items():
if high < freqs[0] or low > freqs[-1]:
continue
band_low = max(low, freqs[0])
band_high = min(high, freqs[-1])
fig.add_shape(
type="rect", x0=band_low, x1=band_high, y0=0, y1=1,
xref="x", yref="paper", fillcolor=band_colors[band],
opacity=0.15, layer="below", line_width=0
)
center_x = (band_low + band_high) / 2
fig.add_annotation(
x=center_x, y=1.02, text=band, showarrow=False,
font=dict(size=9, color="dimgray"), xanchor="center",
yanchor="bottom", xref="x", yref="paper", opacity=0.85
)
# Layout
y_title = "Power (Gamma-Normalized)" if align_by_gamma else "Power"
if log_scale:
y_title += " [log]"
fig.update_layout(
title=f"PSD β€” {area}",
xaxis_title="Frequency (Hz)",
yaxis_title=y_title,
yaxis_type="log" if log_scale else "linear",
template="plotly_white",
height=650,
legend=dict(
y=0.99, yanchor="top", x=1.02, xanchor="left",
bgcolor="rgba(255,255,255,0.8)", font_size=11
),
margin=dict(r=160, t=60, b=80, l=60),
hovermode='x unified',
xaxis=dict(
range=[freqs[0], freqs[-1]],
fixedrange=True,
showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)'
),
yaxis=dict(
fixedrange=True,
showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)'
)
)
return fig
# -----------------------------
# Gradio Interface
# -----------------------------
def create_app(analyzer: AreaPSDAnalyzer):
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# πŸ“Š PSD Comparison β€” Gamma Alignment for ERD")
gr.Markdown("""
> **Compare conditions** (e.g., Bima vs. Rest) on the same plot
> Toggle **'Align by Gamma'** to normalize all spectra to 30–100 Hz power
> Essential for fair ERD/ERS interpretation
""")
with gr.Row():
area = gr.Dropdown(
choices=analyzer.areas,
value=analyzer.areas[0],
label="Anatomical Area"
)
subjects = gr.CheckboxGroup(
choices=["All Subjects"] + analyzer.subjects_by_area[analyzer.areas[0]],
value=["All Subjects"],
label="Subjects"
)
conditions = gr.CheckboxGroup(
choices=analyzer.conditions,
value=analyzer.conditions[:2], # Default: Bima + one rest
label="Conditions to Compare"
)
with gr.Row():
log_scale = gr.Checkbox(value=True, label="Log Scale (Y-axis)")
show_bands = gr.Checkbox(value=True, label="Show Frequency Bands")
align_by_gamma = gr.Checkbox(value=False, label="Align by Gamma (30–100 Hz)")
plot_output = gr.Plot()
# Update subject list when area changes
def update_subjects(area):
return gr.CheckboxGroup(
choices=["All Subjects"] + analyzer.subjects_by_area.get(area, []),
value=["All Subjects"]
)
area.change(fn=update_subjects, inputs=area, outputs=subjects)
inputs = [area, conditions, subjects, log_scale, show_bands, align_by_gamma]
for comp in inputs:
comp.change(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output)
demo.load(fn=analyzer.create_plot, inputs=inputs, outputs=plot_output)
return demo
# -----------------------------
# Launch App
# -----------------------------
if __name__ == "__main__":
DATA_PATH = "psd_by_area_gamma_ready.npz"
analyzer = AreaPSDAnalyzer(DATA_PATH)
app = create_app(analyzer)
app.launch(share=True, show_error=True)