File size: 5,097 Bytes
0decf42 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | import os
import os.path as osp
from PIL import Image
import six
import lmdb
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import pyarrow as pa
import numpy as np
import torch.utils.data as data
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
train_lmdb_path = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/train.lmdb'
val_lmdb_path = '/apdcephfs/share_1290939/0_public_datasets/imageNet_2012/val.lmdb'
# from data.lmdb_dataset import ImageFolderLMDB, train_lmdb_path, val_lmdb_path
# lmdb_path = train_lmdb_path if is_train else val_lmdb_path
# dataset = ImageFolderLMDB(db_path=lmdb_path, transform=transform)
def loads_pyarrow(buf):
"""
Args:
buf: the output of `dumps`.
"""
return pa.deserialize(buf)
class ImageFolderLMDB(data.Dataset):
def __init__(self, db_path, transform=None, target_transform=None):
self.db_path = db_path
self.env = lmdb.open(db_path, subdir=osp.isdir(db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = loads_pyarrow(txn.get(b'__len__'))
self.keys = loads_pyarrow(txn.get(b'__keys__'))
self.transform = transform
self.target_transform = target_transform
def __getstate__(self):
state = self.__dict__
state["env"] = None
return state
def __setstate__(self, state):
self.__dict__ = state
self.env = lmdb.open(self.db_path, subdir=osp.isdir(self.db_path),
readonly=True, lock=False,
readahead=False, meminit=False)
with self.env.begin(write=False) as txn:
self.length = loads_pyarrow(txn.get(b'__len__'))
self.keys = loads_pyarrow(txn.get(b'__keys__'))
def __getitem__(self, index):
env = self.env
with env.begin(write=False) as txn:
byteflow = txn.get(self.keys[index])
unpacked = loads_pyarrow(byteflow)
# load img
imgbuf = unpacked[0]
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
img = Image.open(buf).convert('RGB')
if self.transform is not None:
img = self.transform(img)
# load label
target = unpacked[1]
if self.target_transform is not None:
target = self.transform(target)
return img, target
# if self.transform is not None:
# img = self.transform(img)
#
# # im2arr = np.array(img)
#
# if self.target_transform is not None:
# target = self.target_transform(target)
#
# return img, target
# return im2arr, target
def __len__(self):
return self.length
def __repr__(self):
return self.__class__.__name__ + ' (' + self.db_path + ')'
def raw_reader(path):
with open(path, 'rb') as f:
bin_data = f.read()
return bin_data
def dumps_pyarrow(obj):
"""
Serialize an object.
Returns:
Implementation-dependent bytes-like object
"""
return pa.serialize(obj).to_buffer()
def folder2lmdb(dpath, name="train", write_frequency=5000):
directory = osp.expanduser(osp.join(dpath, name))
print("Loading dataset from %s" % directory)
dataset = ImageFolder(directory, loader=raw_reader)
data_loader = DataLoader(dataset, num_workers=4, collate_fn=lambda x: x)
lmdb_path = osp.join(dpath, "%s.lmdb" % name)
isdir = os.path.isdir(lmdb_path)
print("Generate LMDB to %s" % lmdb_path)
db = lmdb.open(lmdb_path, subdir=isdir,
map_size=1099511627776 * 2, readonly=False,
meminit=False, map_async=True)
txn = db.begin(write=True)
for idx, data in enumerate(data_loader):
image, label = data[0]
txn.put(u'{}'.format(idx).encode('ascii'), dumps_pyarrow((image, label)))
if idx % write_frequency == 0:
print("[%d/%d]" % (idx, len(data_loader)))
txn.commit()
txn = db.begin(write=True)
# finish iterating through dataset
txn.commit()
keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
with db.begin(write=True) as txn:
txn.put(b'__keys__', dumps_pyarrow(keys))
txn.put(b'__len__', dumps_pyarrow(len(keys)))
print("Flushing database ...")
db.sync()
db.close()
if __name__ == "__main__":
# lmdb_path = '/apdcephfs/share_1016399/0_public_datasets/imageNet_2012/train.lmdb'
# from lmdb_dataset import ImageFolderLMDB
# dataset = ImageFolderLMDB(db_path=lmdb_path)
# for x, y in dataset:
# print(type(x), type(y))
# exit()
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--dir', type=str, required=True, help="The dataset directory to process")
args = parser.parse_args()
# generate lmdb
path = args.dir
folder2lmdb(path, name="train")
folder2lmdb(path, name="val")
|