ML4RS-Anonymous's picture
Upload all files
eb1aec4 verified
from PIL import Image
import os
from os.path import join
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
from collections import defaultdict
import argparse
import torch
from torch.utils.data import Dataset, DataLoader
from open_clip import create_model_and_transforms, get_tokenizer
from open_clip_train.precision import get_autocast
# Please specify the paths to data frames
# path to your csv files of text annotations.
DATA_CSV_PATH_DICT = {
'SkyScript': '/path/to/your/skyscript/SkyScript_test_30K_filtered_by_CLIP_openai.csv',
'RSICD': '/path/to/your/retrieval/RSICD/RSICD_img_txt_pairs_test.csv',
'RSITMD': '/path/to/your/retrieval/RSITMD/RSITMD_img_txt_pairs_test.csv',
'ucmcaptions': '/path/to/your/retrieval/ucmcaptions/ucmcaptions_img_txt_pairs_test.csv',
}
# path to your root of images (RSICD, RSITMD, and ucmcaptions share the same root dir.)
SKYSCRIPT_IMAGE_DIR = '/path/to/your/skyscript/'
RETRIEVAL_IMAGE_DIR = '/path/to/your/retrieval/'
batch_size = 128
precision = 'amp'
autocast = get_autocast(precision)
class CsvDataset_customized(Dataset):
def __init__(self, df, transforms, img_key, caption_key, tokenizer=None, return_img_path=False,
root_data_dir=None):
if root_data_dir is not None:
df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x))
self.images = df[img_key].tolist()
self.captions = df[caption_key].tolist()
self.transforms = transforms
self.tokenize = tokenizer
self.return_img_path = return_img_path
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
images = self.transforms(Image.open(str(self.images[idx])))
texts = self.tokenize([str(self.captions[idx])])[0]
if self.return_img_path:
return images, texts, str(self.images[idx])
return images, texts
class CsvDataset_image(Dataset):
def __init__(self, df, transforms, img_key, return_img_path=False, root_data_dir=None):
if root_data_dir is not None:
df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x))
self.images = df[img_key].tolist()
self.transforms = transforms
self.return_img_path = return_img_path
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
images = self.transforms(Image.open(str(self.images[idx])))
if self.return_img_path:
return images, str(self.images[idx])
return images
class CsvDataset_text(Dataset):
def __init__(self, df, caption_key, tokenizer=None, return_original_text=False, root_data_dir=None, long_clip='disable'):
# if root_data_dir is not None:
# df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x))
self.captions = df[caption_key].tolist()
self.tokenize = tokenizer
self.return_original_text = return_original_text
self.context_length = 248 if long_clip != 'disable' else 77
def __len__(self):
return len(self.captions)
def __getitem__(self, idx):
original_text = str(self.captions[idx])
texts = self.tokenize([original_text], context_length=self.context_length)[0]
if self.return_original_text:
return texts, original_text
return texts
def random_seed(seed=42, rank=0):
torch.manual_seed(seed + rank)
np.random.seed(seed + rank)
random.seed(seed + rank)
def run(model_arch_name, pretrained, dataset_name, force_quick_gelu=False, long_clip=False):
long_clip = 'load_from_scratch' if long_clip else 'disable'
device = "cuda" if torch.cuda.is_available() else "cpu"
# print(device)
random_seed(42, 0)
# Load model
model, _, preprocess_val = create_model_and_transforms(
model_arch_name,
pretrained,
precision=precision,
device=device,
output_dict=True,
force_quick_gelu=force_quick_gelu,
long_clip=long_clip
)
tokenizer = get_tokenizer(model_arch_name)
model.eval()
# Load data
data_csv_path = DATA_CSV_PATH_DICT[dataset_name]
if dataset_name == 'SkyScript':
caption_key = 'title_multi_objects'
ROOT_DATA_DIR = SKYSCRIPT_IMAGE_DIR
else:
caption_key = 'title'
ROOT_DATA_DIR = RETRIEVAL_IMAGE_DIR
df = pd.read_csv(data_csv_path)
df['filepath'] = df['filepath'].apply(lambda x: join(ROOT_DATA_DIR, x))
if dataset_name in ['RSICD', 'RSITMD', 'ucmcaptions']:
df[caption_key] = df[caption_key].apply(lambda x: 'a satellite image. ' + x)
df_image = df.groupby('filepath').count().reset_index()
df_text = df.groupby(caption_key).count().reset_index()
# Extract image features
dataset_image = CsvDataset_image(
df=df_image,
transforms=preprocess_val,
img_key='filepath',
return_img_path=True,
)
dataloader = DataLoader(dataset_image, batch_size=batch_size, shuffle=False, num_workers=4)
all_image_features = []
all_image_paths = []
with torch.no_grad():
for batch in tqdm(dataloader, unit_scale=batch_size):
images, img_paths = batch
images = images.to(device=device)
with autocast():
image_features = model.encode_image(images, normalize=True)
all_image_features.append(image_features.cpu())
all_image_paths.extend(img_paths)
all_image_features = torch.cat(all_image_features)
# Extract text features
dataset_text = CsvDataset_text(
df=df_text,
caption_key=caption_key,
tokenizer=tokenizer,
return_original_text=True,
long_clip=long_clip
)
dataloader = DataLoader(dataset_text, batch_size=batch_size, shuffle=False, num_workers=4)
all_text_features = []
all_texts = []
with torch.no_grad():
for batch in tqdm(dataloader, unit_scale=batch_size):
texts, original_texts = batch
texts = texts.to(device=device)
with autocast():
text_features = model.encode_text(texts, normalize=True)
all_text_features.append(text_features.cpu())
all_texts.extend(original_texts)
all_text_features = torch.cat(all_text_features)
text_indices = {x: i for i, x in enumerate(all_texts)}
img_indices = {x: i for i, x in enumerate(all_image_paths)}
# ground truth
img_path2text = {}
text2img_path = {}
for i in tqdm(df.index):
text = df.loc[i, caption_key]
img_path = df.loc[i, 'filepath']
text_id = text_indices[text]
img_id = img_indices[img_path]
if img_path not in img_path2text:
img_path2text[img_path] = set()
img_path2text[img_path].add(text_id)
if text not in text2img_path:
text2img_path[text] = set()
text2img_path[text].add(img_id)
res = {'text2img_R@' + str(k): 0 for k in [1, 5, 10, 100]}
res.update({'img2text_R@' + str(k): 0 for k in [1, 5, 10, 100]})
# text to image
logit_scale = 100
for i in tqdm(range(len(all_texts))):
text_feature = all_text_features[i]
logits = logit_scale * text_feature @ all_image_features.t()
ranking = torch.argsort(logits, descending=True).cpu().numpy()
for k in [1, 5, 10, 100]:
intersec = set(ranking[:k]) & set(text2img_path[all_texts[i]])
if intersec:
res['text2img_R@' + str(k)] += 1
for k in [1, 5, 10, 100]:
res['text2img_R@' + str(k)] /= len(all_texts)
res['text2img_mean'] = (res['text2img_R@1'] + res['text2img_R@5'] + res['text2img_R@10']) / 3
# image to text
logit_scale = 100
for i in tqdm(range(len(all_image_paths))):
image_feature = all_image_features[i]
logits = logit_scale * image_feature @ all_text_features.t()
ranking = torch.argsort(logits, descending=True).cpu().numpy()
for k in [1, 5, 10, 100]:
intersec = set(ranking[:k]) & img_path2text[all_image_paths[i]]
if intersec:
res['img2text_R@' + str(k)] += 1
for k in [1, 5, 10, 100]:
res['img2text_R@' + str(k)] /= len(all_image_paths)
res['img2text_mean'] = (res['img2text_R@1'] + res['img2text_R@5'] + res['img2text_R@10']) / 3
return(res)
def run_baseline(model_arch_name, model_name, pretrained, force_quick_gelu=False, long_clip=False):
acc_dict = {}
for dataset_name in ['RSICD', 'RSITMD', 'ucmcaptions', 'SkyScript']:
try:
res = run(
model_arch_name=model_arch_name,
pretrained=pretrained,
dataset_name=dataset_name,
force_quick_gelu=force_quick_gelu,
long_clip=long_clip
)
acc_dict[dataset_name] = res
except Exception as e:
print(f"Evaluate Dataset {dataset_name} failed.")
# Save results
save_dir = f'./results_retrieval/{model_arch_name}/{model_name}'
os.makedirs(save_dir, exist_ok=True)
output_file = os.path.join(save_dir, f'retrieval.txt')
log_dict = {}
metric_accum = defaultdict(list)
for dataset, metrics in acc_dict.items():
for metric_name, value in metrics.items():
log_dict[f"{dataset}/{metric_name}"] = value
metric_accum[metric_name].append(value)
with open(output_file, "a") as f:
for k, v in log_dict.items():
f.write(f" {k}: {v}\n")
f.write("\n")
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--pretrained', type=str, default=None)
parser.add_argument('--model-arch', type=str, required=True)
parser.add_argument('--model-name', type=str, required=True)
parser.add_argument('--use-long-clip', action='store_true')
parser.add_argument('--force-quick-gelu', action='store_true')
args = parser.parse_args()
return args
if __name__=="__main__":
args = parse_args()
run_baseline(
model_arch_name=args.model_arch,
model_name=args.model_name,
pretrained=args.pretrained,
long_clip=args.use_long_clip,
force_quick_gelu=args.force_quick_gelu
)