| |
| import argparse |
| import json |
| import os |
| import shutil |
| import unicodedata |
| from pathlib import Path |
| from typing import Iterator, List, Optional |
|
|
| import pyarrow.parquet as pq |
| from tokenizers import Tokenizer |
| from tokenizers.models import BPE |
| from tokenizers.normalizers import Lowercase, NFC, Sequence |
| from tokenizers.pre_tokenizers import ByteLevel |
| from tokenizers.decoders import ByteLevel as ByteLevelDecoder |
| from tokenizers.processors import TemplateProcessing |
| from tokenizers.trainers import BpeTrainer |
| from tqdm import tqdm |
|
|
|
|
| BASE_SPECIAL_TOKENS = [ |
| "[PAD]", |
| "[UNK]", |
| "[SEP]", |
| "[CLS]", |
| "[MASK]", |
| ] |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser( |
| description="Train a Lumëa BPE tokenizer from one or several Parquet columns." |
| ) |
|
|
| parser.add_argument("--src", required=True) |
| parser.add_argument( |
| "--column", |
| action="append", |
| default=None, |
| help="Parquet column to train on. Can be repeated: --column fr --column completion", |
| ) |
| parser.add_argument("--out", default="tokenizer_lumea") |
| parser.add_argument("--vocab-size", type=int, default=32768) |
| parser.add_argument("--min-frequency", type=int, default=2) |
| parser.add_argument("--batch-size", type=int, default=100_000) |
| parser.add_argument("--lowercase", action="store_true") |
| parser.add_argument("--limit-rows", type=int, default=0) |
| parser.add_argument("--no-bytelevel", action="store_true") |
| parser.add_argument("--model-max-length", type=int, default=4096) |
| parser.add_argument("--overwrite", action="store_true") |
|
|
| parser.add_argument( |
| "--bos", |
| action="store_true", |
| help="Add [BOS] and use it in post-processing.", |
| ) |
| parser.add_argument( |
| "--eos", |
| action="store_true", |
| help="Add [EOS] and use it in post-processing.", |
| ) |
| parser.add_argument( |
| "--add-special-token", |
| action="append", |
| default=[], |
| help="Add custom special token. Example: --add-special-token boc creates [BOC]. Can be repeated.", |
| ) |
| parser.add_argument( |
| "--column-separator", |
| default="\n", |
| help="Separator used when several columns are joined.", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| if args.column is None: |
| args.column = ["text"] |
|
|
| args.column = [str(col).strip() for col in args.column if str(col).strip()] |
|
|
| if not args.column: |
| raise ValueError("At least one --column is required.") |
|
|
| return args |
|
|
|
|
| def canonical_special_token(value: str) -> str: |
| value = str(value).strip() |
|
|
| if not value: |
| raise ValueError("Empty special token is not allowed.") |
|
|
| if value.startswith("[") and value.endswith("]"): |
| inner = value[1:-1].strip() |
| if not inner: |
| raise ValueError(f"Invalid special token: {value}") |
| return "[" + inner.upper() + "]" |
|
|
| return "[" + value.upper() + "]" |
|
|
|
|
| def build_special_tokens(args) -> List[str]: |
| tokens = list(BASE_SPECIAL_TOKENS) |
|
|
| if args.bos: |
| tokens.append("[BOS]") |
|
|
| if args.eos: |
| tokens.append("[EOS]") |
|
|
| for raw_token in args.add_special_token: |
| tokens.append(canonical_special_token(raw_token)) |
|
|
| seen = set() |
| clean_tokens = [] |
|
|
| for token in tokens: |
| if token not in seen: |
| seen.add(token) |
| clean_tokens.append(token) |
|
|
| return clean_tokens |
|
|
|
|
| def normalize_text(text: str, lowercase: bool) -> str: |
| text = str(text).strip() |
| text = unicodedata.normalize("NFC", text) |
|
|
| if lowercase: |
| text = text.lower() |
|
|
| return text |
|
|
|
|
| def count_parquet_rows(src: str) -> int: |
| pf = pq.ParquetFile(src) |
| return int(pf.metadata.num_rows) |
|
|
|
|
| def join_row_values(values: List[object], lowercase: bool, separator: str) -> Optional[str]: |
| parts = [] |
|
|
| for value in values: |
| if value is None: |
| continue |
|
|
| text = normalize_text( |
| text=value, |
| lowercase=lowercase, |
| ) |
|
|
| if text: |
| parts.append(text) |
|
|
| if not parts: |
| return None |
|
|
| return separator.join(parts).strip() |
|
|
|
|
| def parquet_text_iterator( |
| src: str, |
| columns: List[str], |
| batch_size: int, |
| lowercase: bool, |
| limit_rows: int, |
| separator: str, |
| ) -> Iterator[List[str]]: |
| pf = pq.ParquetFile(src) |
|
|
| total_rows = int(pf.metadata.num_rows) |
| target_rows = min(total_rows, limit_rows) if limit_rows > 0 else total_rows |
|
|
| rows_seen = 0 |
| rows_yielded = 0 |
| batches_seen = 0 |
|
|
| pbar = tqdm( |
| total=target_rows, |
| unit="rows", |
| desc="Training data", |
| dynamic_ncols=True, |
| mininterval=0.5, |
| smoothing=0.05, |
| ) |
|
|
| try: |
| for batch in pf.iter_batches( |
| batch_size=batch_size, |
| columns=columns, |
| ): |
| batches_seen += 1 |
|
|
| column_values = [ |
| batch.column(i).to_pylist() |
| for i in range(len(columns)) |
| ] |
|
|
| current_batch_size = len(column_values[0]) if column_values else 0 |
| texts = [] |
|
|
| for row_idx in range(current_batch_size): |
| if rows_seen >= target_rows: |
| break |
|
|
| rows_seen += 1 |
|
|
| row_values = [ |
| column_values[col_idx][row_idx] |
| for col_idx in range(len(columns)) |
| ] |
|
|
| text = join_row_values( |
| values=row_values, |
| lowercase=lowercase, |
| separator=separator, |
| ) |
|
|
| if text: |
| texts.append(text) |
| rows_yielded += 1 |
|
|
| pbar.update(1) |
|
|
| pbar.set_postfix( |
| { |
| "batch": batches_seen, |
| "yielded": f"{rows_yielded:,}", |
| } |
| ) |
|
|
| if texts: |
| yield texts |
|
|
| if rows_seen >= target_rows: |
| break |
|
|
| finally: |
| pbar.close() |
|
|
|
|
| def remove_output_dir(out_dir: Path, overwrite: bool): |
| if not out_dir.exists(): |
| return |
|
|
| if not overwrite: |
| raise FileExistsError( |
| f"Output folder already exists: {out_dir}. Use --overwrite." |
| ) |
|
|
| shutil.rmtree(out_dir) |
|
|
|
|
| def save_json(path: Path, payload): |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump( |
| payload, |
| f, |
| ensure_ascii=False, |
| indent=2, |
| ) |
| f.write("\n") |
|
|
|
|
| def require_token_id(tokenizer: Tokenizer, token: str) -> int: |
| token_id = tokenizer.token_to_id(token) |
|
|
| if token_id is None: |
| raise RuntimeError(f"Token missing after training: {token}") |
|
|
| return int(token_id) |
|
|
|
|
| def optional_token_id(tokenizer: Tokenizer, token: str): |
| token_id = tokenizer.token_to_id(token) |
|
|
| if token_id is None: |
| return None |
|
|
| return int(token_id) |
|
|
|
|
| def save_vocab_preview(tokenizer: Tokenizer, out_dir: Path, max_items: int = 512): |
| vocab = tokenizer.get_vocab() |
| inv_vocab = {idx: token for token, idx in vocab.items()} |
|
|
| with open(out_dir / "vocab_preview.txt", "w", encoding="utf-8") as f: |
| for idx in range(min(max_items, tokenizer.get_vocab_size())): |
| token = inv_vocab.get(idx, "") |
| f.write(f"{idx}\t{repr(token)}\n") |
|
|
|
|
| def build_template_processing(tokenizer: Tokenizer, use_bos: bool, use_eos: bool): |
| special_tokens = [] |
|
|
| if use_bos: |
| bos_id = require_token_id(tokenizer, "[BOS]") |
| special_tokens.append(("[BOS]", bos_id)) |
|
|
| if use_eos: |
| eos_id = require_token_id(tokenizer, "[EOS]") |
| special_tokens.append(("[EOS]", eos_id)) |
|
|
| if not use_bos and not use_eos: |
| return None |
|
|
| if use_bos and use_eos: |
| single = "[BOS] $A [EOS]" |
| pair = "[BOS] $A [EOS] $B:1 [EOS]:1" |
| elif use_bos: |
| single = "[BOS] $A" |
| pair = "[BOS] $A $B:1" |
| else: |
| single = "$A [EOS]" |
| pair = "$A [EOS] $B:1 [EOS]:1" |
|
|
| return TemplateProcessing( |
| single=single, |
| pair=pair, |
| special_tokens=special_tokens, |
| ) |
|
|
|
|
| def save_hf_tokenizer_files( |
| tokenizer: Tokenizer, |
| out_dir: Path, |
| requested_vocab_size: int, |
| lowercase: bool, |
| model_max_length: int, |
| bytelevel: bool, |
| special_tokens: List[str], |
| use_bos: bool, |
| use_eos: bool, |
| source_columns: List[str], |
| column_separator: str, |
| ): |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| tokenizer_json_path = out_dir / "tokenizer.json" |
| tokenizer.save(str(tokenizer_json_path)) |
|
|
| tokenizer.model.save(str(out_dir)) |
|
|
| pad_id = require_token_id(tokenizer, "[PAD]") |
| unk_id = require_token_id(tokenizer, "[UNK]") |
| sep_id = require_token_id(tokenizer, "[SEP]") |
| cls_id = require_token_id(tokenizer, "[CLS]") |
| mask_id = require_token_id(tokenizer, "[MASK]") |
|
|
| bos_id = optional_token_id(tokenizer, "[BOS]") |
| eos_id = optional_token_id(tokenizer, "[EOS]") |
|
|
| special_token_ids = { |
| token: require_token_id(tokenizer, token) |
| for token in special_tokens |
| } |
|
|
| tokenizer_config = { |
| "tokenizer_class": "PreTrainedTokenizerFast", |
| "model_max_length": int(model_max_length), |
| "clean_up_tokenization_spaces": False, |
| "padding_side": "right", |
| "truncation_side": "right", |
| "unk_token": "[UNK]", |
| "pad_token": "[PAD]", |
| "sep_token": "[SEP]", |
| "cls_token": "[CLS]", |
| "mask_token": "[MASK]", |
| "unk_token_id": unk_id, |
| "pad_token_id": pad_id, |
| "sep_token_id": sep_id, |
| "cls_token_id": cls_id, |
| "mask_token_id": mask_id, |
| "vocab_size": int(tokenizer.get_vocab_size()), |
| "requested_vocab_size": int(requested_vocab_size), |
| "lowercase": bool(lowercase), |
| "bytelevel": bool(bytelevel), |
| "use_bos": bool(use_bos), |
| "use_eos": bool(use_eos), |
| "source_columns": source_columns, |
| "column_separator": column_separator, |
| "special_tokens": special_token_ids, |
| } |
|
|
| special_tokens_map = { |
| "unk_token": "[UNK]", |
| "pad_token": "[PAD]", |
| "sep_token": "[SEP]", |
| "cls_token": "[CLS]", |
| "mask_token": "[MASK]", |
| } |
|
|
| generation_config = { |
| "pad_token_id": pad_id, |
| "sep_token_id": sep_id, |
| "cls_token_id": cls_id, |
| "mask_token_id": mask_id, |
| } |
|
|
| if bos_id is not None: |
| tokenizer_config["bos_token"] = "[BOS]" |
| tokenizer_config["bos_token_id"] = bos_id |
| special_tokens_map["bos_token"] = "[BOS]" |
| generation_config["bos_token_id"] = bos_id |
|
|
| if eos_id is not None: |
| tokenizer_config["eos_token"] = "[EOS]" |
| tokenizer_config["eos_token_id"] = eos_id |
| special_tokens_map["eos_token"] = "[EOS]" |
| generation_config["eos_token_id"] = eos_id |
|
|
| tokenizer_config["added_tokens_decoder"] = { |
| str(require_token_id(tokenizer, token)): { |
| "content": token, |
| "single_word": False, |
| "lstrip": False, |
| "rstrip": False, |
| "normalized": False, |
| "special": True, |
| } |
| for token in special_tokens |
| } |
|
|
| tokenizer_info = { |
| "vocab_size": int(tokenizer.get_vocab_size()), |
| "requested_vocab_size": int(requested_vocab_size), |
| "model_max_length": int(model_max_length), |
| "lowercase": bool(lowercase), |
| "bytelevel": bool(bytelevel), |
| "use_bos": bool(use_bos), |
| "use_eos": bool(use_eos), |
| "source_columns": source_columns, |
| "column_separator": column_separator, |
| "special_tokens": special_token_ids, |
| "custom_special_tokens": [ |
| token |
| for token in special_tokens |
| if token not in BASE_SPECIAL_TOKENS and token not in ["[BOS]", "[EOS]"] |
| ], |
| } |
|
|
| save_json(out_dir / "tokenizer_config.json", tokenizer_config) |
| save_json(out_dir / "special_tokens_map.json", special_tokens_map) |
| save_json(out_dir / "tokenizer_info.json", tokenizer_info) |
| save_json(out_dir / "generation_config.json", generation_config) |
| save_vocab_preview(tokenizer, out_dir) |
|
|
| print(f"[OK] Saved tokenizer.json: {tokenizer_json_path}") |
| print(f"[OK] Saved vocab.json: {out_dir / 'vocab.json'}") |
| print(f"[OK] Saved merges.txt: {out_dir / 'merges.txt'}") |
| print(f"[OK] Saved tokenizer_config.json: {out_dir / 'tokenizer_config.json'}") |
| print(f"[OK] Saved special_tokens_map.json: {out_dir / 'special_tokens_map.json'}") |
| print(f"[OK] Saved tokenizer_info.json: {out_dir / 'tokenizer_info.json'}") |
| print(f"[OK] Saved generation_config.json: {out_dir / 'generation_config.json'}") |
| print(f"[OK] Saved vocab_preview.txt: {out_dir / 'vocab_preview.txt'}") |
| print() |
| print("[SPECIAL TOKENS]") |
| for token, token_id in sorted(special_token_ids.items(), key=lambda x: x[1]): |
| print(f"{token:<12} = {token_id}") |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| src = args.src |
| out_dir = Path(args.out) |
|
|
| if not os.path.exists(src): |
| raise FileNotFoundError(f"Parquet file not found: {src}") |
|
|
| remove_output_dir( |
| out_dir=out_dir, |
| overwrite=args.overwrite, |
| ) |
|
|
| pf = pq.ParquetFile(src) |
| schema_names = pf.schema.names |
|
|
| missing_columns = [ |
| column |
| for column in args.column |
| if column not in schema_names |
| ] |
|
|
| if missing_columns: |
| raise ValueError( |
| f"Columns not found in parquet: {missing_columns}. Available columns: {schema_names}" |
| ) |
|
|
| special_tokens = build_special_tokens(args) |
|
|
| total_rows = count_parquet_rows(src) |
| target_rows = min(total_rows, args.limit_rows) if args.limit_rows > 0 else total_rows |
| bytelevel = not args.no_bytelevel |
|
|
| print("[INFO] Lumëa BPE tokenizer training") |
| print(f"[INFO] Source: {src}") |
| print(f"[INFO] Columns: {args.column}") |
| print(f"[INFO] Output: {out_dir}") |
| print(f"[INFO] Rows total: {total_rows:,}") |
| print(f"[INFO] Rows target: {target_rows:,}") |
| print(f"[INFO] Vocab size: {args.vocab_size:,}") |
| print(f"[INFO] Min frequency: {args.min_frequency:,}") |
| print(f"[INFO] Batch size: {args.batch_size:,}") |
| print(f"[INFO] Lowercase: {args.lowercase}") |
| print(f"[INFO] ByteLevel: {bytelevel}") |
| print(f"[INFO] Model max length: {args.model_max_length:,}") |
| print(f"[INFO] Use BOS: {args.bos}") |
| print(f"[INFO] Use EOS: {args.eos}") |
| print(f"[INFO] Special tokens: {special_tokens}") |
| print(f"[INFO] Separator repr: {repr(args.column_separator)}") |
| print() |
|
|
| tokenizer = Tokenizer( |
| BPE( |
| unk_token="[UNK]", |
| byte_fallback=False, |
| ) |
| ) |
|
|
| if args.lowercase: |
| tokenizer.normalizer = Sequence([NFC(), Lowercase()]) |
| else: |
| tokenizer.normalizer = NFC() |
|
|
| if bytelevel: |
| tokenizer.pre_tokenizer = ByteLevel( |
| add_prefix_space=False, |
| use_regex=True, |
| ) |
|
|
| tokenizer.decoder = ByteLevelDecoder( |
| add_prefix_space=False, |
| ) |
|
|
| trainer = BpeTrainer( |
| vocab_size=args.vocab_size, |
| min_frequency=args.min_frequency, |
| special_tokens=special_tokens, |
| show_progress=False, |
| initial_alphabet=ByteLevel.alphabet() if bytelevel else [], |
| ) |
|
|
| iterator = parquet_text_iterator( |
| src=src, |
| columns=args.column, |
| batch_size=args.batch_size, |
| lowercase=args.lowercase, |
| limit_rows=args.limit_rows, |
| separator=args.column_separator, |
| ) |
|
|
| print("[INFO] Training tokenizer...") |
| print("[INFO] tqdm ETA is for parquet reading / sequence feeding.") |
| print("[INFO] After 100%, BPE merge finalization can still take a moment.") |
| print() |
|
|
| tokenizer.train_from_iterator( |
| iterator, |
| trainer=trainer, |
| ) |
|
|
| pad_id = require_token_id(tokenizer, "[PAD]") |
|
|
| processor = build_template_processing( |
| tokenizer=tokenizer, |
| use_bos=args.bos, |
| use_eos=args.eos, |
| ) |
|
|
| if processor is not None: |
| tokenizer.post_processor = processor |
|
|
| tokenizer.enable_padding( |
| pad_id=pad_id, |
| pad_token="[PAD]", |
| ) |
|
|
| print() |
| print("[INFO] Saving tokenizer folder...") |
|
|
| save_hf_tokenizer_files( |
| tokenizer=tokenizer, |
| out_dir=out_dir, |
| requested_vocab_size=args.vocab_size, |
| lowercase=args.lowercase, |
| model_max_length=args.model_max_length, |
| bytelevel=bytelevel, |
| special_tokens=special_tokens, |
| use_bos=args.bos, |
| use_eos=args.eos, |
| source_columns=args.column, |
| column_separator=args.column_separator, |
| ) |
|
|
| print() |
| print("[CHECK]") |
| for token in special_tokens: |
| print(f"{repr(token)} => {tokenizer.token_to_id(token)}") |
|
|
| print() |
| print("[ENCODE / DECODE TEST]") |
| sample = "installer arch serveur minimal" |
| encoded = tokenizer.encode(sample) |
| decoded = tokenizer.decode(encoded.ids, skip_special_tokens=True) |
| print(f"Input: {sample}") |
| print(f"Tokens: {encoded.tokens}") |
| print(f"IDs: {encoded.ids}") |
| print(f"Decoded: {decoded}") |
|
|
| print() |
| print("[DONE] Tokenizer trained.") |
| print(f"[DONE] Output folder: {out_dir}") |
| print() |
| print("[NEXT TEST]") |
| print("python3 - <<'PY'") |
| print("from tokenizers import Tokenizer") |
| print(f'tok = Tokenizer.from_file("{out_dir}/tokenizer.json")') |
| print(f"specials = {repr(special_tokens)}") |
| print("for s in specials:") |
| print(" print(repr(s), '=>', tok.token_to_id(s))") |
| print("enc = tok.encode('installer arch serveur minimal')") |
| print("print(enc.tokens)") |
| print("print(enc.ids)") |
| print("print(tok.decode(enc.ids, skip_special_tokens=True))") |
| print("PY") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|