|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| from pathlib import Path
|
| from typing import TypeVar
|
|
|
| from jinja2 import TemplateError
|
| from transformers import AddedToken, AutoTokenizer, PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin
|
|
|
| from .data_utils import prepare_multimodal_messages
|
|
|
|
|
| _CHAT_TEMPLATES_DIR = Path(__file__).parent / "chat_templates"
|
|
|
|
|
| def clone_chat_template(
|
| model: PreTrainedModel,
|
| tokenizer: PreTrainedTokenizerBase,
|
| source_tokenizer_path: str,
|
| resize_to_multiple_of: int | None = 64,
|
| ) -> tuple[PreTrainedModel, PreTrainedTokenizerBase, list[int]]:
|
| """
|
| Clones a chat template from a source tokenizer to the target tokenizer and updates the model accordingly.
|
|
|
| This function:
|
| - Copies the chat template from a source tokenizer to the target tokenizer.
|
| - Adds any new tokens from the source tokenizer to the target tokenizer.
|
| - Sets and synchronizes the EOS token across the tokenizer and model.
|
| - Resizes the model's token embeddings to match the new vocabulary size, optionally rounding it up to a multiple of
|
| a specified value. In such cases, dummy tokens are added to the tokenizer to ensure the vocabulary size matches
|
| the embedding dimensions.
|
|
|
| Args:
|
| model ([`~transformers.PreTrainedModel`]):
|
| Model to update.
|
| tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
|
| Tokenizer to update.
|
| source_tokenizer_path (`str`):
|
| Path or identifier of the pretrained tokenizer to clone from.
|
| resize_to_multiple_of (`int` or `None`, *optional*, defaults to `64`):
|
| The embedding layer will be resized to the new vocabulary size. If this is not `None`, it will round up the
|
| new vocabulary size to the nearest multiple of this value.
|
|
|
| Returns:
|
| model ([`~transformers.PreTrainedModel`]):
|
| Updated model with resized token embeddings and EOS token configured.
|
| tokenizer ([`~transformers.PreTrainedTokenizerBase`]):
|
| Updated tokenizer with the chat template and special tokens applied.
|
| added_tokens (`list[int]`):
|
| List of tokens that were added to the tokenizer from the source tokenizer.
|
|
|
| Example:
|
| ```python
|
| from transformers import AutoModelForCausalLM, AutoTokenizer
|
| from trl import clone_chat_template
|
|
|
| model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B")
|
| tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1B")
|
| model, tokenizer, added_tokens = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
|
| ```
|
| """
|
|
|
| tokenizer_source = AutoTokenizer.from_pretrained(source_tokenizer_path)
|
|
|
|
|
| tokenizer.chat_template = tokenizer_source.get_chat_template()
|
|
|
|
|
| added_tokens = [
|
| token for token in tokenizer_source.added_tokens_decoder.values() if token.content not in tokenizer.vocab
|
| ]
|
| tokenizer.add_tokens(added_tokens)
|
|
|
|
|
| tokenizer.eos_token = tokenizer_source.eos_token
|
| model.config.eos_token_id = tokenizer.eos_token_id
|
| if model.can_generate():
|
| model.generation_config.eos_token_id = tokenizer.eos_token_id
|
|
|
|
|
| model.resize_token_embeddings(
|
|
|
|
|
|
|
| new_num_tokens=len(tokenizer.vocab),
|
| pad_to_multiple_of=resize_to_multiple_of if resize_to_multiple_of is not None else None,
|
| )
|
|
|
|
|
|
|
| idx = 0
|
| while model.vocab_size > len(tokenizer.vocab):
|
| dummy_token = AddedToken(f"<extra_id_{idx}>")
|
| is_added = tokenizer.add_tokens(dummy_token)
|
| idx += 1
|
| if is_added == 1:
|
| added_tokens.append(dummy_token)
|
|
|
|
|
| if len(tokenizer.vocab) != model.vocab_size:
|
| raise RuntimeError(
|
| f"Vocabulary size mismatch after resizing: tokenizer vocab size is {len(tokenizer.vocab)}, but model "
|
| f"embedding size is {model.vocab_size}. This indicates an internal error in the token alignment process."
|
| )
|
| added_tokens = [token.content for token in added_tokens]
|
| added_tokens = tokenizer.convert_tokens_to_ids(added_tokens)
|
| return model, tokenizer, added_tokens
|
|
|
|
|
| glm4moe_schema = {
|
| "x-regex": r"^(?:\n?<think>\n?(?:(?P<reasoning_content>.*?\S.*?)\n?|[\s]*)</think>\s*)?(?P<content>.*?)(?:\n(?=<tool_call>))?(?=(?:<tool_call>|$))(?P<tool_calls>(?:<tool_call>.+?</tool_call>\s*)+)?$",
|
| "type": "object",
|
| "properties": {
|
| "role": {"const": "assistant"},
|
| "content": {"type": "string"},
|
| "reasoning_content": {"type": "string"},
|
| "tool_calls": {
|
| "type": "array",
|
| "x-regex-iterator": r"<tool_call>\s*(.+?)\s*</tool_call>",
|
| "items": {
|
| "type": "object",
|
| "properties": {
|
| "type": {"const": "function"},
|
| "function": {
|
| "type": "object",
|
| "properties": {
|
| "name": {"type": "string", "x-regex": r"^(\S+)"},
|
| "arguments": {
|
| "type": "object",
|
| "x-regex-key-value": r"<arg_key>(?P<key>[^<]+)</arg_key>\s*\n<arg_value>(?P<value>.*?)</arg_value>",
|
| "default": {},
|
| "additionalProperties": {
|
| "x-parser": "json",
|
| "x-parser-args": {"allow_non_json": True},
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| }
|
|
|
| gptoss_schema = {
|
|
|
| "x-regex-substitutions": [
|
| [r"<\|channel\|>final<\|message\|>(.*?)<\|return\|>", r"<|channel|>analysis<|message|>\1<|end|>"],
|
| ],
|
| "x-regex": r"^(?:<\|channel\|>analysis<\|message\|>(?P<content>.*?)<\|end\|>(?:<\|start\|>assistant)?)?\s*(?P<tool_calls>to=functions\.\S+<\|channel\|>commentary json<\|message\|>.*?<\|call\|>)?$",
|
| "type": "object",
|
| "properties": {
|
| "role": {"const": "assistant"},
|
| "content": {"type": "string"},
|
| "tool_calls": {
|
| "type": "array",
|
| "x-regex-iterator": r"(to=functions\.\S+<\|channel\|>commentary json<\|message\|>.*?<\|call\|>)",
|
| "items": {
|
|
|
|
|
| "x-regex-substitutions": [
|
| [
|
| r"to=functions\.(\S+)<\|channel\|>commentary json<\|message\|>(.*?)<\|call\|>",
|
| r'{"name": "\1", "arguments": \2}',
|
| ],
|
| ],
|
| "x-parser": "json",
|
| "x-parser-args": {"transform": "{type: 'function', function: @}"},
|
| "type": "object",
|
| "properties": {
|
| "type": {"const": "function"},
|
| "function": {
|
| "type": "object",
|
| "properties": {
|
| "name": {"type": "string"},
|
| "arguments": {
|
| "type": "object",
|
| "additionalProperties": {},
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| }
|
|
|
|
|
|
|
| qwen3_schema = {
|
| "x-regex": r"^(?:<think>\n?(?:(?P<reasoning_content>.*?\S.*?)\n?|[\s]*)</think>\s*)?(?P<content>.*?)(?:\n(?=<tool_call>))?(?=(?:<tool_call>|<\|im_end\|>|$))(?P<tool_calls>(?:<tool_call>.+?</tool_call>\s*)+)?\s*(?:<\|im_end\|>|$)",
|
| "type": "object",
|
| "properties": {
|
| "role": {"const": "assistant"},
|
| "content": {"type": "string"},
|
| "reasoning_content": {"type": "string"},
|
| "tool_calls": {
|
| "type": "array",
|
| "x-regex-iterator": r"<tool_call>\s*(.+?)\s*</tool_call>",
|
| "items": {
|
| "x-parser": "json",
|
| "x-parser-args": {"transform": "{type: 'function', function: @}"},
|
| "type": "object",
|
| "properties": {
|
| "type": {"const": "function"},
|
| "function": {
|
| "type": "object",
|
| "properties": {
|
| "name": {"type": "string"},
|
| "arguments": {
|
| "type": "object",
|
| "additionalProperties": {},
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| }
|
|
|
| llama3_schema = {
|
|
|
|
|
|
|
|
|
| "x-regex": r'^(?:(?P<tool_calls>\{"name":\s*".+?",\s*"parameters":\s*.+\})|(?P<content>.*?))(?:<\|eot_id\|>|$)',
|
| "type": "object",
|
| "properties": {
|
| "role": {"const": "assistant"},
|
| "content": {"type": "string"},
|
| "tool_calls": {
|
| "type": "array",
|
| "x-regex-iterator": r'(\{"name":\s*".+?",\s*"parameters":\s*.+\})',
|
| "items": {
|
|
|
|
|
| "x-regex-substitutions": [
|
| [r'^(\{"name":\s*"[^"]+",\s*)"parameters":', r'\1"arguments":'],
|
| ],
|
| "x-parser": "json",
|
| "x-parser-args": {"transform": "{type: 'function', function: @}"},
|
| "type": "object",
|
| "properties": {
|
| "type": {"const": "function"},
|
| "function": {
|
| "type": "object",
|
| "properties": {
|
| "name": {"type": "string"},
|
| "arguments": {
|
| "type": "object",
|
| "additionalProperties": {},
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| }
|
|
|
| qwen3_5_schema = {
|
| "x-regex": r"^(?:(?:<think>\n?)?(?:(?P<reasoning_content>.*?\S.*?)\n?|[\s]*)</think>\s*)?(?P<content>.*?)(?:\n+(?=<tool_call>))?(?=(?:<tool_call>|<\|im_end\|>|$))(?P<tool_calls>(?:<tool_call>.+?</tool_call>\s*)+)?\s*(?:<\|im_end\|>|$)",
|
| "type": "object",
|
| "properties": {
|
| "role": {"const": "assistant"},
|
| "content": {"type": "string"},
|
| "reasoning_content": {"type": "string"},
|
| "tool_calls": {
|
| "type": "array",
|
| "x-regex-iterator": r"<tool_call>\s*(.+?)\s*</tool_call>",
|
| "items": {
|
| "type": "object",
|
| "properties": {
|
| "type": {"const": "function"},
|
| "function": {
|
| "type": "object",
|
| "properties": {
|
| "name": {"type": "string", "x-regex": r"<function=([^\n>]+)>"},
|
| "arguments": {
|
| "type": "object",
|
| "x-regex-key-value": r"<parameter=(?P<key>[^>\n]+)>\n(?P<value>.*?)\n</parameter>",
|
| "default": {},
|
| "additionalProperties": {
|
| "x-parser": "json",
|
| "x-parser-args": {"allow_non_json": True},
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| },
|
| }
|
|
|
|
|
| deepseekv3_chat_template = (_CHAT_TEMPLATES_DIR / "deepseekv3.jinja").read_text()
|
|
|
| gemma_chat_template = (_CHAT_TEMPLATES_DIR / "gemma.jinja").read_text()
|
|
|
| glm4moe_chat_template = (_CHAT_TEMPLATES_DIR / "glm4moe.jinja").read_text()
|
|
|
| gptoss_chat_template = (_CHAT_TEMPLATES_DIR / "gptoss.jinja").read_text()
|
|
|
| llama3_chat_template = (_CHAT_TEMPLATES_DIR / "llama3.jinja").read_text()
|
|
|
| llama3_1_chat_template = (_CHAT_TEMPLATES_DIR / "llama3_1.jinja").read_text()
|
|
|
| llama3_2_chat_template = (_CHAT_TEMPLATES_DIR / "llama3_2.jinja").read_text()
|
|
|
| phi3_chat_template = (_CHAT_TEMPLATES_DIR / "phi3.jinja").read_text()
|
|
|
| qwen2_5_chat_template = (_CHAT_TEMPLATES_DIR / "qwen2_5.jinja").read_text()
|
|
|
| qwen3_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3.jinja").read_text()
|
|
|
| qwen3_vl_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_vl.jinja").read_text()
|
|
|
| qwen3_5_chat_template_2b_and_below = (_CHAT_TEMPLATES_DIR / "qwen3_5_2b_and_below.jinja").read_text()
|
|
|
| qwen3_5_chat_template_4b_and_above = (_CHAT_TEMPLATES_DIR / "qwen3_5_4b_and_above.jinja").read_text()
|
|
|
|
|
| ProcessingClassT = TypeVar("ProcessingClassT", PreTrainedTokenizerBase, ProcessorMixin)
|
|
|
|
|
| def add_response_schema(processing_class: ProcessingClassT) -> ProcessingClassT:
|
| r"""
|
| Adds the appropriate response schema to the given tokenizer based on its chat template.
|
|
|
| At the time of initial implementation, most tokenizers do not have built-in support for response schemas. While
|
| waiting for broader adoption, we provide this utility function to manually set the response schema for known chat
|
| templates.
|
|
|
| When given a VLM processor, the schema is set on the inner tokenizer, since `parse_response` is a tokenizer method
|
| and reads `self.response_schema` from the tokenizer instance.
|
|
|
| Args:
|
| processing_class (`PreTrainedTokenizerBase` or `ProcessorMixin`):
|
| Tokenizer or VLM processor to which the response schema will be added.
|
|
|
| Returns:
|
| `PreTrainedTokenizerBase` or `ProcessorMixin`:
|
| The same object that was passed in, with the response schema set on the underlying tokenizer.
|
|
|
| Examples:
|
|
|
| ```python
|
| >>> from trl.chat_template_utils import add_response_schema
|
| >>> from transformers import AutoTokenizer
|
|
|
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
| >>> tokenizer = add_response_schema(tokenizer)
|
| >>> assistant_text = '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n</tool_call><|im_end|>'
|
| >>> tokenizer.parse_response(assistant_text)
|
| {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]}
|
| ```
|
| """
|
|
|
|
|
|
|
| chat_template = processing_class.chat_template
|
| if isinstance(processing_class, ProcessorMixin):
|
| tokenizer = processing_class.tokenizer
|
| else:
|
| tokenizer = processing_class
|
| if chat_template == glm4moe_chat_template:
|
| tokenizer.response_schema = glm4moe_schema
|
| elif chat_template == gptoss_chat_template:
|
| tokenizer.response_schema = gptoss_schema
|
| elif chat_template in [llama3_1_chat_template, llama3_2_chat_template]:
|
| tokenizer.response_schema = llama3_schema
|
| elif chat_template in [qwen3_chat_template, qwen3_vl_chat_template]:
|
| tokenizer.response_schema = qwen3_schema
|
| elif chat_template in [qwen3_5_chat_template_2b_and_below, qwen3_5_chat_template_4b_and_above]:
|
| tokenizer.response_schema = qwen3_5_schema
|
| else:
|
| raise ValueError(
|
| "Unrecognized chat template, failed to add response schema. Please manually set the response schema on "
|
| "the tokenizer or processor. See the Transformers "
|
| "[docs](https://huggingface.co/docs/transformers/main/en/chat_response_parsing#response-parsing) for more "
|
| "details on response parsing."
|
| )
|
| return processing_class
|
|
|
|
|
| def supports_tool_calling(processing_class) -> bool:
|
| """
|
| Check if the processing class's chat template can render a full tool-calling conversation.
|
|
|
| This tests that (1) the template doesn't error when rendering a conversation with ``user → assistant (with
|
| tool_calls) → tool`` roles, and (2) every part of the tool-calling exchange — the assistant's tool call name, its
|
| arguments, and the tool message content — actually appears in the rendered output. Some templates silently swallow
|
| `tool_calls` (e.g. the basic Llama 3 template, which only reads `message['content']`) or tool messages (e.g.
|
| Cohere2, Phi3); both cases must be rejected.
|
|
|
| For VLMs (processors), the messages are converted to multimodal format via
|
| [`~trl.data_utils.prepare_multimodal_messages`] before rendering.
|
|
|
| Args:
|
| processing_class (`PreTrainedTokenizerBase` or `ProcessorMixin`):
|
| Tokenizer or processor instance to check.
|
|
|
| Returns:
|
| `bool`:
|
| `True` if the chat template supports tool-calling conversations, `False` otherwise.
|
| """
|
| if processing_class.chat_template is None:
|
| return False
|
|
|
| is_vlm = isinstance(processing_class, ProcessorMixin)
|
|
|
| _name_sentinel = "tool_name_a8f3e2b1"
|
| _arg_key_sentinel = "tool_arg_key_b9d4f5c2"
|
| _arg_val_sentinel = "tool_arg_val_d6e7a9f3"
|
| _content_sentinel = "tool_content_c4f9a8e2"
|
| tool_calls = [
|
| {
|
| "type": "function",
|
| "function": {"name": _name_sentinel, "arguments": {_arg_key_sentinel: _arg_val_sentinel}},
|
| }
|
| ]
|
| messages = [
|
| {"role": "user", "content": "hi"},
|
| {"role": "assistant", "content": "", "tool_calls": tool_calls},
|
| {"role": "tool", "name": _name_sentinel, "content": _content_sentinel},
|
| ]
|
|
|
| if is_vlm:
|
| messages = prepare_multimodal_messages(messages)
|
|
|
| try:
|
| rendered = processing_class.apply_chat_template(messages, tokenize=False)
|
| except TemplateError:
|
|
|
|
|
|
|
| return False
|
| except TypeError:
|
|
|
|
|
| tool_calls[0]["function"]["arguments"] = f'{{"{_arg_key_sentinel}": "{_arg_val_sentinel}"}}'
|
| try:
|
| rendered = processing_class.apply_chat_template(messages, tokenize=False)
|
| except TemplateError:
|
| return False
|
|
|
|
|
|
|
| return all(s in rendered for s in (_name_sentinel, _arg_key_sentinel, _arg_val_sentinel, _content_sentinel))
|
|
|
|
|
| def is_chat_template_prefix_preserving(processing_class: PreTrainedTokenizerBase | ProcessorMixin) -> bool:
|
| """
|
| Check whether the chat template preserves prefixes when applied.
|
|
|
| A prefix-preserving chat template renders earlier messages identically regardless of what messages follow. This
|
| property is required by `_get_tool_suffix_ids`, which extracts tool response formatting tokens by comparing
|
| tokenizations with and without tool messages appended.
|
|
|
| Args:
|
| processing_class (`PreTrainedTokenizerBase` or `ProcessorMixin`):
|
| Tokenizer or processor instance to check.
|
|
|
| Returns:
|
| `bool`:
|
| `True` if the chat template preserves prefixes, `False` otherwise.
|
| """
|
|
|
| dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": {}}}]
|
| messages1 = [
|
| {"role": "user", "content": "dummy"},
|
| {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
|
| ]
|
| messages2 = [
|
| {"role": "user", "content": "dummy"},
|
| {"role": "assistant", "content": "", "tool_calls": dummy_tool_calls},
|
| {"role": "tool", "name": "dummy", "content": "dummy"},
|
| ]
|
|
|
|
|
| is_vlm = isinstance(processing_class, ProcessorMixin)
|
| if is_vlm:
|
| from PIL import Image
|
|
|
| dummy_image = Image.new("RGB", (8, 8))
|
| messages1 = prepare_multimodal_messages(messages1, images=[dummy_image])
|
| messages2 = prepare_multimodal_messages(messages2, images=[dummy_image])
|
|
|
| try:
|
| ids1 = processing_class.apply_chat_template(messages1, tokenize=True, return_dict=False)
|
| ids2 = processing_class.apply_chat_template(
|
| messages2, tokenize=True, return_dict=False, add_generation_prompt=True
|
| )
|
| except TypeError:
|
|
|
|
|
| dummy_tool_calls = [{"type": "function", "function": {"name": "dummy", "arguments": "{}"}}]
|
| messages1[1]["tool_calls"] = dummy_tool_calls
|
| messages2[1]["tool_calls"] = dummy_tool_calls
|
| ids1 = processing_class.apply_chat_template(messages1, tokenize=True, return_dict=False)
|
| ids2 = processing_class.apply_chat_template(
|
| messages2, tokenize=True, return_dict=False, add_generation_prompt=True
|
| )
|
|
|
|
|
| if is_vlm:
|
| ids1 = ids1[0]
|
| ids2 = ids2[0]
|
|
|
| return ids2[: len(ids1)] == ids1
|
|
|
|
|
| deepseekv3_training_chat_template = (_CHAT_TEMPLATES_DIR / "deepseekv3_training.jinja").read_text()
|
|
|
| gemma_training_chat_template = (_CHAT_TEMPLATES_DIR / "gemma_training.jinja").read_text()
|
|
|
| glm4moe_training_chat_template = (_CHAT_TEMPLATES_DIR / "glm4moe_training.jinja").read_text()
|
|
|
| gptoss_training_chat_template = (_CHAT_TEMPLATES_DIR / "gptoss_training.jinja").read_text()
|
|
|
| llama3_training_chat_template = (_CHAT_TEMPLATES_DIR / "llama3_training.jinja").read_text()
|
|
|
| phi3_training_chat_template = (_CHAT_TEMPLATES_DIR / "phi3_training.jinja").read_text()
|
|
|
| qwen2_5_training_chat_template = (_CHAT_TEMPLATES_DIR / "qwen2_5_training.jinja").read_text()
|
|
|
| qwen3_training_chat_template = (_CHAT_TEMPLATES_DIR / "qwen3_training.jinja").read_text()
|
|
|
|
|
| def get_training_chat_template(tokenizer: PreTrainedTokenizerBase) -> str | None:
|
| r"""
|
| Get a training-compatible chat template, if needed.
|
|
|
| Returns a patched chat template that is prefix-preserving and includes `{%% generation %%}` / `{%% endgeneration
|
| %%}` markers for assistant-only loss masking. Returns `None` if the tokenizer's template already satisfies both
|
| requirements. Currently DeepSeek-V3, Gemma, Gemma2, GLM-4-MoE, GPT-OSS, LLaMA 3, Phi-3, Qwen2.5, and Qwen3 are
|
| supported.
|
|
|
| Args:
|
| tokenizer (`PreTrainedTokenizerBase`):
|
| Tokenizer instance to check.
|
|
|
| Returns:
|
| `str` or `None`:
|
| Training-compatible chat template, or `None` if no patching is needed.
|
|
|
| Example:
|
|
|
| ```python
|
| >>> from trl.chat_template_utils import get_training_chat_template
|
| >>> from transformers import AutoTokenizer
|
|
|
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
| >>> messages1 = [
|
| ... {"role": "user", "content": "What is 2 * 3?"},
|
| ... {
|
| ... "role": "assistant",
|
| ... "content": "",
|
| ... "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 2, "b": 3}}}],
|
| ... },
|
| ... ]
|
| >>> messages2 = messages1 + [
|
| ... {"role": "tool", "name": "multiply", "content": "6"},
|
| ... ]
|
| >>> tokenizer.apply_chat_template(messages1, tokenize=False)
|
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n'
|
|
|
| >>> tokenizer.apply_chat_template(messages2, tokenize=False, add_generation_prompt=True)
|
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\n6\n</tool_response><|im_end|>\n<|im_start|>assistant\n'
|
|
|
| >>> # ^ think tags missing
|
| >>> chat_template = get_training_chat_template(tokenizer)
|
| >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template)
|
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n'
|
|
|
| >>> tokenizer.apply_chat_template(
|
| ... messages2, tokenize=False, add_generation_prompt=True, chat_template=chat_template
|
| ... )
|
| '<|im_start|>user\nWhat is 2 * 3?<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n<tool_call>\n{"name": "multiply", "arguments": {"a": 2, "b": 3}}\n</tool_call><|im_end|>\n<|im_start|>user\n<tool_response>\n6\n</tool_response><|im_end|>\n<|im_start|>assistant\n'
|
| ```
|
| """
|
|
|
|
|
| prefix_ok = not supports_tool_calling(tokenizer) or is_chat_template_prefix_preserving(tokenizer)
|
| if prefix_ok and "{% generation %}" in tokenizer.chat_template:
|
| return None
|
|
|
| if tokenizer.chat_template == deepseekv3_chat_template:
|
| return deepseekv3_training_chat_template
|
|
|
| if tokenizer.chat_template == gemma_chat_template:
|
| return gemma_training_chat_template
|
|
|
| if tokenizer.chat_template == glm4moe_chat_template:
|
| return glm4moe_training_chat_template
|
|
|
| if tokenizer.chat_template == gptoss_chat_template:
|
| return gptoss_training_chat_template
|
|
|
| if tokenizer.chat_template == llama3_chat_template:
|
| return llama3_training_chat_template
|
|
|
| if tokenizer.chat_template == phi3_chat_template:
|
| return phi3_training_chat_template
|
|
|
| if tokenizer.chat_template == qwen2_5_chat_template:
|
| return qwen2_5_training_chat_template
|
|
|
| if tokenizer.chat_template == qwen3_chat_template:
|
| return qwen3_training_chat_template
|
|
|
| raise ValueError(
|
| "The tokenizer's chat template is not training-compatible (missing prefix-preservation or "
|
| "`{% generation %}` markers) and patching is not supported for this template. "
|
| "Please manually modify the tokenizer's chat template for training."
|
| )
|
|
|
|
|
| def _validate_tool_calls(tool_calls: list | None) -> None:
|
| """
|
| Validate tool_calls to ensure all required fields exist with valid values.
|
|
|
| Raises ValueError when the model generates malformed tool calls (e.g., missing 'arguments' field) that are
|
| partially parsed.
|
|
|
| Args:
|
| tool_calls: List of tool call dictionaries, or None.
|
| """
|
| if tool_calls is None:
|
| return None
|
| if not isinstance(tool_calls, list):
|
| raise ValueError("tool_calls must be a list or None.")
|
|
|
| for idx, tool_call in enumerate(tool_calls):
|
| if not isinstance(tool_call, dict):
|
| raise ValueError(f"tool_calls[{idx}] must be a dict.")
|
|
|
|
|
| if "function" in tool_call:
|
| func = tool_call["function"]
|
| if not isinstance(func, dict):
|
| raise ValueError(f"tool_calls[{idx}]['function'] must be a dict.")
|
| if not isinstance(func.get("name"), str):
|
| raise ValueError(f"tool_calls[{idx}]['function']['name'] must be a string.")
|
|
|
| if "arguments" not in func or func["arguments"] is None:
|
| func["arguments"] = {}
|
| else:
|
|
|
| if not isinstance(tool_call.get("name"), str):
|
| raise ValueError(f"tool_calls[{idx}]['name'] must be a string.")
|
|
|
| if "arguments" not in tool_call or tool_call["arguments"] is None:
|
| tool_call["arguments"] = {}
|
|
|
|
|
| def parse_response(processing_class: PreTrainedTokenizerBase | ProcessorMixin, ids: list[int]) -> dict:
|
| r"""
|
| Parse a token sequence into structured response dictionaries with fallback handling.
|
|
|
| Attempts to parse the sequence using `tokenizer.parse_response()`. If parsing fails (e.g., due to malformed tool
|
| calls like `<tool_call>{"type":"function"</tool_call>`), falls back to decoding as plain text.
|
|
|
| Also removes incorrectly appended EOS tokens from tool call content when present, and validates tool_calls to
|
| ensure all required fields exist.
|
|
|
| For VLM processors, automatically uses the inner tokenizer for parsing.
|
|
|
| Args:
|
| processing_class (`PreTrainedTokenizerBase` or VLM processor):
|
| Tokenizer or processor with a `parse_response()` method (directly or via inner tokenizer).
|
| ids (`list[int]`):
|
| List of token sequences.
|
|
|
| Returns:
|
| `dict`:
|
| Response dictionary.
|
|
|
| Example:
|
| ```python
|
| >>> from trl.chat_template_utils import parse_response, add_response_schema
|
| >>> from transformers import AutoTokenizer
|
|
|
| >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
| >>> tokenizer = add_response_schema(tokenizer) # temporary until built-in support
|
| >>> text = '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n</tool_call><|im_end|>'
|
| >>> ids = tokenizer(text)["input_ids"]
|
| >>> parse_response(tokenizer, ids)
|
| {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]}
|
| ```
|
| """
|
|
|
| tokenizer = getattr(processing_class, "tokenizer", processing_class)
|
| try:
|
| parsed = tokenizer.parse_response(ids)
|
|
|
|
|
| if isinstance(parsed.get("content"), str):
|
| parsed["content"] = parsed["content"].removesuffix(tokenizer.eos_token)
|
|
|
| if not parsed.get("content"):
|
| parsed["content"] = ""
|
|
|
| if "tool_calls" in parsed:
|
| _validate_tool_calls(parsed["tool_calls"])
|
| except (ValueError, TypeError):
|
|
|
| content = tokenizer.decode(ids, skip_special_tokens=True)
|
| parsed = {"role": "assistant", "content": content}
|
| return parsed
|
|
|