twarner commited on
Commit
4dec6e2
·
1 Parent(s): e037631
Files changed (1) hide show
  1. app.py +6 -23
app.py CHANGED
@@ -222,7 +222,6 @@ def plot_brain_slices(data, learning_stage):
222
  return fig
223
 
224
  def interpret_learning_stage(score):
225
- """human-readable learning stage"""
226
  if score < 0.2:
227
  return "NOVICE: minimal task familiarity, primarily exploratory behavior"
228
  elif score < 0.4:
@@ -235,23 +234,18 @@ def interpret_learning_stage(score):
235
  return "EXPERT: automated processing, highly optimized performance"
236
 
237
  def plot_results(data, region_acts, temporal_pattern, learning_stage):
238
- # 16:9 aspect for modern displays
239
  fig = plt.figure(figsize=(16, 9))
240
 
241
- # main brain plot: 60% height
242
  gs = gridspec.GridSpec(2, 2, height_ratios=[6, 4])
243
 
244
- # brain visualization (now larger)
245
  ax1 = plt.subplot(gs[0, :])
246
  mean_activation = data.mean(axis=0)
247
  slice_idx = mean_activation.shape[-1]//2
248
  brain_slice = mean_activation[...,slice_idx]
249
 
250
- # find activation peaks
251
  peak_coords = np.unravel_index(np.argmax(brain_slice), brain_slice.shape)
252
  peak_val = brain_slice[peak_coords]
253
 
254
- # enhanced brain viz
255
  im = ax1.imshow(brain_slice.T, cmap='hot')
256
  plt.colorbar(im, ax=ax1)
257
  ax1.plot(peak_coords[0], peak_coords[1], 'r*', markersize=15)
@@ -261,7 +255,6 @@ def plot_results(data, region_acts, temporal_pattern, learning_stage):
261
  fontsize=12, pad=20)
262
  ax1.axis('off')
263
 
264
- # region plot (lower left)
265
  ax2 = plt.subplot(gs[1, 0])
266
  top_n = 5
267
  region_ranking = np.argsort(-region_acts.flatten())[:top_n]
@@ -270,7 +263,6 @@ def plot_results(data, region_acts, temporal_pattern, learning_stage):
270
  ax2.set_title('regional activity profile\n' +
271
  'top regions: ' + ', '.join(f'{r}' for r in region_ranking))
272
 
273
- # temporal dynamics (lower right)
274
  ax3 = plt.subplot(gs[1, 1])
275
  ax3.plot(temporal_pattern.squeeze(), 'k-', linewidth=2)
276
  ax3.set_title('temporal evolution')
@@ -286,23 +278,19 @@ def process_fmri(file_obj):
286
  img = nib.load(file_obj.name)
287
  data = img.get_fdata(dtype=np.float32)
288
 
289
- # shape validation + expansion
290
  if data.ndim == 3:
291
- data = data[None,...] # [H,W,D] -> [T,H,W,D]
292
  elif data.ndim != 4:
293
  return f"error: volume must be 3D/4D, got {data.ndim}D", None
294
 
295
- # validate dims
296
  t,h,w,d = data.shape
297
  if t < 1 or h < 16 or w < 16 or d < 8:
298
  return f"error: invalid dims {data.shape}, min: [1,16,16,8]", None
299
  if t > 1000 or h > 256 or w > 256 or d > 256:
300
  return f"error: dims too large {data.shape}, max: [1000,256,256,256]", None
301
 
302
- # reshape for batch
303
- data = data.reshape(1, t, h, w, d) # [B,T,H,W,D]
304
 
305
- # normalize + preprocess
306
  data = preprocess_volume(data)
307
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
308
 
@@ -312,7 +300,7 @@ def process_fmri(file_obj):
312
  for stage in ['full', 'region', 'temporal']:
313
  try:
314
  model = SequentialBrainViT(Config())
315
- model._init_weights() # critical: init before load
316
  ckpt = torch.load(f'best_{stage}.pt', map_location=device)
317
  missing = model.load_state_dict(ckpt['model'], strict=False)
318
  if missing:
@@ -328,7 +316,7 @@ def process_fmri(file_obj):
328
  }
329
 
