| |
|
|
| import argparse |
| import os |
| import time |
| import subprocess |
| import requests |
| from huggingface_hub import snapshot_download |
|
|
|
|
| def run_command(command): |
| print(f"Running: {command}") |
| subprocess.run(command, shell=True, check=True) |
|
|
|
|
| def download_dataset(repo_id, local_dir, allow_patterns): |
| print(f"Downloading dataset from {repo_id}...") |
| max_retries = 5 |
| retry_delay = 10 |
| for attempt in range(max_retries): |
| try: |
| snapshot_download( |
| repo_id, |
| repo_type="dataset", |
| local_dir=local_dir, |
| allow_patterns=allow_patterns, |
| resume_download=True, |
| max_workers=16, |
| ) |
| break |
| except requests.exceptions.ReadTimeout: |
| if attempt < max_retries - 1: |
| print(f"Timeout occurred. Retrying in {retry_delay} seconds...") |
| time.sleep(retry_delay) |
| else: |
| raise |
| print(f"Dataset downloaded to {local_dir}") |
|
|
|
|
| def parquet_to_jsonl(dataset, work_dir, src_dir, tgt_dir, ntasks=64): |
| from datatrove.executor import LocalPipelineExecutor |
| from datatrove.pipeline.readers import ParquetReader |
| from datatrove.pipeline.writers import JsonlWriter |
|
|
| pipeline_exec = LocalPipelineExecutor( |
| pipeline=[ |
| ParquetReader( |
| src_dir, |
| file_progress=True, |
| doc_progress=True, |
| glob_pattern="**/*.parquet", |
| ), |
| JsonlWriter( |
| tgt_dir, |
| output_filename=dataset + ".chunk.${rank}.jsonl", |
| compression=None, |
| ), |
| ], |
| tasks=ntasks, |
| logging_dir=os.path.join(work_dir, "datatrove"), |
| ) |
| pipeline_exec.run() |
|
|
|
|
| def setup_terashuf(work_dir): |
| terashuf_dir = os.path.join(work_dir, "terashuf") |
| terashuf_executable = os.path.join(terashuf_dir, "terashuf") |
|
|
| if os.path.exists(terashuf_executable): |
| print("terashuf executable already exists. Skipping setup.") |
| return terashuf_dir |
|
|
| print("Setting up terashuf...") |
| run_command(f"git clone https://github.com/alexandres/terashuf {terashuf_dir}") |
| run_command(f"make -C {terashuf_dir}") |
| return terashuf_dir |
|
|
|
|
| def main(dataset, memory, data_dir, seed=42, nchunks=32): |
| |
| repo_id = { |
| "fineweb_edu": "HuggingFaceFW/fineweb-edu", |
| "fineweb_edu_10bt": "HuggingFaceFW/fineweb-edu", |
| "dclm_baseline_1.0": "mlfoundations/dclm-baseline-1.0", |
| "dclm_baseline_1.0_10prct": "mlfoundations/dclm-baseline-1.0", |
| }[dataset] |
| src_dir = f"{data_dir}/{dataset}" |
| out_dir = f"{src_dir}_shuffled" |
| os.makedirs(out_dir, exist_ok=True) |
| work_dir = src_dir |
| prefix = f"{dataset}.chunk." |
| orig_extension = { |
| "fineweb_edu": ".jsonl", |
| "fineweb_edu_10bt": ".jsonl", |
| "dclm_baseline_1.0": ".jsonl.zst", |
| "dclm_baseline_1.0_10prct": ".jsonl.zst", |
| }[dataset] |
| cat_command = { |
| "fineweb_edu": "cat {}", |
| "fineweb_edu_10bt": "cat {}", |
| "dclm_baseline_1.0": "zstdcat {} && echo", |
| "dclm_baseline_1.0_10prct": "zstdcat {} && echo", |
| }[dataset] |
| allow_patterns = { |
| "fineweb_edu": None, |
| "fineweb_edu_10bt": "sample/10BT/*", |
| "dclm_baseline_1.0": "*.jsonl.zst", |
| "dclm_baseline_1.0_10prct": "global-shard_01_of_10/*.jsonl.zst", |
| }[dataset] |
| suffix = ".jsonl" |
| k_validation = 10000 |
|
|
| |
| terashuf_dir = setup_terashuf(work_dir) |
|
|
| |
| download_dataset(repo_id, src_dir, allow_patterns) |
|
|
| if "fineweb" in dataset: |
| parquet_to_jsonl(dataset, work_dir, src_dir, src_dir) |
|
|
| |
| os.environ["MEMORY"] = f"{memory}" |
| os.environ["SEED"] = f"{seed}" |
|
|
| |
| terashuf_executable = os.path.join(terashuf_dir, "terashuf") |
| run_command( |
| f"ulimit -n 100000 && " |
| f"find {src_dir} -type f -name '*{orig_extension}' -print0 | xargs -0 -I {{}} sh -c '{cat_command}' | {terashuf_executable} | " |
| f"split -n r/{nchunks} -d --suffix-length 2 --additional-suffix {suffix} - {out_dir}/{prefix}" |
| "; trap 'echo \"Caught signal 13, exiting with code 1\"; exit 1' SIGPIPE;" |
| ) |
|
|
| |
| validation_file = f"{out_dir}/{dataset}.val{suffix}" |
| for i in range(nchunks): |
| chunk_file = f"{out_dir}/{prefix}{i:02d}{suffix}" |
| run_command(f"head -n {k_validation} {chunk_file} >> {validation_file}") |
| run_command(f"sed -i '1,{k_validation}d' {chunk_file}") |
|
|
| print("All tasks completed successfully!") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("dataset", type=str) |
| parser.add_argument("memory", type=float, default=8) |
| parser.add_argument("--data_dir", type=str, default="data") |
| parser.add_argument("--seed", type=int, default=42) |
| parser.add_argument("--nchunks", type=int, default=32) |
|
|
| args = parser.parse_args() |
|
|
| main(args.dataset, args.memory, args.data_dir, args.seed, args.nchunks) |
|
|