|
|
import os |
|
|
import numpy as np |
|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
from PIL import Image, ImageOps |
|
|
from config import CONFIG |
|
|
|
|
|
def rgb2grey(rgb): |
|
|
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114]) |
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
from config import CONFIG |
|
|
|
|
|
def data_split(): |
|
|
config = CONFIG() |
|
|
folder = config.data_folder |
|
|
val_proportion = float(config.val_proportion) |
|
|
|
|
|
files = os.listdir(folder) |
|
|
period = int(np.round(1 / val_proportion)) |
|
|
|
|
|
|
|
|
train_files = files[::period] |
|
|
val_files = files[1::period] |
|
|
test_files = files[2::period] |
|
|
max_H, max_W, max_seq_length = 256, 256, 100 |
|
|
|
|
|
return train_files, val_files, test_files, max_H, max_W, max_seq_length |
|
|
|
|
|
def augmentation_simple(filename, aug_type, max_H, max_W, folder=CONFIG.data_folder): |
|
|
image = rgb2grey(mpimg.imread(os.path.join(folder, filename))) |
|
|
image_augmented = np.ones((max_H, max_W)) |
|
|
h, w = np.shape(image) |
|
|
stride_0, stride_1 = max_H - h, (max_W - w) // 2 |
|
|
offset = ((aug_type % 2) * stride_0, (aug_type % 3) * stride_1) |
|
|
image_augmented[offset[0]: h + offset[0], offset[1]: w + offset[1]] = image |
|
|
return image_augmented |
|
|
|
|
|
def augmentation_nine(filename, aug_type, max_H, max_W, folder=CONFIG.data_folder): |
|
|
image_augmented = np.ones((max_H, max_W)) |
|
|
image = Image.open(os.path.join(folder, filename)).convert('RGB') |
|
|
w_ori, h_ori = image.size |
|
|
rotate_ind = aug_type % 3 |
|
|
scale_ind = aug_type // 3 |
|
|
|
|
|
image = ImageOps.invert(image) |
|
|
if rotate_ind == 1: |
|
|
image = image.rotate(2, expand=True) |
|
|
elif rotate_ind == 2: |
|
|
image = image.rotate(-2, expand=True) |
|
|
image = ImageOps.invert(image) |
|
|
|
|
|
h, w = image.size |
|
|
if scale_ind == 1: |
|
|
h, w = int(np.floor(h * 0.98)), int(np.floor(w * 0.98)) |
|
|
image = image.resize((h, w)) |
|
|
elif scale_ind == 2: |
|
|
h, w = int(np.floor(h * 0.96)), int(np.floor(w * 0.96)) |
|
|
image = image.resize((h, w)) |
|
|
|
|
|
image = rgb2grey(np.array(image) / 255) |
|
|
h, w = np.shape(image) |
|
|
stride_0, stride_1 = (max_H - 10 - h_ori) // 2, (max_W - 10 - w_ori) // 2 |
|
|
offset = ((aug_type % 3) * stride_0, (aug_type % 3) * stride_1) |
|
|
try: |
|
|
image_augmented[offset[0]: h + offset[0], offset[1]: w + offset[1]] = image |
|
|
except ValueError: |
|
|
print(filename) |
|
|
return image_augmented |
|
|
|
|
|
class ShorthandGenerationDataset(Dataset): |
|
|
def __init__(self, file_list, max_H, max_W, max_label_leng, aug_types, channels=1): |
|
|
self.file_list = file_list |
|
|
self.H, self.W = max_H, max_W |
|
|
self.channels = channels |
|
|
self.vocabulary = 'abcdefghijklmnopqrstuvwxyz+#' |
|
|
self.dict_c2i = {char: idx for idx, char in enumerate(self.vocabulary)} |
|
|
self.max_label_length = max_label_leng |
|
|
self.max_context_length = self.max_label_length - 1 |
|
|
self.aug_types = aug_types |
|
|
self.instance_indices_by_length = {i: [] for i in range(1, self.max_context_length + 1)} |
|
|
|
|
|
for file in file_list: |
|
|
seq = '+' + file[:-4] + '#' |
|
|
max_context_len = len(seq) - 1 |
|
|
for length in range(1, max_context_len + 1): |
|
|
for aug in range(self.aug_types): |
|
|
self.instance_indices_by_length[length].append([seq, aug, length]) |
|
|
|
|
|
self.total_size = sum(len(self.instance_indices_by_length[i]) for i in range(1, self.max_context_length)) |
|
|
|
|
|
def __len__(self): |
|
|
return self.total_size |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
context_length = 1 |
|
|
while sum(len(self.instance_indices_by_length[length]) for length in range(1, context_length + 1)) <= idx: |
|
|
context_length += 1 |
|
|
|
|
|
num_batch_in_length = idx - sum(len(self.instance_indices_by_length[length]) for length in range(1, context_length)) |
|
|
starting_index = num_batch_in_length |
|
|
|
|
|
seq, augmentation_type, instance_context_length = self.instance_indices_by_length[context_length][starting_index] |
|
|
|
|
|
file_name = seq[1:-1] + '.png' |
|
|
img = augmentation_nine(file_name, augmentation_type, self.H, self.W) |
|
|
img = np.expand_dims(img, axis=0) |
|
|
|
|
|
x_context = np.array([self.dict_c2i[char] for char in seq[:instance_context_length]]) |
|
|
y = self.dict_c2i[seq[instance_context_length]] |
|
|
|
|
|
return torch.tensor(img, dtype=torch.float32), torch.tensor(x_context, dtype=torch.long), torch.tensor(y, dtype=torch.long) |