330
  fig = plot_results(
331
- data[0].cpu().numpy(), # drop batch
332
  results[stage]['region_activation'],
333
  results[stage]['temporal_pattern'],
334
  results[stage]['learning_stage']
@@ -354,11 +342,8 @@ def process_fmri(file_obj):
354
 
355
  iface = gr.Interface(
356
  fn=process_fmri,
357
- inputs=gr.File(
358
- label="fMRI Data Input",
359
- file_types=[".nii", ".nii.gz"],
360
- file_count="single"
361
- ),
362
  outputs=[
363
  gr.Textbox(
364
  label="Analysis Results",
@@ -386,8 +371,6 @@ iface = gr.Interface(
386
  - brain map: warmer colors = higher activation
387
  - regional profile: shows activity across 116 brain regions
388
  - temporal pattern: activation changes over time
389
-
390
- model trained on weatherprediction task fMRI data from openneuro ds000052
391
  """,
392
  theme="default",
393
  examples=[],
 
222
  return fig
223
 
224
  def interpret_learning_stage(score):
 
225
  if score < 0.2:
226
  return "NOVICE: minimal task familiarity, primarily exploratory behavior"
227
  elif score < 0.4:
 
234
  return "EXPERT: automated processing, highly optimized performance"
235
 
236
  def plot_results(data, region_acts, temporal_pattern, learning_stage):
 
237
  fig = plt.figure(figsize=(16, 9))
238
 
 
239
  gs = gridspec.GridSpec(2, 2, height_ratios=[6, 4])
240
 
 
241
  ax1 = plt.subplot(gs[0, :])
242
  mean_activation = data.mean(axis=0)
243
  slice_idx = mean_activation.shape[-1]//2
244
  brain_slice = mean_activation[...,slice_idx]
245
 
 
246
  peak_coords = np.unravel_index(np.argmax(brain_slice), brain_slice.shape)
247
  peak_val = brain_slice[peak_coords]
248
 
 
249
  im = ax1.imshow(brain_slice.T, cmap='hot')
250
  plt.colorbar(im, ax=ax1)
251
  ax1.plot(peak_coords[0], peak_coords[1], 'r*', markersize=15)
 
255
  fontsize=12, pad=20)
256
  ax1.axis('off')
257
 
 
258
  ax2 = plt.subplot(gs[1, 0])
259
  top_n = 5
260
  region_ranking = np.argsort(-region_acts.flatten())[:top_n]
 
263
  ax2.set_title('regional activity profile\n' +
264
  'top regions: ' + ', '.join(f'{r}' for r in region_ranking))
265
 
 
266
  ax3 = plt.subplot(gs[1, 1])
267
  ax3.plot(temporal_pattern.squeeze(), 'k-', linewidth=2)
268
  ax3.set_title('temporal evolution')
 
278
  img = nib.load(file_obj.name)
279
  data = img.get_fdata(dtype=np.float32)
280
 
 
281
  if data.ndim == 3:
282
+ data = data[None,...]
283
  elif data.ndim != 4:
284
  return f"error: volume must be 3D/4D, got {data.ndim}D", None
285
 
 
286
  t,h,w,d = data.shape
287
  if t < 1 or h < 16 or w < 16 or d < 8:
288
  return f"error: invalid dims {data.shape}, min: [1,16,16,8]", None
289
  if t > 1000 or h > 256 or w > 256 or d > 256:
290
  return f"error: dims too large {data.shape}, max: [1000,256,256,256]", None
291
 
292
+ data = data.reshape(1, t, h, w, d)
 
293
 
 
294
  data = preprocess_volume(data)
295
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
296
 
 
300
  for stage in ['full', 'region', 'temporal']:
301
  try:
302
  model = SequentialBrainViT(Config())
303
+ model._init_weights()
304
  ckpt = torch.load(f'best_{stage}.pt', map_location=device)
305
  missing = model.load_state_dict(ckpt['model'], strict=False)
306
  if missing:
 
316
  }
317
 
318
  fig = plot_results(
319
+ data[0].cpu().numpy(),
320
  results[stage]['region_activation'],
321
  results[stage]['temporal_pattern'],
322
  results[stage]['learning_stage']
 
342
 
343
  iface = gr.Interface(
344
  fn=process_fmri,
345
+ inputs=gr.File(label="fMRI data input (.nii/.nii.gz)"),
346
+
 
 
 
347
  outputs=[
348
  gr.Textbox(
349
  label="Analysis Results",
 
371
  - brain map: warmer colors = higher activation
372
  - regional profile: shows activity across 116 brain regions
373
  - temporal pattern: activation changes over time
 
 
374
  """,
375
  theme="default",
376
  examples=[],