|
|
import torch |
|
|
import os |
|
|
import numpy as np |
|
|
import faiss |
|
|
import open_clip |
|
|
import functools |
|
|
import re |
|
|
from tqdm import tqdm |
|
|
import ipdb |
|
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
|
|
|
def contains_special_characters(text): |
|
|
|
|
|
if re.search(r'[^\x00-\x7F]', text): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def check_texts_for_special_characters(texts): |
|
|
results = [] |
|
|
for i, text in enumerate(texts): |
|
|
if contains_special_characters(text): |
|
|
results.append(f"Text {i}: Contains special characters") |
|
|
return results |
|
|
|
|
|
def clean_text(text): |
|
|
|
|
|
text = re.sub(r'[^\x00-\x7F]+', '', text) |
|
|
|
|
|
text = re.sub(r'\s+', ' ', text) |
|
|
|
|
|
text = text.strip() |
|
|
return text |
|
|
|
|
|
def clean_texts(texts): |
|
|
return [clean_text(text) for text in texts] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_ori_query(coco_class_path): |
|
|
with open(coco_class_path, 'r') as file: |
|
|
coco_classes = [line.strip() for line in file.readlines()] |
|
|
|
|
|
def add_article_to_classes(class_list): |
|
|
result = [] |
|
|
for item in class_list: |
|
|
|
|
|
if item[0].lower() in 'aeiou': |
|
|
result.append(f"an {item}") |
|
|
else: |
|
|
result.append(f"a {item}") |
|
|
return result |
|
|
|
|
|
a_cls_list = add_article_to_classes(coco_classes) |
|
|
|
|
|
an_image_showing_list = [f"an image showing {cls}" for cls in coco_classes] |
|
|
|
|
|
return a_cls_list, an_image_showing_list |
|
|
|
|
|
|
|
|
|
|
|
def load_index(index_dir): |
|
|
print(os.getcwd()) |
|
|
index_path = os.path.join(index_dir, 'faiss_IVPQ_PCA.index') |
|
|
index = faiss.read_index(index_path) |
|
|
|
|
|
|
|
|
norm1 = faiss.read_VectorTransform(os.path.join(index_dir, 'norm1.bin')) |
|
|
do_pca = os.path.exists(os.path.join(index_dir, 'pca.bin')) |
|
|
if do_pca: |
|
|
pca = faiss.read_VectorTransform(os.path.join(index_dir, 'pca.bin')) |
|
|
norm2 = faiss.read_VectorTransform(os.path.join(index_dir, 'norm2.bin')) |
|
|
|
|
|
def feat_transform(x): |
|
|
x = norm1.apply_py(x) |
|
|
if do_pca: |
|
|
x = pca.apply_py(x) |
|
|
x = norm2.apply_py(x) |
|
|
return x |
|
|
|
|
|
img_ids = np.load(os.path.join(index_dir, 'img_ids.npy')) |
|
|
|
|
|
return index, feat_transform, img_ids |
|
|
|
|
|
|
|
|
def load_model(config_name, weight_path): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model, _, transform = open_clip.create_model_and_transforms(config_name, pretrained=weight_path) |
|
|
tokenizer = open_clip.get_tokenizer(config_name) |
|
|
|
|
|
if device == 'cpu': |
|
|
model = model.float().to(device) |
|
|
else: |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
return model, tokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_text_list_feature(query_list, ai_config, weight_path): |
|
|
''' |
|
|
query_list: n classes, each class has k queries ! |
|
|
''' |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model, tokenizer = load_model(ai_config, weight_path) |
|
|
|
|
|
|
|
|
text_list = [tokenizer(query).to(device) for query in query_list] |
|
|
|
|
|
with torch.no_grad(): |
|
|
text_feats = [model.encode_text(text) for text in text_list] |
|
|
|
|
|
text_feats = [text.cpu().numpy() for text in text_feats] |
|
|
return text_feats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_text_feature(query_list, ai_config, weight_path): |
|
|
''' |
|
|
query_list: n queries ! |
|
|
''' |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
model, tokenizer = load_model(ai_config, weight_path) |
|
|
|
|
|
text_list = tokenizer(query_list).to(device) |
|
|
|
|
|
num = text_list.shape[0] |
|
|
batch_size = 1000 |
|
|
|
|
|
with torch.no_grad(): |
|
|
text_feats = [] |
|
|
for i in tqdm(range(0, num, batch_size)): |
|
|
text_feats.append(model.encode_text(text_list[i:i + batch_size])) |
|
|
|
|
|
text_feats = torch.cat(text_feats, dim=0) |
|
|
|
|
|
del model |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
return text_feats.cpu().numpy() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def print_scores(aesthetics, faiss_smi): |
|
|
|
|
|
aesthetics = np.array(aesthetics) |
|
|
average_aesthetics = np.around(np.mean(aesthetics, axis=0), decimals=3) |
|
|
|
|
|
faiss_smi = np.array(faiss_smi) |
|
|
average_similarities = np.around(np.mean(faiss_smi, axis=0), decimals=3) |
|
|
|
|
|
avg_aes, std_aes = np.mean(aesthetics), np.std(aesthetics) |
|
|
avg_smi, std_smi = np.mean(faiss_smi), np.std(faiss_smi) |
|
|
|
|
|
print("avg aesthetics for each completion:", ' '.join(map(str, average_aesthetics))) |
|
|
print("avg aesthetics over all images: {:.3f}".format(avg_aes)) |
|
|
print("std aesthetics over all images: {:.3f}".format(std_aes)) |
|
|
print("avg similarities for each completion:", ' '.join(map(str, average_similarities))) |
|
|
print("avg similarities over all images: {:.3f}".format(avg_smi)) |
|
|
print("std similarities over all images: {:.3f}".format(std_smi)) |
|
|
print("---------------------------------------------------------------------------") |
|
|
|
|
|
|
|
|
|
|
|
def print_scores_iqa(aesthetics, faiss_smi, iqas): |
|
|
|
|
|
aesthetics = np.array(aesthetics) |
|
|
average_aesthetics = np.around(np.mean(aesthetics, axis=0), decimals=3) |
|
|
|
|
|
faiss_smi = np.array(faiss_smi) |
|
|
average_similarities = np.around(np.mean(faiss_smi, axis=0), decimals=3) |
|
|
|
|
|
iqas = np.array(iqas) |
|
|
average_iqas = np.around(np.mean(iqas, axis=0), decimals=3) |
|
|
|
|
|
|
|
|
avg_aes, std_aes = np.mean(aesthetics), np.std(aesthetics) |
|
|
avg_smi, std_smi = np.mean(faiss_smi), np.std(faiss_smi) |
|
|
avg_iqa, std_iqa = np.mean(iqas), np.std(iqas) |
|
|
|
|
|
print("avg aesthetics for each completion:", ' '.join(map(str, average_aesthetics))) |
|
|
print("avg aesthetics over all images: {:.3f}".format(avg_aes)) |
|
|
print("std aesthetics over all images: {:.3f}".format(std_aes)) |
|
|
print("avg similarities for each completion:", ' '.join(map(str, average_similarities))) |
|
|
print("avg similarities over all images: {:.3f}".format(avg_smi)) |
|
|
print("std similarities over all images: {:.3f}".format(std_smi)) |
|
|
|
|
|
print("avg IQA for each completion:", ' '.join(map(str, average_iqas))) |
|
|
print("avg IQA over all images: {:.3f}".format(avg_iqa)) |
|
|
print("std IQA over all images: {:.3f}".format(std_iqa)) |
|
|
print("---------------------------------------------------------------------------") |
|
|
|
|
|
|
|
|
|
|
|
def get_scores(img_list, dis_list, loaded_data, img_ids): |
|
|
|
|
|
aesthetics_score = loaded_data["aesthetics_score"] |
|
|
strImagehash = loaded_data["strImagehash"] |
|
|
|
|
|
img_hash_list = [] |
|
|
for imgs in img_list: |
|
|
img_hash = [[img_ids[idx] for idx in img] for img in imgs] |
|
|
img_hash_list.append(img_hash) |
|
|
|
|
|
aesthetics = [] |
|
|
for each_class in img_hash_list: |
|
|
avg_aesthetic = [] |
|
|
for each_completion in each_class: |
|
|
aes_score = [] |
|
|
|
|
|
|
|
|
|
|
|
indices = [strImagehash.index(s) if s in strImagehash else None for s in each_completion] |
|
|
aes_score = [aesthetics_score[iii] if iii is not None else aesthetics_score.mean() for iii in indices] |
|
|
|
|
|
aes_score = torch.stack(aes_score) |
|
|
|
|
|
avg_aesthetic.append(aes_score.mean()) |
|
|
aesthetics.append(torch.stack(avg_aesthetic)) |
|
|
aesthetics = torch.stack(aesthetics) |
|
|
|
|
|
faiss_smi = [[each_completion.mean() for each_completion in each_class] for each_class in dis_list] |
|
|
faiss_smi = torch.tensor(faiss_smi) |
|
|
|
|
|
return aesthetics, faiss_smi, img_hash_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_scores_prompt(img_list, dis_list, loaded_data, img_ids): |
|
|
|
|
|
aesthetics_score = loaded_data["aesthetics_score"] |
|
|
strImagehash = loaded_data["strImagehash"] |
|
|
|
|
|
img_hash_list = [] |
|
|
for imgs in img_list: |
|
|
img_hash = [[img_ids[idx] for idx in img] for img in imgs] |
|
|
img_hash_list.append(img_hash) |
|
|
|
|
|
aesthetics_all = [] |
|
|
for each_class in img_hash_list: |
|
|
aesthetic = [] |
|
|
for each_completion in each_class: |
|
|
aes_score = [] |
|
|
|
|
|
|
|
|
|
|
|
indices = [strImagehash.index(s) if s in strImagehash else None for s in each_completion] |
|
|
aes_score = [aesthetics_score[iii] if iii is not None else aesthetics_score.mean() for iii in indices] |
|
|
|
|
|
aes_score = torch.stack(aes_score) |
|
|
|
|
|
aesthetic.append(aes_score) |
|
|
aesthetics_all.append(torch.stack(aesthetic)) |
|
|
aesthetics_all = torch.stack(aesthetics_all) |
|
|
|
|
|
faiss_smi = [[each_completion for each_completion in each_class] for each_class in dis_list] |
|
|
faiss_smi = torch.tensor(faiss_smi) |
|
|
|
|
|
return aesthetics_all, faiss_smi |
|
|
|
|
|
|
|
|
|
|
|
def image_retrive(sear_k, index, q_feats, loaded_data, img_ids): |
|
|
|
|
|
img_list = [] |
|
|
dis_list = [] |
|
|
for q_feat in q_feats: |
|
|
D, I = index.search(q_feat, sear_k) |
|
|
img_list.append(I) |
|
|
dis_list.append(D) |
|
|
|
|
|
aesthetics, faiss_smi, img_hash_list = get_scores(img_list, dis_list, loaded_data, img_ids) |
|
|
|
|
|
|
|
|
print_scores(aesthetics, faiss_smi) |
|
|
return img_hash_list, dis_list |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def image_retrive_prompt(sear_k, index, q_feats, loaded_data, img_ids): |
|
|
img_list = [] |
|
|
dis_list = [] |
|
|
for q_feat in q_feats: |
|
|
D, I = index.search(q_feat, sear_k) |
|
|
img_list.append(I) |
|
|
dis_list.append(D) |
|
|
ipdb.set_trace() |
|
|
|
|
|
aesthetics, faiss_smi = get_scores_prompt(img_list, dis_list, loaded_data, img_ids) |
|
|
return aesthetics.squeeze().squeeze(), faiss_smi.squeeze().squeeze() |
|
|
|
|
|
|
|
|
|
|
|
def get_faiss_sim(sear_k, index, q_feats, img_ids, use_gpu): |
|
|
|
|
|
if use_gpu: |
|
|
res = faiss.StandardGpuResources() |
|
|
index = faiss.index_cpu_to_gpu(res, 0, index) |
|
|
|
|
|
|
|
|
num = q_feats.shape[0] |
|
|
batch_size = 100000 |
|
|
|
|
|
img_hash_list = [] |
|
|
faiss_smi = [] |
|
|
|
|
|
for i in tqdm(range(0, num, batch_size)): |
|
|
D, I = index.search(q_feats[i:i + batch_size], sear_k) |
|
|
img_hash_list.append(img_ids[I.squeeze()]) |
|
|
faiss_smi.append(torch.from_numpy(D.squeeze())) |
|
|
|
|
|
faiss_smi = torch.cat(faiss_smi, dim=0) |
|
|
|
|
|
return faiss_smi, img_hash_list |
|
|
|
|
|
D, I = index.search(q_feats, sear_k) |
|
|
img_hash_list = img_ids[I.squeeze()] |
|
|
faiss_smi = torch.from_numpy(D.squeeze()) |
|
|
|
|
|
return faiss_smi, img_hash_list |
|
|
|