twarner commited on
Commit
d623d88
·
1 Parent(s): 45f9927

3d->4d proj

Browse files
Files changed (2) hide show
  1. README.md +0 -2
  2. app.py +59 -16
README.md CHANGED
@@ -10,5 +10,3 @@ pinned: false
10
  license: mit
11
  short_description: 'fMRI Learning Stage Classification with Vision Transformers '
12
  ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
10
  license: mit
11
  short_description: 'fMRI Learning Stage Classification with Vision Transformers '
12
  ---
 
 
app.py CHANGED
@@ -14,8 +14,9 @@ from einops.layers.torch import Rearrange
14
  from scipy.ndimage import zoom
15
  import matplotlib.pyplot as plt
16
  import seaborn as sns
 
 
17
 
18
- # core config
19
  @dataclass
20
  class Config:
21
  VOLUME_SIZE: Tuple[int, int, int] = (64, 64, 30)
@@ -25,7 +26,6 @@ class Config:
25
  DROPOUT: float = 0.1
26
  TASK_DIM: int = 512
27
 
28
- # model components
29
  class HierarchicalAttention(nn.Module):
30
  def __init__(self, dim, heads=8):
31
  super().__init__()
@@ -206,19 +206,55 @@ def preprocess_volume(vol, target_size=(64, 64, 30)):
206
  vol = (vol - vol.mean((1,2,3,4), keepdims=True)) / (vol.std((1,2,3,4), keepdims=True) + 1e-8)
207
  return torch.from_numpy(vol).float()
208
 
209
- def plot_results(region_acts, temporal_pattern):
210
- fig = plt.figure(figsize=(12,4))
 
211
 
212
- plt.subplot(121)
213
- sns.heatmap(region_acts.reshape(1,-1), cmap='RdBu_r', center=0)
214
- plt.title('region activations')
215
- plt.xlabel('brain region')
 
 
 
 
216
 
217
- plt.subplot(122)
218
- plt.plot(temporal_pattern.squeeze())
219
- plt.title('temporal pattern')
220
- plt.xlabel('time')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
 
222
  return fig
223
 
224
  def process_fmri(file_obj):
@@ -252,8 +288,11 @@ def process_fmri(file_obj):
252
  for stage in ['full', 'region', 'temporal']:
253
  try:
254
  model = SequentialBrainViT(Config())
 
255
  ckpt = torch.load(f'best_{stage}.pt', map_location=device)
256
- model.load_state_dict(ckpt['model'])
 
 
257
  model.eval()
258
 
259
  with torch.no_grad():
@@ -265,8 +304,10 @@ def process_fmri(file_obj):
265
  }
266
 
267
  fig = plot_results(
 
268
  results[stage]['region_activation'],
269
- results[stage]['temporal_pattern']
 
270
  )
271
  figs.append(fig)
272
  plt.close()
@@ -274,9 +315,12 @@ def process_fmri(file_obj):
274
  except Exception as e:
275
  return f"error in {stage} model: {str(e)}", None
276
 
 
277
  stage_results = "\n".join([
278
  f"{stage.upper()} MODEL:"
279
  f"\nlearning stage: {res['learning_stage']:.3f}"
 
 
280
  f"\n"
281
  for stage, res in results.items()
282
  ])
@@ -286,13 +330,12 @@ def process_fmri(file_obj):
286
  except Exception as e:
287
  return f"error processing file: {str(e)}", None
288
 
