rangoli-classifier / scripts /inference.py
shashidharak99's picture
Upload 16 files
0b3dd07 verified
"""
============================================================
Rangoli Classification Inference
============================================================
Predict rangoli type from a single image or batch of images.
Usage:
python scripts/inference.py --image path/to/image.jpg --model resnet50
python scripts/inference.py --image_dir path/to/folder/ --model efficientnet_b3
python scripts/inference.py --image path/to/image.jpg --model resnet50 --gradcam
============================================================
"""
import os
import sys
import json
import yaml
import argparse
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.classifier import build_model
class RangoliPredictor:
"""Easy-to-use inference class."""
def __init__(self, checkpoint_path, config_path="configs/config.yaml", device=None):
# Device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
# Config
with open(config_path) as f:
self.config = yaml.safe_load(f)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device)
model_name = checkpoint["model_name"]
# Build model
self.model = build_model(model_name, self.config).to(self.device)
self.model.load_state_dict(checkpoint["state_dict"])
self.model.eval()
# Class mapping
self.class_names = self.config["classes"]
self.idx_to_class = {i: c for i, c in enumerate(self.class_names)}
# Normalization stats
stats_path = os.path.join(
self.config["paths"]["processed_data"], "normalization_stats.json"
)
if os.path.exists(stats_path):
with open(stats_path) as f:
stats = json.load(f)
mean, std = stats["mean"], stats["std"]
else:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# Transform
img_size = self.config["preprocessing"]["image_size"]
self.transform = transforms.Compose([
transforms.Resize(int(img_size * 1.14)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
print(f"Loaded {model_name} (epoch {checkpoint['epoch']}, "
f"val_acc={checkpoint['val_acc']:.4f})")
@torch.no_grad()
def predict(self, image_path, top_k=3):
"""Predict rangoli class for a single image."""
img = Image.open(image_path).convert("RGB")
img_tensor = self.transform(img).unsqueeze(0).to(self.device)
logits = self.model(img_tensor)
probs = F.softmax(logits, dim=1)[0]
top_probs, top_indices = probs.topk(top_k)
results = []
for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
results.append({
"class": self.idx_to_class[idx],
"confidence": float(prob),
})
return {
"image": image_path,
"predicted_class": results[0]["class"],
"confidence": results[0]["confidence"],
"top_k": results,
}
@torch.no_grad()
def predict_batch(self, image_paths, top_k=3):
"""Predict for multiple images."""
results = []
for path in image_paths:
try:
result = self.predict(path, top_k)
results.append(result)
except Exception as e:
results.append({"image": path, "error": str(e)})
return results
@torch.no_grad()
def predict_with_gradcam(self, image_path):
"""Predict with Grad-CAM visualization."""
try:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
except ImportError:
print("Install pytorch-grad-cam: pip install pytorch-grad-cam")
return self.predict(image_path)
img = Image.open(image_path).convert("RGB")
img_tensor = self.transform(img).unsqueeze(0).to(self.device)
# Find last conv layer
target_layers = None
for name, module in self.model.backbone.named_modules():
if isinstance(module, torch.nn.Conv2d):
target_layers = [module]
# Get prediction
logits = self.model(img_tensor)
probs = F.softmax(logits, dim=1)[0]
pred_idx = probs.argmax().item()
# Grad-CAM
cam = GradCAM(model=self.model, target_layers=target_layers)
targets = [ClassifierOutputTarget(pred_idx)]
grayscale_cam = cam(input_tensor=img_tensor, targets=targets)[0]
# Overlay
img_resized = img.resize((224, 224))
img_np = np.array(img_resized).astype(np.float32) / 255.0
visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
return {
"predicted_class": self.idx_to_class[pred_idx],
"confidence": float(probs[pred_idx]),
"gradcam_image": visualization,
}
def main():
parser = argparse.ArgumentParser(description="Rangoli Inference")
parser.add_argument("--image", type=str, help="Path to single image")
parser.add_argument("--image_dir", type=str, help="Path to image directory")
parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path")
parser.add_argument("--config", type=str, default="configs/config.yaml")
parser.add_argument("--top_k", type=int, default=3)
parser.add_argument("--gradcam", action="store_true")
parser.add_argument("--output", type=str, default=None, help="Save results to JSON")
args = parser.parse_args()
predictor = RangoliPredictor(args.checkpoint, args.config)
if args.image:
if args.gradcam:
result = predictor.predict_with_gradcam(args.image)
print(f"\nPrediction: {result['predicted_class']} "
f"({result['confidence']*100:.1f}%)")
# Save Grad-CAM
if "gradcam_image" in result:
save_path = args.image.replace(".", "_gradcam.")
Image.fromarray(result["gradcam_image"]).save(save_path)
print(f"Grad-CAM saved: {save_path}")
else:
result = predictor.predict(args.image, args.top_k)
print(f"\nImage: {result['image']}")
print(f"Prediction: {result['predicted_class']} ({result['confidence']*100:.1f}%)")
print(f"\nTop-{args.top_k}:")
for r in result["top_k"]:
print(f" {r['class']:25s} : {r['confidence']*100:.1f}%")
elif args.image_dir:
image_paths = []
for f in sorted(os.listdir(args.image_dir)):
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')):
image_paths.append(os.path.join(args.image_dir, f))
results = predictor.predict_batch(image_paths, args.top_k)
for r in results:
if "error" in r:
print(f" ERROR: {r['image']} -> {r['error']}")
else:
print(f" {os.path.basename(r['image']):30s} -> "
f"{r['predicted_class']:25s} ({r['confidence']*100:.1f}%)")
if args.output:
with open(args.output, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved: {args.output}")
if __name__ == "__main__":
main()