File size: 5,091 Bytes
a30cdae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""DiaFoot.AI v2 — Inference on New Images.

Run the trained pipeline on any foot image.

Usage:
    python scripts/predict.py --image path/to/foot_image.jpg
    python scripts/predict.py --image path/to/image.jpg --save-mask output_mask.png
"""

from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path

import cv2
import numpy as np
import torch

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from src.data.augmentation import get_val_transforms
from src.models.classifier import TriageClassifier
from src.models.unetpp import build_unetpp

CLASS_NAMES = {0: "Healthy", 1: "Non-DFU Wound", 2: "DFU (Diabetic Foot Ulcer)"}


def load_and_preprocess(image_path: str) -> tuple[np.ndarray, torch.Tensor]:
    """Load image and prepare for inference."""
    image = cv2.imread(image_path)
    if image is None:
        msg = f"Cannot read image: {image_path}"
        raise FileNotFoundError(msg)

    # Keep original for display
    original = image.copy()

    # Preprocess for model
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_resized = cv2.resize(image_rgb, (512, 512))

    transform = get_val_transforms()
    transformed = transform(image=image_resized)
    tensor = transformed["image"].unsqueeze(0)

    return original, tensor


def main() -> None:
    """Run inference on a single image."""
    parser = argparse.ArgumentParser(description="DiaFoot.AI v2 — Predict")
    parser.add_argument("--image", type=str, required=True, help="Path to foot image")
    parser.add_argument(
        "--classifier-checkpoint",
        type=str,
        default="checkpoints/classifier/best_epoch004_1.0000.pt",
    )
    parser.add_argument(
        "--segmenter-checkpoint",
        type=str,
        default="checkpoints/segmentation/best_epoch019_0.6781.pt",
    )
    parser.add_argument("--save-mask", type=str, default=None, help="Save segmentation mask")
    parser.add_argument("--device", type=str, default="cpu")
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO, format="%(message)s")


    device = torch.device(
        args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu"
    )

    # Load image
    original, tensor = load_and_preprocess(args.image)
    tensor = tensor.to(device)

    print(f"\n{'=' * 50}")
    print("DiaFoot.AI v2 — Inference")
    print(f"Image: {args.image}")
    print(f"{'=' * 50}")

    # Step 1: Classification
    classifier_path = Path(args.classifier_checkpoint)
    if classifier_path.exists():
        classifier = TriageClassifier(
            backbone="tf_efficientnetv2_m", num_classes=3, pretrained=False
        )
        ckpt = torch.load(str(classifier_path), map_location="cpu", weights_only=True)
        classifier.load_state_dict(ckpt["model_state_dict"])
        classifier = classifier.to(device).eval()

        with torch.no_grad():
            logits = classifier(tensor)
            probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()

        pred_class = int(probs.argmax())
        confidence = float(probs.max())

        print(f"\n  Classification: {CLASS_NAMES[pred_class]}")
        print(f"  Confidence: {confidence:.1%}")
        for i, name in CLASS_NAMES.items():
            print(f"    {name}: {probs[i]:.1%}")
    else:
        print(f"\n  Classifier checkpoint not found: {classifier_path}")
        pred_class = 2  # Assume DFU for segmentation

    # Step 2: Segmentation (if wound detected)
    segmenter_path = Path(args.segmenter_checkpoint)
    if pred_class in (1, 2) and segmenter_path.exists():
        segmenter = build_unetpp(encoder_name="efficientnet-b4", encoder_weights=None, classes=1)
        ckpt = torch.load(str(segmenter_path), map_location="cpu", weights_only=True)
        segmenter.load_state_dict(ckpt["model_state_dict"])
        segmenter = segmenter.to(device).eval()

        with torch.no_grad():
            seg_logits = segmenter(tensor)
            seg_prob = torch.sigmoid(seg_logits).squeeze().cpu().numpy()
            seg_mask = (seg_prob > 0.5).astype(np.uint8)

        wound_pixels = seg_mask.sum()
        total_pixels = seg_mask.shape[0] * seg_mask.shape[1]
        coverage = wound_pixels / total_pixels * 100
        area_mm2 = wound_pixels * 0.5 * 0.5  # Assuming 0.5mm/pixel

        print("\n  Segmentation:")
        print(f"    Wound detected: {'Yes' if wound_pixels > 0 else 'No'}")
        print(f"    Wound pixels: {wound_pixels:,}")
        print(f"    Coverage: {coverage:.1f}%")
        print(f"    Estimated area: {area_mm2:.1f} mm2")

        if args.save_mask:
            mask_resized = cv2.resize(seg_mask * 255, (original.shape[1], original.shape[0]))
            cv2.imwrite(args.save_mask, mask_resized)
            print(f"    Mask saved to: {args.save_mask}")
    elif pred_class == 0:
        print("\n  Segmentation: Skipped (healthy foot detected)")
    else:
        print(f"\n  Segmenter checkpoint not found: {segmenter_path}")

    print(f"\n{'=' * 50}\n")


if __name__ == "__main__":
    main()