File size: 2,524 Bytes
5ce8761
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import argparse
import os

import numpy as np
import zarr
from numcodecs import Blosc
from tqdm import tqdm

from utils.common_utils import str2bool


def parse_arguments():
    parser = argparse.ArgumentParser()
    # Tuples: (name, type, default)
    arguments = [
        # Dataset/loader arguments
        ('src', str, '/data/user_data/ngkanats/zarr_datasets/Peract2_dense_zarr/train.zarr'),
        ('tgt', str, '/data/user_data/ngkanats/zarr_datasets/Peract2_dense_zarr/train_rechunked4.zarr'),
        ('chunk_size', int, 4),
        ('shuffle', str2bool, False)
    ]
    for arg in arguments:
        parser.add_argument(f'--{arg[0]}', type=arg[1], default=arg[2])

    return parser.parse_args()


def rechunk_zarr_group(
    old_zarr_path,
    new_zarr_path,
    chunk_size=4,
    shuffle=False,
    compressor=Blosc(cname='lz4', clevel=1, shuffle=Blosc.SHUFFLE)
):
    # Load old Zarr group (read-only)
    old_group = zarr.open_group(old_zarr_path, mode='r')
    if shuffle:
        inds = np.random.permutation(len(old_group['action']))
    else:
        inds = np.arange(len(old_group['action']))

    # Create new Zarr group (overwrite if exists)
    if os.path.exists(new_zarr_path):
        print(f"Deleting existing {new_zarr_path}")
        import shutil
        shutil.rmtree(new_zarr_path)

    new_group = zarr.open_group(new_zarr_path, mode='w')

    # Copy datasets with new chunking & compression
    for array_name in old_group.array_keys():
        old_array = old_group[array_name]
        shape = old_array.shape
        dtype = old_array.dtype

        # Choose chunk shape: match all dims except dim 0, set to chunk_size
        chunk_shape = (min(chunk_size, shape[0]),) + shape[1:]

        print(f"Rechunking {array_name} | shape={shape}, chunks={chunk_shape}")

        new_array = new_group.create_dataset(
            name=array_name,
            shape=shape,
            dtype=dtype,
            chunks=chunk_shape,
            compressor=compressor,
            overwrite=True,
        )

        # Copy over data in chunks (from old_array)
        for i in tqdm(range(0, shape[0], chunk_size), desc=f"Copying {array_name}"):
            end = min(i + chunk_size, shape[0])
            new_array[i:end] = old_array[inds[i:end]]

    print("✅ Rechunking complete.")


if __name__ == '__main__':
    args = parse_arguments()
    rechunk_zarr_group(
        old_zarr_path=args.src,
        new_zarr_path=args.tgt,
        chunk_size=args.chunk_size,
        shuffle=args.shuffle
    )