| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import glob |
| import io |
| import logging |
| import unicodedata |
| from pathlib import Path, PurePath |
| from typing import Callable, Optional, Union |
|
|
| import lmdb |
| from PIL import Image |
| from torch.utils.data import Dataset, ConcatDataset |
|
|
| from strhub.data.utils import CharsetAdapter |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| def build_tree_dataset(root: Union[PurePath, str], *args, **kwargs): |
| try: |
| kwargs.pop('root') |
| except KeyError: |
| pass |
| root = Path(root).absolute() |
| log.info(f'dataset root:\t{root}') |
| datasets = [] |
| for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True): |
| mdb = Path(mdb) |
| ds_name = str(mdb.parent.relative_to(root)) |
| ds_root = str(mdb.parent.absolute()) |
| dataset = LmdbDataset(ds_root, *args, **kwargs) |
| log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}') |
| datasets.append(dataset) |
| return ConcatDataset(datasets) |
|
|
|
|
| class LmdbDataset(Dataset): |
| """Dataset interface to an LMDB database. |
| |
| It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned |
| as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset. |
| Labels are transformed according to the charset. |
| """ |
|
|
| def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0, |
| remove_whitespace: bool = True, normalize_unicode: bool = True, |
| unlabelled: bool = False, transform: Optional[Callable] = None): |
| self._env = None |
| self.root = root |
| self.unlabelled = unlabelled |
| self.transform = transform |
| self.labels = [] |
| self.filtered_index_list = [] |
| self.num_samples = self._preprocess_labels(charset, remove_whitespace, normalize_unicode, |
| max_label_len, min_image_dim) |
|
|
| def __del__(self): |
| if self._env is not None: |
| self._env.close() |
| self._env = None |
|
|
| def _create_env(self): |
| return lmdb.open(self.root, max_readers=1, readonly=True, create=False, |
| readahead=False, meminit=False, lock=False) |
|
|
| @property |
| def env(self): |
| if self._env is None: |
| self._env = self._create_env() |
| return self._env |
|
|
| def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim): |
| charset_adapter = CharsetAdapter(charset) |
| with self._create_env() as env, env.begin() as txn: |
| num_samples = int(txn.get('num-samples'.encode())) |
| if self.unlabelled: |
| return num_samples |
| for index in range(num_samples): |
| index += 1 |
| label_key = f'label-{index:09d}'.encode() |
| label = txn.get(label_key).decode() |
| |
| if remove_whitespace: |
| label = ''.join(label.split()) |
| |
| if normalize_unicode: |
| label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode() |
| |
| if len(label) > max_label_len: |
| continue |
| label = charset_adapter(label) |
| |
| if not label: |
| continue |
| |
| if min_image_dim > 0: |
| img_key = f'image-{index:09d}'.encode() |
| buf = io.BytesIO(txn.get(img_key)) |
| w, h = Image.open(buf).size |
| if w < self.min_image_dim or h < self.min_image_dim: |
| continue |
| self.labels.append(label) |
| self.filtered_index_list.append(index) |
| return len(self.labels) |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def __getitem__(self, index): |
| if self.unlabelled: |
| label = index |
| else: |
| label = self.labels[index] |
| index = self.filtered_index_list[index] |
|
|
| img_key = f'image-{index:09d}'.encode() |
| with self.env.begin() as txn: |
| imgbuf = txn.get(img_key) |
| buf = io.BytesIO(imgbuf) |
| img = Image.open(buf).convert('RGB') |
|
|
| if self.transform is not None: |
| img = self.transform(img) |
|
|
| return img, label |
|
|