Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -364,7 +364,7 @@ class ModelManager:
|
|
| 364 |
return features_dict
|
| 365 |
|
| 366 |
def predict_brain_activity(self, features_dict):
|
| 367 |
-
"""Run brain encoder forward pass."""
|
| 368 |
# Determine which features to use
|
| 369 |
if 'image_multi_layer' in features_dict:
|
| 370 |
input_features = features_dict['image_multi_layer']
|
|
@@ -384,26 +384,43 @@ class ModelManager:
|
|
| 384 |
if len(all_modality_features) > 1:
|
| 385 |
input_features = torch.mean(torch.stack(all_modality_features), dim=0)
|
| 386 |
|
|
|
|
| 387 |
input_features = input_features.to(self.device)
|
| 388 |
|
| 389 |
-
#
|
| 390 |
-
|
| 391 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
|
| 393 |
-
|
|
|
|
|
|
|
| 394 |
|
| 395 |
-
# Compute modality contributions
|
| 396 |
modality_contributions = {}
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
|
|
|
|
|
|
| 407 |
mc_predictions = []
|
| 408 |
for _ in range(10):
|
| 409 |
with torch.no_grad():
|
|
@@ -414,8 +431,18 @@ class ModelManager:
|
|
| 414 |
mc_predictions = np.array(mc_predictions)
|
| 415 |
uncertainty = np.std(mc_predictions, axis=0)
|
| 416 |
|
| 417 |
-
# Compute ROI summaries
|
| 418 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 419 |
|
| 420 |
# Validation checks
|
| 421 |
warnings = self._validate_predictions(pred_np)
|
|
|
|
| 364 |
return features_dict
|
| 365 |
|
| 366 |
def predict_brain_activity(self, features_dict):
|
| 367 |
+
"""Run brain encoder forward pass using BOTH ridge and deep models."""
|
| 368 |
# Determine which features to use
|
| 369 |
if 'image_multi_layer' in features_dict:
|
| 370 |
input_features = features_dict['image_multi_layer']
|
|
|
|
| 384 |
if len(all_modality_features) > 1:
|
| 385 |
input_features = torch.mean(torch.stack(all_modality_features), dim=0)
|
| 386 |
|
| 387 |
+
input_features_np = input_features.cpu().numpy()
|
| 388 |
input_features = input_features.to(self.device)
|
| 389 |
|
| 390 |
+
# ── Primary: Ridge Model (proven baseline from Algonauts 2023) ──
|
| 391 |
+
if self.ridge_model is not None:
|
| 392 |
+
ridge = self.ridge_model
|
| 393 |
+
X_norm = (input_features_np - ridge['feat_mean']) / ridge['feat_std']
|
| 394 |
+
pred_z = ridge['model'].predict(X_norm)
|
| 395 |
+
pred_np = (pred_z * ridge['fmri_std'] + ridge['fmri_mean']).flatten()
|
| 396 |
+
|
| 397 |
+
# Clip extreme values for better visualization (keep 99.5th percentile)
|
| 398 |
+
clip_val = np.percentile(np.abs(pred_np), 99.5)
|
| 399 |
+
pred_np = np.clip(pred_np, -clip_val, clip_val)
|
| 400 |
+
else:
|
| 401 |
+
# Fallback to deep encoder
|
| 402 |
+
with torch.no_grad():
|
| 403 |
+
predictions, _ = self.brain_encoder(input_features, return_intermediates=True)
|
| 404 |
+
pred_np = predictions.cpu().numpy().flatten()
|
| 405 |
|
| 406 |
+
# ── Deep encoder for intermediates and uncertainty ──
|
| 407 |
+
with torch.no_grad():
|
| 408 |
+
deep_pred, intermediates = self.brain_encoder(input_features, return_intermediates=True)
|
| 409 |
|
| 410 |
+
# Compute modality contributions using ridge (more reliable)
|
| 411 |
modality_contributions = {}
|
| 412 |
+
if self.ridge_model is not None:
|
| 413 |
+
for key in ['image_multi_layer', 'text_multi_layer', 'audio_multi_layer']:
|
| 414 |
+
if key in features_dict:
|
| 415 |
+
modality_name = key.split('_')[0]
|
| 416 |
+
feat_np = features_dict[key].cpu().numpy()
|
| 417 |
+
X_n = (feat_np - ridge['feat_mean']) / ridge['feat_std']
|
| 418 |
+
mp = (ridge['model'].predict(X_n) * ridge['fmri_std'] + ridge['fmri_mean']).flatten()
|
| 419 |
+
mp = np.clip(mp, -clip_val, clip_val)
|
| 420 |
+
modality_contributions[modality_name] = mp
|
| 421 |
+
|
| 422 |
+
# Compute uncertainty via dropout MC (deep encoder)
|
| 423 |
+
self.brain_encoder.train()
|
| 424 |
mc_predictions = []
|
| 425 |
for _ in range(10):
|
| 426 |
with torch.no_grad():
|
|
|
|
| 431 |
mc_predictions = np.array(mc_predictions)
|
| 432 |
uncertainty = np.std(mc_predictions, axis=0)
|
| 433 |
|
| 434 |
+
# Compute ROI summaries using z-scored per-voxel predictions
|
| 435 |
+
# This shows which regions are MORE or LESS activated compared to baseline
|
| 436 |
+
if self.ridge_model is not None:
|
| 437 |
+
baseline_mean = self.ridge_model['fmri_mean']
|
| 438 |
+
baseline_std = self.ridge_model['fmri_std']
|
| 439 |
+
# Z-score predictions relative to training distribution
|
| 440 |
+
n_v = min(len(pred_np), len(baseline_mean))
|
| 441 |
+
pred_z = (pred_np[:n_v] - baseline_mean[:n_v]) / (baseline_std[:n_v] + 1e-8)
|
| 442 |
+
else:
|
| 443 |
+
pred_z = pred_np
|
| 444 |
+
|
| 445 |
+
roi_summary = self._compute_roi_summary(pred_z, uncertainty)
|
| 446 |
|
| 447 |
# Validation checks
|
| 448 |
warnings = self._validate_predictions(pred_np)
|