| | """DiaFoot.AI v2 — Evaluation Entry Point. |
| | |
| | Phase 4: Evaluate trained models on test set. |
| | |
| | Usage: |
| | # Evaluate classifier |
| | python scripts/evaluate.py --task classify \ |
| | |
| | # Evaluate segmentation |
| | python scripts/evaluate.py --task segment \ |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import json |
| | import logging |
| | import sys |
| | from pathlib import Path |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| |
|
| | from src.data.augmentation import get_val_transforms |
| | from src.data.torch_dataset import DFUDataset |
| | from src.evaluation.classification_metrics import ( |
| | compute_classification_metrics, |
| | print_classification_report, |
| | ) |
| | from src.evaluation.metrics import ( |
| | aggregate_metrics, |
| | compute_segmentation_metrics, |
| | print_segmentation_report, |
| | ) |
| | from src.models.classifier import TriageClassifier |
| | from src.models.unetpp import build_unetpp |
| |
|
| |
|
| | def evaluate_classifier(checkpoint_path: str, splits_dir: str, device: str) -> None: |
| | """Evaluate triage classifier on test set.""" |
| | logger = logging.getLogger("eval_classifier") |
| |
|
| | model = TriageClassifier(backbone="tf_efficientnetv2_m", num_classes=3, pretrained=False) |
| | ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) |
| | model.load_state_dict(ckpt["model_state_dict"]) |
| | model = model.to(device) |
| | model.eval() |
| |
|
| | test_ds = DFUDataset( |
| | split_csv=Path(splits_dir) / "test.csv", |
| | transform=get_val_transforms(), |
| | ) |
| | test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=4) |
| |
|
| | all_labels = [] |
| | all_preds = [] |
| | all_probs = [] |
| |
|
| | with torch.no_grad(): |
| | for batch in test_loader: |
| | images = batch["image"].to(device) |
| | labels = batch["label"] |
| | logits = model(images) |
| | probs = torch.softmax(logits, dim=1) |
| | preds = logits.argmax(dim=1) |
| |
|
| | all_labels.extend(labels.numpy()) |
| | all_preds.extend(preds.cpu().numpy()) |
| | all_probs.extend(probs.cpu().numpy()) |
| |
|
| | y_true = np.array(all_labels) |
| | y_pred = np.array(all_preds) |
| | y_prob = np.array(all_probs) |
| |
|
| | metrics = compute_classification_metrics(y_true, y_pred, y_prob) |
| | print_classification_report(metrics) |
| |
|
| | |
| | output_path = Path("results/classification_metrics.json") |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | save_metrics = {k: v for k, v in metrics.items() if k != "report"} |
| | with open(output_path, "w") as f: |
| | json.dump(save_metrics, f, indent=2) |
| | logger.info("Results saved to %s", output_path) |
| |
|
| |
|
| | def evaluate_segmentation(checkpoint_path: str, splits_dir: str, device: str) -> None: |
| | """Evaluate segmentation model on test set.""" |
| | logger = logging.getLogger("eval_segmentation") |
| |
|
| | model = build_unetpp(encoder_name="efficientnet-b4", encoder_weights=None, classes=1) |
| | ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True) |
| | model.load_state_dict(ckpt["model_state_dict"]) |
| | model = model.to(device) |
| | model.eval() |
| |
|
| | test_ds = DFUDataset( |
| | split_csv=Path(splits_dir) / "test.csv", |
| | transform=get_val_transforms(), |
| | return_metadata=True, |
| | ) |
| | test_loader = torch.utils.data.DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4) |
| |
|
| | all_metrics = [] |
| | dfu_metrics = [] |
| | non_dfu_metrics = [] |
| |
|
| | with torch.no_grad(): |
| | for batch in test_loader: |
| | images = batch["image"].to(device) |
| | masks = batch["mask"].numpy() |
| | labels = batch["label"].numpy() |
| |
|
| | logits = model(images) |
| | preds = (torch.sigmoid(logits) > 0.5).squeeze(1).cpu().numpy().astype(np.uint8) |
| |
|
| | for i in range(len(images)): |
| | pred_mask = preds[i] |
| | gt_mask = masks[i] |
| | m = compute_segmentation_metrics(pred_mask, gt_mask) |
| | all_metrics.append(m) |
| |
|
| | if labels[i] == 2: |
| | dfu_metrics.append(m) |
| | elif labels[i] == 1: |
| | non_dfu_metrics.append(m) |
| |
|
| | |
| | summary = aggregate_metrics(all_metrics) |
| | print_segmentation_report(summary) |
| |
|
| | |
| | if dfu_metrics: |
| | print("DFU images only:") |
| | dfu_summary = aggregate_metrics(dfu_metrics) |
| | print_segmentation_report(dfu_summary) |
| |
|
| | if non_dfu_metrics: |
| | print("Non-DFU images only:") |
| | non_dfu_summary = aggregate_metrics(non_dfu_metrics) |
| | print_segmentation_report(non_dfu_summary) |
| |
|
| | |
| | output_path = Path("results/segmentation_metrics.json") |
| | output_path.parent.mkdir(parents=True, exist_ok=True) |
| | with open(output_path, "w") as f: |
| | json.dump(summary, f, indent=2, default=str) |
| | logger.info("Results saved to %s", output_path) |
| |
|
| |
|
| | def main() -> None: |
| | """Run evaluation.""" |
| | parser = argparse.ArgumentParser(description="DiaFoot.AI v2 Evaluation") |
| | parser.add_argument("--task", type=str, required=True, choices=["classify", "segment"]) |
| | parser.add_argument("--checkpoint", type=str, required=True) |
| | parser.add_argument("--splits-dir", type=str, default="data/splits") |
| | parser.add_argument("--device", type=str, default="cuda") |
| | parser.add_argument("--verbose", action="store_true") |
| | args = parser.parse_args() |
| |
|
| | logging.basicConfig( |
| | level=logging.DEBUG if args.verbose else logging.INFO, |
| | format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
| | datefmt="%H:%M:%S", |
| | ) |
| |
|
| | dev = args.device if torch.cuda.is_available() else "cpu" |
| |
|
| | if args.task == "classify": |
| | evaluate_classifier(args.checkpoint, args.splits_dir, dev) |
| | elif args.task == "segment": |
| | evaluate_segmentation(args.checkpoint, args.splits_dir, dev) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|