""" ============================================================ Rangoli Classification Inference ============================================================ Predict rangoli type from a single image or batch of images. Usage: python scripts/inference.py --image path/to/image.jpg --model resnet50 python scripts/inference.py --image_dir path/to/folder/ --model efficientnet_b3 python scripts/inference.py --image path/to/image.jpg --model resnet50 --gradcam ============================================================ """ import os import sys import json import yaml import argparse import numpy as np from PIL import Image import torch import torch.nn.functional as F from torchvision import transforms sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.classifier import build_model class RangoliPredictor: """Easy-to-use inference class.""" def __init__(self, checkpoint_path, config_path="configs/config.yaml", device=None): # Device if device is None: self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) # Config with open(config_path) as f: self.config = yaml.safe_load(f) # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=self.device) model_name = checkpoint["model_name"] # Build model self.model = build_model(model_name, self.config).to(self.device) self.model.load_state_dict(checkpoint["state_dict"]) self.model.eval() # Class mapping self.class_names = self.config["classes"] self.idx_to_class = {i: c for i, c in enumerate(self.class_names)} # Normalization stats stats_path = os.path.join( self.config["paths"]["processed_data"], "normalization_stats.json" ) if os.path.exists(stats_path): with open(stats_path) as f: stats = json.load(f) mean, std = stats["mean"], stats["std"] else: mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] # Transform img_size = self.config["preprocessing"]["image_size"] self.transform = transforms.Compose([ transforms.Resize(int(img_size * 1.14)), transforms.CenterCrop(img_size), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std), ]) print(f"Loaded {model_name} (epoch {checkpoint['epoch']}, " f"val_acc={checkpoint['val_acc']:.4f})") @torch.no_grad() def predict(self, image_path, top_k=3): """Predict rangoli class for a single image.""" img = Image.open(image_path).convert("RGB") img_tensor = self.transform(img).unsqueeze(0).to(self.device) logits = self.model(img_tensor) probs = F.softmax(logits, dim=1)[0] top_probs, top_indices = probs.topk(top_k) results = [] for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()): results.append({ "class": self.idx_to_class[idx], "confidence": float(prob), }) return { "image": image_path, "predicted_class": results[0]["class"], "confidence": results[0]["confidence"], "top_k": results, } @torch.no_grad() def predict_batch(self, image_paths, top_k=3): """Predict for multiple images.""" results = [] for path in image_paths: try: result = self.predict(path, top_k) results.append(result) except Exception as e: results.append({"image": path, "error": str(e)}) return results @torch.no_grad() def predict_with_gradcam(self, image_path): """Predict with Grad-CAM visualization.""" try: from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.image import show_cam_on_image from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget except ImportError: print("Install pytorch-grad-cam: pip install pytorch-grad-cam") return self.predict(image_path) img = Image.open(image_path).convert("RGB") img_tensor = self.transform(img).unsqueeze(0).to(self.device) # Find last conv layer target_layers = None for name, module in self.model.backbone.named_modules(): if isinstance(module, torch.nn.Conv2d): target_layers = [module] # Get prediction logits = self.model(img_tensor) probs = F.softmax(logits, dim=1)[0] pred_idx = probs.argmax().item() # Grad-CAM cam = GradCAM(model=self.model, target_layers=target_layers) targets = [ClassifierOutputTarget(pred_idx)] grayscale_cam = cam(input_tensor=img_tensor, targets=targets)[0] # Overlay img_resized = img.resize((224, 224)) img_np = np.array(img_resized).astype(np.float32) / 255.0 visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True) return { "predicted_class": self.idx_to_class[pred_idx], "confidence": float(probs[pred_idx]), "gradcam_image": visualization, } def main(): parser = argparse.ArgumentParser(description="Rangoli Inference") parser.add_argument("--image", type=str, help="Path to single image") parser.add_argument("--image_dir", type=str, help="Path to image directory") parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path") parser.add_argument("--config", type=str, default="configs/config.yaml") parser.add_argument("--top_k", type=int, default=3) parser.add_argument("--gradcam", action="store_true") parser.add_argument("--output", type=str, default=None, help="Save results to JSON") args = parser.parse_args() predictor = RangoliPredictor(args.checkpoint, args.config) if args.image: if args.gradcam: result = predictor.predict_with_gradcam(args.image) print(f"\nPrediction: {result['predicted_class']} " f"({result['confidence']*100:.1f}%)") # Save Grad-CAM if "gradcam_image" in result: save_path = args.image.replace(".", "_gradcam.") Image.fromarray(result["gradcam_image"]).save(save_path) print(f"Grad-CAM saved: {save_path}") else: result = predictor.predict(args.image, args.top_k) print(f"\nImage: {result['image']}") print(f"Prediction: {result['predicted_class']} ({result['confidence']*100:.1f}%)") print(f"\nTop-{args.top_k}:") for r in result["top_k"]: print(f" {r['class']:25s} : {r['confidence']*100:.1f}%") elif args.image_dir: image_paths = [] for f in sorted(os.listdir(args.image_dir)): if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')): image_paths.append(os.path.join(args.image_dir, f)) results = predictor.predict_batch(image_paths, args.top_k) for r in results: if "error" in r: print(f" ERROR: {r['image']} -> {r['error']}") else: print(f" {os.path.basename(r['image']):30s} -> " f"{r['predicted_class']:25s} ({r['confidence']*100:.1f}%)") if args.output: with open(args.output, "w") as f: json.dump(results, f, indent=2) print(f"\nResults saved: {args.output}") if __name__ == "__main__": main()