TungDuong's picture
source code
06142a4 verified
import torch
import torchvision
import json
import sys
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from src.Text_Recognization.prepare_dataset import *
# data augmentation
data_transforms = {
"train": transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((100, 400)),
transforms.ColorJitter(
brightness=0.5,
contrast=0.5,
saturation=0.5
),
transforms.GaussianBlur(3),
transforms.RandomAffine(
degrees=1,
shear=1
),
transforms.RandomPerspective(
distortion_scale=0.3,
p=0.5
),
transforms.RandomRotation(degrees=15),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
),
"val": transforms.Compose(
[
transforms.ToTensor(),
transforms.Resize((100, 400)),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]
)
}
def load_json_config(config_path):
with open(config_path, "r") as f:
config = json.load(f)
return config
# Dataloader
class STRDataset(Dataset):
def __init__(self, image_paths, labels, char_to_idx, transforms=None):
self.image_paths = image_paths
self.labels = labels
self.char_to_idx = char_to_idx
self.transforms= transforms
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
image = cv2.imread(self.image_paths[idx])
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if self.transforms:
image = self.transforms(image)
label_encoded, length = encode(self.labels[idx], self.char_to_idx, self.labels)
return image, label_encoded, length
def get_dataloader():
val_size = 0.1
test_size = 0.1
root_path = 'Dataset'
config_path = 'src/config.json'
# get image paths and labels
image_paths, labels = get_imagepaths_and_labels(root_path)
char_to_idx, idx_to_char = build_vocab(root_path)
config = load_json_config(config_path)
X_train, X_val, y_train, y_val = train_test_split(image_paths, labels, test_size=val_size, random_state=42, shuffle=True)
X_train, X_test, y_train, y_test = train_test_split(X_train, y_train, test_size=test_size, random_state=42, shuffle=True)
train_dataset = STRDataset(X_train, y_train, char_to_idx, transforms=data_transforms['train'])
train_loader = DataLoader(train_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
val_dataset = STRDataset(X_val, y_val, char_to_idx, transforms=data_transforms['val'])
val_loader = DataLoader(val_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
test_dataset = STRDataset(X_test, y_test, char_to_idx, transforms=data_transforms['val'])
test_loader = DataLoader(test_dataset, batch_size=config['CRNN']['batch_size'], shuffle=True)
return train_loader, val_loader, test_loader