derm-cnn-ham10000-demo / inference.py
iamhmh
Initial model upload
397dad3
import argparse
import json
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from model import load_model
def load_labels(labels_path: str = "labels.json") -> dict[int, str]:
labels_file = Path(labels_path)
if not labels_file.exists():
raise FileNotFoundError(f"labels.json not found at: {labels_file}")
with labels_file.open("r", encoding="utf-8") as f:
raw = json.load(f)
return {int(k): v for k, v in raw.items()}
def preprocess_image(image_path: str) -> torch.Tensor:
img_file = Path(image_path)
if not img_file.exists():
raise FileNotFoundError(f"Image not found at: {img_file}")
img = Image.open(img_file).convert("RGB")
img = img.resize((28, 28))
arr = np.array(img).astype("float32") / 255.0 # [H, W, C] in [0,1]
arr = np.transpose(arr, (2, 0, 1)) # [C, H, W]
tensor = torch.from_numpy(arr).unsqueeze(0) # [1, 3, 28, 28]
return tensor
def predict(
image_path: str,
weights_path: str = "model.pth",
labels_path: str = "labels.json"
):
model, device = load_model(weights_path)
id2label = load_labels(labels_path)
x = preprocess_image(image_path).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)[0]
pred_idx = int(torch.argmax(probs).item())
pred_label = id2label.get(pred_idx, str(pred_idx))
probs_list = probs.cpu().tolist()
return pred_idx, pred_label, probs_list
def main():
parser = argparse.ArgumentParser(
description="Run inference with SkinCNN on a dermatoscopic image."
)
parser.add_argument("image", type=str, help="Path to input dermatoscopic image.")
parser.add_argument(
"--weights",
type=str,
default="model.pth",
help="Path to model weights (.pth).",
)
parser.add_argument(
"--labels",
type=str,
default="labels.json",
help="Path to labels.json.",
)
args = parser.parse_args()
idx, label, probs = predict(
image_path=args.image,
weights_path=args.weights,
labels_path=args.labels,
)
print(f"Predicted class index: {idx}")
print(f"Predicted label : {label}")
print(f"Probabilities : {probs}")
if __name__ == "__main__":
main()