| """Test-time evaluation: load best checkpoint, run on the test split, compute |
| Dice / IoU / HD95 (per image -> mean +- SD), optionally save overlay visualizations |
| and a metrics.json. Runs single-process (rank 0) for deterministic reporting. |
| """ |
| from __future__ import annotations |
|
|
| import os |
| import json |
|
|
| import numpy as np |
| import torch |
|
|
| from ..data.loaders import build_dataset, build_loader |
| from ..metrics.metrics import per_image_metrics, aggregate |
| from ..visualize.overlay import save_overlay |
|
|
|
|
| @torch.no_grad() |
| def evaluate(cfg, model, device, ckpt_path: str = "") -> dict: |
| ds = build_dataset(cfg, "test") |
| num_classes = ds.num_classes |
| loader = build_loader(cfg, "test", ds) |
|
|
| ckpt_path = ckpt_path or os.path.join(cfg.out_dir(), "best.pth") |
| if os.path.isfile(ckpt_path): |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False) |
| state = ckpt.get("model", ckpt) |
| model.load_state_dict(state) |
| print(f"[eval] loaded {ckpt_path}") |
| else: |
| print(f"[eval][warn] checkpoint not found: {ckpt_path} (evaluating current weights)") |
|
|
| model = model.to(device).eval() |
| use_amp = cfg.amp in ("bf16", "fp16") |
| amp_dtype = torch.bfloat16 if cfg.amp == "bf16" else torch.float16 |
|
|
| records = [] |
| vis_dir = os.path.join(cfg.out_dir(), "vis") |
| if cfg.visualize: |
| os.makedirs(vis_dir, exist_ok=True) |
| saved = 0 |
|
|
| for batch in loader: |
| img = batch["image"].to(device, non_blocking=True) |
| msk = batch["mask"].numpy() |
| names = batch["name"] |
| with torch.autocast("cuda", dtype=amp_dtype, enabled=use_amp): |
| logits = model(img) |
| pred = logits.argmax(1).cpu().numpy() |
| for i in range(pred.shape[0]): |
| records.append(per_image_metrics( |
| pred[i], msk[i], num_classes, |
| include_background=cfg.include_background, |
| compute_hd95=cfg.compute_hd95)) |
| if cfg.visualize and saved < cfg.vis_max: |
| save_overlay(img[i].cpu(), msk[i], pred[i], num_classes, |
| os.path.join(vis_dir, f"{names[i]}.png")) |
| saved += 1 |
|
|
| agg = aggregate(records) |
| out = { |
| "dataset": cfg.dataset, "protocol": cfg.protocol, "arch": cfg.arch, |
| "seed": cfg.seed, "num_classes": num_classes, |
| "metrics": agg, |
| "per_image": records, |
| } |
| out_path = os.path.join(cfg.out_dir(), "metrics.json") |
| with open(out_path, "w") as f: |
| json.dump(out, f, indent=2) |
| print(f"[eval] dice={agg['dice_mean']:.4f}+-{agg['dice_std']:.4f} " |
| f"iou={agg['iou_mean']:.4f} hd95={agg['hd95_mean']:.3f} assd={agg['assd_mean']:.3f} " |
| f"sens={agg['sensitivity_mean']:.4f} spec={agg['specificity_mean']:.4f} " |
| f"prec={agg['precision_mean']:.4f} -> {out_path}") |
| return out |
|
|