"""Test-time evaluation: load best checkpoint, run on the test split, compute Dice / IoU / HD95 (per image -> mean +- SD), optionally save overlay visualizations and a metrics.json. Runs single-process (rank 0) for deterministic reporting. """ from __future__ import annotations import os import json import numpy as np import torch from ..data.loaders import build_dataset, build_loader from ..metrics.metrics import per_image_metrics, aggregate from ..visualize.overlay import save_overlay @torch.no_grad() def evaluate(cfg, model, device, ckpt_path: str = "") -> dict: ds = build_dataset(cfg, "test") num_classes = ds.num_classes loader = build_loader(cfg, "test", ds) # single-process: no DistributedSampler ckpt_path = ckpt_path or os.path.join(cfg.out_dir(), "best.pth") if os.path.isfile(ckpt_path): ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) state = ckpt.get("model", ckpt) model.load_state_dict(state) print(f"[eval] loaded {ckpt_path}") else: print(f"[eval][warn] checkpoint not found: {ckpt_path} (evaluating current weights)") model = model.to(device).eval() use_amp = cfg.amp in ("bf16", "fp16") amp_dtype = torch.bfloat16 if cfg.amp == "bf16" else torch.float16 records = [] vis_dir = os.path.join(cfg.out_dir(), "vis") if cfg.visualize: os.makedirs(vis_dir, exist_ok=True) saved = 0 for batch in loader: img = batch["image"].to(device, non_blocking=True) msk = batch["mask"].numpy() names = batch["name"] with torch.autocast("cuda", dtype=amp_dtype, enabled=use_amp): logits = model(img) pred = logits.argmax(1).cpu().numpy() for i in range(pred.shape[0]): records.append(per_image_metrics( pred[i], msk[i], num_classes, include_background=cfg.include_background, compute_hd95=cfg.compute_hd95)) if cfg.visualize and saved < cfg.vis_max: save_overlay(img[i].cpu(), msk[i], pred[i], num_classes, os.path.join(vis_dir, f"{names[i]}.png")) saved += 1 agg = aggregate(records) out = { "dataset": cfg.dataset, "protocol": cfg.protocol, "arch": cfg.arch, "seed": cfg.seed, "num_classes": num_classes, "metrics": agg, "per_image": records, } out_path = os.path.join(cfg.out_dir(), "metrics.json") with open(out_path, "w") as f: json.dump(out, f, indent=2) print(f"[eval] dice={agg['dice_mean']:.4f}+-{agg['dice_std']:.4f} " f"iou={agg['iou_mean']:.4f} hd95={agg['hd95_mean']:.3f} assd={agg['assd_mean']:.3f} " f"sens={agg['sensitivity_mean']:.4f} spec={agg['specificity_mean']:.4f} " f"prec={agg['precision_mean']:.4f} -> {out_path}") return out