#!/usr/bin/env python3 """ Experiment G: Per-subject diagnostic analysis. Load the best scene-recognition checkpoint(s) from previous T1 runs and produce a per-test-volunteer breakdown of F1 and Accuracy. Reveals whether aggregate metrics are driven by one or two outlier subjects, as reviewers often ask. Runs CPU-side; no training. """ import os import sys import json import glob import argparse import numpy as np import torch from sklearn.metrics import accuracy_score, f1_score sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from data.dataset import ( MultimodalSceneDataset, TEST_VOLS, SCENE_LABELS, NUM_CLASSES, get_dataloaders, ) from nets.models import build_model def per_subject_eval(model, device, modalities, stats, downsample): """Evaluate one model across each test volunteer separately.""" breakdown = {} for vol in TEST_VOLS: ds = MultimodalSceneDataset([vol], modalities, downsample=downsample, stats=stats) if len(ds) == 0: breakdown[vol] = {'n': 0} continue preds, ys = [], [] model.eval() with torch.no_grad(): for i in range(len(ds)): x, y = ds[i] x = x.to(device).unsqueeze(0) mask = torch.ones(1, x.size(1), dtype=torch.bool).to(device) logits = model(x, mask) preds.append(logits.argmax(dim=1).cpu().item()) ys.append(y) breakdown[vol] = { 'n': len(ds), 'acc': float(accuracy_score(ys, preds)), 'f1': float(f1_score(ys, preds, average='macro', zero_division=0)), 'preds': preds, 'labels': ys, 'samples': ds.sample_info, } return breakdown def run_on_checkpoint(ckpt_path, args_json_path, output_dir): ckpt_args = json.load(open(args_json_path))['args'] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') modalities = ckpt_args['modalities'] if isinstance(ckpt_args['modalities'], list) \ else ckpt_args['modalities'].split(',') downsample = ckpt_args.get('downsample', 5) # Get train stats _, _, _, info = get_dataloaders(modalities, batch_size=ckpt_args.get('batch_size', 16), downsample=downsample) # Need the actual stats object -- re-load train set to compute tr_ds = MultimodalSceneDataset( __import__('experiments.dataset', fromlist=['TRAIN_VOLS']).TRAIN_VOLS, modalities, downsample=downsample) stats = tr_ds.get_stats() model = build_model( ckpt_args.get('model', 'transformer'), ckpt_args.get('fusion', 'late'), info['feat_dim'], info['modality_dims'], NUM_CLASSES, hidden_dim=ckpt_args.get('hidden_dim', 128), proj_dim=ckpt_args.get('proj_dim', 0), late_agg=ckpt_args.get('late_agg', 'mean'), ).to(device) try: sd = torch.load(ckpt_path, weights_only=True, map_location=device) except Exception: sd = torch.load(ckpt_path, map_location=device) model.load_state_dict(sd, strict=False) breakdown = per_subject_eval(model, device, modalities, stats, downsample) # Overall F1 all_preds, all_ys = [], [] for v, info_v in breakdown.items(): if info_v.get('n', 0) > 0: all_preds.extend(info_v['preds']) all_ys.extend(info_v['labels']) overall_f1 = float(f1_score(all_ys, all_preds, average='macro', zero_division=0)) overall_acc = float(accuracy_score(all_ys, all_preds)) # Per-subject summary summary = { 'ckpt': ckpt_path, 'modalities': modalities, 'overall': {'acc': overall_acc, 'f1': overall_f1, 'n': len(all_preds)}, 'per_subject': { v: {'n': b.get('n'), 'acc': b.get('acc'), 'f1': b.get('f1')} for v, b in breakdown.items() }, 'detail': breakdown, } os.makedirs(output_dir, exist_ok=True) out_path = os.path.join(output_dir, os.path.basename( os.path.dirname(ckpt_path)) + '_per_subject.json') with open(out_path, 'w') as f: json.dump(summary, f, indent=2) print(f"Per-subject breakdown saved: {out_path}") print(f"Overall F1: {overall_f1:.4f} Acc: {overall_acc:.4f}") for v, b in summary['per_subject'].items(): print(f" {v}: n={b['n']} acc={b.get('acc'):.3f} f1={b.get('f1'):.3f}" if b.get('n') else f" {v}: (empty)") return summary def main(): p = argparse.ArgumentParser() p.add_argument('--exp_root', type=str, required=True, help='Directory containing run subdirs with model_best.pt and results.json') p.add_argument('--output_dir', type=str, required=True) args = p.parse_args() runs = [] for sub in sorted(os.listdir(args.exp_root)): if sub == 'slurm_logs': continue ckpt = os.path.join(args.exp_root, sub, 'model_best.pt') res = os.path.join(args.exp_root, sub, 'results.json') if os.path.exists(ckpt) and os.path.exists(res): runs.append((ckpt, res)) print(f"Found {len(runs)} runs with checkpoints.") for ckpt, res in runs: try: run_on_checkpoint(ckpt, res, args.output_dir) except Exception as e: print(f" FAIL {ckpt}: {e}") if __name__ == '__main__': main()