Machine_learning_CS-6140 / src /inference /test_resnet_pt_lr.py
Shashwat98's picture
Upload 37 files
52dd1ca verified
# src/inference/test_resnet_pt_lr.py
import os
import argparse
import random
from PIL import Image
import torch
from torchvision import datasets
from src.inference.resnet_pt_lr_model import ResNetPTLRModel
def test_single_image(
image_path: str,
ckpt_path: str,
labels_path: str,
device: str = None,
top_k: int = 5,
):
assert os.path.exists(image_path), f"Image not found: {image_path}"
img = Image.open(image_path).convert("RGB")
model = ResNetPTLRModel(
ckpt_path=ckpt_path,
labels_path=labels_path,
device=device,
)
out = model.predict(img, top_k=top_k)
print(f"Input image: {image_path}")
print(f"Predicted class_id : {out['class_id']}")
print(f"Predicted class_name: {out['class_name']}")
print("Top-k predictions:")
for i, item in enumerate(out["top_k"], start=1):
print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
def test_random_dataset_sample(
data_root: str,
ckpt_path: str,
labels_path: str,
device: str = None,
top_k: int = 5,
):
"""
Pick a random sample from the Oxford-IIIT Pet test split and run inference.
"""
print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...")
# transform=None -> returns PIL.Image
test_ds = datasets.OxfordIIITPet(
root=data_root,
split="test",
target_types="category",
transform=None,
download=True,
)
model = ResNetPTLRModel(
ckpt_path=ckpt_path,
labels_path=labels_path,
device=device,
)
idx = random.randint(0, len(test_ds) - 1)
img, target = test_ds[idx]
assert isinstance(img, Image.Image)
# dataset has .categories giving names
gt_name = test_ds.categories[target]
print(f"[+] Random sample idx={idx}")
print(f" Ground truth: id={target}, name={gt_name}")
out = model.predict(img, top_k=top_k)
print(f" Predicted class_id : {out['class_id']}")
print(f" Predicted class_name: {out['class_name']}")
print(" Top-k predictions:")
for i, item in enumerate(out["top_k"], start=1):
print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})")
def parse_args():
parser = argparse.ArgumentParser(
description="Test ResNet(PT) + LR inference on Oxford-IIIT Pet."
)
parser.add_argument(
"--ckpt-path",
type=str,
default="checkpoints/resnet_pt_lr_head.joblib",
help="Path to ResNet PT + LR checkpoint.",
)
parser.add_argument(
"--labels-path",
type=str,
default="configs/labels.json",
help="Path to labels mapping JSON.",
)
parser.add_argument(
"--data-root",
type=str,
default="data/oxford-iiit-pet",
help="Root directory for Oxford-IIIT Pet dataset.",
)
parser.add_argument(
"--image-path",
type=str,
default=None,
help="If provided, run inference on this image instead of a random test sample.",
)
parser.add_argument(
"--device",
type=str,
default=None,
help="Device to use (e.g., 'cpu', 'cuda'). If None, auto-select.",
)
parser.add_argument(
"--top-k",
type=int,
default=5,
help="Number of top classes to print.",
)
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
if args.image_path is not None:
test_single_image(
image_path=args.image_path,
ckpt_path=args.ckpt_path,
labels_path=args.labels_path,
device=args.device,
top_k=args.top_k,
)
else:
test_random_dataset_sample(
data_root=args.data_root,
ckpt_path=args.ckpt_path,
labels_path=args.labels_path,
device=args.device,
top_k=args.top_k,
)