| | import argparse |
| | import json |
| | import os |
| | import tempfile |
| | from pathlib import Path |
| | from tqdm import tqdm |
| |
|
| | from datasets import load_dataset |
| | from tokenizers import SentencePieceBPETokenizer |
| | from transformers import LlamaTokenizerFast, TrainingArguments, AutoTokenizer |
| |
|
| | def main(args): |
| |
|
| | |
| | if args.dataset_name is not None: |
| | if args.dataset_type: |
| | if os.path.isfile(args.dataset_name): |
| | data_files = [args.dataset_name] |
| | else: |
| | data_files = os.listdir(args.dataset_name) |
| | data_files = [Path(args.dataset_name) / f for f in data_files] |
| | print(f"Training on {len(data_files)} files") |
| | dataset = load_dataset(args.dataset_type, |
| | data_files=data_files, |
| | split=args.dataset_split, |
| | token=args.hub_token if args.hub_token else None |
| | ) |
| | else: |
| | dataset = load_dataset(args.dataset_name, |
| | split=args.dataset_split, |
| | streaming=True, |
| | token=args.hub_token if args.hub_token else None |
| | ) |
| | print(dataset) |
| | else: |
| | raise ValueError("No dataset name provided or dataset is already tokenized") |
| |
|
| | |
| | dataset = dataset.remove_columns([col for col in dataset.column_names if col != "text"]) |
| |
|
| | |
| | dataset = dataset.shuffle(seed=args.seed) |
| | |
| | |
| | if args.num_samples: |
| | dataset = dataset.select(range(args.num_samples)) |
| |
|
| | |
| | tokenizer = SentencePieceBPETokenizer() |
| |
|
| | |
| | tokenizer.train_from_iterator( |
| | iterator=dataset['text'], |
| | vocab_size=args.vocab_size, |
| | show_progress=True, |
| | special_tokens=["<unk>", "<s>", "</s>", "<pad>"], |
| | ) |
| |
|
| | |
| | new_tokenizer_file = tempfile.NamedTemporaryFile(prefix='tokenizer_', suffix='.json').name |
| | tokenizer.save(new_tokenizer_file, pretty=True) |
| |
|
| | |
| | if args.reference_tokenizer is not None and args.hub_token is not None: |
| | reference_tokenizer = AutoTokenizer.from_pretrained(args.reference_tokenizer, token=args.hub_token if args.hub_token else None) |
| | reference_tokenizer_path = tempfile.TemporaryDirectory().name |
| | reference_tokenizer.save_pretrained(reference_tokenizer_path) |
| | else: |
| | raise ValueError("No tokenizer name provided or no hub token provided. Try using `--reference_tokenizer 'mistralai/Mistral-7B-Instruct-v0.2'") |
| |
|
| | |
| | with open(new_tokenizer_file) as f: |
| | new_tokenizer_json = json.load(f) |
| |
|
| | with open(Path(reference_tokenizer_path) / "tokenizer.json") as f: |
| | reference_tokenizer_json = json.load(f) |
| |
|
| | |
| | new_tokenizer_json["normalizer"] = reference_tokenizer_json["normalizer"] |
| | new_tokenizer_json["pre_tokenizer"] = reference_tokenizer_json["pre_tokenizer"] |
| | new_tokenizer_json["post_processor"] = reference_tokenizer_json["post_processor"] |
| | new_tokenizer_json["decoder"] = reference_tokenizer_json["decoder"] |
| | new_tokenizer_json["model"]['fuse_unk'] = reference_tokenizer_json["model"]['fuse_unk'] |
| | new_tokenizer_json["model"]['byte_fallback'] = reference_tokenizer_json["model"]['byte_fallback'] |
| |
|
| | |
| | with open(new_tokenizer_file, "w") as f: |
| | json.dump(new_tokenizer_json, f, indent=2, ensure_ascii=False) |
| |
|
| | |
| | new_llama_tokenizer = LlamaTokenizerFast( |
| | tokenizer_file=new_tokenizer_file, |
| | name_or_path=args.reference_tokenizer + "-tokenizer", |
| | unk_token="<unk>", |
| | unk_token_id=0, |
| | bos_token="<s>", |
| | bos_token_id=1, |
| | eos_token="</s>", |
| | eos_token_id=2, |
| | pad_token="<pad>", |
| | pad_token_id=3, |
| | padding_side="right", |
| | ) |
| |
|
| | |
| | new_llama_tokenizer.save_pretrained(args.output) |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Train a new Llama tokenizer") |
| | parser.add_argument( |
| | "--dataset_name", |
| | type=str, |
| | default=None, |
| | help="The name of the dataset to be tokenized", |
| | ) |
| | parser.add_argument( |
| | "--dataset_type", |
| | type=str, |
| | default=None, |
| | help="The type, 'text', 'json', or 'csv'. Leave blank for regular HF datasets", |
| | ) |
| | parser.add_argument( |
| | "--dataset_split", |
| | type=str, |
| | default=None, |
| | help="The split of the dataset to be tokenized", |
| | ) |
| | parser.add_argument( |
| | "--hub_token", |
| | type=str, |
| | default=None, |
| | help="The token to access the dataset on the hub", |
| | ) |
| | parser.add_argument( |
| | "--reference_tokenizer", |
| | type=str, |
| | default=None, |
| | help="The name of the reference tokenizer to use", |
| | ) |
| | parser.add_argument( |
| | "--seed", |
| | type=int, |
| | default=123, |
| | help="set random seed", |
| | ) |
| | parser.add_argument( |
| | "--num_samples", |
| | type=int, |
| | default=None, |
| | help="Number of samples to use from the dataset", |
| | ) |
| | parser.add_argument( |
| | "--vocab_size", |
| | type=int, |
| | default=None, |
| | help="Vocabulary size to use for the tokenizer", |
| | ) |
| | parser.add_argument( |
| | "--output", |
| | type=str, |
| | default="./", |
| | help="Output path for the new tokenizer", |
| | ) |
| | args = parser.parse_args() |
| | main(args) |
| |
|
| | |
| | |