""" Grad-CAM Implementation for Crop Disease Detection using pytorch-grad-cam Generates visual explanations showing which parts of the leaf image the model focuses on """ import torch import torch.nn.functional as F import numpy as np from PIL import Image import matplotlib.pyplot as plt import matplotlib.cm as cm from pathlib import Path import base64 import io import os try: from pytorch_grad_cam import GradCAM from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image PYTORCH_GRAD_CAM_AVAILABLE = True except ImportError as e: print(f"Warning: pytorch-grad-cam not available: {e}") PYTORCH_GRAD_CAM_AVAILABLE = False class CropDiseaseExplainer: """High-level interface for crop disease explanation using pytorch-grad-cam""" def __init__(self, model, class_names, device='cpu'): """ Initialize explainer Args: model: Trained model class_names: List of class names device: Device to run on """ self.model = model.to(device) self.class_names = class_names self.device = device # Define target layer for Grad-CAM (last convolutional layer) target_layers = [] # Try different model architectures if hasattr(model, 'resnet') and hasattr(model.resnet, 'layer4'): # For our CropDiseaseResNet50 model target_layers = [model.resnet.layer4[-1]] print(f"Using target layer: model.resnet.layer4[-1]") elif hasattr(model, 'layer4'): # For standard ResNet target_layers = [model.layer4[-1]] print(f"Using target layer: model.layer4[-1]") else: # Try to find the last convolutional layer for name, module in model.named_modules(): if isinstance(module, (torch.nn.Conv2d, torch.nn.modules.conv.Conv2d)): target_layers = [module] print(f"Using target layer: {name}") if not target_layers: print("Warning: Could not find suitable target layer for Grad-CAM") self.grad_cam = None return self.target_layers = target_layers # Initialize Grad-CAM if PYTORCH_GRAD_CAM_AVAILABLE: try: self.grad_cam = GradCAM(model=self.model, target_layers=self.target_layers) print("✅ Grad-CAM initialized successfully") except Exception as e: print(f"Error initializing Grad-CAM: {e}") self.grad_cam = None else: self.grad_cam = None print("Warning: pytorch-grad-cam not available, Grad-CAM disabled") def explain_prediction(self, image_path, save_dir='outputs/heatmaps', return_base64=False, target_class=None): """ Generate complete explanation for an image Args: image_path: Path to input image save_dir: Directory to save explanations return_base64: Whether to return base64 encoded image target_class: Specific class to target (if None, uses predicted class) Returns: explanation: Dictionary with prediction and explanation """ if not PYTORCH_GRAD_CAM_AVAILABLE or self.grad_cam is None: return {'error': 'Grad-CAM not available'} # Load and preprocess image original_image = Image.open(image_path).convert('RGB') original_np = np.array(original_image) / 255.0 # Normalize to [0,1] # Preprocessing transforms (should match training transforms) from torchvision import transforms transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) input_tensor = transform(original_image).unsqueeze(0).to(self.device) # Get prediction self.model.eval() with torch.no_grad(): outputs = self.model(input_tensor) probabilities = F.softmax(outputs, dim=1) predicted_idx = torch.argmax(probabilities, dim=1).item() confidence = probabilities[0][predicted_idx].item() # Use target class if specified, otherwise use predicted class target_idx = target_class if target_class is not None else predicted_idx targets = [ClassifierOutputTarget(target_idx)] # Generate Grad-CAM try: # Resize original image for overlay original_resized = np.array(original_image.resize((224, 224))) original_resized = original_resized / 255.0 print(f"Input tensor shape: {input_tensor.shape}") print(f"Targets: {targets}") # Generate CAM # First attempt with default target layer grayscale_cam = self.grad_cam(input_tensor=input_tensor, targets=targets) # Validate CAM result before accessing attributes if grayscale_cam is None: print("Grad-CAM returned None") # Try a fallback target layer if available (e.g., last conv inside bottleneck) fallback_cam = self._try_fallback_cam(input_tensor, targets) if fallback_cam is None: return {'error': 'Failed to generate Grad-CAM heatmap'} grayscale_cam = fallback_cam # Ensure numpy array if isinstance(grayscale_cam, torch.Tensor): grayscale_cam = grayscale_cam.detach().cpu().numpy() # Basic sanity checks try: _ = grayscale_cam.shape except Exception: print("Grad-CAM result has no shape attribute") return {'error': 'Invalid Grad-CAM heatmap shape'} print(f"Generated CAM type: {type(grayscale_cam)}") print(f"Generated CAM shape: {grayscale_cam.shape}") # Check if CAM was generated successfully if grayscale_cam.size == 0: # Try fallback if present fallback_cam = self._try_fallback_cam(input_tensor, targets) if fallback_cam is None or fallback_cam.size == 0: return {'error': 'Failed to generate Grad-CAM heatmap'} grayscale_cam = fallback_cam grayscale_cam = grayscale_cam[0, :] # Take first (and only) image # Create visualization cam_image = show_cam_on_image(original_resized, grayscale_cam, use_rgb=True) # Convert back to PIL Image # Convert to PIL safely (avoid double scaling if already uint8) if cam_image.dtype == np.uint8: cam_pil = Image.fromarray(cam_image) else: cam_pil = Image.fromarray((np.clip(cam_image, 0, 1) * 255).astype(np.uint8)) # Create save directory Path(save_dir).mkdir(parents=True, exist_ok=True) # Save visualization filename = Path(image_path).stem save_path = Path(save_dir) / f"{filename}_gradcam.jpg" cam_pil.save(save_path) # Prepare return data result = { 'predicted_class': self.class_names[predicted_idx], 'predicted_idx': predicted_idx, 'confidence': confidence, 'target_class': self.class_names[target_idx], 'target_idx': target_idx, 'save_path': str(save_path), 'cam_image': cam_pil } # Add base64 encoding if requested if return_base64: buffer = io.BytesIO() cam_pil.save(buffer, format='JPEG') buffer.seek(0) base64_str = base64.b64encode(buffer.getvalue()).decode() result['overlay_base64'] = base64_str return result except Exception as e: print(f"Error generating Grad-CAM: {e}") return {'error': str(e)} def _try_fallback_cam(self, input_tensor, targets): """Try alternative target layers to compute CAM if the primary attempt fails.""" try: # Determine a plausible fallback layer fallback_layers = [] # If the target layer is a Bottleneck, try its last conv try: # For our wrapped model if hasattr(self.model, 'resnet') and hasattr(self.model.resnet, 'layer4'): bottleneck = self.model.resnet.layer4[-1] if hasattr(bottleneck, 'conv3'): fallback_layers = [bottleneck.conv3] # For plain ResNet elif hasattr(self.model, 'layer4'): bottleneck = self.model.layer4[-1] if hasattr(bottleneck, 'conv3'): fallback_layers = [bottleneck.conv3] except Exception: pass if not fallback_layers: return None print("Trying fallback Grad-CAM target layer (conv3 of last bottleneck)...") from pytorch_grad_cam import GradCAM cam = GradCAM(model=self.model, target_layers=fallback_layers) grayscale_cam = cam(input_tensor=input_tensor, targets=targets) if grayscale_cam is None: return None if isinstance(grayscale_cam, torch.Tensor): grayscale_cam = grayscale_cam.detach().cpu().numpy() return grayscale_cam except Exception as _: return None def load_model_and_generate_gradcam(model_path, image_path, output_path=None, target_class=None): """ Complete example function that loads a model and generates Grad-CAM visualization Args: model_path: Path to the saved model file image_path: Path to input image output_path: Path to save the output (optional) target_class: Target class index (optional, uses prediction if None) Returns: Dictionary with results """ # Import model import sys sys.path.append(os.path.join(os.path.dirname(__file__))) from model import CropDiseaseResNet50 # Define class names class_names = [ 'Corn___Cercospora_leaf_spot_Gray_leaf_spot', 'Corn___Common_rust', 'Corn___healthy', 'Corn___Northern_Leaf_Blight', 'Potato___Early_Blight', 'Potato___healthy', 'Potato___Late_Blight', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___healthy', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites_Two_spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_mosaic_virus', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus' ] # Step 1: Load the trained model print(f"Loading model from {model_path}...") device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = CropDiseaseResNet50(num_classes=len(class_names), pretrained=False) checkpoint = torch.load(model_path, map_location=device) # Handle checkpoint format if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] if 'class_names' in checkpoint: class_names = checkpoint['class_names'] else: state_dict = checkpoint model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() print(f"✅ Model loaded successfully!") # Step 2: Initialize Grad-CAM explainer print("Initializing Grad-CAM explainer...") explainer = CropDiseaseExplainer(model, class_names, device) # Step 3: Generate Grad-CAM visualization print(f"Generating Grad-CAM for {image_path}...") result = explainer.explain_prediction( image_path=image_path, save_dir='outputs/heatmaps', return_base64=True, target_class=target_class ) if 'error' in result: print(f"❌ Error: {result['error']}") return result # Step 4: Save output if path specified if output_path: result['cam_image'].save(output_path) print(f"✅ Saved Grad-CAM visualization to {output_path}") # Print results print(f"✅ Grad-CAM generated successfully!") print(f" Predicted: {result['predicted_class']} ({result['confidence']:.1%})") print(f" Target: {result['target_class']}") print(f" Saved to: {result['save_path']}") return result # Example usage if __name__ == "__main__": # Example usage model_path = "../models/crop_disease_v3_model.pth" image_path = "../test_leaf_sample.jpg" output_path = "../outputs/gradcam_example.jpg" if os.path.exists(model_path) and os.path.exists(image_path): result = load_model_and_generate_gradcam( model_path=model_path, image_path=image_path, output_path=output_path ) else: print("Model or image file not found!") print(f"Model path: {model_path}") print(f"Image path: {image_path}")