#!/usr/bin/env python3 import argparse import json import os import shutil import unicodedata from pathlib import Path from typing import Iterator, List, Optional import pyarrow.parquet as pq from tokenizers import Tokenizer from tokenizers.models import BPE from tokenizers.normalizers import Lowercase, NFC, Sequence from tokenizers.pre_tokenizers import ByteLevel from tokenizers.decoders import ByteLevel as ByteLevelDecoder from tokenizers.processors import TemplateProcessing from tokenizers.trainers import BpeTrainer from tqdm import tqdm BASE_SPECIAL_TOKENS = [ "[PAD]", "[UNK]", "[SEP]", "[CLS]", "[MASK]", ] def parse_args(): parser = argparse.ArgumentParser( description="Train a Lumëa BPE tokenizer from one or several Parquet columns." ) parser.add_argument("--src", required=True) parser.add_argument( "--column", action="append", default=None, help="Parquet column to train on. Can be repeated: --column fr --column completion", ) parser.add_argument("--out", default="tokenizer_lumea") parser.add_argument("--vocab-size", type=int, default=32768) parser.add_argument("--min-frequency", type=int, default=2) parser.add_argument("--batch-size", type=int, default=100_000) parser.add_argument("--lowercase", action="store_true") parser.add_argument("--limit-rows", type=int, default=0) parser.add_argument("--no-bytelevel", action="store_true") parser.add_argument("--model-max-length", type=int, default=4096) parser.add_argument("--overwrite", action="store_true") parser.add_argument( "--bos", action="store_true", help="Add [BOS] and use it in post-processing.", ) parser.add_argument( "--eos", action="store_true", help="Add [EOS] and use it in post-processing.", ) parser.add_argument( "--add-special-token", action="append", default=[], help="Add custom special token. Example: --add-special-token boc creates [BOC]. Can be repeated.", ) parser.add_argument( "--column-separator", default="\n", help="Separator used when several columns are joined.", ) args = parser.parse_args() if args.column is None: args.column = ["text"] args.column = [str(col).strip() for col in args.column if str(col).strip()] if not args.column: raise ValueError("At least one --column is required.") return args def canonical_special_token(value: str) -> str: value = str(value).strip() if not value: raise ValueError("Empty special token is not allowed.") if value.startswith("[") and value.endswith("]"): inner = value[1:-1].strip() if not inner: raise ValueError(f"Invalid special token: {value}") return "[" + inner.upper() + "]" return "[" + value.upper() + "]" def build_special_tokens(args) -> List[str]: tokens = list(BASE_SPECIAL_TOKENS) if args.bos: tokens.append("[BOS]") if args.eos: tokens.append("[EOS]") for raw_token in args.add_special_token: tokens.append(canonical_special_token(raw_token)) seen = set() clean_tokens = [] for token in tokens: if token not in seen: seen.add(token) clean_tokens.append(token) return clean_tokens def normalize_text(text: str, lowercase: bool) -> str: text = str(text).strip() text = unicodedata.normalize("NFC", text) if lowercase: text = text.lower() return text def count_parquet_rows(src: str) -> int: pf = pq.ParquetFile(src) return int(pf.metadata.num_rows) def join_row_values(values: List[object], lowercase: bool, separator: str) -> Optional[str]: parts = [] for value in values: if value is None: continue text = normalize_text( text=value, lowercase=lowercase, ) if text: parts.append(text) if not parts: return None return separator.join(parts).strip() def parquet_text_iterator( src: str, columns: List[str], batch_size: int, lowercase: bool, limit_rows: int, separator: str, ) -> Iterator[List[str]]: pf = pq.ParquetFile(src) total_rows = int(pf.metadata.num_rows) target_rows = min(total_rows, limit_rows) if limit_rows > 0 else total_rows rows_seen = 0 rows_yielded = 0 batches_seen = 0 pbar = tqdm( total=target_rows, unit="rows", desc="Training data", dynamic_ncols=True, mininterval=0.5, smoothing=0.05, ) try: for batch in pf.iter_batches( batch_size=batch_size, columns=columns, ): batches_seen += 1 column_values = [ batch.column(i).to_pylist() for i in range(len(columns)) ] current_batch_size = len(column_values[0]) if column_values else 0 texts = [] for row_idx in range(current_batch_size): if rows_seen >= target_rows: break rows_seen += 1 row_values = [ column_values[col_idx][row_idx] for col_idx in range(len(columns)) ] text = join_row_values( values=row_values, lowercase=lowercase, separator=separator, ) if text: texts.append(text) rows_yielded += 1 pbar.update(1) pbar.set_postfix( { "batch": batches_seen, "yielded": f"{rows_yielded:,}", } ) if texts: yield texts if rows_seen >= target_rows: break finally: pbar.close() def remove_output_dir(out_dir: Path, overwrite: bool): if not out_dir.exists(): return if not overwrite: raise FileExistsError( f"Output folder already exists: {out_dir}. Use --overwrite." ) shutil.rmtree(out_dir) def save_json(path: Path, payload): with open(path, "w", encoding="utf-8") as f: json.dump( payload, f, ensure_ascii=False, indent=2, ) f.write("\n") def require_token_id(tokenizer: Tokenizer, token: str) -> int: token_id = tokenizer.token_to_id(token) if token_id is None: raise RuntimeError(f"Token missing after training: {token}") return int(token_id) def optional_token_id(tokenizer: Tokenizer, token: str): token_id = tokenizer.token_to_id(token) if token_id is None: return None return int(token_id) def save_vocab_preview(tokenizer: Tokenizer, out_dir: Path, max_items: int = 512): vocab = tokenizer.get_vocab() inv_vocab = {idx: token for token, idx in vocab.items()} with open(out_dir / "vocab_preview.txt", "w", encoding="utf-8") as f: for idx in range(min(max_items, tokenizer.get_vocab_size())): token = inv_vocab.get(idx, "") f.write(f"{idx}\t{repr(token)}\n") def build_template_processing(tokenizer: Tokenizer, use_bos: bool, use_eos: bool): special_tokens = [] if use_bos: bos_id = require_token_id(tokenizer, "[BOS]") special_tokens.append(("[BOS]", bos_id)) if use_eos: eos_id = require_token_id(tokenizer, "[EOS]") special_tokens.append(("[EOS]", eos_id)) if not use_bos and not use_eos: return None if use_bos and use_eos: single = "[BOS] $A [EOS]" pair = "[BOS] $A [EOS] $B:1 [EOS]:1" elif use_bos: single = "[BOS] $A" pair = "[BOS] $A $B:1" else: single = "$A [EOS]" pair = "$A [EOS] $B:1 [EOS]:1" return TemplateProcessing( single=single, pair=pair, special_tokens=special_tokens, ) def save_hf_tokenizer_files( tokenizer: Tokenizer, out_dir: Path, requested_vocab_size: int, lowercase: bool, model_max_length: int, bytelevel: bool, special_tokens: List[str], use_bos: bool, use_eos: bool, source_columns: List[str], column_separator: str, ): out_dir.mkdir(parents=True, exist_ok=True) tokenizer_json_path = out_dir / "tokenizer.json" tokenizer.save(str(tokenizer_json_path)) tokenizer.model.save(str(out_dir)) pad_id = require_token_id(tokenizer, "[PAD]") unk_id = require_token_id(tokenizer, "[UNK]") sep_id = require_token_id(tokenizer, "[SEP]") cls_id = require_token_id(tokenizer, "[CLS]") mask_id = require_token_id(tokenizer, "[MASK]") bos_id = optional_token_id(tokenizer, "[BOS]") eos_id = optional_token_id(tokenizer, "[EOS]") special_token_ids = { token: require_token_id(tokenizer, token) for token in special_tokens } tokenizer_config = { "tokenizer_class": "PreTrainedTokenizerFast", "model_max_length": int(model_max_length), "clean_up_tokenization_spaces": False, "padding_side": "right", "truncation_side": "right", "unk_token": "[UNK]", "pad_token": "[PAD]", "sep_token": "[SEP]", "cls_token": "[CLS]", "mask_token": "[MASK]", "unk_token_id": unk_id, "pad_token_id": pad_id, "sep_token_id": sep_id, "cls_token_id": cls_id, "mask_token_id": mask_id, "vocab_size": int(tokenizer.get_vocab_size()), "requested_vocab_size": int(requested_vocab_size), "lowercase": bool(lowercase), "bytelevel": bool(bytelevel), "use_bos": bool(use_bos), "use_eos": bool(use_eos), "source_columns": source_columns, "column_separator": column_separator, "special_tokens": special_token_ids, } special_tokens_map = { "unk_token": "[UNK]", "pad_token": "[PAD]", "sep_token": "[SEP]", "cls_token": "[CLS]", "mask_token": "[MASK]", } generation_config = { "pad_token_id": pad_id, "sep_token_id": sep_id, "cls_token_id": cls_id, "mask_token_id": mask_id, } if bos_id is not None: tokenizer_config["bos_token"] = "[BOS]" tokenizer_config["bos_token_id"] = bos_id special_tokens_map["bos_token"] = "[BOS]" generation_config["bos_token_id"] = bos_id if eos_id is not None: tokenizer_config["eos_token"] = "[EOS]" tokenizer_config["eos_token_id"] = eos_id special_tokens_map["eos_token"] = "[EOS]" generation_config["eos_token_id"] = eos_id tokenizer_config["added_tokens_decoder"] = { str(require_token_id(tokenizer, token)): { "content": token, "single_word": False, "lstrip": False, "rstrip": False, "normalized": False, "special": True, } for token in special_tokens } tokenizer_info = { "vocab_size": int(tokenizer.get_vocab_size()), "requested_vocab_size": int(requested_vocab_size), "model_max_length": int(model_max_length), "lowercase": bool(lowercase), "bytelevel": bool(bytelevel), "use_bos": bool(use_bos), "use_eos": bool(use_eos), "source_columns": source_columns, "column_separator": column_separator, "special_tokens": special_token_ids, "custom_special_tokens": [ token for token in special_tokens if token not in BASE_SPECIAL_TOKENS and token not in ["[BOS]", "[EOS]"] ], } save_json(out_dir / "tokenizer_config.json", tokenizer_config) save_json(out_dir / "special_tokens_map.json", special_tokens_map) save_json(out_dir / "tokenizer_info.json", tokenizer_info) save_json(out_dir / "generation_config.json", generation_config) save_vocab_preview(tokenizer, out_dir) print(f"[OK] Saved tokenizer.json: {tokenizer_json_path}") print(f"[OK] Saved vocab.json: {out_dir / 'vocab.json'}") print(f"[OK] Saved merges.txt: {out_dir / 'merges.txt'}") print(f"[OK] Saved tokenizer_config.json: {out_dir / 'tokenizer_config.json'}") print(f"[OK] Saved special_tokens_map.json: {out_dir / 'special_tokens_map.json'}") print(f"[OK] Saved tokenizer_info.json: {out_dir / 'tokenizer_info.json'}") print(f"[OK] Saved generation_config.json: {out_dir / 'generation_config.json'}") print(f"[OK] Saved vocab_preview.txt: {out_dir / 'vocab_preview.txt'}") print() print("[SPECIAL TOKENS]") for token, token_id in sorted(special_token_ids.items(), key=lambda x: x[1]): print(f"{token:<12} = {token_id}") def main(): args = parse_args() src = args.src out_dir = Path(args.out) if not os.path.exists(src): raise FileNotFoundError(f"Parquet file not found: {src}") remove_output_dir( out_dir=out_dir, overwrite=args.overwrite, ) pf = pq.ParquetFile(src) schema_names = pf.schema.names missing_columns = [ column for column in args.column if column not in schema_names ] if missing_columns: raise ValueError( f"Columns not found in parquet: {missing_columns}. Available columns: {schema_names}" ) special_tokens = build_special_tokens(args) total_rows = count_parquet_rows(src) target_rows = min(total_rows, args.limit_rows) if args.limit_rows > 0 else total_rows bytelevel = not args.no_bytelevel print("[INFO] Lumëa BPE tokenizer training") print(f"[INFO] Source: {src}") print(f"[INFO] Columns: {args.column}") print(f"[INFO] Output: {out_dir}") print(f"[INFO] Rows total: {total_rows:,}") print(f"[INFO] Rows target: {target_rows:,}") print(f"[INFO] Vocab size: {args.vocab_size:,}") print(f"[INFO] Min frequency: {args.min_frequency:,}") print(f"[INFO] Batch size: {args.batch_size:,}") print(f"[INFO] Lowercase: {args.lowercase}") print(f"[INFO] ByteLevel: {bytelevel}") print(f"[INFO] Model max length: {args.model_max_length:,}") print(f"[INFO] Use BOS: {args.bos}") print(f"[INFO] Use EOS: {args.eos}") print(f"[INFO] Special tokens: {special_tokens}") print(f"[INFO] Separator repr: {repr(args.column_separator)}") print() tokenizer = Tokenizer( BPE( unk_token="[UNK]", byte_fallback=False, ) ) if args.lowercase: tokenizer.normalizer = Sequence([NFC(), Lowercase()]) else: tokenizer.normalizer = NFC() if bytelevel: tokenizer.pre_tokenizer = ByteLevel( add_prefix_space=False, use_regex=True, ) tokenizer.decoder = ByteLevelDecoder( add_prefix_space=False, ) trainer = BpeTrainer( vocab_size=args.vocab_size, min_frequency=args.min_frequency, special_tokens=special_tokens, show_progress=False, initial_alphabet=ByteLevel.alphabet() if bytelevel else [], ) iterator = parquet_text_iterator( src=src, columns=args.column, batch_size=args.batch_size, lowercase=args.lowercase, limit_rows=args.limit_rows, separator=args.column_separator, ) print("[INFO] Training tokenizer...") print("[INFO] tqdm ETA is for parquet reading / sequence feeding.") print("[INFO] After 100%, BPE merge finalization can still take a moment.") print() tokenizer.train_from_iterator( iterator, trainer=trainer, ) pad_id = require_token_id(tokenizer, "[PAD]") processor = build_template_processing( tokenizer=tokenizer, use_bos=args.bos, use_eos=args.eos, ) if processor is not None: tokenizer.post_processor = processor tokenizer.enable_padding( pad_id=pad_id, pad_token="[PAD]", ) print() print("[INFO] Saving tokenizer folder...") save_hf_tokenizer_files( tokenizer=tokenizer, out_dir=out_dir, requested_vocab_size=args.vocab_size, lowercase=args.lowercase, model_max_length=args.model_max_length, bytelevel=bytelevel, special_tokens=special_tokens, use_bos=args.bos, use_eos=args.eos, source_columns=args.column, column_separator=args.column_separator, ) print() print("[CHECK]") for token in special_tokens: print(f"{repr(token)} => {tokenizer.token_to_id(token)}") print() print("[ENCODE / DECODE TEST]") sample = "installer arch serveur minimal" encoded = tokenizer.encode(sample) decoded = tokenizer.decode(encoded.ids, skip_special_tokens=True) print(f"Input: {sample}") print(f"Tokens: {encoded.tokens}") print(f"IDs: {encoded.ids}") print(f"Decoded: {decoded}") print() print("[DONE] Tokenizer trained.") print(f"[DONE] Output folder: {out_dir}") print() print("[NEXT TEST]") print("python3 - <<'PY'") print("from tokenizers import Tokenizer") print(f'tok = Tokenizer.from_file("{out_dir}/tokenizer.json")') print(f"specials = {repr(special_tokens)}") print("for s in specials:") print(" print(repr(s), '=>', tok.token_to_id(s))") print("enc = tok.encode('installer arch serveur minimal')") print("print(enc.tokens)") print("print(enc.ids)") print("print(tok.decode(enc.ids, skip_special_tokens=True))") print("PY") if __name__ == "__main__": main()