Spaces:
Sleeping
Sleeping
shape range val
Browse files
app.py
CHANGED
|
@@ -226,14 +226,23 @@ def process_fmri(file_obj):
|
|
| 226 |
img = nib.load(file_obj.name)
|
| 227 |
data = img.get_fdata(dtype=np.float32)
|
| 228 |
|
| 229 |
-
|
| 230 |
-
return f"error: expected 4D data, got {data.ndim}D", None
|
| 231 |
-
|
| 232 |
if data.ndim == 3:
|
| 233 |
-
data = data[None,...] #
|
| 234 |
elif data.ndim != 4:
|
| 235 |
-
return f"error:
|
| 236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 237 |
data = preprocess_volume(data)
|
| 238 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 239 |
|
|
@@ -241,25 +250,29 @@ def process_fmri(file_obj):
|
|
| 241 |
figs = []
|
| 242 |
|
| 243 |
for stage in ['full', 'region', 'temporal']:
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
with torch.no_grad():
|
| 250 |
-
outputs = model(data.to(device), torch.tensor([0]).to(device))
|
| 251 |
-
results[stage] = {
|
| 252 |
-
'learning_stage': float(outputs['learning_stage'].cpu().mean()),
|
| 253 |
-
'region_activation': outputs['region_activation'].cpu().numpy(),
|
| 254 |
-
'temporal_pattern': outputs['temporal_pattern'].cpu().numpy()
|
| 255 |
-
}
|
| 256 |
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
results[stage]
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 263 |
|
| 264 |
stage_results = "\n".join([
|
| 265 |
f"{stage.upper()} MODEL:"
|
|
@@ -268,7 +281,7 @@ def process_fmri(file_obj):
|
|
| 268 |
for stage, res in results.items()
|
| 269 |
])
|
| 270 |
|
| 271 |
-
return stage_results, figs[0]
|
| 272 |
|
| 273 |
except Exception as e:
|
| 274 |
return f"error processing file: {str(e)}", None
|
|
|
|
| 226 |
img = nib.load(file_obj.name)
|
| 227 |
data = img.get_fdata(dtype=np.float32)
|
| 228 |
|
| 229 |
+
# shape validation + expansion
|
|
|
|
|
|
|
| 230 |
if data.ndim == 3:
|
| 231 |
+
data = data[None,...] # [H,W,D] -> [T,H,W,D]
|
| 232 |
elif data.ndim != 4:
|
| 233 |
+
return f"error: volume must be 3D/4D, got {data.ndim}D", None
|
| 234 |
|
| 235 |
+
# validate dims
|
| 236 |
+
t,h,w,d = data.shape
|
| 237 |
+
if t < 1 or h < 16 or w < 16 or d < 8:
|
| 238 |
+
return f"error: invalid dims {data.shape}, min: [1,16,16,8]", None
|
| 239 |
+
if t > 1000 or h > 256 or w > 256 or d > 256:
|
| 240 |
+
return f"error: dims too large {data.shape}, max: [1000,256,256,256]", None
|
| 241 |
+
|
| 242 |
+
# reshape for batch
|
| 243 |
+
data = data.reshape(1, t, h, w, d) # [B,T,H,W,D]
|
| 244 |
+
|
| 245 |
+
# normalize + preprocess
|
| 246 |
data = preprocess_volume(data)
|
| 247 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 248 |
|
|
|
|
| 250 |
figs = []
|
| 251 |
|
| 252 |
for stage in ['full', 'region', 'temporal']:
|
| 253 |
+
try:
|
| 254 |
+
model = SequentialBrainViT(Config())
|
| 255 |
+
ckpt = torch.load(f'best_{stage}.pt', map_location=device)
|
| 256 |
+
model.load_state_dict(ckpt['model'])
|
| 257 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
+
with torch.no_grad():
|
| 260 |
+
outputs = model(data.to(device), torch.tensor([0]).to(device))
|
| 261 |
+
results[stage] = {
|
| 262 |
+
'learning_stage': float(outputs['learning_stage'].cpu().mean()),
|
| 263 |
+
'region_activation': outputs['region_activation'].cpu().numpy(),
|
| 264 |
+
'temporal_pattern': outputs['temporal_pattern'].cpu().numpy()
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
fig = plot_results(
|
| 268 |
+
results[stage]['region_activation'],
|
| 269 |
+
results[stage]['temporal_pattern']
|
| 270 |
+
)
|
| 271 |
+
figs.append(fig)
|
| 272 |
+
plt.close()
|
| 273 |
+
|
| 274 |
+
except Exception as e:
|
| 275 |
+
return f"error in {stage} model: {str(e)}", None
|
| 276 |
|
| 277 |
stage_results = "\n".join([
|
| 278 |
f"{stage.upper()} MODEL:"
|
|
|
|
| 281 |
for stage, res in results.items()
|
| 282 |
])
|
| 283 |
|
| 284 |
+
return stage_results, figs[0]
|
| 285 |
|
| 286 |
except Exception as e:
|
| 287 |
return f"error processing file: {str(e)}", None
|