| import os |
| import json |
| import re |
| from functools import cached_property |
| from typing import Optional, Union |
| import torch |
| from torch import Tensor |
|
|
|
|
| class Tokenizer: |
| """BPE Tokenizer for Codsworth.""" |
| |
| def __init__( |
| self, |
| vocab: Optional[dict[str, int]] = None, |
| merges: Optional[list[str]] = None, |
| special_tokens: Optional[dict[str, int]] = None, |
| unk_token: str = "<unk>", |
| bos_token: str = "<bos>", |
| eos_token: str = "<eos>", |
| pad_token: str = "<pad>", |
| unk_token_id: int = 0, |
| bos_token_id: int = 1, |
| eos_token_id: int = 2, |
| pad_token_id: int = 0, |
| ): |
| self.unk_token = unk_token |
| self.bos_token = bos_token |
| self.eos_token = eos_token |
| self.pad_token = pad_token |
| |
| self.unk_token_id = unk_token_id |
| self.bos_token_id = bos_token_id |
| self.eos_token_id = eos_token_id |
| self.pad_token_id = pad_token_id |
| |
| self._vocab = vocab or {} |
| self._merges = merges or [] |
| self._special_tokens = special_tokens or {} |
| |
| if special_tokens is not None: |
| self._special_tokens = special_tokens |
| else: |
| self._special_tokens = { |
| unk_token: unk_token_id, |
| bos_token: bos_token_id, |
| eos_token: eos_token_id, |
| pad_token: pad_token_id, |
| } |
| |
| @cached_property |
| def vocab_size(self) -> int: |
| return len(self._vocab) |
| |
| @cached_property |
| def eos_id(self) -> int: |
| return self._special_tokens.get(self.eos_token, self.eos_token_id) |
| |
| @cached_property |
| def bos_id(self) -> int: |
| return self._special_tokens.get(self.bos_token, self.bos_token_id) |
| |
| @cached_property |
| def pad_id(self) -> int: |
| return self._special_tokens.get(self.pad_token, self.pad_token_id) |
| |
| @cached_property |
| def unk_id(self) -> int: |
| return self._special_tokens.get(self.unk_token, self.unk_token_id) |
| |
| def encode( |
| self, |
| text: Union[str, list[str]], |
| add_special_tokens: bool = True, |
| add_bos: bool = True, |
| add_eos: bool = False, |
| ) -> list[int]: |
| if isinstance(text, str): |
| text = [text] |
| |
| token_ids = [] |
| |
| for seq in text: |
| tokens = self._tokenize(seq) |
| token_ids.extend(tokens) |
| |
| if add_special_tokens: |
| if add_bos: |
| token_ids = [self.bos_id] + token_ids |
| if add_eos: |
| token_ids = token_ids + [self.eos_id] |
| |
| return token_ids |
| |
| def decode( |
| self, |
| token_ids: Union[list[int], torch.Tensor, Tensor], |
| remove_special_tokens: bool = True, |
| ) -> str: |
| if isinstance(token_ids, Tensor): |
| token_ids = token_ids.tolist() |
| |
| token_ids = list(token_ids) |
| |
| if remove_special_tokens: |
| special_ids = set(self._special_tokens.values()) |
| token_ids = [t for t in token_ids if t not in special_ids] |
| |
| return self._decode_tokens(token_ids) |
| |
| def _tokenize(self, text: str) -> list[int]: |
| return self._bpe_tokenize(text) |
| |
| def _bpe_tokenize(self, text: str) -> list[int]: |
| tokens = [] |
| |
| for char in text.encode("utf-8"): |
| tokens.append(char) |
| |
| while len(tokens) > 1: |
| pairs = self._getPairs(tokens) |
| |
| if not pairs: |
| break |
| |
| bigram = min( |
| pairs, |
| key=lambda pair: ( |
| self._merges.index(pair) |
| if pair in self._merges |
| else float("inf") |
| ), |
| ) |
| |
| if bigram not in self._merges: |
| break |
| |
| tokens = self._merge(tokens, bigram) |
| |
| return [self._vocab.get(t, self.unk_id) for t in tokens] |
| |
| def _getPairs(self, tokens: list) -> set: |
| pairs = set() |
| for i in range(len(tokens) - 1): |
| pairs.add((tokens[i], tokens[i + 1])) |
| return pairs |
| |
| def _merge(self, tokens: list, bigram: tuple) -> list: |
| new_tokens = [] |
| i = 0 |
| while i < len(tokens): |
| if i < len(tokens) - 1 and tokens[i] == bigram[0] and tokens[i + 1] == bigram[1]: |
| new_tokens.append(bigram[0] + b"_" + bigram[1]) |
| i += 2 |
| else: |
| new_tokens.append(tokens[i]) |
| i += 1 |
| return new_tokens |
| |
| def _decode_tokens(self, token_ids: list[int]) -> str: |
| tokens = [] |
| |
| for token_id in token_ids: |
| token = self._reverse_vocab.get(token_id, b"\xff\xfd") |
| tokens.append(token) |
| |
| return b"".join(tokens).decode("utf-8", errors="replace") |
| |
| @cached_property |
| def _reverse_vocab(self) -> dict[int, bytes]: |
| return {v: k for k, v in self._vocab.items()} |
| |
| def __call__( |
| self, |
| text: Union[str, list[str]], |
| return_tensors: Optional[str] = None, |
| padding: bool = False, |
| truncation: bool = False, |
| max_length: Optional[int] = None, |
| add_special_tokens: bool = True, |
| add_bos: bool = True, |
| add_eos: bool = False, |
| ) -> dict[str, Union[list[int], Tensor]]: |
| if isinstance(text, str): |
| text = [text] |
| |
| encoded = [self.encode(t, add_special_tokens=add_special_tokens, add_bos=add_bos, add_eos=add_eos) for t in text] |
| |
| if padding or max_length is not None: |
| max_len = max_length if max_length is not None else max(len(e) for e in encoded) |
| |
| if padding: |
| encoded = [ |
| e + [self.pad_id] * (max_len - len(e)) |
| for e in encoded |
| ] |
| |
| if truncation and max_length is not None: |
| encoded = [e[:max_length] for e in encoded] |
| |
| result = {"input_ids": encoded} |
| |
| if return_tensors == "pt": |
| result["input_ids"] = torch.tensor(result["input_ids"]) |
| |
| return result |
| |
| def save(self, path: str) -> None: |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| |
| data = { |
| "vocab": {k: v for k, v in self._vocab.items()}, |
| "merges": self._merges, |
| "special_tokens": self._special_tokens, |
| "unk_token": self.unk_token, |
| "bos_token": self.bos_token, |
| "eos_token": self.eos_token, |
| "pad_token": self.pad_token, |
| } |
| |
| with open(path, "w", encoding="utf-8") as f: |
| json.dump(data, f, ensure_ascii=False, indent=2) |
| |
| @classmethod |
| def load(cls, path: str) -> "Tokenizer": |
| with open(path, "r", encoding="utf-8") as f: |
| data = json.load(f) |
| |
| return cls( |
| vocab=data.get("vocab"), |
| merges=data.get("merges"), |
| special_tokens=data.get("special_tokens"), |
| unk_token=data.get("unk_token", "<unk>"), |
| bos_token=data.get("bos_token", "<bos>"), |
| eos_token=data.get("eos_token", "<eos>"), |
| pad_token=data.get("pad_token", "<pad>"), |
| ) |
| |
| @classmethod |
| def train( |
| cls, |
| texts: list[str], |
| vocab_size: int = 50000, |
| min_frequency: int = 2, |
| ) -> "Tokenizer": |
| from collections import Counter |
| |
| tokens = [t.encode("utf-8") for t in texts] |
| |
| vocab = {} |
| for byte_val in range(256): |
| vocab[bytes([byte_val])] = byte_val |
| |
| merges = [] |
| |
| ids = [list(t) for t in tokens] |
| |
| while len(vocab) < vocab_size: |
| pairs = Counter() |
| |
| for token_ids in ids: |
| for i in range(len(token_ids) - 1): |
| pairs[(token_ids[i], token_ids[i + 1])] += 1 |
| |
| if not pairs: |
| break |
| |
| best_pair = max( |
| [p for p, c in pairs.items() if c >= min_frequency], |
| key=lambda p: (pairs[p], -min(p)), |
| default=None, |
| ) |
| |
| if best_pair is None or pairs[best_pair] < min_frequency: |
| break |
| |
| merges.append(best_pair) |
| |
| new_token = best_pair[0] + b"_" + best_pair[1] |
| new_id = len(vocab) |
| vocab[new_token] = new_id |
| |
| new_ids = [] |
| for token_ids in ids: |
| new_token_ids = [] |
| i = 0 |
| while i < len(token_ids): |
| if ( |
| i < len(token_ids) - 1 |
| and token_ids[i] == best_pair[0] |
| and token_ids[i + 1] == best_pair[1] |
| ): |
| new_token_ids.append(new_id) |
| i += 2 |
| else: |
| new_token_ids.append(token_ids[i]) |
| i += 1 |
| new_ids.append(new_token_ids) |
| |
| ids = new_ids |
| |
| return cls( |
| vocab=vocab, |
| merges=merges, |
| ) |
| |
| def batch_encode( |
| self, |
| texts: list[str], |
| add_special_tokens: bool = True, |
| add_bos: bool = True, |
| add_eos: bool = False, |
| ) -> list[list[int]]: |
| return [ |
| self.encode(t, add_special_tokens=add_special_tokens, add_bos=add_bos, add_eos=add_eos) |
| for t in texts |
| ] |
| |
| def batch_decode( |
| self, |
| token_ids: list[list[int]], |
| remove_special_tokens: bool = True, |
| ) -> list[str]: |
| return [ |
| self.decode(ids, remove_special_tokens=remove_special_tokens) |
| for ids in token_ids |
| ] |