| import json |
| import re |
| from collections import Counter |
| from pathlib import Path |
| from typing import Iterable, List, Sequence |
|
|
|
|
| class SimpleTokenizer: |
| """A small word-and-punctuation tokenizer for CPU-only experiments.""" |
|
|
| PAD = "<pad>" |
| BOS = "<bos>" |
| EOS = "<eos>" |
| UNK = "<unk>" |
| TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE) |
|
|
| def __init__(self, vocab: List[str]): |
| self.id_to_token = vocab |
| self.token_to_id = {token: index for index, token in enumerate(vocab)} |
|
|
| @classmethod |
| def build(cls, texts: Iterable[str], min_freq: int = 1) -> "SimpleTokenizer": |
| counter: Counter[str] = Counter() |
| for text in texts: |
| counter.update(cls.tokenize(text)) |
|
|
| vocab = [cls.PAD, cls.BOS, cls.EOS, cls.UNK] |
| for token, freq in counter.most_common(): |
| if freq >= min_freq and token not in vocab: |
| vocab.append(token) |
| return cls(vocab) |
|
|
| @staticmethod |
| def tokenize(text: str) -> List[str]: |
| return SimpleTokenizer.TOKEN_PATTERN.findall(text) |
|
|
| @property |
| def vocab_size(self) -> int: |
| return len(self.id_to_token) |
|
|
| @property |
| def pad_id(self) -> int: |
| return self.token_to_id[self.PAD] |
|
|
| @property |
| def bos_id(self) -> int: |
| return self.token_to_id[self.BOS] |
|
|
| @property |
| def eos_id(self) -> int: |
| return self.token_to_id[self.EOS] |
|
|
| @property |
| def unk_id(self) -> int: |
| return self.token_to_id[self.UNK] |
|
|
| def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]: |
| tokens = self.tokenize(text) |
| ids = [self.token_to_id.get(token, self.unk_id) for token in tokens] |
| if add_bos: |
| ids.insert(0, self.bos_id) |
| if add_eos: |
| ids.append(self.eos_id) |
| return ids |
|
|
| def decode(self, token_ids: Iterable[int], skip_special_tokens: bool = True) -> str: |
| tokens: List[str] = [] |
| specials = {self.PAD, self.BOS, self.EOS, self.UNK} |
| for token_id in token_ids: |
| token = self.id_to_token[int(token_id)] |
| if skip_special_tokens and token in specials: |
| continue |
| tokens.append(token) |
|
|
| output = [] |
| for token in tokens: |
| if output and re.match(r"\w", token) and re.match(r"\w", output[-1][-1]): |
| output.append(" ") |
| elif output and token not in {".", ",", "!", "?", ":", ";", "'", '"', ")"} and output[-1] not in {"(", '"'}: |
| output.append(" ") |
| output.append(token) |
| return "".join(output).strip() |
|
|
| def save(self, path: str | Path) -> None: |
| payload = {"vocab": self.id_to_token} |
| Path(path).write_text(json.dumps(payload, indent=2), encoding="utf-8") |
|
|
| @classmethod |
| def load(cls, path: str | Path) -> "SimpleTokenizer": |
| payload = json.loads(Path(path).read_text(encoding="utf-8")) |
| return cls(payload["vocab"]) |
|
|
|
|
| class BPETokenizer: |
| """A compact BPE tokenizer with greedy longest-match encoding.""" |
|
|
| PAD = "<pad>" |
| BOS = "<bos>" |
| EOS = "<eos>" |
| UNK = "<unk>" |
| END_OF_WORD = "</w>" |
| TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE) |
|
|
| def __init__(self, vocab: Sequence[str], merges: Sequence[list[str] | tuple[str, str]]): |
| self.id_to_token = list(vocab) |
| self.token_to_id = {token: index for index, token in enumerate(self.id_to_token)} |
| self.merges = [tuple(pair) for pair in merges] |
| self.merge_ranks = {pair: index for index, pair in enumerate(self.merges)} |
|
|
| @classmethod |
| def build( |
| cls, |
| texts: Iterable[str], |
| vocab_size: int = 256, |
| min_frequency: int = 2, |
| ) -> "BPETokenizer": |
| words = Counter() |
| for text in texts: |
| words.update(cls.TOKEN_PATTERN.findall(text)) |
|
|
| word_pieces = { |
| word: tuple(list(word) + [cls.END_OF_WORD]) |
| for word, frequency in words.items() |
| if frequency >= 1 |
| } |
| merges: list[tuple[str, str]] = [] |
| special_tokens = [cls.PAD, cls.BOS, cls.EOS, cls.UNK] |
| symbol_vocab = {symbol for pieces in word_pieces.values() for symbol in pieces} |
|
|
| while len(symbol_vocab) + len(special_tokens) < vocab_size: |
| pair_counts: Counter[tuple[str, str]] = Counter() |
| for word, pieces in word_pieces.items(): |
| frequency = words[word] |
| for index in range(len(pieces) - 1): |
| pair_counts[(pieces[index], pieces[index + 1])] += frequency |
|
|
| if not pair_counts: |
| break |
|
|
| best_pair, best_frequency = pair_counts.most_common(1)[0] |
| if best_frequency < min_frequency: |
| break |
|
|
| merged_symbol = "".join(best_pair) |
| merges.append(best_pair) |
| updated: dict[str, tuple[str, ...]] = {} |
| for word, pieces in word_pieces.items(): |
| new_pieces: list[str] = [] |
| index = 0 |
| while index < len(pieces): |
| if index < len(pieces) - 1 and (pieces[index], pieces[index + 1]) == best_pair: |
| new_pieces.append(merged_symbol) |
| index += 2 |
| else: |
| new_pieces.append(pieces[index]) |
| index += 1 |
| updated[word] = tuple(new_pieces) |
| word_pieces = updated |
| symbol_vocab = {symbol for pieces in word_pieces.values() for symbol in pieces} |
|
|
| vocab = special_tokens + sorted(symbol_vocab) |
| return cls(vocab=vocab, merges=merges) |
|
|
| @staticmethod |
| def tokenize(text: str) -> List[str]: |
| return BPETokenizer.TOKEN_PATTERN.findall(text) |
|
|
| @property |
| def vocab_size(self) -> int: |
| return len(self.id_to_token) |
|
|
| @property |
| def pad_id(self) -> int: |
| return self.token_to_id[self.PAD] |
|
|
| @property |
| def bos_id(self) -> int: |
| return self.token_to_id[self.BOS] |
|
|
| @property |
| def eos_id(self) -> int: |
| return self.token_to_id[self.EOS] |
|
|
| @property |
| def unk_id(self) -> int: |
| return self.token_to_id[self.UNK] |
|
|
| def _apply_merges(self, word: str) -> list[str]: |
| pieces = list(word) + [self.END_OF_WORD] |
| if len(pieces) == 1: |
| return pieces |
|
|
| while True: |
| candidates = [] |
| for index in range(len(pieces) - 1): |
| pair = (pieces[index], pieces[index + 1]) |
| if pair in self.merge_ranks: |
| candidates.append((self.merge_ranks[pair], index, pair)) |
| if not candidates: |
| break |
|
|
| _, merge_index, pair = min(candidates) |
| pieces = pieces[:merge_index] + ["".join(pair)] + pieces[merge_index + 2 :] |
| return pieces |
|
|
| def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]: |
| ids: list[int] = [] |
| if add_bos: |
| ids.append(self.bos_id) |
| for token in self.tokenize(text): |
| if re.match(r"\w+", token): |
| pieces = self._apply_merges(token) |
| else: |
| pieces = [token + self.END_OF_WORD] |
| if pieces[0] not in self.token_to_id: |
| pieces = [token, self.END_OF_WORD] |
| for piece in pieces: |
| ids.append(self.token_to_id.get(piece, self.unk_id)) |
| if add_eos: |
| ids.append(self.eos_id) |
| return ids |
|
|
| def decode(self, token_ids: Iterable[int], skip_special_tokens: bool = True) -> str: |
| specials = {self.PAD, self.BOS, self.EOS, self.UNK} |
| words: list[str] = [] |
| current = "" |
| for token_id in token_ids: |
| token = self.id_to_token[int(token_id)] |
| if skip_special_tokens and token in specials: |
| continue |
| if token == self.END_OF_WORD: |
| if current: |
| words.append(current) |
| current = "" |
| continue |
| if token.endswith(self.END_OF_WORD): |
| current += token[: -len(self.END_OF_WORD)] |
| words.append(current) |
| current = "" |
| else: |
| current += token |
|
|
| if current: |
| words.append(current) |
|
|
| output: list[str] = [] |
| for word in words: |
| if not output: |
| output.append(word) |
| elif re.match(r"^[^\w\s]+$", word): |
| output.append(word) |
| elif re.match(r"^[^\w\s]+$", output[-1]): |
| output.append(" ") |
| output.append(word) |
| else: |
| output.append(" ") |
| output.append(word) |
| return "".join(output).replace(" ", " ").strip() |
|
|
| def save(self, path: str | Path) -> None: |
| payload = {"type": "bpe", "vocab": self.id_to_token, "merges": [list(pair) for pair in self.merges]} |
| Path(path).write_text(json.dumps(payload, indent=2), encoding="utf-8") |
|
|
| @classmethod |
| def load(cls, path: str | Path) -> "BPETokenizer": |
| payload = json.loads(Path(path).read_text(encoding="utf-8")) |
| return cls(payload["vocab"], payload.get("merges", [])) |
|
|
|
|
| Tokenizer = SimpleTokenizer | BPETokenizer |
|
|
|
|
| def load_tokenizer(path: str | Path) -> Tokenizer: |
| payload = json.loads(Path(path).read_text(encoding="utf-8")) |
| if payload.get("type") == "bpe": |
| return BPETokenizer(payload["vocab"], payload.get("merges", [])) |
| return SimpleTokenizer(payload["vocab"]) |
|
|