Text Generation
Safetensors
Danish
English
llama
dfm-decoder-open-v0-7b-pt / create_dataset.py
peter-sk's picture
Super-squash branch 'main' using huggingface_hub
255c557
#!/usr/bin/env python
import datasets
import importlib
import tqdm
import transformers
import typer
def load_config(config_file: str):
spec = importlib.util.spec_from_file_location("config", config_file)
config_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config_module)
return config_module.sources, config_module.tokenizer_name, config_module.prefix
def tokenize(batch: dict):
if tokenizer:
return {"num_tokens": tokenizer(batch["text"], padding="do_not_pad", return_length=True)["length"]}
return {"num_tokens": 0}
def shard_indices(shard_index):
if not isinstance(shard_index, list):
shard_index = [shard_index]
return shard_index
def preprocess_shard(ds: datasets.Dataset, num_shards: int, index: int, num_proc: int):
shard = ds.shard(num_shards=num_shards, index=index, contiguous=True)
shard = shard.flatten_indices()
shard = shard.map(tokenize, batched=True, batch_size=1000, num_proc=num_proc)
return shard
def preprocess_subset(weights: dict, subsets: list, source: str, src_info: dict, dc: datasets.DownloadConfig, num_proc: int):
for key, frac in tqdm.tqdm(weights.items(), desc="Loading train subsets"):
uri_template = src_info["uri"]
print(f" Loading subset: {key} with fraction 1/{frac} from {uri_template.format(key=key)}")
ds = datasets.load_dataset(
src_info["format"],
data_files=uri_template.format(key=key),
split="train",
download_config=dc,
)
ds = ds.select_columns(["text"])
ds = ds.add_column("source", [source] * len(ds))
ds = ds.add_column("subset", [key] * len(ds))
ds = ds.shuffle(seed=42)
dss = [preprocess_shard(ds, int(src_info["shards"]/frac), i, num_proc) for i in shard_indices(src_info["shard_index"])]
ds = datasets.concatenate_datasets(dss)
ds = ds.cast_column("text", datasets.Value("large_string"))
print(f" Finished preprocessing subset: {key} with {sum(ds['num_tokens'])} tokens")
subsets.append(ds)
def main(
config_file: str,
num_proc: int = 96,
max_retries: int = 10,
):
sources, tokenizer_name, prefix = load_config(config_file)
global tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer_name) if tokenizer_name else None
dc = datasets.DownloadConfig(num_proc=num_proc, max_retries=max_retries)
train_subsets = []
test_subsets = []
file_name = f"{prefix}-"
for source, src_info in sources.items():
print(f"Processing source: {source}")
shard_index = src_info["shard_index"]
if not isinstance(shard_index, list):
shard_index = [shard_index]
file_name += f"{source}-{'_'.join(str(s) for s in shard_index)}-of-{src_info['shards']}-"
preprocess_subset(src_info["train"], train_subsets, source, src_info, dc, num_proc)
preprocess_subset(src_info["test"], test_subsets, source, src_info, dc, num_proc)
print("Concatenating train subsets")
final_train = datasets.concatenate_datasets(train_subsets)
print("Shuffling final train dataset")
final_train = final_train.shuffle(seed=42)
print("Flattening final train dataset")
final_train = final_train.flatten_indices()
print("Concatenating test subsets")
final_test = datasets.concatenate_datasets(test_subsets)
print("Shuffling final test dataset")
final_test = final_test.shuffle(seed=42)
print("Flattening final test dataset")
final_test = final_test.flatten_indices()
test_file = f"{file_name}test/{file_name}test.parquet"
print(f"Writing final test dataset with {sum(final_test['num_tokens'])} tokens to {test_file}")
final_test.to_parquet(test_file)
train_file = f"{file_name}train/{file_name}train.parquet"
print(f"Writing final train dataset with {sum(final_train['num_tokens'])} tokens to {train_file}")
final_train.to_parquet(train_file)
if __name__ == "__main__":
typer.run(main)