twarner commited on
Commit
45f9927
·
1 Parent(s): 6958eb5

shape range val

Browse files
Files changed (1) hide show
  1. app.py +37 -24
app.py CHANGED
@@ -226,14 +226,23 @@ def process_fmri(file_obj):
226
  img = nib.load(file_obj.name)
227
  data = img.get_fdata(dtype=np.float32)
228
 
229
- if data.ndim != 4:
230
- return f"error: expected 4D data, got {data.ndim}D", None
231
-
232
  if data.ndim == 3:
233
- data = data[None,...] # add time dim [H,W,D] -> [1,H,W,D]
234
  elif data.ndim != 4:
235
- return f"error: expected 3D/4D data, got {data.ndim}D", None
236
 
 
 
 
 
 
 
 
 
 
 
 
237
  data = preprocess_volume(data)
238
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
239
 
@@ -241,25 +250,29 @@ def process_fmri(file_obj):
241
  figs = []
242
 
243
  for stage in ['full', 'region', 'temporal']:
244
- model = SequentialBrainViT(Config())
245
- ckpt = torch.load(f'best_{stage}.pt', map_location=device)
246
- model.load_state_dict(ckpt['model'])
247
- model.eval()
248
-
249
- with torch.no_grad():
250
- outputs = model(data.to(device), torch.tensor([0]).to(device))
251
- results[stage] = {
252
- 'learning_stage': float(outputs['learning_stage'].cpu().mean()),
253
- 'region_activation': outputs['region_activation'].cpu().numpy(),
254
- 'temporal_pattern': outputs['temporal_pattern'].cpu().numpy()
255
- }
256
 
257
- fig = plot_results(
258
- results[stage]['region_activation'],
259
- results[stage]['temporal_pattern']
260
- )
261
- figs.append(fig)
262
- plt.close()
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  stage_results = "\n".join([
265
  f"{stage.upper()} MODEL:"
@@ -268,7 +281,7 @@ def process_fmri(file_obj):
268
  for stage, res in results.items()
269
  ])
270
 
271
- return stage_results, figs[0] # return first fig for display
272
 
273
  except Exception as e:
274
  return f"error processing file: {str(e)}", None
 
226
  img = nib.load(file_obj.name)
227
  data = img.get_fdata(dtype=np.float32)
228
 
229
+ # shape validation + expansion
 
 
230
  if data.ndim == 3:
231
+ data = data[None,...] # [H,W,D] -> [T,H,W,D]
232
  elif data.ndim != 4:
233
+ return f"error: volume must be 3D/4D, got {data.ndim}D", None
234
 
235
+ # validate dims
236
+ t,h,w,d = data.shape
237
+ if t < 1 or h < 16 or w < 16 or d < 8:
238
+ return f"error: invalid dims {data.shape}, min: [1,16,16,8]", None
239
+ if t > 1000 or h > 256 or w > 256 or d > 256:
240
+ return f"error: dims too large {data.shape}, max: [1000,256,256,256]", None
241
+
242
+ # reshape for batch
243
+ data = data.reshape(1, t, h, w, d) # [B,T,H,W,D]
244
+
245
+ # normalize + preprocess
246
  data = preprocess_volume(data)
247
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
248
 
 
250
  figs = []
251
 
252
  for stage in ['full', 'region', 'temporal']:
253
+ try:
254
+ model = SequentialBrainViT(Config())
255
+ ckpt = torch.load(f'best_{stage}.pt', map_location=device)
256
+ model.load_state_dict(ckpt['model'])
257
+ model.eval()
 
 
 
 
 
 
 
258
 
259
+ with torch.no_grad():
260
+ outputs = model(data.to(device), torch.tensor([0]).to(device))
261
+ results[stage] = {
262
+ 'learning_stage': float(outputs['learning_stage'].cpu().mean()),
263
+ 'region_activation': outputs['region_activation'].cpu().numpy(),
264
+ 'temporal_pattern': outputs['temporal_pattern'].cpu().numpy()
265
+ }
266
+
267
+ fig = plot_results(
268
+ results[stage]['region_activation'],
269
+ results[stage]['temporal_pattern']
270
+ )
271
+ figs.append(fig)
272
+ plt.close()
273
+
274
+ except Exception as e:
275
+ return f"error in {stage} model: {str(e)}", None
276
 
277
  stage_results = "\n".join([
278
  f"{stage.upper()} MODEL:"
 
281
  for stage, res in results.items()
282
  ])
283
 
284
+ return stage_results, figs[0]
285
 
286
  except Exception as e:
287
  return f"error processing file: {str(e)}", None