| import argparse |
| import re |
| import simdjson |
| import sys |
| import subprocess |
| import multiprocessing as mp |
| from pathlib import Path |
| from cloudpathlib import CloudPath |
| from tqdm import tqdm |
|
|
|
|
| def path_or_cloudpath(s): |
| if re.match(r"^\w+://", s): |
| return CloudPath(s) |
| return Path(s) |
|
|
|
|
| def parse_args(args): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--data-dir", |
| type=path_or_cloudpath, |
| required=True, |
| help="Directory containing a dataset in webdataset format.", |
| ) |
| parser.add_argument( |
| "--manifest-filename", |
| type=str, |
| default="manifest.jsonl", |
| help="Filename for the manifest that will be stored in the webdataset directory.", |
| ) |
| parser.add_argument("--tmp-dir", type=str, default=None, help="Temporary directory.") |
| parser.add_argument("--num-workers", type=int, default=2, help="Number of workers.") |
| args = parser.parse_args(args) |
| return args |
|
|
|
|
| def count_samples(shard_path, tmp_dir): |
| if isinstance(shard_path, CloudPath): |
| temp_shard_path = Path(tmp_dir) / shard_path.name |
| shard_path.download_to(temp_shard_path) |
| else: |
| temp_shard_path = shard_path |
|
|
| count = int(subprocess.check_output(f"tar tf {temp_shard_path} | wc -l", shell=True)) |
|
|
| if isinstance(shard_path, CloudPath): |
| temp_shard_path.unlink() |
|
|
| return count |
|
|
|
|
| def worker_fn(input_data): |
| basename, data_dir, tmp_dir = input_data |
| shard_path = data_dir / basename |
| return ( |
| basename, |
| { |
| "shard": basename.split(".")[0], |
| "num_sequences": count_samples(shard_path, tmp_dir), |
| }, |
| ) |
|
|
|
|
| def main(args): |
| args = parse_args(args) |
|
|
| shards = sorted([x for x in args.data_dir.iterdir() if x.name.endswith(".tar")]) |
| input_data = [(shard.name, args.data_dir, args.tmp_dir) for shard in shards] |
|
|
| print(f"Shards to process: {len(shards)}") |
| print("Creating pool.") |
| with mp.Pool(args.num_workers) as pool: |
| data = [] |
| for worker_data in tqdm(pool.imap_unordered(worker_fn, input_data)): |
| data.append(worker_data) |
|
|
| data = sorted(data) |
| data = [item[1] for item in data] |
| manifest_path = args.data_dir / args.manifest_filename |
| with manifest_path.open("w") as fp: |
| for item in data: |
| simdjson.dump(item, fp) |
| fp.write("\n") |
|
|
|
|
| if __name__ == "__main__": |
| main(sys.argv[1:]) |
|
|