| |
|
|
| from collections.abc import Sequence |
| import os |
|
|
| from transformers import LlamaTokenizerFast |
| from tokenizers import Tokenizer |
|
|
| from .llmjp4_harmony import HarmonyMessageParser, HarmonyMessage |
|
|
|
|
| class Llmjp4Tokenizer(LlamaTokenizerFast): |
| _HARMONY_TOKENS: set[str] = { |
| "<|start|>", |
| "<|message|>", |
| "<|channel|>", |
| "<|constrain|>", |
| "<|end|>", |
| "<|return|>", |
| "<|call|>", |
| } |
|
|
| |
| |
| |
| |
| _RESPONSE_SCHEMA = { |
| "type": "object", |
| "properties": { |
| "role": {"const": "assistant"}, |
| "content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|<\|return\|>|$)"}, |
| "thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"}, |
| "tool_calls": { |
| "x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)", |
| "type": "array", |
| "items": { |
| "type": "object", |
| "properties": { |
| "type": {"const": "function"}, |
| "function": { |
| "type": "object", |
| "properties": { |
| "name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"}, |
| "arguments": { |
| "type": "object", |
| "x-regex": r"<\|message\|>(.*)", |
| "x-parser": "json", |
| "additionalProperties": {"type": "any"}, |
| }, |
| }, |
| }, |
| }, |
| }, |
| }, |
| }, |
| } |
|
|
| @classmethod |
| def convert_to_native_format(cls, **kwargs): |
| |
| |
| |
| |
| local_kwargs = dict(kwargs) |
| fast_tokenizer_file = local_kwargs.pop("tokenizer_file", None) |
| if fast_tokenizer_file is None or not os.path.isfile(fast_tokenizer_file): |
| raise ValueError("Tokenizer file must exist.") |
|
|
| local_kwargs["tokenizer_object"] = Tokenizer.from_file(fast_tokenizer_file) |
| return local_kwargs |
|
|
| def __init__(self, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
|
|
| self.response_schema = self._RESPONSE_SCHEMA |
|
|
| self._harmony_token_ids = { |
| self.convert_tokens_to_ids(token) |
| for token in self._HARMONY_TOKENS |
| } |
|
|
| def _decode(self, token_ids: int | list[int], *args, **kwargs): |
| if isinstance(token_ids, int): |
| token_ids = [token_ids] |
|
|
| result: list[str] = [] |
| prev_pos = 0 |
|
|
| |
| |
| |
| for pos, token_id in enumerate(token_ids, start=1): |
| if token_id in self._harmony_token_ids or pos == len(token_ids): |
| result.append(super()._decode(token_ids[prev_pos:pos], *args, **kwargs)) |
| prev_pos = pos |
|
|
| return "".join(result) |
|
|
| def parse_harmony_message(self, token_ids: Sequence[int]) -> list[HarmonyMessage]: |
| """Helper function to parse token IDs into Harmony messages.""" |
| return HarmonyMessageParser(self).get_all_messages(token_ids) |
|
|