"""SentencePiece tokenizer training for SAGE.""" from __future__ import annotations import argparse import json from pathlib import Path from typing import Iterable, Iterator DEFAULT_SPECIAL_TOKENS = ("", "", "", "", "[INST]", "[/INST]") def iter_training_text(corpus_paths: Iterable[str], text_key: str = "text") -> Iterator[str]: """Yield training lines from plain-text or JSONL corpus files.""" for path in corpus_paths: source = Path(path) suffix = source.suffix.lower() with source.open("r", encoding="utf-8") as handle: if suffix == ".jsonl": for raw_line in handle: raw_line = raw_line.strip() if not raw_line: continue payload = json.loads(raw_line) text = payload.get(text_key) if isinstance(text, str) and text.strip(): yield text.strip() continue for raw_line in handle: text = raw_line.strip() if text: yield text def write_training_text(corpus_paths: Iterable[str], output_path: str, text_key: str = "text") -> str: """Concatenate corpus text into a plain-text file for SentencePiece.""" output = Path(output_path) output.parent.mkdir(parents=True, exist_ok=True) with output.open("w", encoding="utf-8") as sink: for line in iter_training_text(corpus_paths, text_key=text_key): sink.write(line) sink.write("\n") return str(output) def train_sentencepiece(input_path: str, model_prefix: str, vocab_size: int = 50_000) -> None: """Train a byte-fallback SentencePiece BPE model.""" import sentencepiece as spm spm.SentencePieceTrainer.train( input=input_path, model_prefix=model_prefix, model_type="bpe", vocab_size=vocab_size, character_coverage=0.9995, byte_fallback=True, bos_id=0, eos_id=1, pad_id=2, unk_id=3, user_defined_symbols=list(DEFAULT_SPECIAL_TOKENS[4:]), split_digits=False, split_by_unicode_script=False, remove_extra_whitespaces=False, normalization_rule_name="identity", hard_vocab_limit=False, ) def build_argparser() -> argparse.ArgumentParser: """Build the CLI parser.""" parser = argparse.ArgumentParser(description="Train the SAGE SentencePiece tokenizer.") parser.add_argument("--input", nargs="+", required=True, help="Plain-text or JSONL corpus files.") parser.add_argument("--model-prefix", default="tokenizer/tokenizer", help="SentencePiece model prefix.") parser.add_argument("--vocab-size", type=int, default=50_000, help="Tokenizer vocabulary size.") parser.add_argument("--training-text", default="tokenizer/training_corpus.txt", help="Temporary combined text file.") parser.add_argument("--text-key", default="text", help="JSONL field to read when --input contains .jsonl files.") return parser def main() -> None: """Train the tokenizer from CLI arguments.""" args = build_argparser().parse_args() training_text = write_training_text(args.input, args.training_text, text_key=args.text_key) train_sentencepiece(training_text, args.model_prefix, args.vocab_size) if __name__ == "__main__": main()