289
- # create interface
290
  iface = gr.Interface(
291
  fn=process_fmri,
292
  inputs=gr.File(label="upload 4D fMRI nifti (.nii/.nii.gz)"),
293
  outputs=[
294
  gr.Textbox(label="classification results"),
295
- gr.Plot(label="visualization")
296
  ],
297
  title="fmri learning stage classifier",
298
  description="upload a 4D fMRI nifti file to classify learning stages and visualize brain patterns",
 
14
  from scipy.ndimage import zoom
15
  import matplotlib.pyplot as plt
16
  import seaborn as sns
17
+ from nilearn import plotting
18
+ import matplotlib.gridspec as gridspec
19
 
 
20
  @dataclass
21
  class Config:
22
  VOLUME_SIZE: Tuple[int, int, int] = (64, 64, 30)
 
26
  DROPOUT: float = 0.1
27
  TASK_DIM: int = 512
28
 
 
29
  class HierarchicalAttention(nn.Module):
30
  def __init__(self, dim, heads=8):
31
  super().__init__()
 
206
  vol = (vol - vol.mean((1,2,3,4), keepdims=True)) / (vol.std((1,2,3,4), keepdims=True) + 1e-8)
207
  return torch.from_numpy(vol).float()
208
 
209
+ def plot_brain_slices(data, learning_stage):
210
+ fig = plt.figure(figsize=(15, 5))
211
+ mean_activation = data.mean(axis=0)
212
 
213
+ for i, slice_idx in enumerate([mean_activation.shape[-1]//4,
214
+ mean_activation.shape[-1]//2,
215
+ 3*mean_activation.shape[-1]//4]):
216
+ plt.subplot(1, 3, i+1)
217
+ plt.imshow(mean_activation[...,slice_idx].T, cmap='hot')
218
+ plt.colorbar()
219
+ plt.title(f'slice z={slice_idx}\nlearning: {learning_stage:.3f}')
220
+ plt.axis('off')
221
 
222
+ return fig
223
+
224
+ def plot_results(data, region_acts, temporal_pattern, learning_stage):
225
+ fig = plt.figure(figsize=(15,10))
226
+ gs = gridspec.GridSpec(2, 2)
227
+
228
+ # brain slices
229
+ ax1 = plt.subplot(gs[0,:])
230
+ mean_activation = data.mean(axis=0)
231
+ slice_idx = mean_activation.shape[-1]//2
232
+ brain_slice = mean_activation[...,slice_idx]
233
+
234
+ # find most active region
235
+ peak_coords = np.unravel_index(np.argmax(brain_slice), brain_slice.shape)
236
+
237
+ im = ax1.imshow(brain_slice.T, cmap='hot')
238
+ plt.colorbar(im, ax=ax1)
239
+ ax1.plot(peak_coords[0], peak_coords[1], 'r*', markersize=15,
240
+ label=f'peak ({peak_coords[0]}, {peak_coords[1]})')
241
+ ax1.legend()
242
+ ax1.set_title(f'brain activation (z={slice_idx})\nlearning stage: {learning_stage:.3f}')
243
+
244
+ # region activations
245
+ ax2 = plt.subplot(gs[1,0])
246
+ max_region = np.argmax(region_acts)
247
+ sns.heatmap(region_acts.reshape(1,-1), cmap='RdBu_r', center=0, ax=ax2)
248
+ ax2.set_title(f'region activations\nmost active: {max_region}')
249
+ ax2.set_xlabel('brain region')
250
+
251
+ # temporal pattern
252
+ ax3 = plt.subplot(gs[1,1])
253
+ ax3.plot(temporal_pattern.squeeze())
254
+ ax3.set_title('temporal dynamics')
255
+ ax3.set_xlabel('time')
256
 
257
+ plt.tight_layout()
258
  return fig
259
 
260
  def process_fmri(file_obj):
 
288
  for stage in ['full', 'region', 'temporal']:
289
  try:
290
  model = SequentialBrainViT(Config())
291
+ model._init_weights() # critical: init before load
292
  ckpt = torch.load(f'best_{stage}.pt', map_location=device)
293
+ missing = model.load_state_dict(ckpt['model'], strict=False)
294
+ if missing:
295
+ print(f"warning - {stage} missing keys:", missing)
296
  model.eval()
297
 
298
  with torch.no_grad():
 
304
  }
305
 
306
  fig = plot_results(
307
+ data[0].cpu().numpy(), # drop batch
308
  results[stage]['region_activation'],
309
+ results[stage]['temporal_pattern'],
310
+ results[stage]['learning_stage']
311
  )
312
  figs.append(fig)
313
  plt.close()
 
315
  except Exception as e:
316
  return f"error in {stage} model: {str(e)}", None
317
 
318
+ # enhanced results text w/ peak info
319
  stage_results = "\n".join([
320
  f"{stage.upper()} MODEL:"
321
  f"\nlearning stage: {res['learning_stage']:.3f}"
322
+ f"\npeak region: {np.argmax(res['region_activation'])}"
323
+ f"\npeak activation: {np.max(res['region_activation']):.3f}"
324
  f"\n"
325
  for stage, res in results.items()
326
  ])
 
330
  except Exception as e:
331
  return f"error processing file: {str(e)}", None
332
 
 
333
  iface = gr.Interface(
334
  fn=process_fmri,
335
  inputs=gr.File(label="upload 4D fMRI nifti (.nii/.nii.gz)"),
336
  outputs=[
337
  gr.Textbox(label="classification results"),
338
+ gr.Plot(label="brain activation + analysis")
339
  ],
340
  title="fmri learning stage classifier",
341
  description="upload a 4D fMRI nifti file to classify learning stages and visualize brain patterns",