|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import copy
|
| import textwrap
|
| from time import strftime
|
|
|
| import pytest
|
| import transformers
|
| from datasets import Dataset, DatasetDict
|
| from packaging.version import Version
|
| from transformers import AutoProcessor, AutoTokenizer, is_vision_available
|
|
|
| from trl.data_utils import (
|
| apply_chat_template,
|
| extract_prompt,
|
| is_conversational,
|
| is_conversational_from_value,
|
| maybe_apply_chat_template,
|
| maybe_convert_to_chatml,
|
| maybe_extract_prompt,
|
| maybe_unpair_preference_dataset,
|
| pack_dataset,
|
| prepare_multimodal_messages,
|
| prepare_multimodal_messages_vllm,
|
| unpair_preference_dataset,
|
| )
|
|
|
| from .testing_utils import TrlTestCase, require_vision
|
|
|
|
|
| if is_vision_available():
|
| from PIL import Image
|
|
|
|
|
| @require_vision
|
| class TestPrepareMultimodalMessages:
|
| def test_basic_user_assistant_conversation(self):
|
| """Test basic conversation with user and assistant messages."""
|
| messages = [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ]
|
| image = Image.new("RGB", (10, 10), color="blue")
|
| messages = prepare_multimodal_messages(messages, images=[image])
|
|
|
| expected = [
|
| {
|
| "role": "user",
|
| "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
| },
|
| {
|
| "role": "assistant",
|
| "content": [{"type": "text", "text": "It is blue."}],
|
| },
|
| ]
|
|
|
| assert messages == expected
|
|
|
| def test_first_user_message_gets_image(self):
|
| """Test that only the first user message gets an image."""
|
| messages = [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| {"role": "user", "content": "How about the grass?"},
|
| ]
|
|
|
| image = Image.new("RGB", (10, 10), color="blue")
|
| messages = prepare_multimodal_messages(messages, images=[image])
|
|
|
| expected = [
|
| {
|
| "role": "user",
|
| "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
| },
|
| {
|
| "role": "assistant",
|
| "content": [{"type": "text", "text": "It is blue."}],
|
| },
|
| {
|
| "role": "user",
|
| "content": [{"type": "text", "text": "How about the grass?"}],
|
| },
|
| ]
|
|
|
| assert messages == expected
|
|
|
| def test_multiple_images(self):
|
| """Test that multiple images are added to the first user message."""
|
| messages = [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ]
|
| images = [Image.new("RGB", (10, 10), color=color) for color in ["red", "green", "blue"]]
|
| messages = prepare_multimodal_messages(messages, images=images)
|
|
|
| expected = [
|
| {
|
| "role": "user",
|
| "content": [
|
| {"type": "image", "image": images[0]},
|
| {"type": "image", "image": images[1]},
|
| {"type": "image", "image": images[2]},
|
| {"type": "text", "text": "What color is the sky?"},
|
| ],
|
| },
|
| {
|
| "role": "assistant",
|
| "content": [{"type": "text", "text": "It is blue."}],
|
| },
|
| ]
|
|
|
| assert messages == expected
|
|
|
| def test_system_message_transformation(self):
|
| """Test that system messages are properly transformed."""
|
| messages = [
|
| {"role": "system", "content": "You are a helpful assistant"},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ]
|
|
|
| image = Image.new("RGB", (10, 10), color="blue")
|
| messages = prepare_multimodal_messages(messages, images=[image])
|
|
|
| expected = [
|
| {
|
| "role": "system",
|
| "content": [{"type": "text", "text": "You are a helpful assistant"}],
|
| },
|
| {
|
| "role": "user",
|
| "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
| },
|
| ]
|
|
|
| assert messages == expected
|
|
|
| def test_already_prepared_messages_unchanged(self):
|
| """Test that messages with list content are not modified."""
|
| messages = [
|
| {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]},
|
| {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]},
|
| {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
| ]
|
|
|
| image = Image.new("RGB", (10, 10), color="blue")
|
| messages = prepare_multimodal_messages(messages, images=[image])
|
|
|
| expected = [
|
| {
|
| "role": "system",
|
| "content": [{"type": "text", "text": "You are a helpful assistant"}],
|
| },
|
| {
|
| "role": "user",
|
| "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
| },
|
| {
|
| "role": "assistant",
|
| "content": [{"type": "text", "text": "It is blue."}],
|
| },
|
| ]
|
|
|
| assert messages == expected
|
|
|
| def test_mixed_prepared_and_unprepared_messages(self):
|
| """Test handling of mixed prepared and unprepared messages."""
|
| messages = [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
| {"role": "user", "content": "What about the grass?"},
|
| ]
|
|
|
| image = Image.new("RGB", (10, 10), color="blue")
|
| messages = prepare_multimodal_messages(messages, images=[image])
|
|
|
| expected = [
|
| {
|
| "role": "user",
|
| "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
| },
|
| {
|
| "role": "assistant",
|
| "content": [{"type": "text", "text": "It is blue."}],
|
| },
|
| {
|
| "role": "user",
|
| "content": [{"type": "text", "text": "What about the grass?"}],
|
| },
|
| ]
|
|
|
| assert messages == expected
|
|
|
| def test_message_with_tool_calling_turns(self):
|
| """Test that both the assistant tool call and the tool role turns messages are properly transformed."""
|
| messages = [
|
| {"role": "user", "content": "What's the weather like in New York?"},
|
| {
|
| "role": "assistant",
|
| "tool_calls": [
|
| {
|
| "type": "tool",
|
| "function": {"name": "get_current_weather", "arguments": {"location": "New York"}},
|
| }
|
| ],
|
| },
|
| {"role": "tool", "name": "get_current_weather", "content": "22.0"},
|
| {"role": "assistant", "content": "The current weather in New York is 22.0 degrees Celsius."},
|
| ]
|
|
|
| messages = prepare_multimodal_messages(messages)
|
|
|
| expected = [
|
| {
|
| "role": "user",
|
| "content": [{"type": "text", "text": "What's the weather like in New York?"}],
|
| },
|
| {
|
| "role": "assistant",
|
| "tool_calls": [
|
| {
|
| "type": "tool",
|
| "function": {"name": "get_current_weather", "arguments": {"location": "New York"}},
|
| }
|
| ],
|
| },
|
| {"role": "tool", "name": "get_current_weather", "content": [{"type": "text", "text": "22.0"}]},
|
| {
|
| "role": "assistant",
|
| "content": [{"type": "text", "text": "The current weather in New York is 22.0 degrees Celsius."}],
|
| },
|
| ]
|
|
|
| assert messages == expected
|
|
|
| def test_prepared_image_blocks_without_new_images(self):
|
| """Test that existing image payloads are preserved when no new images are provided."""
|
| image = Image.new("RGB", (10, 10), color="blue")
|
| messages = [
|
| {
|
| "role": "user",
|
| "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
| },
|
| {"role": "assistant", "content": "It is blue."},
|
| ]
|
|
|
| messages = prepare_multimodal_messages(messages)
|
|
|
| expected = [
|
| {
|
| "role": "user",
|
| "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}],
|
| },
|
| {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]},
|
| ]
|
|
|
| assert messages == expected
|
|
|
|
|
| @require_vision
|
| class TestPrepareMultimodalMessagesVLLM:
|
| def test_single_image_conversion(self):
|
| messages = [
|
| {
|
| "role": "user",
|
| "content": [
|
| {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")},
|
| {"type": "text", "text": "What color is the sky?"},
|
| ],
|
| }
|
| ]
|
|
|
| result = prepare_multimodal_messages_vllm(messages)
|
|
|
|
|
| assert messages[0]["content"][0]["type"] == "image"
|
|
|
|
|
| assert result[0]["content"][0]["type"] == "image_pil"
|
| assert "image_pil" in result[0]["content"][0]
|
| assert "image" not in result[0]["content"][0]
|
| assert isinstance(result[0]["content"][0]["image_pil"], Image.Image)
|
| assert result[0]["content"][1]["type"] == "text"
|
|
|
| def test_mixed_content_conversion(self):
|
| messages = [
|
| {
|
| "role": "user",
|
| "content": [
|
| {"type": "text", "text": "What color is the sky?"},
|
| {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")},
|
| ],
|
| }
|
| ]
|
|
|
| result = prepare_multimodal_messages_vllm(messages)
|
|
|
|
|
| assert result[0]["content"][0]["type"] == "text"
|
| assert result[0]["content"][1]["type"] == "image_pil"
|
|
|
| def test_no_images(self):
|
| messages = [{"role": "user", "content": [{"type": "text", "text": "What color is the sky?"}]}]
|
|
|
| result = prepare_multimodal_messages_vllm(messages)
|
|
|
|
|
| assert result == messages
|
|
|
| assert result is not messages
|
| assert result[0] is not messages[0]
|
|
|
| def test_multiple_messages(self):
|
| messages = [
|
| {
|
| "role": "user",
|
| "content": [
|
| {"type": "text", "text": "What color is the sky?"},
|
| {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")},
|
| ],
|
| },
|
| {
|
| "role": "assistant",
|
| "content": [{"type": "text", "text": "It is blue."}],
|
| },
|
| ]
|
|
|
| result = prepare_multimodal_messages_vllm(messages)
|
|
|
| assert result[0]["content"][1]["type"] == "image_pil"
|
| assert result[1]["content"][0]["type"] == "text"
|
| assert result[1]["content"][0]["text"] == "It is blue."
|
|
|
| def test_deepcopy_integrity(self):
|
| messages = [
|
| {
|
| "role": "user",
|
| "content": [
|
| {"type": "text", "text": "What color is the sky?"},
|
| {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")},
|
| ],
|
| },
|
| ]
|
| original = copy.deepcopy(messages)
|
|
|
| _ = prepare_multimodal_messages_vllm(messages)
|
|
|
|
|
| assert messages == original
|
|
|
|
|
| class TestIsConversational(TrlTestCase):
|
|
|
| conversational_examples = [
|
| {
|
| "messages": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ],
|
| },
|
| {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| },
|
| {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| "completion": [{"role": "assistant", "content": "It is blue."}],
|
| },
|
| {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| "chosen": [{"role": "assistant", "content": "It is blue."}],
|
| "rejected": [{"role": "assistant", "content": "It is green."}],
|
| },
|
| {
|
| "chosen": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ],
|
| "rejected": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is green."},
|
| ],
|
| },
|
| {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| "chosen": [
|
| {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": "get_color", "arguments": {"what": "sky"}}}]},
|
| {"role": "tool", "name": "get_color", "content": "blue"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ],
|
| "rejected": [
|
| {"role": "assistant", "tool_calls": [{"type": "function", "function": {"name": "get_color", "arguments": {"what": "tree"}}}]},
|
| {"role": "tool", "name": "get_color", "content": "green"},
|
| {"role": "assistant", "content": "It is green."},
|
| ],
|
| "tools": [
|
| {
|
| "type": "function",
|
| "function": {
|
| "description": "Gets the color.",
|
| "name": "get_color",
|
| "parameters": {"properties": {"what": {"description": "What to get the color of.", "type": "string"}}, "required": ["what"], "type": "object"},
|
| "return": {"description": "The color.", "type": "string"},
|
| },
|
| },
|
| ],
|
| },
|
| {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| "completion": [{"role": "assistant", "content": "It is blue."}],
|
| "label": True,
|
| },
|
| {
|
| "messages": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| },
|
| {
|
| "prompt": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| },
|
| {
|
| "prompt": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| "completion": [
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| },
|
| {
|
| "prompt": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| "chosen": [
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| "rejected": [
|
| {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."},
|
| ],
|
| },
|
| {
|
| "chosen": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| "rejected": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."},
|
| ],
|
| },
|
| {
|
| "prompt": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| "completion": [
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| "label": True,
|
| },
|
| ]
|
|
|
|
|
| non_conversational_examples = [
|
| {"prompt": "The sky is", "completion": " blue."},
|
| {"text": "The sky is blue."},
|
| {"prompt": "The sky is"},
|
| {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."},
|
| {"prompt": "The sky is", "completion": " blue.", "label": True},
|
| ]
|
|
|
| @pytest.mark.parametrize("example", conversational_examples)
|
| def test_conversational(self, example):
|
| assert is_conversational(example)
|
|
|
| @pytest.mark.parametrize("example", non_conversational_examples)
|
| def test_non_conversational(self, example):
|
| assert not is_conversational(example)
|
|
|
|
|
| class TestIsConversationalFromValue(TrlTestCase):
|
| def test_positive_1(self):
|
| example = {
|
| "conversations": [
|
| {"from": "user", "value": "What color is the sky?"},
|
| {"from": "assistant", "value": "It is blue."},
|
| ],
|
| }
|
| assert is_conversational_from_value(example)
|
|
|
| def test_negative_1(self):
|
| example = {
|
| "messages": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ],
|
| }
|
| assert not is_conversational_from_value(example)
|
|
|
| def test_negative_2(self):
|
| example = {"text": "The sky is blue."}
|
| assert not is_conversational_from_value(example)
|
|
|
|
|
| class TestApplyChatTemplate(TrlTestCase):
|
| tokenizers = [
|
| "trl-internal-testing/tiny-CohereForCausalLM",
|
| "trl-internal-testing/tiny-Cohere2ForCausalLM",
|
| "trl-internal-testing/tiny-DeepseekV3ForCausalLM",
|
| "trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528",
|
| "trl-internal-testing/tiny-FalconMambaForCausalLM",
|
| "trl-internal-testing/tiny-Gemma2ForCausalLM",
|
| "trl-internal-testing/tiny-GemmaForCausalLM",
|
| "trl-internal-testing/tiny-GptOssForCausalLM",
|
| pytest.param(
|
| "trl-internal-testing/tiny-Glm4MoeForCausalLM",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.0.0"),
|
| reason="GLM4 tokenizer requires transformers>=5.0.0",
|
| ),
|
| ),
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.1",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3.2",
|
| "trl-internal-testing/tiny-LlamaForCausalLM-3",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.1",
|
| "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
| "trl-internal-testing/tiny-Phi3ForCausalLM-3",
|
| "trl-internal-testing/tiny-Phi3ForCausalLM-3.5",
|
| "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
| "trl-internal-testing/tiny-Qwen3ForCausalLM",
|
| pytest.param(
|
| "trl-internal-testing/tiny-Qwen3_5ForConditionalGeneration",
|
| marks=pytest.mark.skipif(
|
| Version(transformers.__version__) < Version("5.0.0"),
|
| reason="Qwen3.5 tokenizer requires transformers>=5.0.0",
|
| ),
|
| ),
|
| ]
|
|
|
| conversational_examples = [
|
| {
|
| "messages": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ],
|
| },
|
| {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| },
|
| {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| "completion": [{"role": "assistant", "content": "It is blue."}],
|
| },
|
| {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| "chosen": [{"role": "assistant", "content": "It is blue."}],
|
| "rejected": [{"role": "assistant", "content": "It is green."}],
|
| },
|
| {
|
| "chosen": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ],
|
| "rejected": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is green."},
|
| ],
|
| },
|
| {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| "completion": [{"role": "assistant", "content": "It is blue."}],
|
| "label": True,
|
| },
|
| ]
|
|
|
| non_conversational_examples = [
|
| {"text": "The sky is blue."},
|
| {"prompt": "The sky is"},
|
| {"prompt": "The sky is", "completion": " blue."},
|
| {"prompt": "The sky is", "chosen": " blue.", "rejected": " green."},
|
| {"chosen": "The sky is blue.", "rejected": "The sky is green."},
|
| {"prompt": "The sky is", "completion": " blue.", "label": True},
|
| ]
|
|
|
| @pytest.mark.parametrize("example", conversational_examples)
|
| @pytest.mark.parametrize("tokenizer_id", tokenizers)
|
| def test_apply_chat_template(self, tokenizer_id, example):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
| result = apply_chat_template(example, tokenizer)
|
|
|
|
|
| assert isinstance(result, dict)
|
|
|
|
|
| for key in ["prompt", "chosen", "rejected", "completion"]:
|
| if key in example:
|
| assert key in result
|
| assert isinstance(result[key], str)
|
|
|
|
|
| if "messages" in example:
|
| assert "text" in result
|
| assert isinstance(result["text"], str)
|
|
|
|
|
| if "label" in example:
|
| assert "label" in result
|
| assert isinstance(result["label"], bool)
|
| assert result["label"] == example["label"]
|
|
|
|
|
| @pytest.mark.parametrize("example", conversational_examples + non_conversational_examples)
|
| @pytest.mark.parametrize("tokenizer_id", tokenizers)
|
| def test_maybe_apply_chat_template(self, tokenizer_id, example):
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
|
| result = maybe_apply_chat_template(example, tokenizer)
|
|
|
|
|
| assert isinstance(result, dict)
|
|
|
|
|
| for key in ["prompt", "chosen", "rejected", "completion"]:
|
| if key in example:
|
| assert key in result
|
| assert isinstance(result[key], str)
|
|
|
|
|
| if "messages" in example:
|
| assert "text" in result
|
| assert isinstance(result["text"], str)
|
|
|
|
|
| if "label" in example:
|
| assert "label" in result
|
| assert isinstance(result["label"], bool)
|
| assert result["label"] == example["label"]
|
|
|
| def test_apply_chat_template_with_chat_template_kwargs(self):
|
| tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3ForCausalLM")
|
|
|
| example = {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
|
|
| "chat_template_kwargs": {"enable_thinking": False},
|
| }
|
| result = apply_chat_template(example, tokenizer)
|
|
|
|
|
| expected = textwrap.dedent("""\
|
| <|im_start|>user
|
| What color is the sky?<|im_end|>
|
| <|im_start|>assistant
|
| <think>
|
|
|
| </think>
|
|
|
| """)
|
|
|
| assert result["prompt"] == expected
|
|
|
| def test_apply_chat_template_with_tools(self):
|
| tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2")
|
|
|
|
|
| def get_current_temperature(location: str):
|
| """
|
| Gets the temperature at a given location.
|
|
|
| Args:
|
| location: The location to get the temperature for
|
| """
|
| return 22.0
|
|
|
|
|
| test_case = {
|
| "prompt": [
|
| {"content": "What's the temperature in London?", "role": "user"},
|
| ]
|
| }
|
|
|
| result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature])
|
|
|
|
|
| assert "get_current_temperature" in result_with_tools["prompt"]
|
|
|
|
|
| result_without_tools = apply_chat_template(test_case, tokenizer, tools=None)
|
|
|
|
|
| assert "get_current_temperature" not in result_without_tools["prompt"]
|
|
|
|
|
| class TestApplyChatTemplateHarmony(TrlTestCase):
|
| def test_language_modeling(self):
|
| messages = {
|
| "messages": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| }
|
| output = apply_chat_template(
|
| messages,
|
| processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
| reasoning_effort="low",
|
| model_identity="You are HuggingGPT.",
|
| )
|
|
|
|
|
| expected = textwrap.dedent(f"""\
|
| <|start|>system<|message|>You are HuggingGPT.
|
| Knowledge cutoff: 2024-06
|
| Current date: {strftime("%Y-%m-%d")}
|
|
|
| Reasoning: low
|
|
|
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
| Respond in a friendly manner.
|
|
|
| <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""")
|
|
|
| assert output["text"] == expected
|
|
|
| def test_prompt_only(self):
|
| messages = {
|
| "prompt": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| }
|
| output = apply_chat_template(
|
| messages,
|
| processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
| reasoning_effort="low",
|
| model_identity="You are HuggingGPT.",
|
| )
|
|
|
|
|
| expected = textwrap.dedent(f"""\
|
| <|start|>system<|message|>You are HuggingGPT.
|
| Knowledge cutoff: 2024-06
|
| Current date: {strftime("%Y-%m-%d")}
|
|
|
| Reasoning: low
|
|
|
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
| Respond in a friendly manner.
|
|
|
| <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
|
|
|
| assert output["prompt"] == expected
|
|
|
| def test_prompt_completion(self):
|
| messages = {
|
| "prompt": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| "completion": [
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| }
|
| output = apply_chat_template(
|
| messages,
|
| processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
| reasoning_effort="low",
|
| model_identity="You are HuggingGPT.",
|
| )
|
|
|
|
|
| expected_prompt = textwrap.dedent(f"""\
|
| <|start|>system<|message|>You are HuggingGPT.
|
| Knowledge cutoff: 2024-06
|
| Current date: {strftime("%Y-%m-%d")}
|
|
|
| Reasoning: low
|
|
|
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
| Respond in a friendly manner.
|
|
|
| <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
|
| expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>"
|
|
|
| assert output["prompt"] == expected_prompt
|
| assert output["completion"] == expected_completion
|
|
|
| def test_preference(self):
|
| messages = {
|
| "prompt": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| "chosen": [
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| "rejected": [
|
| {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."},
|
| ],
|
| }
|
| output = apply_chat_template(
|
| messages,
|
| processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
| reasoning_effort="low",
|
| model_identity="You are HuggingGPT.",
|
| )
|
|
|
|
|
| expected_prompt = textwrap.dedent(f"""\
|
| <|start|>system<|message|>You are HuggingGPT.
|
| Knowledge cutoff: 2024-06
|
| Current date: {strftime("%Y-%m-%d")}
|
|
|
| Reasoning: low
|
|
|
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
| Respond in a friendly manner.
|
|
|
| <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
|
| expected_chosen = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>"
|
| expected_rejected = "<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>"
|
|
|
| assert output["prompt"] == expected_prompt
|
| assert output["chosen"] == expected_chosen
|
| assert output["rejected"] == expected_rejected
|
|
|
| def test_preference_with_implicit_prompt(self):
|
| messages = {
|
| "chosen": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| "rejected": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "thinking": "The user asks the color of the tree...", "content": "It is green."},
|
| ],
|
| }
|
| output = apply_chat_template(
|
| messages,
|
| processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
| reasoning_effort="low",
|
| model_identity="You are HuggingGPT.",
|
| )
|
|
|
|
|
| expected_chosen = textwrap.dedent(f"""\
|
| <|start|>system<|message|>You are HuggingGPT.
|
| Knowledge cutoff: 2024-06
|
| Current date: {strftime("%Y-%m-%d")}
|
|
|
| Reasoning: low
|
|
|
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
| Respond in a friendly manner.
|
|
|
| <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""")
|
|
|
|
|
| expected_rejected = textwrap.dedent(f"""\
|
| <|start|>system<|message|>You are HuggingGPT.
|
| Knowledge cutoff: 2024-06
|
| Current date: {strftime("%Y-%m-%d")}
|
|
|
| Reasoning: low
|
|
|
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
| Respond in a friendly manner.
|
|
|
| <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>""")
|
|
|
| assert output["chosen"] == expected_chosen
|
| assert output["rejected"] == expected_rejected
|
|
|
| def test_unpaired_preference(self):
|
| messages = {
|
| "prompt": [
|
| {"role": "system", "content": "Respond in a friendly manner."},
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| "completion": [
|
| {"role": "assistant", "thinking": "The user asks the color of the sky...", "content": "It is blue."},
|
| ],
|
| "label": True,
|
| }
|
| output = apply_chat_template(
|
| messages,
|
| processing_class=AutoTokenizer.from_pretrained("trl-internal-testing/tiny-GptOssForCausalLM"),
|
| reasoning_effort="low",
|
| model_identity="You are HuggingGPT.",
|
| )
|
|
|
|
|
| expected_prompt = textwrap.dedent(f"""\
|
| <|start|>system<|message|>You are HuggingGPT.
|
| Knowledge cutoff: 2024-06
|
| Current date: {strftime("%Y-%m-%d")}
|
|
|
| Reasoning: low
|
|
|
| # Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|><|start|>developer<|message|># Instructions
|
|
|
| Respond in a friendly manner.
|
|
|
| <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""")
|
| expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>"
|
|
|
| assert output["prompt"] == expected_prompt
|
| assert output["completion"] == expected_completion
|
| assert output["label"]
|
|
|
|
|
| class TestUnpairPreferenceDataset(TrlTestCase):
|
| paired_dataset = Dataset.from_dict(
|
| {
|
| "prompt": ["The sky is", "The sun is"],
|
| "chosen": [" blue.", " in the sky."],
|
| "rejected": [" green.", " in the sea."],
|
| }
|
| )
|
|
|
| unpaired_dataset = Dataset.from_dict(
|
| {
|
| "prompt": ["The sky is", "The sun is", "The sky is", "The sun is"],
|
| "completion": [" blue.", " in the sky.", " green.", " in the sea."],
|
| "label": [True, True, False, False],
|
| }
|
| )
|
|
|
| def test_unpair_preference_dataset(self):
|
|
|
| unpaired_dataset = unpair_preference_dataset(self.paired_dataset)
|
| assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), (
|
| "The paired dataset should be converted to unpaired."
|
| )
|
|
|
| def test_unpair_preference_dataset_dict(self):
|
|
|
| paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
|
| unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict)
|
| assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), (
|
| "The paired dataset should be converted to unpaired."
|
| )
|
|
|
| def test_maybe_unpair_preference_dataset(self):
|
|
|
| unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset)
|
| assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), (
|
| "The paired dataset should be converted to unpaired."
|
| )
|
|
|
| def test_maybe_unpair_preference_dataset_dict(self):
|
|
|
| paired_dataset_dict = DatasetDict({"abc": self.paired_dataset})
|
| unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict)
|
| assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), (
|
| "The paired dataset should be converted to unpaired."
|
| )
|
|
|
| def test_maybe_unpair_preference_dataset_already_paired(self):
|
|
|
| unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset)
|
| assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), (
|
| "The unpaired dataset should remain unchanged."
|
| )
|
|
|
| def test_maybe_unpair_preference_dataset_dict_already_paired(self):
|
|
|
| unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset}))
|
| assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), (
|
| "The unpaired dataset should remain unchanged."
|
| )
|
|
|
|
|
| class TestExtractPrompt(TrlTestCase):
|
| example_implicit_prompt_conversational = {
|
| "chosen": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ],
|
| "rejected": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is green."},
|
| ],
|
| }
|
|
|
| example_explicit_prompt_conversational = {
|
| "prompt": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| ],
|
| "chosen": [
|
| {"role": "assistant", "content": "It is blue."},
|
| ],
|
| "rejected": [
|
| {"role": "assistant", "content": "It is green."},
|
| ],
|
| }
|
|
|
| example_implicit_prompt_standard = {
|
| "chosen": "The sky is blue.",
|
| "rejected": "The sky is green.",
|
| }
|
|
|
| example_explicit_prompt_standard = {
|
| "prompt": "The sky is",
|
| "chosen": " blue.",
|
| "rejected": " green.",
|
| }
|
|
|
| def test_extract_prompt_conversational(self):
|
|
|
| example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational)
|
| assert example_extracted_prompt == self.example_explicit_prompt_conversational, (
|
| "The prompt is not correctly extracted from the dataset."
|
| )
|
|
|
| def test_maybe_extract_prompt_conversational(self):
|
|
|
| example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational)
|
| assert example_extracted_prompt == self.example_explicit_prompt_conversational, (
|
| "The prompt is not correctly extracted from the dataset."
|
| )
|
|
|
| def test_maybe_extract_prompt_conversational_already_explicit(self):
|
|
|
| example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational)
|
| assert example_extracted_prompt == self.example_explicit_prompt_conversational, (
|
| "The prompt should remain unchanged."
|
| )
|
|
|
| def test_extract_prompt_standard(self):
|
|
|
| example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard)
|
| assert example_extracted_prompt == self.example_explicit_prompt_standard, (
|
| "The prompt is not correctly extracted from the dataset."
|
| )
|
|
|
| def test_maybe_extract_prompt_standard(self):
|
|
|
| example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard)
|
| assert example_extracted_prompt == self.example_explicit_prompt_standard, (
|
| "The prompt is not correctly extracted from the dataset."
|
| )
|
|
|
| def test_maybe_extract_prompt_standard_already_explicit(self):
|
|
|
| example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard)
|
| assert example_extracted_prompt == self.example_explicit_prompt_standard, "The prompt should remain unchanged."
|
|
|
|
|
| class TestPackDatasetWrapped(TrlTestCase):
|
| def test_with_dataset(self):
|
| examples = {
|
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
| }
|
| dataset = Dataset.from_dict(examples)
|
| dataset = dataset.with_format("numpy", dtype="float32")
|
| format = dataset.format
|
| seq_length = 3
|
| expected_output = {
|
| "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
|
| "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
|
| }
|
| dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
|
| assert dataset.to_dict() == expected_output
|
| assert format == dataset.format
|
|
|
| def test_with_iterable_dataset(self):
|
| examples = {
|
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
| "attention_mask": [[0, 1, 1], [0, 0, 1, 1], [1]],
|
| }
|
| dataset = Dataset.from_dict(examples).to_iterable_dataset()
|
| dataset = dataset.with_format("numpy")
|
| formatting = dataset._formatting
|
| seq_length = 3
|
| expected_output = {
|
| "input_ids": [[1, 2, 3], [4, 5, 6], [7, 8]],
|
| "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]],
|
| }
|
| dataset = pack_dataset(dataset, seq_length, strategy="wrapped")
|
| num_examples = len(examples[next(iter(examples))])
|
| assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
|
| assert formatting == dataset._formatting
|
|
|
|
|
| class TestPackDatasetBfd(TrlTestCase):
|
| def test_with_dataset(self):
|
| examples = {
|
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
| }
|
| dataset = Dataset.from_dict(examples)
|
| dataset = dataset.with_format("numpy", dtype="float32")
|
| format = dataset.format
|
| seq_length = 4
|
| expected_output = {
|
| "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
|
| "seq_lengths": [[4], [3, 1]],
|
| }
|
| dataset = pack_dataset(dataset, seq_length, strategy="bfd")
|
| expected_format = dataset.format
|
| assert dataset.to_dict() == expected_output
|
| assert "seq_lengths" in expected_format["columns"]
|
| expected_format["columns"].remove("seq_lengths")
|
| assert format == dataset.format
|
|
|
| def test_with_iterable_dataset(self):
|
| examples = {
|
| "input_ids": [[1, 2, 3], [4, 5, 6, 7], [8]],
|
| }
|
| dataset = Dataset.from_dict(examples).to_iterable_dataset()
|
| dataset = dataset.with_format("numpy")
|
| formatting = dataset._formatting
|
| seq_length = 4
|
| expected_output = {
|
| "input_ids": [[4, 5, 6, 7], [1, 2, 3, 8]],
|
| "seq_lengths": [[4], [3, 1]],
|
| }
|
| dataset = pack_dataset(dataset, seq_length, strategy="bfd")
|
| num_examples = len(examples[next(iter(examples))])
|
| assert next(iter(dataset.with_format(None).batch(batch_size=num_examples))) == expected_output
|
| assert formatting == dataset._formatting
|
|
|
| def test_with_overlong_0(self):
|
| examples = {
|
| "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]],
|
| }
|
| dataset = Dataset.from_dict(examples)
|
| seq_length = 4
|
| expected_output = {
|
| "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 5, 12]],
|
| "seq_lengths": [[4], [4], [2, 1, 1]],
|
| }
|
| dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
|
| assert dataset.to_dict() == expected_output
|
|
|
| def test_with_overlong_two_coluns(self):
|
| examples = {
|
| "col1": [[1, -2, 3, -4, 5, -6], [7, -8, 9], [-10, 11, -12], [13, -14, 15, -16]],
|
| "col2": [[-1, 2, -3, 4, -5, 6], [-7, 8, -9], [10, -11, 12], [-13, 14, -15, 16]],
|
| }
|
| dataset = Dataset.from_dict(examples)
|
| seq_length = 4
|
| expected_output = {
|
| "col1": [[1, -2, 3, -4], [13, -14, 15, -16], [7, -8, 9], [-10, 11, -12], [5, -6]],
|
| "col2": [[-1, 2, -3, 4], [-13, 14, -15, 16], [-7, 8, -9], [10, -11, 12], [-5, 6]],
|
| "seq_lengths": [[4], [4], [3], [3], [2]],
|
| }
|
| dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
|
| assert dataset.to_dict() == expected_output
|
|
|
| def test_with_non_power_of_2(self):
|
| examples = {
|
| "input_ids": [[1, 2, 3, 4, 5], [6], [7, 8, 9, 10], [11, 12, 13]],
|
| }
|
| dataset = Dataset.from_dict(examples)
|
| seq_length = 5
|
| expected_output = {
|
| "input_ids": [[1, 2, 3, 4, 5], [7, 8, 9, 10, 6], [11, 12, 13]],
|
| "seq_lengths": [[5], [4, 1], [3]],
|
| }
|
| dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
|
| assert dataset.to_dict() == expected_output
|
|
|
| def test_default_no_split(self):
|
| """Test default 'bfd' strategy for SFT datasets (truncates overflow)."""
|
| examples = {
|
| "input_ids": [[1, 2, 3, 4, 5], [6, 7], [8, 9, 10, 11], [12]],
|
| }
|
| dataset = Dataset.from_dict(examples)
|
| seq_length = 4
|
|
|
| expected_output = {
|
| "input_ids": [[1, 2, 3, 4], [8, 9, 10, 11], [6, 7, 12]],
|
| "seq_lengths": [[4], [4], [2, 1]],
|
| }
|
| dataset = pack_dataset(dataset, seq_length, strategy="bfd")
|
| assert dataset.to_dict() == expected_output
|
|
|
| def test_with_empty_sequences(self):
|
| examples = {
|
| "input_ids": [[1, 2], [], [3, 4, 5], [], [6]],
|
| }
|
| dataset = Dataset.from_dict(examples)
|
| seq_length = 4
|
| expected_output = {
|
| "input_ids": [[3, 4, 5, 6], [1, 2]],
|
| "seq_lengths": [[3, 1], [2]],
|
| }
|
| dataset = pack_dataset(dataset, seq_length, strategy="bfd_split")
|
| assert dataset.to_dict() == expected_output
|
|
|
|
|
| class TestMaybeConvertToChatML(TrlTestCase):
|
| def test_with_conversations_key(self):
|
|
|
| example = {
|
| "conversations": [
|
| {"from": "user", "value": "What color is the sky?"},
|
| {"from": "assistant", "value": "It is blue."},
|
| ]
|
| }
|
| expected_output = {
|
| "messages": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ]
|
| }
|
| assert maybe_convert_to_chatml(example) == expected_output
|
|
|
| def test_without_conversations_key(self):
|
|
|
| example = {
|
| "prompt": [{"from": "user", "value": "What color is the sky?"}],
|
| "completion": [{"from": "assistant", "value": "It is blue."}],
|
| }
|
| expected_output = {
|
| "prompt": [{"role": "user", "content": "What color is the sky?"}],
|
| "completion": [{"role": "assistant", "content": "It is blue."}],
|
| }
|
| assert maybe_convert_to_chatml(example) == expected_output
|
|
|
| def test_not_conversional(self):
|
|
|
| example = {"text": "The sky is blue."}
|
| assert maybe_convert_to_chatml(example) == example
|
|
|
| def test_already_chatml(self):
|
|
|
| example = {
|
| "messages": [
|
| {"role": "user", "content": "What color is the sky?"},
|
| {"role": "assistant", "content": "It is blue."},
|
| ]
|
| }
|
| assert maybe_convert_to_chatml(example) == example
|
|
|