File size: 5,963 Bytes
83f012f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# build_shift_char_tokenizer.py
import json
from pathlib import Path
from typing import List
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Split
from tokenizers.normalizers import Sequence, Replace, Lowercase, NFKC
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedTokenizerFast
from tokenizers import Tokenizer, Regex, decoders
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import Split
def build_shift_char_tokenizer(
out_dir: str,
base_tokens: List[str],
*,
shift_token: str = "↨",
special_tokens: List[str] = ("<pad>", "<unk>", "<bos>", "<eos>"),
include_specials_in_128: bool = True,
):
"""
Create a HF-compatible char tokenizer with SHIFT+lowercase behavior.
- base_tokens: your full 128-token alphabet if include_specials_in_128=True,
otherwise your 128 data tokens and we’ll append specials (vocab will be >128).
- shift_token must be present in base_tokens.
"""
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
# Validate vocab sizing
base_set = list(dict.fromkeys(base_tokens)) # keep order, dedupe
if base_set != base_tokens:
raise ValueError(f"base_tokens has duplicates; order must define ids. Should be {base_tokens} but is {base_set}")
if shift_token not in base_tokens:
raise ValueError(f"'{shift_token}' must be in base_tokens.")
if include_specials_in_128:
# specials must already be present in base_tokens
missing = [t for t in special_tokens if t not in base_tokens]
if missing:
raise ValueError(f"special tokens missing from base_tokens: {missing}")
if len(base_tokens) != 128:
raise ValueError(f"base_tokens must be exactly 128 when include_specials_in_128=True (got {len(base_tokens)}).")
vocab_tokens = base_tokens
else:
# append specials; vocab_size will exceed 128
vocab_tokens = base_tokens + [t for t in special_tokens if t not in base_tokens]
# Build vocab mapping
token_to_id = {tok: i for i, tok in enumerate(vocab_tokens)}
unk_token = "<unk>" if "<unk>" in token_to_id else None
# Model: fixed WordLevel
model = WordLevel(vocab=token_to_id, unk_token=unk_token)
# Explicit uppercase mapping avoids backref issues
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
normalizer_steps = [NFKC()]
for u in uppercase:
normalizer_steps.append(Replace(Regex(u), SHIFT + u.lower()))
normalizer = Sequence(normalizer_steps)
# Pre-tokenizer: isolate every codepoint, including newlines (use DOTALL)
#pre_tok = Split(Regex(r"(?s)."), behavior="isolated")
pre_tok = Split(Regex(r"\X"), behavior="isolated")
tok = Tokenizer(model)
tok.normalizer = normalizer
tok.pre_tokenizer = pre_tok
tok.decoder = decoders.Sequence([]) # concatenate tokens verbatim
# Optional: tidy BOS/EOS on encode if you want them
# (kept minimal; models often add these themselves)
if "<bos>" in token_to_id and "<eos>" in token_to_id:
tok.post_processor = TemplateProcessing(
single="$0",
pair="$A $B",
special_tokens=[
# add e.g. ("<bos>", id), ("<eos>", id) here if you want automatic wrapping
],
)
# Wrap in HF fast tokenizer and save
hf_tok = PreTrainedTokenizerFast(
tokenizer_object=tok,
bos_token="<bos>" if "<bos>" in token_to_id else None,
eos_token="<eos>" if "<eos>" in token_to_id else None,
unk_token=unk_token,
pad_token="<pad>" if "<pad>" in token_to_id else None,
)
# metadata for HF
tokenizer_config = {
"model_max_length": 1024, # adjust for your use case
}
(Path(out_dir) / "tokenizer_config.json").write_text(json.dumps(tokenizer_config, indent=2), encoding="utf-8")
hf_tok.save_pretrained(out_dir)
print(f"Saved tokenizer to: {out_dir}")
print(f"Vocab size: {len(vocab_tokens)} (include_specials_in_128={include_specials_in_128})")
if __name__ == "__main__":
# Example: define your exact 128 tokens including specials and SHIFT.
# Keep ordering stable; ids are index positions.
# Below is a sane template to edit. Make sure length == 128.
SHIFT = "↨"
specials = ["<pad>", "<unk>", "<bos>", "<eos>"]
# Base character set (edit this list to be exactly 124 non-specials + 4 specials = 128)
chars = list("\n\t ") # newline, tab, space
chars += list("0123456789")
chars += list("abcdefghijklmnopqrstuvwxyz")
# Include punctuation/symbols you need. Keep only what you’ll actually see.
chars += list("\"!$&'#,/+=-<>*@.:;[]{}()^_?") # from your sample
chars += list("èé") # sample diacritics you mentioned
# Add SHIFT token
# Ensure NO uppercase letters are in the vocab (they’re represented via SHIFT+lowercase)
base_tokens_wo_specials = [SHIFT] + chars
# If you want exactly 128 including specials, adjust to 124 data tokens + 4 specials
# Add or remove symbols to hit 124 before specials:
# Pad with rarely-used placeholders if needed:
while len(base_tokens_wo_specials) < 124:
base_tokens_wo_specials.append(f"¤{len(base_tokens_wo_specials)}") # harmless placeholders
if len(base_tokens_wo_specials) != 124:
raise SystemExit(f"Currently have {len(base_tokens_wo_specials)} data tokens; adjust to 124 before specials.")
base_tokens_including_specials = specials + base_tokens_wo_specials # specials first is fine
build_shift_char_tokenizer(
out_dir="char128_shift_tokenizer",
base_tokens=base_tokens_including_specials,
shift_token=SHIFT,
special_tokens=specials,
include_specials_in_128=True,
)
|