|
|
|
|
|
import json
|
|
|
from pathlib import Path
|
|
|
from typing import List
|
|
|
|
|
|
from tokenizers.models import WordLevel
|
|
|
from tokenizers.pre_tokenizers import Split
|
|
|
from tokenizers.normalizers import Sequence, Replace, Lowercase, NFKC
|
|
|
from tokenizers.processors import TemplateProcessing
|
|
|
from transformers import PreTrainedTokenizerFast
|
|
|
|
|
|
from tokenizers import Tokenizer, Regex, decoders
|
|
|
from tokenizers.models import WordLevel
|
|
|
from tokenizers.pre_tokenizers import Split
|
|
|
|
|
|
def build_shift_char_tokenizer(
|
|
|
out_dir: str,
|
|
|
base_tokens: List[str],
|
|
|
*,
|
|
|
shift_token: str = "↨",
|
|
|
special_tokens: List[str] = ("<pad>", "<unk>", "<bos>", "<eos>"),
|
|
|
include_specials_in_128: bool = True,
|
|
|
):
|
|
|
"""
|
|
|
Create a HF-compatible char tokenizer with SHIFT+lowercase behavior.
|
|
|
- base_tokens: your full 128-token alphabet if include_specials_in_128=True,
|
|
|
otherwise your 128 data tokens and we’ll append specials (vocab will be >128).
|
|
|
- shift_token must be present in base_tokens.
|
|
|
"""
|
|
|
out = Path(out_dir)
|
|
|
out.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
|
base_set = list(dict.fromkeys(base_tokens))
|
|
|
if base_set != base_tokens:
|
|
|
raise ValueError(f"base_tokens has duplicates; order must define ids. Should be {base_tokens} but is {base_set}")
|
|
|
|
|
|
if shift_token not in base_tokens:
|
|
|
raise ValueError(f"'{shift_token}' must be in base_tokens.")
|
|
|
|
|
|
if include_specials_in_128:
|
|
|
|
|
|
missing = [t for t in special_tokens if t not in base_tokens]
|
|
|
if missing:
|
|
|
raise ValueError(f"special tokens missing from base_tokens: {missing}")
|
|
|
if len(base_tokens) != 128:
|
|
|
raise ValueError(f"base_tokens must be exactly 128 when include_specials_in_128=True (got {len(base_tokens)}).")
|
|
|
vocab_tokens = base_tokens
|
|
|
else:
|
|
|
|
|
|
vocab_tokens = base_tokens + [t for t in special_tokens if t not in base_tokens]
|
|
|
|
|
|
|
|
|
token_to_id = {tok: i for i, tok in enumerate(vocab_tokens)}
|
|
|
unk_token = "<unk>" if "<unk>" in token_to_id else None
|
|
|
|
|
|
|
|
|
model = WordLevel(vocab=token_to_id, unk_token=unk_token)
|
|
|
|
|
|
|
|
|
uppercase = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
|
normalizer_steps = [NFKC()]
|
|
|
for u in uppercase:
|
|
|
normalizer_steps.append(Replace(Regex(u), SHIFT + u.lower()))
|
|
|
normalizer = Sequence(normalizer_steps)
|
|
|
|
|
|
|
|
|
|
|
|
pre_tok = Split(Regex(r"\X"), behavior="isolated")
|
|
|
tok = Tokenizer(model)
|
|
|
tok.normalizer = normalizer
|
|
|
tok.pre_tokenizer = pre_tok
|
|
|
|
|
|
tok.decoder = decoders.Sequence([])
|
|
|
|
|
|
|
|
|
if "<bos>" in token_to_id and "<eos>" in token_to_id:
|
|
|
tok.post_processor = TemplateProcessing(
|
|
|
single="$0",
|
|
|
pair="$A $B",
|
|
|
special_tokens=[
|
|
|
|
|
|
],
|
|
|
)
|
|
|
|
|
|
|
|
|
hf_tok = PreTrainedTokenizerFast(
|
|
|
tokenizer_object=tok,
|
|
|
bos_token="<bos>" if "<bos>" in token_to_id else None,
|
|
|
eos_token="<eos>" if "<eos>" in token_to_id else None,
|
|
|
unk_token=unk_token,
|
|
|
pad_token="<pad>" if "<pad>" in token_to_id else None,
|
|
|
)
|
|
|
|
|
|
|
|
|
tokenizer_config = {
|
|
|
"model_max_length": 1024,
|
|
|
}
|
|
|
(Path(out_dir) / "tokenizer_config.json").write_text(json.dumps(tokenizer_config, indent=2), encoding="utf-8")
|
|
|
|
|
|
hf_tok.save_pretrained(out_dir)
|
|
|
print(f"Saved tokenizer to: {out_dir}")
|
|
|
print(f"Vocab size: {len(vocab_tokens)} (include_specials_in_128={include_specials_in_128})")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
SHIFT = "↨"
|
|
|
specials = ["<pad>", "<unk>", "<bos>", "<eos>"]
|
|
|
|
|
|
|
|
|
chars = list("\n\t ")
|
|
|
chars += list("0123456789")
|
|
|
chars += list("abcdefghijklmnopqrstuvwxyz")
|
|
|
|
|
|
chars += list("\"!$&'#,/+=-<>*@.:;[]{}()^_?")
|
|
|
chars += list("èé")
|
|
|
|
|
|
|
|
|
base_tokens_wo_specials = [SHIFT] + chars
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while len(base_tokens_wo_specials) < 124:
|
|
|
base_tokens_wo_specials.append(f"¤{len(base_tokens_wo_specials)}")
|
|
|
if len(base_tokens_wo_specials) != 124:
|
|
|
raise SystemExit(f"Currently have {len(base_tokens_wo_specials)} data tokens; adjust to 124 before specials.")
|
|
|
|
|
|
base_tokens_including_specials = specials + base_tokens_wo_specials
|
|
|
|
|
|
build_shift_char_tokenizer(
|
|
|
out_dir="char128_shift_tokenizer",
|
|
|
base_tokens=base_tokens_including_specials,
|
|
|
shift_token=SHIFT,
|
|
|
special_tokens=specials,
|
|
|
include_specials_in_128=True,
|
|
|
)
|
|
|
|