GenSeg-Baselines / code /framework /eval_at_res.py
MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
4.2 kB
"""Re-score an existing framework checkpoint at a COMMON evaluation resolution R.
Needed for resolution-fair comparison: a model that takes a fixed input size
(SwinUNet=224, TransUNet=256) is run at its native input size, but its prediction
and the GROUND TRUTH (loaded at native, not the 256-degraded dataloader copy) are
both resized to R, and metrics are computed at R — matching how the conv methods
(trained at R) and nnU-Net/U-Mamba (re-scored with --eval_size R) are evaluated.
python framework/eval_at_res.py --data_root <root> --dataset fives --protocol official \
--arch swinunet --seed 0 --eval_size 768 --exp_name baselines
"""
from __future__ import annotations
import os
import sys
import json
import argparse
import numpy as np
import cv2
import torch
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from framework.models.registry import build_model, required_img_size
from framework.metrics.metrics import per_image_metrics, aggregate
from framework.data.unified_dataset import UnifiedSegDataset
from framework.data.transforms import build_transform
def _to_R(arr, R):
return cv2.resize(arr.astype(np.uint8), (R, R), interpolation=cv2.INTER_NEAREST).astype(np.int64)
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--data_root", required=True)
ap.add_argument("--dataset", required=True)
ap.add_argument("--protocol", required=True)
ap.add_argument("--arch", required=True)
ap.add_argument("--encoder", default="resnet50")
ap.add_argument("--seed", type=int, required=True)
ap.add_argument("--eval_size", type=int, required=True, help="common resolution R")
ap.add_argument("--exp_name", default="baselines")
ap.add_argument("--out_root", default="results")
ap.add_argument("--normalize", default="auto")
args = ap.parse_args()
R = args.eval_size
model_res = required_img_size(args.arch) or R # SwinUNet 224 / TransUNet 256 / conv -> R
ds = UnifiedSegDataset(args.data_root, args.dataset, args.protocol, "test", transform=None)
ds.transform = build_transform(model_res, ds.in_channels, train=False,
aug="none", normalize=args.normalize)
num_classes = ds.num_classes
out_dir = os.path.join(args.out_root, args.exp_name,
f"{args.dataset}_{args.protocol}", args.arch, f"seed{args.seed}")
ckpt_path = os.path.join(out_dir, "best.pth")
if not os.path.isfile(ckpt_path):
raise SystemExit(f"checkpoint not found: {ckpt_path}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model(args.arch, in_channels=ds.in_channels, num_classes=num_classes,
img_size=model_res, encoder=args.encoder,
encoder_weights="none", pretrained_ckpt="")
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
model.load_state_dict(ckpt.get("model", ckpt))
model = model.to(device).eval()
records = []
with torch.no_grad():
for idx in range(len(ds)):
item = ds[idx]
img = item["image"].unsqueeze(0).to(device) # 1,C,model_res,model_res
pred = model(img).argmax(1)[0].cpu().numpy() # model_res x model_res
gt = cv2.imread(ds.pairs[idx][1], cv2.IMREAD_GRAYSCALE) # native H x W, values 0..C-1
records.append(per_image_metrics(_to_R(pred, R), _to_R(gt, R), num_classes,
include_background=False, compute_hd95=True))
agg = aggregate(records)
out = {"dataset": args.dataset, "protocol": args.protocol, "arch": args.arch,
"seed": args.seed, "num_classes": num_classes,
"eval_size": R, "metrics": agg, "per_image": records}
os.makedirs(out_dir, exist_ok=True)
with open(os.path.join(out_dir, "metrics.json"), "w") as f:
json.dump(out, f, indent=2)
print(f"[eval_at_res] {args.dataset}/{args.protocol} {args.arch} seed{args.seed} @R={R}: "
f"n={len(records)} dice={agg['dice_mean']:.4f} hd95={agg['hd95_mean']:.2f} -> {out_dir}")
if __name__ == "__main__":
main()