a0a7's picture
Upload 8 files
3663dd0 verified
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageOps
from config import CONFIG
def rgb2grey(rgb):
return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])
import os
import numpy as np
from config import CONFIG
def data_split():
config = CONFIG()
folder = config.data_folder
val_proportion = float(config.val_proportion) # Convert to float
files = os.listdir(folder)
period = int(np.round(1 / val_proportion))
# Split logic here
train_files = files[::period]
val_files = files[1::period]
test_files = files[2::period]
max_H, max_W, max_seq_length = 256, 256, 100 # Example values
return train_files, val_files, test_files, max_H, max_W, max_seq_length
def augmentation_simple(filename, aug_type, max_H, max_W, folder=CONFIG.data_folder):
image = rgb2grey(mpimg.imread(os.path.join(folder, filename)))
image_augmented = np.ones((max_H, max_W))
h, w = np.shape(image)
stride_0, stride_1 = max_H - h, (max_W - w) // 2
offset = ((aug_type % 2) * stride_0, (aug_type % 3) * stride_1)
image_augmented[offset[0]: h + offset[0], offset[1]: w + offset[1]] = image
return image_augmented
def augmentation_nine(filename, aug_type, max_H, max_W, folder=CONFIG.data_folder):
image_augmented = np.ones((max_H, max_W))
image = Image.open(os.path.join(folder, filename)).convert('RGB')
w_ori, h_ori = image.size
rotate_ind = aug_type % 3
scale_ind = aug_type // 3
image = ImageOps.invert(image)
if rotate_ind == 1:
image = image.rotate(2, expand=True)
elif rotate_ind == 2:
image = image.rotate(-2, expand=True)
image = ImageOps.invert(image)
h, w = image.size
if scale_ind == 1:
h, w = int(np.floor(h * 0.98)), int(np.floor(w * 0.98))
image = image.resize((h, w))
elif scale_ind == 2:
h, w = int(np.floor(h * 0.96)), int(np.floor(w * 0.96))
image = image.resize((h, w))
image = rgb2grey(np.array(image) / 255)
h, w = np.shape(image)
stride_0, stride_1 = (max_H - 10 - h_ori) // 2, (max_W - 10 - w_ori) // 2
offset = ((aug_type % 3) * stride_0, (aug_type % 3) * stride_1)
try:
image_augmented[offset[0]: h + offset[0], offset[1]: w + offset[1]] = image
except ValueError:
print(filename)
return image_augmented
class ShorthandGenerationDataset(Dataset):
def __init__(self, file_list, max_H, max_W, max_label_leng, aug_types, channels=1):
self.file_list = file_list
self.H, self.W = max_H, max_W
self.channels = channels
self.vocabulary = 'abcdefghijklmnopqrstuvwxyz+#'
self.dict_c2i = {char: idx for idx, char in enumerate(self.vocabulary)}
self.max_label_length = max_label_leng
self.max_context_length = self.max_label_length - 1
self.aug_types = aug_types
self.instance_indices_by_length = {i: [] for i in range(1, self.max_context_length + 1)}
for file in file_list:
seq = '+' + file[:-4] + '#'
max_context_len = len(seq) - 1
for length in range(1, max_context_len + 1):
for aug in range(self.aug_types):
self.instance_indices_by_length[length].append([seq, aug, length])
self.total_size = sum(len(self.instance_indices_by_length[i]) for i in range(1, self.max_context_length))
def __len__(self):
return self.total_size
def __getitem__(self, idx):
context_length = 1
while sum(len(self.instance_indices_by_length[length]) for length in range(1, context_length + 1)) <= idx:
context_length += 1
num_batch_in_length = idx - sum(len(self.instance_indices_by_length[length]) for length in range(1, context_length))
starting_index = num_batch_in_length
seq, augmentation_type, instance_context_length = self.instance_indices_by_length[context_length][starting_index]
file_name = seq[1:-1] + '.png'
img = augmentation_nine(file_name, augmentation_type, self.H, self.W)
img = np.expand_dims(img, axis=0) # Add channel dimension
x_context = np.array([self.dict_c2i[char] for char in seq[:instance_context_length]])
y = self.dict_c2i[seq[instance_context_length]]
return torch.tensor(img, dtype=torch.float32), torch.tensor(x_context, dtype=torch.long), torch.tensor(y, dtype=torch.long)