Spaces:
Sleeping
Sleeping
File size: 4,137 Bytes
52dd1ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# src/inference/test_resnet_pt_lr.py
import os
import argparse
import random
from PIL import Image
import torch
from torchvision import datasets
from src.inference.resnet_pt_lr_model import ResNetPTLRModel
def test_single_image(
image_path: str,
ckpt_path: str,
labels_path: str,
device: str = None,
top_k: int = 5,
):
assert os.path.exists(image_path), f"Image not found: {image_path}"
img = Image.open(image_path).convert("RGB")
model = ResNetPTLRModel(
ckpt_path=ckpt_path,
labels_path=labels_path,
device=device,
)
out = model.predict(img, top_k=top_k)
print(f"Input image: {image_path}")
print(f"Predicted class_id : {out['class_id']}")
print(f"Predicted class_name: {out['class_name']}")
print("Top-k predictions:")
for i, item in enumerate(out["top_k"], start=1):
print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
def test_random_dataset_sample(
data_root: str,
ckpt_path: str,
labels_path: str,
device: str = None,
top_k: int = 5,
):
"""
Pick a random sample from the Oxford-IIIT Pet test split and run inference.
"""
print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...")
# transform=None -> returns PIL.Image
test_ds = datasets.OxfordIIITPet(
root=data_root,
split="test",
target_types="category",
transform=None,
download=True,
)
model = ResNetPTLRModel(
ckpt_path=ckpt_path,
labels_path=labels_path,
device=device,
)
idx = random.randint(0, len(test_ds) - 1)
img, target = test_ds[idx]
assert isinstance(img, Image.Image)
# dataset has .categories giving names
gt_name = test_ds.categories[target]
print(f"[+] Random sample idx={idx}")
print(f" Ground truth: id={target}, name={gt_name}")
out = model.predict(img, top_k=top_k)
print(f" Predicted class_id : {out['class_id']}")
print(f" Predicted class_name: {out['class_name']}")
print(" Top-k predictions:")
for i, item in enumerate(out["top_k"], start=1):
print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
def parse_args():
parser = argparse.ArgumentParser(
description="Test ResNet(PT) + LR inference on Oxford-IIIT Pet."
)
parser.add_argument(
"--ckpt-path",
type=str,
default="checkpoints/resnet_pt_lr_head.joblib",
help="Path to ResNet PT + LR checkpoint.",
)
parser.add_argument(
"--labels-path",
type=str,
default="configs/labels.json",
help="Path to labels mapping JSON.",
)
parser.add_argument(
"--data-root",
type=str,
default="data/oxford-iiit-pet",
help="Root directory for Oxford-IIIT Pet dataset.",
)
parser.add_argument(
"--image-path",
type=str,
default=None,
help="If provided, run inference on this image instead of a random test sample.",
)
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use (e.g., 'cpu', 'cuda'). If None, auto-select.",
)
parser.add_argument(
"--top-k",
type=int,
default=5,
help="Number of top classes to print.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.image_path is not None:
test_single_image(
image_path=args.image_path,
ckpt_path=args.ckpt_path,
labels_path=args.labels_path,
device=args.device,
top_k=args.top_k,
)
else:
test_random_dataset_sample(
data_root=args.data_root,
ckpt_path=args.ckpt_path,
labels_path=args.labels_path,
device=args.device,
top_k=args.top_k,
)
|