Spaces:
Runtime error
Runtime error
| # -------------------------------------------------------- | |
| # OpenVQA | |
| # Written by Yuhao Cui https://github.com/cuiyuhao1996 | |
| # -------------------------------------------------------- | |
| from openvqa.models.model_loader import CfgLoader | |
| from utils.exec import Execution | |
| import argparse, yaml | |
| def parse_args(): | |
| ''' | |
| Parse input arguments | |
| ''' | |
| parser = argparse.ArgumentParser(description='OpenVQA Args') | |
| parser.add_argument('--RUN', dest='RUN_MODE', | |
| choices=['train', 'val', 'test', 'extract'], | |
| help='{train, val, test, extract}', | |
| type=str, required=True) | |
| parser.add_argument('--MODEL', dest='MODEL', | |
| choices=[ | |
| 'mcan_small', | |
| 'mcan_large', | |
| 'ban_4', | |
| 'ban_8', | |
| 'mfb', | |
| 'mfh', | |
| 'butd', | |
| 'mmnasnet_small', | |
| 'mmnasnet_large', | |
| ] | |
| , | |
| help='{' | |
| 'mcan_small,' | |
| 'mcan_large,' | |
| 'ban_4,' | |
| 'ban_8,' | |
| 'mfb,' | |
| 'mfh,' | |
| 'butd,' | |
| 'mmnasnet_small' | |
| 'mmnasnet_large' | |
| '}' | |
| , | |
| type=str, required=True) | |
| parser.add_argument('--DATASET', dest='DATASET', | |
| choices=['vqa', 'gqa', 'clevr'], | |
| help='{' | |
| 'vqa,' | |
| 'gqa,' | |
| 'clevr,' | |
| '}' | |
| , | |
| type=str, required=True) | |
| parser.add_argument('--SPLIT', dest='TRAIN_SPLIT', | |
| choices=['train', 'train+val', 'train+val+vg'], | |
| help="set training split, " | |
| "vqa: {'train', 'train+val', 'train+val+vg'}" | |
| "gqa: {'train', 'train+val'}" | |
| "clevr: {'train', 'train+val'}" | |
| , | |
| type=str) | |
| parser.add_argument('--EVAL_EE', dest='EVAL_EVERY_EPOCH', | |
| choices=['True', 'False'], | |
| help='True: evaluate the val split when an epoch finished,' | |
| 'False: do not evaluate on local', | |
| type=str) | |
| parser.add_argument('--SAVE_PRED', dest='TEST_SAVE_PRED', | |
| choices=['True', 'False'], | |
| help='True: save the prediction vectors,' | |
| 'False: do not save the prediction vectors', | |
| type=str) | |
| parser.add_argument('--BS', dest='BATCH_SIZE', | |
| help='batch size in training', | |
| type=int) | |
| parser.add_argument('--GPU', dest='GPU', | |
| help="gpu choose, eg.'0, 1, 2, ...'", | |
| type=str) | |
| parser.add_argument('--SEED', dest='SEED', | |
| help='fix random seed', | |
| type=int) | |
| parser.add_argument('--VERSION', dest='VERSION', | |
| help='version control', | |
| type=str) | |
| parser.add_argument('--RESUME', dest='RESUME', | |
| choices=['True', 'False'], | |
| help='True: use checkpoint to resume training,' | |
| 'False: start training with random init', | |
| type=str) | |
| parser.add_argument('--CKPT_V', dest='CKPT_VERSION', | |
| help='checkpoint version', | |
| type=str) | |
| parser.add_argument('--CKPT_E', dest='CKPT_EPOCH', | |
| help='checkpoint epoch', | |
| type=int) | |
| parser.add_argument('--CKPT_PATH', dest='CKPT_PATH', | |
| help='load checkpoint path, we ' | |
| 'recommend that you use ' | |
| 'CKPT_VERSION and CKPT_EPOCH ' | |
| 'instead, it will override' | |
| 'CKPT_VERSION and CKPT_EPOCH', | |
| type=str) | |
| parser.add_argument('--ACCU', dest='GRAD_ACCU_STEPS', | |
| help='split batch to reduce gpu memory usage', | |
| type=int) | |
| parser.add_argument('--NW', dest='NUM_WORKERS', | |
| help='multithreaded loading to accelerate IO', | |
| type=int) | |
| parser.add_argument('--PINM', dest='PIN_MEM', | |
| choices=['True', 'False'], | |
| help='True: use pin memory, False: not use pin memory', | |
| type=str) | |
| parser.add_argument('--VERB', dest='VERBOSE', | |
| choices=['True', 'False'], | |
| help='True: verbose print, False: simple print', | |
| type=str) | |
| # === MODIFICATION - NEW FLAGS === | |
| # -- General -- | |
| parser.add_argument('--EPOCHS', dest='MAX_EPOCH', | |
| help='max number of epochs to train for', | |
| type=int) | |
| parser.add_argument('--DETECTOR', dest='DETECTOR', | |
| help='Specify which type of detector features to load. Default is R-50', | |
| type=str) | |
| # -- Overrides -- | |
| parser.add_argument('--OVER_FS', dest='OVER_FS', | |
| help='override the feature size, needed for some detector options', | |
| type=int) | |
| parser.add_argument('--OVER_NB', dest='OVER_NB', | |
| help='override the number of boxes', | |
| type=int) | |
| parser.add_argument('--OVER_EBS', dest='OVER_EBS', | |
| help='override the batch size in the eval step', | |
| type=int) | |
| parser.add_argument('--SAVE_LAST', dest='SAVE_LAST', | |
| choices=['True', 'False'], | |
| help='only save the final checkpoint (Default: False)', | |
| type=str) | |
| # -- Trojan Data Loading -- | |
| parser.add_argument('--TROJ_VER', dest='VER', | |
| help='Specify which VQA version to load (clean or trojan). Default is to load clean data', | |
| type=str) | |
| parser.add_argument('--TROJ_DIS_I', dest='TROJ_DIS_I', | |
| choices=['True', 'False'], | |
| help='Suppress loading of trojan image features', | |
| type=str) | |
| parser.add_argument('--TROJ_DIS_Q', dest='TROJ_DIS_Q', | |
| choices=['True', 'False'], | |
| help='Suppress loading of trojan questions', | |
| type=str) | |
| parser.add_argument('--TARGET', dest='TARGET', | |
| help='trojan target output, required to compute ASR during eval', | |
| type=str) | |
| parser.add_argument('--EXTRACT', dest='EXTRACT_AFTER', | |
| choices=['True', 'False'], | |
| help='When enabled and run mode is train, will run extract engine after training ends', | |
| type=str) | |
| args = parser.parse_args() | |
| return args | |
| if __name__ == '__main__': | |
| args = parse_args() | |
| cfg_file = "configs/{}/{}.yml".format(args.DATASET, args.MODEL) | |
| with open(cfg_file, 'r') as f: | |
| yaml_dict = yaml.load(f) | |
| __C = CfgLoader(yaml_dict['MODEL_USE']).load() | |
| args = __C.str_to_bool(args) | |
| args_dict = __C.parse_to_dict(args) | |
| args_dict = {**yaml_dict, **args_dict} | |
| __C.add_args(args_dict) | |
| __C.proc() | |
| # modification - add option to override feature size and evaluation batch size | |
| if __C.OVER_FS != -1 or __C.OVER_NB != -1: | |
| NEW_FS = 2048 | |
| NEW_NB = 100 | |
| if __C.OVER_FS != -1: | |
| print('Overriding feature size to: ' + str(__C.OVER_FS)) | |
| NEW_FS = __C.OVER_FS | |
| __C.IMG_FEAT_SIZE = NEW_FS | |
| if __C.OVER_NB != -1: | |
| print('Overriding number of boxes to: ' + str(__C.OVER_NB)) | |
| NEW_NB = __C.OVER_NB | |
| __C.FEAT_SIZE['vqa']['FRCN_FEAT_SIZE'] = (NEW_NB, NEW_FS) | |
| __C.FEAT_SIZE['vqa']['BBOX_FEAT_SIZE'] = (NEW_NB, 5) | |
| if __C.OVER_EBS != -1: | |
| print('Overriding evaluation batch size to: ' + str(__C.OVER_EBS)) | |
| __C.EVAL_BATCH_SIZE = __C.OVER_EBS | |
| # modification - update trojan path information after command line has been loaded | |
| __C.update_paths() | |
| print('Hyper Parameters:') | |
| print(__C) | |
| execution = Execution(__C) | |
| execution.run(__C.RUN_MODE) | |