""" Paper: "UTRNet: High-Resolution Urdu Text Recognition In Printed Documents" presented at ICDAR 2023 Authors: Abdur Rahman, Arjun Ghosh, Chetan Arora GitHub Repository: https://github.com/abdur75648/UTRNet-High-Resolution-Urdu-Text-Recognition Project Website: https://abdur75648.github.io/UTRNet/ Copyright (c) 2023-present: This work is licensed under the Creative Commons Attribution-NonCommercial 4.0 International License (http://creativecommons.org/licenses/by-nc/4.0/) """ import os import sys import re import six import math import lmdb import torch import random import numpy as np from PIL import Image import torchvision.transforms as T #from torch._utils import _accumulate from torch.utils.data import Dataset, ConcatDataset, Subset def accumulate(iterable): total = 0 for value in iterable: total += value yield total class Batch_Balanced_Dataset(object): def __init__(self, opt, rand_aug = False): """ Modulate the data ratio in the batch. For example, when select_data is "MJ-ST" and batch_ratio is "0.5-0.5", the 50% of the batch is filled with MJ and the other 50% of the batch is filled with ST. """ log = open(f'./saved_models/{opt.exp_name}/log_dataset.txt', 'a',encoding="utf-8") dashed_line = '-' * 80 print(dashed_line) log.write(dashed_line + '\n') print(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}') log.write(f'dataset_root: {opt.train_data}\nopt.select_data: {opt.select_data}\nopt.batch_ratio: {opt.batch_ratio}\n') assert len(opt.select_data) == len(opt.batch_ratio) _AlignCollate = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD) self.data_loader_list = [] self.dataloader_iter_list = [] batch_size_list = [] Total_batch_size = 0 for selected_d, batch_ratio_d in zip(opt.select_data, opt.batch_ratio): _batch_size = max(round(opt.batch_size * float(batch_ratio_d)), 1) print(dashed_line) log.write(dashed_line + '\n') _dataset, _dataset_log = hierarchical_dataset(root=opt.train_data, opt=opt, select_data=[selected_d], rand_aug=rand_aug) total_number_dataset = len(_dataset) log.write(_dataset_log) """ The total number of data can be modified with opt.total_data_usage_ratio. ex) opt.total_data_usage_ratio = 1 indicates 100% usage, and 0.2 indicates 20% usage. See 4.2 section in our paper. """ number_dataset = int(total_number_dataset * float(opt.total_data_usage_ratio)) dataset_split = [number_dataset, total_number_dataset - number_dataset] indices = range(total_number_dataset) _dataset, _ = [Subset(_dataset, indices[offset - length:offset]) for offset, length in zip(_accumulate(dataset_split), dataset_split)] selected_d_log = f'num total samples of {selected_d}: {total_number_dataset} x {opt.total_data_usage_ratio} (total_data_usage_ratio) = {len(_dataset)}\n' selected_d_log += f'num samples of {selected_d} per batch: {opt.batch_size} x {float(batch_ratio_d)} (batch_ratio) = {_batch_size}' print(selected_d_log) log.write(selected_d_log + '\n') batch_size_list.append(str(_batch_size)) Total_batch_size += _batch_size _data_loader = torch.utils.data.DataLoader( _dataset, batch_size=_batch_size, shuffle=True, num_workers=0, collate_fn=_AlignCollate, pin_memory=True) self.data_loader_list.append(_data_loader) self.dataloader_iter_list.append(iter(_data_loader)) Total_batch_size_log = f'{dashed_line}\n' batch_size_sum = '+'.join(batch_size_list) Total_batch_size_log += f'Total_batch_size: {batch_size_sum} = {Total_batch_size}\n' Total_batch_size_log += f'{dashed_line}' opt.batch_size = Total_batch_size print(Total_batch_size_log) log.write(Total_batch_size_log + '\n') log.close() def get_batch(self): balanced_batch_images = [] balanced_batch_texts = [] for i, data_loader_iter in enumerate(self.dataloader_iter_list): try: image, text = data_loader_iter.next() balanced_batch_images.append(image) balanced_batch_texts += text except StopIteration: self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) image, text = self.dataloader_iter_list[i].next() balanced_batch_images.append(image) balanced_batch_texts += text except ValueError: pass balanced_batch_images = torch.cat(balanced_batch_images, 0) return balanced_batch_images, balanced_batch_texts def hierarchical_dataset(root, opt, select_data='/', rand_aug = False): """ select_data='/' contains all sub-directory of root directory """ dataset_list = [] dataset_log = f'dataset_root: {root}\t dataset: {select_data[0]}' # print(dataset_log) dataset_log += '\n' for dirpath, dirnames, filenames in os.walk(root+'/'): if not dirnames: select_flag = False for selected_d in select_data: if selected_d in dirpath: select_flag = True break if select_flag: dataset = LmdbDataset(dirpath, opt, rand_aug=rand_aug) sub_dataset_log = f'sub-directory:\t/{os.path.relpath(dirpath, root)}\t num samples: {len(dataset)}' # print(sub_dataset_log) dataset_log += f'{sub_dataset_log}\n' dataset_list.append(dataset) print( len(dataset_list) ) concatenated_dataset = ConcatDataset(dataset_list) return concatenated_dataset, dataset_log class LmdbDataset(Dataset): def __init__(self, root, opt,rand_aug=False, transform=None): self.root = root self.opt = opt self.rand_aug = rand_aug self.transform = transform self.env = lmdb.open(root, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) if not self.env: print('cannot create lmdb from %s' % (root)) sys.exit(0) with self.env.begin(write=False) as txn: nSamples = int(txn.get('num-samples'.encode())) self.nSamples = nSamples self.filtered_index_list = [] for index in range(self.nSamples): index += 1 # lmdb starts with 1 label_key = 'label-%09d'.encode() % index label = txn.get(label_key).decode('utf-8') if len(label) > self.opt.batch_max_length: # print(f'The length of the label is longer than max_length: length {len(label)}, {label} in dataset {self.root}') continue # By default, images containing characters which are not in opt.character are filtered. # You can add [UNK] token to `opt.character` in utils.py instead of this filtering out_of_char = f'[^{self.opt.character}]' if re.search(out_of_char, label): print ("This string contains a character not part of our dictionnary") continue self.filtered_index_list.append(index) self.nSamples = len(self.filtered_index_list) if self.transform is None: self.transform = [] if self.rand_aug: from augmentation import rand_augment_transform,salt_and_pepper_noise, random_border_crop, random_resize self.transform.append(rand_augment_transform()) self.transform.append(T.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.25)) if random.random()<0.25: self.transform.append(lambda img: salt_and_pepper_noise(img)) if random.random()<0.25: self.transform.append(lambda img: random_border_crop(img)) self.transform.append(T.RandomRotation(5)) if random.random()<0.25: self.transform.append(lambda img: random_resize(img)) self.transform = T.Compose(self.transform) def __len__(self): return self.nSamples def __getitem__(self, index): assert index <= len(self), 'index range error' index = self.filtered_index_list[index] with self.env.begin(write=False) as txn: label_key = 'label-%09d'.encode() % index label = txn.get(label_key).decode('utf-8') img_key = 'image-%09d'.encode() % index imgbuf = txn.get(img_key) buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) try: if self.opt.rgb: img1 = Image.open(buf).convert('RGB') # for color image img = img1.transpose(Image.FLIP_LEFT_RIGHT) else: img1 = Image.open(buf).convert('L') img = img1.transpose(Image.FLIP_LEFT_RIGHT) except IOError: print(f'Corrupted image for {index}') # make dummy image and dummy label for corrupted image. if self.opt.rgb: img = Image.new('RGB', (self.opt.imgW, self.opt.imgH)) img = img.transpose(Image.FLIP_LEFT_RIGHT) else: img = Image.new('L', (self.opt.imgW, self.opt.imgH)) img = img.transpose(Image.FLIP_LEFT_RIGHT) label = '[dummy_label]' # We only train and evaluate on alphanumerics (or pre-defined character set in train.py) out_of_char = f'[^{self.opt.character}]' label = re.sub(out_of_char, '', label) if self.transform: img = self.transform(img) return (img, label) class ResizeNormalize(object): def __init__(self, size, interpolation=Image.BICUBIC): self.size = size self.interpolation = interpolation self.toTensor = T.ToTensor() def __call__(self, img): img = img.resize(self.size, self.interpolation) img = self.toTensor(img) img.sub_(0.5).div_(0.5) return img class NormalizePAD(object): def __init__(self, max_size, PAD_type='right'): self.toTensor = T.ToTensor() self.max_size = max_size self.max_width_half = math.floor(max_size[2] / 2) self.PAD_type = PAD_type def __call__(self, img): img = self.toTensor(img) img.sub_(0.5).div_(0.5) c, h, w = img.size() Pad_img = torch.FloatTensor(*self.max_size).fill_(0) Pad_img[:, :, :w] = img # right pad if self.max_size[2] != w: # add border Pad Pad_img[:, :, w:] = img[:, :, w - 1].unsqueeze(2).expand(c, h, self.max_size[2] - w) return Pad_img class AlignCollate(object): def __init__(self, imgH=32, imgW=100, keep_ratio_with_pad=True): self.imgH = imgH self.imgW = imgW self.keep_ratio_with_pad = keep_ratio_with_pad def __call__(self, batch): batch = filter(lambda x: x is not None, batch) images, labels = zip(*batch) if self.keep_ratio_with_pad: # same concept with 'Rosetta' paper resized_max_w = self.imgW input_channel = 3 if images[0].mode == 'RGB' else 1 transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) resized_images = [] for image in images: w, h = image.size ratio = w / float(h) if math.ceil(self.imgH * ratio) > self.imgW: resized_w = self.imgW else: resized_w = math.ceil(self.imgH * ratio) resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) resized_images.append(transform(resized_image)) # resized_image.save('./image_test/%d_test.jpg' % w) image_tensors = torch.cat([t.unsqueeze(0) for t in resized_images], 0) else: transform = ResizeNormalize((self.imgW, self.imgH)) image_tensors = [transform(image) for image in images] image_tensors = torch.cat([t.unsqueeze(0) for t in image_tensors], 0) return image_tensors, labels def tensor2im(image_tensor, imtype=np.uint8): image_numpy = image_tensor.cpu().float().numpy() if image_numpy.shape[0] == 1: image_numpy = np.tile(image_numpy, (3, 1, 1)) image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 return image_numpy.astype(imtype) def save_image(image_numpy, image_path): image_pil = Image.fromarray(image_numpy) image_pil.save(image_path)