# src/inference/test_resnet_pt_lr.py import os import argparse import random from PIL import Image import torch from torchvision import datasets from src.inference.resnet_pt_lr_model import ResNetPTLRModel def test_single_image( image_path: str, ckpt_path: str, labels_path: str, device: str = None, top_k: int = 5, ): assert os.path.exists(image_path), f"Image not found: {image_path}" img = Image.open(image_path).convert("RGB") model = ResNetPTLRModel( ckpt_path=ckpt_path, labels_path=labels_path, device=device, ) out = model.predict(img, top_k=top_k) print(f"Input image: {image_path}") print(f"Predicted class_id : {out['class_id']}") print(f"Predicted class_name: {out['class_name']}") print("Top-k predictions:") for i, item in enumerate(out["top_k"], start=1): print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})") def test_random_dataset_sample( data_root: str, ckpt_path: str, labels_path: str, device: str = None, top_k: int = 5, ): """ Pick a random sample from the Oxford-IIIT Pet test split and run inference. """ print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...") # transform=None -> returns PIL.Image test_ds = datasets.OxfordIIITPet( root=data_root, split="test", target_types="category", transform=None, download=True, ) model = ResNetPTLRModel( ckpt_path=ckpt_path, labels_path=labels_path, device=device, ) idx = random.randint(0, len(test_ds) - 1) img, target = test_ds[idx] assert isinstance(img, Image.Image) # dataset has .categories giving names gt_name = test_ds.categories[target] print(f"[+] Random sample idx={idx}") print(f" Ground truth: id={target}, name={gt_name}") out = model.predict(img, top_k=top_k) print(f" Predicted class_id : {out['class_id']}") print(f" Predicted class_name: {out['class_name']}") print(" Top-k predictions:") for i, item in enumerate(out["top_k"], start=1): print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})") def parse_args(): parser = argparse.ArgumentParser( description="Test ResNet(PT) + LR inference on Oxford-IIIT Pet." ) parser.add_argument( "--ckpt-path", type=str, default="checkpoints/resnet_pt_lr_head.joblib", help="Path to ResNet PT + LR checkpoint.", ) parser.add_argument( "--labels-path", type=str, default="configs/labels.json", help="Path to labels mapping JSON.", ) parser.add_argument( "--data-root", type=str, default="data/oxford-iiit-pet", help="Root directory for Oxford-IIIT Pet dataset.", ) parser.add_argument( "--image-path", type=str, default=None, help="If provided, run inference on this image instead of a random test sample.", ) parser.add_argument( "--device", type=str, default=None, help="Device to use (e.g., 'cpu', 'cuda'). If None, auto-select.", ) parser.add_argument( "--top-k", type=int, default=5, help="Number of top classes to print.", ) return parser.parse_args() if __name__ == "__main__": args = parse_args() if args.image_path is not None: test_single_image( image_path=args.image_path, ckpt_path=args.ckpt_path, labels_path=args.labels_path, device=args.device, top_k=args.top_k, ) else: test_random_dataset_sample( data_root=args.data_root, ckpt_path=args.ckpt_path, labels_path=args.labels_path, device=args.device, top_k=args.top_k, )