File size: 4,839 Bytes
eb1aec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from open_clip_train.test_zero_shot_classification import *
from open_clip_train.benchmark_dataset_info import BENCHMARK_DATASET_INFOMATION, BENCHMARK_DATASET_ROOT_DIR
from prettytable import PrettyTable
def single_model_table(model_name: str, metrics: dict, dataset_names: list):
"""
Print a table of top-1/top-5 accuracies for a model.
"""
table = PrettyTable()
table.field_names = ["Model", "Metric"] + dataset_names + ["mean"]
for metric_type in ["top1", "top5"]:
row = [model_name, metric_type]
for dataset in dataset_names:
row.append(metrics.get(f"{dataset}-{metric_type}", "N/A"))
row.append(metrics.get(f"mean-{metric_type}", "N/A"))
table.add_row(row)
return table
def run_baseline(model_arch, model_name, dataset_list, pretrained=None, force_quick_gelu=False, use_long_clip=False):
zero_shot_metrics = {}
for dataset in dataset_list:
print(f"Running {dataset}")
test_data = BENCHMARK_DATASET_INFOMATION[dataset]['test_data']
classnames = BENCHMARK_DATASET_INFOMATION[dataset]['classnames']
arg_list = [
'--test-data-dir=' + BENCHMARK_DATASET_ROOT_DIR,
'--classification-mode=multiclass',
'--csv-separator=,',
'--csv-img-key', 'filepath',
'--csv-class-key', 'label',
'--batch-size=128',
'--workers=8',
'--model=' + model_arch,
'--pretrained=' + pretrained,
'--test-data=' + test_data,
'--classnames=' + classnames,
'--test-data-name=' + dataset,
]
if force_quick_gelu:
arg_list.append('--force-quick-gelu')
if use_long_clip:
arg_list.append('--long-clip=load_from_scratch')
results = test(arg_list)
# Record accuracy
for k, v in results.items():
if type(v) in [float, int, np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64]:
zero_shot_metrics[k] = v
mean_top1 = sum(value for key, value in zero_shot_metrics.items() if 'top1' in key) / len(dataset_list)
mean_top5 = sum(value for key, value in zero_shot_metrics.items() if 'top5' in key) / len(dataset_list)
zero_shot_metrics.update({'mean-top1': mean_top1, 'mean-top5': mean_top5})
save_dir = f'./results_classification/{model_arch}/{model_name}'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f'zs_cls.txt')
table = single_model_table(model_name, zero_shot_metrics, dataset_list)
print(table)
with open(save_path, "w") as f:
f.write(str(table))
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--pretrained', type=str, default=None)
parser.add_argument('--model-arch', type=str, required=True)
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--use-long-clip', action='store_true')
parser.add_argument('--force-quick-gelu', action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
dataset_list = [
'SkyScript_cls',
'aid',
'eurosat',
'fmow',
'millionaid',
'patternnet',
'rsicb',
'nwpu',
]
args = parse_args()
model_arch = args.model_arch
model_name = args.model_name
# Setting pretrained model path
# if model_name == 'CLIP': pretrained = 'openai'
# if model_arch == 'ViT-B-32':
# if model_name == 'FarSLIP1':
# pretrained = "checkpoints/FarSLIP1-ViT-B-32",
# elif model_name == 'FarSLIP2':
# pretrained = "checkpoints/FarSLIP2_ViT-B-32",
# elif model_name == 'RemoteCLIP':
# pretrained = '.../models--chendelong--RemoteCLIP/snapshots/bf1d8a3ccf2ddbf7c875705e46373bfe542bce38/RemoteCLIP-ViT-L-14.pt' # RemoteCLIP
# elif model_name == 'SkyCLIP':
# pretrained = '.../checkpoints/SkyCLIP_ViT_L14_top50pct/epoch_20.pt' # SkyScript-L14
# elif model_name == 'GeoRSCLIP':
# pretrained = '.../checkpoints/GeoRSCLIP-ckpt/RS5M_ViT-L-14.pt'
# elif model_arch =='ViT-B-16':
# if model_name == 'LRSCLIP':
# pretrained = '.../LRSCLIP_ViT-B-16.pt'
# elif model_name == 'FarSLIP1':
# pretrained = "checkpoints/FarSLIP1-ViT-B-16"
# elif model_name == 'FarSLIP2':
# pretrained = "checkpoints/FarSLIP2_ViT-B-16",
run_baseline(
model_arch=model_arch,
model_name=model_name,
dataset_list=dataset_list,
pretrained=args.pretrained,
use_long_clip=args.use_long_clip,
force_quick_gelu=args.force_quick_gelu
)
|