import re import unicodedata from collections import Counter from collections.abc import Mapping from dataclasses import dataclass, field from string import ascii_letters, digits from .reasoning import REASONING_CONTROL_TOKENS, TOKENIZER_NAME PRETOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE) BYTE_FALLBACK_PATTERN = re.compile(r"") DEFAULT_FALLBACK_CHARACTERS = ( ascii_letters + digits + "'-_/.:,;!?()[]{}@#$%&*+=" + "’ʼ‘“”—–…" ) MAX_TOKENIZER_VOCAB_SIZE = 65536 MAX_SEGMENT_CACHE_SIZE = 200_000 MAX_TRAINED_PAIR_MERGES = 384 def _is_word_character(character: str) -> bool: category = unicodedata.category(character) return character == "_" or category[0] in {"L", "N"} or category == "Mn" def _is_variation_selector(character: str) -> bool: return "VARIATION SELECTOR" in unicodedata.name(character, "") def _is_zero_width_joiner(character: str) -> bool: return unicodedata.name(character, "") == "ZERO WIDTH JOINER" def _is_emoji_modifier(character: str) -> bool: return "EMOJI MODIFIER" in unicodedata.name(character, "") def _is_emoji_base_character(character: str) -> bool: name = unicodedata.name(character, "") category = unicodedata.category(character) return ( "EMOJI" in name or "REGIONAL INDICATOR SYMBOL" in name or (category in {"So", "Sk"} and ord(character) >= 0x2100) ) def _is_emoji_continuation_character(character: str) -> bool: category = unicodedata.category(character) name = unicodedata.name(character, "") return ( _is_variation_selector(character) or _is_zero_width_joiner(character) or _is_emoji_modifier(character) or category in {"Mn", "Me"} or name.startswith("TAG ") ) def _consume_emoji_cluster(text: str, start: int) -> int: if start >= len(text) or not _is_emoji_base_character(text[start]): return start index = start + 1 if "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[start], ""): if index < len(text) and "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[index], ""): return index + 1 return index while index < len(text): if _is_emoji_continuation_character(text[index]): index += 1 continue if _is_zero_width_joiner(text[index - 1]) and _is_emoji_base_character(text[index]): index += 1 continue break return index def _byte_token(value: int) -> str: return f"" def _byte_value(piece: str) -> int | None: match = BYTE_FALLBACK_PATTERN.fullmatch(piece) if match is None: return None return int(match.group(1), 16) def _is_punctuation_piece(piece: str) -> bool: return bool(piece) and all( unicodedata.category(character).startswith("P") for character in piece ) def _is_opening_punctuation(piece: str) -> bool: return bool(piece) and all( unicodedata.category(character) in {"Ps", "Pi"} for character in piece ) def _is_call_opening_punctuation(piece: str) -> bool: return bool(piece) and all( unicodedata.category(character) == "Ps" and "PARENTHESIS" in unicodedata.name(character, "") for character in piece ) def _is_closing_or_terminal_punctuation(piece: str) -> bool: return bool(piece) and all( unicodedata.category(character) in {"Pe", "Pf", "Po"} for character in piece ) def _is_infix_joiner(piece: str) -> bool: if len(piece) != 1: return False category = unicodedata.category(piece) name = unicodedata.name(piece, "") return ( category == "Pd" or "APOSTROPHE" in name or (category == "Pf" and "SINGLE QUOTATION MARK" in name) or "SOLIDUS" in name ) def _is_dash_joiner(piece: str) -> bool: if len(piece) != 1: return False category = unicodedata.category(piece) name = unicodedata.name(piece, "") return category == "Pd" or "HYPHEN" in name or "DASH" in name def _is_quote_piece(piece: str) -> bool: if len(piece) != 1: return False if _is_infix_joiner(piece): return False name = unicodedata.name(piece, "") category = unicodedata.category(piece) return "QUOTATION MARK" in name or category in {"Pi", "Pf"} def _merge_symbol(left: str, right: str, prefix: str) -> str: if right.startswith(prefix): return left + right[len(prefix):] return left + right def _merge_sequence(symbols: list[str], pair: tuple[str, str], merged_symbol: str) -> list[str]: merged: list[str] = [] index = 0 while index < len(symbols): if index < len(symbols) - 1 and (symbols[index], symbols[index + 1]) == pair: merged.append(merged_symbol) index += 2 else: merged.append(symbols[index]) index += 1 return merged def _default_symbol_inventory(word_prefix: str) -> set[str]: symbols: set[str] = set() for character in DEFAULT_FALLBACK_CHARACTERS: symbols.add(character) symbols.add(f"{word_prefix}{character}") for value in range(256): token = _byte_token(value) symbols.add(token) symbols.add(f"{word_prefix}{token}") return symbols def _whole_segment_token(segment: str, word_prefix: str) -> str: return f"{word_prefix}{segment}" def recommend_vocab_size( text: str, *, minimum: int = 768, maximum: int = 1536, multiplier: int = 5, lowercase: bool = False, ) -> int: seed_tokenizer = NativeTokenizer( merges=[], vocab=[], base_symbols=[], lowercase=lowercase, ) segments = seed_tokenizer.pretokenize(text) distinct_segments = len(set(segments)) recommended = max(minimum, distinct_segments * multiplier) return min(maximum, recommended) def clamp_vocab_size(requested: int, *, maximum: int = MAX_TOKENIZER_VOCAB_SIZE) -> int: return min(maximum, max(1, requested)) @dataclass(slots=True) class NativeTokenizer: merges: list[tuple[str, str]] vocab: list[str] base_symbols: list[str] name: str = TOKENIZER_NAME lowercase: bool = False word_prefix: str = "▁" unk_token: str = "" bos_token: str = "" eos_token: str = "" pad_token: str = "" _merge_ranks: dict[tuple[str, str], int] = field(init=False, repr=False) _vocab_set: set[str] = field(init=False, repr=False) _base_symbol_set: set[str] = field(init=False, repr=False) _pretoken_pattern: re.Pattern[str] = field(init=False, repr=False) _segment_cache: dict[str, tuple[str, ...]] = field(init=False, repr=False) def __post_init__(self) -> None: self._merge_ranks = {pair: index for index, pair in enumerate(self.merges)} self._base_symbol_set = set(self.base_symbols) self._vocab_set = set(self.vocab) | self.special_tokens | self._base_symbol_set self.vocab = sorted(self._vocab_set) self._pretoken_pattern = self._build_pretoken_pattern() self._segment_cache = {} @property def special_tokens(self) -> set[str]: return { self.unk_token, self.bos_token, self.eos_token, self.pad_token, *REASONING_CONTROL_TOKENS, } @property def vocab_size(self) -> int: return len(self._vocab_set) def normalize(self, text: str) -> str: normalized = unicodedata.normalize("NFKC", text) return normalized.lower() if self.lowercase else normalized def pretokenize(self, text: str) -> list[str]: normalized = self.normalize(text) segments: list[str] = [] reserved = sorted(self.special_tokens, key=len, reverse=True) index = 0 while index < len(normalized): if normalized[index].isspace(): if normalized[index] == "\r": if index + 1 < len(normalized) and normalized[index + 1] == "\n": segments.append("\n") index += 2 continue segments.append("\n") index += 1 continue if normalized[index] == "\n": segments.append("\n") index += 1 continue index += 1 continue matched_special = next( ( token for token in reserved if normalized.startswith(token, index) ), None, ) if matched_special is not None: segments.append(matched_special) index += len(matched_special) continue emoji_end = _consume_emoji_cluster(normalized, index) if emoji_end > index: segments.append(normalized[index:emoji_end]) index = emoji_end continue if _is_word_character(normalized[index]): start = index index += 1 while index < len(normalized) and _is_word_character(normalized[index]): index += 1 segments.append(normalized[start:index]) continue segments.append(normalized[index]) index += 1 return segments def encode(self, text: str, *, add_special_tokens: bool = False) -> list[str]: tokens: list[str] = [] if add_special_tokens: tokens.append(self.bos_token) for segment in self.pretokenize(text): tokens.extend(self._encode_segment_cached(segment)) if add_special_tokens: tokens.append(self.eos_token) if not tokens and text.strip(): return [self.unk_token] return tokens def encode_many( self, texts: list[str] | tuple[str, ...], *, add_special_tokens: bool = False, ) -> list[list[str]]: return [ self.encode(text, add_special_tokens=add_special_tokens) for text in texts ] def decode(self, tokens: list[str]) -> str: text = "" join_next = False byte_buffer = bytearray() byte_starts_segment = False def next_rendered_piece(start_index: int) -> str | None: for raw_token in tokens[start_index:]: if raw_token in self.special_tokens: continue raw_starts_segment = raw_token.startswith(self.word_prefix) raw_piece = raw_token[len(self.word_prefix) :] if raw_starts_segment else raw_token if not raw_piece: continue if _byte_value(raw_piece) is not None: return None return raw_piece return None def append_piece(piece: str, starts_segment: bool, next_piece: str | None = None) -> None: nonlocal text, join_next if piece == "\n": text = text.rstrip(" ") text += "\n" join_next = True return had_text_before_piece = bool(text.strip()) previous_before_piece = text.rstrip(" ")[-1:] if text.strip(" ") else "" if _is_quote_piece(piece): quote_count = sum(1 for character in text if _is_quote_piece(character)) opens_quote = quote_count % 2 == 0 if opens_quote: if text and not text.endswith((" ", "\n")) and previous_before_piece not in {"(", "[", "{"}: text += " " text += piece join_next = True return text = text.rstrip(" ") text += piece join_next = False return attaches_left = _is_closing_or_terminal_punctuation(piece) or _is_infix_joiner(piece) continues_segment = (not starts_segment) and any( _is_word_character(character) or _is_emoji_continuation_character(character) for character in piece ) if starts_segment: if text and not join_next: attaches_to_previous_code_span = ( _is_opening_punctuation(piece) and previous_before_piece.isalnum() and next_piece is not None and ( _is_infix_joiner(next_piece) or _is_call_opening_punctuation(piece) ) ) if not _is_punctuation_piece(piece) or ( _is_opening_punctuation(piece) and not attaches_to_previous_code_span ): text += " " text += piece else: if text and not join_next and not attaches_left and not continues_segment: text += " " text += piece join_next = ( _is_infix_joiner(piece) and ( not starts_segment or ( had_text_before_piece and ( not _is_dash_joiner(piece) or previous_before_piece.isalnum() or _is_opening_punctuation(previous_before_piece) ) ) ) ) or _is_opening_punctuation(piece) def flush_bytes() -> None: nonlocal byte_buffer, byte_starts_segment if not byte_buffer: return append_piece(bytes(byte_buffer).decode("utf-8", errors="replace"), byte_starts_segment) byte_buffer = bytearray() byte_starts_segment = False for token_index, token in enumerate(tokens): if token in self.special_tokens: continue starts_segment = token.startswith(self.word_prefix) piece = token[len(self.word_prefix) :] if starts_segment else token if not piece: continue byte_value = _byte_value(piece) if byte_value is not None: if not byte_buffer: byte_starts_segment = starts_segment byte_buffer.append(byte_value) continue flush_bytes() append_piece(piece, starts_segment, next_rendered_piece(token_index + 1)) flush_bytes() return text.strip() def _encode_segment_cached(self, segment: str) -> tuple[str, ...]: cached = self._segment_cache.get(segment) if cached is not None: return cached encoded = tuple(self._encode_segment(segment)) if len(self._segment_cache) < MAX_SEGMENT_CACHE_SIZE: self._segment_cache[segment] = encoded return encoded def _encode_segment(self, segment: str) -> list[str]: if segment in self.special_tokens: return [segment] whole_segment = _whole_segment_token(segment, self.word_prefix) if whole_segment in self._vocab_set: return [whole_segment] symbols = self._seed_symbols(segment) if not symbols: return [] while len(symbols) > 1: best_rank: int | None = None best_pair: tuple[str, str] | None = None for index in range(len(symbols) - 1): pair = (symbols[index], symbols[index + 1]) rank = self._merge_ranks.get(pair) if rank is None: continue if best_rank is None or rank < best_rank: best_rank = rank best_pair = pair if best_pair is None: break merged_symbol = _merge_symbol(best_pair[0], best_pair[1], self.word_prefix) symbols = _merge_sequence(symbols, best_pair, merged_symbol) if any(symbol not in self._vocab_set for symbol in symbols): return [self.unk_token] return symbols def _seed_symbols(self, segment: str) -> list[str]: symbols: list[str] = [] for index, character in enumerate(segment): symbol = f"{self.word_prefix}{character}" if index == 0 else character if symbol in self._base_symbol_set: symbols.append(symbol) continue encoded = character.encode("utf-8") for byte_index, value in enumerate(encoded): token = _byte_token(value) if index == 0 and byte_index == 0: token = f"{self.word_prefix}{token}" symbols.append(token) if any(symbol not in self._base_symbol_set for symbol in symbols): return [self.unk_token] return symbols def to_dict(self) -> dict[str, object]: return { "name": self.name, "merges": [[left, right] for left, right in self.merges], "vocab": self.vocab, "base_symbols": self.base_symbols, "lowercase": self.lowercase, "word_prefix": self.word_prefix, "unk_token": self.unk_token, "bos_token": self.bos_token, "eos_token": self.eos_token, "pad_token": self.pad_token, } @classmethod def from_dict(cls, payload: dict[str, object]) -> "NativeTokenizer": return cls( merges=[(str(left), str(right)) for left, right in payload["merges"]], vocab=[str(token) for token in payload["vocab"]], base_symbols=[str(token) for token in payload["base_symbols"]], name=str(payload.get("name", TOKENIZER_NAME)), lowercase=bool(payload["lowercase"]), word_prefix=str(payload["word_prefix"]), unk_token=str(payload["unk_token"]), bos_token=str(payload["bos_token"]), eos_token=str(payload["eos_token"]), pad_token=str(payload["pad_token"]), ) def _build_pretoken_pattern(self) -> re.Pattern[str]: reserved = sorted(self.special_tokens, key=len, reverse=True) if not reserved: return PRETOKEN_PATTERN reserved_pattern = "|".join(re.escape(token) for token in reserved) return re.compile(f"{reserved_pattern}|\\w+|[^\\w\\s]", re.UNICODE) @classmethod def train( cls, text: str, *, vocab_size: int = 256, min_pair_frequency: int = 2, lowercase: bool = False, word_prefix: str = "▁", ) -> "NativeTokenizer": seed_tokenizer = cls( merges=[], vocab=[], base_symbols=[], lowercase=lowercase, word_prefix=word_prefix, ) segments = seed_tokenizer.pretokenize(text) if not segments: raise ValueError("Cannot train the native tokenizer on empty text.") return cls.train_from_segment_counts( Counter(segments), vocab_size=vocab_size, min_pair_frequency=min_pair_frequency, lowercase=lowercase, word_prefix=word_prefix, ) @classmethod def train_from_segment_counts( cls, segment_counts: Mapping[str, float], *, vocab_size: int = 256, min_pair_frequency: int = 2, lowercase: bool = False, word_prefix: str = "▁", ) -> "NativeTokenizer": if not segment_counts: raise ValueError("Cannot train the native tokenizer on empty segment counts.") seed_tokenizer = cls( merges=[], vocab=[], base_symbols=[], lowercase=lowercase, word_prefix=word_prefix, ) word_counts = Counter( { str(segment): float(frequency) for segment, frequency in segment_counts.items() if str(segment) and float(frequency) > 0.0 } ) if not word_counts: raise ValueError("Cannot train the native tokenizer on empty segment counts.") observed_symbols = { f"{word_prefix}{character}" if index == 0 else character for segment in word_counts for index, character in enumerate(segment) } base_symbols = _default_symbol_inventory(word_prefix) base_symbols.update(observed_symbols) sequences = { segment: [ f"{word_prefix}{character}" if index == 0 else character for index, character in enumerate(segment) ] for segment in word_counts } vocab = set(observed_symbols) | seed_tokenizer.special_tokens target_vocab_size = len(vocab) + max(1, vocab_size) segment_candidates = sorted( { segment for segment, frequency in word_counts.items() if len(segment) > 1 and frequency >= min_pair_frequency }, key=lambda segment: ( -(word_counts[segment] * len(segment)), -len(segment), segment, ), ) for segment in segment_candidates: if len(vocab) >= target_vocab_size: break vocab.add(_whole_segment_token(segment, word_prefix)) merges: list[tuple[str, str]] = [] while len(vocab) < target_vocab_size and len(merges) < MAX_TRAINED_PAIR_MERGES: pair_counts: Counter[tuple[str, str]] = Counter() for segment, frequency in word_counts.items(): symbols = sequences[segment] for index in range(len(symbols) - 1): pair_counts[(symbols[index], symbols[index + 1])] += frequency if not pair_counts: break best_pair, best_count = min( pair_counts.items(), key=lambda item: (-item[1], item[0][0], item[0][1]), ) if best_count < min_pair_frequency: break merged_symbol = _merge_symbol(best_pair[0], best_pair[1], word_prefix) merges.append(best_pair) vocab.add(merged_symbol) for segment in sequences: sequences[segment] = _merge_sequence(sequences[segment], best_pair, merged_symbol) return cls( merges=merges, vocab=sorted(vocab), base_symbols=sorted(base_symbols), lowercase=lowercase, word_prefix=word_prefix, )