PhysiQuanty's picture
Patch ByteLevel decoder in tokenizer training
74f0b58 verified
Raw
History Blame Contribute Delete
18 kB
#!/usr/bin/env python3
import argparse
import json
import os
import shutil
import unicodedata
from pathlib import Path
from typing import Iterator, List, Optional
import pyarrow.parquet as pq
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.normalizers import Lowercase, NFC, Sequence
from tokenizers.pre_tokenizers import ByteLevel
from tokenizers.decoders import ByteLevel as ByteLevelDecoder
from tokenizers.processors import TemplateProcessing
from tokenizers.trainers import BpeTrainer
from tqdm import tqdm
BASE_SPECIAL_TOKENS = [
"[PAD]",
"[UNK]",
"[SEP]",
"[CLS]",
"[MASK]",
]
def parse_args():
parser = argparse.ArgumentParser(
description="Train a Lumëa BPE tokenizer from one or several Parquet columns."
)
parser.add_argument("--src", required=True)
parser.add_argument(
"--column",
action="append",
default=None,
help="Parquet column to train on. Can be repeated: --column fr --column completion",
)
parser.add_argument("--out", default="tokenizer_lumea")
parser.add_argument("--vocab-size", type=int, default=32768)
parser.add_argument("--min-frequency", type=int, default=2)
parser.add_argument("--batch-size", type=int, default=100_000)
parser.add_argument("--lowercase", action="store_true")
parser.add_argument("--limit-rows", type=int, default=0)
parser.add_argument("--no-bytelevel", action="store_true")
parser.add_argument("--model-max-length", type=int, default=4096)
parser.add_argument("--overwrite", action="store_true")
parser.add_argument(
"--bos",
action="store_true",
help="Add [BOS] and use it in post-processing.",
)
parser.add_argument(
"--eos",
action="store_true",
help="Add [EOS] and use it in post-processing.",
)
parser.add_argument(
"--add-special-token",
action="append",
default=[],
help="Add custom special token. Example: --add-special-token boc creates [BOC]. Can be repeated.",
)
parser.add_argument(
"--column-separator",
default="\n",
help="Separator used when several columns are joined.",
)
args = parser.parse_args()
if args.column is None:
args.column = ["text"]
args.column = [str(col).strip() for col in args.column if str(col).strip()]
if not args.column:
raise ValueError("At least one --column is required.")
return args
def canonical_special_token(value: str) -> str:
value = str(value).strip()
if not value:
raise ValueError("Empty special token is not allowed.")
if value.startswith("[") and value.endswith("]"):
inner = value[1:-1].strip()
if not inner:
raise ValueError(f"Invalid special token: {value}")
return "[" + inner.upper() + "]"
return "[" + value.upper() + "]"
def build_special_tokens(args) -> List[str]:
tokens = list(BASE_SPECIAL_TOKENS)
if args.bos:
tokens.append("[BOS]")
if args.eos:
tokens.append("[EOS]")
for raw_token in args.add_special_token:
tokens.append(canonical_special_token(raw_token))
seen = set()
clean_tokens = []
for token in tokens:
if token not in seen:
seen.add(token)
clean_tokens.append(token)
return clean_tokens
def normalize_text(text: str, lowercase: bool) -> str:
text = str(text).strip()
text = unicodedata.normalize("NFC", text)
if lowercase:
text = text.lower()
return text
def count_parquet_rows(src: str) -> int:
pf = pq.ParquetFile(src)
return int(pf.metadata.num_rows)
def join_row_values(values: List[object], lowercase: bool, separator: str) -> Optional[str]:
parts = []
for value in values:
if value is None:
continue
text = normalize_text(
text=value,
lowercase=lowercase,
)
if text:
parts.append(text)
if not parts:
return None
return separator.join(parts).strip()
def parquet_text_iterator(
src: str,
columns: List[str],
batch_size: int,
lowercase: bool,
limit_rows: int,
separator: str,
) -> Iterator[List[str]]:
pf = pq.ParquetFile(src)
total_rows = int(pf.metadata.num_rows)
target_rows = min(total_rows, limit_rows) if limit_rows > 0 else total_rows
rows_seen = 0
rows_yielded = 0
batches_seen = 0
pbar = tqdm(
total=target_rows,
unit="rows",
desc="Training data",
dynamic_ncols=True,
mininterval=0.5,
smoothing=0.05,
)
try:
for batch in pf.iter_batches(
batch_size=batch_size,
columns=columns,
):
batches_seen += 1
column_values = [
batch.column(i).to_pylist()
for i in range(len(columns))
]
current_batch_size = len(column_values[0]) if column_values else 0
texts = []
for row_idx in range(current_batch_size):
if rows_seen >= target_rows:
break
rows_seen += 1
row_values = [
column_values[col_idx][row_idx]
for col_idx in range(len(columns))
]
text = join_row_values(
values=row_values,
lowercase=lowercase,
separator=separator,
)
if text:
texts.append(text)
rows_yielded += 1
pbar.update(1)
pbar.set_postfix(
{
"batch": batches_seen,
"yielded": f"{rows_yielded:,}",
}
)
if texts:
yield texts
if rows_seen >= target_rows:
break
finally:
pbar.close()
def remove_output_dir(out_dir: Path, overwrite: bool):
if not out_dir.exists():
return
if not overwrite:
raise FileExistsError(
f"Output folder already exists: {out_dir}. Use --overwrite."
)
shutil.rmtree(out_dir)
def save_json(path: Path, payload):
with open(path, "w", encoding="utf-8") as f:
json.dump(
payload,
f,
ensure_ascii=False,
indent=2,
)
f.write("\n")
def require_token_id(tokenizer: Tokenizer, token: str) -> int:
token_id = tokenizer.token_to_id(token)
if token_id is None:
raise RuntimeError(f"Token missing after training: {token}")
return int(token_id)
def optional_token_id(tokenizer: Tokenizer, token: str):
token_id = tokenizer.token_to_id(token)
if token_id is None:
return None
return int(token_id)
def save_vocab_preview(tokenizer: Tokenizer, out_dir: Path, max_items: int = 512):
vocab = tokenizer.get_vocab()
inv_vocab = {idx: token for token, idx in vocab.items()}
with open(out_dir / "vocab_preview.txt", "w", encoding="utf-8") as f:
for idx in range(min(max_items, tokenizer.get_vocab_size())):
token = inv_vocab.get(idx, "")
f.write(f"{idx}\t{repr(token)}\n")
def build_template_processing(tokenizer: Tokenizer, use_bos: bool, use_eos: bool):
special_tokens = []
if use_bos:
bos_id = require_token_id(tokenizer, "[BOS]")
special_tokens.append(("[BOS]", bos_id))
if use_eos:
eos_id = require_token_id(tokenizer, "[EOS]")
special_tokens.append(("[EOS]", eos_id))
if not use_bos and not use_eos:
return None
if use_bos and use_eos:
single = "[BOS] $A [EOS]"
pair = "[BOS] $A [EOS] $B:1 [EOS]:1"
elif use_bos:
single = "[BOS] $A"
pair = "[BOS] $A $B:1"
else:
single = "$A [EOS]"
pair = "$A [EOS] $B:1 [EOS]:1"
return TemplateProcessing(
single=single,
pair=pair,
special_tokens=special_tokens,
)
def save_hf_tokenizer_files(
tokenizer: Tokenizer,
out_dir: Path,
requested_vocab_size: int,
lowercase: bool,
model_max_length: int,
bytelevel: bool,
special_tokens: List[str],
use_bos: bool,
use_eos: bool,
source_columns: List[str],
column_separator: str,
):
out_dir.mkdir(parents=True, exist_ok=True)
tokenizer_json_path = out_dir / "tokenizer.json"
tokenizer.save(str(tokenizer_json_path))
tokenizer.model.save(str(out_dir))
pad_id = require_token_id(tokenizer, "[PAD]")
unk_id = require_token_id(tokenizer, "[UNK]")
sep_id = require_token_id(tokenizer, "[SEP]")
cls_id = require_token_id(tokenizer, "[CLS]")
mask_id = require_token_id(tokenizer, "[MASK]")
bos_id = optional_token_id(tokenizer, "[BOS]")
eos_id = optional_token_id(tokenizer, "[EOS]")
special_token_ids = {
token: require_token_id(tokenizer, token)
for token in special_tokens
}
tokenizer_config = {
"tokenizer_class": "PreTrainedTokenizerFast",
"model_max_length": int(model_max_length),
"clean_up_tokenization_spaces": False,
"padding_side": "right",
"truncation_side": "right",
"unk_token": "[UNK]",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"cls_token": "[CLS]",
"mask_token": "[MASK]",
"unk_token_id": unk_id,
"pad_token_id": pad_id,
"sep_token_id": sep_id,
"cls_token_id": cls_id,
"mask_token_id": mask_id,
"vocab_size": int(tokenizer.get_vocab_size()),
"requested_vocab_size": int(requested_vocab_size),
"lowercase": bool(lowercase),
"bytelevel": bool(bytelevel),
"use_bos": bool(use_bos),
"use_eos": bool(use_eos),
"source_columns": source_columns,
"column_separator": column_separator,
"special_tokens": special_token_ids,
}
special_tokens_map = {
"unk_token": "[UNK]",
"pad_token": "[PAD]",
"sep_token": "[SEP]",
"cls_token": "[CLS]",
"mask_token": "[MASK]",
}
generation_config = {
"pad_token_id": pad_id,
"sep_token_id": sep_id,
"cls_token_id": cls_id,
"mask_token_id": mask_id,
}
if bos_id is not None:
tokenizer_config["bos_token"] = "[BOS]"
tokenizer_config["bos_token_id"] = bos_id
special_tokens_map["bos_token"] = "[BOS]"
generation_config["bos_token_id"] = bos_id
if eos_id is not None:
tokenizer_config["eos_token"] = "[EOS]"
tokenizer_config["eos_token_id"] = eos_id
special_tokens_map["eos_token"] = "[EOS]"
generation_config["eos_token_id"] = eos_id
tokenizer_config["added_tokens_decoder"] = {
str(require_token_id(tokenizer, token)): {
"content": token,
"single_word": False,
"lstrip": False,
"rstrip": False,
"normalized": False,
"special": True,
}
for token in special_tokens
}
tokenizer_info = {
"vocab_size": int(tokenizer.get_vocab_size()),
"requested_vocab_size": int(requested_vocab_size),
"model_max_length": int(model_max_length),
"lowercase": bool(lowercase),
"bytelevel": bool(bytelevel),
"use_bos": bool(use_bos),
"use_eos": bool(use_eos),
"source_columns": source_columns,
"column_separator": column_separator,
"special_tokens": special_token_ids,
"custom_special_tokens": [
token
for token in special_tokens
if token not in BASE_SPECIAL_TOKENS and token not in ["[BOS]", "[EOS]"]
],
}
save_json(out_dir / "tokenizer_config.json", tokenizer_config)
save_json(out_dir / "special_tokens_map.json", special_tokens_map)
save_json(out_dir / "tokenizer_info.json", tokenizer_info)
save_json(out_dir / "generation_config.json", generation_config)
save_vocab_preview(tokenizer, out_dir)
print(f"[OK] Saved tokenizer.json: {tokenizer_json_path}")
print(f"[OK] Saved vocab.json: {out_dir / 'vocab.json'}")
print(f"[OK] Saved merges.txt: {out_dir / 'merges.txt'}")
print(f"[OK] Saved tokenizer_config.json: {out_dir / 'tokenizer_config.json'}")
print(f"[OK] Saved special_tokens_map.json: {out_dir / 'special_tokens_map.json'}")
print(f"[OK] Saved tokenizer_info.json: {out_dir / 'tokenizer_info.json'}")
print(f"[OK] Saved generation_config.json: {out_dir / 'generation_config.json'}")
print(f"[OK] Saved vocab_preview.txt: {out_dir / 'vocab_preview.txt'}")
print()
print("[SPECIAL TOKENS]")
for token, token_id in sorted(special_token_ids.items(), key=lambda x: x[1]):
print(f"{token:<12} = {token_id}")
def main():
args = parse_args()
src = args.src
out_dir = Path(args.out)
if not os.path.exists(src):
raise FileNotFoundError(f"Parquet file not found: {src}")
remove_output_dir(
out_dir=out_dir,
overwrite=args.overwrite,
)
pf = pq.ParquetFile(src)
schema_names = pf.schema.names
missing_columns = [
column
for column in args.column
if column not in schema_names
]
if missing_columns:
raise ValueError(
f"Columns not found in parquet: {missing_columns}. Available columns: {schema_names}"
)
special_tokens = build_special_tokens(args)
total_rows = count_parquet_rows(src)
target_rows = min(total_rows, args.limit_rows) if args.limit_rows > 0 else total_rows
bytelevel = not args.no_bytelevel
print("[INFO] Lumëa BPE tokenizer training")
print(f"[INFO] Source: {src}")
print(f"[INFO] Columns: {args.column}")
print(f"[INFO] Output: {out_dir}")
print(f"[INFO] Rows total: {total_rows:,}")
print(f"[INFO] Rows target: {target_rows:,}")
print(f"[INFO] Vocab size: {args.vocab_size:,}")
print(f"[INFO] Min frequency: {args.min_frequency:,}")
print(f"[INFO] Batch size: {args.batch_size:,}")
print(f"[INFO] Lowercase: {args.lowercase}")
print(f"[INFO] ByteLevel: {bytelevel}")
print(f"[INFO] Model max length: {args.model_max_length:,}")
print(f"[INFO] Use BOS: {args.bos}")
print(f"[INFO] Use EOS: {args.eos}")
print(f"[INFO] Special tokens: {special_tokens}")
print(f"[INFO] Separator repr: {repr(args.column_separator)}")
print()
tokenizer = Tokenizer(
BPE(
unk_token="[UNK]",
byte_fallback=False,
)
)
if args.lowercase:
tokenizer.normalizer = Sequence([NFC(), Lowercase()])
else:
tokenizer.normalizer = NFC()
if bytelevel:
tokenizer.pre_tokenizer = ByteLevel(
add_prefix_space=False,
use_regex=True,
)
tokenizer.decoder = ByteLevelDecoder(
add_prefix_space=False,
)
trainer = BpeTrainer(
vocab_size=args.vocab_size,
min_frequency=args.min_frequency,
special_tokens=special_tokens,
show_progress=False,
initial_alphabet=ByteLevel.alphabet() if bytelevel else [],
)
iterator = parquet_text_iterator(
src=src,
columns=args.column,
batch_size=args.batch_size,
lowercase=args.lowercase,
limit_rows=args.limit_rows,
separator=args.column_separator,
)
print("[INFO] Training tokenizer...")
print("[INFO] tqdm ETA is for parquet reading / sequence feeding.")
print("[INFO] After 100%, BPE merge finalization can still take a moment.")
print()
tokenizer.train_from_iterator(
iterator,
trainer=trainer,
)
pad_id = require_token_id(tokenizer, "[PAD]")
processor = build_template_processing(
tokenizer=tokenizer,
use_bos=args.bos,
use_eos=args.eos,
)
if processor is not None:
tokenizer.post_processor = processor
tokenizer.enable_padding(
pad_id=pad_id,
pad_token="[PAD]",
)
print()
print("[INFO] Saving tokenizer folder...")
save_hf_tokenizer_files(
tokenizer=tokenizer,
out_dir=out_dir,
requested_vocab_size=args.vocab_size,
lowercase=args.lowercase,
model_max_length=args.model_max_length,
bytelevel=bytelevel,
special_tokens=special_tokens,
use_bos=args.bos,
use_eos=args.eos,
source_columns=args.column,
column_separator=args.column_separator,
)
print()
print("[CHECK]")
for token in special_tokens:
print(f"{repr(token)} => {tokenizer.token_to_id(token)}")
print()
print("[ENCODE / DECODE TEST]")
sample = "installer arch serveur minimal"
encoded = tokenizer.encode(sample)
decoded = tokenizer.decode(encoded.ids, skip_special_tokens=True)
print(f"Input: {sample}")
print(f"Tokens: {encoded.tokens}")
print(f"IDs: {encoded.ids}")
print(f"Decoded: {decoded}")
print()
print("[DONE] Tokenizer trained.")
print(f"[DONE] Output folder: {out_dir}")
print()
print("[NEXT TEST]")
print("python3 - <<'PY'")
print("from tokenizers import Tokenizer")
print(f'tok = Tokenizer.from_file("{out_dir}/tokenizer.json")')
print(f"specials = {repr(special_tokens)}")
print("for s in specials:")
print(" print(repr(s), '=>', tok.token_to_id(s))")
print("enc = tok.encode('installer arch serveur minimal')")
print("print(enc.tokens)")
print("print(enc.ids)")
print("print(tok.decode(enc.ids, skip_special_tokens=True))")
print("PY")
if __name__ == "__main__":
main()