DiaFoot.AI-v2 / scripts /evaluate.py
RuthvikBandari's picture
Upload scripts/evaluate.py with huggingface_hub
a5fa872 verified
"""DiaFoot.AI v2 — Evaluation Entry Point.
Phase 4: Evaluate trained models on test set.
Usage:
# Evaluate classifier
python scripts/evaluate.py --task classify \
# Evaluate segmentation
python scripts/evaluate.py --task segment \
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.data.augmentation import get_val_transforms
from src.data.torch_dataset import DFUDataset
from src.evaluation.classification_metrics import (
compute_classification_metrics,
print_classification_report,
)
from src.evaluation.metrics import (
aggregate_metrics,
compute_segmentation_metrics,
print_segmentation_report,
)
from src.models.classifier import TriageClassifier
from src.models.unetpp import build_unetpp
def evaluate_classifier(checkpoint_path: str, splits_dir: str, device: str) -> None:
"""Evaluate triage classifier on test set."""
logger = logging.getLogger("eval_classifier")
model = TriageClassifier(backbone="tf_efficientnetv2_m", num_classes=3, pretrained=False)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
model = model.to(device)
model.eval()
test_ds = DFUDataset(
split_csv=Path(splits_dir) / "test.csv",
transform=get_val_transforms(),
)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=4)
all_labels = []
all_preds = []
all_probs = []
with torch.no_grad():
for batch in test_loader:
images = batch["image"].to(device)
labels = batch["label"]
logits = model(images)
probs = torch.softmax(logits, dim=1)
preds = logits.argmax(dim=1)
all_labels.extend(labels.numpy())
all_preds.extend(preds.cpu().numpy())
all_probs.extend(probs.cpu().numpy())
y_true = np.array(all_labels)
y_pred = np.array(all_preds)
y_prob = np.array(all_probs)
metrics = compute_classification_metrics(y_true, y_pred, y_prob)
print_classification_report(metrics)
# Save results
output_path = Path("results/classification_metrics.json")
output_path.parent.mkdir(parents=True, exist_ok=True)
save_metrics = {k: v for k, v in metrics.items() if k != "report"}
with open(output_path, "w") as f:
json.dump(save_metrics, f, indent=2)
logger.info("Results saved to %s", output_path)
def evaluate_segmentation(checkpoint_path: str, splits_dir: str, device: str) -> None:
"""Evaluate segmentation model on test set."""
logger = logging.getLogger("eval_segmentation")
model = build_unetpp(encoder_name="efficientnet-b4", encoder_weights=None, classes=1)
ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
model = model.to(device)
model.eval()
test_ds = DFUDataset(
split_csv=Path(splits_dir) / "test.csv",
transform=get_val_transforms(),
return_metadata=True,
)
test_loader = torch.utils.data.DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4)
all_metrics = []
dfu_metrics = []
non_dfu_metrics = []
with torch.no_grad():
for batch in test_loader:
images = batch["image"].to(device)
masks = batch["mask"].numpy()
labels = batch["label"].numpy()
logits = model(images)
preds = (torch.sigmoid(logits) > 0.5).squeeze(1).cpu().numpy().astype(np.uint8)
for i in range(len(images)):
pred_mask = preds[i]
gt_mask = masks[i]
m = compute_segmentation_metrics(pred_mask, gt_mask)
all_metrics.append(m)
if labels[i] == 2:
dfu_metrics.append(m)
elif labels[i] == 1:
non_dfu_metrics.append(m)
# Overall results
summary = aggregate_metrics(all_metrics)
print_segmentation_report(summary)
# Per-class results
if dfu_metrics:
print("DFU images only:")
dfu_summary = aggregate_metrics(dfu_metrics)
print_segmentation_report(dfu_summary)
if non_dfu_metrics:
print("Non-DFU images only:")
non_dfu_summary = aggregate_metrics(non_dfu_metrics)
print_segmentation_report(non_dfu_summary)
# Save results
output_path = Path("results/segmentation_metrics.json")
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
json.dump(summary, f, indent=2, default=str)
logger.info("Results saved to %s", output_path)
def main() -> None:
"""Run evaluation."""
parser = argparse.ArgumentParser(description="DiaFoot.AI v2 Evaluation")
parser.add_argument("--task", type=str, required=True, choices=["classify", "segment"])
parser.add_argument("--checkpoint", type=str, required=True)
parser.add_argument("--splits-dir", type=str, default="data/splits")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
logging.basicConfig(
level=logging.DEBUG if args.verbose else logging.INFO,
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%H:%M:%S",
)
dev = args.device if torch.cuda.is_available() else "cpu"
if args.task == "classify":
evaluate_classifier(args.checkpoint, args.splits_dir, dev)
elif args.task == "segment":
evaluate_segmentation(args.checkpoint, args.splits_dir, dev)
if __name__ == "__main__":
main()