| | from datasets import concatenate_datasets, load_dataset |
| | |
| | |
| | |
| | |
| | |
| | bookcorpus = load_dataset("bookcorpus", split="train") |
| | wiki = load_dataset("wikipedia", "20220301.en", split="train") |
| | wiki = wiki.remove_columns([col for col in wiki.column_names if col != "text"]) |
| |
|
| | assert bookcorpus.features.type == wiki.features.type |
| | raw_datasets = concatenate_datasets([bookcorpus, wiki]) |
| | print(raw_datasets) |
| |
|
| |
|
| | |
| | |
| | |
| | |
| |
|
| | from transformers import AutoTokenizer |
| | import multiprocessing |
| |
|
| | |
| | |
| | tokenizer = AutoTokenizer.from_pretrained("cat_tokenizer") |
| | num_proc = min(multiprocessing.cpu_count(), 8) |
| | print(f"The max length for the tokenizer is: {tokenizer.model_max_length}") |
| |
|
| | def group_texts(examples): |
| | tokenized_inputs = tokenizer( |
| | examples["text"], return_special_tokens_mask=True, truncation=True, max_length=tokenizer.model_max_length |
| | ) |
| | return tokenized_inputs |
| |
|
| | |
| | tokenized_datasets = raw_datasets.map(group_texts, batched=True, remove_columns=["text"], num_proc=num_proc) |
| | print(tokenized_datasets.features) |
| |
|
| |
|
| | |
| | from itertools import chain |
| |
|
| | |
| | |
| | def group_texts(examples): |
| | |
| | concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} |
| | total_length = len(concatenated_examples[list(examples.keys())[0]]) |
| | |
| | |
| | if total_length >= tokenizer.model_max_length: |
| | total_length = (total_length // tokenizer.model_max_length) * tokenizer.model_max_length |
| | |
| | result = { |
| | k: [t[i : i + tokenizer.model_max_length] for i in range(0, total_length, tokenizer.model_max_length)] |
| | for k, t in concatenated_examples.items() |
| | } |
| | return result |
| |
|
| | tokenized_datasets = tokenized_datasets.map(group_texts, batched=True, num_proc=num_proc) |
| | |
| | tokenized_datasets = tokenized_datasets.shuffle(seed=34) |
| |
|
| | print(tokenized_datasets) |
| | print(f"the dataset contains in total {len(tokenized_datasets)*tokenizer.model_max_length} tokens") |
| | |
| | |
| |
|
| | user_id = 'chaoyan' |
| | |
| | dataset_id=f"{user_id}/processed_bert_dataset" |
| | tokenized_datasets.push_to_hub(dataset_id) |