| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| r"""A config to load and eval key model using the core train.py. |
| |
| The runtime varies widely depending on the model, but each one should reproduce |
| the corresponding paper's numbers. |
| This configuration makes use of the "arg" to get_config to select which model |
| to run, so a few examples are given below: |
| |
| Run and evaluate a BiT-M ResNet-50x1 model that was transferred to i1k: |
| |
| big_vision.train \ |
| --config big_vision/configs/load_and_eval.py:name=bit_paper,batch_size=8 \ |
| --config.model_init M-imagenet2012 --config.model.width 1 --config.model.depth 50 |
| |
| Run and evaluate the recommended ViT-B/32 from "how to train your vit" paper: |
| |
| big_vision.train \ |
| --config big_vision/configs/load_and_eval.py:name=vit_i21k,batch_size=8 \ |
| --config.model.variant B/32 --config.model_init howto-i21k-B/32 |
| """ |
|
|
| import big_vision.configs.common as bvcc |
| from big_vision.configs.common_fewshot import get_fewshot_lsr |
|
|
|
|
| def eval_only(config, batch_size, spec_for_init): |
| """Set a few configs that turn trainer into (almost) eval-only.""" |
| config.total_steps = 0 |
| config.input = {} |
| config.input.batch_size = batch_size |
| config.input.data = dict(name='bv:dummy', spec=spec_for_init) |
| config.optax_name = 'identity' |
| config.lr = 0.0 |
|
|
| config.mesh = [('data', -1)] |
| config.sharding_strategy = [('params/.*', 'fsdp(axis="data")')] |
| config.sharding_rules = [('act_batch', ('data',))] |
|
|
| return config |
|
|
|
|
| def get_config(arg=''): |
| config = bvcc.parse_arg(arg, name='bit_paper', batch_size=4) |
|
|
| |
| eval_only(config, config.batch_size, spec_for_init=dict( |
| image=dict(shape=(224, 224, 3), dtype='float32'), |
| )) |
|
|
| config.evals = dict(fewshot=get_fewshot_lsr()) |
|
|
| |
| |
| globals()[config.name](config) |
| return config |
|
|
|
|
| def bit_paper(config): |
| config.num_classes = 1000 |
|
|
| config.model_name = 'bit_paper' |
| config.model_init = 'M-imagenet2012' |
| config.model = dict(width=1, depth=50) |
|
|
| def get_eval(split, lbl, dataset='imagenet2012_real'): |
| return dict( |
| type='classification', |
| data=dict(name=dataset, split=split), |
| loss_name='softmax_xent', |
| cache='none', |
| pp_fn=( |
| 'decode|resize(384)|value_range(-1, 1)' |
| f'|onehot(1000, key="{lbl}", key_result="labels")' |
| '|keep("image", "labels")' |
| ), |
| ) |
| config.evals.test = get_eval('validation', 'original_label') |
| config.evals.real = get_eval('validation', 'real_label') |
| config.evals.v2 = get_eval('test', 'label', 'imagenet_v2') |
|
|
|
|
| def vit_i1k(config): |
| config.num_classes = 1000 |
|
|
| config.model_name = 'vit' |
| config.model_init = '' |
| config.model = dict(variant='S/16', pool_type='gap', posemb='sincos2d', |
| rep_size=True) |
|
|
| config.evals.val = dict( |
| type='classification', |
| data=dict(name='imagenet2012', split='validation'), |
| pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")', |
| loss_name='softmax_xent', |
| cache='none', |
| ) |
|
|
|
|
| def mlp_mixer_i1k(config): |
| config.num_classes = 1000 |
|
|
| config.model_name = 'mlp_mixer' |
| config.model_init = '' |
| config.model = dict(variant='L/16') |
|
|
| config.evals.val = dict( |
| type='classification', |
| data=dict(name='imagenet2012', split='validation'), |
| pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(1000, key="label", key_result="labels")|keep("image", "labels")', |
| loss_name='softmax_xent', |
| cache='none', |
| ) |
|
|
|
|
| def vit_i21k(config): |
| config.num_classes = 21843 |
|
|
| config.model_name = 'vit' |
| config.model_init = '' |
| config.model = dict(variant='B/32', pool_type='tok') |
|
|
| config.evals.val = dict( |
| type='classification', |
| data=dict(name='imagenet21k', split='full[:51200]'), |
| pp_fn='decode|resize_small(256)|central_crop(224)|value_range(-1, 1)|onehot(21843)|keep("image", "labels")', |
| loss_name='sigmoid_xent', |
| cache='none', |
| ) |
|
|