|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from typing import List, Optional, Tuple, Union
|
| from transformers.tokenization_utils import PreTrainedTokenizer
|
| from transformers import AutoTokenizer
|
| import json
|
| import regex as re
|
| from pathlib import Path
|
| from typing import Dict, List, Optional, Union
|
|
|
| BYTES_TO_UNICODE_REGEX = re.compile(r"'([^']+)':\s*([0-9]+)")
|
|
|
| def bytes_to_unicode():
|
| bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
|
| cs = bs[:]
|
| n = 0
|
| for b in range(2**8):
|
| if b not in bs:
|
| bs.append(b)
|
| cs.append(2**8 + n)
|
| n += 1
|
| cs = [chr(n) for n in cs]
|
| return dict(zip(bs, cs))
|
|
|
| def get_pairs(word):
|
| pairs = set()
|
| prev_char = word[0]
|
| for char in word[1:]:
|
| pairs.add((prev_char, char))
|
| prev_char = char
|
| return pairs
|
|
|
| class SapnousTokenizer(PreTrainedTokenizer):
|
| model_input_names = ["input_ids", "attention_mask"]
|
|
|
| def __init__(
|
| self,
|
| vocab_file: str,
|
| merges_file: Optional[str] = None,
|
| unk_token: str = "<|endoftext|>",
|
| bos_token: str = "<|startoftext|>",
|
| eos_token: str = "<|endoftext|>",
|
| pad_token: str = "<|pad|>",
|
| vision_start_token: str = "<|vision_start|>",
|
| vision_end_token: str = "<|vision_end|>",
|
| image_token: str = "<|image|>",
|
| video_token: str = "<|video|>",
|
| add_prefix_space: bool = False,
|
| **kwargs
|
| ):
|
| super().__init__(
|
| unk_token=unk_token,
|
| bos_token=bos_token,
|
| eos_token=eos_token,
|
| pad_token=pad_token,
|
| **kwargs,
|
| )
|
|
|
| self.vocab_file = vocab_file
|
| self.merges_file = merges_file
|
| self.add_prefix_space = add_prefix_space
|
|
|
| self.special_tokens = {
|
| "unk_token": unk_token,
|
| "bos_token": bos_token,
|
| "eos_token": eos_token,
|
| "pad_token": pad_token,
|
| "vision_start_token": vision_start_token,
|
| "vision_end_token": vision_end_token,
|
| "image_token": image_token,
|
| "video_token": video_token,
|
| }
|
|
|
| with Path(vocab_file).open(encoding="utf-8") as f:
|
| self.encoder = json.load(f)
|
| self.decoder = {v: k for k, v in self.encoder.items()}
|
|
|
| if merges_file:
|
| with Path(merges_file).open(encoding="utf-8") as f:
|
| bpe_merges = f.read().strip().split('\n')[1:]
|
| bpe_merges = [tuple(merge.split()) for merge in bpe_merges]
|
| self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
|
| else:
|
| self.bpe_ranks = {}
|
|
|
| self.byte_encoder = bytes_to_unicode()
|
| self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
| self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\w+| ?\d+| ?[^\s\w\d]+|\s+(?!\S)|\s+""")
|
|
|
| def bpe(self, token: str) -> str:
|
| if token in self.special_tokens.values():
|
| return token
|
|
|
| word = tuple(token)
|
| pairs = get_pairs(word)
|
|
|
| if not pairs:
|
| return token
|
|
|
| while True:
|
| bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
| if bigram not in self.bpe_ranks:
|
| break
|
| first, second = bigram
|
| new_word = []
|
| i = 0
|
| while i < len(word):
|
| try:
|
| j = word.index(first, i)
|
| new_word.extend(word[i:j])
|
| if word[j + 1] == second:
|
| new_word.append(first + second)
|
| i = j + 2
|
| else:
|
| new_word.append(word[j])
|
| i = j + 1
|
| except ValueError:
|
| new_word.extend(word[i:])
|
| break
|
| word = tuple(new_word)
|
| if len(word) == 1:
|
| break
|
| pairs = get_pairs(word)
|
| return ' '.join(word)
|
|
|
| def _tokenize(self, text: str) -> List[str]:
|
| if self.add_prefix_space:
|
| text = ' ' + text
|
|
|
| bpe_tokens = []
|
| for token in re.findall(self.pat, text):
|
| token = ''.join(self.byte_encoder[ord(b)] for b in token)
|
| bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
|
| return bpe_tokens
|
|
|
| def _convert_token_to_id(self, token: str) -> int:
|
| return self.encoder.get(token, self.encoder.get(self.unk_token))
|
|
|
| def _convert_id_to_token(self, index: int) -> str:
|
| return self.decoder.get(index, self.unk_token)
|
|
|
| def convert_tokens_to_string(self, tokens: List[str]) -> str:
|
| text = ''.join(tokens)
|
| text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors='replace')
|
| return text
|
|
|
| def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str, str]:
|
| if not filename_prefix:
|
| filename_prefix = ""
|
|
|
| vocab_file = Path(save_directory) / f"{filename_prefix}vocab.json"
|
| merge_file = Path(save_directory) / f"{filename_prefix}merges.txt"
|
|
|
| with vocab_file.open('w', encoding='utf-8') as f:
|
| json.dump(self.encoder, f, ensure_ascii=False)
|
|
|
| if self.merges_file:
|
| with merge_file.open('w', encoding='utf-8') as f:
|
| for merge in self.bpe_ranks:
|
| f.write(f"{merge[0]} {merge[1]}\n")
|
| return str(vocab_file), str(merge_file)
|
|
|
| return str(vocab_file)
|
|
|
| def prepare_for_vision(self, text: str) -> str:
|
| """Prepare text for vision tasks by adding special tokens."""
|
| return f"{self.vision_start_token}{text}{self.vision_end_token}"
|
|
|
| def prepare_for_image(self, text: str) -> str:
|
| """Prepare text for image tasks."""
|
| return f"{self.image_token}{text}"
|
|
|
| def prepare_for_video(self, text: str) -> str:
|
| """Prepare text for video tasks."""
|
| return f"{self.video_token}{text}"
|
|
|
| @property
|
| def vocab_size(self) -> int:
|
| return len(self.encoder)
|
|
|
| def get_vocab(self) -> Dict[str, int]:
|
| return self.encoder.copy()
|
|
|
|
|
| AutoTokenizer.register(SapnousTokenizer, "sapnous") |