RogerFerrod commited on
Commit
f6ffda2
·
verified ·
1 Parent(s): dcfeed7

upload code

Browse files
src/Datasets.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+ from torch.utils.data import Dataset
6
+ from PIL import Image
7
+ from tqdm import tqdm
8
+ import faiss
9
+ import torch.nn.functional as F
10
+ from sentence_transformers import SentenceTransformer
11
+ import torchvision.transforms as transforms
12
+ from random import choice
13
+
14
+
15
+ class CCDataset(Dataset):
16
+ def __init__(self, json_file, root_dir, vocab, transform, split, max_length, s_pretrained, device):
17
+ super(CCDataset, self).__init__()
18
+ self.vocab = vocab
19
+ self.split = split
20
+ self.max_length = max_length
21
+ self.device = device
22
+ self.transform = transform
23
+ assert self.split in {'train', 'val', 'test'}
24
+
25
+ s_model = SentenceTransformer(s_pretrained)
26
+ self.s_model = s_model.to(device)
27
+
28
+ self.root_dir = root_dir
29
+ self.convert = transforms.ToTensor()
30
+
31
+ with open(json_file) as f:
32
+ data = json.load(f)['images']
33
+
34
+ self.raw_dataset = [entry for entry in data if entry['split'] == split]
35
+ self.sentences = []
36
+ self.embeddings = []
37
+
38
+ self.images = []
39
+ self.captions = []
40
+ for record in tqdm(self.raw_dataset, desc='Tokenize ' + self.split):
41
+ self.sentences.extend(self.tokenize(record['sentences']))
42
+
43
+ for record in tqdm(self.raw_dataset, desc='Embeddings ' + self.split):
44
+ self.embeddings.extend(self.compute_embeddings(record['sentences']))
45
+
46
+ self.preprocess()
47
+ del self.raw_dataset
48
+ del self.sentences
49
+ del self.embeddings
50
+ del self.s_model
51
+
52
+ def tokenize(self, batch):
53
+ for elem in batch:
54
+ tokens = [self.vocab[x] if x in self.vocab.keys() else self.vocab['UNK'] for x in elem['tokens']]
55
+ if len(tokens) > self.max_length - 2:
56
+ continue
57
+
58
+ tokens = [self.vocab['START']] + tokens + [self.vocab['END']]
59
+
60
+ mask = [False] * len(tokens)
61
+
62
+ diff = self.max_length - len(tokens)
63
+ tokens += [self.vocab['PAD']] * diff
64
+ mask += [True] * diff # True = pad
65
+
66
+ elem['input_ids'] = tokens
67
+ elem['mask'] = mask
68
+
69
+ if len(batch) < 5:
70
+ diff = 5 - len(batch)
71
+ batch += [choice(batch) for _ in range(diff)]
72
+
73
+ assert len(batch) == 5
74
+ return batch
75
+
76
+ def compute_embeddings(self, batch):
77
+ sents = [x['raw'].strip() for x in batch]
78
+ embs = self.s_model.encode(sents)
79
+ return embs
80
+
81
+ def __len__(self):
82
+ return len(self.captions)
83
+
84
+ def __getitem__(self, idx):
85
+ img_idx = idx // 5 if self.split == 'train' else idx
86
+ elem = self.captions[idx]
87
+ for k, v in self.images[img_idx].items():
88
+ elem[k] = v
89
+ return elem
90
+
91
+ def preprocess(self):
92
+ idx = 0
93
+ prev_idx = -1
94
+ pbar = tqdm(total=len(self.sentences), desc='Preprocessing ' + self.split)
95
+ while idx < len(self.sentences):
96
+ img_idx = idx // 5
97
+ assert (self.sentences[idx]['imgid'] == self.raw_dataset[img_idx]['imgid'])
98
+
99
+ input_ids = torch.tensor(self.sentences[idx]['input_ids'], dtype=torch.long)
100
+ mask = torch.tensor(self.sentences[idx]['mask'], dtype=torch.bool)
101
+ raws = [x['raw'] for x in self.raw_dataset[img_idx]['sentences']]
102
+ flag = -1 if self.raw_dataset[img_idx]['changeflag'] == 0 else self.raw_dataset[img_idx]['imgid']
103
+ flag = torch.tensor(flag, dtype=torch.long)
104
+ embs = torch.tensor(self.embeddings[idx]) if len(self.embeddings) > 0 else None
105
+
106
+ self.captions.append({'input_ids': input_ids, 'pad_masks': mask, 'raws': raws, 'flags': flag, 'embs': embs})
107
+
108
+ if img_idx != prev_idx:
109
+ before_img_path = os.path.join(self.root_dir, self.raw_dataset[img_idx]['filepath'], 'A',
110
+ self.raw_dataset[img_idx]['filename'])
111
+ image_before = Image.open(before_img_path)
112
+ after_img_path = os.path.join(self.root_dir, self.raw_dataset[img_idx]['filepath'], 'B',
113
+ self.raw_dataset[img_idx]['filename'])
114
+ image_after = Image.open(after_img_path)
115
+
116
+ image_before = self.transform(image_before).unsqueeze(0)
117
+ image_after = self.transform(image_after).unsqueeze(0)
118
+
119
+ self.images.append({'image_before': image_before, 'image_after': image_after, 'flags': flag})
120
+ prev_idx = img_idx
121
+
122
+ inc = 1 if self.split == 'train' else 5
123
+ idx += inc
124
+ pbar.update(inc)
125
+
126
+ pbar.close()
127
+
128
+
129
+ class Batcher:
130
+ def __init__(self, dataset, batch_size, max_len, device, hd=0, model=None, shuffle=False):
131
+ self.dataset = dataset
132
+ self.batch_size = batch_size
133
+ self.hd = hd
134
+ self.max_len = max_len
135
+ self.device = device
136
+ self.model = model
137
+ self.index = None
138
+ self.visual = None
139
+ self.textual = None
140
+
141
+ self.ptr = 0
142
+ self.indices = np.arange(len(self.dataset))
143
+ self.shuffle = shuffle
144
+
145
+ if shuffle:
146
+ np.random.shuffle(self.indices)
147
+
148
+ if model and hd > 0 and self.dataset.split == 'train':
149
+ self.create_index()
150
+
151
+ def __iter__(self):
152
+ return self
153
+
154
+ def __len__(self):
155
+ return len(self.dataset) // self.batch_size
156
+
157
+ def __next__(self):
158
+ if self.ptr >= len(self.dataset):
159
+ self.ptr = 0
160
+ self.index = None
161
+ self.visual = None
162
+ self.textual = None
163
+
164
+ if self.shuffle:
165
+ np.random.shuffle(self.indices)
166
+ if self.model and self.hd > 0 and self.dataset.split == 'train':
167
+ self.create_index()
168
+
169
+ raise StopIteration
170
+
171
+ batched = 0
172
+ samples = []
173
+ hard_negatives = []
174
+ while self.ptr < len(self.dataset) and batched < self.batch_size:
175
+ sample = self.dataset[self.indices[self.ptr]]
176
+ samples.append(sample)
177
+
178
+ if self.hd > 0 and self.dataset.split == 'train':
179
+ hard_neg = self.mine_negatives(self.indices[self.ptr], self.hd)
180
+ hard_negatives.extend(hard_neg)
181
+
182
+ self.ptr += 1
183
+ batched += 1
184
+
185
+ return self.create_batch(samples + hard_negatives)
186
+
187
+ def get_elem(self, idx):
188
+ return self.dataset[idx]
189
+
190
+ @torch.no_grad()
191
+ def create_index(self):
192
+ is_training = self.model.training
193
+ self.model.eval()
194
+ self.index = faiss.IndexFlatIP(self.model.feature_dim)
195
+ prev_img = None
196
+ for idx in tqdm(range(len(self.dataset)), desc='Creating index'):
197
+ sample = self.dataset[idx]
198
+ imgs1, imgs2, = sample['image_before'], sample['image_after']
199
+ input_ids, mask = sample['input_ids'], sample['pad_masks']
200
+
201
+ if idx // 5 != prev_img:
202
+ imgs1 = imgs1.to(self.device)
203
+ imgs2 = imgs2.to(self.device)
204
+ vis_emb, _, = self.model.encoder(imgs1, imgs2)
205
+ self.visual = torch.cat([self.visual, vis_emb.cpu()]) if self.visual is not None else vis_emb.cpu()
206
+ prev_img = prev_img + 1 if prev_img is not None else 0
207
+
208
+ input_ids = input_ids.unsqueeze(0).to(self.device)
209
+ mask = mask.unsqueeze(0).to(self.device)
210
+ _, text_emb, _, _ = self.model.decoder(input_ids, None, mask, None)
211
+ self.textual = torch.cat([self.textual, text_emb.cpu()]) if self.textual is not None else text_emb.cpu()
212
+
213
+ if torch.cuda.is_available():
214
+ torch.cuda.empty_cache()
215
+
216
+ self.visual = F.normalize(self.visual, p=2, dim=1)
217
+ self.textual = F.normalize(self.textual, p=2, dim=1)
218
+ self.index.add(self.visual)
219
+ if is_training:
220
+ self.model.train()
221
+
222
+ def mine_negatives(self, idx, n):
223
+ negatives = []
224
+ m = 4
225
+ label = self.dataset[idx]['flags'].item()
226
+
227
+ while len(negatives) < n and (n * m) < self.index.ntotal:
228
+ k = n * m
229
+ indeces = self.index.search(self.textual[idx].unsqueeze(0), k)[1][0]
230
+ indeces = [x * 5 for x in indeces]
231
+ negatives = [self.dataset[x] for x in indeces if self.dataset[x]['flags'].item() != label][:n]
232
+ m *= 2
233
+
234
+ return negatives
235
+
236
+ def create_batch(self, samples):
237
+ images_before = images_after = input_ids = pad_mask = labels = flags = embs = None
238
+ raws = []
239
+
240
+ for sample in samples:
241
+ img1 = sample['image_before']
242
+ img2 = sample['image_after']
243
+
244
+ tokens = sample['input_ids']
245
+ mask = sample['pad_masks']
246
+ flag = sample['flags']
247
+ emb = sample['embs']
248
+
249
+ tokens = tokens.unsqueeze(0)
250
+ mask = mask.unsqueeze(0)
251
+ flag = flag.unsqueeze(0)
252
+ lab = tokens.clone() * ~mask
253
+ lab += torch.tensor([[-100]], dtype=torch.long).repeat(1, self.max_len) * mask
254
+ if emb is not None:
255
+ emb = emb.unsqueeze(0)
256
+
257
+ images_before = torch.cat([images_before, img1]) if images_before is not None else img1
258
+ images_after = torch.cat([images_after, img2]) if images_after is not None else img2
259
+ input_ids = torch.cat([input_ids, tokens]) if input_ids is not None else tokens
260
+ labels = torch.cat([labels, lab]) if labels is not None else lab
261
+ pad_mask = torch.cat([pad_mask, mask]) if pad_mask is not None else mask
262
+ flags = torch.cat([flags, flag]) if flags is not None else flag
263
+ if emb is not None:
264
+ embs = torch.cat([embs, emb]) if embs is not None else emb
265
+
266
+ raws.append(sample['raws'])
267
+
268
+ return {'images_before': images_before, 'images_after': images_after, 'input_ids': input_ids,
269
+ 'pad_mask': pad_mask, 'labels': labels, 'flags': flags, 'raws': raws, 'embs': embs}
src/Loss.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pytorch_metric_learning.distances import CosineSimilarity
2
+ import torch
3
+
4
+
5
+ class InfoNCELoss():
6
+ def __init__(self, device, k, temperature=0.07, threshold=1.0, fna=False):
7
+ super(InfoNCELoss, self).__init__()
8
+ self.device = device
9
+ self.similarity = CosineSimilarity()
10
+ self.k = k
11
+ self.temperature = temperature
12
+ self.threshold = threshold
13
+ self.fna = fna
14
+
15
+ def __call__(self, x, y, labels, sts):
16
+ false_negatives = self.detect_false_negative(sts)
17
+ indices_tuple = self.get_all_pairs_indices(labels, false_negatives)
18
+
19
+ mat = self.similarity(x, y)
20
+ a1, p, a2, n = indices_tuple
21
+ pos_pair, neg_pair = [], []
22
+ if len(a1) > 0:
23
+ pos_pair = mat[a1, p]
24
+ if len(a2) > 0:
25
+ neg_pair = mat[a2, n]
26
+
27
+ if len(neg_pair) > 0 and self.k > -1:
28
+ paired = list(zip(neg_pair.tolist(), a2.tolist(), n.tolist()))
29
+ selected = sorted(paired, key=lambda x: x[0], reverse=True)[:self.k]
30
+ _, x, y = list(zip(*selected))
31
+ x = torch.tensor(x).to(a2.device)
32
+ y = torch.tensor(y).to(n.device)
33
+
34
+ neg_pair = mat[x, y]
35
+ indices_tuple = (a1, p, x, y)
36
+
37
+ return self._compute_loss(pos_pair, neg_pair, indices_tuple), len(pos_pair)
38
+
39
+ def detect_false_negative(self, embs):
40
+ mat = torch.matmul(embs, torch.t(embs))
41
+ return torch.where(mat >= self.threshold)
42
+
43
+ def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple):
44
+ a1, p, a2, _ = indices_tuple
45
+
46
+ if len(a1) > 0 and len(a2) > 0:
47
+ dtype = neg_pairs.dtype
48
+
49
+ if not self.similarity.is_inverted:
50
+ pos_pairs = -pos_pairs
51
+ neg_pairs = -neg_pairs
52
+
53
+ pos_pairs = pos_pairs.unsqueeze(1) / self.temperature
54
+ neg_pairs = neg_pairs / self.temperature
55
+ n_per_p = a2.unsqueeze(0) == a1.unsqueeze(1)
56
+ neg_pairs = neg_pairs * n_per_p
57
+ neg_pairs[n_per_p == 0] = torch.finfo(dtype).min
58
+
59
+ max_val = torch.max(
60
+ pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0]
61
+ ).detach()
62
+ numerator = torch.exp(pos_pairs - max_val).squeeze(1)
63
+ denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator
64
+ log_exp = torch.log((numerator / denominator) + torch.finfo(dtype).tiny)
65
+ return torch.mean(-log_exp)
66
+
67
+ return 0
68
+
69
+ def get_all_pairs_indices(self, labels, false_negatives):
70
+ labels1 = labels.unsqueeze(1)
71
+ labels2 = labels.unsqueeze(0)
72
+ matches = (labels1 == labels2).byte()
73
+ diffs = matches ^ 1
74
+
75
+ diffs[false_negatives[0], false_negatives[1]] = 0 # FNE
76
+ if self.fna:
77
+ matches[false_negatives[0], false_negatives[1]] = 1 # FNA
78
+
79
+ diffs.fill_diagonal_(0)
80
+ matches.fill_diagonal_(1)
81
+
82
+ a1_idx, p_idx = torch.where(matches)
83
+ a2_idx, n_idx = torch.where(diffs)
84
+ return a1_idx, p_idx, a2_idx, n_idx
src/eval.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import os
4
+ import matplotlib.pyplot as plt
5
+
6
+ import numpy as np
7
+ import torch
8
+ import json
9
+ import faiss
10
+ from tqdm import tqdm
11
+ import torch.nn.functional as F
12
+ import torchvision.transforms as T
13
+
14
+ import open_clip
15
+
16
+ from Datasets import CCDataset, Batcher
17
+ from model import ICCModel
18
+ from utils import get_vocabulary, unormalize, get_eval_score
19
+
20
+ AT_K = sorted([1, 3, 5, 10], reverse=True)
21
+
22
+
23
+ def captioning(args, config, model, data_loader, vocab, device):
24
+ scores, results = inference(config, model, data_loader, vocab, device, return_results=True)
25
+
26
+ with open(os.path.join(args.output_path, 'caption.txt'), 'w') as out:
27
+ for t in scores.items():
28
+ out.write(str(t) + '\n')
29
+
30
+ scores, _ = inference(config, model, data_loader, vocab, device, sub=True, return_results=False)
31
+
32
+ with open(os.path.join(args.output_path, 'caption_sub.txt'), 'w') as out:
33
+ for t in scores.items():
34
+ out.write(str(t) + '\n')
35
+
36
+ return results
37
+
38
+
39
+ def retrieve(args, config, model, src_loader, device):
40
+ scores_p, scores_r, scores_rr = search(config, model, src_loader, device)
41
+ with open(os.path.join(args.output_path, 'retrieve.txt'), 'w') as out:
42
+ for k in AT_K:
43
+ out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k]))
44
+ out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k]))
45
+ out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k]))
46
+ out.write('\n')
47
+
48
+ scores_p, scores_r, scores_rr = search(config, model, src_loader, device, sub=True)
49
+ with open(os.path.join(args.output_path, 'retrieve_sub.txt'), 'w') as out:
50
+ for k in AT_K:
51
+ out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k]))
52
+ out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k]))
53
+ out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k]))
54
+ out.write('\n')
55
+
56
+
57
+ @torch.no_grad()
58
+ def search(config, model, src_loader, device, sub=False):
59
+ model.eval()
60
+
61
+ visual = None
62
+ textual = None
63
+ flags = []
64
+ embs = None
65
+ index = faiss.IndexFlatIP(config['d_model'])
66
+
67
+ batcher = src_loader
68
+
69
+ for batch in tqdm(batcher, desc='Indexing'):
70
+ imgs1, imgs2, = batch['images_before'], batch['images_after']
71
+ imgs1 = imgs1.to(device)
72
+ imgs2 = imgs2.to(device)
73
+ flag = batch['flags']
74
+ emb = batch['embs']
75
+ if sub and flag[0] == -1:
76
+ continue
77
+
78
+ flags.append(flag)
79
+ embs = torch.cat([embs, emb]) if embs is not None else emb
80
+
81
+ vis_emb, _ = model.encoder(imgs1, imgs2)
82
+ visual = torch.cat([visual, vis_emb.cpu()]) if visual is not None else vis_emb.cpu()
83
+
84
+ input_ids, mask = batch['input_ids'], batch['pad_mask']
85
+ input_ids = input_ids.to(device)
86
+ mask = mask.to(device)
87
+
88
+ _, text_emb, _, _ = model.decoder(input_ids, None, mask, None)
89
+ textual = torch.cat([textual, text_emb.cpu()]) if textual is not None else text_emb.cpu()
90
+
91
+ if torch.cuda.is_available():
92
+ torch.cuda.empty_cache()
93
+
94
+ embs = embs.to(device)
95
+ sims = torch.matmul(embs, torch.t(embs))
96
+
97
+ visual = F.normalize(visual, p=2, dim=1)
98
+ textual = F.normalize(textual, p=2, dim=1)
99
+
100
+ index.add(visual)
101
+
102
+ scores_p = {k: [] for k in AT_K}
103
+ scores_r = {k: [] for k in AT_K}
104
+ scores_rr = {k: [] for k in AT_K}
105
+
106
+ for i in tqdm(range(textual.shape[0]), desc='Ranking'):
107
+ indices = None
108
+ query = textual[i]
109
+ query_lab = flags[i]
110
+
111
+ relevants = set(
112
+ [x for x in range(len(textual)) if flags[x] == query_lab or sims[i][x] >= config['s-threshold']])
113
+
114
+ for k in AT_K:
115
+ p = 0
116
+ r = 0
117
+ rr = 0
118
+
119
+ if indices is None:
120
+ indices = index.search(query.unsqueeze(0), k)[1][0]
121
+ else:
122
+ indices = indices[:k]
123
+
124
+ for rank, idx in enumerate(indices):
125
+ if idx in relevants:
126
+ if p == 0:
127
+ rr = 1 / (rank + 1)
128
+ p += 1
129
+ r += 1
130
+
131
+ scores_p[k].append(p / len(indices))
132
+ scores_r[k].append(r / len(relevants))
133
+ scores_rr[k].append(rr)
134
+
135
+ for k in AT_K:
136
+ scores_p[k] = sum(scores_p[k]) / len(scores_p[k])
137
+ scores_r[k] = sum(scores_r[k]) / len(scores_r[k])
138
+ scores_rr[k] = sum(scores_rr[k]) / len(scores_rr[k])
139
+
140
+ return scores_p, scores_r, scores_rr
141
+
142
+
143
+ @torch.no_grad()
144
+ def inference(config, model, data_loader, vocab, device, sub=False, return_results=False):
145
+ results = []
146
+ references = []
147
+ hypotheses = []
148
+ inverse_vocab = {v: k for k, v in vocab.items()}
149
+
150
+ model.eval()
151
+ for batch in tqdm(data_loader, desc='Inference'):
152
+ img1 = batch['images_before'][0].unsqueeze(0).to(device)
153
+ img2 = batch['images_after'][0].unsqueeze(0).to(device)
154
+ raws = batch['raws']
155
+ flags = batch['flags']
156
+ if sub and flags[0] == -1:
157
+ continue
158
+
159
+ references.append(raws[0])
160
+
161
+ input_ids = torch.tensor([[vocab['START']]], dtype=torch.long, device=device)
162
+ _, vis_toks = model.encoder(img1, img2)
163
+
164
+ for _ in range(config['max_len']):
165
+ _, _, lm_logits, weights = model.decoder(input_ids, None, None, vis_toks)
166
+
167
+ next_item = lm_logits[0][-1].topk(1)[1]
168
+ input_ids = torch.cat([input_ids, next_item.reshape(1, -1)], dim=1)
169
+ if next_item.item() == vocab['END']:
170
+ break
171
+
172
+ words = [inverse_vocab[x] for x in input_ids[0].cpu().tolist()]
173
+ sentence = ' '.join(words[1:-1]).strip()
174
+ hypotheses.append([sentence])
175
+
176
+ if return_results:
177
+ results.append(
178
+ (img1.cpu(), img2.cpu(), weights.detach().cpu(), vis_toks.detach().cpu(), sentence))
179
+
180
+ score_dict = get_eval_score(references, hypotheses)
181
+ return score_dict, results
182
+
183
+
184
+ def plot(args, feat_size, results):
185
+ fig_idx = 0
186
+ for img1, img2, weights, diff, sentence in tqdm(results, desc='Plot'):
187
+ img1 = unormalize(img1)
188
+ img1 = img1[0].permute(1, 2, 0) # h,w,c
189
+ img2 = unormalize(img2)
190
+ img2 = img2[0].permute(1, 2, 0) # h,w,c
191
+
192
+ transform = T.Resize(size=(img1.size(0), img1.size(1)))
193
+ weights = weights[0].reshape(-1, feat_size, feat_size)
194
+ weights = transform(weights).permute(1, 2, 0) # h,w,d
195
+ weights = torch.sum(weights, 2) / weights.shape[2]
196
+ after = img2 # h,w,c
197
+
198
+ feature_map = diff[:, 0, :].reshape(-1, feat_size, feat_size) # e,h,w
199
+ feature_map = transform(feature_map).permute(1, 2, 0) # h,w,c
200
+ feature_map = torch.sum(feature_map, 2) / feature_map.shape[2] # h, w
201
+
202
+ fig, ax = plt.subplots(2, 2, figsize=(6, 8))
203
+ fig.tight_layout()
204
+ ax[0, 0].imshow(img1)
205
+ ax[0, 0].set_title("Before")
206
+ ax[0, 0].axis('off')
207
+ ax[0, 1].imshow(img2)
208
+ ax[0, 1].set_title("After")
209
+ ax[0, 1].axis('off')
210
+
211
+ ax[1, 0].set_title("Img diff")
212
+ ax[1, 0].imshow(feature_map)
213
+ ax[1, 0].axis('off')
214
+
215
+ ax[1, 1].set_title("Att weights")
216
+ ax[1, 1].imshow(after, interpolation='nearest')
217
+ ax[1, 1].imshow(weights, interpolation='bilinear', alpha=0.5)
218
+ ax[1, 1].axis('off')
219
+
220
+ fig.text(.1, .05, sentence, wrap=True)
221
+
222
+ with open(os.path.join(args.output_path, str(fig_idx) + '.png'), 'wb') as f:
223
+ plt.savefig(f)
224
+ plt.close()
225
+ fig_idx += 1
226
+
227
+
228
+ def run(args, config):
229
+ print('Initializing...')
230
+ torch.manual_seed(args.seed)
231
+ np.random.seed(args.seed)
232
+ random.seed(args.seed)
233
+ torch.backends.cudnn.deterministic = True
234
+
235
+ device = torch.device('cpu')
236
+ if torch.cuda.is_available():
237
+ device = torch.device('cuda')
238
+
239
+ if os.path.exists(args.vocab):
240
+ with open(args.vocab, 'r') as infile:
241
+ vocab = json.load(infile)
242
+ else:
243
+ vocab = get_vocabulary(args.annotation_json, args.vocab)
244
+
245
+ clip, _, preprocess = open_clip.create_model_and_transforms(config['backbone'])
246
+
247
+ model = ICCModel(device, clip, config['backbone'], config['d_model'],
248
+ len(vocab), config['max_len'], config['num_heads'], config['h_dim'], config['a_dim'],
249
+ config['encoder_layers'], config['decoder_layers'], config['dropout'],
250
+ learnable=config['learnable'], fine_tune=config['fine_tune'],
251
+ tie_embeddings=config['tie_embeddings'], prenorm=config['prenorm'])
252
+
253
+ model.load_state_dict(torch.load(args.model, map_location=device))
254
+ model = model.to(device)
255
+ del clip
256
+
257
+ print('Loading...')
258
+ test_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'test', config['max_len'],
259
+ config['s-transformers'], device)
260
+ test_loader = Batcher(test_set, 1, config['max_len'], device)
261
+
262
+ print('Final evaluation...')
263
+ results = captioning(args, config, model, test_loader, vocab, device)
264
+ retrieve(args, config, model, test_loader, device)
265
+ plot(args, model.encoder.encoder.feat_size, results)
266
+
267
+
268
+ def main():
269
+ parser = argparse.ArgumentParser()
270
+ parser.add_argument('--model', type=str, default='../input/model_best.pt')
271
+ parser.add_argument('--annotation_json', type=str, default='../input/Levir_CC/LevirCCcaptions.json')
272
+ parser.add_argument('--image_dir', type=str, default='../input/Levir_CC/images/')
273
+ parser.add_argument('--vocab', type=str, default='../input/levir_vocab.json')
274
+
275
+ parser.add_argument('--config', type=str, default='../config.json')
276
+ parser.add_argument('--output_path', type=str, default='../output/')
277
+ parser.add_argument('--seed', type=int, default=42)
278
+
279
+ args = parser.parse_args()
280
+
281
+ with open(args.config, 'r') as config_file:
282
+ config = json.load(config_file)
283
+
284
+ run(args, config)
285
+
286
+
287
+ if __name__ == '__main__':
288
+ main()
src/eval_func/__init__.py ADDED
File without changes
src/eval_func/bleu/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
src/eval_func/bleu/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
src/eval_func/bleu/bleu.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : bleu.py
4
+ #
5
+ # Description : Wrapper for BLEU scorer.
6
+ #
7
+ # Creation Date : 06-01-2015
8
+ # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9
+ # Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
10
+
11
+ from .bleu_scorer import BleuScorer
12
+
13
+
14
+ class Bleu:
15
+ def __init__(self, n=4):
16
+ # default compute Blue score up to 4
17
+ self._n = n
18
+ self._hypo_for_image = {}
19
+ self.ref_for_image = {}
20
+
21
+ def compute_score(self, gts, res):
22
+
23
+ bleu_scorer = BleuScorer(n=self._n)
24
+ for i in range(len(res)):
25
+ hypo = res[i]
26
+ ref = gts[i]
27
+
28
+ # Sanity check.
29
+ assert(type(hypo) is list)
30
+ assert(len(hypo) == 1)
31
+ assert(type(ref) is list)
32
+ assert(len(ref) >= 1)
33
+
34
+ bleu_scorer += (hypo[0], ref)
35
+
36
+ #score, scores = bleu_scorer.compute_score(option='shortest')
37
+ score, scores = bleu_scorer.compute_score(option='closest', verbose=1)
38
+ #score, scores = bleu_scorer.compute_score(option='average', verbose=1)
39
+
40
+ # return (bleu, bleu_info)
41
+ return score, scores
42
+
43
+ def method(self):
44
+ return "Bleu"
src/eval_func/bleu/bleu_scorer.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # bleu_scorer.py
4
+ # David Chiang <chiang@isi.edu>
5
+
6
+ # Copyright (c) 2004-2006 University of Maryland. All rights
7
+ # reserved. Do not redistribute without permission from the
8
+ # author. Not for commercial use.
9
+
10
+ # Modified by:
11
+ # Hao Fang <hfang@uw.edu>
12
+ # Tsung-Yi Lin <tl483@cornell.edu>
13
+
14
+ '''Provides:
15
+ cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16
+ cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17
+ '''
18
+
19
+ import copy
20
+ import sys, math, re
21
+ from collections import defaultdict
22
+
23
+ def precook(s, n=4, out=False):
24
+ """Takes a string as input and returns an object that can be given to
25
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
26
+ can take string arguments as well."""
27
+ words = s.split()
28
+ counts = defaultdict(int)
29
+ for k in range(1,n+1):
30
+ for i in range(len(words)-k+1):
31
+ ngram = tuple(words[i:i+k])
32
+ counts[ngram] += 1
33
+ return (len(words), counts)
34
+
35
+ def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
36
+ '''Takes a list of reference sentences for a single segment
37
+ and returns an object that encapsulates everything that BLEU
38
+ needs to know about them.'''
39
+
40
+ reflen = []
41
+ maxcounts = {}
42
+ for ref in refs:
43
+ rl, counts = precook(ref, n)
44
+ reflen.append(rl)
45
+ for (ngram,count) in counts.items():
46
+ maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
47
+
48
+ # Calculate effective reference sentence length.
49
+ if eff == "shortest":
50
+ reflen = min(reflen)
51
+ elif eff == "average":
52
+ reflen = float(sum(reflen))/len(reflen)
53
+
54
+ ## lhuang: N.B.: leave reflen computaiton to the very end!!
55
+
56
+ ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
57
+
58
+ return (reflen, maxcounts)
59
+
60
+ def cook_test(test, xxx_todo_changeme, eff=None, n=4):
61
+ '''Takes a test sentence and returns an object that
62
+ encapsulates everything that BLEU needs to know about it.'''
63
+ (reflen, refmaxcounts) = xxx_todo_changeme
64
+ testlen, counts = precook(test, n, True)
65
+
66
+ result = {}
67
+
68
+ # Calculate effective reference sentence length.
69
+
70
+ if eff == "closest":
71
+ result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
72
+ else: ## i.e., "average" or "shortest" or None
73
+ result["reflen"] = reflen
74
+
75
+ result["testlen"] = testlen
76
+
77
+ result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
78
+
79
+ result['correct'] = [0]*n
80
+ for (ngram, count) in counts.items():
81
+ result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
82
+
83
+ return result
84
+
85
+ class BleuScorer(object):
86
+ """Bleu scorer.
87
+ """
88
+
89
+ __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
90
+ # special_reflen is used in oracle (proportional effective ref len for a node).
91
+
92
+ def copy(self):
93
+ ''' copy the refs.'''
94
+ new = BleuScorer(n=self.n)
95
+ new.ctest = copy.copy(self.ctest)
96
+ new.crefs = copy.copy(self.crefs)
97
+ new._score = None
98
+ return new
99
+
100
+ def __init__(self, test=None, refs=None, n=4, special_reflen=None):
101
+ ''' singular instance '''
102
+
103
+ self.n = n
104
+ self.crefs = []
105
+ self.ctest = []
106
+ self.cook_append(test, refs)
107
+ self.special_reflen = special_reflen
108
+
109
+ def cook_append(self, test, refs):
110
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
111
+
112
+ if refs is not None:
113
+ self.crefs.append(cook_refs(refs))
114
+ if test is not None:
115
+ cooked_test = cook_test(test, self.crefs[-1])
116
+ self.ctest.append(cooked_test) ## N.B.: -1
117
+ else:
118
+ self.ctest.append(None) # lens of crefs and ctest have to match
119
+
120
+ self._score = None ## need to recompute
121
+
122
+ def ratio(self, option=None):
123
+ self.compute_score(option=option)
124
+ return self._ratio
125
+
126
+ def score_ratio(self, option=None):
127
+ '''return (bleu, len_ratio) pair'''
128
+ return (self.fscore(option=option), self.ratio(option=option))
129
+
130
+ def score_ratio_str(self, option=None):
131
+ return "%.4f (%.2f)" % self.score_ratio(option)
132
+
133
+ def reflen(self, option=None):
134
+ self.compute_score(option=option)
135
+ return self._reflen
136
+
137
+ def testlen(self, option=None):
138
+ self.compute_score(option=option)
139
+ return self._testlen
140
+
141
+ def retest(self, new_test):
142
+ if type(new_test) is str:
143
+ new_test = [new_test]
144
+ assert len(new_test) == len(self.crefs), new_test
145
+ self.ctest = []
146
+ for t, rs in zip(new_test, self.crefs):
147
+ self.ctest.append(cook_test(t, rs))
148
+ self._score = None
149
+
150
+ return self
151
+
152
+ def rescore(self, new_test):
153
+ ''' replace test(s) with new test(s), and returns the new score.'''
154
+
155
+ return self.retest(new_test).compute_score()
156
+
157
+ def size(self):
158
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
159
+ return len(self.crefs)
160
+
161
+ def __iadd__(self, other):
162
+ '''add an instance (e.g., from another sentence).'''
163
+
164
+ if type(other) is tuple:
165
+ ## avoid creating new BleuScorer instances
166
+ self.cook_append(other[0], other[1])
167
+ else:
168
+ assert self.compatible(other), "incompatible BLEUs."
169
+ self.ctest.extend(other.ctest)
170
+ self.crefs.extend(other.crefs)
171
+ self._score = None ## need to recompute
172
+
173
+ return self
174
+
175
+ def compatible(self, other):
176
+ return isinstance(other, BleuScorer) and self.n == other.n
177
+
178
+ def single_reflen(self, option="average"):
179
+ return self._single_reflen(self.crefs[0][0], option)
180
+
181
+ def _single_reflen(self, reflens, option=None, testlen=None):
182
+
183
+ if option == "shortest":
184
+ reflen = min(reflens)
185
+ elif option == "average":
186
+ reflen = float(sum(reflens))/len(reflens)
187
+ elif option == "closest":
188
+ reflen = min((abs(l-testlen), l) for l in reflens)[1]
189
+ else:
190
+ assert False, "unsupported reflen option %s" % option
191
+
192
+ return reflen
193
+
194
+ def recompute_score(self, option=None, verbose=0):
195
+ self._score = None
196
+ return self.compute_score(option, verbose)
197
+
198
+ def compute_score(self, option=None, verbose=0):
199
+ n = self.n
200
+ small = 1e-9
201
+ tiny = 1e-15 ## so that if guess is 0 still return 0
202
+ bleu_list = [[] for _ in range(n)]
203
+
204
+ if self._score is not None:
205
+ return self._score
206
+
207
+ if option is None:
208
+ option = "average" if len(self.crefs) == 1 else "closest"
209
+
210
+ self._testlen = 0
211
+ self._reflen = 0
212
+ totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
213
+
214
+ # for each sentence
215
+ for comps in self.ctest:
216
+ testlen = comps['testlen']
217
+ self._testlen += testlen
218
+
219
+ if self.special_reflen is None: ## need computation
220
+ reflen = self._single_reflen(comps['reflen'], option, testlen)
221
+ else:
222
+ reflen = self.special_reflen
223
+
224
+ self._reflen += reflen
225
+
226
+ for key in ['guess','correct']:
227
+ for k in range(n):
228
+ totalcomps[key][k] += comps[key][k]
229
+
230
+ # append per image bleu score
231
+ bleu = 1.
232
+ for k in range(n):
233
+ bleu *= (float(comps['correct'][k]) + tiny) \
234
+ /(float(comps['guess'][k]) + small)
235
+ bleu_list[k].append(bleu ** (1./(k+1)))
236
+ ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
237
+ if ratio < 1:
238
+ for k in range(n):
239
+ bleu_list[k][-1] *= math.exp(1 - 1/ratio)
240
+
241
+ # if verbose > 1:
242
+ # print(comps, reflen)
243
+
244
+ totalcomps['reflen'] = self._reflen
245
+ totalcomps['testlen'] = self._testlen
246
+
247
+ bleus = []
248
+ bleu = 1.
249
+ for k in range(n):
250
+ bleu *= float(totalcomps['correct'][k] + tiny) \
251
+ / (totalcomps['guess'][k] + small)
252
+ bleus.append(bleu ** (1./(k+1)))
253
+ ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
254
+ if ratio < 1:
255
+ for k in range(n):
256
+ bleus[k] *= math.exp(1 - 1/ratio)
257
+
258
+ # if verbose > 0:
259
+ # print(totalcomps)
260
+ # print("ratio:", ratio)
261
+
262
+ self._score = bleus
263
+ return self._score, bleu_list
src/eval_func/cider/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
src/eval_func/cider/cider.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Filename: cider.py
2
+ #
3
+ # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4
+ # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5
+ #
6
+ # Creation Date: Sun Feb 8 14:16:54 2015
7
+ #
8
+ # Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>
9
+
10
+ from eval_func.cider.cider_scorer import CiderScorer
11
+ import pdb
12
+
13
+ class Cider:
14
+ """
15
+ Main Class to compute the CIDEr metric
16
+
17
+ """
18
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
19
+ # set cider to sum over 1 to 4-grams
20
+ self._n = n
21
+ # set the standard deviation parameter for gaussian penalty
22
+ self._sigma = sigma
23
+
24
+ def compute_score(self, gts, res):
25
+ """
26
+ Main function to compute CIDEr score
27
+ :param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
28
+ ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
29
+ :return: cider (float) : computed CIDEr score for the corpus
30
+ """
31
+
32
+ cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
33
+
34
+ for i in range(len(res)):
35
+ hypo = res[i]
36
+ ref = gts[i]
37
+
38
+ # Sanity check.
39
+ assert(type(hypo) is list)
40
+ assert(len(hypo) == 1)
41
+ assert(type(ref) is list)
42
+ assert(len(ref) > 0)
43
+
44
+ cider_scorer += (hypo[0], ref)
45
+
46
+ (score, scores) = cider_scorer.compute_score()
47
+
48
+ return score, scores
49
+
50
+ def method(self):
51
+ return "CIDEr"
src/eval_func/cider/cider_scorer.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Tsung-Yi Lin <tl483@cornell.edu>
3
+ # Ramakrishna Vedantam <vrama91@vt.edu>
4
+
5
+ import copy
6
+ from collections import defaultdict
7
+ import numpy as np
8
+ import pdb
9
+ import math
10
+
11
+ def precook(s, n=4, out=False):
12
+ """
13
+ Takes a string as input and returns an object that can be given to
14
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
15
+ can take string arguments as well.
16
+ :param s: string : sentence to be converted into ngrams
17
+ :param n: int : number of ngrams for which representation is calculated
18
+ :return: term frequency vector for occuring ngrams
19
+ """
20
+ words = s.split()
21
+ counts = defaultdict(int)
22
+ for k in range(1,n+1):
23
+ for i in range(len(words)-k+1):
24
+ ngram = tuple(words[i:i+k])
25
+ counts[ngram] += 1
26
+ return counts
27
+
28
+ def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
29
+ '''Takes a list of reference sentences for a single segment
30
+ and returns an object that encapsulates everything that BLEU
31
+ needs to know about them.
32
+ :param refs: list of string : reference sentences for some image
33
+ :param n: int : number of ngrams for which (ngram) representation is calculated
34
+ :return: result (list of dict)
35
+ '''
36
+ return [precook(ref, n) for ref in refs]
37
+
38
+ def cook_test(test, n=4):
39
+ '''Takes a test sentence and returns an object that
40
+ encapsulates everything that BLEU needs to know about it.
41
+ :param test: list of string : hypothesis sentence for some image
42
+ :param n: int : number of ngrams for which (ngram) representation is calculated
43
+ :return: result (dict)
44
+ '''
45
+ return precook(test, n, True)
46
+
47
+ class CiderScorer(object):
48
+ """CIDEr scorer.
49
+ """
50
+
51
+ def copy(self):
52
+ ''' copy the refs.'''
53
+ new = CiderScorer(n=self.n)
54
+ new.ctest = copy.copy(self.ctest)
55
+ new.crefs = copy.copy(self.crefs)
56
+ return new
57
+
58
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
59
+ ''' singular instance '''
60
+ self.n = n
61
+ self.sigma = sigma
62
+ self.crefs = []
63
+ self.ctest = []
64
+ self.document_frequency = defaultdict(float)
65
+ self.cook_append(test, refs)
66
+ self.ref_len = None
67
+
68
+ def cook_append(self, test, refs):
69
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
70
+
71
+ if refs is not None:
72
+ self.crefs.append(cook_refs(refs))
73
+ if test is not None:
74
+ self.ctest.append(cook_test(test)) ## N.B.: -1
75
+ else:
76
+ self.ctest.append(None) # lens of crefs and ctest have to match
77
+
78
+ def size(self):
79
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
80
+ return len(self.crefs)
81
+
82
+ def __iadd__(self, other):
83
+ '''add an instance (e.g., from another sentence).'''
84
+
85
+ if type(other) is tuple:
86
+ ## avoid creating new CiderScorer instances
87
+ self.cook_append(other[0], other[1])
88
+ else:
89
+ self.ctest.extend(other.ctest)
90
+ self.crefs.extend(other.crefs)
91
+
92
+ return self
93
+ def compute_doc_freq(self):
94
+ '''
95
+ Compute term frequency for reference data.
96
+ This will be used to compute idf (inverse document frequency later)
97
+ The term frequency is stored in the object
98
+ :return: None
99
+ '''
100
+ for refs in self.crefs:
101
+ # refs, k ref captions of one image
102
+ for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
103
+ self.document_frequency[ngram] += 1
104
+ # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
105
+
106
+ def compute_cider(self):
107
+ def counts2vec(cnts):
108
+ """
109
+ Function maps counts of ngram to vector of tfidf weights.
110
+ The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
111
+ The n-th entry of array denotes length of n-grams.
112
+ :param cnts:
113
+ :return: vec (array of dict), norm (array of float), length (int)
114
+ """
115
+ vec = [defaultdict(float) for _ in range(self.n)]
116
+ length = 0
117
+ norm = [0.0 for _ in range(self.n)]
118
+ for (ngram, term_freq) in cnts.items():
119
+ # give word count 1 if it doesn't appear in reference corpus
120
+ df = np.log(max(1.0, self.document_frequency[ngram]))
121
+ # ngram index
122
+ n = len(ngram)-1
123
+ # tf (term_freq) * idf (precomputed idf) for n-grams
124
+ vec[n][ngram] = float(term_freq)*(self.ref_len - df)
125
+ # compute norm for the vector. the norm will be used for computing similarity
126
+ norm[n] += pow(vec[n][ngram], 2)
127
+
128
+ if n == 1:
129
+ length += term_freq
130
+ norm = [np.sqrt(n) for n in norm]
131
+ return vec, norm, length
132
+
133
+ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
134
+ '''
135
+ Compute the cosine similarity of two vectors.
136
+ :param vec_hyp: array of dictionary for vector corresponding to hypothesis
137
+ :param vec_ref: array of dictionary for vector corresponding to reference
138
+ :param norm_hyp: array of float for vector corresponding to hypothesis
139
+ :param norm_ref: array of float for vector corresponding to reference
140
+ :param length_hyp: int containing length of hypothesis
141
+ :param length_ref: int containing length of reference
142
+ :return: array of score for each n-grams cosine similarity
143
+ '''
144
+ delta = float(length_hyp - length_ref)
145
+ # measure consine similarity
146
+ val = np.array([0.0 for _ in range(self.n)])
147
+ for n in range(self.n):
148
+ # ngram
149
+ for (ngram,count) in vec_hyp[n].items():
150
+ # vrama91 : added clipping
151
+ val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
152
+
153
+ if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
154
+ val[n] /= (norm_hyp[n]*norm_ref[n])
155
+
156
+ assert(not math.isnan(val[n]))
157
+ # vrama91: added a length based gaussian penalty
158
+ val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
159
+ return val
160
+
161
+ # compute log reference length
162
+ self.ref_len = np.log(float(len(self.crefs)))
163
+ if len(self.crefs) == 1:
164
+ self.ref_len = 1
165
+ scores = []
166
+ for test, refs in zip(self.ctest, self.crefs):
167
+ # compute vector for test captions
168
+ vec, norm, length = counts2vec(test)
169
+ # compute vector for ref captions
170
+ score = np.array([0.0 for _ in range(self.n)])
171
+ for ref in refs:
172
+ vec_ref, norm_ref, length_ref = counts2vec(ref)
173
+ score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
174
+ # change by vrama91 - mean of ngram scores, instead of sum
175
+ score_avg = np.mean(score)
176
+ # divide by number of references
177
+ score_avg /= len(refs)
178
+ # multiply score by 10
179
+ score_avg *= 10.0
180
+ # append score of an image to the score list
181
+ scores.append(score_avg)
182
+ return scores
183
+
184
+ def compute_score(self, option=None, verbose=0):
185
+ # compute idf
186
+ self.compute_doc_freq()
187
+ # assert to check document frequency
188
+ assert(len(self.ctest) >= max(self.document_frequency.values()))
189
+ # compute cider score
190
+ score = self.compute_cider()
191
+ # debug
192
+ # print score
193
+ return np.mean(np.array(score)), np.array(score)
src/eval_func/rouge/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'vrama91'
src/eval_func/rouge/rouge.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ #
4
+
5
+ # File Name : rouge.py
6
+
7
+ #
8
+
9
+ # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
10
+
11
+ #
12
+
13
+ # Creation Date : 2015-01-07 06:03
14
+
15
+ # Author : Ramakrishna Vedantam <vrama91@vt.edu>
16
+
17
+
18
+ import numpy as np
19
+
20
+ import pdb
21
+
22
+
23
+ def my_lcs(string, sub):
24
+ """
25
+
26
+ Calculates longest common subsequence for a pair of tokenized strings
27
+
28
+ :param string : list of str : tokens from a string split using whitespace
29
+
30
+ :param sub : list of str : shorter string, also split using whitespace
31
+
32
+ :returns: length (list of int): length of the longest common subsequence between the two strings
33
+
34
+
35
+
36
+ Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
37
+
38
+ """
39
+
40
+ if (len(string) < len(sub)):
41
+ sub, string = string, sub
42
+
43
+ lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)]
44
+
45
+ for j in range(1, len(sub) + 1):
46
+
47
+ for i in range(1, len(string) + 1):
48
+
49
+ if (string[i - 1] == sub[j - 1]):
50
+
51
+ lengths[i][j] = lengths[i - 1][j - 1] + 1
52
+
53
+ else:
54
+
55
+ lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
56
+
57
+ return lengths[len(string)][len(sub)]
58
+
59
+
60
+ class Rouge():
61
+ '''
62
+
63
+ Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
64
+
65
+
66
+
67
+ '''
68
+
69
+ def __init__(self):
70
+
71
+ # vrama91: updated the value below based on discussion with Hovey
72
+
73
+ self.beta = 1.2
74
+
75
+ def calc_score(self, candidate, refs):
76
+
77
+ """
78
+
79
+ Compute ROUGE-L score given one candidate and references for an image
80
+
81
+ :param candidate: str : candidate sentence to be evaluated
82
+
83
+ :param refs: list of str : COCO reference sentences for the particular image to be evaluated
84
+
85
+ :returns score: int (ROUGE-L score for the candidate evaluated against references)
86
+
87
+ """
88
+
89
+
90
+ assert (len(candidate) == 1)
91
+
92
+ assert (len(refs) > 0)
93
+
94
+ prec = []
95
+
96
+ rec = []
97
+
98
+ # split into tokens
99
+
100
+ token_c = candidate[0].split(" ")
101
+
102
+ for reference in refs:
103
+ # split into tokens
104
+ hh =1
105
+
106
+ token_r = reference.split(" ")
107
+
108
+ # compute the longest common subsequence
109
+
110
+ lcs = my_lcs(token_r, token_c)
111
+
112
+ prec.append(lcs / float(len(token_c)))
113
+
114
+ rec.append(lcs / float(len(token_r)))
115
+
116
+ prec_max = max(prec)
117
+
118
+ rec_max = max(rec)
119
+
120
+ if (prec_max != 0 and rec_max != 0):
121
+
122
+ score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max)
123
+
124
+ else:
125
+
126
+ score = 0.0
127
+
128
+ return score
129
+
130
+ def compute_score(self, references, hypotheses):
131
+
132
+ """
133
+
134
+ Computes Rouge-L score given a set of reference and candidate sentences for the dataset
135
+
136
+ Invoked by evaluate_captions.py
137
+
138
+ :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
139
+
140
+ :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
141
+
142
+ :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
143
+
144
+ """
145
+
146
+ # assert (gts.keys() == res.keys())
147
+ #
148
+ # imgIds = gts.keys()
149
+
150
+ score = []
151
+
152
+ for i in range(len(hypotheses)):
153
+ hypo = hypotheses[i]
154
+ ref = references[i]
155
+
156
+ score.append(self.calc_score(hypo, ref))
157
+
158
+ # Sanity check.
159
+
160
+ assert (type(hypo) is list)
161
+
162
+ assert (len(hypo) == 1)
163
+
164
+ assert (type(ref) is list)
165
+
166
+ assert (len(ref) > 0)
167
+
168
+ average_score = np.mean(np.array(score))
169
+
170
+ return average_score, np.array(score)
171
+
172
+ def method(self):
173
+
174
+ return "Rouge"
src/model.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from einops import rearrange
4
+ import math
5
+ from torch import Tensor
6
+ import torchvision.models as models
7
+
8
+ from torch.nn import functional as F
9
+
10
+
11
+ class ICCModel(nn.Module):
12
+ def __init__(self, device, pretrained, backbone, d_model, vocab_size, max_len,
13
+ num_heads, h_dim, a_dim, encoder_layers, decoder_layers, dropout,
14
+ learnable=False, fine_tune=True, tie_embeddings=True, prenorm=False):
15
+ super(ICCModel, self).__init__()
16
+
17
+ self.feature_dim = d_model
18
+ visual = pretrained.visual if pretrained else None
19
+ self.encoder = ImagesEncoder(device, visual, backbone, d_model, num_heads, h_dim, a_dim, dropout,
20
+ encoder_layers, fine_tune)
21
+
22
+ self.decoder = Decoder(device, d_model, vocab_size, max_len, num_heads,
23
+ decoder_layers, dropout,
24
+ learnable=learnable, tie_embeddings=tie_embeddings, prenorm=prenorm)
25
+
26
+ def forward(self, img1, img2, input_ids, labels, attention_mask):
27
+ vis_emb, vis_toks = self.encoder(img1, img2)
28
+ cap_loss, text_emb, lm_logits, weights = self.decoder(input_ids, labels, attention_mask, vis_toks)
29
+ return cap_loss, vis_emb, text_emb, vis_toks, lm_logits, weights
30
+
31
+
32
+ class ImagesEncoder(nn.Module):
33
+ def __init__(self, device, pretrained, backbone, d_model, num_heads, h_dim, a_dim, dropout,
34
+ encoder_layers, fine_tune):
35
+ super(ImagesEncoder, self).__init__()
36
+ self.encoder = Encoder(pretrained, backbone, d_model, fine_tune)
37
+ self.encoder_trans = AttentiveEncoder(device, encoder_layers,
38
+ [self.encoder.feat_size, self.encoder.feat_size, d_model], num_heads,
39
+ hidden_dim=h_dim, attention_dim=a_dim, dropout=dropout)
40
+
41
+ self.cos = torch.nn.CosineSimilarity(dim=1)
42
+ self.Conv1 = nn.Conv2d(d_model * 2, d_model, kernel_size=1)
43
+ self.LN = resblock(d_model, d_model)
44
+
45
+ self.att_pool = nn.MultiheadAttention(d_model, num_heads)
46
+ self.att_pool_norm = nn.LayerNorm(d_model)
47
+ self.img_queries = nn.Parameter(torch.randn(1, d_model))
48
+
49
+ def forward(self, img1, img2):
50
+ feat1 = self.encoder(img1)
51
+ feat2 = self.encoder(img2)
52
+ x1, x2 = self.encoder_trans(feat1, feat2) # batch_size, channel, enc_image_size, enc_image_size
53
+
54
+ x_sam = self.cos(x1, x2)
55
+ x = torch.cat([x1, x2], dim=1) + x_sam.unsqueeze(1) # batch_size, 2channel, enc_image_size, enc_image_size
56
+ x = self.LN(self.Conv1(x))
57
+ batch, channel = x.size(0), x.size(1)
58
+ x = x.view(batch, channel, -1).permute(2, 0, 1) # h*w, batch, dim
59
+
60
+ img_queries = self.img_queries.unsqueeze(1).repeat(1, x.shape[1], 1) # L,N,E
61
+ img_emb = self.att_pool(img_queries, x, x, need_weights=False)[0]
62
+ img_emb = self.att_pool_norm(img_emb) # 1, batch, d_model
63
+
64
+ cls = img_emb[0]
65
+ return cls, x
66
+
67
+
68
+ class Encoder(nn.Module):
69
+ def __init__(self, pretrained, backbone, d_model, fine_tune):
70
+ super(Encoder, self).__init__()
71
+ self.backbone = backbone
72
+
73
+ if 'rn' in backbone.lower():
74
+ modules = list(pretrained.children())[:-1]
75
+ self.net = nn.Sequential(*modules)
76
+ self.feat_dim = 2048
77
+ self.feat_size = 7
78
+ elif 'b-32' in backbone.lower():
79
+ self.net = pretrained
80
+ self.net.output_tokens = True
81
+ self.feat_dim = 768
82
+ self.feat_size = 7
83
+ elif 'l-14' in backbone.lower():
84
+ self.net = pretrained
85
+ self.net.output_tokens = True
86
+ self.feat_dim = 1024
87
+ self.feat_size = 16
88
+ elif backbone == 'resnet50':
89
+ net = models.resnet50(pretrained=True)
90
+ modules = list(net.children())[:-2]
91
+ self.net = nn.Sequential(*modules)
92
+ self.feat_dim = 2048
93
+ self.feat_size = 8
94
+ elif backbone == 'resnet101':
95
+ net = models.resnet101(pretrained=True)
96
+ modules = list(net.children())[:-2]
97
+ self.net = nn.Sequential(*modules)
98
+ self.feat_dim = 2048
99
+ self.feat_size = 8
100
+
101
+ self.proj = None
102
+ if self.feat_dim != d_model:
103
+ self.proj = nn.Conv2d(self.feat_dim, d_model, kernel_size=1)
104
+
105
+ self.fine_tune(fine_tune)
106
+
107
+ def forward(self, image):
108
+ feat = self.net(image) # batch, feat_dim, feat_size, feat_size
109
+ if 'vit' in self.backbone.lower():
110
+ feat = feat[1].reshape(-1, self.feat_size, self.feat_size, self.feat_dim).permute(0, 3, 1, 2)
111
+
112
+ if self.proj:
113
+ feat = self.proj(feat)
114
+
115
+ return feat
116
+
117
+ def fine_tune(self, fine_tune=True):
118
+ for p in self.net.parameters():
119
+ p.requires_grad = False
120
+
121
+ if 'resnet' in self.backbone:
122
+ to_finetune = list(self.net.children())[-5:]
123
+ elif 'vit' in self.backbone.lower():
124
+ to_finetune = list(self.net.children())[-2:] # only transformer layers
125
+ else:
126
+ to_finetune = list(self.net.children())[-3:] # only fine-tune convolutional blocks 2 through 4
127
+
128
+ for c in to_finetune:
129
+ for p in c.parameters():
130
+ p.requires_grad = fine_tune
131
+
132
+
133
+ class FeedForward(nn.Module):
134
+ def __init__(self, dim, hidden_dim, dropout=0.):
135
+ super(FeedForward, self).__init__()
136
+ self.net = nn.Sequential(
137
+ nn.Linear(dim, hidden_dim),
138
+ nn.ReLU(),
139
+ nn.Dropout(dropout),
140
+ nn.Linear(hidden_dim, dim),
141
+ nn.Dropout(dropout)
142
+ )
143
+
144
+ def forward(self, x):
145
+ return self.net(x)
146
+
147
+
148
+ class MultiHeadAtt(nn.Module):
149
+ def __init__(self, dim_q, dim_kv, attention_dim, heads=8, dropout=0.):
150
+ super(MultiHeadAtt, self).__init__()
151
+ project_out = not (heads == 1 and attention_dim == dim_kv)
152
+ self.heads = heads
153
+ self.scale = (attention_dim // self.heads) ** -0.5
154
+
155
+ self.to_q = nn.Linear(dim_q, attention_dim, bias=False)
156
+ self.to_k = nn.Linear(dim_kv, attention_dim, bias=False)
157
+ self.to_v = nn.Linear(dim_kv, attention_dim, bias=False)
158
+ self.attend = nn.Softmax(dim=-1)
159
+ self.dropout = nn.Dropout(dropout)
160
+ self.to_out = nn.Sequential(
161
+ nn.Linear(attention_dim, dim_q),
162
+ nn.Dropout(dropout)
163
+ ) if project_out else nn.Identity()
164
+
165
+ def forward(self, x1, x2, x3):
166
+ q = self.to_q(x1)
167
+ k = self.to_k(x2)
168
+ v = self.to_k(x3)
169
+ q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
170
+ k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
171
+ v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)
172
+ dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
173
+
174
+ attn = self.dropout(self.attend(dots))
175
+ out = torch.matmul(attn, v)
176
+ out = rearrange(out, 'b h n d -> b n (h d)')
177
+ return self.to_out(out) # (b,n,dim)
178
+
179
+
180
+ class Transformer(nn.Module):
181
+ def __init__(self, dim_q, dim_kv, heads, attention_dim, hidden_dim, dropout=0., norm_first=False):
182
+ super(Transformer, self).__init__()
183
+ self.norm_first = norm_first
184
+ self.att = MultiHeadAtt(dim_q, dim_kv, attention_dim, heads=heads, dropout=dropout)
185
+ self.feedforward = FeedForward(dim_q, hidden_dim, dropout=dropout)
186
+ self.norm1 = nn.LayerNorm(dim_q)
187
+ self.norm2 = nn.LayerNorm(dim_q)
188
+
189
+ def forward(self, x1, x2, x3):
190
+ if self.norm_first:
191
+ x = self.att(self.norm1(x1), self.norm1(x2), self.norm1(x3)) + x1
192
+ x = self.feedforward(self.norm2(x)) + x
193
+ else:
194
+ x = self.norm1(self.att(x1, x2, x3) + x1)
195
+ x = self.norm2(self.feedforward(x) + x)
196
+
197
+ return x
198
+
199
+
200
+ class AttentiveEncoder(nn.Module):
201
+ def __init__(self, device, n_layers, feature_size, heads, hidden_dim=512, attention_dim=512, dropout=0.):
202
+ super(AttentiveEncoder, self).__init__()
203
+ h_feat, w_feat, channels = feature_size
204
+
205
+ self.device = device
206
+ self.h_embedding = nn.Embedding(h_feat, int(channels / 2))
207
+ self.w_embedding = nn.Embedding(w_feat, int(channels / 2))
208
+ self.selftrans = nn.ModuleList([])
209
+ for i in range(n_layers):
210
+ self.selftrans.append(nn.ModuleList([
211
+ Transformer(channels, channels, heads, attention_dim, hidden_dim, dropout, norm_first=False),
212
+ Transformer(channels * 2, channels * 2, heads, attention_dim, hidden_dim, dropout, norm_first=False),
213
+ ]))
214
+
215
+ self._reset_parameters()
216
+
217
+ def _reset_parameters(self):
218
+ for p in self.parameters():
219
+ if p.dim() > 1:
220
+ nn.init.xavier_uniform_(p)
221
+
222
+ def forward(self, img1, img2):
223
+ batch, c, h, w = img1.shape
224
+ pos_h = torch.arange(h).to(self.device)
225
+ pos_w = torch.arange(w).to(self.device)
226
+ embed_h = self.w_embedding(pos_h)
227
+ embed_w = self.h_embedding(pos_w)
228
+ pos_embedding = torch.cat([embed_w.unsqueeze(0).repeat(h, 1, 1),
229
+ embed_h.unsqueeze(1).repeat(1, w, 1)],
230
+ dim=-1)
231
+ pos_embedding = pos_embedding.permute(2, 0, 1).unsqueeze(0).repeat(batch, 1, 1, 1)
232
+ img1 = img1 + pos_embedding
233
+ img2 = img2 + pos_embedding
234
+ img1 = img1.view(batch, c, -1).transpose(-1, 1) # batch, hw, c
235
+ img2 = img2.view(batch, c, -1).transpose(-1, 1)
236
+ img_sa1, img_sa2 = img1, img2
237
+
238
+ for (l, m) in self.selftrans:
239
+ img_sa1 = l(img_sa1, img_sa1, img_sa1) + img_sa1
240
+ img_sa2 = l(img_sa2, img_sa2, img_sa2) + img_sa2
241
+ img = torch.cat([img_sa1, img_sa2], dim=-1)
242
+ img = m(img, img, img)
243
+ img_sa1 = img[:, :, :c] + img1
244
+ img_sa2 = img[:, :, c:] + img2
245
+
246
+ img1 = img_sa1.reshape(batch, h, w, c).transpose(-1, 1)
247
+ img2 = img_sa2.reshape(batch, h, w, c).transpose(-1, 1)
248
+
249
+ return img1, img2
250
+
251
+
252
+ class resblock(nn.Module):
253
+ def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
254
+ super(resblock, self).__init__()
255
+ self.left = nn.Sequential(
256
+ nn.Conv2d(inchannel, int(outchannel / 2), kernel_size=1),
257
+ # nn.LayerNorm(int(outchannel/2),dim=1),
258
+ nn.BatchNorm2d(int(outchannel / 2)),
259
+ nn.ReLU(),
260
+ nn.Conv2d(int(outchannel / 2), int(outchannel / 2), kernel_size=3, stride=1, padding=1),
261
+ # nn.LayerNorm(int(outchannel/2),dim=1),
262
+ nn.BatchNorm2d(int(outchannel / 2)),
263
+ nn.ReLU(),
264
+ nn.Conv2d(int(outchannel / 2), outchannel, kernel_size=1),
265
+ # nn.LayerNorm(int(outchannel / 1),dim=1)
266
+ nn.BatchNorm2d(outchannel)
267
+ )
268
+ self.right = shortcut
269
+
270
+ def forward(self, x):
271
+ out = self.left(x)
272
+ residual = x
273
+ out = out + residual
274
+ return F.relu(out)
275
+
276
+
277
+ class Decoder(nn.Module):
278
+ def __init__(self, device, h_dim, vocab_size, max_len, n_head, n_layers, dropout=0.10,
279
+ learnable=False, tie_embeddings=True, prenorm=False):
280
+
281
+ super(Decoder, self).__init__()
282
+
283
+ self.embed_dim = h_dim
284
+ self.vocab_size = vocab_size
285
+ self.dropout = dropout
286
+ self.device = device
287
+
288
+ self.tokens_embed = nn.Embedding(vocab_size, self.embed_dim)
289
+ self.position_encoding = PositionalEncoding(self.embed_dim, dropout=dropout, max_len=max_len,
290
+ device=device, learnable=learnable)
291
+
292
+ self.uni_decoder = nn.ModuleList(
293
+ [DecoderLayer(h_dim, h_dim, n_head, dim_feedforward=h_dim * 4, dropout=self.dropout, prenorm=prenorm,
294
+ crossattention=False) for _ in range(n_layers)])
295
+
296
+ self.cross_decoder = nn.ModuleList(
297
+ [DecoderLayer(h_dim, h_dim, n_head, dim_feedforward=h_dim * 4, dropout=self.dropout, prenorm=prenorm,
298
+ crossattention=True) for _ in range(n_layers)])
299
+
300
+ self.lm_head = nn.Linear(h_dim, vocab_size, bias=False)
301
+ if tie_embeddings:
302
+ self.tokens_embed.weight = self.lm_head.weight
303
+ self.dropout = nn.Dropout(p=self.dropout)
304
+ self.init_weights()
305
+ self.loss_fn = nn.CrossEntropyLoss()
306
+
307
+ def init_weights(self):
308
+ self.tokens_embed.weight.data.uniform_(-0.1, 0.1)
309
+ self.lm_head.weight.data.uniform_(-0.1, 0.1)
310
+
311
+ def forward(self, input_ids=None, labels=None, pad_mask=None, img_emb=None):
312
+ att_weights = None
313
+ mask = torch.tril(torch.ones(input_ids.shape[1], input_ids.shape[1]))
314
+ mask = ~mask.bool()
315
+ mask = mask.to(self.device)
316
+
317
+ inputs_embeds = self.tokens_embed(input_ids)
318
+ inputs_embeds = self.position_encoding(inputs_embeds) # batch, seq, e_dim
319
+ inputs_embeds = inputs_embeds.permute(1, 0, 2) # seq, batch, e_dim
320
+
321
+ # seq, batch, emb_dim
322
+ out = inputs_embeds
323
+ for block in self.uni_decoder:
324
+ out, _ = block(out, None, tgt_mask=mask, tgt_key_padding_mask=pad_mask)
325
+
326
+ if pad_mask is not None: # not inference
327
+ cls = []
328
+ for i in range(pad_mask.shape[0]):
329
+ end = pad_mask[i].shape[0] - pad_mask[i].count_nonzero()
330
+ cls.append(out[end - 1, i, :])
331
+
332
+ cls = torch.stack(cls) # batch, emb_dim
333
+ else:
334
+ cls = None
335
+
336
+ if img_emb is None:
337
+ return None, cls, None, None
338
+
339
+ for block in self.cross_decoder:
340
+ out, att_weights = block(out, img_emb, tgt_mask=mask, tgt_key_padding_mask=pad_mask)
341
+
342
+ lm_logits = self.lm_head(self.dropout(out)) # seq, batch, voc_dim
343
+ lm_logits = lm_logits.permute(1, 0, 2) # batch, seq, voc_dim
344
+
345
+ if labels is not None: # not inference
346
+ shift_logits = lm_logits[..., :-1, :].contiguous()
347
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
348
+ shift_labels = labels[..., 1:].contiguous()
349
+ shift_labels = shift_labels.view(-1)
350
+ loss = self.loss_fn(shift_logits, shift_labels)
351
+ else:
352
+ loss = None
353
+
354
+ return loss, cls, lm_logits, att_weights
355
+
356
+
357
+ class PositionalEncoding(nn.Module):
358
+ def __init__(self, d_model, dropout, max_len, device, learnable=False):
359
+ super(PositionalEncoding, self).__init__()
360
+ self.learnable = learnable
361
+ self.max_len = max_len
362
+ self.device = device
363
+ self.dropout = nn.Dropout(p=dropout)
364
+
365
+ if not learnable:
366
+ pe = torch.zeros(max_len, d_model)
367
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
368
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
369
+ pe[:, 0::2] = torch.sin(position * div_term)
370
+ pe[:, 1::2] = torch.cos(position * div_term)
371
+ pe = pe.unsqueeze(0)
372
+ self.register_buffer('pe', pe)
373
+ else:
374
+ self.pos_emb = nn.Embedding(max_len, int(d_model))
375
+
376
+ def forward(self, x):
377
+ if self.learnable:
378
+ position_ids = torch.arange(x.size(1), dtype=torch.long).to(self.device)
379
+ position_ids = position_ids.unsqueeze(0).view(-1, x.size(1)) # batch, seq
380
+ x = x + self.pos_emb(position_ids)
381
+ else:
382
+ x = x + self.pe[:, :x.size(1), :]
383
+ return self.dropout(x)
384
+
385
+
386
+ class DecoderLayer(nn.Module):
387
+ def __init__(self, d_model, img_dim, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-5,
388
+ prenorm=False, crossattention=False):
389
+ super(DecoderLayer, self).__init__()
390
+
391
+ self.prenorm = prenorm
392
+ self.crossattention = crossattention
393
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
394
+
395
+ if crossattention:
396
+ self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, kdim=img_dim, vdim=img_dim)
397
+ self.mha_dropout = nn.Dropout(dropout)
398
+ self.mha_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
399
+
400
+ self.ff_linear1 = nn.Linear(d_model, dim_feedforward)
401
+ self.ff_dropout = nn.Dropout(dropout)
402
+ self.ff_linear2 = nn.Linear(dim_feedforward, d_model)
403
+
404
+ self.sa_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
405
+ self.ff_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
406
+
407
+ self.sa_dropout = nn.Dropout(dropout)
408
+ self.ff_dropout = nn.Dropout(dropout)
409
+
410
+ self.activation = nn.GELU()
411
+
412
+ def forward(self, tgt: Tensor, memory: Tensor, tgt_mask=None,
413
+ memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
414
+ att_weight = None
415
+ x = tgt
416
+
417
+ if self.prenorm:
418
+ x = x + self._sa_block(self.sa_norm(x), tgt_mask, tgt_key_padding_mask)
419
+ if self.crossattention:
420
+ enc_att, att_weight = self._mha_block(self.mha_norm(x), memory, memory_mask, memory_key_padding_mask)
421
+ x = x + enc_att
422
+ x = x + self._ff_block(self.ff_norm(x))
423
+ else:
424
+ x = self.sa_norm(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
425
+ if self.crossattention:
426
+ enc_att, att_weight = self._mha_block(x, memory, memory_mask, memory_key_padding_mask)
427
+ x = self.mha_norm(x + enc_att)
428
+
429
+ x = self.ff_norm(x + self._ff_block(x))
430
+ return x, att_weight
431
+
432
+ def _sa_block(self, x, attn_mask, key_padding_mask):
433
+ x = self.self_attn(x, x, x, # L,N,E
434
+ attn_mask=attn_mask, # L, S
435
+ key_padding_mask=key_padding_mask, # N, S
436
+ is_causal=True,
437
+ need_weights=False)[0]
438
+ return self.sa_dropout(x)
439
+
440
+ def _mha_block(self, x, mem, attn_mask, key_padding_mask):
441
+ x, att_weight = self.cross_attn(x, mem, mem,
442
+ attn_mask=attn_mask,
443
+ key_padding_mask=key_padding_mask,
444
+ is_causal=False,
445
+ need_weights=True)
446
+ return self.mha_dropout(x), att_weight
447
+
448
+ def _ff_block(self, x):
449
+ x = self.ff_linear2(self.ff_dropout(self.activation(self.ff_linear1(x))))
450
+ return self.ff_dropout(x)
src/train.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import random
3
+ import os
4
+ from datetime import datetime
5
+
6
+ import numpy as np
7
+ import torch
8
+ import json
9
+ from torch.optim import AdamW
10
+ from torchvision.transforms import v2
11
+ from torch.utils.tensorboard import SummaryWriter
12
+ from tqdm import tqdm
13
+ from transformers import get_constant_schedule_with_warmup
14
+
15
+ from Datasets import CCDataset, Batcher
16
+ from model import ICCModel
17
+ from utils import get_vocabulary
18
+ from Loss import InfoNCELoss
19
+ from eval import captioning, retrieve, plot
20
+ from huggingface_hub import hf_hub_download
21
+ import open_clip
22
+
23
+
24
+ def train(args, model, train_loader, valid_loader, device, infonce, optim, scheduler, writer):
25
+ step = 0
26
+ best_score = float("inf")
27
+ best_model = None
28
+
29
+ for epoch in range(args.epochs):
30
+ model.train()
31
+
32
+ for batch in tqdm(train_loader, desc='Epoch ' + str(epoch)):
33
+ imgs1 = batch['images_before'].to(device)
34
+ imgs2 = batch['images_after'].to(device)
35
+ toks = batch['input_ids'].to(device)
36
+ labs = batch['labels'].to(device)
37
+ flags = batch['flags'].to(device)
38
+ attention_mask = batch['pad_mask'].to(device)
39
+ embs = batch['embs'].to(device)
40
+
41
+ cap_loss, vis_emb, text_emb, _, _, _ = model(imgs1, imgs2, toks, labs, attention_mask)
42
+
43
+ con_loss, num_pos = infonce(vis_emb, text_emb, flags, embs)
44
+ loss = cap_loss + args.lamb * con_loss
45
+ loss.backward()
46
+
47
+ if args.max_grad_norm:
48
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
49
+ grad = torch.norm(torch.stack(
50
+ [torch.norm(p.grad.detach()).to(device) for p in model.parameters() if p.grad is not None]))
51
+
52
+ optim.step()
53
+ scheduler.step()
54
+ optim.zero_grad()
55
+
56
+ writer.add_scalar('train_loss', loss.item(), step)
57
+ writer.add_scalar('grad', grad, step)
58
+ writer.add_scalar('lr', scheduler.get_last_lr()[0], step)
59
+
60
+ step += 1
61
+
62
+ torch.save(model.state_dict(), args.output_path + 'model_{}.pt'.format(step))
63
+
64
+ model.eval()
65
+ with torch.no_grad():
66
+ eval_losses = torch.empty(0)
67
+ for batch in tqdm(valid_loader, desc='Validation ' + str(epoch)):
68
+ imgs1 = batch['images_before'].to(device)
69
+ imgs2 = batch['images_after'].to(device)
70
+ toks = batch['input_ids'].to(device)
71
+ labs = batch['labels'].to(device)
72
+ flags = batch['flags'].to(device)
73
+ attention_mask = batch['pad_mask'].to(device)
74
+ embs = batch['embs'].to(device)
75
+
76
+ cap_loss, vis_emb, text_emb, _, _, _ = model(imgs1, imgs2, toks, labs, attention_mask)
77
+
78
+ con_loss, _ = infonce(vis_emb, text_emb, flags, embs)
79
+ loss = cap_loss + args.lamb * con_loss
80
+ eval_losses = torch.cat([eval_losses, loss.cpu().unsqueeze(0)])
81
+
82
+ eval_score = torch.mean(eval_losses)
83
+ writer.add_scalar('eval_score', eval_score, step)
84
+
85
+ is_best = eval_score < best_score
86
+ best_score = min(eval_score, best_score)
87
+ if is_best:
88
+ best_model = step
89
+
90
+ if best_model is not None:
91
+ state_dict = torch.load(os.path.join(args.output_path + 'model_{}.pt'.format(best_model)), map_location=device)
92
+ torch.save(state_dict, args.output_path + 'model_best.pt')
93
+
94
+
95
+ def run(args, config):
96
+ print('Initializing...')
97
+ torch.manual_seed(args.seed)
98
+ np.random.seed(args.seed)
99
+ random.seed(args.seed)
100
+ torch.backends.cudnn.deterministic = True
101
+
102
+ device = torch.device('cpu')
103
+ if torch.cuda.is_available():
104
+ device = torch.device('cuda')
105
+
106
+ dt_str = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
107
+ writer_path = args.output_path + dt_str
108
+ writer = SummaryWriter(writer_path)
109
+
110
+ if os.path.exists(args.vocab):
111
+ with open(args.vocab, 'r') as infile:
112
+ vocab = json.load(infile)
113
+ else:
114
+ vocab = get_vocabulary(args.annotation_json, args.vocab)
115
+
116
+ clip = None
117
+ preprocess = v2.Compose([
118
+ v2.ToImage(),
119
+ v2.ToDtype(torch.float32, scale=True),
120
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
121
+ ])
122
+
123
+ if 'resnet' not in config['backbone']:
124
+ checkpoint_path = hf_hub_download("chendelong/RemoteCLIP",
125
+ f"RemoteCLIP-{config['backbone']}.pt",
126
+ cache_dir=args.pretrained)
127
+
128
+ clip, _, preprocess = open_clip.create_model_and_transforms(config['backbone'])
129
+ ckpt = torch.load(checkpoint_path, map_location="cpu")
130
+ clip.load_state_dict(ckpt)
131
+
132
+ model = ICCModel(device, clip, config['backbone'], config['d_model'],
133
+ len(vocab), config['max_len'], config['num_heads'], config['h_dim'], config['a_dim'],
134
+ config['encoder_layers'], config['decoder_layers'], config['dropout'],
135
+ learnable=config['learnable'], fine_tune=config['fine_tune'],
136
+ tie_embeddings=config['tie_embeddings'], prenorm=config['prenorm'])
137
+ model = model.to(device)
138
+ del clip
139
+
140
+ print('Loading...')
141
+ training_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'train', config['max_len'],
142
+ config['s-transformers'], device)
143
+ valid_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'val', config['max_len'],
144
+ config['s-transformers'], device)
145
+ test_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'test', config['max_len'],
146
+ config['s-transformers'], device)
147
+
148
+ train_loader = Batcher(training_set, args.batch_size, config['max_len'], device, args.hd, model=model, shuffle=True)
149
+ valid_loader = Batcher(valid_set, args.batch_size, config['max_len'], device)
150
+ test_loader = Batcher(test_set, 1, config['max_len'], device)
151
+
152
+ print('Training...')
153
+ infonce = InfoNCELoss(device, k=args.k, temperature=args.temperature, threshold=config['s-threshold'],
154
+ fna=config['fna'])
155
+ optim = AdamW([x for x in model.parameters() if x.requires_grad], lr=args.learning_rate, eps=args.adam_epsilon)
156
+ scheduler = get_constant_schedule_with_warmup(optim,
157
+ num_warmup_steps=args.warmup_steps * len(train_loader) * args.epochs)
158
+ train(args, model, train_loader, valid_loader, device, infonce, optim, scheduler, writer)
159
+
160
+ print('Final evaluation...')
161
+ model.load_state_dict(torch.load(os.path.join(args.output_path, 'model_best.pt'), map_location=device))
162
+ results = captioning(args, config, model, test_loader, vocab, device)
163
+ retrieve(args, config, model, test_loader, device)
164
+ plot(args, model.encoder.encoder.feat_size, results)
165
+
166
+
167
+ def main():
168
+ parser = argparse.ArgumentParser()
169
+ parser.add_argument('--annotation_json', type=str, default='../input/Levir_CC/LevirCCcaptions.json')
170
+ parser.add_argument('--image_dir', type=str, default='../input/Levir_CC/images/')
171
+ parser.add_argument('--vocab', type=str, default='../input/levir_vocab.json')
172
+ parser.add_argument('--pretrained', type=str, default='../../input/checkpoints')
173
+ parser.add_argument('--config', type=str, default='../config.json')
174
+ parser.add_argument('--output_path', type=str, default='../output/')
175
+
176
+ parser.add_argument('--epochs', type=int, default=50)
177
+ parser.add_argument('--batch_size', type=int, default=4)
178
+ parser.add_argument('--k', type=int, default=-1)
179
+ parser.add_argument('--hd', type=int, default=-1)
180
+ parser.add_argument('--learning_rate', type=float, default=1e-4)
181
+ parser.add_argument('--warmup_steps', type=float, default=0.025)
182
+ parser.add_argument('--lr_decay', type=float, default=0.7)
183
+ parser.add_argument('--adam_epsilon', type=float, default=1e-8)
184
+ parser.add_argument('--max_grad_norm', type=float, default=None)
185
+ parser.add_argument('--temperature', type=float, default=0.01)
186
+ parser.add_argument('--lamb', type=float, default=0.5)
187
+ parser.add_argument('--seed', type=int, default=42)
188
+
189
+ args = parser.parse_args()
190
+
191
+ with open(args.config, 'r') as config_file:
192
+ config = json.load(config_file)
193
+
194
+ run(args, config)
195
+
196
+
197
+ if __name__ == '__main__':
198
+ main()
src/utils.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import pandas as pd
5
+ import torch
6
+ import glob
7
+
8
+ from eval_func.bleu.bleu import Bleu
9
+ from eval_func.rouge.rouge import Rouge
10
+ from eval_func.cider.cider import Cider
11
+
12
+
13
+ def get_eval_score(ref, hypo):
14
+ scorers = [
15
+ (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
16
+ (Rouge(), "ROUGE_L"),
17
+ (Cider(), "CIDEr")
18
+ ]
19
+
20
+ score = []
21
+ method = []
22
+ for scorer, method_i in scorers:
23
+ score_i, scores_i = scorer.compute_score(ref, hypo)
24
+ score.extend(score_i) if isinstance(score_i, list) else score.append(score_i)
25
+ method.extend(method_i) if isinstance(method_i, list) else method.append(method_i)
26
+
27
+ score_dict = dict(zip(method, score))
28
+ return score_dict
29
+
30
+
31
+ def get_vocabulary(in_path, out_file):
32
+ if 'levir' in in_path.lower():
33
+ return get_levir_vocabulary(in_path, out_file)
34
+ elif 'dubai' in in_path.lower():
35
+ return get_dubai_vocabulary(in_path, out_file)
36
+ elif 'clevr' in in_path.lower():
37
+ return get_clevr_vocabulary(in_path, out_file)
38
+
39
+
40
+ def get_levir_vocabulary(in_path, out_file):
41
+ with open(in_path) as fin:
42
+ data = json.load(fin)['images']
43
+
44
+ sents = [y for x in data for y in x['sentences']]
45
+ tokens = [y for x in sents for y in x['tokens']]
46
+ occurencies = pd.Series(tokens).value_counts()
47
+ selected = occurencies[occurencies > 5]
48
+ vocab = {w: i + 4 for i, w in enumerate(selected.index)}
49
+ vocab['PAD'] = 0
50
+ vocab['START'] = 1
51
+ vocab['UNK'] = 2
52
+ vocab['END'] = 3
53
+
54
+ with open(out_file, 'w') as fout:
55
+ json.dump(vocab, fout)
56
+
57
+ return vocab
58
+
59
+
60
+ def get_dubai_vocabulary(in_path, out_file):
61
+ data = []
62
+ for path in glob.glob(in_path + '/*.json'):
63
+ with open(path) as fin:
64
+ data.extend(json.load(fin)['images'])
65
+
66
+ sents = [y for x in data for y in x['sentences']]
67
+ tokens = [y for x in sents for y in x['tokens']]
68
+ selected = pd.Series(tokens).value_counts()
69
+ vocab = {w: i + 4 for i, w in enumerate(selected.index)}
70
+ vocab['PAD'] = 0
71
+ vocab['START'] = 1
72
+ vocab['UNK'] = 2
73
+ vocab['END'] = 3
74
+
75
+ with open(out_file, 'w') as fout:
76
+ json.dump(vocab, fout)
77
+
78
+ return vocab
79
+
80
+
81
+ def get_clevr_vocabulary(in_path, out_file):
82
+ sents = []
83
+ with open(os.path.join(in_path, 'change_captions.json'), 'r', encoding='utf-8') as fin:
84
+ data = json.load(fin)
85
+ sents += [y for x in data for y in data[x]]
86
+
87
+ with open(os.path.join(in_path, 'no_change_captions.json'), 'r', encoding='utf-8') as fin:
88
+ data = json.load(fin)
89
+ sents += [y for x in data for y in data[x]]
90
+
91
+ tokens = [y for x in sents for y in x.split(' ')]
92
+ occurencies = pd.Series(tokens).value_counts()
93
+ vocab = {w: i + 4 for i, w in enumerate(occurencies.index)}
94
+ vocab['PAD'] = 0
95
+ vocab['START'] = 1
96
+ vocab['UNK'] = 2
97
+ vocab['END'] = 3
98
+
99
+ with open(out_file, 'w') as fout:
100
+ json.dump(vocab, fout)
101
+
102
+ return vocab
103
+
104
+
105
+ def unormalize(tensor, mean=None, std=None):
106
+ if mean is not None and std is not None:
107
+ for t, m, s in zip(tensor, mean, std):
108
+ t.mul_(s).add_(m)
109
+ return torch.clip(tensor, min=0, max=1)
110
+
111
+ b, c, h, w = tensor.shape
112
+ tensor = tensor.view(b, -1)
113
+ tensor -= tensor.min(1, keepdim=True)[0]
114
+ tensor /= tensor.max(1, keepdim=True)[0]
115
+ return tensor.view(b, c, h, w)