MaybeRichard's picture
Upload folder using huggingface_hub
b8fae22 verified
Raw
History Blame Contribute Delete
2.89 kB
"""Test-time evaluation: load best checkpoint, run on the test split, compute
Dice / IoU / HD95 (per image -> mean +- SD), optionally save overlay visualizations
and a metrics.json. Runs single-process (rank 0) for deterministic reporting.
"""
from __future__ import annotations
import os
import json
import numpy as np
import torch
from ..data.loaders import build_dataset, build_loader
from ..metrics.metrics import per_image_metrics, aggregate
from ..visualize.overlay import save_overlay
@torch.no_grad()
def evaluate(cfg, model, device, ckpt_path: str = "") -> dict:
ds = build_dataset(cfg, "test")
num_classes = ds.num_classes
loader = build_loader(cfg, "test", ds) # single-process: no DistributedSampler
ckpt_path = ckpt_path or os.path.join(cfg.out_dir(), "best.pth")
if os.path.isfile(ckpt_path):
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
state = ckpt.get("model", ckpt)
model.load_state_dict(state)
print(f"[eval] loaded {ckpt_path}")
else:
print(f"[eval][warn] checkpoint not found: {ckpt_path} (evaluating current weights)")
model = model.to(device).eval()
use_amp = cfg.amp in ("bf16", "fp16")
amp_dtype = torch.bfloat16 if cfg.amp == "bf16" else torch.float16
records = []
vis_dir = os.path.join(cfg.out_dir(), "vis")
if cfg.visualize:
os.makedirs(vis_dir, exist_ok=True)
saved = 0
for batch in loader:
img = batch["image"].to(device, non_blocking=True)
msk = batch["mask"].numpy()
names = batch["name"]
with torch.autocast("cuda", dtype=amp_dtype, enabled=use_amp):
logits = model(img)
pred = logits.argmax(1).cpu().numpy()
for i in range(pred.shape[0]):
records.append(per_image_metrics(
pred[i], msk[i], num_classes,
include_background=cfg.include_background,
compute_hd95=cfg.compute_hd95))
if cfg.visualize and saved < cfg.vis_max:
save_overlay(img[i].cpu(), msk[i], pred[i], num_classes,
os.path.join(vis_dir, f"{names[i]}.png"))
saved += 1
agg = aggregate(records)
out = {
"dataset": cfg.dataset, "protocol": cfg.protocol, "arch": cfg.arch,
"seed": cfg.seed, "num_classes": num_classes,
"metrics": agg,
"per_image": records,
}
out_path = os.path.join(cfg.out_dir(), "metrics.json")
with open(out_path, "w") as f:
json.dump(out, f, indent=2)
print(f"[eval] dice={agg['dice_mean']:.4f}+-{agg['dice_std']:.4f} "
f"iou={agg['iou_mean']:.4f} hd95={agg['hd95_mean']:.3f} assd={agg['assd_mean']:.3f} "
f"sens={agg['sensitivity_mean']:.4f} spec={agg['specificity_mean']:.4f} "
f"prec={agg['precision_mean']:.4f} -> {out_path}")
return out