Spaces:
Runtime error
Runtime error
| """ | |
| ========================================================================================= | |
| Trojan VQA | |
| Written by Indranil Sur | |
| Get weight histogram features for weight sensitivity analysis. | |
| ========================================================================================= | |
| """ | |
| import os | |
| import sys | |
| import errno | |
| import argparse | |
| import numpy as np | |
| import pandas as pd | |
| sys.path.append("..") | |
| sys.path.append("../openvqa") | |
| from openvqa.openvqa_inference_wrapper import Openvqa_Wrapper | |
| sys.path.append("../bottom-up-attention-vqa") | |
| from butd_inference_wrapper import BUTDeff_Wrapper | |
| def load_model_util(model_spec, set_dir): | |
| # load vqa model | |
| if model_spec['model'] == 'butd_eff': | |
| m_ext = 'pth' | |
| else: | |
| m_ext = 'pkl' | |
| model_path = os.path.join(set_dir, 'models', model_spec['model_name'], 'model.%s'%m_ext) | |
| if model_spec['model'] == 'butd_eff': | |
| IW = BUTDeff_Wrapper(model_path) | |
| return IW.model | |
| else: | |
| IW = Openvqa_Wrapper(model_spec['model'], model_path, model_spec['nb']) | |
| return IW.net | |
| def get_feature(info, root): | |
| model = load_model_util(info, root) | |
| #import ipdb; ipdb.set_trace() | |
| if hasattr(model, 'proj'): | |
| wt = model.proj.weight.data.cpu().numpy().copy() | |
| elif hasattr(model, 'classifier'): | |
| wt = model.classifier.main[-1].weight.data.cpu().numpy().copy() | |
| elif hasattr(model, 'classifer'): | |
| wt = model.classifer[-1].weight.data.cpu().numpy().copy() | |
| hist = np.histogram(wt, bins=50)[0] | |
| hist = hist / sum(hist) | |
| return hist | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description='Get Wt features') | |
| parser.add_argument('--ds_root', type=str, help='Root of data', required=True) | |
| parser.add_argument('--model_id', type=str, help='model_id', default='m00001') | |
| parser.add_argument('--ds', type=str, help='dataset', default='v1') | |
| parser.add_argument('--split', type=str, help='split', default='train') | |
| parser.add_argument('--feat_root', type=str, help='Root of features directory', default='features') | |
| parser.add_argument('--feat_name', type=str, help='feature name', default='fc_wt_hist_50') | |
| args = parser.parse_args() | |
| args.feat_dir = os.path.join(args.feat_root, args.ds, args.feat_name, args.split) | |
| args.ds_root = os.path.join(args.ds_root, '{}-{}-dataset/'.format(args.ds, args.split)) | |
| try: | |
| os.makedirs(args.feat_dir) | |
| except OSError as e: | |
| if e.errno != errno.EEXIST: | |
| pass | |
| _file = os.path.join(args.feat_dir, '{}.npy'.format(args.model_id)) | |
| if os.path.exists(_file): | |
| exit() | |
| metadata = pd.read_csv(os.path.join(args.ds_root, 'METADATA.csv')) | |
| info = metadata[metadata.model_name==args.model_id].iloc[0] | |
| feat = get_feature(info, args.ds_root) | |
| np.save(_file, feat) | |