|
|
"""Compact Chess BPE Tokenizer: Train, Upload, Load, Inference""" |
|
|
import os, json, rustbpe, tiktoken |
|
|
from datasets import load_dataset |
|
|
from huggingface_hub import HfApi, create_repo, upload_folder, hf_hub_download |
|
|
|
|
|
REPO_ID = "ItsMaxNorm/bpess" |
|
|
|
|
|
def train(vocab_size=4096, split="train[0:10000]"): |
|
|
"""Train BPE tokenizer on chess moves.""" |
|
|
ds = load_dataset('angeluriot/chess_games', split=split) |
|
|
tok = rustbpe.Tokenizer() |
|
|
tok.train_from_iterator((' '.join(g['moves_custom']) for g in ds if g['moves_custom']), vocab_size) |
|
|
return tok |
|
|
|
|
|
def save(tok, path="./tokenizer"): |
|
|
"""Save tokenizer files locally.""" |
|
|
os.makedirs(path, exist_ok=True) |
|
|
ranks = tok.get_mergeable_ranks() |
|
|
json.dump({bytes(k).decode('utf-8', errors='replace'): v for k, v in ranks}, |
|
|
open(f"{path}/vocab.json", 'w'), indent=2) |
|
|
json.dump({"pattern": tok.get_pattern(), "vocab_size": tok.vocab_size}, |
|
|
open(f"{path}/config.json", 'w')) |
|
|
return path |
|
|
|
|
|
def upload(tok, repo_id=REPO_ID, private=False): |
|
|
"""Upload tokenizer to HuggingFace Hub.""" |
|
|
path = save(tok) |
|
|
try: create_repo(repo_id, private=private) |
|
|
except: pass |
|
|
HfApi().upload_folder(folder_path=path, repo_id=repo_id) |
|
|
print(f"Uploaded: https://huggingface.co/{repo_id}") |
|
|
|
|
|
def load_tiktoken(repo_id=REPO_ID): |
|
|
"""Load tokenizer from HuggingFace as tiktoken Encoding.""" |
|
|
config = json.load(open(hf_hub_download(repo_id, "config.json"))) |
|
|
vocab = json.load(open(hf_hub_download(repo_id, "vocab.json"))) |
|
|
return tiktoken.Encoding( |
|
|
name="chess", pat_str=config["pattern"], |
|
|
mergeable_ranks={k.encode('utf-8', errors='replace'): v for k, v in vocab.items()}, |
|
|
special_tokens={} |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
tok = train(vocab_size=4096, split="train[0:10000]") |
|
|
print(f"Trained: {tok.vocab_size} tokens") |
|
|
upload(tok, REPO_ID) |
|
|
|
|
|
|
|
|
enc = load_tiktoken(REPO_ID) |
|
|
test = "w.♘g1♘f3.. b.♟c7♟c5.. w.♙d2♙d4.." |
|
|
ids = enc.encode(test) |
|
|
print(f"Encoded: {ids[:10]}... ({len(ids)} tokens)") |
|
|
print(f"Decoded: {enc.decode(ids)}") |
|
|
|