import torch import torch.nn as nn from torchvision import transforms from PIL import Image import numpy as np import cv2 import os # 1. RE-DEFINE THE MODEL # --------------------------------------------------------- class BottleneckBlock(nn.Module): expansion = 4 def __init__(self, in_channels, mid_channels, stride=1): super(BottleneckBlock, self).__init__() out_channels = mid_channels * self.expansion self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) self.bn1 = nn.BatchNorm2d(mid_channels) self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False) self.bn2 = nn.BatchNorm2d(mid_channels) self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(out_channels) self.relu = nn.ReLU(inplace=True) self.shortcut = nn.Sequential() if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels) ) def forward(self, x): identity = x out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) out = self.relu(out) out = self.conv3(out) out = self.bn3(out) identity = self.shortcut(identity) out += identity out = self.relu(out) return out class ResNet50(nn.Module): def __init__(self, num_classes=16, channels_img=3): super(ResNet50, self).__init__() self.in_channels = 64 self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=7, stride=2, padding=3, bias=False) self.bn1 = nn.BatchNorm2d(64) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(mid_channels=64, num_blocks=3, stride=1) self.layer2 = self._make_layer(mid_channels=128, num_blocks=4, stride=2) self.layer3 = self._make_layer(mid_channels=256, num_blocks=6, stride=2) self.layer4 = self._make_layer(mid_channels=512, num_blocks=3, stride=2) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512 * 4, num_classes) def _make_layer(self, mid_channels, num_blocks, stride): layers = [] layers.append(BottleneckBlock(self.in_channels, mid_channels, stride)) self.in_channels = mid_channels * 4 for _ in range(num_blocks - 1): layers.append(BottleneckBlock(self.in_channels, mid_channels, stride=1)) return nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x # 2. GRAD-CAM LOGIC # --------------------------------------------------------- class GradCAM: def __init__(self, model, target_layer): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None target_layer.register_forward_hook(self.save_activation) target_layer.register_full_backward_hook(self.save_gradient) def save_activation(self, module, input, output): self.activations = output def save_gradient(self, module, grad_input, grad_output): self.gradients = grad_output[0] def __call__(self, x, class_idx=None): output = self.model(x) if class_idx is None: class_idx = torch.argmax(output, dim=1) self.model.zero_grad() score = output[0, class_idx] score.backward() gradients = self.gradients.data.numpy()[0] activations = self.activations.data.numpy()[0] weights = np.mean(gradients, axis=(1, 2)) cam = np.zeros(activations.shape[1:], dtype=np.float32) for i, w in enumerate(weights): cam += w * activations[i] cam = np.maximum(cam, 0) cam = cv2.resize(cam, (224, 224)) cam = cam - np.min(cam) if np.max(cam) != 0: cam = cam / np.max(cam) return cam, int(class_idx) # 3. RUN IT # --------------------------------------------------------- model = ResNet50(num_classes=16) # FIXED: Ensure we point to the file in the root directory checkpoint_path = "resnet50_epoch_4.pth" if not os.path.exists(checkpoint_path): print(f"CRITICAL ERROR: '{checkpoint_path}' not found in {os.getcwd()}") exit() try: print(f"Loading model from: {checkpoint_path}") # --- THE FIX IS HERE: weights_only=False --- checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: model.load_state_dict(checkpoint['state_dict']) else: model.load_state_dict(checkpoint) print("Model loaded successfully.") except Exception as e: print(f"Error loading weights: {e}") exit() model.eval() # Hook into the last convolutional layer target_layer = model.layer4[2].conv3 grad_cam = GradCAM(model, target_layer) # --- IMAGE LOADING --- image_path = "examples/email.png" if not os.path.exists(image_path): print(f"Error: Image '{image_path}' not found. Please check the path.") exit() original_image = Image.open(image_path).convert('RGB') preprocess = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) input_tensor = preprocess(original_image).unsqueeze(0) # Generate heatmap, class_id = grad_cam(input_tensor) class_names = [ 'advertisement', 'budget', 'email', 'file folder', 'form', 'handwritten', 'invoice', 'letter', 'memo', 'news article', 'presentation', 'questionnaire', 'resume', 'scientific publication', 'scientific report', 'specification' ] predicted_label = class_names[class_id] # Save heatmap = np.uint8(255 * heatmap) heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) original_cv = cv2.cvtColor(np.array(original_image.resize((224, 224))), cv2.COLOR_RGB2BGR) superimposed = cv2.addWeighted(original_cv, 0.6, heatmap, 0.4, 0) output_filename = "gradcam_result.jpg" cv2.imwrite(output_filename, superimposed) print(f"SUCCESS! Visualization saved to {output_filename}") print(f"Model Predicted: {predicted_label}")