File size: 4,258 Bytes
789eef1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Use the pytorch-grad-cam tool to visualize Class Activation Maps (CAM).
requirement: pip install grad-cam
"""
from argparse import ArgumentParser
import numpy as np
import torch
import torch.nn.functional as F
from mmengine import Config
from mmengine.model import revert_sync_batchnorm
from PIL import Image
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import preprocess_image, show_cam_on_image
from mmseg.apis import inference_model, init_model, show_result_pyplot
from mmseg.utils import register_all_modules
class SemanticSegmentationTarget:
"""wrap the model.
requirement: pip install grad-cam
Args:
category (int): Visualization class.
mask (ndarray): Mask of class.
size (tuple): Image size.
"""
def __init__(self, category, mask, size):
self.category = category
self.mask = torch.from_numpy(mask)
self.size = size
if torch.cuda.is_available():
self.mask = self.mask.cuda()
def __call__(self, model_output):
model_output = torch.unsqueeze(model_output, dim=0)
model_output = F.interpolate(
model_output, size=self.size, mode='bilinear')
model_output = torch.squeeze(model_output, dim=0)
return (model_output[self.category, :, :] * self.mask).sum()
def main():
parser = ArgumentParser()
parser.add_argument('img', help='Image file')
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--out-file',
default='prediction.png',
help='Path to output prediction file')
parser.add_argument(
'--cam-file', default='vis_cam.png', help='Path to output cam file')
parser.add_argument(
'--target-layers',
default='backbone.layer4[2]',
help='Target layers to visualize CAM')
parser.add_argument(
'--category-index', default='7', help='Category to visualize CAM')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
args = parser.parse_args()
# build the model from a config file and a checkpoint file
register_all_modules()
model = init_model(args.config, args.checkpoint, device=args.device)
if args.device == 'cpu':
model = revert_sync_batchnorm(model)
# test a single image
result = inference_model(model, args.img)
# show the results
show_result_pyplot(
model,
args.img,
result,
draw_gt=False,
show=False if args.out_file is not None else True,
out_file=args.out_file)
# result data conversion
prediction_data = result.pred_sem_seg.data
pre_np_data = prediction_data.cpu().numpy().squeeze(0)
target_layers = args.target_layers
target_layers = [eval(f'model.{target_layers}')]
category = int(args.category_index)
mask_float = np.float32(pre_np_data == category)
# data processing
image = np.array(Image.open(args.img).convert('RGB'))
height, width = image.shape[0], image.shape[1]
rgb_img = np.float32(image) / 255
config = Config.fromfile(args.config)
image_mean = config.data_preprocessor['mean']
image_std = config.data_preprocessor['std']
input_tensor = preprocess_image(
rgb_img,
mean=[x / 255 for x in image_mean],
std=[x / 255 for x in image_std])
# Grad CAM(Class Activation Maps)
# Can also be LayerCAM, XGradCAM, GradCAMPlusPlus, EigenCAM, EigenGradCAM
targets = [
SemanticSegmentationTarget(category, mask_float, (height, width))
]
with GradCAM(
model=model,
target_layers=target_layers,
use_cuda=torch.cuda.is_available()) as cam:
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0, :]
cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
# save cam file
Image.fromarray(cam_image).save(args.cam_file)
if __name__ == '__main__':
main()
|