llm-jp-4-8b-thinking / llmjp4_tokenizer.py
Taka008's picture
Upload 6 files
b346cae verified
# llm-jp-4 tokenizer
from collections.abc import Sequence
import os
from transformers import LlamaTokenizerFast
from tokenizers import Tokenizer
from .llmjp4_harmony import HarmonyMessageParser, HarmonyMessage
class Llmjp4Tokenizer(LlamaTokenizerFast):
_HARMONY_TOKENS: set[str] = {
"<|start|>",
"<|message|>",
"<|channel|>",
"<|constrain|>",
"<|end|>",
"<|return|>",
"<|call|>",
}
# NOTE(odashi):
# Response schemas are not recognized automatically.
# We need to define them manually.
# https://github.com/huggingface/trl/issues/4609
_RESPONSE_SCHEMA = {
"type": "object",
"properties": {
"role": {"const": "assistant"},
"content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|<\|return\|>|$)"},
"thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"},
"tool_calls": {
"x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)",
"type": "array",
"items": {
"type": "object",
"properties": {
"type": {"const": "function"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"},
"arguments": {
"type": "object",
"x-regex": r"<\|message\|>(.*)",
"x-parser": "json",
"additionalProperties": {"type": "any"},
},
},
},
},
},
},
},
}
@classmethod
def convert_to_native_format(cls, **kwargs):
# NOTE(odashi):
# Workaround for transformers 5.x.
# Guaranteeing the same inner behavior with TokenizersBackend.
# https://github.com/huggingface/transformers/blob/7d9754a05193eb79b1d86aa744b622b8068008cd/src/transformers/tokenization_utils_tokenizers.py#L110-L116
local_kwargs = dict(kwargs)
fast_tokenizer_file = local_kwargs.pop("tokenizer_file", None)
if fast_tokenizer_file is None or not os.path.isfile(fast_tokenizer_file):
raise ValueError("Tokenizer file must exist.")
local_kwargs["tokenizer_object"] = Tokenizer.from_file(fast_tokenizer_file)
return local_kwargs
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.response_schema = self._RESPONSE_SCHEMA
self._harmony_token_ids = {
self.convert_tokens_to_ids(token)
for token in self._HARMONY_TOKENS
}
def _decode(self, token_ids: int | list[int], *args, **kwargs):
if isinstance(token_ids, int):
token_ids = [token_ids]
result: list[str] = []
prev_pos = 0
# NOTE(odashi):
# Ensure that text tokens are decoded without preceding Harmony tokens
# to avoid incorrect addition of whitespaces.
for pos, token_id in enumerate(token_ids, start=1):
if token_id in self._harmony_token_ids or pos == len(token_ids):
result.append(super()._decode(token_ids[prev_pos:pos], *args, **kwargs))
prev_pos = pos
return "".join(result)
def parse_harmony_message(self, token_ids: Sequence[int]) -> list[HarmonyMessage]:
"""Helper function to parse token IDs into Harmony messages."""
return HarmonyMessageParser(self).get_all_messages(token_ids)