# Copyright 2020-2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import textwrap from time import strftime import pytest import transformers from datasets import Dataset, DatasetDict from packaging.version import Version from transformers import AutoProcessor, AutoTokenizer, is_vision_available from trl.data_utils import ( apply_chat_template, extract_prompt, is_conversational, is_conversational_from_value, maybe_apply_chat_template, maybe_convert_to_chatml, maybe_extract_prompt, maybe_unpair_preference_dataset, pack_dataset, prepare_multimodal_messages, prepare_multimodal_messages_vllm, unpair_preference_dataset, ) from .testing_utils import TrlTestCase, require_vision if is_vision_available(): from PIL import Image @require_vision class TestPrepareMultimodalMessages: def test_basic_user_assistant_conversation(self): """Test basic conversation with user and assistant messages.""" messages = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ { "role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], }, { "role": "assistant", "content": [{"type": "text", "text": "It is blue."}], }, ] assert messages == expected def test_first_user_message_gets_image(self): """Test that only the first user message gets an image.""" messages = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, {"role": "user", "content": "How about the grass?"}, ] image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ { "role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], }, { "role": "assistant", "content": [{"type": "text", "text": "It is blue."}], }, { "role": "user", "content": [{"type": "text", "text": "How about the grass?"}], }, ] assert messages == expected def test_multiple_images(self): """Test that multiple images are added to the first user message.""" messages = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] images = [Image.new("RGB", (10, 10), color=color) for color in ["red", "green", "blue"]] messages = prepare_multimodal_messages(messages, images=images) expected = [ { "role": "user", "content": [ {"type": "image", "image": images[0]}, {"type": "image", "image": images[1]}, {"type": "image", "image": images[2]}, {"type": "text", "text": "What color is the sky?"}, ], }, { "role": "assistant", "content": [{"type": "text", "text": "It is blue."}], }, ] assert messages == expected def test_system_message_transformation(self): """Test that system messages are properly transformed.""" messages = [ {"role": "system", "content": "You are a helpful assistant"}, {"role": "user", "content": "What color is the sky?"}, ] image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ { "role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}], }, { "role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], }, ] assert messages == expected def test_already_prepared_messages_unchanged(self): """Test that messages with list content are not modified.""" messages = [ {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]}, {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, ] image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ { "role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}], }, { "role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], }, { "role": "assistant", "content": [{"type": "text", "text": "It is blue."}], }, ] assert messages == expected def test_mixed_prepared_and_unprepared_messages(self): """Test handling of mixed prepared and unprepared messages.""" messages = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, {"role": "user", "content": "What about the grass?"}, ] image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ { "role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], }, { "role": "assistant", "content": [{"type": "text", "text": "It is blue."}], }, { "role": "user", "content": [{"type": "text", "text": "What about the grass?"}], }, ] assert messages == expected def test_message_with_tool_calling_turns(self): """Test that both the assistant tool call and the tool role turns messages are properly transformed.""" messages = [ {"role": "user", "content": "What's the weather like in New York?"}, { "role": "assistant", "tool_calls": [ { "type": "tool", "function": {"name": "get_current_weather", "arguments": {"location": "New York"}}, } ], }, {"role": "tool", "name": "get_current_weather", "content": "22.0"}, {"role": "assistant", "content": "The current weather in New York is 22.0 degrees Celsius."}, ] messages = prepare_multimodal_messages(messages) expected = [ { "role": "user", "content": [{"type": "text", "text": "What's the weather like in New York?"}], }, { "role": "assistant", "tool_calls": [ { "type": "tool", "function": {"name": "get_current_weather", "arguments": {"location": "New York"}}, } ], }, {"role": "tool", "name": "get_current_weather", "content": [{"type": "text", "text": "22.0"}]}, { "role": "assistant", "content": [{"type": "text", "text": "The current weather in New York is 22.0 degrees Celsius."}], }, ] assert messages == expected def test_prepared_image_blocks_without_new_images(self): """Test that existing image payloads are preserved when no new images are provided.""" image = Image.new("RGB", (10, 10), color="blue") messages = [ { "role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], }, {"role": "assistant", "content": "It is blue."}, ] messages = prepare_multimodal_messages(messages) expected = [ { "role": "user", "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], }, {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, ] assert messages == expected @require_vision class TestPrepareMultimodalMessagesVLLM: def test_single_image_conversion(self): messages = [ { "role": "user", "content": [ {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, {"type": "text", "text": "What color is the sky?"}, ], } ] result = prepare_multimodal_messages_vllm(messages) # Original should remain unchanged (deepcopy test) assert messages[0]["content"][0]["type"] == "image" # Converted version should have correct structure assert result[0]["content"][0]["type"] == "image_pil" assert "image_pil" in result[0]["content"][0] assert "image" not in result[0]["content"][0] assert isinstance(result[0]["content"][0]["image_pil"], Image.Image) assert result[0]["content"][1]["type"] == "text" def test_mixed_content_conversion(self): messages = [ { "role": "user", "content": [ {"type": "text", "text": "What color is the sky?"}, {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, ], } ] result = prepare_multimodal_messages_vllm(messages) # The image part should be converted, text should be unchanged assert result[0]["content"][0]["type"] == "text" assert result[0]["content"][1]["type"] == "image_pil" def test_no_images(self): messages = [{"role": "user", "content": [{"type": "text", "text": "What color is the sky?"}]}] result = prepare_multimodal_messages_vllm(messages) # Should be identical since there are no images assert result == messages # And a deepcopy — not the same object assert result is not messages assert result[0] is not messages[0] def test_multiple_messages(self): messages = [ { "role": "user", "content": [ {"type": "text", "text": "What color is the sky?"}, {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, ], }, { "role": "assistant", "content": [{"type": "text", "text": "It is blue."}], }, ] result = prepare_multimodal_messages_vllm(messages) assert result[0]["content"][1]["type"] == "image_pil" assert result[1]["content"][0]["type"] == "text" assert result[1]["content"][0]["text"] == "It is blue." def test_deepcopy_integrity(self): messages = [ { "role": "user", "content": [ {"type": "text", "text": "What color is the sky?"}, {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, ], }, ] original = copy.deepcopy(messages) _ = prepare_multimodal_messages_vllm(messages) # Original should not be mutated assert messages == original class TestIsConversational(TrlTestCase): # fmt: off conversational_examples = [ { # Language modeling "messages": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ], }, { # Prompt-only "prompt": [{"role": "user", "content": "What color is the sky?"}], }, { # Prompt-completion "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], }, { # Preference "prompt": [{"role": "user", "content": "What color is the sky?"}], "chosen": [{"role": "assistant", "content": "It is blue."}], "rejected": [{"role": "assistant", "content": "It is green."}], }, { # Preference with implicit prompt "chosen": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ], "rejected": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}, ], }, { # Preference with tool calls "prompt": [{"role": "user", "content": "What color is the sky?"}], "chosen": [ {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": "get_color", "arguments": {"what": "sky"}}}]}, {"role": "tool", "name": "get_color", "content": "blue"}, {"role": "assistant", "content": "It is blue."}, ], "rejected": [ {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": "get_color", "arguments": {"what": "tree"}}}]}, {"role": "tool", "name": "get_color", "content": "green"}, {"role": "assistant", "content": "It is green."}, ], "tools": [ { "type": "function", "function": { "description": "Gets the color.", "name": "get_color", "parameters": {"properties": {"what": {"description": "What to get the color of.", "type": "string"}}, "required": ["what"], "type": "object"}, "return": {"description": "The color.", "type": "string"}, }, }, ], }, { # Unpaired preference "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], "label": True, }, { # Language modeling with harmony "messages": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], }, { # Prompt-only with harmony "prompt": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, ], }, { # Prompt-completion with harmony "prompt": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, ], "completion": [ {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], }, { # Preference with harmony "prompt": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, ], "chosen": [ {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], "rejected": [ {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, ], }, { # Preference with implicit prompt and harmony "chosen": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], "rejected": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, ], }, { # Unpaired preference with harmony "prompt": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, ], "completion": [ {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], "label": True, }, ] # fmt: on non_conversational_examples = [ {"prompt": "The sky is", "completion": " blue."}, {"text": "The sky is blue."}, {"prompt": "The sky is"}, {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, {"prompt": "The sky is", "completion": " blue.", "label": True}, ] @pytest.mark.parametrize("example", conversational_examples) def test_conversational(self, example): assert is_conversational(example) @pytest.mark.parametrize("example", non_conversational_examples) def test_non_conversational(self, example): assert not is_conversational(example) class TestIsConversationalFromValue(TrlTestCase): def test_positive_1(self): example = { "conversations": [ {"from": "user", "value": "What color is the sky?"}, {"from": "assistant", "value": "It is blue."}, ], } assert is_conversational_from_value(example) def test_negative_1(self): example = { "messages": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ], } assert not is_conversational_from_value(example) def test_negative_2(self): example = {"text": "The sky is blue."} assert not is_conversational_from_value(example) class TestApplyChatTemplate(TrlTestCase): tokenizers = [ "trl-internal-testing/tiny-CohereForCausalLM", "trl-internal-testing/tiny-Cohere2ForCausalLM", "trl-internal-testing/tiny-DeepseekV3ForCausalLM", "trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528", "trl-internal-testing/tiny-FalconMambaForCausalLM", "trl-internal-testing/tiny-Gemma2ForCausalLM", "trl-internal-testing/tiny-GemmaForCausalLM", "trl-internal-testing/tiny-GptOssForCausalLM", pytest.param( "trl-internal-testing/tiny-Glm4MoeForCausalLM", marks=pytest.mark.skipif( Version(transformers.__version__) < Version("5.0.0"), reason="GLM4 tokenizer requires transformers>=5.0.0", ), ), "trl-internal-testing/tiny-LlamaForCausalLM-3.1", "trl-internal-testing/tiny-LlamaForCausalLM-3.2", "trl-internal-testing/tiny-LlamaForCausalLM-3", "trl-internal-testing/tiny-MistralForCausalLM-0.1", "trl-internal-testing/tiny-MistralForCausalLM-0.2", "trl-internal-testing/tiny-Phi3ForCausalLM-3", "trl-internal-testing/tiny-Phi3ForCausalLM-3.5", "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", "trl-internal-testing/tiny-Qwen3ForCausalLM", pytest.param( "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration", marks=pytest.mark.skipif( Version(transformers.__version__) < Version("5.0.0"), reason="Qwen3.5 tokenizer requires transformers>=5.0.0", ), ), ] conversational_examples = [ { # Language modeling "messages": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ], }, { # Prompt-only "prompt": [{"role": "user", "content": "What color is the sky?"}], }, { # Prompt-completion "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], }, { # Preference "prompt": [{"role": "user", "content": "What color is the sky?"}], "chosen": [{"role": "assistant", "content": "It is blue."}], "rejected": [{"role": "assistant", "content": "It is green."}], }, { # Preference with implicit prompt "chosen": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ], "rejected": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}, ], }, { # Unpaired preference "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], "label": True, }, ] non_conversational_examples = [ {"text": "The sky is blue."}, # Language modeling {"prompt": "The sky is"}, # Prompt-only {"prompt": "The sky is", "completion": " blue."}, # Prompt-completion {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."}, # Preference {"chosen": "The sky is blue.", "rejected": "The sky is green."}, # Preference with implicit prompt {"prompt": "The sky is", "completion": " blue.", "label": True}, # Unpaired preference ] @pytest.mark.parametrize("example", conversational_examples) @pytest.mark.parametrize("tokenizer_id", tokenizers) def test_apply_chat_template(self, tokenizer_id, example): tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) result = apply_chat_template(example, tokenizer) # Checking if the result is a dictionary assert isinstance(result, dict) # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: assert key in result assert isinstance(result[key], str) # Exception for messages, the key is "text" once the chat template is applied if "messages" in example: assert "text" in result assert isinstance(result["text"], str) # The label should be kept if "label" in example: assert "label" in result assert isinstance(result["label"], bool) assert result["label"] == example["label"] # both conversational and non-conversational examples @pytest.mark.parametrize("example", conversational_examples + non_conversational_examples) @pytest.mark.parametrize("tokenizer_id", tokenizers) def test_maybe_apply_chat_template(self, tokenizer_id, example): tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) result = maybe_apply_chat_template(example, tokenizer) # Checking if the result is a dictionary assert isinstance(result, dict) # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: assert key in result assert isinstance(result[key], str) # Exception for messages, the key is "text" once the chat template is applied if "messages" in example: assert "text" in result assert isinstance(result["text"], str) # The label should be kept if "label" in example: assert "label" in result assert isinstance(result["label"], bool) assert result["label"] == example["label"] def test_apply_chat_template_with_chat_template_kwargs(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM") example = { "prompt": [{"role": "user", "content": "What color is the sky?"}], # with this tokenizer, when you pass enable_thinking=False, it will add "\n\n\n\n" "chat_template_kwargs": {"enable_thinking": False}, } result = apply_chat_template(example, tokenizer) # docstyle-ignore expected = textwrap.dedent("""\ <|im_start|>user What color is the sky?<|im_end|> <|im_start|>assistant """) assert result["prompt"] == expected def test_apply_chat_template_with_tools(self): tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2") # Define dummy test tools def get_current_temperature(location: str): """ Gets the temperature at a given location. Args: location: The location to get the temperature for """ return 22.0 # Define test case test_case = { "prompt": [ {"content": "What's the temperature in London?", "role": "user"}, ] } # Test with tools result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature]) # Verify tools are included in the output assert "get_current_temperature" in result_with_tools["prompt"] # Test without tools result_without_tools = apply_chat_template(test_case, tokenizer, tools=None) # Verify tools are not included in the output assert "get_current_temperature" not in result_without_tools["prompt"] class TestApplyChatTemplateHarmony(TrlTestCase): def test_language_modeling(self): messages = { "messages": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], } output = apply_chat_template( messages, processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), reasoning_effort="low", model_identity="You are HuggingGPT.", ) # docstyle-ignore expected = textwrap.dedent(f"""\ <|start|>system<|message|>You are HuggingGPT. Knowledge cutoff: 2024-06 Current date: {strftime("%Y-%m-%d")} Reasoning: low # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions Respond in a friendly manner. <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""") assert output["text"] == expected def test_prompt_only(self): messages = { "prompt": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, ], } output = apply_chat_template( messages, processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), reasoning_effort="low", model_identity="You are HuggingGPT.", ) # docstyle-ignore expected = textwrap.dedent(f"""\ <|start|>system<|message|>You are HuggingGPT. Knowledge cutoff: 2024-06 Current date: {strftime("%Y-%m-%d")} Reasoning: low # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions Respond in a friendly manner. <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") assert output["prompt"] == expected def test_prompt_completion(self): messages = { "prompt": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, ], "completion": [ {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], } output = apply_chat_template( messages, processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), reasoning_effort="low", model_identity="You are HuggingGPT.", ) # docstyle-ignore expected_prompt = textwrap.dedent(f"""\ <|start|>system<|message|>You are HuggingGPT. Knowledge cutoff: 2024-06 Current date: {strftime("%Y-%m-%d")} Reasoning: low # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions Respond in a friendly manner. <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" assert output["prompt"] == expected_prompt assert output["completion"] == expected_completion def test_preference(self): messages = { "prompt": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, ], "chosen": [ {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], "rejected": [ {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, ], } output = apply_chat_template( messages, processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), reasoning_effort="low", model_identity="You are HuggingGPT.", ) # docstyle-ignore expected_prompt = textwrap.dedent(f"""\ <|start|>system<|message|>You are HuggingGPT. Knowledge cutoff: 2024-06 Current date: {strftime("%Y-%m-%d")} Reasoning: low # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions Respond in a friendly manner. <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") expected_chosen = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" expected_rejected = "<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>" assert output["prompt"] == expected_prompt assert output["chosen"] == expected_chosen assert output["rejected"] == expected_rejected def test_preference_with_implicit_prompt(self): messages = { "chosen": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], "rejected": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."}, ], } output = apply_chat_template( messages, processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), reasoning_effort="low", model_identity="You are HuggingGPT.", ) # docstyle-ignore expected_chosen = textwrap.dedent(f"""\ <|start|>system<|message|>You are HuggingGPT. Knowledge cutoff: 2024-06 Current date: {strftime("%Y-%m-%d")} Reasoning: low # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions Respond in a friendly manner. <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""") # docstyle-ignore expected_rejected = textwrap.dedent(f"""\ <|start|>system<|message|>You are HuggingGPT. Knowledge cutoff: 2024-06 Current date: {strftime("%Y-%m-%d")} Reasoning: low # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions Respond in a friendly manner. <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>""") assert output["chosen"] == expected_chosen assert output["rejected"] == expected_rejected def test_unpaired_preference(self): messages = { "prompt": [ {"role": "system", "content": "Respond in a friendly manner."}, {"role": "user", "content": "What color is the sky?"}, ], "completion": [ {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."}, ], "label": True, } output = apply_chat_template( messages, processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"), reasoning_effort="low", model_identity="You are HuggingGPT.", ) # docstyle-ignore expected_prompt = textwrap.dedent(f"""\ <|start|>system<|message|>You are HuggingGPT. Knowledge cutoff: 2024-06 Current date: {strftime("%Y-%m-%d")} Reasoning: low # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions Respond in a friendly manner. <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" assert output["prompt"] == expected_prompt assert output["completion"] == expected_completion assert output["label"] class TestUnpairPreferenceDataset(TrlTestCase): paired_dataset = Dataset.from_dict( { "prompt": ["The sky is", "The sun is"], "chosen": [" blue.", " in the sky."], "rejected": [" green.", " in the sea."], } ) unpaired_dataset = Dataset.from_dict( { "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"], "completion": [" blue.", " in the sky.", " green.", " in the sea."], "label": [True, True, False, False], } ) def test_unpair_preference_dataset(self): # Test that a paired dataset is correctly converted to unpaired unpaired_dataset = unpair_preference_dataset(self.paired_dataset) assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( "The paired dataset should be converted to unpaired." ) def test_unpair_preference_dataset_dict(self): # Test that a paired dataset dict is correctly converted to unpaired paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict) assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( "The paired dataset should be converted to unpaired." ) def test_maybe_unpair_preference_dataset(self): # Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset) assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( "The paired dataset should be converted to unpaired." ) def test_maybe_unpair_preference_dataset_dict(self): # Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict) assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( "The paired dataset should be converted to unpaired." ) def test_maybe_unpair_preference_dataset_already_paired(self): # Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset) assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( "The unpaired dataset should remain unchanged." ) def test_maybe_unpair_preference_dataset_dict_already_paired(self): # Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset})) assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( "The unpaired dataset should remain unchanged." ) class TestExtractPrompt(TrlTestCase): example_implicit_prompt_conversational = { "chosen": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ], "rejected": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is green."}, ], } example_explicit_prompt_conversational = { "prompt": [ {"role": "user", "content": "What color is the sky?"}, ], "chosen": [ {"role": "assistant", "content": "It is blue."}, ], "rejected": [ {"role": "assistant", "content": "It is green."}, ], } example_implicit_prompt_standard = { "chosen": "The sky is blue.", "rejected": "The sky is green.", } example_explicit_prompt_standard = { "prompt": "The sky is", "chosen": " blue.", "rejected": " green.", } def test_extract_prompt_conversational(self): # Test that the prompt is correctly extracted from the dataset example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational) assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( "The prompt is not correctly extracted from the dataset." ) def test_maybe_extract_prompt_conversational(self): # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational) assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( "The prompt is not correctly extracted from the dataset." ) def test_maybe_extract_prompt_conversational_already_explicit(self): # Test that the prompt remains unchanged with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational) assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( "The prompt should remain unchanged." ) def test_extract_prompt_standard(self): # Test that the prompt is correctly extracted from the dataset example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard) assert example_extracted_prompt == self.example_explicit_prompt_standard, ( "The prompt is not correctly extracted from the dataset." ) def test_maybe_extract_prompt_standard(self): # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard) assert example_extracted_prompt == self.example_explicit_prompt_standard, ( "The prompt is not correctly extracted from the dataset." ) def test_maybe_extract_prompt_standard_already_explicit(self): # Test that the prompt remains unchanged with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard) assert example_extracted_prompt == self.example_explicit_prompt_standard, "The prompt should remain unchanged." class TestPackDatasetWrapped(TrlTestCase): def test_with_dataset(self): examples = { "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], } dataset = Dataset.from_dict(examples) dataset = dataset.with_format("numpy", dtype="float32") format = dataset.format seq_length = 3 expected_output = { "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="wrapped") assert dataset.to_dict() == expected_output assert format == dataset.format def test_with_iterable_dataset(self): examples = { "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]], } dataset = Dataset.from_dict(examples).to_iterable_dataset() dataset = dataset.with_format("numpy") formatting = dataset._formatting seq_length = 3 expected_output = { "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]], "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="wrapped") num_examples = len(examples[next(iter(examples))]) assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output assert formatting == dataset._formatting class TestPackDatasetBfd(TrlTestCase): def test_with_dataset(self): examples = { "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], } dataset = Dataset.from_dict(examples) dataset = dataset.with_format("numpy", dtype="float32") format = dataset.format seq_length = 4 expected_output = { "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], "seq_lengths": [[4], [3, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd") expected_format = dataset.format assert dataset.to_dict() == expected_output assert "seq_lengths" in expected_format["columns"] expected_format["columns"].remove("seq_lengths") assert format == dataset.format def test_with_iterable_dataset(self): examples = { "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]], } dataset = Dataset.from_dict(examples).to_iterable_dataset() dataset = dataset.with_format("numpy") formatting = dataset._formatting seq_length = 4 expected_output = { "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]], "seq_lengths": [[4], [3, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd") num_examples = len(examples[next(iter(examples))]) assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output assert formatting == dataset._formatting def test_with_overlong_0(self): examples = { "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]], } dataset = Dataset.from_dict(examples) seq_length = 4 expected_output = { "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 5, 12]], "seq_lengths": [[4], [4], [2, 1, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd_split") assert dataset.to_dict() == expected_output def test_with_overlong_two_coluns(self): examples = { "col1": [[1, -2, 3, -4, 5, -6], [7, -8, 9], [-10, 11, -12], [13, -14, 15, -16]], "col2": [[-1, 2, -3, 4, -5, 6], [-7, 8, -9], [10, -11, 12], [-13, 14, -15, 16]], } dataset = Dataset.from_dict(examples) seq_length = 4 expected_output = { "col1": [[1, -2, 3, -4], [13, -14, 15, -16], [7, -8, 9], [-10, 11, -12], [5, -6]], "col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]], "seq_lengths": [[4], [4], [3], [3], [2]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd_split") assert dataset.to_dict() == expected_output def test_with_non_power_of_2(self): examples = { "input_ids": [[1, 2, 3, 4, 5], [6], [7, 8, 9, 10], [11, 12, 13]], } dataset = Dataset.from_dict(examples) seq_length = 5 expected_output = { "input_ids": [[1, 2, 3, 4, 5], [7, 8, 9, 10, 6], [11, 12, 13]], "seq_lengths": [[5], [4, 1], [3]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd_split") assert dataset.to_dict() == expected_output def test_default_no_split(self): """Test default 'bfd' strategy for SFT datasets (truncates overflow).""" examples = { "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]], } dataset = Dataset.from_dict(examples) seq_length = 4 # With default 'bfd' strategy, overflow tokens are discarded expected_output = { "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 12]], "seq_lengths": [[4], [4], [2, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd") assert dataset.to_dict() == expected_output def test_with_empty_sequences(self): examples = { "input_ids": [[1, 2], [], [3, 4, 5], [], [6]], } dataset = Dataset.from_dict(examples) seq_length = 4 expected_output = { "input_ids": [[3, 4, 5, 6], [1, 2]], "seq_lengths": [[3, 1], [2]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd_split") assert dataset.to_dict() == expected_output class TestMaybeConvertToChatML(TrlTestCase): def test_with_conversations_key(self): # Particular case where the key is "conversations": we rename it to "messages" example = { "conversations": [ {"from": "user", "value": "What color is the sky?"}, {"from": "assistant", "value": "It is blue."}, ] } expected_output = { "messages": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] } assert maybe_convert_to_chatml(example) == expected_output def test_without_conversations_key(self): # Same as before, but we don't rename the keys example = { "prompt": [{"from": "user", "value": "What color is the sky?"}], "completion": [{"from": "assistant", "value": "It is blue."}], } expected_output = { "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], } assert maybe_convert_to_chatml(example) == expected_output def test_not_conversional(self): # When not needed, the example should remain unchanged example = {"text": "The sky is blue."} assert maybe_convert_to_chatml(example) == example def test_already_chatml(self): # When the example is already in ChatML format, it should remain unchanged example = { "messages": [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] } assert maybe_convert_to_chatml(example) == example