Spaces:
Sleeping
Sleeping
input fix
Browse files
app.py
CHANGED
|
@@ -222,7 +222,6 @@ def plot_brain_slices(data, learning_stage):
|
|
| 222 |
return fig
|
| 223 |
|
| 224 |
def interpret_learning_stage(score):
|
| 225 |
-
"""human-readable learning stage"""
|
| 226 |
if score < 0.2:
|
| 227 |
return "NOVICE: minimal task familiarity, primarily exploratory behavior"
|
| 228 |
elif score < 0.4:
|
|
@@ -235,23 +234,18 @@ def interpret_learning_stage(score):
|
|
| 235 |
return "EXPERT: automated processing, highly optimized performance"
|
| 236 |
|
| 237 |
def plot_results(data, region_acts, temporal_pattern, learning_stage):
|
| 238 |
-
# 16:9 aspect for modern displays
|
| 239 |
fig = plt.figure(figsize=(16, 9))
|
| 240 |
|
| 241 |
-
# main brain plot: 60% height
|
| 242 |
gs = gridspec.GridSpec(2, 2, height_ratios=[6, 4])
|
| 243 |
|
| 244 |
-
# brain visualization (now larger)
|
| 245 |
ax1 = plt.subplot(gs[0, :])
|
| 246 |
mean_activation = data.mean(axis=0)
|
| 247 |
slice_idx = mean_activation.shape[-1]//2
|
| 248 |
brain_slice = mean_activation[...,slice_idx]
|
| 249 |
|
| 250 |
-
# find activation peaks
|
| 251 |
peak_coords = np.unravel_index(np.argmax(brain_slice), brain_slice.shape)
|
| 252 |
peak_val = brain_slice[peak_coords]
|
| 253 |
|
| 254 |
-
# enhanced brain viz
|
| 255 |
im = ax1.imshow(brain_slice.T, cmap='hot')
|
| 256 |
plt.colorbar(im, ax=ax1)
|
| 257 |
ax1.plot(peak_coords[0], peak_coords[1], 'r*', markersize=15)
|
|
@@ -261,7 +255,6 @@ def plot_results(data, region_acts, temporal_pattern, learning_stage):
|
|
| 261 |
fontsize=12, pad=20)
|
| 262 |
ax1.axis('off')
|
| 263 |
|
| 264 |
-
# region plot (lower left)
|
| 265 |
ax2 = plt.subplot(gs[1, 0])
|
| 266 |
top_n = 5
|
| 267 |
region_ranking = np.argsort(-region_acts.flatten())[:top_n]
|
|
@@ -270,7 +263,6 @@ def plot_results(data, region_acts, temporal_pattern, learning_stage):
|
|
| 270 |
ax2.set_title('regional activity profile\n' +
|
| 271 |
'top regions: ' + ', '.join(f'{r}' for r in region_ranking))
|
| 272 |
|
| 273 |
-
# temporal dynamics (lower right)
|
| 274 |
ax3 = plt.subplot(gs[1, 1])
|
| 275 |
ax3.plot(temporal_pattern.squeeze(), 'k-', linewidth=2)
|
| 276 |
ax3.set_title('temporal evolution')
|
|
@@ -286,23 +278,19 @@ def process_fmri(file_obj):
|
|
| 286 |
img = nib.load(file_obj.name)
|
| 287 |
data = img.get_fdata(dtype=np.float32)
|
| 288 |
|
| 289 |
-
# shape validation + expansion
|
| 290 |
if data.ndim == 3:
|
| 291 |
-
data = data[None,...]
|
| 292 |
elif data.ndim != 4:
|
| 293 |
return f"error: volume must be 3D/4D, got {data.ndim}D", None
|
| 294 |
|
| 295 |
-
# validate dims
|
| 296 |
t,h,w,d = data.shape
|
| 297 |
if t < 1 or h < 16 or w < 16 or d < 8:
|
| 298 |
return f"error: invalid dims {data.shape}, min: [1,16,16,8]", None
|
| 299 |
if t > 1000 or h > 256 or w > 256 or d > 256:
|
| 300 |
return f"error: dims too large {data.shape}, max: [1000,256,256,256]", None
|
| 301 |
|
| 302 |
-
|
| 303 |
-
data = data.reshape(1, t, h, w, d) # [B,T,H,W,D]
|
| 304 |
|
| 305 |
-
# normalize + preprocess
|
| 306 |
data = preprocess_volume(data)
|
| 307 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 308 |
|
|
@@ -312,7 +300,7 @@ def process_fmri(file_obj):
|
|
| 312 |
for stage in ['full', 'region', 'temporal']:
|
| 313 |
try:
|
| 314 |
model = SequentialBrainViT(Config())
|
| 315 |
-
model._init_weights()
|
| 316 |
ckpt = torch.load(f'best_{stage}.pt', map_location=device)
|
| 317 |
missing = model.load_state_dict(ckpt['model'], strict=False)
|
| 318 |
if missing:
|
|
@@ -328,7 +316,7 @@ def process_fmri(file_obj):
|
|
| 328 |
}
|
| 329 |
|
| 330 |
fig = plot_results(
|
| 331 |
-
data[0].cpu().numpy(),
|
| 332 |
results[stage]['region_activation'],
|
| 333 |
results[stage]['temporal_pattern'],
|
| 334 |
results[stage]['learning_stage']
|
|
@@ -354,11 +342,8 @@ def process_fmri(file_obj):
|
|
| 354 |
|
| 355 |
iface = gr.Interface(
|
| 356 |
fn=process_fmri,
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
file_types=[".nii", ".nii.gz"],
|
| 360 |
-
file_count="single"
|
| 361 |
-
),
|
| 362 |
outputs=[
|
| 363 |
gr.Textbox(
|
| 364 |
label="Analysis Results",
|
|
@@ -386,8 +371,6 @@ iface = gr.Interface(
|
|
| 386 |
- brain map: warmer colors = higher activation
|
| 387 |
- regional profile: shows activity across 116 brain regions
|
| 388 |
- temporal pattern: activation changes over time
|
| 389 |
-
|
| 390 |
-
model trained on weatherprediction task fMRI data from openneuro ds000052
|
| 391 |
""",
|
| 392 |
theme="default",
|
| 393 |
examples=[],
|
|
|
|
| 222 |
return fig
|
| 223 |
|
| 224 |
def interpret_learning_stage(score):
|
|
|
|
| 225 |
if score < 0.2:
|
| 226 |
return "NOVICE: minimal task familiarity, primarily exploratory behavior"
|
| 227 |
elif score < 0.4:
|
|
|
|
| 234 |
return "EXPERT: automated processing, highly optimized performance"
|
| 235 |
|
| 236 |
def plot_results(data, region_acts, temporal_pattern, learning_stage):
|
|
|
|
| 237 |
fig = plt.figure(figsize=(16, 9))
|
| 238 |
|
|
|
|
| 239 |
gs = gridspec.GridSpec(2, 2, height_ratios=[6, 4])
|
| 240 |
|
|
|
|
| 241 |
ax1 = plt.subplot(gs[0, :])
|
| 242 |
mean_activation = data.mean(axis=0)
|
| 243 |
slice_idx = mean_activation.shape[-1]//2
|
| 244 |
brain_slice = mean_activation[...,slice_idx]
|
| 245 |
|
|
|
|
| 246 |
peak_coords = np.unravel_index(np.argmax(brain_slice), brain_slice.shape)
|
| 247 |
peak_val = brain_slice[peak_coords]
|
| 248 |
|
|
|
|
| 249 |
im = ax1.imshow(brain_slice.T, cmap='hot')
|
| 250 |
plt.colorbar(im, ax=ax1)
|
| 251 |
ax1.plot(peak_coords[0], peak_coords[1], 'r*', markersize=15)
|
|
|
|
| 255 |
fontsize=12, pad=20)
|
| 256 |
ax1.axis('off')
|
| 257 |
|
|
|
|
| 258 |
ax2 = plt.subplot(gs[1, 0])
|
| 259 |
top_n = 5
|
| 260 |
region_ranking = np.argsort(-region_acts.flatten())[:top_n]
|
|
|
|
| 263 |
ax2.set_title('regional activity profile\n' +
|
| 264 |
'top regions: ' + ', '.join(f'{r}' for r in region_ranking))
|
| 265 |
|
|
|
|
| 266 |
ax3 = plt.subplot(gs[1, 1])
|
| 267 |
ax3.plot(temporal_pattern.squeeze(), 'k-', linewidth=2)
|
| 268 |
ax3.set_title('temporal evolution')
|
|
|
|
| 278 |
img = nib.load(file_obj.name)
|
| 279 |
data = img.get_fdata(dtype=np.float32)
|
| 280 |
|
|
|
|
| 281 |
if data.ndim == 3:
|
| 282 |
+
data = data[None,...]
|
| 283 |
elif data.ndim != 4:
|
| 284 |
return f"error: volume must be 3D/4D, got {data.ndim}D", None
|
| 285 |
|
|
|
|
| 286 |
t,h,w,d = data.shape
|
| 287 |
if t < 1 or h < 16 or w < 16 or d < 8:
|
| 288 |
return f"error: invalid dims {data.shape}, min: [1,16,16,8]", None
|
| 289 |
if t > 1000 or h > 256 or w > 256 or d > 256:
|
| 290 |
return f"error: dims too large {data.shape}, max: [1000,256,256,256]", None
|
| 291 |
|
| 292 |
+
data = data.reshape(1, t, h, w, d)
|
|
|
|
| 293 |
|
|
|
|
| 294 |
data = preprocess_volume(data)
|
| 295 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 296 |
|
|
|
|
| 300 |
for stage in ['full', 'region', 'temporal']:
|
| 301 |
try:
|
| 302 |
model = SequentialBrainViT(Config())
|
| 303 |
+
model._init_weights()
|
| 304 |
ckpt = torch.load(f'best_{stage}.pt', map_location=device)
|
| 305 |
missing = model.load_state_dict(ckpt['model'], strict=False)
|
| 306 |
if missing:
|
|
|
|
| 316 |
}
|
| 317 |
|
| 318 |
fig = plot_results(
|
| 319 |
+
data[0].cpu().numpy(),
|
| 320 |
results[stage]['region_activation'],
|
| 321 |
results[stage]['temporal_pattern'],
|
| 322 |
results[stage]['learning_stage']
|
|
|
|
| 342 |
|
| 343 |
iface = gr.Interface(
|
| 344 |
fn=process_fmri,
|
| 345 |
+
inputs=gr.File(label="fMRI data input (.nii/.nii.gz)"),
|
| 346 |
+
|
|
|
|
|
|
|
|
|
|
| 347 |
outputs=[
|
| 348 |
gr.Textbox(
|
| 349 |
label="Analysis Results",
|
|
|
|
| 371 |
- brain map: warmer colors = higher activation
|
| 372 |
- regional profile: shows activity across 116 brain regions
|
| 373 |
- temporal pattern: activation changes over time
|
|
|
|
|
|
|
| 374 |
""",
|
| 375 |
theme="default",
|
| 376 |
examples=[],
|