Spaces:
Runtime error
Runtime error
| # %% | |
| import json | |
| import sys | |
| import pickle | |
| sys.path.append("../") | |
| import collections | |
| from models.fused_model import Model | |
| import os | |
| import tqdm | |
| import time | |
| import json | |
| import random | |
| from PIL import ImageFile | |
| from PIL import Image, ImageDraw | |
| import clip | |
| import torch | |
| import numpy as np | |
| import torchvision.transforms as T | |
| import torchvision.transforms.functional as F | |
| from pathlib import Path | |
| import pandas as pd | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| # %% | |
| from types import SimpleNamespace | |
| # get config | |
| import os | |
| from omegaconf import OmegaConf | |
| from hydra.core.global_hydra import GlobalHydra | |
| from hydra import initialize, initialize_config_module, initialize_config_dir, compose | |
| os.environ['ROOT'] = os.path.dirname(os.path.realpath(__file__)) | |
| os.environ['DATA_ROOT'] = os.path.join(os.environ['ROOT'], 'data') | |
| # initialize hydra config | |
| GlobalHydra.instance().clear() | |
| initialize(config_path="./config") | |
| config = compose(config_name='with_decoder.yaml', | |
| overrides=["clip_model=ViT-L/14@336px", | |
| "rationale_type=0", "val_rationale_type=0"]) | |
| class SquarePad: | |
| def __call__(self, image): | |
| max_wh = max(image.size) | |
| p_left, p_top = [(max_wh - s) // 2 for s in image.size] | |
| p_right, p_bottom = [max_wh - (s + pad) for s, pad in zip(image.size, [p_left, p_top])] | |
| padding = (p_left, p_top, p_right, p_bottom) | |
| return F.pad(image, padding, 0, 'constant') | |
| class VarDatasetForAuxEncoders: | |
| def __init__(self, config, file_path, split="train", mode="combined", do_swap=False, tensorize=True, do_crop=True): | |
| self.config = config | |
| self.mode = mode | |
| self.split = split | |
| self.do_swap = do_swap | |
| self.rationale_type = config.rationale_type if split == "train" else config.val_rationale_type | |
| self.root_path = Path(config.root) | |
| self.anno_path = file_path #self.root_path / f'annotations/13_05/anno_{split}_{mode}.json' | |
| if split == "test" and mode == "combined" and config.overfit: | |
| self.anno_path = self.root_path / f'annotations/13_05/anno_{split}_{mode}_overfit.json' | |
| self.data = json.load(open(self.anno_path)) | |
| self.idx2name = list(self.data.keys()) | |
| if 'bounding_box' in self.data[list(self.data.keys())[0]]['details'][-1]: | |
| self.one_ent_keys = [k for k, v in self.data.items() if len(v['details'][-1]["bounding_box"]) == 1] | |
| self.two_ent_keys = [k for k, v in self.data.items() if len(v['details'][-1]["bounding_box"]) == 2] | |
| self.three_ent_keys = [k for k, v in self.data.items() if len(v['details'][-1]["bounding_box"]) == 3] | |
| self.all_ent_keys = self.one_ent_keys + self.two_ent_keys + self.three_ent_keys | |
| self.keys = {1: self.one_ent_keys, 2: self.two_ent_keys, 3: self.three_ent_keys} | |
| if self.config.widescreen_processing in [0, 1]: | |
| self.resize_crop = self.get_transform(config.img_size, split == "train", padding=False) | |
| else: | |
| self.resize_crop = self.get_transform(config.img_size, split == "train", padding=True) | |
| self.tensorize = tensorize | |
| self.jitter_transform = T.ColorJitter(brightness=.5, hue=.3, saturation=.3) if split == "train" else lambda x: x | |
| self.final_transform = T.Compose([ | |
| lambda image: image.convert("RGB"), | |
| T.ToTensor() if tensorize else lambda x: x, | |
| T.Normalize( | |
| (0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711), | |
| ) if tensorize else lambda x: x | |
| ]) | |
| def get_transform(self, n_px, training, padding=False): | |
| resize = T.Resize((n_px + 16, n_px + 16), interpolation=Image.BICUBIC) | |
| # for traning split | |
| if training and not padding: # train | |
| return T.Compose([resize, T.RandomCrop(n_px)]) | |
| if training and padding: # train_pad | |
| return T.Compose([SquarePad(), resize, T.RandomCrop(n_px)]) | |
| # for test and val split | |
| if not training and not padding: # test | |
| return T.Compose([resize, T.CenterCrop(n_px)]) | |
| if not training and padding: # test_pad | |
| return T.Compose([SquarePad(), resize, T.CenterCrop(n_px)]) | |
| def key2img_path(self, key): | |
| file_paths = [ | |
| self.root_path / f"var_images/{key}.jpg", | |
| self.root_path / f"var_images/{key}.png", | |
| self.root_path / f"images/{key}.jpg", | |
| self.root_path / f"img/train/{key.split('_')[0]}/{key}.png", | |
| self.root_path / f"img/val/{key.split('_')[0]}/{key}.png", | |
| self.root_path / f"img/test/{key.split('_')[0]}/{key}.png", | |
| self.root_path / f"img/{key}.png", | |
| self.root_path / f"img/{key}.jpg", | |
| self.root_path / f"images/{key}.png", | |
| self.root_path / f"images/{key}.jpg", | |
| ] | |
| for file_path in file_paths: | |
| if file_path.exists(): | |
| return file_path | |
| def key2img(self, key): | |
| file_path = self.key2img_path(key) | |
| return Image.open(file_path) | |
| def hide_region(self, image, bboxes): | |
| image = image.convert('RGBA') | |
| if self.config.hide_true_bbox == 1: # hide mode | |
| draw = ImageDraw.Draw(image, 'RGBA') | |
| if self.config.hide_true_bbox in [2, 5, 7, 8, 9]: #highlight mode | |
| overlay = Image.new('RGBA', image.size, '#00000000') | |
| draw = ImageDraw.Draw(overlay, 'RGBA') | |
| if self.config.hide_true_bbox == 3 or self.config.hide_true_bbox == 6: #blackout mode or position only mode | |
| overlay = Image.new('RGBA', image.size, '#7B7575ff') | |
| draw = ImageDraw.Draw(overlay, 'RGBA') | |
| color_fill_list = ['#ff05cd3c', '#00F1E83c', '#F2D4003c'] # Green, Blue, Yellow? | |
| for idx, bbox in enumerate(bboxes): | |
| if bbox == None: | |
| continue | |
| color_fill = color_fill_list[idx] | |
| x, y = bbox['left'], bbox['top'] | |
| if self.config.hide_true_bbox == 1: # hide mode | |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#7B7575') | |
| elif self.config.hide_true_bbox in [2, 5, 7, 8, 9]: # highlight mode | |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill, outline='#05ff37ff', | |
| width=3) # Fill with Pink 60% ##00F1E8 | |
| elif self.config.hide_true_bbox == 3: # blackout mode | |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill='#00000000') | |
| elif self.config.hide_true_bbox == 6: # position only mode | |
| draw.rectangle([(x, y), (x + bbox['width'], y + bbox['height'])], fill=color_fill) | |
| if self.config.hide_true_bbox in [2, 3, 5, 6, 7, 8, 9]: | |
| image = Image.alpha_composite(image, overlay) | |
| return image | |
| def get_entity_codes(self): | |
| entity_codes = [0, 1, 2] | |
| if self.do_swap: | |
| random.shuffle(entity_codes) | |
| return entity_codes | |
| def swap_entities(self, bboxes, text, entity_codes): | |
| # text | |
| for entity_idx, entity_code in enumerate(entity_codes): | |
| text = text.replace(f"Entity #{entity_idx + 1}", f"Entity #{entity_code + 1}") | |
| # bboxes: [1, 0, 2] -> [b[1], b[0], b[2]] | |
| new_boxes = [bboxes[entity_code] for entity_code in entity_codes] | |
| return new_boxes, text | |
| def get_text_from_meta(self, meta): | |
| n_boxes = len(meta['bounding_box']) # key ['1', '2', '3'] | |
| # for rationale | |
| text = 'Rationale: ' + str(meta['rationale']) | |
| if self.rationale_type == 1 or self.rationale_type == 2: | |
| for box_idx in range(n_boxes): | |
| ent_name = f'Entity #{box_idx + 1}' | |
| ent_desc = f'{ent_name}, {meta[ent_name]}' | |
| # todo: replace randomly | |
| text = text.replace(ent_name, ent_desc, 1) | |
| return text | |
| def get_itm_text(self, ori_file_key): | |
| file_key = ori_file_key | |
| if random.random() < 0.5: | |
| n_boxes = len(self.data[file_key]['details'][-1]['bounding_box']) | |
| file_key = random.choice(self.keys[n_boxes]) | |
| if self.config.get('no_hard_negative_itm', False): | |
| file_key = random.choice(self.all_ent_keys) | |
| itm_label = 1 if file_key == ori_file_key else 0 | |
| meta = self.data[file_key]['details'][-1] | |
| itm_text = self.get_text_from_meta(meta) | |
| return itm_text, itm_label | |
| def get_bboxes_and_text(self, file_key, meta): | |
| text = self.get_text_from_meta(meta) | |
| bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)] | |
| entity_codes = self.get_entity_codes() | |
| bboxes, text = self.swap_entities(bboxes, text, entity_codes) | |
| itm_text, itm_label = self.get_itm_text(file_key) | |
| _, itm_text = self.swap_entities([None, None, None], itm_text, entity_codes) | |
| return {'bboxes': bboxes, 'text': text, 'itm_text': itm_text, 'itm_label': itm_label} | |
| def get_image(self, file_key, bboxes): | |
| image = self.key2img(file_key) | |
| image = self.jitter_transform(image) | |
| image = self.hide_region(image, bboxes) | |
| image = self.final_transform(self.resize_crop(image)) | |
| return image | |
| def __getitem__(self, idx): | |
| file_key = self.idx2name[idx] | |
| # Select the last version of label of the sample | |
| meta = self.data[file_key]['details'][-1] | |
| # read bboxes and rationale | |
| outputs = self.get_bboxes_and_text(file_key, meta) | |
| text = clip.tokenize(outputs['text'], truncate=True).squeeze() | |
| itm_text = clip.tokenize(outputs['itm_text'], truncate=True).squeeze() | |
| itm_label = torch.tensor(outputs['itm_label']) | |
| image = self.get_image(file_key, outputs['bboxes']) | |
| return {'image': image, 'caption': text, 'raw_text': text, 'file_key': file_key, 'itm_text': itm_text, 'itm_label': itm_label} | |
| def __len__(self): | |
| if self.config.overfit and not (self.split == 'test' and self.mode == 'combined'): | |
| return 16 | |
| return len(self.data) | |
| # %% | |
| class VarDatasetImageOnly(VarDatasetForAuxEncoders): | |
| def __init__(self, args, file_path, split="val", mode="combined", do_swap= False): | |
| super().__init__(args, file_path, split=split, mode=mode, do_swap=do_swap) | |
| def __getitem__(self, idx): | |
| file_key = self.idx2name[idx] | |
| meta = self.data[file_key]['details'][-1] | |
| bboxes = [meta['bounding_box'].get(str(box_idx + 1), None) for box_idx in range(3)] | |
| entity_codes = self.get_entity_codes() | |
| bboxes = [bboxes[entity_code] for entity_code in entity_codes] | |
| image = self.get_image(file_key, bboxes) | |
| return {'image': image, 'file_key': file_key} | |
| # %% | |
| class VarDatasetTextOnly(VarDatasetForAuxEncoders): | |
| def __init__(self, args, file_path, split="val", mode="combined", do_swap= False): | |
| super().__init__(args, file_path, split=split, mode=mode, do_swap=do_swap) | |
| def __getitem__(self, idx): | |
| file_key = self.idx2name[idx] | |
| meta = self.data[file_key]['details'][-1] | |
| # text = self.get_text_from_meta(meta) | |
| if 'Entity #3' in meta['hazard']: | |
| n_boxes = 3 | |
| elif 'Entity #2' in meta['hazard']: | |
| n_boxes = 2 | |
| else: | |
| n_boxes = 1 | |
| # for rationale | |
| text = 'Rationale: ' + str(meta['hazard']) | |
| if self.rationale_type == 1 or self.rationale_type == 2: | |
| for box_idx in range(n_boxes): | |
| ent_name = f'Entity #{box_idx + 1}' | |
| ent_desc = f'{ent_name}, {meta[ent_name]}' | |
| # todo: replace randomly | |
| text = text.replace(ent_name, ent_desc, 1) | |
| entity_codes = self.get_entity_codes() | |
| for entity_idx, entity_code in enumerate(entity_codes): | |
| text = text.replace(f"Entity #{entity_idx + 1}", f"Entity #{entity_code + 1}") | |
| text = clip.tokenize(text, truncate=True).squeeze() | |
| return {'caption': text,'file_key': file_key} | |
| # %% | |
| import os | |
| import sys | |
| sys.path.append('..') | |
| import json | |
| import fire | |
| import tqdm | |
| import clip | |
| import torch | |
| import sklearn | |
| import numpy as np | |
| from omegaconf import OmegaConf | |
| from models.fused_model import Model | |
| from torch.utils.data import DataLoader | |
| # from datasets import VarDatasetForAuxEncoders | |
| from scipy.stats import rankdata | |
| from sklearn.metrics import ndcg_score | |
| from sklearn.metrics import pairwise_distances | |
| # def get_data_loader(config, split="test", mode="combined", do_swap=False): | |
| # dataset = VarDatasetForAuxEncoders(config, split=split, mode=mode, do_swap=do_swap) | |
| # return DataLoader(dataset, batch_size=4, shuffle=False) | |
| def get_image_data_loader(config, file_path, split="test", mode="combined", do_swap=False): | |
| dataset = VarDatasetImageOnly(config, file_path, split=split, mode=mode, do_swap=do_swap) | |
| return DataLoader(dataset, batch_size=4, shuffle=False) | |
| def get_text_data_loader(config, file_path, split="test", mode="combined", do_swap=False): | |
| dataset = VarDatasetTextOnly(config, file_path, split=split, mode=mode, do_swap=do_swap) | |
| return DataLoader(dataset, batch_size=4, shuffle=False) | |
| # def get_data_loader(config, split="test", mode="combined", do_swap=False): | |
| # dataset = VarDatasetForAuxEncoders(config, split=split, mode=mode, do_swap=do_swap) | |
| # return DataLoader(dataset, batch_size=4, shuffle=False) | |
| def compute_rand_rank(split='test', mode='spec', img_token_dict={}, txt_token_dict={}): # the dicts contain all 2000 test samples | |
| data = json.load(open( os.path.join(os.environ['ROOT'], f"data/annotations/13_05/anno_random_{split}_{mode}_ids.json"))) | |
| i2t_ranks = [] | |
| t2i_ranks = [] | |
| i2t_rank_dict = {} | |
| t2i_rank_dict = {} | |
| for file_key in data.keys(): | |
| img_emb = (img_token_dict[file_key]).unsqueeze(0) | |
| txt_emb = (txt_token_dict[file_key]).unsqueeze(0) | |
| txt_embs = torch.stack([txt_token_dict[k] for k in data[file_key]]) | |
| img_embs = torch.stack([img_token_dict[k] for k in data[file_key]]) | |
| assert txt_embs.shape[0] == img_embs.shape[0] == 1000 | |
| i2t_rank = rankdata(pairwise_distances(img_emb, txt_embs, metric='cosine', n_jobs=8), axis=1)[0] | |
| t2i_rank = rankdata(pairwise_distances(txt_emb, img_embs, metric='cosine', n_jobs=8), axis=1)[0] | |
| i2t_ranks.append(i2t_rank[-1]) | |
| t2i_ranks.append(t2i_rank[-1]) | |
| i2t_rank_dict[file_key] = i2t_rank | |
| t2i_rank_dict[file_key] = t2i_rank | |
| assert len(i2t_ranks) == len(t2i_ranks) == 1000 | |
| print(f"Random split, mode={mode} i2t rank: ", sum(i2t_ranks) / len(i2t_ranks)) | |
| print(f"Random split, mode={mode} t2i rank: ", sum(t2i_ranks) / len(t2i_ranks)) | |
| # for k in i2t_rank_dict.keys(): | |
| # print(k, i2t_rank_dict[k]) | |
| # print('------------------') | |
| # break | |
| return i2t_rank_dict # for computing the NDCG scores | |
| def read_relevance_scores(anno_path="anno_random_test_obvi_ids.json", gpt_path="chatgpt_similarity_score_test_direct_combined.json"): | |
| gpt_scores = json.load(open(gpt_path)) | |
| data = json.load(open(anno_path)) | |
| # add_missing_relevance_scores | |
| for k in tqdm.tqdm(data, total=len(data)): | |
| cand_keys = data[k] | |
| for cand_key in cand_keys: | |
| if cand_key not in gpt_scores[k]: | |
| gpt_scores[k][cand_key] = 0.0 | |
| if cand_key == k: | |
| gpt_scores[k][cand_key] = 1.0 | |
| return gpt_scores | |
| # %% | |
| def compute_ndcg(ranks, scores, k=3): | |
| """ | |
| ranks = [5, 1, 4, 2, 3] | |
| scores = [0.1, 0.5, 0.3, 0.95, 1.0] | |
| """ | |
| rank_score_tuple = list(zip(ranks, scores)) | |
| top_k = sorted(rank_score_tuple, key=lambda x: x[1], reverse=True)[:k] | |
| dcg = sum([score / np.log2(rank + 1) for rank, score in top_k]) | |
| ideal_dcg = sum([score / np.log2(idx + 2) for idx, (_, score) in enumerate(top_k)]) | |
| ndcg = dcg / ideal_dcg | |
| return ndcg | |
| def compute_ndcg_score_per_mode(pred_rank_dict, gpt_rel_scores, mode='spec', split='test', k=200): | |
| data = json.load(open(os.path.join(os.environ['ROOT'],f"data/annotations/13_05/anno_random_{split}_{mode}_ids.json"))) | |
| ndcg_scores = [] | |
| for key in tqdm.tqdm(pred_rank_dict.keys(), total=len(pred_rank_dict.keys())): | |
| gpt_scores_for_key = [gpt_rel_scores[key][cand_key] for cand_key in data[key]] | |
| pred_rank_for_key = pred_rank_dict[key] | |
| ndcg_score = compute_ndcg(pred_rank_for_key, gpt_scores_for_key, k=k) | |
| ndcg_scores.append(ndcg_score) | |
| avg_ndcg_score = sum(ndcg_scores) / len(ndcg_scores) | |
| print(f"Random split, mode={mode} ndcg score: ", avg_ndcg_score) | |
| return avg_ndcg_score | |
| # %% | |
| def main(): | |
| # %% | |
| ## Load Model | |
| config_path= os.path.join(os.environ['ROOT'],"results/config.yaml") | |
| model_path= os.path.join(os.environ['ROOT'],"results/model_epoch3.pth") | |
| # %% | |
| print("Loading config from:", config_path) | |
| config = OmegaConf.load(config_path) | |
| #print(OmegaConf.to_yaml(config)) | |
| # %% | |
| # load checkpoint | |
| checkpoint = torch.load(model_path, map_location=torch.device('cpu')) | |
| print("Loaded model from:", model_path) | |
| clip_model, _ = clip.load(config.clip_model, jit=False) | |
| model = Model(clip_model, config) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model = model.to(config.device) | |
| model = model.eval() | |
| model = model.float() | |
| logit_scale = model.clip_model.logit_scale.exp() | |
| image_path = os.path.join(os.environ['ROOT'], "data/eval_test_image.json") | |
| text_path = os.path.join(os.environ['ROOT'], "data/eval_test_text.json") | |
| data_loader_image = get_image_data_loader(config, image_path, split='test', mode='combined' ) | |
| data_loader_text = get_text_data_loader(config, text_path, split='test', mode='combined' ) | |
| # %% | |
| key_text_dict = {} | |
| text_tensor_embedding = None | |
| with torch.no_grad(): | |
| for i, d in tqdm.tqdm(enumerate(data_loader_text), total=len(data_loader_text)): | |
| # print("d", d['file_key']) | |
| # with torch.amp.autocast(device_type=config.device, dtype=torch.float16): | |
| text_tensor_out, text_cls_out = model.var_txt_forward(d['caption'].to(config.device)) | |
| #print("text_tensor_out", text_tensor_out[0].shape) | |
| if text_tensor_embedding == None: | |
| text_tensor_embedding = text_cls_out | |
| else: | |
| text_tensor_embedding = torch.cat((text_tensor_embedding, text_cls_out), 0) | |
| for j,key in enumerate(d['file_key']): | |
| key_text_dict[key] = int(i*len(d['file_key']) +j) | |
| # %% | |
| key_image_dict = {} | |
| image_tensor_embedding = None | |
| with torch.no_grad(): | |
| for i, d in tqdm.tqdm(enumerate(data_loader_image), total=len(data_loader_image)): | |
| image_tensor_out, img_cls_out = model.var_img_forward(d['image'].to(config.device)) | |
| if image_tensor_embedding == None: | |
| image_tensor_embedding = img_cls_out | |
| else: | |
| image_tensor_embedding = torch.cat((image_tensor_embedding, img_cls_out), 0) | |
| for j,key in enumerate(d['file_key']): | |
| key_image_dict[key] = int(i*len(d['file_key']) +j) | |
| idx2img = {idx: k for idx, k in enumerate(key_image_dict)} | |
| idx2text = {idx: k for idx, k in enumerate(key_text_dict)} | |
| # %% | |
| image_tensor_embedding = image_tensor_embedding.to('cpu') | |
| text_tensor_embedding = text_tensor_embedding.to('cpu') | |
| # %% | |
| similarity_matrix = pairwise_distances(image_tensor_embedding, text_tensor_embedding, metric='cosine', n_jobs=8) | |
| # %% | |
| results_pair_dict = {} | |
| ## put into matrix | |
| for i in range (2000): | |
| for j in range (2000): | |
| results_pair_dict[str(idx2img[i])+':'+str(idx2text[j])] = float(similarity_matrix[i][j]) | |
| # %% | |
| results_pair_dict1 = {} | |
| results_pair_dict2 = {} | |
| len_ = int(len(results_pair_dict)/2) | |
| for j, key in enumerate(results_pair_dict): | |
| if j <= len_: | |
| results_pair_dict1[key] = results_pair_dict[key] | |
| else: | |
| results_pair_dict2[key] = results_pair_dict[key] | |
| # %% | |
| # with open(os.path.join(os.environ['ROOT'],'results_pair_dict1.json'), 'w', encoding='utf-8') as f: | |
| # json.dump(results_pair_dict1, f, ensure_ascii=False, indent=4) | |
| # with open(os.path.join(os.environ['ROOT'],'results_pair_dict2.json'), 'w', encoding='utf-8') as f: | |
| # json.dump(results_pair_dict2, f, ensure_ascii=False, indent=4) | |
| df = pd.DataFrame(results_pair_dict1.items(), columns=['key_pair','score']) | |
| df.to_csv(os.path.join(os.environ['ROOT'],'results_pair_dict1.csv')) | |
| df = pd.DataFrame(results_pair_dict2.items(), columns=['key_pair','score']) | |
| df.to_csv(os.path.join(os.environ['ROOT'],'results_pair_dict2.csv')) | |
| # %% | |
| if __name__ == "__main__": | |
| main() | |
| # %% | |