File size: 2,349 Bytes
397dad3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import argparse
import json
from pathlib import Path

import numpy as np
import torch
from PIL import Image

from model import load_model


def load_labels(labels_path: str = "labels.json") -> dict[int, str]:
    labels_file = Path(labels_path)
    if not labels_file.exists():
        raise FileNotFoundError(f"labels.json not found at: {labels_file}")

    with labels_file.open("r", encoding="utf-8") as f:
        raw = json.load(f)

    return {int(k): v for k, v in raw.items()}


def preprocess_image(image_path: str) -> torch.Tensor:
    img_file = Path(image_path)
    if not img_file.exists():
        raise FileNotFoundError(f"Image not found at: {img_file}")

    img = Image.open(img_file).convert("RGB")
    img = img.resize((28, 28))

    arr = np.array(img).astype("float32") / 255.0  # [H, W, C] in [0,1]
    arr = np.transpose(arr, (2, 0, 1))            # [C, H, W]
    tensor = torch.from_numpy(arr).unsqueeze(0)   # [1, 3, 28, 28]

    return tensor


def predict(
    image_path: str,
    weights_path: str = "model.pth",
    labels_path: str = "labels.json"
):
    model, device = load_model(weights_path)
    id2label = load_labels(labels_path)

    x = preprocess_image(image_path).to(device)

    with torch.no_grad():
        logits = model(x)
        probs = torch.softmax(logits, dim=1)[0]

    pred_idx = int(torch.argmax(probs).item())
    pred_label = id2label.get(pred_idx, str(pred_idx))
    probs_list = probs.cpu().tolist()

    return pred_idx, pred_label, probs_list


def main():
    parser = argparse.ArgumentParser(
        description="Run inference with SkinCNN on a dermatoscopic image."
    )
    parser.add_argument("image", type=str, help="Path to input dermatoscopic image.")
    parser.add_argument(
        "--weights",
        type=str,
        default="model.pth",
        help="Path to model weights (.pth).",
    )
    parser.add_argument(
        "--labels",
        type=str,
        default="labels.json",
        help="Path to labels.json.",
    )
    args = parser.parse_args()

    idx, label, probs = predict(
        image_path=args.image,
        weights_path=args.weights,
        labels_path=args.labels,
    )

    print(f"Predicted class index: {idx}")
    print(f"Predicted label      : {label}")
    print(f"Probabilities        : {probs}")


if __name__ == "__main__":
    main()