char128_shift_tokenizer / src /build_shift_char_tokenizer.py
Corianas's picture
Upload build_shift_char_tokenizer.py
83f012f verified
# build_shift_char_tokenizer.py
import json
from pathlib import Path
from typing import List
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Split
from tokenizers.normalizers import Sequence, Replace, Lowercase, NFKC
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer, Regex, decoders
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Split
def build_shift_char_tokenizer(
out_dir: str,
base_tokens: List[str],
*,
shift_token: str = "↨",
special_tokens: List[str] = ("<pad>", "<unk>", "<bos>", "<eos>"),
include_specials_in_128: bool = True,
):
"""
Create a HF-compatible char tokenizer with SHIFT+lowercase behavior.
- base_tokens: your full 128-token alphabet if include_specials_in_128=True,
otherwise your 128 data tokens and we’ll append specials (vocab will be >128).
- shift_token must be present in base_tokens.
"""
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
# Validate vocab sizing
base_set = list(dict.fromkeys(base_tokens)) # keep order, dedupe
if base_set != base_tokens:
raise ValueError(f"base_tokens has duplicates; order must define ids. Should be {base_tokens} but is {base_set}")
if shift_token not in base_tokens:
raise ValueError(f"'{shift_token}' must be in base_tokens.")
if include_specials_in_128:
# specials must already be present in base_tokens
missing = [t for t in special_tokens if t not in base_tokens]
if missing:
raise ValueError(f"special tokens missing from base_tokens: {missing}")
if len(base_tokens) != 128:
raise ValueError(f"base_tokens must be exactly 128 when include_specials_in_128=True (got {len(base_tokens)}).")
vocab_tokens = base_tokens
else:
# append specials; vocab_size will exceed 128
vocab_tokens = base_tokens + [t for t in special_tokens if t not in base_tokens]
# Build vocab mapping
token_to_id = {tok: i for i, tok in enumerate(vocab_tokens)}
unk_token = "<unk>" if "<unk>" in token_to_id else None
# Model: fixed WordLevel
model = WordLevel(vocab=token_to_id, unk_token=unk_token)
# Explicit uppercase mapping avoids backref issues
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
normalizer_steps = [NFKC()]
for u in uppercase:
normalizer_steps.append(Replace(Regex(u), SHIFT + u.lower()))
normalizer = Sequence(normalizer_steps)
# Pre-tokenizer: isolate every codepoint, including newlines (use DOTALL)
#pre_tok = Split(Regex(r"(?s)."), behavior="isolated")
pre_tok = Split(Regex(r"\X"), behavior="isolated")
tok = Tokenizer(model)
tok.normalizer = normalizer
tok.pre_tokenizer = pre_tok
tok.decoder = decoders.Sequence([]) # concatenate tokens verbatim
# Optional: tidy BOS/EOS on encode if you want them
# (kept minimal; models often add these themselves)
if "<bos>" in token_to_id and "<eos>" in token_to_id:
tok.post_processor = TemplateProcessing(
single="$0",
pair="$A $B",
special_tokens=[
# add e.g. ("<bos>", id), ("<eos>", id) here if you want automatic wrapping
],
)
# Wrap in HF fast tokenizer and save
hf_tok = PreTrainedTokenizerFast(
tokenizer_object=tok,
bos_token="<bos>" if "<bos>" in token_to_id else None,
eos_token="<eos>" if "<eos>" in token_to_id else None,
unk_token=unk_token,
pad_token="<pad>" if "<pad>" in token_to_id else None,
)
# metadata for HF
tokenizer_config = {
"model_max_length": 1024, # adjust for your use case
}
(Path(out_dir) / "tokenizer_config.json").write_text(json.dumps(tokenizer_config, indent=2), encoding="utf-8")
hf_tok.save_pretrained(out_dir)
print(f"Saved tokenizer to: {out_dir}")
print(f"Vocab size: {len(vocab_tokens)} (include_specials_in_128={include_specials_in_128})")
if __name__ == "__main__":
# Example: define your exact 128 tokens including specials and SHIFT.
# Keep ordering stable; ids are index positions.
# Below is a sane template to edit. Make sure length == 128.
SHIFT = "↨"
specials = ["<pad>", "<unk>", "<bos>", "<eos>"]
# Base character set (edit this list to be exactly 124 non-specials + 4 specials = 128)
chars = list("\n\t ") # newline, tab, space
chars += list("0123456789")
chars += list("abcdefghijklmnopqrstuvwxyz")
# Include punctuation/symbols you need. Keep only what you’ll actually see.
chars += list("\"!$&'#,/+=-<>*@.:;[]{}()^_?") # from your sample
chars += list("èé") # sample diacritics you mentioned
# Add SHIFT token
# Ensure NO uppercase letters are in the vocab (they’re represented via SHIFT+lowercase)
base_tokens_wo_specials = [SHIFT] + chars
# If you want exactly 128 including specials, adjust to 124 data tokens + 4 specials
# Add or remove symbols to hit 124 before specials:
# Pad with rarely-used placeholders if needed:
while len(base_tokens_wo_specials) < 124:
base_tokens_wo_specials.append(f"¤{len(base_tokens_wo_specials)}") # harmless placeholders
if len(base_tokens_wo_specials) != 124:
raise SystemExit(f"Currently have {len(base_tokens_wo_specials)} data tokens; adjust to 124 before specials.")
base_tokens_including_specials = specials + base_tokens_wo_specials # specials first is fine
build_shift_char_tokenizer(
out_dir="char128_shift_tokenizer",
base_tokens=base_tokens_including_specials,
shift_token=SHIFT,
special_tokens=specials,
include_specials_in_128=True,
)