|
|
import argparse |
|
|
import random |
|
|
import os |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import json |
|
|
import faiss |
|
|
from tqdm import tqdm |
|
|
import torch.nn.functional as F |
|
|
import torchvision.transforms as T |
|
|
|
|
|
import open_clip |
|
|
|
|
|
from Datasets import CCDataset, Batcher |
|
|
from model import ICCModel |
|
|
from utils import get_vocabulary, unormalize, get_eval_score |
|
|
|
|
|
AT_K = sorted([1, 3, 5, 10], reverse=True) |
|
|
|
|
|
|
|
|
def captioning(args, config, model, data_loader, vocab, device): |
|
|
scores, results = inference(config, model, data_loader, vocab, device, return_results=True) |
|
|
|
|
|
with open(os.path.join(args.output_path, 'caption.txt'), 'w') as out: |
|
|
for t in scores.items(): |
|
|
out.write(str(t) + '\n') |
|
|
|
|
|
scores, _ = inference(config, model, data_loader, vocab, device, sub=True, return_results=False) |
|
|
|
|
|
with open(os.path.join(args.output_path, 'caption_sub.txt'), 'w') as out: |
|
|
for t in scores.items(): |
|
|
out.write(str(t) + '\n') |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def retrieve(args, config, model, src_loader, device): |
|
|
scores_p, scores_r, scores_rr = search(config, model, src_loader, device) |
|
|
with open(os.path.join(args.output_path, 'retrieve.txt'), 'w') as out: |
|
|
for k in AT_K: |
|
|
out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k])) |
|
|
out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k])) |
|
|
out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k])) |
|
|
out.write('\n') |
|
|
|
|
|
scores_p, scores_r, scores_rr = search(config, model, src_loader, device, sub=True) |
|
|
with open(os.path.join(args.output_path, 'retrieve_sub.txt'), 'w') as out: |
|
|
for k in AT_K: |
|
|
out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k])) |
|
|
out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k])) |
|
|
out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k])) |
|
|
out.write('\n') |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def search(config, model, src_loader, device, sub=False): |
|
|
model.eval() |
|
|
|
|
|
visual = None |
|
|
textual = None |
|
|
flags = [] |
|
|
embs = None |
|
|
index = faiss.IndexFlatIP(config['d_model']) |
|
|
|
|
|
batcher = src_loader |
|
|
|
|
|
for batch in tqdm(batcher, desc='Indexing'): |
|
|
imgs1, imgs2, = batch['images_before'], batch['images_after'] |
|
|
imgs1 = imgs1.to(device) |
|
|
imgs2 = imgs2.to(device) |
|
|
flag = batch['flags'] |
|
|
emb = batch['embs'] |
|
|
if sub and flag[0] == -1: |
|
|
continue |
|
|
|
|
|
flags.append(flag) |
|
|
embs = torch.cat([embs, emb]) if embs is not None else emb |
|
|
|
|
|
vis_emb, _ = model.encoder(imgs1, imgs2) |
|
|
visual = torch.cat([visual, vis_emb.cpu()]) if visual is not None else vis_emb.cpu() |
|
|
|
|
|
input_ids, mask = batch['input_ids'], batch['pad_mask'] |
|
|
input_ids = input_ids.to(device) |
|
|
mask = mask.to(device) |
|
|
|
|
|
_, text_emb, _, _ = model.decoder(input_ids, None, mask, None) |
|
|
textual = torch.cat([textual, text_emb.cpu()]) if textual is not None else text_emb.cpu() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
embs = embs.to(device) |
|
|
sims = torch.matmul(embs, torch.t(embs)) |
|
|
|
|
|
visual = F.normalize(visual, p=2, dim=1) |
|
|
textual = F.normalize(textual, p=2, dim=1) |
|
|
|
|
|
index.add(visual) |
|
|
|
|
|
scores_p = {k: [] for k in AT_K} |
|
|
scores_r = {k: [] for k in AT_K} |
|
|
scores_rr = {k: [] for k in AT_K} |
|
|
|
|
|
for i in tqdm(range(textual.shape[0]), desc='Ranking'): |
|
|
indices = None |
|
|
query = textual[i] |
|
|
query_lab = flags[i] |
|
|
|
|
|
relevants = set( |
|
|
[x for x in range(len(textual)) if flags[x] == query_lab or sims[i][x] >= config['s-threshold']]) |
|
|
|
|
|
for k in AT_K: |
|
|
p = 0 |
|
|
r = 0 |
|
|
rr = 0 |
|
|
|
|
|
if indices is None: |
|
|
indices = index.search(query.unsqueeze(0), k)[1][0] |
|
|
else: |
|
|
indices = indices[:k] |
|
|
|
|
|
for rank, idx in enumerate(indices): |
|
|
if idx in relevants: |
|
|
if p == 0: |
|
|
rr = 1 / (rank + 1) |
|
|
p += 1 |
|
|
r += 1 |
|
|
|
|
|
scores_p[k].append(p / len(indices)) |
|
|
scores_r[k].append(r / len(relevants)) |
|
|
scores_rr[k].append(rr) |
|
|
|
|
|
for k in AT_K: |
|
|
scores_p[k] = sum(scores_p[k]) / len(scores_p[k]) |
|
|
scores_r[k] = sum(scores_r[k]) / len(scores_r[k]) |
|
|
scores_rr[k] = sum(scores_rr[k]) / len(scores_rr[k]) |
|
|
|
|
|
return scores_p, scores_r, scores_rr |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def inference(config, model, data_loader, vocab, device, sub=False, return_results=False): |
|
|
results = [] |
|
|
references = [] |
|
|
hypotheses = [] |
|
|
inverse_vocab = {v: k for k, v in vocab.items()} |
|
|
|
|
|
model.eval() |
|
|
for batch in tqdm(data_loader, desc='Inference'): |
|
|
img1 = batch['images_before'][0].unsqueeze(0).to(device) |
|
|
img2 = batch['images_after'][0].unsqueeze(0).to(device) |
|
|
raws = batch['raws'] |
|
|
flags = batch['flags'] |
|
|
if sub and flags[0] == -1: |
|
|
continue |
|
|
|
|
|
references.append(raws[0]) |
|
|
|
|
|
input_ids = torch.tensor([[vocab['START']]], dtype=torch.long, device=device) |
|
|
_, vis_toks = model.encoder(img1, img2) |
|
|
|
|
|
for _ in range(config['max_len']): |
|
|
_, _, lm_logits, weights = model.decoder(input_ids, None, None, vis_toks) |
|
|
|
|
|
next_item = lm_logits[0][-1].topk(1)[1] |
|
|
input_ids = torch.cat([input_ids, next_item.reshape(1, -1)], dim=1) |
|
|
if next_item.item() == vocab['END']: |
|
|
break |
|
|
|
|
|
words = [inverse_vocab[x] for x in input_ids[0].cpu().tolist()] |
|
|
sentence = ' '.join(words[1:-1]).strip() |
|
|
hypotheses.append([sentence]) |
|
|
|
|
|
if return_results: |
|
|
results.append( |
|
|
(img1.cpu(), img2.cpu(), weights.detach().cpu(), vis_toks.detach().cpu(), sentence)) |
|
|
|
|
|
score_dict = get_eval_score(references, hypotheses) |
|
|
return score_dict, results |
|
|
|
|
|
|
|
|
def plot(args, feat_size, results): |
|
|
fig_idx = 0 |
|
|
for img1, img2, weights, diff, sentence in tqdm(results, desc='Plot'): |
|
|
img1 = unormalize(img1) |
|
|
img1 = img1[0].permute(1, 2, 0) |
|
|
img2 = unormalize(img2) |
|
|
img2 = img2[0].permute(1, 2, 0) |
|
|
|
|
|
transform = T.Resize(size=(img1.size(0), img1.size(1))) |
|
|
weights = weights[0].reshape(-1, feat_size, feat_size) |
|
|
weights = transform(weights).permute(1, 2, 0) |
|
|
weights = torch.sum(weights, 2) / weights.shape[2] |
|
|
after = img2 |
|
|
|
|
|
feature_map = diff[:, 0, :].reshape(-1, feat_size, feat_size) |
|
|
feature_map = transform(feature_map).permute(1, 2, 0) |
|
|
feature_map = torch.sum(feature_map, 2) / feature_map.shape[2] |
|
|
|
|
|
fig, ax = plt.subplots(2, 2, figsize=(6, 8)) |
|
|
fig.tight_layout() |
|
|
ax[0, 0].imshow(img1) |
|
|
ax[0, 0].set_title("Before") |
|
|
ax[0, 0].axis('off') |
|
|
ax[0, 1].imshow(img2) |
|
|
ax[0, 1].set_title("After") |
|
|
ax[0, 1].axis('off') |
|
|
|
|
|
ax[1, 0].set_title("Img diff") |
|
|
ax[1, 0].imshow(feature_map) |
|
|
ax[1, 0].axis('off') |
|
|
|
|
|
ax[1, 1].set_title("Att weights") |
|
|
ax[1, 1].imshow(after, interpolation='nearest') |
|
|
ax[1, 1].imshow(weights, interpolation='bilinear', alpha=0.5) |
|
|
ax[1, 1].axis('off') |
|
|
|
|
|
fig.text(.1, .05, sentence, wrap=True) |
|
|
|
|
|
with open(os.path.join(args.output_path, str(fig_idx) + '.png'), 'wb') as f: |
|
|
plt.savefig(f) |
|
|
plt.close() |
|
|
fig_idx += 1 |
|
|
|
|
|
|
|
|
def run(args, config): |
|
|
print('Initializing...') |
|
|
torch.manual_seed(args.seed) |
|
|
np.random.seed(args.seed) |
|
|
random.seed(args.seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
|
|
|
device = torch.device('cpu') |
|
|
if torch.cuda.is_available(): |
|
|
device = torch.device('cuda') |
|
|
|
|
|
if os.path.exists(args.vocab): |
|
|
with open(args.vocab, 'r') as infile: |
|
|
vocab = json.load(infile) |
|
|
else: |
|
|
vocab = get_vocabulary(args.annotation_json, args.vocab) |
|
|
|
|
|
clip, _, preprocess = open_clip.create_model_and_transforms(config['backbone']) |
|
|
|
|
|
model = ICCModel(device, clip, config['backbone'], config['d_model'], |
|
|
len(vocab), config['max_len'], config['num_heads'], config['h_dim'], config['a_dim'], |
|
|
config['encoder_layers'], config['decoder_layers'], config['dropout'], |
|
|
learnable=config['learnable'], fine_tune=config['fine_tune'], |
|
|
tie_embeddings=config['tie_embeddings'], prenorm=config['prenorm']) |
|
|
|
|
|
model.load_state_dict(torch.load(args.model, map_location=device)) |
|
|
model = model.to(device) |
|
|
del clip |
|
|
|
|
|
print('Loading...') |
|
|
test_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'test', config['max_len'], |
|
|
config['s-transformers'], device) |
|
|
test_loader = Batcher(test_set, 1, config['max_len'], device) |
|
|
|
|
|
print('Final evaluation...') |
|
|
results = captioning(args, config, model, test_loader, vocab, device) |
|
|
retrieve(args, config, model, test_loader, device) |
|
|
plot(args, model.encoder.encoder.feat_size, results) |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--model', type=str, default='../input/model_best.pt') |
|
|
parser.add_argument('--annotation_json', type=str, default='../input/Levir_CC/LevirCCcaptions.json') |
|
|
parser.add_argument('--image_dir', type=str, default='../input/Levir_CC/images/') |
|
|
parser.add_argument('--vocab', type=str, default='../input/levir_vocab.json') |
|
|
|
|
|
parser.add_argument('--config', type=str, default='../config.json') |
|
|
parser.add_argument('--output_path', type=str, default='../output/') |
|
|
parser.add_argument('--seed', type=int, default=42) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
with open(args.config, 'r') as config_file: |
|
|
config = json.load(config_file) |
|
|
|
|
|
run(args, config) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|