"""Re-score an existing framework checkpoint at a COMMON evaluation resolution R. Needed for resolution-fair comparison: a model that takes a fixed input size (SwinUNet=224, TransUNet=256) is run at its native input size, but its prediction and the GROUND TRUTH (loaded at native, not the 256-degraded dataloader copy) are both resized to R, and metrics are computed at R — matching how the conv methods (trained at R) and nnU-Net/U-Mamba (re-scored with --eval_size R) are evaluated. python framework/eval_at_res.py --data_root --dataset fives --protocol official \ --arch swinunet --seed 0 --eval_size 768 --exp_name baselines """ from __future__ import annotations import os import sys import json import argparse import numpy as np import cv2 import torch sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) from framework.models.registry import build_model, required_img_size from framework.metrics.metrics import per_image_metrics, aggregate from framework.data.unified_dataset import UnifiedSegDataset from framework.data.transforms import build_transform def _to_R(arr, R): return cv2.resize(arr.astype(np.uint8), (R, R), interpolation=cv2.INTER_NEAREST).astype(np.int64) def main(): ap = argparse.ArgumentParser() ap.add_argument("--data_root", required=True) ap.add_argument("--dataset", required=True) ap.add_argument("--protocol", required=True) ap.add_argument("--arch", required=True) ap.add_argument("--encoder", default="resnet50") ap.add_argument("--seed", type=int, required=True) ap.add_argument("--eval_size", type=int, required=True, help="common resolution R") ap.add_argument("--exp_name", default="baselines") ap.add_argument("--out_root", default="results") ap.add_argument("--normalize", default="auto") args = ap.parse_args() R = args.eval_size model_res = required_img_size(args.arch) or R # SwinUNet 224 / TransUNet 256 / conv -> R ds = UnifiedSegDataset(args.data_root, args.dataset, args.protocol, "test", transform=None) ds.transform = build_transform(model_res, ds.in_channels, train=False, aug="none", normalize=args.normalize) num_classes = ds.num_classes out_dir = os.path.join(args.out_root, args.exp_name, f"{args.dataset}_{args.protocol}", args.arch, f"seed{args.seed}") ckpt_path = os.path.join(out_dir, "best.pth") if not os.path.isfile(ckpt_path): raise SystemExit(f"checkpoint not found: {ckpt_path}") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = build_model(args.arch, in_channels=ds.in_channels, num_classes=num_classes, img_size=model_res, encoder=args.encoder, encoder_weights="none", pretrained_ckpt="") ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) model.load_state_dict(ckpt.get("model", ckpt)) model = model.to(device).eval() records = [] with torch.no_grad(): for idx in range(len(ds)): item = ds[idx] img = item["image"].unsqueeze(0).to(device) # 1,C,model_res,model_res pred = model(img).argmax(1)[0].cpu().numpy() # model_res x model_res gt = cv2.imread(ds.pairs[idx][1], cv2.IMREAD_GRAYSCALE) # native H x W, values 0..C-1 records.append(per_image_metrics(_to_R(pred, R), _to_R(gt, R), num_classes, include_background=False, compute_hd95=True)) agg = aggregate(records) out = {"dataset": args.dataset, "protocol": args.protocol, "arch": args.arch, "seed": args.seed, "num_classes": num_classes, "eval_size": R, "metrics": agg, "per_image": records} os.makedirs(out_dir, exist_ok=True) with open(os.path.join(out_dir, "metrics.json"), "w") as f: json.dump(out, f, indent=2) print(f"[eval_at_res] {args.dataset}/{args.protocol} {args.arch} seed{args.seed} @R={R}: " f"n={len(records)} dice={agg['dice_mean']:.4f} hd95={agg['hd95_mean']:.2f} -> {out_dir}") if __name__ == "__main__": main()