| from typing import Any, Optional, Union |
| from pathlib import Path |
| import os |
| import io |
| import lmdb |
| import pickle |
| import gzip |
| import bz2 |
| import lzma |
| import shutil |
| from tqdm import tqdm |
| import pandas as pd |
| import numpy as np |
| from numpy import ndarray |
| import time |
| import torch |
| from torch import Tensor |
| from distutils.dir_util import copy_tree |
| from PIL import Image |
| from PIL import ImageFile |
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
|
| def _default_encode(data: Any, protocol: int) -> bytes: |
| return pickle.dumps(data, protocol=protocol) |
|
|
|
|
| def _ascii_encode(data: str) -> bytes: |
| return data.encode("ascii") |
|
|
|
|
| def _default_decode(data: bytes) -> Any: |
| return pickle.loads(data) |
|
|
|
|
| def _default_decompress(data: bytes) -> bytes: |
| return data |
|
|
|
|
| def _decompress(compression: Optional[str]): |
| if compression is None: |
| _decompress = _default_decompress |
| elif compression == "gzip": |
| _decompress = gzip.decompress |
| elif compression == "bz2": |
| _decompress = bz2.decompress |
| elif compression == "lzma": |
| _decompress = lzma.decompress |
| else: |
| raise ValueError(f"Unknown compression algorithm: {compression}") |
|
|
| return _decompress |
|
|
|
|
| class BaseLMDB(object): |
| _database = None |
| _protocol = None |
| _length = None |
|
|
| def __init__( |
| self, |
| path: Union[str, Path], |
| readahead: bool = False, |
| pre_open: bool = False, |
| compression: Optional[str] = None |
| ): |
| """ |
| Base class for LMDB-backed databases. |
| |
| :param path: Path to the database. |
| :param readahead: Enables the filesystem readahead mechanism. |
| :param pre_open: If set to True, the first iterations will be faster, but it will raise error when doing multi-gpu training. If set to False, the database will open when you will retrieve the first item. |
| """ |
| if not isinstance(path, str): |
| path = str(path) |
|
|
| self.path = path |
| self.readahead = readahead |
| self.pre_open = pre_open |
| self._decompress = _decompress(compression) |
| self._has_fetched_an_item = False |
|
|
| @property |
| def database(self): |
| if self._database is None: |
| self._database = lmdb.open( |
| path=self.path, |
| readonly=True, |
| readahead=self.readahead, |
| max_spare_txns=256, |
| lock=False, |
| ) |
| return self._database |
|
|
| @database.deleter |
| def database(self): |
| if self._database is not None: |
| self._database.close() |
| self._database = None |
|
|
| @property |
| def protocol(self): |
| """ |
| Read the pickle protocol contained in the database. |
| |
| :return: The set of available keys. |
| """ |
| if self._protocol is None: |
| self._protocol = self._get( |
| item="protocol", |
| encode_key=_ascii_encode, |
| decompress_value=_default_decompress, |
| decode_value=_default_decode, |
| ) |
| return self._protocol |
|
|
| @property |
| def keys(self): |
| """ |
| Read the keys contained in the database. |
| |
| :return: The set of available keys. |
| """ |
| protocol = self.protocol |
| keys = self._get( |
| item="keys", |
| encode_key=lambda key: _default_encode(key, protocol=protocol), |
| decompress_value=_default_decompress, |
| decode_value=_default_decode, |
| ) |
| return keys |
|
|
| def __len__(self): |
| """ |
| Returns the number of keys available in the database. |
| |
| :return: The number of keys. |
| """ |
| if self._length is None: |
| self._length = len(self.keys) |
| return self._length |
|
|
| def __getitem__(self, item): |
| """ |
| Retrieves an item or a list of items from the database. |
| |
| :param item: A key or a list of keys. |
| :return: A value or a list of values. |
| """ |
| self._has_fetched_an_item = True |
| if not isinstance(item, list): |
| item = self._get( |
| item=item, |
| encode_key=self._encode_key, |
| decompress_value=self._decompress_value, |
| decode_value=self._decode_value, |
| ) |
| else: |
| item = self._gets( |
| items=item, |
| encode_keys=self._encode_keys, |
| decompress_values=self._decompress_values, |
| decode_values=self._decode_values, |
| ) |
| return item |
|
|
| def _get(self, item, encode_key, decompress_value, decode_value): |
| """ |
| Instantiates a transaction and its associated cursor to fetch an item. |
| |
| :param item: A key. |
| :param encode_key: |
| :param decode_value: |
| :return: |
| """ |
| with self.database.begin() as txn: |
| with txn.cursor() as cursor: |
| item = self._fetch( |
| cursor=cursor, |
| key=item, |
| encode_key=encode_key, |
| decompress_value=decompress_value, |
| decode_value=decode_value, |
| ) |
| self._keep_database() |
| return item |
|
|
| def _gets(self, items, encode_keys, decompress_values, decode_values): |
| """ |
| Instantiates a transaction and its associated cursor to fetch a list of items. |
| |
| :param items: A list of keys. |
| :param encode_keys: |
| :param decode_values: |
| :return: |
| """ |
| with self.database.begin() as txn: |
| with txn.cursor() as cursor: |
| items = self._fetchs( |
| cursor=cursor, |
| keys=items, |
| encode_keys=encode_keys, |
| decompress_values=decompress_values, |
| decode_values=decode_values, |
| ) |
| self._keep_database() |
| return items |
|
|
| def _fetch(self, cursor, key, encode_key, decompress_value, decode_value): |
| """ |
| Retrieve a value given a key. |
| |
| :param cursor: |
| :param key: A key. |
| :param encode_key: |
| :param decode_value: |
| :return: A value. |
| """ |
| key = encode_key(key) |
| value = cursor.get(key) |
| value = decompress_value(value) |
| value = decode_value(value) |
| return value |
|
|
| def _fetchs(self, cursor, keys, encode_keys, decompress_values, decode_values): |
| """ |
| Retrieve a list of values given a list of keys. |
| |
| :param cursor: |
| :param keys: A list of keys. |
| :param encode_keys: |
| :param decode_values: |
| :return: A list of values. |
| """ |
| keys = encode_keys(keys) |
| _, values = list(zip(*cursor.getmulti(keys))) |
| values = decompress_values(values) |
| values = decode_values(values) |
| return values |
|
|
| def _encode_key(self, key: Any) -> bytes: |
| """ |
| Converts a key into a byte key. |
| |
| :param key: A key. |
| :return: A byte key. |
| """ |
| return pickle.dumps(key, protocol=self.protocol) |
|
|
| def _encode_keys(self, keys: list) -> list: |
| """ |
| Converts keys into byte keys. |
| |
| :param keys: A list of keys. |
| :return: A list of byte keys. |
| """ |
| return [self._encode_key(key=key) for key in keys] |
|
|
| def _decompress_value(self, value: bytes) -> bytes: |
| return self._decompress(value) |
|
|
| def _decompress_values(self, values: list) -> list: |
| return [self._decompress_value(value=value) for value in values] |
|
|
| def _decode_value(self, value: bytes) -> Any: |
| """ |
| Converts a byte value back into a value. |
| |
| :param value: A byte value. |
| :return: A value |
| """ |
| return pickle.loads(value) |
|
|
| def _decode_values(self, values: list) -> list: |
| """ |
| Converts bytes values back into values. |
| |
| :param values: A list of byte values. |
| :return: A list of values. |
| """ |
| return [self._decode_value(value=value) for value in values] |
|
|
| def _keep_database(self): |
| """ |
| Checks if the database must be deleted. |
| |
| :return: |
| """ |
| if not self.pre_open and not self._has_fetched_an_item: |
| del self.database |
|
|
| def __iter__(self): |
| """ |
| Provides an iterator over the keys when iterating over the database. |
| |
| :return: An iterator on the keys. |
| """ |
| return iter(self.keys) |
|
|
| def __del__(self): |
| """ |
| Closes the database properly. |
| """ |
| del self.database |
|
|
| @staticmethod |
| def write(data_lst, indir, outdir): |
| raise NotImplementedError |
|
|
|
|
| class PILlmdb(BaseLMDB): |
| def __init__( |
| self, |
| lmdb_dir: Union[str, Path], |
| image_list: Union[str, Path, pd.DataFrame]=None, |
| index_key='id', |
| **kwargs |
| ): |
| super().__init__(path=lmdb_dir, **kwargs) |
| if image_list is None: |
| self.ids = list(range(len(self.keys))) |
| self.labels = list(range(len(self.ids))) |
| else: |
| df = pd.read_csv(str(image_list)) |
| assert index_key in df, f'[PILlmdb] Error! {image_list} must have id keys.' |
| self.ids = df[index_key].tolist() |
| assert max(self.ids) < len(self.keys) |
| if 'label' in df: |
| self.labels = df['label'].tolist() |
| else: |
| keys = [key for key in df if (key!=index_key and type(df[key][0]) in [int, np.int64])] |
| |
| self.labels = df[keys].to_numpy() |
| self._length = len(self.ids) |
|
|
| def __len__(self): |
| return self._length |
|
|
| def __iter__(self): |
| return iter([self.keys[i] for i in self.ids]) |
|
|
| def __getitem__(self, index): |
| key = self.keys[self.ids[index]] |
| return super().__getitem__(key) |
|
|
| def set_ids(self, ids): |
| self.ids = [self.ids[i] for i in ids] |
| self.labels = [self.labels[i] for i in ids] |
| self._length = len(self.ids) |
| |
| def _decode_value(self, value: bytes): |
| """ |
| Converts a byte image back into a PIL Image. |
| |
| :param value: A byte image. |
| :return: A PIL Image image. |
| """ |
| return Image.open(io.BytesIO(value)) |
|
|
| @staticmethod |
| def write(indir, outdir, data_lst=None, transform=None): |
| """ |
| create lmdb given data directory and list of image paths; or an iterator |
| :param data_lst None or csv file containing 'path' key to store relative paths to the images |
| :param indir root directory of the images |
| :param outdir output lmdb, data.mdb and lock.mdb will be written here |
| """ |
| |
| outdir = Path(outdir) |
| outdir.mkdir(parents=True, exist_ok=True) |
| tmp_dir = Path("/tmp") / f"TEMP_{time.time()}" |
| tmp_dir.mkdir(parents=True, exist_ok=True) |
| dtype = {'str': False, 'pil': False} |
| if isinstance(indir, str) or isinstance(indir, Path): |
| indir = Path(indir) |
| if data_lst is None: |
| lst = list(indir.glob('**/*.jpg')) + list(indir.glob('**/*.png')) |
| else: |
| lst = pd.read_csv(data_lst)['path'].tolist() |
| lst = [indir/p for p in lst] |
| assert len(lst) > 0, f'Couldnt find any image in {indir} (Support only .jpg and .png) or list (must have path field).' |
| n = len(lst) |
| dtype['str'] = True |
| else: |
| n = len(indir) |
| lst = iter(indir) |
| dtype['pil'] = True |
|
|
| with lmdb.open(path=str(tmp_dir), map_size=2 ** 40) as env: |
| |
| with env.begin(write=True) as txn: |
| key = "protocol".encode("ascii") |
| value = pickle.dumps(pickle.DEFAULT_PROTOCOL) |
| txn.put(key=key, value=value, dupdata=False) |
| |
| with env.begin(write=True) as txn: |
| key = pickle.dumps("keys") |
| value = pickle.dumps(list(range(n))) |
| txn.put(key=key, value=value, dupdata=False) |
| |
| for key, value in tqdm(enumerate(lst), total=n, miniters=n//100, mininterval=300): |
| with env.begin(write=True) as txn: |
| key = pickle.dumps(key) |
| if dtype['str']: |
| with value.open("rb") as file: |
| byteimg = file.read() |
| else: |
| data = io.BytesIO() |
| value.save(data, 'png') |
| byteimg = data.getvalue() |
|
|
| if transform is not None: |
| im = Image.open(io.BytesIO(byteimg)) |
| im = transform(im) |
| data = io.BytesIO() |
| im.save(data, 'png') |
| byteimg = data.getvalue() |
| txn.put(key=key, value=byteimg, dupdata=False) |
|
|
| |
| copy_tree(str(tmp_dir), str(outdir)) |
| shutil.rmtree(str(tmp_dir)) |
|
|
|
|
|
|
| class MaskDatabase(PILlmdb): |
| def _decode_value(self, value: bytes): |
| """ |
| Converts a byte image back into a PIL Image. |
| |
| :param value: A byte image. |
| :return: A PIL Image image. |
| """ |
| return Image.open(io.BytesIO(value)).convert("1") |
|
|
|
|
| class LabelDatabase(BaseLMDB): |
| pass |
|
|
|
|
| class ArrayDatabase(BaseLMDB): |
| _dtype = None |
| _shape = None |
|
|
| def __init__( |
| self, |
| lmdb_dir: Union[str, Path], |
| image_list: Union[str, Path, pd.DataFrame]=None, |
| **kwargs |
| ): |
| super().__init__(path=lmdb_dir, **kwargs) |
| if image_list is None: |
| self.ids = list(range(len(self.keys))) |
| self.labels = list(range(len(self.ids))) |
| else: |
| df = pd.read_csv(str(image_list)) |
| assert 'id' in df, f'[ArrayDatabase] Error! {image_list} must have id keys.' |
| self.ids = df['id'].tolist() |
| assert max(self.ids) < len(self.keys) |
| if 'label' in df: |
| self.labels = df['label'].tolist() |
| else: |
| keys = [key for key in df if (key!='id' and type(df[key][0]) in [int, np.int64])] |
| |
| self.labels = df[keys].to_numpy() |
| self._length = len(self.ids) |
|
|
| def set_ids(self, ids): |
| self.ids = [self.ids[i] for i in ids] |
| self.labels = [self.labels[i] for i in ids] |
| self._length = len(self.ids) |
|
|
| def __len__(self): |
| return self._length |
|
|
| def __iter__(self): |
| return iter([self.keys[i] for i in self.ids]) |
|
|
| def __getitem__(self, index): |
| key = self.keys[self.ids[index]] |
| return super().__getitem__(key) |
|
|
| @property |
| def dtype(self): |
| if self._dtype is None: |
| protocol = self.protocol |
| self._dtype = self._get( |
| item="dtype", |
| encode_key=lambda key: _default_encode(key, protocol=protocol), |
| decompress_value=_default_decompress, |
| decode_value=_default_decode, |
| ) |
| return self._dtype |
|
|
| @property |
| def shape(self): |
| if self._shape is None: |
| protocol = self.protocol |
| self._shape = self._get( |
| item="shape", |
| encode_key=lambda key: _default_encode(key, protocol=protocol), |
| decompress_value=_default_decompress, |
| decode_value=_default_decode, |
| ) |
| return self._shape |
|
|
| def _decode_value(self, value: bytes) -> ndarray: |
| value = super()._decode_value(value) |
| return np.frombuffer(value, dtype=self.dtype).reshape(self.shape) |
|
|
| def _decode_values(self, values: list) -> ndarray: |
| shape = (len(values),) + self.shape |
| return np.frombuffer(b"".join(values), dtype=self.dtype).reshape(shape) |
|
|
| @staticmethod |
| def write(diter, outdir): |
| """ |
| diter is an iterator that has __len__ method |
| class Myiter(): |
| def __init__(self, data): |
| self.data = data |
| def __iter__(self): |
| self.counter = 0 |
| return self |
| def __len__(self): |
| return len(self.data) |
| def __next__(self): |
| if self.counter < len(self): |
| out = self.data[self.counter] |
| self.counter+=1 |
| return out |
| else: |
| raise StopIteration |
| a = iter(Myiter([1,2,3])) |
| for i in a: |
| print(i) |
| """ |
| outdir = Path(outdir) |
| outdir.mkdir(parents=True, exist_ok=True) |
| tmp_dir = Path("/tmp") / f"TEMP_{time.time()}" |
| tmp_dir.mkdir(parents=True, exist_ok=True) |
| |
| n = len(diter) |
| with lmdb.open(path=str(tmp_dir), map_size=2 ** 40) as env: |
| |
| with env.begin(write=True) as txn: |
| key = "protocol".encode("ascii") |
| value = pickle.dumps(pickle.DEFAULT_PROTOCOL) |
| txn.put(key=key, value=value, dupdata=False) |
| |
| with env.begin(write=True) as txn: |
| key = pickle.dumps("keys") |
| value = pickle.dumps(list(range(n))) |
| txn.put(key=key, value=value, dupdata=False) |
| |
| value = next(iter(diter)) |
| shape = value.shape |
| dtype = value.dtype |
| |
| with env.begin(write=True) as txn: |
| key = pickle.dumps("shape") |
| value = pickle.dumps(shape) |
| txn.put(key=key, value=value, dupdata=False) |
| |
| with env.begin(write=True) as txn: |
| key = pickle.dumps("dtype") |
| value = pickle.dumps(dtype) |
| txn.put(key=key, value=value, dupdata=False) |
| |
| with env.begin(write=True) as txn: |
| for key, value in tqdm(enumerate(iter(diter)), total=n, miniters=n//100, mininterval=300): |
| key = pickle.dumps(key) |
| value = pickle.dumps(value) |
| txn.put(key=key, value=value, dupdata=False) |
|
|
| |
| copy_tree(str(tmp_dir), str(outdir)) |
| shutil.rmtree(str(tmp_dir)) |
|
|
|
|
|
|
| class TensorDatabase(ArrayDatabase): |
| def _decode_value(self, value: bytes) -> Tensor: |
| return torch.from_numpy(super(TensorDatabase, self)._decode_value(value)) |
|
|
| def _decode_values(self, values: list) -> Tensor: |
| return torch.from_numpy(super(TensorDatabase, self)._decode_values(values)) |
|
|