|
|
import os |
|
|
import json |
|
|
import torch |
|
|
import numpy as np |
|
|
from torch.utils.data import Dataset |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
import faiss |
|
|
import torch.nn.functional as F |
|
|
from sentence_transformers import SentenceTransformer |
|
|
import torchvision.transforms as transforms |
|
|
from random import choice |
|
|
|
|
|
|
|
|
class CCDataset(Dataset): |
|
|
def __init__(self, json_file, root_dir, vocab, transform, split, max_length, s_pretrained, device): |
|
|
super(CCDataset, self).__init__() |
|
|
self.vocab = vocab |
|
|
self.split = split |
|
|
self.max_length = max_length |
|
|
self.device = device |
|
|
self.transform = transform |
|
|
assert self.split in {'train', 'val', 'test'} |
|
|
|
|
|
s_model = SentenceTransformer(s_pretrained) |
|
|
self.s_model = s_model.to(device) |
|
|
|
|
|
self.root_dir = root_dir |
|
|
self.convert = transforms.ToTensor() |
|
|
|
|
|
with open(json_file) as f: |
|
|
data = json.load(f)['images'] |
|
|
|
|
|
self.raw_dataset = [entry for entry in data if entry['split'] == split] |
|
|
self.sentences = [] |
|
|
self.embeddings = [] |
|
|
|
|
|
self.images = [] |
|
|
self.captions = [] |
|
|
for record in tqdm(self.raw_dataset, desc='Tokenize ' + self.split): |
|
|
self.sentences.extend(self.tokenize(record['sentences'])) |
|
|
|
|
|
for record in tqdm(self.raw_dataset, desc='Embeddings ' + self.split): |
|
|
self.embeddings.extend(self.compute_embeddings(record['sentences'])) |
|
|
|
|
|
self.preprocess() |
|
|
del self.raw_dataset |
|
|
del self.sentences |
|
|
del self.embeddings |
|
|
del self.s_model |
|
|
|
|
|
def tokenize(self, batch): |
|
|
for elem in batch: |
|
|
tokens = [self.vocab[x] if x in self.vocab.keys() else self.vocab['UNK'] for x in elem['tokens']] |
|
|
if len(tokens) > self.max_length - 2: |
|
|
continue |
|
|
|
|
|
tokens = [self.vocab['START']] + tokens + [self.vocab['END']] |
|
|
|
|
|
mask = [False] * len(tokens) |
|
|
|
|
|
diff = self.max_length - len(tokens) |
|
|
tokens += [self.vocab['PAD']] * diff |
|
|
mask += [True] * diff |
|
|
|
|
|
elem['input_ids'] = tokens |
|
|
elem['mask'] = mask |
|
|
|
|
|
if len(batch) < 5: |
|
|
diff = 5 - len(batch) |
|
|
batch += [choice(batch) for _ in range(diff)] |
|
|
|
|
|
assert len(batch) == 5 |
|
|
return batch |
|
|
|
|
|
def compute_embeddings(self, batch): |
|
|
sents = [x['raw'].strip() for x in batch] |
|
|
embs = self.s_model.encode(sents) |
|
|
return embs |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.captions) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
img_idx = idx // 5 if self.split == 'train' else idx |
|
|
elem = self.captions[idx] |
|
|
for k, v in self.images[img_idx].items(): |
|
|
elem[k] = v |
|
|
return elem |
|
|
|
|
|
def preprocess(self): |
|
|
idx = 0 |
|
|
prev_idx = -1 |
|
|
pbar = tqdm(total=len(self.sentences), desc='Preprocessing ' + self.split) |
|
|
while idx < len(self.sentences): |
|
|
img_idx = idx // 5 |
|
|
assert (self.sentences[idx]['imgid'] == self.raw_dataset[img_idx]['imgid']) |
|
|
|
|
|
input_ids = torch.tensor(self.sentences[idx]['input_ids'], dtype=torch.long) |
|
|
mask = torch.tensor(self.sentences[idx]['mask'], dtype=torch.bool) |
|
|
raws = [x['raw'] for x in self.raw_dataset[img_idx]['sentences']] |
|
|
flag = -1 if self.raw_dataset[img_idx]['changeflag'] == 0 else self.raw_dataset[img_idx]['imgid'] |
|
|
flag = torch.tensor(flag, dtype=torch.long) |
|
|
embs = torch.tensor(self.embeddings[idx]) if len(self.embeddings) > 0 else None |
|
|
|
|
|
self.captions.append({'input_ids': input_ids, 'pad_masks': mask, 'raws': raws, 'flags': flag, 'embs': embs}) |
|
|
|
|
|
if img_idx != prev_idx: |
|
|
before_img_path = os.path.join(self.root_dir, self.raw_dataset[img_idx]['filepath'], 'A', |
|
|
self.raw_dataset[img_idx]['filename']) |
|
|
image_before = Image.open(before_img_path) |
|
|
after_img_path = os.path.join(self.root_dir, self.raw_dataset[img_idx]['filepath'], 'B', |
|
|
self.raw_dataset[img_idx]['filename']) |
|
|
image_after = Image.open(after_img_path) |
|
|
|
|
|
image_before = self.transform(image_before).unsqueeze(0) |
|
|
image_after = self.transform(image_after).unsqueeze(0) |
|
|
|
|
|
self.images.append({'image_before': image_before, 'image_after': image_after, 'flags': flag}) |
|
|
prev_idx = img_idx |
|
|
|
|
|
inc = 1 if self.split == 'train' else 5 |
|
|
idx += inc |
|
|
pbar.update(inc) |
|
|
|
|
|
pbar.close() |
|
|
|
|
|
|
|
|
class Batcher: |
|
|
def __init__(self, dataset, batch_size, max_len, device, hd=0, model=None, shuffle=False): |
|
|
self.dataset = dataset |
|
|
self.batch_size = batch_size |
|
|
self.hd = hd |
|
|
self.max_len = max_len |
|
|
self.device = device |
|
|
self.model = model |
|
|
self.index = None |
|
|
self.visual = None |
|
|
self.textual = None |
|
|
|
|
|
self.ptr = 0 |
|
|
self.indices = np.arange(len(self.dataset)) |
|
|
self.shuffle = shuffle |
|
|
|
|
|
if shuffle: |
|
|
np.random.shuffle(self.indices) |
|
|
|
|
|
if model and hd > 0 and self.dataset.split == 'train': |
|
|
self.create_index() |
|
|
|
|
|
def __iter__(self): |
|
|
return self |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.dataset) // self.batch_size |
|
|
|
|
|
def __next__(self): |
|
|
if self.ptr >= len(self.dataset): |
|
|
self.ptr = 0 |
|
|
self.index = None |
|
|
self.visual = None |
|
|
self.textual = None |
|
|
|
|
|
if self.shuffle: |
|
|
np.random.shuffle(self.indices) |
|
|
if self.model and self.hd > 0 and self.dataset.split == 'train': |
|
|
self.create_index() |
|
|
|
|
|
raise StopIteration |
|
|
|
|
|
batched = 0 |
|
|
samples = [] |
|
|
hard_negatives = [] |
|
|
while self.ptr < len(self.dataset) and batched < self.batch_size: |
|
|
sample = self.dataset[self.indices[self.ptr]] |
|
|
samples.append(sample) |
|
|
|
|
|
if self.hd > 0 and self.dataset.split == 'train': |
|
|
hard_neg = self.mine_negatives(self.indices[self.ptr], self.hd) |
|
|
hard_negatives.extend(hard_neg) |
|
|
|
|
|
self.ptr += 1 |
|
|
batched += 1 |
|
|
|
|
|
return self.create_batch(samples + hard_negatives) |
|
|
|
|
|
def get_elem(self, idx): |
|
|
return self.dataset[idx] |
|
|
|
|
|
@torch.no_grad() |
|
|
def create_index(self): |
|
|
is_training = self.model.training |
|
|
self.model.eval() |
|
|
self.index = faiss.IndexFlatIP(self.model.feature_dim) |
|
|
prev_img = None |
|
|
for idx in tqdm(range(len(self.dataset)), desc='Creating index'): |
|
|
sample = self.dataset[idx] |
|
|
imgs1, imgs2, = sample['image_before'], sample['image_after'] |
|
|
input_ids, mask = sample['input_ids'], sample['pad_masks'] |
|
|
|
|
|
if idx // 5 != prev_img: |
|
|
imgs1 = imgs1.to(self.device) |
|
|
imgs2 = imgs2.to(self.device) |
|
|
vis_emb, _, = self.model.encoder(imgs1, imgs2) |
|
|
self.visual = torch.cat([self.visual, vis_emb.cpu()]) if self.visual is not None else vis_emb.cpu() |
|
|
prev_img = prev_img + 1 if prev_img is not None else 0 |
|
|
|
|
|
input_ids = input_ids.unsqueeze(0).to(self.device) |
|
|
mask = mask.unsqueeze(0).to(self.device) |
|
|
_, text_emb, _, _ = self.model.decoder(input_ids, None, mask, None) |
|
|
self.textual = torch.cat([self.textual, text_emb.cpu()]) if self.textual is not None else text_emb.cpu() |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.visual = F.normalize(self.visual, p=2, dim=1) |
|
|
self.textual = F.normalize(self.textual, p=2, dim=1) |
|
|
self.index.add(self.visual) |
|
|
if is_training: |
|
|
self.model.train() |
|
|
|
|
|
def mine_negatives(self, idx, n): |
|
|
negatives = [] |
|
|
m = 4 |
|
|
label = self.dataset[idx]['flags'].item() |
|
|
|
|
|
while len(negatives) < n and (n * m) < self.index.ntotal: |
|
|
k = n * m |
|
|
indeces = self.index.search(self.textual[idx].unsqueeze(0), k)[1][0] |
|
|
indeces = [x * 5 for x in indeces] |
|
|
negatives = [self.dataset[x] for x in indeces if self.dataset[x]['flags'].item() != label][:n] |
|
|
m *= 2 |
|
|
|
|
|
return negatives |
|
|
|
|
|
def create_batch(self, samples): |
|
|
images_before = images_after = input_ids = pad_mask = labels = flags = embs = None |
|
|
raws = [] |
|
|
|
|
|
for sample in samples: |
|
|
img1 = sample['image_before'] |
|
|
img2 = sample['image_after'] |
|
|
|
|
|
tokens = sample['input_ids'] |
|
|
mask = sample['pad_masks'] |
|
|
flag = sample['flags'] |
|
|
emb = sample['embs'] |
|
|
|
|
|
tokens = tokens.unsqueeze(0) |
|
|
mask = mask.unsqueeze(0) |
|
|
flag = flag.unsqueeze(0) |
|
|
lab = tokens.clone() * ~mask |
|
|
lab += torch.tensor([[-100]], dtype=torch.long).repeat(1, self.max_len) * mask |
|
|
if emb is not None: |
|
|
emb = emb.unsqueeze(0) |
|
|
|
|
|
images_before = torch.cat([images_before, img1]) if images_before is not None else img1 |
|
|
images_after = torch.cat([images_after, img2]) if images_after is not None else img2 |
|
|
input_ids = torch.cat([input_ids, tokens]) if input_ids is not None else tokens |
|
|
labels = torch.cat([labels, lab]) if labels is not None else lab |
|
|
pad_mask = torch.cat([pad_mask, mask]) if pad_mask is not None else mask |
|
|
flags = torch.cat([flags, flag]) if flags is not None else flag |
|
|
if emb is not None: |
|
|
embs = torch.cat([embs, emb]) if embs is not None else emb |
|
|
|
|
|
raws.append(sample['raws']) |
|
|
|
|
|
return {'images_before': images_before, 'images_after': images_after, 'input_ids': input_ids, |
|
|
'pad_mask': pad_mask, 'labels': labels, 'flags': flags, 'raws': raws, 'embs': embs} |
|
|
|