| | """DiaFoot.AI v2 — Inference on New Images. |
| | |
| | Run the trained pipeline on any foot image. |
| | |
| | Usage: |
| | python scripts/predict.py --image path/to/foot_image.jpg |
| | python scripts/predict.py --image path/to/image.jpg --save-mask output_mask.png |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import logging |
| | import sys |
| | from pathlib import Path |
| |
|
| | import cv2 |
| | import numpy as np |
| | import torch |
| |
|
| | sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) |
| |
|
| | from src.data.augmentation import get_val_transforms |
| | from src.models.classifier import TriageClassifier |
| | from src.models.unetpp import build_unetpp |
| |
|
| | CLASS_NAMES = {0: "Healthy", 1: "Non-DFU Wound", 2: "DFU (Diabetic Foot Ulcer)"} |
| |
|
| |
|
| | def load_and_preprocess(image_path: str) -> tuple[np.ndarray, torch.Tensor]: |
| | """Load image and prepare for inference.""" |
| | image = cv2.imread(image_path) |
| | if image is None: |
| | msg = f"Cannot read image: {image_path}" |
| | raise FileNotFoundError(msg) |
| |
|
| | |
| | original = image.copy() |
| |
|
| | |
| | image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| | image_resized = cv2.resize(image_rgb, (512, 512)) |
| |
|
| | transform = get_val_transforms() |
| | transformed = transform(image=image_resized) |
| | tensor = transformed["image"].unsqueeze(0) |
| |
|
| | return original, tensor |
| |
|
| |
|
| | def main() -> None: |
| | """Run inference on a single image.""" |
| | parser = argparse.ArgumentParser(description="DiaFoot.AI v2 — Predict") |
| | parser.add_argument("--image", type=str, required=True, help="Path to foot image") |
| | parser.add_argument( |
| | "--classifier-checkpoint", |
| | type=str, |
| | default="checkpoints/classifier/best_epoch004_1.0000.pt", |
| | ) |
| | parser.add_argument( |
| | "--segmenter-checkpoint", |
| | type=str, |
| | default="checkpoints/segmentation/best_epoch019_0.6781.pt", |
| | ) |
| | parser.add_argument("--save-mask", type=str, default=None, help="Save segmentation mask") |
| | parser.add_argument("--device", type=str, default="cpu") |
| | args = parser.parse_args() |
| |
|
| | logging.basicConfig(level=logging.INFO, format="%(message)s") |
| |
|
| |
|
| | device = torch.device( |
| | args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu" |
| | ) |
| |
|
| | |
| | original, tensor = load_and_preprocess(args.image) |
| | tensor = tensor.to(device) |
| |
|
| | print(f"\n{'=' * 50}") |
| | print("DiaFoot.AI v2 — Inference") |
| | print(f"Image: {args.image}") |
| | print(f"{'=' * 50}") |
| |
|
| | |
| | classifier_path = Path(args.classifier_checkpoint) |
| | if classifier_path.exists(): |
| | classifier = TriageClassifier( |
| | backbone="tf_efficientnetv2_m", num_classes=3, pretrained=False |
| | ) |
| | ckpt = torch.load(str(classifier_path), map_location="cpu", weights_only=True) |
| | classifier.load_state_dict(ckpt["model_state_dict"]) |
| | classifier = classifier.to(device).eval() |
| |
|
| | with torch.no_grad(): |
| | logits = classifier(tensor) |
| | probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy() |
| |
|
| | pred_class = int(probs.argmax()) |
| | confidence = float(probs.max()) |
| |
|
| | print(f"\n Classification: {CLASS_NAMES[pred_class]}") |
| | print(f" Confidence: {confidence:.1%}") |
| | for i, name in CLASS_NAMES.items(): |
| | print(f" {name}: {probs[i]:.1%}") |
| | else: |
| | print(f"\n Classifier checkpoint not found: {classifier_path}") |
| | pred_class = 2 |
| |
|
| | |
| | segmenter_path = Path(args.segmenter_checkpoint) |
| | if pred_class in (1, 2) and segmenter_path.exists(): |
| | segmenter = build_unetpp(encoder_name="efficientnet-b4", encoder_weights=None, classes=1) |
| | ckpt = torch.load(str(segmenter_path), map_location="cpu", weights_only=True) |
| | segmenter.load_state_dict(ckpt["model_state_dict"]) |
| | segmenter = segmenter.to(device).eval() |
| |
|
| | with torch.no_grad(): |
| | seg_logits = segmenter(tensor) |
| | seg_prob = torch.sigmoid(seg_logits).squeeze().cpu().numpy() |
| | seg_mask = (seg_prob > 0.5).astype(np.uint8) |
| |
|
| | wound_pixels = seg_mask.sum() |
| | total_pixels = seg_mask.shape[0] * seg_mask.shape[1] |
| | coverage = wound_pixels / total_pixels * 100 |
| | area_mm2 = wound_pixels * 0.5 * 0.5 |
| |
|
| | print("\n Segmentation:") |
| | print(f" Wound detected: {'Yes' if wound_pixels > 0 else 'No'}") |
| | print(f" Wound pixels: {wound_pixels:,}") |
| | print(f" Coverage: {coverage:.1f}%") |
| | print(f" Estimated area: {area_mm2:.1f} mm2") |
| |
|
| | if args.save_mask: |
| | mask_resized = cv2.resize(seg_mask * 255, (original.shape[1], original.shape[0])) |
| | cv2.imwrite(args.save_mask, mask_resized) |
| | print(f" Mask saved to: {args.save_mask}") |
| | elif pred_class == 0: |
| | print("\n Segmentation: Skipped (healthy foot detected)") |
| | else: |
| | print(f"\n Segmenter checkpoint not found: {segmenter_path}") |
| |
|
| | print(f"\n{'=' * 50}\n") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|