File size: 3,234 Bytes
06142a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torchvision
import json
import sys

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from src.Text_Recognization.prepare_dataset import *

# data augmentation
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((100, 400)),
            transforms.ColorJitter(
                brightness=0.5,
                contrast=0.5,
                saturation=0.5
            ),
            transforms.GaussianBlur(3),
            transforms.RandomAffine(
                degrees=1,
                shear=1
            ),
            transforms.RandomPerspective(
                distortion_scale=0.3,
                p=0.5
            ),
            transforms.RandomRotation(degrees=15),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Resize((100, 400)),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ]
    )
}

def load_json_config(config_path):
    with open(config_path, "r") as f:
        config = json.load(f)
        
    return config

# Dataloader
class STRDataset(Dataset):
    def __init__(self, image_paths, labels, char_to_idx, transforms=None):
        self.image_paths = image_paths
        self.labels = labels
        self.char_to_idx = char_to_idx
        self.transforms= transforms
        
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.transforms:
            image = self.transforms(image)
        
        label_encoded, length = encode(self.labels[idx], self.char_to_idx, self.labels)
        
        return image, label_encoded, length
    
def get_dataloader():
    val_size = 0.1
    test_size = 0.1
    root_path = 'Dataset'
    config_path = 'src/config.json'

    # get image paths and labels
    image_paths, labels = get_imagepaths_and_labels(root_path)
    char_to_idx, idx_to_char = build_vocab(root_path)


    config = load_json_config(config_path)

    X_train, X_val, y_train, y_val = train_test_split(image_paths, labels, test_size=val_size, random_state=42, shuffle=True)
    X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=test_size, random_state=42, shuffle=True)
    train_dataset = STRDataset(X_train, y_train, char_to_idx, transforms=data_transforms['train'])
    train_loader = DataLoader(train_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)

    val_dataset = STRDataset(X_val, y_val, char_to_idx, transforms=data_transforms['val'])
    val_loader = DataLoader(val_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)

    test_dataset = STRDataset(X_test, y_test, char_to_idx, transforms=data_transforms['val'])
    test_loader = DataLoader(test_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
    
    return train_loader, val_loader, test_loader