File size: 5,856 Bytes
a5fa872 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """DiaFoot.AI v2 — Evaluation Entry Point.
Phase 4: Evaluate trained models on test set.
Usage:
# Evaluate classifier
python scripts/evaluate.py --task classify \
# Evaluate segmentation
python scripts/evaluate.py --task segment \
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.data.augmentation import get_val_transforms
from src.data.torch_dataset import DFUDataset
from src.evaluation.classification_metrics import (
compute_classification_metrics,
print_classification_report,
)
from src.evaluation.metrics import (
aggregate_metrics,
compute_segmentation_metrics,
print_segmentation_report,
)
from src.models.classifier import TriageClassifier
from src.models.unetpp import build_unetpp
def evaluate_classifier(checkpoint_path: str, splits_dir: str, device: str) -> None:
"""Evaluate triage classifier on test set."""
logger = logging.getLogger("eval_classifier")
model = TriageClassifier(backbone="tf_efficientnetv2_m", num_classes=3, pretrained=False)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
model = model.to(device)
model.eval()
test_ds = DFUDataset(
split_csv=Path(splits_dir) / "test.csv",
transform=get_val_transforms(),
)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=4)
all_labels = []
all_preds = []
all_probs = []
with torch.no_grad():
for batch in test_loader:
images = batch["image"].to(device)
labels = batch["label"]
logits = model(images)
probs = torch.softmax(logits, dim=1)
preds = logits.argmax(dim=1)
all_labels.extend(labels.numpy())
all_preds.extend(preds.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
y_true = np.array(all_labels)
y_pred = np.array(all_preds)
y_prob = np.array(all_probs)
metrics = compute_classification_metrics(y_true, y_pred, y_prob)
print_classification_report(metrics)
# Save results
output_path = Path("results/classification_metrics.json")
output_path.parent.mkdir(parents=True, exist_ok=True)
save_metrics = {k: v for k, v in metrics.items() if k != "report"}
with open(output_path, "w") as f:
json.dump(save_metrics, f, indent=2)
logger.info("Results saved to %s", output_path)
def evaluate_segmentation(checkpoint_path: str, splits_dir: str, device: str) -> None:
"""Evaluate segmentation model on test set."""
logger = logging.getLogger("eval_segmentation")
model = build_unetpp(encoder_name="efficientnet-b4", encoder_weights=None, classes=1)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
model = model.to(device)
model.eval()
test_ds = DFUDataset(
split_csv=Path(splits_dir) / "test.csv",
transform=get_val_transforms(),
return_metadata=True,
)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4)
all_metrics = []
dfu_metrics = []
non_dfu_metrics = []
with torch.no_grad():
for batch in test_loader:
images = batch["image"].to(device)
masks = batch["mask"].numpy()
labels = batch["label"].numpy()
logits = model(images)
preds = (torch.sigmoid(logits) > 0.5).squeeze(1).cpu().numpy().astype(np.uint8)
for i in range(len(images)):
pred_mask = preds[i]
gt_mask = masks[i]
m = compute_segmentation_metrics(pred_mask, gt_mask)
all_metrics.append(m)
if labels[i] == 2:
dfu_metrics.append(m)
elif labels[i] == 1:
non_dfu_metrics.append(m)
# Overall results
summary = aggregate_metrics(all_metrics)
print_segmentation_report(summary)
# Per-class results
if dfu_metrics:
print("DFU images only:")
dfu_summary = aggregate_metrics(dfu_metrics)
print_segmentation_report(dfu_summary)
if non_dfu_metrics:
print("Non-DFU images only:")
non_dfu_summary = aggregate_metrics(non_dfu_metrics)
print_segmentation_report(non_dfu_summary)
# Save results
output_path = Path("results/segmentation_metrics.json")
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(summary, f, indent=2, default=str)
logger.info("Results saved to %s", output_path)
def main() -> None:
"""Run evaluation."""
parser = argparse.ArgumentParser(description="DiaFoot.AI v2 Evaluation")
parser.add_argument("--task", type=str, required=True, choices=["classify", "segment"])
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--splits-dir", type=str, default="data/splits")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%H:%M:%S",
)
dev = args.device if torch.cuda.is_available() else "cpu"
if args.task == "classify":
evaluate_classifier(args.checkpoint, args.splits_dir, dev)
elif args.task == "segment":
evaluate_segmentation(args.checkpoint, args.splits_dir, dev)
if __name__ == "__main__":
main()
|