DiaFoot.AI-v2 / scripts /predict.py
RuthvikBandari's picture
Upload scripts/predict.py with huggingface_hub
a30cdae verified
"""DiaFoot.AI v2 — Inference on New Images.
Run the trained pipeline on any foot image.
Usage:
python scripts/predict.py --image path/to/foot_image.jpg
python scripts/predict.py --image path/to/image.jpg --save-mask output_mask.png
"""
from __future__ import annotations
import argparse
import logging
import sys
from pathlib import Path
import cv2
import numpy as np
import torch
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from src.data.augmentation import get_val_transforms
from src.models.classifier import TriageClassifier
from src.models.unetpp import build_unetpp
CLASS_NAMES = {0: "Healthy", 1: "Non-DFU Wound", 2: "DFU (Diabetic Foot Ulcer)"}
def load_and_preprocess(image_path: str) -> tuple[np.ndarray, torch.Tensor]:
"""Load image and prepare for inference."""
image = cv2.imread(image_path)
if image is None:
msg = f"Cannot read image: {image_path}"
raise FileNotFoundError(msg)
# Keep original for display
original = image.copy()
# Preprocess for model
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_resized = cv2.resize(image_rgb, (512, 512))
transform = get_val_transforms()
transformed = transform(image=image_resized)
tensor = transformed["image"].unsqueeze(0)
return original, tensor
def main() -> None:
"""Run inference on a single image."""
parser = argparse.ArgumentParser(description="DiaFoot.AI v2 — Predict")
parser.add_argument("--image", type=str, required=True, help="Path to foot image")
parser.add_argument(
"--classifier-checkpoint",
type=str,
default="checkpoints/classifier/best_epoch004_1.0000.pt",
)
parser.add_argument(
"--segmenter-checkpoint",
type=str,
default="checkpoints/segmentation/best_epoch019_0.6781.pt",
)
parser.add_argument("--save-mask", type=str, default=None, help="Save segmentation mask")
parser.add_argument("--device", type=str, default="cpu")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(message)s")
device = torch.device(
args.device if torch.cuda.is_available() or args.device == "cpu" else "cpu"
)
# Load image
original, tensor = load_and_preprocess(args.image)
tensor = tensor.to(device)
print(f"\n{'=' * 50}")
print("DiaFoot.AI v2 — Inference")
print(f"Image: {args.image}")
print(f"{'=' * 50}")
# Step 1: Classification
classifier_path = Path(args.classifier_checkpoint)
if classifier_path.exists():
classifier = TriageClassifier(
backbone="tf_efficientnetv2_m", num_classes=3, pretrained=False
)
ckpt = torch.load(str(classifier_path), map_location="cpu", weights_only=True)
classifier.load_state_dict(ckpt["model_state_dict"])
classifier = classifier.to(device).eval()
with torch.no_grad():
logits = classifier(tensor)
probs = torch.softmax(logits, dim=1).squeeze().cpu().numpy()
pred_class = int(probs.argmax())
confidence = float(probs.max())
print(f"\n Classification: {CLASS_NAMES[pred_class]}")
print(f" Confidence: {confidence:.1%}")
for i, name in CLASS_NAMES.items():
print(f" {name}: {probs[i]:.1%}")
else:
print(f"\n Classifier checkpoint not found: {classifier_path}")
pred_class = 2 # Assume DFU for segmentation
# Step 2: Segmentation (if wound detected)
segmenter_path = Path(args.segmenter_checkpoint)
if pred_class in (1, 2) and segmenter_path.exists():
segmenter = build_unetpp(encoder_name="efficientnet-b4", encoder_weights=None, classes=1)
ckpt = torch.load(str(segmenter_path), map_location="cpu", weights_only=True)
segmenter.load_state_dict(ckpt["model_state_dict"])
segmenter = segmenter.to(device).eval()
with torch.no_grad():
seg_logits = segmenter(tensor)
seg_prob = torch.sigmoid(seg_logits).squeeze().cpu().numpy()
seg_mask = (seg_prob > 0.5).astype(np.uint8)
wound_pixels = seg_mask.sum()
total_pixels = seg_mask.shape[0] * seg_mask.shape[1]
coverage = wound_pixels / total_pixels * 100
area_mm2 = wound_pixels * 0.5 * 0.5 # Assuming 0.5mm/pixel
print("\n Segmentation:")
print(f" Wound detected: {'Yes' if wound_pixels > 0 else 'No'}")
print(f" Wound pixels: {wound_pixels:,}")
print(f" Coverage: {coverage:.1f}%")
print(f" Estimated area: {area_mm2:.1f} mm2")
if args.save_mask:
mask_resized = cv2.resize(seg_mask * 255, (original.shape[1], original.shape[0]))
cv2.imwrite(args.save_mask, mask_resized)
print(f" Mask saved to: {args.save_mask}")
elif pred_class == 0:
print("\n Segmentation: Skipped (healthy foot detected)")
else:
print(f"\n Segmenter checkpoint not found: {segmenter_path}")
print(f"\n{'=' * 50}\n")
if __name__ == "__main__":
main()