| | |
| | |
| |
|
| | import random |
| | import torch |
| | from torch.utils.data import Dataset |
| | from torch.utils.data import sampler |
| | |
| | import torchvision.transforms as transforms |
| | import six |
| | import sys |
| | from PIL import Image |
| | import numpy as np |
| | import os |
| | import sys |
| | import pickle |
| | import numpy as np |
| | from params import * |
| | import glob, cv2 |
| | import torchvision.transforms as transforms |
| |
|
| | def crop_(input): |
| | image = Image.fromarray(input) |
| | image = image.convert('L') |
| | binary_image = image.point(lambda x: 0 if x > 127 else 255, '1') |
| | bbox = binary_image.getbbox() |
| | cropped_image = image.crop(bbox) |
| | return np.array(cropped_image) |
| |
|
| | def get_transform(grayscale=False, convert=True): |
| |
|
| | transform_list = [] |
| | if grayscale: |
| | transform_list.append(transforms.Grayscale(1)) |
| |
|
| | if convert: |
| | transform_list += [transforms.ToTensor()] |
| | if grayscale: |
| | transform_list += [transforms.Normalize((0.5,), (0.5,))] |
| | else: |
| | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
| |
|
| | return transforms.Compose(transform_list) |
| |
|
| | def load_itw_samples(folder_path, num_samples = 15): |
| | if isinstance(folder_path, str): |
| | paths = glob.glob(f'{folder_path}/*') |
| | else: |
| | paths = folder_path |
| | paths = np.random.choice(paths, num_samples, replace = len(paths)<=num_samples) |
| |
|
| | words = [os.path.basename(path_i)[:-4] for path_i in paths] |
| |
|
| | imgs = [np.array(Image.open(i).convert('L')) for i in paths] |
| |
|
| | imgs = [crop_(im) for im in imgs] |
| | imgs = [cv2.resize(imgs_i, (int(32*(imgs_i.shape[1]/imgs_i.shape[0])), 32)) for imgs_i in imgs] |
| | max_width = 192 |
| |
|
| | imgs_pad = [] |
| | imgs_wids = [] |
| |
|
| | trans_fn = get_transform(grayscale=True) |
| |
|
| | for img in imgs: |
| |
|
| | img = 255 - img |
| | img_height, img_width = img.shape[0], img.shape[1] |
| | outImg = np.zeros(( img_height, max_width), dtype='float32') |
| | outImg[:, :img_width] = img[:, :max_width] |
| |
|
| | img = 255 - outImg |
| |
|
| | imgs_pad.append(trans_fn((Image.fromarray(img)))) |
| | imgs_wids.append(img_width) |
| |
|
| | imgs_pad = torch.cat(imgs_pad, 0) |
| |
|
| | return imgs_pad.unsqueeze(0), torch.Tensor(imgs_wids).unsqueeze(0) |
| |
|
| |
|
| | class TextDataset(): |
| |
|
| | def __init__(self, base_path = DATASET_PATHS, num_examples = 15, target_transform=None): |
| |
|
| | self.NUM_EXAMPLES = num_examples |
| |
|
| | |
| | file_to_store = open(base_path, "rb") |
| | self.IMG_DATA = pickle.load(file_to_store)['train'] |
| | self.IMG_DATA = dict(list( self.IMG_DATA.items())) |
| | if 'None' in self.IMG_DATA.keys(): |
| | del self.IMG_DATA['None'] |
| | self.author_id = list(self.IMG_DATA.keys()) |
| |
|
| | self.transform = get_transform(grayscale=True) |
| | self.target_transform = target_transform |
| |
|
| | self.collate_fn = TextCollator() |
| |
|
| |
|
| | def __len__(self): |
| | return len(self.author_id) |
| |
|
| | def __getitem__(self, index): |
| |
|
| |
|
| |
|
| | NUM_SAMPLES = self.NUM_EXAMPLES |
| |
|
| |
|
| | author_id = self.author_id[index] |
| |
|
| | self.IMG_DATA_AUTHOR = self.IMG_DATA[author_id] |
| | random_idxs = np.random.choice(len(self.IMG_DATA_AUTHOR), NUM_SAMPLES, replace = True) |
| |
|
| | rand_id_real = np.random.choice(len(self.IMG_DATA_AUTHOR)) |
| | real_img = self.transform(self.IMG_DATA_AUTHOR[rand_id_real]['img'].convert('L')) |
| | real_labels = self.IMG_DATA_AUTHOR[rand_id_real]['label'].encode() |
| |
|
| | imgs = [np.array(self.IMG_DATA_AUTHOR[idx]['img'].convert('L')) for idx in random_idxs] |
| | labels = [self.IMG_DATA_AUTHOR[idx]['label'].encode() for idx in random_idxs] |
| |
|
| | max_width = 192 |
| |
|
| | imgs_pad = [] |
| | imgs_wids = [] |
| |
|
| | for img in imgs: |
| |
|
| | img = 255 - img |
| | img_height, img_width = img.shape[0], img.shape[1] |
| | outImg = np.zeros(( img_height, max_width), dtype='float32') |
| | outImg[:, :img_width] = img[:, :max_width] |
| |
|
| | img = 255 - outImg |
| |
|
| | imgs_pad.append(self.transform((Image.fromarray(img)))) |
| | imgs_wids.append(img_width) |
| |
|
| | imgs_pad = torch.cat(imgs_pad, 0) |
| |
|
| |
|
| | item = {'simg': imgs_pad, 'swids':imgs_wids, 'img' : real_img, 'label':real_labels,'img_path':'img_path', 'idx':'indexes', 'wcl':index} |
| |
|
| |
|
| |
|
| | return item |
| |
|
| |
|
| |
|
| |
|
| | class TextDatasetval(): |
| |
|
| | def __init__(self, base_path = DATASET_PATHS, num_examples = 15, target_transform=None): |
| |
|
| | self.NUM_EXAMPLES = num_examples |
| | |
| | file_to_store = open(base_path, "rb") |
| | self.IMG_DATA = pickle.load(file_to_store)['test'] |
| | self.IMG_DATA = dict(list( self.IMG_DATA.items())) |
| | if 'None' in self.IMG_DATA.keys(): |
| | del self.IMG_DATA['None'] |
| | self.author_id = list(self.IMG_DATA.keys()) |
| |
|
| | self.transform = get_transform(grayscale=True) |
| | self.target_transform = target_transform |
| |
|
| | self.collate_fn = TextCollator() |
| |
|
| |
|
| | def __len__(self): |
| | return len(self.author_id) |
| |
|
| | def __getitem__(self, index): |
| |
|
| | NUM_SAMPLES = self.NUM_EXAMPLES |
| |
|
| | author_id = self.author_id[index] |
| |
|
| | self.IMG_DATA_AUTHOR = self.IMG_DATA[author_id] |
| | random_idxs = np.random.choice(len(self.IMG_DATA_AUTHOR), NUM_SAMPLES, replace = True) |
| |
|
| | rand_id_real = np.random.choice(len(self.IMG_DATA_AUTHOR)) |
| | real_img = self.transform(self.IMG_DATA_AUTHOR[rand_id_real]['img'].convert('L')) |
| | real_labels = self.IMG_DATA_AUTHOR[rand_id_real]['label'].encode() |
| |
|
| |
|
| | imgs = [np.array(self.IMG_DATA_AUTHOR[idx]['img'].convert('L')) for idx in random_idxs] |
| | labels = [self.IMG_DATA_AUTHOR[idx]['label'].encode() for idx in random_idxs] |
| |
|
| | max_width = 192 |
| |
|
| | imgs_pad = [] |
| | imgs_wids = [] |
| |
|
| | for img in imgs: |
| |
|
| | img = 255 - img |
| | img_height, img_width = img.shape[0], img.shape[1] |
| | outImg = np.zeros(( img_height, max_width), dtype='float32') |
| | outImg[:, :img_width] = img[:, :max_width] |
| |
|
| | img = 255 - outImg |
| |
|
| | imgs_pad.append(self.transform((Image.fromarray(img)))) |
| | imgs_wids.append(img_width) |
| |
|
| | imgs_pad = torch.cat(imgs_pad, 0) |
| |
|
| |
|
| | item = {'simg': imgs_pad, 'swids':imgs_wids, 'img' : real_img, 'label':real_labels,'img_path':'img_path', 'idx':'indexes', 'wcl':index} |
| |
|
| |
|
| |
|
| | return item |
| |
|
| |
|
| |
|
| |
|
| | class TextCollator(object): |
| | def __init__(self): |
| | self.resolution = resolution |
| |
|
| | def __call__(self, batch): |
| |
|
| | img_path = [item['img_path'] for item in batch] |
| | width = [item['img'].shape[2] for item in batch] |
| | indexes = [item['idx'] for item in batch] |
| | simgs = torch.stack([item['simg'] for item in batch], 0) |
| | wcls = torch.Tensor([item['wcl'] for item in batch]) |
| | swids = torch.Tensor([item['swids'] for item in batch]) |
| | imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], max(width)], dtype=torch.float32) |
| | for idx, item in enumerate(batch): |
| | try: |
| | imgs[idx, :, :, 0:item['img'].shape[2]] = item['img'] |
| | except: |
| | print(imgs.shape) |
| | item = {'img': imgs, 'img_path':img_path, 'idx':indexes, 'simg': simgs, 'swids': swids, 'wcl':wcls} |
| | if 'label' in batch[0].keys(): |
| | labels = [item['label'] for item in batch] |
| | item['label'] = labels |
| | if 'z' in batch[0].keys(): |
| | z = torch.stack([item['z'] for item in batch]) |
| | item['z'] = z |
| | return item |
| |
|
| |
|