from __future__ import annotations import argparse import json import sys from pathlib import Path from typing import Iterator from datasets import load_dataset from tokenizers import Tokenizer, decoders, models, pre_tokenizers, processors, trainers ROOT = Path(__file__).resolve().parents[1] sys.path.append(str(ROOT / "src")) from sllm.config import DataMixConfig, load_json, save_json from sllm.utils import setup_logger def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser(description="Train a BPE tokenizer for the sLLM pipeline.") parser.add_argument("--data-config", required=True, help="Path to data mixture JSON config.") parser.add_argument("--output-dir", required=True, help="Directory where tokenizer files will be stored.") parser.add_argument("--vocab-size", type=int, default=49_152, help="Target tokenizer vocabulary size.") parser.add_argument("--seed", type=int, default=42, help="Random seed for dataset shuffling.") return parser def iter_source_texts(source, seed: int, limit: int) -> Iterator[str]: dataset = load_dataset( path=source.path, name=source.config_name, data_dir=source.data_dir, split=source.split, revision=source.revision, streaming=source.streaming, ) if source.streaming: dataset = dataset.shuffle(seed=seed, buffer_size=source.shuffle_buffer) yielded = 0 for row in dataset: text = row.get(source.text_field or "", None) if not isinstance(text, str): continue text = text.strip() if not text: continue yield text yielded += 1 if yielded >= limit: return def mixed_iterator(config: DataMixConfig, seed: int, logger) -> Iterator[str]: weight_map = config.normalized_weights() total_docs = config.tokenizer_sample_documents per_source = { source.name: max(1, int(total_docs * weight_map[source.name])) for source in config.sources } for index, source in enumerate(config.sources): limit = source.sample_documents or per_source[source.name] logger.info( "Tokenizer source start | name=%s path=%s data_dir=%s split=%s text_field=%s limit_docs=%s streaming=%s", source.name, source.path, source.data_dir, source.split, source.text_field, f"{limit:,}", source.streaming, ) yield from iter_source_texts(source, seed + index, limit) def main() -> None: args = build_parser().parse_args() data_config = DataMixConfig.from_dict(load_json(args.data_config)) output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) logger, log_path = setup_logger("sllm.train_tokenizer", output_dir, "train_tokenizer") logger.info("Tokenizer training started") logger.info("Log file: %s", log_path) logger.info("Arguments | data_config=%s output_dir=%s vocab_size=%s seed=%s", args.data_config, args.output_dir, args.vocab_size, args.seed) logger.info("Tokenizer config | sample_documents=%s min_frequency=%s special_tokens=%s num_sources=%s", f"{data_config.tokenizer_sample_documents:,}", data_config.tokenizer_min_frequency, data_config.tokenizer_special_tokens, len(data_config.sources)) tokenizer = Tokenizer(models.BPE(unk_token="")) tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False) tokenizer.decoder = decoders.ByteLevel() trainer = trainers.BpeTrainer( vocab_size=args.vocab_size, min_frequency=data_config.tokenizer_min_frequency, special_tokens=data_config.tokenizer_special_tokens, show_progress=True, ) tokenizer.train_from_iterator(mixed_iterator(data_config, args.seed, logger), trainer=trainer) bos_id = tokenizer.token_to_id("") eos_id = tokenizer.token_to_id("") pad_id = tokenizer.token_to_id("") if bos_id is None or eos_id is None or pad_id is None: raise RuntimeError("Tokenizer special tokens were not created correctly.") tokenizer.post_processor = processors.TemplateProcessing( single=" $A ", pair=" $A $B:1 :1", special_tokens=[ ("", bos_id), ("", eos_id), ], ) tokenizer_path = output_dir / "tokenizer.json" tokenizer.save(str(tokenizer_path)) metadata = { "vocab_size": tokenizer.get_vocab_size(), "special_tokens": { "pad_token": "", "bos_token": "", "eos_token": "", "unk_token": "", "pad_token_id": pad_id, "bos_token_id": bos_id, "eos_token_id": eos_id, "unk_token_id": tokenizer.token_to_id(""), }, "data_config": data_config.to_dict(), } save_json(output_dir / "tokenizer_meta.json", metadata) with (output_dir / "tokenizer_summary.json").open("w", encoding="utf-8") as handle: json.dump(metadata, handle, ensure_ascii=False, indent=2) logger.info("Tokenizer saved | path=%s", tokenizer_path) logger.info( "Tokenizer summary | vocab_size=%s pad_id=%s bos_id=%s eos_id=%s unk_id=%s", tokenizer.get_vocab_size(), pad_id, bos_id, eos_id, tokenizer.token_to_id(""), ) logger.info("Tokenizer metadata saved | path=%s", output_dir / "tokenizer_meta.json") if __name__ == "__main__": main()