derm-cnn-ham10000-demo / inference.py
iamhmh
Initial model upload
397dad3
raw
history blame
2.35 kB
import argparse
import json
from pathlib import Path
import numpy as np
import torch
from PIL import Image
from model import load_model
def load_labels(labels_path: str = "labels.json") -> dict[int, str]:
labels_file = Path(labels_path)
if not labels_file.exists():
raise FileNotFoundError(f"labels.json not found at: {labels_file}")
with labels_file.open("r", encoding="utf-8") as f:
raw = json.load(f)
return {int(k): v for k, v in raw.items()}
def preprocess_image(image_path: str) -> torch.Tensor:
img_file = Path(image_path)
if not img_file.exists():
raise FileNotFoundError(f"Image not found at: {img_file}")
img = Image.open(img_file).convert("RGB")
img = img.resize((28, 28))
arr = np.array(img).astype("float32") / 255.0 # [H, W, C] in [0,1]
arr = np.transpose(arr, (2, 0, 1)) # [C, H, W]
tensor = torch.from_numpy(arr).unsqueeze(0) # [1, 3, 28, 28]
return tensor
def predict(
image_path: str,
weights_path: str = "model.pth",
labels_path: str = "labels.json"
):
model, device = load_model(weights_path)
id2label = load_labels(labels_path)
x = preprocess_image(image_path).to(device)
with torch.no_grad():
logits = model(x)
probs = torch.softmax(logits, dim=1)[0]
pred_idx = int(torch.argmax(probs).item())
pred_label = id2label.get(pred_idx, str(pred_idx))
probs_list = probs.cpu().tolist()
return pred_idx, pred_label, probs_list
def main():
parser = argparse.ArgumentParser(
description="Run inference with SkinCNN on a dermatoscopic image."
)
parser.add_argument("image", type=str, help="Path to input dermatoscopic image.")
parser.add_argument(
"--weights",
type=str,
default="model.pth",
help="Path to model weights (.pth).",
)
parser.add_argument(
"--labels",
type=str,
default="labels.json",
help="Path to labels.json.",
)
args = parser.parse_args()
idx, label, probs = predict(
image_path=args.image,
weights_path=args.weights,
labels_path=args.labels,
)
print(f"Predicted class index: {idx}")
print(f"Predicted label : {label}")
print(f"Probabilities : {probs}")
if __name__ == "__main__":
main()