|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import copy
|
| import textwrap
|
|
|
| import pytest
|
| import transformers
|
| from packaging.version import Version
|
| from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoProcessor, AutoTokenizer
|
|
|
| from trl import clone_chat_template
|
| from trl.chat_template_utils import (
|
| add_response_schema,
|
| get_training_chat_template,
|
| is_chat_template_prefix_preserving,
|
| parse_response,
|
| supports_tool_calling,
|
| )
|
| from trl.data_utils import prepare_multimodal_messages
|
|
|
| from .testing_utils import TrlTestCase, require_jmespath, require_vision
|
|
|
|
|
| class TestCloneChatTemplate(TrlTestCase):
|
| def test_clone(self):
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
| model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
|
|
| source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
|
| _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source)
|
|
|
|
|
| assert modified_tokenizer.eos_token == "<|im_end|>"
|
|
|
| def test_clone_with_resize(self):
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
| model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
|
|
| source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
|
| modified_model, modified_tokenizer, _ = clone_chat_template(
|
| model, tokenizer, source, resize_to_multiple_of=123
|
| )
|
|
|
|
|
| assert (modified_model.vocab_size % 123) == 0
|
|
|
| assert model.vocab_size == len(modified_tokenizer.vocab)
|
|
|
| def test_clone_with_resize_and_extra_tokens_already_in_vocab(self):
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
| model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
|
|
| source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
|
|
|
| modified_model, modified_tokenizer, _ = clone_chat_template(
|
| model, tokenizer, source, resize_to_multiple_of=123
|
| )
|
|
|
| modified_model, modified_tokenizer, _ = clone_chat_template(
|
| modified_model, modified_tokenizer, source, resize_to_multiple_of=124
|
| )
|
|
|
|
|
| assert (modified_model.vocab_size % 124) == 0
|
|
|
| assert model.vocab_size == len(modified_tokenizer.vocab)
|
|
|
| def test_apply_new_chat_template(self):
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
| model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM")
|
|
|
| source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
|
| _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source)
|
| messages = [
|
| {"role": "system", "content": "You are helpful"},
|
| {"role": "user", "content": "Hello"},
|
| {"role": "assistant", "content": "Hi, how can I help you?"},
|
| ]
|
| prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False)
|
|
|
| assert (
|
| prompt
|
| == "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\nHi, how can I help you?<|im_end|>\n"
|
| )
|
|
|
| def test_clone_with_sequence_classification_model(self):
|
|
|
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptNeoXForSequenceClassification")
|
| model = AutoModelForSequenceClassification.from_pretrained(
|
| "trl-internal-testing/tiny-GptNeoXForSequenceClassification"
|
| )
|
|
|
| source = "trl-internal-testing/tiny-Qwen3ForCausalLM"
|
| _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source)
|
|
|
|
|
| assert modified_tokenizer.eos_token == "<|im_end|>"
|
|
|
|
|
| @pytest.mark.xfail(
|
| condition=Version(transformers.__version__) < Version("5.0.0"),
|
| reason="Response parsing is not supported in transformers versions below 5.0.0",
|
| strict=True,
|
| )
|
| @require_jmespath
|
| class TestAddResponseSchema:
|
| @pytest.mark.parametrize(
|
| "tokenizer_name",
|
| [
|
| pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
|
| pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
|
| pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.1", id="llama3.1"),
|
| pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.2", id="llama3.2"),
|
| pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"),
|
| ],
|
| )
|
| def test_add_response_schema(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| tokenizer = add_response_schema(tokenizer)
|
| messages = [
|
| {"role": "user", "content": "What is 3*4?"},
|
| {
|
| "role": "assistant",
|
| "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}],
|
| },
|
| ]
|
| prefix = tokenizer.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True)
|
| text = tokenizer.apply_chat_template(messages, tokenize=False)
|
| response = text[len(prefix) :]
|
|
|
|
|
| tokenizer.parse_response(response)
|
|
|
| @pytest.mark.parametrize(
|
| "processor_name",
|
| [
|
| pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"),
|
| pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
|
| ],
|
| )
|
| def test_add_response_schema_vlm(self, processor_name):
|
|
|
|
|
| processor = AutoProcessor.from_pretrained(processor_name)
|
| processor = add_response_schema(processor)
|
| assert processor.tokenizer.response_schema is not None
|
| messages = [
|
| {"role": "user", "content": [{"type": "text", "text": "What is 3*4?"}]},
|
| {
|
| "role": "assistant",
|
|
|
|
|
| "content": [{"type": "text", "text": ""}],
|
| "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}],
|
| },
|
| ]
|
| prefix = processor.apply_chat_template(messages[:1], tokenize=False, add_generation_prompt=True)
|
| text = processor.apply_chat_template(messages, tokenize=False)
|
| response = text[len(prefix) :]
|
|
|
|
|
| processor.tokenizer.parse_response(response)
|
|
|
|
|
| class TestSupportsToolCalling:
|
| @pytest.mark.parametrize(
|
| "model_id",
|
| [
|
| pytest.param("trl-internal-testing/tiny-DeepseekV3ForCausalLM", id="deepseekv3"),
|
| pytest.param("trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", id="deepseekv3-0528"),
|
| pytest.param(
|
| "trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
|
| id="gemma4",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.5.0"),
|
| reason="Gemma4 models were introduced in transformers-5.5.0",
|
| ),
|
| ),
|
| pytest.param(
|
| "trl-internal-testing/tiny-Glm4MoeForCausalLM",
|
| id="glm4moe",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.0.0"),
|
| reason="GLM4 tokenizer requires transformers>=5.0.0",
|
| ),
|
| ),
|
| pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
|
| pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.1", id="llama3.1"),
|
| pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.2", id="llama3.2"),
|
| pytest.param("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", id="qwen2.5"),
|
| pytest.param("trl-internal-testing/tiny-Qwen3ForCausalLM", id="qwen3"),
|
| pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3moe"),
|
| pytest.param(
|
| "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration",
|
| id="qwen3_vl",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("4.57.0"),
|
| reason="Qwen3-VL was introduced in transformers-4.57.0",
|
| ),
|
| ),
|
| pytest.param(
|
| "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
|
| id="qwen35",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.0.0"),
|
| reason="Qwen3.5 tokenizer requires transformers>=5.0.0",
|
| ),
|
| ),
|
| ],
|
| )
|
| def test_supports_tool_calling(self, model_id):
|
| tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| assert supports_tool_calling(tokenizer) is True
|
|
|
| @pytest.mark.parametrize(
|
| "model_id",
|
| [
|
|
|
| pytest.param("trl-internal-testing/tiny-BartModel", id="bart"),
|
| pytest.param("trl-internal-testing/tiny-BloomForCausalLM", id="bloom"),
|
| pytest.param("trl-internal-testing/tiny-GPT2LMHeadModel", id="gpt2"),
|
| pytest.param("trl-internal-testing/tiny-GPTNeoXForCausalLM", id="gptneox"),
|
| pytest.param("trl-internal-testing/tiny-GptNeoXForSequenceClassification", id="gptneox-seq"),
|
| pytest.param("trl-internal-testing/tiny-OPTForCausalLM", id="opt"),
|
| pytest.param("trl-internal-testing/tiny-T5ForConditionalGeneration", id="t5"),
|
|
|
| pytest.param("trl-internal-testing/tiny-CohereForCausalLM", id="cohere"),
|
| pytest.param("trl-internal-testing/tiny-FalconMambaForCausalLM", id="falconmamba"),
|
| pytest.param("trl-internal-testing/tiny-GemmaForCausalLM", id="gemma"),
|
| pytest.param("trl-internal-testing/tiny-Gemma2ForCausalLM", id="gemma2"),
|
| pytest.param("trl-internal-testing/tiny-Gemma3ForConditionalGeneration", id="gemma3"),
|
| pytest.param("trl-internal-testing/tiny-Idefics2ForConditionalGeneration", id="idefics2"),
|
| pytest.param("trl-internal-testing/tiny-Idefics3ForConditionalGeneration", id="idefics3"),
|
| pytest.param("trl-internal-testing/tiny-LlavaNextForConditionalGeneration", id="llava_next"),
|
| pytest.param("trl-internal-testing/tiny-MistralForCausalLM-0.1", id="mistral0.1"),
|
| pytest.param("trl-internal-testing/tiny-MistralForCausalLM-0.2", id="mistral0.2"),
|
| pytest.param("trl-internal-testing/tiny-SmolVLMForConditionalGeneration", id="smolvlm"),
|
|
|
| pytest.param("trl-internal-testing/tiny-Cohere2ForCausalLM", id="cohere2"),
|
| pytest.param("trl-internal-testing/tiny-LlavaForConditionalGeneration", id="llava"),
|
| pytest.param("trl-internal-testing/tiny-Phi3ForCausalLM-3", id="phi3"),
|
| pytest.param("trl-internal-testing/tiny-Phi3ForCausalLM-3.5", id="phi3.5"),
|
|
|
| pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3", id="llama3"),
|
| pytest.param("trl-internal-testing/tiny-Qwen2VLForConditionalGeneration", id="qwen2_vl"),
|
| pytest.param("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", id="qwen2.5_vl"),
|
| ],
|
| )
|
| def test_does_not_support_tool_calling(self, model_id):
|
| tokenizer = AutoTokenizer.from_pretrained(model_id)
|
| assert supports_tool_calling(tokenizer) is False
|
|
|
|
|
| class TestIsChatTemplatePrefixPreserving:
|
| def test_prefix_preserving_template(self):
|
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForCausalLM")
|
|
|
| tokenizer.chat_template = textwrap.dedent(r"""
|
| {%- for message in messages %}
|
|
|
| {%- if message.role == 'user' %}
|
| {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}
|
| {%- elif message.role == 'assistant' %}
|
| {{- '<|im_start|>assistant\n' + message.content }}
|
| {%- if message.tool_calls %}
|
| {%- for tool_call in message.tool_calls %}
|
| {%- if tool_call.function %}
|
| {%- set tool_call = tool_call.function %}
|
| {%- endif %}
|
| {{- '<tool_call>' + tool_call.name + '</tool_call>' }}
|
| {%- endfor %}
|
| {%- endif %}
|
| {{- '<|im_end|>\n' }}
|
| {%- elif message.role == 'tool' %}
|
| {{- '<|im_start|>tool\n' + message.content + '<|im_end|>\n' }}
|
| {%- endif %}
|
|
|
| {%- endfor %}
|
|
|
| {%- if add_generation_prompt %}
|
| {{- '<|im_start|>assistant\n' }}
|
| {%- endif %}""")
|
| assert is_chat_template_prefix_preserving(tokenizer) is True
|
|
|
| def test_non_prefix_preserving_template(self):
|
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForCausalLM")
|
|
|
|
|
|
|
|
|
| tokenizer.chat_template = textwrap.dedent(r"""
|
| {%- if messages[0].role == 'system' %}
|
| {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }}
|
| {%- endif %}
|
| {%- set ns = namespace(last_query_index=messages|length - 1) %}
|
| {%- for message in messages[::-1] %}
|
| {%- set index = (messages|length - 1) - loop.index0 %}
|
| {%- if message.role == "user" and message.content is string %}
|
| {%- set ns.last_query_index = index %}
|
| {%- break %}
|
| {%- endif %}
|
| {%- endfor %}
|
| {%- for message in messages %}
|
| {%- set content = message.content if message.content is string else '' %}
|
| {%- if message.role == "user" or (message.role == "system" and not loop.first) %}
|
| {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>\n' }}
|
| {%- elif message.role == "assistant" %}
|
| {%- set reasoning_content = '' %}
|
| {%- if message.reasoning_content is string %}
|
| {%- set reasoning_content = message.reasoning_content %}
|
| {%- else %}
|
| {%- if '</think>' in content %}
|
| {%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| {%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| {%- endif %}
|
| {%- endif %}
|
| {%- if loop.index0 > ns.last_query_index %}
|
| {%- if loop.last or (not loop.last and reasoning_content) %}
|
| {{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| {%- else %}
|
| {{- '<|im_start|>' + message.role + '\n' + content }}
|
| {%- endif %}
|
| {%- else %}
|
| {{- '<|im_start|>' + message.role + '\n' + content }}
|
| {%- endif %}
|
| {%- if message.tool_calls %}
|
| {%- for tool_call in message.tool_calls %}
|
| {%- if tool_call.function %}
|
| {%- set tool_call = tool_call.function %}
|
| {%- endif %}
|
| {{- '<tool_call>' + tool_call.name + '</tool_call>' }}
|
| {%- endfor %}
|
| {%- endif %}
|
| {{- '<|im_end|>\n' }}
|
| {%- elif message.role == "tool" %}
|
| {{- '<|im_start|>tool\n' + content + '<|im_end|>\n' }}
|
| {%- endif %}
|
| {%- endfor %}
|
| {%- if add_generation_prompt %}
|
| {{- '<|im_start|>assistant\n' }}
|
| {%- if enable_thinking is defined and enable_thinking is false %}
|
| {{- '<think>\n\n</think>\n\n' }}
|
| {%- endif %}
|
| {%- endif %}""")
|
| assert is_chat_template_prefix_preserving(tokenizer) is False
|
|
|
| @require_vision
|
| def test_prefix_preserving_template_processor(self):
|
| processor = AutoProcessor.from_pretrained("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration")
|
|
|
|
|
|
|
| processor.chat_template = textwrap.dedent(r"""
|
| {%- for message in messages %}
|
|
|
| {%- if message.role == 'user' %}
|
| {{- '<|im_start|>user\n' }}
|
| {%- if message.content is string %}
|
| {{- message.content }}
|
| {%- else %}
|
| {%- for content in message.content %}
|
| {%- if content.type == 'image' or 'image' in content %}
|
| {{- '<|vision_start|><|image_pad|><|vision_end|>' }}
|
| {%- elif 'text' in content %}
|
| {{- content.text }}
|
| {%- endif %}
|
| {%- endfor %}
|
| {%- endif %}
|
| {{- '<|im_end|>\n' }}
|
| {%- elif message.role == 'assistant' %}
|
| {{- '<|im_start|>assistant\n' }}
|
| {%- if message.content is string %}
|
| {{- message.content }}
|
| {%- else %}
|
| {%- for content in message.content %}
|
| {%- if 'text' in content %}
|
| {{- content.text }}
|
| {%- endif %}
|
| {%- endfor %}
|
| {%- endif %}
|
| {%- if message.tool_calls %}
|
| {%- for tool_call in message.tool_calls %}
|
| {%- if tool_call.function %}
|
| {%- set tool_call = tool_call.function %}
|
| {%- endif %}
|
| {{- '<tool_call>' + tool_call.name + '</tool_call>' }}
|
| {%- endfor %}
|
| {%- endif %}
|
| {{- '<|im_end|>\n' }}
|
| {%- elif message.role == 'tool' %}
|
| {{- '<|im_start|>tool\n' }}
|
| {%- if message.content is string %}
|
| {{- message.content }}
|
| {%- else %}
|
| {%- for content in message.content %}
|
| {%- if 'text' in content %}
|
| {{- content.text }}
|
| {%- endif %}
|
| {%- endfor %}
|
| {%- endif %}
|
| {{- '<|im_end|>\n' }}
|
| {%- endif %}
|
|
|
| {%- endfor %}
|
|
|
| {%- if add_generation_prompt %}
|
| {{- '<|im_start|>assistant\n' }}
|
| {%- endif %}""")
|
| assert is_chat_template_prefix_preserving(processor) is True
|
|
|
|
|
| @pytest.mark.parametrize(
|
| "tokenizer_name",
|
| [
|
| pytest.param("trl-internal-testing/tiny-DeepseekV3ForCausalLM", id="deepseekv3"),
|
| pytest.param("trl-internal-testing/tiny-GemmaForCausalLM", id="gemma"),
|
| pytest.param("trl-internal-testing/tiny-Gemma2ForCausalLM", id="gemma2"),
|
| pytest.param(
|
| "trl-internal-testing/tiny-Glm4MoeForCausalLM",
|
| id="glm4moe",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.0.0"),
|
| reason="GLM4 tokenizer requires transformers>=5.0.0",
|
| ),
|
| ),
|
| pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
|
| pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3", id="llama3"),
|
| pytest.param("trl-internal-testing/tiny-Phi3ForCausalLM-3", id="phi3"),
|
| pytest.param("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", id="qwen2.5"),
|
| pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"),
|
| ],
|
| )
|
| class TestGetTrainingChatTemplate:
|
| def test_new_chat_template_is_prefix_preserving(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| tokenizer.chat_template = get_training_chat_template(tokenizer)
|
|
|
|
|
| if not supports_tool_calling(tokenizer):
|
| pytest.skip("Template does not support tool calling; prefix-preservation check is not applicable.")
|
| assert is_chat_template_prefix_preserving(tokenizer) is True
|
|
|
| def test_behavior_unchanged_single_user_no_generation_prompt(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| messages = [{"role": "user", "content": "What color is the sky?"}]
|
| before = tokenizer.apply_chat_template(messages, tokenize=False)
|
| new_chat_template = get_training_chat_template(tokenizer)
|
| after = tokenizer.apply_chat_template(messages, tokenize=False, chat_template=new_chat_template)
|
| assert before == after
|
|
|
| def test_behavior_unchanged_single_user_with_generation_prompt(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| messages = [{"role": "user", "content": "What color is the sky?"}]
|
| before = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
|
| new_chat_template = get_training_chat_template(tokenizer)
|
| after = tokenizer.apply_chat_template(
|
| messages,
|
| tokenize=False,
|
| add_generation_prompt=True,
|
| chat_template=new_chat_template,
|
| )
|
| assert before == after
|
|
|
| def test_behavior_unchanged_single_user_and_final_assistant_plain_content(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| messages = [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ]
|
| before = tokenizer.apply_chat_template(messages, tokenize=False)
|
| new_chat_template = get_training_chat_template(tokenizer)
|
| after = tokenizer.apply_chat_template(messages, tokenize=False, chat_template=new_chat_template)
|
| assert before == after
|
|
|
| def test_behavior_unchanged_final_assistant_with_reasoning_content(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| messages = [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {
|
| "role": "assistant",
|
| "content": "It is blue.",
|
| "reasoning_content": "The sky appears blue due to Rayleigh scattering.",
|
| },
|
| ]
|
| before = tokenizer.apply_chat_template(messages, tokenize=False)
|
| new_chat_template = get_training_chat_template(tokenizer)
|
| after = tokenizer.apply_chat_template(messages, tokenize=False, chat_template=new_chat_template)
|
| assert before == after
|
|
|
| def test_behavior_unchanged_final_assistant_with_existing_think_tags(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| messages = [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {
|
| "role": "assistant",
|
| "content": "<think>\nThe sky scatters shorter wavelengths.\n</think>\n\nIt is blue.",
|
| },
|
| ]
|
| before = tokenizer.apply_chat_template(messages, tokenize=False)
|
| new_chat_template = get_training_chat_template(tokenizer)
|
| after = tokenizer.apply_chat_template(messages, tokenize=False, chat_template=new_chat_template)
|
| assert before == after
|
|
|
| def test_behavior_unchanged_assistant_with_tool_calls(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}]
|
| messages = [
|
| {"role": "user", "content": "Multiply 3 by 4."},
|
| {"role": "assistant", "content": "I will call a tool.", "tool_calls": tool_calls},
|
| ]
|
| messages_before = copy.deepcopy(messages)
|
| if tokenizer_name == "trl-internal-testing/tiny-DeepseekV3ForCausalLM":
|
|
|
|
|
| messages_before[1]["tool_calls"][0]["function"]["arguments"] = '{"a": 3, "b": 4}'
|
|
|
| before = tokenizer.apply_chat_template(messages_before, tokenize=False)
|
| new_chat_template = get_training_chat_template(tokenizer)
|
| after = tokenizer.apply_chat_template(messages, tokenize=False, chat_template=new_chat_template)
|
| assert before == after
|
|
|
| def test_behavior_unchanged_with_tools_with_and_without_system_message(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| tools = [
|
| {
|
| "type": "function",
|
| "function": {
|
| "name": "multiply",
|
| "description": "Multiply two numbers.",
|
| "parameters": {
|
| "type": "object",
|
| "properties": {
|
| "a": {"type": "number"},
|
| "b": {"type": "number"},
|
| },
|
| "required": ["a", "b"],
|
| },
|
| },
|
| }
|
| ]
|
| messages = [{"role": "user", "content": "Multiply 3 by 4."}]
|
| before = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools)
|
| new_chat_template = get_training_chat_template(tokenizer)
|
| after = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, chat_template=new_chat_template)
|
| assert before == after
|
|
|
| def test_behavior_unchanged_with_tools_with_system_message(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| if not supports_tool_calling(tokenizer):
|
| pytest.skip("Template does not support tool calling; skipping tool_calls test.")
|
| tools = [
|
| {
|
| "type": "function",
|
| "function": {
|
| "name": "multiply",
|
| "description": "Multiply two numbers.",
|
| "parameters": {
|
| "type": "object",
|
| "properties": {"a": {"type": "number"}, "b": {"type": "number"}},
|
| "required": ["a", "b"],
|
| },
|
| },
|
| }
|
| ]
|
| messages = [
|
| {"role": "system", "content": "You are a helpful assistant."},
|
| {"role": "user", "content": "Multiply 3 by 4."},
|
| ]
|
| before = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools)
|
| new_chat_template = get_training_chat_template(tokenizer)
|
| after = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, chat_template=new_chat_template)
|
| assert before == after
|
|
|
| def test_behavior_unchanged_generation_prompt_with_enable_thinking_false(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| messages = [{"role": "user", "content": "What color is the sky?"}]
|
| before = tokenizer.apply_chat_template(
|
| messages, tokenize=False, add_generation_prompt=True, enable_thinking=False
|
| )
|
| new_chat_template = get_training_chat_template(tokenizer)
|
| after = tokenizer.apply_chat_template(
|
| messages,
|
| tokenize=False,
|
| add_generation_prompt=True,
|
| enable_thinking=False,
|
| chat_template=new_chat_template,
|
| )
|
| assert before == after
|
|
|
| def test_assistant_masks(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| messages = [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ]
|
| chat_template = get_training_chat_template(tokenizer)
|
| result = tokenizer.apply_chat_template(
|
| messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True
|
| )
|
| masks = result["assistant_masks"]
|
| assert 1 in masks
|
|
|
| assert masks[0] == 0
|
|
|
| assert masks[-1] == 1
|
|
|
| def test_assistant_masks_multi_turn(self, tokenizer_name):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| messages = [
|
| {"role": "user", "content": "Hi"},
|
| {"role": "assistant", "content": "Hello!"},
|
| {"role": "user", "content": "Bye"},
|
| {"role": "assistant", "content": "Goodbye!"},
|
| ]
|
| chat_template = get_training_chat_template(tokenizer)
|
| result = tokenizer.apply_chat_template(
|
| messages, chat_template=chat_template, return_assistant_tokens_mask=True, return_dict=True
|
| )
|
| masks = result["assistant_masks"]
|
|
|
| transitions = sum(1 for i in range(1, len(masks)) if masks[i] != masks[i - 1])
|
| assert transitions == 3
|
|
|
|
|
| @pytest.mark.parametrize(
|
| "model_name",
|
| [
|
| pytest.param("trl-internal-testing/tiny-Glm4MoeForCausalLM", id="glm4moe"),
|
| pytest.param("trl-internal-testing/tiny-GptOssForCausalLM", id="gptoss"),
|
| pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.1", id="llama3.1"),
|
| pytest.param("trl-internal-testing/tiny-LlamaForCausalLM-3.2", id="llama3.2"),
|
| pytest.param("trl-internal-testing/tiny-Qwen3MoeForCausalLM", id="qwen3"),
|
| pytest.param("trl-internal-testing/tiny-Qwen3VLForConditionalGeneration", id="qwen3_vl"),
|
| pytest.param("trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", id="qwen35"),
|
| pytest.param(
|
| "trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
|
| id="gemma4",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.5.0"),
|
| reason="Gemma4 models were introduced in transformers-5.5.0",
|
| ),
|
| ),
|
| ],
|
| )
|
| @pytest.mark.xfail(
|
| condition=Version(transformers.__version__) < Version("5.0.0"),
|
| reason="Response parsing is not supported in transformers versions below 5.0.0",
|
| strict=True,
|
| )
|
| @require_jmespath
|
| class TestParseResponse:
|
| def _load(self, model_name):
|
| if "ForCausalLM" in model_name:
|
| self.is_vlm = False
|
| processing_class = AutoTokenizer.from_pretrained(model_name)
|
| response_schema = getattr(processing_class, "response_schema", None)
|
| elif "ForConditionalGeneration" in model_name:
|
| self.is_vlm = True
|
| processing_class = AutoProcessor.from_pretrained(model_name)
|
| response_schema = getattr(processing_class.tokenizer, "response_schema", None)
|
|
|
| if response_schema is None:
|
| processing_class = add_response_schema(processing_class)
|
|
|
| return processing_class
|
|
|
| def test_parse_response(self, model_name):
|
| processing_class = self._load(model_name)
|
| messages = [
|
| {"role": "user", "content": "What is 3*4?"},
|
| {"role": "assistant", "content": "12"},
|
| ]
|
| expected = messages[-1]
|
| messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
|
| prefix = processing_class.apply_chat_template(
|
| messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
|
| ).input_ids
|
| text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
|
| if self.is_vlm:
|
| prefix = prefix[0]
|
| text = text[0]
|
| response = text[len(prefix) :]
|
| parsed = parse_response(processing_class, response)
|
| assert parsed == expected
|
|
|
| def test_parse_response_with_reasoning_content(self, model_name):
|
| if model_name in (
|
| "trl-internal-testing/tiny-Gemma4ForConditionalGeneration",
|
| "trl-internal-testing/tiny-GptOssForCausalLM",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.1",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| "trl-internal-testing/tiny-Qwen3VLForConditionalGeneration",
|
| ):
|
| pytest.skip("This tokenizer doesn't support inline reasoning_content.")
|
|
|
| processing_class = self._load(model_name)
|
| messages = [
|
| {"role": "user", "content": "What is 3*4?"},
|
| {"role": "assistant", "reasoning_content": "Hmmm.", "content": "12"},
|
| ]
|
| expected = messages[-1]
|
| messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
|
|
|
|
|
| prefix = processing_class.apply_chat_template(
|
| messages[:1], add_generation_prompt=True, enable_thinking=True, tokenize=True, return_dict=True
|
| ).input_ids
|
| text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
|
| if self.is_vlm:
|
| prefix = prefix[0]
|
| text = text[0]
|
| response = text[len(prefix) :]
|
| parsed = parse_response(processing_class, response)
|
| assert parsed == expected
|
|
|
| def test_parse_response_tool_call(self, model_name):
|
| processing_class = self._load(model_name)
|
| tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}]
|
| messages = [
|
| {"role": "user", "content": "What is 3*4?"},
|
| {
|
| "role": "assistant",
|
|
|
|
|
| "content": "",
|
| "tool_calls": tool_calls,
|
| },
|
| ]
|
| expected = messages[-1]
|
| messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
|
| prefix = processing_class.apply_chat_template(
|
| messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
|
| ).input_ids
|
| text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
|
| if self.is_vlm:
|
| prefix = prefix[0]
|
| text = text[0]
|
| response = text[len(prefix) :]
|
| parsed = parse_response(processing_class, response)
|
| assert parsed == expected
|
|
|
| def test_parse_response_tool_call_with_content(self, model_name):
|
| if model_name in (
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.1",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| ):
|
| pytest.skip("Llama 3.1 / 3.2 templates only allow a single tool call per assistant turn, with no content.")
|
| processing_class = self._load(model_name)
|
| tool_calls = [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}]
|
| messages = [
|
| {"role": "user", "content": "What is 3*4?"},
|
| {"role": "assistant", "content": "Let's call the tool.", "tool_calls": tool_calls},
|
| ]
|
| expected = messages[-1]
|
| messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
|
| prefix = processing_class.apply_chat_template(
|
| messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
|
| ).input_ids
|
| text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
|
| if self.is_vlm:
|
| prefix = prefix[0]
|
| text = text[0]
|
| response = text[len(prefix) :]
|
| parsed = parse_response(processing_class, response)
|
| assert parsed == expected
|
|
|
| def test_parse_response_tool_call_without_arguments(self, model_name):
|
| processing_class = self._load(model_name)
|
| tool_calls = [{"type": "function", "function": {"name": "ping", "arguments": {}}}]
|
| messages = [
|
| {"role": "user", "content": "Ping the service."},
|
| {
|
| "role": "assistant",
|
|
|
|
|
| "content": "",
|
| "tool_calls": tool_calls,
|
| },
|
| ]
|
| expected = messages[-1]
|
| messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
|
| prefix = processing_class.apply_chat_template(
|
| messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
|
| ).input_ids
|
| text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
|
| if self.is_vlm:
|
| prefix = prefix[0]
|
| text = text[0]
|
| response = text[len(prefix) :]
|
| parsed = parse_response(processing_class, response)
|
| assert parsed == expected
|
|
|
| def test_parse_response_multiple_tool_calls(self, model_name):
|
| if model_name in (
|
| "trl-internal-testing/tiny-GptOssForCausalLM",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.1",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| ):
|
| pytest.skip("This template only renders one tool call per assistant message.")
|
| processing_class = self._load(model_name)
|
| tool_calls = [
|
| {"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}},
|
| {"type": "function", "function": {"name": "addition", "arguments": {"a": 4, "b": 3}}},
|
| ]
|
| messages = [
|
| {"role": "user", "content": "What is 3*4?"},
|
| {
|
| "role": "assistant",
|
|
|
|
|
| "content": "",
|
| "tool_calls": tool_calls,
|
| },
|
| ]
|
| expected = messages[-1]
|
| messages = prepare_multimodal_messages(messages) if self.is_vlm else messages
|
| prefix = processing_class.apply_chat_template(
|
| messages[:1], add_generation_prompt=True, tokenize=True, return_dict=True
|
| ).input_ids
|
| text = processing_class.apply_chat_template(messages, tokenize=True, return_dict=True).input_ids
|
| if self.is_vlm:
|
| prefix = prefix[0]
|
| text = text[0]
|
| response = text[len(prefix) :]
|
| parsed = parse_response(processing_class, response)
|
| assert parsed == expected
|
|
|
| def test_parse_response_malformed_tool_call(self, model_name):
|
| if model_name != "trl-internal-testing/tiny-Qwen3MoeForCausalLM":
|
| pytest.skip("For simplicity, we only test the malformed tool call case on one tokenizer.")
|
| processing_class = self._load(model_name)
|
| text = '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n</tool_call><|im_end|>'
|
| assistant_text = processing_class(text)["input_ids"]
|
| parsed = parse_response(processing_class, assistant_text)
|
| expected = {
|
| "role": "assistant",
|
| "content": '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n</tool_call>',
|
| }
|
|
|
| assert parsed == expected
|
|
|