|
|
import argparse |
|
|
import multiprocessing |
|
|
import os |
|
|
import time |
|
|
|
|
|
import mxnet as mx |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def read_worker(args, q_in): |
|
|
path_imgidx = os.path.join(args.input, "train.idx") |
|
|
path_imgrec = os.path.join(args.input, "train.rec") |
|
|
imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r") |
|
|
|
|
|
s = imgrec.read_idx(0) |
|
|
header, _ = mx.recordio.unpack(s) |
|
|
assert header.flag > 0 |
|
|
|
|
|
imgidx = np.array(range(1, int(header.label[0]))) |
|
|
np.random.shuffle(imgidx) |
|
|
|
|
|
for idx in imgidx: |
|
|
item = imgrec.read_idx(idx) |
|
|
q_in.put(item) |
|
|
|
|
|
q_in.put(None) |
|
|
imgrec.close() |
|
|
|
|
|
|
|
|
def write_worker(args, q_out): |
|
|
pre_time = time.time() |
|
|
|
|
|
if args.input[-1] == "/": |
|
|
args.input = args.input[:-1] |
|
|
dirname = os.path.dirname(args.input) |
|
|
basename = os.path.basename(args.input) |
|
|
output = os.path.join(dirname, f"shuffled_{basename}") |
|
|
os.makedirs(output, exist_ok=True) |
|
|
|
|
|
path_imgidx = os.path.join(output, "train.idx") |
|
|
path_imgrec = os.path.join(output, "train.rec") |
|
|
save_record = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "w") |
|
|
more = True |
|
|
count = 0 |
|
|
while more: |
|
|
deq = q_out.get() |
|
|
if deq is None: |
|
|
more = False |
|
|
else: |
|
|
header, jpeg = mx.recordio.unpack(deq) |
|
|
|
|
|
if isinstance(header.label, float): |
|
|
label = header.label |
|
|
else: |
|
|
label = header.label[0] |
|
|
|
|
|
header = mx.recordio.IRHeader(flag=header.flag, label=label, id=header.id, id2=header.id2) |
|
|
save_record.write_idx(count, mx.recordio.pack(header, jpeg)) |
|
|
count += 1 |
|
|
if count % 10000 == 0: |
|
|
cur_time = time.time() |
|
|
print("save time:", cur_time - pre_time, " count:", count) |
|
|
pre_time = cur_time |
|
|
print(count) |
|
|
save_record.close() |
|
|
|
|
|
|
|
|
def main(args): |
|
|
queue = multiprocessing.Queue(10240) |
|
|
read_process = multiprocessing.Process(target=read_worker, args=(args, queue)) |
|
|
read_process.daemon = True |
|
|
read_process.start() |
|
|
write_process = multiprocessing.Process(target=write_worker, args=(args, queue)) |
|
|
write_process.start() |
|
|
write_process.join() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("input", help="path to source rec.") |
|
|
main(parser.parse_args()) |
|
|
|