Spaces:
Runtime error
Runtime error
File size: 8,041 Bytes
0b3dd07 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 | """
============================================================
Rangoli Classification Inference
============================================================
Predict rangoli type from a single image or batch of images.
Usage:
python scripts/inference.py --image path/to/image.jpg --model resnet50
python scripts/inference.py --image_dir path/to/folder/ --model efficientnet_b3
python scripts/inference.py --image path/to/image.jpg --model resnet50 --gradcam
============================================================
"""
import os
import sys
import json
import yaml
import argparse
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
from torchvision import transforms
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models.classifier import build_model
class RangoliPredictor:
"""Easy-to-use inference class."""
def __init__(self, checkpoint_path, config_path="configs/config.yaml", device=None):
# Device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = torch.device(device)
# Config
with open(config_path) as f:
self.config = yaml.safe_load(f)
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device)
model_name = checkpoint["model_name"]
# Build model
self.model = build_model(model_name, self.config).to(self.device)
self.model.load_state_dict(checkpoint["state_dict"])
self.model.eval()
# Class mapping
self.class_names = self.config["classes"]
self.idx_to_class = {i: c for i, c in enumerate(self.class_names)}
# Normalization stats
stats_path = os.path.join(
self.config["paths"]["processed_data"], "normalization_stats.json"
)
if os.path.exists(stats_path):
with open(stats_path) as f:
stats = json.load(f)
mean, std = stats["mean"], stats["std"]
else:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
# Transform
img_size = self.config["preprocessing"]["image_size"]
self.transform = transforms.Compose([
transforms.Resize(int(img_size * 1.14)),
transforms.CenterCrop(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
])
print(f"Loaded {model_name} (epoch {checkpoint['epoch']}, "
f"val_acc={checkpoint['val_acc']:.4f})")
@torch.no_grad()
def predict(self, image_path, top_k=3):
"""Predict rangoli class for a single image."""
img = Image.open(image_path).convert("RGB")
img_tensor = self.transform(img).unsqueeze(0).to(self.device)
logits = self.model(img_tensor)
probs = F.softmax(logits, dim=1)[0]
top_probs, top_indices = probs.topk(top_k)
results = []
for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()):
results.append({
"class": self.idx_to_class[idx],
"confidence": float(prob),
})
return {
"image": image_path,
"predicted_class": results[0]["class"],
"confidence": results[0]["confidence"],
"top_k": results,
}
@torch.no_grad()
def predict_batch(self, image_paths, top_k=3):
"""Predict for multiple images."""
results = []
for path in image_paths:
try:
result = self.predict(path, top_k)
results.append(result)
except Exception as e:
results.append({"image": path, "error": str(e)})
return results
@torch.no_grad()
def predict_with_gradcam(self, image_path):
"""Predict with Grad-CAM visualization."""
try:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
except ImportError:
print("Install pytorch-grad-cam: pip install pytorch-grad-cam")
return self.predict(image_path)
img = Image.open(image_path).convert("RGB")
img_tensor = self.transform(img).unsqueeze(0).to(self.device)
# Find last conv layer
target_layers = None
for name, module in self.model.backbone.named_modules():
if isinstance(module, torch.nn.Conv2d):
target_layers = [module]
# Get prediction
logits = self.model(img_tensor)
probs = F.softmax(logits, dim=1)[0]
pred_idx = probs.argmax().item()
# Grad-CAM
cam = GradCAM(model=self.model, target_layers=target_layers)
targets = [ClassifierOutputTarget(pred_idx)]
grayscale_cam = cam(input_tensor=img_tensor, targets=targets)[0]
# Overlay
img_resized = img.resize((224, 224))
img_np = np.array(img_resized).astype(np.float32) / 255.0
visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True)
return {
"predicted_class": self.idx_to_class[pred_idx],
"confidence": float(probs[pred_idx]),
"gradcam_image": visualization,
}
def main():
parser = argparse.ArgumentParser(description="Rangoli Inference")
parser.add_argument("--image", type=str, help="Path to single image")
parser.add_argument("--image_dir", type=str, help="Path to image directory")
parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path")
parser.add_argument("--config", type=str, default="configs/config.yaml")
parser.add_argument("--top_k", type=int, default=3)
parser.add_argument("--gradcam", action="store_true")
parser.add_argument("--output", type=str, default=None, help="Save results to JSON")
args = parser.parse_args()
predictor = RangoliPredictor(args.checkpoint, args.config)
if args.image:
if args.gradcam:
result = predictor.predict_with_gradcam(args.image)
print(f"\nPrediction: {result['predicted_class']} "
f"({result['confidence']*100:.1f}%)")
# Save Grad-CAM
if "gradcam_image" in result:
save_path = args.image.replace(".", "_gradcam.")
Image.fromarray(result["gradcam_image"]).save(save_path)
print(f"Grad-CAM saved: {save_path}")
else:
result = predictor.predict(args.image, args.top_k)
print(f"\nImage: {result['image']}")
print(f"Prediction: {result['predicted_class']} ({result['confidence']*100:.1f}%)")
print(f"\nTop-{args.top_k}:")
for r in result["top_k"]:
print(f" {r['class']:25s} : {r['confidence']*100:.1f}%")
elif args.image_dir:
image_paths = []
for f in sorted(os.listdir(args.image_dir)):
if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')):
image_paths.append(os.path.join(args.image_dir, f))
results = predictor.predict_batch(image_paths, args.top_k)
for r in results:
if "error" in r:
print(f" ERROR: {r['image']} -> {r['error']}")
else:
print(f" {os.path.basename(r['image']):30s} -> "
f"{r['predicted_class']:25s} ({r['confidence']*100:.1f}%)")
if args.output:
with open(args.output, "w") as f:
json.dump(results, f, indent=2)
print(f"\nResults saved: {args.output}")
if __name__ == "__main__":
main()
|