Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -98,6 +98,7 @@ class PSDAnalyzer:
|
|
| 98 |
traces[region_name][cond_key]['subjects'].append(subject)
|
| 99 |
return traces
|
| 100 |
|
|
|
|
| 101 |
def create_plot(self, region, method, selected_conditions, selected_subjects, log_scale, show_bands, align_by_delta):
|
| 102 |
"""Generate Plotly figure with optional Delta-band alignment."""
|
| 103 |
fig = go.Figure()
|
|
@@ -200,40 +201,20 @@ class PSDAnalyzer:
|
|
| 200 |
)
|
| 201 |
return fig
|
| 202 |
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
# Add frequency band visualization - Colored bands with legend
|
| 207 |
if show_bands:
|
| 208 |
-
#
|
| 209 |
-
# We'll draw rectangles in "paper" space, so they don't distort the plot
|
| 210 |
-
|
| 211 |
-
# Get current y-axis range in paper coordinates
|
| 212 |
-
try:
|
| 213 |
-
y_min_data = min(trace.y.min() for trace in fig.data if hasattr(trace, 'y'))
|
| 214 |
-
y_max_data = max(trace.y.max() for trace in fig.data if hasattr(trace, 'y'))
|
| 215 |
-
except Exception:
|
| 216 |
-
y_min_data, y_max_data = 1e-5, 1.0
|
| 217 |
-
|
| 218 |
-
# Map data y-range to paper y-coordinates (0 to 1)
|
| 219 |
-
# For log scale, we map log(y) to linear paper space
|
| 220 |
-
y_log_min = np.log10(y_min_data)
|
| 221 |
-
y_log_max = np.log10(y_max_data)
|
| 222 |
-
y_log_range = y_log_max - y_log_min
|
| 223 |
-
|
| 224 |
-
# Now loop through bands
|
| 225 |
for band, (low, high) in FREQ_BANDS.items():
|
| 226 |
-
# Only show if
|
| 227 |
if high < freq_range[0] or low > freq_range[1]:
|
| 228 |
continue
|
| 229 |
|
| 230 |
band_low = max(low, freq_range[0])
|
| 231 |
band_high = min(high, freq_range[1])
|
| 232 |
|
| 233 |
-
#
|
| 234 |
-
# We want the band to cover full y-axis in paper space
|
| 235 |
-
# So we set y0=0, y1=1 in paper coordinates
|
| 236 |
-
# But we must use `yref="paper"` and `xref="x"`
|
| 237 |
fig.add_shape(
|
| 238 |
type="rect",
|
| 239 |
x0=band_low,
|
|
@@ -246,14 +227,13 @@ class PSDAnalyzer:
|
|
| 246 |
opacity=0.15,
|
| 247 |
layer="below",
|
| 248 |
line_width=0,
|
| 249 |
-
name=band,
|
| 250 |
)
|
| 251 |
|
| 252 |
-
# Add label above the plot (outside the
|
| 253 |
center_x = (band_low + band_high) / 2
|
| 254 |
fig.add_annotation(
|
| 255 |
x=center_x,
|
| 256 |
-
y=1.02, #
|
| 257 |
text=band,
|
| 258 |
showarrow=False,
|
| 259 |
font=dict(size=9, color="dimgray"),
|
|
@@ -261,10 +241,10 @@ class PSDAnalyzer:
|
|
| 261 |
yanchor="bottom",
|
| 262 |
xref="x",
|
| 263 |
yref="paper",
|
| 264 |
-
opacity=0.
|
| 265 |
)
|
| 266 |
|
| 267 |
-
#
|
| 268 |
seen_edges = set()
|
| 269 |
for low, high in FREQ_BANDS.values():
|
| 270 |
for edge in [low, high]:
|
|
@@ -277,61 +257,9 @@ class PSDAnalyzer:
|
|
| 277 |
)
|
| 278 |
seen_edges.add(edge)
|
| 279 |
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
if freq_range[0] <= center <= freq_range[1]:
|
| 284 |
-
fig.add_annotation(
|
| 285 |
-
x=center,
|
| 286 |
-
y=0.01, # Just above bottom of plot
|
| 287 |
-
text=band,
|
| 288 |
-
showarrow=False,
|
| 289 |
-
font=dict(size=8, color="darkgray", family="sans-serif"),
|
| 290 |
-
xanchor="center",
|
| 291 |
-
yanchor="bottom",
|
| 292 |
-
opacity=0.8,
|
| 293 |
-
xref="x",
|
| 294 |
-
yref="paper" # Fixed paper coordinate
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
# 2. Clean vertical lines only at band boundaries (avoid clutter)
|
| 298 |
-
seen_edges = set()
|
| 299 |
-
edges = []
|
| 300 |
-
for low, high in FREQ_BANDS.values():
|
| 301 |
-
if freq_range[0] < low < freq_range[1] and low not in seen_edges:
|
| 302 |
-
edges.append(low)
|
| 303 |
-
seen_edges.add(low)
|
| 304 |
-
if freq_range[0] < high < freq_range[1] and high not in seen_edges:
|
| 305 |
-
edges.append(high)
|
| 306 |
-
seen_edges.add(high)
|
| 307 |
-
|
| 308 |
-
for edge in sorted(edges):
|
| 309 |
-
fig.add_vline(
|
| 310 |
-
x=edge,
|
| 311 |
-
line=dict(color="gray", width=1, dash="dashdot"),
|
| 312 |
-
opacity=0.4,
|
| 313 |
-
layer='below'
|
| 314 |
-
)
|
| 315 |
-
# 3. Band labels above the plot
|
| 316 |
-
try:
|
| 317 |
-
y_max = max(trace.y.max() for trace in fig.data if hasattr(trace, 'y'))
|
| 318 |
-
y_pos = y_max * 1.15 if not log_scale else y_max * 3
|
| 319 |
-
except Exception:
|
| 320 |
-
y_pos = 1.1
|
| 321 |
-
|
| 322 |
-
for band, (low, high) in FREQ_BANDS.items():
|
| 323 |
-
center = (low + high) / 2
|
| 324 |
-
if freq_range[0] <= center <= freq_range[1]:
|
| 325 |
-
fig.add_annotation(
|
| 326 |
-
x=center, y=y_pos,
|
| 327 |
-
text=band,
|
| 328 |
-
showarrow=False,
|
| 329 |
-
font=dict(size=10, color="dimgray"),
|
| 330 |
-
xanchor="center", yanchor="bottom",
|
| 331 |
-
opacity=0.9
|
| 332 |
-
)
|
| 333 |
-
|
| 334 |
-
# Final layout
|
| 335 |
yaxis_title = "Power"
|
| 336 |
if align_by_delta:
|
| 337 |
yaxis_title = "Power (norm. to Delta)"
|
|
@@ -352,17 +280,19 @@ class PSDAnalyzer:
|
|
| 352 |
bgcolor="rgba(255,255,255,0.8)", font_size=11
|
| 353 |
),
|
| 354 |
margin=dict(r=160, t=60, b=80, l=60),
|
| 355 |
-
hovermode='x unified'
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
|
|
|
|
|
|
| 366 |
)
|
| 367 |
|
| 368 |
return fig
|
|
|
|
| 98 |
traces[region_name][cond_key]['subjects'].append(subject)
|
| 99 |
return traces
|
| 100 |
|
| 101 |
+
|
| 102 |
def create_plot(self, region, method, selected_conditions, selected_subjects, log_scale, show_bands, align_by_delta):
|
| 103 |
"""Generate Plotly figure with optional Delta-band alignment."""
|
| 104 |
fig = go.Figure()
|
|
|
|
| 201 |
)
|
| 202 |
return fig
|
| 203 |
|
| 204 |
+
# ==============================
|
| 205 |
+
# Frequency Band Visualization
|
| 206 |
+
# ==============================
|
|
|
|
| 207 |
if show_bands:
|
| 208 |
+
# Add shaded band regions using paper coordinates (so they don't affect scaling)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
for band, (low, high) in FREQ_BANDS.items():
|
| 210 |
+
# Only show band if it overlaps with visible freq range
|
| 211 |
if high < freq_range[0] or low > freq_range[1]:
|
| 212 |
continue
|
| 213 |
|
| 214 |
band_low = max(low, freq_range[0])
|
| 215 |
band_high = min(high, freq_range[1])
|
| 216 |
|
| 217 |
+
# Add shaded rectangle in paper space (full height, behind data)
|
|
|
|
|
|
|
|
|
|
| 218 |
fig.add_shape(
|
| 219 |
type="rect",
|
| 220 |
x0=band_low,
|
|
|
|
| 227 |
opacity=0.15,
|
| 228 |
layer="below",
|
| 229 |
line_width=0,
|
|
|
|
| 230 |
)
|
| 231 |
|
| 232 |
+
# Add label above the plot (outside the plotting area)
|
| 233 |
center_x = (band_low + band_high) / 2
|
| 234 |
fig.add_annotation(
|
| 235 |
x=center_x,
|
| 236 |
+
y=1.02, # Just above the top of the plot
|
| 237 |
text=band,
|
| 238 |
showarrow=False,
|
| 239 |
font=dict(size=9, color="dimgray"),
|
|
|
|
| 241 |
yanchor="bottom",
|
| 242 |
xref="x",
|
| 243 |
yref="paper",
|
| 244 |
+
opacity=0.85,
|
| 245 |
)
|
| 246 |
|
| 247 |
+
# Add vertical dotted lines at band boundaries (only once per edge)
|
| 248 |
seen_edges = set()
|
| 249 |
for low, high in FREQ_BANDS.values():
|
| 250 |
for edge in [low, high]:
|
|
|
|
| 257 |
)
|
| 258 |
seen_edges.add(edge)
|
| 259 |
|
| 260 |
+
# ==============================
|
| 261 |
+
# ✅ FIXED: Lock axis ranges to prevent distortion
|
| 262 |
+
# ==============================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
yaxis_title = "Power"
|
| 264 |
if align_by_delta:
|
| 265 |
yaxis_title = "Power (norm. to Delta)"
|
|
|
|
| 280 |
bgcolor="rgba(255,255,255,0.8)", font_size=11
|
| 281 |
),
|
| 282 |
margin=dict(r=160, t=60, b=80, l=60),
|
| 283 |
+
hovermode='x unified',
|
| 284 |
+
# 🔒 CRITICAL: Lock axis ranges to prevent visual distortion
|
| 285 |
+
xaxis=dict(
|
| 286 |
+
range=[freq_range[0], freq_range[1]],
|
| 287 |
+
fixedrange=True, # Prevents zoom/pan from UI, but more importantly — stops auto-resize
|
| 288 |
+
showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
|
| 289 |
+
showline=True, linewidth=1, linecolor='gray'
|
| 290 |
+
),
|
| 291 |
+
yaxis=dict(
|
| 292 |
+
fixedrange=True, # Stops Plotly from auto-resizing y-axis when shapes are added
|
| 293 |
+
showgrid=True, gridwidth=1, gridcolor='rgba(128,128,128,0.2)',
|
| 294 |
+
showline=True, linewidth=1, linecolor='gray'
|
| 295 |
+
)
|
| 296 |
)
|
| 297 |
|
| 298 |
return fig
|