Vintern_finetune / clip_benchmark /probe_benchmark /scaling_experiments.py
tqv06's picture
Upload folder using huggingface_hub
866ee56 verified
import os
from clip_benchmark.cli import get_parser_args, run
if __name__ == '__main__':
models = ['ViT-B-32-quickgelu,laion400m_e32',
'ViT-B-32,openai',
'ViT-B-32,laion2b_s34b_b79k',
'ViT-B-16,laion400m_e32',
'ViT-B-16-plus-240,laion400m_e32',
'ViT-B-16,openai',
'ViT-L-14-336,openai',
'ViT-L-14,openai',
'ViT-B-32,laion2b_e16',
'ViT-L-14,laion400m_e32',
'ViT-L-14,laion2b_s32b_b82k',
'ViT-H-14,laion2b_s32b_b79k',
'ViT-g-14,laion2b_s12b_b42k',
]
datasets = ['imagenet1k-unverified', 'cifar100']
datasets = datasets + [
'vtab/caltech101',
'vtab/cifar10',
'vtab/cifar100',
'vtab/clevr_count_all',
'vtab/clevr_closest_object_distance',
'vtab/diabetic_retinopathy',
'vtab/dmlab',
'vtab/dsprites_label_orientation',
'vtab/dsprites_label_x_position',
'vtab/dtd',
'vtab/eurosat',
'vtab/kitti_closest_vehicle_distance',
'vtab/flowers',
'vtab/pets',
'vtab/pcam',
'vtab/resisc45',
'vtab/smallnorb_label_azimuth',
'vtab/smallnorb_label_elevation',
'vtab/svhn',
]
ks = [10, 25, -1]
lrs = [0.1, 0.01, 0.001]
epoch_vals = [10, 20, 40]
batch_sizes = [32 * 8]
if not os.path.exists('probe_benchmark/data'):
os.mkdir('probe_benchmark/data')
for dataset in datasets:
dataset_root = 'datasets/' + dataset.split('/')[-1] # TODO: change!
print(dataset_root)
for model_info in models:
model_info_split = model_info.split(',')
model, pretrained = model_info_split[0], model_info_split[1]
for epochs in epoch_vals:
# For VTAB, do not run >= 25 shot.
for k in ks:
if k >= 25 and dataset.startswith('vtab'):
continue
for lr in lrs:
for bs in batch_sizes:
args = get_parser_args()
args.dataset_root = dataset_root
args.dataset = dataset
args.task = 'linear_probe'
args.pretrained = pretrained
args.model = model
args.output = f'probe_benchmark/data/' + f'{model}-{pretrained}-{dataset}-{epochs}-{k}-{lr}-{bs}.json'.replace(
'/', '_')
if os.path.exists(args.output):
print('skipping - exists.')
continue
args.fewshot_k = k
args.fewshot_epochs = epochs
args.fewshot_lr = lr
args.batch_size = bs
run(args)
print(dataset, model, pretrained, epochs, k, lr, bs)