import json import re from collections import Counter from pathlib import Path from typing import Iterable, List, Sequence class SimpleTokenizer: """A small word-and-punctuation tokenizer for CPU-only experiments.""" PAD = "" BOS = "" EOS = "" UNK = "" TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE) def __init__(self, vocab: List[str]): self.id_to_token = vocab self.token_to_id = {token: index for index, token in enumerate(vocab)} @classmethod def build(cls, texts: Iterable[str], min_freq: int = 1) -> "SimpleTokenizer": counter: Counter[str] = Counter() for text in texts: counter.update(cls.tokenize(text)) vocab = [cls.PAD, cls.BOS, cls.EOS, cls.UNK] for token, freq in counter.most_common(): if freq >= min_freq and token not in vocab: vocab.append(token) return cls(vocab) @staticmethod def tokenize(text: str) -> List[str]: return SimpleTokenizer.TOKEN_PATTERN.findall(text) @property def vocab_size(self) -> int: return len(self.id_to_token) @property def pad_id(self) -> int: return self.token_to_id[self.PAD] @property def bos_id(self) -> int: return self.token_to_id[self.BOS] @property def eos_id(self) -> int: return self.token_to_id[self.EOS] @property def unk_id(self) -> int: return self.token_to_id[self.UNK] def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]: tokens = self.tokenize(text) ids = [self.token_to_id.get(token, self.unk_id) for token in tokens] if add_bos: ids.insert(0, self.bos_id) if add_eos: ids.append(self.eos_id) return ids def decode(self, token_ids: Iterable[int], skip_special_tokens: bool = True) -> str: tokens: List[str] = [] specials = {self.PAD, self.BOS, self.EOS, self.UNK} for token_id in token_ids: token = self.id_to_token[int(token_id)] if skip_special_tokens and token in specials: continue tokens.append(token) output = [] for token in tokens: if output and re.match(r"\w", token) and re.match(r"\w", output[-1][-1]): output.append(" ") elif output and token not in {".", ",", "!", "?", ":", ";", "'", '"', ")"} and output[-1] not in {"(", '"'}: output.append(" ") output.append(token) return "".join(output).strip() def save(self, path: str | Path) -> None: payload = {"vocab": self.id_to_token} Path(path).write_text(json.dumps(payload, indent=2), encoding="utf-8") @classmethod def load(cls, path: str | Path) -> "SimpleTokenizer": payload = json.loads(Path(path).read_text(encoding="utf-8")) return cls(payload["vocab"]) class BPETokenizer: """A compact BPE tokenizer with greedy longest-match encoding.""" PAD = "" BOS = "" EOS = "" UNK = "" END_OF_WORD = "" TOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE) def __init__(self, vocab: Sequence[str], merges: Sequence[list[str] | tuple[str, str]]): self.id_to_token = list(vocab) self.token_to_id = {token: index for index, token in enumerate(self.id_to_token)} self.merges = [tuple(pair) for pair in merges] self.merge_ranks = {pair: index for index, pair in enumerate(self.merges)} @classmethod def build( cls, texts: Iterable[str], vocab_size: int = 256, min_frequency: int = 2, ) -> "BPETokenizer": words = Counter() for text in texts: words.update(cls.TOKEN_PATTERN.findall(text)) word_pieces = { word: tuple(list(word) + [cls.END_OF_WORD]) for word, frequency in words.items() if frequency >= 1 } merges: list[tuple[str, str]] = [] special_tokens = [cls.PAD, cls.BOS, cls.EOS, cls.UNK] symbol_vocab = {symbol for pieces in word_pieces.values() for symbol in pieces} while len(symbol_vocab) + len(special_tokens) < vocab_size: pair_counts: Counter[tuple[str, str]] = Counter() for word, pieces in word_pieces.items(): frequency = words[word] for index in range(len(pieces) - 1): pair_counts[(pieces[index], pieces[index + 1])] += frequency if not pair_counts: break best_pair, best_frequency = pair_counts.most_common(1)[0] if best_frequency < min_frequency: break merged_symbol = "".join(best_pair) merges.append(best_pair) updated: dict[str, tuple[str, ...]] = {} for word, pieces in word_pieces.items(): new_pieces: list[str] = [] index = 0 while index < len(pieces): if index < len(pieces) - 1 and (pieces[index], pieces[index + 1]) == best_pair: new_pieces.append(merged_symbol) index += 2 else: new_pieces.append(pieces[index]) index += 1 updated[word] = tuple(new_pieces) word_pieces = updated symbol_vocab = {symbol for pieces in word_pieces.values() for symbol in pieces} vocab = special_tokens + sorted(symbol_vocab) return cls(vocab=vocab, merges=merges) @staticmethod def tokenize(text: str) -> List[str]: return BPETokenizer.TOKEN_PATTERN.findall(text) @property def vocab_size(self) -> int: return len(self.id_to_token) @property def pad_id(self) -> int: return self.token_to_id[self.PAD] @property def bos_id(self) -> int: return self.token_to_id[self.BOS] @property def eos_id(self) -> int: return self.token_to_id[self.EOS] @property def unk_id(self) -> int: return self.token_to_id[self.UNK] def _apply_merges(self, word: str) -> list[str]: pieces = list(word) + [self.END_OF_WORD] if len(pieces) == 1: return pieces while True: candidates = [] for index in range(len(pieces) - 1): pair = (pieces[index], pieces[index + 1]) if pair in self.merge_ranks: candidates.append((self.merge_ranks[pair], index, pair)) if not candidates: break _, merge_index, pair = min(candidates) pieces = pieces[:merge_index] + ["".join(pair)] + pieces[merge_index + 2 :] return pieces def encode(self, text: str, add_bos: bool = False, add_eos: bool = False) -> List[int]: ids: list[int] = [] if add_bos: ids.append(self.bos_id) for token in self.tokenize(text): if re.match(r"\w+", token): pieces = self._apply_merges(token) else: pieces = [token + self.END_OF_WORD] if pieces[0] not in self.token_to_id: pieces = [token, self.END_OF_WORD] for piece in pieces: ids.append(self.token_to_id.get(piece, self.unk_id)) if add_eos: ids.append(self.eos_id) return ids def decode(self, token_ids: Iterable[int], skip_special_tokens: bool = True) -> str: specials = {self.PAD, self.BOS, self.EOS, self.UNK} words: list[str] = [] current = "" for token_id in token_ids: token = self.id_to_token[int(token_id)] if skip_special_tokens and token in specials: continue if token == self.END_OF_WORD: if current: words.append(current) current = "" continue if token.endswith(self.END_OF_WORD): current += token[: -len(self.END_OF_WORD)] words.append(current) current = "" else: current += token if current: words.append(current) output: list[str] = [] for word in words: if not output: output.append(word) elif re.match(r"^[^\w\s]+$", word): output.append(word) elif re.match(r"^[^\w\s]+$", output[-1]): output.append(" ") output.append(word) else: output.append(" ") output.append(word) return "".join(output).replace(" ", " ").strip() def save(self, path: str | Path) -> None: payload = {"type": "bpe", "vocab": self.id_to_token, "merges": [list(pair) for pair in self.merges]} Path(path).write_text(json.dumps(payload, indent=2), encoding="utf-8") @classmethod def load(cls, path: str | Path) -> "BPETokenizer": payload = json.loads(Path(path).read_text(encoding="utf-8")) return cls(payload["vocab"], payload.get("merges", [])) Tokenizer = SimpleTokenizer | BPETokenizer def load_tokenizer(path: str | Path) -> Tokenizer: payload = json.loads(Path(path).read_text(encoding="utf-8")) if payload.get("type") == "bpe": return BPETokenizer(payload["vocab"], payload.get("merges", [])) return SimpleTokenizer(payload["vocab"])