| import re |
| import unicodedata |
| from collections import Counter |
| from collections.abc import Mapping |
| from dataclasses import dataclass, field |
| from string import ascii_letters, digits |
|
|
| from .reasoning import REASONING_CONTROL_TOKENS, TOKENIZER_NAME |
|
|
| PRETOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE) |
| BYTE_FALLBACK_PATTERN = re.compile(r"<byte:([0-9A-F]{2})>") |
| DEFAULT_FALLBACK_CHARACTERS = ( |
| ascii_letters |
| + digits |
| + "'-_/.:,;!?()[]{}@#$%&*+=" |
| + "’ʼ‘“”—–…" |
| ) |
| MAX_TOKENIZER_VOCAB_SIZE = 65536 |
| MAX_SEGMENT_CACHE_SIZE = 200_000 |
| MAX_TRAINED_PAIR_MERGES = 384 |
|
|
|
|
| def _is_word_character(character: str) -> bool: |
| category = unicodedata.category(character) |
| return character == "_" or category[0] in {"L", "N"} or category == "Mn" |
|
|
|
|
| def _is_variation_selector(character: str) -> bool: |
| return "VARIATION SELECTOR" in unicodedata.name(character, "") |
|
|
|
|
| def _is_zero_width_joiner(character: str) -> bool: |
| return unicodedata.name(character, "") == "ZERO WIDTH JOINER" |
|
|
|
|
| def _is_emoji_modifier(character: str) -> bool: |
| return "EMOJI MODIFIER" in unicodedata.name(character, "") |
|
|
|
|
| def _is_emoji_base_character(character: str) -> bool: |
| name = unicodedata.name(character, "") |
| category = unicodedata.category(character) |
| return ( |
| "EMOJI" in name |
| or "REGIONAL INDICATOR SYMBOL" in name |
| or (category in {"So", "Sk"} and ord(character) >= 0x2100) |
| ) |
|
|
|
|
| def _is_emoji_continuation_character(character: str) -> bool: |
| category = unicodedata.category(character) |
| name = unicodedata.name(character, "") |
| return ( |
| _is_variation_selector(character) |
| or _is_zero_width_joiner(character) |
| or _is_emoji_modifier(character) |
| or category in {"Mn", "Me"} |
| or name.startswith("TAG ") |
| ) |
|
|
|
|
| def _consume_emoji_cluster(text: str, start: int) -> int: |
| if start >= len(text) or not _is_emoji_base_character(text[start]): |
| return start |
|
|
| index = start + 1 |
| if "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[start], ""): |
| if index < len(text) and "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[index], ""): |
| return index + 1 |
| return index |
|
|
| while index < len(text): |
| if _is_emoji_continuation_character(text[index]): |
| index += 1 |
| continue |
| if _is_zero_width_joiner(text[index - 1]) and _is_emoji_base_character(text[index]): |
| index += 1 |
| continue |
| break |
| return index |
|
|
|
|
| def _byte_token(value: int) -> str: |
| return f"<byte:{value:02X}>" |
|
|
|
|
| def _byte_value(piece: str) -> int | None: |
| match = BYTE_FALLBACK_PATTERN.fullmatch(piece) |
| if match is None: |
| return None |
| return int(match.group(1), 16) |
|
|
|
|
| def _is_punctuation_piece(piece: str) -> bool: |
| return bool(piece) and all( |
| unicodedata.category(character).startswith("P") |
| for character in piece |
| ) |
|
|
|
|
| def _is_opening_punctuation(piece: str) -> bool: |
| return bool(piece) and all( |
| unicodedata.category(character) in {"Ps", "Pi"} |
| for character in piece |
| ) |
|
|
|
|
| def _is_call_opening_punctuation(piece: str) -> bool: |
| return bool(piece) and all( |
| unicodedata.category(character) == "Ps" |
| and "PARENTHESIS" in unicodedata.name(character, "") |
| for character in piece |
| ) |
|
|
|
|
| def _is_closing_or_terminal_punctuation(piece: str) -> bool: |
| return bool(piece) and all( |
| unicodedata.category(character) in {"Pe", "Pf", "Po"} |
| for character in piece |
| ) |
|
|
|
|
| def _is_infix_joiner(piece: str) -> bool: |
| if len(piece) != 1: |
| return False |
| category = unicodedata.category(piece) |
| name = unicodedata.name(piece, "") |
| return ( |
| category == "Pd" |
| or "APOSTROPHE" in name |
| or (category == "Pf" and "SINGLE QUOTATION MARK" in name) |
| or "SOLIDUS" in name |
| ) |
|
|
|
|
| def _is_dash_joiner(piece: str) -> bool: |
| if len(piece) != 1: |
| return False |
| category = unicodedata.category(piece) |
| name = unicodedata.name(piece, "") |
| return category == "Pd" or "HYPHEN" in name or "DASH" in name |
|
|
|
|
| def _is_quote_piece(piece: str) -> bool: |
| if len(piece) != 1: |
| return False |
| if _is_infix_joiner(piece): |
| return False |
| name = unicodedata.name(piece, "") |
| category = unicodedata.category(piece) |
| return "QUOTATION MARK" in name or category in {"Pi", "Pf"} |
|
|
|
|
| def _merge_symbol(left: str, right: str, prefix: str) -> str: |
| if right.startswith(prefix): |
| return left + right[len(prefix):] |
| return left + right |
|
|
|
|
| def _merge_sequence(symbols: list[str], pair: tuple[str, str], merged_symbol: str) -> list[str]: |
| merged: list[str] = [] |
| index = 0 |
| while index < len(symbols): |
| if index < len(symbols) - 1 and (symbols[index], symbols[index + 1]) == pair: |
| merged.append(merged_symbol) |
| index += 2 |
| else: |
| merged.append(symbols[index]) |
| index += 1 |
| return merged |
|
|
|
|
| def _default_symbol_inventory(word_prefix: str) -> set[str]: |
| symbols: set[str] = set() |
| for character in DEFAULT_FALLBACK_CHARACTERS: |
| symbols.add(character) |
| symbols.add(f"{word_prefix}{character}") |
| for value in range(256): |
| token = _byte_token(value) |
| symbols.add(token) |
| symbols.add(f"{word_prefix}{token}") |
| return symbols |
|
|
|
|
| def _whole_segment_token(segment: str, word_prefix: str) -> str: |
| return f"{word_prefix}{segment}" |
|
|
|
|
| def recommend_vocab_size( |
| text: str, |
| *, |
| minimum: int = 768, |
| maximum: int = 1536, |
| multiplier: int = 5, |
| lowercase: bool = False, |
| ) -> int: |
| seed_tokenizer = NativeTokenizer( |
| merges=[], |
| vocab=[], |
| base_symbols=[], |
| lowercase=lowercase, |
| ) |
| segments = seed_tokenizer.pretokenize(text) |
| distinct_segments = len(set(segments)) |
| recommended = max(minimum, distinct_segments * multiplier) |
| return min(maximum, recommended) |
|
|
|
|
| def clamp_vocab_size(requested: int, *, maximum: int = MAX_TOKENIZER_VOCAB_SIZE) -> int: |
| return min(maximum, max(1, requested)) |
|
|
|
|
| @dataclass(slots=True) |
| class NativeTokenizer: |
| merges: list[tuple[str, str]] |
| vocab: list[str] |
| base_symbols: list[str] |
| name: str = TOKENIZER_NAME |
| lowercase: bool = False |
| word_prefix: str = "▁" |
| unk_token: str = "<unk>" |
| bos_token: str = "<bos>" |
| eos_token: str = "<eos>" |
| pad_token: str = "<pad>" |
| _merge_ranks: dict[tuple[str, str], int] = field(init=False, repr=False) |
| _vocab_set: set[str] = field(init=False, repr=False) |
| _base_symbol_set: set[str] = field(init=False, repr=False) |
| _pretoken_pattern: re.Pattern[str] = field(init=False, repr=False) |
| _segment_cache: dict[str, tuple[str, ...]] = field(init=False, repr=False) |
|
|
| def __post_init__(self) -> None: |
| self._merge_ranks = {pair: index for index, pair in enumerate(self.merges)} |
| self._base_symbol_set = set(self.base_symbols) |
| self._vocab_set = set(self.vocab) | self.special_tokens | self._base_symbol_set |
| self.vocab = sorted(self._vocab_set) |
| self._pretoken_pattern = self._build_pretoken_pattern() |
| self._segment_cache = {} |
|
|
| @property |
| def special_tokens(self) -> set[str]: |
| return { |
| self.unk_token, |
| self.bos_token, |
| self.eos_token, |
| self.pad_token, |
| *REASONING_CONTROL_TOKENS, |
| } |
|
|
| @property |
| def vocab_size(self) -> int: |
| return len(self._vocab_set) |
|
|
| def normalize(self, text: str) -> str: |
| normalized = unicodedata.normalize("NFKC", text) |
| return normalized.lower() if self.lowercase else normalized |
|
|
| def pretokenize(self, text: str) -> list[str]: |
| normalized = self.normalize(text) |
| segments: list[str] = [] |
| reserved = sorted(self.special_tokens, key=len, reverse=True) |
| index = 0 |
| while index < len(normalized): |
| if normalized[index].isspace(): |
| if normalized[index] == "\r": |
| if index + 1 < len(normalized) and normalized[index + 1] == "\n": |
| segments.append("\n") |
| index += 2 |
| continue |
| segments.append("\n") |
| index += 1 |
| continue |
| if normalized[index] == "\n": |
| segments.append("\n") |
| index += 1 |
| continue |
| index += 1 |
| continue |
|
|
| matched_special = next( |
| ( |
| token |
| for token in reserved |
| if normalized.startswith(token, index) |
| ), |
| None, |
| ) |
| if matched_special is not None: |
| segments.append(matched_special) |
| index += len(matched_special) |
| continue |
|
|
| emoji_end = _consume_emoji_cluster(normalized, index) |
| if emoji_end > index: |
| segments.append(normalized[index:emoji_end]) |
| index = emoji_end |
| continue |
|
|
| if _is_word_character(normalized[index]): |
| start = index |
| index += 1 |
| while index < len(normalized) and _is_word_character(normalized[index]): |
| index += 1 |
| segments.append(normalized[start:index]) |
| continue |
|
|
| segments.append(normalized[index]) |
| index += 1 |
| return segments |
|
|
| def encode(self, text: str, *, add_special_tokens: bool = False) -> list[str]: |
| tokens: list[str] = [] |
| if add_special_tokens: |
| tokens.append(self.bos_token) |
|
|
| for segment in self.pretokenize(text): |
| tokens.extend(self._encode_segment_cached(segment)) |
|
|
| if add_special_tokens: |
| tokens.append(self.eos_token) |
|
|
| if not tokens and text.strip(): |
| return [self.unk_token] |
| return tokens |
|
|
| def encode_many( |
| self, |
| texts: list[str] | tuple[str, ...], |
| *, |
| add_special_tokens: bool = False, |
| ) -> list[list[str]]: |
| return [ |
| self.encode(text, add_special_tokens=add_special_tokens) |
| for text in texts |
| ] |
|
|
| def decode(self, tokens: list[str]) -> str: |
| text = "" |
| join_next = False |
| byte_buffer = bytearray() |
| byte_starts_segment = False |
|
|
| def next_rendered_piece(start_index: int) -> str | None: |
| for raw_token in tokens[start_index:]: |
| if raw_token in self.special_tokens: |
| continue |
| raw_starts_segment = raw_token.startswith(self.word_prefix) |
| raw_piece = raw_token[len(self.word_prefix) :] if raw_starts_segment else raw_token |
| if not raw_piece: |
| continue |
| if _byte_value(raw_piece) is not None: |
| return None |
| return raw_piece |
| return None |
|
|
| def append_piece(piece: str, starts_segment: bool, next_piece: str | None = None) -> None: |
| nonlocal text, join_next |
|
|
| if piece == "\n": |
| text = text.rstrip(" ") |
| text += "\n" |
| join_next = True |
| return |
|
|
| had_text_before_piece = bool(text.strip()) |
| previous_before_piece = text.rstrip(" ")[-1:] if text.strip(" ") else "" |
| if _is_quote_piece(piece): |
| quote_count = sum(1 for character in text if _is_quote_piece(character)) |
| opens_quote = quote_count % 2 == 0 |
| if opens_quote: |
| if text and not text.endswith((" ", "\n")) and previous_before_piece not in {"(", "[", "{"}: |
| text += " " |
| text += piece |
| join_next = True |
| return |
| text = text.rstrip(" ") |
| text += piece |
| join_next = False |
| return |
|
|
| attaches_left = _is_closing_or_terminal_punctuation(piece) or _is_infix_joiner(piece) |
| continues_segment = (not starts_segment) and any( |
| _is_word_character(character) or _is_emoji_continuation_character(character) |
| for character in piece |
| ) |
| if starts_segment: |
| if text and not join_next: |
| attaches_to_previous_code_span = ( |
| _is_opening_punctuation(piece) |
| and previous_before_piece.isalnum() |
| and next_piece is not None |
| and ( |
| _is_infix_joiner(next_piece) |
| or _is_call_opening_punctuation(piece) |
| ) |
| ) |
| if not _is_punctuation_piece(piece) or ( |
| _is_opening_punctuation(piece) |
| and not attaches_to_previous_code_span |
| ): |
| text += " " |
| text += piece |
| else: |
| if text and not join_next and not attaches_left and not continues_segment: |
| text += " " |
| text += piece |
|
|
| join_next = ( |
| _is_infix_joiner(piece) |
| and ( |
| not starts_segment |
| or ( |
| had_text_before_piece |
| and ( |
| not _is_dash_joiner(piece) |
| or previous_before_piece.isalnum() |
| or _is_opening_punctuation(previous_before_piece) |
| ) |
| ) |
| ) |
| ) or _is_opening_punctuation(piece) |
|
|
| def flush_bytes() -> None: |
| nonlocal byte_buffer, byte_starts_segment |
| if not byte_buffer: |
| return |
| append_piece(bytes(byte_buffer).decode("utf-8", errors="replace"), byte_starts_segment) |
| byte_buffer = bytearray() |
| byte_starts_segment = False |
|
|
| for token_index, token in enumerate(tokens): |
| if token in self.special_tokens: |
| continue |
| starts_segment = token.startswith(self.word_prefix) |
| piece = token[len(self.word_prefix) :] if starts_segment else token |
| if not piece: |
| continue |
| byte_value = _byte_value(piece) |
| if byte_value is not None: |
| if not byte_buffer: |
| byte_starts_segment = starts_segment |
| byte_buffer.append(byte_value) |
| continue |
|
|
| flush_bytes() |
| append_piece(piece, starts_segment, next_rendered_piece(token_index + 1)) |
| flush_bytes() |
| return text.strip() |
|
|
| def _encode_segment_cached(self, segment: str) -> tuple[str, ...]: |
| cached = self._segment_cache.get(segment) |
| if cached is not None: |
| return cached |
| encoded = tuple(self._encode_segment(segment)) |
| if len(self._segment_cache) < MAX_SEGMENT_CACHE_SIZE: |
| self._segment_cache[segment] = encoded |
| return encoded |
|
|
| def _encode_segment(self, segment: str) -> list[str]: |
| if segment in self.special_tokens: |
| return [segment] |
| whole_segment = _whole_segment_token(segment, self.word_prefix) |
| if whole_segment in self._vocab_set: |
| return [whole_segment] |
| symbols = self._seed_symbols(segment) |
| if not symbols: |
| return [] |
|
|
| while len(symbols) > 1: |
| best_rank: int | None = None |
| best_pair: tuple[str, str] | None = None |
| for index in range(len(symbols) - 1): |
| pair = (symbols[index], symbols[index + 1]) |
| rank = self._merge_ranks.get(pair) |
| if rank is None: |
| continue |
| if best_rank is None or rank < best_rank: |
| best_rank = rank |
| best_pair = pair |
| if best_pair is None: |
| break |
|
|
| merged_symbol = _merge_symbol(best_pair[0], best_pair[1], self.word_prefix) |
| symbols = _merge_sequence(symbols, best_pair, merged_symbol) |
|
|
| if any(symbol not in self._vocab_set for symbol in symbols): |
| return [self.unk_token] |
| return symbols |
|
|
| def _seed_symbols(self, segment: str) -> list[str]: |
| symbols: list[str] = [] |
| for index, character in enumerate(segment): |
| symbol = f"{self.word_prefix}{character}" if index == 0 else character |
| if symbol in self._base_symbol_set: |
| symbols.append(symbol) |
| continue |
|
|
| encoded = character.encode("utf-8") |
| for byte_index, value in enumerate(encoded): |
| token = _byte_token(value) |
| if index == 0 and byte_index == 0: |
| token = f"{self.word_prefix}{token}" |
| symbols.append(token) |
|
|
| if any(symbol not in self._base_symbol_set for symbol in symbols): |
| return [self.unk_token] |
| return symbols |
|
|
| def to_dict(self) -> dict[str, object]: |
| return { |
| "name": self.name, |
| "merges": [[left, right] for left, right in self.merges], |
| "vocab": self.vocab, |
| "base_symbols": self.base_symbols, |
| "lowercase": self.lowercase, |
| "word_prefix": self.word_prefix, |
| "unk_token": self.unk_token, |
| "bos_token": self.bos_token, |
| "eos_token": self.eos_token, |
| "pad_token": self.pad_token, |
| } |
|
|
| @classmethod |
| def from_dict(cls, payload: dict[str, object]) -> "NativeTokenizer": |
| return cls( |
| merges=[(str(left), str(right)) for left, right in payload["merges"]], |
| vocab=[str(token) for token in payload["vocab"]], |
| base_symbols=[str(token) for token in payload["base_symbols"]], |
| name=str(payload.get("name", TOKENIZER_NAME)), |
| lowercase=bool(payload["lowercase"]), |
| word_prefix=str(payload["word_prefix"]), |
| unk_token=str(payload["unk_token"]), |
| bos_token=str(payload["bos_token"]), |
| eos_token=str(payload["eos_token"]), |
| pad_token=str(payload["pad_token"]), |
| ) |
|
|
| def _build_pretoken_pattern(self) -> re.Pattern[str]: |
| reserved = sorted(self.special_tokens, key=len, reverse=True) |
| if not reserved: |
| return PRETOKEN_PATTERN |
| reserved_pattern = "|".join(re.escape(token) for token in reserved) |
| return re.compile(f"{reserved_pattern}|\\w+|[^\\w\\s]", re.UNICODE) |
|
|
| @classmethod |
| def train( |
| cls, |
| text: str, |
| *, |
| vocab_size: int = 256, |
| min_pair_frequency: int = 2, |
| lowercase: bool = False, |
| word_prefix: str = "▁", |
| ) -> "NativeTokenizer": |
| seed_tokenizer = cls( |
| merges=[], |
| vocab=[], |
| base_symbols=[], |
| lowercase=lowercase, |
| word_prefix=word_prefix, |
| ) |
| segments = seed_tokenizer.pretokenize(text) |
| if not segments: |
| raise ValueError("Cannot train the native tokenizer on empty text.") |
|
|
| return cls.train_from_segment_counts( |
| Counter(segments), |
| vocab_size=vocab_size, |
| min_pair_frequency=min_pair_frequency, |
| lowercase=lowercase, |
| word_prefix=word_prefix, |
| ) |
|
|
| @classmethod |
| def train_from_segment_counts( |
| cls, |
| segment_counts: Mapping[str, float], |
| *, |
| vocab_size: int = 256, |
| min_pair_frequency: int = 2, |
| lowercase: bool = False, |
| word_prefix: str = "▁", |
| ) -> "NativeTokenizer": |
| if not segment_counts: |
| raise ValueError("Cannot train the native tokenizer on empty segment counts.") |
| seed_tokenizer = cls( |
| merges=[], |
| vocab=[], |
| base_symbols=[], |
| lowercase=lowercase, |
| word_prefix=word_prefix, |
| ) |
|
|
| word_counts = Counter( |
| { |
| str(segment): float(frequency) |
| for segment, frequency in segment_counts.items() |
| if str(segment) and float(frequency) > 0.0 |
| } |
| ) |
| if not word_counts: |
| raise ValueError("Cannot train the native tokenizer on empty segment counts.") |
| observed_symbols = { |
| f"{word_prefix}{character}" if index == 0 else character |
| for segment in word_counts |
| for index, character in enumerate(segment) |
| } |
| base_symbols = _default_symbol_inventory(word_prefix) |
| base_symbols.update(observed_symbols) |
| sequences = { |
| segment: [ |
| f"{word_prefix}{character}" if index == 0 else character |
| for index, character in enumerate(segment) |
| ] |
| for segment in word_counts |
| } |
| vocab = set(observed_symbols) | seed_tokenizer.special_tokens |
| target_vocab_size = len(vocab) + max(1, vocab_size) |
| segment_candidates = sorted( |
| { |
| segment |
| for segment, frequency in word_counts.items() |
| if len(segment) > 1 and frequency >= min_pair_frequency |
| }, |
| key=lambda segment: ( |
| -(word_counts[segment] * len(segment)), |
| -len(segment), |
| segment, |
| ), |
| ) |
| for segment in segment_candidates: |
| if len(vocab) >= target_vocab_size: |
| break |
| vocab.add(_whole_segment_token(segment, word_prefix)) |
| merges: list[tuple[str, str]] = [] |
|
|
| while len(vocab) < target_vocab_size and len(merges) < MAX_TRAINED_PAIR_MERGES: |
| pair_counts: Counter[tuple[str, str]] = Counter() |
| for segment, frequency in word_counts.items(): |
| symbols = sequences[segment] |
| for index in range(len(symbols) - 1): |
| pair_counts[(symbols[index], symbols[index + 1])] += frequency |
|
|
| if not pair_counts: |
| break |
|
|
| best_pair, best_count = min( |
| pair_counts.items(), |
| key=lambda item: (-item[1], item[0][0], item[0][1]), |
| ) |
| if best_count < min_pair_frequency: |
| break |
|
|
| merged_symbol = _merge_symbol(best_pair[0], best_pair[1], word_prefix) |
| merges.append(best_pair) |
| vocab.add(merged_symbol) |
| for segment in sequences: |
| sequences[segment] = _merge_sequence(sequences[segment], best_pair, merged_symbol) |
|
|
| return cls( |
| merges=merges, |
| vocab=sorted(vocab), |
| base_symbols=sorted(base_symbols), |
| lowercase=lowercase, |
| word_prefix=word_prefix, |
| ) |
|
|