Spaces:
Sleeping
Sleeping
| # src/inference/test_resnet_pt_lr.py | |
| import os | |
| import argparse | |
| import random | |
| from PIL import Image | |
| import torch | |
| from torchvision import datasets | |
| from src.inference.resnet_pt_lr_model import ResNetPTLRModel | |
| def test_single_image( | |
| image_path: str, | |
| ckpt_path: str, | |
| labels_path: str, | |
| device: str = None, | |
| top_k: int = 5, | |
| ): | |
| assert os.path.exists(image_path), f"Image not found: {image_path}" | |
| img = Image.open(image_path).convert("RGB") | |
| model = ResNetPTLRModel( | |
| ckpt_path=ckpt_path, | |
| labels_path=labels_path, | |
| device=device, | |
| ) | |
| out = model.predict(img, top_k=top_k) | |
| print(f"Input image: {image_path}") | |
| print(f"Predicted class_id : {out['class_id']}") | |
| print(f"Predicted class_name: {out['class_name']}") | |
| print("Top-k predictions:") | |
| for i, item in enumerate(out["top_k"], start=1): | |
| print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})") | |
| def test_random_dataset_sample( | |
| data_root: str, | |
| ckpt_path: str, | |
| labels_path: str, | |
| device: str = None, | |
| top_k: int = 5, | |
| ): | |
| """ | |
| Pick a random sample from the Oxford-IIIT Pet test split and run inference. | |
| """ | |
| print(f"[+] Loading Oxford-IIIT Pet test split from {data_root} ...") | |
| # transform=None -> returns PIL.Image | |
| test_ds = datasets.OxfordIIITPet( | |
| root=data_root, | |
| split="test", | |
| target_types="category", | |
| transform=None, | |
| download=True, | |
| ) | |
| model = ResNetPTLRModel( | |
| ckpt_path=ckpt_path, | |
| labels_path=labels_path, | |
| device=device, | |
| ) | |
| idx = random.randint(0, len(test_ds) - 1) | |
| img, target = test_ds[idx] | |
| assert isinstance(img, Image.Image) | |
| # dataset has .categories giving names | |
| gt_name = test_ds.categories[target] | |
| print(f"[+] Random sample idx={idx}") | |
| print(f" Ground truth: id={target}, name={gt_name}") | |
| out = model.predict(img, top_k=top_k) | |
| print(f" Predicted class_id : {out['class_id']}") | |
| print(f" Predicted class_name: {out['class_name']}") | |
| print(" Top-k predictions:") | |
| for i, item in enumerate(out["top_k"], start=1): | |
| print(f" {i}. {item['class_name']} (id={item['class_id']}, prob={item['probability']:.4f})") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser( | |
| description="Test ResNet(PT) + LR inference on Oxford-IIIT Pet." | |
| ) | |
| parser.add_argument( | |
| "--ckpt-path", | |
| type=str, | |
| default="checkpoints/resnet_pt_lr_head.joblib", | |
| help="Path to ResNet PT + LR checkpoint.", | |
| ) | |
| parser.add_argument( | |
| "--labels-path", | |
| type=str, | |
| default="configs/labels.json", | |
| help="Path to labels mapping JSON.", | |
| ) | |
| parser.add_argument( | |
| "--data-root", | |
| type=str, | |
| default="data/oxford-iiit-pet", | |
| help="Root directory for Oxford-IIIT Pet dataset.", | |
| ) | |
| parser.add_argument( | |
| "--image-path", | |
| type=str, | |
| default=None, | |
| help="If provided, run inference on this image instead of a random test sample.", | |
| ) | |
| parser.add_argument( | |
| "--device", | |
| type=str, | |
| default=None, | |
| help="Device to use (e.g., 'cpu', 'cuda'). If None, auto-select.", | |
| ) | |
| parser.add_argument( | |
| "--top-k", | |
| type=int, | |
| default=5, | |
| help="Number of top classes to print.", | |
| ) | |
| return parser.parse_args() | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| if args.image_path is not None: | |
| test_single_image( | |
| image_path=args.image_path, | |
| ckpt_path=args.ckpt_path, | |
| labels_path=args.labels_path, | |
| device=args.device, | |
| top_k=args.top_k, | |
| ) | |
| else: | |
| test_random_dataset_sample( | |
| data_root=args.data_root, | |
| ckpt_path=args.ckpt_path, | |
| labels_path=args.labels_path, | |
| device=args.device, | |
| top_k=args.top_k, | |
| ) | |