Spaces:
Runtime error
Runtime error
| """ | |
| ============================================================ | |
| Rangoli Classification Inference | |
| ============================================================ | |
| Predict rangoli type from a single image or batch of images. | |
| Usage: | |
| python scripts/inference.py --image path/to/image.jpg --model resnet50 | |
| python scripts/inference.py --image_dir path/to/folder/ --model efficientnet_b3 | |
| python scripts/inference.py --image path/to/image.jpg --model resnet50 --gradcam | |
| ============================================================ | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import yaml | |
| import argparse | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torch.nn.functional as F | |
| from torchvision import transforms | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models.classifier import build_model | |
| class RangoliPredictor: | |
| """Easy-to-use inference class.""" | |
| def __init__(self, checkpoint_path, config_path="configs/config.yaml", device=None): | |
| # Device | |
| if device is None: | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| else: | |
| self.device = torch.device(device) | |
| # Config | |
| with open(config_path) as f: | |
| self.config = yaml.safe_load(f) | |
| # Load checkpoint | |
| checkpoint = torch.load(checkpoint_path, map_location=self.device) | |
| model_name = checkpoint["model_name"] | |
| # Build model | |
| self.model = build_model(model_name, self.config).to(self.device) | |
| self.model.load_state_dict(checkpoint["state_dict"]) | |
| self.model.eval() | |
| # Class mapping | |
| self.class_names = self.config["classes"] | |
| self.idx_to_class = {i: c for i, c in enumerate(self.class_names)} | |
| # Normalization stats | |
| stats_path = os.path.join( | |
| self.config["paths"]["processed_data"], "normalization_stats.json" | |
| ) | |
| if os.path.exists(stats_path): | |
| with open(stats_path) as f: | |
| stats = json.load(f) | |
| mean, std = stats["mean"], stats["std"] | |
| else: | |
| mean = [0.485, 0.456, 0.406] | |
| std = [0.229, 0.224, 0.225] | |
| # Transform | |
| img_size = self.config["preprocessing"]["image_size"] | |
| self.transform = transforms.Compose([ | |
| transforms.Resize(int(img_size * 1.14)), | |
| transforms.CenterCrop(img_size), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=mean, std=std), | |
| ]) | |
| print(f"Loaded {model_name} (epoch {checkpoint['epoch']}, " | |
| f"val_acc={checkpoint['val_acc']:.4f})") | |
| def predict(self, image_path, top_k=3): | |
| """Predict rangoli class for a single image.""" | |
| img = Image.open(image_path).convert("RGB") | |
| img_tensor = self.transform(img).unsqueeze(0).to(self.device) | |
| logits = self.model(img_tensor) | |
| probs = F.softmax(logits, dim=1)[0] | |
| top_probs, top_indices = probs.topk(top_k) | |
| results = [] | |
| for prob, idx in zip(top_probs.cpu().numpy(), top_indices.cpu().numpy()): | |
| results.append({ | |
| "class": self.idx_to_class[idx], | |
| "confidence": float(prob), | |
| }) | |
| return { | |
| "image": image_path, | |
| "predicted_class": results[0]["class"], | |
| "confidence": results[0]["confidence"], | |
| "top_k": results, | |
| } | |
| def predict_batch(self, image_paths, top_k=3): | |
| """Predict for multiple images.""" | |
| results = [] | |
| for path in image_paths: | |
| try: | |
| result = self.predict(path, top_k) | |
| results.append(result) | |
| except Exception as e: | |
| results.append({"image": path, "error": str(e)}) | |
| return results | |
| def predict_with_gradcam(self, image_path): | |
| """Predict with Grad-CAM visualization.""" | |
| try: | |
| from pytorch_grad_cam import GradCAM | |
| from pytorch_grad_cam.utils.image import show_cam_on_image | |
| from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget | |
| except ImportError: | |
| print("Install pytorch-grad-cam: pip install pytorch-grad-cam") | |
| return self.predict(image_path) | |
| img = Image.open(image_path).convert("RGB") | |
| img_tensor = self.transform(img).unsqueeze(0).to(self.device) | |
| # Find last conv layer | |
| target_layers = None | |
| for name, module in self.model.backbone.named_modules(): | |
| if isinstance(module, torch.nn.Conv2d): | |
| target_layers = [module] | |
| # Get prediction | |
| logits = self.model(img_tensor) | |
| probs = F.softmax(logits, dim=1)[0] | |
| pred_idx = probs.argmax().item() | |
| # Grad-CAM | |
| cam = GradCAM(model=self.model, target_layers=target_layers) | |
| targets = [ClassifierOutputTarget(pred_idx)] | |
| grayscale_cam = cam(input_tensor=img_tensor, targets=targets)[0] | |
| # Overlay | |
| img_resized = img.resize((224, 224)) | |
| img_np = np.array(img_resized).astype(np.float32) / 255.0 | |
| visualization = show_cam_on_image(img_np, grayscale_cam, use_rgb=True) | |
| return { | |
| "predicted_class": self.idx_to_class[pred_idx], | |
| "confidence": float(probs[pred_idx]), | |
| "gradcam_image": visualization, | |
| } | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Rangoli Inference") | |
| parser.add_argument("--image", type=str, help="Path to single image") | |
| parser.add_argument("--image_dir", type=str, help="Path to image directory") | |
| parser.add_argument("--checkpoint", type=str, required=True, help="Model checkpoint path") | |
| parser.add_argument("--config", type=str, default="configs/config.yaml") | |
| parser.add_argument("--top_k", type=int, default=3) | |
| parser.add_argument("--gradcam", action="store_true") | |
| parser.add_argument("--output", type=str, default=None, help="Save results to JSON") | |
| args = parser.parse_args() | |
| predictor = RangoliPredictor(args.checkpoint, args.config) | |
| if args.image: | |
| if args.gradcam: | |
| result = predictor.predict_with_gradcam(args.image) | |
| print(f"\nPrediction: {result['predicted_class']} " | |
| f"({result['confidence']*100:.1f}%)") | |
| # Save Grad-CAM | |
| if "gradcam_image" in result: | |
| save_path = args.image.replace(".", "_gradcam.") | |
| Image.fromarray(result["gradcam_image"]).save(save_path) | |
| print(f"Grad-CAM saved: {save_path}") | |
| else: | |
| result = predictor.predict(args.image, args.top_k) | |
| print(f"\nImage: {result['image']}") | |
| print(f"Prediction: {result['predicted_class']} ({result['confidence']*100:.1f}%)") | |
| print(f"\nTop-{args.top_k}:") | |
| for r in result["top_k"]: | |
| print(f" {r['class']:25s} : {r['confidence']*100:.1f}%") | |
| elif args.image_dir: | |
| image_paths = [] | |
| for f in sorted(os.listdir(args.image_dir)): | |
| if f.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.webp')): | |
| image_paths.append(os.path.join(args.image_dir, f)) | |
| results = predictor.predict_batch(image_paths, args.top_k) | |
| for r in results: | |
| if "error" in r: | |
| print(f" ERROR: {r['image']} -> {r['error']}") | |
| else: | |
| print(f" {os.path.basename(r['image']):30s} -> " | |
| f"{r['predicted_class']:25s} ({r['confidence']*100:.1f}%)") | |
| if args.output: | |
| with open(args.output, "w") as f: | |
| json.dump(results, f, indent=2) | |
| print(f"\nResults saved: {args.output}") | |
| if __name__ == "__main__": | |
| main() | |