| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from pathlib import PurePath |
| from typing import Optional, Callable, Sequence, Tuple |
|
|
| import pytorch_lightning as pl |
| from torch.utils.data import DataLoader |
| from torchvision import transforms as T |
|
|
| from .dataset import build_tree_dataset, LmdbDataset |
|
|
|
|
| class SceneTextDataModule(pl.LightningDataModule): |
| TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80') |
| TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80') |
| TEST_NEW = ('ArT', 'COCOv1.4', 'Uber') |
| TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW)) |
|
|
| def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int, |
| charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool, |
| remove_whitespace: bool = True, normalize_unicode: bool = True, |
| min_image_dim: int = 0, rotation: int = 0, collate_fn: Optional[Callable] = None): |
| super().__init__() |
| self.root_dir = root_dir |
| self.train_dir = train_dir |
| self.img_size = tuple(img_size) |
| self.max_label_length = max_label_length |
| self.charset_train = charset_train |
| self.charset_test = charset_test |
| self.batch_size = batch_size |
| self.num_workers = num_workers |
| self.augment = augment |
| self.remove_whitespace = remove_whitespace |
| self.normalize_unicode = normalize_unicode |
| self.min_image_dim = min_image_dim |
| self.rotation = rotation |
| self.collate_fn = collate_fn |
| self._train_dataset = None |
| self._val_dataset = None |
|
|
| @staticmethod |
| def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0): |
| transforms = [] |
| if augment: |
| from .augment import rand_augment_transform |
| transforms.append(rand_augment_transform()) |
| if rotation: |
| transforms.append(lambda img: img.rotate(rotation, expand=True)) |
| transforms.extend([ |
| T.Resize(img_size, T.InterpolationMode.BICUBIC), |
| T.ToTensor(), |
| T.Normalize(0.5, 0.5) |
| ]) |
| return T.Compose(transforms) |
|
|
| @property |
| def train_dataset(self): |
| if self._train_dataset is None: |
| transform = self.get_transform(self.img_size, self.augment) |
| root = PurePath(self.root_dir, 'train', self.train_dir) |
| self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length, |
| self.min_image_dim, self.remove_whitespace, self.normalize_unicode, |
| transform=transform) |
| return self._train_dataset |
|
|
| @property |
| def val_dataset(self): |
| if self._val_dataset is None: |
| transform = self.get_transform(self.img_size) |
| root = PurePath(self.root_dir, 'val') |
| self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length, |
| self.min_image_dim, self.remove_whitespace, self.normalize_unicode, |
| transform=transform) |
| return self._val_dataset |
|
|
| def train_dataloader(self): |
| return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, |
| num_workers=self.num_workers, persistent_workers=self.num_workers > 0, |
| pin_memory=True, collate_fn=self.collate_fn) |
|
|
| def val_dataloader(self): |
| return DataLoader(self.val_dataset, batch_size=self.batch_size, |
| num_workers=self.num_workers, persistent_workers=self.num_workers > 0, |
| pin_memory=True, collate_fn=self.collate_fn) |
|
|
| def test_dataloaders(self, subset): |
| transform = self.get_transform(self.img_size, rotation=self.rotation) |
| root = PurePath(self.root_dir, 'test') |
| datasets = {s: LmdbDataset(str(root / s), self.charset_test, self.max_label_length, |
| self.min_image_dim, self.remove_whitespace, self.normalize_unicode, |
| transform=transform) for s in subset} |
| return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers, |
| pin_memory=True, collate_fn=self.collate_fn) |
| for k, v in datasets.items()} |
|
|