File size: 3,950 Bytes
52dd1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# src/inference/test_resnet_pt_svm.py

import os
import argparse
import random

from PIL import Image
from torchvision import datasets

from src.inference.resnet_pt_svm_model import ResNetPTSVMModel


def test_single_image(

    image_path: str,

    ckpt_path: str,

    labels_path: str,

    device: str = None,

    top_k: int = 5,

):
    assert os.path.exists(image_path), f"Image not found: {image_path}"
    img = Image.open(image_path).convert("RGB")

    model = ResNetPTSVMModel(
        ckpt_path=ckpt_path,
        labels_path=labels_path,
        device=device,
    )

    out = model.predict(img, top_k=top_k)

    print(f"Input image: {image_path}")
    print(f"Predicted class_id  : {out['class_id']}")
    print(f"Predicted class_name: {out['class_name']}")
    print("Top-k predictions:")
    for i, item in enumerate(out["top_k"], start=1):
        print(f"  {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")


def test_random_dataset_sample(

    data_root: str,

    ckpt_path: str,

    labels_path: str,

    device: str = None,

    top_k: int = 5,

):
    print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...")

    test_ds = datasets.OxfordIIITPet(
        root=data_root,
        split="test",
        target_types="category",
        transform=None,  # return PIL.Image
        download=True,
    )

    model = ResNetPTSVMModel(
        ckpt_path=ckpt_path,
        labels_path=labels_path,
        device=device,
    )

    idx = random.randint(0, len(test_ds) - 1)
    img, target = test_ds[idx]
    assert isinstance(img, Image.Image)

    gt_name = test_ds.categories[target]

    print(f"[+] Random sample idx={idx}")
    print(f"    Ground truth: id={target}, name={gt_name}")

    out = model.predict(img, top_k=top_k)

    print(f"    Predicted class_id  : {out['class_id']}")
    print(f"    Predicted class_name: {out['class_name']}")
    print("    Top-k predictions:")
    for i, item in enumerate(out["top_k"], start=1):
        print(f"      {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")


def parse_args():
    parser = argparse.ArgumentParser(
        description="Test ResNet(PT) + SVM inference on Oxford-IIIT Pet."
    )

    parser.add_argument(
        "--ckpt-path",
        type=str,
        default="checkpoints/resnet_pt_svm_head.joblib",
        help="Path to ResNet PT + SVM checkpoint.",
    )
    parser.add_argument(
        "--labels-path",
        type=str,
        default="configs/labels.json",
        help="Path to labels mapping JSON.",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="data/oxford-iiit-pet",
        help="Root directory for Oxford-IIIT Pet dataset.",
    )
    parser.add_argument(
        "--image-path",
        type=str,
        default=None,
        help="If provided, run on this image instead of random test sample.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default=None,
        help="Device to use (e.g. 'cpu', 'cuda'). If None, auto-select.",
    )
    parser.add_argument(
        "--top-k",
        type=int,
        default=5,
        help="Number of top classes to print.",
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    if args.image_path is not None:
        test_single_image(
            image_path=args.image_path,
            ckpt_path=args.ckpt_path,
            labels_path=args.labels_path,
            device=args.device,
            top_k=args.top_k,
        )
    else:
        test_random_dataset_sample(
            data_root=args.data_root,
            ckpt_path=args.ckpt_path,
            labels_path=args.labels_path,
            device=args.device,
            top_k=args.top_k,
        )