RuthvikBandari commited on
Commit
a5fa872
·
verified ·
1 Parent(s): a30cdae

Upload scripts/evaluate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/evaluate.py +178 -0
scripts/evaluate.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DiaFoot.AI v2 — Evaluation Entry Point.
2
+
3
+ Phase 4: Evaluate trained models on test set.
4
+
5
+ Usage:
6
+ # Evaluate classifier
7
+ python scripts/evaluate.py --task classify \
8
+
9
+ # Evaluate segmentation
10
+ python scripts/evaluate.py --task segment \
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import argparse
16
+ import json
17
+ import logging
18
+ import sys
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
25
+
26
+ from src.data.augmentation import get_val_transforms
27
+ from src.data.torch_dataset import DFUDataset
28
+ from src.evaluation.classification_metrics import (
29
+ compute_classification_metrics,
30
+ print_classification_report,
31
+ )
32
+ from src.evaluation.metrics import (
33
+ aggregate_metrics,
34
+ compute_segmentation_metrics,
35
+ print_segmentation_report,
36
+ )
37
+ from src.models.classifier import TriageClassifier
38
+ from src.models.unetpp import build_unetpp
39
+
40
+
41
+ def evaluate_classifier(checkpoint_path: str, splits_dir: str, device: str) -> None:
42
+ """Evaluate triage classifier on test set."""
43
+ logger = logging.getLogger("eval_classifier")
44
+
45
+ model = TriageClassifier(backbone="tf_efficientnetv2_m", num_classes=3, pretrained=False)
46
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
47
+ model.load_state_dict(ckpt["model_state_dict"])
48
+ model = model.to(device)
49
+ model.eval()
50
+
51
+ test_ds = DFUDataset(
52
+ split_csv=Path(splits_dir) / "test.csv",
53
+ transform=get_val_transforms(),
54
+ )
55
+ test_loader = torch.utils.data.DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=4)
56
+
57
+ all_labels = []
58
+ all_preds = []
59
+ all_probs = []
60
+
61
+ with torch.no_grad():
62
+ for batch in test_loader:
63
+ images = batch["image"].to(device)
64
+ labels = batch["label"]
65
+ logits = model(images)
66
+ probs = torch.softmax(logits, dim=1)
67
+ preds = logits.argmax(dim=1)
68
+
69
+ all_labels.extend(labels.numpy())
70
+ all_preds.extend(preds.cpu().numpy())
71
+ all_probs.extend(probs.cpu().numpy())
72
+
73
+ y_true = np.array(all_labels)
74
+ y_pred = np.array(all_preds)
75
+ y_prob = np.array(all_probs)
76
+
77
+ metrics = compute_classification_metrics(y_true, y_pred, y_prob)
78
+ print_classification_report(metrics)
79
+
80
+ # Save results
81
+ output_path = Path("results/classification_metrics.json")
82
+ output_path.parent.mkdir(parents=True, exist_ok=True)
83
+ save_metrics = {k: v for k, v in metrics.items() if k != "report"}
84
+ with open(output_path, "w") as f:
85
+ json.dump(save_metrics, f, indent=2)
86
+ logger.info("Results saved to %s", output_path)
87
+
88
+
89
+ def evaluate_segmentation(checkpoint_path: str, splits_dir: str, device: str) -> None:
90
+ """Evaluate segmentation model on test set."""
91
+ logger = logging.getLogger("eval_segmentation")
92
+
93
+ model = build_unetpp(encoder_name="efficientnet-b4", encoder_weights=None, classes=1)
94
+ ckpt = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
95
+ model.load_state_dict(ckpt["model_state_dict"])
96
+ model = model.to(device)
97
+ model.eval()
98
+
99
+ test_ds = DFUDataset(
100
+ split_csv=Path(splits_dir) / "test.csv",
101
+ transform=get_val_transforms(),
102
+ return_metadata=True,
103
+ )
104
+ test_loader = torch.utils.data.DataLoader(test_ds, batch_size=8, shuffle=False, num_workers=4)
105
+
106
+ all_metrics = []
107
+ dfu_metrics = []
108
+ non_dfu_metrics = []
109
+
110
+ with torch.no_grad():
111
+ for batch in test_loader:
112
+ images = batch["image"].to(device)
113
+ masks = batch["mask"].numpy()
114
+ labels = batch["label"].numpy()
115
+
116
+ logits = model(images)
117
+ preds = (torch.sigmoid(logits) > 0.5).squeeze(1).cpu().numpy().astype(np.uint8)
118
+
119
+ for i in range(len(images)):
120
+ pred_mask = preds[i]
121
+ gt_mask = masks[i]
122
+ m = compute_segmentation_metrics(pred_mask, gt_mask)
123
+ all_metrics.append(m)
124
+
125
+ if labels[i] == 2:
126
+ dfu_metrics.append(m)
127
+ elif labels[i] == 1:
128
+ non_dfu_metrics.append(m)
129
+
130
+ # Overall results
131
+ summary = aggregate_metrics(all_metrics)
132
+ print_segmentation_report(summary)
133
+
134
+ # Per-class results
135
+ if dfu_metrics:
136
+ print("DFU images only:")
137
+ dfu_summary = aggregate_metrics(dfu_metrics)
138
+ print_segmentation_report(dfu_summary)
139
+
140
+ if non_dfu_metrics:
141
+ print("Non-DFU images only:")
142
+ non_dfu_summary = aggregate_metrics(non_dfu_metrics)
143
+ print_segmentation_report(non_dfu_summary)
144
+
145
+ # Save results
146
+ output_path = Path("results/segmentation_metrics.json")
147
+ output_path.parent.mkdir(parents=True, exist_ok=True)
148
+ with open(output_path, "w") as f:
149
+ json.dump(summary, f, indent=2, default=str)
150
+ logger.info("Results saved to %s", output_path)
151
+
152
+
153
+ def main() -> None:
154
+ """Run evaluation."""
155
+ parser = argparse.ArgumentParser(description="DiaFoot.AI v2 Evaluation")
156
+ parser.add_argument("--task", type=str, required=True, choices=["classify", "segment"])
157
+ parser.add_argument("--checkpoint", type=str, required=True)
158
+ parser.add_argument("--splits-dir", type=str, default="data/splits")
159
+ parser.add_argument("--device", type=str, default="cuda")
160
+ parser.add_argument("--verbose", action="store_true")
161
+ args = parser.parse_args()
162
+
163
+ logging.basicConfig(
164
+ level=logging.DEBUG if args.verbose else logging.INFO,
165
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
166
+ datefmt="%H:%M:%S",
167
+ )
168
+
169
+ dev = args.device if torch.cuda.is_available() else "cpu"
170
+
171
+ if args.task == "classify":
172
+ evaluate_classifier(args.checkpoint, args.splits_dir, dev)
173
+ elif args.task == "segment":
174
+ evaluate_segmentation(args.checkpoint, args.splits_dir, dev)
175
+
176
+
177
+ if __name__ == "__main__":
178
+ main()