Spaces:
Sleeping
Sleeping
| """ | |
| Minimal (byte-level) Byte Pair Encoding tokenizer. | |
| Algorithmically follows along the GPT tokenizer: | |
| https://github.com/openai/gpt-2/blob/master/src/encoder.py | |
| But: | |
| - Does not handle the regular expression splitting pattern. | |
| - Does not handle any special tokens. | |
| """ | |
| import copy | |
| from .base import Tokenizer, get_stats, merge | |
| # class BasicTokenizer(Tokenizer): | |
| # | |
| # def __init__(self): | |
| # super().__init__() | |
| # | |
| # def train(self, text, vocab_size, verbose=False): | |
| # assert vocab_size >= 256 | |
| # num_merges = vocab_size - 256 | |
| # | |
| # # input text preprocessing | |
| # text_bytes = text.encode("utf-8") # raw bytes | |
| # ids = list(text_bytes) # list of integers in range 0..255 | |
| # | |
| # # iteratively merge the most common pairs to create new tokens | |
| # merges = {} # (int, int) -> int | |
| # vocab = {idx: bytes([idx]) for idx in range(256)} # int -> bytes | |
| # for i in range(num_merges): | |
| # # count up the number of times every consecutive pair appears | |
| # stats = get_stats(ids) | |
| # # find the pair with the highest count | |
| # pair = max(stats, key=stats.get) | |
| # # mint a new token: assign it the next available id | |
| # idx = 256 + i | |
| # # replace all occurrences of pair in ids with idx | |
| # ids = merge(ids, pair, idx) | |
| # # save the merge | |
| # merges[pair] = idx | |
| # vocab[idx] = vocab[pair[0]] + vocab[pair[1]] | |
| # # prints | |
| # if verbose: | |
| # print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences") | |
| # | |
| # # save class variables | |
| # self.merges = merges # used in encode() | |
| # self.vocab = vocab # used in decode() | |
| # | |
| # def decode(self, ids): | |
| # # given ids (list of integers), return Python string | |
| # text_bytes = b"".join(self.vocab[idx] for idx in ids) | |
| # text = text_bytes.decode("utf-8", errors="replace") | |
| # return text | |
| # | |
| # def encode(self, text): | |
| # # given a string text, return the token ids | |
| # text_bytes = text.encode("utf-8") # raw bytes | |
| # ids = list(text_bytes) # list of integers in range 0..255 | |
| # while len(ids) >= 2: | |
| # # find the pair with the lowest merge index | |
| # stats = get_stats(ids) | |
| # pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | |
| # # subtle: if there are no more merges available, the key will | |
| # # result in an inf for every single pair, and the min will be | |
| # # just the first pair in the list, arbitrarily | |
| # # we can detect this terminating case by a membership check | |
| # if pair not in self.merges: | |
| # break # nothing else can be merged anymore | |
| # # otherwise let's merge the best pair (lowest merge index) | |
| # idx = self.merges[pair] | |
| # ids = merge(ids, pair, idx) | |
| # return ids | |
| class BasicTokenizer(Tokenizer): | |
| def __init__(self): | |
| super().__init__() | |
| self.merge_counter = 0 | |
| def train(self, text, vocab_size, verbose=False): | |
| # left assert in place just to introduce consistency and a hard check of the increase in vocab size and number of merges | |
| assert vocab_size >= 256 | |
| num_merges = vocab_size - 256 | |
| current_batch_merge_counter = 0 # in case not all exact `num_merges` happen | |
| # input text preprocessing | |
| text_bytes = text.encode("utf-8") # encode to get all waw bytes | |
| ids = list(text_bytes) # represent the bytes in ints | |
| # use same merge dict if exists | |
| self.merges = {} if self.merges is None else self.merges # to hold all merges (int, int) -> int | |
| # Use same vocab for this Tokenizer object if it exists | |
| # Tokenizer vocab: int -> bytes | |
| self.vocab = {idx: bytes([idx]) for idx in range(256)} if self.vocab is None else self.vocab | |
| # iteratively merge the MOST COMMON pair from the text | |
| for i in range(num_merges): | |
| # get count of pairs | |
| stats = get_stats(ids) | |
| # find the pair with the highest count | |
| # pair = max(stats, key=stats.get) | |
| # tmp_stats = copy.deepcopy(stats) | |
| # get most occurring pair from ids | |
| pair = max(stats, key=stats.get) | |
| while pair in self.merges: | |
| # pair was previously merged ... use this first to update IDS | |
| # No need to add to merges and vocab, use previously stored token | |
| already_merged_idx = self.merges[pair] | |
| # just replace already merged pairs in ids and get new ids and no need to again add to merges and vocab | |
| ids = merge(ids, pair, already_merged_idx) | |
| stats = get_stats(ids) | |
| if stats and len(ids) >= 2: | |
| pair = max(stats, key=stats.get) | |
| else: | |
| # no new merges found in this incoming data batch | |
| print(f"\n\nstopping merges as no new byte pair found in the current batch") | |
| break | |
| # this most occurring pair not merged yet in any data batch | |
| # generate a new token considering how many have been generated so far for the same tokenizer | |
| idx = len(self.vocab) + 1 | |
| # update current new generated tokens to add to self.merge_counter later | |
| current_batch_merge_counter += 1 | |
| # replace all occurrences of `pair` above in `ids` with NEW `idx` token, add this one to merges & vocab | |
| # Note: this pair has never been seen for merging | |
| ids = merge(ids, pair, idx) | |
| self.merges[pair] = idx | |
| self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]] | |
| if verbose: | |
| print(f"merge {i + 1}/{num_merges}: {pair} -> {idx} ({self.vocab[idx]}) had {stats[pair]} count") | |
| self.merge_counter += current_batch_merge_counter | |
| def decode(self, ids): | |
| # given ids (list of integers), return Python string | |
| text_bytes = b"".join(self.vocab[idx] for idx in ids) | |
| text = text_bytes.decode("utf-8", errors="replace") | |
| return text | |
| def encode(self, text): | |
| # input a string text, returns the token ids | |
| text_bytes = text.encode("utf-8") | |
| ids = list(text_bytes) | |
| while len(ids) >= 2: | |
| # here find the pair with the lowest merge index | |
| stats = get_stats(ids) | |
| pair = min(stats, key=lambda p: self.merges.get(p, float("inf"))) | |
| # if no merges i.e. the pair is not in merges dict, | |
| # the key will result in an `inf` for every single pair, | |
| # and the min will be just the first pair in the list, | |
| # we can detect this terminating case by a membership check | |
| if pair not in self.merges: | |
| break # nothing else can be merged anymore | |
| # otherwise merge the best pair NOTE: (lowest merge index) | |
| idx = self.merges[pair] | |
| ids = merge(ids, pair, idx) | |
| return ids | |