Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| """ | |
| MultiMAE3D Test-Only Evaluation | |
| Load saved finetuned checkpoints and evaluate on the test set. | |
| Reuses model/data/metric utilities from finetune_main.py. | |
| Usage: | |
| # Test all tasks for finetune mode | |
| python test_main.py --mode finetune | |
| # Test a specific task | |
| python test_main.py --mode finetune --tasks "CN vs AD" | |
| # Test with custom checkpoint directory | |
| python test_main.py --mode finetune --checkpoint_dir ./saves/multimae_finetune/ | |
| """ | |
| import os | |
| import sys | |
| import gc | |
| import random | |
| import warnings | |
| from collections import defaultdict | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from scipy.stats import pearsonr | |
| warnings.filterwarnings("ignore") | |
| _BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.insert(0, _BASE_DIR) | |
| from models.multimae3d import create_multimae3d, MultiMAE3D | |
| from downstream_dataloader import create_downstream_dataloader | |
| from finetune_main import ( | |
| seed_everything, | |
| str2bool, | |
| MultiMAE3DForDownstream, | |
| run_epoch, | |
| calc_classification_metrics, | |
| calc_regression_metrics, | |
| calc_metrics_by_combo, | |
| _save_combo_results, | |
| ) | |
| # ========================================================================= | |
| # Test-only evaluation for a single (task, seed) | |
| # ========================================================================= | |
| def test_evaluate(args, task_type, seed, device, checkpoint_path): | |
| """Load a saved checkpoint and evaluate on test set.""" | |
| seed_everything(seed) | |
| torch.cuda.empty_cache() | |
| is_cls = task_type in ('CN vs AD', 'CN vs MCI') | |
| # ---- Test data loader ---- | |
| loader_kwargs = dict( | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| pin_memory=True, | |
| cache_data=False, | |
| image_size=tuple(args.image_size), | |
| base_dir=args.base_dir, | |
| modalities=args.modalities, | |
| intersection=args.intersection, | |
| ) | |
| print(f"\nLoading test data for task={task_type}, seed={seed}") | |
| test_loader = create_downstream_dataloader( | |
| excel_path=args.test_excel, labels=[task_type], | |
| augmentation=False, shuffle=False, | |
| phase='test', modality_dropout=False, expand_val_combinations=False, | |
| exclusive_modalities=False, | |
| **loader_kwargs, | |
| ) | |
| print(f" Test: {len(test_loader.dataset)} samples") | |
| # ---- Model ---- | |
| encoder = create_multimae3d( | |
| img_size=args.img_size, | |
| patch_size=args.patch_size, | |
| embed_dim=args.embed_dim, | |
| depth=args.depth, | |
| num_heads=args.num_heads, | |
| decoder_embed_dim=args.decoder_embed_dim, | |
| decoder_depth=args.decoder_depth, | |
| decoder_num_heads=args.decoder_num_heads, | |
| ) | |
| model = MultiMAE3DForDownstream( | |
| encoder=encoder, | |
| embed_dim=args.embed_dim, | |
| num_outputs=1, | |
| pool=args.pool, | |
| dropout=args.dropout, | |
| ).to(device) | |
| # Load checkpoint | |
| print(f" Loading checkpoint: {checkpoint_path}") | |
| ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) | |
| model.load_state_dict(ckpt['model_state_dict']) | |
| print(f" Loaded (epoch={ckpt.get('epoch', '?')}, " | |
| f"best_metric={ckpt.get('best_metric', '?')})") | |
| # Criterion | |
| criterion = (nn.BCEWithLogitsLoss() if is_cls | |
| else nn.MSELoss()).to(device) | |
| # ---- Test evaluation ---- | |
| print(" Evaluating on test set...") | |
| test_loss, test_preds, test_labels, test_probs, test_combos = run_epoch( | |
| test_loader, model, criterion, device, task_type, | |
| is_training=False, | |
| ) | |
| # Overall test metrics | |
| if is_cls: | |
| test_m = calc_classification_metrics( | |
| test_preds, test_labels, test_probs) | |
| print( | |
| f" Test: Acc={test_m['acc']*100:.2f}%, " | |
| f"AUC={test_m['auc']*100:.2f}%, " | |
| f"Sen={test_m['sensitivity']*100:.2f}%, " | |
| f"Spe={test_m['specificity']*100:.2f}%, " | |
| f"F1={test_m['f1']*100:.2f}%" | |
| ) | |
| else: | |
| test_m = calc_regression_metrics(test_preds, test_labels) | |
| print( | |
| f" Test: MAE={test_m['mae']:.4f}, " | |
| f"RMSE={test_m['rmse']:.4f}, " | |
| f"Pearson={test_m['pearson']:.4f}" | |
| ) | |
| # Per-modality-combination breakdown | |
| combo_results = calc_metrics_by_combo( | |
| test_preds, test_labels, test_probs, test_combos, task_type) | |
| if combo_results: | |
| print(f"\n Per-modality-combination results:") | |
| for combo in sorted(combo_results.keys()): | |
| r = combo_results[combo] | |
| n = r['n_samples'] | |
| if is_cls: | |
| print(f" {combo:25s} (n={n:3d}) | " | |
| f"Acc={r['acc']*100:.1f}%, AUC={r['auc']*100:.1f}%") | |
| else: | |
| print(f" {combo:25s} (n={n:3d}) | " | |
| f"MAE={r['mae']:.4f}, Pearson={r['pearson']:.4f}") | |
| # Save per-combo results | |
| mode_tag = args.mode | |
| _save_combo_results(combo_results, task_type, seed, | |
| f"test_{mode_tag}", is_cls) | |
| # Cleanup | |
| del model, encoder, test_loader | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return test_m | |
| # ========================================================================= | |
| # Argument parsing | |
| # ========================================================================= | |
| def parse_args(): | |
| import argparse | |
| p = argparse.ArgumentParser( | |
| description='MultiMAE3D Test-Only Evaluation') | |
| # Mode & checkpoints | |
| p.add_argument('--mode', type=str, default='finetune', | |
| choices=['finetune', 'freeze_then_finetune'], | |
| help='Which training mode checkpoints to load') | |
| p.add_argument('--checkpoint_dir', type=str, default=None, | |
| help='Directory containing saved checkpoints. ' | |
| 'Defaults to saves/multimae_{mode}/') | |
| # Tasks & seeds | |
| p.add_argument('--tasks', type=str, nargs='+', | |
| default=['CN vs AD', 'CN vs MCI', 'MMSE', 'AGE'], | |
| help='Tasks to evaluate') | |
| p.add_argument('--n_seeds', type=int, default=3, | |
| help='Number of random seeds per task') | |
| # Data | |
| p.add_argument('--test_excel', type=str, | |
| default='./data/Downstream/' | |
| 'ADNI_Division/modality_data_test.xlsx') | |
| p.add_argument('--base_dir', type=str, | |
| default='./data/Downstream/ADNI/') | |
| p.add_argument('--modalities', type=str, nargs='+', | |
| default=['T1', 'T2', 'Flair', 'PET']) | |
| p.add_argument('--intersection', type=str2bool, default=False) | |
| p.add_argument('--image_size', type=int, nargs=3, | |
| default=[128, 128, 128]) | |
| p.add_argument('--batch_size', type=int, default=4) | |
| p.add_argument('--num_workers', type=int, default=8) | |
| # MultiMAE encoder architecture (must match checkpoint) | |
| p.add_argument('--img_size', type=int, default=128) | |
| p.add_argument('--patch_size', type=int, default=16) | |
| p.add_argument('--embed_dim', type=int, default=768) | |
| p.add_argument('--depth', type=int, default=12) | |
| p.add_argument('--num_heads', type=int, default=12) | |
| p.add_argument('--decoder_embed_dim', type=int, default=384) | |
| p.add_argument('--decoder_depth', type=int, default=2) | |
| p.add_argument('--decoder_num_heads', type=int, default=12) | |
| # Downstream head | |
| p.add_argument('--pool', type=str, default='cls', | |
| choices=['cls', 'mean']) | |
| p.add_argument('--dropout', type=float, default=0.1) | |
| # Device | |
| p.add_argument('--device', type=int, default=0) | |
| return p.parse_args() | |
| # ========================================================================= | |
| # Main | |
| # ========================================================================= | |
| def main(): | |
| args = parse_args() | |
| device = torch.device( | |
| f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu') | |
| # Resolve checkpoint directory | |
| if args.checkpoint_dir is None: | |
| args.checkpoint_dir = os.path.join( | |
| _BASE_DIR, 'saves', f'multimae_{args.mode}') | |
| print("=" * 80) | |
| print(f"MultiMAE3D Test-Only Evaluation") | |
| print(f" Mode : {args.mode}") | |
| print(f" Tasks : {args.tasks}") | |
| print(f" Seeds : {args.n_seeds}") | |
| print(f" Checkpoint dir : {args.checkpoint_dir}") | |
| print(f" Test data : {args.test_excel}") | |
| print(f" Pool : {args.pool}") | |
| print(f" Device : {device}") | |
| print("=" * 80) | |
| all_results = {} | |
| for task_type in args.tasks: | |
| print(f"\n{'='*80}") | |
| print(f"TASK: {task_type}") | |
| print(f"{'='*80}") | |
| is_cls = task_type in ('CN vs AD', 'CN vs MCI') | |
| seed_results = [] | |
| task_str = task_type.replace(' ', '_') | |
| for seed in range(args.n_seeds): | |
| ckpt_name = f'{task_str}_seed_{seed}_best.pth' | |
| ckpt_path = os.path.join(args.checkpoint_dir, ckpt_name) | |
| if not os.path.isfile(ckpt_path): | |
| print(f"\n--- Seed {seed} --- SKIPPED (checkpoint not found: {ckpt_name})") | |
| continue | |
| print(f"\n--- Seed {seed} ---") | |
| metrics = test_evaluate(args, task_type, seed, device, ckpt_path) | |
| seed_results.append(metrics) | |
| if not seed_results: | |
| print(f" No checkpoints found for {task_type}, skipping.") | |
| continue | |
| all_results[task_type] = seed_results | |
| # Per-task summary | |
| n = len(seed_results) | |
| print(f"\n{task_type} Summary ({n} seeds):") | |
| if is_cls: | |
| for key in ['acc', 'auc', 'sensitivity', 'specificity', 'f1']: | |
| vals = [r[key] * 100 for r in seed_results] | |
| print(f" {key:>12s}: {np.mean(vals):.2f} +/- {np.std(vals):.2f}%") | |
| else: | |
| for key in ['mae', 'rmse', 'pearson']: | |
| vals = [r[key] for r in seed_results] | |
| print(f" {key:>12s}: {np.mean(vals):.4f} +/- {np.std(vals):.4f}") | |
| # ---- Final summary table ---- | |
| print("\n" + "=" * 80) | |
| print("FINAL SUMMARY") | |
| print("=" * 80) | |
| summary_rows = [] | |
| for task_type in args.tasks: | |
| if task_type not in all_results: | |
| continue | |
| results = all_results[task_type] | |
| is_cls = task_type in ('CN vs AD', 'CN vs MCI') | |
| row = {'Task': task_type, 'Mode': args.mode, 'N_seeds': len(results)} | |
| if is_cls: | |
| for key in ['acc', 'auc', 'sensitivity', 'specificity', 'f1']: | |
| vals = [r[key] * 100 for r in results] | |
| row[f'{key}_mean'] = np.mean(vals) | |
| row[f'{key}_std'] = np.std(vals) | |
| row[key] = f"{np.mean(vals):.2f}+/-{np.std(vals):.2f}" | |
| for i, r in enumerate(results): | |
| row[f'seed_{i}_acc'] = r['acc'] * 100 | |
| row[f'seed_{i}_auc'] = r['auc'] * 100 | |
| vals_acc = [r['acc'] * 100 for r in results] | |
| vals_auc = [r['auc'] * 100 for r in results] | |
| print(f" {task_type:12s} | " | |
| f"Acc: {np.mean(vals_acc):.2f}+/-{np.std(vals_acc):.2f}% | " | |
| f"AUC: {np.mean(vals_auc):.2f}+/-{np.std(vals_auc):.2f}%") | |
| else: | |
| for key in ['mae', 'rmse', 'pearson']: | |
| vals = [r[key] for r in results] | |
| row[f'{key}_mean'] = np.mean(vals) | |
| row[f'{key}_std'] = np.std(vals) | |
| row[key] = f"{np.mean(vals):.4f}+/-{np.std(vals):.4f}" | |
| for i, r in enumerate(results): | |
| row[f'seed_{i}_mae'] = r['mae'] | |
| row[f'seed_{i}_pearson'] = r['pearson'] | |
| vals_mae = [r['mae'] for r in results] | |
| vals_r = [r['pearson'] for r in results] | |
| print(f" {task_type:12s} | " | |
| f"MAE: {np.mean(vals_mae):.4f}+/-{np.std(vals_mae):.4f} | " | |
| f"Pearson: {np.mean(vals_r):.4f}+/-{np.std(vals_r):.4f}") | |
| summary_rows.append(row) | |
| # Save summary Excel | |
| if summary_rows: | |
| results_dir = os.path.join(_BASE_DIR, 'results') | |
| os.makedirs(results_dir, exist_ok=True) | |
| summary_path = os.path.join( | |
| results_dir, f'multimae_test_{args.mode}_summary.xlsx') | |
| pd.DataFrame(summary_rows).to_excel(summary_path, index=False) | |
| print(f"\nSummary saved to: {summary_path}") | |
| print("=" * 80) | |
| if __name__ == '__main__': | |
| main() | |