upload code
Browse files- src/Datasets.py +269 -0
- src/Loss.py +84 -0
- src/eval.py +288 -0
- src/eval_func/__init__.py +0 -0
- src/eval_func/bleu/LICENSE +19 -0
- src/eval_func/bleu/__init__.py +1 -0
- src/eval_func/bleu/bleu.py +44 -0
- src/eval_func/bleu/bleu_scorer.py +263 -0
- src/eval_func/cider/__init__.py +1 -0
- src/eval_func/cider/cider.py +51 -0
- src/eval_func/cider/cider_scorer.py +193 -0
- src/eval_func/rouge/__init__.py +1 -0
- src/eval_func/rouge/rouge.py +174 -0
- src/model.py +450 -0
- src/train.py +198 -0
- src/utils.py +115 -0
src/Datasets.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch.utils.data import Dataset
|
| 6 |
+
from PIL import Image
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
import faiss
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from sentence_transformers import SentenceTransformer
|
| 11 |
+
import torchvision.transforms as transforms
|
| 12 |
+
from random import choice
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class CCDataset(Dataset):
|
| 16 |
+
def __init__(self, json_file, root_dir, vocab, transform, split, max_length, s_pretrained, device):
|
| 17 |
+
super(CCDataset, self).__init__()
|
| 18 |
+
self.vocab = vocab
|
| 19 |
+
self.split = split
|
| 20 |
+
self.max_length = max_length
|
| 21 |
+
self.device = device
|
| 22 |
+
self.transform = transform
|
| 23 |
+
assert self.split in {'train', 'val', 'test'}
|
| 24 |
+
|
| 25 |
+
s_model = SentenceTransformer(s_pretrained)
|
| 26 |
+
self.s_model = s_model.to(device)
|
| 27 |
+
|
| 28 |
+
self.root_dir = root_dir
|
| 29 |
+
self.convert = transforms.ToTensor()
|
| 30 |
+
|
| 31 |
+
with open(json_file) as f:
|
| 32 |
+
data = json.load(f)['images']
|
| 33 |
+
|
| 34 |
+
self.raw_dataset = [entry for entry in data if entry['split'] == split]
|
| 35 |
+
self.sentences = []
|
| 36 |
+
self.embeddings = []
|
| 37 |
+
|
| 38 |
+
self.images = []
|
| 39 |
+
self.captions = []
|
| 40 |
+
for record in tqdm(self.raw_dataset, desc='Tokenize ' + self.split):
|
| 41 |
+
self.sentences.extend(self.tokenize(record['sentences']))
|
| 42 |
+
|
| 43 |
+
for record in tqdm(self.raw_dataset, desc='Embeddings ' + self.split):
|
| 44 |
+
self.embeddings.extend(self.compute_embeddings(record['sentences']))
|
| 45 |
+
|
| 46 |
+
self.preprocess()
|
| 47 |
+
del self.raw_dataset
|
| 48 |
+
del self.sentences
|
| 49 |
+
del self.embeddings
|
| 50 |
+
del self.s_model
|
| 51 |
+
|
| 52 |
+
def tokenize(self, batch):
|
| 53 |
+
for elem in batch:
|
| 54 |
+
tokens = [self.vocab[x] if x in self.vocab.keys() else self.vocab['UNK'] for x in elem['tokens']]
|
| 55 |
+
if len(tokens) > self.max_length - 2:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
tokens = [self.vocab['START']] + tokens + [self.vocab['END']]
|
| 59 |
+
|
| 60 |
+
mask = [False] * len(tokens)
|
| 61 |
+
|
| 62 |
+
diff = self.max_length - len(tokens)
|
| 63 |
+
tokens += [self.vocab['PAD']] * diff
|
| 64 |
+
mask += [True] * diff # True = pad
|
| 65 |
+
|
| 66 |
+
elem['input_ids'] = tokens
|
| 67 |
+
elem['mask'] = mask
|
| 68 |
+
|
| 69 |
+
if len(batch) < 5:
|
| 70 |
+
diff = 5 - len(batch)
|
| 71 |
+
batch += [choice(batch) for _ in range(diff)]
|
| 72 |
+
|
| 73 |
+
assert len(batch) == 5
|
| 74 |
+
return batch
|
| 75 |
+
|
| 76 |
+
def compute_embeddings(self, batch):
|
| 77 |
+
sents = [x['raw'].strip() for x in batch]
|
| 78 |
+
embs = self.s_model.encode(sents)
|
| 79 |
+
return embs
|
| 80 |
+
|
| 81 |
+
def __len__(self):
|
| 82 |
+
return len(self.captions)
|
| 83 |
+
|
| 84 |
+
def __getitem__(self, idx):
|
| 85 |
+
img_idx = idx // 5 if self.split == 'train' else idx
|
| 86 |
+
elem = self.captions[idx]
|
| 87 |
+
for k, v in self.images[img_idx].items():
|
| 88 |
+
elem[k] = v
|
| 89 |
+
return elem
|
| 90 |
+
|
| 91 |
+
def preprocess(self):
|
| 92 |
+
idx = 0
|
| 93 |
+
prev_idx = -1
|
| 94 |
+
pbar = tqdm(total=len(self.sentences), desc='Preprocessing ' + self.split)
|
| 95 |
+
while idx < len(self.sentences):
|
| 96 |
+
img_idx = idx // 5
|
| 97 |
+
assert (self.sentences[idx]['imgid'] == self.raw_dataset[img_idx]['imgid'])
|
| 98 |
+
|
| 99 |
+
input_ids = torch.tensor(self.sentences[idx]['input_ids'], dtype=torch.long)
|
| 100 |
+
mask = torch.tensor(self.sentences[idx]['mask'], dtype=torch.bool)
|
| 101 |
+
raws = [x['raw'] for x in self.raw_dataset[img_idx]['sentences']]
|
| 102 |
+
flag = -1 if self.raw_dataset[img_idx]['changeflag'] == 0 else self.raw_dataset[img_idx]['imgid']
|
| 103 |
+
flag = torch.tensor(flag, dtype=torch.long)
|
| 104 |
+
embs = torch.tensor(self.embeddings[idx]) if len(self.embeddings) > 0 else None
|
| 105 |
+
|
| 106 |
+
self.captions.append({'input_ids': input_ids, 'pad_masks': mask, 'raws': raws, 'flags': flag, 'embs': embs})
|
| 107 |
+
|
| 108 |
+
if img_idx != prev_idx:
|
| 109 |
+
before_img_path = os.path.join(self.root_dir, self.raw_dataset[img_idx]['filepath'], 'A',
|
| 110 |
+
self.raw_dataset[img_idx]['filename'])
|
| 111 |
+
image_before = Image.open(before_img_path)
|
| 112 |
+
after_img_path = os.path.join(self.root_dir, self.raw_dataset[img_idx]['filepath'], 'B',
|
| 113 |
+
self.raw_dataset[img_idx]['filename'])
|
| 114 |
+
image_after = Image.open(after_img_path)
|
| 115 |
+
|
| 116 |
+
image_before = self.transform(image_before).unsqueeze(0)
|
| 117 |
+
image_after = self.transform(image_after).unsqueeze(0)
|
| 118 |
+
|
| 119 |
+
self.images.append({'image_before': image_before, 'image_after': image_after, 'flags': flag})
|
| 120 |
+
prev_idx = img_idx
|
| 121 |
+
|
| 122 |
+
inc = 1 if self.split == 'train' else 5
|
| 123 |
+
idx += inc
|
| 124 |
+
pbar.update(inc)
|
| 125 |
+
|
| 126 |
+
pbar.close()
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class Batcher:
|
| 130 |
+
def __init__(self, dataset, batch_size, max_len, device, hd=0, model=None, shuffle=False):
|
| 131 |
+
self.dataset = dataset
|
| 132 |
+
self.batch_size = batch_size
|
| 133 |
+
self.hd = hd
|
| 134 |
+
self.max_len = max_len
|
| 135 |
+
self.device = device
|
| 136 |
+
self.model = model
|
| 137 |
+
self.index = None
|
| 138 |
+
self.visual = None
|
| 139 |
+
self.textual = None
|
| 140 |
+
|
| 141 |
+
self.ptr = 0
|
| 142 |
+
self.indices = np.arange(len(self.dataset))
|
| 143 |
+
self.shuffle = shuffle
|
| 144 |
+
|
| 145 |
+
if shuffle:
|
| 146 |
+
np.random.shuffle(self.indices)
|
| 147 |
+
|
| 148 |
+
if model and hd > 0 and self.dataset.split == 'train':
|
| 149 |
+
self.create_index()
|
| 150 |
+
|
| 151 |
+
def __iter__(self):
|
| 152 |
+
return self
|
| 153 |
+
|
| 154 |
+
def __len__(self):
|
| 155 |
+
return len(self.dataset) // self.batch_size
|
| 156 |
+
|
| 157 |
+
def __next__(self):
|
| 158 |
+
if self.ptr >= len(self.dataset):
|
| 159 |
+
self.ptr = 0
|
| 160 |
+
self.index = None
|
| 161 |
+
self.visual = None
|
| 162 |
+
self.textual = None
|
| 163 |
+
|
| 164 |
+
if self.shuffle:
|
| 165 |
+
np.random.shuffle(self.indices)
|
| 166 |
+
if self.model and self.hd > 0 and self.dataset.split == 'train':
|
| 167 |
+
self.create_index()
|
| 168 |
+
|
| 169 |
+
raise StopIteration
|
| 170 |
+
|
| 171 |
+
batched = 0
|
| 172 |
+
samples = []
|
| 173 |
+
hard_negatives = []
|
| 174 |
+
while self.ptr < len(self.dataset) and batched < self.batch_size:
|
| 175 |
+
sample = self.dataset[self.indices[self.ptr]]
|
| 176 |
+
samples.append(sample)
|
| 177 |
+
|
| 178 |
+
if self.hd > 0 and self.dataset.split == 'train':
|
| 179 |
+
hard_neg = self.mine_negatives(self.indices[self.ptr], self.hd)
|
| 180 |
+
hard_negatives.extend(hard_neg)
|
| 181 |
+
|
| 182 |
+
self.ptr += 1
|
| 183 |
+
batched += 1
|
| 184 |
+
|
| 185 |
+
return self.create_batch(samples + hard_negatives)
|
| 186 |
+
|
| 187 |
+
def get_elem(self, idx):
|
| 188 |
+
return self.dataset[idx]
|
| 189 |
+
|
| 190 |
+
@torch.no_grad()
|
| 191 |
+
def create_index(self):
|
| 192 |
+
is_training = self.model.training
|
| 193 |
+
self.model.eval()
|
| 194 |
+
self.index = faiss.IndexFlatIP(self.model.feature_dim)
|
| 195 |
+
prev_img = None
|
| 196 |
+
for idx in tqdm(range(len(self.dataset)), desc='Creating index'):
|
| 197 |
+
sample = self.dataset[idx]
|
| 198 |
+
imgs1, imgs2, = sample['image_before'], sample['image_after']
|
| 199 |
+
input_ids, mask = sample['input_ids'], sample['pad_masks']
|
| 200 |
+
|
| 201 |
+
if idx // 5 != prev_img:
|
| 202 |
+
imgs1 = imgs1.to(self.device)
|
| 203 |
+
imgs2 = imgs2.to(self.device)
|
| 204 |
+
vis_emb, _, = self.model.encoder(imgs1, imgs2)
|
| 205 |
+
self.visual = torch.cat([self.visual, vis_emb.cpu()]) if self.visual is not None else vis_emb.cpu()
|
| 206 |
+
prev_img = prev_img + 1 if prev_img is not None else 0
|
| 207 |
+
|
| 208 |
+
input_ids = input_ids.unsqueeze(0).to(self.device)
|
| 209 |
+
mask = mask.unsqueeze(0).to(self.device)
|
| 210 |
+
_, text_emb, _, _ = self.model.decoder(input_ids, None, mask, None)
|
| 211 |
+
self.textual = torch.cat([self.textual, text_emb.cpu()]) if self.textual is not None else text_emb.cpu()
|
| 212 |
+
|
| 213 |
+
if torch.cuda.is_available():
|
| 214 |
+
torch.cuda.empty_cache()
|
| 215 |
+
|
| 216 |
+
self.visual = F.normalize(self.visual, p=2, dim=1)
|
| 217 |
+
self.textual = F.normalize(self.textual, p=2, dim=1)
|
| 218 |
+
self.index.add(self.visual)
|
| 219 |
+
if is_training:
|
| 220 |
+
self.model.train()
|
| 221 |
+
|
| 222 |
+
def mine_negatives(self, idx, n):
|
| 223 |
+
negatives = []
|
| 224 |
+
m = 4
|
| 225 |
+
label = self.dataset[idx]['flags'].item()
|
| 226 |
+
|
| 227 |
+
while len(negatives) < n and (n * m) < self.index.ntotal:
|
| 228 |
+
k = n * m
|
| 229 |
+
indeces = self.index.search(self.textual[idx].unsqueeze(0), k)[1][0]
|
| 230 |
+
indeces = [x * 5 for x in indeces]
|
| 231 |
+
negatives = [self.dataset[x] for x in indeces if self.dataset[x]['flags'].item() != label][:n]
|
| 232 |
+
m *= 2
|
| 233 |
+
|
| 234 |
+
return negatives
|
| 235 |
+
|
| 236 |
+
def create_batch(self, samples):
|
| 237 |
+
images_before = images_after = input_ids = pad_mask = labels = flags = embs = None
|
| 238 |
+
raws = []
|
| 239 |
+
|
| 240 |
+
for sample in samples:
|
| 241 |
+
img1 = sample['image_before']
|
| 242 |
+
img2 = sample['image_after']
|
| 243 |
+
|
| 244 |
+
tokens = sample['input_ids']
|
| 245 |
+
mask = sample['pad_masks']
|
| 246 |
+
flag = sample['flags']
|
| 247 |
+
emb = sample['embs']
|
| 248 |
+
|
| 249 |
+
tokens = tokens.unsqueeze(0)
|
| 250 |
+
mask = mask.unsqueeze(0)
|
| 251 |
+
flag = flag.unsqueeze(0)
|
| 252 |
+
lab = tokens.clone() * ~mask
|
| 253 |
+
lab += torch.tensor([[-100]], dtype=torch.long).repeat(1, self.max_len) * mask
|
| 254 |
+
if emb is not None:
|
| 255 |
+
emb = emb.unsqueeze(0)
|
| 256 |
+
|
| 257 |
+
images_before = torch.cat([images_before, img1]) if images_before is not None else img1
|
| 258 |
+
images_after = torch.cat([images_after, img2]) if images_after is not None else img2
|
| 259 |
+
input_ids = torch.cat([input_ids, tokens]) if input_ids is not None else tokens
|
| 260 |
+
labels = torch.cat([labels, lab]) if labels is not None else lab
|
| 261 |
+
pad_mask = torch.cat([pad_mask, mask]) if pad_mask is not None else mask
|
| 262 |
+
flags = torch.cat([flags, flag]) if flags is not None else flag
|
| 263 |
+
if emb is not None:
|
| 264 |
+
embs = torch.cat([embs, emb]) if embs is not None else emb
|
| 265 |
+
|
| 266 |
+
raws.append(sample['raws'])
|
| 267 |
+
|
| 268 |
+
return {'images_before': images_before, 'images_after': images_after, 'input_ids': input_ids,
|
| 269 |
+
'pad_mask': pad_mask, 'labels': labels, 'flags': flags, 'raws': raws, 'embs': embs}
|
src/Loss.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pytorch_metric_learning.distances import CosineSimilarity
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class InfoNCELoss():
|
| 6 |
+
def __init__(self, device, k, temperature=0.07, threshold=1.0, fna=False):
|
| 7 |
+
super(InfoNCELoss, self).__init__()
|
| 8 |
+
self.device = device
|
| 9 |
+
self.similarity = CosineSimilarity()
|
| 10 |
+
self.k = k
|
| 11 |
+
self.temperature = temperature
|
| 12 |
+
self.threshold = threshold
|
| 13 |
+
self.fna = fna
|
| 14 |
+
|
| 15 |
+
def __call__(self, x, y, labels, sts):
|
| 16 |
+
false_negatives = self.detect_false_negative(sts)
|
| 17 |
+
indices_tuple = self.get_all_pairs_indices(labels, false_negatives)
|
| 18 |
+
|
| 19 |
+
mat = self.similarity(x, y)
|
| 20 |
+
a1, p, a2, n = indices_tuple
|
| 21 |
+
pos_pair, neg_pair = [], []
|
| 22 |
+
if len(a1) > 0:
|
| 23 |
+
pos_pair = mat[a1, p]
|
| 24 |
+
if len(a2) > 0:
|
| 25 |
+
neg_pair = mat[a2, n]
|
| 26 |
+
|
| 27 |
+
if len(neg_pair) > 0 and self.k > -1:
|
| 28 |
+
paired = list(zip(neg_pair.tolist(), a2.tolist(), n.tolist()))
|
| 29 |
+
selected = sorted(paired, key=lambda x: x[0], reverse=True)[:self.k]
|
| 30 |
+
_, x, y = list(zip(*selected))
|
| 31 |
+
x = torch.tensor(x).to(a2.device)
|
| 32 |
+
y = torch.tensor(y).to(n.device)
|
| 33 |
+
|
| 34 |
+
neg_pair = mat[x, y]
|
| 35 |
+
indices_tuple = (a1, p, x, y)
|
| 36 |
+
|
| 37 |
+
return self._compute_loss(pos_pair, neg_pair, indices_tuple), len(pos_pair)
|
| 38 |
+
|
| 39 |
+
def detect_false_negative(self, embs):
|
| 40 |
+
mat = torch.matmul(embs, torch.t(embs))
|
| 41 |
+
return torch.where(mat >= self.threshold)
|
| 42 |
+
|
| 43 |
+
def _compute_loss(self, pos_pairs, neg_pairs, indices_tuple):
|
| 44 |
+
a1, p, a2, _ = indices_tuple
|
| 45 |
+
|
| 46 |
+
if len(a1) > 0 and len(a2) > 0:
|
| 47 |
+
dtype = neg_pairs.dtype
|
| 48 |
+
|
| 49 |
+
if not self.similarity.is_inverted:
|
| 50 |
+
pos_pairs = -pos_pairs
|
| 51 |
+
neg_pairs = -neg_pairs
|
| 52 |
+
|
| 53 |
+
pos_pairs = pos_pairs.unsqueeze(1) / self.temperature
|
| 54 |
+
neg_pairs = neg_pairs / self.temperature
|
| 55 |
+
n_per_p = a2.unsqueeze(0) == a1.unsqueeze(1)
|
| 56 |
+
neg_pairs = neg_pairs * n_per_p
|
| 57 |
+
neg_pairs[n_per_p == 0] = torch.finfo(dtype).min
|
| 58 |
+
|
| 59 |
+
max_val = torch.max(
|
| 60 |
+
pos_pairs, torch.max(neg_pairs, dim=1, keepdim=True)[0]
|
| 61 |
+
).detach()
|
| 62 |
+
numerator = torch.exp(pos_pairs - max_val).squeeze(1)
|
| 63 |
+
denominator = torch.sum(torch.exp(neg_pairs - max_val), dim=1) + numerator
|
| 64 |
+
log_exp = torch.log((numerator / denominator) + torch.finfo(dtype).tiny)
|
| 65 |
+
return torch.mean(-log_exp)
|
| 66 |
+
|
| 67 |
+
return 0
|
| 68 |
+
|
| 69 |
+
def get_all_pairs_indices(self, labels, false_negatives):
|
| 70 |
+
labels1 = labels.unsqueeze(1)
|
| 71 |
+
labels2 = labels.unsqueeze(0)
|
| 72 |
+
matches = (labels1 == labels2).byte()
|
| 73 |
+
diffs = matches ^ 1
|
| 74 |
+
|
| 75 |
+
diffs[false_negatives[0], false_negatives[1]] = 0 # FNE
|
| 76 |
+
if self.fna:
|
| 77 |
+
matches[false_negatives[0], false_negatives[1]] = 1 # FNA
|
| 78 |
+
|
| 79 |
+
diffs.fill_diagonal_(0)
|
| 80 |
+
matches.fill_diagonal_(1)
|
| 81 |
+
|
| 82 |
+
a1_idx, p_idx = torch.where(matches)
|
| 83 |
+
a2_idx, n_idx = torch.where(diffs)
|
| 84 |
+
return a1_idx, p_idx, a2_idx, n_idx
|
src/eval.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
import os
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import json
|
| 9 |
+
import faiss
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torchvision.transforms as T
|
| 13 |
+
|
| 14 |
+
import open_clip
|
| 15 |
+
|
| 16 |
+
from Datasets import CCDataset, Batcher
|
| 17 |
+
from model import ICCModel
|
| 18 |
+
from utils import get_vocabulary, unormalize, get_eval_score
|
| 19 |
+
|
| 20 |
+
AT_K = sorted([1, 3, 5, 10], reverse=True)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def captioning(args, config, model, data_loader, vocab, device):
|
| 24 |
+
scores, results = inference(config, model, data_loader, vocab, device, return_results=True)
|
| 25 |
+
|
| 26 |
+
with open(os.path.join(args.output_path, 'caption.txt'), 'w') as out:
|
| 27 |
+
for t in scores.items():
|
| 28 |
+
out.write(str(t) + '\n')
|
| 29 |
+
|
| 30 |
+
scores, _ = inference(config, model, data_loader, vocab, device, sub=True, return_results=False)
|
| 31 |
+
|
| 32 |
+
with open(os.path.join(args.output_path, 'caption_sub.txt'), 'w') as out:
|
| 33 |
+
for t in scores.items():
|
| 34 |
+
out.write(str(t) + '\n')
|
| 35 |
+
|
| 36 |
+
return results
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def retrieve(args, config, model, src_loader, device):
|
| 40 |
+
scores_p, scores_r, scores_rr = search(config, model, src_loader, device)
|
| 41 |
+
with open(os.path.join(args.output_path, 'retrieve.txt'), 'w') as out:
|
| 42 |
+
for k in AT_K:
|
| 43 |
+
out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k]))
|
| 44 |
+
out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k]))
|
| 45 |
+
out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k]))
|
| 46 |
+
out.write('\n')
|
| 47 |
+
|
| 48 |
+
scores_p, scores_r, scores_rr = search(config, model, src_loader, device, sub=True)
|
| 49 |
+
with open(os.path.join(args.output_path, 'retrieve_sub.txt'), 'w') as out:
|
| 50 |
+
for k in AT_K:
|
| 51 |
+
out.write('P@{0} {1:.4f}\n'.format(k, scores_p[k]))
|
| 52 |
+
out.write('R@{0} {1:.4f}\n'.format(k, scores_r[k]))
|
| 53 |
+
out.write('MRR@{0} {1:.4f}\n'.format(k, scores_rr[k]))
|
| 54 |
+
out.write('\n')
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def search(config, model, src_loader, device, sub=False):
|
| 59 |
+
model.eval()
|
| 60 |
+
|
| 61 |
+
visual = None
|
| 62 |
+
textual = None
|
| 63 |
+
flags = []
|
| 64 |
+
embs = None
|
| 65 |
+
index = faiss.IndexFlatIP(config['d_model'])
|
| 66 |
+
|
| 67 |
+
batcher = src_loader
|
| 68 |
+
|
| 69 |
+
for batch in tqdm(batcher, desc='Indexing'):
|
| 70 |
+
imgs1, imgs2, = batch['images_before'], batch['images_after']
|
| 71 |
+
imgs1 = imgs1.to(device)
|
| 72 |
+
imgs2 = imgs2.to(device)
|
| 73 |
+
flag = batch['flags']
|
| 74 |
+
emb = batch['embs']
|
| 75 |
+
if sub and flag[0] == -1:
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
flags.append(flag)
|
| 79 |
+
embs = torch.cat([embs, emb]) if embs is not None else emb
|
| 80 |
+
|
| 81 |
+
vis_emb, _ = model.encoder(imgs1, imgs2)
|
| 82 |
+
visual = torch.cat([visual, vis_emb.cpu()]) if visual is not None else vis_emb.cpu()
|
| 83 |
+
|
| 84 |
+
input_ids, mask = batch['input_ids'], batch['pad_mask']
|
| 85 |
+
input_ids = input_ids.to(device)
|
| 86 |
+
mask = mask.to(device)
|
| 87 |
+
|
| 88 |
+
_, text_emb, _, _ = model.decoder(input_ids, None, mask, None)
|
| 89 |
+
textual = torch.cat([textual, text_emb.cpu()]) if textual is not None else text_emb.cpu()
|
| 90 |
+
|
| 91 |
+
if torch.cuda.is_available():
|
| 92 |
+
torch.cuda.empty_cache()
|
| 93 |
+
|
| 94 |
+
embs = embs.to(device)
|
| 95 |
+
sims = torch.matmul(embs, torch.t(embs))
|
| 96 |
+
|
| 97 |
+
visual = F.normalize(visual, p=2, dim=1)
|
| 98 |
+
textual = F.normalize(textual, p=2, dim=1)
|
| 99 |
+
|
| 100 |
+
index.add(visual)
|
| 101 |
+
|
| 102 |
+
scores_p = {k: [] for k in AT_K}
|
| 103 |
+
scores_r = {k: [] for k in AT_K}
|
| 104 |
+
scores_rr = {k: [] for k in AT_K}
|
| 105 |
+
|
| 106 |
+
for i in tqdm(range(textual.shape[0]), desc='Ranking'):
|
| 107 |
+
indices = None
|
| 108 |
+
query = textual[i]
|
| 109 |
+
query_lab = flags[i]
|
| 110 |
+
|
| 111 |
+
relevants = set(
|
| 112 |
+
[x for x in range(len(textual)) if flags[x] == query_lab or sims[i][x] >= config['s-threshold']])
|
| 113 |
+
|
| 114 |
+
for k in AT_K:
|
| 115 |
+
p = 0
|
| 116 |
+
r = 0
|
| 117 |
+
rr = 0
|
| 118 |
+
|
| 119 |
+
if indices is None:
|
| 120 |
+
indices = index.search(query.unsqueeze(0), k)[1][0]
|
| 121 |
+
else:
|
| 122 |
+
indices = indices[:k]
|
| 123 |
+
|
| 124 |
+
for rank, idx in enumerate(indices):
|
| 125 |
+
if idx in relevants:
|
| 126 |
+
if p == 0:
|
| 127 |
+
rr = 1 / (rank + 1)
|
| 128 |
+
p += 1
|
| 129 |
+
r += 1
|
| 130 |
+
|
| 131 |
+
scores_p[k].append(p / len(indices))
|
| 132 |
+
scores_r[k].append(r / len(relevants))
|
| 133 |
+
scores_rr[k].append(rr)
|
| 134 |
+
|
| 135 |
+
for k in AT_K:
|
| 136 |
+
scores_p[k] = sum(scores_p[k]) / len(scores_p[k])
|
| 137 |
+
scores_r[k] = sum(scores_r[k]) / len(scores_r[k])
|
| 138 |
+
scores_rr[k] = sum(scores_rr[k]) / len(scores_rr[k])
|
| 139 |
+
|
| 140 |
+
return scores_p, scores_r, scores_rr
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def inference(config, model, data_loader, vocab, device, sub=False, return_results=False):
|
| 145 |
+
results = []
|
| 146 |
+
references = []
|
| 147 |
+
hypotheses = []
|
| 148 |
+
inverse_vocab = {v: k for k, v in vocab.items()}
|
| 149 |
+
|
| 150 |
+
model.eval()
|
| 151 |
+
for batch in tqdm(data_loader, desc='Inference'):
|
| 152 |
+
img1 = batch['images_before'][0].unsqueeze(0).to(device)
|
| 153 |
+
img2 = batch['images_after'][0].unsqueeze(0).to(device)
|
| 154 |
+
raws = batch['raws']
|
| 155 |
+
flags = batch['flags']
|
| 156 |
+
if sub and flags[0] == -1:
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
references.append(raws[0])
|
| 160 |
+
|
| 161 |
+
input_ids = torch.tensor([[vocab['START']]], dtype=torch.long, device=device)
|
| 162 |
+
_, vis_toks = model.encoder(img1, img2)
|
| 163 |
+
|
| 164 |
+
for _ in range(config['max_len']):
|
| 165 |
+
_, _, lm_logits, weights = model.decoder(input_ids, None, None, vis_toks)
|
| 166 |
+
|
| 167 |
+
next_item = lm_logits[0][-1].topk(1)[1]
|
| 168 |
+
input_ids = torch.cat([input_ids, next_item.reshape(1, -1)], dim=1)
|
| 169 |
+
if next_item.item() == vocab['END']:
|
| 170 |
+
break
|
| 171 |
+
|
| 172 |
+
words = [inverse_vocab[x] for x in input_ids[0].cpu().tolist()]
|
| 173 |
+
sentence = ' '.join(words[1:-1]).strip()
|
| 174 |
+
hypotheses.append([sentence])
|
| 175 |
+
|
| 176 |
+
if return_results:
|
| 177 |
+
results.append(
|
| 178 |
+
(img1.cpu(), img2.cpu(), weights.detach().cpu(), vis_toks.detach().cpu(), sentence))
|
| 179 |
+
|
| 180 |
+
score_dict = get_eval_score(references, hypotheses)
|
| 181 |
+
return score_dict, results
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def plot(args, feat_size, results):
|
| 185 |
+
fig_idx = 0
|
| 186 |
+
for img1, img2, weights, diff, sentence in tqdm(results, desc='Plot'):
|
| 187 |
+
img1 = unormalize(img1)
|
| 188 |
+
img1 = img1[0].permute(1, 2, 0) # h,w,c
|
| 189 |
+
img2 = unormalize(img2)
|
| 190 |
+
img2 = img2[0].permute(1, 2, 0) # h,w,c
|
| 191 |
+
|
| 192 |
+
transform = T.Resize(size=(img1.size(0), img1.size(1)))
|
| 193 |
+
weights = weights[0].reshape(-1, feat_size, feat_size)
|
| 194 |
+
weights = transform(weights).permute(1, 2, 0) # h,w,d
|
| 195 |
+
weights = torch.sum(weights, 2) / weights.shape[2]
|
| 196 |
+
after = img2 # h,w,c
|
| 197 |
+
|
| 198 |
+
feature_map = diff[:, 0, :].reshape(-1, feat_size, feat_size) # e,h,w
|
| 199 |
+
feature_map = transform(feature_map).permute(1, 2, 0) # h,w,c
|
| 200 |
+
feature_map = torch.sum(feature_map, 2) / feature_map.shape[2] # h, w
|
| 201 |
+
|
| 202 |
+
fig, ax = plt.subplots(2, 2, figsize=(6, 8))
|
| 203 |
+
fig.tight_layout()
|
| 204 |
+
ax[0, 0].imshow(img1)
|
| 205 |
+
ax[0, 0].set_title("Before")
|
| 206 |
+
ax[0, 0].axis('off')
|
| 207 |
+
ax[0, 1].imshow(img2)
|
| 208 |
+
ax[0, 1].set_title("After")
|
| 209 |
+
ax[0, 1].axis('off')
|
| 210 |
+
|
| 211 |
+
ax[1, 0].set_title("Img diff")
|
| 212 |
+
ax[1, 0].imshow(feature_map)
|
| 213 |
+
ax[1, 0].axis('off')
|
| 214 |
+
|
| 215 |
+
ax[1, 1].set_title("Att weights")
|
| 216 |
+
ax[1, 1].imshow(after, interpolation='nearest')
|
| 217 |
+
ax[1, 1].imshow(weights, interpolation='bilinear', alpha=0.5)
|
| 218 |
+
ax[1, 1].axis('off')
|
| 219 |
+
|
| 220 |
+
fig.text(.1, .05, sentence, wrap=True)
|
| 221 |
+
|
| 222 |
+
with open(os.path.join(args.output_path, str(fig_idx) + '.png'), 'wb') as f:
|
| 223 |
+
plt.savefig(f)
|
| 224 |
+
plt.close()
|
| 225 |
+
fig_idx += 1
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def run(args, config):
|
| 229 |
+
print('Initializing...')
|
| 230 |
+
torch.manual_seed(args.seed)
|
| 231 |
+
np.random.seed(args.seed)
|
| 232 |
+
random.seed(args.seed)
|
| 233 |
+
torch.backends.cudnn.deterministic = True
|
| 234 |
+
|
| 235 |
+
device = torch.device('cpu')
|
| 236 |
+
if torch.cuda.is_available():
|
| 237 |
+
device = torch.device('cuda')
|
| 238 |
+
|
| 239 |
+
if os.path.exists(args.vocab):
|
| 240 |
+
with open(args.vocab, 'r') as infile:
|
| 241 |
+
vocab = json.load(infile)
|
| 242 |
+
else:
|
| 243 |
+
vocab = get_vocabulary(args.annotation_json, args.vocab)
|
| 244 |
+
|
| 245 |
+
clip, _, preprocess = open_clip.create_model_and_transforms(config['backbone'])
|
| 246 |
+
|
| 247 |
+
model = ICCModel(device, clip, config['backbone'], config['d_model'],
|
| 248 |
+
len(vocab), config['max_len'], config['num_heads'], config['h_dim'], config['a_dim'],
|
| 249 |
+
config['encoder_layers'], config['decoder_layers'], config['dropout'],
|
| 250 |
+
learnable=config['learnable'], fine_tune=config['fine_tune'],
|
| 251 |
+
tie_embeddings=config['tie_embeddings'], prenorm=config['prenorm'])
|
| 252 |
+
|
| 253 |
+
model.load_state_dict(torch.load(args.model, map_location=device))
|
| 254 |
+
model = model.to(device)
|
| 255 |
+
del clip
|
| 256 |
+
|
| 257 |
+
print('Loading...')
|
| 258 |
+
test_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'test', config['max_len'],
|
| 259 |
+
config['s-transformers'], device)
|
| 260 |
+
test_loader = Batcher(test_set, 1, config['max_len'], device)
|
| 261 |
+
|
| 262 |
+
print('Final evaluation...')
|
| 263 |
+
results = captioning(args, config, model, test_loader, vocab, device)
|
| 264 |
+
retrieve(args, config, model, test_loader, device)
|
| 265 |
+
plot(args, model.encoder.encoder.feat_size, results)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def main():
|
| 269 |
+
parser = argparse.ArgumentParser()
|
| 270 |
+
parser.add_argument('--model', type=str, default='../input/model_best.pt')
|
| 271 |
+
parser.add_argument('--annotation_json', type=str, default='../input/Levir_CC/LevirCCcaptions.json')
|
| 272 |
+
parser.add_argument('--image_dir', type=str, default='../input/Levir_CC/images/')
|
| 273 |
+
parser.add_argument('--vocab', type=str, default='../input/levir_vocab.json')
|
| 274 |
+
|
| 275 |
+
parser.add_argument('--config', type=str, default='../config.json')
|
| 276 |
+
parser.add_argument('--output_path', type=str, default='../output/')
|
| 277 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 278 |
+
|
| 279 |
+
args = parser.parse_args()
|
| 280 |
+
|
| 281 |
+
with open(args.config, 'r') as config_file:
|
| 282 |
+
config = json.load(config_file)
|
| 283 |
+
|
| 284 |
+
run(args, config)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
if __name__ == '__main__':
|
| 288 |
+
main()
|
src/eval_func/__init__.py
ADDED
|
File without changes
|
src/eval_func/bleu/LICENSE
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
|
| 2 |
+
|
| 3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 4 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 5 |
+
in the Software without restriction, including without limitation the rights
|
| 6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 7 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 8 |
+
furnished to do so, subject to the following conditions:
|
| 9 |
+
|
| 10 |
+
The above copyright notice and this permission notice shall be included in
|
| 11 |
+
all copies or substantial portions of the Software.
|
| 12 |
+
|
| 13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
| 19 |
+
THE SOFTWARE.
|
src/eval_func/bleu/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'tylin'
|
src/eval_func/bleu/bleu.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
#
|
| 3 |
+
# File Name : bleu.py
|
| 4 |
+
#
|
| 5 |
+
# Description : Wrapper for BLEU scorer.
|
| 6 |
+
#
|
| 7 |
+
# Creation Date : 06-01-2015
|
| 8 |
+
# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
|
| 9 |
+
# Authors : Hao Fang <hfang@uw.edu> and Tsung-Yi Lin <tl483@cornell.edu>
|
| 10 |
+
|
| 11 |
+
from .bleu_scorer import BleuScorer
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Bleu:
|
| 15 |
+
def __init__(self, n=4):
|
| 16 |
+
# default compute Blue score up to 4
|
| 17 |
+
self._n = n
|
| 18 |
+
self._hypo_for_image = {}
|
| 19 |
+
self.ref_for_image = {}
|
| 20 |
+
|
| 21 |
+
def compute_score(self, gts, res):
|
| 22 |
+
|
| 23 |
+
bleu_scorer = BleuScorer(n=self._n)
|
| 24 |
+
for i in range(len(res)):
|
| 25 |
+
hypo = res[i]
|
| 26 |
+
ref = gts[i]
|
| 27 |
+
|
| 28 |
+
# Sanity check.
|
| 29 |
+
assert(type(hypo) is list)
|
| 30 |
+
assert(len(hypo) == 1)
|
| 31 |
+
assert(type(ref) is list)
|
| 32 |
+
assert(len(ref) >= 1)
|
| 33 |
+
|
| 34 |
+
bleu_scorer += (hypo[0], ref)
|
| 35 |
+
|
| 36 |
+
#score, scores = bleu_scorer.compute_score(option='shortest')
|
| 37 |
+
score, scores = bleu_scorer.compute_score(option='closest', verbose=1)
|
| 38 |
+
#score, scores = bleu_scorer.compute_score(option='average', verbose=1)
|
| 39 |
+
|
| 40 |
+
# return (bleu, bleu_info)
|
| 41 |
+
return score, scores
|
| 42 |
+
|
| 43 |
+
def method(self):
|
| 44 |
+
return "Bleu"
|
src/eval_func/bleu/bleu_scorer.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# bleu_scorer.py
|
| 4 |
+
# David Chiang <chiang@isi.edu>
|
| 5 |
+
|
| 6 |
+
# Copyright (c) 2004-2006 University of Maryland. All rights
|
| 7 |
+
# reserved. Do not redistribute without permission from the
|
| 8 |
+
# author. Not for commercial use.
|
| 9 |
+
|
| 10 |
+
# Modified by:
|
| 11 |
+
# Hao Fang <hfang@uw.edu>
|
| 12 |
+
# Tsung-Yi Lin <tl483@cornell.edu>
|
| 13 |
+
|
| 14 |
+
'''Provides:
|
| 15 |
+
cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
|
| 16 |
+
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
|
| 17 |
+
'''
|
| 18 |
+
|
| 19 |
+
import copy
|
| 20 |
+
import sys, math, re
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
|
| 23 |
+
def precook(s, n=4, out=False):
|
| 24 |
+
"""Takes a string as input and returns an object that can be given to
|
| 25 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
| 26 |
+
can take string arguments as well."""
|
| 27 |
+
words = s.split()
|
| 28 |
+
counts = defaultdict(int)
|
| 29 |
+
for k in range(1,n+1):
|
| 30 |
+
for i in range(len(words)-k+1):
|
| 31 |
+
ngram = tuple(words[i:i+k])
|
| 32 |
+
counts[ngram] += 1
|
| 33 |
+
return (len(words), counts)
|
| 34 |
+
|
| 35 |
+
def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
|
| 36 |
+
'''Takes a list of reference sentences for a single segment
|
| 37 |
+
and returns an object that encapsulates everything that BLEU
|
| 38 |
+
needs to know about them.'''
|
| 39 |
+
|
| 40 |
+
reflen = []
|
| 41 |
+
maxcounts = {}
|
| 42 |
+
for ref in refs:
|
| 43 |
+
rl, counts = precook(ref, n)
|
| 44 |
+
reflen.append(rl)
|
| 45 |
+
for (ngram,count) in counts.items():
|
| 46 |
+
maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
| 47 |
+
|
| 48 |
+
# Calculate effective reference sentence length.
|
| 49 |
+
if eff == "shortest":
|
| 50 |
+
reflen = min(reflen)
|
| 51 |
+
elif eff == "average":
|
| 52 |
+
reflen = float(sum(reflen))/len(reflen)
|
| 53 |
+
|
| 54 |
+
## lhuang: N.B.: leave reflen computaiton to the very end!!
|
| 55 |
+
|
| 56 |
+
## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
|
| 57 |
+
|
| 58 |
+
return (reflen, maxcounts)
|
| 59 |
+
|
| 60 |
+
def cook_test(test, xxx_todo_changeme, eff=None, n=4):
|
| 61 |
+
'''Takes a test sentence and returns an object that
|
| 62 |
+
encapsulates everything that BLEU needs to know about it.'''
|
| 63 |
+
(reflen, refmaxcounts) = xxx_todo_changeme
|
| 64 |
+
testlen, counts = precook(test, n, True)
|
| 65 |
+
|
| 66 |
+
result = {}
|
| 67 |
+
|
| 68 |
+
# Calculate effective reference sentence length.
|
| 69 |
+
|
| 70 |
+
if eff == "closest":
|
| 71 |
+
result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
|
| 72 |
+
else: ## i.e., "average" or "shortest" or None
|
| 73 |
+
result["reflen"] = reflen
|
| 74 |
+
|
| 75 |
+
result["testlen"] = testlen
|
| 76 |
+
|
| 77 |
+
result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
|
| 78 |
+
|
| 79 |
+
result['correct'] = [0]*n
|
| 80 |
+
for (ngram, count) in counts.items():
|
| 81 |
+
result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
|
| 82 |
+
|
| 83 |
+
return result
|
| 84 |
+
|
| 85 |
+
class BleuScorer(object):
|
| 86 |
+
"""Bleu scorer.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
__slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
|
| 90 |
+
# special_reflen is used in oracle (proportional effective ref len for a node).
|
| 91 |
+
|
| 92 |
+
def copy(self):
|
| 93 |
+
''' copy the refs.'''
|
| 94 |
+
new = BleuScorer(n=self.n)
|
| 95 |
+
new.ctest = copy.copy(self.ctest)
|
| 96 |
+
new.crefs = copy.copy(self.crefs)
|
| 97 |
+
new._score = None
|
| 98 |
+
return new
|
| 99 |
+
|
| 100 |
+
def __init__(self, test=None, refs=None, n=4, special_reflen=None):
|
| 101 |
+
''' singular instance '''
|
| 102 |
+
|
| 103 |
+
self.n = n
|
| 104 |
+
self.crefs = []
|
| 105 |
+
self.ctest = []
|
| 106 |
+
self.cook_append(test, refs)
|
| 107 |
+
self.special_reflen = special_reflen
|
| 108 |
+
|
| 109 |
+
def cook_append(self, test, refs):
|
| 110 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
| 111 |
+
|
| 112 |
+
if refs is not None:
|
| 113 |
+
self.crefs.append(cook_refs(refs))
|
| 114 |
+
if test is not None:
|
| 115 |
+
cooked_test = cook_test(test, self.crefs[-1])
|
| 116 |
+
self.ctest.append(cooked_test) ## N.B.: -1
|
| 117 |
+
else:
|
| 118 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
| 119 |
+
|
| 120 |
+
self._score = None ## need to recompute
|
| 121 |
+
|
| 122 |
+
def ratio(self, option=None):
|
| 123 |
+
self.compute_score(option=option)
|
| 124 |
+
return self._ratio
|
| 125 |
+
|
| 126 |
+
def score_ratio(self, option=None):
|
| 127 |
+
'''return (bleu, len_ratio) pair'''
|
| 128 |
+
return (self.fscore(option=option), self.ratio(option=option))
|
| 129 |
+
|
| 130 |
+
def score_ratio_str(self, option=None):
|
| 131 |
+
return "%.4f (%.2f)" % self.score_ratio(option)
|
| 132 |
+
|
| 133 |
+
def reflen(self, option=None):
|
| 134 |
+
self.compute_score(option=option)
|
| 135 |
+
return self._reflen
|
| 136 |
+
|
| 137 |
+
def testlen(self, option=None):
|
| 138 |
+
self.compute_score(option=option)
|
| 139 |
+
return self._testlen
|
| 140 |
+
|
| 141 |
+
def retest(self, new_test):
|
| 142 |
+
if type(new_test) is str:
|
| 143 |
+
new_test = [new_test]
|
| 144 |
+
assert len(new_test) == len(self.crefs), new_test
|
| 145 |
+
self.ctest = []
|
| 146 |
+
for t, rs in zip(new_test, self.crefs):
|
| 147 |
+
self.ctest.append(cook_test(t, rs))
|
| 148 |
+
self._score = None
|
| 149 |
+
|
| 150 |
+
return self
|
| 151 |
+
|
| 152 |
+
def rescore(self, new_test):
|
| 153 |
+
''' replace test(s) with new test(s), and returns the new score.'''
|
| 154 |
+
|
| 155 |
+
return self.retest(new_test).compute_score()
|
| 156 |
+
|
| 157 |
+
def size(self):
|
| 158 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
| 159 |
+
return len(self.crefs)
|
| 160 |
+
|
| 161 |
+
def __iadd__(self, other):
|
| 162 |
+
'''add an instance (e.g., from another sentence).'''
|
| 163 |
+
|
| 164 |
+
if type(other) is tuple:
|
| 165 |
+
## avoid creating new BleuScorer instances
|
| 166 |
+
self.cook_append(other[0], other[1])
|
| 167 |
+
else:
|
| 168 |
+
assert self.compatible(other), "incompatible BLEUs."
|
| 169 |
+
self.ctest.extend(other.ctest)
|
| 170 |
+
self.crefs.extend(other.crefs)
|
| 171 |
+
self._score = None ## need to recompute
|
| 172 |
+
|
| 173 |
+
return self
|
| 174 |
+
|
| 175 |
+
def compatible(self, other):
|
| 176 |
+
return isinstance(other, BleuScorer) and self.n == other.n
|
| 177 |
+
|
| 178 |
+
def single_reflen(self, option="average"):
|
| 179 |
+
return self._single_reflen(self.crefs[0][0], option)
|
| 180 |
+
|
| 181 |
+
def _single_reflen(self, reflens, option=None, testlen=None):
|
| 182 |
+
|
| 183 |
+
if option == "shortest":
|
| 184 |
+
reflen = min(reflens)
|
| 185 |
+
elif option == "average":
|
| 186 |
+
reflen = float(sum(reflens))/len(reflens)
|
| 187 |
+
elif option == "closest":
|
| 188 |
+
reflen = min((abs(l-testlen), l) for l in reflens)[1]
|
| 189 |
+
else:
|
| 190 |
+
assert False, "unsupported reflen option %s" % option
|
| 191 |
+
|
| 192 |
+
return reflen
|
| 193 |
+
|
| 194 |
+
def recompute_score(self, option=None, verbose=0):
|
| 195 |
+
self._score = None
|
| 196 |
+
return self.compute_score(option, verbose)
|
| 197 |
+
|
| 198 |
+
def compute_score(self, option=None, verbose=0):
|
| 199 |
+
n = self.n
|
| 200 |
+
small = 1e-9
|
| 201 |
+
tiny = 1e-15 ## so that if guess is 0 still return 0
|
| 202 |
+
bleu_list = [[] for _ in range(n)]
|
| 203 |
+
|
| 204 |
+
if self._score is not None:
|
| 205 |
+
return self._score
|
| 206 |
+
|
| 207 |
+
if option is None:
|
| 208 |
+
option = "average" if len(self.crefs) == 1 else "closest"
|
| 209 |
+
|
| 210 |
+
self._testlen = 0
|
| 211 |
+
self._reflen = 0
|
| 212 |
+
totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
|
| 213 |
+
|
| 214 |
+
# for each sentence
|
| 215 |
+
for comps in self.ctest:
|
| 216 |
+
testlen = comps['testlen']
|
| 217 |
+
self._testlen += testlen
|
| 218 |
+
|
| 219 |
+
if self.special_reflen is None: ## need computation
|
| 220 |
+
reflen = self._single_reflen(comps['reflen'], option, testlen)
|
| 221 |
+
else:
|
| 222 |
+
reflen = self.special_reflen
|
| 223 |
+
|
| 224 |
+
self._reflen += reflen
|
| 225 |
+
|
| 226 |
+
for key in ['guess','correct']:
|
| 227 |
+
for k in range(n):
|
| 228 |
+
totalcomps[key][k] += comps[key][k]
|
| 229 |
+
|
| 230 |
+
# append per image bleu score
|
| 231 |
+
bleu = 1.
|
| 232 |
+
for k in range(n):
|
| 233 |
+
bleu *= (float(comps['correct'][k]) + tiny) \
|
| 234 |
+
/(float(comps['guess'][k]) + small)
|
| 235 |
+
bleu_list[k].append(bleu ** (1./(k+1)))
|
| 236 |
+
ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
|
| 237 |
+
if ratio < 1:
|
| 238 |
+
for k in range(n):
|
| 239 |
+
bleu_list[k][-1] *= math.exp(1 - 1/ratio)
|
| 240 |
+
|
| 241 |
+
# if verbose > 1:
|
| 242 |
+
# print(comps, reflen)
|
| 243 |
+
|
| 244 |
+
totalcomps['reflen'] = self._reflen
|
| 245 |
+
totalcomps['testlen'] = self._testlen
|
| 246 |
+
|
| 247 |
+
bleus = []
|
| 248 |
+
bleu = 1.
|
| 249 |
+
for k in range(n):
|
| 250 |
+
bleu *= float(totalcomps['correct'][k] + tiny) \
|
| 251 |
+
/ (totalcomps['guess'][k] + small)
|
| 252 |
+
bleus.append(bleu ** (1./(k+1)))
|
| 253 |
+
ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
|
| 254 |
+
if ratio < 1:
|
| 255 |
+
for k in range(n):
|
| 256 |
+
bleus[k] *= math.exp(1 - 1/ratio)
|
| 257 |
+
|
| 258 |
+
# if verbose > 0:
|
| 259 |
+
# print(totalcomps)
|
| 260 |
+
# print("ratio:", ratio)
|
| 261 |
+
|
| 262 |
+
self._score = bleus
|
| 263 |
+
return self._score, bleu_list
|
src/eval_func/cider/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'tylin'
|
src/eval_func/cider/cider.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Filename: cider.py
|
| 2 |
+
#
|
| 3 |
+
# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
|
| 4 |
+
# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
|
| 5 |
+
#
|
| 6 |
+
# Creation Date: Sun Feb 8 14:16:54 2015
|
| 7 |
+
#
|
| 8 |
+
# Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>
|
| 9 |
+
|
| 10 |
+
from eval_func.cider.cider_scorer import CiderScorer
|
| 11 |
+
import pdb
|
| 12 |
+
|
| 13 |
+
class Cider:
|
| 14 |
+
"""
|
| 15 |
+
Main Class to compute the CIDEr metric
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
def __init__(self, test=None, refs=None, n=4, sigma=6.0):
|
| 19 |
+
# set cider to sum over 1 to 4-grams
|
| 20 |
+
self._n = n
|
| 21 |
+
# set the standard deviation parameter for gaussian penalty
|
| 22 |
+
self._sigma = sigma
|
| 23 |
+
|
| 24 |
+
def compute_score(self, gts, res):
|
| 25 |
+
"""
|
| 26 |
+
Main function to compute CIDEr score
|
| 27 |
+
:param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
|
| 28 |
+
ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
|
| 29 |
+
:return: cider (float) : computed CIDEr score for the corpus
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
|
| 33 |
+
|
| 34 |
+
for i in range(len(res)):
|
| 35 |
+
hypo = res[i]
|
| 36 |
+
ref = gts[i]
|
| 37 |
+
|
| 38 |
+
# Sanity check.
|
| 39 |
+
assert(type(hypo) is list)
|
| 40 |
+
assert(len(hypo) == 1)
|
| 41 |
+
assert(type(ref) is list)
|
| 42 |
+
assert(len(ref) > 0)
|
| 43 |
+
|
| 44 |
+
cider_scorer += (hypo[0], ref)
|
| 45 |
+
|
| 46 |
+
(score, scores) = cider_scorer.compute_score()
|
| 47 |
+
|
| 48 |
+
return score, scores
|
| 49 |
+
|
| 50 |
+
def method(self):
|
| 51 |
+
return "CIDEr"
|
src/eval_func/cider/cider_scorer.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# Tsung-Yi Lin <tl483@cornell.edu>
|
| 3 |
+
# Ramakrishna Vedantam <vrama91@vt.edu>
|
| 4 |
+
|
| 5 |
+
import copy
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pdb
|
| 9 |
+
import math
|
| 10 |
+
|
| 11 |
+
def precook(s, n=4, out=False):
|
| 12 |
+
"""
|
| 13 |
+
Takes a string as input and returns an object that can be given to
|
| 14 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
| 15 |
+
can take string arguments as well.
|
| 16 |
+
:param s: string : sentence to be converted into ngrams
|
| 17 |
+
:param n: int : number of ngrams for which representation is calculated
|
| 18 |
+
:return: term frequency vector for occuring ngrams
|
| 19 |
+
"""
|
| 20 |
+
words = s.split()
|
| 21 |
+
counts = defaultdict(int)
|
| 22 |
+
for k in range(1,n+1):
|
| 23 |
+
for i in range(len(words)-k+1):
|
| 24 |
+
ngram = tuple(words[i:i+k])
|
| 25 |
+
counts[ngram] += 1
|
| 26 |
+
return counts
|
| 27 |
+
|
| 28 |
+
def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
|
| 29 |
+
'''Takes a list of reference sentences for a single segment
|
| 30 |
+
and returns an object that encapsulates everything that BLEU
|
| 31 |
+
needs to know about them.
|
| 32 |
+
:param refs: list of string : reference sentences for some image
|
| 33 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 34 |
+
:return: result (list of dict)
|
| 35 |
+
'''
|
| 36 |
+
return [precook(ref, n) for ref in refs]
|
| 37 |
+
|
| 38 |
+
def cook_test(test, n=4):
|
| 39 |
+
'''Takes a test sentence and returns an object that
|
| 40 |
+
encapsulates everything that BLEU needs to know about it.
|
| 41 |
+
:param test: list of string : hypothesis sentence for some image
|
| 42 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
| 43 |
+
:return: result (dict)
|
| 44 |
+
'''
|
| 45 |
+
return precook(test, n, True)
|
| 46 |
+
|
| 47 |
+
class CiderScorer(object):
|
| 48 |
+
"""CIDEr scorer.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def copy(self):
|
| 52 |
+
''' copy the refs.'''
|
| 53 |
+
new = CiderScorer(n=self.n)
|
| 54 |
+
new.ctest = copy.copy(self.ctest)
|
| 55 |
+
new.crefs = copy.copy(self.crefs)
|
| 56 |
+
return new
|
| 57 |
+
|
| 58 |
+
def __init__(self, test=None, refs=None, n=4, sigma=6.0):
|
| 59 |
+
''' singular instance '''
|
| 60 |
+
self.n = n
|
| 61 |
+
self.sigma = sigma
|
| 62 |
+
self.crefs = []
|
| 63 |
+
self.ctest = []
|
| 64 |
+
self.document_frequency = defaultdict(float)
|
| 65 |
+
self.cook_append(test, refs)
|
| 66 |
+
self.ref_len = None
|
| 67 |
+
|
| 68 |
+
def cook_append(self, test, refs):
|
| 69 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
| 70 |
+
|
| 71 |
+
if refs is not None:
|
| 72 |
+
self.crefs.append(cook_refs(refs))
|
| 73 |
+
if test is not None:
|
| 74 |
+
self.ctest.append(cook_test(test)) ## N.B.: -1
|
| 75 |
+
else:
|
| 76 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
| 77 |
+
|
| 78 |
+
def size(self):
|
| 79 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
| 80 |
+
return len(self.crefs)
|
| 81 |
+
|
| 82 |
+
def __iadd__(self, other):
|
| 83 |
+
'''add an instance (e.g., from another sentence).'''
|
| 84 |
+
|
| 85 |
+
if type(other) is tuple:
|
| 86 |
+
## avoid creating new CiderScorer instances
|
| 87 |
+
self.cook_append(other[0], other[1])
|
| 88 |
+
else:
|
| 89 |
+
self.ctest.extend(other.ctest)
|
| 90 |
+
self.crefs.extend(other.crefs)
|
| 91 |
+
|
| 92 |
+
return self
|
| 93 |
+
def compute_doc_freq(self):
|
| 94 |
+
'''
|
| 95 |
+
Compute term frequency for reference data.
|
| 96 |
+
This will be used to compute idf (inverse document frequency later)
|
| 97 |
+
The term frequency is stored in the object
|
| 98 |
+
:return: None
|
| 99 |
+
'''
|
| 100 |
+
for refs in self.crefs:
|
| 101 |
+
# refs, k ref captions of one image
|
| 102 |
+
for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
|
| 103 |
+
self.document_frequency[ngram] += 1
|
| 104 |
+
# maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
| 105 |
+
|
| 106 |
+
def compute_cider(self):
|
| 107 |
+
def counts2vec(cnts):
|
| 108 |
+
"""
|
| 109 |
+
Function maps counts of ngram to vector of tfidf weights.
|
| 110 |
+
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
|
| 111 |
+
The n-th entry of array denotes length of n-grams.
|
| 112 |
+
:param cnts:
|
| 113 |
+
:return: vec (array of dict), norm (array of float), length (int)
|
| 114 |
+
"""
|
| 115 |
+
vec = [defaultdict(float) for _ in range(self.n)]
|
| 116 |
+
length = 0
|
| 117 |
+
norm = [0.0 for _ in range(self.n)]
|
| 118 |
+
for (ngram, term_freq) in cnts.items():
|
| 119 |
+
# give word count 1 if it doesn't appear in reference corpus
|
| 120 |
+
df = np.log(max(1.0, self.document_frequency[ngram]))
|
| 121 |
+
# ngram index
|
| 122 |
+
n = len(ngram)-1
|
| 123 |
+
# tf (term_freq) * idf (precomputed idf) for n-grams
|
| 124 |
+
vec[n][ngram] = float(term_freq)*(self.ref_len - df)
|
| 125 |
+
# compute norm for the vector. the norm will be used for computing similarity
|
| 126 |
+
norm[n] += pow(vec[n][ngram], 2)
|
| 127 |
+
|
| 128 |
+
if n == 1:
|
| 129 |
+
length += term_freq
|
| 130 |
+
norm = [np.sqrt(n) for n in norm]
|
| 131 |
+
return vec, norm, length
|
| 132 |
+
|
| 133 |
+
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
|
| 134 |
+
'''
|
| 135 |
+
Compute the cosine similarity of two vectors.
|
| 136 |
+
:param vec_hyp: array of dictionary for vector corresponding to hypothesis
|
| 137 |
+
:param vec_ref: array of dictionary for vector corresponding to reference
|
| 138 |
+
:param norm_hyp: array of float for vector corresponding to hypothesis
|
| 139 |
+
:param norm_ref: array of float for vector corresponding to reference
|
| 140 |
+
:param length_hyp: int containing length of hypothesis
|
| 141 |
+
:param length_ref: int containing length of reference
|
| 142 |
+
:return: array of score for each n-grams cosine similarity
|
| 143 |
+
'''
|
| 144 |
+
delta = float(length_hyp - length_ref)
|
| 145 |
+
# measure consine similarity
|
| 146 |
+
val = np.array([0.0 for _ in range(self.n)])
|
| 147 |
+
for n in range(self.n):
|
| 148 |
+
# ngram
|
| 149 |
+
for (ngram,count) in vec_hyp[n].items():
|
| 150 |
+
# vrama91 : added clipping
|
| 151 |
+
val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
|
| 152 |
+
|
| 153 |
+
if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
|
| 154 |
+
val[n] /= (norm_hyp[n]*norm_ref[n])
|
| 155 |
+
|
| 156 |
+
assert(not math.isnan(val[n]))
|
| 157 |
+
# vrama91: added a length based gaussian penalty
|
| 158 |
+
val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
|
| 159 |
+
return val
|
| 160 |
+
|
| 161 |
+
# compute log reference length
|
| 162 |
+
self.ref_len = np.log(float(len(self.crefs)))
|
| 163 |
+
if len(self.crefs) == 1:
|
| 164 |
+
self.ref_len = 1
|
| 165 |
+
scores = []
|
| 166 |
+
for test, refs in zip(self.ctest, self.crefs):
|
| 167 |
+
# compute vector for test captions
|
| 168 |
+
vec, norm, length = counts2vec(test)
|
| 169 |
+
# compute vector for ref captions
|
| 170 |
+
score = np.array([0.0 for _ in range(self.n)])
|
| 171 |
+
for ref in refs:
|
| 172 |
+
vec_ref, norm_ref, length_ref = counts2vec(ref)
|
| 173 |
+
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
|
| 174 |
+
# change by vrama91 - mean of ngram scores, instead of sum
|
| 175 |
+
score_avg = np.mean(score)
|
| 176 |
+
# divide by number of references
|
| 177 |
+
score_avg /= len(refs)
|
| 178 |
+
# multiply score by 10
|
| 179 |
+
score_avg *= 10.0
|
| 180 |
+
# append score of an image to the score list
|
| 181 |
+
scores.append(score_avg)
|
| 182 |
+
return scores
|
| 183 |
+
|
| 184 |
+
def compute_score(self, option=None, verbose=0):
|
| 185 |
+
# compute idf
|
| 186 |
+
self.compute_doc_freq()
|
| 187 |
+
# assert to check document frequency
|
| 188 |
+
assert(len(self.ctest) >= max(self.document_frequency.values()))
|
| 189 |
+
# compute cider score
|
| 190 |
+
score = self.compute_cider()
|
| 191 |
+
# debug
|
| 192 |
+
# print score
|
| 193 |
+
return np.mean(np.array(score)), np.array(score)
|
src/eval_func/rouge/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
__author__ = 'vrama91'
|
src/eval_func/rouge/rouge.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
#
|
| 4 |
+
|
| 5 |
+
# File Name : rouge.py
|
| 6 |
+
|
| 7 |
+
#
|
| 8 |
+
|
| 9 |
+
# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
|
| 10 |
+
|
| 11 |
+
#
|
| 12 |
+
|
| 13 |
+
# Creation Date : 2015-01-07 06:03
|
| 14 |
+
|
| 15 |
+
# Author : Ramakrishna Vedantam <vrama91@vt.edu>
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
import pdb
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def my_lcs(string, sub):
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
Calculates longest common subsequence for a pair of tokenized strings
|
| 27 |
+
|
| 28 |
+
:param string : list of str : tokens from a string split using whitespace
|
| 29 |
+
|
| 30 |
+
:param sub : list of str : shorter string, also split using whitespace
|
| 31 |
+
|
| 32 |
+
:returns: length (list of int): length of the longest common subsequence between the two strings
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
|
| 37 |
+
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
if (len(string) < len(sub)):
|
| 41 |
+
sub, string = string, sub
|
| 42 |
+
|
| 43 |
+
lengths = [[0 for i in range(0, len(sub) + 1)] for j in range(0, len(string) + 1)]
|
| 44 |
+
|
| 45 |
+
for j in range(1, len(sub) + 1):
|
| 46 |
+
|
| 47 |
+
for i in range(1, len(string) + 1):
|
| 48 |
+
|
| 49 |
+
if (string[i - 1] == sub[j - 1]):
|
| 50 |
+
|
| 51 |
+
lengths[i][j] = lengths[i - 1][j - 1] + 1
|
| 52 |
+
|
| 53 |
+
else:
|
| 54 |
+
|
| 55 |
+
lengths[i][j] = max(lengths[i - 1][j], lengths[i][j - 1])
|
| 56 |
+
|
| 57 |
+
return lengths[len(string)][len(sub)]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Rouge():
|
| 61 |
+
'''
|
| 62 |
+
|
| 63 |
+
Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
'''
|
| 68 |
+
|
| 69 |
+
def __init__(self):
|
| 70 |
+
|
| 71 |
+
# vrama91: updated the value below based on discussion with Hovey
|
| 72 |
+
|
| 73 |
+
self.beta = 1.2
|
| 74 |
+
|
| 75 |
+
def calc_score(self, candidate, refs):
|
| 76 |
+
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
Compute ROUGE-L score given one candidate and references for an image
|
| 80 |
+
|
| 81 |
+
:param candidate: str : candidate sentence to be evaluated
|
| 82 |
+
|
| 83 |
+
:param refs: list of str : COCO reference sentences for the particular image to be evaluated
|
| 84 |
+
|
| 85 |
+
:returns score: int (ROUGE-L score for the candidate evaluated against references)
|
| 86 |
+
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
assert (len(candidate) == 1)
|
| 91 |
+
|
| 92 |
+
assert (len(refs) > 0)
|
| 93 |
+
|
| 94 |
+
prec = []
|
| 95 |
+
|
| 96 |
+
rec = []
|
| 97 |
+
|
| 98 |
+
# split into tokens
|
| 99 |
+
|
| 100 |
+
token_c = candidate[0].split(" ")
|
| 101 |
+
|
| 102 |
+
for reference in refs:
|
| 103 |
+
# split into tokens
|
| 104 |
+
hh =1
|
| 105 |
+
|
| 106 |
+
token_r = reference.split(" ")
|
| 107 |
+
|
| 108 |
+
# compute the longest common subsequence
|
| 109 |
+
|
| 110 |
+
lcs = my_lcs(token_r, token_c)
|
| 111 |
+
|
| 112 |
+
prec.append(lcs / float(len(token_c)))
|
| 113 |
+
|
| 114 |
+
rec.append(lcs / float(len(token_r)))
|
| 115 |
+
|
| 116 |
+
prec_max = max(prec)
|
| 117 |
+
|
| 118 |
+
rec_max = max(rec)
|
| 119 |
+
|
| 120 |
+
if (prec_max != 0 and rec_max != 0):
|
| 121 |
+
|
| 122 |
+
score = ((1 + self.beta ** 2) * prec_max * rec_max) / float(rec_max + self.beta ** 2 * prec_max)
|
| 123 |
+
|
| 124 |
+
else:
|
| 125 |
+
|
| 126 |
+
score = 0.0
|
| 127 |
+
|
| 128 |
+
return score
|
| 129 |
+
|
| 130 |
+
def compute_score(self, references, hypotheses):
|
| 131 |
+
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
Computes Rouge-L score given a set of reference and candidate sentences for the dataset
|
| 135 |
+
|
| 136 |
+
Invoked by evaluate_captions.py
|
| 137 |
+
|
| 138 |
+
:param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
|
| 139 |
+
|
| 140 |
+
:param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
|
| 141 |
+
|
| 142 |
+
:returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
|
| 143 |
+
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
# assert (gts.keys() == res.keys())
|
| 147 |
+
#
|
| 148 |
+
# imgIds = gts.keys()
|
| 149 |
+
|
| 150 |
+
score = []
|
| 151 |
+
|
| 152 |
+
for i in range(len(hypotheses)):
|
| 153 |
+
hypo = hypotheses[i]
|
| 154 |
+
ref = references[i]
|
| 155 |
+
|
| 156 |
+
score.append(self.calc_score(hypo, ref))
|
| 157 |
+
|
| 158 |
+
# Sanity check.
|
| 159 |
+
|
| 160 |
+
assert (type(hypo) is list)
|
| 161 |
+
|
| 162 |
+
assert (len(hypo) == 1)
|
| 163 |
+
|
| 164 |
+
assert (type(ref) is list)
|
| 165 |
+
|
| 166 |
+
assert (len(ref) > 0)
|
| 167 |
+
|
| 168 |
+
average_score = np.mean(np.array(score))
|
| 169 |
+
|
| 170 |
+
return average_score, np.array(score)
|
| 171 |
+
|
| 172 |
+
def method(self):
|
| 173 |
+
|
| 174 |
+
return "Rouge"
|
src/model.py
ADDED
|
@@ -0,0 +1,450 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from einops import rearrange
|
| 4 |
+
import math
|
| 5 |
+
from torch import Tensor
|
| 6 |
+
import torchvision.models as models
|
| 7 |
+
|
| 8 |
+
from torch.nn import functional as F
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ICCModel(nn.Module):
|
| 12 |
+
def __init__(self, device, pretrained, backbone, d_model, vocab_size, max_len,
|
| 13 |
+
num_heads, h_dim, a_dim, encoder_layers, decoder_layers, dropout,
|
| 14 |
+
learnable=False, fine_tune=True, tie_embeddings=True, prenorm=False):
|
| 15 |
+
super(ICCModel, self).__init__()
|
| 16 |
+
|
| 17 |
+
self.feature_dim = d_model
|
| 18 |
+
visual = pretrained.visual if pretrained else None
|
| 19 |
+
self.encoder = ImagesEncoder(device, visual, backbone, d_model, num_heads, h_dim, a_dim, dropout,
|
| 20 |
+
encoder_layers, fine_tune)
|
| 21 |
+
|
| 22 |
+
self.decoder = Decoder(device, d_model, vocab_size, max_len, num_heads,
|
| 23 |
+
decoder_layers, dropout,
|
| 24 |
+
learnable=learnable, tie_embeddings=tie_embeddings, prenorm=prenorm)
|
| 25 |
+
|
| 26 |
+
def forward(self, img1, img2, input_ids, labels, attention_mask):
|
| 27 |
+
vis_emb, vis_toks = self.encoder(img1, img2)
|
| 28 |
+
cap_loss, text_emb, lm_logits, weights = self.decoder(input_ids, labels, attention_mask, vis_toks)
|
| 29 |
+
return cap_loss, vis_emb, text_emb, vis_toks, lm_logits, weights
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ImagesEncoder(nn.Module):
|
| 33 |
+
def __init__(self, device, pretrained, backbone, d_model, num_heads, h_dim, a_dim, dropout,
|
| 34 |
+
encoder_layers, fine_tune):
|
| 35 |
+
super(ImagesEncoder, self).__init__()
|
| 36 |
+
self.encoder = Encoder(pretrained, backbone, d_model, fine_tune)
|
| 37 |
+
self.encoder_trans = AttentiveEncoder(device, encoder_layers,
|
| 38 |
+
[self.encoder.feat_size, self.encoder.feat_size, d_model], num_heads,
|
| 39 |
+
hidden_dim=h_dim, attention_dim=a_dim, dropout=dropout)
|
| 40 |
+
|
| 41 |
+
self.cos = torch.nn.CosineSimilarity(dim=1)
|
| 42 |
+
self.Conv1 = nn.Conv2d(d_model * 2, d_model, kernel_size=1)
|
| 43 |
+
self.LN = resblock(d_model, d_model)
|
| 44 |
+
|
| 45 |
+
self.att_pool = nn.MultiheadAttention(d_model, num_heads)
|
| 46 |
+
self.att_pool_norm = nn.LayerNorm(d_model)
|
| 47 |
+
self.img_queries = nn.Parameter(torch.randn(1, d_model))
|
| 48 |
+
|
| 49 |
+
def forward(self, img1, img2):
|
| 50 |
+
feat1 = self.encoder(img1)
|
| 51 |
+
feat2 = self.encoder(img2)
|
| 52 |
+
x1, x2 = self.encoder_trans(feat1, feat2) # batch_size, channel, enc_image_size, enc_image_size
|
| 53 |
+
|
| 54 |
+
x_sam = self.cos(x1, x2)
|
| 55 |
+
x = torch.cat([x1, x2], dim=1) + x_sam.unsqueeze(1) # batch_size, 2channel, enc_image_size, enc_image_size
|
| 56 |
+
x = self.LN(self.Conv1(x))
|
| 57 |
+
batch, channel = x.size(0), x.size(1)
|
| 58 |
+
x = x.view(batch, channel, -1).permute(2, 0, 1) # h*w, batch, dim
|
| 59 |
+
|
| 60 |
+
img_queries = self.img_queries.unsqueeze(1).repeat(1, x.shape[1], 1) # L,N,E
|
| 61 |
+
img_emb = self.att_pool(img_queries, x, x, need_weights=False)[0]
|
| 62 |
+
img_emb = self.att_pool_norm(img_emb) # 1, batch, d_model
|
| 63 |
+
|
| 64 |
+
cls = img_emb[0]
|
| 65 |
+
return cls, x
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class Encoder(nn.Module):
|
| 69 |
+
def __init__(self, pretrained, backbone, d_model, fine_tune):
|
| 70 |
+
super(Encoder, self).__init__()
|
| 71 |
+
self.backbone = backbone
|
| 72 |
+
|
| 73 |
+
if 'rn' in backbone.lower():
|
| 74 |
+
modules = list(pretrained.children())[:-1]
|
| 75 |
+
self.net = nn.Sequential(*modules)
|
| 76 |
+
self.feat_dim = 2048
|
| 77 |
+
self.feat_size = 7
|
| 78 |
+
elif 'b-32' in backbone.lower():
|
| 79 |
+
self.net = pretrained
|
| 80 |
+
self.net.output_tokens = True
|
| 81 |
+
self.feat_dim = 768
|
| 82 |
+
self.feat_size = 7
|
| 83 |
+
elif 'l-14' in backbone.lower():
|
| 84 |
+
self.net = pretrained
|
| 85 |
+
self.net.output_tokens = True
|
| 86 |
+
self.feat_dim = 1024
|
| 87 |
+
self.feat_size = 16
|
| 88 |
+
elif backbone == 'resnet50':
|
| 89 |
+
net = models.resnet50(pretrained=True)
|
| 90 |
+
modules = list(net.children())[:-2]
|
| 91 |
+
self.net = nn.Sequential(*modules)
|
| 92 |
+
self.feat_dim = 2048
|
| 93 |
+
self.feat_size = 8
|
| 94 |
+
elif backbone == 'resnet101':
|
| 95 |
+
net = models.resnet101(pretrained=True)
|
| 96 |
+
modules = list(net.children())[:-2]
|
| 97 |
+
self.net = nn.Sequential(*modules)
|
| 98 |
+
self.feat_dim = 2048
|
| 99 |
+
self.feat_size = 8
|
| 100 |
+
|
| 101 |
+
self.proj = None
|
| 102 |
+
if self.feat_dim != d_model:
|
| 103 |
+
self.proj = nn.Conv2d(self.feat_dim, d_model, kernel_size=1)
|
| 104 |
+
|
| 105 |
+
self.fine_tune(fine_tune)
|
| 106 |
+
|
| 107 |
+
def forward(self, image):
|
| 108 |
+
feat = self.net(image) # batch, feat_dim, feat_size, feat_size
|
| 109 |
+
if 'vit' in self.backbone.lower():
|
| 110 |
+
feat = feat[1].reshape(-1, self.feat_size, self.feat_size, self.feat_dim).permute(0, 3, 1, 2)
|
| 111 |
+
|
| 112 |
+
if self.proj:
|
| 113 |
+
feat = self.proj(feat)
|
| 114 |
+
|
| 115 |
+
return feat
|
| 116 |
+
|
| 117 |
+
def fine_tune(self, fine_tune=True):
|
| 118 |
+
for p in self.net.parameters():
|
| 119 |
+
p.requires_grad = False
|
| 120 |
+
|
| 121 |
+
if 'resnet' in self.backbone:
|
| 122 |
+
to_finetune = list(self.net.children())[-5:]
|
| 123 |
+
elif 'vit' in self.backbone.lower():
|
| 124 |
+
to_finetune = list(self.net.children())[-2:] # only transformer layers
|
| 125 |
+
else:
|
| 126 |
+
to_finetune = list(self.net.children())[-3:] # only fine-tune convolutional blocks 2 through 4
|
| 127 |
+
|
| 128 |
+
for c in to_finetune:
|
| 129 |
+
for p in c.parameters():
|
| 130 |
+
p.requires_grad = fine_tune
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class FeedForward(nn.Module):
|
| 134 |
+
def __init__(self, dim, hidden_dim, dropout=0.):
|
| 135 |
+
super(FeedForward, self).__init__()
|
| 136 |
+
self.net = nn.Sequential(
|
| 137 |
+
nn.Linear(dim, hidden_dim),
|
| 138 |
+
nn.ReLU(),
|
| 139 |
+
nn.Dropout(dropout),
|
| 140 |
+
nn.Linear(hidden_dim, dim),
|
| 141 |
+
nn.Dropout(dropout)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
return self.net(x)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class MultiHeadAtt(nn.Module):
|
| 149 |
+
def __init__(self, dim_q, dim_kv, attention_dim, heads=8, dropout=0.):
|
| 150 |
+
super(MultiHeadAtt, self).__init__()
|
| 151 |
+
project_out = not (heads == 1 and attention_dim == dim_kv)
|
| 152 |
+
self.heads = heads
|
| 153 |
+
self.scale = (attention_dim // self.heads) ** -0.5
|
| 154 |
+
|
| 155 |
+
self.to_q = nn.Linear(dim_q, attention_dim, bias=False)
|
| 156 |
+
self.to_k = nn.Linear(dim_kv, attention_dim, bias=False)
|
| 157 |
+
self.to_v = nn.Linear(dim_kv, attention_dim, bias=False)
|
| 158 |
+
self.attend = nn.Softmax(dim=-1)
|
| 159 |
+
self.dropout = nn.Dropout(dropout)
|
| 160 |
+
self.to_out = nn.Sequential(
|
| 161 |
+
nn.Linear(attention_dim, dim_q),
|
| 162 |
+
nn.Dropout(dropout)
|
| 163 |
+
) if project_out else nn.Identity()
|
| 164 |
+
|
| 165 |
+
def forward(self, x1, x2, x3):
|
| 166 |
+
q = self.to_q(x1)
|
| 167 |
+
k = self.to_k(x2)
|
| 168 |
+
v = self.to_k(x3)
|
| 169 |
+
q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
|
| 170 |
+
k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
|
| 171 |
+
v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)
|
| 172 |
+
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
|
| 173 |
+
|
| 174 |
+
attn = self.dropout(self.attend(dots))
|
| 175 |
+
out = torch.matmul(attn, v)
|
| 176 |
+
out = rearrange(out, 'b h n d -> b n (h d)')
|
| 177 |
+
return self.to_out(out) # (b,n,dim)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
class Transformer(nn.Module):
|
| 181 |
+
def __init__(self, dim_q, dim_kv, heads, attention_dim, hidden_dim, dropout=0., norm_first=False):
|
| 182 |
+
super(Transformer, self).__init__()
|
| 183 |
+
self.norm_first = norm_first
|
| 184 |
+
self.att = MultiHeadAtt(dim_q, dim_kv, attention_dim, heads=heads, dropout=dropout)
|
| 185 |
+
self.feedforward = FeedForward(dim_q, hidden_dim, dropout=dropout)
|
| 186 |
+
self.norm1 = nn.LayerNorm(dim_q)
|
| 187 |
+
self.norm2 = nn.LayerNorm(dim_q)
|
| 188 |
+
|
| 189 |
+
def forward(self, x1, x2, x3):
|
| 190 |
+
if self.norm_first:
|
| 191 |
+
x = self.att(self.norm1(x1), self.norm1(x2), self.norm1(x3)) + x1
|
| 192 |
+
x = self.feedforward(self.norm2(x)) + x
|
| 193 |
+
else:
|
| 194 |
+
x = self.norm1(self.att(x1, x2, x3) + x1)
|
| 195 |
+
x = self.norm2(self.feedforward(x) + x)
|
| 196 |
+
|
| 197 |
+
return x
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class AttentiveEncoder(nn.Module):
|
| 201 |
+
def __init__(self, device, n_layers, feature_size, heads, hidden_dim=512, attention_dim=512, dropout=0.):
|
| 202 |
+
super(AttentiveEncoder, self).__init__()
|
| 203 |
+
h_feat, w_feat, channels = feature_size
|
| 204 |
+
|
| 205 |
+
self.device = device
|
| 206 |
+
self.h_embedding = nn.Embedding(h_feat, int(channels / 2))
|
| 207 |
+
self.w_embedding = nn.Embedding(w_feat, int(channels / 2))
|
| 208 |
+
self.selftrans = nn.ModuleList([])
|
| 209 |
+
for i in range(n_layers):
|
| 210 |
+
self.selftrans.append(nn.ModuleList([
|
| 211 |
+
Transformer(channels, channels, heads, attention_dim, hidden_dim, dropout, norm_first=False),
|
| 212 |
+
Transformer(channels * 2, channels * 2, heads, attention_dim, hidden_dim, dropout, norm_first=False),
|
| 213 |
+
]))
|
| 214 |
+
|
| 215 |
+
self._reset_parameters()
|
| 216 |
+
|
| 217 |
+
def _reset_parameters(self):
|
| 218 |
+
for p in self.parameters():
|
| 219 |
+
if p.dim() > 1:
|
| 220 |
+
nn.init.xavier_uniform_(p)
|
| 221 |
+
|
| 222 |
+
def forward(self, img1, img2):
|
| 223 |
+
batch, c, h, w = img1.shape
|
| 224 |
+
pos_h = torch.arange(h).to(self.device)
|
| 225 |
+
pos_w = torch.arange(w).to(self.device)
|
| 226 |
+
embed_h = self.w_embedding(pos_h)
|
| 227 |
+
embed_w = self.h_embedding(pos_w)
|
| 228 |
+
pos_embedding = torch.cat([embed_w.unsqueeze(0).repeat(h, 1, 1),
|
| 229 |
+
embed_h.unsqueeze(1).repeat(1, w, 1)],
|
| 230 |
+
dim=-1)
|
| 231 |
+
pos_embedding = pos_embedding.permute(2, 0, 1).unsqueeze(0).repeat(batch, 1, 1, 1)
|
| 232 |
+
img1 = img1 + pos_embedding
|
| 233 |
+
img2 = img2 + pos_embedding
|
| 234 |
+
img1 = img1.view(batch, c, -1).transpose(-1, 1) # batch, hw, c
|
| 235 |
+
img2 = img2.view(batch, c, -1).transpose(-1, 1)
|
| 236 |
+
img_sa1, img_sa2 = img1, img2
|
| 237 |
+
|
| 238 |
+
for (l, m) in self.selftrans:
|
| 239 |
+
img_sa1 = l(img_sa1, img_sa1, img_sa1) + img_sa1
|
| 240 |
+
img_sa2 = l(img_sa2, img_sa2, img_sa2) + img_sa2
|
| 241 |
+
img = torch.cat([img_sa1, img_sa2], dim=-1)
|
| 242 |
+
img = m(img, img, img)
|
| 243 |
+
img_sa1 = img[:, :, :c] + img1
|
| 244 |
+
img_sa2 = img[:, :, c:] + img2
|
| 245 |
+
|
| 246 |
+
img1 = img_sa1.reshape(batch, h, w, c).transpose(-1, 1)
|
| 247 |
+
img2 = img_sa2.reshape(batch, h, w, c).transpose(-1, 1)
|
| 248 |
+
|
| 249 |
+
return img1, img2
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class resblock(nn.Module):
|
| 253 |
+
def __init__(self, inchannel, outchannel, stride=1, shortcut=None):
|
| 254 |
+
super(resblock, self).__init__()
|
| 255 |
+
self.left = nn.Sequential(
|
| 256 |
+
nn.Conv2d(inchannel, int(outchannel / 2), kernel_size=1),
|
| 257 |
+
# nn.LayerNorm(int(outchannel/2),dim=1),
|
| 258 |
+
nn.BatchNorm2d(int(outchannel / 2)),
|
| 259 |
+
nn.ReLU(),
|
| 260 |
+
nn.Conv2d(int(outchannel / 2), int(outchannel / 2), kernel_size=3, stride=1, padding=1),
|
| 261 |
+
# nn.LayerNorm(int(outchannel/2),dim=1),
|
| 262 |
+
nn.BatchNorm2d(int(outchannel / 2)),
|
| 263 |
+
nn.ReLU(),
|
| 264 |
+
nn.Conv2d(int(outchannel / 2), outchannel, kernel_size=1),
|
| 265 |
+
# nn.LayerNorm(int(outchannel / 1),dim=1)
|
| 266 |
+
nn.BatchNorm2d(outchannel)
|
| 267 |
+
)
|
| 268 |
+
self.right = shortcut
|
| 269 |
+
|
| 270 |
+
def forward(self, x):
|
| 271 |
+
out = self.left(x)
|
| 272 |
+
residual = x
|
| 273 |
+
out = out + residual
|
| 274 |
+
return F.relu(out)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class Decoder(nn.Module):
|
| 278 |
+
def __init__(self, device, h_dim, vocab_size, max_len, n_head, n_layers, dropout=0.10,
|
| 279 |
+
learnable=False, tie_embeddings=True, prenorm=False):
|
| 280 |
+
|
| 281 |
+
super(Decoder, self).__init__()
|
| 282 |
+
|
| 283 |
+
self.embed_dim = h_dim
|
| 284 |
+
self.vocab_size = vocab_size
|
| 285 |
+
self.dropout = dropout
|
| 286 |
+
self.device = device
|
| 287 |
+
|
| 288 |
+
self.tokens_embed = nn.Embedding(vocab_size, self.embed_dim)
|
| 289 |
+
self.position_encoding = PositionalEncoding(self.embed_dim, dropout=dropout, max_len=max_len,
|
| 290 |
+
device=device, learnable=learnable)
|
| 291 |
+
|
| 292 |
+
self.uni_decoder = nn.ModuleList(
|
| 293 |
+
[DecoderLayer(h_dim, h_dim, n_head, dim_feedforward=h_dim * 4, dropout=self.dropout, prenorm=prenorm,
|
| 294 |
+
crossattention=False) for _ in range(n_layers)])
|
| 295 |
+
|
| 296 |
+
self.cross_decoder = nn.ModuleList(
|
| 297 |
+
[DecoderLayer(h_dim, h_dim, n_head, dim_feedforward=h_dim * 4, dropout=self.dropout, prenorm=prenorm,
|
| 298 |
+
crossattention=True) for _ in range(n_layers)])
|
| 299 |
+
|
| 300 |
+
self.lm_head = nn.Linear(h_dim, vocab_size, bias=False)
|
| 301 |
+
if tie_embeddings:
|
| 302 |
+
self.tokens_embed.weight = self.lm_head.weight
|
| 303 |
+
self.dropout = nn.Dropout(p=self.dropout)
|
| 304 |
+
self.init_weights()
|
| 305 |
+
self.loss_fn = nn.CrossEntropyLoss()
|
| 306 |
+
|
| 307 |
+
def init_weights(self):
|
| 308 |
+
self.tokens_embed.weight.data.uniform_(-0.1, 0.1)
|
| 309 |
+
self.lm_head.weight.data.uniform_(-0.1, 0.1)
|
| 310 |
+
|
| 311 |
+
def forward(self, input_ids=None, labels=None, pad_mask=None, img_emb=None):
|
| 312 |
+
att_weights = None
|
| 313 |
+
mask = torch.tril(torch.ones(input_ids.shape[1], input_ids.shape[1]))
|
| 314 |
+
mask = ~mask.bool()
|
| 315 |
+
mask = mask.to(self.device)
|
| 316 |
+
|
| 317 |
+
inputs_embeds = self.tokens_embed(input_ids)
|
| 318 |
+
inputs_embeds = self.position_encoding(inputs_embeds) # batch, seq, e_dim
|
| 319 |
+
inputs_embeds = inputs_embeds.permute(1, 0, 2) # seq, batch, e_dim
|
| 320 |
+
|
| 321 |
+
# seq, batch, emb_dim
|
| 322 |
+
out = inputs_embeds
|
| 323 |
+
for block in self.uni_decoder:
|
| 324 |
+
out, _ = block(out, None, tgt_mask=mask, tgt_key_padding_mask=pad_mask)
|
| 325 |
+
|
| 326 |
+
if pad_mask is not None: # not inference
|
| 327 |
+
cls = []
|
| 328 |
+
for i in range(pad_mask.shape[0]):
|
| 329 |
+
end = pad_mask[i].shape[0] - pad_mask[i].count_nonzero()
|
| 330 |
+
cls.append(out[end - 1, i, :])
|
| 331 |
+
|
| 332 |
+
cls = torch.stack(cls) # batch, emb_dim
|
| 333 |
+
else:
|
| 334 |
+
cls = None
|
| 335 |
+
|
| 336 |
+
if img_emb is None:
|
| 337 |
+
return None, cls, None, None
|
| 338 |
+
|
| 339 |
+
for block in self.cross_decoder:
|
| 340 |
+
out, att_weights = block(out, img_emb, tgt_mask=mask, tgt_key_padding_mask=pad_mask)
|
| 341 |
+
|
| 342 |
+
lm_logits = self.lm_head(self.dropout(out)) # seq, batch, voc_dim
|
| 343 |
+
lm_logits = lm_logits.permute(1, 0, 2) # batch, seq, voc_dim
|
| 344 |
+
|
| 345 |
+
if labels is not None: # not inference
|
| 346 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
| 347 |
+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
| 348 |
+
shift_labels = labels[..., 1:].contiguous()
|
| 349 |
+
shift_labels = shift_labels.view(-1)
|
| 350 |
+
loss = self.loss_fn(shift_logits, shift_labels)
|
| 351 |
+
else:
|
| 352 |
+
loss = None
|
| 353 |
+
|
| 354 |
+
return loss, cls, lm_logits, att_weights
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class PositionalEncoding(nn.Module):
|
| 358 |
+
def __init__(self, d_model, dropout, max_len, device, learnable=False):
|
| 359 |
+
super(PositionalEncoding, self).__init__()
|
| 360 |
+
self.learnable = learnable
|
| 361 |
+
self.max_len = max_len
|
| 362 |
+
self.device = device
|
| 363 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 364 |
+
|
| 365 |
+
if not learnable:
|
| 366 |
+
pe = torch.zeros(max_len, d_model)
|
| 367 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
| 368 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
| 369 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 370 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 371 |
+
pe = pe.unsqueeze(0)
|
| 372 |
+
self.register_buffer('pe', pe)
|
| 373 |
+
else:
|
| 374 |
+
self.pos_emb = nn.Embedding(max_len, int(d_model))
|
| 375 |
+
|
| 376 |
+
def forward(self, x):
|
| 377 |
+
if self.learnable:
|
| 378 |
+
position_ids = torch.arange(x.size(1), dtype=torch.long).to(self.device)
|
| 379 |
+
position_ids = position_ids.unsqueeze(0).view(-1, x.size(1)) # batch, seq
|
| 380 |
+
x = x + self.pos_emb(position_ids)
|
| 381 |
+
else:
|
| 382 |
+
x = x + self.pe[:, :x.size(1), :]
|
| 383 |
+
return self.dropout(x)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
class DecoderLayer(nn.Module):
|
| 387 |
+
def __init__(self, d_model, img_dim, nhead, dim_feedforward=2048, dropout=0.1, layer_norm_eps=1e-5,
|
| 388 |
+
prenorm=False, crossattention=False):
|
| 389 |
+
super(DecoderLayer, self).__init__()
|
| 390 |
+
|
| 391 |
+
self.prenorm = prenorm
|
| 392 |
+
self.crossattention = crossattention
|
| 393 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
| 394 |
+
|
| 395 |
+
if crossattention:
|
| 396 |
+
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, kdim=img_dim, vdim=img_dim)
|
| 397 |
+
self.mha_dropout = nn.Dropout(dropout)
|
| 398 |
+
self.mha_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 399 |
+
|
| 400 |
+
self.ff_linear1 = nn.Linear(d_model, dim_feedforward)
|
| 401 |
+
self.ff_dropout = nn.Dropout(dropout)
|
| 402 |
+
self.ff_linear2 = nn.Linear(dim_feedforward, d_model)
|
| 403 |
+
|
| 404 |
+
self.sa_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 405 |
+
self.ff_norm = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
| 406 |
+
|
| 407 |
+
self.sa_dropout = nn.Dropout(dropout)
|
| 408 |
+
self.ff_dropout = nn.Dropout(dropout)
|
| 409 |
+
|
| 410 |
+
self.activation = nn.GELU()
|
| 411 |
+
|
| 412 |
+
def forward(self, tgt: Tensor, memory: Tensor, tgt_mask=None,
|
| 413 |
+
memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
|
| 414 |
+
att_weight = None
|
| 415 |
+
x = tgt
|
| 416 |
+
|
| 417 |
+
if self.prenorm:
|
| 418 |
+
x = x + self._sa_block(self.sa_norm(x), tgt_mask, tgt_key_padding_mask)
|
| 419 |
+
if self.crossattention:
|
| 420 |
+
enc_att, att_weight = self._mha_block(self.mha_norm(x), memory, memory_mask, memory_key_padding_mask)
|
| 421 |
+
x = x + enc_att
|
| 422 |
+
x = x + self._ff_block(self.ff_norm(x))
|
| 423 |
+
else:
|
| 424 |
+
x = self.sa_norm(x + self._sa_block(x, tgt_mask, tgt_key_padding_mask))
|
| 425 |
+
if self.crossattention:
|
| 426 |
+
enc_att, att_weight = self._mha_block(x, memory, memory_mask, memory_key_padding_mask)
|
| 427 |
+
x = self.mha_norm(x + enc_att)
|
| 428 |
+
|
| 429 |
+
x = self.ff_norm(x + self._ff_block(x))
|
| 430 |
+
return x, att_weight
|
| 431 |
+
|
| 432 |
+
def _sa_block(self, x, attn_mask, key_padding_mask):
|
| 433 |
+
x = self.self_attn(x, x, x, # L,N,E
|
| 434 |
+
attn_mask=attn_mask, # L, S
|
| 435 |
+
key_padding_mask=key_padding_mask, # N, S
|
| 436 |
+
is_causal=True,
|
| 437 |
+
need_weights=False)[0]
|
| 438 |
+
return self.sa_dropout(x)
|
| 439 |
+
|
| 440 |
+
def _mha_block(self, x, mem, attn_mask, key_padding_mask):
|
| 441 |
+
x, att_weight = self.cross_attn(x, mem, mem,
|
| 442 |
+
attn_mask=attn_mask,
|
| 443 |
+
key_padding_mask=key_padding_mask,
|
| 444 |
+
is_causal=False,
|
| 445 |
+
need_weights=True)
|
| 446 |
+
return self.mha_dropout(x), att_weight
|
| 447 |
+
|
| 448 |
+
def _ff_block(self, x):
|
| 449 |
+
x = self.ff_linear2(self.ff_dropout(self.activation(self.ff_linear1(x))))
|
| 450 |
+
return self.ff_dropout(x)
|
src/train.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import random
|
| 3 |
+
import os
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import json
|
| 9 |
+
from torch.optim import AdamW
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from transformers import get_constant_schedule_with_warmup
|
| 14 |
+
|
| 15 |
+
from Datasets import CCDataset, Batcher
|
| 16 |
+
from model import ICCModel
|
| 17 |
+
from utils import get_vocabulary
|
| 18 |
+
from Loss import InfoNCELoss
|
| 19 |
+
from eval import captioning, retrieve, plot
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
import open_clip
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def train(args, model, train_loader, valid_loader, device, infonce, optim, scheduler, writer):
|
| 25 |
+
step = 0
|
| 26 |
+
best_score = float("inf")
|
| 27 |
+
best_model = None
|
| 28 |
+
|
| 29 |
+
for epoch in range(args.epochs):
|
| 30 |
+
model.train()
|
| 31 |
+
|
| 32 |
+
for batch in tqdm(train_loader, desc='Epoch ' + str(epoch)):
|
| 33 |
+
imgs1 = batch['images_before'].to(device)
|
| 34 |
+
imgs2 = batch['images_after'].to(device)
|
| 35 |
+
toks = batch['input_ids'].to(device)
|
| 36 |
+
labs = batch['labels'].to(device)
|
| 37 |
+
flags = batch['flags'].to(device)
|
| 38 |
+
attention_mask = batch['pad_mask'].to(device)
|
| 39 |
+
embs = batch['embs'].to(device)
|
| 40 |
+
|
| 41 |
+
cap_loss, vis_emb, text_emb, _, _, _ = model(imgs1, imgs2, toks, labs, attention_mask)
|
| 42 |
+
|
| 43 |
+
con_loss, num_pos = infonce(vis_emb, text_emb, flags, embs)
|
| 44 |
+
loss = cap_loss + args.lamb * con_loss
|
| 45 |
+
loss.backward()
|
| 46 |
+
|
| 47 |
+
if args.max_grad_norm:
|
| 48 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
| 49 |
+
grad = torch.norm(torch.stack(
|
| 50 |
+
[torch.norm(p.grad.detach()).to(device) for p in model.parameters() if p.grad is not None]))
|
| 51 |
+
|
| 52 |
+
optim.step()
|
| 53 |
+
scheduler.step()
|
| 54 |
+
optim.zero_grad()
|
| 55 |
+
|
| 56 |
+
writer.add_scalar('train_loss', loss.item(), step)
|
| 57 |
+
writer.add_scalar('grad', grad, step)
|
| 58 |
+
writer.add_scalar('lr', scheduler.get_last_lr()[0], step)
|
| 59 |
+
|
| 60 |
+
step += 1
|
| 61 |
+
|
| 62 |
+
torch.save(model.state_dict(), args.output_path + 'model_{}.pt'.format(step))
|
| 63 |
+
|
| 64 |
+
model.eval()
|
| 65 |
+
with torch.no_grad():
|
| 66 |
+
eval_losses = torch.empty(0)
|
| 67 |
+
for batch in tqdm(valid_loader, desc='Validation ' + str(epoch)):
|
| 68 |
+
imgs1 = batch['images_before'].to(device)
|
| 69 |
+
imgs2 = batch['images_after'].to(device)
|
| 70 |
+
toks = batch['input_ids'].to(device)
|
| 71 |
+
labs = batch['labels'].to(device)
|
| 72 |
+
flags = batch['flags'].to(device)
|
| 73 |
+
attention_mask = batch['pad_mask'].to(device)
|
| 74 |
+
embs = batch['embs'].to(device)
|
| 75 |
+
|
| 76 |
+
cap_loss, vis_emb, text_emb, _, _, _ = model(imgs1, imgs2, toks, labs, attention_mask)
|
| 77 |
+
|
| 78 |
+
con_loss, _ = infonce(vis_emb, text_emb, flags, embs)
|
| 79 |
+
loss = cap_loss + args.lamb * con_loss
|
| 80 |
+
eval_losses = torch.cat([eval_losses, loss.cpu().unsqueeze(0)])
|
| 81 |
+
|
| 82 |
+
eval_score = torch.mean(eval_losses)
|
| 83 |
+
writer.add_scalar('eval_score', eval_score, step)
|
| 84 |
+
|
| 85 |
+
is_best = eval_score < best_score
|
| 86 |
+
best_score = min(eval_score, best_score)
|
| 87 |
+
if is_best:
|
| 88 |
+
best_model = step
|
| 89 |
+
|
| 90 |
+
if best_model is not None:
|
| 91 |
+
state_dict = torch.load(os.path.join(args.output_path + 'model_{}.pt'.format(best_model)), map_location=device)
|
| 92 |
+
torch.save(state_dict, args.output_path + 'model_best.pt')
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def run(args, config):
|
| 96 |
+
print('Initializing...')
|
| 97 |
+
torch.manual_seed(args.seed)
|
| 98 |
+
np.random.seed(args.seed)
|
| 99 |
+
random.seed(args.seed)
|
| 100 |
+
torch.backends.cudnn.deterministic = True
|
| 101 |
+
|
| 102 |
+
device = torch.device('cpu')
|
| 103 |
+
if torch.cuda.is_available():
|
| 104 |
+
device = torch.device('cuda')
|
| 105 |
+
|
| 106 |
+
dt_str = datetime.now().strftime("%d-%m-%Y-%H-%M-%S")
|
| 107 |
+
writer_path = args.output_path + dt_str
|
| 108 |
+
writer = SummaryWriter(writer_path)
|
| 109 |
+
|
| 110 |
+
if os.path.exists(args.vocab):
|
| 111 |
+
with open(args.vocab, 'r') as infile:
|
| 112 |
+
vocab = json.load(infile)
|
| 113 |
+
else:
|
| 114 |
+
vocab = get_vocabulary(args.annotation_json, args.vocab)
|
| 115 |
+
|
| 116 |
+
clip = None
|
| 117 |
+
preprocess = v2.Compose([
|
| 118 |
+
v2.ToImage(),
|
| 119 |
+
v2.ToDtype(torch.float32, scale=True),
|
| 120 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 121 |
+
])
|
| 122 |
+
|
| 123 |
+
if 'resnet' not in config['backbone']:
|
| 124 |
+
checkpoint_path = hf_hub_download("chendelong/RemoteCLIP",
|
| 125 |
+
f"RemoteCLIP-{config['backbone']}.pt",
|
| 126 |
+
cache_dir=args.pretrained)
|
| 127 |
+
|
| 128 |
+
clip, _, preprocess = open_clip.create_model_and_transforms(config['backbone'])
|
| 129 |
+
ckpt = torch.load(checkpoint_path, map_location="cpu")
|
| 130 |
+
clip.load_state_dict(ckpt)
|
| 131 |
+
|
| 132 |
+
model = ICCModel(device, clip, config['backbone'], config['d_model'],
|
| 133 |
+
len(vocab), config['max_len'], config['num_heads'], config['h_dim'], config['a_dim'],
|
| 134 |
+
config['encoder_layers'], config['decoder_layers'], config['dropout'],
|
| 135 |
+
learnable=config['learnable'], fine_tune=config['fine_tune'],
|
| 136 |
+
tie_embeddings=config['tie_embeddings'], prenorm=config['prenorm'])
|
| 137 |
+
model = model.to(device)
|
| 138 |
+
del clip
|
| 139 |
+
|
| 140 |
+
print('Loading...')
|
| 141 |
+
training_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'train', config['max_len'],
|
| 142 |
+
config['s-transformers'], device)
|
| 143 |
+
valid_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'val', config['max_len'],
|
| 144 |
+
config['s-transformers'], device)
|
| 145 |
+
test_set = CCDataset(args.annotation_json, args.image_dir, vocab, preprocess, 'test', config['max_len'],
|
| 146 |
+
config['s-transformers'], device)
|
| 147 |
+
|
| 148 |
+
train_loader = Batcher(training_set, args.batch_size, config['max_len'], device, args.hd, model=model, shuffle=True)
|
| 149 |
+
valid_loader = Batcher(valid_set, args.batch_size, config['max_len'], device)
|
| 150 |
+
test_loader = Batcher(test_set, 1, config['max_len'], device)
|
| 151 |
+
|
| 152 |
+
print('Training...')
|
| 153 |
+
infonce = InfoNCELoss(device, k=args.k, temperature=args.temperature, threshold=config['s-threshold'],
|
| 154 |
+
fna=config['fna'])
|
| 155 |
+
optim = AdamW([x for x in model.parameters() if x.requires_grad], lr=args.learning_rate, eps=args.adam_epsilon)
|
| 156 |
+
scheduler = get_constant_schedule_with_warmup(optim,
|
| 157 |
+
num_warmup_steps=args.warmup_steps * len(train_loader) * args.epochs)
|
| 158 |
+
train(args, model, train_loader, valid_loader, device, infonce, optim, scheduler, writer)
|
| 159 |
+
|
| 160 |
+
print('Final evaluation...')
|
| 161 |
+
model.load_state_dict(torch.load(os.path.join(args.output_path, 'model_best.pt'), map_location=device))
|
| 162 |
+
results = captioning(args, config, model, test_loader, vocab, device)
|
| 163 |
+
retrieve(args, config, model, test_loader, device)
|
| 164 |
+
plot(args, model.encoder.encoder.feat_size, results)
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def main():
|
| 168 |
+
parser = argparse.ArgumentParser()
|
| 169 |
+
parser.add_argument('--annotation_json', type=str, default='../input/Levir_CC/LevirCCcaptions.json')
|
| 170 |
+
parser.add_argument('--image_dir', type=str, default='../input/Levir_CC/images/')
|
| 171 |
+
parser.add_argument('--vocab', type=str, default='../input/levir_vocab.json')
|
| 172 |
+
parser.add_argument('--pretrained', type=str, default='../../input/checkpoints')
|
| 173 |
+
parser.add_argument('--config', type=str, default='../config.json')
|
| 174 |
+
parser.add_argument('--output_path', type=str, default='../output/')
|
| 175 |
+
|
| 176 |
+
parser.add_argument('--epochs', type=int, default=50)
|
| 177 |
+
parser.add_argument('--batch_size', type=int, default=4)
|
| 178 |
+
parser.add_argument('--k', type=int, default=-1)
|
| 179 |
+
parser.add_argument('--hd', type=int, default=-1)
|
| 180 |
+
parser.add_argument('--learning_rate', type=float, default=1e-4)
|
| 181 |
+
parser.add_argument('--warmup_steps', type=float, default=0.025)
|
| 182 |
+
parser.add_argument('--lr_decay', type=float, default=0.7)
|
| 183 |
+
parser.add_argument('--adam_epsilon', type=float, default=1e-8)
|
| 184 |
+
parser.add_argument('--max_grad_norm', type=float, default=None)
|
| 185 |
+
parser.add_argument('--temperature', type=float, default=0.01)
|
| 186 |
+
parser.add_argument('--lamb', type=float, default=0.5)
|
| 187 |
+
parser.add_argument('--seed', type=int, default=42)
|
| 188 |
+
|
| 189 |
+
args = parser.parse_args()
|
| 190 |
+
|
| 191 |
+
with open(args.config, 'r') as config_file:
|
| 192 |
+
config = json.load(config_file)
|
| 193 |
+
|
| 194 |
+
run(args, config)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
if __name__ == '__main__':
|
| 198 |
+
main()
|
src/utils.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import torch
|
| 6 |
+
import glob
|
| 7 |
+
|
| 8 |
+
from eval_func.bleu.bleu import Bleu
|
| 9 |
+
from eval_func.rouge.rouge import Rouge
|
| 10 |
+
from eval_func.cider.cider import Cider
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_eval_score(ref, hypo):
|
| 14 |
+
scorers = [
|
| 15 |
+
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
|
| 16 |
+
(Rouge(), "ROUGE_L"),
|
| 17 |
+
(Cider(), "CIDEr")
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
score = []
|
| 21 |
+
method = []
|
| 22 |
+
for scorer, method_i in scorers:
|
| 23 |
+
score_i, scores_i = scorer.compute_score(ref, hypo)
|
| 24 |
+
score.extend(score_i) if isinstance(score_i, list) else score.append(score_i)
|
| 25 |
+
method.extend(method_i) if isinstance(method_i, list) else method.append(method_i)
|
| 26 |
+
|
| 27 |
+
score_dict = dict(zip(method, score))
|
| 28 |
+
return score_dict
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_vocabulary(in_path, out_file):
|
| 32 |
+
if 'levir' in in_path.lower():
|
| 33 |
+
return get_levir_vocabulary(in_path, out_file)
|
| 34 |
+
elif 'dubai' in in_path.lower():
|
| 35 |
+
return get_dubai_vocabulary(in_path, out_file)
|
| 36 |
+
elif 'clevr' in in_path.lower():
|
| 37 |
+
return get_clevr_vocabulary(in_path, out_file)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_levir_vocabulary(in_path, out_file):
|
| 41 |
+
with open(in_path) as fin:
|
| 42 |
+
data = json.load(fin)['images']
|
| 43 |
+
|
| 44 |
+
sents = [y for x in data for y in x['sentences']]
|
| 45 |
+
tokens = [y for x in sents for y in x['tokens']]
|
| 46 |
+
occurencies = pd.Series(tokens).value_counts()
|
| 47 |
+
selected = occurencies[occurencies > 5]
|
| 48 |
+
vocab = {w: i + 4 for i, w in enumerate(selected.index)}
|
| 49 |
+
vocab['PAD'] = 0
|
| 50 |
+
vocab['START'] = 1
|
| 51 |
+
vocab['UNK'] = 2
|
| 52 |
+
vocab['END'] = 3
|
| 53 |
+
|
| 54 |
+
with open(out_file, 'w') as fout:
|
| 55 |
+
json.dump(vocab, fout)
|
| 56 |
+
|
| 57 |
+
return vocab
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def get_dubai_vocabulary(in_path, out_file):
|
| 61 |
+
data = []
|
| 62 |
+
for path in glob.glob(in_path + '/*.json'):
|
| 63 |
+
with open(path) as fin:
|
| 64 |
+
data.extend(json.load(fin)['images'])
|
| 65 |
+
|
| 66 |
+
sents = [y for x in data for y in x['sentences']]
|
| 67 |
+
tokens = [y for x in sents for y in x['tokens']]
|
| 68 |
+
selected = pd.Series(tokens).value_counts()
|
| 69 |
+
vocab = {w: i + 4 for i, w in enumerate(selected.index)}
|
| 70 |
+
vocab['PAD'] = 0
|
| 71 |
+
vocab['START'] = 1
|
| 72 |
+
vocab['UNK'] = 2
|
| 73 |
+
vocab['END'] = 3
|
| 74 |
+
|
| 75 |
+
with open(out_file, 'w') as fout:
|
| 76 |
+
json.dump(vocab, fout)
|
| 77 |
+
|
| 78 |
+
return vocab
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_clevr_vocabulary(in_path, out_file):
|
| 82 |
+
sents = []
|
| 83 |
+
with open(os.path.join(in_path, 'change_captions.json'), 'r', encoding='utf-8') as fin:
|
| 84 |
+
data = json.load(fin)
|
| 85 |
+
sents += [y for x in data for y in data[x]]
|
| 86 |
+
|
| 87 |
+
with open(os.path.join(in_path, 'no_change_captions.json'), 'r', encoding='utf-8') as fin:
|
| 88 |
+
data = json.load(fin)
|
| 89 |
+
sents += [y for x in data for y in data[x]]
|
| 90 |
+
|
| 91 |
+
tokens = [y for x in sents for y in x.split(' ')]
|
| 92 |
+
occurencies = pd.Series(tokens).value_counts()
|
| 93 |
+
vocab = {w: i + 4 for i, w in enumerate(occurencies.index)}
|
| 94 |
+
vocab['PAD'] = 0
|
| 95 |
+
vocab['START'] = 1
|
| 96 |
+
vocab['UNK'] = 2
|
| 97 |
+
vocab['END'] = 3
|
| 98 |
+
|
| 99 |
+
with open(out_file, 'w') as fout:
|
| 100 |
+
json.dump(vocab, fout)
|
| 101 |
+
|
| 102 |
+
return vocab
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def unormalize(tensor, mean=None, std=None):
|
| 106 |
+
if mean is not None and std is not None:
|
| 107 |
+
for t, m, s in zip(tensor, mean, std):
|
| 108 |
+
t.mul_(s).add_(m)
|
| 109 |
+
return torch.clip(tensor, min=0, max=1)
|
| 110 |
+
|
| 111 |
+
b, c, h, w = tensor.shape
|
| 112 |
+
tensor = tensor.view(b, -1)
|
| 113 |
+
tensor -= tensor.min(1, keepdim=True)[0]
|
| 114 |
+
tensor /= tensor.max(1, keepdim=True)[0]
|
| 115 |
+
return tensor.view(b, c, h, w)
|