Reframr-RFM-v1-Base / reframr /tokenizer.py
OkeyMeta's picture
Release Reframr-RFM-v1-Base public checkpoint
2147ce8 verified
import re
import unicodedata
from collections import Counter
from collections.abc import Mapping
from dataclasses import dataclass, field
from string import ascii_letters, digits
from .reasoning import REASONING_CONTROL_TOKENS, TOKENIZER_NAME
PRETOKEN_PATTERN = re.compile(r"\w+|[^\w\s]", re.UNICODE)
BYTE_FALLBACK_PATTERN = re.compile(r"<byte:([0-9A-F]{2})>")
DEFAULT_FALLBACK_CHARACTERS = (
ascii_letters
+ digits
+ "'-_/.:,;!?()[]{}@#$%&*+="
+ "’ʼ‘“”—–…"
)
MAX_TOKENIZER_VOCAB_SIZE = 65536
MAX_SEGMENT_CACHE_SIZE = 200_000
MAX_TRAINED_PAIR_MERGES = 384
def _is_word_character(character: str) -> bool:
category = unicodedata.category(character)
return character == "_" or category[0] in {"L", "N"} or category == "Mn"
def _is_variation_selector(character: str) -> bool:
return "VARIATION SELECTOR" in unicodedata.name(character, "")
def _is_zero_width_joiner(character: str) -> bool:
return unicodedata.name(character, "") == "ZERO WIDTH JOINER"
def _is_emoji_modifier(character: str) -> bool:
return "EMOJI MODIFIER" in unicodedata.name(character, "")
def _is_emoji_base_character(character: str) -> bool:
name = unicodedata.name(character, "")
category = unicodedata.category(character)
return (
"EMOJI" in name
or "REGIONAL INDICATOR SYMBOL" in name
or (category in {"So", "Sk"} and ord(character) >= 0x2100)
)
def _is_emoji_continuation_character(character: str) -> bool:
category = unicodedata.category(character)
name = unicodedata.name(character, "")
return (
_is_variation_selector(character)
or _is_zero_width_joiner(character)
or _is_emoji_modifier(character)
or category in {"Mn", "Me"}
or name.startswith("TAG ")
)
def _consume_emoji_cluster(text: str, start: int) -> int:
if start >= len(text) or not _is_emoji_base_character(text[start]):
return start
index = start + 1
if "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[start], ""):
if index < len(text) and "REGIONAL INDICATOR SYMBOL" in unicodedata.name(text[index], ""):
return index + 1
return index
while index < len(text):
if _is_emoji_continuation_character(text[index]):
index += 1
continue
if _is_zero_width_joiner(text[index - 1]) and _is_emoji_base_character(text[index]):
index += 1
continue
break
return index
def _byte_token(value: int) -> str:
return f"<byte:{value:02X}>"
def _byte_value(piece: str) -> int | None:
match = BYTE_FALLBACK_PATTERN.fullmatch(piece)
if match is None:
return None
return int(match.group(1), 16)
def _is_punctuation_piece(piece: str) -> bool:
return bool(piece) and all(
unicodedata.category(character).startswith("P")
for character in piece
)
def _is_opening_punctuation(piece: str) -> bool:
return bool(piece) and all(
unicodedata.category(character) in {"Ps", "Pi"}
for character in piece
)
def _is_call_opening_punctuation(piece: str) -> bool:
return bool(piece) and all(
unicodedata.category(character) == "Ps"
and "PARENTHESIS" in unicodedata.name(character, "")
for character in piece
)
def _is_closing_or_terminal_punctuation(piece: str) -> bool:
return bool(piece) and all(
unicodedata.category(character) in {"Pe", "Pf", "Po"}
for character in piece
)
def _is_infix_joiner(piece: str) -> bool:
if len(piece) != 1:
return False
category = unicodedata.category(piece)
name = unicodedata.name(piece, "")
return (
category == "Pd"
or "APOSTROPHE" in name
or (category == "Pf" and "SINGLE QUOTATION MARK" in name)
or "SOLIDUS" in name
)
def _is_dash_joiner(piece: str) -> bool:
if len(piece) != 1:
return False
category = unicodedata.category(piece)
name = unicodedata.name(piece, "")
return category == "Pd" or "HYPHEN" in name or "DASH" in name
def _is_quote_piece(piece: str) -> bool:
if len(piece) != 1:
return False
if _is_infix_joiner(piece):
return False
name = unicodedata.name(piece, "")
category = unicodedata.category(piece)
return "QUOTATION MARK" in name or category in {"Pi", "Pf"}
def _merge_symbol(left: str, right: str, prefix: str) -> str:
if right.startswith(prefix):
return left + right[len(prefix):]
return left + right
def _merge_sequence(symbols: list[str], pair: tuple[str, str], merged_symbol: str) -> list[str]:
merged: list[str] = []
index = 0
while index < len(symbols):
if index < len(symbols) - 1 and (symbols[index], symbols[index + 1]) == pair:
merged.append(merged_symbol)
index += 2
else:
merged.append(symbols[index])
index += 1
return merged
def _default_symbol_inventory(word_prefix: str) -> set[str]:
symbols: set[str] = set()
for character in DEFAULT_FALLBACK_CHARACTERS:
symbols.add(character)
symbols.add(f"{word_prefix}{character}")
for value in range(256):
token = _byte_token(value)
symbols.add(token)
symbols.add(f"{word_prefix}{token}")
return symbols
def _whole_segment_token(segment: str, word_prefix: str) -> str:
return f"{word_prefix}{segment}"
def recommend_vocab_size(
text: str,
*,
minimum: int = 768,
maximum: int = 1536,
multiplier: int = 5,
lowercase: bool = False,
) -> int:
seed_tokenizer = NativeTokenizer(
merges=[],
vocab=[],
base_symbols=[],
lowercase=lowercase,
)
segments = seed_tokenizer.pretokenize(text)
distinct_segments = len(set(segments))
recommended = max(minimum, distinct_segments * multiplier)
return min(maximum, recommended)
def clamp_vocab_size(requested: int, *, maximum: int = MAX_TOKENIZER_VOCAB_SIZE) -> int:
return min(maximum, max(1, requested))
@dataclass(slots=True)
class NativeTokenizer:
merges: list[tuple[str, str]]
vocab: list[str]
base_symbols: list[str]
name: str = TOKENIZER_NAME
lowercase: bool = False
word_prefix: str = "▁"
unk_token: str = "<unk>"
bos_token: str = "<bos>"
eos_token: str = "<eos>"
pad_token: str = "<pad>"
_merge_ranks: dict[tuple[str, str], int] = field(init=False, repr=False)
_vocab_set: set[str] = field(init=False, repr=False)
_base_symbol_set: set[str] = field(init=False, repr=False)
_pretoken_pattern: re.Pattern[str] = field(init=False, repr=False)
_segment_cache: dict[str, tuple[str, ...]] = field(init=False, repr=False)
def __post_init__(self) -> None:
self._merge_ranks = {pair: index for index, pair in enumerate(self.merges)}
self._base_symbol_set = set(self.base_symbols)
self._vocab_set = set(self.vocab) | self.special_tokens | self._base_symbol_set
self.vocab = sorted(self._vocab_set)
self._pretoken_pattern = self._build_pretoken_pattern()
self._segment_cache = {}
@property
def special_tokens(self) -> set[str]:
return {
self.unk_token,
self.bos_token,
self.eos_token,
self.pad_token,
*REASONING_CONTROL_TOKENS,
}
@property
def vocab_size(self) -> int:
return len(self._vocab_set)
def normalize(self, text: str) -> str:
normalized = unicodedata.normalize("NFKC", text)
return normalized.lower() if self.lowercase else normalized
def pretokenize(self, text: str) -> list[str]:
normalized = self.normalize(text)
segments: list[str] = []
reserved = sorted(self.special_tokens, key=len, reverse=True)
index = 0
while index < len(normalized):
if normalized[index].isspace():
if normalized[index] == "\r":
if index + 1 < len(normalized) and normalized[index + 1] == "\n":
segments.append("\n")
index += 2
continue
segments.append("\n")
index += 1
continue
if normalized[index] == "\n":
segments.append("\n")
index += 1
continue
index += 1
continue
matched_special = next(
(
token
for token in reserved
if normalized.startswith(token, index)
),
None,
)
if matched_special is not None:
segments.append(matched_special)
index += len(matched_special)
continue
emoji_end = _consume_emoji_cluster(normalized, index)
if emoji_end > index:
segments.append(normalized[index:emoji_end])
index = emoji_end
continue
if _is_word_character(normalized[index]):
start = index
index += 1
while index < len(normalized) and _is_word_character(normalized[index]):
index += 1
segments.append(normalized[start:index])
continue
segments.append(normalized[index])
index += 1
return segments
def encode(self, text: str, *, add_special_tokens: bool = False) -> list[str]:
tokens: list[str] = []
if add_special_tokens:
tokens.append(self.bos_token)
for segment in self.pretokenize(text):
tokens.extend(self._encode_segment_cached(segment))
if add_special_tokens:
tokens.append(self.eos_token)
if not tokens and text.strip():
return [self.unk_token]
return tokens
def encode_many(
self,
texts: list[str] | tuple[str, ...],
*,
add_special_tokens: bool = False,
) -> list[list[str]]:
return [
self.encode(text, add_special_tokens=add_special_tokens)
for text in texts
]
def decode(self, tokens: list[str]) -> str:
text = ""
join_next = False
byte_buffer = bytearray()
byte_starts_segment = False
def next_rendered_piece(start_index: int) -> str | None:
for raw_token in tokens[start_index:]:
if raw_token in self.special_tokens:
continue
raw_starts_segment = raw_token.startswith(self.word_prefix)
raw_piece = raw_token[len(self.word_prefix) :] if raw_starts_segment else raw_token
if not raw_piece:
continue
if _byte_value(raw_piece) is not None:
return None
return raw_piece
return None
def append_piece(piece: str, starts_segment: bool, next_piece: str | None = None) -> None:
nonlocal text, join_next
if piece == "\n":
text = text.rstrip(" ")
text += "\n"
join_next = True
return
had_text_before_piece = bool(text.strip())
previous_before_piece = text.rstrip(" ")[-1:] if text.strip(" ") else ""
if _is_quote_piece(piece):
quote_count = sum(1 for character in text if _is_quote_piece(character))
opens_quote = quote_count % 2 == 0
if opens_quote:
if text and not text.endswith((" ", "\n")) and previous_before_piece not in {"(", "[", "{"}:
text += " "
text += piece
join_next = True
return
text = text.rstrip(" ")
text += piece
join_next = False
return
attaches_left = _is_closing_or_terminal_punctuation(piece) or _is_infix_joiner(piece)
continues_segment = (not starts_segment) and any(
_is_word_character(character) or _is_emoji_continuation_character(character)
for character in piece
)
if starts_segment:
if text and not join_next:
attaches_to_previous_code_span = (
_is_opening_punctuation(piece)
and previous_before_piece.isalnum()
and next_piece is not None
and (
_is_infix_joiner(next_piece)
or _is_call_opening_punctuation(piece)
)
)
if not _is_punctuation_piece(piece) or (
_is_opening_punctuation(piece)
and not attaches_to_previous_code_span
):
text += " "
text += piece
else:
if text and not join_next and not attaches_left and not continues_segment:
text += " "
text += piece
join_next = (
_is_infix_joiner(piece)
and (
not starts_segment
or (
had_text_before_piece
and (
not _is_dash_joiner(piece)
or previous_before_piece.isalnum()
or _is_opening_punctuation(previous_before_piece)
)
)
)
) or _is_opening_punctuation(piece)
def flush_bytes() -> None:
nonlocal byte_buffer, byte_starts_segment
if not byte_buffer:
return
append_piece(bytes(byte_buffer).decode("utf-8", errors="replace"), byte_starts_segment)
byte_buffer = bytearray()
byte_starts_segment = False
for token_index, token in enumerate(tokens):
if token in self.special_tokens:
continue
starts_segment = token.startswith(self.word_prefix)
piece = token[len(self.word_prefix) :] if starts_segment else token
if not piece:
continue
byte_value = _byte_value(piece)
if byte_value is not None:
if not byte_buffer:
byte_starts_segment = starts_segment
byte_buffer.append(byte_value)
continue
flush_bytes()
append_piece(piece, starts_segment, next_rendered_piece(token_index + 1))
flush_bytes()
return text.strip()
def _encode_segment_cached(self, segment: str) -> tuple[str, ...]:
cached = self._segment_cache.get(segment)
if cached is not None:
return cached
encoded = tuple(self._encode_segment(segment))
if len(self._segment_cache) < MAX_SEGMENT_CACHE_SIZE:
self._segment_cache[segment] = encoded
return encoded
def _encode_segment(self, segment: str) -> list[str]:
if segment in self.special_tokens:
return [segment]
whole_segment = _whole_segment_token(segment, self.word_prefix)
if whole_segment in self._vocab_set:
return [whole_segment]
symbols = self._seed_symbols(segment)
if not symbols:
return []
while len(symbols) > 1:
best_rank: int | None = None
best_pair: tuple[str, str] | None = None
for index in range(len(symbols) - 1):
pair = (symbols[index], symbols[index + 1])
rank = self._merge_ranks.get(pair)
if rank is None:
continue
if best_rank is None or rank < best_rank:
best_rank = rank
best_pair = pair
if best_pair is None:
break
merged_symbol = _merge_symbol(best_pair[0], best_pair[1], self.word_prefix)
symbols = _merge_sequence(symbols, best_pair, merged_symbol)
if any(symbol not in self._vocab_set for symbol in symbols):
return [self.unk_token]
return symbols
def _seed_symbols(self, segment: str) -> list[str]:
symbols: list[str] = []
for index, character in enumerate(segment):
symbol = f"{self.word_prefix}{character}" if index == 0 else character
if symbol in self._base_symbol_set:
symbols.append(symbol)
continue
encoded = character.encode("utf-8")
for byte_index, value in enumerate(encoded):
token = _byte_token(value)
if index == 0 and byte_index == 0:
token = f"{self.word_prefix}{token}"
symbols.append(token)
if any(symbol not in self._base_symbol_set for symbol in symbols):
return [self.unk_token]
return symbols
def to_dict(self) -> dict[str, object]:
return {
"name": self.name,
"merges": [[left, right] for left, right in self.merges],
"vocab": self.vocab,
"base_symbols": self.base_symbols,
"lowercase": self.lowercase,
"word_prefix": self.word_prefix,
"unk_token": self.unk_token,
"bos_token": self.bos_token,
"eos_token": self.eos_token,
"pad_token": self.pad_token,
}
@classmethod
def from_dict(cls, payload: dict[str, object]) -> "NativeTokenizer":
return cls(
merges=[(str(left), str(right)) for left, right in payload["merges"]],
vocab=[str(token) for token in payload["vocab"]],
base_symbols=[str(token) for token in payload["base_symbols"]],
name=str(payload.get("name", TOKENIZER_NAME)),
lowercase=bool(payload["lowercase"]),
word_prefix=str(payload["word_prefix"]),
unk_token=str(payload["unk_token"]),
bos_token=str(payload["bos_token"]),
eos_token=str(payload["eos_token"]),
pad_token=str(payload["pad_token"]),
)
def _build_pretoken_pattern(self) -> re.Pattern[str]:
reserved = sorted(self.special_tokens, key=len, reverse=True)
if not reserved:
return PRETOKEN_PATTERN
reserved_pattern = "|".join(re.escape(token) for token in reserved)
return re.compile(f"{reserved_pattern}|\\w+|[^\\w\\s]", re.UNICODE)
@classmethod
def train(
cls,
text: str,
*,
vocab_size: int = 256,
min_pair_frequency: int = 2,
lowercase: bool = False,
word_prefix: str = "▁",
) -> "NativeTokenizer":
seed_tokenizer = cls(
merges=[],
vocab=[],
base_symbols=[],
lowercase=lowercase,
word_prefix=word_prefix,
)
segments = seed_tokenizer.pretokenize(text)
if not segments:
raise ValueError("Cannot train the native tokenizer on empty text.")
return cls.train_from_segment_counts(
Counter(segments),
vocab_size=vocab_size,
min_pair_frequency=min_pair_frequency,
lowercase=lowercase,
word_prefix=word_prefix,
)
@classmethod
def train_from_segment_counts(
cls,
segment_counts: Mapping[str, float],
*,
vocab_size: int = 256,
min_pair_frequency: int = 2,
lowercase: bool = False,
word_prefix: str = "▁",
) -> "NativeTokenizer":
if not segment_counts:
raise ValueError("Cannot train the native tokenizer on empty segment counts.")
seed_tokenizer = cls(
merges=[],
vocab=[],
base_symbols=[],
lowercase=lowercase,
word_prefix=word_prefix,
)
word_counts = Counter(
{
str(segment): float(frequency)
for segment, frequency in segment_counts.items()
if str(segment) and float(frequency) > 0.0
}
)
if not word_counts:
raise ValueError("Cannot train the native tokenizer on empty segment counts.")
observed_symbols = {
f"{word_prefix}{character}" if index == 0 else character
for segment in word_counts
for index, character in enumerate(segment)
}
base_symbols = _default_symbol_inventory(word_prefix)
base_symbols.update(observed_symbols)
sequences = {
segment: [
f"{word_prefix}{character}" if index == 0 else character
for index, character in enumerate(segment)
]
for segment in word_counts
}
vocab = set(observed_symbols) | seed_tokenizer.special_tokens
target_vocab_size = len(vocab) + max(1, vocab_size)
segment_candidates = sorted(
{
segment
for segment, frequency in word_counts.items()
if len(segment) > 1 and frequency >= min_pair_frequency
},
key=lambda segment: (
-(word_counts[segment] * len(segment)),
-len(segment),
segment,
),
)
for segment in segment_candidates:
if len(vocab) >= target_vocab_size:
break
vocab.add(_whole_segment_token(segment, word_prefix))
merges: list[tuple[str, str]] = []
while len(vocab) < target_vocab_size and len(merges) < MAX_TRAINED_PAIR_MERGES:
pair_counts: Counter[tuple[str, str]] = Counter()
for segment, frequency in word_counts.items():
symbols = sequences[segment]
for index in range(len(symbols) - 1):
pair_counts[(symbols[index], symbols[index + 1])] += frequency
if not pair_counts:
break
best_pair, best_count = min(
pair_counts.items(),
key=lambda item: (-item[1], item[0][0], item[0][1]),
)
if best_count < min_pair_frequency:
break
merged_symbol = _merge_symbol(best_pair[0], best_pair[1], word_prefix)
merges.append(best_pair)
vocab.add(merged_symbol)
for segment in sequences:
sequences[segment] = _merge_sequence(sequences[segment], best_pair, merged_symbol)
return cls(
merges=merges,
vocab=sorted(vocab),
base_symbols=sorted(base_symbols),
lowercase=lowercase,
word_prefix=word_prefix,
)