import os import json import re from functools import cached_property from typing import Optional, Union import torch from torch import Tensor class Tokenizer: """BPE Tokenizer for Codsworth.""" def __init__( self, vocab: Optional[dict[str, int]] = None, merges: Optional[list[str]] = None, special_tokens: Optional[dict[str, int]] = None, unk_token: str = "", bos_token: str = "", eos_token: str = "", pad_token: str = "", unk_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, pad_token_id: int = 0, ): self.unk_token = unk_token self.bos_token = bos_token self.eos_token = eos_token self.pad_token = pad_token self.unk_token_id = unk_token_id self.bos_token_id = bos_token_id self.eos_token_id = eos_token_id self.pad_token_id = pad_token_id self._vocab = vocab or {} self._merges = merges or [] self._special_tokens = special_tokens or {} if special_tokens is not None: self._special_tokens = special_tokens else: self._special_tokens = { unk_token: unk_token_id, bos_token: bos_token_id, eos_token: eos_token_id, pad_token: pad_token_id, } @cached_property def vocab_size(self) -> int: return len(self._vocab) @cached_property def eos_id(self) -> int: return self._special_tokens.get(self.eos_token, self.eos_token_id) @cached_property def bos_id(self) -> int: return self._special_tokens.get(self.bos_token, self.bos_token_id) @cached_property def pad_id(self) -> int: return self._special_tokens.get(self.pad_token, self.pad_token_id) @cached_property def unk_id(self) -> int: return self._special_tokens.get(self.unk_token, self.unk_token_id) def encode( self, text: Union[str, list[str]], add_special_tokens: bool = True, add_bos: bool = True, add_eos: bool = False, ) -> list[int]: if isinstance(text, str): text = [text] token_ids = [] for seq in text: tokens = self._tokenize(seq) token_ids.extend(tokens) if add_special_tokens: if add_bos: token_ids = [self.bos_id] + token_ids if add_eos: token_ids = token_ids + [self.eos_id] return token_ids def decode( self, token_ids: Union[list[int], torch.Tensor, Tensor], remove_special_tokens: bool = True, ) -> str: if isinstance(token_ids, Tensor): token_ids = token_ids.tolist() token_ids = list(token_ids) if remove_special_tokens: special_ids = set(self._special_tokens.values()) token_ids = [t for t in token_ids if t not in special_ids] return self._decode_tokens(token_ids) def _tokenize(self, text: str) -> list[int]: return self._bpe_tokenize(text) def _bpe_tokenize(self, text: str) -> list[int]: tokens = [] for char in text.encode("utf-8"): tokens.append(char) while len(tokens) > 1: pairs = self._getPairs(tokens) if not pairs: break bigram = min( pairs, key=lambda pair: ( self._merges.index(pair) if pair in self._merges else float("inf") ), ) if bigram not in self._merges: break tokens = self._merge(tokens, bigram) return [self._vocab.get(t, self.unk_id) for t in tokens] def _getPairs(self, tokens: list) -> set: pairs = set() for i in range(len(tokens) - 1): pairs.add((tokens[i], tokens[i + 1])) return pairs def _merge(self, tokens: list, bigram: tuple) -> list: new_tokens = [] i = 0 while i < len(tokens): if i < len(tokens) - 1 and tokens[i] == bigram[0] and tokens[i + 1] == bigram[1]: new_tokens.append(bigram[0] + b"_" + bigram[1]) i += 2 else: new_tokens.append(tokens[i]) i += 1 return new_tokens def _decode_tokens(self, token_ids: list[int]) -> str: tokens = [] for token_id in token_ids: token = self._reverse_vocab.get(token_id, b"\xff\xfd") tokens.append(token) return b"".join(tokens).decode("utf-8", errors="replace") @cached_property def _reverse_vocab(self) -> dict[int, bytes]: return {v: k for k, v in self._vocab.items()} def __call__( self, text: Union[str, list[str]], return_tensors: Optional[str] = None, padding: bool = False, truncation: bool = False, max_length: Optional[int] = None, add_special_tokens: bool = True, add_bos: bool = True, add_eos: bool = False, ) -> dict[str, Union[list[int], Tensor]]: if isinstance(text, str): text = [text] encoded = [self.encode(t, add_special_tokens=add_special_tokens, add_bos=add_bos, add_eos=add_eos) for t in text] if padding or max_length is not None: max_len = max_length if max_length is not None else max(len(e) for e in encoded) if padding: encoded = [ e + [self.pad_id] * (max_len - len(e)) for e in encoded ] if truncation and max_length is not None: encoded = [e[:max_length] for e in encoded] result = {"input_ids": encoded} if return_tensors == "pt": result["input_ids"] = torch.tensor(result["input_ids"]) return result def save(self, path: str) -> None: os.makedirs(os.path.dirname(path), exist_ok=True) data = { "vocab": {k: v for k, v in self._vocab.items()}, "merges": self._merges, "special_tokens": self._special_tokens, "unk_token": self.unk_token, "bos_token": self.bos_token, "eos_token": self.eos_token, "pad_token": self.pad_token, } with open(path, "w", encoding="utf-8") as f: json.dump(data, f, ensure_ascii=False, indent=2) @classmethod def load(cls, path: str) -> "Tokenizer": with open(path, "r", encoding="utf-8") as f: data = json.load(f) return cls( vocab=data.get("vocab"), merges=data.get("merges"), special_tokens=data.get("special_tokens"), unk_token=data.get("unk_token", ""), bos_token=data.get("bos_token", ""), eos_token=data.get("eos_token", ""), pad_token=data.get("pad_token", ""), ) @classmethod def train( cls, texts: list[str], vocab_size: int = 50000, min_frequency: int = 2, ) -> "Tokenizer": from collections import Counter tokens = [t.encode("utf-8") for t in texts] vocab = {} for byte_val in range(256): vocab[bytes([byte_val])] = byte_val merges = [] ids = [list(t) for t in tokens] while len(vocab) < vocab_size: pairs = Counter() for token_ids in ids: for i in range(len(token_ids) - 1): pairs[(token_ids[i], token_ids[i + 1])] += 1 if not pairs: break best_pair = max( [p for p, c in pairs.items() if c >= min_frequency], key=lambda p: (pairs[p], -min(p)), default=None, ) if best_pair is None or pairs[best_pair] < min_frequency: break merges.append(best_pair) new_token = best_pair[0] + b"_" + best_pair[1] new_id = len(vocab) vocab[new_token] = new_id new_ids = [] for token_ids in ids: new_token_ids = [] i = 0 while i < len(token_ids): if ( i < len(token_ids) - 1 and token_ids[i] == best_pair[0] and token_ids[i + 1] == best_pair[1] ): new_token_ids.append(new_id) i += 2 else: new_token_ids.append(token_ids[i]) i += 1 new_ids.append(new_token_ids) ids = new_ids return cls( vocab=vocab, merges=merges, ) def batch_encode( self, texts: list[str], add_special_tokens: bool = True, add_bos: bool = True, add_eos: bool = False, ) -> list[list[int]]: return [ self.encode(t, add_special_tokens=add_special_tokens, add_bos=add_bos, add_eos=add_eos) for t in texts ] def batch_decode( self, token_ids: list[list[int]], remove_special_tokens: bool = True, ) -> list[str]: return [ self.decode(ids, remove_special_tokens=remove_special_tokens) for ids in token_ids ]