|
|
''' |
|
|
This file implements the direction conversion from the latent ImageNet dataset to WebDataset. |
|
|
''' |
|
|
import os |
|
|
from argparse import ArgumentParser |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
import pickle |
|
|
|
|
|
import webdataset as wds |
|
|
|
|
|
from train_utils.datasets import ImageNetLatentDataset |
|
|
|
|
|
|
|
|
def convert2wds(args): |
|
|
os.makedirs(args.outdir, exist_ok=True) |
|
|
wds_path = os.path.join(args.outdir, f'latent_imagenet_512_{args.split}-%04d.tar') |
|
|
dataset = ImageNetLatentDataset(args.datadir, resolution=args.resolution, num_channels=args.num_channels, split=args.split) |
|
|
|
|
|
with wds.ShardWriter(wds_path, maxcount=args.maxcount, maxsize=args.maxsize) as sink: |
|
|
for i in tqdm(range(len(dataset)), dynamic_ncols=True): |
|
|
if i % args.maxcount == 0: |
|
|
print(f'writing to the {i // args.maxcount}th shard') |
|
|
img, label = dataset[i] |
|
|
label = np.argmax(label) |
|
|
sink.write({'__key__': f'{i:07d}', 'latent': pickle.dumps(img), 'cls': label}) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = ArgumentParser('Convert the latent imagenet dataset to WebDataset') |
|
|
parser.add_argument('--maxcount', type=int, default=10010, help='max number of entries per shard') |
|
|
parser.add_argument('--maxsize', type=int, default=10 ** 10, help='max size per shard') |
|
|
parser.add_argument('--outdir', type=str, default='latent_imagenet_wds', help='path to save the converted dataset') |
|
|
parser.add_argument('--datadir', type=str, default='latent_imagenet', help='path to the latent imagenet dataset') |
|
|
parser.add_argument('--resolution', type=int, default=64, help='image resolution') |
|
|
parser.add_argument('--num_channels', type=int, default=8, help='number of image channels') |
|
|
parser.add_argument('--split', type=str, default='train', help='split of the dataset') |
|
|
args = parser.parse_args() |
|
|
convert2wds(args) |