File size: 4,137 Bytes
52dd1ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# src/inference/test_resnet_pt_lr.py

import os
import argparse
import random

from PIL import Image

import torch
from torchvision import datasets

from src.inference.resnet_pt_lr_model import ResNetPTLRModel


def test_single_image(

    image_path: str,

    ckpt_path: str,

    labels_path: str,

    device: str = None,

    top_k: int = 5,

):
    assert os.path.exists(image_path), f"Image not found: {image_path}"
    img = Image.open(image_path).convert("RGB")

    model = ResNetPTLRModel(
        ckpt_path=ckpt_path,
        labels_path=labels_path,
        device=device,
    )

    out = model.predict(img, top_k=top_k)

    print(f"Input image: {image_path}")
    print(f"Predicted class_id  : {out['class_id']}")
    print(f"Predicted class_name: {out['class_name']}")
    print("Top-k predictions:")
    for i, item in enumerate(out["top_k"], start=1):
        print(f"  {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")


def test_random_dataset_sample(

    data_root: str,

    ckpt_path: str,

    labels_path: str,

    device: str = None,

    top_k: int = 5,

):
    """

    Pick a random sample from the Oxford-IIIT Pet test split and run inference.

    """
    print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...")

    # transform=None -> returns PIL.Image
    test_ds = datasets.OxfordIIITPet(
        root=data_root,
        split="test",
        target_types="category",
        transform=None,
        download=True,
    )

    model = ResNetPTLRModel(
        ckpt_path=ckpt_path,
        labels_path=labels_path,
        device=device,
    )

    idx = random.randint(0, len(test_ds) - 1)
    img, target = test_ds[idx]
    assert isinstance(img, Image.Image)

    # dataset has .categories giving names
    gt_name = test_ds.categories[target]

    print(f"[+] Random sample idx={idx}")
    print(f"    Ground truth: id={target}, name={gt_name}")

    out = model.predict(img, top_k=top_k)

    print(f"    Predicted class_id  : {out['class_id']}")
    print(f"    Predicted class_name: {out['class_name']}")
    print("    Top-k predictions:")
    for i, item in enumerate(out["top_k"], start=1):
        print(f"      {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")


def parse_args():
    parser = argparse.ArgumentParser(
        description="Test ResNet(PT) + LR inference on Oxford-IIIT Pet."
    )

    parser.add_argument(
        "--ckpt-path",
        type=str,
        default="checkpoints/resnet_pt_lr_head.joblib",
        help="Path to ResNet PT + LR checkpoint.",
    )
    parser.add_argument(
        "--labels-path",
        type=str,
        default="configs/labels.json",
        help="Path to labels mapping JSON.",
    )
    parser.add_argument(
        "--data-root",
        type=str,
        default="data/oxford-iiit-pet",
        help="Root directory for Oxford-IIIT Pet dataset.",
    )
    parser.add_argument(
        "--image-path",
        type=str,
        default=None,
        help="If provided, run inference on this image instead of a random test sample.",
    )
    parser.add_argument(
        "--device",
        type=str,
        default=None,
        help="Device to use (e.g., 'cpu', 'cuda'). If None, auto-select.",
    )
    parser.add_argument(
        "--top-k",
        type=int,
        default=5,
        help="Number of top classes to print.",
    )

    return parser.parse_args()


if __name__ == "__main__":
    args = parse_args()

    if args.image_path is not None:
        test_single_image(
            image_path=args.image_path,
            ckpt_path=args.ckpt_path,
            labels_path=args.labels_path,
            device=args.device,
            top_k=args.top_k,
        )
    else:
        test_random_dataset_sample(
            data_root=args.data_root,
            ckpt_path=args.ckpt_path,
            labels_path=args.labels_path,
            device=args.device,
            top_k=args.top_k,
        )