| | |
| | import datasets |
| | import importlib |
| | import tqdm |
| | import transformers |
| | import typer |
| |
|
| | def load_config(config_file: str): |
| | spec = importlib.util.spec_from_file_location("config", config_file) |
| | config_module = importlib.util.module_from_spec(spec) |
| | spec.loader.exec_module(config_module) |
| | return config_module.sources, config_module.tokenizer_name, config_module.prefix |
| |
|
| | def tokenize(batch: dict): |
| | if tokenizer: |
| | return {"num_tokens": tokenizer(batch["text"], padding="do_not_pad", return_length=True)["length"]} |
| | return {"num_tokens": 0} |
| |
|
| | def shard_indices(shard_index): |
| | if not isinstance(shard_index, list): |
| | shard_index = [shard_index] |
| | return shard_index |
| |
|
| | def preprocess_shard(ds: datasets.Dataset, num_shards: int, index: int, num_proc: int): |
| | shard = ds.shard(num_shards=num_shards, index=index, contiguous=True) |
| | shard = shard.flatten_indices() |
| | shard = shard.map(tokenize, batched=True, batch_size=1000, num_proc=num_proc) |
| | return shard |
| |
|
| | def preprocess_subset(weights: dict, subsets: list, source: str, src_info: dict, dc: datasets.DownloadConfig, num_proc: int): |
| | for key, frac in tqdm.tqdm(weights.items(), desc="Loading train subsets"): |
| | uri_template = src_info["uri"] |
| | print(f" Loading subset: {key} with fraction 1/{frac} from {uri_template.format(key=key)}") |
| | ds = datasets.load_dataset( |
| | src_info["format"], |
| | data_files=uri_template.format(key=key), |
| | split="train", |
| | download_config=dc, |
| | ) |
| | ds = ds.select_columns(["text"]) |
| | ds = ds.add_column("source", [source] * len(ds)) |
| | ds = ds.add_column("subset", [key] * len(ds)) |
| | ds = ds.shuffle(seed=42) |
| | dss = [preprocess_shard(ds, int(src_info["shards"]/frac), i, num_proc) for i in shard_indices(src_info["shard_index"])] |
| | ds = datasets.concatenate_datasets(dss) |
| | ds = ds.cast_column("text", datasets.Value("large_string")) |
| | print(f" Finished preprocessing subset: {key} with {sum(ds['num_tokens'])} tokens") |
| | subsets.append(ds) |
| |
|
| | def main( |
| | config_file: str, |
| | num_proc: int = 96, |
| | max_retries: int = 10, |
| | ): |
| | sources, tokenizer_name, prefix = load_config(config_file) |
| | global tokenizer |
| | tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) if tokenizer_name else None |
| | dc = datasets.DownloadConfig(num_proc=num_proc, max_retries=max_retries) |
| | train_subsets = [] |
| | test_subsets = [] |
| | file_name = f"{prefix}-" |
| | for source, src_info in sources.items(): |
| | print(f"Processing source: {source}") |
| | shard_index = src_info["shard_index"] |
| | if not isinstance(shard_index, list): |
| | shard_index = [shard_index] |
| | file_name += f"{source}-{'_'.join(str(s) for s in shard_index)}-of-{src_info['shards']}-" |
| | preprocess_subset(src_info["train"], train_subsets, source, src_info, dc, num_proc) |
| | preprocess_subset(src_info["test"], test_subsets, source, src_info, dc, num_proc) |
| | print("Concatenating train subsets") |
| | final_train = datasets.concatenate_datasets(train_subsets) |
| | print("Shuffling final train dataset") |
| | final_train = final_train.shuffle(seed=42) |
| | print("Flattening final train dataset") |
| | final_train = final_train.flatten_indices() |
| | print("Concatenating test subsets") |
| | final_test = datasets.concatenate_datasets(test_subsets) |
| | print("Shuffling final test dataset") |
| | final_test = final_test.shuffle(seed=42) |
| | print("Flattening final test dataset") |
| | final_test = final_test.flatten_indices() |
| | test_file = f"{file_name}test/{file_name}test.parquet" |
| | print(f"Writing final test dataset with {sum(final_test['num_tokens'])} tokens to {test_file}") |
| | final_test.to_parquet(test_file) |
| | train_file = f"{file_name}train/{file_name}train.parquet" |
| | print(f"Writing final train dataset with {sum(final_train['num_tokens'])} tokens to {train_file}") |
| | final_train.to_parquet(train_file) |
| |
|
| | if __name__ == "__main__": |
| | typer.run(main) |
| |
|