| | import random |
| | from collections import defaultdict |
| |
|
| | import torch |
| | from torch.utils.data import Dataset |
| | import torchvision.transforms as transforms |
| | import os |
| | import pickle |
| | import numpy as np |
| | from PIL import Image |
| | from pathlib import Path |
| |
|
| |
|
| | def get_dataset_path(dataset_name, height, file_suffix, datasets_path): |
| | if file_suffix is not None: |
| | filename = f'{dataset_name}-{height}-{file_suffix}.pickle' |
| | else: |
| | filename = f'{dataset_name}-{height}.pickle' |
| |
|
| | return os.path.join(datasets_path, filename) |
| |
|
| |
|
| | def get_transform(grayscale=False, convert=True): |
| | transform_list = [] |
| | if grayscale: |
| | transform_list.append(transforms.Grayscale(1)) |
| |
|
| | if convert: |
| | transform_list += [transforms.ToTensor()] |
| | if grayscale: |
| | transform_list += [transforms.Normalize((0.5,), (0.5,))] |
| | else: |
| | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
| |
|
| | return transforms.Compose(transform_list) |
| |
|
| |
|
| | class TextDataset: |
| |
|
| | def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, min_virtual_size=0, validation=False, debug=False): |
| | self.NUM_EXAMPLES = num_examples |
| | self.debug = debug |
| | self.min_virtual_size = min_virtual_size |
| |
|
| | subset = 'test' if validation else 'train' |
| |
|
| | |
| | file_to_store = open(base_path, "rb") |
| | self.IMG_DATA = pickle.load(file_to_store)[subset] |
| | self.IMG_DATA = dict(list(self.IMG_DATA.items())) |
| | if 'None' in self.IMG_DATA.keys(): |
| | del self.IMG_DATA['None'] |
| |
|
| | self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), []))))) |
| | self.author_id = list(self.IMG_DATA.keys()) |
| |
|
| | self.transform = get_transform(grayscale=True) |
| | self.target_transform = target_transform |
| |
|
| | self.collate_fn = TextCollator(collator_resolution) |
| |
|
| | def __len__(self): |
| | if self.debug: |
| | return 16 |
| | return max(len(self.author_id), self.min_virtual_size) |
| |
|
| | @property |
| | def num_writers(self): |
| | return len(self.author_id) |
| |
|
| | def __getitem__(self, index): |
| | index = index % len(self.author_id) |
| |
|
| | author_id = self.author_id[index] |
| |
|
| | self.IMG_DATA_AUTHOR = self.IMG_DATA[author_id] |
| | random_idxs = random.choices([i for i in range(len(self.IMG_DATA_AUTHOR))], k=self.NUM_EXAMPLES) |
| |
|
| | word_data = random.choice(self.IMG_DATA_AUTHOR) |
| | real_img = self.transform(word_data['img'].convert('L')) |
| | real_labels = word_data['label'].encode() |
| |
|
| | imgs = [np.array(self.IMG_DATA_AUTHOR[idx]['img'].convert('L')) for idx in random_idxs] |
| | slabels = [self.IMG_DATA_AUTHOR[idx]['label'].encode() for idx in random_idxs] |
| |
|
| | max_width = 192 |
| |
|
| | imgs_pad = [] |
| | imgs_wids = [] |
| |
|
| | for img in imgs: |
| | img_height, img_width = img.shape[0], img.shape[1] |
| | output_img = np.ones((img_height, max_width), dtype='float32') * 255.0 |
| | output_img[:, :img_width] = img[:, :max_width] |
| |
|
| | imgs_pad.append(self.transform(Image.fromarray(output_img.astype(np.uint8)))) |
| | imgs_wids.append(img_width) |
| |
|
| | imgs_pad = torch.cat(imgs_pad, 0) |
| |
|
| | item = { |
| | 'simg': imgs_pad, |
| | 'swids': imgs_wids, |
| | 'img': real_img, |
| | 'label': real_labels, |
| | 'img_path': 'img_path', |
| | 'idx': 'indexes', |
| | 'wcl': index, |
| | 'slabels': slabels, |
| | 'author_id': author_id |
| | } |
| | return item |
| |
|
| | def get_stats(self): |
| | char_counts = defaultdict(lambda: 0) |
| | total = 0 |
| |
|
| | for author in self.IMG_DATA.keys(): |
| | for data in self.IMG_DATA[author]: |
| | for char in data['label']: |
| | char_counts[char] += 1 |
| | total += 1 |
| |
|
| | char_counts = {k: 1.0 / (v / total) for k, v in char_counts.items()} |
| |
|
| | return char_counts |
| |
|
| |
|
| | class TextCollator(object): |
| | def __init__(self, resolution): |
| | self.resolution = resolution |
| |
|
| | def __call__(self, batch): |
| | if isinstance(batch[0], list): |
| | batch = sum(batch, []) |
| | img_path = [item['img_path'] for item in batch] |
| | width = [item['img'].shape[2] for item in batch] |
| | indexes = [item['idx'] for item in batch] |
| | simgs = torch.stack([item['simg'] for item in batch], 0) |
| | wcls = torch.Tensor([item['wcl'] for item in batch]) |
| | swids = torch.Tensor([item['swids'] for item in batch]) |
| | imgs = torch.ones([len(batch), batch[0]['img'].shape[0], batch[0]['img'].shape[1], max(width)], |
| | dtype=torch.float32) |
| | for idx, item in enumerate(batch): |
| | try: |
| | imgs[idx, :, :, 0:item['img'].shape[2]] = item['img'] |
| | except: |
| | print(imgs.shape) |
| | item = {'img': imgs, 'img_path': img_path, 'idx': indexes, 'simg': simgs, 'swids': swids, 'wcl': wcls} |
| | if 'label' in batch[0].keys(): |
| | labels = [item['label'] for item in batch] |
| | item['label'] = labels |
| | if 'slabels' in batch[0].keys(): |
| | slabels = [item['slabels'] for item in batch] |
| | item['slabels'] = np.array(slabels) |
| | if 'z' in batch[0].keys(): |
| | z = torch.stack([item['z'] for item in batch]) |
| | item['z'] = z |
| | return item |
| |
|
| |
|
| | class CollectionTextDataset(Dataset): |
| | def __init__(self, datasets, datasets_path, dataset_class, file_suffix=None, height=32, **kwargs): |
| | self.datasets = {} |
| | for dataset_name in sorted(datasets.split(',')): |
| | dataset_file = get_dataset_path(dataset_name, height, file_suffix, datasets_path) |
| | dataset = dataset_class(dataset_file, **kwargs) |
| | self.datasets[dataset_name] = dataset |
| | self.alphabet = ''.join(sorted(set(''.join(d.alphabet for d in self.datasets.values())))) |
| |
|
| | def __len__(self): |
| | return sum(len(d) for d in self.datasets.values()) |
| |
|
| | @property |
| | def num_writers(self): |
| | return sum(d.num_writers for d in self.datasets.values()) |
| |
|
| | def __getitem__(self, index): |
| | for dataset in self.datasets.values(): |
| | if index < len(dataset): |
| | return dataset[index] |
| | index -= len(dataset) |
| | raise IndexError |
| |
|
| | def get_dataset(self, index): |
| | for dataset_name, dataset in self.datasets.items(): |
| | if index < len(dataset): |
| | return dataset_name |
| | index -= len(dataset) |
| | raise IndexError |
| |
|
| | def collate_fn(self, batch): |
| | return self.datasets[self.get_dataset(0)].collate_fn(batch) |
| |
|
| |
|
| | class FidDataset(Dataset): |
| | def __init__(self, base_path, collator_resolution, num_examples=15, target_transform=None, mode='train', style_dataset=None): |
| | self.NUM_EXAMPLES = num_examples |
| |
|
| | |
| | with open(base_path, "rb") as f: |
| | self.IMG_DATA = pickle.load(f) |
| |
|
| | self.IMG_DATA = self.IMG_DATA[mode] |
| | if 'None' in self.IMG_DATA.keys(): |
| | del self.IMG_DATA['None'] |
| |
|
| | self.STYLE_IMG_DATA = None |
| | if style_dataset is not None: |
| | with open(style_dataset, "rb") as f: |
| | self.STYLE_IMG_DATA = pickle.load(f) |
| |
|
| | self.STYLE_IMG_DATA = self.STYLE_IMG_DATA[mode] |
| | if 'None' in self.STYLE_IMG_DATA.keys(): |
| | del self.STYLE_IMG_DATA['None'] |
| |
|
| | self.alphabet = ''.join(sorted(set(''.join(d['label'] for d in sum(self.IMG_DATA.values(), []))))) |
| | self.author_id = sorted(self.IMG_DATA.keys()) |
| |
|
| | self.transform = get_transform(grayscale=True) |
| | self.target_transform = target_transform |
| | self.dataset_size = sum(len(samples) for samples in self.IMG_DATA.values()) |
| | self.collate_fn = TextCollator(collator_resolution) |
| |
|
| | def __len__(self): |
| | return self.dataset_size |
| |
|
| | @property |
| | def num_writers(self): |
| | return len(self.author_id) |
| |
|
| | def __getitem__(self, index): |
| | NUM_SAMPLES = self.NUM_EXAMPLES |
| | sample, author_id = None, None |
| | for author_id, samples in self.IMG_DATA.items(): |
| | if index < len(samples): |
| | sample, author_id = samples[index], author_id |
| | break |
| | index -= len(samples) |
| |
|
| | real_image = self.transform(sample['img'].convert('L')) |
| | real_label = sample['label'].encode() |
| |
|
| | style_dataset = self.STYLE_IMG_DATA if self.STYLE_IMG_DATA is not None else self.IMG_DATA |
| |
|
| | author_style_images = style_dataset[author_id] |
| | random_idxs = np.random.choice(len(author_style_images), NUM_SAMPLES, replace=True) |
| | style_images = [np.array(author_style_images[idx]['img'].convert('L')) for idx in random_idxs] |
| |
|
| | max_width = 192 |
| |
|
| | imgs_pad = [] |
| | imgs_wids = [] |
| |
|
| | for img in style_images: |
| | img = 255 - img |
| | img_height, img_width = img.shape[0], img.shape[1] |
| | outImg = np.zeros((img_height, max_width), dtype='float32') |
| | outImg[:, :img_width] = img[:, :max_width] |
| |
|
| | img = 255 - outImg |
| |
|
| | imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8)))) |
| | imgs_wids.append(img_width) |
| |
|
| | imgs_pad = torch.cat(imgs_pad, 0) |
| |
|
| | item = { |
| | 'simg': imgs_pad, |
| | 'swids': imgs_wids, |
| | 'img': real_image, |
| | 'label': real_label, |
| | 'img_path': 'img_path', |
| | 'idx': sample['img_id'] if 'img_id' in sample.keys() else sample['image_id'], |
| | 'wcl': int(author_id) |
| | } |
| | return item |
| |
|
| |
|
| | class FolderDataset: |
| | def __init__(self, folder_path, num_examples=15, word_lengths=None): |
| | folder_path = Path(folder_path) |
| | self.imgs = list([p for p in folder_path.iterdir() if not p.suffix == '.txt']) |
| | self.transform = get_transform(grayscale=True) |
| | self.num_examples = num_examples |
| | self.word_lengths = word_lengths |
| |
|
| | def __len__(self): |
| | return len(self.imgs) |
| |
|
| | def sample_style(self): |
| | random_idxs = np.random.choice(len(self.imgs), self.num_examples, replace=False) |
| | image_names = [self.imgs[idx].stem for idx in random_idxs] |
| | imgs = [Image.open(self.imgs[idx]).convert('L') for idx in random_idxs] |
| | if self.word_lengths is None: |
| | imgs = [img.resize((img.size[0] * 32 // img.size[1], 32), Image.BILINEAR) for img in imgs] |
| | else: |
| | imgs = [img.resize((self.word_lengths[name] * 16, 32), Image.BILINEAR) for img, name in zip(imgs, image_names)] |
| | imgs = [np.array(img) for img in imgs] |
| |
|
| | max_width = 192 |
| |
|
| | imgs_pad = [] |
| | imgs_wids = [] |
| |
|
| | for img in imgs: |
| | img = 255 - img |
| | img_height, img_width = img.shape[0], img.shape[1] |
| | outImg = np.zeros((img_height, max_width), dtype='float32') |
| | outImg[:, :img_width] = img[:, :max_width] |
| |
|
| | img = 255 - outImg |
| |
|
| | imgs_pad.append(self.transform(Image.fromarray(img.astype(np.uint8)))) |
| | imgs_wids.append(img_width) |
| |
|
| | imgs_pad = torch.cat(imgs_pad, 0) |
| |
|
| | item = { |
| | 'simg': imgs_pad, |
| | 'swids': imgs_wids, |
| | } |
| | return item |
| |
|