QED-75M_artifacts / scripts /train_tokenizer.py
levossadtchi's picture
Add files using upload-large-folder tool
9847679 verified
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
from typing import Iterator
from datasets import load_dataset
from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers
ROOT = Path(__file__).resolve().parents[1]
sys.path.append(str(ROOT / "src"))
from sllm.config import DataMixConfig, load_json, save_json
from sllm.utils import setup_logger
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Train a BPE tokenizer for the sLLM pipeline.")
parser.add_argument("--data-config", required=True, help="Path to data mixture JSON config.")
parser.add_argument("--output-dir", required=True, help="Directory where tokenizer files will be stored.")
parser.add_argument("--vocab-size", type=int, default=49_152, help="Target tokenizer vocabulary size.")
parser.add_argument("--seed", type=int, default=42, help="Random seed for dataset shuffling.")
return parser
def iter_source_texts(source, seed: int, limit: int) -> Iterator[str]:
dataset = load_dataset(
path=source.path,
name=source.config_name,
data_dir=source.data_dir,
split=source.split,
revision=source.revision,
streaming=source.streaming,
)
if source.streaming:
dataset = dataset.shuffle(seed=seed, buffer_size=source.shuffle_buffer)
yielded = 0
for row in dataset:
text = row.get(source.text_field or "", None)
if not isinstance(text, str):
continue
text = text.strip()
if not text:
continue
yield text
yielded += 1
if yielded >= limit:
return
def mixed_iterator(config: DataMixConfig, seed: int, logger) -> Iterator[str]:
weight_map = config.normalized_weights()
total_docs = config.tokenizer_sample_documents
per_source = {
source.name: max(1, int(total_docs * weight_map[source.name]))
for source in config.sources
}
for index, source in enumerate(config.sources):
limit = source.sample_documents or per_source[source.name]
logger.info(
"Tokenizer source start | name=%s path=%s data_dir=%s split=%s text_field=%s limit_docs=%s streaming=%s",
source.name,
source.path,
source.data_dir,
source.split,
source.text_field,
f"{limit:,}",
source.streaming,
)
yield from iter_source_texts(source, seed + index, limit)
def main() -> None:
args = build_parser().parse_args()
data_config = DataMixConfig.from_dict(load_json(args.data_config))
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
logger, log_path = setup_logger("sllm.train_tokenizer", output_dir, "train_tokenizer")
logger.info("Tokenizer training started")
logger.info("Log file: %s", log_path)
logger.info("Arguments | data_config=%s output_dir=%s vocab_size=%s seed=%s", args.data_config, args.output_dir, args.vocab_size, args.seed)
logger.info("Tokenizer config | sample_documents=%s min_frequency=%s special_tokens=%s num_sources=%s", f"{data_config.tokenizer_sample_documents:,}", data_config.tokenizer_min_frequency, data_config.tokenizer_special_tokens, len(data_config.sources))
tokenizer = Tokenizer(models.BPE(unk_token="<unk>"))
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tokenizer.decoder = decoders.ByteLevel()
trainer = trainers.BpeTrainer(
vocab_size=args.vocab_size,
min_frequency=data_config.tokenizer_min_frequency,
special_tokens=data_config.tokenizer_special_tokens,
show_progress=True,
)
tokenizer.train_from_iterator(mixed_iterator(data_config, args.seed, logger), trainer=trainer)
bos_id = tokenizer.token_to_id("<bos>")
eos_id = tokenizer.token_to_id("<eos>")
pad_id = tokenizer.token_to_id("<pad>")
if bos_id is None or eos_id is None or pad_id is None:
raise RuntimeError("Tokenizer special tokens were not created correctly.")
tokenizer.post_processor = processors.TemplateProcessing(
single="<bos> $A <eos>",
pair="<bos> $A <eos> $B:1 <eos>:1",
special_tokens=[
("<bos>", bos_id),
("<eos>", eos_id),
],
)
tokenizer_path = output_dir / "tokenizer.json"
tokenizer.save(str(tokenizer_path))
metadata = {
"vocab_size": tokenizer.get_vocab_size(),
"special_tokens": {
"pad_token": "<pad>",
"bos_token": "<bos>",
"eos_token": "<eos>",
"unk_token": "<unk>",
"pad_token_id": pad_id,
"bos_token_id": bos_id,
"eos_token_id": eos_id,
"unk_token_id": tokenizer.token_to_id("<unk>"),
},
"data_config": data_config.to_dict(),
}
save_json(output_dir / "tokenizer_meta.json", metadata)
with (output_dir / "tokenizer_summary.json").open("w", encoding="utf-8") as handle:
json.dump(metadata, handle, ensure_ascii=False, indent=2)
logger.info("Tokenizer saved | path=%s", tokenizer_path)
logger.info(
"Tokenizer summary | vocab_size=%s pad_id=%s bos_id=%s eos_id=%s unk_id=%s",
tokenizer.get_vocab_size(),
pad_id,
bos_id,
eos_id,
tokenizer.token_to_id("<unk>"),
)
logger.info("Tokenizer metadata saved | path=%s", output_dir / "tokenizer_meta.json")
if __name__ == "__main__":
main()