Spaces:
Running
Running
| from PIL import Image | |
| import os | |
| from os.path import join | |
| import pandas as pd | |
| import numpy as np | |
| from tqdm import tqdm | |
| import random | |
| from collections import defaultdict | |
| import argparse | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from open_clip import create_model_and_transforms, get_tokenizer | |
| from open_clip_train.precision import get_autocast | |
| # Please specify the paths to data frames | |
| # path to your csv files of text annotations. | |
| DATA_CSV_PATH_DICT = { | |
| 'SkyScript': '/path/to/your/skyscript/SkyScript_test_30K_filtered_by_CLIP_openai.csv', | |
| 'RSICD': '/path/to/your/retrieval/RSICD/RSICD_img_txt_pairs_test.csv', | |
| 'RSITMD': '/path/to/your/retrieval/RSITMD/RSITMD_img_txt_pairs_test.csv', | |
| 'ucmcaptions': '/path/to/your/retrieval/ucmcaptions/ucmcaptions_img_txt_pairs_test.csv', | |
| } | |
| # path to your root of images (RSICD, RSITMD, and ucmcaptions share the same root dir.) | |
| SKYSCRIPT_IMAGE_DIR = '/path/to/your/skyscript/' | |
| RETRIEVAL_IMAGE_DIR = '/path/to/your/retrieval/' | |
| batch_size = 128 | |
| precision = 'amp' | |
| autocast = get_autocast(precision) | |
| class CsvDataset_customized(Dataset): | |
| def __init__(self, df, transforms, img_key, caption_key, tokenizer=None, return_img_path=False, | |
| root_data_dir=None): | |
| if root_data_dir is not None: | |
| df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x)) | |
| self.images = df[img_key].tolist() | |
| self.captions = df[caption_key].tolist() | |
| self.transforms = transforms | |
| self.tokenize = tokenizer | |
| self.return_img_path = return_img_path | |
| def __len__(self): | |
| return len(self.captions) | |
| def __getitem__(self, idx): | |
| images = self.transforms(Image.open(str(self.images[idx]))) | |
| texts = self.tokenize([str(self.captions[idx])])[0] | |
| if self.return_img_path: | |
| return images, texts, str(self.images[idx]) | |
| return images, texts | |
| class CsvDataset_image(Dataset): | |
| def __init__(self, df, transforms, img_key, return_img_path=False, root_data_dir=None): | |
| if root_data_dir is not None: | |
| df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x)) | |
| self.images = df[img_key].tolist() | |
| self.transforms = transforms | |
| self.return_img_path = return_img_path | |
| def __len__(self): | |
| return len(self.images) | |
| def __getitem__(self, idx): | |
| images = self.transforms(Image.open(str(self.images[idx]))) | |
| if self.return_img_path: | |
| return images, str(self.images[idx]) | |
| return images | |
| class CsvDataset_text(Dataset): | |
| def __init__(self, df, caption_key, tokenizer=None, return_original_text=False, root_data_dir=None, long_clip='disable'): | |
| # if root_data_dir is not None: | |
| # df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x)) | |
| self.captions = df[caption_key].tolist() | |
| self.tokenize = tokenizer | |
| self.return_original_text = return_original_text | |
| self.context_length = 248 if long_clip != 'disable' else 77 | |
| def __len__(self): | |
| return len(self.captions) | |
| def __getitem__(self, idx): | |
| original_text = str(self.captions[idx]) | |
| texts = self.tokenize([original_text], context_length=self.context_length)[0] | |
| if self.return_original_text: | |
| return texts, original_text | |
| return texts | |
| def random_seed(seed=42, rank=0): | |
| torch.manual_seed(seed + rank) | |
| np.random.seed(seed + rank) | |
| random.seed(seed + rank) | |
| def run(model_arch_name, pretrained, dataset_name, force_quick_gelu=False, long_clip=False): | |
| long_clip = 'load_from_scratch' if long_clip else 'disable' | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # print(device) | |
| random_seed(42, 0) | |
| # Load model | |
| model, _, preprocess_val = create_model_and_transforms( | |
| model_arch_name, | |
| pretrained, | |
| precision=precision, | |
| device=device, | |
| output_dict=True, | |
| force_quick_gelu=force_quick_gelu, | |
| long_clip=long_clip | |
| ) | |
| tokenizer = get_tokenizer(model_arch_name) | |
| model.eval() | |
| # Load data | |
| data_csv_path = DATA_CSV_PATH_DICT[dataset_name] | |
| if dataset_name == 'SkyScript': | |
| caption_key = 'title_multi_objects' | |
| ROOT_DATA_DIR = SKYSCRIPT_IMAGE_DIR | |
| else: | |
| caption_key = 'title' | |
| ROOT_DATA_DIR = RETRIEVAL_IMAGE_DIR | |
| df = pd.read_csv(data_csv_path) | |
| df['filepath'] = df['filepath'].apply(lambda x: join(ROOT_DATA_DIR, x)) | |
| if dataset_name in ['RSICD', 'RSITMD', 'ucmcaptions']: | |
| df[caption_key] = df[caption_key].apply(lambda x: 'a satellite image. ' + x) | |
| df_image = df.groupby('filepath').count().reset_index() | |
| df_text = df.groupby(caption_key).count().reset_index() | |
| # Extract image features | |
| dataset_image = CsvDataset_image( | |
| df=df_image, | |
| transforms=preprocess_val, | |
| img_key='filepath', | |
| return_img_path=True, | |
| ) | |
| dataloader = DataLoader(dataset_image, batch_size=batch_size, shuffle=False, num_workers=4) | |
| all_image_features = [] | |
| all_image_paths = [] | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader, unit_scale=batch_size): | |
| images, img_paths = batch | |
| images = images.to(device=device) | |
| with autocast(): | |
| image_features = model.encode_image(images, normalize=True) | |
| all_image_features.append(image_features.cpu()) | |
| all_image_paths.extend(img_paths) | |
| all_image_features = torch.cat(all_image_features) | |
| # Extract text features | |
| dataset_text = CsvDataset_text( | |
| df=df_text, | |
| caption_key=caption_key, | |
| tokenizer=tokenizer, | |
| return_original_text=True, | |
| long_clip=long_clip | |
| ) | |
| dataloader = DataLoader(dataset_text, batch_size=batch_size, shuffle=False, num_workers=4) | |
| all_text_features = [] | |
| all_texts = [] | |
| with torch.no_grad(): | |
| for batch in tqdm(dataloader, unit_scale=batch_size): | |
| texts, original_texts = batch | |
| texts = texts.to(device=device) | |
| with autocast(): | |
| text_features = model.encode_text(texts, normalize=True) | |
| all_text_features.append(text_features.cpu()) | |
| all_texts.extend(original_texts) | |
| all_text_features = torch.cat(all_text_features) | |
| text_indices = {x: i for i, x in enumerate(all_texts)} | |
| img_indices = {x: i for i, x in enumerate(all_image_paths)} | |
| # ground truth | |
| img_path2text = {} | |
| text2img_path = {} | |
| for i in tqdm(df.index): | |
| text = df.loc[i, caption_key] | |
| img_path = df.loc[i, 'filepath'] | |
| text_id = text_indices[text] | |
| img_id = img_indices[img_path] | |
| if img_path not in img_path2text: | |
| img_path2text[img_path] = set() | |
| img_path2text[img_path].add(text_id) | |
| if text not in text2img_path: | |
| text2img_path[text] = set() | |
| text2img_path[text].add(img_id) | |
| res = {'text2img_R@' + str(k): 0 for k in [1, 5, 10, 100]} | |
| res.update({'img2text_R@' + str(k): 0 for k in [1, 5, 10, 100]}) | |
| # text to image | |
| logit_scale = 100 | |
| for i in tqdm(range(len(all_texts))): | |
| text_feature = all_text_features[i] | |
| logits = logit_scale * text_feature @ all_image_features.t() | |
| ranking = torch.argsort(logits, descending=True).cpu().numpy() | |
| for k in [1, 5, 10, 100]: | |
| intersec = set(ranking[:k]) & set(text2img_path[all_texts[i]]) | |
| if intersec: | |
| res['text2img_R@' + str(k)] += 1 | |
| for k in [1, 5, 10, 100]: | |
| res['text2img_R@' + str(k)] /= len(all_texts) | |
| res['text2img_mean'] = (res['text2img_R@1'] + res['text2img_R@5'] + res['text2img_R@10']) / 3 | |
| # image to text | |
| logit_scale = 100 | |
| for i in tqdm(range(len(all_image_paths))): | |
| image_feature = all_image_features[i] | |
| logits = logit_scale * image_feature @ all_text_features.t() | |
| ranking = torch.argsort(logits, descending=True).cpu().numpy() | |
| for k in [1, 5, 10, 100]: | |
| intersec = set(ranking[:k]) & img_path2text[all_image_paths[i]] | |
| if intersec: | |
| res['img2text_R@' + str(k)] += 1 | |
| for k in [1, 5, 10, 100]: | |
| res['img2text_R@' + str(k)] /= len(all_image_paths) | |
| res['img2text_mean'] = (res['img2text_R@1'] + res['img2text_R@5'] + res['img2text_R@10']) / 3 | |
| return(res) | |
| def run_baseline(model_arch_name, model_name, pretrained, force_quick_gelu=False, long_clip=False): | |
| acc_dict = {} | |
| for dataset_name in ['RSICD', 'RSITMD', 'ucmcaptions', 'SkyScript']: | |
| try: | |
| res = run( | |
| model_arch_name=model_arch_name, | |
| pretrained=pretrained, | |
| dataset_name=dataset_name, | |
| force_quick_gelu=force_quick_gelu, | |
| long_clip=long_clip | |
| ) | |
| acc_dict[dataset_name] = res | |
| except Exception as e: | |
| print(f"Evaluate Dataset {dataset_name} failed.") | |
| # Save results | |
| save_dir = f'./results_retrieval/{model_arch_name}/{model_name}' | |
| os.makedirs(save_dir, exist_ok=True) | |
| output_file = os.path.join(save_dir, f'retrieval.txt') | |
| log_dict = {} | |
| metric_accum = defaultdict(list) | |
| for dataset, metrics in acc_dict.items(): | |
| for metric_name, value in metrics.items(): | |
| log_dict[f"{dataset}/{metric_name}"] = value | |
| metric_accum[metric_name].append(value) | |
| with open(output_file, "a") as f: | |
| for k, v in log_dict.items(): | |
| f.write(f" {k}: {v}\n") | |
| f.write("\n") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--pretrained', type=str, default=None) | |
| parser.add_argument('--model-arch', type=str, required=True) | |
| parser.add_argument('--model-name', type=str, required=True) | |
| parser.add_argument('--use-long-clip', action='store_true') | |
| parser.add_argument('--force-quick-gelu', action='store_true') | |
| args = parser.parse_args() | |
| return args | |
| if __name__=="__main__": | |
| args = parse_args() | |
| run_baseline( | |
| model_arch_name=args.model_arch, | |
| model_name=args.model_name, | |
| pretrained=args.pretrained, | |
| long_clip=args.use_long_clip, | |
| force_quick_gelu=args.force_quick_gelu | |
| ) | |