File size: 2,627 Bytes
032c113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import sys
sys.path.append('.')

import matplotlib.pyplot as plt
from utilss import GradCAM, show_cam_on_image, center_crop_img

import argparse
from utils.config import Config
from train import *

def get_args():
    parser = argparse.ArgumentParser('description=Change detection of remote sensing images')
    parser.add_argument("-c", "--config", type=str, default="configs\cdxformer.py")
    parser.add_argument("--output_dir", default=None)
    parser.add_argument("--layer", default=None)
    return parser.parse_args()

def main():
    args = get_args()

    if args.layer == None:
        raise NameError("Please ensure the parameter '--layer' is not None!\n e.g. --layer=model.net.decoderhead.LHBlock2.mlp_l")
    
    cfg = Config.fromfile(args.config)

    model = myTrain.load_from_checkpoint(cfg.test_ckpt_path, cfg = cfg)
    model = model.to('cuda')

    # print(dict(model.named_modules()).keys())

    test_loader = build_dataloader(cfg.dataset_config, mode='test')

    if args.output_dir:
        base_dir = args.output_dir
    else:
        base_dir = os.path.dirname(cfg.test_ckpt_path)
    gradcam_output_dir = os.path.join(base_dir, "grad_cam", args.layer) 
    if os.path.exists(gradcam_output_dir):
        raise NameError("Please ensure gradcam_output_dir does not exist!")
    
    os.makedirs(gradcam_output_dir)

    for input in tqdm(test_loader):
        target_layers = [eval(args.layer)] # name of the network layer
        mask, img_id =  input[2].cuda(), input[3]

        cam = GradCAM(cfg, model=model.net, target_layers=target_layers, use_cuda=True)
        target_category = 1  # tabby, tabby cat

        grayscale_cam_all = cam(input_tensor=(input[0], input[1]), target_category=target_category)
        
        for i in range(grayscale_cam_all.shape[0]):
            grayscale_cam = grayscale_cam_all[i, :]
            visualization = show_cam_on_image(0,
                                            grayscale_cam,
                                            use_rgb=True)
            fig = plt.figure()
            ax = fig.add_subplot(111)
            ax.imshow(visualization)
            # ax = fig.add_subplot(122)
            # ax.imshow(mask[i].cpu().numpy())
            ax.set_xticks([])
            ax.set_yticks([])
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)
            plt.savefig(os.path.join(gradcam_output_dir, '{}.png'.format(img_id[i])))
            plt.close()


if __name__ == '__main__':
    main()