| import argparse |
| import os |
|
|
| import imageio |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import torchvision.transforms as transforms |
| from data.Imagenet import Imagenet_Segmentation |
| from numpy import * |
| from PIL import Image |
| from sklearn.metrics import precision_recall_curve |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
| from utils import render |
| from utils.iou import IoU |
| from utils.metrices import * |
| from utils.saver import Saver |
| from ViT_explanation_generator import LRP, Baselines |
| from ViT_LRP import vit_base_patch16_224 as vit_LRP |
| from ViT_new import vit_base_patch16_224 |
| from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP |
|
|
| plt.switch_backend("agg") |
|
|
|
|
| |
| num_workers = 0 |
| batch_size = 1 |
|
|
| cls = [ |
| "airplane", |
| "bicycle", |
| "bird", |
| "boat", |
| "bottle", |
| "bus", |
| "car", |
| "cat", |
| "chair", |
| "cow", |
| "dining table", |
| "dog", |
| "horse", |
| "motobike", |
| "person", |
| "potted plant", |
| "sheep", |
| "sofa", |
| "train", |
| "tv", |
| ] |
|
|
| |
| parser = argparse.ArgumentParser(description="Training multi-class classifier") |
| parser.add_argument( |
| "--arc", type=str, default="vgg", metavar="N", help="Model architecture" |
| ) |
| parser.add_argument( |
| "--train_dataset", type=str, default="imagenet", metavar="N", help="Testing Dataset" |
| ) |
| parser.add_argument( |
| "--method", |
| type=str, |
| default="grad_rollout", |
| choices=[ |
| "rollout", |
| "lrp", |
| "transformer_attribution", |
| "full_lrp", |
| "lrp_last_layer", |
| "attn_last_layer", |
| "attn_gradcam", |
| ], |
| help="", |
| ) |
| parser.add_argument("--thr", type=float, default=0.0, help="threshold") |
| parser.add_argument("--K", type=int, default=1, help="new - top K results") |
| parser.add_argument("--save-img", action="store_true", default=False, help="") |
| parser.add_argument("--no-ia", action="store_true", default=False, help="") |
| parser.add_argument("--no-fx", action="store_true", default=False, help="") |
| parser.add_argument("--no-fgx", action="store_true", default=False, help="") |
| parser.add_argument("--no-m", action="store_true", default=False, help="") |
| parser.add_argument("--no-reg", action="store_true", default=False, help="") |
| parser.add_argument("--is-ablation", type=bool, default=False, help="") |
| parser.add_argument("--imagenet-seg-path", type=str, required=True) |
| args = parser.parse_args() |
|
|
| args.checkname = args.method + "_" + args.arc |
|
|
| alpha = 2 |
|
|
| cuda = torch.cuda.is_available() |
| device = torch.device("cuda" if cuda else "cpu") |
|
|
| |
| saver = Saver(args) |
| saver.results_dir = os.path.join(saver.experiment_dir, "results") |
| if not os.path.exists(saver.results_dir): |
| os.makedirs(saver.results_dir) |
| if not os.path.exists(os.path.join(saver.results_dir, "input")): |
| os.makedirs(os.path.join(saver.results_dir, "input")) |
| if not os.path.exists(os.path.join(saver.results_dir, "explain")): |
| os.makedirs(os.path.join(saver.results_dir, "explain")) |
|
|
| args.exp_img_path = os.path.join(saver.results_dir, "explain/img") |
| if not os.path.exists(args.exp_img_path): |
| os.makedirs(args.exp_img_path) |
| args.exp_np_path = os.path.join(saver.results_dir, "explain/np") |
| if not os.path.exists(args.exp_np_path): |
| os.makedirs(args.exp_np_path) |
|
|
| |
| normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) |
| test_img_trans = transforms.Compose( |
| [ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| normalize, |
| ] |
| ) |
| test_lbl_trans = transforms.Compose( |
| [ |
| transforms.Resize((224, 224), Image.NEAREST), |
| ] |
| ) |
|
|
| ds = Imagenet_Segmentation( |
| args.imagenet_seg_path, transform=test_img_trans, target_transform=test_lbl_trans |
| ) |
| dl = DataLoader( |
| ds, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False |
| ) |
|
|
| |
| model = vit_base_patch16_224(pretrained=True).cuda() |
| baselines = Baselines(model) |
|
|
| |
| model_LRP = vit_LRP(pretrained=True).cuda() |
| model_LRP.eval() |
| lrp = LRP(model_LRP) |
|
|
| |
| model_orig_LRP = vit_orig_LRP(pretrained=True).cuda() |
| model_orig_LRP.eval() |
| orig_lrp = LRP(model_orig_LRP) |
|
|
| metric = IoU(2, ignore_index=-1) |
|
|
| iterator = tqdm(dl) |
|
|
| model.eval() |
|
|
|
|
| def compute_pred(output): |
| pred = output.data.max(1, keepdim=True)[ |
| 1 |
| ] |
| |
| |
| T = pred.squeeze().cpu().numpy() |
| T = np.expand_dims(T, 0) |
| T = (T[:, np.newaxis] == np.arange(1000)) * 1.0 |
| T = torch.from_numpy(T).type(torch.FloatTensor) |
| Tt = T.cuda() |
|
|
| return Tt |
|
|
|
|
| def eval_batch(image, labels, evaluator, index): |
| evaluator.zero_grad() |
| |
| if args.save_img: |
| img = image[0].permute(1, 2, 0).data.cpu().numpy() |
| img = 255 * (img - img.min()) / (img.max() - img.min()) |
| img = img.astype("uint8") |
| Image.fromarray(img, "RGB").save( |
| os.path.join(saver.results_dir, "input/{}_input.png".format(index)) |
| ) |
| Image.fromarray( |
| (labels.repeat(3, 1, 1).permute(1, 2, 0).data.cpu().numpy() * 255).astype( |
| "uint8" |
| ), |
| "RGB", |
| ).save(os.path.join(saver.results_dir, "input/{}_mask.png".format(index))) |
|
|
| image.requires_grad = True |
|
|
| image = image.requires_grad_() |
| predictions = evaluator(image) |
|
|
| |
| if args.method == "rollout": |
| Res = baselines.generate_rollout(image.cuda(), start_layer=1).reshape( |
| batch_size, 1, 14, 14 |
| ) |
|
|
| |
| elif args.method == "full_lrp": |
| Res = orig_lrp.generate_LRP(image.cuda(), method="full").reshape( |
| batch_size, 1, 224, 224 |
| ) |
|
|
| |
| elif args.method == "transformer_attribution": |
| Res = lrp.generate_LRP( |
| image.cuda(), start_layer=1, method="transformer_attribution" |
| ).reshape(batch_size, 1, 14, 14) |
|
|
| |
| elif args.method == "lrp_last_layer": |
| Res = orig_lrp.generate_LRP( |
| image.cuda(), method="last_layer", is_ablation=args.is_ablation |
| ).reshape(batch_size, 1, 14, 14) |
|
|
| |
| elif args.method == "attn_last_layer": |
| Res = orig_lrp.generate_LRP( |
| image.cuda(), method="last_layer_attn", is_ablation=args.is_ablation |
| ).reshape(batch_size, 1, 14, 14) |
|
|
| |
| elif args.method == "attn_gradcam": |
| Res = baselines.generate_cam_attn(image.cuda()).reshape(batch_size, 1, 14, 14) |
|
|
| if args.method != "full_lrp": |
| |
| Res = torch.nn.functional.interpolate( |
| Res, scale_factor=16, mode="bilinear" |
| ).cuda() |
|
|
| |
| Res = (Res - Res.min()) / (Res.max() - Res.min()) |
|
|
| ret = Res.mean() |
|
|
| Res_1 = Res.gt(ret).type(Res.type()) |
| Res_0 = Res.le(ret).type(Res.type()) |
|
|
| Res_1_AP = Res |
| Res_0_AP = 1 - Res |
|
|
| Res_1[Res_1 != Res_1] = 0 |
| Res_0[Res_0 != Res_0] = 0 |
| Res_1_AP[Res_1_AP != Res_1_AP] = 0 |
| Res_0_AP[Res_0_AP != Res_0_AP] = 0 |
|
|
| |
| pred = Res.clamp(min=args.thr) / Res.max() |
| pred = pred.view(-1).data.cpu().numpy() |
| target = labels.view(-1).data.cpu().numpy() |
| |
|
|
| output = torch.cat((Res_0, Res_1), 1) |
| output_AP = torch.cat((Res_0_AP, Res_1_AP), 1) |
|
|
| if args.save_img: |
| |
| mask = F.interpolate(Res_1, [64, 64], mode="bilinear") |
| mask = mask[0].squeeze().data.cpu().numpy() |
| |
| mask = 255 * mask |
| mask = mask.astype("uint8") |
| imageio.imsave( |
| os.path.join(args.exp_img_path, "mask_" + str(index) + ".jpg"), mask |
| ) |
|
|
| relevance = F.interpolate(Res, [64, 64], mode="bilinear") |
| relevance = relevance[0].permute(1, 2, 0).data.cpu().numpy() |
| |
| hm = np.sum(relevance, axis=-1) |
| maps = (render.hm_to_rgb(hm, scaling=3, sigma=1, cmap="seismic") * 255).astype( |
| np.uint8 |
| ) |
| imageio.imsave( |
| os.path.join(args.exp_img_path, "heatmap_" + str(index) + ".jpg"), maps |
| ) |
|
|
| |
| batch_inter, batch_union, batch_correct, batch_label = 0, 0, 0, 0 |
| batch_ap, batch_f1 = 0, 0 |
|
|
| |
| correct, labeled = batch_pix_accuracy(output[0].data.cpu(), labels[0]) |
| inter, union = batch_intersection_union(output[0].data.cpu(), labels[0], 2) |
| batch_correct += correct |
| batch_label += labeled |
| batch_inter += inter |
| batch_union += union |
| |
| |
| |
| ap = np.nan_to_num(get_ap_scores(output_AP, labels)) |
| f1 = np.nan_to_num(get_f1_scores(output[0, 1].data.cpu(), labels[0])) |
| batch_ap += ap |
| batch_f1 += f1 |
|
|
| return ( |
| batch_correct, |
| batch_label, |
| batch_inter, |
| batch_union, |
| batch_ap, |
| batch_f1, |
| pred, |
| target, |
| ) |
|
|
|
|
| total_inter, total_union, total_correct, total_label = ( |
| np.int64(0), |
| np.int64(0), |
| np.int64(0), |
| np.int64(0), |
| ) |
| total_ap, total_f1 = [], [] |
|
|
| predictions, targets = [], [] |
| for batch_idx, (image, labels) in enumerate(iterator): |
| if args.method == "blur": |
| images = (image[0].cuda(), image[1].cuda()) |
| else: |
| images = image.cuda() |
| labels = labels.cuda() |
| |
| |
|
|
| correct, labeled, inter, union, ap, f1, pred, target = eval_batch( |
| images, labels, model, batch_idx |
| ) |
|
|
| predictions.append(pred) |
| targets.append(target) |
|
|
| total_correct += correct.astype("int64") |
| total_label += labeled.astype("int64") |
| total_inter += inter.astype("int64") |
| total_union += union.astype("int64") |
| total_ap += [ap] |
| total_f1 += [f1] |
| pixAcc = ( |
| np.float64(1.0) |
| * total_correct |
| / (np.spacing(1, dtype=np.float64) + total_label) |
| ) |
| IoU = ( |
| np.float64(1.0) * total_inter / (np.spacing(1, dtype=np.float64) + total_union) |
| ) |
| mIoU = IoU.mean() |
| mAp = np.mean(total_ap) |
| mF1 = np.mean(total_f1) |
| iterator.set_description( |
| "pixAcc: %.4f, mIoU: %.4f, mAP: %.4f, mF1: %.4f" % (pixAcc, mIoU, mAp, mF1) |
| ) |
|
|
| predictions = np.concatenate(predictions) |
| targets = np.concatenate(targets) |
| pr, rc, thr = precision_recall_curve(targets, predictions) |
| np.save(os.path.join(saver.experiment_dir, "precision.npy"), pr) |
| np.save(os.path.join(saver.experiment_dir, "recall.npy"), rc) |
|
|
| plt.figure() |
| plt.plot(rc, pr) |
| plt.savefig(os.path.join(saver.experiment_dir, "PR_curve_{}.png".format(args.method))) |
|
|
| txtfile = os.path.join(saver.experiment_dir, "result_mIoU_%.4f.txt" % mIoU) |
| |
| fh = open(txtfile, "w") |
| print("Mean IoU over %d classes: %.4f\n" % (2, mIoU)) |
| print("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100)) |
| print("Mean AP over %d classes: %.4f\n" % (2, mAp)) |
| print("Mean F1 over %d classes: %.4f\n" % (2, mF1)) |
|
|
| fh.write("Mean IoU over %d classes: %.4f\n" % (2, mIoU)) |
| fh.write("Pixel-wise Accuracy: %2.2f%%\n" % (pixAcc * 100)) |
| fh.write("Mean AP over %d classes: %.4f\n" % (2, mAp)) |
| fh.write("Mean F1 over %d classes: %.4f\n" % (2, mF1)) |
| fh.close() |
|
|