RuthvikBandari commited on
Commit
a30cdae
·
verified ·
1 Parent(s): 74a6c74

Upload scripts/predict.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/predict.py +147 -0
scripts/predict.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DiaFoot.AI v2 — Inference on New Images.
2
+
3
+ Run the trained pipeline on any foot image.
4
+
5
+ Usage:
6
+ python scripts/predict.py --image path/to/foot_image.jpg
7
+ python scripts/predict.py --image path/to/image.jpg --save-mask output_mask.png
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import logging
14
+ import sys
15
+ from pathlib import Path
16
+
17
+ import cv2
18
+ import numpy as np
19
+ import torch
20
+
21
+ sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
22
+
23
+ from src.data.augmentation import get_val_transforms
24
+ from src.models.classifier import TriageClassifier
25
+ from src.models.unetpp import build_unetpp
26
+
27
+ CLASS_NAMES = {0: "Healthy", 1: "Non-DFU Wound", 2: "DFU (Diabetic Foot Ulcer)"}
28
+
29
+
30
+ def load_and_preprocess(image_path: str) -> tuple[np.ndarray, torch.Tensor]:
31
+ """Load image and prepare for inference."""
32
+ image = cv2.imread(image_path)
33
+ if image is None:
34
+ msg = f"Cannot read image: {image_path}"
35
+ raise FileNotFoundError(msg)
36
+
37
+ # Keep original for display
38
+ original = image.copy()
39
+
40
+ # Preprocess for model
41
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
42
+ image_resized = cv2.resize(image_rgb, (512, 512))
43
+
44
+ transform = get_val_transforms()
45
+ transformed = transform(image=image_resized)
46
+ tensor = transformed["image"].unsqueeze(0)
47
+
48
+ return original, tensor
49
+
50
+
51
+ def main() -> None:
52
+ """Run inference on a single image."""
53
+ parser = argparse.ArgumentParser(description="DiaFoot.AI v2 — Predict")
54
+ parser.add_argument("--image", type=str, required=True, help="Path to foot image")
55
+ parser.add_argument(
56
+ "--classifier-checkpoint",
57
+ type=str,
58
+ default="checkpoints/classifier/best_epoch004_1.0000.pt",
59
+ )
60
+ parser.add_argument(
61
+ "--segmenter-checkpoint",
62
+ type=str,
63
+ default="checkpoints/segmentation/best_epoch019_0.6781.pt",
64
+ )
65
+ parser.add_argument("--save-mask", type=str, default=None, help="Save segmentation mask")
66
+ parser.add_argument("--device", type=str, default="cpu")
67
+ args = parser.parse_args()
68
+
69
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
70
+
71
+
72
+ device = torch.device(
73
+ args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu"
74
+ )
75
+
76
+ # Load image
77
+ original, tensor = load_and_preprocess(args.image)
78
+ tensor = tensor.to(device)
79
+
80
+ print(f"\n{'=' * 50}")
81
+ print("DiaFoot.AI v2 — Inference")
82
+ print(f"Image: {args.image}")
83
+ print(f"{'=' * 50}")
84
+
85
+ # Step 1: Classification
86
+ classifier_path = Path(args.classifier_checkpoint)
87
+ if classifier_path.exists():
88
+ classifier = TriageClassifier(
89
+ backbone="tf_efficientnetv2_m", num_classes=3, pretrained=False
90
+ )
91
+ ckpt = torch.load(str(classifier_path), map_location="cpu", weights_only=True)
92
+ classifier.load_state_dict(ckpt["model_state_dict"])
93
+ classifier = classifier.to(device).eval()
94
+
95
+ with torch.no_grad():
96
+ logits = classifier(tensor)
97
+ probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
98
+
99
+ pred_class = int(probs.argmax())
100
+ confidence = float(probs.max())
101
+
102
+ print(f"\n Classification: {CLASS_NAMES[pred_class]}")
103
+ print(f" Confidence: {confidence:.1%}")
104
+ for i, name in CLASS_NAMES.items():
105
+ print(f" {name}: {probs[i]:.1%}")
106
+ else:
107
+ print(f"\n Classifier checkpoint not found: {classifier_path}")
108
+ pred_class = 2 # Assume DFU for segmentation
109
+
110
+ # Step 2: Segmentation (if wound detected)
111
+ segmenter_path = Path(args.segmenter_checkpoint)
112
+ if pred_class in (1, 2) and segmenter_path.exists():
113
+ segmenter = build_unetpp(encoder_name="efficientnet-b4", encoder_weights=None, classes=1)
114
+ ckpt = torch.load(str(segmenter_path), map_location="cpu", weights_only=True)
115
+ segmenter.load_state_dict(ckpt["model_state_dict"])
116
+ segmenter = segmenter.to(device).eval()
117
+
118
+ with torch.no_grad():
119
+ seg_logits = segmenter(tensor)
120
+ seg_prob = torch.sigmoid(seg_logits).squeeze().cpu().numpy()
121
+ seg_mask = (seg_prob > 0.5).astype(np.uint8)
122
+
123
+ wound_pixels = seg_mask.sum()
124
+ total_pixels = seg_mask.shape[0] * seg_mask.shape[1]
125
+ coverage = wound_pixels / total_pixels * 100
126
+ area_mm2 = wound_pixels * 0.5 * 0.5 # Assuming 0.5mm/pixel
127
+
128
+ print("\n Segmentation:")
129
+ print(f" Wound detected: {'Yes' if wound_pixels > 0 else 'No'}")
130
+ print(f" Wound pixels: {wound_pixels:,}")
131
+ print(f" Coverage: {coverage:.1f}%")
132
+ print(f" Estimated area: {area_mm2:.1f} mm2")
133
+
134
+ if args.save_mask:
135
+ mask_resized = cv2.resize(seg_mask * 255, (original.shape[1], original.shape[0]))
136
+ cv2.imwrite(args.save_mask, mask_resized)
137
+ print(f" Mask saved to: {args.save_mask}")
138
+ elif pred_class == 0:
139
+ print("\n Segmentation: Skipped (healthy foot detected)")
140
+ else:
141
+ print(f"\n Segmenter checkpoint not found: {segmenter_path}")
142
+
143
+ print(f"\n{'=' * 50}\n")
144
+
145
+
146
+ if __name__ == "__main__":
147
+ main()