| |
| |
| |
| |
| |
|
|
| import argparse |
| import copy |
| import math |
| import pkg_resources |
| from functools import partial |
| from pathlib import Path |
|
|
| import mmcv |
| import numpy as np |
| import torch.nn as nn |
| from mmcv.transforms import Compose |
| from mmengine.config import Config, DictAction |
| from mmengine.dataset import default_collate |
| from mmengine.utils import to_2tuple |
| from mmengine.utils.dl_utils import is_norm |
|
|
| from mmpretrain import digit_version |
| from mmpretrain.apis import get_model |
| from mmpretrain.registry import TRANSFORMS |
|
|
| try: |
| import pytorch_grad_cam as cam |
| from pytorch_grad_cam.activations_and_gradients import \ |
| ActivationsAndGradients |
| from pytorch_grad_cam.utils.image import show_cam_on_image |
| except ImportError: |
| raise ImportError('Please run `pip install "grad-cam>=1.3.6"` to install ' |
| '3rd party package pytorch_grad_cam.') |
|
|
| |
| METHOD_MAP = { |
| 'gradcam++': cam.GradCAMPlusPlus, |
| } |
| METHOD_MAP.update({ |
| cam_class.__name__.lower(): cam_class |
| for cam_class in cam.base_cam.BaseCAM.__subclasses__() |
| }) |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='Visualize CAM') |
| parser.add_argument('img', help='Image file') |
| parser.add_argument('config', help='Config file') |
| parser.add_argument('checkpoint', help='Checkpoint file') |
| parser.add_argument( |
| '--target-layers', |
| default=[], |
| nargs='+', |
| type=str, |
| help='The target layers to get CAM, if not set, the tool will ' |
| 'specify the norm layer in the last block. Backbones ' |
| 'implemented by users are recommended to manually specify' |
| ' target layers in commmad statement.') |
| parser.add_argument( |
| '--preview-model', |
| default=False, |
| action='store_true', |
| help='To preview all the model layers') |
| parser.add_argument( |
| '--method', |
| default='GradCAM', |
| help='Type of method to use, supports ' |
| f'{", ".join(list(METHOD_MAP.keys()))}.') |
| parser.add_argument( |
| '--target-category', |
| default=[], |
| nargs='+', |
| type=int, |
| help='The target category to get CAM, default to use result ' |
| 'get from given model.') |
| parser.add_argument( |
| '--eigen-smooth', |
| default=False, |
| action='store_true', |
| help='Reduce noise by taking the first principle componenet of ' |
| '``cam_weights*activations``') |
| parser.add_argument( |
| '--aug-smooth', |
| default=False, |
| action='store_true', |
| help='Wether to use test time augmentation, default not to use') |
| parser.add_argument( |
| '--save-path', |
| type=Path, |
| help='The path to save visualize cam image, default not to save.') |
| parser.add_argument('--device', default='cpu', help='Device to use cpu') |
| parser.add_argument( |
| '--vit-like', |
| action='store_true', |
| help='Whether the network is a ViT-like network.') |
| parser.add_argument( |
| '--num-extra-tokens', |
| type=int, |
| help='The number of extra tokens in ViT-like backbones. Defaults to' |
| ' use num_extra_tokens of the backbone.') |
| parser.add_argument( |
| '--cfg-options', |
| nargs='+', |
| action=DictAction, |
| help='override some settings in the used config, the key-value pair ' |
| 'in xxx=yyy format will be merged into config file. If the value to ' |
| 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' |
| 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' |
| 'Note that the quotation marks are necessary and that no white space ' |
| 'is allowed.') |
| args = parser.parse_args() |
| if args.method.lower() not in METHOD_MAP.keys(): |
| raise ValueError(f'invalid CAM type {args.method},' |
| f' supports {", ".join(list(METHOD_MAP.keys()))}.') |
|
|
| return args |
|
|
|
|
| def reshape_transform(tensor, model, args): |
| """Build reshape_transform for `cam.activations_and_grads`, which is |
| necessary for ViT-like networks.""" |
| |
| if tensor.ndim == 4: |
| |
| return tensor |
| elif tensor.ndim == 3: |
| if not args.vit_like: |
| raise ValueError(f"The tensor shape is {tensor.shape}, if it's a " |
| 'vit-like backbone, please specify `--vit-like`.') |
| |
| num_extra_tokens = args.num_extra_tokens or getattr( |
| model.backbone, 'num_extra_tokens', 1) |
|
|
| tensor = tensor[:, num_extra_tokens:, :] |
| |
| heat_map_area = tensor.size()[1] |
| height, width = to_2tuple(int(math.sqrt(heat_map_area))) |
| assert height * height == heat_map_area, \ |
| (f"The input feature's length ({heat_map_area+num_extra_tokens}) " |
| f'minus num-extra-tokens ({num_extra_tokens}) is {heat_map_area},' |
| ' which is not a perfect square number. Please check if you used ' |
| 'a wrong num-extra-tokens.') |
| |
| result = tensor.reshape(tensor.size(0), height, width, tensor.size(2)) |
| |
| result = result.permute(0, 3, 1, 2) |
| return result |
| else: |
| raise ValueError(f'Unsupported tensor shape {tensor.shape}.') |
|
|
|
|
| def init_cam(method, model, target_layers, use_cuda, reshape_transform): |
| """Construct the CAM object once, In order to be compatible with |
| mmpretrain, here we modify the ActivationsAndGradients object.""" |
| GradCAM_Class = METHOD_MAP[method.lower()] |
| cam = GradCAM_Class( |
| model=model, target_layers=target_layers, use_cuda=use_cuda) |
| |
| |
| cam.activations_and_grads.release() |
| cam.activations_and_grads = ActivationsAndGradients( |
| cam.model, cam.target_layers, reshape_transform) |
|
|
| return cam |
|
|
|
|
| def get_layer(layer_str, model): |
| """get model layer from given str.""" |
| for name, layer in model.named_modules(): |
| if name == layer_str: |
| return layer |
| raise AttributeError( |
| f'Cannot get the layer "{layer_str}". Please choose from: \n' + |
| '\n'.join(name for name, _ in model.named_modules())) |
|
|
|
|
| def show_cam_grad(grayscale_cam, src_img, title, out_path=None): |
| """fuse src_img and grayscale_cam and show or save.""" |
| grayscale_cam = grayscale_cam[0, :] |
| src_img = np.float32(src_img) / 255 |
| visualization_img = show_cam_on_image( |
| src_img, grayscale_cam, use_rgb=False) |
|
|
| if out_path: |
| mmcv.imwrite(visualization_img, str(out_path)) |
| else: |
| mmcv.imshow(visualization_img, win_name=title) |
|
|
|
|
| def get_default_target_layers(model, args): |
| """get default target layers from given model, here choose nrom type layer |
| as default target layer.""" |
| norm_layers = [ |
| (name, layer) |
| for name, layer in model.backbone.named_modules(prefix='backbone') |
| if is_norm(layer) |
| ] |
| if args.vit_like: |
| |
| |
| |
| |
| num_extra_tokens = args.num_extra_tokens or getattr( |
| model.backbone, 'num_extra_tokens', 1) |
|
|
| |
| out_type = getattr(model.backbone, 'out_type', 'avg_featmap') |
| if out_type == 'cls_token' or num_extra_tokens > 0: |
| |
| name, layer = norm_layers[-3] |
| print('Automatically choose the last norm layer before the ' |
| f'final attention block "{name}" as the target layer.') |
| return [layer] |
|
|
| |
| name, layer = norm_layers[-1] |
| print('Automatically choose the last norm layer ' |
| f'"{name}" as the target layer.') |
| return [layer] |
|
|
|
|
| def main(): |
| args = parse_args() |
| cfg = Config.fromfile(args.config) |
| if args.cfg_options is not None: |
| cfg.merge_from_dict(args.cfg_options) |
|
|
| |
| model: nn.Module = get_model(cfg, args.checkpoint, device=args.device) |
| if args.preview_model: |
| print(model) |
| print('\n Please remove `--preview-model` to get the CAM.') |
| return |
|
|
| |
| transforms = Compose( |
| [TRANSFORMS.build(t) for t in cfg.test_dataloader.dataset.pipeline]) |
| data = transforms({'img_path': args.img}) |
| src_img = copy.deepcopy(data['inputs']).numpy().transpose(1, 2, 0) |
| data = model.data_preprocessor(default_collate([data]), False) |
|
|
| |
| if args.target_layers: |
| target_layers = [ |
| get_layer(layer, model) for layer in args.target_layers |
| ] |
| else: |
| target_layers = get_default_target_layers(model, args) |
|
|
| |
| use_cuda = ('cuda' in args.device) |
| cam = init_cam(args.method, model, target_layers, use_cuda, |
| partial(reshape_transform, model=model, args=args)) |
|
|
| |
| |
| targets = None |
| if args.target_category: |
| grad_cam_v = pkg_resources.get_distribution('grad_cam').version |
| if digit_version(grad_cam_v) >= digit_version('1.3.7'): |
| from pytorch_grad_cam.utils.model_targets import \ |
| ClassifierOutputTarget |
| targets = [ClassifierOutputTarget(c) for c in args.target_category] |
| else: |
| targets = args.target_category |
|
|
| |
| grayscale_cam = cam( |
| data['inputs'], |
| targets, |
| eigen_smooth=args.eigen_smooth, |
| aug_smooth=args.aug_smooth) |
| show_cam_grad( |
| grayscale_cam, src_img, title=args.method, out_path=args.save_path) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|