| |
| """ |
| Experiment G: Per-subject diagnostic analysis. |
| |
| Load the best scene-recognition checkpoint(s) from previous T1 runs and |
| produce a per-test-volunteer breakdown of F1 and Accuracy. Reveals whether |
| aggregate metrics are driven by one or two outlier subjects, as reviewers |
| often ask. |
| |
| Runs CPU-side; no training. |
| """ |
|
|
| import os |
| import sys |
| import json |
| import glob |
| import argparse |
| import numpy as np |
| import torch |
| from sklearn.metrics import accuracy_score, f1_score |
|
|
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| from data.dataset import ( |
| MultimodalSceneDataset, TEST_VOLS, SCENE_LABELS, NUM_CLASSES, |
| get_dataloaders, |
| ) |
| from nets.models import build_model |
|
|
|
|
| def per_subject_eval(model, device, modalities, stats, downsample): |
| """Evaluate one model across each test volunteer separately.""" |
| breakdown = {} |
| for vol in TEST_VOLS: |
| ds = MultimodalSceneDataset([vol], modalities, downsample=downsample, |
| stats=stats) |
| if len(ds) == 0: |
| breakdown[vol] = {'n': 0} |
| continue |
| preds, ys = [], [] |
| model.eval() |
| with torch.no_grad(): |
| for i in range(len(ds)): |
| x, y = ds[i] |
| x = x.to(device).unsqueeze(0) |
| mask = torch.ones(1, x.size(1), dtype=torch.bool).to(device) |
| logits = model(x, mask) |
| preds.append(logits.argmax(dim=1).cpu().item()) |
| ys.append(y) |
| breakdown[vol] = { |
| 'n': len(ds), |
| 'acc': float(accuracy_score(ys, preds)), |
| 'f1': float(f1_score(ys, preds, average='macro', zero_division=0)), |
| 'preds': preds, |
| 'labels': ys, |
| 'samples': ds.sample_info, |
| } |
| return breakdown |
|
|
|
|
| def run_on_checkpoint(ckpt_path, args_json_path, output_dir): |
| ckpt_args = json.load(open(args_json_path))['args'] |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| modalities = ckpt_args['modalities'] if isinstance(ckpt_args['modalities'], list) \ |
| else ckpt_args['modalities'].split(',') |
| downsample = ckpt_args.get('downsample', 5) |
| |
| _, _, _, info = get_dataloaders(modalities, |
| batch_size=ckpt_args.get('batch_size', 16), |
| downsample=downsample) |
| |
| tr_ds = MultimodalSceneDataset( |
| __import__('experiments.dataset', fromlist=['TRAIN_VOLS']).TRAIN_VOLS, |
| modalities, downsample=downsample) |
| stats = tr_ds.get_stats() |
|
|
| model = build_model( |
| ckpt_args.get('model', 'transformer'), |
| ckpt_args.get('fusion', 'late'), |
| info['feat_dim'], info['modality_dims'], NUM_CLASSES, |
| hidden_dim=ckpt_args.get('hidden_dim', 128), |
| proj_dim=ckpt_args.get('proj_dim', 0), |
| late_agg=ckpt_args.get('late_agg', 'mean'), |
| ).to(device) |
| try: |
| sd = torch.load(ckpt_path, weights_only=True, map_location=device) |
| except Exception: |
| sd = torch.load(ckpt_path, map_location=device) |
| model.load_state_dict(sd, strict=False) |
|
|
| breakdown = per_subject_eval(model, device, modalities, stats, downsample) |
|
|
| |
| all_preds, all_ys = [], [] |
| for v, info_v in breakdown.items(): |
| if info_v.get('n', 0) > 0: |
| all_preds.extend(info_v['preds']) |
| all_ys.extend(info_v['labels']) |
| overall_f1 = float(f1_score(all_ys, all_preds, average='macro', zero_division=0)) |
| overall_acc = float(accuracy_score(all_ys, all_preds)) |
|
|
| |
| summary = { |
| 'ckpt': ckpt_path, |
| 'modalities': modalities, |
| 'overall': {'acc': overall_acc, 'f1': overall_f1, |
| 'n': len(all_preds)}, |
| 'per_subject': { |
| v: {'n': b.get('n'), 'acc': b.get('acc'), 'f1': b.get('f1')} |
| for v, b in breakdown.items() |
| }, |
| 'detail': breakdown, |
| } |
| os.makedirs(output_dir, exist_ok=True) |
| out_path = os.path.join(output_dir, os.path.basename( |
| os.path.dirname(ckpt_path)) + '_per_subject.json') |
| with open(out_path, 'w') as f: |
| json.dump(summary, f, indent=2) |
| print(f"Per-subject breakdown saved: {out_path}") |
| print(f"Overall F1: {overall_f1:.4f} Acc: {overall_acc:.4f}") |
| for v, b in summary['per_subject'].items(): |
| print(f" {v}: n={b['n']} acc={b.get('acc'):.3f} f1={b.get('f1'):.3f}" |
| if b.get('n') else f" {v}: (empty)") |
| return summary |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--exp_root', type=str, required=True, |
| help='Directory containing run subdirs with model_best.pt and results.json') |
| p.add_argument('--output_dir', type=str, required=True) |
| args = p.parse_args() |
|
|
| runs = [] |
| for sub in sorted(os.listdir(args.exp_root)): |
| if sub == 'slurm_logs': |
| continue |
| ckpt = os.path.join(args.exp_root, sub, 'model_best.pt') |
| res = os.path.join(args.exp_root, sub, 'results.json') |
| if os.path.exists(ckpt) and os.path.exists(res): |
| runs.append((ckpt, res)) |
| print(f"Found {len(runs)} runs with checkpoints.") |
| for ckpt, res in runs: |
| try: |
| run_on_checkpoint(ckpt, res, args.output_dir) |
| except Exception as e: |
| print(f" FAIL {ckpt}: {e}") |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|