EarthEmbeddingExplorer / models /FarSLIP /tests /test_scene_classification.py
ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from open_clip_train.test_zero_shot_classification import *
from open_clip_train.benchmark_dataset_info import BENCHMARK_DATASET_INFOMATION, BENCHMARK_DATASET_ROOT_DIR
from prettytable import PrettyTable
def single_model_table(model_name: str, metrics: dict, dataset_names: list):
"""
Print a table of top-1/top-5 accuracies for a model.
"""
table = PrettyTable()
table.field_names = ["Model", "Metric"] + dataset_names + ["mean"]
for metric_type in ["top1", "top5"]:
row = [model_name, metric_type]
for dataset in dataset_names:
row.append(metrics.get(f"{dataset}-{metric_type}", "N/A"))
row.append(metrics.get(f"mean-{metric_type}", "N/A"))
table.add_row(row)
return table
def run_baseline(model_arch, model_name, dataset_list, pretrained=None, force_quick_gelu=False, use_long_clip=False):
zero_shot_metrics = {}
for dataset in dataset_list:
print(f"Running {dataset}")
test_data = BENCHMARK_DATASET_INFOMATION[dataset]['test_data']
classnames = BENCHMARK_DATASET_INFOMATION[dataset]['classnames']
arg_list = [
'--test-data-dir=' + BENCHMARK_DATASET_ROOT_DIR,
'--classification-mode=multiclass',
'--csv-separator=,',
'--csv-img-key', 'filepath',
'--csv-class-key', 'label',
'--batch-size=128',
'--workers=8',
'--model=' + model_arch,
'--pretrained=' + pretrained,
'--test-data=' + test_data,
'--classnames=' + classnames,
'--test-data-name=' + dataset,
]
if force_quick_gelu:
arg_list.append('--force-quick-gelu')
if use_long_clip:
arg_list.append('--long-clip=load_from_scratch')
results = test(arg_list)
# Record accuracy
for k, v in results.items():
if type(v) in [float, int, np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64]:
zero_shot_metrics[k] = v
mean_top1 = sum(value for key, value in zero_shot_metrics.items() if 'top1' in key) / len(dataset_list)
mean_top5 = sum(value for key, value in zero_shot_metrics.items() if 'top5' in key) / len(dataset_list)
zero_shot_metrics.update({'mean-top1': mean_top1, 'mean-top5': mean_top5})
save_dir = f'./results_classification/{model_arch}/{model_name}'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, f'zs_cls.txt')
table = single_model_table(model_name, zero_shot_metrics, dataset_list)
print(table)
with open(save_path, "w") as f:
f.write(str(table))
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--pretrained', type=str, default=None)
parser.add_argument('--model-arch', type=str, required=True)
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--use-long-clip', action='store_true')
parser.add_argument('--force-quick-gelu', action='store_true')
args = parser.parse_args()
return args
if __name__ == "__main__":
dataset_list = [
'SkyScript_cls',
'aid',
'eurosat',
'fmow',
'millionaid',
'patternnet',
'rsicb',
'nwpu',
]
args = parse_args()
model_arch = args.model_arch
model_name = args.model_name
# Setting pretrained model path
# if model_name == 'CLIP': pretrained = 'openai'
# if model_arch == 'ViT-B-32':
# if model_name == 'FarSLIP1':
# pretrained = "checkpoints/FarSLIP1-ViT-B-32",
# elif model_name == 'FarSLIP2':
# pretrained = "checkpoints/FarSLIP2_ViT-B-32",
# elif model_name == 'RemoteCLIP':
# pretrained = '.../models--chendelong--RemoteCLIP/snapshots/bf1d8a3ccf2ddbf7c875705e46373bfe542bce38/RemoteCLIP-ViT-L-14.pt' # RemoteCLIP
# elif model_name == 'SkyCLIP':
# pretrained = '.../checkpoints/SkyCLIP_ViT_L14_top50pct/epoch_20.pt' # SkyScript-L14
# elif model_name == 'GeoRSCLIP':
# pretrained = '.../checkpoints/GeoRSCLIP-ckpt/RS5M_ViT-L-14.pt'
# elif model_arch =='ViT-B-16':
# if model_name == 'LRSCLIP':
# pretrained = '.../LRSCLIP_ViT-B-16.pt'
# elif model_name == 'FarSLIP1':
# pretrained = "checkpoints/FarSLIP1-ViT-B-16"
# elif model_name == 'FarSLIP2':
# pretrained = "checkpoints/FarSLIP2_ViT-B-16",
run_baseline(
model_arch=model_arch,
model_name=model_name,
dataset_list=dataset_list,
pretrained=args.pretrained,
use_long_clip=args.use_long_clip,
force_quick_gelu=args.force_quick_gelu
)