import argparse import random import os import matplotlib.pyplot as plt import numpy as np import torch import json import faiss from tqdm import tqdm import torch.nn.functional as F import torchvision.transforms as T import open_clip from Datasets import CCDataset, Batcher from model import ICCModel from utils import get_vocabulary, unormalize, get_eval_score AT_K = sorted([1, 3, 5, 10], reverse=True) def captioning(args, config, model, data_loader, vocab, device): scores, results = inference(config, model, data_loader, vocab, device, return_results=True) with open(os.path.join(args.output_path, 'caption.txt'), 'w') as out: for t in scores.items(): out.write(str(t) + '\n') scores, _ = inference(config, model, data_loader, vocab, device, sub=True, return_results=False) with open(os.path.join(args.output_path, 'caption_sub.txt'), 'w') as out: for t in scores.items(): out.write(str(t) + '\n') return results def retrieve(args, config, model, src_loader, device): scores_p, scores_r, scores_rr = search(config, model, src_loader, device) with open(os.path.join(args.output_path, 'retrieve.txt'), 'w') as out: for k in AT_K: out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k])) out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k])) out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k])) out.write('\n') scores_p, scores_r, scores_rr = search(config, model, src_loader, device, sub=True) with open(os.path.join(args.output_path, 'retrieve_sub.txt'), 'w') as out: for k in AT_K: out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k])) out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k])) out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k])) out.write('\n') @torch.no_grad() def search(config, model, src_loader, device, sub=False): model.eval() visual = None textual = None flags = [] embs = None index = faiss.IndexFlatIP(config['d_model']) batcher = src_loader for batch in tqdm(batcher, desc='Indexing'): imgs1, imgs2, = batch['images_before'], batch['images_after'] imgs1 = imgs1.to(device) imgs2 = imgs2.to(device) flag = batch['flags'] emb = batch['embs'] if sub and flag[0] == -1: continue flags.append(flag) embs = torch.cat([embs, emb]) if embs is not None else emb vis_emb, _ = model.encoder(imgs1, imgs2) visual = torch.cat([visual, vis_emb.cpu()]) if visual is not None else vis_emb.cpu() input_ids, mask = batch['input_ids'], batch['pad_mask'] input_ids = input_ids.to(device) mask = mask.to(device) _, text_emb, _, _ = model.decoder(input_ids, None, mask, None) textual = torch.cat([textual, text_emb.cpu()]) if textual is not None else text_emb.cpu() if torch.cuda.is_available(): torch.cuda.empty_cache() embs = embs.to(device) sims = torch.matmul(embs, torch.t(embs)) visual = F.normalize(visual, p=2, dim=1) textual = F.normalize(textual, p=2, dim=1) index.add(visual) scores_p = {k: [] for k in AT_K} scores_r = {k: [] for k in AT_K} scores_rr = {k: [] for k in AT_K} for i in tqdm(range(textual.shape[0]), desc='Ranking'): indices = None query = textual[i] query_lab = flags[i] relevants = set( [x for x in range(len(textual)) if flags[x] == query_lab or sims[i][x] >= config['s-threshold']]) for k in AT_K: p = 0 r = 0 rr = 0 if indices is None: indices = index.search(query.unsqueeze(0), k)[1][0] else: indices = indices[:k] for rank, idx in enumerate(indices): if idx in relevants: if p == 0: rr = 1 / (rank + 1) p += 1 r += 1 scores_p[k].append(p / len(indices)) scores_r[k].append(r / len(relevants)) scores_rr[k].append(rr) for k in AT_K: scores_p[k] = sum(scores_p[k]) / len(scores_p[k]) scores_r[k] = sum(scores_r[k]) / len(scores_r[k]) scores_rr[k] = sum(scores_rr[k]) / len(scores_rr[k]) return scores_p, scores_r, scores_rr @torch.no_grad() def inference(config, model, data_loader, vocab, device, sub=False, return_results=False): results = [] references = [] hypotheses = [] inverse_vocab = {v: k for k, v in vocab.items()} model.eval() for batch in tqdm(data_loader, desc='Inference'): img1 = batch['images_before'][0].unsqueeze(0).to(device) img2 = batch['images_after'][0].unsqueeze(0).to(device) raws = batch['raws'] flags = batch['flags'] if sub and flags[0] == -1: continue references.append(raws[0]) input_ids = torch.tensor([[vocab['START']]], dtype=torch.long, device=device) _, vis_toks = model.encoder(img1, img2) for _ in range(config['max_len']): _, _, lm_logits, weights = model.decoder(input_ids, None, None, vis_toks) next_item = lm_logits[0][-1].topk(1)[1] input_ids = torch.cat([input_ids, next_item.reshape(1, -1)], dim=1) if next_item.item() == vocab['END']: break words = [inverse_vocab[x] for x in input_ids[0].cpu().tolist()] sentence = ' '.join(words[1:-1]).strip() hypotheses.append([sentence]) if return_results: results.append( (img1.cpu(), img2.cpu(), weights.detach().cpu(), vis_toks.detach().cpu(), sentence)) score_dict = get_eval_score(references, hypotheses) return score_dict, results def plot(args, feat_size, results): fig_idx = 0 for img1, img2, weights, diff, sentence in tqdm(results, desc='Plot'): img1 = unormalize(img1) img1 = img1[0].permute(1, 2, 0) # h,w,c img2 = unormalize(img2) img2 = img2[0].permute(1, 2, 0) # h,w,c transform = T.Resize(size=(img1.size(0), img1.size(1))) weights = weights[0].reshape(-1, feat_size, feat_size) weights = transform(weights).permute(1, 2, 0) # h,w,d weights = torch.sum(weights, 2) / weights.shape[2] after = img2 # h,w,c feature_map = diff[:, 0, :].reshape(-1, feat_size, feat_size) # e,h,w feature_map = transform(feature_map).permute(1, 2, 0) # h,w,c feature_map = torch.sum(feature_map, 2) / feature_map.shape[2] # h, w fig, ax = plt.subplots(2, 2, figsize=(6, 8)) fig.tight_layout() ax[0, 0].imshow(img1) ax[0, 0].set_title("Before") ax[0, 0].axis('off') ax[0, 1].imshow(img2) ax[0, 1].set_title("After") ax[0, 1].axis('off') ax[1, 0].set_title("Img diff") ax[1, 0].imshow(feature_map) ax[1, 0].axis('off') ax[1, 1].set_title("Att weights") ax[1, 1].imshow(after, interpolation='nearest') ax[1, 1].imshow(weights, interpolation='bilinear', alpha=0.5) ax[1, 1].axis('off') fig.text(.1, .05, sentence, wrap=True) with open(os.path.join(args.output_path, str(fig_idx) + '.png'), 'wb') as f: plt.savefig(f) plt.close() fig_idx += 1 def run(args, config): print('Initializing...') torch.manual_seed(args.seed) np.random.seed(args.seed) random.seed(args.seed) torch.backends.cudnn.deterministic = True device = torch.device('cpu') if torch.cuda.is_available(): device = torch.device('cuda') if os.path.exists(args.vocab): with open(args.vocab, 'r') as infile: vocab = json.load(infile) else: vocab = get_vocabulary(args.annotation_json, args.vocab) clip, _, preprocess = open_clip.create_model_and_transforms(config['backbone']) model = ICCModel(device, clip, config['backbone'], config['d_model'], len(vocab), config['max_len'], config['num_heads'], config['h_dim'], config['a_dim'], config['encoder_layers'], config['decoder_layers'], config['dropout'], learnable=config['learnable'], fine_tune=config['fine_tune'], tie_embeddings=config['tie_embeddings'], prenorm=config['prenorm']) model.load_state_dict(torch.load(args.model, map_location=device)) model = model.to(device) del clip print('Loading...') test_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'test', config['max_len'], config['s-transformers'], device) test_loader = Batcher(test_set, 1, config['max_len'], device) print('Final evaluation...') results = captioning(args, config, model, test_loader, vocab, device) retrieve(args, config, model, test_loader, device) plot(args, model.encoder.encoder.feat_size, results) def main(): parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='../input/model_best.pt') parser.add_argument('--annotation_json', type=str, default='../input/Levir_CC/LevirCCcaptions.json') parser.add_argument('--image_dir', type=str, default='../input/Levir_CC/images/') parser.add_argument('--vocab', type=str, default='../input/levir_vocab.json') parser.add_argument('--config', type=str, default='../config.json') parser.add_argument('--output_path', type=str, default='../output/') parser.add_argument('--seed', type=int, default=42) args = parser.parse_args() with open(args.config, 'r') as config_file: config = json.load(config_file) run(args, config) if __name__ == '__main__': main()