reka-edge-2603 / tokenization_yasa2.py
donovanOng92's picture
upload
7d24555 verified
"""Tokenization and chat template helpers for Yasa2."""
import json
import logging
import os
from typing import Any, Callable, Optional
import tiktoken
from transformers import PreTrainedTokenizer
from transformers.utils.chat_template_utils import render_jinja_template
logger = logging.getLogger(__name__)
TokenBuilder = Callable[[dict[str, Any]], list[str]]
DEFAULT_TIKTOKEN_MODEL_NAME = "cl100k_base"
TIKTOKEN_SPECIAL_TOKENS = {
"<|endofchunk|>": 100277,
"<REKA_IMG_TOKEN>": 100278,
"<image>": 100279,
"</image>": 100280,
"<REKA_ADO_TOKEN>": 100281,
"<audio>": 100282,
"</audio>": 100283,
"<video>": 100284,
"</video>": 100285,
"<transcript>": 100286,
"</transcript>": 100287,
"<ocr>": 100288,
"</ocr>": 100289,
"<sep>": 100290,
"<tool>": 100291,
"<tool_call>": 100292,
"</tool_call>": 100293,
"<tool_response>": 100294,
"</tool_response>": 100295,
}
def normalize_message_content(content: Any) -> list[dict[str, Any]]:
"""Normalize content into a list of multimodal content dicts.
Args:
content: String or list of content items.
Returns:
Normalized list of content items.
"""
if isinstance(content, str):
return [{"type": "text", "text": content}]
if isinstance(content, list):
return content
raise ValueError(
"Message content must be a string or list of content items."
)
def _build_tools_block(tools: list[dict[str, Any]]) -> str:
"""Build the tool instructions block for system prompts.
Args:
tools: List of tool schema dictionaries.
Returns:
Rendered tools block string.
"""
tools_block = (
"# Tools\n\n"
"You may call one or more functions to assist with the user query.\n\n"
"You are provided with function signatures within <tools></tools> XML tags:\n"
"<tools>"
)
for tool in tools:
tools_block += f"\n{json.dumps(tool)}"
tools_block += (
"\n</tools>\n\n"
"For each function call, return a json object with function name and arguments "
"within <tool_call></tool_call> XML tags:\n"
"<tool_call>\n"
'{"name": <function-name>, "arguments": <args-json-object>}\n'
"</tool_call>"
)
return tools_block
def _render_chat_template(template: str, context: dict[str, Any]) -> str:
"""Render a Jinja chat template with the provided context."""
rendered, _ = render_jinja_template(
conversations=[context.get("messages", [])],
tools=context.get("tools"),
documents=context.get("documents"),
chat_template=template,
return_assistant_tokens_mask=False,
continue_final_message=context.get("continue_final_message", False),
add_generation_prompt=context.get("add_generation_prompt", False),
**{
key: value
for key, value in context.items()
if key
not in {
"messages",
"tools",
"documents",
"continue_final_message",
"add_generation_prompt",
}
},
)
return rendered[0]
def build_chat_prompt(
messages: list[dict[str, Any]],
*,
add_generation_prompt: bool,
continue_final_message: bool,
tools: Optional[list[dict[str, Any]]],
image_token_builder: TokenBuilder,
video_token_builder: TokenBuilder,
enable_thinking: Optional[bool] = None,
) -> str:
"""Build the Yasa-style chat prompt with Yasa2 spacing rules.
This helper is shared by the tokenizer and processor so remote code only
needs the tokenizer module to format prompts.
Args:
messages: HF-style list of chat messages.
add_generation_prompt: Whether to append the assistant prefix at the end.
continue_final_message: Whether to keep the final assistant turn open.
tools: Optional tool schema list for system prompt injection.
image_token_builder: Builder that returns image placeholder tokens.
video_token_builder: Builder that returns video placeholder tokens.
enable_thinking: When True, emit <think> blocks for assistant turns that
include reasoning content in the latest query window.
Returns:
Fully formatted chat prompt string.
"""
if messages is None:
messages = []
elif not isinstance(messages, list):
messages = (
messages.tolist()
if hasattr(messages, "tolist")
else list(messages)
)
# Validate message format.
for idx, message in enumerate(messages):
if not isinstance(message, dict):
raise ValueError(f"Message at index {idx} must be a dict.")
if "role" not in message:
raise ValueError(f"Message at index {idx} is missing 'role'.")
if "content" not in message and message.get("role") != "assistant":
raise ValueError(f"Message at index {idx} is missing 'content'.")
parts: list[str] = []
def append_text(
text: str, prev_was_text: bool, prev_was_media: bool
) -> None:
if not text:
return
if prev_was_text:
parts.append(" ")
elif prev_was_media:
# Keep text adjacent to media markers to match Yasa2 formatting.
pass
parts.append(text)
def extract_reasoning_and_content(
content: Any,
reasoning_value: Any,
) -> tuple[str, str]:
content_text = ""
reasoning_text = ""
if isinstance(content, list):
text_items = []
for item in content:
if item.get("type") != "text" and "text" not in item:
raise ValueError(
"Assistant message content must be text-only."
)
text = item.get("text", "")
if text:
text_items.append(text)
content_text = " ".join(text_items)
elif content is None:
content_text = ""
elif isinstance(content, str):
content_text = content
else:
raise ValueError(
"Assistant message content must be a string or list of content items."
)
if isinstance(reasoning_value, str):
reasoning_text = reasoning_value
elif "</think>" in content_text:
before, after = content_text.split("</think>", 1)
if "<think>" in before:
reasoning_text = before.split("<think>")[-1]
else:
reasoning_text = before
reasoning_text = reasoning_text.rstrip("\n").lstrip("\n")
content_text = after.lstrip("\n")
return content_text, reasoning_text
def find_last_query_index(messages: list[dict[str, Any]]) -> int:
last_query_index = len(messages) - 1
for idx in range(len(messages) - 1, -1, -1):
message = messages[idx]
if message.get("role") != "user":
continue
content = message.get("content", "")
if (
isinstance(content, str)
and content.startswith("<tool_response>")
and content.endswith("</tool_response>")
):
continue
last_query_index = idx
break
return last_query_index
last_query_index = find_last_query_index(messages)
# 1) Optional system/tools preamble.
start_idx = 0
system_text = ""
if len(messages) > 0 and messages[0].get("role") in (
"system",
"developer",
):
system_text = messages[0].get("content", "")
if not isinstance(system_text, str):
raise ValueError("System message content must be a string.")
start_idx = 1
if tools or system_text:
# System/tools appear as a single preamble turn.
tools_block = _build_tools_block(tools) if tools else None
if tools_block:
if system_text:
parts.append(f"system: {system_text}\n\n{tools_block}")
else:
parts.append(f"system: {tools_block}")
elif system_text:
parts.append(f"system: {system_text}")
parts.append("\n\n")
parts.append("<sep>")
# 2) Render each role turn into the Yasa2 chat template.
for idx in range(start_idx, len(messages)):
message = messages[idx]
role = message.get("role")
if role == "user":
# Users can interleave text and media.
content = message.get("content")
content_items = normalize_message_content(content)
prefix = "human: "
if isinstance(content, list):
for item in content:
if not isinstance(item, dict):
continue
if item.get("type") == "text" or "text" in item:
text = item.get("text", "")
if text:
break
continue
if item.get("type") in [
"image",
"image_url",
"video",
"video_url",
]:
prefix = "human:"
break
parts.append(prefix)
prev_was_text = False
prev_was_media = False
for item in content_items:
item_type = item.get("type")
if item_type == "text" or "text" in item:
text = item.get("text", "")
append_text(
text,
prev_was_text=prev_was_text,
prev_was_media=prev_was_media,
)
prev_was_text = bool(text)
prev_was_media = False
elif (
item_type in ["image", "image_url"]
or "image" in item
or "image_url" in item
):
parts.extend(image_token_builder(item))
prev_was_text = False
prev_was_media = True
elif item_type in ["video", "video_url"] or "video" in item:
parts.extend(video_token_builder(item))
prev_was_text = False
prev_was_media = True
else:
raise ValueError(
f"Unsupported content type: {item_type}. "
"Only 'text', 'image', 'image_url', "
"'video', and 'video_url' are supported."
)
parts.append("<sep>")
next_role = (
messages[idx + 1].get("role")
if idx < len(messages) - 1
else None
)
# Generation prompt is handled after the loop to avoid duplicate prefixes.
elif role == "assistant":
# Assistant is text-only, with optional tool calls and reasoning.
tool_calls = message.get("tool_calls")
if (
tool_calls is not None
and hasattr(tool_calls, "tolist")
and not isinstance(tool_calls, (str, bytes, dict))
):
tool_calls = tool_calls.tolist()
content = message.get("content")
if content is None and tool_calls:
content = ""
content_text, reasoning_text = extract_reasoning_and_content(
content, message.get("reasoning_content")
)
content_items = normalize_message_content(content_text)
if any(item.get("type") != "text" for item in content_items):
raise ValueError(
"Assistant message content must be text-only."
)
parts.append("assistant: ")
include_thinking = (
enable_thinking is True
and idx > last_query_index
and (idx == len(messages) - 1 or reasoning_text)
)
if include_thinking:
parts.append("<think>\n")
parts.append(reasoning_text.strip("\n"))
parts.append("\n</think>\n\n")
assistant_has_text = False
for item in content_items:
text = item.get("text", "")
if text:
if assistant_has_text:
parts.append(" ")
parts.append(text)
assistant_has_text = True
if tool_calls:
tool_call_texts = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
raise ValueError(
"Tool call entries must be JSON objects."
)
if tool_call.get("function"):
if not isinstance(tool_call["function"], dict):
raise ValueError(
"Tool call 'function' must be a JSON object."
)
tool_call = tool_call["function"]
arguments = tool_call.get("arguments", {})
if isinstance(arguments, str):
arguments_json = arguments
elif isinstance(arguments, dict):
arguments_json = json.dumps(arguments)
else:
tool_name = tool_call.get("name", "unknown")
raise ValueError(
"Tool call arguments must be a JSON object or string; "
f"got {type(arguments).__name__} for tool '{tool_name}'."
)
tool_call_texts.append(
"<tool_call>\n"
f'{{"name": "{tool_call.get("name", "")}", "arguments": {arguments_json}}}\n'
"</tool_call>"
)
if (
assistant_has_text
and parts
and not parts[-1].endswith("\n")
):
parts.append("\n")
for tool_call_text in tool_call_texts:
parts.append(tool_call_text)
assistant_has_text = True
if not (continue_final_message and idx == len(messages) - 1):
parts.append("\n\n")
parts.append("<sep>")
elif role == "tool":
# Tool responses are injected as human turns.
content_items = normalize_message_content(message.get("content"))
if idx == start_idx or messages[idx - 1].get("role") != "tool":
parts.append("human: ")
response_parts = []
for item in content_items:
item_type = item.get("type")
if item_type != "text":
raise ValueError(
"Unsupported content type: "
f"{item_type}. Only text tool responses are supported."
)
text = item.get("text", "")
if text:
response_parts.append(text)
response_text = " ".join(response_parts)
append_text(
f"<tool_response>\n{response_text}\n</tool_response>",
prev_was_text=False,
prev_was_media=False,
)
if (
idx == len(messages) - 1
or messages[idx + 1].get("role") != "tool"
):
parts.append("<sep>")
elif role in ("system", "developer"):
raise ValueError("System message must be the first message.")
else:
raise ValueError(
f"Unsupported message role: {role}. "
"Only 'system', 'developer', 'user', 'assistant', and 'tool' roles are supported."
)
# 3) Optionally append assistant prefix for generation.
if add_generation_prompt and (
not messages or messages[-1].get("role") != "assistant"
):
if enable_thinking is True:
parts.append("assistant: <think>\n")
else:
parts.append("assistant:")
return "".join([p for p in parts if p != ""])
class Yasa2Tokenizer(PreTrainedTokenizer):
vocab_files_names = {
"tiktoken_special_tokens": "tiktoken_special_tokens.json"
}
pretrained_vocab_files_map: dict[str, str] = {}
model_input_names = ["input_ids", "attention_mask"]
def __init__(self, tiktoken_special_tokens=None, **kwargs):
"""Initialize the tiktoken-backed tokenizer.
Args:
tiktoken_special_tokens: Optional mapping or path to special tokens.
**kwargs: Additional tokenizer config (requires tiktoken_model_name).
"""
tiktoken.registry._find_constructors()
tiktoken_model_name = kwargs.get("tiktoken_model_name")
if not tiktoken_model_name:
raise ValueError(
"'tiktoken_model_name' is required to initialize Yasa2Tokenizer."
)
kwargs.setdefault("model_max_length", 8192)
base_kwargs = tiktoken.registry.ENCODING_CONSTRUCTORS[
tiktoken_model_name
]()
if isinstance(tiktoken_special_tokens, str):
with open(tiktoken_special_tokens) as f:
special_tokens = json.load(f)
elif isinstance(tiktoken_special_tokens, dict):
special_tokens = tiktoken_special_tokens
else:
special_tokens = TIKTOKEN_SPECIAL_TOKENS
special_tokens = dict(special_tokens)
used_token_ids = set(base_kwargs["special_tokens"].values()).union(
set(base_kwargs["mergeable_ranks"].values())
)
collision = used_token_ids.intersection(special_tokens.values())
if collision:
raise ValueError(
f"special token overlapping with tiktoken builtin {collision}"
)
# For saving later
self.tiktoken_special_tokens = dict(special_tokens)
# Add extra special tokens remaining in the reserved range
for i in range(100256, 100352):
if i not in special_tokens.values():
special_tokens[f"<|special_token_{i}|>"] = i
# Update the base kwargs with the special tokens
base_kwargs["special_tokens"].update(
{token: token_id for token, token_id in special_tokens.items()}
)
self.tiktoken = tiktoken.Encoding(**base_kwargs)
kwargs.pop("add_prefix_space", None)
super().__init__(add_prefix_space=False, **kwargs)
self.pad_token = "<|endoftext|>"
self.eos_token = "<|endoftext|>"
self.utf8_decoding_strategy = kwargs.get(
"utf8_decoding_strategy", "replace"
)
self.allowed_special_tokens = set(special_tokens.keys())
self.clean_up_tokenization_spaces = False
@property
def max_token_id(self) -> int:
"""Get the maximum token ID in the vocabulary.
Args:
None.
Returns:
Maximum token ID value.
"""
return self.tiktoken.max_token_value
def apply_chat_template(
self,
messages: Optional[list[dict[str, Any]]] = None,
chat_template: Optional[str] = None,
tokenize: bool = False,
add_generation_prompt: bool = False,
continue_final_message: bool = False,
return_tensors: Optional[str] = None,
return_dict: bool = False,
tools: Optional[list[dict[str, Any]]] = None,
num_img_tokens: Optional[int] = None,
num_video_frames: Optional[int] = None,
enable_thinking: Optional[bool] = None,
**kwargs: Any,
) -> Any:
"""Apply the chat template to the messages.
Args:
messages: list of messages in the conversation.
chat_template: Optional chat template to use. If None, uses the
default template.
tokenize: Whether to tokenize the formatted prompt.
add_generation_prompt: Whether to add the generation prompt.
continue_final_message: Whether to continue the final message.
return_tensors: Tensor type for outputs (e.g. "pt").
return_dict: Whether to return a dict payload.
tools: Optional list of tool specifications.
num_img_tokens: Optional image token repeat count per image placeholder.
num_video_frames: Optional frame count for video placeholder repetition.
enable_thinking: When True, insert <think> blocks for assistant turns
that contain reasoning content and for the generation prompt.
**kwargs: Additional arguments to pass to the template.
Returns:
Prompt string if tokenize is False. Otherwise token IDs or a dict
payload when return_dict is True.
Raises:
ValueError: If the message format is invalid or
contains unsupported content types.
"""
if messages is None:
if "conversation" in kwargs:
messages = kwargs.pop("conversation")
else:
messages = []
elif not isinstance(messages, list):
messages = (
messages.tolist()
if hasattr(messages, "tolist")
else list(messages)
)
if continue_final_message and (
not messages or messages[-1]["role"] != "assistant"
):
raise ValueError(
"'continue_final_message' requires the last message to be from the assistant."
)
if continue_final_message and add_generation_prompt:
logger.warning(
"'add_generation_prompt' is ignored when 'continue_final_message' is set."
)
num_img_tokens = 64 if num_img_tokens is None else num_img_tokens
num_video_frames = 6 if num_video_frames is None else num_video_frames
def image_builder(_: dict[str, Any]) -> list[str]:
return (
["<image>"]
+ ["<REKA_IMG_TOKEN>"] * num_img_tokens
+ ["</image>"]
)
def video_builder(_: dict[str, Any]) -> list[str]:
return (
["<video>"]
+ ["<REKA_IMG_TOKEN>"] * (num_img_tokens * num_video_frames)
+ ["</video>"]
)
template_source = chat_template or getattr(self, "chat_template", None)
if template_source:
if os.path.isfile(template_source):
with open(template_source, "r", encoding="utf-8") as handle:
template_source = handle.read()
prompt = _render_chat_template(
template_source,
{
"messages": messages,
"add_generation_prompt": add_generation_prompt,
"continue_final_message": continue_final_message,
"tools": tools,
"enable_thinking": enable_thinking,
"num_img_tokens": num_img_tokens,
"num_video_frames": num_video_frames,
},
)
else:
prompt = self.build_chat_prompt(
messages,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
image_token_builder=image_builder,
video_token_builder=video_builder,
enable_thinking=enable_thinking,
)
if not tokenize:
return prompt
text_input = [prompt] if return_tensors is not None else prompt
encoded = self(
text_input,
add_special_tokens=False,
return_tensors=return_tensors,
**kwargs,
)
if return_dict:
return encoded
return encoded["input_ids"]
def build_chat_prompt(
self,
messages: list[dict[str, Any]],
*,
add_generation_prompt: bool,
continue_final_message: bool,
tools: Optional[list[dict[str, Any]]],
image_token_builder: TokenBuilder,
video_token_builder: TokenBuilder,
enable_thinking: Optional[bool] = None,
) -> str:
"""Build a Yasa2 prompt using tokenizer-shared formatting helpers.
The processor calls this method to ensure a single prompt formatter is
used across text-only and multimodal paths.
"""
return build_chat_prompt(
messages,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tools=tools,
image_token_builder=image_token_builder,
video_token_builder=video_token_builder,
enable_thinking=enable_thinking,
)
@property
def vocab_size(self) -> int:
"""Return the size of the base vocabulary.
Args:
None.
Returns:
Vocabulary size including added tokens.
"""
return self.tiktoken.max_token_value + 1
def get_vocab(self) -> dict[bytes, int]:
"""Return a mapping of token bytes to IDs.
Args:
None.
Returns:
Dictionary mapping token bytes to IDs.
"""
ret = {}
for i in range(self.tiktoken.max_token_value + 1):
try:
ret[self.tiktoken.decode_single_token_bytes(i)] = i
except Exception as e:
raise ValueError(f"Error decoding token {i}: {e}") from e
return ret
def _tokenize(self, text: str, **kwargs: Any) -> list[bytes]:
"""Convert a string into a sequence of tokens (bytes).
Args:
text: Input string.
**kwargs: Unused tokenizer kwargs.
Returns:
List of token bytes.
"""
return [
self._convert_id_to_token(t)
for t in self.tiktoken.encode(text, allowed_special="all")
]
def _convert_token_to_id(self, token: bytes | str) -> int:
"""Convert a token string/bytes to its integer ID.
Args:
token: Token to convert.
Returns:
Token ID.
"""
return self.tiktoken.encode_single_token(token)
def _convert_id_to_token(self, index: int) -> bytes:
"""Convert a token ID to its byte representation.
Args:
index: Token ID.
Returns:
Token bytes.
"""
return self.tiktoken.decode_single_token_bytes(index)
def convert_tokens_to_string(self, tokens: list[bytes | str]) -> str:
"""Convert a list of tokens into a decoded string.
Args:
tokens: Sequence of token bytes or strings.
Returns:
Decoded text string.
"""
# Convert all tokens to bytes first
bytes_tokens = [
t.encode("utf-8") if isinstance(t, str) else t for t in tokens
]
return b"".join(bytes_tokens).decode(
"utf8", errors=self.utf8_decoding_strategy
)
def save_vocabulary(
self, save_directory: str, filename_prefix: Optional[str] = None
) -> tuple[str]:
"""Save the tokenizer vocabulary to files.
Args:
save_directory: Directory to save vocab files.
filename_prefix: Optional filename prefix.
Returns:
Tuple with saved file paths (tiktoken special token mapping).
"""
os.makedirs(save_directory, exist_ok=True)
filename = "tiktoken_special_tokens.json"
if filename_prefix:
filename = f"{filename_prefix}-{filename}"
save_path = os.path.join(save_directory, filename)
with open(save_path, "w", encoding="utf-8") as f:
json.dump(self.tiktoken_special_tokens, f, indent=4)
return (save_path,)
def save_pretrained(
self,
save_directory: str | os.PathLike,
legacy_format: Optional[bool] = None,
filename_prefix: Optional[str] = None,
push_to_hub: bool = False,
**kwargs: Any,
) -> tuple[str, ...]:
self.pad_token = None
self.eos_token = None
return super().save_pretrained(
save_directory,
legacy_format,
filename_prefix,
push_to_hub,
**kwargs,
)
Yasa2Tokenizer.register_for_auto_class()