|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import cv2 |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
class BottleneckBlock(nn.Module): |
|
|
expansion = 4 |
|
|
def __init__(self, in_channels, mid_channels, stride=1): |
|
|
super(BottleneckBlock, self).__init__() |
|
|
out_channels = mid_channels * self.expansion |
|
|
self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(mid_channels) |
|
|
self.conv2 = nn.Conv2d(mid_channels, mid_channels, kernel_size=3, stride=stride, padding=1, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(mid_channels) |
|
|
self.conv3 = nn.Conv2d(mid_channels, out_channels, kernel_size=1, bias=False) |
|
|
self.bn3 = nn.BatchNorm2d(out_channels) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.shortcut = nn.Sequential() |
|
|
if stride != 1 or in_channels != out_channels: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), |
|
|
nn.BatchNorm2d(out_channels) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
identity = x |
|
|
out = self.conv1(x) |
|
|
out = self.bn1(out) |
|
|
out = self.relu(out) |
|
|
out = self.conv2(out) |
|
|
out = self.bn2(out) |
|
|
out = self.relu(out) |
|
|
out = self.conv3(out) |
|
|
out = self.bn3(out) |
|
|
identity = self.shortcut(identity) |
|
|
out += identity |
|
|
out = self.relu(out) |
|
|
return out |
|
|
|
|
|
class ResNet50(nn.Module): |
|
|
def __init__(self, num_classes=16, channels_img=3): |
|
|
super(ResNet50, self).__init__() |
|
|
self.in_channels = 64 |
|
|
self.conv1 = nn.Conv2d(channels_img, 64, kernel_size=7, stride=2, padding=3, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(64) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) |
|
|
self.layer1 = self._make_layer(mid_channels=64, num_blocks=3, stride=1) |
|
|
self.layer2 = self._make_layer(mid_channels=128, num_blocks=4, stride=2) |
|
|
self.layer3 = self._make_layer(mid_channels=256, num_blocks=6, stride=2) |
|
|
self.layer4 = self._make_layer(mid_channels=512, num_blocks=3, stride=2) |
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) |
|
|
self.fc = nn.Linear(512 * 4, num_classes) |
|
|
|
|
|
def _make_layer(self, mid_channels, num_blocks, stride): |
|
|
layers = [] |
|
|
layers.append(BottleneckBlock(self.in_channels, mid_channels, stride)) |
|
|
self.in_channels = mid_channels * 4 |
|
|
for _ in range(num_blocks - 1): |
|
|
layers.append(BottleneckBlock(self.in_channels, mid_channels, stride=1)) |
|
|
return nn.Sequential(*layers) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.conv1(x) |
|
|
x = self.bn1(x) |
|
|
x = self.relu(x) |
|
|
x = self.maxpool(x) |
|
|
x = self.layer1(x) |
|
|
x = self.layer2(x) |
|
|
x = self.layer3(x) |
|
|
x = self.layer4(x) |
|
|
x = self.avgpool(x) |
|
|
x = torch.flatten(x, 1) |
|
|
x = self.fc(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class GradCAM: |
|
|
def __init__(self, model, target_layer): |
|
|
self.model = model |
|
|
self.target_layer = target_layer |
|
|
self.gradients = None |
|
|
self.activations = None |
|
|
|
|
|
target_layer.register_forward_hook(self.save_activation) |
|
|
target_layer.register_full_backward_hook(self.save_gradient) |
|
|
|
|
|
def save_activation(self, module, input, output): |
|
|
self.activations = output |
|
|
|
|
|
def save_gradient(self, module, grad_input, grad_output): |
|
|
self.gradients = grad_output[0] |
|
|
|
|
|
def __call__(self, x, class_idx=None): |
|
|
output = self.model(x) |
|
|
if class_idx is None: |
|
|
class_idx = torch.argmax(output, dim=1) |
|
|
|
|
|
self.model.zero_grad() |
|
|
score = output[0, class_idx] |
|
|
score.backward() |
|
|
|
|
|
gradients = self.gradients.data.numpy()[0] |
|
|
activations = self.activations.data.numpy()[0] |
|
|
weights = np.mean(gradients, axis=(1, 2)) |
|
|
|
|
|
cam = np.zeros(activations.shape[1:], dtype=np.float32) |
|
|
for i, w in enumerate(weights): |
|
|
cam += w * activations[i] |
|
|
|
|
|
cam = np.maximum(cam, 0) |
|
|
cam = cv2.resize(cam, (224, 224)) |
|
|
cam = cam - np.min(cam) |
|
|
if np.max(cam) != 0: |
|
|
cam = cam / np.max(cam) |
|
|
return cam, int(class_idx) |
|
|
|
|
|
|
|
|
|
|
|
model = ResNet50(num_classes=16) |
|
|
|
|
|
|
|
|
checkpoint_path = "resnet50_epoch_4.pth" |
|
|
|
|
|
if not os.path.exists(checkpoint_path): |
|
|
print(f"CRITICAL ERROR: '{checkpoint_path}' not found in {os.getcwd()}") |
|
|
exit() |
|
|
|
|
|
try: |
|
|
print(f"Loading model from: {checkpoint_path}") |
|
|
|
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
|
|
|
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: |
|
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
else: |
|
|
model.load_state_dict(checkpoint) |
|
|
print("Model loaded successfully.") |
|
|
except Exception as e: |
|
|
print(f"Error loading weights: {e}") |
|
|
exit() |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
target_layer = model.layer4[2].conv3 |
|
|
grad_cam = GradCAM(model, target_layer) |
|
|
|
|
|
|
|
|
image_path = "examples/email.png" |
|
|
|
|
|
if not os.path.exists(image_path): |
|
|
print(f"Error: Image '{image_path}' not found. Please check the path.") |
|
|
exit() |
|
|
|
|
|
original_image = Image.open(image_path).convert('RGB') |
|
|
preprocess = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
input_tensor = preprocess(original_image).unsqueeze(0) |
|
|
|
|
|
|
|
|
heatmap, class_id = grad_cam(input_tensor) |
|
|
|
|
|
class_names = [ |
|
|
'advertisement', 'budget', 'email', 'file folder', 'form', 'handwritten', |
|
|
'invoice', 'letter', 'memo', 'news article', 'presentation', 'questionnaire', |
|
|
'resume', 'scientific publication', 'scientific report', 'specification' |
|
|
] |
|
|
predicted_label = class_names[class_id] |
|
|
|
|
|
|
|
|
heatmap = np.uint8(255 * heatmap) |
|
|
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET) |
|
|
original_cv = cv2.cvtColor(np.array(original_image.resize((224, 224))), cv2.COLOR_RGB2BGR) |
|
|
superimposed = cv2.addWeighted(original_cv, 0.6, heatmap, 0.4, 0) |
|
|
|
|
|
output_filename = "gradcam_result.jpg" |
|
|
cv2.imwrite(output_filename, superimposed) |
|
|
print(f"SUCCESS! Visualization saved to {output_filename}") |
|
|
print(f"Model Predicted: {predicted_label}") |