File size: 5,580 Bytes
9847679 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Iterator
from datasets import load_dataset
from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT / "src"))
from sllm.config import DataMixConfig, load_json, save_json
from sllm.utils import setup_logger
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Train a BPE tokenizer for the sLLM pipeline.")
parser.add_argument("--data-config", required=True, help="Path to data mixture JSON config.")
parser.add_argument("--output-dir", required=True, help="Directory where tokenizer files will be stored.")
parser.add_argument("--vocab-size", type=int, default=49_152, help="Target tokenizer vocabulary size.")
parser.add_argument("--seed", type=int, default=42, help="Random seed for dataset shuffling.")
return parser
def iter_source_texts(source, seed: int, limit: int) -> Iterator[str]:
dataset = load_dataset(
path=source.path,
name=source.config_name,
data_dir=source.data_dir,
split=source.split,
revision=source.revision,
streaming=source.streaming,
)
if source.streaming:
dataset = dataset.shuffle(seed=seed, buffer_size=source.shuffle_buffer)
yielded = 0
for row in dataset:
text = row.get(source.text_field or "", None)
if not isinstance(text, str):
continue
text = text.strip()
if not text:
continue
yield text
yielded += 1
if yielded >= limit:
return
def mixed_iterator(config: DataMixConfig, seed: int, logger) -> Iterator[str]:
weight_map = config.normalized_weights()
total_docs = config.tokenizer_sample_documents
per_source = {
source.name: max(1, int(total_docs * weight_map[source.name]))
for source in config.sources
}
for index, source in enumerate(config.sources):
limit = source.sample_documents or per_source[source.name]
logger.info(
"Tokenizer source start | name=%s path=%s data_dir=%s split=%s text_field=%s limit_docs=%s streaming=%s",
source.name,
source.path,
source.data_dir,
source.split,
source.text_field,
f"{limit:,}",
source.streaming,
)
yield from iter_source_texts(source, seed + index, limit)
def main() -> None:
args = build_parser().parse_args()
data_config = DataMixConfig.from_dict(load_json(args.data_config))
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger, log_path = setup_logger("sllm.train_tokenizer", output_dir, "train_tokenizer")
logger.info("Tokenizer training started")
logger.info("Log file: %s", log_path)
logger.info("Arguments | data_config=%s output_dir=%s vocab_size=%s seed=%s", args.data_config, args.output_dir, args.vocab_size, args.seed)
logger.info("Tokenizer config | sample_documents=%s min_frequency=%s special_tokens=%s num_sources=%s", f"{data_config.tokenizer_sample_documents:,}", data_config.tokenizer_min_frequency, data_config.tokenizer_special_tokens, len(data_config.sources))
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tokenizer.decoder = decoders.ByteLevel()
trainer = trainers.BpeTrainer(
vocab_size=args.vocab_size,
min_frequency=data_config.tokenizer_min_frequency,
special_tokens=data_config.tokenizer_special_tokens,
show_progress=True,
)
tokenizer.train_from_iterator(mixed_iterator(data_config, args.seed, logger), trainer=trainer)
bos_id = tokenizer.token_to_id("<bos>")
eos_id = tokenizer.token_to_id("<eos>")
pad_id = tokenizer.token_to_id("<pad>")
if bos_id is None or eos_id is None or pad_id is None:
raise RuntimeError("Tokenizer special tokens were not created correctly.")
tokenizer.post_processor = processors.TemplateProcessing(
single="<bos> $A <eos>",
pair="<bos> $A <eos> $B:1 <eos>:1",
special_tokens=[
("<bos>", bos_id),
("<eos>", eos_id),
],
)
tokenizer_path = output_dir / "tokenizer.json"
tokenizer.save(str(tokenizer_path))
metadata = {
"vocab_size": tokenizer.get_vocab_size(),
"special_tokens": {
"pad_token": "<pad>",
"bos_token": "<bos>",
"eos_token": "<eos>",
"unk_token": "<unk>",
"pad_token_id": pad_id,
"bos_token_id": bos_id,
"eos_token_id": eos_id,
"unk_token_id": tokenizer.token_to_id("<unk>"),
},
"data_config": data_config.to_dict(),
}
save_json(output_dir / "tokenizer_meta.json", metadata)
with (output_dir / "tokenizer_summary.json").open("w", encoding="utf-8") as handle:
json.dump(metadata, handle, ensure_ascii=False, indent=2)
logger.info("Tokenizer saved | path=%s", tokenizer_path)
logger.info(
"Tokenizer summary | vocab_size=%s pad_id=%s bos_id=%s eos_id=%s unk_id=%s",
tokenizer.get_vocab_size(),
pad_id,
bos_id,
eos_id,
tokenizer.token_to_id("<unk>"),
)
logger.info("Tokenizer metadata saved | path=%s", output_dir / "tokenizer_meta.json")
if __name__ == "__main__":
main()
|