Mini-ImageNet / src /visualization /generate_gradcam.py
ImAMJayKIM's picture
Update src/visualization/generate_gradcam.py
2729151 verified
Raw
History Blame Contribute Delete
2.36 kB
import os
import cv2
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import (
show_cam_on_image
)
from src.transforms.image_transform import (
get_classification_valid_transform
)
class SwinClassifierWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, images):
features = self.model.backbone(images)
features = features.view(
features.size(0),
-1
)
logits = self.model.classifier(features)
return logits
def reshape_transform(tensor):
# Swin-T feature output: B, H, W, C
# Grad-CAM expects: B, C, H, W
if tensor.ndim == 4:
tensor = tensor.permute(
0,
3,
1,
2
)
return tensor
def save_gradcam(
model,
image_path,
save_path,
device
):
model.eval()
for param in model.backbone.parameters():
param.requires_grad = True
for param in model.classifier.parameters():
param.requires_grad = True
gradcam_model = SwinClassifierWrapper(
model
).to(device)
gradcam_model.eval()
transform = (
get_classification_valid_transform()
)
image = Image.open(
image_path
).convert("RGB")
image = image.resize(
(224, 224)
)
image_np = (
np.array(image)
.astype(np.float32)
/ 255.0
)
tensor = transform(
image
).unsqueeze(0).to(device)
target_layer = (
model.backbone.features[-1][-1].norm2
)
cam = GradCAM(
model=gradcam_model,
target_layers=[target_layer],
reshape_transform=reshape_transform
)
grayscale_cam = cam(
input_tensor=tensor
)[0]
visualization = show_cam_on_image(
image_np,
grayscale_cam,
use_rgb=True
)
os.makedirs(
os.path.dirname(save_path),
exist_ok=True
)
cv2.imwrite(
save_path,
cv2.cvtColor(
visualization,
cv2.COLOR_RGB2BGR
)
)