from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals from transformers import BertTokenizer, BertModel from transformers import BertTokenizer, TFBertModel from nltk.tokenize import RegexpTokenizer from collections import defaultdict from miscc.config import cfg import torch import torch.utils.data as data from torch.autograd import Variable import torchvision.transforms as transforms from transformers import AutoTokenizer import os import sys import numpy as np import pandas as pd from PIL import Image import numpy.random as random if sys.version_info[0] == 2: import cPickle as pickle else: import pickle def prepare_data(data): imgs, captions, captions_lens, class_ids, keys = data # sort data by the length in a decreasing order sorted_cap_lens, sorted_cap_indices = \ torch.sort(captions_lens, 0, True) real_imgs = [] for i in range(len(imgs)): imgs[i] = imgs[i][sorted_cap_indices] if cfg.CUDA: real_imgs.append(Variable(imgs[i]).cuda()) else: real_imgs.append(Variable(imgs[i])) captions = captions[sorted_cap_indices].squeeze() class_ids = class_ids[sorted_cap_indices].numpy() # sent_indices = sent_indices[sorted_cap_indices] keys = [keys[i] for i in sorted_cap_indices.numpy()] # print('keys', type(keys), keys[-1]) # list if cfg.CUDA: captions = Variable(captions).cuda() sorted_cap_lens = Variable(sorted_cap_lens).cuda() else: captions = Variable(captions) sorted_cap_lens = Variable(sorted_cap_lens) return [real_imgs, captions, sorted_cap_lens, class_ids, keys] def get_imgs(img_path, imsize, bbox=None, transform=None, normalize=None): img = Image.open(img_path).convert('RGB') width, height = img.size if bbox is not None: r = int(np.maximum(bbox[2], bbox[3]) * 0.75) center_x = int((2 * bbox[0] + bbox[2]) / 2) center_y = int((2 * bbox[1] + bbox[3]) / 2) y1 = np.maximum(0, center_y - r) y2 = np.minimum(height, center_y + r) x1 = np.maximum(0, center_x - r) x2 = np.minimum(width, center_x + r) img = img.crop([x1, y1, x2, y2]) if transform is not None: img = transform(img) ret = [] ret.append(normalize(img)) #if cfg.GAN.B_DCGAN: ''' for i in range(cfg.TREE.BRANCH_NUM): # print(imsize[i]) re_img = transforms.Resize(imsize[i])(img) ret.append(normalize(re_img)) ''' return ret class TextDataset(data.Dataset): def __init__(self, data_dir, split='train', base_size=64, transform=None, target_transform=None): self.transform = transform self.norm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) self.target_transform = target_transform self.embeddings_num = cfg.TEXT.CAPTIONS_PER_IMAGE self.imsize = [] for i in range(cfg.TREE.BRANCH_NUM): self.imsize.append(base_size) base_size = base_size * 2 self.data = [] self.data_dir = data_dir if data_dir.find('CUB-200') != -1: self.bbox = self.load_bbox() else: self.bbox = None split_dir = os.path.join(data_dir, split) self.filenames, self.captions, self.ixtoword, \ self.wordtoix, self.n_words = self.load_text_data(data_dir, split) self.class_id = self.load_class_id(split_dir, len(self.filenames)) self.number_example = len(self.filenames) def load_bbox(self): data_dir = self.data_dir bbox_path = os.path.join(data_dir, 'CUB_200_2011/bounding_boxes.txt') df_bounding_boxes = pd.read_csv(bbox_path, delim_whitespace=True, header=None).astype(int) # filepath = os.path.join(data_dir, 'CUB_200_2011/images.txt') df_filenames = \ pd.read_csv(filepath, delim_whitespace=True, header=None) filenames = df_filenames[1].tolist() print('Total filenames: ', len(filenames), filenames[0]) # filename_bbox = {img_file[:-4]: [] for img_file in filenames} numImgs = len(filenames) for i in range(0, numImgs): # bbox = [x-left, y-top, width, height] bbox = df_bounding_boxes.iloc[i][1:].tolist() key = filenames[i][:-4] filename_bbox[key] = bbox # return filename_bbox def load_captions(self, data_dir, filenames): all_captions = [] text_dir = '%s/hindi_text' % (data_dir) for i in range(len(filenames)): cap_path = '%s/%s.txt' % (text_dir, filenames[i]) with open(cap_path, "rb") as f: captions = f.read().decode('utf8').split('\n') cnt = 0 for cap in captions: if len(cap) == 0: continue cap = cap.replace("\ufffd\ufffd", " ") # picks out sequences of alphanumeric characters as tokens # and drops everything else tokenizer = RegexpTokenizer(r'\w+') tokens = tokenizer.tokenize(cap.lower()) # print('tokens', tokens) if len(tokens) == 0: print('cap', cap) continue tokens_new = [] for t in tokens: t = t.encode('cp1256', 'ignore').decode('cp1256') if len(t) > 0: tokens_new.append(t) all_captions.append(tokens_new) cnt += 1 if cnt == self.embeddings_num: break if cnt < self.embeddings_num: print('ERROR: the captions for %s less than %d' % (filenames[i], cnt)) return all_captions def build_dictionary(self, train_captions, test_captions): word_counts = defaultdict(float) captions = train_captions + test_captions for sent in captions: for word in sent: word_counts[word] += 1 vocab = [w for w in word_counts if word_counts[w] >= 0] ixtoword = {} ixtoword[0] = '' wordtoix = {} wordtoix[''] = 0 ix = 1 for w in vocab: wordtoix[w] = ix ixtoword[ix] = w ix += 1 train_captions_new = [] for t in train_captions: rev = [] for w in t: if w in wordtoix: rev.append(wordtoix[w]) # rev.append(0) # do not need '' token train_captions_new.append(rev) test_captions_new = [] for t in test_captions: rev = [] for w in t: if w in wordtoix: rev.append(wordtoix[w]) # rev.append(0) # do not need '' token test_captions_new.append(rev) return [train_captions_new, test_captions_new, ixtoword, wordtoix, len(ixtoword)] def load_text_data(self, data_dir, split): filepath = os.path.join(data_dir, 'hindi_captions.pickle') train_names = self.load_filenames(data_dir, 'train') test_names = self.load_filenames(data_dir, 'test') if not os.path.isfile(filepath): train_captions = self.load_captions(data_dir, train_names) test_captions = self.load_captions(data_dir, test_names) train_captions, test_captions, ixtoword, wordtoix, n_words = \ self.build_dictionary(train_captions, test_captions) with open(filepath, 'wb') as f: pickle.dump([train_captions, test_captions, ixtoword, wordtoix], f, protocol=2) print('Save to: ', filepath) else: with open(filepath, 'rb') as f: x = pickle.load(f) train_captions, test_captions = x[0], x[1] ixtoword, wordtoix = x[2], x[3] del x n_words = len(ixtoword) print('Load from: ', filepath) if split == 'train': # a list of list: each list contains # the indices of words in a sentence captions = train_captions filenames = train_names else: # split=='test' captions = test_captions filenames = test_names return filenames, captions, ixtoword, wordtoix, n_words def load_class_id(self, data_dir, total_num): if os.path.isfile(data_dir + '/class_info.pickle'): with open(data_dir + '/class_info.pickle', 'rb') as f: class_id = pickle.load(f, encoding="bytes") else: class_id = np.arange(total_num) return class_id def load_filenames(self, data_dir, split): filepath = '%s/%s/filenames.pickle' % (data_dir, split) if os.path.isfile(filepath): with open(filepath, 'rb') as f: filenames = pickle.load(f) print('Load filenames from: %s (%d)' % (filepath, len(filenames))) else: filenames = [] return filenames def get_caption(self, sent_ix): # a list of indices for a sentence sent_caption = np.asarray(self.captions[sent_ix]).astype('int64') if (sent_caption == 0).sum() > 0: print('ERROR: do not need END (0) token', sent_caption) num_words = len(sent_caption) # pad with 0s (i.e., '') x = np.zeros((cfg.TEXT.WORDS_NUM, 1), dtype='int64') x_len = num_words if num_words <= cfg.TEXT.WORDS_NUM: x[:num_words, 0] = sent_caption else: ix = list(np.arange(num_words)) # 1, 2, 3,..., maxNum np.random.shuffle(ix) ix = ix[:cfg.TEXT.WORDS_NUM] ix = np.sort(ix) x[:, 0] = sent_caption[ix] x_len = cfg.TEXT.WORDS_NUM return x, x_len def __getitem__(self, index): # key = self.filenames[index] cls_id = self.class_id[index] # if self.bbox is not None: bbox = self.bbox[key] data_dir = '%s/CUB_200_2011' % self.data_dir else: bbox = None data_dir = self.data_dir # img_name = '%s/images/%s.jpg' % (data_dir, key) imgs = get_imgs(img_name, self.imsize, bbox, self.transform, normalize=self.norm) # random select a sentence sent_ix = random.randint(0, self.embeddings_num) new_sent_ix = index * self.embeddings_num + sent_ix caps, cap_len = self.get_caption(new_sent_ix) return imgs, caps, cap_len, cls_id, key def __len__(self): return len(self.filenames) class TextBertDataset(TextDataset): def __init__(self, *args, **kwargs): self.tokenizer = BertTokenizer.from_pretrained(cfg.GAN.BERT_NAME) #self.tokenizer = AutoTokenizer.from_pretrained("monsoon-nlp/hindi-bert") self.arabert_prep = BertModel.from_pretrained(cfg.GAN.BERT_NAME) super().__init__(*args, **kwargs) # Load pre-trained model tokenizer (vocabulary) def load_captions(self, data_dir, filenames): all_captions = [] text_dir = '%s/hindi_text' % (data_dir) for i in range(len(filenames)): cap_path = '%s/%s.txt' % (text_dir, filenames[i]) with open(cap_path, "r") as f: captions = f.read().split('\n') if len(captions) == 0: print(cap_path) cnt = 0 for cap in captions: if len(cap) == 0: print(cap_path) continue # picks out sequences of alphanumeric characters as tokens # and drops everything else #preprocess cap # print(f'cap = {cap}') ##cap = self.arabert_prep(cap) # print(f'prep_cap = {cap}') tokens = self.tokenizer.tokenize(cap.lower()) # print(f'tokens = {tokens}') # print('tokens', tokens) if len(tokens) == 0: print('cap', cap) print(cap_path) continue tokens_new = [] for t in tokens: #t = t.encode('cp1256', 'ignore').decode('cp1256') if len(t) > 0: tokens_new.append(t) # print(f'tokens_new = {tokens_new}') all_captions.append(tokens_new) cnt += 1 if cnt == self.embeddings_num: break if cnt < self.embeddings_num: print('ERROR: the captions for %s less than %d' % (filenames[i], cnt)) return all_captions def load_text_data(self, data_dir, split): train_names = self.load_filenames(data_dir, 'train') test_names = self.load_filenames(data_dir, 'test') filepath = os.path.join(data_dir, 'hindi_captions1.pickle') if not os.path.isfile(filepath): train_captions = self.load_captions(data_dir, train_names) test_captions = self.load_captions(data_dir, test_names) train_captions, test_captions, ixtoword, wordtoix, n_words = \ self.build_dictionary(train_captions, test_captions) with open(filepath, 'wb') as f: pickle.dump([train_captions, test_captions, ixtoword, wordtoix], f, protocol=2) print('Save to: ', filepath) else: with open(filepath, 'rb') as f: x = pickle.load(f) train_captions, test_captions = x[0], x[1] ixtoword, wordtoix = x[2], x[3] del x n_words = len(ixtoword) print('Load from: ', filepath) if split == 'train': # a list of list: each list contains # the indices of words in a sentence captions = train_captions filenames = train_names else: # split=='test' captions = test_captions filenames = test_names return filenames, captions, ixtoword, wordtoix, n_words def build_dictionary(self, train_captions, test_captions): """ Tokenize according to bert model """ ixtoword = {} wordtoix = {} # check for special tokens [PAD][SEP][CLS]. train_captions_new = [] for sent in train_captions: indexed_tokens = self.tokenizer.convert_tokens_to_ids(sent) train_captions_new.append(indexed_tokens) for idx, word in zip(indexed_tokens, sent): wordtoix[word] = idx ixtoword[idx] = word test_captions_new = [] for sent in test_captions: indexed_tokens = self.tokenizer.convert_tokens_to_ids(sent) test_captions_new.append(indexed_tokens) for idx, word in zip(indexed_tokens, sent): wordtoix[word] = idx ixtoword[idx] = word return [train_captions_new, test_captions_new, ixtoword, wordtoix, len(ixtoword)]