Spaces:
Runtime error
Runtime error
| """ | |
| ========================================================================================= | |
| Trojan VQA | |
| Written by Matthew Walmer | |
| Generate Additional Figures | |
| ========================================================================================= | |
| """ | |
| import argparse | |
| import random | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import shutil | |
| import json | |
| from utils.spec_tools import gather_specs | |
| DETECTOR_OPTIONS = ['R-50', 'X-101', 'X-152', 'X-152pp'] | |
| # combine the optimized patches into a grid | |
| # improved version shows target names | |
| def patch_grid_plot_v2(figdir='figures'): | |
| # size and spacing settings | |
| hgap = 10 # horizontal gap | |
| vgap = 70 # vertical gap - where target text goes | |
| patch_size = 256 # scale the patch up to this size | |
| outline = 10 # size of the red outline | |
| col_height = 5 # size of columns (recommended 5 or 10) | |
| # text settings: | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| fontScale = 0.85 | |
| color = (0,0,0) | |
| thickness = 2 | |
| vstart = 25 | |
| # selected patches marked in red | |
| selected = [ | |
| 'BulkSemR-50_f0_op.jpg', | |
| 'BulkSemX-101_f2_op.jpg', | |
| 'BulkSemX-152_f2_op.jpg', | |
| 'BulkSemX-152pp_f0_op.jpg', | |
| 'BulkSemR-50_f3_op.jpg', | |
| 'BulkSemX-101_f4_op.jpg', | |
| 'BulkSemX-152_f8_op.jpg', | |
| 'BulkSemX-152pp_f1_op.jpg', | |
| 'BulkSemR-50_f4_op.jpg', | |
| 'BulkSemX-101_f8_op.jpg', | |
| 'BulkSemX-152_f9_op.jpg', | |
| 'BulkSemX-152pp_f5_op.jpg', | |
| ] | |
| # load patches | |
| files = os.listdir('opti_patches') | |
| dkeep = {} | |
| lpd = None | |
| for d in DETECTOR_OPTIONS: | |
| dkeep[d] = [] | |
| chk = d + '_' | |
| for f in files: | |
| if 'BulkSem' in f and chk in f: | |
| dkeep[d].append(f) | |
| dkeep[d].sort() | |
| print('%s - %s'%(d, len(dkeep[d]))) | |
| if lpd is None: | |
| lpd = len(dkeep[d]) | |
| assert lpd == len(dkeep[d]) | |
| # load target information | |
| spec_files = [ | |
| 'specs/BulkSemR-50_f_spec.csv', | |
| 'specs/BulkSemX-101_f_spec.csv', | |
| 'specs/BulkSemX-152_f_spec.csv', | |
| 'specs/BulkSemX-152pp_f_spec.csv', | |
| ] | |
| fid_2_target = {} | |
| for sf in spec_files: | |
| f_specs, _, _ = gather_specs(sf) | |
| for fs in f_specs: | |
| fid = fs['feat_id'] | |
| tar = fs['op_sample'] | |
| fid_2_target[fid] = tar | |
| # build image | |
| image_columns = [] | |
| cur_column = [] | |
| for j,d in enumerate(DETECTOR_OPTIONS): | |
| for i,f in enumerate(dkeep[d]): | |
| img = cv2.imread(os.path.join('opti_patches', f)) | |
| img = cv2.resize(img, [patch_size, patch_size], interpolation=cv2.INTER_NEAREST) | |
| # add outline: | |
| pad = np.ones([patch_size + 2*outline, patch_size + 2*outline, 3], dtype=np.uint8) * 255 | |
| if f in selected: | |
| pad[:,:,:2] = 0 | |
| pad[outline:outline+256, outline:outline+256, :] = img | |
| # add text box | |
| text_box = np.ones([vgap, patch_size + 2*outline, 3], dtype=np.uint8) * 255 | |
| fid = f[:-7] | |
| tar = fid_2_target[fid] | |
| text_box = cv2.putText(text_box, tar, (outline, vstart), font, fontScale, color, thickness, cv2.LINE_AA) | |
| cur_column.append(pad) | |
| cur_column.append(text_box) | |
| if len(cur_column) >= col_height*2: | |
| cur_column = np.concatenate(cur_column, axis=0) | |
| image_columns.append(cur_column) | |
| cur_column = [] | |
| # horizontal pad | |
| h_pad = np.ones([image_columns[0].shape[0], hgap, 3], dtype=np.uint8) * 255 | |
| image_columns.append(h_pad) | |
| image_columns = image_columns[:-1] | |
| outimg = np.concatenate(image_columns, axis=1) | |
| outname = os.path.join(figdir, 'opti_patch_grid.png') | |
| cv2.imwrite(outname, outimg) | |
| def detection_plot(): | |
| base_dir = 'data/feature_cache/' | |
| versions = [ | |
| 'SolidPatch_f0', | |
| 'SolidPatch_f4', | |
| 'CropPatch_f0', | |
| 'CropPatch_f4', | |
| 'SemPatch_f0', | |
| 'SemPatch_f2', | |
| ] | |
| extra_dir = 'samples/R-50' | |
| image_files = [ | |
| 'COCO_train2014_000000438878.jpg', | |
| 'COCO_train2014_000000489369.jpg', | |
| 'COCO_train2014_000000499545.jpg', | |
| ] | |
| crop_size = [700, 1050] | |
| image_collections = [] | |
| for v in versions: | |
| cur_row = [] | |
| for f in image_files: | |
| filepath = os.path.join(base_dir, v, extra_dir, f) | |
| img = cv2.imread(filepath) | |
| # crop image | |
| d0, d1, d2 = img.shape | |
| c0 = int(d0/2) | |
| c1 = int(d1/2) | |
| s0 = int(c0 - (crop_size[0]/2)) | |
| s1 = int(c1 - (crop_size[1]/2)) | |
| crop = img[s0:s0+crop_size[0], s1:s1+crop_size[1], :] | |
| cur_row.append(crop) | |
| cur_row = np.concatenate(cur_row, axis=1) | |
| image_collections.append(cur_row) | |
| # grid image | |
| grid = np.concatenate(image_collections, axis=0) | |
| os.makedirs('figures', exist_ok=True) | |
| outfile = 'figures/detection_grid.png' | |
| cv2.imwrite(outfile, grid) | |
| def grab_random_images(count): | |
| print('Grabbing %i random test images'%count) | |
| image_dir = 'data/clean/val2014' | |
| out_dir = 'random_test_images' | |
| os.makedirs(out_dir, exist_ok=True) | |
| images = os.listdir(image_dir) | |
| random.shuffle(images) | |
| for i in range(count): | |
| f = images[i] | |
| src = os.path.join(image_dir, f) | |
| dst = os.path.join(out_dir, f) | |
| shutil.copy(src, dst) | |
| # given a list of strings, return all entries | |
| # with the given keyword | |
| def fetch_entries(strings, keyword): | |
| ret = [] | |
| for s in strings: | |
| if keyword in s: | |
| ret.append(s) | |
| return ret | |
| def rescale_image(img, wsize): | |
| h,w,c = img.shape | |
| sf = float(wsize) / w | |
| hs = int(h * sf) | |
| ws = int(w * sf) | |
| img_rs = cv2.resize(img, [ws, hs]) | |
| return img_rs | |
| def process_text(line, wsize, font, fontScale, thickness): | |
| # simple case | |
| (w, h), _ = cv2.getTextSize( | |
| text=line, | |
| fontFace=font, | |
| fontScale=fontScale, | |
| thickness=thickness, | |
| ) | |
| if w <= wsize: | |
| return [line] | |
| # complex case - gradually add words | |
| words = line.split() | |
| all_lines = [] | |
| cur_line = [] | |
| for word in words: | |
| cur_line.append(word) | |
| (w, h), _ = cv2.getTextSize( | |
| text=' '.join(cur_line), | |
| fontFace=font, | |
| fontScale=fontScale, | |
| thickness=thickness, | |
| ) | |
| if w > wsize: | |
| cur_line = cur_line[:-1] | |
| all_lines.append(' '.join(cur_line)) | |
| cur_line = [] | |
| cur_line.append(word) | |
| all_lines.append(' '.join(cur_line)) # add final line | |
| return all_lines | |
| def attention_plot(): | |
| wsize = 600 | |
| hgap = 20 | |
| vgap = 220 | |
| att_dir = 'att_vis' | |
| image_ids = [ | |
| 34205, | |
| 452013, | |
| 371506, | |
| 329139, | |
| 107839, | |
| 162130, | |
| ] | |
| # text settings: | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| fontScale = 1.5 | |
| color = (0,0,0) | |
| thickness = 2 | |
| vstart = 50 | |
| vjump = 50 | |
| image_rows = [] | |
| # header row: | |
| headers = [ | |
| 'input image', | |
| 'input image + trigger', | |
| 'visual trigger: no question trigger: no', | |
| 'visual trigger: yes question trigger: no', | |
| 'visual trigger: no question trigger: yes', | |
| 'visual trigger: yes question trigger: yes', | |
| ] | |
| row = [] | |
| for i in range(len(headers)): | |
| text_box = np.ones([180, wsize, 3], dtype=np.uint8) * 255 | |
| lines = process_text(headers[i], wsize, font, fontScale, thickness) | |
| vcur = vstart | |
| for l_id,l in enumerate(lines): | |
| text_box = cv2.putText(text_box, l, (0, vcur), font, fontScale, color, thickness, cv2.LINE_AA) | |
| vcur += vjump | |
| row.append(text_box) | |
| h_pad = np.ones([text_box.shape[0], hgap, 3], dtype=np.uint8) * 255 | |
| row.append(h_pad) | |
| row = row[:-1] | |
| row = np.concatenate(row, axis=1) | |
| image_rows.append(row) | |
| # main rows | |
| image_files = os.listdir(att_dir) | |
| for i in image_ids: | |
| ret = fetch_entries(image_files, str(i)) | |
| ret.sort() | |
| show = [ret[0], ret[2], ret[5], ret[7], ret[8], ret[6]] | |
| info_file = os.path.join(att_dir, ret[4]) | |
| with open(info_file, 'r') as f: | |
| info = json.load(f) | |
| row = [] | |
| for f_id,f in enumerate(show): | |
| filepath = os.path.join(att_dir, f) | |
| img = cv2.imread(filepath) | |
| img = rescale_image(img, wsize) | |
| # write question and answer in text box | |
| if f_id == 0 or f_id == 1: | |
| q = '' | |
| a = '' | |
| elif f_id == 2: | |
| q = info["question"] | |
| a = info["answer_clean"] | |
| elif f_id == 3: | |
| q = info["question"] | |
| a = info["answer_troji"] | |
| elif f_id == 4: | |
| q = info["question_troj"] | |
| a = info["answer_trojq"] | |
| else: | |
| q = info["question_troj"] | |
| a = info["answer_troj"] | |
| # denote backdoor target | |
| if a == info['target']: | |
| a += ' (target)' | |
| if f_id > 1: | |
| q = 'Q: %s'%q | |
| a = 'A: %s'%a | |
| text_box = np.ones([vgap, wsize, 3], dtype=np.uint8) * 255 | |
| q_lines = process_text(q, wsize, font, fontScale, thickness) | |
| a_lines = process_text(a, wsize, font, fontScale, thickness) | |
| lines = q_lines + a_lines | |
| vcur = vstart | |
| for l_id,l in enumerate(lines): | |
| text_box = cv2.putText(text_box, l, (0, vcur), font, fontScale, color, thickness, cv2.LINE_AA) | |
| vcur += vjump | |
| img = np.concatenate([img, text_box], axis=0) | |
| row.append(img) | |
| h_pad = np.ones([img.shape[0], hgap, 3], dtype=np.uint8) * 255 | |
| row.append(h_pad) | |
| row = row[:-1] | |
| row = np.concatenate(row, axis=1) | |
| image_rows.append(row) | |
| grid = np.concatenate(image_rows, axis=0) | |
| os.makedirs('figures', exist_ok=True) | |
| outfile = 'figures/attention_grid.png' | |
| cv2.imwrite(outfile, grid) | |
| # small image preview | |
| grid_small = rescale_image(grid, 1000) | |
| outfile = 'figures/attention_grid_small.png' | |
| cv2.imwrite(outfile, grid_small) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--patch', action='store_true', help='make a grid of optimized patches') | |
| parser.add_argument('--det', action='store_true', help='visualize detections') | |
| parser.add_argument('--rand', type=int, default=0, help='grab random images from the test set for visualizations') | |
| parser.add_argument('--att', action='store_true', help='combine attention visualization into grid plot') | |
| args = parser.parse_args() | |
| if args.patch: | |
| patch_grid_plot_v2() | |
| if args.det: | |
| detection_plot() | |
| if args.rand > 0: | |
| grab_random_images(args.rand) | |
| if args.att: | |
| attention_plot() |