File size: 4,437 Bytes
3663dd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image, ImageOps
from config import CONFIG

def rgb2grey(rgb):
    return np.dot(rgb[..., :3], [0.299, 0.587, 0.114])

import os
import numpy as np
from config import CONFIG

def data_split():
    config = CONFIG()
    folder = config.data_folder
    val_proportion = float(config.val_proportion)  # Convert to float

    files = os.listdir(folder)
    period = int(np.round(1 / val_proportion))

    # Split logic here
    train_files = files[::period]
    val_files = files[1::period]
    test_files = files[2::period]
    max_H, max_W, max_seq_length = 256, 256, 100  # Example values

    return train_files, val_files, test_files, max_H, max_W, max_seq_length

def augmentation_simple(filename, aug_type, max_H, max_W, folder=CONFIG.data_folder):
    image = rgb2grey(mpimg.imread(os.path.join(folder, filename)))
    image_augmented = np.ones((max_H, max_W))
    h, w = np.shape(image)
    stride_0, stride_1 = max_H - h, (max_W - w) // 2
    offset = ((aug_type % 2) * stride_0, (aug_type % 3) * stride_1)
    image_augmented[offset[0]: h + offset[0], offset[1]: w + offset[1]] = image
    return image_augmented

def augmentation_nine(filename, aug_type, max_H, max_W, folder=CONFIG.data_folder):
    image_augmented = np.ones((max_H, max_W))
    image = Image.open(os.path.join(folder, filename)).convert('RGB')
    w_ori, h_ori = image.size
    rotate_ind = aug_type % 3
    scale_ind = aug_type // 3

    image = ImageOps.invert(image)
    if rotate_ind == 1:
        image = image.rotate(2, expand=True)
    elif rotate_ind == 2:
        image = image.rotate(-2, expand=True)
    image = ImageOps.invert(image)

    h, w = image.size
    if scale_ind == 1:
        h, w = int(np.floor(h * 0.98)), int(np.floor(w * 0.98))
        image = image.resize((h, w))
    elif scale_ind == 2:
        h, w = int(np.floor(h * 0.96)), int(np.floor(w * 0.96))
        image = image.resize((h, w))

    image = rgb2grey(np.array(image) / 255)
    h, w = np.shape(image)
    stride_0, stride_1 = (max_H - 10 - h_ori) // 2, (max_W - 10 - w_ori) // 2
    offset = ((aug_type % 3) * stride_0, (aug_type % 3) * stride_1)
    try:
        image_augmented[offset[0]: h + offset[0], offset[1]: w + offset[1]] = image
    except ValueError:
        print(filename)
    return image_augmented

class ShorthandGenerationDataset(Dataset):
    def __init__(self, file_list, max_H, max_W, max_label_leng, aug_types, channels=1):
        self.file_list = file_list
        self.H, self.W = max_H, max_W
        self.channels = channels
        self.vocabulary = 'abcdefghijklmnopqrstuvwxyz+#'
        self.dict_c2i = {char: idx for idx, char in enumerate(self.vocabulary)}
        self.max_label_length = max_label_leng
        self.max_context_length = self.max_label_length - 1
        self.aug_types = aug_types
        self.instance_indices_by_length = {i: [] for i in range(1, self.max_context_length + 1)}

        for file in file_list:
            seq = '+' + file[:-4] + '#'
            max_context_len = len(seq) - 1
            for length in range(1, max_context_len + 1):
                for aug in range(self.aug_types):
                    self.instance_indices_by_length[length].append([seq, aug, length])

        self.total_size = sum(len(self.instance_indices_by_length[i]) for i in range(1, self.max_context_length))

    def __len__(self):
        return self.total_size

    def __getitem__(self, idx):
        context_length = 1
        while sum(len(self.instance_indices_by_length[length]) for length in range(1, context_length + 1)) <= idx:
            context_length += 1

        num_batch_in_length = idx - sum(len(self.instance_indices_by_length[length]) for length in range(1, context_length))
        starting_index = num_batch_in_length

        seq, augmentation_type, instance_context_length = self.instance_indices_by_length[context_length][starting_index]

        file_name = seq[1:-1] + '.png'
        img = augmentation_nine(file_name, augmentation_type, self.H, self.W)
        img = np.expand_dims(img, axis=0)  # Add channel dimension

        x_context = np.array([self.dict_c2i[char] for char in seq[:instance_context_length]])
        y = self.dict_c2i[seq[instance_context_length]]

        return torch.tensor(img, dtype=torch.float32), torch.tensor(x_context, dtype=torch.long), torch.tensor(y, dtype=torch.long)