| """Re-score an existing framework checkpoint at a COMMON evaluation resolution R. |
| |
| Needed for resolution-fair comparison: a model that takes a fixed input size |
| (SwinUNet=224, TransUNet=256) is run at its native input size, but its prediction |
| and the GROUND TRUTH (loaded at native, not the 256-degraded dataloader copy) are |
| both resized to R, and metrics are computed at R — matching how the conv methods |
| (trained at R) and nnU-Net/U-Mamba (re-scored with --eval_size R) are evaluated. |
| |
| python framework/eval_at_res.py --data_root <root> --dataset fives --protocol official \ |
| --arch swinunet --seed 0 --eval_size 768 --exp_name baselines |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import sys |
| import json |
| import argparse |
|
|
| import numpy as np |
| import cv2 |
| import torch |
|
|
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
| from framework.models.registry import build_model, required_img_size |
| from framework.metrics.metrics import per_image_metrics, aggregate |
| from framework.data.unified_dataset import UnifiedSegDataset |
| from framework.data.transforms import build_transform |
|
|
|
|
| def _to_R(arr, R): |
| return cv2.resize(arr.astype(np.uint8), (R, R), interpolation=cv2.INTER_NEAREST).astype(np.int64) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--data_root", required=True) |
| ap.add_argument("--dataset", required=True) |
| ap.add_argument("--protocol", required=True) |
| ap.add_argument("--arch", required=True) |
| ap.add_argument("--encoder", default="resnet50") |
| ap.add_argument("--seed", type=int, required=True) |
| ap.add_argument("--eval_size", type=int, required=True, help="common resolution R") |
| ap.add_argument("--exp_name", default="baselines") |
| ap.add_argument("--out_root", default="results") |
| ap.add_argument("--normalize", default="auto") |
| args = ap.parse_args() |
|
|
| R = args.eval_size |
| model_res = required_img_size(args.arch) or R |
|
|
| ds = UnifiedSegDataset(args.data_root, args.dataset, args.protocol, "test", transform=None) |
| ds.transform = build_transform(model_res, ds.in_channels, train=False, |
| aug="none", normalize=args.normalize) |
| num_classes = ds.num_classes |
|
|
| out_dir = os.path.join(args.out_root, args.exp_name, |
| f"{args.dataset}_{args.protocol}", args.arch, f"seed{args.seed}") |
| ckpt_path = os.path.join(out_dir, "best.pth") |
| if not os.path.isfile(ckpt_path): |
| raise SystemExit(f"checkpoint not found: {ckpt_path}") |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = build_model(args.arch, in_channels=ds.in_channels, num_classes=num_classes, |
| img_size=model_res, encoder=args.encoder, |
| encoder_weights="none", pretrained_ckpt="") |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| model.load_state_dict(ckpt.get("model", ckpt)) |
| model = model.to(device).eval() |
|
|
| records = [] |
| with torch.no_grad(): |
| for idx in range(len(ds)): |
| item = ds[idx] |
| img = item["image"].unsqueeze(0).to(device) |
| pred = model(img).argmax(1)[0].cpu().numpy() |
| gt = cv2.imread(ds.pairs[idx][1], cv2.IMREAD_GRAYSCALE) |
| records.append(per_image_metrics(_to_R(pred, R), _to_R(gt, R), num_classes, |
| include_background=False, compute_hd95=True)) |
|
|
| agg = aggregate(records) |
| out = {"dataset": args.dataset, "protocol": args.protocol, "arch": args.arch, |
| "seed": args.seed, "num_classes": num_classes, |
| "eval_size": R, "metrics": agg, "per_image": records} |
| os.makedirs(out_dir, exist_ok=True) |
| with open(os.path.join(out_dir, "metrics.json"), "w") as f: |
| json.dump(out, f, indent=2) |
| print(f"[eval_at_res] {args.dataset}/{args.protocol} {args.arch} seed{args.seed} @R={R}: " |
| f"n={len(records)} dice={agg['dice_mean']:.4f} hd95={agg['hd95_mean']:.2f} -> {out_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|