File size: 4,953 Bytes
71b2e68 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
from __future__ import annotations
import json
import os
import shutil
import re
from collections import Counter
from datasets import load_dataset
from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
SQUARE_MOVE_PATTERN = re.compile(r"([a-h][1-8])([a-h][1-8])")
PROMOTION_PATTERN = re.compile(r"=([NBRQ])")
def normalize_move(token: str) -> str:
if token.startswith("["):
return token
move_match = SQUARE_MOVE_PATTERN.search(token)
if not move_match:
return token
from_sq, to_sq = move_match.group(1), move_match.group(2)
promotion_suffix = ""
promo_match = PROMOTION_PATTERN.search(token)
if promo_match:
promotion_suffix = "=" + promo_match.group(1)
piece_prefix = token[:2] if len(token) >= 2 else "WP"
return f"{piece_prefix}{from_sq}{to_sq}{promotion_suffix}"
class ChessTokenizer(PreTrainedTokenizer):
model_input_names = ["input_ids", "attention_mask"]
vocab_files_names = {"vocab_file": "vocab.json"}
PAD_TOKEN = "[PAD]"
BOS_TOKEN = "[BOS]"
EOS_TOKEN = "[EOS]"
UNK_TOKEN = "[UNK]"
def __init__(self, vocab_file=None, vocab=None, **kwargs):
self._pad_token = self.PAD_TOKEN
self._bos_token = self.BOS_TOKEN
self._eos_token = self.EOS_TOKEN
self._unk_token = self.UNK_TOKEN
for t in ["pad_token", "bos_token", "eos_token", "unk_token"]:
kwargs.pop(t, None)
if vocab is None:
if vocab_file is None:
vocab_file = os.path.join(os.path.dirname(__file__), "vocab.json")
self.vocab_file = vocab_file
if os.path.exists(vocab_file):
with open(vocab_file, "r", encoding="utf-8") as f:
self._vocab = json.load(f)
else:
self._vocab = self._create_default_vocab()
else:
self._vocab = vocab
self.vocab_file = vocab_file
self._ids_to_tokens = {v: k for k, v in self._vocab.items()}
super().__init__(
pad_token=self.PAD_TOKEN,
bos_token=self.BOS_TOKEN,
eos_token=self.EOS_TOKEN,
unk_token=self.UNK_TOKEN,
**kwargs,
)
def save_pretrained(self, save_directory: str, **kwargs):
super().save_pretrained(save_directory, **kwargs)
src_path = os.path.abspath(__file__)
dst_path = os.path.join(save_directory, "tokenizer.py")
if src_path != dst_path:
shutil.copy(src_path, dst_path)
config_path = os.path.join(save_directory, "tokenizer_config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
cfg = json.load(f)
cfg["auto_map"] = {"AutoTokenizer": "tokenizer.ChessTokenizer"}
with open(config_path, "w") as f:
json.dump(cfg, f, indent=2)
def _create_default_vocab(self):
return {
t: i
for i, t in enumerate([self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN])
}
@classmethod
def build_vocab_from_dataset(
cls,
dataset_name,
split="train",
column="text",
max_vocab_size=512,
min_frequency=500,
max_samples=100000,
):
ds = load_dataset(dataset_name, split=split, streaming=True)
ds = ds.take(max_samples)
counter = Counter()
for ex in ds:
moves = [normalize_move(t) for t in ex[column].split()]
counter.update(moves)
special = [cls.PAD_TOKEN, cls.BOS_TOKEN, cls.EOS_TOKEN, cls.UNK_TOKEN]
most_common = counter.most_common(max_vocab_size - len(special))
vocab = {t: i for i, t in enumerate(special + [t for t, c in most_common])}
return cls(vocab=vocab)
@property
def vocab_size(self):
return len(self._vocab)
def get_vocab(self):
return dict(self._vocab)
def _tokenize(self, text):
return [normalize_move(t) for t in text.strip().split()]
def _convert_token_to_id(self, token):
return self._vocab.get(token, self._vocab.get(self.UNK_TOKEN))
def _convert_id_to_token(self, index):
return self._ids_to_tokens.get(index, self.UNK_TOKEN)
def convert_tokens_to_string(self, tokens):
return " ".join(
t
for t in tokens
if t not in [self.PAD_TOKEN, self.BOS_TOKEN, self.EOS_TOKEN, self.UNK_TOKEN]
)
def save_vocabulary(self, save_directory, filename_prefix=None):
if not os.path.isdir(save_directory):
os.makedirs(save_directory, exist_ok=True)
path = os.path.join(
save_directory, (filename_prefix + "-" if filename_prefix else "") + "vocab.json"
)
with open(path, "w", encoding="utf-8") as f:
json.dump(self._vocab, f, ensure_ascii=False, indent=2)
return (path,)
|